Files
coredns/test/quic_test.go
Nico Berlee 7d7bbc8061 fix: prevent QUIC reload panic by lazily initializing the listener (#7680)
* fix: prevent QUIC reload panic by lazily initializing the listener

ServePacket on reload receives the reused PacketConn before the new
ServerQUIC has recreated its quic.Listener, so quicListener is nil and
the process panics. Lazily initialise quicListener from the provided
PacketConn when it’s nil and then proceed with ServeQUIC.

fixes: #7679
Signed-off-by: Nico Berlee <nico.berlee@on2it.net>

* test: add regression test for QUIC reload panic

Signed-off-by: Nico Berlee <nico.berlee@on2it.net>

---------

Signed-off-by: Nico Berlee <nico.berlee@on2it.net>
2025-11-18 08:34:29 -08:00

431 lines
11 KiB
Go

package test
import (
"context"
"crypto/tls"
"encoding/binary"
"errors"
"io"
"strings"
"sync"
"testing"
"time"
"github.com/coredns/coredns/core/dnsserver"
ctls "github.com/coredns/coredns/plugin/pkg/tls"
"github.com/miekg/dns"
"github.com/quic-go/quic-go"
)
var quicCorefile = `quic://.:0 {
tls ../plugin/tls/test_cert.pem ../plugin/tls/test_key.pem ../plugin/tls/test_ca.pem
whoami
}`
var quicReloadCorefile = `quic://.:0 {
tls ../plugin/tls/test_cert.pem ../plugin/tls/test_key.pem ../plugin/tls/test_ca.pem
whoami
reload 2s
}`
// Corefile with custom stream limits
var quicLimitCorefile = `quic://.:0 {
tls ../plugin/tls/test_cert.pem ../plugin/tls/test_key.pem ../plugin/tls/test_ca.pem
quic {
max_streams 5
worker_pool_size 10
}
whoami
}`
func TestQUIC(t *testing.T) {
q, udp, _, err := CoreDNSServerAndPorts(quicCorefile)
if err != nil {
t.Fatalf("Could not get CoreDNS serving instance: %s", err)
}
defer q.Stop()
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
conn, err := quic.DialAddr(ctx, convertAddress(udp), generateTLSConfig(), nil)
if err != nil {
t.Fatalf("Expected no error but got: %s", err)
}
m := createTestMsg()
streamSync, err := conn.OpenStreamSync(ctx)
if err != nil {
t.Errorf("Expected no error but got: %s", err)
}
_, err = streamSync.Write(m)
if err != nil {
t.Errorf("Expected no error but got: %s", err)
}
_ = streamSync.Close()
sizeBuf := make([]byte, 2)
_, err = io.ReadFull(streamSync, sizeBuf)
if err != nil {
t.Errorf("Expected no error but got: %s", err)
}
size := binary.BigEndian.Uint16(sizeBuf)
buf := make([]byte, size)
_, err = io.ReadFull(streamSync, buf)
if err != nil {
t.Errorf("Expected no error but got: %s", err)
}
d := new(dns.Msg)
err = d.Unpack(buf)
if err != nil {
t.Errorf("Expected no error but got: %s", err)
}
if d.Rcode != dns.RcodeSuccess {
t.Errorf("Expected success but got %d", d.Rcode)
}
if len(d.Extra) != 2 {
t.Errorf("Expected 2 RRs in additional section, but got %d", len(d.Extra))
}
}
func TestQUICReloadDoesNotPanic(t *testing.T) {
inst, udp, _, err := CoreDNSServerAndPorts(quicCorefile)
if err != nil {
t.Fatalf("Could not get CoreDNS serving instance: %s", err)
}
t.Cleanup(func() { inst.Stop() })
assertQUICQuerySucceeds(t, udp)
restart, err := inst.Restart(NewInput(quicReloadCorefile))
if err != nil {
t.Fatalf("Failed to restart CoreDNS: %s", err)
}
t.Cleanup(func() { restart.Stop() })
udpReload, _ := CoreDNSServerPorts(restart, 0)
if udpReload == "" {
t.Fatal("Failed to determine QUIC listener address after reload")
}
assertQUICQuerySucceeds(t, udpReload)
}
func TestQUICProtocolError(t *testing.T) {
q, udp, _, err := CoreDNSServerAndPorts(quicCorefile)
if err != nil {
t.Fatalf("Could not get CoreDNS serving instance: %s", err)
}
defer q.Stop()
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
conn, err := quic.DialAddr(ctx, convertAddress(udp), generateTLSConfig(), nil)
if err != nil {
t.Fatalf("Expected no error but got: %s", err)
}
m := createInvalidDOQMsg()
streamSync, err := conn.OpenStreamSync(ctx)
if err != nil {
t.Errorf("Expected no error but got: %s", err)
}
_, err = streamSync.Write(m)
if err != nil {
t.Errorf("Expected no error but got: %s", err)
}
_ = streamSync.Close()
errorBuf := make([]byte, 2)
_, err = io.ReadFull(streamSync, errorBuf)
if err == nil {
t.Errorf("Expected protocol error but got: %s", errorBuf)
}
if !isProtocolErr(err) {
t.Errorf("Expected \"Application Error 0x2\" but got: %s", err)
}
}
// TestQUICStreamLimits tests that the max_streams limit is correctly enforced
func TestQUICStreamLimits(t *testing.T) {
q, udp, _, err := CoreDNSServerAndPorts(quicLimitCorefile)
if err != nil {
t.Fatalf("Could not get CoreDNS serving instance: %s", err)
}
defer q.Stop()
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
conn, err := quic.DialAddr(ctx, convertAddress(udp), generateTLSConfig(), nil)
if err != nil {
t.Fatalf("Expected no error but got: %s", err)
}
m := createTestMsg()
// Test opening exactly the max number of streams
var wg sync.WaitGroup
streamCount := 5 // Must match max_streams in quicLimitCorefile
successCount := 0
var mu sync.Mutex
// Create a slice to store all the streams so we can keep them open
streams := make([]*quic.Stream, 0, streamCount)
streamsMu := sync.Mutex{}
// Attempt to open exactly the configured number of streams
for i := range streamCount {
wg.Add(1)
go func(idx int) {
defer wg.Done()
// Open stream
streamSync, err := conn.OpenStreamSync(ctx)
if err != nil {
t.Logf("Stream %d: Failed to open: %s", idx, err)
return
}
// Store the stream so we can keep it open
streamsMu.Lock()
streams = append(streams, streamSync)
streamsMu.Unlock()
// Write DNS message
_, err = streamSync.Write(m)
if err != nil {
t.Logf("Stream %d: Failed to write: %s", idx, err)
return
}
// Read response
sizeBuf := make([]byte, 2)
_, err = io.ReadFull(streamSync, sizeBuf)
if err != nil {
t.Logf("Stream %d: Failed to read size: %s", idx, err)
return
}
size := binary.BigEndian.Uint16(sizeBuf)
buf := make([]byte, size)
_, err = io.ReadFull(streamSync, buf)
if err != nil {
t.Logf("Stream %d: Failed to read response: %s", idx, err)
return
}
mu.Lock()
successCount++
mu.Unlock()
}(i)
}
wg.Wait()
if successCount != streamCount {
t.Errorf("Expected all %d streams to succeed, but only %d succeeded", streamCount, successCount)
}
// Now try to open more streams beyond the limit while keeping existing streams open
// The QUIC protocol doesn't immediately reject streams; they might be allowed
// to open but will be blocked (flow control) until other streams close
// First, make sure none of our streams have been closed
for i, s := range streams {
if s == nil {
t.Errorf("Stream %d is nil", i)
continue
}
}
// Try to open a batch of additional streams - with streams limited to 5,
// these should either block or be queued but should not allow concurrent use
extraCount := 10
extraSuccess := 0
var extraSuccessMu sync.Mutex
// Set a shorter timeout for these attempts
extraCtx, extraCancel := context.WithTimeout(context.Background(), 2*time.Second)
defer extraCancel()
var extraWg sync.WaitGroup
// Create a channel to signal test completion
done := make(chan struct{})
// Launch goroutines to attempt opening additional streams
for i := range extraCount {
extraWg.Add(1)
go func(idx int) {
defer extraWg.Done()
select {
case <-done:
return // Test is finishing, abandon attempts
default:
// Continue with the test
}
// Attempt to open an additional stream
stream, err := conn.OpenStreamSync(extraCtx)
if err != nil {
t.Logf("Extra stream %d correctly failed to open: %s", idx, err)
return
}
// If we got this far, we managed to open a stream
// But we shouldn't be able to use more than max_streams concurrently
_, err = stream.Write(m)
if err != nil {
t.Logf("Extra stream %d failed to write: %s", idx, err)
return
}
// Read response
sizeBuf := make([]byte, 2)
_, err = io.ReadFull(stream, sizeBuf)
if err != nil {
t.Logf("Extra stream %d failed to read: %s", idx, err)
return
}
// This stream completed successfully
extraSuccessMu.Lock()
extraSuccess++
extraSuccessMu.Unlock()
// Close the stream explicitly
_ = stream.Close()
}(i)
}
// Start closing original streams after a delay
// This should allow extra streams to proceed as slots become available
time.Sleep(500 * time.Millisecond)
// Close all the original streams
for _, s := range streams {
_ = s.Close()
}
// Allow extra streams some time to progress
extraWg.Wait()
close(done)
// Since original streams are now closed, extra streams might succeed
// But we shouldn't see more than max_streams succeed during the blocked phase
if extraSuccess > streamCount {
t.Logf("Warning: %d extra streams succeeded, which is more than the limit of %d. This might be because original streams were closed.",
extraSuccess, streamCount)
}
t.Logf("%d/%d extra streams were able to complete after original streams were closed",
extraSuccess, extraCount)
}
func isProtocolErr(err error) bool {
var qAppErr *quic.ApplicationError
return errors.As(err, &qAppErr) && qAppErr.ErrorCode == 2
}
// convertAddress transforms the address given in CoreDNSServerAndPorts to a format
// that quic.DialAddr can read. It is unable to use [::]:61799, see:
// "INTERNAL_ERROR (local): write udp [::]:50676->[::]:61799: sendmsg: no route to host"
// So it transforms it to localhost:61799.
func convertAddress(address string) string {
if strings.HasPrefix(address, "[::]") {
address = strings.Replace(address, "[::]", "localhost", 1)
}
return address
}
func generateTLSConfig() *tls.Config {
tlsConfig, err := ctls.NewTLSConfig(
"../plugin/tls/test_cert.pem",
"../plugin/tls/test_key.pem",
"../plugin/tls/test_ca.pem")
if err != nil {
panic(err)
}
tlsConfig.NextProtos = []string{"doq"}
tlsConfig.InsecureSkipVerify = true
return tlsConfig
}
func createTestMsg() []byte {
m := new(dns.Msg)
m.SetQuestion("whoami.example.org.", dns.TypeA)
m.Id = 0
msg, _ := m.Pack()
return dnsserver.AddPrefix(msg)
}
func createInvalidDOQMsg() []byte {
m := new(dns.Msg)
m.SetQuestion("whoami.example.org.", dns.TypeA)
msg, _ := m.Pack()
return msg
}
func assertQUICQuerySucceeds(t *testing.T, address string) {
t.Helper()
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
conn, err := quic.DialAddr(ctx, convertAddress(address), generateTLSConfig(), nil)
if err != nil {
t.Fatalf("Expected no error but got: %s", err)
}
defer func() { _ = conn.CloseWithError(0, "") }()
stream, err := conn.OpenStreamSync(ctx)
if err != nil {
t.Fatalf("Expected no error but got: %s", err)
}
defer func() { _ = stream.Close() }()
msg := createTestMsg()
if _, err = stream.Write(msg); err != nil {
t.Fatalf("Expected no error but got: %s", err)
}
sizeBuf := make([]byte, 2)
if _, err = io.ReadFull(stream, sizeBuf); err != nil {
t.Fatalf("Expected no error but got: %s", err)
}
size := binary.BigEndian.Uint16(sizeBuf)
buf := make([]byte, size)
if _, err = io.ReadFull(stream, buf); err != nil {
t.Fatalf("Expected no error but got: %s", err)
}
resp := new(dns.Msg)
if err = resp.Unpack(buf); err != nil {
t.Fatalf("Expected no error but got: %s", err)
}
if resp.Rcode != dns.RcodeSuccess {
t.Fatalf("Expected success but got %d", resp.Rcode)
}
if len(resp.Extra) != 2 {
t.Fatalf("Expected 2 RRs in additional section, but got %d", len(resp.Extra))
}
}