mirror of
https://github.com/coredns/coredns.git
synced 2026-04-05 11:45:33 -04:00
fix(transfer): batch AXFR records by message size instead of count (#8002)
This commit is contained in:
@@ -388,3 +388,89 @@ func TestLongestMatchNilWhenNoMatch(t *testing.T) {
|
||||
t.Fatalf("expected nil when no zones match, got %+v", got)
|
||||
}
|
||||
}
|
||||
|
||||
// largeRecordTransferer produces records that are large enough to exceed 64KB
|
||||
// if batched by a fixed record count.
|
||||
type largeRecordTransferer struct {
|
||||
Zone string
|
||||
Count int
|
||||
TxtSize int
|
||||
}
|
||||
|
||||
func (lr *largeRecordTransferer) Name() string { return "largerecordtransferer" }
|
||||
func (lr *largeRecordTransferer) ServeDNS(_ context.Context, _ dns.ResponseWriter, _ *dns.Msg) (int, error) {
|
||||
return 0, nil
|
||||
}
|
||||
func (lr *largeRecordTransferer) Transfer(zone string, _ uint32) (<-chan []dns.RR, error) {
|
||||
if zone != lr.Zone {
|
||||
return nil, ErrNotAuthoritative
|
||||
}
|
||||
ch := make(chan []dns.RR, 2)
|
||||
go func() {
|
||||
defer close(ch)
|
||||
soa := test.SOA(fmt.Sprintf("%s 100 IN SOA ns.dns.%s hostmaster.%s 1 7200 1800 86400 100", lr.Zone, lr.Zone, lr.Zone))
|
||||
ch <- []dns.RR{soa}
|
||||
payload := make([]byte, lr.TxtSize)
|
||||
for i := range payload {
|
||||
payload[i] = 'x'
|
||||
}
|
||||
txt := string(payload)
|
||||
for i := range lr.Count {
|
||||
rr, _ := dns.NewRR(fmt.Sprintf("r%d.%s 3600 IN TXT \"%s\"", i, lr.Zone, txt))
|
||||
ch <- []dns.RR{rr}
|
||||
}
|
||||
ch <- []dns.RR{soa}
|
||||
}()
|
||||
return ch, nil
|
||||
}
|
||||
|
||||
func TestTransferLargeRecordBatching(t *testing.T) {
|
||||
// 300 TXT records of ~250 bytes each = ~75KB total, exceeding a single
|
||||
// 64KB TCP message. The transfer plugin must split them into multiple
|
||||
// messages.
|
||||
lr := &largeRecordTransferer{Zone: "example.org.", Count: 300, TxtSize: 240}
|
||||
|
||||
tr := &Transfer{
|
||||
Transferers: []Transferer{lr},
|
||||
xfrs: []*xfr{{Zones: []string{"example.org."}, to: []string{"*"}}},
|
||||
Next: lr,
|
||||
}
|
||||
|
||||
ctx := context.TODO()
|
||||
w := dnstest.NewMultiRecorder(&test.ResponseWriter{TCP: true})
|
||||
m := new(dns.Msg)
|
||||
m.SetAxfr("example.org.")
|
||||
|
||||
_, err := tr.ServeDNS(ctx, w, m)
|
||||
if err != nil {
|
||||
t.Fatalf("ServeDNS error: %v", err)
|
||||
}
|
||||
|
||||
if len(w.Msgs) == 0 {
|
||||
t.Fatal("no messages received")
|
||||
}
|
||||
|
||||
// Count total records across all messages.
|
||||
total := 0
|
||||
for _, msg := range w.Msgs {
|
||||
// Each message must fit in a TCP DNS message (65535 bytes).
|
||||
packed, packErr := msg.Pack()
|
||||
if packErr != nil {
|
||||
t.Fatalf("message too large to pack: %v", packErr)
|
||||
}
|
||||
if len(packed) > dns.MaxMsgSize {
|
||||
t.Errorf("message size %d exceeds max %d", len(packed), dns.MaxMsgSize)
|
||||
}
|
||||
total += len(msg.Answer)
|
||||
}
|
||||
|
||||
// 300 TXT + 2 SOA = 302 records
|
||||
if total != 302 {
|
||||
t.Errorf("expected 302 records total, got %d", total)
|
||||
}
|
||||
|
||||
// Must be split into multiple messages.
|
||||
if len(w.Msgs) < 2 {
|
||||
t.Errorf("expected multiple messages for large transfer, got %d", len(w.Msgs))
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user