mirror of
				https://github.com/coredns/coredns.git
				synced 2025-11-03 18:53:13 -05:00 
			
		
		
		
	
		
			
				
	
	
		
			77 lines
		
	
	
		
			1.5 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
			
		
		
	
	
			77 lines
		
	
	
		
			1.5 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
package loadbalance
 | 
						|
 | 
						|
import (
 | 
						|
	"net"
 | 
						|
 | 
						|
	"github.com/miekg/dns"
 | 
						|
)
 | 
						|
 | 
						|
func reorderPreferredSubnets(msg *dns.Msg, subnets []*net.IPNet) *dns.Msg {
 | 
						|
	msg.Answer = reorderRecords(msg.Answer, subnets)
 | 
						|
	msg.Extra = reorderRecords(msg.Extra, subnets)
 | 
						|
	return msg
 | 
						|
}
 | 
						|
 | 
						|
func reorderRecords(records []dns.RR, subnets []*net.IPNet) []dns.RR {
 | 
						|
	var cname, address, mx, rest []dns.RR
 | 
						|
 | 
						|
	for _, r := range records {
 | 
						|
		switch r.Header().Rrtype {
 | 
						|
		case dns.TypeCNAME:
 | 
						|
			cname = append(cname, r)
 | 
						|
		case dns.TypeA, dns.TypeAAAA:
 | 
						|
			address = append(address, r)
 | 
						|
		case dns.TypeMX:
 | 
						|
			mx = append(mx, r)
 | 
						|
		default:
 | 
						|
			rest = append(rest, r)
 | 
						|
		}
 | 
						|
	}
 | 
						|
 | 
						|
	sorted := sortBySubnetPriority(address, subnets)
 | 
						|
 | 
						|
	out := append([]dns.RR{}, cname...)
 | 
						|
	out = append(out, sorted...)
 | 
						|
	out = append(out, mx...)
 | 
						|
	out = append(out, rest...)
 | 
						|
	return out
 | 
						|
}
 | 
						|
 | 
						|
func sortBySubnetPriority(records []dns.RR, subnets []*net.IPNet) []dns.RR {
 | 
						|
	matched := make([]dns.RR, 0, len(records))
 | 
						|
	seen := make(map[int]bool)
 | 
						|
 | 
						|
	for _, subnet := range subnets {
 | 
						|
		for i, r := range records {
 | 
						|
			if seen[i] {
 | 
						|
				continue
 | 
						|
			}
 | 
						|
			ip := extractIP(r)
 | 
						|
			if ip != nil && subnet.Contains(ip) {
 | 
						|
				matched = append(matched, r)
 | 
						|
				seen[i] = true
 | 
						|
			}
 | 
						|
		}
 | 
						|
	}
 | 
						|
 | 
						|
	unmatched := make([]dns.RR, 0, len(records)-len(matched))
 | 
						|
	for i, r := range records {
 | 
						|
		if !seen[i] {
 | 
						|
			unmatched = append(unmatched, r)
 | 
						|
		}
 | 
						|
	}
 | 
						|
 | 
						|
	return append(matched, unmatched...)
 | 
						|
}
 | 
						|
 | 
						|
func extractIP(rr dns.RR) net.IP {
 | 
						|
	switch r := rr.(type) {
 | 
						|
	case *dns.A:
 | 
						|
		return r.A
 | 
						|
	case *dns.AAAA:
 | 
						|
		return r.AAAA
 | 
						|
	default:
 | 
						|
		return nil
 | 
						|
	}
 | 
						|
}
 |