Add weighted-round-robin policy to loadbalance plugin (#5662)

* Add weighted-round-robin policy to loadbalance plugin

Signed-off-by: Gabor Dozsa <gabor.dozsa@ibm.com>
This commit is contained in:
Gabor Dozsa
2023-01-27 17:36:56 +01:00
committed by GitHub
parent bf7c2cf37b
commit 7da2cedaf0
8 changed files with 975 additions and 36 deletions

View File

@@ -16,10 +16,43 @@ implementations (like glibc) are particular about that.
## Syntax
~~~
loadbalance [POLICY]
loadbalance [round_robin | weighted WEIGHTFILE] {
reload DURATION
}
~~~
* `round_robin` policy randomizes the order of A, AAAA, and MX records applying a uniform probability distribution. This is the default load balancing policy.
* `weighted` policy assigns weight values to IPs to control the relative likelihood of particular IPs to be returned as the first
(top) A/AAAA record in the answer. Note that it does not shuffle all the records in the answer, it is only concerned about the first A/AAAA record
returned in the answer.
* **WEIGHTFILE** is the file containing the weight values assigned to IPs for various domain names. If the path is relative, the path from the **root** plugin will be prepended to it. The format is explained below in the *Weightfile* section.
* **DURATION** interval to reload `WEIGHTFILE` and update weight assignments if there are changes in the file. The default value is `30s`. A value of `0s` means to not scan for changes and reload.
## Weightfile
The generic weight file syntax:
~~~
# Comment lines are ignored
domain-name1
ip11 weight11
ip12 weight12
ip13 weight13
domain-name2
ip21 weight21
ip22 weight22
# ... etc.
~~~
* **POLICY** is how to balance. The default, and only option, is "round_robin".
where `ipXY` is an IP address for `domain-nameX` and `weightXY` is the weight value associated with that IP. The weight values are in the range of [1,255].
The `weighted` policy selects one of the address record in the result list and moves it to the top (first) position in the list. The random selection takes into account the weight values assigned to the addresses in the weight file. If an address in the result list is associated with no weight value in the weight file then the default weight value "1" is assumed for it when the selection is performed.
## Examples
@@ -31,3 +64,27 @@ Load balance replies coming back from Google Public DNS:
forward . 8.8.8.8 8.8.4.4
}
~~~
Use the `weighted` strategy to load balance replies supplied by the **file** plugin. We assign weight vales `3`, `1` and `2` to the IPs `100.64.1.1`, `100.64.1.2` and `100.64.1.3`, respectively. These IPs are addresses in A records for the domain name `www.example.com` defined in the `./db.example.com` zone file. The ratio between the number of answers in which `100.64.1.1`, `100.64.1.2` or `100.64.1.3` is in the top (first) A record should converge to `3 : 1 : 2`. (E.g. there should be twice as many answers with `100.64.1.3` in the top A record than with `100.64.1.2`).
Corefile:
~~~ corefile
example.com {
file ./db.example.com {
reload 10s
}
loadbalance weighted ./db.example.com.weights {
reload 10s
}
}
~~~
weight file `./db.example.com.weights`:
~~~
www.example.com
100.64.1.1 3
100.64.1.2 1
100.64.1.3 2
~~~

View File

@@ -10,15 +10,16 @@ import (
)
// RoundRobin is a plugin to rewrite responses for "load balancing".
type RoundRobin struct {
Next plugin.Handler
type LoadBalance struct {
Next plugin.Handler
shuffle func(*dns.Msg) *dns.Msg
}
// ServeDNS implements the plugin.Handler interface.
func (rr RoundRobin) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) {
wrr := &RoundRobinResponseWriter{w}
return plugin.NextOrFailure(rr.Name(), rr.Next, ctx, wrr, r)
func (lb LoadBalance) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) {
rw := &LoadBalanceResponseWriter{ResponseWriter: w, shuffle: lb.shuffle}
return plugin.NextOrFailure(lb.Name(), lb.Next, ctx, rw, r)
}
// Name implements the Handler interface.
func (rr RoundRobin) Name() string { return "loadbalance" }
func (lb LoadBalance) Name() string { return "loadbalance" }

View File

@@ -5,11 +5,19 @@ import (
"github.com/miekg/dns"
)
// RoundRobinResponseWriter is a response writer that shuffles A, AAAA and MX records.
type RoundRobinResponseWriter struct{ dns.ResponseWriter }
const (
ramdomShufflePolicy = "round_robin"
weightedRoundRobinPolicy = "weighted"
)
// LoadBalanceResponseWriter is a response writer that shuffles A, AAAA and MX records.
type LoadBalanceResponseWriter struct {
dns.ResponseWriter
shuffle func(*dns.Msg) *dns.Msg
}
// WriteMsg implements the dns.ResponseWriter interface.
func (r *RoundRobinResponseWriter) WriteMsg(res *dns.Msg) error {
func (r *LoadBalanceResponseWriter) WriteMsg(res *dns.Msg) error {
if res.Rcode != dns.RcodeSuccess {
return r.ResponseWriter.WriteMsg(res)
}
@@ -18,11 +26,14 @@ func (r *RoundRobinResponseWriter) WriteMsg(res *dns.Msg) error {
return r.ResponseWriter.WriteMsg(res)
}
return r.ResponseWriter.WriteMsg(r.shuffle(res))
}
func randomShuffle(res *dns.Msg) *dns.Msg {
res.Answer = roundRobin(res.Answer)
res.Ns = roundRobin(res.Ns)
res.Extra = roundRobin(res.Extra)
return r.ResponseWriter.WriteMsg(res)
return res
}
func roundRobin(in []dns.RR) []dns.RR {
@@ -72,9 +83,9 @@ func roundRobinShuffle(records []dns.RR) {
}
// Write implements the dns.ResponseWriter interface.
func (r *RoundRobinResponseWriter) Write(buf []byte) (int, error) {
func (r *LoadBalanceResponseWriter) Write(buf []byte) (int, error) {
// Should we pack and unpack here to fiddle with the packet... Not likely.
log.Warning("RoundRobin called with Write: not shuffling records")
log.Warning("LoadBalance called with Write: not shuffling records")
n, err := r.ResponseWriter.Write(buf)
return n, err
}

View File

@@ -11,8 +11,8 @@ import (
"github.com/miekg/dns"
)
func TestLoadBalance(t *testing.T) {
rm := RoundRobin{Next: handler()}
func TestLoadBalanceRandom(t *testing.T) {
rm := LoadBalance{Next: handler(), shuffle: randomShuffle}
// the first X records must be cnames after this test
tests := []struct {
@@ -124,7 +124,7 @@ func TestLoadBalance(t *testing.T) {
}
func TestLoadBalanceXFR(t *testing.T) {
rm := RoundRobin{Next: handler()}
rm := LoadBalance{Next: handler()}
answer := []dns.RR{
test.SOA("skydns.test. 30 IN SOA ns.dns.skydns.test. hostmaster.skydns.test. 1542756695 7200 1800 86400 30"),

View File

@@ -1,43 +1,103 @@
package loadbalance
import (
"errors"
"fmt"
"path/filepath"
"time"
"github.com/coredns/caddy"
"github.com/coredns/coredns/core/dnsserver"
"github.com/coredns/coredns/plugin"
clog "github.com/coredns/coredns/plugin/pkg/log"
"github.com/miekg/dns"
)
var log = clog.NewWithPlugin("loadbalance")
var errOpen = errors.New("Weight file open error")
func init() { plugin.Register("loadbalance", setup) }
type lbFuncs struct {
shuffleFunc func(*dns.Msg) *dns.Msg
onStartUpFunc func() error
onShutdownFunc func() error
weighted *weightedRR // used in unit tests only
}
func setup(c *caddy.Controller) error {
err := parse(c)
//shuffleFunc, startUpFunc, shutdownFunc, err := parse(c)
lb, err := parse(c)
if err != nil {
return plugin.Error("loadbalance", err)
}
if lb.onStartUpFunc != nil {
c.OnStartup(lb.onStartUpFunc)
}
if lb.onShutdownFunc != nil {
c.OnShutdown(lb.onShutdownFunc)
}
dnsserver.GetConfig(c).AddPlugin(func(next plugin.Handler) plugin.Handler {
return RoundRobin{Next: next}
return LoadBalance{Next: next, shuffle: lb.shuffleFunc}
})
return nil
}
func parse(c *caddy.Controller) error {
// func parse(c *caddy.Controller) (string, *weightedRR, error) {
func parse(c *caddy.Controller) (*lbFuncs, error) {
config := dnsserver.GetConfig(c)
for c.Next() {
args := c.RemainingArgs()
switch len(args) {
case 0:
return nil
case 1:
if args[0] != "round_robin" {
return fmt.Errorf("unknown policy: %s", args[0])
if len(args) == 0 {
return &lbFuncs{shuffleFunc: randomShuffle}, nil
}
switch args[0] {
case ramdomShufflePolicy:
if len(args) > 1 {
return nil, c.Errf("unknown property for %s", args[0])
}
return nil
return &lbFuncs{shuffleFunc: randomShuffle}, nil
case weightedRoundRobinPolicy:
if len(args) < 2 {
return nil, c.Err("missing weight file argument")
}
if len(args) > 2 {
return nil, c.Err("unexpected argument(s)")
}
weightFileName := args[1]
if !filepath.IsAbs(weightFileName) && config.Root != "" {
weightFileName = filepath.Join(config.Root, weightFileName)
}
reload := 30 * time.Second // default reload period
for c.NextBlock() {
switch c.Val() {
case "reload":
t := c.RemainingArgs()
if len(t) < 1 {
return nil, c.Err("reload duration value is missing")
}
if len(t) > 1 {
return nil, c.Err("unexpected argument")
}
var err error
reload, err = time.ParseDuration(t[0])
if err != nil {
return nil, c.Errf("invalid reload duration '%s'", t[0])
}
default:
return nil, c.Errf("unknown property '%s'", c.Val())
}
}
return createWeightedFuncs(weightFileName, reload), nil
default:
return nil, fmt.Errorf("unknown policy: %s", args[0])
}
}
return c.ArgErr()
return nil, c.ArgErr()
}

View File

@@ -7,24 +7,53 @@ import (
"github.com/coredns/caddy"
)
// weighted round robin specific test data
var testWeighted = []struct {
expectedWeightFile string
expectedWeightReload string
}{
{"wfile", "30s"},
{"wf", "10s"},
{"wf", "0s"},
}
func TestSetup(t *testing.T) {
tests := []struct {
input string
shouldErr bool
expectedPolicy string
expectedErrContent string // substring from the expected error. Empty for positive cases.
weightedDataIndex int // weighted round robin specific data index
}{
// positive
{`loadbalance`, false, "round_robin", ""},
{`loadbalance round_robin`, false, "round_robin", ""},
{`loadbalance`, false, "round_robin", "", -1},
{`loadbalance round_robin`, false, "round_robin", "", -1},
{`loadbalance weighted wfile`, false, "weighted", "", 0},
{`loadbalance weighted wf {
reload 10s
} `, false, "weighted", "", 1},
{`loadbalance weighted wf {
reload 0s
} `, false, "weighted", "", 2},
// negative
{`loadbalance fleeb`, true, "", "unknown policy"},
{`loadbalance a b`, true, "", "argument count or unexpected line"},
{`loadbalance fleeb`, true, "", "unknown policy", -1},
{`loadbalance round_robin a`, true, "", "unknown property", -1},
{`loadbalance weighted`, true, "", "missing weight file argument", -1},
{`loadbalance weighted a b`, true, "", "unexpected argument", -1},
{`loadbalance weighted wfile {
susu
} `, true, "", "unknown property", -1},
{`loadbalance weighted wfile {
reload a
} `, true, "", "invalid reload duration", -1},
{`loadbalance weighted wfile {
reload 30s a
} `, true, "", "unexpected argument", -1},
}
for i, test := range tests {
c := caddy.NewTestController("dns", test.input)
err := parse(c)
lb, err := parse(c)
if test.shouldErr && err == nil {
t.Errorf("Test %d: Expected error but found %s for input %s", i, err, test.input)
@@ -32,11 +61,39 @@ func TestSetup(t *testing.T) {
if err != nil {
if !test.shouldErr {
t.Errorf("Test %d: Expected no error but found one for input %s. Error was: %v", i, test.input, err)
t.Errorf("Test %d: Expected no error but found one for input %s. Error was: %v",
i, test.input, err)
}
if !strings.Contains(err.Error(), test.expectedErrContent) {
t.Errorf("Test %d: Expected error to contain: %v, found error: %v, input: %s", i, test.expectedErrContent, err, test.input)
t.Errorf("Test %d: Expected error to contain: %v, found error: %v, input: %s",
i, test.expectedErrContent, err, test.input)
}
continue
}
if lb == nil {
t.Errorf("Test %d: Expected valid loadbalance funcs but got nil for input %s",
i, test.input)
continue
}
policy := ramdomShufflePolicy
if lb.weighted != nil {
policy = weightedRoundRobinPolicy
}
if policy != test.expectedPolicy {
t.Errorf("Test %d: Expected policy %s but got %s for input %s", i,
test.expectedPolicy, policy, test.input)
}
if policy == weightedRoundRobinPolicy && test.weightedDataIndex >= 0 {
i := test.weightedDataIndex
if testWeighted[i].expectedWeightFile != lb.weighted.fileName {
t.Errorf("Test %d: Expected weight file name %s but got %s for input %s",
i, testWeighted[i].expectedWeightFile, lb.weighted.fileName, test.input)
}
if testWeighted[i].expectedWeightReload != lb.weighted.reload.String() {
t.Errorf("Test %d: Expected weight reload duration %s but got %s for input %s",
i, testWeighted[i].expectedWeightReload, lb.weighted.reload, test.input)
}
}
}

View File

@@ -0,0 +1,329 @@
package loadbalance
import (
"bufio"
"bytes"
"crypto/md5"
"errors"
"fmt"
"io"
"math/rand"
"net"
"os"
"path/filepath"
"sort"
"strconv"
"strings"
"sync"
"time"
"github.com/coredns/coredns/plugin"
"github.com/miekg/dns"
)
type (
// "weighted-round-robin" policy specific data
weightedRR struct {
fileName string
reload time.Duration
md5sum [md5.Size]byte
domains map[string]weights
randomGen
mutex sync.Mutex
}
// Per domain weights
weights []*weightItem
// Weight assigned to an address
weightItem struct {
address net.IP
value uint8
}
// Random uint generator
randomGen interface {
randInit()
randUint(limit uint) uint
}
)
// Random uint generator
type randomUint struct {
rn *rand.Rand
}
func (r *randomUint) randInit() {
r.rn = rand.New(rand.NewSource(time.Now().UnixNano()))
}
func (r *randomUint) randUint(limit uint) uint {
return uint(r.rn.Intn(int(limit)))
}
func weightedShuffle(res *dns.Msg, w *weightedRR) *dns.Msg {
switch res.Question[0].Qtype {
case dns.TypeA, dns.TypeAAAA, dns.TypeSRV:
res.Answer = w.weightedRoundRobin(res.Answer)
res.Extra = w.weightedRoundRobin(res.Extra)
}
return res
}
func weightedOnStartUp(w *weightedRR, stopReloadChan chan bool) error {
err := w.updateWeights()
if errors.Is(err, errOpen) && w.reload != 0 {
log.Warningf("Failed to open weight file:%v. Will try again in %v",
err, w.reload)
} else if err != nil {
return plugin.Error("loadbalance", err)
}
// start periodic weight file reload go routine
w.periodicWeightUpdate(stopReloadChan)
return nil
}
func createWeightedFuncs(weightFileName string,
reload time.Duration) *lbFuncs {
lb := &lbFuncs{
weighted: &weightedRR{
fileName: weightFileName,
reload: reload,
randomGen: &randomUint{},
},
}
lb.weighted.randomGen.randInit()
lb.shuffleFunc = func(res *dns.Msg) *dns.Msg {
return weightedShuffle(res, lb.weighted)
}
stopReloadChan := make(chan bool)
lb.onStartUpFunc = func() error {
return weightedOnStartUp(lb.weighted, stopReloadChan)
}
lb.onShutdownFunc = func() error {
// stop periodic weigh reload go routine
close(stopReloadChan)
return nil
}
return lb
}
// Apply weighted round robin policy to the answer
func (w *weightedRR) weightedRoundRobin(in []dns.RR) []dns.RR {
cname := []dns.RR{}
address := []dns.RR{}
mx := []dns.RR{}
rest := []dns.RR{}
for _, r := range in {
switch r.Header().Rrtype {
case dns.TypeCNAME:
cname = append(cname, r)
case dns.TypeA, dns.TypeAAAA:
address = append(address, r)
case dns.TypeMX:
mx = append(mx, r)
default:
rest = append(rest, r)
}
}
if len(address) == 0 {
// no change
return in
}
w.setTopRecord(address)
out := append(cname, rest...)
out = append(out, address...)
out = append(out, mx...)
return out
}
// Move the next expected address to the first position in the result list
func (w *weightedRR) setTopRecord(address []dns.RR) {
itop := w.topAddressIndex(address)
if itop < 0 {
// internal error
return
}
if itop != 0 {
// swap the selected top entry with the actual one
address[0], address[itop] = address[itop], address[0]
}
}
// Compute the top (first) address index
func (w *weightedRR) topAddressIndex(address []dns.RR) int {
w.mutex.Lock()
defer w.mutex.Unlock()
// Dertermine the weight value for each address in the answer
var wsum uint
type waddress struct {
index int
weight uint8
}
weightedAddr := make([]waddress, len(address))
for i, ar := range address {
wa := &weightedAddr[i]
wa.index = i
wa.weight = 1 // default weight
var ip net.IP
switch ar.Header().Rrtype {
case dns.TypeA:
ip = ar.(*dns.A).A
case dns.TypeAAAA:
ip = ar.(*dns.AAAA).AAAA
}
ws := w.domains[ar.Header().Name]
for _, w := range ws {
if w.address.Equal(ip) {
wa.weight = w.value
break
}
}
wsum += uint(wa.weight)
}
// Select the first (top) IP
sort.Slice(weightedAddr, func(i, j int) bool {
return weightedAddr[i].weight > weightedAddr[j].weight
})
v := w.randUint(wsum)
var psum uint
for _, wa := range weightedAddr {
psum += uint(wa.weight)
if v < psum {
return int(wa.index)
}
}
// we should never reach this
log.Errorf("Internal error: cannot find top addres (randv:%v wsum:%v)", v, wsum)
return -1
}
// Start go routine to update weights from the weight file periodically
func (w *weightedRR) periodicWeightUpdate(stopReload <-chan bool) {
if w.reload == 0 {
return
}
go func() {
ticker := time.NewTicker(w.reload)
for {
select {
case <-stopReload:
return
case <-ticker.C:
err := w.updateWeights()
if err != nil {
log.Error(err)
}
}
}
}()
}
// Update weights from weight file
func (w *weightedRR) updateWeights() error {
reader, err := os.Open(filepath.Clean(w.fileName))
if err != nil {
return errOpen
}
defer reader.Close()
// check if the contents has changed
var buf bytes.Buffer
tee := io.TeeReader(reader, &buf)
bytes, err := io.ReadAll(tee)
if err != nil {
return err
}
md5sum := md5.Sum(bytes)
if md5sum == w.md5sum {
// file contents has not changed
return nil
}
w.md5sum = md5sum
scanner := bufio.NewScanner(&buf)
// Parse the weight file contents
err = w.parseWeights(scanner)
if err != nil {
return err
}
log.Infof("Successfully reloaded weight file %s", w.fileName)
return nil
}
// Parse the weight file contents
func (w *weightedRR) parseWeights(scanner *bufio.Scanner) error {
// access to weights must be protected
w.mutex.Lock()
defer w.mutex.Unlock()
// Reset domains
w.domains = make(map[string]weights)
var dname string
var ws weights
for scanner.Scan() {
nextLine := strings.TrimSpace(scanner.Text())
if len(nextLine) == 0 || nextLine[0:1] == "#" {
// Empty and comment lines are ignored
continue
}
fields := strings.Fields(nextLine)
switch len(fields) {
case 1:
// (domain) name sanity check
if net.ParseIP(fields[0]) != nil {
return fmt.Errorf("Wrong domain name:\"%s\" in weight file %s. (Maybe a missing weight value?)",
fields[0], w.fileName)
}
dname = fields[0]
// add the root domain if it is missing
if dname[len(dname)-1] != '.' {
dname += "."
}
var ok bool
ws, ok = w.domains[dname]
if !ok {
ws = make(weights, 0)
w.domains[dname] = ws
}
case 2:
// IP address and weight value
ip := net.ParseIP(fields[0])
if ip == nil {
return fmt.Errorf("Wrong IP address:\"%s\" in weight file %s", fields[0], w.fileName)
}
weight, err := strconv.ParseUint(fields[1], 10, 8)
if err != nil {
return fmt.Errorf("Wrong weight value:\"%s\" in weight file %s", fields[1], w.fileName)
}
witem := &weightItem{address: ip, value: uint8(weight)}
if dname == "" {
return fmt.Errorf("Missing domain name in weight file %s", w.fileName)
}
ws = append(ws, witem)
w.domains[dname] = ws
default:
return fmt.Errorf("Could not parse weight line:\"%s\" in weight file %s", nextLine, w.fileName)
}
}
if err := scanner.Err(); err != nil {
return fmt.Errorf("Weight file %s parsing error:%s", w.fileName, err)
}
return nil
}

View File

@@ -0,0 +1,424 @@
package loadbalance
import (
"context"
"errors"
"net"
"strings"
"testing"
"time"
"github.com/coredns/coredns/plugin/pkg/dnstest"
testutil "github.com/coredns/coredns/plugin/test"
"github.com/miekg/dns"
)
const oneDomainWRR = `
w1,example.org
192.168.1.15 10
192.168.1.14 20
`
var testOneDomainWRR = map[string]weights{
"w1,example.org.": weights{
&weightItem{net.ParseIP("192.168.1.15"), uint8(10)},
&weightItem{net.ParseIP("192.168.1.14"), uint8(20)},
},
}
const twoDomainsWRR = `
# domain 1
w1.example.org
192.168.1.15 10
192.168.1.14 20
# domain 2
w2.example.org
# domain 3
w3.example.org
192.168.2.16 11
192.168.2.15 12
192.168.2.14 13
`
var testTwoDomainsWRR = map[string]weights{
"w1.example.org.": weights{
&weightItem{net.ParseIP("192.168.1.15"), uint8(10)},
&weightItem{net.ParseIP("192.168.1.14"), uint8(20)},
},
"w2.example.org.": weights{},
"w3.example.org.": weights{
&weightItem{net.ParseIP("192.168.2.16"), uint8(11)},
&weightItem{net.ParseIP("192.168.2.15"), uint8(12)},
&weightItem{net.ParseIP("192.168.2.14"), uint8(13)},
},
}
const missingWeightWRR = `
w1,example.org
192.168.1.14
192.168.1.15 20
`
const missingDomainWRR = `
# missing domain
192.168.1.14 10
w2,example.org
192.168.2.14 11
192.168.2.15 12
`
const wrongIpWRR = `
w1,example.org
192.168.1.300 10
`
const wrongWeightWRR = `
w1,example.org
192.168.1.14 300
`
func TestWeightFileUpdate(t *testing.T) {
tests := []struct {
weightFilContent string
shouldErr bool
expectedDomains map[string]weights
expectedErrContent string // substring from the expected error. Empty for positive cases.
}{
// positive
{"", false, nil, ""},
{oneDomainWRR, false, testOneDomainWRR, ""},
{twoDomainsWRR, false, testTwoDomainsWRR, ""},
// negative
{missingWeightWRR, true, nil, "Wrong domain name"},
{missingDomainWRR, true, nil, "Missing domain name"},
{wrongIpWRR, true, nil, "Wrong IP address"},
{wrongWeightWRR, true, nil, "Wrong weight value"},
}
for i, test := range tests {
testFile, rm, err := testutil.TempFile(".", test.weightFilContent)
if err != nil {
t.Fatal(err)
}
defer rm()
weighted := &weightedRR{fileName: testFile}
err = weighted.updateWeights()
if test.shouldErr && err == nil {
t.Errorf("Test %d: Expected error but found %s", i, err)
}
if err != nil {
if !test.shouldErr {
t.Errorf("Test %d: Expected no error but found error: %v", i, err)
}
if !strings.Contains(err.Error(), test.expectedErrContent) {
t.Errorf("Test %d: Expected error to contain: %v, found error: %v",
i, test.expectedErrContent, err)
}
}
if test.expectedDomains != nil {
if len(test.expectedDomains) != len(weighted.domains) {
t.Errorf("Test %d: Expected len(domains): %d but got %d",
i, len(test.expectedDomains), len(weighted.domains))
} else {
_ = checkDomainsWRR(t, i, test.expectedDomains, weighted.domains)
}
}
}
}
func checkDomainsWRR(t *testing.T, testIndex int, expectedDomains, domains map[string]weights) error {
var ret error
retError := errors.New("Check domains failed")
for dname, expectedWeights := range expectedDomains {
ws, ok := domains[dname]
if !ok {
t.Errorf("Test %d: Expected domain %s but not found it", testIndex, dname)
ret = retError
} else {
if len(expectedWeights) != len(ws) {
t.Errorf("Test %d: Expected len(weights): %d for domain %s but got %d",
testIndex, len(expectedWeights), dname, len(ws))
ret = retError
} else {
for i, w := range expectedWeights {
if !w.address.Equal(ws[i].address) || w.value != ws[i].value {
t.Errorf("Test %d: Weight list differs at index %d for domain %s. "+
"Expected: %v got: %v", testIndex, i, dname, expectedWeights[i], ws[i])
ret = retError
}
}
}
}
}
return ret
}
func TestPeriodicWeightUpdate(t *testing.T) {
testFile1, rm, err := testutil.TempFile(".", oneDomainWRR)
if err != nil {
t.Fatal(err)
}
defer rm()
testFile2, rm, err := testutil.TempFile(".", twoDomainsWRR)
if err != nil {
t.Fatal(err)
}
defer rm()
// configure weightedRR with "oneDomainWRR" weight file content
weighted := &weightedRR{fileName: testFile1}
err = weighted.updateWeights()
if err != nil {
t.Fatal(err)
} else {
err = checkDomainsWRR(t, 0, testOneDomainWRR, weighted.domains)
if err != nil {
t.Fatalf("Initial check domains failed")
}
}
// change weight file
weighted.fileName = testFile2
// start periodic update
weighted.reload = 10 * time.Millisecond
stopChan := make(chan bool)
weighted.periodicWeightUpdate(stopChan)
time.Sleep(20 * time.Millisecond)
// stop periodic update
close(stopChan)
// check updated config
weighted.mutex.Lock()
err = checkDomainsWRR(t, 0, testTwoDomainsWRR, weighted.domains)
weighted.mutex.Unlock()
if err != nil {
t.Fatalf("Final check domains failed")
}
}
// Fake random number generator for testing
type fakeRandomGen struct {
expectedLimit uint
testIndex int
queryIndex int
randv uint
t *testing.T
}
func (r *fakeRandomGen) randInit() {
}
func (r *fakeRandomGen) randUint(limit uint) uint {
if limit != r.expectedLimit {
r.t.Errorf("Test %d query %d: Expected weights sum %d but got %d",
r.testIndex, r.queryIndex, r.expectedLimit, limit)
}
return r.randv
}
func TestLoadBalanceWRR(t *testing.T) {
type testQuery struct {
randv uint // fake random value for selecting the top IP
topIP string // top (first) address record in the answer
}
// domain maps to test
oneDomain := map[string]weights{
"endpoint.region2.skydns.test.": weights{
&weightItem{net.ParseIP("10.240.0.2"), uint8(3)},
&weightItem{net.ParseIP("10.240.0.1"), uint8(2)},
},
}
twoDomains := map[string]weights{
"endpoint.region2.skydns.test.": weights{
&weightItem{net.ParseIP("10.240.0.2"), uint8(5)},
&weightItem{net.ParseIP("10.240.0.1"), uint8(2)},
},
"endpoint.region1.skydns.test.": weights{
&weightItem{net.ParseIP("::2"), uint8(4)},
&weightItem{net.ParseIP("::1"), uint8(3)},
},
}
// the first X records must be cnames after this test
tests := []struct {
answer []dns.RR
extra []dns.RR
cnameAnswer int
cnameExtra int
addressAnswer int
addressExtra int
mxAnswer int
mxExtra int
domains map[string]weights
sumWeights uint // sum of weights in the answer
queries []testQuery
}{
{
answer: []dns.RR{
testutil.CNAME("cname1.region2.skydns.test. 300 IN CNAME cname2.region2.skydns.test."),
testutil.CNAME("cname2.region2.skydns.test. 300 IN CNAME cname3.region2.skydns.test."),
testutil.CNAME("cname5.region2.skydns.test. 300 IN CNAME cname6.region2.skydns.test."),
testutil.CNAME("cname6.region2.skydns.test. 300 IN CNAME endpoint.region2.skydns.test."),
testutil.A("endpoint.region2.skydns.test. 300 IN A 10.240.0.1"),
testutil.A("endpoint.region2.skydns.test. 300 IN A 10.240.0.2"),
testutil.A("endpoint.region2.skydns.test. 300 IN A 10.240.0.3"),
testutil.AAAA("endpoint.region1.skydns.test. 300 IN AAAA ::1"),
testutil.AAAA("endpoint.region1.skydns.test. 300 IN AAAA ::2"),
testutil.MX("mx.region2.skydns.test. 300 IN MX 1 mx1.region2.skydns.test."),
testutil.MX("mx.region2.skydns.test. 300 IN MX 2 mx2.region2.skydns.test."),
testutil.MX("mx.region2.skydns.test. 300 IN MX 3 mx3.region2.skydns.test."),
},
extra: []dns.RR{
testutil.CNAME("cname6.region2.skydns.test. 300 IN CNAME endpoint.region2.skydns.test."),
testutil.A("endpoint.region2.skydns.test. 300 IN A 10.240.0.1"),
testutil.A("endpoint.region2.skydns.test. 300 IN A 10.240.0.2"),
testutil.A("endpoint.region2.skydns.test. 300 IN A 10.240.0.3"),
testutil.AAAA("endpoint.region1.skydns.test. 300 IN AAAA ::1"),
testutil.AAAA("endpoint.region1.skydns.test. 300 IN AAAA ::2"),
testutil.MX("mx.region2.skydns.test. 300 IN MX 1 mx1.region2.skydns.test."),
},
cnameAnswer: 4,
cnameExtra: 1,
addressAnswer: 5,
addressExtra: 5,
mxAnswer: 3,
mxExtra: 1,
domains: twoDomains,
sumWeights: 15,
queries: []testQuery{
{0, "10.240.0.2"}, // domain 1 weight 5
{4, "10.240.0.2"}, // domain 1 weight 5
{5, "::2"}, // domain 2 weight 4
{8, "::2"}, // domain 2 weight 4
{9, "::1"}, // domain 2 weight 3
{11, "::1"}, // domain 2 weight 3
{12, "10.240.0.1"}, // domain 1 weight 2
{13, "10.240.0.1"}, // domain 1 weight 2
{14, "10.240.0.3"}, // domain 1 no weight -> default weight
},
},
{
answer: []dns.RR{
testutil.A("endpoint.region2.skydns.test. 300 IN A 10.240.0.1"),
testutil.MX("mx.region2.skydns.test. 300 IN MX 1 mx1.region2.skydns.test."),
testutil.CNAME("cname.region2.skydns.test. 300 IN CNAME endpoint.region2.skydns.test."),
testutil.A("endpoint.region2.skydns.test. 300 IN A 10.240.0.2"),
testutil.A("endpoint.region1.skydns.test. 300 IN A 10.240.0.3"),
},
cnameAnswer: 1,
addressAnswer: 3,
mxAnswer: 1,
domains: oneDomain,
sumWeights: 6,
queries: []testQuery{
{0, "10.240.0.2"}, // weight 3
{2, "10.240.0.2"}, // weight 3
{3, "10.240.0.1"}, // weight 2
{4, "10.240.0.1"}, // weight 2
{5, "10.240.0.3"}, // no domain -> default weight
},
},
{
answer: []dns.RR{
testutil.MX("mx.region2.skydns.test. 300 IN MX 1 mx1.region2.skydns.test."),
testutil.CNAME("cname.region2.skydns.test. 300 IN CNAME endpoint.region2.skydns.test."),
},
cnameAnswer: 1,
mxAnswer: 1,
domains: oneDomain,
queries: []testQuery{
{0, ""}, // no address records -> answer unaltered
},
},
}
testRand := &fakeRandomGen{t: t}
weighted := &weightedRR{randomGen: testRand}
shuffle := func(res *dns.Msg) *dns.Msg {
return weightedShuffle(res, weighted)
}
rm := LoadBalance{Next: handler(), shuffle: shuffle}
rec := dnstest.NewRecorder(&testutil.ResponseWriter{})
for i, test := range tests {
// set domain map for weighted round robin
weighted.domains = test.domains
testRand.testIndex = i
testRand.expectedLimit = test.sumWeights
for j, query := range test.queries {
req := new(dns.Msg)
req.SetQuestion("endpoint.region2.skydns.test", dns.TypeSRV)
req.Answer = test.answer
req.Extra = test.extra
// Set fake random number
testRand.randv = query.randv
testRand.queryIndex = j
_, err := rm.ServeDNS(context.TODO(), rec, req)
if err != nil {
t.Errorf("Test %d: Expected no error, but got %s", i, err)
continue
}
checkTopIP(t, i, j, rec.Msg.Answer, query.topIP)
checkTopIP(t, i, j, rec.Msg.Extra, query.topIP)
cname, address, mx, sorted := countRecords(rec.Msg.Answer)
if query.topIP != "" && !sorted {
t.Errorf("Test %d query %d: Expected CNAMEs, then AAAAs, then MX in Answer, but got mixed", i, j)
}
if cname != test.cnameAnswer {
t.Errorf("Test %d query %d: Expected %d CNAMEs in Answer, but got %d", i, j, test.cnameAnswer, cname)
}
if address != test.addressAnswer {
t.Errorf("Test %d query %d: Expected %d A/AAAAs in Answer, but got %d", i, j, test.addressAnswer, address)
}
if mx != test.mxAnswer {
t.Errorf("Test %d query %d: Expected %d MXs in Answer, but got %d", i, j, test.mxAnswer, mx)
}
cname, address, mx, sorted = countRecords(rec.Msg.Extra)
if query.topIP != "" && !sorted {
t.Errorf("Test %d query %d: Expected CNAMEs, then AAAAs, then MX in Answer, but got mixed", i, j)
}
if cname != test.cnameExtra {
t.Errorf("Test %d query %d: Expected %d CNAMEs in Extra, but got %d", i, j, test.cnameAnswer, cname)
}
if address != test.addressExtra {
t.Errorf("Test %d query %d: Expected %d A/AAAAs in Extra, but got %d", i, j, test.addressAnswer, address)
}
if mx != test.mxExtra {
t.Errorf("Test %d query %d: Expected %d MXs in Extra, but got %d", i, j, test.mxAnswer, mx)
}
}
}
}
func checkTopIP(t *testing.T, i, j int, result []dns.RR, expectedTopIP string) {
expected := net.ParseIP(expectedTopIP)
for _, r := range result {
switch r.Header().Rrtype {
case dns.TypeA:
ar := r.(*dns.A)
if !ar.A.Equal(expected) {
t.Errorf("Test %d query %d: expected top IP %s but got %s", i, j, expectedTopIP, ar.A)
}
return
case dns.TypeAAAA:
ar := r.(*dns.AAAA)
if !ar.AAAA.Equal(expected) {
t.Errorf("Test %d query %d: expected top IP %s but got %s", i, j, expectedTopIP, ar.AAAA)
}
return
}
}
}