From 0d8cbb1a6bcb6bc9c1a489865278b8725fa20812 Mon Sep 17 00:00:00 2001 From: Ville Vesilehto Date: Thu, 18 Dec 2025 05:08:59 +0200 Subject: [PATCH] Merge commit from fork Add configurable resource limits to prevent potential DoS vectors via connection/stream exhaustion on gRPC, HTTPS, and HTTPS/3 servers. New configuration plugins: - grpc_server: configure max_streams, max_connections - https: configure max_connections - https3: configure max_streams Changes: - Use netutil.LimitListener for connection limiting - Use gRPC MaxConcurrentStreams and message size limits - Add QUIC MaxIncomingStreams for HTTPS/3 stream limiting - Set secure defaults: 256 max streams, 200 max connections - Setting any limit to 0 means unbounded/fallback to previous impl Defaults are applied automatically when plugins are omitted from config. Includes tests and integration tests. Signed-off-by: Ville Vesilehto --- core/dnsserver/config.go | 16 +++ core/dnsserver/server_grpc.go | 67 ++++++++-- core/dnsserver/server_grpc_test.go | 119 ++++++++++++++++++ core/dnsserver/server_https.go | 33 ++++- core/dnsserver/server_https3.go | 18 ++- core/dnsserver/server_https3_test.go | 79 ++++++++++++ core/dnsserver/server_https_test.go | 61 +++++++++ core/dnsserver/server_quic.go | 29 ++++- core/dnsserver/zdirectives.go | 3 + core/plugin/zplugin.go | 3 + plugin.cfg | 3 + plugin/chaos/zowners.go | 2 +- plugin/grpc_server/README.md | 51 ++++++++ plugin/grpc_server/setup.go | 79 ++++++++++++ plugin/grpc_server/setup_test.go | 169 +++++++++++++++++++++++++ plugin/https/README.md | 47 +++++++ plugin/https/setup.go | 63 ++++++++++ plugin/https/setup_test.go | 144 ++++++++++++++++++++++ plugin/https3/README.md | 47 +++++++ plugin/https3/setup.go | 63 ++++++++++ plugin/https3/setup_test.go | 144 ++++++++++++++++++++++ test/grpc_test.go | 151 ++++++++++++++++++++++- test/https3_test.go | 145 ++++++++++++++++++++++ test/https_test.go | 177 +++++++++++++++++++++++++++ 24 files changed, 1689 insertions(+), 24 deletions(-) create mode 100644 plugin/grpc_server/README.md create mode 100644 plugin/grpc_server/setup.go create mode 100644 plugin/grpc_server/setup_test.go create mode 100644 plugin/https/README.md create mode 100644 plugin/https/setup.go create mode 100644 plugin/https/setup_test.go create mode 100644 plugin/https3/README.md create mode 100644 plugin/https3/setup.go create mode 100644 plugin/https3/setup_test.go create mode 100644 test/https3_test.go create mode 100644 test/https_test.go diff --git a/core/dnsserver/config.go b/core/dnsserver/config.go index 168120795..fcf9c95ce 100644 --- a/core/dnsserver/config.go +++ b/core/dnsserver/config.go @@ -66,6 +66,22 @@ type Config struct { // This is nil if not specified, allowing for a default to be used. MaxQUICWorkerPoolSize *int + // MaxGRPCStreams defines the maximum number of concurrent streams per gRPC connection. + // This is nil if not specified, allowing for a default to be used. + MaxGRPCStreams *int + + // MaxGRPCConnections defines the maximum number of concurrent gRPC connections. + // This is nil if not specified, allowing for a default to be used. + MaxGRPCConnections *int + + // MaxHTTPSConnections defines the maximum number of concurrent HTTPS connections. + // This is nil if not specified, allowing for a default to be used. + MaxHTTPSConnections *int + + // MaxHTTPS3Streams defines the maximum number of concurrent QUIC streams for HTTPS3. + // This is nil if not specified, allowing for a default to be used. + MaxHTTPS3Streams *int + // Timeouts for TCP, TLS and HTTPS servers. ReadTimeout time.Duration WriteTimeout time.Duration diff --git a/core/dnsserver/server_grpc.go b/core/dnsserver/server_grpc.go index a834502c8..e899ffa48 100644 --- a/core/dnsserver/server_grpc.go +++ b/core/dnsserver/server_grpc.go @@ -15,17 +15,35 @@ import ( "github.com/grpc-ecosystem/grpc-opentracing/go/otgrpc" "github.com/miekg/dns" "github.com/opentracing/opentracing-go" + "golang.org/x/net/netutil" "google.golang.org/grpc" "google.golang.org/grpc/peer" ) +const ( + // maxDNSMessageBytes is the maximum size of a DNS message on the wire. + maxDNSMessageBytes = dns.MaxMsgSize + + // maxProtobufPayloadBytes accounts for protobuf overhead. + // Field tag=1 (1 byte) + length varint for 65535 (3 bytes) = 4 bytes total + maxProtobufPayloadBytes = maxDNSMessageBytes + 4 + + // DefaultGRPCMaxStreams is the default maximum number of concurrent streams per connection. + DefaultGRPCMaxStreams = 256 + + // DefaultGRPCMaxConnections is the default maximum number of concurrent connections. + DefaultGRPCMaxConnections = 200 +) + // ServergRPC represents an instance of a DNS-over-gRPC server. type ServergRPC struct { *Server *pb.UnimplementedDnsServiceServer - grpcServer *grpc.Server - listenAddr net.Addr - tlsConfig *tls.Config + grpcServer *grpc.Server + listenAddr net.Addr + tlsConfig *tls.Config + maxStreams int + maxConnections int } // NewServergRPC returns a new CoreDNS GRPC server and compiles all plugin in to it. @@ -49,7 +67,22 @@ func NewServergRPC(addr string, group []*Config) (*ServergRPC, error) { tlsConfig.NextProtos = []string{"h2"} } - return &ServergRPC{Server: s, tlsConfig: tlsConfig}, nil + maxStreams := DefaultGRPCMaxStreams + if len(group) > 0 && group[0] != nil && group[0].MaxGRPCStreams != nil { + maxStreams = *group[0].MaxGRPCStreams + } + + maxConnections := DefaultGRPCMaxConnections + if len(group) > 0 && group[0] != nil && group[0].MaxGRPCConnections != nil { + maxConnections = *group[0].MaxGRPCConnections + } + + return &ServergRPC{ + Server: s, + tlsConfig: tlsConfig, + maxStreams: maxStreams, + maxConnections: maxConnections, + }, nil } // Compile-time check to ensure ServergRPC implements the caddy.GracefulServer interface @@ -61,21 +94,36 @@ func (s *ServergRPC) Serve(l net.Listener) error { s.listenAddr = l.Addr() s.m.Unlock() + serverOpts := []grpc.ServerOption{ + grpc.MaxRecvMsgSize(maxProtobufPayloadBytes), + grpc.MaxSendMsgSize(maxProtobufPayloadBytes), + } + + // Only set MaxConcurrentStreams if not unbounded (0) + if s.maxStreams > 0 { + serverOpts = append(serverOpts, grpc.MaxConcurrentStreams(uint32(s.maxStreams))) + } + if s.Tracer() != nil { onlyIfParent := func(parentSpanCtx opentracing.SpanContext, method string, req, resp any) bool { return parentSpanCtx != nil } - intercept := otgrpc.OpenTracingServerInterceptor(s.Tracer(), otgrpc.IncludingSpans(onlyIfParent)) - s.grpcServer = grpc.NewServer(grpc.UnaryInterceptor(intercept)) - } else { - s.grpcServer = grpc.NewServer() + serverOpts = append(serverOpts, grpc.UnaryInterceptor(otgrpc.OpenTracingServerInterceptor(s.Tracer(), otgrpc.IncludingSpans(onlyIfParent)))) } + s.grpcServer = grpc.NewServer(serverOpts...) + pb.RegisterDnsServiceServer(s.grpcServer, s) if s.tlsConfig != nil { l = tls.NewListener(l, s.tlsConfig) } + + // Wrap listener to limit concurrent connections + if s.maxConnections > 0 { + l = netutil.LimitListener(l, s.maxConnections) + } + return s.grpcServer.Serve(l) } @@ -122,6 +170,9 @@ func (s *ServergRPC) Stop() (err error) { // any normal server. We use a custom responseWriter to pick up the bytes we need to write // back to the client as a protobuf. func (s *ServergRPC) Query(ctx context.Context, in *pb.DnsPacket) (*pb.DnsPacket, error) { + if len(in.GetMsg()) > dns.MaxMsgSize { + return nil, fmt.Errorf("dns message exceeds size limit: %d", len(in.GetMsg())) + } msg := new(dns.Msg) err := msg.Unpack(in.GetMsg()) if err != nil { diff --git a/core/dnsserver/server_grpc_test.go b/core/dnsserver/server_grpc_test.go index 5dc72b55b..bfb095cfc 100644 --- a/core/dnsserver/server_grpc_test.go +++ b/core/dnsserver/server_grpc_test.go @@ -69,6 +69,61 @@ func TestNewServergRPCWithTLS(t *testing.T) { } } +func TestNewServergRPCWithCustomLimits(t *testing.T) { + config := testConfig("grpc", testPlugin{}) + maxStreams := 50 + maxConnections := 100 + config.MaxGRPCStreams = &maxStreams + config.MaxGRPCConnections = &maxConnections + + server, err := NewServergRPC("127.0.0.1:0", []*Config{config}) + if err != nil { + t.Fatalf("NewServergRPC() with custom limits failed: %v", err) + } + + if server.maxStreams != maxStreams { + t.Errorf("Expected maxStreams = %d, got %d", maxStreams, server.maxStreams) + } + + if server.maxConnections != maxConnections { + t.Errorf("Expected maxConnections = %d, got %d", maxConnections, server.maxConnections) + } +} + +func TestNewServergRPCDefaults(t *testing.T) { + server, err := NewServergRPC("127.0.0.1:0", []*Config{testConfig("grpc", testPlugin{})}) + if err != nil { + t.Fatalf("NewServergRPC() failed: %v", err) + } + + if server.maxStreams != DefaultGRPCMaxStreams { + t.Errorf("Expected default maxStreams = %d, got %d", DefaultGRPCMaxStreams, server.maxStreams) + } + + if server.maxConnections != DefaultGRPCMaxConnections { + t.Errorf("Expected default maxConnections = %d, got %d", DefaultGRPCMaxConnections, server.maxConnections) + } +} + +func TestNewServergRPCZeroLimits(t *testing.T) { + config := testConfig("grpc", testPlugin{}) + zero := 0 + config.MaxGRPCStreams = &zero + config.MaxGRPCConnections = &zero + + server, err := NewServergRPC("127.0.0.1:0", []*Config{config}) + if err != nil { + t.Fatalf("NewServergRPC() with zero limits failed: %v", err) + } + + if server.maxStreams != 0 { + t.Errorf("Expected maxStreams = 0, got %d", server.maxStreams) + } + if server.maxConnections != 0 { + t.Errorf("Expected maxConnections = 0, got %d", server.maxConnections) + } +} + func TestServergRPC_Listen(t *testing.T) { server, err := NewServergRPC(transport.GRPC+"://127.0.0.1:0", []*Config{testConfig("grpc", testPlugin{})}) if err != nil { @@ -328,3 +383,67 @@ func TestGRPCResponse_WriteInvalidMessage(t *testing.T) { t.Error("Write() should return error for invalid DNS message") } } + +func TestServergRPC_Query_LargeMessage(t *testing.T) { + server, err := NewServergRPC("127.0.0.1:0", []*Config{testConfig("grpc", testPlugin{})}) + if err != nil { + t.Fatalf("NewServergRPC failed: %v", err) + } + + // Create oversized message (> dns.MaxMsgSize = 65535) + oversizedMsg := make([]byte, dns.MaxMsgSize+1) + dnsPacket := &pb.DnsPacket{Msg: oversizedMsg} + + tcpAddr, _ := net.ResolveTCPAddr("tcp", "127.0.0.1:12345") + p := &peer.Peer{Addr: tcpAddr} + ctx := peer.NewContext(context.Background(), p) + + server.listenAddr = tcpAddr + + _, err = server.Query(ctx, dnsPacket) + if err == nil { + t.Error("Expected error for oversized message") + } + + expectedError := "dns message exceeds size limit: 65536" + if err.Error() != expectedError { + t.Errorf("Expected error '%s', got '%s'", expectedError, err.Error()) + } +} + +func TestServergRPC_Query_MaxSizeMessage(t *testing.T) { + server, err := NewServergRPC("127.0.0.1:0", []*Config{testConfig("grpc", testPlugin{})}) + if err != nil { + t.Fatalf("NewServergRPC failed: %v", err) + } + + // Create message exactly at the size limit (dns.MaxMsgSize = 65535) + msg := new(dns.Msg) + msg.SetQuestion("example.com.", dns.TypeA) + packed, err := msg.Pack() + if err != nil { + t.Fatalf("Failed to pack DNS message: %v", err) + } + + // Pad the message to exactly max size + if len(packed) > dns.MaxMsgSize { + t.Fatalf("Packed message is already larger than max size: %d", len(packed)) + } + + maxSizeMsg := make([]byte, dns.MaxMsgSize) + copy(maxSizeMsg, packed) + + dnsPacket := &pb.DnsPacket{Msg: maxSizeMsg} + + tcpAddr, _ := net.ResolveTCPAddr("tcp", "127.0.0.1:12345") + p := &peer.Peer{Addr: tcpAddr} + ctx := peer.NewContext(context.Background(), p) + + server.listenAddr = tcpAddr + + // Should not return an error for exactly max size message + _, err = server.Query(ctx, dnsPacket) + if err != nil { + t.Errorf("Expected no error for max size message, got: %v", err) + } +} diff --git a/core/dnsserver/server_https.go b/core/dnsserver/server_https.go index cf84e8c35..0d522a051 100644 --- a/core/dnsserver/server_https.go +++ b/core/dnsserver/server_https.go @@ -18,15 +18,23 @@ import ( "github.com/coredns/coredns/plugin/pkg/response" "github.com/coredns/coredns/plugin/pkg/reuseport" "github.com/coredns/coredns/plugin/pkg/transport" + + "golang.org/x/net/netutil" +) + +const ( + // DefaultHTTPSMaxConnections is the default maximum number of concurrent connections. + DefaultHTTPSMaxConnections = 200 ) // ServerHTTPS represents an instance of a DNS-over-HTTPS server. type ServerHTTPS struct { *Server - httpsServer *http.Server - listenAddr net.Addr - tlsConfig *tls.Config - validRequest func(*http.Request) bool + httpsServer *http.Server + listenAddr net.Addr + tlsConfig *tls.Config + validRequest func(*http.Request) bool + maxConnections int } // loggerAdapter is a simple adapter around CoreDNS logger made to implement io.Writer in order to log errors from HTTP server @@ -81,8 +89,17 @@ func NewServerHTTPS(addr string, group []*Config) (*ServerHTTPS, error) { IdleTimeout: s.IdleTimeout, ErrorLog: stdlog.New(&loggerAdapter{}, "", 0), } + maxConnections := DefaultHTTPSMaxConnections + if len(group) > 0 && group[0] != nil && group[0].MaxHTTPSConnections != nil { + maxConnections = *group[0].MaxHTTPSConnections + } + sh := &ServerHTTPS{ - Server: s, tlsConfig: tlsConfig, httpsServer: srv, validRequest: validator, + Server: s, + tlsConfig: tlsConfig, + httpsServer: srv, + validRequest: validator, + maxConnections: maxConnections, } sh.httpsServer.Handler = sh @@ -98,9 +115,15 @@ func (s *ServerHTTPS) Serve(l net.Listener) error { s.listenAddr = l.Addr() s.m.Unlock() + // Wrap listener to limit concurrent connections (before TLS) + if s.maxConnections > 0 { + l = netutil.LimitListener(l, s.maxConnections) + } + if s.tlsConfig != nil { l = tls.NewListener(l, s.tlsConfig) } + return s.httpsServer.Serve(l) } diff --git a/core/dnsserver/server_https3.go b/core/dnsserver/server_https3.go index d6d1d85b8..ea36abbda 100644 --- a/core/dnsserver/server_https3.go +++ b/core/dnsserver/server_https3.go @@ -21,6 +21,11 @@ import ( "github.com/quic-go/quic-go/http3" ) +const ( + // DefaultHTTPS3MaxStreams is the default maximum number of concurrent QUIC streams per connection. + DefaultHTTPS3MaxStreams = 256 +) + // ServerHTTPS3 represents a DNS-over-HTTP/3 server. type ServerHTTPS3 struct { *Server @@ -29,6 +34,7 @@ type ServerHTTPS3 struct { tlsConfig *tls.Config quicConfig *quic.Config validRequest func(*http.Request) bool + maxStreams int } // NewServerHTTPS3 builds the HTTP/3 (DoH3) server. @@ -63,11 +69,20 @@ func NewServerHTTPS3(addr string, group []*Config) (*ServerHTTPS3, error) { validator = func(r *http.Request) bool { return r.URL.Path == doh.Path } } - // QUIC transport config + maxStreams := DefaultHTTPS3MaxStreams + if len(group) > 0 && group[0] != nil && group[0].MaxHTTPS3Streams != nil { + maxStreams = *group[0].MaxHTTPS3Streams + } + + // QUIC transport config with stream limits (0 means use QUIC default) qconf := &quic.Config{ MaxIdleTimeout: s.IdleTimeout, Allow0RTT: true, } + if maxStreams > 0 { + qconf.MaxIncomingStreams = int64(maxStreams) + qconf.MaxIncomingUniStreams = int64(maxStreams) + } h3srv := &http3.Server{ Handler: nil, // set after constructing ServerHTTPS3 @@ -83,6 +98,7 @@ func NewServerHTTPS3(addr string, group []*Config) (*ServerHTTPS3, error) { httpsServer: h3srv, quicConfig: qconf, validRequest: validator, + maxStreams: maxStreams, } h3srv.Handler = sh diff --git a/core/dnsserver/server_https3_test.go b/core/dnsserver/server_https3_test.go index c3a6c3184..bd460c889 100644 --- a/core/dnsserver/server_https3_test.go +++ b/core/dnsserver/server_https3_test.go @@ -60,3 +60,82 @@ func TestCustomHTTP3RequestValidator(t *testing.T) { }) } } + +func TestNewServerHTTPS3WithCustomLimits(t *testing.T) { + maxStreams := 50 + c := Config{ + Zone: "example.com.", + Transport: "https3", + TLSConfig: &tls.Config{}, + ListenHosts: []string{"127.0.0.1"}, + Port: "443", + MaxHTTPS3Streams: &maxStreams, + } + + server, err := NewServerHTTPS3("127.0.0.1:443", []*Config{&c}) + if err != nil { + t.Fatalf("NewServerHTTPS3() with custom limits failed: %v", err) + } + + if server.maxStreams != maxStreams { + t.Errorf("Expected maxStreams = %d, got %d", maxStreams, server.maxStreams) + } + + expectedMaxStreams := int64(maxStreams) + if server.quicConfig.MaxIncomingStreams != expectedMaxStreams { + t.Errorf("Expected quicConfig.MaxIncomingStreams = %d, got %d", expectedMaxStreams, server.quicConfig.MaxIncomingStreams) + } + + if server.quicConfig.MaxIncomingUniStreams != expectedMaxStreams { + t.Errorf("Expected quicConfig.MaxIncomingUniStreams = %d, got %d", expectedMaxStreams, server.quicConfig.MaxIncomingUniStreams) + } +} + +func TestNewServerHTTPS3Defaults(t *testing.T) { + c := Config{ + Zone: "example.com.", + Transport: "https3", + TLSConfig: &tls.Config{}, + ListenHosts: []string{"127.0.0.1"}, + Port: "443", + } + + server, err := NewServerHTTPS3("127.0.0.1:443", []*Config{&c}) + if err != nil { + t.Fatalf("NewServerHTTPS3() failed: %v", err) + } + + if server.maxStreams != DefaultHTTPS3MaxStreams { + t.Errorf("Expected default maxStreams = %d, got %d", DefaultHTTPS3MaxStreams, server.maxStreams) + } + + expectedMaxStreams := int64(DefaultHTTPS3MaxStreams) + if server.quicConfig.MaxIncomingStreams != expectedMaxStreams { + t.Errorf("Expected default quicConfig.MaxIncomingStreams = %d, got %d", expectedMaxStreams, server.quicConfig.MaxIncomingStreams) + } +} + +func TestNewServerHTTPS3ZeroLimits(t *testing.T) { + zero := 0 + c := Config{ + Zone: "example.com.", + Transport: "https3", + TLSConfig: &tls.Config{}, + ListenHosts: []string{"127.0.0.1"}, + Port: "443", + MaxHTTPS3Streams: &zero, + } + + server, err := NewServerHTTPS3("127.0.0.1:443", []*Config{&c}) + if err != nil { + t.Fatalf("NewServerHTTPS3() with zero limits failed: %v", err) + } + + if server.maxStreams != 0 { + t.Errorf("Expected maxStreams = 0, got %d", server.maxStreams) + } + // When maxStreams is 0, quicConfig should not set MaxIncomingStreams (uses QUIC default) + if server.quicConfig.MaxIncomingStreams != 0 { + t.Errorf("Expected quicConfig.MaxIncomingStreams = 0 (QUIC default), got %d", server.quicConfig.MaxIncomingStreams) + } +} diff --git a/core/dnsserver/server_https_test.go b/core/dnsserver/server_https_test.go index 5d062e168..cde4d7a15 100644 --- a/core/dnsserver/server_https_test.go +++ b/core/dnsserver/server_https_test.go @@ -72,6 +72,67 @@ func TestCustomHTTPRequestValidator(t *testing.T) { } } +func TestNewServerHTTPSWithCustomLimits(t *testing.T) { + maxConnections := 100 + c := Config{ + Zone: "example.com.", + Transport: "https", + TLSConfig: &tls.Config{}, + ListenHosts: []string{"127.0.0.1"}, + Port: "443", + MaxHTTPSConnections: &maxConnections, + } + + server, err := NewServerHTTPS("127.0.0.1:443", []*Config{&c}) + if err != nil { + t.Fatalf("NewServerHTTPS() with custom limits failed: %v", err) + } + + if server.maxConnections != maxConnections { + t.Errorf("Expected maxConnections = %d, got %d", maxConnections, server.maxConnections) + } +} + +func TestNewServerHTTPSDefaults(t *testing.T) { + c := Config{ + Zone: "example.com.", + Transport: "https", + TLSConfig: &tls.Config{}, + ListenHosts: []string{"127.0.0.1"}, + Port: "443", + } + + server, err := NewServerHTTPS("127.0.0.1:443", []*Config{&c}) + if err != nil { + t.Fatalf("NewServerHTTPS() failed: %v", err) + } + + if server.maxConnections != DefaultHTTPSMaxConnections { + t.Errorf("Expected default maxConnections = %d, got %d", DefaultHTTPSMaxConnections, server.maxConnections) + } +} + +func TestNewServerHTTPSZeroLimits(t *testing.T) { + zero := 0 + c := Config{ + Zone: "example.com.", + Transport: "https", + TLSConfig: &tls.Config{}, + ListenHosts: []string{"127.0.0.1"}, + Port: "443", + MaxHTTPSConnections: &zero, + } + + server, err := NewServerHTTPS("127.0.0.1:443", []*Config{&c}) + if err != nil { + t.Fatalf("NewServerHTTPS() with zero limits failed: %v", err) + } + + if server.maxConnections != 0 { + t.Errorf("Expected maxConnections = 0, got %d", server.maxConnections) + } +} + type contextCapturingPlugin struct { capturedContext context.Context contextCancelled bool diff --git a/core/dnsserver/server_quic.go b/core/dnsserver/server_quic.go index b7d7fd7ff..cc07c7b05 100644 --- a/core/dnsserver/server_quic.go +++ b/core/dnsserver/server_quic.go @@ -156,12 +156,29 @@ func (s *ServerQUIC) serveQUICConnection(conn *quic.Conn) { return } - // Use a bounded worker pool - s.streamProcessPool <- struct{}{} // Acquire a worker slot, may block - go func(st *quic.Stream, cn *quic.Conn) { - defer func() { <-s.streamProcessPool }() // Release worker slot - s.serveQUICStream(st, cn) - }(stream, conn) + // Use a bounded worker pool with context cancellation + select { + case s.streamProcessPool <- struct{}{}: + // Got worker slot immediately + go func(st *quic.Stream, cn *quic.Conn) { + defer func() { <-s.streamProcessPool }() // Release worker slot + s.serveQUICStream(st, cn) + }(stream, conn) + default: + // Worker pool full, check for context cancellation + go func(st *quic.Stream, cn *quic.Conn) { + select { + case s.streamProcessPool <- struct{}{}: + // Got worker slot after waiting + defer func() { <-s.streamProcessPool }() // Release worker slot + s.serveQUICStream(st, cn) + case <-conn.Context().Done(): + // Connection context was cancelled while waiting + st.Close() + return + } + }(stream, conn) + } } } diff --git a/core/dnsserver/zdirectives.go b/core/dnsserver/zdirectives.go index a237cbf34..c356740c1 100644 --- a/core/dnsserver/zdirectives.go +++ b/core/dnsserver/zdirectives.go @@ -16,6 +16,9 @@ var Directives = []string{ "cancel", "tls", "quic", + "grpc_server", + "https", + "https3", "timeouts", "multisocket", "reload", diff --git a/core/plugin/zplugin.go b/core/plugin/zplugin.go index 025c04474..d080cc5f9 100644 --- a/core/plugin/zplugin.go +++ b/core/plugin/zplugin.go @@ -27,9 +27,12 @@ import ( _ "github.com/coredns/coredns/plugin/forward" _ "github.com/coredns/coredns/plugin/geoip" _ "github.com/coredns/coredns/plugin/grpc" + _ "github.com/coredns/coredns/plugin/grpc_server" _ "github.com/coredns/coredns/plugin/header" _ "github.com/coredns/coredns/plugin/health" _ "github.com/coredns/coredns/plugin/hosts" + _ "github.com/coredns/coredns/plugin/https" + _ "github.com/coredns/coredns/plugin/https3" _ "github.com/coredns/coredns/plugin/k8s_external" _ "github.com/coredns/coredns/plugin/kubernetes" _ "github.com/coredns/coredns/plugin/loadbalance" diff --git a/plugin.cfg b/plugin.cfg index b4d3bae03..1612a9cec 100644 --- a/plugin.cfg +++ b/plugin.cfg @@ -25,6 +25,9 @@ geoip:geoip cancel:cancel tls:tls quic:quic +grpc_server:grpc_server +https:https +https3:https3 timeouts:timeouts multisocket:multisocket reload:reload diff --git a/plugin/chaos/zowners.go b/plugin/chaos/zowners.go index 419ca3cf5..b9553f3c4 100644 --- a/plugin/chaos/zowners.go +++ b/plugin/chaos/zowners.go @@ -1,4 +1,4 @@ package chaos // Owners are all GitHub handlers of all maintainers. -var Owners = []string{"Tantalor93", "bradbeam", "chrisohaver", "darshanime", "dilyevsky", "ekleiner", "greenpau", "ihac", "inigohu", "isolus", "jameshartig", "johnbelamaric", "miekg", "mqasimsarfraz", "nchrisdk", "nitisht", "pmoroney", "rajansandeep", "rdrozhdzh", "rtreffer", "snebel29", "stp-ip", "superq", "varyoo", "ykhr53", "yongtang", "zouyee"} +var Owners = []string{"Tantalor93", "bradbeam", "chrisohaver", "darshanime", "dilyevsky", "ekleiner", "greenpau", "ihac", "inigohu", "isolus", "jameshartig", "johnbelamaric", "miekg", "mqasimsarfraz", "nchrisdk", "nitisht", "pmoroney", "rajansandeep", "rdrozhdzh", "rtreffer", "snebel29", "stp-ip", "superq", "thevilledev", "varyoo", "ykhr53", "yongtang", "zouyee"} diff --git a/plugin/grpc_server/README.md b/plugin/grpc_server/README.md new file mode 100644 index 000000000..1a19bb1f6 --- /dev/null +++ b/plugin/grpc_server/README.md @@ -0,0 +1,51 @@ +# grpc_server + +## Name + +*grpc_server* - configures DNS-over-gRPC server options. + +## Description + +The *grpc_server* plugin allows you to configure parameters for the DNS-over-gRPC server to fine-tune the security posture and performance of the server. + +This plugin can only be used once per gRPC listener block. + +## Syntax + +```txt +grpc_server { + max_streams POSITIVE_INTEGER + max_connections POSITIVE_INTEGER +} +``` + +* `max_streams` limits the number of concurrent gRPC streams per connection. This helps prevent unbounded streams on a single connection, exhausting server resources. The default value is 256 if not specified. Set to 0 for unbounded. +* `max_connections` limits the number of concurrent TCP connections to the gRPC server. The default value is 200 if not specified. Set to 0 for unbounded. + +## Examples + +Set custom limits for maximum streams and connections: + +``` +grpc://.:8053 { + tls cert.pem key.pem + grpc_server { + max_streams 50 + max_connections 100 + } + whoami +} +``` + +Set values to 0 for unbounded, matching CoreDNS behaviour before v1.14.0: + +``` +grpc://.:8053 { + tls cert.pem key.pem + grpc_server { + max_streams 0 + max_connections 0 + } + whoami +} +``` diff --git a/plugin/grpc_server/setup.go b/plugin/grpc_server/setup.go new file mode 100644 index 000000000..0cecd7dd6 --- /dev/null +++ b/plugin/grpc_server/setup.go @@ -0,0 +1,79 @@ +package grpc_server + +import ( + "strconv" + + "github.com/coredns/caddy" + "github.com/coredns/coredns/core/dnsserver" + "github.com/coredns/coredns/plugin" +) + +func init() { + caddy.RegisterPlugin("grpc_server", caddy.Plugin{ + ServerType: "dns", + Action: setup, + }) +} + +func setup(c *caddy.Controller) error { + err := parseGRPCServer(c) + if err != nil { + return plugin.Error("grpc_server", err) + } + return nil +} + +func parseGRPCServer(c *caddy.Controller) error { + config := dnsserver.GetConfig(c) + + // Skip the "grpc_server" directive itself + c.Next() + + // Get any arguments on the "grpc_server" line + args := c.RemainingArgs() + if len(args) > 0 { + return c.ArgErr() + } + + // Process all nested directives in the block + for c.NextBlock() { + switch c.Val() { + case "max_streams": + args := c.RemainingArgs() + if len(args) != 1 { + return c.ArgErr() + } + val, err := strconv.Atoi(args[0]) + if err != nil { + return c.Errf("invalid max_streams value '%s': %v", args[0], err) + } + if val < 0 { + return c.Errf("max_streams must be a non-negative integer: %d", val) + } + if config.MaxGRPCStreams != nil { + return c.Err("max_streams already defined for this server block") + } + config.MaxGRPCStreams = &val + case "max_connections": + args := c.RemainingArgs() + if len(args) != 1 { + return c.ArgErr() + } + val, err := strconv.Atoi(args[0]) + if err != nil { + return c.Errf("invalid max_connections value '%s': %v", args[0], err) + } + if val < 0 { + return c.Errf("max_connections must be a non-negative integer: %d", val) + } + if config.MaxGRPCConnections != nil { + return c.Err("max_connections already defined for this server block") + } + config.MaxGRPCConnections = &val + default: + return c.Errf("unknown property '%s'", c.Val()) + } + } + + return nil +} diff --git a/plugin/grpc_server/setup_test.go b/plugin/grpc_server/setup_test.go new file mode 100644 index 000000000..75fb74980 --- /dev/null +++ b/plugin/grpc_server/setup_test.go @@ -0,0 +1,169 @@ +package grpc_server + +import ( + "fmt" + "strings" + "testing" + + "github.com/coredns/caddy" + "github.com/coredns/coredns/core/dnsserver" +) + +func TestSetup(t *testing.T) { + tests := []struct { + input string + shouldErr bool + expectedErrContent string + expectedMaxStreams *int + expectedMaxConnections *int + }{ + // Valid configurations + { + input: `grpc_server`, + shouldErr: false, + }, + { + input: `grpc_server { + }`, + shouldErr: false, + }, + { + input: `grpc_server { + max_streams 100 + }`, + shouldErr: false, + expectedMaxStreams: intPtr(100), + }, + { + input: `grpc_server { + max_connections 200 + }`, + shouldErr: false, + expectedMaxConnections: intPtr(200), + }, + { + input: `grpc_server { + max_streams 50 + max_connections 100 + }`, + shouldErr: false, + expectedMaxStreams: intPtr(50), + expectedMaxConnections: intPtr(100), + }, + // Zero values (unbounded) + { + input: `grpc_server { + max_streams 0 + }`, + shouldErr: false, + expectedMaxStreams: intPtr(0), + }, + { + input: `grpc_server { + max_connections 0 + }`, + shouldErr: false, + expectedMaxConnections: intPtr(0), + }, + // Error cases + { + input: `grpc_server { + max_streams + }`, + shouldErr: true, + expectedErrContent: "Wrong argument count", + }, + { + input: `grpc_server { + max_streams abc + }`, + shouldErr: true, + expectedErrContent: "invalid max_streams value", + }, + { + input: `grpc_server { + max_streams -1 + }`, + shouldErr: true, + expectedErrContent: "must be a non-negative integer", + }, + { + input: `grpc_server { + max_streams 100 + max_streams 200 + }`, + shouldErr: true, + expectedErrContent: "already defined", + }, + { + input: `grpc_server { + unknown_option 123 + }`, + shouldErr: true, + expectedErrContent: "unknown property", + }, + { + input: `grpc_server extra_arg`, + shouldErr: true, + expectedErrContent: "Wrong argument count", + }, + } + + for i, test := range tests { + c := caddy.NewTestController("dns", test.input) + err := setup(c) + + if test.shouldErr && err == nil { + t.Errorf("Test %d (%s): Expected error but got none", i, test.input) + continue + } + + if !test.shouldErr && err != nil { + t.Errorf("Test %d (%s): Expected no error but got: %v", i, test.input, err) + continue + } + + if test.shouldErr && test.expectedErrContent != "" { + if !strings.Contains(err.Error(), test.expectedErrContent) { + t.Errorf("Test %d (%s): Expected error containing '%s' but got: %v", + i, test.input, test.expectedErrContent, err) + } + continue + } + + if !test.shouldErr { + config := dnsserver.GetConfig(c) + assertIntPtrValue(t, i, test.input, "MaxGRPCStreams", config.MaxGRPCStreams, test.expectedMaxStreams) + assertIntPtrValue(t, i, test.input, "MaxGRPCConnections", config.MaxGRPCConnections, test.expectedMaxConnections) + } + } +} + +func intPtr(v int) *int { + return &v +} + +func assertIntPtrValue(t *testing.T, testIndex int, testInput, fieldName string, actual, expected *int) { + t.Helper() + if actual == nil && expected == nil { + return + } + + if (actual == nil) != (expected == nil) { + t.Errorf("Test %d (%s): Expected %s to be %v, but got %v", + testIndex, testInput, fieldName, formatNilableInt(expected), formatNilableInt(actual)) + return + } + + if *actual != *expected { + t.Errorf("Test %d (%s): Expected %s to be %d, but got %d", + testIndex, testInput, fieldName, *expected, *actual) + } +} + +func formatNilableInt(v *int) string { + if v == nil { + return "nil" + } + return fmt.Sprintf("%d", *v) +} diff --git a/plugin/https/README.md b/plugin/https/README.md new file mode 100644 index 000000000..938c2dbd2 --- /dev/null +++ b/plugin/https/README.md @@ -0,0 +1,47 @@ +# https + +## Name + +*https* - configures DNS-over-HTTPS (DoH) server options. + +## Description + +The *https* plugin allows you to configure parameters for the DNS-over-HTTPS (DoH) server to fine-tune the security posture and performance of the server. + +This plugin can only be used once per HTTPS listener block. + +## Syntax + +```txt +https { + max_connections POSITIVE_INTEGER +} +``` + +* `max_connections` limits the number of concurrent TCP connections to the HTTPS server. The default value is 200 if not specified. Set to 0 for unbounded. + +## Examples + +Set custom limits for maximum connections: + +``` +https://.:443 { + tls cert.pem key.pem + https { + max_connections 100 + } + whoami +} +``` + +Set values to 0 for unbounded, matching CoreDNS behaviour before v1.14.0: + +``` +https://.:443 { + tls cert.pem key.pem + https { + max_connections 0 + } + whoami +} +``` diff --git a/plugin/https/setup.go b/plugin/https/setup.go new file mode 100644 index 000000000..727a37861 --- /dev/null +++ b/plugin/https/setup.go @@ -0,0 +1,63 @@ +package https + +import ( + "strconv" + + "github.com/coredns/caddy" + "github.com/coredns/coredns/core/dnsserver" + "github.com/coredns/coredns/plugin" +) + +func init() { + caddy.RegisterPlugin("https", caddy.Plugin{ + ServerType: "dns", + Action: setup, + }) +} + +func setup(c *caddy.Controller) error { + err := parseDOH(c) + if err != nil { + return plugin.Error("https", err) + } + return nil +} + +func parseDOH(c *caddy.Controller) error { + config := dnsserver.GetConfig(c) + + // Skip the "https" directive itself + c.Next() + + // Get any arguments on the "https" line + args := c.RemainingArgs() + if len(args) > 0 { + return c.ArgErr() + } + + // Process all nested directives in the block + for c.NextBlock() { + switch c.Val() { + case "max_connections": + args := c.RemainingArgs() + if len(args) != 1 { + return c.ArgErr() + } + val, err := strconv.Atoi(args[0]) + if err != nil { + return c.Errf("invalid max_connections value '%s': %v", args[0], err) + } + if val < 0 { + return c.Errf("max_connections must be a non-negative integer: %d", val) + } + if config.MaxHTTPSConnections != nil { + return c.Err("max_connections already defined for this server block") + } + config.MaxHTTPSConnections = &val + default: + return c.Errf("unknown property '%s'", c.Val()) + } + } + + return nil +} diff --git a/plugin/https/setup_test.go b/plugin/https/setup_test.go new file mode 100644 index 000000000..cb7020a8d --- /dev/null +++ b/plugin/https/setup_test.go @@ -0,0 +1,144 @@ +package https + +import ( + "fmt" + "strings" + "testing" + + "github.com/coredns/caddy" + "github.com/coredns/coredns/core/dnsserver" +) + +func TestSetup(t *testing.T) { + tests := []struct { + input string + shouldErr bool + expectedErrContent string + expectedMaxConnections *int + }{ + // Valid configurations + { + input: `https`, + shouldErr: false, + }, + { + input: `https { + }`, + shouldErr: false, + }, + { + input: `https { + max_connections 200 + }`, + shouldErr: false, + expectedMaxConnections: intPtr(200), + }, + // Zero values (unbounded) + { + input: `https { + max_connections 0 + }`, + shouldErr: false, + expectedMaxConnections: intPtr(0), + }, + // Error cases + { + input: `https { + max_connections + }`, + shouldErr: true, + expectedErrContent: "Wrong argument count", + }, + { + input: `https { + max_connections abc + }`, + shouldErr: true, + expectedErrContent: "invalid max_connections value", + }, + { + input: `https { + max_connections -1 + }`, + shouldErr: true, + expectedErrContent: "must be a non-negative integer", + }, + { + input: `https { + max_connections 100 + max_connections 200 + }`, + shouldErr: true, + expectedErrContent: "already defined", + }, + { + input: `https { + unknown_option 123 + }`, + shouldErr: true, + expectedErrContent: "unknown property", + }, + { + input: `https extra_arg`, + shouldErr: true, + expectedErrContent: "Wrong argument count", + }, + } + + for i, test := range tests { + c := caddy.NewTestController("dns", test.input) + err := setup(c) + + if test.shouldErr && err == nil { + t.Errorf("Test %d (%s): Expected error but got none", i, test.input) + continue + } + + if !test.shouldErr && err != nil { + t.Errorf("Test %d (%s): Expected no error but got: %v", i, test.input, err) + continue + } + + if test.shouldErr && test.expectedErrContent != "" { + if !strings.Contains(err.Error(), test.expectedErrContent) { + t.Errorf("Test %d (%s): Expected error containing '%s' but got: %v", + i, test.input, test.expectedErrContent, err) + } + continue + } + + if !test.shouldErr { + config := dnsserver.GetConfig(c) + assertIntPtrValue(t, i, test.input, "MaxHTTPSConnections", config.MaxHTTPSConnections, test.expectedMaxConnections) + } + } +} + +func intPtr(v int) *int { + return &v +} + +func assertIntPtrValue(t *testing.T, testIndex int, testInput, fieldName string, actual, expected *int) { + t.Helper() + if actual == nil && expected == nil { + return + } + + if (actual == nil) != (expected == nil) { + t.Errorf("Test %d (%s): Expected %s to be %v, but got %v", + testIndex, testInput, fieldName, formatNilableInt(expected), formatNilableInt(actual)) + return + } + + if *actual != *expected { + t.Errorf("Test %d (%s): Expected %s to be %d, but got %d", + testIndex, testInput, fieldName, *expected, *actual) + } +} + +func formatNilableInt(v *int) string { + if v == nil { + return "nil" + } + return fmt.Sprintf("%d", *v) +} diff --git a/plugin/https3/README.md b/plugin/https3/README.md new file mode 100644 index 000000000..9146137e4 --- /dev/null +++ b/plugin/https3/README.md @@ -0,0 +1,47 @@ +# https3 + +## Name + +*https3* - configures DNS-over-HTTPS/3 (DoH3) server options. + +## Description + +The *https3* plugin allows you to configure parameters for the DNS-over-HTTPS/3 (DoH3) server to fine-tune the security posture and performance of the server. HTTPS/3 uses QUIC as the underlying transport. + +This plugin can only be used once per HTTPS3 listener block. + +## Syntax + +```txt +https3 { + max_streams POSITIVE_INTEGER +} +``` + +* `max_streams` limits the number of concurrent QUIC streams per connection. This helps prevent unbounded streams on a single connection, exhausting server resources. The default value is 256 if not specified. Set to 0 to use underlying QUIC transport default. + +## Examples + +Set custom limits for maximum streams: + +``` +https3://.:443 { + tls cert.pem key.pem + https3 { + max_streams 50 + } + whoami +} +``` + +Set values to 0 for QUIC transport default, matching CoreDNS behaviour before v1.14.0: + +``` +https3://.:443 { + tls cert.pem key.pem + https3 { + max_streams 0 + } + whoami +} +``` diff --git a/plugin/https3/setup.go b/plugin/https3/setup.go new file mode 100644 index 000000000..dd42b356d --- /dev/null +++ b/plugin/https3/setup.go @@ -0,0 +1,63 @@ +package https3 + +import ( + "strconv" + + "github.com/coredns/caddy" + "github.com/coredns/coredns/core/dnsserver" + "github.com/coredns/coredns/plugin" +) + +func init() { + caddy.RegisterPlugin("https3", caddy.Plugin{ + ServerType: "dns", + Action: setup, + }) +} + +func setup(c *caddy.Controller) error { + err := parseDOH3(c) + if err != nil { + return plugin.Error("https3", err) + } + return nil +} + +func parseDOH3(c *caddy.Controller) error { + config := dnsserver.GetConfig(c) + + // Skip the "https3" directive itself + c.Next() + + // Get any arguments on the "https3" line + args := c.RemainingArgs() + if len(args) > 0 { + return c.ArgErr() + } + + // Process all nested directives in the block + for c.NextBlock() { + switch c.Val() { + case "max_streams": + args := c.RemainingArgs() + if len(args) != 1 { + return c.ArgErr() + } + val, err := strconv.Atoi(args[0]) + if err != nil { + return c.Errf("invalid max_streams value '%s': %v", args[0], err) + } + if val < 0 { + return c.Errf("max_streams must be a non-negative integer: %d", val) + } + if config.MaxHTTPS3Streams != nil { + return c.Err("max_streams already defined for this server block") + } + config.MaxHTTPS3Streams = &val + default: + return c.Errf("unknown property '%s'", c.Val()) + } + } + + return nil +} diff --git a/plugin/https3/setup_test.go b/plugin/https3/setup_test.go new file mode 100644 index 000000000..5e4c7abc3 --- /dev/null +++ b/plugin/https3/setup_test.go @@ -0,0 +1,144 @@ +package https3 + +import ( + "fmt" + "strings" + "testing" + + "github.com/coredns/caddy" + "github.com/coredns/coredns/core/dnsserver" +) + +func TestSetup(t *testing.T) { + tests := []struct { + input string + shouldErr bool + expectedErrContent string + expectedMaxStreams *int + }{ + // Valid configurations + { + input: `https3`, + shouldErr: false, + }, + { + input: `https3 { + }`, + shouldErr: false, + }, + { + input: `https3 { + max_streams 100 + }`, + shouldErr: false, + expectedMaxStreams: intPtr(100), + }, + // Zero values (unbounded) + { + input: `https3 { + max_streams 0 + }`, + shouldErr: false, + expectedMaxStreams: intPtr(0), + }, + // Error cases + { + input: `https3 { + max_streams + }`, + shouldErr: true, + expectedErrContent: "Wrong argument count", + }, + { + input: `https3 { + max_streams abc + }`, + shouldErr: true, + expectedErrContent: "invalid max_streams value", + }, + { + input: `https3 { + max_streams -1 + }`, + shouldErr: true, + expectedErrContent: "must be a non-negative integer", + }, + { + input: `https3 { + max_streams 100 + max_streams 200 + }`, + shouldErr: true, + expectedErrContent: "already defined", + }, + { + input: `https3 { + unknown_option 123 + }`, + shouldErr: true, + expectedErrContent: "unknown property", + }, + { + input: `https3 extra_arg`, + shouldErr: true, + expectedErrContent: "Wrong argument count", + }, + } + + for i, test := range tests { + c := caddy.NewTestController("dns", test.input) + err := setup(c) + + if test.shouldErr && err == nil { + t.Errorf("Test %d (%s): Expected error but got none", i, test.input) + continue + } + + if !test.shouldErr && err != nil { + t.Errorf("Test %d (%s): Expected no error but got: %v", i, test.input, err) + continue + } + + if test.shouldErr && test.expectedErrContent != "" { + if !strings.Contains(err.Error(), test.expectedErrContent) { + t.Errorf("Test %d (%s): Expected error containing '%s' but got: %v", + i, test.input, test.expectedErrContent, err) + } + continue + } + + if !test.shouldErr { + config := dnsserver.GetConfig(c) + assertIntPtrValue(t, i, test.input, "MaxHTTPS3Streams", config.MaxHTTPS3Streams, test.expectedMaxStreams) + } + } +} + +func intPtr(v int) *int { + return &v +} + +func assertIntPtrValue(t *testing.T, testIndex int, testInput, fieldName string, actual, expected *int) { + t.Helper() + if actual == nil && expected == nil { + return + } + + if (actual == nil) != (expected == nil) { + t.Errorf("Test %d (%s): Expected %s to be %v, but got %v", + testIndex, testInput, fieldName, formatNilableInt(expected), formatNilableInt(actual)) + return + } + + if *actual != *expected { + t.Errorf("Test %d (%s): Expected %s to be %d, but got %d", + testIndex, testInput, fieldName, *expected, *actual) + } +} + +func formatNilableInt(v *int) string { + if v == nil { + return "nil" + } + return fmt.Sprintf("%d", *v) +} diff --git a/test/grpc_test.go b/test/grpc_test.go index 07e7f211b..4481fe139 100644 --- a/test/grpc_test.go +++ b/test/grpc_test.go @@ -2,19 +2,40 @@ package test import ( "context" + "crypto/tls" + "net" "testing" + "time" "github.com/coredns/coredns/pb" "github.com/miekg/dns" "google.golang.org/grpc" + "google.golang.org/grpc/credentials" "google.golang.org/grpc/credentials/insecure" ) +var grpcCorefile = `grpc://.:0 { + whoami +}` + +var grpcLimitCorefile = `grpc://.:0 { + grpc_server { + max_streams 2 + } + whoami +}` + +var grpcConnectionLimitCorefile = `grpc://.:0 { + tls ../plugin/tls/test_cert.pem ../plugin/tls/test_key.pem ../plugin/tls/test_ca.pem + grpc_server { + max_connections 2 + } + whoami +}` + func TestGrpc(t *testing.T) { - corefile := `grpc://.:0 { - whoami - }` + corefile := grpcCorefile g, _, tcp, err := CoreDNSServerAndPorts(corefile) if err != nil { @@ -53,3 +74,127 @@ func TestGrpc(t *testing.T) { t.Errorf("Expected 2 RRs in additional section, but got %d", len(d.Extra)) } } + +// TestGRPCWithLimits tests that the server starts and works with configured limits +func TestGRPCWithLimits(t *testing.T) { + g, _, tcp, err := CoreDNSServerAndPorts(grpcLimitCorefile) + if err != nil { + t.Fatalf("Could not get CoreDNS serving instance: %s", err) + } + defer g.Stop() + + conn, err := grpc.NewClient(tcp, grpc.WithTransportCredentials(insecure.NewCredentials())) + if err != nil { + t.Fatalf("Expected no error but got: %s", err) + } + defer conn.Close() + + client := pb.NewDnsServiceClient(conn) + + m := new(dns.Msg) + m.SetQuestion("whoami.example.org.", dns.TypeA) + msg, _ := m.Pack() + + reply, err := client.Query(context.Background(), &pb.DnsPacket{Msg: msg}) + if err != nil { + t.Fatalf("Query failed: %s", err) + } + + d := new(dns.Msg) + if err := d.Unpack(reply.GetMsg()); err != nil { + t.Fatalf("Failed to unpack: %s", err) + } + + if d.Rcode != dns.RcodeSuccess { + t.Errorf("Expected success but got %d", d.Rcode) + } +} + +// TestGRPCConnectionLimit tests that connection limits are enforced +func TestGRPCConnectionLimit(t *testing.T) { + g, _, tcp, err := CoreDNSServerAndPorts(grpcConnectionLimitCorefile) + if err != nil { + t.Fatalf("Could not get CoreDNS serving instance: %s", err) + } + defer g.Stop() + + const maxConns = 2 + + // Create TLS connections to hold them open + tlsConfig := &tls.Config{InsecureSkipVerify: true} + conns := make([]net.Conn, 0, maxConns+1) + defer func() { + for _, c := range conns { + c.Close() + } + }() + + // Open connections up to the limit - these should succeed + for i := range maxConns { + conn, err := tls.Dial("tcp", tcp, tlsConfig) + if err != nil { + t.Fatalf("Connection %d failed (should succeed): %v", i+1, err) + } + conns = append(conns, conn) + } + + // Try to open more connections beyond the limit - should timeout + conn, err := tls.DialWithDialer( + &net.Dialer{Timeout: 100 * time.Millisecond}, + "tcp", tcp, tlsConfig, + ) + if err == nil { + conn.Close() + t.Fatal("Connection beyond limit should have timed out") + } + + // Close one connection and verify a new one can be established + conns[0].Close() + conns = conns[1:] + + time.Sleep(10 * time.Millisecond) + + conn, err = tls.Dial("tcp", tcp, tlsConfig) + if err != nil { + t.Fatalf("Connection after freeing slot failed: %v", err) + } + conns = append(conns, conn) +} + +// TestGRPCTLSWithLimits tests that gRPC with TLS starts and works with configured limits +func TestGRPCTLSWithLimits(t *testing.T) { + g, _, tcp, err := CoreDNSServerAndPorts(grpcConnectionLimitCorefile) + if err != nil { + t.Fatalf("Could not get CoreDNS serving instance: %s", err) + } + defer g.Stop() + + tlsConfig := &tls.Config{InsecureSkipVerify: true} + creds := credentials.NewTLS(tlsConfig) + + conn, err := grpc.NewClient(tcp, grpc.WithTransportCredentials(creds)) + if err != nil { + t.Fatalf("Expected no error but got: %s", err) + } + defer conn.Close() + + client := pb.NewDnsServiceClient(conn) + + m := new(dns.Msg) + m.SetQuestion("whoami.example.org.", dns.TypeA) + msg, _ := m.Pack() + + reply, err := client.Query(context.Background(), &pb.DnsPacket{Msg: msg}) + if err != nil { + t.Fatalf("Query failed: %s", err) + } + + d := new(dns.Msg) + if err := d.Unpack(reply.GetMsg()); err != nil { + t.Fatalf("Failed to unpack: %s", err) + } + + if d.Rcode != dns.RcodeSuccess { + t.Errorf("Expected success but got %d", d.Rcode) + } +} diff --git a/test/https3_test.go b/test/https3_test.go new file mode 100644 index 000000000..3057a457e --- /dev/null +++ b/test/https3_test.go @@ -0,0 +1,145 @@ +package test + +import ( + "bytes" + "context" + "crypto/tls" + "io" + "net/http" + "testing" + "time" + + ctls "github.com/coredns/coredns/plugin/pkg/tls" + + "github.com/miekg/dns" + "github.com/quic-go/quic-go/http3" +) + +var https3Corefile = `https3://.:0 { + tls ../plugin/tls/test_cert.pem ../plugin/tls/test_key.pem ../plugin/tls/test_ca.pem + whoami +}` + +var https3LimitCorefile = `https3://.:0 { + tls ../plugin/tls/test_cert.pem ../plugin/tls/test_key.pem ../plugin/tls/test_ca.pem + https3 { + max_streams 2 + } + whoami +}` + +func generateHTTPS3TLSConfig() *tls.Config { + tlsConfig, err := ctls.NewTLSConfig( + "../plugin/tls/test_cert.pem", + "../plugin/tls/test_key.pem", + "../plugin/tls/test_ca.pem") + + if err != nil { + panic(err) + } + + tlsConfig.InsecureSkipVerify = true + + return tlsConfig +} + +func TestHTTPS3(t *testing.T) { + s, udp, _, err := CoreDNSServerAndPorts(https3Corefile) + if err != nil { + t.Fatalf("Could not get CoreDNS serving instance: %s", err) + } + defer s.Stop() + + // Create HTTP/3 client + transport := &http3.Transport{ + TLSClientConfig: generateHTTPS3TLSConfig(), + } + defer transport.Close() + + client := &http.Client{ + Transport: transport, + Timeout: 5 * time.Second, + } + + // Create DNS query + m := new(dns.Msg) + m.SetQuestion("whoami.example.org.", dns.TypeA) + msg, err := m.Pack() + if err != nil { + t.Fatalf("Failed to pack DNS message: %v", err) + } + + // Make DoH3 request - use UDP address for HTTP/3 + url := "https://" + convertAddress(udp) + "/dns-query" + req, err := http.NewRequestWithContext(context.Background(), http.MethodPost, url, bytes.NewReader(msg)) + if err != nil { + t.Fatalf("Failed to create request: %v", err) + } + req.Header.Set("Content-Type", "application/dns-message") + req.Header.Set("Accept", "application/dns-message") + + resp, err := client.Do(req) + if err != nil { + t.Fatalf("Failed to make request: %v", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + t.Fatalf("Expected status 200, got %d", resp.StatusCode) + } + + body, err := io.ReadAll(resp.Body) + if err != nil { + t.Fatalf("Failed to read response: %v", err) + } + + d := new(dns.Msg) + err = d.Unpack(body) + if err != nil { + t.Fatalf("Failed to unpack response: %v", err) + } + + if d.Rcode != dns.RcodeSuccess { + t.Errorf("Expected success but got %d", d.Rcode) + } + + if len(d.Extra) != 2 { + t.Errorf("Expected 2 RRs in additional section, but got %d", len(d.Extra)) + } +} + +// TestHTTPS3WithLimits tests that the server starts and works with configured limits +func TestHTTPS3WithLimits(t *testing.T) { + s, udp, _, err := CoreDNSServerAndPorts(https3LimitCorefile) + if err != nil { + t.Fatalf("Could not get CoreDNS serving instance: %s", err) + } + defer s.Stop() + + transport := &http3.Transport{ + TLSClientConfig: generateHTTPS3TLSConfig(), + } + defer transport.Close() + + client := &http.Client{ + Transport: transport, + Timeout: 5 * time.Second, + } + + m := new(dns.Msg) + m.SetQuestion("whoami.example.org.", dns.TypeA) + msg, _ := m.Pack() + + req, _ := http.NewRequestWithContext(context.Background(), http.MethodPost, "https://"+convertAddress(udp)+"/dns-query", bytes.NewReader(msg)) + req.Header.Set("Content-Type", "application/dns-message") + + resp, err := client.Do(req) + if err != nil { + t.Fatalf("Request failed: %s", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + t.Fatalf("Expected status 200, got %d", resp.StatusCode) + } +} diff --git a/test/https_test.go b/test/https_test.go new file mode 100644 index 000000000..2bf4940f6 --- /dev/null +++ b/test/https_test.go @@ -0,0 +1,177 @@ +package test + +import ( + "bytes" + "crypto/tls" + "io" + "net" + "net/http" + "testing" + "time" + + "github.com/miekg/dns" +) + +var httpsCorefile = `https://.:0 { + tls ../plugin/tls/test_cert.pem ../plugin/tls/test_key.pem ../plugin/tls/test_ca.pem + whoami +}` + +var httpsLimitCorefile = `https://.:0 { + tls ../plugin/tls/test_cert.pem ../plugin/tls/test_key.pem ../plugin/tls/test_ca.pem + https { + max_connections 2 + } + whoami +}` + +func TestHTTPS(t *testing.T) { + s, _, tcp, err := CoreDNSServerAndPorts(httpsCorefile) + if err != nil { + t.Fatalf("Could not get CoreDNS serving instance: %s", err) + } + defer s.Stop() + + // Create HTTPS client with TLS config + tlsConfig := &tls.Config{ + InsecureSkipVerify: true, + } + client := &http.Client{ + Transport: &http.Transport{ + TLSClientConfig: tlsConfig, + }, + Timeout: 5 * time.Second, + } + + // Create DNS query + m := new(dns.Msg) + m.SetQuestion("whoami.example.org.", dns.TypeA) + msg, err := m.Pack() + if err != nil { + t.Fatalf("Failed to pack DNS message: %v", err) + } + + // Make DoH request + url := "https://" + tcp + "/dns-query" + req, err := http.NewRequest(http.MethodPost, url, bytes.NewReader(msg)) + if err != nil { + t.Fatalf("Failed to create request: %v", err) + } + req.Header.Set("Content-Type", "application/dns-message") + req.Header.Set("Accept", "application/dns-message") + + resp, err := client.Do(req) + if err != nil { + t.Fatalf("Failed to make request: %v", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + t.Fatalf("Expected status 200, got %d", resp.StatusCode) + } + + body, err := io.ReadAll(resp.Body) + if err != nil { + t.Fatalf("Failed to read response: %v", err) + } + + d := new(dns.Msg) + err = d.Unpack(body) + if err != nil { + t.Fatalf("Failed to unpack response: %v", err) + } + + if d.Rcode != dns.RcodeSuccess { + t.Errorf("Expected success but got %d", d.Rcode) + } + + if len(d.Extra) != 2 { + t.Errorf("Expected 2 RRs in additional section, but got %d", len(d.Extra)) + } +} + +// TestHTTPSWithLimits tests that the server starts and works with configured limits +func TestHTTPSWithLimits(t *testing.T) { + s, _, tcp, err := CoreDNSServerAndPorts(httpsLimitCorefile) + if err != nil { + t.Fatalf("Could not get CoreDNS serving instance: %s", err) + } + defer s.Stop() + + client := &http.Client{ + Transport: &http.Transport{ + TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, + }, + Timeout: 5 * time.Second, + } + + m := new(dns.Msg) + m.SetQuestion("whoami.example.org.", dns.TypeA) + msg, _ := m.Pack() + + req, _ := http.NewRequest(http.MethodPost, "https://"+tcp+"/dns-query", bytes.NewReader(msg)) + req.Header.Set("Content-Type", "application/dns-message") + + resp, err := client.Do(req) + if err != nil { + t.Fatalf("Request failed: %s", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + t.Fatalf("Expected status 200, got %d", resp.StatusCode) + } +} + +// TestHTTPSConnectionLimit tests that connection limits are enforced +func TestHTTPSConnectionLimit(t *testing.T) { + s, _, tcp, err := CoreDNSServerAndPorts(httpsLimitCorefile) + if err != nil { + t.Fatalf("Could not get CoreDNS serving instance: %s", err) + } + defer s.Stop() + + const maxConns = 2 + const totalConns = 4 + + // Create raw TLS connections to hold them open + conns := make([]net.Conn, 0, totalConns) + defer func() { + for _, c := range conns { + c.Close() + } + }() + + // Open connections up to the limit - these should succeed + for i := range maxConns { + conn, err := tls.Dial("tcp", tcp, &tls.Config{InsecureSkipVerify: true}) + if err != nil { + t.Fatalf("Connection %d failed (should succeed): %v", i+1, err) + } + conns = append(conns, conn) + } + + // Try to open more connections beyond the limit + // The LimitListener blocks Accept() until a slot is free, so Dial with timeout should fail + conn, err := tls.DialWithDialer( + &net.Dialer{Timeout: 100 * time.Millisecond}, + "tcp", tcp, + &tls.Config{InsecureSkipVerify: true}, + ) + if err == nil { + conn.Close() + t.Fatal("Connection beyond limit should have timed out") + } + + // Close one connection and verify a new one can be established + conns[0].Close() + conns = conns[1:] + + time.Sleep(10 * time.Millisecond) // Give the listener time to accept + + conn, err = tls.Dial("tcp", tcp, &tls.Config{InsecureSkipVerify: true}) + if err != nil { + t.Fatalf("Connection after freeing slot failed: %v", err) + } + conns = append(conns, conn) +}