plugin/forward: added support for per-nameserver TLS SNI (#7633)

This commit is contained in:
Endre Szabo
2025-10-27 16:43:30 +01:00
committed by GitHub
parent b72d267a29
commit d68cbedbb1
5 changed files with 150 additions and 24 deletions

View File

@@ -78,11 +78,16 @@ forward FROM TO... {
The server certificate is verified using the specified CA file
* `tls_servername` **NAME** allows you to set a server name in the TLS configuration; for instance 9.9.9.9
needs this to be set to `dns.quad9.net`. Multiple upstreams are still allowed in this scenario,
but they have to use the same `tls_servername`. E.g. mixing 9.9.9.9 (QuadDNS) with 1.1.1.1
(Cloudflare) will not work. Using TLS forwarding but not setting `tls_servername` results in anyone
needs this to be set to `dns.quad9.net`. Using TLS forwarding but not setting `tls_servername` results in anyone
being able to man-in-the-middle your connection to the DNS server you are forwarding to. Because of this,
it is strongly recommended to set this value when using TLS forwarding.
Per destination endpoint TLS server name indication is possible in the form of `tls://9.9.9.9%dns.quad9.net`.
`tls_servername` must not be specified when using per destination endpoint TLS server name indication
as it would introduce clash between the server name indication spectifications. If destination endpoint
is to be reached via a port other than 853 then the port must be appended to the end of the destination
endpoint specifier. In case of port 10853, the above string would be: `tls://9.9.9.9%dns.quad9.net:10853`.
* `policy` specifies the policy to use for selecting upstream servers. The default is `random`.
* `random` is a policy that implements random upstream selection.
* `round_robin` is a policy that selects hosts based on round robin ordering.

View File

@@ -97,6 +97,22 @@ func parseForward(c *caddy.Controller) ([]*Forward, error) {
return fs, nil
}
// Splits the zone, preserving any port that comes after the zone
func splitZone(host string) (newHost string, zone string) {
newHost = host
if strings.Contains(host, "%") {
lastPercent := strings.LastIndex(host, "%")
newHost = host[:lastPercent]
zone = host[lastPercent+1:]
if strings.Contains(zone, ":") {
lastColon := strings.LastIndex(zone, ":")
newHost += zone[lastColon:]
zone = zone[:lastColon]
}
}
return
}
func parseStanza(c *caddy.Controller) (*Forward, error) {
f := New()
@@ -124,27 +140,46 @@ func parseStanza(c *caddy.Controller) (*Forward, error) {
return f, err
}
transports := make([]string, len(toHosts))
allowedTrans := map[string]bool{"dns": true, "tls": true}
for i, host := range toHosts {
trans, h := parse.Transport(host)
if !allowedTrans[trans] {
return f, fmt.Errorf("'%s' is not supported as a destination protocol in forward: %s", trans, host)
}
p := proxy.NewProxy("forward", h, trans)
f.proxies = append(f.proxies, p)
transports[i] = trans
}
for c.NextBlock() {
if err := parseBlock(c, f); err != nil {
return f, err
}
}
tlsServerNames := make([]string, len(toHosts))
perServerNameProxyCount := make(map[string]int)
transports := make([]string, len(toHosts))
allowedTrans := map[string]bool{"dns": true, "tls": true}
for i, hostWithZone := range toHosts {
host, serverName := splitZone(hostWithZone)
trans, h := parse.Transport(host)
if !allowedTrans[trans] {
return f, fmt.Errorf("'%s' is not supported as a destination protocol in forward: %s", trans, host)
}
if trans == transport.TLS && serverName != "" {
if f.tlsServerName != "" {
return f, fmt.Errorf("both forward ('%s') and proxy level ('%s') TLS servernames are set for upstream proxy '%s'", f.tlsServerName, serverName, host)
}
tlsServerNames[i] = serverName
perServerNameProxyCount[serverName]++
}
p := proxy.NewProxy("forward", h, trans)
f.proxies = append(f.proxies, p)
transports[i] = trans
}
perServerNameTlsConfig := make(map[string]*tls.Config)
if f.tlsServerName != "" {
f.tlsConfig.ServerName = f.tlsServerName
} else {
for serverName, proxyCount := range perServerNameProxyCount {
tlsConfig := f.tlsConfig.Clone()
tlsConfig.ServerName = serverName
tlsConfig.ClientSessionCache = tls.NewLRUClientSessionCache(proxyCount)
perServerNameTlsConfig[serverName] = tlsConfig
}
}
// Initialize ClientSessionCache in tls.Config. This may speed up a TLS handshake
@@ -154,8 +189,12 @@ func parseStanza(c *caddy.Controller) (*Forward, error) {
for i := range f.proxies {
// Only set this for proxies that need it.
if transports[i] == transport.TLS {
if tlsConfig, ok := perServerNameTlsConfig[tlsServerNames[i]]; ok {
f.proxies[i].SetTLSConfig(tlsConfig)
} else {
f.proxies[i].SetTLSConfig(f.tlsConfig)
}
}
f.proxies[i].SetExpire(f.expire)
f.proxies[i].GetHealthchecker().SetRecursionDesired(f.opts.HCRecursionDesired)
// when TLS is used, checks are set to tcp-tls

View File

@@ -90,6 +90,36 @@ func TestSetup(t *testing.T) {
}
}
func TestSplitZone(t *testing.T) {
tests := []struct {
input string
expectedHost string
expectedZone string
}{
{
"tls://127.0.0.1%example.net:854", "tls://127.0.0.1:854", "example.net",
}, {
"tls://127.0.0.1%example.net", "tls://127.0.0.1", "example.net",
}, {
"tls://127.0.0.1:854", "tls://127.0.0.1:854", "",
}, {
"dns://127.0.0.1", "dns://127.0.0.1", "",
}, {
"foo%bar:baz", "foo:baz", "bar",
},
}
for i, test := range tests {
host, zone := splitZone(test.input)
if host != test.expectedHost {
t.Errorf("Test %d: expected host %q, actual: %q", i, test.expectedHost, host)
}
if zone != test.expectedZone {
t.Errorf("Test %d: expected host %q, actual: %q", i, test.expectedHost, host)
}
}
}
func TestSetupTLS(t *testing.T) {
tests := []struct {
input string
@@ -101,6 +131,19 @@ func TestSetupTLS(t *testing.T) {
{`forward . tls://127.0.0.1 {
tls_servername dns
}`, false, "dns", ""},
{`forward . tls://127.0.0.1%example.net {
tls
}`, false, "example.net", ""},
{`forward . tls://127.0.0.1%example.net:854 tls://127.0.0.2%example.net tls://fe80::1%example.com {
tls
}`, false, "example.net", ""},
{`forward . tls://127.0.0.1%example.net:854 {
tls
}`, false, "example.net", ""},
// SNI specifications clash test
{`forward . tls://127.0.0.1%example.net:854 {
tls_servername foo
}`, true, "", "both forward ('foo') and proxy level ('example.net') TLS servernames are set for upstream proxy 'tls://127.0.0.1:854'"},
{`forward . 127.0.0.1 {
tls_servername dns
}`, false, "", ""},
@@ -126,16 +169,48 @@ func TestSetupTLS(t *testing.T) {
if !strings.Contains(err.Error(), test.expectedErr) {
t.Errorf("Test %d: expected error to contain: %v, found error: %v, input: %s", i, test.expectedErr, err, test.input)
}
continue
}
/*
if len(fs) == 0 {
continue
}
*/
f := fs[0]
if !test.shouldErr && test.expectedServerName != "" && test.expectedServerName != f.tlsConfig.ServerName {
t.Errorf("Test %d: expected: %q, actual: %q", i, test.expectedServerName, f.tlsConfig.ServerName)
if !test.shouldErr && test.expectedServerName != "" && test.expectedServerName != f.proxies[0].GetTransport().GetTLSConfig().ServerName {
t.Errorf("Test %d: expected server name: %q, actual: %q", i, test.expectedServerName, f.proxies[0].GetTransport().GetTLSConfig().ServerName)
}
if !test.shouldErr && test.expectedServerName != "" && test.expectedServerName != f.proxies[0].GetHealthchecker().GetTLSConfig().ServerName {
t.Errorf("Test %d: expected: %q, actual: %q", i, test.expectedServerName, f.proxies[0].GetHealthchecker().GetTLSConfig().ServerName)
t.Errorf("Test %d: expected server name: %q, actual: %q", i, test.expectedServerName, f.proxies[0].GetHealthchecker().GetTLSConfig().ServerName)
}
}
}
func TestSetupTLSclientSessionCacheCount(t *testing.T) {
tests := []struct {
input string
}{
{`forward . tls://127.0.0.1%foo tls://127.0.0.2%foo tls://127.0.0.3%foo tls://127.0.0.4%bar tls://127.0.0.5%bar { }`},
{`forward . tls://127.0.0.1%foo tls://127.0.0.2%foo tls://127.0.0.3%bar tls://127.0.0.4%bar tls://127.0.0.5%bar { }`},
}
for i, test := range tests {
c := caddy.NewTestController("dns", test.input)
fs, err := parseForward(c)
if err != nil {
t.Errorf("Test %d: expected no error but found one for input %s, got: %v", i, test.input, err)
}
if fs[0].proxies[0].GetTransport().GetTLSConfig() == fs[0].proxies[len(fs[0].proxies)-1].GetTransport().GetTLSConfig() {
t.Errorf("Test %d: tlsConfig is the same for both the first and last proxies", i)
}
if fs[0].proxies[0].GetTransport().GetTLSConfig() != fs[0].proxies[1].GetTransport().GetTLSConfig() {
t.Errorf("Test %d: tlsConfig differs for the first two proxies", i)
}
if fs[0].proxies[len(fs[0].proxies)-1].GetTransport().GetTLSConfig() != fs[0].proxies[len(fs[0].proxies)-2].GetTransport().GetTLSConfig() {
t.Errorf("Test %d: tlsConfig differs for the last two proxies", i)
}
}
}
@@ -473,13 +548,13 @@ func TestFailover(t *testing.T) {
}`, s.Addr, server_fail_s.Addr, server_refused_s.Addr), true, "Although failover is not set, as long as the first upstream is work, there should be has a record return"},
}
for _, testCase := range tests {
for i, testCase := range tests {
c := caddy.NewTestController("dns", testCase.input)
fs, err := parseForward(c)
f := fs[0]
if err != nil {
t.Errorf("Failed to create forwarder: %s", err)
t.Errorf("Test #%d: Failed to create forwarder: %s", i, err)
}
f.OnStartup()
defer f.OnShutdown()
@@ -495,11 +570,11 @@ func TestFailover(t *testing.T) {
rec := dnstest.NewRecorder(&test.ResponseWriter{})
if _, err := f.ServeDNS(context.TODO(), rec, m); err != nil {
t.Fatal("Expected to receive reply, but didn't")
t.Fatalf("Test #%d: Expected to receive reply, but didn't", i)
}
if (len(rec.Msg.Answer) > 0) != testCase.hasRecord {
t.Errorf(" %s: \n %s", testCase.failMsg, testCase.input)
t.Errorf("Test #%d: %s: \n %s", i, testCase.failMsg, testCase.input)
}
}
}

View File

@@ -151,6 +151,9 @@ func (t *Transport) SetExpire(expire time.Duration) { t.expire = expire }
// SetTLSConfig sets the TLS config in transport.
func (t *Transport) SetTLSConfig(cfg *tls.Config) { t.tlsConfig = cfg }
// GetTLSConfig returns the TLS config in transport.
func (t *Transport) GetTLSConfig() *tls.Config { return t.tlsConfig }
const (
defaultExpire = 10 * time.Second
minDialTimeout = 1 * time.Second

View File

@@ -56,6 +56,10 @@ func (p *Proxy) GetHealthchecker() HealthChecker {
return p.health
}
func (p *Proxy) GetTransport() *Transport {
return p.transport
}
func (p *Proxy) Fails() uint32 {
return atomic.LoadUint32(&p.fails)
}