Skip to content

Commit

Permalink
Merge pull request #3 from gambol99/merge
Browse files Browse the repository at this point in the history
Merge
  • Loading branch information
gambol99 authored Mar 31, 2018
2 parents 2111f98 + 836f676 commit 750fa6d
Show file tree
Hide file tree
Showing 6 changed files with 62 additions and 14 deletions.
7 changes: 6 additions & 1 deletion http/http.go
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,12 @@ func expires(date, expires string) (time.Duration, bool, error) {
return 0, false, nil
}

te, err := time.Parse(time.RFC1123, expires)
var te time.Time
var err error
if expires == "0" {
return 0, false, nil
}
te, err = time.Parse(time.RFC1123, expires)
if err != nil {
return 0, false, err
}
Expand Down
7 changes: 7 additions & 0 deletions http/http_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,13 @@ func TestExpiresPass(t *testing.T) {
wantTTL: 0,
wantOK: false,
},
// Expires set to false
{
date: "Thu, 01 Dec 1983 22:00:00 GMT",
exp: "0",
wantTTL: 0,
wantOK: false,
},
// Expires < Date
{
date: "Fri, 02 Dec 1983 01:00:00 GMT",
Expand Down
7 changes: 4 additions & 3 deletions jwks.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package oidc

import (
"context"
"encoding/json"
"errors"
"fmt"
"io/ioutil"
Expand Down Expand Up @@ -182,14 +181,16 @@ func (r *remoteKeySet) updateKeys() ([]jose.JSONWebKey, time.Time, error) {

body, err := ioutil.ReadAll(resp.Body)
if err != nil {
return nil, time.Time{}, fmt.Errorf("oidc: read response body: %v", err)
return nil, time.Time{}, fmt.Errorf("unable to read response body: %v", err)
}

if resp.StatusCode != http.StatusOK {
return nil, time.Time{}, fmt.Errorf("oidc: get keys failed: %s %s", resp.Status, body)
}

var keySet jose.JSONWebKeySet
if err := json.Unmarshal(body, &keySet); err != nil {
err = unmarshalResp(resp, body, &keySet)
if err != nil {
return nil, time.Time{}, fmt.Errorf("oidc: failed to decode keys: %v %s", err, body)
}

Expand Down
25 changes: 22 additions & 3 deletions oidc.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"errors"
"fmt"
"io/ioutil"
"mime"
"net/http"
"strings"
"time"
Expand Down Expand Up @@ -93,18 +94,23 @@ func NewProvider(ctx context.Context, issuer string) (*Provider, error) {
if err != nil {
return nil, err
}
defer resp.Body.Close()

body, err := ioutil.ReadAll(resp.Body)
if err != nil {
return nil, err
return nil, fmt.Errorf("unable to read response body: %v", err)
}

if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("%s: %s", resp.Status, body)
}
defer resp.Body.Close()

var p providerJSON
if err := json.Unmarshal(body, &p); err != nil {
err = unmarshalResp(resp, body, &p)
if err != nil {
return nil, fmt.Errorf("oidc: failed to decode provider discovery object: %v", err)
}

if p.Issuer != issuer {
return nil, fmt.Errorf("oidc: issuer did not match the issuer returned by provider, expected %q got %q", issuer, p.Issuer)
}
Expand Down Expand Up @@ -307,3 +313,16 @@ func (j *jsonTime) UnmarshalJSON(b []byte) error {
*j = jsonTime(time.Unix(unix, 0))
return nil
}

func unmarshalResp(r *http.Response, body []byte, v interface{}) error {
err := json.Unmarshal(body, &v)
if err == nil {
return nil
}
ct := r.Header.Get("Content-Type")
mediaType, _, parseErr := mime.ParseMediaType(ct)
if parseErr == nil && mediaType == "application/json" {
return fmt.Errorf("got Content-Type = application/json, but could not unmarshal as JSON: %v", err)
}
return fmt.Errorf("expected Content-Type = application/json, got %q: %v", ct, err)
}
28 changes: 22 additions & 6 deletions oidc/provider_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -473,15 +473,19 @@ func (g *fakeProviderConfigGetterSetter) Set(cfg ProviderConfig) error {
}

type fakeProviderConfigHandler struct {
cfg ProviderConfig
maxAge time.Duration
cfg ProviderConfig
maxAge time.Duration
noExpires bool
}

func (s *fakeProviderConfigHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
b, _ := json.Marshal(&s.cfg)
if s.maxAge.Seconds() >= 0 {
w.Header().Set("Cache-Control", fmt.Sprintf("public, max-age=%d", int(s.maxAge.Seconds())))
}
if s.noExpires {
w.Header().Set("Expires", "0")
}
w.Header().Set("Content-Type", "application/json")
w.Write(b)
}
Expand Down Expand Up @@ -552,10 +556,11 @@ func TestHTTPProviderConfigGetter(t *testing.T) {
now := fc.Now().UTC()

tests := []struct {
dsc string
age time.Duration
cfg ProviderConfig
ok bool
dsc string
age time.Duration
cfg ProviderConfig
noExpires bool
ok bool
}{
// everything is good
{
Expand Down Expand Up @@ -596,6 +601,17 @@ func TestHTTPProviderConfigGetter(t *testing.T) {
},
ok: true,
},
// An expires header set to 0
{
dsc: "https://example.com",
age: time.Minute,
cfg: ProviderConfig{
Issuer: &url.URL{Scheme: "https", Host: "example.com"},
ExpiresAt: now.Add(time.Minute),
},
ok: true,
noExpires: true,
},
}

for i, tt := range tests {
Expand Down
2 changes: 1 addition & 1 deletion verify.go
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ func contains(sli []string, ele string) bool {
func (v *IDTokenVerifier) Verify(ctx context.Context, rawIDToken string) (*IDToken, error) {
jws, err := jose.ParseSigned(rawIDToken)
if err != nil {
return nil, fmt.Errorf("oidc: mallformed jwt: %v", err)
return nil, fmt.Errorf("oidc: malformed jwt: %v", err)
}

// Throw out tokens with invalid claims before trying to verify the token. This lets
Expand Down

0 comments on commit 750fa6d

Please sign in to comment.