plugin/forward: added support for per-nameserver TLS SNI (#7633)

This commit is contained in:
Endre Szabo
2025-10-27 16:43:30 +01:00
committed by GitHub
parent b72d267a29
commit d68cbedbb1
5 changed files with 150 additions and 24 deletions

View File

@@ -97,6 +97,22 @@ func parseForward(c *caddy.Controller) ([]*Forward, error) {
return fs, nil
}
// Splits the zone, preserving any port that comes after the zone
func splitZone(host string) (newHost string, zone string) {
newHost = host
if strings.Contains(host, "%") {
lastPercent := strings.LastIndex(host, "%")
newHost = host[:lastPercent]
zone = host[lastPercent+1:]
if strings.Contains(zone, ":") {
lastColon := strings.LastIndex(zone, ":")
newHost += zone[lastColon:]
zone = zone[:lastColon]
}
}
return
}
func parseStanza(c *caddy.Controller) (*Forward, error) {
f := New()
@@ -124,27 +140,46 @@ func parseStanza(c *caddy.Controller) (*Forward, error) {
return f, err
}
transports := make([]string, len(toHosts))
allowedTrans := map[string]bool{"dns": true, "tls": true}
for i, host := range toHosts {
trans, h := parse.Transport(host)
if !allowedTrans[trans] {
return f, fmt.Errorf("'%s' is not supported as a destination protocol in forward: %s", trans, host)
}
p := proxy.NewProxy("forward", h, trans)
f.proxies = append(f.proxies, p)
transports[i] = trans
}
for c.NextBlock() {
if err := parseBlock(c, f); err != nil {
return f, err
}
}
tlsServerNames := make([]string, len(toHosts))
perServerNameProxyCount := make(map[string]int)
transports := make([]string, len(toHosts))
allowedTrans := map[string]bool{"dns": true, "tls": true}
for i, hostWithZone := range toHosts {
host, serverName := splitZone(hostWithZone)
trans, h := parse.Transport(host)
if !allowedTrans[trans] {
return f, fmt.Errorf("'%s' is not supported as a destination protocol in forward: %s", trans, host)
}
if trans == transport.TLS && serverName != "" {
if f.tlsServerName != "" {
return f, fmt.Errorf("both forward ('%s') and proxy level ('%s') TLS servernames are set for upstream proxy '%s'", f.tlsServerName, serverName, host)
}
tlsServerNames[i] = serverName
perServerNameProxyCount[serverName]++
}
p := proxy.NewProxy("forward", h, trans)
f.proxies = append(f.proxies, p)
transports[i] = trans
}
perServerNameTlsConfig := make(map[string]*tls.Config)
if f.tlsServerName != "" {
f.tlsConfig.ServerName = f.tlsServerName
} else {
for serverName, proxyCount := range perServerNameProxyCount {
tlsConfig := f.tlsConfig.Clone()
tlsConfig.ServerName = serverName
tlsConfig.ClientSessionCache = tls.NewLRUClientSessionCache(proxyCount)
perServerNameTlsConfig[serverName] = tlsConfig
}
}
// Initialize ClientSessionCache in tls.Config. This may speed up a TLS handshake
@@ -154,7 +189,11 @@ func parseStanza(c *caddy.Controller) (*Forward, error) {
for i := range f.proxies {
// Only set this for proxies that need it.
if transports[i] == transport.TLS {
f.proxies[i].SetTLSConfig(f.tlsConfig)
if tlsConfig, ok := perServerNameTlsConfig[tlsServerNames[i]]; ok {
f.proxies[i].SetTLSConfig(tlsConfig)
} else {
f.proxies[i].SetTLSConfig(f.tlsConfig)
}
}
f.proxies[i].SetExpire(f.expire)
f.proxies[i].GetHealthchecker().SetRecursionDesired(f.opts.HCRecursionDesired)