mirror of
https://github.com/coredns/coredns.git
synced 2025-12-19 16:45:11 -05:00
Merge commit from fork
Add configurable resource limits to prevent potential DoS vectors via connection/stream exhaustion on gRPC, HTTPS, and HTTPS/3 servers. New configuration plugins: - grpc_server: configure max_streams, max_connections - https: configure max_connections - https3: configure max_streams Changes: - Use netutil.LimitListener for connection limiting - Use gRPC MaxConcurrentStreams and message size limits - Add QUIC MaxIncomingStreams for HTTPS/3 stream limiting - Set secure defaults: 256 max streams, 200 max connections - Setting any limit to 0 means unbounded/fallback to previous impl Defaults are applied automatically when plugins are omitted from config. Includes tests and integration tests. Signed-off-by: Ville Vesilehto <ville@vesilehto.fi>
This commit is contained in:
@@ -66,6 +66,22 @@ type Config struct {
|
||||
// This is nil if not specified, allowing for a default to be used.
|
||||
MaxQUICWorkerPoolSize *int
|
||||
|
||||
// MaxGRPCStreams defines the maximum number of concurrent streams per gRPC connection.
|
||||
// This is nil if not specified, allowing for a default to be used.
|
||||
MaxGRPCStreams *int
|
||||
|
||||
// MaxGRPCConnections defines the maximum number of concurrent gRPC connections.
|
||||
// This is nil if not specified, allowing for a default to be used.
|
||||
MaxGRPCConnections *int
|
||||
|
||||
// MaxHTTPSConnections defines the maximum number of concurrent HTTPS connections.
|
||||
// This is nil if not specified, allowing for a default to be used.
|
||||
MaxHTTPSConnections *int
|
||||
|
||||
// MaxHTTPS3Streams defines the maximum number of concurrent QUIC streams for HTTPS3.
|
||||
// This is nil if not specified, allowing for a default to be used.
|
||||
MaxHTTPS3Streams *int
|
||||
|
||||
// Timeouts for TCP, TLS and HTTPS servers.
|
||||
ReadTimeout time.Duration
|
||||
WriteTimeout time.Duration
|
||||
|
||||
@@ -15,17 +15,35 @@ import (
|
||||
"github.com/grpc-ecosystem/grpc-opentracing/go/otgrpc"
|
||||
"github.com/miekg/dns"
|
||||
"github.com/opentracing/opentracing-go"
|
||||
"golang.org/x/net/netutil"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/peer"
|
||||
)
|
||||
|
||||
const (
|
||||
// maxDNSMessageBytes is the maximum size of a DNS message on the wire.
|
||||
maxDNSMessageBytes = dns.MaxMsgSize
|
||||
|
||||
// maxProtobufPayloadBytes accounts for protobuf overhead.
|
||||
// Field tag=1 (1 byte) + length varint for 65535 (3 bytes) = 4 bytes total
|
||||
maxProtobufPayloadBytes = maxDNSMessageBytes + 4
|
||||
|
||||
// DefaultGRPCMaxStreams is the default maximum number of concurrent streams per connection.
|
||||
DefaultGRPCMaxStreams = 256
|
||||
|
||||
// DefaultGRPCMaxConnections is the default maximum number of concurrent connections.
|
||||
DefaultGRPCMaxConnections = 200
|
||||
)
|
||||
|
||||
// ServergRPC represents an instance of a DNS-over-gRPC server.
|
||||
type ServergRPC struct {
|
||||
*Server
|
||||
*pb.UnimplementedDnsServiceServer
|
||||
grpcServer *grpc.Server
|
||||
listenAddr net.Addr
|
||||
tlsConfig *tls.Config
|
||||
grpcServer *grpc.Server
|
||||
listenAddr net.Addr
|
||||
tlsConfig *tls.Config
|
||||
maxStreams int
|
||||
maxConnections int
|
||||
}
|
||||
|
||||
// NewServergRPC returns a new CoreDNS GRPC server and compiles all plugin in to it.
|
||||
@@ -49,7 +67,22 @@ func NewServergRPC(addr string, group []*Config) (*ServergRPC, error) {
|
||||
tlsConfig.NextProtos = []string{"h2"}
|
||||
}
|
||||
|
||||
return &ServergRPC{Server: s, tlsConfig: tlsConfig}, nil
|
||||
maxStreams := DefaultGRPCMaxStreams
|
||||
if len(group) > 0 && group[0] != nil && group[0].MaxGRPCStreams != nil {
|
||||
maxStreams = *group[0].MaxGRPCStreams
|
||||
}
|
||||
|
||||
maxConnections := DefaultGRPCMaxConnections
|
||||
if len(group) > 0 && group[0] != nil && group[0].MaxGRPCConnections != nil {
|
||||
maxConnections = *group[0].MaxGRPCConnections
|
||||
}
|
||||
|
||||
return &ServergRPC{
|
||||
Server: s,
|
||||
tlsConfig: tlsConfig,
|
||||
maxStreams: maxStreams,
|
||||
maxConnections: maxConnections,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Compile-time check to ensure ServergRPC implements the caddy.GracefulServer interface
|
||||
@@ -61,21 +94,36 @@ func (s *ServergRPC) Serve(l net.Listener) error {
|
||||
s.listenAddr = l.Addr()
|
||||
s.m.Unlock()
|
||||
|
||||
serverOpts := []grpc.ServerOption{
|
||||
grpc.MaxRecvMsgSize(maxProtobufPayloadBytes),
|
||||
grpc.MaxSendMsgSize(maxProtobufPayloadBytes),
|
||||
}
|
||||
|
||||
// Only set MaxConcurrentStreams if not unbounded (0)
|
||||
if s.maxStreams > 0 {
|
||||
serverOpts = append(serverOpts, grpc.MaxConcurrentStreams(uint32(s.maxStreams)))
|
||||
}
|
||||
|
||||
if s.Tracer() != nil {
|
||||
onlyIfParent := func(parentSpanCtx opentracing.SpanContext, method string, req, resp any) bool {
|
||||
return parentSpanCtx != nil
|
||||
}
|
||||
intercept := otgrpc.OpenTracingServerInterceptor(s.Tracer(), otgrpc.IncludingSpans(onlyIfParent))
|
||||
s.grpcServer = grpc.NewServer(grpc.UnaryInterceptor(intercept))
|
||||
} else {
|
||||
s.grpcServer = grpc.NewServer()
|
||||
serverOpts = append(serverOpts, grpc.UnaryInterceptor(otgrpc.OpenTracingServerInterceptor(s.Tracer(), otgrpc.IncludingSpans(onlyIfParent))))
|
||||
}
|
||||
|
||||
s.grpcServer = grpc.NewServer(serverOpts...)
|
||||
|
||||
pb.RegisterDnsServiceServer(s.grpcServer, s)
|
||||
|
||||
if s.tlsConfig != nil {
|
||||
l = tls.NewListener(l, s.tlsConfig)
|
||||
}
|
||||
|
||||
// Wrap listener to limit concurrent connections
|
||||
if s.maxConnections > 0 {
|
||||
l = netutil.LimitListener(l, s.maxConnections)
|
||||
}
|
||||
|
||||
return s.grpcServer.Serve(l)
|
||||
}
|
||||
|
||||
@@ -122,6 +170,9 @@ func (s *ServergRPC) Stop() (err error) {
|
||||
// any normal server. We use a custom responseWriter to pick up the bytes we need to write
|
||||
// back to the client as a protobuf.
|
||||
func (s *ServergRPC) Query(ctx context.Context, in *pb.DnsPacket) (*pb.DnsPacket, error) {
|
||||
if len(in.GetMsg()) > dns.MaxMsgSize {
|
||||
return nil, fmt.Errorf("dns message exceeds size limit: %d", len(in.GetMsg()))
|
||||
}
|
||||
msg := new(dns.Msg)
|
||||
err := msg.Unpack(in.GetMsg())
|
||||
if err != nil {
|
||||
|
||||
@@ -69,6 +69,61 @@ func TestNewServergRPCWithTLS(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewServergRPCWithCustomLimits(t *testing.T) {
|
||||
config := testConfig("grpc", testPlugin{})
|
||||
maxStreams := 50
|
||||
maxConnections := 100
|
||||
config.MaxGRPCStreams = &maxStreams
|
||||
config.MaxGRPCConnections = &maxConnections
|
||||
|
||||
server, err := NewServergRPC("127.0.0.1:0", []*Config{config})
|
||||
if err != nil {
|
||||
t.Fatalf("NewServergRPC() with custom limits failed: %v", err)
|
||||
}
|
||||
|
||||
if server.maxStreams != maxStreams {
|
||||
t.Errorf("Expected maxStreams = %d, got %d", maxStreams, server.maxStreams)
|
||||
}
|
||||
|
||||
if server.maxConnections != maxConnections {
|
||||
t.Errorf("Expected maxConnections = %d, got %d", maxConnections, server.maxConnections)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewServergRPCDefaults(t *testing.T) {
|
||||
server, err := NewServergRPC("127.0.0.1:0", []*Config{testConfig("grpc", testPlugin{})})
|
||||
if err != nil {
|
||||
t.Fatalf("NewServergRPC() failed: %v", err)
|
||||
}
|
||||
|
||||
if server.maxStreams != DefaultGRPCMaxStreams {
|
||||
t.Errorf("Expected default maxStreams = %d, got %d", DefaultGRPCMaxStreams, server.maxStreams)
|
||||
}
|
||||
|
||||
if server.maxConnections != DefaultGRPCMaxConnections {
|
||||
t.Errorf("Expected default maxConnections = %d, got %d", DefaultGRPCMaxConnections, server.maxConnections)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewServergRPCZeroLimits(t *testing.T) {
|
||||
config := testConfig("grpc", testPlugin{})
|
||||
zero := 0
|
||||
config.MaxGRPCStreams = &zero
|
||||
config.MaxGRPCConnections = &zero
|
||||
|
||||
server, err := NewServergRPC("127.0.0.1:0", []*Config{config})
|
||||
if err != nil {
|
||||
t.Fatalf("NewServergRPC() with zero limits failed: %v", err)
|
||||
}
|
||||
|
||||
if server.maxStreams != 0 {
|
||||
t.Errorf("Expected maxStreams = 0, got %d", server.maxStreams)
|
||||
}
|
||||
if server.maxConnections != 0 {
|
||||
t.Errorf("Expected maxConnections = 0, got %d", server.maxConnections)
|
||||
}
|
||||
}
|
||||
|
||||
func TestServergRPC_Listen(t *testing.T) {
|
||||
server, err := NewServergRPC(transport.GRPC+"://127.0.0.1:0", []*Config{testConfig("grpc", testPlugin{})})
|
||||
if err != nil {
|
||||
@@ -328,3 +383,67 @@ func TestGRPCResponse_WriteInvalidMessage(t *testing.T) {
|
||||
t.Error("Write() should return error for invalid DNS message")
|
||||
}
|
||||
}
|
||||
|
||||
func TestServergRPC_Query_LargeMessage(t *testing.T) {
|
||||
server, err := NewServergRPC("127.0.0.1:0", []*Config{testConfig("grpc", testPlugin{})})
|
||||
if err != nil {
|
||||
t.Fatalf("NewServergRPC failed: %v", err)
|
||||
}
|
||||
|
||||
// Create oversized message (> dns.MaxMsgSize = 65535)
|
||||
oversizedMsg := make([]byte, dns.MaxMsgSize+1)
|
||||
dnsPacket := &pb.DnsPacket{Msg: oversizedMsg}
|
||||
|
||||
tcpAddr, _ := net.ResolveTCPAddr("tcp", "127.0.0.1:12345")
|
||||
p := &peer.Peer{Addr: tcpAddr}
|
||||
ctx := peer.NewContext(context.Background(), p)
|
||||
|
||||
server.listenAddr = tcpAddr
|
||||
|
||||
_, err = server.Query(ctx, dnsPacket)
|
||||
if err == nil {
|
||||
t.Error("Expected error for oversized message")
|
||||
}
|
||||
|
||||
expectedError := "dns message exceeds size limit: 65536"
|
||||
if err.Error() != expectedError {
|
||||
t.Errorf("Expected error '%s', got '%s'", expectedError, err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
func TestServergRPC_Query_MaxSizeMessage(t *testing.T) {
|
||||
server, err := NewServergRPC("127.0.0.1:0", []*Config{testConfig("grpc", testPlugin{})})
|
||||
if err != nil {
|
||||
t.Fatalf("NewServergRPC failed: %v", err)
|
||||
}
|
||||
|
||||
// Create message exactly at the size limit (dns.MaxMsgSize = 65535)
|
||||
msg := new(dns.Msg)
|
||||
msg.SetQuestion("example.com.", dns.TypeA)
|
||||
packed, err := msg.Pack()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to pack DNS message: %v", err)
|
||||
}
|
||||
|
||||
// Pad the message to exactly max size
|
||||
if len(packed) > dns.MaxMsgSize {
|
||||
t.Fatalf("Packed message is already larger than max size: %d", len(packed))
|
||||
}
|
||||
|
||||
maxSizeMsg := make([]byte, dns.MaxMsgSize)
|
||||
copy(maxSizeMsg, packed)
|
||||
|
||||
dnsPacket := &pb.DnsPacket{Msg: maxSizeMsg}
|
||||
|
||||
tcpAddr, _ := net.ResolveTCPAddr("tcp", "127.0.0.1:12345")
|
||||
p := &peer.Peer{Addr: tcpAddr}
|
||||
ctx := peer.NewContext(context.Background(), p)
|
||||
|
||||
server.listenAddr = tcpAddr
|
||||
|
||||
// Should not return an error for exactly max size message
|
||||
_, err = server.Query(ctx, dnsPacket)
|
||||
if err != nil {
|
||||
t.Errorf("Expected no error for max size message, got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -18,15 +18,23 @@ import (
|
||||
"github.com/coredns/coredns/plugin/pkg/response"
|
||||
"github.com/coredns/coredns/plugin/pkg/reuseport"
|
||||
"github.com/coredns/coredns/plugin/pkg/transport"
|
||||
|
||||
"golang.org/x/net/netutil"
|
||||
)
|
||||
|
||||
const (
|
||||
// DefaultHTTPSMaxConnections is the default maximum number of concurrent connections.
|
||||
DefaultHTTPSMaxConnections = 200
|
||||
)
|
||||
|
||||
// ServerHTTPS represents an instance of a DNS-over-HTTPS server.
|
||||
type ServerHTTPS struct {
|
||||
*Server
|
||||
httpsServer *http.Server
|
||||
listenAddr net.Addr
|
||||
tlsConfig *tls.Config
|
||||
validRequest func(*http.Request) bool
|
||||
httpsServer *http.Server
|
||||
listenAddr net.Addr
|
||||
tlsConfig *tls.Config
|
||||
validRequest func(*http.Request) bool
|
||||
maxConnections int
|
||||
}
|
||||
|
||||
// loggerAdapter is a simple adapter around CoreDNS logger made to implement io.Writer in order to log errors from HTTP server
|
||||
@@ -81,8 +89,17 @@ func NewServerHTTPS(addr string, group []*Config) (*ServerHTTPS, error) {
|
||||
IdleTimeout: s.IdleTimeout,
|
||||
ErrorLog: stdlog.New(&loggerAdapter{}, "", 0),
|
||||
}
|
||||
maxConnections := DefaultHTTPSMaxConnections
|
||||
if len(group) > 0 && group[0] != nil && group[0].MaxHTTPSConnections != nil {
|
||||
maxConnections = *group[0].MaxHTTPSConnections
|
||||
}
|
||||
|
||||
sh := &ServerHTTPS{
|
||||
Server: s, tlsConfig: tlsConfig, httpsServer: srv, validRequest: validator,
|
||||
Server: s,
|
||||
tlsConfig: tlsConfig,
|
||||
httpsServer: srv,
|
||||
validRequest: validator,
|
||||
maxConnections: maxConnections,
|
||||
}
|
||||
sh.httpsServer.Handler = sh
|
||||
|
||||
@@ -98,9 +115,15 @@ func (s *ServerHTTPS) Serve(l net.Listener) error {
|
||||
s.listenAddr = l.Addr()
|
||||
s.m.Unlock()
|
||||
|
||||
// Wrap listener to limit concurrent connections (before TLS)
|
||||
if s.maxConnections > 0 {
|
||||
l = netutil.LimitListener(l, s.maxConnections)
|
||||
}
|
||||
|
||||
if s.tlsConfig != nil {
|
||||
l = tls.NewListener(l, s.tlsConfig)
|
||||
}
|
||||
|
||||
return s.httpsServer.Serve(l)
|
||||
}
|
||||
|
||||
|
||||
@@ -21,6 +21,11 @@ import (
|
||||
"github.com/quic-go/quic-go/http3"
|
||||
)
|
||||
|
||||
const (
|
||||
// DefaultHTTPS3MaxStreams is the default maximum number of concurrent QUIC streams per connection.
|
||||
DefaultHTTPS3MaxStreams = 256
|
||||
)
|
||||
|
||||
// ServerHTTPS3 represents a DNS-over-HTTP/3 server.
|
||||
type ServerHTTPS3 struct {
|
||||
*Server
|
||||
@@ -29,6 +34,7 @@ type ServerHTTPS3 struct {
|
||||
tlsConfig *tls.Config
|
||||
quicConfig *quic.Config
|
||||
validRequest func(*http.Request) bool
|
||||
maxStreams int
|
||||
}
|
||||
|
||||
// NewServerHTTPS3 builds the HTTP/3 (DoH3) server.
|
||||
@@ -63,11 +69,20 @@ func NewServerHTTPS3(addr string, group []*Config) (*ServerHTTPS3, error) {
|
||||
validator = func(r *http.Request) bool { return r.URL.Path == doh.Path }
|
||||
}
|
||||
|
||||
// QUIC transport config
|
||||
maxStreams := DefaultHTTPS3MaxStreams
|
||||
if len(group) > 0 && group[0] != nil && group[0].MaxHTTPS3Streams != nil {
|
||||
maxStreams = *group[0].MaxHTTPS3Streams
|
||||
}
|
||||
|
||||
// QUIC transport config with stream limits (0 means use QUIC default)
|
||||
qconf := &quic.Config{
|
||||
MaxIdleTimeout: s.IdleTimeout,
|
||||
Allow0RTT: true,
|
||||
}
|
||||
if maxStreams > 0 {
|
||||
qconf.MaxIncomingStreams = int64(maxStreams)
|
||||
qconf.MaxIncomingUniStreams = int64(maxStreams)
|
||||
}
|
||||
|
||||
h3srv := &http3.Server{
|
||||
Handler: nil, // set after constructing ServerHTTPS3
|
||||
@@ -83,6 +98,7 @@ func NewServerHTTPS3(addr string, group []*Config) (*ServerHTTPS3, error) {
|
||||
httpsServer: h3srv,
|
||||
quicConfig: qconf,
|
||||
validRequest: validator,
|
||||
maxStreams: maxStreams,
|
||||
}
|
||||
|
||||
h3srv.Handler = sh
|
||||
|
||||
@@ -60,3 +60,82 @@ func TestCustomHTTP3RequestValidator(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewServerHTTPS3WithCustomLimits(t *testing.T) {
|
||||
maxStreams := 50
|
||||
c := Config{
|
||||
Zone: "example.com.",
|
||||
Transport: "https3",
|
||||
TLSConfig: &tls.Config{},
|
||||
ListenHosts: []string{"127.0.0.1"},
|
||||
Port: "443",
|
||||
MaxHTTPS3Streams: &maxStreams,
|
||||
}
|
||||
|
||||
server, err := NewServerHTTPS3("127.0.0.1:443", []*Config{&c})
|
||||
if err != nil {
|
||||
t.Fatalf("NewServerHTTPS3() with custom limits failed: %v", err)
|
||||
}
|
||||
|
||||
if server.maxStreams != maxStreams {
|
||||
t.Errorf("Expected maxStreams = %d, got %d", maxStreams, server.maxStreams)
|
||||
}
|
||||
|
||||
expectedMaxStreams := int64(maxStreams)
|
||||
if server.quicConfig.MaxIncomingStreams != expectedMaxStreams {
|
||||
t.Errorf("Expected quicConfig.MaxIncomingStreams = %d, got %d", expectedMaxStreams, server.quicConfig.MaxIncomingStreams)
|
||||
}
|
||||
|
||||
if server.quicConfig.MaxIncomingUniStreams != expectedMaxStreams {
|
||||
t.Errorf("Expected quicConfig.MaxIncomingUniStreams = %d, got %d", expectedMaxStreams, server.quicConfig.MaxIncomingUniStreams)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewServerHTTPS3Defaults(t *testing.T) {
|
||||
c := Config{
|
||||
Zone: "example.com.",
|
||||
Transport: "https3",
|
||||
TLSConfig: &tls.Config{},
|
||||
ListenHosts: []string{"127.0.0.1"},
|
||||
Port: "443",
|
||||
}
|
||||
|
||||
server, err := NewServerHTTPS3("127.0.0.1:443", []*Config{&c})
|
||||
if err != nil {
|
||||
t.Fatalf("NewServerHTTPS3() failed: %v", err)
|
||||
}
|
||||
|
||||
if server.maxStreams != DefaultHTTPS3MaxStreams {
|
||||
t.Errorf("Expected default maxStreams = %d, got %d", DefaultHTTPS3MaxStreams, server.maxStreams)
|
||||
}
|
||||
|
||||
expectedMaxStreams := int64(DefaultHTTPS3MaxStreams)
|
||||
if server.quicConfig.MaxIncomingStreams != expectedMaxStreams {
|
||||
t.Errorf("Expected default quicConfig.MaxIncomingStreams = %d, got %d", expectedMaxStreams, server.quicConfig.MaxIncomingStreams)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewServerHTTPS3ZeroLimits(t *testing.T) {
|
||||
zero := 0
|
||||
c := Config{
|
||||
Zone: "example.com.",
|
||||
Transport: "https3",
|
||||
TLSConfig: &tls.Config{},
|
||||
ListenHosts: []string{"127.0.0.1"},
|
||||
Port: "443",
|
||||
MaxHTTPS3Streams: &zero,
|
||||
}
|
||||
|
||||
server, err := NewServerHTTPS3("127.0.0.1:443", []*Config{&c})
|
||||
if err != nil {
|
||||
t.Fatalf("NewServerHTTPS3() with zero limits failed: %v", err)
|
||||
}
|
||||
|
||||
if server.maxStreams != 0 {
|
||||
t.Errorf("Expected maxStreams = 0, got %d", server.maxStreams)
|
||||
}
|
||||
// When maxStreams is 0, quicConfig should not set MaxIncomingStreams (uses QUIC default)
|
||||
if server.quicConfig.MaxIncomingStreams != 0 {
|
||||
t.Errorf("Expected quicConfig.MaxIncomingStreams = 0 (QUIC default), got %d", server.quicConfig.MaxIncomingStreams)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -72,6 +72,67 @@ func TestCustomHTTPRequestValidator(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewServerHTTPSWithCustomLimits(t *testing.T) {
|
||||
maxConnections := 100
|
||||
c := Config{
|
||||
Zone: "example.com.",
|
||||
Transport: "https",
|
||||
TLSConfig: &tls.Config{},
|
||||
ListenHosts: []string{"127.0.0.1"},
|
||||
Port: "443",
|
||||
MaxHTTPSConnections: &maxConnections,
|
||||
}
|
||||
|
||||
server, err := NewServerHTTPS("127.0.0.1:443", []*Config{&c})
|
||||
if err != nil {
|
||||
t.Fatalf("NewServerHTTPS() with custom limits failed: %v", err)
|
||||
}
|
||||
|
||||
if server.maxConnections != maxConnections {
|
||||
t.Errorf("Expected maxConnections = %d, got %d", maxConnections, server.maxConnections)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewServerHTTPSDefaults(t *testing.T) {
|
||||
c := Config{
|
||||
Zone: "example.com.",
|
||||
Transport: "https",
|
||||
TLSConfig: &tls.Config{},
|
||||
ListenHosts: []string{"127.0.0.1"},
|
||||
Port: "443",
|
||||
}
|
||||
|
||||
server, err := NewServerHTTPS("127.0.0.1:443", []*Config{&c})
|
||||
if err != nil {
|
||||
t.Fatalf("NewServerHTTPS() failed: %v", err)
|
||||
}
|
||||
|
||||
if server.maxConnections != DefaultHTTPSMaxConnections {
|
||||
t.Errorf("Expected default maxConnections = %d, got %d", DefaultHTTPSMaxConnections, server.maxConnections)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewServerHTTPSZeroLimits(t *testing.T) {
|
||||
zero := 0
|
||||
c := Config{
|
||||
Zone: "example.com.",
|
||||
Transport: "https",
|
||||
TLSConfig: &tls.Config{},
|
||||
ListenHosts: []string{"127.0.0.1"},
|
||||
Port: "443",
|
||||
MaxHTTPSConnections: &zero,
|
||||
}
|
||||
|
||||
server, err := NewServerHTTPS("127.0.0.1:443", []*Config{&c})
|
||||
if err != nil {
|
||||
t.Fatalf("NewServerHTTPS() with zero limits failed: %v", err)
|
||||
}
|
||||
|
||||
if server.maxConnections != 0 {
|
||||
t.Errorf("Expected maxConnections = 0, got %d", server.maxConnections)
|
||||
}
|
||||
}
|
||||
|
||||
type contextCapturingPlugin struct {
|
||||
capturedContext context.Context
|
||||
contextCancelled bool
|
||||
|
||||
@@ -156,12 +156,29 @@ func (s *ServerQUIC) serveQUICConnection(conn *quic.Conn) {
|
||||
return
|
||||
}
|
||||
|
||||
// Use a bounded worker pool
|
||||
s.streamProcessPool <- struct{}{} // Acquire a worker slot, may block
|
||||
go func(st *quic.Stream, cn *quic.Conn) {
|
||||
defer func() { <-s.streamProcessPool }() // Release worker slot
|
||||
s.serveQUICStream(st, cn)
|
||||
}(stream, conn)
|
||||
// Use a bounded worker pool with context cancellation
|
||||
select {
|
||||
case s.streamProcessPool <- struct{}{}:
|
||||
// Got worker slot immediately
|
||||
go func(st *quic.Stream, cn *quic.Conn) {
|
||||
defer func() { <-s.streamProcessPool }() // Release worker slot
|
||||
s.serveQUICStream(st, cn)
|
||||
}(stream, conn)
|
||||
default:
|
||||
// Worker pool full, check for context cancellation
|
||||
go func(st *quic.Stream, cn *quic.Conn) {
|
||||
select {
|
||||
case s.streamProcessPool <- struct{}{}:
|
||||
// Got worker slot after waiting
|
||||
defer func() { <-s.streamProcessPool }() // Release worker slot
|
||||
s.serveQUICStream(st, cn)
|
||||
case <-conn.Context().Done():
|
||||
// Connection context was cancelled while waiting
|
||||
st.Close()
|
||||
return
|
||||
}
|
||||
}(stream, conn)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -16,6 +16,9 @@ var Directives = []string{
|
||||
"cancel",
|
||||
"tls",
|
||||
"quic",
|
||||
"grpc_server",
|
||||
"https",
|
||||
"https3",
|
||||
"timeouts",
|
||||
"multisocket",
|
||||
"reload",
|
||||
|
||||
@@ -27,9 +27,12 @@ import (
|
||||
_ "github.com/coredns/coredns/plugin/forward"
|
||||
_ "github.com/coredns/coredns/plugin/geoip"
|
||||
_ "github.com/coredns/coredns/plugin/grpc"
|
||||
_ "github.com/coredns/coredns/plugin/grpc_server"
|
||||
_ "github.com/coredns/coredns/plugin/header"
|
||||
_ "github.com/coredns/coredns/plugin/health"
|
||||
_ "github.com/coredns/coredns/plugin/hosts"
|
||||
_ "github.com/coredns/coredns/plugin/https"
|
||||
_ "github.com/coredns/coredns/plugin/https3"
|
||||
_ "github.com/coredns/coredns/plugin/k8s_external"
|
||||
_ "github.com/coredns/coredns/plugin/kubernetes"
|
||||
_ "github.com/coredns/coredns/plugin/loadbalance"
|
||||
|
||||
@@ -25,6 +25,9 @@ geoip:geoip
|
||||
cancel:cancel
|
||||
tls:tls
|
||||
quic:quic
|
||||
grpc_server:grpc_server
|
||||
https:https
|
||||
https3:https3
|
||||
timeouts:timeouts
|
||||
multisocket:multisocket
|
||||
reload:reload
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
package chaos
|
||||
|
||||
// Owners are all GitHub handlers of all maintainers.
|
||||
var Owners = []string{"Tantalor93", "bradbeam", "chrisohaver", "darshanime", "dilyevsky", "ekleiner", "greenpau", "ihac", "inigohu", "isolus", "jameshartig", "johnbelamaric", "miekg", "mqasimsarfraz", "nchrisdk", "nitisht", "pmoroney", "rajansandeep", "rdrozhdzh", "rtreffer", "snebel29", "stp-ip", "superq", "varyoo", "ykhr53", "yongtang", "zouyee"}
|
||||
var Owners = []string{"Tantalor93", "bradbeam", "chrisohaver", "darshanime", "dilyevsky", "ekleiner", "greenpau", "ihac", "inigohu", "isolus", "jameshartig", "johnbelamaric", "miekg", "mqasimsarfraz", "nchrisdk", "nitisht", "pmoroney", "rajansandeep", "rdrozhdzh", "rtreffer", "snebel29", "stp-ip", "superq", "thevilledev", "varyoo", "ykhr53", "yongtang", "zouyee"}
|
||||
|
||||
51
plugin/grpc_server/README.md
Normal file
51
plugin/grpc_server/README.md
Normal file
@@ -0,0 +1,51 @@
|
||||
# grpc_server
|
||||
|
||||
## Name
|
||||
|
||||
*grpc_server* - configures DNS-over-gRPC server options.
|
||||
|
||||
## Description
|
||||
|
||||
The *grpc_server* plugin allows you to configure parameters for the DNS-over-gRPC server to fine-tune the security posture and performance of the server.
|
||||
|
||||
This plugin can only be used once per gRPC listener block.
|
||||
|
||||
## Syntax
|
||||
|
||||
```txt
|
||||
grpc_server {
|
||||
max_streams POSITIVE_INTEGER
|
||||
max_connections POSITIVE_INTEGER
|
||||
}
|
||||
```
|
||||
|
||||
* `max_streams` limits the number of concurrent gRPC streams per connection. This helps prevent unbounded streams on a single connection, exhausting server resources. The default value is 256 if not specified. Set to 0 for unbounded.
|
||||
* `max_connections` limits the number of concurrent TCP connections to the gRPC server. The default value is 200 if not specified. Set to 0 for unbounded.
|
||||
|
||||
## Examples
|
||||
|
||||
Set custom limits for maximum streams and connections:
|
||||
|
||||
```
|
||||
grpc://.:8053 {
|
||||
tls cert.pem key.pem
|
||||
grpc_server {
|
||||
max_streams 50
|
||||
max_connections 100
|
||||
}
|
||||
whoami
|
||||
}
|
||||
```
|
||||
|
||||
Set values to 0 for unbounded, matching CoreDNS behaviour before v1.14.0:
|
||||
|
||||
```
|
||||
grpc://.:8053 {
|
||||
tls cert.pem key.pem
|
||||
grpc_server {
|
||||
max_streams 0
|
||||
max_connections 0
|
||||
}
|
||||
whoami
|
||||
}
|
||||
```
|
||||
79
plugin/grpc_server/setup.go
Normal file
79
plugin/grpc_server/setup.go
Normal file
@@ -0,0 +1,79 @@
|
||||
package grpc_server
|
||||
|
||||
import (
|
||||
"strconv"
|
||||
|
||||
"github.com/coredns/caddy"
|
||||
"github.com/coredns/coredns/core/dnsserver"
|
||||
"github.com/coredns/coredns/plugin"
|
||||
)
|
||||
|
||||
func init() {
|
||||
caddy.RegisterPlugin("grpc_server", caddy.Plugin{
|
||||
ServerType: "dns",
|
||||
Action: setup,
|
||||
})
|
||||
}
|
||||
|
||||
func setup(c *caddy.Controller) error {
|
||||
err := parseGRPCServer(c)
|
||||
if err != nil {
|
||||
return plugin.Error("grpc_server", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func parseGRPCServer(c *caddy.Controller) error {
|
||||
config := dnsserver.GetConfig(c)
|
||||
|
||||
// Skip the "grpc_server" directive itself
|
||||
c.Next()
|
||||
|
||||
// Get any arguments on the "grpc_server" line
|
||||
args := c.RemainingArgs()
|
||||
if len(args) > 0 {
|
||||
return c.ArgErr()
|
||||
}
|
||||
|
||||
// Process all nested directives in the block
|
||||
for c.NextBlock() {
|
||||
switch c.Val() {
|
||||
case "max_streams":
|
||||
args := c.RemainingArgs()
|
||||
if len(args) != 1 {
|
||||
return c.ArgErr()
|
||||
}
|
||||
val, err := strconv.Atoi(args[0])
|
||||
if err != nil {
|
||||
return c.Errf("invalid max_streams value '%s': %v", args[0], err)
|
||||
}
|
||||
if val < 0 {
|
||||
return c.Errf("max_streams must be a non-negative integer: %d", val)
|
||||
}
|
||||
if config.MaxGRPCStreams != nil {
|
||||
return c.Err("max_streams already defined for this server block")
|
||||
}
|
||||
config.MaxGRPCStreams = &val
|
||||
case "max_connections":
|
||||
args := c.RemainingArgs()
|
||||
if len(args) != 1 {
|
||||
return c.ArgErr()
|
||||
}
|
||||
val, err := strconv.Atoi(args[0])
|
||||
if err != nil {
|
||||
return c.Errf("invalid max_connections value '%s': %v", args[0], err)
|
||||
}
|
||||
if val < 0 {
|
||||
return c.Errf("max_connections must be a non-negative integer: %d", val)
|
||||
}
|
||||
if config.MaxGRPCConnections != nil {
|
||||
return c.Err("max_connections already defined for this server block")
|
||||
}
|
||||
config.MaxGRPCConnections = &val
|
||||
default:
|
||||
return c.Errf("unknown property '%s'", c.Val())
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
169
plugin/grpc_server/setup_test.go
Normal file
169
plugin/grpc_server/setup_test.go
Normal file
@@ -0,0 +1,169 @@
|
||||
package grpc_server
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/coredns/caddy"
|
||||
"github.com/coredns/coredns/core/dnsserver"
|
||||
)
|
||||
|
||||
func TestSetup(t *testing.T) {
|
||||
tests := []struct {
|
||||
input string
|
||||
shouldErr bool
|
||||
expectedErrContent string
|
||||
expectedMaxStreams *int
|
||||
expectedMaxConnections *int
|
||||
}{
|
||||
// Valid configurations
|
||||
{
|
||||
input: `grpc_server`,
|
||||
shouldErr: false,
|
||||
},
|
||||
{
|
||||
input: `grpc_server {
|
||||
}`,
|
||||
shouldErr: false,
|
||||
},
|
||||
{
|
||||
input: `grpc_server {
|
||||
max_streams 100
|
||||
}`,
|
||||
shouldErr: false,
|
||||
expectedMaxStreams: intPtr(100),
|
||||
},
|
||||
{
|
||||
input: `grpc_server {
|
||||
max_connections 200
|
||||
}`,
|
||||
shouldErr: false,
|
||||
expectedMaxConnections: intPtr(200),
|
||||
},
|
||||
{
|
||||
input: `grpc_server {
|
||||
max_streams 50
|
||||
max_connections 100
|
||||
}`,
|
||||
shouldErr: false,
|
||||
expectedMaxStreams: intPtr(50),
|
||||
expectedMaxConnections: intPtr(100),
|
||||
},
|
||||
// Zero values (unbounded)
|
||||
{
|
||||
input: `grpc_server {
|
||||
max_streams 0
|
||||
}`,
|
||||
shouldErr: false,
|
||||
expectedMaxStreams: intPtr(0),
|
||||
},
|
||||
{
|
||||
input: `grpc_server {
|
||||
max_connections 0
|
||||
}`,
|
||||
shouldErr: false,
|
||||
expectedMaxConnections: intPtr(0),
|
||||
},
|
||||
// Error cases
|
||||
{
|
||||
input: `grpc_server {
|
||||
max_streams
|
||||
}`,
|
||||
shouldErr: true,
|
||||
expectedErrContent: "Wrong argument count",
|
||||
},
|
||||
{
|
||||
input: `grpc_server {
|
||||
max_streams abc
|
||||
}`,
|
||||
shouldErr: true,
|
||||
expectedErrContent: "invalid max_streams value",
|
||||
},
|
||||
{
|
||||
input: `grpc_server {
|
||||
max_streams -1
|
||||
}`,
|
||||
shouldErr: true,
|
||||
expectedErrContent: "must be a non-negative integer",
|
||||
},
|
||||
{
|
||||
input: `grpc_server {
|
||||
max_streams 100
|
||||
max_streams 200
|
||||
}`,
|
||||
shouldErr: true,
|
||||
expectedErrContent: "already defined",
|
||||
},
|
||||
{
|
||||
input: `grpc_server {
|
||||
unknown_option 123
|
||||
}`,
|
||||
shouldErr: true,
|
||||
expectedErrContent: "unknown property",
|
||||
},
|
||||
{
|
||||
input: `grpc_server extra_arg`,
|
||||
shouldErr: true,
|
||||
expectedErrContent: "Wrong argument count",
|
||||
},
|
||||
}
|
||||
|
||||
for i, test := range tests {
|
||||
c := caddy.NewTestController("dns", test.input)
|
||||
err := setup(c)
|
||||
|
||||
if test.shouldErr && err == nil {
|
||||
t.Errorf("Test %d (%s): Expected error but got none", i, test.input)
|
||||
continue
|
||||
}
|
||||
|
||||
if !test.shouldErr && err != nil {
|
||||
t.Errorf("Test %d (%s): Expected no error but got: %v", i, test.input, err)
|
||||
continue
|
||||
}
|
||||
|
||||
if test.shouldErr && test.expectedErrContent != "" {
|
||||
if !strings.Contains(err.Error(), test.expectedErrContent) {
|
||||
t.Errorf("Test %d (%s): Expected error containing '%s' but got: %v",
|
||||
i, test.input, test.expectedErrContent, err)
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
if !test.shouldErr {
|
||||
config := dnsserver.GetConfig(c)
|
||||
assertIntPtrValue(t, i, test.input, "MaxGRPCStreams", config.MaxGRPCStreams, test.expectedMaxStreams)
|
||||
assertIntPtrValue(t, i, test.input, "MaxGRPCConnections", config.MaxGRPCConnections, test.expectedMaxConnections)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func intPtr(v int) *int {
|
||||
return &v
|
||||
}
|
||||
|
||||
func assertIntPtrValue(t *testing.T, testIndex int, testInput, fieldName string, actual, expected *int) {
|
||||
t.Helper()
|
||||
if actual == nil && expected == nil {
|
||||
return
|
||||
}
|
||||
|
||||
if (actual == nil) != (expected == nil) {
|
||||
t.Errorf("Test %d (%s): Expected %s to be %v, but got %v",
|
||||
testIndex, testInput, fieldName, formatNilableInt(expected), formatNilableInt(actual))
|
||||
return
|
||||
}
|
||||
|
||||
if *actual != *expected {
|
||||
t.Errorf("Test %d (%s): Expected %s to be %d, but got %d",
|
||||
testIndex, testInput, fieldName, *expected, *actual)
|
||||
}
|
||||
}
|
||||
|
||||
func formatNilableInt(v *int) string {
|
||||
if v == nil {
|
||||
return "nil"
|
||||
}
|
||||
return fmt.Sprintf("%d", *v)
|
||||
}
|
||||
47
plugin/https/README.md
Normal file
47
plugin/https/README.md
Normal file
@@ -0,0 +1,47 @@
|
||||
# https
|
||||
|
||||
## Name
|
||||
|
||||
*https* - configures DNS-over-HTTPS (DoH) server options.
|
||||
|
||||
## Description
|
||||
|
||||
The *https* plugin allows you to configure parameters for the DNS-over-HTTPS (DoH) server to fine-tune the security posture and performance of the server.
|
||||
|
||||
This plugin can only be used once per HTTPS listener block.
|
||||
|
||||
## Syntax
|
||||
|
||||
```txt
|
||||
https {
|
||||
max_connections POSITIVE_INTEGER
|
||||
}
|
||||
```
|
||||
|
||||
* `max_connections` limits the number of concurrent TCP connections to the HTTPS server. The default value is 200 if not specified. Set to 0 for unbounded.
|
||||
|
||||
## Examples
|
||||
|
||||
Set custom limits for maximum connections:
|
||||
|
||||
```
|
||||
https://.:443 {
|
||||
tls cert.pem key.pem
|
||||
https {
|
||||
max_connections 100
|
||||
}
|
||||
whoami
|
||||
}
|
||||
```
|
||||
|
||||
Set values to 0 for unbounded, matching CoreDNS behaviour before v1.14.0:
|
||||
|
||||
```
|
||||
https://.:443 {
|
||||
tls cert.pem key.pem
|
||||
https {
|
||||
max_connections 0
|
||||
}
|
||||
whoami
|
||||
}
|
||||
```
|
||||
63
plugin/https/setup.go
Normal file
63
plugin/https/setup.go
Normal file
@@ -0,0 +1,63 @@
|
||||
package https
|
||||
|
||||
import (
|
||||
"strconv"
|
||||
|
||||
"github.com/coredns/caddy"
|
||||
"github.com/coredns/coredns/core/dnsserver"
|
||||
"github.com/coredns/coredns/plugin"
|
||||
)
|
||||
|
||||
func init() {
|
||||
caddy.RegisterPlugin("https", caddy.Plugin{
|
||||
ServerType: "dns",
|
||||
Action: setup,
|
||||
})
|
||||
}
|
||||
|
||||
func setup(c *caddy.Controller) error {
|
||||
err := parseDOH(c)
|
||||
if err != nil {
|
||||
return plugin.Error("https", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func parseDOH(c *caddy.Controller) error {
|
||||
config := dnsserver.GetConfig(c)
|
||||
|
||||
// Skip the "https" directive itself
|
||||
c.Next()
|
||||
|
||||
// Get any arguments on the "https" line
|
||||
args := c.RemainingArgs()
|
||||
if len(args) > 0 {
|
||||
return c.ArgErr()
|
||||
}
|
||||
|
||||
// Process all nested directives in the block
|
||||
for c.NextBlock() {
|
||||
switch c.Val() {
|
||||
case "max_connections":
|
||||
args := c.RemainingArgs()
|
||||
if len(args) != 1 {
|
||||
return c.ArgErr()
|
||||
}
|
||||
val, err := strconv.Atoi(args[0])
|
||||
if err != nil {
|
||||
return c.Errf("invalid max_connections value '%s': %v", args[0], err)
|
||||
}
|
||||
if val < 0 {
|
||||
return c.Errf("max_connections must be a non-negative integer: %d", val)
|
||||
}
|
||||
if config.MaxHTTPSConnections != nil {
|
||||
return c.Err("max_connections already defined for this server block")
|
||||
}
|
||||
config.MaxHTTPSConnections = &val
|
||||
default:
|
||||
return c.Errf("unknown property '%s'", c.Val())
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
144
plugin/https/setup_test.go
Normal file
144
plugin/https/setup_test.go
Normal file
@@ -0,0 +1,144 @@
|
||||
package https
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/coredns/caddy"
|
||||
"github.com/coredns/coredns/core/dnsserver"
|
||||
)
|
||||
|
||||
func TestSetup(t *testing.T) {
|
||||
tests := []struct {
|
||||
input string
|
||||
shouldErr bool
|
||||
expectedErrContent string
|
||||
expectedMaxConnections *int
|
||||
}{
|
||||
// Valid configurations
|
||||
{
|
||||
input: `https`,
|
||||
shouldErr: false,
|
||||
},
|
||||
{
|
||||
input: `https {
|
||||
}`,
|
||||
shouldErr: false,
|
||||
},
|
||||
{
|
||||
input: `https {
|
||||
max_connections 200
|
||||
}`,
|
||||
shouldErr: false,
|
||||
expectedMaxConnections: intPtr(200),
|
||||
},
|
||||
// Zero values (unbounded)
|
||||
{
|
||||
input: `https {
|
||||
max_connections 0
|
||||
}`,
|
||||
shouldErr: false,
|
||||
expectedMaxConnections: intPtr(0),
|
||||
},
|
||||
// Error cases
|
||||
{
|
||||
input: `https {
|
||||
max_connections
|
||||
}`,
|
||||
shouldErr: true,
|
||||
expectedErrContent: "Wrong argument count",
|
||||
},
|
||||
{
|
||||
input: `https {
|
||||
max_connections abc
|
||||
}`,
|
||||
shouldErr: true,
|
||||
expectedErrContent: "invalid max_connections value",
|
||||
},
|
||||
{
|
||||
input: `https {
|
||||
max_connections -1
|
||||
}`,
|
||||
shouldErr: true,
|
||||
expectedErrContent: "must be a non-negative integer",
|
||||
},
|
||||
{
|
||||
input: `https {
|
||||
max_connections 100
|
||||
max_connections 200
|
||||
}`,
|
||||
shouldErr: true,
|
||||
expectedErrContent: "already defined",
|
||||
},
|
||||
{
|
||||
input: `https {
|
||||
unknown_option 123
|
||||
}`,
|
||||
shouldErr: true,
|
||||
expectedErrContent: "unknown property",
|
||||
},
|
||||
{
|
||||
input: `https extra_arg`,
|
||||
shouldErr: true,
|
||||
expectedErrContent: "Wrong argument count",
|
||||
},
|
||||
}
|
||||
|
||||
for i, test := range tests {
|
||||
c := caddy.NewTestController("dns", test.input)
|
||||
err := setup(c)
|
||||
|
||||
if test.shouldErr && err == nil {
|
||||
t.Errorf("Test %d (%s): Expected error but got none", i, test.input)
|
||||
continue
|
||||
}
|
||||
|
||||
if !test.shouldErr && err != nil {
|
||||
t.Errorf("Test %d (%s): Expected no error but got: %v", i, test.input, err)
|
||||
continue
|
||||
}
|
||||
|
||||
if test.shouldErr && test.expectedErrContent != "" {
|
||||
if !strings.Contains(err.Error(), test.expectedErrContent) {
|
||||
t.Errorf("Test %d (%s): Expected error containing '%s' but got: %v",
|
||||
i, test.input, test.expectedErrContent, err)
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
if !test.shouldErr {
|
||||
config := dnsserver.GetConfig(c)
|
||||
assertIntPtrValue(t, i, test.input, "MaxHTTPSConnections", config.MaxHTTPSConnections, test.expectedMaxConnections)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func intPtr(v int) *int {
|
||||
return &v
|
||||
}
|
||||
|
||||
func assertIntPtrValue(t *testing.T, testIndex int, testInput, fieldName string, actual, expected *int) {
|
||||
t.Helper()
|
||||
if actual == nil && expected == nil {
|
||||
return
|
||||
}
|
||||
|
||||
if (actual == nil) != (expected == nil) {
|
||||
t.Errorf("Test %d (%s): Expected %s to be %v, but got %v",
|
||||
testIndex, testInput, fieldName, formatNilableInt(expected), formatNilableInt(actual))
|
||||
return
|
||||
}
|
||||
|
||||
if *actual != *expected {
|
||||
t.Errorf("Test %d (%s): Expected %s to be %d, but got %d",
|
||||
testIndex, testInput, fieldName, *expected, *actual)
|
||||
}
|
||||
}
|
||||
|
||||
func formatNilableInt(v *int) string {
|
||||
if v == nil {
|
||||
return "nil"
|
||||
}
|
||||
return fmt.Sprintf("%d", *v)
|
||||
}
|
||||
47
plugin/https3/README.md
Normal file
47
plugin/https3/README.md
Normal file
@@ -0,0 +1,47 @@
|
||||
# https3
|
||||
|
||||
## Name
|
||||
|
||||
*https3* - configures DNS-over-HTTPS/3 (DoH3) server options.
|
||||
|
||||
## Description
|
||||
|
||||
The *https3* plugin allows you to configure parameters for the DNS-over-HTTPS/3 (DoH3) server to fine-tune the security posture and performance of the server. HTTPS/3 uses QUIC as the underlying transport.
|
||||
|
||||
This plugin can only be used once per HTTPS3 listener block.
|
||||
|
||||
## Syntax
|
||||
|
||||
```txt
|
||||
https3 {
|
||||
max_streams POSITIVE_INTEGER
|
||||
}
|
||||
```
|
||||
|
||||
* `max_streams` limits the number of concurrent QUIC streams per connection. This helps prevent unbounded streams on a single connection, exhausting server resources. The default value is 256 if not specified. Set to 0 to use underlying QUIC transport default.
|
||||
|
||||
## Examples
|
||||
|
||||
Set custom limits for maximum streams:
|
||||
|
||||
```
|
||||
https3://.:443 {
|
||||
tls cert.pem key.pem
|
||||
https3 {
|
||||
max_streams 50
|
||||
}
|
||||
whoami
|
||||
}
|
||||
```
|
||||
|
||||
Set values to 0 for QUIC transport default, matching CoreDNS behaviour before v1.14.0:
|
||||
|
||||
```
|
||||
https3://.:443 {
|
||||
tls cert.pem key.pem
|
||||
https3 {
|
||||
max_streams 0
|
||||
}
|
||||
whoami
|
||||
}
|
||||
```
|
||||
63
plugin/https3/setup.go
Normal file
63
plugin/https3/setup.go
Normal file
@@ -0,0 +1,63 @@
|
||||
package https3
|
||||
|
||||
import (
|
||||
"strconv"
|
||||
|
||||
"github.com/coredns/caddy"
|
||||
"github.com/coredns/coredns/core/dnsserver"
|
||||
"github.com/coredns/coredns/plugin"
|
||||
)
|
||||
|
||||
func init() {
|
||||
caddy.RegisterPlugin("https3", caddy.Plugin{
|
||||
ServerType: "dns",
|
||||
Action: setup,
|
||||
})
|
||||
}
|
||||
|
||||
func setup(c *caddy.Controller) error {
|
||||
err := parseDOH3(c)
|
||||
if err != nil {
|
||||
return plugin.Error("https3", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func parseDOH3(c *caddy.Controller) error {
|
||||
config := dnsserver.GetConfig(c)
|
||||
|
||||
// Skip the "https3" directive itself
|
||||
c.Next()
|
||||
|
||||
// Get any arguments on the "https3" line
|
||||
args := c.RemainingArgs()
|
||||
if len(args) > 0 {
|
||||
return c.ArgErr()
|
||||
}
|
||||
|
||||
// Process all nested directives in the block
|
||||
for c.NextBlock() {
|
||||
switch c.Val() {
|
||||
case "max_streams":
|
||||
args := c.RemainingArgs()
|
||||
if len(args) != 1 {
|
||||
return c.ArgErr()
|
||||
}
|
||||
val, err := strconv.Atoi(args[0])
|
||||
if err != nil {
|
||||
return c.Errf("invalid max_streams value '%s': %v", args[0], err)
|
||||
}
|
||||
if val < 0 {
|
||||
return c.Errf("max_streams must be a non-negative integer: %d", val)
|
||||
}
|
||||
if config.MaxHTTPS3Streams != nil {
|
||||
return c.Err("max_streams already defined for this server block")
|
||||
}
|
||||
config.MaxHTTPS3Streams = &val
|
||||
default:
|
||||
return c.Errf("unknown property '%s'", c.Val())
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
144
plugin/https3/setup_test.go
Normal file
144
plugin/https3/setup_test.go
Normal file
@@ -0,0 +1,144 @@
|
||||
package https3
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/coredns/caddy"
|
||||
"github.com/coredns/coredns/core/dnsserver"
|
||||
)
|
||||
|
||||
func TestSetup(t *testing.T) {
|
||||
tests := []struct {
|
||||
input string
|
||||
shouldErr bool
|
||||
expectedErrContent string
|
||||
expectedMaxStreams *int
|
||||
}{
|
||||
// Valid configurations
|
||||
{
|
||||
input: `https3`,
|
||||
shouldErr: false,
|
||||
},
|
||||
{
|
||||
input: `https3 {
|
||||
}`,
|
||||
shouldErr: false,
|
||||
},
|
||||
{
|
||||
input: `https3 {
|
||||
max_streams 100
|
||||
}`,
|
||||
shouldErr: false,
|
||||
expectedMaxStreams: intPtr(100),
|
||||
},
|
||||
// Zero values (unbounded)
|
||||
{
|
||||
input: `https3 {
|
||||
max_streams 0
|
||||
}`,
|
||||
shouldErr: false,
|
||||
expectedMaxStreams: intPtr(0),
|
||||
},
|
||||
// Error cases
|
||||
{
|
||||
input: `https3 {
|
||||
max_streams
|
||||
}`,
|
||||
shouldErr: true,
|
||||
expectedErrContent: "Wrong argument count",
|
||||
},
|
||||
{
|
||||
input: `https3 {
|
||||
max_streams abc
|
||||
}`,
|
||||
shouldErr: true,
|
||||
expectedErrContent: "invalid max_streams value",
|
||||
},
|
||||
{
|
||||
input: `https3 {
|
||||
max_streams -1
|
||||
}`,
|
||||
shouldErr: true,
|
||||
expectedErrContent: "must be a non-negative integer",
|
||||
},
|
||||
{
|
||||
input: `https3 {
|
||||
max_streams 100
|
||||
max_streams 200
|
||||
}`,
|
||||
shouldErr: true,
|
||||
expectedErrContent: "already defined",
|
||||
},
|
||||
{
|
||||
input: `https3 {
|
||||
unknown_option 123
|
||||
}`,
|
||||
shouldErr: true,
|
||||
expectedErrContent: "unknown property",
|
||||
},
|
||||
{
|
||||
input: `https3 extra_arg`,
|
||||
shouldErr: true,
|
||||
expectedErrContent: "Wrong argument count",
|
||||
},
|
||||
}
|
||||
|
||||
for i, test := range tests {
|
||||
c := caddy.NewTestController("dns", test.input)
|
||||
err := setup(c)
|
||||
|
||||
if test.shouldErr && err == nil {
|
||||
t.Errorf("Test %d (%s): Expected error but got none", i, test.input)
|
||||
continue
|
||||
}
|
||||
|
||||
if !test.shouldErr && err != nil {
|
||||
t.Errorf("Test %d (%s): Expected no error but got: %v", i, test.input, err)
|
||||
continue
|
||||
}
|
||||
|
||||
if test.shouldErr && test.expectedErrContent != "" {
|
||||
if !strings.Contains(err.Error(), test.expectedErrContent) {
|
||||
t.Errorf("Test %d (%s): Expected error containing '%s' but got: %v",
|
||||
i, test.input, test.expectedErrContent, err)
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
if !test.shouldErr {
|
||||
config := dnsserver.GetConfig(c)
|
||||
assertIntPtrValue(t, i, test.input, "MaxHTTPS3Streams", config.MaxHTTPS3Streams, test.expectedMaxStreams)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func intPtr(v int) *int {
|
||||
return &v
|
||||
}
|
||||
|
||||
func assertIntPtrValue(t *testing.T, testIndex int, testInput, fieldName string, actual, expected *int) {
|
||||
t.Helper()
|
||||
if actual == nil && expected == nil {
|
||||
return
|
||||
}
|
||||
|
||||
if (actual == nil) != (expected == nil) {
|
||||
t.Errorf("Test %d (%s): Expected %s to be %v, but got %v",
|
||||
testIndex, testInput, fieldName, formatNilableInt(expected), formatNilableInt(actual))
|
||||
return
|
||||
}
|
||||
|
||||
if *actual != *expected {
|
||||
t.Errorf("Test %d (%s): Expected %s to be %d, but got %d",
|
||||
testIndex, testInput, fieldName, *expected, *actual)
|
||||
}
|
||||
}
|
||||
|
||||
func formatNilableInt(v *int) string {
|
||||
if v == nil {
|
||||
return "nil"
|
||||
}
|
||||
return fmt.Sprintf("%d", *v)
|
||||
}
|
||||
@@ -2,19 +2,40 @@ package test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"net"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/coredns/coredns/pb"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/credentials"
|
||||
"google.golang.org/grpc/credentials/insecure"
|
||||
)
|
||||
|
||||
var grpcCorefile = `grpc://.:0 {
|
||||
whoami
|
||||
}`
|
||||
|
||||
var grpcLimitCorefile = `grpc://.:0 {
|
||||
grpc_server {
|
||||
max_streams 2
|
||||
}
|
||||
whoami
|
||||
}`
|
||||
|
||||
var grpcConnectionLimitCorefile = `grpc://.:0 {
|
||||
tls ../plugin/tls/test_cert.pem ../plugin/tls/test_key.pem ../plugin/tls/test_ca.pem
|
||||
grpc_server {
|
||||
max_connections 2
|
||||
}
|
||||
whoami
|
||||
}`
|
||||
|
||||
func TestGrpc(t *testing.T) {
|
||||
corefile := `grpc://.:0 {
|
||||
whoami
|
||||
}`
|
||||
corefile := grpcCorefile
|
||||
|
||||
g, _, tcp, err := CoreDNSServerAndPorts(corefile)
|
||||
if err != nil {
|
||||
@@ -53,3 +74,127 @@ func TestGrpc(t *testing.T) {
|
||||
t.Errorf("Expected 2 RRs in additional section, but got %d", len(d.Extra))
|
||||
}
|
||||
}
|
||||
|
||||
// TestGRPCWithLimits tests that the server starts and works with configured limits
|
||||
func TestGRPCWithLimits(t *testing.T) {
|
||||
g, _, tcp, err := CoreDNSServerAndPorts(grpcLimitCorefile)
|
||||
if err != nil {
|
||||
t.Fatalf("Could not get CoreDNS serving instance: %s", err)
|
||||
}
|
||||
defer g.Stop()
|
||||
|
||||
conn, err := grpc.NewClient(tcp, grpc.WithTransportCredentials(insecure.NewCredentials()))
|
||||
if err != nil {
|
||||
t.Fatalf("Expected no error but got: %s", err)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
client := pb.NewDnsServiceClient(conn)
|
||||
|
||||
m := new(dns.Msg)
|
||||
m.SetQuestion("whoami.example.org.", dns.TypeA)
|
||||
msg, _ := m.Pack()
|
||||
|
||||
reply, err := client.Query(context.Background(), &pb.DnsPacket{Msg: msg})
|
||||
if err != nil {
|
||||
t.Fatalf("Query failed: %s", err)
|
||||
}
|
||||
|
||||
d := new(dns.Msg)
|
||||
if err := d.Unpack(reply.GetMsg()); err != nil {
|
||||
t.Fatalf("Failed to unpack: %s", err)
|
||||
}
|
||||
|
||||
if d.Rcode != dns.RcodeSuccess {
|
||||
t.Errorf("Expected success but got %d", d.Rcode)
|
||||
}
|
||||
}
|
||||
|
||||
// TestGRPCConnectionLimit tests that connection limits are enforced
|
||||
func TestGRPCConnectionLimit(t *testing.T) {
|
||||
g, _, tcp, err := CoreDNSServerAndPorts(grpcConnectionLimitCorefile)
|
||||
if err != nil {
|
||||
t.Fatalf("Could not get CoreDNS serving instance: %s", err)
|
||||
}
|
||||
defer g.Stop()
|
||||
|
||||
const maxConns = 2
|
||||
|
||||
// Create TLS connections to hold them open
|
||||
tlsConfig := &tls.Config{InsecureSkipVerify: true}
|
||||
conns := make([]net.Conn, 0, maxConns+1)
|
||||
defer func() {
|
||||
for _, c := range conns {
|
||||
c.Close()
|
||||
}
|
||||
}()
|
||||
|
||||
// Open connections up to the limit - these should succeed
|
||||
for i := range maxConns {
|
||||
conn, err := tls.Dial("tcp", tcp, tlsConfig)
|
||||
if err != nil {
|
||||
t.Fatalf("Connection %d failed (should succeed): %v", i+1, err)
|
||||
}
|
||||
conns = append(conns, conn)
|
||||
}
|
||||
|
||||
// Try to open more connections beyond the limit - should timeout
|
||||
conn, err := tls.DialWithDialer(
|
||||
&net.Dialer{Timeout: 100 * time.Millisecond},
|
||||
"tcp", tcp, tlsConfig,
|
||||
)
|
||||
if err == nil {
|
||||
conn.Close()
|
||||
t.Fatal("Connection beyond limit should have timed out")
|
||||
}
|
||||
|
||||
// Close one connection and verify a new one can be established
|
||||
conns[0].Close()
|
||||
conns = conns[1:]
|
||||
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
|
||||
conn, err = tls.Dial("tcp", tcp, tlsConfig)
|
||||
if err != nil {
|
||||
t.Fatalf("Connection after freeing slot failed: %v", err)
|
||||
}
|
||||
conns = append(conns, conn)
|
||||
}
|
||||
|
||||
// TestGRPCTLSWithLimits tests that gRPC with TLS starts and works with configured limits
|
||||
func TestGRPCTLSWithLimits(t *testing.T) {
|
||||
g, _, tcp, err := CoreDNSServerAndPorts(grpcConnectionLimitCorefile)
|
||||
if err != nil {
|
||||
t.Fatalf("Could not get CoreDNS serving instance: %s", err)
|
||||
}
|
||||
defer g.Stop()
|
||||
|
||||
tlsConfig := &tls.Config{InsecureSkipVerify: true}
|
||||
creds := credentials.NewTLS(tlsConfig)
|
||||
|
||||
conn, err := grpc.NewClient(tcp, grpc.WithTransportCredentials(creds))
|
||||
if err != nil {
|
||||
t.Fatalf("Expected no error but got: %s", err)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
client := pb.NewDnsServiceClient(conn)
|
||||
|
||||
m := new(dns.Msg)
|
||||
m.SetQuestion("whoami.example.org.", dns.TypeA)
|
||||
msg, _ := m.Pack()
|
||||
|
||||
reply, err := client.Query(context.Background(), &pb.DnsPacket{Msg: msg})
|
||||
if err != nil {
|
||||
t.Fatalf("Query failed: %s", err)
|
||||
}
|
||||
|
||||
d := new(dns.Msg)
|
||||
if err := d.Unpack(reply.GetMsg()); err != nil {
|
||||
t.Fatalf("Failed to unpack: %s", err)
|
||||
}
|
||||
|
||||
if d.Rcode != dns.RcodeSuccess {
|
||||
t.Errorf("Expected success but got %d", d.Rcode)
|
||||
}
|
||||
}
|
||||
|
||||
145
test/https3_test.go
Normal file
145
test/https3_test.go
Normal file
@@ -0,0 +1,145 @@
|
||||
package test
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"io"
|
||||
"net/http"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
ctls "github.com/coredns/coredns/plugin/pkg/tls"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
"github.com/quic-go/quic-go/http3"
|
||||
)
|
||||
|
||||
var https3Corefile = `https3://.:0 {
|
||||
tls ../plugin/tls/test_cert.pem ../plugin/tls/test_key.pem ../plugin/tls/test_ca.pem
|
||||
whoami
|
||||
}`
|
||||
|
||||
var https3LimitCorefile = `https3://.:0 {
|
||||
tls ../plugin/tls/test_cert.pem ../plugin/tls/test_key.pem ../plugin/tls/test_ca.pem
|
||||
https3 {
|
||||
max_streams 2
|
||||
}
|
||||
whoami
|
||||
}`
|
||||
|
||||
func generateHTTPS3TLSConfig() *tls.Config {
|
||||
tlsConfig, err := ctls.NewTLSConfig(
|
||||
"../plugin/tls/test_cert.pem",
|
||||
"../plugin/tls/test_key.pem",
|
||||
"../plugin/tls/test_ca.pem")
|
||||
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
tlsConfig.InsecureSkipVerify = true
|
||||
|
||||
return tlsConfig
|
||||
}
|
||||
|
||||
func TestHTTPS3(t *testing.T) {
|
||||
s, udp, _, err := CoreDNSServerAndPorts(https3Corefile)
|
||||
if err != nil {
|
||||
t.Fatalf("Could not get CoreDNS serving instance: %s", err)
|
||||
}
|
||||
defer s.Stop()
|
||||
|
||||
// Create HTTP/3 client
|
||||
transport := &http3.Transport{
|
||||
TLSClientConfig: generateHTTPS3TLSConfig(),
|
||||
}
|
||||
defer transport.Close()
|
||||
|
||||
client := &http.Client{
|
||||
Transport: transport,
|
||||
Timeout: 5 * time.Second,
|
||||
}
|
||||
|
||||
// Create DNS query
|
||||
m := new(dns.Msg)
|
||||
m.SetQuestion("whoami.example.org.", dns.TypeA)
|
||||
msg, err := m.Pack()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to pack DNS message: %v", err)
|
||||
}
|
||||
|
||||
// Make DoH3 request - use UDP address for HTTP/3
|
||||
url := "https://" + convertAddress(udp) + "/dns-query"
|
||||
req, err := http.NewRequestWithContext(context.Background(), http.MethodPost, url, bytes.NewReader(msg))
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create request: %v", err)
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/dns-message")
|
||||
req.Header.Set("Accept", "application/dns-message")
|
||||
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to make request: %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
t.Fatalf("Expected status 200, got %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to read response: %v", err)
|
||||
}
|
||||
|
||||
d := new(dns.Msg)
|
||||
err = d.Unpack(body)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to unpack response: %v", err)
|
||||
}
|
||||
|
||||
if d.Rcode != dns.RcodeSuccess {
|
||||
t.Errorf("Expected success but got %d", d.Rcode)
|
||||
}
|
||||
|
||||
if len(d.Extra) != 2 {
|
||||
t.Errorf("Expected 2 RRs in additional section, but got %d", len(d.Extra))
|
||||
}
|
||||
}
|
||||
|
||||
// TestHTTPS3WithLimits tests that the server starts and works with configured limits
|
||||
func TestHTTPS3WithLimits(t *testing.T) {
|
||||
s, udp, _, err := CoreDNSServerAndPorts(https3LimitCorefile)
|
||||
if err != nil {
|
||||
t.Fatalf("Could not get CoreDNS serving instance: %s", err)
|
||||
}
|
||||
defer s.Stop()
|
||||
|
||||
transport := &http3.Transport{
|
||||
TLSClientConfig: generateHTTPS3TLSConfig(),
|
||||
}
|
||||
defer transport.Close()
|
||||
|
||||
client := &http.Client{
|
||||
Transport: transport,
|
||||
Timeout: 5 * time.Second,
|
||||
}
|
||||
|
||||
m := new(dns.Msg)
|
||||
m.SetQuestion("whoami.example.org.", dns.TypeA)
|
||||
msg, _ := m.Pack()
|
||||
|
||||
req, _ := http.NewRequestWithContext(context.Background(), http.MethodPost, "https://"+convertAddress(udp)+"/dns-query", bytes.NewReader(msg))
|
||||
req.Header.Set("Content-Type", "application/dns-message")
|
||||
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
t.Fatalf("Request failed: %s", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
t.Fatalf("Expected status 200, got %d", resp.StatusCode)
|
||||
}
|
||||
}
|
||||
177
test/https_test.go
Normal file
177
test/https_test.go
Normal file
@@ -0,0 +1,177 @@
|
||||
package test
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/tls"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
|
||||
var httpsCorefile = `https://.:0 {
|
||||
tls ../plugin/tls/test_cert.pem ../plugin/tls/test_key.pem ../plugin/tls/test_ca.pem
|
||||
whoami
|
||||
}`
|
||||
|
||||
var httpsLimitCorefile = `https://.:0 {
|
||||
tls ../plugin/tls/test_cert.pem ../plugin/tls/test_key.pem ../plugin/tls/test_ca.pem
|
||||
https {
|
||||
max_connections 2
|
||||
}
|
||||
whoami
|
||||
}`
|
||||
|
||||
func TestHTTPS(t *testing.T) {
|
||||
s, _, tcp, err := CoreDNSServerAndPorts(httpsCorefile)
|
||||
if err != nil {
|
||||
t.Fatalf("Could not get CoreDNS serving instance: %s", err)
|
||||
}
|
||||
defer s.Stop()
|
||||
|
||||
// Create HTTPS client with TLS config
|
||||
tlsConfig := &tls.Config{
|
||||
InsecureSkipVerify: true,
|
||||
}
|
||||
client := &http.Client{
|
||||
Transport: &http.Transport{
|
||||
TLSClientConfig: tlsConfig,
|
||||
},
|
||||
Timeout: 5 * time.Second,
|
||||
}
|
||||
|
||||
// Create DNS query
|
||||
m := new(dns.Msg)
|
||||
m.SetQuestion("whoami.example.org.", dns.TypeA)
|
||||
msg, err := m.Pack()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to pack DNS message: %v", err)
|
||||
}
|
||||
|
||||
// Make DoH request
|
||||
url := "https://" + tcp + "/dns-query"
|
||||
req, err := http.NewRequest(http.MethodPost, url, bytes.NewReader(msg))
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create request: %v", err)
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/dns-message")
|
||||
req.Header.Set("Accept", "application/dns-message")
|
||||
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to make request: %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
t.Fatalf("Expected status 200, got %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to read response: %v", err)
|
||||
}
|
||||
|
||||
d := new(dns.Msg)
|
||||
err = d.Unpack(body)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to unpack response: %v", err)
|
||||
}
|
||||
|
||||
if d.Rcode != dns.RcodeSuccess {
|
||||
t.Errorf("Expected success but got %d", d.Rcode)
|
||||
}
|
||||
|
||||
if len(d.Extra) != 2 {
|
||||
t.Errorf("Expected 2 RRs in additional section, but got %d", len(d.Extra))
|
||||
}
|
||||
}
|
||||
|
||||
// TestHTTPSWithLimits tests that the server starts and works with configured limits
|
||||
func TestHTTPSWithLimits(t *testing.T) {
|
||||
s, _, tcp, err := CoreDNSServerAndPorts(httpsLimitCorefile)
|
||||
if err != nil {
|
||||
t.Fatalf("Could not get CoreDNS serving instance: %s", err)
|
||||
}
|
||||
defer s.Stop()
|
||||
|
||||
client := &http.Client{
|
||||
Transport: &http.Transport{
|
||||
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
|
||||
},
|
||||
Timeout: 5 * time.Second,
|
||||
}
|
||||
|
||||
m := new(dns.Msg)
|
||||
m.SetQuestion("whoami.example.org.", dns.TypeA)
|
||||
msg, _ := m.Pack()
|
||||
|
||||
req, _ := http.NewRequest(http.MethodPost, "https://"+tcp+"/dns-query", bytes.NewReader(msg))
|
||||
req.Header.Set("Content-Type", "application/dns-message")
|
||||
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
t.Fatalf("Request failed: %s", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
t.Fatalf("Expected status 200, got %d", resp.StatusCode)
|
||||
}
|
||||
}
|
||||
|
||||
// TestHTTPSConnectionLimit tests that connection limits are enforced
|
||||
func TestHTTPSConnectionLimit(t *testing.T) {
|
||||
s, _, tcp, err := CoreDNSServerAndPorts(httpsLimitCorefile)
|
||||
if err != nil {
|
||||
t.Fatalf("Could not get CoreDNS serving instance: %s", err)
|
||||
}
|
||||
defer s.Stop()
|
||||
|
||||
const maxConns = 2
|
||||
const totalConns = 4
|
||||
|
||||
// Create raw TLS connections to hold them open
|
||||
conns := make([]net.Conn, 0, totalConns)
|
||||
defer func() {
|
||||
for _, c := range conns {
|
||||
c.Close()
|
||||
}
|
||||
}()
|
||||
|
||||
// Open connections up to the limit - these should succeed
|
||||
for i := range maxConns {
|
||||
conn, err := tls.Dial("tcp", tcp, &tls.Config{InsecureSkipVerify: true})
|
||||
if err != nil {
|
||||
t.Fatalf("Connection %d failed (should succeed): %v", i+1, err)
|
||||
}
|
||||
conns = append(conns, conn)
|
||||
}
|
||||
|
||||
// Try to open more connections beyond the limit
|
||||
// The LimitListener blocks Accept() until a slot is free, so Dial with timeout should fail
|
||||
conn, err := tls.DialWithDialer(
|
||||
&net.Dialer{Timeout: 100 * time.Millisecond},
|
||||
"tcp", tcp,
|
||||
&tls.Config{InsecureSkipVerify: true},
|
||||
)
|
||||
if err == nil {
|
||||
conn.Close()
|
||||
t.Fatal("Connection beyond limit should have timed out")
|
||||
}
|
||||
|
||||
// Close one connection and verify a new one can be established
|
||||
conns[0].Close()
|
||||
conns = conns[1:]
|
||||
|
||||
time.Sleep(10 * time.Millisecond) // Give the listener time to accept
|
||||
|
||||
conn, err = tls.Dial("tcp", tcp, &tls.Config{InsecureSkipVerify: true})
|
||||
if err != nil {
|
||||
t.Fatalf("Connection after freeing slot failed: %v", err)
|
||||
}
|
||||
conns = append(conns, conn)
|
||||
}
|
||||
Reference in New Issue
Block a user