diff --git a/http/cache.go b/http/cache.go index 647c1dc..8ea55de 100644 --- a/http/cache.go +++ b/http/cache.go @@ -10,19 +10,18 @@ import ( type Cache struct { capacity int mu sync.RWMutex - entries map[uint64]Response - keys *list.List + entries map[uint64]*list.Element + values *list.List } func NewCache(capacity int) *Cache { if capacity < 0 { capacity = 0 } - keys := list.New() return &Cache{ capacity: capacity, - entries: make(map[uint64]Response), - keys: keys, + entries: make(map[uint64]*list.Element), + values: list.New(), } } @@ -41,12 +40,17 @@ func (c *Cache) Set(ip net.IP, resp Response) { defer c.mu.Unlock() if len(c.entries) == c.capacity { // At capacity. Remove the oldest entry - oldest := c.keys.Front() - delete(c.entries, oldest.Value.(uint64)) - c.keys.Remove(oldest) + oldest := c.values.Front() + oldestValue := oldest.Value.(Response) + oldestKey := key(oldestValue.IP) + delete(c.entries, oldestKey) + c.values.Remove(oldest) } - c.entries[k] = resp - c.keys.PushBack(k) + current, ok := c.entries[k] + if ok { + c.values.Remove(current) + } + c.entries[k] = c.values.PushBack(resp) } func (c *Cache) Get(ip net.IP) (Response, bool) { @@ -54,5 +58,8 @@ func (c *Cache) Get(ip net.IP) (Response, bool) { c.mu.RLock() defer c.mu.RUnlock() r, ok := c.entries[k] - return r, ok + if !ok { + return Response{}, false + } + return r.Value.(Response), true } diff --git a/http/cache_test.go b/http/cache_test.go index 1020750..2c6d4ea 100644 --- a/http/cache_test.go +++ b/http/cache_test.go @@ -39,3 +39,18 @@ func TestCacheCapacity(t *testing.T) { } } } + +func TestCacheDuplicate(t *testing.T) { + c := NewCache(10) + ip := net.ParseIP("192.0.2.1") + response := Response{IP: ip} + c.Set(ip, response) + c.Set(ip, response) + want := 1 + if got := len(c.entries); got != want { + t.Errorf("want %d entries, got %d", want, got) + } + if got := c.values.Len(); got != want { + t.Errorf("want %d values, got %d", want, got) + } +}