From 4c9a80c296d0f42b9c0d1beaa8292dfcfb116da3 Mon Sep 17 00:00:00 2001 From: Yong Tang Date: Sat, 4 Apr 2026 01:58:36 -0700 Subject: [PATCH] core: Add full TSIG verification in gRPC transport (#8006) * core: Add full TSIG verification in gRPC transport This PR add full TSIG verification in gRPC using dns.TsigVerify() so invalid signatures and timestamps are correctly detected instead of only checking key presence. Signed-off-by: Yong Tang * Fix Signed-off-by: Yong Tang * Fix Signed-off-by: Yong Tang --------- Signed-off-by: Yong Tang --- core/dnsserver/server_grpc.go | 4 +- core/dnsserver/server_grpc_test.go | 189 +++++++++++++++++++++++++++++ 2 files changed, 192 insertions(+), 1 deletion(-) diff --git a/core/dnsserver/server_grpc.go b/core/dnsserver/server_grpc.go index 3b8edb016..42f54eaff 100644 --- a/core/dnsserver/server_grpc.go +++ b/core/dnsserver/server_grpc.go @@ -198,8 +198,10 @@ func (s *ServergRPC) Query(ctx context.Context, in *pb.DnsPacket) (*pb.DnsPacket if tsig := msg.IsTsig(); tsig != nil { if s.tsigSecret == nil { w.tsigStatus = dns.ErrSecret - } else if _, ok := s.tsigSecret[tsig.Hdr.Name]; !ok { + } else if secret, ok := s.tsigSecret[tsig.Hdr.Name]; !ok { w.tsigStatus = dns.ErrSecret + } else { + w.tsigStatus = dns.TsigVerify(in.GetMsg(), secret, "", false) } } diff --git a/core/dnsserver/server_grpc_test.go b/core/dnsserver/server_grpc_test.go index ae2990f98..9f9e99cb6 100644 --- a/core/dnsserver/server_grpc_test.go +++ b/core/dnsserver/server_grpc_test.go @@ -6,6 +6,7 @@ import ( "errors" "net" "testing" + "time" "github.com/coredns/coredns/pb" "github.com/coredns/coredns/plugin/pkg/transport" @@ -15,6 +16,43 @@ import ( "google.golang.org/grpc/peer" ) +type tsigStatusCheckPlugin struct { + t *testing.T + check func(*testing.T, error) + called *bool +} + +func (p tsigStatusCheckPlugin) Name() string { return "tsig-status-check" } + +func (p tsigStatusCheckPlugin) ServeDNS(_ context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) { + p.t.Helper() + if p.called != nil { + *p.called = true + } + p.check(p.t, w.TsigStatus()) + + m := new(dns.Msg) + m.SetReply(r) + if err := w.WriteMsg(m); err != nil { + p.t.Fatalf("WriteMsg() failed: %v", err) + } + return dns.RcodeSuccess, nil +} + +func mustPackSignedTSIGQuery(t *testing.T, keyName, secret string, tsigTime int64) []byte { + t.Helper() + + m := new(dns.Msg) + m.SetQuestion("example.com.", dns.TypeA) + m.SetTsig(keyName, dns.HmacSHA256, 300, tsigTime) + + wire, _, err := dns.TsigGenerate(m, secret, "", false) + if err != nil { + t.Fatalf("dns.TsigGenerate() failed: %v", err) + } + return wire +} + func TestNewServergRPC(t *testing.T) { tests := []struct { name string @@ -460,3 +498,154 @@ func TestGRPCResponseTsigStatusReturnsStoredStatus(t *testing.T) { t.Fatalf("TsigStatus() = %v, want %v", got, want) } } + +func TestServergRPC_Query_TSIGBadSigSetsTsigStatus(t *testing.T) { + const keyName = "tsig-key." + const clientSecret = "MTIzNDU2Nzg5MDEyMzQ1Ng==" + const serverSecret = "QUJDREVGR0hJSktMTU5PUA==" + + called := false + + server, err := NewServergRPC("127.0.0.1:0", []*Config{ + testConfig("grpc", tsigStatusCheckPlugin{ + t: t, + called: &called, + check: func(t *testing.T, got error) { + t.Helper() + if got == nil { + t.Fatal("TsigStatus() = nil, want non-nil for bad TSIG MAC") + } + if errors.Is(got, dns.ErrSecret) { + t.Fatalf("TsigStatus() = %v, want signature verification error, not ErrSecret", got) + } + if errors.Is(got, dns.ErrTime) { + t.Fatalf("TsigStatus() = %v, want signature verification error, not ErrTime", got) + } + }, + }), + }) + if err != nil { + t.Fatalf("NewServergRPC() failed: %v", err) + } + + server.tsigSecret = map[string]string{ + keyName: serverSecret, + } + + tcpAddr, err := net.ResolveTCPAddr("tcp", "127.0.0.1:12345") + if err != nil { + t.Fatalf("ResolveTCPAddr() failed: %v", err) + } + ctx := peer.NewContext(context.Background(), &peer.Peer{Addr: tcpAddr}) + server.listenAddr = tcpAddr + + wire := mustPackSignedTSIGQuery(t, keyName, clientSecret, time.Now().Unix()) + + _, err = server.Query(ctx, &pb.DnsPacket{Msg: wire}) + if err != nil { + t.Fatalf("Query() failed: %v", err) + } + + if !called { + t.Fatal("ServeDNS() was not called") + } +} + +func TestServergRPC_Query_TSIGBadTimeSetsTsigStatus(t *testing.T) { + const keyName = "tsig-key." + const secret = "MTIzNDU2Nzg5MDEyMzQ1Ng==" + + called := false + + server, err := NewServergRPC("127.0.0.1:0", []*Config{ + testConfig("grpc", tsigStatusCheckPlugin{ + t: t, + called: &called, + check: func(t *testing.T, got error) { + t.Helper() + if !errors.Is(got, dns.ErrTime) { + t.Fatalf("TsigStatus() = %v, want %v", got, dns.ErrTime) + } + }, + }), + }) + if err != nil { + t.Fatalf("NewServergRPC() failed: %v", err) + } + + server.tsigSecret = map[string]string{ + keyName: secret, + } + + tcpAddr, err := net.ResolveTCPAddr("tcp", "127.0.0.1:12345") + if err != nil { + t.Fatalf("ResolveTCPAddr() failed: %v", err) + } + ctx := peer.NewContext(context.Background(), &peer.Peer{Addr: tcpAddr}) + server.listenAddr = tcpAddr + + wire := mustPackSignedTSIGQuery(t, keyName, secret, time.Now().Add(-10*time.Minute).Unix()) + + _, err = server.Query(ctx, &pb.DnsPacket{Msg: wire}) + if err != nil { + t.Fatalf("Query() failed: %v", err) + } + + if !called { + t.Fatal("ServeDNS() was not called") + } +} + +func TestServergRPC_Query_TSIGValidLeavesTsigStatusNil(t *testing.T) { + const keyName = "tsig-key." + const secret = "MTIzNDU2Nzg5MDEyMzQ1Ng==" + + called := false + + server, err := NewServergRPC("127.0.0.1:0", []*Config{ + testConfig("grpc", tsigStatusCheckPlugin{ + t: t, + called: &called, + check: func(t *testing.T, got error) { + t.Helper() + if got != nil { + t.Fatalf("TsigStatus() = %v, want nil", got) + } + }, + }), + }) + if err != nil { + t.Fatalf("NewServergRPC() failed: %v", err) + } + + server.tsigSecret = map[string]string{ + keyName: secret, + } + + tcpAddr, err := net.ResolveTCPAddr("tcp", "127.0.0.1:12345") + if err != nil { + t.Fatalf("ResolveTCPAddr() failed: %v", err) + } + ctx := peer.NewContext(context.Background(), &peer.Peer{Addr: tcpAddr}) + server.listenAddr = tcpAddr + + wire := mustPackSignedTSIGQuery(t, keyName, secret, time.Now().Unix()) + + resp, err := server.Query(ctx, &pb.DnsPacket{Msg: wire}) + if err != nil { + t.Fatalf("Query() failed: %v", err) + } + + if !called { + t.Fatal("ServeDNS() was not called") + } + + if len(resp.GetMsg()) == 0 { + t.Fatal("Query() returned empty message") + } + + respMsg := new(dns.Msg) + if err := respMsg.Unpack(resp.GetMsg()); err != nil { + t.Fatalf("Failed to unpack response message: %v", err) + } +}