mirror of
https://github.com/coredns/coredns.git
synced 2025-10-28 16:54:15 -04:00
fix(grpc): enforce DNS message size limits (#7490)
Add DNS wire size validation for requests/replies. Limit gRPC recv/send via default call options, accounting necessary framing/protobuf overhead. An error is returned for oversized messages. Add test. Signed-off-by: Ville Vesilehto <ville@vesilehto.fi>
This commit is contained in:
@@ -3,6 +3,8 @@ package grpc
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
"strconv"
|
"strconv"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@@ -16,6 +18,20 @@ import (
|
|||||||
"google.golang.org/grpc/status"
|
"google.golang.org/grpc/status"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
// maxDNSMessageBytes is the maximum size of a DNS message on the wire.
|
||||||
|
maxDNSMessageBytes = dns.MaxMsgSize
|
||||||
|
|
||||||
|
// maxProtobufPayloadBytes accounts for protobuf overhead.
|
||||||
|
// Field tag=1 (1 byte) + length varint for 65535 (3 bytes) = 4 bytes total
|
||||||
|
maxProtobufPayloadBytes = maxDNSMessageBytes + 4
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
// ErrDNSMessageTooLarge is returned when a DNS message exceeds the maximum allowed size.
|
||||||
|
ErrDNSMessageTooLarge = errors.New("dns message exceeds size limit")
|
||||||
|
)
|
||||||
|
|
||||||
// Proxy defines an upstream host.
|
// Proxy defines an upstream host.
|
||||||
type Proxy struct {
|
type Proxy struct {
|
||||||
addr string
|
addr string
|
||||||
@@ -37,6 +53,15 @@ func newProxy(addr string, tlsConfig *tls.Config) (*Proxy, error) {
|
|||||||
p.dialOpts = append(p.dialOpts, grpc.WithTransportCredentials(insecure.NewCredentials()))
|
p.dialOpts = append(p.dialOpts, grpc.WithTransportCredentials(insecure.NewCredentials()))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Cap send/recv sizes to avoid oversized messages.
|
||||||
|
// Note: gRPC size limits apply to the serialized protobuf message size.
|
||||||
|
p.dialOpts = append(p.dialOpts,
|
||||||
|
grpc.WithDefaultCallOptions(
|
||||||
|
grpc.MaxCallRecvMsgSize(maxProtobufPayloadBytes),
|
||||||
|
grpc.MaxCallSendMsgSize(maxProtobufPayloadBytes),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
conn, err := grpc.NewClient(p.addr, p.dialOpts...)
|
conn, err := grpc.NewClient(p.addr, p.dialOpts...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@@ -55,6 +80,10 @@ func (p *Proxy) query(ctx context.Context, req *dns.Msg) (*dns.Msg, error) {
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if err := validateDNSSize(msg); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
reply, err := p.client.Query(ctx, &pb.DnsPacket{Msg: msg})
|
reply, err := p.client.Query(ctx, &pb.DnsPacket{Msg: msg})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// if not found message, return empty message with NXDomain code
|
// if not found message, return empty message with NXDomain code
|
||||||
@@ -64,8 +93,14 @@ func (p *Proxy) query(ctx context.Context, req *dns.Msg) (*dns.Msg, error) {
|
|||||||
}
|
}
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
wire := reply.GetMsg()
|
||||||
|
|
||||||
|
if err := validateDNSSize(wire); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
ret := new(dns.Msg)
|
ret := new(dns.Msg)
|
||||||
if err := ret.Unpack(reply.GetMsg()); err != nil {
|
if err := ret.Unpack(wire); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -80,3 +115,11 @@ func (p *Proxy) query(ctx context.Context, req *dns.Msg) (*dns.Msg, error) {
|
|||||||
|
|
||||||
return ret, nil
|
return ret, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func validateDNSSize(data []byte) error {
|
||||||
|
l := len(data)
|
||||||
|
if l > maxDNSMessageBytes {
|
||||||
|
return fmt.Errorf("%w: %d bytes (limit %d)", ErrDNSMessageTooLarge, l, maxDNSMessageBytes)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ import (
|
|||||||
"errors"
|
"errors"
|
||||||
"net"
|
"net"
|
||||||
"path"
|
"path"
|
||||||
|
"slices"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/coredns/caddy"
|
"github.com/coredns/caddy"
|
||||||
@@ -61,6 +62,33 @@ func TestProxy(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestProxy_RejectsOversizedReply(t *testing.T) {
|
||||||
|
p := &Proxy{}
|
||||||
|
oversized := make([]byte, maxDNSMessageBytes+1)
|
||||||
|
p.client = testServiceClient{dnsPacket: &pb.DnsPacket{Msg: oversized}, err: nil}
|
||||||
|
_, err := p.query(context.TODO(), new(dns.Msg))
|
||||||
|
if !errors.Is(err, ErrDNSMessageTooLarge) {
|
||||||
|
t.Fatalf("expected %v, got %v", ErrDNSMessageTooLarge, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestProxy_RejectsOversizedRequest(t *testing.T) {
|
||||||
|
p := &Proxy{}
|
||||||
|
p.client = testServiceClient{dnsPacket: &pb.DnsPacket{Msg: []byte("ok")}, err: nil}
|
||||||
|
|
||||||
|
oversizedMsg := &dns.Msg{}
|
||||||
|
oversizedMsg.SetQuestion("example.org.", dns.TypeA)
|
||||||
|
oversizedMsg.Extra = slices.Repeat([]dns.RR{&dns.TXT{
|
||||||
|
Hdr: dns.RR_Header{Name: "example.org.", Rrtype: dns.TypeTXT, Class: dns.ClassINET, Ttl: 300},
|
||||||
|
Txt: []string{"very long text record to make the message oversized when packed"},
|
||||||
|
}}, 2000)
|
||||||
|
|
||||||
|
_, err := p.query(context.TODO(), oversizedMsg)
|
||||||
|
if !errors.Is(err, ErrDNSMessageTooLarge) {
|
||||||
|
t.Fatalf("expected %v, got %v", ErrDNSMessageTooLarge, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
type testServiceClient struct {
|
type testServiceClient struct {
|
||||||
dnsPacket *pb.DnsPacket
|
dnsPacket *pb.DnsPacket
|
||||||
err error
|
err error
|
||||||
|
|||||||
Reference in New Issue
Block a user