diff --git a/cmd/ipd/main.go b/cmd/ipd/main.go index 3098d30..70d5c52 100644 --- a/cmd/ipd/main.go +++ b/cmd/ipd/main.go @@ -14,13 +14,13 @@ import ( func main() { var opts struct { - CountryDBPath string `short:"f" long:"country-db" description:"Path to GeoIP country database" value-name:"FILE" default:""` - CityDBPath string `short:"c" long:"city-db" description:"Path to GeoIP city database" value-name:"FILE" default:""` - Listen string `short:"l" long:"listen" description:"Listening address" value-name:"ADDR" default:":8080"` - ReverseLookup bool `short:"r" long:"reverse-lookup" description:"Perform reverse hostname lookups"` - PortLookup bool `short:"p" long:"port-lookup" description:"Enable port lookup"` - Template string `short:"t" long:"template" description:"Path to template" default:"index.html" value-name:"FILE"` - IPHeader string `short:"H" long:"trusted-header" description:"Header to trust for remote IP, if present (e.g. X-Real-IP)" value-name:"NAME"` + CountryDBPath string `short:"f" long:"country-db" description:"Path to GeoIP country database" value-name:"FILE" default:""` + CityDBPath string `short:"c" long:"city-db" description:"Path to GeoIP city database" value-name:"FILE" default:""` + Listen string `short:"l" long:"listen" description:"Listening address" value-name:"ADDR" default:":8080"` + ReverseLookup bool `short:"r" long:"reverse-lookup" description:"Perform reverse hostname lookups"` + PortLookup bool `short:"p" long:"port-lookup" description:"Enable port lookup"` + Template string `short:"t" long:"template" description:"Path to template" default:"index.html" value-name:"FILE"` + IPHeaders []string `short:"H" long:"trusted-header" description:"Header to trust for remote IP, if present (e.g. X-Real-IP)" value-name:"NAME"` } _, err := flags.ParseArgs(&opts, os.Args) if err != nil { @@ -35,7 +35,7 @@ func main() { server := http.New(db) server.Template = opts.Template - server.IPHeader = opts.IPHeader + server.IPHeaders = opts.IPHeaders if opts.ReverseLookup { log.Println("Enabling reverse lookup") server.LookupAddr = iputil.LookupAddr @@ -44,8 +44,8 @@ func main() { log.Println("Enabling port lookup") server.LookupPort = iputil.LookupPort } - if opts.IPHeader != "" { - log.Printf("Trusting header %s to contain correct remote IP", opts.IPHeader) + if len(opts.IPHeaders) > 0 { + log.Printf("Trusting header(s) %+v to contain correct remote IP", opts.IPHeaders) } log.Printf("Listening on http://%s", opts.Listen) diff --git a/http/http.go b/http/http.go index c04f514..cac1010 100644 --- a/http/http.go +++ b/http/http.go @@ -22,7 +22,7 @@ const ( type Server struct { Template string - IPHeader string + IPHeaders []string LookupAddr func(net.IP) (string, error) LookupPort func(net.IP, uint64) error db database.Client @@ -47,8 +47,14 @@ func New(db database.Client) *Server { return &Server{db: db} } -func ipFromRequest(header string, r *http.Request) (net.IP, error) { - remoteIP := r.Header.Get(header) +func ipFromRequest(headers []string, r *http.Request) (net.IP, error) { + remoteIP := "" + for _, header := range headers { + remoteIP = r.Header.Get(header) + if remoteIP != "" { + break + } + } if remoteIP == "" { host, _, err := net.SplitHostPort(r.RemoteAddr) if err != nil { @@ -64,7 +70,7 @@ func ipFromRequest(header string, r *http.Request) (net.IP, error) { } func (s *Server) newResponse(r *http.Request) (Response, error) { - ip, err := ipFromRequest(s.IPHeader, r) + ip, err := ipFromRequest(s.IPHeaders, r) if err != nil { return Response{}, err } @@ -91,7 +97,7 @@ func (s *Server) newPortResponse(r *http.Request) (PortResponse, error) { if err != nil || port < 1 || port > 65355 { return PortResponse{Port: port}, fmt.Errorf("invalid port: %d", port) } - ip, err := ipFromRequest(s.IPHeader, r) + ip, err := ipFromRequest(s.IPHeaders, r) if err != nil { return PortResponse{Port: port}, err } @@ -104,7 +110,7 @@ func (s *Server) newPortResponse(r *http.Request) (PortResponse, error) { } func (s *Server) CLIHandler(w http.ResponseWriter, r *http.Request) *appError { - ip, err := ipFromRequest(s.IPHeader, r) + ip, err := ipFromRequest(s.IPHeaders, r) if err != nil { return internalServerError(err) } diff --git a/http/http_test.go b/http/http_test.go index 66a3027..206d9ff 100644 --- a/http/http_test.go +++ b/http/http_test.go @@ -149,16 +149,17 @@ func TestJSONHandlers(t *testing.T) { func TestIPFromRequest(t *testing.T) { var tests = []struct { - remoteAddr string - headerKey string - headerValue string - trustedHeader string - out string + remoteAddr string + headerKey string + headerValue string + trustedHeaders []string + out string }{ - {"127.0.0.1:9999", "", "", "", "127.0.0.1"}, // No header given - {"127.0.0.1:9999", "X-Real-IP", "1.3.3.7", "", "127.0.0.1"}, // Trusted header is empty - {"127.0.0.1:9999", "X-Real-IP", "1.3.3.7", "X-Foo-Bar", "127.0.0.1"}, // Trusted header does not match - {"127.0.0.1:9999", "X-Real-IP", "1.3.3.7", "X-Real-IP", "1.3.3.7"}, // Trusted header matches + {"127.0.0.1:9999", "", "", nil, "127.0.0.1"}, // No header given + {"127.0.0.1:9999", "X-Real-IP", "1.3.3.7", nil, "127.0.0.1"}, // Trusted header is empty + {"127.0.0.1:9999", "X-Real-IP", "1.3.3.7", []string{"X-Foo-Bar"}, "127.0.0.1"}, // Trusted header does not match + {"127.0.0.1:9999", "X-Real-IP", "1.3.3.7", []string{"X-Real-IP", "X-Forwarded-For"}, "1.3.3.7"}, // Trusted header matches + {"127.0.0.1:9999", "X-Forwarded-For", "1.3.3.7", []string{"X-Real-IP", "X-Forwarded-For"}, "1.3.3.7"}, // Second trusted header matches } for _, tt := range tests { r := &http.Request{ @@ -166,7 +167,7 @@ func TestIPFromRequest(t *testing.T) { Header: http.Header{}, } r.Header.Add(tt.headerKey, tt.headerValue) - ip, err := ipFromRequest(tt.trustedHeader, r) + ip, err := ipFromRequest(tt.trustedHeaders, r) if err != nil { t.Fatal(err) }