Merge commit from fork

Add configurable resource limits to prevent potential DoS vectors
via connection/stream exhaustion on gRPC, HTTPS, and HTTPS/3 servers.

New configuration plugins:
- grpc_server: configure max_streams, max_connections
- https: configure max_connections
- https3: configure max_streams

Changes:
- Use netutil.LimitListener for connection limiting
- Use gRPC MaxConcurrentStreams and message size limits
- Add QUIC MaxIncomingStreams for HTTPS/3 stream limiting
- Set secure defaults: 256 max streams, 200 max connections
- Setting any limit to 0 means unbounded/fallback to previous impl

Defaults are applied automatically when plugins are omitted from
config.

Includes tests and integration tests.

Signed-off-by: Ville Vesilehto <ville@vesilehto.fi>
This commit is contained in:
Ville Vesilehto
2025-12-18 05:08:59 +02:00
committed by GitHub
parent 0fb05f225c
commit 0d8cbb1a6b
24 changed files with 1689 additions and 24 deletions

View File

@@ -66,6 +66,22 @@ type Config struct {
// This is nil if not specified, allowing for a default to be used.
MaxQUICWorkerPoolSize *int
// MaxGRPCStreams defines the maximum number of concurrent streams per gRPC connection.
// This is nil if not specified, allowing for a default to be used.
MaxGRPCStreams *int
// MaxGRPCConnections defines the maximum number of concurrent gRPC connections.
// This is nil if not specified, allowing for a default to be used.
MaxGRPCConnections *int
// MaxHTTPSConnections defines the maximum number of concurrent HTTPS connections.
// This is nil if not specified, allowing for a default to be used.
MaxHTTPSConnections *int
// MaxHTTPS3Streams defines the maximum number of concurrent QUIC streams for HTTPS3.
// This is nil if not specified, allowing for a default to be used.
MaxHTTPS3Streams *int
// Timeouts for TCP, TLS and HTTPS servers.
ReadTimeout time.Duration
WriteTimeout time.Duration

View File

@@ -15,17 +15,35 @@ import (
"github.com/grpc-ecosystem/grpc-opentracing/go/otgrpc"
"github.com/miekg/dns"
"github.com/opentracing/opentracing-go"
"golang.org/x/net/netutil"
"google.golang.org/grpc"
"google.golang.org/grpc/peer"
)
const (
// maxDNSMessageBytes is the maximum size of a DNS message on the wire.
maxDNSMessageBytes = dns.MaxMsgSize
// maxProtobufPayloadBytes accounts for protobuf overhead.
// Field tag=1 (1 byte) + length varint for 65535 (3 bytes) = 4 bytes total
maxProtobufPayloadBytes = maxDNSMessageBytes + 4
// DefaultGRPCMaxStreams is the default maximum number of concurrent streams per connection.
DefaultGRPCMaxStreams = 256
// DefaultGRPCMaxConnections is the default maximum number of concurrent connections.
DefaultGRPCMaxConnections = 200
)
// ServergRPC represents an instance of a DNS-over-gRPC server.
type ServergRPC struct {
*Server
*pb.UnimplementedDnsServiceServer
grpcServer *grpc.Server
listenAddr net.Addr
tlsConfig *tls.Config
grpcServer *grpc.Server
listenAddr net.Addr
tlsConfig *tls.Config
maxStreams int
maxConnections int
}
// NewServergRPC returns a new CoreDNS GRPC server and compiles all plugin in to it.
@@ -49,7 +67,22 @@ func NewServergRPC(addr string, group []*Config) (*ServergRPC, error) {
tlsConfig.NextProtos = []string{"h2"}
}
return &ServergRPC{Server: s, tlsConfig: tlsConfig}, nil
maxStreams := DefaultGRPCMaxStreams
if len(group) > 0 && group[0] != nil && group[0].MaxGRPCStreams != nil {
maxStreams = *group[0].MaxGRPCStreams
}
maxConnections := DefaultGRPCMaxConnections
if len(group) > 0 && group[0] != nil && group[0].MaxGRPCConnections != nil {
maxConnections = *group[0].MaxGRPCConnections
}
return &ServergRPC{
Server: s,
tlsConfig: tlsConfig,
maxStreams: maxStreams,
maxConnections: maxConnections,
}, nil
}
// Compile-time check to ensure ServergRPC implements the caddy.GracefulServer interface
@@ -61,21 +94,36 @@ func (s *ServergRPC) Serve(l net.Listener) error {
s.listenAddr = l.Addr()
s.m.Unlock()
serverOpts := []grpc.ServerOption{
grpc.MaxRecvMsgSize(maxProtobufPayloadBytes),
grpc.MaxSendMsgSize(maxProtobufPayloadBytes),
}
// Only set MaxConcurrentStreams if not unbounded (0)
if s.maxStreams > 0 {
serverOpts = append(serverOpts, grpc.MaxConcurrentStreams(uint32(s.maxStreams)))
}
if s.Tracer() != nil {
onlyIfParent := func(parentSpanCtx opentracing.SpanContext, method string, req, resp any) bool {
return parentSpanCtx != nil
}
intercept := otgrpc.OpenTracingServerInterceptor(s.Tracer(), otgrpc.IncludingSpans(onlyIfParent))
s.grpcServer = grpc.NewServer(grpc.UnaryInterceptor(intercept))
} else {
s.grpcServer = grpc.NewServer()
serverOpts = append(serverOpts, grpc.UnaryInterceptor(otgrpc.OpenTracingServerInterceptor(s.Tracer(), otgrpc.IncludingSpans(onlyIfParent))))
}
s.grpcServer = grpc.NewServer(serverOpts...)
pb.RegisterDnsServiceServer(s.grpcServer, s)
if s.tlsConfig != nil {
l = tls.NewListener(l, s.tlsConfig)
}
// Wrap listener to limit concurrent connections
if s.maxConnections > 0 {
l = netutil.LimitListener(l, s.maxConnections)
}
return s.grpcServer.Serve(l)
}
@@ -122,6 +170,9 @@ func (s *ServergRPC) Stop() (err error) {
// any normal server. We use a custom responseWriter to pick up the bytes we need to write
// back to the client as a protobuf.
func (s *ServergRPC) Query(ctx context.Context, in *pb.DnsPacket) (*pb.DnsPacket, error) {
if len(in.GetMsg()) > dns.MaxMsgSize {
return nil, fmt.Errorf("dns message exceeds size limit: %d", len(in.GetMsg()))
}
msg := new(dns.Msg)
err := msg.Unpack(in.GetMsg())
if err != nil {

View File

@@ -69,6 +69,61 @@ func TestNewServergRPCWithTLS(t *testing.T) {
}
}
func TestNewServergRPCWithCustomLimits(t *testing.T) {
config := testConfig("grpc", testPlugin{})
maxStreams := 50
maxConnections := 100
config.MaxGRPCStreams = &maxStreams
config.MaxGRPCConnections = &maxConnections
server, err := NewServergRPC("127.0.0.1:0", []*Config{config})
if err != nil {
t.Fatalf("NewServergRPC() with custom limits failed: %v", err)
}
if server.maxStreams != maxStreams {
t.Errorf("Expected maxStreams = %d, got %d", maxStreams, server.maxStreams)
}
if server.maxConnections != maxConnections {
t.Errorf("Expected maxConnections = %d, got %d", maxConnections, server.maxConnections)
}
}
func TestNewServergRPCDefaults(t *testing.T) {
server, err := NewServergRPC("127.0.0.1:0", []*Config{testConfig("grpc", testPlugin{})})
if err != nil {
t.Fatalf("NewServergRPC() failed: %v", err)
}
if server.maxStreams != DefaultGRPCMaxStreams {
t.Errorf("Expected default maxStreams = %d, got %d", DefaultGRPCMaxStreams, server.maxStreams)
}
if server.maxConnections != DefaultGRPCMaxConnections {
t.Errorf("Expected default maxConnections = %d, got %d", DefaultGRPCMaxConnections, server.maxConnections)
}
}
func TestNewServergRPCZeroLimits(t *testing.T) {
config := testConfig("grpc", testPlugin{})
zero := 0
config.MaxGRPCStreams = &zero
config.MaxGRPCConnections = &zero
server, err := NewServergRPC("127.0.0.1:0", []*Config{config})
if err != nil {
t.Fatalf("NewServergRPC() with zero limits failed: %v", err)
}
if server.maxStreams != 0 {
t.Errorf("Expected maxStreams = 0, got %d", server.maxStreams)
}
if server.maxConnections != 0 {
t.Errorf("Expected maxConnections = 0, got %d", server.maxConnections)
}
}
func TestServergRPC_Listen(t *testing.T) {
server, err := NewServergRPC(transport.GRPC+"://127.0.0.1:0", []*Config{testConfig("grpc", testPlugin{})})
if err != nil {
@@ -328,3 +383,67 @@ func TestGRPCResponse_WriteInvalidMessage(t *testing.T) {
t.Error("Write() should return error for invalid DNS message")
}
}
func TestServergRPC_Query_LargeMessage(t *testing.T) {
server, err := NewServergRPC("127.0.0.1:0", []*Config{testConfig("grpc", testPlugin{})})
if err != nil {
t.Fatalf("NewServergRPC failed: %v", err)
}
// Create oversized message (> dns.MaxMsgSize = 65535)
oversizedMsg := make([]byte, dns.MaxMsgSize+1)
dnsPacket := &pb.DnsPacket{Msg: oversizedMsg}
tcpAddr, _ := net.ResolveTCPAddr("tcp", "127.0.0.1:12345")
p := &peer.Peer{Addr: tcpAddr}
ctx := peer.NewContext(context.Background(), p)
server.listenAddr = tcpAddr
_, err = server.Query(ctx, dnsPacket)
if err == nil {
t.Error("Expected error for oversized message")
}
expectedError := "dns message exceeds size limit: 65536"
if err.Error() != expectedError {
t.Errorf("Expected error '%s', got '%s'", expectedError, err.Error())
}
}
func TestServergRPC_Query_MaxSizeMessage(t *testing.T) {
server, err := NewServergRPC("127.0.0.1:0", []*Config{testConfig("grpc", testPlugin{})})
if err != nil {
t.Fatalf("NewServergRPC failed: %v", err)
}
// Create message exactly at the size limit (dns.MaxMsgSize = 65535)
msg := new(dns.Msg)
msg.SetQuestion("example.com.", dns.TypeA)
packed, err := msg.Pack()
if err != nil {
t.Fatalf("Failed to pack DNS message: %v", err)
}
// Pad the message to exactly max size
if len(packed) > dns.MaxMsgSize {
t.Fatalf("Packed message is already larger than max size: %d", len(packed))
}
maxSizeMsg := make([]byte, dns.MaxMsgSize)
copy(maxSizeMsg, packed)
dnsPacket := &pb.DnsPacket{Msg: maxSizeMsg}
tcpAddr, _ := net.ResolveTCPAddr("tcp", "127.0.0.1:12345")
p := &peer.Peer{Addr: tcpAddr}
ctx := peer.NewContext(context.Background(), p)
server.listenAddr = tcpAddr
// Should not return an error for exactly max size message
_, err = server.Query(ctx, dnsPacket)
if err != nil {
t.Errorf("Expected no error for max size message, got: %v", err)
}
}

View File

@@ -18,15 +18,23 @@ import (
"github.com/coredns/coredns/plugin/pkg/response"
"github.com/coredns/coredns/plugin/pkg/reuseport"
"github.com/coredns/coredns/plugin/pkg/transport"
"golang.org/x/net/netutil"
)
const (
// DefaultHTTPSMaxConnections is the default maximum number of concurrent connections.
DefaultHTTPSMaxConnections = 200
)
// ServerHTTPS represents an instance of a DNS-over-HTTPS server.
type ServerHTTPS struct {
*Server
httpsServer *http.Server
listenAddr net.Addr
tlsConfig *tls.Config
validRequest func(*http.Request) bool
httpsServer *http.Server
listenAddr net.Addr
tlsConfig *tls.Config
validRequest func(*http.Request) bool
maxConnections int
}
// loggerAdapter is a simple adapter around CoreDNS logger made to implement io.Writer in order to log errors from HTTP server
@@ -81,8 +89,17 @@ func NewServerHTTPS(addr string, group []*Config) (*ServerHTTPS, error) {
IdleTimeout: s.IdleTimeout,
ErrorLog: stdlog.New(&loggerAdapter{}, "", 0),
}
maxConnections := DefaultHTTPSMaxConnections
if len(group) > 0 && group[0] != nil && group[0].MaxHTTPSConnections != nil {
maxConnections = *group[0].MaxHTTPSConnections
}
sh := &ServerHTTPS{
Server: s, tlsConfig: tlsConfig, httpsServer: srv, validRequest: validator,
Server: s,
tlsConfig: tlsConfig,
httpsServer: srv,
validRequest: validator,
maxConnections: maxConnections,
}
sh.httpsServer.Handler = sh
@@ -98,9 +115,15 @@ func (s *ServerHTTPS) Serve(l net.Listener) error {
s.listenAddr = l.Addr()
s.m.Unlock()
// Wrap listener to limit concurrent connections (before TLS)
if s.maxConnections > 0 {
l = netutil.LimitListener(l, s.maxConnections)
}
if s.tlsConfig != nil {
l = tls.NewListener(l, s.tlsConfig)
}
return s.httpsServer.Serve(l)
}

View File

@@ -21,6 +21,11 @@ import (
"github.com/quic-go/quic-go/http3"
)
const (
// DefaultHTTPS3MaxStreams is the default maximum number of concurrent QUIC streams per connection.
DefaultHTTPS3MaxStreams = 256
)
// ServerHTTPS3 represents a DNS-over-HTTP/3 server.
type ServerHTTPS3 struct {
*Server
@@ -29,6 +34,7 @@ type ServerHTTPS3 struct {
tlsConfig *tls.Config
quicConfig *quic.Config
validRequest func(*http.Request) bool
maxStreams int
}
// NewServerHTTPS3 builds the HTTP/3 (DoH3) server.
@@ -63,11 +69,20 @@ func NewServerHTTPS3(addr string, group []*Config) (*ServerHTTPS3, error) {
validator = func(r *http.Request) bool { return r.URL.Path == doh.Path }
}
// QUIC transport config
maxStreams := DefaultHTTPS3MaxStreams
if len(group) > 0 && group[0] != nil && group[0].MaxHTTPS3Streams != nil {
maxStreams = *group[0].MaxHTTPS3Streams
}
// QUIC transport config with stream limits (0 means use QUIC default)
qconf := &quic.Config{
MaxIdleTimeout: s.IdleTimeout,
Allow0RTT: true,
}
if maxStreams > 0 {
qconf.MaxIncomingStreams = int64(maxStreams)
qconf.MaxIncomingUniStreams = int64(maxStreams)
}
h3srv := &http3.Server{
Handler: nil, // set after constructing ServerHTTPS3
@@ -83,6 +98,7 @@ func NewServerHTTPS3(addr string, group []*Config) (*ServerHTTPS3, error) {
httpsServer: h3srv,
quicConfig: qconf,
validRequest: validator,
maxStreams: maxStreams,
}
h3srv.Handler = sh

View File

@@ -60,3 +60,82 @@ func TestCustomHTTP3RequestValidator(t *testing.T) {
})
}
}
func TestNewServerHTTPS3WithCustomLimits(t *testing.T) {
maxStreams := 50
c := Config{
Zone: "example.com.",
Transport: "https3",
TLSConfig: &tls.Config{},
ListenHosts: []string{"127.0.0.1"},
Port: "443",
MaxHTTPS3Streams: &maxStreams,
}
server, err := NewServerHTTPS3("127.0.0.1:443", []*Config{&c})
if err != nil {
t.Fatalf("NewServerHTTPS3() with custom limits failed: %v", err)
}
if server.maxStreams != maxStreams {
t.Errorf("Expected maxStreams = %d, got %d", maxStreams, server.maxStreams)
}
expectedMaxStreams := int64(maxStreams)
if server.quicConfig.MaxIncomingStreams != expectedMaxStreams {
t.Errorf("Expected quicConfig.MaxIncomingStreams = %d, got %d", expectedMaxStreams, server.quicConfig.MaxIncomingStreams)
}
if server.quicConfig.MaxIncomingUniStreams != expectedMaxStreams {
t.Errorf("Expected quicConfig.MaxIncomingUniStreams = %d, got %d", expectedMaxStreams, server.quicConfig.MaxIncomingUniStreams)
}
}
func TestNewServerHTTPS3Defaults(t *testing.T) {
c := Config{
Zone: "example.com.",
Transport: "https3",
TLSConfig: &tls.Config{},
ListenHosts: []string{"127.0.0.1"},
Port: "443",
}
server, err := NewServerHTTPS3("127.0.0.1:443", []*Config{&c})
if err != nil {
t.Fatalf("NewServerHTTPS3() failed: %v", err)
}
if server.maxStreams != DefaultHTTPS3MaxStreams {
t.Errorf("Expected default maxStreams = %d, got %d", DefaultHTTPS3MaxStreams, server.maxStreams)
}
expectedMaxStreams := int64(DefaultHTTPS3MaxStreams)
if server.quicConfig.MaxIncomingStreams != expectedMaxStreams {
t.Errorf("Expected default quicConfig.MaxIncomingStreams = %d, got %d", expectedMaxStreams, server.quicConfig.MaxIncomingStreams)
}
}
func TestNewServerHTTPS3ZeroLimits(t *testing.T) {
zero := 0
c := Config{
Zone: "example.com.",
Transport: "https3",
TLSConfig: &tls.Config{},
ListenHosts: []string{"127.0.0.1"},
Port: "443",
MaxHTTPS3Streams: &zero,
}
server, err := NewServerHTTPS3("127.0.0.1:443", []*Config{&c})
if err != nil {
t.Fatalf("NewServerHTTPS3() with zero limits failed: %v", err)
}
if server.maxStreams != 0 {
t.Errorf("Expected maxStreams = 0, got %d", server.maxStreams)
}
// When maxStreams is 0, quicConfig should not set MaxIncomingStreams (uses QUIC default)
if server.quicConfig.MaxIncomingStreams != 0 {
t.Errorf("Expected quicConfig.MaxIncomingStreams = 0 (QUIC default), got %d", server.quicConfig.MaxIncomingStreams)
}
}

View File

@@ -72,6 +72,67 @@ func TestCustomHTTPRequestValidator(t *testing.T) {
}
}
func TestNewServerHTTPSWithCustomLimits(t *testing.T) {
maxConnections := 100
c := Config{
Zone: "example.com.",
Transport: "https",
TLSConfig: &tls.Config{},
ListenHosts: []string{"127.0.0.1"},
Port: "443",
MaxHTTPSConnections: &maxConnections,
}
server, err := NewServerHTTPS("127.0.0.1:443", []*Config{&c})
if err != nil {
t.Fatalf("NewServerHTTPS() with custom limits failed: %v", err)
}
if server.maxConnections != maxConnections {
t.Errorf("Expected maxConnections = %d, got %d", maxConnections, server.maxConnections)
}
}
func TestNewServerHTTPSDefaults(t *testing.T) {
c := Config{
Zone: "example.com.",
Transport: "https",
TLSConfig: &tls.Config{},
ListenHosts: []string{"127.0.0.1"},
Port: "443",
}
server, err := NewServerHTTPS("127.0.0.1:443", []*Config{&c})
if err != nil {
t.Fatalf("NewServerHTTPS() failed: %v", err)
}
if server.maxConnections != DefaultHTTPSMaxConnections {
t.Errorf("Expected default maxConnections = %d, got %d", DefaultHTTPSMaxConnections, server.maxConnections)
}
}
func TestNewServerHTTPSZeroLimits(t *testing.T) {
zero := 0
c := Config{
Zone: "example.com.",
Transport: "https",
TLSConfig: &tls.Config{},
ListenHosts: []string{"127.0.0.1"},
Port: "443",
MaxHTTPSConnections: &zero,
}
server, err := NewServerHTTPS("127.0.0.1:443", []*Config{&c})
if err != nil {
t.Fatalf("NewServerHTTPS() with zero limits failed: %v", err)
}
if server.maxConnections != 0 {
t.Errorf("Expected maxConnections = 0, got %d", server.maxConnections)
}
}
type contextCapturingPlugin struct {
capturedContext context.Context
contextCancelled bool

View File

@@ -156,12 +156,29 @@ func (s *ServerQUIC) serveQUICConnection(conn *quic.Conn) {
return
}
// Use a bounded worker pool
s.streamProcessPool <- struct{}{} // Acquire a worker slot, may block
go func(st *quic.Stream, cn *quic.Conn) {
defer func() { <-s.streamProcessPool }() // Release worker slot
s.serveQUICStream(st, cn)
}(stream, conn)
// Use a bounded worker pool with context cancellation
select {
case s.streamProcessPool <- struct{}{}:
// Got worker slot immediately
go func(st *quic.Stream, cn *quic.Conn) {
defer func() { <-s.streamProcessPool }() // Release worker slot
s.serveQUICStream(st, cn)
}(stream, conn)
default:
// Worker pool full, check for context cancellation
go func(st *quic.Stream, cn *quic.Conn) {
select {
case s.streamProcessPool <- struct{}{}:
// Got worker slot after waiting
defer func() { <-s.streamProcessPool }() // Release worker slot
s.serveQUICStream(st, cn)
case <-conn.Context().Done():
// Connection context was cancelled while waiting
st.Close()
return
}
}(stream, conn)
}
}
}

View File

@@ -16,6 +16,9 @@ var Directives = []string{
"cancel",
"tls",
"quic",
"grpc_server",
"https",
"https3",
"timeouts",
"multisocket",
"reload",

View File

@@ -27,9 +27,12 @@ import (
_ "github.com/coredns/coredns/plugin/forward"
_ "github.com/coredns/coredns/plugin/geoip"
_ "github.com/coredns/coredns/plugin/grpc"
_ "github.com/coredns/coredns/plugin/grpc_server"
_ "github.com/coredns/coredns/plugin/header"
_ "github.com/coredns/coredns/plugin/health"
_ "github.com/coredns/coredns/plugin/hosts"
_ "github.com/coredns/coredns/plugin/https"
_ "github.com/coredns/coredns/plugin/https3"
_ "github.com/coredns/coredns/plugin/k8s_external"
_ "github.com/coredns/coredns/plugin/kubernetes"
_ "github.com/coredns/coredns/plugin/loadbalance"