Skip to content

Commit

Permalink
Merge pull request #703 from Sergio0694/dev/external-static-fields
Browse files Browse the repository at this point in the history
Implement full support for static fields
  • Loading branch information
Sergio0694 authored Dec 6, 2023
2 parents 361b0b4 + b832874 commit f56458d
Show file tree
Hide file tree
Showing 11 changed files with 400 additions and 171 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ public static string GetHlslSource(
Dictionary<IMethodSymbol, MethodDeclarationSyntax> instanceMethods = new(SymbolEqualityComparer.Default);
Dictionary<IMethodSymbol, (MethodDeclarationSyntax, MethodDeclarationSyntax)> constructors = new(SymbolEqualityComparer.Default);
Dictionary<IFieldSymbol, string> constantDefinitions = new(SymbolEqualityComparer.Default);
Dictionary<IFieldSymbol, (string, string, string?)> staticFieldDefinitions = new(SymbolEqualityComparer.Default);

token.ThrowIfCancellationRequested();

Expand All @@ -78,6 +79,7 @@ public static string GetHlslSource(
instanceMethods,
constructors,
constantDefinitions,
staticFieldDefinitions,
token,
out bool methodsNeedD2D1RequiresScenePosition);

Expand All @@ -89,6 +91,7 @@ public static string GetHlslSource(
structDeclarationSymbol,
discoveredTypes,
constantDefinitions,
staticFieldDefinitions,
token,
out bool fieldsNeedD2D1RequiresScenePosition);

Expand Down Expand Up @@ -238,6 +241,7 @@ private static void GetInstanceFields(
/// <param name="structDeclarationSymbol">The type symbol for the shader type.</param>
/// <param name="discoveredTypes">The collection of currently discovered types.</param>
/// <param name="constantDefinitions">The collection of discovered constant definitions.</param>
/// <param name="staticFieldDefinitions">The collection of discovered static field definitions.</param>
/// <param name="token">The <see cref="CancellationToken"/> used to cancel the operation, if needed.</param>
/// <param name="needsD2D1RequiresScenePosition">Whether or not the shader needs the <c>[D2DRequiresScenePosition]</c> annotation.</param>
/// <returns>A sequence of static constant fields in <paramref name="structDeclarationSymbol"/>.</returns>
Expand All @@ -247,6 +251,7 @@ private static void GetInstanceFields(
INamedTypeSymbol structDeclarationSymbol,
ICollection<INamedTypeSymbol> discoveredTypes,
IDictionary<IFieldSymbol, string> constantDefinitions,
IDictionary<IFieldSymbol, (string, string, string?)> staticFieldDefinitions,
CancellationToken token,
out bool needsD2D1RequiresScenePosition)
{
Expand All @@ -256,52 +261,35 @@ private static void GetInstanceFields(

foreach (ISymbol memberSymbol in structDeclarationSymbol.GetMembers())
{
// Find all declared static fields in the type
if (memberSymbol is not IFieldSymbol { IsImplicitlyDeclared: false, IsStatic: true, IsConst: false, } fieldSymbol)
if (memberSymbol is not IFieldSymbol fieldSymbol)
{
continue;
}

if (!fieldSymbol.TryGetSyntaxNode(token, out VariableDeclaratorSyntax? variableDeclarator))
{
continue;
}

// Constant properties must be of a primitive, vector or matrix type
if (fieldSymbol.Type is not INamedTypeSymbol typeSymbol ||
!HlslKnownTypes.IsKnownHlslType(typeSymbol.GetFullyQualifiedMetadataName()))
{
diagnostics.Add(InvalidShaderStaticFieldType, variableDeclarator, structDeclarationSymbol, fieldSymbol.Name, fieldSymbol.Type);

continue;
}

_ = HlslKnownKeywords.TryGetMappedName(fieldSymbol.Name, out string? mapping);

string typeDeclaration = fieldSymbol.IsReadOnly switch
{
true => $"static const {HlslKnownTypes.GetMappedName(typeSymbol)}",
false => $"static {HlslKnownTypes.GetMappedName(typeSymbol)}"
};

token.ThrowIfCancellationRequested();

StaticFieldRewriter staticFieldRewriter = new(
if (HlslDefinitionsSyntaxProcessor.TryGetStaticField(
structDeclarationSymbol,
fieldSymbol,
semanticModel,
discoveredTypes,
constantDefinitions,
staticFieldDefinitions,
diagnostics,
token);

ExpressionSyntax? processedDeclaration = staticFieldRewriter.Visit(variableDeclarator);

token.ThrowIfCancellationRequested();

string? assignment = processedDeclaration?.NormalizeWhitespace(eol: "\n").ToFullString();
token,
out string? name,
out string? typeDeclaration,
out string? assignmentExpression,
out StaticFieldRewriter? staticFieldRewriter))
{
needsD2D1RequiresScenePosition |= staticFieldRewriter.NeedsD2DRequiresScenePositionAttribute;

needsD2D1RequiresScenePosition |= staticFieldRewriter.NeedsD2DRequiresScenePositionAttribute;
builder.Add((name, typeDeclaration, assignmentExpression));
}
}

builder.Add((mapping ?? fieldSymbol.Name, typeDeclaration, assignment));
// Also gather the external static fields (same as in the DX12 generator)
foreach ((string, string, string?) externalField in staticFieldDefinitions.Values)
{
builder.Add(externalField);
}

return builder.ToImmutable();
Expand All @@ -318,6 +306,7 @@ private static void GetInstanceFields(
/// <param name="instanceMethods">The collection of discovered instance methods for custom struct types.</param>
/// <param name="constructors">The collection of discovered constructors for custom struct types.</param>
/// <param name="constantDefinitions">The collection of discovered constant definitions.</param>
/// <param name="staticFieldDefinitions">The collection of discovered static field definitions.</param>
/// <param name="needsD2D1RequiresScenePosition">Whether or not the shader needs the <c>[D2DRequiresScenePosition]</c> annotation.</param>
/// <param name="token">The <see cref="CancellationToken"/> used to cancel the operation, if needed.</param>
/// <returns>A sequence of processed methods in <paramref name="structDeclarationSymbol"/>, and the entry point.</returns>
Expand All @@ -330,6 +319,7 @@ private static (string EntryPoint, ImmutableArray<(string Signature, string Defi
IDictionary<IMethodSymbol, MethodDeclarationSyntax> instanceMethods,
IDictionary<IMethodSymbol, (MethodDeclarationSyntax, MethodDeclarationSyntax)> constructors,
IDictionary<IFieldSymbol, string> constantDefinitions,
IDictionary<IFieldSymbol, (string, string, string?)> staticFieldDefinitions,
CancellationToken token,
out bool needsD2D1RequiresScenePosition)
{
Expand Down Expand Up @@ -374,6 +364,7 @@ private static (string EntryPoint, ImmutableArray<(string Signature, string Defi
instanceMethods,
constructors,
constantDefinitions,
staticFieldDefinitions,
diagnostics,
token,
isShaderEntryPoint);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -397,19 +397,19 @@ partial class DiagnosticDescriptors
helpLinkUri: "https://github.com/Sergio0694/ComputeSharp");

/// <summary>
/// Gets a <see cref="DiagnosticDescriptor"/> for an invalid shader static readonly field type.
/// Gets a <see cref="DiagnosticDescriptor"/> for an invalid static field type.
/// <para>
/// Format: <c>"The pixel shader of type {0} contains a static readonly field "{1}" of an invalid type {2} (only primitive, vector and matrix types are supported)"</c>.
/// Format: <c>"The pixel shader of type {0} contains or references a static field "{1}" of an invalid type {2} (only primitive, vector and matrix types are supported)"</c>.
/// </para>
/// </summary>
public static readonly DiagnosticDescriptor InvalidShaderStaticFieldType = new(
id: "CMPSD2D0030",
title: "Invalid shader static readonly field type",
messageFormat: """The pixel shader of type {0} contains a static readonly field "{1}" of an invalid type {2} (only primitive, vector and matrix types are supported)""",
title: "Invalid shader static field type",
messageFormat: """The pixel shader of type {0} contains or references a static field "{1}" of an invalid type {2} (only primitive, vector and matrix types are supported)""",
category: "ComputeSharp.D2D1.Shaders",
defaultSeverity: DiagnosticSeverity.Error,
isEnabledByDefault: true,
description: "A type representing a pixel shader contains a static readonly field of a type that is not supported.",
description: "A type representing a pixel shader contains or references a static field of a type that is not supported.",
helpLinkUri: "https://github.com/Sergio0694/ComputeSharp");

/// <summary>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,18 +29,14 @@ private partial SyntaxNode RewriteSampledTextureAccess(IInvocationOperation oper
}

/// <inheritdoc/>
private partial void TrackKnownPropertyAccess(IMemberReferenceOperation operation, MemberAccessExpressionSyntax node, string mappedName)
partial void TrackKnownMethodInvocation(string metadataName)
{
// No special tracking is needed for D2D1 shaders
NeedsD2DRequiresScenePositionAttribute |= HlslKnownMethods.NeedsD2DRequiresScenePositionAttribute(metadataName);
}

/// <inheritdoc/>
private partial void TrackKnownMethodInvocation(string metadataName)
partial void TrackExternalStaticField(StaticFieldRewriter staticFieldRewriter)
{
// Track whether the method needs [D2DRequiresScenePosition]
if (HlslKnownMethods.NeedsD2DRequiresScenePositionAttribute(metadataName))
{
NeedsD2DRequiresScenePositionAttribute = true;
}
NeedsD2DRequiresScenePositionAttribute |= staticFieldRewriter.NeedsD2DRequiresScenePositionAttribute;
}
}
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
using System;
using System.Collections.Generic;
using System.Collections.Immutable;
using System.Diagnostics.CodeAnalysis;
using System.Threading;
using ComputeSharp.SourceGeneration.Extensions;
using ComputeSharp.SourceGeneration.Helpers;
using ComputeSharp.SourceGeneration.Mappings;
using ComputeSharp.SourceGeneration.Models;
using ComputeSharp.SourceGeneration.SyntaxRewriters;
using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.CSharp;
using Microsoft.CodeAnalysis.CSharp.Syntax;
Expand Down Expand Up @@ -38,6 +41,100 @@ internal static class HlslDefinitionsSyntaxProcessor
return builder.ToImmutable();
}

/// <summary>
/// Tries to get and rewrite a given static field to be used in a shader.
/// </summary>
/// <param name="structDeclarationSymbol">The type symbol for the shader type.</param>
/// <param name="fieldSymbol">The symbol for the field to analyze.</param>
/// <param name="semanticModel">The <see cref="SemanticModelProvider"/> instance for the type to process.</param>
/// <param name="discoveredTypes">The collection of currently discovered types.</param>
/// <param name="constantDefinitions">The collection of discovered constant definitions.</param>
/// <param name="staticFieldDefinitions">The collection of discovered static field definitions.</param>
/// <param name="diagnostics">The collection of produced <see cref="DiagnosticInfo"/> instances.</param>
/// <param name="token">The <see cref="CancellationToken"/> used to cancel the operation, if needed.</param>
/// <param name="name">The mapped name for the field.</param>
/// <param name="typeDeclaration">The type declaration for the field.</param>
/// <param name="assignmentExpression">The assignment expression for the field, if present.</param>
/// <param name="staticFieldRewriter">The <see cref="StaticFieldRewriter"/> instance used to rewrite the field expression.</param>
/// <returns>Whether the field was processed successfully and is valid.</returns>
public static bool TryGetStaticField(
INamedTypeSymbol structDeclarationSymbol,
IFieldSymbol fieldSymbol,
SemanticModelProvider semanticModel,
ICollection<INamedTypeSymbol> discoveredTypes,
IDictionary<IFieldSymbol, string> constantDefinitions,
IDictionary<IFieldSymbol, (string, string, string?)> staticFieldDefinitions,
ImmutableArrayBuilder<DiagnosticInfo> diagnostics,
CancellationToken token,
[NotNullWhen(true)] out string? name,
[NotNullWhen(true)] out string? typeDeclaration,
out string? assignmentExpression,
[NotNullWhen(true)] out StaticFieldRewriter? staticFieldRewriter)
{
if (fieldSymbol.IsImplicitlyDeclared || !fieldSymbol.IsStatic || fieldSymbol.IsConst)
{
goto Failure;
}

if (!fieldSymbol.TryGetSyntaxNode(token, out VariableDeclaratorSyntax? variableDeclarator))
{
goto Failure;
}

// Static fields must be of a primitive, vector or matrix type
if (fieldSymbol.Type is not INamedTypeSymbol typeSymbol ||
!HlslKnownTypes.IsKnownHlslType(typeSymbol.GetFullyQualifiedMetadataName()))
{
diagnostics.Add(InvalidShaderStaticFieldType, variableDeclarator, structDeclarationSymbol, fieldSymbol.Name, fieldSymbol.Type);

goto Failure;
}

_ = HlslKnownKeywords.TryGetMappedName(fieldSymbol.Name, out string? mapping);

// The field name is either the mapped name (if a reserved name) or just the field name.
// This method is shared across external fields too, and callers can just override this.
name = mapping ?? fieldSymbol.Name;

// Readonly fields are rewritten to static const fields, and mutable fields are just static.
// Note that there's no protection for mutable static fields that may have been written to
// in C# elsewhere. Shader authors should be aware that those writes would not appear in HLSL,
// as each shader invocation would only see the initial assignment value (or the default value).
typeDeclaration = fieldSymbol.IsReadOnly switch
{
true => $"static const {HlslKnownTypes.GetMappedName(typeSymbol)}",
false => $"static {HlslKnownTypes.GetMappedName(typeSymbol)}"
};

token.ThrowIfCancellationRequested();

// Create the rewriter to use, which is also returned to callers so they can extract any additional
// info if needed. For instance, the D2D generator will check if the dispatch position is required.
staticFieldRewriter = new StaticFieldRewriter(
semanticModel,
discoveredTypes,
constantDefinitions,
staticFieldDefinitions,
diagnostics,
token);

ExpressionSyntax? processedDeclaration = staticFieldRewriter.Visit(variableDeclarator);

token.ThrowIfCancellationRequested();

assignmentExpression = processedDeclaration?.NormalizeWhitespace(eol: "\n").ToFullString();

return true;

Failure:
name = null;
typeDeclaration = null;
assignmentExpression = null;
staticFieldRewriter = null;

return false;
}

/// <summary>
/// Gets the sequence of processed discovered custom types.
/// </summary>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,18 +33,21 @@ internal abstract partial class HlslSourceRewriter : CSharpSyntaxRewriter
/// <param name="semanticModel">The <see cref="Microsoft.CodeAnalysis.SemanticModel"/> instance for the target syntax tree.</param>
/// <param name="discoveredTypes">The set of discovered custom types.</param>
/// <param name="constantDefinitions">The collection of discovered constant definitions.</param>
/// <param name="staticFieldDefinitions">The collection of discovered static field definitions.</param>
/// <param name="diagnostics">The collection of produced <see cref="DiagnosticInfo"/> instances.</param>
/// <param name="token">The <see cref="System.Threading.CancellationToken"/> value for the current operation.</param>
protected HlslSourceRewriter(
SemanticModelProvider semanticModel,
ICollection<INamedTypeSymbol> discoveredTypes,
IDictionary<IFieldSymbol, string> constantDefinitions,
IDictionary<IFieldSymbol, (string, string, string?)> staticFieldDefinitions,
ImmutableArrayBuilder<DiagnosticInfo> diagnostics,
CancellationToken token)
{
SemanticModel = semanticModel;
DiscoveredTypes = discoveredTypes;
ConstantDefinitions = constantDefinitions;
StaticFieldDefinitions = staticFieldDefinitions;
Diagnostics = diagnostics;
CancellationToken = token;
}
Expand All @@ -64,6 +67,11 @@ protected HlslSourceRewriter(
/// </summary>
protected IDictionary<IFieldSymbol, string> ConstantDefinitions { get; }

/// <summary>
/// Gets the collection of discovered static field definitions.
/// </summary>
protected IDictionary<IFieldSymbol, (string, string, string?)> StaticFieldDefinitions { get; }

/// <summary>
/// Gets the collection of produced <see cref="DiagnosticInfo"/> instances.
/// </summary>
Expand Down Expand Up @@ -382,11 +390,15 @@ public sealed override SyntaxNode VisitDefaultExpression(DefaultExpressionSyntax
}

/// <inheritdoc/>
public sealed override SyntaxNode? VisitIdentifierName(IdentifierNameSyntax node)
public override SyntaxNode? VisitIdentifierName(IdentifierNameSyntax node)
{
IdentifierNameSyntax updatedNode = (IdentifierNameSyntax)base.VisitIdentifierName(node)!;

if (SemanticModel.For(node).GetOperation(node) is IFieldReferenceOperation operation &&
// Only gather constants directly accessed by name. We can also pre-filter to exclude invocations
// and member access expressions, as those will be handled separately. Doing so avoids unnecessarily
// retrieving semantic information for every identifier, which would otherwise be fairly expensive.
if (node.Parent is not (InvocationExpressionSyntax or MemberAccessExpressionSyntax) &&
SemanticModel.For(node).GetOperation(node) is IFieldReferenceOperation operation &&
operation.Field.IsConst &&
operation.Type!.TypeKind != TypeKind.Enum)
{
Expand Down
Loading

0 comments on commit f56458d

Please sign in to comment.