Skip to content
This repository has been archived by the owner on Dec 7, 2020. It is now read-only.

Commit

Permalink
Merge pull request #392 from gambol99/request_id
Browse files Browse the repository at this point in the history
Request ID
  • Loading branch information
gambol99 authored Jul 11, 2018
2 parents 89caa51 + 5a4ab10 commit c4d677a
Show file tree
Hide file tree
Showing 12 changed files with 103 additions and 27 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

FEATURES:
* Added the ability to use a "any" operation on the roles rather then just "and" with the inclusion of a `require-any-role` [#PR389](https://github.com/gambol99/keycloak-proxy/pull/389)
* Added a `--enable-request-id` option to inject a request id into the upstream request [#PR392](https://github.com/gambol99/keycloak-proxy/pull/392)

#### **2.2.2**

Expand Down
8 changes: 7 additions & 1 deletion Gopkg.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ format:

bench:
@echo "--> Running go bench"
@go test -bench=.
@go test -bench=. -benchmem

coverage:
@echo "--> Running go coverage"
Expand All @@ -134,7 +134,7 @@ cover:
@go test --cover

spelling:
@echo "--> Chekcing the spelling"
@echo "--> Checking the spelling"
@which misspell 2>/dev/null ; if [ $$? -eq 1 ]; then \
go get -u github.com/client9/misspell/cmd/misspell; \
fi
Expand Down
8 changes: 8 additions & 0 deletions cli.go
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,12 @@ func getCommandLineOptions() []cli.Flag {
Name: optName,
Usage: usage,
})
case reflect.Int:
flags = append(flags, cli.IntFlag{
Name: optName,
Usage: usage,
EnvVar: envName,
})
case reflect.Int64:
switch t.String() {
case "time.Duration":
Expand Down Expand Up @@ -170,6 +176,8 @@ func parseCLIOptions(cx *cli.Context, config *Config) (err error) {
reflect.ValueOf(config).Elem().FieldByName(field.Name).SetString(cx.String(name))
case reflect.Slice:
reflect.ValueOf(config).Elem().FieldByName(field.Name).Set(reflect.ValueOf(cx.StringSlice(name)))
case reflect.Int:
reflect.ValueOf(config).Elem().FieldByName(field.Name).Set(reflect.ValueOf(cx.Int(name)))
case reflect.Int64:
switch field.Type.String() {
case "time.Duration":
Expand Down
1 change: 1 addition & 0 deletions config.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ func newDefaultConfig() *Config {
OAuthURI: "/oauth",
OpenIDProviderTimeout: 30 * time.Second,
PreserveHost: false,
RequestIDHeader: "X-Request-ID",
ResponseHeaders: make(map[string]string),
SecureCookie: true,
ServerIdleTimeout: 120 * time.Second,
Expand Down
4 changes: 4 additions & 0 deletions doc.go
Original file line number Diff line number Diff line change
Expand Up @@ -178,9 +178,13 @@ type Config struct {
Headers map[string]string `json:"headers" yaml:"headers" usage:"custom headers to the upstream request, key=value"`
// PreserveHost preserves the host header of the proxied request in the upstream request
PreserveHost bool `json:"preserve-host" yaml:"preserve-host" usage:"preserve the host header of the proxied request in the upstream request"`
// RequestIDHeader is the header name for request ids
RequestIDHeader string `json:"request-id-header" yaml:"request-id-header" usage:"the http header name for request id" env:"REQUEST_ID_HEADER"`
// ResponseHeader is a map of response headers to add to the response
ResponseHeaders map[string]string `json:"response-headers" yaml:"response-headers" usage:"custom headers to added to the http response key=value"`

// EnableRequestID indicates the proxy should add request id if none if found
EnableRequestID bool `json:"enable-request-id" yaml:"enable-request-id" usage:"indicates we should add a request id if none found" env:"ENABLE_REQUEST_ID"`
// EnableLogoutRedirect indicates we should redirect to the identity provider for logging out
EnableLogoutRedirect bool `json:"enable-logout-redirect" yaml:"enable-logout-redirect" usage:"indicates we should redirect to the identity provider for logging out"`
// EnableDefaultDeny indicates we should deny by default all requests
Expand Down
2 changes: 1 addition & 1 deletion forwarding.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ func (r *oauthProxy) proxyMiddleware(next http.Handler) http.Handler {

// @step: add the proxy forwarding headers
req.Header.Add("X-Forwarded-For", realIP(req))
req.Header.Set("X-Forwarded-Host", req.URL.Host)
req.Header.Set("X-Forwarded-Host", req.Host)
req.Header.Set("X-Forwarded-Proto", req.Header.Get("X-Forwarded-Proto"))

// @step: add any custom headers to the request
Expand Down
14 changes: 14 additions & 0 deletions middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import (
"github.com/PuerkitoBio/purell"
"github.com/gambol99/go-oidc/jose"
"github.com/go-chi/chi/middleware"
uuid "github.com/satori/go.uuid"
"github.com/unrolled/secure"
"go.uber.org/zap"
"go.uber.org/zap/zapcore"
Expand Down Expand Up @@ -66,6 +67,19 @@ func entrypointMiddleware(next http.Handler) http.Handler {
})
}

// requestIDMiddleware is responsible for adding a request id if none found
func (r *oauthProxy) requestIDMiddleware(header string) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
if v := req.Header.Get(header); v == "" {
req.Header.Set(header, uuid.NewV1().String())
}

next.ServeHTTP(w, req)
})
}
}

// loggingMiddleware is a custom http logger
func (r *oauthProxy) loggingMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
Expand Down
16 changes: 14 additions & 2 deletions middleware_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,9 +60,10 @@ type fakeRequest struct {
ExpectedContentContains string
ExpectedCookies map[string]string
ExpectedHeaders map[string]string
ExpectedProxyHeaders map[string]string
ExpectedLocation string
ExpectedNoProxyHeaders []string
ExpectedProxy bool
ExpectedProxyHeaders map[string]string
}

type fakeProxy struct {
Expand Down Expand Up @@ -230,9 +231,20 @@ func (f *fakeProxy) RunTests(t *testing.T, requests []fakeRequest) {
if c.ExpectedProxyHeaders != nil && len(c.ExpectedProxyHeaders) > 0 {
for k, v := range c.ExpectedProxyHeaders {
headers := upstream.Headers
assert.Equal(t, v, headers.Get(k), "case %d, expected proxy header %s=%s, got: %s", i, k, v, headers.Get(k))
switch v {
case "":
assert.NotEmpty(t, headers.Get(k), "case %d, expected the proxy header: %s to exist", i, k)
default:
assert.Equal(t, v, headers.Get(k), "case %d, expected proxy header %s=%s, got: %s", i, k, v, headers.Get(k))
}
}
}
if len(c.ExpectedNoProxyHeaders) > 0 {
for _, k := range c.ExpectedNoProxyHeaders {
assert.Empty(t, upstream.Headers.Get(k), "case %d, header: %s was not expected to exist", i, k)
}
}

if c.ExpectedContent != "" {
e := string(resp.Body())
assert.Equal(t, c.ExpectedContent, e, "case %d, expected content: %s, got: %s", i, c.ExpectedContent, e)
Expand Down
6 changes: 6 additions & 0 deletions server.go
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,12 @@ func (r *oauthProxy) createReverseProxy() error {
engine.MethodNotAllowed(emptyHandler)
engine.NotFound(emptyHandler)
engine.Use(middleware.Recoverer)
// @check if the request tracking id middleware is enabled
if r.config.EnableRequestID {
r.log.Info("enabled the correlation request id middlware")
engine.Use(r.requestIDMiddleware(r.config.RequestIDHeader))
}
// @step: enable the entrypoint middleware
engine.Use(entrypointMiddleware)

if r.config.EnableLogging {
Expand Down
58 changes: 37 additions & 21 deletions server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,43 @@ func TestForbiddenTemplate(t *testing.T) {
newFakeProxy(cfg).RunTests(t, requests)
}

func TestRequestIDHeader(t *testing.T) {
c := newFakeKeycloakConfig()
c.EnableRequestID = true
requests := []fakeRequest{
{
URI: "/auth_all/test",
HasLogin: true,
ExpectedProxy: true,
Redirects: true,
ExpectedHeaders: map[string]string{
"X-Request-ID": "",
},
ExpectedCode: http.StatusOK,
},
}
newFakeProxy(c).RunTests(t, requests)
}

func TestAuthTokenHeaderDisabled(t *testing.T) {
c := newFakeKeycloakConfig()
c.EnableTokenHeader = false
p := newFakeProxy(c)
token := newTestToken(p.idp.getLocation())
signed, _ := p.idp.signToken(token.claims)

requests := []fakeRequest{
{
URI: "/auth_all/test",
RawToken: signed.Encode(),
ExpectedNoProxyHeaders: []string{"X-Auth-Token"},
ExpectedProxy: true,
ExpectedCode: http.StatusOK,
},
}
p.RunTests(t, requests)
}

func TestAudienceHeader(t *testing.T) {
c := newFakeKeycloakConfig()
c.NoRedirects = false
Expand Down Expand Up @@ -371,27 +408,6 @@ func TestAuthTokenHeaderEnabled(t *testing.T) {
p.RunTests(t, requests)
}

func TestAuthTokenHeaderDisabled(t *testing.T) {
c := newFakeKeycloakConfig()
c.EnableTokenHeader = false
p := newFakeProxy(c)
token := newTestToken(p.idp.getLocation())
signed, _ := p.idp.signToken(token.claims)

requests := []fakeRequest{
{
URI: "/auth_all/test",
RawToken: signed.Encode(),
ExpectedProxyHeaders: map[string]string{
"X-Auth-Token": "",
},
ExpectedProxy: true,
ExpectedCode: http.StatusOK,
},
}
p.RunTests(t, requests)
}

func TestDisableAuthorizationCookie(t *testing.T) {
c := newFakeKeycloakConfig()
c.EnableAuthorizationCookies = false
Expand Down
8 changes: 8 additions & 0 deletions utils_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import (
"testing"
"time"

uuid "github.com/satori/go.uuid"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
Expand Down Expand Up @@ -64,6 +65,13 @@ func TestDecodeKeyPairs(t *testing.T) {
}
}

func BenchmarkUUID(b *testing.B) {
for n := 0; n < b.N; n++ {
s := uuid.NewV1()
s.String()
}
}

func TestDefaultTo(t *testing.T) {
cs := []struct {
Value string
Expand Down

0 comments on commit c4d677a

Please sign in to comment.