diff --git a/http/cache.go b/http/cache.go index 97aac30..6d61b20 100644 --- a/http/cache.go +++ b/http/cache.go @@ -44,13 +44,17 @@ func (c *Cache) Set(ip net.IP, resp Response) { k := key(ip) c.mu.Lock() defer c.mu.Unlock() - if len(c.entries) == c.capacity { - // At capacity. Remove the oldest entry - oldest := c.values.Front() - oldestValue := oldest.Value.(Response) - oldestKey := key(oldestValue.IP) - delete(c.entries, oldestKey) - c.values.Remove(oldest) + minEvictions := len(c.entries) - c.capacity + 1 + if minEvictions > 0 { // At or above capacity. Shrink the cache + evicted := 0 + for el := c.values.Front(); el != nil && evicted < minEvictions; { + value := el.Value.(Response) + delete(c.entries, key(value.IP)) + next := el.Next() + c.values.Remove(el) + el = next + evicted++ + } } current, ok := c.entries[k] if ok { @@ -70,6 +74,16 @@ func (c *Cache) Get(ip net.IP) (Response, bool) { return r.Value.(Response), true } +func (c *Cache) Resize(capacity int) error { + if capacity < 0 { + return fmt.Errorf("invalid capacity: %d\n", capacity) + } + c.mu.Lock() + defer c.mu.Unlock() + c.capacity = capacity + return nil +} + func (c *Cache) Stats() CacheStats { c.mu.RLock() defer c.mu.RUnlock() diff --git a/http/cache_test.go b/http/cache_test.go index 2c6d4ea..0867958 100644 --- a/http/cache_test.go +++ b/http/cache_test.go @@ -54,3 +54,23 @@ func TestCacheDuplicate(t *testing.T) { t.Errorf("want %d values, got %d", want, got) } } + +func TestCacheResize(t *testing.T) { + c := NewCache(10) + for i := 1; i <= 10; i++ { + ip := net.ParseIP(fmt.Sprintf("192.0.2.%d", i)) + r := Response{IP: ip} + c.Set(ip, r) + } + if got, want := len(c.entries), 10; got != want { + t.Errorf("want %d entries, got %d", want, got) + } + if err := c.Resize(5); err != nil { + t.Fatal(err) + } + r := Response{IP: net.ParseIP("192.0.2.42")} + c.Set(r.IP, r) + if got, want := len(c.entries), 5; got != want { + t.Errorf("want %d entries, got %d", want, got) + } +} diff --git a/http/http.go b/http/http.go index 940684f..29295b5 100644 --- a/http/http.go +++ b/http/http.go @@ -4,6 +4,7 @@ import ( "encoding/json" "fmt" "html/template" + "io/ioutil" "path/filepath" "strings" @@ -271,6 +272,30 @@ func (s *Server) PortHandler(w http.ResponseWriter, r *http.Request) *appError { return nil } +func (s *Server) cacheResizeHandler(w http.ResponseWriter, r *http.Request) *appError { + body, err := ioutil.ReadAll(r.Body) + if err != nil { + return badRequest(err).WithMessage(err.Error()).AsJSON() + } + capacity, err := strconv.Atoi(string(body)) + if err != nil { + return badRequest(err).WithMessage(err.Error()).AsJSON() + } + if err := s.cache.Resize(capacity); err != nil { + return badRequest(err).WithMessage(err.Error()).AsJSON() + } + data := struct { + Message string `json:"message"` + }{fmt.Sprintf("Changed cache capacity to %d.", capacity)} + b, err := json.Marshal(data) + if err != nil { + return internalServerError(err).AsJSON() + } + w.Header().Set("Content-Type", jsonMediaType) + w.Write(b) + return nil +} + func (s *Server) cacheHandler(w http.ResponseWriter, r *http.Request) *appError { cacheStats := s.cache.Stats() var data = struct { @@ -409,6 +434,7 @@ func (s *Server) Handler() http.Handler { // Profiling if s.profile { + r.Route("POST", "/debug/cache/resize", s.cacheResizeHandler) r.Route("GET", "/debug/cache/", s.cacheHandler) r.Route("GET", "/debug/pprof/cmdline", wrapHandlerFunc(pprof.Cmdline)) r.Route("GET", "/debug/pprof/profile", wrapHandlerFunc(pprof.Profile)) diff --git a/http/http_test.go b/http/http_test.go index f081510..b7ee568 100644 --- a/http/http_test.go +++ b/http/http_test.go @@ -7,6 +7,7 @@ import ( "net/http" "net/http/httptest" "net/url" + "strings" "testing" "github.com/mpolden/echoip/iputil/geo" @@ -56,6 +57,23 @@ func httpGet(url string, acceptMediaType string, userAgent string) (string, int, return string(data), res.StatusCode, nil } +func httpPost(url, body string) (*http.Response, string, error) { + r, err := http.NewRequest(http.MethodPost, url, strings.NewReader(body)) + if err != nil { + return nil, "", err + } + res, err := http.DefaultClient.Do(r) + if err != nil { + return nil, "", err + } + defer res.Body.Close() + data, err := ioutil.ReadAll(res.Body) + if err != nil { + return nil, "", err + } + return res, string(data), nil +} + func TestCLIHandlers(t *testing.T) { log.SetOutput(ioutil.Discard) s := httptest.NewServer(testServer().Handler()) @@ -175,6 +193,21 @@ func TestCacheHandler(t *testing.T) { } } +func TestCacheResizeHandler(t *testing.T) { + log.SetOutput(ioutil.Discard) + srv := testServer() + srv.profile = true + s := httptest.NewServer(srv.Handler()) + _, got, err := httpPost(s.URL+"/debug/cache/resize", "10") + if err != nil { + t.Fatal(err) + } + want := `{"message":"Changed cache capacity to 10."}` + if got != want { + t.Errorf("got %q, want %q", got, want) + } +} + func TestIPFromRequest(t *testing.T) { var tests = []struct { remoteAddr string