| 
									
										
										
										
											2024-11-13 20:40:25 +03:00
										 |  |  |  | package test
 | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  | import (
 | 
					
						
							|  |  |  |  | 	"fmt"
 | 
					
						
							|  |  |  |  | 	"net"
 | 
					
						
							|  |  |  |  | 	"testing"
 | 
					
						
							| 
									
										
										
										
											2025-08-11 23:37:09 +05:30
										 |  |  |  | 	"time"
 | 
					
						
							| 
									
										
										
										
											2024-11-13 20:40:25 +03:00
										 |  |  |  | 
 | 
					
						
							|  |  |  |  | 	"github.com/miekg/dns"
 | 
					
						
							|  |  |  |  | )
 | 
					
						
							|  |  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-08-11 23:37:09 +05:30
										 |  |  |  | // pickPort returns a free TCP port on 127.0.0.1 and closes the probe listener.
 | 
					
						
							|  |  |  |  | func pickPort(t *testing.T) int {
 | 
					
						
							|  |  |  |  | 	t.Helper()
 | 
					
						
							|  |  |  |  | 	ln, err := net.Listen("tcp", "127.0.0.1:0")
 | 
					
						
							|  |  |  |  | 	if err != nil {
 | 
					
						
							|  |  |  |  | 		t.Fatalf("probe listen failed: %v", err)
 | 
					
						
							|  |  |  |  | 	}
 | 
					
						
							|  |  |  |  | 	defer ln.Close()
 | 
					
						
							|  |  |  |  | 	return ln.Addr().(*net.TCPAddr).Port
 | 
					
						
							|  |  |  |  | }
 | 
					
						
							| 
									
										
										
										
											2024-11-13 20:40:25 +03:00
										 |  |  |  | 
 | 
					
						
							|  |  |  |  | func TestMultisocket(t *testing.T) {
 | 
					
						
							|  |  |  |  | 	tests := []struct {
 | 
					
						
							|  |  |  |  | 		name            string
 | 
					
						
							|  |  |  |  | 		corefile        string
 | 
					
						
							|  |  |  |  | 		expectedServers int
 | 
					
						
							|  |  |  |  | 		expectedErr     string
 | 
					
						
							|  |  |  |  | 		expectedPort    string
 | 
					
						
							|  |  |  |  | 	}{
 | 
					
						
							|  |  |  |  | 		{
 | 
					
						
							|  |  |  |  | 			name: "no multisocket",
 | 
					
						
							|  |  |  |  | 			corefile: `.:5054 {
 | 
					
						
							|  |  |  |  | 			}`,
 | 
					
						
							|  |  |  |  | 			expectedServers: 1,
 | 
					
						
							|  |  |  |  | 			expectedPort:    "5054",
 | 
					
						
							|  |  |  |  | 		},
 | 
					
						
							|  |  |  |  | 		{
 | 
					
						
							|  |  |  |  | 			name: "multisocket 1",
 | 
					
						
							|  |  |  |  | 			corefile: `.:5055 {
 | 
					
						
							|  |  |  |  | 				multisocket 1
 | 
					
						
							|  |  |  |  | 			}`,
 | 
					
						
							|  |  |  |  | 			expectedServers: 1,
 | 
					
						
							|  |  |  |  | 			expectedPort:    "5055",
 | 
					
						
							|  |  |  |  | 		},
 | 
					
						
							|  |  |  |  | 		{
 | 
					
						
							|  |  |  |  | 			name: "multisocket 2",
 | 
					
						
							|  |  |  |  | 			corefile: `.:5056 {
 | 
					
						
							|  |  |  |  | 				multisocket 2
 | 
					
						
							|  |  |  |  | 			}`,
 | 
					
						
							|  |  |  |  | 			expectedServers: 2,
 | 
					
						
							|  |  |  |  | 			expectedPort:    "5056",
 | 
					
						
							|  |  |  |  | 		},
 | 
					
						
							|  |  |  |  | 		{
 | 
					
						
							|  |  |  |  | 			name: "multisocket 100",
 | 
					
						
							|  |  |  |  | 			corefile: `.:5057 {
 | 
					
						
							|  |  |  |  | 				multisocket 100
 | 
					
						
							|  |  |  |  | 			}`,
 | 
					
						
							|  |  |  |  | 			expectedServers: 100,
 | 
					
						
							|  |  |  |  | 			expectedPort:    "5057",
 | 
					
						
							|  |  |  |  | 		},
 | 
					
						
							|  |  |  |  | 	}
 | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  | 	for _, test := range tests {
 | 
					
						
							|  |  |  |  | 		t.Run(test.name, func(t *testing.T) {
 | 
					
						
							|  |  |  |  | 			s, err := CoreDNSServer(test.corefile)
 | 
					
						
							|  |  |  |  | 			if err != nil {
 | 
					
						
							|  |  |  |  | 				t.Fatalf("Could not get CoreDNS serving instance: %s", err)
 | 
					
						
							|  |  |  |  | 			}
 | 
					
						
							| 
									
										
										
										
											2025-08-11 23:37:09 +05:30
										 |  |  |  | 			defer s.Stop()
 | 
					
						
							|  |  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-11-13 20:40:25 +03:00
										 |  |  |  | 			// check number of servers
 | 
					
						
							|  |  |  |  | 			if len(s.Servers()) != test.expectedServers {
 | 
					
						
							|  |  |  |  | 				t.Fatalf("Expected %d servers, got %d", test.expectedServers, len(s.Servers()))
 | 
					
						
							|  |  |  |  | 			}
 | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  | 			// check that ports are the same
 | 
					
						
							|  |  |  |  | 			for _, listener := range s.Servers() {
 | 
					
						
							|  |  |  |  | 				if listener.Addr().String() != listener.LocalAddr().String() {
 | 
					
						
							|  |  |  |  | 					t.Fatalf("Expected tcp address %s to be on the same port as udp address %s",
 | 
					
						
							|  |  |  |  | 						listener.LocalAddr().String(), listener.Addr().String())
 | 
					
						
							|  |  |  |  | 				}
 | 
					
						
							|  |  |  |  | 				_, port, err := net.SplitHostPort(listener.Addr().String())
 | 
					
						
							|  |  |  |  | 				if err != nil {
 | 
					
						
							|  |  |  |  | 					t.Fatalf("Could not get port from listener addr: %s", err)
 | 
					
						
							|  |  |  |  | 				}
 | 
					
						
							|  |  |  |  | 				if port != test.expectedPort {
 | 
					
						
							|  |  |  |  | 					t.Fatalf("Expected port %s, got %s", test.expectedPort, port)
 | 
					
						
							|  |  |  |  | 				}
 | 
					
						
							|  |  |  |  | 			}
 | 
					
						
							|  |  |  |  | 		})
 | 
					
						
							|  |  |  |  | 	}
 | 
					
						
							|  |  |  |  | }
 | 
					
						
							|  |  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-08-11 23:37:09 +05:30
										 |  |  |  | // NOTE: restart uses a different port to avoid transient EADDRINUSE / shutdown races
 | 
					
						
							|  |  |  |  | // when TCP/UDP from the previous instance haven’t fully torn down yet.
 | 
					
						
							| 
									
										
										
										
											2024-11-13 20:40:25 +03:00
										 |  |  |  | func TestMultisocket_Restart(t *testing.T) {
 | 
					
						
							|  |  |  |  | 	tests := []struct {
 | 
					
						
							|  |  |  |  | 		name             string
 | 
					
						
							|  |  |  |  | 		numSocketsBefore int
 | 
					
						
							|  |  |  |  | 		numSocketsAfter  int
 | 
					
						
							|  |  |  |  | 	}{
 | 
					
						
							| 
									
										
										
										
											2025-08-11 23:37:09 +05:30
										 |  |  |  | 		{name: "increase", numSocketsBefore: 1, numSocketsAfter: 2},
 | 
					
						
							|  |  |  |  | 		{name: "decrease", numSocketsBefore: 2, numSocketsAfter: 1},
 | 
					
						
							|  |  |  |  | 		{name: "no changes", numSocketsBefore: 2, numSocketsAfter: 2},
 | 
					
						
							| 
									
										
										
										
											2024-11-13 20:40:25 +03:00
										 |  |  |  | 	}
 | 
					
						
							|  |  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-08-11 23:37:09 +05:30
										 |  |  |  | 	for _, tc := range tests {
 | 
					
						
							|  |  |  |  | 		t.Run(tc.name, func(t *testing.T) {
 | 
					
						
							|  |  |  |  | 			port1 := pickPort(t)
 | 
					
						
							|  |  |  |  | 			port2 := pickPort(t) // restart onto a different free port
 | 
					
						
							|  |  |  |  | 			coreTmpl := `.:%d {
 | 
					
						
							| 
									
										
										
										
											2024-11-13 20:40:25 +03:00
										 |  |  |  | 				multisocket %d
 | 
					
						
							|  |  |  |  | 			}`
 | 
					
						
							| 
									
										
										
										
											2025-08-11 23:37:09 +05:30
										 |  |  |  | 
 | 
					
						
							|  |  |  |  | 			srv, err := CoreDNSServer(fmt.Sprintf(coreTmpl, port1, tc.numSocketsBefore))
 | 
					
						
							| 
									
										
										
										
											2024-11-13 20:40:25 +03:00
										 |  |  |  | 			if err != nil {
 | 
					
						
							| 
									
										
										
										
											2025-08-11 23:37:09 +05:30
										 |  |  |  | 				t.Fatalf("Could not get CoreDNS serving instance: %v", err)
 | 
					
						
							| 
									
										
										
										
											2024-11-13 20:40:25 +03:00
										 |  |  |  | 			}
 | 
					
						
							| 
									
										
										
										
											2025-08-11 23:37:09 +05:30
										 |  |  |  | 			if got := len(srv.Servers()); got != tc.numSocketsBefore {
 | 
					
						
							|  |  |  |  | 				t.Fatalf("Expected %d servers, got %d", tc.numSocketsBefore, got)
 | 
					
						
							| 
									
										
										
										
											2024-11-13 20:40:25 +03:00
										 |  |  |  | 			}
 | 
					
						
							|  |  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-08-11 23:37:09 +05:30
										 |  |  |  | 			resultCh := make(chan int, 1)
 | 
					
						
							|  |  |  |  | 			errCh := make(chan error, 1)
 | 
					
						
							|  |  |  |  | 			stopCh := make(chan struct{})
 | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  | 			// Do the restart in a goroutine; return only the server count.
 | 
					
						
							|  |  |  |  | 			go func() {
 | 
					
						
							|  |  |  |  | 				newSrv, rerr := srv.Restart(NewInput(fmt.Sprintf(coreTmpl, port2, tc.numSocketsAfter)))
 | 
					
						
							|  |  |  |  | 				if rerr != nil {
 | 
					
						
							|  |  |  |  | 					errCh <- rerr
 | 
					
						
							|  |  |  |  | 					return
 | 
					
						
							|  |  |  |  | 				}
 | 
					
						
							|  |  |  |  | 				resultCh <- len(newSrv.Servers())
 | 
					
						
							|  |  |  |  | 				<-stopCh
 | 
					
						
							|  |  |  |  | 				newSrv.Stop()
 | 
					
						
							|  |  |  |  | 			}()
 | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  | 			select {
 | 
					
						
							|  |  |  |  | 			case got := <-resultCh:
 | 
					
						
							|  |  |  |  | 				if got != tc.numSocketsAfter {
 | 
					
						
							|  |  |  |  | 					close(stopCh) // still stop the new instance
 | 
					
						
							|  |  |  |  | 					t.Fatalf("Expected %d servers, got %d", tc.numSocketsAfter, got)
 | 
					
						
							|  |  |  |  | 				}
 | 
					
						
							|  |  |  |  | 				close(stopCh) // now safe to stop the new instance
 | 
					
						
							|  |  |  |  | 			case rerr := <-errCh:
 | 
					
						
							|  |  |  |  | 				// Restart failed; stop the original instance.
 | 
					
						
							|  |  |  |  | 				srv.Stop()
 | 
					
						
							|  |  |  |  | 				t.Fatalf("Restart failed: %v", rerr)
 | 
					
						
							|  |  |  |  | 			case <-time.After(30 * time.Second):
 | 
					
						
							|  |  |  |  | 				// Timeout; stop the original instance.
 | 
					
						
							|  |  |  |  | 				srv.Stop()
 | 
					
						
							|  |  |  |  | 				t.Fatalf("Restart timed out after 30s (ports :%d→:%d, %d→%d sockets)",
 | 
					
						
							|  |  |  |  | 					port1, port2, tc.numSocketsBefore, tc.numSocketsAfter)
 | 
					
						
							| 
									
										
										
										
											2024-11-13 20:40:25 +03:00
										 |  |  |  | 			}
 | 
					
						
							|  |  |  |  | 		})
 | 
					
						
							|  |  |  |  | 	}
 | 
					
						
							|  |  |  |  | }
 | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  | // Just check that server with multisocket works
 | 
					
						
							|  |  |  |  | func TestMultisocket_WhoAmI(t *testing.T) {
 | 
					
						
							|  |  |  |  | 	corefile := `.:5059 {
 | 
					
						
							|  |  |  |  | 		multisocket
 | 
					
						
							|  |  |  |  | 		whoami
 | 
					
						
							|  |  |  |  | 	}`
 | 
					
						
							|  |  |  |  | 	s, udp, tcp, err := CoreDNSServerAndPorts(corefile)
 | 
					
						
							|  |  |  |  | 	if err != nil {
 | 
					
						
							|  |  |  |  | 		t.Fatalf("Could not get CoreDNS serving instance: %s", err)
 | 
					
						
							|  |  |  |  | 	}
 | 
					
						
							|  |  |  |  | 	defer s.Stop()
 | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  | 	m := new(dns.Msg)
 | 
					
						
							|  |  |  |  | 	m.SetQuestion("whoami.example.org.", dns.TypeA)
 | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  | 	// check udp
 | 
					
						
							|  |  |  |  | 	cl := dns.Client{Net: "udp"}
 | 
					
						
							|  |  |  |  | 	udpResp, err := dns.Exchange(m, udp)
 | 
					
						
							|  |  |  |  | 	if err != nil {
 | 
					
						
							|  |  |  |  | 		t.Fatalf("Expected to receive reply, but didn't: %v", err)
 | 
					
						
							|  |  |  |  | 	}
 | 
					
						
							|  |  |  |  | 	// check tcp
 | 
					
						
							|  |  |  |  | 	cl.Net = "tcp"
 | 
					
						
							|  |  |  |  | 	tcpResp, _, err := cl.Exchange(m, tcp)
 | 
					
						
							|  |  |  |  | 	if err != nil {
 | 
					
						
							|  |  |  |  | 		t.Fatalf("Expected to receive reply, but didn't: %v", err)
 | 
					
						
							|  |  |  |  | 	}
 | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  | 	for _, resp := range []*dns.Msg{udpResp, tcpResp} {
 | 
					
						
							|  |  |  |  | 		if resp.Rcode != dns.RcodeSuccess {
 | 
					
						
							|  |  |  |  | 			t.Fatalf("Expected RcodeSuccess, got %v", resp.Rcode)
 | 
					
						
							|  |  |  |  | 		}
 | 
					
						
							|  |  |  |  | 		if len(resp.Extra) != 2 {
 | 
					
						
							|  |  |  |  | 			t.Errorf("Expected 2 RRs in additional section, got %d", len(resp.Extra))
 | 
					
						
							|  |  |  |  | 		}
 | 
					
						
							|  |  |  |  | 	}
 | 
					
						
							|  |  |  |  | }
 |