Skip to content

Commit

Permalink
Improved selector when using batching. (#7542)
Browse files Browse the repository at this point in the history
(cherry picked from commit 314c9d8)
  • Loading branch information
michaelstaib committed Oct 1, 2024
1 parent 4270746 commit 1496620
Show file tree
Hide file tree
Showing 6 changed files with 212 additions and 0 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
using System.Linq.Expressions;

namespace HotChocolate.Pagination.Expressions;

internal sealed class ExtractOrderPropertiesVisitor : ExpressionVisitor
{
private const string _orderByMethod = "OrderBy";
private const string _thenByMethod = "ThenBy";
private const string _orderByDescendingMethod = "OrderByDescending";
private const string _thenByDescendingMethod = "ThenByDescending";
private bool _isOrderScope;

public List<MemberExpression> OrderProperties { get; } = [];

protected override Expression VisitMethodCall(MethodCallExpression node)
{
if (node.Method.Name == _orderByMethod ||
node.Method.Name == _thenByMethod ||
node.Method.Name == _orderByDescendingMethod ||
node.Method.Name == _thenByDescendingMethod)
{
_isOrderScope = true;

var lambda = StripQuotes(node.Arguments[1]);
Visit(lambda.Body);

_isOrderScope = false;
}

return base.VisitMethodCall(node);
}

protected override Expression VisitMember(MemberExpression node)
{
if (_isOrderScope)
{
// we only collect members that are within an order method.
OrderProperties.Add(node);
}

return base.VisitMember(node);
}

private static LambdaExpression StripQuotes(Expression expression)
{
while (expression.NodeType == ExpressionType.Quote)
{
expression = ((UnaryExpression)expression).Operand;
}

return (LambdaExpression)expression;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
using System.Linq.Expressions;

namespace HotChocolate.Pagination.Expressions;

internal sealed class ExtractSelectExpressionVisitor : ExpressionVisitor
{
private const string _selectMethod = "Select";

public LambdaExpression? Selector { get; private set; }

protected override Expression VisitMethodCall(MethodCallExpression node)
{
if (node.Method.Name == _selectMethod && node.Arguments.Count == 2)
{
var lambda = StripQuotes(node.Arguments[1]);
if (lambda.Type.IsGenericType
&& lambda.Type.GetGenericTypeDefinition() == typeof(Func<,>))
{
// we make sure that the selector is of type Expression<Func<T, T>>
// otherwise we are not interested in it.
var genericArgs = lambda.Type.GetGenericArguments();
if (genericArgs[0] == genericArgs[1])
{
Selector = lambda;
}
}
}

return base.VisitMethodCall(node);
}

private static LambdaExpression StripQuotes(Expression e)
{
while (e.NodeType == ExpressionType.Quote)
{
e = ((UnaryExpression)e).Operand;
}

return (LambdaExpression)e;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
using System.Linq.Expressions;

namespace HotChocolate.Pagination.Expressions;

internal static class QueryHelpers
{
public static IQueryable<T> EnsureOrderPropsAreSelected<T>(
IQueryable<T> query)
{
var selector = ExtractCurrentSelector(query);
if (selector is null)
{
return query;
}

var orderByProperties = ExtractOrderProperties(query);
if(orderByProperties.Count == 0)
{
return query;
}

var updatedSelector = AddPropertiesInSelector(selector, orderByProperties);
return ReplaceSelector(query, updatedSelector);
}

private static Expression<Func<T, T>>? ExtractCurrentSelector<T>(
IQueryable<T> query)
{
var visitor = new ExtractSelectExpressionVisitor();
visitor.Visit(query.Expression);
return visitor.Selector as Expression<Func<T, T>>;
}

private static Expression<Func<T, T>> AddPropertiesInSelector<T>(
Expression<Func<T, T>> selector,
List<MemberExpression> properties)
{
var parameter = selector.Parameters[0];
var bindings = ((MemberInitExpression)selector.Body).Bindings.Cast<MemberAssignment>().ToList();

foreach (var property in properties)
{
var propertyName = property.Member.Name;
if(property.Expression is not ParameterExpression parameterExpression
|| bindings.Any(b => b.Member.Name == propertyName))
{
continue;
}

var replacer = new ReplacerParameterVisitor(parameterExpression, parameter);
var rewrittenProperty = (MemberExpression)replacer.Visit(property);
bindings.Add(Expression.Bind(rewrittenProperty.Member, rewrittenProperty));
}

var newBody = Expression.MemberInit(Expression.New(typeof(T)), bindings);
return Expression.Lambda<Func<T, T>>(newBody, parameter);
}

private static List<MemberExpression> ExtractOrderProperties<T>(
IQueryable<T> query)
{
var visitor = new ExtractOrderPropertiesVisitor();
visitor.Visit(query.Expression);
return visitor.OrderProperties;
}

private static IQueryable<T> ReplaceSelector<T>(
IQueryable<T> query,
Expression<Func<T, T>> newSelector)
{
var visitor = new ReplaceSelectorVisitor<T>(newSelector);
var newExpression = visitor.Visit(query.Expression);
return query.Provider.CreateQuery<T>(newExpression);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
using System.Linq.Expressions;

namespace HotChocolate.Pagination.Expressions;

internal sealed class ReplaceSelectorVisitor<T>(
Expression<Func<T, T>> newSelector)
: ExpressionVisitor
{
private const string _selectMethod = "Select";

protected override Expression VisitMethodCall(MethodCallExpression node)
{
if (node.Method.Name == _selectMethod && node.Arguments.Count == 2)
{
return Expression.Call(
node.Method.DeclaringType!,
node.Method.Name,
[typeof(T), typeof(T)],
node.Arguments[0],
newSelector);
}

return base.VisitMethodCall(node);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
using System.Linq.Expressions;

namespace HotChocolate.Pagination.Expressions;

internal sealed class ReplacerParameterVisitor(
ParameterExpression oldParameter,
ParameterExpression newParameter)
: ExpressionVisitor
{
protected override Expression VisitParameter(ParameterExpression node)
=> node == oldParameter
? newParameter
: base.VisitParameter(node);
}
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,8 @@ public static async ValueTask<Page<T>> ToPageAsync<T>(
throw new ArgumentNullException(nameof(source));
}

source = QueryHelpers.EnsureOrderPropsAreSelected(source);

var keys = ParseDataSetKeys(source);

if (keys.Length == 0)
Expand Down Expand Up @@ -215,6 +217,8 @@ public static async ValueTask<Dictionary<TKey, Page<TValue>>> ToBatchPageAsync<T
nameof(arguments));
}

source = QueryHelpers.EnsureOrderPropsAreSelected(source);

// we need to move the ordering into the select expression we are constructing
// so that the groupBy will not remove it. The first thing we do here is to extract the order expressions
// and to create a new expression that will not contain it anymore.
Expand Down

0 comments on commit 1496620

Please sign in to comment.