Skip to content

Commit cfb8557

Browse files
authored
Revert "Revert "Register Torchvision Ops as Cutom Ops (#1267)" (#1316)"
This reverts commit fe234fc.
1 parent fe234fc commit cfb8557

File tree

12 files changed

+199
-13
lines changed

12 files changed

+199
-13
lines changed

.travis.yml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,10 @@ before_install:
4747
- pip install future
4848
- pip install pytest pytest-cov codecov
4949
- pip install mock
50+
- |
51+
if [[ $TRAVIS_PYTHON_VERSION == 3.6 ]]; then
52+
pip install onnxruntime
53+
fi
5054
- conda install av -c conda-forge
5155

5256

setup.py

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -96,12 +96,21 @@ def get_extensions():
9696
source_models = [os.path.join(models_dir, s) for s in source_models]
9797
tests = test_file + source_models
9898

99+
custom_ops_sources = [os.path.join(extensions_dir, "custom_ops", "custom_ops.cpp"),
100+
os.path.join(extensions_dir, "cpu", "nms_cpu.cpp"),
101+
os.path.join(extensions_dir, "cpu", "ROIAlign_cpu.cpp"),
102+
os.path.join(extensions_dir, "cpu", "ROIPool_cpu.cpp")]
103+
custom_ops_sources_cuda = [os.path.join(extensions_dir, "cuda", "nms_cuda.cu"),
104+
os.path.join(extensions_dir, "cuda", "ROIAlign_cuda.cu"),
105+
os.path.join(extensions_dir, "cuda", "ROIPool_cuda.cu")]
106+
99107
define_macros = []
100108

101109
extra_compile_args = {}
102110
if (torch.cuda.is_available() and CUDA_HOME is not None) or os.getenv('FORCE_CUDA', '0') == '1':
103111
extension = CUDAExtension
104112
sources += source_cuda
113+
custom_ops_sources += custom_ops_sources_cuda
105114
define_macros += [('WITH_CUDA', None)]
106115
nvcc_flags = os.getenv('NVCC_FLAGS', '')
107116
if nvcc_flags == '':
@@ -138,7 +147,14 @@ def get_extensions():
138147
include_dirs=tests_include_dirs,
139148
define_macros=define_macros,
140149
extra_compile_args=extra_compile_args,
141-
)
150+
),
151+
extension(
152+
"torchvision._custom_ops",
153+
sources=custom_ops_sources,
154+
include_dirs=include_dirs,
155+
define_macros=define_macros,
156+
extra_compile_args=extra_compile_args,
157+
),
142158
]
143159

144160
return ext_modules
@@ -179,5 +195,6 @@ def run(self):
179195
"scipy": ["scipy"],
180196
},
181197
ext_modules=get_extensions(),
182-
cmdclass={'build_ext': torch.utils.cpp_extension.BuildExtension, 'clean': clean}
198+
cmdclass={'build_ext': torch.utils.cpp_extension.BuildExtension,
199+
'clean': clean}
183200
)

test/test_onnx.py

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
import io
2+
import torch
3+
from torchvision import ops
4+
5+
# onnxruntime requires python 3.5 or above
6+
try:
7+
import onnxruntime
8+
except ImportError:
9+
onnxruntime = None
10+
11+
import unittest
12+
13+
14+
@unittest.skipIf(onnxruntime is None, 'ONNX Runtime unavailable')
15+
class ONNXExporterTester(unittest.TestCase):
16+
@classmethod
17+
def setUpClass(cls):
18+
torch.manual_seed(123)
19+
20+
def run_model(self, model, inputs):
21+
model.eval()
22+
23+
# run pytorch model
24+
with torch.no_grad():
25+
if isinstance(inputs, torch.Tensor):
26+
inputs = (inputs,)
27+
outputs = model(*inputs)
28+
if isinstance(outputs, torch.Tensor):
29+
outputs = (outputs,)
30+
31+
onnx_io = io.BytesIO()
32+
# export to onnx
33+
torch.onnx.export(model, inputs, onnx_io, do_constant_folding=True, opset_version=10)
34+
35+
# validate the exported model with onnx runtime
36+
self.ort_validate(onnx_io, inputs, outputs)
37+
38+
def ort_validate(self, onnx_io, inputs, outputs):
39+
40+
inputs, _ = torch.jit._flatten(inputs)
41+
outputs, _ = torch.jit._flatten(outputs)
42+
43+
def to_numpy(tensor):
44+
if tensor.requires_grad:
45+
return tensor.detach().cpu().numpy()
46+
else:
47+
return tensor.cpu().numpy()
48+
49+
inputs = list(map(to_numpy, inputs))
50+
outputs = list(map(to_numpy, outputs))
51+
52+
ort_session = onnxruntime.InferenceSession(onnx_io.getvalue())
53+
# compute onnxruntime output prediction
54+
ort_inputs = dict((ort_session.get_inputs()[i].name, inpt) for i, inpt in enumerate(inputs))
55+
ort_outs = ort_session.run(None, ort_inputs)
56+
57+
for i in range(0, len(outputs)):
58+
torch.testing.assert_allclose(outputs[i], ort_outs[i], rtol=1e-03, atol=1e-05)
59+
60+
def test_nms(self):
61+
boxes = torch.rand(5, 4)
62+
boxes[:, 2:] += torch.rand(5, 2)
63+
scores = torch.randn(5)
64+
65+
class Module(torch.nn.Module):
66+
def forward(self, boxes, scores):
67+
return ops.nms(boxes, scores, 0.5)
68+
69+
self.run_model(Module(), (boxes, scores))
70+
71+
def test_roi_pool(self):
72+
x = torch.rand(1, 1, 10, 10, dtype=torch.float32)
73+
single_roi = torch.tensor([[0, 0, 0, 4, 4]], dtype=torch.float32)
74+
model = ops.RoIAlign((5, 5), 1, 2)
75+
self.run_model(model, (x, single_roi))
76+
77+
def test_roi_align(self):
78+
x = torch.rand(1, 1, 10, 10, dtype=torch.float32)
79+
rois = torch.tensor([[0, 0, 0, 4, 4]], dtype=torch.float32)
80+
pool_h = 5
81+
pool_w = 5
82+
model = ops.RoIPool((pool_h, pool_w), 2)
83+
model.eval()
84+
self.run_model(model, (x, rois))
85+
86+
87+
if __name__ == '__main__':
88+
unittest.main()

torchvision/csrc/ROIAlign.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,11 @@
1010
at::Tensor ROIAlign_forward(
1111
const at::Tensor& input, // Input feature map.
1212
const at::Tensor& rois, // List of ROIs to pool over.
13-
const float spatial_scale, // The scale of the image features. ROIs will be
13+
const double spatial_scale, // The scale of the image features. ROIs will be
1414
// scaled to this.
15-
const int pooled_height, // The height of the pooled feature map.
16-
const int pooled_width, // The width of the pooled feature
17-
const int sampling_ratio) // The number of points to sample in each bin
15+
const int64_t pooled_height, // The height of the pooled feature map.
16+
const int64_t pooled_width, // The width of the pooled feature
17+
const int64_t sampling_ratio) // The number of points to sample in each bin
1818
// along each axis.
1919
{
2020
if (input.type().is_cuda()) {

torchvision/csrc/ROIPool.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,9 @@
99
std::tuple<at::Tensor, at::Tensor> ROIPool_forward(
1010
const at::Tensor& input,
1111
const at::Tensor& rois,
12-
const float spatial_scale,
13-
const int pooled_height,
14-
const int pooled_width) {
12+
const double spatial_scale,
13+
const int64_t pooled_height,
14+
const int64_t pooled_width) {
1515
if (input.type().is_cuda()) {
1616
#ifdef WITH_CUDA
1717
return ROIPool_forward_cuda(
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
#include <torch/script.h>
2+
3+
#include "ROIAlign.h"
4+
#include "ROIPool.h"
5+
#include "nms.h"
6+
7+
using namespace at;
8+
9+
static auto registry =
10+
torch::RegisterOperators()
11+
.op("torchvision::nms", &nms)
12+
.op("torchvision::roi_align(Tensor input, Tensor rois, float spatial_scale, int pooled_height, int pooled_width, int sampling_ratio) -> Tensor",
13+
&ROIAlign_forward)
14+
.op("torchvision::roi_pool", &ROIPool_forward);

torchvision/csrc/nms.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
at::Tensor nms(
99
const at::Tensor& dets,
1010
const at::Tensor& scores,
11-
const float iou_threshold) {
11+
const double iou_threshold) {
1212
if (dets.device().is_cuda()) {
1313
#ifdef WITH_CUDA
1414
if (dets.numel() == 0) {

torchvision/csrc/vision.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77
#endif
88

99
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
10+
// TODO: remove nms from here since it is now registered
11+
// and used as a PyTorch custom op
1012
m.def("nms", &nms, "non-maximum suppression");
1113
m.def("roi_align_forward", &ROIAlign_forward, "ROIAlign_forward");
1214
m.def("roi_align_backward", &ROIAlign_backward, "ROIAlign_backward");

torchvision/ops/_custom_ops.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
import os
2+
import sys
3+
import imp
4+
import torch
5+
6+
7+
# load the custom_op_library and register the custom ops
8+
lib_dir = os.path.join(os.path.dirname(__file__), '..')
9+
file, path, description = imp.find_module("_custom_ops", [lib_dir])
10+
torch.ops.load_library(path)
11+
12+
13+
def register_custom_op():
14+
from torch.onnx.symbolic_helper import parse_args, scalar_type_to_onnx
15+
from torch.onnx.symbolic_opset9 import select, unsqueeze, squeeze, _cast_Long, reshape
16+
17+
@parse_args('v', 'v', 'f')
18+
def symbolic_multi_label_nms(g, boxes, scores, iou_threshold):
19+
boxes = unsqueeze(g, boxes, 0)
20+
scores = unsqueeze(g, unsqueeze(g, scores, 0), 0)
21+
max_output_per_class = g.op('Constant', value_t=torch.tensor([sys.maxsize], dtype=torch.long))
22+
iou_threshold = g.op('Constant', value_t=torch.tensor([iou_threshold], dtype=torch.float))
23+
nms_out = g.op('NonMaxSuppression', boxes, scores, max_output_per_class, iou_threshold)
24+
return squeeze(g, select(g, nms_out, 1, g.op('Constant', value_t=torch.tensor([2], dtype=torch.long))), 1)
25+
26+
@parse_args('v', 'v', 'f', 'i', 'i', 'i')
27+
def roi_align(g, input, rois, spatial_scale, pooled_height, pooled_width, sampling_ratio):
28+
batch_indices = _cast_Long(g, squeeze(g, select(g, rois, 1, g.op('Constant',
29+
value_t=torch.tensor([0], dtype=torch.long))), 1), False)
30+
rois = select(g, rois, 1, g.op('Constant', value_t=torch.tensor([1, 2, 3, 4], dtype=torch.long)))
31+
return g.op('RoiAlign', input, rois, batch_indices, spatial_scale_f=spatial_scale,
32+
output_height_i=pooled_height, output_width_i=pooled_width, sampling_ratio_i=sampling_ratio)
33+
34+
@parse_args('v', 'v', 'f', 'i', 'i')
35+
def roi_pool(g, input, rois, spatial_scale, pooled_height, pooled_width):
36+
roi_pool = g.op('MaxRoiPool', input, rois,
37+
pooled_shape_i=(pooled_height, pooled_width), spatial_scale_f=spatial_scale)
38+
return roi_pool, None
39+
40+
from torch.onnx import register_custom_op_symbolic
41+
register_custom_op_symbolic('torchvision::nms', symbolic_multi_label_nms, 10)
42+
register_custom_op_symbolic('torchvision::roi_align', roi_align, 10)
43+
register_custom_op_symbolic('torchvision::roi_pool', roi_pool, 10)
44+
45+
46+
register_custom_op()

torchvision/ops/boxes.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import torch
2-
from torchvision.extension import _lazy_import
2+
import torchvision.ops._custom_ops
33

44

55
def nms(boxes, scores, iou_threshold):
@@ -29,8 +29,7 @@ def nms(boxes, scores, iou_threshold):
2929
of the elements that have been kept
3030
by NMS, sorted in decreasing order of scores
3131
"""
32-
_C = _lazy_import()
33-
return _C.nms(boxes, scores, iou_threshold)
32+
return torch.ops.torchvision.nms(boxes, scores, iou_threshold)
3433

3534

3635
def batched_nms(boxes, scores, idxs, iou_threshold):

0 commit comments

Comments
 (0)