Skip to content

Commit

Permalink
Refresh AccessKey passively (#2114)
Browse files Browse the repository at this point in the history
* Refresh AccessKey passively

* revert unrelated codes
  • Loading branch information
terencefan authored Dec 11, 2024
1 parent f7ae346 commit 1f0ef6c
Show file tree
Hide file tree
Showing 13 changed files with 256 additions and 366 deletions.
9 changes: 1 addition & 8 deletions src/Microsoft.Azure.SignalR.AspNet/DispatcherHelper.cs
Original file line number Diff line number Diff line change
Expand Up @@ -61,14 +61,7 @@ internal static ServiceHubDispatcher PrepareAndGetDispatcher(IAppBuilder builder
configuration.Resolver.Register(typeof(IServerNameProvider), () => serverNameProvider);
}

var synchronizer = configuration.Resolver.Resolve<IAccessKeySynchronizer>();
if (synchronizer == null)
{
synchronizer = new AccessKeySynchronizer(loggerFactory);
configuration.Resolver.Register(typeof(IAccessKeySynchronizer), () => synchronizer);
}

var endpoint = new ServiceEndpointManager(synchronizer, options, loggerFactory);
var endpoint = new ServiceEndpointManager(options, loggerFactory);
configuration.Resolver.Register(typeof(IServiceEndpointManager), () => endpoint);

var requestIdProvider = configuration.Resolver.Resolve<IConnectionRequestIdProvider>();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,16 +9,12 @@ internal class ServiceEndpointManager : ServiceEndpointManagerBase
{
private readonly ServiceOptions _options;

private readonly IAccessKeySynchronizer _synchronizer;

public ServiceEndpointManager(IAccessKeySynchronizer synchronizer,
ServiceOptions options,
public ServiceEndpointManager(ServiceOptions options,
ILoggerFactory loggerFactory) :
base(options,
loggerFactory?.CreateLogger<ServiceEndpointManager>())
{
_options = options;
_synchronizer = synchronizer;
}

public override IServiceEndpointProvider GetEndpointProvider(ServiceEndpoint endpoint)
Expand All @@ -27,7 +23,6 @@ public override IServiceEndpointProvider GetEndpointProvider(ServiceEndpoint end
{
return null;
}
_synchronizer.AddServiceEndpoint(endpoint);
return new ServiceEndpointProvider(endpoint, _options);
}
}

This file was deleted.

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,14 @@ internal class MicrosoftEntraAccessKey : IAccessKey
{
internal static readonly TimeSpan GetAccessKeyTimeout = TimeSpan.FromSeconds(100);

private const int UpdateTaskIdle = 0;

private const int UpdateTaskRunning = 1;

private const int GetAccessKeyMaxRetryTimes = 3;

private const int GetMicrosoftEntraTokenMaxRetryTimes = 3;

private readonly object _lock = new object();

private volatile TaskCompletionSource<bool> _updateTaskSource;

private static readonly TokenRequestContext DefaultRequestContext = new TokenRequestContext(new string[] { Constants.AsrsDefaultScope });

private static readonly TimeSpan GetAccessKeyInterval = TimeSpan.FromMinutes(55);
Expand All @@ -40,12 +40,8 @@ internal class MicrosoftEntraAccessKey : IAccessKey

private static readonly TimeSpan AccessKeyExpireTime = TimeSpan.FromMinutes(120);

private readonly TaskCompletionSource<object?> _initializedTcs = new(TaskCreationOptions.RunContinuationsAsynchronously);

private readonly IHttpClientFactory _httpClientFactory;

private volatile int _updateState = 0;

private volatile bool _isAuthorized = false;

private DateTime _updateAt = DateTime.MinValue;
Expand All @@ -54,8 +50,6 @@ internal class MicrosoftEntraAccessKey : IAccessKey

private volatile byte[]? _keyBytes;

public bool Initialized => _initializedTcs.Task.IsCompleted;

public bool NeedRefresh => DateTime.UtcNow - _updateAt > (Available ? GetAccessKeyInterval : GetAccessKeyIntervalUnavailable);

public bool Available
Expand All @@ -70,7 +64,6 @@ private set
}
_updateAt = DateTime.UtcNow;
_isAuthorized = value;
_initializedTcs.TrySetResult(null);
}
}

Expand All @@ -95,6 +88,9 @@ public MicrosoftEntraAccessKey(Uri serverEndpoint,
TokenCredential = credential;

_httpClientFactory = httpClientFactory ?? HttpClientFactory.Instance;

_updateTaskSource = new(TaskCreationOptions.RunContinuationsAsynchronously);
_updateTaskSource.TrySetResult(false);
}

public virtual async Task<string> GetMicrosoftEntraTokenAsync(CancellationToken ctoken = default)
Expand All @@ -121,44 +117,67 @@ public async Task<string> GenerateAccessTokenAsync(string audience,
AccessTokenAlgorithm algorithm,
CancellationToken ctoken = default)
{
if (!_initializedTcs.Task.IsCompleted)
var updateTask = Task.CompletedTask;
if (NeedRefresh)
{
_ = UpdateAccessKeyAsync();
updateTask = UpdateAccessKeyAsync();
}

await _initializedTcs.Task.OrCancelAsync(ctoken, "The access key initialization timed out.");

if (!Available)
{
try
{
await updateTask.OrCancelAsync(ctoken);
}
catch (OperationCanceledException)
{
}
}
return Available
? AuthUtility.GenerateAccessToken(KeyBytes, Kid, audience, claims, lifetime, algorithm)
: throw new AzureSignalRAccessTokenNotAuthorizedException(TokenCredential, GetExceptionMessage(LastException), LastException);
: throw new AzureSignalRAccessTokenNotAuthorizedException(TokenCredential, GetExceptionMessage(LastException, _keyBytes != null), LastException);
}

internal void UpdateAccessKey(string kid, string keyStr)
{
_keyBytes = Encoding.UTF8.GetBytes(keyStr);
_kid = kid;
Available = true;
}

internal async Task UpdateAccessKeyAsync()
{
if (!NeedRefresh)
lock (_lock)
{
return;
_updateTaskSource.TrySetResult(true);
}
}

if (Interlocked.CompareExchange(ref _updateState, UpdateTaskRunning, UpdateTaskIdle) != UpdateTaskIdle)
internal async Task UpdateAccessKeyAsync()
{
TaskCompletionSource<bool> tcs;
lock (_lock)
{
return;
if (!_updateTaskSource.Task.IsCompleted)
{
tcs = _updateTaskSource;
}
else
{
_updateTaskSource = new TaskCompletionSource<bool>(TaskCreationOptions.RunContinuationsAsynchronously);
_ = UpdateAccessKeyInternalAsync(_updateTaskSource);
tcs = _updateTaskSource;
}
}
await tcs.Task;
}

private async Task UpdateAccessKeyInternalAsync(TaskCompletionSource<bool> tcs)
{
for (var i = 0; i < GetAccessKeyMaxRetryTimes; i++)
{
var source = new CancellationTokenSource(GetAccessKeyTimeout);
using var source = new CancellationTokenSource(GetAccessKeyTimeout);
try
{
await UpdateAccessKeyInternalAsync(source.Token);
Interlocked.Exchange(ref _updateState, UpdateTaskIdle);
tcs.TrySetResult(true);
return;
}
catch (OperationCanceledException e)
Expand All @@ -168,14 +187,7 @@ internal async Task UpdateAccessKeyAsync()
catch (Exception e)
{
LastException = e;
try
{
await Task.Delay(GetAccessKeyRetryInterval); // retry after interval.
}
catch (OperationCanceledException)
{
break;
}
await Task.Delay(GetAccessKeyRetryInterval); // retry after interval.
}
}

Expand All @@ -184,15 +196,15 @@ internal async Task UpdateAccessKeyAsync()
// Update the status only when it becomes "not available" due to expiration to refresh updateAt.
Available = false;
}
Interlocked.Exchange(ref _updateState, UpdateTaskIdle);
tcs.TrySetResult(false);
}

private static string GetExceptionMessage(Exception? exception)
private static string GetExceptionMessage(Exception? exception, bool initialized)
{
return exception switch
{
AzureSignalRUnauthorizedException => AzureSignalRUnauthorizedException.ErrorMessageMicrosoftEntra,
_ => exception?.Message ?? "The access key has expired.",
_ => exception?.Message ?? (initialized ? "The access key has expired." : "The access key has not been initialized."),
};
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,7 @@ public Task<string> Generate(string audience, TimeSpan? lifetime = null)
{
return key.GetMicrosoftEntraTokenAsync();
}

return _accessKey.GenerateAccessTokenAsync(audience,
_claims,
lifetime ?? Constants.Periods.DefaultAccessTokenLifetime,
DefaultAlgorithm);
var time = lifetime ?? Constants.Periods.DefaultAccessTokenLifetime;
return _accessKey.GenerateAccessTokenAsync(audience, _claims, time, DefaultAlgorithm);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,6 @@ private static ISignalRServerBuilder AddAzureSignalRCore(this ISignalRServerBuil
.AddSingleton(typeof(AzureSignalRMarkerService))
.AddSingleton<IClientConnectionFactory, ClientConnectionFactory>()
.AddSingleton<IHostedService, HeartBeat>()
.AddSingleton<IAccessKeySynchronizer, AccessKeySynchronizer>()
.AddSingleton(typeof(NegotiateHandler<>));

// If a custom router is added, do not add the default router
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,16 +19,12 @@ internal class ServiceEndpointManager : ServiceEndpointManagerBase

private readonly TimeSpan _scaleTimeout;

private readonly IAccessKeySynchronizer _synchronizer;

public ServiceEndpointManager(IAccessKeySynchronizer synchronizer,
IOptionsMonitor<ServiceOptions> optionsMonitor,
public ServiceEndpointManager(IOptionsMonitor<ServiceOptions> optionsMonitor,
ILoggerFactory loggerFactory) :
base(optionsMonitor.CurrentValue, loggerFactory.CreateLogger<ServiceEndpointManager>())
{
_options = optionsMonitor.CurrentValue;
_logger = loggerFactory?.CreateLogger<ServiceEndpointManager>() ?? throw new ArgumentNullException(nameof(loggerFactory));
_synchronizer = synchronizer;

optionsMonitor.OnChange(OnChange);
_scaleTimeout = _options.ServiceScaleTimeout;
Expand All @@ -40,8 +36,6 @@ public override IServiceEndpointProvider GetEndpointProvider(ServiceEndpoint end
{
return null;
}

_synchronizer.AddServiceEndpoint(endpoint);
return new ServiceEndpointProvider(endpoint, _options);
}

Expand All @@ -53,7 +47,6 @@ private void OnChange(ServiceOptions options)

private Task ReloadServiceEndpointsAsync(IEnumerable<ServiceEndpoint> serviceEndpoints)
{
_synchronizer.UpdateServiceEndpoints(serviceEndpoints);
return ReloadServiceEndpointsAsync(serviceEndpoints, _scaleTimeout);
}

Expand Down
Loading

0 comments on commit 1f0ef6c

Please sign in to comment.