diff --git a/src/libraries/System.Configuration.ConfigurationManager/tests/System.Configuration.ConfigurationManager.Tests.csproj b/src/libraries/System.Configuration.ConfigurationManager/tests/System.Configuration.ConfigurationManager.Tests.csproj index 0a6b23dad7e658..6880adcbf6c2ff 100644 --- a/src/libraries/System.Configuration.ConfigurationManager/tests/System.Configuration.ConfigurationManager.Tests.csproj +++ b/src/libraries/System.Configuration.ConfigurationManager/tests/System.Configuration.ConfigurationManager.Tests.csproj @@ -59,7 +59,7 @@ - + diff --git a/src/libraries/System.Configuration.ConfigurationManager/tests/System/Configuration/CustomHostTests.cs b/src/libraries/System.Configuration.ConfigurationManager/tests/System/Configuration/CustomHostTests.cs index e8f69c198cf82d..3d32a020502a92 100644 --- a/src/libraries/System.Configuration.ConfigurationManager/tests/System/Configuration/CustomHostTests.cs +++ b/src/libraries/System.Configuration.ConfigurationManager/tests/System/Configuration/CustomHostTests.cs @@ -19,7 +19,8 @@ public void FilePathIsPopulatedCorrectly() { RemoteExecutor.Invoke(() => { - MakeAssemblyGetEntryAssemblyReturnNull(); + Assembly.SetEntryAssembly(null); + Assert.Null(Assembly.GetEntryAssembly()); string expectedFilePathEnding = RuntimeInformation.IsOSPlatform(OSPlatform.Windows) ? "dotnet.exe.config" : @@ -29,17 +30,5 @@ public void FilePathIsPopulatedCorrectly() Assert.EndsWith(expectedFilePathEnding, config.FilePath); }).Dispose(); } - - /// - /// Makes Assembly.GetEntryAssembly() return null using private reflection. - /// - private static void MakeAssemblyGetEntryAssemblyReturnNull() - { - typeof(Assembly) - .GetField("s_forceNullEntryPoint", BindingFlags.NonPublic | BindingFlags.Static) - .SetValue(null, true); - - Assert.Null(Assembly.GetEntryAssembly()); - } } } diff --git a/src/libraries/System.Diagnostics.TraceSource/tests/System.Diagnostics.TraceSource.Tests/DefaultTraceListenerClassTests.cs b/src/libraries/System.Diagnostics.TraceSource/tests/System.Diagnostics.TraceSource.Tests/DefaultTraceListenerClassTests.cs index efbb26a53a7a9e..ef84bec4fd2121 100644 --- a/src/libraries/System.Diagnostics.TraceSource/tests/System.Diagnostics.TraceSource.Tests/DefaultTraceListenerClassTests.cs +++ b/src/libraries/System.Diagnostics.TraceSource/tests/System.Diagnostics.TraceSource.Tests/DefaultTraceListenerClassTests.cs @@ -159,7 +159,8 @@ public void EntryAssemblyName_Null_NotIncludedInTrace() { RemoteExecutor.Invoke(() => { - MakeAssemblyGetEntryAssemblyReturnNull(); + Assembly.SetEntryAssembly(null); + Assert.Null(Assembly.GetEntryAssembly()); var listener = new TestDefaultTraceListener(); Trace.Listeners.Add(listener); @@ -167,17 +168,5 @@ public void EntryAssemblyName_Null_NotIncludedInTrace() Assert.Equal("Error: 0 : hello world", listener.Output.Trim()); }).Dispose(); } - - /// - /// Makes Assembly.GetEntryAssembly() return null using private reflection. - /// - private static void MakeAssemblyGetEntryAssemblyReturnNull() - { - typeof(Assembly) - .GetField("s_forceNullEntryPoint", BindingFlags.NonPublic | BindingFlags.Static) - .SetValue(null, true); - - Assert.Null(Assembly.GetEntryAssembly()); - } } } diff --git a/src/libraries/System.Private.CoreLib/src/System/Reflection/Assembly.cs b/src/libraries/System.Private.CoreLib/src/System/Reflection/Assembly.cs index d467fd6fb08242..c11d07f8e0b2cd 100644 --- a/src/libraries/System.Private.CoreLib/src/System/Reflection/Assembly.cs +++ b/src/libraries/System.Private.CoreLib/src/System/Reflection/Assembly.cs @@ -217,13 +217,36 @@ public override string ToString() return type.Module?.Assembly; } - // internal test hook - private static bool s_forceNullEntryPoint; + private static object? s_overriddenEntryAssembly; + + /// + /// Sets the application's entry assembly to the provided assembly object. + /// + /// + /// Assembly object that represents the application's new entry assembly. + /// + /// + /// The assembly passed to this function has to be a runtime defined Assembly + /// type object. Otherwise, an exception will be thrown. + /// + public static void SetEntryAssembly(Assembly? assembly) + { + if (assembly is null) + { + s_overriddenEntryAssembly = string.Empty; + return; + } + + if (assembly is not RuntimeAssembly) + throw new ArgumentException(SR.Argument_MustBeRuntimeAssembly); + + s_overriddenEntryAssembly = assembly; + } public static Assembly? GetEntryAssembly() { - if (s_forceNullEntryPoint) - return null; + if (s_overriddenEntryAssembly is not null) + return s_overriddenEntryAssembly as Assembly; return GetEntryAssemblyInternal(); } diff --git a/src/libraries/System.Runtime/ref/System.Runtime.cs b/src/libraries/System.Runtime/ref/System.Runtime.cs index fcfa67c001d8f5..eb2a12b198b716 100644 --- a/src/libraries/System.Runtime/ref/System.Runtime.cs +++ b/src/libraries/System.Runtime/ref/System.Runtime.cs @@ -11157,6 +11157,7 @@ public virtual void GetObjectData(System.Runtime.Serialization.SerializationInfo [System.Diagnostics.CodeAnalysis.RequiresUnreferencedCodeAttribute("Types and members the loaded assembly depends on might be removed")] [System.ObsoleteAttribute("ReflectionOnly loading is not supported and throws PlatformNotSupportedException.", DiagnosticId="SYSLIB0018", UrlFormat="https://aka.ms/dotnet-warnings/{0}")] public static System.Reflection.Assembly ReflectionOnlyLoadFrom(string assemblyFile) { throw null; } + public static void SetEntryAssembly(System.Reflection.Assembly? assembly) { throw null; } public override string ToString() { throw null; } [System.Diagnostics.CodeAnalysis.RequiresUnreferencedCodeAttribute("Types and members the loaded assembly depends on might be removed")] public static System.Reflection.Assembly UnsafeLoadFrom(string assemblyFile) { throw null; } diff --git a/src/libraries/System.Runtime/tests/System.Reflection.Tests/AssemblyTests.cs b/src/libraries/System.Runtime/tests/System.Reflection.Tests/AssemblyTests.cs index 600a4733b40c9c..2134b65251df40 100644 --- a/src/libraries/System.Runtime/tests/System.Reflection.Tests/AssemblyTests.cs +++ b/src/libraries/System.Runtime/tests/System.Reflection.Tests/AssemblyTests.cs @@ -7,10 +7,12 @@ using System.Globalization; using System.IO; using System.Linq; +using System.Reflection.Emit; using System.Reflection.Tests; using System.Runtime.CompilerServices; using System.Security; using System.Text; +using Microsoft.DotNet.RemoteExecutor; using Xunit; [assembly: @@ -181,6 +183,32 @@ public void GetEntryAssembly() Assert.True(correct, $"Unexpected assembly name {assembly}"); } + [ConditionalFact(typeof(RemoteExecutor), nameof(RemoteExecutor.IsSupported))] + public void SetEntryAssembly() + { + Assert.NotNull(Assembly.GetEntryAssembly()); + + RemoteExecutor.Invoke(() => + { + Assembly.SetEntryAssembly(null); + Assert.Null(Assembly.GetEntryAssembly()); + + Assembly testAssembly = typeof(AssemblyTests).Assembly; + + Assembly.SetEntryAssembly(testAssembly); + Assert.Equal(Assembly.GetEntryAssembly(), testAssembly); + + var invalidAssembly = new PersistedAssemblyBuilder( + new AssemblyName("NotaRuntimeAssemblyTest"), + typeof(object).Assembly + ); + + Assert.Throws( + () => Assembly.SetEntryAssembly(invalidAssembly) + ); + }).Dispose(); + } + [Fact] public void GetFile() { diff --git a/src/libraries/System.Runtime/tests/System.Reflection.Tests/System.Reflection.Tests.csproj b/src/libraries/System.Runtime/tests/System.Reflection.Tests/System.Reflection.Tests.csproj index 9553a7dacc4828..9c380a90bfbd23 100644 --- a/src/libraries/System.Runtime/tests/System.Reflection.Tests/System.Reflection.Tests.csproj +++ b/src/libraries/System.Runtime/tests/System.Reflection.Tests/System.Reflection.Tests.csproj @@ -2,6 +2,7 @@ $(NetCoreAppCurrent) true + true true