Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 50 additions & 11 deletions test/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,22 @@
from torchvision.models.feature_extraction import get_graph_node_names


# Context manager for setting deterministic flag and automatically
# resetting it to its original value
class DeterministicGuard:
def __init__(self, deterministic, *, warn_only=False):
self.deterministic = deterministic
self.warn_only = warn_only

def __enter__(self):
self.deterministic_restore = torch.are_deterministic_algorithms_enabled()
self.warn_only_restore = torch.is_deterministic_algorithms_warn_only_enabled()
torch.use_deterministic_algorithms(self.deterministic, warn_only=self.warn_only)

def __exit__(self, exception_type, exception_value, traceback):
torch.use_deterministic_algorithms(self.deterministic_restore, warn_only=self.warn_only_restore)


class RoIOpTesterModuleWrapper(nn.Module):
def __init__(self, obj):
super().__init__()
Expand Down Expand Up @@ -83,7 +99,7 @@ class RoIOpTester(ABC):

@pytest.mark.parametrize("device", cpu_and_gpu())
@pytest.mark.parametrize("contiguous", (True, False))
def test_forward(self, device, contiguous, x_dtype=None, rois_dtype=None, **kwargs):
def test_forward(self, device, contiguous, x_dtype=None, rois_dtype=None, deterministic=False, **kwargs):
x_dtype = self.dtype if x_dtype is None else x_dtype
rois_dtype = self.dtype if rois_dtype is None else rois_dtype
pool_size = 5
Expand All @@ -99,7 +115,8 @@ def test_forward(self, device, contiguous, x_dtype=None, rois_dtype=None, **kwar
)

pool_h, pool_w = pool_size, pool_size
y = self.fn(x, rois, pool_h, pool_w, spatial_scale=1, sampling_ratio=-1, **kwargs)
with DeterministicGuard(deterministic):
y = self.fn(x, rois, pool_h, pool_w, spatial_scale=1, sampling_ratio=-1, **kwargs)
# the following should be true whether we're running an autocast test or not.
assert y.dtype == x.dtype
gt_y = self.expected_fn(
Expand Down Expand Up @@ -140,7 +157,7 @@ def test_torch_fx_trace(self, device, x_dtype=torch.float, rois_dtype=torch.floa
@pytest.mark.parametrize("seed", range(10))
@pytest.mark.parametrize("device", cpu_and_gpu())
@pytest.mark.parametrize("contiguous", (True, False))
def test_backward(self, seed, device, contiguous):
def test_backward(self, seed, device, contiguous, deterministic=False):
torch.random.manual_seed(seed)
pool_size = 2
x = torch.rand(1, 2 * (pool_size**2), 5, 5, dtype=self.dtype, device=device, requires_grad=True)
Expand All @@ -155,7 +172,9 @@ def func(z):

script_func = self.get_script_fn(rois, pool_size)

gradcheck(func, (x,))
with DeterministicGuard(deterministic):
gradcheck(func, (x,))

gradcheck(script_func, (x,))

@needs_cuda
Expand Down Expand Up @@ -384,7 +403,6 @@ def expected_fn(
grid_w = sampling_ratio if sampling_ratio > 0 else int(np.ceil(bin_w))

for channel in range(0, n_channels):

val = 0
for iy in range(0, grid_h):
y = start_h + (iy + 0.5) * bin_h / grid_h
Expand All @@ -402,21 +420,44 @@ def test_boxes_shape(self):
@pytest.mark.parametrize("aligned", (True, False))
@pytest.mark.parametrize("device", cpu_and_gpu())
@pytest.mark.parametrize("contiguous", (True, False))
def test_forward(self, device, contiguous, aligned, x_dtype=None, rois_dtype=None):
@pytest.mark.parametrize("deterministic", (True, False))
def test_forward(self, device, contiguous, deterministic, aligned, x_dtype=None, rois_dtype=None):
if deterministic and device == "cpu":
pytest.skip("cpu is always deterministic, don't retest")
super().test_forward(
device=device, contiguous=contiguous, x_dtype=x_dtype, rois_dtype=rois_dtype, aligned=aligned
device=device,
contiguous=contiguous,
deterministic=deterministic,
x_dtype=x_dtype,
rois_dtype=rois_dtype,
aligned=aligned,
)

@needs_cuda
@pytest.mark.parametrize("aligned", (True, False))
@pytest.mark.parametrize("deterministic", (True, False))
@pytest.mark.parametrize("x_dtype", (torch.float, torch.half))
@pytest.mark.parametrize("rois_dtype", (torch.float, torch.half))
def test_autocast(self, aligned, x_dtype, rois_dtype):
def test_autocast(self, aligned, deterministic, x_dtype, rois_dtype):
with torch.cuda.amp.autocast():
self.test_forward(
torch.device("cuda"), contiguous=False, aligned=aligned, x_dtype=x_dtype, rois_dtype=rois_dtype
torch.device("cuda"),
contiguous=False,
deterministic=deterministic,
aligned=aligned,
x_dtype=x_dtype,
rois_dtype=rois_dtype,
)

@pytest.mark.parametrize("seed", range(10))
@pytest.mark.parametrize("device", cpu_and_gpu())
@pytest.mark.parametrize("contiguous", (True, False))
@pytest.mark.parametrize("deterministic", (True, False))
def test_backward(self, seed, device, contiguous, deterministic):
if deterministic and device == "cpu":
pytest.skip("cpu is always deterministic, don't retest")
super().test_backward(seed, device, contiguous, deterministic)

def _make_rois(self, img_size, num_imgs, dtype, num_rois=1000):
rois = torch.randint(0, img_size // 2, size=(num_rois, 5)).to(dtype)
rois[:, 0] = torch.randint(0, num_imgs, size=(num_rois,)) # set batch index
Expand Down Expand Up @@ -978,7 +1019,6 @@ def test_compare_cpu_cuda_grads(self, contiguous):
weight = init_weight

for d in ["cpu", "cuda"]:

out = ops.deform_conv2d(img.to(d), offset.to(d), weight.to(d), padding=1, mask=mask.to(d))
out.mean().backward()
if true_cpu_grads is None:
Expand Down Expand Up @@ -1374,7 +1414,6 @@ class TestGeneralizedBoxIouLoss:
@pytest.mark.parametrize("device", cpu_and_gpu())
@pytest.mark.parametrize("dtype", [torch.float32, torch.half])
def test_giou_loss(self, dtype, device):

box1, box2, box3, box4, box1s, box2s = get_boxes(dtype, device)

# Identical boxes should have loss of 0
Expand Down
179 changes: 177 additions & 2 deletions torchvision/ops/roi_align.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,188 @@
from typing import List, Union

import torch
import torch._dynamo
import torch.fx
from torch import nn, Tensor
from torch.jit.annotations import BroadcastingList2
from torch.nn.modules.utils import _pair
from torchvision.extension import _assert_has_ops
from torchvision.extension import _assert_has_ops, _has_ops

from ..utils import _log_api_usage_once
from ._utils import check_roi_boxes_shape, convert_boxes_to_roi_format


# NB: all inputs are tensors
def _bilinear_interpolate(
input, # [N, C, H, W]
roi_batch_ind, # [K]
y, # [K, PH, IY]
x, # [K, PW, IX]
ymask, # [K, IY]
xmask, # [K, IX]
):
_, channels, height, width = input.size()

# deal with inverse element out of feature map boundary
y = y.clamp(min=0)
x = x.clamp(min=0)
y_low = y.int()
x_low = x.int()
y_high = torch.where(y_low >= height - 1, height - 1, y_low + 1)
y_low = torch.where(y_low >= height - 1, height - 1, y_low)
y = torch.where(y_low >= height - 1, y.to(input.dtype), y)

x_high = torch.where(x_low >= width - 1, width - 1, x_low + 1)
x_low = torch.where(x_low >= width - 1, width - 1, x_low)
x = torch.where(x_low >= width - 1, x.to(input.dtype), x)

ly = y - y_low
lx = x - x_low
hy = 1.0 - ly
hx = 1.0 - lx

# do bilinear interpolation, but respect the masking!
# TODO: It's possible the masking here is unnecessary if y and
# x were clamped appropriately; hard to tell
def masked_index(
y, # [K, PH, IY]
x, # [K, PW, IX]
):
if ymask is not None:
assert xmask is not None
y = torch.where(ymask[:, None, :], y, 0)
x = torch.where(xmask[:, None, :], x, 0)
return input[
roi_batch_ind[:, None, None, None, None, None],
torch.arange(channels, device=input.device)[None, :, None, None, None, None],
y[:, None, :, None, :, None], # prev [K, PH, IY]
x[:, None, None, :, None, :], # prev [K, PW, IX]
] # [K, C, PH, PW, IY, IX]

v1 = masked_index(y_low, x_low)
v2 = masked_index(y_low, x_high)
v3 = masked_index(y_high, x_low)
v4 = masked_index(y_high, x_high)

# all ws preemptively [K, C, PH, PW, IY, IX]
def outer_prod(y, x):
return y[:, None, :, None, :, None] * x[:, None, None, :, None, :]

w1 = outer_prod(hy, hx)
w2 = outer_prod(hy, lx)
w3 = outer_prod(ly, hx)
w4 = outer_prod(ly, lx)

val = w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4
return val


# TODO: this doesn't actually cache
# TODO: main library should make this easier to do
def maybe_cast(tensor):
if torch.is_autocast_enabled() and tensor.is_cuda and tensor.dtype != torch.double:
return tensor.float()
else:
return tensor


# This is a slow but pure Python and differentiable implementation of
# roi_align. It potentially is a good basis for Inductor compilation
# (but I have not benchmarked it) but today it is solely used for the
# fact that its backwards can be implemented deterministically,
# which is needed for the PT2 benchmark suite.
#
# It is transcribed directly off of the roi_align CUDA kernel, see
# https://dev-discuss.pytorch.org/t/a-pure-python-implementation-of-roi-align-that-looks-just-like-its-cuda-kernel/1266
@torch._dynamo.allow_in_graph
def _roi_align(input, rois, spatial_scale, pooled_height, pooled_width, sampling_ratio, aligned):
orig_dtype = input.dtype

input = maybe_cast(input)
rois = maybe_cast(rois)

_, _, height, width = input.size()

ph = torch.arange(pooled_height, device=input.device) # [PH]
pw = torch.arange(pooled_width, device=input.device) # [PW]

# input: [N, C, H, W]
# rois: [K, 5]

roi_batch_ind = rois[:, 0].int() # [K]
offset = 0.5 if aligned else 0.0
roi_start_w = rois[:, 1] * spatial_scale - offset # [K]
roi_start_h = rois[:, 2] * spatial_scale - offset # [K]
roi_end_w = rois[:, 3] * spatial_scale - offset # [K]
roi_end_h = rois[:, 4] * spatial_scale - offset # [K]

roi_width = roi_end_w - roi_start_w # [K]
roi_height = roi_end_h - roi_start_h # [K]
if not aligned:
roi_width = torch.clamp(roi_width, min=1.0) # [K]
roi_height = torch.clamp(roi_height, min=1.0) # [K]

bin_size_h = roi_height / pooled_height # [K]
bin_size_w = roi_width / pooled_width # [K]

exact_sampling = sampling_ratio > 0

roi_bin_grid_h = sampling_ratio if exact_sampling else torch.ceil(roi_height / pooled_height) # scalar or [K]
roi_bin_grid_w = sampling_ratio if exact_sampling else torch.ceil(roi_width / pooled_width) # scalar or [K]

"""
iy, ix = dims(2)
"""

if exact_sampling:
count = max(roi_bin_grid_h * roi_bin_grid_w, 1) # scalar
iy = torch.arange(roi_bin_grid_h, device=input.device) # [IY]
ix = torch.arange(roi_bin_grid_w, device=input.device) # [IX]
ymask = None
xmask = None
else:
count = torch.clamp(roi_bin_grid_h * roi_bin_grid_w, min=1) # [K]
# When doing adaptive sampling, the number of samples we need to do
# is data-dependent based on how big the ROIs are. This is a bit
# awkward because first-class dims can't actually handle this.
# So instead, we inefficiently suppose that we needed to sample ALL
# the points and mask out things that turned out to be unnecessary
iy = torch.arange(height, device=input.device) # [IY]
ix = torch.arange(width, device=input.device) # [IX]
ymask = iy[None, :] < roi_bin_grid_h[:, None] # [K, IY]
xmask = ix[None, :] < roi_bin_grid_w[:, None] # [K, IX]

def from_K(t):
return t[:, None, None]

y = (
from_K(roi_start_h)
+ ph[None, :, None] * from_K(bin_size_h)
+ (iy[None, None, :] + 0.5) * from_K(bin_size_h / roi_bin_grid_h)
) # [K, PH, IY]
x = (
from_K(roi_start_w)
+ pw[None, :, None] * from_K(bin_size_w)
+ (ix[None, None, :] + 0.5) * from_K(bin_size_w / roi_bin_grid_w)
) # [K, PW, IX]
val = _bilinear_interpolate(input, roi_batch_ind, y, x, ymask, xmask) # [K, C, PH, PW, IY, IX]

# Mask out samples that weren't actually adaptively needed
if not exact_sampling:
val = torch.where(ymask[:, None, None, None, :, None], val, 0)
val = torch.where(xmask[:, None, None, None, None, :], val, 0)

output = val.sum((-1, -2)) # remove IY, IX ~> [K, C, PH, PW]
if isinstance(count, torch.Tensor):
output /= count[:, None, None, None]
else:
output /= count

output = output.to(orig_dtype)

return output


@torch.fx.wrap
def roi_align(
input: Tensor,
Expand Down Expand Up @@ -54,12 +226,15 @@ def roi_align(
"""
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
_log_api_usage_once(roi_align)
_assert_has_ops()
check_roi_boxes_shape(boxes)
rois = boxes
output_size = _pair(output_size)
if not isinstance(rois, torch.Tensor):
rois = convert_boxes_to_roi_format(rois)
if not torch.jit.is_scripting():
if not _has_ops() or (torch.are_deterministic_algorithms_enabled() and input.is_cuda):
return _roi_align(input, rois, spatial_scale, output_size[0], output_size[1], sampling_ratio, aligned)
_assert_has_ops()
return torch.ops.torchvision.roi_align(
input, rois, spatial_scale, output_size[0], output_size[1], sampling_ratio, aligned
)
Expand Down