mirror of
				https://github.com/coredns/coredns.git
				synced 2025-10-30 09:43:17 -04:00 
			
		
		
		
	plugin/loadbalance: support prefer option (#7433)
Signed-off-by: Olli Janatuinen <olli.janatuinen@gmail.com>
This commit is contained in:
		| @@ -2,7 +2,7 @@ | ||||
|  | ||||
| ## Name | ||||
|  | ||||
| *loadbalance* - randomizes the order of A, AAAA and MX records. | ||||
| *loadbalance* - randomizes the order of A, AAAA and MX records  and optionally prefers specific subnets. | ||||
|  | ||||
| ## Description | ||||
|  | ||||
| @@ -18,6 +18,7 @@ implementations (like glibc) are particular about that. | ||||
| ~~~ | ||||
| loadbalance [round_robin | weighted WEIGHTFILE] { | ||||
| 			reload DURATION | ||||
| 			prefer CIDR [CIDR...] | ||||
| } | ||||
| ~~~ | ||||
| * `round_robin` policy randomizes the order of  A, AAAA, and MX records applying a uniform probability distribution. This is the default load balancing policy. | ||||
| @@ -26,6 +27,8 @@ loadbalance [round_robin | weighted WEIGHTFILE] { | ||||
| (top) A/AAAA record in the answer. Note that it does not shuffle all the records in the answer, it is only concerned about the first A/AAAA record | ||||
| returned in the answer. | ||||
|  | ||||
| Additionally, the plugin supports subnet-based ordering using the `prefer` directive, which reorders A/AAAA records so that IPs from preferred subnets appear first. | ||||
|  | ||||
|  * **WEIGHTFILE** is the file containing the weight values assigned to IPs for various domain names. If the path is relative, the path from the **root** plugin will be prepended to it. The format is explained below in the *Weightfile* section. | ||||
|  | ||||
|  * **DURATION** interval to reload `WEIGHTFILE` and update weight assignments if there are changes in the file. The default value is `30s`. A value of `0s` means to not scan for changes and reload. | ||||
| @@ -88,3 +91,17 @@ www.example.com | ||||
| 100.64.1.3 2 | ||||
| ~~~ | ||||
|  | ||||
| ### Subnet Prioritization | ||||
|  | ||||
| Prioritize IPs from 10.9.20.0/24 and 192.168.1.0/24: | ||||
|  | ||||
| ```corefile | ||||
| . { | ||||
|     loadbalance round_robin { | ||||
|         prefer 10.9.20.0/24 192.168.1.0/24 | ||||
|     } | ||||
|     forward . 1.1.1.1 | ||||
| } | ||||
| ``` | ||||
|  | ||||
| If the DNS response includes multiple A/AAAA records, the plugin will reorder them to place the ones matching preferred subnets first. | ||||
|   | ||||
							
								
								
									
										76
									
								
								plugin/loadbalance/prefer.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										76
									
								
								plugin/loadbalance/prefer.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,76 @@ | ||||
| package loadbalance | ||||
|  | ||||
| import ( | ||||
| 	"net" | ||||
|  | ||||
| 	"github.com/miekg/dns" | ||||
| ) | ||||
|  | ||||
| func reorderPreferredSubnets(msg *dns.Msg, subnets []*net.IPNet) *dns.Msg { | ||||
| 	msg.Answer = reorderRecords(msg.Answer, subnets) | ||||
| 	msg.Extra = reorderRecords(msg.Extra, subnets) | ||||
| 	return msg | ||||
| } | ||||
|  | ||||
| func reorderRecords(records []dns.RR, subnets []*net.IPNet) []dns.RR { | ||||
| 	var cname, address, mx, rest []dns.RR | ||||
|  | ||||
| 	for _, r := range records { | ||||
| 		switch r.Header().Rrtype { | ||||
| 		case dns.TypeCNAME: | ||||
| 			cname = append(cname, r) | ||||
| 		case dns.TypeA, dns.TypeAAAA: | ||||
| 			address = append(address, r) | ||||
| 		case dns.TypeMX: | ||||
| 			mx = append(mx, r) | ||||
| 		default: | ||||
| 			rest = append(rest, r) | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	sorted := sortBySubnetPriority(address, subnets) | ||||
|  | ||||
| 	out := append([]dns.RR{}, cname...) | ||||
| 	out = append(out, sorted...) | ||||
| 	out = append(out, mx...) | ||||
| 	out = append(out, rest...) | ||||
| 	return out | ||||
| } | ||||
|  | ||||
| func sortBySubnetPriority(records []dns.RR, subnets []*net.IPNet) []dns.RR { | ||||
| 	matched := make([]dns.RR, 0, len(records)) | ||||
| 	seen := make(map[int]bool) | ||||
|  | ||||
| 	for _, subnet := range subnets { | ||||
| 		for i, r := range records { | ||||
| 			if seen[i] { | ||||
| 				continue | ||||
| 			} | ||||
| 			ip := extractIP(r) | ||||
| 			if ip != nil && subnet.Contains(ip) { | ||||
| 				matched = append(matched, r) | ||||
| 				seen[i] = true | ||||
| 			} | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	unmatched := make([]dns.RR, 0, len(records)-len(matched)) | ||||
| 	for i, r := range records { | ||||
| 		if !seen[i] { | ||||
| 			unmatched = append(unmatched, r) | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	return append(matched, unmatched...) | ||||
| } | ||||
|  | ||||
| func extractIP(rr dns.RR) net.IP { | ||||
| 	switch r := rr.(type) { | ||||
| 	case *dns.A: | ||||
| 		return r.A | ||||
| 	case *dns.AAAA: | ||||
| 		return r.AAAA | ||||
| 	default: | ||||
| 		return nil | ||||
| 	} | ||||
| } | ||||
							
								
								
									
										96
									
								
								plugin/loadbalance/prefer_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										96
									
								
								plugin/loadbalance/prefer_test.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,96 @@ | ||||
| package loadbalance | ||||
|  | ||||
| import ( | ||||
| 	"net" | ||||
| 	"testing" | ||||
|  | ||||
| 	"github.com/coredns/coredns/plugin/test" | ||||
|  | ||||
| 	"github.com/miekg/dns" | ||||
| ) | ||||
|  | ||||
| func TestSortPreferred(t *testing.T) { | ||||
| 	records := []dns.RR{ | ||||
| 		test.A("example.org. 300 IN A 10.9.30.1"), | ||||
| 		test.A("example.org. 300 IN A 10.9.20.5"), | ||||
| 		test.A("example.org. 300 IN A 192.168.1.2"), | ||||
| 		test.A("example.org. 300 IN A 10.10.0.1"), | ||||
| 		test.A("example.org. 300 IN A 10.9.20.3"), | ||||
| 		test.A("example.org. 300 IN A 172.16.0.1"), | ||||
| 		test.AAAA("example.org. 300 IN AAAA 2001:db8::1"), | ||||
| 		test.AAAA("example.org. 300 IN AAAA 2001:db8:abcd::1"), | ||||
| 		test.AAAA("example.org. 300 IN AAAA fd00::1"), | ||||
| 		test.CNAME("example.org. 300 IN CNAME alias.example.org."), | ||||
| 	} | ||||
|  | ||||
| 	subnets := []*net.IPNet{} | ||||
| 	cidrs := []string{"2001:db8::/32", "10.9.20.0/24", "10.9.30.0/24"} | ||||
| 	for _, cidr := range cidrs { | ||||
| 		_, subnet, err := net.ParseCIDR(cidr) | ||||
| 		if err != nil { | ||||
| 			t.Fatalf("Failed to parse CIDR: %v", err) | ||||
| 		} | ||||
| 		subnets = append(subnets, subnet) | ||||
| 	} | ||||
|  | ||||
| 	msg := &dns.Msg{Answer: records} | ||||
| 	reorderPreferredSubnets(msg, subnets) | ||||
| 	sorted := msg.Answer | ||||
|  | ||||
| 	expectedOrder := []string{ | ||||
| 		"alias.example.org.", | ||||
| 		"2001:db8::1", | ||||
| 		"2001:db8:abcd::1", | ||||
| 		"10.9.20.5", | ||||
| 		"10.9.20.3", | ||||
| 		"10.9.30.1", | ||||
| 		"192.168.1.2", | ||||
| 		"10.10.0.1", | ||||
| 		"172.16.0.1", | ||||
| 		"fd00::1", | ||||
| 	} | ||||
|  | ||||
| 	if len(sorted) != len(expectedOrder) { | ||||
| 		t.Fatalf("Expected %d records, got %d", len(expectedOrder), len(sorted)) | ||||
| 	} | ||||
|  | ||||
| 	for i, rr := range sorted { | ||||
| 		expected := expectedOrder[i] | ||||
| 		switch r := rr.(type) { | ||||
| 		case *dns.CNAME: | ||||
| 			if r.Target != expected { | ||||
| 				t.Errorf("Record %d: expected CNAME %s, got %s", i, expected, r.Target) | ||||
| 			} | ||||
| 		case *dns.A: | ||||
| 			if r.A.String() != expected { | ||||
| 				t.Errorf("Record %d: expected A IP %s, got %s", i, expected, r.A.String()) | ||||
| 			} | ||||
| 		case *dns.AAAA: | ||||
| 			if r.AAAA.String() != expected { | ||||
| 				t.Errorf("Record %d: expected AAAA IP %s, got %s", i, expected, r.AAAA.String()) | ||||
| 			} | ||||
| 		default: | ||||
| 			t.Errorf("Record %d: unexpected RR type %T", i, r) | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func TestExtractIP(t *testing.T) { | ||||
| 	a := test.A("example.org. 300 IN A 10.0.0.1") | ||||
| 	ip := extractIP(a) | ||||
| 	if ip.String() != "10.0.0.1" { | ||||
| 		t.Errorf("Expected 10.0.0.1, got %s", ip.String()) | ||||
| 	} | ||||
|  | ||||
| 	aaaa := test.AAAA("example.org. 300 IN AAAA ::1") | ||||
| 	ip = extractIP(aaaa) | ||||
| 	if ip.String() != "::1" { | ||||
| 		t.Errorf("Expected ::1, got %s", ip.String()) | ||||
| 	} | ||||
|  | ||||
| 	cname := test.CNAME("example.org. 300 IN CNAME other.org.") | ||||
| 	ip = extractIP(cname) | ||||
| 	if ip != nil { | ||||
| 		t.Errorf("Expected nil for CNAME, got %v", ip) | ||||
| 	} | ||||
| } | ||||
| @@ -3,6 +3,7 @@ package loadbalance | ||||
| import ( | ||||
| 	"errors" | ||||
| 	"fmt" | ||||
| 	"net" | ||||
| 	"path/filepath" | ||||
| 	"time" | ||||
|  | ||||
| @@ -24,6 +25,7 @@ type lbFuncs struct { | ||||
| 	onStartUpFunc  func() error | ||||
| 	onShutdownFunc func() error | ||||
| 	weighted       *weightedRR // used in unit tests only | ||||
| 	preferSubnets  []*net.IPNet | ||||
| } | ||||
|  | ||||
| func setup(c *caddy.Controller) error { | ||||
| @@ -39,65 +41,90 @@ func setup(c *caddy.Controller) error { | ||||
| 		c.OnShutdown(lb.onShutdownFunc) | ||||
| 	} | ||||
|  | ||||
| 	shuffle := lb.shuffleFunc | ||||
| 	if len(lb.preferSubnets) > 0 { | ||||
| 		original := shuffle | ||||
| 		shuffle = func(res *dns.Msg) *dns.Msg { | ||||
| 			return reorderPreferredSubnets(original(res), lb.preferSubnets) | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	dnsserver.GetConfig(c).AddPlugin(func(next plugin.Handler) plugin.Handler { | ||||
| 		return LoadBalance{Next: next, shuffle: lb.shuffleFunc} | ||||
| 		return LoadBalance{Next: next, shuffle: shuffle} | ||||
| 	}) | ||||
|  | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| // func parse(c *caddy.Controller) (string, *weightedRR, error) { | ||||
| func parse(c *caddy.Controller) (*lbFuncs, error) { | ||||
| 	config := dnsserver.GetConfig(c) | ||||
| 	lb := &lbFuncs{} | ||||
|  | ||||
| 	for c.Next() { | ||||
| 		args := c.RemainingArgs() | ||||
| 		if len(args) == 0 { | ||||
| 			return &lbFuncs{shuffleFunc: randomShuffle}, nil | ||||
| 		} | ||||
| 		switch args[0] { | ||||
| 		case ramdomShufflePolicy: | ||||
| 			if len(args) > 1 { | ||||
| 				return nil, c.Errf("unknown property for %s", args[0]) | ||||
| 			} | ||||
| 			return &lbFuncs{shuffleFunc: randomShuffle}, nil | ||||
| 		case weightedRoundRobinPolicy: | ||||
| 			if len(args) < 2 { | ||||
| 				return nil, c.Err("missing weight file argument") | ||||
| 			} | ||||
|  | ||||
| 			if len(args) > 2 { | ||||
| 				return nil, c.Err("unexpected argument(s)") | ||||
| 			} | ||||
|  | ||||
| 			weightFileName := args[1] | ||||
| 			if !filepath.IsAbs(weightFileName) && config.Root != "" { | ||||
| 				weightFileName = filepath.Join(config.Root, weightFileName) | ||||
| 			} | ||||
| 			reload := 30 * time.Second // default reload period | ||||
| 			for c.NextBlock() { | ||||
| 				switch c.Val() { | ||||
| 				case "reload": | ||||
| 					t := c.RemainingArgs() | ||||
| 					if len(t) < 1 { | ||||
| 						return nil, c.Err("reload duration value is missing") | ||||
| 					} | ||||
| 					if len(t) > 1 { | ||||
| 						return nil, c.Err("unexpected argument") | ||||
| 					} | ||||
| 					var err error | ||||
| 					reload, err = time.ParseDuration(t[0]) | ||||
| 					if err != nil { | ||||
| 						return nil, c.Errf("invalid reload duration '%s'", t[0]) | ||||
| 					} | ||||
| 				default: | ||||
| 					return nil, c.Errf("unknown property '%s'", c.Val()) | ||||
| 			lb.shuffleFunc = randomShuffle | ||||
| 		} else { | ||||
| 			switch args[0] { | ||||
| 			case ramdomShufflePolicy: | ||||
| 				if len(args) > 1 { | ||||
| 					return nil, c.Errf("unknown property for %s", args[0]) | ||||
| 				} | ||||
| 				lb.shuffleFunc = randomShuffle | ||||
|  | ||||
| 			case weightedRoundRobinPolicy: | ||||
| 				if len(args) < 2 { | ||||
| 					return nil, c.Err("missing weight file argument") | ||||
| 				} | ||||
| 				if len(args) > 2 { | ||||
| 					return nil, c.Err("unexpected argument(s)") | ||||
| 				} | ||||
| 				weightFileName := args[1] | ||||
| 				if !filepath.IsAbs(weightFileName) && config.Root != "" { | ||||
| 					weightFileName = filepath.Join(config.Root, weightFileName) | ||||
| 				} | ||||
| 				reload := 30 * time.Second | ||||
| 				for c.NextBlock() { | ||||
| 					switch c.Val() { | ||||
| 					case "reload": | ||||
| 						t := c.RemainingArgs() | ||||
| 						if len(t) < 1 { | ||||
| 							return nil, c.Err("reload duration value is missing") | ||||
| 						} | ||||
| 						if len(t) > 1 { | ||||
| 							return nil, c.Err("unexpected argument") | ||||
| 						} | ||||
| 						var err error | ||||
| 						reload, err = time.ParseDuration(t[0]) | ||||
| 						if err != nil { | ||||
| 							return nil, c.Errf("invalid reload duration '%s'", t[0]) | ||||
| 						} | ||||
| 					default: | ||||
| 						return nil, c.Errf("unknown property '%s'", c.Val()) | ||||
| 					} | ||||
| 				} | ||||
| 				*lb = *createWeightedFuncs(weightFileName, reload) | ||||
| 			default: | ||||
| 				return nil, fmt.Errorf("unknown policy: %s", args[0]) | ||||
| 			} | ||||
| 		} | ||||
|  | ||||
| 		for c.NextBlock() { | ||||
| 			switch c.Val() { | ||||
| 			case "prefer": | ||||
| 				cidrs := c.RemainingArgs() | ||||
| 				for _, cidr := range cidrs { | ||||
| 					_, subnet, err := net.ParseCIDR(cidr) | ||||
| 					if err != nil { | ||||
| 						return nil, c.Errf("invalid CIDR %q: %v", cidr, err) | ||||
| 					} | ||||
| 					lb.preferSubnets = append(lb.preferSubnets, subnet) | ||||
| 				} | ||||
| 			default: | ||||
| 				return nil, c.Errf("unknown property '%s'", c.Val()) | ||||
| 			} | ||||
| 			return createWeightedFuncs(weightFileName, reload), nil | ||||
| 		default: | ||||
| 			return nil, fmt.Errorf("unknown policy: %s", args[0]) | ||||
| 		} | ||||
| 	} | ||||
| 	return nil, c.ArgErr() | ||||
|  | ||||
| 	return lb, nil | ||||
| } | ||||
|   | ||||
		Reference in New Issue
	
	Block a user