mirror of
https://github.com/coredns/coredns.git
synced 2025-12-21 17:45:15 -05:00
Merge commit from fork
Add configurable resource limits to prevent potential DoS vectors via connection/stream exhaustion on gRPC, HTTPS, and HTTPS/3 servers. New configuration plugins: - grpc_server: configure max_streams, max_connections - https: configure max_connections - https3: configure max_streams Changes: - Use netutil.LimitListener for connection limiting - Use gRPC MaxConcurrentStreams and message size limits - Add QUIC MaxIncomingStreams for HTTPS/3 stream limiting - Set secure defaults: 256 max streams, 200 max connections - Setting any limit to 0 means unbounded/fallback to previous impl Defaults are applied automatically when plugins are omitted from config. Includes tests and integration tests. Signed-off-by: Ville Vesilehto <ville@vesilehto.fi>
This commit is contained in:
@@ -66,6 +66,22 @@ type Config struct {
|
||||
// This is nil if not specified, allowing for a default to be used.
|
||||
MaxQUICWorkerPoolSize *int
|
||||
|
||||
// MaxGRPCStreams defines the maximum number of concurrent streams per gRPC connection.
|
||||
// This is nil if not specified, allowing for a default to be used.
|
||||
MaxGRPCStreams *int
|
||||
|
||||
// MaxGRPCConnections defines the maximum number of concurrent gRPC connections.
|
||||
// This is nil if not specified, allowing for a default to be used.
|
||||
MaxGRPCConnections *int
|
||||
|
||||
// MaxHTTPSConnections defines the maximum number of concurrent HTTPS connections.
|
||||
// This is nil if not specified, allowing for a default to be used.
|
||||
MaxHTTPSConnections *int
|
||||
|
||||
// MaxHTTPS3Streams defines the maximum number of concurrent QUIC streams for HTTPS3.
|
||||
// This is nil if not specified, allowing for a default to be used.
|
||||
MaxHTTPS3Streams *int
|
||||
|
||||
// Timeouts for TCP, TLS and HTTPS servers.
|
||||
ReadTimeout time.Duration
|
||||
WriteTimeout time.Duration
|
||||
|
||||
@@ -15,17 +15,35 @@ import (
|
||||
"github.com/grpc-ecosystem/grpc-opentracing/go/otgrpc"
|
||||
"github.com/miekg/dns"
|
||||
"github.com/opentracing/opentracing-go"
|
||||
"golang.org/x/net/netutil"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/peer"
|
||||
)
|
||||
|
||||
const (
|
||||
// maxDNSMessageBytes is the maximum size of a DNS message on the wire.
|
||||
maxDNSMessageBytes = dns.MaxMsgSize
|
||||
|
||||
// maxProtobufPayloadBytes accounts for protobuf overhead.
|
||||
// Field tag=1 (1 byte) + length varint for 65535 (3 bytes) = 4 bytes total
|
||||
maxProtobufPayloadBytes = maxDNSMessageBytes + 4
|
||||
|
||||
// DefaultGRPCMaxStreams is the default maximum number of concurrent streams per connection.
|
||||
DefaultGRPCMaxStreams = 256
|
||||
|
||||
// DefaultGRPCMaxConnections is the default maximum number of concurrent connections.
|
||||
DefaultGRPCMaxConnections = 200
|
||||
)
|
||||
|
||||
// ServergRPC represents an instance of a DNS-over-gRPC server.
|
||||
type ServergRPC struct {
|
||||
*Server
|
||||
*pb.UnimplementedDnsServiceServer
|
||||
grpcServer *grpc.Server
|
||||
listenAddr net.Addr
|
||||
tlsConfig *tls.Config
|
||||
grpcServer *grpc.Server
|
||||
listenAddr net.Addr
|
||||
tlsConfig *tls.Config
|
||||
maxStreams int
|
||||
maxConnections int
|
||||
}
|
||||
|
||||
// NewServergRPC returns a new CoreDNS GRPC server and compiles all plugin in to it.
|
||||
@@ -49,7 +67,22 @@ func NewServergRPC(addr string, group []*Config) (*ServergRPC, error) {
|
||||
tlsConfig.NextProtos = []string{"h2"}
|
||||
}
|
||||
|
||||
return &ServergRPC{Server: s, tlsConfig: tlsConfig}, nil
|
||||
maxStreams := DefaultGRPCMaxStreams
|
||||
if len(group) > 0 && group[0] != nil && group[0].MaxGRPCStreams != nil {
|
||||
maxStreams = *group[0].MaxGRPCStreams
|
||||
}
|
||||
|
||||
maxConnections := DefaultGRPCMaxConnections
|
||||
if len(group) > 0 && group[0] != nil && group[0].MaxGRPCConnections != nil {
|
||||
maxConnections = *group[0].MaxGRPCConnections
|
||||
}
|
||||
|
||||
return &ServergRPC{
|
||||
Server: s,
|
||||
tlsConfig: tlsConfig,
|
||||
maxStreams: maxStreams,
|
||||
maxConnections: maxConnections,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Compile-time check to ensure ServergRPC implements the caddy.GracefulServer interface
|
||||
@@ -61,21 +94,36 @@ func (s *ServergRPC) Serve(l net.Listener) error {
|
||||
s.listenAddr = l.Addr()
|
||||
s.m.Unlock()
|
||||
|
||||
serverOpts := []grpc.ServerOption{
|
||||
grpc.MaxRecvMsgSize(maxProtobufPayloadBytes),
|
||||
grpc.MaxSendMsgSize(maxProtobufPayloadBytes),
|
||||
}
|
||||
|
||||
// Only set MaxConcurrentStreams if not unbounded (0)
|
||||
if s.maxStreams > 0 {
|
||||
serverOpts = append(serverOpts, grpc.MaxConcurrentStreams(uint32(s.maxStreams)))
|
||||
}
|
||||
|
||||
if s.Tracer() != nil {
|
||||
onlyIfParent := func(parentSpanCtx opentracing.SpanContext, method string, req, resp any) bool {
|
||||
return parentSpanCtx != nil
|
||||
}
|
||||
intercept := otgrpc.OpenTracingServerInterceptor(s.Tracer(), otgrpc.IncludingSpans(onlyIfParent))
|
||||
s.grpcServer = grpc.NewServer(grpc.UnaryInterceptor(intercept))
|
||||
} else {
|
||||
s.grpcServer = grpc.NewServer()
|
||||
serverOpts = append(serverOpts, grpc.UnaryInterceptor(otgrpc.OpenTracingServerInterceptor(s.Tracer(), otgrpc.IncludingSpans(onlyIfParent))))
|
||||
}
|
||||
|
||||
s.grpcServer = grpc.NewServer(serverOpts...)
|
||||
|
||||
pb.RegisterDnsServiceServer(s.grpcServer, s)
|
||||
|
||||
if s.tlsConfig != nil {
|
||||
l = tls.NewListener(l, s.tlsConfig)
|
||||
}
|
||||
|
||||
// Wrap listener to limit concurrent connections
|
||||
if s.maxConnections > 0 {
|
||||
l = netutil.LimitListener(l, s.maxConnections)
|
||||
}
|
||||
|
||||
return s.grpcServer.Serve(l)
|
||||
}
|
||||
|
||||
@@ -122,6 +170,9 @@ func (s *ServergRPC) Stop() (err error) {
|
||||
// any normal server. We use a custom responseWriter to pick up the bytes we need to write
|
||||
// back to the client as a protobuf.
|
||||
func (s *ServergRPC) Query(ctx context.Context, in *pb.DnsPacket) (*pb.DnsPacket, error) {
|
||||
if len(in.GetMsg()) > dns.MaxMsgSize {
|
||||
return nil, fmt.Errorf("dns message exceeds size limit: %d", len(in.GetMsg()))
|
||||
}
|
||||
msg := new(dns.Msg)
|
||||
err := msg.Unpack(in.GetMsg())
|
||||
if err != nil {
|
||||
|
||||
@@ -69,6 +69,61 @@ func TestNewServergRPCWithTLS(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewServergRPCWithCustomLimits(t *testing.T) {
|
||||
config := testConfig("grpc", testPlugin{})
|
||||
maxStreams := 50
|
||||
maxConnections := 100
|
||||
config.MaxGRPCStreams = &maxStreams
|
||||
config.MaxGRPCConnections = &maxConnections
|
||||
|
||||
server, err := NewServergRPC("127.0.0.1:0", []*Config{config})
|
||||
if err != nil {
|
||||
t.Fatalf("NewServergRPC() with custom limits failed: %v", err)
|
||||
}
|
||||
|
||||
if server.maxStreams != maxStreams {
|
||||
t.Errorf("Expected maxStreams = %d, got %d", maxStreams, server.maxStreams)
|
||||
}
|
||||
|
||||
if server.maxConnections != maxConnections {
|
||||
t.Errorf("Expected maxConnections = %d, got %d", maxConnections, server.maxConnections)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewServergRPCDefaults(t *testing.T) {
|
||||
server, err := NewServergRPC("127.0.0.1:0", []*Config{testConfig("grpc", testPlugin{})})
|
||||
if err != nil {
|
||||
t.Fatalf("NewServergRPC() failed: %v", err)
|
||||
}
|
||||
|
||||
if server.maxStreams != DefaultGRPCMaxStreams {
|
||||
t.Errorf("Expected default maxStreams = %d, got %d", DefaultGRPCMaxStreams, server.maxStreams)
|
||||
}
|
||||
|
||||
if server.maxConnections != DefaultGRPCMaxConnections {
|
||||
t.Errorf("Expected default maxConnections = %d, got %d", DefaultGRPCMaxConnections, server.maxConnections)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewServergRPCZeroLimits(t *testing.T) {
|
||||
config := testConfig("grpc", testPlugin{})
|
||||
zero := 0
|
||||
config.MaxGRPCStreams = &zero
|
||||
config.MaxGRPCConnections = &zero
|
||||
|
||||
server, err := NewServergRPC("127.0.0.1:0", []*Config{config})
|
||||
if err != nil {
|
||||
t.Fatalf("NewServergRPC() with zero limits failed: %v", err)
|
||||
}
|
||||
|
||||
if server.maxStreams != 0 {
|
||||
t.Errorf("Expected maxStreams = 0, got %d", server.maxStreams)
|
||||
}
|
||||
if server.maxConnections != 0 {
|
||||
t.Errorf("Expected maxConnections = 0, got %d", server.maxConnections)
|
||||
}
|
||||
}
|
||||
|
||||
func TestServergRPC_Listen(t *testing.T) {
|
||||
server, err := NewServergRPC(transport.GRPC+"://127.0.0.1:0", []*Config{testConfig("grpc", testPlugin{})})
|
||||
if err != nil {
|
||||
@@ -328,3 +383,67 @@ func TestGRPCResponse_WriteInvalidMessage(t *testing.T) {
|
||||
t.Error("Write() should return error for invalid DNS message")
|
||||
}
|
||||
}
|
||||
|
||||
func TestServergRPC_Query_LargeMessage(t *testing.T) {
|
||||
server, err := NewServergRPC("127.0.0.1:0", []*Config{testConfig("grpc", testPlugin{})})
|
||||
if err != nil {
|
||||
t.Fatalf("NewServergRPC failed: %v", err)
|
||||
}
|
||||
|
||||
// Create oversized message (> dns.MaxMsgSize = 65535)
|
||||
oversizedMsg := make([]byte, dns.MaxMsgSize+1)
|
||||
dnsPacket := &pb.DnsPacket{Msg: oversizedMsg}
|
||||
|
||||
tcpAddr, _ := net.ResolveTCPAddr("tcp", "127.0.0.1:12345")
|
||||
p := &peer.Peer{Addr: tcpAddr}
|
||||
ctx := peer.NewContext(context.Background(), p)
|
||||
|
||||
server.listenAddr = tcpAddr
|
||||
|
||||
_, err = server.Query(ctx, dnsPacket)
|
||||
if err == nil {
|
||||
t.Error("Expected error for oversized message")
|
||||
}
|
||||
|
||||
expectedError := "dns message exceeds size limit: 65536"
|
||||
if err.Error() != expectedError {
|
||||
t.Errorf("Expected error '%s', got '%s'", expectedError, err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
func TestServergRPC_Query_MaxSizeMessage(t *testing.T) {
|
||||
server, err := NewServergRPC("127.0.0.1:0", []*Config{testConfig("grpc", testPlugin{})})
|
||||
if err != nil {
|
||||
t.Fatalf("NewServergRPC failed: %v", err)
|
||||
}
|
||||
|
||||
// Create message exactly at the size limit (dns.MaxMsgSize = 65535)
|
||||
msg := new(dns.Msg)
|
||||
msg.SetQuestion("example.com.", dns.TypeA)
|
||||
packed, err := msg.Pack()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to pack DNS message: %v", err)
|
||||
}
|
||||
|
||||
// Pad the message to exactly max size
|
||||
if len(packed) > dns.MaxMsgSize {
|
||||
t.Fatalf("Packed message is already larger than max size: %d", len(packed))
|
||||
}
|
||||
|
||||
maxSizeMsg := make([]byte, dns.MaxMsgSize)
|
||||
copy(maxSizeMsg, packed)
|
||||
|
||||
dnsPacket := &pb.DnsPacket{Msg: maxSizeMsg}
|
||||
|
||||
tcpAddr, _ := net.ResolveTCPAddr("tcp", "127.0.0.1:12345")
|
||||
p := &peer.Peer{Addr: tcpAddr}
|
||||
ctx := peer.NewContext(context.Background(), p)
|
||||
|
||||
server.listenAddr = tcpAddr
|
||||
|
||||
// Should not return an error for exactly max size message
|
||||
_, err = server.Query(ctx, dnsPacket)
|
||||
if err != nil {
|
||||
t.Errorf("Expected no error for max size message, got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -18,15 +18,23 @@ import (
|
||||
"github.com/coredns/coredns/plugin/pkg/response"
|
||||
"github.com/coredns/coredns/plugin/pkg/reuseport"
|
||||
"github.com/coredns/coredns/plugin/pkg/transport"
|
||||
|
||||
"golang.org/x/net/netutil"
|
||||
)
|
||||
|
||||
const (
|
||||
// DefaultHTTPSMaxConnections is the default maximum number of concurrent connections.
|
||||
DefaultHTTPSMaxConnections = 200
|
||||
)
|
||||
|
||||
// ServerHTTPS represents an instance of a DNS-over-HTTPS server.
|
||||
type ServerHTTPS struct {
|
||||
*Server
|
||||
httpsServer *http.Server
|
||||
listenAddr net.Addr
|
||||
tlsConfig *tls.Config
|
||||
validRequest func(*http.Request) bool
|
||||
httpsServer *http.Server
|
||||
listenAddr net.Addr
|
||||
tlsConfig *tls.Config
|
||||
validRequest func(*http.Request) bool
|
||||
maxConnections int
|
||||
}
|
||||
|
||||
// loggerAdapter is a simple adapter around CoreDNS logger made to implement io.Writer in order to log errors from HTTP server
|
||||
@@ -81,8 +89,17 @@ func NewServerHTTPS(addr string, group []*Config) (*ServerHTTPS, error) {
|
||||
IdleTimeout: s.IdleTimeout,
|
||||
ErrorLog: stdlog.New(&loggerAdapter{}, "", 0),
|
||||
}
|
||||
maxConnections := DefaultHTTPSMaxConnections
|
||||
if len(group) > 0 && group[0] != nil && group[0].MaxHTTPSConnections != nil {
|
||||
maxConnections = *group[0].MaxHTTPSConnections
|
||||
}
|
||||
|
||||
sh := &ServerHTTPS{
|
||||
Server: s, tlsConfig: tlsConfig, httpsServer: srv, validRequest: validator,
|
||||
Server: s,
|
||||
tlsConfig: tlsConfig,
|
||||
httpsServer: srv,
|
||||
validRequest: validator,
|
||||
maxConnections: maxConnections,
|
||||
}
|
||||
sh.httpsServer.Handler = sh
|
||||
|
||||
@@ -98,9 +115,15 @@ func (s *ServerHTTPS) Serve(l net.Listener) error {
|
||||
s.listenAddr = l.Addr()
|
||||
s.m.Unlock()
|
||||
|
||||
// Wrap listener to limit concurrent connections (before TLS)
|
||||
if s.maxConnections > 0 {
|
||||
l = netutil.LimitListener(l, s.maxConnections)
|
||||
}
|
||||
|
||||
if s.tlsConfig != nil {
|
||||
l = tls.NewListener(l, s.tlsConfig)
|
||||
}
|
||||
|
||||
return s.httpsServer.Serve(l)
|
||||
}
|
||||
|
||||
|
||||
@@ -21,6 +21,11 @@ import (
|
||||
"github.com/quic-go/quic-go/http3"
|
||||
)
|
||||
|
||||
const (
|
||||
// DefaultHTTPS3MaxStreams is the default maximum number of concurrent QUIC streams per connection.
|
||||
DefaultHTTPS3MaxStreams = 256
|
||||
)
|
||||
|
||||
// ServerHTTPS3 represents a DNS-over-HTTP/3 server.
|
||||
type ServerHTTPS3 struct {
|
||||
*Server
|
||||
@@ -29,6 +34,7 @@ type ServerHTTPS3 struct {
|
||||
tlsConfig *tls.Config
|
||||
quicConfig *quic.Config
|
||||
validRequest func(*http.Request) bool
|
||||
maxStreams int
|
||||
}
|
||||
|
||||
// NewServerHTTPS3 builds the HTTP/3 (DoH3) server.
|
||||
@@ -63,11 +69,20 @@ func NewServerHTTPS3(addr string, group []*Config) (*ServerHTTPS3, error) {
|
||||
validator = func(r *http.Request) bool { return r.URL.Path == doh.Path }
|
||||
}
|
||||
|
||||
// QUIC transport config
|
||||
maxStreams := DefaultHTTPS3MaxStreams
|
||||
if len(group) > 0 && group[0] != nil && group[0].MaxHTTPS3Streams != nil {
|
||||
maxStreams = *group[0].MaxHTTPS3Streams
|
||||
}
|
||||
|
||||
// QUIC transport config with stream limits (0 means use QUIC default)
|
||||
qconf := &quic.Config{
|
||||
MaxIdleTimeout: s.IdleTimeout,
|
||||
Allow0RTT: true,
|
||||
}
|
||||
if maxStreams > 0 {
|
||||
qconf.MaxIncomingStreams = int64(maxStreams)
|
||||
qconf.MaxIncomingUniStreams = int64(maxStreams)
|
||||
}
|
||||
|
||||
h3srv := &http3.Server{
|
||||
Handler: nil, // set after constructing ServerHTTPS3
|
||||
@@ -83,6 +98,7 @@ func NewServerHTTPS3(addr string, group []*Config) (*ServerHTTPS3, error) {
|
||||
httpsServer: h3srv,
|
||||
quicConfig: qconf,
|
||||
validRequest: validator,
|
||||
maxStreams: maxStreams,
|
||||
}
|
||||
|
||||
h3srv.Handler = sh
|
||||
|
||||
@@ -60,3 +60,82 @@ func TestCustomHTTP3RequestValidator(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewServerHTTPS3WithCustomLimits(t *testing.T) {
|
||||
maxStreams := 50
|
||||
c := Config{
|
||||
Zone: "example.com.",
|
||||
Transport: "https3",
|
||||
TLSConfig: &tls.Config{},
|
||||
ListenHosts: []string{"127.0.0.1"},
|
||||
Port: "443",
|
||||
MaxHTTPS3Streams: &maxStreams,
|
||||
}
|
||||
|
||||
server, err := NewServerHTTPS3("127.0.0.1:443", []*Config{&c})
|
||||
if err != nil {
|
||||
t.Fatalf("NewServerHTTPS3() with custom limits failed: %v", err)
|
||||
}
|
||||
|
||||
if server.maxStreams != maxStreams {
|
||||
t.Errorf("Expected maxStreams = %d, got %d", maxStreams, server.maxStreams)
|
||||
}
|
||||
|
||||
expectedMaxStreams := int64(maxStreams)
|
||||
if server.quicConfig.MaxIncomingStreams != expectedMaxStreams {
|
||||
t.Errorf("Expected quicConfig.MaxIncomingStreams = %d, got %d", expectedMaxStreams, server.quicConfig.MaxIncomingStreams)
|
||||
}
|
||||
|
||||
if server.quicConfig.MaxIncomingUniStreams != expectedMaxStreams {
|
||||
t.Errorf("Expected quicConfig.MaxIncomingUniStreams = %d, got %d", expectedMaxStreams, server.quicConfig.MaxIncomingUniStreams)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewServerHTTPS3Defaults(t *testing.T) {
|
||||
c := Config{
|
||||
Zone: "example.com.",
|
||||
Transport: "https3",
|
||||
TLSConfig: &tls.Config{},
|
||||
ListenHosts: []string{"127.0.0.1"},
|
||||
Port: "443",
|
||||
}
|
||||
|
||||
server, err := NewServerHTTPS3("127.0.0.1:443", []*Config{&c})
|
||||
if err != nil {
|
||||
t.Fatalf("NewServerHTTPS3() failed: %v", err)
|
||||
}
|
||||
|
||||
if server.maxStreams != DefaultHTTPS3MaxStreams {
|
||||
t.Errorf("Expected default maxStreams = %d, got %d", DefaultHTTPS3MaxStreams, server.maxStreams)
|
||||
}
|
||||
|
||||
expectedMaxStreams := int64(DefaultHTTPS3MaxStreams)
|
||||
if server.quicConfig.MaxIncomingStreams != expectedMaxStreams {
|
||||
t.Errorf("Expected default quicConfig.MaxIncomingStreams = %d, got %d", expectedMaxStreams, server.quicConfig.MaxIncomingStreams)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewServerHTTPS3ZeroLimits(t *testing.T) {
|
||||
zero := 0
|
||||
c := Config{
|
||||
Zone: "example.com.",
|
||||
Transport: "https3",
|
||||
TLSConfig: &tls.Config{},
|
||||
ListenHosts: []string{"127.0.0.1"},
|
||||
Port: "443",
|
||||
MaxHTTPS3Streams: &zero,
|
||||
}
|
||||
|
||||
server, err := NewServerHTTPS3("127.0.0.1:443", []*Config{&c})
|
||||
if err != nil {
|
||||
t.Fatalf("NewServerHTTPS3() with zero limits failed: %v", err)
|
||||
}
|
||||
|
||||
if server.maxStreams != 0 {
|
||||
t.Errorf("Expected maxStreams = 0, got %d", server.maxStreams)
|
||||
}
|
||||
// When maxStreams is 0, quicConfig should not set MaxIncomingStreams (uses QUIC default)
|
||||
if server.quicConfig.MaxIncomingStreams != 0 {
|
||||
t.Errorf("Expected quicConfig.MaxIncomingStreams = 0 (QUIC default), got %d", server.quicConfig.MaxIncomingStreams)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -72,6 +72,67 @@ func TestCustomHTTPRequestValidator(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewServerHTTPSWithCustomLimits(t *testing.T) {
|
||||
maxConnections := 100
|
||||
c := Config{
|
||||
Zone: "example.com.",
|
||||
Transport: "https",
|
||||
TLSConfig: &tls.Config{},
|
||||
ListenHosts: []string{"127.0.0.1"},
|
||||
Port: "443",
|
||||
MaxHTTPSConnections: &maxConnections,
|
||||
}
|
||||
|
||||
server, err := NewServerHTTPS("127.0.0.1:443", []*Config{&c})
|
||||
if err != nil {
|
||||
t.Fatalf("NewServerHTTPS() with custom limits failed: %v", err)
|
||||
}
|
||||
|
||||
if server.maxConnections != maxConnections {
|
||||
t.Errorf("Expected maxConnections = %d, got %d", maxConnections, server.maxConnections)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewServerHTTPSDefaults(t *testing.T) {
|
||||
c := Config{
|
||||
Zone: "example.com.",
|
||||
Transport: "https",
|
||||
TLSConfig: &tls.Config{},
|
||||
ListenHosts: []string{"127.0.0.1"},
|
||||
Port: "443",
|
||||
}
|
||||
|
||||
server, err := NewServerHTTPS("127.0.0.1:443", []*Config{&c})
|
||||
if err != nil {
|
||||
t.Fatalf("NewServerHTTPS() failed: %v", err)
|
||||
}
|
||||
|
||||
if server.maxConnections != DefaultHTTPSMaxConnections {
|
||||
t.Errorf("Expected default maxConnections = %d, got %d", DefaultHTTPSMaxConnections, server.maxConnections)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewServerHTTPSZeroLimits(t *testing.T) {
|
||||
zero := 0
|
||||
c := Config{
|
||||
Zone: "example.com.",
|
||||
Transport: "https",
|
||||
TLSConfig: &tls.Config{},
|
||||
ListenHosts: []string{"127.0.0.1"},
|
||||
Port: "443",
|
||||
MaxHTTPSConnections: &zero,
|
||||
}
|
||||
|
||||
server, err := NewServerHTTPS("127.0.0.1:443", []*Config{&c})
|
||||
if err != nil {
|
||||
t.Fatalf("NewServerHTTPS() with zero limits failed: %v", err)
|
||||
}
|
||||
|
||||
if server.maxConnections != 0 {
|
||||
t.Errorf("Expected maxConnections = 0, got %d", server.maxConnections)
|
||||
}
|
||||
}
|
||||
|
||||
type contextCapturingPlugin struct {
|
||||
capturedContext context.Context
|
||||
contextCancelled bool
|
||||
|
||||
@@ -156,12 +156,29 @@ func (s *ServerQUIC) serveQUICConnection(conn *quic.Conn) {
|
||||
return
|
||||
}
|
||||
|
||||
// Use a bounded worker pool
|
||||
s.streamProcessPool <- struct{}{} // Acquire a worker slot, may block
|
||||
go func(st *quic.Stream, cn *quic.Conn) {
|
||||
defer func() { <-s.streamProcessPool }() // Release worker slot
|
||||
s.serveQUICStream(st, cn)
|
||||
}(stream, conn)
|
||||
// Use a bounded worker pool with context cancellation
|
||||
select {
|
||||
case s.streamProcessPool <- struct{}{}:
|
||||
// Got worker slot immediately
|
||||
go func(st *quic.Stream, cn *quic.Conn) {
|
||||
defer func() { <-s.streamProcessPool }() // Release worker slot
|
||||
s.serveQUICStream(st, cn)
|
||||
}(stream, conn)
|
||||
default:
|
||||
// Worker pool full, check for context cancellation
|
||||
go func(st *quic.Stream, cn *quic.Conn) {
|
||||
select {
|
||||
case s.streamProcessPool <- struct{}{}:
|
||||
// Got worker slot after waiting
|
||||
defer func() { <-s.streamProcessPool }() // Release worker slot
|
||||
s.serveQUICStream(st, cn)
|
||||
case <-conn.Context().Done():
|
||||
// Connection context was cancelled while waiting
|
||||
st.Close()
|
||||
return
|
||||
}
|
||||
}(stream, conn)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -16,6 +16,9 @@ var Directives = []string{
|
||||
"cancel",
|
||||
"tls",
|
||||
"quic",
|
||||
"grpc_server",
|
||||
"https",
|
||||
"https3",
|
||||
"timeouts",
|
||||
"multisocket",
|
||||
"reload",
|
||||
|
||||
@@ -27,9 +27,12 @@ import (
|
||||
_ "github.com/coredns/coredns/plugin/forward"
|
||||
_ "github.com/coredns/coredns/plugin/geoip"
|
||||
_ "github.com/coredns/coredns/plugin/grpc"
|
||||
_ "github.com/coredns/coredns/plugin/grpc_server"
|
||||
_ "github.com/coredns/coredns/plugin/header"
|
||||
_ "github.com/coredns/coredns/plugin/health"
|
||||
_ "github.com/coredns/coredns/plugin/hosts"
|
||||
_ "github.com/coredns/coredns/plugin/https"
|
||||
_ "github.com/coredns/coredns/plugin/https3"
|
||||
_ "github.com/coredns/coredns/plugin/k8s_external"
|
||||
_ "github.com/coredns/coredns/plugin/kubernetes"
|
||||
_ "github.com/coredns/coredns/plugin/loadbalance"
|
||||
|
||||
Reference in New Issue
Block a user