plugin/forward: added option failfast_all_unhealthy_upstreams to return servfail if all upstreams are down (#6999)

* feat: option to return servfail if upstreams are down

Signed-off-by: Puneet Loya <puneetloya@Puneets-MBP.attlocal.net>

* fix based on review comments and added to Readme

Signed-off-by: Puneet Loya <puneetloya@Puneets-MBP.attlocal.net>

* add tests to improve code coverage

Signed-off-by: Puneet Loya <puneetloya@Puneets-MBP.attlocal.net>

* added failfast_all_unhealthy_upstreams option to forward plugin

Signed-off-by: Puneet Loya <puneetloya@Puneets-MBP.attlocal.net>

---------

Signed-off-by: Puneet Loya <puneetloya@Puneets-MBP.attlocal.net>
Co-authored-by: Puneet Loya <puneetloya@Puneets-MBP.attlocal.net>
This commit is contained in:
Puneet Loya
2025-03-07 08:37:25 -08:00
committed by GitHub
parent 669ff527bf
commit 4de8fb57b2
5 changed files with 138 additions and 8 deletions

View File

@@ -51,6 +51,7 @@ forward FROM TO... {
health_check DURATION [no_rec] [domain FQDN] health_check DURATION [no_rec] [domain FQDN]
max_concurrent MAX max_concurrent MAX
next RCODE_1 [RCODE_2] [RCODE_3...] next RCODE_1 [RCODE_2] [RCODE_3...]
failfast_all_unhealthy_upstreams
} }
~~~ ~~~
@@ -97,6 +98,7 @@ forward FROM TO... {
at least greater than the expected *upstream query rate* * *latency* of the upstream servers. at least greater than the expected *upstream query rate* * *latency* of the upstream servers.
As an upper bound for **MAX**, consider that each concurrent query will use about 2kb of memory. As an upper bound for **MAX**, consider that each concurrent query will use about 2kb of memory.
* `next` If the `RCODE` (i.e. `NXDOMAIN`) is returned by the remote then execute the next plugin. If no next plugin is defined, or the next plugin is not a `forward` plugin, this setting is ignored * `next` If the `RCODE` (i.e. `NXDOMAIN`) is returned by the remote then execute the next plugin. If no next plugin is defined, or the next plugin is not a `forward` plugin, this setting is ignored
* `failfast_all_unhealthy_upstreams` - determines the handling of requests when all upstream servers are unhealthy and unresponsive to health checks. Enabling this option will immediately return SERVFAIL responses for all requests. By default, requests are sent to a random upstream.
Also note the TLS config is "global" for the whole forwarding proxy if you need a different Also note the TLS config is "global" for the whole forwarding proxy if you need a different
`tls_servername` for different upstreams you're out of luck. `tls_servername` for different upstreams you're out of luck.

View File

@@ -45,11 +45,12 @@ type Forward struct {
nextAlternateRcodes []int nextAlternateRcodes []int
tlsConfig *tls.Config tlsConfig *tls.Config
tlsServerName string tlsServerName string
maxfails uint32 maxfails uint32
expire time.Duration expire time.Duration
maxConcurrent int64 maxConcurrent int64
failfastUnhealthyUpstreams bool
opts proxy.Options // also here for testing opts proxy.Options // also here for testing
@@ -126,12 +127,16 @@ func (f *Forward) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg
if fails < len(f.proxies) { if fails < len(f.proxies) {
continue continue
} }
// All upstream proxies are dead, assume healthcheck is completely broken and randomly
healthcheckBrokenCount.Add(1)
// All upstreams are dead, return servfail if all upstreams are down
if f.failfastUnhealthyUpstreams {
break
}
// assume healthcheck is completely broken and randomly
// select an upstream to connect to. // select an upstream to connect to.
r := new(random) r := new(random)
proxy = r.List(f.proxies)[0] proxy = r.List(f.proxies)[0]
healthcheckBrokenCount.Add(1)
} }
if span != nil { if span != nil {

View File

@@ -277,3 +277,79 @@ func TestHealthDomain(t *testing.T) {
t.Errorf("Expected number of health checks with Domain==%s to be %d, got %d", hcDomain, 1, i1) t.Errorf("Expected number of health checks with Domain==%s to be %d, got %d", hcDomain, 1, i1)
} }
} }
func TestAllUpstreamsDown(t *testing.T) {
qs := uint32(0)
s := dnstest.NewServer(func(w dns.ResponseWriter, r *dns.Msg) {
// count non-healthcheck queries
if r.Question[0].Name != "." {
atomic.AddUint32(&qs, 1)
}
// timeout
})
defer s.Close()
s1 := dnstest.NewServer(func(w dns.ResponseWriter, r *dns.Msg) {
// count non-healthcheck queries
if r.Question[0].Name != "." {
atomic.AddUint32(&qs, 1)
}
// timeout
})
defer s1.Close()
p := proxy.NewProxy("TestHealthAllUpstreamsDown", s.Addr, transport.DNS)
p1 := proxy.NewProxy("TestHealthAllUpstreamsDown2", s1.Addr, transport.DNS)
p.GetHealthchecker().SetReadTimeout(10 * time.Millisecond)
p1.GetHealthchecker().SetReadTimeout(10 * time.Millisecond)
f := New()
f.SetProxy(p)
f.SetProxy(p1)
f.failfastUnhealthyUpstreams = true
f.maxfails = 1
// Make proxys fail by checking health twice
// i.e, fails > maxfails
for range f.maxfails + 1 {
p.GetHealthchecker().Check(p)
p1.GetHealthchecker().Check(p1)
}
defer f.OnShutdown()
// Check if all proxies are down
if !p.Down(f.maxfails) || !p1.Down(f.maxfails) {
t.Fatalf("Expected all proxies to be down")
}
req := new(dns.Msg)
req.SetQuestion("example.org.", dns.TypeA)
resp, err := f.ServeDNS(context.TODO(), &test.ResponseWriter{}, req)
if resp != dns.RcodeServerFailure {
t.Errorf("Expected Response code: %d, Got: %d", dns.RcodeServerFailure, resp)
}
if err != ErrNoHealthy {
t.Errorf("Expected error message: no healthy proxies, Got: %s", err.Error())
}
q1 := atomic.LoadUint32(&qs)
if q1 != 0 {
t.Errorf("Expected queries to the upstream: 0, Got: %d", q1)
}
// set failfast to false to check if queries get answered
f.failfastUnhealthyUpstreams = false
req = new(dns.Msg)
req.SetQuestion("example.org.", dns.TypeA)
_, err = f.ServeDNS(context.TODO(), &test.ResponseWriter{}, req)
if err == ErrNoHealthy {
t.Error("Unexpected error message: no healthy proxies")
}
q1 = atomic.LoadUint32(&qs)
if q1 != 1 {
t.Errorf("Expected queries to the upstream: 1, Got: %d", q1)
}
}

View File

@@ -306,6 +306,12 @@ func parseBlock(c *caddy.Controller, f *Forward) error {
f.nextAlternateRcodes = append(f.nextAlternateRcodes, rc) f.nextAlternateRcodes = append(f.nextAlternateRcodes, rc)
} }
case "failfast_all_unhealthy_upstreams":
args := c.RemainingArgs()
if len(args) != 0 {
return c.ArgErr()
}
f.failfastUnhealthyUpstreams = true
default: default:
return c.Errf("unknown property '%s'", c.Val()) return c.Errf("unknown property '%s'", c.Val())
} }

View File

@@ -382,3 +382,44 @@ func TestNextAlternate(t *testing.T) {
} }
} }
} }
func TestFailfastAllUnhealthyUpstreams(t *testing.T) {
tests := []struct {
input string
expectedRecVal bool
expectedErr string
}{
// positive
{"forward . 127.0.0.1\n", false, ""},
{"forward . 127.0.0.1 {\nfailfast_all_unhealthy_upstreams\n}\n", true, ""},
// negative
{"forward . 127.0.0.1 {\nfailfast_all_unhealthy_upstreams false\n}\n", false, "Wrong argument count"},
}
for i, test := range tests {
c := caddy.NewTestController("dns", test.input)
fs, err := parseForward(c)
if err != nil {
if test.expectedErr == "" {
t.Errorf("Test %d: expected no error but found one for input %s, got: %v", i, test.input, err)
}
if !strings.Contains(err.Error(), test.expectedErr) {
t.Errorf("Test %d: expected error to contain: %v, found error: %v, input: %s", i, test.expectedErr, err, test.input)
}
} else {
if test.expectedErr != "" {
t.Errorf("Test %d: expected error but found no error for input %s", i, test.input)
}
}
if test.expectedErr != "" {
continue
}
f := fs[0]
if f.failfastUnhealthyUpstreams != test.expectedRecVal {
t.Errorf("Test %d: Expected Rec:%v, got:%v", i, test.expectedRecVal, f.failfastUnhealthyUpstreams)
}
}
}