mirror of
https://github.com/coredns/coredns.git
synced 2026-04-05 03:35:33 -04:00
core: Add full TSIG verification in gRPC transport (#8006)
* core: Add full TSIG verification in gRPC transport This PR add full TSIG verification in gRPC using dns.TsigVerify() so invalid signatures and timestamps are correctly detected instead of only checking key presence. Signed-off-by: Yong Tang <yong.tang.github@outlook.com> * Fix Signed-off-by: Yong Tang <yong.tang.github@outlook.com> * Fix Signed-off-by: Yong Tang <yong.tang.github@outlook.com> --------- Signed-off-by: Yong Tang <yong.tang.github@outlook.com>
This commit is contained in:
@@ -198,8 +198,10 @@ func (s *ServergRPC) Query(ctx context.Context, in *pb.DnsPacket) (*pb.DnsPacket
|
|||||||
if tsig := msg.IsTsig(); tsig != nil {
|
if tsig := msg.IsTsig(); tsig != nil {
|
||||||
if s.tsigSecret == nil {
|
if s.tsigSecret == nil {
|
||||||
w.tsigStatus = dns.ErrSecret
|
w.tsigStatus = dns.ErrSecret
|
||||||
} else if _, ok := s.tsigSecret[tsig.Hdr.Name]; !ok {
|
} else if secret, ok := s.tsigSecret[tsig.Hdr.Name]; !ok {
|
||||||
w.tsigStatus = dns.ErrSecret
|
w.tsigStatus = dns.ErrSecret
|
||||||
|
} else {
|
||||||
|
w.tsigStatus = dns.TsigVerify(in.GetMsg(), secret, "", false)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ import (
|
|||||||
"errors"
|
"errors"
|
||||||
"net"
|
"net"
|
||||||
"testing"
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/coredns/coredns/pb"
|
"github.com/coredns/coredns/pb"
|
||||||
"github.com/coredns/coredns/plugin/pkg/transport"
|
"github.com/coredns/coredns/plugin/pkg/transport"
|
||||||
@@ -15,6 +16,43 @@ import (
|
|||||||
"google.golang.org/grpc/peer"
|
"google.golang.org/grpc/peer"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
type tsigStatusCheckPlugin struct {
|
||||||
|
t *testing.T
|
||||||
|
check func(*testing.T, error)
|
||||||
|
called *bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p tsigStatusCheckPlugin) Name() string { return "tsig-status-check" }
|
||||||
|
|
||||||
|
func (p tsigStatusCheckPlugin) ServeDNS(_ context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) {
|
||||||
|
p.t.Helper()
|
||||||
|
if p.called != nil {
|
||||||
|
*p.called = true
|
||||||
|
}
|
||||||
|
p.check(p.t, w.TsigStatus())
|
||||||
|
|
||||||
|
m := new(dns.Msg)
|
||||||
|
m.SetReply(r)
|
||||||
|
if err := w.WriteMsg(m); err != nil {
|
||||||
|
p.t.Fatalf("WriteMsg() failed: %v", err)
|
||||||
|
}
|
||||||
|
return dns.RcodeSuccess, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func mustPackSignedTSIGQuery(t *testing.T, keyName, secret string, tsigTime int64) []byte {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
m := new(dns.Msg)
|
||||||
|
m.SetQuestion("example.com.", dns.TypeA)
|
||||||
|
m.SetTsig(keyName, dns.HmacSHA256, 300, tsigTime)
|
||||||
|
|
||||||
|
wire, _, err := dns.TsigGenerate(m, secret, "", false)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("dns.TsigGenerate() failed: %v", err)
|
||||||
|
}
|
||||||
|
return wire
|
||||||
|
}
|
||||||
|
|
||||||
func TestNewServergRPC(t *testing.T) {
|
func TestNewServergRPC(t *testing.T) {
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
@@ -460,3 +498,154 @@ func TestGRPCResponseTsigStatusReturnsStoredStatus(t *testing.T) {
|
|||||||
t.Fatalf("TsigStatus() = %v, want %v", got, want)
|
t.Fatalf("TsigStatus() = %v, want %v", got, want)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestServergRPC_Query_TSIGBadSigSetsTsigStatus(t *testing.T) {
|
||||||
|
const keyName = "tsig-key."
|
||||||
|
const clientSecret = "MTIzNDU2Nzg5MDEyMzQ1Ng=="
|
||||||
|
const serverSecret = "QUJDREVGR0hJSktMTU5PUA=="
|
||||||
|
|
||||||
|
called := false
|
||||||
|
|
||||||
|
server, err := NewServergRPC("127.0.0.1:0", []*Config{
|
||||||
|
testConfig("grpc", tsigStatusCheckPlugin{
|
||||||
|
t: t,
|
||||||
|
called: &called,
|
||||||
|
check: func(t *testing.T, got error) {
|
||||||
|
t.Helper()
|
||||||
|
if got == nil {
|
||||||
|
t.Fatal("TsigStatus() = nil, want non-nil for bad TSIG MAC")
|
||||||
|
}
|
||||||
|
if errors.Is(got, dns.ErrSecret) {
|
||||||
|
t.Fatalf("TsigStatus() = %v, want signature verification error, not ErrSecret", got)
|
||||||
|
}
|
||||||
|
if errors.Is(got, dns.ErrTime) {
|
||||||
|
t.Fatalf("TsigStatus() = %v, want signature verification error, not ErrTime", got)
|
||||||
|
}
|
||||||
|
},
|
||||||
|
}),
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("NewServergRPC() failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
server.tsigSecret = map[string]string{
|
||||||
|
keyName: serverSecret,
|
||||||
|
}
|
||||||
|
|
||||||
|
tcpAddr, err := net.ResolveTCPAddr("tcp", "127.0.0.1:12345")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("ResolveTCPAddr() failed: %v", err)
|
||||||
|
}
|
||||||
|
ctx := peer.NewContext(context.Background(), &peer.Peer{Addr: tcpAddr})
|
||||||
|
server.listenAddr = tcpAddr
|
||||||
|
|
||||||
|
wire := mustPackSignedTSIGQuery(t, keyName, clientSecret, time.Now().Unix())
|
||||||
|
|
||||||
|
_, err = server.Query(ctx, &pb.DnsPacket{Msg: wire})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Query() failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !called {
|
||||||
|
t.Fatal("ServeDNS() was not called")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestServergRPC_Query_TSIGBadTimeSetsTsigStatus(t *testing.T) {
|
||||||
|
const keyName = "tsig-key."
|
||||||
|
const secret = "MTIzNDU2Nzg5MDEyMzQ1Ng=="
|
||||||
|
|
||||||
|
called := false
|
||||||
|
|
||||||
|
server, err := NewServergRPC("127.0.0.1:0", []*Config{
|
||||||
|
testConfig("grpc", tsigStatusCheckPlugin{
|
||||||
|
t: t,
|
||||||
|
called: &called,
|
||||||
|
check: func(t *testing.T, got error) {
|
||||||
|
t.Helper()
|
||||||
|
if !errors.Is(got, dns.ErrTime) {
|
||||||
|
t.Fatalf("TsigStatus() = %v, want %v", got, dns.ErrTime)
|
||||||
|
}
|
||||||
|
},
|
||||||
|
}),
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("NewServergRPC() failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
server.tsigSecret = map[string]string{
|
||||||
|
keyName: secret,
|
||||||
|
}
|
||||||
|
|
||||||
|
tcpAddr, err := net.ResolveTCPAddr("tcp", "127.0.0.1:12345")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("ResolveTCPAddr() failed: %v", err)
|
||||||
|
}
|
||||||
|
ctx := peer.NewContext(context.Background(), &peer.Peer{Addr: tcpAddr})
|
||||||
|
server.listenAddr = tcpAddr
|
||||||
|
|
||||||
|
wire := mustPackSignedTSIGQuery(t, keyName, secret, time.Now().Add(-10*time.Minute).Unix())
|
||||||
|
|
||||||
|
_, err = server.Query(ctx, &pb.DnsPacket{Msg: wire})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Query() failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !called {
|
||||||
|
t.Fatal("ServeDNS() was not called")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestServergRPC_Query_TSIGValidLeavesTsigStatusNil(t *testing.T) {
|
||||||
|
const keyName = "tsig-key."
|
||||||
|
const secret = "MTIzNDU2Nzg5MDEyMzQ1Ng=="
|
||||||
|
|
||||||
|
called := false
|
||||||
|
|
||||||
|
server, err := NewServergRPC("127.0.0.1:0", []*Config{
|
||||||
|
testConfig("grpc", tsigStatusCheckPlugin{
|
||||||
|
t: t,
|
||||||
|
called: &called,
|
||||||
|
check: func(t *testing.T, got error) {
|
||||||
|
t.Helper()
|
||||||
|
if got != nil {
|
||||||
|
t.Fatalf("TsigStatus() = %v, want nil", got)
|
||||||
|
}
|
||||||
|
},
|
||||||
|
}),
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("NewServergRPC() failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
server.tsigSecret = map[string]string{
|
||||||
|
keyName: secret,
|
||||||
|
}
|
||||||
|
|
||||||
|
tcpAddr, err := net.ResolveTCPAddr("tcp", "127.0.0.1:12345")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("ResolveTCPAddr() failed: %v", err)
|
||||||
|
}
|
||||||
|
ctx := peer.NewContext(context.Background(), &peer.Peer{Addr: tcpAddr})
|
||||||
|
server.listenAddr = tcpAddr
|
||||||
|
|
||||||
|
wire := mustPackSignedTSIGQuery(t, keyName, secret, time.Now().Unix())
|
||||||
|
|
||||||
|
resp, err := server.Query(ctx, &pb.DnsPacket{Msg: wire})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Query() failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !called {
|
||||||
|
t.Fatal("ServeDNS() was not called")
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(resp.GetMsg()) == 0 {
|
||||||
|
t.Fatal("Query() returned empty message")
|
||||||
|
}
|
||||||
|
|
||||||
|
respMsg := new(dns.Msg)
|
||||||
|
if err := respMsg.Unpack(resp.GetMsg()); err != nil {
|
||||||
|
t.Fatalf("Failed to unpack response message: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user