mirror of
https://github.com/coredns/coredns.git
synced 2025-10-27 08:14:18 -04:00
Add weighted-round-robin policy to loadbalance plugin (#5662)
* Add weighted-round-robin policy to loadbalance plugin Signed-off-by: Gabor Dozsa <gabor.dozsa@ibm.com>
This commit is contained in:
329
plugin/loadbalance/weighted.go
Normal file
329
plugin/loadbalance/weighted.go
Normal file
@@ -0,0 +1,329 @@
|
||||
package loadbalance
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"crypto/md5"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"math/rand"
|
||||
"net"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/coredns/coredns/plugin"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
|
||||
type (
|
||||
// "weighted-round-robin" policy specific data
|
||||
weightedRR struct {
|
||||
fileName string
|
||||
reload time.Duration
|
||||
md5sum [md5.Size]byte
|
||||
domains map[string]weights
|
||||
randomGen
|
||||
mutex sync.Mutex
|
||||
}
|
||||
// Per domain weights
|
||||
weights []*weightItem
|
||||
// Weight assigned to an address
|
||||
weightItem struct {
|
||||
address net.IP
|
||||
value uint8
|
||||
}
|
||||
// Random uint generator
|
||||
randomGen interface {
|
||||
randInit()
|
||||
randUint(limit uint) uint
|
||||
}
|
||||
)
|
||||
|
||||
// Random uint generator
|
||||
type randomUint struct {
|
||||
rn *rand.Rand
|
||||
}
|
||||
|
||||
func (r *randomUint) randInit() {
|
||||
r.rn = rand.New(rand.NewSource(time.Now().UnixNano()))
|
||||
}
|
||||
|
||||
func (r *randomUint) randUint(limit uint) uint {
|
||||
return uint(r.rn.Intn(int(limit)))
|
||||
}
|
||||
|
||||
func weightedShuffle(res *dns.Msg, w *weightedRR) *dns.Msg {
|
||||
switch res.Question[0].Qtype {
|
||||
case dns.TypeA, dns.TypeAAAA, dns.TypeSRV:
|
||||
res.Answer = w.weightedRoundRobin(res.Answer)
|
||||
res.Extra = w.weightedRoundRobin(res.Extra)
|
||||
}
|
||||
return res
|
||||
}
|
||||
|
||||
func weightedOnStartUp(w *weightedRR, stopReloadChan chan bool) error {
|
||||
err := w.updateWeights()
|
||||
if errors.Is(err, errOpen) && w.reload != 0 {
|
||||
log.Warningf("Failed to open weight file:%v. Will try again in %v",
|
||||
err, w.reload)
|
||||
} else if err != nil {
|
||||
return plugin.Error("loadbalance", err)
|
||||
}
|
||||
// start periodic weight file reload go routine
|
||||
w.periodicWeightUpdate(stopReloadChan)
|
||||
return nil
|
||||
}
|
||||
|
||||
func createWeightedFuncs(weightFileName string,
|
||||
reload time.Duration) *lbFuncs {
|
||||
lb := &lbFuncs{
|
||||
weighted: &weightedRR{
|
||||
fileName: weightFileName,
|
||||
reload: reload,
|
||||
randomGen: &randomUint{},
|
||||
},
|
||||
}
|
||||
lb.weighted.randomGen.randInit()
|
||||
|
||||
lb.shuffleFunc = func(res *dns.Msg) *dns.Msg {
|
||||
return weightedShuffle(res, lb.weighted)
|
||||
}
|
||||
|
||||
stopReloadChan := make(chan bool)
|
||||
|
||||
lb.onStartUpFunc = func() error {
|
||||
return weightedOnStartUp(lb.weighted, stopReloadChan)
|
||||
}
|
||||
|
||||
lb.onShutdownFunc = func() error {
|
||||
// stop periodic weigh reload go routine
|
||||
close(stopReloadChan)
|
||||
return nil
|
||||
}
|
||||
return lb
|
||||
}
|
||||
|
||||
// Apply weighted round robin policy to the answer
|
||||
func (w *weightedRR) weightedRoundRobin(in []dns.RR) []dns.RR {
|
||||
cname := []dns.RR{}
|
||||
address := []dns.RR{}
|
||||
mx := []dns.RR{}
|
||||
rest := []dns.RR{}
|
||||
for _, r := range in {
|
||||
switch r.Header().Rrtype {
|
||||
case dns.TypeCNAME:
|
||||
cname = append(cname, r)
|
||||
case dns.TypeA, dns.TypeAAAA:
|
||||
address = append(address, r)
|
||||
case dns.TypeMX:
|
||||
mx = append(mx, r)
|
||||
default:
|
||||
rest = append(rest, r)
|
||||
}
|
||||
}
|
||||
|
||||
if len(address) == 0 {
|
||||
// no change
|
||||
return in
|
||||
}
|
||||
|
||||
w.setTopRecord(address)
|
||||
|
||||
out := append(cname, rest...)
|
||||
out = append(out, address...)
|
||||
out = append(out, mx...)
|
||||
return out
|
||||
}
|
||||
|
||||
// Move the next expected address to the first position in the result list
|
||||
func (w *weightedRR) setTopRecord(address []dns.RR) {
|
||||
itop := w.topAddressIndex(address)
|
||||
|
||||
if itop < 0 {
|
||||
// internal error
|
||||
return
|
||||
}
|
||||
|
||||
if itop != 0 {
|
||||
// swap the selected top entry with the actual one
|
||||
address[0], address[itop] = address[itop], address[0]
|
||||
}
|
||||
}
|
||||
|
||||
// Compute the top (first) address index
|
||||
func (w *weightedRR) topAddressIndex(address []dns.RR) int {
|
||||
w.mutex.Lock()
|
||||
defer w.mutex.Unlock()
|
||||
|
||||
// Dertermine the weight value for each address in the answer
|
||||
var wsum uint
|
||||
type waddress struct {
|
||||
index int
|
||||
weight uint8
|
||||
}
|
||||
weightedAddr := make([]waddress, len(address))
|
||||
for i, ar := range address {
|
||||
wa := &weightedAddr[i]
|
||||
wa.index = i
|
||||
wa.weight = 1 // default weight
|
||||
var ip net.IP
|
||||
switch ar.Header().Rrtype {
|
||||
case dns.TypeA:
|
||||
ip = ar.(*dns.A).A
|
||||
case dns.TypeAAAA:
|
||||
ip = ar.(*dns.AAAA).AAAA
|
||||
}
|
||||
ws := w.domains[ar.Header().Name]
|
||||
for _, w := range ws {
|
||||
if w.address.Equal(ip) {
|
||||
wa.weight = w.value
|
||||
break
|
||||
}
|
||||
}
|
||||
wsum += uint(wa.weight)
|
||||
}
|
||||
|
||||
// Select the first (top) IP
|
||||
sort.Slice(weightedAddr, func(i, j int) bool {
|
||||
return weightedAddr[i].weight > weightedAddr[j].weight
|
||||
})
|
||||
v := w.randUint(wsum)
|
||||
var psum uint
|
||||
for _, wa := range weightedAddr {
|
||||
psum += uint(wa.weight)
|
||||
if v < psum {
|
||||
return int(wa.index)
|
||||
}
|
||||
}
|
||||
|
||||
// we should never reach this
|
||||
log.Errorf("Internal error: cannot find top addres (randv:%v wsum:%v)", v, wsum)
|
||||
return -1
|
||||
}
|
||||
|
||||
// Start go routine to update weights from the weight file periodically
|
||||
func (w *weightedRR) periodicWeightUpdate(stopReload <-chan bool) {
|
||||
if w.reload == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
go func() {
|
||||
ticker := time.NewTicker(w.reload)
|
||||
for {
|
||||
select {
|
||||
case <-stopReload:
|
||||
return
|
||||
case <-ticker.C:
|
||||
err := w.updateWeights()
|
||||
if err != nil {
|
||||
log.Error(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// Update weights from weight file
|
||||
func (w *weightedRR) updateWeights() error {
|
||||
reader, err := os.Open(filepath.Clean(w.fileName))
|
||||
if err != nil {
|
||||
return errOpen
|
||||
}
|
||||
defer reader.Close()
|
||||
|
||||
// check if the contents has changed
|
||||
var buf bytes.Buffer
|
||||
tee := io.TeeReader(reader, &buf)
|
||||
bytes, err := io.ReadAll(tee)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
md5sum := md5.Sum(bytes)
|
||||
if md5sum == w.md5sum {
|
||||
// file contents has not changed
|
||||
return nil
|
||||
}
|
||||
w.md5sum = md5sum
|
||||
scanner := bufio.NewScanner(&buf)
|
||||
|
||||
// Parse the weight file contents
|
||||
err = w.parseWeights(scanner)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
log.Infof("Successfully reloaded weight file %s", w.fileName)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Parse the weight file contents
|
||||
func (w *weightedRR) parseWeights(scanner *bufio.Scanner) error {
|
||||
// access to weights must be protected
|
||||
w.mutex.Lock()
|
||||
defer w.mutex.Unlock()
|
||||
|
||||
// Reset domains
|
||||
w.domains = make(map[string]weights)
|
||||
|
||||
var dname string
|
||||
var ws weights
|
||||
for scanner.Scan() {
|
||||
nextLine := strings.TrimSpace(scanner.Text())
|
||||
if len(nextLine) == 0 || nextLine[0:1] == "#" {
|
||||
// Empty and comment lines are ignored
|
||||
continue
|
||||
}
|
||||
fields := strings.Fields(nextLine)
|
||||
switch len(fields) {
|
||||
case 1:
|
||||
// (domain) name sanity check
|
||||
if net.ParseIP(fields[0]) != nil {
|
||||
return fmt.Errorf("Wrong domain name:\"%s\" in weight file %s. (Maybe a missing weight value?)",
|
||||
fields[0], w.fileName)
|
||||
}
|
||||
dname = fields[0]
|
||||
|
||||
// add the root domain if it is missing
|
||||
if dname[len(dname)-1] != '.' {
|
||||
dname += "."
|
||||
}
|
||||
var ok bool
|
||||
ws, ok = w.domains[dname]
|
||||
if !ok {
|
||||
ws = make(weights, 0)
|
||||
w.domains[dname] = ws
|
||||
}
|
||||
case 2:
|
||||
// IP address and weight value
|
||||
ip := net.ParseIP(fields[0])
|
||||
if ip == nil {
|
||||
return fmt.Errorf("Wrong IP address:\"%s\" in weight file %s", fields[0], w.fileName)
|
||||
}
|
||||
weight, err := strconv.ParseUint(fields[1], 10, 8)
|
||||
if err != nil {
|
||||
return fmt.Errorf("Wrong weight value:\"%s\" in weight file %s", fields[1], w.fileName)
|
||||
}
|
||||
witem := &weightItem{address: ip, value: uint8(weight)}
|
||||
if dname == "" {
|
||||
return fmt.Errorf("Missing domain name in weight file %s", w.fileName)
|
||||
}
|
||||
ws = append(ws, witem)
|
||||
w.domains[dname] = ws
|
||||
default:
|
||||
return fmt.Errorf("Could not parse weight line:\"%s\" in weight file %s", nextLine, w.fileName)
|
||||
}
|
||||
}
|
||||
|
||||
if err := scanner.Err(); err != nil {
|
||||
return fmt.Errorf("Weight file %s parsing error:%s", w.fileName, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
Reference in New Issue
Block a user