mirror of
				https://github.com/coredns/coredns.git
				synced 2025-10-31 18:23:13 -04:00 
			
		
		
		
	plugin/forward: added support for per-nameserver TLS SNI (#7633)
This commit is contained in:
		| @@ -97,6 +97,22 @@ func parseForward(c *caddy.Controller) ([]*Forward, error) { | ||||
| 	return fs, nil | ||||
| } | ||||
|  | ||||
| // Splits the zone, preserving any port that comes after the zone | ||||
| func splitZone(host string) (newHost string, zone string) { | ||||
| 	newHost = host | ||||
| 	if strings.Contains(host, "%") { | ||||
| 		lastPercent := strings.LastIndex(host, "%") | ||||
| 		newHost = host[:lastPercent] | ||||
| 		zone = host[lastPercent+1:] | ||||
| 		if strings.Contains(zone, ":") { | ||||
| 			lastColon := strings.LastIndex(zone, ":") | ||||
| 			newHost += zone[lastColon:] | ||||
| 			zone = zone[:lastColon] | ||||
| 		} | ||||
| 	} | ||||
| 	return | ||||
| } | ||||
|  | ||||
| func parseStanza(c *caddy.Controller) (*Forward, error) { | ||||
| 	f := New() | ||||
|  | ||||
| @@ -124,27 +140,46 @@ func parseStanza(c *caddy.Controller) (*Forward, error) { | ||||
| 		return f, err | ||||
| 	} | ||||
|  | ||||
| 	transports := make([]string, len(toHosts)) | ||||
| 	allowedTrans := map[string]bool{"dns": true, "tls": true} | ||||
| 	for i, host := range toHosts { | ||||
| 		trans, h := parse.Transport(host) | ||||
|  | ||||
| 		if !allowedTrans[trans] { | ||||
| 			return f, fmt.Errorf("'%s' is not supported as a destination protocol in forward: %s", trans, host) | ||||
| 		} | ||||
| 		p := proxy.NewProxy("forward", h, trans) | ||||
| 		f.proxies = append(f.proxies, p) | ||||
| 		transports[i] = trans | ||||
| 	} | ||||
|  | ||||
| 	for c.NextBlock() { | ||||
| 		if err := parseBlock(c, f); err != nil { | ||||
| 			return f, err | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	tlsServerNames := make([]string, len(toHosts)) | ||||
| 	perServerNameProxyCount := make(map[string]int) | ||||
| 	transports := make([]string, len(toHosts)) | ||||
| 	allowedTrans := map[string]bool{"dns": true, "tls": true} | ||||
| 	for i, hostWithZone := range toHosts { | ||||
| 		host, serverName := splitZone(hostWithZone) | ||||
| 		trans, h := parse.Transport(host) | ||||
|  | ||||
| 		if !allowedTrans[trans] { | ||||
| 			return f, fmt.Errorf("'%s' is not supported as a destination protocol in forward: %s", trans, host) | ||||
| 		} | ||||
| 		if trans == transport.TLS && serverName != "" { | ||||
| 			if f.tlsServerName != "" { | ||||
| 				return f, fmt.Errorf("both forward ('%s') and proxy level ('%s') TLS servernames are set for upstream proxy '%s'", f.tlsServerName, serverName, host) | ||||
| 			} | ||||
|  | ||||
| 			tlsServerNames[i] = serverName | ||||
| 			perServerNameProxyCount[serverName]++ | ||||
| 		} | ||||
| 		p := proxy.NewProxy("forward", h, trans) | ||||
| 		f.proxies = append(f.proxies, p) | ||||
| 		transports[i] = trans | ||||
| 	} | ||||
|  | ||||
| 	perServerNameTlsConfig := make(map[string]*tls.Config) | ||||
| 	if f.tlsServerName != "" { | ||||
| 		f.tlsConfig.ServerName = f.tlsServerName | ||||
| 	} else { | ||||
| 		for serverName, proxyCount := range perServerNameProxyCount { | ||||
| 			tlsConfig := f.tlsConfig.Clone() | ||||
| 			tlsConfig.ServerName = serverName | ||||
| 			tlsConfig.ClientSessionCache = tls.NewLRUClientSessionCache(proxyCount) | ||||
| 			perServerNameTlsConfig[serverName] = tlsConfig | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	// Initialize ClientSessionCache in tls.Config. This may speed up a TLS handshake | ||||
| @@ -154,7 +189,11 @@ func parseStanza(c *caddy.Controller) (*Forward, error) { | ||||
| 	for i := range f.proxies { | ||||
| 		// Only set this for proxies that need it. | ||||
| 		if transports[i] == transport.TLS { | ||||
| 			f.proxies[i].SetTLSConfig(f.tlsConfig) | ||||
| 			if tlsConfig, ok := perServerNameTlsConfig[tlsServerNames[i]]; ok { | ||||
| 				f.proxies[i].SetTLSConfig(tlsConfig) | ||||
| 			} else { | ||||
| 				f.proxies[i].SetTLSConfig(f.tlsConfig) | ||||
| 			} | ||||
| 		} | ||||
| 		f.proxies[i].SetExpire(f.expire) | ||||
| 		f.proxies[i].GetHealthchecker().SetRecursionDesired(f.opts.HCRecursionDesired) | ||||
|   | ||||
		Reference in New Issue
	
	Block a user