mirror of
https://github.com/coredns/coredns.git
synced 2025-12-03 00:54:01 -05:00
feat(forward): add max connect attempts knob (#7722)
This commit is contained in:
@@ -45,6 +45,7 @@ forward FROM TO... {
|
||||
prefer_udp
|
||||
expire DURATION
|
||||
max_fails INTEGER
|
||||
max_connect_attempts INTEGER
|
||||
tls CERT KEY CA
|
||||
tls_servername NAME
|
||||
policy random|round_robin|sequential
|
||||
@@ -66,6 +67,9 @@ forward FROM TO... {
|
||||
* `max_fails` is the number of subsequent failed health checks that are needed before considering
|
||||
an upstream to be down. If 0, the upstream will never be marked as down (nor health checked).
|
||||
Default is 2.
|
||||
* `max_connect_attempts` caps the total number of upstream connect attempts
|
||||
performed for a single incoming DNS request. Default value of 0 means no per-request
|
||||
cap.
|
||||
* `expire` **DURATION**, expire (cached) connections after this time, the default is 10s.
|
||||
* `tls` **CERT** **KEY** **CA** define the TLS properties for TLS connection. From 0 to 3 arguments can be
|
||||
provided with the meaning as described below
|
||||
|
||||
@@ -52,6 +52,7 @@ type Forward struct {
|
||||
maxConcurrent int64
|
||||
failfastUnhealthyUpstreams bool
|
||||
failoverRcodes []int
|
||||
maxConnectAttempts uint32
|
||||
|
||||
opts proxyPkg.Options // also here for testing
|
||||
|
||||
@@ -119,7 +120,9 @@ func (f *Forward) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg
|
||||
list := f.List()
|
||||
deadline := time.Now().Add(defaultTimeout)
|
||||
start := time.Now()
|
||||
for time.Now().Before(deadline) && ctx.Err() == nil {
|
||||
connectAttempts := uint32(0)
|
||||
|
||||
for time.Now().Before(deadline) && ctx.Err() == nil && (f.maxConnectAttempts == 0 || connectAttempts < f.maxConnectAttempts) {
|
||||
if i >= len(list) {
|
||||
// reached the end of list, reset to begin
|
||||
i = 0
|
||||
@@ -191,7 +194,15 @@ func (f *Forward) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg
|
||||
proxy.Healthcheck()
|
||||
}
|
||||
|
||||
fails++
|
||||
// If a per-request connect-attempt cap is configured, count this
|
||||
// failed connect attempt and stop retrying when the cap is hit.
|
||||
if f.maxConnectAttempts > 0 {
|
||||
connectAttempts++
|
||||
if connectAttempts >= f.maxConnectAttempts {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if fails < len(f.proxies) {
|
||||
continue
|
||||
}
|
||||
|
||||
@@ -93,45 +93,67 @@ func (m *mockResponseWriter) TsigStatus() error { return nil }
|
||||
func (m *mockResponseWriter) TsigTimersOnly(bool) {}
|
||||
func (m *mockResponseWriter) Hijack() {}
|
||||
|
||||
// TestForward_Regression_NoBusyLoop tests that the ServeDNS function does
|
||||
// not enter an infinite busy loop when the upstream DNS server refuses
|
||||
// the connection.
|
||||
// TestForward_Regression_NoBusyLoop ensures that ServeDNS does not perform
|
||||
// an unbounded number of upstream connect attempts for a single request when
|
||||
// maxConnectAttempts is configured, and that maxConnectAttempts=0 keeps the
|
||||
// legacy behaviour (no per-request cap).
|
||||
func TestForward_Regression_NoBusyLoop(t *testing.T) {
|
||||
f := New()
|
||||
|
||||
// ForceTCP ensures that connection refused errors happen immediately on Dial
|
||||
f.opts.ForceTCP = true
|
||||
|
||||
// Disable healthcheck
|
||||
f.maxfails = 0
|
||||
|
||||
// Assume nothing is listening on this port, so the connection will be refused.
|
||||
p := proxy.NewProxy("forward", "127.0.0.1:54321", "tcp")
|
||||
f.SetProxy(p)
|
||||
|
||||
// Create a mock tracer to count the number of connection attempts
|
||||
tracer := mocktracer.New()
|
||||
span := tracer.StartSpan("test")
|
||||
|
||||
// Create a context with the span and a short timeout
|
||||
ctx := opentracing.ContextWithSpan(context.Background(), span)
|
||||
timeout := 500 * time.Millisecond
|
||||
ctx, cancel := context.WithTimeout(ctx, timeout)
|
||||
defer cancel()
|
||||
|
||||
req := new(dns.Msg)
|
||||
req.SetQuestion("example.com.", dns.TypeA)
|
||||
|
||||
rw := &mockResponseWriter{}
|
||||
|
||||
_, err := f.ServeDNS(ctx, rw, req)
|
||||
spans := tracer.FinishedSpans()
|
||||
|
||||
if err == nil {
|
||||
t.Errorf("Expected error from ServeDNS due to connection refused, got nil")
|
||||
tests := []struct {
|
||||
name string
|
||||
maxAttempts uint32
|
||||
}{
|
||||
{name: "unbounded", maxAttempts: 0},
|
||||
{name: "single attempt", maxAttempts: 1},
|
||||
{name: "10 attempts", maxAttempts: 10},
|
||||
}
|
||||
|
||||
if len(spans) != 1 {
|
||||
t.Errorf("Expected 1 span, got %d", len(spans))
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
f := New()
|
||||
|
||||
// ForceTCP ensures that connection refused errors happen immediately on Dial.
|
||||
f.opts.ForceTCP = true
|
||||
// Disable healthcheck so that only the per-request attempts cap applies here.
|
||||
f.maxfails = 0
|
||||
|
||||
// Set maxConnectAttempts to the number of attempts we want to test.
|
||||
f.maxConnectAttempts = tc.maxAttempts
|
||||
|
||||
// Assume nothing is listening on this port, so the connection will be refused.
|
||||
p := proxy.NewProxy("forward", "127.0.0.1:54321", "tcp")
|
||||
f.SetProxy(p)
|
||||
|
||||
// Create a mock tracer to count the number of connection attempts.
|
||||
tracer := mocktracer.New()
|
||||
span := tracer.StartSpan("test")
|
||||
|
||||
ctx := opentracing.ContextWithSpan(context.Background(), span)
|
||||
timeout := 500 * time.Millisecond
|
||||
ctx, cancel := context.WithTimeout(ctx, timeout)
|
||||
defer cancel()
|
||||
|
||||
req := new(dns.Msg)
|
||||
req.SetQuestion("example.com.", dns.TypeA)
|
||||
|
||||
rw := &mockResponseWriter{}
|
||||
|
||||
_, err := f.ServeDNS(ctx, rw, req)
|
||||
spans := tracer.FinishedSpans()
|
||||
|
||||
if err == nil {
|
||||
t.Errorf("Expected error from ServeDNS due to connection refused, got nil")
|
||||
}
|
||||
|
||||
// In all cases we expect at least one attempt/span.
|
||||
if len(spans) == 0 {
|
||||
t.Errorf("Expected at least 1 span, got 0")
|
||||
}
|
||||
|
||||
// When maxConnectAttempts is configured (> 0), the number of connect
|
||||
// attempts as observed via spans should be equal to the configured value.
|
||||
if tc.maxAttempts > 0 && uint32(len(spans)) != tc.maxAttempts {
|
||||
t.Errorf("Expected %d spans, got %d", tc.maxAttempts, len(spans))
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -227,6 +227,15 @@ func parseBlock(c *caddy.Controller, f *Forward) error {
|
||||
return err
|
||||
}
|
||||
f.maxfails = uint32(n)
|
||||
case "max_connect_attempts":
|
||||
if !c.NextArg() {
|
||||
return c.ArgErr()
|
||||
}
|
||||
n, err := strconv.ParseUint(c.Val(), 10, 32)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
f.maxConnectAttempts = uint32(n)
|
||||
case "health_check":
|
||||
if !c.NextArg() {
|
||||
return c.ArgErr()
|
||||
|
||||
@@ -324,6 +324,47 @@ func TestSetupMaxConcurrent(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestSetupMaxConnectAttempts(t *testing.T) {
|
||||
tests := []struct {
|
||||
input string
|
||||
shouldErr bool
|
||||
expectedVal uint32
|
||||
expectedErr string
|
||||
}{
|
||||
|
||||
{"forward . 127.0.0.1 {\n}\n", false, 0, ""},
|
||||
{"forward . 127.0.0.1 {\nmax_connect_attempts 5\n}\n", false, 5, ""},
|
||||
{"forward . 127.0.0.1 {\nmax_connect_attempts many\n}\n", true, 0, "invalid"},
|
||||
{"forward . 127.0.0.1 {\nmax_connect_attempts -4\n}\n", true, 0, "invalid"},
|
||||
}
|
||||
|
||||
for i, test := range tests {
|
||||
c := caddy.NewTestController("dns", test.input)
|
||||
fs, err := parseForward(c)
|
||||
|
||||
if test.shouldErr && err == nil {
|
||||
t.Errorf("Test %d: expected error but found %s for input %s", i, err, test.input)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
if !test.shouldErr {
|
||||
t.Errorf("Test %d: expected no error but found one for input %s, got: %v", i, test.input, err)
|
||||
}
|
||||
|
||||
if !strings.Contains(err.Error(), test.expectedErr) {
|
||||
t.Errorf("Test %d: expected error to contain: %v, found error: %v, input: %s", i, test.expectedErr, err, test.input)
|
||||
}
|
||||
}
|
||||
|
||||
if !test.shouldErr {
|
||||
f := fs[0]
|
||||
if f.maxConnectAttempts != test.expectedVal {
|
||||
t.Errorf("Test %d: expected: %d, got: %d", i, test.expectedVal, f.maxConnectAttempts)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestSetupHealthCheck(t *testing.T) {
|
||||
tests := []struct {
|
||||
input string
|
||||
|
||||
Reference in New Issue
Block a user