From 452caaaf9a46f36390f00e7f66066c3b4edfaf22 Mon Sep 17 00:00:00 2001 From: "Edward Z. Yang" Date: Mon, 15 May 2023 17:45:56 -0700 Subject: [PATCH 1/3] Add non-TS'able _resize_image_and_masks variant with less tensor ops We did some horrible things to _resize_image_and_masks to make it TorchScriptable, and those horrible things cause weird divergences when you send the float computation to a real compiler that is willing to do fastmath optimizations to floating point, see https://github.com/pytorch/pytorch/issues/93598 This PR adds a non TS-goopified version of the operator which doesn't have this problem, since it does the size compute the "normal way" (and consequently, doesn't get fastmath'ified). Signed-off-by: Edward Z. Yang --- torchvision/models/detection/transform.py | 64 ++++++++++++++++++++--- 1 file changed, 57 insertions(+), 7 deletions(-) diff --git a/torchvision/models/detection/transform.py b/torchvision/models/detection/transform.py index 589d5e45bdc..e9b3389415a 100644 --- a/torchvision/models/detection/transform.py +++ b/torchvision/models/detection/transform.py @@ -71,6 +71,47 @@ def _resize_image_and_masks( return image, target +def _resize_image_and_masks_simple( + image: Tensor, + self_min_size: int, + self_max_size: int, + target: Optional[Dict[str, Tensor]] = None, + fixed_size: Optional[Tuple[int, int]] = None, +) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]: + im_shape = image.shape[-2:] + + size: Optional[List[int]] = None + scale_factor: Optional[float] = None + recompute_scale_factor: Optional[bool] = None + if fixed_size is not None: + size = [fixed_size[1], fixed_size[0]] + else: + min_size = min(im_shape) + max_size = max(im_shape) + scale_factor = min(self_min_size / min_size, self_max_size / max_size) + recompute_scale_factor = True + + image = torch.nn.functional.interpolate( + image[None], + size=size, + scale_factor=scale_factor, + mode="bilinear", + recompute_scale_factor=recompute_scale_factor, + align_corners=False, + )[0] + + if target is None: + return image, target + + if "masks" in target: + mask = target["masks"] + mask = torch.nn.functional.interpolate( + mask[:, None].float(), size=size, scale_factor=scale_factor, recompute_scale_factor=recompute_scale_factor + )[:, 0].byte() + target["masks"] = mask + return image, target + + class GeneralizedRCNNTransform(nn.Module): """ Performs input / target transformation before feeding the data to a GeneralizedRCNN @@ -171,14 +212,23 @@ def resize( target: Optional[Dict[str, Tensor]] = None, ) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]: h, w = image.shape[-2:] - if self.training: - if self._skip_resize: - return image, target - size = float(self.torch_choice(self.min_size)) + if not torch.jit.is_scripting(): + if self.training: + if self._skip_resize: + return image, target + size = random.choice(self.min_size) + else: + size = self.min_size[-1] + image, target = _resize_image_and_masks_simple(image, size, self.max_size, target, self.fixed_size) else: - # FIXME assume for now that testing uses the largest scale - size = float(self.min_size[-1]) - image, target = _resize_image_and_masks(image, size, float(self.max_size), target, self.fixed_size) + if self.training: + if self._skip_resize: + return image, target + size = float(self.torch_choice(self.min_size)) + else: + # FIXME assume for now that testing uses the largest scale + size = float(self.min_size[-1]) + image, target = _resize_image_and_masks(image, size, float(self.max_size), target, self.fixed_size) if target is None: return image, target From 8878e0348c9ecf74a17db1b1fa9974300c5f3563 Mon Sep 17 00:00:00 2001 From: "Edward Z. Yang" Date: Tue, 16 May 2023 07:06:26 -0700 Subject: [PATCH 2/3] simplify the duplication Signed-off-by: Edward Z. Yang --- torchvision/models/detection/transform.py | 90 ++++++----------------- 1 file changed, 24 insertions(+), 66 deletions(-) diff --git a/torchvision/models/detection/transform.py b/torchvision/models/detection/transform.py index e9b3389415a..be02f46ba29 100644 --- a/torchvision/models/detection/transform.py +++ b/torchvision/models/detection/transform.py @@ -24,8 +24,8 @@ def _fake_cast_onnx(v: Tensor) -> float: def _resize_image_and_masks( image: Tensor, - self_min_size: float, - self_max_size: float, + self_min_size: int, + self_max_size: int, target: Optional[Dict[str, Tensor]] = None, fixed_size: Optional[Tuple[int, int]] = None, ) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]: @@ -40,55 +40,24 @@ def _resize_image_and_masks( if fixed_size is not None: size = [fixed_size[1], fixed_size[0]] else: - min_size = torch.min(im_shape).to(dtype=torch.float32) - max_size = torch.max(im_shape).to(dtype=torch.float32) - scale = torch.min(self_min_size / min_size, self_max_size / max_size) + if torch.jit.is_scripting(): + min_size = torch.min(im_shape).to(dtype=torch.float32) + max_size = torch.max(im_shape).to(dtype=torch.float32) + self_min_size_f = float(self_min_size) + self_max_size_f = float(self_max_size) + scale = torch.min(self_min_size_f / min_size, self_max_size_f / max_size) + + if torchvision._is_tracing(): + scale_factor = _fake_cast_onnx(scale) + else: + scale_factor = scale.item() - if torchvision._is_tracing(): - scale_factor = _fake_cast_onnx(scale) else: - scale_factor = scale.item() - recompute_scale_factor = True - - image = torch.nn.functional.interpolate( - image[None], - size=size, - scale_factor=scale_factor, - mode="bilinear", - recompute_scale_factor=recompute_scale_factor, - align_corners=False, - )[0] - - if target is None: - return image, target - - if "masks" in target: - mask = target["masks"] - mask = torch.nn.functional.interpolate( - mask[:, None].float(), size=size, scale_factor=scale_factor, recompute_scale_factor=recompute_scale_factor - )[:, 0].byte() - target["masks"] = mask - return image, target - + # Do it the normal way + min_size = min(im_shape) + max_size = max(im_shape) + scale_factor = min(self_min_size / min_size, self_max_size / max_size) -def _resize_image_and_masks_simple( - image: Tensor, - self_min_size: int, - self_max_size: int, - target: Optional[Dict[str, Tensor]] = None, - fixed_size: Optional[Tuple[int, int]] = None, -) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]: - im_shape = image.shape[-2:] - - size: Optional[List[int]] = None - scale_factor: Optional[float] = None - recompute_scale_factor: Optional[bool] = None - if fixed_size is not None: - size = [fixed_size[1], fixed_size[0]] - else: - min_size = min(im_shape) - max_size = max(im_shape) - scale_factor = min(self_min_size / min_size, self_max_size / max_size) recompute_scale_factor = True image = torch.nn.functional.interpolate( @@ -200,8 +169,7 @@ def normalize(self, image: Tensor) -> Tensor: def torch_choice(self, k: List[int]) -> int: """ Implements `random.choice` via torch ops, so it can be compiled with - TorchScript. Remove if https://github.com/pytorch/pytorch/issues/25803 - is fixed. + TorchScript and we use PyTorch's RNG (not native RNG) """ index = int(torch.empty(1).uniform_(0.0, float(len(k))).item()) return k[index] @@ -212,23 +180,13 @@ def resize( target: Optional[Dict[str, Tensor]] = None, ) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]: h, w = image.shape[-2:] - if not torch.jit.is_scripting(): - if self.training: - if self._skip_resize: - return image, target - size = random.choice(self.min_size) - else: - size = self.min_size[-1] - image, target = _resize_image_and_masks_simple(image, size, self.max_size, target, self.fixed_size) + if self.training: + if self._skip_resize: + return image, target + size = self.torch_choice(self.min_size) else: - if self.training: - if self._skip_resize: - return image, target - size = float(self.torch_choice(self.min_size)) - else: - # FIXME assume for now that testing uses the largest scale - size = float(self.min_size[-1]) - image, target = _resize_image_and_masks(image, size, float(self.max_size), target, self.fixed_size) + size = self.min_size[-1] + image, target = _resize_image_and_masks(image, size, self.max_size, target, self.fixed_size) if target is None: return image, target From d8043a9868dc9b1e1a578d1572ddbfd28f1b21a5 Mon Sep 17 00:00:00 2001 From: "Edward Z. Yang" Date: Fri, 19 May 2023 17:49:24 -0700 Subject: [PATCH 3/3] fix onnx Signed-off-by: Edward Z. Yang --- torchvision/models/detection/transform.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchvision/models/detection/transform.py b/torchvision/models/detection/transform.py index be02f46ba29..658c9e83455 100644 --- a/torchvision/models/detection/transform.py +++ b/torchvision/models/detection/transform.py @@ -40,7 +40,7 @@ def _resize_image_and_masks( if fixed_size is not None: size = [fixed_size[1], fixed_size[0]] else: - if torch.jit.is_scripting(): + if torch.jit.is_scripting() or torchvision._is_tracing(): min_size = torch.min(im_shape).to(dtype=torch.float32) max_size = torch.max(im_shape).to(dtype=torch.float32) self_min_size_f = float(self_min_size)