Skip to content
This repository has been archived by the owner on Dec 7, 2020. It is now read-only.

Request ID #392

Merged
merged 3 commits into from
Jul 11, 2018
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

FEATURES:
* Added the ability to use a "any" operation on the roles rather then just "and" with the inclusion of a `require-any-role` [#PR389](https://github.com/gambol99/keycloak-proxy/pull/389)
* Added a `--enable-request-id` option to inject a request id into the upstream request [#PR392](https://github.com/gambol99/keycloak-proxy/pull/392)

#### **2.2.2**

Expand Down
8 changes: 7 additions & 1 deletion Gopkg.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ format:

bench:
@echo "--> Running go bench"
@go test -bench=.
@go test -bench=. -benchmem

coverage:
@echo "--> Running go coverage"
Expand All @@ -134,7 +134,7 @@ cover:
@go test --cover

spelling:
@echo "--> Chekcing the spelling"
@echo "--> Checking the spelling"
@which misspell 2>/dev/null ; if [ $$? -eq 1 ]; then \
go get -u github.com/client9/misspell/cmd/misspell; \
fi
Expand Down
8 changes: 8 additions & 0 deletions cli.go
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,12 @@ func getCommandLineOptions() []cli.Flag {
Name: optName,
Usage: usage,
})
case reflect.Int:
flags = append(flags, cli.IntFlag{
Name: optName,
Usage: usage,
EnvVar: envName,
})
case reflect.Int64:
switch t.String() {
case "time.Duration":
Expand Down Expand Up @@ -170,6 +176,8 @@ func parseCLIOptions(cx *cli.Context, config *Config) (err error) {
reflect.ValueOf(config).Elem().FieldByName(field.Name).SetString(cx.String(name))
case reflect.Slice:
reflect.ValueOf(config).Elem().FieldByName(field.Name).Set(reflect.ValueOf(cx.StringSlice(name)))
case reflect.Int:
reflect.ValueOf(config).Elem().FieldByName(field.Name).Set(reflect.ValueOf(cx.Int(name)))
case reflect.Int64:
switch field.Type.String() {
case "time.Duration":
Expand Down
1 change: 1 addition & 0 deletions config.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ func newDefaultConfig() *Config {
OAuthURI: "/oauth",
OpenIDProviderTimeout: 30 * time.Second,
PreserveHost: false,
RequestIDHeader: "X-Request-ID",
ResponseHeaders: make(map[string]string),
SecureCookie: true,
ServerIdleTimeout: 120 * time.Second,
Expand Down
4 changes: 4 additions & 0 deletions doc.go
Original file line number Diff line number Diff line change
Expand Up @@ -178,9 +178,13 @@ type Config struct {
Headers map[string]string `json:"headers" yaml:"headers" usage:"custom headers to the upstream request, key=value"`
// PreserveHost preserves the host header of the proxied request in the upstream request
PreserveHost bool `json:"preserve-host" yaml:"preserve-host" usage:"preserve the host header of the proxied request in the upstream request"`
// RequestIDHeader is the header name for request ids
RequestIDHeader string `json:"request-id-header" yaml:"request-id-header" usage:"the http header name for request id" env:"REQUEST_ID_HEADER"`
// ResponseHeader is a map of response headers to add to the response
ResponseHeaders map[string]string `json:"response-headers" yaml:"response-headers" usage:"custom headers to added to the http response key=value"`

// EnableRequestID indicates the proxy should add request id if none if found
EnableRequestID bool `json:"enable-request-id" yaml:"enable-request-id" usage:"indicates we should add a request id if none found" env:"ENABLE_REQUEST_ID"`
// EnableLogoutRedirect indicates we should redirect to the identity provider for logging out
EnableLogoutRedirect bool `json:"enable-logout-redirect" yaml:"enable-logout-redirect" usage:"indicates we should redirect to the identity provider for logging out"`
// EnableDefaultDeny indicates we should deny by default all requests
Expand Down
13 changes: 13 additions & 0 deletions middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,19 @@ func entrypointMiddleware(next http.Handler) http.Handler {
})
}

// requestIDMiddleware is responsible for adding a request id if none found
func (r *oauthProxy) requestIDMiddleware(header string) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
if v := req.Header.Get(header); v == "" {
req.Header.Set(header, randomUUID())
}

next.ServeHTTP(w, req)
})
}
}

// loggingMiddleware is a custom http logger
func (r *oauthProxy) loggingMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
Expand Down
16 changes: 14 additions & 2 deletions middleware_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,9 +60,10 @@ type fakeRequest struct {
ExpectedContentContains string
ExpectedCookies map[string]string
ExpectedHeaders map[string]string
ExpectedProxyHeaders map[string]string
ExpectedLocation string
ExpectedNoProxyHeaders []string
ExpectedProxy bool
ExpectedProxyHeaders map[string]string
}

type fakeProxy struct {
Expand Down Expand Up @@ -230,9 +231,20 @@ func (f *fakeProxy) RunTests(t *testing.T, requests []fakeRequest) {
if c.ExpectedProxyHeaders != nil && len(c.ExpectedProxyHeaders) > 0 {
for k, v := range c.ExpectedProxyHeaders {
headers := upstream.Headers
assert.Equal(t, v, headers.Get(k), "case %d, expected proxy header %s=%s, got: %s", i, k, v, headers.Get(k))
switch v {
case "":
assert.NotEmpty(t, headers.Get(k), "case %d, expected the proxy header: %s to exist", i, k)
default:
assert.Equal(t, v, headers.Get(k), "case %d, expected proxy header %s=%s, got: %s", i, k, v, headers.Get(k))
}
}
}
if len(c.ExpectedNoProxyHeaders) > 0 {
for _, k := range c.ExpectedNoProxyHeaders {
assert.Empty(t, upstream.Headers.Get(k), "case %d, header: %s was not expected to exist", i, k)
}
}

if c.ExpectedContent != "" {
e := string(resp.Body())
assert.Equal(t, c.ExpectedContent, e, "case %d, expected content: %s, got: %s", i, c.ExpectedContent, e)
Expand Down
5 changes: 5 additions & 0 deletions server.go
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,11 @@ func (r *oauthProxy) createReverseProxy() error {
engine.MethodNotAllowed(emptyHandler)
engine.NotFound(emptyHandler)
engine.Use(middleware.Recoverer)
// @check if the request tracking id middleware is enabled
if r.config.EnableRequestID {
engine.Use(r.requestIDMiddleware(r.config.RequestIDHeader))
}
// @step: enable the entrypoint middleware
engine.Use(entrypointMiddleware)

if r.config.EnableLogging {
Expand Down
58 changes: 37 additions & 21 deletions server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,43 @@ func TestForbiddenTemplate(t *testing.T) {
newFakeProxy(cfg).RunTests(t, requests)
}

func TestRequestIDHeader(t *testing.T) {
c := newFakeKeycloakConfig()
c.EnableRequestID = true
requests := []fakeRequest{
{
URI: "/auth_all/test",
HasLogin: true,
ExpectedProxy: true,
Redirects: true,
ExpectedHeaders: map[string]string{
"X-Request-ID": "",
},
ExpectedCode: http.StatusOK,
},
}
newFakeProxy(c).RunTests(t, requests)
}

func TestAuthTokenHeaderDisabled(t *testing.T) {
c := newFakeKeycloakConfig()
c.EnableTokenHeader = false
p := newFakeProxy(c)
token := newTestToken(p.idp.getLocation())
signed, _ := p.idp.signToken(token.claims)

requests := []fakeRequest{
{
URI: "/auth_all/test",
RawToken: signed.Encode(),
ExpectedNoProxyHeaders: []string{"X-Auth-Token"},
ExpectedProxy: true,
ExpectedCode: http.StatusOK,
},
}
p.RunTests(t, requests)
}

func TestAudienceHeader(t *testing.T) {
c := newFakeKeycloakConfig()
c.NoRedirects = false
Expand Down Expand Up @@ -371,27 +408,6 @@ func TestAuthTokenHeaderEnabled(t *testing.T) {
p.RunTests(t, requests)
}

func TestAuthTokenHeaderDisabled(t *testing.T) {
c := newFakeKeycloakConfig()
c.EnableTokenHeader = false
p := newFakeProxy(c)
token := newTestToken(p.idp.getLocation())
signed, _ := p.idp.signToken(token.claims)

requests := []fakeRequest{
{
URI: "/auth_all/test",
RawToken: signed.Encode(),
ExpectedProxyHeaders: map[string]string{
"X-Auth-Token": "",
},
ExpectedProxy: true,
ExpectedCode: http.StatusOK,
},
}
p.RunTests(t, requests)
}

func TestDisableAuthorizationCookie(t *testing.T) {
c := newFakeKeycloakConfig()
c.EnableAuthorizationCookies = false
Expand Down
50 changes: 50 additions & 0 deletions utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ import (
"fmt"
"io"
"io/ioutil"
mrand "math/rand"
"net"
"net/http"
"net/url"
Expand Down Expand Up @@ -77,6 +78,55 @@ func getRequestHostURL(r *http.Request) string {
return fmt.Sprintf("%s://%s", scheme, hostname)
}

const (
letterBytes = "abcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyz0123456789"
letterIdxBits = 6
letterIdxMask = 1<<letterIdxBits - 1
letterIdxMax = 63 / letterIdxBits
)

var randomSource = mrand.NewSource(time.Now().UnixNano())

// randomBytes returns a random array of bytes
// @note: code taken from https://stackoverflow.com/questions/22892120/how-to-generate-a-random-string-of-a-fixed-length-in-golang
func randomBytes(n int) []byte {
b := make([]byte, n)
for i, cache, remain := n-1, randomSource.Int63(), letterIdxMax; i >= 0; {
if remain == 0 {
cache, remain = randomSource.Int63(), letterIdxMax
}
if idx := int(cache & letterIdxMask); idx < len(letterBytes) {
b[i] = letterBytes[idx]
i--
}
cache >>= letterIdxBits
remain--
}

return b
}

// randomString returns a random string of x length
func randomString(length int) string {
return string(randomBytes(length))
}

// randomUUID returns a uuid from the random string
func randomUUID() string {
uuid := make([]byte, 36)
r := randomBytes(32)
i := 0
for x := range []int{8, 4, 4, 4, 12} {
copy(uuid, r[i:i+x])
if x != 12 {
copy(uuid, []byte("-"))
i = i + x
}
}

return string(uuid)
}

// readConfigFile reads and parses the configuration file
func readConfigFile(filename string, config *Config) error {
content, err := ioutil.ReadFile(filename)
Expand Down
43 changes: 43 additions & 0 deletions utils_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import (
"testing"
"time"

"github.com/google/uuid"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
Expand Down Expand Up @@ -64,6 +65,48 @@ func TestDecodeKeyPairs(t *testing.T) {
}
}

func TestRandom(t *testing.T) {
s := randomBytes(6)
assert.NotEmpty(t, s)
assert.Equal(t, 6, len(s))
}

func TestRandomString(t *testing.T) {
s := randomString(6)
assert.NotEmpty(t, s)
assert.Equal(t, 6, len(s))
}

func TestRandomUUID(t *testing.T) {
s := randomUUID()
assert.NotEmpty(t, s)
assert.Equal(t, 36, len(s))
}

func BenchmarkRandomBytes36(b *testing.B) {
for n := 0; n < b.N; n++ {
randomString(36)
}
}

func BenchmarkRandomString36(b *testing.B) {
for n := 0; n < b.N; n++ {
randomString(36)
}
}

func BenchmarkUUID(b *testing.B) {
for n := 0; n < b.N; n++ {
uuid.New()
}
}

func BenchmarkRandomUUID(b *testing.B) {
for n := 0; n < b.N; n++ {
randomUUID()
}
}

func TestDefaultTo(t *testing.T) {
cs := []struct {
Value string
Expand Down