Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Check for cancellation while doing HLSL rewriting #695

Merged
merged 3 commits into from
Nov 29, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
using System.Collections.Generic;
using System.Collections.Immutable;
using System.Linq;
using System.Threading;
using ComputeSharp.SourceGeneration.Extensions;
using ComputeSharp.SourceGeneration.Helpers;
using ComputeSharp.SourceGeneration.Mappings;
Expand All @@ -28,20 +29,20 @@ private static partial class HlslSource
/// </summary>
/// <param name="diagnostics">The collection of produced <see cref="DiagnosticInfo"/> instances.</param>
/// <param name="compilation">The input <see cref="Compilation"/> object currently in use.</param>
/// <param name="structDeclaration">The <see cref="StructDeclarationSyntax"/> node to process.</param>
/// <param name="structDeclarationSymbol">The <see cref="INamedTypeSymbol"/> for <paramref name="structDeclaration"/>.</param>
/// <param name="structDeclarationSymbol">The <see cref="INamedTypeSymbol"/> for the shader type.</param>
/// <param name="inputCount">The number of inputs for the shader.</param>
/// <param name="inputSimpleIndices">The indicess of the simple shader inputs.</param>
/// <param name="inputComplexIndices">The indicess of the complex shader inputs.</param>
/// <param name="token">The <see cref="CancellationToken"/> used to cancel the operation, if needed.</param>
/// <returns>The HLSL source for the shader.</returns>
public static string GetHlslSource(
ImmutableArrayBuilder<DiagnosticInfo> diagnostics,
Compilation compilation,
StructDeclarationSyntax structDeclaration,
INamedTypeSymbol structDeclarationSymbol,
int inputCount,
ImmutableArray<int> inputSimpleIndices,
ImmutableArray<int> inputComplexIndices)
ImmutableArray<int> inputComplexIndices,
CancellationToken token)
{
// Detect any invalid properties
HlslDefinitionsSyntaxProcessor.DetectAndReportInvalidPropertyDeclarations(diagnostics, structDeclarationSymbol);
Expand All @@ -52,6 +53,8 @@ public static string GetHlslSource(
Dictionary<IMethodSymbol, MethodDeclarationSyntax> instanceMethods = new(SymbolEqualityComparer.Default);
Dictionary<IFieldSymbol, string> constantDefinitions = new(SymbolEqualityComparer.Default);

token.ThrowIfCancellationRequested();

// Extract information on all captured fields
GetInstanceFields(
diagnostics,
Expand All @@ -60,20 +63,58 @@ public static string GetHlslSource(
out ImmutableArray<(string Name, string HlslType)> valueFields,
out ImmutableArray<(string Name, string HlslType, int Index)> resourceTextureFields);

// Explore the syntax tree and extract the processed info
token.ThrowIfCancellationRequested();

SemanticModelProvider semanticModelProvider = new(compilation);
(string entryPoint, ImmutableArray<(string Signature, string Definition)> processedMethods) = GetProcessedMethods(diagnostics, structDeclarationSymbol, semanticModelProvider, discoveredTypes, staticMethods, instanceMethods, constantDefinitions, out bool methodsNeedD2D1RequiresScenePosition);
ImmutableArray<(string Name, string TypeDeclaration, string? Assignment)> staticFields = GetStaticFields(diagnostics, semanticModelProvider, structDeclarationSymbol, discoveredTypes, constantDefinitions, out bool fieldsNeedD2D1RequiresScenePosition);

// Explore the syntax tree and extract the processed info
(string entryPoint, ImmutableArray<(string Signature, string Definition)> processedMethods) = GetProcessedMethods(
diagnostics,
structDeclarationSymbol,
semanticModelProvider,
discoveredTypes,
staticMethods,
instanceMethods,
constantDefinitions,
token,
out bool methodsNeedD2D1RequiresScenePosition);

token.ThrowIfCancellationRequested();

ImmutableArray<(string Name, string TypeDeclaration, string? Assignment)> staticFields = GetStaticFields(
diagnostics,
semanticModelProvider,
structDeclarationSymbol,
discoveredTypes,
constantDefinitions,
out bool fieldsNeedD2D1RequiresScenePosition);

token.ThrowIfCancellationRequested();

// Process the discovered types and constants
ImmutableArray<(string Name, string Definition)> declaredTypes = HlslDefinitionsSyntaxProcessor.GetDeclaredTypes(diagnostics, structDeclarationSymbol, discoveredTypes, instanceMethods);
ImmutableArray<(string Name, string Definition)> declaredTypes = HlslDefinitionsSyntaxProcessor.GetDeclaredTypes(
diagnostics,
structDeclarationSymbol,
discoveredTypes,
instanceMethods);

token.ThrowIfCancellationRequested();

ImmutableArray<(string Name, string Value)> definedConstants = HlslDefinitionsSyntaxProcessor.GetDefinedConstants(constantDefinitions);

token.ThrowIfCancellationRequested();

// Check whether the scene position is required
bool requiresScenePosition = GetD2DRequiresScenePositionInfo(structDeclarationSymbol);

// Emit diagnostics for incorrect scene position uses
ReportInvalidD2DRequiresScenePositionUse(diagnostics, structDeclarationSymbol, requiresScenePosition, methodsNeedD2D1RequiresScenePosition || fieldsNeedD2D1RequiresScenePosition);
ReportInvalidD2DRequiresScenePositionUse(
diagnostics,
structDeclarationSymbol,
requiresScenePosition,
methodsNeedD2D1RequiresScenePosition || fieldsNeedD2D1RequiresScenePosition);

token.ThrowIfCancellationRequested();

// Get the HLSL source
return GetHlslSource(
Expand Down Expand Up @@ -264,6 +305,7 @@ private static void GetInstanceFields(
/// <param name="instanceMethods">The collection of discovered instance methods for custom struct types.</param>
/// <param name="constantDefinitions">The collection of discovered constant definitions.</param>
/// <param name="needsD2D1RequiresScenePosition">Whether or not the shader needs the <c>[D2DRequiresScenePosition]</c> annotation.</param>
/// <param name="token">The <see cref="CancellationToken"/> used to cancel the operation, if needed.</param>
/// <returns>A sequence of processed methods in <paramref name="structDeclarationSymbol"/>, and the entry point.</returns>
private static (string EntryPoint, ImmutableArray<(string Signature, string Definition)> Methods) GetProcessedMethods(
ImmutableArrayBuilder<DiagnosticInfo> diagnostics,
Expand All @@ -273,6 +315,7 @@ private static (string EntryPoint, ImmutableArray<(string Signature, string Defi
IDictionary<IMethodSymbol, MethodDeclarationSyntax> staticMethods,
IDictionary<IMethodSymbol, MethodDeclarationSyntax> instanceMethods,
IDictionary<IFieldSymbol, string> constantDefinitions,
CancellationToken token,
out bool needsD2D1RequiresScenePosition)
{
using ImmutableArrayBuilder<(string, string)> methods = new();
Expand Down Expand Up @@ -305,6 +348,8 @@ private static (string EntryPoint, ImmutableArray<(string Signature, string Defi
continue;
}

token.ThrowIfCancellationRequested();

// Create the source rewriter for the current method
ShaderSourceRewriter shaderSourceRewriter = new(
structDeclarationSymbol,
Expand All @@ -319,6 +364,8 @@ private static (string EntryPoint, ImmutableArray<(string Signature, string Defi
// Rewrite the method syntax tree
MethodDeclarationSyntax? processedMethod = shaderSourceRewriter.Visit(methodDeclaration)!.WithoutTrivia();

token.ThrowIfCancellationRequested();

// Update the position requirement
needsD2D1RequiresScenePosition |= shaderSourceRewriter.NeedsD2DRequiresScenePositionAttribute;

Expand Down Expand Up @@ -346,6 +393,8 @@ private static (string EntryPoint, ImmutableArray<(string Signature, string Defi
}
}

token.ThrowIfCancellationRequested();

// Process static methods as well
foreach (MethodDeclarationSyntax staticMethod in staticMethods.Values)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -119,11 +119,11 @@ public void Initialize(IncrementalGeneratorInitializationContext context)
string hlslSource = HlslSource.GetHlslSource(
diagnostics,
context.SemanticModel.Compilation,
(StructDeclarationSyntax)context.TargetNode,
typeSymbol,
inputCount,
inputSimpleIndices,
inputComplexIndices);
inputComplexIndices,
token);

token.ThrowIfCancellationRequested();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
using System.Collections.Generic;
using System.Collections.Immutable;
using System.Linq;
using System.Threading;
using ComputeSharp.SourceGeneration.Extensions;
using ComputeSharp.SourceGeneration.Helpers;
using ComputeSharp.SourceGeneration.Mappings;
Expand Down Expand Up @@ -29,22 +30,22 @@ internal static partial class HlslSource
/// </summary>
/// <param name="diagnostics">The collection of produced <see cref="DiagnosticInfo"/> instances.</param>
/// <param name="compilation">The input <see cref="Compilation"/> object currently in use.</param>
/// <param name="structDeclaration">The <see cref="StructDeclarationSyntax"/> node to process.</param>
/// <param name="structDeclarationSymbol">The <see cref="INamedTypeSymbol"/> for <paramref name="structDeclaration"/>.</param>
/// <param name="structDeclarationSymbol">The <see cref="INamedTypeSymbol"/> for the shader type.</param>
/// <param name="threadsX">The thread ids value for the X axis.</param>
/// <param name="threadsY">The thread ids value for the Y axis.</param>
/// <param name="threadsZ">The thread ids value for the Z axis.</param>
/// <param name="token">The <see cref="CancellationToken"/> used to cancel the operation, if needed.</param>
/// <param name="isImplicitTextureUsed">Indicates whether the current shader uses an implicit texture.</param>
/// <param name="isSamplerUsed">Whether or not the static sampler is used.</param>
/// <param name="hlslSource">The resulting HLSL source for the current shader.</param>
public static void GetInfo(
ImmutableArrayBuilder<DiagnosticInfo> diagnostics,
Compilation compilation,
StructDeclarationSyntax structDeclaration,
INamedTypeSymbol structDeclarationSymbol,
int threadsX,
int threadsY,
int threadsZ,
CancellationToken token,
out bool isImplicitTextureUsed,
out bool isSamplerUsed,
out string hlslSource)
Expand All @@ -58,21 +59,66 @@ public static void GetInfo(
Dictionary<IMethodSymbol, MethodDeclarationSyntax> instanceMethods = new(SymbolEqualityComparer.Default);
Dictionary<IFieldSymbol, string> constantDefinitions = new(SymbolEqualityComparer.Default);

// Explore the syntax tree and extract the processed info
SemanticModelProvider semanticModelProvider = new(compilation);
INamedTypeSymbol? pixelShaderSymbol = structDeclarationSymbol.AllInterfaces.FirstOrDefault(static interfaceSymbol => interfaceSymbol is { IsGenericType: true, Name: nameof(IComputeShader<byte>) });
// Setup the semantic model and basic properties
INamedTypeSymbol? pixelShaderSymbol = structDeclarationSymbol.AllInterfaces.FirstOrDefault(static interfaceSymbol => interfaceSymbol is { IsGenericType: true, Name: "IComputeShader" });
bool isComputeShader = pixelShaderSymbol is null;
string? implicitTextureType = isComputeShader ? null : HlslKnownTypes.GetMappedNameForPixelShaderType(pixelShaderSymbol!);
(ImmutableArray<(string MetadataName, string Name, string HlslType)> resourceFields, ImmutableArray<(string Name, string HlslType)> valueFields) = GetInstanceFields(diagnostics, structDeclarationSymbol, discoveredTypes, isComputeShader);
ImmutableArray<(string Name, string Type, int? Count)> sharedBuffers = GetSharedBuffers(diagnostics, structDeclarationSymbol, discoveredTypes);
(string entryPoint, ImmutableArray<(string Signature, string Definition)> processedMethods, isSamplerUsed) = GetProcessedMethods(diagnostics, structDeclarationSymbol, semanticModelProvider, discoveredTypes, staticMethods, instanceMethods, constantDefinitions, isComputeShader);

token.ThrowIfCancellationRequested();

(ImmutableArray<(string MetadataName, string Name, string HlslType)> resourceFields, ImmutableArray<(string Name, string HlslType)> valueFields) = GetInstanceFields(
diagnostics,
structDeclarationSymbol,
discoveredTypes,
isComputeShader);

token.ThrowIfCancellationRequested();

ImmutableArray<(string Name, string Type, int? Count)> sharedBuffers = GetSharedBuffers(
diagnostics,
structDeclarationSymbol,
discoveredTypes);

token.ThrowIfCancellationRequested();

SemanticModelProvider semanticModelProvider = new(compilation);

(string entryPoint, ImmutableArray<(string Signature, string Definition)> processedMethods, isSamplerUsed) = GetProcessedMethods(
diagnostics,
structDeclarationSymbol,
semanticModelProvider,
discoveredTypes,
staticMethods,
instanceMethods,
constantDefinitions,
isComputeShader,
token);

token.ThrowIfCancellationRequested();

(string, string)? implicitSamplerField = isSamplerUsed ? ("SamplerState", "__sampler") : default((string, string)?);
ImmutableArray<(string Name, string TypeDeclaration, string? Assignment)> staticFields = GetStaticFields(diagnostics, semanticModelProvider, structDeclarationSymbol, discoveredTypes, constantDefinitions);
ImmutableArray<(string Name, string TypeDeclaration, string? Assignment)> staticFields = GetStaticFields(
diagnostics,
semanticModelProvider,
structDeclarationSymbol,
discoveredTypes,
constantDefinitions);

token.ThrowIfCancellationRequested();

// Process the discovered types and constants
ImmutableArray<(string Name, string Definition)> declaredTypes = HlslDefinitionsSyntaxProcessor.GetDeclaredTypes(diagnostics, structDeclarationSymbol, discoveredTypes, instanceMethods);
ImmutableArray<(string Name, string Definition)> declaredTypes = HlslDefinitionsSyntaxProcessor.GetDeclaredTypes(
diagnostics,
structDeclarationSymbol,
discoveredTypes,
instanceMethods);

token.ThrowIfCancellationRequested();

ImmutableArray<(string Name, string Value)> definedConstants = HlslDefinitionsSyntaxProcessor.GetDefinedConstants(constantDefinitions);

token.ThrowIfCancellationRequested();

// Check whether an implicit texture is used in the shader
isImplicitTextureUsed = implicitTextureType is not null;

Expand Down Expand Up @@ -312,6 +358,7 @@ private static (
/// <param name="instanceMethods">The collection of discovered instance methods for custom struct types.</param>
/// <param name="constantDefinitions">The collection of discovered constant definitions.</param>
/// <param name="isComputeShader">Indicates whether or not <paramref name="structDeclarationSymbol"/> represents a compute shader.</param>
/// <param name="token">The <see cref="CancellationToken"/> used to cancel the operation, if needed.</param>
/// <returns>A sequence of processed methods in <paramref name="structDeclarationSymbol"/>, and the entry point.</returns>
private static (string EntryPoint, ImmutableArray<(string Signature, string Definition)> Methods, bool IsSamplerUser) GetProcessedMethods(
ImmutableArrayBuilder<DiagnosticInfo> diagnostics,
Expand All @@ -321,7 +368,8 @@ private static (string EntryPoint, ImmutableArray<(string Signature, string Defi
IDictionary<IMethodSymbol, MethodDeclarationSyntax> staticMethods,
IDictionary<IMethodSymbol, MethodDeclarationSyntax> instanceMethods,
IDictionary<IFieldSymbol, string> constantDefinitions,
bool isComputeShader)
bool isComputeShader,
CancellationToken token)
{
using ImmutableArrayBuilder<(string, string)> methods = new();

Expand All @@ -343,12 +391,12 @@ private static (string EntryPoint, ImmutableArray<(string Signature, string Defi

bool isShaderEntryPoint =
(isComputeShader &&
methodSymbol.Name == nameof(IComputeShader.Execute) &&
methodSymbol.Name == "Execute" &&
methodSymbol.ReturnsVoid &&
methodSymbol.TypeParameters.Length == 0 &&
methodSymbol.Parameters.Length == 0) ||
(!isComputeShader &&
methodSymbol.Name == nameof(IComputeShader<byte>.Execute) &&
methodSymbol.Name == "Execute" &&
methodSymbol.ReturnType is not null && // TODO: match for pixel type
methodSymbol.TypeParameters.Length == 0 &&
methodSymbol.Parameters.Length == 0);
Expand All @@ -359,6 +407,8 @@ private static (string EntryPoint, ImmutableArray<(string Signature, string Defi
continue;
}

token.ThrowIfCancellationRequested();

// Create the source rewriter for the current method
ShaderSourceRewriter shaderSourceRewriter = new(
structDeclarationSymbol,
Expand All @@ -373,6 +423,8 @@ private static (string EntryPoint, ImmutableArray<(string Signature, string Defi
// Rewrite the method syntax tree
MethodDeclarationSyntax? processedMethod = shaderSourceRewriter.Visit(methodDeclaration)!.WithoutTrivia();

token.ThrowIfCancellationRequested();

// Track the implicit sampler, if used
isSamplerUsed = isSamplerUsed || shaderSourceRewriter.IsSamplerUsed;

Expand Down Expand Up @@ -403,6 +455,8 @@ private static (string EntryPoint, ImmutableArray<(string Signature, string Defi
}
}

token.ThrowIfCancellationRequested();

// Process static methods as well
foreach (MethodDeclarationSyntax staticMethod in staticMethods.Values)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,11 +85,11 @@ public void Initialize(IncrementalGeneratorInitializationContext context)
HlslSource.GetInfo(
diagnostics,
context.SemanticModel.Compilation,
(StructDeclarationSyntax)context.TargetNode,
typeSymbol,
threadsX,
threadsY,
threadsZ,
token,
out bool isImplicitTextureUsed,
out bool isSamplerUsed,
out string hlslSource);
Expand Down
Loading
Loading