diff --git a/plugin/forward/forward.go b/plugin/forward/forward.go index cec1adb9c..e0a9f7133 100644 --- a/plugin/forward/forward.go +++ b/plugin/forward/forward.go @@ -191,6 +191,7 @@ func (f *Forward) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg proxy.Healthcheck() } + fails++ if fails < len(f.proxies) { continue } diff --git a/plugin/forward/forward_test.go b/plugin/forward/forward_test.go index aca58cbf9..0f2b02ae5 100644 --- a/plugin/forward/forward_test.go +++ b/plugin/forward/forward_test.go @@ -1,8 +1,11 @@ package forward import ( + "context" + "net" "strings" "testing" + "time" "github.com/coredns/caddy" "github.com/coredns/caddy/caddyfile" @@ -10,6 +13,10 @@ import ( "github.com/coredns/coredns/plugin/dnstap" "github.com/coredns/coredns/plugin/pkg/proxy" "github.com/coredns/coredns/plugin/pkg/transport" + + "github.com/miekg/dns" + "github.com/opentracing/opentracing-go" + "github.com/opentracing/opentracing-go/mocktracer" ) func TestList(t *testing.T) { @@ -74,3 +81,57 @@ func TestSetTapPlugin(t *testing.T) { t.Error("Unexpected order of dnstap plugins") } } + +type mockResponseWriter struct{} + +func (m *mockResponseWriter) LocalAddr() net.Addr { return nil } +func (m *mockResponseWriter) RemoteAddr() net.Addr { return nil } +func (m *mockResponseWriter) WriteMsg(msg *dns.Msg) error { return nil } +func (m *mockResponseWriter) Write([]byte) (int, error) { return 0, nil } +func (m *mockResponseWriter) Close() error { return nil } +func (m *mockResponseWriter) TsigStatus() error { return nil } +func (m *mockResponseWriter) TsigTimersOnly(bool) {} +func (m *mockResponseWriter) Hijack() {} + +// TestForward_Regression_NoBusyLoop tests that the ServeDNS function does +// not enter an infinite busy loop when the upstream DNS server refuses +// the connection. +func TestForward_Regression_NoBusyLoop(t *testing.T) { + f := New() + + // ForceTCP ensures that connection refused errors happen immediately on Dial + f.opts.ForceTCP = true + + // Disable healthcheck + f.maxfails = 0 + + // Assume nothing is listening on this port, so the connection will be refused. + p := proxy.NewProxy("forward", "127.0.0.1:54321", "tcp") + f.SetProxy(p) + + // Create a mock tracer to count the number of connection attempts + tracer := mocktracer.New() + span := tracer.StartSpan("test") + + // Create a context with the span and a short timeout + ctx := opentracing.ContextWithSpan(context.Background(), span) + timeout := 500 * time.Millisecond + ctx, cancel := context.WithTimeout(ctx, timeout) + defer cancel() + + req := new(dns.Msg) + req.SetQuestion("example.com.", dns.TypeA) + + rw := &mockResponseWriter{} + + _, err := f.ServeDNS(ctx, rw, req) + spans := tracer.FinishedSpans() + + if err == nil { + t.Errorf("Expected error from ServeDNS due to connection refused, got nil") + } + + if len(spans) != 1 { + t.Errorf("Expected 1 span, got %d", len(spans)) + } +}