From 85c5f4ed6fa3e784409b3c153767600659e81e4c Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Tue, 1 Dec 2020 22:14:29 +0000 Subject: [PATCH 1/5] Renaming C++ files & methods according to recommended naming conventions and aligning them with Python's API. --- ...{PSROIPool_cpu.cpp => ps_roi_pool_cpu.cpp} | 20 +++++++++---------- torchvision/csrc/cpu/vision_cpu.h | 4 ++-- ...{PSROIPool_cuda.cu => ps_roi_pool_cuda.cu} | 20 +++++++++---------- torchvision/csrc/cuda/vision_cuda.h | 4 ++-- .../csrc/{PSROIPool.h => ps_roi_pool.h} | 6 +++--- torchvision/csrc/vision.cpp | 16 +++++++-------- 6 files changed, 35 insertions(+), 35 deletions(-) rename torchvision/csrc/cpu/{PSROIPool_cpu.cpp => ps_roi_pool_cpu.cpp} (94%) rename torchvision/csrc/cuda/{PSROIPool_cuda.cu => ps_roi_pool_cuda.cu} (94%) rename torchvision/csrc/{PSROIPool.h => ps_roi_pool.h} (97%) diff --git a/torchvision/csrc/cpu/PSROIPool_cpu.cpp b/torchvision/csrc/cpu/ps_roi_pool_cpu.cpp similarity index 94% rename from torchvision/csrc/cpu/PSROIPool_cpu.cpp rename to torchvision/csrc/cpu/ps_roi_pool_cpu.cpp index c6e0a64cac3..bb27b22dd10 100644 --- a/torchvision/csrc/cpu/PSROIPool_cpu.cpp +++ b/torchvision/csrc/cpu/ps_roi_pool_cpu.cpp @@ -9,7 +9,7 @@ inline void add(T* address, const T& val) { } template -void PSROIPoolForward( +void ps_roi_pool_forward_kernel_impl( const T* input, const T spatial_scale, int channels, @@ -79,7 +79,7 @@ void PSROIPoolForward( } template -void PSROIPoolBackward( +void ps_roi_pool_backward_kernel_impl( const T* grad_output, const int* channel_mapping, int num_rois, @@ -143,7 +143,7 @@ void PSROIPoolBackward( } } -std::tuple PSROIPool_forward_cpu( +std::tuple ps_roi_pool_forward_cpu( const at::Tensor& input, const at::Tensor& rois, double spatial_scale, @@ -157,7 +157,7 @@ std::tuple PSROIPool_forward_cpu( at::TensorArg input_t{input, "input", 1}, rois_t{rois, "rois", 2}; - at::CheckedFrom c = "PSROIPool_forward_cpu"; + at::CheckedFrom c = "ps_roi_pool_forward_cpu"; at::checkAllSameType(c, {input_t, rois_t}); int num_rois = rois.size(0); @@ -182,8 +182,8 @@ std::tuple PSROIPool_forward_cpu( auto input_ = input.contiguous(), rois_ = rois.contiguous(); AT_DISPATCH_FLOATING_TYPES_AND_HALF( - input.scalar_type(), "PSROIPool_forward", [&] { - PSROIPoolForward( + input.scalar_type(), "ps_roi_pool_forward", [&] { + ps_roi_pool_forward_kernel_impl( input_.data_ptr(), spatial_scale, channels, @@ -200,7 +200,7 @@ std::tuple PSROIPool_forward_cpu( return std::make_tuple(output, channel_mapping); } -at::Tensor PSROIPool_backward_cpu( +at::Tensor ps_roi_pool_backward_cpu( const at::Tensor& grad, const at::Tensor& rois, const at::Tensor& channel_mapping, @@ -221,7 +221,7 @@ at::Tensor PSROIPool_backward_cpu( at::TensorArg grad_t{grad, "grad", 1}, rois_t{rois, "rois", 2}, channel_mapping_t{channel_mapping, "channel_mapping", 3}; - at::CheckedFrom c = "PSROIPool_backward_cpu"; + at::CheckedFrom c = "ps_roi_pool_backward_cpu"; at::checkAllSameType(c, {grad_t, rois_t}); auto num_rois = rois.size(0); @@ -237,8 +237,8 @@ at::Tensor PSROIPool_backward_cpu( auto grad_ = grad.contiguous(), rois_ = rois.contiguous(); AT_DISPATCH_FLOATING_TYPES_AND_HALF( - grad.scalar_type(), "PSROIPool_backward", [&] { - PSROIPoolBackward( + grad.scalar_type(), "ps_roi_pool_backward", [&] { + ps_roi_pool_backward_kernel_impl( grad_.data_ptr(), channel_mapping.data_ptr(), num_rois, diff --git a/torchvision/csrc/cpu/vision_cpu.h b/torchvision/csrc/cpu/vision_cpu.h index 22119b5e292..92ed6b58382 100644 --- a/torchvision/csrc/cpu/vision_cpu.h +++ b/torchvision/csrc/cpu/vision_cpu.h @@ -4,14 +4,14 @@ // TODO: Delete this file once all the methods are gone -VISION_API std::tuple PSROIPool_forward_cpu( +VISION_API std::tuple ps_roi_pool_forward_cpu( const at::Tensor& input, const at::Tensor& rois, double spatial_scale, int64_t pooled_height, int64_t pooled_width); -VISION_API at::Tensor PSROIPool_backward_cpu( +VISION_API at::Tensor ps_roi_pool_backward_cpu( const at::Tensor& grad, const at::Tensor& rois, const at::Tensor& channel_mapping, diff --git a/torchvision/csrc/cuda/PSROIPool_cuda.cu b/torchvision/csrc/cuda/ps_roi_pool_cuda.cu similarity index 94% rename from torchvision/csrc/cuda/PSROIPool_cuda.cu rename to torchvision/csrc/cuda/ps_roi_pool_cuda.cu index ab6a50b009c..6b6e3bf831a 100644 --- a/torchvision/csrc/cuda/PSROIPool_cuda.cu +++ b/torchvision/csrc/cuda/ps_roi_pool_cuda.cu @@ -7,7 +7,7 @@ #include "cuda_helpers.h" template -__global__ void PSROIPoolForward( +__global__ void ps_roi_pool_forward_kernel_impl( int nthreads, const T* input, const T spatial_scale, @@ -73,7 +73,7 @@ __global__ void PSROIPoolForward( } template -__global__ void PSROIPoolBackward( +__global__ void ps_roi_pool_backward_kernel_impl( int nthreads, const T* grad_output, const int* channel_mapping, @@ -132,7 +132,7 @@ __global__ void PSROIPoolBackward( } } -std::tuple PSROIPool_forward_cuda( +std::tuple ps_roi_pool_forward_cuda( const at::Tensor& input, const at::Tensor& rois, double spatial_scale, @@ -146,7 +146,7 @@ std::tuple PSROIPool_forward_cuda( at::TensorArg input_t{input, "input", 1}, rois_t{rois, "rois", 2}; - at::CheckedFrom c = "PSROIPool_forward_cuda"; + at::CheckedFrom c = "ps_roi_pool_forward_cuda"; at::checkAllSameGPU(c, {input_t, rois_t}); at::checkAllSameType(c, {input_t, rois_t}); @@ -183,8 +183,8 @@ std::tuple PSROIPool_forward_cuda( auto input_ = input.contiguous(), rois_ = rois.contiguous(); AT_DISPATCH_FLOATING_TYPES_AND_HALF( - input.scalar_type(), "PSROIPool_forward", [&] { - PSROIPoolForward<<>>( + input.scalar_type(), "ps_roi_pool_forward", [&] { + ps_roi_pool_forward_kernel_impl<<>>( output_size, input_.data_ptr(), spatial_scale, @@ -202,7 +202,7 @@ std::tuple PSROIPool_forward_cuda( return std::make_tuple(output, channel_mapping); } -at::Tensor PSROIPool_backward_cuda( +at::Tensor ps_roi_pool_backward_cuda( const at::Tensor& grad, const at::Tensor& rois, const at::Tensor& channel_mapping, @@ -223,7 +223,7 @@ at::Tensor PSROIPool_backward_cuda( at::TensorArg grad_t{grad, "grad", 1}, rois_t{rois, "rois", 2}, channel_mapping_t{channel_mapping, "channel_mapping", 3}; - at::CheckedFrom c = "PSROIPool_backward_cuda"; + at::CheckedFrom c = "ps_roi_pool_backward_cuda"; at::checkAllSameGPU(c, {grad_t, rois_t, channel_mapping_t}); at::checkAllSameType(c, {grad_t, rois_t}); @@ -251,8 +251,8 @@ at::Tensor PSROIPool_backward_cuda( auto grad_ = grad.contiguous(), rois_ = rois.contiguous(); AT_DISPATCH_FLOATING_TYPES_AND_HALF( - grad.scalar_type(), "PSROIPool_backward", [&] { - PSROIPoolBackward<<>>( + grad.scalar_type(), "ps_roi_pool_backward", [&] { + ps_roi_pool_backward_kernel_impl<<>>( grad.numel(), grad_.data_ptr(), channel_mapping.data_ptr(), diff --git a/torchvision/csrc/cuda/vision_cuda.h b/torchvision/csrc/cuda/vision_cuda.h index c80386a8db1..613cf86fce6 100644 --- a/torchvision/csrc/cuda/vision_cuda.h +++ b/torchvision/csrc/cuda/vision_cuda.h @@ -4,14 +4,14 @@ // TODO: Delete this file once all the methods are gone -VISION_API std::tuple PSROIPool_forward_cuda( +VISION_API std::tuple ps_roi_pool_forward_cuda( const at::Tensor& input, const at::Tensor& rois, double spatial_scale, int64_t pooled_height, int64_t pooled_width); -VISION_API at::Tensor PSROIPool_backward_cuda( +VISION_API at::Tensor ps_roi_pool_backward_cuda( const at::Tensor& grad, const at::Tensor& rois, const at::Tensor& channel_mapping, diff --git a/torchvision/csrc/PSROIPool.h b/torchvision/csrc/ps_roi_pool.h similarity index 97% rename from torchvision/csrc/PSROIPool.h rename to torchvision/csrc/ps_roi_pool.h index c3ced9e7842..2d5bb7e2d62 100644 --- a/torchvision/csrc/PSROIPool.h +++ b/torchvision/csrc/ps_roi_pool.h @@ -26,7 +26,7 @@ std::tuple ps_roi_pool( } #if defined(WITH_CUDA) || defined(WITH_HIP) -std::tuple PSROIPool_autocast( +std::tuple ps_roi_pool_autocast( const at::Tensor& input, const at::Tensor& rois, double spatial_scale, @@ -166,7 +166,7 @@ class PSROIPoolBackwardFunction } }; -std::tuple PSROIPool_autograd( +std::tuple ps_roi_pool_autograd( const at::Tensor& input, const at::Tensor& rois, double spatial_scale, @@ -178,7 +178,7 @@ std::tuple PSROIPool_autograd( return std::make_tuple(result[0], result[1]); } -at::Tensor PSROIPool_backward_autograd( +at::Tensor ps_roi_pool_backward_autograd( const at::Tensor& grad, const at::Tensor& rois, const at::Tensor& channel_mapping, diff --git a/torchvision/csrc/vision.cpp b/torchvision/csrc/vision.cpp index c5c204aac2b..6f540c6832e 100644 --- a/torchvision/csrc/vision.cpp +++ b/torchvision/csrc/vision.cpp @@ -8,13 +8,13 @@ #include #endif -#include "PSROIPool.h" #include "ROIAlign.h" #include "ROIPool.h" #include "deform_conv2d.h" #include "empty_tensor_op.h" #include "nms.h" #include "ps_roi_align.h" +#include "ps_roi_pool.h" // If we are in a Windows environment, we need to define // initialization functions for the _custom_ops extension @@ -67,8 +67,8 @@ TORCH_LIBRARY_IMPL(torchvision, CPU, m) { m.impl("nms", nms_cpu); m.impl("ps_roi_align", ps_roi_align_forward_cpu); m.impl("_ps_roi_align_backward", ps_roi_align_backward_cpu); - m.impl("ps_roi_pool", PSROIPool_forward_cpu); - m.impl("_ps_roi_pool_backward", PSROIPool_backward_cpu); + m.impl("ps_roi_pool", ps_roi_pool_forward_cpu); + m.impl("_ps_roi_pool_backward", ps_roi_pool_backward_cpu); m.impl("roi_align", ROIAlign_forward_cpu); m.impl("_roi_align_backward", ROIAlign_backward_cpu); m.impl("roi_pool", ROIPool_forward_cpu); @@ -83,8 +83,8 @@ TORCH_LIBRARY_IMPL(torchvision, CUDA, m) { m.impl("nms", nms_cuda); m.impl("ps_roi_align", ps_roi_align_forward_cuda); m.impl("_ps_roi_align_backward", ps_roi_align_backward_cuda); - m.impl("ps_roi_pool", PSROIPool_forward_cuda); - m.impl("_ps_roi_pool_backward", PSROIPool_backward_cuda); + m.impl("ps_roi_pool", ps_roi_pool_forward_cuda); + m.impl("_ps_roi_pool_backward", ps_roi_pool_backward_cuda); m.impl("roi_align", ROIAlign_forward_cuda); m.impl("_roi_align_backward", ROIAlign_backward_cuda); m.impl("roi_pool", ROIPool_forward_cuda); @@ -98,7 +98,7 @@ TORCH_LIBRARY_IMPL(torchvision, Autocast, m) { m.impl("deform_conv2d", deform_conv2d_autocast); m.impl("nms", nms_autocast); m.impl("ps_roi_align", ps_roi_align_autocast); - m.impl("ps_roi_pool", PSROIPool_autocast); + m.impl("ps_roi_pool", ps_roi_pool_autocast); m.impl("roi_align", ROIAlign_autocast); m.impl("roi_pool", ROIPool_autocast); } @@ -109,8 +109,8 @@ TORCH_LIBRARY_IMPL(torchvision, Autograd, m) { m.impl("_deform_conv2d_backward", deform_conv2d_backward_autograd); m.impl("ps_roi_align", ps_roi_align_autograd); m.impl("_ps_roi_align_backward", ps_roi_align_backward_autograd); - m.impl("ps_roi_pool", PSROIPool_autograd); - m.impl("_ps_roi_pool_backward", PSROIPool_backward_autograd); + m.impl("ps_roi_pool", ps_roi_pool_autograd); + m.impl("_ps_roi_pool_backward", ps_roi_pool_backward_autograd); m.impl("roi_align", ROIAlign_autograd); m.impl("_roi_align_backward", ROIAlign_backward_autograd); m.impl("roi_pool", ROIPool_autograd); From 9b7bd824901b6c9ec9f03d25867687a5a90a5bea Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Tue, 1 Dec 2020 22:18:43 +0000 Subject: [PATCH 2/5] Adding all internal functions in anonymous namespaces. --- torchvision/csrc/cpu/ps_roi_pool_cpu.cpp | 4 ++++ torchvision/csrc/cuda/ps_roi_pool_cuda.cu | 4 ++++ torchvision/csrc/ps_roi_pool.h | 6 ++++-- 3 files changed, 12 insertions(+), 2 deletions(-) diff --git a/torchvision/csrc/cpu/ps_roi_pool_cpu.cpp b/torchvision/csrc/cpu/ps_roi_pool_cpu.cpp index bb27b22dd10..d67c159048a 100644 --- a/torchvision/csrc/cpu/ps_roi_pool_cpu.cpp +++ b/torchvision/csrc/cpu/ps_roi_pool_cpu.cpp @@ -3,6 +3,8 @@ #include #include +namespace { + template inline void add(T* address, const T& val) { *address += val; @@ -143,6 +145,8 @@ void ps_roi_pool_backward_kernel_impl( } } +} // namespace + std::tuple ps_roi_pool_forward_cpu( const at::Tensor& input, const at::Tensor& rois, diff --git a/torchvision/csrc/cuda/ps_roi_pool_cuda.cu b/torchvision/csrc/cuda/ps_roi_pool_cuda.cu index 6b6e3bf831a..4ac6bb8491b 100644 --- a/torchvision/csrc/cuda/ps_roi_pool_cuda.cu +++ b/torchvision/csrc/cuda/ps_roi_pool_cuda.cu @@ -6,6 +6,8 @@ #include "cuda_helpers.h" +namespace { + template __global__ void ps_roi_pool_forward_kernel_impl( int nthreads, @@ -132,6 +134,8 @@ __global__ void ps_roi_pool_backward_kernel_impl( } } +} // namespace + std::tuple ps_roi_pool_forward_cuda( const at::Tensor& input, const at::Tensor& rois, diff --git a/torchvision/csrc/ps_roi_pool.h b/torchvision/csrc/ps_roi_pool.h index 2d5bb7e2d62..0fae65a4b10 100644 --- a/torchvision/csrc/ps_roi_pool.h +++ b/torchvision/csrc/ps_roi_pool.h @@ -11,8 +11,6 @@ #include "hip/vision_cuda.h" #endif -// TODO: put this stuff in torchvision namespace - std::tuple ps_roi_pool( const at::Tensor& input, const at::Tensor& rois, @@ -74,6 +72,8 @@ at::Tensor _ps_roi_pool_backward( width); } +namespace { + class PSROIPoolFunction : public torch::autograd::Function { public: static torch::autograd::variable_list forward( @@ -166,6 +166,8 @@ class PSROIPoolBackwardFunction } }; +} // namespace + std::tuple ps_roi_pool_autograd( const at::Tensor& input, const at::Tensor& rois, From 6e6fd1a232471e901d50f49f78d779d4dea5062b Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Tue, 1 Dec 2020 22:30:11 +0000 Subject: [PATCH 3/5] Renaming C++/CUDA kernel files and moving operator code from header to cpp file. --- .../csrc/cpu/{ps_roi_pool_cpu.cpp => ps_roi_pool_kernel.cpp} | 0 .../csrc/cuda/{ps_roi_pool_cuda.cu => ps_roi_pool_kernel.cu} | 0 torchvision/csrc/{ps_roi_pool.h => ps_roi_pool.cpp} | 0 3 files changed, 0 insertions(+), 0 deletions(-) rename torchvision/csrc/cpu/{ps_roi_pool_cpu.cpp => ps_roi_pool_kernel.cpp} (100%) rename torchvision/csrc/cuda/{ps_roi_pool_cuda.cu => ps_roi_pool_kernel.cu} (100%) rename torchvision/csrc/{ps_roi_pool.h => ps_roi_pool.cpp} (100%) diff --git a/torchvision/csrc/cpu/ps_roi_pool_cpu.cpp b/torchvision/csrc/cpu/ps_roi_pool_kernel.cpp similarity index 100% rename from torchvision/csrc/cpu/ps_roi_pool_cpu.cpp rename to torchvision/csrc/cpu/ps_roi_pool_kernel.cpp diff --git a/torchvision/csrc/cuda/ps_roi_pool_cuda.cu b/torchvision/csrc/cuda/ps_roi_pool_kernel.cu similarity index 100% rename from torchvision/csrc/cuda/ps_roi_pool_cuda.cu rename to torchvision/csrc/cuda/ps_roi_pool_kernel.cu diff --git a/torchvision/csrc/ps_roi_pool.h b/torchvision/csrc/ps_roi_pool.cpp similarity index 100% rename from torchvision/csrc/ps_roi_pool.h rename to torchvision/csrc/ps_roi_pool.cpp From be21d920999300f804430ceae093d083dd8f782d Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Tue, 1 Dec 2020 22:40:43 +0000 Subject: [PATCH 4/5] Create foreach cpp file a separate header file with "public" functions. --- torchvision/csrc/cpu/ps_roi_pool_kernel.cpp | 2 + torchvision/csrc/cpu/ps_roi_pool_kernel.h | 23 ++++++++ torchvision/csrc/cpu/vision_cpu.h | 19 ------- torchvision/csrc/cuda/ps_roi_pool_kernel.cu | 1 + torchvision/csrc/cuda/ps_roi_pool_kernel.h | 23 ++++++++ torchvision/csrc/cuda/vision_cuda.h | 19 ------- torchvision/csrc/ps_roi_pool.cpp | 14 ++--- torchvision/csrc/ps_roi_pool.h | 61 +++++++++++++++++++++ 8 files changed, 114 insertions(+), 48 deletions(-) create mode 100644 torchvision/csrc/cpu/ps_roi_pool_kernel.h create mode 100644 torchvision/csrc/cuda/ps_roi_pool_kernel.h create mode 100644 torchvision/csrc/ps_roi_pool.h diff --git a/torchvision/csrc/cpu/ps_roi_pool_kernel.cpp b/torchvision/csrc/cpu/ps_roi_pool_kernel.cpp index d67c159048a..e7dc51e6565 100644 --- a/torchvision/csrc/cpu/ps_roi_pool_kernel.cpp +++ b/torchvision/csrc/cpu/ps_roi_pool_kernel.cpp @@ -3,6 +3,8 @@ #include #include +#include "ps_roi_pool_kernel.h" + namespace { template diff --git a/torchvision/csrc/cpu/ps_roi_pool_kernel.h b/torchvision/csrc/cpu/ps_roi_pool_kernel.h new file mode 100644 index 00000000000..14a4e22681a --- /dev/null +++ b/torchvision/csrc/cpu/ps_roi_pool_kernel.h @@ -0,0 +1,23 @@ +#pragma once + +#include +#include "../macros.h" + +VISION_API std::tuple ps_roi_pool_forward_cpu( + const at::Tensor& input, + const at::Tensor& rois, + double spatial_scale, + int64_t pooled_height, + int64_t pooled_width); + +VISION_API at::Tensor ps_roi_pool_backward_cpu( + const at::Tensor& grad, + const at::Tensor& rois, + const at::Tensor& channel_mapping, + double spatial_scale, + int64_t pooled_height, + int64_t pooled_width, + int64_t batch_size, + int64_t channels, + int64_t height, + int64_t width); diff --git a/torchvision/csrc/cpu/vision_cpu.h b/torchvision/csrc/cpu/vision_cpu.h index 92ed6b58382..baf64f89689 100644 --- a/torchvision/csrc/cpu/vision_cpu.h +++ b/torchvision/csrc/cpu/vision_cpu.h @@ -4,25 +4,6 @@ // TODO: Delete this file once all the methods are gone -VISION_API std::tuple ps_roi_pool_forward_cpu( - const at::Tensor& input, - const at::Tensor& rois, - double spatial_scale, - int64_t pooled_height, - int64_t pooled_width); - -VISION_API at::Tensor ps_roi_pool_backward_cpu( - const at::Tensor& grad, - const at::Tensor& rois, - const at::Tensor& channel_mapping, - double spatial_scale, - int64_t pooled_height, - int64_t pooled_width, - int64_t batch_size, - int64_t channels, - int64_t height, - int64_t width); - VISION_API at::Tensor ROIAlign_forward_cpu( const at::Tensor& input, const at::Tensor& rois, diff --git a/torchvision/csrc/cuda/ps_roi_pool_kernel.cu b/torchvision/csrc/cuda/ps_roi_pool_kernel.cu index 4ac6bb8491b..dfdf2accf42 100644 --- a/torchvision/csrc/cuda/ps_roi_pool_kernel.cu +++ b/torchvision/csrc/cuda/ps_roi_pool_kernel.cu @@ -5,6 +5,7 @@ #include #include "cuda_helpers.h" +#include "ps_roi_pool_kernel.h" namespace { diff --git a/torchvision/csrc/cuda/ps_roi_pool_kernel.h b/torchvision/csrc/cuda/ps_roi_pool_kernel.h new file mode 100644 index 00000000000..e97f0ee7065 --- /dev/null +++ b/torchvision/csrc/cuda/ps_roi_pool_kernel.h @@ -0,0 +1,23 @@ +#pragma once + +#include +#include "../macros.h" + +VISION_API std::tuple ps_roi_pool_forward_cuda( + const at::Tensor& input, + const at::Tensor& rois, + double spatial_scale, + int64_t pooled_height, + int64_t pooled_width); + +VISION_API at::Tensor ps_roi_pool_backward_cuda( + const at::Tensor& grad, + const at::Tensor& rois, + const at::Tensor& channel_mapping, + double spatial_scale, + int64_t pooled_height, + int64_t pooled_width, + int64_t batch_size, + int64_t channels, + int64_t height, + int64_t width); diff --git a/torchvision/csrc/cuda/vision_cuda.h b/torchvision/csrc/cuda/vision_cuda.h index 613cf86fce6..8d411b9c67e 100644 --- a/torchvision/csrc/cuda/vision_cuda.h +++ b/torchvision/csrc/cuda/vision_cuda.h @@ -4,25 +4,6 @@ // TODO: Delete this file once all the methods are gone -VISION_API std::tuple ps_roi_pool_forward_cuda( - const at::Tensor& input, - const at::Tensor& rois, - double spatial_scale, - int64_t pooled_height, - int64_t pooled_width); - -VISION_API at::Tensor ps_roi_pool_backward_cuda( - const at::Tensor& grad, - const at::Tensor& rois, - const at::Tensor& channel_mapping, - double spatial_scale, - int64_t pooled_height, - int64_t pooled_width, - int64_t batch_size, - int64_t channels, - int64_t height, - int64_t width); - VISION_API at::Tensor ROIAlign_forward_cuda( const at::Tensor& input, const at::Tensor& rois, diff --git a/torchvision/csrc/ps_roi_pool.cpp b/torchvision/csrc/ps_roi_pool.cpp index 0fae65a4b10..76fb2d04be7 100644 --- a/torchvision/csrc/ps_roi_pool.cpp +++ b/torchvision/csrc/ps_roi_pool.cpp @@ -1,14 +1,8 @@ -#pragma once +#include "ps_roi_pool.h" +#include -#include "cpu/vision_cpu.h" - -#ifdef WITH_CUDA -#include "autocast.h" -#include "cuda/vision_cuda.h" -#endif -#ifdef WITH_HIP -#include "autocast.h" -#include "hip/vision_cuda.h" +#if defined(WITH_CUDA) || defined(WITH_HIP) +#include #endif std::tuple ps_roi_pool( diff --git a/torchvision/csrc/ps_roi_pool.h b/torchvision/csrc/ps_roi_pool.h new file mode 100644 index 00000000000..0c8baef4a9a --- /dev/null +++ b/torchvision/csrc/ps_roi_pool.h @@ -0,0 +1,61 @@ +#pragma once + +#include "cpu/ps_roi_pool_kernel.h" + +#ifdef WITH_CUDA +#include "cuda/ps_roi_pool_kernel.h" +#endif +#ifdef WITH_HIP +#include "hip/ps_roi_pool_kernel.h" +#endif + +// C++ Forward +std::tuple ps_roi_pool( + const at::Tensor& input, + const at::Tensor& rois, + double spatial_scale, + int64_t pooled_height, + int64_t pooled_width); + +// Autocast Forward +#if defined(WITH_CUDA) || defined(WITH_HIP) +std::tuple ps_roi_pool_autocast( + const at::Tensor& input, + const at::Tensor& rois, + double spatial_scale, + int64_t pooled_height, + int64_t pooled_width); +#endif + +// C++ Backward +at::Tensor _ps_roi_pool_backward( + const at::Tensor& grad, + const at::Tensor& rois, + const at::Tensor& channel_mapping, + double spatial_scale, + int64_t pooled_height, + int64_t pooled_width, + int64_t batch_size, + int64_t channels, + int64_t height, + int64_t width); + +// Autograd Forward and Backward +std::tuple ps_roi_pool_autograd( + const at::Tensor& input, + const at::Tensor& rois, + double spatial_scale, + int64_t pooled_height, + int64_t pooled_width); + +at::Tensor ps_roi_pool_backward_autograd( + const at::Tensor& grad, + const at::Tensor& rois, + const at::Tensor& channel_mapping, + double spatial_scale, + int64_t pooled_height, + int64_t pooled_width, + int64_t batch_size, + int64_t channels, + int64_t height, + int64_t width); From 9c61021aeedf4a7220e763b760f537895aff3f52 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Tue, 1 Dec 2020 22:43:39 +0000 Subject: [PATCH 5/5] Removing unnecessary repeated includes. --- torchvision/csrc/cpu/ps_roi_pool_kernel.cpp | 5 ----- torchvision/csrc/cuda/ps_roi_pool_kernel.cu | 2 -- 2 files changed, 7 deletions(-) diff --git a/torchvision/csrc/cpu/ps_roi_pool_kernel.cpp b/torchvision/csrc/cpu/ps_roi_pool_kernel.cpp index e7dc51e6565..171de9edc6a 100644 --- a/torchvision/csrc/cpu/ps_roi_pool_kernel.cpp +++ b/torchvision/csrc/cpu/ps_roi_pool_kernel.cpp @@ -1,8 +1,3 @@ -#include -#include -#include -#include - #include "ps_roi_pool_kernel.h" namespace { diff --git a/torchvision/csrc/cuda/ps_roi_pool_kernel.cu b/torchvision/csrc/cuda/ps_roi_pool_kernel.cu index dfdf2accf42..aa1c834e059 100644 --- a/torchvision/csrc/cuda/ps_roi_pool_kernel.cu +++ b/torchvision/csrc/cuda/ps_roi_pool_kernel.cu @@ -1,5 +1,3 @@ -#include -#include #include #include #include