mirror of
https://github.com/coredns/coredns.git
synced 2025-10-26 15:54:16 -04:00
fix(transfer): goroutine leak on axfr err (#7516)
This commit is contained in:
@@ -134,6 +134,11 @@ func (t *Transfer) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Ms
|
||||
select {
|
||||
case ch <- &dns.Envelope{RR: rrs}:
|
||||
case err := <-errCh:
|
||||
// Client errored; drain pchan to avoid blocking the producer goroutine.
|
||||
go func() {
|
||||
for range pchan {
|
||||
}
|
||||
}()
|
||||
return dns.RcodeServerFailure, err
|
||||
}
|
||||
l += len(rrs)
|
||||
@@ -161,11 +166,7 @@ func (t *Transfer) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Ms
|
||||
}
|
||||
|
||||
if len(rrs) > 0 {
|
||||
select {
|
||||
case ch <- &dns.Envelope{RR: rrs}:
|
||||
case err := <-errCh:
|
||||
return dns.RcodeServerFailure, err
|
||||
}
|
||||
ch <- &dns.Envelope{RR: rrs}
|
||||
l += len(rrs)
|
||||
}
|
||||
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"context"
|
||||
"fmt"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/coredns/coredns/plugin"
|
||||
"github.com/coredns/coredns/plugin/pkg/dnstest"
|
||||
@@ -277,3 +278,64 @@ func TestTransferNotAllowed(t *testing.T) {
|
||||
t.Errorf("Expected REFUSED response code, got %s", dns.RcodeToString[w.Msg.Rcode])
|
||||
}
|
||||
}
|
||||
|
||||
// errWriter is a dns.ResponseWriter that simulates a client error on write.
|
||||
type errWriter struct {
|
||||
test.ResponseWriter
|
||||
}
|
||||
|
||||
func (e *errWriter) WriteMsg(m *dns.Msg) error { return fmt.Errorf("write error") }
|
||||
|
||||
// blockingTransferer produces many records into the channel and signals when done.
|
||||
type blockingTransferer struct {
|
||||
Zone string
|
||||
done chan struct{}
|
||||
}
|
||||
|
||||
func (b *blockingTransferer) Name() string { return "blockingtransferer" }
|
||||
func (b *blockingTransferer) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) {
|
||||
return 0, nil
|
||||
}
|
||||
func (b *blockingTransferer) Transfer(zone string, serial uint32) (<-chan []dns.RR, error) {
|
||||
if zone != b.Zone {
|
||||
return nil, ErrNotAuthoritative
|
||||
}
|
||||
ch := make(chan []dns.RR, 2)
|
||||
go func() {
|
||||
defer close(ch)
|
||||
defer close(b.done)
|
||||
soa := test.SOA(fmt.Sprintf("%s 100 IN SOA ns.dns.%s hostmaster.%s %d 7200 1800 86400 100", b.Zone, b.Zone, b.Zone, 1))
|
||||
ch <- []dns.RR{soa}
|
||||
for range 2000 {
|
||||
ch <- []dns.RR{test.A("ns.dns." + b.Zone + " 100 IN A 1.2.3.4")}
|
||||
}
|
||||
ch <- []dns.RR{soa}
|
||||
}()
|
||||
return ch, nil
|
||||
}
|
||||
|
||||
// Test that when the client errors mid-transfer, the server drains the producer channel
|
||||
// so the producer goroutine can complete (no leak/block on small buffer).
|
||||
func TestTransferDrainsProducerOnClientError(t *testing.T) {
|
||||
b := &blockingTransferer{Zone: "example.org.", done: make(chan struct{})}
|
||||
|
||||
transfer := &Transfer{
|
||||
Transferers: []Transferer{b},
|
||||
xfrs: []*xfr{{Zones: []string{"example.org."}, to: []string{"*"}}},
|
||||
Next: b,
|
||||
}
|
||||
|
||||
ctx := context.TODO()
|
||||
w := &errWriter{ResponseWriter: test.ResponseWriter{TCP: true}}
|
||||
m := &dns.Msg{}
|
||||
m.SetAxfr("example.org.")
|
||||
|
||||
_, _ = transfer.ServeDNS(ctx, w, m)
|
||||
|
||||
select {
|
||||
case <-b.done:
|
||||
// success
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("producer goroutine did not finish; channel likely not drained on client error")
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user