From ba370c583f832046736f6fdabd20668d79e39d44 Mon Sep 17 00:00:00 2001 From: pixsperdavid Date: Mon, 14 Apr 2025 16:19:01 +0100 Subject: [PATCH] Add support for unions in nested types --- .../MemoryPackGenerator.Emitter.cs | 55 ++++++++++++++----- tests/MemoryPack.Tests/NestedUnionTest.cs | 39 +++++++++++++ 2 files changed, 79 insertions(+), 15 deletions(-) create mode 100644 tests/MemoryPack.Tests/NestedUnionTest.cs diff --git a/src/MemoryPack.Generator/MemoryPackGenerator.Emitter.cs b/src/MemoryPack.Generator/MemoryPackGenerator.Emitter.cs index d17735fa..15e24ba6 100644 --- a/src/MemoryPack.Generator/MemoryPackGenerator.Emitter.cs +++ b/src/MemoryPack.Generator/MemoryPackGenerator.Emitter.cs @@ -310,20 +310,7 @@ public void Emit(StringBuilder writer, IGeneratorContext context) (false, false) => "class", }; - var containingTypeDeclarations = new List(); - var containingType = Symbol.ContainingType; - while (containingType is not null) - { - containingTypeDeclarations.Add((containingType.IsRecord, containingType.IsValueType) switch - { - (true, true) => $"partial record struct {containingType.Name}", - (true, false) => $"partial record {containingType.Name}", - (false, true) => $"partial struct {containingType.Name}", - (false, false) => $"partial class {containingType.Name}", - }); - containingType = containingType.ContainingType; - } - containingTypeDeclarations.Reverse(); + var containingTypeDeclarations = GetContainingTypeDeclarations(); var nullable = IsValueType ? "" : "?"; @@ -991,8 +978,17 @@ string EmitUnionTemplate(IGeneratorContext context) ? "Serialize(ref MemoryPackWriter" : "Serialize(ref MemoryPackWriter"; - var code = $$""" + var containingTypeDeclarations = GetContainingTypeDeclarations(); + + var containingTypesOpening = containingTypeDeclarations + .Select(d => $"{d}{Environment.NewLine}{{") + .NewLine(); + + var containingTypesClosing = Enumerable.Repeat("}", containingTypeDeclarations.Count) + .NewLine(); + var code = $$""" +{{containingTypesOpening}} partial {{classOrInterfaceOrRecord}} {{TypeName}} : IMemoryPackFormatterRegister { static partial void StaticConstructor(); @@ -1038,6 +1034,7 @@ public override void Deserialize(ref MemoryPackReader reader, {{scopedRef}} {{Ty } } } +{{containingTypesClosing}} """; return code; @@ -1268,6 +1265,34 @@ partial class {{TypeName}} : IMemoryPackFormatterRegister return code; } + + IReadOnlyList GetContainingTypeDeclarations() + { + var containingTypeDeclarations = new List(); + var containingType = Symbol.ContainingType; + while (containingType is not null) + { + if (containingType.TypeKind == TypeKind.Interface) + { + containingTypeDeclarations.Add($"partial interface {containingType.Name}"); + } + else + { + containingTypeDeclarations.Add((containingType.IsRecord, containingType.IsValueType) switch + { + (true, true) => $"partial record struct {containingType.Name}", + (true, false) => $"partial record {containingType.Name}", + (false, true) => $"partial struct {containingType.Name}", + (false, false) => $"partial class {containingType.Name}", + }); + } + + containingType = containingType.ContainingType; + } + containingTypeDeclarations.Reverse(); + + return containingTypeDeclarations; + } } public partial class MethodMeta diff --git a/tests/MemoryPack.Tests/NestedUnionTest.cs b/tests/MemoryPack.Tests/NestedUnionTest.cs new file mode 100644 index 00000000..9c1e35bb --- /dev/null +++ b/tests/MemoryPack.Tests/NestedUnionTest.cs @@ -0,0 +1,39 @@ +namespace MemoryPack.Tests; + +[MemoryPackable] +public partial record NestedUnion +{ + [MemoryPackable] + [MemoryPackUnion(0, typeof(NestedUnionA))] + [MemoryPackUnion(1, typeof(NestedUnionB))] + public partial interface INestedUnion + { + + } + + [MemoryPackable] + public partial record NestedUnionA(string Value) : INestedUnion; + + [MemoryPackable] + public partial record NestedUnionB(string Value) : INestedUnion; +} + +[MemoryPackable] +public partial record NestedUnionContainer +{ + public required NestedUnion.INestedUnion NestedUnion { get; init; } +} + +public class NestedUnionTest +{ + [Fact] + public void CanSerializeNestedUnion() + { + var data = new NestedUnionContainer { NestedUnion = new NestedUnion.NestedUnionB("Foo") }; + var bytes = MemoryPackSerializer.Serialize(data); + var result = MemoryPackSerializer.Deserialize(bytes); + result.Should().NotBeNull(); + result?.NestedUnion.Should().BeOfType(); + (result?.NestedUnion as NestedUnion.NestedUnionB)?.Value.Should().Be("Foo"); + } +}