| 
									
										
										
										
											2018-02-05 22:00:47 +00:00
										 |  |  | package forward
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | import (
 | 
					
						
							| 
									
										
										
										
											2018-02-15 10:21:57 +01:00
										 |  |  | 	"crypto/tls"
 | 
					
						
							| 
									
										
										
										
											2018-02-05 22:00:47 +00:00
										 |  |  | 	"net"
 | 
					
						
							|  |  |  | 	"time"
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	"github.com/miekg/dns"
 | 
					
						
							|  |  |  | )
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | // a persistConn hold the dns.Conn and the last used time.
 | 
					
						
							|  |  |  | type persistConn struct {
 | 
					
						
							|  |  |  | 	c    *dns.Conn
 | 
					
						
							|  |  |  | 	used time.Time
 | 
					
						
							|  |  |  | }
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | // connErr is used to communicate the connection manager.
 | 
					
						
							|  |  |  | type connErr struct {
 | 
					
						
							| 
									
										
										
										
											2018-04-06 15:41:48 +03:00
										 |  |  | 	c      *dns.Conn
 | 
					
						
							|  |  |  | 	err    error
 | 
					
						
							|  |  |  | 	cached bool
 | 
					
						
							| 
									
										
										
										
											2018-02-05 22:00:47 +00:00
										 |  |  | }
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | // transport hold the persistent cache.
 | 
					
						
							|  |  |  | type transport struct {
 | 
					
						
							| 
									
										
										
										
											2018-02-15 10:21:57 +01:00
										 |  |  | 	conns     map[string][]*persistConn //  Buckets for udp, tcp and tcp-tls.
 | 
					
						
							|  |  |  | 	expire    time.Duration             // After this duration a connection is expired.
 | 
					
						
							|  |  |  | 	addr      string
 | 
					
						
							|  |  |  | 	tlsConfig *tls.Config
 | 
					
						
							| 
									
										
										
										
											2018-02-05 22:00:47 +00:00
										 |  |  | 
 | 
					
						
							|  |  |  | 	dial  chan string
 | 
					
						
							|  |  |  | 	yield chan connErr
 | 
					
						
							|  |  |  | 	ret   chan connErr
 | 
					
						
							| 
									
										
										
										
											2018-04-24 16:10:31 +01:00
										 |  |  | 	stop  chan bool
 | 
					
						
							| 
									
										
										
										
											2018-02-05 22:00:47 +00:00
										 |  |  | }
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2018-02-15 10:21:57 +01:00
										 |  |  | func newTransport(addr string, tlsConfig *tls.Config) *transport {
 | 
					
						
							| 
									
										
										
										
											2018-02-05 22:00:47 +00:00
										 |  |  | 	t := &transport{
 | 
					
						
							| 
									
										
										
										
											2018-04-24 16:10:31 +01:00
										 |  |  | 		conns:  make(map[string][]*persistConn),
 | 
					
						
							|  |  |  | 		expire: defaultExpire,
 | 
					
						
							|  |  |  | 		addr:   addr,
 | 
					
						
							|  |  |  | 		dial:   make(chan string),
 | 
					
						
							|  |  |  | 		yield:  make(chan connErr),
 | 
					
						
							|  |  |  | 		ret:    make(chan connErr),
 | 
					
						
							|  |  |  | 		stop:   make(chan bool),
 | 
					
						
							| 
									
										
										
										
											2018-02-05 22:00:47 +00:00
										 |  |  | 	}
 | 
					
						
							|  |  |  | 	go t.connManager()
 | 
					
						
							|  |  |  | 	return t
 | 
					
						
							|  |  |  | }
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | // len returns the number of connection, used for metrics. Can only be safely
 | 
					
						
							| 
									
										
										
										
											2018-02-15 10:21:57 +01:00
										 |  |  | // used inside connManager() because of data races.
 | 
					
						
							| 
									
										
										
										
											2018-02-05 22:00:47 +00:00
										 |  |  | func (t *transport) len() int {
 | 
					
						
							|  |  |  | 	l := 0
 | 
					
						
							|  |  |  | 	for _, conns := range t.conns {
 | 
					
						
							|  |  |  | 		l += len(conns)
 | 
					
						
							|  |  |  | 	}
 | 
					
						
							|  |  |  | 	return l
 | 
					
						
							|  |  |  | }
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | // connManagers manages the persistent connection cache for UDP and TCP.
 | 
					
						
							|  |  |  | func (t *transport) connManager() {
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | Wait:
 | 
					
						
							|  |  |  | 	for {
 | 
					
						
							|  |  |  | 		select {
 | 
					
						
							|  |  |  | 		case proto := <-t.dial:
 | 
					
						
							|  |  |  | 			// Yes O(n), shouldn't put millions in here. We walk all connection until we find the first
 | 
					
						
							|  |  |  | 			// one that is usuable.
 | 
					
						
							|  |  |  | 			i := 0
 | 
					
						
							|  |  |  | 			for i = 0; i < len(t.conns[proto]); i++ {
 | 
					
						
							|  |  |  | 				pc := t.conns[proto][i]
 | 
					
						
							| 
									
										
										
										
											2018-02-15 10:21:57 +01:00
										 |  |  | 				if time.Since(pc.used) < t.expire {
 | 
					
						
							| 
									
										
										
										
											2018-02-05 22:00:47 +00:00
										 |  |  | 					// Found one, remove from pool and return this conn.
 | 
					
						
							|  |  |  | 					t.conns[proto] = t.conns[proto][i+1:]
 | 
					
						
							| 
									
										
										
										
											2018-04-06 15:41:48 +03:00
										 |  |  | 					t.ret <- connErr{pc.c, nil, true}
 | 
					
						
							| 
									
										
										
										
											2018-02-05 22:00:47 +00:00
										 |  |  | 					continue Wait
 | 
					
						
							|  |  |  | 				}
 | 
					
						
							|  |  |  | 				// This conn has expired. Close it.
 | 
					
						
							|  |  |  | 				pc.c.Close()
 | 
					
						
							|  |  |  | 			}
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 			// Not conns were found. Connect to the upstream to create one.
 | 
					
						
							|  |  |  | 			t.conns[proto] = t.conns[proto][i:]
 | 
					
						
							| 
									
										
										
										
											2018-02-15 10:21:57 +01:00
										 |  |  | 			SocketGauge.WithLabelValues(t.addr).Set(float64(t.len()))
 | 
					
						
							| 
									
										
										
										
											2018-02-05 22:00:47 +00:00
										 |  |  | 
 | 
					
						
							|  |  |  | 			go func() {
 | 
					
						
							|  |  |  | 				if proto != "tcp-tls" {
 | 
					
						
							| 
									
										
										
										
											2018-02-15 10:21:57 +01:00
										 |  |  | 					c, err := dns.DialTimeout(proto, t.addr, dialTimeout)
 | 
					
						
							| 
									
										
										
										
											2018-04-06 15:41:48 +03:00
										 |  |  | 					t.ret <- connErr{c, err, false}
 | 
					
						
							| 
									
										
										
										
											2018-02-05 22:00:47 +00:00
										 |  |  | 					return
 | 
					
						
							|  |  |  | 				}
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2018-02-15 10:21:57 +01:00
										 |  |  | 				c, err := dns.DialTimeoutWithTLS("tcp", t.addr, t.tlsConfig, dialTimeout)
 | 
					
						
							| 
									
										
										
										
											2018-04-06 15:41:48 +03:00
										 |  |  | 				t.ret <- connErr{c, err, false}
 | 
					
						
							| 
									
										
										
										
											2018-02-05 22:00:47 +00:00
										 |  |  | 			}()
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 		case conn := <-t.yield:
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2018-02-15 10:21:57 +01:00
										 |  |  | 			SocketGauge.WithLabelValues(t.addr).Set(float64(t.len() + 1))
 | 
					
						
							| 
									
										
										
										
											2018-02-05 22:00:47 +00:00
										 |  |  | 
 | 
					
						
							|  |  |  | 			// no proto here, infer from config and conn
 | 
					
						
							|  |  |  | 			if _, ok := conn.c.Conn.(*net.UDPConn); ok {
 | 
					
						
							|  |  |  | 				t.conns["udp"] = append(t.conns["udp"], &persistConn{conn.c, time.Now()})
 | 
					
						
							|  |  |  | 				continue Wait
 | 
					
						
							|  |  |  | 			}
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2018-02-15 10:21:57 +01:00
										 |  |  | 			if t.tlsConfig == nil {
 | 
					
						
							| 
									
										
										
										
											2018-02-05 22:00:47 +00:00
										 |  |  | 				t.conns["tcp"] = append(t.conns["tcp"], &persistConn{conn.c, time.Now()})
 | 
					
						
							|  |  |  | 				continue Wait
 | 
					
						
							|  |  |  | 			}
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 			t.conns["tcp-tls"] = append(t.conns["tcp-tls"], &persistConn{conn.c, time.Now()})
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 		case <-t.stop:
 | 
					
						
							|  |  |  | 			return
 | 
					
						
							|  |  |  | 		}
 | 
					
						
							|  |  |  | 	}
 | 
					
						
							|  |  |  | }
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2018-02-15 10:21:57 +01:00
										 |  |  | // Dial dials the address configured in transport, potentially reusing a connection or creating a new one.
 | 
					
						
							| 
									
										
										
										
											2018-04-06 15:41:48 +03:00
										 |  |  | func (t *transport) Dial(proto string) (*dns.Conn, bool, error) {
 | 
					
						
							| 
									
										
										
										
											2018-02-15 10:21:57 +01:00
										 |  |  | 	// If tls has been configured; use it.
 | 
					
						
							|  |  |  | 	if t.tlsConfig != nil {
 | 
					
						
							|  |  |  | 		proto = "tcp-tls"
 | 
					
						
							|  |  |  | 	}
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2018-02-05 22:00:47 +00:00
										 |  |  | 	t.dial <- proto
 | 
					
						
							|  |  |  | 	c := <-t.ret
 | 
					
						
							| 
									
										
										
										
											2018-04-06 15:41:48 +03:00
										 |  |  | 	return c.c, c.cached, c.err
 | 
					
						
							| 
									
										
										
										
											2018-02-05 22:00:47 +00:00
										 |  |  | }
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2018-02-15 10:21:57 +01:00
										 |  |  | // Yield return the connection to transport for reuse.
 | 
					
						
							| 
									
										
										
										
											2018-02-05 22:00:47 +00:00
										 |  |  | func (t *transport) Yield(c *dns.Conn) {
 | 
					
						
							| 
									
										
										
										
											2018-04-06 15:41:48 +03:00
										 |  |  | 	t.yield <- connErr{c, nil, false}
 | 
					
						
							| 
									
										
										
										
											2018-02-05 22:00:47 +00:00
										 |  |  | }
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2018-02-15 10:21:57 +01:00
										 |  |  | // Stop stops the transport's connection manager.
 | 
					
						
							| 
									
										
										
										
											2018-02-05 22:00:47 +00:00
										 |  |  | func (t *transport) Stop() { t.stop <- true }
 | 
					
						
							| 
									
										
										
										
											2018-02-15 10:21:57 +01:00
										 |  |  | 
 | 
					
						
							|  |  |  | // SetExpire sets the connection expire time in transport.
 | 
					
						
							|  |  |  | func (t *transport) SetExpire(expire time.Duration) { t.expire = expire }
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | // SetTLSConfig sets the TLS config in transport.
 | 
					
						
							|  |  |  | func (t *transport) SetTLSConfig(cfg *tls.Config) { t.tlsConfig = cfg }
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | const defaultExpire = 10 * time.Second
 |