From 6b93363b948967f72d916ad862a00585206c642c Mon Sep 17 00:00:00 2001 From: Nicholas Amorim Date: Fri, 29 May 2026 01:45:48 +0300 Subject: [PATCH] feat(core): expose TLS ConnectionState (SNI) for DoQ (#8129) DoQWriter previously stored only the QUIC stream, so plugins reading TLS state via dns.ConnectionStater (e.g. for SNI-based routing or auditing) could not see anything for DoQ connections, even though the underlying QUIC connection carries a full tls.ConnectionState. This change adds a *quic.Conn reference to DoQWriter and wires it in serveQUICStream. It implements dns.ConnectionStater on *DoQWriter, returning the TLS state from the underlying QUIC connection (mirrors the DoT behavior that miekg/dns already provides for *tls.Conn) Forwards ConnectionState through request.ScrubWriter, which wraps every response writer before the plugin chain runs; the embedded dns.ResponseWriter interface does not promote ConnectionState (it belongs to a separate interface), so without this plugins would still see nil for both DoQ and DoT Signed-off-by: Nicholas Amorim --- core/dnsserver/quic.go | 14 ++++ core/dnsserver/quic_test.go | 8 ++ core/dnsserver/server_quic.go | 1 + core/dnsserver/server_quic_test.go | 122 +++++++++++++++++++++++++++++ request/writer.go | 19 ++++- request/writer_test.go | 33 ++++++++ 6 files changed, 196 insertions(+), 1 deletion(-) diff --git a/core/dnsserver/quic.go b/core/dnsserver/quic.go index 7aa7aa48f..029f89206 100644 --- a/core/dnsserver/quic.go +++ b/core/dnsserver/quic.go @@ -1,6 +1,7 @@ package dnsserver import ( + "crypto/tls" "encoding/binary" "errors" "net" @@ -13,6 +14,7 @@ type DoQWriter struct { localAddr net.Addr remoteAddr net.Addr stream *quic.Stream + conn *quic.Conn Msg *dns.Msg tsigStatus error } @@ -68,3 +70,15 @@ func (w *DoQWriter) Hijack() {} func (w *DoQWriter) LocalAddr() net.Addr { return w.localAddr } func (w *DoQWriter) RemoteAddr() net.Addr { return w.remoteAddr } func (w *DoQWriter) Network() string { return "" } + +// ConnectionState implements the dns.ConnectionStater interface, exposing the +// TLS state of the underlying QUIC connection (e.g. for plugins that need to +// read the SNI ServerName). Mirrors the DoT behavior already provided by the +// miekg/dns response writer. +func (w *DoQWriter) ConnectionState() *tls.ConnectionState { + if w.conn == nil { + return nil + } + state := w.conn.ConnectionState().TLS + return &state +} diff --git a/core/dnsserver/quic_test.go b/core/dnsserver/quic_test.go index 7e7301906..276a0607c 100644 --- a/core/dnsserver/quic_test.go +++ b/core/dnsserver/quic_test.go @@ -48,3 +48,11 @@ func TestDoQWriter_ResponseWriterMethods(t *testing.T) { t.Errorf("RemoteAddr() = %v, want %v", addr, remoteAddr) } } + +func TestDoQWriter_ConnectionStateNilConn(t *testing.T) { + writer := &DoQWriter{} + + if state := writer.ConnectionState(); state != nil { + t.Errorf("ConnectionState() = %v, want nil when conn is unset", state) + } +} diff --git a/core/dnsserver/server_quic.go b/core/dnsserver/server_quic.go index 0deb80c37..6a5c0929e 100644 --- a/core/dnsserver/server_quic.go +++ b/core/dnsserver/server_quic.go @@ -221,6 +221,7 @@ func (s *ServerQUIC) serveQUICStream(stream *quic.Stream, conn *quic.Conn) { localAddr: conn.LocalAddr(), remoteAddr: conn.RemoteAddr(), stream: stream, + conn: conn, Msg: req, } diff --git a/core/dnsserver/server_quic_test.go b/core/dnsserver/server_quic_test.go index 28d8931be..76d592d86 100644 --- a/core/dnsserver/server_quic_test.go +++ b/core/dnsserver/server_quic_test.go @@ -683,3 +683,125 @@ func TestServerQUIC_ServeQUIC_TSIGValidSigLeavesTsigStatusNil(t *testing.T) { t.Fatal("ServeDNS() was not called") } } + +// connectionStateCapturePlugin records the *tls.ConnectionState observed via +// the dns.ConnectionStater interface implemented by the DoQ response writer. +type connectionStateCapturePlugin struct { + t *testing.T + called chan struct{} + state chan *tls.ConnectionState +} + +func (p connectionStateCapturePlugin) Name() string { return "connection-state-capture" } + +func (p connectionStateCapturePlugin) ServeDNS(_ context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) { + p.t.Helper() + + if cs, ok := w.(dns.ConnectionStater); ok { + p.state <- cs.ConnectionState() + } else { + p.state <- nil + } + if p.called != nil { + p.called <- struct{}{} + } + + m := new(dns.Msg) + m.SetReply(r) + if err := w.WriteMsg(m); err != nil { + p.t.Fatalf("WriteMsg() failed: %v", err) + } + return dns.RcodeSuccess, nil +} + +func TestServerQUIC_ServeQUIC_ConnectionStateExposesSNI(t *testing.T) { + const sni = "doq.example.com" + + called := make(chan struct{}, 1) + stateCh := make(chan *tls.ConnectionState, 1) + + config := testConfig("quic", connectionStateCapturePlugin{ + t: t, + called: called, + state: stateCh, + }) + config.TLSConfig = mustMakeQUICServerTLSConfig(t) + + server, err := NewServerQUIC(transport.QUIC+"://127.0.0.1:0", []*Config{config}) + if err != nil { + t.Fatalf("NewServerQUIC() failed: %v", err) + } + + pc, err := server.ListenPacket() + if err != nil { + t.Fatalf("ListenPacket() failed: %v", err) + } + defer pc.Close() + + serveErrCh := make(chan error, 1) + go func() { + serveErrCh <- server.ServeQUIC() + }() + + defer func() { + _ = server.Stop() + select { + case <-serveErrCh: + case <-time.After(2 * time.Second): + } + }() + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + clientTLS := mustMakeQUICClientTLSConfig() + clientTLS.ServerName = sni + + conn, err := quic.DialAddr(ctx, pc.LocalAddr().String(), clientTLS, &quic.Config{}) + if err != nil { + t.Fatalf("quic.DialAddr() failed: %v", err) + } + defer conn.CloseWithError(DoQCodeNoError, "") + + stream, err := conn.OpenStreamSync(ctx) + if err != nil { + t.Fatalf("OpenStreamSync() failed: %v", err) + } + + q := new(dns.Msg) + q.SetQuestion("example.com.", dns.TypeA) + q.Id = 0 + wire, err := q.Pack() + if err != nil { + t.Fatalf("dns.Msg.Pack() failed: %v", err) + } + + if _, err := stream.Write(AddPrefix(wire)); err != nil { + t.Fatalf("stream.Write() failed: %v", err) + } + if err := stream.Close(); err != nil { + t.Fatalf("stream.Close() failed: %v", err) + } + + if _, err := readDOQMessage(stream); err != nil { + t.Fatalf("readDOQMessage() failed: %v", err) + } + + select { + case <-called: + case <-time.After(5 * time.Second): + t.Fatal("ServeDNS() was not called") + } + + select { + case state := <-stateCh: + if state == nil { + t.Fatal("ConnectionState() = nil, want non-nil TLS state for DoQ request") + } + if state.ServerName != sni { + t.Errorf("ConnectionState().ServerName = %q, want %q", state.ServerName, sni) + } + case <-time.After(5 * time.Second): + t.Fatal("did not receive connection state from plugin") + } +} diff --git a/request/writer.go b/request/writer.go index 587b3b5d8..c560a9ac6 100644 --- a/request/writer.go +++ b/request/writer.go @@ -1,6 +1,10 @@ package request -import "github.com/miekg/dns" +import ( + "crypto/tls" + + "github.com/miekg/dns" +) // ScrubWriter will, when writing the message, call scrub to make it fit the client's buffer. type ScrubWriter struct { @@ -19,3 +23,16 @@ func (s *ScrubWriter) WriteMsg(m *dns.Msg) error { state.Scrub(m) return s.ResponseWriter.WriteMsg(m) } + +// ConnectionState forwards the TLS connection state from the wrapped +// dns.ResponseWriter, if any. Method-set promotion through the embedded +// dns.ResponseWriter does not surface ConnectionState because it belongs to +// the separate dns.ConnectionStater interface, so plugins that need TLS state +// (e.g. SNI) would otherwise lose access to it once ScrubWriter wraps the +// underlying writer. +func (s *ScrubWriter) ConnectionState() *tls.ConnectionState { + if cs, ok := s.ResponseWriter.(dns.ConnectionStater); ok { + return cs.ConnectionState() + } + return nil +} diff --git a/request/writer_test.go b/request/writer_test.go index 2b6a918f3..b475ad3e3 100644 --- a/request/writer_test.go +++ b/request/writer_test.go @@ -1,6 +1,7 @@ package request import ( + "crypto/tls" "fmt" "testing" @@ -20,6 +21,15 @@ func (m *mockResponseWriter) WriteMsg(msg *dns.Msg) error { return nil } +// connStateResponseWriter implements both dns.ResponseWriter and +// dns.ConnectionStater for testing forwarding through ScrubWriter. +type connStateResponseWriter struct { + test.ResponseWriter + state *tls.ConnectionState +} + +func (c *connStateResponseWriter) ConnectionState() *tls.ConnectionState { return c.state } + func TestScrubWriter(t *testing.T) { req := new(dns.Msg) req.SetQuestion("example.com.", dns.TypeA) @@ -49,3 +59,26 @@ func TestScrubWriter(t *testing.T) { t.Fatalf("Expected WriteMsg to be called with a message") } } + +func TestScrubWriterConnectionStateForwarded(t *testing.T) { + want := &tls.ConnectionState{ServerName: "example.test"} + inner := &connStateResponseWriter{state: want} + + sw := NewScrubWriter(new(dns.Msg), inner) + + cs, ok := dns.ResponseWriter(sw).(dns.ConnectionStater) + if !ok { + t.Fatal("ScrubWriter does not satisfy dns.ConnectionStater") + } + if got := cs.ConnectionState(); got != want { + t.Errorf("ConnectionState() = %v, want %v", got, want) + } +} + +func TestScrubWriterConnectionStateNilWhenUnsupported(t *testing.T) { + sw := NewScrubWriter(new(dns.Msg), &mockResponseWriter{}) + + if got := sw.ConnectionState(); got != nil { + t.Errorf("ConnectionState() = %v, want nil when wrapped writer is not a ConnectionStater", got) + } +}