diff --git a/core/dnsserver/server_quic.go b/core/dnsserver/server_quic.go index a16375d44..8589c40a8 100644 --- a/core/dnsserver/server_quic.go +++ b/core/dnsserver/server_quic.go @@ -136,6 +136,15 @@ func (s *ServerQUIC) ServeQUIC() error { } } +func acquireQUICWorker(ctx context.Context, pool chan struct{}) bool { + select { + case pool <- struct{}{}: + return true + case <-ctx.Done(): + return false + } +} + // serveQUICConnection handles a new QUIC connection. It waits for new streams // and passes them to serveQUICStream. func (s *ServerQUIC) serveQUICConnection(conn *quic.Conn) { @@ -157,29 +166,15 @@ func (s *ServerQUIC) serveQUICConnection(conn *quic.Conn) { return } - // Use a bounded worker pool with context cancellation - select { - case s.streamProcessPool <- struct{}{}: - // Got worker slot immediately - go func(st *quic.Stream, cn *quic.Conn) { - defer func() { <-s.streamProcessPool }() // Release worker slot - s.serveQUICStream(st, cn) - }(stream, conn) - default: - // Worker pool full, check for context cancellation - go func(st *quic.Stream, cn *quic.Conn) { - select { - case s.streamProcessPool <- struct{}{}: - // Got worker slot after waiting - defer func() { <-s.streamProcessPool }() // Release worker slot - s.serveQUICStream(st, cn) - case <-conn.Context().Done(): - // Connection context was cancelled while waiting - st.Close() - return - } - }(stream, conn) + if !acquireQUICWorker(conn.Context(), s.streamProcessPool) { + _ = stream.Close() + return } + + go func(st *quic.Stream, cn *quic.Conn) { + defer func() { <-s.streamProcessPool }() + s.serveQUICStream(st, cn) + }(stream, conn) } } diff --git a/core/dnsserver/server_quic_test.go b/core/dnsserver/server_quic_test.go index 8deb11c7c..19cadd2f0 100644 --- a/core/dnsserver/server_quic_test.go +++ b/core/dnsserver/server_quic_test.go @@ -2,6 +2,7 @@ package dnsserver import ( "bytes" + "context" "crypto/tls" "errors" "testing" @@ -400,3 +401,41 @@ func TestAddPrefix(t *testing.T) { }) } } + +func TestAcquireQUICWorkerWaitsForSlot(t *testing.T) { + pool := make(chan struct{}, 1) + pool <- struct{}{} + + ctx, cancel := context.WithCancel(t.Context()) + defer cancel() + + done := make(chan bool, 1) + go func() { + done <- acquireQUICWorker(ctx, pool) + }() + + select { + case <-done: + t.Fatal("acquireQUICWorker returned before a slot was released") + default: + } + + <-pool + + got := <-done + if !got { + t.Fatal("expected acquireQUICWorker to succeed after slot release") + } +} + +func TestAcquireQUICWorkerReturnsFalseOnCancelledContext(t *testing.T) { + pool := make(chan struct{}, 1) + pool <- struct{}{} + + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + if got := acquireQUICWorker(ctx, pool); got { + t.Fatal("expected acquireQUICWorker to return false when context is cancelled") + } +}