mirror of
https://github.com/coredns/coredns.git
synced 2026-06-02 07:10:24 -04:00
feat(core): expose TLS ConnectionState (SNI) for DoQ (#8129)
DoQWriter previously stored only the QUIC stream, so plugins reading TLS state via dns.ConnectionStater (e.g. for SNI-based routing or auditing) could not see anything for DoQ connections, even though the underlying QUIC connection carries a full tls.ConnectionState. This change adds a *quic.Conn reference to DoQWriter and wires it in serveQUICStream. It implements dns.ConnectionStater on *DoQWriter, returning the TLS state from the underlying QUIC connection (mirrors the DoT behavior that miekg/dns already provides for *tls.Conn) Forwards ConnectionState through request.ScrubWriter, which wraps every response writer before the plugin chain runs; the embedded dns.ResponseWriter interface does not promote ConnectionState (it belongs to a separate interface), so without this plugins would still see nil for both DoQ and DoT Signed-off-by: Nicholas Amorim <nicholas@santos.ee>
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
package dnsserver
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"net"
|
||||
@@ -13,6 +14,7 @@ type DoQWriter struct {
|
||||
localAddr net.Addr
|
||||
remoteAddr net.Addr
|
||||
stream *quic.Stream
|
||||
conn *quic.Conn
|
||||
Msg *dns.Msg
|
||||
tsigStatus error
|
||||
}
|
||||
@@ -68,3 +70,15 @@ func (w *DoQWriter) Hijack() {}
|
||||
func (w *DoQWriter) LocalAddr() net.Addr { return w.localAddr }
|
||||
func (w *DoQWriter) RemoteAddr() net.Addr { return w.remoteAddr }
|
||||
func (w *DoQWriter) Network() string { return "" }
|
||||
|
||||
// ConnectionState implements the dns.ConnectionStater interface, exposing the
|
||||
// TLS state of the underlying QUIC connection (e.g. for plugins that need to
|
||||
// read the SNI ServerName). Mirrors the DoT behavior already provided by the
|
||||
// miekg/dns response writer.
|
||||
func (w *DoQWriter) ConnectionState() *tls.ConnectionState {
|
||||
if w.conn == nil {
|
||||
return nil
|
||||
}
|
||||
state := w.conn.ConnectionState().TLS
|
||||
return &state
|
||||
}
|
||||
|
||||
@@ -48,3 +48,11 @@ func TestDoQWriter_ResponseWriterMethods(t *testing.T) {
|
||||
t.Errorf("RemoteAddr() = %v, want %v", addr, remoteAddr)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDoQWriter_ConnectionStateNilConn(t *testing.T) {
|
||||
writer := &DoQWriter{}
|
||||
|
||||
if state := writer.ConnectionState(); state != nil {
|
||||
t.Errorf("ConnectionState() = %v, want nil when conn is unset", state)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -221,6 +221,7 @@ func (s *ServerQUIC) serveQUICStream(stream *quic.Stream, conn *quic.Conn) {
|
||||
localAddr: conn.LocalAddr(),
|
||||
remoteAddr: conn.RemoteAddr(),
|
||||
stream: stream,
|
||||
conn: conn,
|
||||
Msg: req,
|
||||
}
|
||||
|
||||
|
||||
@@ -683,3 +683,125 @@ func TestServerQUIC_ServeQUIC_TSIGValidSigLeavesTsigStatusNil(t *testing.T) {
|
||||
t.Fatal("ServeDNS() was not called")
|
||||
}
|
||||
}
|
||||
|
||||
// connectionStateCapturePlugin records the *tls.ConnectionState observed via
|
||||
// the dns.ConnectionStater interface implemented by the DoQ response writer.
|
||||
type connectionStateCapturePlugin struct {
|
||||
t *testing.T
|
||||
called chan struct{}
|
||||
state chan *tls.ConnectionState
|
||||
}
|
||||
|
||||
func (p connectionStateCapturePlugin) Name() string { return "connection-state-capture" }
|
||||
|
||||
func (p connectionStateCapturePlugin) ServeDNS(_ context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) {
|
||||
p.t.Helper()
|
||||
|
||||
if cs, ok := w.(dns.ConnectionStater); ok {
|
||||
p.state <- cs.ConnectionState()
|
||||
} else {
|
||||
p.state <- nil
|
||||
}
|
||||
if p.called != nil {
|
||||
p.called <- struct{}{}
|
||||
}
|
||||
|
||||
m := new(dns.Msg)
|
||||
m.SetReply(r)
|
||||
if err := w.WriteMsg(m); err != nil {
|
||||
p.t.Fatalf("WriteMsg() failed: %v", err)
|
||||
}
|
||||
return dns.RcodeSuccess, nil
|
||||
}
|
||||
|
||||
func TestServerQUIC_ServeQUIC_ConnectionStateExposesSNI(t *testing.T) {
|
||||
const sni = "doq.example.com"
|
||||
|
||||
called := make(chan struct{}, 1)
|
||||
stateCh := make(chan *tls.ConnectionState, 1)
|
||||
|
||||
config := testConfig("quic", connectionStateCapturePlugin{
|
||||
t: t,
|
||||
called: called,
|
||||
state: stateCh,
|
||||
})
|
||||
config.TLSConfig = mustMakeQUICServerTLSConfig(t)
|
||||
|
||||
server, err := NewServerQUIC(transport.QUIC+"://127.0.0.1:0", []*Config{config})
|
||||
if err != nil {
|
||||
t.Fatalf("NewServerQUIC() failed: %v", err)
|
||||
}
|
||||
|
||||
pc, err := server.ListenPacket()
|
||||
if err != nil {
|
||||
t.Fatalf("ListenPacket() failed: %v", err)
|
||||
}
|
||||
defer pc.Close()
|
||||
|
||||
serveErrCh := make(chan error, 1)
|
||||
go func() {
|
||||
serveErrCh <- server.ServeQUIC()
|
||||
}()
|
||||
|
||||
defer func() {
|
||||
_ = server.Stop()
|
||||
select {
|
||||
case <-serveErrCh:
|
||||
case <-time.After(2 * time.Second):
|
||||
}
|
||||
}()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
clientTLS := mustMakeQUICClientTLSConfig()
|
||||
clientTLS.ServerName = sni
|
||||
|
||||
conn, err := quic.DialAddr(ctx, pc.LocalAddr().String(), clientTLS, &quic.Config{})
|
||||
if err != nil {
|
||||
t.Fatalf("quic.DialAddr() failed: %v", err)
|
||||
}
|
||||
defer conn.CloseWithError(DoQCodeNoError, "")
|
||||
|
||||
stream, err := conn.OpenStreamSync(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("OpenStreamSync() failed: %v", err)
|
||||
}
|
||||
|
||||
q := new(dns.Msg)
|
||||
q.SetQuestion("example.com.", dns.TypeA)
|
||||
q.Id = 0
|
||||
wire, err := q.Pack()
|
||||
if err != nil {
|
||||
t.Fatalf("dns.Msg.Pack() failed: %v", err)
|
||||
}
|
||||
|
||||
if _, err := stream.Write(AddPrefix(wire)); err != nil {
|
||||
t.Fatalf("stream.Write() failed: %v", err)
|
||||
}
|
||||
if err := stream.Close(); err != nil {
|
||||
t.Fatalf("stream.Close() failed: %v", err)
|
||||
}
|
||||
|
||||
if _, err := readDOQMessage(stream); err != nil {
|
||||
t.Fatalf("readDOQMessage() failed: %v", err)
|
||||
}
|
||||
|
||||
select {
|
||||
case <-called:
|
||||
case <-time.After(5 * time.Second):
|
||||
t.Fatal("ServeDNS() was not called")
|
||||
}
|
||||
|
||||
select {
|
||||
case state := <-stateCh:
|
||||
if state == nil {
|
||||
t.Fatal("ConnectionState() = nil, want non-nil TLS state for DoQ request")
|
||||
}
|
||||
if state.ServerName != sni {
|
||||
t.Errorf("ConnectionState().ServerName = %q, want %q", state.ServerName, sni)
|
||||
}
|
||||
case <-time.After(5 * time.Second):
|
||||
t.Fatal("did not receive connection state from plugin")
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user