diff --git a/src/libraries/System.Private.CoreLib/src/Resources/Strings.resx b/src/libraries/System.Private.CoreLib/src/Resources/Strings.resx
index 487028d306a3fa..475d54061b4d0b 100644
--- a/src/libraries/System.Private.CoreLib/src/Resources/Strings.resx
+++ b/src/libraries/System.Private.CoreLib/src/Resources/Strings.resx
@@ -4358,4 +4358,7 @@
VariantWrappers cannot be stored in Variants.
+
+ Creating an object wrapper for a COM instance with user state is not implemented.
+
diff --git a/src/libraries/System.Private.CoreLib/src/System.Private.CoreLib.Shared.projitems b/src/libraries/System.Private.CoreLib/src/System.Private.CoreLib.Shared.projitems
index 5d1bfacb65637f..4b74be3c56bb17 100644
--- a/src/libraries/System.Private.CoreLib/src/System.Private.CoreLib.Shared.projitems
+++ b/src/libraries/System.Private.CoreLib/src/System.Private.CoreLib.Shared.projitems
@@ -961,6 +961,7 @@
+
@@ -2774,7 +2775,7 @@
-
+
@@ -2849,4 +2850,4 @@
-
+
\ No newline at end of file
diff --git a/src/libraries/System.Private.CoreLib/src/System/Runtime/InteropServices/ComWrappers.PlatformNotSupported.cs b/src/libraries/System.Private.CoreLib/src/System/Runtime/InteropServices/ComWrappers.PlatformNotSupported.cs
index 1d87d2b2e9dbba..c7c046732046c3 100644
--- a/src/libraries/System.Private.CoreLib/src/System/Runtime/InteropServices/ComWrappers.PlatformNotSupported.cs
+++ b/src/libraries/System.Private.CoreLib/src/System/Runtime/InteropServices/ComWrappers.PlatformNotSupported.cs
@@ -25,6 +25,11 @@ public struct ComInterfaceEntry
protected abstract object? CreateObject(IntPtr externalComObject, CreateObjectFlags flags);
+ protected virtual object? CreateObject(IntPtr externalComObject, CreateObjectFlags flags, object? userState, out CreatedWrapperFlags wrapperFlags)
+ {
+ throw new PlatformNotSupportedException();
+ }
+
protected internal abstract void ReleaseObjects(IEnumerable objects);
public static unsafe bool TryGetComInstance(object obj, out IntPtr unknown)
@@ -58,6 +63,11 @@ public object GetOrCreateObjectForComInstance(IntPtr externalComObject, CreateOb
throw new PlatformNotSupportedException();
}
+ public object GetOrCreateObjectForComInstance(IntPtr externalComObject, CreateObjectFlags flags, object? userState)
+ {
+ throw new PlatformNotSupportedException();
+ }
+
public object GetOrRegisterObjectForComInstance(IntPtr externalComObject, CreateObjectFlags flags, object wrapper)
{
throw new PlatformNotSupportedException();
diff --git a/src/libraries/System.Private.CoreLib/src/System/Runtime/InteropServices/ComWrappers.cs b/src/libraries/System.Private.CoreLib/src/System/Runtime/InteropServices/ComWrappers.cs
index 639f72b2d93416..636bff946229a7 100644
--- a/src/libraries/System.Private.CoreLib/src/System/Runtime/InteropServices/ComWrappers.cs
+++ b/src/libraries/System.Private.CoreLib/src/System/Runtime/InteropServices/ComWrappers.cs
@@ -901,7 +901,28 @@ private static nuint AlignUp(nuint value, nuint alignment)
public object GetOrCreateObjectForComInstance(IntPtr externalComObject, CreateObjectFlags flags)
{
object? obj;
- if (!TryGetOrCreateObjectForComInstanceInternal(externalComObject, IntPtr.Zero, flags, null, out obj))
+ if (!TryGetOrCreateObjectForComInstanceInternal(externalComObject, IntPtr.Zero, flags, wrapperMaybe: null, userState: NoUserState.Instance, out obj))
+ throw new ArgumentNullException(nameof(externalComObject));
+
+ return obj;
+ }
+
+ ///
+ /// Get the currently registered managed object or creates a new managed object and registers it.
+ ///
+ /// Object to import for usage into the .NET runtime.
+ /// Flags used to describe the external object.
+ /// A state object to use to help create the wrapping .NET object.
+ /// Returns a managed object associated with the supplied external COM object.
+ ///
+ /// If a managed object was previously created for the specified
+ /// using this instance, the previously created object will be returned.
+ /// If not, a new one will be created.
+ ///
+ public object GetOrCreateObjectForComInstance(IntPtr externalComObject, CreateObjectFlags flags, object? userState)
+ {
+ object? obj;
+ if (!TryGetOrCreateObjectForComInstanceInternal(externalComObject, IntPtr.Zero, flags, wrapperMaybe: null, userState, out obj))
throw new ArgumentNullException(nameof(externalComObject));
return obj;
@@ -943,7 +964,7 @@ public object GetOrRegisterObjectForComInstance(IntPtr externalComObject, Create
ArgumentNullException.ThrowIfNull(wrapper);
object? obj;
- if (!TryGetOrCreateObjectForComInstanceInternal(externalComObject, inner, flags, wrapper, out obj))
+ if (!TryGetOrCreateObjectForComInstanceInternal(externalComObject, inner, flags, wrapper, userState: NoUserState.Instance, out obj))
throw new ArgumentNullException(nameof(externalComObject));
return obj;
@@ -1025,6 +1046,15 @@ private static void DetermineIdentityAndInner(
}
}
+ private sealed class NoUserState
+ {
+ public static readonly NoUserState Instance = new NoUserState();
+
+ private NoUserState()
+ {
+ }
+ }
+
///
/// Get the currently registered managed object or creates a new managed object and registers it.
///
@@ -1032,13 +1062,15 @@ private static void DetermineIdentityAndInner(
/// The inner instance if aggregation is involved
/// Flags used to describe the external object.
/// The to be used as the wrapper for the external object.
- /// The managed object associated with the supplied external COM object or null if it could not be created.
+ /// A state object provided by the user for creating the object, otherwise .
/// Returns true if a managed object could be retrieved/created, false otherwise
+ /// The managed object associated with the supplied external COM object or null if it could not be created.
private unsafe bool TryGetOrCreateObjectForComInstanceInternal(
IntPtr externalComObject,
IntPtr innerMaybe,
CreateObjectFlags flags,
object? wrapperMaybe,
+ object? userState,
[NotNullWhen(true)] out object? retValue)
{
if (externalComObject == IntPtr.Zero)
@@ -1062,7 +1094,7 @@ private unsafe bool TryGetOrCreateObjectForComInstanceInternal(
// and return.
if (flags.HasFlag(CreateObjectFlags.UniqueInstance))
{
- retValue = CreateAndRegisterObjectForComInstance(identity, inner, flags, ref referenceTrackerMaybe);
+ retValue = CreateAndRegisterObjectForComInstance(identity, inner, flags, userState, ref referenceTrackerMaybe);
return retValue is not null;
}
@@ -1116,7 +1148,7 @@ private unsafe bool TryGetOrCreateObjectForComInstanceInternal(
// If the user didn't provide a wrapper and couldn't unwrap a managed object wrapper,
// create a new wrapper.
- retValue = CreateAndRegisterObjectForComInstance(identity, inner, flags, ref referenceTrackerMaybe);
+ retValue = CreateAndRegisterObjectForComInstance(identity, inner, flags, userState, ref referenceTrackerMaybe);
return retValue is not null;
}
finally
@@ -1137,15 +1169,32 @@ private unsafe bool TryGetOrCreateObjectForComInstanceInternal(
IntPtr identity,
IntPtr inner,
CreateObjectFlags flags,
+ object? userState,
ref IntPtr referenceTrackerMaybe)
{
- object? retValue = CreateObject(identity, flags);
+ CreatedWrapperFlags wrapperFlags = CreatedWrapperFlags.None;
+
+ object? retValue = userState is NoUserState
+ ? CreateObject(identity, flags)
+ : CreateObject(identity, flags, userState, out wrapperFlags);
+
if (retValue is null)
{
// If ComWrappers instance cannot create wrapper, we can do nothing here.
return null;
}
+ if (wrapperFlags.HasFlag(CreatedWrapperFlags.NonWrapping))
+ {
+ return retValue;
+ }
+
+ if (wrapperFlags.HasFlag(CreatedWrapperFlags.TrackerObject))
+ {
+ // The user has determined after inspecting the COM object that it should have tracker support.
+ flags |= CreateObjectFlags.TrackerObject;
+ }
+
return RegisterObjectForComInstance(identity, inner, retValue, flags, ref referenceTrackerMaybe);
}
@@ -1391,6 +1440,23 @@ public static void RegisterForMarshalling(ComWrappers instance)
///
protected abstract object? CreateObject(IntPtr externalComObject, CreateObjectFlags flags);
+ ///
+ /// Create a managed object for the object pointed at by respecting the values of .
+ ///
+ /// Object to import for usage into the .NET runtime.
+ /// Flags used to describe the external object.
+ /// User state provided by the call to .
+ /// Flags used to describe the created wrapper object.
+ /// Returns a managed object associated with the supplied external COM object.
+ ///
+ /// The default implementation throws .
+ /// If the object cannot be created and null is returned, the call to will throw a .
+ ///
+ protected virtual object? CreateObject(IntPtr externalComObject, CreateObjectFlags flags, object? userState, out CreatedWrapperFlags wrapperFlags)
+ {
+ throw new NotImplementedException(SR.NotImplemented_CreateObjectWithUserState);
+ }
+
///
/// Called when a request is made for a collection of objects to be released outside of normal object or COM interface lifetime.
///
diff --git a/src/libraries/System.Private.CoreLib/src/System/Runtime/InteropServices/CreateObjectFlags.cs b/src/libraries/System.Private.CoreLib/src/System/Runtime/InteropServices/CreateObjectFlags.cs
index 337e955a57e708..b0a50ba4d17261 100644
--- a/src/libraries/System.Private.CoreLib/src/System/Runtime/InteropServices/CreateObjectFlags.cs
+++ b/src/libraries/System.Private.CoreLib/src/System/Runtime/InteropServices/CreateObjectFlags.cs
@@ -6,7 +6,6 @@
namespace System.Runtime.InteropServices
{
-
///
/// Enumeration of flags for .
///
diff --git a/src/libraries/System.Private.CoreLib/src/System/Runtime/InteropServices/CreatedWrapperFlags.cs b/src/libraries/System.Private.CoreLib/src/System/Runtime/InteropServices/CreatedWrapperFlags.cs
new file mode 100644
index 00000000000000..a8c68a497a85e2
--- /dev/null
+++ b/src/libraries/System.Private.CoreLib/src/System/Runtime/InteropServices/CreatedWrapperFlags.cs
@@ -0,0 +1,33 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+
+namespace System.Runtime.InteropServices
+{
+ ///
+ /// Enumeration of flags for .
+ ///
+ [Flags]
+ public enum CreatedWrapperFlags
+ {
+ None = 0,
+
+ ///
+ /// Indicate if the supplied external COM object implements the IReferenceTracker.
+ ///
+ TrackerObject = 1,
+
+ ///
+ /// The managed object doesn't keep the native object alive. It represents an equivalent value.
+ ///
+ ///
+ /// Using this flag results in the following changes:
+ /// will return false for the returned object.
+ /// The features provided by the flag will be disabled.
+ /// Integration between and the returned object via the native IWeakReferenceSource interface will not work.
+ /// behavior is implied.
+ /// Diagnostics tooling support to unwrap objects returned by `CreateObject` will not see this object as a wrapper.
+ /// The same object can be returned from `CreateObject` wrapping different COM objects.
+ ///
+ NonWrapping = 0x2
+ }
+}
diff --git a/src/libraries/System.Runtime.InteropServices/ref/System.Runtime.InteropServices.cs b/src/libraries/System.Runtime.InteropServices/ref/System.Runtime.InteropServices.cs
index eb2130e4f80bc8..c49c5525d18a1a 100644
--- a/src/libraries/System.Runtime.InteropServices/ref/System.Runtime.InteropServices.cs
+++ b/src/libraries/System.Runtime.InteropServices/ref/System.Runtime.InteropServices.cs
@@ -478,6 +478,10 @@ public void FromManaged(object? managed) { }
public void Free() { }
}
}
+ [System.Runtime.Versioning.UnsupportedOSPlatformAttribute("android")]
+ [System.Runtime.Versioning.UnsupportedOSPlatformAttribute("browser")]
+ [System.Runtime.Versioning.UnsupportedOSPlatformAttribute("ios")]
+ [System.Runtime.Versioning.UnsupportedOSPlatformAttribute("tvos")]
[System.CLSCompliantAttribute(false)]
public partial class StrategyBasedComWrappers : System.Runtime.InteropServices.ComWrappers
{
@@ -488,6 +492,7 @@ public StrategyBasedComWrappers() { }
protected virtual System.Runtime.InteropServices.Marshalling.IIUnknownCacheStrategy CreateCacheStrategy() { throw null; }
protected static System.Runtime.InteropServices.Marshalling.IIUnknownCacheStrategy CreateDefaultCacheStrategy() { throw null; }
protected sealed override object CreateObject(nint externalComObject, System.Runtime.InteropServices.CreateObjectFlags flags) { throw null; }
+ protected sealed override object? CreateObject(nint externalComObject, System.Runtime.InteropServices.CreateObjectFlags flags, object? userState, out System.Runtime.InteropServices.CreatedWrapperFlags wrapperFlags) { throw null; }
protected virtual System.Runtime.InteropServices.Marshalling.IIUnknownInterfaceDetailsStrategy GetOrCreateInterfaceDetailsStrategy() { throw null; }
protected virtual System.Runtime.InteropServices.Marshalling.IIUnknownStrategy GetOrCreateIUnknownStrategy() { throw null; }
protected sealed override void ReleaseObjects(System.Collections.IEnumerable objects) { }
@@ -762,16 +767,18 @@ public struct ComInterfaceDispatch
public System.IntPtr Vtable;
public unsafe static T GetInstance(ComInterfaceDispatch* dispatchPtr) where T : class { throw null; }
}
- public System.IntPtr GetOrCreateComInterfaceForObject(object instance, CreateComInterfaceFlags flags) { throw null; }
- protected unsafe abstract ComInterfaceEntry* ComputeVtables(object obj, CreateComInterfaceFlags flags, out int count);
- public object GetOrCreateObjectForComInstance(System.IntPtr externalComObject, CreateObjectFlags flags) { throw null; }
- protected abstract object? CreateObject(System.IntPtr externalComObject, CreateObjectFlags flags);
- public object GetOrRegisterObjectForComInstance(System.IntPtr externalComObject, CreateObjectFlags flags, object wrapper) { throw null; }
- public object GetOrRegisterObjectForComInstance(System.IntPtr externalComObject, CreateObjectFlags flags, object wrapper, System.IntPtr inner) { throw null; }
+ public System.IntPtr GetOrCreateComInterfaceForObject(object instance, System.Runtime.InteropServices.CreateComInterfaceFlags flags) { throw null; }
+ protected unsafe abstract ComInterfaceEntry* ComputeVtables(object obj, System.Runtime.InteropServices.CreateComInterfaceFlags flags, out int count);
+ public object GetOrCreateObjectForComInstance(System.IntPtr externalComObject, System.Runtime.InteropServices.CreateObjectFlags flags) { throw null; }
+ public object GetOrCreateObjectForComInstance(System.IntPtr externalComObject, System.Runtime.InteropServices.CreateObjectFlags flags, object? userState) { throw null; }
+ protected abstract object? CreateObject(System.IntPtr externalComObject, System.Runtime.InteropServices.CreateObjectFlags flags);
+ protected virtual object? CreateObject(System.IntPtr externalComObject, System.Runtime.InteropServices.CreateObjectFlags flags, object? userState, out System.Runtime.InteropServices.CreatedWrapperFlags wrapperFlags) { throw null; }
+ public object GetOrRegisterObjectForComInstance(System.IntPtr externalComObject, System.Runtime.InteropServices.CreateObjectFlags flags, object wrapper) { throw null; }
+ public object GetOrRegisterObjectForComInstance(System.IntPtr externalComObject, System.Runtime.InteropServices.CreateObjectFlags flags, object wrapper, System.IntPtr inner) { throw null; }
protected abstract void ReleaseObjects(System.Collections.IEnumerable objects);
public static void RegisterForTrackerSupport(ComWrappers instance) { }
[System.Runtime.Versioning.SupportedOSPlatformAttribute("windows")]
- public static void RegisterForMarshalling(ComWrappers instance) { }
+ public static void RegisterForMarshalling(System.Runtime.InteropServices.ComWrappers instance) { }
public static void GetIUnknownImpl(out System.IntPtr fpQueryInterface, out System.IntPtr fpAddRef, out System.IntPtr fpRelease) { throw null; }
}
[System.FlagsAttribute]
@@ -790,6 +797,13 @@ public enum CreateObjectFlags
Aggregation = 4,
Unwrap = 8,
}
+ [System.FlagsAttribute]
+ public enum CreatedWrapperFlags
+ {
+ None = 0,
+ TrackerObject = 1,
+ NonWrapping = 2
+ }
[System.CLSCompliantAttribute(false)]
public readonly partial struct CULong : System.IEquatable
{
diff --git a/src/libraries/System.Runtime.InteropServices/src/System/Runtime/InteropServices/Marshalling/StrategyBasedComWrappers.cs b/src/libraries/System.Runtime.InteropServices/src/System/Runtime/InteropServices/Marshalling/StrategyBasedComWrappers.cs
index 9741b482633c24..74c147a5c2c21f 100644
--- a/src/libraries/System.Runtime.InteropServices/src/System/Runtime/InteropServices/Marshalling/StrategyBasedComWrappers.cs
+++ b/src/libraries/System.Runtime.InteropServices/src/System/Runtime/InteropServices/Marshalling/StrategyBasedComWrappers.cs
@@ -5,12 +5,17 @@
using System.Diagnostics.CodeAnalysis;
using System.Reflection;
using System.Runtime.CompilerServices;
+using System.Runtime.Versioning;
namespace System.Runtime.InteropServices.Marshalling
{
///
/// A -based type that uses customizable strategy objects to implement COM object wrappers and managed object wrappers exposed to COM.
///
+ [UnsupportedOSPlatform("android")]
+ [UnsupportedOSPlatform("browser")]
+ [UnsupportedOSPlatform("ios")]
+ [UnsupportedOSPlatform("tvos")]
[CLSCompliant(false)]
public class StrategyBasedComWrappers : ComWrappers
{
@@ -84,7 +89,7 @@ static IIUnknownInterfaceDetailsStrategy GetInteropStrategy()
return null;
}
- ///
+ ///
protected sealed override unsafe object CreateObject(nint externalComObject, CreateObjectFlags flags)
{
if (flags.HasFlag(CreateObjectFlags.TrackerObject)
@@ -101,6 +106,12 @@ protected sealed override unsafe object CreateObject(nint externalComObject, Cre
return rcw;
}
+ ///
+ protected sealed override object? CreateObject(nint externalComObject, CreateObjectFlags flags, object? userState, out CreatedWrapperFlags wrapperFlags)
+ {
+ return base.CreateObject(externalComObject, flags, userState, out wrapperFlags);
+ }
+
///
protected sealed override void ReleaseObjects(IEnumerable objects)
{
diff --git a/src/tests/Interop/COM/ComWrappers/API/Program.cs b/src/tests/Interop/COM/ComWrappers/API/Program.cs
index a87ba2e79a6906..2751f427838e91 100644
--- a/src/tests/Interop/COM/ComWrappers/API/Program.cs
+++ b/src/tests/Interop/COM/ComWrappers/API/Program.cs
@@ -18,6 +18,8 @@ namespace ComWrappersTests
public class Program : IDisposable
{
+ record class WrappedUserState(object? UserState);
+
class TestComWrappers : ComWrappers
{
private static IntPtr fpQueryInterface = default;
@@ -80,6 +82,41 @@ static TestComWrappers()
index++;
}
}
+ else if (obj is NotWrappedObject)
+ {
+ // Return a single vtable for the INotWrappedObject interface.
+ // Or two if the caller is requesting an IUnknown definition.
+ count = flags.HasFlag(CreateComInterfaceFlags.CallerDefinedIUnknown) ? 2 : 1;
+ entryRaw = (ComInterfaceEntry*)RuntimeHelpers.AllocateTypeAssociatedMemory(typeof(NotWrappedObject), sizeof(ComInterfaceEntry) * count);
+
+ var vtbl = new IUnknownVtbl()
+ {
+ QueryInterface = fpQueryInterface,
+ AddRef = fpAddRef,
+ Release = fpRelease
+ };
+
+ int index = 0;
+
+ if (flags.HasFlag(CreateComInterfaceFlags.CallerDefinedIUnknown))
+ {
+ var vtblRaw = RuntimeHelpers.AllocateTypeAssociatedMemory(typeof(NotWrappedObject), sizeof(IUnknownVtbl));
+ Marshal.StructureToPtr(vtbl, vtblRaw, false);
+
+ entryRaw[index].IID = IUnknownVtbl.IID_IUnknown;
+ entryRaw[index].Vtable = vtblRaw;
+ index++;
+ }
+
+ {
+ var vtblRaw = RuntimeHelpers.AllocateTypeAssociatedMemory(typeof(NotWrappedObject), sizeof(IUnknownVtbl));
+ Marshal.StructureToPtr(vtbl, vtblRaw, false);
+
+ entryRaw[index].IID = typeof(INotWrappedObject).GUID;
+ entryRaw[index].Vtable = vtblRaw;
+ index++;
+ }
+ }
return entryRaw;
}
@@ -103,6 +140,39 @@ protected override object CreateObject(IntPtr externalComObject, CreateObjectFla
return null;
}
+ public bool CalledUserStateOverload { get; set; } = false;
+
+ public bool CallBaseCreateObject { get; set; } = false;
+
+ protected override object CreateObject(IntPtr externalComObject, CreateObjectFlags flags, object? userState, out CreatedWrapperFlags createdWrapperFlags)
+ {
+ CalledUserStateOverload = true;
+
+ if (CallBaseCreateObject)
+ {
+ return base.CreateObject(externalComObject, flags, userState, out createdWrapperFlags);
+ }
+
+ createdWrapperFlags = CreatedWrapperFlags.None;
+
+ int hr = Marshal.QueryInterface(externalComObject, typeof(INotWrappedObject).GUID, out IntPtr iNotWrappedObject);
+ if (hr == 0)
+ {
+ // This is a non-wrapped object, return the user state as an object.
+ Marshal.Release(iNotWrappedObject);
+ createdWrapperFlags = CreatedWrapperFlags.NonWrapping;
+ return new WrappedUserState(userState);
+ }
+
+ object result = CreateObject(externalComObject, flags);
+ if (result is ITrackerObjectWrapper trackerObj)
+ {
+ createdWrapperFlags = CreatedWrapperFlags.TrackerObject;
+ }
+
+ return result;
+ }
+
public const int ReleaseObjectsCallAck = unchecked((int)-1);
protected override void ReleaseObjects(IEnumerable objects)
@@ -402,9 +472,9 @@ public void ValidateFallbackQueryInterface()
Console.WriteLine($"Running {nameof(ValidateFallbackQueryInterface)}...");
var testObj = new Test()
- {
- EnableICustomQueryInterface = true
- };
+ {
+ EnableICustomQueryInterface = true
+ };
var wrappers = new TestComWrappers();
@@ -717,10 +787,10 @@ public enum FailureMode
switch (ComputeVtablesMode)
{
case FailureMode.ReturnInvalid:
- {
- count = -1;
- return null;
- }
+ {
+ count = -1;
+ return null;
+ }
case FailureMode.ThrowException:
throw new Exception() { HResult = ExceptionErrorCode };
default:
@@ -795,18 +865,13 @@ public void ValidateBadComWrapperImpl()
Marshal.Release(trackerObjRaw);
}
- [Fact]
- public void ValidateRuntimeTrackerScenario()
+ private void ValidateRuntimeTrackerScenarioCore(ComWrappers cw, Func createObjectFunc)
{
- Console.WriteLine($"Running {nameof(ValidateRuntimeTrackerScenario)}...");
-
- var cw = new TestComWrappers();
-
// Get an object from a tracker runtime.
IntPtr trackerObjRaw = MockReferenceTrackerRuntime.CreateTrackerObject();
// Create a managed wrapper for the native object.
- var trackerObj = (ITrackerObjectWrapper)cw.GetOrCreateObjectForComInstance(trackerObjRaw, CreateObjectFlags.TrackerObject);
+ var trackerObj = (ITrackerObjectWrapper)createObjectFunc(trackerObjRaw);
// Ownership has been transferred to the wrapper.
Marshal.Release(trackerObjRaw);
@@ -843,6 +908,32 @@ public void ValidateRuntimeTrackerScenario()
ForceGC();
}
+ [Fact]
+ public void ValidateRuntimeTrackerScenario()
+ {
+ Console.WriteLine($"Running {nameof(ValidateRuntimeTrackerScenario)}...");
+
+ var cw = new TestComWrappers();
+
+ ValidateRuntimeTrackerScenarioCore(cw, (trackerObjRaw) =>
+ {
+ return cw.GetOrCreateObjectForComInstance(trackerObjRaw, CreateObjectFlags.TrackerObject);
+ });
+ }
+
+ [Fact]
+ public void ValidateRuntimeTrackerScenarioUserStateOverload()
+ {
+ Console.WriteLine($"Running {nameof(ValidateRuntimeTrackerScenarioUserStateOverload)}...");
+
+ var cw = new TestComWrappers();
+
+ ValidateRuntimeTrackerScenarioCore(cw, (trackerObjRaw) =>
+ {
+ return cw.GetOrCreateObjectForComInstance(trackerObjRaw, CreateObjectFlags.None, userState: null);
+ });
+ }
+
[Fact]
public void ValidateQueryInterfaceAfterManagedObjectCollected()
{
@@ -1123,6 +1214,68 @@ CustomQueryInterfaceResult ICustomQueryInterface.GetInterface(ref Guid iid, out
return CustomQueryInterfaceResult.NotHandled;
}
}
+
+ [Fact]
+ public void UserStateOverloadNotCalledWhenNoUserStatePassed()
+ {
+ Console.WriteLine($"Running {nameof(UserStateOverloadNotCalledWhenNoUserStatePassed)}...");
+
+ var testObj = new Test();
+
+ var wrappers = new TestComWrappers();
+
+ // Allocate a wrapper for the object
+ IntPtr comWrapper = wrappers.GetOrCreateComInterfaceForObject(testObj, CreateComInterfaceFlags.None);
+ Assert.NotEqual(IntPtr.Zero, comWrapper);
+
+ var testObjFromNative = (ITestObjectWrapper)wrappers.GetOrCreateObjectForComInstance(comWrapper, CreateObjectFlags.None);
+
+ Assert.False(wrappers.CalledUserStateOverload);
+
+ testObjFromNative.FinalRelease();
+ }
+
+ [Theory]
+ [InlineData(null)]
+ [InlineData(1)]
+ [InlineData("testString")]
+ public void UserStatePassedThrough(object? userState)
+ {
+ Console.WriteLine($"Running {nameof(UserStatePassedThrough)}...");
+
+ var testObj = new NotWrappedObject();
+
+ var wrappers = new TestComWrappers();
+
+ // Allocate a wrapper for the object
+ IntPtr comWrapper = wrappers.GetOrCreateComInterfaceForObject(testObj, CreateComInterfaceFlags.None);
+ Assert.NotEqual(IntPtr.Zero, comWrapper);
+
+ var testObjFromNative = (WrappedUserState)wrappers.GetOrCreateObjectForComInstance(comWrapper, CreateObjectFlags.None, userState);
+
+ Assert.True(wrappers.CalledUserStateOverload);
+ Assert.Same(userState, testObjFromNative.UserState);
+
+ Assert.False(ComWrappers.TryGetComInstance(testObjFromNative, out _));
+ }
+
+ [Fact]
+ public void UserStateBaseImplementationThrows()
+ {
+ Console.WriteLine($"Running {nameof(UserStateBaseImplementationThrows)}...");
+
+ var testObj = new NotWrappedObject();
+
+ var wrappers = new TestComWrappers();
+
+ // Allocate a wrapper for the object
+ IntPtr comWrapper = wrappers.GetOrCreateComInterfaceForObject(testObj, CreateComInterfaceFlags.None);
+ Assert.NotEqual(IntPtr.Zero, comWrapper);
+
+ wrappers.CallBaseCreateObject = true;
+
+ Assert.Throws(() => wrappers.GetOrCreateObjectForComInstance(comWrapper, CreateObjectFlags.None, userState: null));
+ }
}
}
diff --git a/src/tests/Interop/COM/ComWrappers/Common.cs b/src/tests/Interop/COM/ComWrappers/Common.cs
index dbcefa47ac175e..ca32876a04a2ba 100644
--- a/src/tests/Interop/COM/ComWrappers/Common.cs
+++ b/src/tests/Interop/COM/ComWrappers/Common.cs
@@ -337,6 +337,15 @@ CustomQueryInterfaceResult ICustomQueryInterface.GetInterface(ref Guid iid, out
}
}
+ [Guid("DA582249-EBf7-437E-BBF8-3B6775BFDB9D")]
+ public interface INotWrappedObject
+ {
+ }
+
+ sealed class NotWrappedObject : INotWrappedObject
+ {
+ }
+
class ComWrappersHelper
{
public static readonly Guid IID_IReferenceTracker = new Guid("11d3b13a-180e-4789-a8be-7712882893e6");