From ce9da6fa41ec207bbbbda43385126097f843eff5 Mon Sep 17 00:00:00 2001 From: Ville Vesilehto Date: Sat, 4 Apr 2026 20:37:59 +0300 Subject: [PATCH] fix(test): deduplicate TSIG test helpers (#8009) --- core/dnsserver/server_grpc_test.go | 61 ++++++++---------------------- core/dnsserver/server_quic_test.go | 38 ------------------- core/dnsserver/tsig_test.go | 46 ++++++++++++++++++++++ 3 files changed, 61 insertions(+), 84 deletions(-) create mode 100644 core/dnsserver/tsig_test.go diff --git a/core/dnsserver/server_grpc_test.go b/core/dnsserver/server_grpc_test.go index 9f9e99cb6..3a74b1677 100644 --- a/core/dnsserver/server_grpc_test.go +++ b/core/dnsserver/server_grpc_test.go @@ -16,43 +16,6 @@ import ( "google.golang.org/grpc/peer" ) -type tsigStatusCheckPlugin struct { - t *testing.T - check func(*testing.T, error) - called *bool -} - -func (p tsigStatusCheckPlugin) Name() string { return "tsig-status-check" } - -func (p tsigStatusCheckPlugin) ServeDNS(_ context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) { - p.t.Helper() - if p.called != nil { - *p.called = true - } - p.check(p.t, w.TsigStatus()) - - m := new(dns.Msg) - m.SetReply(r) - if err := w.WriteMsg(m); err != nil { - p.t.Fatalf("WriteMsg() failed: %v", err) - } - return dns.RcodeSuccess, nil -} - -func mustPackSignedTSIGQuery(t *testing.T, keyName, secret string, tsigTime int64) []byte { - t.Helper() - - m := new(dns.Msg) - m.SetQuestion("example.com.", dns.TypeA) - m.SetTsig(keyName, dns.HmacSHA256, 300, tsigTime) - - wire, _, err := dns.TsigGenerate(m, secret, "", false) - if err != nil { - t.Fatalf("dns.TsigGenerate() failed: %v", err) - } - return wire -} - func TestNewServergRPC(t *testing.T) { tests := []struct { name string @@ -504,12 +467,12 @@ func TestServergRPC_Query_TSIGBadSigSetsTsigStatus(t *testing.T) { const clientSecret = "MTIzNDU2Nzg5MDEyMzQ1Ng==" const serverSecret = "QUJDREVGR0hJSktMTU5PUA==" - called := false + called := make(chan struct{}, 1) server, err := NewServergRPC("127.0.0.1:0", []*Config{ testConfig("grpc", tsigStatusCheckPlugin{ t: t, - called: &called, + called: called, check: func(t *testing.T, got error) { t.Helper() if got == nil { @@ -546,7 +509,9 @@ func TestServergRPC_Query_TSIGBadSigSetsTsigStatus(t *testing.T) { t.Fatalf("Query() failed: %v", err) } - if !called { + select { + case <-called: + default: t.Fatal("ServeDNS() was not called") } } @@ -555,12 +520,12 @@ func TestServergRPC_Query_TSIGBadTimeSetsTsigStatus(t *testing.T) { const keyName = "tsig-key." const secret = "MTIzNDU2Nzg5MDEyMzQ1Ng==" - called := false + called := make(chan struct{}, 1) server, err := NewServergRPC("127.0.0.1:0", []*Config{ testConfig("grpc", tsigStatusCheckPlugin{ t: t, - called: &called, + called: called, check: func(t *testing.T, got error) { t.Helper() if !errors.Is(got, dns.ErrTime) { @@ -591,7 +556,9 @@ func TestServergRPC_Query_TSIGBadTimeSetsTsigStatus(t *testing.T) { t.Fatalf("Query() failed: %v", err) } - if !called { + select { + case <-called: + default: t.Fatal("ServeDNS() was not called") } } @@ -600,12 +567,12 @@ func TestServergRPC_Query_TSIGValidLeavesTsigStatusNil(t *testing.T) { const keyName = "tsig-key." const secret = "MTIzNDU2Nzg5MDEyMzQ1Ng==" - called := false + called := make(chan struct{}, 1) server, err := NewServergRPC("127.0.0.1:0", []*Config{ testConfig("grpc", tsigStatusCheckPlugin{ t: t, - called: &called, + called: called, check: func(t *testing.T, got error) { t.Helper() if got != nil { @@ -636,7 +603,9 @@ func TestServergRPC_Query_TSIGValidLeavesTsigStatusNil(t *testing.T) { t.Fatalf("Query() failed: %v", err) } - if !called { + select { + case <-called: + default: t.Fatal("ServeDNS() was not called") } diff --git a/core/dnsserver/server_quic_test.go b/core/dnsserver/server_quic_test.go index f25fa9a2b..28d8931be 100644 --- a/core/dnsserver/server_quic_test.go +++ b/core/dnsserver/server_quic_test.go @@ -19,44 +19,6 @@ import ( "github.com/quic-go/quic-go" ) -type tsigStatusCheckPlugin struct { - t *testing.T - check func(*testing.T, error) - called chan struct{} -} - -func (p tsigStatusCheckPlugin) Name() string { return "tsig-status-check" } - -func (p tsigStatusCheckPlugin) ServeDNS(_ context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) { - p.t.Helper() - if p.called != nil { - p.called <- struct{}{} - } - p.check(p.t, w.TsigStatus()) - - m := new(dns.Msg) - m.SetReply(r) - if err := w.WriteMsg(m); err != nil { - p.t.Fatalf("WriteMsg() failed: %v", err) - } - return dns.RcodeSuccess, nil -} - -func mustPackSignedTSIGQuery(t *testing.T, keyName, secret string, tsigTime int64) []byte { - t.Helper() - - m := new(dns.Msg) - m.SetQuestion("example.com.", dns.TypeA) - m.Id = 0 - m.SetTsig(keyName, dns.HmacSHA256, 300, tsigTime) - - wire, _, err := dns.TsigGenerate(m, secret, "", false) - if err != nil { - t.Fatalf("dns.TsigGenerate() failed: %v", err) - } - return wire -} - func TestNewServerQUIC(t *testing.T) { tests := []struct { name string diff --git a/core/dnsserver/tsig_test.go b/core/dnsserver/tsig_test.go new file mode 100644 index 000000000..61cef3e80 --- /dev/null +++ b/core/dnsserver/tsig_test.go @@ -0,0 +1,46 @@ +package dnsserver + +import ( + "context" + "testing" + + "github.com/miekg/dns" +) + +type tsigStatusCheckPlugin struct { + t *testing.T + check func(*testing.T, error) + called chan struct{} +} + +func (p tsigStatusCheckPlugin) Name() string { return "tsig-status-check" } + +func (p tsigStatusCheckPlugin) ServeDNS(_ context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) { + p.t.Helper() + if p.called != nil { + p.called <- struct{}{} + } + p.check(p.t, w.TsigStatus()) + + m := new(dns.Msg) + m.SetReply(r) + if err := w.WriteMsg(m); err != nil { + p.t.Fatalf("WriteMsg() failed: %v", err) + } + return dns.RcodeSuccess, nil +} + +func mustPackSignedTSIGQuery(t *testing.T, keyName, secret string, tsigTime int64) []byte { + t.Helper() + + m := new(dns.Msg) + m.SetQuestion("example.com.", dns.TypeA) + m.Id = 0 + m.SetTsig(keyName, dns.HmacSHA256, 300, tsigTime) + + wire, _, err := dns.TsigGenerate(m, secret, "", false) + if err != nil { + t.Fatalf("dns.TsigGenerate() failed: %v", err) + } + return wire +}