mirror of
https://github.com/coredns/coredns.git
synced 2025-10-26 15:54:16 -04:00
fix(https): propagate HTTP request context (#7491)
This commit is contained in:
@@ -38,7 +38,8 @@ func (l *loggerAdapter) Write(p []byte) (n int, err error) {
|
||||
return len(p), nil
|
||||
}
|
||||
|
||||
// HTTPRequestKey is the context key for the current processed HTTP request (if current processed request was done over DOH)
|
||||
// HTTPRequestKey is the context key for the HTTP request when processing DNS-over-HTTPS.
|
||||
// Plugins can access the original HTTP request to retrieve headers, client IP, and metadata.
|
||||
type HTTPRequestKey struct{}
|
||||
|
||||
// NewServerHTTPS returns a new CoreDNS HTTPS server and compiles all plugins in to it.
|
||||
@@ -168,7 +169,11 @@ func (s *ServerHTTPS) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
// We just call the normal chain handler - all error handling is done there.
|
||||
// We should expect a packet to be returned that we can send to the client.
|
||||
ctx := context.WithValue(context.Background(), Key{}, s.Server)
|
||||
|
||||
// Propagate HTTP request context to DNS processing chain. This ensures that
|
||||
// HTTP request timeouts, cancellations, and other context values are properly
|
||||
// inherited by the DNS processing pipeline.
|
||||
ctx := context.WithValue(r.Context(), Key{}, s.Server)
|
||||
ctx = context.WithValue(ctx, LoopKey{}, 0)
|
||||
ctx = context.WithValue(ctx, HTTPRequestKey{}, r)
|
||||
s.ServeDNS(ctx, dw, msg)
|
||||
|
||||
@@ -2,11 +2,16 @@ package dnsserver
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"regexp"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/coredns/coredns/plugin"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
@@ -66,3 +71,163 @@ func TestCustomHTTPRequestValidator(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
type contextCapturingPlugin struct {
|
||||
capturedContext context.Context
|
||||
contextCancelled bool
|
||||
}
|
||||
|
||||
func (p *contextCapturingPlugin) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) {
|
||||
p.capturedContext = ctx
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
p.contextCancelled = true
|
||||
default:
|
||||
}
|
||||
|
||||
m := new(dns.Msg)
|
||||
m.SetReply(r)
|
||||
m.Authoritative = true
|
||||
w.WriteMsg(m)
|
||||
return dns.RcodeSuccess, nil
|
||||
}
|
||||
|
||||
func (p *contextCapturingPlugin) Name() string { return "context_capturing" }
|
||||
|
||||
func testConfigWithPlugin(p *contextCapturingPlugin) *Config {
|
||||
c := &Config{
|
||||
Zone: "example.com.",
|
||||
Transport: "https",
|
||||
TLSConfig: &tls.Config{},
|
||||
ListenHosts: []string{"127.0.0.1"},
|
||||
Port: "443",
|
||||
}
|
||||
c.AddPlugin(func(next plugin.Handler) plugin.Handler { return p })
|
||||
return c
|
||||
}
|
||||
|
||||
func TestHTTPRequestContextPropagation(t *testing.T) {
|
||||
plugin := &contextCapturingPlugin{}
|
||||
|
||||
s, err := NewServerHTTPS("127.0.0.1:443", []*Config{testConfigWithPlugin(plugin)})
|
||||
if err != nil {
|
||||
t.Fatal("could not create HTTPS server:", err)
|
||||
}
|
||||
|
||||
m := new(dns.Msg)
|
||||
m.SetQuestion("example.com.", dns.TypeA)
|
||||
buf, err := m.Pack()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
t.Run("context values propagation", func(t *testing.T) {
|
||||
contextValue := "test-request-id"
|
||||
|
||||
r := httptest.NewRequest(http.MethodPost, "/dns-query", io.NopCloser(bytes.NewReader(buf)))
|
||||
ctx := context.WithValue(r.Context(), Key{}, contextValue)
|
||||
r = r.WithContext(ctx)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
s.ServeHTTP(w, r)
|
||||
|
||||
if plugin.capturedContext == nil {
|
||||
t.Fatal("No context received in plugin")
|
||||
}
|
||||
|
||||
if val := plugin.capturedContext.Value(Key{}); val != s.Server {
|
||||
t.Error("Server key not properly set in context")
|
||||
}
|
||||
|
||||
if httpReq, ok := plugin.capturedContext.Value(HTTPRequestKey{}).(*http.Request); !ok {
|
||||
t.Error("HTTPRequestKey not found in context")
|
||||
} else if httpReq != r {
|
||||
t.Error("HTTPRequestKey contains different request than expected")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("plugins can access HTTP request details", func(t *testing.T) {
|
||||
r := httptest.NewRequest(http.MethodPost, "/dns-query", io.NopCloser(bytes.NewReader(buf)))
|
||||
r.Header.Set("User-Agent", "my-doh-client/2.1")
|
||||
r.Header.Set("X-Forwarded-For", "10.10.10.10")
|
||||
r.Header.Set("Accept", "application/dns-message")
|
||||
r.RemoteAddr = "10.10.10.100:45678"
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
s.ServeHTTP(w, r)
|
||||
|
||||
if plugin.capturedContext == nil {
|
||||
t.Fatal("No context received in plugin")
|
||||
}
|
||||
|
||||
httpReq, ok := plugin.capturedContext.Value(HTTPRequestKey{}).(*http.Request)
|
||||
if !ok {
|
||||
t.Fatal("HTTPRequestKey not found in context")
|
||||
}
|
||||
|
||||
if httpReq.Method != "POST" {
|
||||
t.Errorf("Plugin expected POST method, got %s", httpReq.Method)
|
||||
}
|
||||
|
||||
if ua := httpReq.Header.Get("User-Agent"); ua != "my-doh-client/2.1" {
|
||||
t.Errorf("Plugin expected User-Agent 'my-doh-client/2.1', got %s", ua)
|
||||
}
|
||||
|
||||
if xff := httpReq.Header.Get("X-Forwarded-For"); xff != "10.10.10.10" {
|
||||
t.Errorf("Plugin expected X-Forwarded-For '10.10.10.10', got %s", xff)
|
||||
}
|
||||
|
||||
if accept := httpReq.Header.Get("Accept"); accept != "application/dns-message" {
|
||||
t.Errorf("Plugin expected Accept 'application/dns-message', got %s", accept)
|
||||
}
|
||||
|
||||
if httpReq.RemoteAddr != "10.10.10.100:45678" {
|
||||
t.Errorf("Plugin expected RemoteAddr '10.10.10.100:45678', got %s", httpReq.RemoteAddr)
|
||||
}
|
||||
|
||||
if loopValue := plugin.capturedContext.Value(LoopKey{}); loopValue != 0 {
|
||||
t.Errorf("Expected LoopKey value 0, got %v", loopValue)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("context cancellation propagation", func(t *testing.T) {
|
||||
r := httptest.NewRequest(http.MethodPost, "/dns-query", io.NopCloser(bytes.NewReader(buf)))
|
||||
ctx, cancel := context.WithCancel(r.Context())
|
||||
r = r.WithContext(ctx)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
cancel()
|
||||
s.ServeHTTP(w, r)
|
||||
|
||||
if plugin.capturedContext == nil {
|
||||
t.Fatal("No context received in plugin")
|
||||
}
|
||||
|
||||
if !plugin.contextCancelled {
|
||||
t.Error("Context cancellation was not detected in plugin")
|
||||
}
|
||||
|
||||
if err := plugin.capturedContext.Err(); err == nil {
|
||||
t.Error("Expected context to be cancelled, but it wasn't")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("context timeout propagation", func(t *testing.T) {
|
||||
r := httptest.NewRequest(http.MethodPost, "/dns-query", io.NopCloser(bytes.NewReader(buf)))
|
||||
ctx, cancel := context.WithTimeout(r.Context(), time.Millisecond)
|
||||
defer cancel()
|
||||
r = r.WithContext(ctx)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
s.ServeHTTP(w, r)
|
||||
|
||||
if plugin.capturedContext == nil {
|
||||
t.Fatal("No context received in plugin")
|
||||
}
|
||||
|
||||
if deadline, ok := plugin.capturedContext.Deadline(); !ok {
|
||||
t.Error("Expected context to have a deadline")
|
||||
} else if deadline.IsZero() {
|
||||
t.Error("Context deadline is zero")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
20
plugin.md
20
plugin.md
@@ -71,12 +71,22 @@ your plugin handle reload events better.
|
||||
|
||||
## Context
|
||||
|
||||
Every request get a context.Context these are pre-filled with 2 values:
|
||||
Every request gets a `context.Context` with values that provide information about the request and server state.
|
||||
|
||||
* `Key`: holds a pointer to the current server, this can be useful for logging or metrics. It is
|
||||
infact used in the *metrics* plugin to tie a request to a specific (internal) server.
|
||||
* `LoopKey`: holds an integer to detect loops within the current context. The *file* plugin uses
|
||||
this to detect loops when resolving CNAMEs.
|
||||
### Core Context Values
|
||||
|
||||
These values are available for all DNS requests:
|
||||
|
||||
* `Key`: holds a pointer to the current server, useful for logging or metrics. Used by the *metrics* plugin to tie requests to specific (internal) server.
|
||||
* `LoopKey`: holds an integer to detect loops within the current context. Used by the *file* plugin when resolving CNAMEs.
|
||||
|
||||
### Transport-Specific Context Values
|
||||
|
||||
Depending on the DNS transport protocol, additional context values may be available:
|
||||
|
||||
* **DNS-over-HTTPS**: `HTTPRequestKey` contains the original `*http.Request`, providing access to HTTP headers, client information, and request metadata.
|
||||
* **DNS-over-gRPC**: Standard gRPC context values are available, including peer information via `peer.FromContext()` and metadata via `metadata.FromIncomingContext()`.
|
||||
* **DNS-over-QUIC**: QUIC stream context is propagated, including timeouts and cancellation signals.
|
||||
|
||||
## Documentation
|
||||
|
||||
|
||||
Reference in New Issue
Block a user