| 
									
										
										
										
											2016-03-18 20:57:35 +00:00
										 |  |  | package middleware | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | import ( | 
					
						
							| 
									
										
										
										
											2016-03-19 11:16:08 +00:00
										 |  |  | 	"net" | 
					
						
							| 
									
										
										
										
											2016-03-18 20:57:35 +00:00
										 |  |  | 	"time" | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	"github.com/miekg/dns" | 
					
						
							|  |  |  | ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | // ResponseRecorder is a type of ResponseWriter that captures | 
					
						
							|  |  |  | // the rcode code written to it and also the size of the message | 
					
						
							|  |  |  | // written in the response. A rcode code does not have | 
					
						
							|  |  |  | // to be written, however, in which case 0 must be assumed. | 
					
						
							|  |  |  | // It is best to have the constructor initialize this type | 
					
						
							|  |  |  | // with that default status code. | 
					
						
							|  |  |  | type ResponseRecorder struct { | 
					
						
							|  |  |  | 	dns.ResponseWriter | 
					
						
							|  |  |  | 	rcode int | 
					
						
							|  |  |  | 	size  int | 
					
						
							| 
									
										
										
										
											2016-03-23 08:30:29 +00:00
										 |  |  | 	msg   *dns.Msg | 
					
						
							| 
									
										
										
										
											2016-03-18 20:57:35 +00:00
										 |  |  | 	start time.Time | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | // NewResponseRecorder makes and returns a new responseRecorder, | 
					
						
							|  |  |  | // which captures the DNS rcode from the ResponseWriter | 
					
						
							|  |  |  | // and also the length of the response message written through it. | 
					
						
							|  |  |  | func NewResponseRecorder(w dns.ResponseWriter) *ResponseRecorder { | 
					
						
							|  |  |  | 	return &ResponseRecorder{ | 
					
						
							|  |  |  | 		ResponseWriter: w, | 
					
						
							|  |  |  | 		rcode:          0, | 
					
						
							| 
									
										
										
										
											2016-03-23 08:30:29 +00:00
										 |  |  | 		msg:            nil, | 
					
						
							| 
									
										
										
										
											2016-03-18 20:57:35 +00:00
										 |  |  | 		start:          time.Now(), | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | // WriteMsg records the status code and calls the | 
					
						
							|  |  |  | // underlying ResponseWriter's WriteMsg method. | 
					
						
							|  |  |  | func (r *ResponseRecorder) WriteMsg(res *dns.Msg) error { | 
					
						
							|  |  |  | 	r.rcode = res.Rcode | 
					
						
							| 
									
										
										
										
											2016-03-28 12:08:05 +01:00
										 |  |  | 	// We may get called multiple times (axfr for instance). | 
					
						
							|  |  |  | 	// Save the last message, but add the sizes. | 
					
						
							|  |  |  | 	r.size += res.Len() | 
					
						
							| 
									
										
										
										
											2016-03-23 08:30:29 +00:00
										 |  |  | 	r.msg = res | 
					
						
							| 
									
										
										
										
											2016-03-18 20:57:35 +00:00
										 |  |  | 	return r.ResponseWriter.WriteMsg(res) | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | // Write is a wrapper that records the size of the message that gets written. | 
					
						
							|  |  |  | func (r *ResponseRecorder) Write(buf []byte) (int, error) { | 
					
						
							|  |  |  | 	n, err := r.ResponseWriter.Write(buf) | 
					
						
							|  |  |  | 	if err == nil { | 
					
						
							|  |  |  | 		r.size += n | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 	return n, err | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | // Size returns the size. | 
					
						
							|  |  |  | func (r *ResponseRecorder) Size() int { | 
					
						
							|  |  |  | 	return r.size | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | // Rcode returns the rcode. | 
					
						
							|  |  |  | func (r *ResponseRecorder) Rcode() int { | 
					
						
							|  |  |  | 	return r.rcode | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | // Start returns the start time of the ResponseRecorder. | 
					
						
							|  |  |  | func (r *ResponseRecorder) Start() time.Time { | 
					
						
							|  |  |  | 	return r.start | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2016-03-24 17:58:31 +00:00
										 |  |  | // Msg returns the written message from the ResponseRecorder. | 
					
						
							|  |  |  | func (r *ResponseRecorder) Msg() *dns.Msg { | 
					
						
							| 
									
										
										
										
											2016-03-23 08:30:29 +00:00
										 |  |  | 	return r.msg | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2016-03-18 20:57:35 +00:00
										 |  |  | // Hijack implements dns.Hijacker. It simply wraps the underlying | 
					
						
							|  |  |  | // ResponseWriter's Hijack method if there is one, or returns an error. | 
					
						
							|  |  |  | func (r *ResponseRecorder) Hijack() { | 
					
						
							|  |  |  | 	r.ResponseWriter.Hijack() | 
					
						
							|  |  |  | 	return | 
					
						
							|  |  |  | } | 
					
						
							| 
									
										
										
										
											2016-03-19 11:16:08 +00:00
										 |  |  | 
 | 
					
						
							|  |  |  | type TestResponseWriter struct{} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | func (t *TestResponseWriter) LocalAddr() net.Addr { | 
					
						
							|  |  |  | 	ip := net.ParseIP("127.0.0.1") | 
					
						
							|  |  |  | 	port := 53 | 
					
						
							|  |  |  | 	return &net.UDPAddr{IP: ip, Port: port, Zone: ""} | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | func (t *TestResponseWriter) RemoteAddr() net.Addr { | 
					
						
							|  |  |  | 	ip := net.ParseIP("10.240.0.1") | 
					
						
							|  |  |  | 	port := 40212 | 
					
						
							|  |  |  | 	return &net.UDPAddr{IP: ip, Port: port, Zone: ""} | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | func (t *TestResponseWriter) WriteMsg(m *dns.Msg) error     { return nil } | 
					
						
							|  |  |  | func (t *TestResponseWriter) Write(buf []byte) (int, error) { return len(buf), nil } | 
					
						
							|  |  |  | func (t *TestResponseWriter) Close() error                  { return nil } | 
					
						
							|  |  |  | func (t *TestResponseWriter) TsigStatus() error             { return nil } | 
					
						
							|  |  |  | func (t *TestResponseWriter) TsigTimersOnly(bool)           { return } | 
					
						
							|  |  |  | func (t *TestResponseWriter) Hijack()                       { return } |