Skip to content
This repository has been archived by the owner on Jan 24, 2019. It is now read-only.

Commit

Permalink
Merge pull request #360 from jehiah/csrf_validation_360
Browse files Browse the repository at this point in the history
CSRF protection for OAuth flow.
  • Loading branch information
jehiah authored Mar 29, 2017
2 parents 6c690b6 + 55085d9 commit 4464655
Show file tree
Hide file tree
Showing 4 changed files with 112 additions and 44 deletions.
16 changes: 16 additions & 0 deletions cookie/nonce.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
package cookie

import (
"crypto/rand"
"fmt"
)

func Nonce() (nonce string, err error) {
b := make([]byte, 16)
_, err = rand.Read(b)
if err != nil {
return
}
nonce = fmt.Sprintf("%x", b)
return
}
96 changes: 68 additions & 28 deletions oauthproxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ var SignatureHeaders []string = []string{
type OAuthProxy struct {
CookieSeed string
CookieName string
CSRFCookieName string
CookieDomain string
CookieSecure bool
CookieHttpOnly bool
Expand Down Expand Up @@ -174,6 +175,7 @@ func NewOAuthProxy(opts *Options, validator func(string) bool) *OAuthProxy {

return &OAuthProxy{
CookieName: opts.CookieName,
CSRFCookieName: fmt.Sprintf("%v_%v", opts.CookieName, "csrf"),
CookieSeed: opts.CookieSecret,
CookieDomain: opts.CookieDomain,
CookieSecure: opts.CookieSecure,
Expand Down Expand Up @@ -245,7 +247,22 @@ func (p *OAuthProxy) redeemCode(host, code string) (s *providers.SessionState, e
return
}

func (p *OAuthProxy) MakeCookie(req *http.Request, value string, expiration time.Duration, now time.Time) *http.Cookie {
func (p *OAuthProxy) MakeSessionCookie(req *http.Request, value string, expiration time.Duration, now time.Time) *http.Cookie {
if value != "" {
value = cookie.SignedValue(p.CookieSeed, p.CookieName, value, now)
if len(value) > 4096 {
// Cookies cannot be larger than 4kb
log.Printf("WARNING - Cookie Size: %d bytes", len(value))
}
}
return p.makeCookie(req, p.CookieName, value, expiration, now)
}

func (p *OAuthProxy) MakeCSRFCookie(req *http.Request, value string, expiration time.Duration, now time.Time) *http.Cookie {
return p.makeCookie(req, p.CSRFCookieName, value, expiration, now)
}

func (p *OAuthProxy) makeCookie(req *http.Request, name string, value string, expiration time.Duration, now time.Time) *http.Cookie {
domain := req.Host
if h, _, err := net.SplitHostPort(domain); err == nil {
domain = h
Expand All @@ -257,15 +274,8 @@ func (p *OAuthProxy) MakeCookie(req *http.Request, value string, expiration time
domain = p.CookieDomain
}

if value != "" {
value = cookie.SignedValue(p.CookieSeed, p.CookieName, value, now)
if len(value) > 4096 {
// Cookies cannot be larger than 4kb
log.Printf("WARNING - Cookie Size: %d bytes", len(value))
}
}
return &http.Cookie{
Name: p.CookieName,
Name: name,
Value: value,
Path: "/",
Domain: domain,
Expand All @@ -275,12 +285,20 @@ func (p *OAuthProxy) MakeCookie(req *http.Request, value string, expiration time
}
}

func (p *OAuthProxy) ClearCookie(rw http.ResponseWriter, req *http.Request) {
http.SetCookie(rw, p.MakeCookie(req, "", time.Hour*-1, time.Now()))
func (p *OAuthProxy) ClearCSRFCookie(rw http.ResponseWriter, req *http.Request) {
http.SetCookie(rw, p.MakeCSRFCookie(req, "", time.Hour*-1, time.Now()))
}

func (p *OAuthProxy) SetCSRFCookie(rw http.ResponseWriter, req *http.Request, val string) {
http.SetCookie(rw, p.MakeCSRFCookie(req, val, p.CookieExpire, time.Now()))
}

func (p *OAuthProxy) SetCookie(rw http.ResponseWriter, req *http.Request, val string) {
http.SetCookie(rw, p.MakeCookie(req, val, p.CookieExpire, time.Now()))
func (p *OAuthProxy) ClearSessionCookie(rw http.ResponseWriter, req *http.Request) {
http.SetCookie(rw, p.MakeSessionCookie(req, "", time.Hour*-1, time.Now()))
}

func (p *OAuthProxy) SetSessionCookie(rw http.ResponseWriter, req *http.Request, val string) {
http.SetCookie(rw, p.MakeSessionCookie(req, val, p.CookieExpire, time.Now()))
}

func (p *OAuthProxy) LoadCookiedSession(req *http.Request) (*providers.SessionState, time.Duration, error) {
Expand Down Expand Up @@ -309,7 +327,7 @@ func (p *OAuthProxy) SaveSession(rw http.ResponseWriter, req *http.Request, s *p
if err != nil {
return err
}
p.SetCookie(rw, req, value)
p.SetSessionCookie(rw, req, value)
return nil
}

Expand Down Expand Up @@ -339,7 +357,7 @@ func (p *OAuthProxy) ErrorPage(rw http.ResponseWriter, code int, title string, m
}

func (p *OAuthProxy) SignInPage(rw http.ResponseWriter, req *http.Request, code int) {
p.ClearCookie(rw, req)
p.ClearSessionCookie(rw, req)
rw.WriteHeader(code)

redirect_url := req.URL.RequestURI()
Expand Down Expand Up @@ -384,20 +402,18 @@ func (p *OAuthProxy) ManualSignIn(rw http.ResponseWriter, req *http.Request) (st
return "", false
}

func (p *OAuthProxy) GetRedirect(req *http.Request) (string, error) {
err := req.ParseForm()

func (p *OAuthProxy) GetRedirect(req *http.Request) (redirect string, err error) {
err = req.ParseForm()
if err != nil {
return "", err
return
}

redirect := req.FormValue("rd")

if redirect == "" {
redirect = req.Form.Get("rd")
if redirect == "" || !strings.HasPrefix(redirect, "/") || strings.HasPrefix(redirect, "//") {
redirect = "/"
}

return redirect, err
return
}

func (p *OAuthProxy) IsWhitelistedPath(path string) (ok bool) {
Expand Down Expand Up @@ -459,18 +475,24 @@ func (p *OAuthProxy) SignIn(rw http.ResponseWriter, req *http.Request) {
}

func (p *OAuthProxy) SignOut(rw http.ResponseWriter, req *http.Request) {
p.ClearCookie(rw, req)
p.ClearSessionCookie(rw, req)
http.Redirect(rw, req, "/", 302)
}

func (p *OAuthProxy) OAuthStart(rw http.ResponseWriter, req *http.Request) {
nonce, err := cookie.Nonce()
if err != nil {
p.ErrorPage(rw, 500, "Internal Error", err.Error())
return
}
p.SetCSRFCookie(rw, req, nonce)
redirect, err := p.GetRedirect(req)
if err != nil {
p.ErrorPage(rw, 500, "Internal Error", err.Error())
return
}
redirectURI := p.GetRedirectURI(req.Host)
http.Redirect(rw, req, p.provider.GetLoginURL(redirectURI, redirect), 302)
http.Redirect(rw, req, p.provider.GetLoginURL(redirectURI, fmt.Sprintf("%v:%v", nonce, redirect)), 302)
}

func (p *OAuthProxy) OAuthCallback(rw http.ResponseWriter, req *http.Request) {
Expand All @@ -495,8 +517,26 @@ func (p *OAuthProxy) OAuthCallback(rw http.ResponseWriter, req *http.Request) {
return
}

redirect := req.Form.Get("state")
if !strings.HasPrefix(redirect, "/") || strings.HasPrefix(redirect, "//") {
s := strings.SplitN(req.Form.Get("state"), ":", 2)
if len(s) != 2 {
p.ErrorPage(rw, 500, "Internal Error", "Invalid State")
return
}
nonce := s[0]
redirect := s[1]
c, err := req.Cookie(p.CSRFCookieName)
if err != nil {
p.ErrorPage(rw, 403, "Permission Denied", err.Error())
return
}
p.ClearCSRFCookie(rw, req)
if c.Value != nonce {
log.Printf("%s csrf token mismatch, potential attack", remoteAddr)
p.ErrorPage(rw, 403, "Permission Denied", "csrf failed")
return
}

if !strings.HasPrefix(redirect, "/") || strings.HasPrefix(redirect, "//") {
redirect = "/"
}

Expand Down Expand Up @@ -595,7 +635,7 @@ func (p *OAuthProxy) Authenticate(rw http.ResponseWriter, req *http.Request) int
}

if clearSession {
p.ClearCookie(rw, req)
p.ClearSessionCookie(rw, req)
}

if session == nil {
Expand Down
37 changes: 26 additions & 11 deletions oauthproxy_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -170,10 +170,14 @@ func TestBasicAuthPassword(t *testing.T) {
})

rw := httptest.NewRecorder()
req, _ := http.NewRequest("GET", "/oauth2/callback?code=callback_code",
req, _ := http.NewRequest("GET", "/oauth2/callback?code=callback_code&state=nonce:",
strings.NewReader(""))
req.AddCookie(proxy.MakeCSRFCookie(req, "nonce", proxy.CookieExpire, time.Now()))
proxy.ServeHTTP(rw, req)
cookie := rw.HeaderMap["Set-Cookie"][0]
if rw.Code >= 400 {
t.Fatalf("expected 3xx got %d", rw.Code)
}
cookie := rw.HeaderMap["Set-Cookie"][1]

cookieName := proxy.CookieName
var value string
Expand All @@ -196,9 +200,11 @@ func TestBasicAuthPassword(t *testing.T) {
Expires: time.Now().Add(time.Duration(24)),
HttpOnly: true,
})
req.AddCookie(proxy.MakeCSRFCookie(req, "nonce", proxy.CookieExpire, time.Now()))

rw = httptest.NewRecorder()
proxy.ServeHTTP(rw, req)

expectedHeader := "Basic " + base64.StdEncoding.EncodeToString([]byte(user_name+":"+opts.BasicAuthPassword))
assert.Equal(t, expectedHeader, rw.Body.String())
provider_server.Close()
Expand Down Expand Up @@ -263,13 +269,14 @@ func (pat_test *PassAccessTokenTest) Close() {
func (pat_test *PassAccessTokenTest) getCallbackEndpoint() (http_code int,
cookie string) {
rw := httptest.NewRecorder()
req, err := http.NewRequest("GET", "/oauth2/callback?code=callback_code",
req, err := http.NewRequest("GET", "/oauth2/callback?code=callback_code&state=nonce:",
strings.NewReader(""))
if err != nil {
return 0, ""
}
req.AddCookie(pat_test.proxy.MakeCSRFCookie(req, "nonce", time.Hour, time.Now()))
pat_test.proxy.ServeHTTP(rw, req)
return rw.Code, rw.HeaderMap["Set-Cookie"][0]
return rw.Code, rw.HeaderMap["Set-Cookie"][1]
}

func (pat_test *PassAccessTokenTest) getRootEndpoint(cookie string) (http_code int, access_token string) {
Expand Down Expand Up @@ -314,14 +321,18 @@ func TestForwardAccessTokenUpstream(t *testing.T) {

// A successful validation will redirect and set the auth cookie.
code, cookie := pat_test.getCallbackEndpoint()
assert.Equal(t, 302, code)
if code != 302 {
t.Fatalf("expected 302; got %d", code)
}
assert.NotEqual(t, nil, cookie)

// Now we make a regular request; the access_token from the cookie is
// forwarded as the "X-Forwarded-Access-Token" header. The token is
// read by the test provider server and written in the response body.
code, payload := pat_test.getRootEndpoint(cookie)
assert.Equal(t, 200, code)
if code != 200 {
t.Fatalf("expected 200; got %d", code)
}
assert.Equal(t, "my_auth_token", payload)
}

Expand All @@ -333,13 +344,17 @@ func TestDoNotForwardAccessTokenUpstream(t *testing.T) {

// A successful validation will redirect and set the auth cookie.
code, cookie := pat_test.getCallbackEndpoint()
assert.Equal(t, 302, code)
if code != 302 {
t.Fatalf("expected 302; got %d", code)
}
assert.NotEqual(t, nil, cookie)

// Now we make a regular request, but the access token header should
// not be present.
code, payload := pat_test.getRootEndpoint(cookie)
assert.Equal(t, 200, code)
if code != 200 {
t.Fatalf("expected 200; got %d", code)
}
assert.Equal(t, "No access token found.", payload)
}

Expand Down Expand Up @@ -457,15 +472,15 @@ func NewProcessCookieTestWithDefaults() *ProcessCookieTest {
}

func (p *ProcessCookieTest) MakeCookie(value string, ref time.Time) *http.Cookie {
return p.proxy.MakeCookie(p.req, value, p.opts.CookieExpire, ref)
return p.proxy.MakeSessionCookie(p.req, value, p.opts.CookieExpire, ref)
}

func (p *ProcessCookieTest) SaveSession(s *providers.SessionState, ref time.Time) error {
value, err := p.proxy.provider.CookieForSession(s, p.proxy.CookieCipher)
if err != nil {
return err
}
p.req.AddCookie(p.proxy.MakeCookie(p.req, value, p.proxy.CookieExpire, ref))
p.req.AddCookie(p.proxy.MakeSessionCookie(p.req, value, p.proxy.CookieExpire, ref))
return nil
}

Expand Down Expand Up @@ -697,7 +712,7 @@ func (st *SignatureTest) MakeRequestWithExpectedKey(method, body, key string) {
if err != nil {
panic(err)
}
cookie := proxy.MakeCookie(req, value, proxy.CookieExpire, time.Now())
cookie := proxy.MakeSessionCookie(req, value, proxy.CookieExpire, time.Now())
req.AddCookie(cookie)
// This is used by the upstream to validate the signature.
st.authenticator.auth = hmacauth.NewHmacAuth(
Expand Down
7 changes: 2 additions & 5 deletions providers/provider_default.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ import (
"io/ioutil"
"net/http"
"net/url"
"strings"

"github.com/bitly/oauth2_proxy/cookie"
)
Expand Down Expand Up @@ -79,7 +78,7 @@ func (p *ProviderData) Redeem(redirectURL, code string) (s *SessionState, err er
}

// GetLoginURL with typical oauth parameters
func (p *ProviderData) GetLoginURL(redirectURI, finalRedirect string) string {
func (p *ProviderData) GetLoginURL(redirectURI, state string) string {
var a url.URL
a = *p.LoginURL
params, _ := url.ParseQuery(a.RawQuery)
Expand All @@ -88,9 +87,7 @@ func (p *ProviderData) GetLoginURL(redirectURI, finalRedirect string) string {
params.Add("scope", p.Scope)
params.Set("client_id", p.ClientID)
params.Set("response_type", "code")
if strings.HasPrefix(finalRedirect, "/") && !strings.HasPrefix(finalRedirect,"//") {
params.Add("state", finalRedirect)
}
params.Add("state", state)
a.RawQuery = params.Encode()
return a.String()
}
Expand Down

0 comments on commit 4464655

Please sign in to comment.