mirror of
https://github.com/coredns/coredns.git
synced 2025-12-20 17:15:10 -05:00
Merge commit from fork
Add configurable resource limits to prevent potential DoS vectors via connection/stream exhaustion on gRPC, HTTPS, and HTTPS/3 servers. New configuration plugins: - grpc_server: configure max_streams, max_connections - https: configure max_connections - https3: configure max_streams Changes: - Use netutil.LimitListener for connection limiting - Use gRPC MaxConcurrentStreams and message size limits - Add QUIC MaxIncomingStreams for HTTPS/3 stream limiting - Set secure defaults: 256 max streams, 200 max connections - Setting any limit to 0 means unbounded/fallback to previous impl Defaults are applied automatically when plugins are omitted from config. Includes tests and integration tests. Signed-off-by: Ville Vesilehto <ville@vesilehto.fi>
This commit is contained in:
@@ -15,17 +15,35 @@ import (
|
||||
"github.com/grpc-ecosystem/grpc-opentracing/go/otgrpc"
|
||||
"github.com/miekg/dns"
|
||||
"github.com/opentracing/opentracing-go"
|
||||
"golang.org/x/net/netutil"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/peer"
|
||||
)
|
||||
|
||||
const (
|
||||
// maxDNSMessageBytes is the maximum size of a DNS message on the wire.
|
||||
maxDNSMessageBytes = dns.MaxMsgSize
|
||||
|
||||
// maxProtobufPayloadBytes accounts for protobuf overhead.
|
||||
// Field tag=1 (1 byte) + length varint for 65535 (3 bytes) = 4 bytes total
|
||||
maxProtobufPayloadBytes = maxDNSMessageBytes + 4
|
||||
|
||||
// DefaultGRPCMaxStreams is the default maximum number of concurrent streams per connection.
|
||||
DefaultGRPCMaxStreams = 256
|
||||
|
||||
// DefaultGRPCMaxConnections is the default maximum number of concurrent connections.
|
||||
DefaultGRPCMaxConnections = 200
|
||||
)
|
||||
|
||||
// ServergRPC represents an instance of a DNS-over-gRPC server.
|
||||
type ServergRPC struct {
|
||||
*Server
|
||||
*pb.UnimplementedDnsServiceServer
|
||||
grpcServer *grpc.Server
|
||||
listenAddr net.Addr
|
||||
tlsConfig *tls.Config
|
||||
grpcServer *grpc.Server
|
||||
listenAddr net.Addr
|
||||
tlsConfig *tls.Config
|
||||
maxStreams int
|
||||
maxConnections int
|
||||
}
|
||||
|
||||
// NewServergRPC returns a new CoreDNS GRPC server and compiles all plugin in to it.
|
||||
@@ -49,7 +67,22 @@ func NewServergRPC(addr string, group []*Config) (*ServergRPC, error) {
|
||||
tlsConfig.NextProtos = []string{"h2"}
|
||||
}
|
||||
|
||||
return &ServergRPC{Server: s, tlsConfig: tlsConfig}, nil
|
||||
maxStreams := DefaultGRPCMaxStreams
|
||||
if len(group) > 0 && group[0] != nil && group[0].MaxGRPCStreams != nil {
|
||||
maxStreams = *group[0].MaxGRPCStreams
|
||||
}
|
||||
|
||||
maxConnections := DefaultGRPCMaxConnections
|
||||
if len(group) > 0 && group[0] != nil && group[0].MaxGRPCConnections != nil {
|
||||
maxConnections = *group[0].MaxGRPCConnections
|
||||
}
|
||||
|
||||
return &ServergRPC{
|
||||
Server: s,
|
||||
tlsConfig: tlsConfig,
|
||||
maxStreams: maxStreams,
|
||||
maxConnections: maxConnections,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Compile-time check to ensure ServergRPC implements the caddy.GracefulServer interface
|
||||
@@ -61,21 +94,36 @@ func (s *ServergRPC) Serve(l net.Listener) error {
|
||||
s.listenAddr = l.Addr()
|
||||
s.m.Unlock()
|
||||
|
||||
serverOpts := []grpc.ServerOption{
|
||||
grpc.MaxRecvMsgSize(maxProtobufPayloadBytes),
|
||||
grpc.MaxSendMsgSize(maxProtobufPayloadBytes),
|
||||
}
|
||||
|
||||
// Only set MaxConcurrentStreams if not unbounded (0)
|
||||
if s.maxStreams > 0 {
|
||||
serverOpts = append(serverOpts, grpc.MaxConcurrentStreams(uint32(s.maxStreams)))
|
||||
}
|
||||
|
||||
if s.Tracer() != nil {
|
||||
onlyIfParent := func(parentSpanCtx opentracing.SpanContext, method string, req, resp any) bool {
|
||||
return parentSpanCtx != nil
|
||||
}
|
||||
intercept := otgrpc.OpenTracingServerInterceptor(s.Tracer(), otgrpc.IncludingSpans(onlyIfParent))
|
||||
s.grpcServer = grpc.NewServer(grpc.UnaryInterceptor(intercept))
|
||||
} else {
|
||||
s.grpcServer = grpc.NewServer()
|
||||
serverOpts = append(serverOpts, grpc.UnaryInterceptor(otgrpc.OpenTracingServerInterceptor(s.Tracer(), otgrpc.IncludingSpans(onlyIfParent))))
|
||||
}
|
||||
|
||||
s.grpcServer = grpc.NewServer(serverOpts...)
|
||||
|
||||
pb.RegisterDnsServiceServer(s.grpcServer, s)
|
||||
|
||||
if s.tlsConfig != nil {
|
||||
l = tls.NewListener(l, s.tlsConfig)
|
||||
}
|
||||
|
||||
// Wrap listener to limit concurrent connections
|
||||
if s.maxConnections > 0 {
|
||||
l = netutil.LimitListener(l, s.maxConnections)
|
||||
}
|
||||
|
||||
return s.grpcServer.Serve(l)
|
||||
}
|
||||
|
||||
@@ -122,6 +170,9 @@ func (s *ServergRPC) Stop() (err error) {
|
||||
// any normal server. We use a custom responseWriter to pick up the bytes we need to write
|
||||
// back to the client as a protobuf.
|
||||
func (s *ServergRPC) Query(ctx context.Context, in *pb.DnsPacket) (*pb.DnsPacket, error) {
|
||||
if len(in.GetMsg()) > dns.MaxMsgSize {
|
||||
return nil, fmt.Errorf("dns message exceeds size limit: %d", len(in.GetMsg()))
|
||||
}
|
||||
msg := new(dns.Msg)
|
||||
err := msg.Unpack(in.GetMsg())
|
||||
if err != nil {
|
||||
|
||||
Reference in New Issue
Block a user