| 
									
										
										
										
											2023-03-24 12:55:51 +00:00
										 |  |  | // Package proxy implements a forwarding proxy. It caches an upstream net.Conn for some time, so if the same
 | 
					
						
							| 
									
										
										
										
											2018-02-05 22:00:47 +00:00
										 |  |  | // client returns the upstream's Conn will be precached. Depending on how you benchmark this looks to be
 | 
					
						
							| 
									
										
										
										
											2018-08-14 17:55:55 +02:00
										 |  |  | // 50% faster than just opening a new connection for every client. It works with UDP and TCP and uses
 | 
					
						
							| 
									
										
										
										
											2018-02-05 22:00:47 +00:00
										 |  |  | // inband healthchecking.
 | 
					
						
							| 
									
										
										
										
											2023-03-24 12:55:51 +00:00
										 |  |  | package proxy
 | 
					
						
							| 
									
										
										
										
											2018-02-05 22:00:47 +00:00
										 |  |  | 
 | 
					
						
							|  |  |  | import (
 | 
					
						
							| 
									
										
										
										
											2018-04-22 08:34:35 +01:00
										 |  |  | 	"context"
 | 
					
						
							| 
									
										
										
										
											2023-09-07 12:01:45 -07:00
										 |  |  | 	"errors"
 | 
					
						
							| 
									
										
										
										
											2018-04-06 15:41:48 +03:00
										 |  |  | 	"io"
 | 
					
						
							| 
									
										
										
										
											2018-02-05 22:00:47 +00:00
										 |  |  | 	"strconv"
 | 
					
						
							| 
									
										
										
										
											2023-09-07 12:01:45 -07:00
										 |  |  | 	"strings"
 | 
					
						
							| 
									
										
										
										
											2018-04-11 09:50:06 +03:00
										 |  |  | 	"sync/atomic"
 | 
					
						
							| 
									
										
										
										
											2018-02-05 22:00:47 +00:00
										 |  |  | 	"time"
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	"github.com/coredns/coredns/request"
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	"github.com/miekg/dns"
 | 
					
						
							|  |  |  | )
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2018-06-15 02:37:22 -04:00
										 |  |  | // limitTimeout is a utility function to auto-tune timeout values
 | 
					
						
							|  |  |  | // average observed time is moved towards the last observed delay moderated by a weight
 | 
					
						
							|  |  |  | // next timeout to use will be the double of the computed average, limited by min and max frame.
 | 
					
						
							|  |  |  | func limitTimeout(currentAvg *int64, minValue time.Duration, maxValue time.Duration) time.Duration {
 | 
					
						
							|  |  |  | 	rt := time.Duration(atomic.LoadInt64(currentAvg))
 | 
					
						
							|  |  |  | 	if rt < minValue {
 | 
					
						
							|  |  |  | 		return minValue
 | 
					
						
							|  |  |  | 	}
 | 
					
						
							|  |  |  | 	if rt < maxValue/2 {
 | 
					
						
							|  |  |  | 		return 2 * rt
 | 
					
						
							|  |  |  | 	}
 | 
					
						
							|  |  |  | 	return maxValue
 | 
					
						
							|  |  |  | }
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | func averageTimeout(currentAvg *int64, observedDuration time.Duration, weight int64) {
 | 
					
						
							|  |  |  | 	dt := time.Duration(atomic.LoadInt64(currentAvg))
 | 
					
						
							|  |  |  | 	atomic.AddInt64(currentAvg, int64(observedDuration-dt)/weight)
 | 
					
						
							|  |  |  | }
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2018-09-19 07:29:37 +01:00
										 |  |  | func (t *Transport) dialTimeout() time.Duration {
 | 
					
						
							| 
									
										
										
										
											2018-06-15 02:37:22 -04:00
										 |  |  | 	return limitTimeout(&t.avgDialTime, minDialTimeout, maxDialTimeout)
 | 
					
						
							|  |  |  | }
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2018-09-19 07:29:37 +01:00
										 |  |  | func (t *Transport) updateDialTimeout(newDialTime time.Duration) {
 | 
					
						
							| 
									
										
										
										
											2018-06-15 02:37:22 -04:00
										 |  |  | 	averageTimeout(&t.avgDialTime, newDialTime, cumulativeAvgWeight)
 | 
					
						
							|  |  |  | }
 | 
					
						
							| 
									
										
										
										
											2018-04-16 19:51:49 +01:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2018-06-15 02:37:22 -04:00
										 |  |  | // Dial dials the address configured in transport, potentially reusing a connection or creating a new one.
 | 
					
						
							| 
									
										
										
										
											2019-10-01 16:39:42 +01:00
										 |  |  | func (t *Transport) Dial(proto string) (*persistConn, bool, error) {
 | 
					
						
							| 
									
										
										
										
											2018-06-15 02:37:22 -04:00
										 |  |  | 	// If tls has been configured; use it.
 | 
					
						
							|  |  |  | 	if t.tlsConfig != nil {
 | 
					
						
							|  |  |  | 		proto = "tcp-tls"
 | 
					
						
							| 
									
										
										
										
											2018-04-16 19:51:49 +01:00
										 |  |  | 	}
 | 
					
						
							| 
									
										
										
										
											2018-06-15 02:37:22 -04:00
										 |  |  | 
 | 
					
						
							|  |  |  | 	t.dial <- proto
 | 
					
						
							| 
									
										
										
										
											2019-10-01 16:39:42 +01:00
										 |  |  | 	pc := <-t.ret
 | 
					
						
							| 
									
										
										
										
											2018-06-15 02:37:22 -04:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2019-10-01 16:39:42 +01:00
										 |  |  | 	if pc != nil {
 | 
					
						
							| 
									
										
										
										
											2023-07-04 15:35:55 +01:00
										 |  |  | 		connCacheHitsCount.WithLabelValues(t.proxyName, t.addr, proto).Add(1)
 | 
					
						
							| 
									
										
										
										
											2019-10-01 16:39:42 +01:00
										 |  |  | 		return pc, true, nil
 | 
					
						
							| 
									
										
										
										
											2018-04-11 09:50:06 +03:00
										 |  |  | 	}
 | 
					
						
							| 
									
										
										
										
											2023-07-04 15:35:55 +01:00
										 |  |  | 	connCacheMissesCount.WithLabelValues(t.proxyName, t.addr, proto).Add(1)
 | 
					
						
							| 
									
										
										
										
											2018-06-15 02:37:22 -04:00
										 |  |  | 
 | 
					
						
							|  |  |  | 	reqTime := time.Now()
 | 
					
						
							|  |  |  | 	timeout := t.dialTimeout()
 | 
					
						
							|  |  |  | 	if proto == "tcp-tls" {
 | 
					
						
							|  |  |  | 		conn, err := dns.DialTimeoutWithTLS("tcp", t.addr, t.tlsConfig, timeout)
 | 
					
						
							|  |  |  | 		t.updateDialTimeout(time.Since(reqTime))
 | 
					
						
							| 
									
										
										
										
											2019-10-01 16:39:42 +01:00
										 |  |  | 		return &persistConn{c: conn}, false, err
 | 
					
						
							| 
									
										
										
										
											2018-06-15 02:37:22 -04:00
										 |  |  | 	}
 | 
					
						
							|  |  |  | 	conn, err := dns.DialTimeout(proto, t.addr, timeout)
 | 
					
						
							|  |  |  | 	t.updateDialTimeout(time.Since(reqTime))
 | 
					
						
							| 
									
										
										
										
											2019-10-01 16:39:42 +01:00
										 |  |  | 	return &persistConn{c: conn}, false, err
 | 
					
						
							| 
									
										
										
										
											2018-06-15 02:37:22 -04:00
										 |  |  | }
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2018-05-04 08:47:26 +03:00
										 |  |  | // Connect selects an upstream, sends the request and waits for a response.
 | 
					
						
							| 
									
										
										
										
											2023-03-24 12:55:51 +00:00
										 |  |  | func (p *Proxy) Connect(ctx context.Context, state request.Request, opts Options) (*dns.Msg, error) {
 | 
					
						
							| 
									
										
										
										
											2018-02-05 22:00:47 +00:00
										 |  |  | 	start := time.Now()
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2018-07-07 10:14:21 +03:00
										 |  |  | 	proto := ""
 | 
					
						
							|  |  |  | 	switch {
 | 
					
						
							| 
									
										
										
										
											2023-03-24 12:55:51 +00:00
										 |  |  | 	case opts.ForceTCP: // TCP flag has precedence over UDP flag
 | 
					
						
							| 
									
										
										
										
											2018-02-05 22:00:47 +00:00
										 |  |  | 		proto = "tcp"
 | 
					
						
							| 
									
										
										
										
											2023-03-24 12:55:51 +00:00
										 |  |  | 	case opts.PreferUDP:
 | 
					
						
							| 
									
										
										
										
											2018-07-07 10:14:21 +03:00
										 |  |  | 		proto = "udp"
 | 
					
						
							|  |  |  | 	default:
 | 
					
						
							|  |  |  | 		proto = state.Proto()
 | 
					
						
							| 
									
										
										
										
											2018-02-05 22:00:47 +00:00
										 |  |  | 	}
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2019-10-01 16:39:42 +01:00
										 |  |  | 	pc, cached, err := p.transport.Dial(proto)
 | 
					
						
							| 
									
										
										
										
											2018-02-05 22:00:47 +00:00
										 |  |  | 	if err != nil {
 | 
					
						
							|  |  |  | 		return nil, err
 | 
					
						
							|  |  |  | 	}
 | 
					
						
							| 
									
										
										
										
											2018-04-26 09:34:58 +01:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2018-02-05 22:00:47 +00:00
										 |  |  | 	// Set buffer size correctly for this client.
 | 
					
						
							| 
									
										
										
										
											2019-10-01 16:39:42 +01:00
										 |  |  | 	pc.c.UDPSize = uint16(state.Size())
 | 
					
						
							|  |  |  | 	if pc.c.UDPSize < 512 {
 | 
					
						
							|  |  |  | 		pc.c.UDPSize = 512
 | 
					
						
							| 
									
										
										
										
											2018-02-05 22:00:47 +00:00
										 |  |  | 	}
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2019-10-01 16:39:42 +01:00
										 |  |  | 	pc.c.SetWriteDeadline(time.Now().Add(maxTimeout))
 | 
					
						
							| 
									
										
										
										
											2021-10-09 01:34:22 +08:00
										 |  |  | 	// records the origin Id before upstream.
 | 
					
						
							|  |  |  | 	originId := state.Req.Id
 | 
					
						
							|  |  |  | 	state.Req.Id = dns.Id()
 | 
					
						
							| 
									
										
										
										
											2021-10-11 10:28:01 +00:00
										 |  |  | 	defer func() {
 | 
					
						
							| 
									
										
										
										
											2021-10-09 01:34:22 +08:00
										 |  |  | 		state.Req.Id = originId
 | 
					
						
							|  |  |  | 	}()
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2019-10-01 16:39:42 +01:00
										 |  |  | 	if err := pc.c.WriteMsg(state.Req); err != nil {
 | 
					
						
							|  |  |  | 		pc.c.Close() // not giving it back
 | 
					
						
							| 
									
										
										
										
											2018-04-06 15:41:48 +03:00
										 |  |  | 		if err == io.EOF && cached {
 | 
					
						
							| 
									
										
										
										
											2018-05-09 14:41:14 +03:00
										 |  |  | 			return nil, ErrCachedClosed
 | 
					
						
							| 
									
										
										
										
											2018-04-06 15:41:48 +03:00
										 |  |  | 		}
 | 
					
						
							| 
									
										
										
										
											2018-02-05 22:00:47 +00:00
										 |  |  | 		return nil, err
 | 
					
						
							|  |  |  | 	}
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2019-03-01 17:40:52 +03:00
										 |  |  | 	var ret *dns.Msg
 | 
					
						
							| 
									
										
										
										
											2023-03-24 12:55:51 +00:00
										 |  |  | 	pc.c.SetReadDeadline(time.Now().Add(p.readTimeout))
 | 
					
						
							| 
									
										
										
										
											2019-03-01 17:40:52 +03:00
										 |  |  | 	for {
 | 
					
						
							| 
									
										
										
										
											2019-10-01 16:39:42 +01:00
										 |  |  | 		ret, err = pc.c.ReadMsg()
 | 
					
						
							| 
									
										
										
										
											2019-03-01 17:40:52 +03:00
										 |  |  | 		if err != nil {
 | 
					
						
							| 
									
										
										
										
											2023-09-07 12:01:45 -07:00
										 |  |  | 			if ret != nil && (state.Req.Id == ret.Id) && p.transport.transportTypeFromConn(pc) == typeUDP && shouldTruncateResponse(err) {
 | 
					
						
							|  |  |  | 				// For UDP, if the error is an overflow, we probably have an upstream misbehaving in some way.
 | 
					
						
							|  |  |  | 				// (e.g. sending >512 byte responses without an eDNS0 OPT RR).
 | 
					
						
							|  |  |  | 				// Instead of returning an error, return an empty response with TC bit set. This will make the
 | 
					
						
							|  |  |  | 				// client retry over TCP (if that's supported) or at least receive a clean
 | 
					
						
							|  |  |  | 				// error. The connection is still good so we break before the close.
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 				// Truncate the response.
 | 
					
						
							|  |  |  | 				ret = truncateResponse(ret)
 | 
					
						
							|  |  |  | 				break
 | 
					
						
							|  |  |  | 			}
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-08-14 20:33:37 -04:00
										 |  |  | 			pc.c.Close() // not giving it back
 | 
					
						
							| 
									
										
										
										
											2019-03-01 17:40:52 +03:00
										 |  |  | 			if err == io.EOF && cached {
 | 
					
						
							|  |  |  | 				return nil, ErrCachedClosed
 | 
					
						
							|  |  |  | 			}
 | 
					
						
							| 
									
										
										
										
											2023-08-14 20:33:37 -04:00
										 |  |  | 			// recovery the origin Id after upstream.
 | 
					
						
							| 
									
										
										
										
											2021-10-11 10:28:01 +00:00
										 |  |  | 			if ret != nil {
 | 
					
						
							| 
									
										
										
										
											2021-10-09 01:34:22 +08:00
										 |  |  | 				ret.Id = originId
 | 
					
						
							|  |  |  | 			}
 | 
					
						
							| 
									
										
										
										
											2019-03-01 17:40:52 +03:00
										 |  |  | 			return ret, err
 | 
					
						
							|  |  |  | 		}
 | 
					
						
							|  |  |  | 		// drop out-of-order responses
 | 
					
						
							|  |  |  | 		if state.Req.Id == ret.Id {
 | 
					
						
							|  |  |  | 			break
 | 
					
						
							| 
									
										
										
										
											2018-04-06 15:41:48 +03:00
										 |  |  | 		}
 | 
					
						
							| 
									
										
										
										
											2018-02-05 22:00:47 +00:00
										 |  |  | 	}
 | 
					
						
							| 
									
										
										
										
											2021-10-09 01:34:22 +08:00
										 |  |  | 	// recovery the origin Id after upstream.
 | 
					
						
							|  |  |  | 	ret.Id = originId
 | 
					
						
							| 
									
										
										
										
											2018-02-05 22:00:47 +00:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2019-10-01 16:39:42 +01:00
										 |  |  | 	p.transport.Yield(pc)
 | 
					
						
							| 
									
										
										
										
											2018-02-05 22:00:47 +00:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2018-07-07 10:14:21 +03:00
										 |  |  | 	rc, ok := dns.RcodeToString[ret.Rcode]
 | 
					
						
							|  |  |  | 	if !ok {
 | 
					
						
							|  |  |  | 		rc = strconv.Itoa(ret.Rcode)
 | 
					
						
							| 
									
										
										
										
											2018-02-05 22:00:47 +00:00
										 |  |  | 	}
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-07-04 15:35:55 +01:00
										 |  |  | 	requestDuration.WithLabelValues(p.proxyName, p.addr, rc).Observe(time.Since(start).Seconds())
 | 
					
						
							| 
									
										
										
										
											2018-07-07 10:14:21 +03:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2018-02-05 22:00:47 +00:00
										 |  |  | 	return ret, nil
 | 
					
						
							|  |  |  | }
 | 
					
						
							| 
									
										
										
										
											2018-04-11 09:50:06 +03:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2018-06-15 02:37:22 -04:00
										 |  |  | const cumulativeAvgWeight = 4
 | 
					
						
							| 
									
										
										
										
											2023-09-07 12:01:45 -07:00
										 |  |  | 
 | 
					
						
							|  |  |  | // Function to determine if a response should be truncated.
 | 
					
						
							|  |  |  | func shouldTruncateResponse(err error) bool {
 | 
					
						
							|  |  |  | 	// This is to handle a scenario in which upstream sets the TC bit, but doesn't truncate the response
 | 
					
						
							|  |  |  | 	// and we get ErrBuf instead of overflow.
 | 
					
						
							|  |  |  | 	if _, isDNSErr := err.(*dns.Error); isDNSErr && errors.Is(err, dns.ErrBuf) {
 | 
					
						
							|  |  |  | 		return true
 | 
					
						
							|  |  |  | 	} else if strings.Contains(err.Error(), "overflow") {
 | 
					
						
							|  |  |  | 		return true
 | 
					
						
							|  |  |  | 	}
 | 
					
						
							|  |  |  | 	return false
 | 
					
						
							|  |  |  | }
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | // Function to return an empty response with TC (truncated) bit set.
 | 
					
						
							|  |  |  | func truncateResponse(response *dns.Msg) *dns.Msg {
 | 
					
						
							|  |  |  | 	// Clear out Answer, Extra, and Ns sections
 | 
					
						
							|  |  |  | 	response.Answer = nil
 | 
					
						
							|  |  |  | 	response.Extra = nil
 | 
					
						
							|  |  |  | 	response.Ns = nil
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	// Set TC bit to indicate truncation.
 | 
					
						
							|  |  |  | 	response.Truncated = true
 | 
					
						
							|  |  |  | 	return response
 | 
					
						
							|  |  |  | }
 |