From 78116f69adf1dde8561f1acd41d1f1ed33d32b76 Mon Sep 17 00:00:00 2001 From: Martin Polden Date: Wed, 25 Dec 2019 21:04:26 +0100 Subject: [PATCH] Implement response cache --- Makefile | 2 +- cmd/echoip/main.go | 5 +++-- http/cache.go | 54 ++++++++++++++++++++++++++++++++++++++++++++++ http/cache_test.go | 41 +++++++++++++++++++++++++++++++++++ http/http.go | 15 +++++++++---- http/http_test.go | 2 +- 6 files changed, 111 insertions(+), 8 deletions(-) create mode 100644 http/cache.go create mode 100644 http/cache_test.go diff --git a/Makefile b/Makefile index c80b973..b69bfe1 100644 --- a/Makefile +++ b/Makefile @@ -52,4 +52,4 @@ heroku-run: geoip-download ifndef PORT $(error PORT must be set) endif - echoip -f data/country.mmdb -c data/city.mmdb -a data/asn.mmdb -p -r -H CF-Connecting-IP -H X-Forwarded-For -l :$(PORT) + echoip -C 1000000 -f data/country.mmdb -c data/city.mmdb -a data/asn.mmdb -p -r -H CF-Connecting-IP -H X-Forwarded-For -l :$(PORT) diff --git a/cmd/echoip/main.go b/cmd/echoip/main.go index 9f18681..7e6d6cf 100644 --- a/cmd/echoip/main.go +++ b/cmd/echoip/main.go @@ -22,6 +22,7 @@ func main() { PortLookup bool `short:"p" long:"port-lookup" description:"Enable port lookup"` Template string `short:"t" long:"template" description:"Path to template" default:"index.html" value-name:"FILE"` IPHeaders []string `short:"H" long:"trusted-header" description:"Header to trust for remote IP, if present (e.g. X-Real-IP)" value-name:"NAME"` + CacheCapacity int `short:"C" long:"cache-size" description:"Size of response cache. Set to 0 to disable" value-name:"SIZE"` } _, err := flags.ParseArgs(&opts, os.Args) if err != nil { @@ -33,8 +34,8 @@ func main() { if err != nil { log.Fatal(err) } - - server := http.New(r) + cache := http.NewCache(opts.CacheCapacity) + server := http.New(r, cache) server.IPHeaders = opts.IPHeaders if _, err := os.Stat(opts.Template); err == nil { server.Template = opts.Template diff --git a/http/cache.go b/http/cache.go new file mode 100644 index 0000000..6b68dfb --- /dev/null +++ b/http/cache.go @@ -0,0 +1,54 @@ +package http + +import ( + "hash/fnv" + "net" + "sync" +) + +type Cache struct { + capacity int + mu sync.RWMutex + entries map[uint64]*Response + keys []uint64 +} + +func NewCache(capacity int) *Cache { + if capacity < 0 { + capacity = 0 + } + return &Cache{ + capacity: capacity, + entries: make(map[uint64]*Response), + keys: make([]uint64, 0, capacity), + } +} + +func key(ip net.IP) uint64 { + h := fnv.New64a() + h.Write(ip) + return h.Sum64() +} + +func (c *Cache) Set(ip net.IP, resp *Response) { + if c.capacity == 0 { + return + } + k := key(ip) + c.mu.Lock() + defer c.mu.Unlock() + if len(c.entries) == c.capacity && c.capacity > 0 { + delete(c.entries, c.keys[0]) + c.keys = c.keys[1:] + } + c.entries[k] = resp + c.keys = append(c.keys, k) +} + +func (c *Cache) Get(ip net.IP) (*Response, bool) { + k := key(ip) + c.mu.RLock() + defer c.mu.RUnlock() + r, ok := c.entries[k] + return r, ok +} diff --git a/http/cache_test.go b/http/cache_test.go new file mode 100644 index 0000000..c16c068 --- /dev/null +++ b/http/cache_test.go @@ -0,0 +1,41 @@ +package http + +import ( + "fmt" + "net" + "testing" +) + +func TestCacheCapacity(t *testing.T) { + var tests = []struct { + addCount, capacity, size int + }{ + {1, 0, 0}, + {1, 2, 1}, + {2, 2, 2}, + {3, 2, 2}, + } + for i, tt := range tests { + c := NewCache(tt.capacity) + var responses []*Response + for i := 0; i < tt.addCount; i++ { + ip := net.ParseIP(fmt.Sprintf("192.0.2.%d", i)) + r := &Response{IP: ip} + responses = append(responses, r) + c.Set(ip, r) + } + if got := len(c.entries); got != tt.size { + t.Errorf("#%d: len(entries) = %d, want %d", i, got, tt.size) + } + if tt.capacity > 0 && tt.addCount > tt.capacity && tt.capacity == tt.size { + lastAdded := responses[tt.addCount-1] + if _, ok := c.Get(lastAdded.IP); !ok { + t.Errorf("#%d: Get(%s) = (_, %t), want (_, %t)", i, lastAdded.IP.String(), ok, !ok) + } + firstAdded := responses[0] + if _, ok := c.Get(firstAdded.IP); ok { + t.Errorf("#%d: Get(%s) = (_, %t), want (_, %t)", i, firstAdded.IP.String(), ok, !ok) + } + } + } +} diff --git a/http/http.go b/http/http.go index 6df6563..14d7d12 100644 --- a/http/http.go +++ b/http/http.go @@ -27,6 +27,7 @@ type Server struct { IPHeaders []string LookupAddr func(net.IP) (string, error) LookupPort func(net.IP, uint64) error + cache *Cache gr geo.Reader } @@ -51,8 +52,8 @@ type PortResponse struct { Reachable bool `json:"reachable"` } -func New(db geo.Reader) *Server { - return &Server{gr: db} +func New(db geo.Reader, cache *Cache) *Server { + return &Server{cache: cache, gr: db} } func ipFromForwardedForHeader(v string) string { @@ -93,6 +94,10 @@ func (s *Server) newResponse(r *http.Request) (Response, error) { if err != nil { return Response{}, err } + response, ok := s.cache.Get(ip) + if ok { + return *response, nil + } ipDecimal := iputil.ToDecimal(ip) country, _ := s.gr.Country(ip) city, _ := s.gr.City(ip) @@ -111,7 +116,7 @@ func (s *Server) newResponse(r *http.Request) (Response, error) { parsed := useragent.Parse(userAgentRaw) userAgent = &parsed } - return Response{ + response = &Response{ IP: ip, IPDecimal: ipDecimal, Country: country.Name, @@ -124,7 +129,9 @@ func (s *Server) newResponse(r *http.Request) (Response, error) { ASN: autonomousSystemNumber, ASNOrg: asn.AutonomousSystemOrganization, UserAgent: userAgent, - }, nil + } + s.cache.Set(ip, response) + return *response, nil } func (s *Server) newPortResponse(r *http.Request) (PortResponse, error) { diff --git a/http/http_test.go b/http/http_test.go index c766bf5..60bc048 100644 --- a/http/http_test.go +++ b/http/http_test.go @@ -31,7 +31,7 @@ func (t *testDb) ASN(net.IP) (geo.ASN, error) { func (t *testDb) IsEmpty() bool { return false } func testServer() *Server { - return &Server{gr: &testDb{}, LookupAddr: lookupAddr, LookupPort: lookupPort} + return &Server{cache: NewCache(100), gr: &testDb{}, LookupAddr: lookupAddr, LookupPort: lookupPort} } func httpGet(url string, acceptMediaType string, userAgent string) (string, int, error) {