fix(dnstap): Better error handling (redial & logging) when Dnstap is busy (#7619)

* Fix dnstap redial & improve logging

Signed-off-by: xyang378 <xyang378@bloomberg.net>

* fix CR comments

Signed-off-by: xyang378 <xyang378@bloomberg.net>

* redial at interval

Signed-off-by: xyang378 <xyang378@bloomberg.net>

* CR comments & lint
Signed-off-by: xyang378 <xyang378@bloomberg.net>

CR comment

* fix lint

Signed-off-by: xyang378 <xyang378@bloomberg.net>

---------

Signed-off-by: xyang378 <xyang378@bloomberg.net>
This commit is contained in:
Alicia Y
2025-11-06 21:11:08 +00:00
committed by GitHub
parent 18e70fcde6
commit 59afd4b65e
2 changed files with 198 additions and 57 deletions

View File

@@ -2,6 +2,7 @@ package dnstap
import ( import (
"crypto/tls" "crypto/tls"
"errors"
"net" "net"
"sync/atomic" "sync/atomic"
"time" "time"
@@ -13,8 +14,9 @@ const (
tcpWriteBufSize = 1024 * 1024 // there is no good explanation for why this number has this value. tcpWriteBufSize = 1024 * 1024 // there is no good explanation for why this number has this value.
queueSize = 10000 // idem. queueSize = 10000 // idem.
tcpTimeout = 4 * time.Second tcpTimeout = 4 * time.Second
flushTimeout = 1 * time.Second flushTimeout = 1 * time.Second
errorCheckInterval = 10 * time.Second
skipVerify = false // by default, every tls connection is verified to be secure skipVerify = false // by default, every tls connection is verified to be secure
) )
@@ -24,31 +26,41 @@ type tapper interface {
Dnstap(*tap.Dnstap) Dnstap(*tap.Dnstap)
} }
type WarnLogger interface {
Warningf(format string, v ...any)
}
// dio implements the Tapper interface. // dio implements the Tapper interface.
type dio struct { type dio struct {
endpoint string endpoint string
proto string proto string
enc *encoder enc *encoder
queue chan *tap.Dnstap queue chan *tap.Dnstap
dropped uint32 dropped uint32
quit chan struct{} quit chan struct{}
flushTimeout time.Duration flushTimeout time.Duration
tcpTimeout time.Duration tcpTimeout time.Duration
skipVerify bool skipVerify bool
tcpWriteBufSize int tcpWriteBufSize int
logger WarnLogger
errorCheckInterval time.Duration
} }
var errNoOutput = errors.New("dnstap not connected to output socket")
// newIO returns a new and initialized pointer to a dio. // newIO returns a new and initialized pointer to a dio.
func newIO(proto, endpoint string, multipleQueue int, multipleTcpWriteBuf int) *dio { func newIO(proto, endpoint string, multipleQueue int, multipleTcpWriteBuf int) *dio {
return &dio{ return &dio{
endpoint: endpoint, endpoint: endpoint,
proto: proto, proto: proto,
queue: make(chan *tap.Dnstap, multipleQueue*queueSize), queue: make(chan *tap.Dnstap, multipleQueue*queueSize),
quit: make(chan struct{}), quit: make(chan struct{}),
flushTimeout: flushTimeout, flushTimeout: flushTimeout,
tcpTimeout: tcpTimeout, tcpTimeout: tcpTimeout,
skipVerify: skipVerify, skipVerify: skipVerify,
tcpWriteBufSize: multipleTcpWriteBuf * tcpWriteBufSize, tcpWriteBufSize: multipleTcpWriteBuf * tcpWriteBufSize,
logger: log,
errorCheckInterval: errorCheckInterval,
} }
} }
@@ -104,21 +116,21 @@ func (d *dio) close() { close(d.quit) }
func (d *dio) write(payload *tap.Dnstap) error { func (d *dio) write(payload *tap.Dnstap) error {
if d.enc == nil { if d.enc == nil {
atomic.AddUint32(&d.dropped, 1) return errNoOutput
return nil
} }
if err := d.enc.writeMsg(payload); err != nil { if err := d.enc.writeMsg(payload); err != nil {
atomic.AddUint32(&d.dropped, 1)
return err return err
} }
return nil return nil
} }
func (d *dio) serve() { func (d *dio) serve() {
timeout := time.NewTimer(d.flushTimeout) flushTicker := time.NewTicker(d.flushTimeout)
defer timeout.Stop() errorCheckTicker := time.NewTicker(d.errorCheckInterval)
defer flushTicker.Stop()
defer errorCheckTicker.Stop()
for { for {
timeout.Reset(d.flushTimeout)
select { select {
case <-d.quit: case <-d.quit:
if d.enc == nil { if d.enc == nil {
@@ -129,16 +141,22 @@ func (d *dio) serve() {
return return
case payload := <-d.queue: case payload := <-d.queue:
if err := d.write(payload); err != nil { if err := d.write(payload); err != nil {
d.dial() atomic.AddUint32(&d.dropped, 1)
if !errors.Is(err, errNoOutput) {
// Redial immediately if it's not an output connection error
d.dial()
}
} }
case <-timeout.C: case <-flushTicker.C:
if d.enc != nil {
d.enc.flush()
}
case <-errorCheckTicker.C:
if dropped := atomic.SwapUint32(&d.dropped, 0); dropped > 0 { if dropped := atomic.SwapUint32(&d.dropped, 0); dropped > 0 {
log.Warningf("Dropped dnstap messages: %d", dropped) d.logger.Warningf("Dropped dnstap messages: %d\n", dropped)
} }
if d.enc == nil { if d.enc == nil {
d.dial() d.dial()
} else {
d.enc.flush()
} }
} }
} }

View File

@@ -1,6 +1,7 @@
package dnstap package dnstap
import ( import (
"fmt"
"net" "net"
"sync" "sync"
"testing" "testing"
@@ -10,6 +11,7 @@ import (
tap "github.com/dnstap/golang-dnstap" tap "github.com/dnstap/golang-dnstap"
fs "github.com/farsightsec/golang-framestream" fs "github.com/farsightsec/golang-framestream"
"github.com/stretchr/testify/require"
) )
var ( var (
@@ -17,6 +19,16 @@ var (
tmsg = tap.Dnstap{Type: &msgType} tmsg = tap.Dnstap{Type: &msgType}
) )
type MockLogger struct {
WarnCount int
WarnLog string
}
func (l *MockLogger) Warningf(format string, v ...any) {
l.WarnCount++
l.WarnLog += fmt.Sprintf(format, v...)
}
func accept(t *testing.T, l net.Listener, count int) { func accept(t *testing.T, l net.Listener, count int) {
t.Helper() t.Helper()
server, err := l.Accept() server, err := l.Accept()
@@ -64,6 +76,7 @@ func TestTransport(t *testing.T) {
dio := newIO(param[0], l.Addr().String(), 1, 1) dio := newIO(param[0], l.Addr().String(), 1, 1)
dio.tcpTimeout = 10 * time.Millisecond dio.tcpTimeout = 10 * time.Millisecond
dio.flushTimeout = 30 * time.Millisecond dio.flushTimeout = 30 * time.Millisecond
dio.errorCheckInterval = 50 * time.Millisecond
dio.connect() dio.connect()
dio.Dnstap(&tmsg) dio.Dnstap(&tmsg)
@@ -93,6 +106,7 @@ func TestRace(t *testing.T) {
dio := newIO("tcp", l.Addr().String(), 1, 1) dio := newIO("tcp", l.Addr().String(), 1, 1)
dio.tcpTimeout = 10 * time.Millisecond dio.tcpTimeout = 10 * time.Millisecond
dio.flushTimeout = 30 * time.Millisecond dio.flushTimeout = 30 * time.Millisecond
dio.errorCheckInterval = 50 * time.Millisecond
dio.connect() dio.connect()
defer dio.close() defer dio.close()
@@ -108,12 +122,131 @@ func TestRace(t *testing.T) {
} }
func TestReconnect(t *testing.T) { func TestReconnect(t *testing.T) {
count := 5 t.Run("ConnectedOnStart", func(t *testing.T) {
// GIVEN
// TCP connection available before DnsTap start up
// DnsTap successfully established output connection on start up
l, err := reuseport.Listen("tcp", ":0")
if err != nil {
t.Fatalf("Cannot start listener: %s", err)
}
l, err := reuseport.Listen("tcp", ":0") var wg sync.WaitGroup
wg.Add(1)
go func() {
accept(t, l, 1)
wg.Done()
}()
addr := l.Addr().String()
logger := MockLogger{}
dio := newIO("tcp", addr, 1, 1)
dio.tcpTimeout = 10 * time.Millisecond
dio.flushTimeout = 30 * time.Millisecond
dio.errorCheckInterval = 50 * time.Millisecond
dio.logger = &logger
dio.connect()
defer dio.close()
// WHEN
// TCP connection closed when DnsTap is still running
// TCP listener starts again on the same port
// DnsTap send multiple messages
dio.Dnstap(&tmsg)
wg.Wait()
// Close listener
l.Close()
// And start TCP listener again on the same port
l, err = reuseport.Listen("tcp", addr)
if err != nil {
t.Fatalf("Cannot start listener: %s", err)
}
defer l.Close()
wg.Add(1)
go func() {
accept(t, l, 1)
wg.Done()
}()
messageCount := 5
for range messageCount {
time.Sleep(100 * time.Millisecond)
dio.Dnstap(&tmsg)
}
wg.Wait()
// THEN
// DnsTap is able to reconnect
// Messages can be sent eventually
require.NotNil(t, dio.enc)
require.Equal(t, 0, len(dio.queue))
require.Less(t, logger.WarnCount, messageCount)
})
t.Run("NotConnectedOnStart", func(t *testing.T) {
// GIVEN
// No TCP connection established at DnsTap start up
l, err := reuseport.Listen("tcp", ":0")
if err != nil {
t.Fatalf("Cannot start listener: %s", err)
}
l.Close()
logger := MockLogger{}
addr := l.Addr().String()
dio := newIO("tcp", addr, 1, 1)
dio.tcpTimeout = 10 * time.Millisecond
dio.flushTimeout = 30 * time.Millisecond
dio.errorCheckInterval = 50 * time.Millisecond
dio.logger = &logger
dio.connect()
defer dio.close()
// WHEN
// DnsTap is already running
// TCP listener starts on DnsTap's configured port
// DnsTap send multiple messages
dio.Dnstap(&tmsg)
l, err = reuseport.Listen("tcp", addr)
if err != nil {
t.Fatalf("Cannot start listener: %s", err)
}
defer l.Close()
var wg sync.WaitGroup
wg.Add(1)
messageCount := 5
go func() {
accept(t, l, messageCount)
wg.Done()
}()
for range messageCount {
time.Sleep(100 * time.Millisecond)
dio.Dnstap(&tmsg)
}
wg.Wait()
// THEN
// DnsTap is able to reconnect
// Messages can be sent eventually
require.NotNil(t, dio.enc)
require.Equal(t, 0, len(dio.queue))
require.Less(t, logger.WarnCount, messageCount)
})
}
func TestFullQueueWriteFail(t *testing.T) {
// GIVEN
// DnsTap I/O with a small queue
l, err := reuseport.Listen("unix", "dn2stap.sock")
if err != nil { if err != nil {
t.Fatalf("Cannot start listener: %s", err) t.Fatalf("Cannot start listener: %s", err)
} }
defer l.Close()
var wg sync.WaitGroup var wg sync.WaitGroup
wg.Add(1) wg.Add(1)
@@ -122,35 +255,25 @@ func TestReconnect(t *testing.T) {
wg.Done() wg.Done()
}() }()
addr := l.Addr().String() logger := MockLogger{}
dio := newIO("tcp", addr, 1, 1) dio := newIO("unix", l.Addr().String(), 1, 1)
dio.tcpTimeout = 10 * time.Millisecond dio.flushTimeout = 500 * time.Millisecond
dio.flushTimeout = 30 * time.Millisecond dio.errorCheckInterval = 50 * time.Millisecond
dio.logger = &logger
dio.queue = make(chan *tap.Dnstap, 1)
dio.connect() dio.connect()
defer dio.close() defer dio.close()
dio.Dnstap(&tmsg) // WHEN
// messages overwhelms the queue
wg.Wait() count := 100
// Close listener
l.Close()
// And start TCP listener again on the same port
l, err = reuseport.Listen("tcp", addr)
if err != nil {
t.Fatalf("Cannot start listener: %s", err)
}
defer l.Close()
wg.Add(1)
go func() {
accept(t, l, 1)
wg.Done()
}()
for range count { for range count {
time.Sleep(100 * time.Millisecond)
dio.Dnstap(&tmsg) dio.Dnstap(&tmsg)
} }
wg.Wait() wg.Wait()
// THEN
// Dropped messages are logged
require.NotEqual(t, 0, logger.WarnCount)
require.Contains(t, logger.WarnLog, "Dropped dnstap messages")
} }