mirror of
https://github.com/coredns/coredns.git
synced 2025-10-27 16:24:19 -04:00
plugin/loadbalance: support prefer option (#7433)
Signed-off-by: Olli Janatuinen <olli.janatuinen@gmail.com>
This commit is contained in:
@@ -2,7 +2,7 @@
|
||||
|
||||
## Name
|
||||
|
||||
*loadbalance* - randomizes the order of A, AAAA and MX records.
|
||||
*loadbalance* - randomizes the order of A, AAAA and MX records and optionally prefers specific subnets.
|
||||
|
||||
## Description
|
||||
|
||||
@@ -18,6 +18,7 @@ implementations (like glibc) are particular about that.
|
||||
~~~
|
||||
loadbalance [round_robin | weighted WEIGHTFILE] {
|
||||
reload DURATION
|
||||
prefer CIDR [CIDR...]
|
||||
}
|
||||
~~~
|
||||
* `round_robin` policy randomizes the order of A, AAAA, and MX records applying a uniform probability distribution. This is the default load balancing policy.
|
||||
@@ -26,6 +27,8 @@ loadbalance [round_robin | weighted WEIGHTFILE] {
|
||||
(top) A/AAAA record in the answer. Note that it does not shuffle all the records in the answer, it is only concerned about the first A/AAAA record
|
||||
returned in the answer.
|
||||
|
||||
Additionally, the plugin supports subnet-based ordering using the `prefer` directive, which reorders A/AAAA records so that IPs from preferred subnets appear first.
|
||||
|
||||
* **WEIGHTFILE** is the file containing the weight values assigned to IPs for various domain names. If the path is relative, the path from the **root** plugin will be prepended to it. The format is explained below in the *Weightfile* section.
|
||||
|
||||
* **DURATION** interval to reload `WEIGHTFILE` and update weight assignments if there are changes in the file. The default value is `30s`. A value of `0s` means to not scan for changes and reload.
|
||||
@@ -88,3 +91,17 @@ www.example.com
|
||||
100.64.1.3 2
|
||||
~~~
|
||||
|
||||
### Subnet Prioritization
|
||||
|
||||
Prioritize IPs from 10.9.20.0/24 and 192.168.1.0/24:
|
||||
|
||||
```corefile
|
||||
. {
|
||||
loadbalance round_robin {
|
||||
prefer 10.9.20.0/24 192.168.1.0/24
|
||||
}
|
||||
forward . 1.1.1.1
|
||||
}
|
||||
```
|
||||
|
||||
If the DNS response includes multiple A/AAAA records, the plugin will reorder them to place the ones matching preferred subnets first.
|
||||
|
||||
76
plugin/loadbalance/prefer.go
Normal file
76
plugin/loadbalance/prefer.go
Normal file
@@ -0,0 +1,76 @@
|
||||
package loadbalance
|
||||
|
||||
import (
|
||||
"net"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
|
||||
func reorderPreferredSubnets(msg *dns.Msg, subnets []*net.IPNet) *dns.Msg {
|
||||
msg.Answer = reorderRecords(msg.Answer, subnets)
|
||||
msg.Extra = reorderRecords(msg.Extra, subnets)
|
||||
return msg
|
||||
}
|
||||
|
||||
func reorderRecords(records []dns.RR, subnets []*net.IPNet) []dns.RR {
|
||||
var cname, address, mx, rest []dns.RR
|
||||
|
||||
for _, r := range records {
|
||||
switch r.Header().Rrtype {
|
||||
case dns.TypeCNAME:
|
||||
cname = append(cname, r)
|
||||
case dns.TypeA, dns.TypeAAAA:
|
||||
address = append(address, r)
|
||||
case dns.TypeMX:
|
||||
mx = append(mx, r)
|
||||
default:
|
||||
rest = append(rest, r)
|
||||
}
|
||||
}
|
||||
|
||||
sorted := sortBySubnetPriority(address, subnets)
|
||||
|
||||
out := append([]dns.RR{}, cname...)
|
||||
out = append(out, sorted...)
|
||||
out = append(out, mx...)
|
||||
out = append(out, rest...)
|
||||
return out
|
||||
}
|
||||
|
||||
func sortBySubnetPriority(records []dns.RR, subnets []*net.IPNet) []dns.RR {
|
||||
matched := make([]dns.RR, 0, len(records))
|
||||
seen := make(map[int]bool)
|
||||
|
||||
for _, subnet := range subnets {
|
||||
for i, r := range records {
|
||||
if seen[i] {
|
||||
continue
|
||||
}
|
||||
ip := extractIP(r)
|
||||
if ip != nil && subnet.Contains(ip) {
|
||||
matched = append(matched, r)
|
||||
seen[i] = true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
unmatched := make([]dns.RR, 0, len(records)-len(matched))
|
||||
for i, r := range records {
|
||||
if !seen[i] {
|
||||
unmatched = append(unmatched, r)
|
||||
}
|
||||
}
|
||||
|
||||
return append(matched, unmatched...)
|
||||
}
|
||||
|
||||
func extractIP(rr dns.RR) net.IP {
|
||||
switch r := rr.(type) {
|
||||
case *dns.A:
|
||||
return r.A
|
||||
case *dns.AAAA:
|
||||
return r.AAAA
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
96
plugin/loadbalance/prefer_test.go
Normal file
96
plugin/loadbalance/prefer_test.go
Normal file
@@ -0,0 +1,96 @@
|
||||
package loadbalance
|
||||
|
||||
import (
|
||||
"net"
|
||||
"testing"
|
||||
|
||||
"github.com/coredns/coredns/plugin/test"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
|
||||
func TestSortPreferred(t *testing.T) {
|
||||
records := []dns.RR{
|
||||
test.A("example.org. 300 IN A 10.9.30.1"),
|
||||
test.A("example.org. 300 IN A 10.9.20.5"),
|
||||
test.A("example.org. 300 IN A 192.168.1.2"),
|
||||
test.A("example.org. 300 IN A 10.10.0.1"),
|
||||
test.A("example.org. 300 IN A 10.9.20.3"),
|
||||
test.A("example.org. 300 IN A 172.16.0.1"),
|
||||
test.AAAA("example.org. 300 IN AAAA 2001:db8::1"),
|
||||
test.AAAA("example.org. 300 IN AAAA 2001:db8:abcd::1"),
|
||||
test.AAAA("example.org. 300 IN AAAA fd00::1"),
|
||||
test.CNAME("example.org. 300 IN CNAME alias.example.org."),
|
||||
}
|
||||
|
||||
subnets := []*net.IPNet{}
|
||||
cidrs := []string{"2001:db8::/32", "10.9.20.0/24", "10.9.30.0/24"}
|
||||
for _, cidr := range cidrs {
|
||||
_, subnet, err := net.ParseCIDR(cidr)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to parse CIDR: %v", err)
|
||||
}
|
||||
subnets = append(subnets, subnet)
|
||||
}
|
||||
|
||||
msg := &dns.Msg{Answer: records}
|
||||
reorderPreferredSubnets(msg, subnets)
|
||||
sorted := msg.Answer
|
||||
|
||||
expectedOrder := []string{
|
||||
"alias.example.org.",
|
||||
"2001:db8::1",
|
||||
"2001:db8:abcd::1",
|
||||
"10.9.20.5",
|
||||
"10.9.20.3",
|
||||
"10.9.30.1",
|
||||
"192.168.1.2",
|
||||
"10.10.0.1",
|
||||
"172.16.0.1",
|
||||
"fd00::1",
|
||||
}
|
||||
|
||||
if len(sorted) != len(expectedOrder) {
|
||||
t.Fatalf("Expected %d records, got %d", len(expectedOrder), len(sorted))
|
||||
}
|
||||
|
||||
for i, rr := range sorted {
|
||||
expected := expectedOrder[i]
|
||||
switch r := rr.(type) {
|
||||
case *dns.CNAME:
|
||||
if r.Target != expected {
|
||||
t.Errorf("Record %d: expected CNAME %s, got %s", i, expected, r.Target)
|
||||
}
|
||||
case *dns.A:
|
||||
if r.A.String() != expected {
|
||||
t.Errorf("Record %d: expected A IP %s, got %s", i, expected, r.A.String())
|
||||
}
|
||||
case *dns.AAAA:
|
||||
if r.AAAA.String() != expected {
|
||||
t.Errorf("Record %d: expected AAAA IP %s, got %s", i, expected, r.AAAA.String())
|
||||
}
|
||||
default:
|
||||
t.Errorf("Record %d: unexpected RR type %T", i, r)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractIP(t *testing.T) {
|
||||
a := test.A("example.org. 300 IN A 10.0.0.1")
|
||||
ip := extractIP(a)
|
||||
if ip.String() != "10.0.0.1" {
|
||||
t.Errorf("Expected 10.0.0.1, got %s", ip.String())
|
||||
}
|
||||
|
||||
aaaa := test.AAAA("example.org. 300 IN AAAA ::1")
|
||||
ip = extractIP(aaaa)
|
||||
if ip.String() != "::1" {
|
||||
t.Errorf("Expected ::1, got %s", ip.String())
|
||||
}
|
||||
|
||||
cname := test.CNAME("example.org. 300 IN CNAME other.org.")
|
||||
ip = extractIP(cname)
|
||||
if ip != nil {
|
||||
t.Errorf("Expected nil for CNAME, got %v", ip)
|
||||
}
|
||||
}
|
||||
@@ -3,6 +3,7 @@ package loadbalance
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"path/filepath"
|
||||
"time"
|
||||
|
||||
@@ -24,6 +25,7 @@ type lbFuncs struct {
|
||||
onStartUpFunc func() error
|
||||
onShutdownFunc func() error
|
||||
weighted *weightedRR // used in unit tests only
|
||||
preferSubnets []*net.IPNet
|
||||
}
|
||||
|
||||
func setup(c *caddy.Controller) error {
|
||||
@@ -39,65 +41,90 @@ func setup(c *caddy.Controller) error {
|
||||
c.OnShutdown(lb.onShutdownFunc)
|
||||
}
|
||||
|
||||
shuffle := lb.shuffleFunc
|
||||
if len(lb.preferSubnets) > 0 {
|
||||
original := shuffle
|
||||
shuffle = func(res *dns.Msg) *dns.Msg {
|
||||
return reorderPreferredSubnets(original(res), lb.preferSubnets)
|
||||
}
|
||||
}
|
||||
|
||||
dnsserver.GetConfig(c).AddPlugin(func(next plugin.Handler) plugin.Handler {
|
||||
return LoadBalance{Next: next, shuffle: lb.shuffleFunc}
|
||||
return LoadBalance{Next: next, shuffle: shuffle}
|
||||
})
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// func parse(c *caddy.Controller) (string, *weightedRR, error) {
|
||||
func parse(c *caddy.Controller) (*lbFuncs, error) {
|
||||
config := dnsserver.GetConfig(c)
|
||||
lb := &lbFuncs{}
|
||||
|
||||
for c.Next() {
|
||||
args := c.RemainingArgs()
|
||||
if len(args) == 0 {
|
||||
return &lbFuncs{shuffleFunc: randomShuffle}, nil
|
||||
}
|
||||
switch args[0] {
|
||||
case ramdomShufflePolicy:
|
||||
if len(args) > 1 {
|
||||
return nil, c.Errf("unknown property for %s", args[0])
|
||||
}
|
||||
return &lbFuncs{shuffleFunc: randomShuffle}, nil
|
||||
case weightedRoundRobinPolicy:
|
||||
if len(args) < 2 {
|
||||
return nil, c.Err("missing weight file argument")
|
||||
}
|
||||
|
||||
if len(args) > 2 {
|
||||
return nil, c.Err("unexpected argument(s)")
|
||||
}
|
||||
|
||||
weightFileName := args[1]
|
||||
if !filepath.IsAbs(weightFileName) && config.Root != "" {
|
||||
weightFileName = filepath.Join(config.Root, weightFileName)
|
||||
}
|
||||
reload := 30 * time.Second // default reload period
|
||||
for c.NextBlock() {
|
||||
switch c.Val() {
|
||||
case "reload":
|
||||
t := c.RemainingArgs()
|
||||
if len(t) < 1 {
|
||||
return nil, c.Err("reload duration value is missing")
|
||||
}
|
||||
if len(t) > 1 {
|
||||
return nil, c.Err("unexpected argument")
|
||||
}
|
||||
var err error
|
||||
reload, err = time.ParseDuration(t[0])
|
||||
if err != nil {
|
||||
return nil, c.Errf("invalid reload duration '%s'", t[0])
|
||||
}
|
||||
default:
|
||||
return nil, c.Errf("unknown property '%s'", c.Val())
|
||||
lb.shuffleFunc = randomShuffle
|
||||
} else {
|
||||
switch args[0] {
|
||||
case ramdomShufflePolicy:
|
||||
if len(args) > 1 {
|
||||
return nil, c.Errf("unknown property for %s", args[0])
|
||||
}
|
||||
lb.shuffleFunc = randomShuffle
|
||||
|
||||
case weightedRoundRobinPolicy:
|
||||
if len(args) < 2 {
|
||||
return nil, c.Err("missing weight file argument")
|
||||
}
|
||||
if len(args) > 2 {
|
||||
return nil, c.Err("unexpected argument(s)")
|
||||
}
|
||||
weightFileName := args[1]
|
||||
if !filepath.IsAbs(weightFileName) && config.Root != "" {
|
||||
weightFileName = filepath.Join(config.Root, weightFileName)
|
||||
}
|
||||
reload := 30 * time.Second
|
||||
for c.NextBlock() {
|
||||
switch c.Val() {
|
||||
case "reload":
|
||||
t := c.RemainingArgs()
|
||||
if len(t) < 1 {
|
||||
return nil, c.Err("reload duration value is missing")
|
||||
}
|
||||
if len(t) > 1 {
|
||||
return nil, c.Err("unexpected argument")
|
||||
}
|
||||
var err error
|
||||
reload, err = time.ParseDuration(t[0])
|
||||
if err != nil {
|
||||
return nil, c.Errf("invalid reload duration '%s'", t[0])
|
||||
}
|
||||
default:
|
||||
return nil, c.Errf("unknown property '%s'", c.Val())
|
||||
}
|
||||
}
|
||||
*lb = *createWeightedFuncs(weightFileName, reload)
|
||||
default:
|
||||
return nil, fmt.Errorf("unknown policy: %s", args[0])
|
||||
}
|
||||
}
|
||||
|
||||
for c.NextBlock() {
|
||||
switch c.Val() {
|
||||
case "prefer":
|
||||
cidrs := c.RemainingArgs()
|
||||
for _, cidr := range cidrs {
|
||||
_, subnet, err := net.ParseCIDR(cidr)
|
||||
if err != nil {
|
||||
return nil, c.Errf("invalid CIDR %q: %v", cidr, err)
|
||||
}
|
||||
lb.preferSubnets = append(lb.preferSubnets, subnet)
|
||||
}
|
||||
default:
|
||||
return nil, c.Errf("unknown property '%s'", c.Val())
|
||||
}
|
||||
return createWeightedFuncs(weightFileName, reload), nil
|
||||
default:
|
||||
return nil, fmt.Errorf("unknown policy: %s", args[0])
|
||||
}
|
||||
}
|
||||
return nil, c.ArgErr()
|
||||
|
||||
return lb, nil
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user