Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixed initialization race-condition in generated code. #7834

Merged
merged 3 commits into from
Dec 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ public string WriteBeginClass(string typeName)
_writer.WriteIndentedLine("internal static class {0}", typeName);
_writer.WriteIndentedLine("{");
_writer.IncreaseIndent();
_writer.WriteIndentedLine("private static readonly object _sync = new object();");
_writer.WriteIndentedLine("private static bool _bindingsInitialized;");
return typeName;
}
Expand Down Expand Up @@ -94,15 +95,18 @@ public void AddParameterInitializer(IEnumerable<Resolver> resolvers, ILocalTypeL

if (first)
{
_writer.WriteIndentedLine("if (_bindingsInitialized)");
_writer.WriteIndentedLine("if (!_bindingsInitialized)");
_writer.WriteIndentedLine("{");
using (_writer.IncreaseIndent())
{
_writer.WriteIndentedLine("return;");
}
_writer.IncreaseIndent();

_writer.WriteIndentedLine("lock (_sync)");
_writer.WriteIndentedLine("{");
_writer.IncreaseIndent();

_writer.WriteIndentedLine("if (!_bindingsInitialized)");
_writer.WriteIndentedLine("{");
_writer.IncreaseIndent();

_writer.WriteIndentedLine("}");
_writer.WriteIndentedLine("_bindingsInitialized = true;");
_writer.WriteLine();
_writer.WriteIndentedLine(
"const global::{0} bindingFlags =",
Expand All @@ -120,6 +124,8 @@ public void AddParameterInitializer(IEnumerable<Resolver> resolvers, ILocalTypeL
_writer.WriteIndentedLine("var type = typeof({0});", method.ContainingType.ToFullyQualified());
_writer.WriteIndentedLine("global::System.Reflection.MethodInfo resolver = default!;");
_writer.WriteIndentedLine("global::System.Reflection.ParameterInfo[] parameters = default!;");

_writer.WriteIndentedLine("_bindingsInitialized = true;");
first = false;
}

Expand Down Expand Up @@ -182,8 +188,8 @@ public void AddParameterInitializer(IEnumerable<Resolver> resolvers, ILocalTypeL
using (_writer.WriteForEach("binding", $"_args_{resolver.TypeName}_{resolver.Member.Name}"))
{
using (_writer.WriteIfClause(
"binding.Kind == global::{0}.Argument",
WellKnownTypes.ArgumentKind))
"binding.Kind == global::{0}.Argument",
WellKnownTypes.ArgumentKind))
{
_writer.WriteIndentedLine("argumentCount++;");
}
Expand All @@ -204,8 +210,8 @@ public void AddParameterInitializer(IEnumerable<Resolver> resolvers, ILocalTypeL
using (_writer.IncreaseIndent())
{
_writer.WriteIndentedLine(
".SetMessage(\"The node resolver `{0}.{1}` mustn't have more than one " +
"argument. Node resolvers can only have a single argument called `id`.\")",
".SetMessage(\"The node resolver `{0}.{1}` mustn't have more than one "
+ "argument. Node resolvers can only have a single argument called `id`.\")",
resolver.Member.ContainingType.ToDisplayString(),
resolver.Member.Name);
_writer.WriteIndentedLine(".Build());");
Expand All @@ -214,6 +220,16 @@ public void AddParameterInitializer(IEnumerable<Resolver> resolvers, ILocalTypeL
}
}
}

if (!first)
{
_writer.DecreaseIndent();
_writer.WriteIndentedLine("}");
_writer.DecreaseIndent();
_writer.WriteIndentedLine("}");
_writer.DecreaseIndent();
_writer.WriteIndentedLine("}");
}
}

_writer.WriteIndentedLine("}");
Expand All @@ -224,8 +240,7 @@ private static string ToFullyQualifiedString(
IMethodSymbol resolverMethod,
ILocalTypeLookup typeLookup)
{
if (type.TypeKind is TypeKind.Error &&
typeLookup.TryGetTypeName(type, resolverMethod, out var typeDisplayName))
if (type.TypeKind is TypeKind.Error && typeLookup.TryGetTypeName(type, resolverMethod, out var typeDisplayName))
{
return typeDisplayName;
}
Expand Down Expand Up @@ -269,9 +284,9 @@ private void AddStaticStandardResolver(
ILocalTypeLookup typeLookup)
{
using (_writer.WriteMethod(
"public static",
returnType: WellKnownTypes.FieldResolverDelegates,
methodName: $"{resolver.TypeName}_{resolver.Member.Name}"))
"public static",
returnType: WellKnownTypes.FieldResolverDelegates,
methodName: $"{resolver.TypeName}_{resolver.Member.Name}"))
{
using (_writer.WriteIfClause(condition: "!_bindingsInitialized"))
{
Expand Down Expand Up @@ -341,9 +356,9 @@ private void AddStaticStandardResolver(
private void AddStaticPureResolver(Resolver resolver, IMethodSymbol resolverMethod, ILocalTypeLookup typeLookup)
{
using (_writer.WriteMethod(
"public static",
returnType: WellKnownTypes.FieldResolverDelegates,
methodName: $"{resolver.TypeName}_{resolver.Member.Name}"))
"public static",
returnType: WellKnownTypes.FieldResolverDelegates,
methodName: $"{resolver.TypeName}_{resolver.Member.Name}"))
{
using (_writer.WriteIfClause(condition: "!_bindingsInitialized"))
{
Expand Down Expand Up @@ -436,9 +451,9 @@ private void AddStaticPureResolver(Resolver resolver, IMethodSymbol resolverMeth
private void AddStaticPropertyResolver(Resolver resolver)
{
using (_writer.WriteMethod(
"public static",
returnType: WellKnownTypes.FieldResolverDelegates,
methodName: $"{resolver.TypeName}_{resolver.Member.Name}"))
"public static",
returnType: WellKnownTypes.FieldResolverDelegates,
methodName: $"{resolver.TypeName}_{resolver.Member.Name}"))
{
using (_writer.WriteIfClause(condition: "!_bindingsInitialized"))
{
Expand Down Expand Up @@ -489,13 +504,13 @@ private void AddResolverArguments(Resolver resolver, IMethodSymbol resolverMetho
{
var parameter = resolver.Parameters[i];

if(resolver.IsNodeResolver
&& parameter.Kind is ResolverParameterKind.Argument or ResolverParameterKind.Unknown
&& (parameter.Name == "id" || parameter.Key == "id"))
if (resolver.IsNodeResolver
&& parameter.Kind is ResolverParameterKind.Argument or ResolverParameterKind.Unknown
&& (parameter.Name == "id" || parameter.Key == "id"))
{
_writer.WriteIndentedLine(
"var args{0} = context.GetLocalState<{1}>(" +
"global::HotChocolate.WellKnownContextData.InternalId);",
"var args{0} = context.GetLocalState<{1}>("
+ "global::HotChocolate.WellKnownContextData.InternalId);",
i,
parameter.Type.ToFullyQualified());
continue;
Expand Down Expand Up @@ -524,8 +539,8 @@ private void AddResolverArguments(Resolver resolver, IMethodSymbol resolverMetho
break;
case ResolverParameterKind.EventMessage:
_writer.WriteIndentedLine(
"var args{0} = context.GetScopedState<{1}>(" +
"global::HotChocolate.WellKnownContextData.EventMessage);",
"var args{0} = context.GetScopedState<{1}>("
+ "global::HotChocolate.WellKnownContextData.EventMessage);",
i,
parameter.Type.ToFullyQualified());
break;
Expand Down Expand Up @@ -593,8 +608,8 @@ private void AddResolverArguments(Resolver resolver, IMethodSymbol resolverMetho
}
case ResolverParameterKind.SetGlobalState:
_writer.WriteIndentedLine(
"var args{0} = new HotChocolate.SetState<{1}>(" +
"value => context.SetGlobalState(\"{2}\", value));",
"var args{0} = new HotChocolate.SetState<{1}>("
+ "value => context.SetGlobalState(\"{2}\", value));",
i,
((INamedTypeSymbol)parameter.Type).TypeArguments[0].ToFullyQualified(),
parameter.Key);
Expand Down Expand Up @@ -633,8 +648,8 @@ private void AddResolverArguments(Resolver resolver, IMethodSymbol resolverMetho
}
case ResolverParameterKind.SetScopedState:
_writer.WriteIndentedLine(
"var args{0} = new HotChocolate.SetState<{1}>(" +
"value => context.SetScopedState(\"{2}\", value));",
"var args{0} = new HotChocolate.SetState<{1}>("
+ "value => context.SetScopedState(\"{2}\", value));",
i,
((INamedTypeSymbol)parameter.Type).TypeArguments[0].ToFullyQualified(),
parameter.Key);
Expand Down Expand Up @@ -673,8 +688,8 @@ private void AddResolverArguments(Resolver resolver, IMethodSymbol resolverMetho
}
case ResolverParameterKind.SetLocalState:
_writer.WriteIndentedLine(
"var args{0} = new HotChocolate.SetState<{1}>(" +
"value => context.SetLocalState(\"{2}\", value));",
"var args{0} = new HotChocolate.SetState<{1}>("
+ "value => context.SetLocalState(\"{2}\", value));",
i,
((INamedTypeSymbol)parameter.Type).TypeArguments[0].ToFullyQualified(),
parameter.Key);
Expand Down
68 changes: 68 additions & 0 deletions src/HotChocolate/Core/test/Types.Analyzers.Tests/TestMe.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
// <auto-generated/>

#nullable enable
#pragma warning disable

using System;
using System.Runtime.CompilerServices;
using HotChocolate;
using HotChocolate.Types;
using HotChocolate.Execution.Configuration;
using HotChocolate.Internal;

namespace HotChocolate.Types
{
internal static class EntityInterfaceResolvers2
{
private static readonly object _sync = new object();
private static bool _bindingsInitialized;
private readonly static global::HotChocolate.Internal.IParameterBinding[] _args_EntityInterface_IdString = new global::HotChocolate.Internal.IParameterBinding[1];

public static void InitializeBindings(global::HotChocolate.Internal.IParameterBindingResolver bindingResolver)
{
if (!_bindingsInitialized)
{
lock (_sync)
{
if (!_bindingsInitialized)
{
const global::System.Reflection.BindingFlags bindingFlags =
global::System.Reflection.BindingFlags.Public
| global::System.Reflection.BindingFlags.NonPublic
| global::System.Reflection.BindingFlags.Static;

var type = typeof(global::HotChocolate.Types.EntityInterface);
global::System.Reflection.MethodInfo resolver = default!;
global::System.Reflection.ParameterInfo[] parameters = default!;

resolver = type.GetMethod(
"IdString",
bindingFlags,
new global::System.Type[] { typeof(global::HotChocolate.Types.IEntity) })!;
parameters = resolver.GetParameters();
_args_EntityInterface_IdString[0] = bindingResolver.GetBinding(parameters[0]);

_bindingsInitialized = true;
}
}
}
}

public static HotChocolate.Resolvers.FieldResolverDelegates EntityInterface_IdString()
{
if(!_bindingsInitialized)
{
throw new global::System.InvalidOperationException("The bindings must be initialized before the resolvers can be created.");
}
return new global::HotChocolate.Resolvers.FieldResolverDelegates(pureResolver: EntityInterface_IdString_Resolver);
}

private static global::System.Object? EntityInterface_IdString_Resolver(global::HotChocolate.Resolvers.IResolverContext context)
{
var args0 = context.Parent<global::HotChocolate.Types.IEntity>();
var result = global::HotChocolate.Types.EntityInterface.IdString(args0);
return result;
}
}
}

Original file line number Diff line number Diff line change
Expand Up @@ -19,37 +19,43 @@ namespace TestNamespace
{
internal static class BookNodeResolvers
{
private static readonly object _sync = new object();
private static bool _bindingsInitialized;
private readonly static global::HotChocolate.Internal.IParameterBinding[] _args_BookNode_GetAuthorAsync = new global::HotChocolate.Internal.IParameterBinding[2];

public static void InitializeBindings(global::HotChocolate.Internal.IParameterBindingResolver bindingResolver)
{
if (_bindingsInitialized)
if (!_bindingsInitialized)
{
return;
}
_bindingsInitialized = true;

const global::System.Reflection.BindingFlags bindingFlags =
global::System.Reflection.BindingFlags.Public
| global::System.Reflection.BindingFlags.NonPublic
| global::System.Reflection.BindingFlags.Static;

var type = typeof(global::TestNamespace.BookNode);
global::System.Reflection.MethodInfo resolver = default!;
global::System.Reflection.ParameterInfo[] parameters = default!;

resolver = type.GetMethod(
"GetAuthorAsync",
bindingFlags,
new global::System.Type[]
lock (_sync)
{
typeof(global::TestNamespace.Book),
typeof(global::System.Threading.CancellationToken)
})!;
parameters = resolver.GetParameters();
_args_BookNode_GetAuthorAsync[0] = bindingResolver.GetBinding(parameters[0]);
_args_BookNode_GetAuthorAsync[1] = bindingResolver.GetBinding(parameters[1]);
if (!_bindingsInitialized)
{

const global::System.Reflection.BindingFlags bindingFlags =
global::System.Reflection.BindingFlags.Public
| global::System.Reflection.BindingFlags.NonPublic
| global::System.Reflection.BindingFlags.Static;

var type = typeof(global::TestNamespace.BookNode);
global::System.Reflection.MethodInfo resolver = default!;
global::System.Reflection.ParameterInfo[] parameters = default!;
_bindingsInitialized = true;

resolver = type.GetMethod(
"GetAuthorAsync",
bindingFlags,
new global::System.Type[]
{
typeof(global::TestNamespace.Book),
typeof(global::System.Threading.CancellationToken)
})!;
parameters = resolver.GetParameters();
_args_BookNode_GetAuthorAsync[0] = bindingResolver.GetBinding(parameters[0]);
_args_BookNode_GetAuthorAsync[1] = bindingResolver.GetBinding(parameters[1]);
}
}
}
}

public static HotChocolate.Resolvers.FieldResolverDelegates BookNode_GetAuthorAsync()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,35 +19,41 @@ namespace TestNamespace
{
internal static class TestTypeResolvers
{
private static readonly object _sync = new object();
private static bool _bindingsInitialized;
private readonly static global::HotChocolate.Internal.IParameterBinding[] _args_TestType_GetTest = new global::HotChocolate.Internal.IParameterBinding[1];

public static void InitializeBindings(global::HotChocolate.Internal.IParameterBindingResolver bindingResolver)
{
if (_bindingsInitialized)
if (!_bindingsInitialized)
{
return;
}
_bindingsInitialized = true;

const global::System.Reflection.BindingFlags bindingFlags =
global::System.Reflection.BindingFlags.Public
| global::System.Reflection.BindingFlags.NonPublic
| global::System.Reflection.BindingFlags.Static;

var type = typeof(global::TestNamespace.TestType);
global::System.Reflection.MethodInfo resolver = default!;
global::System.Reflection.ParameterInfo[] parameters = default!;

resolver = type.GetMethod(
"GetTest",
bindingFlags,
new global::System.Type[]
lock (_sync)
{
typeof(int)
})!;
parameters = resolver.GetParameters();
_args_TestType_GetTest[0] = bindingResolver.GetBinding(parameters[0]);
if (!_bindingsInitialized)
{

const global::System.Reflection.BindingFlags bindingFlags =
global::System.Reflection.BindingFlags.Public
| global::System.Reflection.BindingFlags.NonPublic
| global::System.Reflection.BindingFlags.Static;

var type = typeof(global::TestNamespace.TestType);
global::System.Reflection.MethodInfo resolver = default!;
global::System.Reflection.ParameterInfo[] parameters = default!;
_bindingsInitialized = true;

resolver = type.GetMethod(
"GetTest",
bindingFlags,
new global::System.Type[]
{
typeof(int)
})!;
parameters = resolver.GetParameters();
_args_TestType_GetTest[0] = bindingResolver.GetBinding(parameters[0]);
}
}
}
}

public static HotChocolate.Resolvers.FieldResolverDelegates TestType_GetTest()
Expand Down
Loading
Loading