From d68cbedbb117097568885eb94c8c80283293bb6e Mon Sep 17 00:00:00 2001 From: Endre Szabo Date: Mon, 27 Oct 2025 16:43:30 +0100 Subject: [PATCH] plugin/forward: added support for per-nameserver TLS SNI (#7633) --- plugin/forward/README.md | 11 +++-- plugin/forward/setup.go | 67 +++++++++++++++++++------ plugin/forward/setup_test.go | 89 +++++++++++++++++++++++++++++++--- plugin/pkg/proxy/persistent.go | 3 ++ plugin/pkg/proxy/proxy.go | 4 ++ 5 files changed, 150 insertions(+), 24 deletions(-) diff --git a/plugin/forward/README.md b/plugin/forward/README.md index 1cc01d43e..0eea32c89 100644 --- a/plugin/forward/README.md +++ b/plugin/forward/README.md @@ -78,11 +78,16 @@ forward FROM TO... { The server certificate is verified using the specified CA file * `tls_servername` **NAME** allows you to set a server name in the TLS configuration; for instance 9.9.9.9 - needs this to be set to `dns.quad9.net`. Multiple upstreams are still allowed in this scenario, - but they have to use the same `tls_servername`. E.g. mixing 9.9.9.9 (QuadDNS) with 1.1.1.1 - (Cloudflare) will not work. Using TLS forwarding but not setting `tls_servername` results in anyone + needs this to be set to `dns.quad9.net`. Using TLS forwarding but not setting `tls_servername` results in anyone being able to man-in-the-middle your connection to the DNS server you are forwarding to. Because of this, it is strongly recommended to set this value when using TLS forwarding. + + Per destination endpoint TLS server name indication is possible in the form of `tls://9.9.9.9%dns.quad9.net`. + `tls_servername` must not be specified when using per destination endpoint TLS server name indication + as it would introduce clash between the server name indication spectifications. If destination endpoint + is to be reached via a port other than 853 then the port must be appended to the end of the destination + endpoint specifier. In case of port 10853, the above string would be: `tls://9.9.9.9%dns.quad9.net:10853`. + * `policy` specifies the policy to use for selecting upstream servers. The default is `random`. * `random` is a policy that implements random upstream selection. * `round_robin` is a policy that selects hosts based on round robin ordering. diff --git a/plugin/forward/setup.go b/plugin/forward/setup.go index 0aca2f4c4..6822e8a5d 100644 --- a/plugin/forward/setup.go +++ b/plugin/forward/setup.go @@ -97,6 +97,22 @@ func parseForward(c *caddy.Controller) ([]*Forward, error) { return fs, nil } +// Splits the zone, preserving any port that comes after the zone +func splitZone(host string) (newHost string, zone string) { + newHost = host + if strings.Contains(host, "%") { + lastPercent := strings.LastIndex(host, "%") + newHost = host[:lastPercent] + zone = host[lastPercent+1:] + if strings.Contains(zone, ":") { + lastColon := strings.LastIndex(zone, ":") + newHost += zone[lastColon:] + zone = zone[:lastColon] + } + } + return +} + func parseStanza(c *caddy.Controller) (*Forward, error) { f := New() @@ -124,27 +140,46 @@ func parseStanza(c *caddy.Controller) (*Forward, error) { return f, err } - transports := make([]string, len(toHosts)) - allowedTrans := map[string]bool{"dns": true, "tls": true} - for i, host := range toHosts { - trans, h := parse.Transport(host) - - if !allowedTrans[trans] { - return f, fmt.Errorf("'%s' is not supported as a destination protocol in forward: %s", trans, host) - } - p := proxy.NewProxy("forward", h, trans) - f.proxies = append(f.proxies, p) - transports[i] = trans - } - for c.NextBlock() { if err := parseBlock(c, f); err != nil { return f, err } } + tlsServerNames := make([]string, len(toHosts)) + perServerNameProxyCount := make(map[string]int) + transports := make([]string, len(toHosts)) + allowedTrans := map[string]bool{"dns": true, "tls": true} + for i, hostWithZone := range toHosts { + host, serverName := splitZone(hostWithZone) + trans, h := parse.Transport(host) + + if !allowedTrans[trans] { + return f, fmt.Errorf("'%s' is not supported as a destination protocol in forward: %s", trans, host) + } + if trans == transport.TLS && serverName != "" { + if f.tlsServerName != "" { + return f, fmt.Errorf("both forward ('%s') and proxy level ('%s') TLS servernames are set for upstream proxy '%s'", f.tlsServerName, serverName, host) + } + + tlsServerNames[i] = serverName + perServerNameProxyCount[serverName]++ + } + p := proxy.NewProxy("forward", h, trans) + f.proxies = append(f.proxies, p) + transports[i] = trans + } + + perServerNameTlsConfig := make(map[string]*tls.Config) if f.tlsServerName != "" { f.tlsConfig.ServerName = f.tlsServerName + } else { + for serverName, proxyCount := range perServerNameProxyCount { + tlsConfig := f.tlsConfig.Clone() + tlsConfig.ServerName = serverName + tlsConfig.ClientSessionCache = tls.NewLRUClientSessionCache(proxyCount) + perServerNameTlsConfig[serverName] = tlsConfig + } } // Initialize ClientSessionCache in tls.Config. This may speed up a TLS handshake @@ -154,7 +189,11 @@ func parseStanza(c *caddy.Controller) (*Forward, error) { for i := range f.proxies { // Only set this for proxies that need it. if transports[i] == transport.TLS { - f.proxies[i].SetTLSConfig(f.tlsConfig) + if tlsConfig, ok := perServerNameTlsConfig[tlsServerNames[i]]; ok { + f.proxies[i].SetTLSConfig(tlsConfig) + } else { + f.proxies[i].SetTLSConfig(f.tlsConfig) + } } f.proxies[i].SetExpire(f.expire) f.proxies[i].GetHealthchecker().SetRecursionDesired(f.opts.HCRecursionDesired) diff --git a/plugin/forward/setup_test.go b/plugin/forward/setup_test.go index 2e66e26c1..28d7241be 100644 --- a/plugin/forward/setup_test.go +++ b/plugin/forward/setup_test.go @@ -90,6 +90,36 @@ func TestSetup(t *testing.T) { } } +func TestSplitZone(t *testing.T) { + tests := []struct { + input string + expectedHost string + expectedZone string + }{ + { + "tls://127.0.0.1%example.net:854", "tls://127.0.0.1:854", "example.net", + }, { + "tls://127.0.0.1%example.net", "tls://127.0.0.1", "example.net", + }, { + "tls://127.0.0.1:854", "tls://127.0.0.1:854", "", + }, { + "dns://127.0.0.1", "dns://127.0.0.1", "", + }, { + "foo%bar:baz", "foo:baz", "bar", + }, + } + for i, test := range tests { + host, zone := splitZone(test.input) + + if host != test.expectedHost { + t.Errorf("Test %d: expected host %q, actual: %q", i, test.expectedHost, host) + } + if zone != test.expectedZone { + t.Errorf("Test %d: expected host %q, actual: %q", i, test.expectedHost, host) + } + } +} + func TestSetupTLS(t *testing.T) { tests := []struct { input string @@ -101,6 +131,19 @@ func TestSetupTLS(t *testing.T) { {`forward . tls://127.0.0.1 { tls_servername dns }`, false, "dns", ""}, + {`forward . tls://127.0.0.1%example.net { + tls + }`, false, "example.net", ""}, + {`forward . tls://127.0.0.1%example.net:854 tls://127.0.0.2%example.net tls://fe80::1%example.com { + tls + }`, false, "example.net", ""}, + {`forward . tls://127.0.0.1%example.net:854 { + tls + }`, false, "example.net", ""}, + // SNI specifications clash test + {`forward . tls://127.0.0.1%example.net:854 { + tls_servername foo + }`, true, "", "both forward ('foo') and proxy level ('example.net') TLS servernames are set for upstream proxy 'tls://127.0.0.1:854'"}, {`forward . 127.0.0.1 { tls_servername dns }`, false, "", ""}, @@ -126,16 +169,48 @@ func TestSetupTLS(t *testing.T) { if !strings.Contains(err.Error(), test.expectedErr) { t.Errorf("Test %d: expected error to contain: %v, found error: %v, input: %s", i, test.expectedErr, err, test.input) } + continue } + /* + if len(fs) == 0 { + continue + } + */ f := fs[0] - if !test.shouldErr && test.expectedServerName != "" && test.expectedServerName != f.tlsConfig.ServerName { - t.Errorf("Test %d: expected: %q, actual: %q", i, test.expectedServerName, f.tlsConfig.ServerName) + if !test.shouldErr && test.expectedServerName != "" && test.expectedServerName != f.proxies[0].GetTransport().GetTLSConfig().ServerName { + t.Errorf("Test %d: expected server name: %q, actual: %q", i, test.expectedServerName, f.proxies[0].GetTransport().GetTLSConfig().ServerName) } if !test.shouldErr && test.expectedServerName != "" && test.expectedServerName != f.proxies[0].GetHealthchecker().GetTLSConfig().ServerName { - t.Errorf("Test %d: expected: %q, actual: %q", i, test.expectedServerName, f.proxies[0].GetHealthchecker().GetTLSConfig().ServerName) + t.Errorf("Test %d: expected server name: %q, actual: %q", i, test.expectedServerName, f.proxies[0].GetHealthchecker().GetTLSConfig().ServerName) + } + } +} + +func TestSetupTLSclientSessionCacheCount(t *testing.T) { + tests := []struct { + input string + }{ + {`forward . tls://127.0.0.1%foo tls://127.0.0.2%foo tls://127.0.0.3%foo tls://127.0.0.4%bar tls://127.0.0.5%bar { }`}, + {`forward . tls://127.0.0.1%foo tls://127.0.0.2%foo tls://127.0.0.3%bar tls://127.0.0.4%bar tls://127.0.0.5%bar { }`}, + } + for i, test := range tests { + c := caddy.NewTestController("dns", test.input) + fs, err := parseForward(c) + if err != nil { + t.Errorf("Test %d: expected no error but found one for input %s, got: %v", i, test.input, err) + } + + if fs[0].proxies[0].GetTransport().GetTLSConfig() == fs[0].proxies[len(fs[0].proxies)-1].GetTransport().GetTLSConfig() { + t.Errorf("Test %d: tlsConfig is the same for both the first and last proxies", i) + } + if fs[0].proxies[0].GetTransport().GetTLSConfig() != fs[0].proxies[1].GetTransport().GetTLSConfig() { + t.Errorf("Test %d: tlsConfig differs for the first two proxies", i) + } + if fs[0].proxies[len(fs[0].proxies)-1].GetTransport().GetTLSConfig() != fs[0].proxies[len(fs[0].proxies)-2].GetTransport().GetTLSConfig() { + t.Errorf("Test %d: tlsConfig differs for the last two proxies", i) } } } @@ -473,13 +548,13 @@ func TestFailover(t *testing.T) { }`, s.Addr, server_fail_s.Addr, server_refused_s.Addr), true, "Although failover is not set, as long as the first upstream is work, there should be has a record return"}, } - for _, testCase := range tests { + for i, testCase := range tests { c := caddy.NewTestController("dns", testCase.input) fs, err := parseForward(c) f := fs[0] if err != nil { - t.Errorf("Failed to create forwarder: %s", err) + t.Errorf("Test #%d: Failed to create forwarder: %s", i, err) } f.OnStartup() defer f.OnShutdown() @@ -495,11 +570,11 @@ func TestFailover(t *testing.T) { rec := dnstest.NewRecorder(&test.ResponseWriter{}) if _, err := f.ServeDNS(context.TODO(), rec, m); err != nil { - t.Fatal("Expected to receive reply, but didn't") + t.Fatalf("Test #%d: Expected to receive reply, but didn't", i) } if (len(rec.Msg.Answer) > 0) != testCase.hasRecord { - t.Errorf(" %s: \n %s", testCase.failMsg, testCase.input) + t.Errorf("Test #%d: %s: \n %s", i, testCase.failMsg, testCase.input) } } } diff --git a/plugin/pkg/proxy/persistent.go b/plugin/pkg/proxy/persistent.go index 49c9dd385..280941980 100644 --- a/plugin/pkg/proxy/persistent.go +++ b/plugin/pkg/proxy/persistent.go @@ -151,6 +151,9 @@ func (t *Transport) SetExpire(expire time.Duration) { t.expire = expire } // SetTLSConfig sets the TLS config in transport. func (t *Transport) SetTLSConfig(cfg *tls.Config) { t.tlsConfig = cfg } +// GetTLSConfig returns the TLS config in transport. +func (t *Transport) GetTLSConfig() *tls.Config { return t.tlsConfig } + const ( defaultExpire = 10 * time.Second minDialTimeout = 1 * time.Second diff --git a/plugin/pkg/proxy/proxy.go b/plugin/pkg/proxy/proxy.go index 99fb5df78..35e94bf83 100644 --- a/plugin/pkg/proxy/proxy.go +++ b/plugin/pkg/proxy/proxy.go @@ -56,6 +56,10 @@ func (p *Proxy) GetHealthchecker() HealthChecker { return p.health } +func (p *Proxy) GetTransport() *Transport { + return p.transport +} + func (p *Proxy) Fails() uint32 { return atomic.LoadUint32(&p.fails) }