Skip to content

Commit

Permalink
Update IMDS credential provider to handle credentials fail to refresh (
Browse files Browse the repository at this point in the history
…#1634)

Adds support for IMDS assume role credential provider to handle
credentials that fail to refresh, and are approaching their expiry time.

Updates expiry time of IMDS credentials to have an initial ceiling of 1
hour instead of the duration returned by IMDS. This allows the SDK to
proactively retrieve updated IMDS credentials.
  • Loading branch information
jasdel authored Mar 23, 2022
1 parent f83f630 commit cbd1bab
Show file tree
Hide file tree
Showing 9 changed files with 712 additions and 47 deletions.
8 changes: 8 additions & 0 deletions .changelog/08d93344ca9c4036aefcf73146edaf55.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
{
"id": "08d93344-ca9c-4036-aefc-f73146edaf55",
"type": "feature",
"description": "Update CredentialsCache to make use of two new optional CredentialsProvider interfaces to give the cache, per provider, behavior how the cache handles credentials that fail to refresh, and adjusting expires time. See [aws.CredentialsCache](https://pkg.go.dev/github.com/aws/aws-sdk-go-v2/aws#CredentialsCache) for more details.",
"modules": [
"."
]
}
9 changes: 9 additions & 0 deletions .changelog/9913162361ed41fe867f56fb4ee75e8e.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
{
"id": "99131623-61ed-41fe-867f-56fb4ee75e8e",
"type": "feature",
"description": "Update `ec2rolecreds` package's `Provider` to implememnt support for CredentialsCache new optional caching strategy interfaces, HandleFailRefreshCredentialsCacheStrategy and AdjustExpiresByCredentialsCacheStrategy.",
"modules": [
".",
"credentials"
]
}
133 changes: 106 additions & 27 deletions aws/credential_cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package aws

import (
"context"
"fmt"
"sync/atomic"
"time"

Expand All @@ -24,11 +25,13 @@ type CredentialsCacheOptions struct {
// If ExpiryWindow is 0 or less it will be ignored.
ExpiryWindow time.Duration

// ExpiryWindowJitterFrac provides a mechanism for randomizing the expiration of credentials
// within the configured ExpiryWindow by a random percentage. Valid values are between 0.0 and 1.0.
// ExpiryWindowJitterFrac provides a mechanism for randomizing the
// expiration of credentials within the configured ExpiryWindow by a random
// percentage. Valid values are between 0.0 and 1.0.
//
// As an example if ExpiryWindow is 60 seconds and ExpiryWindowJitterFrac is 0.5 then credentials will be set to
// expire between 30 to 60 seconds prior to their actual expiration time.
// As an example if ExpiryWindow is 60 seconds and ExpiryWindowJitterFrac
// is 0.5 then credentials will be set to expire between 30 to 60 seconds
// prior to their actual expiration time.
//
// If ExpiryWindow is 0 or less then ExpiryWindowJitterFrac is ignored.
// If ExpiryWindowJitterFrac is 0 then no randomization will be applied to the window.
Expand All @@ -39,17 +42,29 @@ type CredentialsCacheOptions struct {

// CredentialsCache provides caching and concurrency safe credentials retrieval
// via the provider's retrieve method.
//
// CredentialsCache will look for optional interfaces on the Provider to adjust
// how the credential cache handles credentials caching.
//
// * HandleFailRefreshCredentialsCacheStrategy - Allows provider to handle
// credential refresh failures. This could return an updated Credentials
// value, or attempt another means of retrieving credentials.
//
// * AdjustExpiresByCredentialsCacheStrategy - Allows provider to adjust how
// credentials Expires is modified. This could modify how the Credentials
// Expires is adjusted based on the CredentialsCache ExpiryWindow option.
// Such as providing a floor not to reduce the Expires below.
type CredentialsCache struct {
// provider is the CredentialProvider implementation to be wrapped by the CredentialCache.
provider CredentialsProvider

options CredentialsCacheOptions
creds atomic.Value
sf singleflight.Group
}

// NewCredentialsCache returns a CredentialsCache that wraps provider. Provider is expected to not be nil. A variadic
// list of one or more functions can be provided to modify the CredentialsCache configuration. This allows for
// NewCredentialsCache returns a CredentialsCache that wraps provider. Provider
// is expected to not be nil. A variadic list of one or more functions can be
// provided to modify the CredentialsCache configuration. This allows for
// configuration of credential expiry window and jitter.
func NewCredentialsCache(provider CredentialsProvider, optFns ...func(options *CredentialsCacheOptions)) *CredentialsCache {
options := CredentialsCacheOptions{}
Expand Down Expand Up @@ -81,8 +96,8 @@ func NewCredentialsCache(provider CredentialsProvider, optFns ...func(options *C
//
// Returns and error if the provider's retrieve method returns an error.
func (p *CredentialsCache) Retrieve(ctx context.Context) (Credentials, error) {
if creds := p.getCreds(); creds != nil {
return *creds, nil
if creds, ok := p.getCreds(); ok && !creds.Expired() {
return creds, nil
}

resCh := p.sf.DoChan("", func() (interface{}, error) {
Expand All @@ -97,43 +112,107 @@ func (p *CredentialsCache) Retrieve(ctx context.Context) (Credentials, error) {
}

func (p *CredentialsCache) singleRetrieve(ctx context.Context) (interface{}, error) {
if creds := p.getCreds(); creds != nil {
return *creds, nil
currCreds, ok := p.getCreds()
if ok && !currCreds.Expired() {
return currCreds, nil
}

newCreds, err := p.provider.Retrieve(ctx)
if err != nil {
handleFailToRefresh := defaultHandleFailToRefresh
if cs, ok := p.provider.(HandleFailRefreshCredentialsCacheStrategy); ok {
handleFailToRefresh = cs.HandleFailToRefresh
}
newCreds, err = handleFailToRefresh(ctx, currCreds, err)
if err != nil {
return Credentials{}, fmt.Errorf("failed to refresh cached credentials, %w", err)
}
}

creds, err := p.provider.Retrieve(ctx)
if err == nil {
if creds.CanExpire {
randFloat64, err := sdkrand.CryptoRandFloat64()
if err != nil {
return Credentials{}, err
}
jitter := time.Duration(randFloat64 * p.options.ExpiryWindowJitterFrac * float64(p.options.ExpiryWindow))
creds.Expires = creds.Expires.Add(-(p.options.ExpiryWindow - jitter))
if newCreds.CanExpire && p.options.ExpiryWindow > 0 {
adjustExpiresBy := defaultAdjustExpiresBy
if cs, ok := p.provider.(AdjustExpiresByCredentialsCacheStrategy); ok {
adjustExpiresBy = cs.AdjustExpiresBy
}

randFloat64, err := sdkrand.CryptoRandFloat64()
if err != nil {
return Credentials{}, fmt.Errorf("failed to get random provider, %w", err)
}

p.creds.Store(&creds)
var jitter time.Duration
if p.options.ExpiryWindowJitterFrac > 0 {
jitter = time.Duration(randFloat64 *
p.options.ExpiryWindowJitterFrac * float64(p.options.ExpiryWindow))
}

newCreds, err = adjustExpiresBy(newCreds, -(p.options.ExpiryWindow - jitter))
if err != nil {
return Credentials{}, fmt.Errorf("failed to adjust credentials expires, %w", err)
}
}

return creds, err
p.creds.Store(&newCreds)
return newCreds, nil
}

func (p *CredentialsCache) getCreds() *Credentials {
// getCreds returns the currently stored credentials and true. Returning false
// if no credentials were stored.
func (p *CredentialsCache) getCreds() (Credentials, bool) {
v := p.creds.Load()
if v == nil {
return nil
return Credentials{}, false
}

c := v.(*Credentials)
if c != nil && c.HasKeys() && !c.Expired() {
return c
if c == nil || !c.HasKeys() {
return Credentials{}, false
}

return nil
return *c, true
}

// Invalidate will invalidate the cached credentials. The next call to Retrieve
// will cause the provider's Retrieve method to be called.
func (p *CredentialsCache) Invalidate() {
p.creds.Store((*Credentials)(nil))
}

// HandleFailRefreshCredentialsCacheStrategy is an interface for
// CredentialsCache to allow CredentialsProvider how failed to refresh
// credentials is handled.
type HandleFailRefreshCredentialsCacheStrategy interface {
// Given the previously cached Credentials, if any, and refresh error, may
// returns new or modified set of Credentials, or error.
//
// Credential caches may use default implementation if nil.
HandleFailToRefresh(context.Context, Credentials, error) (Credentials, error)
}

// defaultHandleFailToRefresh returns the passed in error.
func defaultHandleFailToRefresh(ctx context.Context, _ Credentials, err error) (Credentials, error) {
return Credentials{}, err
}

// AdjustExpiresByCredentialsCacheStrategy is an interface for CredentialCache
// to allow CredentialsProvider to intercept adjustments to Credentials expiry
// based on expectations and use cases of CredentialsProvider.
//
// Credential caches may use default implementation if nil.
type AdjustExpiresByCredentialsCacheStrategy interface {
// Given a Credentials as input, applying any mutations and
// returning the potentially updated Credentials, or error.
AdjustExpiresBy(Credentials, time.Duration) (Credentials, error)
}

// defaultAdjustExpiresBy adds the duration to the passed in credentials Expires,
// and returns the updated credentials value. If Credentials value's CanExpire
// is false, the passed in credentials are returned unchanged.
func defaultAdjustExpiresBy(creds Credentials, dur time.Duration) (Credentials, error) {
if !creds.CanExpire {
return creds, nil
}

creds.Expires = creds.Expires.Add(dur)
return creds, nil
}
Loading

0 comments on commit cbd1bab

Please sign in to comment.