mirror of
https://github.com/coredns/coredns.git
synced 2025-10-27 08:14:18 -04:00
plugin/loadbalance: support prefer option (#7433)
Signed-off-by: Olli Janatuinen <olli.janatuinen@gmail.com>
This commit is contained in:
76
plugin/loadbalance/prefer.go
Normal file
76
plugin/loadbalance/prefer.go
Normal file
@@ -0,0 +1,76 @@
|
||||
package loadbalance
|
||||
|
||||
import (
|
||||
"net"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
|
||||
func reorderPreferredSubnets(msg *dns.Msg, subnets []*net.IPNet) *dns.Msg {
|
||||
msg.Answer = reorderRecords(msg.Answer, subnets)
|
||||
msg.Extra = reorderRecords(msg.Extra, subnets)
|
||||
return msg
|
||||
}
|
||||
|
||||
func reorderRecords(records []dns.RR, subnets []*net.IPNet) []dns.RR {
|
||||
var cname, address, mx, rest []dns.RR
|
||||
|
||||
for _, r := range records {
|
||||
switch r.Header().Rrtype {
|
||||
case dns.TypeCNAME:
|
||||
cname = append(cname, r)
|
||||
case dns.TypeA, dns.TypeAAAA:
|
||||
address = append(address, r)
|
||||
case dns.TypeMX:
|
||||
mx = append(mx, r)
|
||||
default:
|
||||
rest = append(rest, r)
|
||||
}
|
||||
}
|
||||
|
||||
sorted := sortBySubnetPriority(address, subnets)
|
||||
|
||||
out := append([]dns.RR{}, cname...)
|
||||
out = append(out, sorted...)
|
||||
out = append(out, mx...)
|
||||
out = append(out, rest...)
|
||||
return out
|
||||
}
|
||||
|
||||
func sortBySubnetPriority(records []dns.RR, subnets []*net.IPNet) []dns.RR {
|
||||
matched := make([]dns.RR, 0, len(records))
|
||||
seen := make(map[int]bool)
|
||||
|
||||
for _, subnet := range subnets {
|
||||
for i, r := range records {
|
||||
if seen[i] {
|
||||
continue
|
||||
}
|
||||
ip := extractIP(r)
|
||||
if ip != nil && subnet.Contains(ip) {
|
||||
matched = append(matched, r)
|
||||
seen[i] = true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
unmatched := make([]dns.RR, 0, len(records)-len(matched))
|
||||
for i, r := range records {
|
||||
if !seen[i] {
|
||||
unmatched = append(unmatched, r)
|
||||
}
|
||||
}
|
||||
|
||||
return append(matched, unmatched...)
|
||||
}
|
||||
|
||||
func extractIP(rr dns.RR) net.IP {
|
||||
switch r := rr.(type) {
|
||||
case *dns.A:
|
||||
return r.A
|
||||
case *dns.AAAA:
|
||||
return r.AAAA
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user