mirror of
https://github.com/coredns/coredns.git
synced 2025-11-01 10:43:17 -04:00
plugin/forward: Allow Proxy to be used outside of forward plugin. (#5951)
* plugin/forward: Move Proxy into pkg/plugin/proxy, to allow forward.Proxy to be used outside of forward plugin. Signed-off-by: Patrick Downey <patrick.downey@dioadconsulting.com>
This commit is contained in:
152
plugin/pkg/proxy/connect.go
Normal file
152
plugin/pkg/proxy/connect.go
Normal file
@@ -0,0 +1,152 @@
|
||||
// Package proxy implements a forwarding proxy. It caches an upstream net.Conn for some time, so if the same
|
||||
// client returns the upstream's Conn will be precached. Depending on how you benchmark this looks to be
|
||||
// 50% faster than just opening a new connection for every client. It works with UDP and TCP and uses
|
||||
// inband healthchecking.
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"strconv"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/coredns/coredns/request"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
|
||||
// limitTimeout is a utility function to auto-tune timeout values
|
||||
// average observed time is moved towards the last observed delay moderated by a weight
|
||||
// next timeout to use will be the double of the computed average, limited by min and max frame.
|
||||
func limitTimeout(currentAvg *int64, minValue time.Duration, maxValue time.Duration) time.Duration {
|
||||
rt := time.Duration(atomic.LoadInt64(currentAvg))
|
||||
if rt < minValue {
|
||||
return minValue
|
||||
}
|
||||
if rt < maxValue/2 {
|
||||
return 2 * rt
|
||||
}
|
||||
return maxValue
|
||||
}
|
||||
|
||||
func averageTimeout(currentAvg *int64, observedDuration time.Duration, weight int64) {
|
||||
dt := time.Duration(atomic.LoadInt64(currentAvg))
|
||||
atomic.AddInt64(currentAvg, int64(observedDuration-dt)/weight)
|
||||
}
|
||||
|
||||
func (t *Transport) dialTimeout() time.Duration {
|
||||
return limitTimeout(&t.avgDialTime, minDialTimeout, maxDialTimeout)
|
||||
}
|
||||
|
||||
func (t *Transport) updateDialTimeout(newDialTime time.Duration) {
|
||||
averageTimeout(&t.avgDialTime, newDialTime, cumulativeAvgWeight)
|
||||
}
|
||||
|
||||
// Dial dials the address configured in transport, potentially reusing a connection or creating a new one.
|
||||
func (t *Transport) Dial(proto string) (*persistConn, bool, error) {
|
||||
// If tls has been configured; use it.
|
||||
if t.tlsConfig != nil {
|
||||
proto = "tcp-tls"
|
||||
}
|
||||
|
||||
t.dial <- proto
|
||||
pc := <-t.ret
|
||||
|
||||
if pc != nil {
|
||||
ConnCacheHitsCount.WithLabelValues(t.addr, proto).Add(1)
|
||||
return pc, true, nil
|
||||
}
|
||||
ConnCacheMissesCount.WithLabelValues(t.addr, proto).Add(1)
|
||||
|
||||
reqTime := time.Now()
|
||||
timeout := t.dialTimeout()
|
||||
if proto == "tcp-tls" {
|
||||
conn, err := dns.DialTimeoutWithTLS("tcp", t.addr, t.tlsConfig, timeout)
|
||||
t.updateDialTimeout(time.Since(reqTime))
|
||||
return &persistConn{c: conn}, false, err
|
||||
}
|
||||
conn, err := dns.DialTimeout(proto, t.addr, timeout)
|
||||
t.updateDialTimeout(time.Since(reqTime))
|
||||
return &persistConn{c: conn}, false, err
|
||||
}
|
||||
|
||||
// Connect selects an upstream, sends the request and waits for a response.
|
||||
func (p *Proxy) Connect(ctx context.Context, state request.Request, opts Options) (*dns.Msg, error) {
|
||||
start := time.Now()
|
||||
|
||||
proto := ""
|
||||
switch {
|
||||
case opts.ForceTCP: // TCP flag has precedence over UDP flag
|
||||
proto = "tcp"
|
||||
case opts.PreferUDP:
|
||||
proto = "udp"
|
||||
default:
|
||||
proto = state.Proto()
|
||||
}
|
||||
|
||||
pc, cached, err := p.transport.Dial(proto)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Set buffer size correctly for this client.
|
||||
pc.c.UDPSize = uint16(state.Size())
|
||||
if pc.c.UDPSize < 512 {
|
||||
pc.c.UDPSize = 512
|
||||
}
|
||||
|
||||
pc.c.SetWriteDeadline(time.Now().Add(maxTimeout))
|
||||
// records the origin Id before upstream.
|
||||
originId := state.Req.Id
|
||||
state.Req.Id = dns.Id()
|
||||
defer func() {
|
||||
state.Req.Id = originId
|
||||
}()
|
||||
|
||||
if err := pc.c.WriteMsg(state.Req); err != nil {
|
||||
pc.c.Close() // not giving it back
|
||||
if err == io.EOF && cached {
|
||||
return nil, ErrCachedClosed
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var ret *dns.Msg
|
||||
pc.c.SetReadDeadline(time.Now().Add(p.readTimeout))
|
||||
for {
|
||||
ret, err = pc.c.ReadMsg()
|
||||
if err != nil {
|
||||
pc.c.Close() // not giving it back
|
||||
if err == io.EOF && cached {
|
||||
return nil, ErrCachedClosed
|
||||
}
|
||||
// recovery the origin Id after upstream.
|
||||
if ret != nil {
|
||||
ret.Id = originId
|
||||
}
|
||||
return ret, err
|
||||
}
|
||||
// drop out-of-order responses
|
||||
if state.Req.Id == ret.Id {
|
||||
break
|
||||
}
|
||||
}
|
||||
// recovery the origin Id after upstream.
|
||||
ret.Id = originId
|
||||
|
||||
p.transport.Yield(pc)
|
||||
|
||||
rc, ok := dns.RcodeToString[ret.Rcode]
|
||||
if !ok {
|
||||
rc = strconv.Itoa(ret.Rcode)
|
||||
}
|
||||
|
||||
RequestCount.WithLabelValues(p.addr).Add(1)
|
||||
RcodeCount.WithLabelValues(rc, p.addr).Add(1)
|
||||
RequestDuration.WithLabelValues(p.addr, rc).Observe(time.Since(start).Seconds())
|
||||
|
||||
return ret, nil
|
||||
}
|
||||
|
||||
const cumulativeAvgWeight = 4
|
||||
26
plugin/pkg/proxy/errors.go
Normal file
26
plugin/pkg/proxy/errors.go
Normal file
@@ -0,0 +1,26 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"errors"
|
||||
)
|
||||
|
||||
var (
|
||||
// ErrNoHealthy means no healthy proxies left.
|
||||
ErrNoHealthy = errors.New("no healthy proxies")
|
||||
// ErrNoForward means no forwarder defined.
|
||||
ErrNoForward = errors.New("no forwarder defined")
|
||||
// ErrCachedClosed means cached connection was closed by peer.
|
||||
ErrCachedClosed = errors.New("cached connection was closed by peer")
|
||||
)
|
||||
|
||||
// Options holds various Options that can be set.
|
||||
type Options struct {
|
||||
// ForceTCP use TCP protocol for upstream DNS request. Has precedence over PreferUDP flag
|
||||
ForceTCP bool
|
||||
// PreferUDP use UDP protocol for upstream DNS request.
|
||||
PreferUDP bool
|
||||
// HCRecursionDesired sets recursion desired flag for Proxy healthcheck requests
|
||||
HCRecursionDesired bool
|
||||
// HCDomain sets domain for Proxy healthcheck requests
|
||||
HCDomain string
|
||||
}
|
||||
131
plugin/pkg/proxy/health.go
Normal file
131
plugin/pkg/proxy/health.go
Normal file
@@ -0,0 +1,131 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/coredns/coredns/plugin/pkg/log"
|
||||
"github.com/coredns/coredns/plugin/pkg/transport"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
|
||||
// HealthChecker checks the upstream health.
|
||||
type HealthChecker interface {
|
||||
Check(*Proxy) error
|
||||
SetTLSConfig(*tls.Config)
|
||||
GetTLSConfig() *tls.Config
|
||||
SetRecursionDesired(bool)
|
||||
GetRecursionDesired() bool
|
||||
SetDomain(domain string)
|
||||
GetDomain() string
|
||||
SetTCPTransport()
|
||||
GetReadTimeout() time.Duration
|
||||
SetReadTimeout(time.Duration)
|
||||
GetWriteTimeout() time.Duration
|
||||
SetWriteTimeout(time.Duration)
|
||||
}
|
||||
|
||||
// dnsHc is a health checker for a DNS endpoint (DNS, and DoT).
|
||||
type dnsHc struct {
|
||||
c *dns.Client
|
||||
recursionDesired bool
|
||||
domain string
|
||||
}
|
||||
|
||||
// NewHealthChecker returns a new HealthChecker based on transport.
|
||||
func NewHealthChecker(trans string, recursionDesired bool, domain string) HealthChecker {
|
||||
switch trans {
|
||||
case transport.DNS, transport.TLS:
|
||||
c := new(dns.Client)
|
||||
c.Net = "udp"
|
||||
c.ReadTimeout = 1 * time.Second
|
||||
c.WriteTimeout = 1 * time.Second
|
||||
|
||||
return &dnsHc{
|
||||
c: c,
|
||||
recursionDesired: recursionDesired,
|
||||
domain: domain,
|
||||
}
|
||||
}
|
||||
|
||||
log.Warningf("No healthchecker for transport %q", trans)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *dnsHc) SetTLSConfig(cfg *tls.Config) {
|
||||
h.c.Net = "tcp-tls"
|
||||
h.c.TLSConfig = cfg
|
||||
}
|
||||
|
||||
func (h *dnsHc) GetTLSConfig() *tls.Config {
|
||||
return h.c.TLSConfig
|
||||
}
|
||||
|
||||
func (h *dnsHc) SetRecursionDesired(recursionDesired bool) {
|
||||
h.recursionDesired = recursionDesired
|
||||
}
|
||||
func (h *dnsHc) GetRecursionDesired() bool {
|
||||
return h.recursionDesired
|
||||
}
|
||||
|
||||
func (h *dnsHc) SetDomain(domain string) {
|
||||
h.domain = domain
|
||||
}
|
||||
func (h *dnsHc) GetDomain() string {
|
||||
return h.domain
|
||||
}
|
||||
|
||||
func (h *dnsHc) SetTCPTransport() {
|
||||
h.c.Net = "tcp"
|
||||
}
|
||||
|
||||
func (h *dnsHc) GetReadTimeout() time.Duration {
|
||||
return h.c.ReadTimeout
|
||||
}
|
||||
|
||||
func (h *dnsHc) SetReadTimeout(t time.Duration) {
|
||||
h.c.ReadTimeout = t
|
||||
}
|
||||
|
||||
func (h *dnsHc) GetWriteTimeout() time.Duration {
|
||||
return h.c.WriteTimeout
|
||||
}
|
||||
|
||||
func (h *dnsHc) SetWriteTimeout(t time.Duration) {
|
||||
h.c.WriteTimeout = t
|
||||
}
|
||||
|
||||
// For HC, we send to . IN NS +[no]rec message to the upstream. Dial timeouts and empty
|
||||
// replies are considered fails, basically anything else constitutes a healthy upstream.
|
||||
|
||||
// Check is used as the up.Func in the up.Probe.
|
||||
func (h *dnsHc) Check(p *Proxy) error {
|
||||
err := h.send(p.addr)
|
||||
if err != nil {
|
||||
HealthcheckFailureCount.WithLabelValues(p.addr).Add(1)
|
||||
atomic.AddUint32(&p.fails, 1)
|
||||
return err
|
||||
}
|
||||
|
||||
atomic.StoreUint32(&p.fails, 0)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *dnsHc) send(addr string) error {
|
||||
ping := new(dns.Msg)
|
||||
ping.SetQuestion(h.domain, dns.TypeNS)
|
||||
ping.MsgHdr.RecursionDesired = h.recursionDesired
|
||||
|
||||
m, _, err := h.c.Exchange(ping, addr)
|
||||
// If we got a header, we're alright, basically only care about I/O errors 'n stuff.
|
||||
if err != nil && m != nil {
|
||||
// Silly check, something sane came back.
|
||||
if m.Response || m.Opcode == dns.OpcodeQuery {
|
||||
err = nil
|
||||
}
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
153
plugin/pkg/proxy/health_test.go
Normal file
153
plugin/pkg/proxy/health_test.go
Normal file
@@ -0,0 +1,153 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/coredns/coredns/plugin/pkg/dnstest"
|
||||
"github.com/coredns/coredns/plugin/pkg/transport"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
|
||||
func TestHealth(t *testing.T) {
|
||||
i := uint32(0)
|
||||
s := dnstest.NewServer(func(w dns.ResponseWriter, r *dns.Msg) {
|
||||
if r.Question[0].Name == "." && r.RecursionDesired == true {
|
||||
atomic.AddUint32(&i, 1)
|
||||
}
|
||||
ret := new(dns.Msg)
|
||||
ret.SetReply(r)
|
||||
w.WriteMsg(ret)
|
||||
})
|
||||
defer s.Close()
|
||||
|
||||
hc := NewHealthChecker(transport.DNS, true, "")
|
||||
hc.SetReadTimeout(10 * time.Millisecond)
|
||||
hc.SetWriteTimeout(10 * time.Millisecond)
|
||||
|
||||
p := NewProxy(s.Addr, transport.DNS)
|
||||
p.readTimeout = 10 * time.Millisecond
|
||||
err := hc.Check(p)
|
||||
if err != nil {
|
||||
t.Errorf("check failed: %v", err)
|
||||
}
|
||||
|
||||
time.Sleep(20 * time.Millisecond)
|
||||
i1 := atomic.LoadUint32(&i)
|
||||
if i1 != 1 {
|
||||
t.Errorf("Expected number of health checks with RecursionDesired==true to be %d, got %d", 1, i1)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHealthTCP(t *testing.T) {
|
||||
i := uint32(0)
|
||||
s := dnstest.NewServer(func(w dns.ResponseWriter, r *dns.Msg) {
|
||||
if r.Question[0].Name == "." && r.RecursionDesired == true {
|
||||
atomic.AddUint32(&i, 1)
|
||||
}
|
||||
ret := new(dns.Msg)
|
||||
ret.SetReply(r)
|
||||
w.WriteMsg(ret)
|
||||
})
|
||||
defer s.Close()
|
||||
|
||||
hc := NewHealthChecker(transport.DNS, true, "")
|
||||
hc.SetTCPTransport()
|
||||
hc.SetReadTimeout(10 * time.Millisecond)
|
||||
hc.SetWriteTimeout(10 * time.Millisecond)
|
||||
|
||||
p := NewProxy(s.Addr, transport.DNS)
|
||||
p.readTimeout = 10 * time.Millisecond
|
||||
err := hc.Check(p)
|
||||
if err != nil {
|
||||
t.Errorf("check failed: %v", err)
|
||||
}
|
||||
|
||||
time.Sleep(20 * time.Millisecond)
|
||||
i1 := atomic.LoadUint32(&i)
|
||||
if i1 != 1 {
|
||||
t.Errorf("Expected number of health checks with RecursionDesired==true to be %d, got %d", 1, i1)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHealthNoRecursion(t *testing.T) {
|
||||
i := uint32(0)
|
||||
s := dnstest.NewServer(func(w dns.ResponseWriter, r *dns.Msg) {
|
||||
if r.Question[0].Name == "." && r.RecursionDesired == false {
|
||||
atomic.AddUint32(&i, 1)
|
||||
}
|
||||
ret := new(dns.Msg)
|
||||
ret.SetReply(r)
|
||||
w.WriteMsg(ret)
|
||||
})
|
||||
defer s.Close()
|
||||
|
||||
hc := NewHealthChecker(transport.DNS, false, "")
|
||||
hc.SetReadTimeout(10 * time.Millisecond)
|
||||
hc.SetWriteTimeout(10 * time.Millisecond)
|
||||
|
||||
p := NewProxy(s.Addr, transport.DNS)
|
||||
p.readTimeout = 10 * time.Millisecond
|
||||
err := hc.Check(p)
|
||||
if err != nil {
|
||||
t.Errorf("check failed: %v", err)
|
||||
}
|
||||
|
||||
time.Sleep(20 * time.Millisecond)
|
||||
i1 := atomic.LoadUint32(&i)
|
||||
if i1 != 1 {
|
||||
t.Errorf("Expected number of health checks with RecursionDesired==false to be %d, got %d", 1, i1)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHealthTimeout(t *testing.T) {
|
||||
s := dnstest.NewServer(func(w dns.ResponseWriter, r *dns.Msg) {
|
||||
// timeout
|
||||
})
|
||||
defer s.Close()
|
||||
|
||||
hc := NewHealthChecker(transport.DNS, false, "")
|
||||
hc.SetReadTimeout(10 * time.Millisecond)
|
||||
hc.SetWriteTimeout(10 * time.Millisecond)
|
||||
|
||||
p := NewProxy(s.Addr, transport.DNS)
|
||||
p.readTimeout = 10 * time.Millisecond
|
||||
err := hc.Check(p)
|
||||
if err == nil {
|
||||
t.Errorf("expected error")
|
||||
}
|
||||
}
|
||||
|
||||
func TestHealthDomain(t *testing.T) {
|
||||
hcDomain := "example.org."
|
||||
|
||||
i := uint32(0)
|
||||
s := dnstest.NewServer(func(w dns.ResponseWriter, r *dns.Msg) {
|
||||
if r.Question[0].Name == hcDomain && r.RecursionDesired == true {
|
||||
atomic.AddUint32(&i, 1)
|
||||
}
|
||||
ret := new(dns.Msg)
|
||||
ret.SetReply(r)
|
||||
w.WriteMsg(ret)
|
||||
})
|
||||
defer s.Close()
|
||||
|
||||
hc := NewHealthChecker(transport.DNS, true, hcDomain)
|
||||
hc.SetReadTimeout(10 * time.Millisecond)
|
||||
hc.SetWriteTimeout(10 * time.Millisecond)
|
||||
|
||||
p := NewProxy(s.Addr, transport.DNS)
|
||||
p.readTimeout = 10 * time.Millisecond
|
||||
err := hc.Check(p)
|
||||
if err != nil {
|
||||
t.Errorf("check failed: %v", err)
|
||||
}
|
||||
|
||||
time.Sleep(12 * time.Millisecond)
|
||||
i1 := atomic.LoadUint32(&i)
|
||||
if i1 != 1 {
|
||||
t.Errorf("Expected number of health checks with Domain==%s to be %d, got %d", hcDomain, 1, i1)
|
||||
}
|
||||
}
|
||||
49
plugin/pkg/proxy/metrics.go
Normal file
49
plugin/pkg/proxy/metrics.go
Normal file
@@ -0,0 +1,49 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"github.com/coredns/coredns/plugin"
|
||||
|
||||
"github.com/prometheus/client_golang/prometheus"
|
||||
"github.com/prometheus/client_golang/prometheus/promauto"
|
||||
)
|
||||
|
||||
// Variables declared for monitoring.
|
||||
var (
|
||||
RequestCount = promauto.NewCounterVec(prometheus.CounterOpts{
|
||||
Namespace: plugin.Namespace,
|
||||
Subsystem: "proxy",
|
||||
Name: "requests_total",
|
||||
Help: "Counter of requests made per upstream.",
|
||||
}, []string{"to"})
|
||||
RcodeCount = promauto.NewCounterVec(prometheus.CounterOpts{
|
||||
Namespace: plugin.Namespace,
|
||||
Subsystem: "proxy",
|
||||
Name: "responses_total",
|
||||
Help: "Counter of responses received per upstream.",
|
||||
}, []string{"rcode", "to"})
|
||||
RequestDuration = promauto.NewHistogramVec(prometheus.HistogramOpts{
|
||||
Namespace: plugin.Namespace,
|
||||
Subsystem: "proxy",
|
||||
Name: "request_duration_seconds",
|
||||
Buckets: plugin.TimeBuckets,
|
||||
Help: "Histogram of the time each request took.",
|
||||
}, []string{"to", "rcode"})
|
||||
HealthcheckFailureCount = promauto.NewCounterVec(prometheus.CounterOpts{
|
||||
Namespace: plugin.Namespace,
|
||||
Subsystem: "proxy",
|
||||
Name: "healthcheck_failures_total",
|
||||
Help: "Counter of the number of failed healthchecks.",
|
||||
}, []string{"to"})
|
||||
ConnCacheHitsCount = promauto.NewCounterVec(prometheus.CounterOpts{
|
||||
Namespace: plugin.Namespace,
|
||||
Subsystem: "proxy",
|
||||
Name: "conn_cache_hits_total",
|
||||
Help: "Counter of connection cache hits per upstream and protocol.",
|
||||
}, []string{"to", "proto"})
|
||||
ConnCacheMissesCount = promauto.NewCounterVec(prometheus.CounterOpts{
|
||||
Namespace: plugin.Namespace,
|
||||
Subsystem: "proxy",
|
||||
Name: "conn_cache_misses_total",
|
||||
Help: "Counter of connection cache misses per upstream and protocol.",
|
||||
}, []string{"to", "proto"})
|
||||
)
|
||||
156
plugin/pkg/proxy/persistent.go
Normal file
156
plugin/pkg/proxy/persistent.go
Normal file
@@ -0,0 +1,156 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"sort"
|
||||
"time"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
|
||||
// a persistConn hold the dns.Conn and the last used time.
|
||||
type persistConn struct {
|
||||
c *dns.Conn
|
||||
used time.Time
|
||||
}
|
||||
|
||||
// Transport hold the persistent cache.
|
||||
type Transport struct {
|
||||
avgDialTime int64 // kind of average time of dial time
|
||||
conns [typeTotalCount][]*persistConn // Buckets for udp, tcp and tcp-tls.
|
||||
expire time.Duration // After this duration a connection is expired.
|
||||
addr string
|
||||
tlsConfig *tls.Config
|
||||
|
||||
dial chan string
|
||||
yield chan *persistConn
|
||||
ret chan *persistConn
|
||||
stop chan bool
|
||||
}
|
||||
|
||||
func newTransport(addr string) *Transport {
|
||||
t := &Transport{
|
||||
avgDialTime: int64(maxDialTimeout / 2),
|
||||
conns: [typeTotalCount][]*persistConn{},
|
||||
expire: defaultExpire,
|
||||
addr: addr,
|
||||
dial: make(chan string),
|
||||
yield: make(chan *persistConn),
|
||||
ret: make(chan *persistConn),
|
||||
stop: make(chan bool),
|
||||
}
|
||||
return t
|
||||
}
|
||||
|
||||
// connManagers manages the persistent connection cache for UDP and TCP.
|
||||
func (t *Transport) connManager() {
|
||||
ticker := time.NewTicker(defaultExpire)
|
||||
defer ticker.Stop()
|
||||
Wait:
|
||||
for {
|
||||
select {
|
||||
case proto := <-t.dial:
|
||||
transtype := stringToTransportType(proto)
|
||||
// take the last used conn - complexity O(1)
|
||||
if stack := t.conns[transtype]; len(stack) > 0 {
|
||||
pc := stack[len(stack)-1]
|
||||
if time.Since(pc.used) < t.expire {
|
||||
// Found one, remove from pool and return this conn.
|
||||
t.conns[transtype] = stack[:len(stack)-1]
|
||||
t.ret <- pc
|
||||
continue Wait
|
||||
}
|
||||
// clear entire cache if the last conn is expired
|
||||
t.conns[transtype] = nil
|
||||
// now, the connections being passed to closeConns() are not reachable from
|
||||
// transport methods anymore. So, it's safe to close them in a separate goroutine
|
||||
go closeConns(stack)
|
||||
}
|
||||
t.ret <- nil
|
||||
|
||||
case pc := <-t.yield:
|
||||
transtype := t.transportTypeFromConn(pc)
|
||||
t.conns[transtype] = append(t.conns[transtype], pc)
|
||||
|
||||
case <-ticker.C:
|
||||
t.cleanup(false)
|
||||
|
||||
case <-t.stop:
|
||||
t.cleanup(true)
|
||||
close(t.ret)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// closeConns closes connections.
|
||||
func closeConns(conns []*persistConn) {
|
||||
for _, pc := range conns {
|
||||
pc.c.Close()
|
||||
}
|
||||
}
|
||||
|
||||
// cleanup removes connections from cache.
|
||||
func (t *Transport) cleanup(all bool) {
|
||||
staleTime := time.Now().Add(-t.expire)
|
||||
for transtype, stack := range t.conns {
|
||||
if len(stack) == 0 {
|
||||
continue
|
||||
}
|
||||
if all {
|
||||
t.conns[transtype] = nil
|
||||
// now, the connections being passed to closeConns() are not reachable from
|
||||
// transport methods anymore. So, it's safe to close them in a separate goroutine
|
||||
go closeConns(stack)
|
||||
continue
|
||||
}
|
||||
if stack[0].used.After(staleTime) {
|
||||
continue
|
||||
}
|
||||
|
||||
// connections in stack are sorted by "used"
|
||||
good := sort.Search(len(stack), func(i int) bool {
|
||||
return stack[i].used.After(staleTime)
|
||||
})
|
||||
t.conns[transtype] = stack[good:]
|
||||
// now, the connections being passed to closeConns() are not reachable from
|
||||
// transport methods anymore. So, it's safe to close them in a separate goroutine
|
||||
go closeConns(stack[:good])
|
||||
}
|
||||
}
|
||||
|
||||
// It is hard to pin a value to this, the import thing is to no block forever, losing at cached connection is not terrible.
|
||||
const yieldTimeout = 25 * time.Millisecond
|
||||
|
||||
// Yield returns the connection to transport for reuse.
|
||||
func (t *Transport) Yield(pc *persistConn) {
|
||||
pc.used = time.Now() // update used time
|
||||
|
||||
// Make this non-blocking, because in the case of a very busy forwarder we will *block* on this yield. This
|
||||
// blocks the outer go-routine and stuff will just pile up. We timeout when the send fails to as returning
|
||||
// these connection is an optimization anyway.
|
||||
select {
|
||||
case t.yield <- pc:
|
||||
return
|
||||
case <-time.After(yieldTimeout):
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// Start starts the transport's connection manager.
|
||||
func (t *Transport) Start() { go t.connManager() }
|
||||
|
||||
// Stop stops the transport's connection manager.
|
||||
func (t *Transport) Stop() { close(t.stop) }
|
||||
|
||||
// SetExpire sets the connection expire time in transport.
|
||||
func (t *Transport) SetExpire(expire time.Duration) { t.expire = expire }
|
||||
|
||||
// SetTLSConfig sets the TLS config in transport.
|
||||
func (t *Transport) SetTLSConfig(cfg *tls.Config) { t.tlsConfig = cfg }
|
||||
|
||||
const (
|
||||
defaultExpire = 10 * time.Second
|
||||
minDialTimeout = 1 * time.Second
|
||||
maxDialTimeout = 30 * time.Second
|
||||
)
|
||||
109
plugin/pkg/proxy/persistent_test.go
Normal file
109
plugin/pkg/proxy/persistent_test.go
Normal file
@@ -0,0 +1,109 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/coredns/coredns/plugin/pkg/dnstest"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
|
||||
func TestCached(t *testing.T) {
|
||||
s := dnstest.NewServer(func(w dns.ResponseWriter, r *dns.Msg) {
|
||||
ret := new(dns.Msg)
|
||||
ret.SetReply(r)
|
||||
w.WriteMsg(ret)
|
||||
})
|
||||
defer s.Close()
|
||||
|
||||
tr := newTransport(s.Addr)
|
||||
tr.Start()
|
||||
defer tr.Stop()
|
||||
|
||||
c1, cache1, _ := tr.Dial("udp")
|
||||
c2, cache2, _ := tr.Dial("udp")
|
||||
|
||||
if cache1 || cache2 {
|
||||
t.Errorf("Expected non-cached connection")
|
||||
}
|
||||
|
||||
tr.Yield(c1)
|
||||
tr.Yield(c2)
|
||||
c3, cached3, _ := tr.Dial("udp")
|
||||
if !cached3 {
|
||||
t.Error("Expected cached connection (c3)")
|
||||
}
|
||||
if c2 != c3 {
|
||||
t.Error("Expected c2 == c3")
|
||||
}
|
||||
|
||||
tr.Yield(c3)
|
||||
|
||||
// dial another protocol
|
||||
c4, cached4, _ := tr.Dial("tcp")
|
||||
if cached4 {
|
||||
t.Errorf("Expected non-cached connection (c4)")
|
||||
}
|
||||
tr.Yield(c4)
|
||||
}
|
||||
|
||||
func TestCleanupByTimer(t *testing.T) {
|
||||
s := dnstest.NewServer(func(w dns.ResponseWriter, r *dns.Msg) {
|
||||
ret := new(dns.Msg)
|
||||
ret.SetReply(r)
|
||||
w.WriteMsg(ret)
|
||||
})
|
||||
defer s.Close()
|
||||
|
||||
tr := newTransport(s.Addr)
|
||||
tr.SetExpire(100 * time.Millisecond)
|
||||
tr.Start()
|
||||
defer tr.Stop()
|
||||
|
||||
c1, _, _ := tr.Dial("udp")
|
||||
c2, _, _ := tr.Dial("udp")
|
||||
tr.Yield(c1)
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
tr.Yield(c2)
|
||||
|
||||
time.Sleep(120 * time.Millisecond)
|
||||
c3, cached, _ := tr.Dial("udp")
|
||||
if cached {
|
||||
t.Error("Expected non-cached connection (c3)")
|
||||
}
|
||||
tr.Yield(c3)
|
||||
|
||||
time.Sleep(120 * time.Millisecond)
|
||||
c4, cached, _ := tr.Dial("udp")
|
||||
if cached {
|
||||
t.Error("Expected non-cached connection (c4)")
|
||||
}
|
||||
tr.Yield(c4)
|
||||
}
|
||||
|
||||
func TestCleanupAll(t *testing.T) {
|
||||
s := dnstest.NewServer(func(w dns.ResponseWriter, r *dns.Msg) {
|
||||
ret := new(dns.Msg)
|
||||
ret.SetReply(r)
|
||||
w.WriteMsg(ret)
|
||||
})
|
||||
defer s.Close()
|
||||
|
||||
tr := newTransport(s.Addr)
|
||||
|
||||
c1, _ := dns.DialTimeout("udp", tr.addr, maxDialTimeout)
|
||||
c2, _ := dns.DialTimeout("udp", tr.addr, maxDialTimeout)
|
||||
c3, _ := dns.DialTimeout("udp", tr.addr, maxDialTimeout)
|
||||
|
||||
tr.conns[typeUDP] = []*persistConn{{c1, time.Now()}, {c2, time.Now()}, {c3, time.Now()}}
|
||||
|
||||
if len(tr.conns[typeUDP]) != 3 {
|
||||
t.Error("Expected 3 connections")
|
||||
}
|
||||
tr.cleanup(true)
|
||||
|
||||
if len(tr.conns[typeUDP]) > 0 {
|
||||
t.Error("Expected no cached connections")
|
||||
}
|
||||
}
|
||||
98
plugin/pkg/proxy/proxy.go
Normal file
98
plugin/pkg/proxy/proxy.go
Normal file
@@ -0,0 +1,98 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"runtime"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/coredns/coredns/plugin/pkg/log"
|
||||
"github.com/coredns/coredns/plugin/pkg/up"
|
||||
)
|
||||
|
||||
// Proxy defines an upstream host.
|
||||
type Proxy struct {
|
||||
fails uint32
|
||||
addr string
|
||||
|
||||
transport *Transport
|
||||
|
||||
readTimeout time.Duration
|
||||
|
||||
// health checking
|
||||
probe *up.Probe
|
||||
health HealthChecker
|
||||
}
|
||||
|
||||
// NewProxy returns a new proxy.
|
||||
func NewProxy(addr, trans string) *Proxy {
|
||||
p := &Proxy{
|
||||
addr: addr,
|
||||
fails: 0,
|
||||
probe: up.New(),
|
||||
readTimeout: 2 * time.Second,
|
||||
transport: newTransport(addr),
|
||||
}
|
||||
p.health = NewHealthChecker(trans, true, ".")
|
||||
runtime.SetFinalizer(p, (*Proxy).finalizer)
|
||||
return p
|
||||
}
|
||||
|
||||
func (p *Proxy) Addr() string { return p.addr }
|
||||
|
||||
// SetTLSConfig sets the TLS config in the lower p.transport and in the healthchecking client.
|
||||
func (p *Proxy) SetTLSConfig(cfg *tls.Config) {
|
||||
p.transport.SetTLSConfig(cfg)
|
||||
p.health.SetTLSConfig(cfg)
|
||||
}
|
||||
|
||||
// SetExpire sets the expire duration in the lower p.transport.
|
||||
func (p *Proxy) SetExpire(expire time.Duration) { p.transport.SetExpire(expire) }
|
||||
|
||||
func (p *Proxy) GetHealthchecker() HealthChecker {
|
||||
return p.health
|
||||
}
|
||||
|
||||
func (p *Proxy) Fails() uint32 {
|
||||
return atomic.LoadUint32(&p.fails)
|
||||
}
|
||||
|
||||
// Healthcheck kicks of a round of health checks for this proxy.
|
||||
func (p *Proxy) Healthcheck() {
|
||||
if p.health == nil {
|
||||
log.Warning("No healthchecker")
|
||||
return
|
||||
}
|
||||
|
||||
p.probe.Do(func() error {
|
||||
return p.health.Check(p)
|
||||
})
|
||||
}
|
||||
|
||||
// Down returns true if this proxy is down, i.e. has *more* fails than maxfails.
|
||||
func (p *Proxy) Down(maxfails uint32) bool {
|
||||
if maxfails == 0 {
|
||||
return false
|
||||
}
|
||||
|
||||
fails := atomic.LoadUint32(&p.fails)
|
||||
return fails > maxfails
|
||||
}
|
||||
|
||||
// Stop close stops the health checking goroutine.
|
||||
func (p *Proxy) Stop() { p.probe.Stop() }
|
||||
func (p *Proxy) finalizer() { p.transport.Stop() }
|
||||
|
||||
// Start starts the proxy's healthchecking.
|
||||
func (p *Proxy) Start(duration time.Duration) {
|
||||
p.probe.Start(duration)
|
||||
p.transport.Start()
|
||||
}
|
||||
|
||||
func (p *Proxy) SetReadTimeout(duration time.Duration) {
|
||||
p.readTimeout = duration
|
||||
}
|
||||
|
||||
const (
|
||||
maxTimeout = 2 * time.Second
|
||||
)
|
||||
99
plugin/pkg/proxy/proxy_test.go
Normal file
99
plugin/pkg/proxy/proxy_test.go
Normal file
@@ -0,0 +1,99 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/coredns/coredns/plugin/pkg/dnstest"
|
||||
"github.com/coredns/coredns/plugin/pkg/transport"
|
||||
"github.com/coredns/coredns/plugin/test"
|
||||
"github.com/coredns/coredns/request"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
|
||||
func TestProxy(t *testing.T) {
|
||||
s := dnstest.NewServer(func(w dns.ResponseWriter, r *dns.Msg) {
|
||||
ret := new(dns.Msg)
|
||||
ret.SetReply(r)
|
||||
ret.Answer = append(ret.Answer, test.A("example.org. IN A 127.0.0.1"))
|
||||
w.WriteMsg(ret)
|
||||
})
|
||||
defer s.Close()
|
||||
|
||||
p := NewProxy(s.Addr, transport.DNS)
|
||||
p.readTimeout = 10 * time.Millisecond
|
||||
p.Start(5 * time.Second)
|
||||
m := new(dns.Msg)
|
||||
|
||||
m.SetQuestion("example.org.", dns.TypeA)
|
||||
|
||||
rec := dnstest.NewRecorder(&test.ResponseWriter{})
|
||||
req := request.Request{Req: m, W: rec}
|
||||
|
||||
resp, err := p.Connect(context.Background(), req, Options{PreferUDP: true})
|
||||
if err != nil {
|
||||
t.Errorf("Failed to connect to testdnsserver: %s", err)
|
||||
}
|
||||
|
||||
if x := resp.Answer[0].Header().Name; x != "example.org." {
|
||||
t.Errorf("Expected %s, got %s", "example.org.", x)
|
||||
}
|
||||
}
|
||||
|
||||
func TestProxyTLSFail(t *testing.T) {
|
||||
// This is an udp/tcp test server, so we shouldn't reach it with TLS.
|
||||
s := dnstest.NewServer(func(w dns.ResponseWriter, r *dns.Msg) {
|
||||
ret := new(dns.Msg)
|
||||
ret.SetReply(r)
|
||||
ret.Answer = append(ret.Answer, test.A("example.org. IN A 127.0.0.1"))
|
||||
w.WriteMsg(ret)
|
||||
})
|
||||
defer s.Close()
|
||||
|
||||
p := NewProxy(s.Addr, transport.TLS)
|
||||
p.readTimeout = 10 * time.Millisecond
|
||||
p.SetTLSConfig(&tls.Config{})
|
||||
p.Start(5 * time.Second)
|
||||
m := new(dns.Msg)
|
||||
|
||||
m.SetQuestion("example.org.", dns.TypeA)
|
||||
|
||||
rec := dnstest.NewRecorder(&test.ResponseWriter{})
|
||||
req := request.Request{Req: m, W: rec}
|
||||
|
||||
_, err := p.Connect(context.Background(), req, Options{})
|
||||
if err == nil {
|
||||
t.Fatal("Expected *not* to receive reply, but got one")
|
||||
}
|
||||
}
|
||||
|
||||
func TestProtocolSelection(t *testing.T) {
|
||||
p := NewProxy("bad_address", transport.DNS)
|
||||
p.readTimeout = 10 * time.Millisecond
|
||||
|
||||
stateUDP := request.Request{W: &test.ResponseWriter{}, Req: new(dns.Msg)}
|
||||
stateTCP := request.Request{W: &test.ResponseWriter{TCP: true}, Req: new(dns.Msg)}
|
||||
ctx := context.TODO()
|
||||
|
||||
go func() {
|
||||
p.Connect(ctx, stateUDP, Options{})
|
||||
p.Connect(ctx, stateUDP, Options{ForceTCP: true})
|
||||
p.Connect(ctx, stateUDP, Options{PreferUDP: true})
|
||||
p.Connect(ctx, stateUDP, Options{PreferUDP: true, ForceTCP: true})
|
||||
p.Connect(ctx, stateTCP, Options{})
|
||||
p.Connect(ctx, stateTCP, Options{ForceTCP: true})
|
||||
p.Connect(ctx, stateTCP, Options{PreferUDP: true})
|
||||
p.Connect(ctx, stateTCP, Options{PreferUDP: true, ForceTCP: true})
|
||||
}()
|
||||
|
||||
for i, exp := range []string{"udp", "tcp", "udp", "tcp", "tcp", "tcp", "udp", "tcp"} {
|
||||
proto := <-p.transport.dial
|
||||
p.transport.ret <- nil
|
||||
if proto != exp {
|
||||
t.Errorf("Unexpected protocol in case %d, expected %q, actual %q", i, exp, proto)
|
||||
}
|
||||
}
|
||||
}
|
||||
39
plugin/pkg/proxy/type.go
Normal file
39
plugin/pkg/proxy/type.go
Normal file
@@ -0,0 +1,39 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"net"
|
||||
)
|
||||
|
||||
type transportType int
|
||||
|
||||
const (
|
||||
typeUDP transportType = iota
|
||||
typeTCP
|
||||
typeTLS
|
||||
typeTotalCount // keep this last
|
||||
)
|
||||
|
||||
func stringToTransportType(s string) transportType {
|
||||
switch s {
|
||||
case "udp":
|
||||
return typeUDP
|
||||
case "tcp":
|
||||
return typeTCP
|
||||
case "tcp-tls":
|
||||
return typeTLS
|
||||
}
|
||||
|
||||
return typeUDP
|
||||
}
|
||||
|
||||
func (t *Transport) transportTypeFromConn(pc *persistConn) transportType {
|
||||
if _, ok := pc.c.Conn.(*net.UDPConn); ok {
|
||||
return typeUDP
|
||||
}
|
||||
|
||||
if t.tlsConfig == nil {
|
||||
return typeTCP
|
||||
}
|
||||
|
||||
return typeTLS
|
||||
}
|
||||
Reference in New Issue
Block a user