mirror of
				https://github.com/coredns/coredns.git
				synced 2025-10-27 08:14:18 -04:00 
			
		
		
		
	plugin/forward: add prefer_udp option (#1944)
* plugin/forward: add prefer_udp option * updated according to code review - fixed linter warning - removed metric parameter in Proxy.Connect()
This commit is contained in:
		
				
					committed by
					
						 Miek Gieben
						Miek Gieben
					
				
			
			
				
	
			
			
			
						parent
						
							7c41f2ce9f
						
					
				
				
					commit
					bc50901234
				
			| @@ -47,6 +47,7 @@ Extra knobs are available with an expanded syntax: | ||||
| forward FROM TO... { | ||||
|     except IGNORED_NAMES... | ||||
|     force_tcp | ||||
|     prefer_udp | ||||
|     expire DURATION | ||||
|     max_fails INTEGER | ||||
|     tls CERT KEY CA | ||||
| @@ -60,6 +61,9 @@ forward FROM TO... { | ||||
| * **IGNORED_NAMES** in `except` is a space-separated list of domains to exclude from forwarding. | ||||
|   Requests that match none of these names will be passed through. | ||||
| * `force_tcp`, use TCP even when the request comes in over UDP. | ||||
| * `prefer_udp`, try first using UDP even when the request comes in over TCP. If response is truncated | ||||
|   (TC flag set in response) then do another attempt over TCP. In case if both `force_tcp` and `prefer_udp` | ||||
|   options specified the `force_tcp` takes precedence. | ||||
| * `max_fails` is the number of subsequent failed health checks that are needed before considering | ||||
|   an upstream to be down. If 0, the upstream will never be marked as down (nor health checked). | ||||
|   Default is 2. | ||||
|   | ||||
| @@ -78,12 +78,17 @@ func (p *Proxy) updateRtt(newRtt time.Duration) { | ||||
| } | ||||
|  | ||||
| // Connect selects an upstream, sends the request and waits for a response. | ||||
| func (p *Proxy) Connect(ctx context.Context, state request.Request, forceTCP, metric bool) (*dns.Msg, error) { | ||||
| func (p *Proxy) Connect(ctx context.Context, state request.Request, opts options) (*dns.Msg, error) { | ||||
| 	start := time.Now() | ||||
|  | ||||
| 	proto := state.Proto() | ||||
| 	if forceTCP { | ||||
| 	proto := "" | ||||
| 	switch { | ||||
| 	case opts.forceTCP: // TCP flag has precedence over UDP flag | ||||
| 		proto = "tcp" | ||||
| 	case opts.preferUDP: | ||||
| 		proto = "udp" | ||||
| 	default: | ||||
| 		proto = state.Proto() | ||||
| 	} | ||||
|  | ||||
| 	conn, cached, err := p.Dial(proto) | ||||
| @@ -122,17 +127,15 @@ func (p *Proxy) Connect(ctx context.Context, state request.Request, forceTCP, me | ||||
|  | ||||
| 	p.Yield(conn) | ||||
|  | ||||
| 	if metric { | ||||
| 		rc, ok := dns.RcodeToString[ret.Rcode] | ||||
| 		if !ok { | ||||
| 			rc = strconv.Itoa(ret.Rcode) | ||||
| 		} | ||||
|  | ||||
| 		RequestCount.WithLabelValues(p.addr).Add(1) | ||||
| 		RcodeCount.WithLabelValues(rc, p.addr).Add(1) | ||||
| 		RequestDuration.WithLabelValues(p.addr).Observe(time.Since(start).Seconds()) | ||||
| 	rc, ok := dns.RcodeToString[ret.Rcode] | ||||
| 	if !ok { | ||||
| 		rc = strconv.Itoa(ret.Rcode) | ||||
| 	} | ||||
|  | ||||
| 	RequestCount.WithLabelValues(p.addr).Add(1) | ||||
| 	RcodeCount.WithLabelValues(rc, p.addr).Add(1) | ||||
| 	RequestDuration.WithLabelValues(p.addr).Observe(time.Since(start).Seconds()) | ||||
|  | ||||
| 	return ret, nil | ||||
| } | ||||
|  | ||||
|   | ||||
| @@ -33,7 +33,7 @@ type Forward struct { | ||||
| 	maxfails      uint32 | ||||
| 	expire        time.Duration | ||||
|  | ||||
| 	forceTCP bool // also here for testing | ||||
| 	opts options // also here for testing | ||||
|  | ||||
| 	Next plugin.Handler | ||||
| } | ||||
| @@ -103,9 +103,18 @@ func (f *Forward) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg | ||||
| 			ret *dns.Msg | ||||
| 			err error | ||||
| 		) | ||||
| 		opts := f.opts | ||||
| 		for { | ||||
| 			ret, err = proxy.Connect(ctx, state, f.forceTCP, true) | ||||
| 			if err != nil && err == ErrCachedClosed { // Remote side closed conn, can only happen with TCP. | ||||
| 			ret, err = proxy.Connect(ctx, state, opts) | ||||
| 			if err == nil { | ||||
| 				break | ||||
| 			} | ||||
| 			if err == ErrCachedClosed { // Remote side closed conn, can only happen with TCP. | ||||
| 				continue | ||||
| 			} | ||||
| 			// Retry with TCP if truncated and prefer_udp configured | ||||
| 			if err == dns.ErrTruncated && !opts.forceTCP && f.opts.preferUDP { | ||||
| 				opts.forceTCP = true | ||||
| 				continue | ||||
| 			} | ||||
| 			break | ||||
| @@ -183,7 +192,10 @@ func (f *Forward) isAllowedDomain(name string) bool { | ||||
| func (f *Forward) From() string { return f.from } | ||||
|  | ||||
| // ForceTCP returns if TCP is forced to be used even when the request comes in over UDP. | ||||
| func (f *Forward) ForceTCP() bool { return f.forceTCP } | ||||
| func (f *Forward) ForceTCP() bool { return f.opts.forceTCP } | ||||
|  | ||||
| // PreferUDP returns if UDP is preferred to be used even when the request comes in over TCP. | ||||
| func (f *Forward) PreferUDP() bool { return f.opts.preferUDP } | ||||
|  | ||||
| // List returns a set of proxies to be used for this client depending on the policy in f. | ||||
| func (f *Forward) List() []*Proxy { return f.p.List(f.proxies) } | ||||
| @@ -206,4 +218,9 @@ const ( | ||||
| 	sequentialPolicy | ||||
| ) | ||||
|  | ||||
| type options struct { | ||||
| 	forceTCP  bool | ||||
| 	preferUDP bool | ||||
| } | ||||
|  | ||||
| const defaultTimeout = 5 * time.Second | ||||
|   | ||||
| @@ -32,7 +32,7 @@ func (f *Forward) Forward(state request.Request) (*dns.Msg, error) { | ||||
| 			proxy = f.List()[0] | ||||
| 		} | ||||
|  | ||||
| 		ret, err := proxy.Connect(context.Background(), state, f.forceTCP, true) | ||||
| 		ret, err := proxy.Connect(context.Background(), state, f.opts) | ||||
|  | ||||
| 		ret, err = truncated(state, ret, err) | ||||
| 		upstreamErr = err | ||||
|   | ||||
| @@ -29,10 +29,10 @@ func TestProxyClose(t *testing.T) { | ||||
| 		p := NewProxy(s.Addr, nil) | ||||
| 		p.start(hcInterval) | ||||
|  | ||||
| 		go func() { p.Connect(ctx, state, false, false) }() | ||||
| 		go func() { p.Connect(ctx, state, true, false) }() | ||||
| 		go func() { p.Connect(ctx, state, false, false) }() | ||||
| 		go func() { p.Connect(ctx, state, true, false) }() | ||||
| 		go func() { p.Connect(ctx, state, options{}) }() | ||||
| 		go func() { p.Connect(ctx, state, options{forceTCP: true}) }() | ||||
| 		go func() { p.Connect(ctx, state, options{}) }() | ||||
| 		go func() { p.Connect(ctx, state, options{forceTCP: true}) }() | ||||
|  | ||||
| 		p.close() | ||||
| 	} | ||||
| @@ -93,3 +93,30 @@ func TestProxyTLSFail(t *testing.T) { | ||||
| 		t.Fatal("Expected *not* to receive reply, but got one") | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func TestProtocolSelection(t *testing.T) { | ||||
| 	p := NewProxy("bad_address", nil) | ||||
|  | ||||
| 	stateUDP := request.Request{W: &test.ResponseWriter{}, Req: new(dns.Msg)} | ||||
| 	stateTCP := request.Request{W: &test.ResponseWriter{TCP: true}, Req: new(dns.Msg)} | ||||
| 	ctx := context.TODO() | ||||
|  | ||||
| 	go func() { | ||||
| 		p.Connect(ctx, stateUDP, options{}) | ||||
| 		p.Connect(ctx, stateUDP, options{forceTCP: true}) | ||||
| 		p.Connect(ctx, stateUDP, options{preferUDP: true}) | ||||
| 		p.Connect(ctx, stateUDP, options{preferUDP: true, forceTCP: true}) | ||||
| 		p.Connect(ctx, stateTCP, options{}) | ||||
| 		p.Connect(ctx, stateTCP, options{forceTCP: true}) | ||||
| 		p.Connect(ctx, stateTCP, options{preferUDP: true}) | ||||
| 		p.Connect(ctx, stateTCP, options{preferUDP: true, forceTCP: true}) | ||||
| 	}() | ||||
|  | ||||
| 	for i, exp := range []string{"udp", "tcp", "udp", "tcp", "tcp", "tcp", "udp", "tcp"} { | ||||
| 		proto := <-p.transport.dial | ||||
| 		p.transport.ret <- nil | ||||
| 		if proto != exp { | ||||
| 			t.Errorf("Unexpected protocol in case %d, expected %q, actual %q", i, exp, proto) | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
|   | ||||
| @@ -187,7 +187,12 @@ func parseBlock(c *caddy.Controller, f *Forward) error { | ||||
| 		if c.NextArg() { | ||||
| 			return c.ArgErr() | ||||
| 		} | ||||
| 		f.forceTCP = true | ||||
| 		f.opts.forceTCP = true | ||||
| 	case "prefer_udp": | ||||
| 		if c.NextArg() { | ||||
| 			return c.ArgErr() | ||||
| 		} | ||||
| 		f.opts.preferUDP = true | ||||
| 	case "tls": | ||||
| 		args := c.RemainingArgs() | ||||
| 		if len(args) > 3 { | ||||
|   | ||||
| @@ -10,28 +10,30 @@ import ( | ||||
|  | ||||
| func TestSetup(t *testing.T) { | ||||
| 	tests := []struct { | ||||
| 		input            string | ||||
| 		shouldErr        bool | ||||
| 		expectedFrom     string | ||||
| 		expectedIgnored  []string | ||||
| 		expectedFails    uint32 | ||||
| 		expectedForceTCP bool | ||||
| 		expectedErr      string | ||||
| 		input           string | ||||
| 		shouldErr       bool | ||||
| 		expectedFrom    string | ||||
| 		expectedIgnored []string | ||||
| 		expectedFails   uint32 | ||||
| 		expectedOpts    options | ||||
| 		expectedErr     string | ||||
| 	}{ | ||||
| 		// positive | ||||
| 		{"forward . 127.0.0.1", false, ".", nil, 2, false, ""}, | ||||
| 		{"forward . 127.0.0.1 {\nexcept miek.nl\n}\n", false, ".", nil, 2, false, ""}, | ||||
| 		{"forward . 127.0.0.1 {\nmax_fails 3\n}\n", false, ".", nil, 3, false, ""}, | ||||
| 		{"forward . 127.0.0.1 {\nforce_tcp\n}\n", false, ".", nil, 2, true, ""}, | ||||
| 		{"forward . 127.0.0.1:53", false, ".", nil, 2, false, ""}, | ||||
| 		{"forward . 127.0.0.1:8080", false, ".", nil, 2, false, ""}, | ||||
| 		{"forward . [::1]:53", false, ".", nil, 2, false, ""}, | ||||
| 		{"forward . [2003::1]:53", false, ".", nil, 2, false, ""}, | ||||
| 		{"forward . 127.0.0.1", false, ".", nil, 2, options{}, ""}, | ||||
| 		{"forward . 127.0.0.1 {\nexcept miek.nl\n}\n", false, ".", nil, 2, options{}, ""}, | ||||
| 		{"forward . 127.0.0.1 {\nmax_fails 3\n}\n", false, ".", nil, 3, options{}, ""}, | ||||
| 		{"forward . 127.0.0.1 {\nforce_tcp\n}\n", false, ".", nil, 2, options{forceTCP: true}, ""}, | ||||
| 		{"forward . 127.0.0.1 {\nprefer_udp\n}\n", false, ".", nil, 2, options{preferUDP: true}, ""}, | ||||
| 		{"forward . 127.0.0.1 {\nforce_tcp\nprefer_udp\n}\n", false, ".", nil, 2, options{preferUDP: true, forceTCP: true}, ""}, | ||||
| 		{"forward . 127.0.0.1:53", false, ".", nil, 2, options{}, ""}, | ||||
| 		{"forward . 127.0.0.1:8080", false, ".", nil, 2, options{}, ""}, | ||||
| 		{"forward . [::1]:53", false, ".", nil, 2, options{}, ""}, | ||||
| 		{"forward . [2003::1]:53", false, ".", nil, 2, options{}, ""}, | ||||
| 		// negative | ||||
| 		{"forward . a27.0.0.1", true, "", nil, 0, false, "not an IP"}, | ||||
| 		{"forward . 127.0.0.1 {\nblaatl\n}\n", true, "", nil, 0, false, "unknown property"}, | ||||
| 		{"forward . a27.0.0.1", true, "", nil, 0, options{}, "not an IP"}, | ||||
| 		{"forward . 127.0.0.1 {\nblaatl\n}\n", true, "", nil, 0, options{}, "unknown property"}, | ||||
| 		{`forward . ::1 | ||||
| 		forward com ::2`, true, "", nil, 0, false, "plugin"}, | ||||
| 		forward com ::2`, true, "", nil, 0, options{}, "plugin"}, | ||||
| 	} | ||||
|  | ||||
| 	for i, test := range tests { | ||||
| @@ -63,8 +65,8 @@ func TestSetup(t *testing.T) { | ||||
| 		if !test.shouldErr && f.maxfails != test.expectedFails { | ||||
| 			t.Errorf("Test %d: expected: %d, got: %d", i, test.expectedFails, f.maxfails) | ||||
| 		} | ||||
| 		if !test.shouldErr && f.forceTCP != test.expectedForceTCP { | ||||
| 			t.Errorf("Test %d: expected: %t, got: %t", i, test.expectedForceTCP, f.forceTCP) | ||||
| 		if !test.shouldErr && f.opts != test.expectedOpts { | ||||
| 			t.Errorf("Test %d: expected: %v, got: %v", i, test.expectedOpts, f.opts) | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
|   | ||||
| @@ -9,12 +9,17 @@ import ( | ||||
| // ResponseWriter is useful for writing tests. It uses some fixed values for the client. The | ||||
| // remote will always be 10.240.0.1 and port 40212. The local address is always 127.0.0.1 and | ||||
| // port 53. | ||||
| type ResponseWriter struct{} | ||||
| type ResponseWriter struct { | ||||
| 	TCP bool | ||||
| } | ||||
|  | ||||
| // LocalAddr returns the local address, always 127.0.0.1:53 (UDP). | ||||
| func (t *ResponseWriter) LocalAddr() net.Addr { | ||||
| 	ip := net.ParseIP("127.0.0.1") | ||||
| 	port := 53 | ||||
| 	if t.TCP { | ||||
| 		return &net.TCPAddr{IP: ip, Port: port, Zone: ""} | ||||
| 	} | ||||
| 	return &net.UDPAddr{IP: ip, Port: port, Zone: ""} | ||||
| } | ||||
|  | ||||
| @@ -22,6 +27,9 @@ func (t *ResponseWriter) LocalAddr() net.Addr { | ||||
| func (t *ResponseWriter) RemoteAddr() net.Addr { | ||||
| 	ip := net.ParseIP("10.240.0.1") | ||||
| 	port := 40212 | ||||
| 	if t.TCP { | ||||
| 		return &net.TCPAddr{IP: ip, Port: port, Zone: ""} | ||||
| 	} | ||||
| 	return &net.UDPAddr{IP: ip, Port: port, Zone: ""} | ||||
| } | ||||
|  | ||||
| @@ -52,10 +60,16 @@ type ResponseWriter6 struct { | ||||
|  | ||||
| // LocalAddr returns the local address, always ::1, port 53 (UDP). | ||||
| func (t *ResponseWriter6) LocalAddr() net.Addr { | ||||
| 	if t.TCP { | ||||
| 		return &net.TCPAddr{IP: net.ParseIP("::1"), Port: 53, Zone: ""} | ||||
| 	} | ||||
| 	return &net.UDPAddr{IP: net.ParseIP("::1"), Port: 53, Zone: ""} | ||||
| } | ||||
|  | ||||
| // RemoteAddr returns the remote address, always fe80::42:ff:feca:4c65 port 40212 (UDP). | ||||
| func (t *ResponseWriter6) RemoteAddr() net.Addr { | ||||
| 	if t.TCP { | ||||
| 		return &net.TCPAddr{IP: net.ParseIP("fe80::42:ff:feca:4c65"), Port: 40212, Zone: ""} | ||||
| 	} | ||||
| 	return &net.UDPAddr{IP: net.ParseIP("fe80::42:ff:feca:4c65"), Port: 40212, Zone: ""} | ||||
| } | ||||
|   | ||||
		Reference in New Issue
	
	Block a user