diff --git a/api/api.go b/api/api.go index 3970581..870b783 100644 --- a/api/api.go +++ b/api/api.go @@ -12,10 +12,8 @@ import ( "regexp" "strconv" "strings" - "time" "github.com/gorilla/mux" - geoip2 "github.com/oschwald/geoip2-golang" ) const APPLICATION_JSON = "application/json" @@ -27,13 +25,8 @@ var USER_AGENT_RE = regexp.MustCompile( type API struct { CORS bool Template string - lookupAddr func(string) ([]string, error) - lookupCountry func(net.IP) (string, error) - testPort func(net.IP, uint64) error + oracle Oracle ipFromRequest func(*http.Request) (net.IP, error) - reverseLookup bool - countryLookup bool - portTesting bool } type Response struct { @@ -48,37 +41,13 @@ type TestPortResponse struct { Reachable bool `json:"reachable"` } -func New() *API { +func New(oracle Oracle) *API { return &API{ - lookupAddr: func(addr string) (names []string, err error) { return nil, nil }, - lookupCountry: func(ip net.IP) (string, error) { return "", nil }, - testPort: func(ip net.IP, port uint64) error { return nil }, + oracle: oracle, ipFromRequest: ipFromRequest, } } -func (a *API) EnableCountryLookup(filepath string) error { - db, err := geoip2.Open(filepath) - if err != nil { - return err - } - a.lookupCountry = func(ip net.IP) (string, error) { - return lookupCountry(db, ip) - } - a.countryLookup = true - return nil -} - -func (a *API) EnableReverseLookup() { - a.lookupAddr = net.LookupAddr - a.reverseLookup = true -} - -func (a *API) EnablePortTesting() { - a.testPort = testPort - a.portTesting = true -} - func ipFromRequest(r *http.Request) (net.IP, error) { remoteIP := r.Header.Get("X-Real-IP") if remoteIP == "" { @@ -95,43 +64,16 @@ func ipFromRequest(r *http.Request) (net.IP, error) { return ip, nil } -func testPort(ip net.IP, port uint64) error { - address := fmt.Sprintf("%s:%d", ip, port) - conn, err := net.DialTimeout("tcp", address, 2*time.Second) - if err != nil { - return err - } - defer conn.Close() - return nil -} - -func lookupCountry(db *geoip2.Reader, ip net.IP) (string, error) { - if db == nil { - return "", nil - } - record, err := db.Country(ip) - if err != nil { - return "", err - } - if country, exists := record.Country.Names["en"]; exists { - return country, nil - } - if country, exists := record.RegisteredCountry.Names["en"]; exists { - return country, nil - } - return "Unknown", fmt.Errorf("could not determine country for IP: %s", ip) -} - func (a *API) newResponse(r *http.Request) (Response, error) { ip, err := a.ipFromRequest(r) if err != nil { return Response{}, err } - country, err := a.lookupCountry(ip) + country, err := a.oracle.LookupCountry(ip) if err != nil { log.Print(err) } - hostnames, err := a.lookupAddr(ip.String()) + hostnames, err := a.oracle.LookupAddr(ip.String()) if err != nil { log.Print(err) } @@ -187,7 +129,7 @@ func (a *API) PortHandler(w http.ResponseWriter, r *http.Request) *appError { if err != nil { return internalServerError(err).AsJSON() } - err = a.testPort(ip, port) + err = a.oracle.LookupPort(ip, port) response := TestPortResponse{ IP: ip, Port: port, @@ -213,10 +155,8 @@ func (a *API) DefaultHandler(w http.ResponseWriter, r *http.Request) *appError { } var data = struct { Response - ReverseLookup bool - CountryLookup bool - PortTesting bool - }{response, a.reverseLookup, a.countryLookup, a.portTesting} + Oracle + }{response, a.oracle} if err := t.Execute(w, &data); err != nil { return internalServerError(err) } diff --git a/api/api_test.go b/api/api_test.go index ceef36d..6be3000 100644 --- a/api/api_test.go +++ b/api/api_test.go @@ -9,20 +9,21 @@ import ( "testing" ) +type mockOracle struct{} + +func (r *mockOracle) LookupAddr(string) ([]string, error) { return []string{"localhost"}, nil } +func (r *mockOracle) LookupCountry(net.IP) (string, error) { return "Elbonia", nil } +func (r *mockOracle) LookupPort(net.IP, uint64) error { return nil } +func (r *mockOracle) IsLookupAddrEnabled() bool { return true } +func (r *mockOracle) IsLookupCountryEnabled() bool { return true } +func (r *mockOracle) IsLookupPortEnabled() bool { return true } + func newTestAPI() *API { return &API{ - lookupAddr: func(string) ([]string, error) { - return []string{"localhost"}, nil - }, - lookupCountry: func(ip net.IP) (string, error) { - return "Elbonia", nil - }, + oracle: &mockOracle{}, ipFromRequest: func(*http.Request) (net.IP, error) { return net.ParseIP("127.0.0.1"), nil }, - testPort: func(net.IP, uint64) error { - return nil - }, } } diff --git a/api/oracle.go b/api/oracle.go new file mode 100644 index 0000000..b745049 --- /dev/null +++ b/api/oracle.go @@ -0,0 +1,97 @@ +package api + +import ( + "fmt" + "net" + "time" + + "github.com/oschwald/geoip2-golang" +) + +type Oracle interface { + LookupAddr(string) ([]string, error) + LookupCountry(net.IP) (string, error) + LookupPort(net.IP, uint64) error + IsLookupAddrEnabled() bool + IsLookupCountryEnabled() bool + IsLookupPortEnabled() bool +} + +type DefaultOracle struct { + lookupAddr func(string) ([]string, error) + lookupCountry func(net.IP) (string, error) + lookupPort func(net.IP, uint64) error + lookupAddrEnabled bool + lookupCountryEnabled bool + lookupPortEnabled bool +} + +func NewOracle() *DefaultOracle { + return &DefaultOracle{ + lookupAddr: func(string) ([]string, error) { return nil, nil }, + lookupCountry: func(net.IP) (string, error) { return "", nil }, + lookupPort: func(net.IP, uint64) error { return nil }, + } +} + +func (r *DefaultOracle) LookupAddr(address string) ([]string, error) { + return r.lookupAddr(address) +} + +func (r *DefaultOracle) LookupCountry(ip net.IP) (string, error) { + return r.lookupCountry(ip) +} + +func (r *DefaultOracle) LookupPort(ip net.IP, port uint64) error { + return r.lookupPort(ip, port) +} + +func (r *DefaultOracle) EnableLookupAddr() { + r.lookupAddr = net.LookupAddr + r.lookupAddrEnabled = true +} + +func (r *DefaultOracle) EnableLookupCountry(filepath string) error { + db, err := geoip2.Open(filepath) + if err != nil { + return err + } + r.lookupCountry = func(ip net.IP) (string, error) { + return lookupCountry(db, ip) + } + r.lookupCountryEnabled = true + return nil +} + +func (r *DefaultOracle) EnableLookupPort() { + r.lookupPort = lookupPort + r.lookupPortEnabled = true +} + +func (r *DefaultOracle) IsLookupAddrEnabled() bool { return r.lookupAddrEnabled } +func (r *DefaultOracle) IsLookupCountryEnabled() bool { return r.lookupCountryEnabled } +func (r *DefaultOracle) IsLookupPortEnabled() bool { return r.lookupPortEnabled } + +func lookupPort(ip net.IP, port uint64) error { + address := fmt.Sprintf("%s:%d", ip, port) + conn, err := net.DialTimeout("tcp", address, 2*time.Second) + if err != nil { + return err + } + defer conn.Close() + return nil +} + +func lookupCountry(db *geoip2.Reader, ip net.IP) (string, error) { + record, err := db.Country(ip) + if err != nil { + return "", err + } + if country, exists := record.Country.Names["en"]; exists { + return country, nil + } + if country, exists := record.RegisteredCountry.Names["en"]; exists { + return country, nil + } + return "Unknown", fmt.Errorf("could not determine country for IP: %s", ip) +} diff --git a/index.html b/index.html index ffd7e67..c2be443 100644 --- a/index.html +++ b/index.html @@ -60,7 +60,7 @@ $ wget -qO- ifconfig.co $ fetch -qo- http://ifconfig.co {{ .IP }} -{{ if .CountryLookup }} +{{ if .IsLookupCountryEnabled }}
Country lookup:
$ http ifconfig.co/country @@ -77,8 +77,8 @@ Content-Length: 61 Content-Type: application/json Date: Fri, 15 Apr 2016 17:26:53 GMT -{ {{ if .CountryLookup }} - "country": "{{ .Country }}",{{ end }}{{ if .ReverseLookup }} +{ {{ if .IsLookupCountryEnabled }} + "country": "{{ .Country }}",{{ end }}{{ if .IsLookupAddrEnabled }} "hostname": "{{ .Hostname }}",{{ end }} "ip": "{{ .IP }}" } @@ -86,7 +86,7 @@ Date: Fri, 15 Apr 2016 17:26:53 GMT # or set Accept header to application/json: # http --json ifconfig.co-{{ if .PortTesting }} +{{ if .IsLookupPortEnabled }}
Testing port connectivity (only supports JSON output):
http --json localhost:8080/port/8080 diff --git a/main.go b/main.go index 53d8da7..46b6e01 100644 --- a/main.go +++ b/main.go @@ -15,7 +15,7 @@ func main() { Listen string `short:"l" long:"listen" description:"Listening address" value-name:"ADDR" default:":8080"` CORS bool `short:"x" long:"cors" description:"Allow requests from other domains"` ReverseLookup bool `short:"r" long:"reverse-lookup" description:"Perform reverse hostname lookups"` - PortTesting bool `short:"p" long:"port-testing" description:"Enable port testing"` + PortLookup bool `short:"p" long:"port-lookup" description:"Enable port lookup"` Template string `short:"t" long:"template" description:"Path to template" default:"index.html"` } _, err := flags.ParseArgs(&opts, os.Args) @@ -23,22 +23,24 @@ func main() { os.Exit(1) } - api := api.New() - api.CORS = opts.CORS + oracle := api.NewOracle() if opts.ReverseLookup { log.Println("Enabling reverse lookup") - api.EnableReverseLookup() + oracle.EnableLookupAddr() } - if opts.PortTesting { - log.Println("Enabling port testing") - api.EnablePortTesting() + if opts.PortLookup { + log.Println("Enabling port lookup") + oracle.EnableLookupPort() } if opts.DBPath != "" { - log.Printf("Enabling country lookup (using database: %s)\n", opts.DBPath) - if err := api.EnableCountryLookup(opts.DBPath); err != nil { + log.Printf("Enabling country lookup (using database: %s)", opts.DBPath) + if err := oracle.EnableLookupCountry(opts.DBPath); err != nil { log.Fatal(err) } } + + api := api.New(oracle) + api.CORS = opts.CORS api.Template = opts.Template log.Printf("Listening on %s", opts.Listen)