From 02203b2d67aa6dd80f77967913c79c3e35482cd9 Mon Sep 17 00:00:00 2001 From: "Edward Z. Yang" Date: Sat, 13 May 2023 20:57:05 -0700 Subject: [PATCH 01/10] Add deterministic, pure-Python roi_align implementation See https://dev-discuss.pytorch.org/t/a-pure-python-implementation-of-roi-align-that-looks-just-like-its-cuda-kernel/1266 for discussion and motivation. Signed-off-by: Edward Z. Yang --- test/test_ops.py | 32 ++++++-- torchvision/ops/roi_align.py | 141 ++++++++++++++++++++++++++++++++++- 2 files changed, 163 insertions(+), 10 deletions(-) diff --git a/test/test_ops.py b/test/test_ops.py index 5f8f8098c21..5f3d32049b8 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -11,6 +11,7 @@ import torch.fx import torch.nn.functional as F from common_utils import assert_equal, cpu_and_gpu, needs_cuda +from torch.testing._internal.common_utils import DeterministicGuard from PIL import Image from torch import nn, Tensor from torch.autograd import gradcheck @@ -83,7 +84,7 @@ class RoIOpTester(ABC): @pytest.mark.parametrize("device", cpu_and_gpu()) @pytest.mark.parametrize("contiguous", (True, False)) - def test_forward(self, device, contiguous, x_dtype=None, rois_dtype=None, **kwargs): + def test_forward(self, device, contiguous, x_dtype=None, rois_dtype=None, deterministic=False, **kwargs): x_dtype = self.dtype if x_dtype is None else x_dtype rois_dtype = self.dtype if rois_dtype is None else rois_dtype pool_size = 5 @@ -99,7 +100,8 @@ def test_forward(self, device, contiguous, x_dtype=None, rois_dtype=None, **kwar ) pool_h, pool_w = pool_size, pool_size - y = self.fn(x, rois, pool_h, pool_w, spatial_scale=1, sampling_ratio=-1, **kwargs) + with DeterministicGuard(deterministic): + y = self.fn(x, rois, pool_h, pool_w, spatial_scale=1, sampling_ratio=-1, **kwargs) # the following should be true whether we're running an autocast test or not. assert y.dtype == x.dtype gt_y = self.expected_fn( @@ -140,7 +142,8 @@ def test_torch_fx_trace(self, device, x_dtype=torch.float, rois_dtype=torch.floa @pytest.mark.parametrize("seed", range(10)) @pytest.mark.parametrize("device", cpu_and_gpu()) @pytest.mark.parametrize("contiguous", (True, False)) - def test_backward(self, seed, device, contiguous): + @pytest.mark.parametrize("deterministic", (False,)) + def test_backward(self, seed, device, contiguous, deterministic): torch.random.manual_seed(seed) pool_size = 2 x = torch.rand(1, 2 * (pool_size**2), 5, 5, dtype=self.dtype, device=device, requires_grad=True) @@ -155,7 +158,9 @@ def func(z): script_func = self.get_script_fn(rois, pool_size) - gradcheck(func, (x,)) + with DeterministicGuard(deterministic): + gradcheck(func, (x,)) + gradcheck(script_func, (x,)) @needs_cuda @@ -402,21 +407,32 @@ def test_boxes_shape(self): @pytest.mark.parametrize("aligned", (True, False)) @pytest.mark.parametrize("device", cpu_and_gpu()) @pytest.mark.parametrize("contiguous", (True, False)) - def test_forward(self, device, contiguous, aligned, x_dtype=None, rois_dtype=None): + @pytest.mark.parametrize("deterministic", (True, False)) + def test_forward(self, device, contiguous, deterministic, aligned, x_dtype=None, rois_dtype=None): + if deterministic and device == "cpu": + pytest.skip("cpu is always deterministic, don't retest") super().test_forward( - device=device, contiguous=contiguous, x_dtype=x_dtype, rois_dtype=rois_dtype, aligned=aligned + device=device, contiguous=contiguous, deterministic=deterministic, x_dtype=x_dtype, rois_dtype=rois_dtype, aligned=aligned ) @needs_cuda @pytest.mark.parametrize("aligned", (True, False)) + @pytest.mark.parametrize("deterministic", (True, False)) @pytest.mark.parametrize("x_dtype", (torch.float, torch.half)) @pytest.mark.parametrize("rois_dtype", (torch.float, torch.half)) - def test_autocast(self, aligned, x_dtype, rois_dtype): + def test_autocast(self, aligned, deterministic, x_dtype, rois_dtype): with torch.cuda.amp.autocast(): self.test_forward( - torch.device("cuda"), contiguous=False, aligned=aligned, x_dtype=x_dtype, rois_dtype=rois_dtype + torch.device("cuda"), contiguous=False, deterministic=deterministic, aligned=aligned, x_dtype=x_dtype, rois_dtype=rois_dtype ) + @pytest.mark.parametrize("seed", range(10)) + @pytest.mark.parametrize("device", cpu_and_gpu()) + @pytest.mark.parametrize("contiguous", (True, False)) + @pytest.mark.parametrize("deterministic", (True, False)) + def test_backward(self, seed, device, contiguous, deterministic): + super().test_backward(seed, device, contiguous, deterministic) + def _make_rois(self, img_size, num_imgs, dtype, num_rois=1000): rois = torch.randint(0, img_size // 2, size=(num_rois, 5)).to(dtype) rois[:, 0] = torch.randint(0, num_imgs, size=(num_rois,)) # set batch index diff --git a/torchvision/ops/roi_align.py b/torchvision/ops/roi_align.py index 42e93cca211..83a7e7ee8bb 100644 --- a/torchvision/ops/roi_align.py +++ b/torchvision/ops/roi_align.py @@ -1,16 +1,149 @@ from typing import List, Union import torch +import torch._dynamo import torch.fx from torch import nn, Tensor from torch.jit.annotations import BroadcastingList2 from torch.nn.modules.utils import _pair -from torchvision.extension import _assert_has_ops +from torchvision.extension import _has_ops from ..utils import _log_api_usage_once from ._utils import check_roi_boxes_shape, convert_boxes_to_roi_format +# NB: all tensor inputs +def _bilinear_interpolate(input, roi_batch_ind, c, height, width, y, x, ymask, xmask): + from functorch.dim import dims + + # deal with inverse element out of feature map boundary + y = y.clamp(min=0) + x = x.clamp(min=0) + y_low = y.int() + x_low = x.int() + y_high = torch.where(y_low >= height - 1, height - 1, y_low + 1) + y_low = torch.where(y_low >= height - 1, height - 1, y_low) + y = torch.where(y_low >= height - 1, y.to(input.dtype), y) + + x_high = torch.where(x_low >= width - 1, width - 1, x_low + 1) + x_low = torch.where(x_low >= width - 1, width - 1, x_low) + x = torch.where(x_low >= width - 1, x.to(input.dtype), x) + + ly = y - y_low + lx = x - x_low + hy = 1. - ly + hx = 1. - lx + + # do bilinear interpolation, but respect the masking! + # TODO: It's possible the masking here is unnecessary if y and + # x were clamped appropriately; hard to tell + def masked_index(y, x): + if ymask is not None: + assert xmask is not None + y = torch.where(ymask, y, 0) + x = torch.where(xmask, x, 0) + return input[roi_batch_ind, c, y, x] + + v1 = masked_index(y_low, x_low) + v2 = masked_index(y_low, x_high) + v3 = masked_index(y_high, x_low) + v4 = masked_index(y_high, x_high) + w1 = hy * hx + w2 = hy * lx + w3 = ly * hx + w4 = ly * lx; + + val = w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4 + return val + +# TODO: this doesn't actually cache +# TODO: main library should make this easier to do +def maybe_cast(tensor): + if torch.is_autocast_enabled() and tensor.is_cuda and tensor.dtype != torch.double: + return tensor.float() + else: + return tensor + +# This is a slow but pure Python and differentiable implementation of +# roi_align. It potentially is a good basis for Inductor compilation +# (but I have not benchmarked it) but today it is solely used for the +# fact that its backwards can be implemented deterministically. +# +# It is transcribed directly off of the roi_align CUDA kernel, see +# https://dev-discuss.pytorch.org/t/a-pure-python-implementation-of-roi-align-that-looks-just-like-its-cuda-kernel/1266 +@torch._dynamo.allow_in_graph +def _roi_align(input, rois, spatial_scale, pooled_height, pooled_width, sampling_ratio, aligned): + from functorch.dim import dims + + orig_dtype = input.dtype + + input = maybe_cast(input) + rois = maybe_cast(rois) + + _, _, height, width = input.size() + + n, c, ph, pw = dims(4) + ph.size = pooled_height + pw.size = pooled_width + offset_rois = rois[n] + roi_batch_ind = offset_rois[0].int() + offset = 0.5 if aligned else 0.0 + roi_start_w = offset_rois[1] * spatial_scale - offset + roi_start_h = offset_rois[2] * spatial_scale - offset + roi_end_w = offset_rois[3] * spatial_scale - offset + roi_end_h = offset_rois[4] * spatial_scale - offset + + roi_width = roi_end_w - roi_start_w + roi_height = roi_end_h - roi_start_h + if not aligned: + roi_width = torch.clamp(roi_width, min=1.0) + roi_height = torch.clamp(roi_height, min=1.0) + + bin_size_h = roi_height / pooled_height + bin_size_w = roi_width / pooled_width + + exact_sampling = sampling_ratio > 0 + + roi_bin_grid_h = sampling_ratio if exact_sampling else torch.ceil(roi_height / pooled_height) + roi_bin_grid_w = sampling_ratio if exact_sampling else torch.ceil(roi_width / pooled_width) + + iy, ix = dims(2) + + if exact_sampling: + count = max(roi_bin_grid_h * roi_bin_grid_w, 1) + iy.size = roi_bin_grid_h + ix.size = roi_bin_grid_w + ymask = None + xmask = None + else: + count = torch.clamp(roi_bin_grid_h * roi_bin_grid_w, min=1) + # When doing adaptive sampling, the number of samples we need to do + # is data-dependent based on how big the ROIs are. This is a bit + # awkward because first-class dims can't actually handle this. + # So instead, we inefficiently suppose that we needed to sample ALL + # the points and mask out things that turned out to be unnecessary + iy.size = height + ix.size = width + ymask = iy < roi_bin_grid_h + xmask = ix < roi_bin_grid_w + + y = roi_start_h + ph * bin_size_h + (iy + 0.5) * bin_size_h / roi_bin_grid_h + x = roi_start_w + pw * bin_size_w + (ix + 0.5) * bin_size_w / roi_bin_grid_w + val = _bilinear_interpolate(input, roi_batch_ind, c, height, width, y, x, ymask, xmask) + + # Mask out samples that weren't actually adaptively needed + if not exact_sampling: + val = torch.where(ymask, val, 0) + val = torch.where(xmask, val, 0) + + output = val.sum((iy, ix)) + output /= count + + output = output.to(orig_dtype) + + return output.order(n, c, ph, pw) + + @torch.fx.wrap def roi_align( input: Tensor, @@ -54,12 +187,16 @@ def roi_align( """ if not torch.jit.is_scripting() and not torch.jit.is_tracing(): _log_api_usage_once(roi_align) - _assert_has_ops() check_roi_boxes_shape(boxes) rois = boxes output_size = _pair(output_size) if not isinstance(rois, torch.Tensor): rois = convert_boxes_to_roi_format(rois) + if not torch.jit.is_scripting(): + if not _has_ops() or (torch.are_deterministic_algorithms_enabled() and input.is_cuda): + return _roi_align( + input, rois, spatial_scale, output_size[0], output_size[1], sampling_ratio, aligned + ) return torch.ops.torchvision.roi_align( input, rois, spatial_scale, output_size[0], output_size[1], sampling_ratio, aligned ) From f114d2753124b8e4ed89e1c16cba3640de11d0f3 Mon Sep 17 00:00:00 2001 From: "Edward Z. Yang" Date: Sun, 14 May 2023 05:53:05 -0700 Subject: [PATCH 02/10] Remove expecttest dep Signed-off-by: Edward Z. Yang --- test/test_ops.py | 21 ++++++++++++++++++++- 1 file changed, 20 insertions(+), 1 deletion(-) diff --git a/test/test_ops.py b/test/test_ops.py index 5f3d32049b8..8a1fa8e0e08 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -11,7 +11,6 @@ import torch.fx import torch.nn.functional as F from common_utils import assert_equal, cpu_and_gpu, needs_cuda -from torch.testing._internal.common_utils import DeterministicGuard from PIL import Image from torch import nn, Tensor from torch.autograd import gradcheck @@ -20,6 +19,26 @@ from torchvision.models.feature_extraction import get_graph_node_names +# Context manager for setting deterministic flag and automatically +# resetting it to its original value +class DeterministicGuard: + def __init__(self, deterministic, *, warn_only=False): + self.deterministic = deterministic + self.warn_only = warn_only + + def __enter__(self): + self.deterministic_restore = torch.are_deterministic_algorithms_enabled() + self.warn_only_restore = torch.is_deterministic_algorithms_warn_only_enabled() + torch.use_deterministic_algorithms( + self.deterministic, + warn_only=self.warn_only) + + def __exit__(self, exception_type, exception_value, traceback): + torch.use_deterministic_algorithms( + self.deterministic_restore, + warn_only=self.warn_only_restore) + + class RoIOpTesterModuleWrapper(nn.Module): def __init__(self, obj): super().__init__() From 9e5cb7bbe34b2eea2a8fad31096acc8fc1c4f66a Mon Sep 17 00:00:00 2001 From: "Edward Z. Yang" Date: Sun, 14 May 2023 09:18:08 -0700 Subject: [PATCH 03/10] lintfix Signed-off-by: Edward Z. Yang --- torchvision/ops/roi_align.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/torchvision/ops/roi_align.py b/torchvision/ops/roi_align.py index 83a7e7ee8bb..0f82633dcfa 100644 --- a/torchvision/ops/roi_align.py +++ b/torchvision/ops/roi_align.py @@ -14,8 +14,6 @@ # NB: all tensor inputs def _bilinear_interpolate(input, roi_batch_ind, c, height, width, y, x, ymask, xmask): - from functorch.dim import dims - # deal with inverse element out of feature map boundary y = y.clamp(min=0) x = x.clamp(min=0) From f7612bff3d4cea083b2a0b4a1712e5503cafaece Mon Sep 17 00:00:00 2001 From: "Edward Z. Yang" Date: Sun, 14 May 2023 09:18:35 -0700 Subject: [PATCH 04/10] formatting Signed-off-by: Edward Z. Yang --- test/test_ops.py | 25 ++++++++++++++----------- torchvision/ops/roi_align.py | 12 ++++++------ 2 files changed, 20 insertions(+), 17 deletions(-) diff --git a/test/test_ops.py b/test/test_ops.py index 8a1fa8e0e08..879a027ab48 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -29,14 +29,10 @@ def __init__(self, deterministic, *, warn_only=False): def __enter__(self): self.deterministic_restore = torch.are_deterministic_algorithms_enabled() self.warn_only_restore = torch.is_deterministic_algorithms_warn_only_enabled() - torch.use_deterministic_algorithms( - self.deterministic, - warn_only=self.warn_only) + torch.use_deterministic_algorithms(self.deterministic, warn_only=self.warn_only) def __exit__(self, exception_type, exception_value, traceback): - torch.use_deterministic_algorithms( - self.deterministic_restore, - warn_only=self.warn_only_restore) + torch.use_deterministic_algorithms(self.deterministic_restore, warn_only=self.warn_only_restore) class RoIOpTesterModuleWrapper(nn.Module): @@ -408,7 +404,6 @@ def expected_fn( grid_w = sampling_ratio if sampling_ratio > 0 else int(np.ceil(bin_w)) for channel in range(0, n_channels): - val = 0 for iy in range(0, grid_h): y = start_h + (iy + 0.5) * bin_h / grid_h @@ -431,7 +426,12 @@ def test_forward(self, device, contiguous, deterministic, aligned, x_dtype=None, if deterministic and device == "cpu": pytest.skip("cpu is always deterministic, don't retest") super().test_forward( - device=device, contiguous=contiguous, deterministic=deterministic, x_dtype=x_dtype, rois_dtype=rois_dtype, aligned=aligned + device=device, + contiguous=contiguous, + deterministic=deterministic, + x_dtype=x_dtype, + rois_dtype=rois_dtype, + aligned=aligned, ) @needs_cuda @@ -442,7 +442,12 @@ def test_forward(self, device, contiguous, deterministic, aligned, x_dtype=None, def test_autocast(self, aligned, deterministic, x_dtype, rois_dtype): with torch.cuda.amp.autocast(): self.test_forward( - torch.device("cuda"), contiguous=False, deterministic=deterministic, aligned=aligned, x_dtype=x_dtype, rois_dtype=rois_dtype + torch.device("cuda"), + contiguous=False, + deterministic=deterministic, + aligned=aligned, + x_dtype=x_dtype, + rois_dtype=rois_dtype, ) @pytest.mark.parametrize("seed", range(10)) @@ -1013,7 +1018,6 @@ def test_compare_cpu_cuda_grads(self, contiguous): weight = init_weight for d in ["cpu", "cuda"]: - out = ops.deform_conv2d(img.to(d), offset.to(d), weight.to(d), padding=1, mask=mask.to(d)) out.mean().backward() if true_cpu_grads is None: @@ -1409,7 +1413,6 @@ class TestGeneralizedBoxIouLoss: @pytest.mark.parametrize("device", cpu_and_gpu()) @pytest.mark.parametrize("dtype", [torch.float32, torch.half]) def test_giou_loss(self, dtype, device): - box1, box2, box3, box4, box1s, box2s = get_boxes(dtype, device) # Identical boxes should have loss of 0 diff --git a/torchvision/ops/roi_align.py b/torchvision/ops/roi_align.py index 0f82633dcfa..e18353c689d 100644 --- a/torchvision/ops/roi_align.py +++ b/torchvision/ops/roi_align.py @@ -29,8 +29,8 @@ def _bilinear_interpolate(input, roi_batch_ind, c, height, width, y, x, ymask, x ly = y - y_low lx = x - x_low - hy = 1. - ly - hx = 1. - lx + hy = 1.0 - ly + hx = 1.0 - lx # do bilinear interpolation, but respect the masking! # TODO: It's possible the masking here is unnecessary if y and @@ -49,11 +49,12 @@ def masked_index(y, x): w1 = hy * hx w2 = hy * lx w3 = ly * hx - w4 = ly * lx; + w4 = ly * lx val = w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4 return val + # TODO: this doesn't actually cache # TODO: main library should make this easier to do def maybe_cast(tensor): @@ -62,6 +63,7 @@ def maybe_cast(tensor): else: return tensor + # This is a slow but pure Python and differentiable implementation of # roi_align. It potentially is a good basis for Inductor compilation # (but I have not benchmarked it) but today it is solely used for the @@ -192,9 +194,7 @@ def roi_align( rois = convert_boxes_to_roi_format(rois) if not torch.jit.is_scripting(): if not _has_ops() or (torch.are_deterministic_algorithms_enabled() and input.is_cuda): - return _roi_align( - input, rois, spatial_scale, output_size[0], output_size[1], sampling_ratio, aligned - ) + return _roi_align(input, rois, spatial_scale, output_size[0], output_size[1], sampling_ratio, aligned) return torch.ops.torchvision.roi_align( input, rois, spatial_scale, output_size[0], output_size[1], sampling_ratio, aligned ) From 4d64bc380c6cd8105224d8971df3b1bbb568e532 Mon Sep 17 00:00:00 2001 From: "Edward Z. Yang" Date: Mon, 15 May 2023 06:32:08 -0700 Subject: [PATCH 05/10] CR comments Signed-off-by: Edward Z. Yang --- test/test_ops.py | 7 +++++-- torchvision/ops/roi_align.py | 6 ++++-- 2 files changed, 9 insertions(+), 4 deletions(-) diff --git a/test/test_ops.py b/test/test_ops.py index 879a027ab48..417a0c73003 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -157,8 +157,7 @@ def test_torch_fx_trace(self, device, x_dtype=torch.float, rois_dtype=torch.floa @pytest.mark.parametrize("seed", range(10)) @pytest.mark.parametrize("device", cpu_and_gpu()) @pytest.mark.parametrize("contiguous", (True, False)) - @pytest.mark.parametrize("deterministic", (False,)) - def test_backward(self, seed, device, contiguous, deterministic): + def test_backward(self, seed, device, contiguous, deterministic=False): torch.random.manual_seed(seed) pool_size = 2 x = torch.rand(1, 2 * (pool_size**2), 5, 5, dtype=self.dtype, device=device, requires_grad=True) @@ -440,6 +439,8 @@ def test_forward(self, device, contiguous, deterministic, aligned, x_dtype=None, @pytest.mark.parametrize("x_dtype", (torch.float, torch.half)) @pytest.mark.parametrize("rois_dtype", (torch.float, torch.half)) def test_autocast(self, aligned, deterministic, x_dtype, rois_dtype): + if deterministic and device == "cpu": + pytest.skip("cpu is always deterministic, don't retest") with torch.cuda.amp.autocast(): self.test_forward( torch.device("cuda"), @@ -455,6 +456,8 @@ def test_autocast(self, aligned, deterministic, x_dtype, rois_dtype): @pytest.mark.parametrize("contiguous", (True, False)) @pytest.mark.parametrize("deterministic", (True, False)) def test_backward(self, seed, device, contiguous, deterministic): + if deterministic and device == "cpu": + pytest.skip("cpu is always deterministic, don't retest") super().test_backward(seed, device, contiguous, deterministic) def _make_rois(self, img_size, num_imgs, dtype, num_rois=1000): diff --git a/torchvision/ops/roi_align.py b/torchvision/ops/roi_align.py index e18353c689d..58d7e5710df 100644 --- a/torchvision/ops/roi_align.py +++ b/torchvision/ops/roi_align.py @@ -12,7 +12,7 @@ from ._utils import check_roi_boxes_shape, convert_boxes_to_roi_format -# NB: all tensor inputs +# NB: all inputs are tensors def _bilinear_interpolate(input, roi_batch_ind, c, height, width, y, x, ymask, xmask): # deal with inverse element out of feature map boundary y = y.clamp(min=0) @@ -67,7 +67,8 @@ def maybe_cast(tensor): # This is a slow but pure Python and differentiable implementation of # roi_align. It potentially is a good basis for Inductor compilation # (but I have not benchmarked it) but today it is solely used for the -# fact that its backwards can be implemented deterministically. +# fact that its backwards can be implemented deterministically, +# which is needed for the PT2 benchmark suite. # # It is transcribed directly off of the roi_align CUDA kernel, see # https://dev-discuss.pytorch.org/t/a-pure-python-implementation-of-roi-align-that-looks-just-like-its-cuda-kernel/1266 @@ -195,6 +196,7 @@ def roi_align( if not torch.jit.is_scripting(): if not _has_ops() or (torch.are_deterministic_algorithms_enabled() and input.is_cuda): return _roi_align(input, rois, spatial_scale, output_size[0], output_size[1], sampling_ratio, aligned) + assert _assert_has_ops() return torch.ops.torchvision.roi_align( input, rois, spatial_scale, output_size[0], output_size[1], sampling_ratio, aligned ) From 2121bb0d7201e4906a26d25d0ef8ac30b94afbc8 Mon Sep 17 00:00:00 2001 From: "Edward Z. Yang" Date: Mon, 15 May 2023 08:13:26 -0700 Subject: [PATCH 06/10] fix import braino Signed-off-by: Edward Z. Yang --- torchvision/ops/roi_align.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchvision/ops/roi_align.py b/torchvision/ops/roi_align.py index 58d7e5710df..4f849e47ac7 100644 --- a/torchvision/ops/roi_align.py +++ b/torchvision/ops/roi_align.py @@ -6,7 +6,7 @@ from torch import nn, Tensor from torch.jit.annotations import BroadcastingList2 from torch.nn.modules.utils import _pair -from torchvision.extension import _has_ops +from torchvision.extension import _assert_has_ops, _has_ops from ..utils import _log_api_usage_once from ._utils import check_roi_boxes_shape, convert_boxes_to_roi_format From 0b51bd57e213980a7591add65ff4e5a75d0307e9 Mon Sep 17 00:00:00 2001 From: "Edward Z. Yang" Date: Mon, 15 May 2023 08:15:08 -0700 Subject: [PATCH 07/10] fix another braino Signed-off-by: Edward Z. Yang --- torchvision/ops/roi_align.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchvision/ops/roi_align.py b/torchvision/ops/roi_align.py index 4f849e47ac7..df070751a6c 100644 --- a/torchvision/ops/roi_align.py +++ b/torchvision/ops/roi_align.py @@ -196,7 +196,7 @@ def roi_align( if not torch.jit.is_scripting(): if not _has_ops() or (torch.are_deterministic_algorithms_enabled() and input.is_cuda): return _roi_align(input, rois, spatial_scale, output_size[0], output_size[1], sampling_ratio, aligned) - assert _assert_has_ops() + _assert_has_ops() return torch.ops.torchvision.roi_align( input, rois, spatial_scale, output_size[0], output_size[1], sampling_ratio, aligned ) From 296bca682c8fd7fc9b91c0f38e0a475ec68b0863 Mon Sep 17 00:00:00 2001 From: "Edward Z. Yang" Date: Mon, 15 May 2023 10:18:23 -0700 Subject: [PATCH 08/10] Convert the code to stop using first class dims Signed-off-by: Edward Z. Yang --- torchvision/ops/roi_align.py | 120 +++++++++++++++++++++++------------ 1 file changed, 78 insertions(+), 42 deletions(-) diff --git a/torchvision/ops/roi_align.py b/torchvision/ops/roi_align.py index df070751a6c..53be85c5b4c 100644 --- a/torchvision/ops/roi_align.py +++ b/torchvision/ops/roi_align.py @@ -13,7 +13,16 @@ # NB: all inputs are tensors -def _bilinear_interpolate(input, roi_batch_ind, c, height, width, y, x, ymask, xmask): +def _bilinear_interpolate( + input, # [N, C, H, W] + roi_batch_ind, # [K] + y, # [K, PH, IY] + x, # [K, PW, IX] + ymask, # [IY] + xmask, # [IX] +): + _, channels, height, width = input.size() + # deal with inverse element out of feature map boundary y = y.clamp(min=0) x = x.clamp(min=0) @@ -35,21 +44,35 @@ def _bilinear_interpolate(input, roi_batch_ind, c, height, width, y, x, ymask, x # do bilinear interpolation, but respect the masking! # TODO: It's possible the masking here is unnecessary if y and # x were clamped appropriately; hard to tell - def masked_index(y, x): + def masked_index( + y, # [K, PH, IY] + x, # [K, PW, IX] + ): if ymask is not None: assert xmask is not None - y = torch.where(ymask, y, 0) - x = torch.where(xmask, x, 0) - return input[roi_batch_ind, c, y, x] + y = torch.where(ymask[None, None, :], y, 0) + x = torch.where(xmask[None, None, :], x, 0) + return input[ + roi_batch_ind[:, None, None, None, None, None], + torch.arange(channels, device=input.device)[None, :, None, None, None, None], + y[:, None, :, None, :, None], # prev [K, PH, IY] + x[:, None, None, :, None, :], # prev [K, PW, IX] + ] # [K, C, PH, PW, IY, IX] v1 = masked_index(y_low, x_low) v2 = masked_index(y_low, x_high) v3 = masked_index(y_high, x_low) v4 = masked_index(y_high, x_high) - w1 = hy * hx - w2 = hy * lx - w3 = ly * hx - w4 = ly * lx + # all ws preemptively [K, C, PH, PW, IY, IX] + def outer_prod(y, x): + return ( + y[:, None, :, None, :, None] * + x[:, None, None, :, None, :] + ) + w1 = outer_prod(hy, hx) + w2 = outer_prod(hy, lx) + w3 = outer_prod(ly, hx) + w4 = outer_prod(ly, lx) val = w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4 return val @@ -74,8 +97,6 @@ def maybe_cast(tensor): # https://dev-discuss.pytorch.org/t/a-pure-python-implementation-of-roi-align-that-looks-just-like-its-cuda-kernel/1266 @torch._dynamo.allow_in_graph def _roi_align(input, rois, spatial_scale, pooled_height, pooled_width, sampling_ratio, aligned): - from functorch.dim import dims - orig_dtype = input.dtype input = maybe_cast(input) @@ -83,37 +104,41 @@ def _roi_align(input, rois, spatial_scale, pooled_height, pooled_width, sampling _, _, height, width = input.size() - n, c, ph, pw = dims(4) - ph.size = pooled_height - pw.size = pooled_width - offset_rois = rois[n] - roi_batch_ind = offset_rois[0].int() + ph = torch.arange(pooled_height, device=input.device) # [PH] + pw = torch.arange(pooled_width, device=input.device) # [PW] + + # input: [N, C, H, W] + # rois: [K, 5] + + roi_batch_ind = rois[:, 0].int() # [K] offset = 0.5 if aligned else 0.0 - roi_start_w = offset_rois[1] * spatial_scale - offset - roi_start_h = offset_rois[2] * spatial_scale - offset - roi_end_w = offset_rois[3] * spatial_scale - offset - roi_end_h = offset_rois[4] * spatial_scale - offset + roi_start_w = rois[:, 1] * spatial_scale - offset # [K] + roi_start_h = rois[:, 2] * spatial_scale - offset # [K] + roi_end_w = rois[:, 3] * spatial_scale - offset # [K] + roi_end_h = rois[:, 4] * spatial_scale - offset # [K] - roi_width = roi_end_w - roi_start_w - roi_height = roi_end_h - roi_start_h + roi_width = roi_end_w - roi_start_w # [K] + roi_height = roi_end_h - roi_start_h # [K] if not aligned: - roi_width = torch.clamp(roi_width, min=1.0) - roi_height = torch.clamp(roi_height, min=1.0) + roi_width = torch.clamp(roi_width, min=1.0) # [K] + roi_height = torch.clamp(roi_height, min=1.0) # [K] - bin_size_h = roi_height / pooled_height - bin_size_w = roi_width / pooled_width + bin_size_h = roi_height / pooled_height # [K] + bin_size_w = roi_width / pooled_width # [K] exact_sampling = sampling_ratio > 0 - roi_bin_grid_h = sampling_ratio if exact_sampling else torch.ceil(roi_height / pooled_height) - roi_bin_grid_w = sampling_ratio if exact_sampling else torch.ceil(roi_width / pooled_width) + roi_bin_grid_h = sampling_ratio if exact_sampling else torch.ceil(roi_height / pooled_height) # scalar or [K] + roi_bin_grid_w = sampling_ratio if exact_sampling else torch.ceil(roi_width / pooled_width) # scalar or [K] + """ iy, ix = dims(2) + """ if exact_sampling: count = max(roi_bin_grid_h * roi_bin_grid_w, 1) - iy.size = roi_bin_grid_h - ix.size = roi_bin_grid_w + iy = torch.arange(roi_bin_grid_h, device=input.device) # [IY] + ix = torch.arange(roi_bin_grid_w, device=input.device) # [IX] ymask = None xmask = None else: @@ -123,26 +148,37 @@ def _roi_align(input, rois, spatial_scale, pooled_height, pooled_width, sampling # awkward because first-class dims can't actually handle this. # So instead, we inefficiently suppose that we needed to sample ALL # the points and mask out things that turned out to be unnecessary - iy.size = height - ix.size = width - ymask = iy < roi_bin_grid_h - xmask = ix < roi_bin_grid_w - - y = roi_start_h + ph * bin_size_h + (iy + 0.5) * bin_size_h / roi_bin_grid_h - x = roi_start_w + pw * bin_size_w + (ix + 0.5) * bin_size_w / roi_bin_grid_w - val = _bilinear_interpolate(input, roi_batch_ind, c, height, width, y, x, ymask, xmask) + iy = torch.arange(height, device=input.device) # [IY] + ix = torch.arange(width, device=input.device) # [IX] + ymask = iy < roi_bin_grid_h # [IY] + xmask = ix < roi_bin_grid_w # [IX] + + def from_K(t): + return t[:, None, None] + + y = ( + from_K(roi_start_h) + + ph[None, :, None] * from_K(bin_size_h) + + (iy[None, None, :] + 0.5) * from_K(bin_size_h / roi_bin_grid_h) + ) # [K, PH, IY] + x = ( + from_K(roi_start_w) + + pw[None, :, None] * from_K(bin_size_w) + + (ix[None, None, :] + 0.5) * from_K(bin_size_w / roi_bin_grid_w) + ) # [K, PW, IX] + val = _bilinear_interpolate(input, roi_batch_ind, y, x, ymask, xmask) # [K, C, PH, PW, IY, IX] # Mask out samples that weren't actually adaptively needed if not exact_sampling: - val = torch.where(ymask, val, 0) - val = torch.where(xmask, val, 0) + val = torch.where(ymask[None, None, None, None, :, None], val, 0) + val = torch.where(xmask[None, None, None, None, None, :], val, 0) - output = val.sum((iy, ix)) + output = val.sum((-1, -2)) # remove IY, IX ~> [K, C, PH, PW] output /= count output = output.to(orig_dtype) - return output.order(n, c, ph, pw) + return output @torch.fx.wrap From 41c1ff66e23e336433884b7e25ed318a883b40a8 Mon Sep 17 00:00:00 2001 From: "Edward Z. Yang" Date: Mon, 15 May 2023 12:33:00 -0700 Subject: [PATCH 09/10] Minor fixups Signed-off-by: Edward Z. Yang --- test/test_ops.py | 2 -- torchvision/ops/roi_align.py | 45 ++++++++++++++++++------------------ 2 files changed, 22 insertions(+), 25 deletions(-) diff --git a/test/test_ops.py b/test/test_ops.py index 417a0c73003..463ebb333ff 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -439,8 +439,6 @@ def test_forward(self, device, contiguous, deterministic, aligned, x_dtype=None, @pytest.mark.parametrize("x_dtype", (torch.float, torch.half)) @pytest.mark.parametrize("rois_dtype", (torch.float, torch.half)) def test_autocast(self, aligned, deterministic, x_dtype, rois_dtype): - if deterministic and device == "cpu": - pytest.skip("cpu is always deterministic, don't retest") with torch.cuda.amp.autocast(): self.test_forward( torch.device("cuda"), diff --git a/torchvision/ops/roi_align.py b/torchvision/ops/roi_align.py index 53be85c5b4c..a3c90421de1 100644 --- a/torchvision/ops/roi_align.py +++ b/torchvision/ops/roi_align.py @@ -18,8 +18,8 @@ def _bilinear_interpolate( roi_batch_ind, # [K] y, # [K, PH, IY] x, # [K, PW, IX] - ymask, # [IY] - xmask, # [IX] + ymask, # [K, IY] + xmask, # [K, IX] ): _, channels, height, width = input.size() @@ -50,8 +50,8 @@ def masked_index( ): if ymask is not None: assert xmask is not None - y = torch.where(ymask[None, None, :], y, 0) - x = torch.where(xmask[None, None, :], x, 0) + y = torch.where(ymask[:, None, :], y, 0) + x = torch.where(xmask[:, None, :], x, 0) return input[ roi_batch_ind[:, None, None, None, None, None], torch.arange(channels, device=input.device)[None, :, None, None, None, None], @@ -63,12 +63,11 @@ def masked_index( v2 = masked_index(y_low, x_high) v3 = masked_index(y_high, x_low) v4 = masked_index(y_high, x_high) + # all ws preemptively [K, C, PH, PW, IY, IX] def outer_prod(y, x): - return ( - y[:, None, :, None, :, None] * - x[:, None, None, :, None, :] - ) + return y[:, None, :, None, :, None] * x[:, None, None, :, None, :] + w1 = outer_prod(hy, hx) w2 = outer_prod(hy, lx) w3 = outer_prod(ly, hx) @@ -136,13 +135,13 @@ def _roi_align(input, rois, spatial_scale, pooled_height, pooled_width, sampling """ if exact_sampling: - count = max(roi_bin_grid_h * roi_bin_grid_w, 1) + count = max(roi_bin_grid_h * roi_bin_grid_w, 1) # scalar iy = torch.arange(roi_bin_grid_h, device=input.device) # [IY] ix = torch.arange(roi_bin_grid_w, device=input.device) # [IX] ymask = None xmask = None else: - count = torch.clamp(roi_bin_grid_h * roi_bin_grid_w, min=1) + count = torch.clamp(roi_bin_grid_h * roi_bin_grid_w, min=1) # [K] # When doing adaptive sampling, the number of samples we need to do # is data-dependent based on how big the ROIs are. This is a bit # awkward because first-class dims can't actually handle this. @@ -150,31 +149,31 @@ def _roi_align(input, rois, spatial_scale, pooled_height, pooled_width, sampling # the points and mask out things that turned out to be unnecessary iy = torch.arange(height, device=input.device) # [IY] ix = torch.arange(width, device=input.device) # [IX] - ymask = iy < roi_bin_grid_h # [IY] - xmask = ix < roi_bin_grid_w # [IX] + ymask = iy[None, :] < roi_bin_grid_h[:, None] # [K, IY] + xmask = ix[None, :] < roi_bin_grid_w[:, None] # [K, IX] def from_K(t): return t[:, None, None] y = ( - from_K(roi_start_h) + - ph[None, :, None] * from_K(bin_size_h) + - (iy[None, None, :] + 0.5) * from_K(bin_size_h / roi_bin_grid_h) - ) # [K, PH, IY] + from_K(roi_start_h) + + ph[None, :, None] * from_K(bin_size_h) + + (iy[None, None, :] + 0.5) * from_K(bin_size_h / roi_bin_grid_h) + ) # [K, PH, IY] x = ( - from_K(roi_start_w) + - pw[None, :, None] * from_K(bin_size_w) + - (ix[None, None, :] + 0.5) * from_K(bin_size_w / roi_bin_grid_w) - ) # [K, PW, IX] + from_K(roi_start_w) + + pw[None, :, None] * from_K(bin_size_w) + + (ix[None, None, :] + 0.5) * from_K(bin_size_w / roi_bin_grid_w) + ) # [K, PW, IX] val = _bilinear_interpolate(input, roi_batch_ind, y, x, ymask, xmask) # [K, C, PH, PW, IY, IX] # Mask out samples that weren't actually adaptively needed if not exact_sampling: - val = torch.where(ymask[None, None, None, None, :, None], val, 0) - val = torch.where(xmask[None, None, None, None, None, :], val, 0) + val = torch.where(ymask[:, None, None, None, :, None], val, 0) + val = torch.where(xmask[:, None, None, None, None, :], val, 0) output = val.sum((-1, -2)) # remove IY, IX ~> [K, C, PH, PW] - output /= count + output /= count[:, None, None, None] output = output.to(orig_dtype) From d7f4e80cfd4c229adfa960e65ffcbd80261adedc Mon Sep 17 00:00:00 2001 From: "Edward Z. Yang" Date: Mon, 15 May 2023 12:36:40 -0700 Subject: [PATCH 10/10] one more fix Signed-off-by: Edward Z. Yang --- torchvision/ops/roi_align.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/torchvision/ops/roi_align.py b/torchvision/ops/roi_align.py index a3c90421de1..be8ec8aea74 100644 --- a/torchvision/ops/roi_align.py +++ b/torchvision/ops/roi_align.py @@ -173,7 +173,10 @@ def from_K(t): val = torch.where(xmask[:, None, None, None, None, :], val, 0) output = val.sum((-1, -2)) # remove IY, IX ~> [K, C, PH, PW] - output /= count[:, None, None, None] + if isinstance(count, torch.Tensor): + output /= count[:, None, None, None] + else: + output /= count output = output.to(orig_dtype)