mirror of
				https://github.com/coredns/coredns.git
				synced 2025-10-31 10:13:14 -04:00 
			
		
		
		
	middleware/cache: fix race (#757)
While adding a parallel performance benchmark I stumbled on a race condition (another reason to add performance benchmarks!), so this PR makes sure the msg is created in a race free manor and adds the parallel benchmark.
This commit is contained in:
		
							
								
								
									
										1
									
								
								middleware/cache/cache.go
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										1
									
								
								middleware/cache/cache.go
									
									
									
									
										vendored
									
									
								
							| @@ -113,7 +113,6 @@ func (w *ResponseWriter) WriteMsg(res *dns.Msg) error { | ||||
| 		cacheSize.WithLabelValues(Denial).Set(float64(w.ncache.Len())) | ||||
| 	} | ||||
|  | ||||
| 	setMsgTTL(res, uint32(duration.Seconds())) | ||||
| 	if w.prefetch { | ||||
| 		return nil | ||||
| 	} | ||||
|   | ||||
							
								
								
									
										44
									
								
								middleware/cache/cache_test.go
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										44
									
								
								middleware/cache/cache_test.go
									
									
									
									
										vendored
									
									
								
							| @@ -6,6 +6,8 @@ import ( | ||||
| 	"testing" | ||||
| 	"time" | ||||
|  | ||||
| 	"golang.org/x/net/context" | ||||
|  | ||||
| 	"github.com/coredns/coredns/middleware" | ||||
| 	"github.com/coredns/coredns/middleware/pkg/cache" | ||||
| 	"github.com/coredns/coredns/middleware/pkg/response" | ||||
| @@ -205,3 +207,45 @@ func TestCache(t *testing.T) { | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func BenchmarkCacheResponse(b *testing.B) { | ||||
| 	c := &Cache{Zones: []string{"."}, pcap: defaultCap, ncap: defaultCap, pttl: maxTTL, nttl: maxTTL} | ||||
| 	c.pcache = cache.New(c.pcap) | ||||
| 	c.ncache = cache.New(c.ncap) | ||||
| 	c.prefetch = 1 | ||||
| 	c.duration = 1 * time.Second | ||||
| 	c.Next = BackendHandler() | ||||
|  | ||||
| 	ctx := context.TODO() | ||||
|  | ||||
| 	reqs := make([]*dns.Msg, 5) | ||||
| 	for i, q := range []string{"example1", "example2", "a", "b", "ddd"} { | ||||
| 		reqs[i] = new(dns.Msg) | ||||
| 		reqs[i].SetQuestion(q+".example.org.", dns.TypeA) | ||||
| 	} | ||||
|  | ||||
| 	b.RunParallel(func(pb *testing.PB) { | ||||
| 		i := 0 | ||||
| 		for pb.Next() { | ||||
| 			req := reqs[i] | ||||
| 			c.ServeDNS(ctx, &test.ResponseWriter{}, req) | ||||
| 			i++ | ||||
| 			i = i % 5 | ||||
| 		} | ||||
| 	}) | ||||
| } | ||||
|  | ||||
| func BackendHandler() middleware.Handler { | ||||
| 	return middleware.HandlerFunc(func(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) { | ||||
| 		m := new(dns.Msg) | ||||
| 		m.SetReply(r) | ||||
| 		m.Response = true | ||||
| 		m.RecursionAvailable = true | ||||
|  | ||||
| 		owner := m.Question[0].Name | ||||
| 		m.Answer = []dns.RR{test.A(owner + " 303 IN A 127.0.0.53")} | ||||
|  | ||||
| 		w.WriteMsg(m) | ||||
| 		return dns.RcodeSuccess, nil | ||||
| 	}) | ||||
| } | ||||
|   | ||||
							
								
								
									
										1
									
								
								middleware/cache/handler.go
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										1
									
								
								middleware/cache/handler.go
									
									
									
									
										vendored
									
									
								
							| @@ -29,6 +29,7 @@ func (c *Cache) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) | ||||
| 	i, ttl := c.get(now, qname, qtype, do) | ||||
| 	if i != nil && ttl > 0 { | ||||
| 		resp := i.toMsg(r) | ||||
|  | ||||
| 		state.SizeAndDo(resp) | ||||
| 		resp, _ = state.Scrub(resp) | ||||
| 		w.WriteMsg(resp) | ||||
|   | ||||
							
								
								
									
										48
									
								
								middleware/cache/item.go
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										48
									
								
								middleware/cache/item.go
									
									
									
									
										vendored
									
									
								
							| @@ -63,12 +63,29 @@ func (i *item) toMsg(m *dns.Msg) *dns.Msg { | ||||
| 	m1.Rcode = i.Rcode | ||||
| 	m1.Compress = true | ||||
|  | ||||
| 	m1.Answer = i.Answer | ||||
| 	m1.Ns = i.Ns | ||||
| 	m1.Extra = i.Extra | ||||
| 	m1.Answer = make([]dns.RR, len(i.Answer)) | ||||
| 	m1.Ns = make([]dns.RR, len(i.Ns)) | ||||
| 	m1.Extra = make([]dns.RR, len(i.Extra)) | ||||
|  | ||||
| 	ttl := int(i.origTTL) - int(time.Now().UTC().Sub(i.stored).Seconds()) | ||||
| 	setMsgTTL(m1, uint32(ttl)) | ||||
| 	ttl := uint32(i.ttl(time.Now())) | ||||
| 	if ttl < minTTL { | ||||
| 		ttl = minTTL | ||||
| 	} | ||||
|  | ||||
| 	for j, r := range i.Answer { | ||||
| 		m1.Answer[j] = dns.Copy(r) | ||||
| 		m1.Answer[j].Header().Ttl = ttl | ||||
| 	} | ||||
| 	for j, r := range i.Ns { | ||||
| 		m1.Ns[j] = dns.Copy(r) | ||||
| 		m1.Ns[j].Header().Ttl = ttl | ||||
| 	} | ||||
| 	for j, r := range i.Extra { | ||||
| 		m1.Extra[j] = dns.Copy(r) | ||||
| 		if m1.Extra[j].Header().Rrtype != dns.TypeOPT { | ||||
| 			m1.Extra[j].Header().Ttl = ttl | ||||
| 		} | ||||
| 	} | ||||
| 	return m1 | ||||
| } | ||||
|  | ||||
| @@ -77,27 +94,6 @@ func (i *item) ttl(now time.Time) int { | ||||
| 	return ttl | ||||
| } | ||||
|  | ||||
| // setMsgTTL sets the ttl on all RRs in all sections. If ttl is smaller than minTTL | ||||
| // that value is used. | ||||
| func setMsgTTL(m *dns.Msg, ttl uint32) { | ||||
| 	if ttl < minTTL { | ||||
| 		ttl = minTTL | ||||
| 	} | ||||
|  | ||||
| 	for _, r := range m.Answer { | ||||
| 		r.Header().Ttl = ttl | ||||
| 	} | ||||
| 	for _, r := range m.Ns { | ||||
| 		r.Header().Ttl = ttl | ||||
| 	} | ||||
| 	for _, r := range m.Extra { | ||||
| 		if r.Header().Rrtype == dns.TypeOPT { | ||||
| 			continue | ||||
| 		} | ||||
| 		r.Header().Ttl = ttl | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func minMsgTTL(m *dns.Msg, mt response.Type) time.Duration { | ||||
| 	if mt != response.NoError && mt != response.NameError && mt != response.NoData { | ||||
| 		return 0 | ||||
|   | ||||
		Reference in New Issue
	
	Block a user