22import math
33import os
44import re
5-
6- from typing import get_type_hints
5+ from unittest import mock
76
87import numpy as np
98import PIL .Image
109import pytest
11-
1210import torch
1311
1412from common_utils import (
2725from torchvision .transforms .v2 import functional as F
2826from torchvision .transforms .v2 .functional ._geometry import _center_crop_compute_padding
2927from torchvision .transforms .v2 .functional ._meta import clamp_bounding_boxes , convert_format_bounding_boxes
28+ from torchvision .transforms .v2 .functional ._utils import _KERNEL_REGISTRY
3029from torchvision .transforms .v2 .utils import is_simple_tensor
3130from transforms_v2_dispatcher_infos import DISPATCHER_INFOS
3231from transforms_v2_kernel_infos import KERNEL_INFOS
@@ -424,12 +423,18 @@ def test_pil_output_type(self, info, args_kwargs):
424423 def test_dispatch_datapoint (self , info , args_kwargs , spy_on ):
425424 (datapoint , * other_args ), kwargs = args_kwargs .load ()
426425
427- method_name = info .id
428- method = getattr (datapoint , method_name )
429- datapoint_type = type (datapoint )
430- spy = spy_on (method , module = datapoint_type .__module__ , name = f"{ datapoint_type .__name__ } .{ method_name } " )
426+ input_type = type (datapoint )
427+
428+ wrapped_kernel = _KERNEL_REGISTRY [info .dispatcher ][input_type ]
431429
432- info .dispatcher (datapoint , * other_args , ** kwargs )
430+ # In case the wrapper was decorated with @functools.wraps, we can make the check more strict and test if the
431+ # proper kernel was wrapped
432+ if hasattr (wrapped_kernel , "__wrapped__" ):
433+ assert wrapped_kernel .__wrapped__ is info .kernels [input_type ]
434+
435+ spy = mock .MagicMock (wraps = wrapped_kernel , name = wrapped_kernel .__name__ )
436+ with mock .patch .dict (_KERNEL_REGISTRY [info .dispatcher ], values = {input_type : spy }):
437+ info .dispatcher (datapoint , * other_args , ** kwargs )
433438
434439 spy .assert_called_once ()
435440
@@ -462,9 +467,12 @@ def test_dispatcher_kernel_signatures_consistency(self, dispatcher_info, datapoi
462467 kernel_params = list (kernel_signature .parameters .values ())[1 :]
463468
464469 # We filter out metadata that is implicitly passed to the dispatcher through the input datapoint, but has to be
465- # explicit passed to the kernel.
466- datapoint_type_metadata = datapoint_type .__annotations__ .keys ()
467- kernel_params = [param for param in kernel_params if param .name not in datapoint_type_metadata ]
470+ # explicitly passed to the kernel.
471+ input_type = {v : k for k , v in dispatcher_info .kernels .items ()}.get (kernel_info .kernel )
472+ explicit_metadata = {
473+ datapoints .BoundingBoxes : {"format" , "canvas_size" },
474+ }
475+ kernel_params = [param for param in kernel_params if param .name not in explicit_metadata .get (input_type , set ())]
468476
469477 dispatcher_params = iter (dispatcher_params )
470478 for dispatcher_param , kernel_param in zip (dispatcher_params , kernel_params ):
@@ -481,28 +489,6 @@ def test_dispatcher_kernel_signatures_consistency(self, dispatcher_info, datapoi
481489
482490 assert dispatcher_param == kernel_param
483491
484- @pytest .mark .parametrize ("info" , DISPATCHER_INFOS , ids = lambda info : info .id )
485- def test_dispatcher_datapoint_signatures_consistency (self , info ):
486- try :
487- datapoint_method = getattr (datapoints ._datapoint .Datapoint , info .id )
488- except AttributeError :
489- pytest .skip ("Dispatcher doesn't support arbitrary datapoint dispatch." )
490-
491- dispatcher_signature = inspect .signature (info .dispatcher )
492- dispatcher_params = list (dispatcher_signature .parameters .values ())[1 :]
493-
494- datapoint_signature = inspect .signature (datapoint_method )
495- datapoint_params = list (datapoint_signature .parameters .values ())[1 :]
496-
497- # Because we use `from __future__ import annotations` inside the module where `datapoints._datapoint` is
498- # defined, the annotations are stored as strings. This makes them concrete again, so they can be compared to the
499- # natively concrete dispatcher annotations.
500- datapoint_annotations = get_type_hints (datapoint_method )
501- for param in datapoint_params :
502- param ._annotation = datapoint_annotations [param .name ]
503-
504- assert dispatcher_params == datapoint_params
505-
506492 @pytest .mark .parametrize ("info" , DISPATCHER_INFOS , ids = lambda info : info .id )
507493 def test_unkown_type (self , info ):
508494 unkown_input = object ()
0 commit comments