mirror of
https://github.com/coredns/coredns.git
synced 2026-04-05 11:45:33 -04:00
fix(transfer): batch AXFR records by message size instead of count (#8002)
This commit is contained in:
@@ -124,25 +124,33 @@ func (t *Transfer) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Ms
|
|||||||
|
|
||||||
rrs := []dns.RR{}
|
rrs := []dns.RR{}
|
||||||
l := 0
|
l := 0
|
||||||
|
batchSize := 0
|
||||||
var soa *dns.SOA
|
var soa *dns.SOA
|
||||||
for records := range pchan {
|
for records := range pchan {
|
||||||
if x, ok := records[0].(*dns.SOA); ok && soa == nil {
|
if x, ok := records[0].(*dns.SOA); ok && soa == nil {
|
||||||
soa = x
|
soa = x
|
||||||
}
|
}
|
||||||
rrs = append(rrs, records...)
|
for _, rr := range records {
|
||||||
if len(rrs) > 500 {
|
rrLen := dns.Len(rr)
|
||||||
select {
|
// Flush the batch before it exceeds the 64KB TCP message limit.
|
||||||
case ch <- &dns.Envelope{RR: rrs}:
|
// The 12-byte header and question section are not counted in rrLen,
|
||||||
case err := <-errCh:
|
// so we use a conservative threshold to leave room for framing.
|
||||||
// Client errored; drain pchan to avoid blocking the producer goroutine.
|
if len(rrs) > 0 && batchSize+rrLen > 63000 {
|
||||||
go func() {
|
select {
|
||||||
for range pchan {
|
case ch <- &dns.Envelope{RR: rrs}:
|
||||||
}
|
case err := <-errCh:
|
||||||
}()
|
go func() {
|
||||||
return dns.RcodeServerFailure, err
|
for range pchan {
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
return dns.RcodeServerFailure, err
|
||||||
|
}
|
||||||
|
l += len(rrs)
|
||||||
|
rrs = []dns.RR{}
|
||||||
|
batchSize = 0
|
||||||
}
|
}
|
||||||
l += len(rrs)
|
rrs = append(rrs, rr)
|
||||||
rrs = []dns.RR{}
|
batchSize += rrLen
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -166,7 +174,12 @@ func (t *Transfer) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Ms
|
|||||||
}
|
}
|
||||||
|
|
||||||
if len(rrs) > 0 {
|
if len(rrs) > 0 {
|
||||||
ch <- &dns.Envelope{RR: rrs}
|
select {
|
||||||
|
case ch <- &dns.Envelope{RR: rrs}:
|
||||||
|
case err := <-errCh:
|
||||||
|
close(ch)
|
||||||
|
return dns.RcodeServerFailure, err
|
||||||
|
}
|
||||||
l += len(rrs)
|
l += len(rrs)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -388,3 +388,89 @@ func TestLongestMatchNilWhenNoMatch(t *testing.T) {
|
|||||||
t.Fatalf("expected nil when no zones match, got %+v", got)
|
t.Fatalf("expected nil when no zones match, got %+v", got)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// largeRecordTransferer produces records that are large enough to exceed 64KB
|
||||||
|
// if batched by a fixed record count.
|
||||||
|
type largeRecordTransferer struct {
|
||||||
|
Zone string
|
||||||
|
Count int
|
||||||
|
TxtSize int
|
||||||
|
}
|
||||||
|
|
||||||
|
func (lr *largeRecordTransferer) Name() string { return "largerecordtransferer" }
|
||||||
|
func (lr *largeRecordTransferer) ServeDNS(_ context.Context, _ dns.ResponseWriter, _ *dns.Msg) (int, error) {
|
||||||
|
return 0, nil
|
||||||
|
}
|
||||||
|
func (lr *largeRecordTransferer) Transfer(zone string, _ uint32) (<-chan []dns.RR, error) {
|
||||||
|
if zone != lr.Zone {
|
||||||
|
return nil, ErrNotAuthoritative
|
||||||
|
}
|
||||||
|
ch := make(chan []dns.RR, 2)
|
||||||
|
go func() {
|
||||||
|
defer close(ch)
|
||||||
|
soa := test.SOA(fmt.Sprintf("%s 100 IN SOA ns.dns.%s hostmaster.%s 1 7200 1800 86400 100", lr.Zone, lr.Zone, lr.Zone))
|
||||||
|
ch <- []dns.RR{soa}
|
||||||
|
payload := make([]byte, lr.TxtSize)
|
||||||
|
for i := range payload {
|
||||||
|
payload[i] = 'x'
|
||||||
|
}
|
||||||
|
txt := string(payload)
|
||||||
|
for i := range lr.Count {
|
||||||
|
rr, _ := dns.NewRR(fmt.Sprintf("r%d.%s 3600 IN TXT \"%s\"", i, lr.Zone, txt))
|
||||||
|
ch <- []dns.RR{rr}
|
||||||
|
}
|
||||||
|
ch <- []dns.RR{soa}
|
||||||
|
}()
|
||||||
|
return ch, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTransferLargeRecordBatching(t *testing.T) {
|
||||||
|
// 300 TXT records of ~250 bytes each = ~75KB total, exceeding a single
|
||||||
|
// 64KB TCP message. The transfer plugin must split them into multiple
|
||||||
|
// messages.
|
||||||
|
lr := &largeRecordTransferer{Zone: "example.org.", Count: 300, TxtSize: 240}
|
||||||
|
|
||||||
|
tr := &Transfer{
|
||||||
|
Transferers: []Transferer{lr},
|
||||||
|
xfrs: []*xfr{{Zones: []string{"example.org."}, to: []string{"*"}}},
|
||||||
|
Next: lr,
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx := context.TODO()
|
||||||
|
w := dnstest.NewMultiRecorder(&test.ResponseWriter{TCP: true})
|
||||||
|
m := new(dns.Msg)
|
||||||
|
m.SetAxfr("example.org.")
|
||||||
|
|
||||||
|
_, err := tr.ServeDNS(ctx, w, m)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("ServeDNS error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(w.Msgs) == 0 {
|
||||||
|
t.Fatal("no messages received")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Count total records across all messages.
|
||||||
|
total := 0
|
||||||
|
for _, msg := range w.Msgs {
|
||||||
|
// Each message must fit in a TCP DNS message (65535 bytes).
|
||||||
|
packed, packErr := msg.Pack()
|
||||||
|
if packErr != nil {
|
||||||
|
t.Fatalf("message too large to pack: %v", packErr)
|
||||||
|
}
|
||||||
|
if len(packed) > dns.MaxMsgSize {
|
||||||
|
t.Errorf("message size %d exceeds max %d", len(packed), dns.MaxMsgSize)
|
||||||
|
}
|
||||||
|
total += len(msg.Answer)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 300 TXT + 2 SOA = 302 records
|
||||||
|
if total != 302 {
|
||||||
|
t.Errorf("expected 302 records total, got %d", total)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Must be split into multiple messages.
|
||||||
|
if len(w.Msgs) < 2 {
|
||||||
|
t.Errorf("expected multiple messages for large transfer, got %d", len(w.Msgs))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user