Skip to content

Commit d6f78ab

Browse files
committed
Added check for index upper bound
1 parent 07f3374 commit d6f78ab

File tree

2 files changed

+21
-5
lines changed

2 files changed

+21
-5
lines changed

test/test_ops.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -304,20 +304,20 @@ def test_qroialign(self):
304304
pool_size = 5
305305
img_size = 10
306306
n_channels = 2
307-
num_batches = 2
307+
num_imgs = 2
308308
dtype = torch.float
309309

310310
def make_rois(num_rois=1000):
311311
rois = torch.randint(0, img_size // 2, size=(num_rois, 5)).to(dtype)
312-
rois[:, 0] = torch.randint(0, num_batches, size=(num_rois,)) # set batch index
312+
rois[:, 0] = torch.randint(0, num_imgs, size=(num_rois,)) # set batch index
313313
rois[:, 3:] += rois[:, 1:3] # make sure boxes aren't degenerate
314314
return rois
315315

316316
for aligned in (True, False):
317317
for scale, zero_point in ((1, 0), (2, 10), (0.1, 50)):
318318
for qdtype in (torch.qint8, torch.quint8, torch.qint32):
319319

320-
x = torch.randint(50, 100, size=(num_batches, n_channels, img_size, img_size)).to(dtype)
320+
x = torch.randint(50, 100, size=(num_imgs, n_channels, img_size, img_size)).to(dtype)
321321
qx = torch.quantize_per_tensor(x, scale=scale, zero_point=zero_point, dtype=qdtype)
322322

323323
rois = make_rois()
@@ -364,6 +364,13 @@ def make_rois(num_rois=1000):
364364
t_scale = torch.full_like(abs_diff, fill_value=scale)
365365
self.assertTrue(torch.allclose(abs_diff, t_scale, atol=1e-5))
366366

367+
x = torch.randint(50, 100, size=(129, 3, 10, 10)).to(dtype)
368+
qx = torch.quantize_per_tensor(x, scale=0, zero_point=1, dtype=torch.qint8)
369+
rois = make_rois(10)
370+
qrois = torch.quantize_per_tensor(rois, scale=0, zero_point=1, dtype=torch.qint8)
371+
with self.assertRaisesRegex(RuntimeError, "There are 129 input images in the batch, but the RoIs tensor"):
372+
ops.roi_align(qx, qrois, output_size=pool_size)
373+
367374

368375
class PSRoIAlignTester(RoIOpTester, unittest.TestCase):
369376
def fn(self, x, rois, pool_h, pool_w, spatial_scale=1, sampling_ratio=-1, **kwargs):

torchvision/csrc/ops/quantized/cpu/qroi_align_kernel.cpp

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,8 +36,7 @@ void qroi_align_forward_kernel_impl(
3636

3737
const T* offset_rois = rois + n * 5;
3838
int roi_batch_ind = at::native::dequantize_val(
39-
rois_scale, rois_zp, offset_rois[0]); // FIXME: This can be out of the
40-
// range of the quantized type!!
39+
rois_scale, rois_zp, offset_rois[0]);
4140

4241
// Do not using rounding; this implementation detail is critical
4342
float offset = aligned ? 0.5 : 0.;
@@ -172,6 +171,16 @@ at::Tensor qroi_align_forward_kernel(
172171
return output;
173172

174173
AT_DISPATCH_QINT_TYPES(input.scalar_type(), "qroi_align_forward_kernel", [&] {
174+
// Note: q_max relates to the input tensor, but we need that of the rois
175+
// tensor. They're the same since we make sure rois and input have the same
176+
// type above.
177+
uint64_t max_indexable = std::numeric_limits<underlying_t>::max() + 1;
178+
std::string err_msg = "There are " + std::to_string(input.size(0)) +
179+
" input images in the batch, but the RoIs tensor can only index up to " +
180+
std::to_string(max_indexable) +
181+
" images. Try to reduce the batch size.";
182+
TORCH_CHECK(input.size(0) <= max_indexable, err_msg);
183+
175184
qroi_align_forward_kernel_impl<scalar_t>(
176185
num_rois,
177186
input,

0 commit comments

Comments
 (0)