diff --git a/plugin/grpc/proxy.go b/plugin/grpc/proxy.go index a94e76902..fc06a5a46 100644 --- a/plugin/grpc/proxy.go +++ b/plugin/grpc/proxy.go @@ -3,6 +3,8 @@ package grpc import ( "context" "crypto/tls" + "errors" + "fmt" "strconv" "time" @@ -16,6 +18,20 @@ import ( "google.golang.org/grpc/status" ) +const ( + // maxDNSMessageBytes is the maximum size of a DNS message on the wire. + maxDNSMessageBytes = dns.MaxMsgSize + + // maxProtobufPayloadBytes accounts for protobuf overhead. + // Field tag=1 (1 byte) + length varint for 65535 (3 bytes) = 4 bytes total + maxProtobufPayloadBytes = maxDNSMessageBytes + 4 +) + +var ( + // ErrDNSMessageTooLarge is returned when a DNS message exceeds the maximum allowed size. + ErrDNSMessageTooLarge = errors.New("dns message exceeds size limit") +) + // Proxy defines an upstream host. type Proxy struct { addr string @@ -37,6 +53,15 @@ func newProxy(addr string, tlsConfig *tls.Config) (*Proxy, error) { p.dialOpts = append(p.dialOpts, grpc.WithTransportCredentials(insecure.NewCredentials())) } + // Cap send/recv sizes to avoid oversized messages. + // Note: gRPC size limits apply to the serialized protobuf message size. + p.dialOpts = append(p.dialOpts, + grpc.WithDefaultCallOptions( + grpc.MaxCallRecvMsgSize(maxProtobufPayloadBytes), + grpc.MaxCallSendMsgSize(maxProtobufPayloadBytes), + ), + ) + conn, err := grpc.NewClient(p.addr, p.dialOpts...) if err != nil { return nil, err @@ -55,6 +80,10 @@ func (p *Proxy) query(ctx context.Context, req *dns.Msg) (*dns.Msg, error) { return nil, err } + if err := validateDNSSize(msg); err != nil { + return nil, err + } + reply, err := p.client.Query(ctx, &pb.DnsPacket{Msg: msg}) if err != nil { // if not found message, return empty message with NXDomain code @@ -64,8 +93,14 @@ func (p *Proxy) query(ctx context.Context, req *dns.Msg) (*dns.Msg, error) { } return nil, err } + wire := reply.GetMsg() + + if err := validateDNSSize(wire); err != nil { + return nil, err + } + ret := new(dns.Msg) - if err := ret.Unpack(reply.GetMsg()); err != nil { + if err := ret.Unpack(wire); err != nil { return nil, err } @@ -80,3 +115,11 @@ func (p *Proxy) query(ctx context.Context, req *dns.Msg) (*dns.Msg, error) { return ret, nil } + +func validateDNSSize(data []byte) error { + l := len(data) + if l > maxDNSMessageBytes { + return fmt.Errorf("%w: %d bytes (limit %d)", ErrDNSMessageTooLarge, l, maxDNSMessageBytes) + } + return nil +} diff --git a/plugin/grpc/proxy_test.go b/plugin/grpc/proxy_test.go index 2ca0b1b48..b5c92f8e8 100644 --- a/plugin/grpc/proxy_test.go +++ b/plugin/grpc/proxy_test.go @@ -5,6 +5,7 @@ import ( "errors" "net" "path" + "slices" "testing" "github.com/coredns/caddy" @@ -61,6 +62,33 @@ func TestProxy(t *testing.T) { } } +func TestProxy_RejectsOversizedReply(t *testing.T) { + p := &Proxy{} + oversized := make([]byte, maxDNSMessageBytes+1) + p.client = testServiceClient{dnsPacket: &pb.DnsPacket{Msg: oversized}, err: nil} + _, err := p.query(context.TODO(), new(dns.Msg)) + if !errors.Is(err, ErrDNSMessageTooLarge) { + t.Fatalf("expected %v, got %v", ErrDNSMessageTooLarge, err) + } +} + +func TestProxy_RejectsOversizedRequest(t *testing.T) { + p := &Proxy{} + p.client = testServiceClient{dnsPacket: &pb.DnsPacket{Msg: []byte("ok")}, err: nil} + + oversizedMsg := &dns.Msg{} + oversizedMsg.SetQuestion("example.org.", dns.TypeA) + oversizedMsg.Extra = slices.Repeat([]dns.RR{&dns.TXT{ + Hdr: dns.RR_Header{Name: "example.org.", Rrtype: dns.TypeTXT, Class: dns.ClassINET, Ttl: 300}, + Txt: []string{"very long text record to make the message oversized when packed"}, + }}, 2000) + + _, err := p.query(context.TODO(), oversizedMsg) + if !errors.Is(err, ErrDNSMessageTooLarge) { + t.Fatalf("expected %v, got %v", ErrDNSMessageTooLarge, err) + } +} + type testServiceClient struct { dnsPacket *pb.DnsPacket err error