Merge commit from fork

Add configurable resource limits to prevent potential DoS vectors
via connection/stream exhaustion on gRPC, HTTPS, and HTTPS/3 servers.

New configuration plugins:
- grpc_server: configure max_streams, max_connections
- https: configure max_connections
- https3: configure max_streams

Changes:
- Use netutil.LimitListener for connection limiting
- Use gRPC MaxConcurrentStreams and message size limits
- Add QUIC MaxIncomingStreams for HTTPS/3 stream limiting
- Set secure defaults: 256 max streams, 200 max connections
- Setting any limit to 0 means unbounded/fallback to previous impl

Defaults are applied automatically when plugins are omitted from
config.

Includes tests and integration tests.

Signed-off-by: Ville Vesilehto <ville@vesilehto.fi>
This commit is contained in:
Ville Vesilehto
2025-12-18 05:08:59 +02:00
committed by GitHub
parent 0fb05f225c
commit 0d8cbb1a6b
24 changed files with 1689 additions and 24 deletions

View File

@@ -66,6 +66,22 @@ type Config struct {
// This is nil if not specified, allowing for a default to be used.
MaxQUICWorkerPoolSize *int
// MaxGRPCStreams defines the maximum number of concurrent streams per gRPC connection.
// This is nil if not specified, allowing for a default to be used.
MaxGRPCStreams *int
// MaxGRPCConnections defines the maximum number of concurrent gRPC connections.
// This is nil if not specified, allowing for a default to be used.
MaxGRPCConnections *int
// MaxHTTPSConnections defines the maximum number of concurrent HTTPS connections.
// This is nil if not specified, allowing for a default to be used.
MaxHTTPSConnections *int
// MaxHTTPS3Streams defines the maximum number of concurrent QUIC streams for HTTPS3.
// This is nil if not specified, allowing for a default to be used.
MaxHTTPS3Streams *int
// Timeouts for TCP, TLS and HTTPS servers.
ReadTimeout time.Duration
WriteTimeout time.Duration

View File

@@ -15,17 +15,35 @@ import (
"github.com/grpc-ecosystem/grpc-opentracing/go/otgrpc"
"github.com/miekg/dns"
"github.com/opentracing/opentracing-go"
"golang.org/x/net/netutil"
"google.golang.org/grpc"
"google.golang.org/grpc/peer"
)
const (
// maxDNSMessageBytes is the maximum size of a DNS message on the wire.
maxDNSMessageBytes = dns.MaxMsgSize
// maxProtobufPayloadBytes accounts for protobuf overhead.
// Field tag=1 (1 byte) + length varint for 65535 (3 bytes) = 4 bytes total
maxProtobufPayloadBytes = maxDNSMessageBytes + 4
// DefaultGRPCMaxStreams is the default maximum number of concurrent streams per connection.
DefaultGRPCMaxStreams = 256
// DefaultGRPCMaxConnections is the default maximum number of concurrent connections.
DefaultGRPCMaxConnections = 200
)
// ServergRPC represents an instance of a DNS-over-gRPC server.
type ServergRPC struct {
*Server
*pb.UnimplementedDnsServiceServer
grpcServer *grpc.Server
listenAddr net.Addr
tlsConfig *tls.Config
grpcServer *grpc.Server
listenAddr net.Addr
tlsConfig *tls.Config
maxStreams int
maxConnections int
}
// NewServergRPC returns a new CoreDNS GRPC server and compiles all plugin in to it.
@@ -49,7 +67,22 @@ func NewServergRPC(addr string, group []*Config) (*ServergRPC, error) {
tlsConfig.NextProtos = []string{"h2"}
}
return &ServergRPC{Server: s, tlsConfig: tlsConfig}, nil
maxStreams := DefaultGRPCMaxStreams
if len(group) > 0 && group[0] != nil && group[0].MaxGRPCStreams != nil {
maxStreams = *group[0].MaxGRPCStreams
}
maxConnections := DefaultGRPCMaxConnections
if len(group) > 0 && group[0] != nil && group[0].MaxGRPCConnections != nil {
maxConnections = *group[0].MaxGRPCConnections
}
return &ServergRPC{
Server: s,
tlsConfig: tlsConfig,
maxStreams: maxStreams,
maxConnections: maxConnections,
}, nil
}
// Compile-time check to ensure ServergRPC implements the caddy.GracefulServer interface
@@ -61,21 +94,36 @@ func (s *ServergRPC) Serve(l net.Listener) error {
s.listenAddr = l.Addr()
s.m.Unlock()
serverOpts := []grpc.ServerOption{
grpc.MaxRecvMsgSize(maxProtobufPayloadBytes),
grpc.MaxSendMsgSize(maxProtobufPayloadBytes),
}
// Only set MaxConcurrentStreams if not unbounded (0)
if s.maxStreams > 0 {
serverOpts = append(serverOpts, grpc.MaxConcurrentStreams(uint32(s.maxStreams)))
}
if s.Tracer() != nil {
onlyIfParent := func(parentSpanCtx opentracing.SpanContext, method string, req, resp any) bool {
return parentSpanCtx != nil
}
intercept := otgrpc.OpenTracingServerInterceptor(s.Tracer(), otgrpc.IncludingSpans(onlyIfParent))
s.grpcServer = grpc.NewServer(grpc.UnaryInterceptor(intercept))
} else {
s.grpcServer = grpc.NewServer()
serverOpts = append(serverOpts, grpc.UnaryInterceptor(otgrpc.OpenTracingServerInterceptor(s.Tracer(), otgrpc.IncludingSpans(onlyIfParent))))
}
s.grpcServer = grpc.NewServer(serverOpts...)
pb.RegisterDnsServiceServer(s.grpcServer, s)
if s.tlsConfig != nil {
l = tls.NewListener(l, s.tlsConfig)
}
// Wrap listener to limit concurrent connections
if s.maxConnections > 0 {
l = netutil.LimitListener(l, s.maxConnections)
}
return s.grpcServer.Serve(l)
}
@@ -122,6 +170,9 @@ func (s *ServergRPC) Stop() (err error) {
// any normal server. We use a custom responseWriter to pick up the bytes we need to write
// back to the client as a protobuf.
func (s *ServergRPC) Query(ctx context.Context, in *pb.DnsPacket) (*pb.DnsPacket, error) {
if len(in.GetMsg()) > dns.MaxMsgSize {
return nil, fmt.Errorf("dns message exceeds size limit: %d", len(in.GetMsg()))
}
msg := new(dns.Msg)
err := msg.Unpack(in.GetMsg())
if err != nil {

View File

@@ -69,6 +69,61 @@ func TestNewServergRPCWithTLS(t *testing.T) {
}
}
func TestNewServergRPCWithCustomLimits(t *testing.T) {
config := testConfig("grpc", testPlugin{})
maxStreams := 50
maxConnections := 100
config.MaxGRPCStreams = &maxStreams
config.MaxGRPCConnections = &maxConnections
server, err := NewServergRPC("127.0.0.1:0", []*Config{config})
if err != nil {
t.Fatalf("NewServergRPC() with custom limits failed: %v", err)
}
if server.maxStreams != maxStreams {
t.Errorf("Expected maxStreams = %d, got %d", maxStreams, server.maxStreams)
}
if server.maxConnections != maxConnections {
t.Errorf("Expected maxConnections = %d, got %d", maxConnections, server.maxConnections)
}
}
func TestNewServergRPCDefaults(t *testing.T) {
server, err := NewServergRPC("127.0.0.1:0", []*Config{testConfig("grpc", testPlugin{})})
if err != nil {
t.Fatalf("NewServergRPC() failed: %v", err)
}
if server.maxStreams != DefaultGRPCMaxStreams {
t.Errorf("Expected default maxStreams = %d, got %d", DefaultGRPCMaxStreams, server.maxStreams)
}
if server.maxConnections != DefaultGRPCMaxConnections {
t.Errorf("Expected default maxConnections = %d, got %d", DefaultGRPCMaxConnections, server.maxConnections)
}
}
func TestNewServergRPCZeroLimits(t *testing.T) {
config := testConfig("grpc", testPlugin{})
zero := 0
config.MaxGRPCStreams = &zero
config.MaxGRPCConnections = &zero
server, err := NewServergRPC("127.0.0.1:0", []*Config{config})
if err != nil {
t.Fatalf("NewServergRPC() with zero limits failed: %v", err)
}
if server.maxStreams != 0 {
t.Errorf("Expected maxStreams = 0, got %d", server.maxStreams)
}
if server.maxConnections != 0 {
t.Errorf("Expected maxConnections = 0, got %d", server.maxConnections)
}
}
func TestServergRPC_Listen(t *testing.T) {
server, err := NewServergRPC(transport.GRPC+"://127.0.0.1:0", []*Config{testConfig("grpc", testPlugin{})})
if err != nil {
@@ -328,3 +383,67 @@ func TestGRPCResponse_WriteInvalidMessage(t *testing.T) {
t.Error("Write() should return error for invalid DNS message")
}
}
func TestServergRPC_Query_LargeMessage(t *testing.T) {
server, err := NewServergRPC("127.0.0.1:0", []*Config{testConfig("grpc", testPlugin{})})
if err != nil {
t.Fatalf("NewServergRPC failed: %v", err)
}
// Create oversized message (> dns.MaxMsgSize = 65535)
oversizedMsg := make([]byte, dns.MaxMsgSize+1)
dnsPacket := &pb.DnsPacket{Msg: oversizedMsg}
tcpAddr, _ := net.ResolveTCPAddr("tcp", "127.0.0.1:12345")
p := &peer.Peer{Addr: tcpAddr}
ctx := peer.NewContext(context.Background(), p)
server.listenAddr = tcpAddr
_, err = server.Query(ctx, dnsPacket)
if err == nil {
t.Error("Expected error for oversized message")
}
expectedError := "dns message exceeds size limit: 65536"
if err.Error() != expectedError {
t.Errorf("Expected error '%s', got '%s'", expectedError, err.Error())
}
}
func TestServergRPC_Query_MaxSizeMessage(t *testing.T) {
server, err := NewServergRPC("127.0.0.1:0", []*Config{testConfig("grpc", testPlugin{})})
if err != nil {
t.Fatalf("NewServergRPC failed: %v", err)
}
// Create message exactly at the size limit (dns.MaxMsgSize = 65535)
msg := new(dns.Msg)
msg.SetQuestion("example.com.", dns.TypeA)
packed, err := msg.Pack()
if err != nil {
t.Fatalf("Failed to pack DNS message: %v", err)
}
// Pad the message to exactly max size
if len(packed) > dns.MaxMsgSize {
t.Fatalf("Packed message is already larger than max size: %d", len(packed))
}
maxSizeMsg := make([]byte, dns.MaxMsgSize)
copy(maxSizeMsg, packed)
dnsPacket := &pb.DnsPacket{Msg: maxSizeMsg}
tcpAddr, _ := net.ResolveTCPAddr("tcp", "127.0.0.1:12345")
p := &peer.Peer{Addr: tcpAddr}
ctx := peer.NewContext(context.Background(), p)
server.listenAddr = tcpAddr
// Should not return an error for exactly max size message
_, err = server.Query(ctx, dnsPacket)
if err != nil {
t.Errorf("Expected no error for max size message, got: %v", err)
}
}

View File

@@ -18,15 +18,23 @@ import (
"github.com/coredns/coredns/plugin/pkg/response"
"github.com/coredns/coredns/plugin/pkg/reuseport"
"github.com/coredns/coredns/plugin/pkg/transport"
"golang.org/x/net/netutil"
)
const (
// DefaultHTTPSMaxConnections is the default maximum number of concurrent connections.
DefaultHTTPSMaxConnections = 200
)
// ServerHTTPS represents an instance of a DNS-over-HTTPS server.
type ServerHTTPS struct {
*Server
httpsServer *http.Server
listenAddr net.Addr
tlsConfig *tls.Config
validRequest func(*http.Request) bool
httpsServer *http.Server
listenAddr net.Addr
tlsConfig *tls.Config
validRequest func(*http.Request) bool
maxConnections int
}
// loggerAdapter is a simple adapter around CoreDNS logger made to implement io.Writer in order to log errors from HTTP server
@@ -81,8 +89,17 @@ func NewServerHTTPS(addr string, group []*Config) (*ServerHTTPS, error) {
IdleTimeout: s.IdleTimeout,
ErrorLog: stdlog.New(&loggerAdapter{}, "", 0),
}
maxConnections := DefaultHTTPSMaxConnections
if len(group) > 0 && group[0] != nil && group[0].MaxHTTPSConnections != nil {
maxConnections = *group[0].MaxHTTPSConnections
}
sh := &ServerHTTPS{
Server: s, tlsConfig: tlsConfig, httpsServer: srv, validRequest: validator,
Server: s,
tlsConfig: tlsConfig,
httpsServer: srv,
validRequest: validator,
maxConnections: maxConnections,
}
sh.httpsServer.Handler = sh
@@ -98,9 +115,15 @@ func (s *ServerHTTPS) Serve(l net.Listener) error {
s.listenAddr = l.Addr()
s.m.Unlock()
// Wrap listener to limit concurrent connections (before TLS)
if s.maxConnections > 0 {
l = netutil.LimitListener(l, s.maxConnections)
}
if s.tlsConfig != nil {
l = tls.NewListener(l, s.tlsConfig)
}
return s.httpsServer.Serve(l)
}

View File

@@ -21,6 +21,11 @@ import (
"github.com/quic-go/quic-go/http3"
)
const (
// DefaultHTTPS3MaxStreams is the default maximum number of concurrent QUIC streams per connection.
DefaultHTTPS3MaxStreams = 256
)
// ServerHTTPS3 represents a DNS-over-HTTP/3 server.
type ServerHTTPS3 struct {
*Server
@@ -29,6 +34,7 @@ type ServerHTTPS3 struct {
tlsConfig *tls.Config
quicConfig *quic.Config
validRequest func(*http.Request) bool
maxStreams int
}
// NewServerHTTPS3 builds the HTTP/3 (DoH3) server.
@@ -63,11 +69,20 @@ func NewServerHTTPS3(addr string, group []*Config) (*ServerHTTPS3, error) {
validator = func(r *http.Request) bool { return r.URL.Path == doh.Path }
}
// QUIC transport config
maxStreams := DefaultHTTPS3MaxStreams
if len(group) > 0 && group[0] != nil && group[0].MaxHTTPS3Streams != nil {
maxStreams = *group[0].MaxHTTPS3Streams
}
// QUIC transport config with stream limits (0 means use QUIC default)
qconf := &quic.Config{
MaxIdleTimeout: s.IdleTimeout,
Allow0RTT: true,
}
if maxStreams > 0 {
qconf.MaxIncomingStreams = int64(maxStreams)
qconf.MaxIncomingUniStreams = int64(maxStreams)
}
h3srv := &http3.Server{
Handler: nil, // set after constructing ServerHTTPS3
@@ -83,6 +98,7 @@ func NewServerHTTPS3(addr string, group []*Config) (*ServerHTTPS3, error) {
httpsServer: h3srv,
quicConfig: qconf,
validRequest: validator,
maxStreams: maxStreams,
}
h3srv.Handler = sh

View File

@@ -60,3 +60,82 @@ func TestCustomHTTP3RequestValidator(t *testing.T) {
})
}
}
func TestNewServerHTTPS3WithCustomLimits(t *testing.T) {
maxStreams := 50
c := Config{
Zone: "example.com.",
Transport: "https3",
TLSConfig: &tls.Config{},
ListenHosts: []string{"127.0.0.1"},
Port: "443",
MaxHTTPS3Streams: &maxStreams,
}
server, err := NewServerHTTPS3("127.0.0.1:443", []*Config{&c})
if err != nil {
t.Fatalf("NewServerHTTPS3() with custom limits failed: %v", err)
}
if server.maxStreams != maxStreams {
t.Errorf("Expected maxStreams = %d, got %d", maxStreams, server.maxStreams)
}
expectedMaxStreams := int64(maxStreams)
if server.quicConfig.MaxIncomingStreams != expectedMaxStreams {
t.Errorf("Expected quicConfig.MaxIncomingStreams = %d, got %d", expectedMaxStreams, server.quicConfig.MaxIncomingStreams)
}
if server.quicConfig.MaxIncomingUniStreams != expectedMaxStreams {
t.Errorf("Expected quicConfig.MaxIncomingUniStreams = %d, got %d", expectedMaxStreams, server.quicConfig.MaxIncomingUniStreams)
}
}
func TestNewServerHTTPS3Defaults(t *testing.T) {
c := Config{
Zone: "example.com.",
Transport: "https3",
TLSConfig: &tls.Config{},
ListenHosts: []string{"127.0.0.1"},
Port: "443",
}
server, err := NewServerHTTPS3("127.0.0.1:443", []*Config{&c})
if err != nil {
t.Fatalf("NewServerHTTPS3() failed: %v", err)
}
if server.maxStreams != DefaultHTTPS3MaxStreams {
t.Errorf("Expected default maxStreams = %d, got %d", DefaultHTTPS3MaxStreams, server.maxStreams)
}
expectedMaxStreams := int64(DefaultHTTPS3MaxStreams)
if server.quicConfig.MaxIncomingStreams != expectedMaxStreams {
t.Errorf("Expected default quicConfig.MaxIncomingStreams = %d, got %d", expectedMaxStreams, server.quicConfig.MaxIncomingStreams)
}
}
func TestNewServerHTTPS3ZeroLimits(t *testing.T) {
zero := 0
c := Config{
Zone: "example.com.",
Transport: "https3",
TLSConfig: &tls.Config{},
ListenHosts: []string{"127.0.0.1"},
Port: "443",
MaxHTTPS3Streams: &zero,
}
server, err := NewServerHTTPS3("127.0.0.1:443", []*Config{&c})
if err != nil {
t.Fatalf("NewServerHTTPS3() with zero limits failed: %v", err)
}
if server.maxStreams != 0 {
t.Errorf("Expected maxStreams = 0, got %d", server.maxStreams)
}
// When maxStreams is 0, quicConfig should not set MaxIncomingStreams (uses QUIC default)
if server.quicConfig.MaxIncomingStreams != 0 {
t.Errorf("Expected quicConfig.MaxIncomingStreams = 0 (QUIC default), got %d", server.quicConfig.MaxIncomingStreams)
}
}

View File

@@ -72,6 +72,67 @@ func TestCustomHTTPRequestValidator(t *testing.T) {
}
}
func TestNewServerHTTPSWithCustomLimits(t *testing.T) {
maxConnections := 100
c := Config{
Zone: "example.com.",
Transport: "https",
TLSConfig: &tls.Config{},
ListenHosts: []string{"127.0.0.1"},
Port: "443",
MaxHTTPSConnections: &maxConnections,
}
server, err := NewServerHTTPS("127.0.0.1:443", []*Config{&c})
if err != nil {
t.Fatalf("NewServerHTTPS() with custom limits failed: %v", err)
}
if server.maxConnections != maxConnections {
t.Errorf("Expected maxConnections = %d, got %d", maxConnections, server.maxConnections)
}
}
func TestNewServerHTTPSDefaults(t *testing.T) {
c := Config{
Zone: "example.com.",
Transport: "https",
TLSConfig: &tls.Config{},
ListenHosts: []string{"127.0.0.1"},
Port: "443",
}
server, err := NewServerHTTPS("127.0.0.1:443", []*Config{&c})
if err != nil {
t.Fatalf("NewServerHTTPS() failed: %v", err)
}
if server.maxConnections != DefaultHTTPSMaxConnections {
t.Errorf("Expected default maxConnections = %d, got %d", DefaultHTTPSMaxConnections, server.maxConnections)
}
}
func TestNewServerHTTPSZeroLimits(t *testing.T) {
zero := 0
c := Config{
Zone: "example.com.",
Transport: "https",
TLSConfig: &tls.Config{},
ListenHosts: []string{"127.0.0.1"},
Port: "443",
MaxHTTPSConnections: &zero,
}
server, err := NewServerHTTPS("127.0.0.1:443", []*Config{&c})
if err != nil {
t.Fatalf("NewServerHTTPS() with zero limits failed: %v", err)
}
if server.maxConnections != 0 {
t.Errorf("Expected maxConnections = 0, got %d", server.maxConnections)
}
}
type contextCapturingPlugin struct {
capturedContext context.Context
contextCancelled bool

View File

@@ -156,12 +156,29 @@ func (s *ServerQUIC) serveQUICConnection(conn *quic.Conn) {
return
}
// Use a bounded worker pool
s.streamProcessPool <- struct{}{} // Acquire a worker slot, may block
go func(st *quic.Stream, cn *quic.Conn) {
defer func() { <-s.streamProcessPool }() // Release worker slot
s.serveQUICStream(st, cn)
}(stream, conn)
// Use a bounded worker pool with context cancellation
select {
case s.streamProcessPool <- struct{}{}:
// Got worker slot immediately
go func(st *quic.Stream, cn *quic.Conn) {
defer func() { <-s.streamProcessPool }() // Release worker slot
s.serveQUICStream(st, cn)
}(stream, conn)
default:
// Worker pool full, check for context cancellation
go func(st *quic.Stream, cn *quic.Conn) {
select {
case s.streamProcessPool <- struct{}{}:
// Got worker slot after waiting
defer func() { <-s.streamProcessPool }() // Release worker slot
s.serveQUICStream(st, cn)
case <-conn.Context().Done():
// Connection context was cancelled while waiting
st.Close()
return
}
}(stream, conn)
}
}
}

View File

@@ -16,6 +16,9 @@ var Directives = []string{
"cancel",
"tls",
"quic",
"grpc_server",
"https",
"https3",
"timeouts",
"multisocket",
"reload",

View File

@@ -27,9 +27,12 @@ import (
_ "github.com/coredns/coredns/plugin/forward"
_ "github.com/coredns/coredns/plugin/geoip"
_ "github.com/coredns/coredns/plugin/grpc"
_ "github.com/coredns/coredns/plugin/grpc_server"
_ "github.com/coredns/coredns/plugin/header"
_ "github.com/coredns/coredns/plugin/health"
_ "github.com/coredns/coredns/plugin/hosts"
_ "github.com/coredns/coredns/plugin/https"
_ "github.com/coredns/coredns/plugin/https3"
_ "github.com/coredns/coredns/plugin/k8s_external"
_ "github.com/coredns/coredns/plugin/kubernetes"
_ "github.com/coredns/coredns/plugin/loadbalance"

View File

@@ -25,6 +25,9 @@ geoip:geoip
cancel:cancel
tls:tls
quic:quic
grpc_server:grpc_server
https:https
https3:https3
timeouts:timeouts
multisocket:multisocket
reload:reload

View File

@@ -1,4 +1,4 @@
package chaos
// Owners are all GitHub handlers of all maintainers.
var Owners = []string{"Tantalor93", "bradbeam", "chrisohaver", "darshanime", "dilyevsky", "ekleiner", "greenpau", "ihac", "inigohu", "isolus", "jameshartig", "johnbelamaric", "miekg", "mqasimsarfraz", "nchrisdk", "nitisht", "pmoroney", "rajansandeep", "rdrozhdzh", "rtreffer", "snebel29", "stp-ip", "superq", "varyoo", "ykhr53", "yongtang", "zouyee"}
var Owners = []string{"Tantalor93", "bradbeam", "chrisohaver", "darshanime", "dilyevsky", "ekleiner", "greenpau", "ihac", "inigohu", "isolus", "jameshartig", "johnbelamaric", "miekg", "mqasimsarfraz", "nchrisdk", "nitisht", "pmoroney", "rajansandeep", "rdrozhdzh", "rtreffer", "snebel29", "stp-ip", "superq", "thevilledev", "varyoo", "ykhr53", "yongtang", "zouyee"}

View File

@@ -0,0 +1,51 @@
# grpc_server
## Name
*grpc_server* - configures DNS-over-gRPC server options.
## Description
The *grpc_server* plugin allows you to configure parameters for the DNS-over-gRPC server to fine-tune the security posture and performance of the server.
This plugin can only be used once per gRPC listener block.
## Syntax
```txt
grpc_server {
max_streams POSITIVE_INTEGER
max_connections POSITIVE_INTEGER
}
```
* `max_streams` limits the number of concurrent gRPC streams per connection. This helps prevent unbounded streams on a single connection, exhausting server resources. The default value is 256 if not specified. Set to 0 for unbounded.
* `max_connections` limits the number of concurrent TCP connections to the gRPC server. The default value is 200 if not specified. Set to 0 for unbounded.
## Examples
Set custom limits for maximum streams and connections:
```
grpc://.:8053 {
tls cert.pem key.pem
grpc_server {
max_streams 50
max_connections 100
}
whoami
}
```
Set values to 0 for unbounded, matching CoreDNS behaviour before v1.14.0:
```
grpc://.:8053 {
tls cert.pem key.pem
grpc_server {
max_streams 0
max_connections 0
}
whoami
}
```

View File

@@ -0,0 +1,79 @@
package grpc_server
import (
"strconv"
"github.com/coredns/caddy"
"github.com/coredns/coredns/core/dnsserver"
"github.com/coredns/coredns/plugin"
)
func init() {
caddy.RegisterPlugin("grpc_server", caddy.Plugin{
ServerType: "dns",
Action: setup,
})
}
func setup(c *caddy.Controller) error {
err := parseGRPCServer(c)
if err != nil {
return plugin.Error("grpc_server", err)
}
return nil
}
func parseGRPCServer(c *caddy.Controller) error {
config := dnsserver.GetConfig(c)
// Skip the "grpc_server" directive itself
c.Next()
// Get any arguments on the "grpc_server" line
args := c.RemainingArgs()
if len(args) > 0 {
return c.ArgErr()
}
// Process all nested directives in the block
for c.NextBlock() {
switch c.Val() {
case "max_streams":
args := c.RemainingArgs()
if len(args) != 1 {
return c.ArgErr()
}
val, err := strconv.Atoi(args[0])
if err != nil {
return c.Errf("invalid max_streams value '%s': %v", args[0], err)
}
if val < 0 {
return c.Errf("max_streams must be a non-negative integer: %d", val)
}
if config.MaxGRPCStreams != nil {
return c.Err("max_streams already defined for this server block")
}
config.MaxGRPCStreams = &val
case "max_connections":
args := c.RemainingArgs()
if len(args) != 1 {
return c.ArgErr()
}
val, err := strconv.Atoi(args[0])
if err != nil {
return c.Errf("invalid max_connections value '%s': %v", args[0], err)
}
if val < 0 {
return c.Errf("max_connections must be a non-negative integer: %d", val)
}
if config.MaxGRPCConnections != nil {
return c.Err("max_connections already defined for this server block")
}
config.MaxGRPCConnections = &val
default:
return c.Errf("unknown property '%s'", c.Val())
}
}
return nil
}

View File

@@ -0,0 +1,169 @@
package grpc_server
import (
"fmt"
"strings"
"testing"
"github.com/coredns/caddy"
"github.com/coredns/coredns/core/dnsserver"
)
func TestSetup(t *testing.T) {
tests := []struct {
input string
shouldErr bool
expectedErrContent string
expectedMaxStreams *int
expectedMaxConnections *int
}{
// Valid configurations
{
input: `grpc_server`,
shouldErr: false,
},
{
input: `grpc_server {
}`,
shouldErr: false,
},
{
input: `grpc_server {
max_streams 100
}`,
shouldErr: false,
expectedMaxStreams: intPtr(100),
},
{
input: `grpc_server {
max_connections 200
}`,
shouldErr: false,
expectedMaxConnections: intPtr(200),
},
{
input: `grpc_server {
max_streams 50
max_connections 100
}`,
shouldErr: false,
expectedMaxStreams: intPtr(50),
expectedMaxConnections: intPtr(100),
},
// Zero values (unbounded)
{
input: `grpc_server {
max_streams 0
}`,
shouldErr: false,
expectedMaxStreams: intPtr(0),
},
{
input: `grpc_server {
max_connections 0
}`,
shouldErr: false,
expectedMaxConnections: intPtr(0),
},
// Error cases
{
input: `grpc_server {
max_streams
}`,
shouldErr: true,
expectedErrContent: "Wrong argument count",
},
{
input: `grpc_server {
max_streams abc
}`,
shouldErr: true,
expectedErrContent: "invalid max_streams value",
},
{
input: `grpc_server {
max_streams -1
}`,
shouldErr: true,
expectedErrContent: "must be a non-negative integer",
},
{
input: `grpc_server {
max_streams 100
max_streams 200
}`,
shouldErr: true,
expectedErrContent: "already defined",
},
{
input: `grpc_server {
unknown_option 123
}`,
shouldErr: true,
expectedErrContent: "unknown property",
},
{
input: `grpc_server extra_arg`,
shouldErr: true,
expectedErrContent: "Wrong argument count",
},
}
for i, test := range tests {
c := caddy.NewTestController("dns", test.input)
err := setup(c)
if test.shouldErr && err == nil {
t.Errorf("Test %d (%s): Expected error but got none", i, test.input)
continue
}
if !test.shouldErr && err != nil {
t.Errorf("Test %d (%s): Expected no error but got: %v", i, test.input, err)
continue
}
if test.shouldErr && test.expectedErrContent != "" {
if !strings.Contains(err.Error(), test.expectedErrContent) {
t.Errorf("Test %d (%s): Expected error containing '%s' but got: %v",
i, test.input, test.expectedErrContent, err)
}
continue
}
if !test.shouldErr {
config := dnsserver.GetConfig(c)
assertIntPtrValue(t, i, test.input, "MaxGRPCStreams", config.MaxGRPCStreams, test.expectedMaxStreams)
assertIntPtrValue(t, i, test.input, "MaxGRPCConnections", config.MaxGRPCConnections, test.expectedMaxConnections)
}
}
}
func intPtr(v int) *int {
return &v
}
func assertIntPtrValue(t *testing.T, testIndex int, testInput, fieldName string, actual, expected *int) {
t.Helper()
if actual == nil && expected == nil {
return
}
if (actual == nil) != (expected == nil) {
t.Errorf("Test %d (%s): Expected %s to be %v, but got %v",
testIndex, testInput, fieldName, formatNilableInt(expected), formatNilableInt(actual))
return
}
if *actual != *expected {
t.Errorf("Test %d (%s): Expected %s to be %d, but got %d",
testIndex, testInput, fieldName, *expected, *actual)
}
}
func formatNilableInt(v *int) string {
if v == nil {
return "nil"
}
return fmt.Sprintf("%d", *v)
}

47
plugin/https/README.md Normal file
View File

@@ -0,0 +1,47 @@
# https
## Name
*https* - configures DNS-over-HTTPS (DoH) server options.
## Description
The *https* plugin allows you to configure parameters for the DNS-over-HTTPS (DoH) server to fine-tune the security posture and performance of the server.
This plugin can only be used once per HTTPS listener block.
## Syntax
```txt
https {
max_connections POSITIVE_INTEGER
}
```
* `max_connections` limits the number of concurrent TCP connections to the HTTPS server. The default value is 200 if not specified. Set to 0 for unbounded.
## Examples
Set custom limits for maximum connections:
```
https://.:443 {
tls cert.pem key.pem
https {
max_connections 100
}
whoami
}
```
Set values to 0 for unbounded, matching CoreDNS behaviour before v1.14.0:
```
https://.:443 {
tls cert.pem key.pem
https {
max_connections 0
}
whoami
}
```

63
plugin/https/setup.go Normal file
View File

@@ -0,0 +1,63 @@
package https
import (
"strconv"
"github.com/coredns/caddy"
"github.com/coredns/coredns/core/dnsserver"
"github.com/coredns/coredns/plugin"
)
func init() {
caddy.RegisterPlugin("https", caddy.Plugin{
ServerType: "dns",
Action: setup,
})
}
func setup(c *caddy.Controller) error {
err := parseDOH(c)
if err != nil {
return plugin.Error("https", err)
}
return nil
}
func parseDOH(c *caddy.Controller) error {
config := dnsserver.GetConfig(c)
// Skip the "https" directive itself
c.Next()
// Get any arguments on the "https" line
args := c.RemainingArgs()
if len(args) > 0 {
return c.ArgErr()
}
// Process all nested directives in the block
for c.NextBlock() {
switch c.Val() {
case "max_connections":
args := c.RemainingArgs()
if len(args) != 1 {
return c.ArgErr()
}
val, err := strconv.Atoi(args[0])
if err != nil {
return c.Errf("invalid max_connections value '%s': %v", args[0], err)
}
if val < 0 {
return c.Errf("max_connections must be a non-negative integer: %d", val)
}
if config.MaxHTTPSConnections != nil {
return c.Err("max_connections already defined for this server block")
}
config.MaxHTTPSConnections = &val
default:
return c.Errf("unknown property '%s'", c.Val())
}
}
return nil
}

144
plugin/https/setup_test.go Normal file
View File

@@ -0,0 +1,144 @@
package https
import (
"fmt"
"strings"
"testing"
"github.com/coredns/caddy"
"github.com/coredns/coredns/core/dnsserver"
)
func TestSetup(t *testing.T) {
tests := []struct {
input string
shouldErr bool
expectedErrContent string
expectedMaxConnections *int
}{
// Valid configurations
{
input: `https`,
shouldErr: false,
},
{
input: `https {
}`,
shouldErr: false,
},
{
input: `https {
max_connections 200
}`,
shouldErr: false,
expectedMaxConnections: intPtr(200),
},
// Zero values (unbounded)
{
input: `https {
max_connections 0
}`,
shouldErr: false,
expectedMaxConnections: intPtr(0),
},
// Error cases
{
input: `https {
max_connections
}`,
shouldErr: true,
expectedErrContent: "Wrong argument count",
},
{
input: `https {
max_connections abc
}`,
shouldErr: true,
expectedErrContent: "invalid max_connections value",
},
{
input: `https {
max_connections -1
}`,
shouldErr: true,
expectedErrContent: "must be a non-negative integer",
},
{
input: `https {
max_connections 100
max_connections 200
}`,
shouldErr: true,
expectedErrContent: "already defined",
},
{
input: `https {
unknown_option 123
}`,
shouldErr: true,
expectedErrContent: "unknown property",
},
{
input: `https extra_arg`,
shouldErr: true,
expectedErrContent: "Wrong argument count",
},
}
for i, test := range tests {
c := caddy.NewTestController("dns", test.input)
err := setup(c)
if test.shouldErr && err == nil {
t.Errorf("Test %d (%s): Expected error but got none", i, test.input)
continue
}
if !test.shouldErr && err != nil {
t.Errorf("Test %d (%s): Expected no error but got: %v", i, test.input, err)
continue
}
if test.shouldErr && test.expectedErrContent != "" {
if !strings.Contains(err.Error(), test.expectedErrContent) {
t.Errorf("Test %d (%s): Expected error containing '%s' but got: %v",
i, test.input, test.expectedErrContent, err)
}
continue
}
if !test.shouldErr {
config := dnsserver.GetConfig(c)
assertIntPtrValue(t, i, test.input, "MaxHTTPSConnections", config.MaxHTTPSConnections, test.expectedMaxConnections)
}
}
}
func intPtr(v int) *int {
return &v
}
func assertIntPtrValue(t *testing.T, testIndex int, testInput, fieldName string, actual, expected *int) {
t.Helper()
if actual == nil && expected == nil {
return
}
if (actual == nil) != (expected == nil) {
t.Errorf("Test %d (%s): Expected %s to be %v, but got %v",
testIndex, testInput, fieldName, formatNilableInt(expected), formatNilableInt(actual))
return
}
if *actual != *expected {
t.Errorf("Test %d (%s): Expected %s to be %d, but got %d",
testIndex, testInput, fieldName, *expected, *actual)
}
}
func formatNilableInt(v *int) string {
if v == nil {
return "nil"
}
return fmt.Sprintf("%d", *v)
}

47
plugin/https3/README.md Normal file
View File

@@ -0,0 +1,47 @@
# https3
## Name
*https3* - configures DNS-over-HTTPS/3 (DoH3) server options.
## Description
The *https3* plugin allows you to configure parameters for the DNS-over-HTTPS/3 (DoH3) server to fine-tune the security posture and performance of the server. HTTPS/3 uses QUIC as the underlying transport.
This plugin can only be used once per HTTPS3 listener block.
## Syntax
```txt
https3 {
max_streams POSITIVE_INTEGER
}
```
* `max_streams` limits the number of concurrent QUIC streams per connection. This helps prevent unbounded streams on a single connection, exhausting server resources. The default value is 256 if not specified. Set to 0 to use underlying QUIC transport default.
## Examples
Set custom limits for maximum streams:
```
https3://.:443 {
tls cert.pem key.pem
https3 {
max_streams 50
}
whoami
}
```
Set values to 0 for QUIC transport default, matching CoreDNS behaviour before v1.14.0:
```
https3://.:443 {
tls cert.pem key.pem
https3 {
max_streams 0
}
whoami
}
```

63
plugin/https3/setup.go Normal file
View File

@@ -0,0 +1,63 @@
package https3
import (
"strconv"
"github.com/coredns/caddy"
"github.com/coredns/coredns/core/dnsserver"
"github.com/coredns/coredns/plugin"
)
func init() {
caddy.RegisterPlugin("https3", caddy.Plugin{
ServerType: "dns",
Action: setup,
})
}
func setup(c *caddy.Controller) error {
err := parseDOH3(c)
if err != nil {
return plugin.Error("https3", err)
}
return nil
}
func parseDOH3(c *caddy.Controller) error {
config := dnsserver.GetConfig(c)
// Skip the "https3" directive itself
c.Next()
// Get any arguments on the "https3" line
args := c.RemainingArgs()
if len(args) > 0 {
return c.ArgErr()
}
// Process all nested directives in the block
for c.NextBlock() {
switch c.Val() {
case "max_streams":
args := c.RemainingArgs()
if len(args) != 1 {
return c.ArgErr()
}
val, err := strconv.Atoi(args[0])
if err != nil {
return c.Errf("invalid max_streams value '%s': %v", args[0], err)
}
if val < 0 {
return c.Errf("max_streams must be a non-negative integer: %d", val)
}
if config.MaxHTTPS3Streams != nil {
return c.Err("max_streams already defined for this server block")
}
config.MaxHTTPS3Streams = &val
default:
return c.Errf("unknown property '%s'", c.Val())
}
}
return nil
}

144
plugin/https3/setup_test.go Normal file
View File

@@ -0,0 +1,144 @@
package https3
import (
"fmt"
"strings"
"testing"
"github.com/coredns/caddy"
"github.com/coredns/coredns/core/dnsserver"
)
func TestSetup(t *testing.T) {
tests := []struct {
input string
shouldErr bool
expectedErrContent string
expectedMaxStreams *int
}{
// Valid configurations
{
input: `https3`,
shouldErr: false,
},
{
input: `https3 {
}`,
shouldErr: false,
},
{
input: `https3 {
max_streams 100
}`,
shouldErr: false,
expectedMaxStreams: intPtr(100),
},
// Zero values (unbounded)
{
input: `https3 {
max_streams 0
}`,
shouldErr: false,
expectedMaxStreams: intPtr(0),
},
// Error cases
{
input: `https3 {
max_streams
}`,
shouldErr: true,
expectedErrContent: "Wrong argument count",
},
{
input: `https3 {
max_streams abc
}`,
shouldErr: true,
expectedErrContent: "invalid max_streams value",
},
{
input: `https3 {
max_streams -1
}`,
shouldErr: true,
expectedErrContent: "must be a non-negative integer",
},
{
input: `https3 {
max_streams 100
max_streams 200
}`,
shouldErr: true,
expectedErrContent: "already defined",
},
{
input: `https3 {
unknown_option 123
}`,
shouldErr: true,
expectedErrContent: "unknown property",
},
{
input: `https3 extra_arg`,
shouldErr: true,
expectedErrContent: "Wrong argument count",
},
}
for i, test := range tests {
c := caddy.NewTestController("dns", test.input)
err := setup(c)
if test.shouldErr && err == nil {
t.Errorf("Test %d (%s): Expected error but got none", i, test.input)
continue
}
if !test.shouldErr && err != nil {
t.Errorf("Test %d (%s): Expected no error but got: %v", i, test.input, err)
continue
}
if test.shouldErr && test.expectedErrContent != "" {
if !strings.Contains(err.Error(), test.expectedErrContent) {
t.Errorf("Test %d (%s): Expected error containing '%s' but got: %v",
i, test.input, test.expectedErrContent, err)
}
continue
}
if !test.shouldErr {
config := dnsserver.GetConfig(c)
assertIntPtrValue(t, i, test.input, "MaxHTTPS3Streams", config.MaxHTTPS3Streams, test.expectedMaxStreams)
}
}
}
func intPtr(v int) *int {
return &v
}
func assertIntPtrValue(t *testing.T, testIndex int, testInput, fieldName string, actual, expected *int) {
t.Helper()
if actual == nil && expected == nil {
return
}
if (actual == nil) != (expected == nil) {
t.Errorf("Test %d (%s): Expected %s to be %v, but got %v",
testIndex, testInput, fieldName, formatNilableInt(expected), formatNilableInt(actual))
return
}
if *actual != *expected {
t.Errorf("Test %d (%s): Expected %s to be %d, but got %d",
testIndex, testInput, fieldName, *expected, *actual)
}
}
func formatNilableInt(v *int) string {
if v == nil {
return "nil"
}
return fmt.Sprintf("%d", *v)
}

View File

@@ -2,19 +2,40 @@ package test
import (
"context"
"crypto/tls"
"net"
"testing"
"time"
"github.com/coredns/coredns/pb"
"github.com/miekg/dns"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/credentials/insecure"
)
var grpcCorefile = `grpc://.:0 {
whoami
}`
var grpcLimitCorefile = `grpc://.:0 {
grpc_server {
max_streams 2
}
whoami
}`
var grpcConnectionLimitCorefile = `grpc://.:0 {
tls ../plugin/tls/test_cert.pem ../plugin/tls/test_key.pem ../plugin/tls/test_ca.pem
grpc_server {
max_connections 2
}
whoami
}`
func TestGrpc(t *testing.T) {
corefile := `grpc://.:0 {
whoami
}`
corefile := grpcCorefile
g, _, tcp, err := CoreDNSServerAndPorts(corefile)
if err != nil {
@@ -53,3 +74,127 @@ func TestGrpc(t *testing.T) {
t.Errorf("Expected 2 RRs in additional section, but got %d", len(d.Extra))
}
}
// TestGRPCWithLimits tests that the server starts and works with configured limits
func TestGRPCWithLimits(t *testing.T) {
g, _, tcp, err := CoreDNSServerAndPorts(grpcLimitCorefile)
if err != nil {
t.Fatalf("Could not get CoreDNS serving instance: %s", err)
}
defer g.Stop()
conn, err := grpc.NewClient(tcp, grpc.WithTransportCredentials(insecure.NewCredentials()))
if err != nil {
t.Fatalf("Expected no error but got: %s", err)
}
defer conn.Close()
client := pb.NewDnsServiceClient(conn)
m := new(dns.Msg)
m.SetQuestion("whoami.example.org.", dns.TypeA)
msg, _ := m.Pack()
reply, err := client.Query(context.Background(), &pb.DnsPacket{Msg: msg})
if err != nil {
t.Fatalf("Query failed: %s", err)
}
d := new(dns.Msg)
if err := d.Unpack(reply.GetMsg()); err != nil {
t.Fatalf("Failed to unpack: %s", err)
}
if d.Rcode != dns.RcodeSuccess {
t.Errorf("Expected success but got %d", d.Rcode)
}
}
// TestGRPCConnectionLimit tests that connection limits are enforced
func TestGRPCConnectionLimit(t *testing.T) {
g, _, tcp, err := CoreDNSServerAndPorts(grpcConnectionLimitCorefile)
if err != nil {
t.Fatalf("Could not get CoreDNS serving instance: %s", err)
}
defer g.Stop()
const maxConns = 2
// Create TLS connections to hold them open
tlsConfig := &tls.Config{InsecureSkipVerify: true}
conns := make([]net.Conn, 0, maxConns+1)
defer func() {
for _, c := range conns {
c.Close()
}
}()
// Open connections up to the limit - these should succeed
for i := range maxConns {
conn, err := tls.Dial("tcp", tcp, tlsConfig)
if err != nil {
t.Fatalf("Connection %d failed (should succeed): %v", i+1, err)
}
conns = append(conns, conn)
}
// Try to open more connections beyond the limit - should timeout
conn, err := tls.DialWithDialer(
&net.Dialer{Timeout: 100 * time.Millisecond},
"tcp", tcp, tlsConfig,
)
if err == nil {
conn.Close()
t.Fatal("Connection beyond limit should have timed out")
}
// Close one connection and verify a new one can be established
conns[0].Close()
conns = conns[1:]
time.Sleep(10 * time.Millisecond)
conn, err = tls.Dial("tcp", tcp, tlsConfig)
if err != nil {
t.Fatalf("Connection after freeing slot failed: %v", err)
}
conns = append(conns, conn)
}
// TestGRPCTLSWithLimits tests that gRPC with TLS starts and works with configured limits
func TestGRPCTLSWithLimits(t *testing.T) {
g, _, tcp, err := CoreDNSServerAndPorts(grpcConnectionLimitCorefile)
if err != nil {
t.Fatalf("Could not get CoreDNS serving instance: %s", err)
}
defer g.Stop()
tlsConfig := &tls.Config{InsecureSkipVerify: true}
creds := credentials.NewTLS(tlsConfig)
conn, err := grpc.NewClient(tcp, grpc.WithTransportCredentials(creds))
if err != nil {
t.Fatalf("Expected no error but got: %s", err)
}
defer conn.Close()
client := pb.NewDnsServiceClient(conn)
m := new(dns.Msg)
m.SetQuestion("whoami.example.org.", dns.TypeA)
msg, _ := m.Pack()
reply, err := client.Query(context.Background(), &pb.DnsPacket{Msg: msg})
if err != nil {
t.Fatalf("Query failed: %s", err)
}
d := new(dns.Msg)
if err := d.Unpack(reply.GetMsg()); err != nil {
t.Fatalf("Failed to unpack: %s", err)
}
if d.Rcode != dns.RcodeSuccess {
t.Errorf("Expected success but got %d", d.Rcode)
}
}

145
test/https3_test.go Normal file
View File

@@ -0,0 +1,145 @@
package test
import (
"bytes"
"context"
"crypto/tls"
"io"
"net/http"
"testing"
"time"
ctls "github.com/coredns/coredns/plugin/pkg/tls"
"github.com/miekg/dns"
"github.com/quic-go/quic-go/http3"
)
var https3Corefile = `https3://.:0 {
tls ../plugin/tls/test_cert.pem ../plugin/tls/test_key.pem ../plugin/tls/test_ca.pem
whoami
}`
var https3LimitCorefile = `https3://.:0 {
tls ../plugin/tls/test_cert.pem ../plugin/tls/test_key.pem ../plugin/tls/test_ca.pem
https3 {
max_streams 2
}
whoami
}`
func generateHTTPS3TLSConfig() *tls.Config {
tlsConfig, err := ctls.NewTLSConfig(
"../plugin/tls/test_cert.pem",
"../plugin/tls/test_key.pem",
"../plugin/tls/test_ca.pem")
if err != nil {
panic(err)
}
tlsConfig.InsecureSkipVerify = true
return tlsConfig
}
func TestHTTPS3(t *testing.T) {
s, udp, _, err := CoreDNSServerAndPorts(https3Corefile)
if err != nil {
t.Fatalf("Could not get CoreDNS serving instance: %s", err)
}
defer s.Stop()
// Create HTTP/3 client
transport := &http3.Transport{
TLSClientConfig: generateHTTPS3TLSConfig(),
}
defer transport.Close()
client := &http.Client{
Transport: transport,
Timeout: 5 * time.Second,
}
// Create DNS query
m := new(dns.Msg)
m.SetQuestion("whoami.example.org.", dns.TypeA)
msg, err := m.Pack()
if err != nil {
t.Fatalf("Failed to pack DNS message: %v", err)
}
// Make DoH3 request - use UDP address for HTTP/3
url := "https://" + convertAddress(udp) + "/dns-query"
req, err := http.NewRequestWithContext(context.Background(), http.MethodPost, url, bytes.NewReader(msg))
if err != nil {
t.Fatalf("Failed to create request: %v", err)
}
req.Header.Set("Content-Type", "application/dns-message")
req.Header.Set("Accept", "application/dns-message")
resp, err := client.Do(req)
if err != nil {
t.Fatalf("Failed to make request: %v", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
t.Fatalf("Expected status 200, got %d", resp.StatusCode)
}
body, err := io.ReadAll(resp.Body)
if err != nil {
t.Fatalf("Failed to read response: %v", err)
}
d := new(dns.Msg)
err = d.Unpack(body)
if err != nil {
t.Fatalf("Failed to unpack response: %v", err)
}
if d.Rcode != dns.RcodeSuccess {
t.Errorf("Expected success but got %d", d.Rcode)
}
if len(d.Extra) != 2 {
t.Errorf("Expected 2 RRs in additional section, but got %d", len(d.Extra))
}
}
// TestHTTPS3WithLimits tests that the server starts and works with configured limits
func TestHTTPS3WithLimits(t *testing.T) {
s, udp, _, err := CoreDNSServerAndPorts(https3LimitCorefile)
if err != nil {
t.Fatalf("Could not get CoreDNS serving instance: %s", err)
}
defer s.Stop()
transport := &http3.Transport{
TLSClientConfig: generateHTTPS3TLSConfig(),
}
defer transport.Close()
client := &http.Client{
Transport: transport,
Timeout: 5 * time.Second,
}
m := new(dns.Msg)
m.SetQuestion("whoami.example.org.", dns.TypeA)
msg, _ := m.Pack()
req, _ := http.NewRequestWithContext(context.Background(), http.MethodPost, "https://"+convertAddress(udp)+"/dns-query", bytes.NewReader(msg))
req.Header.Set("Content-Type", "application/dns-message")
resp, err := client.Do(req)
if err != nil {
t.Fatalf("Request failed: %s", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
t.Fatalf("Expected status 200, got %d", resp.StatusCode)
}
}

177
test/https_test.go Normal file
View File

@@ -0,0 +1,177 @@
package test
import (
"bytes"
"crypto/tls"
"io"
"net"
"net/http"
"testing"
"time"
"github.com/miekg/dns"
)
var httpsCorefile = `https://.:0 {
tls ../plugin/tls/test_cert.pem ../plugin/tls/test_key.pem ../plugin/tls/test_ca.pem
whoami
}`
var httpsLimitCorefile = `https://.:0 {
tls ../plugin/tls/test_cert.pem ../plugin/tls/test_key.pem ../plugin/tls/test_ca.pem
https {
max_connections 2
}
whoami
}`
func TestHTTPS(t *testing.T) {
s, _, tcp, err := CoreDNSServerAndPorts(httpsCorefile)
if err != nil {
t.Fatalf("Could not get CoreDNS serving instance: %s", err)
}
defer s.Stop()
// Create HTTPS client with TLS config
tlsConfig := &tls.Config{
InsecureSkipVerify: true,
}
client := &http.Client{
Transport: &http.Transport{
TLSClientConfig: tlsConfig,
},
Timeout: 5 * time.Second,
}
// Create DNS query
m := new(dns.Msg)
m.SetQuestion("whoami.example.org.", dns.TypeA)
msg, err := m.Pack()
if err != nil {
t.Fatalf("Failed to pack DNS message: %v", err)
}
// Make DoH request
url := "https://" + tcp + "/dns-query"
req, err := http.NewRequest(http.MethodPost, url, bytes.NewReader(msg))
if err != nil {
t.Fatalf("Failed to create request: %v", err)
}
req.Header.Set("Content-Type", "application/dns-message")
req.Header.Set("Accept", "application/dns-message")
resp, err := client.Do(req)
if err != nil {
t.Fatalf("Failed to make request: %v", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
t.Fatalf("Expected status 200, got %d", resp.StatusCode)
}
body, err := io.ReadAll(resp.Body)
if err != nil {
t.Fatalf("Failed to read response: %v", err)
}
d := new(dns.Msg)
err = d.Unpack(body)
if err != nil {
t.Fatalf("Failed to unpack response: %v", err)
}
if d.Rcode != dns.RcodeSuccess {
t.Errorf("Expected success but got %d", d.Rcode)
}
if len(d.Extra) != 2 {
t.Errorf("Expected 2 RRs in additional section, but got %d", len(d.Extra))
}
}
// TestHTTPSWithLimits tests that the server starts and works with configured limits
func TestHTTPSWithLimits(t *testing.T) {
s, _, tcp, err := CoreDNSServerAndPorts(httpsLimitCorefile)
if err != nil {
t.Fatalf("Could not get CoreDNS serving instance: %s", err)
}
defer s.Stop()
client := &http.Client{
Transport: &http.Transport{
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
},
Timeout: 5 * time.Second,
}
m := new(dns.Msg)
m.SetQuestion("whoami.example.org.", dns.TypeA)
msg, _ := m.Pack()
req, _ := http.NewRequest(http.MethodPost, "https://"+tcp+"/dns-query", bytes.NewReader(msg))
req.Header.Set("Content-Type", "application/dns-message")
resp, err := client.Do(req)
if err != nil {
t.Fatalf("Request failed: %s", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
t.Fatalf("Expected status 200, got %d", resp.StatusCode)
}
}
// TestHTTPSConnectionLimit tests that connection limits are enforced
func TestHTTPSConnectionLimit(t *testing.T) {
s, _, tcp, err := CoreDNSServerAndPorts(httpsLimitCorefile)
if err != nil {
t.Fatalf("Could not get CoreDNS serving instance: %s", err)
}
defer s.Stop()
const maxConns = 2
const totalConns = 4
// Create raw TLS connections to hold them open
conns := make([]net.Conn, 0, totalConns)
defer func() {
for _, c := range conns {
c.Close()
}
}()
// Open connections up to the limit - these should succeed
for i := range maxConns {
conn, err := tls.Dial("tcp", tcp, &tls.Config{InsecureSkipVerify: true})
if err != nil {
t.Fatalf("Connection %d failed (should succeed): %v", i+1, err)
}
conns = append(conns, conn)
}
// Try to open more connections beyond the limit
// The LimitListener blocks Accept() until a slot is free, so Dial with timeout should fail
conn, err := tls.DialWithDialer(
&net.Dialer{Timeout: 100 * time.Millisecond},
"tcp", tcp,
&tls.Config{InsecureSkipVerify: true},
)
if err == nil {
conn.Close()
t.Fatal("Connection beyond limit should have timed out")
}
// Close one connection and verify a new one can be established
conns[0].Close()
conns = conns[1:]
time.Sleep(10 * time.Millisecond) // Give the listener time to accept
conn, err = tls.Dial("tcp", tcp, &tls.Config{InsecureSkipVerify: true})
if err != nil {
t.Fatalf("Connection after freeing slot failed: %v", err)
}
conns = append(conns, conn)
}