package http import ( "io/ioutil" "log" "net" "net/http" "net/http/httptest" "testing" "github.com/mpolden/echoip/iputil/geo" ) func lookupAddr(net.IP) (string, error) { return "localhost", nil } func lookupPort(net.IP, uint64) error { return nil } type testDb struct{} func (t *testDb) Country(net.IP) (geo.Country, error) { return geo.Country{Name: "Elbonia", ISO: "EB", IsEU: new(bool)}, nil } func (t *testDb) City(net.IP) (geo.City, error) { return geo.City{Name: "Bornyasherk", Latitude: 63.416667, Longitude: 10.416667}, nil } func (t *testDb) ASN(net.IP) (geo.ASN, error) { return geo.ASN{AutonomousSystemNumber: 59795, AutonomousSystemOrganization: "Hosting4Real"}, nil } func (t *testDb) IsEmpty() bool { return false } func testServer() *Server { return &Server{gr: &testDb{}, LookupAddr: lookupAddr, LookupPort: lookupPort} } func httpGet(url string, acceptMediaType string, userAgent string) (string, int, error) { r, err := http.NewRequest("GET", url, nil) if err != nil { return "", 0, err } if acceptMediaType != "" { r.Header.Set("Accept", acceptMediaType) } r.Header.Set("User-Agent", userAgent) res, err := http.DefaultClient.Do(r) if err != nil { return "", 0, err } defer res.Body.Close() data, err := ioutil.ReadAll(res.Body) if err != nil { return "", 0, err } return string(data), res.StatusCode, nil } func TestCLIHandlers(t *testing.T) { log.SetOutput(ioutil.Discard) s := httptest.NewServer(testServer().Handler()) var tests = []struct { url string out string status int userAgent string acceptMediaType string }{ {s.URL, "127.0.0.1\n", 200, "curl/7.43.0", ""}, {s.URL, "127.0.0.1\n", 200, "foo/bar", textMediaType}, {s.URL + "/ip", "127.0.0.1\n", 200, "", ""}, {s.URL + "/country", "Elbonia\n", 200, "", ""}, {s.URL + "/country-iso", "EB\n", 200, "", ""}, {s.URL + "/coordinates", "63.416667,10.416667\n", 200, "", ""}, {s.URL + "/city", "Bornyasherk\n", 200, "", ""}, {s.URL + "/foo", "404 page not found", 404, "", ""}, {s.URL + "/asn", "AS59795\n", 200, "", ""}, } for _, tt := range tests { out, status, err := httpGet(tt.url, tt.acceptMediaType, tt.userAgent) if err != nil { t.Fatal(err) } if status != tt.status { t.Errorf("Expected %d, got %d", tt.status, status) } if out != tt.out { t.Errorf("Expected %q, got %q", tt.out, out) } } } func TestDisabledHandlers(t *testing.T) { log.SetOutput(ioutil.Discard) server := testServer() server.LookupPort = nil server.LookupAddr = nil server.gr, _ = geo.Open("", "", "") s := httptest.NewServer(server.Handler()) var tests = []struct { url string out string status int }{ {s.URL + "/port/1337", "404 page not found", 404}, {s.URL + "/country", "404 page not found", 404}, {s.URL + "/country-iso", "404 page not found", 404}, {s.URL + "/city", "404 page not found", 404}, {s.URL + "/json", `{"ip":"127.0.0.1","ip_decimal":2130706433}`, 200}, } for _, tt := range tests { out, status, err := httpGet(tt.url, "", "") if err != nil { t.Fatal(err) } if status != tt.status { t.Errorf("Expected %d, got %d", tt.status, status) } if out != tt.out { t.Errorf("Expected %q, got %q", tt.out, out) } } } func TestJSONHandlers(t *testing.T) { log.SetOutput(ioutil.Discard) s := httptest.NewServer(testServer().Handler()) var tests = []struct { url string out string status int }{ {s.URL, `{"ip":"127.0.0.1","ip_decimal":2130706433,"country":"Elbonia","country_eu":false,"country_iso":"EB","city":"Bornyasherk","hostname":"localhost","latitude":63.416667,"longitude":10.416667,"asn":"AS59795","asn_org":"Hosting4Real","user_agent":{"product":"curl","version":"7.2.6.0","raw_value":"curl/7.2.6.0"}}`, 200}, {s.URL + "/port/foo", `{"error":"invalid port: foo"}`, 400}, {s.URL + "/port/0", `{"error":"invalid port: 0"}`, 400}, {s.URL + "/port/65537", `{"error":"invalid port: 65537"}`, 400}, {s.URL + "/port/31337", `{"ip":"127.0.0.1","port":31337,"reachable":true}`, 200}, {s.URL + "/foo", `{"error":"404 page not found"}`, 404}, {s.URL + "/health", `{"status":"OK"}`, 200}, } for _, tt := range tests { out, status, err := httpGet(tt.url, jsonMediaType, "curl/7.2.6.0") if err != nil { t.Fatal(err) } if status != tt.status { t.Errorf("Expected %d for %s, got %d", tt.status, tt.url, status) } if out != tt.out { t.Errorf("Expected %q for %s, got %q", tt.out, tt.url, out) } } } func TestIPFromRequest(t *testing.T) { var tests = []struct { remoteAddr string headerKey string headerValue string trustedHeaders []string out string }{ {"127.0.0.1:9999", "", "", nil, "127.0.0.1"}, // No header given {"127.0.0.1:9999", "X-Real-IP", "1.3.3.7", nil, "127.0.0.1"}, // Trusted header is empty {"127.0.0.1:9999", "X-Real-IP", "1.3.3.7", []string{"X-Foo-Bar"}, "127.0.0.1"}, // Trusted header does not match {"127.0.0.1:9999", "X-Real-IP", "1.3.3.7", []string{"X-Real-IP", "X-Forwarded-For"}, "1.3.3.7"}, // Trusted header matches {"127.0.0.1:9999", "X-Forwarded-For", "1.3.3.7", []string{"X-Real-IP", "X-Forwarded-For"}, "1.3.3.7"}, // Second trusted header matches {"127.0.0.1:9999", "X-Forwarded-For", "1.3.3.7,4.2.4.2", []string{"X-Forwarded-For"}, "1.3.3.7"}, // X-Forwarded-For with multiple entries (commas separator) {"127.0.0.1:9999", "X-Forwarded-For", "1.3.3.7, 4.2.4.2", []string{"X-Forwarded-For"}, "1.3.3.7"}, // X-Forwarded-For with multiple entries (space+comma separator) {"127.0.0.1:9999", "X-Forwarded-For", "", []string{"X-Forwarded-For"}, "127.0.0.1"}, // Empty header } for _, tt := range tests { r := &http.Request{ RemoteAddr: tt.remoteAddr, Header: http.Header{}, } r.Header.Add(tt.headerKey, tt.headerValue) ip, err := ipFromRequest(tt.trustedHeaders, r) if err != nil { t.Fatal(err) } out := net.ParseIP(tt.out) if !ip.Equal(out) { t.Errorf("Expected %s, got %s", out, ip) } } } func TestCLIMatcher(t *testing.T) { browserUserAgent := "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_8_4) " + "AppleWebKit/537.36 (KHTML, like Gecko) Chrome/30.0.1599.28 " + "Safari/537.36" var tests = []struct { in string out bool }{ {"curl/7.26.0", true}, {"Wget/1.13.4 (linux-gnu)", true}, {"Wget", true}, {"fetch libfetch/2.0", true}, {"HTTPie/0.9.3", true}, {"Go 1.1 package http", true}, {"Go-http-client/1.1", true}, {"Go-http-client/2.0", true}, {"ddclient/3.8.3", true}, {"Mikrotik/6.x Fetch", true}, {browserUserAgent, false}, } for _, tt := range tests { r := &http.Request{Header: http.Header{"User-Agent": []string{tt.in}}} if got := cliMatcher(r); got != tt.out { t.Errorf("Expected %t, got %t for %q", tt.out, got, tt.in) } } }