Do not trust X-Real-IP header by default

Use -H option to whitelist header to trust for remote IP address. This
is useful when a reverse proxy is used in front of ipd.
This commit is contained in:
Martin Polden 2016-04-17 15:52:06 +02:00
parent 270ffec441
commit 3134de8260
4 changed files with 32 additions and 14 deletions

View File

@ -88,6 +88,7 @@ Application Options:
-r, --reverse-lookup Perform reverse hostname lookups
-p, --port-lookup Enable port lookup
-t, --template= Path to template (default: index.html)
-H, --trusted-header= Header to trust for remote IP, if present (e.g. X-Real-IP)
Help Options:
-h, --help Show this help message

View File

@ -24,8 +24,9 @@ var USER_AGENT_RE = regexp.MustCompile(
type API struct {
Template string
IPHeader string
oracle Oracle
ipFromRequest func(*http.Request) (net.IP, error)
ipFromRequest func(string, *http.Request) (net.IP, error)
}
type Response struct {
@ -48,8 +49,8 @@ func New(oracle Oracle) *API {
}
}
func ipFromRequest(r *http.Request) (net.IP, error) {
remoteIP := r.Header.Get("X-Real-IP")
func ipFromRequest(header string, r *http.Request) (net.IP, error) {
remoteIP := r.Header.Get(header)
if remoteIP == "" {
host, _, err := net.SplitHostPort(r.RemoteAddr)
if err != nil {
@ -65,7 +66,7 @@ func ipFromRequest(r *http.Request) (net.IP, error) {
}
func (a *API) newResponse(r *http.Request) (Response, error) {
ip, err := a.ipFromRequest(r)
ip, err := a.ipFromRequest(a.IPHeader, r)
if err != nil {
return Response{}, err
}
@ -90,7 +91,7 @@ func (a *API) newResponse(r *http.Request) (Response, error) {
}
func (a *API) CLIHandler(w http.ResponseWriter, r *http.Request) *appError {
ip, err := a.ipFromRequest(r)
ip, err := a.ipFromRequest(a.IPHeader, r)
if err != nil {
return internalServerError(err)
}
@ -139,7 +140,7 @@ func (a *API) PortHandler(w http.ResponseWriter, r *http.Request) *appError {
if port < 1 || port > 65355 {
return badRequest(nil).WithMessage("Invalid port: " + vars["port"]).AsJSON()
}
ip, err := a.ipFromRequest(r)
ip, err := a.ipFromRequest(a.IPHeader, r)
if err != nil {
return internalServerError(err).AsJSON()
}

View File

@ -23,7 +23,7 @@ func (r *mockOracle) IsLookupPortEnabled() bool { return true }
func newTestAPI() *API {
return &API{
oracle: &mockOracle{},
ipFromRequest: func(*http.Request) (net.IP, error) {
ipFromRequest: func(string, *http.Request) (net.IP, error) {
return net.ParseIP("127.0.0.1"), nil
},
}
@ -110,19 +110,30 @@ func TestJSONHandlers(t *testing.T) {
func TestIPFromRequest(t *testing.T) {
var tests = []struct {
in *http.Request
out net.IP
remoteAddr string
headerKey string
headerValue string
trustedHeader string
out string
}{
{&http.Request{RemoteAddr: "1.3.3.7:9999"}, net.ParseIP("1.3.3.7")},
{&http.Request{Header: http.Header{"X-Real-Ip": []string{"1.3.3.7"}}}, net.ParseIP("1.3.3.7")},
{"127.0.0.1:9999", "", "", "", "127.0.0.1"}, // No header given
{"127.0.0.1:9999", "X-Real-IP", "1.3.3.7", "", "127.0.0.1"}, // Trusted header is empty
{"127.0.0.1:9999", "X-Real-IP", "1.3.3.7", "X-Foo-Bar", "127.0.0.1"}, // Trusted header does not match
{"127.0.0.1:9999", "X-Real-IP", "1.3.3.7", "X-Real-IP", "1.3.3.7"}, // Trusted header matches
}
for _, tt := range tests {
ip, err := ipFromRequest(tt.in)
r := &http.Request{
RemoteAddr: tt.remoteAddr,
Header: http.Header{},
}
r.Header.Add(tt.headerKey, tt.headerValue)
ip, err := ipFromRequest(tt.trustedHeader, r)
if err != nil {
t.Fatal(err)
}
if !ip.Equal(tt.out) {
t.Errorf("Expected %s, got %s", tt.out, ip)
out := net.ParseIP(tt.out)
if !ip.Equal(out) {
t.Errorf("Expected %s, got %s", out, ip)
}
}
}

View File

@ -17,6 +17,7 @@ func main() {
ReverseLookup bool `short:"r" long:"reverse-lookup" description:"Perform reverse hostname lookups"`
PortLookup bool `short:"p" long:"port-lookup" description:"Enable port lookup"`
Template string `short:"t" long:"template" description:"Path to template" default:"index.html"`
IPHeader string `short:"H" long:"trusted-header" description:"Header to trust for remote IP, if present (e.g. X-Real-IP)"`
}
_, err := flags.ParseArgs(&opts, os.Args)
if err != nil {
@ -44,9 +45,13 @@ func main() {
log.Fatal(err)
}
}
if opts.IPHeader != "" {
log.Printf("Trusting header %s to contain correct remote IP", opts.IPHeader)
}
api := api.New(oracle)
api.Template = opts.Template
api.IPHeader = opts.IPHeader
log.Printf("Listening on %s", opts.Listen)
if err := api.ListenAndServe(opts.Listen); err != nil {