Skip to content

Commit

Permalink
Fix race condition in endpoint discovery (#1504)
Browse files Browse the repository at this point in the history
  • Loading branch information
skmcgrail authored Nov 23, 2021
1 parent a6f44ed commit 56f090f
Show file tree
Hide file tree
Showing 5 changed files with 214 additions and 3 deletions.
8 changes: 8 additions & 0 deletions .changelog/9667162dc94c43c7bb0c5b0bbcd5ef8a.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
{
"id": "9667162d-c94c-43c7-bb0c-5b0bbcd5ef8a",
"type": "bugfix",
"description": "Fixed a race condition that caused concurrent calls relying on endpoint discovery to share the same `url.URL` reference in their operation's http.Request.",
"modules": [
"service/internal/endpoint-discovery"
]
}
5 changes: 4 additions & 1 deletion service/internal/endpoint-discovery/cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,10 @@ func (c *EndpointCache) get(endpointKey string) (Endpoint, bool) {
return Endpoint{}, false
}

c.endpoints.Store(endpointKey, endpoint)
ev := endpoint.(Endpoint)
ev.Prune()

c.endpoints.Store(endpointKey, ev)
return endpoint.(Endpoint), true
}

Expand Down
46 changes: 46 additions & 0 deletions service/internal/endpoint-discovery/cache_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
package endpointdiscovery

import (
"net/url"
"testing"
"time"
)

func TestEndpointCache_Get_prune(t *testing.T) {
c := NewEndpointCache(2)
c.Add(Endpoint{
Key: "foo",
Addresses: []WeightedAddress{
{
URL: &url.URL{
Host: "foo.amazonaws.com",
},
Expired: time.Now().Add(5 * time.Minute),
},
{
URL: &url.URL{
Host: "bar.amazonaws.com",
},
Expired: time.Now().Add(5 * -time.Minute),
},
},
})

load, _ := c.endpoints.Load("foo")
if ev := load.(Endpoint); len(ev.Addresses) != 2 {
t.Errorf("expected two weighted addresses")
}

weightedAddress, ok := c.Get("foo")
if !ok {
t.Errorf("expect weighted address, got none")
}
if e, a := "foo.amazonaws.com", weightedAddress.URL.Host; e != a {
t.Errorf("expect %v, got %v", e, a)
}

load, _ = c.endpoints.Load("foo")
if ev := load.(Endpoint); len(ev.Addresses) != 1 {
t.Errorf("expected one weighted address")
}
}
35 changes: 33 additions & 2 deletions service/internal/endpoint-discovery/endpoint.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,13 +51,44 @@ func (e *Endpoint) GetValidAddress() (WeightedAddress, bool) {
we := e.Addresses[i]

if we.HasExpired() {
e.Addresses = append(e.Addresses[:i], e.Addresses[i+1:]...)
i--
continue
}

we.URL = cloneURL(we.URL)

return we, true
}

return WeightedAddress{}, false
}

// Prune will prune the expired addresses from the endpoint by allocating a new []WeightAddress.
// This is not concurrent safe, and should be called from a single owning thread.
func (e *Endpoint) Prune() bool {
validLen := e.Len()
if validLen == len(e.Addresses) {
return false
}
wa := make([]WeightedAddress, 0, validLen)
for i := range e.Addresses {
if e.Addresses[i].HasExpired() {
continue
}
wa = append(wa, e.Addresses[i])
}
e.Addresses = wa
return true
}

func cloneURL(u *url.URL) (clone *url.URL) {
clone = &url.URL{}

*clone = *u

if u.User != nil {
user := *u.User
clone.User = &user
}

return clone
}
123 changes: 123 additions & 0 deletions service/internal/endpoint-discovery/endpoint_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
package endpointdiscovery

import (
"net/url"
"reflect"
"strconv"
"testing"
"time"
)

func Test_cloneURL(t *testing.T) {
tests := []struct {
value *url.URL
wantClone *url.URL
}{
{
value: &url.URL{
Scheme: "https",
Opaque: "foo",
User: nil,
Host: "amazonaws.com",
Path: "/",
RawPath: "/",
ForceQuery: true,
RawQuery: "thing=value",
Fragment: "1234",
RawFragment: "1234",
},
wantClone: &url.URL{
Scheme: "https",
Opaque: "foo",
User: nil,
Host: "amazonaws.com",
Path: "/",
RawPath: "/",
ForceQuery: true,
RawQuery: "thing=value",
Fragment: "1234",
RawFragment: "1234",
},
},
{
value: &url.URL{
Scheme: "https",
Opaque: "foo",
User: url.UserPassword("NOT", "VALID"),
Host: "amazonaws.com",
Path: "/",
RawPath: "/",
ForceQuery: true,
RawQuery: "thing=value",
Fragment: "1234",
RawFragment: "1234",
},
wantClone: &url.URL{
Scheme: "https",
Opaque: "foo",
User: url.UserPassword("NOT", "VALID"),
Host: "amazonaws.com",
Path: "/",
RawPath: "/",
ForceQuery: true,
RawQuery: "thing=value",
Fragment: "1234",
RawFragment: "1234",
},
},
}
for i, tt := range tests {
t.Run(strconv.Itoa(i), func(t *testing.T) {
gotClone := cloneURL(tt.value)
if gotClone == tt.value {
t.Errorf("expct clone URL to not be same pointer address")
}
if tt.value.User != nil {
if tt.value.User == gotClone.User {
t.Errorf("expct cloned Userinfo to not be same pointer address")
}
}
if !reflect.DeepEqual(gotClone, tt.wantClone) {
t.Errorf("cloneURL() = %v, want %v", gotClone, tt.wantClone)
}
})
}
}

func TestEndpoint_Prune(t *testing.T) {
endpoint := Endpoint{}

endpoint.Add(WeightedAddress{
URL: &url.URL{},
Expired: time.Now().Add(5 * time.Minute),
})

initial := endpoint.Addresses

if e, a := false, endpoint.Prune(); e != a {
t.Errorf("expect prune %v, got %v", e, a)
}

if e, a := &initial[0], &endpoint.Addresses[0]; e != a {
t.Errorf("expect slice address to be same")
}

endpoint.Add(WeightedAddress{
URL: &url.URL{},
Expired: time.Now().Add(5 * -time.Minute),
})

initial = endpoint.Addresses

if e, a := true, endpoint.Prune(); e != a {
t.Errorf("expect prune %v, got %v", e, a)
}

if e, a := &initial[0], &endpoint.Addresses[0]; e == a {
t.Errorf("expect slice address to be different")
}

if e, a := 1, endpoint.Len(); e != a {
t.Errorf("expect slice length %v, got %v", e, a)
}
}

0 comments on commit 56f090f

Please sign in to comment.