package forward import ( "fmt" "os" "strings" "testing" "github.com/coredns/caddy" "github.com/coredns/coredns/plugin/pkg/dnstest" "github.com/coredns/coredns/plugin/pkg/parse" "github.com/coredns/coredns/plugin/pkg/proxy" "github.com/coredns/coredns/plugin/pkg/transport" "github.com/coredns/coredns/plugin/test" "github.com/miekg/dns" ) func TestClassifyToAddrs(t *testing.T) { // Create a resolv.conf for file test const resolv = "test_resolv.conf" if err := os.WriteFile(resolv, []byte("nameserver 10.0.0.1\n"), 0666); err != nil { t.Fatal(err) } defer os.Remove(resolv) tests := []struct { name string input []string wantStatic int wantDynamic int wantErr bool errContains string }{ { name: "simple IP", input: []string{"127.0.0.1"}, wantStatic: 1, }, { name: "IP with port", input: []string{"127.0.0.1:8053"}, wantStatic: 1, }, { name: "IPv6", input: []string{"::1"}, wantStatic: 1, }, { name: "TLS IP", input: []string{"tls://127.0.0.1"}, wantStatic: 1, }, { name: "resolv.conf file", input: []string{resolv}, wantStatic: 1, }, { name: "hostname", input: []string{"dns.example.com"}, wantDynamic: 1, }, { name: "hostname with port", input: []string{"dns.example.com:5353"}, wantDynamic: 1, }, { name: "TLS hostname", input: []string{"tls://dns.example.com"}, wantDynamic: 1, }, { name: "k8s service name", input: []string{"rbldnsd.rbldnsd.svc.cluster.local"}, wantDynamic: 1, }, { name: "mixed IPs and hostnames", input: []string{"127.0.0.1", "dns.example.com", "10.0.0.1"}, wantStatic: 2, wantDynamic: 1, }, { name: "/dev/null returns file error", input: []string{"/dev/null"}, wantErr: true, errContains: "no nameservers", }, } for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { entries, err := classifyToAddrs(tc.input) if tc.wantErr { if err == nil { t.Fatal("expected error, got nil") } if tc.errContains != "" && !strings.Contains(err.Error(), tc.errContains) { t.Errorf("expected error to contain %q, got: %v", tc.errContains, err) } return } if err != nil { t.Fatalf("unexpected error: %v", err) } staticCount := 0 dynamicCount := 0 for _, e := range entries { if e.static { staticCount++ } else { dynamicCount++ } } if staticCount != tc.wantStatic { t.Errorf("expected %d static entries, got %d", tc.wantStatic, staticCount) } if dynamicCount != tc.wantDynamic { t.Errorf("expected %d dynamic entries, got %d", tc.wantDynamic, dynamicCount) } }) } } func TestClassifyToAddrsPreservesOrder(t *testing.T) { entries, err := classifyToAddrs([]string{"dns.example.com", "127.0.0.1", "other.example.com"}) if err != nil { t.Fatalf("unexpected error: %v", err) } if len(entries) != 3 { t.Fatalf("expected 3 entries, got %d", len(entries)) } if entries[0].static || entries[0].entry.hostname != "dns.example.com" { t.Errorf("entry 0: expected dynamic dns.example.com, got static=%v entry=%v", entries[0].static, entries[0].entry) } if !entries[1].static || entries[1].addrs[0] != "127.0.0.1:53" { t.Errorf("entry 1: expected static 127.0.0.1:53, got static=%v addrs=%v", entries[1].static, entries[1].addrs) } if entries[2].static || entries[2].entry.hostname != "other.example.com" { t.Errorf("entry 2: expected dynamic other.example.com, got static=%v entry=%v", entries[2].static, entries[2].entry) } } func TestParseAsHostEntry(t *testing.T) { tests := []struct { input string wantOK bool hostname string port string transport string zone string }{ {"dns.example.com", true, "dns.example.com", "53", transport.DNS, ""}, {"dns.example.com:5353", true, "dns.example.com", "5353", transport.DNS, ""}, {"tls://dns.example.com", true, "dns.example.com", "853", transport.TLS, ""}, {"tls://dns.example.com:8853", true, "dns.example.com", "8853", transport.TLS, ""}, {"tls://dns.example.com%servername.example.com", true, "dns.example.com", "853", transport.TLS, "servername.example.com"}, {"rbldnsd.rbldnsd.svc.cluster.local", true, "rbldnsd.rbldnsd.svc.cluster.local", "53", transport.DNS, ""}, // Should fail for IPs {"127.0.0.1", false, "", "", "", ""}, {"::1", false, "", "", "", ""}, // Should fail for unsupported transports {"https://example.com", false, "", "", "", ""}, // Should fail for empty {"", false, "", "", "", ""}, } for _, tc := range tests { t.Run(tc.input, func(t *testing.T) { entry, ok := parseAsHostEntry(tc.input) if ok != tc.wantOK { t.Fatalf("expected ok=%v, got %v", tc.wantOK, ok) } if !ok { return } if entry.hostname != tc.hostname { t.Errorf("expected hostname=%q, got %q", tc.hostname, entry.hostname) } if entry.port != tc.port { t.Errorf("expected port=%q, got %q", tc.port, entry.port) } if entry.transport != tc.transport { t.Errorf("expected transport=%q, got %q", tc.transport, entry.transport) } if entry.zone != tc.zone { t.Errorf("expected zone=%q, got %q", tc.zone, entry.zone) } }) } } func TestFormatResolvedAddr(t *testing.T) { tests := []struct { ip, port, trans, zone string expected string }{ {"10.0.0.1", "53", transport.DNS, "", "10.0.0.1:53"}, {"10.0.0.1", "853", transport.TLS, "", "tls://10.0.0.1:853"}, {"10.0.0.1", "853", transport.TLS, "example.com", "tls://10.0.0.1%example.com:853"}, {"::1", "53", transport.DNS, "", "[::1]:53"}, {"::1", "853", transport.TLS, "", "tls://[::1]:853"}, {"::1", "853", transport.TLS, "example.com", "tls://[::1%example.com]:853"}, } for _, tc := range tests { t.Run(tc.expected, func(t *testing.T) { result := formatResolvedAddr(tc.ip, tc.port, tc.trans, tc.zone) if result != tc.expected { t.Errorf("expected %q, got %q", tc.expected, result) } }) } } func TestExpandAndDedup(t *testing.T) { // Start a test DNS server that returns different IPs for different hostnames s := dnstest.NewMultipleServer(func(w dns.ResponseWriter, r *dns.Msg) { ret := new(dns.Msg) ret.SetReply(r) if r.Question[0].Qtype == dns.TypeA { switch r.Question[0].Name { case "host1.example.com.": ret.Answer = append(ret.Answer, test.A("host1.example.com. IN A 10.0.0.1"), test.A("host1.example.com. IN A 10.0.0.2"), ) case "host2.example.com.": ret.Answer = append(ret.Answer, test.A("host2.example.com. IN A 10.0.0.2"), test.A("host2.example.com. IN A 10.0.0.3"), ) } } w.WriteMsg(ret) }) defer s.Close() // Simulate: forward . host1(→10.0.0.1,10.0.0.2) host2(→10.0.0.2,10.0.0.3) 10.0.0.3 10.0.0.2 entries := []toEntry{ {static: false, entry: hostEntry{hostname: "host1.example.com", port: "53", transport: "dns"}}, {static: false, entry: hostEntry{hostname: "host2.example.com", port: "53", transport: "dns"}}, {static: true, addrs: []string{"10.0.0.3:53"}}, {static: true, addrs: []string{"10.0.0.2:53"}}, } result, err := expandAndDedup(entries, []string{s.Addr}) if err != nil { t.Fatalf("unexpected error: %v", err) } // Expected: 10.0.0.1, 10.0.0.2, 10.0.0.3 (first-seen order, deduped) expected := []string{"10.0.0.1:53", "10.0.0.2:53", "10.0.0.3:53"} if len(result) != len(expected) { t.Fatalf("expected %d addresses, got %d: %v", len(expected), len(result), result) } for i, addr := range result { normalized := normalizeAddr(addr) if normalized != expected[i] { t.Errorf("position %d: expected %s, got %s", i, expected[i], normalized) } } } func TestExpandAndDedupOrderPreserved(t *testing.T) { // Start a test DNS server s := dnstest.NewMultipleServer(func(w dns.ResponseWriter, r *dns.Msg) { ret := new(dns.Msg) ret.SetReply(r) if r.Question[0].Qtype == dns.TypeA { ret.Answer = append(ret.Answer, test.A("myhost.example.com. IN A 10.0.0.42")) } w.WriteMsg(ret) }) defer s.Close() // Config order: hostname first, then static IP // forward . myhost.example.com 192.168.1.1 entries := []toEntry{ {static: false, entry: hostEntry{hostname: "myhost.example.com", port: "53", transport: "dns"}}, {static: true, addrs: []string{"192.168.1.1:53"}}, } result, err := expandAndDedup(entries, []string{s.Addr}) if err != nil { t.Fatalf("unexpected error: %v", err) } // hostname resolved IP should come first, then static if len(result) != 2 { t.Fatalf("expected 2 addresses, got %d: %v", len(result), result) } if normalizeAddr(result[0]) != "10.0.0.42:53" { t.Errorf("expected first addr 10.0.0.42:53, got %s", normalizeAddr(result[0])) } if normalizeAddr(result[1]) != "192.168.1.1:53" { t.Errorf("expected second addr 192.168.1.1:53, got %s", normalizeAddr(result[1])) } } func TestDnsLookup(t *testing.T) { // Start a test DNS server that responds to A queries s := dnstest.NewMultipleServer(func(w dns.ResponseWriter, r *dns.Msg) { ret := new(dns.Msg) ret.SetReply(r) if r.Question[0].Qtype == dns.TypeA { ret.Answer = append(ret.Answer, test.A("myhost.example.com. IN A 10.0.0.42")) } w.WriteMsg(ret) }) defer s.Close() // Use the full server address (IP:port) since the test server uses a random port ips, err := dnsLookup("myhost.example.com", []string{s.Addr}) if err != nil { t.Fatalf("unexpected error: %v", err) } if len(ips) == 0 { t.Fatal("expected at least one IP") } found := false for _, ip := range ips { if ip == "10.0.0.42" { found = true } } if !found { t.Errorf("expected to find 10.0.0.42 in %v", ips) } } func TestSetupResolver(t *testing.T) { tests := []struct { name string input string shouldErr bool expectedErr string resolverLen int }{ { name: "single resolver IP", input: "forward . 127.0.0.1 {\nresolver 10.96.0.10\n}\n", resolverLen: 1, }, { name: "multiple resolver IPs", input: "forward . 127.0.0.1 {\nresolver 10.96.0.10 10.96.0.11\n}\n", resolverLen: 2, }, { name: "IPv6 resolver", input: "forward . 127.0.0.1 {\nresolver ::1\n}\n", resolverLen: 1, }, { name: "resolver not an IP", input: "forward . 127.0.0.1 {\nresolver dns.example.com\n}\n", shouldErr: true, expectedErr: "resolver must be an IP address", }, { name: "resolver no args", input: "forward . 127.0.0.1 {\nresolver\n}\n", shouldErr: true, expectedErr: "Wrong argument count", }, } for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { c := caddy.NewTestController("dns", tc.input) fs, err := parseForward(c) if tc.shouldErr { if err == nil { t.Fatal("expected error, got nil") } if !strings.Contains(err.Error(), tc.expectedErr) { t.Errorf("expected error to contain %q, got: %v", tc.expectedErr, err) } return } if err != nil { t.Fatalf("unexpected error: %v", err) } f := fs[0] if len(f.resolver) != tc.resolverLen { t.Errorf("expected %d resolver(s), got %d: %v", tc.resolverLen, len(f.resolver), f.resolver) } }) } } func TestSetupWithHostnameTO(t *testing.T) { // Start a test DNS server that resolves "myupstream.example.com" to 10.0.0.42 s := dnstest.NewMultipleServer(func(w dns.ResponseWriter, r *dns.Msg) { ret := new(dns.Msg) ret.SetReply(r) if r.Question[0].Qtype == dns.TypeA && r.Question[0].Name == "myupstream.example.com." { ret.Answer = append(ret.Answer, test.A("myupstream.example.com. IN A 10.0.0.42")) } w.WriteMsg(ret) }) defer s.Close() // Test resolving a hostname entry directly entry := hostEntry{hostname: "myupstream.example.com", port: "53", transport: "dns"} addrs, err := resolveHostEntry(entry, []string{s.Addr}) if err != nil { t.Fatalf("unexpected error: %v", err) } if len(addrs) == 0 { t.Fatal("expected at least one resolved address") } if addrs[0] != "10.0.0.42:53" { t.Errorf("expected resolved addr 10.0.0.42:53, got %s", addrs[0]) } // Test full integration: manually build the Forward with resolver f := New() f.from = "." f.resolver = []string{s.Addr} f.toEntries = []toEntry{ {static: false, entry: entry}, } resolvedAddrs, err := expandAndDedup(f.toEntries, f.resolver) if err != nil { t.Fatalf("resolution failed: %v", err) } for _, addr := range resolvedAddrs { host, _ := splitZone(addr) trans, h := parse.Transport(host) p := proxy.NewProxy("forward", h, trans) f.proxies = append(f.proxies, p) } if len(f.proxies) == 0 { t.Fatal("expected at least one proxy") } if f.proxies[0].Addr() != "10.0.0.42:53" { t.Errorf("expected proxy addr 10.0.0.42:53, got %s", f.proxies[0].Addr()) } } func TestSetupMixedIPAndHostnameTO(t *testing.T) { // Start a test DNS server s := dnstest.NewMultipleServer(func(w dns.ResponseWriter, r *dns.Msg) { ret := new(dns.Msg) ret.SetReply(r) if r.Question[0].Qtype == dns.TypeA { ret.Answer = append(ret.Answer, test.A("myupstream.example.com. IN A 10.0.0.42")) } w.WriteMsg(ret) }) defer s.Close() // Manually build Forward to test mixed hostname + IP (hostname first for order test) f := New() f.from = "." f.resolver = []string{s.Addr} f.toEntries = []toEntry{ {static: false, entry: hostEntry{hostname: "myupstream.example.com", port: "53", transport: "dns"}}, {static: true, addrs: []string{"127.0.0.1:53"}}, } resolvedAddrs, err := expandAndDedup(f.toEntries, f.resolver) if err != nil { t.Fatalf("expand error: %v", err) } for _, addr := range resolvedAddrs { host, _ := splitZone(addr) trans, h := parse.Transport(host) p := proxy.NewProxy("forward", h, trans) f.proxies = append(f.proxies, p) } // Should have 2 proxies: resolved hostname first, then static IP if len(f.proxies) != 2 { t.Fatalf("expected 2 proxies, got %d", len(f.proxies)) } if f.proxies[0].Addr() != "10.0.0.42:53" { t.Errorf("expected first proxy 10.0.0.42:53, got %s", f.proxies[0].Addr()) } if f.proxies[1].Addr() != "127.0.0.1:53" { t.Errorf("expected second proxy 127.0.0.1:53, got %s", f.proxies[1].Addr()) } } func TestSetupResolverWithProxyOptions(t *testing.T) { s := dnstest.NewMultipleServer(func(w dns.ResponseWriter, r *dns.Msg) { ret := new(dns.Msg) ret.SetReply(r) if r.Question[0].Qtype == dns.TypeA { ret.Answer = append(ret.Answer, test.A("myhost.example.com. IN A 10.0.0.1")) } w.WriteMsg(ret) }) defer s.Close() input := fmt.Sprintf(`forward . myhost.example.com { resolver %s force_tcp health_check 5s domain example.org. max_fails 3 } `, s.Addr) c := caddy.NewTestController("dns", input) fs, err := parseForward(c) if err != nil { t.Fatalf("unexpected error: %v", err) } f := fs[0] if f.maxfails != 3 { t.Errorf("expected maxfails 3, got %d", f.maxfails) } if !f.opts.ForceTCP { t.Error("expected ForceTCP to be true") } if f.opts.HCDomain != "example.org." { t.Errorf("expected HCDomain example.org., got %s", f.opts.HCDomain) } p := f.proxies[0] if p.GetHealthchecker().GetDomain() != "example.org." { t.Errorf("expected healthcheck domain example.org., got %s", p.GetHealthchecker().GetDomain()) } if !p.GetHealthchecker().GetRecursionDesired() { t.Error("expected recursion desired to be true") } } func TestExpandAndDedupTLS(t *testing.T) { // tls://hostname1(A 9.9.9.9, A 149.112.112.112) hostname2(A 149.112.112.112, A 9.9.9.10) 149.112.112.112 9.9.9.10 // Expected after dedup: 9.9.9.9 149.112.112.112 9.9.9.10 (first-seen order) s := dnstest.NewMultipleServer(func(w dns.ResponseWriter, r *dns.Msg) { ret := new(dns.Msg) ret.SetReply(r) if r.Question[0].Qtype == dns.TypeA { switch r.Question[0].Name { case "dns1.example.com.": ret.Answer = append(ret.Answer, test.A("dns1.example.com. IN A 9.9.9.9"), test.A("dns1.example.com. IN A 149.112.112.112"), ) case "dns2.example.com.": ret.Answer = append(ret.Answer, test.A("dns2.example.com. IN A 149.112.112.112"), test.A("dns2.example.com. IN A 9.9.9.10"), ) } } w.WriteMsg(ret) }) defer s.Close() entries := []toEntry{ {static: false, entry: hostEntry{hostname: "dns1.example.com", port: "853", transport: "tls"}}, {static: false, entry: hostEntry{hostname: "dns2.example.com", port: "853", transport: "tls"}}, {static: true, addrs: []string{"tls://149.112.112.112:853"}}, {static: true, addrs: []string{"tls://9.9.9.10:853"}}, } result, err := expandAndDedup(entries, []string{s.Addr}) if err != nil { t.Fatalf("unexpected error: %v", err) } expected := []string{"9.9.9.9:853", "149.112.112.112:853", "9.9.9.10:853"} if len(result) != len(expected) { t.Fatalf("expected %d addresses after dedup, got %d: %v", len(expected), len(result), result) } for i, addr := range result { if normalizeAddr(addr) != expected[i] { t.Errorf("position %d: expected %s, got %s", i, expected[i], normalizeAddr(addr)) } } } func TestResolverWithHCOptions(t *testing.T) { input := "forward . 127.0.0.1 {\nresolver 10.96.0.10\n}\n" c := caddy.NewTestController("dns", input) fs, err := parseForward(c) if err != nil { t.Fatalf("unexpected error: %v", err) } f := fs[0] if len(f.resolver) != 1 || f.resolver[0] != "10.96.0.10" { t.Errorf("unexpected resolver: %v", f.resolver) } expectedOpts := proxy.Options{HCRecursionDesired: true, HCDomain: "."} if f.opts != expectedOpts { t.Errorf("expected opts %v, got %v", expectedOpts, f.opts) } }