| 
									
										
										
										
											2023-03-24 12:55:51 +00:00
										 |  |  | // Package proxy implements a forwarding proxy. It caches an upstream net.Conn for some time, so if the same | 
					
						
							| 
									
										
										
										
											2018-02-05 22:00:47 +00:00
										 |  |  | // client returns the upstream's Conn will be precached. Depending on how you benchmark this looks to be | 
					
						
							| 
									
										
										
										
											2018-08-14 17:55:55 +02:00
										 |  |  | // 50% faster than just opening a new connection for every client. It works with UDP and TCP and uses | 
					
						
							| 
									
										
										
										
											2018-02-05 22:00:47 +00:00
										 |  |  | // inband healthchecking. | 
					
						
							| 
									
										
										
										
											2023-03-24 12:55:51 +00:00
										 |  |  | package proxy | 
					
						
							| 
									
										
										
										
											2018-02-05 22:00:47 +00:00
										 |  |  | 
 | 
					
						
							|  |  |  | import ( | 
					
						
							| 
									
										
										
										
											2018-04-22 08:34:35 +01:00
										 |  |  | 	"context" | 
					
						
							| 
									
										
										
										
											2023-09-07 12:01:45 -07:00
										 |  |  | 	"errors" | 
					
						
							| 
									
										
										
										
											2018-04-06 15:41:48 +03:00
										 |  |  | 	"io" | 
					
						
							| 
									
										
										
										
											2018-02-05 22:00:47 +00:00
										 |  |  | 	"strconv" | 
					
						
							| 
									
										
										
										
											2023-09-07 12:01:45 -07:00
										 |  |  | 	"strings" | 
					
						
							| 
									
										
										
										
											2018-04-11 09:50:06 +03:00
										 |  |  | 	"sync/atomic" | 
					
						
							| 
									
										
										
										
											2018-02-05 22:00:47 +00:00
										 |  |  | 	"time" | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	"github.com/coredns/coredns/request" | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	"github.com/miekg/dns" | 
					
						
							|  |  |  | ) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-05-28 16:58:48 +03:00
										 |  |  | const ( | 
					
						
							|  |  |  | 	ErrTransportStopped              = "proxy: transport stopped" | 
					
						
							|  |  |  | 	ErrTransportStoppedDuringDial    = "proxy: transport stopped during dial" | 
					
						
							|  |  |  | 	ErrTransportStoppedRetClosed     = "proxy: transport stopped, ret channel closed" | 
					
						
							|  |  |  | 	ErrTransportStoppedDuringRetWait = "proxy: transport stopped during ret wait" | 
					
						
							|  |  |  | ) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2018-06-15 02:37:22 -04:00
										 |  |  | // limitTimeout is a utility function to auto-tune timeout values | 
					
						
							|  |  |  | // average observed time is moved towards the last observed delay moderated by a weight | 
					
						
							|  |  |  | // next timeout to use will be the double of the computed average, limited by min and max frame. | 
					
						
							|  |  |  | func limitTimeout(currentAvg *int64, minValue time.Duration, maxValue time.Duration) time.Duration { | 
					
						
							|  |  |  | 	rt := time.Duration(atomic.LoadInt64(currentAvg)) | 
					
						
							|  |  |  | 	if rt < minValue { | 
					
						
							|  |  |  | 		return minValue | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 	if rt < maxValue/2 { | 
					
						
							|  |  |  | 		return 2 * rt | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 	return maxValue | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | func averageTimeout(currentAvg *int64, observedDuration time.Duration, weight int64) { | 
					
						
							|  |  |  | 	dt := time.Duration(atomic.LoadInt64(currentAvg)) | 
					
						
							|  |  |  | 	atomic.AddInt64(currentAvg, int64(observedDuration-dt)/weight) | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2018-09-19 07:29:37 +01:00
										 |  |  | func (t *Transport) dialTimeout() time.Duration { | 
					
						
							| 
									
										
										
										
											2018-06-15 02:37:22 -04:00
										 |  |  | 	return limitTimeout(&t.avgDialTime, minDialTimeout, maxDialTimeout) | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2018-09-19 07:29:37 +01:00
										 |  |  | func (t *Transport) updateDialTimeout(newDialTime time.Duration) { | 
					
						
							| 
									
										
										
										
											2018-06-15 02:37:22 -04:00
										 |  |  | 	averageTimeout(&t.avgDialTime, newDialTime, cumulativeAvgWeight) | 
					
						
							|  |  |  | } | 
					
						
							| 
									
										
										
										
											2018-04-16 19:51:49 +01:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2018-06-15 02:37:22 -04:00
										 |  |  | // Dial dials the address configured in transport, potentially reusing a connection or creating a new one. | 
					
						
							| 
									
										
										
										
											2019-10-01 16:39:42 +01:00
										 |  |  | func (t *Transport) Dial(proto string) (*persistConn, bool, error) { | 
					
						
							| 
									
										
										
										
											2018-06-15 02:37:22 -04:00
										 |  |  | 	// If tls has been configured; use it. | 
					
						
							|  |  |  | 	if t.tlsConfig != nil { | 
					
						
							|  |  |  | 		proto = "tcp-tls" | 
					
						
							| 
									
										
										
										
											2018-04-16 19:51:49 +01:00
										 |  |  | 	} | 
					
						
							| 
									
										
										
										
											2018-06-15 02:37:22 -04:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-05-28 16:58:48 +03:00
										 |  |  | 	// Check if transport is stopped before attempting to dial | 
					
						
							|  |  |  | 	select { | 
					
						
							|  |  |  | 	case <-t.stop: | 
					
						
							|  |  |  | 		return nil, false, errors.New(ErrTransportStopped) | 
					
						
							|  |  |  | 	default: | 
					
						
							|  |  |  | 	} | 
					
						
							| 
									
										
										
										
											2018-06-15 02:37:22 -04:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-05-28 16:58:48 +03:00
										 |  |  | 	// Use select to avoid blocking if connManager has stopped | 
					
						
							|  |  |  | 	select { | 
					
						
							|  |  |  | 	case t.dial <- proto: | 
					
						
							|  |  |  | 		// Successfully sent dial request | 
					
						
							|  |  |  | 	case <-t.stop: | 
					
						
							|  |  |  | 		return nil, false, errors.New(ErrTransportStoppedDuringDial) | 
					
						
							| 
									
										
										
										
											2018-04-11 09:50:06 +03:00
										 |  |  | 	} | 
					
						
							| 
									
										
										
										
											2018-06-15 02:37:22 -04:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-05-28 16:58:48 +03:00
										 |  |  | 	// Receive response with stop awareness | 
					
						
							|  |  |  | 	select { | 
					
						
							|  |  |  | 	case pc, ok := <-t.ret: | 
					
						
							|  |  |  | 		if !ok { | 
					
						
							|  |  |  | 			// ret channel was closed by connManager during stop | 
					
						
							|  |  |  | 			return nil, false, errors.New(ErrTransportStoppedRetClosed) | 
					
						
							|  |  |  | 		} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 		if pc != nil { | 
					
						
							|  |  |  | 			connCacheHitsCount.WithLabelValues(t.proxyName, t.addr, proto).Add(1) | 
					
						
							|  |  |  | 			return pc, true, nil | 
					
						
							|  |  |  | 		} | 
					
						
							|  |  |  | 		connCacheMissesCount.WithLabelValues(t.proxyName, t.addr, proto).Add(1) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 		reqTime := time.Now() | 
					
						
							|  |  |  | 		timeout := t.dialTimeout() | 
					
						
							|  |  |  | 		if proto == "tcp-tls" { | 
					
						
							|  |  |  | 			conn, err := dns.DialTimeoutWithTLS("tcp", t.addr, t.tlsConfig, timeout) | 
					
						
							|  |  |  | 			t.updateDialTimeout(time.Since(reqTime)) | 
					
						
							|  |  |  | 			return &persistConn{c: conn}, false, err | 
					
						
							|  |  |  | 		} | 
					
						
							|  |  |  | 		conn, err := dns.DialTimeout(proto, t.addr, timeout) | 
					
						
							| 
									
										
										
										
											2018-06-15 02:37:22 -04:00
										 |  |  | 		t.updateDialTimeout(time.Since(reqTime)) | 
					
						
							| 
									
										
										
										
											2019-10-01 16:39:42 +01:00
										 |  |  | 		return &persistConn{c: conn}, false, err | 
					
						
							| 
									
										
										
										
											2025-05-28 16:58:48 +03:00
										 |  |  | 	case <-t.stop: | 
					
						
							|  |  |  | 		return nil, false, errors.New(ErrTransportStoppedDuringRetWait) | 
					
						
							| 
									
										
										
										
											2018-06-15 02:37:22 -04:00
										 |  |  | 	} | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2018-05-04 08:47:26 +03:00
										 |  |  | // Connect selects an upstream, sends the request and waits for a response. | 
					
						
							| 
									
										
										
										
											2023-03-24 12:55:51 +00:00
										 |  |  | func (p *Proxy) Connect(ctx context.Context, state request.Request, opts Options) (*dns.Msg, error) { | 
					
						
							| 
									
										
										
										
											2018-02-05 22:00:47 +00:00
										 |  |  | 	start := time.Now() | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2018-07-07 10:14:21 +03:00
										 |  |  | 	proto := "" | 
					
						
							|  |  |  | 	switch { | 
					
						
							| 
									
										
										
										
											2023-03-24 12:55:51 +00:00
										 |  |  | 	case opts.ForceTCP: // TCP flag has precedence over UDP flag | 
					
						
							| 
									
										
										
										
											2018-02-05 22:00:47 +00:00
										 |  |  | 		proto = "tcp" | 
					
						
							| 
									
										
										
										
											2023-03-24 12:55:51 +00:00
										 |  |  | 	case opts.PreferUDP: | 
					
						
							| 
									
										
										
										
											2018-07-07 10:14:21 +03:00
										 |  |  | 		proto = "udp" | 
					
						
							|  |  |  | 	default: | 
					
						
							|  |  |  | 		proto = state.Proto() | 
					
						
							| 
									
										
										
										
											2018-02-05 22:00:47 +00:00
										 |  |  | 	} | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2019-10-01 16:39:42 +01:00
										 |  |  | 	pc, cached, err := p.transport.Dial(proto) | 
					
						
							| 
									
										
										
										
											2018-02-05 22:00:47 +00:00
										 |  |  | 	if err != nil { | 
					
						
							|  |  |  | 		return nil, err | 
					
						
							|  |  |  | 	} | 
					
						
							| 
									
										
										
										
											2018-04-26 09:34:58 +01:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2018-02-05 22:00:47 +00:00
										 |  |  | 	// Set buffer size correctly for this client. | 
					
						
							| 
									
										
										
										
											2019-10-01 16:39:42 +01:00
										 |  |  | 	pc.c.UDPSize = uint16(state.Size()) | 
					
						
							|  |  |  | 	if pc.c.UDPSize < 512 { | 
					
						
							|  |  |  | 		pc.c.UDPSize = 512 | 
					
						
							| 
									
										
										
										
											2018-02-05 22:00:47 +00:00
										 |  |  | 	} | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2019-10-01 16:39:42 +01:00
										 |  |  | 	pc.c.SetWriteDeadline(time.Now().Add(maxTimeout)) | 
					
						
							| 
									
										
										
										
											2021-10-09 01:34:22 +08:00
										 |  |  | 	// records the origin Id before upstream. | 
					
						
							|  |  |  | 	originId := state.Req.Id | 
					
						
							|  |  |  | 	state.Req.Id = dns.Id() | 
					
						
							| 
									
										
										
										
											2021-10-11 10:28:01 +00:00
										 |  |  | 	defer func() { | 
					
						
							| 
									
										
										
										
											2021-10-09 01:34:22 +08:00
										 |  |  | 		state.Req.Id = originId | 
					
						
							|  |  |  | 	}() | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2019-10-01 16:39:42 +01:00
										 |  |  | 	if err := pc.c.WriteMsg(state.Req); err != nil { | 
					
						
							|  |  |  | 		pc.c.Close() // not giving it back | 
					
						
							| 
									
										
										
										
											2018-04-06 15:41:48 +03:00
										 |  |  | 		if err == io.EOF && cached { | 
					
						
							| 
									
										
										
										
											2018-05-09 14:41:14 +03:00
										 |  |  | 			return nil, ErrCachedClosed | 
					
						
							| 
									
										
										
										
											2018-04-06 15:41:48 +03:00
										 |  |  | 		} | 
					
						
							| 
									
										
										
										
											2018-02-05 22:00:47 +00:00
										 |  |  | 		return nil, err | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2019-03-01 17:40:52 +03:00
										 |  |  | 	var ret *dns.Msg | 
					
						
							| 
									
										
										
										
											2023-03-24 12:55:51 +00:00
										 |  |  | 	pc.c.SetReadDeadline(time.Now().Add(p.readTimeout)) | 
					
						
							| 
									
										
										
										
											2019-03-01 17:40:52 +03:00
										 |  |  | 	for { | 
					
						
							| 
									
										
										
										
											2019-10-01 16:39:42 +01:00
										 |  |  | 		ret, err = pc.c.ReadMsg() | 
					
						
							| 
									
										
										
										
											2019-03-01 17:40:52 +03:00
										 |  |  | 		if err != nil { | 
					
						
							| 
									
										
										
										
											2023-09-07 12:01:45 -07:00
										 |  |  | 			if ret != nil && (state.Req.Id == ret.Id) && p.transport.transportTypeFromConn(pc) == typeUDP && shouldTruncateResponse(err) { | 
					
						
							|  |  |  | 				// For UDP, if the error is an overflow, we probably have an upstream misbehaving in some way. | 
					
						
							|  |  |  | 				// (e.g. sending >512 byte responses without an eDNS0 OPT RR). | 
					
						
							|  |  |  | 				// Instead of returning an error, return an empty response with TC bit set. This will make the | 
					
						
							|  |  |  | 				// client retry over TCP (if that's supported) or at least receive a clean | 
					
						
							|  |  |  | 				// error. The connection is still good so we break before the close. | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 				// Truncate the response. | 
					
						
							|  |  |  | 				ret = truncateResponse(ret) | 
					
						
							|  |  |  | 				break | 
					
						
							|  |  |  | 			} | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-08-14 20:33:37 -04:00
										 |  |  | 			pc.c.Close() // not giving it back | 
					
						
							| 
									
										
										
										
											2019-03-01 17:40:52 +03:00
										 |  |  | 			if err == io.EOF && cached { | 
					
						
							|  |  |  | 				return nil, ErrCachedClosed | 
					
						
							|  |  |  | 			} | 
					
						
							| 
									
										
										
										
											2023-08-14 20:33:37 -04:00
										 |  |  | 			// recovery the origin Id after upstream. | 
					
						
							| 
									
										
										
										
											2021-10-11 10:28:01 +00:00
										 |  |  | 			if ret != nil { | 
					
						
							| 
									
										
										
										
											2021-10-09 01:34:22 +08:00
										 |  |  | 				ret.Id = originId | 
					
						
							|  |  |  | 			} | 
					
						
							| 
									
										
										
										
											2019-03-01 17:40:52 +03:00
										 |  |  | 			return ret, err | 
					
						
							|  |  |  | 		} | 
					
						
							|  |  |  | 		// drop out-of-order responses | 
					
						
							|  |  |  | 		if state.Req.Id == ret.Id { | 
					
						
							|  |  |  | 			break | 
					
						
							| 
									
										
										
										
											2018-04-06 15:41:48 +03:00
										 |  |  | 		} | 
					
						
							| 
									
										
										
										
											2018-02-05 22:00:47 +00:00
										 |  |  | 	} | 
					
						
							| 
									
										
										
										
											2021-10-09 01:34:22 +08:00
										 |  |  | 	// recovery the origin Id after upstream. | 
					
						
							|  |  |  | 	ret.Id = originId | 
					
						
							| 
									
										
										
										
											2018-02-05 22:00:47 +00:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2019-10-01 16:39:42 +01:00
										 |  |  | 	p.transport.Yield(pc) | 
					
						
							| 
									
										
										
										
											2018-02-05 22:00:47 +00:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2018-07-07 10:14:21 +03:00
										 |  |  | 	rc, ok := dns.RcodeToString[ret.Rcode] | 
					
						
							|  |  |  | 	if !ok { | 
					
						
							|  |  |  | 		rc = strconv.Itoa(ret.Rcode) | 
					
						
							| 
									
										
										
										
											2018-02-05 22:00:47 +00:00
										 |  |  | 	} | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-07-04 15:35:55 +01:00
										 |  |  | 	requestDuration.WithLabelValues(p.proxyName, p.addr, rc).Observe(time.Since(start).Seconds()) | 
					
						
							| 
									
										
										
										
											2018-07-07 10:14:21 +03:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2018-02-05 22:00:47 +00:00
										 |  |  | 	return ret, nil | 
					
						
							|  |  |  | } | 
					
						
							| 
									
										
										
										
											2018-04-11 09:50:06 +03:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2018-06-15 02:37:22 -04:00
										 |  |  | const cumulativeAvgWeight = 4 | 
					
						
							| 
									
										
										
										
											2023-09-07 12:01:45 -07:00
										 |  |  | 
 | 
					
						
							|  |  |  | // Function to determine if a response should be truncated. | 
					
						
							|  |  |  | func shouldTruncateResponse(err error) bool { | 
					
						
							|  |  |  | 	// This is to handle a scenario in which upstream sets the TC bit, but doesn't truncate the response | 
					
						
							|  |  |  | 	// and we get ErrBuf instead of overflow. | 
					
						
							|  |  |  | 	if _, isDNSErr := err.(*dns.Error); isDNSErr && errors.Is(err, dns.ErrBuf) { | 
					
						
							|  |  |  | 		return true | 
					
						
							|  |  |  | 	} else if strings.Contains(err.Error(), "overflow") { | 
					
						
							|  |  |  | 		return true | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 	return false | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | // Function to return an empty response with TC (truncated) bit set. | 
					
						
							|  |  |  | func truncateResponse(response *dns.Msg) *dns.Msg { | 
					
						
							|  |  |  | 	// Clear out Answer, Extra, and Ns sections | 
					
						
							|  |  |  | 	response.Answer = nil | 
					
						
							|  |  |  | 	response.Extra = nil | 
					
						
							|  |  |  | 	response.Ns = nil | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	// Set TC bit to indicate truncation. | 
					
						
							|  |  |  | 	response.Truncated = true | 
					
						
							|  |  |  | 	return response | 
					
						
							|  |  |  | } |