mirror of
https://github.com/coredns/coredns.git
synced 2026-06-01 23:00:23 -04:00
feat(core): expose TLS ConnectionState (SNI) for DoQ (#8129)
DoQWriter previously stored only the QUIC stream, so plugins reading TLS state via dns.ConnectionStater (e.g. for SNI-based routing or auditing) could not see anything for DoQ connections, even though the underlying QUIC connection carries a full tls.ConnectionState. This change adds a *quic.Conn reference to DoQWriter and wires it in serveQUICStream. It implements dns.ConnectionStater on *DoQWriter, returning the TLS state from the underlying QUIC connection (mirrors the DoT behavior that miekg/dns already provides for *tls.Conn) Forwards ConnectionState through request.ScrubWriter, which wraps every response writer before the plugin chain runs; the embedded dns.ResponseWriter interface does not promote ConnectionState (it belongs to a separate interface), so without this plugins would still see nil for both DoQ and DoT Signed-off-by: Nicholas Amorim <nicholas@santos.ee>
This commit is contained in:
@@ -1,6 +1,7 @@
|
|||||||
package dnsserver
|
package dnsserver
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"crypto/tls"
|
||||||
"encoding/binary"
|
"encoding/binary"
|
||||||
"errors"
|
"errors"
|
||||||
"net"
|
"net"
|
||||||
@@ -13,6 +14,7 @@ type DoQWriter struct {
|
|||||||
localAddr net.Addr
|
localAddr net.Addr
|
||||||
remoteAddr net.Addr
|
remoteAddr net.Addr
|
||||||
stream *quic.Stream
|
stream *quic.Stream
|
||||||
|
conn *quic.Conn
|
||||||
Msg *dns.Msg
|
Msg *dns.Msg
|
||||||
tsigStatus error
|
tsigStatus error
|
||||||
}
|
}
|
||||||
@@ -68,3 +70,15 @@ func (w *DoQWriter) Hijack() {}
|
|||||||
func (w *DoQWriter) LocalAddr() net.Addr { return w.localAddr }
|
func (w *DoQWriter) LocalAddr() net.Addr { return w.localAddr }
|
||||||
func (w *DoQWriter) RemoteAddr() net.Addr { return w.remoteAddr }
|
func (w *DoQWriter) RemoteAddr() net.Addr { return w.remoteAddr }
|
||||||
func (w *DoQWriter) Network() string { return "" }
|
func (w *DoQWriter) Network() string { return "" }
|
||||||
|
|
||||||
|
// ConnectionState implements the dns.ConnectionStater interface, exposing the
|
||||||
|
// TLS state of the underlying QUIC connection (e.g. for plugins that need to
|
||||||
|
// read the SNI ServerName). Mirrors the DoT behavior already provided by the
|
||||||
|
// miekg/dns response writer.
|
||||||
|
func (w *DoQWriter) ConnectionState() *tls.ConnectionState {
|
||||||
|
if w.conn == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
state := w.conn.ConnectionState().TLS
|
||||||
|
return &state
|
||||||
|
}
|
||||||
|
|||||||
@@ -48,3 +48,11 @@ func TestDoQWriter_ResponseWriterMethods(t *testing.T) {
|
|||||||
t.Errorf("RemoteAddr() = %v, want %v", addr, remoteAddr)
|
t.Errorf("RemoteAddr() = %v, want %v", addr, remoteAddr)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestDoQWriter_ConnectionStateNilConn(t *testing.T) {
|
||||||
|
writer := &DoQWriter{}
|
||||||
|
|
||||||
|
if state := writer.ConnectionState(); state != nil {
|
||||||
|
t.Errorf("ConnectionState() = %v, want nil when conn is unset", state)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -221,6 +221,7 @@ func (s *ServerQUIC) serveQUICStream(stream *quic.Stream, conn *quic.Conn) {
|
|||||||
localAddr: conn.LocalAddr(),
|
localAddr: conn.LocalAddr(),
|
||||||
remoteAddr: conn.RemoteAddr(),
|
remoteAddr: conn.RemoteAddr(),
|
||||||
stream: stream,
|
stream: stream,
|
||||||
|
conn: conn,
|
||||||
Msg: req,
|
Msg: req,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -683,3 +683,125 @@ func TestServerQUIC_ServeQUIC_TSIGValidSigLeavesTsigStatusNil(t *testing.T) {
|
|||||||
t.Fatal("ServeDNS() was not called")
|
t.Fatal("ServeDNS() was not called")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// connectionStateCapturePlugin records the *tls.ConnectionState observed via
|
||||||
|
// the dns.ConnectionStater interface implemented by the DoQ response writer.
|
||||||
|
type connectionStateCapturePlugin struct {
|
||||||
|
t *testing.T
|
||||||
|
called chan struct{}
|
||||||
|
state chan *tls.ConnectionState
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p connectionStateCapturePlugin) Name() string { return "connection-state-capture" }
|
||||||
|
|
||||||
|
func (p connectionStateCapturePlugin) ServeDNS(_ context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) {
|
||||||
|
p.t.Helper()
|
||||||
|
|
||||||
|
if cs, ok := w.(dns.ConnectionStater); ok {
|
||||||
|
p.state <- cs.ConnectionState()
|
||||||
|
} else {
|
||||||
|
p.state <- nil
|
||||||
|
}
|
||||||
|
if p.called != nil {
|
||||||
|
p.called <- struct{}{}
|
||||||
|
}
|
||||||
|
|
||||||
|
m := new(dns.Msg)
|
||||||
|
m.SetReply(r)
|
||||||
|
if err := w.WriteMsg(m); err != nil {
|
||||||
|
p.t.Fatalf("WriteMsg() failed: %v", err)
|
||||||
|
}
|
||||||
|
return dns.RcodeSuccess, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestServerQUIC_ServeQUIC_ConnectionStateExposesSNI(t *testing.T) {
|
||||||
|
const sni = "doq.example.com"
|
||||||
|
|
||||||
|
called := make(chan struct{}, 1)
|
||||||
|
stateCh := make(chan *tls.ConnectionState, 1)
|
||||||
|
|
||||||
|
config := testConfig("quic", connectionStateCapturePlugin{
|
||||||
|
t: t,
|
||||||
|
called: called,
|
||||||
|
state: stateCh,
|
||||||
|
})
|
||||||
|
config.TLSConfig = mustMakeQUICServerTLSConfig(t)
|
||||||
|
|
||||||
|
server, err := NewServerQUIC(transport.QUIC+"://127.0.0.1:0", []*Config{config})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("NewServerQUIC() failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
pc, err := server.ListenPacket()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("ListenPacket() failed: %v", err)
|
||||||
|
}
|
||||||
|
defer pc.Close()
|
||||||
|
|
||||||
|
serveErrCh := make(chan error, 1)
|
||||||
|
go func() {
|
||||||
|
serveErrCh <- server.ServeQUIC()
|
||||||
|
}()
|
||||||
|
|
||||||
|
defer func() {
|
||||||
|
_ = server.Stop()
|
||||||
|
select {
|
||||||
|
case <-serveErrCh:
|
||||||
|
case <-time.After(2 * time.Second):
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
clientTLS := mustMakeQUICClientTLSConfig()
|
||||||
|
clientTLS.ServerName = sni
|
||||||
|
|
||||||
|
conn, err := quic.DialAddr(ctx, pc.LocalAddr().String(), clientTLS, &quic.Config{})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("quic.DialAddr() failed: %v", err)
|
||||||
|
}
|
||||||
|
defer conn.CloseWithError(DoQCodeNoError, "")
|
||||||
|
|
||||||
|
stream, err := conn.OpenStreamSync(ctx)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("OpenStreamSync() failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
q := new(dns.Msg)
|
||||||
|
q.SetQuestion("example.com.", dns.TypeA)
|
||||||
|
q.Id = 0
|
||||||
|
wire, err := q.Pack()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("dns.Msg.Pack() failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, err := stream.Write(AddPrefix(wire)); err != nil {
|
||||||
|
t.Fatalf("stream.Write() failed: %v", err)
|
||||||
|
}
|
||||||
|
if err := stream.Close(); err != nil {
|
||||||
|
t.Fatalf("stream.Close() failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, err := readDOQMessage(stream); err != nil {
|
||||||
|
t.Fatalf("readDOQMessage() failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-called:
|
||||||
|
case <-time.After(5 * time.Second):
|
||||||
|
t.Fatal("ServeDNS() was not called")
|
||||||
|
}
|
||||||
|
|
||||||
|
select {
|
||||||
|
case state := <-stateCh:
|
||||||
|
if state == nil {
|
||||||
|
t.Fatal("ConnectionState() = nil, want non-nil TLS state for DoQ request")
|
||||||
|
}
|
||||||
|
if state.ServerName != sni {
|
||||||
|
t.Errorf("ConnectionState().ServerName = %q, want %q", state.ServerName, sni)
|
||||||
|
}
|
||||||
|
case <-time.After(5 * time.Second):
|
||||||
|
t.Fatal("did not receive connection state from plugin")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -1,6 +1,10 @@
|
|||||||
package request
|
package request
|
||||||
|
|
||||||
import "github.com/miekg/dns"
|
import (
|
||||||
|
"crypto/tls"
|
||||||
|
|
||||||
|
"github.com/miekg/dns"
|
||||||
|
)
|
||||||
|
|
||||||
// ScrubWriter will, when writing the message, call scrub to make it fit the client's buffer.
|
// ScrubWriter will, when writing the message, call scrub to make it fit the client's buffer.
|
||||||
type ScrubWriter struct {
|
type ScrubWriter struct {
|
||||||
@@ -19,3 +23,16 @@ func (s *ScrubWriter) WriteMsg(m *dns.Msg) error {
|
|||||||
state.Scrub(m)
|
state.Scrub(m)
|
||||||
return s.ResponseWriter.WriteMsg(m)
|
return s.ResponseWriter.WriteMsg(m)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ConnectionState forwards the TLS connection state from the wrapped
|
||||||
|
// dns.ResponseWriter, if any. Method-set promotion through the embedded
|
||||||
|
// dns.ResponseWriter does not surface ConnectionState because it belongs to
|
||||||
|
// the separate dns.ConnectionStater interface, so plugins that need TLS state
|
||||||
|
// (e.g. SNI) would otherwise lose access to it once ScrubWriter wraps the
|
||||||
|
// underlying writer.
|
||||||
|
func (s *ScrubWriter) ConnectionState() *tls.ConnectionState {
|
||||||
|
if cs, ok := s.ResponseWriter.(dns.ConnectionStater); ok {
|
||||||
|
return cs.ConnectionState()
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
package request
|
package request
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"crypto/tls"
|
||||||
"fmt"
|
"fmt"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
@@ -20,6 +21,15 @@ func (m *mockResponseWriter) WriteMsg(msg *dns.Msg) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// connStateResponseWriter implements both dns.ResponseWriter and
|
||||||
|
// dns.ConnectionStater for testing forwarding through ScrubWriter.
|
||||||
|
type connStateResponseWriter struct {
|
||||||
|
test.ResponseWriter
|
||||||
|
state *tls.ConnectionState
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *connStateResponseWriter) ConnectionState() *tls.ConnectionState { return c.state }
|
||||||
|
|
||||||
func TestScrubWriter(t *testing.T) {
|
func TestScrubWriter(t *testing.T) {
|
||||||
req := new(dns.Msg)
|
req := new(dns.Msg)
|
||||||
req.SetQuestion("example.com.", dns.TypeA)
|
req.SetQuestion("example.com.", dns.TypeA)
|
||||||
@@ -49,3 +59,26 @@ func TestScrubWriter(t *testing.T) {
|
|||||||
t.Fatalf("Expected WriteMsg to be called with a message")
|
t.Fatalf("Expected WriteMsg to be called with a message")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestScrubWriterConnectionStateForwarded(t *testing.T) {
|
||||||
|
want := &tls.ConnectionState{ServerName: "example.test"}
|
||||||
|
inner := &connStateResponseWriter{state: want}
|
||||||
|
|
||||||
|
sw := NewScrubWriter(new(dns.Msg), inner)
|
||||||
|
|
||||||
|
cs, ok := dns.ResponseWriter(sw).(dns.ConnectionStater)
|
||||||
|
if !ok {
|
||||||
|
t.Fatal("ScrubWriter does not satisfy dns.ConnectionStater")
|
||||||
|
}
|
||||||
|
if got := cs.ConnectionState(); got != want {
|
||||||
|
t.Errorf("ConnectionState() = %v, want %v", got, want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestScrubWriterConnectionStateNilWhenUnsupported(t *testing.T) {
|
||||||
|
sw := NewScrubWriter(new(dns.Msg), &mockResponseWriter{})
|
||||||
|
|
||||||
|
if got := sw.ConnectionState(); got != nil {
|
||||||
|
t.Errorf("ConnectionState() = %v, want nil when wrapped writer is not a ConnectionStater", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user