| 
									
										
										
										
											2018-02-05 22:00:47 +00:00
										 |  |  | package forward
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | import (
 | 
					
						
							| 
									
										
										
										
											2018-02-15 10:21:57 +01:00
										 |  |  | 	"crypto/tls"
 | 
					
						
							| 
									
										
										
										
											2018-02-05 22:00:47 +00:00
										 |  |  | 	"net"
 | 
					
						
							| 
									
										
										
										
											2018-05-26 01:00:11 +03:00
										 |  |  | 	"sort"
 | 
					
						
							| 
									
										
										
										
											2018-02-05 22:00:47 +00:00
										 |  |  | 	"time"
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	"github.com/miekg/dns"
 | 
					
						
							|  |  |  | )
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | // a persistConn hold the dns.Conn and the last used time.
 | 
					
						
							|  |  |  | type persistConn struct {
 | 
					
						
							|  |  |  | 	c    *dns.Conn
 | 
					
						
							|  |  |  | 	used time.Time
 | 
					
						
							|  |  |  | }
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2018-09-19 07:29:37 +01:00
										 |  |  | // Transport hold the persistent cache.
 | 
					
						
							|  |  |  | type Transport struct {
 | 
					
						
							| 
									
										
										
										
											2018-06-15 02:37:22 -04:00
										 |  |  | 	avgDialTime int64                     // kind of average time of dial time
 | 
					
						
							| 
									
										
										
										
											2018-09-19 07:29:37 +01:00
										 |  |  | 	conns       map[string][]*persistConn // Buckets for udp, tcp and tcp-tls.
 | 
					
						
							| 
									
										
										
										
											2018-06-15 02:37:22 -04:00
										 |  |  | 	expire      time.Duration             // After this duration a connection is expired.
 | 
					
						
							|  |  |  | 	addr        string
 | 
					
						
							|  |  |  | 	tlsConfig   *tls.Config
 | 
					
						
							| 
									
										
										
										
											2018-02-05 22:00:47 +00:00
										 |  |  | 
 | 
					
						
							|  |  |  | 	dial  chan string
 | 
					
						
							| 
									
										
										
										
											2018-04-26 09:34:58 +01:00
										 |  |  | 	yield chan *dns.Conn
 | 
					
						
							|  |  |  | 	ret   chan *dns.Conn
 | 
					
						
							| 
									
										
										
										
											2018-04-24 16:10:31 +01:00
										 |  |  | 	stop  chan bool
 | 
					
						
							| 
									
										
										
										
											2018-02-05 22:00:47 +00:00
										 |  |  | }
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2018-09-19 07:29:37 +01:00
										 |  |  | func newTransport(addr string) *Transport {
 | 
					
						
							|  |  |  | 	t := &Transport{
 | 
					
						
							| 
									
										
										
										
											2018-11-20 08:48:56 +01:00
										 |  |  | 		avgDialTime: int64(maxDialTimeout / 2),
 | 
					
						
							| 
									
										
										
										
											2018-06-15 02:37:22 -04:00
										 |  |  | 		conns:       make(map[string][]*persistConn),
 | 
					
						
							|  |  |  | 		expire:      defaultExpire,
 | 
					
						
							|  |  |  | 		addr:        addr,
 | 
					
						
							|  |  |  | 		dial:        make(chan string),
 | 
					
						
							|  |  |  | 		yield:       make(chan *dns.Conn),
 | 
					
						
							|  |  |  | 		ret:         make(chan *dns.Conn),
 | 
					
						
							|  |  |  | 		stop:        make(chan bool),
 | 
					
						
							| 
									
										
										
										
											2018-02-05 22:00:47 +00:00
										 |  |  | 	}
 | 
					
						
							|  |  |  | 	return t
 | 
					
						
							|  |  |  | }
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | // len returns the number of connection, used for metrics. Can only be safely
 | 
					
						
							| 
									
										
										
										
											2018-02-15 10:21:57 +01:00
										 |  |  | // used inside connManager() because of data races.
 | 
					
						
							| 
									
										
										
										
											2018-09-19 07:29:37 +01:00
										 |  |  | func (t *Transport) len() int {
 | 
					
						
							| 
									
										
										
										
											2018-02-05 22:00:47 +00:00
										 |  |  | 	l := 0
 | 
					
						
							|  |  |  | 	for _, conns := range t.conns {
 | 
					
						
							|  |  |  | 		l += len(conns)
 | 
					
						
							|  |  |  | 	}
 | 
					
						
							|  |  |  | 	return l
 | 
					
						
							|  |  |  | }
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | // connManagers manages the persistent connection cache for UDP and TCP.
 | 
					
						
							| 
									
										
										
										
											2018-09-19 07:29:37 +01:00
										 |  |  | func (t *Transport) connManager() {
 | 
					
						
							| 
									
										
										
										
											2018-05-26 01:00:11 +03:00
										 |  |  | 	ticker := time.NewTicker(t.expire)
 | 
					
						
							| 
									
										
										
										
											2018-02-05 22:00:47 +00:00
										 |  |  | Wait:
 | 
					
						
							|  |  |  | 	for {
 | 
					
						
							|  |  |  | 		select {
 | 
					
						
							|  |  |  | 		case proto := <-t.dial:
 | 
					
						
							| 
									
										
										
										
											2018-05-26 01:00:11 +03:00
										 |  |  | 			// take the last used conn - complexity O(1)
 | 
					
						
							|  |  |  | 			if stack := t.conns[proto]; len(stack) > 0 {
 | 
					
						
							|  |  |  | 				pc := stack[len(stack)-1]
 | 
					
						
							| 
									
										
										
										
											2018-02-15 10:21:57 +01:00
										 |  |  | 				if time.Since(pc.used) < t.expire {
 | 
					
						
							| 
									
										
										
										
											2018-02-05 22:00:47 +00:00
										 |  |  | 					// Found one, remove from pool and return this conn.
 | 
					
						
							| 
									
										
										
										
											2018-05-26 01:00:11 +03:00
										 |  |  | 					t.conns[proto] = stack[:len(stack)-1]
 | 
					
						
							| 
									
										
										
										
											2018-04-26 09:34:58 +01:00
										 |  |  | 					t.ret <- pc.c
 | 
					
						
							| 
									
										
										
										
											2018-02-05 22:00:47 +00:00
										 |  |  | 					continue Wait
 | 
					
						
							|  |  |  | 				}
 | 
					
						
							| 
									
										
										
										
											2018-05-26 01:00:11 +03:00
										 |  |  | 				// clear entire cache if the last conn is expired
 | 
					
						
							|  |  |  | 				t.conns[proto] = nil
 | 
					
						
							|  |  |  | 				// now, the connections being passed to closeConns() are not reachable from
 | 
					
						
							|  |  |  | 				// transport methods anymore. So, it's safe to close them in a separate goroutine
 | 
					
						
							|  |  |  | 				go closeConns(stack)
 | 
					
						
							| 
									
										
										
										
											2018-02-05 22:00:47 +00:00
										 |  |  | 			}
 | 
					
						
							| 
									
										
										
										
											2018-02-15 10:21:57 +01:00
										 |  |  | 			SocketGauge.WithLabelValues(t.addr).Set(float64(t.len()))
 | 
					
						
							| 
									
										
										
										
											2018-02-05 22:00:47 +00:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2018-04-26 09:34:58 +01:00
										 |  |  | 			t.ret <- nil
 | 
					
						
							| 
									
										
										
										
											2018-02-05 22:00:47 +00:00
										 |  |  | 
 | 
					
						
							|  |  |  | 		case conn := <-t.yield:
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2018-02-15 10:21:57 +01:00
										 |  |  | 			SocketGauge.WithLabelValues(t.addr).Set(float64(t.len() + 1))
 | 
					
						
							| 
									
										
										
										
											2018-02-05 22:00:47 +00:00
										 |  |  | 
 | 
					
						
							|  |  |  | 			// no proto here, infer from config and conn
 | 
					
						
							| 
									
										
										
										
											2018-04-26 09:34:58 +01:00
										 |  |  | 			if _, ok := conn.Conn.(*net.UDPConn); ok {
 | 
					
						
							|  |  |  | 				t.conns["udp"] = append(t.conns["udp"], &persistConn{conn, time.Now()})
 | 
					
						
							| 
									
										
										
										
											2018-02-05 22:00:47 +00:00
										 |  |  | 				continue Wait
 | 
					
						
							|  |  |  | 			}
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2018-02-15 10:21:57 +01:00
										 |  |  | 			if t.tlsConfig == nil {
 | 
					
						
							| 
									
										
										
										
											2018-04-26 09:34:58 +01:00
										 |  |  | 				t.conns["tcp"] = append(t.conns["tcp"], &persistConn{conn, time.Now()})
 | 
					
						
							| 
									
										
										
										
											2018-02-05 22:00:47 +00:00
										 |  |  | 				continue Wait
 | 
					
						
							|  |  |  | 			}
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2018-04-26 09:34:58 +01:00
										 |  |  | 			t.conns["tcp-tls"] = append(t.conns["tcp-tls"], &persistConn{conn, time.Now()})
 | 
					
						
							| 
									
										
										
										
											2018-02-05 22:00:47 +00:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2018-05-26 01:00:11 +03:00
										 |  |  | 		case <-ticker.C:
 | 
					
						
							|  |  |  | 			t.cleanup(false)
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2018-02-05 22:00:47 +00:00
										 |  |  | 		case <-t.stop:
 | 
					
						
							| 
									
										
										
										
											2018-05-26 01:00:11 +03:00
										 |  |  | 			t.cleanup(true)
 | 
					
						
							| 
									
										
										
										
											2018-04-26 09:34:58 +01:00
										 |  |  | 			close(t.ret)
 | 
					
						
							| 
									
										
										
										
											2018-02-05 22:00:47 +00:00
										 |  |  | 			return
 | 
					
						
							|  |  |  | 		}
 | 
					
						
							|  |  |  | 	}
 | 
					
						
							|  |  |  | }
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2018-05-26 01:00:11 +03:00
										 |  |  | // closeConns closes connections.
 | 
					
						
							|  |  |  | func closeConns(conns []*persistConn) {
 | 
					
						
							|  |  |  | 	for _, pc := range conns {
 | 
					
						
							|  |  |  | 		pc.c.Close()
 | 
					
						
							|  |  |  | 	}
 | 
					
						
							|  |  |  | }
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | // cleanup removes connections from cache.
 | 
					
						
							| 
									
										
										
										
											2018-09-19 07:29:37 +01:00
										 |  |  | func (t *Transport) cleanup(all bool) {
 | 
					
						
							| 
									
										
										
										
											2018-05-26 01:00:11 +03:00
										 |  |  | 	staleTime := time.Now().Add(-t.expire)
 | 
					
						
							|  |  |  | 	for proto, stack := range t.conns {
 | 
					
						
							|  |  |  | 		if len(stack) == 0 {
 | 
					
						
							|  |  |  | 			continue
 | 
					
						
							|  |  |  | 		}
 | 
					
						
							|  |  |  | 		if all {
 | 
					
						
							|  |  |  | 			t.conns[proto] = nil
 | 
					
						
							|  |  |  | 			// now, the connections being passed to closeConns() are not reachable from
 | 
					
						
							|  |  |  | 			// transport methods anymore. So, it's safe to close them in a separate goroutine
 | 
					
						
							|  |  |  | 			go closeConns(stack)
 | 
					
						
							|  |  |  | 			continue
 | 
					
						
							|  |  |  | 		}
 | 
					
						
							|  |  |  | 		if stack[0].used.After(staleTime) {
 | 
					
						
							|  |  |  | 			continue
 | 
					
						
							|  |  |  | 		}
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 		// connections in stack are sorted by "used"
 | 
					
						
							|  |  |  | 		good := sort.Search(len(stack), func(i int) bool {
 | 
					
						
							|  |  |  | 			return stack[i].used.After(staleTime)
 | 
					
						
							|  |  |  | 		})
 | 
					
						
							|  |  |  | 		t.conns[proto] = stack[good:]
 | 
					
						
							|  |  |  | 		// now, the connections being passed to closeConns() are not reachable from
 | 
					
						
							|  |  |  | 		// transport methods anymore. So, it's safe to close them in a separate goroutine
 | 
					
						
							|  |  |  | 		go closeConns(stack[:good])
 | 
					
						
							|  |  |  | 	}
 | 
					
						
							|  |  |  | }
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2018-02-15 10:21:57 +01:00
										 |  |  | // Yield return the connection to transport for reuse.
 | 
					
						
							| 
									
										
										
										
											2018-09-19 07:29:37 +01:00
										 |  |  | func (t *Transport) Yield(c *dns.Conn) { t.yield <- c }
 | 
					
						
							| 
									
										
										
										
											2018-02-05 22:00:47 +00:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2018-05-26 01:00:11 +03:00
										 |  |  | // Start starts the transport's connection manager.
 | 
					
						
							| 
									
										
										
										
											2018-09-19 07:29:37 +01:00
										 |  |  | func (t *Transport) Start() { go t.connManager() }
 | 
					
						
							| 
									
										
										
										
											2018-05-26 01:00:11 +03:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2018-02-15 10:21:57 +01:00
										 |  |  | // Stop stops the transport's connection manager.
 | 
					
						
							| 
									
										
										
										
											2018-09-19 07:29:37 +01:00
										 |  |  | func (t *Transport) Stop() { close(t.stop) }
 | 
					
						
							| 
									
										
										
										
											2018-02-15 10:21:57 +01:00
										 |  |  | 
 | 
					
						
							|  |  |  | // SetExpire sets the connection expire time in transport.
 | 
					
						
							| 
									
										
										
										
											2018-09-19 07:29:37 +01:00
										 |  |  | func (t *Transport) SetExpire(expire time.Duration) { t.expire = expire }
 | 
					
						
							| 
									
										
										
										
											2018-02-15 10:21:57 +01:00
										 |  |  | 
 | 
					
						
							|  |  |  | // SetTLSConfig sets the TLS config in transport.
 | 
					
						
							| 
									
										
										
										
											2018-09-19 07:29:37 +01:00
										 |  |  | func (t *Transport) SetTLSConfig(cfg *tls.Config) { t.tlsConfig = cfg }
 | 
					
						
							| 
									
										
										
										
											2018-02-15 10:21:57 +01:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2018-06-15 02:37:22 -04:00
										 |  |  | const (
 | 
					
						
							| 
									
										
										
										
											2018-11-20 08:48:56 +01:00
										 |  |  | 	defaultExpire  = 10 * time.Second
 | 
					
						
							|  |  |  | 	minDialTimeout = 1 * time.Second
 | 
					
						
							|  |  |  | 	maxDialTimeout = 30 * time.Second
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	// Some resolves might take quite a while, usually (cached) responses are fast. Set to 2s to give us some time to retry a different upstream.
 | 
					
						
							|  |  |  | 	readTimeout = 2 * time.Second
 | 
					
						
							| 
									
										
										
										
											2018-06-15 02:37:22 -04:00
										 |  |  | )
 |