Skip to content

Commit 83ce243

Browse files
committed
Add support for unions in nested types
1 parent fed6d9f commit 83ce243

File tree

2 files changed

+79
-15
lines changed

2 files changed

+79
-15
lines changed

src/MemoryPack.Generator/MemoryPackGenerator.Emitter.cs

Lines changed: 40 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -310,20 +310,7 @@ public void Emit(StringBuilder writer, IGeneratorContext context)
310310
(false, false) => "class",
311311
};
312312

313-
var containingTypeDeclarations = new List<string>();
314-
var containingType = Symbol.ContainingType;
315-
while (containingType is not null)
316-
{
317-
containingTypeDeclarations.Add((containingType.IsRecord, containingType.IsValueType) switch
318-
{
319-
(true, true) => $"partial record struct {containingType.Name}",
320-
(true, false) => $"partial record {containingType.Name}",
321-
(false, true) => $"partial struct {containingType.Name}",
322-
(false, false) => $"partial class {containingType.Name}",
323-
});
324-
containingType = containingType.ContainingType;
325-
}
326-
containingTypeDeclarations.Reverse();
313+
var containingTypeDeclarations = GetContainingTypeDeclarations();
327314

328315
var nullable = IsValueType ? "" : "?";
329316

@@ -991,8 +978,17 @@ string EmitUnionTemplate(IGeneratorContext context)
991978
? "Serialize(ref MemoryPackWriter"
992979
: "Serialize<TBufferWriter>(ref MemoryPackWriter<TBufferWriter>";
993980

994-
var code = $$"""
981+
var containingTypeDeclarations = GetContainingTypeDeclarations();
982+
983+
var containingTypesOpening = containingTypeDeclarations
984+
.Select(d => $"{d}{Environment.NewLine}{{")
985+
.NewLine();
986+
987+
var containingTypesClosing = Enumerable.Repeat("}", containingTypeDeclarations.Count)
988+
.NewLine();
995989

990+
var code = $$"""
991+
{{containingTypesOpening}}
996992
partial {{classOrInterfaceOrRecord}} {{TypeName}} : IMemoryPackFormatterRegister
997993
{
998994
static partial void StaticConstructor();
@@ -1038,6 +1034,7 @@ public override void Deserialize(ref MemoryPackReader reader, {{scopedRef}} {{Ty
10381034
}
10391035
}
10401036
}
1037+
{{containingTypesClosing}}
10411038
""";
10421039

10431040
return code;
@@ -1268,6 +1265,34 @@ partial class {{TypeName}} : IMemoryPackFormatterRegister
12681265

12691266
return code;
12701267
}
1268+
1269+
IReadOnlyList<string> GetContainingTypeDeclarations()
1270+
{
1271+
var containingTypeDeclarations = new List<string>();
1272+
var containingType = Symbol.ContainingType;
1273+
while (containingType is not null)
1274+
{
1275+
if (containingType.TypeKind == TypeKind.Interface)
1276+
{
1277+
containingTypeDeclarations.Add($"partial interface {containingType.Name}");
1278+
}
1279+
else
1280+
{
1281+
containingTypeDeclarations.Add((containingType.IsRecord, containingType.IsValueType) switch
1282+
{
1283+
(true, true) => $"partial record struct {containingType.Name}",
1284+
(true, false) => $"partial record {containingType.Name}",
1285+
(false, true) => $"partial struct {containingType.Name}",
1286+
(false, false) => $"partial class {containingType.Name}",
1287+
});
1288+
}
1289+
1290+
containingType = containingType.ContainingType;
1291+
}
1292+
containingTypeDeclarations.Reverse();
1293+
1294+
return containingTypeDeclarations;
1295+
}
12711296
}
12721297

12731298
public partial class MethodMeta
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
namespace MemoryPack.Tests;
2+
3+
[MemoryPackable]
4+
public partial record NestedUnion
5+
{
6+
[MemoryPackable]
7+
[MemoryPackUnion(0, typeof(NestedUnionA))]
8+
[MemoryPackUnion(1, typeof(NestedUnionB))]
9+
public partial interface INestedUnion
10+
{
11+
12+
}
13+
14+
[MemoryPackable]
15+
public partial record NestedUnionA(string Value) : INestedUnion;
16+
17+
[MemoryPackable]
18+
public partial record NestedUnionB(string Value) : INestedUnion;
19+
}
20+
21+
[MemoryPackable]
22+
public partial record NestedUnionContainer
23+
{
24+
public required NestedUnion.INestedUnion NestedUnion { get; init; }
25+
}
26+
27+
public class NestedUnionTest
28+
{
29+
[Fact]
30+
public void CanSerializeNestedUnion()
31+
{
32+
var data = new NestedUnionContainer { NestedUnion = new NestedUnion.NestedUnionB("Foo") };
33+
var bytes = MemoryPackSerializer.Serialize(data);
34+
var result = MemoryPackSerializer.Deserialize<NestedUnionContainer>(bytes);
35+
result.Should().NotBeNull();
36+
result?.NestedUnion.Should().BeOfType<NestedUnion.NestedUnionB>();
37+
(result?.NestedUnion as NestedUnion.NestedUnionB)?.Value.Should().Be("Foo");
38+
}
39+
}

0 commit comments

Comments
 (0)