feat(core): expose TLS ConnectionState (SNI) for DoQ (#8129)

DoQWriter previously stored only the QUIC stream, so plugins reading
TLS state via dns.ConnectionStater (e.g. for SNI-based routing or
auditing) could not see anything for DoQ connections, even
though the underlying QUIC connection carries a full tls.ConnectionState.

This change adds a *quic.Conn reference to DoQWriter and wires it in serveQUICStream.

It implements dns.ConnectionStater on *DoQWriter, returning the TLS
state from the underlying QUIC connection (mirrors the DoT behavior
that miekg/dns already provides for *tls.Conn)

Forwards ConnectionState through request.ScrubWriter, which wraps
every response writer before the plugin chain runs; the embedded
dns.ResponseWriter interface does not promote ConnectionState (it
belongs to a separate interface), so without this plugins would
still see nil for both DoQ and DoT

Signed-off-by: Nicholas Amorim <nicholas@santos.ee>
This commit is contained in:
Nicholas Amorim
2026-05-29 01:45:48 +03:00
committed by GitHub
parent 0bcb17df06
commit 6b93363b94
6 changed files with 196 additions and 1 deletions

View File

@@ -1,6 +1,7 @@
package dnsserver
import (
"crypto/tls"
"encoding/binary"
"errors"
"net"
@@ -13,6 +14,7 @@ type DoQWriter struct {
localAddr net.Addr
remoteAddr net.Addr
stream *quic.Stream
conn *quic.Conn
Msg *dns.Msg
tsigStatus error
}
@@ -68,3 +70,15 @@ func (w *DoQWriter) Hijack() {}
func (w *DoQWriter) LocalAddr() net.Addr { return w.localAddr }
func (w *DoQWriter) RemoteAddr() net.Addr { return w.remoteAddr }
func (w *DoQWriter) Network() string { return "" }
// ConnectionState implements the dns.ConnectionStater interface, exposing the
// TLS state of the underlying QUIC connection (e.g. for plugins that need to
// read the SNI ServerName). Mirrors the DoT behavior already provided by the
// miekg/dns response writer.
func (w *DoQWriter) ConnectionState() *tls.ConnectionState {
if w.conn == nil {
return nil
}
state := w.conn.ConnectionState().TLS
return &state
}

View File

@@ -48,3 +48,11 @@ func TestDoQWriter_ResponseWriterMethods(t *testing.T) {
t.Errorf("RemoteAddr() = %v, want %v", addr, remoteAddr)
}
}
func TestDoQWriter_ConnectionStateNilConn(t *testing.T) {
writer := &DoQWriter{}
if state := writer.ConnectionState(); state != nil {
t.Errorf("ConnectionState() = %v, want nil when conn is unset", state)
}
}

View File

@@ -221,6 +221,7 @@ func (s *ServerQUIC) serveQUICStream(stream *quic.Stream, conn *quic.Conn) {
localAddr: conn.LocalAddr(),
remoteAddr: conn.RemoteAddr(),
stream: stream,
conn: conn,
Msg: req,
}

View File

@@ -683,3 +683,125 @@ func TestServerQUIC_ServeQUIC_TSIGValidSigLeavesTsigStatusNil(t *testing.T) {
t.Fatal("ServeDNS() was not called")
}
}
// connectionStateCapturePlugin records the *tls.ConnectionState observed via
// the dns.ConnectionStater interface implemented by the DoQ response writer.
type connectionStateCapturePlugin struct {
t *testing.T
called chan struct{}
state chan *tls.ConnectionState
}
func (p connectionStateCapturePlugin) Name() string { return "connection-state-capture" }
func (p connectionStateCapturePlugin) ServeDNS(_ context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) {
p.t.Helper()
if cs, ok := w.(dns.ConnectionStater); ok {
p.state <- cs.ConnectionState()
} else {
p.state <- nil
}
if p.called != nil {
p.called <- struct{}{}
}
m := new(dns.Msg)
m.SetReply(r)
if err := w.WriteMsg(m); err != nil {
p.t.Fatalf("WriteMsg() failed: %v", err)
}
return dns.RcodeSuccess, nil
}
func TestServerQUIC_ServeQUIC_ConnectionStateExposesSNI(t *testing.T) {
const sni = "doq.example.com"
called := make(chan struct{}, 1)
stateCh := make(chan *tls.ConnectionState, 1)
config := testConfig("quic", connectionStateCapturePlugin{
t: t,
called: called,
state: stateCh,
})
config.TLSConfig = mustMakeQUICServerTLSConfig(t)
server, err := NewServerQUIC(transport.QUIC+"://127.0.0.1:0", []*Config{config})
if err != nil {
t.Fatalf("NewServerQUIC() failed: %v", err)
}
pc, err := server.ListenPacket()
if err != nil {
t.Fatalf("ListenPacket() failed: %v", err)
}
defer pc.Close()
serveErrCh := make(chan error, 1)
go func() {
serveErrCh <- server.ServeQUIC()
}()
defer func() {
_ = server.Stop()
select {
case <-serveErrCh:
case <-time.After(2 * time.Second):
}
}()
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
clientTLS := mustMakeQUICClientTLSConfig()
clientTLS.ServerName = sni
conn, err := quic.DialAddr(ctx, pc.LocalAddr().String(), clientTLS, &quic.Config{})
if err != nil {
t.Fatalf("quic.DialAddr() failed: %v", err)
}
defer conn.CloseWithError(DoQCodeNoError, "")
stream, err := conn.OpenStreamSync(ctx)
if err != nil {
t.Fatalf("OpenStreamSync() failed: %v", err)
}
q := new(dns.Msg)
q.SetQuestion("example.com.", dns.TypeA)
q.Id = 0
wire, err := q.Pack()
if err != nil {
t.Fatalf("dns.Msg.Pack() failed: %v", err)
}
if _, err := stream.Write(AddPrefix(wire)); err != nil {
t.Fatalf("stream.Write() failed: %v", err)
}
if err := stream.Close(); err != nil {
t.Fatalf("stream.Close() failed: %v", err)
}
if _, err := readDOQMessage(stream); err != nil {
t.Fatalf("readDOQMessage() failed: %v", err)
}
select {
case <-called:
case <-time.After(5 * time.Second):
t.Fatal("ServeDNS() was not called")
}
select {
case state := <-stateCh:
if state == nil {
t.Fatal("ConnectionState() = nil, want non-nil TLS state for DoQ request")
}
if state.ServerName != sni {
t.Errorf("ConnectionState().ServerName = %q, want %q", state.ServerName, sni)
}
case <-time.After(5 * time.Second):
t.Fatal("did not receive connection state from plugin")
}
}

View File

@@ -1,6 +1,10 @@
package request
import "github.com/miekg/dns"
import (
"crypto/tls"
"github.com/miekg/dns"
)
// ScrubWriter will, when writing the message, call scrub to make it fit the client's buffer.
type ScrubWriter struct {
@@ -19,3 +23,16 @@ func (s *ScrubWriter) WriteMsg(m *dns.Msg) error {
state.Scrub(m)
return s.ResponseWriter.WriteMsg(m)
}
// ConnectionState forwards the TLS connection state from the wrapped
// dns.ResponseWriter, if any. Method-set promotion through the embedded
// dns.ResponseWriter does not surface ConnectionState because it belongs to
// the separate dns.ConnectionStater interface, so plugins that need TLS state
// (e.g. SNI) would otherwise lose access to it once ScrubWriter wraps the
// underlying writer.
func (s *ScrubWriter) ConnectionState() *tls.ConnectionState {
if cs, ok := s.ResponseWriter.(dns.ConnectionStater); ok {
return cs.ConnectionState()
}
return nil
}

View File

@@ -1,6 +1,7 @@
package request
import (
"crypto/tls"
"fmt"
"testing"
@@ -20,6 +21,15 @@ func (m *mockResponseWriter) WriteMsg(msg *dns.Msg) error {
return nil
}
// connStateResponseWriter implements both dns.ResponseWriter and
// dns.ConnectionStater for testing forwarding through ScrubWriter.
type connStateResponseWriter struct {
test.ResponseWriter
state *tls.ConnectionState
}
func (c *connStateResponseWriter) ConnectionState() *tls.ConnectionState { return c.state }
func TestScrubWriter(t *testing.T) {
req := new(dns.Msg)
req.SetQuestion("example.com.", dns.TypeA)
@@ -49,3 +59,26 @@ func TestScrubWriter(t *testing.T) {
t.Fatalf("Expected WriteMsg to be called with a message")
}
}
func TestScrubWriterConnectionStateForwarded(t *testing.T) {
want := &tls.ConnectionState{ServerName: "example.test"}
inner := &connStateResponseWriter{state: want}
sw := NewScrubWriter(new(dns.Msg), inner)
cs, ok := dns.ResponseWriter(sw).(dns.ConnectionStater)
if !ok {
t.Fatal("ScrubWriter does not satisfy dns.ConnectionStater")
}
if got := cs.ConnectionState(); got != want {
t.Errorf("ConnectionState() = %v, want %v", got, want)
}
}
func TestScrubWriterConnectionStateNilWhenUnsupported(t *testing.T) {
sw := NewScrubWriter(new(dns.Msg), &mockResponseWriter{})
if got := sw.ConnectionState(); got != nil {
t.Errorf("ConnectionState() = %v, want nil when wrapped writer is not a ConnectionStater", got)
}
}