diff --git a/plugin/forward/README.md b/plugin/forward/README.md index 9575e96e6..74278414c 100644 --- a/plugin/forward/README.md +++ b/plugin/forward/README.md @@ -31,7 +31,9 @@ forward FROM TO... that expand to multiple reverse zones are not fully supported; only the first expanded zone is used. * **TO...** are the destination endpoints to forward to. The **TO** syntax allows you to specify a protocol, `tls://9.9.9.9` or `dns://` (or no protocol) for plain DNS. The number of upstreams is - limited to 15. + limited to 15. In addition to IP addresses and files (like `/etc/resolv.conf`), **TO** can also be + a hostname (e.g., `my-dns.svc.cluster.local`). Hostnames are resolved to IP addresses at startup. + See the `resolver` option below. Multiple upstreams are randomized (see `policy`) on first use. When a healthy proxy returns an error during the exchange the next upstream in the list is tried. @@ -55,6 +57,7 @@ forward FROM TO... { next RCODE_1 [RCODE_2] [RCODE_3...] failfast_all_unhealthy_upstreams failover RCODE_1 [RCODE_2] [RCODE_3...] + resolver IP[:PORT] [IP[:PORT]...] } ~~~ @@ -114,6 +117,7 @@ forward FROM TO... { * `next_on_nodata` If `NOERROR` is returned by the remote, but an empty answer section (`NODATA`) was provided, execute the next `forward` plugin, if configured. * `failfast_all_unhealthy_upstreams` - determines the handling of requests when all upstream servers are unhealthy and unresponsive to health checks. Enabling this option will immediately return SERVFAIL responses for all requests. By default, requests are sent to a random upstream. * `failover` - By default when a DNS lookup fails to return a DNS response (e.g. timeout), _forward_ will attempt a lookup on the next upstream server. The `failover` option will make _forward_ do the same for any response with a response code matching an `RCODE` ( e.g. `SERVFAIL`、`REFUSED`). `NOERROR` cannot be used. If all upstreams have been tried, the response from the last attempt is returned. +* `resolver` **IP[:PORT] [IP[:PORT]...]** specifies one or more DNS resolver addresses used to resolve hostname-based **TO** endpoints at startup. If not specified, the system resolver (`/etc/resolv.conf`) is used. Each address is either a bare IP (IPv4 or IPv6, port 53 assumed) or `IP:port`. Multiple addresses can be specified for redundancy. Also note the TLS config is "global" for the whole forwarding proxy if you need a different `tls_servername` for different upstreams you're out of luck. @@ -313,6 +317,16 @@ In the following example, if the response from `1.2.3.4` is `SERVFAIL` or `REFUS } ~~~ +Forward to an upstream identified by hostname, using a specific resolver to look it up: + +~~~ txt +. { + forward . dns.example.local { + resolver 10.0.0.1 + } +} +~~~ + ## See Also [RFC 7858](https://tools.ietf.org/html/rfc7858) for DNS over TLS. diff --git a/plugin/forward/forward.go b/plugin/forward/forward.go index bca47c3cb..0b48da462 100644 --- a/plugin/forward/forward.go +++ b/plugin/forward/forward.go @@ -57,6 +57,10 @@ type Forward struct { failoverRcodes []int maxConnectAttempts uint32 + // Hostname resolution fields + resolver []string // custom resolver IPs for hostname TO resolution + toEntries []toEntry // ordered TO entries preserving config order + opts proxyPkg.Options // also here for testing // ErrLimitExceeded indicates that a query was rejected because the number of concurrent queries has exceeded diff --git a/plugin/forward/resolve.go b/plugin/forward/resolve.go new file mode 100644 index 000000000..f10e8817b --- /dev/null +++ b/plugin/forward/resolve.go @@ -0,0 +1,257 @@ +package forward + +import ( + "fmt" + "net" + "strings" + "time" + + "github.com/coredns/coredns/plugin/pkg/parse" + "github.com/coredns/coredns/plugin/pkg/transport" + + "github.com/miekg/dns" +) + +// hostEntry represents a hostname-based TO address that needs DNS resolution. +type hostEntry struct { + hostname string // the hostname to resolve (e.g., "rbldnsd.rbldnsd.svc.cluster.local") + port string // port (e.g., "53", "853") + transport string // "dns" or "tls" + zone string // TLS server name zone (from %zone syntax) +} + +// toEntry represents a single TO address from the config, preserving order. +type toEntry struct { + static bool // true for IP/file-based entries + addrs []string // for static: resolved by HostPortOrFile + entry hostEntry // for dynamic: hostname to resolve +} + +// classifyToAddrs processes TO addresses in order, returning an ordered list of +// toEntries that preserves config ordering. +func classifyToAddrs(toAddrs []string) ([]toEntry, error) { + var entries []toEntry + for _, h := range toAddrs { + // Try HostPortOrFile first - this handles IPs and files + hosts, parseErr := parse.HostPortOrFile(h) + if parseErr == nil { + entries = append(entries, toEntry{static: true, addrs: hosts}) + continue + } + + // Only fall through to hostname parsing if the error specifically + // indicates the address is not an IP or file. Other errors (like + // "no nameservers found" from file parsing) should be propagated. + if !strings.Contains(parseErr.Error(), "not an IP address or file") { + return nil, parseErr + } + + // Not an IP or file - check if it's a valid hostname + entry, ok := parseAsHostEntry(h) + if !ok { + return nil, fmt.Errorf("not an IP address, file, or valid domain: %q", h) + } + entries = append(entries, toEntry{static: false, entry: entry}) + } + return entries, nil +} + +// parseAsHostEntry attempts to parse a TO address as a hostname-based entry. +func parseAsHostEntry(h string) (hostEntry, bool) { + cleanH, zone := splitZone(h) + trans, host := parse.Transport(cleanH) + + // Only dns and tls transports are supported for hostname resolution + if trans != transport.DNS && trans != transport.TLS { + return hostEntry{}, false + } + + hostname := host + port := transport.Port + if trans == transport.TLS { + port = transport.TLSPort + } + + // Check if there's a port + if h2, p, err := net.SplitHostPort(host); err == nil { + hostname = h2 + port = p + } + + hostname = strings.Trim(hostname, "[]") + + // Validate as domain name + if _, ok := dns.IsDomainName(hostname); !ok || hostname == "" { + return hostEntry{}, false + } + + // Make sure it's not actually an IP + if net.ParseIP(hostname) != nil { + return hostEntry{}, false + } + + return hostEntry{ + hostname: hostname, + port: port, + transport: trans, + zone: zone, + }, true +} + +// expandAndDedup resolves all toEntries in order, expands hostnames to IPs, +// and deduplicates by first-seen address. Returns the deduplicated address list. +func expandAndDedup(entries []toEntry, resolvers []string) ([]string, error) { + seen := make(map[string]bool) + var result []string + + for _, e := range entries { + var addrs []string + if e.static { + addrs = e.addrs + } else { + resolved, err := resolveHostEntry(e.entry, resolvers) + if err != nil { + return nil, err + } + addrs = resolved + } + + for _, addr := range addrs { + // Normalize the address for dedup comparison + key := normalizeAddr(addr) + if !seen[key] { + seen[key] = true + result = append(result, addr) + } + } + } + return result, nil +} + +// normalizeAddr extracts the canonical IP:port from an address string +// (stripping transport prefix and zone) for deduplication. +func normalizeAddr(addr string) string { + host, _ := splitZone(addr) + _, h := parse.Transport(host) + return h +} + +// resolveHostEntry resolves a single hostname entry and returns its addresses. +func resolveHostEntry(entry hostEntry, resolvers []string) ([]string, error) { + ips, err := lookupHost(entry.hostname, resolvers) + if err != nil { + return nil, fmt.Errorf("failed to resolve %q: %v", entry.hostname, err) + } + var addrs []string + for _, ip := range ips { + addrs = append(addrs, formatResolvedAddr(ip, entry.port, entry.transport, entry.zone)) + } + return addrs, nil +} + +// formatResolvedAddr formats a resolved IP into an address string compatible +// with the proxy creation code in parseStanza. +func formatResolvedAddr(ip, port, trans, zone string) string { + isIPv6 := strings.Contains(ip, ":") + + switch trans { + case transport.TLS: + if zone != "" { + if isIPv6 { + return transport.TLS + "://[" + ip + "%" + zone + "]:" + port + } + return transport.TLS + "://" + ip + "%" + zone + ":" + port + } + return transport.TLS + "://" + net.JoinHostPort(ip, port) + default: // transport.DNS + return net.JoinHostPort(ip, port) + } +} + +// lookupHost resolves a hostname to IP addresses using the specified resolvers. +// If resolvers is empty, the system resolver (/etc/resolv.conf) is used. +func lookupHost(hostname string, resolvers []string) ([]string, error) { + if len(resolvers) == 0 { + return systemLookup(hostname) + } + return dnsLookup(hostname, resolvers) +} + +// systemLookup resolves using the system resolver (/etc/resolv.conf). +func systemLookup(hostname string) ([]string, error) { + ips, err := net.LookupHost(hostname) + if err != nil { + return nil, err + } + if len(ips) == 0 { + return nil, fmt.Errorf("no addresses found for %q", hostname) + } + return ips, nil +} + +// dnsLookup resolves a hostname using specific DNS resolver addresses. +// Each resolver can be a bare IP (port 53 is assumed) or an IP:port pair. +// It tries each resolver in order until one succeeds. +func dnsLookup(hostname string, resolvers []string) ([]string, error) { + c := new(dns.Client) + c.ReadTimeout = 2 * time.Second + c.WriteTimeout = 2 * time.Second + + var lastErr error + + for _, resolver := range resolvers { + resolverAddr := resolver + if _, _, err := net.SplitHostPort(resolver); err != nil { + resolverAddr = net.JoinHostPort(resolver, transport.Port) + } + var ips []string + + // Try A records + m := new(dns.Msg) + m.SetQuestion(dns.Fqdn(hostname), dns.TypeA) + m.RecursionDesired = true + + r, _, err := c.Exchange(m, resolverAddr) + if err != nil { + lastErr = err + continue + } + if r != nil { + for _, ans := range r.Answer { + if a, ok := ans.(*dns.A); ok { + ips = append(ips, a.A.String()) + } + } + } + + // Also try AAAA + m = new(dns.Msg) + m.SetQuestion(dns.Fqdn(hostname), dns.TypeAAAA) + m.RecursionDesired = true + + r, _, err = c.Exchange(m, resolverAddr) + if err != nil { + if len(ips) > 0 { + return ips, nil // we have A records, AAAA failure is OK + } + lastErr = err + continue + } + if r != nil { + for _, ans := range r.Answer { + if aaaa, ok := ans.(*dns.AAAA); ok { + ips = append(ips, aaaa.AAAA.String()) + } + } + } + + if len(ips) > 0 { + return ips, nil + } + } + + if lastErr != nil { + return nil, fmt.Errorf("no addresses found for %q: %v", hostname, lastErr) + } + return nil, fmt.Errorf("no addresses found for %q", hostname) +} diff --git a/plugin/forward/resolve_test.go b/plugin/forward/resolve_test.go new file mode 100644 index 000000000..2d989b5dc --- /dev/null +++ b/plugin/forward/resolve_test.go @@ -0,0 +1,608 @@ +package forward + +import ( + "fmt" + "os" + "strings" + "testing" + + "github.com/coredns/caddy" + "github.com/coredns/coredns/plugin/pkg/dnstest" + "github.com/coredns/coredns/plugin/pkg/parse" + "github.com/coredns/coredns/plugin/pkg/proxy" + "github.com/coredns/coredns/plugin/pkg/transport" + "github.com/coredns/coredns/plugin/test" + + "github.com/miekg/dns" +) + +func TestClassifyToAddrs(t *testing.T) { + // Create a resolv.conf for file test + const resolv = "test_resolv.conf" + if err := os.WriteFile(resolv, []byte("nameserver 10.0.0.1\n"), 0666); err != nil { + t.Fatal(err) + } + defer os.Remove(resolv) + + tests := []struct { + name string + input []string + wantStatic int + wantDynamic int + wantErr bool + errContains string + }{ + { + name: "simple IP", + input: []string{"127.0.0.1"}, + wantStatic: 1, + }, + { + name: "IP with port", + input: []string{"127.0.0.1:8053"}, + wantStatic: 1, + }, + { + name: "IPv6", + input: []string{"::1"}, + wantStatic: 1, + }, + { + name: "TLS IP", + input: []string{"tls://127.0.0.1"}, + wantStatic: 1, + }, + { + name: "resolv.conf file", + input: []string{resolv}, + wantStatic: 1, + }, + { + name: "hostname", + input: []string{"dns.example.com"}, + wantDynamic: 1, + }, + { + name: "hostname with port", + input: []string{"dns.example.com:5353"}, + wantDynamic: 1, + }, + { + name: "TLS hostname", + input: []string{"tls://dns.example.com"}, + wantDynamic: 1, + }, + { + name: "k8s service name", + input: []string{"rbldnsd.rbldnsd.svc.cluster.local"}, + wantDynamic: 1, + }, + { + name: "mixed IPs and hostnames", + input: []string{"127.0.0.1", "dns.example.com", "10.0.0.1"}, + wantStatic: 2, + wantDynamic: 1, + }, + { + name: "/dev/null returns file error", + input: []string{"/dev/null"}, + wantErr: true, + errContains: "no nameservers", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + entries, err := classifyToAddrs(tc.input) + if tc.wantErr { + if err == nil { + t.Fatal("expected error, got nil") + } + if tc.errContains != "" && !strings.Contains(err.Error(), tc.errContains) { + t.Errorf("expected error to contain %q, got: %v", tc.errContains, err) + } + return + } + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + staticCount := 0 + dynamicCount := 0 + for _, e := range entries { + if e.static { + staticCount++ + } else { + dynamicCount++ + } + } + if staticCount != tc.wantStatic { + t.Errorf("expected %d static entries, got %d", tc.wantStatic, staticCount) + } + if dynamicCount != tc.wantDynamic { + t.Errorf("expected %d dynamic entries, got %d", tc.wantDynamic, dynamicCount) + } + }) + } +} + +func TestClassifyToAddrsPreservesOrder(t *testing.T) { + entries, err := classifyToAddrs([]string{"dns.example.com", "127.0.0.1", "other.example.com"}) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if len(entries) != 3 { + t.Fatalf("expected 3 entries, got %d", len(entries)) + } + if entries[0].static || entries[0].entry.hostname != "dns.example.com" { + t.Errorf("entry 0: expected dynamic dns.example.com, got static=%v entry=%v", entries[0].static, entries[0].entry) + } + if !entries[1].static || entries[1].addrs[0] != "127.0.0.1:53" { + t.Errorf("entry 1: expected static 127.0.0.1:53, got static=%v addrs=%v", entries[1].static, entries[1].addrs) + } + if entries[2].static || entries[2].entry.hostname != "other.example.com" { + t.Errorf("entry 2: expected dynamic other.example.com, got static=%v entry=%v", entries[2].static, entries[2].entry) + } +} + +func TestParseAsHostEntry(t *testing.T) { + tests := []struct { + input string + wantOK bool + hostname string + port string + transport string + zone string + }{ + {"dns.example.com", true, "dns.example.com", "53", transport.DNS, ""}, + {"dns.example.com:5353", true, "dns.example.com", "5353", transport.DNS, ""}, + {"tls://dns.example.com", true, "dns.example.com", "853", transport.TLS, ""}, + {"tls://dns.example.com:8853", true, "dns.example.com", "8853", transport.TLS, ""}, + {"tls://dns.example.com%servername.example.com", true, "dns.example.com", "853", transport.TLS, "servername.example.com"}, + {"rbldnsd.rbldnsd.svc.cluster.local", true, "rbldnsd.rbldnsd.svc.cluster.local", "53", transport.DNS, ""}, + // Should fail for IPs + {"127.0.0.1", false, "", "", "", ""}, + {"::1", false, "", "", "", ""}, + // Should fail for unsupported transports + {"https://example.com", false, "", "", "", ""}, + // Should fail for empty + {"", false, "", "", "", ""}, + } + + for _, tc := range tests { + t.Run(tc.input, func(t *testing.T) { + entry, ok := parseAsHostEntry(tc.input) + if ok != tc.wantOK { + t.Fatalf("expected ok=%v, got %v", tc.wantOK, ok) + } + if !ok { + return + } + if entry.hostname != tc.hostname { + t.Errorf("expected hostname=%q, got %q", tc.hostname, entry.hostname) + } + if entry.port != tc.port { + t.Errorf("expected port=%q, got %q", tc.port, entry.port) + } + if entry.transport != tc.transport { + t.Errorf("expected transport=%q, got %q", tc.transport, entry.transport) + } + if entry.zone != tc.zone { + t.Errorf("expected zone=%q, got %q", tc.zone, entry.zone) + } + }) + } +} + +func TestFormatResolvedAddr(t *testing.T) { + tests := []struct { + ip, port, trans, zone string + expected string + }{ + {"10.0.0.1", "53", transport.DNS, "", "10.0.0.1:53"}, + {"10.0.0.1", "853", transport.TLS, "", "tls://10.0.0.1:853"}, + {"10.0.0.1", "853", transport.TLS, "example.com", "tls://10.0.0.1%example.com:853"}, + {"::1", "53", transport.DNS, "", "[::1]:53"}, + {"::1", "853", transport.TLS, "", "tls://[::1]:853"}, + {"::1", "853", transport.TLS, "example.com", "tls://[::1%example.com]:853"}, + } + + for _, tc := range tests { + t.Run(tc.expected, func(t *testing.T) { + result := formatResolvedAddr(tc.ip, tc.port, tc.trans, tc.zone) + if result != tc.expected { + t.Errorf("expected %q, got %q", tc.expected, result) + } + }) + } +} + +func TestExpandAndDedup(t *testing.T) { + // Start a test DNS server that returns different IPs for different hostnames + s := dnstest.NewMultipleServer(func(w dns.ResponseWriter, r *dns.Msg) { + ret := new(dns.Msg) + ret.SetReply(r) + if r.Question[0].Qtype == dns.TypeA { + switch r.Question[0].Name { + case "host1.example.com.": + ret.Answer = append(ret.Answer, + test.A("host1.example.com. IN A 10.0.0.1"), + test.A("host1.example.com. IN A 10.0.0.2"), + ) + case "host2.example.com.": + ret.Answer = append(ret.Answer, + test.A("host2.example.com. IN A 10.0.0.2"), + test.A("host2.example.com. IN A 10.0.0.3"), + ) + } + } + w.WriteMsg(ret) + }) + defer s.Close() + + // Simulate: forward . host1(→10.0.0.1,10.0.0.2) host2(→10.0.0.2,10.0.0.3) 10.0.0.3 10.0.0.2 + entries := []toEntry{ + {static: false, entry: hostEntry{hostname: "host1.example.com", port: "53", transport: "dns"}}, + {static: false, entry: hostEntry{hostname: "host2.example.com", port: "53", transport: "dns"}}, + {static: true, addrs: []string{"10.0.0.3:53"}}, + {static: true, addrs: []string{"10.0.0.2:53"}}, + } + + result, err := expandAndDedup(entries, []string{s.Addr}) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + // Expected: 10.0.0.1, 10.0.0.2, 10.0.0.3 (first-seen order, deduped) + expected := []string{"10.0.0.1:53", "10.0.0.2:53", "10.0.0.3:53"} + if len(result) != len(expected) { + t.Fatalf("expected %d addresses, got %d: %v", len(expected), len(result), result) + } + for i, addr := range result { + normalized := normalizeAddr(addr) + if normalized != expected[i] { + t.Errorf("position %d: expected %s, got %s", i, expected[i], normalized) + } + } +} + +func TestExpandAndDedupOrderPreserved(t *testing.T) { + // Start a test DNS server + s := dnstest.NewMultipleServer(func(w dns.ResponseWriter, r *dns.Msg) { + ret := new(dns.Msg) + ret.SetReply(r) + if r.Question[0].Qtype == dns.TypeA { + ret.Answer = append(ret.Answer, test.A("myhost.example.com. IN A 10.0.0.42")) + } + w.WriteMsg(ret) + }) + defer s.Close() + + // Config order: hostname first, then static IP + // forward . myhost.example.com 192.168.1.1 + entries := []toEntry{ + {static: false, entry: hostEntry{hostname: "myhost.example.com", port: "53", transport: "dns"}}, + {static: true, addrs: []string{"192.168.1.1:53"}}, + } + + result, err := expandAndDedup(entries, []string{s.Addr}) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + // hostname resolved IP should come first, then static + if len(result) != 2 { + t.Fatalf("expected 2 addresses, got %d: %v", len(result), result) + } + if normalizeAddr(result[0]) != "10.0.0.42:53" { + t.Errorf("expected first addr 10.0.0.42:53, got %s", normalizeAddr(result[0])) + } + if normalizeAddr(result[1]) != "192.168.1.1:53" { + t.Errorf("expected second addr 192.168.1.1:53, got %s", normalizeAddr(result[1])) + } +} + +func TestDnsLookup(t *testing.T) { + // Start a test DNS server that responds to A queries + s := dnstest.NewMultipleServer(func(w dns.ResponseWriter, r *dns.Msg) { + ret := new(dns.Msg) + ret.SetReply(r) + if r.Question[0].Qtype == dns.TypeA { + ret.Answer = append(ret.Answer, test.A("myhost.example.com. IN A 10.0.0.42")) + } + w.WriteMsg(ret) + }) + defer s.Close() + + // Use the full server address (IP:port) since the test server uses a random port + ips, err := dnsLookup("myhost.example.com", []string{s.Addr}) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if len(ips) == 0 { + t.Fatal("expected at least one IP") + } + found := false + for _, ip := range ips { + if ip == "10.0.0.42" { + found = true + } + } + if !found { + t.Errorf("expected to find 10.0.0.42 in %v", ips) + } +} + +func TestSetupResolver(t *testing.T) { + tests := []struct { + name string + input string + shouldErr bool + expectedErr string + resolverLen int + }{ + { + name: "single resolver IP", + input: "forward . 127.0.0.1 {\nresolver 10.96.0.10\n}\n", + resolverLen: 1, + }, + { + name: "multiple resolver IPs", + input: "forward . 127.0.0.1 {\nresolver 10.96.0.10 10.96.0.11\n}\n", + resolverLen: 2, + }, + { + name: "IPv6 resolver", + input: "forward . 127.0.0.1 {\nresolver ::1\n}\n", + resolverLen: 1, + }, + { + name: "resolver not an IP", + input: "forward . 127.0.0.1 {\nresolver dns.example.com\n}\n", + shouldErr: true, + expectedErr: "resolver must be an IP address", + }, + { + name: "resolver no args", + input: "forward . 127.0.0.1 {\nresolver\n}\n", + shouldErr: true, + expectedErr: "Wrong argument count", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + c := caddy.NewTestController("dns", tc.input) + fs, err := parseForward(c) + + if tc.shouldErr { + if err == nil { + t.Fatal("expected error, got nil") + } + if !strings.Contains(err.Error(), tc.expectedErr) { + t.Errorf("expected error to contain %q, got: %v", tc.expectedErr, err) + } + return + } + + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + f := fs[0] + if len(f.resolver) != tc.resolverLen { + t.Errorf("expected %d resolver(s), got %d: %v", tc.resolverLen, len(f.resolver), f.resolver) + } + }) + } +} + +func TestSetupWithHostnameTO(t *testing.T) { + // Start a test DNS server that resolves "myupstream.example.com" to 10.0.0.42 + s := dnstest.NewMultipleServer(func(w dns.ResponseWriter, r *dns.Msg) { + ret := new(dns.Msg) + ret.SetReply(r) + if r.Question[0].Qtype == dns.TypeA && r.Question[0].Name == "myupstream.example.com." { + ret.Answer = append(ret.Answer, test.A("myupstream.example.com. IN A 10.0.0.42")) + } + w.WriteMsg(ret) + }) + defer s.Close() + + // Test resolving a hostname entry directly + entry := hostEntry{hostname: "myupstream.example.com", port: "53", transport: "dns"} + addrs, err := resolveHostEntry(entry, []string{s.Addr}) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if len(addrs) == 0 { + t.Fatal("expected at least one resolved address") + } + if addrs[0] != "10.0.0.42:53" { + t.Errorf("expected resolved addr 10.0.0.42:53, got %s", addrs[0]) + } + + // Test full integration: manually build the Forward with resolver + f := New() + f.from = "." + f.resolver = []string{s.Addr} + f.toEntries = []toEntry{ + {static: false, entry: entry}, + } + + resolvedAddrs, err := expandAndDedup(f.toEntries, f.resolver) + if err != nil { + t.Fatalf("resolution failed: %v", err) + } + + for _, addr := range resolvedAddrs { + host, _ := splitZone(addr) + trans, h := parse.Transport(host) + p := proxy.NewProxy("forward", h, trans) + f.proxies = append(f.proxies, p) + } + + if len(f.proxies) == 0 { + t.Fatal("expected at least one proxy") + } + if f.proxies[0].Addr() != "10.0.0.42:53" { + t.Errorf("expected proxy addr 10.0.0.42:53, got %s", f.proxies[0].Addr()) + } +} + +func TestSetupMixedIPAndHostnameTO(t *testing.T) { + // Start a test DNS server + s := dnstest.NewMultipleServer(func(w dns.ResponseWriter, r *dns.Msg) { + ret := new(dns.Msg) + ret.SetReply(r) + if r.Question[0].Qtype == dns.TypeA { + ret.Answer = append(ret.Answer, test.A("myupstream.example.com. IN A 10.0.0.42")) + } + w.WriteMsg(ret) + }) + defer s.Close() + + // Manually build Forward to test mixed hostname + IP (hostname first for order test) + f := New() + f.from = "." + f.resolver = []string{s.Addr} + f.toEntries = []toEntry{ + {static: false, entry: hostEntry{hostname: "myupstream.example.com", port: "53", transport: "dns"}}, + {static: true, addrs: []string{"127.0.0.1:53"}}, + } + + resolvedAddrs, err := expandAndDedup(f.toEntries, f.resolver) + if err != nil { + t.Fatalf("expand error: %v", err) + } + + for _, addr := range resolvedAddrs { + host, _ := splitZone(addr) + trans, h := parse.Transport(host) + p := proxy.NewProxy("forward", h, trans) + f.proxies = append(f.proxies, p) + } + + // Should have 2 proxies: resolved hostname first, then static IP + if len(f.proxies) != 2 { + t.Fatalf("expected 2 proxies, got %d", len(f.proxies)) + } + + if f.proxies[0].Addr() != "10.0.0.42:53" { + t.Errorf("expected first proxy 10.0.0.42:53, got %s", f.proxies[0].Addr()) + } + if f.proxies[1].Addr() != "127.0.0.1:53" { + t.Errorf("expected second proxy 127.0.0.1:53, got %s", f.proxies[1].Addr()) + } +} + +func TestSetupResolverWithProxyOptions(t *testing.T) { + s := dnstest.NewMultipleServer(func(w dns.ResponseWriter, r *dns.Msg) { + ret := new(dns.Msg) + ret.SetReply(r) + if r.Question[0].Qtype == dns.TypeA { + ret.Answer = append(ret.Answer, test.A("myhost.example.com. IN A 10.0.0.1")) + } + w.WriteMsg(ret) + }) + defer s.Close() + + input := fmt.Sprintf(`forward . myhost.example.com { + resolver %s + force_tcp + health_check 5s domain example.org. + max_fails 3 +} +`, s.Addr) + c := caddy.NewTestController("dns", input) + fs, err := parseForward(c) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + f := fs[0] + + if f.maxfails != 3 { + t.Errorf("expected maxfails 3, got %d", f.maxfails) + } + if !f.opts.ForceTCP { + t.Error("expected ForceTCP to be true") + } + if f.opts.HCDomain != "example.org." { + t.Errorf("expected HCDomain example.org., got %s", f.opts.HCDomain) + } + + p := f.proxies[0] + if p.GetHealthchecker().GetDomain() != "example.org." { + t.Errorf("expected healthcheck domain example.org., got %s", p.GetHealthchecker().GetDomain()) + } + if !p.GetHealthchecker().GetRecursionDesired() { + t.Error("expected recursion desired to be true") + } +} + +func TestExpandAndDedupTLS(t *testing.T) { + // tls://hostname1(A 9.9.9.9, A 149.112.112.112) hostname2(A 149.112.112.112, A 9.9.9.10) 149.112.112.112 9.9.9.10 + // Expected after dedup: 9.9.9.9 149.112.112.112 9.9.9.10 (first-seen order) + s := dnstest.NewMultipleServer(func(w dns.ResponseWriter, r *dns.Msg) { + ret := new(dns.Msg) + ret.SetReply(r) + if r.Question[0].Qtype == dns.TypeA { + switch r.Question[0].Name { + case "dns1.example.com.": + ret.Answer = append(ret.Answer, + test.A("dns1.example.com. IN A 9.9.9.9"), + test.A("dns1.example.com. IN A 149.112.112.112"), + ) + case "dns2.example.com.": + ret.Answer = append(ret.Answer, + test.A("dns2.example.com. IN A 149.112.112.112"), + test.A("dns2.example.com. IN A 9.9.9.10"), + ) + } + } + w.WriteMsg(ret) + }) + defer s.Close() + + entries := []toEntry{ + {static: false, entry: hostEntry{hostname: "dns1.example.com", port: "853", transport: "tls"}}, + {static: false, entry: hostEntry{hostname: "dns2.example.com", port: "853", transport: "tls"}}, + {static: true, addrs: []string{"tls://149.112.112.112:853"}}, + {static: true, addrs: []string{"tls://9.9.9.10:853"}}, + } + + result, err := expandAndDedup(entries, []string{s.Addr}) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + expected := []string{"9.9.9.9:853", "149.112.112.112:853", "9.9.9.10:853"} + if len(result) != len(expected) { + t.Fatalf("expected %d addresses after dedup, got %d: %v", len(expected), len(result), result) + } + for i, addr := range result { + if normalizeAddr(addr) != expected[i] { + t.Errorf("position %d: expected %s, got %s", i, expected[i], normalizeAddr(addr)) + } + } +} + +func TestResolverWithHCOptions(t *testing.T) { + input := "forward . 127.0.0.1 {\nresolver 10.96.0.10\n}\n" + + c := caddy.NewTestController("dns", input) + fs, err := parseForward(c) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + f := fs[0] + if len(f.resolver) != 1 || f.resolver[0] != "10.96.0.10" { + t.Errorf("unexpected resolver: %v", f.resolver) + } + + expectedOpts := proxy.Options{HCRecursionDesired: true, HCDomain: "."} + if f.opts != expectedOpts { + t.Errorf("expected opts %v, got %v", expectedOpts, f.opts) + } +} diff --git a/plugin/forward/setup.go b/plugin/forward/setup.go index 2d76b73e6..245aba0c0 100644 --- a/plugin/forward/setup.go +++ b/plugin/forward/setup.go @@ -4,6 +4,7 @@ import ( "crypto/tls" "errors" "fmt" + "net" "path/filepath" "strconv" "strings" @@ -32,8 +33,8 @@ func setup(c *caddy.Controller) error { } for i := range fs { f := fs[i] - if f.Len() > max { - return plugin.Error("forward", fmt.Errorf("more than %d TOs configured: %d", max, f.Len())) + if len(f.toEntries) > max { + return plugin.Error("forward", fmt.Errorf("more than %d TOs configured: %d", max, len(f.toEntries))) } if i == len(fs)-1 { @@ -146,11 +147,7 @@ func parseStanza(c *caddy.Controller) (*Forward, error) { return f, c.ArgErr() } - toHosts, err := parse.HostPortOrFile(to...) - if err != nil { - return f, err - } - + // Parse block first to get resolver and other options before processing TO addresses. for c.NextBlock() { if err := parseBlock(c, f); err != nil { return f, err @@ -161,6 +158,22 @@ func parseStanza(c *caddy.Controller) (*Forward, error) { return f, fmt.Errorf("max_age (%s) must not be less than expire (%s)", f.maxAge, f.expire) } + // Classify TO addresses in order, preserving config ordering. + entries, err := classifyToAddrs(to) + if err != nil { + return f, err + } + f.toEntries = entries + + // Expand hostnames and deduplicate globally (first-seen order wins). + toHosts, err := expandAndDedup(f.toEntries, f.resolver) + if err != nil { + return f, err + } + if len(toHosts) == 0 { + return f, fmt.Errorf("no valid upstream addresses found") + } + tlsServerNames := make([]string, len(toHosts)) perServerNameProxyCount := make(map[string]int) transports := make([]string, len(toHosts)) @@ -424,6 +437,21 @@ func parseBlock(c *caddy.Controller, f *Forward) error { f.failoverRcodes = append(f.failoverRcodes, rc) } + case "resolver": + args := c.RemainingArgs() + if len(args) == 0 { + return c.ArgErr() + } + for _, arg := range args { + host := arg + if h, _, err := net.SplitHostPort(arg); err == nil { + host = h + } + if net.ParseIP(host) == nil { + return fmt.Errorf("resolver must be an IP address or IP:port: %q", arg) + } + } + f.resolver = args default: return c.Errf("unknown property '%s'", c.Val()) } diff --git a/plugin/forward/setup_test.go b/plugin/forward/setup_test.go index 06b245fc3..2b86f355d 100644 --- a/plugin/forward/setup_test.go +++ b/plugin/forward/setup_test.go @@ -46,7 +46,7 @@ func TestSetup(t *testing.T) { forward com ::2`, false, ".", nil, 2, proxy.Options{HCRecursionDesired: true, HCDomain: "."}, "plugin"}, {"forward . tls://[2400:3200::1%dns.alidns.com]:853 {\ntls\n}\n", false, ".", nil, 2, proxy.Options{HCRecursionDesired: true, HCDomain: "."}, ""}, // negative - {"forward . a27.0.0.1", true, "", nil, 0, proxy.Options{HCRecursionDesired: true, HCDomain: "."}, "not an IP"}, + {"forward . a27.0.0.1", true, "", nil, 0, proxy.Options{HCRecursionDesired: true, HCDomain: "."}, "failed to resolve"}, {"forward . 127.0.0.1 {\nblaatl\n}\n", true, "", nil, 0, proxy.Options{HCRecursionDesired: true, HCDomain: "."}, "unknown property"}, {"forward . 127.0.0.1 {\nhealth_check 0.5s domain\n}\n", true, "", nil, 0, proxy.Options{HCRecursionDesired: true, HCDomain: "."}, "Wrong argument count or unexpected line ending after 'domain'"}, {"forward . https://127.0.0.1 \n", true, ".", nil, 2, proxy.Options{HCRecursionDesired: true, HCDomain: "."}, "'https' is not supported as a destination protocol in forward: https://127.0.0.1"},