diff --git a/plugin/loadbalance/README.md b/plugin/loadbalance/README.md index fe29b19fc..e8777785b 100644 --- a/plugin/loadbalance/README.md +++ b/plugin/loadbalance/README.md @@ -2,7 +2,7 @@ ## Name -*loadbalance* - randomizes the order of A, AAAA and MX records. +*loadbalance* - randomizes the order of A, AAAA and MX records and optionally prefers specific subnets. ## Description @@ -18,6 +18,7 @@ implementations (like glibc) are particular about that. ~~~ loadbalance [round_robin | weighted WEIGHTFILE] { reload DURATION + prefer CIDR [CIDR...] } ~~~ * `round_robin` policy randomizes the order of A, AAAA, and MX records applying a uniform probability distribution. This is the default load balancing policy. @@ -26,6 +27,8 @@ loadbalance [round_robin | weighted WEIGHTFILE] { (top) A/AAAA record in the answer. Note that it does not shuffle all the records in the answer, it is only concerned about the first A/AAAA record returned in the answer. +Additionally, the plugin supports subnet-based ordering using the `prefer` directive, which reorders A/AAAA records so that IPs from preferred subnets appear first. + * **WEIGHTFILE** is the file containing the weight values assigned to IPs for various domain names. If the path is relative, the path from the **root** plugin will be prepended to it. The format is explained below in the *Weightfile* section. * **DURATION** interval to reload `WEIGHTFILE` and update weight assignments if there are changes in the file. The default value is `30s`. A value of `0s` means to not scan for changes and reload. @@ -88,3 +91,17 @@ www.example.com 100.64.1.3 2 ~~~ +### Subnet Prioritization + +Prioritize IPs from 10.9.20.0/24 and 192.168.1.0/24: + +```corefile +. { + loadbalance round_robin { + prefer 10.9.20.0/24 192.168.1.0/24 + } + forward . 1.1.1.1 +} +``` + +If the DNS response includes multiple A/AAAA records, the plugin will reorder them to place the ones matching preferred subnets first. diff --git a/plugin/loadbalance/prefer.go b/plugin/loadbalance/prefer.go new file mode 100644 index 000000000..44b3d9c73 --- /dev/null +++ b/plugin/loadbalance/prefer.go @@ -0,0 +1,76 @@ +package loadbalance + +import ( + "net" + + "github.com/miekg/dns" +) + +func reorderPreferredSubnets(msg *dns.Msg, subnets []*net.IPNet) *dns.Msg { + msg.Answer = reorderRecords(msg.Answer, subnets) + msg.Extra = reorderRecords(msg.Extra, subnets) + return msg +} + +func reorderRecords(records []dns.RR, subnets []*net.IPNet) []dns.RR { + var cname, address, mx, rest []dns.RR + + for _, r := range records { + switch r.Header().Rrtype { + case dns.TypeCNAME: + cname = append(cname, r) + case dns.TypeA, dns.TypeAAAA: + address = append(address, r) + case dns.TypeMX: + mx = append(mx, r) + default: + rest = append(rest, r) + } + } + + sorted := sortBySubnetPriority(address, subnets) + + out := append([]dns.RR{}, cname...) + out = append(out, sorted...) + out = append(out, mx...) + out = append(out, rest...) + return out +} + +func sortBySubnetPriority(records []dns.RR, subnets []*net.IPNet) []dns.RR { + matched := make([]dns.RR, 0, len(records)) + seen := make(map[int]bool) + + for _, subnet := range subnets { + for i, r := range records { + if seen[i] { + continue + } + ip := extractIP(r) + if ip != nil && subnet.Contains(ip) { + matched = append(matched, r) + seen[i] = true + } + } + } + + unmatched := make([]dns.RR, 0, len(records)-len(matched)) + for i, r := range records { + if !seen[i] { + unmatched = append(unmatched, r) + } + } + + return append(matched, unmatched...) +} + +func extractIP(rr dns.RR) net.IP { + switch r := rr.(type) { + case *dns.A: + return r.A + case *dns.AAAA: + return r.AAAA + default: + return nil + } +} diff --git a/plugin/loadbalance/prefer_test.go b/plugin/loadbalance/prefer_test.go new file mode 100644 index 000000000..12537c2f8 --- /dev/null +++ b/plugin/loadbalance/prefer_test.go @@ -0,0 +1,96 @@ +package loadbalance + +import ( + "net" + "testing" + + "github.com/coredns/coredns/plugin/test" + + "github.com/miekg/dns" +) + +func TestSortPreferred(t *testing.T) { + records := []dns.RR{ + test.A("example.org. 300 IN A 10.9.30.1"), + test.A("example.org. 300 IN A 10.9.20.5"), + test.A("example.org. 300 IN A 192.168.1.2"), + test.A("example.org. 300 IN A 10.10.0.1"), + test.A("example.org. 300 IN A 10.9.20.3"), + test.A("example.org. 300 IN A 172.16.0.1"), + test.AAAA("example.org. 300 IN AAAA 2001:db8::1"), + test.AAAA("example.org. 300 IN AAAA 2001:db8:abcd::1"), + test.AAAA("example.org. 300 IN AAAA fd00::1"), + test.CNAME("example.org. 300 IN CNAME alias.example.org."), + } + + subnets := []*net.IPNet{} + cidrs := []string{"2001:db8::/32", "10.9.20.0/24", "10.9.30.0/24"} + for _, cidr := range cidrs { + _, subnet, err := net.ParseCIDR(cidr) + if err != nil { + t.Fatalf("Failed to parse CIDR: %v", err) + } + subnets = append(subnets, subnet) + } + + msg := &dns.Msg{Answer: records} + reorderPreferredSubnets(msg, subnets) + sorted := msg.Answer + + expectedOrder := []string{ + "alias.example.org.", + "2001:db8::1", + "2001:db8:abcd::1", + "10.9.20.5", + "10.9.20.3", + "10.9.30.1", + "192.168.1.2", + "10.10.0.1", + "172.16.0.1", + "fd00::1", + } + + if len(sorted) != len(expectedOrder) { + t.Fatalf("Expected %d records, got %d", len(expectedOrder), len(sorted)) + } + + for i, rr := range sorted { + expected := expectedOrder[i] + switch r := rr.(type) { + case *dns.CNAME: + if r.Target != expected { + t.Errorf("Record %d: expected CNAME %s, got %s", i, expected, r.Target) + } + case *dns.A: + if r.A.String() != expected { + t.Errorf("Record %d: expected A IP %s, got %s", i, expected, r.A.String()) + } + case *dns.AAAA: + if r.AAAA.String() != expected { + t.Errorf("Record %d: expected AAAA IP %s, got %s", i, expected, r.AAAA.String()) + } + default: + t.Errorf("Record %d: unexpected RR type %T", i, r) + } + } +} + +func TestExtractIP(t *testing.T) { + a := test.A("example.org. 300 IN A 10.0.0.1") + ip := extractIP(a) + if ip.String() != "10.0.0.1" { + t.Errorf("Expected 10.0.0.1, got %s", ip.String()) + } + + aaaa := test.AAAA("example.org. 300 IN AAAA ::1") + ip = extractIP(aaaa) + if ip.String() != "::1" { + t.Errorf("Expected ::1, got %s", ip.String()) + } + + cname := test.CNAME("example.org. 300 IN CNAME other.org.") + ip = extractIP(cname) + if ip != nil { + t.Errorf("Expected nil for CNAME, got %v", ip) + } +} diff --git a/plugin/loadbalance/setup.go b/plugin/loadbalance/setup.go index 6a84e31fd..b1ded1aa6 100644 --- a/plugin/loadbalance/setup.go +++ b/plugin/loadbalance/setup.go @@ -3,6 +3,7 @@ package loadbalance import ( "errors" "fmt" + "net" "path/filepath" "time" @@ -24,6 +25,7 @@ type lbFuncs struct { onStartUpFunc func() error onShutdownFunc func() error weighted *weightedRR // used in unit tests only + preferSubnets []*net.IPNet } func setup(c *caddy.Controller) error { @@ -39,65 +41,90 @@ func setup(c *caddy.Controller) error { c.OnShutdown(lb.onShutdownFunc) } + shuffle := lb.shuffleFunc + if len(lb.preferSubnets) > 0 { + original := shuffle + shuffle = func(res *dns.Msg) *dns.Msg { + return reorderPreferredSubnets(original(res), lb.preferSubnets) + } + } + dnsserver.GetConfig(c).AddPlugin(func(next plugin.Handler) plugin.Handler { - return LoadBalance{Next: next, shuffle: lb.shuffleFunc} + return LoadBalance{Next: next, shuffle: shuffle} }) return nil } -// func parse(c *caddy.Controller) (string, *weightedRR, error) { func parse(c *caddy.Controller) (*lbFuncs, error) { config := dnsserver.GetConfig(c) + lb := &lbFuncs{} for c.Next() { args := c.RemainingArgs() if len(args) == 0 { - return &lbFuncs{shuffleFunc: randomShuffle}, nil - } - switch args[0] { - case ramdomShufflePolicy: - if len(args) > 1 { - return nil, c.Errf("unknown property for %s", args[0]) - } - return &lbFuncs{shuffleFunc: randomShuffle}, nil - case weightedRoundRobinPolicy: - if len(args) < 2 { - return nil, c.Err("missing weight file argument") - } - - if len(args) > 2 { - return nil, c.Err("unexpected argument(s)") - } - - weightFileName := args[1] - if !filepath.IsAbs(weightFileName) && config.Root != "" { - weightFileName = filepath.Join(config.Root, weightFileName) - } - reload := 30 * time.Second // default reload period - for c.NextBlock() { - switch c.Val() { - case "reload": - t := c.RemainingArgs() - if len(t) < 1 { - return nil, c.Err("reload duration value is missing") - } - if len(t) > 1 { - return nil, c.Err("unexpected argument") - } - var err error - reload, err = time.ParseDuration(t[0]) - if err != nil { - return nil, c.Errf("invalid reload duration '%s'", t[0]) - } - default: - return nil, c.Errf("unknown property '%s'", c.Val()) + lb.shuffleFunc = randomShuffle + } else { + switch args[0] { + case ramdomShufflePolicy: + if len(args) > 1 { + return nil, c.Errf("unknown property for %s", args[0]) } + lb.shuffleFunc = randomShuffle + + case weightedRoundRobinPolicy: + if len(args) < 2 { + return nil, c.Err("missing weight file argument") + } + if len(args) > 2 { + return nil, c.Err("unexpected argument(s)") + } + weightFileName := args[1] + if !filepath.IsAbs(weightFileName) && config.Root != "" { + weightFileName = filepath.Join(config.Root, weightFileName) + } + reload := 30 * time.Second + for c.NextBlock() { + switch c.Val() { + case "reload": + t := c.RemainingArgs() + if len(t) < 1 { + return nil, c.Err("reload duration value is missing") + } + if len(t) > 1 { + return nil, c.Err("unexpected argument") + } + var err error + reload, err = time.ParseDuration(t[0]) + if err != nil { + return nil, c.Errf("invalid reload duration '%s'", t[0]) + } + default: + return nil, c.Errf("unknown property '%s'", c.Val()) + } + } + *lb = *createWeightedFuncs(weightFileName, reload) + default: + return nil, fmt.Errorf("unknown policy: %s", args[0]) + } + } + + for c.NextBlock() { + switch c.Val() { + case "prefer": + cidrs := c.RemainingArgs() + for _, cidr := range cidrs { + _, subnet, err := net.ParseCIDR(cidr) + if err != nil { + return nil, c.Errf("invalid CIDR %q: %v", cidr, err) + } + lb.preferSubnets = append(lb.preferSubnets, subnet) + } + default: + return nil, c.Errf("unknown property '%s'", c.Val()) } - return createWeightedFuncs(weightFileName, reload), nil - default: - return nil, fmt.Errorf("unknown policy: %s", args[0]) } } - return nil, c.ArgErr() + + return lb, nil }