Cleanup ParseHostOrFile (#2100)

Create plugin/pkg/transport that holds the transport related functions.
This needed to be a new pkg to prevent cyclic import errors.

This cleans up a bunch of duplicated code in core/dnsserver that also
tried to parse a transport (now all done in transport.Parse).

Signed-off-by: Miek Gieben <miek@miek.nl>
This commit is contained in:
Miek Gieben
2018-09-19 07:29:37 +01:00
committed by GitHub
parent 2f1223c36a
commit c349446a23
24 changed files with 182 additions and 221 deletions

View File

@@ -6,6 +6,7 @@ import (
"strings"
"github.com/coredns/coredns/plugin"
"github.com/coredns/coredns/plugin/pkg/transport"
"github.com/miekg/dns"
)
@@ -27,43 +28,13 @@ func (z zoneAddr) String() string {
return s
}
// Transport returns the protocol of the string s
func Transport(s string) string {
switch {
case strings.HasPrefix(s, TransportTLS+"://"):
return TransportTLS
case strings.HasPrefix(s, TransportDNS+"://"):
return TransportDNS
case strings.HasPrefix(s, TransportGRPC+"://"):
return TransportGRPC
case strings.HasPrefix(s, TransportHTTPS+"://"):
return TransportHTTPS
}
return TransportDNS
}
// normalizeZone parses an zone string into a structured format with separate
// host, and port portions, as well as the original input string.
func normalizeZone(str string) (zoneAddr, error) {
var err error
// Default to DNS if there isn't a transport protocol prefix.
trans := TransportDNS
switch {
case strings.HasPrefix(str, TransportTLS+"://"):
trans = TransportTLS
str = str[len(TransportTLS+"://"):]
case strings.HasPrefix(str, TransportDNS+"://"):
trans = TransportDNS
str = str[len(TransportDNS+"://"):]
case strings.HasPrefix(str, TransportGRPC+"://"):
trans = TransportGRPC
str = str[len(TransportGRPC+"://"):]
case strings.HasPrefix(str, TransportHTTPS+"://"):
trans = TransportHTTPS
str = str[len(TransportHTTPS+"://"):]
}
var trans string
trans, str = transport.Parse(str)
host, port, ipnet, err := plugin.SplitHostPort(str)
if err != nil {
@@ -71,17 +42,15 @@ func normalizeZone(str string) (zoneAddr, error) {
}
if port == "" {
if trans == TransportDNS {
switch trans {
case transport.DNS:
port = Port
}
if trans == TransportTLS {
port = TLSPort
}
if trans == TransportGRPC {
port = GRPCPort
}
if trans == TransportHTTPS {
port = HTTPSPort
case transport.TLS:
port = transport.TLSPort
case transport.GRPC:
port = transport.GRPCPort
case transport.HTTPS:
port = transport.HTTPSPort
}
}
@@ -103,14 +72,6 @@ func SplitProtocolHostPort(address string) (protocol string, ip string, port str
}
}
// Supported transports.
const (
TransportDNS = "dns"
TransportTLS = "tls"
TransportGRPC = "grpc"
TransportHTTPS = "https"
)
type zoneOverlap struct {
registeredAddr map[zoneAddr]zoneAddr // each zoneAddr is registered once by its key
unboundOverlap map[zoneAddr]zoneAddr // the "no bind" equiv ZoneAdddr is registered by its original key

View File

@@ -192,21 +192,3 @@ func TestOverlapAddressChecker(t *testing.T) {
}
}
}
func TestTransport(t *testing.T) {
for i, test := range []struct {
input string
expected string
}{
{"dns://.:53", TransportDNS},
{"2003::1/64.:53", TransportDNS},
{"grpc://example.org:1443 ", TransportGRPC},
{"tls://example.org ", TransportTLS},
{"https://example.org ", TransportHTTPS},
} {
actual := Transport(test.input)
if actual != test.expected {
t.Errorf("Test %d: Expected %s but got %s", i, test.expected, actual)
}
}
}

View File

@@ -9,6 +9,7 @@ import (
"github.com/coredns/coredns/plugin"
"github.com/coredns/coredns/plugin/pkg/dnsutil"
"github.com/coredns/coredns/plugin/pkg/transport"
"github.com/mholt/caddy"
"github.com/mholt/caddy/caddyfile"
@@ -111,29 +112,29 @@ func (h *dnsContext) MakeServers() ([]caddy.Server, error) {
var servers []caddy.Server
for addr, group := range groups {
// switch on addr
switch Transport(addr) {
case TransportDNS:
switch tr, _ := transport.Parse(addr); tr {
case transport.DNS:
s, err := NewServer(addr, group)
if err != nil {
return nil, err
}
servers = append(servers, s)
case TransportTLS:
case transport.TLS:
s, err := NewServerTLS(addr, group)
if err != nil {
return nil, err
}
servers = append(servers, s)
case TransportGRPC:
case transport.GRPC:
s, err := NewServergRPC(addr, group)
if err != nil {
return nil, err
}
servers = append(servers, s)
case TransportHTTPS:
case transport.HTTPS:
s, err := NewServerHTTPS(addr, group)
if err != nil {
return nil, err
@@ -234,16 +235,8 @@ func groupConfigsByListenAddr(configs []*Config) (map[string][]*Config, error) {
return groups, nil
}
const (
// DefaultPort is the default port.
DefaultPort = "53"
// TLSPort is the default port for DNS-over-TLS.
TLSPort = "853"
// GRPCPort is the default port for DNS-over-gRPC.
GRPCPort = "443"
// HTTPSPort is the default port for DNS-over-HTTPS.
HTTPSPort = "443"
)
// DefaultPort is the default port.
const DefaultPort = "53"
// These "soft defaults" are configurable by
// command line flags, etc.

View File

@@ -15,6 +15,7 @@ import (
"github.com/coredns/coredns/plugin/pkg/log"
"github.com/coredns/coredns/plugin/pkg/rcode"
"github.com/coredns/coredns/plugin/pkg/trace"
"github.com/coredns/coredns/plugin/pkg/transport"
"github.com/coredns/coredns/request"
"github.com/miekg/dns"
@@ -134,7 +135,7 @@ func (s *Server) ServePacket(p net.PacketConn) error {
// Listen implements caddy.TCPServer interface.
func (s *Server) Listen() (net.Listener, error) {
l, err := net.Listen("tcp", s.Addr[len(TransportDNS+"://"):])
l, err := net.Listen("tcp", s.Addr[len(transport.DNS+"://"):])
if err != nil {
return nil, err
}
@@ -143,7 +144,7 @@ func (s *Server) Listen() (net.Listener, error) {
// ListenPacket implements caddy.UDPServer interface.
func (s *Server) ListenPacket() (net.PacketConn, error) {
p, err := net.ListenPacket("udp", s.Addr[len(TransportDNS+"://"):])
p, err := net.ListenPacket("udp", s.Addr[len(transport.DNS+"://"):])
if err != nil {
return nil, err
}

View File

@@ -8,6 +8,7 @@ import (
"net"
"github.com/coredns/coredns/pb"
"github.com/coredns/coredns/plugin/pkg/transport"
"github.com/coredns/coredns/plugin/pkg/watch"
"github.com/grpc-ecosystem/grpc-opentracing/go/otgrpc"
@@ -73,7 +74,7 @@ func (s *ServergRPC) ServePacket(p net.PacketConn) error { return nil }
// Listen implements caddy.TCPServer interface.
func (s *ServergRPC) Listen() (net.Listener, error) {
l, err := net.Listen("tcp", s.Addr[len(TransportGRPC+"://"):])
l, err := net.Listen("tcp", s.Addr[len(transport.GRPC+"://"):])
if err != nil {
return nil, err
}
@@ -90,7 +91,7 @@ func (s *ServergRPC) OnStartupComplete() {
return
}
out := startUpZones(TransportGRPC+"://", s.Addr, s.zones)
out := startUpZones(transport.GRPC+"://", s.Addr, s.zones)
if out != "" {
fmt.Print(out)
}

View File

@@ -12,6 +12,7 @@ import (
"github.com/coredns/coredns/plugin/pkg/dnsutil"
"github.com/coredns/coredns/plugin/pkg/doh"
"github.com/coredns/coredns/plugin/pkg/response"
"github.com/coredns/coredns/plugin/pkg/transport"
)
// ServerHTTPS represents an instance of a DNS-over-HTTPS server.
@@ -60,7 +61,7 @@ func (s *ServerHTTPS) ServePacket(p net.PacketConn) error { return nil }
// Listen implements caddy.TCPServer interface.
func (s *ServerHTTPS) Listen() (net.Listener, error) {
l, err := net.Listen("tcp", s.Addr[len(TransportHTTPS+"://"):])
l, err := net.Listen("tcp", s.Addr[len(transport.HTTPS+"://"):])
if err != nil {
return nil, err
}
@@ -77,7 +78,7 @@ func (s *ServerHTTPS) OnStartupComplete() {
return
}
out := startUpZones(TransportHTTPS+"://", s.Addr, s.zones)
out := startUpZones(transport.HTTPS+"://", s.Addr, s.zones)
if out != "" {
fmt.Print(out)
}

View File

@@ -6,6 +6,8 @@ import (
"fmt"
"net"
"github.com/coredns/coredns/plugin/pkg/transport"
"github.com/miekg/dns"
)
@@ -55,7 +57,7 @@ func (s *ServerTLS) ServePacket(p net.PacketConn) error { return nil }
// Listen implements caddy.TCPServer interface.
func (s *ServerTLS) Listen() (net.Listener, error) {
l, err := net.Listen("tcp", s.Addr[len(TransportTLS+"://"):])
l, err := net.Listen("tcp", s.Addr[len(transport.TLS+"://"):])
if err != nil {
return nil, err
}
@@ -72,7 +74,7 @@ func (s *ServerTLS) OnStartupComplete() {
return
}
out := startUpZones(TransportTLS+"://", s.Addr, s.zones)
out := startUpZones(transport.TLS+"://", s.Addr, s.zones)
if out != "" {
fmt.Print(out)
}