diff --git a/src/libraries/System.Private.CoreLib/src/Resources/Strings.resx b/src/libraries/System.Private.CoreLib/src/Resources/Strings.resx index 487028d306a3fa..475d54061b4d0b 100644 --- a/src/libraries/System.Private.CoreLib/src/Resources/Strings.resx +++ b/src/libraries/System.Private.CoreLib/src/Resources/Strings.resx @@ -4358,4 +4358,7 @@ VariantWrappers cannot be stored in Variants. + + Creating an object wrapper for a COM instance with user state is not implemented. + diff --git a/src/libraries/System.Private.CoreLib/src/System.Private.CoreLib.Shared.projitems b/src/libraries/System.Private.CoreLib/src/System.Private.CoreLib.Shared.projitems index 5d1bfacb65637f..4b74be3c56bb17 100644 --- a/src/libraries/System.Private.CoreLib/src/System.Private.CoreLib.Shared.projitems +++ b/src/libraries/System.Private.CoreLib/src/System.Private.CoreLib.Shared.projitems @@ -961,6 +961,7 @@ + @@ -2774,7 +2775,7 @@ - + @@ -2849,4 +2850,4 @@ - + \ No newline at end of file diff --git a/src/libraries/System.Private.CoreLib/src/System/Runtime/InteropServices/ComWrappers.PlatformNotSupported.cs b/src/libraries/System.Private.CoreLib/src/System/Runtime/InteropServices/ComWrappers.PlatformNotSupported.cs index 1d87d2b2e9dbba..c7c046732046c3 100644 --- a/src/libraries/System.Private.CoreLib/src/System/Runtime/InteropServices/ComWrappers.PlatformNotSupported.cs +++ b/src/libraries/System.Private.CoreLib/src/System/Runtime/InteropServices/ComWrappers.PlatformNotSupported.cs @@ -25,6 +25,11 @@ public struct ComInterfaceEntry protected abstract object? CreateObject(IntPtr externalComObject, CreateObjectFlags flags); + protected virtual object? CreateObject(IntPtr externalComObject, CreateObjectFlags flags, object? userState, out CreatedWrapperFlags wrapperFlags) + { + throw new PlatformNotSupportedException(); + } + protected internal abstract void ReleaseObjects(IEnumerable objects); public static unsafe bool TryGetComInstance(object obj, out IntPtr unknown) @@ -58,6 +63,11 @@ public object GetOrCreateObjectForComInstance(IntPtr externalComObject, CreateOb throw new PlatformNotSupportedException(); } + public object GetOrCreateObjectForComInstance(IntPtr externalComObject, CreateObjectFlags flags, object? userState) + { + throw new PlatformNotSupportedException(); + } + public object GetOrRegisterObjectForComInstance(IntPtr externalComObject, CreateObjectFlags flags, object wrapper) { throw new PlatformNotSupportedException(); diff --git a/src/libraries/System.Private.CoreLib/src/System/Runtime/InteropServices/ComWrappers.cs b/src/libraries/System.Private.CoreLib/src/System/Runtime/InteropServices/ComWrappers.cs index 639f72b2d93416..636bff946229a7 100644 --- a/src/libraries/System.Private.CoreLib/src/System/Runtime/InteropServices/ComWrappers.cs +++ b/src/libraries/System.Private.CoreLib/src/System/Runtime/InteropServices/ComWrappers.cs @@ -901,7 +901,28 @@ private static nuint AlignUp(nuint value, nuint alignment) public object GetOrCreateObjectForComInstance(IntPtr externalComObject, CreateObjectFlags flags) { object? obj; - if (!TryGetOrCreateObjectForComInstanceInternal(externalComObject, IntPtr.Zero, flags, null, out obj)) + if (!TryGetOrCreateObjectForComInstanceInternal(externalComObject, IntPtr.Zero, flags, wrapperMaybe: null, userState: NoUserState.Instance, out obj)) + throw new ArgumentNullException(nameof(externalComObject)); + + return obj; + } + + /// + /// Get the currently registered managed object or creates a new managed object and registers it. + /// + /// Object to import for usage into the .NET runtime. + /// Flags used to describe the external object. + /// A state object to use to help create the wrapping .NET object. + /// Returns a managed object associated with the supplied external COM object. + /// + /// If a managed object was previously created for the specified + /// using this instance, the previously created object will be returned. + /// If not, a new one will be created. + /// + public object GetOrCreateObjectForComInstance(IntPtr externalComObject, CreateObjectFlags flags, object? userState) + { + object? obj; + if (!TryGetOrCreateObjectForComInstanceInternal(externalComObject, IntPtr.Zero, flags, wrapperMaybe: null, userState, out obj)) throw new ArgumentNullException(nameof(externalComObject)); return obj; @@ -943,7 +964,7 @@ public object GetOrRegisterObjectForComInstance(IntPtr externalComObject, Create ArgumentNullException.ThrowIfNull(wrapper); object? obj; - if (!TryGetOrCreateObjectForComInstanceInternal(externalComObject, inner, flags, wrapper, out obj)) + if (!TryGetOrCreateObjectForComInstanceInternal(externalComObject, inner, flags, wrapper, userState: NoUserState.Instance, out obj)) throw new ArgumentNullException(nameof(externalComObject)); return obj; @@ -1025,6 +1046,15 @@ private static void DetermineIdentityAndInner( } } + private sealed class NoUserState + { + public static readonly NoUserState Instance = new NoUserState(); + + private NoUserState() + { + } + } + /// /// Get the currently registered managed object or creates a new managed object and registers it. /// @@ -1032,13 +1062,15 @@ private static void DetermineIdentityAndInner( /// The inner instance if aggregation is involved /// Flags used to describe the external object. /// The to be used as the wrapper for the external object. - /// The managed object associated with the supplied external COM object or null if it could not be created. + /// A state object provided by the user for creating the object, otherwise . /// Returns true if a managed object could be retrieved/created, false otherwise + /// The managed object associated with the supplied external COM object or null if it could not be created. private unsafe bool TryGetOrCreateObjectForComInstanceInternal( IntPtr externalComObject, IntPtr innerMaybe, CreateObjectFlags flags, object? wrapperMaybe, + object? userState, [NotNullWhen(true)] out object? retValue) { if (externalComObject == IntPtr.Zero) @@ -1062,7 +1094,7 @@ private unsafe bool TryGetOrCreateObjectForComInstanceInternal( // and return. if (flags.HasFlag(CreateObjectFlags.UniqueInstance)) { - retValue = CreateAndRegisterObjectForComInstance(identity, inner, flags, ref referenceTrackerMaybe); + retValue = CreateAndRegisterObjectForComInstance(identity, inner, flags, userState, ref referenceTrackerMaybe); return retValue is not null; } @@ -1116,7 +1148,7 @@ private unsafe bool TryGetOrCreateObjectForComInstanceInternal( // If the user didn't provide a wrapper and couldn't unwrap a managed object wrapper, // create a new wrapper. - retValue = CreateAndRegisterObjectForComInstance(identity, inner, flags, ref referenceTrackerMaybe); + retValue = CreateAndRegisterObjectForComInstance(identity, inner, flags, userState, ref referenceTrackerMaybe); return retValue is not null; } finally @@ -1137,15 +1169,32 @@ private unsafe bool TryGetOrCreateObjectForComInstanceInternal( IntPtr identity, IntPtr inner, CreateObjectFlags flags, + object? userState, ref IntPtr referenceTrackerMaybe) { - object? retValue = CreateObject(identity, flags); + CreatedWrapperFlags wrapperFlags = CreatedWrapperFlags.None; + + object? retValue = userState is NoUserState + ? CreateObject(identity, flags) + : CreateObject(identity, flags, userState, out wrapperFlags); + if (retValue is null) { // If ComWrappers instance cannot create wrapper, we can do nothing here. return null; } + if (wrapperFlags.HasFlag(CreatedWrapperFlags.NonWrapping)) + { + return retValue; + } + + if (wrapperFlags.HasFlag(CreatedWrapperFlags.TrackerObject)) + { + // The user has determined after inspecting the COM object that it should have tracker support. + flags |= CreateObjectFlags.TrackerObject; + } + return RegisterObjectForComInstance(identity, inner, retValue, flags, ref referenceTrackerMaybe); } @@ -1391,6 +1440,23 @@ public static void RegisterForMarshalling(ComWrappers instance) /// protected abstract object? CreateObject(IntPtr externalComObject, CreateObjectFlags flags); + /// + /// Create a managed object for the object pointed at by respecting the values of . + /// + /// Object to import for usage into the .NET runtime. + /// Flags used to describe the external object. + /// User state provided by the call to . + /// Flags used to describe the created wrapper object. + /// Returns a managed object associated with the supplied external COM object. + /// + /// The default implementation throws . + /// If the object cannot be created and null is returned, the call to will throw a . + /// + protected virtual object? CreateObject(IntPtr externalComObject, CreateObjectFlags flags, object? userState, out CreatedWrapperFlags wrapperFlags) + { + throw new NotImplementedException(SR.NotImplemented_CreateObjectWithUserState); + } + /// /// Called when a request is made for a collection of objects to be released outside of normal object or COM interface lifetime. /// diff --git a/src/libraries/System.Private.CoreLib/src/System/Runtime/InteropServices/CreateObjectFlags.cs b/src/libraries/System.Private.CoreLib/src/System/Runtime/InteropServices/CreateObjectFlags.cs index 337e955a57e708..b0a50ba4d17261 100644 --- a/src/libraries/System.Private.CoreLib/src/System/Runtime/InteropServices/CreateObjectFlags.cs +++ b/src/libraries/System.Private.CoreLib/src/System/Runtime/InteropServices/CreateObjectFlags.cs @@ -6,7 +6,6 @@ namespace System.Runtime.InteropServices { - /// /// Enumeration of flags for . /// diff --git a/src/libraries/System.Private.CoreLib/src/System/Runtime/InteropServices/CreatedWrapperFlags.cs b/src/libraries/System.Private.CoreLib/src/System/Runtime/InteropServices/CreatedWrapperFlags.cs new file mode 100644 index 00000000000000..a8c68a497a85e2 --- /dev/null +++ b/src/libraries/System.Private.CoreLib/src/System/Runtime/InteropServices/CreatedWrapperFlags.cs @@ -0,0 +1,33 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +namespace System.Runtime.InteropServices +{ + /// + /// Enumeration of flags for . + /// + [Flags] + public enum CreatedWrapperFlags + { + None = 0, + + /// + /// Indicate if the supplied external COM object implements the IReferenceTracker. + /// + TrackerObject = 1, + + /// + /// The managed object doesn't keep the native object alive. It represents an equivalent value. + /// + /// + /// Using this flag results in the following changes: + /// will return false for the returned object. + /// The features provided by the flag will be disabled. + /// Integration between and the returned object via the native IWeakReferenceSource interface will not work. + /// behavior is implied. + /// Diagnostics tooling support to unwrap objects returned by `CreateObject` will not see this object as a wrapper. + /// The same object can be returned from `CreateObject` wrapping different COM objects. + /// + NonWrapping = 0x2 + } +} diff --git a/src/libraries/System.Runtime.InteropServices/ref/System.Runtime.InteropServices.cs b/src/libraries/System.Runtime.InteropServices/ref/System.Runtime.InteropServices.cs index eb2130e4f80bc8..c49c5525d18a1a 100644 --- a/src/libraries/System.Runtime.InteropServices/ref/System.Runtime.InteropServices.cs +++ b/src/libraries/System.Runtime.InteropServices/ref/System.Runtime.InteropServices.cs @@ -478,6 +478,10 @@ public void FromManaged(object? managed) { } public void Free() { } } } + [System.Runtime.Versioning.UnsupportedOSPlatformAttribute("android")] + [System.Runtime.Versioning.UnsupportedOSPlatformAttribute("browser")] + [System.Runtime.Versioning.UnsupportedOSPlatformAttribute("ios")] + [System.Runtime.Versioning.UnsupportedOSPlatformAttribute("tvos")] [System.CLSCompliantAttribute(false)] public partial class StrategyBasedComWrappers : System.Runtime.InteropServices.ComWrappers { @@ -488,6 +492,7 @@ public StrategyBasedComWrappers() { } protected virtual System.Runtime.InteropServices.Marshalling.IIUnknownCacheStrategy CreateCacheStrategy() { throw null; } protected static System.Runtime.InteropServices.Marshalling.IIUnknownCacheStrategy CreateDefaultCacheStrategy() { throw null; } protected sealed override object CreateObject(nint externalComObject, System.Runtime.InteropServices.CreateObjectFlags flags) { throw null; } + protected sealed override object? CreateObject(nint externalComObject, System.Runtime.InteropServices.CreateObjectFlags flags, object? userState, out System.Runtime.InteropServices.CreatedWrapperFlags wrapperFlags) { throw null; } protected virtual System.Runtime.InteropServices.Marshalling.IIUnknownInterfaceDetailsStrategy GetOrCreateInterfaceDetailsStrategy() { throw null; } protected virtual System.Runtime.InteropServices.Marshalling.IIUnknownStrategy GetOrCreateIUnknownStrategy() { throw null; } protected sealed override void ReleaseObjects(System.Collections.IEnumerable objects) { } @@ -762,16 +767,18 @@ public struct ComInterfaceDispatch public System.IntPtr Vtable; public unsafe static T GetInstance(ComInterfaceDispatch* dispatchPtr) where T : class { throw null; } } - public System.IntPtr GetOrCreateComInterfaceForObject(object instance, CreateComInterfaceFlags flags) { throw null; } - protected unsafe abstract ComInterfaceEntry* ComputeVtables(object obj, CreateComInterfaceFlags flags, out int count); - public object GetOrCreateObjectForComInstance(System.IntPtr externalComObject, CreateObjectFlags flags) { throw null; } - protected abstract object? CreateObject(System.IntPtr externalComObject, CreateObjectFlags flags); - public object GetOrRegisterObjectForComInstance(System.IntPtr externalComObject, CreateObjectFlags flags, object wrapper) { throw null; } - public object GetOrRegisterObjectForComInstance(System.IntPtr externalComObject, CreateObjectFlags flags, object wrapper, System.IntPtr inner) { throw null; } + public System.IntPtr GetOrCreateComInterfaceForObject(object instance, System.Runtime.InteropServices.CreateComInterfaceFlags flags) { throw null; } + protected unsafe abstract ComInterfaceEntry* ComputeVtables(object obj, System.Runtime.InteropServices.CreateComInterfaceFlags flags, out int count); + public object GetOrCreateObjectForComInstance(System.IntPtr externalComObject, System.Runtime.InteropServices.CreateObjectFlags flags) { throw null; } + public object GetOrCreateObjectForComInstance(System.IntPtr externalComObject, System.Runtime.InteropServices.CreateObjectFlags flags, object? userState) { throw null; } + protected abstract object? CreateObject(System.IntPtr externalComObject, System.Runtime.InteropServices.CreateObjectFlags flags); + protected virtual object? CreateObject(System.IntPtr externalComObject, System.Runtime.InteropServices.CreateObjectFlags flags, object? userState, out System.Runtime.InteropServices.CreatedWrapperFlags wrapperFlags) { throw null; } + public object GetOrRegisterObjectForComInstance(System.IntPtr externalComObject, System.Runtime.InteropServices.CreateObjectFlags flags, object wrapper) { throw null; } + public object GetOrRegisterObjectForComInstance(System.IntPtr externalComObject, System.Runtime.InteropServices.CreateObjectFlags flags, object wrapper, System.IntPtr inner) { throw null; } protected abstract void ReleaseObjects(System.Collections.IEnumerable objects); public static void RegisterForTrackerSupport(ComWrappers instance) { } [System.Runtime.Versioning.SupportedOSPlatformAttribute("windows")] - public static void RegisterForMarshalling(ComWrappers instance) { } + public static void RegisterForMarshalling(System.Runtime.InteropServices.ComWrappers instance) { } public static void GetIUnknownImpl(out System.IntPtr fpQueryInterface, out System.IntPtr fpAddRef, out System.IntPtr fpRelease) { throw null; } } [System.FlagsAttribute] @@ -790,6 +797,13 @@ public enum CreateObjectFlags Aggregation = 4, Unwrap = 8, } + [System.FlagsAttribute] + public enum CreatedWrapperFlags + { + None = 0, + TrackerObject = 1, + NonWrapping = 2 + } [System.CLSCompliantAttribute(false)] public readonly partial struct CULong : System.IEquatable { diff --git a/src/libraries/System.Runtime.InteropServices/src/System/Runtime/InteropServices/Marshalling/StrategyBasedComWrappers.cs b/src/libraries/System.Runtime.InteropServices/src/System/Runtime/InteropServices/Marshalling/StrategyBasedComWrappers.cs index 9741b482633c24..74c147a5c2c21f 100644 --- a/src/libraries/System.Runtime.InteropServices/src/System/Runtime/InteropServices/Marshalling/StrategyBasedComWrappers.cs +++ b/src/libraries/System.Runtime.InteropServices/src/System/Runtime/InteropServices/Marshalling/StrategyBasedComWrappers.cs @@ -5,12 +5,17 @@ using System.Diagnostics.CodeAnalysis; using System.Reflection; using System.Runtime.CompilerServices; +using System.Runtime.Versioning; namespace System.Runtime.InteropServices.Marshalling { /// /// A -based type that uses customizable strategy objects to implement COM object wrappers and managed object wrappers exposed to COM. /// + [UnsupportedOSPlatform("android")] + [UnsupportedOSPlatform("browser")] + [UnsupportedOSPlatform("ios")] + [UnsupportedOSPlatform("tvos")] [CLSCompliant(false)] public class StrategyBasedComWrappers : ComWrappers { @@ -84,7 +89,7 @@ static IIUnknownInterfaceDetailsStrategy GetInteropStrategy() return null; } - /// + /// protected sealed override unsafe object CreateObject(nint externalComObject, CreateObjectFlags flags) { if (flags.HasFlag(CreateObjectFlags.TrackerObject) @@ -101,6 +106,12 @@ protected sealed override unsafe object CreateObject(nint externalComObject, Cre return rcw; } + /// + protected sealed override object? CreateObject(nint externalComObject, CreateObjectFlags flags, object? userState, out CreatedWrapperFlags wrapperFlags) + { + return base.CreateObject(externalComObject, flags, userState, out wrapperFlags); + } + /// protected sealed override void ReleaseObjects(IEnumerable objects) { diff --git a/src/tests/Interop/COM/ComWrappers/API/Program.cs b/src/tests/Interop/COM/ComWrappers/API/Program.cs index a87ba2e79a6906..2751f427838e91 100644 --- a/src/tests/Interop/COM/ComWrappers/API/Program.cs +++ b/src/tests/Interop/COM/ComWrappers/API/Program.cs @@ -18,6 +18,8 @@ namespace ComWrappersTests public class Program : IDisposable { + record class WrappedUserState(object? UserState); + class TestComWrappers : ComWrappers { private static IntPtr fpQueryInterface = default; @@ -80,6 +82,41 @@ static TestComWrappers() index++; } } + else if (obj is NotWrappedObject) + { + // Return a single vtable for the INotWrappedObject interface. + // Or two if the caller is requesting an IUnknown definition. + count = flags.HasFlag(CreateComInterfaceFlags.CallerDefinedIUnknown) ? 2 : 1; + entryRaw = (ComInterfaceEntry*)RuntimeHelpers.AllocateTypeAssociatedMemory(typeof(NotWrappedObject), sizeof(ComInterfaceEntry) * count); + + var vtbl = new IUnknownVtbl() + { + QueryInterface = fpQueryInterface, + AddRef = fpAddRef, + Release = fpRelease + }; + + int index = 0; + + if (flags.HasFlag(CreateComInterfaceFlags.CallerDefinedIUnknown)) + { + var vtblRaw = RuntimeHelpers.AllocateTypeAssociatedMemory(typeof(NotWrappedObject), sizeof(IUnknownVtbl)); + Marshal.StructureToPtr(vtbl, vtblRaw, false); + + entryRaw[index].IID = IUnknownVtbl.IID_IUnknown; + entryRaw[index].Vtable = vtblRaw; + index++; + } + + { + var vtblRaw = RuntimeHelpers.AllocateTypeAssociatedMemory(typeof(NotWrappedObject), sizeof(IUnknownVtbl)); + Marshal.StructureToPtr(vtbl, vtblRaw, false); + + entryRaw[index].IID = typeof(INotWrappedObject).GUID; + entryRaw[index].Vtable = vtblRaw; + index++; + } + } return entryRaw; } @@ -103,6 +140,39 @@ protected override object CreateObject(IntPtr externalComObject, CreateObjectFla return null; } + public bool CalledUserStateOverload { get; set; } = false; + + public bool CallBaseCreateObject { get; set; } = false; + + protected override object CreateObject(IntPtr externalComObject, CreateObjectFlags flags, object? userState, out CreatedWrapperFlags createdWrapperFlags) + { + CalledUserStateOverload = true; + + if (CallBaseCreateObject) + { + return base.CreateObject(externalComObject, flags, userState, out createdWrapperFlags); + } + + createdWrapperFlags = CreatedWrapperFlags.None; + + int hr = Marshal.QueryInterface(externalComObject, typeof(INotWrappedObject).GUID, out IntPtr iNotWrappedObject); + if (hr == 0) + { + // This is a non-wrapped object, return the user state as an object. + Marshal.Release(iNotWrappedObject); + createdWrapperFlags = CreatedWrapperFlags.NonWrapping; + return new WrappedUserState(userState); + } + + object result = CreateObject(externalComObject, flags); + if (result is ITrackerObjectWrapper trackerObj) + { + createdWrapperFlags = CreatedWrapperFlags.TrackerObject; + } + + return result; + } + public const int ReleaseObjectsCallAck = unchecked((int)-1); protected override void ReleaseObjects(IEnumerable objects) @@ -402,9 +472,9 @@ public void ValidateFallbackQueryInterface() Console.WriteLine($"Running {nameof(ValidateFallbackQueryInterface)}..."); var testObj = new Test() - { - EnableICustomQueryInterface = true - }; + { + EnableICustomQueryInterface = true + }; var wrappers = new TestComWrappers(); @@ -717,10 +787,10 @@ public enum FailureMode switch (ComputeVtablesMode) { case FailureMode.ReturnInvalid: - { - count = -1; - return null; - } + { + count = -1; + return null; + } case FailureMode.ThrowException: throw new Exception() { HResult = ExceptionErrorCode }; default: @@ -795,18 +865,13 @@ public void ValidateBadComWrapperImpl() Marshal.Release(trackerObjRaw); } - [Fact] - public void ValidateRuntimeTrackerScenario() + private void ValidateRuntimeTrackerScenarioCore(ComWrappers cw, Func createObjectFunc) { - Console.WriteLine($"Running {nameof(ValidateRuntimeTrackerScenario)}..."); - - var cw = new TestComWrappers(); - // Get an object from a tracker runtime. IntPtr trackerObjRaw = MockReferenceTrackerRuntime.CreateTrackerObject(); // Create a managed wrapper for the native object. - var trackerObj = (ITrackerObjectWrapper)cw.GetOrCreateObjectForComInstance(trackerObjRaw, CreateObjectFlags.TrackerObject); + var trackerObj = (ITrackerObjectWrapper)createObjectFunc(trackerObjRaw); // Ownership has been transferred to the wrapper. Marshal.Release(trackerObjRaw); @@ -843,6 +908,32 @@ public void ValidateRuntimeTrackerScenario() ForceGC(); } + [Fact] + public void ValidateRuntimeTrackerScenario() + { + Console.WriteLine($"Running {nameof(ValidateRuntimeTrackerScenario)}..."); + + var cw = new TestComWrappers(); + + ValidateRuntimeTrackerScenarioCore(cw, (trackerObjRaw) => + { + return cw.GetOrCreateObjectForComInstance(trackerObjRaw, CreateObjectFlags.TrackerObject); + }); + } + + [Fact] + public void ValidateRuntimeTrackerScenarioUserStateOverload() + { + Console.WriteLine($"Running {nameof(ValidateRuntimeTrackerScenarioUserStateOverload)}..."); + + var cw = new TestComWrappers(); + + ValidateRuntimeTrackerScenarioCore(cw, (trackerObjRaw) => + { + return cw.GetOrCreateObjectForComInstance(trackerObjRaw, CreateObjectFlags.None, userState: null); + }); + } + [Fact] public void ValidateQueryInterfaceAfterManagedObjectCollected() { @@ -1123,6 +1214,68 @@ CustomQueryInterfaceResult ICustomQueryInterface.GetInterface(ref Guid iid, out return CustomQueryInterfaceResult.NotHandled; } } + + [Fact] + public void UserStateOverloadNotCalledWhenNoUserStatePassed() + { + Console.WriteLine($"Running {nameof(UserStateOverloadNotCalledWhenNoUserStatePassed)}..."); + + var testObj = new Test(); + + var wrappers = new TestComWrappers(); + + // Allocate a wrapper for the object + IntPtr comWrapper = wrappers.GetOrCreateComInterfaceForObject(testObj, CreateComInterfaceFlags.None); + Assert.NotEqual(IntPtr.Zero, comWrapper); + + var testObjFromNative = (ITestObjectWrapper)wrappers.GetOrCreateObjectForComInstance(comWrapper, CreateObjectFlags.None); + + Assert.False(wrappers.CalledUserStateOverload); + + testObjFromNative.FinalRelease(); + } + + [Theory] + [InlineData(null)] + [InlineData(1)] + [InlineData("testString")] + public void UserStatePassedThrough(object? userState) + { + Console.WriteLine($"Running {nameof(UserStatePassedThrough)}..."); + + var testObj = new NotWrappedObject(); + + var wrappers = new TestComWrappers(); + + // Allocate a wrapper for the object + IntPtr comWrapper = wrappers.GetOrCreateComInterfaceForObject(testObj, CreateComInterfaceFlags.None); + Assert.NotEqual(IntPtr.Zero, comWrapper); + + var testObjFromNative = (WrappedUserState)wrappers.GetOrCreateObjectForComInstance(comWrapper, CreateObjectFlags.None, userState); + + Assert.True(wrappers.CalledUserStateOverload); + Assert.Same(userState, testObjFromNative.UserState); + + Assert.False(ComWrappers.TryGetComInstance(testObjFromNative, out _)); + } + + [Fact] + public void UserStateBaseImplementationThrows() + { + Console.WriteLine($"Running {nameof(UserStateBaseImplementationThrows)}..."); + + var testObj = new NotWrappedObject(); + + var wrappers = new TestComWrappers(); + + // Allocate a wrapper for the object + IntPtr comWrapper = wrappers.GetOrCreateComInterfaceForObject(testObj, CreateComInterfaceFlags.None); + Assert.NotEqual(IntPtr.Zero, comWrapper); + + wrappers.CallBaseCreateObject = true; + + Assert.Throws(() => wrappers.GetOrCreateObjectForComInstance(comWrapper, CreateObjectFlags.None, userState: null)); + } } } diff --git a/src/tests/Interop/COM/ComWrappers/Common.cs b/src/tests/Interop/COM/ComWrappers/Common.cs index dbcefa47ac175e..ca32876a04a2ba 100644 --- a/src/tests/Interop/COM/ComWrappers/Common.cs +++ b/src/tests/Interop/COM/ComWrappers/Common.cs @@ -337,6 +337,15 @@ CustomQueryInterfaceResult ICustomQueryInterface.GetInterface(ref Guid iid, out } } + [Guid("DA582249-EBf7-437E-BBF8-3B6775BFDB9D")] + public interface INotWrappedObject + { + } + + sealed class NotWrappedObject : INotWrappedObject + { + } + class ComWrappersHelper { public static readonly Guid IID_IReferenceTracker = new Guid("11d3b13a-180e-4789-a8be-7712882893e6");