Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix invalid HLSL for shader methods using custom types in signatures #728

Merged
merged 3 commits into from
Dec 21, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -451,13 +451,13 @@ private static bool GetD2DRequiresScenePositionInfo(INamedTypeSymbol structDecla
/// <summary>
/// Produces the series of statements to build the current HLSL source.
/// </summary>
/// <param name="definedConstants"><inheritdoc cref="HlslSourceHelper.WriteTopDeclarations" path="/param[@name='definedConstants']/node()"/></param>
/// <param name="valueFields"><inheritdoc cref="HlslSourceHelper.WriteCapturedFields" path="/param[@name='valueFields']/node()"/></param>
/// <param name="definedConstants"><inheritdoc cref="HlslSourceSyntaxProcessor.WriteTopDeclarations" path="/param[@name='definedConstants']/node()"/></param>
/// <param name="valueFields"><inheritdoc cref="HlslSourceSyntaxProcessor.WriteCapturedFields" path="/param[@name='valueFields']/node()"/></param>
/// <param name="resourceTextureFields">The sequence of captured resource textures for the current shader.</param>
/// <param name="staticFields"><inheritdoc cref="HlslSourceHelper.WriteTopDeclarations" path="/param[@name='staticFields']/node()"/></param>
/// <param name="processedMethods"><inheritdoc cref="HlslSourceHelper.WriteTopDeclarations" path="/param[@name='processedMethods']/node()"/></param>
/// <param name="typeDeclarations"><inheritdoc cref="HlslSourceHelper.WriteTopDeclarations" path="/param[@name='typeDeclarations']/node()"/></param>
/// <param name="typeMethodDeclarations"><inheritdoc cref="HlslSourceHelper.WriteMethodDeclarations" path="/param[@name='typeMethodDeclarations']/node()"/></param>
/// <param name="staticFields"><inheritdoc cref="HlslSourceSyntaxProcessor.WriteTopDeclarations" path="/param[@name='staticFields']/node()"/></param>
/// <param name="processedMethods"><inheritdoc cref="HlslSourceSyntaxProcessor.WriteTopDeclarations" path="/param[@name='processedMethods']/node()"/></param>
/// <param name="typeDeclarations"><inheritdoc cref="HlslSourceSyntaxProcessor.WriteTopDeclarations" path="/param[@name='typeDeclarations']/node()"/></param>
/// <param name="typeMethodDeclarations"><inheritdoc cref="HlslSourceSyntaxProcessor.WriteMethodDeclarations" path="/param[@name='typeMethodDeclarations']/node()"/></param>
/// <param name="executeMethod">The body of the entry point of the shader.</param>
/// <param name="inputCount">The number of shader inputs to declare.</param>
/// <param name="inputSimpleIndices">The indicess of the simple shader inputs.</param>
Expand All @@ -480,7 +480,7 @@ private static string GetHlslSource(
{
using IndentedTextWriter writer = new();

HlslSourceHelper.WriteHeader(writer);
HlslSourceSyntaxProcessor.WriteHeader(writer);

// Shader metadata
writer.WriteLine($"#define D2D_INPUT_COUNT {inputCount}");
Expand All @@ -503,15 +503,15 @@ private static string GetHlslSource(
writer.WriteLine();

// The FXC compiler does not support type forward declarations
HlslSourceHelper.WriteTopDeclarations(
HlslSourceSyntaxProcessor.WriteTopDeclarations(
writer,
definedConstants,
staticFields,
processedMethods,
typeDeclarations,
includeTypeForwardDeclarations: false);

HlslSourceHelper.WriteCapturedFields(writer, valueFields);
HlslSourceSyntaxProcessor.WriteCapturedFields(writer, valueFields);

// Resource textures
foreach (HlslResourceTextureField field in resourceTextureFields)
Expand All @@ -521,7 +521,7 @@ private static string GetHlslSource(
writer.WriteLine($"SamplerState __sampler__{field.Name} : register(s{field.Index});");
}

HlslSourceHelper.WriteMethodDeclarations(writer, processedMethods, typeMethodDeclarations);
HlslSourceSyntaxProcessor.WriteMethodDeclarations(writer, processedMethods, typeMethodDeclarations);

// Entry point
writer.WriteLine(executeMethod);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
<Compile Include="$(MSBuildThisFileDirectory)Models\TypeAliases.cs" />
<Compile Include="$(MSBuildThisFileDirectory)SyntaxProcessors\HlslBytecodeSyntaxProcessor.cs" />
<Compile Include="$(MSBuildThisFileDirectory)SyntaxProcessors\ConstantBufferSyntaxProcessor.Generation.cs" />
<Compile Include="$(MSBuildThisFileDirectory)Helpers\HlslSourceHelper.cs" />
<Compile Include="$(MSBuildThisFileDirectory)SyntaxProcessors\HlslSourceSyntaxProcessor.cs" />
<Compile Include="$(MSBuildThisFileDirectory)SyntaxProcessors\HlslDefinitionsSyntaxProcessor.cs" />
<Compile Include="$(MSBuildThisFileDirectory)SyntaxRewriters\HlslSourceRewriter.Tracking.cs" />
<Compile Include="$(MSBuildThisFileDirectory)SyntaxRewriters\HlslSourceRewriter.cs" />
Expand Down
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
using System.Collections.Immutable;
using ComputeSharp.SourceGeneration.Helpers;

namespace ComputeSharp.SourceGeneration.Helpers;
namespace ComputeSharp.SourceGeneration.SyntaxProcessors;

/// <summary>
/// A helper type to write HLSL source.
/// A processor responsible for formatting shared HLSL source for all shader types.
/// </summary>
internal static class HlslSourceHelper
internal static class HlslSourceSyntaxProcessor
{
/// <summary>
/// Writes the header included at the top of each generated HLSL shader.
Expand Down Expand Up @@ -67,6 +68,18 @@ public static void WriteTopDeclarations(
writer.WriteLine(skipIfPresent: true);
}

// Declared types (these have to be declared early on in the shader so that even if
// forward declarations for types are not supported, like is the case for D2D shaders
// using the FXC compiler, the resulting HLSL code is valid in case any forward
// declaration of methods in the shader has one of these types in its signature).
foreach (HlslUserType userType in typeDeclarations)
{
writer.WriteLine(skipIfPresent: true);
writer.WriteLine(userType.Definition);
}

writer.WriteLine(skipIfPresent: true);

// Forward declarations of shader/static methods
foreach (HlslMethod method in processedMethods)
{
Expand All @@ -90,15 +103,6 @@ public static void WriteTopDeclarations(
}

writer.WriteLine(skipIfPresent: true);

// Declared types
foreach (HlslUserType userType in typeDeclarations)
{
writer.WriteLine(skipIfPresent: true);
writer.WriteLine(userType.Definition);
}

writer.WriteLine(skipIfPresent: true);
}

/// <summary>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -489,14 +489,14 @@ private static (string EntryPoint, ImmutableArray<HlslMethod> Methods, bool IsSa
/// <param name="threadsX">The thread ids value for the X axis.</param>
/// <param name="threadsY">The thread ids value for the Y axis.</param>
/// <param name="threadsZ">The thread ids value for the Z axis.</param>
/// <param name="definedConstants"><inheritdoc cref="HlslSourceHelper.WriteTopDeclarations" path="/param[@name='definedConstants']/node()"/></param>
/// <param name="definedConstants"><inheritdoc cref="HlslSourceSyntaxProcessor.WriteTopDeclarations" path="/param[@name='definedConstants']/node()"/></param>
/// <param name="resourceFields">The sequence of resource instance fields for the current shader.</param>
/// <param name="valueFields"><inheritdoc cref="HlslSourceHelper.WriteCapturedFields" path="/param[@name='valueFields']/node()"/></param>
/// <param name="staticFields"><inheritdoc cref="HlslSourceHelper.WriteTopDeclarations" path="/param[@name='staticFields']/node()"/></param>
/// <param name="valueFields"><inheritdoc cref="HlslSourceSyntaxProcessor.WriteCapturedFields" path="/param[@name='valueFields']/node()"/></param>
/// <param name="staticFields"><inheritdoc cref="HlslSourceSyntaxProcessor.WriteTopDeclarations" path="/param[@name='staticFields']/node()"/></param>
/// <param name="sharedBuffers">The sequence of shared buffers declared by the shader.</param>
/// <param name="processedMethods"><inheritdoc cref="HlslSourceHelper.WriteTopDeclarations" path="/param[@name='processedMethods']/node()"/></param>
/// <param name="typeDeclarations"><inheritdoc cref="HlslSourceHelper.WriteTopDeclarations" path="/param[@name='typeDeclarations']/node()"/></param>
/// <param name="typeMethodDeclarations"><inheritdoc cref="HlslSourceHelper.WriteMethodDeclarations" path="/param[@name='typeMethodDeclarations']/node()"/></param>
/// <param name="processedMethods"><inheritdoc cref="HlslSourceSyntaxProcessor.WriteTopDeclarations" path="/param[@name='processedMethods']/node()"/></param>
/// <param name="typeDeclarations"><inheritdoc cref="HlslSourceSyntaxProcessor.WriteTopDeclarations" path="/param[@name='typeDeclarations']/node()"/></param>
/// <param name="typeMethodDeclarations"><inheritdoc cref="HlslSourceSyntaxProcessor.WriteMethodDeclarations" path="/param[@name='typeMethodDeclarations']/node()"/></param>
/// <param name="isComputeShader">Whether or not the current shader type is a compute shader.</param>
/// <param name="implicitTextureType">The type of the implicit target texture, if present.</param>
/// <param name="isSamplerUsed">Whether the static sampler is used by the shader.</param>
Expand All @@ -521,14 +521,14 @@ private static string GetHlslSourceInfo(
{
using IndentedTextWriter writer = new();

HlslSourceHelper.WriteHeader(writer);
HlslSourceSyntaxProcessor.WriteHeader(writer);

// Group size constants
writer.WriteLine($"#define __GroupSize__get_X {threadsX}");
writer.WriteLine($"#define __GroupSize__get_Y {threadsY}");
writer.WriteLine($"#define __GroupSize__get_Z {threadsZ}");

HlslSourceHelper.WriteTopDeclarations(
HlslSourceSyntaxProcessor.WriteTopDeclarations(
writer,
definedConstants,
staticFields,
Expand All @@ -545,7 +545,7 @@ private static string GetHlslSourceInfo(
writer.WriteLine("uint __y;");
writer.WriteLineIf(isComputeShader, "uint __z;");

HlslSourceHelper.WriteCapturedFields(writer, valueFields);
HlslSourceSyntaxProcessor.WriteCapturedFields(writer, valueFields);
}

int constantBuffersCount = 1;
Expand Down Expand Up @@ -593,7 +593,7 @@ private static string GetHlslSourceInfo(
writer.WriteLine($"groupshared {buffer.Type} {buffer.Name} [{count}];");
}

HlslSourceHelper.WriteMethodDeclarations(writer, processedMethods, typeMethodDeclarations);
HlslSourceSyntaxProcessor.WriteMethodDeclarations(writer, processedMethods, typeMethodDeclarations);

// Entry point
writer.WriteLine("[NumThreads(__GroupSize__get_X, __GroupSize__get_Y, __GroupSize__get_Z)]");
Expand Down
4 changes: 2 additions & 2 deletions tests/ComputeSharp.D2D1.Tests/D2D1ReflectionServicesTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ public void GetShaderInfo()
float4 value4 = __reserved__texture.Sample(__sampler____reserved__texture, float2(0, 0.5));
return input0 + input1 + input2 + value3 + value4;
}
""".Replace("\r\n", "\n"), shaderInfo.HlslSource);
""", shaderInfo.HlslSource);

CollectionAssert.AreEqual(D2D1PixelShader.LoadBytecode<ReflectedShader>().ToArray(), shaderInfo.HlslBytecode.ToArray());
}
Expand Down Expand Up @@ -119,7 +119,7 @@ public void GetShaderInfoWithDoublePrecisionFeature()
{
return (float4)(D2DGetInput(0) + (double4)amount);
}
""".Replace("\r\n", "\n"), shaderInfo.HlslSource);
""", shaderInfo.HlslSource);

CollectionAssert.AreEqual(D2D1PixelShader.LoadBytecode<ReflectedShaderWithDoubleOperations>().ToArray(), shaderInfo.HlslBytecode.ToArray());
}
Expand Down
76 changes: 76 additions & 0 deletions tests/ComputeSharp.D2D1.Tests/ShaderRewriterTests.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
using ComputeSharp.D2D1.Interop;
using Microsoft.VisualStudio.TestTools.UnitTesting;

#pragma warning disable CA1822

namespace ComputeSharp.D2D1.Tests;

[TestClass]
public partial class ShaderRewriterTests
{
// See https://github.com/Sergio0694/ComputeSharp/issues/725
[TestMethod]
public void ShaderWithStructAndInstanceMethodUsingIt_IsRewrittenCorrectly()
{
D2D1ShaderInfo shaderInfo = D2D1ReflectionServices.GetShaderInfo<ClassWithShader.ShaderWithStructAndInstanceMethodUsingIt>();

Assert.AreEqual("""
// ================================================
// AUTO GENERATED
// ================================================
// This shader was created by ComputeSharp.
// See: https://github.com/Sergio0694/ComputeSharp.

#define D2D_INPUT_COUNT 0

#include "d2d1effecthelpers.hlsli"

struct ComputeSharp_D2D1_Tests_ShaderRewriterTests_ClassWithShader_ShaderWithStructAndInstanceMethodUsingIt_Data
{
int value;
};

void UseData(inout ComputeSharp_D2D1_Tests_ShaderRewriterTests_ClassWithShader_ShaderWithStructAndInstanceMethodUsingIt_Data data);

void UseData(inout ComputeSharp_D2D1_Tests_ShaderRewriterTests_ClassWithShader_ShaderWithStructAndInstanceMethodUsingIt_Data data)
{
++data.value;
}

D2D_PS_ENTRY(Execute)
{
ComputeSharp_D2D1_Tests_ShaderRewriterTests_ClassWithShader_ShaderWithStructAndInstanceMethodUsingIt_Data data = (ComputeSharp_D2D1_Tests_ShaderRewriterTests_ClassWithShader_ShaderWithStructAndInstanceMethodUsingIt_Data)0;
UseData(data);
return float4(data.value, data.value, data.value, data.value);
}
""", shaderInfo.HlslSource);
}

internal sealed partial class ClassWithShader
{
[D2DInputCount(0)]
[D2DShaderProfile(D2D1ShaderProfile.PixelShader50)]
[D2DGeneratedPixelShaderDescriptor]
internal readonly partial struct ShaderWithStructAndInstanceMethodUsingIt : ID2D1PixelShader
{
private struct Data
{
public int value;
}

public float4 Execute()
{
Data data = default;

UseData(ref data);

return new float4(data.value, data.value, data.value, data.value);
}

private void UseData(ref Data data)
{
++data.value;
}
}
}
}
26 changes: 13 additions & 13 deletions tests/ComputeSharp.Tests/ShaderCompilerTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -992,19 +992,6 @@ public void ShaderWithAllSupportedMembers_IsProcessedCorrectly()
struct ComputeSharp_Tests_ShaderCompilerTests_StructType1;
struct ComputeSharp_Tests_ShaderCompilerTests_StructType2;

int InstanceMethodInShader();

static float StaticMethodInShader(float x);

static float ComputeSharp_Tests_ShaderCompilerTests_StructType1_StaticMethod(int x);

static float ComputeSharp_Tests_ShaderCompilerTests_StructType2_StaticMethod(int x);

static const float Init = abs(__ComputeSharp_Tests_ShaderCompilerTests_ShaderWithAllSupportedMembers__PI);
static int Temp;
static int ComputeSharp_Tests_ShaderCompilerTests_ExternalContainerClass_Temp;
static const float ComputeSharp_Tests_ShaderCompilerTests_ExternalContainerClass_PI2 = 3.14 * 2;

struct ComputeSharp_Tests_ShaderCompilerTests_StructType1
{
int X;
Expand All @@ -1020,6 +1007,19 @@ struct ComputeSharp_Tests_ShaderCompilerTests_StructType2
float Combine(ComputeSharp_Tests_ShaderCompilerTests_StructType1 other);
};

int InstanceMethodInShader();

static float StaticMethodInShader(float x);

static float ComputeSharp_Tests_ShaderCompilerTests_StructType1_StaticMethod(int x);

static float ComputeSharp_Tests_ShaderCompilerTests_StructType2_StaticMethod(int x);

static const float Init = abs(__ComputeSharp_Tests_ShaderCompilerTests_ShaderWithAllSupportedMembers__PI);
static int Temp;
static int ComputeSharp_Tests_ShaderCompilerTests_ExternalContainerClass_Temp;
static const float ComputeSharp_Tests_ShaderCompilerTests_ExternalContainerClass_PI2 = 3.14 * 2;

cbuffer _ : register(b0)
{
uint __x;
Expand Down