mirror of
				https://github.com/coredns/coredns.git
				synced 2025-10-30 01:34:21 -04:00 
			
		
		
		
	EDNS: return error on wrong version. (#95)
Split up the previous changes a bit. This PR only returns the expected error when the received packet has the wrong EDNS version. EDNS0 handling in the middleware needs a nicer abstraction, like ReflectEdns() or something.
This commit is contained in:
		
							
								
								
									
										34
									
								
								middleware/edns.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										34
									
								
								middleware/edns.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,34 @@ | |||||||
|  | package middleware | ||||||
|  |  | ||||||
|  | import ( | ||||||
|  | 	"errors" | ||||||
|  |  | ||||||
|  | 	"github.com/miekg/dns" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | // Edns0Version checks the EDNS version in the request. If error | ||||||
|  | // is nil everything is OK and we can invoke the middleware. If non-nil, the | ||||||
|  | // returned Msg is valid to be returned to the client (and should). For some | ||||||
|  | // reason this response should not contain a question RR in the question section. | ||||||
|  | func Edns0Version(req *dns.Msg) (*dns.Msg, error) { | ||||||
|  | 	opt := req.IsEdns0() | ||||||
|  | 	if opt == nil { | ||||||
|  | 		return nil, nil | ||||||
|  | 	} | ||||||
|  | 	if opt.Version() == 0 { | ||||||
|  | 		return nil, nil | ||||||
|  | 	} | ||||||
|  | 	m := new(dns.Msg) | ||||||
|  | 	m.SetReply(req) | ||||||
|  | 	// zero out question section, wtf. | ||||||
|  | 	m.Question = nil | ||||||
|  |  | ||||||
|  | 	o := new(dns.OPT) | ||||||
|  | 	o.Hdr.Name = "." | ||||||
|  | 	o.Hdr.Rrtype = dns.TypeOPT | ||||||
|  | 	o.SetVersion(0) | ||||||
|  | 	o.SetExtendedRcode(dns.RcodeBadVers) | ||||||
|  | 	m.Extra = []dns.RR{o} | ||||||
|  |  | ||||||
|  | 	return m, errors.New("EDNS0 BADVERS") | ||||||
|  | } | ||||||
							
								
								
									
										37
									
								
								middleware/edns_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										37
									
								
								middleware/edns_test.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,37 @@ | |||||||
|  | package middleware | ||||||
|  |  | ||||||
|  | import ( | ||||||
|  | 	"testing" | ||||||
|  |  | ||||||
|  | 	"github.com/miekg/dns" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | func TestEdns0Version(t *testing.T) { | ||||||
|  | 	m := ednsMsg() | ||||||
|  | 	m.Extra[0].(*dns.OPT).SetVersion(2) | ||||||
|  |  | ||||||
|  | 	_, err := Edns0Version(m) | ||||||
|  | 	if err == nil { | ||||||
|  | 		t.Errorf("expected wrong version, but got OK") | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func TestEdns0VersionNoEdns(t *testing.T) { | ||||||
|  | 	m := ednsMsg() | ||||||
|  | 	m.Extra = nil | ||||||
|  |  | ||||||
|  | 	_, err := Edns0Version(m) | ||||||
|  | 	if err != nil { | ||||||
|  | 		t.Errorf("expected no error, but got one: %s", err) | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func ednsMsg() *dns.Msg { | ||||||
|  | 	m := new(dns.Msg) | ||||||
|  | 	m.SetQuestion("example.com.", dns.TypeA) | ||||||
|  | 	o := new(dns.OPT) | ||||||
|  | 	o.Hdr.Name = "." | ||||||
|  | 	o.Hdr.Rrtype = dns.TypeOPT | ||||||
|  | 	m.Extra = append(m.Extra, o) | ||||||
|  | 	return m | ||||||
|  | } | ||||||
							
								
								
									
										14
									
								
								middleware/rcode.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										14
									
								
								middleware/rcode.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,14 @@ | |||||||
|  | package middleware | ||||||
|  |  | ||||||
|  | import ( | ||||||
|  | 	"strconv" | ||||||
|  |  | ||||||
|  | 	"github.com/miekg/dns" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | func RcodeToString(rcode int) string { | ||||||
|  | 	if str, ok := dns.RcodeToString[rcode]; ok { | ||||||
|  | 		return str | ||||||
|  | 	} | ||||||
|  | 	return "RCODE" + strconv.Itoa(rcode) | ||||||
|  | } | ||||||
| @@ -1,7 +1,6 @@ | |||||||
| package middleware | package middleware | ||||||
|  |  | ||||||
| import ( | import ( | ||||||
| 	"strconv" |  | ||||||
| 	"time" | 	"time" | ||||||
|  |  | ||||||
| 	"github.com/miekg/dns" | 	"github.com/miekg/dns" | ||||||
| @@ -54,27 +53,16 @@ func (r *ResponseRecorder) Write(buf []byte) (int, error) { | |||||||
| } | } | ||||||
|  |  | ||||||
| // Size returns the size. | // Size returns the size. | ||||||
| func (r *ResponseRecorder) Size() int { | func (r *ResponseRecorder) Size() int { return r.size } | ||||||
| 	return r.size |  | ||||||
| } |  | ||||||
|  |  | ||||||
| // Rcode returns the rcode. | // Rcode returns the rcode. | ||||||
| func (r *ResponseRecorder) Rcode() string { | func (r *ResponseRecorder) Rcode() string { return RcodeToString(r.rcode) } | ||||||
| 	if rcode, ok := dns.RcodeToString[r.rcode]; ok { |  | ||||||
| 		return rcode |  | ||||||
| 	} |  | ||||||
| 	return "RCODE" + strconv.Itoa(r.rcode) |  | ||||||
| } |  | ||||||
|  |  | ||||||
| // Start returns the start time of the ResponseRecorder. | // Start returns the start time of the ResponseRecorder. | ||||||
| func (r *ResponseRecorder) Start() time.Time { | func (r *ResponseRecorder) Start() time.Time { return r.start } | ||||||
| 	return r.start |  | ||||||
| } |  | ||||||
|  |  | ||||||
| // Msg returns the written message from the ResponseRecorder. | // Msg returns the written message from the ResponseRecorder. | ||||||
| func (r *ResponseRecorder) Msg() *dns.Msg { | func (r *ResponseRecorder) Msg() *dns.Msg { return r.msg } | ||||||
| 	return r.msg |  | ||||||
| } |  | ||||||
|  |  | ||||||
| // Hijack implements dns.Hijacker. It simply wraps the underlying | // Hijack implements dns.Hijacker. It simply wraps the underlying | ||||||
| // ResponseWriter's Hijack method if there is one, or returns an error. | // ResponseWriter's Hijack method if there is one, or returns an error. | ||||||
|   | |||||||
| @@ -12,10 +12,10 @@ import ( | |||||||
| 	"net" | 	"net" | ||||||
| 	"os" | 	"os" | ||||||
| 	"runtime" | 	"runtime" | ||||||
| 	"strconv" |  | ||||||
| 	"sync" | 	"sync" | ||||||
| 	"time" | 	"time" | ||||||
|  |  | ||||||
|  | 	"github.com/miekg/coredns/middleware" | ||||||
| 	"github.com/miekg/coredns/middleware/chaos" | 	"github.com/miekg/coredns/middleware/chaos" | ||||||
| 	"github.com/miekg/coredns/middleware/prometheus" | 	"github.com/miekg/coredns/middleware/prometheus" | ||||||
|  |  | ||||||
| @@ -279,6 +279,14 @@ func (s *Server) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { | |||||||
| 		} | 		} | ||||||
| 	}() | 	}() | ||||||
|  |  | ||||||
|  | 	if m, err := middleware.Edns0Version(r); err != nil { // Wrong EDNS version, return at once. | ||||||
|  | 		qtype := dns.Type(r.Question[0].Qtype).String() | ||||||
|  | 		rc := middleware.RcodeToString(dns.RcodeBadVers) | ||||||
|  | 		metrics.Report(dropped, qtype, rc, m.Len(), time.Now()) | ||||||
|  | 		w.WriteMsg(m) | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  |  | ||||||
| 	// Execute the optional request callback if it exists | 	// Execute the optional request callback if it exists | ||||||
| 	if s.ReqCallback != nil && s.ReqCallback(w, r) { | 	if s.ReqCallback != nil && s.ReqCallback(w, r) { | ||||||
| 		return | 		return | ||||||
| @@ -332,12 +340,7 @@ func (s *Server) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { | |||||||
| // of the specified HTTP status code. | // of the specified HTTP status code. | ||||||
| func DefaultErrorFunc(w dns.ResponseWriter, r *dns.Msg, rcode int) { | func DefaultErrorFunc(w dns.ResponseWriter, r *dns.Msg, rcode int) { | ||||||
| 	qtype := dns.Type(r.Question[0].Qtype).String() | 	qtype := dns.Type(r.Question[0].Qtype).String() | ||||||
|  | 	rc := middleware.RcodeToString(rcode) | ||||||
| 	// this code is duplicated a few times, TODO(miek) |  | ||||||
| 	rc := dns.RcodeToString[rcode] |  | ||||||
| 	if rc == "" { |  | ||||||
| 		rc = "RCODE" + strconv.Itoa(rcode) |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	answer := new(dns.Msg) | 	answer := new(dns.Msg) | ||||||
| 	answer.SetRcode(r, rcode) | 	answer.SetRcode(r, rcode) | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user