mirror of
				https://github.com/coredns/coredns.git
				synced 2025-10-31 10:13:14 -04:00 
			
		
		
		
	EDNS: return error on wrong version. (#95)
Split up the previous changes a bit. This PR only returns the expected error when the received packet has the wrong EDNS version. EDNS0 handling in the middleware needs a nicer abstraction, like ReflectEdns() or something.
This commit is contained in:
		
							
								
								
									
										34
									
								
								middleware/edns.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										34
									
								
								middleware/edns.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,34 @@ | ||||
| package middleware | ||||
|  | ||||
| import ( | ||||
| 	"errors" | ||||
|  | ||||
| 	"github.com/miekg/dns" | ||||
| ) | ||||
|  | ||||
| // Edns0Version checks the EDNS version in the request. If error | ||||
| // is nil everything is OK and we can invoke the middleware. If non-nil, the | ||||
| // returned Msg is valid to be returned to the client (and should). For some | ||||
| // reason this response should not contain a question RR in the question section. | ||||
| func Edns0Version(req *dns.Msg) (*dns.Msg, error) { | ||||
| 	opt := req.IsEdns0() | ||||
| 	if opt == nil { | ||||
| 		return nil, nil | ||||
| 	} | ||||
| 	if opt.Version() == 0 { | ||||
| 		return nil, nil | ||||
| 	} | ||||
| 	m := new(dns.Msg) | ||||
| 	m.SetReply(req) | ||||
| 	// zero out question section, wtf. | ||||
| 	m.Question = nil | ||||
|  | ||||
| 	o := new(dns.OPT) | ||||
| 	o.Hdr.Name = "." | ||||
| 	o.Hdr.Rrtype = dns.TypeOPT | ||||
| 	o.SetVersion(0) | ||||
| 	o.SetExtendedRcode(dns.RcodeBadVers) | ||||
| 	m.Extra = []dns.RR{o} | ||||
|  | ||||
| 	return m, errors.New("EDNS0 BADVERS") | ||||
| } | ||||
							
								
								
									
										37
									
								
								middleware/edns_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										37
									
								
								middleware/edns_test.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,37 @@ | ||||
| package middleware | ||||
|  | ||||
| import ( | ||||
| 	"testing" | ||||
|  | ||||
| 	"github.com/miekg/dns" | ||||
| ) | ||||
|  | ||||
| func TestEdns0Version(t *testing.T) { | ||||
| 	m := ednsMsg() | ||||
| 	m.Extra[0].(*dns.OPT).SetVersion(2) | ||||
|  | ||||
| 	_, err := Edns0Version(m) | ||||
| 	if err == nil { | ||||
| 		t.Errorf("expected wrong version, but got OK") | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func TestEdns0VersionNoEdns(t *testing.T) { | ||||
| 	m := ednsMsg() | ||||
| 	m.Extra = nil | ||||
|  | ||||
| 	_, err := Edns0Version(m) | ||||
| 	if err != nil { | ||||
| 		t.Errorf("expected no error, but got one: %s", err) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func ednsMsg() *dns.Msg { | ||||
| 	m := new(dns.Msg) | ||||
| 	m.SetQuestion("example.com.", dns.TypeA) | ||||
| 	o := new(dns.OPT) | ||||
| 	o.Hdr.Name = "." | ||||
| 	o.Hdr.Rrtype = dns.TypeOPT | ||||
| 	m.Extra = append(m.Extra, o) | ||||
| 	return m | ||||
| } | ||||
							
								
								
									
										14
									
								
								middleware/rcode.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										14
									
								
								middleware/rcode.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,14 @@ | ||||
| package middleware | ||||
|  | ||||
| import ( | ||||
| 	"strconv" | ||||
|  | ||||
| 	"github.com/miekg/dns" | ||||
| ) | ||||
|  | ||||
| func RcodeToString(rcode int) string { | ||||
| 	if str, ok := dns.RcodeToString[rcode]; ok { | ||||
| 		return str | ||||
| 	} | ||||
| 	return "RCODE" + strconv.Itoa(rcode) | ||||
| } | ||||
| @@ -1,7 +1,6 @@ | ||||
| package middleware | ||||
|  | ||||
| import ( | ||||
| 	"strconv" | ||||
| 	"time" | ||||
|  | ||||
| 	"github.com/miekg/dns" | ||||
| @@ -54,27 +53,16 @@ func (r *ResponseRecorder) Write(buf []byte) (int, error) { | ||||
| } | ||||
|  | ||||
| // Size returns the size. | ||||
| func (r *ResponseRecorder) Size() int { | ||||
| 	return r.size | ||||
| } | ||||
| func (r *ResponseRecorder) Size() int { return r.size } | ||||
|  | ||||
| // Rcode returns the rcode. | ||||
| func (r *ResponseRecorder) Rcode() string { | ||||
| 	if rcode, ok := dns.RcodeToString[r.rcode]; ok { | ||||
| 		return rcode | ||||
| 	} | ||||
| 	return "RCODE" + strconv.Itoa(r.rcode) | ||||
| } | ||||
| func (r *ResponseRecorder) Rcode() string { return RcodeToString(r.rcode) } | ||||
|  | ||||
| // Start returns the start time of the ResponseRecorder. | ||||
| func (r *ResponseRecorder) Start() time.Time { | ||||
| 	return r.start | ||||
| } | ||||
| func (r *ResponseRecorder) Start() time.Time { return r.start } | ||||
|  | ||||
| // Msg returns the written message from the ResponseRecorder. | ||||
| func (r *ResponseRecorder) Msg() *dns.Msg { | ||||
| 	return r.msg | ||||
| } | ||||
| func (r *ResponseRecorder) Msg() *dns.Msg { return r.msg } | ||||
|  | ||||
| // Hijack implements dns.Hijacker. It simply wraps the underlying | ||||
| // ResponseWriter's Hijack method if there is one, or returns an error. | ||||
|   | ||||
| @@ -12,10 +12,10 @@ import ( | ||||
| 	"net" | ||||
| 	"os" | ||||
| 	"runtime" | ||||
| 	"strconv" | ||||
| 	"sync" | ||||
| 	"time" | ||||
|  | ||||
| 	"github.com/miekg/coredns/middleware" | ||||
| 	"github.com/miekg/coredns/middleware/chaos" | ||||
| 	"github.com/miekg/coredns/middleware/prometheus" | ||||
|  | ||||
| @@ -279,6 +279,14 @@ func (s *Server) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { | ||||
| 		} | ||||
| 	}() | ||||
|  | ||||
| 	if m, err := middleware.Edns0Version(r); err != nil { // Wrong EDNS version, return at once. | ||||
| 		qtype := dns.Type(r.Question[0].Qtype).String() | ||||
| 		rc := middleware.RcodeToString(dns.RcodeBadVers) | ||||
| 		metrics.Report(dropped, qtype, rc, m.Len(), time.Now()) | ||||
| 		w.WriteMsg(m) | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	// Execute the optional request callback if it exists | ||||
| 	if s.ReqCallback != nil && s.ReqCallback(w, r) { | ||||
| 		return | ||||
| @@ -332,12 +340,7 @@ func (s *Server) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { | ||||
| // of the specified HTTP status code. | ||||
| func DefaultErrorFunc(w dns.ResponseWriter, r *dns.Msg, rcode int) { | ||||
| 	qtype := dns.Type(r.Question[0].Qtype).String() | ||||
|  | ||||
| 	// this code is duplicated a few times, TODO(miek) | ||||
| 	rc := dns.RcodeToString[rcode] | ||||
| 	if rc == "" { | ||||
| 		rc = "RCODE" + strconv.Itoa(rcode) | ||||
| 	} | ||||
| 	rc := middleware.RcodeToString(rcode) | ||||
|  | ||||
| 	answer := new(dns.Msg) | ||||
| 	answer.SetRcode(r, rcode) | ||||
|   | ||||
		Reference in New Issue
	
	Block a user