From a4501deb73036c03c2f29c2dc5cd02e955440381 Mon Sep 17 00:00:00 2001 From: Terence Fan Date: Thu, 5 Dec 2024 14:17:07 +0800 Subject: [PATCH] return TokenCredential instead of AccessKey in parser --- .../Endpoints/ServiceEndpoint.cs | 13 +- .../Utilities/ConnectionStringParser.cs | 69 ++++---- .../Utilities/ParsedConnectionString.cs | 12 +- .../Auth/ConnectionStringParserTests.cs | 151 +++++++++++------- 4 files changed, 137 insertions(+), 108 deletions(-) diff --git a/src/Microsoft.Azure.SignalR.Common/Endpoints/ServiceEndpoint.cs b/src/Microsoft.Azure.SignalR.Common/Endpoints/ServiceEndpoint.cs index b5bf4cd3b..5dd4afc0d 100644 --- a/src/Microsoft.Azure.SignalR.Common/Endpoints/ServiceEndpoint.cs +++ b/src/Microsoft.Azure.SignalR.Common/Endpoints/ServiceEndpoint.cs @@ -103,6 +103,13 @@ public ServiceEndpoint(string nameWithEndpointType, string connectionString) : t (Name, EndpointType) = Parse(nameWithEndpointType); } + private static IAccessKey BuildAccessKey(ParsedConnectionString parsed) + { + return string.IsNullOrEmpty(parsed.AccessKey) + ? new MicrosoftEntraAccessKey(parsed.Endpoint, parsed.TokenCredential, parsed.ServerEndpoint) + : new AccessKey(parsed.Endpoint, parsed.AccessKey); + } + /// /// Connection string constructor /// @@ -116,12 +123,12 @@ public ServiceEndpoint(string connectionString, EndpointType type = EndpointType throw new ArgumentException($"'{nameof(connectionString)}' cannot be null or whitespace.", nameof(connectionString)); } ConnectionString = connectionString; - - var result = ConnectionStringParser.Parse(connectionString); EndpointType = type; Name = name; - _accessKey = result.AccessKey; + var result = ConnectionStringParser.Parse(connectionString); + + _accessKey = BuildAccessKey(result); _serviceEndpoint = result.Endpoint; _clientEndpoint = result.ClientEndpoint; _serverEndpoint = result.ServerEndpoint; diff --git a/src/Microsoft.Azure.SignalR.Common/Utilities/ConnectionStringParser.cs b/src/Microsoft.Azure.SignalR.Common/Utilities/ConnectionStringParser.cs index 60ebc2742..0ca052a5a 100644 --- a/src/Microsoft.Azure.SignalR.Common/Utilities/ConnectionStringParser.cs +++ b/src/Microsoft.Azure.SignalR.Common/Utilities/ConnectionStringParser.cs @@ -4,6 +4,7 @@ using System; using System.Collections.Generic; using System.Text.RegularExpressions; +using Azure.Core; using Azure.Identity; namespace Microsoft.Azure.SignalR; @@ -39,6 +40,7 @@ internal static class ConnectionStringParser private const string TypeAzure = "azure"; + [Obsolete] private const string TypeAzureAD = "aad"; private const string TypeAzureApp = "azure.app"; @@ -107,14 +109,9 @@ internal static ParsedConnectionString Parse(string connectionString) // parse and validate port. if (dict.TryGetValue(PortProperty, out var s)) { - if (int.TryParse(s, out var port) && port > 0 && port <= 0xFFFF) - { - builder.Port = port; - } - else - { - throw new ArgumentException(InvalidPortValue, nameof(port)); - } + builder.Port = int.TryParse(s, out var port) && port > 0 && port <= 0xFFFF + ? port + : throw new ArgumentException(InvalidPortValue, nameof(port)); } Uri? clientEndpointUri = null; @@ -140,19 +137,22 @@ internal static ParsedConnectionString Parse(string connectionString) // try building accesskey. dict.TryGetValue(AuthTypeProperty, out var type); - var accessKey = type?.ToLower() switch - { - TypeAzureAD => BuildAzureADAccessKey(builder.Uri, serverEndpointUri, dict), - TypeAzure => BuildAzureAccessKey(builder.Uri, serverEndpointUri, dict), - TypeAzureApp => BuildAzureAppAccessKey(builder.Uri, serverEndpointUri, dict), - TypeAzureMsi => BuildAzureMsiAccessKey(builder.Uri, serverEndpointUri, dict), - _ => BuildAccessKey(builder.Uri, dict), + var tokenCredential = type?.ToLower() switch + { + TypeAzureApp => BuildApplicationCredential(dict), + TypeAzureMsi => BuildManagedIdentityCredential(dict), +#pragma warning disable CS0612 // Type or member is obsolete + TypeAzureAD => BuildAzureTokenCredential(dict), +#pragma warning restore CS0612 // Type or member is obsolete + _ => new DefaultAzureCredential(), }; - return new ParsedConnectionString(builder.Uri) + dict.TryGetValue(AccessKeyProperty, out var accessKey); + + return new ParsedConnectionString(builder.Uri, tokenCredential) { - ClientEndpoint = clientEndpointUri, AccessKey = accessKey, + ClientEndpoint = clientEndpointUri, ServerEndpoint = serverEndpointUri }; } @@ -163,7 +163,8 @@ private static bool TryCreateEndpointUri(string endpoint, out Uri? uriResult) && (uriResult.Scheme == Uri.UriSchemeHttp || uriResult.Scheme == Uri.UriSchemeHttps); } - private static IAccessKey BuildAzureADAccessKey(Uri uri, Uri? serverEndpointUri, Dictionary dict) + [Obsolete] + private static TokenCredential BuildAzureTokenCredential(Dictionary dict) { if (dict.TryGetValue(ClientIdProperty, out var clientId)) { @@ -171,11 +172,11 @@ private static IAccessKey BuildAzureADAccessKey(Uri uri, Uri? serverEndpointUri, { if (dict.TryGetValue(ClientSecretProperty, out var clientSecret)) { - return new MicrosoftEntraAccessKey(uri, new ClientSecretCredential(tenantId, clientId, clientSecret), serverEndpointUri); + return new ClientSecretCredential(tenantId, clientId, clientSecret); } else if (dict.TryGetValue(ClientCertProperty, out var clientCertPath)) { - return new MicrosoftEntraAccessKey(uri, new ClientCertificateCredential(tenantId, clientId, clientCertPath), serverEndpointUri); + return new ClientCertificateCredential(tenantId, clientId, clientCertPath); } else { @@ -184,28 +185,16 @@ private static IAccessKey BuildAzureADAccessKey(Uri uri, Uri? serverEndpointUri, } else { - return new MicrosoftEntraAccessKey(uri, new ManagedIdentityCredential(clientId), serverEndpointUri); + return new ManagedIdentityCredential(clientId); } } else { - return new MicrosoftEntraAccessKey(uri, new ManagedIdentityCredential(), serverEndpointUri); + return new ManagedIdentityCredential(); } } - private static IAccessKey BuildAccessKey(Uri uri, Dictionary dict) - { - return dict.TryGetValue(AccessKeyProperty, out var key) - ? new AccessKey(uri, key) - : throw new ArgumentException(MissingAccessKeyProperty, AccessKeyProperty); - } - - private static IAccessKey BuildAzureAccessKey(Uri uri, Uri? serverEndpointUri, Dictionary dict) - { - return new MicrosoftEntraAccessKey(uri, new DefaultAzureCredential(), serverEndpointUri); - } - - private static IAccessKey BuildAzureAppAccessKey(Uri uri, Uri? serverEndpointUri, Dictionary dict) + private static TokenCredential BuildApplicationCredential(Dictionary dict) { if (!dict.TryGetValue(ClientIdProperty, out var clientId)) { @@ -219,20 +208,20 @@ private static IAccessKey BuildAzureAppAccessKey(Uri uri, Uri? serverEndpointUri if (dict.TryGetValue(ClientSecretProperty, out var clientSecret)) { - return new MicrosoftEntraAccessKey(uri, new ClientSecretCredential(tenantId, clientId, clientSecret), serverEndpointUri); + return new ClientSecretCredential(tenantId, clientId, clientSecret); } else if (dict.TryGetValue(ClientCertProperty, out var clientCertPath)) { - return new MicrosoftEntraAccessKey(uri, new ClientCertificateCredential(tenantId, clientId, clientCertPath), serverEndpointUri); + return new ClientCertificateCredential(tenantId, clientId, clientCertPath); } throw new ArgumentException(MissingClientSecretProperty, ClientSecretProperty); } - private static IAccessKey BuildAzureMsiAccessKey(Uri uri, Uri? serverEndpointUri, Dictionary dict) + private static TokenCredential BuildManagedIdentityCredential(Dictionary dict) { return dict.TryGetValue(ClientIdProperty, out var clientId) - ? new MicrosoftEntraAccessKey(uri, new ManagedIdentityCredential(clientId), serverEndpointUri) - : new MicrosoftEntraAccessKey(uri, new ManagedIdentityCredential(), serverEndpointUri); + ? new ManagedIdentityCredential(clientId) + : new ManagedIdentityCredential(); } private static Dictionary ToDictionary(string connectionString) diff --git a/src/Microsoft.Azure.SignalR.Common/Utilities/ParsedConnectionString.cs b/src/Microsoft.Azure.SignalR.Common/Utilities/ParsedConnectionString.cs index fdc6ae07b..7acc4aae0 100644 --- a/src/Microsoft.Azure.SignalR.Common/Utilities/ParsedConnectionString.cs +++ b/src/Microsoft.Azure.SignalR.Common/Utilities/ParsedConnectionString.cs @@ -2,6 +2,7 @@ // Licensed under the MIT license. See LICENSE file in the project root for full license information. using System; +using Azure.Core; namespace Microsoft.Azure.SignalR; @@ -11,14 +12,17 @@ internal class ParsedConnectionString { internal Uri Endpoint { get; } - internal IAccessKey? AccessKey { get; set; } + internal string? AccessKey { get; init; } - internal Uri? ClientEndpoint { get; set; } + internal TokenCredential TokenCredential { get; } - internal Uri? ServerEndpoint { get; set; } + internal Uri? ClientEndpoint { get; init; } - public ParsedConnectionString(Uri endpoint) + internal Uri? ServerEndpoint { get; init; } + + public ParsedConnectionString(Uri endpoint, TokenCredential tokenCredential) { Endpoint = endpoint; + TokenCredential = tokenCredential; } } diff --git a/test/Microsoft.Azure.SignalR.Common.Tests/Auth/ConnectionStringParserTests.cs b/test/Microsoft.Azure.SignalR.Common.Tests/Auth/ConnectionStringParserTests.cs index 66c484748..29d0f99aa 100644 --- a/test/Microsoft.Azure.SignalR.Common.Tests/Auth/ConnectionStringParserTests.cs +++ b/test/Microsoft.Azure.SignalR.Common.Tests/Auth/ConnectionStringParserTests.cs @@ -4,34 +4,52 @@ using System; using System.Collections; using System.Collections.Generic; - +using System.Reflection; using Azure.Identity; using Xunit; namespace Microsoft.Azure.SignalR.Common.Tests.Auth; +#nullable enable + [Collection("Auth")] public class ConnectionStringParserTests { - private const string ClientEndpoint = "http://bbb"; - private const string DefaultKey = "ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"; private const string HttpEndpoint = "http://aaa"; private const string HttpsEndpoint = "https://aaa"; + private const string ClientEndpoint = "http://bbb"; + private const string ServerEndpoint = "http://ccc"; - public static IEnumerable ServerEndpointTestData + private const string TestTenantId = "aaaaaaaa-bbbb-bbbb-bbbb-cccccccccccc"; + + private const string TestEndpoint = "https://aaa"; + + public static IEnumerable ServerEndpointTestData { get { - yield return new object[] { $"endpoint={HttpEndpoint};accesskey={DefaultKey}", null, null }; - yield return new object[] { $"endpoint={HttpEndpoint}:500;accesskey={DefaultKey}", null, null }; - yield return new object[] { $"endpoint={HttpEndpoint}:500;accesskey={DefaultKey};port=400", null, null }; - yield return new object[] { $"endpoint={HttpEndpoint}:500;accesskey={DefaultKey};port=400;serverEndpoint={ServerEndpoint}", ServerEndpoint, 80 }; - yield return new object[] { $"endpoint={HttpEndpoint}:500;accesskey={DefaultKey};port=400;serverEndpoint={ServerEndpoint}:500", $"{ServerEndpoint}:500", 500 }; + yield return new object?[] { $"endpoint={HttpEndpoint};accesskey={DefaultKey}", null, null }; + yield return new object?[] { $"endpoint={HttpEndpoint}:500;accesskey={DefaultKey}", null, null }; + yield return new object?[] { $"endpoint={HttpEndpoint}:500;accesskey={DefaultKey};port=400", null, null }; + yield return new object?[] { $"endpoint={HttpEndpoint}:500;accesskey={DefaultKey};port=400;serverEndpoint={ServerEndpoint}", ServerEndpoint, 80 }; + yield return new object?[] { $"endpoint={HttpEndpoint}:500;accesskey={DefaultKey};port=400;serverEndpoint={ServerEndpoint}:500", $"{ServerEndpoint}:500", 500 }; + } + } + + public static IEnumerable ClientEndpointTestData + { + get + { + yield return new object?[] { $"endpoint={HttpEndpoint};accesskey={DefaultKey}", null, null }; + yield return new object?[] { $"endpoint={HttpEndpoint}:500;accesskey={DefaultKey}", null, null }; + yield return new object?[] { $"endpoint={HttpEndpoint}:500;accesskey={DefaultKey};port=400", null, null }; + yield return new object?[] { $"endpoint={HttpEndpoint}:500;accesskey={DefaultKey};port=400;clientEndpoint={ClientEndpoint}", ClientEndpoint, 80 }; + yield return new object?[] { $"endpoint={HttpEndpoint}:500;accesskey={DefaultKey};port=400;clientEndpoint={ClientEndpoint}:500", $"{ClientEndpoint}:500", 500 }; } } @@ -62,7 +80,6 @@ public void InvalidServerEndpoint(string connectionString) Assert.Contains("Invalid value for serverEndpoint property, it must be a valid URI. (Parameter 'serverEndpoint')", exception.Message); } - [Theory] [InlineData("Endpoint=xxx")] [InlineData("AccessKey=xxx")] @@ -104,24 +121,10 @@ public void InvalidVersion(string connectionString, string version) } [Theory] - [InlineData("endpoint=https://aaa;AuthType=aad;clientId=foo;clientSecret=bar;tenantId=aaaaaaaa-bbbb-bbbb-bbbb-cccccccccccc")] - [InlineData("endpoint=https://aaa;AuthType=azure.app;clientId=foo;clientSecret=bar;tenantId=aaaaaaaa-bbbb-bbbb-bbbb-cccccccccccc")] - public void TestAzureApplication(string connectionString) - { - var r = ConnectionStringParser.Parse(connectionString); - - var key = Assert.IsType(r.AccessKey); - Assert.IsType(key.TokenCredential); - Assert.Same(r.Endpoint, r.AccessKey.Endpoint); - Assert.Null(r.ClientEndpoint); - } - - [Theory] - [ClassData(typeof(ClientEndpointTestData))] + [MemberData(nameof(ClientEndpointTestData))] public void TestClientEndpoint(string connectionString, string expectedClientEndpoint, int? expectedPort) { var r = ConnectionStringParser.Parse(connectionString); - Assert.Same(r.Endpoint, r.AccessKey.Endpoint); var expectedUri = expectedClientEndpoint == null ? null : new Uri(expectedClientEndpoint); Assert.Equal(expectedUri, r.ClientEndpoint); Assert.Equal(expectedPort, r.ClientEndpoint?.Port); @@ -132,45 +135,85 @@ public void TestClientEndpoint(string connectionString, string expectedClientEnd public void TestServerEndpoint(string connectionString, string expectedServerEndpoint, int? expectedPort) { var r = ConnectionStringParser.Parse(connectionString); - Assert.Same(r.Endpoint, r.AccessKey.Endpoint); var expectedUri = expectedServerEndpoint == null ? null : new Uri(expectedServerEndpoint); Assert.Equal(expectedUri, r.ServerEndpoint); Assert.Equal(expectedPort, r.ServerEndpoint?.Port); } [Theory] - [InlineData("https://aaa", "endpoint=https://aaa;AuthType=azure;clientId=xxxx;")] // should ignore the clientId - [InlineData("https://aaa", "endpoint=https://aaa;AuthType=azure;tenantId=xxxx;")] // should ignore the tenantId - [InlineData("https://aaa", "endpoint=https://aaa;AuthType=azure;clientSecret=xxxx;")] // should ignore the clientSecret - internal void TestDefaultAzureCredential(string expectedEndpoint, string connectionString) + [InlineData($"endpoint=https://aaa;AuthType=aad;clientId=foo;clientSecret=bar;tenantId={TestTenantId}")] + [InlineData($"endpoint=https://aaa;AuthType=azure.app;clientId=foo;clientSecret=bar;tenantId={TestTenantId}")] + public void TestClientSecretCredential(string connectionString) { var r = ConnectionStringParser.Parse(connectionString); + Assert.Null(r.AccessKey); + Assert.Null(r.ClientEndpoint); + Assert.Null(r.ServerEndpoint); + var credential = Assert.IsType(r.TokenCredential); + + var tenantIdField = typeof(ClientSecretCredential).GetProperty("TenantId", BindingFlags.Instance | BindingFlags.NonPublic); + Assert.Equal(TestTenantId, Assert.IsType(tenantIdField?.GetValue(credential))); + + var clientIdField = typeof(ClientSecretCredential).GetProperty("ClientId", BindingFlags.Instance | BindingFlags.NonPublic); + Assert.Equal("foo", Assert.IsType(clientIdField?.GetValue(credential))); - Assert.Equal(expectedEndpoint, r.Endpoint.AbsoluteUri.TrimEnd('/')); - var key = Assert.IsType(r.AccessKey); - Assert.IsType(key.TokenCredential); - Assert.Same(r.Endpoint, r.AccessKey.Endpoint); + var clientSecretField = typeof(ClientSecretCredential).GetProperty("ClientSecret", BindingFlags.Instance | BindingFlags.NonPublic); + Assert.Equal("bar", Assert.IsType(clientSecretField?.GetValue(credential))); } [Theory] - [InlineData("https://aaa", "endpoint=https://aaa;AuthType=aad;")] - [InlineData("https://aaa", "endpoint=https://aaa;AuthType=aad;clientId=123;")] - [InlineData("https://aaa", "endpoint=https://aaa;AuthType=aad;tenantId=xxxx;")] // should ignore the tenantId - [InlineData("https://aaa", "endpoint=https://aaa;AuthType=aad;clientSecret=xxxx;")] // should ignore the clientSecret - [InlineData("https://aaa", "endpoint=https://aaa;AuthType=azure.msi;")] - [InlineData("https://aaa", "endpoint=https://aaa;AuthType=azure.msi;clientId=123;")] - internal void TestManagedIdentity(string expectedEndpoint, string connectionString) + [InlineData($"endpoint=https://aaa;AuthType=aad;clientId=foo;clientCert=bar;tenantId={TestTenantId}")] + [InlineData($"endpoint=https://aaa;AuthType=azure.app;clientId=foo;clientCert=bar;tenantId={TestTenantId}")] + public void TestClientCertificateCredential(string connectionString) { var r = ConnectionStringParser.Parse(connectionString); + Assert.Null(r.AccessKey); + Assert.Null(r.ClientEndpoint); + Assert.Null(r.ServerEndpoint); + var credential = Assert.IsType(r.TokenCredential); + + var tenantIdField = typeof(ClientCertificateCredential).GetProperty("TenantId", BindingFlags.Instance | BindingFlags.NonPublic); + Assert.Equal(TestTenantId, Assert.IsType(tenantIdField?.GetValue(credential))); + + var clientIdField = typeof(ClientCertificateCredential).GetProperty("ClientId", BindingFlags.Instance | BindingFlags.NonPublic); + Assert.Equal("foo", Assert.IsType(clientIdField?.GetValue(credential))); + } + + [Theory] + [InlineData($"endpoint={TestEndpoint};AuthType=azure;clientId=xxxx;")] // should ignore the clientId + [InlineData($"endpoint={TestEndpoint};AuthType=azure;tenantId=xxxx;")] // should ignore the tenantId + [InlineData($"endpoint={TestEndpoint};AuthType=azure;clientSecret=xxxx;")] // should ignore the clientSecret + internal void TestDefaultAzureCredential(string connectionString) + { + var r = ConnectionStringParser.Parse(connectionString); + Assert.Equal(TestEndpoint, r.Endpoint.AbsoluteUri.TrimEnd('/')); + + Assert.Null(r.AccessKey); + Assert.Null(r.ClientEndpoint); + Assert.Null(r.ServerEndpoint); + Assert.IsType(r.TokenCredential); + } + + [Theory] + [InlineData($"endpoint={TestEndpoint};AuthType=aad;")] + [InlineData($"endpoint={TestEndpoint};AuthType=aad;clientId=123;")] + [InlineData($"endpoint={TestEndpoint};AuthType=aad;tenantId=xxxx;")] // should ignore the tenantId + [InlineData($"endpoint={TestEndpoint};AuthType=aad;clientSecret=xxxx;")] // should ignore the clientSecret + [InlineData($"endpoint={TestEndpoint};AuthType=azure.msi;")] + [InlineData($"endpoint={TestEndpoint};AuthType=azure.msi;clientId=123;")] + internal void TestManagedIdentityCredential(string connectionString) + { + var r = ConnectionStringParser.Parse(connectionString); + Assert.Equal(TestEndpoint, r.Endpoint.AbsoluteUri.TrimEnd('/')); - Assert.Equal(expectedEndpoint, r.Endpoint.AbsoluteUri.TrimEnd('/')); - var key = Assert.IsType(r.AccessKey); - Assert.IsType(key.TokenCredential); - Assert.Same(r.Endpoint, r.AccessKey.Endpoint); + Assert.Null(r.AccessKey); Assert.Null(r.ClientEndpoint); + Assert.Null(r.ServerEndpoint); + Assert.IsType(r.TokenCredential); } [Theory] + [Obsolete] [InlineData("endpoint=https://aaa;AuthType=aad;serverendpoint=https://foo", "https://foo/api/v1/auth/accesskey")] [InlineData("endpoint=https://aaa;AuthType=aad;serverendpoint=https://foo:123", "https://foo:123/api/v1/auth/accesskey")] [InlineData("endpoint=https://aaa;AuthType=aad;serverendpoint=https://foo/bar", "https://foo/bar/api/v1/auth/accesskey")] @@ -178,25 +221,11 @@ internal void TestManagedIdentity(string expectedEndpoint, string connectionStri [InlineData("endpoint=https://aaa;AuthType=aad;serverendpoint=https://foo:123/bar/", "https://foo:123/bar/api/v1/auth/accesskey")] internal void TestAzureADWithServerEndpoint(string connectionString, string expectedAuthorizeUrl) { - var r = ConnectionStringParser.Parse(connectionString); - var key = Assert.IsType(r.AccessKey); + var endpoint = new ServiceEndpoint(connectionString); + var key = Assert.IsType(endpoint.AccessKey); Assert.Equal(expectedAuthorizeUrl, key.GetAccessKeyUrl, StringComparer.OrdinalIgnoreCase); } - public class ClientEndpointTestData : IEnumerable - { - public IEnumerator GetEnumerator() - { - yield return new object[] { $"endpoint={HttpEndpoint};accesskey={DefaultKey}", null, null }; - yield return new object[] { $"endpoint={HttpEndpoint}:500;accesskey={DefaultKey}", null, null }; - yield return new object[] { $"endpoint={HttpEndpoint}:500;accesskey={DefaultKey};port=400", null, null }; - yield return new object[] { $"endpoint={HttpEndpoint}:500;accesskey={DefaultKey};port=400;clientEndpoint={ClientEndpoint}", ClientEndpoint, 80 }; - yield return new object[] { $"endpoint={HttpEndpoint}:500;accesskey={DefaultKey};port=400;clientEndpoint={ClientEndpoint}:500", $"{ClientEndpoint}:500", 500 }; - } - - IEnumerator IEnumerable.GetEnumerator() => GetEnumerator(); - } - public class EndpointEndWithSlash : IEnumerable { public IEnumerator GetEnumerator()