mirror of
				https://github.com/coredns/coredns.git
				synced 2025-10-31 02:03:20 -04:00 
			
		
		
		
	Cache elements of State
Cache the size and the do bit whenever someone asked for it. We can probably add more: PASS BenchmarkStateDo-4 100000000 11.9 ns/op BenchmarkStateSize-4 5000000 265 ns/op ok github.com/miekg/coredns/middleware 2.828s PASS BenchmarkStateDo-4 1000000000 2.86 ns/op BenchmarkStateSize-4 500000000 3.10 ns/op ok github.com/miekg/coredns/middleware 5.032s This PR also includes some testing cleanups as well.
This commit is contained in:
		| @@ -4,6 +4,7 @@ import ( | ||||
| 	"testing" | ||||
|  | ||||
| 	"github.com/miekg/coredns/middleware" | ||||
| 	coretest "github.com/miekg/coredns/middleware/testing" | ||||
|  | ||||
| 	"github.com/miekg/dns" | ||||
| 	"golang.org/x/net/context" | ||||
| @@ -57,7 +58,7 @@ func TestChaos(t *testing.T) { | ||||
| 		req.Question[0].Qclass = dns.ClassCHAOS | ||||
| 		em.Next = test.next | ||||
|  | ||||
| 		rec := middleware.NewResponseRecorder(&middleware.TestResponseWriter{}) | ||||
| 		rec := middleware.NewResponseRecorder(&coretest.ResponseWriter{}) | ||||
| 		code, err := em.ServeDNS(ctx, rec, req) | ||||
|  | ||||
| 		if err != test.expectedErr { | ||||
|   | ||||
| @@ -9,6 +9,7 @@ import ( | ||||
| 	"testing" | ||||
|  | ||||
| 	"github.com/miekg/coredns/middleware" | ||||
| 	coretest "github.com/miekg/coredns/middleware/testing" | ||||
|  | ||||
| 	"github.com/miekg/dns" | ||||
| 	"golang.org/x/net/context" | ||||
| @@ -46,7 +47,7 @@ func TestErrors(t *testing.T) { | ||||
| 	for i, test := range tests { | ||||
| 		em.Next = test.next | ||||
| 		buf.Reset() | ||||
| 		rec := middleware.NewResponseRecorder(&middleware.TestResponseWriter{}) | ||||
| 		rec := middleware.NewResponseRecorder(&coretest.ResponseWriter{}) | ||||
| 		code, err := em.ServeDNS(ctx, rec, req) | ||||
|  | ||||
| 		if err != test.expectedErr { | ||||
| @@ -77,7 +78,7 @@ func TestVisibleErrorWithPanic(t *testing.T) { | ||||
| 	req := new(dns.Msg) | ||||
| 	req.SetQuestion("example.org.", dns.TypeA) | ||||
|  | ||||
| 	rec := middleware.NewResponseRecorder(&middleware.TestResponseWriter{}) | ||||
| 	rec := middleware.NewResponseRecorder(&coretest.ResponseWriter{}) | ||||
|  | ||||
| 	code, err := eh.ServeDNS(ctx, rec, req) | ||||
| 	if code != 0 { | ||||
|   | ||||
| @@ -23,7 +23,7 @@ func TestCnameLookup(t *testing.T) { | ||||
| 	for _, tc := range dnsTestCasesCname { | ||||
| 		m := tc.Msg() | ||||
|  | ||||
| 		rec := middleware.NewResponseRecorder(&middleware.TestResponseWriter{}) | ||||
| 		rec := middleware.NewResponseRecorder(&coretest.ResponseWriter{}) | ||||
| 		_, err := etc.ServeDNS(ctx, rec, m) | ||||
| 		if err != nil { | ||||
| 			t.Errorf("expected no error, got %v\n", err) | ||||
|   | ||||
| @@ -25,7 +25,7 @@ func TestGroupLookup(t *testing.T) { | ||||
| 	for _, tc := range dnsTestCasesGroup { | ||||
| 		m := tc.Msg() | ||||
|  | ||||
| 		rec := middleware.NewResponseRecorder(&middleware.TestResponseWriter{}) | ||||
| 		rec := middleware.NewResponseRecorder(&coretest.ResponseWriter{}) | ||||
| 		_, err := etc.ServeDNS(ctx, rec, m) | ||||
| 		if err != nil { | ||||
| 			t.Errorf("expected no error, got %v\n", err) | ||||
|   | ||||
| @@ -28,7 +28,7 @@ func TestMultiLookup(t *testing.T) { | ||||
| 	for _, tc := range dnsTestCasesMulti { | ||||
| 		m := tc.Msg() | ||||
|  | ||||
| 		rec := middleware.NewResponseRecorder(&middleware.TestResponseWriter{}) | ||||
| 		rec := middleware.NewResponseRecorder(&coretest.ResponseWriter{}) | ||||
| 		_, err := etcMulti.ServeDNS(ctx, rec, m) | ||||
| 		if err != nil { | ||||
| 			t.Errorf("expected no error, got %v\n", err) | ||||
|   | ||||
| @@ -27,7 +27,7 @@ func TestOtherLookup(t *testing.T) { | ||||
| 	for _, tc := range dnsTestCasesOther { | ||||
| 		m := tc.Msg() | ||||
|  | ||||
| 		rec := middleware.NewResponseRecorder(&middleware.TestResponseWriter{}) | ||||
| 		rec := middleware.NewResponseRecorder(&coretest.ResponseWriter{}) | ||||
| 		_, err := etc.ServeDNS(ctx, rec, m) | ||||
| 		if err != nil { | ||||
| 			t.Errorf("expected no error, got %v\n", err) | ||||
|   | ||||
| @@ -16,9 +16,9 @@ import ( | ||||
| 	"github.com/miekg/coredns/middleware/etcd/singleflight" | ||||
| 	"github.com/miekg/coredns/middleware/proxy" | ||||
| 	coretest "github.com/miekg/coredns/middleware/testing" | ||||
| 	"github.com/miekg/dns" | ||||
|  | ||||
| 	etcdc "github.com/coreos/etcd/client" | ||||
| 	"github.com/miekg/dns" | ||||
| 	"golang.org/x/net/context" | ||||
| ) | ||||
|  | ||||
| @@ -67,7 +67,7 @@ func TestLookup(t *testing.T) { | ||||
| 	for _, tc := range dnsTestCases { | ||||
| 		m := tc.Msg() | ||||
|  | ||||
| 		rec := middleware.NewResponseRecorder(&middleware.TestResponseWriter{}) | ||||
| 		rec := middleware.NewResponseRecorder(&coretest.ResponseWriter{}) | ||||
| 		_, err := etc.ServeDNS(ctx, rec, m) | ||||
| 		if err != nil { | ||||
| 			t.Errorf("expected no error, got %v\n", err) | ||||
|   | ||||
| @@ -109,7 +109,7 @@ func TestLookupDNSSEC(t *testing.T) { | ||||
| 	for _, tc := range dnssecTestCases { | ||||
| 		m := tc.Msg() | ||||
|  | ||||
| 		rec := middleware.NewResponseRecorder(&middleware.TestResponseWriter{}) | ||||
| 		rec := middleware.NewResponseRecorder(&coretest.ResponseWriter{}) | ||||
| 		_, err := fm.ServeDNS(ctx, rec, m) | ||||
| 		if err != nil { | ||||
| 			t.Errorf("expected no error, got %v\n", err) | ||||
| @@ -147,7 +147,7 @@ func BenchmarkLookupDNSSEC(b *testing.B) { | ||||
|  | ||||
| 	fm := File{Next: coretest.ErrorHandler(), Zones: Zones{Z: map[string]*Zone{testzone: zone}, Names: []string{testzone}}} | ||||
| 	ctx := context.TODO() | ||||
| 	rec := middleware.NewResponseRecorder(&middleware.TestResponseWriter{}) | ||||
| 	rec := middleware.NewResponseRecorder(&coretest.ResponseWriter{}) | ||||
|  | ||||
| 	tc := coretest.Case{ | ||||
| 		Qname: "b.miek.nl.", Qtype: dns.TypeA, Do: true, | ||||
|   | ||||
| @@ -42,7 +42,7 @@ func TestLookupENT(t *testing.T) { | ||||
| 	for _, tc := range entTestCases { | ||||
| 		m := tc.Msg() | ||||
|  | ||||
| 		rec := middleware.NewResponseRecorder(&middleware.TestResponseWriter{}) | ||||
| 		rec := middleware.NewResponseRecorder(&coretest.ResponseWriter{}) | ||||
| 		_, err := fm.ServeDNS(ctx, rec, m) | ||||
| 		if err != nil { | ||||
| 			t.Errorf("expected no error, got %v\n", err) | ||||
|   | ||||
| @@ -77,7 +77,7 @@ func TestLookup(t *testing.T) { | ||||
| 	for _, tc := range dnsTestCases { | ||||
| 		m := tc.Msg() | ||||
|  | ||||
| 		rec := middleware.NewResponseRecorder(&middleware.TestResponseWriter{}) | ||||
| 		rec := middleware.NewResponseRecorder(&coretest.ResponseWriter{}) | ||||
| 		_, err := fm.ServeDNS(ctx, rec, m) | ||||
| 		if err != nil { | ||||
| 			t.Errorf("expected no error, got %v\n", err) | ||||
| @@ -112,7 +112,7 @@ func TestLookupNil(t *testing.T) { | ||||
| 	ctx := context.TODO() | ||||
|  | ||||
| 	m := dnsTestCases[0].Msg() | ||||
| 	rec := middleware.NewResponseRecorder(&middleware.TestResponseWriter{}) | ||||
| 	rec := middleware.NewResponseRecorder(&coretest.ResponseWriter{}) | ||||
| 	fm.ServeDNS(ctx, rec, m) | ||||
| } | ||||
|  | ||||
| @@ -124,7 +124,7 @@ func BenchmarkLookup(b *testing.B) { | ||||
|  | ||||
| 	fm := File{Next: coretest.ErrorHandler(), Zones: Zones{Z: map[string]*Zone{testzone: zone}, Names: []string{testzone}}} | ||||
| 	ctx := context.TODO() | ||||
| 	rec := middleware.NewResponseRecorder(&middleware.TestResponseWriter{}) | ||||
| 	rec := middleware.NewResponseRecorder(&coretest.ResponseWriter{}) | ||||
|  | ||||
| 	tc := coretest.Case{ | ||||
| 		Qname: "www.miek.nl.", Qtype: dns.TypeA, | ||||
|   | ||||
| @@ -56,7 +56,7 @@ func TestLookupWildcard(t *testing.T) { | ||||
| 	for _, tc := range wildcardTestCases { | ||||
| 		m := tc.Msg() | ||||
|  | ||||
| 		rec := middleware.NewResponseRecorder(&middleware.TestResponseWriter{}) | ||||
| 		rec := middleware.NewResponseRecorder(&coretest.ResponseWriter{}) | ||||
| 		_, err := fm.ServeDNS(ctx, rec, m) | ||||
| 		if err != nil { | ||||
| 			t.Errorf("expected no error, got %v\n", err) | ||||
|   | ||||
| @@ -4,6 +4,7 @@ import ( | ||||
| 	"testing" | ||||
|  | ||||
| 	"github.com/miekg/coredns/middleware" | ||||
| 	coretest "github.com/miekg/coredns/middleware/testing" | ||||
|  | ||||
| 	"github.com/miekg/dns" | ||||
| 	"golang.org/x/net/context" | ||||
| @@ -55,7 +56,7 @@ func TestLoadBalance(t *testing.T) { | ||||
| 		}, | ||||
| 	} | ||||
|  | ||||
| 	rec := middleware.NewResponseRecorder(&middleware.TestResponseWriter{}) | ||||
| 	rec := middleware.NewResponseRecorder(&coretest.ResponseWriter{}) | ||||
|  | ||||
| 	for i, test := range tests { | ||||
| 		req := new(dns.Msg) | ||||
|   | ||||
| @@ -7,6 +7,8 @@ import ( | ||||
| 	"testing" | ||||
|  | ||||
| 	"github.com/miekg/coredns/middleware" | ||||
| 	coretest "github.com/miekg/coredns/middleware/testing" | ||||
|  | ||||
| 	"github.com/miekg/dns" | ||||
| 	"golang.org/x/net/context" | ||||
| ) | ||||
| @@ -35,7 +37,7 @@ func TestLoggedStatus(t *testing.T) { | ||||
| 	r := new(dns.Msg) | ||||
| 	r.SetQuestion("example.org.", dns.TypeA) | ||||
|  | ||||
| 	rec := middleware.NewResponseRecorder(&middleware.TestResponseWriter{}) | ||||
| 	rec := middleware.NewResponseRecorder(&coretest.ResponseWriter{}) | ||||
|  | ||||
| 	rcode, _ := logger.ServeDNS(ctx, rec, r) | ||||
| 	if rcode != 0 { | ||||
|   | ||||
| @@ -7,6 +7,8 @@ import ( | ||||
| 	"testing" | ||||
|  | ||||
| 	"github.com/miekg/coredns/middleware" | ||||
| 	coretest "github.com/miekg/coredns/middleware/testing" | ||||
|  | ||||
| 	"github.com/miekg/dns" | ||||
| ) | ||||
|  | ||||
| @@ -30,5 +32,5 @@ func TestLookupProxy(t *testing.T) { | ||||
| } | ||||
|  | ||||
| func fakeState() middleware.State { | ||||
| 	return middleware.State{W: &middleware.TestResponseWriter{}, Req: new(dns.Msg)} | ||||
| 	return middleware.State{W: &coretest.ResponseWriter{}, Req: new(dns.Msg)} | ||||
| } | ||||
|   | ||||
| @@ -1,7 +1,6 @@ | ||||
| package middleware | ||||
|  | ||||
| import ( | ||||
| 	"net" | ||||
| 	"time" | ||||
|  | ||||
| 	"github.com/miekg/dns" | ||||
| @@ -79,24 +78,3 @@ func (r *ResponseRecorder) Hijack() { | ||||
| 	r.ResponseWriter.Hijack() | ||||
| 	return | ||||
| } | ||||
|  | ||||
| type TestResponseWriter struct{} | ||||
|  | ||||
| func (t *TestResponseWriter) LocalAddr() net.Addr { | ||||
| 	ip := net.ParseIP("127.0.0.1") | ||||
| 	port := 53 | ||||
| 	return &net.UDPAddr{IP: ip, Port: port, Zone: ""} | ||||
| } | ||||
|  | ||||
| func (t *TestResponseWriter) RemoteAddr() net.Addr { | ||||
| 	ip := net.ParseIP("10.240.0.1") | ||||
| 	port := 40212 | ||||
| 	return &net.UDPAddr{IP: ip, Port: port, Zone: ""} | ||||
| } | ||||
|  | ||||
| func (t *TestResponseWriter) WriteMsg(m *dns.Msg) error     { return nil } | ||||
| func (t *TestResponseWriter) Write(buf []byte) (int, error) { return len(buf), nil } | ||||
| func (t *TestResponseWriter) Close() error                  { return nil } | ||||
| func (t *TestResponseWriter) TsigStatus() error             { return nil } | ||||
| func (t *TestResponseWriter) TsigTimersOnly(bool)           { return } | ||||
| func (t *TestResponseWriter) Hijack()                       { return } | ||||
|   | ||||
| @@ -2,36 +2,38 @@ package middleware | ||||
|  | ||||
| import ( | ||||
| 	"net" | ||||
| 	"net/http" | ||||
| 	"strings" | ||||
| 	"time" | ||||
|  | ||||
| 	"github.com/miekg/dns" | ||||
| ) | ||||
|  | ||||
| // This file contains the state nd functions available for use in the templates. | ||||
| // This file contains the state functions available for use in the middlewares. | ||||
|  | ||||
| // State contains some connection state and is useful in middleware. | ||||
| type State struct { | ||||
| 	Root http.FileSystem // TODO(miek): needed? | ||||
| 	Req *dns.Msg | ||||
| 	W   dns.ResponseWriter | ||||
|  | ||||
| 	// Cache size after first call to Size or Do | ||||
| 	size int | ||||
| 	do   int // 0: not, 1: true: 2: false | ||||
| } | ||||
|  | ||||
| // Now returns the current timestamp in the specified format. | ||||
| func (s State) Now(format string) string { return time.Now().Format(format) } | ||||
| func (s *State) Now(format string) string { return time.Now().Format(format) } | ||||
|  | ||||
| // NowDate returns the current date/time that can be used in other time functions. | ||||
| func (s State) NowDate() time.Time { return time.Now() } | ||||
| func (s *State) NowDate() time.Time { return time.Now() } | ||||
|  | ||||
| // Header gets the heaser of the request in State. | ||||
| func (s State) Header() *dns.RR_Header { | ||||
| func (s *State) Header() *dns.RR_Header { | ||||
| 	// TODO(miek) | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| // IP gets the (remote) IP address of the client making the request. | ||||
| func (s State) IP() string { | ||||
| func (s *State) IP() string { | ||||
| 	ip, _, err := net.SplitHostPort(s.W.RemoteAddr().String()) | ||||
| 	if err != nil { | ||||
| 		return s.W.RemoteAddr().String() | ||||
| @@ -40,7 +42,7 @@ func (s State) IP() string { | ||||
| } | ||||
|  | ||||
| // Post gets the (remote) Port of the client making the request. | ||||
| func (s State) Port() (string, error) { | ||||
| func (s *State) Port() (string, error) { | ||||
| 	_, port, err := net.SplitHostPort(s.W.RemoteAddr().String()) | ||||
| 	if err != nil { | ||||
| 		return "0", err | ||||
| @@ -50,7 +52,7 @@ func (s State) Port() (string, error) { | ||||
|  | ||||
| // Proto gets the protocol used as the transport. This | ||||
| // will be udp or tcp. | ||||
| func (s State) Proto() string { | ||||
| func (s *State) Proto() string { | ||||
| 	if _, ok := s.W.RemoteAddr().(*net.UDPAddr); ok { | ||||
| 		return "udp" | ||||
| 	} | ||||
| @@ -62,7 +64,7 @@ func (s State) Proto() string { | ||||
|  | ||||
| // Family returns the family of the transport. | ||||
| // 1 for IPv4 and 2 for IPv6. | ||||
| func (s State) Family() int { | ||||
| func (s *State) Family() int { | ||||
| 	var a net.IP | ||||
| 	ip := s.W.RemoteAddr() | ||||
| 	if i, ok := ip.(*net.UDPAddr); ok { | ||||
| @@ -79,33 +81,56 @@ func (s State) Family() int { | ||||
| } | ||||
|  | ||||
| // Do returns if the request has the DO (DNSSEC OK) bit set. | ||||
| func (s State) Do() bool { | ||||
| func (s *State) Do() bool { | ||||
| 	if s.do != 0 { | ||||
| 		return s.do == doTrue | ||||
| 	} | ||||
|  | ||||
| 	if o := s.Req.IsEdns0(); o != nil { | ||||
| 		if o.Do() { | ||||
| 			s.do = doTrue | ||||
| 		} else { | ||||
| 			s.do = doFalse | ||||
| 		} | ||||
| 		return o.Do() | ||||
| 	} | ||||
| 	s.do = doFalse | ||||
| 	return false | ||||
| } | ||||
|  | ||||
| // UDPSize returns if UDP buffer size advertised in the requests OPT record. | ||||
| // Or when the request was over TCP, we return the maximum allowed size of 64K. | ||||
| func (s State) Size() int { | ||||
| func (s *State) Size() int { | ||||
| 	if s.size != 0 { | ||||
| 		return s.size | ||||
| 	} | ||||
|  | ||||
| 	if s.Proto() == "tcp" { | ||||
| 		s.size = dns.MaxMsgSize | ||||
| 		return dns.MaxMsgSize | ||||
| 	} | ||||
| 	if o := s.Req.IsEdns0(); o != nil { | ||||
| 		s := o.UDPSize() | ||||
| 		if s < dns.MinMsgSize { | ||||
| 			s = dns.MinMsgSize | ||||
| 		if o.Do() == true { | ||||
| 			s.do = doTrue | ||||
| 		} else { | ||||
| 			s.do = doFalse | ||||
| 		} | ||||
| 		return int(s) | ||||
|  | ||||
| 		size := o.UDPSize() | ||||
| 		if size < dns.MinMsgSize { | ||||
| 			size = dns.MinMsgSize | ||||
| 		} | ||||
| 		s.size = int(size) | ||||
| 		return int(size) | ||||
| 	} | ||||
| 	s.size = dns.MinMsgSize | ||||
| 	return dns.MinMsgSize | ||||
| } | ||||
|  | ||||
| // SizeAndDo returns a ready made OPT record that the reflects the intent from | ||||
| // state. This can be added to upstream requests that will then hopefully | ||||
| // return a message that is fits the buffer in the client. | ||||
| func (s State) SizeAndDo() *dns.OPT { | ||||
| func (s *State) SizeAndDo() *dns.OPT { | ||||
| 	size := s.Size() | ||||
| 	Do := s.Do() | ||||
|  | ||||
| @@ -134,7 +159,7 @@ const ( | ||||
| // the TC bit will be set regardless of protocol, even TCP message will get the bit, the client | ||||
| // should then retry with pigeons. | ||||
| // TODO(referral). | ||||
| func (s State) Scrub(reply *dns.Msg) (*dns.Msg, Result) { | ||||
| func (s *State) Scrub(reply *dns.Msg) (*dns.Msg, Result) { | ||||
| 	size := s.Size() | ||||
| 	l := reply.Len() | ||||
| 	if size >= l { | ||||
| @@ -150,32 +175,36 @@ func (s State) Scrub(reply *dns.Msg) (*dns.Msg, Result) { | ||||
| 	// Still?!! does not fit. | ||||
| 	reply.Truncated = true | ||||
| 	return reply, ScrubDone | ||||
|  | ||||
| } | ||||
|  | ||||
| // Type returns the type of the question as a string. | ||||
| func (s State) Type() string { return dns.Type(s.Req.Question[0].Qtype).String() } | ||||
| func (s *State) Type() string { return dns.Type(s.Req.Question[0].Qtype).String() } | ||||
|  | ||||
| // QType returns the type of the question as a uint16. | ||||
| func (s State) QType() uint16 { return s.Req.Question[0].Qtype } | ||||
| func (s *State) QType() uint16 { return s.Req.Question[0].Qtype } | ||||
|  | ||||
| // Name returns the name of the question in the request. Note | ||||
| // this name will always have a closing dot and will be lower cased. | ||||
| func (s State) Name() string { return strings.ToLower(dns.Name(s.Req.Question[0].Name).String()) } | ||||
| func (s *State) Name() string { return strings.ToLower(dns.Name(s.Req.Question[0].Name).String()) } | ||||
|  | ||||
| // QName returns the name of the question in the request. | ||||
| func (s State) QName() string { return dns.Name(s.Req.Question[0].Name).String() } | ||||
| func (s *State) QName() string { return dns.Name(s.Req.Question[0].Name).String() } | ||||
|  | ||||
| // Class returns the class of the question in the request. | ||||
| func (s State) Class() string { return dns.Class(s.Req.Question[0].Qclass).String() } | ||||
| func (s *State) Class() string { return dns.Class(s.Req.Question[0].Qclass).String() } | ||||
|  | ||||
| // QClass returns the class of the question in the request. | ||||
| func (s State) QClass() uint16 { return s.Req.Question[0].Qclass } | ||||
| func (s *State) QClass() uint16 { return s.Req.Question[0].Qclass } | ||||
|  | ||||
| // ErrorMessage returns an error message suitable for sending | ||||
| // back to the client. | ||||
| func (s State) ErrorMessage(rcode int) *dns.Msg { | ||||
| func (s *State) ErrorMessage(rcode int) *dns.Msg { | ||||
| 	m := new(dns.Msg) | ||||
| 	m.SetRcode(s.Req, rcode) | ||||
| 	return m | ||||
| } | ||||
|  | ||||
| const ( | ||||
| 	doTrue  = 1 | ||||
| 	doFalse = 2 | ||||
| ) | ||||
|   | ||||
| @@ -1,5 +1,46 @@ | ||||
| package middleware | ||||
|  | ||||
| import ( | ||||
| 	"testing" | ||||
|  | ||||
| 	coretest "github.com/miekg/coredns/middleware/testing" | ||||
|  | ||||
| 	"github.com/miekg/dns" | ||||
| ) | ||||
|  | ||||
| func TestStateDo(t *testing.T) { | ||||
| 	st := testState() | ||||
|  | ||||
| 	st.Do() | ||||
| 	if st.do == 0 { | ||||
| 		t.Fatalf("expected st.do to be set") | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func BenchmarkStateDo(b *testing.B) { | ||||
| 	st := testState() | ||||
|  | ||||
| 	for i := 0; i < b.N; i++ { | ||||
| 		st.Do() | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func BenchmarkStateSize(b *testing.B) { | ||||
| 	st := testState() | ||||
|  | ||||
| 	for i := 0; i < b.N; i++ { | ||||
| 		st.Size() | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func testState() State { | ||||
| 	m := new(dns.Msg) | ||||
| 	m.SetQuestion("example.com.", dns.TypeA) | ||||
| 	m.SetEdns0(4097, true) | ||||
|  | ||||
| 	return State{W: &coretest.ResponseWriter{}, Req: m} | ||||
| } | ||||
|  | ||||
| /* | ||||
| func TestHeader(t *testing.T) { | ||||
| 	state := getContextOrFail(t) | ||||
|   | ||||
| @@ -3,8 +3,6 @@ package testing | ||||
| import ( | ||||
| 	"testing" | ||||
|  | ||||
| 	"github.com/miekg/coredns/middleware" | ||||
|  | ||||
| 	"github.com/miekg/dns" | ||||
| 	"golang.org/x/net/context" | ||||
| ) | ||||
| @@ -199,11 +197,27 @@ func Section(t *testing.T, tc Case, sect Sect, rr []dns.RR) bool { | ||||
| 	return true | ||||
| } | ||||
|  | ||||
| func ErrorHandler() middleware.Handler { | ||||
| 	return middleware.HandlerFunc(func(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) { | ||||
| func ErrorHandler() Handler { | ||||
| 	return HandlerFunc(func(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) { | ||||
| 		m := new(dns.Msg) | ||||
| 		m.SetRcode(r, dns.RcodeServerFailure) | ||||
| 		w.WriteMsg(m) | ||||
| 		return dns.RcodeServerFailure, nil | ||||
| 	}) | ||||
| } | ||||
|  | ||||
| // Copied here to prevent an import cycle. | ||||
| type ( | ||||
| 	// HandlerFunc is a convenience type like dns.HandlerFunc, except | ||||
| 	// ServeDNS returns an rcode and an error. | ||||
| 	HandlerFunc func(context.Context, dns.ResponseWriter, *dns.Msg) (int, error) | ||||
|  | ||||
| 	Handler interface { | ||||
| 		ServeDNS(context.Context, dns.ResponseWriter, *dns.Msg) (int, error) | ||||
| 	} | ||||
| ) | ||||
|  | ||||
| // ServeDNS implements the Handler interface. | ||||
| func (f HandlerFunc) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) { | ||||
| 	return f(ctx, w, r) | ||||
| } | ||||
|   | ||||
							
								
								
									
										28
									
								
								middleware/testing/responsewriter.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										28
									
								
								middleware/testing/responsewriter.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,28 @@ | ||||
| package testing | ||||
|  | ||||
| import ( | ||||
| 	"net" | ||||
|  | ||||
| 	"github.com/miekg/dns" | ||||
| ) | ||||
|  | ||||
| type ResponseWriter struct{} | ||||
|  | ||||
| func (t *ResponseWriter) LocalAddr() net.Addr { | ||||
| 	ip := net.ParseIP("127.0.0.1") | ||||
| 	port := 53 | ||||
| 	return &net.UDPAddr{IP: ip, Port: port, Zone: ""} | ||||
| } | ||||
|  | ||||
| func (t *ResponseWriter) RemoteAddr() net.Addr { | ||||
| 	ip := net.ParseIP("10.240.0.1") | ||||
| 	port := 40212 | ||||
| 	return &net.UDPAddr{IP: ip, Port: port, Zone: ""} | ||||
| } | ||||
|  | ||||
| func (t *ResponseWriter) WriteMsg(m *dns.Msg) error     { return nil } | ||||
| func (t *ResponseWriter) Write(buf []byte) (int, error) { return len(buf), nil } | ||||
| func (t *ResponseWriter) Close() error                  { return nil } | ||||
| func (t *ResponseWriter) TsigStatus() error             { return nil } | ||||
| func (t *ResponseWriter) TsigTimersOnly(bool)           { return } | ||||
| func (t *ResponseWriter) Hijack()                       { return } | ||||
		Reference in New Issue
	
	Block a user