diff --git a/test/test_ops.py b/test/test_ops.py index 5f8f8098c21..463ebb333ff 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -19,6 +19,22 @@ from torchvision.models.feature_extraction import get_graph_node_names +# Context manager for setting deterministic flag and automatically +# resetting it to its original value +class DeterministicGuard: + def __init__(self, deterministic, *, warn_only=False): + self.deterministic = deterministic + self.warn_only = warn_only + + def __enter__(self): + self.deterministic_restore = torch.are_deterministic_algorithms_enabled() + self.warn_only_restore = torch.is_deterministic_algorithms_warn_only_enabled() + torch.use_deterministic_algorithms(self.deterministic, warn_only=self.warn_only) + + def __exit__(self, exception_type, exception_value, traceback): + torch.use_deterministic_algorithms(self.deterministic_restore, warn_only=self.warn_only_restore) + + class RoIOpTesterModuleWrapper(nn.Module): def __init__(self, obj): super().__init__() @@ -83,7 +99,7 @@ class RoIOpTester(ABC): @pytest.mark.parametrize("device", cpu_and_gpu()) @pytest.mark.parametrize("contiguous", (True, False)) - def test_forward(self, device, contiguous, x_dtype=None, rois_dtype=None, **kwargs): + def test_forward(self, device, contiguous, x_dtype=None, rois_dtype=None, deterministic=False, **kwargs): x_dtype = self.dtype if x_dtype is None else x_dtype rois_dtype = self.dtype if rois_dtype is None else rois_dtype pool_size = 5 @@ -99,7 +115,8 @@ def test_forward(self, device, contiguous, x_dtype=None, rois_dtype=None, **kwar ) pool_h, pool_w = pool_size, pool_size - y = self.fn(x, rois, pool_h, pool_w, spatial_scale=1, sampling_ratio=-1, **kwargs) + with DeterministicGuard(deterministic): + y = self.fn(x, rois, pool_h, pool_w, spatial_scale=1, sampling_ratio=-1, **kwargs) # the following should be true whether we're running an autocast test or not. assert y.dtype == x.dtype gt_y = self.expected_fn( @@ -140,7 +157,7 @@ def test_torch_fx_trace(self, device, x_dtype=torch.float, rois_dtype=torch.floa @pytest.mark.parametrize("seed", range(10)) @pytest.mark.parametrize("device", cpu_and_gpu()) @pytest.mark.parametrize("contiguous", (True, False)) - def test_backward(self, seed, device, contiguous): + def test_backward(self, seed, device, contiguous, deterministic=False): torch.random.manual_seed(seed) pool_size = 2 x = torch.rand(1, 2 * (pool_size**2), 5, 5, dtype=self.dtype, device=device, requires_grad=True) @@ -155,7 +172,9 @@ def func(z): script_func = self.get_script_fn(rois, pool_size) - gradcheck(func, (x,)) + with DeterministicGuard(deterministic): + gradcheck(func, (x,)) + gradcheck(script_func, (x,)) @needs_cuda @@ -384,7 +403,6 @@ def expected_fn( grid_w = sampling_ratio if sampling_ratio > 0 else int(np.ceil(bin_w)) for channel in range(0, n_channels): - val = 0 for iy in range(0, grid_h): y = start_h + (iy + 0.5) * bin_h / grid_h @@ -402,21 +420,44 @@ def test_boxes_shape(self): @pytest.mark.parametrize("aligned", (True, False)) @pytest.mark.parametrize("device", cpu_and_gpu()) @pytest.mark.parametrize("contiguous", (True, False)) - def test_forward(self, device, contiguous, aligned, x_dtype=None, rois_dtype=None): + @pytest.mark.parametrize("deterministic", (True, False)) + def test_forward(self, device, contiguous, deterministic, aligned, x_dtype=None, rois_dtype=None): + if deterministic and device == "cpu": + pytest.skip("cpu is always deterministic, don't retest") super().test_forward( - device=device, contiguous=contiguous, x_dtype=x_dtype, rois_dtype=rois_dtype, aligned=aligned + device=device, + contiguous=contiguous, + deterministic=deterministic, + x_dtype=x_dtype, + rois_dtype=rois_dtype, + aligned=aligned, ) @needs_cuda @pytest.mark.parametrize("aligned", (True, False)) + @pytest.mark.parametrize("deterministic", (True, False)) @pytest.mark.parametrize("x_dtype", (torch.float, torch.half)) @pytest.mark.parametrize("rois_dtype", (torch.float, torch.half)) - def test_autocast(self, aligned, x_dtype, rois_dtype): + def test_autocast(self, aligned, deterministic, x_dtype, rois_dtype): with torch.cuda.amp.autocast(): self.test_forward( - torch.device("cuda"), contiguous=False, aligned=aligned, x_dtype=x_dtype, rois_dtype=rois_dtype + torch.device("cuda"), + contiguous=False, + deterministic=deterministic, + aligned=aligned, + x_dtype=x_dtype, + rois_dtype=rois_dtype, ) + @pytest.mark.parametrize("seed", range(10)) + @pytest.mark.parametrize("device", cpu_and_gpu()) + @pytest.mark.parametrize("contiguous", (True, False)) + @pytest.mark.parametrize("deterministic", (True, False)) + def test_backward(self, seed, device, contiguous, deterministic): + if deterministic and device == "cpu": + pytest.skip("cpu is always deterministic, don't retest") + super().test_backward(seed, device, contiguous, deterministic) + def _make_rois(self, img_size, num_imgs, dtype, num_rois=1000): rois = torch.randint(0, img_size // 2, size=(num_rois, 5)).to(dtype) rois[:, 0] = torch.randint(0, num_imgs, size=(num_rois,)) # set batch index @@ -978,7 +1019,6 @@ def test_compare_cpu_cuda_grads(self, contiguous): weight = init_weight for d in ["cpu", "cuda"]: - out = ops.deform_conv2d(img.to(d), offset.to(d), weight.to(d), padding=1, mask=mask.to(d)) out.mean().backward() if true_cpu_grads is None: @@ -1374,7 +1414,6 @@ class TestGeneralizedBoxIouLoss: @pytest.mark.parametrize("device", cpu_and_gpu()) @pytest.mark.parametrize("dtype", [torch.float32, torch.half]) def test_giou_loss(self, dtype, device): - box1, box2, box3, box4, box1s, box2s = get_boxes(dtype, device) # Identical boxes should have loss of 0 diff --git a/torchvision/ops/roi_align.py b/torchvision/ops/roi_align.py index 42e93cca211..be8ec8aea74 100644 --- a/torchvision/ops/roi_align.py +++ b/torchvision/ops/roi_align.py @@ -1,16 +1,188 @@ from typing import List, Union import torch +import torch._dynamo import torch.fx from torch import nn, Tensor from torch.jit.annotations import BroadcastingList2 from torch.nn.modules.utils import _pair -from torchvision.extension import _assert_has_ops +from torchvision.extension import _assert_has_ops, _has_ops from ..utils import _log_api_usage_once from ._utils import check_roi_boxes_shape, convert_boxes_to_roi_format +# NB: all inputs are tensors +def _bilinear_interpolate( + input, # [N, C, H, W] + roi_batch_ind, # [K] + y, # [K, PH, IY] + x, # [K, PW, IX] + ymask, # [K, IY] + xmask, # [K, IX] +): + _, channels, height, width = input.size() + + # deal with inverse element out of feature map boundary + y = y.clamp(min=0) + x = x.clamp(min=0) + y_low = y.int() + x_low = x.int() + y_high = torch.where(y_low >= height - 1, height - 1, y_low + 1) + y_low = torch.where(y_low >= height - 1, height - 1, y_low) + y = torch.where(y_low >= height - 1, y.to(input.dtype), y) + + x_high = torch.where(x_low >= width - 1, width - 1, x_low + 1) + x_low = torch.where(x_low >= width - 1, width - 1, x_low) + x = torch.where(x_low >= width - 1, x.to(input.dtype), x) + + ly = y - y_low + lx = x - x_low + hy = 1.0 - ly + hx = 1.0 - lx + + # do bilinear interpolation, but respect the masking! + # TODO: It's possible the masking here is unnecessary if y and + # x were clamped appropriately; hard to tell + def masked_index( + y, # [K, PH, IY] + x, # [K, PW, IX] + ): + if ymask is not None: + assert xmask is not None + y = torch.where(ymask[:, None, :], y, 0) + x = torch.where(xmask[:, None, :], x, 0) + return input[ + roi_batch_ind[:, None, None, None, None, None], + torch.arange(channels, device=input.device)[None, :, None, None, None, None], + y[:, None, :, None, :, None], # prev [K, PH, IY] + x[:, None, None, :, None, :], # prev [K, PW, IX] + ] # [K, C, PH, PW, IY, IX] + + v1 = masked_index(y_low, x_low) + v2 = masked_index(y_low, x_high) + v3 = masked_index(y_high, x_low) + v4 = masked_index(y_high, x_high) + + # all ws preemptively [K, C, PH, PW, IY, IX] + def outer_prod(y, x): + return y[:, None, :, None, :, None] * x[:, None, None, :, None, :] + + w1 = outer_prod(hy, hx) + w2 = outer_prod(hy, lx) + w3 = outer_prod(ly, hx) + w4 = outer_prod(ly, lx) + + val = w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4 + return val + + +# TODO: this doesn't actually cache +# TODO: main library should make this easier to do +def maybe_cast(tensor): + if torch.is_autocast_enabled() and tensor.is_cuda and tensor.dtype != torch.double: + return tensor.float() + else: + return tensor + + +# This is a slow but pure Python and differentiable implementation of +# roi_align. It potentially is a good basis for Inductor compilation +# (but I have not benchmarked it) but today it is solely used for the +# fact that its backwards can be implemented deterministically, +# which is needed for the PT2 benchmark suite. +# +# It is transcribed directly off of the roi_align CUDA kernel, see +# https://dev-discuss.pytorch.org/t/a-pure-python-implementation-of-roi-align-that-looks-just-like-its-cuda-kernel/1266 +@torch._dynamo.allow_in_graph +def _roi_align(input, rois, spatial_scale, pooled_height, pooled_width, sampling_ratio, aligned): + orig_dtype = input.dtype + + input = maybe_cast(input) + rois = maybe_cast(rois) + + _, _, height, width = input.size() + + ph = torch.arange(pooled_height, device=input.device) # [PH] + pw = torch.arange(pooled_width, device=input.device) # [PW] + + # input: [N, C, H, W] + # rois: [K, 5] + + roi_batch_ind = rois[:, 0].int() # [K] + offset = 0.5 if aligned else 0.0 + roi_start_w = rois[:, 1] * spatial_scale - offset # [K] + roi_start_h = rois[:, 2] * spatial_scale - offset # [K] + roi_end_w = rois[:, 3] * spatial_scale - offset # [K] + roi_end_h = rois[:, 4] * spatial_scale - offset # [K] + + roi_width = roi_end_w - roi_start_w # [K] + roi_height = roi_end_h - roi_start_h # [K] + if not aligned: + roi_width = torch.clamp(roi_width, min=1.0) # [K] + roi_height = torch.clamp(roi_height, min=1.0) # [K] + + bin_size_h = roi_height / pooled_height # [K] + bin_size_w = roi_width / pooled_width # [K] + + exact_sampling = sampling_ratio > 0 + + roi_bin_grid_h = sampling_ratio if exact_sampling else torch.ceil(roi_height / pooled_height) # scalar or [K] + roi_bin_grid_w = sampling_ratio if exact_sampling else torch.ceil(roi_width / pooled_width) # scalar or [K] + + """ + iy, ix = dims(2) + """ + + if exact_sampling: + count = max(roi_bin_grid_h * roi_bin_grid_w, 1) # scalar + iy = torch.arange(roi_bin_grid_h, device=input.device) # [IY] + ix = torch.arange(roi_bin_grid_w, device=input.device) # [IX] + ymask = None + xmask = None + else: + count = torch.clamp(roi_bin_grid_h * roi_bin_grid_w, min=1) # [K] + # When doing adaptive sampling, the number of samples we need to do + # is data-dependent based on how big the ROIs are. This is a bit + # awkward because first-class dims can't actually handle this. + # So instead, we inefficiently suppose that we needed to sample ALL + # the points and mask out things that turned out to be unnecessary + iy = torch.arange(height, device=input.device) # [IY] + ix = torch.arange(width, device=input.device) # [IX] + ymask = iy[None, :] < roi_bin_grid_h[:, None] # [K, IY] + xmask = ix[None, :] < roi_bin_grid_w[:, None] # [K, IX] + + def from_K(t): + return t[:, None, None] + + y = ( + from_K(roi_start_h) + + ph[None, :, None] * from_K(bin_size_h) + + (iy[None, None, :] + 0.5) * from_K(bin_size_h / roi_bin_grid_h) + ) # [K, PH, IY] + x = ( + from_K(roi_start_w) + + pw[None, :, None] * from_K(bin_size_w) + + (ix[None, None, :] + 0.5) * from_K(bin_size_w / roi_bin_grid_w) + ) # [K, PW, IX] + val = _bilinear_interpolate(input, roi_batch_ind, y, x, ymask, xmask) # [K, C, PH, PW, IY, IX] + + # Mask out samples that weren't actually adaptively needed + if not exact_sampling: + val = torch.where(ymask[:, None, None, None, :, None], val, 0) + val = torch.where(xmask[:, None, None, None, None, :], val, 0) + + output = val.sum((-1, -2)) # remove IY, IX ~> [K, C, PH, PW] + if isinstance(count, torch.Tensor): + output /= count[:, None, None, None] + else: + output /= count + + output = output.to(orig_dtype) + + return output + + @torch.fx.wrap def roi_align( input: Tensor, @@ -54,12 +226,15 @@ def roi_align( """ if not torch.jit.is_scripting() and not torch.jit.is_tracing(): _log_api_usage_once(roi_align) - _assert_has_ops() check_roi_boxes_shape(boxes) rois = boxes output_size = _pair(output_size) if not isinstance(rois, torch.Tensor): rois = convert_boxes_to_roi_format(rois) + if not torch.jit.is_scripting(): + if not _has_ops() or (torch.are_deterministic_algorithms_enabled() and input.is_cuda): + return _roi_align(input, rois, spatial_scale, output_size[0], output_size[1], sampling_ratio, aligned) + _assert_has_ops() return torch.ops.torchvision.roi_align( input, rois, spatial_scale, output_size[0], output_size[1], sampling_ratio, aligned )