diff --git a/core/dnsserver/server_grpc.go b/core/dnsserver/server_grpc.go index 3b8edb016..42f54eaff 100644 --- a/core/dnsserver/server_grpc.go +++ b/core/dnsserver/server_grpc.go @@ -198,8 +198,10 @@ func (s *ServergRPC) Query(ctx context.Context, in *pb.DnsPacket) (*pb.DnsPacket if tsig := msg.IsTsig(); tsig != nil { if s.tsigSecret == nil { w.tsigStatus = dns.ErrSecret - } else if _, ok := s.tsigSecret[tsig.Hdr.Name]; !ok { + } else if secret, ok := s.tsigSecret[tsig.Hdr.Name]; !ok { w.tsigStatus = dns.ErrSecret + } else { + w.tsigStatus = dns.TsigVerify(in.GetMsg(), secret, "", false) } } diff --git a/core/dnsserver/server_grpc_test.go b/core/dnsserver/server_grpc_test.go index ae2990f98..9f9e99cb6 100644 --- a/core/dnsserver/server_grpc_test.go +++ b/core/dnsserver/server_grpc_test.go @@ -6,6 +6,7 @@ import ( "errors" "net" "testing" + "time" "github.com/coredns/coredns/pb" "github.com/coredns/coredns/plugin/pkg/transport" @@ -15,6 +16,43 @@ import ( "google.golang.org/grpc/peer" ) +type tsigStatusCheckPlugin struct { + t *testing.T + check func(*testing.T, error) + called *bool +} + +func (p tsigStatusCheckPlugin) Name() string { return "tsig-status-check" } + +func (p tsigStatusCheckPlugin) ServeDNS(_ context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) { + p.t.Helper() + if p.called != nil { + *p.called = true + } + p.check(p.t, w.TsigStatus()) + + m := new(dns.Msg) + m.SetReply(r) + if err := w.WriteMsg(m); err != nil { + p.t.Fatalf("WriteMsg() failed: %v", err) + } + return dns.RcodeSuccess, nil +} + +func mustPackSignedTSIGQuery(t *testing.T, keyName, secret string, tsigTime int64) []byte { + t.Helper() + + m := new(dns.Msg) + m.SetQuestion("example.com.", dns.TypeA) + m.SetTsig(keyName, dns.HmacSHA256, 300, tsigTime) + + wire, _, err := dns.TsigGenerate(m, secret, "", false) + if err != nil { + t.Fatalf("dns.TsigGenerate() failed: %v", err) + } + return wire +} + func TestNewServergRPC(t *testing.T) { tests := []struct { name string @@ -460,3 +498,154 @@ func TestGRPCResponseTsigStatusReturnsStoredStatus(t *testing.T) { t.Fatalf("TsigStatus() = %v, want %v", got, want) } } + +func TestServergRPC_Query_TSIGBadSigSetsTsigStatus(t *testing.T) { + const keyName = "tsig-key." + const clientSecret = "MTIzNDU2Nzg5MDEyMzQ1Ng==" + const serverSecret = "QUJDREVGR0hJSktMTU5PUA==" + + called := false + + server, err := NewServergRPC("127.0.0.1:0", []*Config{ + testConfig("grpc", tsigStatusCheckPlugin{ + t: t, + called: &called, + check: func(t *testing.T, got error) { + t.Helper() + if got == nil { + t.Fatal("TsigStatus() = nil, want non-nil for bad TSIG MAC") + } + if errors.Is(got, dns.ErrSecret) { + t.Fatalf("TsigStatus() = %v, want signature verification error, not ErrSecret", got) + } + if errors.Is(got, dns.ErrTime) { + t.Fatalf("TsigStatus() = %v, want signature verification error, not ErrTime", got) + } + }, + }), + }) + if err != nil { + t.Fatalf("NewServergRPC() failed: %v", err) + } + + server.tsigSecret = map[string]string{ + keyName: serverSecret, + } + + tcpAddr, err := net.ResolveTCPAddr("tcp", "127.0.0.1:12345") + if err != nil { + t.Fatalf("ResolveTCPAddr() failed: %v", err) + } + ctx := peer.NewContext(context.Background(), &peer.Peer{Addr: tcpAddr}) + server.listenAddr = tcpAddr + + wire := mustPackSignedTSIGQuery(t, keyName, clientSecret, time.Now().Unix()) + + _, err = server.Query(ctx, &pb.DnsPacket{Msg: wire}) + if err != nil { + t.Fatalf("Query() failed: %v", err) + } + + if !called { + t.Fatal("ServeDNS() was not called") + } +} + +func TestServergRPC_Query_TSIGBadTimeSetsTsigStatus(t *testing.T) { + const keyName = "tsig-key." + const secret = "MTIzNDU2Nzg5MDEyMzQ1Ng==" + + called := false + + server, err := NewServergRPC("127.0.0.1:0", []*Config{ + testConfig("grpc", tsigStatusCheckPlugin{ + t: t, + called: &called, + check: func(t *testing.T, got error) { + t.Helper() + if !errors.Is(got, dns.ErrTime) { + t.Fatalf("TsigStatus() = %v, want %v", got, dns.ErrTime) + } + }, + }), + }) + if err != nil { + t.Fatalf("NewServergRPC() failed: %v", err) + } + + server.tsigSecret = map[string]string{ + keyName: secret, + } + + tcpAddr, err := net.ResolveTCPAddr("tcp", "127.0.0.1:12345") + if err != nil { + t.Fatalf("ResolveTCPAddr() failed: %v", err) + } + ctx := peer.NewContext(context.Background(), &peer.Peer{Addr: tcpAddr}) + server.listenAddr = tcpAddr + + wire := mustPackSignedTSIGQuery(t, keyName, secret, time.Now().Add(-10*time.Minute).Unix()) + + _, err = server.Query(ctx, &pb.DnsPacket{Msg: wire}) + if err != nil { + t.Fatalf("Query() failed: %v", err) + } + + if !called { + t.Fatal("ServeDNS() was not called") + } +} + +func TestServergRPC_Query_TSIGValidLeavesTsigStatusNil(t *testing.T) { + const keyName = "tsig-key." + const secret = "MTIzNDU2Nzg5MDEyMzQ1Ng==" + + called := false + + server, err := NewServergRPC("127.0.0.1:0", []*Config{ + testConfig("grpc", tsigStatusCheckPlugin{ + t: t, + called: &called, + check: func(t *testing.T, got error) { + t.Helper() + if got != nil { + t.Fatalf("TsigStatus() = %v, want nil", got) + } + }, + }), + }) + if err != nil { + t.Fatalf("NewServergRPC() failed: %v", err) + } + + server.tsigSecret = map[string]string{ + keyName: secret, + } + + tcpAddr, err := net.ResolveTCPAddr("tcp", "127.0.0.1:12345") + if err != nil { + t.Fatalf("ResolveTCPAddr() failed: %v", err) + } + ctx := peer.NewContext(context.Background(), &peer.Peer{Addr: tcpAddr}) + server.listenAddr = tcpAddr + + wire := mustPackSignedTSIGQuery(t, keyName, secret, time.Now().Unix()) + + resp, err := server.Query(ctx, &pb.DnsPacket{Msg: wire}) + if err != nil { + t.Fatalf("Query() failed: %v", err) + } + + if !called { + t.Fatal("ServeDNS() was not called") + } + + if len(resp.GetMsg()) == 0 { + t.Fatal("Query() returned empty message") + } + + respMsg := new(dns.Msg) + if err := respMsg.Unpack(resp.GetMsg()); err != nil { + t.Fatalf("Failed to unpack response message: %v", err) + } +}