From 03d0863a45ba5b9da6c850aabc07f4506f3a1d14 Mon Sep 17 00:00:00 2001 From: Cedric Wang Date: Sat, 4 Apr 2026 11:32:10 -0700 Subject: [PATCH] fix(doh): use per-connection local address for PROXY protocol (#8005) --- core/dnsserver/server_https.go | 16 ++++- core/dnsserver/server_https_test.go | 104 ++++++++++++++++++++++++++++ 2 files changed, 119 insertions(+), 1 deletion(-) diff --git a/core/dnsserver/server_https.go b/core/dnsserver/server_https.go index b74bca87e..0df47c32b 100644 --- a/core/dnsserver/server_https.go +++ b/core/dnsserver/server_https.go @@ -51,6 +51,9 @@ func (l *loggerAdapter) Write(p []byte) (n int, err error) { // Plugins can access the original HTTP request to retrieve headers, client IP, and metadata. type HTTPRequestKey struct{} +// connAddrKey is the context key for the per-connection local address set by ConnContext. +type connAddrKey struct{} + // NewServerHTTPS returns a new CoreDNS HTTPS server and compiles all plugins in to it. func NewServerHTTPS(addr string, group []*Config) (*ServerHTTPS, error) { s, err := NewServer(addr, group) @@ -89,6 +92,9 @@ func NewServerHTTPS(addr string, group []*Config) (*ServerHTTPS, error) { WriteTimeout: s.WriteTimeout, IdleTimeout: s.IdleTimeout, ErrorLog: stdlog.New(&loggerAdapter{}, "", 0), + ConnContext: func(ctx context.Context, c net.Conn) context.Context { + return context.WithValue(ctx, connAddrKey{}, c.LocalAddr()) + }, } maxConnections := DefaultHTTPSMaxConnections if len(group) > 0 && group[0] != nil && group[0].MaxHTTPSConnections != nil { @@ -169,6 +175,14 @@ func (s *ServerHTTPS) Stop() error { return nil } +// localAddr returns the per-connection local address from context, or s.listenAddr as fallback. +func (s *ServerHTTPS) localAddr(r *http.Request) net.Addr { + if addr, ok := r.Context().Value(connAddrKey{}).(net.Addr); ok { + return addr + } + return s.listenAddr +} + // ServeHTTP is the handler that gets the HTTP request and converts to the dns format, calls the plugin // chain, converts it back and write it to the client. func (s *ServerHTTPS) ServeHTTP(w http.ResponseWriter, r *http.Request) { @@ -189,7 +203,7 @@ func (s *ServerHTTPS) ServeHTTP(w http.ResponseWriter, r *http.Request) { h, p, _ := net.SplitHostPort(r.RemoteAddr) port, _ := strconv.Atoi(p) dw := &DoHWriter{ - laddr: s.listenAddr, + laddr: s.localAddr(r), raddr: &net.TCPAddr{IP: net.ParseIP(h), Port: port}, request: r, } diff --git a/core/dnsserver/server_https_test.go b/core/dnsserver/server_https_test.go index d5bc49cad..21dbaa84b 100644 --- a/core/dnsserver/server_https_test.go +++ b/core/dnsserver/server_https_test.go @@ -5,6 +5,7 @@ import ( "context" "crypto/tls" "io" + "net" "net/http" "net/http/httptest" "regexp" @@ -167,6 +168,109 @@ func testConfigWithPlugin(p *contextCapturingPlugin) *Config { return c } +func TestDoHWriterLaddrFromConnContext(t *testing.T) { + capturer := &addrCapturingPlugin{} + cfg := testConfigWithHandler(capturer) + + s, err := NewServerHTTPS("127.0.0.1:443", []*Config{cfg}) + if err != nil { + t.Fatal("could not create HTTPS server:", err) + } + s.listenAddr = &net.TCPAddr{IP: net.ParseIP("127.0.0.1"), Port: 443} + + m := new(dns.Msg) + m.SetQuestion("example.com.", dns.TypeA) + buf, err := m.Pack() + if err != nil { + t.Fatal(err) + } + + // Simulate a PROXY protocol destination that differs from listenAddr. + ppDst := &net.TCPAddr{IP: net.ParseIP("10.0.0.1"), Port: 443} + + r := httptest.NewRequest(http.MethodPost, "/dns-query", io.NopCloser(bytes.NewReader(buf))) + ctx := context.WithValue(r.Context(), connAddrKey{}, ppDst) + r = r.WithContext(ctx) + w := httptest.NewRecorder() + + s.ServeHTTP(w, r) + + if !capturer.called { + t.Fatal("plugin was not called") + } + if capturer.localAddr == nil { + t.Fatal("DoHWriter.laddr is nil") + } + if capturer.localAddr.String() != ppDst.String() { + t.Errorf("expected laddr %s (PP destination), got %s", ppDst, capturer.localAddr) + } +} + +func TestDoHWriterLaddrFallback(t *testing.T) { + capturer := &addrCapturingPlugin{} + cfg := testConfigWithHandler(capturer) + + s, err := NewServerHTTPS("127.0.0.1:443", []*Config{cfg}) + if err != nil { + t.Fatal("could not create HTTPS server:", err) + } + s.listenAddr = &net.TCPAddr{IP: net.ParseIP("127.0.0.1"), Port: 443} + + m := new(dns.Msg) + m.SetQuestion("example.com.", dns.TypeA) + buf, err := m.Pack() + if err != nil { + t.Fatal(err) + } + + // No connAddrKey in context; should fall back to s.listenAddr. + r := httptest.NewRequest(http.MethodPost, "/dns-query", io.NopCloser(bytes.NewReader(buf))) + w := httptest.NewRecorder() + + s.ServeHTTP(w, r) + + if !capturer.called { + t.Fatal("plugin was not called") + } + if capturer.localAddr == nil { + t.Fatal("DoHWriter.laddr is nil") + } + if capturer.localAddr.String() != s.listenAddr.String() { + t.Errorf("expected fallback laddr %s, got %s", s.listenAddr, capturer.localAddr) + } +} + +type addrCapturingPlugin struct { + called bool + localAddr net.Addr + remoteAddr net.Addr +} + +func (p *addrCapturingPlugin) ServeDNS(_ context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) { + p.called = true + p.localAddr = w.LocalAddr() + p.remoteAddr = w.RemoteAddr() + m := new(dns.Msg) + m.SetReply(r) + m.Authoritative = true + w.WriteMsg(m) + return dns.RcodeSuccess, nil +} + +func (p *addrCapturingPlugin) Name() string { return "addr_capturing" } + +func testConfigWithHandler(h plugin.Handler) *Config { + c := &Config{ + Zone: "example.com.", + Transport: "https", + TLSConfig: &tls.Config{}, + ListenHosts: []string{"127.0.0.1"}, + Port: "443", + } + c.AddPlugin(func(_next plugin.Handler) plugin.Handler { return h }) + return c +} + func TestHTTPRequestContextPropagation(t *testing.T) { plugin := &contextCapturingPlugin{}