diff --git a/setup.py b/setup.py index a83e555acda..d8333089170 100644 --- a/setup.py +++ b/setup.py @@ -134,6 +134,7 @@ def get_extensions(): ) source_cpu = ( glob.glob(os.path.join(extensions_dir, "ops", "autograd", "*.cpp")) + + glob.glob(os.path.join(extensions_dir, "ops", "autocast", "cpu", "*.cpp")) + glob.glob(os.path.join(extensions_dir, "ops", "cpu", "*.cpp")) + glob.glob(os.path.join(extensions_dir, "ops", "quantized", "cpu", "*.cpp")) ) @@ -163,7 +164,7 @@ def get_extensions(): else: source_cuda = glob.glob(os.path.join(extensions_dir, "ops", "cuda", "*.cu")) - source_cuda += glob.glob(os.path.join(extensions_dir, "ops", "autocast", "*.cpp")) + source_cuda += glob.glob(os.path.join(extensions_dir, "ops", "autocast", "cuda", "*.cpp")) sources = main_file + source_cpu extension = CppExtension diff --git a/test/test_ops.py b/test/test_ops.py index 64329936b72..83f8773d1d9 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -1,7 +1,7 @@ import math import os from abc import ABC, abstractmethod -from functools import lru_cache +from functools import lru_cache, partial from typing import Tuple import numpy as np @@ -65,12 +65,13 @@ def func(z): gradcheck(func, (x,)) gradcheck(script_func, (x,)) - @needs_cuda + @pytest.mark.parametrize("device", cpu_and_gpu()) @pytest.mark.parametrize("x_dtype", (torch.float, torch.half)) @pytest.mark.parametrize("rois_dtype", (torch.float, torch.half)) - def test_autocast(self, x_dtype, rois_dtype): - with torch.cuda.amp.autocast(): - self.test_forward(torch.device("cuda"), contiguous=False, x_dtype=x_dtype, rois_dtype=rois_dtype) + def test_autocast(self, device, x_dtype, rois_dtype): + cm = torch.cpu.amp.autocast if device == "cpu" else torch.cuda.amp.autocast + with cm(): + self.test_forward(torch.device(device), contiguous=False, x_dtype=x_dtype, rois_dtype=rois_dtype) def _helper_boxes_shape(self, func): # test boxes as Tensor[N, 5] @@ -284,14 +285,15 @@ def test_forward(self, device, contiguous, aligned, x_dtype=None, rois_dtype=Non device=device, contiguous=contiguous, x_dtype=x_dtype, rois_dtype=rois_dtype, aligned=aligned ) - @needs_cuda + @pytest.mark.parametrize("device", cpu_and_gpu()) @pytest.mark.parametrize("aligned", (True, False)) @pytest.mark.parametrize("x_dtype", (torch.float, torch.half)) @pytest.mark.parametrize("rois_dtype", (torch.float, torch.half)) - def test_autocast(self, aligned, x_dtype, rois_dtype): - with torch.cuda.amp.autocast(): + def test_autocast(self, device, aligned, x_dtype, rois_dtype): + cm = torch.cpu.amp.autocast if device == "cpu" else torch.cuda.amp.autocast + with cm(): self.test_forward( - torch.device("cuda"), contiguous=False, aligned=aligned, x_dtype=x_dtype, rois_dtype=rois_dtype + torch.device(device), contiguous=False, aligned=aligned, x_dtype=x_dtype, rois_dtype=rois_dtype ) def _make_rois(self, img_size, num_imgs, dtype, num_rois=1000): @@ -532,12 +534,14 @@ def test_nms_cuda(self, iou, dtype=torch.float64): is_eq = torch.allclose(scores[r_cpu], scores[r_cuda.cpu()], rtol=tol, atol=tol) assert is_eq, err_msg.format(iou) - @needs_cuda + @pytest.mark.parametrize("device", cpu_and_gpu()) @pytest.mark.parametrize("iou", (0.2, 0.5, 0.8)) @pytest.mark.parametrize("dtype", (torch.float, torch.half)) - def test_autocast(self, iou, dtype): - with torch.cuda.amp.autocast(): - self.test_nms_cuda(iou=iou, dtype=dtype) + def test_autocast(self, device, iou, dtype): + test_fn = self.test_nms_ref if device == "cpu" else partial(self.test_nms_cuda, dtype=dtype) + cm = torch.cpu.amp.autocast if device == "cpu" else torch.cuda.amp.autocast + with cm(): + test_fn(iou=iou) @needs_cuda def test_nms_cuda_float16(self): @@ -823,12 +827,13 @@ def test_compare_cpu_cuda_grads(self, contiguous): res_grads = init_weight.grad.to("cpu") torch.testing.assert_close(true_cpu_grads, res_grads) - @needs_cuda + @pytest.mark.parametrize("device", cpu_and_gpu()) @pytest.mark.parametrize("batch_sz", (0, 33)) @pytest.mark.parametrize("dtype", (torch.float, torch.half)) - def test_autocast(self, batch_sz, dtype): - with torch.cuda.amp.autocast(): - self.test_forward(torch.device("cuda"), contiguous=False, batch_sz=batch_sz, dtype=dtype) + def test_autocast(self, device, batch_sz, dtype): + cm = torch.cpu.amp.autocast if device == "cpu" else torch.cuda.amp.autocast + with cm(): + self.test_forward(torch.device(device), contiguous=False, batch_sz=batch_sz, dtype=dtype) def test_forward_scriptability(self): # Non-regression test for https://github.com/pytorch/vision/issues/4078 diff --git a/torchvision/csrc/ops/autocast/cpu/deform_conv2d_kernel.cpp b/torchvision/csrc/ops/autocast/cpu/deform_conv2d_kernel.cpp new file mode 100644 index 00000000000..e1aa77a4b1e --- /dev/null +++ b/torchvision/csrc/ops/autocast/cpu/deform_conv2d_kernel.cpp @@ -0,0 +1,56 @@ +#include "../../deform_conv2d.h" + +#include +#include + +namespace vision { +namespace ops { + +namespace { + +at::Tensor deform_conv2d_autocast( + const at::Tensor& input, + const at::Tensor& weight, + const at::Tensor& offset, + const at::Tensor& mask, + const at::Tensor& bias, + int64_t stride_h, + int64_t stride_w, + int64_t pad_h, + int64_t pad_w, + int64_t dilation_h, + int64_t dilation_w, + int64_t groups, + int64_t offset_groups, + bool use_mask) { + c10::impl::ExcludeDispatchKeyGuard no_autocast(c10::DispatchKey::AutocastCPU); + return deform_conv2d( + at::autocast::cached_cast(at::kFloat, input, c10::DeviceType::CPU), + at::autocast::cached_cast( + at::kFloat, weight, c10::DeviceType::CPU), + at::autocast::cached_cast( + at::kFloat, offset, c10::DeviceType::CPU), + at::autocast::cached_cast(at::kFloat, mask, c10::DeviceType::CPU), + at::autocast::cached_cast(at::kFloat, bias, c10::DeviceType::CPU), + stride_h, + stride_w, + pad_h, + pad_w, + dilation_h, + dilation_w, + groups, + offset_groups, + use_mask) + .to(input.scalar_type()); +} + +} // namespace + +TORCH_LIBRARY_IMPL(torchvision, AutocastCPU, m) { + m.impl( + TORCH_SELECTIVE_NAME("torchvision::deform_conv2d"), + TORCH_FN(deform_conv2d_autocast)); +} + +} // namespace ops +} // namespace vision diff --git a/torchvision/csrc/ops/autocast/cpu/nms_kernel.cpp b/torchvision/csrc/ops/autocast/cpu/nms_kernel.cpp new file mode 100644 index 00000000000..3f31c9bd2cc --- /dev/null +++ b/torchvision/csrc/ops/autocast/cpu/nms_kernel.cpp @@ -0,0 +1,29 @@ +#include "../../nms.h" + +#include +#include + +namespace vision { +namespace ops { + +namespace { + +at::Tensor nms_autocast( + const at::Tensor& dets, + const at::Tensor& scores, + double iou_threshold) { + c10::impl::ExcludeDispatchKeyGuard no_autocast(c10::DispatchKey::AutocastCPU); + return nms( + at::autocast::cached_cast(at::kFloat, dets, c10::DeviceType::CPU), + at::autocast::cached_cast(at::kFloat, scores, c10::DeviceType::CPU), + iou_threshold); +} + +} // namespace + +TORCH_LIBRARY_IMPL(torchvision, AutocastCPU, m) { + m.impl(TORCH_SELECTIVE_NAME("torchvision::nms"), TORCH_FN(nms_autocast)); +} + +} // namespace ops +} // namespace vision diff --git a/torchvision/csrc/ops/autocast/cpu/ps_roi_align_kernel.cpp b/torchvision/csrc/ops/autocast/cpu/ps_roi_align_kernel.cpp new file mode 100644 index 00000000000..565e14c49c8 --- /dev/null +++ b/torchvision/csrc/ops/autocast/cpu/ps_roi_align_kernel.cpp @@ -0,0 +1,41 @@ +#include "../../ps_roi_align.h" + +#include +#include + +namespace vision { +namespace ops { + +namespace { + +std::tuple ps_roi_align_autocast( + const at::Tensor& input, + const at::Tensor& rois, + double spatial_scale, + int64_t pooled_height, + int64_t pooled_width, + int64_t sampling_ratio) { + c10::impl::ExcludeDispatchKeyGuard no_autocast(c10::DispatchKey::AutocastCPU); + auto result = ps_roi_align( + at::autocast::cached_cast(at::kFloat, input, c10::DeviceType::CPU), + at::autocast::cached_cast(at::kFloat, rois, c10::DeviceType::CPU), + spatial_scale, + pooled_height, + pooled_width, + sampling_ratio); + + return std::make_tuple( + std::get<0>(result).to(input.scalar_type()), + std::get<1>(result).to(input.scalar_type())); +} + +} // namespace + +TORCH_LIBRARY_IMPL(torchvision, AutocastCPU, m) { + m.impl( + TORCH_SELECTIVE_NAME("torchvision::ps_roi_align"), + TORCH_FN(ps_roi_align_autocast)); +} + +} // namespace ops +} // namespace vision diff --git a/torchvision/csrc/ops/autocast/cpu/ps_roi_pool_kernel.cpp b/torchvision/csrc/ops/autocast/cpu/ps_roi_pool_kernel.cpp new file mode 100644 index 00000000000..45f598f71f3 --- /dev/null +++ b/torchvision/csrc/ops/autocast/cpu/ps_roi_pool_kernel.cpp @@ -0,0 +1,39 @@ +#include "../../ps_roi_pool.h" + +#include +#include + +namespace vision { +namespace ops { + +namespace { + +std::tuple ps_roi_pool_autocast( + const at::Tensor& input, + const at::Tensor& rois, + double spatial_scale, + int64_t pooled_height, + int64_t pooled_width) { + c10::impl::ExcludeDispatchKeyGuard no_autocast(c10::DispatchKey::AutocastCPU); + auto result = ps_roi_pool( + at::autocast::cached_cast(at::kFloat, input, c10::DeviceType::CPU), + at::autocast::cached_cast(at::kFloat, rois, c10::DeviceType::CPU), + spatial_scale, + pooled_height, + pooled_width); + + return std::make_tuple( + std::get<0>(result).to(input.scalar_type()), + std::get<1>(result).to(input.scalar_type())); +} + +} // namespace + +TORCH_LIBRARY_IMPL(torchvision, AutocastCPU, m) { + m.impl( + TORCH_SELECTIVE_NAME("torchvision::ps_roi_pool"), + TORCH_FN(ps_roi_pool_autocast)); +} + +} // namespace ops +} // namespace vision diff --git a/torchvision/csrc/ops/autocast/cpu/roi_align_kernel.cpp b/torchvision/csrc/ops/autocast/cpu/roi_align_kernel.cpp new file mode 100644 index 00000000000..5c320a68538 --- /dev/null +++ b/torchvision/csrc/ops/autocast/cpu/roi_align_kernel.cpp @@ -0,0 +1,40 @@ +#include "../../roi_align.h" + +#include +#include + +namespace vision { +namespace ops { + +namespace { + +at::Tensor roi_align_autocast( + const at::Tensor& input, + const at::Tensor& rois, + double spatial_scale, + int64_t pooled_height, + int64_t pooled_width, + int64_t sampling_ratio, + bool aligned) { + c10::impl::ExcludeDispatchKeyGuard no_autocast(c10::DispatchKey::AutocastCPU); + return roi_align( + at::autocast::cached_cast(at::kFloat, input, c10::DeviceType::CPU), + at::autocast::cached_cast(at::kFloat, rois, c10::DeviceType::CPU), + spatial_scale, + pooled_height, + pooled_width, + sampling_ratio, + aligned) + .to(input.scalar_type()); +} + +} // namespace + +TORCH_LIBRARY_IMPL(torchvision, AutocastCPU, m) { + m.impl( + TORCH_SELECTIVE_NAME("torchvision::roi_align"), + TORCH_FN(roi_align_autocast)); +} + +} // namespace ops +} // namespace vision diff --git a/torchvision/csrc/ops/autocast/cpu/roi_pool_kernel.cpp b/torchvision/csrc/ops/autocast/cpu/roi_pool_kernel.cpp new file mode 100644 index 00000000000..e2c75b3efd4 --- /dev/null +++ b/torchvision/csrc/ops/autocast/cpu/roi_pool_kernel.cpp @@ -0,0 +1,39 @@ +#include "../../roi_pool.h" + +#include +#include + +namespace vision { +namespace ops { + +namespace { + +std::tuple roi_pool_autocast( + const at::Tensor& input, + const at::Tensor& rois, + double spatial_scale, + int64_t pooled_height, + int64_t pooled_width) { + c10::impl::ExcludeDispatchKeyGuard no_autocast(c10::DispatchKey::AutocastCPU); + auto result = roi_pool( + at::autocast::cached_cast(at::kFloat, input, c10::DeviceType::CPU), + at::autocast::cached_cast(at::kFloat, rois, c10::DeviceType::CPU), + spatial_scale, + pooled_height, + pooled_width); + + return std::make_tuple( + std::get<0>(result).to(input.scalar_type()), + std::get<1>(result).to(input.scalar_type())); +} + +} // namespace + +TORCH_LIBRARY_IMPL(torchvision, AutocastCPU, m) { + m.impl( + TORCH_SELECTIVE_NAME("torchvision::roi_pool"), + TORCH_FN(roi_pool_autocast)); +} + +} // namespace ops +} // namespace vision diff --git a/torchvision/csrc/ops/autocast/deform_conv2d_kernel.cpp b/torchvision/csrc/ops/autocast/cuda/deform_conv2d_kernel.cpp similarity index 97% rename from torchvision/csrc/ops/autocast/deform_conv2d_kernel.cpp rename to torchvision/csrc/ops/autocast/cuda/deform_conv2d_kernel.cpp index 28c325be9b1..5cc19b41948 100644 --- a/torchvision/csrc/ops/autocast/deform_conv2d_kernel.cpp +++ b/torchvision/csrc/ops/autocast/cuda/deform_conv2d_kernel.cpp @@ -1,4 +1,4 @@ -#include "../deform_conv2d.h" +#include "../../deform_conv2d.h" #include #include diff --git a/torchvision/csrc/ops/autocast/nms_kernel.cpp b/torchvision/csrc/ops/autocast/cuda/nms_kernel.cpp similarity index 96% rename from torchvision/csrc/ops/autocast/nms_kernel.cpp rename to torchvision/csrc/ops/autocast/cuda/nms_kernel.cpp index 3a0ead004fd..269fd2e9a2a 100644 --- a/torchvision/csrc/ops/autocast/nms_kernel.cpp +++ b/torchvision/csrc/ops/autocast/cuda/nms_kernel.cpp @@ -1,4 +1,4 @@ -#include "../nms.h" +#include "../../nms.h" #include #include diff --git a/torchvision/csrc/ops/autocast/ps_roi_align_kernel.cpp b/torchvision/csrc/ops/autocast/cuda/ps_roi_align_kernel.cpp similarity index 96% rename from torchvision/csrc/ops/autocast/ps_roi_align_kernel.cpp rename to torchvision/csrc/ops/autocast/cuda/ps_roi_align_kernel.cpp index c93b26c8ad3..d8b553dc0e9 100644 --- a/torchvision/csrc/ops/autocast/ps_roi_align_kernel.cpp +++ b/torchvision/csrc/ops/autocast/cuda/ps_roi_align_kernel.cpp @@ -1,4 +1,4 @@ -#include "../ps_roi_align.h" +#include "../../ps_roi_align.h" #include #include diff --git a/torchvision/csrc/ops/autocast/ps_roi_pool_kernel.cpp b/torchvision/csrc/ops/autocast/cuda/ps_roi_pool_kernel.cpp similarity index 96% rename from torchvision/csrc/ops/autocast/ps_roi_pool_kernel.cpp rename to torchvision/csrc/ops/autocast/cuda/ps_roi_pool_kernel.cpp index 1421680ea98..b2816e6eb7c 100644 --- a/torchvision/csrc/ops/autocast/ps_roi_pool_kernel.cpp +++ b/torchvision/csrc/ops/autocast/cuda/ps_roi_pool_kernel.cpp @@ -1,4 +1,4 @@ -#include "../ps_roi_pool.h" +#include "../../ps_roi_pool.h" #include #include diff --git a/torchvision/csrc/ops/autocast/roi_align_kernel.cpp b/torchvision/csrc/ops/autocast/cuda/roi_align_kernel.cpp similarity index 96% rename from torchvision/csrc/ops/autocast/roi_align_kernel.cpp rename to torchvision/csrc/ops/autocast/cuda/roi_align_kernel.cpp index 95457224ac0..129d3d8e76b 100644 --- a/torchvision/csrc/ops/autocast/roi_align_kernel.cpp +++ b/torchvision/csrc/ops/autocast/cuda/roi_align_kernel.cpp @@ -1,4 +1,4 @@ -#include "../roi_align.h" +#include "../../roi_align.h" #include #include diff --git a/torchvision/csrc/ops/autocast/roi_pool_kernel.cpp b/torchvision/csrc/ops/autocast/cuda/roi_pool_kernel.cpp similarity index 97% rename from torchvision/csrc/ops/autocast/roi_pool_kernel.cpp rename to torchvision/csrc/ops/autocast/cuda/roi_pool_kernel.cpp index d317c38c792..6c4863fe542 100644 --- a/torchvision/csrc/ops/autocast/roi_pool_kernel.cpp +++ b/torchvision/csrc/ops/autocast/cuda/roi_pool_kernel.cpp @@ -1,4 +1,4 @@ -#include "../roi_pool.h" +#include "../../roi_pool.h" #include #include