mirror of
https://github.com/coredns/coredns.git
synced 2026-04-04 19:25:40 -04:00
core: Add full TSIG verification in QUIC transport (#8007)
* core: Add full TSIG verification in QUIC transport This PR add full TSIG verification in QUIC using dns.TsigVerify() Signed-off-by: Yong Tang <yong.tang.github@outlook.com> * Fix Signed-off-by: Yong Tang <yong.tang.github@outlook.com> --------- Signed-off-by: Yong Tang <yong.tang.github@outlook.com>
This commit is contained in:
@@ -227,8 +227,10 @@ func (s *ServerQUIC) serveQUICStream(stream *quic.Stream, conn *quic.Conn) {
|
|||||||
if tsig := req.IsTsig(); tsig != nil {
|
if tsig := req.IsTsig(); tsig != nil {
|
||||||
if s.tsigSecret == nil {
|
if s.tsigSecret == nil {
|
||||||
w.tsigStatus = dns.ErrSecret
|
w.tsigStatus = dns.ErrSecret
|
||||||
} else if _, ok := s.tsigSecret[tsig.Hdr.Name]; !ok {
|
} else if secret, ok := s.tsigSecret[tsig.Hdr.Name]; !ok {
|
||||||
w.tsigStatus = dns.ErrSecret
|
w.tsigStatus = dns.ErrSecret
|
||||||
|
} else {
|
||||||
|
w.tsigStatus = dns.TsigVerify(buf, secret, "", false)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -3,14 +3,60 @@ package dnsserver
|
|||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"context"
|
"context"
|
||||||
|
"crypto/rand"
|
||||||
|
"crypto/rsa"
|
||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
|
"crypto/x509"
|
||||||
|
"crypto/x509/pkix"
|
||||||
"errors"
|
"errors"
|
||||||
|
"math/big"
|
||||||
"testing"
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/coredns/coredns/plugin/pkg/transport"
|
||||||
|
|
||||||
"github.com/miekg/dns"
|
"github.com/miekg/dns"
|
||||||
"github.com/quic-go/quic-go"
|
"github.com/quic-go/quic-go"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
type tsigStatusCheckPlugin struct {
|
||||||
|
t *testing.T
|
||||||
|
check func(*testing.T, error)
|
||||||
|
called chan struct{}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p tsigStatusCheckPlugin) Name() string { return "tsig-status-check" }
|
||||||
|
|
||||||
|
func (p tsigStatusCheckPlugin) ServeDNS(_ context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) {
|
||||||
|
p.t.Helper()
|
||||||
|
if p.called != nil {
|
||||||
|
p.called <- struct{}{}
|
||||||
|
}
|
||||||
|
p.check(p.t, w.TsigStatus())
|
||||||
|
|
||||||
|
m := new(dns.Msg)
|
||||||
|
m.SetReply(r)
|
||||||
|
if err := w.WriteMsg(m); err != nil {
|
||||||
|
p.t.Fatalf("WriteMsg() failed: %v", err)
|
||||||
|
}
|
||||||
|
return dns.RcodeSuccess, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func mustPackSignedTSIGQuery(t *testing.T, keyName, secret string, tsigTime int64) []byte {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
m := new(dns.Msg)
|
||||||
|
m.SetQuestion("example.com.", dns.TypeA)
|
||||||
|
m.Id = 0
|
||||||
|
m.SetTsig(keyName, dns.HmacSHA256, 300, tsigTime)
|
||||||
|
|
||||||
|
wire, _, err := dns.TsigGenerate(m, secret, "", false)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("dns.TsigGenerate() failed: %v", err)
|
||||||
|
}
|
||||||
|
return wire
|
||||||
|
}
|
||||||
|
|
||||||
func TestNewServerQUIC(t *testing.T) {
|
func TestNewServerQUIC(t *testing.T) {
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
@@ -451,3 +497,227 @@ func TestDoQWriterTsigStatusReturnsStoredStatus(t *testing.T) {
|
|||||||
t.Fatalf("TsigStatus() = %v, want %v", got, want)
|
t.Fatalf("TsigStatus() = %v, want %v", got, want)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestServerQUIC_ServeQUIC_TSIGBadSigSetsTsigStatus(t *testing.T) {
|
||||||
|
const keyName = "tsig-key."
|
||||||
|
const clientSecret = "MTIzNDU2Nzg5MDEyMzQ1Ng=="
|
||||||
|
const serverSecret = "QUJDREVGR0hJSktMTU5PUA=="
|
||||||
|
|
||||||
|
called := make(chan struct{}, 1)
|
||||||
|
|
||||||
|
config := testConfig("quic", tsigStatusCheckPlugin{
|
||||||
|
t: t,
|
||||||
|
called: called,
|
||||||
|
check: func(t *testing.T, got error) {
|
||||||
|
t.Helper()
|
||||||
|
if got == nil {
|
||||||
|
t.Fatal("TsigStatus() = nil, want non-nil for bad TSIG MAC")
|
||||||
|
}
|
||||||
|
if errors.Is(got, dns.ErrSecret) {
|
||||||
|
t.Fatalf("TsigStatus() = %v, want signature verification error, not ErrSecret", got)
|
||||||
|
}
|
||||||
|
if errors.Is(got, dns.ErrTime) {
|
||||||
|
t.Fatalf("TsigStatus() = %v, want signature verification error, not ErrTime", got)
|
||||||
|
}
|
||||||
|
},
|
||||||
|
})
|
||||||
|
config.TLSConfig = mustMakeQUICServerTLSConfig(t)
|
||||||
|
|
||||||
|
server, err := NewServerQUIC(transport.QUIC+"://127.0.0.1:0", []*Config{config})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("NewServerQUIC() failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
server.tsigSecret = map[string]string{
|
||||||
|
keyName: serverSecret,
|
||||||
|
}
|
||||||
|
|
||||||
|
pc, err := server.ListenPacket()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("ListenPacket() failed: %v", err)
|
||||||
|
}
|
||||||
|
defer pc.Close()
|
||||||
|
|
||||||
|
serveErrCh := make(chan error, 1)
|
||||||
|
go func() {
|
||||||
|
serveErrCh <- server.ServeQUIC()
|
||||||
|
}()
|
||||||
|
|
||||||
|
defer func() {
|
||||||
|
_ = server.Stop()
|
||||||
|
select {
|
||||||
|
case <-serveErrCh:
|
||||||
|
case <-time.After(2 * time.Second):
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
conn, err := quic.DialAddr(ctx, pc.LocalAddr().String(), mustMakeQUICClientTLSConfig(), &quic.Config{})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("quic.DialAddr() failed: %v", err)
|
||||||
|
}
|
||||||
|
defer conn.CloseWithError(DoQCodeNoError, "")
|
||||||
|
|
||||||
|
stream, err := conn.OpenStreamSync(ctx)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("OpenStreamSync() failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
wire := mustPackSignedTSIGQuery(t, keyName, clientSecret, time.Now().Unix())
|
||||||
|
|
||||||
|
if _, err := stream.Write(AddPrefix(wire)); err != nil {
|
||||||
|
t.Fatalf("stream.Write() failed: %v", err)
|
||||||
|
}
|
||||||
|
if err := stream.Close(); err != nil {
|
||||||
|
t.Fatalf("stream.Close() failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
respWire, err := readDOQMessage(stream)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("readDOQMessage() failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
resp := new(dns.Msg)
|
||||||
|
if err := resp.Unpack(respWire); err != nil {
|
||||||
|
t.Fatalf("response unpack failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-called:
|
||||||
|
case <-time.After(5 * time.Second):
|
||||||
|
t.Fatal("ServeDNS() was not called")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func mustMakeQUICServerTLSConfig(t *testing.T) *tls.Config {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
priv, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("rsa.GenerateKey() failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
tmpl := &x509.Certificate{
|
||||||
|
SerialNumber: big.NewInt(1),
|
||||||
|
Subject: pkix.Name{
|
||||||
|
CommonName: "127.0.0.1",
|
||||||
|
},
|
||||||
|
NotBefore: time.Now().Add(-time.Hour),
|
||||||
|
NotAfter: time.Now().Add(time.Hour),
|
||||||
|
KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageKeyEncipherment,
|
||||||
|
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
|
||||||
|
BasicConstraintsValid: true,
|
||||||
|
DNSNames: []string{"localhost"},
|
||||||
|
IPAddresses: nil,
|
||||||
|
}
|
||||||
|
|
||||||
|
der, err := x509.CreateCertificate(rand.Reader, tmpl, tmpl, &priv.PublicKey, priv)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("x509.CreateCertificate() failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
cert := tls.Certificate{
|
||||||
|
Certificate: [][]byte{der},
|
||||||
|
PrivateKey: priv,
|
||||||
|
}
|
||||||
|
|
||||||
|
return &tls.Config{
|
||||||
|
Certificates: []tls.Certificate{cert},
|
||||||
|
NextProtos: []string{"doq"},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func mustMakeQUICClientTLSConfig() *tls.Config {
|
||||||
|
return &tls.Config{
|
||||||
|
InsecureSkipVerify: true,
|
||||||
|
NextProtos: []string{"doq"},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestServerQUIC_ServeQUIC_TSIGValidSigLeavesTsigStatusNil(t *testing.T) {
|
||||||
|
const keyName = "tsig-key."
|
||||||
|
const secret = "MTIzNDU2Nzg5MDEyMzQ1Ng=="
|
||||||
|
|
||||||
|
called := make(chan struct{}, 1)
|
||||||
|
|
||||||
|
config := testConfig("quic", tsigStatusCheckPlugin{
|
||||||
|
t: t,
|
||||||
|
called: called,
|
||||||
|
check: func(t *testing.T, got error) {
|
||||||
|
t.Helper()
|
||||||
|
if got != nil {
|
||||||
|
t.Fatalf("TsigStatus() = %v, want nil for valid TSIG MAC", got)
|
||||||
|
}
|
||||||
|
},
|
||||||
|
})
|
||||||
|
config.TLSConfig = mustMakeQUICServerTLSConfig(t)
|
||||||
|
|
||||||
|
server, err := NewServerQUIC(transport.QUIC+"://127.0.0.1:0", []*Config{config})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("NewServerQUIC() failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
server.tsigSecret = map[string]string{
|
||||||
|
keyName: secret,
|
||||||
|
}
|
||||||
|
|
||||||
|
pc, err := server.ListenPacket()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("ListenPacket() failed: %v", err)
|
||||||
|
}
|
||||||
|
defer pc.Close()
|
||||||
|
|
||||||
|
serveErrCh := make(chan error, 1)
|
||||||
|
go func() {
|
||||||
|
serveErrCh <- server.ServeQUIC()
|
||||||
|
}()
|
||||||
|
|
||||||
|
defer func() {
|
||||||
|
_ = server.Stop()
|
||||||
|
select {
|
||||||
|
case <-serveErrCh:
|
||||||
|
case <-time.After(2 * time.Second):
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
conn, err := quic.DialAddr(ctx, pc.LocalAddr().String(), mustMakeQUICClientTLSConfig(), &quic.Config{})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("quic.DialAddr() failed: %v", err)
|
||||||
|
}
|
||||||
|
defer conn.CloseWithError(DoQCodeNoError, "")
|
||||||
|
|
||||||
|
stream, err := conn.OpenStreamSync(ctx)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("OpenStreamSync() failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
wire := mustPackSignedTSIGQuery(t, keyName, secret, time.Now().Unix())
|
||||||
|
|
||||||
|
if _, err := stream.Write(AddPrefix(wire)); err != nil {
|
||||||
|
t.Fatalf("stream.Write() failed: %v", err)
|
||||||
|
}
|
||||||
|
if err := stream.Close(); err != nil {
|
||||||
|
t.Fatalf("stream.Close() failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
respWire, err := readDOQMessage(stream)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("readDOQMessage() failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
resp := new(dns.Msg)
|
||||||
|
if err := resp.Unpack(respWire); err != nil {
|
||||||
|
t.Fatalf("response unpack failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-called:
|
||||||
|
case <-time.After(5 * time.Second):
|
||||||
|
t.Fatal("ServeDNS() was not called")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user