Fix context passing (#2681)

This commit is contained in:
Stefan Budeanu
2019-03-13 14:08:33 -04:00
committed by Miek Gieben
parent 26e4026ec1
commit f798d18bdd
3 changed files with 26 additions and 8 deletions

View File

@@ -6,7 +6,7 @@ import (
) )
func TestDnstapContext(t *testing.T) { func TestDnstapContext(t *testing.T) {
ctx := tapContext{context.TODO(), Dnstap{}} ctx := ContextWithTapper(context.TODO(), Dnstap{})
tapper := TapperFromContext(ctx) tapper := TapperFromContext(ctx)
if tapper == nil { if tapper == nil {

View File

@@ -0,0 +1,23 @@
package dnstap
import "context"
type contextKey struct{}
var dnstapKey = contextKey{}
// ContextWithTapper returns a new `context.Context` that holds a reference to
// `t`'s Tapper.
func ContextWithTapper(ctx context.Context, t Tapper) context.Context {
return context.WithValue(ctx, dnstapKey, t)
}
// TapperFromContext returns the `Tapper` previously associated with `ctx`, or
// `nil` if no such `Tapper` could be found.
func TapperFromContext(ctx context.Context) Tapper {
val := ctx.Value(dnstapKey)
if sp, ok := val.(Tapper); ok {
return sp
}
return nil
}

View File

@@ -44,12 +44,6 @@ const (
DnstapSendOption ContextKey = "dnstap-send-option" DnstapSendOption ContextKey = "dnstap-send-option"
) )
// TapperFromContext will return a Tapper if the dnstap plugin is enabled.
func TapperFromContext(ctx context.Context) (t Tapper) {
t, _ = ctx.(Tapper)
return
}
// TapMessage implements Tapper. // TapMessage implements Tapper.
func (h Dnstap) TapMessage(m *tap.Message) { func (h Dnstap) TapMessage(m *tap.Message) {
t := tap.Dnstap_MESSAGE t := tap.Dnstap_MESSAGE
@@ -71,6 +65,7 @@ func (h Dnstap) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg)
// message to be sent out // message to be sent out
sendOption := taprw.SendOption{Cq: true, Cr: true} sendOption := taprw.SendOption{Cq: true, Cr: true}
newCtx := context.WithValue(ctx, DnstapSendOption, &sendOption) newCtx := context.WithValue(ctx, DnstapSendOption, &sendOption)
newCtx = ContextWithTapper(newCtx, h)
rw := &taprw.ResponseWriter{ rw := &taprw.ResponseWriter{
ResponseWriter: w, ResponseWriter: w,
@@ -80,7 +75,7 @@ func (h Dnstap) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg)
QueryEpoch: time.Now(), QueryEpoch: time.Now(),
} }
code, err := plugin.NextOrFailure(h.Name(), h.Next, tapContext{newCtx, h}, rw, r) code, err := plugin.NextOrFailure(h.Name(), h.Next, newCtx, rw, r)
if err != nil { if err != nil {
// ignore dnstap errors // ignore dnstap errors
return code, err return code, err