mirror of
https://github.com/coredns/coredns.git
synced 2025-10-27 08:14:18 -04:00
fix(grpc): enforce DNS message size limits (#7490)
Add DNS wire size validation for requests/replies. Limit gRPC recv/send via default call options, accounting necessary framing/protobuf overhead. An error is returned for oversized messages. Add test. Signed-off-by: Ville Vesilehto <ville@vesilehto.fi>
This commit is contained in:
@@ -3,6 +3,8 @@ package grpc
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
@@ -16,6 +18,20 @@ import (
|
||||
"google.golang.org/grpc/status"
|
||||
)
|
||||
|
||||
const (
|
||||
// maxDNSMessageBytes is the maximum size of a DNS message on the wire.
|
||||
maxDNSMessageBytes = dns.MaxMsgSize
|
||||
|
||||
// maxProtobufPayloadBytes accounts for protobuf overhead.
|
||||
// Field tag=1 (1 byte) + length varint for 65535 (3 bytes) = 4 bytes total
|
||||
maxProtobufPayloadBytes = maxDNSMessageBytes + 4
|
||||
)
|
||||
|
||||
var (
|
||||
// ErrDNSMessageTooLarge is returned when a DNS message exceeds the maximum allowed size.
|
||||
ErrDNSMessageTooLarge = errors.New("dns message exceeds size limit")
|
||||
)
|
||||
|
||||
// Proxy defines an upstream host.
|
||||
type Proxy struct {
|
||||
addr string
|
||||
@@ -37,6 +53,15 @@ func newProxy(addr string, tlsConfig *tls.Config) (*Proxy, error) {
|
||||
p.dialOpts = append(p.dialOpts, grpc.WithTransportCredentials(insecure.NewCredentials()))
|
||||
}
|
||||
|
||||
// Cap send/recv sizes to avoid oversized messages.
|
||||
// Note: gRPC size limits apply to the serialized protobuf message size.
|
||||
p.dialOpts = append(p.dialOpts,
|
||||
grpc.WithDefaultCallOptions(
|
||||
grpc.MaxCallRecvMsgSize(maxProtobufPayloadBytes),
|
||||
grpc.MaxCallSendMsgSize(maxProtobufPayloadBytes),
|
||||
),
|
||||
)
|
||||
|
||||
conn, err := grpc.NewClient(p.addr, p.dialOpts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -55,6 +80,10 @@ func (p *Proxy) query(ctx context.Context, req *dns.Msg) (*dns.Msg, error) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := validateDNSSize(msg); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
reply, err := p.client.Query(ctx, &pb.DnsPacket{Msg: msg})
|
||||
if err != nil {
|
||||
// if not found message, return empty message with NXDomain code
|
||||
@@ -64,8 +93,14 @@ func (p *Proxy) query(ctx context.Context, req *dns.Msg) (*dns.Msg, error) {
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
wire := reply.GetMsg()
|
||||
|
||||
if err := validateDNSSize(wire); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ret := new(dns.Msg)
|
||||
if err := ret.Unpack(reply.GetMsg()); err != nil {
|
||||
if err := ret.Unpack(wire); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -80,3 +115,11 @@ func (p *Proxy) query(ctx context.Context, req *dns.Msg) (*dns.Msg, error) {
|
||||
|
||||
return ret, nil
|
||||
}
|
||||
|
||||
func validateDNSSize(data []byte) error {
|
||||
l := len(data)
|
||||
if l > maxDNSMessageBytes {
|
||||
return fmt.Errorf("%w: %d bytes (limit %d)", ErrDNSMessageTooLarge, l, maxDNSMessageBytes)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"errors"
|
||||
"net"
|
||||
"path"
|
||||
"slices"
|
||||
"testing"
|
||||
|
||||
"github.com/coredns/caddy"
|
||||
@@ -61,6 +62,33 @@ func TestProxy(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestProxy_RejectsOversizedReply(t *testing.T) {
|
||||
p := &Proxy{}
|
||||
oversized := make([]byte, maxDNSMessageBytes+1)
|
||||
p.client = testServiceClient{dnsPacket: &pb.DnsPacket{Msg: oversized}, err: nil}
|
||||
_, err := p.query(context.TODO(), new(dns.Msg))
|
||||
if !errors.Is(err, ErrDNSMessageTooLarge) {
|
||||
t.Fatalf("expected %v, got %v", ErrDNSMessageTooLarge, err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestProxy_RejectsOversizedRequest(t *testing.T) {
|
||||
p := &Proxy{}
|
||||
p.client = testServiceClient{dnsPacket: &pb.DnsPacket{Msg: []byte("ok")}, err: nil}
|
||||
|
||||
oversizedMsg := &dns.Msg{}
|
||||
oversizedMsg.SetQuestion("example.org.", dns.TypeA)
|
||||
oversizedMsg.Extra = slices.Repeat([]dns.RR{&dns.TXT{
|
||||
Hdr: dns.RR_Header{Name: "example.org.", Rrtype: dns.TypeTXT, Class: dns.ClassINET, Ttl: 300},
|
||||
Txt: []string{"very long text record to make the message oversized when packed"},
|
||||
}}, 2000)
|
||||
|
||||
_, err := p.query(context.TODO(), oversizedMsg)
|
||||
if !errors.Is(err, ErrDNSMessageTooLarge) {
|
||||
t.Fatalf("expected %v, got %v", ErrDNSMessageTooLarge, err)
|
||||
}
|
||||
}
|
||||
|
||||
type testServiceClient struct {
|
||||
dnsPacket *pb.DnsPacket
|
||||
err error
|
||||
|
||||
Reference in New Issue
Block a user