| 
									
										
										
										
											2023-01-27 17:36:56 +01:00
										 |  |  | package loadbalance
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | import (
 | 
					
						
							|  |  |  | 	"bufio"
 | 
					
						
							|  |  |  | 	"bytes"
 | 
					
						
							|  |  |  | 	"crypto/md5"
 | 
					
						
							|  |  |  | 	"errors"
 | 
					
						
							|  |  |  | 	"fmt"
 | 
					
						
							|  |  |  | 	"io"
 | 
					
						
							|  |  |  | 	"math/rand"
 | 
					
						
							|  |  |  | 	"net"
 | 
					
						
							|  |  |  | 	"os"
 | 
					
						
							|  |  |  | 	"path/filepath"
 | 
					
						
							|  |  |  | 	"sort"
 | 
					
						
							|  |  |  | 	"strconv"
 | 
					
						
							|  |  |  | 	"strings"
 | 
					
						
							|  |  |  | 	"sync"
 | 
					
						
							|  |  |  | 	"time"
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	"github.com/coredns/coredns/plugin"
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	"github.com/miekg/dns"
 | 
					
						
							|  |  |  | )
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | type (
 | 
					
						
							|  |  |  | 	// "weighted-round-robin" policy specific data
 | 
					
						
							|  |  |  | 	weightedRR struct {
 | 
					
						
							|  |  |  | 		fileName string
 | 
					
						
							|  |  |  | 		reload   time.Duration
 | 
					
						
							|  |  |  | 		md5sum   [md5.Size]byte
 | 
					
						
							|  |  |  | 		domains  map[string]weights
 | 
					
						
							|  |  |  | 		randomGen
 | 
					
						
							|  |  |  | 		mutex sync.Mutex
 | 
					
						
							|  |  |  | 	}
 | 
					
						
							|  |  |  | 	// Per domain weights
 | 
					
						
							|  |  |  | 	weights []*weightItem
 | 
					
						
							|  |  |  | 	// Weight assigned to an address
 | 
					
						
							|  |  |  | 	weightItem struct {
 | 
					
						
							|  |  |  | 		address net.IP
 | 
					
						
							|  |  |  | 		value   uint8
 | 
					
						
							|  |  |  | 	}
 | 
					
						
							|  |  |  | 	// Random uint generator
 | 
					
						
							|  |  |  | 	randomGen interface {
 | 
					
						
							|  |  |  | 		randInit()
 | 
					
						
							|  |  |  | 		randUint(limit uint) uint
 | 
					
						
							|  |  |  | 	}
 | 
					
						
							|  |  |  | )
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | // Random uint generator
 | 
					
						
							|  |  |  | type randomUint struct {
 | 
					
						
							|  |  |  | 	rn *rand.Rand
 | 
					
						
							|  |  |  | }
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | func (r *randomUint) randInit() {
 | 
					
						
							|  |  |  | 	r.rn = rand.New(rand.NewSource(time.Now().UnixNano()))
 | 
					
						
							|  |  |  | }
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | func (r *randomUint) randUint(limit uint) uint {
 | 
					
						
							|  |  |  | 	return uint(r.rn.Intn(int(limit)))
 | 
					
						
							|  |  |  | }
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | func weightedShuffle(res *dns.Msg, w *weightedRR) *dns.Msg {
 | 
					
						
							|  |  |  | 	switch res.Question[0].Qtype {
 | 
					
						
							|  |  |  | 	case dns.TypeA, dns.TypeAAAA, dns.TypeSRV:
 | 
					
						
							|  |  |  | 		res.Answer = w.weightedRoundRobin(res.Answer)
 | 
					
						
							|  |  |  | 		res.Extra = w.weightedRoundRobin(res.Extra)
 | 
					
						
							|  |  |  | 	}
 | 
					
						
							|  |  |  | 	return res
 | 
					
						
							|  |  |  | }
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | func weightedOnStartUp(w *weightedRR, stopReloadChan chan bool) error {
 | 
					
						
							|  |  |  | 	err := w.updateWeights()
 | 
					
						
							|  |  |  | 	if errors.Is(err, errOpen) && w.reload != 0 {
 | 
					
						
							|  |  |  | 		log.Warningf("Failed to open weight file:%v. Will try again in %v",
 | 
					
						
							|  |  |  | 			err, w.reload)
 | 
					
						
							|  |  |  | 	} else if err != nil {
 | 
					
						
							|  |  |  | 		return plugin.Error("loadbalance", err)
 | 
					
						
							|  |  |  | 	}
 | 
					
						
							|  |  |  | 	// start periodic weight file reload go routine
 | 
					
						
							|  |  |  | 	w.periodicWeightUpdate(stopReloadChan)
 | 
					
						
							|  |  |  | 	return nil
 | 
					
						
							|  |  |  | }
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | func createWeightedFuncs(weightFileName string,
 | 
					
						
							|  |  |  | 	reload time.Duration) *lbFuncs {
 | 
					
						
							|  |  |  | 	lb := &lbFuncs{
 | 
					
						
							|  |  |  | 		weighted: &weightedRR{
 | 
					
						
							|  |  |  | 			fileName:  weightFileName,
 | 
					
						
							|  |  |  | 			reload:    reload,
 | 
					
						
							|  |  |  | 			randomGen: &randomUint{},
 | 
					
						
							|  |  |  | 		},
 | 
					
						
							|  |  |  | 	}
 | 
					
						
							|  |  |  | 	lb.weighted.randomGen.randInit()
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	lb.shuffleFunc = func(res *dns.Msg) *dns.Msg {
 | 
					
						
							|  |  |  | 		return weightedShuffle(res, lb.weighted)
 | 
					
						
							|  |  |  | 	}
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	stopReloadChan := make(chan bool)
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	lb.onStartUpFunc = func() error {
 | 
					
						
							|  |  |  | 		return weightedOnStartUp(lb.weighted, stopReloadChan)
 | 
					
						
							|  |  |  | 	}
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	lb.onShutdownFunc = func() error {
 | 
					
						
							|  |  |  | 		// stop periodic weigh reload go routine
 | 
					
						
							|  |  |  | 		close(stopReloadChan)
 | 
					
						
							|  |  |  | 		return nil
 | 
					
						
							|  |  |  | 	}
 | 
					
						
							|  |  |  | 	return lb
 | 
					
						
							|  |  |  | }
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | // Apply weighted round robin policy to the answer
 | 
					
						
							|  |  |  | func (w *weightedRR) weightedRoundRobin(in []dns.RR) []dns.RR {
 | 
					
						
							|  |  |  | 	cname := []dns.RR{}
 | 
					
						
							|  |  |  | 	address := []dns.RR{}
 | 
					
						
							|  |  |  | 	mx := []dns.RR{}
 | 
					
						
							|  |  |  | 	rest := []dns.RR{}
 | 
					
						
							|  |  |  | 	for _, r := range in {
 | 
					
						
							|  |  |  | 		switch r.Header().Rrtype {
 | 
					
						
							|  |  |  | 		case dns.TypeCNAME:
 | 
					
						
							|  |  |  | 			cname = append(cname, r)
 | 
					
						
							|  |  |  | 		case dns.TypeA, dns.TypeAAAA:
 | 
					
						
							|  |  |  | 			address = append(address, r)
 | 
					
						
							|  |  |  | 		case dns.TypeMX:
 | 
					
						
							|  |  |  | 			mx = append(mx, r)
 | 
					
						
							|  |  |  | 		default:
 | 
					
						
							|  |  |  | 			rest = append(rest, r)
 | 
					
						
							|  |  |  | 		}
 | 
					
						
							|  |  |  | 	}
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	if len(address) == 0 {
 | 
					
						
							|  |  |  | 		// no change
 | 
					
						
							|  |  |  | 		return in
 | 
					
						
							|  |  |  | 	}
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	w.setTopRecord(address)
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	out := append(cname, rest...)
 | 
					
						
							|  |  |  | 	out = append(out, address...)
 | 
					
						
							|  |  |  | 	out = append(out, mx...)
 | 
					
						
							|  |  |  | 	return out
 | 
					
						
							|  |  |  | }
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | // Move the next expected address to the first position in the result list
 | 
					
						
							|  |  |  | func (w *weightedRR) setTopRecord(address []dns.RR) {
 | 
					
						
							|  |  |  | 	itop := w.topAddressIndex(address)
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	if itop < 0 {
 | 
					
						
							|  |  |  | 		// internal error
 | 
					
						
							|  |  |  | 		return
 | 
					
						
							|  |  |  | 	}
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	if itop != 0 {
 | 
					
						
							|  |  |  | 		// swap the selected top entry with the actual one
 | 
					
						
							|  |  |  | 		address[0], address[itop] = address[itop], address[0]
 | 
					
						
							|  |  |  | 	}
 | 
					
						
							|  |  |  | }
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | // Compute the top (first) address index
 | 
					
						
							|  |  |  | func (w *weightedRR) topAddressIndex(address []dns.RR) int {
 | 
					
						
							|  |  |  | 	w.mutex.Lock()
 | 
					
						
							|  |  |  | 	defer w.mutex.Unlock()
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	// Dertermine the weight value for each address in the answer
 | 
					
						
							|  |  |  | 	var wsum uint
 | 
					
						
							|  |  |  | 	type waddress struct {
 | 
					
						
							|  |  |  | 		index  int
 | 
					
						
							|  |  |  | 		weight uint8
 | 
					
						
							|  |  |  | 	}
 | 
					
						
							|  |  |  | 	weightedAddr := make([]waddress, len(address))
 | 
					
						
							|  |  |  | 	for i, ar := range address {
 | 
					
						
							|  |  |  | 		wa := &weightedAddr[i]
 | 
					
						
							|  |  |  | 		wa.index = i
 | 
					
						
							|  |  |  | 		wa.weight = 1 // default weight
 | 
					
						
							|  |  |  | 		var ip net.IP
 | 
					
						
							|  |  |  | 		switch ar.Header().Rrtype {
 | 
					
						
							|  |  |  | 		case dns.TypeA:
 | 
					
						
							|  |  |  | 			ip = ar.(*dns.A).A
 | 
					
						
							|  |  |  | 		case dns.TypeAAAA:
 | 
					
						
							|  |  |  | 			ip = ar.(*dns.AAAA).AAAA
 | 
					
						
							|  |  |  | 		}
 | 
					
						
							|  |  |  | 		ws := w.domains[ar.Header().Name]
 | 
					
						
							|  |  |  | 		for _, w := range ws {
 | 
					
						
							|  |  |  | 			if w.address.Equal(ip) {
 | 
					
						
							|  |  |  | 				wa.weight = w.value
 | 
					
						
							|  |  |  | 				break
 | 
					
						
							|  |  |  | 			}
 | 
					
						
							|  |  |  | 		}
 | 
					
						
							|  |  |  | 		wsum += uint(wa.weight)
 | 
					
						
							|  |  |  | 	}
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	// Select the first (top) IP
 | 
					
						
							|  |  |  | 	sort.Slice(weightedAddr, func(i, j int) bool {
 | 
					
						
							|  |  |  | 		return weightedAddr[i].weight > weightedAddr[j].weight
 | 
					
						
							|  |  |  | 	})
 | 
					
						
							|  |  |  | 	v := w.randUint(wsum)
 | 
					
						
							|  |  |  | 	var psum uint
 | 
					
						
							|  |  |  | 	for _, wa := range weightedAddr {
 | 
					
						
							|  |  |  | 		psum += uint(wa.weight)
 | 
					
						
							|  |  |  | 		if v < psum {
 | 
					
						
							|  |  |  | 			return int(wa.index)
 | 
					
						
							|  |  |  | 		}
 | 
					
						
							|  |  |  | 	}
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	// we should never reach this
 | 
					
						
							| 
									
										
										
										
											2023-02-10 22:39:55 +08:00
										 |  |  | 	log.Errorf("Internal error: cannot find top address (randv:%v wsum:%v)", v, wsum)
 | 
					
						
							| 
									
										
										
										
											2023-01-27 17:36:56 +01:00
										 |  |  | 	return -1
 | 
					
						
							|  |  |  | }
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | // Start go routine to update weights from the weight file periodically
 | 
					
						
							|  |  |  | func (w *weightedRR) periodicWeightUpdate(stopReload <-chan bool) {
 | 
					
						
							|  |  |  | 	if w.reload == 0 {
 | 
					
						
							|  |  |  | 		return
 | 
					
						
							|  |  |  | 	}
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	go func() {
 | 
					
						
							|  |  |  | 		ticker := time.NewTicker(w.reload)
 | 
					
						
							|  |  |  | 		for {
 | 
					
						
							|  |  |  | 			select {
 | 
					
						
							|  |  |  | 			case <-stopReload:
 | 
					
						
							|  |  |  | 				return
 | 
					
						
							|  |  |  | 			case <-ticker.C:
 | 
					
						
							|  |  |  | 				err := w.updateWeights()
 | 
					
						
							|  |  |  | 				if err != nil {
 | 
					
						
							|  |  |  | 					log.Error(err)
 | 
					
						
							|  |  |  | 				}
 | 
					
						
							|  |  |  | 			}
 | 
					
						
							|  |  |  | 		}
 | 
					
						
							|  |  |  | 	}()
 | 
					
						
							|  |  |  | }
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | // Update weights from weight file
 | 
					
						
							|  |  |  | func (w *weightedRR) updateWeights() error {
 | 
					
						
							|  |  |  | 	reader, err := os.Open(filepath.Clean(w.fileName))
 | 
					
						
							|  |  |  | 	if err != nil {
 | 
					
						
							|  |  |  | 		return errOpen
 | 
					
						
							|  |  |  | 	}
 | 
					
						
							|  |  |  | 	defer reader.Close()
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	// check if the contents has changed
 | 
					
						
							|  |  |  | 	var buf bytes.Buffer
 | 
					
						
							|  |  |  | 	tee := io.TeeReader(reader, &buf)
 | 
					
						
							|  |  |  | 	bytes, err := io.ReadAll(tee)
 | 
					
						
							|  |  |  | 	if err != nil {
 | 
					
						
							|  |  |  | 		return err
 | 
					
						
							|  |  |  | 	}
 | 
					
						
							|  |  |  | 	md5sum := md5.Sum(bytes)
 | 
					
						
							|  |  |  | 	if md5sum == w.md5sum {
 | 
					
						
							|  |  |  | 		// file contents has not changed
 | 
					
						
							|  |  |  | 		return nil
 | 
					
						
							|  |  |  | 	}
 | 
					
						
							|  |  |  | 	w.md5sum = md5sum
 | 
					
						
							|  |  |  | 	scanner := bufio.NewScanner(&buf)
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	// Parse the weight file contents
 | 
					
						
							| 
									
										
										
										
											2023-02-15 18:43:19 +01:00
										 |  |  | 	domains, err := w.parseWeights(scanner)
 | 
					
						
							| 
									
										
										
										
											2023-01-27 17:36:56 +01:00
										 |  |  | 	if err != nil {
 | 
					
						
							|  |  |  | 		return err
 | 
					
						
							|  |  |  | 	}
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-02-15 18:43:19 +01:00
										 |  |  | 	// access to weights must be protected
 | 
					
						
							|  |  |  | 	w.mutex.Lock()
 | 
					
						
							|  |  |  | 	w.domains = domains
 | 
					
						
							|  |  |  | 	w.mutex.Unlock()
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-01-27 17:36:56 +01:00
										 |  |  | 	log.Infof("Successfully reloaded weight file %s", w.fileName)
 | 
					
						
							|  |  |  | 	return nil
 | 
					
						
							|  |  |  | }
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | // Parse the weight file contents
 | 
					
						
							| 
									
										
										
										
											2023-02-15 18:43:19 +01:00
										 |  |  | func (w *weightedRR) parseWeights(scanner *bufio.Scanner) (map[string]weights, error) {
 | 
					
						
							| 
									
										
										
										
											2023-01-27 17:36:56 +01:00
										 |  |  | 	var dname string
 | 
					
						
							|  |  |  | 	var ws weights
 | 
					
						
							| 
									
										
										
										
											2023-02-15 18:43:19 +01:00
										 |  |  | 	domains := make(map[string]weights)
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-01-27 17:36:56 +01:00
										 |  |  | 	for scanner.Scan() {
 | 
					
						
							|  |  |  | 		nextLine := strings.TrimSpace(scanner.Text())
 | 
					
						
							|  |  |  | 		if len(nextLine) == 0 || nextLine[0:1] == "#" {
 | 
					
						
							|  |  |  | 			// Empty and comment lines are ignored
 | 
					
						
							|  |  |  | 			continue
 | 
					
						
							|  |  |  | 		}
 | 
					
						
							|  |  |  | 		fields := strings.Fields(nextLine)
 | 
					
						
							|  |  |  | 		switch len(fields) {
 | 
					
						
							|  |  |  | 		case 1:
 | 
					
						
							|  |  |  | 			// (domain) name sanity check
 | 
					
						
							|  |  |  | 			if net.ParseIP(fields[0]) != nil {
 | 
					
						
							| 
									
										
										
										
											2023-02-15 18:43:19 +01:00
										 |  |  | 				return nil, fmt.Errorf("Wrong domain name:\"%s\" in weight file %s. (Maybe a missing weight value?)",
 | 
					
						
							| 
									
										
										
										
											2023-01-27 17:36:56 +01:00
										 |  |  | 					fields[0], w.fileName)
 | 
					
						
							|  |  |  | 			}
 | 
					
						
							|  |  |  | 			dname = fields[0]
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 			// add the root domain if it is missing
 | 
					
						
							|  |  |  | 			if dname[len(dname)-1] != '.' {
 | 
					
						
							|  |  |  | 				dname += "."
 | 
					
						
							|  |  |  | 			}
 | 
					
						
							|  |  |  | 			var ok bool
 | 
					
						
							| 
									
										
										
										
											2023-02-15 18:43:19 +01:00
										 |  |  | 			ws, ok = domains[dname]
 | 
					
						
							| 
									
										
										
										
											2023-01-27 17:36:56 +01:00
										 |  |  | 			if !ok {
 | 
					
						
							|  |  |  | 				ws = make(weights, 0)
 | 
					
						
							| 
									
										
										
										
											2023-02-15 18:43:19 +01:00
										 |  |  | 				domains[dname] = ws
 | 
					
						
							| 
									
										
										
										
											2023-01-27 17:36:56 +01:00
										 |  |  | 			}
 | 
					
						
							|  |  |  | 		case 2:
 | 
					
						
							|  |  |  | 			// IP address and weight value
 | 
					
						
							|  |  |  | 			ip := net.ParseIP(fields[0])
 | 
					
						
							|  |  |  | 			if ip == nil {
 | 
					
						
							| 
									
										
										
										
											2023-02-15 18:43:19 +01:00
										 |  |  | 				return nil, fmt.Errorf("Wrong IP address:\"%s\" in weight file %s", fields[0], w.fileName)
 | 
					
						
							| 
									
										
										
										
											2023-01-27 17:36:56 +01:00
										 |  |  | 			}
 | 
					
						
							|  |  |  | 			weight, err := strconv.ParseUint(fields[1], 10, 8)
 | 
					
						
							| 
									
										
										
										
											2023-02-15 18:43:19 +01:00
										 |  |  | 			if err != nil || weight == 0 {
 | 
					
						
							|  |  |  | 				return nil, fmt.Errorf("Wrong weight value:\"%s\" in weight file %s", fields[1], w.fileName)
 | 
					
						
							| 
									
										
										
										
											2023-01-27 17:36:56 +01:00
										 |  |  | 			}
 | 
					
						
							|  |  |  | 			witem := &weightItem{address: ip, value: uint8(weight)}
 | 
					
						
							|  |  |  | 			if dname == "" {
 | 
					
						
							| 
									
										
										
										
											2023-02-15 18:43:19 +01:00
										 |  |  | 				return nil, fmt.Errorf("Missing domain name in weight file %s", w.fileName)
 | 
					
						
							| 
									
										
										
										
											2023-01-27 17:36:56 +01:00
										 |  |  | 			}
 | 
					
						
							|  |  |  | 			ws = append(ws, witem)
 | 
					
						
							| 
									
										
										
										
											2023-02-15 18:43:19 +01:00
										 |  |  | 			domains[dname] = ws
 | 
					
						
							| 
									
										
										
										
											2023-01-27 17:36:56 +01:00
										 |  |  | 		default:
 | 
					
						
							| 
									
										
										
										
											2023-02-15 18:43:19 +01:00
										 |  |  | 			return nil, fmt.Errorf("Could not parse weight line:\"%s\" in weight file %s", nextLine, w.fileName)
 | 
					
						
							| 
									
										
										
										
											2023-01-27 17:36:56 +01:00
										 |  |  | 		}
 | 
					
						
							|  |  |  | 	}
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	if err := scanner.Err(); err != nil {
 | 
					
						
							| 
									
										
										
										
											2023-02-15 18:43:19 +01:00
										 |  |  | 		return nil, fmt.Errorf("Weight file %s parsing error:%s", w.fileName, err)
 | 
					
						
							| 
									
										
										
										
											2023-01-27 17:36:56 +01:00
										 |  |  | 	}
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-02-15 18:43:19 +01:00
										 |  |  | 	return domains, nil
 | 
					
						
							| 
									
										
										
										
											2023-01-27 17:36:56 +01:00
										 |  |  | }
 |