diff --git a/core/dnsserver/quic.go b/core/dnsserver/quic.go index 04201f220..1dcbe3039 100644 --- a/core/dnsserver/quic.go +++ b/core/dnsserver/quic.go @@ -14,6 +14,7 @@ type DoQWriter struct { remoteAddr net.Addr stream *quic.Stream Msg *dns.Msg + tsigStatus error } func (w *DoQWriter) Write(b []byte) (int, error) { @@ -61,7 +62,7 @@ func AddPrefix(b []byte) (m []byte) { // These methods implement the dns.ResponseWriter interface from Go DNS. -func (w *DoQWriter) TsigStatus() error { return nil } +func (w *DoQWriter) TsigStatus() error { return w.tsigStatus } func (w *DoQWriter) TsigTimersOnly(b bool) {} func (w *DoQWriter) Hijack() {} func (w *DoQWriter) LocalAddr() net.Addr { return w.localAddr } diff --git a/core/dnsserver/server_quic.go b/core/dnsserver/server_quic.go index 8589c40a8..d80eba174 100644 --- a/core/dnsserver/server_quic.go +++ b/core/dnsserver/server_quic.go @@ -224,6 +224,14 @@ func (s *ServerQUIC) serveQUICStream(stream *quic.Stream, conn *quic.Conn) { Msg: req, } + if tsig := req.IsTsig(); tsig != nil { + if s.tsigSecret == nil { + w.tsigStatus = dns.ErrSecret + } else if _, ok := s.tsigSecret[tsig.Hdr.Name]; !ok { + w.tsigStatus = dns.ErrSecret + } + } + dnsCtx := context.WithValue(stream.Context(), Key{}, s.Server) dnsCtx = context.WithValue(dnsCtx, LoopKey{}, 0) s.ServeDNS(dnsCtx, w, req) diff --git a/core/dnsserver/server_quic_test.go b/core/dnsserver/server_quic_test.go index 19cadd2f0..563ffc578 100644 --- a/core/dnsserver/server_quic_test.go +++ b/core/dnsserver/server_quic_test.go @@ -439,3 +439,15 @@ func TestAcquireQUICWorkerReturnsFalseOnCancelledContext(t *testing.T) { t.Fatal("expected acquireQUICWorker to return false when context is cancelled") } } + +func TestDoQWriterTsigStatusReturnsStoredStatus(t *testing.T) { + want := errors.New("bad tsig") + + w := &DoQWriter{ + tsigStatus: want, + } + + if got := w.TsigStatus(); got != want { + t.Fatalf("TsigStatus() = %v, want %v", got, want) + } +}