diff --git a/test/multisocket_test.go b/test/multisocket_test.go index 36da89a64..0ed7ace18 100644 --- a/test/multisocket_test.go +++ b/test/multisocket_test.go @@ -4,12 +4,21 @@ import ( "fmt" "net" "testing" + "time" "github.com/miekg/dns" ) -// These tests need a fixed port, because :0 selects a random port for each socket, but we need all sockets to be on -// the same port. +// pickPort returns a free TCP port on 127.0.0.1 and closes the probe listener. +func pickPort(t *testing.T) int { + t.Helper() + ln, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("probe listen failed: %v", err) + } + defer ln.Close() + return ln.Addr().(*net.TCPAddr).Port +} func TestMultisocket(t *testing.T) { tests := []struct { @@ -55,10 +64,11 @@ func TestMultisocket(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { s, err := CoreDNSServer(test.corefile) - defer s.Stop() if err != nil { t.Fatalf("Could not get CoreDNS serving instance: %s", err) } + defer s.Stop() + // check number of servers if len(s.Servers()) != test.expectedServers { t.Fatalf("Expected %d servers, got %d", test.expectedServers, len(s.Servers())) @@ -82,50 +92,68 @@ func TestMultisocket(t *testing.T) { } } +// NOTE: restart uses a different port to avoid transient EADDRINUSE / shutdown races +// when TCP/UDP from the previous instance haven’t fully torn down yet. func TestMultisocket_Restart(t *testing.T) { tests := []struct { name string numSocketsBefore int numSocketsAfter int }{ - { - name: "increase", - numSocketsBefore: 1, - numSocketsAfter: 2, - }, - { - name: "decrease", - numSocketsBefore: 2, - numSocketsAfter: 1, - }, - { - name: "no changes", - numSocketsBefore: 2, - numSocketsAfter: 2, - }, + {name: "increase", numSocketsBefore: 1, numSocketsAfter: 2}, + {name: "decrease", numSocketsBefore: 2, numSocketsAfter: 1}, + {name: "no changes", numSocketsBefore: 2, numSocketsAfter: 2}, } - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - corefile := `.:5058 { + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + port1 := pickPort(t) + port2 := pickPort(t) // restart onto a different free port + coreTmpl := `.:%d { multisocket %d }` - srv, err := CoreDNSServer(fmt.Sprintf(corefile, test.numSocketsBefore)) + + srv, err := CoreDNSServer(fmt.Sprintf(coreTmpl, port1, tc.numSocketsBefore)) if err != nil { - t.Fatalf("Could not get CoreDNS serving instance: %s", err) + t.Fatalf("Could not get CoreDNS serving instance: %v", err) } - if test.numSocketsBefore != len(srv.Servers()) { - t.Fatalf("Expected %d servers, got %d", test.numSocketsBefore, len(srv.Servers())) + if got := len(srv.Servers()); got != tc.numSocketsBefore { + t.Fatalf("Expected %d servers, got %d", tc.numSocketsBefore, got) } - newSrv, err := srv.Restart(NewInput(fmt.Sprintf(corefile, test.numSocketsAfter))) - if err != nil { - t.Fatalf("Could not get CoreDNS serving instance: %s", err) + resultCh := make(chan int, 1) + errCh := make(chan error, 1) + stopCh := make(chan struct{}) + + // Do the restart in a goroutine; return only the server count. + go func() { + newSrv, rerr := srv.Restart(NewInput(fmt.Sprintf(coreTmpl, port2, tc.numSocketsAfter))) + if rerr != nil { + errCh <- rerr + return + } + resultCh <- len(newSrv.Servers()) + <-stopCh + newSrv.Stop() + }() + + select { + case got := <-resultCh: + if got != tc.numSocketsAfter { + close(stopCh) // still stop the new instance + t.Fatalf("Expected %d servers, got %d", tc.numSocketsAfter, got) + } + close(stopCh) // now safe to stop the new instance + case rerr := <-errCh: + // Restart failed; stop the original instance. + srv.Stop() + t.Fatalf("Restart failed: %v", rerr) + case <-time.After(30 * time.Second): + // Timeout; stop the original instance. + srv.Stop() + t.Fatalf("Restart timed out after 30s (ports :%d→:%d, %d→%d sockets)", + port1, port2, tc.numSocketsBefore, tc.numSocketsAfter) } - if test.numSocketsAfter != len(newSrv.Servers()) { - t.Fatalf("Expected %d servers, got %d", test.numSocketsAfter, len(newSrv.Servers())) - } - newSrv.Stop() }) } }