diff --git a/src/coreclr/tools/Common/TypeSystem/IL/UnsafeAccessors.cs b/src/coreclr/tools/Common/TypeSystem/IL/UnsafeAccessors.cs index e8e97f2eb31973..b6bb5af4bc2fd0 100644 --- a/src/coreclr/tools/Common/TypeSystem/IL/UnsafeAccessors.cs +++ b/src/coreclr/tools/Common/TypeSystem/IL/UnsafeAccessors.cs @@ -318,6 +318,12 @@ private static bool DoesMethodMatchUnsafeAccessorDeclaration(ref GenerationConte return false; } + // Validate generic parameter. + if (declSig.GenericParameterCount != maybeSig.GenericParameterCount) + { + return false; + } + // Validate argument count and return type if (context.Kind == UnsafeAccessorKind.Constructor) { diff --git a/src/coreclr/vm/prestub.cpp b/src/coreclr/vm/prestub.cpp index 99086a462a26fb..0d7fb7f875453a 100644 --- a/src/coreclr/vm/prestub.cpp +++ b/src/coreclr/vm/prestub.cpp @@ -1190,14 +1190,26 @@ namespace return false; } - // Handle generic param count - DWORD declGenericCount = 0; - DWORD methodGenericCount = 0; + // Handle generic signature if (callConvDecl & IMAGE_CEE_CS_CALLCONV_GENERIC) + { + if (!(callConvMethod & IMAGE_CEE_CS_CALLCONV_GENERIC)) + return false; + + DWORD declGenericCount = 0; + DWORD methodGenericCount = 0; IfFailThrow(CorSigUncompressData_EndPtr(pSig1, pEndSig1, &declGenericCount)); - if (callConvMethod & IMAGE_CEE_CS_CALLCONV_GENERIC) IfFailThrow(CorSigUncompressData_EndPtr(pSig2, pEndSig2, &methodGenericCount)); + if (declGenericCount != methodGenericCount) + return false; + } + else if (callConvMethod & IMAGE_CEE_CS_CALLCONV_GENERIC) + { + // Method is generic but declaration is not + return false; + } + DWORD declArgCount; DWORD methodArgCount; IfFailThrow(CorSigUncompressData_EndPtr(pSig1, pEndSig1, &declArgCount)); @@ -3541,7 +3553,7 @@ static PCODE getHelperForStaticBase(Module * pModule, CORCOMPILE_FIXUP_BLOB_KIND bool threadStatic = (kind == ENCODE_THREAD_STATIC_BASE_NONGC_HELPER || kind == ENCODE_THREAD_STATIC_BASE_GC_HELPER); CorInfoHelpFunc helper; - + if (threadStatic) { if (GCStatic) diff --git a/src/tests/baseservices/compilerservices/UnsafeAccessors/UnsafeAccessorsTests.Generics.cs b/src/tests/baseservices/compilerservices/UnsafeAccessors/UnsafeAccessorsTests.Generics.cs index 311550810224c8..1504a87af72850 100644 --- a/src/tests/baseservices/compilerservices/UnsafeAccessors/UnsafeAccessorsTests.Generics.cs +++ b/src/tests/baseservices/compilerservices/UnsafeAccessors/UnsafeAccessorsTests.Generics.cs @@ -181,21 +181,76 @@ public static void Verify_Generic_AccessFieldClass() } } + class AmbiguousMethodName + { + private void M() { } + private void M() { } + private void N() { } + + private static void SM() { } + private static void SM() { } + private static void SN() { } + } + + static class AccessorsAmbiguousMethodName + { + [UnsafeAccessor(UnsafeAccessorKind.Method, Name = "M")] + public extern static void CallM(AmbiguousMethodName a); + + [UnsafeAccessor(UnsafeAccessorKind.Method, Name = "M")] + public extern static void CallM(AmbiguousMethodName a); + + [UnsafeAccessor(UnsafeAccessorKind.Method, Name = "N")] + public extern static void CallN_MissingMethod(AmbiguousMethodName a); + + [UnsafeAccessor(UnsafeAccessorKind.StaticMethod, Name = "SM")] + public extern static void CallSM(AmbiguousMethodName a); + + [UnsafeAccessor(UnsafeAccessorKind.StaticMethod, Name = "SM")] + public extern static void CallSM(AmbiguousMethodName a); + + [UnsafeAccessor(UnsafeAccessorKind.StaticMethod, Name = "SN")] + public extern static void CallSN_MissingMethod(AmbiguousMethodName a); + } + + [Fact] + public static void Verify_Generic_AmbiguousMethodName() + { + Console.WriteLine($"Running {nameof(Verify_Generic_AmbiguousMethodName)}"); + + { + AmbiguousMethodName a = new(); + AccessorsAmbiguousMethodName.CallM(a); + AccessorsAmbiguousMethodName.CallM(a); + AccessorsAmbiguousMethodName.CallM(a); + AccessorsAmbiguousMethodName.CallM(a); + Assert.Throws(() => AccessorsAmbiguousMethodName.CallN_MissingMethod(a)); + } + + { + AccessorsAmbiguousMethodName.CallSM(null); + AccessorsAmbiguousMethodName.CallSM(null); + AccessorsAmbiguousMethodName.CallSM(null); + AccessorsAmbiguousMethodName.CallSM(null); + Assert.Throws(() => AccessorsAmbiguousMethodName.CallSN_MissingMethod(null)); + } + } + class Base { - protected virtual string CreateMessageGeneric(T t) => $"{nameof(Base)}:{t}"; + protected virtual string CreateMessage(T t) => $"{nameof(Base)}<>:{t}"; } - class GenericBase : Base + class GenericBase : Base { - protected virtual string CreateMessage(T t) => $"{nameof(GenericBase)}:{t}"; - protected override string CreateMessageGeneric(U u) => $"{nameof(GenericBase)}:{u}"; + protected virtual string CreateMessage(U u) => $"{nameof(GenericBase)}:{u}"; + protected override string CreateMessage(V v) => $"{nameof(GenericBase)}<>:{v}"; } sealed class Derived1 : GenericBase { protected override string CreateMessage(string u) => $"{nameof(Derived1)}:{u}"; - protected override string CreateMessageGeneric(U t) => $"{nameof(Derived1)}:{t}"; + protected override string CreateMessage(W w) => $"{nameof(Derived1)}<>:{w}"; } sealed class Derived2 : GenericBase @@ -209,33 +264,33 @@ public static void Verify_Generic_InheritanceMethodResolution() Console.WriteLine($"Running {nameof(Verify_Generic_InheritanceMethodResolution)}"); { Base a = new(); - Assert.Equal($"{nameof(Base)}:1", CreateMessage(a, 1)); - Assert.Equal($"{nameof(Base)}:{expect}", CreateMessage(a, expect)); - Assert.Equal($"{nameof(Base)}:{nameof(Struct)}", CreateMessage(a, new Struct())); + Assert.Equal($"{nameof(Base)}<>:1", CreateMessage(a, 1)); + Assert.Equal($"{nameof(Base)}<>:{expect}", CreateMessage(a, expect)); + Assert.Equal($"{nameof(Base)}<>:{nameof(Struct)}", CreateMessage(a, new Struct())); } { GenericBase a = new(); - Assert.Equal($"{nameof(GenericBase)}:1", CreateMessage(a, 1)); - Assert.Equal($"{nameof(GenericBase)}:{expect}", CreateMessage(a, expect)); - Assert.Equal($"{nameof(GenericBase)}:{nameof(Struct)}", CreateMessage(a, new Struct())); + Assert.Equal($"{nameof(GenericBase)}<>:1", CreateMessage(a, 1)); + Assert.Equal($"{nameof(GenericBase)}<>:{expect}", CreateMessage(a, expect)); + Assert.Equal($"{nameof(GenericBase)}<>:{nameof(Struct)}", CreateMessage(a, new Struct())); } { GenericBase a = new(); - Assert.Equal($"{nameof(GenericBase)}:1", CreateMessage(a, 1)); - Assert.Equal($"{nameof(GenericBase)}:{expect}", CreateMessage(a, expect)); - Assert.Equal($"{nameof(GenericBase)}:{nameof(Struct)}", CreateMessage(a, new Struct())); + Assert.Equal($"{nameof(GenericBase)}<>:1", CreateMessage(a, 1)); + Assert.Equal($"{nameof(GenericBase)}<>:{expect}", CreateMessage(a, expect)); + Assert.Equal($"{nameof(GenericBase)}<>:{nameof(Struct)}", CreateMessage(a, new Struct())); } { GenericBase a = new(); - Assert.Equal($"{nameof(GenericBase)}:1", CreateMessage(a, 1)); - Assert.Equal($"{nameof(GenericBase)}:{expect}", CreateMessage(a, expect)); - Assert.Equal($"{nameof(GenericBase)}:{nameof(Struct)}", CreateMessage(a, new Struct())); + Assert.Equal($"{nameof(GenericBase)}<>:1", CreateMessage(a, 1)); + Assert.Equal($"{nameof(GenericBase)}<>:{expect}", CreateMessage(a, expect)); + Assert.Equal($"{nameof(GenericBase)}<>:{nameof(Struct)}", CreateMessage(a, new Struct())); } { Derived1 a = new(); - Assert.Equal($"{nameof(Derived1)}:1", CreateMessage(a, 1)); - Assert.Equal($"{nameof(Derived1)}:{expect}", CreateMessage(a, expect)); - Assert.Equal($"{nameof(Derived1)}:{nameof(Struct)}", CreateMessage(a, new Struct())); + Assert.Equal($"{nameof(Derived1)}<>:1", CreateMessage(a, 1)); + Assert.Equal($"{nameof(Derived1)}<>:{expect}", CreateMessage(a, expect)); + Assert.Equal($"{nameof(Derived1)}<>:{nameof(Struct)}", CreateMessage(a, new Struct())); } { // Verify resolution of generic override logic. @@ -245,7 +300,7 @@ public static void Verify_Generic_InheritanceMethodResolution() Assert.Equal($"{nameof(GenericBase)}:{expect}", Accessors.CreateMessage(a2, expect)); } - [UnsafeAccessor(UnsafeAccessorKind.Method, Name = "CreateMessageGeneric")] + [UnsafeAccessor(UnsafeAccessorKind.Method, Name = "CreateMessage")] extern static string CreateMessage(Base b, W w); }