@@ -25,10 +25,9 @@ static void Generate(TypeDeclarationSyntax syntax, Compilation compilation, stri
2525 return ;
2626 }
2727
28- // nested is not allowed
29- if ( IsNested ( syntax ) )
28+ if ( IsNested ( syntax ) && ! IsNestedContainingTypesPartial ( syntax ) )
3029 {
31- context . ReportDiagnostic ( Diagnostic . Create ( DiagnosticDescriptors . NestedNotAllow , syntax . Identifier . GetLocation ( ) , typeSymbol . Name ) ) ;
30+ context . ReportDiagnostic ( Diagnostic . Create ( DiagnosticDescriptors . NestedContainingTypesMustBePartial , syntax . Identifier . GetLocation ( ) , typeSymbol . Name ) ) ;
3231 return ;
3332 }
3433
@@ -157,6 +156,21 @@ static bool IsPartial(TypeDeclarationSyntax typeDeclaration)
157156 return typeDeclaration . Modifiers . Any ( m => m . IsKind ( SyntaxKind . PartialKeyword ) ) ;
158157 }
159158
159+ static bool IsNestedContainingTypesPartial ( TypeDeclarationSyntax typeDeclaration )
160+ {
161+ if ( typeDeclaration . Parent is TypeDeclarationSyntax parentTypeDeclaration )
162+ {
163+ if ( ! IsPartial ( parentTypeDeclaration ) )
164+ return false ;
165+
166+ return IsNestedContainingTypesPartial ( parentTypeDeclaration ) ;
167+ }
168+ else
169+ {
170+ return true ;
171+ }
172+ }
173+
160174 static bool IsNested ( TypeDeclarationSyntax typeDeclaration )
161175 {
162176 return typeDeclaration . Parent is TypeDeclarationSyntax ;
@@ -296,6 +310,21 @@ public void Emit(StringBuilder writer, IGeneratorContext context)
296310 ( false , false ) => "class" ,
297311 } ;
298312
313+ var containingTypeDeclarations = new List < string > ( ) ;
314+ var containingType = Symbol . ContainingType ;
315+ while ( containingType is not null )
316+ {
317+ containingTypeDeclarations . Add ( ( containingType . IsRecord , containingType . IsValueType ) switch
318+ {
319+ ( true , true ) => $ "partial record struct { containingType . Name } ",
320+ ( true , false ) => $ "partial record { containingType . Name } ",
321+ ( false , true ) => $ "partial struct { containingType . Name } ",
322+ ( false , false ) => $ "partial class { containingType . Name } ",
323+ } ) ;
324+ containingType = containingType . ContainingType ;
325+ }
326+ containingTypeDeclarations . Reverse ( ) ;
327+
299328 var nullable = IsValueType ? "" : "?" ;
300329
301330 string staticRegisterFormatterMethod , staticMemoryPackableMethod , scopedRef , constraint , registerBody , registerT ;
@@ -345,6 +374,12 @@ public void Emit(StringBuilder writer, IGeneratorContext context)
345374 ? "Serialize(ref MemoryPackWriter"
346375 : "Serialize<TBufferWriter>(ref MemoryPackWriter<TBufferWriter>" ;
347376
377+ foreach ( var declaration in containingTypeDeclarations )
378+ {
379+ writer . AppendLine ( declaration ) ;
380+ writer . AppendLine ( "{" ) ;
381+ }
382+
348383 writer . AppendLine ( $$ """
349384partial {{ classOrStructOrRecord }} {{ TypeName }} : IMemoryPackable<{{ TypeName }} >{{ fixedSizeInterface }}
350385{
@@ -420,6 +455,10 @@ public override void Deserialize(ref MemoryPackReader reader, {{scopedRef}} {{Ty
420455 writer . AppendLine ( code ) ;
421456 }
422457
458+ for ( int i = 0 ; i < containingTypeDeclarations . Count ; ++ i )
459+ {
460+ writer . AppendLine ( "}" ) ;
461+ }
423462 }
424463
425464 private string EmitDeserializeBody ( )
0 commit comments