From 384be4cd8e5f0d6766dc96bb95f80d39d28cee75 Mon Sep 17 00:00:00 2001 From: Yong Tang Date: Tue, 24 Mar 2026 13:46:15 -0700 Subject: [PATCH] core: Preserve TSIG status in gRPC transport (#7943) --- core/dnsserver/server_grpc.go | 11 ++++++++++- core/dnsserver/server_grpc_test.go | 13 +++++++++++++ 2 files changed, 23 insertions(+), 1 deletion(-) diff --git a/core/dnsserver/server_grpc.go b/core/dnsserver/server_grpc.go index 0fd377072..0b51ab0e2 100644 --- a/core/dnsserver/server_grpc.go +++ b/core/dnsserver/server_grpc.go @@ -195,6 +195,14 @@ func (s *ServergRPC) Query(ctx context.Context, in *pb.DnsPacket) (*pb.DnsPacket w := &gRPCresponse{localAddr: s.listenAddr, remoteAddr: a, Msg: msg} + if tsig := msg.IsTsig(); tsig != nil { + if s.tsigSecret == nil { + w.tsigStatus = dns.ErrSecret + } else if _, ok := s.tsigSecret[tsig.Hdr.Name]; !ok { + w.tsigStatus = dns.ErrSecret + } + } + dnsCtx := context.WithValue(ctx, Key{}, s.Server) dnsCtx = context.WithValue(dnsCtx, LoopKey{}, 0) s.ServeDNS(dnsCtx, w, msg) @@ -219,6 +227,7 @@ type gRPCresponse struct { localAddr net.Addr remoteAddr net.Addr Msg *dns.Msg + tsigStatus error } // Write is the hack that makes this work. It does not actually write the message @@ -232,7 +241,7 @@ func (r *gRPCresponse) Write(b []byte) (int, error) { // These methods implement the dns.ResponseWriter interface from Go DNS. func (r *gRPCresponse) Close() error { return nil } -func (r *gRPCresponse) TsigStatus() error { return nil } +func (r *gRPCresponse) TsigStatus() error { return r.tsigStatus } func (r *gRPCresponse) TsigTimersOnly(b bool) {} func (r *gRPCresponse) Hijack() {} func (r *gRPCresponse) LocalAddr() net.Addr { return r.localAddr } diff --git a/core/dnsserver/server_grpc_test.go b/core/dnsserver/server_grpc_test.go index bfb095cfc..ae2990f98 100644 --- a/core/dnsserver/server_grpc_test.go +++ b/core/dnsserver/server_grpc_test.go @@ -3,6 +3,7 @@ package dnsserver import ( "context" "crypto/tls" + "errors" "net" "testing" @@ -447,3 +448,15 @@ func TestServergRPC_Query_MaxSizeMessage(t *testing.T) { t.Errorf("Expected no error for max size message, got: %v", err) } } + +func TestGRPCResponseTsigStatusReturnsStoredStatus(t *testing.T) { + want := errors.New("bad tsig") + + r := &gRPCresponse{ + tsigStatus: want, + } + + if got := r.TsigStatus(); got != want { + t.Fatalf("TsigStatus() = %v, want %v", got, want) + } +}