plugin/loadbalance: support prefer option (#7433)

Signed-off-by: Olli Janatuinen <olli.janatuinen@gmail.com>
This commit is contained in:
Olli Janatuinen
2025-08-05 20:34:38 +02:00
committed by GitHub
parent dc8f3b08e5
commit 52639bc66c
4 changed files with 262 additions and 46 deletions

View File

@@ -2,7 +2,7 @@
## Name
*loadbalance* - randomizes the order of A, AAAA and MX records.
*loadbalance* - randomizes the order of A, AAAA and MX records and optionally prefers specific subnets.
## Description
@@ -18,6 +18,7 @@ implementations (like glibc) are particular about that.
~~~
loadbalance [round_robin | weighted WEIGHTFILE] {
reload DURATION
prefer CIDR [CIDR...]
}
~~~
* `round_robin` policy randomizes the order of A, AAAA, and MX records applying a uniform probability distribution. This is the default load balancing policy.
@@ -26,6 +27,8 @@ loadbalance [round_robin | weighted WEIGHTFILE] {
(top) A/AAAA record in the answer. Note that it does not shuffle all the records in the answer, it is only concerned about the first A/AAAA record
returned in the answer.
Additionally, the plugin supports subnet-based ordering using the `prefer` directive, which reorders A/AAAA records so that IPs from preferred subnets appear first.
* **WEIGHTFILE** is the file containing the weight values assigned to IPs for various domain names. If the path is relative, the path from the **root** plugin will be prepended to it. The format is explained below in the *Weightfile* section.
* **DURATION** interval to reload `WEIGHTFILE` and update weight assignments if there are changes in the file. The default value is `30s`. A value of `0s` means to not scan for changes and reload.
@@ -88,3 +91,17 @@ www.example.com
100.64.1.3 2
~~~
### Subnet Prioritization
Prioritize IPs from 10.9.20.0/24 and 192.168.1.0/24:
```corefile
. {
loadbalance round_robin {
prefer 10.9.20.0/24 192.168.1.0/24
}
forward . 1.1.1.1
}
```
If the DNS response includes multiple A/AAAA records, the plugin will reorder them to place the ones matching preferred subnets first.

View File

@@ -0,0 +1,76 @@
package loadbalance
import (
"net"
"github.com/miekg/dns"
)
func reorderPreferredSubnets(msg *dns.Msg, subnets []*net.IPNet) *dns.Msg {
msg.Answer = reorderRecords(msg.Answer, subnets)
msg.Extra = reorderRecords(msg.Extra, subnets)
return msg
}
func reorderRecords(records []dns.RR, subnets []*net.IPNet) []dns.RR {
var cname, address, mx, rest []dns.RR
for _, r := range records {
switch r.Header().Rrtype {
case dns.TypeCNAME:
cname = append(cname, r)
case dns.TypeA, dns.TypeAAAA:
address = append(address, r)
case dns.TypeMX:
mx = append(mx, r)
default:
rest = append(rest, r)
}
}
sorted := sortBySubnetPriority(address, subnets)
out := append([]dns.RR{}, cname...)
out = append(out, sorted...)
out = append(out, mx...)
out = append(out, rest...)
return out
}
func sortBySubnetPriority(records []dns.RR, subnets []*net.IPNet) []dns.RR {
matched := make([]dns.RR, 0, len(records))
seen := make(map[int]bool)
for _, subnet := range subnets {
for i, r := range records {
if seen[i] {
continue
}
ip := extractIP(r)
if ip != nil && subnet.Contains(ip) {
matched = append(matched, r)
seen[i] = true
}
}
}
unmatched := make([]dns.RR, 0, len(records)-len(matched))
for i, r := range records {
if !seen[i] {
unmatched = append(unmatched, r)
}
}
return append(matched, unmatched...)
}
func extractIP(rr dns.RR) net.IP {
switch r := rr.(type) {
case *dns.A:
return r.A
case *dns.AAAA:
return r.AAAA
default:
return nil
}
}

View File

@@ -0,0 +1,96 @@
package loadbalance
import (
"net"
"testing"
"github.com/coredns/coredns/plugin/test"
"github.com/miekg/dns"
)
func TestSortPreferred(t *testing.T) {
records := []dns.RR{
test.A("example.org. 300 IN A 10.9.30.1"),
test.A("example.org. 300 IN A 10.9.20.5"),
test.A("example.org. 300 IN A 192.168.1.2"),
test.A("example.org. 300 IN A 10.10.0.1"),
test.A("example.org. 300 IN A 10.9.20.3"),
test.A("example.org. 300 IN A 172.16.0.1"),
test.AAAA("example.org. 300 IN AAAA 2001:db8::1"),
test.AAAA("example.org. 300 IN AAAA 2001:db8:abcd::1"),
test.AAAA("example.org. 300 IN AAAA fd00::1"),
test.CNAME("example.org. 300 IN CNAME alias.example.org."),
}
subnets := []*net.IPNet{}
cidrs := []string{"2001:db8::/32", "10.9.20.0/24", "10.9.30.0/24"}
for _, cidr := range cidrs {
_, subnet, err := net.ParseCIDR(cidr)
if err != nil {
t.Fatalf("Failed to parse CIDR: %v", err)
}
subnets = append(subnets, subnet)
}
msg := &dns.Msg{Answer: records}
reorderPreferredSubnets(msg, subnets)
sorted := msg.Answer
expectedOrder := []string{
"alias.example.org.",
"2001:db8::1",
"2001:db8:abcd::1",
"10.9.20.5",
"10.9.20.3",
"10.9.30.1",
"192.168.1.2",
"10.10.0.1",
"172.16.0.1",
"fd00::1",
}
if len(sorted) != len(expectedOrder) {
t.Fatalf("Expected %d records, got %d", len(expectedOrder), len(sorted))
}
for i, rr := range sorted {
expected := expectedOrder[i]
switch r := rr.(type) {
case *dns.CNAME:
if r.Target != expected {
t.Errorf("Record %d: expected CNAME %s, got %s", i, expected, r.Target)
}
case *dns.A:
if r.A.String() != expected {
t.Errorf("Record %d: expected A IP %s, got %s", i, expected, r.A.String())
}
case *dns.AAAA:
if r.AAAA.String() != expected {
t.Errorf("Record %d: expected AAAA IP %s, got %s", i, expected, r.AAAA.String())
}
default:
t.Errorf("Record %d: unexpected RR type %T", i, r)
}
}
}
func TestExtractIP(t *testing.T) {
a := test.A("example.org. 300 IN A 10.0.0.1")
ip := extractIP(a)
if ip.String() != "10.0.0.1" {
t.Errorf("Expected 10.0.0.1, got %s", ip.String())
}
aaaa := test.AAAA("example.org. 300 IN AAAA ::1")
ip = extractIP(aaaa)
if ip.String() != "::1" {
t.Errorf("Expected ::1, got %s", ip.String())
}
cname := test.CNAME("example.org. 300 IN CNAME other.org.")
ip = extractIP(cname)
if ip != nil {
t.Errorf("Expected nil for CNAME, got %v", ip)
}
}

View File

@@ -3,6 +3,7 @@ package loadbalance
import (
"errors"
"fmt"
"net"
"path/filepath"
"time"
@@ -24,6 +25,7 @@ type lbFuncs struct {
onStartUpFunc func() error
onShutdownFunc func() error
weighted *weightedRR // used in unit tests only
preferSubnets []*net.IPNet
}
func setup(c *caddy.Controller) error {
@@ -39,65 +41,90 @@ func setup(c *caddy.Controller) error {
c.OnShutdown(lb.onShutdownFunc)
}
shuffle := lb.shuffleFunc
if len(lb.preferSubnets) > 0 {
original := shuffle
shuffle = func(res *dns.Msg) *dns.Msg {
return reorderPreferredSubnets(original(res), lb.preferSubnets)
}
}
dnsserver.GetConfig(c).AddPlugin(func(next plugin.Handler) plugin.Handler {
return LoadBalance{Next: next, shuffle: lb.shuffleFunc}
return LoadBalance{Next: next, shuffle: shuffle}
})
return nil
}
// func parse(c *caddy.Controller) (string, *weightedRR, error) {
func parse(c *caddy.Controller) (*lbFuncs, error) {
config := dnsserver.GetConfig(c)
lb := &lbFuncs{}
for c.Next() {
args := c.RemainingArgs()
if len(args) == 0 {
return &lbFuncs{shuffleFunc: randomShuffle}, nil
}
switch args[0] {
case ramdomShufflePolicy:
if len(args) > 1 {
return nil, c.Errf("unknown property for %s", args[0])
}
return &lbFuncs{shuffleFunc: randomShuffle}, nil
case weightedRoundRobinPolicy:
if len(args) < 2 {
return nil, c.Err("missing weight file argument")
}
if len(args) > 2 {
return nil, c.Err("unexpected argument(s)")
}
weightFileName := args[1]
if !filepath.IsAbs(weightFileName) && config.Root != "" {
weightFileName = filepath.Join(config.Root, weightFileName)
}
reload := 30 * time.Second // default reload period
for c.NextBlock() {
switch c.Val() {
case "reload":
t := c.RemainingArgs()
if len(t) < 1 {
return nil, c.Err("reload duration value is missing")
}
if len(t) > 1 {
return nil, c.Err("unexpected argument")
}
var err error
reload, err = time.ParseDuration(t[0])
if err != nil {
return nil, c.Errf("invalid reload duration '%s'", t[0])
}
default:
return nil, c.Errf("unknown property '%s'", c.Val())
lb.shuffleFunc = randomShuffle
} else {
switch args[0] {
case ramdomShufflePolicy:
if len(args) > 1 {
return nil, c.Errf("unknown property for %s", args[0])
}
lb.shuffleFunc = randomShuffle
case weightedRoundRobinPolicy:
if len(args) < 2 {
return nil, c.Err("missing weight file argument")
}
if len(args) > 2 {
return nil, c.Err("unexpected argument(s)")
}
weightFileName := args[1]
if !filepath.IsAbs(weightFileName) && config.Root != "" {
weightFileName = filepath.Join(config.Root, weightFileName)
}
reload := 30 * time.Second
for c.NextBlock() {
switch c.Val() {
case "reload":
t := c.RemainingArgs()
if len(t) < 1 {
return nil, c.Err("reload duration value is missing")
}
if len(t) > 1 {
return nil, c.Err("unexpected argument")
}
var err error
reload, err = time.ParseDuration(t[0])
if err != nil {
return nil, c.Errf("invalid reload duration '%s'", t[0])
}
default:
return nil, c.Errf("unknown property '%s'", c.Val())
}
}
*lb = *createWeightedFuncs(weightFileName, reload)
default:
return nil, fmt.Errorf("unknown policy: %s", args[0])
}
}
for c.NextBlock() {
switch c.Val() {
case "prefer":
cidrs := c.RemainingArgs()
for _, cidr := range cidrs {
_, subnet, err := net.ParseCIDR(cidr)
if err != nil {
return nil, c.Errf("invalid CIDR %q: %v", cidr, err)
}
lb.preferSubnets = append(lb.preferSubnets, subnet)
}
default:
return nil, c.Errf("unknown property '%s'", c.Val())
}
return createWeightedFuncs(weightFileName, reload), nil
default:
return nil, fmt.Errorf("unknown policy: %s", args[0])
}
}
return nil, c.ArgErr()
return lb, nil
}