core: Preserve TSIG status in gRPC transport (#7943)

This commit is contained in:
Yong Tang
2026-03-24 13:46:15 -07:00
committed by GitHub
parent a025712827
commit 384be4cd8e
2 changed files with 23 additions and 1 deletions

View File

@@ -195,6 +195,14 @@ func (s *ServergRPC) Query(ctx context.Context, in *pb.DnsPacket) (*pb.DnsPacket
w := &gRPCresponse{localAddr: s.listenAddr, remoteAddr: a, Msg: msg} w := &gRPCresponse{localAddr: s.listenAddr, remoteAddr: a, Msg: msg}
if tsig := msg.IsTsig(); tsig != nil {
if s.tsigSecret == nil {
w.tsigStatus = dns.ErrSecret
} else if _, ok := s.tsigSecret[tsig.Hdr.Name]; !ok {
w.tsigStatus = dns.ErrSecret
}
}
dnsCtx := context.WithValue(ctx, Key{}, s.Server) dnsCtx := context.WithValue(ctx, Key{}, s.Server)
dnsCtx = context.WithValue(dnsCtx, LoopKey{}, 0) dnsCtx = context.WithValue(dnsCtx, LoopKey{}, 0)
s.ServeDNS(dnsCtx, w, msg) s.ServeDNS(dnsCtx, w, msg)
@@ -219,6 +227,7 @@ type gRPCresponse struct {
localAddr net.Addr localAddr net.Addr
remoteAddr net.Addr remoteAddr net.Addr
Msg *dns.Msg Msg *dns.Msg
tsigStatus error
} }
// Write is the hack that makes this work. It does not actually write the message // Write is the hack that makes this work. It does not actually write the message
@@ -232,7 +241,7 @@ func (r *gRPCresponse) Write(b []byte) (int, error) {
// These methods implement the dns.ResponseWriter interface from Go DNS. // These methods implement the dns.ResponseWriter interface from Go DNS.
func (r *gRPCresponse) Close() error { return nil } func (r *gRPCresponse) Close() error { return nil }
func (r *gRPCresponse) TsigStatus() error { return nil } func (r *gRPCresponse) TsigStatus() error { return r.tsigStatus }
func (r *gRPCresponse) TsigTimersOnly(b bool) {} func (r *gRPCresponse) TsigTimersOnly(b bool) {}
func (r *gRPCresponse) Hijack() {} func (r *gRPCresponse) Hijack() {}
func (r *gRPCresponse) LocalAddr() net.Addr { return r.localAddr } func (r *gRPCresponse) LocalAddr() net.Addr { return r.localAddr }

View File

@@ -3,6 +3,7 @@ package dnsserver
import ( import (
"context" "context"
"crypto/tls" "crypto/tls"
"errors"
"net" "net"
"testing" "testing"
@@ -447,3 +448,15 @@ func TestServergRPC_Query_MaxSizeMessage(t *testing.T) {
t.Errorf("Expected no error for max size message, got: %v", err) t.Errorf("Expected no error for max size message, got: %v", err)
} }
} }
func TestGRPCResponseTsigStatusReturnsStoredStatus(t *testing.T) {
want := errors.New("bad tsig")
r := &gRPCresponse{
tsigStatus: want,
}
if got := r.TsigStatus(); got != want {
t.Fatalf("TsigStatus() = %v, want %v", got, want)
}
}