mirror of
https://github.com/coredns/coredns.git
synced 2025-10-30 17:53:21 -04:00
plugin/forward: added support for per-nameserver TLS SNI (#7633)
This commit is contained in:
@@ -78,11 +78,16 @@ forward FROM TO... {
|
|||||||
The server certificate is verified using the specified CA file
|
The server certificate is verified using the specified CA file
|
||||||
|
|
||||||
* `tls_servername` **NAME** allows you to set a server name in the TLS configuration; for instance 9.9.9.9
|
* `tls_servername` **NAME** allows you to set a server name in the TLS configuration; for instance 9.9.9.9
|
||||||
needs this to be set to `dns.quad9.net`. Multiple upstreams are still allowed in this scenario,
|
needs this to be set to `dns.quad9.net`. Using TLS forwarding but not setting `tls_servername` results in anyone
|
||||||
but they have to use the same `tls_servername`. E.g. mixing 9.9.9.9 (QuadDNS) with 1.1.1.1
|
|
||||||
(Cloudflare) will not work. Using TLS forwarding but not setting `tls_servername` results in anyone
|
|
||||||
being able to man-in-the-middle your connection to the DNS server you are forwarding to. Because of this,
|
being able to man-in-the-middle your connection to the DNS server you are forwarding to. Because of this,
|
||||||
it is strongly recommended to set this value when using TLS forwarding.
|
it is strongly recommended to set this value when using TLS forwarding.
|
||||||
|
|
||||||
|
Per destination endpoint TLS server name indication is possible in the form of `tls://9.9.9.9%dns.quad9.net`.
|
||||||
|
`tls_servername` must not be specified when using per destination endpoint TLS server name indication
|
||||||
|
as it would introduce clash between the server name indication spectifications. If destination endpoint
|
||||||
|
is to be reached via a port other than 853 then the port must be appended to the end of the destination
|
||||||
|
endpoint specifier. In case of port 10853, the above string would be: `tls://9.9.9.9%dns.quad9.net:10853`.
|
||||||
|
|
||||||
* `policy` specifies the policy to use for selecting upstream servers. The default is `random`.
|
* `policy` specifies the policy to use for selecting upstream servers. The default is `random`.
|
||||||
* `random` is a policy that implements random upstream selection.
|
* `random` is a policy that implements random upstream selection.
|
||||||
* `round_robin` is a policy that selects hosts based on round robin ordering.
|
* `round_robin` is a policy that selects hosts based on round robin ordering.
|
||||||
|
|||||||
@@ -97,6 +97,22 @@ func parseForward(c *caddy.Controller) ([]*Forward, error) {
|
|||||||
return fs, nil
|
return fs, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Splits the zone, preserving any port that comes after the zone
|
||||||
|
func splitZone(host string) (newHost string, zone string) {
|
||||||
|
newHost = host
|
||||||
|
if strings.Contains(host, "%") {
|
||||||
|
lastPercent := strings.LastIndex(host, "%")
|
||||||
|
newHost = host[:lastPercent]
|
||||||
|
zone = host[lastPercent+1:]
|
||||||
|
if strings.Contains(zone, ":") {
|
||||||
|
lastColon := strings.LastIndex(zone, ":")
|
||||||
|
newHost += zone[lastColon:]
|
||||||
|
zone = zone[:lastColon]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
func parseStanza(c *caddy.Controller) (*Forward, error) {
|
func parseStanza(c *caddy.Controller) (*Forward, error) {
|
||||||
f := New()
|
f := New()
|
||||||
|
|
||||||
@@ -124,27 +140,46 @@ func parseStanza(c *caddy.Controller) (*Forward, error) {
|
|||||||
return f, err
|
return f, err
|
||||||
}
|
}
|
||||||
|
|
||||||
transports := make([]string, len(toHosts))
|
|
||||||
allowedTrans := map[string]bool{"dns": true, "tls": true}
|
|
||||||
for i, host := range toHosts {
|
|
||||||
trans, h := parse.Transport(host)
|
|
||||||
|
|
||||||
if !allowedTrans[trans] {
|
|
||||||
return f, fmt.Errorf("'%s' is not supported as a destination protocol in forward: %s", trans, host)
|
|
||||||
}
|
|
||||||
p := proxy.NewProxy("forward", h, trans)
|
|
||||||
f.proxies = append(f.proxies, p)
|
|
||||||
transports[i] = trans
|
|
||||||
}
|
|
||||||
|
|
||||||
for c.NextBlock() {
|
for c.NextBlock() {
|
||||||
if err := parseBlock(c, f); err != nil {
|
if err := parseBlock(c, f); err != nil {
|
||||||
return f, err
|
return f, err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
tlsServerNames := make([]string, len(toHosts))
|
||||||
|
perServerNameProxyCount := make(map[string]int)
|
||||||
|
transports := make([]string, len(toHosts))
|
||||||
|
allowedTrans := map[string]bool{"dns": true, "tls": true}
|
||||||
|
for i, hostWithZone := range toHosts {
|
||||||
|
host, serverName := splitZone(hostWithZone)
|
||||||
|
trans, h := parse.Transport(host)
|
||||||
|
|
||||||
|
if !allowedTrans[trans] {
|
||||||
|
return f, fmt.Errorf("'%s' is not supported as a destination protocol in forward: %s", trans, host)
|
||||||
|
}
|
||||||
|
if trans == transport.TLS && serverName != "" {
|
||||||
|
if f.tlsServerName != "" {
|
||||||
|
return f, fmt.Errorf("both forward ('%s') and proxy level ('%s') TLS servernames are set for upstream proxy '%s'", f.tlsServerName, serverName, host)
|
||||||
|
}
|
||||||
|
|
||||||
|
tlsServerNames[i] = serverName
|
||||||
|
perServerNameProxyCount[serverName]++
|
||||||
|
}
|
||||||
|
p := proxy.NewProxy("forward", h, trans)
|
||||||
|
f.proxies = append(f.proxies, p)
|
||||||
|
transports[i] = trans
|
||||||
|
}
|
||||||
|
|
||||||
|
perServerNameTlsConfig := make(map[string]*tls.Config)
|
||||||
if f.tlsServerName != "" {
|
if f.tlsServerName != "" {
|
||||||
f.tlsConfig.ServerName = f.tlsServerName
|
f.tlsConfig.ServerName = f.tlsServerName
|
||||||
|
} else {
|
||||||
|
for serverName, proxyCount := range perServerNameProxyCount {
|
||||||
|
tlsConfig := f.tlsConfig.Clone()
|
||||||
|
tlsConfig.ServerName = serverName
|
||||||
|
tlsConfig.ClientSessionCache = tls.NewLRUClientSessionCache(proxyCount)
|
||||||
|
perServerNameTlsConfig[serverName] = tlsConfig
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Initialize ClientSessionCache in tls.Config. This may speed up a TLS handshake
|
// Initialize ClientSessionCache in tls.Config. This may speed up a TLS handshake
|
||||||
@@ -154,7 +189,11 @@ func parseStanza(c *caddy.Controller) (*Forward, error) {
|
|||||||
for i := range f.proxies {
|
for i := range f.proxies {
|
||||||
// Only set this for proxies that need it.
|
// Only set this for proxies that need it.
|
||||||
if transports[i] == transport.TLS {
|
if transports[i] == transport.TLS {
|
||||||
f.proxies[i].SetTLSConfig(f.tlsConfig)
|
if tlsConfig, ok := perServerNameTlsConfig[tlsServerNames[i]]; ok {
|
||||||
|
f.proxies[i].SetTLSConfig(tlsConfig)
|
||||||
|
} else {
|
||||||
|
f.proxies[i].SetTLSConfig(f.tlsConfig)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
f.proxies[i].SetExpire(f.expire)
|
f.proxies[i].SetExpire(f.expire)
|
||||||
f.proxies[i].GetHealthchecker().SetRecursionDesired(f.opts.HCRecursionDesired)
|
f.proxies[i].GetHealthchecker().SetRecursionDesired(f.opts.HCRecursionDesired)
|
||||||
|
|||||||
@@ -90,6 +90,36 @@ func TestSetup(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestSplitZone(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
input string
|
||||||
|
expectedHost string
|
||||||
|
expectedZone string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
"tls://127.0.0.1%example.net:854", "tls://127.0.0.1:854", "example.net",
|
||||||
|
}, {
|
||||||
|
"tls://127.0.0.1%example.net", "tls://127.0.0.1", "example.net",
|
||||||
|
}, {
|
||||||
|
"tls://127.0.0.1:854", "tls://127.0.0.1:854", "",
|
||||||
|
}, {
|
||||||
|
"dns://127.0.0.1", "dns://127.0.0.1", "",
|
||||||
|
}, {
|
||||||
|
"foo%bar:baz", "foo:baz", "bar",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for i, test := range tests {
|
||||||
|
host, zone := splitZone(test.input)
|
||||||
|
|
||||||
|
if host != test.expectedHost {
|
||||||
|
t.Errorf("Test %d: expected host %q, actual: %q", i, test.expectedHost, host)
|
||||||
|
}
|
||||||
|
if zone != test.expectedZone {
|
||||||
|
t.Errorf("Test %d: expected host %q, actual: %q", i, test.expectedHost, host)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestSetupTLS(t *testing.T) {
|
func TestSetupTLS(t *testing.T) {
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
input string
|
input string
|
||||||
@@ -101,6 +131,19 @@ func TestSetupTLS(t *testing.T) {
|
|||||||
{`forward . tls://127.0.0.1 {
|
{`forward . tls://127.0.0.1 {
|
||||||
tls_servername dns
|
tls_servername dns
|
||||||
}`, false, "dns", ""},
|
}`, false, "dns", ""},
|
||||||
|
{`forward . tls://127.0.0.1%example.net {
|
||||||
|
tls
|
||||||
|
}`, false, "example.net", ""},
|
||||||
|
{`forward . tls://127.0.0.1%example.net:854 tls://127.0.0.2%example.net tls://fe80::1%example.com {
|
||||||
|
tls
|
||||||
|
}`, false, "example.net", ""},
|
||||||
|
{`forward . tls://127.0.0.1%example.net:854 {
|
||||||
|
tls
|
||||||
|
}`, false, "example.net", ""},
|
||||||
|
// SNI specifications clash test
|
||||||
|
{`forward . tls://127.0.0.1%example.net:854 {
|
||||||
|
tls_servername foo
|
||||||
|
}`, true, "", "both forward ('foo') and proxy level ('example.net') TLS servernames are set for upstream proxy 'tls://127.0.0.1:854'"},
|
||||||
{`forward . 127.0.0.1 {
|
{`forward . 127.0.0.1 {
|
||||||
tls_servername dns
|
tls_servername dns
|
||||||
}`, false, "", ""},
|
}`, false, "", ""},
|
||||||
@@ -126,16 +169,48 @@ func TestSetupTLS(t *testing.T) {
|
|||||||
if !strings.Contains(err.Error(), test.expectedErr) {
|
if !strings.Contains(err.Error(), test.expectedErr) {
|
||||||
t.Errorf("Test %d: expected error to contain: %v, found error: %v, input: %s", i, test.expectedErr, err, test.input)
|
t.Errorf("Test %d: expected error to contain: %v, found error: %v, input: %s", i, test.expectedErr, err, test.input)
|
||||||
}
|
}
|
||||||
|
continue
|
||||||
}
|
}
|
||||||
|
/*
|
||||||
|
if len(fs) == 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
*/
|
||||||
|
|
||||||
f := fs[0]
|
f := fs[0]
|
||||||
|
|
||||||
if !test.shouldErr && test.expectedServerName != "" && test.expectedServerName != f.tlsConfig.ServerName {
|
if !test.shouldErr && test.expectedServerName != "" && test.expectedServerName != f.proxies[0].GetTransport().GetTLSConfig().ServerName {
|
||||||
t.Errorf("Test %d: expected: %q, actual: %q", i, test.expectedServerName, f.tlsConfig.ServerName)
|
t.Errorf("Test %d: expected server name: %q, actual: %q", i, test.expectedServerName, f.proxies[0].GetTransport().GetTLSConfig().ServerName)
|
||||||
}
|
}
|
||||||
|
|
||||||
if !test.shouldErr && test.expectedServerName != "" && test.expectedServerName != f.proxies[0].GetHealthchecker().GetTLSConfig().ServerName {
|
if !test.shouldErr && test.expectedServerName != "" && test.expectedServerName != f.proxies[0].GetHealthchecker().GetTLSConfig().ServerName {
|
||||||
t.Errorf("Test %d: expected: %q, actual: %q", i, test.expectedServerName, f.proxies[0].GetHealthchecker().GetTLSConfig().ServerName)
|
t.Errorf("Test %d: expected server name: %q, actual: %q", i, test.expectedServerName, f.proxies[0].GetHealthchecker().GetTLSConfig().ServerName)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSetupTLSclientSessionCacheCount(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
input string
|
||||||
|
}{
|
||||||
|
{`forward . tls://127.0.0.1%foo tls://127.0.0.2%foo tls://127.0.0.3%foo tls://127.0.0.4%bar tls://127.0.0.5%bar { }`},
|
||||||
|
{`forward . tls://127.0.0.1%foo tls://127.0.0.2%foo tls://127.0.0.3%bar tls://127.0.0.4%bar tls://127.0.0.5%bar { }`},
|
||||||
|
}
|
||||||
|
for i, test := range tests {
|
||||||
|
c := caddy.NewTestController("dns", test.input)
|
||||||
|
fs, err := parseForward(c)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Test %d: expected no error but found one for input %s, got: %v", i, test.input, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if fs[0].proxies[0].GetTransport().GetTLSConfig() == fs[0].proxies[len(fs[0].proxies)-1].GetTransport().GetTLSConfig() {
|
||||||
|
t.Errorf("Test %d: tlsConfig is the same for both the first and last proxies", i)
|
||||||
|
}
|
||||||
|
if fs[0].proxies[0].GetTransport().GetTLSConfig() != fs[0].proxies[1].GetTransport().GetTLSConfig() {
|
||||||
|
t.Errorf("Test %d: tlsConfig differs for the first two proxies", i)
|
||||||
|
}
|
||||||
|
if fs[0].proxies[len(fs[0].proxies)-1].GetTransport().GetTLSConfig() != fs[0].proxies[len(fs[0].proxies)-2].GetTransport().GetTLSConfig() {
|
||||||
|
t.Errorf("Test %d: tlsConfig differs for the last two proxies", i)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -473,13 +548,13 @@ func TestFailover(t *testing.T) {
|
|||||||
}`, s.Addr, server_fail_s.Addr, server_refused_s.Addr), true, "Although failover is not set, as long as the first upstream is work, there should be has a record return"},
|
}`, s.Addr, server_fail_s.Addr, server_refused_s.Addr), true, "Although failover is not set, as long as the first upstream is work, there should be has a record return"},
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, testCase := range tests {
|
for i, testCase := range tests {
|
||||||
c := caddy.NewTestController("dns", testCase.input)
|
c := caddy.NewTestController("dns", testCase.input)
|
||||||
fs, err := parseForward(c)
|
fs, err := parseForward(c)
|
||||||
|
|
||||||
f := fs[0]
|
f := fs[0]
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("Failed to create forwarder: %s", err)
|
t.Errorf("Test #%d: Failed to create forwarder: %s", i, err)
|
||||||
}
|
}
|
||||||
f.OnStartup()
|
f.OnStartup()
|
||||||
defer f.OnShutdown()
|
defer f.OnShutdown()
|
||||||
@@ -495,11 +570,11 @@ func TestFailover(t *testing.T) {
|
|||||||
rec := dnstest.NewRecorder(&test.ResponseWriter{})
|
rec := dnstest.NewRecorder(&test.ResponseWriter{})
|
||||||
|
|
||||||
if _, err := f.ServeDNS(context.TODO(), rec, m); err != nil {
|
if _, err := f.ServeDNS(context.TODO(), rec, m); err != nil {
|
||||||
t.Fatal("Expected to receive reply, but didn't")
|
t.Fatalf("Test #%d: Expected to receive reply, but didn't", i)
|
||||||
}
|
}
|
||||||
|
|
||||||
if (len(rec.Msg.Answer) > 0) != testCase.hasRecord {
|
if (len(rec.Msg.Answer) > 0) != testCase.hasRecord {
|
||||||
t.Errorf(" %s: \n %s", testCase.failMsg, testCase.input)
|
t.Errorf("Test #%d: %s: \n %s", i, testCase.failMsg, testCase.input)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -151,6 +151,9 @@ func (t *Transport) SetExpire(expire time.Duration) { t.expire = expire }
|
|||||||
// SetTLSConfig sets the TLS config in transport.
|
// SetTLSConfig sets the TLS config in transport.
|
||||||
func (t *Transport) SetTLSConfig(cfg *tls.Config) { t.tlsConfig = cfg }
|
func (t *Transport) SetTLSConfig(cfg *tls.Config) { t.tlsConfig = cfg }
|
||||||
|
|
||||||
|
// GetTLSConfig returns the TLS config in transport.
|
||||||
|
func (t *Transport) GetTLSConfig() *tls.Config { return t.tlsConfig }
|
||||||
|
|
||||||
const (
|
const (
|
||||||
defaultExpire = 10 * time.Second
|
defaultExpire = 10 * time.Second
|
||||||
minDialTimeout = 1 * time.Second
|
minDialTimeout = 1 * time.Second
|
||||||
|
|||||||
@@ -56,6 +56,10 @@ func (p *Proxy) GetHealthchecker() HealthChecker {
|
|||||||
return p.health
|
return p.health
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (p *Proxy) GetTransport() *Transport {
|
||||||
|
return p.transport
|
||||||
|
}
|
||||||
|
|
||||||
func (p *Proxy) Fails() uint32 {
|
func (p *Proxy) Fails() uint32 {
|
||||||
return atomic.LoadUint32(&p.fails)
|
return atomic.LoadUint32(&p.fails)
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user