From 3134de826096836208b3401e38c599f6992a8656 Mon Sep 17 00:00:00 2001 From: Martin Polden Date: Sun, 17 Apr 2016 15:52:06 +0200 Subject: [PATCH] Do not trust X-Real-IP header by default Use -H option to whitelist header to trust for remote IP address. This is useful when a reverse proxy is used in front of ipd. --- README.md | 1 + api/api.go | 13 +++++++------ api/api_test.go | 27 +++++++++++++++++++-------- main.go | 5 +++++ 4 files changed, 32 insertions(+), 14 deletions(-) diff --git a/README.md b/README.md index 955e45e..5958027 100644 --- a/README.md +++ b/README.md @@ -88,6 +88,7 @@ Application Options: -r, --reverse-lookup Perform reverse hostname lookups -p, --port-lookup Enable port lookup -t, --template= Path to template (default: index.html) + -H, --trusted-header= Header to trust for remote IP, if present (e.g. X-Real-IP) Help Options: -h, --help Show this help message diff --git a/api/api.go b/api/api.go index e822dfc..bf4ab5d 100644 --- a/api/api.go +++ b/api/api.go @@ -24,8 +24,9 @@ var USER_AGENT_RE = regexp.MustCompile( type API struct { Template string + IPHeader string oracle Oracle - ipFromRequest func(*http.Request) (net.IP, error) + ipFromRequest func(string, *http.Request) (net.IP, error) } type Response struct { @@ -48,8 +49,8 @@ func New(oracle Oracle) *API { } } -func ipFromRequest(r *http.Request) (net.IP, error) { - remoteIP := r.Header.Get("X-Real-IP") +func ipFromRequest(header string, r *http.Request) (net.IP, error) { + remoteIP := r.Header.Get(header) if remoteIP == "" { host, _, err := net.SplitHostPort(r.RemoteAddr) if err != nil { @@ -65,7 +66,7 @@ func ipFromRequest(r *http.Request) (net.IP, error) { } func (a *API) newResponse(r *http.Request) (Response, error) { - ip, err := a.ipFromRequest(r) + ip, err := a.ipFromRequest(a.IPHeader, r) if err != nil { return Response{}, err } @@ -90,7 +91,7 @@ func (a *API) newResponse(r *http.Request) (Response, error) { } func (a *API) CLIHandler(w http.ResponseWriter, r *http.Request) *appError { - ip, err := a.ipFromRequest(r) + ip, err := a.ipFromRequest(a.IPHeader, r) if err != nil { return internalServerError(err) } @@ -139,7 +140,7 @@ func (a *API) PortHandler(w http.ResponseWriter, r *http.Request) *appError { if port < 1 || port > 65355 { return badRequest(nil).WithMessage("Invalid port: " + vars["port"]).AsJSON() } - ip, err := a.ipFromRequest(r) + ip, err := a.ipFromRequest(a.IPHeader, r) if err != nil { return internalServerError(err).AsJSON() } diff --git a/api/api_test.go b/api/api_test.go index 245869a..25df69e 100644 --- a/api/api_test.go +++ b/api/api_test.go @@ -23,7 +23,7 @@ func (r *mockOracle) IsLookupPortEnabled() bool { return true } func newTestAPI() *API { return &API{ oracle: &mockOracle{}, - ipFromRequest: func(*http.Request) (net.IP, error) { + ipFromRequest: func(string, *http.Request) (net.IP, error) { return net.ParseIP("127.0.0.1"), nil }, } @@ -110,19 +110,30 @@ func TestJSONHandlers(t *testing.T) { func TestIPFromRequest(t *testing.T) { var tests = []struct { - in *http.Request - out net.IP + remoteAddr string + headerKey string + headerValue string + trustedHeader string + out string }{ - {&http.Request{RemoteAddr: "1.3.3.7:9999"}, net.ParseIP("1.3.3.7")}, - {&http.Request{Header: http.Header{"X-Real-Ip": []string{"1.3.3.7"}}}, net.ParseIP("1.3.3.7")}, + {"127.0.0.1:9999", "", "", "", "127.0.0.1"}, // No header given + {"127.0.0.1:9999", "X-Real-IP", "1.3.3.7", "", "127.0.0.1"}, // Trusted header is empty + {"127.0.0.1:9999", "X-Real-IP", "1.3.3.7", "X-Foo-Bar", "127.0.0.1"}, // Trusted header does not match + {"127.0.0.1:9999", "X-Real-IP", "1.3.3.7", "X-Real-IP", "1.3.3.7"}, // Trusted header matches } for _, tt := range tests { - ip, err := ipFromRequest(tt.in) + r := &http.Request{ + RemoteAddr: tt.remoteAddr, + Header: http.Header{}, + } + r.Header.Add(tt.headerKey, tt.headerValue) + ip, err := ipFromRequest(tt.trustedHeader, r) if err != nil { t.Fatal(err) } - if !ip.Equal(tt.out) { - t.Errorf("Expected %s, got %s", tt.out, ip) + out := net.ParseIP(tt.out) + if !ip.Equal(out) { + t.Errorf("Expected %s, got %s", out, ip) } } } diff --git a/main.go b/main.go index 1e541a5..ed8e762 100644 --- a/main.go +++ b/main.go @@ -17,6 +17,7 @@ func main() { ReverseLookup bool `short:"r" long:"reverse-lookup" description:"Perform reverse hostname lookups"` PortLookup bool `short:"p" long:"port-lookup" description:"Enable port lookup"` Template string `short:"t" long:"template" description:"Path to template" default:"index.html"` + IPHeader string `short:"H" long:"trusted-header" description:"Header to trust for remote IP, if present (e.g. X-Real-IP)"` } _, err := flags.ParseArgs(&opts, os.Args) if err != nil { @@ -44,9 +45,13 @@ func main() { log.Fatal(err) } } + if opts.IPHeader != "" { + log.Printf("Trusting header %s to contain correct remote IP", opts.IPHeader) + } api := api.New(oracle) api.Template = opts.Template + api.IPHeader = opts.IPHeader log.Printf("Listening on %s", opts.Listen) if err := api.ListenAndServe(opts.Listen); err != nil {