mirror of
				https://github.com/coredns/coredns.git
				synced 2025-10-30 17:53:21 -04:00 
			
		
		
		
	plugin/forward: added support for per-nameserver TLS SNI (#7633)
This commit is contained in:
		| @@ -78,11 +78,16 @@ forward FROM TO... { | ||||
|     The server certificate is verified using the specified CA file | ||||
|  | ||||
| * `tls_servername` **NAME** allows you to set a server name in the TLS configuration; for instance 9.9.9.9 | ||||
|   needs this to be set to `dns.quad9.net`. Multiple upstreams are still allowed in this scenario, | ||||
|   but they have to use the same `tls_servername`. E.g. mixing 9.9.9.9 (QuadDNS) with 1.1.1.1 | ||||
|   (Cloudflare) will not work. Using TLS forwarding but not setting `tls_servername` results in anyone | ||||
|   needs this to be set to `dns.quad9.net`. Using TLS forwarding but not setting `tls_servername` results in anyone | ||||
|   being able to man-in-the-middle your connection to the DNS server you are forwarding to. Because of this, | ||||
|   it is strongly recommended to set this value when using TLS forwarding. | ||||
|  | ||||
|   Per destination endpoint TLS server name indication is possible in the form of `tls://9.9.9.9%dns.quad9.net`. | ||||
|   `tls_servername` must not be specified when using per destination endpoint TLS server name indication | ||||
|   as it would introduce clash between the server name indication spectifications. If destination endpoint | ||||
|   is to be reached via a port other than 853 then the port must be appended to the end of the destination | ||||
|   endpoint specifier. In case of port 10853, the above string would be: `tls://9.9.9.9%dns.quad9.net:10853`. | ||||
|  | ||||
| * `policy` specifies the policy to use for selecting upstream servers. The default is `random`. | ||||
|   * `random` is a policy that implements random upstream selection. | ||||
|   * `round_robin` is a policy that selects hosts based on round robin ordering. | ||||
|   | ||||
| @@ -97,6 +97,22 @@ func parseForward(c *caddy.Controller) ([]*Forward, error) { | ||||
| 	return fs, nil | ||||
| } | ||||
|  | ||||
| // Splits the zone, preserving any port that comes after the zone | ||||
| func splitZone(host string) (newHost string, zone string) { | ||||
| 	newHost = host | ||||
| 	if strings.Contains(host, "%") { | ||||
| 		lastPercent := strings.LastIndex(host, "%") | ||||
| 		newHost = host[:lastPercent] | ||||
| 		zone = host[lastPercent+1:] | ||||
| 		if strings.Contains(zone, ":") { | ||||
| 			lastColon := strings.LastIndex(zone, ":") | ||||
| 			newHost += zone[lastColon:] | ||||
| 			zone = zone[:lastColon] | ||||
| 		} | ||||
| 	} | ||||
| 	return | ||||
| } | ||||
|  | ||||
| func parseStanza(c *caddy.Controller) (*Forward, error) { | ||||
| 	f := New() | ||||
|  | ||||
| @@ -124,27 +140,46 @@ func parseStanza(c *caddy.Controller) (*Forward, error) { | ||||
| 		return f, err | ||||
| 	} | ||||
|  | ||||
| 	transports := make([]string, len(toHosts)) | ||||
| 	allowedTrans := map[string]bool{"dns": true, "tls": true} | ||||
| 	for i, host := range toHosts { | ||||
| 		trans, h := parse.Transport(host) | ||||
|  | ||||
| 		if !allowedTrans[trans] { | ||||
| 			return f, fmt.Errorf("'%s' is not supported as a destination protocol in forward: %s", trans, host) | ||||
| 		} | ||||
| 		p := proxy.NewProxy("forward", h, trans) | ||||
| 		f.proxies = append(f.proxies, p) | ||||
| 		transports[i] = trans | ||||
| 	} | ||||
|  | ||||
| 	for c.NextBlock() { | ||||
| 		if err := parseBlock(c, f); err != nil { | ||||
| 			return f, err | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	tlsServerNames := make([]string, len(toHosts)) | ||||
| 	perServerNameProxyCount := make(map[string]int) | ||||
| 	transports := make([]string, len(toHosts)) | ||||
| 	allowedTrans := map[string]bool{"dns": true, "tls": true} | ||||
| 	for i, hostWithZone := range toHosts { | ||||
| 		host, serverName := splitZone(hostWithZone) | ||||
| 		trans, h := parse.Transport(host) | ||||
|  | ||||
| 		if !allowedTrans[trans] { | ||||
| 			return f, fmt.Errorf("'%s' is not supported as a destination protocol in forward: %s", trans, host) | ||||
| 		} | ||||
| 		if trans == transport.TLS && serverName != "" { | ||||
| 			if f.tlsServerName != "" { | ||||
| 				return f, fmt.Errorf("both forward ('%s') and proxy level ('%s') TLS servernames are set for upstream proxy '%s'", f.tlsServerName, serverName, host) | ||||
| 			} | ||||
|  | ||||
| 			tlsServerNames[i] = serverName | ||||
| 			perServerNameProxyCount[serverName]++ | ||||
| 		} | ||||
| 		p := proxy.NewProxy("forward", h, trans) | ||||
| 		f.proxies = append(f.proxies, p) | ||||
| 		transports[i] = trans | ||||
| 	} | ||||
|  | ||||
| 	perServerNameTlsConfig := make(map[string]*tls.Config) | ||||
| 	if f.tlsServerName != "" { | ||||
| 		f.tlsConfig.ServerName = f.tlsServerName | ||||
| 	} else { | ||||
| 		for serverName, proxyCount := range perServerNameProxyCount { | ||||
| 			tlsConfig := f.tlsConfig.Clone() | ||||
| 			tlsConfig.ServerName = serverName | ||||
| 			tlsConfig.ClientSessionCache = tls.NewLRUClientSessionCache(proxyCount) | ||||
| 			perServerNameTlsConfig[serverName] = tlsConfig | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	// Initialize ClientSessionCache in tls.Config. This may speed up a TLS handshake | ||||
| @@ -154,7 +189,11 @@ func parseStanza(c *caddy.Controller) (*Forward, error) { | ||||
| 	for i := range f.proxies { | ||||
| 		// Only set this for proxies that need it. | ||||
| 		if transports[i] == transport.TLS { | ||||
| 			f.proxies[i].SetTLSConfig(f.tlsConfig) | ||||
| 			if tlsConfig, ok := perServerNameTlsConfig[tlsServerNames[i]]; ok { | ||||
| 				f.proxies[i].SetTLSConfig(tlsConfig) | ||||
| 			} else { | ||||
| 				f.proxies[i].SetTLSConfig(f.tlsConfig) | ||||
| 			} | ||||
| 		} | ||||
| 		f.proxies[i].SetExpire(f.expire) | ||||
| 		f.proxies[i].GetHealthchecker().SetRecursionDesired(f.opts.HCRecursionDesired) | ||||
|   | ||||
| @@ -90,6 +90,36 @@ func TestSetup(t *testing.T) { | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func TestSplitZone(t *testing.T) { | ||||
| 	tests := []struct { | ||||
| 		input        string | ||||
| 		expectedHost string | ||||
| 		expectedZone string | ||||
| 	}{ | ||||
| 		{ | ||||
| 			"tls://127.0.0.1%example.net:854", "tls://127.0.0.1:854", "example.net", | ||||
| 		}, { | ||||
| 			"tls://127.0.0.1%example.net", "tls://127.0.0.1", "example.net", | ||||
| 		}, { | ||||
| 			"tls://127.0.0.1:854", "tls://127.0.0.1:854", "", | ||||
| 		}, { | ||||
| 			"dns://127.0.0.1", "dns://127.0.0.1", "", | ||||
| 		}, { | ||||
| 			"foo%bar:baz", "foo:baz", "bar", | ||||
| 		}, | ||||
| 	} | ||||
| 	for i, test := range tests { | ||||
| 		host, zone := splitZone(test.input) | ||||
|  | ||||
| 		if host != test.expectedHost { | ||||
| 			t.Errorf("Test %d: expected host %q, actual: %q", i, test.expectedHost, host) | ||||
| 		} | ||||
| 		if zone != test.expectedZone { | ||||
| 			t.Errorf("Test %d: expected host %q, actual: %q", i, test.expectedHost, host) | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func TestSetupTLS(t *testing.T) { | ||||
| 	tests := []struct { | ||||
| 		input              string | ||||
| @@ -101,6 +131,19 @@ func TestSetupTLS(t *testing.T) { | ||||
| 		{`forward . tls://127.0.0.1 { | ||||
| 				tls_servername dns | ||||
| 			}`, false, "dns", ""}, | ||||
| 		{`forward . tls://127.0.0.1%example.net { | ||||
| 				tls | ||||
| 			}`, false, "example.net", ""}, | ||||
| 		{`forward . tls://127.0.0.1%example.net:854 tls://127.0.0.2%example.net tls://fe80::1%example.com { | ||||
| 				tls | ||||
| 			}`, false, "example.net", ""}, | ||||
| 		{`forward . tls://127.0.0.1%example.net:854 { | ||||
| 				tls | ||||
| 			}`, false, "example.net", ""}, | ||||
| 		// SNI specifications clash test | ||||
| 		{`forward . tls://127.0.0.1%example.net:854 { | ||||
| 				tls_servername foo | ||||
| 			}`, true, "", "both forward ('foo') and proxy level ('example.net') TLS servernames are set for upstream proxy 'tls://127.0.0.1:854'"}, | ||||
| 		{`forward . 127.0.0.1 { | ||||
| 				tls_servername dns | ||||
| 			}`, false, "", ""}, | ||||
| @@ -126,16 +169,48 @@ func TestSetupTLS(t *testing.T) { | ||||
| 			if !strings.Contains(err.Error(), test.expectedErr) { | ||||
| 				t.Errorf("Test %d: expected error to contain: %v, found error: %v, input: %s", i, test.expectedErr, err, test.input) | ||||
| 			} | ||||
| 			continue | ||||
| 		} | ||||
| 		/* | ||||
| 			if len(fs) == 0 { | ||||
| 				continue | ||||
| 			} | ||||
| 		*/ | ||||
|  | ||||
| 		f := fs[0] | ||||
|  | ||||
| 		if !test.shouldErr && test.expectedServerName != "" && test.expectedServerName != f.tlsConfig.ServerName { | ||||
| 			t.Errorf("Test %d: expected: %q, actual: %q", i, test.expectedServerName, f.tlsConfig.ServerName) | ||||
| 		if !test.shouldErr && test.expectedServerName != "" && test.expectedServerName != f.proxies[0].GetTransport().GetTLSConfig().ServerName { | ||||
| 			t.Errorf("Test %d: expected server name: %q, actual: %q", i, test.expectedServerName, f.proxies[0].GetTransport().GetTLSConfig().ServerName) | ||||
| 		} | ||||
|  | ||||
| 		if !test.shouldErr && test.expectedServerName != "" && test.expectedServerName != f.proxies[0].GetHealthchecker().GetTLSConfig().ServerName { | ||||
| 			t.Errorf("Test %d: expected: %q, actual: %q", i, test.expectedServerName, f.proxies[0].GetHealthchecker().GetTLSConfig().ServerName) | ||||
| 			t.Errorf("Test %d: expected server name: %q, actual: %q", i, test.expectedServerName, f.proxies[0].GetHealthchecker().GetTLSConfig().ServerName) | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func TestSetupTLSclientSessionCacheCount(t *testing.T) { | ||||
| 	tests := []struct { | ||||
| 		input string | ||||
| 	}{ | ||||
| 		{`forward . tls://127.0.0.1%foo tls://127.0.0.2%foo tls://127.0.0.3%foo tls://127.0.0.4%bar tls://127.0.0.5%bar { }`}, | ||||
| 		{`forward . tls://127.0.0.1%foo tls://127.0.0.2%foo tls://127.0.0.3%bar tls://127.0.0.4%bar tls://127.0.0.5%bar { }`}, | ||||
| 	} | ||||
| 	for i, test := range tests { | ||||
| 		c := caddy.NewTestController("dns", test.input) | ||||
| 		fs, err := parseForward(c) | ||||
| 		if err != nil { | ||||
| 			t.Errorf("Test %d: expected no error but found one for input %s, got: %v", i, test.input, err) | ||||
| 		} | ||||
|  | ||||
| 		if fs[0].proxies[0].GetTransport().GetTLSConfig() == fs[0].proxies[len(fs[0].proxies)-1].GetTransport().GetTLSConfig() { | ||||
| 			t.Errorf("Test %d: tlsConfig is the same for both the first and last proxies", i) | ||||
| 		} | ||||
| 		if fs[0].proxies[0].GetTransport().GetTLSConfig() != fs[0].proxies[1].GetTransport().GetTLSConfig() { | ||||
| 			t.Errorf("Test %d: tlsConfig differs for the first two proxies", i) | ||||
| 		} | ||||
| 		if fs[0].proxies[len(fs[0].proxies)-1].GetTransport().GetTLSConfig() != fs[0].proxies[len(fs[0].proxies)-2].GetTransport().GetTLSConfig() { | ||||
| 			t.Errorf("Test %d: tlsConfig differs for the last two proxies", i) | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
| @@ -473,13 +548,13 @@ func TestFailover(t *testing.T) { | ||||
| 				}`, s.Addr, server_fail_s.Addr, server_refused_s.Addr), true, "Although failover is not set, as long as the first upstream is work, there should be has a record return"}, | ||||
| 	} | ||||
|  | ||||
| 	for _, testCase := range tests { | ||||
| 	for i, testCase := range tests { | ||||
| 		c := caddy.NewTestController("dns", testCase.input) | ||||
| 		fs, err := parseForward(c) | ||||
|  | ||||
| 		f := fs[0] | ||||
| 		if err != nil { | ||||
| 			t.Errorf("Failed to create forwarder: %s", err) | ||||
| 			t.Errorf("Test #%d: Failed to create forwarder: %s", i, err) | ||||
| 		} | ||||
| 		f.OnStartup() | ||||
| 		defer f.OnShutdown() | ||||
| @@ -495,11 +570,11 @@ func TestFailover(t *testing.T) { | ||||
| 		rec := dnstest.NewRecorder(&test.ResponseWriter{}) | ||||
|  | ||||
| 		if _, err := f.ServeDNS(context.TODO(), rec, m); err != nil { | ||||
| 			t.Fatal("Expected to receive reply, but didn't") | ||||
| 			t.Fatalf("Test #%d: Expected to receive reply, but didn't", i) | ||||
| 		} | ||||
|  | ||||
| 		if (len(rec.Msg.Answer) > 0) != testCase.hasRecord { | ||||
| 			t.Errorf(" %s: \n %s", testCase.failMsg, testCase.input) | ||||
| 			t.Errorf("Test #%d: %s: \n %s", i, testCase.failMsg, testCase.input) | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
|   | ||||
| @@ -151,6 +151,9 @@ func (t *Transport) SetExpire(expire time.Duration) { t.expire = expire } | ||||
| // SetTLSConfig sets the TLS config in transport. | ||||
| func (t *Transport) SetTLSConfig(cfg *tls.Config) { t.tlsConfig = cfg } | ||||
|  | ||||
| // GetTLSConfig returns the TLS config in transport. | ||||
| func (t *Transport) GetTLSConfig() *tls.Config { return t.tlsConfig } | ||||
|  | ||||
| const ( | ||||
| 	defaultExpire  = 10 * time.Second | ||||
| 	minDialTimeout = 1 * time.Second | ||||
|   | ||||
| @@ -56,6 +56,10 @@ func (p *Proxy) GetHealthchecker() HealthChecker { | ||||
| 	return p.health | ||||
| } | ||||
|  | ||||
| func (p *Proxy) GetTransport() *Transport { | ||||
| 	return p.transport | ||||
| } | ||||
|  | ||||
| func (p *Proxy) Fails() uint32 { | ||||
| 	return atomic.LoadUint32(&p.fails) | ||||
| } | ||||
|   | ||||
		Reference in New Issue
	
	Block a user