core: Add full TSIG verification in QUIC transport (#8007)

* core: Add full TSIG verification in QUIC transport

This PR add full TSIG verification in QUIC using dns.TsigVerify()

Signed-off-by: Yong Tang <yong.tang.github@outlook.com>

* Fix

Signed-off-by: Yong Tang <yong.tang.github@outlook.com>

---------

Signed-off-by: Yong Tang <yong.tang.github@outlook.com>
This commit is contained in:
Yong Tang
2026-04-04 02:00:23 -07:00
committed by GitHub
parent 4c9a80c296
commit 0e1870d762
2 changed files with 273 additions and 1 deletions

View File

@@ -227,8 +227,10 @@ func (s *ServerQUIC) serveQUICStream(stream *quic.Stream, conn *quic.Conn) {
if tsig := req.IsTsig(); tsig != nil { if tsig := req.IsTsig(); tsig != nil {
if s.tsigSecret == nil { if s.tsigSecret == nil {
w.tsigStatus = dns.ErrSecret w.tsigStatus = dns.ErrSecret
} else if _, ok := s.tsigSecret[tsig.Hdr.Name]; !ok { } else if secret, ok := s.tsigSecret[tsig.Hdr.Name]; !ok {
w.tsigStatus = dns.ErrSecret w.tsigStatus = dns.ErrSecret
} else {
w.tsigStatus = dns.TsigVerify(buf, secret, "", false)
} }
} }

View File

@@ -3,14 +3,60 @@ package dnsserver
import ( import (
"bytes" "bytes"
"context" "context"
"crypto/rand"
"crypto/rsa"
"crypto/tls" "crypto/tls"
"crypto/x509"
"crypto/x509/pkix"
"errors" "errors"
"math/big"
"testing" "testing"
"time"
"github.com/coredns/coredns/plugin/pkg/transport"
"github.com/miekg/dns" "github.com/miekg/dns"
"github.com/quic-go/quic-go" "github.com/quic-go/quic-go"
) )
type tsigStatusCheckPlugin struct {
t *testing.T
check func(*testing.T, error)
called chan struct{}
}
func (p tsigStatusCheckPlugin) Name() string { return "tsig-status-check" }
func (p tsigStatusCheckPlugin) ServeDNS(_ context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) {
p.t.Helper()
if p.called != nil {
p.called <- struct{}{}
}
p.check(p.t, w.TsigStatus())
m := new(dns.Msg)
m.SetReply(r)
if err := w.WriteMsg(m); err != nil {
p.t.Fatalf("WriteMsg() failed: %v", err)
}
return dns.RcodeSuccess, nil
}
func mustPackSignedTSIGQuery(t *testing.T, keyName, secret string, tsigTime int64) []byte {
t.Helper()
m := new(dns.Msg)
m.SetQuestion("example.com.", dns.TypeA)
m.Id = 0
m.SetTsig(keyName, dns.HmacSHA256, 300, tsigTime)
wire, _, err := dns.TsigGenerate(m, secret, "", false)
if err != nil {
t.Fatalf("dns.TsigGenerate() failed: %v", err)
}
return wire
}
func TestNewServerQUIC(t *testing.T) { func TestNewServerQUIC(t *testing.T) {
tests := []struct { tests := []struct {
name string name string
@@ -451,3 +497,227 @@ func TestDoQWriterTsigStatusReturnsStoredStatus(t *testing.T) {
t.Fatalf("TsigStatus() = %v, want %v", got, want) t.Fatalf("TsigStatus() = %v, want %v", got, want)
} }
} }
func TestServerQUIC_ServeQUIC_TSIGBadSigSetsTsigStatus(t *testing.T) {
const keyName = "tsig-key."
const clientSecret = "MTIzNDU2Nzg5MDEyMzQ1Ng=="
const serverSecret = "QUJDREVGR0hJSktMTU5PUA=="
called := make(chan struct{}, 1)
config := testConfig("quic", tsigStatusCheckPlugin{
t: t,
called: called,
check: func(t *testing.T, got error) {
t.Helper()
if got == nil {
t.Fatal("TsigStatus() = nil, want non-nil for bad TSIG MAC")
}
if errors.Is(got, dns.ErrSecret) {
t.Fatalf("TsigStatus() = %v, want signature verification error, not ErrSecret", got)
}
if errors.Is(got, dns.ErrTime) {
t.Fatalf("TsigStatus() = %v, want signature verification error, not ErrTime", got)
}
},
})
config.TLSConfig = mustMakeQUICServerTLSConfig(t)
server, err := NewServerQUIC(transport.QUIC+"://127.0.0.1:0", []*Config{config})
if err != nil {
t.Fatalf("NewServerQUIC() failed: %v", err)
}
server.tsigSecret = map[string]string{
keyName: serverSecret,
}
pc, err := server.ListenPacket()
if err != nil {
t.Fatalf("ListenPacket() failed: %v", err)
}
defer pc.Close()
serveErrCh := make(chan error, 1)
go func() {
serveErrCh <- server.ServeQUIC()
}()
defer func() {
_ = server.Stop()
select {
case <-serveErrCh:
case <-time.After(2 * time.Second):
}
}()
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
conn, err := quic.DialAddr(ctx, pc.LocalAddr().String(), mustMakeQUICClientTLSConfig(), &quic.Config{})
if err != nil {
t.Fatalf("quic.DialAddr() failed: %v", err)
}
defer conn.CloseWithError(DoQCodeNoError, "")
stream, err := conn.OpenStreamSync(ctx)
if err != nil {
t.Fatalf("OpenStreamSync() failed: %v", err)
}
wire := mustPackSignedTSIGQuery(t, keyName, clientSecret, time.Now().Unix())
if _, err := stream.Write(AddPrefix(wire)); err != nil {
t.Fatalf("stream.Write() failed: %v", err)
}
if err := stream.Close(); err != nil {
t.Fatalf("stream.Close() failed: %v", err)
}
respWire, err := readDOQMessage(stream)
if err != nil {
t.Fatalf("readDOQMessage() failed: %v", err)
}
resp := new(dns.Msg)
if err := resp.Unpack(respWire); err != nil {
t.Fatalf("response unpack failed: %v", err)
}
select {
case <-called:
case <-time.After(5 * time.Second):
t.Fatal("ServeDNS() was not called")
}
}
func mustMakeQUICServerTLSConfig(t *testing.T) *tls.Config {
t.Helper()
priv, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
t.Fatalf("rsa.GenerateKey() failed: %v", err)
}
tmpl := &x509.Certificate{
SerialNumber: big.NewInt(1),
Subject: pkix.Name{
CommonName: "127.0.0.1",
},
NotBefore: time.Now().Add(-time.Hour),
NotAfter: time.Now().Add(time.Hour),
KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageKeyEncipherment,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
BasicConstraintsValid: true,
DNSNames: []string{"localhost"},
IPAddresses: nil,
}
der, err := x509.CreateCertificate(rand.Reader, tmpl, tmpl, &priv.PublicKey, priv)
if err != nil {
t.Fatalf("x509.CreateCertificate() failed: %v", err)
}
cert := tls.Certificate{
Certificate: [][]byte{der},
PrivateKey: priv,
}
return &tls.Config{
Certificates: []tls.Certificate{cert},
NextProtos: []string{"doq"},
}
}
func mustMakeQUICClientTLSConfig() *tls.Config {
return &tls.Config{
InsecureSkipVerify: true,
NextProtos: []string{"doq"},
}
}
func TestServerQUIC_ServeQUIC_TSIGValidSigLeavesTsigStatusNil(t *testing.T) {
const keyName = "tsig-key."
const secret = "MTIzNDU2Nzg5MDEyMzQ1Ng=="
called := make(chan struct{}, 1)
config := testConfig("quic", tsigStatusCheckPlugin{
t: t,
called: called,
check: func(t *testing.T, got error) {
t.Helper()
if got != nil {
t.Fatalf("TsigStatus() = %v, want nil for valid TSIG MAC", got)
}
},
})
config.TLSConfig = mustMakeQUICServerTLSConfig(t)
server, err := NewServerQUIC(transport.QUIC+"://127.0.0.1:0", []*Config{config})
if err != nil {
t.Fatalf("NewServerQUIC() failed: %v", err)
}
server.tsigSecret = map[string]string{
keyName: secret,
}
pc, err := server.ListenPacket()
if err != nil {
t.Fatalf("ListenPacket() failed: %v", err)
}
defer pc.Close()
serveErrCh := make(chan error, 1)
go func() {
serveErrCh <- server.ServeQUIC()
}()
defer func() {
_ = server.Stop()
select {
case <-serveErrCh:
case <-time.After(2 * time.Second):
}
}()
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
conn, err := quic.DialAddr(ctx, pc.LocalAddr().String(), mustMakeQUICClientTLSConfig(), &quic.Config{})
if err != nil {
t.Fatalf("quic.DialAddr() failed: %v", err)
}
defer conn.CloseWithError(DoQCodeNoError, "")
stream, err := conn.OpenStreamSync(ctx)
if err != nil {
t.Fatalf("OpenStreamSync() failed: %v", err)
}
wire := mustPackSignedTSIGQuery(t, keyName, secret, time.Now().Unix())
if _, err := stream.Write(AddPrefix(wire)); err != nil {
t.Fatalf("stream.Write() failed: %v", err)
}
if err := stream.Close(); err != nil {
t.Fatalf("stream.Close() failed: %v", err)
}
respWire, err := readDOQMessage(stream)
if err != nil {
t.Fatalf("readDOQMessage() failed: %v", err)
}
resp := new(dns.Msg)
if err := resp.Unpack(respWire); err != nil {
t.Fatalf("response unpack failed: %v", err)
}
select {
case <-called:
case <-time.After(5 * time.Second):
t.Fatal("ServeDNS() was not called")
}
}