Prevent fail counter of a proxy overflows (#5990)

Signed-off-by: vanceli <vanceli@tencent.com>
Signed-off-by: Vance Li <vncl@YingyingM1.local>
Co-authored-by: vanceli <vanceli@tencent.com>
This commit is contained in:
Vancl
2023-04-16 22:08:56 +08:00
committed by GitHub
parent 8e8231d627
commit 7db1d4f6e9
3 changed files with 42 additions and 1 deletions

View File

@@ -105,7 +105,7 @@ func (h *dnsHc) Check(p *Proxy) error {
err := h.send(p.addr) err := h.send(p.addr)
if err != nil { if err != nil {
HealthcheckFailureCount.WithLabelValues(p.addr).Add(1) HealthcheckFailureCount.WithLabelValues(p.addr).Add(1)
atomic.AddUint32(&p.fails, 1) p.incrementFails()
return err return err
} }

View File

@@ -93,6 +93,16 @@ func (p *Proxy) SetReadTimeout(duration time.Duration) {
p.readTimeout = duration p.readTimeout = duration
} }
// incrementFails increments the number of fails safely.
func (p *Proxy) incrementFails() {
curVal := atomic.LoadUint32(&p.fails)
if curVal > curVal+1 {
// overflow occurred, do not update the counter again
return
}
atomic.AddUint32(&p.fails, 1)
}
const ( const (
maxTimeout = 2 * time.Second maxTimeout = 2 * time.Second
) )

View File

@@ -3,6 +3,7 @@ package proxy
import ( import (
"context" "context"
"crypto/tls" "crypto/tls"
"math"
"testing" "testing"
"time" "time"
@@ -97,3 +98,33 @@ func TestProtocolSelection(t *testing.T) {
} }
} }
} }
func TestProxyIncrementFails(t *testing.T) {
var testCases = []struct {
name string
fails uint32
expectFails uint32
}{
{
name: "increment fails counter overflows",
fails: math.MaxUint32,
expectFails: math.MaxUint32,
},
{
name: "increment fails counter",
fails: 0,
expectFails: 1,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
p := NewProxy("bad_address", transport.DNS)
p.fails = tc.fails
p.incrementFails()
if p.fails != tc.expectFails {
t.Errorf("Expected fails to be %d, got %d", tc.expectFails, p.fails)
}
})
}
}