diff --git a/core/dnsserver/server_https.go b/core/dnsserver/server_https.go index d4cb64be9..2e5d22bdf 100644 --- a/core/dnsserver/server_https.go +++ b/core/dnsserver/server_https.go @@ -38,7 +38,8 @@ func (l *loggerAdapter) Write(p []byte) (n int, err error) { return len(p), nil } -// HTTPRequestKey is the context key for the current processed HTTP request (if current processed request was done over DOH) +// HTTPRequestKey is the context key for the HTTP request when processing DNS-over-HTTPS. +// Plugins can access the original HTTP request to retrieve headers, client IP, and metadata. type HTTPRequestKey struct{} // NewServerHTTPS returns a new CoreDNS HTTPS server and compiles all plugins in to it. @@ -168,7 +169,11 @@ func (s *ServerHTTPS) ServeHTTP(w http.ResponseWriter, r *http.Request) { // We just call the normal chain handler - all error handling is done there. // We should expect a packet to be returned that we can send to the client. - ctx := context.WithValue(context.Background(), Key{}, s.Server) + + // Propagate HTTP request context to DNS processing chain. This ensures that + // HTTP request timeouts, cancellations, and other context values are properly + // inherited by the DNS processing pipeline. + ctx := context.WithValue(r.Context(), Key{}, s.Server) ctx = context.WithValue(ctx, LoopKey{}, 0) ctx = context.WithValue(ctx, HTTPRequestKey{}, r) s.ServeDNS(ctx, dw, msg) diff --git a/core/dnsserver/server_https_test.go b/core/dnsserver/server_https_test.go index 031cbf14a..5d062e168 100644 --- a/core/dnsserver/server_https_test.go +++ b/core/dnsserver/server_https_test.go @@ -2,11 +2,16 @@ package dnsserver import ( "bytes" + "context" "crypto/tls" + "io" "net/http" "net/http/httptest" "regexp" "testing" + "time" + + "github.com/coredns/coredns/plugin" "github.com/miekg/dns" ) @@ -66,3 +71,163 @@ func TestCustomHTTPRequestValidator(t *testing.T) { }) } } + +type contextCapturingPlugin struct { + capturedContext context.Context + contextCancelled bool +} + +func (p *contextCapturingPlugin) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) { + p.capturedContext = ctx + select { + case <-ctx.Done(): + p.contextCancelled = true + default: + } + + m := new(dns.Msg) + m.SetReply(r) + m.Authoritative = true + w.WriteMsg(m) + return dns.RcodeSuccess, nil +} + +func (p *contextCapturingPlugin) Name() string { return "context_capturing" } + +func testConfigWithPlugin(p *contextCapturingPlugin) *Config { + c := &Config{ + Zone: "example.com.", + Transport: "https", + TLSConfig: &tls.Config{}, + ListenHosts: []string{"127.0.0.1"}, + Port: "443", + } + c.AddPlugin(func(next plugin.Handler) plugin.Handler { return p }) + return c +} + +func TestHTTPRequestContextPropagation(t *testing.T) { + plugin := &contextCapturingPlugin{} + + s, err := NewServerHTTPS("127.0.0.1:443", []*Config{testConfigWithPlugin(plugin)}) + if err != nil { + t.Fatal("could not create HTTPS server:", err) + } + + m := new(dns.Msg) + m.SetQuestion("example.com.", dns.TypeA) + buf, err := m.Pack() + if err != nil { + t.Fatal(err) + } + t.Run("context values propagation", func(t *testing.T) { + contextValue := "test-request-id" + + r := httptest.NewRequest(http.MethodPost, "/dns-query", io.NopCloser(bytes.NewReader(buf))) + ctx := context.WithValue(r.Context(), Key{}, contextValue) + r = r.WithContext(ctx) + w := httptest.NewRecorder() + + s.ServeHTTP(w, r) + + if plugin.capturedContext == nil { + t.Fatal("No context received in plugin") + } + + if val := plugin.capturedContext.Value(Key{}); val != s.Server { + t.Error("Server key not properly set in context") + } + + if httpReq, ok := plugin.capturedContext.Value(HTTPRequestKey{}).(*http.Request); !ok { + t.Error("HTTPRequestKey not found in context") + } else if httpReq != r { + t.Error("HTTPRequestKey contains different request than expected") + } + }) + + t.Run("plugins can access HTTP request details", func(t *testing.T) { + r := httptest.NewRequest(http.MethodPost, "/dns-query", io.NopCloser(bytes.NewReader(buf))) + r.Header.Set("User-Agent", "my-doh-client/2.1") + r.Header.Set("X-Forwarded-For", "10.10.10.10") + r.Header.Set("Accept", "application/dns-message") + r.RemoteAddr = "10.10.10.100:45678" + w := httptest.NewRecorder() + + s.ServeHTTP(w, r) + + if plugin.capturedContext == nil { + t.Fatal("No context received in plugin") + } + + httpReq, ok := plugin.capturedContext.Value(HTTPRequestKey{}).(*http.Request) + if !ok { + t.Fatal("HTTPRequestKey not found in context") + } + + if httpReq.Method != "POST" { + t.Errorf("Plugin expected POST method, got %s", httpReq.Method) + } + + if ua := httpReq.Header.Get("User-Agent"); ua != "my-doh-client/2.1" { + t.Errorf("Plugin expected User-Agent 'my-doh-client/2.1', got %s", ua) + } + + if xff := httpReq.Header.Get("X-Forwarded-For"); xff != "10.10.10.10" { + t.Errorf("Plugin expected X-Forwarded-For '10.10.10.10', got %s", xff) + } + + if accept := httpReq.Header.Get("Accept"); accept != "application/dns-message" { + t.Errorf("Plugin expected Accept 'application/dns-message', got %s", accept) + } + + if httpReq.RemoteAddr != "10.10.10.100:45678" { + t.Errorf("Plugin expected RemoteAddr '10.10.10.100:45678', got %s", httpReq.RemoteAddr) + } + + if loopValue := plugin.capturedContext.Value(LoopKey{}); loopValue != 0 { + t.Errorf("Expected LoopKey value 0, got %v", loopValue) + } + }) + + t.Run("context cancellation propagation", func(t *testing.T) { + r := httptest.NewRequest(http.MethodPost, "/dns-query", io.NopCloser(bytes.NewReader(buf))) + ctx, cancel := context.WithCancel(r.Context()) + r = r.WithContext(ctx) + w := httptest.NewRecorder() + + cancel() + s.ServeHTTP(w, r) + + if plugin.capturedContext == nil { + t.Fatal("No context received in plugin") + } + + if !plugin.contextCancelled { + t.Error("Context cancellation was not detected in plugin") + } + + if err := plugin.capturedContext.Err(); err == nil { + t.Error("Expected context to be cancelled, but it wasn't") + } + }) + + t.Run("context timeout propagation", func(t *testing.T) { + r := httptest.NewRequest(http.MethodPost, "/dns-query", io.NopCloser(bytes.NewReader(buf))) + ctx, cancel := context.WithTimeout(r.Context(), time.Millisecond) + defer cancel() + r = r.WithContext(ctx) + w := httptest.NewRecorder() + + s.ServeHTTP(w, r) + + if plugin.capturedContext == nil { + t.Fatal("No context received in plugin") + } + + if deadline, ok := plugin.capturedContext.Deadline(); !ok { + t.Error("Expected context to have a deadline") + } else if deadline.IsZero() { + t.Error("Context deadline is zero") + } + }) +} diff --git a/plugin.md b/plugin.md index be6360c57..b1c68151f 100644 --- a/plugin.md +++ b/plugin.md @@ -71,12 +71,22 @@ your plugin handle reload events better. ## Context -Every request get a context.Context these are pre-filled with 2 values: +Every request gets a `context.Context` with values that provide information about the request and server state. -* `Key`: holds a pointer to the current server, this can be useful for logging or metrics. It is - infact used in the *metrics* plugin to tie a request to a specific (internal) server. -* `LoopKey`: holds an integer to detect loops within the current context. The *file* plugin uses - this to detect loops when resolving CNAMEs. +### Core Context Values + +These values are available for all DNS requests: + +* `Key`: holds a pointer to the current server, useful for logging or metrics. Used by the *metrics* plugin to tie requests to specific (internal) server. +* `LoopKey`: holds an integer to detect loops within the current context. Used by the *file* plugin when resolving CNAMEs. + +### Transport-Specific Context Values + +Depending on the DNS transport protocol, additional context values may be available: + +* **DNS-over-HTTPS**: `HTTPRequestKey` contains the original `*http.Request`, providing access to HTTP headers, client information, and request metadata. +* **DNS-over-gRPC**: Standard gRPC context values are available, including peer information via `peer.FromContext()` and metadata via `metadata.FromIncomingContext()`. +* **DNS-over-QUIC**: QUIC stream context is propagated, including timeouts and cancellation signals. ## Documentation