core: Add full TSIG verification in gRPC transport (#8006)

* core: Add full TSIG verification in gRPC transport

This PR add full TSIG verification in gRPC using dns.TsigVerify() so invalid signatures and timestamps are correctly detected instead of only checking key presence.

Signed-off-by: Yong Tang <yong.tang.github@outlook.com>

* Fix

Signed-off-by: Yong Tang <yong.tang.github@outlook.com>

* Fix

Signed-off-by: Yong Tang <yong.tang.github@outlook.com>

---------

Signed-off-by: Yong Tang <yong.tang.github@outlook.com>
This commit is contained in:
Yong Tang
2026-04-04 01:58:36 -07:00
committed by GitHub
parent 510977c476
commit 4c9a80c296
2 changed files with 192 additions and 1 deletions

View File

@@ -198,8 +198,10 @@ func (s *ServergRPC) Query(ctx context.Context, in *pb.DnsPacket) (*pb.DnsPacket
if tsig := msg.IsTsig(); tsig != nil {
if s.tsigSecret == nil {
w.tsigStatus = dns.ErrSecret
} else if _, ok := s.tsigSecret[tsig.Hdr.Name]; !ok {
} else if secret, ok := s.tsigSecret[tsig.Hdr.Name]; !ok {
w.tsigStatus = dns.ErrSecret
} else {
w.tsigStatus = dns.TsigVerify(in.GetMsg(), secret, "", false)
}
}

View File

@@ -6,6 +6,7 @@ import (
"errors"
"net"
"testing"
"time"
"github.com/coredns/coredns/pb"
"github.com/coredns/coredns/plugin/pkg/transport"
@@ -15,6 +16,43 @@ import (
"google.golang.org/grpc/peer"
)
type tsigStatusCheckPlugin struct {
t *testing.T
check func(*testing.T, error)
called *bool
}
func (p tsigStatusCheckPlugin) Name() string { return "tsig-status-check" }
func (p tsigStatusCheckPlugin) ServeDNS(_ context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) {
p.t.Helper()
if p.called != nil {
*p.called = true
}
p.check(p.t, w.TsigStatus())
m := new(dns.Msg)
m.SetReply(r)
if err := w.WriteMsg(m); err != nil {
p.t.Fatalf("WriteMsg() failed: %v", err)
}
return dns.RcodeSuccess, nil
}
func mustPackSignedTSIGQuery(t *testing.T, keyName, secret string, tsigTime int64) []byte {
t.Helper()
m := new(dns.Msg)
m.SetQuestion("example.com.", dns.TypeA)
m.SetTsig(keyName, dns.HmacSHA256, 300, tsigTime)
wire, _, err := dns.TsigGenerate(m, secret, "", false)
if err != nil {
t.Fatalf("dns.TsigGenerate() failed: %v", err)
}
return wire
}
func TestNewServergRPC(t *testing.T) {
tests := []struct {
name string
@@ -460,3 +498,154 @@ func TestGRPCResponseTsigStatusReturnsStoredStatus(t *testing.T) {
t.Fatalf("TsigStatus() = %v, want %v", got, want)
}
}
func TestServergRPC_Query_TSIGBadSigSetsTsigStatus(t *testing.T) {
const keyName = "tsig-key."
const clientSecret = "MTIzNDU2Nzg5MDEyMzQ1Ng=="
const serverSecret = "QUJDREVGR0hJSktMTU5PUA=="
called := false
server, err := NewServergRPC("127.0.0.1:0", []*Config{
testConfig("grpc", tsigStatusCheckPlugin{
t: t,
called: &called,
check: func(t *testing.T, got error) {
t.Helper()
if got == nil {
t.Fatal("TsigStatus() = nil, want non-nil for bad TSIG MAC")
}
if errors.Is(got, dns.ErrSecret) {
t.Fatalf("TsigStatus() = %v, want signature verification error, not ErrSecret", got)
}
if errors.Is(got, dns.ErrTime) {
t.Fatalf("TsigStatus() = %v, want signature verification error, not ErrTime", got)
}
},
}),
})
if err != nil {
t.Fatalf("NewServergRPC() failed: %v", err)
}
server.tsigSecret = map[string]string{
keyName: serverSecret,
}
tcpAddr, err := net.ResolveTCPAddr("tcp", "127.0.0.1:12345")
if err != nil {
t.Fatalf("ResolveTCPAddr() failed: %v", err)
}
ctx := peer.NewContext(context.Background(), &peer.Peer{Addr: tcpAddr})
server.listenAddr = tcpAddr
wire := mustPackSignedTSIGQuery(t, keyName, clientSecret, time.Now().Unix())
_, err = server.Query(ctx, &pb.DnsPacket{Msg: wire})
if err != nil {
t.Fatalf("Query() failed: %v", err)
}
if !called {
t.Fatal("ServeDNS() was not called")
}
}
func TestServergRPC_Query_TSIGBadTimeSetsTsigStatus(t *testing.T) {
const keyName = "tsig-key."
const secret = "MTIzNDU2Nzg5MDEyMzQ1Ng=="
called := false
server, err := NewServergRPC("127.0.0.1:0", []*Config{
testConfig("grpc", tsigStatusCheckPlugin{
t: t,
called: &called,
check: func(t *testing.T, got error) {
t.Helper()
if !errors.Is(got, dns.ErrTime) {
t.Fatalf("TsigStatus() = %v, want %v", got, dns.ErrTime)
}
},
}),
})
if err != nil {
t.Fatalf("NewServergRPC() failed: %v", err)
}
server.tsigSecret = map[string]string{
keyName: secret,
}
tcpAddr, err := net.ResolveTCPAddr("tcp", "127.0.0.1:12345")
if err != nil {
t.Fatalf("ResolveTCPAddr() failed: %v", err)
}
ctx := peer.NewContext(context.Background(), &peer.Peer{Addr: tcpAddr})
server.listenAddr = tcpAddr
wire := mustPackSignedTSIGQuery(t, keyName, secret, time.Now().Add(-10*time.Minute).Unix())
_, err = server.Query(ctx, &pb.DnsPacket{Msg: wire})
if err != nil {
t.Fatalf("Query() failed: %v", err)
}
if !called {
t.Fatal("ServeDNS() was not called")
}
}
func TestServergRPC_Query_TSIGValidLeavesTsigStatusNil(t *testing.T) {
const keyName = "tsig-key."
const secret = "MTIzNDU2Nzg5MDEyMzQ1Ng=="
called := false
server, err := NewServergRPC("127.0.0.1:0", []*Config{
testConfig("grpc", tsigStatusCheckPlugin{
t: t,
called: &called,
check: func(t *testing.T, got error) {
t.Helper()
if got != nil {
t.Fatalf("TsigStatus() = %v, want nil", got)
}
},
}),
})
if err != nil {
t.Fatalf("NewServergRPC() failed: %v", err)
}
server.tsigSecret = map[string]string{
keyName: secret,
}
tcpAddr, err := net.ResolveTCPAddr("tcp", "127.0.0.1:12345")
if err != nil {
t.Fatalf("ResolveTCPAddr() failed: %v", err)
}
ctx := peer.NewContext(context.Background(), &peer.Peer{Addr: tcpAddr})
server.listenAddr = tcpAddr
wire := mustPackSignedTSIGQuery(t, keyName, secret, time.Now().Unix())
resp, err := server.Query(ctx, &pb.DnsPacket{Msg: wire})
if err != nil {
t.Fatalf("Query() failed: %v", err)
}
if !called {
t.Fatal("ServeDNS() was not called")
}
if len(resp.GetMsg()) == 0 {
t.Fatal("Query() returned empty message")
}
respMsg := new(dns.Msg)
if err := respMsg.Unpack(resp.GetMsg()); err != nil {
t.Fatalf("Failed to unpack response message: %v", err)
}
}