mirror of
https://github.com/coredns/coredns.git
synced 2026-04-05 11:45:33 -04:00
fix(doh): use per-connection local address for PROXY protocol (#8005)
This commit is contained in:
@@ -51,6 +51,9 @@ func (l *loggerAdapter) Write(p []byte) (n int, err error) {
|
|||||||
// Plugins can access the original HTTP request to retrieve headers, client IP, and metadata.
|
// Plugins can access the original HTTP request to retrieve headers, client IP, and metadata.
|
||||||
type HTTPRequestKey struct{}
|
type HTTPRequestKey struct{}
|
||||||
|
|
||||||
|
// connAddrKey is the context key for the per-connection local address set by ConnContext.
|
||||||
|
type connAddrKey struct{}
|
||||||
|
|
||||||
// NewServerHTTPS returns a new CoreDNS HTTPS server and compiles all plugins in to it.
|
// NewServerHTTPS returns a new CoreDNS HTTPS server and compiles all plugins in to it.
|
||||||
func NewServerHTTPS(addr string, group []*Config) (*ServerHTTPS, error) {
|
func NewServerHTTPS(addr string, group []*Config) (*ServerHTTPS, error) {
|
||||||
s, err := NewServer(addr, group)
|
s, err := NewServer(addr, group)
|
||||||
@@ -89,6 +92,9 @@ func NewServerHTTPS(addr string, group []*Config) (*ServerHTTPS, error) {
|
|||||||
WriteTimeout: s.WriteTimeout,
|
WriteTimeout: s.WriteTimeout,
|
||||||
IdleTimeout: s.IdleTimeout,
|
IdleTimeout: s.IdleTimeout,
|
||||||
ErrorLog: stdlog.New(&loggerAdapter{}, "", 0),
|
ErrorLog: stdlog.New(&loggerAdapter{}, "", 0),
|
||||||
|
ConnContext: func(ctx context.Context, c net.Conn) context.Context {
|
||||||
|
return context.WithValue(ctx, connAddrKey{}, c.LocalAddr())
|
||||||
|
},
|
||||||
}
|
}
|
||||||
maxConnections := DefaultHTTPSMaxConnections
|
maxConnections := DefaultHTTPSMaxConnections
|
||||||
if len(group) > 0 && group[0] != nil && group[0].MaxHTTPSConnections != nil {
|
if len(group) > 0 && group[0] != nil && group[0].MaxHTTPSConnections != nil {
|
||||||
@@ -169,6 +175,14 @@ func (s *ServerHTTPS) Stop() error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// localAddr returns the per-connection local address from context, or s.listenAddr as fallback.
|
||||||
|
func (s *ServerHTTPS) localAddr(r *http.Request) net.Addr {
|
||||||
|
if addr, ok := r.Context().Value(connAddrKey{}).(net.Addr); ok {
|
||||||
|
return addr
|
||||||
|
}
|
||||||
|
return s.listenAddr
|
||||||
|
}
|
||||||
|
|
||||||
// ServeHTTP is the handler that gets the HTTP request and converts to the dns format, calls the plugin
|
// ServeHTTP is the handler that gets the HTTP request and converts to the dns format, calls the plugin
|
||||||
// chain, converts it back and write it to the client.
|
// chain, converts it back and write it to the client.
|
||||||
func (s *ServerHTTPS) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
func (s *ServerHTTPS) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||||
@@ -189,7 +203,7 @@ func (s *ServerHTTPS) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
|||||||
h, p, _ := net.SplitHostPort(r.RemoteAddr)
|
h, p, _ := net.SplitHostPort(r.RemoteAddr)
|
||||||
port, _ := strconv.Atoi(p)
|
port, _ := strconv.Atoi(p)
|
||||||
dw := &DoHWriter{
|
dw := &DoHWriter{
|
||||||
laddr: s.listenAddr,
|
laddr: s.localAddr(r),
|
||||||
raddr: &net.TCPAddr{IP: net.ParseIP(h), Port: port},
|
raddr: &net.TCPAddr{IP: net.ParseIP(h), Port: port},
|
||||||
request: r,
|
request: r,
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
"io"
|
"io"
|
||||||
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
"regexp"
|
"regexp"
|
||||||
@@ -167,6 +168,109 @@ func testConfigWithPlugin(p *contextCapturingPlugin) *Config {
|
|||||||
return c
|
return c
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestDoHWriterLaddrFromConnContext(t *testing.T) {
|
||||||
|
capturer := &addrCapturingPlugin{}
|
||||||
|
cfg := testConfigWithHandler(capturer)
|
||||||
|
|
||||||
|
s, err := NewServerHTTPS("127.0.0.1:443", []*Config{cfg})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal("could not create HTTPS server:", err)
|
||||||
|
}
|
||||||
|
s.listenAddr = &net.TCPAddr{IP: net.ParseIP("127.0.0.1"), Port: 443}
|
||||||
|
|
||||||
|
m := new(dns.Msg)
|
||||||
|
m.SetQuestion("example.com.", dns.TypeA)
|
||||||
|
buf, err := m.Pack()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Simulate a PROXY protocol destination that differs from listenAddr.
|
||||||
|
ppDst := &net.TCPAddr{IP: net.ParseIP("10.0.0.1"), Port: 443}
|
||||||
|
|
||||||
|
r := httptest.NewRequest(http.MethodPost, "/dns-query", io.NopCloser(bytes.NewReader(buf)))
|
||||||
|
ctx := context.WithValue(r.Context(), connAddrKey{}, ppDst)
|
||||||
|
r = r.WithContext(ctx)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
|
s.ServeHTTP(w, r)
|
||||||
|
|
||||||
|
if !capturer.called {
|
||||||
|
t.Fatal("plugin was not called")
|
||||||
|
}
|
||||||
|
if capturer.localAddr == nil {
|
||||||
|
t.Fatal("DoHWriter.laddr is nil")
|
||||||
|
}
|
||||||
|
if capturer.localAddr.String() != ppDst.String() {
|
||||||
|
t.Errorf("expected laddr %s (PP destination), got %s", ppDst, capturer.localAddr)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDoHWriterLaddrFallback(t *testing.T) {
|
||||||
|
capturer := &addrCapturingPlugin{}
|
||||||
|
cfg := testConfigWithHandler(capturer)
|
||||||
|
|
||||||
|
s, err := NewServerHTTPS("127.0.0.1:443", []*Config{cfg})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal("could not create HTTPS server:", err)
|
||||||
|
}
|
||||||
|
s.listenAddr = &net.TCPAddr{IP: net.ParseIP("127.0.0.1"), Port: 443}
|
||||||
|
|
||||||
|
m := new(dns.Msg)
|
||||||
|
m.SetQuestion("example.com.", dns.TypeA)
|
||||||
|
buf, err := m.Pack()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// No connAddrKey in context; should fall back to s.listenAddr.
|
||||||
|
r := httptest.NewRequest(http.MethodPost, "/dns-query", io.NopCloser(bytes.NewReader(buf)))
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
|
s.ServeHTTP(w, r)
|
||||||
|
|
||||||
|
if !capturer.called {
|
||||||
|
t.Fatal("plugin was not called")
|
||||||
|
}
|
||||||
|
if capturer.localAddr == nil {
|
||||||
|
t.Fatal("DoHWriter.laddr is nil")
|
||||||
|
}
|
||||||
|
if capturer.localAddr.String() != s.listenAddr.String() {
|
||||||
|
t.Errorf("expected fallback laddr %s, got %s", s.listenAddr, capturer.localAddr)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type addrCapturingPlugin struct {
|
||||||
|
called bool
|
||||||
|
localAddr net.Addr
|
||||||
|
remoteAddr net.Addr
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *addrCapturingPlugin) ServeDNS(_ context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) {
|
||||||
|
p.called = true
|
||||||
|
p.localAddr = w.LocalAddr()
|
||||||
|
p.remoteAddr = w.RemoteAddr()
|
||||||
|
m := new(dns.Msg)
|
||||||
|
m.SetReply(r)
|
||||||
|
m.Authoritative = true
|
||||||
|
w.WriteMsg(m)
|
||||||
|
return dns.RcodeSuccess, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *addrCapturingPlugin) Name() string { return "addr_capturing" }
|
||||||
|
|
||||||
|
func testConfigWithHandler(h plugin.Handler) *Config {
|
||||||
|
c := &Config{
|
||||||
|
Zone: "example.com.",
|
||||||
|
Transport: "https",
|
||||||
|
TLSConfig: &tls.Config{},
|
||||||
|
ListenHosts: []string{"127.0.0.1"},
|
||||||
|
Port: "443",
|
||||||
|
}
|
||||||
|
c.AddPlugin(func(_next plugin.Handler) plugin.Handler { return h })
|
||||||
|
return c
|
||||||
|
}
|
||||||
|
|
||||||
func TestHTTPRequestContextPropagation(t *testing.T) {
|
func TestHTTPRequestContextPropagation(t *testing.T) {
|
||||||
plugin := &contextCapturingPlugin{}
|
plugin := &contextCapturingPlugin{}
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user