fix(grpc): enforce DNS message size limits (#7490)

Add DNS wire size validation for requests/replies. Limit gRPC
recv/send via default call options, accounting necessary
framing/protobuf overhead. An error is returned for oversized
messages. Add test.

Signed-off-by: Ville Vesilehto <ville@vesilehto.fi>
This commit is contained in:
Ville Vesilehto
2025-09-12 08:21:33 +03:00
committed by GitHub
parent 39abf5aeba
commit 8817d8f2f9
2 changed files with 72 additions and 1 deletions

View File

@@ -3,6 +3,8 @@ package grpc
import (
"context"
"crypto/tls"
"errors"
"fmt"
"strconv"
"time"
@@ -16,6 +18,20 @@ import (
"google.golang.org/grpc/status"
)
const (
// maxDNSMessageBytes is the maximum size of a DNS message on the wire.
maxDNSMessageBytes = dns.MaxMsgSize
// maxProtobufPayloadBytes accounts for protobuf overhead.
// Field tag=1 (1 byte) + length varint for 65535 (3 bytes) = 4 bytes total
maxProtobufPayloadBytes = maxDNSMessageBytes + 4
)
var (
// ErrDNSMessageTooLarge is returned when a DNS message exceeds the maximum allowed size.
ErrDNSMessageTooLarge = errors.New("dns message exceeds size limit")
)
// Proxy defines an upstream host.
type Proxy struct {
addr string
@@ -37,6 +53,15 @@ func newProxy(addr string, tlsConfig *tls.Config) (*Proxy, error) {
p.dialOpts = append(p.dialOpts, grpc.WithTransportCredentials(insecure.NewCredentials()))
}
// Cap send/recv sizes to avoid oversized messages.
// Note: gRPC size limits apply to the serialized protobuf message size.
p.dialOpts = append(p.dialOpts,
grpc.WithDefaultCallOptions(
grpc.MaxCallRecvMsgSize(maxProtobufPayloadBytes),
grpc.MaxCallSendMsgSize(maxProtobufPayloadBytes),
),
)
conn, err := grpc.NewClient(p.addr, p.dialOpts...)
if err != nil {
return nil, err
@@ -55,6 +80,10 @@ func (p *Proxy) query(ctx context.Context, req *dns.Msg) (*dns.Msg, error) {
return nil, err
}
if err := validateDNSSize(msg); err != nil {
return nil, err
}
reply, err := p.client.Query(ctx, &pb.DnsPacket{Msg: msg})
if err != nil {
// if not found message, return empty message with NXDomain code
@@ -64,8 +93,14 @@ func (p *Proxy) query(ctx context.Context, req *dns.Msg) (*dns.Msg, error) {
}
return nil, err
}
wire := reply.GetMsg()
if err := validateDNSSize(wire); err != nil {
return nil, err
}
ret := new(dns.Msg)
if err := ret.Unpack(reply.GetMsg()); err != nil {
if err := ret.Unpack(wire); err != nil {
return nil, err
}
@@ -80,3 +115,11 @@ func (p *Proxy) query(ctx context.Context, req *dns.Msg) (*dns.Msg, error) {
return ret, nil
}
func validateDNSSize(data []byte) error {
l := len(data)
if l > maxDNSMessageBytes {
return fmt.Errorf("%w: %d bytes (limit %d)", ErrDNSMessageTooLarge, l, maxDNSMessageBytes)
}
return nil
}

View File

@@ -5,6 +5,7 @@ import (
"errors"
"net"
"path"
"slices"
"testing"
"github.com/coredns/caddy"
@@ -61,6 +62,33 @@ func TestProxy(t *testing.T) {
}
}
func TestProxy_RejectsOversizedReply(t *testing.T) {
p := &Proxy{}
oversized := make([]byte, maxDNSMessageBytes+1)
p.client = testServiceClient{dnsPacket: &pb.DnsPacket{Msg: oversized}, err: nil}
_, err := p.query(context.TODO(), new(dns.Msg))
if !errors.Is(err, ErrDNSMessageTooLarge) {
t.Fatalf("expected %v, got %v", ErrDNSMessageTooLarge, err)
}
}
func TestProxy_RejectsOversizedRequest(t *testing.T) {
p := &Proxy{}
p.client = testServiceClient{dnsPacket: &pb.DnsPacket{Msg: []byte("ok")}, err: nil}
oversizedMsg := &dns.Msg{}
oversizedMsg.SetQuestion("example.org.", dns.TypeA)
oversizedMsg.Extra = slices.Repeat([]dns.RR{&dns.TXT{
Hdr: dns.RR_Header{Name: "example.org.", Rrtype: dns.TypeTXT, Class: dns.ClassINET, Ttl: 300},
Txt: []string{"very long text record to make the message oversized when packed"},
}}, 2000)
_, err := p.query(context.TODO(), oversizedMsg)
if !errors.Is(err, ErrDNSMessageTooLarge) {
t.Fatalf("expected %v, got %v", ErrDNSMessageTooLarge, err)
}
}
type testServiceClient struct {
dnsPacket *pb.DnsPacket
err error