Skip to content

Commit

Permalink
use lock instead of CAS
Browse files Browse the repository at this point in the history
  • Loading branch information
terencefan committed Dec 6, 2024
1 parent 6f674d5 commit 4c4e96b
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 39 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,14 @@ internal class MicrosoftEntraAccessKey : IAccessKey
{
internal static readonly TimeSpan GetAccessKeyTimeout = TimeSpan.FromSeconds(100);

private const int UpdateTaskIdle = 0;

private const int UpdateTaskRunning = 1;

private const int GetAccessKeyMaxRetryTimes = 3;

private const int GetMicrosoftEntraTokenMaxRetryTimes = 3;

private readonly object _lock = new object();

private volatile TaskCompletionSource<bool> _updateTaskSource;

private static readonly TokenRequestContext DefaultRequestContext = new TokenRequestContext(new string[] { Constants.AsrsDefaultScope });

private static readonly TimeSpan GetAccessKeyInterval = TimeSpan.FromMinutes(55);
Expand All @@ -44,8 +44,6 @@ internal class MicrosoftEntraAccessKey : IAccessKey

private readonly IHttpClientFactory _httpClientFactory;

private volatile int _updateState = 0;

private volatile bool _isAuthorized = false;

private DateTime _updateAt = DateTime.MinValue;
Expand Down Expand Up @@ -95,6 +93,9 @@ public MicrosoftEntraAccessKey(Uri serverEndpoint,
TokenCredential = credential;

_httpClientFactory = httpClientFactory ?? HttpClientFactory.Instance;

_updateTaskSource = new();
_updateTaskSource.TrySetResult(false);
}

public virtual async Task<string> GetMicrosoftEntraTokenAsync(CancellationToken ctoken = default)
Expand All @@ -121,31 +122,15 @@ public async Task<string> GenerateAccessTokenAsync(string audience,
AccessTokenAlgorithm algorithm,
CancellationToken ctoken = default)
{
if (!Initialized || NeedRefresh)
if (NeedRefresh)
{
var source = new CancellationTokenSource(Constants.Periods.DefaultUpdateAccessKeyTimeout);
_ = UpdateAccessKeyAsync(source.Token);
await UpdateAccessKeyAsync(source.Token).OrCancelAsync(ctoken);
}

await _initializedTcs.Task.OrCancelAsync(ctoken, "The access key initialization timed out.");

if (Available)
{
return AuthUtility.GenerateAccessToken(KeyBytes, Kid, audience, claims, lifetime, algorithm);
}
else
{
while (true)
{
if (_updateState == UpdateTaskIdle)
{
return Available
? AuthUtility.GenerateAccessToken(KeyBytes, Kid, audience, claims, lifetime, algorithm)
: throw new AzureSignalRAccessTokenNotAuthorizedException(TokenCredential, GetExceptionMessage(LastException), LastException);
}
await Task.Delay(100, ctoken);
}
}
return Available
? AuthUtility.GenerateAccessToken(KeyBytes, Kid, audience, claims, lifetime, algorithm)
: throw new AzureSignalRAccessTokenNotAuthorizedException(TokenCredential, GetExceptionMessage(LastException), LastException);
}

internal void UpdateAccessKey(string kid, string keyStr)
Expand All @@ -155,25 +140,34 @@ internal void UpdateAccessKey(string kid, string keyStr)
Available = true;
}

internal async Task UpdateAccessKeyAsync(CancellationToken ctoken = default)
internal async Task<bool> UpdateAccessKeyAsync(CancellationToken ctoken = default)
{
if (Interlocked.CompareExchange(ref _updateState, UpdateTaskRunning, UpdateTaskIdle) != UpdateTaskIdle)
TaskCompletionSource<bool> source;
lock (_lock)
{
return;
if (!_updateTaskSource.Task.IsCompleted)
{
source = _updateTaskSource;
}
else
{
_updateTaskSource = new TaskCompletionSource<bool>();
_ = UpdateAccessKeyInternalAsync(_updateTaskSource, ctoken);
source = _updateTaskSource;
}
}
return await source.Task;
}

internal async Task UpdateAccessKeyInternalAsync(TaskCompletionSource<bool> tcs, CancellationToken ctoken)
{
for (var i = 0; i < GetAccessKeyMaxRetryTimes; i++)
{
if (ctoken.IsCancellationRequested)
{
break;
}

var source = new CancellationTokenSource(GetAccessKeyTimeout);
try
{
await UpdateAccessKeyInternalAsync(source.Token).OrCancelAsync(ctoken);
Interlocked.Exchange(ref _updateState, UpdateTaskIdle);
tcs.TrySetResult(true);
return;
}
catch (OperationCanceledException e)
Expand All @@ -199,7 +193,7 @@ internal async Task UpdateAccessKeyAsync(CancellationToken ctoken = default)
// Update the status only when it becomes "not available" due to expiration to refresh updateAt.
Available = false;
}
Interlocked.Exchange(ref _updateState, UpdateTaskIdle);
tcs.TrySetResult(false);
}

private static string GetExceptionMessage(Exception? exception)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ public async Task TestNotInitialized()
var exception = await Assert.ThrowsAsync<TaskCanceledException>(
async () => await key.GenerateAccessTokenAsync("", [], TimeSpan.FromSeconds(1), AccessTokenAlgorithm.HS256, source.Token)
);
Assert.Contains("initialization timed out", exception.Message);
Assert.Contains("A task was canceled.", exception.Message);
}

[Theory]
Expand Down Expand Up @@ -190,9 +190,10 @@ public async Task TestLazyLoadAccessKeyFailed()

var task1 = key.GenerateAccessTokenAsync(DefaultAudience, [], TimeSpan.FromMinutes(1), AccessTokenAlgorithm.HS256);
var task2 = key.UpdateAccessKeyAsync();
Assert.True(task2.IsCompleted); // another task is in progress.
Assert.False(task2.IsCompleted);

await Assert.ThrowsAsync<AzureSignalRAccessTokenNotAuthorizedException>(async () => await task1);
Assert.False(await task2);

Assert.True(key.Initialized);
}
Expand Down

0 comments on commit 4c4e96b

Please sign in to comment.