From 31e16025ef6688badfaed2d68d39a1e76148037c Mon Sep 17 00:00:00 2001 From: rpb-ant Date: Tue, 24 Mar 2026 11:47:11 -0400 Subject: [PATCH] plugin/cache: prefetch without holding a client connection (#7944) --- plugin/cache/cache.go | 80 +++++++++++++++++++++++++++++------------ plugin/cache/handler.go | 18 +++++----- 2 files changed, 67 insertions(+), 31 deletions(-) diff --git a/plugin/cache/cache.go b/plugin/cache/cache.go index 05223dec7..04eb3283a 100644 --- a/plugin/cache/cache.go +++ b/plugin/cache/cache.go @@ -142,30 +142,26 @@ type ResponseWriter struct { nexcept []string // negative zone exceptions } -// newPrefetchResponseWriter returns a Cache ResponseWriter to be used in -// prefetch requests. It ensures RemoteAddr() can be called even after the -// original connection has already been closed. -func newPrefetchResponseWriter(server string, state request.Request, c *Cache) *ResponseWriter { - // Resolve the address now, the connection might be already closed when the - // actual prefetch request is made. - addr := state.W.RemoteAddr() - // The protocol of the client triggering a cache prefetch doesn't matter. - // The address type is used by request.Proto to determine the response size, - // and using TCP ensures the message isn't unnecessarily truncated. - if u, ok := addr.(*net.UDPAddr); ok { - addr = &net.TCPAddr{IP: u.IP, Port: u.Port, Zone: u.Zone} - } +// prefetchAddr is the synthetic remote address for prefetch requests. There is +// no client connection, and per request.Proto the address type is what selects +// the response-size budget; TCP ensures upstream replies aren't truncated. +var prefetchAddr = &net.TCPAddr{} - return &ResponseWriter{ - ResponseWriter: state.W, - Cache: c, - state: state, - server: server, - do: state.Do(), - cd: state.Req.CheckingDisabled, - prefetch: true, - remoteAddr: addr, +// newPrefetchResponseWriter returns a ResponseWriter for prefetch requests. +// Prefetch has no client connection: the inner ResponseWriter is nil, WriteMsg +// short-circuits after caching when w.prefetch is true, and the nil-safe +// overrides below make the remaining dns.ResponseWriter methods well-defined. +func newPrefetchResponseWriter(server string, req *dns.Msg, do, cd bool, c *Cache) *ResponseWriter { + cw := &ResponseWriter{ + Cache: c, + server: server, + do: do, + cd: cd, + prefetch: true, + remoteAddr: prefetchAddr, } + cw.state = request.Request{Req: req} + return cw } // RemoteAddr implements the dns.ResponseWriter interface. @@ -176,6 +172,46 @@ func (w *ResponseWriter) RemoteAddr() net.Addr { return w.ResponseWriter.RemoteAddr() } +// The following overrides make a nil inner ResponseWriter well-defined. +// Prefetch constructs a ResponseWriter with no client connection; WriteMsg +// and Write already short-circuit on w.prefetch before delegating, and +// RemoteAddr uses w.remoteAddr. These cover the rest of the interface. + +func (w *ResponseWriter) LocalAddr() net.Addr { + if w.ResponseWriter == nil { + return prefetchAddr + } + return w.ResponseWriter.LocalAddr() +} + +func (w *ResponseWriter) Close() error { + if w.ResponseWriter == nil { + return nil + } + return w.ResponseWriter.Close() +} + +func (w *ResponseWriter) TsigStatus() error { + if w.ResponseWriter == nil { + return nil + } + return w.ResponseWriter.TsigStatus() +} + +func (w *ResponseWriter) TsigTimersOnly(b bool) { + if w.ResponseWriter == nil { + return + } + w.ResponseWriter.TsigTimersOnly(b) +} + +func (w *ResponseWriter) Hijack() { + if w.ResponseWriter == nil { + return + } + w.ResponseWriter.Hijack() +} + // WriteMsg implements the dns.ResponseWriter interface. func (w *ResponseWriter) WriteMsg(res *dns.Msg) error { res = res.Copy() diff --git a/plugin/cache/handler.go b/plugin/cache/handler.go index 4e8edc3d2..4e8b1450e 100644 --- a/plugin/cache/handler.go +++ b/plugin/cache/handler.go @@ -55,13 +55,13 @@ func (c *Cache) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) // Adjust the time to get a 0 TTL in the reply built from a stale item. now = now.Add(time.Duration(ttl) * time.Second) if !c.verifyStale { - cw := newPrefetchResponseWriter(server, state, c) - go c.doPrefetch(ctx, state, cw, i, now) + cw := newPrefetchResponseWriter(server, rc, do, cd, c) + go c.doPrefetch(ctx, cw, i, now) } servedStale.WithLabelValues(server, c.zonesMetricLabel, c.viewMetricLabel).Inc() } else if c.shouldPrefetch(i, now) { - cw := newPrefetchResponseWriter(server, state, c) - go c.doPrefetch(ctx, state, cw, i, now) + cw := newPrefetchResponseWriter(server, rc, do, cd, c) + go c.doPrefetch(ctx, cw, i, now) } if i.wildcard != "" { @@ -91,16 +91,16 @@ func wildcardFunc(ctx context.Context) func() string { } } -func (c *Cache) doPrefetch(ctx context.Context, state request.Request, cw *ResponseWriter, i *item, now time.Time) { +func (c *Cache) doPrefetch(ctx context.Context, cw *ResponseWriter, i *item, now time.Time) { // Use a fresh metadata map to avoid concurrent writes to the original request's metadata. ctx = metadata.ContextWithMetadata(ctx) cachePrefetches.WithLabelValues(cw.server, c.zonesMetricLabel, c.viewMetricLabel).Inc() - c.doRefresh(ctx, state, cw) + c.doRefresh(ctx, cw.state, cw) // When prefetching we loose the item i, and with it the frequency // that we've gathered sofar. See we copy the frequencies info back // into the new item that was stored in the cache. - if i1 := c.exists(state); i1 != nil { + if i1 := c.exists(cw.state.Name(), cw.state.QType(), cw.do, cw.cd); i1 != nil { i1.Reset(now, i.Hits()) } } @@ -145,8 +145,8 @@ func (c *Cache) getIfNotStale(now time.Time, state request.Request, server strin } // exists unconditionally returns an item if it exists in the cache. -func (c *Cache) exists(state request.Request) *item { - k := hash(state.Name(), state.QType(), state.Do(), state.Req.CheckingDisabled) +func (c *Cache) exists(name string, qtype uint16, do, cd bool) *item { + k := hash(name, qtype, do, cd) if i, ok := c.ncache.Get(k); ok { return i }