feat(forward): add max connect attempts knob (#7722)

This commit is contained in:
Ville Vesilehto
2025-12-02 04:06:52 +02:00
committed by GitHub
parent 5cb2c5dbf5
commit c2894d47d6
5 changed files with 126 additions and 39 deletions

View File

@@ -45,6 +45,7 @@ forward FROM TO... {
prefer_udp
expire DURATION
max_fails INTEGER
max_connect_attempts INTEGER
tls CERT KEY CA
tls_servername NAME
policy random|round_robin|sequential
@@ -66,6 +67,9 @@ forward FROM TO... {
* `max_fails` is the number of subsequent failed health checks that are needed before considering
an upstream to be down. If 0, the upstream will never be marked as down (nor health checked).
Default is 2.
* `max_connect_attempts` caps the total number of upstream connect attempts
performed for a single incoming DNS request. Default value of 0 means no per-request
cap.
* `expire` **DURATION**, expire (cached) connections after this time, the default is 10s.
* `tls` **CERT** **KEY** **CA** define the TLS properties for TLS connection. From 0 to 3 arguments can be
provided with the meaning as described below

View File

@@ -52,6 +52,7 @@ type Forward struct {
maxConcurrent int64
failfastUnhealthyUpstreams bool
failoverRcodes []int
maxConnectAttempts uint32
opts proxyPkg.Options // also here for testing
@@ -119,7 +120,9 @@ func (f *Forward) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg
list := f.List()
deadline := time.Now().Add(defaultTimeout)
start := time.Now()
for time.Now().Before(deadline) && ctx.Err() == nil {
connectAttempts := uint32(0)
for time.Now().Before(deadline) && ctx.Err() == nil && (f.maxConnectAttempts == 0 || connectAttempts < f.maxConnectAttempts) {
if i >= len(list) {
// reached the end of list, reset to begin
i = 0
@@ -191,7 +194,15 @@ func (f *Forward) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg
proxy.Healthcheck()
}
fails++
// If a per-request connect-attempt cap is configured, count this
// failed connect attempt and stop retrying when the cap is hit.
if f.maxConnectAttempts > 0 {
connectAttempts++
if connectAttempts >= f.maxConnectAttempts {
break
}
}
if fails < len(f.proxies) {
continue
}

View File

@@ -93,45 +93,67 @@ func (m *mockResponseWriter) TsigStatus() error { return nil }
func (m *mockResponseWriter) TsigTimersOnly(bool) {}
func (m *mockResponseWriter) Hijack() {}
// TestForward_Regression_NoBusyLoop tests that the ServeDNS function does
// not enter an infinite busy loop when the upstream DNS server refuses
// the connection.
// TestForward_Regression_NoBusyLoop ensures that ServeDNS does not perform
// an unbounded number of upstream connect attempts for a single request when
// maxConnectAttempts is configured, and that maxConnectAttempts=0 keeps the
// legacy behaviour (no per-request cap).
func TestForward_Regression_NoBusyLoop(t *testing.T) {
f := New()
// ForceTCP ensures that connection refused errors happen immediately on Dial
f.opts.ForceTCP = true
// Disable healthcheck
f.maxfails = 0
// Assume nothing is listening on this port, so the connection will be refused.
p := proxy.NewProxy("forward", "127.0.0.1:54321", "tcp")
f.SetProxy(p)
// Create a mock tracer to count the number of connection attempts
tracer := mocktracer.New()
span := tracer.StartSpan("test")
// Create a context with the span and a short timeout
ctx := opentracing.ContextWithSpan(context.Background(), span)
timeout := 500 * time.Millisecond
ctx, cancel := context.WithTimeout(ctx, timeout)
defer cancel()
req := new(dns.Msg)
req.SetQuestion("example.com.", dns.TypeA)
rw := &mockResponseWriter{}
_, err := f.ServeDNS(ctx, rw, req)
spans := tracer.FinishedSpans()
if err == nil {
t.Errorf("Expected error from ServeDNS due to connection refused, got nil")
tests := []struct {
name string
maxAttempts uint32
}{
{name: "unbounded", maxAttempts: 0},
{name: "single attempt", maxAttempts: 1},
{name: "10 attempts", maxAttempts: 10},
}
if len(spans) != 1 {
t.Errorf("Expected 1 span, got %d", len(spans))
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
f := New()
// ForceTCP ensures that connection refused errors happen immediately on Dial.
f.opts.ForceTCP = true
// Disable healthcheck so that only the per-request attempts cap applies here.
f.maxfails = 0
// Set maxConnectAttempts to the number of attempts we want to test.
f.maxConnectAttempts = tc.maxAttempts
// Assume nothing is listening on this port, so the connection will be refused.
p := proxy.NewProxy("forward", "127.0.0.1:54321", "tcp")
f.SetProxy(p)
// Create a mock tracer to count the number of connection attempts.
tracer := mocktracer.New()
span := tracer.StartSpan("test")
ctx := opentracing.ContextWithSpan(context.Background(), span)
timeout := 500 * time.Millisecond
ctx, cancel := context.WithTimeout(ctx, timeout)
defer cancel()
req := new(dns.Msg)
req.SetQuestion("example.com.", dns.TypeA)
rw := &mockResponseWriter{}
_, err := f.ServeDNS(ctx, rw, req)
spans := tracer.FinishedSpans()
if err == nil {
t.Errorf("Expected error from ServeDNS due to connection refused, got nil")
}
// In all cases we expect at least one attempt/span.
if len(spans) == 0 {
t.Errorf("Expected at least 1 span, got 0")
}
// When maxConnectAttempts is configured (> 0), the number of connect
// attempts as observed via spans should be equal to the configured value.
if tc.maxAttempts > 0 && uint32(len(spans)) != tc.maxAttempts {
t.Errorf("Expected %d spans, got %d", tc.maxAttempts, len(spans))
}
})
}
}

View File

@@ -227,6 +227,15 @@ func parseBlock(c *caddy.Controller, f *Forward) error {
return err
}
f.maxfails = uint32(n)
case "max_connect_attempts":
if !c.NextArg() {
return c.ArgErr()
}
n, err := strconv.ParseUint(c.Val(), 10, 32)
if err != nil {
return err
}
f.maxConnectAttempts = uint32(n)
case "health_check":
if !c.NextArg() {
return c.ArgErr()

View File

@@ -324,6 +324,47 @@ func TestSetupMaxConcurrent(t *testing.T) {
}
}
func TestSetupMaxConnectAttempts(t *testing.T) {
tests := []struct {
input string
shouldErr bool
expectedVal uint32
expectedErr string
}{
{"forward . 127.0.0.1 {\n}\n", false, 0, ""},
{"forward . 127.0.0.1 {\nmax_connect_attempts 5\n}\n", false, 5, ""},
{"forward . 127.0.0.1 {\nmax_connect_attempts many\n}\n", true, 0, "invalid"},
{"forward . 127.0.0.1 {\nmax_connect_attempts -4\n}\n", true, 0, "invalid"},
}
for i, test := range tests {
c := caddy.NewTestController("dns", test.input)
fs, err := parseForward(c)
if test.shouldErr && err == nil {
t.Errorf("Test %d: expected error but found %s for input %s", i, err, test.input)
}
if err != nil {
if !test.shouldErr {
t.Errorf("Test %d: expected no error but found one for input %s, got: %v", i, test.input, err)
}
if !strings.Contains(err.Error(), test.expectedErr) {
t.Errorf("Test %d: expected error to contain: %v, found error: %v, input: %s", i, test.expectedErr, err, test.input)
}
}
if !test.shouldErr {
f := fs[0]
if f.maxConnectAttempts != test.expectedVal {
t.Errorf("Test %d: expected: %d, got: %d", i, test.expectedVal, f.maxConnectAttempts)
}
}
}
}
func TestSetupHealthCheck(t *testing.T) {
tests := []struct {
input string