diff --git a/core/dnsserver/server.go b/core/dnsserver/server.go index 89ee5f31f..3f7441dfc 100644 --- a/core/dnsserver/server.go +++ b/core/dnsserver/server.go @@ -6,7 +6,6 @@ import ( "fmt" "maps" "net" - "runtime" "runtime/debug" "strings" "sync" @@ -42,7 +41,6 @@ type Server struct { m sync.Mutex // protects the servers zones map[string][]*Config // zones keyed by their address - dnsWg sync.WaitGroup // used to wait on outstanding connections graceTimeout time.Duration // the maximum duration of a graceful shutdown trace trace.Trace // the trace plugin for the server debug bool // disable recover() @@ -52,8 +50,8 @@ type Server struct { tsigSecret map[string]string // Ensure Stop is idempotent when invoked concurrently (e.g., during reload and SIGTERM). - stopOnce sync.Once - wgDoneOnce sync.Once + stopOnce sync.Once + stopErr error } // MetadataCollector is a plugin that can retrieve metadata functions from all metadata providing plugins @@ -74,14 +72,6 @@ func NewServer(addr string, group []*Config) (*Server, error) { tsigSecret: make(map[string]string), } - // We have to bound our wg with one increment - // to prevent a "race condition" that is hard-coded - // into sync.WaitGroup.Wait() - basically, an add - // with a positive delta must be guaranteed to - // occur before Wait() is called on the wg. - // In a way, this kind of acts as a safety barrier. - s.dnsWg.Add(1) - for _, site := range group { if site.Debug { s.debug = true @@ -209,44 +199,36 @@ func (s *Server) ListenPacket() (net.PacketConn, error) { return p, nil } -// Stop stops the server. It blocks until the server is -// totally stopped. On POSIX systems, it will wait for -// connections to close (up to a max timeout of a few -// seconds); on Windows it will close the listener -// immediately. +// Stop attempts to gracefully stop the server. +// It waits until the server is stopped and its connections are closed, +// up to a max timeout of a few seconds. If unsuccessful, an error is returned. +// // This implements Caddy.Stopper interface. -func (s *Server) Stop() (err error) { - var onceErr error +func (s *Server) Stop() error { s.stopOnce.Do(func() { - if runtime.GOOS != "windows" { - // force connections to close after timeout - done := make(chan struct{}) - go func() { - // decrement our initial increment used as a barrier, but only once - s.wgDoneOnce.Do(func() { s.dnsWg.Done() }) - s.dnsWg.Wait() - close(done) - }() + ctx, cancelCtx := context.WithTimeout(context.Background(), s.graceTimeout) + defer cancelCtx() - // Wait for remaining connections to finish or - // force them all to close after timeout - select { - case <-time.After(s.graceTimeout): - case <-done: - } - } - - // Close the listener now; this stops the server without delay + var wg sync.WaitGroup s.m.Lock() for _, s1 := range s.server { // We might not have started and initialized the full set of servers - if s1 != nil { - onceErr = s1.Shutdown() + if s1 == nil { + continue } + + wg.Add(1) + go func() { + s1.ShutdownContext(ctx) + wg.Done() + }() } s.m.Unlock() + wg.Wait() + + s.stopErr = ctx.Err() }) - return onceErr + return s.stopErr } // Address together with Stop() implement caddy.GracefulServer. diff --git a/core/dnsserver/server_test.go b/core/dnsserver/server_test.go index 8eca39b78..fef01faed 100644 --- a/core/dnsserver/server_test.go +++ b/core/dnsserver/server_test.go @@ -2,8 +2,11 @@ package dnsserver import ( "context" + "errors" + "net" "sync" "testing" + "time" "github.com/coredns/coredns/plugin" "github.com/coredns/coredns/plugin/pkg/log" @@ -20,6 +23,24 @@ func (tp testPlugin) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns. func (tp testPlugin) Name() string { return "local" } +// blockingPlugin uses sync.Mutex to simulate extended processing. +type blockingPlugin struct { + sync.Mutex +} + +func (b *blockingPlugin) Name() string { return "blocking" } + +func (b *blockingPlugin) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) { + // Respond immediately to avoid waiting in dns.Exchange + m := new(dns.Msg) + m.SetRcodeFormatError(r) + w.WriteMsg(m) + + b.Lock() + defer b.Unlock() + return dns.RcodeSuccess, nil +} + func testConfig(transport string, p plugin.Handler) *Config { c := &Config{ Zone: "example.com.", @@ -104,6 +125,44 @@ func TestStacktrace(t *testing.T) { } } +func TestGracefulStopTimeout_Internal(t *testing.T) { + p := new(blockingPlugin) + cfg := testConfig("dns", p) + + s, err := NewServer("127.0.0.1:0", []*Config{cfg}) + if err != nil { + t.Fatalf("NewServer failed: %v", err) + } + + // Shorten the graceful timeout + s.graceTimeout = 500 * time.Millisecond + + pc, err := net.ListenPacket("udp", "127.0.0.1:0") + if err != nil { + t.Fatalf("ListenPacket failed: %v", err) + } + defer pc.Close() + + go s.ServePacket(pc) + udp := pc.LocalAddr().String() + + // Block the handler + p.Lock() + defer p.Unlock() + + m := new(dns.Msg) + m.SetQuestion("example.com.", dns.TypeA) + _, err = dns.Exchange(m, udp) + if err != nil { + t.Fatalf("dns.Exchange failed: %v", err) + } + + err = s.Stop() + if !errors.Is(err, context.DeadlineExceeded) { + t.Fatalf("expected context.DeadlineExceeded, got %v", err) + } +} + func BenchmarkCoreServeDNS(b *testing.B) { s, err := NewServer("127.0.0.1:53", []*Config{testConfig("dns", testPlugin{})}) if err != nil { @@ -121,22 +180,3 @@ func BenchmarkCoreServeDNS(b *testing.B) { s.ServeDNS(ctx, w, m) } } - -// Validates Stop is idempotent and safe under concurrent calls. -func TestStopIsIdempotent(t *testing.T) { - t.Parallel() - - s := &Server{} - s.dnsWg.Add(1) - - const n = 10 - var wg sync.WaitGroup - wg.Add(n) - for range n { - go func() { - defer wg.Done() - _ = s.Stop() - }() - } - wg.Wait() -}