diff --git a/request/edns0_test.go b/request/edns0_test.go new file mode 100644 index 000000000..78addd026 --- /dev/null +++ b/request/edns0_test.go @@ -0,0 +1,50 @@ +package request + +import ( + "testing" + + "github.com/miekg/dns" +) + +func TestSupportedOptions(t *testing.T) { + tests := []struct { + name string + options []dns.EDNS0 + expected int + }{ + { + name: "empty options", + options: []dns.EDNS0{}, + expected: 0, + }, + { + name: "all supported options", + options: []dns.EDNS0{ + &dns.EDNS0_NSID{}, + &dns.EDNS0_EXPIRE{}, + &dns.EDNS0_COOKIE{}, + &dns.EDNS0_TCP_KEEPALIVE{}, + &dns.EDNS0_PADDING{}, + }, + expected: 5, + }, + { + name: "mixed supported and unsupported options", + options: []dns.EDNS0{ + &dns.EDNS0_NSID{}, + &dns.EDNS0_LOCAL{Code: 65001}, // unsupported code + &dns.EDNS0_PADDING{}, + }, + expected: 2, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + result := supportedOptions(tc.options) + if len(result) != tc.expected { + t.Errorf("Expected %d supported options, got %d", tc.expected, len(result)) + } + }) + } +} diff --git a/request/request_test.go b/request/request_test.go index 0a3b1f2d8..b548f1483 100644 --- a/request/request_test.go +++ b/request/request_test.go @@ -32,6 +32,102 @@ func TestRequestRemote(t *testing.T) { } } +// TestRequestLocal tests LocalIP and LocalPort methods +func TestRequestLocal(t *testing.T) { + st := testRequest() + if st.LocalIP() != "127.0.0.1" { + t.Errorf("Wrong LocalIP from request, got %s", st.LocalIP()) + } + p := st.LocalPort() + if p == "" { + t.Errorf("Failed to get LocalPort from request") + } + if p != "53" { + t.Errorf("Wrong LocalPort from request, got %s", p) + } +} + +// TestRequestAddrs tests RemoteAddr and LocalAddr methods +func TestRequestAddrs(t *testing.T) { + st := testRequest() + remote := st.RemoteAddr() + if remote != "10.240.0.1:40212" { + t.Errorf("Wrong RemoteAddr from request, got %s", remote) + } + local := st.LocalAddr() + if local != "127.0.0.1:53" { + t.Errorf("Wrong LocalAddr from request, got %s", local) + } +} + +// TestRequestProto tests Proto and Family methods together +func TestRequestProto(t *testing.T) { + st := testRequest() + proto := st.Proto() + if proto != "udp" { + t.Errorf("Expected proto to be udp, got %s", proto) + } + family := st.Family() + if family != 1 { + t.Errorf("Expected family to be 1 (IPv4), got %d", family) + } +} + +// TestRequestSizeAndDo tests the SizeAndDo method +func TestRequestSizeAndDo(t *testing.T) { + st := testRequest() + m := new(dns.Msg) + + // Test with no OPT in the response + modified := st.SizeAndDo(m) + if !modified { + t.Errorf("Expected SizeAndDo to return true") + } + if m.IsEdns0() == nil { + t.Errorf("Expected OPT record to be added to response") + } + + // Test with existing OPT in the response + m = new(dns.Msg) + opt := new(dns.OPT) + opt.Hdr.Name = "." + opt.Hdr.Rrtype = dns.TypeOPT + opt.SetUDPSize(2048) + m.Extra = append(m.Extra, opt) + + modified = st.SizeAndDo(m) + if !modified { + t.Errorf("Expected SizeAndDo to return true") + } + if m.IsEdns0() == nil { + t.Errorf("Expected OPT record to remain in response") + } + if m.IsEdns0().UDPSize() != 4096 { + t.Errorf("Expected UDP size to be updated to 4096, got %d", m.IsEdns0().UDPSize()) + } +} + +// TestRequestNewWithQuestion tests the NewWithQuestion method +func TestRequestNewWithQuestion(t *testing.T) { + st := testRequest() + newReq := st.NewWithQuestion("example.org.", dns.TypeMX) + + if newReq.Name() != "example.org." { + t.Errorf("Expected new request name to be example.org., got %s", newReq.Name()) + } + if newReq.QType() != dns.TypeMX { + t.Errorf("Expected new request type to be MX, got %d", newReq.QType()) + } + + // Original request should be unchanged + if st.Name() != "example.com." { + t.Errorf("Expected original request to be unchanged, got %s", st.Name()) + } + if st.QType() != dns.TypeA { + t.Errorf("Expected original request type to remain A, got %d", st.QType()) + } +} + func TestRequestMalformed(t *testing.T) { m := new(dns.Msg) st := Request{Req: m} diff --git a/request/writer_test.go b/request/writer_test.go new file mode 100644 index 000000000..2b6a918f3 --- /dev/null +++ b/request/writer_test.go @@ -0,0 +1,51 @@ +package request + +import ( + "fmt" + "testing" + + "github.com/coredns/coredns/plugin/test" + + "github.com/miekg/dns" +) + +// mockResponseWriter implements dns.ResponseWriter interface for testing +type mockResponseWriter struct { + test.ResponseWriter + lastMsg *dns.Msg +} + +func (m *mockResponseWriter) WriteMsg(msg *dns.Msg) error { + m.lastMsg = msg + return nil +} + +func TestScrubWriter(t *testing.T) { + req := new(dns.Msg) + req.SetQuestion("example.com.", dns.TypeA) + req.SetEdns0(4096, true) + + mock := &mockResponseWriter{} + sw := NewScrubWriter(req, mock) + + // Create a large response message + resp := new(dns.Msg) + resp.SetReply(req) + + // Add a lot of records to make it large + for i := 1; i < 100; i++ { + resp.Answer = append(resp.Answer, test.A( + fmt.Sprintf("example.com. 10 IN A 10.0.0.%d", i))) + } + + // Write the message through ScrubWriter + err := sw.WriteMsg(resp) + if err != nil { + t.Errorf("Expected no error, got: %v", err) + } + + // Verify that ScrubWriter called methods properly + if mock.lastMsg == nil { + t.Fatalf("Expected WriteMsg to be called with a message") + } +}