Files
coredns/core/dnsserver/server_grpc.go
Yong Tang 4c9a80c296 core: Add full TSIG verification in gRPC transport (#8006)
* core: Add full TSIG verification in gRPC transport

This PR add full TSIG verification in gRPC using dns.TsigVerify() so invalid signatures and timestamps are correctly detected instead of only checking key presence.

Signed-off-by: Yong Tang <yong.tang.github@outlook.com>

* Fix

Signed-off-by: Yong Tang <yong.tang.github@outlook.com>

* Fix

Signed-off-by: Yong Tang <yong.tang.github@outlook.com>

---------

Signed-off-by: Yong Tang <yong.tang.github@outlook.com>
2026-04-04 11:58:36 +03:00

253 lines
7.2 KiB
Go

package dnsserver
import (
"context"
"crypto/tls"
"errors"
"fmt"
"net"
"github.com/coredns/caddy"
"github.com/coredns/coredns/pb"
"github.com/coredns/coredns/plugin/pkg/reuseport"
"github.com/coredns/coredns/plugin/pkg/transport"
"github.com/grpc-ecosystem/grpc-opentracing/go/otgrpc"
"github.com/miekg/dns"
"github.com/opentracing/opentracing-go"
"github.com/pires/go-proxyproto"
"golang.org/x/net/netutil"
"google.golang.org/grpc"
"google.golang.org/grpc/peer"
)
const (
// maxDNSMessageBytes is the maximum size of a DNS message on the wire.
maxDNSMessageBytes = dns.MaxMsgSize
// maxProtobufPayloadBytes accounts for protobuf overhead.
// Field tag=1 (1 byte) + length varint for 65535 (3 bytes) = 4 bytes total
maxProtobufPayloadBytes = maxDNSMessageBytes + 4
// DefaultGRPCMaxStreams is the default maximum number of concurrent streams per connection.
DefaultGRPCMaxStreams = 256
// DefaultGRPCMaxConnections is the default maximum number of concurrent connections.
DefaultGRPCMaxConnections = 200
)
// ServergRPC represents an instance of a DNS-over-gRPC server.
type ServergRPC struct {
*Server
*pb.UnimplementedDnsServiceServer
grpcServer *grpc.Server
listenAddr net.Addr
tlsConfig *tls.Config
maxStreams int
maxConnections int
}
// NewServergRPC returns a new CoreDNS GRPC server and compiles all plugin in to it.
func NewServergRPC(addr string, group []*Config) (*ServergRPC, error) {
s, err := NewServer(addr, group)
if err != nil {
return nil, err
}
// The *tls* plugin must make sure that multiple conflicting
// TLS configuration returns an error: it can only be specified once.
var tlsConfig *tls.Config
for _, z := range s.zones {
for _, conf := range z {
// Should we error if some configs *don't* have TLS?
tlsConfig = conf.TLSConfig
}
}
// http/2 is required when using gRPC. We need to specify it in next protos
// or the upgrade won't happen.
if tlsConfig != nil {
tlsConfig.NextProtos = []string{"h2"}
}
maxStreams := DefaultGRPCMaxStreams
if len(group) > 0 && group[0] != nil && group[0].MaxGRPCStreams != nil {
maxStreams = *group[0].MaxGRPCStreams
}
maxConnections := DefaultGRPCMaxConnections
if len(group) > 0 && group[0] != nil && group[0].MaxGRPCConnections != nil {
maxConnections = *group[0].MaxGRPCConnections
}
return &ServergRPC{
Server: s,
tlsConfig: tlsConfig,
maxStreams: maxStreams,
maxConnections: maxConnections,
}, nil
}
// Compile-time check to ensure ServergRPC implements the caddy.GracefulServer interface
var _ caddy.GracefulServer = &ServergRPC{}
// Serve implements caddy.TCPServer interface.
func (s *ServergRPC) Serve(l net.Listener) error {
s.m.Lock()
s.listenAddr = l.Addr()
s.m.Unlock()
serverOpts := []grpc.ServerOption{
grpc.MaxRecvMsgSize(maxProtobufPayloadBytes),
grpc.MaxSendMsgSize(maxProtobufPayloadBytes),
}
// Only set MaxConcurrentStreams if not unbounded (0)
if s.maxStreams > 0 {
serverOpts = append(serverOpts, grpc.MaxConcurrentStreams(uint32(s.maxStreams))) // #nosec G115 -- maxStreams is bounded
}
if s.Tracer() != nil {
onlyIfParent := func(parentSpanCtx opentracing.SpanContext, _method string, _req, _resp any) bool {
return parentSpanCtx != nil
}
serverOpts = append(serverOpts, grpc.UnaryInterceptor(otgrpc.OpenTracingServerInterceptor(s.Tracer(), otgrpc.IncludingSpans(onlyIfParent))))
}
s.grpcServer = grpc.NewServer(serverOpts...)
pb.RegisterDnsServiceServer(s.grpcServer, s)
if s.tlsConfig != nil {
l = tls.NewListener(l, s.tlsConfig)
}
// Wrap listener to limit concurrent connections
if s.maxConnections > 0 {
l = netutil.LimitListener(l, s.maxConnections)
}
return s.grpcServer.Serve(l)
}
// ServePacket implements caddy.UDPServer interface.
func (s *ServergRPC) ServePacket(_p net.PacketConn) error { return nil }
// Listen implements caddy.TCPServer interface.
func (s *ServergRPC) Listen() (net.Listener, error) {
l, err := reuseport.Listen("tcp", s.Addr[len(transport.GRPC+"://"):])
if err != nil {
return nil, err
}
if s.connPolicy != nil {
l = &proxyproto.Listener{Listener: l, ConnPolicy: s.connPolicy}
}
return l, nil
}
// ListenPacket implements caddy.UDPServer interface.
func (s *ServergRPC) ListenPacket() (net.PacketConn, error) { return nil, nil }
// OnStartupComplete lists the sites served by this server
// and any relevant information, assuming Quiet is false.
func (s *ServergRPC) OnStartupComplete() {
if Quiet {
return
}
out := startUpZones(transport.GRPC+"://", s.Addr, s.zones)
if out != "" {
fmt.Print(out)
}
}
// Stop stops the server. It blocks until the server is
// totally stopped.
func (s *ServergRPC) Stop() (err error) {
s.m.Lock()
defer s.m.Unlock()
if s.grpcServer != nil {
s.grpcServer.GracefulStop()
}
return
}
// Query is the main entry-point into the gRPC server. From here we call ServeDNS like
// any normal server. We use a custom responseWriter to pick up the bytes we need to write
// back to the client as a protobuf.
func (s *ServergRPC) Query(ctx context.Context, in *pb.DnsPacket) (*pb.DnsPacket, error) {
if len(in.GetMsg()) > dns.MaxMsgSize {
return nil, fmt.Errorf("dns message exceeds size limit: %d", len(in.GetMsg()))
}
msg := new(dns.Msg)
err := msg.Unpack(in.GetMsg())
if err != nil {
return nil, err
}
p, ok := peer.FromContext(ctx)
if !ok {
return nil, errors.New("no peer in gRPC context")
}
a, ok := p.Addr.(*net.TCPAddr)
if !ok {
return nil, fmt.Errorf("no TCP peer in gRPC context: %v", p.Addr)
}
w := &gRPCresponse{localAddr: s.listenAddr, remoteAddr: a, Msg: msg}
if tsig := msg.IsTsig(); tsig != nil {
if s.tsigSecret == nil {
w.tsigStatus = dns.ErrSecret
} else if secret, ok := s.tsigSecret[tsig.Hdr.Name]; !ok {
w.tsigStatus = dns.ErrSecret
} else {
w.tsigStatus = dns.TsigVerify(in.GetMsg(), secret, "", false)
}
}
dnsCtx := context.WithValue(ctx, Key{}, s.Server)
dnsCtx = context.WithValue(dnsCtx, LoopKey{}, 0)
s.ServeDNS(dnsCtx, w, msg)
packed, err := w.Msg.Pack()
if err != nil {
return nil, err
}
return &pb.DnsPacket{Msg: packed}, nil
}
// Shutdown stops the server (non gracefully).
func (s *ServergRPC) Shutdown() error {
if s.grpcServer != nil {
s.grpcServer.Stop()
}
return nil
}
type gRPCresponse struct {
localAddr net.Addr
remoteAddr net.Addr
Msg *dns.Msg
tsigStatus error
}
// Write is the hack that makes this work. It does not actually write the message
// but returns the bytes we need to write in r. We can then pick this up in Query
// and write a proper protobuf back to the client.
func (r *gRPCresponse) Write(b []byte) (int, error) {
r.Msg = new(dns.Msg)
return len(b), r.Msg.Unpack(b)
}
// These methods implement the dns.ResponseWriter interface from Go DNS.
func (r *gRPCresponse) Close() error { return nil }
func (r *gRPCresponse) TsigStatus() error { return r.tsigStatus }
func (r *gRPCresponse) TsigTimersOnly(_b bool) {}
func (r *gRPCresponse) Hijack() {}
func (r *gRPCresponse) LocalAddr() net.Addr { return r.localAddr }
func (r *gRPCresponse) RemoteAddr() net.Addr { return r.remoteAddr }
func (r *gRPCresponse) Network() string { return "" }
func (r *gRPCresponse) WriteMsg(m *dns.Msg) error { r.Msg = m; return nil }