99 raise unittest .SkipTest ("Windows-specific test" )
1010
1111
12- from _ctypes import COMError
12+ from _ctypes import COMError , CopyComPointer
1313from ctypes import HRESULT
1414
1515
@@ -78,6 +78,19 @@ def is_equal_guid(guid1, guid2):
7878)
7979
8080
81+ def create_shelllink_persist (typ ):
82+ ppst = typ ()
83+ # https://learn.microsoft.com/en-us/windows/win32/api/combaseapi/nf-combaseapi-cocreateinstance
84+ ole32 .CoCreateInstance (
85+ byref (CLSID_ShellLink ),
86+ None ,
87+ CLSCTX_SERVER ,
88+ byref (IID_IPersist ),
89+ byref (ppst ),
90+ )
91+ return ppst
92+
93+
8194class ForeignFunctionsThatWillCallComMethodsTests (unittest .TestCase ):
8295 def setUp (self ):
8396 # https://learn.microsoft.com/en-us/windows/win32/api/combaseapi/nf-combaseapi-coinitializeex
@@ -88,19 +101,6 @@ def tearDown(self):
88101 ole32 .CoUninitialize ()
89102 gc .collect ()
90103
91- @staticmethod
92- def create_shelllink_persist (typ ):
93- ppst = typ ()
94- # https://learn.microsoft.com/en-us/windows/win32/api/combaseapi/nf-combaseapi-cocreateinstance
95- ole32 .CoCreateInstance (
96- byref (CLSID_ShellLink ),
97- None ,
98- CLSCTX_SERVER ,
99- byref (IID_IPersist ),
100- byref (ppst ),
101- )
102- return ppst
103-
104104 def test_without_paramflags_and_iid (self ):
105105 class IUnknown (c_void_p ):
106106 QueryInterface = proto_query_interface ()
@@ -110,7 +110,7 @@ class IUnknown(c_void_p):
110110 class IPersist (IUnknown ):
111111 GetClassID = proto_get_class_id ()
112112
113- ppst = self . create_shelllink_persist (IPersist )
113+ ppst = create_shelllink_persist (IPersist )
114114
115115 clsid = GUID ()
116116 hr_getclsid = ppst .GetClassID (byref (clsid ))
@@ -142,7 +142,7 @@ class IUnknown(c_void_p):
142142 class IPersist (IUnknown ):
143143 GetClassID = proto_get_class_id (((OUT , "pClassID" ),))
144144
145- ppst = self . create_shelllink_persist (IPersist )
145+ ppst = create_shelllink_persist (IPersist )
146146
147147 clsid = ppst .GetClassID ()
148148 self .assertEqual (TRUE , is_equal_guid (CLSID_ShellLink , clsid ))
@@ -167,7 +167,7 @@ class IUnknown(c_void_p):
167167 class IPersist (IUnknown ):
168168 GetClassID = proto_get_class_id (((OUT , "pClassID" ),), IID_IPersist )
169169
170- ppst = self . create_shelllink_persist (IPersist )
170+ ppst = create_shelllink_persist (IPersist )
171171
172172 clsid = ppst .GetClassID ()
173173 self .assertEqual (TRUE , is_equal_guid (CLSID_ShellLink , clsid ))
@@ -184,5 +184,103 @@ class IPersist(IUnknown):
184184 self .assertEqual (0 , ppst .Release ())
185185
186186
187+ class CopyComPointerTests (unittest .TestCase ):
188+ def setUp (self ):
189+ ole32 .CoInitializeEx (None , COINIT_APARTMENTTHREADED )
190+
191+ class IUnknown (c_void_p ):
192+ QueryInterface = proto_query_interface (None , IID_IUnknown )
193+ AddRef = proto_add_ref ()
194+ Release = proto_release ()
195+
196+ class IPersist (IUnknown ):
197+ GetClassID = proto_get_class_id (((OUT , "pClassID" ),), IID_IPersist )
198+
199+ self .IUnknown = IUnknown
200+ self .IPersist = IPersist
201+
202+ def tearDown (self ):
203+ ole32 .CoUninitialize ()
204+ gc .collect ()
205+
206+ def test_both_are_null (self ):
207+ src = self .IPersist ()
208+ dst = self .IPersist ()
209+
210+ hr = CopyComPointer (src , byref (dst ))
211+
212+ self .assertEqual (S_OK , hr )
213+
214+ self .assertIsNone (src .value )
215+ self .assertIsNone (dst .value )
216+
217+ def test_src_is_nonnull_and_dest_is_null (self ):
218+ # The reference count of the COM pointer created by `CoCreateInstance`
219+ # is initially 1.
220+ src = create_shelllink_persist (self .IPersist )
221+ dst = self .IPersist ()
222+
223+ # `CopyComPointer` calls `AddRef` explicitly in the C implementation.
224+ # The refcount of `src` is incremented from 1 to 2 here.
225+ hr = CopyComPointer (src , byref (dst ))
226+
227+ self .assertEqual (S_OK , hr )
228+ self .assertEqual (src .value , dst .value )
229+
230+ # This indicates that the refcount was 2 before the `Release` call.
231+ self .assertEqual (1 , src .Release ())
232+
233+ clsid = dst .GetClassID ()
234+ self .assertEqual (TRUE , is_equal_guid (CLSID_ShellLink , clsid ))
235+
236+ self .assertEqual (0 , dst .Release ())
237+
238+ def test_src_is_null_and_dest_is_nonnull (self ):
239+ src = self .IPersist ()
240+ dst_orig = create_shelllink_persist (self .IPersist )
241+ dst = self .IPersist ()
242+ CopyComPointer (dst_orig , byref (dst ))
243+ self .assertEqual (1 , dst_orig .Release ())
244+
245+ clsid = dst .GetClassID ()
246+ self .assertEqual (TRUE , is_equal_guid (CLSID_ShellLink , clsid ))
247+
248+ # This does NOT affects the refcount of `dst_orig`.
249+ hr = CopyComPointer (src , byref (dst ))
250+
251+ self .assertEqual (S_OK , hr )
252+ self .assertIsNone (dst .value )
253+
254+ with self .assertRaises (ValueError ):
255+ dst .GetClassID () # NULL COM pointer access
256+
257+ # This indicates that the refcount was 1 before the `Release` call.
258+ self .assertEqual (0 , dst_orig .Release ())
259+
260+ def test_both_are_nonnull (self ):
261+ src = create_shelllink_persist (self .IPersist )
262+ dst_orig = create_shelllink_persist (self .IPersist )
263+ dst = self .IPersist ()
264+ CopyComPointer (dst_orig , byref (dst ))
265+ self .assertEqual (1 , dst_orig .Release ())
266+
267+ self .assertEqual (dst .value , dst_orig .value )
268+ self .assertNotEqual (src .value , dst .value )
269+
270+ hr = CopyComPointer (src , byref (dst ))
271+
272+ self .assertEqual (S_OK , hr )
273+ self .assertEqual (src .value , dst .value )
274+ self .assertNotEqual (dst .value , dst_orig .value )
275+
276+ self .assertEqual (1 , src .Release ())
277+
278+ clsid = dst .GetClassID ()
279+ self .assertEqual (TRUE , is_equal_guid (CLSID_ShellLink , clsid ))
280+
281+ self .assertEqual (0 , dst .Release ())
282+ self .assertEqual (0 , dst_orig .Release ())
283+
284+
187285if __name__ == '__main__' :
188286 unittest .main ()
0 commit comments