From 35061bfe83e5f71f7e4f9e74f8a23e1589d89dee Mon Sep 17 00:00:00 2001 From: Martin Polden Date: Sat, 10 Feb 2018 13:24:32 +0100 Subject: [PATCH] Restructure --- cmd/ipd/main.go | 15 ++--- {api => http}/error.go | 2 +- api/api.go => http/http.go | 96 +++++++++++++++------------- api/api_test.go => http/http_test.go | 10 +-- {api => http}/oracle.go | 2 +- 5 files changed, 63 insertions(+), 62 deletions(-) rename {api => http}/error.go (98%) rename api/api.go => http/http.go (65%) rename api/api_test.go => http/http_test.go (96%) rename {api => http}/oracle.go (99%) diff --git a/cmd/ipd/main.go b/cmd/ipd/main.go index d896ce7..71e8bf5 100644 --- a/cmd/ipd/main.go +++ b/cmd/ipd/main.go @@ -1,15 +1,12 @@ package main import ( - "net/http" - flags "github.com/jessevdk/go-flags" "os" + "github.com/mpolden/ipd/http" "github.com/sirupsen/logrus" - - "github.com/mpolden/ipd/api" ) func main() { @@ -35,7 +32,7 @@ func main() { } log.Level = level - oracle := api.NewOracle() + oracle := http.NewOracle() if opts.ReverseLookup { log.Println("Enabling reverse lookup") oracle.EnableLookupAddr() @@ -60,12 +57,12 @@ func main() { log.Printf("Trusting header %s to contain correct remote IP", opts.IPHeader) } - api := api.New(oracle, log) - api.Template = opts.Template - api.IPHeader = opts.IPHeader + server := http.New(oracle, log) + server.Template = opts.Template + server.IPHeader = opts.IPHeader log.Printf("Listening on http://%s", opts.Listen) - if err := http.ListenAndServe(opts.Listen, api.Router()); err != nil { + if err := server.ListenAndServe(opts.Listen); err != nil { log.Fatal(err) } } diff --git a/api/error.go b/http/error.go similarity index 98% rename from api/error.go rename to http/error.go index a6ad60e..72c6fce 100644 --- a/api/error.go +++ b/http/error.go @@ -1,4 +1,4 @@ -package api +package http import "net/http" diff --git a/api/api.go b/http/http.go similarity index 65% rename from api/api.go rename to http/http.go index fbddf3f..41f8a65 100644 --- a/api/api.go +++ b/http/http.go @@ -1,4 +1,4 @@ -package api +package http import ( "encoding/json" @@ -23,7 +23,7 @@ const ( textMediaType = "text/plain" ) -type API struct { +type Server struct { Template string IPHeader string oracle Oracle @@ -45,8 +45,8 @@ type PortResponse struct { Reachable bool `json:"reachable"` } -func New(oracle Oracle, logger *logrus.Logger) *API { - return &API{oracle: oracle, log: logger} +func New(oracle Oracle, logger *logrus.Logger) *Server { + return &Server{oracle: oracle, log: logger} } func ipToDecimal(ip net.IP) *big.Int { @@ -75,27 +75,27 @@ func ipFromRequest(header string, r *http.Request) (net.IP, error) { return ip, nil } -func (a *API) newResponse(r *http.Request) (Response, error) { - ip, err := ipFromRequest(a.IPHeader, r) +func (s *Server) newResponse(r *http.Request) (Response, error) { + ip, err := ipFromRequest(s.IPHeader, r) if err != nil { return Response{}, err } ipDecimal := ipToDecimal(ip) - country, err := a.oracle.LookupCountry(ip) + country, err := s.oracle.LookupCountry(ip) if err != nil { - a.log.Debug(err) + s.log.Debug(err) } - countryISO, err := a.oracle.LookupCountryISO(ip) + countryISO, err := s.oracle.LookupCountryISO(ip) if err != nil { - a.log.Debug(err) + s.log.Debug(err) } - city, err := a.oracle.LookupCity(ip) + city, err := s.oracle.LookupCity(ip) if err != nil { - a.log.Debug(err) + s.log.Debug(err) } - hostnames, err := a.oracle.LookupAddr(ip) + hostnames, err := s.oracle.LookupAddr(ip) if err != nil { - a.log.Debug(err) + s.log.Debug(err) } return Response{ IP: ip, @@ -107,7 +107,7 @@ func (a *API) newResponse(r *http.Request) (Response, error) { }, nil } -func (a *API) newPortResponse(r *http.Request) (PortResponse, error) { +func (s *Server) newPortResponse(r *http.Request) (PortResponse, error) { vars := mux.Vars(r) port, err := strconv.ParseUint(vars["port"], 10, 16) if err != nil { @@ -116,11 +116,11 @@ func (a *API) newPortResponse(r *http.Request) (PortResponse, error) { if port < 1 || port > 65355 { return PortResponse{Port: port}, fmt.Errorf("invalid port: %d", port) } - ip, err := ipFromRequest(a.IPHeader, r) + ip, err := ipFromRequest(s.IPHeader, r) if err != nil { return PortResponse{Port: port}, err } - err = a.oracle.LookupPort(ip, port) + err = s.oracle.LookupPort(ip, port) return PortResponse{ IP: ip, Port: port, @@ -128,8 +128,8 @@ func (a *API) newPortResponse(r *http.Request) (PortResponse, error) { }, nil } -func (a *API) CLIHandler(w http.ResponseWriter, r *http.Request) *appError { - ip, err := ipFromRequest(a.IPHeader, r) +func (s *Server) CLIHandler(w http.ResponseWriter, r *http.Request) *appError { + ip, err := ipFromRequest(s.IPHeader, r) if err != nil { return internalServerError(err) } @@ -137,8 +137,8 @@ func (a *API) CLIHandler(w http.ResponseWriter, r *http.Request) *appError { return nil } -func (a *API) CLICountryHandler(w http.ResponseWriter, r *http.Request) *appError { - response, err := a.newResponse(r) +func (s *Server) CLICountryHandler(w http.ResponseWriter, r *http.Request) *appError { + response, err := s.newResponse(r) if err != nil { return internalServerError(err) } @@ -146,8 +146,8 @@ func (a *API) CLICountryHandler(w http.ResponseWriter, r *http.Request) *appErro return nil } -func (a *API) CLICountryISOHandler(w http.ResponseWriter, r *http.Request) *appError { - response, err := a.newResponse(r) +func (s *Server) CLICountryISOHandler(w http.ResponseWriter, r *http.Request) *appError { + response, err := s.newResponse(r) if err != nil { return internalServerError(err) } @@ -155,8 +155,8 @@ func (a *API) CLICountryISOHandler(w http.ResponseWriter, r *http.Request) *appE return nil } -func (a *API) CLICityHandler(w http.ResponseWriter, r *http.Request) *appError { - response, err := a.newResponse(r) +func (s *Server) CLICityHandler(w http.ResponseWriter, r *http.Request) *appError { + response, err := s.newResponse(r) if err != nil { return internalServerError(err) } @@ -164,8 +164,8 @@ func (a *API) CLICityHandler(w http.ResponseWriter, r *http.Request) *appError { return nil } -func (a *API) JSONHandler(w http.ResponseWriter, r *http.Request) *appError { - response, err := a.newResponse(r) +func (s *Server) JSONHandler(w http.ResponseWriter, r *http.Request) *appError { + response, err := s.newResponse(r) if err != nil { return internalServerError(err).AsJSON() } @@ -178,8 +178,8 @@ func (a *API) JSONHandler(w http.ResponseWriter, r *http.Request) *appError { return nil } -func (a *API) PortHandler(w http.ResponseWriter, r *http.Request) *appError { - response, err := a.newPortResponse(r) +func (s *Server) PortHandler(w http.ResponseWriter, r *http.Request) *appError { + response, err := s.newPortResponse(r) if err != nil { return badRequest(err).WithMessage(fmt.Sprintf("Invalid port: %d", response.Port)).AsJSON() } @@ -192,12 +192,12 @@ func (a *API) PortHandler(w http.ResponseWriter, r *http.Request) *appError { return nil } -func (a *API) DefaultHandler(w http.ResponseWriter, r *http.Request) *appError { - response, err := a.newResponse(r) +func (s *Server) DefaultHandler(w http.ResponseWriter, r *http.Request) *appError { + response, err := s.newResponse(r) if err != nil { return internalServerError(err) } - t, err := template.New(filepath.Base(a.Template)).ParseFiles(a.Template) + t, err := template.New(filepath.Base(s.Template)).ParseFiles(s.Template) if err != nil { return internalServerError(err) } @@ -205,14 +205,14 @@ func (a *API) DefaultHandler(w http.ResponseWriter, r *http.Request) *appError { Host string Response Oracle - }{r.Host, response, a.oracle} + }{r.Host, response, s.oracle} if err := t.Execute(w, &data); err != nil { return internalServerError(err) } return nil } -func (a *API) NotFoundHandler(w http.ResponseWriter, r *http.Request) *appError { +func (s *Server) NotFoundHandler(w http.ResponseWriter, r *http.Request) *appError { err := notFound(nil).WithMessage("404 page not found") if r.Header.Get("accept") == jsonMediaType { err = err.AsJSON() @@ -253,29 +253,33 @@ func (fn appHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { } } -func (a *API) Router() http.Handler { +func (s *Server) Handler() http.Handler { r := mux.NewRouter() // JSON - r.Handle("/", appHandler(a.JSONHandler)).Methods("GET").Headers("Accept", jsonMediaType) - r.Handle("/json", appHandler(a.JSONHandler)).Methods("GET") + r.Handle("/", appHandler(s.JSONHandler)).Methods("GET").Headers("Accept", jsonMediaType) + r.Handle("/json", appHandler(s.JSONHandler)).Methods("GET") // CLI - r.Handle("/", appHandler(a.CLIHandler)).Methods("GET").MatcherFunc(cliMatcher) - r.Handle("/", appHandler(a.CLIHandler)).Methods("GET").Headers("Accept", textMediaType) - r.Handle("/ip", appHandler(a.CLIHandler)).Methods("GET") - r.Handle("/country", appHandler(a.CLICountryHandler)).Methods("GET") - r.Handle("/country-iso", appHandler(a.CLICountryISOHandler)).Methods("GET") - r.Handle("/city", appHandler(a.CLICityHandler)).Methods("GET") + r.Handle("/", appHandler(s.CLIHandler)).Methods("GET").MatcherFunc(cliMatcher) + r.Handle("/", appHandler(s.CLIHandler)).Methods("GET").Headers("Accept", textMediaType) + r.Handle("/ip", appHandler(s.CLIHandler)).Methods("GET") + r.Handle("/country", appHandler(s.CLICountryHandler)).Methods("GET") + r.Handle("/country-iso", appHandler(s.CLICountryISOHandler)).Methods("GET") + r.Handle("/city", appHandler(s.CLICityHandler)).Methods("GET") // Browser - r.Handle("/", appHandler(a.DefaultHandler)).Methods("GET") + r.Handle("/", appHandler(s.DefaultHandler)).Methods("GET") // Port testing - r.Handle("/port/{port:[0-9]+}", appHandler(a.PortHandler)).Methods("GET") + r.Handle("/port/{port:[0-9]+}", appHandler(s.PortHandler)).Methods("GET") // Not found handler which returns JSON when appropriate - r.NotFoundHandler = appHandler(a.NotFoundHandler) + r.NotFoundHandler = appHandler(s.NotFoundHandler) return r } + +func (s *Server) ListenAndServe(addr string) error { + return http.ListenAndServe(addr, s.Handler()) +} diff --git a/api/api_test.go b/http/http_test.go similarity index 96% rename from api/api_test.go rename to http/http_test.go index a4b0c61..f4ce1a3 100644 --- a/api/api_test.go +++ b/http/http_test.go @@ -1,4 +1,4 @@ -package api +package http import ( "io/ioutil" @@ -22,8 +22,8 @@ func (r *mockOracle) IsLookupCountryEnabled() bool { return true } func (r *mockOracle) IsLookupCityEnabled() bool { return true } func (r *mockOracle) IsLookupPortEnabled() bool { return true } -func newTestAPI() *API { - return &API{oracle: &mockOracle{}} +func newTestAPI() *Server { + return &Server{oracle: &mockOracle{}} } func httpGet(url string, acceptMediaType string, userAgent string) (string, int, error) { @@ -49,7 +49,7 @@ func httpGet(url string, acceptMediaType string, userAgent string) (string, int, func TestCLIHandlers(t *testing.T) { log.SetOutput(ioutil.Discard) - s := httptest.NewServer(newTestAPI().Router()) + s := httptest.NewServer(newTestAPI().Handler()) var tests = []struct { url string @@ -83,7 +83,7 @@ func TestCLIHandlers(t *testing.T) { func TestJSONHandlers(t *testing.T) { log.SetOutput(ioutil.Discard) - s := httptest.NewServer(newTestAPI().Router()) + s := httptest.NewServer(newTestAPI().Handler()) var tests = []struct { url string diff --git a/api/oracle.go b/http/oracle.go similarity index 99% rename from api/oracle.go rename to http/oracle.go index 1030241..1654758 100644 --- a/api/oracle.go +++ b/http/oracle.go @@ -1,4 +1,4 @@ -package api +package http import ( "fmt"