diff --git a/api/api.go b/api/api.go index d1df5f5..98f12b6 100644 --- a/api/api.go +++ b/api/api.go @@ -74,7 +74,7 @@ func (a *API) newResponse(r *http.Request) (Response, error) { if err != nil { log.Print(err) } - hostnames, err := a.oracle.LookupAddr(ip.String()) + hostnames, err := a.oracle.LookupAddr(ip) if err != nil { log.Print(err) } diff --git a/api/api_test.go b/api/api_test.go index 7687699..c34c5ae 100644 --- a/api/api_test.go +++ b/api/api_test.go @@ -11,7 +11,7 @@ import ( type mockOracle struct{} -func (r *mockOracle) LookupAddr(string) ([]string, error) { return []string{"localhost"}, nil } +func (r *mockOracle) LookupAddr(net.IP) ([]string, error) { return []string{"localhost"}, nil } func (r *mockOracle) LookupCountry(net.IP) (string, error) { return "Elbonia", nil } func (r *mockOracle) LookupCity(net.IP) (string, error) { return "Bornyasherk", nil } func (r *mockOracle) LookupPort(net.IP, uint64) error { return nil } diff --git a/api/oracle.go b/api/oracle.go index 77b999e..7a750bc 100644 --- a/api/oracle.go +++ b/api/oracle.go @@ -3,13 +3,14 @@ package api import ( "fmt" "net" + "strings" "time" "github.com/oschwald/geoip2-golang" ) type Oracle interface { - LookupAddr(string) ([]string, error) + LookupAddr(net.IP) ([]string, error) LookupCountry(net.IP) (string, error) LookupCity(net.IP) (string, error) LookupPort(net.IP, uint64) error @@ -20,7 +21,7 @@ type Oracle interface { } type DefaultOracle struct { - lookupAddr func(string) ([]string, error) + lookupAddr func(net.IP) ([]string, error) lookupCountry func(net.IP) (string, error) lookupCity func(net.IP) (string, error) lookupPort func(net.IP, uint64) error @@ -32,15 +33,15 @@ type DefaultOracle struct { func NewOracle() *DefaultOracle { return &DefaultOracle{ - lookupAddr: func(string) ([]string, error) { return nil, nil }, + lookupAddr: func(net.IP) ([]string, error) { return nil, nil }, lookupCountry: func(net.IP) (string, error) { return "", nil }, lookupCity: func(net.IP) (string, error) { return "", nil }, lookupPort: func(net.IP, uint64) error { return nil }, } } -func (r *DefaultOracle) LookupAddr(address string) ([]string, error) { - return r.lookupAddr(address) +func (r *DefaultOracle) LookupAddr(ip net.IP) ([]string, error) { + return r.lookupAddr(ip) } func (r *DefaultOracle) LookupCountry(ip net.IP) (string, error) { @@ -56,7 +57,7 @@ func (r *DefaultOracle) LookupPort(ip net.IP, port uint64) error { } func (r *DefaultOracle) EnableLookupAddr() { - r.lookupAddr = net.LookupAddr + r.lookupAddr = lookupAddr r.lookupAddrEnabled = true } @@ -94,6 +95,14 @@ func (r *DefaultOracle) IsLookupCountryEnabled() bool { return r.lookupCountryEn func (r *DefaultOracle) IsLookupCityEnabled() bool { return r.lookupCityEnabled } func (r *DefaultOracle) IsLookupPortEnabled() bool { return r.lookupPortEnabled } +func lookupAddr(ip net.IP) ([]string, error) { + names, err := net.LookupAddr(ip.String()) + for i, _ := range names { + names[i] = strings.TrimRight(names[i], ".") // Always return unrooted name + } + return names, err +} + func lookupPort(ip net.IP, port uint64) error { address := fmt.Sprintf("[%s]:%d", ip, port) conn, err := net.DialTimeout("tcp", address, 2*time.Second)