mirror of
https://github.com/coredns/coredns.git
synced 2026-04-04 19:25:40 -04:00
fix(test): deduplicate TSIG test helpers (#8009)
This commit is contained in:
@@ -16,43 +16,6 @@ import (
|
||||
"google.golang.org/grpc/peer"
|
||||
)
|
||||
|
||||
type tsigStatusCheckPlugin struct {
|
||||
t *testing.T
|
||||
check func(*testing.T, error)
|
||||
called *bool
|
||||
}
|
||||
|
||||
func (p tsigStatusCheckPlugin) Name() string { return "tsig-status-check" }
|
||||
|
||||
func (p tsigStatusCheckPlugin) ServeDNS(_ context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) {
|
||||
p.t.Helper()
|
||||
if p.called != nil {
|
||||
*p.called = true
|
||||
}
|
||||
p.check(p.t, w.TsigStatus())
|
||||
|
||||
m := new(dns.Msg)
|
||||
m.SetReply(r)
|
||||
if err := w.WriteMsg(m); err != nil {
|
||||
p.t.Fatalf("WriteMsg() failed: %v", err)
|
||||
}
|
||||
return dns.RcodeSuccess, nil
|
||||
}
|
||||
|
||||
func mustPackSignedTSIGQuery(t *testing.T, keyName, secret string, tsigTime int64) []byte {
|
||||
t.Helper()
|
||||
|
||||
m := new(dns.Msg)
|
||||
m.SetQuestion("example.com.", dns.TypeA)
|
||||
m.SetTsig(keyName, dns.HmacSHA256, 300, tsigTime)
|
||||
|
||||
wire, _, err := dns.TsigGenerate(m, secret, "", false)
|
||||
if err != nil {
|
||||
t.Fatalf("dns.TsigGenerate() failed: %v", err)
|
||||
}
|
||||
return wire
|
||||
}
|
||||
|
||||
func TestNewServergRPC(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
@@ -504,12 +467,12 @@ func TestServergRPC_Query_TSIGBadSigSetsTsigStatus(t *testing.T) {
|
||||
const clientSecret = "MTIzNDU2Nzg5MDEyMzQ1Ng=="
|
||||
const serverSecret = "QUJDREVGR0hJSktMTU5PUA=="
|
||||
|
||||
called := false
|
||||
called := make(chan struct{}, 1)
|
||||
|
||||
server, err := NewServergRPC("127.0.0.1:0", []*Config{
|
||||
testConfig("grpc", tsigStatusCheckPlugin{
|
||||
t: t,
|
||||
called: &called,
|
||||
called: called,
|
||||
check: func(t *testing.T, got error) {
|
||||
t.Helper()
|
||||
if got == nil {
|
||||
@@ -546,7 +509,9 @@ func TestServergRPC_Query_TSIGBadSigSetsTsigStatus(t *testing.T) {
|
||||
t.Fatalf("Query() failed: %v", err)
|
||||
}
|
||||
|
||||
if !called {
|
||||
select {
|
||||
case <-called:
|
||||
default:
|
||||
t.Fatal("ServeDNS() was not called")
|
||||
}
|
||||
}
|
||||
@@ -555,12 +520,12 @@ func TestServergRPC_Query_TSIGBadTimeSetsTsigStatus(t *testing.T) {
|
||||
const keyName = "tsig-key."
|
||||
const secret = "MTIzNDU2Nzg5MDEyMzQ1Ng=="
|
||||
|
||||
called := false
|
||||
called := make(chan struct{}, 1)
|
||||
|
||||
server, err := NewServergRPC("127.0.0.1:0", []*Config{
|
||||
testConfig("grpc", tsigStatusCheckPlugin{
|
||||
t: t,
|
||||
called: &called,
|
||||
called: called,
|
||||
check: func(t *testing.T, got error) {
|
||||
t.Helper()
|
||||
if !errors.Is(got, dns.ErrTime) {
|
||||
@@ -591,7 +556,9 @@ func TestServergRPC_Query_TSIGBadTimeSetsTsigStatus(t *testing.T) {
|
||||
t.Fatalf("Query() failed: %v", err)
|
||||
}
|
||||
|
||||
if !called {
|
||||
select {
|
||||
case <-called:
|
||||
default:
|
||||
t.Fatal("ServeDNS() was not called")
|
||||
}
|
||||
}
|
||||
@@ -600,12 +567,12 @@ func TestServergRPC_Query_TSIGValidLeavesTsigStatusNil(t *testing.T) {
|
||||
const keyName = "tsig-key."
|
||||
const secret = "MTIzNDU2Nzg5MDEyMzQ1Ng=="
|
||||
|
||||
called := false
|
||||
called := make(chan struct{}, 1)
|
||||
|
||||
server, err := NewServergRPC("127.0.0.1:0", []*Config{
|
||||
testConfig("grpc", tsigStatusCheckPlugin{
|
||||
t: t,
|
||||
called: &called,
|
||||
called: called,
|
||||
check: func(t *testing.T, got error) {
|
||||
t.Helper()
|
||||
if got != nil {
|
||||
@@ -636,7 +603,9 @@ func TestServergRPC_Query_TSIGValidLeavesTsigStatusNil(t *testing.T) {
|
||||
t.Fatalf("Query() failed: %v", err)
|
||||
}
|
||||
|
||||
if !called {
|
||||
select {
|
||||
case <-called:
|
||||
default:
|
||||
t.Fatal("ServeDNS() was not called")
|
||||
}
|
||||
|
||||
|
||||
@@ -19,44 +19,6 @@ import (
|
||||
"github.com/quic-go/quic-go"
|
||||
)
|
||||
|
||||
type tsigStatusCheckPlugin struct {
|
||||
t *testing.T
|
||||
check func(*testing.T, error)
|
||||
called chan struct{}
|
||||
}
|
||||
|
||||
func (p tsigStatusCheckPlugin) Name() string { return "tsig-status-check" }
|
||||
|
||||
func (p tsigStatusCheckPlugin) ServeDNS(_ context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) {
|
||||
p.t.Helper()
|
||||
if p.called != nil {
|
||||
p.called <- struct{}{}
|
||||
}
|
||||
p.check(p.t, w.TsigStatus())
|
||||
|
||||
m := new(dns.Msg)
|
||||
m.SetReply(r)
|
||||
if err := w.WriteMsg(m); err != nil {
|
||||
p.t.Fatalf("WriteMsg() failed: %v", err)
|
||||
}
|
||||
return dns.RcodeSuccess, nil
|
||||
}
|
||||
|
||||
func mustPackSignedTSIGQuery(t *testing.T, keyName, secret string, tsigTime int64) []byte {
|
||||
t.Helper()
|
||||
|
||||
m := new(dns.Msg)
|
||||
m.SetQuestion("example.com.", dns.TypeA)
|
||||
m.Id = 0
|
||||
m.SetTsig(keyName, dns.HmacSHA256, 300, tsigTime)
|
||||
|
||||
wire, _, err := dns.TsigGenerate(m, secret, "", false)
|
||||
if err != nil {
|
||||
t.Fatalf("dns.TsigGenerate() failed: %v", err)
|
||||
}
|
||||
return wire
|
||||
}
|
||||
|
||||
func TestNewServerQUIC(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
|
||||
46
core/dnsserver/tsig_test.go
Normal file
46
core/dnsserver/tsig_test.go
Normal file
@@ -0,0 +1,46 @@
|
||||
package dnsserver
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
|
||||
type tsigStatusCheckPlugin struct {
|
||||
t *testing.T
|
||||
check func(*testing.T, error)
|
||||
called chan struct{}
|
||||
}
|
||||
|
||||
func (p tsigStatusCheckPlugin) Name() string { return "tsig-status-check" }
|
||||
|
||||
func (p tsigStatusCheckPlugin) ServeDNS(_ context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) {
|
||||
p.t.Helper()
|
||||
if p.called != nil {
|
||||
p.called <- struct{}{}
|
||||
}
|
||||
p.check(p.t, w.TsigStatus())
|
||||
|
||||
m := new(dns.Msg)
|
||||
m.SetReply(r)
|
||||
if err := w.WriteMsg(m); err != nil {
|
||||
p.t.Fatalf("WriteMsg() failed: %v", err)
|
||||
}
|
||||
return dns.RcodeSuccess, nil
|
||||
}
|
||||
|
||||
func mustPackSignedTSIGQuery(t *testing.T, keyName, secret string, tsigTime int64) []byte {
|
||||
t.Helper()
|
||||
|
||||
m := new(dns.Msg)
|
||||
m.SetQuestion("example.com.", dns.TypeA)
|
||||
m.Id = 0
|
||||
m.SetTsig(keyName, dns.HmacSHA256, 300, tsigTime)
|
||||
|
||||
wire, _, err := dns.TsigGenerate(m, secret, "", false)
|
||||
if err != nil {
|
||||
t.Fatalf("dns.TsigGenerate() failed: %v", err)
|
||||
}
|
||||
return wire
|
||||
}
|
||||
Reference in New Issue
Block a user