mirror of
https://github.com/coredns/coredns.git
synced 2025-10-27 16:24:19 -04:00
test(dnsserver): add unit tests for gRPC and QUIC servers (#7319)
Add comprehensive unit test coverage for DNS-over-gRPC and DNS-over-QUIC server implementations: - server_grpc_test.go: Tests gRPC server creation, TLS config, lifecycle methods, Query handling, and response writer - server_quic_test.go: Tests QUIC server creation, custom limits, message validation, DOQ message parsing, and writer interface Tests focus on component-level validation with mocks, complementing existing integration tests without overlap. Signed-off-by: Ville Vesilehto <ville@vesilehto.fi>
This commit is contained in:
330
core/dnsserver/server_grpc_test.go
Normal file
330
core/dnsserver/server_grpc_test.go
Normal file
@@ -0,0 +1,330 @@
|
||||
package dnsserver
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"net"
|
||||
"testing"
|
||||
|
||||
"github.com/coredns/coredns/pb"
|
||||
"github.com/coredns/coredns/plugin/pkg/transport"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/peer"
|
||||
)
|
||||
|
||||
func TestNewServergRPC(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
addr string
|
||||
configs []*Config
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "valid grpc server",
|
||||
addr: "127.0.0.1:0",
|
||||
configs: []*Config{testConfig("grpc", testPlugin{})},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "empty configs",
|
||||
addr: "127.0.0.1:0",
|
||||
configs: []*Config{},
|
||||
wantErr: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
server, err := NewServergRPC(tt.addr, tt.configs)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("NewServergRPC() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
if !tt.wantErr && server == nil {
|
||||
t.Error("NewServergRPC() returned nil server without error")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewServergRPCWithTLS(t *testing.T) {
|
||||
config := testConfig("grpc", testPlugin{})
|
||||
config.TLSConfig = &tls.Config{
|
||||
ServerName: "example.com",
|
||||
}
|
||||
|
||||
server, err := NewServergRPC("127.0.0.1:0", []*Config{config})
|
||||
if err != nil {
|
||||
t.Fatalf("NewServergRPC() with TLS failed: %v", err)
|
||||
}
|
||||
|
||||
if server.tlsConfig == nil {
|
||||
t.Error("Expected TLS config to be set")
|
||||
}
|
||||
|
||||
if len(server.tlsConfig.NextProtos) == 0 || server.tlsConfig.NextProtos[0] != "h2" {
|
||||
t.Error("Expected NextProtos to include h2 for gRPC")
|
||||
}
|
||||
}
|
||||
|
||||
func TestServergRPC_Listen(t *testing.T) {
|
||||
server, err := NewServergRPC(transport.GRPC+"://127.0.0.1:0", []*Config{testConfig("grpc", testPlugin{})})
|
||||
if err != nil {
|
||||
t.Fatalf("NewServergRPC() failed: %v", err)
|
||||
}
|
||||
|
||||
listener, err := server.Listen()
|
||||
if err != nil {
|
||||
t.Fatalf("Listen() failed: %v", err)
|
||||
}
|
||||
defer listener.Close()
|
||||
|
||||
if listener == nil {
|
||||
t.Error("Listen() returned nil listener")
|
||||
}
|
||||
|
||||
// Verify it's a TCP listener
|
||||
if _, ok := listener.Addr().(*net.TCPAddr); !ok {
|
||||
t.Errorf("Expected TCP listener, got %T", listener.Addr())
|
||||
}
|
||||
}
|
||||
|
||||
func TestServergRPC_Listen_InvalidAddress(t *testing.T) {
|
||||
server, err := NewServergRPC(transport.GRPC+"://invalid:99999", []*Config{testConfig("grpc", testPlugin{})})
|
||||
if err != nil {
|
||||
t.Fatalf("NewServergRPC() failed: %v", err)
|
||||
}
|
||||
|
||||
_, err = server.Listen()
|
||||
if err == nil {
|
||||
t.Error("Listen() should fail with invalid address")
|
||||
}
|
||||
}
|
||||
|
||||
func TestServergRPC_ListenPacket(t *testing.T) {
|
||||
server, err := NewServergRPC("127.0.0.1:0", []*Config{testConfig("grpc", testPlugin{})})
|
||||
if err != nil {
|
||||
t.Fatalf("NewServergRPC() failed: %v", err)
|
||||
}
|
||||
|
||||
conn, err := server.ListenPacket()
|
||||
if err != nil {
|
||||
t.Errorf("ListenPacket() failed: %v", err)
|
||||
}
|
||||
if conn != nil {
|
||||
t.Error("ListenPacket() should return nil for gRPC server")
|
||||
}
|
||||
}
|
||||
|
||||
func TestServergRPC_ServePacket(t *testing.T) {
|
||||
server, err := NewServergRPC("127.0.0.1:0", []*Config{testConfig("grpc", testPlugin{})})
|
||||
if err != nil {
|
||||
t.Fatalf("NewServergRPC() failed: %v", err)
|
||||
}
|
||||
|
||||
err = server.ServePacket(nil)
|
||||
if err != nil {
|
||||
t.Errorf("ServePacket() should not return error, got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestServergRPC_Stop(t *testing.T) {
|
||||
server, err := NewServergRPC("127.0.0.1:0", []*Config{testConfig("grpc", testPlugin{})})
|
||||
if err != nil {
|
||||
t.Fatalf("NewServergRPC() failed: %v", err)
|
||||
}
|
||||
|
||||
// Test stopping server without grpcServer initialized
|
||||
err = server.Stop()
|
||||
if err != nil {
|
||||
t.Errorf("Stop() failed: %v", err)
|
||||
}
|
||||
|
||||
// Test stopping after initializing grpcServer
|
||||
server.grpcServer = grpc.NewServer()
|
||||
err = server.Stop()
|
||||
if err != nil {
|
||||
t.Errorf("Stop() with grpcServer failed: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestServergRPC_Shutdown(t *testing.T) {
|
||||
server, err := NewServergRPC("127.0.0.1:0", []*Config{testConfig("grpc", testPlugin{})})
|
||||
if err != nil {
|
||||
t.Fatalf("NewServergRPC() failed: %v", err)
|
||||
}
|
||||
|
||||
// Test shutdown without grpcServer
|
||||
err = server.Shutdown()
|
||||
if err != nil {
|
||||
t.Errorf("Shutdown() failed: %v", err)
|
||||
}
|
||||
|
||||
// Test shutdown with grpcServer
|
||||
server.grpcServer = grpc.NewServer()
|
||||
err = server.Shutdown()
|
||||
if err != nil {
|
||||
t.Errorf("Shutdown() with grpcServer failed: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestServergRPC_OnStartupComplete(t *testing.T) {
|
||||
server, err := NewServergRPC("127.0.0.1:53", []*Config{testConfig("grpc", testPlugin{})})
|
||||
if err != nil {
|
||||
t.Fatalf("NewServergRPC() failed: %v", err)
|
||||
}
|
||||
|
||||
Quiet = true
|
||||
server.OnStartupComplete()
|
||||
|
||||
Quiet = false
|
||||
server.OnStartupComplete()
|
||||
}
|
||||
|
||||
func TestServergRPC_Query(t *testing.T) {
|
||||
server, err := NewServergRPC("127.0.0.1:0", []*Config{testConfig("grpc", testPlugin{})})
|
||||
if err != nil {
|
||||
t.Fatalf("NewServergRPC() failed: %v", err)
|
||||
}
|
||||
|
||||
msg := new(dns.Msg)
|
||||
msg.SetQuestion("example.com.", dns.TypeA)
|
||||
packed, err := msg.Pack()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to pack DNS message: %v", err)
|
||||
}
|
||||
|
||||
dnsPacket := &pb.DnsPacket{Msg: packed}
|
||||
|
||||
tcpAddr, _ := net.ResolveTCPAddr("tcp", "127.0.0.1:12345")
|
||||
p := &peer.Peer{Addr: tcpAddr}
|
||||
ctx := peer.NewContext(context.Background(), p)
|
||||
|
||||
server.listenAddr = tcpAddr
|
||||
|
||||
response, err := server.Query(ctx, dnsPacket)
|
||||
if err != nil {
|
||||
t.Errorf("Query() failed: %v", err)
|
||||
}
|
||||
|
||||
if len(response.Msg) == 0 {
|
||||
t.Error("Query() returned empty message")
|
||||
}
|
||||
|
||||
// Verify the response can be unpacked
|
||||
respMsg := new(dns.Msg)
|
||||
err = respMsg.Unpack(response.Msg)
|
||||
if err != nil {
|
||||
t.Errorf("Failed to unpack response message: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestServergRPC_Query_ErrorCases(t *testing.T) {
|
||||
server, err := NewServergRPC("127.0.0.1:0", []*Config{testConfig("grpc", testPlugin{})})
|
||||
if err != nil {
|
||||
t.Fatalf("NewServergRPC() failed: %v", err)
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
ctx context.Context
|
||||
packet *pb.DnsPacket
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "invalid DNS message",
|
||||
ctx: peer.NewContext(context.Background(), &peer.Peer{Addr: &net.TCPAddr{}}),
|
||||
packet: &pb.DnsPacket{Msg: []byte("invalid")},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "no peer in context",
|
||||
ctx: context.Background(),
|
||||
packet: &pb.DnsPacket{Msg: []byte{}},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "non-TCP peer",
|
||||
ctx: peer.NewContext(context.Background(), &peer.Peer{Addr: &net.UDPAddr{}}),
|
||||
packet: &pb.DnsPacket{Msg: []byte{}},
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
_, err := server.Query(tt.ctx, tt.packet)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("Query() error = %v, wantErr %v", err, tt.wantErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGRPCResponse(t *testing.T) {
|
||||
localAddr, _ := net.ResolveTCPAddr("tcp", "127.0.0.1:53")
|
||||
remoteAddr, _ := net.ResolveTCPAddr("tcp", "127.0.0.1:12345")
|
||||
|
||||
r := &gRPCresponse{
|
||||
localAddr: localAddr,
|
||||
remoteAddr: remoteAddr,
|
||||
}
|
||||
|
||||
if r.LocalAddr() != localAddr {
|
||||
t.Errorf("LocalAddr() = %v, want %v", r.LocalAddr(), localAddr)
|
||||
}
|
||||
|
||||
if r.RemoteAddr() != remoteAddr {
|
||||
t.Errorf("RemoteAddr() = %v, want %v", r.RemoteAddr(), remoteAddr)
|
||||
}
|
||||
|
||||
msg := new(dns.Msg)
|
||||
msg.SetQuestion("example.com.", dns.TypeA)
|
||||
packed, err := msg.Pack()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to pack DNS message: %v", err)
|
||||
}
|
||||
|
||||
n, err := r.Write(packed)
|
||||
if err != nil {
|
||||
t.Errorf("Write() failed: %v", err)
|
||||
}
|
||||
|
||||
if n != len(packed) {
|
||||
t.Errorf("Write() returned %d, want %d", n, len(packed))
|
||||
}
|
||||
|
||||
if r.Msg == nil {
|
||||
t.Error("Write() did not set Msg")
|
||||
}
|
||||
|
||||
newMsg := new(dns.Msg)
|
||||
newMsg.SetQuestion("test.com.", dns.TypeAAAA)
|
||||
|
||||
err = r.WriteMsg(newMsg)
|
||||
if err != nil {
|
||||
t.Errorf("WriteMsg() failed: %v", err)
|
||||
}
|
||||
|
||||
if r.Msg != newMsg {
|
||||
t.Error("WriteMsg() did not set correct message")
|
||||
}
|
||||
if err := r.Close(); err != nil {
|
||||
t.Errorf("Close() returned error: %v", err)
|
||||
}
|
||||
|
||||
if err := r.TsigStatus(); err != nil {
|
||||
t.Errorf("TsigStatus() returned error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGRPCResponse_WriteInvalidMessage(t *testing.T) {
|
||||
r := &gRPCresponse{}
|
||||
|
||||
_, err := r.Write([]byte("invalid dns message"))
|
||||
if err == nil {
|
||||
t.Error("Write() should return error for invalid DNS message")
|
||||
}
|
||||
}
|
||||
506
core/dnsserver/server_quic_test.go
Normal file
506
core/dnsserver/server_quic_test.go
Normal file
@@ -0,0 +1,506 @@
|
||||
package dnsserver
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"io"
|
||||
"net"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
"github.com/quic-go/quic-go"
|
||||
)
|
||||
|
||||
func TestNewServerQUIC(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
addr string
|
||||
configs []*Config
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "valid quic server",
|
||||
addr: "127.0.0.1:0",
|
||||
configs: []*Config{testConfig("quic", testPlugin{})},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "empty configs",
|
||||
addr: "127.0.0.1:0",
|
||||
configs: []*Config{},
|
||||
wantErr: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
server, err := NewServerQUIC(tt.addr, tt.configs)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("NewServerQUIC() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
if !tt.wantErr && server == nil {
|
||||
t.Error("NewServerQUIC() returned nil server without error")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewServerQUICWithTLS(t *testing.T) {
|
||||
config := testConfig("quic", testPlugin{})
|
||||
config.TLSConfig = &tls.Config{
|
||||
ServerName: "example.com",
|
||||
}
|
||||
|
||||
server, err := NewServerQUIC("127.0.0.1:0", []*Config{config})
|
||||
if err != nil {
|
||||
t.Fatalf("NewServerQUIC() with TLS failed: %v", err)
|
||||
}
|
||||
|
||||
if server.tlsConfig == nil {
|
||||
t.Error("Expected TLS config to be set")
|
||||
}
|
||||
|
||||
if len(server.tlsConfig.NextProtos) == 0 || server.tlsConfig.NextProtos[0] != "doq" {
|
||||
t.Error("Expected NextProtos to include doq for QUIC")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewServerQUICWithCustomLimits(t *testing.T) {
|
||||
config := testConfig("quic", testPlugin{})
|
||||
maxStreams := 100
|
||||
workerPoolSize := 50
|
||||
config.MaxQUICStreams = &maxStreams
|
||||
config.MaxQUICWorkerPoolSize = &workerPoolSize
|
||||
|
||||
server, err := NewServerQUIC("127.0.0.1:0", []*Config{config})
|
||||
if err != nil {
|
||||
t.Fatalf("NewServerQUIC() with custom limits failed: %v", err)
|
||||
}
|
||||
|
||||
if server.maxStreams != maxStreams {
|
||||
t.Errorf("Expected maxStreams = %d, got %d", maxStreams, server.maxStreams)
|
||||
}
|
||||
|
||||
if cap(server.streamProcessPool) != workerPoolSize {
|
||||
t.Errorf("Expected streamProcessPool capacity = %d, got %d", workerPoolSize, cap(server.streamProcessPool))
|
||||
}
|
||||
|
||||
expectedMaxStreams := int64(maxStreams)
|
||||
if server.quicConfig.MaxIncomingStreams != expectedMaxStreams {
|
||||
t.Errorf("Expected quicConfig.MaxIncomingStreams = %d, got %d", expectedMaxStreams, server.quicConfig.MaxIncomingStreams)
|
||||
}
|
||||
|
||||
if server.quicConfig.MaxIncomingUniStreams != expectedMaxStreams {
|
||||
t.Errorf("Expected quicConfig.MaxIncomingUniStreams = %d, got %d", expectedMaxStreams, server.quicConfig.MaxIncomingUniStreams)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewServerQUICDefaults(t *testing.T) {
|
||||
server, err := NewServerQUIC("127.0.0.1:0", []*Config{testConfig("quic", testPlugin{})})
|
||||
if err != nil {
|
||||
t.Fatalf("NewServerQUIC() failed: %v", err)
|
||||
}
|
||||
|
||||
if server.maxStreams != DefaultMaxQUICStreams {
|
||||
t.Errorf("Expected default maxStreams = %d, got %d", DefaultMaxQUICStreams, server.maxStreams)
|
||||
}
|
||||
|
||||
if cap(server.streamProcessPool) != DefaultQUICStreamWorkers {
|
||||
t.Errorf("Expected default streamProcessPool capacity = %d, got %d", DefaultQUICStreamWorkers, cap(server.streamProcessPool))
|
||||
}
|
||||
|
||||
if !server.quicConfig.Allow0RTT {
|
||||
t.Error("Expected Allow0RTT to be true by default")
|
||||
}
|
||||
}
|
||||
|
||||
func TestServerQUIC_ServeAndListen(t *testing.T) {
|
||||
server, err := NewServerQUIC("127.0.0.1:0", []*Config{testConfig("quic", testPlugin{})})
|
||||
if err != nil {
|
||||
t.Fatalf("NewServerQUIC() failed: %v", err)
|
||||
}
|
||||
|
||||
// Test Serve - should return nil for QUIC (not used)
|
||||
err = server.Serve(nil)
|
||||
if err != nil {
|
||||
t.Errorf("Serve() should return nil for QUIC server, got: %v", err)
|
||||
}
|
||||
|
||||
// Test Listen - should return nil for QUIC (not used)
|
||||
listener, err := server.Listen()
|
||||
if err != nil {
|
||||
t.Errorf("Listen() should return nil error for QUIC server, got: %v", err)
|
||||
}
|
||||
if listener != nil {
|
||||
t.Error("Listen() should return nil listener for QUIC server")
|
||||
}
|
||||
}
|
||||
|
||||
func TestServerQUIC_OnStartupComplete(t *testing.T) {
|
||||
server, err := NewServerQUIC("127.0.0.1:53", []*Config{testConfig("quic", testPlugin{})})
|
||||
if err != nil {
|
||||
t.Fatalf("NewServerQUIC() failed: %v", err)
|
||||
}
|
||||
|
||||
Quiet = true
|
||||
server.OnStartupComplete()
|
||||
|
||||
Quiet = false
|
||||
server.OnStartupComplete()
|
||||
}
|
||||
|
||||
func TestServerQUIC_Stop(t *testing.T) {
|
||||
server, err := NewServerQUIC("127.0.0.1:0", []*Config{testConfig("quic", testPlugin{})})
|
||||
if err != nil {
|
||||
t.Fatalf("NewServerQUIC() failed: %v", err)
|
||||
}
|
||||
|
||||
err = server.Stop()
|
||||
if err != nil {
|
||||
t.Errorf("Stop() without listener should not error, got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestServerQUIC_CloseQUICConn(t *testing.T) {
|
||||
server, err := NewServerQUIC("127.0.0.1:0", []*Config{testConfig("quic", testPlugin{})})
|
||||
if err != nil {
|
||||
t.Fatalf("NewServerQUIC() failed: %v", err)
|
||||
}
|
||||
|
||||
server.closeQUICConn(nil, DoQCodeNoError)
|
||||
}
|
||||
|
||||
func TestServerQUIC_IsExpectedErr(t *testing.T) {
|
||||
server, err := NewServerQUIC("127.0.0.1:0", []*Config{testConfig("quic", testPlugin{})})
|
||||
if err != nil {
|
||||
t.Fatalf("NewServerQUIC() failed: %v", err)
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
err error
|
||||
expected bool
|
||||
}{
|
||||
{
|
||||
name: "nil error",
|
||||
err: nil,
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "server closed error",
|
||||
err: quic.ErrServerClosed,
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "application error code 2",
|
||||
err: &quic.ApplicationError{ErrorCode: 2},
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "application error code 1",
|
||||
err: &quic.ApplicationError{ErrorCode: 1},
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "idle timeout error",
|
||||
err: &quic.IdleTimeoutError{},
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "other error",
|
||||
err: errors.New("some other error"),
|
||||
expected: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := server.isExpectedErr(tt.err)
|
||||
if result != tt.expected {
|
||||
t.Errorf("isExpectedErr(%v) = %v, want %v", tt.err, result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidRequest(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
setupMsg func() *dns.Msg
|
||||
valid bool
|
||||
}{
|
||||
{
|
||||
name: "valid request",
|
||||
setupMsg: func() *dns.Msg {
|
||||
m := new(dns.Msg)
|
||||
m.SetQuestion("example.com.", dns.TypeA)
|
||||
m.Id = 0
|
||||
return m
|
||||
},
|
||||
valid: true,
|
||||
},
|
||||
{
|
||||
name: "non-zero message ID",
|
||||
setupMsg: func() *dns.Msg {
|
||||
m := new(dns.Msg)
|
||||
m.SetQuestion("example.com.", dns.TypeA)
|
||||
m.Id = 1234
|
||||
return m
|
||||
},
|
||||
valid: false,
|
||||
},
|
||||
{
|
||||
name: "with EDNS TCP keepalive",
|
||||
setupMsg: func() *dns.Msg {
|
||||
m := new(dns.Msg)
|
||||
m.SetQuestion("example.com.", dns.TypeA)
|
||||
m.Id = 0
|
||||
opt := &dns.OPT{
|
||||
Hdr: dns.RR_Header{
|
||||
Name: ".",
|
||||
Rrtype: dns.TypeOPT,
|
||||
Class: 4096,
|
||||
Ttl: 0,
|
||||
},
|
||||
Option: []dns.EDNS0{
|
||||
&dns.EDNS0_TCP_KEEPALIVE{
|
||||
Code: dns.EDNS0TCPKEEPALIVE,
|
||||
Timeout: 300,
|
||||
},
|
||||
},
|
||||
}
|
||||
m.Extra = append(m.Extra, opt)
|
||||
return m
|
||||
},
|
||||
valid: false,
|
||||
},
|
||||
{
|
||||
name: "with other EDNS options",
|
||||
setupMsg: func() *dns.Msg {
|
||||
m := new(dns.Msg)
|
||||
m.SetQuestion("example.com.", dns.TypeA)
|
||||
m.Id = 0
|
||||
opt := &dns.OPT{
|
||||
Hdr: dns.RR_Header{
|
||||
Name: ".",
|
||||
Rrtype: dns.TypeOPT,
|
||||
Class: 4096,
|
||||
Ttl: 0,
|
||||
},
|
||||
Option: []dns.EDNS0{
|
||||
&dns.EDNS0_NSID{
|
||||
Code: dns.EDNS0NSID,
|
||||
Nsid: "test",
|
||||
},
|
||||
},
|
||||
}
|
||||
m.Extra = append(m.Extra, opt)
|
||||
return m
|
||||
},
|
||||
valid: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
msg := tt.setupMsg()
|
||||
result := validRequest(msg)
|
||||
if result != tt.valid {
|
||||
t.Errorf("validRequest() = %v, want %v", result, tt.valid)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestReadDOQMessage(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input []byte
|
||||
wantMsg []byte
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "valid message",
|
||||
input: []byte{0x00, 0x05, 0x01, 0x02, 0x03, 0x04, 0x05},
|
||||
wantMsg: []byte{0x01, 0x02, 0x03, 0x04, 0x05},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "zero length message",
|
||||
input: []byte{0x00, 0x00},
|
||||
wantMsg: nil,
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "incomplete length prefix",
|
||||
input: []byte{0x00},
|
||||
wantMsg: nil,
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "incomplete message",
|
||||
input: []byte{0x00, 0x05, 0x01, 0x02},
|
||||
wantMsg: []byte{0x01, 0x02, 0x00, 0x00, 0x00},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "empty input",
|
||||
input: []byte{},
|
||||
wantMsg: nil,
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
reader := bytes.NewReader(tt.input)
|
||||
msg, err := readDOQMessage(reader)
|
||||
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("readDOQMessage() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
|
||||
if !bytes.Equal(msg, tt.wantMsg) {
|
||||
t.Errorf("readDOQMessage() msg = %v, want %v", msg, tt.wantMsg)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDoQWriter(t *testing.T) {
|
||||
mockStream := &mockQUICStream{}
|
||||
localAddr, _ := net.ResolveUDPAddr("udp", "127.0.0.1:53")
|
||||
remoteAddr, _ := net.ResolveUDPAddr("udp", "127.0.0.1:12345")
|
||||
|
||||
writer := &DoQWriter{
|
||||
localAddr: localAddr,
|
||||
remoteAddr: remoteAddr,
|
||||
stream: mockStream,
|
||||
}
|
||||
|
||||
if writer.LocalAddr() != localAddr {
|
||||
t.Errorf("LocalAddr() = %v, want %v", writer.LocalAddr(), localAddr)
|
||||
}
|
||||
|
||||
if writer.RemoteAddr() != remoteAddr {
|
||||
t.Errorf("RemoteAddr() = %v, want %v", writer.RemoteAddr(), remoteAddr)
|
||||
}
|
||||
|
||||
testData := []byte("test message")
|
||||
n, err := writer.Write(testData)
|
||||
if err != nil {
|
||||
t.Errorf("Write() failed: %v", err)
|
||||
}
|
||||
|
||||
expectedLen := len(testData) + 2 // +2 for length prefix
|
||||
if n != expectedLen {
|
||||
t.Errorf("Write() returned %d, want %d", n, expectedLen)
|
||||
}
|
||||
|
||||
// Verify the written data includes length prefix
|
||||
written := mockStream.writtenData
|
||||
if len(written) != expectedLen {
|
||||
t.Errorf("Expected written data length %d, got %d", expectedLen, len(written))
|
||||
}
|
||||
|
||||
// Check length prefix
|
||||
expectedLength := uint16(len(testData))
|
||||
actualLength := binary.BigEndian.Uint16(written[:2])
|
||||
if actualLength != expectedLength {
|
||||
t.Errorf("Expected length prefix %d, got %d", expectedLength, actualLength)
|
||||
}
|
||||
|
||||
// Check message content
|
||||
if !bytes.Equal(written[2:], testData) {
|
||||
t.Errorf("Expected message content %v, got %v", testData, written[2:])
|
||||
}
|
||||
|
||||
// Test WriteMsg method
|
||||
msg := new(dns.Msg)
|
||||
msg.SetQuestion("example.com.", dns.TypeA)
|
||||
msg.Id = 0
|
||||
|
||||
mockStream.reset()
|
||||
err = writer.WriteMsg(msg)
|
||||
if err != nil {
|
||||
t.Errorf("WriteMsg() failed: %v", err)
|
||||
}
|
||||
|
||||
if !mockStream.closed {
|
||||
t.Error("WriteMsg() should close the stream")
|
||||
}
|
||||
|
||||
if err := writer.TsigStatus(); err != nil {
|
||||
t.Errorf("TsigStatus() returned error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAddPrefix(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input []byte
|
||||
expected []byte
|
||||
}{
|
||||
{
|
||||
name: "empty message",
|
||||
input: []byte{},
|
||||
expected: []byte{0x00, 0x00},
|
||||
},
|
||||
{
|
||||
name: "short message",
|
||||
input: []byte{0x01, 0x02},
|
||||
expected: []byte{0x00, 0x02, 0x01, 0x02},
|
||||
},
|
||||
{
|
||||
name: "longer message",
|
||||
input: []byte{0x01, 0x02, 0x03, 0x04, 0x05},
|
||||
expected: []byte{0x00, 0x05, 0x01, 0x02, 0x03, 0x04, 0x05},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := AddPrefix(tt.input)
|
||||
if !bytes.Equal(result, tt.expected) {
|
||||
t.Errorf("AddPrefix() = %v, want %v", result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
type mockQUICStream struct {
|
||||
writtenData []byte
|
||||
closed bool
|
||||
}
|
||||
|
||||
func (m *mockQUICStream) Write(data []byte) (int, error) {
|
||||
m.writtenData = append(m.writtenData, data...)
|
||||
return len(data), nil
|
||||
}
|
||||
|
||||
func (m *mockQUICStream) Read([]byte) (int, error) { return 0, io.EOF }
|
||||
|
||||
func (m *mockQUICStream) Close() error {
|
||||
m.closed = true
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockQUICStream) reset() {
|
||||
m.writtenData = nil
|
||||
m.closed = false
|
||||
}
|
||||
|
||||
// Minimal implementation of other required methods
|
||||
func (m *mockQUICStream) StreamID() quic.StreamID { return 0 }
|
||||
func (m *mockQUICStream) SetReadDeadline(time.Time) error { return nil }
|
||||
func (m *mockQUICStream) SetWriteDeadline(time.Time) error { return nil }
|
||||
func (m *mockQUICStream) SetDeadline(time.Time) error { return nil }
|
||||
func (m *mockQUICStream) Context() context.Context { return context.Background() }
|
||||
func (m *mockQUICStream) CancelWrite(quic.StreamErrorCode) {}
|
||||
func (m *mockQUICStream) CancelRead(quic.StreamErrorCode) {}
|
||||
Reference in New Issue
Block a user