diff --git a/plugin/forward/forward.go b/plugin/forward/forward.go index 306519dc2..f8e1ffca8 100644 --- a/plugin/forward/forward.go +++ b/plugin/forward/forward.go @@ -49,6 +49,7 @@ type Forward struct { tlsServerName string maxfails uint32 expire time.Duration + maxAge time.Duration maxIdleConns int maxConcurrent int64 failfastUnhealthyUpstreams bool diff --git a/plugin/forward/setup.go b/plugin/forward/setup.go index 45cb00cdc..c6d850890 100644 --- a/plugin/forward/setup.go +++ b/plugin/forward/setup.go @@ -157,6 +157,10 @@ func parseStanza(c *caddy.Controller) (*Forward, error) { } } + if f.maxAge > 0 && f.maxAge < f.expire { + return f, fmt.Errorf("max_age (%s) must not be less than expire (%s)", f.maxAge, f.expire) + } + tlsServerNames := make([]string, len(toHosts)) perServerNameProxyCount := make(map[string]int) transports := make([]string, len(toHosts)) @@ -207,6 +211,7 @@ func parseStanza(c *caddy.Controller) (*Forward, error) { } } f.proxies[i].SetExpire(f.expire) + f.proxies[i].SetMaxAge(f.maxAge) f.proxies[i].SetMaxIdleConns(f.maxIdleConns) f.proxies[i].GetHealthchecker().SetRecursionDesired(f.opts.HCRecursionDesired) // when TLS is used, checks are set to tcp-tls @@ -323,6 +328,18 @@ func parseBlock(c *caddy.Controller, f *Forward) error { return fmt.Errorf("expire can't be negative: %s", dur) } f.expire = dur + case "max_age": + if !c.NextArg() { + return c.ArgErr() + } + dur, err := time.ParseDuration(c.Val()) + if err != nil { + return err + } + if dur < 0 { + return fmt.Errorf("max_age can't be negative: %s", dur) + } + f.maxAge = dur case "max_idle_conns": if !c.NextArg() { return c.ArgErr() diff --git a/plugin/forward/setup_test.go b/plugin/forward/setup_test.go index d606fdf81..06b245fc3 100644 --- a/plugin/forward/setup_test.go +++ b/plugin/forward/setup_test.go @@ -707,3 +707,78 @@ func TestFailoverValidation(t *testing.T) { }) } } + +func TestSetupMaxAge(t *testing.T) { + tests := []struct { + name string + input string + shouldErr bool + expectedVal time.Duration + expectedErr string + }{ + { + name: "default (no max_age)", + input: "forward . 127.0.0.1\n", + expectedVal: 0, + }, + { + name: "valid max_age", + input: "forward . 127.0.0.1 {\nmax_age 30s\n}\n", + expectedVal: 30 * time.Second, + }, + { + name: "max_age equal to expire", + input: "forward . 127.0.0.1 {\nexpire 10s\nmax_age 10s\n}\n", + expectedVal: 10 * time.Second, + }, + { + name: "max_age zero (unlimited)", + input: "forward . 127.0.0.1 {\nmax_age 0s\n}\n", + expectedVal: 0, + }, + { + name: "negative max_age", + input: "forward . 127.0.0.1 {\nmax_age -1s\n}\n", + shouldErr: true, + expectedErr: "negative", + }, + { + name: "invalid max_age value", + input: "forward . 127.0.0.1 {\nmax_age invalid\n}\n", + shouldErr: true, + expectedErr: "invalid", + }, + { + name: "max_age less than expire", + input: "forward . 127.0.0.1 {\nexpire 30s\nmax_age 10s\n}\n", + shouldErr: true, + expectedErr: "max_age", + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + c := caddy.NewTestController("dns", test.input) + fs, err := parseForward(c) + + if test.shouldErr { + if err == nil { + t.Errorf("expected error but found none for input %s", test.input) + return + } + if !strings.Contains(err.Error(), test.expectedErr) { + t.Errorf("expected error to contain %q, got: %v", test.expectedErr, err) + } + return + } + + if err != nil { + t.Errorf("expected no error but found: %v", err) + return + } + if fs[0].maxAge != test.expectedVal { + t.Errorf("expected maxAge %v, got %v", test.expectedVal, fs[0].maxAge) + } + }) + } +} diff --git a/plugin/pkg/proxy/connect.go b/plugin/pkg/proxy/connect.go index 4026cfbdd..33a8f426e 100644 --- a/plugin/pkg/proxy/connect.go +++ b/plugin/pkg/proxy/connect.go @@ -65,6 +65,11 @@ func (t *Transport) Dial(proto string) (*persistConn, bool, error) { transtype := stringToTransportType(proto) t.mu.Lock() + // Pre-compute max-age deadline outside the loop to avoid repeated time.Now() calls. + var maxAgeDeadline time.Time + if t.maxAge > 0 { + maxAgeDeadline = time.Now().Add(-t.maxAge) + } // FIFO: take the oldest conn (front of slice) for source port diversity for len(t.conns[transtype]) > 0 { pc := t.conns[transtype][0] @@ -73,6 +78,10 @@ func (t *Transport) Dial(proto string) (*persistConn, bool, error) { pc.c.Close() continue } + if !maxAgeDeadline.IsZero() && pc.created.Before(maxAgeDeadline) { + pc.c.Close() + continue + } t.mu.Unlock() connCacheHitsCount.WithLabelValues(t.proxyName, t.addr, proto).Add(1) return pc, true, nil @@ -86,11 +95,11 @@ func (t *Transport) Dial(proto string) (*persistConn, bool, error) { if proto == "tcp-tls" { conn, err := dns.DialTimeoutWithTLS("tcp", t.addr, t.tlsConfig, timeout) t.updateDialTimeout(time.Since(reqTime)) - return &persistConn{c: conn}, false, err + return &persistConn{c: conn, created: time.Now()}, false, err } conn, err := dns.DialTimeout(proto, t.addr, timeout) t.updateDialTimeout(time.Since(reqTime)) - return &persistConn{c: conn}, false, err + return &persistConn{c: conn, created: time.Now()}, false, err } // Connect selects an upstream, sends the request and waits for a response. diff --git a/plugin/pkg/proxy/persistent.go b/plugin/pkg/proxy/persistent.go index da2dca122..74c5d8d8d 100644 --- a/plugin/pkg/proxy/persistent.go +++ b/plugin/pkg/proxy/persistent.go @@ -9,17 +9,19 @@ import ( "github.com/miekg/dns" ) -// a persistConn hold the dns.Conn and the last used time. +// a persistConn holds the dns.Conn, its creation time, and the last used time. type persistConn struct { - c *dns.Conn - used time.Time + c *dns.Conn + created time.Time + used time.Time } // Transport hold the persistent cache. type Transport struct { avgDialTime int64 // kind of average time of dial time conns [typeTotalCount][]*persistConn // Buckets for udp, tcp and tcp-tls. - expire time.Duration // After this duration a connection is expired. + expire time.Duration // After this duration an idle connection is expired. + maxAge time.Duration // After this duration a connection is closed regardless of activity; 0 means unlimited. maxIdleConns int // Max idle connections per transport type; 0 means unlimited. addr string tlsConfig *tls.Config @@ -68,7 +70,13 @@ func (t *Transport) cleanup(all bool) { var toClose []*persistConn t.mu.Lock() - staleTime := time.Now().Add(-t.expire) + now := time.Now() + staleTime := now.Add(-t.expire) + // Pre-compute max-age deadline outside the loop to avoid repeated time.Now() calls. + var maxAgeDeadline time.Time + if t.maxAge > 0 { + maxAgeDeadline = now.Add(-t.maxAge) + } for transtype, stack := range t.conns { if len(stack) == 0 { continue @@ -78,10 +86,26 @@ func (t *Transport) cleanup(all bool) { toClose = append(toClose, stack...) continue } - if stack[0].used.After(staleTime) { + + // When max-age is set, use a linear scan to evaluate both the idle-timeout + // (expire, based on last-used time) and the max-age (based on creation time). + if t.maxAge > 0 { + var alive []*persistConn + for _, pc := range stack { + if !pc.used.After(staleTime) || pc.created.Before(maxAgeDeadline) { + toClose = append(toClose, pc) + } else { + alive = append(alive, pc) + } + } + t.conns[transtype] = alive continue } + // Original expire-only path: connections are sorted by "used"; use binary search. + if stack[0].used.After(staleTime) { + continue + } // connections in stack are sorted by "used" good := sort.Search(len(stack), func(i int) bool { return stack[i].used.After(staleTime) @@ -130,6 +154,10 @@ func (t *Transport) Stop() { close(t.stop) } // SetExpire sets the connection expire time in transport. func (t *Transport) SetExpire(expire time.Duration) { t.expire = expire } +// SetMaxAge sets the maximum lifetime of a connection regardless of activity. +// A value of 0 (default) disables max-age and connections are only closed by expire (idle-timeout). +func (t *Transport) SetMaxAge(maxAge time.Duration) { t.maxAge = maxAge } + // SetMaxIdleConns sets the maximum idle connections per transport type. // A value of 0 means unlimited (default). func (t *Transport) SetMaxIdleConns(n int) { t.maxIdleConns = n } diff --git a/plugin/pkg/proxy/persistent_test.go b/plugin/pkg/proxy/persistent_test.go index d2b96cddc..dbd9c057f 100644 --- a/plugin/pkg/proxy/persistent_test.go +++ b/plugin/pkg/proxy/persistent_test.go @@ -98,7 +98,12 @@ func TestCleanupAll(t *testing.T) { c2, _ := dns.DialTimeout("udp", tr.addr, maxDialTimeout) c3, _ := dns.DialTimeout("udp", tr.addr, maxDialTimeout) - tr.conns[typeUDP] = []*persistConn{{c1, time.Now()}, {c2, time.Now()}, {c3, time.Now()}} + now := time.Now() + tr.conns[typeUDP] = []*persistConn{ + {c: c1, created: now, used: now}, + {c: c2, created: now, used: now}, + {c: c3, created: now, used: now}, + } if len(tr.conns[typeUDP]) != 3 { t.Error("Expected 3 connections") @@ -226,6 +231,89 @@ func TestYieldAfterStop(t *testing.T) { } } +// TestMaxAgeExpireByCreation verifies that a connection is rejected when its +// creation time exceeds max_age, even if it was recently yielded (fresh used time). +// This guards against the FIFO rotation bug where used time is continually +// refreshed, preventing connections from expiring by idle-timeout alone. +func TestMaxAgeExpireByCreation(t *testing.T) { + s := dnstest.NewServer(func(w dns.ResponseWriter, r *dns.Msg) { + ret := new(dns.Msg) + ret.SetReply(r) + w.WriteMsg(ret) + }) + defer s.Close() + + tr := newTransport("TestMaxAgeExpireByCreation", s.Addr) + tr.SetExpire(10 * time.Second) // long idle-timeout: would not expire the connection + tr.SetMaxAge(100 * time.Millisecond) // short max-age: should close old connection + tr.Start() + defer tr.Stop() + + // Inject a connection whose creation time is past max_age but whose used + // time is fresh, simulating a FIFO-rotated connection that is never idle. + oldConn, err := dns.DialTimeout("udp", tr.addr, maxDialTimeout) + if err != nil { + t.Fatalf("Failed to dial: %v", err) + } + pc := &persistConn{ + c: oldConn, + created: time.Now().Add(-200 * time.Millisecond), // 2x max-age: should be closed + used: time.Now(), // freshly used: idle-timeout would pass + } + tr.mu.Lock() + tr.conns[typeUDP] = []*persistConn{pc} + tr.mu.Unlock() + + _, cached, _ := tr.Dial("udp") + if cached { + t.Error("connection should be closed by max_age, not reused despite fresh used time") + } +} + +// TestMaxAgeFIFORotation verifies that connections in a FIFO pool are closed by +// max_age even when continuously rotated (which refreshes their used timestamps). +// Regression test for Scale up: new upstream pods should receive traffic after +// existing connections exceed max_age, regardless of request rate. +func TestMaxAgeFIFORotation(t *testing.T) { + s := dnstest.NewServer(func(w dns.ResponseWriter, r *dns.Msg) { + ret := new(dns.Msg) + ret.SetReply(r) + w.WriteMsg(ret) + }) + defer s.Close() + + tr := newTransport("TestMaxAgeFIFORotation", s.Addr) + tr.SetExpire(10 * time.Second) // long idle-timeout: FIFO rotation keeps connections alive + tr.SetMaxAge(100 * time.Millisecond) // max-age: connections must be closed by creation age + tr.Start() + defer tr.Stop() + + // Inject 3 connections old by creation time but with fresh used timestamps, + // simulating active FIFO rotation where idle-timeout never triggers. + tr.mu.Lock() + for range 3 { + c, err := dns.DialTimeout("udp", tr.addr, maxDialTimeout) + if err != nil { + tr.mu.Unlock() + t.Fatalf("Failed to dial: %v", err) + } + tr.conns[typeUDP] = append(tr.conns[typeUDP], &persistConn{ + c: c, + created: time.Now().Add(-200 * time.Millisecond), // exceeds max-age + used: time.Now(), // fresh: idle-timeout would pass + }) + } + tr.mu.Unlock() + + // All 3 connections must be rejected by max_age despite fresh used timestamps. + for i := range 3 { + _, cached, _ := tr.Dial("udp") + if cached { + t.Errorf("Dial %d: connection should be closed by max_age (FIFO rotation must not prevent max-age expiry)", i+1) + } + } +} + func BenchmarkYield(b *testing.B) { s := dnstest.NewServer(func(w dns.ResponseWriter, r *dns.Msg) { ret := new(dns.Msg) diff --git a/plugin/pkg/proxy/proxy.go b/plugin/pkg/proxy/proxy.go index 6c460c397..d455da2f9 100644 --- a/plugin/pkg/proxy/proxy.go +++ b/plugin/pkg/proxy/proxy.go @@ -52,6 +52,10 @@ func (p *Proxy) SetTLSConfig(cfg *tls.Config) { // SetExpire sets the expire duration in the lower p.transport. func (p *Proxy) SetExpire(expire time.Duration) { p.transport.SetExpire(expire) } +// SetMaxAge sets the maximum connection lifetime in the lower p.transport. +// A value of 0 (default) disables max-age. +func (p *Proxy) SetMaxAge(maxAge time.Duration) { p.transport.SetMaxAge(maxAge) } + // SetMaxIdleConns sets the maximum idle connections per transport type. // A value of 0 means unlimited (default). func (p *Proxy) SetMaxIdleConns(n int) { p.transport.SetMaxIdleConns(n) }