diff --git a/cmd/ipd/main.go b/cmd/ipd/main.go index 453fdb1..3bf150a 100644 --- a/cmd/ipd/main.go +++ b/cmd/ipd/main.go @@ -34,31 +34,26 @@ func main() { } log.Level = level - db := database.Empty() - if opts.CountryDBPath != "" || opts.CityDBPath != "" { - db, err = database.New(opts.CountryDBPath, opts.CityDBPath) - if err != nil { - log.Fatal(err) - } + db, err := database.New(opts.CountryDBPath, opts.CityDBPath) + if err != nil { + log.Fatal(err) } - var lookupAddr http.LookupAddr - var lookupPort http.LookupPort + + server := http.New(db, log) + server.Template = opts.Template + server.IPHeader = opts.IPHeader if opts.ReverseLookup { log.Println("Enabling reverse lookup") - lookupAddr = iputil.LookupAddr + server.LookupAddr = iputil.LookupAddr } if opts.PortLookup { log.Println("Enabling port lookup") - lookupPort = iputil.LookupPort + server.LookupPort = iputil.LookupPort } if opts.IPHeader != "" { log.Printf("Trusting header %s to contain correct remote IP", opts.IPHeader) } - server := http.New(db, lookupAddr, lookupPort, log) - server.Template = opts.Template - server.IPHeader = opts.IPHeader - log.Printf("Listening on http://%s", opts.Listen) if err := server.ListenAndServe(opts.Listen); err != nil { log.Fatal(err) diff --git a/http/http.go b/http/http.go index dffa989..1f127f5 100644 --- a/http/http.go +++ b/http/http.go @@ -23,14 +23,11 @@ const ( textMediaType = "text/plain" ) -type LookupAddr func(net.IP) ([]string, error) -type LookupPort func(net.IP, uint64) error - type Server struct { Template string IPHeader string - lookupAddr LookupAddr - lookupPort LookupPort + LookupAddr func(net.IP) ([]string, error) + LookupPort func(net.IP, uint64) error db database.Client log *logrus.Logger } @@ -50,8 +47,8 @@ type PortResponse struct { Reachable bool `json:"reachable"` } -func New(db database.Client, lookupAddr LookupAddr, lookupPort LookupPort, log *logrus.Logger) *Server { - return &Server{lookupAddr: lookupAddr, lookupPort: lookupPort, db: db, log: log} +func New(db database.Client, log *logrus.Logger) *Server { + return &Server{db: db, log: log} } func ipFromRequest(header string, r *http.Request) (net.IP, error) { @@ -85,8 +82,8 @@ func (s *Server) newResponse(r *http.Request) (Response, error) { s.log.Debug(err) } var hostnames []string - if s.lookupAddr != nil { - h, err := s.lookupAddr(ip) + if s.LookupAddr != nil { + h, err := s.LookupAddr(ip) if err != nil { s.log.Debug(err) } @@ -115,7 +112,7 @@ func (s *Server) newPortResponse(r *http.Request) (PortResponse, error) { if err != nil { return PortResponse{Port: port}, err } - err = s.lookupPort(ip, port) + err = s.LookupPort(ip, port) return PortResponse{ IP: ip, Port: port, @@ -210,7 +207,7 @@ func (s *Server) DefaultHandler(w http.ResponseWriter, r *http.Request) *appErro response, r.Host, string(json), - s.lookupPort != nil, + s.LookupPort != nil, response.Country != "" && response.City != "", } if err := t.Execute(w, &data); err != nil { @@ -281,7 +278,7 @@ func (s *Server) Handler() http.Handler { r.Handle("/", appHandler(s.DefaultHandler)).Methods("GET") // Port testing - if s.lookupPort != nil { + if s.LookupPort != nil { r.Handle("/port/{port:[0-9]+}", appHandler(s.PortHandler)).Methods("GET") } diff --git a/http/http_test.go b/http/http_test.go index de1d407..52b92c9 100644 --- a/http/http_test.go +++ b/http/http_test.go @@ -24,7 +24,7 @@ func (t *testDb) City(net.IP) (string, error) { return "Bornyasherk", nil } func (t *testDb) IsEmpty() bool { return false } func testServer() *Server { - return &Server{db: &testDb{}, lookupAddr: lookupAddr, lookupPort: lookupPort} + return &Server{db: &testDb{}, LookupAddr: lookupAddr, LookupPort: lookupPort} } func httpGet(url string, acceptMediaType string, userAgent string) (string, int, error) { @@ -85,9 +85,9 @@ func TestCLIHandlers(t *testing.T) { func TestDisabledHandlers(t *testing.T) { log.SetOutput(ioutil.Discard) server := testServer() - server.lookupPort = nil - server.lookupAddr = nil - server.db = database.Empty() + server.LookupPort = nil + server.LookupAddr = nil + server.db, _ = database.New("", "") s := httptest.NewServer(server.Handler()) var tests = []struct { diff --git a/iputil/database/database.go b/iputil/database/database.go index 9aa427e..c431312 100644 --- a/iputil/database/database.go +++ b/iputil/database/database.go @@ -22,14 +22,6 @@ type geoip struct { city *geoip2.Reader } -type empty struct{} - -func (d *empty) Country(ip net.IP) (Country, error) { return Country{}, nil } -func (d *empty) City(ip net.IP) (string, error) { return "", nil } -func (d *empty) IsEmpty() bool { return true } - -func Empty() Client { return &empty{} } - func New(countryDB, cityDB string) (Client, error) { var country, city *geoip2.Reader if countryDB != "" {