Skip to content

Commit a810d89

Browse files
committed
Register ops to AutocastCPU
* modify the directory structure: moved the autocast files from torchvision/csrc/ops/autocast/ to torchvision/csrc/ops/autocast/cuda * add the cpu directory under the autocast directory; * register deform_conv2d, nms, ps_roi_align, ps_roi_pool, roi_align, and roi_pool to AutocastCPU.
1 parent 719e120 commit a810d89

14 files changed

+284
-7
lines changed

setup.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,7 @@ def get_extensions():
142142
'*.cpp'))
143143
source_cpu = (
144144
glob.glob(os.path.join(extensions_dir, 'ops', 'autograd', '*.cpp')) +
145+
glob.glob(os.path.join(extensions_dir, 'ops', 'autocast', 'cpu', '*.cpp')) +
145146
glob.glob(os.path.join(extensions_dir, 'ops', 'cpu', '*.cpp')) +
146147
glob.glob(os.path.join(extensions_dir, 'ops', 'quantized', 'cpu', '*.cpp'))
147148
)
@@ -170,7 +171,7 @@ def get_extensions():
170171
else:
171172
source_cuda = glob.glob(os.path.join(extensions_dir, 'ops', 'cuda', '*.cu'))
172173

173-
source_cuda += glob.glob(os.path.join(extensions_dir, 'ops', 'autocast', '*.cpp'))
174+
source_cuda += glob.glob(os.path.join(extensions_dir, 'ops', 'autocast', 'cuda', '*.cpp'))
174175

175176
sources = main_file + source_cpu
176177
extension = CppExtension

test/test_ops.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,12 @@ def test_autocast(self, x_dtype, rois_dtype):
7171
with torch.cuda.amp.autocast():
7272
self.test_forward(torch.device("cuda"), contiguous=False, x_dtype=x_dtype, rois_dtype=rois_dtype)
7373

74+
@pytest.mark.parametrize('x_dtype', (torch.float, torch.half))
75+
@pytest.mark.parametrize('rois_dtype', (torch.float, torch.half))
76+
def test_autocast_cpu(self, x_dtype, rois_dtype):
77+
with torch.cpu.amp.autocast():
78+
self.test_forward(torch.device("cpu"), contiguous=False, x_dtype=x_dtype, rois_dtype=rois_dtype)
79+
7480
def _helper_boxes_shape(self, func):
7581
# test boxes as Tensor[N, 5]
7682
with pytest.raises(AssertionError):
@@ -278,6 +284,14 @@ def test_autocast(self, aligned, x_dtype, rois_dtype):
278284
self.test_forward(torch.device("cuda"), contiguous=False, aligned=aligned, x_dtype=x_dtype,
279285
rois_dtype=rois_dtype)
280286

287+
@pytest.mark.parametrize('aligned', (True, False))
288+
@pytest.mark.parametrize('x_dtype', (torch.float, torch.half))
289+
@pytest.mark.parametrize('rois_dtype', (torch.float, torch.half))
290+
def test_autocast_cpu(self, aligned, x_dtype, rois_dtype):
291+
with torch.cpu.amp.autocast():
292+
self.test_forward(torch.device("cpu"), contiguous=False, aligned=aligned, x_dtype=x_dtype,
293+
rois_dtype=rois_dtype)
294+
281295
def _make_rois(self, img_size, num_imgs, dtype, num_rois=1000):
282296
rois = torch.randint(0, img_size // 2, size=(num_rois, 5)).to(dtype)
283297
rois[:, 0] = torch.randint(0, num_imgs, size=(num_rois,)) # set batch index
@@ -514,13 +528,27 @@ def test_nms_cuda(self, iou, dtype=torch.float64):
514528
is_eq = torch.allclose(scores[r_cpu], scores[r_cuda.cpu()], rtol=tol, atol=tol)
515529
assert is_eq, err_msg.format(iou)
516530

531+
517532
@needs_cuda
518533
@pytest.mark.parametrize("iou", (.2, .5, .8))
519534
@pytest.mark.parametrize("dtype", (torch.float, torch.half))
520535
def test_autocast(self, iou, dtype):
521536
with torch.cuda.amp.autocast():
522537
self.test_nms_cuda(iou=iou, dtype=dtype)
523538

539+
@pytest.mark.parametrize("iou", (.2, .5, .8))
540+
@pytest.mark.parametrize("dtype", (torch.bfloat16,))
541+
def test_autocast_cpu(self, iou, dtype):
542+
with torch.cpu.amp.autocast():
543+
def test_nms_cpu(iou, dtype):
544+
boxes, scores = self._create_tensors_with_iou(1000, iou)
545+
boxes = boxes.to(dtype=dtype)
546+
scores = scores.to(dtype=dtype)
547+
out = ops.nms(boxes, scores, iou)
548+
outf = ops.nms(boxes.float(), scores.float(), iou)
549+
torch.testing.assert_close(out, outf)
550+
test_nms_cpu(iou=iou, dtype=dtype)
551+
524552
@needs_cuda
525553
def test_nms_cuda_float16(self):
526554
boxes = torch.tensor([[285.3538, 185.5758, 1193.5110, 851.4551],
@@ -774,6 +802,12 @@ def test_autocast(self, batch_sz, dtype):
774802
with torch.cuda.amp.autocast():
775803
self.test_forward(torch.device("cuda"), contiguous=False, batch_sz=batch_sz, dtype=dtype)
776804

805+
@pytest.mark.parametrize('batch_sz', (0, 33))
806+
@pytest.mark.parametrize('dtype', (torch.float, torch.half))
807+
def test_autocast_cpu(self, batch_sz, dtype):
808+
with torch.cpu.amp.autocast():
809+
self.test_forward(torch.device("cpu"), contiguous=False, batch_sz=batch_sz, dtype=dtype)
810+
777811
def test_forward_scriptability(self):
778812
# Non-regression test for https://github.com/pytorch/vision/issues/4078
779813
torch.jit.script(ops.DeformConv2d(in_channels=8, out_channels=8, kernel_size=3))
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
#include "../../deform_conv2d.h"
2+
3+
#include <ATen/autocast_mode.h>
4+
#include <torch/types.h>
5+
6+
namespace vision {
7+
namespace ops {
8+
9+
namespace {
10+
11+
at::Tensor deform_conv2d_autocast(
12+
const at::Tensor& input,
13+
const at::Tensor& weight,
14+
const at::Tensor& offset,
15+
const at::Tensor& mask,
16+
const at::Tensor& bias,
17+
int64_t stride_h,
18+
int64_t stride_w,
19+
int64_t pad_h,
20+
int64_t pad_w,
21+
int64_t dilation_h,
22+
int64_t dilation_w,
23+
int64_t groups,
24+
int64_t offset_groups,
25+
bool use_mask) {
26+
c10::impl::ExcludeDispatchKeyGuard no_autocast(c10::DispatchKey::AutocastCPU);
27+
return deform_conv2d(
28+
at::autocast::cached_cast(at::kFloat, input, c10::DeviceType::CPU),
29+
at::autocast::cached_cast(at::kFloat, weight, c10::DeviceType::CPU),
30+
at::autocast::cached_cast(at::kFloat, offset, c10::DeviceType::CPU),
31+
at::autocast::cached_cast(at::kFloat, mask, c10::DeviceType::CPU),
32+
at::autocast::cached_cast(at::kFloat, bias, c10::DeviceType::CPU),
33+
stride_h,
34+
stride_w,
35+
pad_h,
36+
pad_w,
37+
dilation_h,
38+
dilation_w,
39+
groups,
40+
offset_groups,
41+
use_mask)
42+
.to(input.scalar_type());
43+
}
44+
45+
} // namespace
46+
47+
TORCH_LIBRARY_IMPL(torchvision, AutocastCPU, m) {
48+
m.impl(
49+
TORCH_SELECTIVE_NAME("torchvision::deform_conv2d"),
50+
TORCH_FN(deform_conv2d_autocast));
51+
}
52+
53+
} // namespace ops
54+
} // namespace vision
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
#include "../../nms.h"
2+
3+
#include <ATen/autocast_mode.h>
4+
#include <torch/types.h>
5+
6+
namespace vision {
7+
namespace ops {
8+
9+
namespace {
10+
11+
at::Tensor nms_autocast(
12+
const at::Tensor& dets,
13+
const at::Tensor& scores,
14+
double iou_threshold) {
15+
c10::impl::ExcludeDispatchKeyGuard no_autocast(c10::DispatchKey::AutocastCPU);
16+
return nms(
17+
at::autocast::cached_cast(at::kFloat, dets, c10::DeviceType::CPU),
18+
at::autocast::cached_cast(at::kFloat, scores, c10::DeviceType::CPU),
19+
iou_threshold);
20+
}
21+
22+
} // namespace
23+
24+
TORCH_LIBRARY_IMPL(torchvision, AutocastCPU, m) {
25+
m.impl(TORCH_SELECTIVE_NAME("torchvision::nms"), TORCH_FN(nms_autocast));
26+
}
27+
28+
} // namespace ops
29+
} // namespace vision
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
#include "../../ps_roi_align.h"
2+
3+
#include <ATen/autocast_mode.h>
4+
#include <torch/types.h>
5+
6+
namespace vision {
7+
namespace ops {
8+
9+
namespace {
10+
11+
std::tuple<at::Tensor, at::Tensor> ps_roi_align_autocast(
12+
const at::Tensor& input,
13+
const at::Tensor& rois,
14+
double spatial_scale,
15+
int64_t pooled_height,
16+
int64_t pooled_width,
17+
int64_t sampling_ratio) {
18+
c10::impl::ExcludeDispatchKeyGuard no_autocast(c10::DispatchKey::AutocastCPU);
19+
auto result = ps_roi_align(
20+
at::autocast::cached_cast(at::kFloat, input, c10::DeviceType::CPU),
21+
at::autocast::cached_cast(at::kFloat, rois, c10::DeviceType::CPU),
22+
spatial_scale,
23+
pooled_height,
24+
pooled_width,
25+
sampling_ratio);
26+
27+
return std::make_tuple(
28+
std::get<0>(result).to(input.scalar_type()),
29+
std::get<1>(result).to(input.scalar_type()));
30+
}
31+
32+
} // namespace
33+
34+
TORCH_LIBRARY_IMPL(torchvision, AutocastCPU, m) {
35+
m.impl(
36+
TORCH_SELECTIVE_NAME("torchvision::ps_roi_align"),
37+
TORCH_FN(ps_roi_align_autocast));
38+
}
39+
40+
} // namespace ops
41+
} // namespace vision
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
#include "../../ps_roi_pool.h"
2+
3+
#include <ATen/autocast_mode.h>
4+
#include <torch/types.h>
5+
6+
namespace vision {
7+
namespace ops {
8+
9+
namespace {
10+
11+
std::tuple<at::Tensor, at::Tensor> ps_roi_pool_autocast(
12+
const at::Tensor& input,
13+
const at::Tensor& rois,
14+
double spatial_scale,
15+
int64_t pooled_height,
16+
int64_t pooled_width) {
17+
c10::impl::ExcludeDispatchKeyGuard no_autocast(c10::DispatchKey::AutocastCPU);
18+
auto result = ps_roi_pool(
19+
at::autocast::cached_cast(at::kFloat, input, c10::DeviceType::CPU),
20+
at::autocast::cached_cast(at::kFloat, rois, c10::DeviceType::CPU),
21+
spatial_scale,
22+
pooled_height,
23+
pooled_width);
24+
25+
return std::make_tuple(
26+
std::get<0>(result).to(input.scalar_type()),
27+
std::get<1>(result).to(input.scalar_type()));
28+
}
29+
30+
} // namespace
31+
32+
TORCH_LIBRARY_IMPL(torchvision, AutocastCPU, m) {
33+
m.impl(
34+
TORCH_SELECTIVE_NAME("torchvision::ps_roi_pool"),
35+
TORCH_FN(ps_roi_pool_autocast));
36+
}
37+
38+
} // namespace ops
39+
} // namespace vision
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
#include "../../roi_align.h"
2+
3+
#include <ATen/autocast_mode.h>
4+
#include <torch/types.h>
5+
6+
namespace vision {
7+
namespace ops {
8+
9+
namespace {
10+
11+
at::Tensor roi_align_autocast(
12+
const at::Tensor& input,
13+
const at::Tensor& rois,
14+
double spatial_scale,
15+
int64_t pooled_height,
16+
int64_t pooled_width,
17+
int64_t sampling_ratio,
18+
bool aligned) {
19+
c10::impl::ExcludeDispatchKeyGuard no_autocast(c10::DispatchKey::AutocastCPU);
20+
return roi_align(
21+
at::autocast::cached_cast(at::kFloat, input, c10::DeviceType::CPU),
22+
at::autocast::cached_cast(at::kFloat, rois, c10::DeviceType::CPU),
23+
spatial_scale,
24+
pooled_height,
25+
pooled_width,
26+
sampling_ratio,
27+
aligned)
28+
.to(input.scalar_type());
29+
}
30+
31+
} // namespace
32+
33+
TORCH_LIBRARY_IMPL(torchvision, AutocastCPU, m) {
34+
m.impl(
35+
TORCH_SELECTIVE_NAME("torchvision::roi_align"),
36+
TORCH_FN(roi_align_autocast));
37+
}
38+
39+
} // namespace ops
40+
} // namespace vision
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
#include "../../roi_pool.h"
2+
3+
#include <ATen/autocast_mode.h>
4+
#include <torch/types.h>
5+
6+
namespace vision {
7+
namespace ops {
8+
9+
namespace {
10+
11+
std::tuple<at::Tensor, at::Tensor> roi_pool_autocast(
12+
const at::Tensor& input,
13+
const at::Tensor& rois,
14+
double spatial_scale,
15+
int64_t pooled_height,
16+
int64_t pooled_width) {
17+
c10::impl::ExcludeDispatchKeyGuard no_autocast(c10::DispatchKey::AutocastCPU);
18+
auto result = roi_pool(
19+
at::autocast::cached_cast(at::kFloat, input, c10::DeviceType::CPU),
20+
at::autocast::cached_cast(at::kFloat, rois, c10::DeviceType::CPU),
21+
spatial_scale,
22+
pooled_height,
23+
pooled_width);
24+
25+
return std::make_tuple(
26+
std::get<0>(result).to(input.scalar_type()),
27+
std::get<1>(result).to(input.scalar_type()));
28+
}
29+
30+
} // namespace
31+
32+
TORCH_LIBRARY_IMPL(torchvision, AutocastCPU, m) {
33+
m.impl(
34+
TORCH_SELECTIVE_NAME("torchvision::roi_pool"),
35+
TORCH_FN(roi_pool_autocast));
36+
}
37+
38+
} // namespace ops
39+
} // namespace vision

torchvision/csrc/ops/autocast/deform_conv2d_kernel.cpp renamed to torchvision/csrc/ops/autocast/cuda/deform_conv2d_kernel.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
#include "../deform_conv2d.h"
1+
#include "../../deform_conv2d.h"
22

33
#include <ATen/autocast_mode.h>
44
#include <torch/types.h>

torchvision/csrc/ops/autocast/nms_kernel.cpp renamed to torchvision/csrc/ops/autocast/cuda/nms_kernel.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
#include "../nms.h"
1+
#include "../../nms.h"
22

33
#include <ATen/autocast_mode.h>
44
#include <torch/types.h>

0 commit comments

Comments
 (0)