diff --git a/plugin/cache/cache.go b/plugin/cache/cache.go index c88f831a4..3b59c9189 100644 --- a/plugin/cache/cache.go +++ b/plugin/cache/cache.go @@ -24,12 +24,12 @@ type Cache struct { zonesMetricLabel string viewMetricLabel string - ncache *cache.Cache + ncache *cache.Cache[*item] ncap int nttl time.Duration minnttl time.Duration - pcache *cache.Cache + pcache *cache.Cache[*item] pcap int pttl time.Duration minpttl time.Duration @@ -61,11 +61,11 @@ func New() *Cache { return &Cache{ Zones: []string{"."}, pcap: defaultCap, - pcache: cache.New(defaultCap), + pcache: cache.New[*item](defaultCap), pttl: maxTTL, minpttl: minTTL, ncap: defaultCap, - ncache: cache.New(defaultCap), + ncache: cache.New[*item](defaultCap), nttl: maxNTTL, minnttl: minNTTL, failttl: minNTTL, diff --git a/plugin/cache/cache_test.go b/plugin/cache/cache_test.go index 25b301b98..4dd0fb3d8 100644 --- a/plugin/cache/cache_test.go +++ b/plugin/cache/cache_test.go @@ -712,8 +712,8 @@ func TestCacheWildcardMetadata(t *testing.T) { } _, k := key(qname, w.Msg, response.NoError, state.Do(), state.Req.CheckingDisabled) i, _ := c.pcache.Get(k) - if i.(*item).wildcard != wildcard { - t.Errorf("expected wildcard response to enter cache with cache item's wildcard = %q, got %q", wildcard, i.(*item).wildcard) + if i.wildcard != wildcard { + t.Errorf("expected wildcard response to enter cache with cache item's wildcard = %q, got %q", wildcard, i.wildcard) } // 2. Test retrieving the cached item from cache and writing its wildcard value to metadata @@ -728,7 +728,7 @@ func TestCacheWildcardMetadata(t *testing.T) { t.Fatal("expected metadata func for wildcard response retrieved from cache, got nil") } if f() != wildcard { - t.Errorf("after retrieving wildcard item from cache, expected \"zone/wildcard\" metadata value to be %q, got %q", wildcard, i.(*item).wildcard) + t.Errorf("after retrieving wildcard item from cache, expected \"zone/wildcard\" metadata value to be %q, got %q", wildcard, i.wildcard) } } diff --git a/plugin/cache/handler.go b/plugin/cache/handler.go index b6815ee0e..4e8edc3d2 100644 --- a/plugin/cache/handler.go +++ b/plugin/cache/handler.go @@ -127,19 +127,17 @@ func (c *Cache) getIfNotStale(now time.Time, state request.Request, server strin cacheRequests.WithLabelValues(server, c.zonesMetricLabel, c.viewMetricLabel).Inc() if i, ok := c.ncache.Get(k); ok { - itm := i.(*item) - ttl := itm.ttl(now) - if itm.matches(state) && (ttl > 0 || (c.staleUpTo > 0 && -ttl < int(c.staleUpTo.Seconds()))) { + ttl := i.ttl(now) + if i.matches(state) && (ttl > 0 || (c.staleUpTo > 0 && -ttl < int(c.staleUpTo.Seconds()))) { cacheHits.WithLabelValues(server, Denial, c.zonesMetricLabel, c.viewMetricLabel).Inc() - return i.(*item) + return i } } if i, ok := c.pcache.Get(k); ok { - itm := i.(*item) - ttl := itm.ttl(now) - if itm.matches(state) && (ttl > 0 || (c.staleUpTo > 0 && -ttl < int(c.staleUpTo.Seconds()))) { + ttl := i.ttl(now) + if i.matches(state) && (ttl > 0 || (c.staleUpTo > 0 && -ttl < int(c.staleUpTo.Seconds()))) { cacheHits.WithLabelValues(server, Success, c.zonesMetricLabel, c.viewMetricLabel).Inc() - return i.(*item) + return i } } cacheMisses.WithLabelValues(server, c.zonesMetricLabel, c.viewMetricLabel).Inc() @@ -150,10 +148,10 @@ func (c *Cache) getIfNotStale(now time.Time, state request.Request, server strin func (c *Cache) exists(state request.Request) *item { k := hash(state.Name(), state.QType(), state.Do(), state.Req.CheckingDisabled) if i, ok := c.ncache.Get(k); ok { - return i.(*item) + return i } if i, ok := c.pcache.Get(k); ok { - return i.(*item) + return i } return nil } diff --git a/plugin/cache/setup.go b/plugin/cache/setup.go index f8278b872..9f18c673c 100644 --- a/plugin/cache/setup.go +++ b/plugin/cache/setup.go @@ -253,8 +253,8 @@ func cacheParse(c *caddy.Controller) (*Cache, error) { ca.Zones = origins ca.zonesMetricLabel = strings.Join(origins, ",") - ca.pcache = cache.New(ca.pcap) - ca.ncache = cache.New(ca.ncap) + ca.pcache = cache.New[*item](ca.pcap) + ca.ncache = cache.New[*item](ca.ncap) } return ca, nil diff --git a/plugin/dnssec/cache.go b/plugin/dnssec/cache.go index 9c94efd91..a2483ee2b 100644 --- a/plugin/dnssec/cache.go +++ b/plugin/dnssec/cache.go @@ -23,7 +23,7 @@ func hash(rrs []dns.RR) uint64 { return h.Sum64() } -func periodicClean(c *cache.Cache, stop <-chan struct{}) { +func periodicClean(c *cache.Cache[[]dns.RR], stop <-chan struct{}) { tick := time.NewTicker(8 * time.Hour) defer tick.Stop() for { @@ -32,8 +32,8 @@ func periodicClean(c *cache.Cache, stop <-chan struct{}) { // we sign for 8 days, check if a signature in the cache reached 75% of that (i.e. 6), if found delete // the signature is75 := time.Now().UTC().Add(twoDays) - c.Walk(func(items map[uint64]any, key uint64) bool { - for _, rr := range items[key].([]dns.RR) { + c.Walk(func(items map[uint64][]dns.RR, key uint64) bool { + for _, rr := range items[key] { if !rr.(*dns.RRSIG).ValidityPeriod(is75) { delete(items, key) } diff --git a/plugin/dnssec/cache_test.go b/plugin/dnssec/cache_test.go index 8d5ea8876..64e5bd068 100644 --- a/plugin/dnssec/cache_test.go +++ b/plugin/dnssec/cache_test.go @@ -7,6 +7,8 @@ import ( "github.com/coredns/coredns/plugin/pkg/cache" "github.com/coredns/coredns/plugin/test" "github.com/coredns/coredns/request" + + "github.com/miekg/dns" ) func TestCacheSet(t *testing.T) { @@ -20,7 +22,7 @@ func TestCacheSet(t *testing.T) { t.Fatalf("Failed to parse key: %v\n", err) } - c := cache.New(defaultCap) + c := cache.New[[]dns.RR](defaultCap) m := testMsg() state := request.Request{Req: m, Zone: "miek.nl."} k := hash(m.Answer) // calculate *before* we add the sig @@ -44,7 +46,7 @@ func TestCacheNotValidExpired(t *testing.T) { t.Fatalf("Failed to parse key: %v\n", err) } - c := cache.New(defaultCap) + c := cache.New[[]dns.RR](defaultCap) m := testMsg() state := request.Request{Req: m, Zone: "miek.nl."} k := hash(m.Answer) // calculate *before* we add the sig @@ -68,7 +70,7 @@ func TestCacheNotValidYet(t *testing.T) { t.Fatalf("Failed to parse key: %v\n", err) } - c := cache.New(defaultCap) + c := cache.New[[]dns.RR](defaultCap) m := testMsg() state := request.Request{Req: m, Zone: "miek.nl."} k := hash(m.Answer) // calculate *before* we add the sig diff --git a/plugin/dnssec/dnssec.go b/plugin/dnssec/dnssec.go index bb2abd052..0032c048c 100644 --- a/plugin/dnssec/dnssec.go +++ b/plugin/dnssec/dnssec.go @@ -22,11 +22,11 @@ type Dnssec struct { keys []*DNSKEY splitkeys bool inflight *singleflight.Group - cache *cache.Cache + cache *cache.Cache[[]dns.RR] } // New returns a new Dnssec. -func New(zones []string, keys []*DNSKEY, splitkeys bool, next plugin.Handler, c *cache.Cache) Dnssec { +func New(zones []string, keys []*DNSKEY, splitkeys bool, next plugin.Handler, c *cache.Cache[[]dns.RR]) Dnssec { return Dnssec{Next: next, zones: zones, keys: keys, @@ -152,7 +152,7 @@ func (d Dnssec) get(key uint64, server string) ([]dns.RR, bool) { if s, ok := d.cache.Get(key); ok { // we sign for 8 days, check if a signature in the cache reached 3/4 of that is75 := time.Now().UTC().Add(twoDays) - for _, rr := range s.([]dns.RR) { + for _, rr := range s { if !rr.(*dns.RRSIG).ValidityPeriod(is75) { cacheMisses.WithLabelValues(server).Inc() return nil, false @@ -160,7 +160,7 @@ func (d Dnssec) get(key uint64, server string) ([]dns.RR, bool) { } cacheHits.WithLabelValues(server).Inc() - return s.([]dns.RR), true + return s, true } cacheMisses.WithLabelValues(server).Inc() return nil, false diff --git a/plugin/dnssec/dnssec_test.go b/plugin/dnssec/dnssec_test.go index f48d9a9ec..786c92608 100644 --- a/plugin/dnssec/dnssec_test.go +++ b/plugin/dnssec/dnssec_test.go @@ -69,7 +69,7 @@ func TestSigningDifferentZone(t *testing.T) { m := testMsgEx() state := request.Request{Req: m, Zone: "example.org."} - c := cache.New(defaultCap) + c := cache.New[[]dns.RR](defaultCap) d := New([]string{"example.org."}, []*DNSKEY{key}, false, nil, c) m = d.Sign(state, time.Now().UTC(), server) if !section(m.Answer, 1) { @@ -250,7 +250,7 @@ func testEmptyMsg() *dns.Msg { func newDnssec(t *testing.T, zones []string) (Dnssec, func(), func()) { t.Helper() k, rm1, rm2 := newKey(t) - c := cache.New(defaultCap) + c := cache.New[[]dns.RR](defaultCap) d := New(zones, []*DNSKEY{k}, false, nil, c) return d, rm1, rm2 } diff --git a/plugin/dnssec/handler_test.go b/plugin/dnssec/handler_test.go index e82e546d3..04a972a0e 100644 --- a/plugin/dnssec/handler_test.go +++ b/plugin/dnssec/handler_test.go @@ -170,7 +170,7 @@ func TestLookupZone(t *testing.T) { dnskey, rm1, rm2 := newKey(t) defer rm1() defer rm2() - c := cache.New(defaultCap) + c := cache.New[[]dns.RR](defaultCap) dh := New([]string{"miek.nl."}, []*DNSKEY{dnskey}, false, fm, c) for _, tc := range dnsTestCases { @@ -193,7 +193,7 @@ func TestLookupDNSKEY(t *testing.T) { dnskey, rm1, rm2 := newKey(t) defer rm1() defer rm2() - c := cache.New(defaultCap) + c := cache.New[[]dns.RR](defaultCap) dh := New([]string{"miek.nl."}, []*DNSKEY{dnskey}, false, test.ErrorHandler(), c) for _, tc := range dnssecTestCases { diff --git a/plugin/dnssec/setup.go b/plugin/dnssec/setup.go index e3feae05b..3ab968f8b 100644 --- a/plugin/dnssec/setup.go +++ b/plugin/dnssec/setup.go @@ -12,6 +12,8 @@ import ( "github.com/coredns/coredns/plugin" "github.com/coredns/coredns/plugin/pkg/cache" clog "github.com/coredns/coredns/plugin/pkg/log" + + "github.com/miekg/dns" ) var log = clog.NewWithPlugin("dnssec") @@ -24,7 +26,7 @@ func setup(c *caddy.Controller) error { return plugin.Error("dnssec", err) } - ca := cache.New(capacity) + ca := cache.New[[]dns.RR](capacity) stop := make(chan struct{}) c.OnShutdown(func() error { diff --git a/plugin/pkg/cache/cache.go b/plugin/pkg/cache/cache.go index e063a45c2..fe8dd9437 100644 --- a/plugin/pkg/cache/cache.go +++ b/plugin/pkg/cache/cache.go @@ -16,52 +16,52 @@ func Hash(what []byte) uint64 { } // Cache is cache. -type Cache struct { - shards [shardSize]*shard +type Cache[T any] struct { + shards [shardSize]*shard[T] } // shard is a cache with random eviction. -type shard struct { - items map[uint64]any +type shard[T any] struct { + items map[uint64]T size int sync.RWMutex } // New returns a new cache. -func New(size int) *Cache { +func New[T any](size int) *Cache[T] { ssize := max(size/shardSize, 4) - c := &Cache{} + c := &Cache[T]{} // Initialize all the shards for i := range shardSize { - c.shards[i] = newShard(ssize) + c.shards[i] = newShard[T](ssize) } return c } // Add adds a new element to the cache. If the element already exists it is overwritten. // Returns true if an existing element was evicted to make room for this element. -func (c *Cache) Add(key uint64, el any) bool { +func (c *Cache[T]) Add(key uint64, el T) bool { shard := key & (shardSize - 1) return c.shards[shard].Add(key, el) } // Get looks up element index under key. -func (c *Cache) Get(key uint64) (any, bool) { +func (c *Cache[T]) Get(key uint64) (T, bool) { shard := key & (shardSize - 1) return c.shards[shard].Get(key) } // Remove removes the element indexed with key. -func (c *Cache) Remove(key uint64) { +func (c *Cache[T]) Remove(key uint64) { shard := key & (shardSize - 1) c.shards[shard].Remove(key) } // Len returns the number of elements in the cache. -func (c *Cache) Len() int { +func (c *Cache[T]) Len() int { l := 0 for _, s := range &c.shards { l += s.Len() @@ -70,18 +70,18 @@ func (c *Cache) Len() int { } // Walk walks each shard in the cache. -func (c *Cache) Walk(f func(map[uint64]any, uint64) bool) { +func (c *Cache[T]) Walk(f func(map[uint64]T, uint64) bool) { for _, s := range &c.shards { s.Walk(f) } } // newShard returns a new shard with size. -func newShard(size int) *shard { return &shard{items: make(map[uint64]any), size: size} } +func newShard[T any](size int) *shard[T] { return &shard[T]{items: make(map[uint64]T), size: size} } // Add adds element indexed by key into the cache. Any existing element is overwritten // Returns true if an existing element was evicted to make room for this element. -func (s *shard) Add(key uint64, el any) bool { +func (s *shard[T]) Add(key uint64, el T) bool { eviction := false s.Lock() if len(s.items) >= s.size { @@ -99,14 +99,14 @@ func (s *shard) Add(key uint64, el any) bool { } // Remove removes the element indexed by key from the cache. -func (s *shard) Remove(key uint64) { +func (s *shard[T]) Remove(key uint64) { s.Lock() delete(s.items, key) s.Unlock() } // Evict removes a random element from the cache. -func (s *shard) Evict() { +func (s *shard[T]) Evict() { s.Lock() for k := range s.items { delete(s.items, k) @@ -116,7 +116,7 @@ func (s *shard) Evict() { } // Get looks up the element indexed under key. -func (s *shard) Get(key uint64) (any, bool) { +func (s *shard[T]) Get(key uint64) (T, bool) { s.RLock() el, found := s.items[key] s.RUnlock() @@ -124,7 +124,7 @@ func (s *shard) Get(key uint64) (any, bool) { } // Len returns the current length of the cache. -func (s *shard) Len() int { +func (s *shard[T]) Len() int { s.RLock() l := len(s.items) s.RUnlock() @@ -132,7 +132,7 @@ func (s *shard) Len() int { } // Walk walks the shard for each element the function f is executed while holding a write lock. -func (s *shard) Walk(f func(map[uint64]any, uint64) bool) { +func (s *shard[T]) Walk(f func(map[uint64]T, uint64) bool) { s.RLock() items := make([]uint64, len(s.items)) i := 0 diff --git a/plugin/pkg/cache/cache_test.go b/plugin/pkg/cache/cache_test.go index 53d1152cc..11a3b32c3 100644 --- a/plugin/pkg/cache/cache_test.go +++ b/plugin/pkg/cache/cache_test.go @@ -6,7 +6,7 @@ import ( func TestCacheAddAndGet(t *testing.T) { const N = shardSize * 4 - c := New(N) + c := New[int](N) c.Add(1, 1) if _, found := c.Get(1); !found { @@ -25,7 +25,7 @@ func TestCacheAddAndGet(t *testing.T) { } func TestCacheLen(t *testing.T) { - c := New(4) + c := New[int](4) c.Add(1, 1) if l := c.Len(); l != 1 { @@ -44,7 +44,7 @@ func TestCacheLen(t *testing.T) { } func TestCacheSharding(t *testing.T) { - c := New(shardSize) + c := New[int](shardSize) for i := range shardSize * 2 { c.Add(uint64(i), 1) } @@ -56,15 +56,15 @@ func TestCacheSharding(t *testing.T) { } func TestCacheWalk(t *testing.T) { - c := New(10) + c := New[int](10) exp := make([]int, 10*2) for i := range 10 * 2 { c.Add(uint64(i), 1) exp[i] = 1 } got := make([]int, 10*2) - c.Walk(func(items map[uint64]any, key uint64) bool { - got[key] = items[key].(int) + c.Walk(func(items map[uint64]int, key uint64) bool { + got[key] = items[key] return true }) for i := range exp { @@ -77,7 +77,7 @@ func TestCacheWalk(t *testing.T) { func BenchmarkCache(b *testing.B) { b.ReportAllocs() - c := New(4) + c := New[int](4) for b.Loop() { c.Add(1, 1) c.Get(1) diff --git a/plugin/pkg/cache/shard_test.go b/plugin/pkg/cache/shard_test.go index b9f046027..d0eefa72e 100644 --- a/plugin/pkg/cache/shard_test.go +++ b/plugin/pkg/cache/shard_test.go @@ -6,7 +6,7 @@ import ( ) func TestShardAddAndGet(t *testing.T) { - s := newShard(1) + s := newShard[int](1) s.Add(1, 1) if _, found := s.Get(1); !found { @@ -24,7 +24,7 @@ func TestShardAddAndGet(t *testing.T) { func TestAddEvict(t *testing.T) { const size = 1024 - s := newShard(size) + s := newShard[int](size) for i := range size { s.Add(uint64(i), 1) @@ -38,7 +38,7 @@ func TestAddEvict(t *testing.T) { } func TestShardLen(t *testing.T) { - s := newShard(4) + s := newShard[int](4) s.Add(1, 1) if l := s.Len(); l != 1 { @@ -57,7 +57,7 @@ func TestShardLen(t *testing.T) { } func TestShardEvict(t *testing.T) { - s := newShard(1) + s := newShard[int](1) s.Add(1, 1) s.Add(2, 2) // 1 should be gone @@ -68,7 +68,7 @@ func TestShardEvict(t *testing.T) { } func TestShardLenEvict(t *testing.T) { - s := newShard(4) + s := newShard[int](4) s.Add(1, 1) s.Add(2, 1) s.Add(3, 1) @@ -95,7 +95,7 @@ func TestShardLenEvict(t *testing.T) { } func TestShardEvictParallel(t *testing.T) { - s := newShard(shardSize) + s := newShard[struct{}](shardSize) for i := range shardSize { s.Add(uint64(i), struct{}{}) } @@ -117,7 +117,7 @@ func TestShardEvictParallel(t *testing.T) { } func BenchmarkShard(b *testing.B) { - s := newShard(shardSize) + s := newShard[int](shardSize) b.ResetTimer() for i := range b.N { k := uint64(i) % shardSize * 2 @@ -127,7 +127,7 @@ func BenchmarkShard(b *testing.B) { } func BenchmarkShardParallel(b *testing.B) { - s := newShard(shardSize) + s := newShard[int](shardSize) b.ResetTimer() b.RunParallel(func(pb *testing.PB) { for i := uint64(0); pb.Next(); i++ {