diff --git a/api/api.go b/api/api.go index 9efee2b..3c6d7cb 100644 --- a/api/api.go +++ b/api/api.go @@ -10,7 +10,9 @@ import ( "net/http" "path/filepath" "regexp" + "strconv" "strings" + "time" "github.com/gorilla/mux" geoip2 "github.com/oschwald/geoip2-golang" @@ -25,6 +27,7 @@ type API struct { Template string lookupAddr func(string) ([]string, error) lookupCountry func(net.IP) (string, error) + testPort func(net.IP, uint64) error ipFromRequest func(*http.Request) (net.IP, error) } @@ -34,10 +37,17 @@ type Response struct { Hostname string `json:"hostname,omitempty"` } +type TestPortResponse struct { + IP net.IP `json:"ip"` + Port uint64 `json:"port"` + Reachable bool `json:"reachable"` +} + func New() *API { return &API{ lookupAddr: func(addr string) (names []string, err error) { return nil, nil }, lookupCountry: func(ip net.IP) (string, error) { return "", nil }, + testPort: func(ip net.IP, port uint64) error { return nil }, ipFromRequest: ipFromRequest, } } @@ -57,6 +67,10 @@ func (a *API) EnableReverseLookup() { a.lookupAddr = net.LookupAddr } +func (a *API) EnablePortTesting() { + a.testPort = testPort +} + func ipFromRequest(r *http.Request) (net.IP, error) { var host string realIP := r.Header.Get("X-Real-IP") @@ -76,6 +90,14 @@ func ipFromRequest(r *http.Request) (net.IP, error) { return ip, nil } +func testPort(ip net.IP, port uint64) error { + address := fmt.Sprintf("%s:%d", ip, port) + if _, err := net.DialTimeout("tcp", address, 2*time.Second); err != nil { + return err + } + return nil +} + func lookupCountry(db *geoip2.Reader, ip net.IP) (string, error) { if db == nil { return "", nil @@ -140,6 +162,34 @@ func (a *API) JSONHandler(w http.ResponseWriter, r *http.Request) *appError { return nil } +func (a *API) TestPortHandler(w http.ResponseWriter, r *http.Request) *appError { + vars := mux.Vars(r) + port, err := strconv.ParseUint(vars["port"], 10, 16) + if err != nil { + return badRequest(err).WithMessage("Invalid port: " + vars["port"]).AsJSON() + } + if port < 1 || port > 65355 { + return badRequest(nil).WithMessage("Invalid port: " + vars["port"]).AsJSON() + } + ip, err := a.ipFromRequest(r) + if err != nil { + return internalServerError(err).AsJSON() + } + err = testPort(ip, port) + response := TestPortResponse{ + IP: ip, + Port: port, + Reachable: err == nil, + } + b, err := json.Marshal(response) + if err != nil { + return internalServerError(err).AsJSON() + } + w.Header().Set("Content-Type", APPLICATION_JSON) + w.Write(b) + return nil +} + func (a *API) DefaultHandler(w http.ResponseWriter, r *http.Request) *appError { response, err := a.newResponse(r) if err != nil { @@ -218,6 +268,9 @@ func (a *API) Handlers() http.Handler { // Browser r.Handle("/", appHandler(a.DefaultHandler)).Methods("GET") + // Port testing + r.Handle("/port/{port:[0-9]+}", appHandler(a.TestPortHandler)).Methods("GET") + // Not found handler which returns JSON when appropriate r.NotFoundHandler = appHandler(a.NotFoundHandler) diff --git a/api/api_test.go b/api/api_test.go index c0b877e..68554dc 100644 --- a/api/api_test.go +++ b/api/api_test.go @@ -22,6 +22,9 @@ func newTestAPI() *API { ipFromRequest: func(*http.Request) (net.IP, error) { return net.ParseIP("127.0.0.1"), nil }, + testPort: func(net.IP, uint64) error { + return nil + }, } } diff --git a/api/error.go b/api/error.go index dfbd9bb..8d878cd 100644 --- a/api/error.go +++ b/api/error.go @@ -21,6 +21,10 @@ func notFound(err error) *appError { return &appError{Error: err, Code: http.StatusNotFound} } +func badRequest(err error) *appError { + return &appError{Error: err, Code: http.StatusBadRequest} +} + func (e *appError) AsJSON() *appError { e.ContentType = APPLICATION_JSON return e diff --git a/index.html b/index.html index 290b6f0..fcaf9dc 100644 --- a/index.html +++ b/index.html @@ -78,6 +78,20 @@ Date: Fri, 15 Apr 2016 17:26:53 GMT "ip": "{{ .IP }}" } +

Testing port connectivity (only supports JSON output):

+
+http --json localhost:8080/port/8080
+HTTP/1.1 200 OK
+Content-Length: 47
+Content-Type: application/json
+Date: Fri, 15 Apr 2016 18:47:20 GMT
+
+{
+    "ip": "127.0.0.1",
+    "port": 8080,
+    "reachable": true
+}
+        
diff --git a/main.go b/main.go index 6499b7a..53d8da7 100644 --- a/main.go +++ b/main.go @@ -15,6 +15,7 @@ func main() { Listen string `short:"l" long:"listen" description:"Listening address" value-name:"ADDR" default:":8080"` CORS bool `short:"x" long:"cors" description:"Allow requests from other domains"` ReverseLookup bool `short:"r" long:"reverse-lookup" description:"Perform reverse hostname lookups"` + PortTesting bool `short:"p" long:"port-testing" description:"Enable port testing"` Template string `short:"t" long:"template" description:"Path to template" default:"index.html"` } _, err := flags.ParseArgs(&opts, os.Args) @@ -28,6 +29,10 @@ func main() { log.Println("Enabling reverse lookup") api.EnableReverseLookup() } + if opts.PortTesting { + log.Println("Enabling port testing") + api.EnablePortTesting() + } if opts.DBPath != "" { log.Printf("Enabling country lookup (using database: %s)\n", opts.DBPath) if err := api.EnableCountryLookup(opts.DBPath); err != nil {