fix(grpc): check proxy list length in policies (#7512)

This commit is contained in:
Ville Vesilehto
2025-09-04 02:24:44 +03:00
committed by GitHub
parent abef207695
commit 066e51675c
2 changed files with 133 additions and 0 deletions

View File

@@ -20,6 +20,8 @@ func (r *random) String() string { return "random" }
func (r *random) List(p []*Proxy) []*Proxy {
switch len(p) {
case 0:
return nil
case 1:
return p
case 2:
@@ -46,6 +48,9 @@ type roundRobin struct {
func (r *roundRobin) String() string { return "round_robin" }
func (r *roundRobin) List(p []*Proxy) []*Proxy {
if len(p) == 0 {
return nil
}
poolLen := uint32(len(p))
i := atomic.AddUint32(&r.robin, 1) % poolLen

128
plugin/grpc/policy_test.go Normal file
View File

@@ -0,0 +1,128 @@
package grpc
import (
"testing"
)
func TestRoundRobinEmpty(t *testing.T) {
t.Parallel()
r := &roundRobin{}
got := r.List(nil)
if len(got) != 0 {
t.Fatalf("expected length 0, got %d", len(got))
}
}
func TestRandomEmpty(t *testing.T) {
t.Parallel()
r := &random{}
got := r.List(nil)
if len(got) != 0 {
t.Fatalf("expected length 0, got %d", len(got))
}
}
func TestSequentialEmpty(t *testing.T) {
t.Parallel()
r := &sequential{}
got := r.List(nil)
if len(got) != 0 {
t.Fatalf("expected length 0, got %d", len(got))
}
}
func TestPoliciesOrdering(t *testing.T) {
t.Parallel()
p0 := &Proxy{addr: "p0"}
p1 := &Proxy{addr: "p1"}
p2 := &Proxy{addr: "p2"}
in := []*Proxy{p0, p1, p2}
t.Run("sequential keeps order", func(t *testing.T) {
t.Parallel()
r := &sequential{}
got := r.List(in)
if len(got) != len(in) {
t.Fatalf("expected length %d, got %d", len(in), len(got))
}
for i := range in {
if got[i] != in[i] {
t.Fatalf("sequential order changed at %d: want %p, got %p", i, in[i], got[i])
}
}
})
t.Run("round robin advances and permutation", func(t *testing.T) {
t.Parallel()
r := &roundRobin{}
got1 := r.List(in)
if !isPermutation(in, got1) {
t.Fatalf("first call: expected permutation of input")
}
if got1[0] != p1 {
t.Fatalf("first element should advance to p1, got %p", got1[0])
}
got2 := r.List(in)
if !isPermutation(in, got2) {
t.Fatalf("second call: expected permutation of input")
}
if got2[0] != p2 {
t.Fatalf("first element should advance to p2 on second call, got %p", got2[0])
}
got3 := r.List(in)
if !isPermutation(in, got3) {
t.Fatalf("third call: expected permutation of input")
}
if got3[0] != p0 {
t.Fatalf("first element should wrap to p0 on third call, got %p", got3[0])
}
})
t.Run("random is a permutation", func(t *testing.T) {
t.Parallel()
r := &random{}
got := r.List(in)
if !isPermutation(in, got) {
t.Fatalf("random did not return a permutation of input")
}
})
t.Run("random with two proxies", func(t *testing.T) {
t.Parallel()
r := &random{}
in2 := []*Proxy{p0, p1}
got := r.List(in2)
if !isPermutation(in2, got) {
t.Fatalf("random did not return a permutation of input")
}
})
}
// Helper: returns true if b is a permutation of a (same multiset of pointers).
func isPermutation(a, b []*Proxy) bool {
if len(a) != len(b) {
return false
}
count := make(map[*Proxy]int, len(a))
for _, p := range a {
count[p]++
}
for _, p := range b {
count[p]--
if count[p] < 0 {
return false
}
}
return true
}