mirror of
				https://github.com/coredns/coredns.git
				synced 2025-10-30 01:34:21 -04:00 
			
		
		
		
	plugin/forward using pkg/up (#1493)
* plugin/forward: on demand healtchecking Only start doing health checks when we encouner an error (any error). This uses the new pluing/pkg/up package to abstract away the actual checking. This reduces the LOC quite a bit; does need more testing, unit testing and tcpdumping a bit. * fix tests * Fix readme * Use pkg/up for healthchecks * remove unused channel * more cleanups * update readme * * Again do go generate and go build; still referencing the wrong forward repo? Anyway fixed. * Use pkg/up for doing the healtchecks to cut back on unwanted queries * Change up.Func to return an error instead of a boolean. * Drop the string target argument as it doesn't make sense. * Add healthcheck test on failing to get an upstream answer. TODO(miek): double check Forward and Lookup and how they interact with HC, and if we correctly call close() on those * actual test * Tests here * more tests * try getting rid of host * Get rid of the host indirection * Finish removing hosts * moar testing * import fmt * field is not used * docs * move some stuff * bring back health_check * maxfails=0 test * git and merging, bah * review
This commit is contained in:
		| @@ -6,10 +6,17 @@ | |||||||
|  |  | ||||||
| ## Description | ## Description | ||||||
|  |  | ||||||
| The *forward* plugin is generally faster (~30+%) than *proxy* as it re-uses already opened sockets | The *forward* plugin re-uses already opened sockets to the upstreams. It supports UDP, TCP and | ||||||
| to the upstreams. It supports UDP, TCP and DNS-over-TLS and uses inband health checking that is | DNS-over-TLS and uses in band health checking. | ||||||
| enabled by default. |  | ||||||
| When *all* upstreams are down it assumes healtchecking as a mechanism has failed and will try to | When it detects an error a health check is performed. This checks runs in a loop, every *0.5s*, for | ||||||
|  | as long as the upstream reports unhealthy. Once healthy we stop health checking (until the next | ||||||
|  | error). The health checks use a recursive DNS query (`. IN NS`) to get upstream health. Any response | ||||||
|  | that is not a network error (REFUSED, NOTIMPL, SERVFAIL, etc) is taken as a healthy upstream. The | ||||||
|  | health check uses the same protocol as specified in **TO**. If `max_fails` is set to 0, no checking | ||||||
|  | is performed and upstreams will always be considered healthy. | ||||||
|  |  | ||||||
|  | When *all* upstreams are down it assumes health checking as a mechanism has failed and will try to | ||||||
| connect to a random upstream (which may or may not work). | connect to a random upstream (which may or may not work). | ||||||
|  |  | ||||||
| ## Syntax | ## Syntax | ||||||
| @@ -22,16 +29,11 @@ forward FROM TO... | |||||||
|  |  | ||||||
| * **FROM** is the base domain to match for the request to be forwarded. | * **FROM** is the base domain to match for the request to be forwarded. | ||||||
| * **TO...** are the destination endpoints to forward to. The **TO** syntax allows you to specify | * **TO...** are the destination endpoints to forward to. The **TO** syntax allows you to specify | ||||||
|   a protocol, `tls://9.9.9.9` or `dns://` for plain DNS. The number of upstreams is limited to 15. |   a protocol, `tls://9.9.9.9` or `dns://` (or no protocol) for plain DNS. The number of upstreams is | ||||||
|  |   limited to 15. | ||||||
|  |  | ||||||
| The health checks are done every *0.5s*. After *two* failed checks the upstream is considered | Multiple upstreams are randomized (see `policy`) on first use. When a healthy proxy returns an error | ||||||
| unhealthy. The health checks use a recursive DNS query (`. IN NS`) to get upstream health. Any | during the exchange the next upstream in the list is tried. | ||||||
| response that is not an error (REFUSED, NOTIMPL, SERVFAIL, etc) is taken as a healthy upstream. The |  | ||||||
| health check uses the same protocol as specific in the **TO**. On startup each upstream is marked |  | ||||||
| unhealthy until it passes a health check. A 0 duration will disable any health checks. |  | ||||||
|  |  | ||||||
| Multiple upstreams are randomized (default policy) on first use. When a healthy proxy returns an |  | ||||||
| error during the exchange the next upstream in the list is tried. |  | ||||||
|  |  | ||||||
| Extra knobs are available with an expanded syntax: | Extra knobs are available with an expanded syntax: | ||||||
|  |  | ||||||
| @@ -39,12 +41,12 @@ Extra knobs are available with an expanded syntax: | |||||||
| forward FROM TO... { | forward FROM TO... { | ||||||
|     except IGNORED_NAMES... |     except IGNORED_NAMES... | ||||||
|     force_tcp |     force_tcp | ||||||
|     health_check DURATION |  | ||||||
|     expire DURATION |     expire DURATION | ||||||
|     max_fails INTEGER |     max_fails INTEGER | ||||||
|     tls CERT KEY CA |     tls CERT KEY CA | ||||||
|     tls_servername NAME |     tls_servername NAME | ||||||
|     policy random|round_robin |     policy random|round_robin | ||||||
|  |     health_checks DURATION | ||||||
| } | } | ||||||
| ~~~ | ~~~ | ||||||
|  |  | ||||||
| @@ -52,21 +54,16 @@ forward FROM TO... { | |||||||
| * **IGNORED_NAMES** in `except` is a space-separated list of domains to exclude from forwarding. | * **IGNORED_NAMES** in `except` is a space-separated list of domains to exclude from forwarding. | ||||||
|   Requests that match none of these names will be passed through. |   Requests that match none of these names will be passed through. | ||||||
| * `force_tcp`, use TCP even when the request comes in over UDP. | * `force_tcp`, use TCP even when the request comes in over UDP. | ||||||
| * `health_checks`, use a different **DURATION** for health checking, the default duration is 0.5s. |  | ||||||
|   A value of 0 disables the health checks completely. |  | ||||||
| * `max_fails` is the number of subsequent failed health checks that are needed before considering | * `max_fails` is the number of subsequent failed health checks that are needed before considering | ||||||
|   a backend to be down. If 0, the backend will never be marked as down. Default is 2. |   an upstream to be down. If 0, the upstream will never be marked as down (nor health checked). | ||||||
|  |   Default is 2. | ||||||
| * `expire` **DURATION**, expire (cached) connections after this time, the default is 10s. | * `expire` **DURATION**, expire (cached) connections after this time, the default is 10s. | ||||||
| * `tls` **CERT** **KEY** **CA** define the TLS properties for TLS; if you leave this out the | * `tls` **CERT** **KEY** **CA** define the TLS properties for TLS; if you leave this out the | ||||||
|   system's configuration will be used. |   system's configuration will be used. | ||||||
| * `tls_servername` **NAME** allows you to set a server name in the TLS configuration; for instance 9.9.9.9 | * `tls_servername` **NAME** allows you to set a server name in the TLS configuration; for instance 9.9.9.9 | ||||||
|   needs this to be set to `dns.quad9.net`. |   needs this to be set to `dns.quad9.net`. | ||||||
| * `policy` specifies the policy to use for selecting upstream servers. The default is `random`. | * `policy` specifies the policy to use for selecting upstream servers. The default is `random`. | ||||||
|  | * `health_checks`, use a different **DURATION** for health checking, the default duration is 0.5s. | ||||||
| The upstream selection is done via random (default policy) selection. If the socket for this client |  | ||||||
| isn't known *forward* will randomly choose one. If this turns out to be unhealthy, the next one is |  | ||||||
| tried. If *all* hosts are down, we assume health checking is broken and select a *random* upstream to |  | ||||||
| try. |  | ||||||
|  |  | ||||||
| Also note the TLS config is "global" for the whole forwarding proxy if you need a different | Also note the TLS config is "global" for the whole forwarding proxy if you need a different | ||||||
| `tls-name` for different upstreams you're out of luck. | `tls-name` for different upstreams you're out of luck. | ||||||
| @@ -80,7 +77,7 @@ If monitoring is enabled (via the *prometheus* directive) then the following met | |||||||
| * `coredns_forward_response_rcode_total{to, rcode}` - count of RCODEs per upstream. | * `coredns_forward_response_rcode_total{to, rcode}` - count of RCODEs per upstream. | ||||||
| * `coredns_forward_healthcheck_failure_count_total{to}` - number of failed health checks per upstream. | * `coredns_forward_healthcheck_failure_count_total{to}` - number of failed health checks per upstream. | ||||||
| * `coredns_forward_healthcheck_broken_count_total{}` - counter of when all upstreams are unhealthy, | * `coredns_forward_healthcheck_broken_count_total{}` - counter of when all upstreams are unhealthy, | ||||||
|   and we are randomly spraying to a target. |   and we are randomly (this always uses the `random` policy) spraying to an upstream. | ||||||
| * `coredns_forward_socket_count_total{to}` - number of cached sockets per upstream. | * `coredns_forward_socket_count_total{to}` - number of cached sockets per upstream. | ||||||
|  |  | ||||||
| Where `to` is one of the upstream servers (**TO** from the config), `proto` is the protocol used by | Where `to` is one of the upstream servers (**TO** from the config), `proto` is the protocol used by | ||||||
| @@ -125,16 +122,10 @@ Proxy everything except `example.org` using the host's `resolv.conf`'s nameserve | |||||||
| } | } | ||||||
| ~~~ | ~~~ | ||||||
|  |  | ||||||
| Forward to a IPv6 host: |  | ||||||
|  |  | ||||||
| ~~~ corefile |  | ||||||
| . { |  | ||||||
|     forward . [::1]:1053 |  | ||||||
| } |  | ||||||
| ~~~ |  | ||||||
|  |  | ||||||
| Proxy all requests to 9.9.9.9 using the DNS-over-TLS protocol, and cache every answer for up to 30 | Proxy all requests to 9.9.9.9 using the DNS-over-TLS protocol, and cache every answer for up to 30 | ||||||
| seconds. | seconds. Note the `tls_servername` is mandatory if you want a working setup, as 9.9.9.9 can't be | ||||||
|  | used in the TLS negotiation. Also set the health check duration to 5s to not completely swamp the | ||||||
|  | service with health checks. | ||||||
|  |  | ||||||
| ~~~ corefile | ~~~ corefile | ||||||
| . { | . { | ||||||
| @@ -148,7 +139,7 @@ seconds. | |||||||
|  |  | ||||||
| ## Bugs | ## Bugs | ||||||
|  |  | ||||||
| The TLS config is global for the whole forwarding proxy if you need a different `tls-name` for | The TLS config is global for the whole forwarding proxy if you need a different `tls_serveraame` for | ||||||
| different upstreams you're out of luck. | different upstreams you're out of luck. | ||||||
|  |  | ||||||
| ## Also See | ## Also See | ||||||
|   | |||||||
| @@ -21,9 +21,6 @@ func (p *Proxy) connect(ctx context.Context, state request.Request, forceTCP, me | |||||||
| 	if forceTCP { | 	if forceTCP { | ||||||
| 		proto = "tcp" | 		proto = "tcp" | ||||||
| 	} | 	} | ||||||
| 	if p.host.tlsConfig != nil { |  | ||||||
| 		proto = "tcp-tls" |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	conn, err := p.Dial(proto) | 	conn, err := p.Dial(proto) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| @@ -57,9 +54,9 @@ func (p *Proxy) connect(ctx context.Context, state request.Request, forceTCP, me | |||||||
| 			rc = strconv.Itoa(ret.Rcode) | 			rc = strconv.Itoa(ret.Rcode) | ||||||
| 		} | 		} | ||||||
|  |  | ||||||
| 		RequestCount.WithLabelValues(p.host.addr).Add(1) | 		RequestCount.WithLabelValues(p.addr).Add(1) | ||||||
| 		RcodeCount.WithLabelValues(rc, p.host.addr).Add(1) | 		RcodeCount.WithLabelValues(rc, p.addr).Add(1) | ||||||
| 		RequestDuration.WithLabelValues(p.host.addr).Observe(time.Since(start).Seconds()) | 		RequestDuration.WithLabelValues(p.addr).Observe(time.Since(start).Seconds()) | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	return ret, nil | 	return ret, nil | ||||||
|   | |||||||
| @@ -22,6 +22,7 @@ import ( | |||||||
| type Forward struct { | type Forward struct { | ||||||
| 	proxies    []*Proxy | 	proxies    []*Proxy | ||||||
| 	p          Policy | 	p          Policy | ||||||
|  | 	hcInterval time.Duration | ||||||
|  |  | ||||||
| 	from    string | 	from    string | ||||||
| 	ignored []string | 	ignored []string | ||||||
| @@ -32,21 +33,20 @@ type Forward struct { | |||||||
| 	expire        time.Duration | 	expire        time.Duration | ||||||
|  |  | ||||||
| 	forceTCP bool // also here for testing | 	forceTCP bool // also here for testing | ||||||
| 	hcInterval time.Duration // also here for testing |  | ||||||
|  |  | ||||||
| 	Next plugin.Handler | 	Next plugin.Handler | ||||||
| } | } | ||||||
|  |  | ||||||
| // New returns a new Forward. | // New returns a new Forward. | ||||||
| func New() *Forward { | func New() *Forward { | ||||||
| 	f := &Forward{maxfails: 2, tlsConfig: new(tls.Config), expire: defaultExpire, hcInterval: hcDuration, p: new(random)} | 	f := &Forward{maxfails: 2, tlsConfig: new(tls.Config), expire: defaultExpire, p: new(random), from: ".", hcInterval: hcDuration} | ||||||
| 	return f | 	return f | ||||||
| } | } | ||||||
|  |  | ||||||
| // SetProxy appends p to the proxy list and starts healthchecking. | // SetProxy appends p to the proxy list and starts healthchecking. | ||||||
| func (f *Forward) SetProxy(p *Proxy) { | func (f *Forward) SetProxy(p *Proxy) { | ||||||
| 	f.proxies = append(f.proxies, p) | 	f.proxies = append(f.proxies, p) | ||||||
| 	go p.healthCheck() | 	p.start(f.hcInterval) | ||||||
| } | } | ||||||
|  |  | ||||||
| // Len returns the number of configured proxies. | // Len returns the number of configured proxies. | ||||||
| @@ -92,7 +92,27 @@ func (f *Forward) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg | |||||||
| 			child.Finish() | 			child.Finish() | ||||||
| 		} | 		} | ||||||
|  |  | ||||||
|  | 		// If you query for instance ANY isc.org; you get a truncated query back which miekg/dns fails to unpack | ||||||
|  | 		// because the RRs are not finished. The returned message can be useful or useless. Return the original | ||||||
|  | 		// query with some header bits set that they should retry with TCP. | ||||||
|  | 		if err == dns.ErrTruncated { | ||||||
|  | 			// We may or may not have something sensible... if not reassemble something to send to the client. | ||||||
|  | 			if ret == nil { | ||||||
|  | 				ret = new(dns.Msg) | ||||||
|  | 				ret.SetReply(r) | ||||||
|  | 				ret.Truncated = true | ||||||
|  | 				ret.Authoritative = true | ||||||
|  | 				ret.Rcode = dns.RcodeSuccess | ||||||
|  | 			} | ||||||
|  | 			err = nil // and reset err to pass this back to the client. | ||||||
|  | 		} | ||||||
|  |  | ||||||
| 		if err != nil { | 		if err != nil { | ||||||
|  | 			// Kick off health check to see if *our* upstream is broken. | ||||||
|  | 			if f.maxfails != 0 { | ||||||
|  | 				proxy.Healthcheck() | ||||||
|  | 			} | ||||||
|  |  | ||||||
| 			if fails < len(f.proxies) { | 			if fails < len(f.proxies) { | ||||||
| 				continue | 				continue | ||||||
| 			} | 			} | ||||||
| @@ -140,8 +160,8 @@ func (f *Forward) isAllowedDomain(name string) bool { | |||||||
| func (f *Forward) list() []*Proxy { return f.p.List(f.proxies) } | func (f *Forward) list() []*Proxy { return f.p.List(f.proxies) } | ||||||
|  |  | ||||||
| var ( | var ( | ||||||
| 	errInvalidDomain = errors.New("invalid domain for proxy") | 	errInvalidDomain = errors.New("invalid domain for forward") | ||||||
| 	errNoHealthy     = errors.New("no healthy proxies") | 	errNoHealthy     = errors.New("no healthy proxies or upstream error") | ||||||
| 	errNoForward     = errors.New("no forwarder defined") | 	errNoForward     = errors.New("no forwarder defined") | ||||||
| ) | ) | ||||||
|  |  | ||||||
|   | |||||||
| @@ -6,6 +6,7 @@ import ( | |||||||
| 	"github.com/coredns/coredns/plugin/pkg/dnstest" | 	"github.com/coredns/coredns/plugin/pkg/dnstest" | ||||||
| 	"github.com/coredns/coredns/plugin/test" | 	"github.com/coredns/coredns/plugin/test" | ||||||
| 	"github.com/coredns/coredns/request" | 	"github.com/coredns/coredns/request" | ||||||
|  |  | ||||||
| 	"github.com/miekg/dns" | 	"github.com/miekg/dns" | ||||||
| ) | ) | ||||||
|  |  | ||||||
| @@ -18,7 +19,7 @@ func TestForward(t *testing.T) { | |||||||
| 	}) | 	}) | ||||||
| 	defer s.Close() | 	defer s.Close() | ||||||
|  |  | ||||||
| 	p := NewProxy(s.Addr) | 	p := NewProxy(s.Addr, nil /* not TLS */) | ||||||
| 	f := New() | 	f := New() | ||||||
| 	f.SetProxy(p) | 	f.SetProxy(p) | ||||||
| 	defer f.Close() | 	defer f.Close() | ||||||
|   | |||||||
| @@ -1,7 +1,6 @@ | |||||||
| package forward | package forward | ||||||
|  |  | ||||||
| import ( | import ( | ||||||
| 	"log" |  | ||||||
| 	"sync/atomic" | 	"sync/atomic" | ||||||
|  |  | ||||||
| 	"github.com/miekg/dns" | 	"github.com/miekg/dns" | ||||||
| @@ -10,41 +9,25 @@ import ( | |||||||
| // For HC we send to . IN NS +norec message to the upstream. Dial timeouts and empty | // For HC we send to . IN NS +norec message to the upstream. Dial timeouts and empty | ||||||
| // replies are considered fails, basically anything else constitutes a healthy upstream. | // replies are considered fails, basically anything else constitutes a healthy upstream. | ||||||
|  |  | ||||||
| func (h *host) Check() { | // Check is used as the up.Func in the up.Probe. | ||||||
| 	h.Lock() | func (p *Proxy) Check() error { | ||||||
|  | 	err := p.send() | ||||||
| 	if h.checking { |  | ||||||
| 		h.Unlock() |  | ||||||
| 		return |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	h.checking = true |  | ||||||
| 	h.Unlock() |  | ||||||
|  |  | ||||||
| 	err := h.send() |  | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		log.Printf("[INFO] healtheck of %s failed with %s", h.addr, err) | 		HealthcheckFailureCount.WithLabelValues(p.addr).Add(1) | ||||||
|  | 		atomic.AddUint32(&p.fails, 1) | ||||||
| 		HealthcheckFailureCount.WithLabelValues(h.addr).Add(1) | 		return err | ||||||
|  |  | ||||||
| 		atomic.AddUint32(&h.fails, 1) |  | ||||||
| 	} else { |  | ||||||
| 		atomic.StoreUint32(&h.fails, 0) |  | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	h.Lock() | 	atomic.StoreUint32(&p.fails, 0) | ||||||
| 	h.checking = false | 	return nil | ||||||
| 	h.Unlock() |  | ||||||
|  |  | ||||||
| 	return |  | ||||||
| } | } | ||||||
|  |  | ||||||
| func (h *host) send() error { | func (p *Proxy) send() error { | ||||||
| 	hcping := new(dns.Msg) | 	hcping := new(dns.Msg) | ||||||
| 	hcping.SetQuestion(".", dns.TypeNS) | 	hcping.SetQuestion(".", dns.TypeNS) | ||||||
| 	hcping.RecursionDesired = false | 	hcping.RecursionDesired = false | ||||||
|  |  | ||||||
| 	m, _, err := h.client.Exchange(hcping, h.addr) | 	m, _, err := p.client.Exchange(hcping, p.addr) | ||||||
| 	// If we got a header, we're alright, basically only care about I/O errors 'n stuff | 	// If we got a header, we're alright, basically only care about I/O errors 'n stuff | ||||||
| 	if err != nil && m != nil { | 	if err != nil && m != nil { | ||||||
| 		// Silly check, something sane came back | 		// Silly check, something sane came back | ||||||
| @@ -55,13 +38,3 @@ func (h *host) send() error { | |||||||
|  |  | ||||||
| 	return err | 	return err | ||||||
| } | } | ||||||
|  |  | ||||||
| // down returns true is this host has more than maxfails fails. |  | ||||||
| func (h *host) down(maxfails uint32) bool { |  | ||||||
| 	if maxfails == 0 { |  | ||||||
| 		return false |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	fails := atomic.LoadUint32(&h.fails) |  | ||||||
| 	return fails > maxfails |  | ||||||
| } |  | ||||||
|   | |||||||
							
								
								
									
										136
									
								
								plugin/forward/health_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										136
									
								
								plugin/forward/health_test.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,136 @@ | |||||||
|  | package forward | ||||||
|  |  | ||||||
|  | import ( | ||||||
|  | 	"sync/atomic" | ||||||
|  | 	"testing" | ||||||
|  | 	"time" | ||||||
|  |  | ||||||
|  | 	"github.com/coredns/coredns/plugin/pkg/dnstest" | ||||||
|  | 	"github.com/coredns/coredns/plugin/test" | ||||||
|  |  | ||||||
|  | 	"github.com/miekg/dns" | ||||||
|  | 	"golang.org/x/net/context" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | func TestHealth(t *testing.T) { | ||||||
|  | 	const expected = 0 | ||||||
|  | 	i := uint32(0) | ||||||
|  | 	s := dnstest.NewServer(func(w dns.ResponseWriter, r *dns.Msg) { | ||||||
|  | 		if r.Question[0].Name == "." { | ||||||
|  | 			atomic.AddUint32(&i, 1) | ||||||
|  | 		} | ||||||
|  | 		ret := new(dns.Msg) | ||||||
|  | 		ret.SetReply(r) | ||||||
|  | 		w.WriteMsg(ret) | ||||||
|  | 	}) | ||||||
|  | 	defer s.Close() | ||||||
|  |  | ||||||
|  | 	p := NewProxy(s.Addr, nil /* no TLS */) | ||||||
|  | 	f := New() | ||||||
|  | 	f.SetProxy(p) | ||||||
|  | 	defer f.Close() | ||||||
|  |  | ||||||
|  | 	req := new(dns.Msg) | ||||||
|  | 	req.SetQuestion("example.org.", dns.TypeA) | ||||||
|  |  | ||||||
|  | 	f.ServeDNS(context.TODO(), &test.ResponseWriter{}, req) | ||||||
|  |  | ||||||
|  | 	time.Sleep(1 * time.Second) | ||||||
|  | 	i1 := atomic.LoadUint32(&i) | ||||||
|  | 	if i1 != expected { | ||||||
|  | 		t.Errorf("Expected number of health checks to be %d, got %d", expected, i1) | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func TestHealthTimeout(t *testing.T) { | ||||||
|  | 	const expected = 1 | ||||||
|  | 	i := uint32(0) | ||||||
|  | 	s := dnstest.NewServer(func(w dns.ResponseWriter, r *dns.Msg) { | ||||||
|  | 		if r.Question[0].Name == "." { | ||||||
|  | 			// health check, answer | ||||||
|  | 			atomic.AddUint32(&i, 1) | ||||||
|  | 			ret := new(dns.Msg) | ||||||
|  | 			ret.SetReply(r) | ||||||
|  | 			w.WriteMsg(ret) | ||||||
|  | 		} | ||||||
|  | 		// not a health check, do a timeout | ||||||
|  | 	}) | ||||||
|  | 	defer s.Close() | ||||||
|  |  | ||||||
|  | 	p := NewProxy(s.Addr, nil /* no TLS */) | ||||||
|  | 	f := New() | ||||||
|  | 	f.SetProxy(p) | ||||||
|  | 	defer f.Close() | ||||||
|  |  | ||||||
|  | 	req := new(dns.Msg) | ||||||
|  | 	req.SetQuestion("example.org.", dns.TypeA) | ||||||
|  |  | ||||||
|  | 	f.ServeDNS(context.TODO(), &test.ResponseWriter{}, req) | ||||||
|  |  | ||||||
|  | 	time.Sleep(1 * time.Second) | ||||||
|  | 	i1 := atomic.LoadUint32(&i) | ||||||
|  | 	if i1 != expected { | ||||||
|  | 		t.Errorf("Expected number of health checks to be %d, got %d", expected, i1) | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func TestHealthFailTwice(t *testing.T) { | ||||||
|  | 	const expected = 2 | ||||||
|  | 	i := uint32(0) | ||||||
|  | 	s := dnstest.NewServer(func(w dns.ResponseWriter, r *dns.Msg) { | ||||||
|  | 		if r.Question[0].Name == "." { | ||||||
|  | 			atomic.AddUint32(&i, 1) | ||||||
|  | 			i1 := atomic.LoadUint32(&i) | ||||||
|  | 			// Timeout health until we get the second one | ||||||
|  | 			if i1 < 2 { | ||||||
|  | 				return | ||||||
|  | 			} | ||||||
|  | 			ret := new(dns.Msg) | ||||||
|  | 			ret.SetReply(r) | ||||||
|  | 			w.WriteMsg(ret) | ||||||
|  | 		} | ||||||
|  | 	}) | ||||||
|  | 	defer s.Close() | ||||||
|  |  | ||||||
|  | 	p := NewProxy(s.Addr, nil /* no TLS */) | ||||||
|  | 	f := New() | ||||||
|  | 	f.SetProxy(p) | ||||||
|  | 	defer f.Close() | ||||||
|  |  | ||||||
|  | 	req := new(dns.Msg) | ||||||
|  | 	req.SetQuestion("example.org.", dns.TypeA) | ||||||
|  |  | ||||||
|  | 	f.ServeDNS(context.TODO(), &test.ResponseWriter{}, req) | ||||||
|  |  | ||||||
|  | 	time.Sleep(3 * time.Second) | ||||||
|  | 	i1 := atomic.LoadUint32(&i) | ||||||
|  | 	if i1 != expected { | ||||||
|  | 		t.Errorf("Expected number of health checks to be %d, got %d", expected, i1) | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func TestHealthMaxFails(t *testing.T) { | ||||||
|  | 	const expected = 0 | ||||||
|  | 	i := uint32(0) | ||||||
|  | 	s := dnstest.NewServer(func(w dns.ResponseWriter, r *dns.Msg) { | ||||||
|  | 		// timeout | ||||||
|  | 	}) | ||||||
|  | 	defer s.Close() | ||||||
|  |  | ||||||
|  | 	p := NewProxy(s.Addr, nil /* no TLS */) | ||||||
|  | 	f := New() | ||||||
|  | 	f.maxfails = 0 | ||||||
|  | 	f.SetProxy(p) | ||||||
|  | 	defer f.Close() | ||||||
|  |  | ||||||
|  | 	req := new(dns.Msg) | ||||||
|  | 	req.SetQuestion("example.org.", dns.TypeA) | ||||||
|  |  | ||||||
|  | 	f.ServeDNS(context.TODO(), &test.ResponseWriter{}, req) | ||||||
|  |  | ||||||
|  | 	time.Sleep(1 * time.Second) | ||||||
|  | 	i1 := atomic.LoadUint32(&i) | ||||||
|  | 	if i1 != expected { | ||||||
|  | 		t.Errorf("Expected number of health checks to be %d, got %d", expected, i1) | ||||||
|  | 	} | ||||||
|  | } | ||||||
| @@ -1,44 +0,0 @@ | |||||||
| package forward |  | ||||||
|  |  | ||||||
| import ( |  | ||||||
| 	"crypto/tls" |  | ||||||
| 	"sync" |  | ||||||
| 	"time" |  | ||||||
|  |  | ||||||
| 	"github.com/miekg/dns" |  | ||||||
| ) |  | ||||||
|  |  | ||||||
| type host struct { |  | ||||||
| 	addr   string |  | ||||||
| 	client *dns.Client |  | ||||||
|  |  | ||||||
| 	tlsConfig *tls.Config |  | ||||||
| 	expire    time.Duration |  | ||||||
|  |  | ||||||
| 	fails uint32 |  | ||||||
| 	sync.RWMutex |  | ||||||
| 	checking bool |  | ||||||
| } |  | ||||||
|  |  | ||||||
| // newHost returns a new host, the fails are set to 1, i.e. |  | ||||||
| // the first healthcheck must succeed before we use this host. |  | ||||||
| func newHost(addr string) *host { |  | ||||||
| 	return &host{addr: addr, fails: 1, expire: defaultExpire} |  | ||||||
| } |  | ||||||
|  |  | ||||||
| // setClient sets and configures the dns.Client in host. |  | ||||||
| func (h *host) SetClient() { |  | ||||||
| 	c := new(dns.Client) |  | ||||||
| 	c.Net = "udp" |  | ||||||
| 	c.ReadTimeout = 2 * time.Second |  | ||||||
| 	c.WriteTimeout = 2 * time.Second |  | ||||||
|  |  | ||||||
| 	if h.tlsConfig != nil { |  | ||||||
| 		c.Net = "tcp-tls" |  | ||||||
| 		c.TLSConfig = h.tlsConfig |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	h.client = c |  | ||||||
| } |  | ||||||
|  |  | ||||||
| const defaultExpire = 10 * time.Second |  | ||||||
| @@ -5,10 +5,6 @@ | |||||||
| package forward | package forward | ||||||
|  |  | ||||||
| import ( | import ( | ||||||
| 	"crypto/tls" |  | ||||||
| 	"log" |  | ||||||
| 	"time" |  | ||||||
|  |  | ||||||
| 	"github.com/coredns/coredns/request" | 	"github.com/coredns/coredns/request" | ||||||
|  |  | ||||||
| 	"github.com/miekg/dns" | 	"github.com/miekg/dns" | ||||||
| @@ -32,12 +28,10 @@ func (f *Forward) Forward(state request.Request) (*dns.Msg, error) { | |||||||
| 			// All upstream proxies are dead, assume healtcheck is complete broken and randomly | 			// All upstream proxies are dead, assume healtcheck is complete broken and randomly | ||||||
| 			// select an upstream to connect to. | 			// select an upstream to connect to. | ||||||
| 			proxy = f.list()[0] | 			proxy = f.list()[0] | ||||||
| 			log.Printf("[WARNING] All upstreams down, picking random one to connect to %s", proxy.host.addr) |  | ||||||
| 		} | 		} | ||||||
|  |  | ||||||
| 		ret, err := proxy.connect(context.Background(), state, f.forceTCP, true) | 		ret, err := proxy.connect(context.Background(), state, f.forceTCP, true) | ||||||
| 		if err != nil { | 		if err != nil { | ||||||
| 			log.Printf("[WARNING] Failed to connect to %s: %s", proxy.host.addr, err) |  | ||||||
| 			if fails < len(f.proxies) { | 			if fails < len(f.proxies) { | ||||||
| 				continue | 				continue | ||||||
| 			} | 			} | ||||||
| @@ -68,10 +62,11 @@ func (f *Forward) Lookup(state request.Request, name string, typ uint16) (*dns.M | |||||||
| } | } | ||||||
|  |  | ||||||
| // NewLookup returns a Forward that can be used for plugin that need an upstream to resolve external names. | // NewLookup returns a Forward that can be used for plugin that need an upstream to resolve external names. | ||||||
|  | // Note that the caller must run Close on the forward to stop the health checking goroutines. | ||||||
| func NewLookup(addr []string) *Forward { | func NewLookup(addr []string) *Forward { | ||||||
| 	f := &Forward{maxfails: 2, tlsConfig: new(tls.Config), expire: defaultExpire, hcInterval: 2 * time.Second} | 	f := New() | ||||||
| 	for i := range addr { | 	for i := range addr { | ||||||
| 		p := NewProxy(addr[i]) | 		p := NewProxy(addr[i], nil) | ||||||
| 		f.SetProxy(p) | 		f.SetProxy(p) | ||||||
| 	} | 	} | ||||||
| 	return f | 	return f | ||||||
|   | |||||||
| @@ -6,6 +6,7 @@ import ( | |||||||
| 	"github.com/coredns/coredns/plugin/pkg/dnstest" | 	"github.com/coredns/coredns/plugin/pkg/dnstest" | ||||||
| 	"github.com/coredns/coredns/plugin/test" | 	"github.com/coredns/coredns/plugin/test" | ||||||
| 	"github.com/coredns/coredns/request" | 	"github.com/coredns/coredns/request" | ||||||
|  |  | ||||||
| 	"github.com/miekg/dns" | 	"github.com/miekg/dns" | ||||||
| ) | ) | ||||||
|  |  | ||||||
| @@ -18,7 +19,7 @@ func TestLookup(t *testing.T) { | |||||||
| 	}) | 	}) | ||||||
| 	defer s.Close() | 	defer s.Close() | ||||||
|  |  | ||||||
| 	p := NewProxy(s.Addr) | 	p := NewProxy(s.Addr, nil /* no TLS */) | ||||||
| 	f := New() | 	f := New() | ||||||
| 	f.SetProxy(p) | 	f.SetProxy(p) | ||||||
| 	defer f.Close() | 	defer f.Close() | ||||||
|   | |||||||
| @@ -1,6 +1,7 @@ | |||||||
| package forward | package forward | ||||||
|  |  | ||||||
| import ( | import ( | ||||||
|  | 	"crypto/tls" | ||||||
| 	"net" | 	"net" | ||||||
| 	"time" | 	"time" | ||||||
|  |  | ||||||
| @@ -22,7 +23,9 @@ type connErr struct { | |||||||
| // transport hold the persistent cache. | // transport hold the persistent cache. | ||||||
| type transport struct { | type transport struct { | ||||||
| 	conns     map[string][]*persistConn //  Buckets for udp, tcp and tcp-tls. | 	conns     map[string][]*persistConn //  Buckets for udp, tcp and tcp-tls. | ||||||
| 	host  *host | 	expire    time.Duration             // After this duration a connection is expired. | ||||||
|  | 	addr      string | ||||||
|  | 	tlsConfig *tls.Config | ||||||
|  |  | ||||||
| 	dial  chan string | 	dial  chan string | ||||||
| 	yield chan connErr | 	yield chan connErr | ||||||
| @@ -35,10 +38,11 @@ type transport struct { | |||||||
| 	stop chan bool | 	stop chan bool | ||||||
| } | } | ||||||
|  |  | ||||||
| func newTransport(h *host) *transport { | func newTransport(addr string, tlsConfig *tls.Config) *transport { | ||||||
| 	t := &transport{ | 	t := &transport{ | ||||||
| 		conns:   make(map[string][]*persistConn), | 		conns:   make(map[string][]*persistConn), | ||||||
| 		host:    h, | 		expire:  defaultExpire, | ||||||
|  | 		addr:    addr, | ||||||
| 		dial:    make(chan string), | 		dial:    make(chan string), | ||||||
| 		yield:   make(chan connErr), | 		yield:   make(chan connErr), | ||||||
| 		ret:     make(chan connErr), | 		ret:     make(chan connErr), | ||||||
| @@ -51,7 +55,7 @@ func newTransport(h *host) *transport { | |||||||
| } | } | ||||||
|  |  | ||||||
| // len returns the number of connection, used for metrics. Can only be safely | // len returns the number of connection, used for metrics. Can only be safely | ||||||
| // used inside connManager() because of races. | // used inside connManager() because of data races. | ||||||
| func (t *transport) len() int { | func (t *transport) len() int { | ||||||
| 	l := 0 | 	l := 0 | ||||||
| 	for _, conns := range t.conns { | 	for _, conns := range t.conns { | ||||||
| @@ -79,7 +83,7 @@ Wait: | |||||||
| 			i := 0 | 			i := 0 | ||||||
| 			for i = 0; i < len(t.conns[proto]); i++ { | 			for i = 0; i < len(t.conns[proto]); i++ { | ||||||
| 				pc := t.conns[proto][i] | 				pc := t.conns[proto][i] | ||||||
| 				if time.Since(pc.used) < t.host.expire { | 				if time.Since(pc.used) < t.expire { | ||||||
| 					// Found one, remove from pool and return this conn. | 					// Found one, remove from pool and return this conn. | ||||||
| 					t.conns[proto] = t.conns[proto][i+1:] | 					t.conns[proto] = t.conns[proto][i+1:] | ||||||
| 					t.ret <- connErr{pc.c, nil} | 					t.ret <- connErr{pc.c, nil} | ||||||
| @@ -91,22 +95,22 @@ Wait: | |||||||
|  |  | ||||||
| 			// Not conns were found. Connect to the upstream to create one. | 			// Not conns were found. Connect to the upstream to create one. | ||||||
| 			t.conns[proto] = t.conns[proto][i:] | 			t.conns[proto] = t.conns[proto][i:] | ||||||
| 			SocketGauge.WithLabelValues(t.host.addr).Set(float64(t.len())) | 			SocketGauge.WithLabelValues(t.addr).Set(float64(t.len())) | ||||||
|  |  | ||||||
| 			go func() { | 			go func() { | ||||||
| 				if proto != "tcp-tls" { | 				if proto != "tcp-tls" { | ||||||
| 					c, err := dns.DialTimeout(proto, t.host.addr, dialTimeout) | 					c, err := dns.DialTimeout(proto, t.addr, dialTimeout) | ||||||
| 					t.ret <- connErr{c, err} | 					t.ret <- connErr{c, err} | ||||||
| 					return | 					return | ||||||
| 				} | 				} | ||||||
|  |  | ||||||
| 				c, err := dns.DialTimeoutWithTLS("tcp", t.host.addr, t.host.tlsConfig, dialTimeout) | 				c, err := dns.DialTimeoutWithTLS("tcp", t.addr, t.tlsConfig, dialTimeout) | ||||||
| 				t.ret <- connErr{c, err} | 				t.ret <- connErr{c, err} | ||||||
| 			}() | 			}() | ||||||
|  |  | ||||||
| 		case conn := <-t.yield: | 		case conn := <-t.yield: | ||||||
|  |  | ||||||
| 			SocketGauge.WithLabelValues(t.host.addr).Set(float64(t.len() + 1)) | 			SocketGauge.WithLabelValues(t.addr).Set(float64(t.len() + 1)) | ||||||
|  |  | ||||||
| 			// no proto here, infer from config and conn | 			// no proto here, infer from config and conn | ||||||
| 			if _, ok := conn.c.Conn.(*net.UDPConn); ok { | 			if _, ok := conn.c.Conn.(*net.UDPConn); ok { | ||||||
| @@ -114,7 +118,7 @@ Wait: | |||||||
| 				continue Wait | 				continue Wait | ||||||
| 			} | 			} | ||||||
|  |  | ||||||
| 			if t.host.tlsConfig == nil { | 			if t.tlsConfig == nil { | ||||||
| 				t.conns["tcp"] = append(t.conns["tcp"], &persistConn{conn.c, time.Now()}) | 				t.conns["tcp"] = append(t.conns["tcp"], &persistConn{conn.c, time.Now()}) | ||||||
| 				continue Wait | 				continue Wait | ||||||
| 			} | 			} | ||||||
| @@ -134,15 +138,30 @@ Wait: | |||||||
| 	} | 	} | ||||||
| } | } | ||||||
|  |  | ||||||
|  | // Dial dials the address configured in transport, potentially reusing a connection or creating a new one. | ||||||
| func (t *transport) Dial(proto string) (*dns.Conn, error) { | func (t *transport) Dial(proto string) (*dns.Conn, error) { | ||||||
|  | 	// If tls has been configured; use it. | ||||||
|  | 	if t.tlsConfig != nil { | ||||||
|  | 		proto = "tcp-tls" | ||||||
|  | 	} | ||||||
|  |  | ||||||
| 	t.dial <- proto | 	t.dial <- proto | ||||||
| 	c := <-t.ret | 	c := <-t.ret | ||||||
| 	return c.c, c.err | 	return c.c, c.err | ||||||
| } | } | ||||||
|  |  | ||||||
|  | // Yield return the connection to transport for reuse. | ||||||
| func (t *transport) Yield(c *dns.Conn) { | func (t *transport) Yield(c *dns.Conn) { | ||||||
| 	t.yield <- connErr{c, nil} | 	t.yield <- connErr{c, nil} | ||||||
| } | } | ||||||
|  |  | ||||||
| // Stop stops the transports. | // Stop stops the transport's connection manager. | ||||||
| func (t *transport) Stop() { t.stop <- true } | func (t *transport) Stop() { t.stop <- true } | ||||||
|  |  | ||||||
|  | // SetExpire sets the connection expire time in transport. | ||||||
|  | func (t *transport) SetExpire(expire time.Duration) { t.expire = expire } | ||||||
|  |  | ||||||
|  | // SetTLSConfig sets the TLS config in transport. | ||||||
|  | func (t *transport) SetTLSConfig(cfg *tls.Config) { t.tlsConfig = cfg } | ||||||
|  |  | ||||||
|  | const defaultExpire = 10 * time.Second | ||||||
|   | |||||||
| @@ -16,8 +16,7 @@ func TestPersistent(t *testing.T) { | |||||||
| 	}) | 	}) | ||||||
| 	defer s.Close() | 	defer s.Close() | ||||||
|  |  | ||||||
| 	h := newHost(s.Addr) | 	tr := newTransport(s.Addr, nil /* no TLS */) | ||||||
| 	tr := newTransport(h) |  | ||||||
| 	defer tr.Stop() | 	defer tr.Stop() | ||||||
|  |  | ||||||
| 	c1, _ := tr.Dial("udp") | 	c1, _ := tr.Dial("udp") | ||||||
|   | |||||||
| @@ -2,47 +2,60 @@ package forward | |||||||
|  |  | ||||||
| import ( | import ( | ||||||
| 	"crypto/tls" | 	"crypto/tls" | ||||||
| 	"sync" | 	"sync/atomic" | ||||||
| 	"time" | 	"time" | ||||||
|  |  | ||||||
|  | 	"github.com/coredns/coredns/plugin/pkg/up" | ||||||
|  |  | ||||||
| 	"github.com/miekg/dns" | 	"github.com/miekg/dns" | ||||||
| ) | ) | ||||||
|  |  | ||||||
| // Proxy defines an upstream host. | // Proxy defines an upstream host. | ||||||
| type Proxy struct { | type Proxy struct { | ||||||
| 	host *host | 	addr   string | ||||||
|  | 	client *dns.Client | ||||||
|  |  | ||||||
|  | 	// Connection caching | ||||||
|  | 	expire    time.Duration | ||||||
| 	transport *transport | 	transport *transport | ||||||
|  |  | ||||||
| 	// copied from Forward. | 	// health checking | ||||||
| 	hcInterval time.Duration | 	probe *up.Probe | ||||||
| 	forceTCP   bool | 	fails uint32 | ||||||
|  |  | ||||||
| 	stop chan bool |  | ||||||
|  |  | ||||||
| 	sync.RWMutex |  | ||||||
| } | } | ||||||
|  |  | ||||||
| // NewProxy returns a new proxy. | // NewProxy returns a new proxy. | ||||||
| func NewProxy(addr string) *Proxy { | func NewProxy(addr string, tlsConfig *tls.Config) *Proxy { | ||||||
| 	host := newHost(addr) |  | ||||||
|  |  | ||||||
| 	p := &Proxy{ | 	p := &Proxy{ | ||||||
| 		host:       host, | 		addr:      addr, | ||||||
| 		hcInterval: hcDuration, | 		fails:     0, | ||||||
| 		stop:       make(chan bool), | 		probe:     up.New(), | ||||||
| 		transport:  newTransport(host), | 		transport: newTransport(addr, tlsConfig), | ||||||
| 	} | 	} | ||||||
|  | 	p.client = dnsClient(tlsConfig) | ||||||
| 	return p | 	return p | ||||||
| } | } | ||||||
|  |  | ||||||
| // SetTLSConfig sets the TLS config in the lower p.host. | // dnsClient returns a client used for health checking. | ||||||
| func (p *Proxy) SetTLSConfig(cfg *tls.Config) { p.host.tlsConfig = cfg } | func dnsClient(tlsConfig *tls.Config) *dns.Client { | ||||||
|  | 	c := new(dns.Client) | ||||||
|  | 	c.Net = "udp" | ||||||
|  | 	// TODO(miek): this should be half of hcDuration? | ||||||
|  | 	c.ReadTimeout = 1 * time.Second | ||||||
|  | 	c.WriteTimeout = 1 * time.Second | ||||||
|  |  | ||||||
| // SetExpire sets the expire duration in the lower p.host. | 	if tlsConfig != nil { | ||||||
| func (p *Proxy) SetExpire(expire time.Duration) { p.host.expire = expire } | 		c.Net = "tcp-tls" | ||||||
|  | 		c.TLSConfig = tlsConfig | ||||||
|  | 	} | ||||||
|  | 	return c | ||||||
|  | } | ||||||
|  |  | ||||||
| func (p *Proxy) close() { p.stop <- true } | // SetTLSConfig sets the TLS config in the lower p.transport. | ||||||
|  | func (p *Proxy) SetTLSConfig(cfg *tls.Config) { p.transport.SetTLSConfig(cfg) } | ||||||
|  |  | ||||||
|  | // SetExpire sets the expire duration in the lower p.transport. | ||||||
|  | func (p *Proxy) SetExpire(expire time.Duration) { p.transport.SetExpire(expire) } | ||||||
|  |  | ||||||
| // Dial connects to the host in p with the configured transport. | // Dial connects to the host in p with the configured transport. | ||||||
| func (p *Proxy) Dial(proto string) (*dns.Conn, error) { return p.transport.Dial(proto) } | func (p *Proxy) Dial(proto string) (*dns.Conn, error) { return p.transport.Dial(proto) } | ||||||
| @@ -50,26 +63,28 @@ func (p *Proxy) Dial(proto string) (*dns.Conn, error) { return p.transport.Dial( | |||||||
| // Yield returns the connection to the pool. | // Yield returns the connection to the pool. | ||||||
| func (p *Proxy) Yield(c *dns.Conn) { p.transport.Yield(c) } | func (p *Proxy) Yield(c *dns.Conn) { p.transport.Yield(c) } | ||||||
|  |  | ||||||
| // Down returns if this proxy is up or down. | // Healthcheck kicks of a round of health checks for this proxy. | ||||||
| func (p *Proxy) Down(maxfails uint32) bool { return p.host.down(maxfails) } | func (p *Proxy) Healthcheck() { p.probe.Do(p.Check) } | ||||||
|  |  | ||||||
| func (p *Proxy) healthCheck() { | // Down returns true if this proxy is down, i.e. has *more* fails than maxfails. | ||||||
|  | func (p *Proxy) Down(maxfails uint32) bool { | ||||||
|  | 	if maxfails == 0 { | ||||||
|  | 		return false | ||||||
|  | 	} | ||||||
|  |  | ||||||
| 	// stop channel | 	fails := atomic.LoadUint32(&p.fails) | ||||||
| 	p.host.SetClient() | 	return fails > maxfails | ||||||
|  | } | ||||||
|  |  | ||||||
| 	p.host.Check() | // close stops the health checking goroutine. | ||||||
| 	tick := time.NewTicker(p.hcInterval) | func (p *Proxy) close() { | ||||||
| 	for { | 	p.probe.Stop() | ||||||
| 		select { | 	p.transport.Stop() | ||||||
| 		case <-tick.C: |  | ||||||
| 			p.host.Check() |  | ||||||
| 		case <-p.stop: |  | ||||||
| 			return |  | ||||||
| 		} |  | ||||||
| 	} |  | ||||||
| } | } | ||||||
|  |  | ||||||
|  | // start starts the proxy's healthchecking. | ||||||
|  | func (p *Proxy) start(duration time.Duration) { p.probe.Start(duration) } | ||||||
|  |  | ||||||
| const ( | const ( | ||||||
| 	dialTimeout = 4 * time.Second | 	dialTimeout = 4 * time.Second | ||||||
| 	timeout     = 2 * time.Second | 	timeout     = 2 * time.Second | ||||||
|   | |||||||
| @@ -62,25 +62,14 @@ func setup(c *caddy.Controller) error { | |||||||
|  |  | ||||||
| // OnStartup starts a goroutines for all proxies. | // OnStartup starts a goroutines for all proxies. | ||||||
| func (f *Forward) OnStartup() (err error) { | func (f *Forward) OnStartup() (err error) { | ||||||
| 	if f.hcInterval == 0 { |  | ||||||
| 	for _, p := range f.proxies { | 	for _, p := range f.proxies { | ||||||
| 			p.host.fails = 0 | 		p.start(f.hcInterval) | ||||||
| 		} |  | ||||||
| 		return nil |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	for _, p := range f.proxies { |  | ||||||
| 		go p.healthCheck() |  | ||||||
| 	} | 	} | ||||||
| 	return nil | 	return nil | ||||||
| } | } | ||||||
|  |  | ||||||
| // OnShutdown stops all configured proxies. | // OnShutdown stops all configured proxies. | ||||||
| func (f *Forward) OnShutdown() error { | func (f *Forward) OnShutdown() error { | ||||||
| 	if f.hcInterval == 0 { |  | ||||||
| 		return nil |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	for _, p := range f.proxies { | 	for _, p := range f.proxies { | ||||||
| 		p.close() | 		p.close() | ||||||
| 	} | 	} | ||||||
| @@ -88,9 +77,7 @@ func (f *Forward) OnShutdown() error { | |||||||
| } | } | ||||||
|  |  | ||||||
| // Close is a synonym for OnShutdown(). | // Close is a synonym for OnShutdown(). | ||||||
| func (f *Forward) Close() { | func (f *Forward) Close() { f.OnShutdown() } | ||||||
| 	f.OnShutdown() |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func parseForward(c *caddy.Controller) (*Forward, error) { | func parseForward(c *caddy.Controller) (*Forward, error) { | ||||||
| 	f := New() | 	f := New() | ||||||
| @@ -140,8 +127,8 @@ func parseForward(c *caddy.Controller) (*Forward, error) { | |||||||
| 			} | 			} | ||||||
|  |  | ||||||
| 			// We can't set tlsConfig here, because we haven't parsed it yet. | 			// We can't set tlsConfig here, because we haven't parsed it yet. | ||||||
| 			// We set it below at the end of parseBlock. | 			// We set it below at the end of parseBlock, use nil now. | ||||||
| 			p := NewProxy(h) | 			p := NewProxy(h, nil /* no TLS */) | ||||||
| 			f.proxies = append(f.proxies, p) | 			f.proxies = append(f.proxies, p) | ||||||
| 		} | 		} | ||||||
|  |  | ||||||
| @@ -200,17 +187,11 @@ func parseBlock(c *caddy.Controller, f *Forward) error { | |||||||
| 			return fmt.Errorf("health_check can't be negative: %d", dur) | 			return fmt.Errorf("health_check can't be negative: %d", dur) | ||||||
| 		} | 		} | ||||||
| 		f.hcInterval = dur | 		f.hcInterval = dur | ||||||
| 		for i := range f.proxies { |  | ||||||
| 			f.proxies[i].hcInterval = dur |  | ||||||
| 		} |  | ||||||
| 	case "force_tcp": | 	case "force_tcp": | ||||||
| 		if c.NextArg() { | 		if c.NextArg() { | ||||||
| 			return c.ArgErr() | 			return c.ArgErr() | ||||||
| 		} | 		} | ||||||
| 		f.forceTCP = true | 		f.forceTCP = true | ||||||
| 		for i := range f.proxies { |  | ||||||
| 			f.proxies[i].forceTCP = true |  | ||||||
| 		} |  | ||||||
| 	case "tls": | 	case "tls": | ||||||
| 		args := c.RemainingArgs() | 		args := c.RemainingArgs() | ||||||
| 		if len(args) != 3 { | 		if len(args) != 3 { | ||||||
|   | |||||||
| @@ -17,8 +17,8 @@ type Probe struct { | |||||||
| 	inprogress bool | 	inprogress bool | ||||||
| } | } | ||||||
|  |  | ||||||
| // Func is used to determine if a target is alive. If so this function must return true. | // Func is used to determine if a target is alive. If so this function must return nil. | ||||||
| type Func func(target string) bool | type Func func() error | ||||||
|  |  | ||||||
| // New returns a pointer to an intialized Probe. | // New returns a pointer to an intialized Probe. | ||||||
| func New() *Probe { | func New() *Probe { | ||||||
| @@ -32,9 +32,9 @@ func (p *Probe) Do(f Func) { p.do <- f } | |||||||
| func (p *Probe) Stop() { p.stop <- true } | func (p *Probe) Stop() { p.stop <- true } | ||||||
|  |  | ||||||
| // Start will start the probe manager, after which probes can be initialized with Do. | // Start will start the probe manager, after which probes can be initialized with Do. | ||||||
| func (p *Probe) Start(target string, interval time.Duration) { go p.start(target, interval) } | func (p *Probe) Start(interval time.Duration) { go p.start(interval) } | ||||||
|  |  | ||||||
| func (p *Probe) start(target string, interval time.Duration) { | func (p *Probe) start(interval time.Duration) { | ||||||
| 	for { | 	for { | ||||||
| 		select { | 		select { | ||||||
| 		case <-p.stop: | 		case <-p.stop: | ||||||
| @@ -52,9 +52,10 @@ func (p *Probe) start(target string, interval time.Duration) { | |||||||
| 			// we return from the goroutine and we can accept another Func to run. | 			// we return from the goroutine and we can accept another Func to run. | ||||||
| 			go func() { | 			go func() { | ||||||
| 				for { | 				for { | ||||||
| 					if ok := f(target); ok { | 					if err := f(); err == nil { | ||||||
| 						break | 						break | ||||||
| 					} | 					} | ||||||
|  | 					// TODO(miek): little bit of exponential backoff here? | ||||||
| 					time.Sleep(interval) | 					time.Sleep(interval) | ||||||
| 				} | 				} | ||||||
| 				p.Lock() | 				p.Lock() | ||||||
|   | |||||||
| @@ -12,20 +12,20 @@ func TestUp(t *testing.T) { | |||||||
| 	wg := sync.WaitGroup{} | 	wg := sync.WaitGroup{} | ||||||
| 	hits := int32(0) | 	hits := int32(0) | ||||||
|  |  | ||||||
| 	upfunc := func(s string) bool { | 	upfunc := func() error { | ||||||
| 		atomic.AddInt32(&hits, 1) | 		atomic.AddInt32(&hits, 1) | ||||||
| 		// Sleep tiny amount so that our other pr.Do() calls hit the lock. | 		// Sleep tiny amount so that our other pr.Do() calls hit the lock. | ||||||
| 		time.Sleep(3 * time.Millisecond) | 		time.Sleep(3 * time.Millisecond) | ||||||
| 		wg.Done() | 		wg.Done() | ||||||
| 		return true | 		return nil | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	pr.Start("nonexistent", 5*time.Millisecond) | 	pr.Start(5 * time.Millisecond) | ||||||
| 	defer pr.Stop() | 	defer pr.Stop() | ||||||
|  |  | ||||||
| 	// These functions AddInt32 to the same hits variable, but we only want to wait when | 	// These functions AddInt32 to the same hits variable, but we only want to wait when | ||||||
| 	// upfunc finishes, as that only calls Done() on the waitgroup. | 	// upfunc finishes, as that only calls Done() on the waitgroup. | ||||||
| 	upfuncNoWg := func(s string) bool { atomic.AddInt32(&hits, 1); return true } | 	upfuncNoWg := func() error { atomic.AddInt32(&hits, 1); return nil } | ||||||
| 	wg.Add(1) | 	wg.Add(1) | ||||||
| 	pr.Do(upfunc) | 	pr.Do(upfunc) | ||||||
| 	pr.Do(upfuncNoWg) | 	pr.Do(upfuncNoWg) | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user