Skip to content

Commit

Permalink
Make generated code pretty to read (and write) (#280)
Browse files Browse the repository at this point in the history
* Add snapshot testing of generated code

* Generate pretty code

* Group function fields together

* Comment re-binding section in "ReloadModule"

* Normalise sources to stabilise hash in generated code

* Remove irrelevant comment

* Implement ugly hack to workaround file deletion.

See also:

- <https://github.com/tonybaloney/CSnakes/runs/31639380191#r4s3>
  • Loading branch information
atifaziz authored Oct 18, 2024
1 parent dff06c3 commit 93a5434
Show file tree
Hide file tree
Showing 20 changed files with 3,020 additions and 22 deletions.
5 changes: 4 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -486,4 +486,7 @@ $RECYCLE.BIN/
*.orig

.venv
site/
site/

# Snapshot test files
*.received.*
123 changes: 104 additions & 19 deletions src/CSnakes.SourceGeneration/PythonStaticGenerator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -46,15 +46,15 @@ public void Initialize(IncrementalGeneratorInitializationContext context)

if (result)
{
IEnumerable<MethodDefinition> methods = ModuleReflection.MethodsFromFunctionDefinitions(functions, fileName);
var methods = ModuleReflection.MethodsFromFunctionDefinitions(functions, fileName).ToImmutableArray();
string source = FormatClassFromMethods(@namespace, pascalFileName, methods, fileName, functions, hash);
sourceContext.AddSource($"{pascalFileName}.py.cs", source);
sourceContext.ReportDiagnostic(Diagnostic.Create(new DiagnosticDescriptor("PSG002", "PythonStaticGenerator", $"Generated {pascalFileName}.py.cs", "PythonStaticGenerator", DiagnosticSeverity.Info, true), Location.None));
}
});
}

public static string FormatClassFromMethods(string @namespace, string pascalFileName, IEnumerable<MethodDefinition> methods, string fileName, PythonFunctionDefinition[] functions, ImmutableArray<byte> hash)
public static string FormatClassFromMethods(string @namespace, string pascalFileName, ImmutableArray<MethodDefinition> methods, string fileName, PythonFunctionDefinition[] functions, ImmutableArray<byte> hash)
{
var paramGenericArgs = methods
.Select(m => m.ParameterGenericArgs)
Expand All @@ -65,6 +65,7 @@ public static string FormatClassFromMethods(string @namespace, string pascalFile
return $$"""
// <auto-generated/>
#nullable enable
using CSnakes.Runtime;
using CSnakes.Runtime.Python;
Expand All @@ -78,6 +79,7 @@ public static string FormatClassFromMethods(string @namespace, string pascalFile
[assembly: MetadataUpdateHandler(typeof({{@namespace}}.{{pascalFileName}}Extensions))]
namespace {{@namespace}};
public static class {{pascalFileName}}Extensions
{
private static I{{pascalFileName}}? instance;
Expand All @@ -94,16 +96,19 @@ public static class {{pascalFileName}}Extensions
return instance;
}
public static void UpdateApplication(Type[]? updatedTypes) {
public static void UpdateApplication(Type[]? updatedTypes)
{
instance?.ReloadModule();
}
private class {{pascalFileName}}Internal : I{{pascalFileName}}
{
private PyObject module;
private readonly ILogger<IPythonEnvironment> logger;
{{string.Join(Environment.NewLine, functionNames.Select(f => $"private PyObject {f.Field};")) }}
{{ Lines(IndentationLevel.Two,
from f in functionNames
select $"private PyObject {f.Field};") }}
internal {{pascalFileName}}Internal(ILogger<IPythonEnvironment> logger)
{
Expand All @@ -112,42 +117,59 @@ private class {{pascalFileName}}Internal : I{{pascalFileName}}
{
logger.LogDebug("Importing module {ModuleName}", "{{fileName}}");
module = Import.ImportModule("{{fileName}}");
{{ string.Join(Environment.NewLine, functionNames.Select(f => $"this.{f.Field} = module.GetAttr(\"{f.Attr}\");")) }}
{{ Lines(IndentationLevel.Four,
from f in functionNames
select $"this.{f.Field} = module.GetAttr(\"{f.Attr}\");") }}
}
}
void IReloadableModuleImport.ReloadModule() {
void IReloadableModuleImport.ReloadModule()
{
logger.LogDebug("Reloading module {ModuleName}", "{{fileName}}");
using (GIL.Acquire())
{
Import.ReloadModule(ref module);
// Dispose old functions
{{string.Join(Environment.NewLine, functionNames.Select(f => $"this.{f.Field}.Dispose();"))}}
{{string.Join(Environment.NewLine, functionNames.Select(f => $"this.{f.Field} = module.GetAttr(\"{f.Attr}\");"))}}
{{ Lines(IndentationLevel.Four,
from f in functionNames
select $"this.{f.Field}.Dispose();") }}
// Bind to new functions
{{ Lines(IndentationLevel.Four,
from f in functionNames
select $"this.{f.Field} = module.GetAttr(\"{f.Attr}\");") }}
}
}
public void Dispose()
{
logger.LogDebug("Disposing module {ModuleName}", "{{fileName}}");
{{ string.Join(Environment.NewLine, functionNames.Select(f => $"this.{f.Field}.Dispose();")) }}
{{ Lines(IndentationLevel.Three,
from f in functionNames
select $"this.{f.Field}.Dispose();") }}
module.Dispose();
}
{{methods.Select(m => m.Syntax).Compile()}}
{{ Lines(IndentationLevel.Two, methods.Select(m => m.Syntax).Compile().TrimEnd()) }}
}
}
public interface I{{pascalFileName}} : IReloadableModuleImport
{
{{string.Join(Environment.NewLine, methods.Select(m => m.Syntax)
.Select(m => m.Identifier.Text == "ReloadModule"
// This prevents the warning:
// > warning CS0108: 'IFooBar.ReloadModule()' hides inherited member 'IReloadableModuleImport.ReloadModule()'. Use the new keyword if hiding was intended.
// because `IReloadableModuleImport` already has a `ReloadModule` method.
? m.AddModifiers(SyntaxFactory.Token(SyntaxKind.NewKeyword))
: m)
.Select(m => $"{m.WithBody(null).NormalizeWhitespace()};"))}}
{{ Lines(IndentationLevel.One,
from m in methods
select m.Syntax into m
select m.Identifier.Text == "ReloadModule"
// This prevents the warning:
// > warning CS0108: 'IFooBar.ReloadModule()' hides inherited member 'IReloadableModuleImport.ReloadModule()'. Use the new keyword if hiding was intended.
// because `IReloadableModuleImport` already has a `ReloadModule` method.
? m.AddModifiers(SyntaxFactory.Token(SyntaxKind.NewKeyword))
: m
into m
select $"{m.WithModifiers(m.Modifiers.RemoveAt(m.Modifiers.IndexOf(SyntaxKind.PublicKeyword)))
.WithBody(null)
.NormalizeWhitespace()};") }}
}
""";
}

Expand All @@ -165,4 +187,67 @@ private static string HexString(ReadOnlySpan<byte> bytes)

return new string(chars);
}

const string Space = " ";
const string Indent = $"{Space}{Space}{Space}{Space}";

private static readonly string[] Indents =
[
"",
Indent,
Indent + Indent,
Indent + Indent + Indent,
Indent + Indent + Indent + Indent,
];

private enum IndentationLevel { Zero = 0, One = 1, Two = 2, Three = 3, Four = 4 }

private static FormattableLines Lines(IndentationLevel level, string lines) =>
Lines(level, from line in SourceText.From(lines).Lines
select line.ToString());

private static FormattableLines Lines(IndentationLevel level, IEnumerable<string> lines) =>
new([..lines], Indents[(int)level], string.Empty);

private sealed class FormattableLines(ImmutableArray<string> lines,
string? prefix = null,
string? emptyPrefix = null) :
IFormattable
{
string IFormattable.ToString(string? format, IFormatProvider? formatProvider)
{
if (lines.Length == 0)
return string.Empty;

var writer = new StringWriter();
var lastIndex = lines.Length - 1;
for (var i = 0; i < lines.Length; i++)
{
var line = lines[i];

if (line.Length > 0)
{
if (prefix is { } somePrefix)
writer.Write(somePrefix);
}
else if (emptyPrefix is { } someEmptyPrefix)
{
writer.Write(someEmptyPrefix);
}

if (i < lastIndex)
writer.WriteLine(line);
else
writer.Write(line);
}

return writer.ToString();
}

public override string ToString()
{
IFormattable formattable = this;
return formattable.ToString(format: null, formatProvider: null);
}
}
}
7 changes: 5 additions & 2 deletions src/CSnakes.SourceGeneration/Reflection/ModuleReflection.cs
Original file line number Diff line number Diff line change
Expand Up @@ -7,17 +7,20 @@ namespace CSnakes.Reflection;

public static class ModuleReflection
{
public static IEnumerable<MethodDefinition> MethodsFromFunctionDefinitions(PythonFunctionDefinition[] functions, string moduleName)
public static IEnumerable<MethodDefinition> MethodsFromFunctionDefinitions(IEnumerable<PythonFunctionDefinition> functions, string moduleName)
{
return functions.Select(function => MethodReflection.FromMethod(function, moduleName));
}

public static string Compile(this IEnumerable<MethodDeclarationSyntax> methods)
{
using StringWriter sw = new();
foreach (var method in methods)
foreach (var (i, method) in methods.Select((m, i) => (i, m)))
{
if (i > 0)
sw.WriteLine();
method.NormalizeWhitespace().WriteTo(sw);
sw.WriteLine();
}
return sw.ToString();
}
Expand Down
5 changes: 5 additions & 0 deletions src/CSnakes.Tests/CSnakes.Tests.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,14 @@
<EnableUnsafeBinaryFormatterSerialization>true</EnableUnsafeBinaryFormatterSerialization>
</PropertyGroup>

<ItemGroup>
<EmbeddedResource Include="..\Integration.Tests\python\*.py" Link="python\%(Filename)%(Extension)" />
</ItemGroup>

<ItemGroup>
<PackageReference Include="coverlet.collector" />
<PackageReference Include="Microsoft.NET.Test.Sdk" />
<PackageReference Include="Shouldly" />
<PackageReference Include="System.Net.Http" />
<PackageReference Include="System.Text.RegularExpressions" />
<PackageReference Include="xunit" />
Expand Down
65 changes: 65 additions & 0 deletions src/CSnakes.Tests/PythonStaticGeneratorTests.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
using CSnakes.Parser;
using CSnakes.Reflection;
using Microsoft.CodeAnalysis.Text;
using Shouldly;
using System.Collections.Immutable;
using System.Reflection;
using System.Text.RegularExpressions;

namespace CSnakes.Tests;

public class PythonStaticGeneratorTests
{
private static Assembly Assembly => typeof(GeneratedSignatureTests).Assembly;

public static readonly TheoryData<string> ResourceNames =
new(from name in Assembly.GetManifestResourceNames()
where name.EndsWith(".py")
select name);

[Theory]
[MemberData(nameof(ResourceNames))]
public void FormatClassFromMethods(string resourceName)
{
SourceText sourceText;

using (var stream = Assembly.GetManifestResourceStream(resourceName))
{
Assert.NotNull(stream);
using var reader = new StreamReader(stream);
string normalizedText = Regex.Replace(reader.ReadToEnd(), @"\r?\n", "\n");
sourceText = SourceText.From(normalizedText);
}

_ = PythonParser.TryParseFunctionDefinitions(sourceText, out var functions, out var errors);
Assert.Empty(errors);

var module = ModuleReflection.MethodsFromFunctionDefinitions(functions, "test").ToImmutableArray();
string compiledCode = PythonStaticGenerator.FormatClassFromMethods("Python.Generated.Tests", "TestClass", module, "test", functions, sourceText.GetContentHash());

try
{
compiledCode.ShouldMatchApproved(options =>
options.WithDiscriminator(// Just keep last part of the dotted name, e.g.:
// "CSnakes.Tests.python.test_args.py" -> "test_args"
Path.GetFileNameWithoutExtension(resourceName).Split('.').Last())
.SubFolder(GetType().Name)
.WithFilenameGenerator((info, d, type, ext) => $"{info.MethodName}{d}.{type}.{ext}")
.NoDiff());
}
catch (FileNotFoundException ex) when (ex.FileName is { } fn
&& fn.Contains(".received.", StringComparison.OrdinalIgnoreCase))
{
// `ShouldMatchApproved` deletes the received file when the condition is met:
// https://github.com/shouldly/shouldly/blob/4.2.1/src/Shouldly/ShouldlyExtensionMethods/ShouldMatchApprovedTestExtensions.cs#L70
//
// `File.Delete` is documented to never throw an exception if the file doesn't exist:
//
// > If the file to be deleted does not exist, no exception is thrown. Source:
// > https://learn.microsoft.com/en-us/dotnet/api/system.io.file.delete?view=net-8.0#remarks
//
// However, `FileNotFoundException` has been observed on some platforms during CI runs
// so we catch it and should be harmless to ignore.
}
}
}
Loading

0 comments on commit 93a5434

Please sign in to comment.