diff --git a/api/api.go b/api/api.go index ed8f43e..8d9101c 100644 --- a/api/api.go +++ b/api/api.go @@ -29,20 +29,32 @@ const ( var cliUserAgentExp = regexp.MustCompile("^(?i)(curl|wget|fetch\\slibfetch)\\/.*$") type API struct { - db *geoip2.Reader CORS bool ReverseLookup bool Template string + lookupAddr func(string) ([]string, error) + lookupCountry func(net.IP) (string, error) + ipFromRequest func(*http.Request) (net.IP, error) } -func New() *API { return &API{} } +func New() *API { + return &API{ + lookupAddr: net.LookupAddr, + lookupCountry: func(ip net.IP) (string, error) { return "", nil }, + ipFromRequest: ipFromRequest, + } +} func NewWithGeoIP(filepath string) (*API, error) { db, err := geoip2.Open(filepath) if err != nil { return nil, err } - return &API{db: db}, nil + api := New() + api.lookupCountry = func(ip net.IP) (string, error) { + return lookupCountry(db, ip) + } + return api, nil } type Cmd struct { @@ -100,11 +112,11 @@ func headerPairFromRequest(r *http.Request) (string, string, error) { return header, value, nil } -func (a *API) lookupCountry(ip net.IP) (string, error) { - if a.db == nil { +func lookupCountry(db *geoip2.Reader, ip net.IP) (string, error) { + if db == nil { return "", nil } - record, err := a.db.Country(ip) + record, err := db.Country(ip) if err != nil { return "", err } @@ -187,25 +199,24 @@ func cliMatcher(r *http.Request, rm *mux.RouteMatch) bool { func (a *API) requestFilter(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - ip, err := ipFromRequest(r) + ip, err := a.ipFromRequest(r) if err != nil { - r.Header.Set(IP_HEADER, err.Error()) + log.Print(err) + r.Header.Set(IP_HEADER, "") } else { r.Header.Set(IP_HEADER, ip.String()) country, err := a.lookupCountry(ip) if err != nil { - r.Header.Set(COUNTRY_HEADER, err.Error()) - } else { - r.Header.Set(COUNTRY_HEADER, country) + log.Print(err) } + r.Header.Set(COUNTRY_HEADER, country) } if a.ReverseLookup { - hostname, err := net.LookupAddr(ip.String()) + hostname, err := a.lookupAddr(ip.String()) if err != nil { - r.Header.Set(HOSTNAME_HEADER, err.Error()) - } else { - r.Header.Set(HOSTNAME_HEADER, strings.Join(hostname, ", ")) + log.Print(err) } + r.Header.Set(HOSTNAME_HEADER, strings.Join(hostname, ", ")) } if a.CORS { w.Header().Set("Access-Control-Allow-Methods", "GET") diff --git a/api/api_test.go b/api/api_test.go index d77b62e..89ecf0d 100644 --- a/api/api_test.go +++ b/api/api_test.go @@ -9,9 +9,25 @@ import ( "net/http/httptest" "net/url" "reflect" + "strings" "testing" ) +func newTestAPI() *API { + return &API{ + lookupAddr: func(string) ([]string, error) { + return []string{"localhost"}, nil + }, + lookupCountry: func(ip net.IP) (string, error) { + return "Elbonia", nil + }, + ipFromRequest: func(*http.Request) (net.IP, error) { + return net.ParseIP("127.0.0.1"), nil + }, + ReverseLookup: true, + } +} + func httpGet(url string, json bool, userAgent string) (string, int, error) { r, err := http.NewRequest("GET", url, nil) if err != nil { @@ -38,9 +54,10 @@ func TestGetIP(t *testing.T) { toJSON := func(k string, v string) string { return fmt.Sprintf("{\n \"%s\": \"%s\"\n}", k, v) } - s := httptest.NewServer(New().Handlers()) + s := httptest.NewServer(newTestAPI().Handlers()) jsonAll := "{\n \"Accept-Encoding\": [\n \"gzip\"\n ]," + - "\n \"X-Ifconfig-Country\": [\n \"\"\n ]," + + "\n \"X-Ifconfig-Country\": [\n \"Elbonia\"\n ]," + + "\n \"X-Ifconfig-Hostname\": [\n \"localhost\"\n ]," + "\n \"X-Ifconfig-Ip\": [\n \"127.0.0.1\"\n ]\n}" var tests = []struct { @@ -74,6 +91,21 @@ func TestGetIP(t *testing.T) { } } +func TestGetIPWithoutReverse(t *testing.T) { + log.SetOutput(ioutil.Discard) + api := newTestAPI() + api.ReverseLookup = false + s := httptest.NewServer(api.Handlers()) + + out, _, err := httpGet(s.URL, false, "curl/7.26.0") + if err != nil { + t.Fatal(err) + } + if header := "X-Ifconfig-Hostname"; strings.Contains(out, header) { + t.Errorf("Expected response to not contain %q", header) + } +} + func TestIPFromRequest(t *testing.T) { var tests = []struct { in *http.Request