From 0e1870d762e1deb6279e3e5f470708379f5f3e79 Mon Sep 17 00:00:00 2001 From: Yong Tang Date: Sat, 4 Apr 2026 02:00:23 -0700 Subject: [PATCH] core: Add full TSIG verification in QUIC transport (#8007) * core: Add full TSIG verification in QUIC transport This PR add full TSIG verification in QUIC using dns.TsigVerify() Signed-off-by: Yong Tang * Fix Signed-off-by: Yong Tang --------- Signed-off-by: Yong Tang --- core/dnsserver/server_quic.go | 4 +- core/dnsserver/server_quic_test.go | 270 +++++++++++++++++++++++++++++ 2 files changed, 273 insertions(+), 1 deletion(-) diff --git a/core/dnsserver/server_quic.go b/core/dnsserver/server_quic.go index 5163195c0..0deb80c37 100644 --- a/core/dnsserver/server_quic.go +++ b/core/dnsserver/server_quic.go @@ -227,8 +227,10 @@ func (s *ServerQUIC) serveQUICStream(stream *quic.Stream, conn *quic.Conn) { if tsig := req.IsTsig(); tsig != nil { if s.tsigSecret == nil { w.tsigStatus = dns.ErrSecret - } else if _, ok := s.tsigSecret[tsig.Hdr.Name]; !ok { + } else if secret, ok := s.tsigSecret[tsig.Hdr.Name]; !ok { w.tsigStatus = dns.ErrSecret + } else { + w.tsigStatus = dns.TsigVerify(buf, secret, "", false) } } diff --git a/core/dnsserver/server_quic_test.go b/core/dnsserver/server_quic_test.go index 563ffc578..f25fa9a2b 100644 --- a/core/dnsserver/server_quic_test.go +++ b/core/dnsserver/server_quic_test.go @@ -3,14 +3,60 @@ package dnsserver import ( "bytes" "context" + "crypto/rand" + "crypto/rsa" "crypto/tls" + "crypto/x509" + "crypto/x509/pkix" "errors" + "math/big" "testing" + "time" + + "github.com/coredns/coredns/plugin/pkg/transport" "github.com/miekg/dns" "github.com/quic-go/quic-go" ) +type tsigStatusCheckPlugin struct { + t *testing.T + check func(*testing.T, error) + called chan struct{} +} + +func (p tsigStatusCheckPlugin) Name() string { return "tsig-status-check" } + +func (p tsigStatusCheckPlugin) ServeDNS(_ context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) { + p.t.Helper() + if p.called != nil { + p.called <- struct{}{} + } + p.check(p.t, w.TsigStatus()) + + m := new(dns.Msg) + m.SetReply(r) + if err := w.WriteMsg(m); err != nil { + p.t.Fatalf("WriteMsg() failed: %v", err) + } + return dns.RcodeSuccess, nil +} + +func mustPackSignedTSIGQuery(t *testing.T, keyName, secret string, tsigTime int64) []byte { + t.Helper() + + m := new(dns.Msg) + m.SetQuestion("example.com.", dns.TypeA) + m.Id = 0 + m.SetTsig(keyName, dns.HmacSHA256, 300, tsigTime) + + wire, _, err := dns.TsigGenerate(m, secret, "", false) + if err != nil { + t.Fatalf("dns.TsigGenerate() failed: %v", err) + } + return wire +} + func TestNewServerQUIC(t *testing.T) { tests := []struct { name string @@ -451,3 +497,227 @@ func TestDoQWriterTsigStatusReturnsStoredStatus(t *testing.T) { t.Fatalf("TsigStatus() = %v, want %v", got, want) } } + +func TestServerQUIC_ServeQUIC_TSIGBadSigSetsTsigStatus(t *testing.T) { + const keyName = "tsig-key." + const clientSecret = "MTIzNDU2Nzg5MDEyMzQ1Ng==" + const serverSecret = "QUJDREVGR0hJSktMTU5PUA==" + + called := make(chan struct{}, 1) + + config := testConfig("quic", tsigStatusCheckPlugin{ + t: t, + called: called, + check: func(t *testing.T, got error) { + t.Helper() + if got == nil { + t.Fatal("TsigStatus() = nil, want non-nil for bad TSIG MAC") + } + if errors.Is(got, dns.ErrSecret) { + t.Fatalf("TsigStatus() = %v, want signature verification error, not ErrSecret", got) + } + if errors.Is(got, dns.ErrTime) { + t.Fatalf("TsigStatus() = %v, want signature verification error, not ErrTime", got) + } + }, + }) + config.TLSConfig = mustMakeQUICServerTLSConfig(t) + + server, err := NewServerQUIC(transport.QUIC+"://127.0.0.1:0", []*Config{config}) + if err != nil { + t.Fatalf("NewServerQUIC() failed: %v", err) + } + + server.tsigSecret = map[string]string{ + keyName: serverSecret, + } + + pc, err := server.ListenPacket() + if err != nil { + t.Fatalf("ListenPacket() failed: %v", err) + } + defer pc.Close() + + serveErrCh := make(chan error, 1) + go func() { + serveErrCh <- server.ServeQUIC() + }() + + defer func() { + _ = server.Stop() + select { + case <-serveErrCh: + case <-time.After(2 * time.Second): + } + }() + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + conn, err := quic.DialAddr(ctx, pc.LocalAddr().String(), mustMakeQUICClientTLSConfig(), &quic.Config{}) + if err != nil { + t.Fatalf("quic.DialAddr() failed: %v", err) + } + defer conn.CloseWithError(DoQCodeNoError, "") + + stream, err := conn.OpenStreamSync(ctx) + if err != nil { + t.Fatalf("OpenStreamSync() failed: %v", err) + } + + wire := mustPackSignedTSIGQuery(t, keyName, clientSecret, time.Now().Unix()) + + if _, err := stream.Write(AddPrefix(wire)); err != nil { + t.Fatalf("stream.Write() failed: %v", err) + } + if err := stream.Close(); err != nil { + t.Fatalf("stream.Close() failed: %v", err) + } + + respWire, err := readDOQMessage(stream) + if err != nil { + t.Fatalf("readDOQMessage() failed: %v", err) + } + + resp := new(dns.Msg) + if err := resp.Unpack(respWire); err != nil { + t.Fatalf("response unpack failed: %v", err) + } + + select { + case <-called: + case <-time.After(5 * time.Second): + t.Fatal("ServeDNS() was not called") + } +} + +func mustMakeQUICServerTLSConfig(t *testing.T) *tls.Config { + t.Helper() + + priv, err := rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + t.Fatalf("rsa.GenerateKey() failed: %v", err) + } + + tmpl := &x509.Certificate{ + SerialNumber: big.NewInt(1), + Subject: pkix.Name{ + CommonName: "127.0.0.1", + }, + NotBefore: time.Now().Add(-time.Hour), + NotAfter: time.Now().Add(time.Hour), + KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageKeyEncipherment, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, + BasicConstraintsValid: true, + DNSNames: []string{"localhost"}, + IPAddresses: nil, + } + + der, err := x509.CreateCertificate(rand.Reader, tmpl, tmpl, &priv.PublicKey, priv) + if err != nil { + t.Fatalf("x509.CreateCertificate() failed: %v", err) + } + + cert := tls.Certificate{ + Certificate: [][]byte{der}, + PrivateKey: priv, + } + + return &tls.Config{ + Certificates: []tls.Certificate{cert}, + NextProtos: []string{"doq"}, + } +} + +func mustMakeQUICClientTLSConfig() *tls.Config { + return &tls.Config{ + InsecureSkipVerify: true, + NextProtos: []string{"doq"}, + } +} + +func TestServerQUIC_ServeQUIC_TSIGValidSigLeavesTsigStatusNil(t *testing.T) { + const keyName = "tsig-key." + const secret = "MTIzNDU2Nzg5MDEyMzQ1Ng==" + + called := make(chan struct{}, 1) + + config := testConfig("quic", tsigStatusCheckPlugin{ + t: t, + called: called, + check: func(t *testing.T, got error) { + t.Helper() + if got != nil { + t.Fatalf("TsigStatus() = %v, want nil for valid TSIG MAC", got) + } + }, + }) + config.TLSConfig = mustMakeQUICServerTLSConfig(t) + + server, err := NewServerQUIC(transport.QUIC+"://127.0.0.1:0", []*Config{config}) + if err != nil { + t.Fatalf("NewServerQUIC() failed: %v", err) + } + + server.tsigSecret = map[string]string{ + keyName: secret, + } + + pc, err := server.ListenPacket() + if err != nil { + t.Fatalf("ListenPacket() failed: %v", err) + } + defer pc.Close() + + serveErrCh := make(chan error, 1) + go func() { + serveErrCh <- server.ServeQUIC() + }() + + defer func() { + _ = server.Stop() + select { + case <-serveErrCh: + case <-time.After(2 * time.Second): + } + }() + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + conn, err := quic.DialAddr(ctx, pc.LocalAddr().String(), mustMakeQUICClientTLSConfig(), &quic.Config{}) + if err != nil { + t.Fatalf("quic.DialAddr() failed: %v", err) + } + defer conn.CloseWithError(DoQCodeNoError, "") + + stream, err := conn.OpenStreamSync(ctx) + if err != nil { + t.Fatalf("OpenStreamSync() failed: %v", err) + } + + wire := mustPackSignedTSIGQuery(t, keyName, secret, time.Now().Unix()) + + if _, err := stream.Write(AddPrefix(wire)); err != nil { + t.Fatalf("stream.Write() failed: %v", err) + } + if err := stream.Close(); err != nil { + t.Fatalf("stream.Close() failed: %v", err) + } + + respWire, err := readDOQMessage(stream) + if err != nil { + t.Fatalf("readDOQMessage() failed: %v", err) + } + + resp := new(dns.Msg) + if err := resp.Unpack(respWire); err != nil { + t.Fatalf("response unpack failed: %v", err) + } + + select { + case <-called: + case <-time.After(5 * time.Second): + t.Fatal("ServeDNS() was not called") + } +}