Skip to content

Commit c095522

Browse files
authored
Merge branch 'main' into cxx-14-17
2 parents 20dd6a6 + a1ec864 commit c095522

File tree

2 files changed

+22
-15
lines changed

2 files changed

+22
-15
lines changed

torchvision/models/detection/transform.py

Lines changed: 22 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,8 @@ def _fake_cast_onnx(v: Tensor) -> float:
2424

2525
def _resize_image_and_masks(
2626
image: Tensor,
27-
self_min_size: float,
28-
self_max_size: float,
27+
self_min_size: int,
28+
self_max_size: int,
2929
target: Optional[Dict[str, Tensor]] = None,
3030
fixed_size: Optional[Tuple[int, int]] = None,
3131
) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]:
@@ -40,14 +40,24 @@ def _resize_image_and_masks(
4040
if fixed_size is not None:
4141
size = [fixed_size[1], fixed_size[0]]
4242
else:
43-
min_size = torch.min(im_shape).to(dtype=torch.float32)
44-
max_size = torch.max(im_shape).to(dtype=torch.float32)
45-
scale = torch.min(self_min_size / min_size, self_max_size / max_size)
43+
if torch.jit.is_scripting() or torchvision._is_tracing():
44+
min_size = torch.min(im_shape).to(dtype=torch.float32)
45+
max_size = torch.max(im_shape).to(dtype=torch.float32)
46+
self_min_size_f = float(self_min_size)
47+
self_max_size_f = float(self_max_size)
48+
scale = torch.min(self_min_size_f / min_size, self_max_size_f / max_size)
49+
50+
if torchvision._is_tracing():
51+
scale_factor = _fake_cast_onnx(scale)
52+
else:
53+
scale_factor = scale.item()
4654

47-
if torchvision._is_tracing():
48-
scale_factor = _fake_cast_onnx(scale)
4955
else:
50-
scale_factor = scale.item()
56+
# Do it the normal way
57+
min_size = min(im_shape)
58+
max_size = max(im_shape)
59+
scale_factor = min(self_min_size / min_size, self_max_size / max_size)
60+
5161
recompute_scale_factor = True
5262

5363
image = torch.nn.functional.interpolate(
@@ -159,8 +169,7 @@ def normalize(self, image: Tensor) -> Tensor:
159169
def torch_choice(self, k: List[int]) -> int:
160170
"""
161171
Implements `random.choice` via torch ops, so it can be compiled with
162-
TorchScript. Remove if https://github.com/pytorch/pytorch/issues/25803
163-
is fixed.
172+
TorchScript and we use PyTorch's RNG (not native RNG)
164173
"""
165174
index = int(torch.empty(1).uniform_(0.0, float(len(k))).item())
166175
return k[index]
@@ -174,11 +183,10 @@ def resize(
174183
if self.training:
175184
if self._skip_resize:
176185
return image, target
177-
size = float(self.torch_choice(self.min_size))
186+
size = self.torch_choice(self.min_size)
178187
else:
179-
# FIXME assume for now that testing uses the largest scale
180-
size = float(self.min_size[-1])
181-
image, target = _resize_image_and_masks(image, size, float(self.max_size), target, self.fixed_size)
188+
size = self.min_size[-1]
189+
image, target = _resize_image_and_masks(image, size, self.max_size, target, self.fixed_size)
182190

183191
if target is None:
184192
return image, target

torchvision/utils.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,6 @@ def make_grid(
2929
value_range: Optional[Tuple[int, int]] = None,
3030
scale_each: bool = False,
3131
pad_value: float = 0.0,
32-
**kwargs,
3332
) -> torch.Tensor:
3433
"""
3534
Make a grid of images.

0 commit comments

Comments
 (0)