dnsserver: Rely on dns.Server.ShutdownContext to gracefully stop (#7517)

This commit is contained in:
Ilya Kulakov
2025-09-27 06:34:03 -07:00
committed by GitHub
parent a1dfc2c84d
commit eafc352f58
2 changed files with 81 additions and 59 deletions

View File

@@ -6,7 +6,6 @@ import (
"fmt"
"maps"
"net"
"runtime"
"runtime/debug"
"strings"
"sync"
@@ -42,7 +41,6 @@ type Server struct {
m sync.Mutex // protects the servers
zones map[string][]*Config // zones keyed by their address
dnsWg sync.WaitGroup // used to wait on outstanding connections
graceTimeout time.Duration // the maximum duration of a graceful shutdown
trace trace.Trace // the trace plugin for the server
debug bool // disable recover()
@@ -52,8 +50,8 @@ type Server struct {
tsigSecret map[string]string
// Ensure Stop is idempotent when invoked concurrently (e.g., during reload and SIGTERM).
stopOnce sync.Once
wgDoneOnce sync.Once
stopOnce sync.Once
stopErr error
}
// MetadataCollector is a plugin that can retrieve metadata functions from all metadata providing plugins
@@ -74,14 +72,6 @@ func NewServer(addr string, group []*Config) (*Server, error) {
tsigSecret: make(map[string]string),
}
// We have to bound our wg with one increment
// to prevent a "race condition" that is hard-coded
// into sync.WaitGroup.Wait() - basically, an add
// with a positive delta must be guaranteed to
// occur before Wait() is called on the wg.
// In a way, this kind of acts as a safety barrier.
s.dnsWg.Add(1)
for _, site := range group {
if site.Debug {
s.debug = true
@@ -209,44 +199,36 @@ func (s *Server) ListenPacket() (net.PacketConn, error) {
return p, nil
}
// Stop stops the server. It blocks until the server is
// totally stopped. On POSIX systems, it will wait for
// connections to close (up to a max timeout of a few
// seconds); on Windows it will close the listener
// immediately.
// Stop attempts to gracefully stop the server.
// It waits until the server is stopped and its connections are closed,
// up to a max timeout of a few seconds. If unsuccessful, an error is returned.
//
// This implements Caddy.Stopper interface.
func (s *Server) Stop() (err error) {
var onceErr error
func (s *Server) Stop() error {
s.stopOnce.Do(func() {
if runtime.GOOS != "windows" {
// force connections to close after timeout
done := make(chan struct{})
go func() {
// decrement our initial increment used as a barrier, but only once
s.wgDoneOnce.Do(func() { s.dnsWg.Done() })
s.dnsWg.Wait()
close(done)
}()
ctx, cancelCtx := context.WithTimeout(context.Background(), s.graceTimeout)
defer cancelCtx()
// Wait for remaining connections to finish or
// force them all to close after timeout
select {
case <-time.After(s.graceTimeout):
case <-done:
}
}
// Close the listener now; this stops the server without delay
var wg sync.WaitGroup
s.m.Lock()
for _, s1 := range s.server {
// We might not have started and initialized the full set of servers
if s1 != nil {
onceErr = s1.Shutdown()
if s1 == nil {
continue
}
wg.Add(1)
go func() {
s1.ShutdownContext(ctx)
wg.Done()
}()
}
s.m.Unlock()
wg.Wait()
s.stopErr = ctx.Err()
})
return onceErr
return s.stopErr
}
// Address together with Stop() implement caddy.GracefulServer.

View File

@@ -2,8 +2,11 @@ package dnsserver
import (
"context"
"errors"
"net"
"sync"
"testing"
"time"
"github.com/coredns/coredns/plugin"
"github.com/coredns/coredns/plugin/pkg/log"
@@ -20,6 +23,24 @@ func (tp testPlugin) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.
func (tp testPlugin) Name() string { return "local" }
// blockingPlugin uses sync.Mutex to simulate extended processing.
type blockingPlugin struct {
sync.Mutex
}
func (b *blockingPlugin) Name() string { return "blocking" }
func (b *blockingPlugin) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) {
// Respond immediately to avoid waiting in dns.Exchange
m := new(dns.Msg)
m.SetRcodeFormatError(r)
w.WriteMsg(m)
b.Lock()
defer b.Unlock()
return dns.RcodeSuccess, nil
}
func testConfig(transport string, p plugin.Handler) *Config {
c := &Config{
Zone: "example.com.",
@@ -104,6 +125,44 @@ func TestStacktrace(t *testing.T) {
}
}
func TestGracefulStopTimeout_Internal(t *testing.T) {
p := new(blockingPlugin)
cfg := testConfig("dns", p)
s, err := NewServer("127.0.0.1:0", []*Config{cfg})
if err != nil {
t.Fatalf("NewServer failed: %v", err)
}
// Shorten the graceful timeout
s.graceTimeout = 500 * time.Millisecond
pc, err := net.ListenPacket("udp", "127.0.0.1:0")
if err != nil {
t.Fatalf("ListenPacket failed: %v", err)
}
defer pc.Close()
go s.ServePacket(pc)
udp := pc.LocalAddr().String()
// Block the handler
p.Lock()
defer p.Unlock()
m := new(dns.Msg)
m.SetQuestion("example.com.", dns.TypeA)
_, err = dns.Exchange(m, udp)
if err != nil {
t.Fatalf("dns.Exchange failed: %v", err)
}
err = s.Stop()
if !errors.Is(err, context.DeadlineExceeded) {
t.Fatalf("expected context.DeadlineExceeded, got %v", err)
}
}
func BenchmarkCoreServeDNS(b *testing.B) {
s, err := NewServer("127.0.0.1:53", []*Config{testConfig("dns", testPlugin{})})
if err != nil {
@@ -121,22 +180,3 @@ func BenchmarkCoreServeDNS(b *testing.B) {
s.ServeDNS(ctx, w, m)
}
}
// Validates Stop is idempotent and safe under concurrent calls.
func TestStopIsIdempotent(t *testing.T) {
t.Parallel()
s := &Server{}
s.dnsWg.Add(1)
const n = 10
var wg sync.WaitGroup
wg.Add(n)
for range n {
go func() {
defer wg.Done()
_ = s.Stop()
}()
}
wg.Wait()
}