From 4de8fb57b2b9111f9e5a4d57833cb9b6d2987ed8 Mon Sep 17 00:00:00 2001 From: Puneet Loya Date: Fri, 7 Mar 2025 08:37:25 -0800 Subject: [PATCH] plugin/forward: added option `failfast_all_unhealthy_upstreams` to return servfail if all upstreams are down (#6999) * feat: option to return servfail if upstreams are down Signed-off-by: Puneet Loya * fix based on review comments and added to Readme Signed-off-by: Puneet Loya * add tests to improve code coverage Signed-off-by: Puneet Loya * added failfast_all_unhealthy_upstreams option to forward plugin Signed-off-by: Puneet Loya --------- Signed-off-by: Puneet Loya Co-authored-by: Puneet Loya --- plugin/forward/README.md | 2 + plugin/forward/forward.go | 21 ++++++---- plugin/forward/health_test.go | 76 +++++++++++++++++++++++++++++++++++ plugin/forward/setup.go | 6 +++ plugin/forward/setup_test.go | 41 +++++++++++++++++++ 5 files changed, 138 insertions(+), 8 deletions(-) diff --git a/plugin/forward/README.md b/plugin/forward/README.md index bcfb4f355..33f4e60f6 100644 --- a/plugin/forward/README.md +++ b/plugin/forward/README.md @@ -51,6 +51,7 @@ forward FROM TO... { health_check DURATION [no_rec] [domain FQDN] max_concurrent MAX next RCODE_1 [RCODE_2] [RCODE_3...] + failfast_all_unhealthy_upstreams } ~~~ @@ -97,6 +98,7 @@ forward FROM TO... { at least greater than the expected *upstream query rate* * *latency* of the upstream servers. As an upper bound for **MAX**, consider that each concurrent query will use about 2kb of memory. * `next` If the `RCODE` (i.e. `NXDOMAIN`) is returned by the remote then execute the next plugin. If no next plugin is defined, or the next plugin is not a `forward` plugin, this setting is ignored +* `failfast_all_unhealthy_upstreams` - determines the handling of requests when all upstream servers are unhealthy and unresponsive to health checks. Enabling this option will immediately return SERVFAIL responses for all requests. By default, requests are sent to a random upstream. Also note the TLS config is "global" for the whole forwarding proxy if you need a different `tls_servername` for different upstreams you're out of luck. diff --git a/plugin/forward/forward.go b/plugin/forward/forward.go index cb22391e2..3ea255161 100644 --- a/plugin/forward/forward.go +++ b/plugin/forward/forward.go @@ -45,11 +45,12 @@ type Forward struct { nextAlternateRcodes []int - tlsConfig *tls.Config - tlsServerName string - maxfails uint32 - expire time.Duration - maxConcurrent int64 + tlsConfig *tls.Config + tlsServerName string + maxfails uint32 + expire time.Duration + maxConcurrent int64 + failfastUnhealthyUpstreams bool opts proxy.Options // also here for testing @@ -126,12 +127,16 @@ func (f *Forward) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg if fails < len(f.proxies) { continue } - // All upstream proxies are dead, assume healthcheck is completely broken and randomly + + healthcheckBrokenCount.Add(1) + // All upstreams are dead, return servfail if all upstreams are down + if f.failfastUnhealthyUpstreams { + break + } + // assume healthcheck is completely broken and randomly // select an upstream to connect to. r := new(random) proxy = r.List(f.proxies)[0] - - healthcheckBrokenCount.Add(1) } if span != nil { diff --git a/plugin/forward/health_test.go b/plugin/forward/health_test.go index 211a620c4..3f511385e 100644 --- a/plugin/forward/health_test.go +++ b/plugin/forward/health_test.go @@ -277,3 +277,79 @@ func TestHealthDomain(t *testing.T) { t.Errorf("Expected number of health checks with Domain==%s to be %d, got %d", hcDomain, 1, i1) } } + +func TestAllUpstreamsDown(t *testing.T) { + qs := uint32(0) + s := dnstest.NewServer(func(w dns.ResponseWriter, r *dns.Msg) { + // count non-healthcheck queries + if r.Question[0].Name != "." { + atomic.AddUint32(&qs, 1) + } + // timeout + }) + defer s.Close() + + s1 := dnstest.NewServer(func(w dns.ResponseWriter, r *dns.Msg) { + // count non-healthcheck queries + if r.Question[0].Name != "." { + atomic.AddUint32(&qs, 1) + } + // timeout + }) + defer s1.Close() + + p := proxy.NewProxy("TestHealthAllUpstreamsDown", s.Addr, transport.DNS) + p1 := proxy.NewProxy("TestHealthAllUpstreamsDown2", s1.Addr, transport.DNS) + p.GetHealthchecker().SetReadTimeout(10 * time.Millisecond) + p1.GetHealthchecker().SetReadTimeout(10 * time.Millisecond) + + f := New() + f.SetProxy(p) + f.SetProxy(p1) + f.failfastUnhealthyUpstreams = true + f.maxfails = 1 + // Make proxys fail by checking health twice + // i.e, fails > maxfails + for range f.maxfails + 1 { + p.GetHealthchecker().Check(p) + p1.GetHealthchecker().Check(p1) + } + + defer f.OnShutdown() + + // Check if all proxies are down + if !p.Down(f.maxfails) || !p1.Down(f.maxfails) { + t.Fatalf("Expected all proxies to be down") + } + req := new(dns.Msg) + req.SetQuestion("example.org.", dns.TypeA) + resp, err := f.ServeDNS(context.TODO(), &test.ResponseWriter{}, req) + + if resp != dns.RcodeServerFailure { + t.Errorf("Expected Response code: %d, Got: %d", dns.RcodeServerFailure, resp) + } + + if err != ErrNoHealthy { + t.Errorf("Expected error message: no healthy proxies, Got: %s", err.Error()) + } + + q1 := atomic.LoadUint32(&qs) + if q1 != 0 { + t.Errorf("Expected queries to the upstream: 0, Got: %d", q1) + } + + // set failfast to false to check if queries get answered + f.failfastUnhealthyUpstreams = false + + req = new(dns.Msg) + req.SetQuestion("example.org.", dns.TypeA) + _, err = f.ServeDNS(context.TODO(), &test.ResponseWriter{}, req) + if err == ErrNoHealthy { + t.Error("Unexpected error message: no healthy proxies") + } + + q1 = atomic.LoadUint32(&qs) + if q1 != 1 { + t.Errorf("Expected queries to the upstream: 1, Got: %d", q1) + } +} diff --git a/plugin/forward/setup.go b/plugin/forward/setup.go index e8211abf8..be5b12f94 100644 --- a/plugin/forward/setup.go +++ b/plugin/forward/setup.go @@ -306,6 +306,12 @@ func parseBlock(c *caddy.Controller, f *Forward) error { f.nextAlternateRcodes = append(f.nextAlternateRcodes, rc) } + case "failfast_all_unhealthy_upstreams": + args := c.RemainingArgs() + if len(args) != 0 { + return c.ArgErr() + } + f.failfastUnhealthyUpstreams = true default: return c.Errf("unknown property '%s'", c.Val()) } diff --git a/plugin/forward/setup_test.go b/plugin/forward/setup_test.go index 817ddfde9..c45ccce85 100644 --- a/plugin/forward/setup_test.go +++ b/plugin/forward/setup_test.go @@ -382,3 +382,44 @@ func TestNextAlternate(t *testing.T) { } } } + +func TestFailfastAllUnhealthyUpstreams(t *testing.T) { + tests := []struct { + input string + expectedRecVal bool + expectedErr string + }{ + // positive + {"forward . 127.0.0.1\n", false, ""}, + {"forward . 127.0.0.1 {\nfailfast_all_unhealthy_upstreams\n}\n", true, ""}, + // negative + {"forward . 127.0.0.1 {\nfailfast_all_unhealthy_upstreams false\n}\n", false, "Wrong argument count"}, + } + + for i, test := range tests { + c := caddy.NewTestController("dns", test.input) + fs, err := parseForward(c) + + if err != nil { + if test.expectedErr == "" { + t.Errorf("Test %d: expected no error but found one for input %s, got: %v", i, test.input, err) + } + if !strings.Contains(err.Error(), test.expectedErr) { + t.Errorf("Test %d: expected error to contain: %v, found error: %v, input: %s", i, test.expectedErr, err, test.input) + } + } else { + if test.expectedErr != "" { + t.Errorf("Test %d: expected error but found no error for input %s", i, test.input) + } + } + + if test.expectedErr != "" { + continue + } + + f := fs[0] + if f.failfastUnhealthyUpstreams != test.expectedRecVal { + t.Errorf("Test %d: Expected Rec:%v, got:%v", i, test.expectedRecVal, f.failfastUnhealthyUpstreams) + } + } +}