Files
coredns/core/dnsserver/server_grpc.go
Ville Vesilehto 0d8cbb1a6b Merge commit from fork
Add configurable resource limits to prevent potential DoS vectors
via connection/stream exhaustion on gRPC, HTTPS, and HTTPS/3 servers.

New configuration plugins:
- grpc_server: configure max_streams, max_connections
- https: configure max_connections
- https3: configure max_streams

Changes:
- Use netutil.LimitListener for connection limiting
- Use gRPC MaxConcurrentStreams and message size limits
- Add QUIC MaxIncomingStreams for HTTPS/3 stream limiting
- Set secure defaults: 256 max streams, 200 max connections
- Setting any limit to 0 means unbounded/fallback to previous impl

Defaults are applied automatically when plugins are omitted from
config.

Includes tests and integration tests.

Signed-off-by: Ville Vesilehto <ville@vesilehto.fi>
2025-12-17 19:08:59 -08:00

238 lines
6.7 KiB
Go

package dnsserver
import (
"context"
"crypto/tls"
"errors"
"fmt"
"net"
"github.com/coredns/caddy"
"github.com/coredns/coredns/pb"
"github.com/coredns/coredns/plugin/pkg/reuseport"
"github.com/coredns/coredns/plugin/pkg/transport"
"github.com/grpc-ecosystem/grpc-opentracing/go/otgrpc"
"github.com/miekg/dns"
"github.com/opentracing/opentracing-go"
"golang.org/x/net/netutil"
"google.golang.org/grpc"
"google.golang.org/grpc/peer"
)
const (
// maxDNSMessageBytes is the maximum size of a DNS message on the wire.
maxDNSMessageBytes = dns.MaxMsgSize
// maxProtobufPayloadBytes accounts for protobuf overhead.
// Field tag=1 (1 byte) + length varint for 65535 (3 bytes) = 4 bytes total
maxProtobufPayloadBytes = maxDNSMessageBytes + 4
// DefaultGRPCMaxStreams is the default maximum number of concurrent streams per connection.
DefaultGRPCMaxStreams = 256
// DefaultGRPCMaxConnections is the default maximum number of concurrent connections.
DefaultGRPCMaxConnections = 200
)
// ServergRPC represents an instance of a DNS-over-gRPC server.
type ServergRPC struct {
*Server
*pb.UnimplementedDnsServiceServer
grpcServer *grpc.Server
listenAddr net.Addr
tlsConfig *tls.Config
maxStreams int
maxConnections int
}
// NewServergRPC returns a new CoreDNS GRPC server and compiles all plugin in to it.
func NewServergRPC(addr string, group []*Config) (*ServergRPC, error) {
s, err := NewServer(addr, group)
if err != nil {
return nil, err
}
// The *tls* plugin must make sure that multiple conflicting
// TLS configuration returns an error: it can only be specified once.
var tlsConfig *tls.Config
for _, z := range s.zones {
for _, conf := range z {
// Should we error if some configs *don't* have TLS?
tlsConfig = conf.TLSConfig
}
}
// http/2 is required when using gRPC. We need to specify it in next protos
// or the upgrade won't happen.
if tlsConfig != nil {
tlsConfig.NextProtos = []string{"h2"}
}
maxStreams := DefaultGRPCMaxStreams
if len(group) > 0 && group[0] != nil && group[0].MaxGRPCStreams != nil {
maxStreams = *group[0].MaxGRPCStreams
}
maxConnections := DefaultGRPCMaxConnections
if len(group) > 0 && group[0] != nil && group[0].MaxGRPCConnections != nil {
maxConnections = *group[0].MaxGRPCConnections
}
return &ServergRPC{
Server: s,
tlsConfig: tlsConfig,
maxStreams: maxStreams,
maxConnections: maxConnections,
}, nil
}
// Compile-time check to ensure ServergRPC implements the caddy.GracefulServer interface
var _ caddy.GracefulServer = &ServergRPC{}
// Serve implements caddy.TCPServer interface.
func (s *ServergRPC) Serve(l net.Listener) error {
s.m.Lock()
s.listenAddr = l.Addr()
s.m.Unlock()
serverOpts := []grpc.ServerOption{
grpc.MaxRecvMsgSize(maxProtobufPayloadBytes),
grpc.MaxSendMsgSize(maxProtobufPayloadBytes),
}
// Only set MaxConcurrentStreams if not unbounded (0)
if s.maxStreams > 0 {
serverOpts = append(serverOpts, grpc.MaxConcurrentStreams(uint32(s.maxStreams)))
}
if s.Tracer() != nil {
onlyIfParent := func(parentSpanCtx opentracing.SpanContext, method string, req, resp any) bool {
return parentSpanCtx != nil
}
serverOpts = append(serverOpts, grpc.UnaryInterceptor(otgrpc.OpenTracingServerInterceptor(s.Tracer(), otgrpc.IncludingSpans(onlyIfParent))))
}
s.grpcServer = grpc.NewServer(serverOpts...)
pb.RegisterDnsServiceServer(s.grpcServer, s)
if s.tlsConfig != nil {
l = tls.NewListener(l, s.tlsConfig)
}
// Wrap listener to limit concurrent connections
if s.maxConnections > 0 {
l = netutil.LimitListener(l, s.maxConnections)
}
return s.grpcServer.Serve(l)
}
// ServePacket implements caddy.UDPServer interface.
func (s *ServergRPC) ServePacket(p net.PacketConn) error { return nil }
// Listen implements caddy.TCPServer interface.
func (s *ServergRPC) Listen() (net.Listener, error) {
l, err := reuseport.Listen("tcp", s.Addr[len(transport.GRPC+"://"):])
if err != nil {
return nil, err
}
return l, nil
}
// ListenPacket implements caddy.UDPServer interface.
func (s *ServergRPC) ListenPacket() (net.PacketConn, error) { return nil, nil }
// OnStartupComplete lists the sites served by this server
// and any relevant information, assuming Quiet is false.
func (s *ServergRPC) OnStartupComplete() {
if Quiet {
return
}
out := startUpZones(transport.GRPC+"://", s.Addr, s.zones)
if out != "" {
fmt.Print(out)
}
}
// Stop stops the server. It blocks until the server is
// totally stopped.
func (s *ServergRPC) Stop() (err error) {
s.m.Lock()
defer s.m.Unlock()
if s.grpcServer != nil {
s.grpcServer.GracefulStop()
}
return
}
// Query is the main entry-point into the gRPC server. From here we call ServeDNS like
// any normal server. We use a custom responseWriter to pick up the bytes we need to write
// back to the client as a protobuf.
func (s *ServergRPC) Query(ctx context.Context, in *pb.DnsPacket) (*pb.DnsPacket, error) {
if len(in.GetMsg()) > dns.MaxMsgSize {
return nil, fmt.Errorf("dns message exceeds size limit: %d", len(in.GetMsg()))
}
msg := new(dns.Msg)
err := msg.Unpack(in.GetMsg())
if err != nil {
return nil, err
}
p, ok := peer.FromContext(ctx)
if !ok {
return nil, errors.New("no peer in gRPC context")
}
a, ok := p.Addr.(*net.TCPAddr)
if !ok {
return nil, fmt.Errorf("no TCP peer in gRPC context: %v", p.Addr)
}
w := &gRPCresponse{localAddr: s.listenAddr, remoteAddr: a, Msg: msg}
dnsCtx := context.WithValue(ctx, Key{}, s.Server)
dnsCtx = context.WithValue(dnsCtx, LoopKey{}, 0)
s.ServeDNS(dnsCtx, w, msg)
packed, err := w.Msg.Pack()
if err != nil {
return nil, err
}
return &pb.DnsPacket{Msg: packed}, nil
}
// Shutdown stops the server (non gracefully).
func (s *ServergRPC) Shutdown() error {
if s.grpcServer != nil {
s.grpcServer.Stop()
}
return nil
}
type gRPCresponse struct {
localAddr net.Addr
remoteAddr net.Addr
Msg *dns.Msg
}
// Write is the hack that makes this work. It does not actually write the message
// but returns the bytes we need to write in r. We can then pick this up in Query
// and write a proper protobuf back to the client.
func (r *gRPCresponse) Write(b []byte) (int, error) {
r.Msg = new(dns.Msg)
return len(b), r.Msg.Unpack(b)
}
// These methods implement the dns.ResponseWriter interface from Go DNS.
func (r *gRPCresponse) Close() error { return nil }
func (r *gRPCresponse) TsigStatus() error { return nil }
func (r *gRPCresponse) TsigTimersOnly(b bool) {}
func (r *gRPCresponse) Hijack() {}
func (r *gRPCresponse) LocalAddr() net.Addr { return r.localAddr }
func (r *gRPCresponse) RemoteAddr() net.Addr { return r.remoteAddr }
func (r *gRPCresponse) Network() string { return "" }
func (r *gRPCresponse) WriteMsg(m *dns.Msg) error { r.Msg = m; return nil }