mirror of
				https://github.com/coredns/coredns.git
				synced 2025-10-27 08:14:18 -04:00 
			
		
		
		
	plugin/forward: move Dial goroutine out (#1738)
Rework the TestProxyClose - close the proxy in the *same* goroutine as where we started it. Close channels as long as we don't get dataraces (this may need another fix). Move the Dial goroutine out of the connManager - this simplifies things *and* makes another goroutine go away and removes the need for connErr channels - can now just be dns.Conn. Also: Revert "plugin/forward: gracefull stop (#1701)" This reverts commit135377bf77. Revert "rework TestProxyClose (#1735)" This reverts commit9e8893a0b5.
This commit is contained in:
		| @@ -34,16 +34,6 @@ func (p *Proxy) updateRtt(newRtt time.Duration) { | ||||
| } | ||||
|  | ||||
| func (p *Proxy) connect(ctx context.Context, state request.Request, forceTCP, metric bool) (*dns.Msg, error) { | ||||
| 	atomic.AddInt32(&p.inProgress, 1) | ||||
| 	defer func() { | ||||
| 		if atomic.AddInt32(&p.inProgress, -1) == 0 { | ||||
| 			p.checkStopTransport() | ||||
| 		} | ||||
| 	}() | ||||
| 	if atomic.LoadUint32(&p.state) != running { | ||||
| 		return nil, errStopped | ||||
| 	} | ||||
|  | ||||
| 	start := time.Now() | ||||
|  | ||||
| 	proto := state.Proto() | ||||
| @@ -55,6 +45,7 @@ func (p *Proxy) connect(ctx context.Context, state request.Request, forceTCP, me | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
|  | ||||
| 	// Set buffer size correctly for this client. | ||||
| 	conn.UDPSize = uint16(state.Size()) | ||||
| 	if conn.UDPSize < 512 { | ||||
|   | ||||
| @@ -119,7 +119,7 @@ func (f *Forward) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg | ||||
|  | ||||
| 		if err != nil { | ||||
| 			// Kick off health check to see if *our* upstream is broken. | ||||
| 			if f.maxfails != 0 && err != errStopped { | ||||
| 			if f.maxfails != 0 { | ||||
| 				proxy.Healthcheck() | ||||
| 			} | ||||
|  | ||||
| @@ -185,7 +185,6 @@ var ( | ||||
| 	errNoHealthy     = errors.New("no healthy proxies") | ||||
| 	errNoForward     = errors.New("no forwarder defined") | ||||
| 	errCachedClosed  = errors.New("cached connection was closed by peer") | ||||
| 	errStopped       = errors.New("proxy has been stopped") | ||||
| ) | ||||
|  | ||||
| // policy tells forward what policy for selecting upstream it uses. | ||||
|   | ||||
| @@ -14,13 +14,6 @@ type persistConn struct { | ||||
| 	used time.Time | ||||
| } | ||||
|  | ||||
| // connErr is used to communicate the connection manager. | ||||
| type connErr struct { | ||||
| 	c      *dns.Conn | ||||
| 	err    error | ||||
| 	cached bool | ||||
| } | ||||
|  | ||||
| // transport hold the persistent cache. | ||||
| type transport struct { | ||||
| 	conns     map[string][]*persistConn //  Buckets for udp, tcp and tcp-tls. | ||||
| @@ -29,8 +22,8 @@ type transport struct { | ||||
| 	tlsConfig *tls.Config | ||||
|  | ||||
| 	dial  chan string | ||||
| 	yield chan connErr | ||||
| 	ret   chan connErr | ||||
| 	yield chan *dns.Conn | ||||
| 	ret   chan *dns.Conn | ||||
| 	stop  chan bool | ||||
| } | ||||
|  | ||||
| @@ -40,18 +33,11 @@ func newTransport(addr string, tlsConfig *tls.Config) *transport { | ||||
| 		expire: defaultExpire, | ||||
| 		addr:   addr, | ||||
| 		dial:   make(chan string), | ||||
| 		yield:  make(chan connErr), | ||||
| 		ret:    make(chan connErr), | ||||
| 		yield:  make(chan *dns.Conn), | ||||
| 		ret:    make(chan *dns.Conn), | ||||
| 		stop:   make(chan bool), | ||||
| 	} | ||||
| 	go func() { | ||||
| 		t.connManager() | ||||
| 		// if connManager returns it has been stopped. | ||||
| 		close(t.stop) | ||||
| 		close(t.yield) | ||||
| 		close(t.dial) | ||||
| 		// close(t.ret) // we can still be dialing and wanting to send back the socket on t.ret | ||||
| 	}() | ||||
| 	go func() { t.connManager() }() | ||||
| 	return t | ||||
| } | ||||
|  | ||||
| @@ -80,7 +66,7 @@ Wait: | ||||
| 				if time.Since(pc.used) < t.expire { | ||||
| 					// Found one, remove from pool and return this conn. | ||||
| 					t.conns[proto] = t.conns[proto][i+1:] | ||||
| 					t.ret <- connErr{pc.c, nil, true} | ||||
| 					t.ret <- pc.c | ||||
| 					continue Wait | ||||
| 				} | ||||
| 				// This conn has expired. Close it. | ||||
| @@ -91,35 +77,27 @@ Wait: | ||||
| 			t.conns[proto] = t.conns[proto][i:] | ||||
| 			SocketGauge.WithLabelValues(t.addr).Set(float64(t.len())) | ||||
|  | ||||
| 			go func() { | ||||
| 				if proto != "tcp-tls" { | ||||
| 					c, err := dns.DialTimeout(proto, t.addr, dialTimeout) | ||||
| 					t.ret <- connErr{c, err, false} | ||||
| 					return | ||||
| 				} | ||||
|  | ||||
| 				c, err := dns.DialTimeoutWithTLS("tcp", t.addr, t.tlsConfig, dialTimeout) | ||||
| 				t.ret <- connErr{c, err, false} | ||||
| 			}() | ||||
| 			t.ret <- nil | ||||
|  | ||||
| 		case conn := <-t.yield: | ||||
|  | ||||
| 			SocketGauge.WithLabelValues(t.addr).Set(float64(t.len() + 1)) | ||||
|  | ||||
| 			// no proto here, infer from config and conn | ||||
| 			if _, ok := conn.c.Conn.(*net.UDPConn); ok { | ||||
| 				t.conns["udp"] = append(t.conns["udp"], &persistConn{conn.c, time.Now()}) | ||||
| 			if _, ok := conn.Conn.(*net.UDPConn); ok { | ||||
| 				t.conns["udp"] = append(t.conns["udp"], &persistConn{conn, time.Now()}) | ||||
| 				continue Wait | ||||
| 			} | ||||
|  | ||||
| 			if t.tlsConfig == nil { | ||||
| 				t.conns["tcp"] = append(t.conns["tcp"], &persistConn{conn.c, time.Now()}) | ||||
| 				t.conns["tcp"] = append(t.conns["tcp"], &persistConn{conn, time.Now()}) | ||||
| 				continue Wait | ||||
| 			} | ||||
|  | ||||
| 			t.conns["tcp-tls"] = append(t.conns["tcp-tls"], &persistConn{conn.c, time.Now()}) | ||||
| 			t.conns["tcp-tls"] = append(t.conns["tcp-tls"], &persistConn{conn, time.Now()}) | ||||
|  | ||||
| 		case <-t.stop: | ||||
| 			close(t.ret) | ||||
| 			return | ||||
| 		} | ||||
| 	} | ||||
| @@ -134,16 +112,24 @@ func (t *transport) Dial(proto string) (*dns.Conn, bool, error) { | ||||
|  | ||||
| 	t.dial <- proto | ||||
| 	c := <-t.ret | ||||
| 	return c.c, c.cached, c.err | ||||
|  | ||||
| 	if c != nil { | ||||
| 		return c, true, nil | ||||
| 	} | ||||
|  | ||||
| 	if proto == "tcp-tls" { | ||||
| 		conn, err := dns.DialTimeoutWithTLS("tcp", t.addr, t.tlsConfig, dialTimeout) | ||||
| 		return conn, false, err | ||||
| 	} | ||||
| 	conn, err := dns.DialTimeout(proto, t.addr, dialTimeout) | ||||
| 	return conn, false, err | ||||
| } | ||||
|  | ||||
| // Yield return the connection to transport for reuse. | ||||
| func (t *transport) Yield(c *dns.Conn) { | ||||
| 	t.yield <- connErr{c, nil, false} | ||||
| } | ||||
| func (t *transport) Yield(c *dns.Conn) { t.yield <- c } | ||||
|  | ||||
| // Stop stops the transport's connection manager. | ||||
| func (t *transport) Stop() { t.stop <- true } | ||||
| func (t *transport) Stop() { close(t.stop) } | ||||
|  | ||||
| // SetExpire sets the connection expire time in transport. | ||||
| func (t *transport) SetExpire(expire time.Duration) { t.expire = expire } | ||||
|   | ||||
| @@ -24,9 +24,6 @@ type Proxy struct { | ||||
| 	fails uint32 | ||||
|  | ||||
| 	avgRtt int64 | ||||
|  | ||||
| 	state      uint32 | ||||
| 	inProgress int32 | ||||
| } | ||||
|  | ||||
| // NewProxy returns a new proxy. | ||||
| @@ -85,26 +82,15 @@ func (p *Proxy) Down(maxfails uint32) bool { | ||||
| 	return fails > maxfails | ||||
| } | ||||
|  | ||||
| // close stops the health checking goroutine and connection manager. | ||||
| // close stops the health checking goroutine. | ||||
| func (p *Proxy) close() { | ||||
| 	if atomic.CompareAndSwapUint32(&p.state, running, stopping) { | ||||
| 	p.probe.Stop() | ||||
| 	} | ||||
| 	if atomic.LoadInt32(&p.inProgress) == 0 { | ||||
| 		p.checkStopTransport() | ||||
| 	} | ||||
| 	p.transport.Stop() | ||||
| } | ||||
|  | ||||
| // start starts the proxy's healthchecking. | ||||
| func (p *Proxy) start(duration time.Duration) { p.probe.Start(duration) } | ||||
|  | ||||
| // checkStopTransport checks if stop was requested and stops connection manager | ||||
| func (p *Proxy) checkStopTransport() { | ||||
| 	if atomic.CompareAndSwapUint32(&p.state, stopping, stopped) { | ||||
| 		p.transport.Stop() | ||||
| 	} | ||||
| } | ||||
|  | ||||
| const ( | ||||
| 	dialTimeout = 4 * time.Second | ||||
| 	timeout     = 2 * time.Second | ||||
| @@ -112,9 +98,3 @@ const ( | ||||
| 	minTimeout  = 10 * time.Millisecond | ||||
| 	hcDuration  = 500 * time.Millisecond | ||||
| ) | ||||
|  | ||||
| const ( | ||||
| 	running = iota | ||||
| 	stopping | ||||
| 	stopped | ||||
| ) | ||||
|   | ||||
| @@ -2,9 +2,7 @@ package forward | ||||
|  | ||||
| import ( | ||||
| 	"context" | ||||
| 	"runtime" | ||||
| 	"testing" | ||||
| 	"time" | ||||
|  | ||||
| 	"github.com/coredns/coredns/plugin/pkg/dnstest" | ||||
| 	"github.com/coredns/coredns/plugin/test" | ||||
| @@ -28,50 +26,15 @@ func TestProxyClose(t *testing.T) { | ||||
| 	ctx := context.TODO() | ||||
|  | ||||
| 	for i := 0; i < 100; i++ { | ||||
| 		p := NewProxy(s.Addr, nil /* no TLS */) | ||||
| 		p := NewProxy(s.Addr, nil) | ||||
| 		p.start(hcDuration) | ||||
|  | ||||
| 		doneCnt := 0 | ||||
| 		doneCh := make(chan bool) | ||||
| 		timeCh := time.After(10 * time.Second) | ||||
| 		go func() { | ||||
| 			p.connect(ctx, state, false, false) | ||||
| 			doneCh <- true | ||||
| 		}() | ||||
| 		go func() { | ||||
| 			p.connect(ctx, state, true, false) | ||||
| 			doneCh <- true | ||||
| 		}() | ||||
| 		go func() { | ||||
| 			p.close() | ||||
| 			doneCh <- true | ||||
| 		}() | ||||
| 		go func() { | ||||
| 			p.connect(ctx, state, false, false) | ||||
| 			doneCh <- true | ||||
| 		}() | ||||
| 		go func() { | ||||
| 			p.connect(ctx, state, true, false) | ||||
| 			doneCh <- true | ||||
| 		}() | ||||
| 		go func() { p.connect(ctx, state, false, false) }() | ||||
| 		go func() { p.connect(ctx, state, true, false) }() | ||||
| 		go func() { p.connect(ctx, state, false, false) }() | ||||
| 		go func() { p.connect(ctx, state, true, false) }() | ||||
|  | ||||
| 		for doneCnt < 5 { | ||||
| 			select { | ||||
| 			case <-doneCh: | ||||
| 				doneCnt++ | ||||
| 			case <-timeCh: | ||||
| 				t.Error("TestProxyClose is running too long, dumping goroutines:") | ||||
| 				buf := make([]byte, 100000) | ||||
| 				stackSize := runtime.Stack(buf, true) | ||||
| 				t.Fatal(string(buf[:stackSize])) | ||||
| 			} | ||||
| 		} | ||||
| 		if p.inProgress != 0 { | ||||
| 			t.Errorf("unexpected query in progress") | ||||
| 		} | ||||
| 		if p.state != stopped { | ||||
| 			t.Errorf("unexpected proxy state, expected %d, got %d", stopped, p.state) | ||||
| 		} | ||||
| 		p.close() | ||||
| 	} | ||||
| } | ||||
|  | ||||
|   | ||||
		Reference in New Issue
	
	Block a user