From f3983c1111b546c1659117cfe74239e4ff18a3a2 Mon Sep 17 00:00:00 2001 From: Ville Vesilehto Date: Wed, 14 Jan 2026 03:49:46 +0200 Subject: [PATCH] perf(proxy): use mutex-based connection pool (#7790) * perf(proxy): use mutex-based connection pool The proxy package (used for example by the forward plugin) utilized an actor model where a single connManager goroutine managed connection pooling via unbuffered channels (dial, yield, ret). This design serialized all connection acquisition and release operations through a single goroutine, creating a bottleneck under high concurrency. This was observable as a performance degradation when using a single upstream backend compared to multiple backends (which sharded the bottleneck). Changes: - Removed dial, yield, and ret channels from the Transport struct. - Removed the connManager goroutine's request processing loop. - Implemented Dial() and Yield() using a sync.Mutex to protect the connection slice, allowing for fast concurrent access without context switching. - Downgraded connManager to a simple background cleanup loop that only handles connection expiration on a ticker. - Updated plugin/pkg/proxy/connect.go to use direct method calls instead of channel sends. - Updated tests to reflect the removal of internal channels. Benchmarks show that this change eliminates the single-backend bottleneck. Now a single upstream backend performs on par with multiple backends, and overall throughput is improved. The implementation aligns with standard Go patterns for connection pooling (e.g., net/http.Transport). Signed-off-by: Ville Vesilehto * fix: address PR review for persistent.go - Named mutex field instead of embedding, to not expose Lock() and Unlock() - Move stop check outside of lock in Yield() - Close() without a separate goroutine - Change stop channel to struct Signed-off-by: Ville Vesilehto * fix: address code review feedback for conn pool - Switch from LIFO to FIFO connection selection for source port diversity, reducing DNS cache poisoning risk (RFC 5452). - Remove "clear entire cache" optimization as it was LIFO-specific. FIFO naturally iterates and skips expired connections. - Remove all goroutines for closing connections; collect connections while holding lock, close synchronously after releasing lock. Signed-off-by: Ville Vesilehto * fix: remove unused error consts No longer utilised after refactoring the channel based approach. Signed-off-by: Ville Vesilehto * feat(forward): add max_idle_conns option Add configurable connection pool limit for the forward plugin via the max_idle_conns Corefile option. Changes: - Add SetMaxIdleConns to proxy - Add maxIdleConns field to Forward struct - Add max_idle_conns parsing in forward plugin setup - Apply setting to each proxy during configuration - Update forward plugin README with new option By default the value is 0 (unbounded). When set, excess connections returned to the pool are closed immediately rather than cached. Also add a yield related test. Signed-off-by: Ville Vesilehto * chore(proxy): simple Dial by closing conns inline Remove toClose slice collection to reduce complexity. Instead close expired connections directly while iterating. Reduces complexity with negligible lock-time impact. Signed-off-by: Ville Vesilehto * chore: fewer explicit Unlock calls Cleaner and less chance of forgetting to unlock on new possible code paths. Signed-off-by: Ville Vesilehto --------- Signed-off-by: Ville Vesilehto --- plugin/forward/README.md | 3 + plugin/forward/forward.go | 1 + plugin/forward/setup.go | 13 ++ plugin/forward/setup_test.go | 42 ++++++ plugin/pkg/proxy/connect.go | 64 ++++----- plugin/pkg/proxy/connect_test.go | 204 ---------------------------- plugin/pkg/proxy/persistent.go | 98 ++++++------- plugin/pkg/proxy/persistent_test.go | 131 +++++++++++++++++- plugin/pkg/proxy/proxy.go | 4 + plugin/pkg/proxy/proxy_test.go | 79 ++++++++--- test/proxy_test.go | 75 ++++++++++ 11 files changed, 386 insertions(+), 328 deletions(-) diff --git a/plugin/forward/README.md b/plugin/forward/README.md index 436b6c2e5..44679f091 100644 --- a/plugin/forward/README.md +++ b/plugin/forward/README.md @@ -44,6 +44,7 @@ forward FROM TO... { force_tcp prefer_udp expire DURATION + max_idle_conns INTEGER max_fails INTEGER max_connect_attempts INTEGER tls CERT KEY CA @@ -71,6 +72,8 @@ forward FROM TO... { performed for a single incoming DNS request. Default value of 0 means no per-request cap. * `expire` **DURATION**, expire (cached) connections after this time, the default is 10s. +* `max_idle_conns` **INTEGER**, maximum number of idle connections to cache per upstream for reuse. + Default is 0, which means unlimited. * `tls` **CERT** **KEY** **CA** define the TLS properties for TLS connection. From 0 to 3 arguments can be provided with the meaning as described below diff --git a/plugin/forward/forward.go b/plugin/forward/forward.go index 449579e5a..306519dc2 100644 --- a/plugin/forward/forward.go +++ b/plugin/forward/forward.go @@ -49,6 +49,7 @@ type Forward struct { tlsServerName string maxfails uint32 expire time.Duration + maxIdleConns int maxConcurrent int64 failfastUnhealthyUpstreams bool failoverRcodes []int diff --git a/plugin/forward/setup.go b/plugin/forward/setup.go index 6469bfad2..6f17882f9 100644 --- a/plugin/forward/setup.go +++ b/plugin/forward/setup.go @@ -196,6 +196,7 @@ func parseStanza(c *caddy.Controller) (*Forward, error) { } } f.proxies[i].SetExpire(f.expire) + f.proxies[i].SetMaxIdleConns(f.maxIdleConns) f.proxies[i].GetHealthchecker().SetRecursionDesired(f.opts.HCRecursionDesired) // when TLS is used, checks are set to tcp-tls if f.opts.ForceTCP && transports[i] != transport.TLS { @@ -311,6 +312,18 @@ func parseBlock(c *caddy.Controller, f *Forward) error { return fmt.Errorf("expire can't be negative: %s", dur) } f.expire = dur + case "max_idle_conns": + if !c.NextArg() { + return c.ArgErr() + } + n, err := strconv.Atoi(c.Val()) + if err != nil { + return err + } + if n < 0 { + return fmt.Errorf("max_idle_conns can't be negative: %d", n) + } + f.maxIdleConns = n case "policy": if !c.NextArg() { return c.ArgErr() diff --git a/plugin/forward/setup_test.go b/plugin/forward/setup_test.go index 49195185a..260eeb567 100644 --- a/plugin/forward/setup_test.go +++ b/plugin/forward/setup_test.go @@ -365,6 +365,48 @@ func TestSetupMaxConnectAttempts(t *testing.T) { } } +func TestSetupMaxIdleConns(t *testing.T) { + tests := []struct { + input string + shouldErr bool + expectedVal int + expectedErr string + }{ + {"forward . 127.0.0.1\n", false, 0, ""}, + {"forward . 127.0.0.1 {\nmax_idle_conns 10\n}\n", false, 10, ""}, + {"forward . 127.0.0.1 {\nmax_idle_conns 0\n}\n", false, 0, ""}, + {"forward . 127.0.0.1 {\nmax_idle_conns many\n}\n", true, 0, "invalid"}, + {"forward . 127.0.0.1 {\nmax_idle_conns -1\n}\n", true, 0, "negative"}, + } + + for i, test := range tests { + c := caddy.NewTestController("dns", test.input) + fs, err := parseForward(c) + + if test.shouldErr && err == nil { + t.Errorf("Test %d: expected error but found none for input %s", i, test.input) + } + + if err != nil { + if !test.shouldErr { + t.Errorf("Test %d: expected no error but found one for input %s, got: %v", i, test.input, err) + } + + if !strings.Contains(err.Error(), test.expectedErr) { + t.Errorf("Test %d: expected error to contain: %v, found error: %v, input: %s", i, test.expectedErr, err, test.input) + } + } + + if test.shouldErr { + continue + } + f := fs[0] + if f.maxIdleConns != test.expectedVal { + t.Errorf("Test %d: expected: %d, got: %d", i, test.expectedVal, f.maxIdleConns) + } + } +} + func TestSetupHealthCheck(t *testing.T) { tests := []struct { input string diff --git a/plugin/pkg/proxy/connect.go b/plugin/pkg/proxy/connect.go index 1bce4cfa2..4026cfbdd 100644 --- a/plugin/pkg/proxy/connect.go +++ b/plugin/pkg/proxy/connect.go @@ -1,7 +1,6 @@ -// Package proxy implements a forwarding proxy. It caches an upstream net.Conn for some time, so if the same -// client returns the upstream's Conn will be precached. Depending on how you benchmark this looks to be -// 50% faster than just opening a new connection for every client. It works with UDP and TCP and uses -// inband healthchecking. +// Package proxy implements a forwarding proxy with connection caching. +// It manages a pool of upstream connections (UDP and TCP) to reuse them for subsequent requests, +// reducing latency and handshake overhead. It supports in-band health checking. package proxy import ( @@ -19,10 +18,7 @@ import ( ) const ( - ErrTransportStopped = "proxy: transport stopped" - ErrTransportStoppedDuringDial = "proxy: transport stopped during dial" - ErrTransportStoppedRetClosed = "proxy: transport stopped, ret channel closed" - ErrTransportStoppedDuringRetWait = "proxy: transport stopped during ret wait" + ErrTransportStopped = "proxy: transport stopped" ) // limitTimeout is a utility function to auto-tune timeout values @@ -66,41 +62,35 @@ func (t *Transport) Dial(proto string) (*persistConn, bool, error) { default: } - // Use select to avoid blocking if connManager has stopped - select { - case t.dial <- proto: - // Successfully sent dial request - case <-t.stop: - return nil, false, errors.New(ErrTransportStoppedDuringDial) + transtype := stringToTransportType(proto) + + t.mu.Lock() + // FIFO: take the oldest conn (front of slice) for source port diversity + for len(t.conns[transtype]) > 0 { + pc := t.conns[transtype][0] + t.conns[transtype] = t.conns[transtype][1:] + if time.Since(pc.used) > t.expire { + pc.c.Close() + continue + } + t.mu.Unlock() + connCacheHitsCount.WithLabelValues(t.proxyName, t.addr, proto).Add(1) + return pc, true, nil } + t.mu.Unlock() - // Receive response with stop awareness - select { - case pc, ok := <-t.ret: - if !ok { - // ret channel was closed by connManager during stop - return nil, false, errors.New(ErrTransportStoppedRetClosed) - } + connCacheMissesCount.WithLabelValues(t.proxyName, t.addr, proto).Add(1) - if pc != nil { - connCacheHitsCount.WithLabelValues(t.proxyName, t.addr, proto).Add(1) - return pc, true, nil - } - connCacheMissesCount.WithLabelValues(t.proxyName, t.addr, proto).Add(1) - - reqTime := time.Now() - timeout := t.dialTimeout() - if proto == "tcp-tls" { - conn, err := dns.DialTimeoutWithTLS("tcp", t.addr, t.tlsConfig, timeout) - t.updateDialTimeout(time.Since(reqTime)) - return &persistConn{c: conn}, false, err - } - conn, err := dns.DialTimeout(proto, t.addr, timeout) + reqTime := time.Now() + timeout := t.dialTimeout() + if proto == "tcp-tls" { + conn, err := dns.DialTimeoutWithTLS("tcp", t.addr, t.tlsConfig, timeout) t.updateDialTimeout(time.Since(reqTime)) return &persistConn{c: conn}, false, err - case <-t.stop: - return nil, false, errors.New(ErrTransportStoppedDuringRetWait) } + conn, err := dns.DialTimeout(proto, t.addr, timeout) + t.updateDialTimeout(time.Since(reqTime)) + return &persistConn{c: conn}, false, err } // Connect selects an upstream, sends the request and waits for a response. diff --git a/plugin/pkg/proxy/connect_test.go b/plugin/pkg/proxy/connect_test.go index d7a564530..1feb1584f 100644 --- a/plugin/pkg/proxy/connect_test.go +++ b/plugin/pkg/proxy/connect_test.go @@ -1,7 +1,6 @@ package proxy import ( - "sync" "testing" "time" ) @@ -30,209 +29,6 @@ func TestDial_TransportStopped_InitialCheck(t *testing.T) { } } -// TestDial_TransportStoppedDuringDialSend tests that Dial returns ErrTransportStoppedDuringDial -// if Stop() is called while Dial is attempting to send on the (blocked) t.dial channel. -// This is achieved by not starting the connManager, so t.dial remains unread. -func TestDial_TransportStoppedDuringDialSend(t *testing.T) { - tr := newTransport("test_during_dial_send", "127.0.0.1:0") - // No tr.Start() here. This ensures t.dial channel will block. - - dialErrChan := make(chan error, 1) - go func() { - // Dial will pass initial stop check (t.stop is open). - // Then it will block on `t.dial <- proto` because no connManager is reading. - _, _, err := tr.Dial("udp") - dialErrChan <- err - }() - - // Allow Dial goroutine to reach the blocking send on t.dial - time.Sleep(50 * time.Millisecond) - - tr.Stop() // Close t.stop. Dial's select should now pick <-t.stop. - - err := <-dialErrChan - if err == nil { - t.Fatalf("%s: %s", testMsgExpectedError, testMsgUnexpectedNilError) - } - if err.Error() != ErrTransportStoppedDuringDial { - t.Errorf("%s: got '%v', want '%s'", testMsgWrongError, err, ErrTransportStoppedDuringDial) - } -} - -// TestDial_TransportStoppedDuringRetWait tests that Dial returns ErrTransportStoppedDuringRetWait -// when the transport is stopped while Dial is waiting to receive from t.ret channel. -func TestDial_TransportStoppedDuringRetWait(t *testing.T) { - tr := newTransport("test_during_ret_wait", "127.0.0.1:0") - // Replace transport's channels to control interaction precisely - tr.dial = make(chan string) // Test-controlled, unbuffered - tr.ret = make(chan *persistConn) // Test-controlled, unbuffered - // tr.stop remains the original transport stop channel - - // NOTE: We purposefully do not call tr.Start() here, instead we - // manually simulate connManager behavior below - - dialErrChan := make(chan error, 1) - wg := sync.WaitGroup{} - wg.Add(1) - - go func() { - defer wg.Done() - // Dial will: - // 1. Pass initial stop check. - // 2. Send on our tr.dial. - // 3. Block on our tr.ret in its 3rd select. - // When tr.Stop() is called, this 3rd select should pick <-tr.stop. - _, _, err := tr.Dial("udp") - dialErrChan <- err - }() - - // Simulate connManager reading from our tr.dial. - // This unblocks the Dial goroutine's send. - var protoFromDial string - select { - case protoFromDial = <-tr.dial: - t.Logf("Simulated connManager read '%s' from Dial via test-controlled tr.dial", protoFromDial) - case <-time.After(500 * time.Millisecond): - t.Fatal("Timeout waiting for Dial to send on test-controlled tr.dial") - } - - // Stop the transport and the tr.stop channel - tr.Stop() - - wg.Wait() // Wait for Dial goroutine to complete. - err := <-dialErrChan - - if err == nil { - t.Fatalf("%s: %s", testMsgExpectedError, testMsgUnexpectedNilError) - } - - // Expected error is ErrTransportStoppedDuringRetWait - // However, if connManager (using replaced channels) itself reacts to stop faster - // and somehow closes the test-controlled tr.ret (not its design), other errors are possible. - // But with tr.ret being ours and unwritten-to, Dial should pick tr.stop. - if err.Error() != ErrTransportStoppedDuringRetWait { - t.Errorf("%s: got '%v', want '%s' (or potentially '%s' if timing is very tight)", - testMsgWrongError, err, ErrTransportStoppedDuringRetWait, ErrTransportStopped) - } else { - t.Logf("Dial correctly returned '%s'", ErrTransportStoppedDuringRetWait) - } -} - -// TestDial_Returns_ErrTransportStoppedRetClosed tests that Dial -// returns ErrTransportStoppedRetClosed when tr.ret is closed before Dial reads from it. -func TestDial_Returns_ErrTransportStoppedRetClosed(t *testing.T) { - tr := newTransport("test_returns_ret_closed", "127.0.0.1:0") - - // Replace transport channels with test-controlled ones - testDialChan := make(chan string, 1) // Buffered to allow non-blocking send by Dial - testRetChan := make(chan *persistConn) // This will be closed by the test - tr.dial = testDialChan - tr.ret = testRetChan - // tr.stop remains the original, initially open channel. - - dialErrChan := make(chan error, 1) - var wg sync.WaitGroup - wg.Add(1) - - go func() { - defer wg.Done() - // Dial will: - // 1. Pass initial stop check (tr.stop is open). - // 2. Send "udp" on tr.dial (which is testDialChan). - // 3. Block on <-tr.ret (which is testRetChan) in its 3rd select. - // When testRetChan is closed, it will read (nil, false), hitting the target error. - _, _, err := tr.Dial("udp") - dialErrChan <- err - }() - - // Step 1: Simulate connManager reading the dial request from Dial. - // Read from testDialChan. This unblocks the Dial goroutine's send to testDialChan. - select { - case proto := <-testDialChan: - if proto != "udp" { - wg.Done() - t.Fatalf("Dial sent wrong proto on testDialChan: got %s, want udp", proto) - } - t.Logf("Simulated connManager received '%s' from Dial via testDialChan.", proto) - case <-time.After(500 * time.Millisecond): - // If Dial didn't send, the test is flawed or Dial is stuck before sending. - wg.Done() - t.Fatal("Timeout waiting for Dial to send on testDialChan.") - } - - // Step 2: Simulate connManager stopping and closing its 'ret' channel. - close(testRetChan) - t.Logf("Closed testRetChan (simulating connManager closing tr.ret).") - - // Step 3: Wait for the Dial goroutine to complete. - wg.Wait() - err := <-dialErrChan - - if err == nil { - t.Fatalf("%s: %s", testMsgExpectedError, testMsgUnexpectedNilError) - } - - if err.Error() != ErrTransportStoppedRetClosed { - t.Errorf("%s: got '%v', want '%s'", testMsgWrongError, err, ErrTransportStoppedRetClosed) - } else { - t.Logf("Dial correctly returned '%s'", ErrTransportStoppedRetClosed) - } - - // Call tr.Stop() for completeness to close the original tr.stop channel. - // connManager was not started with original channels, so this mainly affects tr.stop. - tr.Stop() -} - -// TestDial_ConnManagerClosesRetOnStop verifies that connManager closes tr.ret upon stopping. -func TestDial_ConnManagerClosesRetOnStop(t *testing.T) { - tr := newTransport("test_connmanager_closes_ret", "127.0.0.1:0") - tr.Start() - - // Initiate a Dial to interact with connManager so tr.ret is used. - interactionDialErrChan := make(chan error, 1) - go func() { - _, _, err := tr.Dial("udp") - interactionDialErrChan <- err - }() - - // Allow the Dial goroutine to interact with connManager. - time.Sleep(100 * time.Millisecond) - - // Now stop the transport. connManager should clean up and close tr.ret. - tr.Stop() - - // Wait for connManager to fully stop and close its channels. - // This duration needs to be sufficient for the select loop in connManager to see <-t.stop, - // call t.cleanup(true), which in turn calls close(t.ret). - time.Sleep(50 * time.Millisecond) - - // Check if tr.ret is actually closed by trying a non-blocking read. - select { - case _, ok := <-tr.ret: - if !ok { - t.Logf("tr.ret channel is closed as expected after transport stop.") - } else { - t.Errorf("tr.ret channel was not closed after transport stop, or a value was read unexpectedly.") - } - default: - // This case means tr.ret is open but blocking (empty). - // This would be unexpected if connManager is supposed to close it on stop. - t.Errorf("tr.ret channel is not closed and is blocking (or empty but open).") - } - - // Drain the error channel from the initial interaction Dial to ensure the goroutine finishes. - select { - case err := <-interactionDialErrChan: - if err != nil { - t.Logf("Interaction Dial completed with error (possibly expected due to 127.0.0.1:0 or race with Stop): %v", err) - } else { - t.Logf("Interaction Dial completed without error.") - } - case <-time.After(500 * time.Millisecond): // Timeout for safety if Dial hangs - t.Logf("Timeout waiting for interaction Dial to complete.") - } -} - // TestDial_MultipleCallsAfterStop tests that multiple Dial calls after Stop // consistently return ErrTransportStopped. func TestDial_MultipleCallsAfterStop(t *testing.T) { diff --git a/plugin/pkg/proxy/persistent.go b/plugin/pkg/proxy/persistent.go index 0bacc851a..da2dca122 100644 --- a/plugin/pkg/proxy/persistent.go +++ b/plugin/pkg/proxy/persistent.go @@ -3,6 +3,7 @@ package proxy import ( "crypto/tls" "sort" + "sync" "time" "github.com/miekg/dns" @@ -16,17 +17,16 @@ type persistConn struct { // Transport hold the persistent cache. type Transport struct { - avgDialTime int64 // kind of average time of dial time - conns [typeTotalCount][]*persistConn // Buckets for udp, tcp and tcp-tls. - expire time.Duration // After this duration a connection is expired. - addr string - tlsConfig *tls.Config - proxyName string + avgDialTime int64 // kind of average time of dial time + conns [typeTotalCount][]*persistConn // Buckets for udp, tcp and tcp-tls. + expire time.Duration // After this duration a connection is expired. + maxIdleConns int // Max idle connections per transport type; 0 means unlimited. + addr string + tlsConfig *tls.Config + proxyName string - dial chan string - yield chan *persistConn - ret chan *persistConn - stop chan bool + mu sync.Mutex + stop chan struct{} } func newTransport(proxyName, addr string) *Transport { @@ -35,10 +35,7 @@ func newTransport(proxyName, addr string) *Transport { conns: [typeTotalCount][]*persistConn{}, expire: defaultExpire, addr: addr, - dial: make(chan string), - yield: make(chan *persistConn), - ret: make(chan *persistConn), - stop: make(chan bool), + stop: make(chan struct{}), proxyName: proxyName, } return t @@ -48,38 +45,12 @@ func newTransport(proxyName, addr string) *Transport { func (t *Transport) connManager() { ticker := time.NewTicker(defaultExpire) defer ticker.Stop() -Wait: for { select { - case proto := <-t.dial: - transtype := stringToTransportType(proto) - // take the last used conn - complexity O(1) - if stack := t.conns[transtype]; len(stack) > 0 { - pc := stack[len(stack)-1] - if time.Since(pc.used) < t.expire { - // Found one, remove from pool and return this conn. - t.conns[transtype] = stack[:len(stack)-1] - t.ret <- pc - continue Wait - } - // clear entire cache if the last conn is expired - t.conns[transtype] = nil - // now, the connections being passed to closeConns() are not reachable from - // transport methods anymore. So, it's safe to close them in a separate goroutine - go closeConns(stack) - } - t.ret <- nil - - case pc := <-t.yield: - transtype := t.transportTypeFromConn(pc) - t.conns[transtype] = append(t.conns[transtype], pc) - case <-ticker.C: t.cleanup(false) - case <-t.stop: t.cleanup(true) - close(t.ret) return } } @@ -94,6 +65,9 @@ func closeConns(conns []*persistConn) { // cleanup removes connections from cache. func (t *Transport) cleanup(all bool) { + var toClose []*persistConn + + t.mu.Lock() staleTime := time.Now().Add(-t.expire) for transtype, stack := range t.conns { if len(stack) == 0 { @@ -101,9 +75,7 @@ func (t *Transport) cleanup(all bool) { } if all { t.conns[transtype] = nil - // now, the connections being passed to closeConns() are not reachable from - // transport methods anymore. So, it's safe to close them in a separate goroutine - go closeConns(stack) + toClose = append(toClose, stack...) continue } if stack[0].used.After(staleTime) { @@ -115,34 +87,38 @@ func (t *Transport) cleanup(all bool) { return stack[i].used.After(staleTime) }) t.conns[transtype] = stack[good:] - // now, the connections being passed to closeConns() are not reachable from - // transport methods anymore. So, it's safe to close them in a separate goroutine - go closeConns(stack[:good]) + toClose = append(toClose, stack[:good]...) } -} + t.mu.Unlock() -// It is hard to pin a value to this, the import thing is to no block forever, losing at cached connection is not terrible. -const yieldTimeout = 25 * time.Millisecond + // Close connections after releasing lock + closeConns(toClose) +} // Yield returns the connection to transport for reuse. func (t *Transport) Yield(pc *persistConn) { - pc.used = time.Now() // update used time - - // Optimization: Try to return the connection immediately without creating a timer. - // If the receiver is not ready, we fall back to a timeout-based send to avoid blocking forever. - // Returning the connection is just an optimization, so dropping it on timeout is fine. + // Check if transport is stopped before acquiring lock select { - case t.yield <- pc: + case <-t.stop: + // If stopped, don't return to pool, just close + pc.c.Close() return default: } - select { - case t.yield <- pc: - return - case <-time.After(yieldTimeout): + pc.used = time.Now() // update used time + + t.mu.Lock() + defer t.mu.Unlock() + + transtype := t.transportTypeFromConn(pc) + + if t.maxIdleConns > 0 && len(t.conns[transtype]) >= t.maxIdleConns { + pc.c.Close() return } + + t.conns[transtype] = append(t.conns[transtype], pc) } // Start starts the transport's connection manager. @@ -154,6 +130,10 @@ func (t *Transport) Stop() { close(t.stop) } // SetExpire sets the connection expire time in transport. func (t *Transport) SetExpire(expire time.Duration) { t.expire = expire } +// SetMaxIdleConns sets the maximum idle connections per transport type. +// A value of 0 means unlimited (default). +func (t *Transport) SetMaxIdleConns(n int) { t.maxIdleConns = n } + // SetTLSConfig sets the TLS config in transport. func (t *Transport) SetTLSConfig(cfg *tls.Config) { t.tlsConfig = cfg } diff --git a/plugin/pkg/proxy/persistent_test.go b/plugin/pkg/proxy/persistent_test.go index 89715ab1d..d2b96cddc 100644 --- a/plugin/pkg/proxy/persistent_test.go +++ b/plugin/pkg/proxy/persistent_test.go @@ -35,8 +35,9 @@ func TestCached(t *testing.T) { if !cached3 { t.Error("Expected cached connection (c3)") } - if c2 != c3 { - t.Error("Expected c2 == c3") + // FIFO: first yielded (c1) should be first out + if c1 != c3 { + t.Error("Expected c1 == c3 (FIFO order)") } tr.Yield(c3) @@ -109,6 +110,122 @@ func TestCleanupAll(t *testing.T) { } } +func TestMaxIdleConns(t *testing.T) { + s := dnstest.NewServer(func(w dns.ResponseWriter, r *dns.Msg) { + ret := new(dns.Msg) + ret.SetReply(r) + w.WriteMsg(ret) + }) + defer s.Close() + + tr := newTransport("TestMaxIdleConns", s.Addr) + tr.SetMaxIdleConns(2) // Limit to 2 connections per type + tr.Start() + defer tr.Stop() + + // Dial 3 connections + c1, _, _ := tr.Dial("udp") + c2, _, _ := tr.Dial("udp") + c3, _, _ := tr.Dial("udp") + + // Yield all 3 + tr.Yield(c1) + tr.Yield(c2) + tr.Yield(c3) // This should be discarded (pool full) + + // Check pool size is capped at 2 + tr.mu.Lock() + poolSize := len(tr.conns[typeUDP]) + tr.mu.Unlock() + + if poolSize != 2 { + t.Errorf("Expected pool size 2, got %d", poolSize) + } + + // Verify we get the first 2 back (FIFO) + d1, cached1, _ := tr.Dial("udp") + d2, cached2, _ := tr.Dial("udp") + _, cached3, _ := tr.Dial("udp") + + if !cached1 || !cached2 { + t.Error("Expected first 2 dials to be cached") + } + if cached3 { + t.Error("Expected 3rd dial to be non-cached (pool was limited to 2)") + } + if d1 != c1 || d2 != c2 { + t.Error("Expected FIFO order: d1==c1, d2==c2") + } +} + +func TestMaxIdleConnsUnlimited(t *testing.T) { + s := dnstest.NewServer(func(w dns.ResponseWriter, r *dns.Msg) { + ret := new(dns.Msg) + ret.SetReply(r) + w.WriteMsg(ret) + }) + defer s.Close() + + tr := newTransport("TestMaxIdleConnsUnlimited", s.Addr) + // maxIdleConns defaults to 0 (unlimited) + tr.Start() + defer tr.Stop() + + // Dial and yield 5 connections + conns := make([]*persistConn, 5) + for i := range conns { + conns[i], _, _ = tr.Dial("udp") + } + for _, c := range conns { + tr.Yield(c) + } + + // Check all 5 are in pool + tr.mu.Lock() + poolSize := len(tr.conns[typeUDP]) + tr.mu.Unlock() + + if poolSize != 5 { + t.Errorf("Expected pool size 5 (unlimited), got %d", poolSize) + } +} + +func TestYieldAfterStop(t *testing.T) { + s := dnstest.NewServer(func(w dns.ResponseWriter, r *dns.Msg) { + ret := new(dns.Msg) + ret.SetReply(r) + w.WriteMsg(ret) + }) + defer s.Close() + + tr := newTransport("TestYieldAfterStop", s.Addr) + tr.Start() + + // Dial a connection while transport is running + c1, _, err := tr.Dial("udp") + if err != nil { + t.Fatalf("Failed to dial: %v", err) + } + + // Stop the transport + tr.Stop() + + // Give cleanup goroutine time to exit + time.Sleep(50 * time.Millisecond) + + // Yield the connection after stop - should close it, not pool it + tr.Yield(c1) + + // Verify pool is empty (connection was closed, not added) + tr.mu.Lock() + poolSize := len(tr.conns[typeUDP]) + tr.mu.Unlock() + + if poolSize != 0 { + t.Errorf("Expected pool size 0 after stop, got %d", poolSize) + } +} + func BenchmarkYield(b *testing.B) { s := dnstest.NewServer(func(w dns.ResponseWriter, r *dns.Msg) { ret := new(dns.Msg) @@ -127,12 +244,12 @@ func BenchmarkYield(b *testing.B) { for b.Loop() { tr.Yield(c) - // Drain the yield channel so we can yield again without blocking/timing out - // We need to simulate the consumer side slightly to keep Yield flowing - select { - case <-tr.yield: - default: + // Simulate FIFO consumption: remove from front + tr.mu.Lock() + if len(tr.conns[typeUDP]) > 0 { + tr.conns[typeUDP] = tr.conns[typeUDP][1:] } + tr.mu.Unlock() runtime.Gosched() } } diff --git a/plugin/pkg/proxy/proxy.go b/plugin/pkg/proxy/proxy.go index 35e94bf83..6c460c397 100644 --- a/plugin/pkg/proxy/proxy.go +++ b/plugin/pkg/proxy/proxy.go @@ -52,6 +52,10 @@ func (p *Proxy) SetTLSConfig(cfg *tls.Config) { // SetExpire sets the expire duration in the lower p.transport. func (p *Proxy) SetExpire(expire time.Duration) { p.transport.SetExpire(expire) } +// SetMaxIdleConns sets the maximum idle connections per transport type. +// A value of 0 means unlimited (default). +func (p *Proxy) SetMaxIdleConns(n int) { p.transport.SetMaxIdleConns(n) } + func (p *Proxy) GetHealthchecker() HealthChecker { return p.health } diff --git a/plugin/pkg/proxy/proxy_test.go b/plugin/pkg/proxy/proxy_test.go index 03d10ce5f..afb66204b 100644 --- a/plugin/pkg/proxy/proxy_test.go +++ b/plugin/pkg/proxy/proxy_test.go @@ -5,6 +5,7 @@ import ( "crypto/tls" "errors" "math" + "net" "testing" "time" @@ -73,30 +74,66 @@ func TestProxyTLSFail(t *testing.T) { } func TestProtocolSelection(t *testing.T) { - p := NewProxy("TestProtocolSelection", "bad_address", transport.DNS) - p.readTimeout = 10 * time.Millisecond + testCases := []struct { + name string + requestTCP bool // true = TCP request, false = UDP request + opts Options + expectedProto string + }{ + {"UDP request, no options", false, Options{}, "udp"}, + {"UDP request, ForceTCP", false, Options{ForceTCP: true}, "tcp"}, + {"UDP request, PreferUDP", false, Options{PreferUDP: true}, "udp"}, + {"UDP request, ForceTCP+PreferUDP", false, Options{ForceTCP: true, PreferUDP: true}, "tcp"}, + {"TCP request, no options", true, Options{}, "tcp"}, + {"TCP request, ForceTCP", true, Options{ForceTCP: true}, "tcp"}, + {"TCP request, PreferUDP", true, Options{PreferUDP: true}, "udp"}, + {"TCP request, ForceTCP+PreferUDP", true, Options{ForceTCP: true, PreferUDP: true}, "tcp"}, + } - stateUDP := request.Request{W: &test.ResponseWriter{}, Req: new(dns.Msg)} - stateTCP := request.Request{W: &test.ResponseWriter{TCP: true}, Req: new(dns.Msg)} - ctx := context.TODO() + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + // Track which protocol the server received (use channel to avoid data race) + protoChan := make(chan string, 1) + s := dnstest.NewServer(func(w dns.ResponseWriter, r *dns.Msg) { + // Determine protocol from the connection type + if _, ok := w.RemoteAddr().(*net.TCPAddr); ok { + protoChan <- "tcp" + } else { + protoChan <- "udp" + } + ret := new(dns.Msg) + ret.SetReply(r) + ret.Answer = append(ret.Answer, test.A("example.org. IN A 127.0.0.1")) + w.WriteMsg(ret) + }) + defer s.Close() - go func() { - p.Connect(ctx, stateUDP, Options{}) - p.Connect(ctx, stateUDP, Options{ForceTCP: true}) - p.Connect(ctx, stateUDP, Options{PreferUDP: true}) - p.Connect(ctx, stateUDP, Options{PreferUDP: true, ForceTCP: true}) - p.Connect(ctx, stateTCP, Options{}) - p.Connect(ctx, stateTCP, Options{ForceTCP: true}) - p.Connect(ctx, stateTCP, Options{PreferUDP: true}) - p.Connect(ctx, stateTCP, Options{PreferUDP: true, ForceTCP: true}) - }() + p := NewProxy("TestProtocolSelection", s.Addr, transport.DNS) + p.readTimeout = 1 * time.Second + p.Start(5 * time.Second) + defer p.Stop() - for i, exp := range []string{"udp", "tcp", "udp", "tcp", "tcp", "tcp", "udp", "tcp"} { - proto := <-p.transport.dial - p.transport.ret <- nil - if proto != exp { - t.Errorf("Unexpected protocol in case %d, expected %q, actual %q", i, exp, proto) - } + m := new(dns.Msg) + m.SetQuestion("example.org.", dns.TypeA) + + req := request.Request{ + W: &test.ResponseWriter{TCP: tc.requestTCP}, + Req: m, + } + + resp, err := p.Connect(context.Background(), req, tc.opts) + if err != nil { + t.Fatalf("Connect failed: %v", err) + } + if resp == nil { + t.Fatal("Expected response, got nil") + } + + receivedProto := <-protoChan + if receivedProto != tc.expectedProto { + t.Errorf("Expected protocol %q, but server received %q", tc.expectedProto, receivedProto) + } + }) } } diff --git a/test/proxy_test.go b/test/proxy_test.go index 7b93daa3d..84b03aee1 100644 --- a/test/proxy_test.go +++ b/test/proxy_test.go @@ -1,8 +1,14 @@ package test import ( + "context" + "fmt" + "net" "testing" + "github.com/coredns/coredns/plugin/forward" + "github.com/coredns/coredns/plugin/pkg/dnstest" + "github.com/coredns/coredns/plugin/pkg/proxy" "github.com/coredns/coredns/plugin/test" "github.com/miekg/dns" @@ -76,3 +82,72 @@ func BenchmarkProxyLookup(b *testing.B) { } } } + +// BenchmarkProxyWithMultipleBackends verifies the serialization issue by running concurrent load +// against 1, 2, and 3 backend proxies using the forward plugin. +func BenchmarkProxyWithMultipleBackends(b *testing.B) { + // Start a dummy upstream server + s := dnstest.NewServer(func(w dns.ResponseWriter, r *dns.Msg) { + ret := new(dns.Msg) + ret.SetReply(r) + w.WriteMsg(ret) + }) + defer s.Close() + + counts := []int{1, 2, 3} + + for _, n := range counts { + b.Run(fmt.Sprintf("%d-Backends", n), func(b *testing.B) { + f := forward.New() + f.SetProxyOptions(proxy.Options{PreferUDP: true}) + + proxies := make([]*proxy.Proxy, n) + for i := range n { + p := proxy.NewProxy(fmt.Sprintf("proxy-%d", i), s.Addr, "dns") + f.SetProxy(p) + proxies[i] = p + } + defer func() { + for _, p := range proxies { + p.Stop() + } + }() + + // Pre-warm connections + ctx := context.Background() + m := new(dns.Msg) + m.SetQuestion("example.org.", dns.TypeA) + noop := &benchmarkResponseWriter{} + + for range n * 10 { + f.ServeDNS(ctx, noop, m) + } + + b.ResetTimer() + b.ReportAllocs() + + b.RunParallel(func(pb *testing.PB) { + m := new(dns.Msg) + m.SetQuestion("example.org.", dns.TypeA) + ctx := context.Background() + w := &benchmarkResponseWriter{} + + for pb.Next() { + // forward plugin handles selection via its policy (default random) + f.ServeDNS(ctx, w, m) + } + }) + }) + } +} + +type benchmarkResponseWriter struct{} + +func (b *benchmarkResponseWriter) LocalAddr() net.Addr { return nil } +func (b *benchmarkResponseWriter) RemoteAddr() net.Addr { return nil } +func (b *benchmarkResponseWriter) WriteMsg(m *dns.Msg) error { return nil } +func (b *benchmarkResponseWriter) Write(p []byte) (int, error) { return len(p), nil } +func (b *benchmarkResponseWriter) Close() error { return nil } +func (b *benchmarkResponseWriter) TsigStatus() error { return nil } +func (b *benchmarkResponseWriter) TsigTimersOnly(bool) {} +func (b *benchmarkResponseWriter) Hijack() {}