mirror of
				https://github.com/coredns/coredns.git
				synced 2025-10-30 17:53:21 -04:00 
			
		
		
		
	plugin/trace: read trace context info from headers for DOH (#5439)
Signed-off-by: Ondřej Benkovský <ondrej.benkovsky@jamf.com>
This commit is contained in:
		| @@ -27,6 +27,9 @@ type ServerHTTPS struct { | |||||||
| 	validRequest func(*http.Request) bool | 	validRequest func(*http.Request) bool | ||||||
| } | } | ||||||
|  |  | ||||||
|  | // HTTPRequestKey is the context key for the current processed HTTP request (if current processed request was done over DOH) | ||||||
|  | type HTTPRequestKey struct{} | ||||||
|  |  | ||||||
| // NewServerHTTPS returns a new CoreDNS HTTPS server and compiles all plugins in to it. | // NewServerHTTPS returns a new CoreDNS HTTPS server and compiles all plugins in to it. | ||||||
| func NewServerHTTPS(addr string, group []*Config) (*ServerHTTPS, error) { | func NewServerHTTPS(addr string, group []*Config) (*ServerHTTPS, error) { | ||||||
| 	s, err := NewServer(addr, group) | 	s, err := NewServer(addr, group) | ||||||
| @@ -153,6 +156,7 @@ func (s *ServerHTTPS) ServeHTTP(w http.ResponseWriter, r *http.Request) { | |||||||
| 	// We should expect a packet to be returned that we can send to the client. | 	// We should expect a packet to be returned that we can send to the client. | ||||||
| 	ctx := context.WithValue(context.Background(), Key{}, s.Server) | 	ctx := context.WithValue(context.Background(), Key{}, s.Server) | ||||||
| 	ctx = context.WithValue(ctx, LoopKey{}, 0) | 	ctx = context.WithValue(ctx, LoopKey{}, 0) | ||||||
|  | 	ctx = context.WithValue(ctx, HTTPRequestKey{}, r) | ||||||
| 	s.ServeDNS(ctx, dw, msg) | 	s.ServeDNS(ctx, dw, msg) | ||||||
|  |  | ||||||
| 	// See section 4.2.1 of RFC 8484. | 	// See section 4.2.1 of RFC 8484. | ||||||
|   | |||||||
| @@ -4,9 +4,11 @@ package trace | |||||||
| import ( | import ( | ||||||
| 	"context" | 	"context" | ||||||
| 	"fmt" | 	"fmt" | ||||||
|  | 	"net/http" | ||||||
| 	"sync" | 	"sync" | ||||||
| 	"sync/atomic" | 	"sync/atomic" | ||||||
|  |  | ||||||
|  | 	"github.com/coredns/coredns/core/dnsserver" | ||||||
| 	"github.com/coredns/coredns/plugin" | 	"github.com/coredns/coredns/plugin" | ||||||
| 	"github.com/coredns/coredns/plugin/metadata" | 	"github.com/coredns/coredns/plugin/metadata" | ||||||
| 	"github.com/coredns/coredns/plugin/pkg/dnstest" | 	"github.com/coredns/coredns/plugin/pkg/dnstest" | ||||||
| @@ -140,8 +142,15 @@ func (t *trace) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) | |||||||
| 		return plugin.NextOrFailure(t.Name(), t.Next, ctx, w, r) | 		return plugin.NextOrFailure(t.Name(), t.Next, ctx, w, r) | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
|  | 	var spanCtx ot.SpanContext | ||||||
|  | 	if val := ctx.Value(dnsserver.HTTPRequestKey{}); val != nil { | ||||||
|  | 		if httpReq, ok := val.(*http.Request); ok { | ||||||
|  | 			spanCtx, _ = t.Tracer().Extract(ot.HTTPHeaders, ot.HTTPHeadersCarrier(httpReq.Header)) | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  |  | ||||||
| 	req := request.Request{W: w, Req: r} | 	req := request.Request{W: w, Req: r} | ||||||
| 	span = t.Tracer().StartSpan(defaultTopLevelSpanName) | 	span = t.Tracer().StartSpan(defaultTopLevelSpanName, otext.RPCServerOption(spanCtx)) | ||||||
| 	defer span.Finish() | 	defer span.Finish() | ||||||
|  |  | ||||||
| 	switch spanCtx := span.Context().(type) { | 	switch spanCtx := span.Context().(type) { | ||||||
|   | |||||||
| @@ -3,9 +3,11 @@ package trace | |||||||
| import ( | import ( | ||||||
| 	"context" | 	"context" | ||||||
| 	"errors" | 	"errors" | ||||||
|  | 	"net/http/httptest" | ||||||
| 	"testing" | 	"testing" | ||||||
|  |  | ||||||
| 	"github.com/coredns/caddy" | 	"github.com/coredns/caddy" | ||||||
|  | 	"github.com/coredns/coredns/core/dnsserver" | ||||||
| 	"github.com/coredns/coredns/plugin" | 	"github.com/coredns/coredns/plugin" | ||||||
| 	"github.com/coredns/coredns/plugin/pkg/dnstest" | 	"github.com/coredns/coredns/plugin/pkg/dnstest" | ||||||
| 	"github.com/coredns/coredns/plugin/pkg/rcode" | 	"github.com/coredns/coredns/plugin/pkg/rcode" | ||||||
| @@ -13,6 +15,7 @@ import ( | |||||||
| 	"github.com/coredns/coredns/request" | 	"github.com/coredns/coredns/request" | ||||||
|  |  | ||||||
| 	"github.com/miekg/dns" | 	"github.com/miekg/dns" | ||||||
|  | 	"github.com/opentracing/opentracing-go" | ||||||
| 	"github.com/opentracing/opentracing-go/mocktracer" | 	"github.com/opentracing/opentracing-go/mocktracer" | ||||||
| ) | ) | ||||||
|  |  | ||||||
| @@ -131,3 +134,40 @@ func TestTrace(t *testing.T) { | |||||||
| 		}) | 		}) | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
|  |  | ||||||
|  | func TestTrace_DOH_TraceHeaderExtraction(t *testing.T) { | ||||||
|  | 	w := dnstest.NewRecorder(&test.ResponseWriter{}) | ||||||
|  | 	m := mocktracer.New() | ||||||
|  | 	tr := &trace{ | ||||||
|  | 		Next: test.HandlerFunc(func(_ context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) { | ||||||
|  | 			if plugin.ClientWrite(dns.RcodeSuccess) { | ||||||
|  | 				m := new(dns.Msg) | ||||||
|  | 				m.SetRcode(r, dns.RcodeSuccess) | ||||||
|  | 				w.WriteMsg(m) | ||||||
|  | 			} | ||||||
|  | 			return dns.RcodeSuccess, nil | ||||||
|  | 		}), | ||||||
|  | 		every:  1, | ||||||
|  | 		tracer: m, | ||||||
|  | 	} | ||||||
|  | 	q := new(dns.Msg).SetQuestion("example.net.", dns.TypeA) | ||||||
|  |  | ||||||
|  | 	req := httptest.NewRequest("POST", "/dns-query", nil) | ||||||
|  |  | ||||||
|  | 	outsideSpan := m.StartSpan("test-header-span") | ||||||
|  | 	outsideSpan.Tracer().Inject(outsideSpan.Context(), opentracing.HTTPHeaders, opentracing.HTTPHeadersCarrier(req.Header)) | ||||||
|  | 	defer outsideSpan.Finish() | ||||||
|  |  | ||||||
|  | 	ctx := context.TODO() | ||||||
|  | 	ctx = context.WithValue(ctx, dnsserver.HTTPRequestKey{}, req) | ||||||
|  |  | ||||||
|  | 	tr.ServeDNS(ctx, w, q) | ||||||
|  |  | ||||||
|  | 	fs := m.FinishedSpans() | ||||||
|  | 	rootCoreDNSspan := fs[1] | ||||||
|  | 	rootCoreDNSTraceID := rootCoreDNSspan.Context().(mocktracer.MockSpanContext).TraceID | ||||||
|  | 	outsideSpanTraceID := outsideSpan.Context().(mocktracer.MockSpanContext).TraceID | ||||||
|  | 	if rootCoreDNSTraceID != outsideSpanTraceID { | ||||||
|  | 		t.Errorf("Unexpected traceID: rootSpan.TraceID: want %v, got %v", rootCoreDNSTraceID, outsideSpanTraceID) | ||||||
|  | 	} | ||||||
|  | } | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user