diff --git a/api/api.go b/api/api.go index 3970581..870b783 100644 --- a/api/api.go +++ b/api/api.go @@ -12,10 +12,8 @@ import ( "regexp" "strconv" "strings" - "time" "github.com/gorilla/mux" - geoip2 "github.com/oschwald/geoip2-golang" ) const APPLICATION_JSON = "application/json" @@ -27,13 +25,8 @@ var USER_AGENT_RE = regexp.MustCompile( type API struct { CORS bool Template string - lookupAddr func(string) ([]string, error) - lookupCountry func(net.IP) (string, error) - testPort func(net.IP, uint64) error + oracle Oracle ipFromRequest func(*http.Request) (net.IP, error) - reverseLookup bool - countryLookup bool - portTesting bool } type Response struct { @@ -48,37 +41,13 @@ type TestPortResponse struct { Reachable bool `json:"reachable"` } -func New() *API { +func New(oracle Oracle) *API { return &API{ - lookupAddr: func(addr string) (names []string, err error) { return nil, nil }, - lookupCountry: func(ip net.IP) (string, error) { return "", nil }, - testPort: func(ip net.IP, port uint64) error { return nil }, + oracle: oracle, ipFromRequest: ipFromRequest, } } -func (a *API) EnableCountryLookup(filepath string) error { - db, err := geoip2.Open(filepath) - if err != nil { - return err - } - a.lookupCountry = func(ip net.IP) (string, error) { - return lookupCountry(db, ip) - } - a.countryLookup = true - return nil -} - -func (a *API) EnableReverseLookup() { - a.lookupAddr = net.LookupAddr - a.reverseLookup = true -} - -func (a *API) EnablePortTesting() { - a.testPort = testPort - a.portTesting = true -} - func ipFromRequest(r *http.Request) (net.IP, error) { remoteIP := r.Header.Get("X-Real-IP") if remoteIP == "" { @@ -95,43 +64,16 @@ func ipFromRequest(r *http.Request) (net.IP, error) { return ip, nil } -func testPort(ip net.IP, port uint64) error { - address := fmt.Sprintf("%s:%d", ip, port) - conn, err := net.DialTimeout("tcp", address, 2*time.Second) - if err != nil { - return err - } - defer conn.Close() - return nil -} - -func lookupCountry(db *geoip2.Reader, ip net.IP) (string, error) { - if db == nil { - return "", nil - } - record, err := db.Country(ip) - if err != nil { - return "", err - } - if country, exists := record.Country.Names["en"]; exists { - return country, nil - } - if country, exists := record.RegisteredCountry.Names["en"]; exists { - return country, nil - } - return "Unknown", fmt.Errorf("could not determine country for IP: %s", ip) -} - func (a *API) newResponse(r *http.Request) (Response, error) { ip, err := a.ipFromRequest(r) if err != nil { return Response{}, err } - country, err := a.lookupCountry(ip) + country, err := a.oracle.LookupCountry(ip) if err != nil { log.Print(err) } - hostnames, err := a.lookupAddr(ip.String()) + hostnames, err := a.oracle.LookupAddr(ip.String()) if err != nil { log.Print(err) } @@ -187,7 +129,7 @@ func (a *API) PortHandler(w http.ResponseWriter, r *http.Request) *appError { if err != nil { return internalServerError(err).AsJSON() } - err = a.testPort(ip, port) + err = a.oracle.LookupPort(ip, port) response := TestPortResponse{ IP: ip, Port: port, @@ -213,10 +155,8 @@ func (a *API) DefaultHandler(w http.ResponseWriter, r *http.Request) *appError { } var data = struct { Response - ReverseLookup bool - CountryLookup bool - PortTesting bool - }{response, a.reverseLookup, a.countryLookup, a.portTesting} + Oracle + }{response, a.oracle} if err := t.Execute(w, &data); err != nil { return internalServerError(err) } diff --git a/api/api_test.go b/api/api_test.go index ceef36d..6be3000 100644 --- a/api/api_test.go +++ b/api/api_test.go @@ -9,20 +9,21 @@ import ( "testing" ) +type mockOracle struct{} + +func (r *mockOracle) LookupAddr(string) ([]string, error) { return []string{"localhost"}, nil } +func (r *mockOracle) LookupCountry(net.IP) (string, error) { return "Elbonia", nil } +func (r *mockOracle) LookupPort(net.IP, uint64) error { return nil } +func (r *mockOracle) IsLookupAddrEnabled() bool { return true } +func (r *mockOracle) IsLookupCountryEnabled() bool { return true } +func (r *mockOracle) IsLookupPortEnabled() bool { return true } + func newTestAPI() *API { return &API{ - lookupAddr: func(string) ([]string, error) { - return []string{"localhost"}, nil - }, - lookupCountry: func(ip net.IP) (string, error) { - return "Elbonia", nil - }, + oracle: &mockOracle{}, ipFromRequest: func(*http.Request) (net.IP, error) { return net.ParseIP("127.0.0.1"), nil }, - testPort: func(net.IP, uint64) error { - return nil - }, } } diff --git a/api/oracle.go b/api/oracle.go new file mode 100644 index 0000000..b745049 --- /dev/null +++ b/api/oracle.go @@ -0,0 +1,97 @@ +package api + +import ( + "fmt" + "net" + "time" + + "github.com/oschwald/geoip2-golang" +) + +type Oracle interface { + LookupAddr(string) ([]string, error) + LookupCountry(net.IP) (string, error) + LookupPort(net.IP, uint64) error + IsLookupAddrEnabled() bool + IsLookupCountryEnabled() bool + IsLookupPortEnabled() bool +} + +type DefaultOracle struct { + lookupAddr func(string) ([]string, error) + lookupCountry func(net.IP) (string, error) + lookupPort func(net.IP, uint64) error + lookupAddrEnabled bool + lookupCountryEnabled bool + lookupPortEnabled bool +} + +func NewOracle() *DefaultOracle { + return &DefaultOracle{ + lookupAddr: func(string) ([]string, error) { return nil, nil }, + lookupCountry: func(net.IP) (string, error) { return "", nil }, + lookupPort: func(net.IP, uint64) error { return nil }, + } +} + +func (r *DefaultOracle) LookupAddr(address string) ([]string, error) { + return r.lookupAddr(address) +} + +func (r *DefaultOracle) LookupCountry(ip net.IP) (string, error) { + return r.lookupCountry(ip) +} + +func (r *DefaultOracle) LookupPort(ip net.IP, port uint64) error { + return r.lookupPort(ip, port) +} + +func (r *DefaultOracle) EnableLookupAddr() { + r.lookupAddr = net.LookupAddr + r.lookupAddrEnabled = true +} + +func (r *DefaultOracle) EnableLookupCountry(filepath string) error { + db, err := geoip2.Open(filepath) + if err != nil { + return err + } + r.lookupCountry = func(ip net.IP) (string, error) { + return lookupCountry(db, ip) + } + r.lookupCountryEnabled = true + return nil +} + +func (r *DefaultOracle) EnableLookupPort() { + r.lookupPort = lookupPort + r.lookupPortEnabled = true +} + +func (r *DefaultOracle) IsLookupAddrEnabled() bool { return r.lookupAddrEnabled } +func (r *DefaultOracle) IsLookupCountryEnabled() bool { return r.lookupCountryEnabled } +func (r *DefaultOracle) IsLookupPortEnabled() bool { return r.lookupPortEnabled } + +func lookupPort(ip net.IP, port uint64) error { + address := fmt.Sprintf("%s:%d", ip, port) + conn, err := net.DialTimeout("tcp", address, 2*time.Second) + if err != nil { + return err + } + defer conn.Close() + return nil +} + +func lookupCountry(db *geoip2.Reader, ip net.IP) (string, error) { + record, err := db.Country(ip) + if err != nil { + return "", err + } + if country, exists := record.Country.Names["en"]; exists { + return country, nil + } + if country, exists := record.RegisteredCountry.Names["en"]; exists { + return country, nil + } + return "Unknown", fmt.Errorf("could not determine country for IP: %s", ip) +} diff --git a/index.html b/index.html index ffd7e67..c2be443 100644 --- a/index.html +++ b/index.html @@ -60,7 +60,7 @@ $ wget -qO- ifconfig.co $ fetch -qo- http://ifconfig.co {{ .IP }} -{{ if .CountryLookup }} +{{ if .IsLookupCountryEnabled }}

Country lookup:

 $ http ifconfig.co/country
@@ -77,8 +77,8 @@ Content-Length: 61
 Content-Type: application/json
 Date: Fri, 15 Apr 2016 17:26:53 GMT
 
-{ {{ if .CountryLookup }}
-    "country": "{{ .Country }}",{{ end }}{{ if .ReverseLookup }}
+{ {{ if .IsLookupCountryEnabled }}
+    "country": "{{ .Country }}",{{ end }}{{ if .IsLookupAddrEnabled }}
     "hostname": "{{ .Hostname }}",{{ end }}
     "ip": "{{ .IP }}"
 }
@@ -86,7 +86,7 @@ Date: Fri, 15 Apr 2016 17:26:53 GMT
 # or set Accept header to application/json:
 # http --json ifconfig.co
         
-{{ if .PortTesting }} +{{ if .IsLookupPortEnabled }}

Testing port connectivity (only supports JSON output):

 http --json localhost:8080/port/8080
diff --git a/main.go b/main.go
index 53d8da7..46b6e01 100644
--- a/main.go
+++ b/main.go
@@ -15,7 +15,7 @@ func main() {
 		Listen        string `short:"l" long:"listen" description:"Listening address" value-name:"ADDR" default:":8080"`
 		CORS          bool   `short:"x" long:"cors" description:"Allow requests from other domains"`
 		ReverseLookup bool   `short:"r" long:"reverse-lookup" description:"Perform reverse hostname lookups"`
-		PortTesting   bool   `short:"p" long:"port-testing" description:"Enable port testing"`
+		PortLookup    bool   `short:"p" long:"port-lookup" description:"Enable port lookup"`
 		Template      string `short:"t" long:"template" description:"Path to template" default:"index.html"`
 	}
 	_, err := flags.ParseArgs(&opts, os.Args)
@@ -23,22 +23,24 @@ func main() {
 		os.Exit(1)
 	}
 
-	api := api.New()
-	api.CORS = opts.CORS
+	oracle := api.NewOracle()
 	if opts.ReverseLookup {
 		log.Println("Enabling reverse lookup")
-		api.EnableReverseLookup()
+		oracle.EnableLookupAddr()
 	}
-	if opts.PortTesting {
-		log.Println("Enabling port testing")
-		api.EnablePortTesting()
+	if opts.PortLookup {
+		log.Println("Enabling port lookup")
+		oracle.EnableLookupPort()
 	}
 	if opts.DBPath != "" {
-		log.Printf("Enabling country lookup (using database: %s)\n", opts.DBPath)
-		if err := api.EnableCountryLookup(opts.DBPath); err != nil {
+		log.Printf("Enabling country lookup (using database: %s)", opts.DBPath)
+		if err := oracle.EnableLookupCountry(opts.DBPath); err != nil {
 			log.Fatal(err)
 		}
 	}
+
+	api := api.New(oracle)
+	api.CORS = opts.CORS
 	api.Template = opts.Template
 
 	log.Printf("Listening on %s", opts.Listen)