Skip to content

Commit cf3ca68

Browse files
NicolasHugfacebook-github-bot
authored andcommitted
[fbsync] Dispatcher -> Functional (#7829)
Reviewed By: matteobettini Differential Revision: D48642298 fbshipit-source-id: fc7c521bde09fea9450b5ac2b11b5b055432db65
1 parent f168463 commit cf3ca68

File tree

7 files changed

+99
-99
lines changed

7 files changed

+99
-99
lines changed

gallery/plot_custom_datapoints.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ class MyDatapoint(datapoints.Datapoint):
4949
from torchvision.transforms.v2 import functional as F
5050

5151

52-
@F.register_kernel(dispatcher="hflip", datapoint_cls=MyDatapoint)
52+
@F.register_kernel(functional="hflip", datapoint_cls=MyDatapoint)
5353
def hflip_my_datapoint(my_dp, *args, **kwargs):
5454
print("Flipping!")
5555
out = my_dp.flip(-1)
@@ -64,9 +64,9 @@ def hflip_my_datapoint(my_dp, *args, **kwargs):
6464
# .. note::
6565
#
6666
# In our call to ``register_kernel`` above we used a string
67-
# ``dispatcher="hflip"`` to refer to the functional we want to hook into. We
67+
# ``functional="hflip"`` to refer to the functional we want to hook into. We
6868
# could also have used the functional *itself*, i.e.
69-
# ``@register_kernel(dispatcher=F.hflip, ...)``.
69+
# ``@register_kernel(functional=F.hflip, ...)``.
7070
#
7171
# The functionals that you can be hooked into are the ones in
7272
# ``torchvision.transforms.v2.functional`` and they are documented in

test/test_transforms_v2_refactored.py

Lines changed: 64 additions & 64 deletions
Large diffs are not rendered by default.

torchvision/transforms/v2/_augment.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -91,13 +91,13 @@ def __init__(
9191

9292
self._log_ratio = torch.log(torch.tensor(self.ratio))
9393

94-
def _call_kernel(self, dispatcher: Callable, inpt: Any, *args: Any, **kwargs: Any) -> Any:
94+
def _call_kernel(self, functional: Callable, inpt: Any, *args: Any, **kwargs: Any) -> Any:
9595
if isinstance(inpt, (datapoints.BoundingBoxes, datapoints.Mask)):
9696
warnings.warn(
9797
f"{type(self).__name__}() is currently passing through inputs of type "
9898
f"datapoints.{type(inpt).__name__}. This will likely change in the future."
9999
)
100-
return super()._call_kernel(dispatcher, inpt, *args, **kwargs)
100+
return super()._call_kernel(functional, inpt, *args, **kwargs)
101101

102102
def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
103103
img_c, img_h, img_w = query_chw(flat_inputs)

torchvision/transforms/v2/_geometry.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -358,13 +358,13 @@ def __init__(self, size: Union[int, Sequence[int]]) -> None:
358358
super().__init__()
359359
self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.")
360360

361-
def _call_kernel(self, dispatcher: Callable, inpt: Any, *args: Any, **kwargs: Any) -> Any:
361+
def _call_kernel(self, functional: Callable, inpt: Any, *args: Any, **kwargs: Any) -> Any:
362362
if isinstance(inpt, (datapoints.BoundingBoxes, datapoints.Mask)):
363363
warnings.warn(
364364
f"{type(self).__name__}() is currently passing through inputs of type "
365365
f"datapoints.{type(inpt).__name__}. This will likely change in the future."
366366
)
367-
return super()._call_kernel(dispatcher, inpt, *args, **kwargs)
367+
return super()._call_kernel(functional, inpt, *args, **kwargs)
368368

369369
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
370370
return self._call_kernel(F.five_crop, inpt, self.size)
@@ -405,13 +405,13 @@ def __init__(self, size: Union[int, Sequence[int]], vertical_flip: bool = False)
405405
self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.")
406406
self.vertical_flip = vertical_flip
407407

408-
def _call_kernel(self, dispatcher: Callable, inpt: Any, *args: Any, **kwargs: Any) -> Any:
408+
def _call_kernel(self, functional: Callable, inpt: Any, *args: Any, **kwargs: Any) -> Any:
409409
if isinstance(inpt, (datapoints.BoundingBoxes, datapoints.Mask)):
410410
warnings.warn(
411411
f"{type(self).__name__}() is currently passing through inputs of type "
412412
f"datapoints.{type(inpt).__name__}. This will likely change in the future."
413413
)
414-
return super()._call_kernel(dispatcher, inpt, *args, **kwargs)
414+
return super()._call_kernel(functional, inpt, *args, **kwargs)
415415

416416
def _check_inputs(self, flat_inputs: List[Any]) -> None:
417417
if has_any(flat_inputs, datapoints.BoundingBoxes, datapoints.Mask):

torchvision/transforms/v2/_transform.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,8 @@ def _check_inputs(self, flat_inputs: List[Any]) -> None:
3030
def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
3131
return dict()
3232

33-
def _call_kernel(self, dispatcher: Callable, inpt: Any, *args: Any, **kwargs: Any) -> Any:
34-
kernel = _get_kernel(dispatcher, type(inpt), allow_passthrough=True)
33+
def _call_kernel(self, functional: Callable, inpt: Any, *args: Any, **kwargs: Any) -> Any:
34+
kernel = _get_kernel(functional, type(inpt), allow_passthrough=True)
3535
return kernel(inpt, *args, **kwargs)
3636

3737
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:

torchvision/transforms/v2/functional/_meta.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -203,7 +203,7 @@ def convert_format_bounding_boxes(
203203
new_format: Optional[BoundingBoxFormat] = None,
204204
inplace: bool = False,
205205
) -> torch.Tensor:
206-
# This being a kernel / dispatcher hybrid, we need an option to pass `old_format` explicitly for simple tensor
206+
# This being a kernel / functional hybrid, we need an option to pass `old_format` explicitly for simple tensor
207207
# inputs as well as extract it from `datapoints.BoundingBoxes` inputs. However, putting a default value on
208208
# `old_format` means we also need to put one on `new_format` to have syntactically correct Python. Here we mimic the
209209
# default error that would be thrown if `new_format` had no default value.

torchvision/transforms/v2/functional/_utils.py

Lines changed: 23 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ def is_simple_tensor(inpt: Any) -> bool:
1212
return isinstance(inpt, torch.Tensor) and not isinstance(inpt, datapoints.Datapoint)
1313

1414

15-
# {dispatcher: {input_type: type_specific_kernel}}
15+
# {functional: {input_type: type_specific_kernel}}
1616
_KERNEL_REGISTRY: Dict[Callable, Dict[Type, Callable]] = {}
1717

1818

@@ -27,10 +27,10 @@ def wrapper(inpt, *args, **kwargs):
2727
return wrapper
2828

2929

30-
def _register_kernel_internal(dispatcher, input_type, *, datapoint_wrapper=True):
31-
registry = _KERNEL_REGISTRY.setdefault(dispatcher, {})
30+
def _register_kernel_internal(functional, input_type, *, datapoint_wrapper=True):
31+
registry = _KERNEL_REGISTRY.setdefault(functional, {})
3232
if input_type in registry:
33-
raise ValueError(f"Dispatcher {dispatcher} already has a kernel registered for type {input_type}.")
33+
raise ValueError(f"Functional {functional} already has a kernel registered for type {input_type}.")
3434

3535
def decorator(kernel):
3636
registry[input_type] = (
@@ -43,14 +43,14 @@ def decorator(kernel):
4343
return decorator
4444

4545

46-
def _name_to_dispatcher(name):
46+
def _name_to_functional(name):
4747
import torchvision.transforms.v2.functional # noqa
4848

4949
try:
5050
return getattr(torchvision.transforms.v2.functional, name)
5151
except AttributeError:
5252
raise ValueError(
53-
f"Could not find dispatcher with name '{name}' in torchvision.transforms.v2.functional."
53+
f"Could not find functional with name '{name}' in torchvision.transforms.v2.functional."
5454
) from None
5555

5656

@@ -59,21 +59,21 @@ def _name_to_dispatcher(name):
5959
}
6060

6161

62-
def register_kernel(dispatcher, datapoint_cls):
63-
"""Decorate a kernel to register it for a dispatcher and a (custom) datapoint type.
62+
def register_kernel(functional, datapoint_cls):
63+
"""Decorate a kernel to register it for a functional and a (custom) datapoint type.
6464
6565
See :ref:`sphx_glr_auto_examples_plot_custom_datapoints.py` for usage
6666
details.
6767
"""
68-
if isinstance(dispatcher, str):
69-
dispatcher = _name_to_dispatcher(name=dispatcher)
68+
if isinstance(functional, str):
69+
functional = _name_to_functional(name=functional)
7070
elif not (
71-
callable(dispatcher)
72-
and getattr(dispatcher, "__module__", "").startswith("torchvision.transforms.v2.functional")
71+
callable(functional)
72+
and getattr(functional, "__module__", "").startswith("torchvision.transforms.v2.functional")
7373
):
7474
raise ValueError(
75-
f"Kernels can only be registered on dispatchers from the torchvision.transforms.v2.functional namespace, "
76-
f"but got {dispatcher}."
75+
f"Kernels can only be registered on functionals from the torchvision.transforms.v2.functional namespace, "
76+
f"but got {functional}."
7777
)
7878

7979
if not (isinstance(datapoint_cls, type) and issubclass(datapoint_cls, datapoints.Datapoint)):
@@ -85,13 +85,13 @@ def register_kernel(dispatcher, datapoint_cls):
8585
if datapoint_cls in _BUILTIN_DATAPOINT_TYPES:
8686
raise ValueError(f"Kernels cannot be registered for the builtin datapoint classes, but got {datapoint_cls}")
8787

88-
return _register_kernel_internal(dispatcher, datapoint_cls, datapoint_wrapper=False)
88+
return _register_kernel_internal(functional, datapoint_cls, datapoint_wrapper=False)
8989

9090

91-
def _get_kernel(dispatcher, input_type, *, allow_passthrough=False):
92-
registry = _KERNEL_REGISTRY.get(dispatcher)
91+
def _get_kernel(functional, input_type, *, allow_passthrough=False):
92+
registry = _KERNEL_REGISTRY.get(functional)
9393
if not registry:
94-
raise ValueError(f"No kernel registered for dispatcher {dispatcher.__name__}.")
94+
raise ValueError(f"No kernel registered for functional {functional.__name__}.")
9595

9696
# In case we have an exact type match, we take a shortcut.
9797
if input_type in registry:
@@ -113,17 +113,17 @@ def _get_kernel(dispatcher, input_type, *, allow_passthrough=False):
113113
return lambda inpt, *args, **kwargs: inpt
114114

115115
raise TypeError(
116-
f"Dispatcher F.{dispatcher.__name__} supports inputs of type {registry.keys()}, "
116+
f"Functional F.{functional.__name__} supports inputs of type {registry.keys()}, "
117117
f"but got {input_type} instead."
118118
)
119119

120120

121121
# This basically replicates _register_kernel_internal, but with a specialized wrapper for five_crop / ten_crop
122-
# We could get rid of this by letting _register_kernel_internal take arbitrary dispatchers rather than wrap_kernel: bool
123-
def _register_five_ten_crop_kernel_internal(dispatcher, input_type):
124-
registry = _KERNEL_REGISTRY.setdefault(dispatcher, {})
122+
# We could get rid of this by letting _register_kernel_internal take arbitrary functionals rather than wrap_kernel: bool
123+
def _register_five_ten_crop_kernel_internal(functional, input_type):
124+
registry = _KERNEL_REGISTRY.setdefault(functional, {})
125125
if input_type in registry:
126-
raise TypeError(f"Dispatcher '{dispatcher}' already has a kernel registered for type '{input_type}'.")
126+
raise TypeError(f"Functional '{functional}' already has a kernel registered for type '{input_type}'.")
127127

128128
def wrap(kernel):
129129
@functools.wraps(kernel)

0 commit comments

Comments
 (0)