mirror of
				https://github.com/coredns/coredns.git
				synced 2025-10-31 10:13:14 -04:00 
			
		
		
		
	Cache elements of State
Cache the size and the do bit whenever someone asked for it. We can probably add more: PASS BenchmarkStateDo-4 100000000 11.9 ns/op BenchmarkStateSize-4 5000000 265 ns/op ok github.com/miekg/coredns/middleware 2.828s PASS BenchmarkStateDo-4 1000000000 2.86 ns/op BenchmarkStateSize-4 500000000 3.10 ns/op ok github.com/miekg/coredns/middleware 5.032s This PR also includes some testing cleanups as well.
This commit is contained in:
		| @@ -4,6 +4,7 @@ import ( | |||||||
| 	"testing" | 	"testing" | ||||||
|  |  | ||||||
| 	"github.com/miekg/coredns/middleware" | 	"github.com/miekg/coredns/middleware" | ||||||
|  | 	coretest "github.com/miekg/coredns/middleware/testing" | ||||||
|  |  | ||||||
| 	"github.com/miekg/dns" | 	"github.com/miekg/dns" | ||||||
| 	"golang.org/x/net/context" | 	"golang.org/x/net/context" | ||||||
| @@ -57,7 +58,7 @@ func TestChaos(t *testing.T) { | |||||||
| 		req.Question[0].Qclass = dns.ClassCHAOS | 		req.Question[0].Qclass = dns.ClassCHAOS | ||||||
| 		em.Next = test.next | 		em.Next = test.next | ||||||
|  |  | ||||||
| 		rec := middleware.NewResponseRecorder(&middleware.TestResponseWriter{}) | 		rec := middleware.NewResponseRecorder(&coretest.ResponseWriter{}) | ||||||
| 		code, err := em.ServeDNS(ctx, rec, req) | 		code, err := em.ServeDNS(ctx, rec, req) | ||||||
|  |  | ||||||
| 		if err != test.expectedErr { | 		if err != test.expectedErr { | ||||||
|   | |||||||
| @@ -9,6 +9,7 @@ import ( | |||||||
| 	"testing" | 	"testing" | ||||||
|  |  | ||||||
| 	"github.com/miekg/coredns/middleware" | 	"github.com/miekg/coredns/middleware" | ||||||
|  | 	coretest "github.com/miekg/coredns/middleware/testing" | ||||||
|  |  | ||||||
| 	"github.com/miekg/dns" | 	"github.com/miekg/dns" | ||||||
| 	"golang.org/x/net/context" | 	"golang.org/x/net/context" | ||||||
| @@ -46,7 +47,7 @@ func TestErrors(t *testing.T) { | |||||||
| 	for i, test := range tests { | 	for i, test := range tests { | ||||||
| 		em.Next = test.next | 		em.Next = test.next | ||||||
| 		buf.Reset() | 		buf.Reset() | ||||||
| 		rec := middleware.NewResponseRecorder(&middleware.TestResponseWriter{}) | 		rec := middleware.NewResponseRecorder(&coretest.ResponseWriter{}) | ||||||
| 		code, err := em.ServeDNS(ctx, rec, req) | 		code, err := em.ServeDNS(ctx, rec, req) | ||||||
|  |  | ||||||
| 		if err != test.expectedErr { | 		if err != test.expectedErr { | ||||||
| @@ -77,7 +78,7 @@ func TestVisibleErrorWithPanic(t *testing.T) { | |||||||
| 	req := new(dns.Msg) | 	req := new(dns.Msg) | ||||||
| 	req.SetQuestion("example.org.", dns.TypeA) | 	req.SetQuestion("example.org.", dns.TypeA) | ||||||
|  |  | ||||||
| 	rec := middleware.NewResponseRecorder(&middleware.TestResponseWriter{}) | 	rec := middleware.NewResponseRecorder(&coretest.ResponseWriter{}) | ||||||
|  |  | ||||||
| 	code, err := eh.ServeDNS(ctx, rec, req) | 	code, err := eh.ServeDNS(ctx, rec, req) | ||||||
| 	if code != 0 { | 	if code != 0 { | ||||||
|   | |||||||
| @@ -23,7 +23,7 @@ func TestCnameLookup(t *testing.T) { | |||||||
| 	for _, tc := range dnsTestCasesCname { | 	for _, tc := range dnsTestCasesCname { | ||||||
| 		m := tc.Msg() | 		m := tc.Msg() | ||||||
|  |  | ||||||
| 		rec := middleware.NewResponseRecorder(&middleware.TestResponseWriter{}) | 		rec := middleware.NewResponseRecorder(&coretest.ResponseWriter{}) | ||||||
| 		_, err := etc.ServeDNS(ctx, rec, m) | 		_, err := etc.ServeDNS(ctx, rec, m) | ||||||
| 		if err != nil { | 		if err != nil { | ||||||
| 			t.Errorf("expected no error, got %v\n", err) | 			t.Errorf("expected no error, got %v\n", err) | ||||||
|   | |||||||
| @@ -25,7 +25,7 @@ func TestGroupLookup(t *testing.T) { | |||||||
| 	for _, tc := range dnsTestCasesGroup { | 	for _, tc := range dnsTestCasesGroup { | ||||||
| 		m := tc.Msg() | 		m := tc.Msg() | ||||||
|  |  | ||||||
| 		rec := middleware.NewResponseRecorder(&middleware.TestResponseWriter{}) | 		rec := middleware.NewResponseRecorder(&coretest.ResponseWriter{}) | ||||||
| 		_, err := etc.ServeDNS(ctx, rec, m) | 		_, err := etc.ServeDNS(ctx, rec, m) | ||||||
| 		if err != nil { | 		if err != nil { | ||||||
| 			t.Errorf("expected no error, got %v\n", err) | 			t.Errorf("expected no error, got %v\n", err) | ||||||
|   | |||||||
| @@ -28,7 +28,7 @@ func TestMultiLookup(t *testing.T) { | |||||||
| 	for _, tc := range dnsTestCasesMulti { | 	for _, tc := range dnsTestCasesMulti { | ||||||
| 		m := tc.Msg() | 		m := tc.Msg() | ||||||
|  |  | ||||||
| 		rec := middleware.NewResponseRecorder(&middleware.TestResponseWriter{}) | 		rec := middleware.NewResponseRecorder(&coretest.ResponseWriter{}) | ||||||
| 		_, err := etcMulti.ServeDNS(ctx, rec, m) | 		_, err := etcMulti.ServeDNS(ctx, rec, m) | ||||||
| 		if err != nil { | 		if err != nil { | ||||||
| 			t.Errorf("expected no error, got %v\n", err) | 			t.Errorf("expected no error, got %v\n", err) | ||||||
|   | |||||||
| @@ -27,7 +27,7 @@ func TestOtherLookup(t *testing.T) { | |||||||
| 	for _, tc := range dnsTestCasesOther { | 	for _, tc := range dnsTestCasesOther { | ||||||
| 		m := tc.Msg() | 		m := tc.Msg() | ||||||
|  |  | ||||||
| 		rec := middleware.NewResponseRecorder(&middleware.TestResponseWriter{}) | 		rec := middleware.NewResponseRecorder(&coretest.ResponseWriter{}) | ||||||
| 		_, err := etc.ServeDNS(ctx, rec, m) | 		_, err := etc.ServeDNS(ctx, rec, m) | ||||||
| 		if err != nil { | 		if err != nil { | ||||||
| 			t.Errorf("expected no error, got %v\n", err) | 			t.Errorf("expected no error, got %v\n", err) | ||||||
|   | |||||||
| @@ -16,9 +16,9 @@ import ( | |||||||
| 	"github.com/miekg/coredns/middleware/etcd/singleflight" | 	"github.com/miekg/coredns/middleware/etcd/singleflight" | ||||||
| 	"github.com/miekg/coredns/middleware/proxy" | 	"github.com/miekg/coredns/middleware/proxy" | ||||||
| 	coretest "github.com/miekg/coredns/middleware/testing" | 	coretest "github.com/miekg/coredns/middleware/testing" | ||||||
| 	"github.com/miekg/dns" |  | ||||||
|  |  | ||||||
| 	etcdc "github.com/coreos/etcd/client" | 	etcdc "github.com/coreos/etcd/client" | ||||||
|  | 	"github.com/miekg/dns" | ||||||
| 	"golang.org/x/net/context" | 	"golang.org/x/net/context" | ||||||
| ) | ) | ||||||
|  |  | ||||||
| @@ -67,7 +67,7 @@ func TestLookup(t *testing.T) { | |||||||
| 	for _, tc := range dnsTestCases { | 	for _, tc := range dnsTestCases { | ||||||
| 		m := tc.Msg() | 		m := tc.Msg() | ||||||
|  |  | ||||||
| 		rec := middleware.NewResponseRecorder(&middleware.TestResponseWriter{}) | 		rec := middleware.NewResponseRecorder(&coretest.ResponseWriter{}) | ||||||
| 		_, err := etc.ServeDNS(ctx, rec, m) | 		_, err := etc.ServeDNS(ctx, rec, m) | ||||||
| 		if err != nil { | 		if err != nil { | ||||||
| 			t.Errorf("expected no error, got %v\n", err) | 			t.Errorf("expected no error, got %v\n", err) | ||||||
|   | |||||||
| @@ -109,7 +109,7 @@ func TestLookupDNSSEC(t *testing.T) { | |||||||
| 	for _, tc := range dnssecTestCases { | 	for _, tc := range dnssecTestCases { | ||||||
| 		m := tc.Msg() | 		m := tc.Msg() | ||||||
|  |  | ||||||
| 		rec := middleware.NewResponseRecorder(&middleware.TestResponseWriter{}) | 		rec := middleware.NewResponseRecorder(&coretest.ResponseWriter{}) | ||||||
| 		_, err := fm.ServeDNS(ctx, rec, m) | 		_, err := fm.ServeDNS(ctx, rec, m) | ||||||
| 		if err != nil { | 		if err != nil { | ||||||
| 			t.Errorf("expected no error, got %v\n", err) | 			t.Errorf("expected no error, got %v\n", err) | ||||||
| @@ -147,7 +147,7 @@ func BenchmarkLookupDNSSEC(b *testing.B) { | |||||||
|  |  | ||||||
| 	fm := File{Next: coretest.ErrorHandler(), Zones: Zones{Z: map[string]*Zone{testzone: zone}, Names: []string{testzone}}} | 	fm := File{Next: coretest.ErrorHandler(), Zones: Zones{Z: map[string]*Zone{testzone: zone}, Names: []string{testzone}}} | ||||||
| 	ctx := context.TODO() | 	ctx := context.TODO() | ||||||
| 	rec := middleware.NewResponseRecorder(&middleware.TestResponseWriter{}) | 	rec := middleware.NewResponseRecorder(&coretest.ResponseWriter{}) | ||||||
|  |  | ||||||
| 	tc := coretest.Case{ | 	tc := coretest.Case{ | ||||||
| 		Qname: "b.miek.nl.", Qtype: dns.TypeA, Do: true, | 		Qname: "b.miek.nl.", Qtype: dns.TypeA, Do: true, | ||||||
|   | |||||||
| @@ -42,7 +42,7 @@ func TestLookupENT(t *testing.T) { | |||||||
| 	for _, tc := range entTestCases { | 	for _, tc := range entTestCases { | ||||||
| 		m := tc.Msg() | 		m := tc.Msg() | ||||||
|  |  | ||||||
| 		rec := middleware.NewResponseRecorder(&middleware.TestResponseWriter{}) | 		rec := middleware.NewResponseRecorder(&coretest.ResponseWriter{}) | ||||||
| 		_, err := fm.ServeDNS(ctx, rec, m) | 		_, err := fm.ServeDNS(ctx, rec, m) | ||||||
| 		if err != nil { | 		if err != nil { | ||||||
| 			t.Errorf("expected no error, got %v\n", err) | 			t.Errorf("expected no error, got %v\n", err) | ||||||
|   | |||||||
| @@ -77,7 +77,7 @@ func TestLookup(t *testing.T) { | |||||||
| 	for _, tc := range dnsTestCases { | 	for _, tc := range dnsTestCases { | ||||||
| 		m := tc.Msg() | 		m := tc.Msg() | ||||||
|  |  | ||||||
| 		rec := middleware.NewResponseRecorder(&middleware.TestResponseWriter{}) | 		rec := middleware.NewResponseRecorder(&coretest.ResponseWriter{}) | ||||||
| 		_, err := fm.ServeDNS(ctx, rec, m) | 		_, err := fm.ServeDNS(ctx, rec, m) | ||||||
| 		if err != nil { | 		if err != nil { | ||||||
| 			t.Errorf("expected no error, got %v\n", err) | 			t.Errorf("expected no error, got %v\n", err) | ||||||
| @@ -112,7 +112,7 @@ func TestLookupNil(t *testing.T) { | |||||||
| 	ctx := context.TODO() | 	ctx := context.TODO() | ||||||
|  |  | ||||||
| 	m := dnsTestCases[0].Msg() | 	m := dnsTestCases[0].Msg() | ||||||
| 	rec := middleware.NewResponseRecorder(&middleware.TestResponseWriter{}) | 	rec := middleware.NewResponseRecorder(&coretest.ResponseWriter{}) | ||||||
| 	fm.ServeDNS(ctx, rec, m) | 	fm.ServeDNS(ctx, rec, m) | ||||||
| } | } | ||||||
|  |  | ||||||
| @@ -124,7 +124,7 @@ func BenchmarkLookup(b *testing.B) { | |||||||
|  |  | ||||||
| 	fm := File{Next: coretest.ErrorHandler(), Zones: Zones{Z: map[string]*Zone{testzone: zone}, Names: []string{testzone}}} | 	fm := File{Next: coretest.ErrorHandler(), Zones: Zones{Z: map[string]*Zone{testzone: zone}, Names: []string{testzone}}} | ||||||
| 	ctx := context.TODO() | 	ctx := context.TODO() | ||||||
| 	rec := middleware.NewResponseRecorder(&middleware.TestResponseWriter{}) | 	rec := middleware.NewResponseRecorder(&coretest.ResponseWriter{}) | ||||||
|  |  | ||||||
| 	tc := coretest.Case{ | 	tc := coretest.Case{ | ||||||
| 		Qname: "www.miek.nl.", Qtype: dns.TypeA, | 		Qname: "www.miek.nl.", Qtype: dns.TypeA, | ||||||
|   | |||||||
| @@ -56,7 +56,7 @@ func TestLookupWildcard(t *testing.T) { | |||||||
| 	for _, tc := range wildcardTestCases { | 	for _, tc := range wildcardTestCases { | ||||||
| 		m := tc.Msg() | 		m := tc.Msg() | ||||||
|  |  | ||||||
| 		rec := middleware.NewResponseRecorder(&middleware.TestResponseWriter{}) | 		rec := middleware.NewResponseRecorder(&coretest.ResponseWriter{}) | ||||||
| 		_, err := fm.ServeDNS(ctx, rec, m) | 		_, err := fm.ServeDNS(ctx, rec, m) | ||||||
| 		if err != nil { | 		if err != nil { | ||||||
| 			t.Errorf("expected no error, got %v\n", err) | 			t.Errorf("expected no error, got %v\n", err) | ||||||
|   | |||||||
| @@ -4,6 +4,7 @@ import ( | |||||||
| 	"testing" | 	"testing" | ||||||
|  |  | ||||||
| 	"github.com/miekg/coredns/middleware" | 	"github.com/miekg/coredns/middleware" | ||||||
|  | 	coretest "github.com/miekg/coredns/middleware/testing" | ||||||
|  |  | ||||||
| 	"github.com/miekg/dns" | 	"github.com/miekg/dns" | ||||||
| 	"golang.org/x/net/context" | 	"golang.org/x/net/context" | ||||||
| @@ -55,7 +56,7 @@ func TestLoadBalance(t *testing.T) { | |||||||
| 		}, | 		}, | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	rec := middleware.NewResponseRecorder(&middleware.TestResponseWriter{}) | 	rec := middleware.NewResponseRecorder(&coretest.ResponseWriter{}) | ||||||
|  |  | ||||||
| 	for i, test := range tests { | 	for i, test := range tests { | ||||||
| 		req := new(dns.Msg) | 		req := new(dns.Msg) | ||||||
|   | |||||||
| @@ -7,6 +7,8 @@ import ( | |||||||
| 	"testing" | 	"testing" | ||||||
|  |  | ||||||
| 	"github.com/miekg/coredns/middleware" | 	"github.com/miekg/coredns/middleware" | ||||||
|  | 	coretest "github.com/miekg/coredns/middleware/testing" | ||||||
|  |  | ||||||
| 	"github.com/miekg/dns" | 	"github.com/miekg/dns" | ||||||
| 	"golang.org/x/net/context" | 	"golang.org/x/net/context" | ||||||
| ) | ) | ||||||
| @@ -35,7 +37,7 @@ func TestLoggedStatus(t *testing.T) { | |||||||
| 	r := new(dns.Msg) | 	r := new(dns.Msg) | ||||||
| 	r.SetQuestion("example.org.", dns.TypeA) | 	r.SetQuestion("example.org.", dns.TypeA) | ||||||
|  |  | ||||||
| 	rec := middleware.NewResponseRecorder(&middleware.TestResponseWriter{}) | 	rec := middleware.NewResponseRecorder(&coretest.ResponseWriter{}) | ||||||
|  |  | ||||||
| 	rcode, _ := logger.ServeDNS(ctx, rec, r) | 	rcode, _ := logger.ServeDNS(ctx, rec, r) | ||||||
| 	if rcode != 0 { | 	if rcode != 0 { | ||||||
|   | |||||||
| @@ -7,6 +7,8 @@ import ( | |||||||
| 	"testing" | 	"testing" | ||||||
|  |  | ||||||
| 	"github.com/miekg/coredns/middleware" | 	"github.com/miekg/coredns/middleware" | ||||||
|  | 	coretest "github.com/miekg/coredns/middleware/testing" | ||||||
|  |  | ||||||
| 	"github.com/miekg/dns" | 	"github.com/miekg/dns" | ||||||
| ) | ) | ||||||
|  |  | ||||||
| @@ -30,5 +32,5 @@ func TestLookupProxy(t *testing.T) { | |||||||
| } | } | ||||||
|  |  | ||||||
| func fakeState() middleware.State { | func fakeState() middleware.State { | ||||||
| 	return middleware.State{W: &middleware.TestResponseWriter{}, Req: new(dns.Msg)} | 	return middleware.State{W: &coretest.ResponseWriter{}, Req: new(dns.Msg)} | ||||||
| } | } | ||||||
|   | |||||||
| @@ -1,7 +1,6 @@ | |||||||
| package middleware | package middleware | ||||||
|  |  | ||||||
| import ( | import ( | ||||||
| 	"net" |  | ||||||
| 	"time" | 	"time" | ||||||
|  |  | ||||||
| 	"github.com/miekg/dns" | 	"github.com/miekg/dns" | ||||||
| @@ -79,24 +78,3 @@ func (r *ResponseRecorder) Hijack() { | |||||||
| 	r.ResponseWriter.Hijack() | 	r.ResponseWriter.Hijack() | ||||||
| 	return | 	return | ||||||
| } | } | ||||||
|  |  | ||||||
| type TestResponseWriter struct{} |  | ||||||
|  |  | ||||||
| func (t *TestResponseWriter) LocalAddr() net.Addr { |  | ||||||
| 	ip := net.ParseIP("127.0.0.1") |  | ||||||
| 	port := 53 |  | ||||||
| 	return &net.UDPAddr{IP: ip, Port: port, Zone: ""} |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func (t *TestResponseWriter) RemoteAddr() net.Addr { |  | ||||||
| 	ip := net.ParseIP("10.240.0.1") |  | ||||||
| 	port := 40212 |  | ||||||
| 	return &net.UDPAddr{IP: ip, Port: port, Zone: ""} |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func (t *TestResponseWriter) WriteMsg(m *dns.Msg) error     { return nil } |  | ||||||
| func (t *TestResponseWriter) Write(buf []byte) (int, error) { return len(buf), nil } |  | ||||||
| func (t *TestResponseWriter) Close() error                  { return nil } |  | ||||||
| func (t *TestResponseWriter) TsigStatus() error             { return nil } |  | ||||||
| func (t *TestResponseWriter) TsigTimersOnly(bool)           { return } |  | ||||||
| func (t *TestResponseWriter) Hijack()                       { return } |  | ||||||
|   | |||||||
| @@ -2,36 +2,38 @@ package middleware | |||||||
|  |  | ||||||
| import ( | import ( | ||||||
| 	"net" | 	"net" | ||||||
| 	"net/http" |  | ||||||
| 	"strings" | 	"strings" | ||||||
| 	"time" | 	"time" | ||||||
|  |  | ||||||
| 	"github.com/miekg/dns" | 	"github.com/miekg/dns" | ||||||
| ) | ) | ||||||
|  |  | ||||||
| // This file contains the state nd functions available for use in the templates. | // This file contains the state functions available for use in the middlewares. | ||||||
|  |  | ||||||
| // State contains some connection state and is useful in middleware. | // State contains some connection state and is useful in middleware. | ||||||
| type State struct { | type State struct { | ||||||
| 	Root http.FileSystem // TODO(miek): needed? |  | ||||||
| 	Req *dns.Msg | 	Req *dns.Msg | ||||||
| 	W   dns.ResponseWriter | 	W   dns.ResponseWriter | ||||||
|  |  | ||||||
|  | 	// Cache size after first call to Size or Do | ||||||
|  | 	size int | ||||||
|  | 	do   int // 0: not, 1: true: 2: false | ||||||
| } | } | ||||||
|  |  | ||||||
| // Now returns the current timestamp in the specified format. | // Now returns the current timestamp in the specified format. | ||||||
| func (s State) Now(format string) string { return time.Now().Format(format) } | func (s *State) Now(format string) string { return time.Now().Format(format) } | ||||||
|  |  | ||||||
| // NowDate returns the current date/time that can be used in other time functions. | // NowDate returns the current date/time that can be used in other time functions. | ||||||
| func (s State) NowDate() time.Time { return time.Now() } | func (s *State) NowDate() time.Time { return time.Now() } | ||||||
|  |  | ||||||
| // Header gets the heaser of the request in State. | // Header gets the heaser of the request in State. | ||||||
| func (s State) Header() *dns.RR_Header { | func (s *State) Header() *dns.RR_Header { | ||||||
| 	// TODO(miek) | 	// TODO(miek) | ||||||
| 	return nil | 	return nil | ||||||
| } | } | ||||||
|  |  | ||||||
| // IP gets the (remote) IP address of the client making the request. | // IP gets the (remote) IP address of the client making the request. | ||||||
| func (s State) IP() string { | func (s *State) IP() string { | ||||||
| 	ip, _, err := net.SplitHostPort(s.W.RemoteAddr().String()) | 	ip, _, err := net.SplitHostPort(s.W.RemoteAddr().String()) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return s.W.RemoteAddr().String() | 		return s.W.RemoteAddr().String() | ||||||
| @@ -40,7 +42,7 @@ func (s State) IP() string { | |||||||
| } | } | ||||||
|  |  | ||||||
| // Post gets the (remote) Port of the client making the request. | // Post gets the (remote) Port of the client making the request. | ||||||
| func (s State) Port() (string, error) { | func (s *State) Port() (string, error) { | ||||||
| 	_, port, err := net.SplitHostPort(s.W.RemoteAddr().String()) | 	_, port, err := net.SplitHostPort(s.W.RemoteAddr().String()) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return "0", err | 		return "0", err | ||||||
| @@ -50,7 +52,7 @@ func (s State) Port() (string, error) { | |||||||
|  |  | ||||||
| // Proto gets the protocol used as the transport. This | // Proto gets the protocol used as the transport. This | ||||||
| // will be udp or tcp. | // will be udp or tcp. | ||||||
| func (s State) Proto() string { | func (s *State) Proto() string { | ||||||
| 	if _, ok := s.W.RemoteAddr().(*net.UDPAddr); ok { | 	if _, ok := s.W.RemoteAddr().(*net.UDPAddr); ok { | ||||||
| 		return "udp" | 		return "udp" | ||||||
| 	} | 	} | ||||||
| @@ -62,7 +64,7 @@ func (s State) Proto() string { | |||||||
|  |  | ||||||
| // Family returns the family of the transport. | // Family returns the family of the transport. | ||||||
| // 1 for IPv4 and 2 for IPv6. | // 1 for IPv4 and 2 for IPv6. | ||||||
| func (s State) Family() int { | func (s *State) Family() int { | ||||||
| 	var a net.IP | 	var a net.IP | ||||||
| 	ip := s.W.RemoteAddr() | 	ip := s.W.RemoteAddr() | ||||||
| 	if i, ok := ip.(*net.UDPAddr); ok { | 	if i, ok := ip.(*net.UDPAddr); ok { | ||||||
| @@ -79,33 +81,56 @@ func (s State) Family() int { | |||||||
| } | } | ||||||
|  |  | ||||||
| // Do returns if the request has the DO (DNSSEC OK) bit set. | // Do returns if the request has the DO (DNSSEC OK) bit set. | ||||||
| func (s State) Do() bool { | func (s *State) Do() bool { | ||||||
|  | 	if s.do != 0 { | ||||||
|  | 		return s.do == doTrue | ||||||
|  | 	} | ||||||
|  |  | ||||||
| 	if o := s.Req.IsEdns0(); o != nil { | 	if o := s.Req.IsEdns0(); o != nil { | ||||||
|  | 		if o.Do() { | ||||||
|  | 			s.do = doTrue | ||||||
|  | 		} else { | ||||||
|  | 			s.do = doFalse | ||||||
|  | 		} | ||||||
| 		return o.Do() | 		return o.Do() | ||||||
| 	} | 	} | ||||||
|  | 	s.do = doFalse | ||||||
| 	return false | 	return false | ||||||
| } | } | ||||||
|  |  | ||||||
| // UDPSize returns if UDP buffer size advertised in the requests OPT record. | // UDPSize returns if UDP buffer size advertised in the requests OPT record. | ||||||
| // Or when the request was over TCP, we return the maximum allowed size of 64K. | // Or when the request was over TCP, we return the maximum allowed size of 64K. | ||||||
| func (s State) Size() int { | func (s *State) Size() int { | ||||||
|  | 	if s.size != 0 { | ||||||
|  | 		return s.size | ||||||
|  | 	} | ||||||
|  |  | ||||||
| 	if s.Proto() == "tcp" { | 	if s.Proto() == "tcp" { | ||||||
|  | 		s.size = dns.MaxMsgSize | ||||||
| 		return dns.MaxMsgSize | 		return dns.MaxMsgSize | ||||||
| 	} | 	} | ||||||
| 	if o := s.Req.IsEdns0(); o != nil { | 	if o := s.Req.IsEdns0(); o != nil { | ||||||
| 		s := o.UDPSize() | 		if o.Do() == true { | ||||||
| 		if s < dns.MinMsgSize { | 			s.do = doTrue | ||||||
| 			s = dns.MinMsgSize | 		} else { | ||||||
|  | 			s.do = doFalse | ||||||
| 		} | 		} | ||||||
| 		return int(s) |  | ||||||
|  | 		size := o.UDPSize() | ||||||
|  | 		if size < dns.MinMsgSize { | ||||||
|  | 			size = dns.MinMsgSize | ||||||
| 		} | 		} | ||||||
|  | 		s.size = int(size) | ||||||
|  | 		return int(size) | ||||||
|  | 	} | ||||||
|  | 	s.size = dns.MinMsgSize | ||||||
| 	return dns.MinMsgSize | 	return dns.MinMsgSize | ||||||
| } | } | ||||||
|  |  | ||||||
| // SizeAndDo returns a ready made OPT record that the reflects the intent from | // SizeAndDo returns a ready made OPT record that the reflects the intent from | ||||||
| // state. This can be added to upstream requests that will then hopefully | // state. This can be added to upstream requests that will then hopefully | ||||||
| // return a message that is fits the buffer in the client. | // return a message that is fits the buffer in the client. | ||||||
| func (s State) SizeAndDo() *dns.OPT { | func (s *State) SizeAndDo() *dns.OPT { | ||||||
| 	size := s.Size() | 	size := s.Size() | ||||||
| 	Do := s.Do() | 	Do := s.Do() | ||||||
|  |  | ||||||
| @@ -134,7 +159,7 @@ const ( | |||||||
| // the TC bit will be set regardless of protocol, even TCP message will get the bit, the client | // the TC bit will be set regardless of protocol, even TCP message will get the bit, the client | ||||||
| // should then retry with pigeons. | // should then retry with pigeons. | ||||||
| // TODO(referral). | // TODO(referral). | ||||||
| func (s State) Scrub(reply *dns.Msg) (*dns.Msg, Result) { | func (s *State) Scrub(reply *dns.Msg) (*dns.Msg, Result) { | ||||||
| 	size := s.Size() | 	size := s.Size() | ||||||
| 	l := reply.Len() | 	l := reply.Len() | ||||||
| 	if size >= l { | 	if size >= l { | ||||||
| @@ -150,32 +175,36 @@ func (s State) Scrub(reply *dns.Msg) (*dns.Msg, Result) { | |||||||
| 	// Still?!! does not fit. | 	// Still?!! does not fit. | ||||||
| 	reply.Truncated = true | 	reply.Truncated = true | ||||||
| 	return reply, ScrubDone | 	return reply, ScrubDone | ||||||
|  |  | ||||||
| } | } | ||||||
|  |  | ||||||
| // Type returns the type of the question as a string. | // Type returns the type of the question as a string. | ||||||
| func (s State) Type() string { return dns.Type(s.Req.Question[0].Qtype).String() } | func (s *State) Type() string { return dns.Type(s.Req.Question[0].Qtype).String() } | ||||||
|  |  | ||||||
| // QType returns the type of the question as a uint16. | // QType returns the type of the question as a uint16. | ||||||
| func (s State) QType() uint16 { return s.Req.Question[0].Qtype } | func (s *State) QType() uint16 { return s.Req.Question[0].Qtype } | ||||||
|  |  | ||||||
| // Name returns the name of the question in the request. Note | // Name returns the name of the question in the request. Note | ||||||
| // this name will always have a closing dot and will be lower cased. | // this name will always have a closing dot and will be lower cased. | ||||||
| func (s State) Name() string { return strings.ToLower(dns.Name(s.Req.Question[0].Name).String()) } | func (s *State) Name() string { return strings.ToLower(dns.Name(s.Req.Question[0].Name).String()) } | ||||||
|  |  | ||||||
| // QName returns the name of the question in the request. | // QName returns the name of the question in the request. | ||||||
| func (s State) QName() string { return dns.Name(s.Req.Question[0].Name).String() } | func (s *State) QName() string { return dns.Name(s.Req.Question[0].Name).String() } | ||||||
|  |  | ||||||
| // Class returns the class of the question in the request. | // Class returns the class of the question in the request. | ||||||
| func (s State) Class() string { return dns.Class(s.Req.Question[0].Qclass).String() } | func (s *State) Class() string { return dns.Class(s.Req.Question[0].Qclass).String() } | ||||||
|  |  | ||||||
| // QClass returns the class of the question in the request. | // QClass returns the class of the question in the request. | ||||||
| func (s State) QClass() uint16 { return s.Req.Question[0].Qclass } | func (s *State) QClass() uint16 { return s.Req.Question[0].Qclass } | ||||||
|  |  | ||||||
| // ErrorMessage returns an error message suitable for sending | // ErrorMessage returns an error message suitable for sending | ||||||
| // back to the client. | // back to the client. | ||||||
| func (s State) ErrorMessage(rcode int) *dns.Msg { | func (s *State) ErrorMessage(rcode int) *dns.Msg { | ||||||
| 	m := new(dns.Msg) | 	m := new(dns.Msg) | ||||||
| 	m.SetRcode(s.Req, rcode) | 	m.SetRcode(s.Req, rcode) | ||||||
| 	return m | 	return m | ||||||
| } | } | ||||||
|  |  | ||||||
|  | const ( | ||||||
|  | 	doTrue  = 1 | ||||||
|  | 	doFalse = 2 | ||||||
|  | ) | ||||||
|   | |||||||
| @@ -1,5 +1,46 @@ | |||||||
| package middleware | package middleware | ||||||
|  |  | ||||||
|  | import ( | ||||||
|  | 	"testing" | ||||||
|  |  | ||||||
|  | 	coretest "github.com/miekg/coredns/middleware/testing" | ||||||
|  |  | ||||||
|  | 	"github.com/miekg/dns" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | func TestStateDo(t *testing.T) { | ||||||
|  | 	st := testState() | ||||||
|  |  | ||||||
|  | 	st.Do() | ||||||
|  | 	if st.do == 0 { | ||||||
|  | 		t.Fatalf("expected st.do to be set") | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func BenchmarkStateDo(b *testing.B) { | ||||||
|  | 	st := testState() | ||||||
|  |  | ||||||
|  | 	for i := 0; i < b.N; i++ { | ||||||
|  | 		st.Do() | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func BenchmarkStateSize(b *testing.B) { | ||||||
|  | 	st := testState() | ||||||
|  |  | ||||||
|  | 	for i := 0; i < b.N; i++ { | ||||||
|  | 		st.Size() | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func testState() State { | ||||||
|  | 	m := new(dns.Msg) | ||||||
|  | 	m.SetQuestion("example.com.", dns.TypeA) | ||||||
|  | 	m.SetEdns0(4097, true) | ||||||
|  |  | ||||||
|  | 	return State{W: &coretest.ResponseWriter{}, Req: m} | ||||||
|  | } | ||||||
|  |  | ||||||
| /* | /* | ||||||
| func TestHeader(t *testing.T) { | func TestHeader(t *testing.T) { | ||||||
| 	state := getContextOrFail(t) | 	state := getContextOrFail(t) | ||||||
|   | |||||||
| @@ -3,8 +3,6 @@ package testing | |||||||
| import ( | import ( | ||||||
| 	"testing" | 	"testing" | ||||||
|  |  | ||||||
| 	"github.com/miekg/coredns/middleware" |  | ||||||
|  |  | ||||||
| 	"github.com/miekg/dns" | 	"github.com/miekg/dns" | ||||||
| 	"golang.org/x/net/context" | 	"golang.org/x/net/context" | ||||||
| ) | ) | ||||||
| @@ -199,11 +197,27 @@ func Section(t *testing.T, tc Case, sect Sect, rr []dns.RR) bool { | |||||||
| 	return true | 	return true | ||||||
| } | } | ||||||
|  |  | ||||||
| func ErrorHandler() middleware.Handler { | func ErrorHandler() Handler { | ||||||
| 	return middleware.HandlerFunc(func(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) { | 	return HandlerFunc(func(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) { | ||||||
| 		m := new(dns.Msg) | 		m := new(dns.Msg) | ||||||
| 		m.SetRcode(r, dns.RcodeServerFailure) | 		m.SetRcode(r, dns.RcodeServerFailure) | ||||||
| 		w.WriteMsg(m) | 		w.WriteMsg(m) | ||||||
| 		return dns.RcodeServerFailure, nil | 		return dns.RcodeServerFailure, nil | ||||||
| 	}) | 	}) | ||||||
| } | } | ||||||
|  |  | ||||||
|  | // Copied here to prevent an import cycle. | ||||||
|  | type ( | ||||||
|  | 	// HandlerFunc is a convenience type like dns.HandlerFunc, except | ||||||
|  | 	// ServeDNS returns an rcode and an error. | ||||||
|  | 	HandlerFunc func(context.Context, dns.ResponseWriter, *dns.Msg) (int, error) | ||||||
|  |  | ||||||
|  | 	Handler interface { | ||||||
|  | 		ServeDNS(context.Context, dns.ResponseWriter, *dns.Msg) (int, error) | ||||||
|  | 	} | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | // ServeDNS implements the Handler interface. | ||||||
|  | func (f HandlerFunc) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) { | ||||||
|  | 	return f(ctx, w, r) | ||||||
|  | } | ||||||
|   | |||||||
							
								
								
									
										28
									
								
								middleware/testing/responsewriter.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										28
									
								
								middleware/testing/responsewriter.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,28 @@ | |||||||
|  | package testing | ||||||
|  |  | ||||||
|  | import ( | ||||||
|  | 	"net" | ||||||
|  |  | ||||||
|  | 	"github.com/miekg/dns" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | type ResponseWriter struct{} | ||||||
|  |  | ||||||
|  | func (t *ResponseWriter) LocalAddr() net.Addr { | ||||||
|  | 	ip := net.ParseIP("127.0.0.1") | ||||||
|  | 	port := 53 | ||||||
|  | 	return &net.UDPAddr{IP: ip, Port: port, Zone: ""} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (t *ResponseWriter) RemoteAddr() net.Addr { | ||||||
|  | 	ip := net.ParseIP("10.240.0.1") | ||||||
|  | 	port := 40212 | ||||||
|  | 	return &net.UDPAddr{IP: ip, Port: port, Zone: ""} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (t *ResponseWriter) WriteMsg(m *dns.Msg) error     { return nil } | ||||||
|  | func (t *ResponseWriter) Write(buf []byte) (int, error) { return len(buf), nil } | ||||||
|  | func (t *ResponseWriter) Close() error                  { return nil } | ||||||
|  | func (t *ResponseWriter) TsigStatus() error             { return nil } | ||||||
|  | func (t *ResponseWriter) TsigTimersOnly(bool)           { return } | ||||||
|  | func (t *ResponseWriter) Hijack()                       { return } | ||||||
		Reference in New Issue
	
	Block a user