Skip to content

Commit 11268ca

Browse files
KsenijaSdatumbox
andauthored
[ONNX] Fix roi_align ONNX export (#3355)
* add tests * fix bug * remove tests * fix comment * fix comment * add warning * fix syntax error * fix python lint Co-authored-by: Vasilis Vryniotis <[email protected]>
1 parent 9bccd5a commit 11268ca

File tree

2 files changed

+16
-0
lines changed

2 files changed

+16
-0
lines changed

test/test_onnx.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,11 @@ def test_roi_align(self):
129129
model = ops.RoIAlign((5, 5), 1, 2)
130130
self.run_model(model, [(x, single_roi)])
131131

132+
x = torch.rand(1, 1, 10, 10, dtype=torch.float32)
133+
single_roi = torch.tensor([[0, 0, 0, 4, 4]], dtype=torch.float32)
134+
model = ops.RoIAlign((5, 5), 1, -1)
135+
self.run_model(model, [(x, single_roi)])
136+
132137
def test_roi_align_aligned(self):
133138
x = torch.rand(1, 1, 10, 10, dtype=torch.float32)
134139
single_roi = torch.tensor([[0, 1.5, 1.5, 3, 3]], dtype=torch.float32)
@@ -150,6 +155,11 @@ def test_roi_align_aligned(self):
150155
model = ops.RoIAlign((2, 2), 2.5, 0, aligned=True)
151156
self.run_model(model, [(x, single_roi)])
152157

158+
x = torch.rand(1, 1, 10, 10, dtype=torch.float32)
159+
single_roi = torch.tensor([[0, 0.2, 0.3, 4.5, 3.5]], dtype=torch.float32)
160+
model = ops.RoIAlign((2, 2), 2.5, -1, aligned=True)
161+
self.run_model(model, [(x, single_roi)])
162+
153163
@unittest.skip # Issue in exporting ROIAlign with aligned = True for malformed boxes
154164
def test_roi_align_malformed_boxes(self):
155165
x = torch.randn(1, 1, 10, 10, dtype=torch.float32)

torchvision/ops/_register_onnx_ops.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,12 @@ def roi_align(g, input, rois, spatial_scale, pooled_height, pooled_width, sampli
2929
" ONNX forces ROIs to be 1x1 or larger.")
3030
scale = torch.tensor(0.5 / spatial_scale).to(dtype=torch.float)
3131
rois = g.op("Sub", rois, scale)
32+
33+
# ONNX doesn't support negative sampling_ratio
34+
if sampling_ratio < 0:
35+
warnings.warn("ONNX doesn't support negative sampling ratio,"
36+
"therefore is is set to 0 in order to be exported.")
37+
sampling_ratio = 0
3238
return g.op('RoiAlign', input, rois, batch_indices, spatial_scale_f=spatial_scale,
3339
output_height_i=pooled_height, output_width_i=pooled_width, sampling_ratio_i=sampling_ratio)
3440

0 commit comments

Comments
 (0)