| 
									
										
										
										
											2018-02-05 22:00:47 +00:00
										 |  |  | // Package forward implements a forwarding proxy. It caches an upstream net.Conn for some time, so if the same
 | 
					
						
							|  |  |  | // client returns the upstream's Conn will be precached. Depending on how you benchmark this looks to be
 | 
					
						
							|  |  |  | // 50% faster than just openening a new connection for every client. It works with UDP and TCP and uses
 | 
					
						
							|  |  |  | // inband healthchecking.
 | 
					
						
							|  |  |  | package forward
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | import (
 | 
					
						
							| 
									
										
										
										
											2018-04-22 08:34:35 +01:00
										 |  |  | 	"context"
 | 
					
						
							| 
									
										
										
										
											2018-04-06 15:41:48 +03:00
										 |  |  | 	"io"
 | 
					
						
							| 
									
										
										
										
											2018-02-05 22:00:47 +00:00
										 |  |  | 	"strconv"
 | 
					
						
							| 
									
										
										
										
											2018-04-11 09:50:06 +03:00
										 |  |  | 	"sync/atomic"
 | 
					
						
							| 
									
										
										
										
											2018-02-05 22:00:47 +00:00
										 |  |  | 	"time"
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	"github.com/coredns/coredns/request"
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	"github.com/miekg/dns"
 | 
					
						
							|  |  |  | )
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2018-04-11 09:50:06 +03:00
										 |  |  | func (p *Proxy) readTimeout() time.Duration {
 | 
					
						
							|  |  |  | 	rtt := time.Duration(atomic.LoadInt64(&p.avgRtt))
 | 
					
						
							| 
									
										
										
										
											2018-04-16 19:51:49 +01:00
										 |  |  | 
 | 
					
						
							|  |  |  | 	if rtt < minTimeout {
 | 
					
						
							|  |  |  | 		return minTimeout
 | 
					
						
							|  |  |  | 	}
 | 
					
						
							|  |  |  | 	if rtt < maxTimeout/2 {
 | 
					
						
							| 
									
										
										
										
											2018-04-11 09:50:06 +03:00
										 |  |  | 		return 2 * rtt
 | 
					
						
							|  |  |  | 	}
 | 
					
						
							| 
									
										
										
										
											2018-04-16 19:51:49 +01:00
										 |  |  | 	return maxTimeout
 | 
					
						
							| 
									
										
										
										
											2018-04-11 09:50:06 +03:00
										 |  |  | }
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | func (p *Proxy) updateRtt(newRtt time.Duration) {
 | 
					
						
							|  |  |  | 	rtt := time.Duration(atomic.LoadInt64(&p.avgRtt))
 | 
					
						
							|  |  |  | 	atomic.AddInt64(&p.avgRtt, int64((newRtt-rtt)/rttCount))
 | 
					
						
							|  |  |  | }
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2018-02-05 22:00:47 +00:00
										 |  |  | func (p *Proxy) connect(ctx context.Context, state request.Request, forceTCP, metric bool) (*dns.Msg, error) {
 | 
					
						
							| 
									
										
										
										
											2018-04-20 17:47:46 +03:00
										 |  |  | 	atomic.AddInt32(&p.inProgress, 1)
 | 
					
						
							|  |  |  | 	defer func() {
 | 
					
						
							|  |  |  | 		if atomic.AddInt32(&p.inProgress, -1) == 0 {
 | 
					
						
							|  |  |  | 			p.checkStopTransport()
 | 
					
						
							|  |  |  | 		}
 | 
					
						
							|  |  |  | 	}()
 | 
					
						
							|  |  |  | 	if atomic.LoadUint32(&p.state) != running {
 | 
					
						
							|  |  |  | 		return nil, errStopped
 | 
					
						
							|  |  |  | 	}
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2018-02-05 22:00:47 +00:00
										 |  |  | 	start := time.Now()
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	proto := state.Proto()
 | 
					
						
							|  |  |  | 	if forceTCP {
 | 
					
						
							|  |  |  | 		proto = "tcp"
 | 
					
						
							|  |  |  | 	}
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2018-04-06 15:41:48 +03:00
										 |  |  | 	conn, cached, err := p.Dial(proto)
 | 
					
						
							| 
									
										
										
										
											2018-02-05 22:00:47 +00:00
										 |  |  | 	if err != nil {
 | 
					
						
							|  |  |  | 		return nil, err
 | 
					
						
							|  |  |  | 	}
 | 
					
						
							|  |  |  | 	// Set buffer size correctly for this client.
 | 
					
						
							|  |  |  | 	conn.UDPSize = uint16(state.Size())
 | 
					
						
							|  |  |  | 	if conn.UDPSize < 512 {
 | 
					
						
							|  |  |  | 		conn.UDPSize = 512
 | 
					
						
							|  |  |  | 	}
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	conn.SetWriteDeadline(time.Now().Add(timeout))
 | 
					
						
							| 
									
										
										
										
											2018-04-11 09:50:06 +03:00
										 |  |  | 	reqTime := time.Now()
 | 
					
						
							| 
									
										
										
										
											2018-02-05 22:00:47 +00:00
										 |  |  | 	if err := conn.WriteMsg(state.Req); err != nil {
 | 
					
						
							|  |  |  | 		conn.Close() // not giving it back
 | 
					
						
							| 
									
										
										
										
											2018-04-06 15:41:48 +03:00
										 |  |  | 		if err == io.EOF && cached {
 | 
					
						
							|  |  |  | 			return nil, errCachedClosed
 | 
					
						
							|  |  |  | 		}
 | 
					
						
							| 
									
										
										
										
											2018-02-05 22:00:47 +00:00
										 |  |  | 		return nil, err
 | 
					
						
							|  |  |  | 	}
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2018-04-11 09:50:06 +03:00
										 |  |  | 	conn.SetReadDeadline(time.Now().Add(p.readTimeout()))
 | 
					
						
							| 
									
										
										
										
											2018-02-05 22:00:47 +00:00
										 |  |  | 	ret, err := conn.ReadMsg()
 | 
					
						
							|  |  |  | 	if err != nil {
 | 
					
						
							| 
									
										
										
										
											2018-04-11 09:50:06 +03:00
										 |  |  | 		p.updateRtt(timeout)
 | 
					
						
							| 
									
										
										
										
											2018-02-05 22:00:47 +00:00
										 |  |  | 		conn.Close() // not giving it back
 | 
					
						
							| 
									
										
										
										
											2018-04-06 15:41:48 +03:00
										 |  |  | 		if err == io.EOF && cached {
 | 
					
						
							|  |  |  | 			return nil, errCachedClosed
 | 
					
						
							|  |  |  | 		}
 | 
					
						
							| 
									
										
										
										
											2018-04-12 21:17:05 +02:00
										 |  |  | 		return ret, err
 | 
					
						
							| 
									
										
										
										
											2018-02-05 22:00:47 +00:00
										 |  |  | 	}
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2018-04-11 09:50:06 +03:00
										 |  |  | 	p.updateRtt(time.Since(reqTime))
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2018-02-05 22:00:47 +00:00
										 |  |  | 	p.Yield(conn)
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	if metric {
 | 
					
						
							|  |  |  | 		rc, ok := dns.RcodeToString[ret.Rcode]
 | 
					
						
							|  |  |  | 		if !ok {
 | 
					
						
							|  |  |  | 			rc = strconv.Itoa(ret.Rcode)
 | 
					
						
							|  |  |  | 		}
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2018-02-15 10:21:57 +01:00
										 |  |  | 		RequestCount.WithLabelValues(p.addr).Add(1)
 | 
					
						
							|  |  |  | 		RcodeCount.WithLabelValues(rc, p.addr).Add(1)
 | 
					
						
							|  |  |  | 		RequestDuration.WithLabelValues(p.addr).Observe(time.Since(start).Seconds())
 | 
					
						
							| 
									
										
										
										
											2018-02-05 22:00:47 +00:00
										 |  |  | 	}
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	return ret, nil
 | 
					
						
							|  |  |  | }
 | 
					
						
							| 
									
										
										
										
											2018-04-11 09:50:06 +03:00
										 |  |  | 
 | 
					
						
							|  |  |  | const rttCount = 4
 |