diff --git a/core/dnsserver/server_quic.go b/core/dnsserver/server_quic.go index 531cbd82d..b7d7fd7ff 100644 --- a/core/dnsserver/server_quic.go +++ b/core/dnsserver/server_quic.go @@ -103,6 +103,14 @@ func NewServerQUIC(addr string, group []*Config) (*ServerQUIC, error) { // ServePacket implements caddy.UDPServer interface. func (s *ServerQUIC) ServePacket(p net.PacketConn) error { s.m.Lock() + if s.quicListener == nil { + listener, err := quic.Listen(p, s.tlsConfig, s.quicConfig) + if err != nil { + s.m.Unlock() + return err + } + s.quicListener = listener + } s.listenAddr = s.quicListener.Addr() s.m.Unlock() diff --git a/test/quic_test.go b/test/quic_test.go index e8d673d74..1027c31d9 100644 --- a/test/quic_test.go +++ b/test/quic_test.go @@ -23,6 +23,12 @@ var quicCorefile = `quic://.:0 { whoami }` +var quicReloadCorefile = `quic://.:0 { + tls ../plugin/tls/test_cert.pem ../plugin/tls/test_key.pem ../plugin/tls/test_ca.pem + whoami + reload 2s + }` + // Corefile with custom stream limits var quicLimitCorefile = `quic://.:0 { tls ../plugin/tls/test_cert.pem ../plugin/tls/test_key.pem ../plugin/tls/test_ca.pem @@ -89,6 +95,29 @@ func TestQUIC(t *testing.T) { } } +func TestQUICReloadDoesNotPanic(t *testing.T) { + inst, udp, _, err := CoreDNSServerAndPorts(quicCorefile) + if err != nil { + t.Fatalf("Could not get CoreDNS serving instance: %s", err) + } + t.Cleanup(func() { inst.Stop() }) + + assertQUICQuerySucceeds(t, udp) + + restart, err := inst.Restart(NewInput(quicReloadCorefile)) + if err != nil { + t.Fatalf("Failed to restart CoreDNS: %s", err) + } + t.Cleanup(func() { restart.Stop() }) + + udpReload, _ := CoreDNSServerPorts(restart, 0) + if udpReload == "" { + t.Fatal("Failed to determine QUIC listener address after reload") + } + + assertQUICQuerySucceeds(t, udpReload) +} + func TestQUICProtocolError(t *testing.T) { q, udp, _, err := CoreDNSServerAndPorts(quicCorefile) if err != nil { @@ -352,3 +381,50 @@ func createInvalidDOQMsg() []byte { msg, _ := m.Pack() return msg } + +func assertQUICQuerySucceeds(t *testing.T, address string) { + t.Helper() + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + conn, err := quic.DialAddr(ctx, convertAddress(address), generateTLSConfig(), nil) + if err != nil { + t.Fatalf("Expected no error but got: %s", err) + } + defer func() { _ = conn.CloseWithError(0, "") }() + + stream, err := conn.OpenStreamSync(ctx) + if err != nil { + t.Fatalf("Expected no error but got: %s", err) + } + defer func() { _ = stream.Close() }() + + msg := createTestMsg() + if _, err = stream.Write(msg); err != nil { + t.Fatalf("Expected no error but got: %s", err) + } + + sizeBuf := make([]byte, 2) + if _, err = io.ReadFull(stream, sizeBuf); err != nil { + t.Fatalf("Expected no error but got: %s", err) + } + + size := binary.BigEndian.Uint16(sizeBuf) + buf := make([]byte, size) + if _, err = io.ReadFull(stream, buf); err != nil { + t.Fatalf("Expected no error but got: %s", err) + } + + resp := new(dns.Msg) + if err = resp.Unpack(buf); err != nil { + t.Fatalf("Expected no error but got: %s", err) + } + + if resp.Rcode != dns.RcodeSuccess { + t.Fatalf("Expected success but got %d", resp.Rcode) + } + + if len(resp.Extra) != 2 { + t.Fatalf("Expected 2 RRs in additional section, but got %d", len(resp.Extra)) + } +}