Skip to content

Commit 4c073b0

Browse files
authored
[proto] Fixed bug in ScaleJitter with params (#6541)
* [proto] Fixed bug in ScaleJitter with params * Updated tests
1 parent 74feb19 commit 4c073b0

File tree

2 files changed

+18
-11
lines changed

2 files changed

+18
-11
lines changed

test/test_prototype_transforms.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1263,18 +1263,24 @@ def test__get_params(self, mocker):
12631263
scale_range = (0.5, 1.5)
12641264

12651265
transform = transforms.ScaleJitter(target_size=target_size, scale_range=scale_range)
1266-
12671266
sample = mocker.MagicMock(spec=features.Image, num_channels=3, image_size=image_size)
1268-
params = transform._get_params(sample)
12691267

1270-
assert "size" in params
1271-
size = params["size"]
1268+
n_samples = 5
1269+
for _ in range(n_samples):
12721270

1273-
assert isinstance(size, tuple) and len(size) == 2
1274-
height, width = size
1271+
params = transform._get_params(sample)
1272+
1273+
assert "size" in params
1274+
size = params["size"]
1275+
1276+
assert isinstance(size, tuple) and len(size) == 2
1277+
height, width = size
1278+
1279+
r_min = min(target_size[1] / image_size[0], target_size[0] / image_size[1]) * scale_range[0]
1280+
r_max = min(target_size[1] / image_size[0], target_size[0] / image_size[1]) * scale_range[1]
12751281

1276-
assert int(target_size[0] * scale_range[0]) <= height <= int(target_size[0] * scale_range[1])
1277-
assert int(target_size[1] * scale_range[0]) <= width <= int(target_size[1] * scale_range[1])
1282+
assert int(image_size[0] * r_min) <= height <= int(image_size[0] * r_max)
1283+
assert int(image_size[1] * r_min) <= width <= int(image_size[1] * r_max)
12781284

12791285
def test__transform(self, mocker):
12801286
interpolation_sentinel = mocker.MagicMock()

torchvision/prototype/transforms/_geometry.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -727,9 +727,10 @@ def __init__(
727727
def _get_params(self, sample: Any) -> Dict[str, Any]:
728728
_, orig_height, orig_width = query_chw(sample)
729729

730-
r = self.scale_range[0] + torch.rand(1) * (self.scale_range[1] - self.scale_range[0])
731-
new_width = int(self.target_size[1] * r)
732-
new_height = int(self.target_size[0] * r)
730+
scale = self.scale_range[0] + torch.rand(1) * (self.scale_range[1] - self.scale_range[0])
731+
r = min(self.target_size[1] / orig_height, self.target_size[0] / orig_width) * scale
732+
new_width = int(orig_width * r)
733+
new_height = int(orig_height * r)
733734

734735
return dict(size=(new_height, new_width))
735736

0 commit comments

Comments
 (0)