| 
									
										
										
										
											2018-02-08 10:11:04 -06:00
										 |  |  | package kubernetes
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | import (
 | 
					
						
							| 
									
										
										
										
											2018-04-22 08:34:35 +01:00
										 |  |  | 	"context"
 | 
					
						
							| 
									
										
										
										
											2018-02-08 10:11:04 -06:00
										 |  |  | 	"strings"
 | 
					
						
							|  |  |  | 	"testing"
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	"github.com/coredns/coredns/plugin/pkg/dnstest"
 | 
					
						
							|  |  |  | 	"github.com/coredns/coredns/plugin/test"
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	"github.com/miekg/dns"
 | 
					
						
							|  |  |  | )
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | func TestKubernetesXFR(t *testing.T) {
 | 
					
						
							|  |  |  | 	k := New([]string{"cluster.local."})
 | 
					
						
							|  |  |  | 	k.APIConn = &APIConnServeTest{}
 | 
					
						
							| 
									
										
										
										
											2018-11-13 18:25:30 -05:00
										 |  |  | 	k.TransferTo = []string{"10.240.0.1:53"}
 | 
					
						
							| 
									
										
										
										
											2019-02-17 03:32:28 -05:00
										 |  |  | 	k.Namespaces = map[string]struct{}{"testns": {}}
 | 
					
						
							| 
									
										
										
										
											2018-02-08 10:11:04 -06:00
										 |  |  | 
 | 
					
						
							|  |  |  | 	ctx := context.TODO()
 | 
					
						
							|  |  |  | 	w := dnstest.NewMultiRecorder(&test.ResponseWriter{})
 | 
					
						
							|  |  |  | 	dnsmsg := &dns.Msg{}
 | 
					
						
							|  |  |  | 	dnsmsg.SetAxfr(k.Zones[0])
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	_, err := k.ServeDNS(ctx, w, dnsmsg)
 | 
					
						
							|  |  |  | 	if err != nil {
 | 
					
						
							|  |  |  | 		t.Error(err)
 | 
					
						
							|  |  |  | 	}
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	if len(w.Msgs) == 0 {
 | 
					
						
							|  |  |  | 		t.Logf("%+v\n", w)
 | 
					
						
							| 
									
										
										
										
											2018-11-13 18:25:30 -05:00
										 |  |  | 		t.Fatal("Did not get back a zone response")
 | 
					
						
							|  |  |  | 	}
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	if len(w.Msgs[0].Answer) == 0 {
 | 
					
						
							|  |  |  | 		t.Logf("%+v\n", w)
 | 
					
						
							|  |  |  | 		t.Fatal("Did not get back an answer")
 | 
					
						
							| 
									
										
										
										
											2018-02-08 10:11:04 -06:00
										 |  |  | 	}
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	// Ensure xfr starts with SOA
 | 
					
						
							|  |  |  | 	if w.Msgs[0].Answer[0].Header().Rrtype != dns.TypeSOA {
 | 
					
						
							|  |  |  | 		t.Error("Invalid XFR, does not start with SOA record")
 | 
					
						
							|  |  |  | 	}
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	// Ensure xfr starts with SOA
 | 
					
						
							|  |  |  | 	// Last message is empty, so we need to go back one further
 | 
					
						
							|  |  |  | 	if w.Msgs[len(w.Msgs)-2].Answer[len(w.Msgs[len(w.Msgs)-2].Answer)-1].Header().Rrtype != dns.TypeSOA {
 | 
					
						
							|  |  |  | 		t.Error("Invalid XFR, does not end with SOA record")
 | 
					
						
							|  |  |  | 	}
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	testRRs := []dns.RR{}
 | 
					
						
							|  |  |  | 	for _, tc := range dnsTestCases {
 | 
					
						
							|  |  |  | 		if tc.Rcode != dns.RcodeSuccess {
 | 
					
						
							|  |  |  | 			continue
 | 
					
						
							|  |  |  | 		}
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 		for _, ans := range tc.Answer {
 | 
					
						
							|  |  |  | 			// Exclude wildcard searches
 | 
					
						
							|  |  |  | 			if strings.Contains(ans.Header().Name, "*") {
 | 
					
						
							|  |  |  | 				continue
 | 
					
						
							|  |  |  | 			}
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 			// Exclude TXT records
 | 
					
						
							|  |  |  | 			if ans.Header().Rrtype == dns.TypeTXT {
 | 
					
						
							|  |  |  | 				continue
 | 
					
						
							|  |  |  | 			}
 | 
					
						
							|  |  |  | 			testRRs = append(testRRs, ans)
 | 
					
						
							|  |  |  | 		}
 | 
					
						
							|  |  |  | 	}
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	gotRRs := []dns.RR{}
 | 
					
						
							|  |  |  | 	for _, resp := range w.Msgs {
 | 
					
						
							|  |  |  | 		for _, ans := range resp.Answer {
 | 
					
						
							|  |  |  | 			// Skip SOA records since these
 | 
					
						
							|  |  |  | 			// test cases do not exist
 | 
					
						
							|  |  |  | 			if ans.Header().Rrtype == dns.TypeSOA {
 | 
					
						
							|  |  |  | 				continue
 | 
					
						
							|  |  |  | 			}
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 			gotRRs = append(gotRRs, ans)
 | 
					
						
							|  |  |  | 		}
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	}
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	diff := difference(testRRs, gotRRs)
 | 
					
						
							|  |  |  | 	if len(diff) != 0 {
 | 
					
						
							|  |  |  | 		t.Errorf("Got back %d records that do not exist in test cases, should be 0:", len(diff))
 | 
					
						
							|  |  |  | 		for _, rec := range diff {
 | 
					
						
							|  |  |  | 			t.Errorf("%+v", rec)
 | 
					
						
							|  |  |  | 		}
 | 
					
						
							|  |  |  | 	}
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	diff = difference(gotRRs, testRRs)
 | 
					
						
							|  |  |  | 	if len(diff) != 0 {
 | 
					
						
							| 
									
										
										
										
											2018-10-09 21:56:09 +01:00
										 |  |  | 		t.Errorf("Found %d records we're missing, should be 0:", len(diff))
 | 
					
						
							| 
									
										
										
										
											2018-02-08 10:11:04 -06:00
										 |  |  | 		for _, rec := range diff {
 | 
					
						
							|  |  |  | 			t.Errorf("%+v", rec)
 | 
					
						
							|  |  |  | 		}
 | 
					
						
							|  |  |  | 	}
 | 
					
						
							|  |  |  | }
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2018-11-13 18:25:30 -05:00
										 |  |  | func TestKubernetesXFRNotAllowed(t *testing.T) {
 | 
					
						
							|  |  |  | 	k := New([]string{"cluster.local."})
 | 
					
						
							|  |  |  | 	k.APIConn = &APIConnServeTest{}
 | 
					
						
							|  |  |  | 	k.TransferTo = []string{"1.2.3.4:53"}
 | 
					
						
							| 
									
										
										
										
											2019-02-17 03:32:28 -05:00
										 |  |  | 	k.Namespaces = map[string]struct{}{"testns": {}}
 | 
					
						
							| 
									
										
										
										
											2018-11-13 18:25:30 -05:00
										 |  |  | 
 | 
					
						
							|  |  |  | 	ctx := context.TODO()
 | 
					
						
							|  |  |  | 	w := dnstest.NewMultiRecorder(&test.ResponseWriter{})
 | 
					
						
							|  |  |  | 	dnsmsg := &dns.Msg{}
 | 
					
						
							|  |  |  | 	dnsmsg.SetAxfr(k.Zones[0])
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	_, err := k.ServeDNS(ctx, w, dnsmsg)
 | 
					
						
							|  |  |  | 	if err != nil {
 | 
					
						
							|  |  |  | 		t.Error(err)
 | 
					
						
							|  |  |  | 	}
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	if len(w.Msgs) == 0 {
 | 
					
						
							|  |  |  | 		t.Logf("%+v\n", w)
 | 
					
						
							|  |  |  | 		t.Fatal("Did not get back a zone response")
 | 
					
						
							|  |  |  | 	}
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	if len(w.Msgs[0].Answer) != 0 {
 | 
					
						
							|  |  |  | 		t.Logf("%+v\n", w)
 | 
					
						
							|  |  |  | 		t.Fatal("Got an answer, should not have")
 | 
					
						
							|  |  |  | 	}
 | 
					
						
							|  |  |  | }
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2018-02-08 10:11:04 -06:00
										 |  |  | // difference shows what we're missing when comparing two RR slices
 | 
					
						
							|  |  |  | func difference(testRRs []dns.RR, gotRRs []dns.RR) []dns.RR {
 | 
					
						
							| 
									
										
										
										
											2018-12-10 10:17:15 +00:00
										 |  |  | 	expectedRRs := map[string]struct{}{}
 | 
					
						
							| 
									
										
										
										
											2018-02-08 10:11:04 -06:00
										 |  |  | 	for _, rr := range testRRs {
 | 
					
						
							| 
									
										
										
										
											2018-12-10 10:17:15 +00:00
										 |  |  | 		expectedRRs[rr.String()] = struct{}{}
 | 
					
						
							| 
									
										
										
										
											2018-02-08 10:11:04 -06:00
										 |  |  | 	}
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	foundRRs := []dns.RR{}
 | 
					
						
							|  |  |  | 	for _, rr := range gotRRs {
 | 
					
						
							|  |  |  | 		if _, ok := expectedRRs[rr.String()]; !ok {
 | 
					
						
							|  |  |  | 			foundRRs = append(foundRRs, rr)
 | 
					
						
							|  |  |  | 		}
 | 
					
						
							|  |  |  | 	}
 | 
					
						
							|  |  |  | 	return foundRRs
 | 
					
						
							|  |  |  | }
 |