From 803e90efba5636373106301e24e5337be4320cf7 Mon Sep 17 00:00:00 2001 From: CaoE Date: Tue, 14 Sep 2021 10:00:05 +0800 Subject: [PATCH 1/2] Register ops to AutocastCPU * modify the directory structure: moved the autocast files from torchvision/csrc/ops/autocast/ to torchvision/csrc/ops/autocast/cuda * add the cpu directory under the autocast directory; * register deform_conv2d, nms, ps_roi_align, ps_roi_pool, roi_align, and roi_pool to AutocastCPU. --- setup.py | 9 +-- test/test_ops.py | 58 ++++++++++--------- .../ops/autocast/cpu/deform_conv2d_kernel.cpp | 56 ++++++++++++++++++ .../csrc/ops/autocast/cpu/nms_kernel.cpp | 29 ++++++++++ .../ops/autocast/cpu/ps_roi_align_kernel.cpp | 41 +++++++++++++ .../ops/autocast/cpu/ps_roi_pool_kernel.cpp | 39 +++++++++++++ .../ops/autocast/cpu/roi_align_kernel.cpp | 40 +++++++++++++ .../csrc/ops/autocast/cpu/roi_pool_kernel.cpp | 39 +++++++++++++ .../{ => cuda}/deform_conv2d_kernel.cpp | 2 +- .../ops/autocast/{ => cuda}/nms_kernel.cpp | 2 +- .../{ => cuda}/ps_roi_align_kernel.cpp | 2 +- .../{ => cuda}/ps_roi_pool_kernel.cpp | 2 +- .../autocast/{ => cuda}/roi_align_kernel.cpp | 2 +- .../autocast/{ => cuda}/roi_pool_kernel.cpp | 2 +- 14 files changed, 286 insertions(+), 37 deletions(-) create mode 100644 torchvision/csrc/ops/autocast/cpu/deform_conv2d_kernel.cpp create mode 100644 torchvision/csrc/ops/autocast/cpu/nms_kernel.cpp create mode 100644 torchvision/csrc/ops/autocast/cpu/ps_roi_align_kernel.cpp create mode 100644 torchvision/csrc/ops/autocast/cpu/ps_roi_pool_kernel.cpp create mode 100644 torchvision/csrc/ops/autocast/cpu/roi_align_kernel.cpp create mode 100644 torchvision/csrc/ops/autocast/cpu/roi_pool_kernel.cpp rename torchvision/csrc/ops/autocast/{ => cuda}/deform_conv2d_kernel.cpp (97%) rename torchvision/csrc/ops/autocast/{ => cuda}/nms_kernel.cpp (96%) rename torchvision/csrc/ops/autocast/{ => cuda}/ps_roi_align_kernel.cpp (96%) rename torchvision/csrc/ops/autocast/{ => cuda}/ps_roi_pool_kernel.cpp (96%) rename torchvision/csrc/ops/autocast/{ => cuda}/roi_align_kernel.cpp (96%) rename torchvision/csrc/ops/autocast/{ => cuda}/roi_pool_kernel.cpp (97%) diff --git a/setup.py b/setup.py index a83e555acda..9ed85499829 100644 --- a/setup.py +++ b/setup.py @@ -133,9 +133,10 @@ def get_extensions(): os.path.join(extensions_dir, "ops", "*.cpp") ) source_cpu = ( - glob.glob(os.path.join(extensions_dir, "ops", "autograd", "*.cpp")) - + glob.glob(os.path.join(extensions_dir, "ops", "cpu", "*.cpp")) - + glob.glob(os.path.join(extensions_dir, "ops", "quantized", "cpu", "*.cpp")) + glob.glob(os.path.join(extensions_dir, 'ops', 'autograd', '*.cpp')) + + glob.glob(os.path.join(extensions_dir, 'ops', 'autocast', 'cpu', '*.cpp')) + + glob.glob(os.path.join(extensions_dir, 'ops', 'cpu', '*.cpp')) + + glob.glob(os.path.join(extensions_dir, 'ops', 'quantized', 'cpu', '*.cpp')) ) is_rocm_pytorch = False @@ -163,7 +164,7 @@ def get_extensions(): else: source_cuda = glob.glob(os.path.join(extensions_dir, "ops", "cuda", "*.cu")) - source_cuda += glob.glob(os.path.join(extensions_dir, "ops", "autocast", "*.cpp")) + source_cuda += glob.glob(os.path.join(extensions_dir, 'ops', 'autocast', 'cuda', '*.cpp')) sources = main_file + source_cpu extension = CppExtension diff --git a/test/test_ops.py b/test/test_ops.py index 64329936b72..5116837e732 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -1,7 +1,7 @@ import math import os from abc import ABC, abstractmethod -from functools import lru_cache +from functools import lru_cache, partial from typing import Tuple import numpy as np @@ -65,12 +65,13 @@ def func(z): gradcheck(func, (x,)) gradcheck(script_func, (x,)) - @needs_cuda - @pytest.mark.parametrize("x_dtype", (torch.float, torch.half)) - @pytest.mark.parametrize("rois_dtype", (torch.float, torch.half)) - def test_autocast(self, x_dtype, rois_dtype): - with torch.cuda.amp.autocast(): - self.test_forward(torch.device("cuda"), contiguous=False, x_dtype=x_dtype, rois_dtype=rois_dtype) + @pytest.mark.parametrize('device', cpu_and_gpu()) + @pytest.mark.parametrize('x_dtype', (torch.float, torch.half)) + @pytest.mark.parametrize('rois_dtype', (torch.float, torch.half)) + def test_autocast(self, device, x_dtype, rois_dtype): + cm = torch.cpu.amp.autocast if device == 'cpu' else torch.cuda.amp.autocast + with cm(): + self.test_forward(torch.device(device), contiguous=False, x_dtype=x_dtype, rois_dtype=rois_dtype) def _helper_boxes_shape(self, func): # test boxes as Tensor[N, 5] @@ -284,15 +285,15 @@ def test_forward(self, device, contiguous, aligned, x_dtype=None, rois_dtype=Non device=device, contiguous=contiguous, x_dtype=x_dtype, rois_dtype=rois_dtype, aligned=aligned ) - @needs_cuda - @pytest.mark.parametrize("aligned", (True, False)) - @pytest.mark.parametrize("x_dtype", (torch.float, torch.half)) - @pytest.mark.parametrize("rois_dtype", (torch.float, torch.half)) - def test_autocast(self, aligned, x_dtype, rois_dtype): - with torch.cuda.amp.autocast(): - self.test_forward( - torch.device("cuda"), contiguous=False, aligned=aligned, x_dtype=x_dtype, rois_dtype=rois_dtype - ) + @pytest.mark.parametrize('device', cpu_and_gpu()) + @pytest.mark.parametrize('aligned', (True, False)) + @pytest.mark.parametrize('x_dtype', (torch.float, torch.half)) + @pytest.mark.parametrize('rois_dtype', (torch.float, torch.half)) + def test_autocast(self, device, aligned, x_dtype, rois_dtype): + cm = torch.cpu.amp.autocast if device == 'cpu' else torch.cuda.amp.autocast + with cm(): + self.test_forward(torch.device(device), contiguous=False, aligned=aligned, x_dtype=x_dtype, + rois_dtype=rois_dtype) def _make_rois(self, img_size, num_imgs, dtype, num_rois=1000): rois = torch.randint(0, img_size // 2, size=(num_rois, 5)).to(dtype) @@ -532,12 +533,14 @@ def test_nms_cuda(self, iou, dtype=torch.float64): is_eq = torch.allclose(scores[r_cpu], scores[r_cuda.cpu()], rtol=tol, atol=tol) assert is_eq, err_msg.format(iou) - @needs_cuda - @pytest.mark.parametrize("iou", (0.2, 0.5, 0.8)) + @pytest.mark.parametrize('device', cpu_and_gpu()) + @pytest.mark.parametrize("iou", (.2, .5, .8)) @pytest.mark.parametrize("dtype", (torch.float, torch.half)) - def test_autocast(self, iou, dtype): - with torch.cuda.amp.autocast(): - self.test_nms_cuda(iou=iou, dtype=dtype) + def test_autocast(self, device, iou, dtype): + test_fn = self.test_nms_ref if device == 'cpu' else partial(self.test_nms_cuda, dtype=dtype) + cm = torch.cpu.amp.autocast if device == 'cpu' else torch.cuda.amp.autocast + with cm(): + test_fn(iou=iou) @needs_cuda def test_nms_cuda_float16(self): @@ -823,12 +826,13 @@ def test_compare_cpu_cuda_grads(self, contiguous): res_grads = init_weight.grad.to("cpu") torch.testing.assert_close(true_cpu_grads, res_grads) - @needs_cuda - @pytest.mark.parametrize("batch_sz", (0, 33)) - @pytest.mark.parametrize("dtype", (torch.float, torch.half)) - def test_autocast(self, batch_sz, dtype): - with torch.cuda.amp.autocast(): - self.test_forward(torch.device("cuda"), contiguous=False, batch_sz=batch_sz, dtype=dtype) + @pytest.mark.parametrize('device', cpu_and_gpu()) + @pytest.mark.parametrize('batch_sz', (0, 33)) + @pytest.mark.parametrize('dtype', (torch.float, torch.half)) + def test_autocast(self, device, batch_sz, dtype): + cm = torch.cpu.amp.autocast if device == 'cpu' else torch.cuda.amp.autocast + with cm(): + self.test_forward(torch.device(device), contiguous=False, batch_sz=batch_sz, dtype=dtype) def test_forward_scriptability(self): # Non-regression test for https://github.com/pytorch/vision/issues/4078 diff --git a/torchvision/csrc/ops/autocast/cpu/deform_conv2d_kernel.cpp b/torchvision/csrc/ops/autocast/cpu/deform_conv2d_kernel.cpp new file mode 100644 index 00000000000..e1aa77a4b1e --- /dev/null +++ b/torchvision/csrc/ops/autocast/cpu/deform_conv2d_kernel.cpp @@ -0,0 +1,56 @@ +#include "../../deform_conv2d.h" + +#include +#include + +namespace vision { +namespace ops { + +namespace { + +at::Tensor deform_conv2d_autocast( + const at::Tensor& input, + const at::Tensor& weight, + const at::Tensor& offset, + const at::Tensor& mask, + const at::Tensor& bias, + int64_t stride_h, + int64_t stride_w, + int64_t pad_h, + int64_t pad_w, + int64_t dilation_h, + int64_t dilation_w, + int64_t groups, + int64_t offset_groups, + bool use_mask) { + c10::impl::ExcludeDispatchKeyGuard no_autocast(c10::DispatchKey::AutocastCPU); + return deform_conv2d( + at::autocast::cached_cast(at::kFloat, input, c10::DeviceType::CPU), + at::autocast::cached_cast( + at::kFloat, weight, c10::DeviceType::CPU), + at::autocast::cached_cast( + at::kFloat, offset, c10::DeviceType::CPU), + at::autocast::cached_cast(at::kFloat, mask, c10::DeviceType::CPU), + at::autocast::cached_cast(at::kFloat, bias, c10::DeviceType::CPU), + stride_h, + stride_w, + pad_h, + pad_w, + dilation_h, + dilation_w, + groups, + offset_groups, + use_mask) + .to(input.scalar_type()); +} + +} // namespace + +TORCH_LIBRARY_IMPL(torchvision, AutocastCPU, m) { + m.impl( + TORCH_SELECTIVE_NAME("torchvision::deform_conv2d"), + TORCH_FN(deform_conv2d_autocast)); +} + +} // namespace ops +} // namespace vision diff --git a/torchvision/csrc/ops/autocast/cpu/nms_kernel.cpp b/torchvision/csrc/ops/autocast/cpu/nms_kernel.cpp new file mode 100644 index 00000000000..3f31c9bd2cc --- /dev/null +++ b/torchvision/csrc/ops/autocast/cpu/nms_kernel.cpp @@ -0,0 +1,29 @@ +#include "../../nms.h" + +#include +#include + +namespace vision { +namespace ops { + +namespace { + +at::Tensor nms_autocast( + const at::Tensor& dets, + const at::Tensor& scores, + double iou_threshold) { + c10::impl::ExcludeDispatchKeyGuard no_autocast(c10::DispatchKey::AutocastCPU); + return nms( + at::autocast::cached_cast(at::kFloat, dets, c10::DeviceType::CPU), + at::autocast::cached_cast(at::kFloat, scores, c10::DeviceType::CPU), + iou_threshold); +} + +} // namespace + +TORCH_LIBRARY_IMPL(torchvision, AutocastCPU, m) { + m.impl(TORCH_SELECTIVE_NAME("torchvision::nms"), TORCH_FN(nms_autocast)); +} + +} // namespace ops +} // namespace vision diff --git a/torchvision/csrc/ops/autocast/cpu/ps_roi_align_kernel.cpp b/torchvision/csrc/ops/autocast/cpu/ps_roi_align_kernel.cpp new file mode 100644 index 00000000000..565e14c49c8 --- /dev/null +++ b/torchvision/csrc/ops/autocast/cpu/ps_roi_align_kernel.cpp @@ -0,0 +1,41 @@ +#include "../../ps_roi_align.h" + +#include +#include + +namespace vision { +namespace ops { + +namespace { + +std::tuple ps_roi_align_autocast( + const at::Tensor& input, + const at::Tensor& rois, + double spatial_scale, + int64_t pooled_height, + int64_t pooled_width, + int64_t sampling_ratio) { + c10::impl::ExcludeDispatchKeyGuard no_autocast(c10::DispatchKey::AutocastCPU); + auto result = ps_roi_align( + at::autocast::cached_cast(at::kFloat, input, c10::DeviceType::CPU), + at::autocast::cached_cast(at::kFloat, rois, c10::DeviceType::CPU), + spatial_scale, + pooled_height, + pooled_width, + sampling_ratio); + + return std::make_tuple( + std::get<0>(result).to(input.scalar_type()), + std::get<1>(result).to(input.scalar_type())); +} + +} // namespace + +TORCH_LIBRARY_IMPL(torchvision, AutocastCPU, m) { + m.impl( + TORCH_SELECTIVE_NAME("torchvision::ps_roi_align"), + TORCH_FN(ps_roi_align_autocast)); +} + +} // namespace ops +} // namespace vision diff --git a/torchvision/csrc/ops/autocast/cpu/ps_roi_pool_kernel.cpp b/torchvision/csrc/ops/autocast/cpu/ps_roi_pool_kernel.cpp new file mode 100644 index 00000000000..45f598f71f3 --- /dev/null +++ b/torchvision/csrc/ops/autocast/cpu/ps_roi_pool_kernel.cpp @@ -0,0 +1,39 @@ +#include "../../ps_roi_pool.h" + +#include +#include + +namespace vision { +namespace ops { + +namespace { + +std::tuple ps_roi_pool_autocast( + const at::Tensor& input, + const at::Tensor& rois, + double spatial_scale, + int64_t pooled_height, + int64_t pooled_width) { + c10::impl::ExcludeDispatchKeyGuard no_autocast(c10::DispatchKey::AutocastCPU); + auto result = ps_roi_pool( + at::autocast::cached_cast(at::kFloat, input, c10::DeviceType::CPU), + at::autocast::cached_cast(at::kFloat, rois, c10::DeviceType::CPU), + spatial_scale, + pooled_height, + pooled_width); + + return std::make_tuple( + std::get<0>(result).to(input.scalar_type()), + std::get<1>(result).to(input.scalar_type())); +} + +} // namespace + +TORCH_LIBRARY_IMPL(torchvision, AutocastCPU, m) { + m.impl( + TORCH_SELECTIVE_NAME("torchvision::ps_roi_pool"), + TORCH_FN(ps_roi_pool_autocast)); +} + +} // namespace ops +} // namespace vision diff --git a/torchvision/csrc/ops/autocast/cpu/roi_align_kernel.cpp b/torchvision/csrc/ops/autocast/cpu/roi_align_kernel.cpp new file mode 100644 index 00000000000..5c320a68538 --- /dev/null +++ b/torchvision/csrc/ops/autocast/cpu/roi_align_kernel.cpp @@ -0,0 +1,40 @@ +#include "../../roi_align.h" + +#include +#include + +namespace vision { +namespace ops { + +namespace { + +at::Tensor roi_align_autocast( + const at::Tensor& input, + const at::Tensor& rois, + double spatial_scale, + int64_t pooled_height, + int64_t pooled_width, + int64_t sampling_ratio, + bool aligned) { + c10::impl::ExcludeDispatchKeyGuard no_autocast(c10::DispatchKey::AutocastCPU); + return roi_align( + at::autocast::cached_cast(at::kFloat, input, c10::DeviceType::CPU), + at::autocast::cached_cast(at::kFloat, rois, c10::DeviceType::CPU), + spatial_scale, + pooled_height, + pooled_width, + sampling_ratio, + aligned) + .to(input.scalar_type()); +} + +} // namespace + +TORCH_LIBRARY_IMPL(torchvision, AutocastCPU, m) { + m.impl( + TORCH_SELECTIVE_NAME("torchvision::roi_align"), + TORCH_FN(roi_align_autocast)); +} + +} // namespace ops +} // namespace vision diff --git a/torchvision/csrc/ops/autocast/cpu/roi_pool_kernel.cpp b/torchvision/csrc/ops/autocast/cpu/roi_pool_kernel.cpp new file mode 100644 index 00000000000..e2c75b3efd4 --- /dev/null +++ b/torchvision/csrc/ops/autocast/cpu/roi_pool_kernel.cpp @@ -0,0 +1,39 @@ +#include "../../roi_pool.h" + +#include +#include + +namespace vision { +namespace ops { + +namespace { + +std::tuple roi_pool_autocast( + const at::Tensor& input, + const at::Tensor& rois, + double spatial_scale, + int64_t pooled_height, + int64_t pooled_width) { + c10::impl::ExcludeDispatchKeyGuard no_autocast(c10::DispatchKey::AutocastCPU); + auto result = roi_pool( + at::autocast::cached_cast(at::kFloat, input, c10::DeviceType::CPU), + at::autocast::cached_cast(at::kFloat, rois, c10::DeviceType::CPU), + spatial_scale, + pooled_height, + pooled_width); + + return std::make_tuple( + std::get<0>(result).to(input.scalar_type()), + std::get<1>(result).to(input.scalar_type())); +} + +} // namespace + +TORCH_LIBRARY_IMPL(torchvision, AutocastCPU, m) { + m.impl( + TORCH_SELECTIVE_NAME("torchvision::roi_pool"), + TORCH_FN(roi_pool_autocast)); +} + +} // namespace ops +} // namespace vision diff --git a/torchvision/csrc/ops/autocast/deform_conv2d_kernel.cpp b/torchvision/csrc/ops/autocast/cuda/deform_conv2d_kernel.cpp similarity index 97% rename from torchvision/csrc/ops/autocast/deform_conv2d_kernel.cpp rename to torchvision/csrc/ops/autocast/cuda/deform_conv2d_kernel.cpp index 28c325be9b1..5cc19b41948 100644 --- a/torchvision/csrc/ops/autocast/deform_conv2d_kernel.cpp +++ b/torchvision/csrc/ops/autocast/cuda/deform_conv2d_kernel.cpp @@ -1,4 +1,4 @@ -#include "../deform_conv2d.h" +#include "../../deform_conv2d.h" #include #include diff --git a/torchvision/csrc/ops/autocast/nms_kernel.cpp b/torchvision/csrc/ops/autocast/cuda/nms_kernel.cpp similarity index 96% rename from torchvision/csrc/ops/autocast/nms_kernel.cpp rename to torchvision/csrc/ops/autocast/cuda/nms_kernel.cpp index 3a0ead004fd..269fd2e9a2a 100644 --- a/torchvision/csrc/ops/autocast/nms_kernel.cpp +++ b/torchvision/csrc/ops/autocast/cuda/nms_kernel.cpp @@ -1,4 +1,4 @@ -#include "../nms.h" +#include "../../nms.h" #include #include diff --git a/torchvision/csrc/ops/autocast/ps_roi_align_kernel.cpp b/torchvision/csrc/ops/autocast/cuda/ps_roi_align_kernel.cpp similarity index 96% rename from torchvision/csrc/ops/autocast/ps_roi_align_kernel.cpp rename to torchvision/csrc/ops/autocast/cuda/ps_roi_align_kernel.cpp index c93b26c8ad3..d8b553dc0e9 100644 --- a/torchvision/csrc/ops/autocast/ps_roi_align_kernel.cpp +++ b/torchvision/csrc/ops/autocast/cuda/ps_roi_align_kernel.cpp @@ -1,4 +1,4 @@ -#include "../ps_roi_align.h" +#include "../../ps_roi_align.h" #include #include diff --git a/torchvision/csrc/ops/autocast/ps_roi_pool_kernel.cpp b/torchvision/csrc/ops/autocast/cuda/ps_roi_pool_kernel.cpp similarity index 96% rename from torchvision/csrc/ops/autocast/ps_roi_pool_kernel.cpp rename to torchvision/csrc/ops/autocast/cuda/ps_roi_pool_kernel.cpp index 1421680ea98..b2816e6eb7c 100644 --- a/torchvision/csrc/ops/autocast/ps_roi_pool_kernel.cpp +++ b/torchvision/csrc/ops/autocast/cuda/ps_roi_pool_kernel.cpp @@ -1,4 +1,4 @@ -#include "../ps_roi_pool.h" +#include "../../ps_roi_pool.h" #include #include diff --git a/torchvision/csrc/ops/autocast/roi_align_kernel.cpp b/torchvision/csrc/ops/autocast/cuda/roi_align_kernel.cpp similarity index 96% rename from torchvision/csrc/ops/autocast/roi_align_kernel.cpp rename to torchvision/csrc/ops/autocast/cuda/roi_align_kernel.cpp index 95457224ac0..129d3d8e76b 100644 --- a/torchvision/csrc/ops/autocast/roi_align_kernel.cpp +++ b/torchvision/csrc/ops/autocast/cuda/roi_align_kernel.cpp @@ -1,4 +1,4 @@ -#include "../roi_align.h" +#include "../../roi_align.h" #include #include diff --git a/torchvision/csrc/ops/autocast/roi_pool_kernel.cpp b/torchvision/csrc/ops/autocast/cuda/roi_pool_kernel.cpp similarity index 97% rename from torchvision/csrc/ops/autocast/roi_pool_kernel.cpp rename to torchvision/csrc/ops/autocast/cuda/roi_pool_kernel.cpp index d317c38c792..6c4863fe542 100644 --- a/torchvision/csrc/ops/autocast/roi_pool_kernel.cpp +++ b/torchvision/csrc/ops/autocast/cuda/roi_pool_kernel.cpp @@ -1,4 +1,4 @@ -#include "../roi_pool.h" +#include "../../roi_pool.h" #include #include From 8ca7e75ad075ae987ceea55acf61506cc7178ebb Mon Sep 17 00:00:00 2001 From: CaoE Date: Tue, 12 Oct 2021 13:05:18 +0800 Subject: [PATCH 2/2] fix code format --- setup.py | 10 +++++----- test/test_ops.py | 39 ++++++++++++++++++++------------------- 2 files changed, 25 insertions(+), 24 deletions(-) diff --git a/setup.py b/setup.py index 9ed85499829..d8333089170 100644 --- a/setup.py +++ b/setup.py @@ -133,10 +133,10 @@ def get_extensions(): os.path.join(extensions_dir, "ops", "*.cpp") ) source_cpu = ( - glob.glob(os.path.join(extensions_dir, 'ops', 'autograd', '*.cpp')) + - glob.glob(os.path.join(extensions_dir, 'ops', 'autocast', 'cpu', '*.cpp')) + - glob.glob(os.path.join(extensions_dir, 'ops', 'cpu', '*.cpp')) + - glob.glob(os.path.join(extensions_dir, 'ops', 'quantized', 'cpu', '*.cpp')) + glob.glob(os.path.join(extensions_dir, "ops", "autograd", "*.cpp")) + + glob.glob(os.path.join(extensions_dir, "ops", "autocast", "cpu", "*.cpp")) + + glob.glob(os.path.join(extensions_dir, "ops", "cpu", "*.cpp")) + + glob.glob(os.path.join(extensions_dir, "ops", "quantized", "cpu", "*.cpp")) ) is_rocm_pytorch = False @@ -164,7 +164,7 @@ def get_extensions(): else: source_cuda = glob.glob(os.path.join(extensions_dir, "ops", "cuda", "*.cu")) - source_cuda += glob.glob(os.path.join(extensions_dir, 'ops', 'autocast', 'cuda', '*.cpp')) + source_cuda += glob.glob(os.path.join(extensions_dir, "ops", "autocast", "cuda", "*.cpp")) sources = main_file + source_cpu extension = CppExtension diff --git a/test/test_ops.py b/test/test_ops.py index 5116837e732..83f8773d1d9 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -65,11 +65,11 @@ def func(z): gradcheck(func, (x,)) gradcheck(script_func, (x,)) - @pytest.mark.parametrize('device', cpu_and_gpu()) - @pytest.mark.parametrize('x_dtype', (torch.float, torch.half)) - @pytest.mark.parametrize('rois_dtype', (torch.float, torch.half)) + @pytest.mark.parametrize("device", cpu_and_gpu()) + @pytest.mark.parametrize("x_dtype", (torch.float, torch.half)) + @pytest.mark.parametrize("rois_dtype", (torch.float, torch.half)) def test_autocast(self, device, x_dtype, rois_dtype): - cm = torch.cpu.amp.autocast if device == 'cpu' else torch.cuda.amp.autocast + cm = torch.cpu.amp.autocast if device == "cpu" else torch.cuda.amp.autocast with cm(): self.test_forward(torch.device(device), contiguous=False, x_dtype=x_dtype, rois_dtype=rois_dtype) @@ -285,15 +285,16 @@ def test_forward(self, device, contiguous, aligned, x_dtype=None, rois_dtype=Non device=device, contiguous=contiguous, x_dtype=x_dtype, rois_dtype=rois_dtype, aligned=aligned ) - @pytest.mark.parametrize('device', cpu_and_gpu()) - @pytest.mark.parametrize('aligned', (True, False)) - @pytest.mark.parametrize('x_dtype', (torch.float, torch.half)) - @pytest.mark.parametrize('rois_dtype', (torch.float, torch.half)) + @pytest.mark.parametrize("device", cpu_and_gpu()) + @pytest.mark.parametrize("aligned", (True, False)) + @pytest.mark.parametrize("x_dtype", (torch.float, torch.half)) + @pytest.mark.parametrize("rois_dtype", (torch.float, torch.half)) def test_autocast(self, device, aligned, x_dtype, rois_dtype): - cm = torch.cpu.amp.autocast if device == 'cpu' else torch.cuda.amp.autocast + cm = torch.cpu.amp.autocast if device == "cpu" else torch.cuda.amp.autocast with cm(): - self.test_forward(torch.device(device), contiguous=False, aligned=aligned, x_dtype=x_dtype, - rois_dtype=rois_dtype) + self.test_forward( + torch.device(device), contiguous=False, aligned=aligned, x_dtype=x_dtype, rois_dtype=rois_dtype + ) def _make_rois(self, img_size, num_imgs, dtype, num_rois=1000): rois = torch.randint(0, img_size // 2, size=(num_rois, 5)).to(dtype) @@ -533,12 +534,12 @@ def test_nms_cuda(self, iou, dtype=torch.float64): is_eq = torch.allclose(scores[r_cpu], scores[r_cuda.cpu()], rtol=tol, atol=tol) assert is_eq, err_msg.format(iou) - @pytest.mark.parametrize('device', cpu_and_gpu()) - @pytest.mark.parametrize("iou", (.2, .5, .8)) + @pytest.mark.parametrize("device", cpu_and_gpu()) + @pytest.mark.parametrize("iou", (0.2, 0.5, 0.8)) @pytest.mark.parametrize("dtype", (torch.float, torch.half)) def test_autocast(self, device, iou, dtype): - test_fn = self.test_nms_ref if device == 'cpu' else partial(self.test_nms_cuda, dtype=dtype) - cm = torch.cpu.amp.autocast if device == 'cpu' else torch.cuda.amp.autocast + test_fn = self.test_nms_ref if device == "cpu" else partial(self.test_nms_cuda, dtype=dtype) + cm = torch.cpu.amp.autocast if device == "cpu" else torch.cuda.amp.autocast with cm(): test_fn(iou=iou) @@ -826,11 +827,11 @@ def test_compare_cpu_cuda_grads(self, contiguous): res_grads = init_weight.grad.to("cpu") torch.testing.assert_close(true_cpu_grads, res_grads) - @pytest.mark.parametrize('device', cpu_and_gpu()) - @pytest.mark.parametrize('batch_sz', (0, 33)) - @pytest.mark.parametrize('dtype', (torch.float, torch.half)) + @pytest.mark.parametrize("device", cpu_and_gpu()) + @pytest.mark.parametrize("batch_sz", (0, 33)) + @pytest.mark.parametrize("dtype", (torch.float, torch.half)) def test_autocast(self, device, batch_sz, dtype): - cm = torch.cpu.amp.autocast if device == 'cpu' else torch.cuda.amp.autocast + cm = torch.cpu.amp.autocast if device == "cpu" else torch.cuda.amp.autocast with cm(): self.test_forward(torch.device(device), contiguous=False, batch_sz=batch_sz, dtype=dtype)