fix(test): deduplicate TSIG test helpers (#8009)

This commit is contained in:
Ville Vesilehto
2026-04-04 20:37:59 +03:00
committed by GitHub
parent 0e1870d762
commit ce9da6fa41
3 changed files with 61 additions and 84 deletions

View File

@@ -16,43 +16,6 @@ import (
"google.golang.org/grpc/peer"
)
type tsigStatusCheckPlugin struct {
t *testing.T
check func(*testing.T, error)
called *bool
}
func (p tsigStatusCheckPlugin) Name() string { return "tsig-status-check" }
func (p tsigStatusCheckPlugin) ServeDNS(_ context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) {
p.t.Helper()
if p.called != nil {
*p.called = true
}
p.check(p.t, w.TsigStatus())
m := new(dns.Msg)
m.SetReply(r)
if err := w.WriteMsg(m); err != nil {
p.t.Fatalf("WriteMsg() failed: %v", err)
}
return dns.RcodeSuccess, nil
}
func mustPackSignedTSIGQuery(t *testing.T, keyName, secret string, tsigTime int64) []byte {
t.Helper()
m := new(dns.Msg)
m.SetQuestion("example.com.", dns.TypeA)
m.SetTsig(keyName, dns.HmacSHA256, 300, tsigTime)
wire, _, err := dns.TsigGenerate(m, secret, "", false)
if err != nil {
t.Fatalf("dns.TsigGenerate() failed: %v", err)
}
return wire
}
func TestNewServergRPC(t *testing.T) {
tests := []struct {
name string
@@ -504,12 +467,12 @@ func TestServergRPC_Query_TSIGBadSigSetsTsigStatus(t *testing.T) {
const clientSecret = "MTIzNDU2Nzg5MDEyMzQ1Ng=="
const serverSecret = "QUJDREVGR0hJSktMTU5PUA=="
called := false
called := make(chan struct{}, 1)
server, err := NewServergRPC("127.0.0.1:0", []*Config{
testConfig("grpc", tsigStatusCheckPlugin{
t: t,
called: &called,
called: called,
check: func(t *testing.T, got error) {
t.Helper()
if got == nil {
@@ -546,7 +509,9 @@ func TestServergRPC_Query_TSIGBadSigSetsTsigStatus(t *testing.T) {
t.Fatalf("Query() failed: %v", err)
}
if !called {
select {
case <-called:
default:
t.Fatal("ServeDNS() was not called")
}
}
@@ -555,12 +520,12 @@ func TestServergRPC_Query_TSIGBadTimeSetsTsigStatus(t *testing.T) {
const keyName = "tsig-key."
const secret = "MTIzNDU2Nzg5MDEyMzQ1Ng=="
called := false
called := make(chan struct{}, 1)
server, err := NewServergRPC("127.0.0.1:0", []*Config{
testConfig("grpc", tsigStatusCheckPlugin{
t: t,
called: &called,
called: called,
check: func(t *testing.T, got error) {
t.Helper()
if !errors.Is(got, dns.ErrTime) {
@@ -591,7 +556,9 @@ func TestServergRPC_Query_TSIGBadTimeSetsTsigStatus(t *testing.T) {
t.Fatalf("Query() failed: %v", err)
}
if !called {
select {
case <-called:
default:
t.Fatal("ServeDNS() was not called")
}
}
@@ -600,12 +567,12 @@ func TestServergRPC_Query_TSIGValidLeavesTsigStatusNil(t *testing.T) {
const keyName = "tsig-key."
const secret = "MTIzNDU2Nzg5MDEyMzQ1Ng=="
called := false
called := make(chan struct{}, 1)
server, err := NewServergRPC("127.0.0.1:0", []*Config{
testConfig("grpc", tsigStatusCheckPlugin{
t: t,
called: &called,
called: called,
check: func(t *testing.T, got error) {
t.Helper()
if got != nil {
@@ -636,7 +603,9 @@ func TestServergRPC_Query_TSIGValidLeavesTsigStatusNil(t *testing.T) {
t.Fatalf("Query() failed: %v", err)
}
if !called {
select {
case <-called:
default:
t.Fatal("ServeDNS() was not called")
}

View File

@@ -19,44 +19,6 @@ import (
"github.com/quic-go/quic-go"
)
type tsigStatusCheckPlugin struct {
t *testing.T
check func(*testing.T, error)
called chan struct{}
}
func (p tsigStatusCheckPlugin) Name() string { return "tsig-status-check" }
func (p tsigStatusCheckPlugin) ServeDNS(_ context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) {
p.t.Helper()
if p.called != nil {
p.called <- struct{}{}
}
p.check(p.t, w.TsigStatus())
m := new(dns.Msg)
m.SetReply(r)
if err := w.WriteMsg(m); err != nil {
p.t.Fatalf("WriteMsg() failed: %v", err)
}
return dns.RcodeSuccess, nil
}
func mustPackSignedTSIGQuery(t *testing.T, keyName, secret string, tsigTime int64) []byte {
t.Helper()
m := new(dns.Msg)
m.SetQuestion("example.com.", dns.TypeA)
m.Id = 0
m.SetTsig(keyName, dns.HmacSHA256, 300, tsigTime)
wire, _, err := dns.TsigGenerate(m, secret, "", false)
if err != nil {
t.Fatalf("dns.TsigGenerate() failed: %v", err)
}
return wire
}
func TestNewServerQUIC(t *testing.T) {
tests := []struct {
name string

View File

@@ -0,0 +1,46 @@
package dnsserver
import (
"context"
"testing"
"github.com/miekg/dns"
)
type tsigStatusCheckPlugin struct {
t *testing.T
check func(*testing.T, error)
called chan struct{}
}
func (p tsigStatusCheckPlugin) Name() string { return "tsig-status-check" }
func (p tsigStatusCheckPlugin) ServeDNS(_ context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) {
p.t.Helper()
if p.called != nil {
p.called <- struct{}{}
}
p.check(p.t, w.TsigStatus())
m := new(dns.Msg)
m.SetReply(r)
if err := w.WriteMsg(m); err != nil {
p.t.Fatalf("WriteMsg() failed: %v", err)
}
return dns.RcodeSuccess, nil
}
func mustPackSignedTSIGQuery(t *testing.T, keyName, secret string, tsigTime int64) []byte {
t.Helper()
m := new(dns.Msg)
m.SetQuestion("example.com.", dns.TypeA)
m.Id = 0
m.SetTsig(keyName, dns.HmacSHA256, 300, tsigTime)
wire, _, err := dns.TsigGenerate(m, secret, "", false)
if err != nil {
t.Fatalf("dns.TsigGenerate() failed: %v", err)
}
return wire
}