mirror of
https://github.com/coredns/coredns.git
synced 2026-04-05 03:35:33 -04:00
fix(transfer): batch AXFR records by message size instead of count (#8002)
This commit is contained in:
@@ -124,25 +124,33 @@ func (t *Transfer) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Ms
|
||||
|
||||
rrs := []dns.RR{}
|
||||
l := 0
|
||||
batchSize := 0
|
||||
var soa *dns.SOA
|
||||
for records := range pchan {
|
||||
if x, ok := records[0].(*dns.SOA); ok && soa == nil {
|
||||
soa = x
|
||||
}
|
||||
rrs = append(rrs, records...)
|
||||
if len(rrs) > 500 {
|
||||
select {
|
||||
case ch <- &dns.Envelope{RR: rrs}:
|
||||
case err := <-errCh:
|
||||
// Client errored; drain pchan to avoid blocking the producer goroutine.
|
||||
go func() {
|
||||
for range pchan {
|
||||
}
|
||||
}()
|
||||
return dns.RcodeServerFailure, err
|
||||
for _, rr := range records {
|
||||
rrLen := dns.Len(rr)
|
||||
// Flush the batch before it exceeds the 64KB TCP message limit.
|
||||
// The 12-byte header and question section are not counted in rrLen,
|
||||
// so we use a conservative threshold to leave room for framing.
|
||||
if len(rrs) > 0 && batchSize+rrLen > 63000 {
|
||||
select {
|
||||
case ch <- &dns.Envelope{RR: rrs}:
|
||||
case err := <-errCh:
|
||||
go func() {
|
||||
for range pchan {
|
||||
}
|
||||
}()
|
||||
return dns.RcodeServerFailure, err
|
||||
}
|
||||
l += len(rrs)
|
||||
rrs = []dns.RR{}
|
||||
batchSize = 0
|
||||
}
|
||||
l += len(rrs)
|
||||
rrs = []dns.RR{}
|
||||
rrs = append(rrs, rr)
|
||||
batchSize += rrLen
|
||||
}
|
||||
}
|
||||
|
||||
@@ -166,7 +174,12 @@ func (t *Transfer) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Ms
|
||||
}
|
||||
|
||||
if len(rrs) > 0 {
|
||||
ch <- &dns.Envelope{RR: rrs}
|
||||
select {
|
||||
case ch <- &dns.Envelope{RR: rrs}:
|
||||
case err := <-errCh:
|
||||
close(ch)
|
||||
return dns.RcodeServerFailure, err
|
||||
}
|
||||
l += len(rrs)
|
||||
}
|
||||
|
||||
|
||||
@@ -388,3 +388,89 @@ func TestLongestMatchNilWhenNoMatch(t *testing.T) {
|
||||
t.Fatalf("expected nil when no zones match, got %+v", got)
|
||||
}
|
||||
}
|
||||
|
||||
// largeRecordTransferer produces records that are large enough to exceed 64KB
|
||||
// if batched by a fixed record count.
|
||||
type largeRecordTransferer struct {
|
||||
Zone string
|
||||
Count int
|
||||
TxtSize int
|
||||
}
|
||||
|
||||
func (lr *largeRecordTransferer) Name() string { return "largerecordtransferer" }
|
||||
func (lr *largeRecordTransferer) ServeDNS(_ context.Context, _ dns.ResponseWriter, _ *dns.Msg) (int, error) {
|
||||
return 0, nil
|
||||
}
|
||||
func (lr *largeRecordTransferer) Transfer(zone string, _ uint32) (<-chan []dns.RR, error) {
|
||||
if zone != lr.Zone {
|
||||
return nil, ErrNotAuthoritative
|
||||
}
|
||||
ch := make(chan []dns.RR, 2)
|
||||
go func() {
|
||||
defer close(ch)
|
||||
soa := test.SOA(fmt.Sprintf("%s 100 IN SOA ns.dns.%s hostmaster.%s 1 7200 1800 86400 100", lr.Zone, lr.Zone, lr.Zone))
|
||||
ch <- []dns.RR{soa}
|
||||
payload := make([]byte, lr.TxtSize)
|
||||
for i := range payload {
|
||||
payload[i] = 'x'
|
||||
}
|
||||
txt := string(payload)
|
||||
for i := range lr.Count {
|
||||
rr, _ := dns.NewRR(fmt.Sprintf("r%d.%s 3600 IN TXT \"%s\"", i, lr.Zone, txt))
|
||||
ch <- []dns.RR{rr}
|
||||
}
|
||||
ch <- []dns.RR{soa}
|
||||
}()
|
||||
return ch, nil
|
||||
}
|
||||
|
||||
func TestTransferLargeRecordBatching(t *testing.T) {
|
||||
// 300 TXT records of ~250 bytes each = ~75KB total, exceeding a single
|
||||
// 64KB TCP message. The transfer plugin must split them into multiple
|
||||
// messages.
|
||||
lr := &largeRecordTransferer{Zone: "example.org.", Count: 300, TxtSize: 240}
|
||||
|
||||
tr := &Transfer{
|
||||
Transferers: []Transferer{lr},
|
||||
xfrs: []*xfr{{Zones: []string{"example.org."}, to: []string{"*"}}},
|
||||
Next: lr,
|
||||
}
|
||||
|
||||
ctx := context.TODO()
|
||||
w := dnstest.NewMultiRecorder(&test.ResponseWriter{TCP: true})
|
||||
m := new(dns.Msg)
|
||||
m.SetAxfr("example.org.")
|
||||
|
||||
_, err := tr.ServeDNS(ctx, w, m)
|
||||
if err != nil {
|
||||
t.Fatalf("ServeDNS error: %v", err)
|
||||
}
|
||||
|
||||
if len(w.Msgs) == 0 {
|
||||
t.Fatal("no messages received")
|
||||
}
|
||||
|
||||
// Count total records across all messages.
|
||||
total := 0
|
||||
for _, msg := range w.Msgs {
|
||||
// Each message must fit in a TCP DNS message (65535 bytes).
|
||||
packed, packErr := msg.Pack()
|
||||
if packErr != nil {
|
||||
t.Fatalf("message too large to pack: %v", packErr)
|
||||
}
|
||||
if len(packed) > dns.MaxMsgSize {
|
||||
t.Errorf("message size %d exceeds max %d", len(packed), dns.MaxMsgSize)
|
||||
}
|
||||
total += len(msg.Answer)
|
||||
}
|
||||
|
||||
// 300 TXT + 2 SOA = 302 records
|
||||
if total != 302 {
|
||||
t.Errorf("expected 302 records total, got %d", total)
|
||||
}
|
||||
|
||||
// Must be split into multiple messages.
|
||||
if len(w.Msgs) < 2 {
|
||||
t.Errorf("expected multiple messages for large transfer, got %d", len(w.Msgs))
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user