fix(transfer): batch AXFR records by message size instead of count (#8002)

This commit is contained in:
Umut Polat
2026-04-04 21:35:27 +03:00
committed by GitHub
parent 03d0863a45
commit 61f4145506
2 changed files with 113 additions and 14 deletions

View File

@@ -388,3 +388,89 @@ func TestLongestMatchNilWhenNoMatch(t *testing.T) {
t.Fatalf("expected nil when no zones match, got %+v", got)
}
}
// largeRecordTransferer produces records that are large enough to exceed 64KB
// if batched by a fixed record count.
type largeRecordTransferer struct {
Zone string
Count int
TxtSize int
}
func (lr *largeRecordTransferer) Name() string { return "largerecordtransferer" }
func (lr *largeRecordTransferer) ServeDNS(_ context.Context, _ dns.ResponseWriter, _ *dns.Msg) (int, error) {
return 0, nil
}
func (lr *largeRecordTransferer) Transfer(zone string, _ uint32) (<-chan []dns.RR, error) {
if zone != lr.Zone {
return nil, ErrNotAuthoritative
}
ch := make(chan []dns.RR, 2)
go func() {
defer close(ch)
soa := test.SOA(fmt.Sprintf("%s 100 IN SOA ns.dns.%s hostmaster.%s 1 7200 1800 86400 100", lr.Zone, lr.Zone, lr.Zone))
ch <- []dns.RR{soa}
payload := make([]byte, lr.TxtSize)
for i := range payload {
payload[i] = 'x'
}
txt := string(payload)
for i := range lr.Count {
rr, _ := dns.NewRR(fmt.Sprintf("r%d.%s 3600 IN TXT \"%s\"", i, lr.Zone, txt))
ch <- []dns.RR{rr}
}
ch <- []dns.RR{soa}
}()
return ch, nil
}
func TestTransferLargeRecordBatching(t *testing.T) {
// 300 TXT records of ~250 bytes each = ~75KB total, exceeding a single
// 64KB TCP message. The transfer plugin must split them into multiple
// messages.
lr := &largeRecordTransferer{Zone: "example.org.", Count: 300, TxtSize: 240}
tr := &Transfer{
Transferers: []Transferer{lr},
xfrs: []*xfr{{Zones: []string{"example.org."}, to: []string{"*"}}},
Next: lr,
}
ctx := context.TODO()
w := dnstest.NewMultiRecorder(&test.ResponseWriter{TCP: true})
m := new(dns.Msg)
m.SetAxfr("example.org.")
_, err := tr.ServeDNS(ctx, w, m)
if err != nil {
t.Fatalf("ServeDNS error: %v", err)
}
if len(w.Msgs) == 0 {
t.Fatal("no messages received")
}
// Count total records across all messages.
total := 0
for _, msg := range w.Msgs {
// Each message must fit in a TCP DNS message (65535 bytes).
packed, packErr := msg.Pack()
if packErr != nil {
t.Fatalf("message too large to pack: %v", packErr)
}
if len(packed) > dns.MaxMsgSize {
t.Errorf("message size %d exceeds max %d", len(packed), dns.MaxMsgSize)
}
total += len(msg.Answer)
}
// 300 TXT + 2 SOA = 302 records
if total != 302 {
t.Errorf("expected 302 records total, got %d", total)
}
// Must be split into multiple messages.
if len(w.Msgs) < 2 {
t.Errorf("expected multiple messages for large transfer, got %d", len(w.Msgs))
}
}