Skip to content

Commit

Permalink
Added nullability awareness to projections. (#7541)
Browse files Browse the repository at this point in the history
(cherry picked from commit 97ebf4b)
  • Loading branch information
michaelstaib committed Oct 1, 2024
1 parent 470e4dd commit 2a2943d
Show file tree
Hide file tree
Showing 18 changed files with 838 additions and 110 deletions.
1 change: 1 addition & 0 deletions src/GreenDonut/src/Core/GreenDonut.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
</PropertyGroup>

<ItemGroup>
<InternalsVisibleTo Include="HotChocolate.Execution" />
<InternalsVisibleTo Include="HotChocolate.Pagination.Batching" />
</ItemGroup>

Expand Down
Original file line number Diff line number Diff line change
@@ -1,9 +1,15 @@
#if NET6_0_OR_GREATER
using System.Buffers;
using System.Buffers.Text;
using System.Diagnostics.CodeAnalysis;
using System.Linq.Expressions;
using System.Text;
using System.Runtime.CompilerServices;
using GreenDonut.Projections;
using HotChocolate.Execution.Projections;
using HotChocolate.Types;
using HotChocolate.Types.Descriptors.Definitions;
using HotChocolate.Utilities;

// ReSharper disable once CheckNamespace
namespace HotChocolate.Execution.Processing;
Expand All @@ -27,17 +33,136 @@ public static class HotChocolateExecutionSelectionExtensions
/// <returns>
/// Returns a selector expression that can be used for data projections.
/// </returns>
#if NET8_0_OR_GREATER
[Experimental(Experiments.Projections)]
#endif
public static Expression<Func<TValue, TValue>> AsSelector<TValue>(
this ISelection selection)
=> GetOrCreateExpression<TValue>(selection);
{
// we first check if we already have an expression for this selection,
// this would be the cheapest way to get the expression.
if(TryGetExpression<TValue>(selection, out var expression))
{
return expression;
}

// if we do not have an expression we need to create one.
// we first check what kind of field selection we have,
// connection, collection or single field.
var flags = ((ObjectField)selection.Field).Flags;

if ((flags & FieldFlags.Connection) == FieldFlags.Connection)
{
var builder = new DefaultSelectorBuilder<TValue>();
var buffer = ArrayPool<ISelection>.Shared.Rent(16);
var count = GetConnectionSelections(selection, buffer);
for (var i = 0; i < count; i++)
{
builder.Add(GetOrCreateExpression<TValue>(buffer[i]));
}
ArrayPool<ISelection>.Shared.Return(buffer);
return GetOrCreateExpression<TValue>(selection, builder);
}

if ((flags & FieldFlags.CollectionSegment) == FieldFlags.CollectionSegment)
{
var builder = new DefaultSelectorBuilder<TValue>();
var buffer = ArrayPool<ISelection>.Shared.Rent(16);
var count = GetCollectionSelections(selection, buffer);
for (var i = 0; i < count; i++)
{
builder.Add(GetOrCreateExpression<TValue>(buffer[i]));
}
ArrayPool<ISelection>.Shared.Return(buffer);
return GetOrCreateExpression<TValue>(selection, builder);
}

return GetOrCreateExpression<TValue>(selection);
}

private static Expression<Func<TValue, TValue>> GetOrCreateExpression<TValue>(
ISelection selection)
{
return selection.DeclaringOperation.GetOrAddState(
=> selection.DeclaringOperation.GetOrAddState(
CreateExpressionKey(selection.Id),
static (_, ctx) => ctx._builder.BuildExpression<TValue>(ctx.selection),
(_builder, selection));

#if NET8_0_OR_GREATER
[Experimental(Experiments.Projections)]
#endif
private static Expression<Func<TValue, TValue>> GetOrCreateExpression<TValue>(
ISelection selection,
ISelectorBuilder builder)
=> selection.DeclaringOperation.GetOrAddState(
CreateExpressionKey(selection.Id),
static (_, ctx) => ctx.builder.TryCompile<TValue>()!,
(builder, selection));

private static bool TryGetExpression<TValue>(
ISelection selection,
[NotNullWhen(true)] out Expression<Func<TValue, TValue>>? expression)
=> selection.DeclaringOperation.TryGetState(CreateExpressionKey(selection.Id), out expression);

private static int GetConnectionSelections(ISelection selection, Span<ISelection> buffer)
{
var pageType = (ObjectType)selection.Field.Type.NamedType();
var connectionSelections = selection.DeclaringOperation.GetSelectionSet(selection, pageType);
var count = 0;

foreach (var connectionChild in connectionSelections.Selections)
{
if (connectionChild.Field.Name.EqualsOrdinal("nodes"))
{
if (buffer.Length == count)
{
throw new InvalidOperationException("To many alias selections of nodes and edges.");
}

buffer[count++] = connectionChild;
}
else if (connectionChild.Field.Name.EqualsOrdinal("edges"))
{
var edgeType = (ObjectType)connectionChild.Field.Type.NamedType();
var edgeSelections = connectionChild.DeclaringOperation.GetSelectionSet(connectionChild, edgeType);

foreach (var edgeChild in edgeSelections.Selections)
{
if (edgeChild.Field.Name.EqualsOrdinal("node"))
{
if (buffer.Length == count)
{
throw new InvalidOperationException("To many alias selections of nodes and edges.");
}

buffer[count++] = edgeChild;
}
}
}
}

return count;
}

private static int GetCollectionSelections(ISelection selection, Span<ISelection> buffer)
{
var pageType = (ObjectType)selection.Field.Type.NamedType();
var connectionSelections = selection.DeclaringOperation.GetSelectionSet(selection, pageType);
var count = 0;

foreach (var connectionChild in connectionSelections.Selections)
{
if (connectionChild.Field.Name.EqualsOrdinal("items"))
{
if (buffer.Length == count)
{
throw new InvalidOperationException("To many alias selections of items.");
}

buffer[count++] = connectionChild;
}
}

return count;
}

private static string CreateExpressionKey(int key)
Expand All @@ -63,7 +188,7 @@ private static int EstimateIntLength(int value)
}

// if the number is negative we need one more digit for the sign
var length = (value < 0) ? 1 : 0;
var length = value < 0 ? 1 : 0;

// we add the number of digits the number has to the length of the number.
length += (int)Math.Floor(Math.Log10(Math.Abs(value)) + 1);
Expand Down
19 changes: 19 additions & 0 deletions src/HotChocolate/Core/src/Execution/Processing/Operation.cs
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,25 @@ public long CreateIncludeFlags(IVariableValueCollection variables)
return context;
}

public bool TryGetState<TState>(out TState? state)
{
var key = typeof(TState).FullName ?? throw new InvalidOperationException();
return TryGetState(key, out state);
}

public bool TryGetState<TState>(string key, out TState? state)
{
if(_contextData.TryGetValue(key, out var value)
&& value is TState casted)
{
state = casted;
return true;
}

state = default;
return false;
}

public TState GetOrAddState<TState>(Func<TState> createState)
=> GetOrAddState<TState, object?>(_ => createState(), null);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,8 @@

using System.Buffers;
using System.Diagnostics.CodeAnalysis;
using HotChocolate.Execution;
using HotChocolate.Execution.Processing;
using HotChocolate.Pagination;
using HotChocolate.Types;
using HotChocolate.Types.Descriptors.Definitions;
using HotChocolate.Utilities;

// ReSharper disable once CheckNamespace
namespace GreenDonut.Projections;
Expand Down Expand Up @@ -44,6 +40,16 @@ public static ISelectionDataLoader<TKey, TValue> Select<TKey, TValue>(
ISelection selection)
where TKey : notnull
{
if (dataLoader == null)
{
throw new ArgumentNullException(nameof(dataLoader));
}

if (selection == null)
{
throw new ArgumentNullException(nameof(selection));
}

var expression = selection.AsSelector<TValue>();
return dataLoader.Select(expression);
}
Expand Down Expand Up @@ -71,99 +77,19 @@ public static IPagingDataLoader<TKey, Page<TValue>> Select<TKey, TValue>(
ISelection selection)
where TKey : notnull
{
var flags = ((ObjectField)selection.Field).Flags;

if ((flags & FieldFlags.Connection) == FieldFlags.Connection)
if (dataLoader == null)
{
var buffer = ArrayPool<ISelection>.Shared.Rent(16);
var count = GetConnectionSelections(selection, buffer);
for (var i = 0; i < count; i++)
{
var expression = buffer[i].AsSelector<TValue>();
dataLoader.Select(expression);
}
ArrayPool<ISelection>.Shared.Return(buffer);
}
else if ((flags & FieldFlags.CollectionSegment) == FieldFlags.CollectionSegment)
{
var buffer = ArrayPool<ISelection>.Shared.Rent(16);
var count = GetCollectionSelections(selection, buffer);
for (var i = 0; i < count; i++)
{
var expression = buffer[i].AsSelector<TValue>();
dataLoader.Select(expression);
}
ArrayPool<ISelection>.Shared.Return(buffer);
}
else
{
var expression = selection.AsSelector<TValue>();
dataLoader.Select(expression);
throw new ArgumentNullException(nameof(dataLoader));
}

return dataLoader;
}

private static int GetConnectionSelections(ISelection selection, Span<ISelection> buffer)
{
var pageType = (ObjectType)selection.Field.Type.NamedType();
var connectionSelections = selection.DeclaringOperation.GetSelectionSet(selection, pageType);
var count = 0;

foreach (var connectionChild in connectionSelections.Selections)
{
if (connectionChild.Field.Name.EqualsOrdinal("nodes"))
{
if (buffer.Length == count)
{
throw new InvalidOperationException("To many alias selections of nodes and edges.");
}

buffer[count++] = connectionChild;
}
else if (connectionChild.Field.Name.EqualsOrdinal("edges"))
{
var edgeType = (ObjectType)selection.Field.Type.NamedType();
var edgeSelections = selection.DeclaringOperation.GetSelectionSet(connectionChild, edgeType);

foreach (var edgeChild in edgeSelections.Selections)
{
if (edgeChild.Field.Name.EqualsOrdinal("node"))
{
if (buffer.Length == count)
{
throw new InvalidOperationException("To many alias selections of nodes and edges.");
}

buffer[count++] = edgeChild;
}
}
}
}

return count;
}

private static int GetCollectionSelections(ISelection selection, Span<ISelection> buffer)
{
var pageType = (ObjectType)selection.Field.Type.NamedType();
var connectionSelections = selection.DeclaringOperation.GetSelectionSet(selection, pageType);
var count = 0;

foreach (var connectionChild in connectionSelections.Selections)
if (selection == null)
{
if (connectionChild.Field.Name.EqualsOrdinal("items"))
{
if (buffer.Length == count)
{
throw new InvalidOperationException("To many alias selections of items.");
}

buffer[count++] = connectionChild;
}
throw new ArgumentNullException(nameof(selection));
}

return count;
var expression = selection.AsSelector<TValue>();
dataLoader.Select(expression);
return dataLoader;
}
}
#endif
Loading

0 comments on commit 2a2943d

Please sign in to comment.