From 21176fbf1ab86459f1f7e0ad05e1bd591bd1242f Mon Sep 17 00:00:00 2001 From: Ville Vesilehto Date: Tue, 2 Sep 2025 04:09:51 +0300 Subject: [PATCH] fix(grpc): span leak on error attempt (#7487) --- plugin/grpc/grpc.go | 27 ++++++++------ plugin/grpc/grpc_test.go | 77 ++++++++++++++++++++++++++++++++++++++++ 2 files changed, 93 insertions(+), 11 deletions(-) diff --git a/plugin/grpc/grpc.go b/plugin/grpc/grpc.go index ea5837e3e..31c3f7de8 100644 --- a/plugin/grpc/grpc.go +++ b/plugin/grpc/grpc.go @@ -43,10 +43,10 @@ func (g *GRPC) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) ( } var ( - span, child ot.Span - ret *dns.Msg - err error - i int + span ot.Span + ret *dns.Msg + err error + i int ) span = ot.SpanFromContext(ctx) list := g.list() @@ -65,20 +65,25 @@ func (g *GRPC) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) ( proxy := list[i] i++ + callCtx := ctx + var child ot.Span if span != nil { - child = span.Tracer().StartSpan("query", ot.ChildOf(span.Context())) - ctx = ot.ContextWithSpan(ctx, child) + child, callCtx = ot.StartSpanFromContext(callCtx, "query") } - ret, err = proxy.query(ctx, r) - if err != nil { - // Continue with the next proxy - continue - } + var cancel context.CancelFunc + callCtx, cancel = context.WithDeadline(callCtx, deadline) + + ret, err = proxy.query(callCtx, r) + cancel() if child != nil { child.Finish() } + if err != nil { + // Continue with the next proxy + continue + } // Check if the reply is correct; if not return FormErr. if !state.Match(ret) { diff --git a/plugin/grpc/grpc_test.go b/plugin/grpc/grpc_test.go index 1ab7a85b4..f0ea12efb 100644 --- a/plugin/grpc/grpc_test.go +++ b/plugin/grpc/grpc_test.go @@ -5,6 +5,7 @@ import ( "errors" "strings" "testing" + "time" "github.com/coredns/coredns/pb" "github.com/coredns/coredns/plugin/pkg/dnstest" @@ -12,6 +13,9 @@ import ( "github.com/coredns/coredns/plugin/test" "github.com/miekg/dns" + ot "github.com/opentracing/opentracing-go" + "github.com/opentracing/opentracing-go/mocktracer" + grpcgo "google.golang.org/grpc" ) func TestGRPC(t *testing.T) { @@ -102,3 +106,76 @@ func TestGRPCFallthroughNoNext(t *testing.T) { t.Errorf("Expected SERVFAIL when no backends and no next plugin, got: %d", rcode) } } + +// deadlineCheckingClient records whether a deadline was attached to ctx. +type deadlineCheckingClient struct { + sawDeadline bool + lastDeadline time.Time + dnsPacket *pb.DnsPacket + err error +} + +func (c *deadlineCheckingClient) Query(ctx context.Context, in *pb.DnsPacket, opts ...grpcgo.CallOption) (*pb.DnsPacket, error) { + if dl, ok := ctx.Deadline(); ok { + c.sawDeadline = true + c.lastDeadline = dl + } + return c.dnsPacket, c.err +} + +// Test that on error paths we still finish child spans, and that we set a per-call deadline. +func TestGRPC_SpansOnErrorPath(t *testing.T) { + m := &dns.Msg{} + msgBytes, err := m.Pack() + if err != nil { + t.Fatalf("Error packing response: %s", err) + } + dnsPacket := &pb.DnsPacket{Msg: msgBytes} + + // Proxy 1: returns error, we should still finish its child span and have a deadline + p1 := &deadlineCheckingClient{dnsPacket: nil, err: errors.New("kaboom")} + // Proxy 2: returns success + p2 := &deadlineCheckingClient{dnsPacket: dnsPacket, err: nil} + + g := newGRPC() + g.from = "." + g.proxies = []*Proxy{{client: p1}, {client: p2}} + + // Ensure deterministic order of the retries: try p1 then p2 + g.p = new(sequential) + + // Set a parent span in context so ServeDNS creates child spans per attempt + tracer := mocktracer.New() + prev := ot.GlobalTracer() + ot.SetGlobalTracer(tracer) + defer ot.SetGlobalTracer(prev) + + parent := tracer.StartSpan("parent") + ctx := ot.ContextWithSpan(t.Context(), parent) + + rec := dnstest.NewRecorder(&test.ResponseWriter{}) + if _, err := g.ServeDNS(ctx, rec, m); err != nil { + t.Fatalf("ServeDNS returned error: %v", err) + } + + // Assert both attempts finished child spans with retries + // (2 query spans: error + success) + finished := tracer.FinishedSpans() + var finishedQueries int + for _, s := range finished { + if s.OperationName == "query" { + finishedQueries++ + } + } + if finishedQueries != 2 { + t.Fatalf("expected 2 finished 'query' spans, got %d (finished: %v)", finishedQueries, finished) + } + + // Assert we set a deadline on the call contexts + if !p1.sawDeadline { + t.Fatalf("expected deadline to be set on first proxy call context") + } + if !p2.sawDeadline { + t.Fatalf("expected deadline to be set on second proxy call context") + } +}