| 
									
										
										
										
											2019-03-14 08:12:28 +01:00
										 |  |  | package grpc
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | import (
 | 
					
						
							|  |  |  | 	"context"
 | 
					
						
							|  |  |  | 	"crypto/tls"
 | 
					
						
							| 
									
										
										
										
											2025-09-12 08:21:33 +03:00
										 |  |  | 	"errors"
 | 
					
						
							|  |  |  | 	"fmt"
 | 
					
						
							| 
									
										
										
										
											2019-03-14 08:12:28 +01:00
										 |  |  | 	"strconv"
 | 
					
						
							|  |  |  | 	"time"
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	"github.com/coredns/coredns/pb"
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	"github.com/miekg/dns"
 | 
					
						
							|  |  |  | 	"google.golang.org/grpc"
 | 
					
						
							|  |  |  | 	"google.golang.org/grpc/codes"
 | 
					
						
							|  |  |  | 	"google.golang.org/grpc/credentials"
 | 
					
						
							| 
									
										
										
										
											2022-07-10 20:06:33 +02:00
										 |  |  | 	"google.golang.org/grpc/credentials/insecure"
 | 
					
						
							| 
									
										
										
										
											2019-03-14 08:12:28 +01:00
										 |  |  | 	"google.golang.org/grpc/status"
 | 
					
						
							|  |  |  | )
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-09-12 08:21:33 +03:00
										 |  |  | const (
 | 
					
						
							|  |  |  | 	// maxDNSMessageBytes is the maximum size of a DNS message on the wire.
 | 
					
						
							|  |  |  | 	maxDNSMessageBytes = dns.MaxMsgSize
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	// maxProtobufPayloadBytes accounts for protobuf overhead.
 | 
					
						
							|  |  |  | 	// Field tag=1 (1 byte) + length varint for 65535 (3 bytes) = 4 bytes total
 | 
					
						
							|  |  |  | 	maxProtobufPayloadBytes = maxDNSMessageBytes + 4
 | 
					
						
							|  |  |  | )
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | var (
 | 
					
						
							|  |  |  | 	// ErrDNSMessageTooLarge is returned when a DNS message exceeds the maximum allowed size.
 | 
					
						
							|  |  |  | 	ErrDNSMessageTooLarge = errors.New("dns message exceeds size limit")
 | 
					
						
							|  |  |  | )
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2019-03-14 08:12:28 +01:00
										 |  |  | // Proxy defines an upstream host.
 | 
					
						
							|  |  |  | type Proxy struct {
 | 
					
						
							|  |  |  | 	addr string
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	// connection
 | 
					
						
							|  |  |  | 	client   pb.DnsServiceClient
 | 
					
						
							|  |  |  | 	dialOpts []grpc.DialOption
 | 
					
						
							|  |  |  | }
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | // newProxy returns a new proxy.
 | 
					
						
							|  |  |  | func newProxy(addr string, tlsConfig *tls.Config) (*Proxy, error) {
 | 
					
						
							|  |  |  | 	p := &Proxy{
 | 
					
						
							|  |  |  | 		addr: addr,
 | 
					
						
							|  |  |  | 	}
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	if tlsConfig != nil {
 | 
					
						
							|  |  |  | 		p.dialOpts = append(p.dialOpts, grpc.WithTransportCredentials(credentials.NewTLS(tlsConfig)))
 | 
					
						
							|  |  |  | 	} else {
 | 
					
						
							| 
									
										
										
										
											2022-07-10 20:06:33 +02:00
										 |  |  | 		p.dialOpts = append(p.dialOpts, grpc.WithTransportCredentials(insecure.NewCredentials()))
 | 
					
						
							| 
									
										
										
										
											2019-03-14 08:12:28 +01:00
										 |  |  | 	}
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-09-12 08:21:33 +03:00
										 |  |  | 	// Cap send/recv sizes to avoid oversized messages.
 | 
					
						
							|  |  |  | 	// Note: gRPC size limits apply to the serialized protobuf message size.
 | 
					
						
							|  |  |  | 	p.dialOpts = append(p.dialOpts,
 | 
					
						
							|  |  |  | 		grpc.WithDefaultCallOptions(
 | 
					
						
							|  |  |  | 			grpc.MaxCallRecvMsgSize(maxProtobufPayloadBytes),
 | 
					
						
							|  |  |  | 			grpc.MaxCallSendMsgSize(maxProtobufPayloadBytes),
 | 
					
						
							|  |  |  | 		),
 | 
					
						
							|  |  |  | 	)
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-08-18 06:13:10 +02:00
										 |  |  | 	conn, err := grpc.NewClient(p.addr, p.dialOpts...)
 | 
					
						
							| 
									
										
										
										
											2019-03-14 08:12:28 +01:00
										 |  |  | 	if err != nil {
 | 
					
						
							|  |  |  | 		return nil, err
 | 
					
						
							|  |  |  | 	}
 | 
					
						
							|  |  |  | 	p.client = pb.NewDnsServiceClient(conn)
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	return p, nil
 | 
					
						
							|  |  |  | }
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | // query sends the request and waits for a response.
 | 
					
						
							|  |  |  | func (p *Proxy) query(ctx context.Context, req *dns.Msg) (*dns.Msg, error) {
 | 
					
						
							|  |  |  | 	start := time.Now()
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	msg, err := req.Pack()
 | 
					
						
							|  |  |  | 	if err != nil {
 | 
					
						
							|  |  |  | 		return nil, err
 | 
					
						
							|  |  |  | 	}
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-09-12 08:21:33 +03:00
										 |  |  | 	if err := validateDNSSize(msg); err != nil {
 | 
					
						
							|  |  |  | 		return nil, err
 | 
					
						
							|  |  |  | 	}
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2019-03-14 08:12:28 +01:00
										 |  |  | 	reply, err := p.client.Query(ctx, &pb.DnsPacket{Msg: msg})
 | 
					
						
							|  |  |  | 	if err != nil {
 | 
					
						
							|  |  |  | 		// if not found message, return empty message with NXDomain code
 | 
					
						
							|  |  |  | 		if status.Code(err) == codes.NotFound {
 | 
					
						
							|  |  |  | 			m := new(dns.Msg).SetRcode(req, dns.RcodeNameError)
 | 
					
						
							|  |  |  | 			return m, nil
 | 
					
						
							|  |  |  | 		}
 | 
					
						
							|  |  |  | 		return nil, err
 | 
					
						
							|  |  |  | 	}
 | 
					
						
							| 
									
										
										
										
											2025-09-12 08:21:33 +03:00
										 |  |  | 	wire := reply.GetMsg()
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	if err := validateDNSSize(wire); err != nil {
 | 
					
						
							|  |  |  | 		return nil, err
 | 
					
						
							|  |  |  | 	}
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2019-03-14 08:12:28 +01:00
										 |  |  | 	ret := new(dns.Msg)
 | 
					
						
							| 
									
										
										
										
											2025-09-12 08:21:33 +03:00
										 |  |  | 	if err := ret.Unpack(wire); err != nil {
 | 
					
						
							| 
									
										
										
										
											2019-03-14 08:12:28 +01:00
										 |  |  | 		return nil, err
 | 
					
						
							|  |  |  | 	}
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	rc, ok := dns.RcodeToString[ret.Rcode]
 | 
					
						
							|  |  |  | 	if !ok {
 | 
					
						
							|  |  |  | 		rc = strconv.Itoa(ret.Rcode)
 | 
					
						
							|  |  |  | 	}
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	RequestCount.WithLabelValues(p.addr).Add(1)
 | 
					
						
							|  |  |  | 	RcodeCount.WithLabelValues(rc, p.addr).Add(1)
 | 
					
						
							|  |  |  | 	RequestDuration.WithLabelValues(p.addr).Observe(time.Since(start).Seconds())
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	return ret, nil
 | 
					
						
							|  |  |  | }
 | 
					
						
							| 
									
										
										
										
											2025-09-12 08:21:33 +03:00
										 |  |  | 
 | 
					
						
							|  |  |  | func validateDNSSize(data []byte) error {
 | 
					
						
							|  |  |  | 	l := len(data)
 | 
					
						
							|  |  |  | 	if l > maxDNSMessageBytes {
 | 
					
						
							|  |  |  | 		return fmt.Errorf("%w: %d bytes (limit %d)", ErrDNSMessageTooLarge, l, maxDNSMessageBytes)
 | 
					
						
							|  |  |  | 	}
 | 
					
						
							|  |  |  | 	return nil
 | 
					
						
							|  |  |  | }
 |