diff --git a/middleware/http_tracing.go b/middleware/http_tracing.go index 54c99844..a1ac40f5 100644 --- a/middleware/http_tracing.go +++ b/middleware/http_tracing.go @@ -17,20 +17,28 @@ var _ = nethttp.MWURLTagFunc // Tracer is a middleware which traces incoming requests. type Tracer struct { RouteMatcher RouteMatcher + SourceIPs *SourceIPExtractor } // Wrap implements Interface func (t Tracer) Wrap(next http.Handler) http.Handler { - opMatcher := nethttp.OperationNameFunc(func(r *http.Request) string { - op := getRouteName(t.RouteMatcher, r) - if op == "" { - return "HTTP " + r.Method - } - - return fmt.Sprintf("HTTP %s - %s", r.Method, op) - }) + options := []nethttp.MWOption{ + nethttp.OperationNameFunc(func(r *http.Request) string { + op := getRouteName(t.RouteMatcher, r) + if op == "" { + return "HTTP " + r.Method + } + + return fmt.Sprintf("HTTP %s - %s", r.Method, op) + }), + } + if t.SourceIPs != nil { + options = append(options, nethttp.MWSpanObserver(func(sp opentracing.Span, r *http.Request) { + sp.SetTag("sourceIPs", t.SourceIPs.Get(r)) + })) + } - return nethttp.Middleware(opentracing.GlobalTracer(), next, opMatcher) + return nethttp.Middleware(opentracing.GlobalTracer(), next, options...) } // ExtractTraceID extracts the trace id, if any from the context. diff --git a/middleware/logging.go b/middleware/logging.go index 79368f9d..00270df9 100644 --- a/middleware/logging.go +++ b/middleware/logging.go @@ -15,16 +15,25 @@ import ( type Log struct { Log logging.Interface LogRequestHeaders bool // LogRequestHeaders true -> dump http headers at debug log level + SourceIPs *SourceIPExtractor } // logWithRequest information from the request and context as fields. func (l Log) logWithRequest(r *http.Request) logging.Interface { + localLog := l.Log traceID, ok := ExtractTraceID(r.Context()) if ok { - l.Log = l.Log.WithField("traceID", traceID) + localLog = localLog.WithField("traceID", traceID) } - return user.LogWith(r.Context(), l.Log) + if l.SourceIPs != nil { + ips := l.SourceIPs.Get(r) + if ips != "" { + localLog = localLog.WithField("sourceIPs", ips) + } + } + + return user.LogWith(r.Context(), localLog) } // Wrap implements Middleware diff --git a/middleware/source_ips.go b/middleware/source_ips.go new file mode 100644 index 00000000..17178d42 --- /dev/null +++ b/middleware/source_ips.go @@ -0,0 +1,141 @@ +package middleware + +import ( + "fmt" + "net" + "net/http" + "regexp" + "strings" +) + +// Parts copied and changed from gorilla mux proxy_headers.go + +var ( + // De-facto standard header keys. + xForwardedFor = http.CanonicalHeaderKey("X-Forwarded-For") + xRealIP = http.CanonicalHeaderKey("X-Real-IP") +) + +var ( + // RFC7239 defines a new "Forwarded: " header designed to replace the + // existing use of X-Forwarded-* headers. + // e.g. Forwarded: for=192.0.2.60;proto=https;by=203.0.113.43 + forwarded = http.CanonicalHeaderKey("Forwarded") + // Allows for a sub-match of the first value after 'for=' to the next + // comma, semi-colon or space. The match is case-insensitive. + forRegex = regexp.MustCompile(`(?i)(?:for=)([^(;|,| )]+)`) +) + +// SourceIPExtractor extracts the source IPs from a HTTP request +type SourceIPExtractor struct { + // The header to search for + header string + // A regex that extracts the IP address from the header. + // It should contain at least one capturing group the first of which will be returned. + regex *regexp.Regexp +} + +// NewSourceIPs creates a new SourceIPs +func NewSourceIPs(header, regex string) (*SourceIPExtractor, error) { + if (header == "" && regex != "") || (header != "" && regex == "") { + return nil, fmt.Errorf("either both a header field and a regex have to be given or neither") + } + re, err := regexp.Compile(regex) + if err != nil { + return nil, fmt.Errorf("invalid regex given") + } + + return &SourceIPExtractor{ + header: header, + regex: re, + }, nil +} + +// extractHost returns the Host IP address without any port information +func extractHost(address string) string { + hostIP := net.ParseIP(address) + if hostIP != nil { + return hostIP.String() + } + var err error + hostStr, _, err := net.SplitHostPort(address) + if err != nil { + // Invalid IP address, just return it so it shows up in the logs + return address + } + return hostStr +} + +// Get returns any source addresses we can find in the request, comma-separated +func (sips SourceIPExtractor) Get(req *http.Request) string { + fwd := extractHost(sips.getIP(req)) + if fwd == "" { + if req.RemoteAddr == "" { + return "" + } + return extractHost(req.RemoteAddr) + } + // If RemoteAddr is empty just return the header + if req.RemoteAddr == "" { + return fwd + } + remoteIP := extractHost(req.RemoteAddr) + if fwd == remoteIP { + return remoteIP + } + // If both a header and RemoteAddr are present return them both, stripping off any port info from the RemoteAddr + return fmt.Sprintf("%v, %v", fwd, remoteIP) +} + +// getIP retrieves the IP from the RFC7239 Forwarded headers, +// X-Real-IP and X-Forwarded-For (in that order) or from the +// custom regex. +func (sips SourceIPExtractor) getIP(r *http.Request) string { + var addr string + + // Use the custom regex only if it was setup + if sips.header != "" { + hdr := r.Header.Get(sips.header) + if hdr == "" { + return "" + } + allMatches := sips.regex.FindAllStringSubmatch(hdr, 1) + if len(allMatches) == 0 { + return "" + } + firstMatch := allMatches[0] + // Check there is at least 1 submatch + if len(firstMatch) < 2 { + return "" + } + return firstMatch[1] + } + + if fwd := r.Header.Get(forwarded); fwd != "" { + // match should contain at least two elements if the protocol was + // specified in the Forwarded header. The first element will always be + // the 'for=' capture, which we ignore. In the case of multiple IP + // addresses (for=8.8.8.8, 8.8.4.4,172.16.1.20 is valid) we only + // extract the first, which should be the client IP. + if match := forRegex.FindStringSubmatch(fwd); len(match) > 1 { + // IPv6 addresses in Forwarded headers are quoted-strings. We strip + // these quotes. + addr = strings.Trim(match[1], `"`) + } + } else if fwd := r.Header.Get(xRealIP); fwd != "" { + // X-Real-IP should only contain one IP address (the client making the + // request). + addr = fwd + } else if fwd := strings.ReplaceAll(r.Header.Get(xForwardedFor), " ", ""); fwd != "" { + // Only grab the first (client) address. Note that '192.168.0.1, + // 10.1.1.1' is a valid key for X-Forwarded-For where addresses after + // the first may represent forwarding proxies earlier in the chain. + s := strings.Index(fwd, ",") + if s == -1 { + s = len(fwd) + } + addr = fwd[:s] + } + + return addr +} diff --git a/middleware/source_ips_test.go b/middleware/source_ips_test.go new file mode 100644 index 00000000..455a5383 --- /dev/null +++ b/middleware/source_ips_test.go @@ -0,0 +1,266 @@ +package middleware + +import ( + "net/http" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestGetSourceIPs(t *testing.T) { + tests := []struct { + name string + req *http.Request + want string + }{ + { + name: "no header", + req: &http.Request{RemoteAddr: "192.168.1.100:3454"}, + want: "192.168.1.100", + }, + { + name: "no header and remote has no port", + req: &http.Request{RemoteAddr: "192.168.1.100"}, + want: "192.168.1.100", + }, + { + name: "no header, remote address is invalid", + req: &http.Request{RemoteAddr: "192.168.100"}, + want: "192.168.100", + }, + { + name: "X-Forwarded-For and single forward address", + req: &http.Request{ + RemoteAddr: "192.168.1.100:3454", + Header: map[string][]string{ + http.CanonicalHeaderKey(xForwardedFor): {"172.16.1.1"}, + }, + }, + want: "172.16.1.1, 192.168.1.100", + }, + { + name: "X-Forwarded-For and single forward address which is same as remote", + req: &http.Request{ + RemoteAddr: "192.168.1.100:3454", + Header: map[string][]string{ + http.CanonicalHeaderKey(xForwardedFor): {"192.168.1.100"}, + }, + }, + want: "192.168.1.100", + }, + { + name: "single IPv6 X-Forwarded-For address", + req: &http.Request{ + RemoteAddr: "[2001:db9::1]:3454", + Header: map[string][]string{ + http.CanonicalHeaderKey(xForwardedFor): {"2001:db8::1"}, + }, + }, + want: "2001:db8::1, 2001:db9::1", + }, + { + name: "single X-Forwarded-For address no RemoteAddr", + req: &http.Request{ + Header: map[string][]string{ + http.CanonicalHeaderKey(xForwardedFor): {"172.16.1.1"}, + }, + }, + want: "172.16.1.1", + }, + { + name: "multiple X-Forwarded-For with remote", + req: &http.Request{ + RemoteAddr: "192.168.1.100:3454", + Header: map[string][]string{ + http.CanonicalHeaderKey(xForwardedFor): {"172.16.1.1, 10.10.13.20"}, + }, + }, + want: "172.16.1.1, 192.168.1.100", + }, + { + name: "multiple X-Forwarded-For with remote and no spaces", + req: &http.Request{ + RemoteAddr: "192.168.1.100:3454", + Header: map[string][]string{ + http.CanonicalHeaderKey(xForwardedFor): {"172.16.1.1,10.10.13.20,10.11.16.46"}, + }, + }, + want: "172.16.1.1, 192.168.1.100", + }, + { + name: "multiple X-Forwarded-For with IPv6 remote", + req: &http.Request{ + RemoteAddr: "192.168.1.100:3454", + Header: map[string][]string{ + http.CanonicalHeaderKey(xForwardedFor): {"[2001:db8:cafe::17]:4711, 10.10.13.20"}, + }, + }, + want: "2001:db8:cafe::17, 192.168.1.100", + }, + { + name: "no header, no remote", + req: &http.Request{}, + want: "", + }, + { + name: "X-Real-IP with IPv6 remote with port", + req: &http.Request{ + RemoteAddr: "192.168.1.100:3454", + Header: map[string][]string{ + http.CanonicalHeaderKey(xRealIP): {"[2001:db8:cafe::17]:4711"}, + }, + }, + want: "2001:db8:cafe::17, 192.168.1.100", + }, + { + name: "X-Real-IP with IPv4 remote", + req: &http.Request{ + RemoteAddr: "192.168.1.100:3454", + Header: map[string][]string{ + http.CanonicalHeaderKey(xRealIP): {"192.169.1.200"}, + }, + }, + want: "192.169.1.200, 192.168.1.100", + }, + { + name: "X-Real-IP with IPv4 remote and X-Forwarded-For", + req: &http.Request{ + RemoteAddr: "192.168.1.100:3454", + Header: map[string][]string{ + http.CanonicalHeaderKey(xForwardedFor): {"[2001:db8:cafe::17]:4711, 10.10.13.20"}, + http.CanonicalHeaderKey(xRealIP): {"192.169.1.200"}, + }, + }, + want: "192.169.1.200, 192.168.1.100", + }, + { + name: "Forwarded with IPv4 remote", + req: &http.Request{ + RemoteAddr: "192.168.1.100:3454", + Header: map[string][]string{ + http.CanonicalHeaderKey(forwarded): {"for=192.169.1.200"}, + }, + }, + want: "192.169.1.200, 192.168.1.100", + }, + { + name: "Forwarded with IPv4 and proto and by fields", + req: &http.Request{ + RemoteAddr: "192.168.1.100:3454", + Header: map[string][]string{ + http.CanonicalHeaderKey(forwarded): {"for=192.0.2.60;proto=http;by=203.0.113.43"}, + }, + }, + want: "192.0.2.60, 192.168.1.100", + }, + { + name: "Forwarded with IPv6 and IPv4 remote", + req: &http.Request{ + RemoteAddr: "192.168.1.100:3454", + Header: map[string][]string{ + http.CanonicalHeaderKey(forwarded): {"for=[2001:db8:cafe::17]:4711,for=192.169.1.200"}, + }, + }, + want: "2001:db8:cafe::17, 192.168.1.100", + }, + { + name: "Forwarded with X-Real-IP and X-Forwarded-For", + req: &http.Request{ + RemoteAddr: "192.168.1.100:3454", + Header: map[string][]string{ + http.CanonicalHeaderKey(xForwardedFor): {"[2001:db8:cafe::17]:4711, 10.10.13.20"}, + http.CanonicalHeaderKey(xRealIP): {"192.169.1.200"}, + http.CanonicalHeaderKey(forwarded): {"for=[2001:db8:cafe::17]:4711,for=192.169.1.200"}, + }, + }, + want: "2001:db8:cafe::17, 192.168.1.100", + }, + { + name: "Forwarded returns hostname", + req: &http.Request{ + RemoteAddr: "192.168.1.100:3454", + Header: map[string][]string{ + http.CanonicalHeaderKey(forwarded): {"for=workstation.local"}, + }, + }, + want: "workstation.local, 192.168.1.100", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + sourceIPs, err := NewSourceIPs("", "") + require.NoError(t, err) + + if got := sourceIPs.Get(tt.req); got != tt.want { + t.Errorf("GetSource() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestGetSourceIPsWithCustomRegex(t *testing.T) { + tests := []struct { + name string + req *http.Request + want string + }{ + { + name: "no header", + req: &http.Request{RemoteAddr: "192.168.1.100:3454"}, + want: "192.168.1.100", + }, + { + name: "No matching entry in the header", + req: &http.Request{ + RemoteAddr: "192.168.1.100:3454", + Header: map[string][]string{ + http.CanonicalHeaderKey("SomeHeader"): {"not matching"}, + }, + }, + want: "192.168.1.100", + }, + { + name: "one matching entry in the header", + req: &http.Request{ + RemoteAddr: "192.168.1.100:3454", + Header: map[string][]string{ + http.CanonicalHeaderKey("SomeHeader"): {"172.16.1.1"}, + }, + }, + want: "172.16.1.1, 192.168.1.100", + }, + { + name: "multiple matching entries in the header, only first used", + req: &http.Request{ + RemoteAddr: "192.168.1.100:3454", + Header: map[string][]string{ + http.CanonicalHeaderKey("SomeHeader"): {"172.16.1.1", "172.16.2.1"}, + }, + }, + want: "172.16.1.1, 192.168.1.100", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + sourceIPs, err := NewSourceIPs("SomeHeader", "((?:[0-9]{1,3}\\.){3}[0-9]{1,3})") + require.NoError(t, err) + + if got := sourceIPs.Get(tt.req); got != tt.want { + t.Errorf("GetSource() = %v, want %v", got, tt.want) + } + }) + } +} +func TestInvalid(t *testing.T) { + sourceIPs, err := NewSourceIPs("Header", "") + require.Empty(t, sourceIPs) + require.Error(t, err) + + sourceIPs, err = NewSourceIPs("", "a(.*)b") + require.Empty(t, sourceIPs) + require.Error(t, err) + + sourceIPs, err = NewSourceIPs("Header", "[*") + require.Empty(t, sourceIPs) + require.Error(t, err) +} diff --git a/server/server.go b/server/server.go index 4dc8ae0e..c8c44d98 100644 --- a/server/server.go +++ b/server/server.go @@ -76,9 +76,12 @@ type Config struct { GRPCServerTime time.Duration `yaml:"grpc_server_keepalive_time"` GRPCServerTimeout time.Duration `yaml:"grpc_server_keepalive_timeout"` - LogFormat logging.Format `yaml:"log_format"` - LogLevel logging.Level `yaml:"log_level"` - Log logging.Interface `yaml:"-"` + LogFormat logging.Format `yaml:"log_format"` + LogLevel logging.Level `yaml:"log_level"` + Log logging.Interface `yaml:"-"` + LogSourceIPs bool `yaml:"log_source_ips_enabled"` + LogSourceIPsHeader string `yaml:"log_source_ips_header"` + LogSourceIPsRegex string `yaml:"log_source_ips_regex"` // If not set, default signal handler is used. SignalHandler SignalHandler `yaml:"-"` @@ -120,6 +123,9 @@ func (cfg *Config) RegisterFlags(f *flag.FlagSet) { f.StringVar(&cfg.PathPrefix, "server.path-prefix", "", "Base path to serve all API routes from (e.g. /v1/)") cfg.LogFormat.RegisterFlags(f) cfg.LogLevel.RegisterFlags(f) + f.BoolVar(&cfg.LogSourceIPs, "server.log-source-ips-enabled", false, "Optionally log the source IPs.") + f.StringVar(&cfg.LogSourceIPsHeader, "server.log-source-ips-header", "", "Header field storing the source IPs. Only used if server.log-source-ips-enabled is true. If not set the default Forwarded, X-Real-IP and X-Forwarded-For headers are used") + f.StringVar(&cfg.LogSourceIPsRegex, "server.log-source-ips-regex", "", "Regex for matching the source IPs. Only used if server.log-source-ips-enabled is true. If not set the default Forwarded, X-Real-IP and X-Forwarded-For headers are used") } // Server wraps a HTTP and gRPC server, and some common initialization. @@ -249,12 +255,21 @@ func New(cfg Config) (*Server, error) { if cfg.RegisterInstrumentation { RegisterInstrumentation(router) } + var sourceIPs *middleware.SourceIPExtractor + if cfg.LogSourceIPs { + sourceIPs, err = middleware.NewSourceIPs(cfg.LogSourceIPsHeader, cfg.LogSourceIPsRegex) + if err != nil { + return nil, fmt.Errorf("error setting up source IP extraction: %v", err) + } + } httpMiddleware := []middleware.Interface{ middleware.Tracer{ RouteMatcher: router, + SourceIPs: sourceIPs, }, middleware.Log{ - Log: log, + Log: log, + SourceIPs: sourceIPs, }, middleware.Instrument{ Duration: requestDuration, diff --git a/server/server_test.go b/server/server_test.go index f48c0dd0..a4f8add9 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -375,6 +375,67 @@ func TestTLSServer(t *testing.T) { require.EqualValues(t, &empty, grpcRes) } +type FakeLogger struct { + sourceIPs string +} + +func (f *FakeLogger) Debugf(format string, args ...interface{}) {} +func (f *FakeLogger) Debugln(args ...interface{}) {} + +func (f *FakeLogger) Infof(format string, args ...interface{}) {} +func (f *FakeLogger) Infoln(args ...interface{}) {} + +func (f *FakeLogger) Errorf(format string, args ...interface{}) {} +func (f *FakeLogger) Errorln(args ...interface{}) {} + +func (f *FakeLogger) Warnf(format string, args ...interface{}) {} +func (f *FakeLogger) Warnln(args ...interface{}) {} + +func (f *FakeLogger) WithField(key string, value interface{}) logging.Interface { + if key == "sourceIPs" { + f.sourceIPs = value.(string) + } + + return f +} + +func (f *FakeLogger) WithFields(fields logging.Fields) logging.Interface { + return f +} + +func TestLogSourceIPs(t *testing.T) { + var level logging.Level + level.Set("debug") + fake := FakeLogger{} + cfg := Config{ + HTTPListenAddress: "localhost", + HTTPListenPort: 9195, + GRPCListenAddress: "localhost", + HTTPMiddleware: []middleware.Interface{middleware.Logging}, + MetricsNamespace: "testing_mux", + LogLevel: level, + Log: &fake, + LogSourceIPs: true, + } + server, err := New(cfg) + require.NoError(t, err) + + server.HTTP.HandleFunc("/error500", func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(500) + }) + + go server.Run() + defer server.Shutdown() + + require.Empty(t, fake.sourceIPs) + + req, err := http.NewRequest("GET", "http://127.0.0.1:9195/error500", nil) + require.NoError(t, err) + http.DefaultClient.Do(req) + + require.Equal(t, fake.sourceIPs, "127.0.0.1") +} + func TestStopWithDisabledSignalHandling(t *testing.T) { cfg := Config{ HTTPListenAddress: "localhost",