Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions test/test_prototype_transforms_consistency.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ def __init__(
legacy_transforms.Resize,
[
ArgsKwargs(32),
ArgsKwargs([32]),
ArgsKwargs((32, 29)),
ArgsKwargs((31, 28), interpolation=prototype_transforms.InterpolationMode.NEAREST),
ArgsKwargs((33, 26), interpolation=prototype_transforms.InterpolationMode.BICUBIC),
Expand Down
15 changes: 10 additions & 5 deletions torchvision/prototype/transforms/_geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,11 +46,16 @@ def __init__(
) -> None:
super().__init__()

self.size = (
[size]
if isinstance(size, int)
else _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.")
)
if isinstance(size, int):
size = [size]
elif isinstance(size, (list, tuple)) and len(size) in {1, 2}:
size = list(size)
else:
raise ValueError(
f"size can either be an integer or a list or tuple of one or two integers, " f"but got {size} instead."
)
self.size = size

self.interpolation = interpolation
self.max_size = max_size
self.antialias = antialias
Expand Down