From 1efde17791bfc12c317678cd296e9fa837aacee2 Mon Sep 17 00:00:00 2001 From: Martin Polden Date: Sun, 18 Mar 2018 22:15:51 +0100 Subject: [PATCH] Replace gorilla/mux with own router --- http/http.go | 38 ++++++++++------------ http/http_test.go | 8 ++--- http/router.go | 80 +++++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 101 insertions(+), 25 deletions(-) create mode 100644 http/router.go diff --git a/http/http.go b/http/http.go index 078dc70..9e1b8ac 100644 --- a/http/http.go +++ b/http/http.go @@ -4,6 +4,7 @@ import ( "encoding/json" "fmt" "html/template" + "path/filepath" "github.com/mpolden/ipd/iputil" "github.com/mpolden/ipd/iputil/database" @@ -14,8 +15,6 @@ import ( "net/http" "strconv" "strings" - - "github.com/gorilla/mux" ) const ( @@ -100,8 +99,8 @@ func (s *Server) newResponse(r *http.Request) (Response, error) { } func (s *Server) newPortResponse(r *http.Request) (PortResponse, error) { - vars := mux.Vars(r) - port, err := strconv.ParseUint(vars["port"], 10, 16) + lastElement := filepath.Base(r.URL.Path) + port, err := strconv.ParseUint(lastElement, 10, 16) if err != nil { return PortResponse{Port: port}, err } @@ -214,7 +213,7 @@ func (s *Server) DefaultHandler(w http.ResponseWriter, r *http.Request) *appErro return nil } -func (s *Server) NotFoundHandler(w http.ResponseWriter, r *http.Request) *appError { +func NotFoundHandler(w http.ResponseWriter, r *http.Request) *appError { err := notFound(nil).WithMessage("404 page not found") if r.Header.Get("accept") == jsonMediaType { err = err.AsJSON() @@ -222,7 +221,7 @@ func (s *Server) NotFoundHandler(w http.ResponseWriter, r *http.Request) *appErr return err } -func cliMatcher(r *http.Request, rm *mux.RouteMatch) bool { +func cliMatcher(r *http.Request) bool { ua := useragent.Parse(r.UserAgent()) switch ua.Product { case "curl", "HTTPie", "Wget", "fetch libfetch", "Go", "Go-http-client", "ddclient": @@ -256,34 +255,31 @@ func (fn appHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { } func (s *Server) Handler() http.Handler { - r := mux.NewRouter() + r := NewRouter() // JSON - r.Handle("/", appHandler(s.JSONHandler)).Methods("GET").Headers("Accept", jsonMediaType) - r.Handle("/json", appHandler(s.JSONHandler)).Methods("GET") + r.Route("GET", "/", s.JSONHandler).Header("Accept", jsonMediaType) + r.Route("GET", "/json", s.JSONHandler) // CLI - r.Handle("/", appHandler(s.CLIHandler)).Methods("GET").MatcherFunc(cliMatcher) - r.Handle("/", appHandler(s.CLIHandler)).Methods("GET").Headers("Accept", textMediaType) - r.Handle("/ip", appHandler(s.CLIHandler)).Methods("GET") + r.Route("GET", "/", s.CLIHandler).MatcherFunc(cliMatcher) + r.Route("GET", "/", s.CLIHandler).Header("Accept", textMediaType) + r.Route("GET", "/ip", s.CLIHandler) if !s.db.IsEmpty() { - r.Handle("/country", appHandler(s.CLICountryHandler)).Methods("GET") - r.Handle("/country-iso", appHandler(s.CLICountryISOHandler)).Methods("GET") - r.Handle("/city", appHandler(s.CLICityHandler)).Methods("GET") + r.Route("GET", "/country", s.CLICountryHandler) + r.Route("GET", "/country-iso", s.CLICountryISOHandler) + r.Route("GET", "/city", s.CLICityHandler) } // Browser - r.Handle("/", appHandler(s.DefaultHandler)).Methods("GET") + r.Route("GET", "/", s.DefaultHandler) // Port testing if s.LookupPort != nil { - r.Handle("/port/{port:[0-9]+}", appHandler(s.PortHandler)).Methods("GET") + r.RoutePrefix("GET", "/port/", s.PortHandler) } - // Not found handler which returns JSON when appropriate - r.NotFoundHandler = appHandler(s.NotFoundHandler) - - return r + return r.Handler() } func (s *Server) ListenAndServe(addr string) error { diff --git a/http/http_test.go b/http/http_test.go index 52b92c9..b4a9e05 100644 --- a/http/http_test.go +++ b/http/http_test.go @@ -126,7 +126,7 @@ func TestJSONHandlers(t *testing.T) { status int }{ {s.URL, `{"ip":"127.0.0.1","ip_decimal":2130706433,"country":"Elbonia","country_iso":"EB","city":"Bornyasherk","hostname":"localhost"}`, 200}, - {s.URL + "/port/foo", `{"error":"404 page not found"}`, 404}, + {s.URL + "/port/foo", `{"error":"Invalid port: 0"}`, 400}, {s.URL + "/port/0", `{"error":"Invalid port: 0"}`, 400}, {s.URL + "/port/65356", `{"error":"Invalid port: 65356"}`, 400}, {s.URL + "/port/31337", `{"ip":"127.0.0.1","port":31337,"reachable":true}`, 200}, @@ -139,10 +139,10 @@ func TestJSONHandlers(t *testing.T) { t.Fatal(err) } if status != tt.status { - t.Errorf("Expected %d, got %d", tt.status, status) + t.Errorf("Expected %d for %s, got %d", tt.status, tt.url, status) } if out != tt.out { - t.Errorf("Expected %q, got %q", tt.out, out) + t.Errorf("Expected %q for %s, got %q", tt.out, tt.url, out) } } } @@ -198,7 +198,7 @@ func TestCLIMatcher(t *testing.T) { } for _, tt := range tests { r := &http.Request{Header: http.Header{"User-Agent": []string{tt.in}}} - if got := cliMatcher(r, nil); got != tt.out { + if got := cliMatcher(r); got != tt.out { t.Errorf("Expected %t, got %t for %q", tt.out, got, tt.in) } } diff --git a/http/router.go b/http/router.go new file mode 100644 index 0000000..786363d --- /dev/null +++ b/http/router.go @@ -0,0 +1,80 @@ +package http + +import ( + "net/http" + "strings" +) + +type router struct { + routes []*route +} + +type route struct { + method string + path string + prefix bool + matcherFuncs []func(*http.Request) bool + handler appHandler +} + +func NewRouter() *router { + return &router{} +} + +func (r *router) Route(method, path string, handler appHandler) *route { + route := route{ + method: method, + path: path, + handler: handler, + } + r.routes = append(r.routes, &route) + return &route +} + +func (r *router) RoutePrefix(method, path string, handler appHandler) *route { + route := r.Route(method, path, handler) + route.prefix = true + return route +} + +func (r *router) Handler() http.Handler { + return appHandler(func(w http.ResponseWriter, req *http.Request) *appError { + for _, route := range r.routes { + if route.match(req) { + return route.handler(w, req) + } + } + return NotFoundHandler(w, req) + }) +} + +func (r *route) Header(header, value string) *route { + return r.MatcherFunc(func(req *http.Request) bool { + return req.Header.Get(header) == value + }) +} + +func (r *route) MatcherFunc(f func(*http.Request) bool) *route { + r.matcherFuncs = append(r.matcherFuncs, f) + return r +} + +func (r *route) match(req *http.Request) bool { + if req.Method != r.method { + return false + } + if r.prefix { + if !strings.HasPrefix(req.URL.Path, r.path) { + return false + } + } else if r.path != req.URL.Path { + return false + } + match := len(r.matcherFuncs) == 0 + for _, f := range r.matcherFuncs { + if match = f(req); match { + break + } + } + return match +}