33import inspect
44import math
55import re
6- from typing import get_type_hints
76from unittest import mock
87
98import numpy as np
@@ -178,28 +177,28 @@ def _check_dispatcher_dispatch(dispatcher, kernel, input, *args, **kwargs):
178177 """Checks if the dispatcher correctly dispatches the input to the corresponding kernel and that the input type is
179178 preserved in doing so. For bounding boxes also checks that the format is preserved.
180179 """
181- if isinstance (input , datapoints ._datapoint .Datapoint ):
182- if dispatcher in {F .resize , F .adjust_brightness }:
180+ input_type = type (input )
181+
182+ if isinstance (input , datapoints .Datapoint ):
183+ wrapped_kernel = _KERNEL_REGISTRY [dispatcher ][input_type ]
184+
185+ # In case the wrapper was decorated with @functools.wraps, we can make the check more strict and test if the
186+ # proper kernel was wrapped
187+ if hasattr (wrapped_kernel , "__wrapped__" ):
188+ assert wrapped_kernel .__wrapped__ is kernel
189+
190+ spy = mock .MagicMock (wraps = wrapped_kernel , name = wrapped_kernel .__name__ )
191+ with mock .patch .dict (_KERNEL_REGISTRY [dispatcher ], values = {input_type : spy }):
183192 output = dispatcher (input , * args , ** kwargs )
184- else :
185- # Due to our complex dispatch architecture for datapoints, we cannot spy on the kernel directly,
186- # but rather have to patch the `Datapoint.__F` attribute to contain the spied on kernel.
187- spy = mock .MagicMock (wraps = kernel , name = kernel .__name__ )
188- with mock .patch .object (F , kernel .__name__ , spy ):
189- # Due to Python's name mangling, the `Datapoint.__F` attribute is only accessible from inside the class.
190- # Since that is not the case here, we need to prefix f"_{cls.__name__}"
191- # See https://docs.python.org/3/tutorial/classes.html#private-variables for details
192- with mock .patch .object (datapoints ._datapoint .Datapoint , "_Datapoint__F" , new = F ):
193- output = dispatcher (input , * args , ** kwargs )
194193
195- spy .assert_called_once ()
194+ spy .assert_called_once ()
196195 else :
197196 with mock .patch (f"{ dispatcher .__module__ } .{ kernel .__name__ } " , wraps = kernel ) as spy :
198197 output = dispatcher (input , * args , ** kwargs )
199198
200199 spy .assert_called_once ()
201200
202- assert isinstance (output , type ( input ) )
201+ assert isinstance (output , input_type )
203202
204203 if isinstance (input , datapoints .BoundingBoxes ):
205204 assert output .format == input .format
@@ -214,34 +213,32 @@ def check_dispatcher(
214213 check_dispatch = True ,
215214 ** kwargs ,
216215):
216+ unknown_input = object ()
217217 with mock .patch ("torch._C._log_api_usage_once" , wraps = torch ._C ._log_api_usage_once ) as spy :
218- dispatcher (input , * args , ** kwargs )
218+ with pytest .raises (TypeError , match = re .escape (str (type (unknown_input )))):
219+ dispatcher (unknown_input , * args , ** kwargs )
219220
220221 spy .assert_any_call (f"{ dispatcher .__module__ } .{ dispatcher .__name__ } " )
221222
222- unknown_input = object ()
223- with pytest .raises (TypeError , match = re .escape (str (type (unknown_input )))):
224- dispatcher (unknown_input , * args , ** kwargs )
225-
226223 if check_scripted_smoke :
227224 _check_dispatcher_scripted_smoke (dispatcher , input , * args , ** kwargs )
228225
229226 if check_dispatch :
230227 _check_dispatcher_dispatch (dispatcher , kernel , input , * args , ** kwargs )
231228
232229
233- def _check_dispatcher_kernel_signature_match (dispatcher , * , kernel , input_type ):
230+ def check_dispatcher_kernel_signature_match (dispatcher , * , kernel , input_type ):
234231 """Checks if the signature of the dispatcher matches the kernel signature."""
235- dispatcher_signature = inspect .signature (dispatcher )
236- dispatcher_params = list (dispatcher_signature .parameters .values ())[1 :]
237-
238- kernel_signature = inspect .signature (kernel )
239- kernel_params = list (kernel_signature .parameters .values ())[1 :]
232+ dispatcher_params = list (inspect .signature (dispatcher ).parameters .values ())[1 :]
233+ kernel_params = list (inspect .signature (kernel ).parameters .values ())[1 :]
240234
241- if issubclass (input_type , datapoints ._datapoint . Datapoint ):
235+ if issubclass (input_type , datapoints .Datapoint ):
242236 # We filter out metadata that is implicitly passed to the dispatcher through the input datapoint, but has to be
243237 # explicitly passed to the kernel.
244- kernel_params = [param for param in kernel_params if param .name not in input_type .__annotations__ .keys ()]
238+ explicit_metadata = {
239+ datapoints .BoundingBoxes : {"format" , "canvas_size" },
240+ }
241+ kernel_params = [param for param in kernel_params if param .name not in explicit_metadata .get (input_type , set ())]
245242
246243 dispatcher_params = iter (dispatcher_params )
247244 for dispatcher_param , kernel_param in zip (dispatcher_params , kernel_params ):
@@ -264,32 +261,6 @@ def _check_dispatcher_kernel_signature_match(dispatcher, *, kernel, input_type):
264261 assert dispatcher_param == kernel_param
265262
266263
267- def _check_dispatcher_datapoint_signature_match (dispatcher ):
268- """Checks if the signature of the dispatcher matches the corresponding method signature on the Datapoint class."""
269- if dispatcher in {F .resize , F .adjust_brightness }:
270- return
271- dispatcher_signature = inspect .signature (dispatcher )
272- dispatcher_params = list (dispatcher_signature .parameters .values ())[1 :]
273-
274- datapoint_method = getattr (datapoints ._datapoint .Datapoint , dispatcher .__name__ )
275- datapoint_signature = inspect .signature (datapoint_method )
276- datapoint_params = list (datapoint_signature .parameters .values ())[1 :]
277-
278- # Some annotations in the `datapoints._datapoint` module
279- # are stored as strings. The block below makes them concrete again (non-strings), so they can be compared to the
280- # natively concrete dispatcher annotations.
281- datapoint_annotations = get_type_hints (datapoint_method )
282- for param in datapoint_params :
283- param ._annotation = datapoint_annotations [param .name ]
284-
285- assert dispatcher_params == datapoint_params
286-
287-
288- def check_dispatcher_signatures_match (dispatcher , * , kernel , input_type ):
289- _check_dispatcher_kernel_signature_match (dispatcher , kernel = kernel , input_type = input_type )
290- _check_dispatcher_datapoint_signature_match (dispatcher )
291-
292-
293264def _check_transform_v1_compatibility (transform , input ):
294265 """If the transform defines the ``_v1_transform_cls`` attribute, checks if the transform has a public, static
295266 ``get_params`` method, is scriptable, and the scripted version can be called without error."""
@@ -461,7 +432,7 @@ def test_exhaustive_kernel_registration(dispatcher, registered_datapoint_clss):
461432 * [f"- { name } " for name in names ],
462433 "" ,
463434 f"If available, register the kernels with @_register_kernel_internal({ dispatcher .__name__ } , ...)." ,
464- f"If not, register explicit no-ops with @_register_explicit_noops ({ ', ' .join (names )} )" ,
435+ f"If not, register explicit no-ops with @_register_explicit_noop ({ ', ' .join (names )} )" ,
465436 ]
466437 )
467438 )
@@ -602,7 +573,7 @@ def test_dispatcher(self, size, kernel, make_input):
602573 ],
603574 )
604575 def test_dispatcher_signature (self , kernel , input_type ):
605- check_dispatcher_signatures_match (F .resize , kernel = kernel , input_type = input_type )
576+ check_dispatcher_kernel_signature_match (F .resize , kernel = kernel , input_type = input_type )
606577
607578 @pytest .mark .parametrize ("size" , OUTPUT_SIZES )
608579 @pytest .mark .parametrize ("device" , cpu_and_cuda ())
@@ -800,7 +771,7 @@ def test_noop(self, size, make_input):
800771
801772 # This identity check is not a requirement. It is here to avoid breaking the behavior by accident. If there
802773 # is a good reason to break this, feel free to downgrade to an equality check.
803- if isinstance (input , datapoints ._datapoint . Datapoint ):
774+ if isinstance (input , datapoints .Datapoint ):
804775 # We can't test identity directly, since that checks for the identity of the Python object. Since all
805776 # datapoints unwrap before a kernel and wrap again afterwards, the Python object changes. Thus, we check
806777 # that the underlying storage is the same
@@ -884,7 +855,7 @@ def test_dispatcher(self, kernel, make_input):
884855 ],
885856 )
886857 def test_dispatcher_signature (self , kernel , input_type ):
887- check_dispatcher_signatures_match (F .horizontal_flip , kernel = kernel , input_type = input_type )
858+ check_dispatcher_kernel_signature_match (F .horizontal_flip , kernel = kernel , input_type = input_type )
888859
889860 @pytest .mark .parametrize (
890861 "make_input" ,
@@ -1067,7 +1038,7 @@ def test_dispatcher(self, kernel, make_input):
10671038 ],
10681039 )
10691040 def test_dispatcher_signature (self , kernel , input_type ):
1070- check_dispatcher_signatures_match (F .affine , kernel = kernel , input_type = input_type )
1041+ check_dispatcher_kernel_signature_match (F .affine , kernel = kernel , input_type = input_type )
10711042
10721043 @pytest .mark .parametrize (
10731044 "make_input" ,
@@ -1363,7 +1334,7 @@ def test_dispatcher(self, kernel, make_input):
13631334 ],
13641335 )
13651336 def test_dispatcher_signature (self , kernel , input_type ):
1366- check_dispatcher_signatures_match (F .vertical_flip , kernel = kernel , input_type = input_type )
1337+ check_dispatcher_kernel_signature_match (F .vertical_flip , kernel = kernel , input_type = input_type )
13671338
13681339 @pytest .mark .parametrize (
13691340 "make_input" ,
@@ -1520,7 +1491,7 @@ def test_dispatcher(self, kernel, make_input):
15201491 ],
15211492 )
15221493 def test_dispatcher_signature (self , kernel , input_type ):
1523- check_dispatcher_signatures_match (F .rotate , kernel = kernel , input_type = input_type )
1494+ check_dispatcher_kernel_signature_match (F .rotate , kernel = kernel , input_type = input_type )
15241495
15251496 @pytest .mark .parametrize (
15261497 "make_input" ,
@@ -1971,7 +1942,7 @@ def test_dispatcher(self, kernel, make_input):
19711942 ],
19721943 )
19731944 def test_dispatcher_signature (self , kernel , input_type ):
1974- check_dispatcher_signatures_match (F .adjust_brightness , kernel = kernel , input_type = input_type )
1945+ check_dispatcher_kernel_signature_match (F .adjust_brightness , kernel = kernel , input_type = input_type )
19751946
19761947 @pytest .mark .parametrize ("brightness_factor" , _CORRECTNESS_BRIGHTNESS_FACTORS )
19771948 def test_image_correctness (self , brightness_factor ):
0 commit comments