mirror of
				https://github.com/coredns/coredns.git
				synced 2025-11-03 18:53:13 -05:00 
			
		
		
		
	fix(plugin): guard nil lookups across plugins (#7494)
This commit is contained in:
		@@ -64,7 +64,7 @@ func A(ctx context.Context, b ServiceBackend, zone string, state request.Request
 | 
			
		||||
			target := newRecord.Target
 | 
			
		||||
			// Lookup
 | 
			
		||||
			m1, e1 := b.Lookup(ctx, state, target, state.QType())
 | 
			
		||||
			if e1 != nil {
 | 
			
		||||
			if e1 != nil || m1 == nil {
 | 
			
		||||
				continue
 | 
			
		||||
			}
 | 
			
		||||
			if m1.Truncated {
 | 
			
		||||
@@ -137,7 +137,7 @@ func AAAA(ctx context.Context, b ServiceBackend, zone string, state request.Requ
 | 
			
		||||
			// This means we can not complete the CNAME, try to look else where.
 | 
			
		||||
			target := newRecord.Target
 | 
			
		||||
			m1, e1 := b.Lookup(ctx, state, target, state.QType())
 | 
			
		||||
			if e1 != nil {
 | 
			
		||||
			if e1 != nil || m1 == nil {
 | 
			
		||||
				continue
 | 
			
		||||
			}
 | 
			
		||||
			if m1.Truncated {
 | 
			
		||||
@@ -219,12 +219,12 @@ func SRV(ctx context.Context, b ServiceBackend, zone string, state request.Reque
 | 
			
		||||
 | 
			
		||||
			if !dns.IsSubDomain(zone, srv.Target) {
 | 
			
		||||
				m1, e1 := b.Lookup(ctx, state, srv.Target, dns.TypeA)
 | 
			
		||||
				if e1 == nil {
 | 
			
		||||
				if e1 == nil && m1 != nil {
 | 
			
		||||
					extra = append(extra, m1.Answer...)
 | 
			
		||||
				}
 | 
			
		||||
 | 
			
		||||
				m1, e1 = b.Lookup(ctx, state, srv.Target, dns.TypeAAAA)
 | 
			
		||||
				if e1 == nil {
 | 
			
		||||
				if e1 == nil && m1 != nil {
 | 
			
		||||
					// If we have seen CNAME's we *assume* that they are already added.
 | 
			
		||||
					for _, a := range m1.Answer {
 | 
			
		||||
						if _, ok := a.(*dns.CNAME); !ok {
 | 
			
		||||
@@ -286,12 +286,12 @@ func MX(ctx context.Context, b ServiceBackend, zone string, state request.Reques
 | 
			
		||||
 | 
			
		||||
			if !dns.IsSubDomain(zone, mx.Mx) {
 | 
			
		||||
				m1, e1 := b.Lookup(ctx, state, mx.Mx, dns.TypeA)
 | 
			
		||||
				if e1 == nil {
 | 
			
		||||
				if e1 == nil && m1 != nil {
 | 
			
		||||
					extra = append(extra, m1.Answer...)
 | 
			
		||||
				}
 | 
			
		||||
 | 
			
		||||
				m1, e1 = b.Lookup(ctx, state, mx.Mx, dns.TypeAAAA)
 | 
			
		||||
				if e1 == nil {
 | 
			
		||||
				if e1 == nil && m1 != nil {
 | 
			
		||||
					// If we have seen CNAME's we *assume* that they are already added.
 | 
			
		||||
					for _, a := range m1.Answer {
 | 
			
		||||
						if _, ok := a.(*dns.CNAME); !ok {
 | 
			
		||||
@@ -390,7 +390,7 @@ func TXT(ctx context.Context, b ServiceBackend, zone string, state request.Reque
 | 
			
		||||
			target := newRecord.Target
 | 
			
		||||
			// Lookup
 | 
			
		||||
			m1, e1 := b.Lookup(ctx, state, target, state.QType())
 | 
			
		||||
			if e1 != nil {
 | 
			
		||||
			if e1 != nil || m1 == nil {
 | 
			
		||||
				continue
 | 
			
		||||
			}
 | 
			
		||||
			// Len(m1.Answer) > 0 here is well?
 | 
			
		||||
 
 | 
			
		||||
@@ -77,10 +77,17 @@ func (r *cnameTargetRuleWithReqState) RewriteResponse(res *dns.Msg, rr dns.RR) {
 | 
			
		||||
			if cname.Target == fromTarget {
 | 
			
		||||
				// create upstream request with the new target with the same qtype
 | 
			
		||||
				r.state.Req.Question[0].Name = toTarget
 | 
			
		||||
				// upRes can be nil if the internal query path didn't write a response
 | 
			
		||||
				// (e.g. a plugin returned a success rcode without writing, dropped the query,
 | 
			
		||||
				// or the context was canceled). Guard upRes before dereferencing.
 | 
			
		||||
				upRes, err := r.rule.Upstream.Lookup(r.ctx, r.state, toTarget, r.state.Req.Question[0].Qtype)
 | 
			
		||||
 | 
			
		||||
				if err != nil {
 | 
			
		||||
					log.Errorf("Error upstream request %v", err)
 | 
			
		||||
					log.Errorf("upstream lookup failed: %v", err)
 | 
			
		||||
					return
 | 
			
		||||
				}
 | 
			
		||||
				if upRes == nil {
 | 
			
		||||
					log.Errorf("upstream lookup returned nil")
 | 
			
		||||
					return
 | 
			
		||||
				}
 | 
			
		||||
 | 
			
		||||
				var newAnswer []dns.RR
 | 
			
		||||
 
 | 
			
		||||
@@ -2,6 +2,7 @@ package rewrite
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"context"
 | 
			
		||||
	"errors"
 | 
			
		||||
	"reflect"
 | 
			
		||||
	"testing"
 | 
			
		||||
 | 
			
		||||
@@ -207,3 +208,58 @@ func doTestCNameTargetTests(t *testing.T, rules []Rule) {
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// nilUpstream returns a nil message to simulate an upstream failure path.
 | 
			
		||||
type nilUpstream struct{}
 | 
			
		||||
 | 
			
		||||
func (f *nilUpstream) Lookup(ctx context.Context, state request.Request, name string, typ uint16) (*dns.Msg, error) {
 | 
			
		||||
	return nil, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// errUpstream returns a nil message with an error to simulate an upstream failure path.
 | 
			
		||||
type errUpstream struct{}
 | 
			
		||||
 | 
			
		||||
func (f *errUpstream) Lookup(ctx context.Context, state request.Request, name string, typ uint16) (*dns.Msg, error) {
 | 
			
		||||
	return nil, errors.New("upstream failure")
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestCNAMETargetRewrite_upstreamFailurePaths(t *testing.T) {
 | 
			
		||||
	cases := []struct {
 | 
			
		||||
		name     string
 | 
			
		||||
		upstream UpstreamInt
 | 
			
		||||
	}{
 | 
			
		||||
		{name: "nil message, no error", upstream: &nilUpstream{}},
 | 
			
		||||
		{name: "nil message, with error", upstream: &errUpstream{}},
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	for _, tc := range cases {
 | 
			
		||||
		t.Run(tc.name, func(t *testing.T) {
 | 
			
		||||
			rule := cnameTargetRule{
 | 
			
		||||
				rewriteType:     ExactMatch,
 | 
			
		||||
				paramFromTarget: "bad.target.",
 | 
			
		||||
				paramToTarget:   "good.target.",
 | 
			
		||||
				nextAction:      Stop,
 | 
			
		||||
				Upstream:        tc.upstream,
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			req := new(dns.Msg)
 | 
			
		||||
			req.SetQuestion("bad.test.", dns.TypeA)
 | 
			
		||||
			state := request.Request{Req: req}
 | 
			
		||||
 | 
			
		||||
			rrState := &cnameTargetRuleWithReqState{rule: rule, state: state, ctx: context.Background()}
 | 
			
		||||
 | 
			
		||||
			res := new(dns.Msg)
 | 
			
		||||
			res.SetReply(req)
 | 
			
		||||
			res.Answer = []dns.RR{&dns.CNAME{Hdr: dns.RR_Header{Name: "bad.test.", Rrtype: dns.TypeCNAME, Class: dns.ClassINET, Ttl: 60}, Target: "bad.target."}}
 | 
			
		||||
 | 
			
		||||
			rr := &dns.CNAME{Hdr: dns.RR_Header{Name: "bad.test.", Rrtype: dns.TypeCNAME, Class: dns.ClassINET, Ttl: 60}, Target: "bad.target."}
 | 
			
		||||
 | 
			
		||||
			rrState.RewriteResponse(res, rr)
 | 
			
		||||
 | 
			
		||||
			finalTarget := res.Answer[0].(*dns.CNAME).Target
 | 
			
		||||
			if finalTarget != "bad.target." {
 | 
			
		||||
				t.Errorf("Expected answer to be %q, but got %q", "bad.target.", finalTarget)
 | 
			
		||||
			}
 | 
			
		||||
		})
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user