mirror of
https://github.com/coredns/coredns.git
synced 2025-10-26 15:54:16 -04:00
fix(https): propagate HTTP request context (#7491)
This commit is contained in:
@@ -38,7 +38,8 @@ func (l *loggerAdapter) Write(p []byte) (n int, err error) {
|
|||||||
return len(p), nil
|
return len(p), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// HTTPRequestKey is the context key for the current processed HTTP request (if current processed request was done over DOH)
|
// HTTPRequestKey is the context key for the HTTP request when processing DNS-over-HTTPS.
|
||||||
|
// Plugins can access the original HTTP request to retrieve headers, client IP, and metadata.
|
||||||
type HTTPRequestKey struct{}
|
type HTTPRequestKey struct{}
|
||||||
|
|
||||||
// NewServerHTTPS returns a new CoreDNS HTTPS server and compiles all plugins in to it.
|
// NewServerHTTPS returns a new CoreDNS HTTPS server and compiles all plugins in to it.
|
||||||
@@ -168,7 +169,11 @@ func (s *ServerHTTPS) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
|||||||
|
|
||||||
// We just call the normal chain handler - all error handling is done there.
|
// We just call the normal chain handler - all error handling is done there.
|
||||||
// We should expect a packet to be returned that we can send to the client.
|
// We should expect a packet to be returned that we can send to the client.
|
||||||
ctx := context.WithValue(context.Background(), Key{}, s.Server)
|
|
||||||
|
// Propagate HTTP request context to DNS processing chain. This ensures that
|
||||||
|
// HTTP request timeouts, cancellations, and other context values are properly
|
||||||
|
// inherited by the DNS processing pipeline.
|
||||||
|
ctx := context.WithValue(r.Context(), Key{}, s.Server)
|
||||||
ctx = context.WithValue(ctx, LoopKey{}, 0)
|
ctx = context.WithValue(ctx, LoopKey{}, 0)
|
||||||
ctx = context.WithValue(ctx, HTTPRequestKey{}, r)
|
ctx = context.WithValue(ctx, HTTPRequestKey{}, r)
|
||||||
s.ServeDNS(ctx, dw, msg)
|
s.ServeDNS(ctx, dw, msg)
|
||||||
|
|||||||
@@ -2,11 +2,16 @@ package dnsserver
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
|
"context"
|
||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
"regexp"
|
"regexp"
|
||||||
"testing"
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/coredns/coredns/plugin"
|
||||||
|
|
||||||
"github.com/miekg/dns"
|
"github.com/miekg/dns"
|
||||||
)
|
)
|
||||||
@@ -66,3 +71,163 @@ func TestCustomHTTPRequestValidator(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type contextCapturingPlugin struct {
|
||||||
|
capturedContext context.Context
|
||||||
|
contextCancelled bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *contextCapturingPlugin) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) {
|
||||||
|
p.capturedContext = ctx
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
p.contextCancelled = true
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
|
||||||
|
m := new(dns.Msg)
|
||||||
|
m.SetReply(r)
|
||||||
|
m.Authoritative = true
|
||||||
|
w.WriteMsg(m)
|
||||||
|
return dns.RcodeSuccess, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *contextCapturingPlugin) Name() string { return "context_capturing" }
|
||||||
|
|
||||||
|
func testConfigWithPlugin(p *contextCapturingPlugin) *Config {
|
||||||
|
c := &Config{
|
||||||
|
Zone: "example.com.",
|
||||||
|
Transport: "https",
|
||||||
|
TLSConfig: &tls.Config{},
|
||||||
|
ListenHosts: []string{"127.0.0.1"},
|
||||||
|
Port: "443",
|
||||||
|
}
|
||||||
|
c.AddPlugin(func(next plugin.Handler) plugin.Handler { return p })
|
||||||
|
return c
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHTTPRequestContextPropagation(t *testing.T) {
|
||||||
|
plugin := &contextCapturingPlugin{}
|
||||||
|
|
||||||
|
s, err := NewServerHTTPS("127.0.0.1:443", []*Config{testConfigWithPlugin(plugin)})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal("could not create HTTPS server:", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
m := new(dns.Msg)
|
||||||
|
m.SetQuestion("example.com.", dns.TypeA)
|
||||||
|
buf, err := m.Pack()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
t.Run("context values propagation", func(t *testing.T) {
|
||||||
|
contextValue := "test-request-id"
|
||||||
|
|
||||||
|
r := httptest.NewRequest(http.MethodPost, "/dns-query", io.NopCloser(bytes.NewReader(buf)))
|
||||||
|
ctx := context.WithValue(r.Context(), Key{}, contextValue)
|
||||||
|
r = r.WithContext(ctx)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
|
s.ServeHTTP(w, r)
|
||||||
|
|
||||||
|
if plugin.capturedContext == nil {
|
||||||
|
t.Fatal("No context received in plugin")
|
||||||
|
}
|
||||||
|
|
||||||
|
if val := plugin.capturedContext.Value(Key{}); val != s.Server {
|
||||||
|
t.Error("Server key not properly set in context")
|
||||||
|
}
|
||||||
|
|
||||||
|
if httpReq, ok := plugin.capturedContext.Value(HTTPRequestKey{}).(*http.Request); !ok {
|
||||||
|
t.Error("HTTPRequestKey not found in context")
|
||||||
|
} else if httpReq != r {
|
||||||
|
t.Error("HTTPRequestKey contains different request than expected")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("plugins can access HTTP request details", func(t *testing.T) {
|
||||||
|
r := httptest.NewRequest(http.MethodPost, "/dns-query", io.NopCloser(bytes.NewReader(buf)))
|
||||||
|
r.Header.Set("User-Agent", "my-doh-client/2.1")
|
||||||
|
r.Header.Set("X-Forwarded-For", "10.10.10.10")
|
||||||
|
r.Header.Set("Accept", "application/dns-message")
|
||||||
|
r.RemoteAddr = "10.10.10.100:45678"
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
|
s.ServeHTTP(w, r)
|
||||||
|
|
||||||
|
if plugin.capturedContext == nil {
|
||||||
|
t.Fatal("No context received in plugin")
|
||||||
|
}
|
||||||
|
|
||||||
|
httpReq, ok := plugin.capturedContext.Value(HTTPRequestKey{}).(*http.Request)
|
||||||
|
if !ok {
|
||||||
|
t.Fatal("HTTPRequestKey not found in context")
|
||||||
|
}
|
||||||
|
|
||||||
|
if httpReq.Method != "POST" {
|
||||||
|
t.Errorf("Plugin expected POST method, got %s", httpReq.Method)
|
||||||
|
}
|
||||||
|
|
||||||
|
if ua := httpReq.Header.Get("User-Agent"); ua != "my-doh-client/2.1" {
|
||||||
|
t.Errorf("Plugin expected User-Agent 'my-doh-client/2.1', got %s", ua)
|
||||||
|
}
|
||||||
|
|
||||||
|
if xff := httpReq.Header.Get("X-Forwarded-For"); xff != "10.10.10.10" {
|
||||||
|
t.Errorf("Plugin expected X-Forwarded-For '10.10.10.10', got %s", xff)
|
||||||
|
}
|
||||||
|
|
||||||
|
if accept := httpReq.Header.Get("Accept"); accept != "application/dns-message" {
|
||||||
|
t.Errorf("Plugin expected Accept 'application/dns-message', got %s", accept)
|
||||||
|
}
|
||||||
|
|
||||||
|
if httpReq.RemoteAddr != "10.10.10.100:45678" {
|
||||||
|
t.Errorf("Plugin expected RemoteAddr '10.10.10.100:45678', got %s", httpReq.RemoteAddr)
|
||||||
|
}
|
||||||
|
|
||||||
|
if loopValue := plugin.capturedContext.Value(LoopKey{}); loopValue != 0 {
|
||||||
|
t.Errorf("Expected LoopKey value 0, got %v", loopValue)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("context cancellation propagation", func(t *testing.T) {
|
||||||
|
r := httptest.NewRequest(http.MethodPost, "/dns-query", io.NopCloser(bytes.NewReader(buf)))
|
||||||
|
ctx, cancel := context.WithCancel(r.Context())
|
||||||
|
r = r.WithContext(ctx)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
|
cancel()
|
||||||
|
s.ServeHTTP(w, r)
|
||||||
|
|
||||||
|
if plugin.capturedContext == nil {
|
||||||
|
t.Fatal("No context received in plugin")
|
||||||
|
}
|
||||||
|
|
||||||
|
if !plugin.contextCancelled {
|
||||||
|
t.Error("Context cancellation was not detected in plugin")
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := plugin.capturedContext.Err(); err == nil {
|
||||||
|
t.Error("Expected context to be cancelled, but it wasn't")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("context timeout propagation", func(t *testing.T) {
|
||||||
|
r := httptest.NewRequest(http.MethodPost, "/dns-query", io.NopCloser(bytes.NewReader(buf)))
|
||||||
|
ctx, cancel := context.WithTimeout(r.Context(), time.Millisecond)
|
||||||
|
defer cancel()
|
||||||
|
r = r.WithContext(ctx)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
|
s.ServeHTTP(w, r)
|
||||||
|
|
||||||
|
if plugin.capturedContext == nil {
|
||||||
|
t.Fatal("No context received in plugin")
|
||||||
|
}
|
||||||
|
|
||||||
|
if deadline, ok := plugin.capturedContext.Deadline(); !ok {
|
||||||
|
t.Error("Expected context to have a deadline")
|
||||||
|
} else if deadline.IsZero() {
|
||||||
|
t.Error("Context deadline is zero")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|||||||
20
plugin.md
20
plugin.md
@@ -71,12 +71,22 @@ your plugin handle reload events better.
|
|||||||
|
|
||||||
## Context
|
## Context
|
||||||
|
|
||||||
Every request get a context.Context these are pre-filled with 2 values:
|
Every request gets a `context.Context` with values that provide information about the request and server state.
|
||||||
|
|
||||||
* `Key`: holds a pointer to the current server, this can be useful for logging or metrics. It is
|
### Core Context Values
|
||||||
infact used in the *metrics* plugin to tie a request to a specific (internal) server.
|
|
||||||
* `LoopKey`: holds an integer to detect loops within the current context. The *file* plugin uses
|
These values are available for all DNS requests:
|
||||||
this to detect loops when resolving CNAMEs.
|
|
||||||
|
* `Key`: holds a pointer to the current server, useful for logging or metrics. Used by the *metrics* plugin to tie requests to specific (internal) server.
|
||||||
|
* `LoopKey`: holds an integer to detect loops within the current context. Used by the *file* plugin when resolving CNAMEs.
|
||||||
|
|
||||||
|
### Transport-Specific Context Values
|
||||||
|
|
||||||
|
Depending on the DNS transport protocol, additional context values may be available:
|
||||||
|
|
||||||
|
* **DNS-over-HTTPS**: `HTTPRequestKey` contains the original `*http.Request`, providing access to HTTP headers, client information, and request metadata.
|
||||||
|
* **DNS-over-gRPC**: Standard gRPC context values are available, including peer information via `peer.FromContext()` and metadata via `metadata.FromIncomingContext()`.
|
||||||
|
* **DNS-over-QUIC**: QUIC stream context is propagated, including timeouts and cancellation signals.
|
||||||
|
|
||||||
## Documentation
|
## Documentation
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user