diff --git a/src/libraries/System.Reflection.Emit/src/System/Reflection/Emit/SignatureHelper.cs b/src/libraries/System.Reflection.Emit/src/System/Reflection/Emit/SignatureHelper.cs index 625704f781faf0..0cacbcc8c3f1a7 100644 --- a/src/libraries/System.Reflection.Emit/src/System/Reflection/Emit/SignatureHelper.cs +++ b/src/libraries/System.Reflection.Emit/src/System/Reflection/Emit/SignatureHelper.cs @@ -223,12 +223,76 @@ private static void WriteSignatureForType(SignatureTypeEncoder signature, Type t { signature.GenericTypeParameter(type.GenericParameterPosition); } + else if (type.IsFunctionPointer) + { + WriteSignatureForFunctionPointerType(signature, type, module); + } else { WriteSimpleSignature(signature, type, module); } } + private static void WriteSignatureForFunctionPointerType(SignatureTypeEncoder signature, Type type, ModuleBuilderImpl module) + { + SignatureCallingConvention callConv = SignatureCallingConvention.Default; + FunctionPointerAttributes attribs = FunctionPointerAttributes.None; + + Type returnType = type.GetFunctionPointerReturnType(); + Type[] paramTypes = type.GetFunctionPointerParameterTypes(); + + if (type.IsUnmanagedFunctionPointer) + { + callConv = SignatureCallingConvention.Unmanaged; + + if (type.GetFunctionPointerCallingConventions() is Type[] conventions && conventions.Length == 1) + { + switch (conventions[0].FullName) + { + case "System.Runtime.CompilerServices.CallConvCdecl": + callConv = SignatureCallingConvention.CDecl; + break; + case "System.Runtime.CompilerServices.CallConvStdcall": + callConv = SignatureCallingConvention.StdCall; + break; + case "System.Runtime.CompilerServices.CallConvThiscall": + callConv = SignatureCallingConvention.ThisCall; + break; + case "System.Runtime.CompilerServices.CallConvFastcall": + callConv = SignatureCallingConvention.FastCall; + break; + } + } + } + + MethodSignatureEncoder sigEncoder = signature.FunctionPointer(callConv, attribs); + sigEncoder.Parameters(paramTypes.Length, out ReturnTypeEncoder retTypeEncoder, out ParametersEncoder paramsEncoder); + + CustomModifiersEncoder retModifiersEncoder = retTypeEncoder.CustomModifiers(); + + if (returnType.GetOptionalCustomModifiers() is Type[] retModOpts) + WriteCustomModifiers(retModifiersEncoder, retModOpts, isOptional: true, module); + + if (returnType.GetRequiredCustomModifiers() is Type[] retModReqs) + WriteCustomModifiers(retModifiersEncoder, retModReqs, isOptional: false, module); + + WriteSignatureForType(retTypeEncoder.Type(), returnType, module); + + foreach (Type paramType in paramTypes) + { + ParameterTypeEncoder paramEncoder = paramsEncoder.AddParameter(); + CustomModifiersEncoder paramModifiersEncoder = paramEncoder.CustomModifiers(); + + if (paramType.GetOptionalCustomModifiers() is Type[] paramModOpts) + WriteCustomModifiers(paramModifiersEncoder, paramModOpts, isOptional: true, module); + + if (paramType.GetRequiredCustomModifiers() is Type[] paramModReqs) + WriteCustomModifiers(paramModifiersEncoder, paramModReqs, isOptional: false, module); + + WriteSignatureForType(paramEncoder.Type(), paramType, module); + } + } + private static void WriteSimpleSignature(SignatureTypeEncoder signature, Type type, ModuleBuilderImpl module) { CoreTypeId? typeId = module.GetTypeIdFromCoreTypes(type); diff --git a/src/libraries/System.Reflection.Emit/tests/PersistedAssemblyBuilder/AssemblySaveTypeBuilderTests.cs b/src/libraries/System.Reflection.Emit/tests/PersistedAssemblyBuilder/AssemblySaveTypeBuilderTests.cs index 1894d8d5f35b86..4645dae1b7a92c 100644 --- a/src/libraries/System.Reflection.Emit/tests/PersistedAssemblyBuilder/AssemblySaveTypeBuilderTests.cs +++ b/src/libraries/System.Reflection.Emit/tests/PersistedAssemblyBuilder/AssemblySaveTypeBuilderTests.cs @@ -5,8 +5,11 @@ using System.Diagnostics.CodeAnalysis; using System.IO; using System.Linq; +using System.Numerics; using System.Reflection.Metadata; using System.Reflection.PortableExecutable; +using System.Runtime.CompilerServices; +using System.Runtime.InteropServices; using Xunit; namespace System.Reflection.Emit.Tests @@ -789,6 +792,155 @@ public void CreateGenericTypeFromMetadataLoadContextSignatureTypes() Assert.Equal("ValueTypeChildren", fields[1].Name); Assert.True(fields[1].FieldType.GetGenericArguments()[0].IsValueType); } + + [Fact] + public void SaveFunctionPointerFields() + { + using TempFile file = TempFile.Create(); + using MetadataLoadContext mlc = new MetadataLoadContext(new CoreMetadataAssemblyResolver()); + + PersistedAssemblyBuilder ab = AssemblySaveTools.PopulateAssemblyAndModule(out ModuleBuilder mb); + TypeBuilder tb = mb.DefineType("TestType", TypeAttributes.Public | TypeAttributes.Class); + + // delegate* + Type funcPtr1 = typeof(delegate*); + tb.DefineField("FuncPtr1", funcPtr1, FieldAttributes.Public | FieldAttributes.Static); + + // delegate* unmanaged[Cdecl] + Type funcPtr4 = new ModifiedTypeHelpers.FunctionPointer( + typeof(delegate* unmanaged[Cdecl]), + [typeof(CallConvCdecl)]); + tb.DefineField("FuncPtr2", funcPtr4, FieldAttributes.Public | FieldAttributes.Static); + + // delegate* unmanaged[Stdcall] + Type funcPtr5 = new ModifiedTypeHelpers.FunctionPointer( + typeof(delegate* unmanaged[Stdcall]), + [typeof(CallConvStdcall)], + customParameterTypes: [typeof(string), new ModifiedTypeHelpers.ModifiedType(typeof(int).MakeByRefType(), [typeof(InAttribute)], [])]); + tb.DefineField("FuncPtr3", funcPtr5, FieldAttributes.Public | FieldAttributes.Static); + + tb.CreateType(); + ab.Save(file.Path); + + Assembly assemblyFromDisk = mlc.LoadFromAssemblyPath(file.Path); + Type testType = assemblyFromDisk.Modules.First().GetType("TestType"); + Assert.NotNull(testType); + + FieldInfo field1 = testType.GetField("FuncPtr1"); + Assert.NotNull(field1); + Assert.True(field1.FieldType.IsFunctionPointer); + Assert.False(field1.FieldType.IsUnmanagedFunctionPointer); + Type[] paramTypes1 = field1.FieldType.GetFunctionPointerParameterTypes(); + Assert.Equal(1, paramTypes1.Length); + Assert.Equal(typeof(int).FullName, paramTypes1[0].FullName); + Assert.Equal(typeof(int).FullName, field1.FieldType.GetFunctionPointerReturnType().FullName); + + FieldInfo field2 = testType.GetField("FuncPtr2"); + Type field2Type = field2.GetModifiedFieldType(); + Assert.NotNull(field2); + Assert.True(field2Type.IsFunctionPointer); + Assert.True(field2Type.IsUnmanagedFunctionPointer); + Type[] paramTypes2 = field2Type.GetFunctionPointerParameterTypes(); + Assert.Equal(2, paramTypes2.Length); + Assert.Equal(typeof(int).FullName, paramTypes2[0].FullName); + Assert.Equal(typeof(float).FullName, paramTypes2[1].FullName); + Assert.Equal(typeof(double).FullName, field2Type.GetFunctionPointerReturnType().FullName); + Type[] callingConventions2 = field2Type.GetFunctionPointerCallingConventions(); + Assert.Contains(callingConventions2, t => t.FullName == typeof(CallConvCdecl).FullName); + + FieldInfo field3 = testType.GetField("FuncPtr3"); + Type field3Type = field3.GetModifiedFieldType(); + Assert.NotNull(field3); + Assert.True(field3Type.IsFunctionPointer); + Assert.True(field3Type.IsUnmanagedFunctionPointer); + Type[] paramTypes3 = field3Type.GetFunctionPointerParameterTypes(); + Assert.Equal(2, paramTypes3.Length); + Assert.Equal(typeof(string).FullName, paramTypes3[0].FullName); + Assert.Equal(typeof(int).MakeByRefType().FullName, paramTypes3[1].FullName); + Assert.Contains(paramTypes3[1].GetRequiredCustomModifiers(), t => t.FullName == typeof(InAttribute).FullName); + Assert.Equal(typeof(void).FullName, field3Type.GetFunctionPointerReturnType().FullName); + Type[] callingConventions3 = field3Type.GetFunctionPointerCallingConventions(); + Assert.Contains(callingConventions3, t => t.FullName == typeof(CallConvStdcall).FullName); + } + + [Fact] + public void ConsumeFunctionPointerFields() + { + // public unsafe class Container + // { + // public static delegate* Method; + // + // public static int Add(int a, int b) => a + b; + // public static void Init() => Method = &Add; + // } + + TempFile assembly1Path = TempFile.Create(); + PersistedAssemblyBuilder assembly1 = new(new AssemblyName("Assembly1"), typeof(object).Assembly); + ModuleBuilder mod1 = assembly1.DefineDynamicModule("Module1"); + TypeBuilder containerType = mod1.DefineType("Container", TypeAttributes.Public | TypeAttributes.Class); + FieldBuilder methodField = containerType.DefineField("Method", typeof(delegate*), FieldAttributes.Public | FieldAttributes.Static); + MethodBuilder addMethod = containerType.DefineMethod("Add", MethodAttributes.Public | MethodAttributes.Static); + addMethod.SetParameters(typeof(int), typeof(int)); + addMethod.SetReturnType(typeof(int)); + ILGenerator addMethodIL = addMethod.GetILGenerator(); + addMethodIL.Emit(OpCodes.Ldarg_0); + addMethodIL.Emit(OpCodes.Ldarg_1); + addMethodIL.Emit(OpCodes.Add); + addMethodIL.Emit(OpCodes.Ret); + MethodBuilder initMethod = containerType.DefineMethod("Init", MethodAttributes.Public | MethodAttributes.Static); + initMethod.SetReturnType(typeof(void)); + ILGenerator initMethodIL = initMethod.GetILGenerator(); + initMethodIL.Emit(OpCodes.Ldftn, addMethod); + initMethodIL.Emit(OpCodes.Stsfld, methodField); + initMethodIL.Emit(OpCodes.Ret); + containerType.CreateType(); + assembly1.Save(assembly1Path.Path); + + // class Program + // { + // public static int Main() + // { + // Container.Init(); + // return Container.Method(2, 3); + // } + // } + + TestAssemblyLoadContext context = new(); + + TempFile assembly2Path = TempFile.Create(); + Assembly assembly1FromDisk = context.LoadFromAssemblyPath(assembly1Path.Path); + PersistedAssemblyBuilder assembly2 = new(new AssemblyName("Assembly2"), typeof(object).Assembly); + ModuleBuilder mod2 = assembly2.DefineDynamicModule("Module2"); + TypeBuilder programType = mod2.DefineType("Program"); + MethodBuilder mainMethod = programType.DefineMethod("Main", MethodAttributes.Public | MethodAttributes.Static); + mainMethod.SetReturnType(typeof(int)); + ILGenerator il = mainMethod.GetILGenerator(); + il.Emit(OpCodes.Ldsfld, typeof(ClassWithFunctionPointerFields).GetField("field1")); + il.Emit(OpCodes.Pop); + // References to fields with unmanaged calling convention are broken + // [ActiveIssue("https://github.com/dotnet/runtime/issues/120909")] + // il.Emit(OpCodes.Ldsfld, typeof(ClassWithFunctionPointerFields).GetField("field2")); + // il.Emit(OpCodes.Pop); + // il.Emit(OpCodes.Ldsfld, typeof(ClassWithFunctionPointerFields).GetField("field3")); + // il.Emit(OpCodes.Pop); + // il.Emit(OpCodes.Ldsfld, typeof(ClassWithFunctionPointerFields).GetField("field4")); + // il.Emit(OpCodes.Pop); + il.Emit(OpCodes.Call, assembly1FromDisk.GetType("Container").GetMethod("Init")); + il.Emit(OpCodes.Ldc_I4_2); + il.Emit(OpCodes.Ldc_I4_3); + il.Emit(OpCodes.Ldsfld, assembly1FromDisk.GetType("Container").GetField("Method")); + il.EmitCalli(OpCodes.Calli, CallingConventions.Standard, typeof(int), [typeof(int), typeof(int)], null); + il.Emit(OpCodes.Ret); + programType.CreateType(); + assembly2.Save(assembly2Path.Path); + + Assembly assembly2FromDisk = context.LoadFromAssemblyPath(assembly2Path.Path); + int result = (int)assembly2FromDisk.GetType("Program").GetMethod("Main").Invoke(null, null); + Assert.Equal(5, result); + + assembly1Path.Dispose(); + assembly2Path.Dispose(); + } } // Test Types @@ -836,4 +988,12 @@ public class ClassWithFields : EmptyTestClass public EmptyTestClass field1; public byte field2; } + + public unsafe class ClassWithFunctionPointerFields + { + public static delegate* field1; + public static delegate* unmanaged field2; + public static delegate* unmanaged[Cdecl] field3; + public static delegate* unmanaged[Cdecl, SuppressGCTransition], Vector> field4; + } } diff --git a/src/libraries/System.Reflection.Emit/tests/Utilities.cs b/src/libraries/System.Reflection.Emit/tests/Utilities.cs index 4baf34722a49d9..446b33f8628aa9 100644 --- a/src/libraries/System.Reflection.Emit/tests/Utilities.cs +++ b/src/libraries/System.Reflection.Emit/tests/Utilities.cs @@ -164,4 +164,54 @@ public static string GetFullName(string name) return name; } } + + public static class ModifiedTypeHelpers + { + public class FunctionPointer : TypeDelegator + { + private readonly Type[] callingConventions; + private readonly Type returnType; + private readonly Type[] parameterTypes; + private readonly Type[] requiredModifiers; + private readonly Type[] optionalModifiers; + + public FunctionPointer( + Type baseFunctionPointerType, + Type[] conventions = null, + Type customReturnType = null, + Type[] customParameterTypes = null, + Type[] fnPtrRequiredMods = null, + Type[] fnPtrOptionalMods = null) + : base(baseFunctionPointerType) + { + callingConventions = conventions ?? []; + returnType = customReturnType ?? baseFunctionPointerType.GetFunctionPointerReturnType(); + parameterTypes = customParameterTypes ?? baseFunctionPointerType.GetFunctionPointerParameterTypes(); + requiredModifiers = fnPtrRequiredMods ?? []; + optionalModifiers = fnPtrOptionalMods ?? []; + } + + public override Type[] GetFunctionPointerCallingConventions() => callingConventions; + public override Type GetFunctionPointerReturnType() => returnType; + public override Type[] GetFunctionPointerParameterTypes() => parameterTypes; + public override Type[] GetRequiredCustomModifiers() => requiredModifiers; + public override Type[] GetOptionalCustomModifiers() => optionalModifiers; + } + + public class ModifiedType : TypeDelegator + { + private readonly Type[] requiredModifiers; + private readonly Type[] optionalModifiers; + + public ModifiedType(Type delegatingType, Type[] requiredMods = null, Type[] optionalMods = null) + : base(delegatingType) + { + requiredModifiers = requiredMods ?? []; + optionalModifiers = optionalMods ?? []; + } + + public override Type[] GetRequiredCustomModifiers() => requiredModifiers; + public override Type[] GetOptionalCustomModifiers() => optionalModifiers; + } + } }