mirror of
				https://github.com/coredns/coredns.git
				synced 2025-10-30 17:53:21 -04:00 
			
		
		
		
	Cleanup ParseHostOrFile (#2100)
Create plugin/pkg/transport that holds the transport related functions. This needed to be a new pkg to prevent cyclic import errors. This cleans up a bunch of duplicated code in core/dnsserver that also tried to parse a transport (now all done in transport.Parse). Signed-off-by: Miek Gieben <miek@miek.nl>
This commit is contained in:
		| @@ -6,6 +6,7 @@ import ( | |||||||
| 	"strings" | 	"strings" | ||||||
|  |  | ||||||
| 	"github.com/coredns/coredns/plugin" | 	"github.com/coredns/coredns/plugin" | ||||||
|  | 	"github.com/coredns/coredns/plugin/pkg/transport" | ||||||
|  |  | ||||||
| 	"github.com/miekg/dns" | 	"github.com/miekg/dns" | ||||||
| ) | ) | ||||||
| @@ -27,43 +28,13 @@ func (z zoneAddr) String() string { | |||||||
| 	return s | 	return s | ||||||
| } | } | ||||||
|  |  | ||||||
| // Transport returns the protocol of the string s |  | ||||||
| func Transport(s string) string { |  | ||||||
| 	switch { |  | ||||||
| 	case strings.HasPrefix(s, TransportTLS+"://"): |  | ||||||
| 		return TransportTLS |  | ||||||
| 	case strings.HasPrefix(s, TransportDNS+"://"): |  | ||||||
| 		return TransportDNS |  | ||||||
| 	case strings.HasPrefix(s, TransportGRPC+"://"): |  | ||||||
| 		return TransportGRPC |  | ||||||
| 	case strings.HasPrefix(s, TransportHTTPS+"://"): |  | ||||||
| 		return TransportHTTPS |  | ||||||
| 	} |  | ||||||
| 	return TransportDNS |  | ||||||
| } |  | ||||||
|  |  | ||||||
| // normalizeZone parses an zone string into a structured format with separate | // normalizeZone parses an zone string into a structured format with separate | ||||||
| // host, and port portions, as well as the original input string. | // host, and port portions, as well as the original input string. | ||||||
| func normalizeZone(str string) (zoneAddr, error) { | func normalizeZone(str string) (zoneAddr, error) { | ||||||
| 	var err error | 	var err error | ||||||
|  |  | ||||||
| 	// Default to DNS if there isn't a transport protocol prefix. | 	var trans string | ||||||
| 	trans := TransportDNS | 	trans, str = transport.Parse(str) | ||||||
|  |  | ||||||
| 	switch { |  | ||||||
| 	case strings.HasPrefix(str, TransportTLS+"://"): |  | ||||||
| 		trans = TransportTLS |  | ||||||
| 		str = str[len(TransportTLS+"://"):] |  | ||||||
| 	case strings.HasPrefix(str, TransportDNS+"://"): |  | ||||||
| 		trans = TransportDNS |  | ||||||
| 		str = str[len(TransportDNS+"://"):] |  | ||||||
| 	case strings.HasPrefix(str, TransportGRPC+"://"): |  | ||||||
| 		trans = TransportGRPC |  | ||||||
| 		str = str[len(TransportGRPC+"://"):] |  | ||||||
| 	case strings.HasPrefix(str, TransportHTTPS+"://"): |  | ||||||
| 		trans = TransportHTTPS |  | ||||||
| 		str = str[len(TransportHTTPS+"://"):] |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	host, port, ipnet, err := plugin.SplitHostPort(str) | 	host, port, ipnet, err := plugin.SplitHostPort(str) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| @@ -71,17 +42,15 @@ func normalizeZone(str string) (zoneAddr, error) { | |||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	if port == "" { | 	if port == "" { | ||||||
| 		if trans == TransportDNS { | 		switch trans { | ||||||
|  | 		case transport.DNS: | ||||||
| 			port = Port | 			port = Port | ||||||
| 		} | 		case transport.TLS: | ||||||
| 		if trans == TransportTLS { | 			port = transport.TLSPort | ||||||
| 			port = TLSPort | 		case transport.GRPC: | ||||||
| 		} | 			port = transport.GRPCPort | ||||||
| 		if trans == TransportGRPC { | 		case transport.HTTPS: | ||||||
| 			port = GRPCPort | 			port = transport.HTTPSPort | ||||||
| 		} |  | ||||||
| 		if trans == TransportHTTPS { |  | ||||||
| 			port = HTTPSPort |  | ||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| @@ -103,14 +72,6 @@ func SplitProtocolHostPort(address string) (protocol string, ip string, port str | |||||||
| 	} | 	} | ||||||
| } | } | ||||||
|  |  | ||||||
| // Supported transports. |  | ||||||
| const ( |  | ||||||
| 	TransportDNS   = "dns" |  | ||||||
| 	TransportTLS   = "tls" |  | ||||||
| 	TransportGRPC  = "grpc" |  | ||||||
| 	TransportHTTPS = "https" |  | ||||||
| ) |  | ||||||
|  |  | ||||||
| type zoneOverlap struct { | type zoneOverlap struct { | ||||||
| 	registeredAddr map[zoneAddr]zoneAddr // each zoneAddr is registered once by its key | 	registeredAddr map[zoneAddr]zoneAddr // each zoneAddr is registered once by its key | ||||||
| 	unboundOverlap map[zoneAddr]zoneAddr // the "no bind" equiv ZoneAdddr is registered by its original key | 	unboundOverlap map[zoneAddr]zoneAddr // the "no bind" equiv ZoneAdddr is registered by its original key | ||||||
|   | |||||||
| @@ -192,21 +192,3 @@ func TestOverlapAddressChecker(t *testing.T) { | |||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
|  |  | ||||||
| func TestTransport(t *testing.T) { |  | ||||||
| 	for i, test := range []struct { |  | ||||||
| 		input    string |  | ||||||
| 		expected string |  | ||||||
| 	}{ |  | ||||||
| 		{"dns://.:53", TransportDNS}, |  | ||||||
| 		{"2003::1/64.:53", TransportDNS}, |  | ||||||
| 		{"grpc://example.org:1443 ", TransportGRPC}, |  | ||||||
| 		{"tls://example.org ", TransportTLS}, |  | ||||||
| 		{"https://example.org ", TransportHTTPS}, |  | ||||||
| 	} { |  | ||||||
| 		actual := Transport(test.input) |  | ||||||
| 		if actual != test.expected { |  | ||||||
| 			t.Errorf("Test %d: Expected %s but got %s", i, test.expected, actual) |  | ||||||
| 		} |  | ||||||
| 	} |  | ||||||
| } |  | ||||||
|   | |||||||
| @@ -9,6 +9,7 @@ import ( | |||||||
|  |  | ||||||
| 	"github.com/coredns/coredns/plugin" | 	"github.com/coredns/coredns/plugin" | ||||||
| 	"github.com/coredns/coredns/plugin/pkg/dnsutil" | 	"github.com/coredns/coredns/plugin/pkg/dnsutil" | ||||||
|  | 	"github.com/coredns/coredns/plugin/pkg/transport" | ||||||
|  |  | ||||||
| 	"github.com/mholt/caddy" | 	"github.com/mholt/caddy" | ||||||
| 	"github.com/mholt/caddy/caddyfile" | 	"github.com/mholt/caddy/caddyfile" | ||||||
| @@ -111,29 +112,29 @@ func (h *dnsContext) MakeServers() ([]caddy.Server, error) { | |||||||
| 	var servers []caddy.Server | 	var servers []caddy.Server | ||||||
| 	for addr, group := range groups { | 	for addr, group := range groups { | ||||||
| 		// switch on addr | 		// switch on addr | ||||||
| 		switch Transport(addr) { | 		switch tr, _ := transport.Parse(addr); tr { | ||||||
| 		case TransportDNS: | 		case transport.DNS: | ||||||
| 			s, err := NewServer(addr, group) | 			s, err := NewServer(addr, group) | ||||||
| 			if err != nil { | 			if err != nil { | ||||||
| 				return nil, err | 				return nil, err | ||||||
| 			} | 			} | ||||||
| 			servers = append(servers, s) | 			servers = append(servers, s) | ||||||
|  |  | ||||||
| 		case TransportTLS: | 		case transport.TLS: | ||||||
| 			s, err := NewServerTLS(addr, group) | 			s, err := NewServerTLS(addr, group) | ||||||
| 			if err != nil { | 			if err != nil { | ||||||
| 				return nil, err | 				return nil, err | ||||||
| 			} | 			} | ||||||
| 			servers = append(servers, s) | 			servers = append(servers, s) | ||||||
|  |  | ||||||
| 		case TransportGRPC: | 		case transport.GRPC: | ||||||
| 			s, err := NewServergRPC(addr, group) | 			s, err := NewServergRPC(addr, group) | ||||||
| 			if err != nil { | 			if err != nil { | ||||||
| 				return nil, err | 				return nil, err | ||||||
| 			} | 			} | ||||||
| 			servers = append(servers, s) | 			servers = append(servers, s) | ||||||
|  |  | ||||||
| 		case TransportHTTPS: | 		case transport.HTTPS: | ||||||
| 			s, err := NewServerHTTPS(addr, group) | 			s, err := NewServerHTTPS(addr, group) | ||||||
| 			if err != nil { | 			if err != nil { | ||||||
| 				return nil, err | 				return nil, err | ||||||
| @@ -234,16 +235,8 @@ func groupConfigsByListenAddr(configs []*Config) (map[string][]*Config, error) { | |||||||
| 	return groups, nil | 	return groups, nil | ||||||
| } | } | ||||||
|  |  | ||||||
| const ( | // DefaultPort is the default port. | ||||||
| 	// DefaultPort is the default port. | const DefaultPort = "53" | ||||||
| 	DefaultPort = "53" |  | ||||||
| 	// TLSPort is the default port for DNS-over-TLS. |  | ||||||
| 	TLSPort = "853" |  | ||||||
| 	// GRPCPort is the default port for DNS-over-gRPC. |  | ||||||
| 	GRPCPort = "443" |  | ||||||
| 	// HTTPSPort is the default port for DNS-over-HTTPS. |  | ||||||
| 	HTTPSPort = "443" |  | ||||||
| ) |  | ||||||
|  |  | ||||||
| // These "soft defaults" are configurable by | // These "soft defaults" are configurable by | ||||||
| // command line flags, etc. | // command line flags, etc. | ||||||
|   | |||||||
| @@ -15,6 +15,7 @@ import ( | |||||||
| 	"github.com/coredns/coredns/plugin/pkg/log" | 	"github.com/coredns/coredns/plugin/pkg/log" | ||||||
| 	"github.com/coredns/coredns/plugin/pkg/rcode" | 	"github.com/coredns/coredns/plugin/pkg/rcode" | ||||||
| 	"github.com/coredns/coredns/plugin/pkg/trace" | 	"github.com/coredns/coredns/plugin/pkg/trace" | ||||||
|  | 	"github.com/coredns/coredns/plugin/pkg/transport" | ||||||
| 	"github.com/coredns/coredns/request" | 	"github.com/coredns/coredns/request" | ||||||
|  |  | ||||||
| 	"github.com/miekg/dns" | 	"github.com/miekg/dns" | ||||||
| @@ -134,7 +135,7 @@ func (s *Server) ServePacket(p net.PacketConn) error { | |||||||
|  |  | ||||||
| // Listen implements caddy.TCPServer interface. | // Listen implements caddy.TCPServer interface. | ||||||
| func (s *Server) Listen() (net.Listener, error) { | func (s *Server) Listen() (net.Listener, error) { | ||||||
| 	l, err := net.Listen("tcp", s.Addr[len(TransportDNS+"://"):]) | 	l, err := net.Listen("tcp", s.Addr[len(transport.DNS+"://"):]) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return nil, err | 		return nil, err | ||||||
| 	} | 	} | ||||||
| @@ -143,7 +144,7 @@ func (s *Server) Listen() (net.Listener, error) { | |||||||
|  |  | ||||||
| // ListenPacket implements caddy.UDPServer interface. | // ListenPacket implements caddy.UDPServer interface. | ||||||
| func (s *Server) ListenPacket() (net.PacketConn, error) { | func (s *Server) ListenPacket() (net.PacketConn, error) { | ||||||
| 	p, err := net.ListenPacket("udp", s.Addr[len(TransportDNS+"://"):]) | 	p, err := net.ListenPacket("udp", s.Addr[len(transport.DNS+"://"):]) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return nil, err | 		return nil, err | ||||||
| 	} | 	} | ||||||
|   | |||||||
| @@ -8,6 +8,7 @@ import ( | |||||||
| 	"net" | 	"net" | ||||||
|  |  | ||||||
| 	"github.com/coredns/coredns/pb" | 	"github.com/coredns/coredns/pb" | ||||||
|  | 	"github.com/coredns/coredns/plugin/pkg/transport" | ||||||
| 	"github.com/coredns/coredns/plugin/pkg/watch" | 	"github.com/coredns/coredns/plugin/pkg/watch" | ||||||
|  |  | ||||||
| 	"github.com/grpc-ecosystem/grpc-opentracing/go/otgrpc" | 	"github.com/grpc-ecosystem/grpc-opentracing/go/otgrpc" | ||||||
| @@ -73,7 +74,7 @@ func (s *ServergRPC) ServePacket(p net.PacketConn) error { return nil } | |||||||
| // Listen implements caddy.TCPServer interface. | // Listen implements caddy.TCPServer interface. | ||||||
| func (s *ServergRPC) Listen() (net.Listener, error) { | func (s *ServergRPC) Listen() (net.Listener, error) { | ||||||
|  |  | ||||||
| 	l, err := net.Listen("tcp", s.Addr[len(TransportGRPC+"://"):]) | 	l, err := net.Listen("tcp", s.Addr[len(transport.GRPC+"://"):]) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return nil, err | 		return nil, err | ||||||
| 	} | 	} | ||||||
| @@ -90,7 +91,7 @@ func (s *ServergRPC) OnStartupComplete() { | |||||||
| 		return | 		return | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	out := startUpZones(TransportGRPC+"://", s.Addr, s.zones) | 	out := startUpZones(transport.GRPC+"://", s.Addr, s.zones) | ||||||
| 	if out != "" { | 	if out != "" { | ||||||
| 		fmt.Print(out) | 		fmt.Print(out) | ||||||
| 	} | 	} | ||||||
|   | |||||||
| @@ -12,6 +12,7 @@ import ( | |||||||
| 	"github.com/coredns/coredns/plugin/pkg/dnsutil" | 	"github.com/coredns/coredns/plugin/pkg/dnsutil" | ||||||
| 	"github.com/coredns/coredns/plugin/pkg/doh" | 	"github.com/coredns/coredns/plugin/pkg/doh" | ||||||
| 	"github.com/coredns/coredns/plugin/pkg/response" | 	"github.com/coredns/coredns/plugin/pkg/response" | ||||||
|  | 	"github.com/coredns/coredns/plugin/pkg/transport" | ||||||
| ) | ) | ||||||
|  |  | ||||||
| // ServerHTTPS represents an instance of a DNS-over-HTTPS server. | // ServerHTTPS represents an instance of a DNS-over-HTTPS server. | ||||||
| @@ -60,7 +61,7 @@ func (s *ServerHTTPS) ServePacket(p net.PacketConn) error { return nil } | |||||||
| // Listen implements caddy.TCPServer interface. | // Listen implements caddy.TCPServer interface. | ||||||
| func (s *ServerHTTPS) Listen() (net.Listener, error) { | func (s *ServerHTTPS) Listen() (net.Listener, error) { | ||||||
|  |  | ||||||
| 	l, err := net.Listen("tcp", s.Addr[len(TransportHTTPS+"://"):]) | 	l, err := net.Listen("tcp", s.Addr[len(transport.HTTPS+"://"):]) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return nil, err | 		return nil, err | ||||||
| 	} | 	} | ||||||
| @@ -77,7 +78,7 @@ func (s *ServerHTTPS) OnStartupComplete() { | |||||||
| 		return | 		return | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	out := startUpZones(TransportHTTPS+"://", s.Addr, s.zones) | 	out := startUpZones(transport.HTTPS+"://", s.Addr, s.zones) | ||||||
| 	if out != "" { | 	if out != "" { | ||||||
| 		fmt.Print(out) | 		fmt.Print(out) | ||||||
| 	} | 	} | ||||||
|   | |||||||
| @@ -6,6 +6,8 @@ import ( | |||||||
| 	"fmt" | 	"fmt" | ||||||
| 	"net" | 	"net" | ||||||
|  |  | ||||||
|  | 	"github.com/coredns/coredns/plugin/pkg/transport" | ||||||
|  |  | ||||||
| 	"github.com/miekg/dns" | 	"github.com/miekg/dns" | ||||||
| ) | ) | ||||||
|  |  | ||||||
| @@ -55,7 +57,7 @@ func (s *ServerTLS) ServePacket(p net.PacketConn) error { return nil } | |||||||
|  |  | ||||||
| // Listen implements caddy.TCPServer interface. | // Listen implements caddy.TCPServer interface. | ||||||
| func (s *ServerTLS) Listen() (net.Listener, error) { | func (s *ServerTLS) Listen() (net.Listener, error) { | ||||||
| 	l, err := net.Listen("tcp", s.Addr[len(TransportTLS+"://"):]) | 	l, err := net.Listen("tcp", s.Addr[len(transport.TLS+"://"):]) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return nil, err | 		return nil, err | ||||||
| 	} | 	} | ||||||
| @@ -72,7 +74,7 @@ func (s *ServerTLS) OnStartupComplete() { | |||||||
| 		return | 		return | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	out := startUpZones(TransportTLS+"://", s.Addr, s.zones) | 	out := startUpZones(transport.TLS+"://", s.Addr, s.zones) | ||||||
| 	if out != "" { | 	if out != "" { | ||||||
| 		fmt.Print(out) | 		fmt.Print(out) | ||||||
| 	} | 	} | ||||||
|   | |||||||
| @@ -35,16 +35,16 @@ func averageTimeout(currentAvg *int64, observedDuration time.Duration, weight in | |||||||
| 	atomic.AddInt64(currentAvg, int64(observedDuration-dt)/weight) | 	atomic.AddInt64(currentAvg, int64(observedDuration-dt)/weight) | ||||||
| } | } | ||||||
|  |  | ||||||
| func (t *transport) dialTimeout() time.Duration { | func (t *Transport) dialTimeout() time.Duration { | ||||||
| 	return limitTimeout(&t.avgDialTime, minDialTimeout, maxDialTimeout) | 	return limitTimeout(&t.avgDialTime, minDialTimeout, maxDialTimeout) | ||||||
| } | } | ||||||
|  |  | ||||||
| func (t *transport) updateDialTimeout(newDialTime time.Duration) { | func (t *Transport) updateDialTimeout(newDialTime time.Duration) { | ||||||
| 	averageTimeout(&t.avgDialTime, newDialTime, cumulativeAvgWeight) | 	averageTimeout(&t.avgDialTime, newDialTime, cumulativeAvgWeight) | ||||||
| } | } | ||||||
|  |  | ||||||
| // Dial dials the address configured in transport, potentially reusing a connection or creating a new one. | // Dial dials the address configured in transport, potentially reusing a connection or creating a new one. | ||||||
| func (t *transport) Dial(proto string) (*dns.Conn, bool, error) { | func (t *Transport) Dial(proto string) (*dns.Conn, bool, error) { | ||||||
| 	// If tls has been configured; use it. | 	// If tls has been configured; use it. | ||||||
| 	if t.tlsConfig != nil { | 	if t.tlsConfig != nil { | ||||||
| 		proto = "tcp-tls" | 		proto = "tcp-tls" | ||||||
|   | |||||||
| @@ -4,6 +4,7 @@ import ( | |||||||
| 	"testing" | 	"testing" | ||||||
|  |  | ||||||
| 	"github.com/coredns/coredns/plugin/pkg/dnstest" | 	"github.com/coredns/coredns/plugin/pkg/dnstest" | ||||||
|  | 	"github.com/coredns/coredns/plugin/pkg/transport" | ||||||
| 	"github.com/coredns/coredns/plugin/test" | 	"github.com/coredns/coredns/plugin/test" | ||||||
| 	"github.com/coredns/coredns/request" | 	"github.com/coredns/coredns/request" | ||||||
|  |  | ||||||
| @@ -19,7 +20,7 @@ func TestForward(t *testing.T) { | |||||||
| 	}) | 	}) | ||||||
| 	defer s.Close() | 	defer s.Close() | ||||||
|  |  | ||||||
| 	p := NewProxy(s.Addr, DNS) | 	p := NewProxy(s.Addr, transport.DNS) | ||||||
| 	f := New() | 	f := New() | ||||||
| 	f.SetProxy(p) | 	f.SetProxy(p) | ||||||
| 	defer f.Close() | 	defer f.Close() | ||||||
| @@ -51,7 +52,7 @@ func TestForwardRefused(t *testing.T) { | |||||||
| 	}) | 	}) | ||||||
| 	defer s.Close() | 	defer s.Close() | ||||||
|  |  | ||||||
| 	p := NewProxy(s.Addr, DNS) | 	p := NewProxy(s.Addr, transport.DNS) | ||||||
| 	f := New() | 	f := New() | ||||||
| 	f.SetProxy(p) | 	f.SetProxy(p) | ||||||
| 	defer f.Close() | 	defer f.Close() | ||||||
|   | |||||||
| @@ -5,6 +5,8 @@ import ( | |||||||
| 	"sync/atomic" | 	"sync/atomic" | ||||||
| 	"time" | 	"time" | ||||||
|  |  | ||||||
|  | 	"github.com/coredns/coredns/plugin/pkg/transport" | ||||||
|  |  | ||||||
| 	"github.com/miekg/dns" | 	"github.com/miekg/dns" | ||||||
| ) | ) | ||||||
|  |  | ||||||
| @@ -17,10 +19,10 @@ type HealthChecker interface { | |||||||
| // dnsHc is a health checker for a DNS endpoint (DNS, and DoT). | // dnsHc is a health checker for a DNS endpoint (DNS, and DoT). | ||||||
| type dnsHc struct{ c *dns.Client } | type dnsHc struct{ c *dns.Client } | ||||||
|  |  | ||||||
| // NewHealthChecker returns a new HealthChecker based on protocol. | // NewHealthChecker returns a new HealthChecker based on transport. | ||||||
| func NewHealthChecker(protocol int) HealthChecker { | func NewHealthChecker(trans string) HealthChecker { | ||||||
| 	switch protocol { | 	switch trans { | ||||||
| 	case DNS, TLS: | 	case transport.DNS, transport.TLS: | ||||||
| 		c := new(dns.Client) | 		c := new(dns.Client) | ||||||
| 		c.Net = "udp" | 		c.Net = "udp" | ||||||
| 		c.ReadTimeout = 1 * time.Second | 		c.ReadTimeout = 1 * time.Second | ||||||
|   | |||||||
| @@ -7,6 +7,7 @@ import ( | |||||||
| 	"time" | 	"time" | ||||||
|  |  | ||||||
| 	"github.com/coredns/coredns/plugin/pkg/dnstest" | 	"github.com/coredns/coredns/plugin/pkg/dnstest" | ||||||
|  | 	"github.com/coredns/coredns/plugin/pkg/transport" | ||||||
| 	"github.com/coredns/coredns/plugin/test" | 	"github.com/coredns/coredns/plugin/test" | ||||||
|  |  | ||||||
| 	"github.com/miekg/dns" | 	"github.com/miekg/dns" | ||||||
| @@ -25,7 +26,7 @@ func TestHealth(t *testing.T) { | |||||||
| 	}) | 	}) | ||||||
| 	defer s.Close() | 	defer s.Close() | ||||||
|  |  | ||||||
| 	p := NewProxy(s.Addr, DNS) | 	p := NewProxy(s.Addr, transport.DNS) | ||||||
| 	f := New() | 	f := New() | ||||||
| 	f.SetProxy(p) | 	f.SetProxy(p) | ||||||
| 	defer f.Close() | 	defer f.Close() | ||||||
| @@ -65,7 +66,7 @@ func TestHealthTimeout(t *testing.T) { | |||||||
| 	}) | 	}) | ||||||
| 	defer s.Close() | 	defer s.Close() | ||||||
|  |  | ||||||
| 	p := NewProxy(s.Addr, DNS) | 	p := NewProxy(s.Addr, transport.DNS) | ||||||
| 	f := New() | 	f := New() | ||||||
| 	f.SetProxy(p) | 	f.SetProxy(p) | ||||||
| 	defer f.Close() | 	defer f.Close() | ||||||
| @@ -109,7 +110,7 @@ func TestHealthFailTwice(t *testing.T) { | |||||||
| 	}) | 	}) | ||||||
| 	defer s.Close() | 	defer s.Close() | ||||||
|  |  | ||||||
| 	p := NewProxy(s.Addr, DNS) | 	p := NewProxy(s.Addr, transport.DNS) | ||||||
| 	f := New() | 	f := New() | ||||||
| 	f.SetProxy(p) | 	f.SetProxy(p) | ||||||
| 	defer f.Close() | 	defer f.Close() | ||||||
| @@ -132,7 +133,7 @@ func TestHealthMaxFails(t *testing.T) { | |||||||
| 	}) | 	}) | ||||||
| 	defer s.Close() | 	defer s.Close() | ||||||
|  |  | ||||||
| 	p := NewProxy(s.Addr, DNS) | 	p := NewProxy(s.Addr, transport.DNS) | ||||||
| 	f := New() | 	f := New() | ||||||
| 	f.maxfails = 2 | 	f.maxfails = 2 | ||||||
| 	f.SetProxy(p) | 	f.SetProxy(p) | ||||||
| @@ -163,7 +164,7 @@ func TestHealthNoMaxFails(t *testing.T) { | |||||||
| 	}) | 	}) | ||||||
| 	defer s.Close() | 	defer s.Close() | ||||||
|  |  | ||||||
| 	p := NewProxy(s.Addr, DNS) | 	p := NewProxy(s.Addr, transport.DNS) | ||||||
| 	f := New() | 	f := New() | ||||||
| 	f.maxfails = 0 | 	f.maxfails = 0 | ||||||
| 	f.SetProxy(p) | 	f.SetProxy(p) | ||||||
|   | |||||||
| @@ -7,6 +7,7 @@ package forward | |||||||
| import ( | import ( | ||||||
| 	"context" | 	"context" | ||||||
|  |  | ||||||
|  | 	"github.com/coredns/coredns/plugin/pkg/transport" | ||||||
| 	"github.com/coredns/coredns/request" | 	"github.com/coredns/coredns/request" | ||||||
|  |  | ||||||
| 	"github.com/miekg/dns" | 	"github.com/miekg/dns" | ||||||
| @@ -81,7 +82,7 @@ func (f *Forward) Lookup(state request.Request, name string, typ uint16) (*dns.M | |||||||
| func NewLookup(addr []string) *Forward { | func NewLookup(addr []string) *Forward { | ||||||
| 	f := New() | 	f := New() | ||||||
| 	for i := range addr { | 	for i := range addr { | ||||||
| 		p := NewProxy(addr[i], DNS) | 		p := NewProxy(addr[i], transport.DNS) | ||||||
| 		f.SetProxy(p) | 		f.SetProxy(p) | ||||||
| 	} | 	} | ||||||
| 	return f | 	return f | ||||||
|   | |||||||
| @@ -4,6 +4,7 @@ import ( | |||||||
| 	"testing" | 	"testing" | ||||||
|  |  | ||||||
| 	"github.com/coredns/coredns/plugin/pkg/dnstest" | 	"github.com/coredns/coredns/plugin/pkg/dnstest" | ||||||
|  | 	"github.com/coredns/coredns/plugin/pkg/transport" | ||||||
| 	"github.com/coredns/coredns/plugin/test" | 	"github.com/coredns/coredns/plugin/test" | ||||||
| 	"github.com/coredns/coredns/request" | 	"github.com/coredns/coredns/request" | ||||||
|  |  | ||||||
| @@ -19,7 +20,7 @@ func TestLookup(t *testing.T) { | |||||||
| 	}) | 	}) | ||||||
| 	defer s.Close() | 	defer s.Close() | ||||||
|  |  | ||||||
| 	p := NewProxy(s.Addr, DNS) | 	p := NewProxy(s.Addr, transport.DNS) | ||||||
| 	f := New() | 	f := New() | ||||||
| 	f.SetProxy(p) | 	f.SetProxy(p) | ||||||
| 	defer f.Close() | 	defer f.Close() | ||||||
|   | |||||||
| @@ -15,8 +15,8 @@ type persistConn struct { | |||||||
| 	used time.Time | 	used time.Time | ||||||
| } | } | ||||||
|  |  | ||||||
| // transport hold the persistent cache. | // Transport hold the persistent cache. | ||||||
| type transport struct { | type Transport struct { | ||||||
| 	avgDialTime int64                     // kind of average time of dial time | 	avgDialTime int64                     // kind of average time of dial time | ||||||
| 	conns       map[string][]*persistConn // Buckets for udp, tcp and tcp-tls. | 	conns       map[string][]*persistConn // Buckets for udp, tcp and tcp-tls. | ||||||
| 	expire      time.Duration             // After this duration a connection is expired. | 	expire      time.Duration             // After this duration a connection is expired. | ||||||
| @@ -29,8 +29,8 @@ type transport struct { | |||||||
| 	stop  chan bool | 	stop  chan bool | ||||||
| } | } | ||||||
|  |  | ||||||
| func newTransport(addr string) *transport { | func newTransport(addr string) *Transport { | ||||||
| 	t := &transport{ | 	t := &Transport{ | ||||||
| 		avgDialTime: int64(defaultDialTimeout / 2), | 		avgDialTime: int64(defaultDialTimeout / 2), | ||||||
| 		conns:       make(map[string][]*persistConn), | 		conns:       make(map[string][]*persistConn), | ||||||
| 		expire:      defaultExpire, | 		expire:      defaultExpire, | ||||||
| @@ -45,7 +45,7 @@ func newTransport(addr string) *transport { | |||||||
|  |  | ||||||
| // len returns the number of connection, used for metrics. Can only be safely | // len returns the number of connection, used for metrics. Can only be safely | ||||||
| // used inside connManager() because of data races. | // used inside connManager() because of data races. | ||||||
| func (t *transport) len() int { | func (t *Transport) len() int { | ||||||
| 	l := 0 | 	l := 0 | ||||||
| 	for _, conns := range t.conns { | 	for _, conns := range t.conns { | ||||||
| 		l += len(conns) | 		l += len(conns) | ||||||
| @@ -54,7 +54,7 @@ func (t *transport) len() int { | |||||||
| } | } | ||||||
|  |  | ||||||
| // connManagers manages the persistent connection cache for UDP and TCP. | // connManagers manages the persistent connection cache for UDP and TCP. | ||||||
| func (t *transport) connManager() { | func (t *Transport) connManager() { | ||||||
| 	ticker := time.NewTicker(t.expire) | 	ticker := time.NewTicker(t.expire) | ||||||
| Wait: | Wait: | ||||||
| 	for { | 	for { | ||||||
| @@ -115,7 +115,7 @@ func closeConns(conns []*persistConn) { | |||||||
| } | } | ||||||
|  |  | ||||||
| // cleanup removes connections from cache. | // cleanup removes connections from cache. | ||||||
| func (t *transport) cleanup(all bool) { | func (t *Transport) cleanup(all bool) { | ||||||
| 	staleTime := time.Now().Add(-t.expire) | 	staleTime := time.Now().Add(-t.expire) | ||||||
| 	for proto, stack := range t.conns { | 	for proto, stack := range t.conns { | ||||||
| 		if len(stack) == 0 { | 		if len(stack) == 0 { | ||||||
| @@ -144,19 +144,19 @@ func (t *transport) cleanup(all bool) { | |||||||
| } | } | ||||||
|  |  | ||||||
| // Yield return the connection to transport for reuse. | // Yield return the connection to transport for reuse. | ||||||
| func (t *transport) Yield(c *dns.Conn) { t.yield <- c } | func (t *Transport) Yield(c *dns.Conn) { t.yield <- c } | ||||||
|  |  | ||||||
| // Start starts the transport's connection manager. | // Start starts the transport's connection manager. | ||||||
| func (t *transport) Start() { go t.connManager() } | func (t *Transport) Start() { go t.connManager() } | ||||||
|  |  | ||||||
| // Stop stops the transport's connection manager. | // Stop stops the transport's connection manager. | ||||||
| func (t *transport) Stop() { close(t.stop) } | func (t *Transport) Stop() { close(t.stop) } | ||||||
|  |  | ||||||
| // SetExpire sets the connection expire time in transport. | // SetExpire sets the connection expire time in transport. | ||||||
| func (t *transport) SetExpire(expire time.Duration) { t.expire = expire } | func (t *Transport) SetExpire(expire time.Duration) { t.expire = expire } | ||||||
|  |  | ||||||
| // SetTLSConfig sets the TLS config in transport. | // SetTLSConfig sets the TLS config in transport. | ||||||
| func (t *transport) SetTLSConfig(cfg *tls.Config) { t.tlsConfig = cfg } | func (t *Transport) SetTLSConfig(cfg *tls.Config) { t.tlsConfig = cfg } | ||||||
|  |  | ||||||
| const ( | const ( | ||||||
| 	defaultExpire      = 10 * time.Second | 	defaultExpire      = 10 * time.Second | ||||||
|   | |||||||
| @@ -1,30 +0,0 @@ | |||||||
| package forward |  | ||||||
|  |  | ||||||
| // Copied from coredns/core/dnsserver/address.go |  | ||||||
|  |  | ||||||
| import ( |  | ||||||
| 	"strings" |  | ||||||
| ) |  | ||||||
|  |  | ||||||
| // protocol returns the protocol of the string s. The second string returns s |  | ||||||
| // with the prefix chopped off. |  | ||||||
| func protocol(s string) (int, string) { |  | ||||||
| 	switch { |  | ||||||
| 	case strings.HasPrefix(s, _tls+"://"): |  | ||||||
| 		return TLS, s[len(_tls)+3:] |  | ||||||
| 	case strings.HasPrefix(s, _dns+"://"): |  | ||||||
| 		return DNS, s[len(_dns)+3:] |  | ||||||
| 	} |  | ||||||
| 	return DNS, s |  | ||||||
| } |  | ||||||
|  |  | ||||||
| // Supported protocols. |  | ||||||
| const ( |  | ||||||
| 	DNS = iota + 1 |  | ||||||
| 	TLS |  | ||||||
| ) |  | ||||||
|  |  | ||||||
| const ( |  | ||||||
| 	_dns = "dns" |  | ||||||
| 	_tls = "tls" |  | ||||||
| ) |  | ||||||
| @@ -18,7 +18,7 @@ type Proxy struct { | |||||||
|  |  | ||||||
| 	// Connection caching | 	// Connection caching | ||||||
| 	expire    time.Duration | 	expire    time.Duration | ||||||
| 	transport *transport | 	transport *Transport | ||||||
|  |  | ||||||
| 	// health checking | 	// health checking | ||||||
| 	probe  *up.Probe | 	probe  *up.Probe | ||||||
| @@ -26,7 +26,7 @@ type Proxy struct { | |||||||
| } | } | ||||||
|  |  | ||||||
| // NewProxy returns a new proxy. | // NewProxy returns a new proxy. | ||||||
| func NewProxy(addr string, protocol int) *Proxy { | func NewProxy(addr, trans string) *Proxy { | ||||||
| 	p := &Proxy{ | 	p := &Proxy{ | ||||||
| 		addr:      addr, | 		addr:      addr, | ||||||
| 		fails:     0, | 		fails:     0, | ||||||
| @@ -34,7 +34,7 @@ func NewProxy(addr string, protocol int) *Proxy { | |||||||
| 		transport: newTransport(addr), | 		transport: newTransport(addr), | ||||||
| 		avgRtt:    int64(maxTimeout / 2), | 		avgRtt:    int64(maxTimeout / 2), | ||||||
| 	} | 	} | ||||||
| 	p.health = NewHealthChecker(protocol) | 	p.health = NewHealthChecker(trans) | ||||||
| 	runtime.SetFinalizer(p, (*Proxy).finalizer) | 	runtime.SetFinalizer(p, (*Proxy).finalizer) | ||||||
| 	return p | 	return p | ||||||
| } | } | ||||||
|   | |||||||
| @@ -5,6 +5,7 @@ import ( | |||||||
| 	"testing" | 	"testing" | ||||||
|  |  | ||||||
| 	"github.com/coredns/coredns/plugin/pkg/dnstest" | 	"github.com/coredns/coredns/plugin/pkg/dnstest" | ||||||
|  | 	"github.com/coredns/coredns/plugin/pkg/transport" | ||||||
| 	"github.com/coredns/coredns/plugin/test" | 	"github.com/coredns/coredns/plugin/test" | ||||||
| 	"github.com/coredns/coredns/request" | 	"github.com/coredns/coredns/request" | ||||||
|  |  | ||||||
| @@ -26,7 +27,7 @@ func TestProxyClose(t *testing.T) { | |||||||
| 	ctx := context.TODO() | 	ctx := context.TODO() | ||||||
|  |  | ||||||
| 	for i := 0; i < 100; i++ { | 	for i := 0; i < 100; i++ { | ||||||
| 		p := NewProxy(s.Addr, DNS) | 		p := NewProxy(s.Addr, transport.DNS) | ||||||
| 		p.start(hcInterval) | 		p.start(hcInterval) | ||||||
|  |  | ||||||
| 		go func() { p.Connect(ctx, state, options{}) }() | 		go func() { p.Connect(ctx, state, options{}) }() | ||||||
| @@ -95,7 +96,7 @@ func TestProxyTLSFail(t *testing.T) { | |||||||
| } | } | ||||||
|  |  | ||||||
| func TestProtocolSelection(t *testing.T) { | func TestProtocolSelection(t *testing.T) { | ||||||
| 	p := NewProxy("bad_address", DNS) | 	p := NewProxy("bad_address", transport.DNS) | ||||||
|  |  | ||||||
| 	stateUDP := request.Request{W: &test.ResponseWriter{}, Req: new(dns.Msg)} | 	stateUDP := request.Request{W: &test.ResponseWriter{}, Req: new(dns.Msg)} | ||||||
| 	stateTCP := request.Request{W: &test.ResponseWriter{TCP: true}, Req: new(dns.Msg)} | 	stateTCP := request.Request{W: &test.ResponseWriter{TCP: true}, Req: new(dns.Msg)} | ||||||
|   | |||||||
| @@ -2,7 +2,6 @@ package forward | |||||||
|  |  | ||||||
| import ( | import ( | ||||||
| 	"fmt" | 	"fmt" | ||||||
| 	"net" |  | ||||||
| 	"strconv" | 	"strconv" | ||||||
| 	"time" | 	"time" | ||||||
|  |  | ||||||
| @@ -11,6 +10,7 @@ import ( | |||||||
| 	"github.com/coredns/coredns/plugin/metrics" | 	"github.com/coredns/coredns/plugin/metrics" | ||||||
| 	"github.com/coredns/coredns/plugin/pkg/dnsutil" | 	"github.com/coredns/coredns/plugin/pkg/dnsutil" | ||||||
| 	pkgtls "github.com/coredns/coredns/plugin/pkg/tls" | 	pkgtls "github.com/coredns/coredns/plugin/pkg/tls" | ||||||
|  | 	"github.com/coredns/coredns/plugin/pkg/transport" | ||||||
|  |  | ||||||
| 	"github.com/mholt/caddy" | 	"github.com/mholt/caddy" | ||||||
| 	"github.com/mholt/caddy/caddyfile" | 	"github.com/mholt/caddy/caddyfile" | ||||||
| @@ -93,8 +93,6 @@ func parseForward(c *caddy.Controller) (*Forward, error) { | |||||||
| func ParseForwardStanza(c *caddyfile.Dispenser) (*Forward, error) { | func ParseForwardStanza(c *caddyfile.Dispenser) (*Forward, error) { | ||||||
| 	f := New() | 	f := New() | ||||||
|  |  | ||||||
| 	protocols := map[int]int{} |  | ||||||
|  |  | ||||||
| 	if !c.Args(&f.from) { | 	if !c.Args(&f.from) { | ||||||
| 		return f, c.ArgErr() | 		return f, c.ArgErr() | ||||||
| 	} | 	} | ||||||
| @@ -105,41 +103,17 @@ func ParseForwardStanza(c *caddyfile.Dispenser) (*Forward, error) { | |||||||
| 		return f, c.ArgErr() | 		return f, c.ArgErr() | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	// A bit fiddly, but first check if we've got protocols and if so add them back in when we create the proxies. |  | ||||||
| 	protocols = make(map[int]int) |  | ||||||
| 	for i := range to { |  | ||||||
| 		protocols[i], to[i] = protocol(to[i]) |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	// If parseHostPortOrFile expands a file with a lot of nameserver our accounting in protocols doesn't make |  | ||||||
| 	// any sense anymore... For now: lets don't care. |  | ||||||
| 	toHosts, err := dnsutil.ParseHostPortOrFile(to...) | 	toHosts, err := dnsutil.ParseHostPortOrFile(to...) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return f, err | 		return f, err | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	for i, h := range toHosts { | 	transports := make([]string, len(toHosts)) | ||||||
| 		// Double check the port, if e.g. is 53 and the transport is TLS make it 853. | 	for i, host := range toHosts { | ||||||
| 		// This can be somewhat annoying because you *can't* have TLS on port 53 then. | 		trans, h := transport.Parse(host) | ||||||
| 		switch protocols[i] { | 		p := NewProxy(h, trans) | ||||||
| 		case TLS: |  | ||||||
| 			h1, p, err := net.SplitHostPort(h) |  | ||||||
| 			if err != nil { |  | ||||||
| 				break |  | ||||||
| 			} |  | ||||||
|  |  | ||||||
| 			// This is more of a bug in dnsutil.ParseHostPortOrFile that defaults to |  | ||||||
| 			// 53 because it doesn't know about the tls:// // and friends (that should be fixed). Hence |  | ||||||
| 			// Fix the port number here, back to what the user intended. |  | ||||||
| 			if p == "53" { |  | ||||||
| 				h = net.JoinHostPort(h1, "853") |  | ||||||
| 			} |  | ||||||
| 		} |  | ||||||
|  |  | ||||||
| 		// We can't set tlsConfig here, because we haven't parsed it yet. |  | ||||||
| 		// We set it below at the end of parseBlock, use nil now. |  | ||||||
| 		p := NewProxy(h, protocols[i]) |  | ||||||
| 		f.proxies = append(f.proxies, p) | 		f.proxies = append(f.proxies, p) | ||||||
|  | 		transports[i] = trans | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	for c.NextBlock() { | 	for c.NextBlock() { | ||||||
| @@ -153,7 +127,7 @@ func ParseForwardStanza(c *caddyfile.Dispenser) (*Forward, error) { | |||||||
| 	} | 	} | ||||||
| 	for i := range f.proxies { | 	for i := range f.proxies { | ||||||
| 		// Only set this for proxies that need it. | 		// Only set this for proxies that need it. | ||||||
| 		if protocols[i] == TLS { | 		if transports[i] == transport.TLS { | ||||||
| 			f.proxies[i].SetTLSConfig(f.tlsConfig) | 			f.proxies[i].SetTLSConfig(f.tlsConfig) | ||||||
| 		} | 		} | ||||||
| 		f.proxies[i].SetExpire(f.expire) | 		f.proxies[i].SetExpire(f.expire) | ||||||
|   | |||||||
| @@ -5,6 +5,7 @@ import ( | |||||||
| 	"testing" | 	"testing" | ||||||
|  |  | ||||||
| 	"github.com/coredns/coredns/plugin/pkg/dnstest" | 	"github.com/coredns/coredns/plugin/pkg/dnstest" | ||||||
|  | 	"github.com/coredns/coredns/plugin/pkg/transport" | ||||||
| 	"github.com/coredns/coredns/plugin/test" | 	"github.com/coredns/coredns/plugin/test" | ||||||
| 	"github.com/coredns/coredns/request" | 	"github.com/coredns/coredns/request" | ||||||
|  |  | ||||||
| @@ -34,7 +35,7 @@ func TestLookupTruncated(t *testing.T) { | |||||||
| 	}) | 	}) | ||||||
| 	defer s.Close() | 	defer s.Close() | ||||||
|  |  | ||||||
| 	p := NewProxy(s.Addr, DNS) | 	p := NewProxy(s.Addr, transport.DNS) | ||||||
| 	f := New() | 	f := New() | ||||||
| 	f.SetProxy(p) | 	f.SetProxy(p) | ||||||
| 	defer f.Close() | 	defer f.Close() | ||||||
| @@ -88,9 +89,9 @@ func TestForwardTruncated(t *testing.T) { | |||||||
|  |  | ||||||
| 	f := New() | 	f := New() | ||||||
|  |  | ||||||
| 	p1 := NewProxy(s.Addr, DNS) | 	p1 := NewProxy(s.Addr, transport.DNS) | ||||||
| 	f.SetProxy(p1) | 	f.SetProxy(p1) | ||||||
| 	p2 := NewProxy(s.Addr, DNS) | 	p2 := NewProxy(s.Addr, transport.DNS) | ||||||
| 	f.SetProxy(p2) | 	f.SetProxy(p2) | ||||||
| 	defer f.Close() | 	defer f.Close() | ||||||
|  |  | ||||||
|   | |||||||
| @@ -6,6 +6,8 @@ import ( | |||||||
| 	"strconv" | 	"strconv" | ||||||
| 	"strings" | 	"strings" | ||||||
|  |  | ||||||
|  | 	"github.com/coredns/coredns/plugin/pkg/transport" | ||||||
|  |  | ||||||
| 	"github.com/miekg/dns" | 	"github.com/miekg/dns" | ||||||
| ) | ) | ||||||
|  |  | ||||||
| @@ -61,22 +63,10 @@ type ( | |||||||
| // Normalize will return the host portion of host, stripping | // Normalize will return the host portion of host, stripping | ||||||
| // of any port or transport. The host will also be fully qualified and lowercased. | // of any port or transport. The host will also be fully qualified and lowercased. | ||||||
| func (h Host) Normalize() string { | func (h Host) Normalize() string { | ||||||
|  |  | ||||||
| 	s := string(h) | 	s := string(h) | ||||||
|  | 	_, s = transport.Parse(s) | ||||||
|  |  | ||||||
| 	switch { | 	// The error can be ignore here, because this function is called after the corefile has already been vetted. | ||||||
| 	case strings.HasPrefix(s, TransportTLS+"://"): |  | ||||||
| 		s = s[len(TransportTLS+"://"):] |  | ||||||
| 	case strings.HasPrefix(s, TransportDNS+"://"): |  | ||||||
| 		s = s[len(TransportDNS+"://"):] |  | ||||||
| 	case strings.HasPrefix(s, TransportGRPC+"://"): |  | ||||||
| 		s = s[len(TransportGRPC+"://"):] |  | ||||||
| 	case strings.HasPrefix(s, TransportHTTPS+"://"): |  | ||||||
| 		s = s[len(TransportHTTPS+"://"):] |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	// The error can be ignore here, because this function is called after the corefile |  | ||||||
| 	// has already been vetted. |  | ||||||
| 	host, _, _, _ := SplitHostPort(s) | 	host, _, _, _ := SplitHostPort(s) | ||||||
| 	return Name(host).Normalize() | 	return Name(host).Normalize() | ||||||
| } | } | ||||||
| @@ -138,11 +128,3 @@ func SplitHostPort(s string) (host, port string, ipnet *net.IPNet, err error) { | |||||||
| 	} | 	} | ||||||
| 	return host, port, n, nil | 	return host, port, n, nil | ||||||
| } | } | ||||||
|  |  | ||||||
| // Duplicated from core/dnsserver/address.go ! |  | ||||||
| const ( |  | ||||||
| 	TransportDNS   = "dns" |  | ||||||
| 	TransportTLS   = "tls" |  | ||||||
| 	TransportGRPC  = "grpc" |  | ||||||
| 	TransportHTTPS = "https" |  | ||||||
| ) |  | ||||||
|   | |||||||
| @@ -5,15 +5,21 @@ import ( | |||||||
| 	"net" | 	"net" | ||||||
| 	"os" | 	"os" | ||||||
|  |  | ||||||
|  | 	"github.com/coredns/coredns/plugin/pkg/transport" | ||||||
|  |  | ||||||
| 	"github.com/miekg/dns" | 	"github.com/miekg/dns" | ||||||
| ) | ) | ||||||
|  |  | ||||||
| // ParseHostPortOrFile parses the strings in s, each string can either be a address, | // ParseHostPortOrFile parses the strings in s, each string can either be a | ||||||
| // address:port or a filename. The address part is checked and the filename case a | // address, [scheme://]address:port or a filename. The address part is checked | ||||||
| // resolv.conf like file is parsed and the nameserver found are returned. | // and in case of filename a resolv.conf like file is (assumed) and parsed and | ||||||
|  | // the nameservers found are returned. | ||||||
| func ParseHostPortOrFile(s ...string) ([]string, error) { | func ParseHostPortOrFile(s ...string) ([]string, error) { | ||||||
| 	var servers []string | 	var servers []string | ||||||
| 	for _, host := range s { | 	for _, h := range s { | ||||||
|  |  | ||||||
|  | 		trans, host := transport.Parse(h) | ||||||
|  |  | ||||||
| 		addr, _, err := net.SplitHostPort(host) | 		addr, _, err := net.SplitHostPort(host) | ||||||
| 		if err != nil { | 		if err != nil { | ||||||
| 			// Parse didn't work, it is not a addr:port combo | 			// Parse didn't work, it is not a addr:port combo | ||||||
| @@ -26,13 +32,23 @@ func ParseHostPortOrFile(s ...string) ([]string, error) { | |||||||
| 				} | 				} | ||||||
| 				return servers, fmt.Errorf("not an IP address or file: %q", host) | 				return servers, fmt.Errorf("not an IP address or file: %q", host) | ||||||
| 			} | 			} | ||||||
| 			ss := net.JoinHostPort(host, "53") | 			var ss string | ||||||
|  | 			switch trans { | ||||||
|  | 			case transport.DNS: | ||||||
|  | 				ss = net.JoinHostPort(host, "53") | ||||||
|  | 			case transport.TLS: | ||||||
|  | 				ss = transport.TLS + "://" + net.JoinHostPort(host, transport.TLSPort) | ||||||
|  | 			case transport.GRPC: | ||||||
|  | 				ss = transport.GRPC + "://" + net.JoinHostPort(host, transport.GRPCPort) | ||||||
|  | 			case transport.HTTPS: | ||||||
|  | 				ss = transport.HTTPS + "://" + net.JoinHostPort(host, transport.HTTPSPort) | ||||||
|  | 			} | ||||||
| 			servers = append(servers, ss) | 			servers = append(servers, ss) | ||||||
| 			continue | 			continue | ||||||
| 		} | 		} | ||||||
|  |  | ||||||
| 		if net.ParseIP(addr) == nil { | 		if net.ParseIP(addr) == nil { | ||||||
| 			// No an IP address. | 			// Not an IP address. | ||||||
| 			ss, err := tryFile(host) | 			ss, err := tryFile(host) | ||||||
| 			if err == nil { | 			if err == nil { | ||||||
| 				servers = append(servers, ss...) | 				servers = append(servers, ss...) | ||||||
| @@ -40,7 +56,7 @@ func ParseHostPortOrFile(s ...string) ([]string, error) { | |||||||
| 			} | 			} | ||||||
| 			return servers, fmt.Errorf("not an IP address or file: %q", host) | 			return servers, fmt.Errorf("not an IP address or file: %q", host) | ||||||
| 		} | 		} | ||||||
| 		servers = append(servers, host) | 		servers = append(servers, h) | ||||||
| 	} | 	} | ||||||
| 	return servers, nil | 	return servers, nil | ||||||
| } | } | ||||||
|   | |||||||
							
								
								
									
										49
									
								
								plugin/pkg/transport/transport.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										49
									
								
								plugin/pkg/transport/transport.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,49 @@ | |||||||
|  | package transport | ||||||
|  |  | ||||||
|  | import ( | ||||||
|  | 	"strings" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | // Parse returns the transport defined in s and a string where the | ||||||
|  | // transport prefix is removed (if there was any). If no transport is defined | ||||||
|  | // we default to TransportDNS | ||||||
|  | func Parse(s string) (transport string, addr string) { | ||||||
|  | 	switch { | ||||||
|  | 	case strings.HasPrefix(s, TLS+"://"): | ||||||
|  | 		s = s[len(TLS+"://"):] | ||||||
|  | 		return TLS, s | ||||||
|  |  | ||||||
|  | 	case strings.HasPrefix(s, DNS+"://"): | ||||||
|  | 		s = s[len(DNS+"://"):] | ||||||
|  | 		return DNS, s | ||||||
|  |  | ||||||
|  | 	case strings.HasPrefix(s, GRPC+"://"): | ||||||
|  | 		s = s[len(GRPC+"://"):] | ||||||
|  | 		return GRPC, s | ||||||
|  |  | ||||||
|  | 	case strings.HasPrefix(s, HTTPS+"://"): | ||||||
|  | 		s = s[len(HTTPS+"://"):] | ||||||
|  |  | ||||||
|  | 		return HTTPS, s | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	return DNS, s | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // Supported transports. | ||||||
|  | const ( | ||||||
|  | 	DNS   = "dns" | ||||||
|  | 	TLS   = "tls" | ||||||
|  | 	GRPC  = "grpc" | ||||||
|  | 	HTTPS = "https" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | // Port numbers for the various protocols | ||||||
|  | const ( | ||||||
|  | 	// TLSPort is the default port for DNS-over-TLS. | ||||||
|  | 	TLSPort = "853" | ||||||
|  | 	// GRPCPort is the default port for DNS-over-gRPC. | ||||||
|  | 	GRPCPort = "443" | ||||||
|  | 	// HTTPSPort is the default port for DNS-over-HTTPS. | ||||||
|  | 	HTTPSPort = "443" | ||||||
|  | ) | ||||||
							
								
								
									
										21
									
								
								plugin/pkg/transport/transport_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										21
									
								
								plugin/pkg/transport/transport_test.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,21 @@ | |||||||
|  | package transport | ||||||
|  |  | ||||||
|  | import "testing" | ||||||
|  |  | ||||||
|  | func TestParse(t *testing.T) { | ||||||
|  | 	for i, test := range []struct { | ||||||
|  | 		input    string | ||||||
|  | 		expected string | ||||||
|  | 	}{ | ||||||
|  | 		{"dns://.:53", DNS}, | ||||||
|  | 		{"2003::1/64.:53", DNS}, | ||||||
|  | 		{"grpc://example.org:1443 ", GRPC}, | ||||||
|  | 		{"tls://example.org ", TLS}, | ||||||
|  | 		{"https://example.org ", HTTPS}, | ||||||
|  | 	} { | ||||||
|  | 		actual, _ := Parse(test.input) | ||||||
|  | 		if actual != test.expected { | ||||||
|  | 			t.Errorf("Test %d: Expected %s but got %s", i, test.expected, actual) | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | } | ||||||
		Reference in New Issue
	
	Block a user