proxyproto: add UDP session tracking for Spectrum PPv2 (#7967)

This commit is contained in:
Minghang Chen
2026-03-28 15:06:36 -07:00
committed by GitHub
parent 12d9457e71
commit 34acf8353f
10 changed files with 398 additions and 19 deletions

View File

@@ -10,6 +10,7 @@ import (
clog "github.com/coredns/coredns/plugin/pkg/log"
"github.com/hashicorp/golang-lru/v2/expirable"
"github.com/pires/go-proxyproto"
)
@@ -18,11 +19,49 @@ var (
_ net.Addr = (*Addr)(nil)
)
// errHeaderOnly is a sentinel used internally to signal that the datagram
// contained only a PROXY Protocol header with no DNS payload. It is never
// returned to callers of ReadFrom.
var errHeaderOnly = errors.New("header-only datagram; no payload")
// PacketConn wraps a net.PacketConn and strips PROXY Protocol v2 headers from
// incoming UDP datagrams.
//
// When UDPSessionTrackingTTL is greater than zero the connection implements
// Cloudflare Spectrum's PPv2-over-UDP behavior: the PROXY header arrives in
// the very first datagram of a session (which may carries an empty payload)
// while all subsequent datagrams carry real DNS payload without any header.
// The real source address parsed from the first datagram is cached keyed by
// the Spectrum-side remote address and applied to every headerless datagram
// that arrives from the same remote address within UDPSessionTrackingTTL.
//
// The session cache is a fixed-capacity LRU (capped at udpSessionMaxEntries)
// so that memory usage is bounded regardless of the number of distinct remote
// addresses seen.
type PacketConn struct {
net.PacketConn
ConnPolicy proxyproto.ConnPolicyFunc
ValidateHeader proxyproto.Validator
ReadHeaderTimeout time.Duration
// UDPSessionTrackingTTL enables per-remote-address session state for UDP
// when set to a positive duration. A header-only datagram (valid PPv2
// header with or without payload) causes the parsed source address to be
// cached for this duration. Subsequent datagrams from the same remote
// address that carry no PPv2 header are assigned the cached source
// address. The TTL is refreshed on every matching packet. A zero or
// negative value disables session tracking entirely.
UDPSessionTrackingTTL time.Duration
// UDPSessionTrackingMaxSessions is the maximum number of concurrent UDP
// sessions held in the LRU cache. Zero or negative means use the default
// (udpSessionMaxEntries). Has no effect unless UDPSessionTrackingTTL is
// positive.
UDPSessionTrackingMaxSessions int
// sessionCache is a thread-safe expirable LRU; lazily initialized on
// first use when UDPSessionTrackingTTL > 0.
sessionCache *expirable.LRU[string, *proxyproto.Header]
}
func (c *PacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
@@ -33,6 +72,12 @@ func (c *PacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
}
n, addr, err = c.readFrom(p[:n], addr)
if err != nil {
if errors.Is(err, errHeaderOnly) {
// Header-only datagram with no DNS payload (Spectrum PPv2 UDP
// session establishment). Silently discard and wait for the
// next datagram.
continue
}
// drop invalid packet as returning error would cause the ReadFrom caller to exit
// which could result in DoS if an attacker sends intentional invalid packets
clog.Warningf("dropping invalid Proxy Protocol packet from %s: %v", addr.String(), err)
@@ -84,8 +129,26 @@ func (c *PacketConn) readFrom(p []byte, addr net.Addr) (_ int, _ net.Addr, err e
fallthrough
case proxyproto.USE:
if header != nil {
srcAddr, _, _ := header.UDPAddrs()
addr = &Addr{u: addr, r: srcAddr}
addr = &Addr{u: addr, r: header.SourceAddr}
if c.UDPSessionTrackingTTL > 0 {
// Cache the real source address for subsequent headerless datagrams.
// Spectrum sends the header in a standalone datagram with no DNS
// payload; refresh or insert the entry either way so that the TTL
// resets on every header packet.
c.storeSession(addr.(*Addr).u, header)
if len(payload) == 0 {
// Header-only datagram: no DNS payload to return; loop back
// to read the next datagram.
return 0, nil, errHeaderOnly
}
}
} else if c.UDPSessionTrackingTTL > 0 {
// No header present look for a cached header for this remote.
if cachedHeader, ok := c.lookupSession(addr); ok {
addr = &Addr{u: addr, r: cachedHeader.SourceAddr}
}
}
default:
}

View File

@@ -0,0 +1,63 @@
package proxyproto
import (
"net"
"sync"
"github.com/hashicorp/golang-lru/v2/expirable"
"github.com/pires/go-proxyproto"
)
// udpSessionMaxEntries is the default maximum number of concurrent UDP
// sessions that the LRU cache will track. When the cache is full the
// least-recently-used entry is evicted.
const udpSessionMaxEntries = 10_240
// sessionInitMu serializes lazy initialization of PacketConn.sessionCache.
var sessionInitMu sync.Mutex
// ensureSessionCache lazily creates the expirable LRU if it hasn't been
// created yet. The expirable.LRU itself is thread-safe once constructed.
func (c *PacketConn) ensureSessionCache() {
if c.sessionCache != nil {
return
}
sessionInitMu.Lock()
defer sessionInitMu.Unlock()
if c.sessionCache != nil {
return // double-check after acquiring lock
}
cap := c.UDPSessionTrackingMaxSessions
if cap <= 0 {
cap = udpSessionMaxEntries
}
c.sessionCache = expirable.NewLRU[string, *proxyproto.Header](cap, nil, c.UDPSessionTrackingTTL)
}
// storeSession inserts or refreshes the session entry for remoteAddr.
// Calling Add on an existing key resets its TTL.
func (c *PacketConn) storeSession(remoteAddr net.Addr, header *proxyproto.Header) {
c.ensureSessionCache()
c.sessionCache.Add(sessionKey(remoteAddr), header)
}
// lookupSession returns the cached source address for remoteAddr, if one
// exists and has not expired. Looking up a key refreshes its TTL by
// re-adding it.
func (c *PacketConn) lookupSession(remoteAddr net.Addr) (*proxyproto.Header, bool) {
if c.sessionCache == nil {
return nil, false
}
key := sessionKey(remoteAddr)
header, ok := c.sessionCache.Get(key)
if !ok {
return nil, false
}
// Refresh TTL by re-adding.
c.sessionCache.Add(key, header)
return header, true
}
func sessionKey(addr net.Addr) string {
return addr.Network() + "://" + addr.String()
}

View File

@@ -0,0 +1,134 @@
package proxyproto
import (
"net"
"testing"
"time"
proxyproto "github.com/pires/go-proxyproto"
)
func udpAddr(host string, port int) *net.UDPAddr {
return &net.UDPAddr{IP: net.ParseIP(host), Port: port}
}
// testHeader builds a minimal PPv2 header with the given source address.
func testHeader(src *net.UDPAddr) *proxyproto.Header {
return &proxyproto.Header{
Version: 2,
SourceAddr: src,
}
}
func TestSessionKey(t *testing.T) {
addr := &net.UDPAddr{IP: net.ParseIP("10.0.0.1"), Port: 5000}
got := sessionKey(addr)
want := "udp://10.0.0.1:5000"
if got != want {
t.Fatalf("sessionKey = %q, want %q", got, want)
}
}
func newTestPacketConn(ttl time.Duration, maxSessions int) *PacketConn {
return &PacketConn{
UDPSessionTrackingTTL: ttl,
UDPSessionTrackingMaxSessions: maxSessions,
}
}
func TestStoreAndLookupSession(t *testing.T) {
pc := newTestPacketConn(time.Second, 0)
remote := udpAddr("10.0.0.1", 5000)
src := udpAddr("192.168.1.1", 12345)
pc.storeSession(remote, testHeader(src))
got, ok := pc.lookupSession(remote)
if !ok {
t.Fatal("expected session to be found")
}
if got.SourceAddr.String() != src.String() {
t.Fatalf("expected SourceAddr %s, got %s", src, got.SourceAddr)
}
}
func TestLookupSessionMiss(t *testing.T) {
pc := newTestPacketConn(time.Second, 0)
_, ok := pc.lookupSession(udpAddr("10.0.0.1", 5000))
if ok {
t.Fatal("expected miss on empty cache")
}
}
func TestLookupSessionExpired(t *testing.T) {
pc := newTestPacketConn(50*time.Millisecond, 0)
remote := udpAddr("10.0.0.1", 5000)
src := udpAddr("192.168.1.1", 12345)
pc.storeSession(remote, testHeader(src))
time.Sleep(100 * time.Millisecond)
_, ok := pc.lookupSession(remote)
if ok {
t.Fatal("expected expired entry to be missing")
}
}
func TestLookupSessionRefreshesTTL(t *testing.T) {
ttl := 50 * time.Millisecond
pc := newTestPacketConn(ttl, 0)
remote := udpAddr("10.0.0.1", 5000)
src := udpAddr("192.168.1.1", 12345)
pc.storeSession(remote, testHeader(src))
// Wait past half the TTL, then look up (which should refresh).
time.Sleep(30 * time.Millisecond)
_, ok := pc.lookupSession(remote)
if !ok {
t.Fatal("expected session to be found before TTL")
}
// Wait another 30ms. Original TTL would have expired (60ms > 50ms),
// but the refresh from lookupSession should keep it alive.
time.Sleep(30 * time.Millisecond)
_, ok = pc.lookupSession(remote)
if !ok {
t.Fatal("expected session to survive after TTL refresh")
}
}
func TestStoreSessionCustomMaxSessions(t *testing.T) {
pc := newTestPacketConn(time.Second, 5)
// Fill beyond custom cap.
for i := range 10 {
pc.storeSession(udpAddr("10.0.0.1", i), testHeader(udpAddr("1.1.1.1", i)))
}
if pc.sessionCache.Len() != 5 {
t.Fatalf("expected cache capped at 5, got %d", pc.sessionCache.Len())
}
}
func TestStoreSessionEvictsOldest(t *testing.T) {
pc := newTestPacketConn(time.Minute, 2)
r1 := udpAddr("10.0.0.1", 1)
r2 := udpAddr("10.0.0.2", 2)
r3 := udpAddr("10.0.0.3", 3)
pc.storeSession(r1, testHeader(udpAddr("1.1.1.1", 1)))
pc.storeSession(r2, testHeader(udpAddr("2.2.2.2", 2)))
// Cache is full (cap=2). Storing r3 evicts r1.
pc.storeSession(r3, testHeader(udpAddr("3.3.3.3", 3)))
if _, ok := pc.lookupSession(r1); ok {
t.Fatal("expected r1 to be evicted")
}
if _, ok := pc.lookupSession(r2); !ok {
t.Fatal("expected r2 to be present")
}
if _, ok := pc.lookupSession(r3); !ok {
t.Fatal("expected r3 to be present")
}
}