Skip to content

Commit

Permalink
Merge pull request #728 from Sergio0694/dev/fix-methods-forward-decla…
Browse files Browse the repository at this point in the history
…rations

Fix invalid HLSL for shader methods using custom types in signatures
  • Loading branch information
Sergio0694 authored Dec 21, 2023
2 parents 2f6ab50 + 5f429df commit f5e182c
Show file tree
Hide file tree
Showing 7 changed files with 128 additions and 48 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -451,13 +451,13 @@ private static bool GetD2DRequiresScenePositionInfo(INamedTypeSymbol structDecla
/// <summary>
/// Produces the series of statements to build the current HLSL source.
/// </summary>
/// <param name="definedConstants"><inheritdoc cref="HlslSourceHelper.WriteTopDeclarations" path="/param[@name='definedConstants']/node()"/></param>
/// <param name="valueFields"><inheritdoc cref="HlslSourceHelper.WriteCapturedFields" path="/param[@name='valueFields']/node()"/></param>
/// <param name="definedConstants"><inheritdoc cref="HlslSourceSyntaxProcessor.WriteTopDeclarations" path="/param[@name='definedConstants']/node()"/></param>
/// <param name="valueFields"><inheritdoc cref="HlslSourceSyntaxProcessor.WriteCapturedFields" path="/param[@name='valueFields']/node()"/></param>
/// <param name="resourceTextureFields">The sequence of captured resource textures for the current shader.</param>
/// <param name="staticFields"><inheritdoc cref="HlslSourceHelper.WriteTopDeclarations" path="/param[@name='staticFields']/node()"/></param>
/// <param name="processedMethods"><inheritdoc cref="HlslSourceHelper.WriteTopDeclarations" path="/param[@name='processedMethods']/node()"/></param>
/// <param name="typeDeclarations"><inheritdoc cref="HlslSourceHelper.WriteTopDeclarations" path="/param[@name='typeDeclarations']/node()"/></param>
/// <param name="typeMethodDeclarations"><inheritdoc cref="HlslSourceHelper.WriteMethodDeclarations" path="/param[@name='typeMethodDeclarations']/node()"/></param>
/// <param name="staticFields"><inheritdoc cref="HlslSourceSyntaxProcessor.WriteTopDeclarations" path="/param[@name='staticFields']/node()"/></param>
/// <param name="processedMethods"><inheritdoc cref="HlslSourceSyntaxProcessor.WriteTopDeclarations" path="/param[@name='processedMethods']/node()"/></param>
/// <param name="typeDeclarations"><inheritdoc cref="HlslSourceSyntaxProcessor.WriteTopDeclarations" path="/param[@name='typeDeclarations']/node()"/></param>
/// <param name="typeMethodDeclarations"><inheritdoc cref="HlslSourceSyntaxProcessor.WriteMethodDeclarations" path="/param[@name='typeMethodDeclarations']/node()"/></param>
/// <param name="executeMethod">The body of the entry point of the shader.</param>
/// <param name="inputCount">The number of shader inputs to declare.</param>
/// <param name="inputSimpleIndices">The indicess of the simple shader inputs.</param>
Expand All @@ -480,7 +480,7 @@ private static string GetHlslSource(
{
using IndentedTextWriter writer = new();

HlslSourceHelper.WriteHeader(writer);
HlslSourceSyntaxProcessor.WriteHeader(writer);

// Shader metadata
writer.WriteLine($"#define D2D_INPUT_COUNT {inputCount}");
Expand All @@ -503,15 +503,15 @@ private static string GetHlslSource(
writer.WriteLine();

// The FXC compiler does not support type forward declarations
HlslSourceHelper.WriteTopDeclarations(
HlslSourceSyntaxProcessor.WriteTopDeclarations(
writer,
definedConstants,
staticFields,
processedMethods,
typeDeclarations,
includeTypeForwardDeclarations: false);

HlslSourceHelper.WriteCapturedFields(writer, valueFields);
HlslSourceSyntaxProcessor.WriteCapturedFields(writer, valueFields);

// Resource textures
foreach (HlslResourceTextureField field in resourceTextureFields)
Expand All @@ -521,7 +521,7 @@ private static string GetHlslSource(
writer.WriteLine($"SamplerState __sampler__{field.Name} : register(s{field.Index});");
}

HlslSourceHelper.WriteMethodDeclarations(writer, processedMethods, typeMethodDeclarations);
HlslSourceSyntaxProcessor.WriteMethodDeclarations(writer, processedMethods, typeMethodDeclarations);

// Entry point
writer.WriteLine(executeMethod);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
<Compile Include="$(MSBuildThisFileDirectory)Models\TypeAliases.cs" />
<Compile Include="$(MSBuildThisFileDirectory)SyntaxProcessors\HlslBytecodeSyntaxProcessor.cs" />
<Compile Include="$(MSBuildThisFileDirectory)SyntaxProcessors\ConstantBufferSyntaxProcessor.Generation.cs" />
<Compile Include="$(MSBuildThisFileDirectory)Helpers\HlslSourceHelper.cs" />
<Compile Include="$(MSBuildThisFileDirectory)SyntaxProcessors\HlslSourceSyntaxProcessor.cs" />
<Compile Include="$(MSBuildThisFileDirectory)SyntaxProcessors\HlslDefinitionsSyntaxProcessor.cs" />
<Compile Include="$(MSBuildThisFileDirectory)SyntaxRewriters\HlslSourceRewriter.Tracking.cs" />
<Compile Include="$(MSBuildThisFileDirectory)SyntaxRewriters\HlslSourceRewriter.cs" />
Expand Down
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
using System.Collections.Immutable;
using ComputeSharp.SourceGeneration.Helpers;

namespace ComputeSharp.SourceGeneration.Helpers;
namespace ComputeSharp.SourceGeneration.SyntaxProcessors;

/// <summary>
/// A helper type to write HLSL source.
/// A processor responsible for formatting shared HLSL source for all shader types.
/// </summary>
internal static class HlslSourceHelper
internal static class HlslSourceSyntaxProcessor
{
/// <summary>
/// Writes the header included at the top of each generated HLSL shader.
Expand Down Expand Up @@ -67,6 +68,18 @@ public static void WriteTopDeclarations(
writer.WriteLine(skipIfPresent: true);
}

// Declared types (these have to be declared early on in the shader so that even if
// forward declarations for types are not supported, like is the case for D2D shaders
// using the FXC compiler, the resulting HLSL code is valid in case any forward
// declaration of methods in the shader has one of these types in its signature).
foreach (HlslUserType userType in typeDeclarations)
{
writer.WriteLine(skipIfPresent: true);
writer.WriteLine(userType.Definition);
}

writer.WriteLine(skipIfPresent: true);

// Forward declarations of shader/static methods
foreach (HlslMethod method in processedMethods)
{
Expand All @@ -90,15 +103,6 @@ public static void WriteTopDeclarations(
}

writer.WriteLine(skipIfPresent: true);

// Declared types
foreach (HlslUserType userType in typeDeclarations)
{
writer.WriteLine(skipIfPresent: true);
writer.WriteLine(userType.Definition);
}

writer.WriteLine(skipIfPresent: true);
}

/// <summary>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -489,14 +489,14 @@ private static (string EntryPoint, ImmutableArray<HlslMethod> Methods, bool IsSa
/// <param name="threadsX">The thread ids value for the X axis.</param>
/// <param name="threadsY">The thread ids value for the Y axis.</param>
/// <param name="threadsZ">The thread ids value for the Z axis.</param>
/// <param name="definedConstants"><inheritdoc cref="HlslSourceHelper.WriteTopDeclarations" path="/param[@name='definedConstants']/node()"/></param>
/// <param name="definedConstants"><inheritdoc cref="HlslSourceSyntaxProcessor.WriteTopDeclarations" path="/param[@name='definedConstants']/node()"/></param>
/// <param name="resourceFields">The sequence of resource instance fields for the current shader.</param>
/// <param name="valueFields"><inheritdoc cref="HlslSourceHelper.WriteCapturedFields" path="/param[@name='valueFields']/node()"/></param>
/// <param name="staticFields"><inheritdoc cref="HlslSourceHelper.WriteTopDeclarations" path="/param[@name='staticFields']/node()"/></param>
/// <param name="valueFields"><inheritdoc cref="HlslSourceSyntaxProcessor.WriteCapturedFields" path="/param[@name='valueFields']/node()"/></param>
/// <param name="staticFields"><inheritdoc cref="HlslSourceSyntaxProcessor.WriteTopDeclarations" path="/param[@name='staticFields']/node()"/></param>
/// <param name="sharedBuffers">The sequence of shared buffers declared by the shader.</param>
/// <param name="processedMethods"><inheritdoc cref="HlslSourceHelper.WriteTopDeclarations" path="/param[@name='processedMethods']/node()"/></param>
/// <param name="typeDeclarations"><inheritdoc cref="HlslSourceHelper.WriteTopDeclarations" path="/param[@name='typeDeclarations']/node()"/></param>
/// <param name="typeMethodDeclarations"><inheritdoc cref="HlslSourceHelper.WriteMethodDeclarations" path="/param[@name='typeMethodDeclarations']/node()"/></param>
/// <param name="processedMethods"><inheritdoc cref="HlslSourceSyntaxProcessor.WriteTopDeclarations" path="/param[@name='processedMethods']/node()"/></param>
/// <param name="typeDeclarations"><inheritdoc cref="HlslSourceSyntaxProcessor.WriteTopDeclarations" path="/param[@name='typeDeclarations']/node()"/></param>
/// <param name="typeMethodDeclarations"><inheritdoc cref="HlslSourceSyntaxProcessor.WriteMethodDeclarations" path="/param[@name='typeMethodDeclarations']/node()"/></param>
/// <param name="isComputeShader">Whether or not the current shader type is a compute shader.</param>
/// <param name="implicitTextureType">The type of the implicit target texture, if present.</param>
/// <param name="isSamplerUsed">Whether the static sampler is used by the shader.</param>
Expand All @@ -521,14 +521,14 @@ private static string GetHlslSourceInfo(
{
using IndentedTextWriter writer = new();

HlslSourceHelper.WriteHeader(writer);
HlslSourceSyntaxProcessor.WriteHeader(writer);

// Group size constants
writer.WriteLine($"#define __GroupSize__get_X {threadsX}");
writer.WriteLine($"#define __GroupSize__get_Y {threadsY}");
writer.WriteLine($"#define __GroupSize__get_Z {threadsZ}");

HlslSourceHelper.WriteTopDeclarations(
HlslSourceSyntaxProcessor.WriteTopDeclarations(
writer,
definedConstants,
staticFields,
Expand All @@ -545,7 +545,7 @@ private static string GetHlslSourceInfo(
writer.WriteLine("uint __y;");
writer.WriteLineIf(isComputeShader, "uint __z;");

HlslSourceHelper.WriteCapturedFields(writer, valueFields);
HlslSourceSyntaxProcessor.WriteCapturedFields(writer, valueFields);
}

int constantBuffersCount = 1;
Expand Down Expand Up @@ -593,7 +593,7 @@ private static string GetHlslSourceInfo(
writer.WriteLine($"groupshared {buffer.Type} {buffer.Name} [{count}];");
}

HlslSourceHelper.WriteMethodDeclarations(writer, processedMethods, typeMethodDeclarations);
HlslSourceSyntaxProcessor.WriteMethodDeclarations(writer, processedMethods, typeMethodDeclarations);

// Entry point
writer.WriteLine("[NumThreads(__GroupSize__get_X, __GroupSize__get_Y, __GroupSize__get_Z)]");
Expand Down
4 changes: 2 additions & 2 deletions tests/ComputeSharp.D2D1.Tests/D2D1ReflectionServicesTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ public void GetShaderInfo()
float4 value4 = __reserved__texture.Sample(__sampler____reserved__texture, float2(0, 0.5));
return input0 + input1 + input2 + value3 + value4;
}
""".Replace("\r\n", "\n"), shaderInfo.HlslSource);
""", shaderInfo.HlslSource);

CollectionAssert.AreEqual(D2D1PixelShader.LoadBytecode<ReflectedShader>().ToArray(), shaderInfo.HlslBytecode.ToArray());
}
Expand Down Expand Up @@ -119,7 +119,7 @@ public void GetShaderInfoWithDoublePrecisionFeature()
{
return (float4)(D2DGetInput(0) + (double4)amount);
}
""".Replace("\r\n", "\n"), shaderInfo.HlslSource);
""", shaderInfo.HlslSource);

CollectionAssert.AreEqual(D2D1PixelShader.LoadBytecode<ReflectedShaderWithDoubleOperations>().ToArray(), shaderInfo.HlslBytecode.ToArray());
}
Expand Down
76 changes: 76 additions & 0 deletions tests/ComputeSharp.D2D1.Tests/ShaderRewriterTests.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
using ComputeSharp.D2D1.Interop;
using Microsoft.VisualStudio.TestTools.UnitTesting;

#pragma warning disable CA1822

namespace ComputeSharp.D2D1.Tests;

[TestClass]
public partial class ShaderRewriterTests
{
// See https://github.com/Sergio0694/ComputeSharp/issues/725
[TestMethod]
public void ShaderWithStructAndInstanceMethodUsingIt_IsRewrittenCorrectly()
{
D2D1ShaderInfo shaderInfo = D2D1ReflectionServices.GetShaderInfo<ClassWithShader.ShaderWithStructAndInstanceMethodUsingIt>();

Assert.AreEqual("""
// ================================================
// AUTO GENERATED
// ================================================
// This shader was created by ComputeSharp.
// See: https://github.com/Sergio0694/ComputeSharp.
#define D2D_INPUT_COUNT 0
#include "d2d1effecthelpers.hlsli"
struct ComputeSharp_D2D1_Tests_ShaderRewriterTests_ClassWithShader_ShaderWithStructAndInstanceMethodUsingIt_Data
{
int value;
};
void UseData(inout ComputeSharp_D2D1_Tests_ShaderRewriterTests_ClassWithShader_ShaderWithStructAndInstanceMethodUsingIt_Data data);
void UseData(inout ComputeSharp_D2D1_Tests_ShaderRewriterTests_ClassWithShader_ShaderWithStructAndInstanceMethodUsingIt_Data data)
{
++data.value;
}
D2D_PS_ENTRY(Execute)
{
ComputeSharp_D2D1_Tests_ShaderRewriterTests_ClassWithShader_ShaderWithStructAndInstanceMethodUsingIt_Data data = (ComputeSharp_D2D1_Tests_ShaderRewriterTests_ClassWithShader_ShaderWithStructAndInstanceMethodUsingIt_Data)0;
UseData(data);
return float4(data.value, data.value, data.value, data.value);
}
""", shaderInfo.HlslSource);
}

internal sealed partial class ClassWithShader
{
[D2DInputCount(0)]
[D2DShaderProfile(D2D1ShaderProfile.PixelShader50)]
[D2DGeneratedPixelShaderDescriptor]
internal readonly partial struct ShaderWithStructAndInstanceMethodUsingIt : ID2D1PixelShader
{
private struct Data
{
public int value;
}

public float4 Execute()
{
Data data = default;

UseData(ref data);

return new float4(data.value, data.value, data.value, data.value);
}

private void UseData(ref Data data)
{
++data.value;
}
}
}
}
26 changes: 13 additions & 13 deletions tests/ComputeSharp.Tests/ShaderCompilerTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -992,19 +992,6 @@ public void ShaderWithAllSupportedMembers_IsProcessedCorrectly()
struct ComputeSharp_Tests_ShaderCompilerTests_StructType1;
struct ComputeSharp_Tests_ShaderCompilerTests_StructType2;
int InstanceMethodInShader();
static float StaticMethodInShader(float x);
static float ComputeSharp_Tests_ShaderCompilerTests_StructType1_StaticMethod(int x);
static float ComputeSharp_Tests_ShaderCompilerTests_StructType2_StaticMethod(int x);
static const float Init = abs(__ComputeSharp_Tests_ShaderCompilerTests_ShaderWithAllSupportedMembers__PI);
static int Temp;
static int ComputeSharp_Tests_ShaderCompilerTests_ExternalContainerClass_Temp;
static const float ComputeSharp_Tests_ShaderCompilerTests_ExternalContainerClass_PI2 = 3.14 * 2;
struct ComputeSharp_Tests_ShaderCompilerTests_StructType1
{
int X;
Expand All @@ -1020,6 +1007,19 @@ struct ComputeSharp_Tests_ShaderCompilerTests_StructType2
float Combine(ComputeSharp_Tests_ShaderCompilerTests_StructType1 other);
};
int InstanceMethodInShader();
static float StaticMethodInShader(float x);
static float ComputeSharp_Tests_ShaderCompilerTests_StructType1_StaticMethod(int x);
static float ComputeSharp_Tests_ShaderCompilerTests_StructType2_StaticMethod(int x);
static const float Init = abs(__ComputeSharp_Tests_ShaderCompilerTests_ShaderWithAllSupportedMembers__PI);
static int Temp;
static int ComputeSharp_Tests_ShaderCompilerTests_ExternalContainerClass_Temp;
static const float ComputeSharp_Tests_ShaderCompilerTests_ExternalContainerClass_PI2 = 3.14 * 2;
cbuffer _ : register(b0)
{
uint __x;
Expand Down

0 comments on commit f5e182c

Please sign in to comment.