From 61f4145506d9bbdae3d9f7cb56e9ea2e7fb1bb26 Mon Sep 17 00:00:00 2001 From: Umut Polat <52835619+umut-polat@users.noreply.github.com> Date: Sat, 4 Apr 2026 21:35:27 +0300 Subject: [PATCH] fix(transfer): batch AXFR records by message size instead of count (#8002) --- plugin/transfer/transfer.go | 41 +++++++++------ plugin/transfer/transfer_test.go | 86 ++++++++++++++++++++++++++++++++ 2 files changed, 113 insertions(+), 14 deletions(-) diff --git a/plugin/transfer/transfer.go b/plugin/transfer/transfer.go index bbd91b5eb..d4d26c982 100644 --- a/plugin/transfer/transfer.go +++ b/plugin/transfer/transfer.go @@ -124,25 +124,33 @@ func (t *Transfer) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Ms rrs := []dns.RR{} l := 0 + batchSize := 0 var soa *dns.SOA for records := range pchan { if x, ok := records[0].(*dns.SOA); ok && soa == nil { soa = x } - rrs = append(rrs, records...) - if len(rrs) > 500 { - select { - case ch <- &dns.Envelope{RR: rrs}: - case err := <-errCh: - // Client errored; drain pchan to avoid blocking the producer goroutine. - go func() { - for range pchan { - } - }() - return dns.RcodeServerFailure, err + for _, rr := range records { + rrLen := dns.Len(rr) + // Flush the batch before it exceeds the 64KB TCP message limit. + // The 12-byte header and question section are not counted in rrLen, + // so we use a conservative threshold to leave room for framing. + if len(rrs) > 0 && batchSize+rrLen > 63000 { + select { + case ch <- &dns.Envelope{RR: rrs}: + case err := <-errCh: + go func() { + for range pchan { + } + }() + return dns.RcodeServerFailure, err + } + l += len(rrs) + rrs = []dns.RR{} + batchSize = 0 } - l += len(rrs) - rrs = []dns.RR{} + rrs = append(rrs, rr) + batchSize += rrLen } } @@ -166,7 +174,12 @@ func (t *Transfer) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Ms } if len(rrs) > 0 { - ch <- &dns.Envelope{RR: rrs} + select { + case ch <- &dns.Envelope{RR: rrs}: + case err := <-errCh: + close(ch) + return dns.RcodeServerFailure, err + } l += len(rrs) } diff --git a/plugin/transfer/transfer_test.go b/plugin/transfer/transfer_test.go index e968b5b4d..10c77b4d9 100644 --- a/plugin/transfer/transfer_test.go +++ b/plugin/transfer/transfer_test.go @@ -388,3 +388,89 @@ func TestLongestMatchNilWhenNoMatch(t *testing.T) { t.Fatalf("expected nil when no zones match, got %+v", got) } } + +// largeRecordTransferer produces records that are large enough to exceed 64KB +// if batched by a fixed record count. +type largeRecordTransferer struct { + Zone string + Count int + TxtSize int +} + +func (lr *largeRecordTransferer) Name() string { return "largerecordtransferer" } +func (lr *largeRecordTransferer) ServeDNS(_ context.Context, _ dns.ResponseWriter, _ *dns.Msg) (int, error) { + return 0, nil +} +func (lr *largeRecordTransferer) Transfer(zone string, _ uint32) (<-chan []dns.RR, error) { + if zone != lr.Zone { + return nil, ErrNotAuthoritative + } + ch := make(chan []dns.RR, 2) + go func() { + defer close(ch) + soa := test.SOA(fmt.Sprintf("%s 100 IN SOA ns.dns.%s hostmaster.%s 1 7200 1800 86400 100", lr.Zone, lr.Zone, lr.Zone)) + ch <- []dns.RR{soa} + payload := make([]byte, lr.TxtSize) + for i := range payload { + payload[i] = 'x' + } + txt := string(payload) + for i := range lr.Count { + rr, _ := dns.NewRR(fmt.Sprintf("r%d.%s 3600 IN TXT \"%s\"", i, lr.Zone, txt)) + ch <- []dns.RR{rr} + } + ch <- []dns.RR{soa} + }() + return ch, nil +} + +func TestTransferLargeRecordBatching(t *testing.T) { + // 300 TXT records of ~250 bytes each = ~75KB total, exceeding a single + // 64KB TCP message. The transfer plugin must split them into multiple + // messages. + lr := &largeRecordTransferer{Zone: "example.org.", Count: 300, TxtSize: 240} + + tr := &Transfer{ + Transferers: []Transferer{lr}, + xfrs: []*xfr{{Zones: []string{"example.org."}, to: []string{"*"}}}, + Next: lr, + } + + ctx := context.TODO() + w := dnstest.NewMultiRecorder(&test.ResponseWriter{TCP: true}) + m := new(dns.Msg) + m.SetAxfr("example.org.") + + _, err := tr.ServeDNS(ctx, w, m) + if err != nil { + t.Fatalf("ServeDNS error: %v", err) + } + + if len(w.Msgs) == 0 { + t.Fatal("no messages received") + } + + // Count total records across all messages. + total := 0 + for _, msg := range w.Msgs { + // Each message must fit in a TCP DNS message (65535 bytes). + packed, packErr := msg.Pack() + if packErr != nil { + t.Fatalf("message too large to pack: %v", packErr) + } + if len(packed) > dns.MaxMsgSize { + t.Errorf("message size %d exceeds max %d", len(packed), dns.MaxMsgSize) + } + total += len(msg.Answer) + } + + // 300 TXT + 2 SOA = 302 records + if total != 302 { + t.Errorf("expected 302 records total, got %d", total) + } + + // Must be split into multiple messages. + if len(w.Msgs) < 2 { + t.Errorf("expected multiple messages for large transfer, got %d", len(w.Msgs)) + } +}