Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refresh AccessKey passively #2114

Merged
merged 3 commits into from
Dec 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 1 addition & 8 deletions src/Microsoft.Azure.SignalR.AspNet/DispatcherHelper.cs
Original file line number Diff line number Diff line change
Expand Up @@ -61,14 +61,7 @@ internal static ServiceHubDispatcher PrepareAndGetDispatcher(IAppBuilder builder
configuration.Resolver.Register(typeof(IServerNameProvider), () => serverNameProvider);
}

var synchronizer = configuration.Resolver.Resolve<IAccessKeySynchronizer>();
if (synchronizer == null)
{
synchronizer = new AccessKeySynchronizer(loggerFactory);
configuration.Resolver.Register(typeof(IAccessKeySynchronizer), () => synchronizer);
}

var endpoint = new ServiceEndpointManager(synchronizer, options, loggerFactory);
var endpoint = new ServiceEndpointManager(options, loggerFactory);
configuration.Resolver.Register(typeof(IServiceEndpointManager), () => endpoint);

var requestIdProvider = configuration.Resolver.Resolve<IConnectionRequestIdProvider>();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,16 +9,12 @@ internal class ServiceEndpointManager : ServiceEndpointManagerBase
{
private readonly ServiceOptions _options;

private readonly IAccessKeySynchronizer _synchronizer;

public ServiceEndpointManager(IAccessKeySynchronizer synchronizer,
ServiceOptions options,
public ServiceEndpointManager(ServiceOptions options,
ILoggerFactory loggerFactory) :
base(options,
loggerFactory?.CreateLogger<ServiceEndpointManager>())
{
_options = options;
_synchronizer = synchronizer;
}

public override IServiceEndpointProvider GetEndpointProvider(ServiceEndpoint endpoint)
Expand All @@ -27,7 +23,6 @@ public override IServiceEndpointProvider GetEndpointProvider(ServiceEndpoint end
{
return null;
}
_synchronizer.AddServiceEndpoint(endpoint);
return new ServiceEndpointProvider(endpoint, _options);
}
}

This file was deleted.

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,14 @@ internal class MicrosoftEntraAccessKey : IAccessKey
{
internal static readonly TimeSpan GetAccessKeyTimeout = TimeSpan.FromSeconds(100);

private const int UpdateTaskIdle = 0;

private const int UpdateTaskRunning = 1;

private const int GetAccessKeyMaxRetryTimes = 3;

private const int GetMicrosoftEntraTokenMaxRetryTimes = 3;

private readonly object _lock = new object();

private volatile TaskCompletionSource<bool> _updateTaskSource;

private static readonly TokenRequestContext DefaultRequestContext = new TokenRequestContext(new string[] { Constants.AsrsDefaultScope });

private static readonly TimeSpan GetAccessKeyInterval = TimeSpan.FromMinutes(55);
Expand All @@ -40,12 +40,8 @@ internal class MicrosoftEntraAccessKey : IAccessKey

private static readonly TimeSpan AccessKeyExpireTime = TimeSpan.FromMinutes(120);

private readonly TaskCompletionSource<object?> _initializedTcs = new(TaskCreationOptions.RunContinuationsAsynchronously);

private readonly IHttpClientFactory _httpClientFactory;

private volatile int _updateState = 0;

private volatile bool _isAuthorized = false;

private DateTime _updateAt = DateTime.MinValue;
Expand All @@ -54,8 +50,6 @@ internal class MicrosoftEntraAccessKey : IAccessKey

private volatile byte[]? _keyBytes;

public bool Initialized => _initializedTcs.Task.IsCompleted;

public bool NeedRefresh => DateTime.UtcNow - _updateAt > (Available ? GetAccessKeyInterval : GetAccessKeyIntervalUnavailable);

public bool Available
Expand All @@ -70,7 +64,6 @@ private set
}
_updateAt = DateTime.UtcNow;
_isAuthorized = value;
_initializedTcs.TrySetResult(null);
}
}

Expand All @@ -95,6 +88,9 @@ public MicrosoftEntraAccessKey(Uri serverEndpoint,
TokenCredential = credential;

_httpClientFactory = httpClientFactory ?? HttpClientFactory.Instance;

_updateTaskSource = new(TaskCreationOptions.RunContinuationsAsynchronously);
_updateTaskSource.TrySetResult(false);
}

public virtual async Task<string> GetMicrosoftEntraTokenAsync(CancellationToken ctoken = default)
Expand All @@ -121,44 +117,67 @@ public async Task<string> GenerateAccessTokenAsync(string audience,
AccessTokenAlgorithm algorithm,
CancellationToken ctoken = default)
{
if (!_initializedTcs.Task.IsCompleted)
var updateTask = Task.CompletedTask;
if (NeedRefresh)
terencefan marked this conversation as resolved.
Show resolved Hide resolved
{
_ = UpdateAccessKeyAsync();
updateTask = UpdateAccessKeyAsync();
}

await _initializedTcs.Task.OrCancelAsync(ctoken, "The access key initialization timed out.");

if (!Available)
{
try
{
await updateTask.OrCancelAsync(ctoken);
}
catch (OperationCanceledException)
{
}
}
return Available
? AuthUtility.GenerateAccessToken(KeyBytes, Kid, audience, claims, lifetime, algorithm)
: throw new AzureSignalRAccessTokenNotAuthorizedException(TokenCredential, GetExceptionMessage(LastException), LastException);
: throw new AzureSignalRAccessTokenNotAuthorizedException(TokenCredential, GetExceptionMessage(LastException, _keyBytes != null), LastException);
}

internal void UpdateAccessKey(string kid, string keyStr)
{
_keyBytes = Encoding.UTF8.GetBytes(keyStr);
_kid = kid;
Available = true;
}

internal async Task UpdateAccessKeyAsync()
{
if (!NeedRefresh)
lock (_lock)
{
return;
_updateTaskSource.TrySetResult(true);
terencefan marked this conversation as resolved.
Show resolved Hide resolved
}
}

if (Interlocked.CompareExchange(ref _updateState, UpdateTaskRunning, UpdateTaskIdle) != UpdateTaskIdle)
internal async Task UpdateAccessKeyAsync()
{
TaskCompletionSource<bool> tcs;
lock (_lock)
{
terencefan marked this conversation as resolved.
Show resolved Hide resolved
return;
if (!_updateTaskSource.Task.IsCompleted)
{
tcs = _updateTaskSource;
}
else
{
_updateTaskSource = new TaskCompletionSource<bool>(TaskCreationOptions.RunContinuationsAsynchronously);
_ = UpdateAccessKeyInternalAsync(_updateTaskSource);
tcs = _updateTaskSource;
}
}
await tcs.Task;
}

private async Task UpdateAccessKeyInternalAsync(TaskCompletionSource<bool> tcs)
{
for (var i = 0; i < GetAccessKeyMaxRetryTimes; i++)
{
var source = new CancellationTokenSource(GetAccessKeyTimeout);
using var source = new CancellationTokenSource(GetAccessKeyTimeout);
try
{
await UpdateAccessKeyInternalAsync(source.Token);
Interlocked.Exchange(ref _updateState, UpdateTaskIdle);
tcs.TrySetResult(true);
return;
}
catch (OperationCanceledException e)
Expand All @@ -168,14 +187,7 @@ internal async Task UpdateAccessKeyAsync()
catch (Exception e)
{
LastException = e;
try
{
await Task.Delay(GetAccessKeyRetryInterval); // retry after interval.
}
catch (OperationCanceledException)
{
break;
}
await Task.Delay(GetAccessKeyRetryInterval); // retry after interval.
}
}

Expand All @@ -184,15 +196,15 @@ internal async Task UpdateAccessKeyAsync()
// Update the status only when it becomes "not available" due to expiration to refresh updateAt.
Available = false;
}
Interlocked.Exchange(ref _updateState, UpdateTaskIdle);
tcs.TrySetResult(false);
}

private static string GetExceptionMessage(Exception? exception)
private static string GetExceptionMessage(Exception? exception, bool initialized)
{
return exception switch
{
AzureSignalRUnauthorizedException => AzureSignalRUnauthorizedException.ErrorMessageMicrosoftEntra,
_ => exception?.Message ?? "The access key has expired.",
_ => exception?.Message ?? (initialized ? "The access key has expired." : "The access key has not been initialized."),
};
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,7 @@ public Task<string> Generate(string audience, TimeSpan? lifetime = null)
{
return key.GetMicrosoftEntraTokenAsync();
}

return _accessKey.GenerateAccessTokenAsync(audience,
_claims,
lifetime ?? Constants.Periods.DefaultAccessTokenLifetime,
DefaultAlgorithm);
var time = lifetime ?? Constants.Periods.DefaultAccessTokenLifetime;
return _accessKey.GenerateAccessTokenAsync(audience, _claims, time, DefaultAlgorithm);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,6 @@ private static ISignalRServerBuilder AddAzureSignalRCore(this ISignalRServerBuil
.AddSingleton(typeof(AzureSignalRMarkerService))
.AddSingleton<IClientConnectionFactory, ClientConnectionFactory>()
.AddSingleton<IHostedService, HeartBeat>()
.AddSingleton<IAccessKeySynchronizer, AccessKeySynchronizer>()
.AddSingleton(typeof(NegotiateHandler<>));

// If a custom router is added, do not add the default router
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,16 +19,12 @@ internal class ServiceEndpointManager : ServiceEndpointManagerBase

private readonly TimeSpan _scaleTimeout;

private readonly IAccessKeySynchronizer _synchronizer;

public ServiceEndpointManager(IAccessKeySynchronizer synchronizer,
IOptionsMonitor<ServiceOptions> optionsMonitor,
public ServiceEndpointManager(IOptionsMonitor<ServiceOptions> optionsMonitor,
ILoggerFactory loggerFactory) :
base(optionsMonitor.CurrentValue, loggerFactory.CreateLogger<ServiceEndpointManager>())
{
_options = optionsMonitor.CurrentValue;
_logger = loggerFactory?.CreateLogger<ServiceEndpointManager>() ?? throw new ArgumentNullException(nameof(loggerFactory));
_synchronizer = synchronizer;

optionsMonitor.OnChange(OnChange);
_scaleTimeout = _options.ServiceScaleTimeout;
Expand All @@ -40,8 +36,6 @@ public override IServiceEndpointProvider GetEndpointProvider(ServiceEndpoint end
{
return null;
}

_synchronizer.AddServiceEndpoint(endpoint);
return new ServiceEndpointProvider(endpoint, _options);
}

Expand All @@ -53,7 +47,6 @@ private void OnChange(ServiceOptions options)

private Task ReloadServiceEndpointsAsync(IEnumerable<ServiceEndpoint> serviceEndpoints)
{
_synchronizer.UpdateServiceEndpoints(serviceEndpoints);
return ReloadServiceEndpointsAsync(serviceEndpoints, _scaleTimeout);
}

Expand Down
Loading
Loading