mirror of
				https://github.com/coredns/coredns.git
				synced 2025-10-31 18:23:13 -04:00 
			
		
		
		
	
		
			
	
	
		
			256 lines
		
	
	
		
			5.7 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
		
		
			
		
	
	
			256 lines
		
	
	
		
			5.7 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
|  | package tsig
 | ||
|  | 
 | ||
|  | import (
 | ||
|  | 	"context"
 | ||
|  | 	"fmt"
 | ||
|  | 	"testing"
 | ||
|  | 	"time"
 | ||
|  | 
 | ||
|  | 	"github.com/coredns/coredns/plugin/pkg/dnstest"
 | ||
|  | 	"github.com/coredns/coredns/plugin/test"
 | ||
|  | 	"github.com/coredns/coredns/request"
 | ||
|  | 
 | ||
|  | 	"github.com/miekg/dns"
 | ||
|  | )
 | ||
|  | 
 | ||
|  | func TestServeDNS(t *testing.T) {
 | ||
|  | 	cases := []struct {
 | ||
|  | 		zones       []string
 | ||
|  | 		reqTypes    qTypes
 | ||
|  | 		qType       uint16
 | ||
|  | 		qTsig, all  bool
 | ||
|  | 		expectRcode int
 | ||
|  | 		expectTsig  bool
 | ||
|  | 		statusError bool
 | ||
|  | 	}{
 | ||
|  | 		{
 | ||
|  | 			zones:       []string{"."},
 | ||
|  | 			all:         true,
 | ||
|  | 			qType:       dns.TypeA,
 | ||
|  | 			qTsig:       true,
 | ||
|  | 			expectRcode: dns.RcodeSuccess,
 | ||
|  | 			expectTsig:  true,
 | ||
|  | 		},
 | ||
|  | 		{
 | ||
|  | 			zones:       []string{"."},
 | ||
|  | 			all:         true,
 | ||
|  | 			qType:       dns.TypeA,
 | ||
|  | 			qTsig:       false,
 | ||
|  | 			expectRcode: dns.RcodeRefused,
 | ||
|  | 			expectTsig:  false,
 | ||
|  | 		},
 | ||
|  | 		{
 | ||
|  | 			zones:       []string{"another.domain."},
 | ||
|  | 			all:         true,
 | ||
|  | 			qType:       dns.TypeA,
 | ||
|  | 			qTsig:       false,
 | ||
|  | 			expectRcode: dns.RcodeSuccess,
 | ||
|  | 			expectTsig:  false,
 | ||
|  | 		},
 | ||
|  | 		{
 | ||
|  | 			zones:       []string{"another.domain."},
 | ||
|  | 			all:         true,
 | ||
|  | 			qType:       dns.TypeA,
 | ||
|  | 			qTsig:       true,
 | ||
|  | 			expectRcode: dns.RcodeSuccess,
 | ||
|  | 			expectTsig:  false,
 | ||
|  | 		},
 | ||
|  | 		{
 | ||
|  | 			zones:       []string{"."},
 | ||
|  | 			reqTypes:    qTypes{dns.TypeAXFR: {}},
 | ||
|  | 			qType:       dns.TypeAXFR,
 | ||
|  | 			qTsig:       true,
 | ||
|  | 			expectRcode: dns.RcodeSuccess,
 | ||
|  | 			expectTsig:  true,
 | ||
|  | 		},
 | ||
|  | 		{
 | ||
|  | 			zones:       []string{"."},
 | ||
|  | 			reqTypes:    qTypes{},
 | ||
|  | 			qType:       dns.TypeA,
 | ||
|  | 			qTsig:       false,
 | ||
|  | 			expectRcode: dns.RcodeSuccess,
 | ||
|  | 			expectTsig:  false,
 | ||
|  | 		},
 | ||
|  | 		{
 | ||
|  | 			zones:       []string{"."},
 | ||
|  | 			reqTypes:    qTypes{},
 | ||
|  | 			qType:       dns.TypeA,
 | ||
|  | 			qTsig:       true,
 | ||
|  | 			expectRcode: dns.RcodeSuccess,
 | ||
|  | 			expectTsig:  true,
 | ||
|  | 		},
 | ||
|  | 		{
 | ||
|  | 			zones:       []string{"."},
 | ||
|  | 			all:         true,
 | ||
|  | 			qType:       dns.TypeA,
 | ||
|  | 			qTsig:       true,
 | ||
|  | 			expectRcode: dns.RcodeNotAuth,
 | ||
|  | 			expectTsig:  true,
 | ||
|  | 			statusError: true,
 | ||
|  | 		},
 | ||
|  | 	}
 | ||
|  | 
 | ||
|  | 	for i, tc := range cases {
 | ||
|  | 		t.Run(fmt.Sprintf("%d", i), func(t *testing.T) {
 | ||
|  | 			tsig := TSIGServer{
 | ||
|  | 				Zones: tc.zones,
 | ||
|  | 				all:   tc.all,
 | ||
|  | 				types: tc.reqTypes,
 | ||
|  | 				Next:  testHandler(),
 | ||
|  | 			}
 | ||
|  | 
 | ||
|  | 			ctx := context.TODO()
 | ||
|  | 
 | ||
|  | 			var w *dnstest.Recorder
 | ||
|  | 			if tc.statusError {
 | ||
|  | 				w = dnstest.NewRecorder(&ErrWriter{err: dns.ErrSig})
 | ||
|  | 			} else {
 | ||
|  | 				w = dnstest.NewRecorder(&test.ResponseWriter{})
 | ||
|  | 			}
 | ||
|  | 			r := new(dns.Msg)
 | ||
|  | 			r.SetQuestion("test.example.", tc.qType)
 | ||
|  | 			if tc.qTsig {
 | ||
|  | 				r.SetTsig("test.key.", dns.HmacSHA256, 300, time.Now().Unix())
 | ||
|  | 			}
 | ||
|  | 
 | ||
|  | 			_, err := tsig.ServeDNS(ctx, w, r)
 | ||
|  | 			if err != nil {
 | ||
|  | 				t.Fatal(err)
 | ||
|  | 			}
 | ||
|  | 
 | ||
|  | 			if w.Msg.Rcode != tc.expectRcode {
 | ||
|  | 				t.Fatalf("expected rcode %v, got %v", tc.expectRcode, w.Msg.Rcode)
 | ||
|  | 			}
 | ||
|  | 
 | ||
|  | 			if ts := w.Msg.IsTsig(); ts == nil && tc.expectTsig {
 | ||
|  | 				t.Fatal("expected TSIG in response")
 | ||
|  | 			}
 | ||
|  | 			if ts := w.Msg.IsTsig(); ts != nil && !tc.expectTsig {
 | ||
|  | 				t.Fatal("expected no TSIG in response")
 | ||
|  | 			}
 | ||
|  | 		})
 | ||
|  | 	}
 | ||
|  | }
 | ||
|  | 
 | ||
|  | func TestServeDNSTsigErrors(t *testing.T) {
 | ||
|  | 	clientNow := time.Now().Unix()
 | ||
|  | 
 | ||
|  | 	cases := []struct {
 | ||
|  | 		desc              string
 | ||
|  | 		tsigErr           error
 | ||
|  | 		expectRcode       int
 | ||
|  | 		expectError       int
 | ||
|  | 		expectOtherLength int
 | ||
|  | 		expectTimeSigned  int64
 | ||
|  | 	}{
 | ||
|  | 		{
 | ||
|  | 			desc:              "Unknown Key",
 | ||
|  | 			tsigErr:           dns.ErrSecret,
 | ||
|  | 			expectRcode:       dns.RcodeNotAuth,
 | ||
|  | 			expectError:       dns.RcodeBadKey,
 | ||
|  | 			expectOtherLength: 0,
 | ||
|  | 			expectTimeSigned:  0,
 | ||
|  | 		},
 | ||
|  | 		{
 | ||
|  | 			desc:              "Bad Signature",
 | ||
|  | 			tsigErr:           dns.ErrSig,
 | ||
|  | 			expectRcode:       dns.RcodeNotAuth,
 | ||
|  | 			expectError:       dns.RcodeBadSig,
 | ||
|  | 			expectOtherLength: 0,
 | ||
|  | 			expectTimeSigned:  0,
 | ||
|  | 		},
 | ||
|  | 		{
 | ||
|  | 			desc:              "Bad Time",
 | ||
|  | 			tsigErr:           dns.ErrTime,
 | ||
|  | 			expectRcode:       dns.RcodeNotAuth,
 | ||
|  | 			expectError:       dns.RcodeBadTime,
 | ||
|  | 			expectOtherLength: 6,
 | ||
|  | 			expectTimeSigned:  clientNow,
 | ||
|  | 		},
 | ||
|  | 	}
 | ||
|  | 
 | ||
|  | 	tsig := TSIGServer{
 | ||
|  | 		Zones: []string{"."},
 | ||
|  | 		all:   true,
 | ||
|  | 		Next:  testHandler(),
 | ||
|  | 	}
 | ||
|  | 
 | ||
|  | 	for _, tc := range cases {
 | ||
|  | 		t.Run(tc.desc, func(t *testing.T) {
 | ||
|  | 			ctx := context.TODO()
 | ||
|  | 
 | ||
|  | 			var w *dnstest.Recorder
 | ||
|  | 
 | ||
|  | 			w = dnstest.NewRecorder(&ErrWriter{err: tc.tsigErr})
 | ||
|  | 
 | ||
|  | 			r := new(dns.Msg)
 | ||
|  | 			r.SetQuestion("test.example.", dns.TypeA)
 | ||
|  | 			r.SetTsig("test.key.", dns.HmacSHA256, 300, clientNow)
 | ||
|  | 
 | ||
|  | 			// set a fake MAC and Size in request
 | ||
|  | 			rtsig := r.IsTsig()
 | ||
|  | 			rtsig.MAC = "0123456789012345678901234567890101234567890123456789012345678901"
 | ||
|  | 			rtsig.MACSize = 32
 | ||
|  | 
 | ||
|  | 			_, err := tsig.ServeDNS(ctx, w, r)
 | ||
|  | 			if err != nil {
 | ||
|  | 				t.Fatal(err)
 | ||
|  | 			}
 | ||
|  | 
 | ||
|  | 			if w.Msg.Rcode != tc.expectRcode {
 | ||
|  | 				t.Fatalf("expected rcode %v, got %v", tc.expectRcode, w.Msg.Rcode)
 | ||
|  | 			}
 | ||
|  | 
 | ||
|  | 			ts := w.Msg.IsTsig()
 | ||
|  | 
 | ||
|  | 			if ts == nil {
 | ||
|  | 				t.Fatal("expected TSIG in response")
 | ||
|  | 			}
 | ||
|  | 
 | ||
|  | 			if int(ts.Error) != tc.expectError {
 | ||
|  | 				t.Errorf("expected TSIG error code %v, got %v", tc.expectError, ts.Error)
 | ||
|  | 			}
 | ||
|  | 
 | ||
|  | 			if len(ts.OtherData)/2 != tc.expectOtherLength {
 | ||
|  | 				t.Errorf("expected Other of length %v, got %v", tc.expectOtherLength, len(ts.OtherData))
 | ||
|  | 			}
 | ||
|  | 
 | ||
|  | 			if int(ts.OtherLen) != tc.expectOtherLength {
 | ||
|  | 				t.Errorf("expected OtherLen %v, got %v", tc.expectOtherLength, ts.OtherLen)
 | ||
|  | 			}
 | ||
|  | 
 | ||
|  | 			if ts.TimeSigned != uint64(tc.expectTimeSigned) {
 | ||
|  | 				t.Errorf("expected TimeSigned to be %v, got %v", tc.expectTimeSigned, ts.TimeSigned)
 | ||
|  | 			}
 | ||
|  | 		})
 | ||
|  | 	}
 | ||
|  | }
 | ||
|  | 
 | ||
|  | func testHandler() test.HandlerFunc {
 | ||
|  | 	return func(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) {
 | ||
|  | 		state := request.Request{W: w, Req: r}
 | ||
|  | 		qname := state.Name()
 | ||
|  | 		m := new(dns.Msg)
 | ||
|  | 		rcode := dns.RcodeServerFailure
 | ||
|  | 		if qname == "test.example." {
 | ||
|  | 			m.SetReply(r)
 | ||
|  | 			rr := test.A("test.example.  300  IN  A  1.2.3.48")
 | ||
|  | 			m.Answer = []dns.RR{rr}
 | ||
|  | 			m.Authoritative = true
 | ||
|  | 			rcode = dns.RcodeSuccess
 | ||
|  | 		}
 | ||
|  | 		m.SetRcode(r, rcode)
 | ||
|  | 		w.WriteMsg(m)
 | ||
|  | 		return rcode, nil
 | ||
|  | 	}
 | ||
|  | }
 | ||
|  | 
 | ||
|  | // a test.ResponseWriter that always returns err as the TSIG status error
 | ||
|  | type ErrWriter struct {
 | ||
|  | 	err error
 | ||
|  | 	test.ResponseWriter
 | ||
|  | }
 | ||
|  | 
 | ||
|  | // TsigStatus always returns an error.
 | ||
|  | func (t *ErrWriter) TsigStatus() error { return t.err }
 |