From 14d6778368eee46299d0a40af581e739d99b1d58 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Wed, 2 Dec 2020 00:27:59 +0000 Subject: [PATCH 1/6] Renaming C++ files & methods according to recommended naming conventions and aligning them with Python's API. --- .../cpu/{ROIPool_cpu.cpp => roi_pool_cpu.cpp} | 20 +++++++++---------- torchvision/csrc/cpu/vision_cpu.h | 4 ++-- .../{ROIPool_cuda.cu => roi_pool_cuda.cu} | 20 +++++++++---------- torchvision/csrc/cuda/vision_cuda.h | 4 ++-- torchvision/csrc/{ROIPool.h => roi_pool.h} | 6 +++--- torchvision/csrc/vision.cpp | 16 +++++++-------- 6 files changed, 35 insertions(+), 35 deletions(-) rename torchvision/csrc/cpu/{ROIPool_cpu.cpp => roi_pool_cpu.cpp} (93%) rename torchvision/csrc/cuda/{ROIPool_cuda.cu => roi_pool_cuda.cu} (92%) rename torchvision/csrc/{ROIPool.h => roi_pool.h} (97%) diff --git a/torchvision/csrc/cpu/ROIPool_cpu.cpp b/torchvision/csrc/cpu/roi_pool_cpu.cpp similarity index 93% rename from torchvision/csrc/cpu/ROIPool_cpu.cpp rename to torchvision/csrc/cpu/roi_pool_cpu.cpp index 34da4f1d1cc..d5724c09f98 100644 --- a/torchvision/csrc/cpu/ROIPool_cpu.cpp +++ b/torchvision/csrc/cpu/roi_pool_cpu.cpp @@ -9,7 +9,7 @@ inline void add(T* address, const T& val) { } template -void RoIPoolForward( +void roi_pool_forward_kernel_impl( const T* input, const T spatial_scale, int channels, @@ -78,7 +78,7 @@ void RoIPoolForward( } template -void RoIPoolBackward( +void roi_pool_backward_kernel_impl( const T* grad_output, const int* argmax_data, int num_rois, @@ -120,7 +120,7 @@ void RoIPoolBackward( } // num_rois } -std::tuple ROIPool_forward_cpu( +std::tuple roi_pool_forward_cpu( const at::Tensor& input, const at::Tensor& rois, double spatial_scale, @@ -131,7 +131,7 @@ std::tuple ROIPool_forward_cpu( at::TensorArg input_t{input, "input", 1}, rois_t{rois, "rois", 2}; - at::CheckedFrom c = "ROIPool_forward_cpu"; + at::CheckedFrom c = "roi_pool_forward_cpu"; at::checkAllSameType(c, {input_t, rois_t}); int num_rois = rois.size(0); @@ -151,8 +151,8 @@ std::tuple ROIPool_forward_cpu( auto input_ = input.contiguous(), rois_ = rois.contiguous(); AT_DISPATCH_FLOATING_TYPES_AND_HALF( - input.scalar_type(), "ROIPool_forward", [&] { - RoIPoolForward( + input.scalar_type(), "roi_pool_forward", [&] { + roi_pool_forward_kernel_impl( input_.data_ptr(), spatial_scale, channels, @@ -168,7 +168,7 @@ std::tuple ROIPool_forward_cpu( return std::make_tuple(output, argmax); } -at::Tensor ROIPool_backward_cpu( +at::Tensor roi_pool_backward_cpu( const at::Tensor& grad, const at::Tensor& rois, const at::Tensor& argmax, @@ -188,7 +188,7 @@ at::Tensor ROIPool_backward_cpu( at::TensorArg grad_t{grad, "grad", 1}, rois_t{rois, "rois", 2}; - at::CheckedFrom c = "ROIPool_backward_cpu"; + at::CheckedFrom c = "roi_pool_backward_cpu"; at::checkAllSameType(c, {grad_t, rois_t}); auto num_rois = rois.size(0); @@ -209,8 +209,8 @@ at::Tensor ROIPool_backward_cpu( auto rois_ = rois.contiguous(); AT_DISPATCH_FLOATING_TYPES_AND_HALF( - grad.scalar_type(), "ROIPool_backward", [&] { - RoIPoolBackward( + grad.scalar_type(), "roi_pool_backward", [&] { + roi_pool_backward_kernel_impl( grad.data_ptr(), argmax.data_ptr(), num_rois, diff --git a/torchvision/csrc/cpu/vision_cpu.h b/torchvision/csrc/cpu/vision_cpu.h index a2647c57aa5..4b1ece04df5 100644 --- a/torchvision/csrc/cpu/vision_cpu.h +++ b/torchvision/csrc/cpu/vision_cpu.h @@ -4,14 +4,14 @@ // TODO: Delete this file once all the methods are gone -VISION_API std::tuple ROIPool_forward_cpu( +VISION_API std::tuple roi_pool_forward_cpu( const at::Tensor& input, const at::Tensor& rois, double spatial_scale, int64_t pooled_height, int64_t pooled_width); -VISION_API at::Tensor ROIPool_backward_cpu( +VISION_API at::Tensor roi_pool_backward_cpu( const at::Tensor& grad, const at::Tensor& rois, const at::Tensor& argmax, diff --git a/torchvision/csrc/cuda/ROIPool_cuda.cu b/torchvision/csrc/cuda/roi_pool_cuda.cu similarity index 92% rename from torchvision/csrc/cuda/ROIPool_cuda.cu rename to torchvision/csrc/cuda/roi_pool_cuda.cu index 3131b9eea7e..56f56f08030 100644 --- a/torchvision/csrc/cuda/ROIPool_cuda.cu +++ b/torchvision/csrc/cuda/roi_pool_cuda.cu @@ -7,7 +7,7 @@ #include "cuda_helpers.h" template -__global__ void RoIPoolForward( +__global__ void roi_pool_forward_kernel_impl( int nthreads, const T* input, const T spatial_scale, @@ -72,7 +72,7 @@ __global__ void RoIPoolForward( } template -__global__ void RoIPoolBackward( +__global__ void roi_pool_backward_kernel_impl( int nthreads, const T* grad_output, const int* argmax_data, @@ -115,7 +115,7 @@ __global__ void RoIPoolBackward( } } -std::tuple ROIPool_forward_cuda( +std::tuple roi_pool_forward_cuda( const at::Tensor& input, const at::Tensor& rois, double spatial_scale, @@ -128,7 +128,7 @@ std::tuple ROIPool_forward_cuda( at::TensorArg input_t{input, "input", 1}, rois_t{rois, "rois", 2}; - at::CheckedFrom c = "ROIPool_forward_cuda"; + at::CheckedFrom c = "roi_pool_forward_cuda"; at::checkAllSameGPU(c, {input_t, rois_t}); at::checkAllSameType(c, {input_t, rois_t}); @@ -160,8 +160,8 @@ std::tuple ROIPool_forward_cuda( auto input_ = input.contiguous(), rois_ = rois.contiguous(); - AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.scalar_type(), "ROIPool_forward", [&] { - RoIPoolForward<<>>( + AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.scalar_type(), "roi_pool_forward", [&] { + roi_pool_forward_kernel_impl<<>>( output_size, input_.data_ptr(), spatial_scale, @@ -178,7 +178,7 @@ std::tuple ROIPool_forward_cuda( return std::make_tuple(output, argmax); } -at::Tensor ROIPool_backward_cuda( +at::Tensor roi_pool_backward_cuda( const at::Tensor& grad, const at::Tensor& rois, const at::Tensor& argmax, @@ -197,7 +197,7 @@ at::Tensor ROIPool_backward_cuda( at::TensorArg grad_t{grad, "grad", 1}, rois_t{rois, "rois", 2}, argmax_t{argmax, "argmax", 3}; - at::CheckedFrom c = "ROIPool_backward_cuda"; + at::CheckedFrom c = "roi_pool_backward_cuda"; at::checkAllSameGPU(c, {grad_t, rois_t, argmax_t}); at::checkAllSameType(c, {grad_t, rois_t}); @@ -228,8 +228,8 @@ at::Tensor ROIPool_backward_cuda( auto argmax_ = argmax.contiguous(), rois_ = rois.contiguous(); - AT_DISPATCH_FLOATING_TYPES_AND_HALF(grad.scalar_type(), "ROIPool_backward", [&] { - RoIPoolBackward<<>>( + AT_DISPATCH_FLOATING_TYPES_AND_HALF(grad.scalar_type(), "roi_pool_backward", [&] { + roi_pool_backward_kernel_impl<<>>( grad.numel(), grad.data_ptr(), argmax_.data_ptr(), diff --git a/torchvision/csrc/cuda/vision_cuda.h b/torchvision/csrc/cuda/vision_cuda.h index 1ec187c3348..80a23c6a6d8 100644 --- a/torchvision/csrc/cuda/vision_cuda.h +++ b/torchvision/csrc/cuda/vision_cuda.h @@ -4,14 +4,14 @@ // TODO: Delete this file once all the methods are gone -VISION_API std::tuple ROIPool_forward_cuda( +VISION_API std::tuple roi_pool_forward_cuda( const at::Tensor& input, const at::Tensor& rois, const double spatial_scale, const int64_t pooled_height, const int64_t pooled_width); -VISION_API at::Tensor ROIPool_backward_cuda( +VISION_API at::Tensor roi_pool_backward_cuda( const at::Tensor& grad, const at::Tensor& rois, const at::Tensor& argmax, diff --git a/torchvision/csrc/ROIPool.h b/torchvision/csrc/roi_pool.h similarity index 97% rename from torchvision/csrc/ROIPool.h rename to torchvision/csrc/roi_pool.h index 7950005f1bd..c5ae34fc1ca 100644 --- a/torchvision/csrc/ROIPool.h +++ b/torchvision/csrc/roi_pool.h @@ -26,7 +26,7 @@ std::tuple roi_pool( } #if defined(WITH_CUDA) || defined(WITH_HIP) -std::tuple ROIPool_autocast( +std::tuple roi_pool_autocast( const at::Tensor& input, const at::Tensor& rois, double spatial_scale, @@ -165,7 +165,7 @@ class ROIPoolBackwardFunction } }; -std::tuple ROIPool_autograd( +std::tuple roi_pool_autograd( const at::Tensor& input, const at::Tensor& rois, double spatial_scale, @@ -177,7 +177,7 @@ std::tuple ROIPool_autograd( return std::make_tuple(result[0], result[1]); } -at::Tensor ROIPool_backward_autograd( +at::Tensor roi_pool_backward_autograd( const at::Tensor& grad, const at::Tensor& rois, const at::Tensor& argmax, diff --git a/torchvision/csrc/vision.cpp b/torchvision/csrc/vision.cpp index c41663f0736..d764ec9334b 100644 --- a/torchvision/csrc/vision.cpp +++ b/torchvision/csrc/vision.cpp @@ -8,13 +8,13 @@ #include #endif -#include "ROIPool.h" #include "deform_conv2d.h" #include "empty_tensor_op.h" #include "nms.h" #include "ps_roi_align.h" #include "ps_roi_pool.h" #include "roi_align.h" +#include "roi_pool.h" // If we are in a Windows environment, we need to define // initialization functions for the _custom_ops extension @@ -71,8 +71,8 @@ TORCH_LIBRARY_IMPL(torchvision, CPU, m) { m.impl("_ps_roi_pool_backward", ps_roi_pool_backward_cpu); m.impl("roi_align", roi_align_forward_cpu); m.impl("_roi_align_backward", roi_align_backward_cpu); - m.impl("roi_pool", ROIPool_forward_cpu); - m.impl("_roi_pool_backward", ROIPool_backward_cpu); + m.impl("roi_pool", roi_pool_forward_cpu); + m.impl("_roi_pool_backward", roi_pool_backward_cpu); } // TODO: Place this in a hypothetical separate torchvision_cuda library @@ -87,8 +87,8 @@ TORCH_LIBRARY_IMPL(torchvision, CUDA, m) { m.impl("_ps_roi_pool_backward", ps_roi_pool_backward_cuda); m.impl("roi_align", roi_align_forward_cuda); m.impl("_roi_align_backward", roi_align_backward_cuda); - m.impl("roi_pool", ROIPool_forward_cuda); - m.impl("_roi_pool_backward", ROIPool_backward_cuda); + m.impl("roi_pool", roi_pool_forward_cuda); + m.impl("_roi_pool_backward", roi_pool_backward_cuda); } #endif @@ -100,7 +100,7 @@ TORCH_LIBRARY_IMPL(torchvision, Autocast, m) { m.impl("ps_roi_align", ps_roi_align_autocast); m.impl("ps_roi_pool", ps_roi_pool_autocast); m.impl("roi_align", roi_align_autocast); - m.impl("roi_pool", ROIPool_autocast); + m.impl("roi_pool", roi_pool_autocast); } #endif @@ -113,6 +113,6 @@ TORCH_LIBRARY_IMPL(torchvision, Autograd, m) { m.impl("_ps_roi_pool_backward", ps_roi_pool_backward_autograd); m.impl("roi_align", roi_align_autograd); m.impl("_roi_align_backward", roi_align_backward_autograd); - m.impl("roi_pool", ROIPool_autograd); - m.impl("_roi_pool_backward", ROIPool_backward_autograd); + m.impl("roi_pool", roi_pool_autograd); + m.impl("_roi_pool_backward", roi_pool_backward_autograd); } From 9bea6684f85e58602cfb7894ade961761f643e2a Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Wed, 2 Dec 2020 00:32:06 +0000 Subject: [PATCH 2/6] Adding all internal functions in anonymous namespaces. --- torchvision/csrc/cpu/roi_pool_cpu.cpp | 4 ++++ torchvision/csrc/cuda/roi_pool_cuda.cu | 4 ++++ torchvision/csrc/roi_pool.h | 6 ++++-- 3 files changed, 12 insertions(+), 2 deletions(-) diff --git a/torchvision/csrc/cpu/roi_pool_cpu.cpp b/torchvision/csrc/cpu/roi_pool_cpu.cpp index d5724c09f98..9710380e0bc 100644 --- a/torchvision/csrc/cpu/roi_pool_cpu.cpp +++ b/torchvision/csrc/cpu/roi_pool_cpu.cpp @@ -3,6 +3,8 @@ #include #include +namespace { + template inline void add(T* address, const T& val) { *address += val; @@ -120,6 +122,8 @@ void roi_pool_backward_kernel_impl( } // num_rois } +} // namespace + std::tuple roi_pool_forward_cpu( const at::Tensor& input, const at::Tensor& rois, diff --git a/torchvision/csrc/cuda/roi_pool_cuda.cu b/torchvision/csrc/cuda/roi_pool_cuda.cu index 56f56f08030..50cc9ef8723 100644 --- a/torchvision/csrc/cuda/roi_pool_cuda.cu +++ b/torchvision/csrc/cuda/roi_pool_cuda.cu @@ -6,6 +6,8 @@ #include "cuda_helpers.h" +namespace { + template __global__ void roi_pool_forward_kernel_impl( int nthreads, @@ -115,6 +117,8 @@ __global__ void roi_pool_backward_kernel_impl( } } +} // namespace + std::tuple roi_pool_forward_cuda( const at::Tensor& input, const at::Tensor& rois, diff --git a/torchvision/csrc/roi_pool.h b/torchvision/csrc/roi_pool.h index c5ae34fc1ca..f64b5f12937 100644 --- a/torchvision/csrc/roi_pool.h +++ b/torchvision/csrc/roi_pool.h @@ -11,8 +11,6 @@ #include "hip/vision_cuda.h" #endif -// TODO: put this stuff in torchvision namespace - std::tuple roi_pool( const at::Tensor& input, const at::Tensor& rois, @@ -73,6 +71,8 @@ at::Tensor _roi_pool_backward( width); } +namespace { + class ROIPoolFunction : public torch::autograd::Function { public: static torch::autograd::variable_list forward( @@ -165,6 +165,8 @@ class ROIPoolBackwardFunction } }; +} // namespace + std::tuple roi_pool_autograd( const at::Tensor& input, const at::Tensor& rois, From f206a8bfe10feea60c82233d2e82027d7524a606 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Wed, 2 Dec 2020 00:36:27 +0000 Subject: [PATCH 3/6] Syncing variable names between the cpp files and their header files. --- torchvision/csrc/cuda/vision_cuda.h | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/torchvision/csrc/cuda/vision_cuda.h b/torchvision/csrc/cuda/vision_cuda.h index 80a23c6a6d8..d53afc961bb 100644 --- a/torchvision/csrc/cuda/vision_cuda.h +++ b/torchvision/csrc/cuda/vision_cuda.h @@ -7,18 +7,18 @@ VISION_API std::tuple roi_pool_forward_cuda( const at::Tensor& input, const at::Tensor& rois, - const double spatial_scale, - const int64_t pooled_height, - const int64_t pooled_width); + double spatial_scale, + int64_t pooled_height, + int64_t pooled_width); VISION_API at::Tensor roi_pool_backward_cuda( const at::Tensor& grad, const at::Tensor& rois, const at::Tensor& argmax, - const double spatial_scale, - const int64_t pooled_height, - const int64_t pooled_width, - const int64_t batch_size, - const int64_t channels, - const int64_t height, - const int64_t width); + double spatial_scale, + int64_t pooled_height, + int64_t pooled_width, + int64_t batch_size, + int64_t channels, + int64_t height, + int64_t width); From b84d38f99c7effcf396c4ca7bcde2d63376dc293 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Wed, 2 Dec 2020 00:49:22 +0000 Subject: [PATCH 4/6] Renaming C++/CUDA kernel files and moving operator code from header to cpp file. --- torchvision/csrc/cpu/{roi_pool_cpu.cpp => roi_pool_kernel.cpp} | 0 torchvision/csrc/cuda/{roi_pool_cuda.cu => roi_pool_kernel.cu} | 0 torchvision/csrc/{roi_pool.h => roi_pool.cpp} | 0 3 files changed, 0 insertions(+), 0 deletions(-) rename torchvision/csrc/cpu/{roi_pool_cpu.cpp => roi_pool_kernel.cpp} (100%) rename torchvision/csrc/cuda/{roi_pool_cuda.cu => roi_pool_kernel.cu} (100%) rename torchvision/csrc/{roi_pool.h => roi_pool.cpp} (100%) diff --git a/torchvision/csrc/cpu/roi_pool_cpu.cpp b/torchvision/csrc/cpu/roi_pool_kernel.cpp similarity index 100% rename from torchvision/csrc/cpu/roi_pool_cpu.cpp rename to torchvision/csrc/cpu/roi_pool_kernel.cpp diff --git a/torchvision/csrc/cuda/roi_pool_cuda.cu b/torchvision/csrc/cuda/roi_pool_kernel.cu similarity index 100% rename from torchvision/csrc/cuda/roi_pool_cuda.cu rename to torchvision/csrc/cuda/roi_pool_kernel.cu diff --git a/torchvision/csrc/roi_pool.h b/torchvision/csrc/roi_pool.cpp similarity index 100% rename from torchvision/csrc/roi_pool.h rename to torchvision/csrc/roi_pool.cpp From d19b2fabb155c5b40ffd0eed899f88fcc1fefe52 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Wed, 2 Dec 2020 01:05:22 +0000 Subject: [PATCH 5/6] Create foreach cpp file a separate header file with "public" functions. --- torchvision/csrc/cpu/roi_pool_kernel.cpp | 2 + torchvision/csrc/cpu/roi_pool_kernel.h | 23 +++++++++ torchvision/csrc/cpu/vision_cpu.h | 19 -------- torchvision/csrc/cuda/roi_pool_kernel.cu | 1 + torchvision/csrc/cuda/roi_pool_kernel.h | 23 +++++++++ torchvision/csrc/cuda/vision_cuda.h | 19 -------- torchvision/csrc/roi_pool.cpp | 14 ++---- torchvision/csrc/roi_pool.h | 61 ++++++++++++++++++++++++ 8 files changed, 114 insertions(+), 48 deletions(-) create mode 100644 torchvision/csrc/cpu/roi_pool_kernel.h create mode 100644 torchvision/csrc/cuda/roi_pool_kernel.h create mode 100644 torchvision/csrc/roi_pool.h diff --git a/torchvision/csrc/cpu/roi_pool_kernel.cpp b/torchvision/csrc/cpu/roi_pool_kernel.cpp index 9710380e0bc..e083661bb14 100644 --- a/torchvision/csrc/cpu/roi_pool_kernel.cpp +++ b/torchvision/csrc/cpu/roi_pool_kernel.cpp @@ -3,6 +3,8 @@ #include #include +#include "roi_pool_kernel.h" + namespace { template diff --git a/torchvision/csrc/cpu/roi_pool_kernel.h b/torchvision/csrc/cpu/roi_pool_kernel.h new file mode 100644 index 00000000000..66fd993d5b4 --- /dev/null +++ b/torchvision/csrc/cpu/roi_pool_kernel.h @@ -0,0 +1,23 @@ +#pragma once + +#include +#include "../macros.h" + +VISION_API std::tuple roi_pool_forward_cpu( + const at::Tensor& input, + const at::Tensor& rois, + double spatial_scale, + int64_t pooled_height, + int64_t pooled_width); + +VISION_API at::Tensor roi_pool_backward_cpu( + const at::Tensor& grad, + const at::Tensor& rois, + const at::Tensor& argmax, + double spatial_scale, + int64_t pooled_height, + int64_t pooled_width, + int64_t batch_size, + int64_t channels, + int64_t height, + int64_t width); diff --git a/torchvision/csrc/cpu/vision_cpu.h b/torchvision/csrc/cpu/vision_cpu.h index 4b1ece04df5..a772fa13f01 100644 --- a/torchvision/csrc/cpu/vision_cpu.h +++ b/torchvision/csrc/cpu/vision_cpu.h @@ -3,22 +3,3 @@ #include "../macros.h" // TODO: Delete this file once all the methods are gone - -VISION_API std::tuple roi_pool_forward_cpu( - const at::Tensor& input, - const at::Tensor& rois, - double spatial_scale, - int64_t pooled_height, - int64_t pooled_width); - -VISION_API at::Tensor roi_pool_backward_cpu( - const at::Tensor& grad, - const at::Tensor& rois, - const at::Tensor& argmax, - double spatial_scale, - int64_t pooled_height, - int64_t pooled_width, - int64_t batch_size, - int64_t channels, - int64_t height, - int64_t width); diff --git a/torchvision/csrc/cuda/roi_pool_kernel.cu b/torchvision/csrc/cuda/roi_pool_kernel.cu index 50cc9ef8723..97dd93d2f42 100644 --- a/torchvision/csrc/cuda/roi_pool_kernel.cu +++ b/torchvision/csrc/cuda/roi_pool_kernel.cu @@ -5,6 +5,7 @@ #include #include "cuda_helpers.h" +#include "roi_pool_kernel.h" namespace { diff --git a/torchvision/csrc/cuda/roi_pool_kernel.h b/torchvision/csrc/cuda/roi_pool_kernel.h new file mode 100644 index 00000000000..3a99f7521bd --- /dev/null +++ b/torchvision/csrc/cuda/roi_pool_kernel.h @@ -0,0 +1,23 @@ +#pragma once + +#include +#include "../macros.h" + +VISION_API std::tuple roi_pool_forward_cuda( + const at::Tensor& input, + const at::Tensor& rois, + double spatial_scale, + int64_t pooled_height, + int64_t pooled_width); + +VISION_API at::Tensor roi_pool_backward_cuda( + const at::Tensor& grad, + const at::Tensor& rois, + const at::Tensor& argmax, + double spatial_scale, + int64_t pooled_height, + int64_t pooled_width, + int64_t batch_size, + int64_t channels, + int64_t height, + int64_t width); diff --git a/torchvision/csrc/cuda/vision_cuda.h b/torchvision/csrc/cuda/vision_cuda.h index d53afc961bb..a772fa13f01 100644 --- a/torchvision/csrc/cuda/vision_cuda.h +++ b/torchvision/csrc/cuda/vision_cuda.h @@ -3,22 +3,3 @@ #include "../macros.h" // TODO: Delete this file once all the methods are gone - -VISION_API std::tuple roi_pool_forward_cuda( - const at::Tensor& input, - const at::Tensor& rois, - double spatial_scale, - int64_t pooled_height, - int64_t pooled_width); - -VISION_API at::Tensor roi_pool_backward_cuda( - const at::Tensor& grad, - const at::Tensor& rois, - const at::Tensor& argmax, - double spatial_scale, - int64_t pooled_height, - int64_t pooled_width, - int64_t batch_size, - int64_t channels, - int64_t height, - int64_t width); diff --git a/torchvision/csrc/roi_pool.cpp b/torchvision/csrc/roi_pool.cpp index f64b5f12937..c8d70bd8940 100644 --- a/torchvision/csrc/roi_pool.cpp +++ b/torchvision/csrc/roi_pool.cpp @@ -1,14 +1,8 @@ -#pragma once +#include "roi_pool.h" +#include -#include "cpu/vision_cpu.h" - -#ifdef WITH_CUDA -#include "autocast.h" -#include "cuda/vision_cuda.h" -#endif -#ifdef WITH_HIP -#include "autocast.h" -#include "hip/vision_cuda.h" +#if defined(WITH_CUDA) || defined(WITH_HIP) +#include #endif std::tuple roi_pool( diff --git a/torchvision/csrc/roi_pool.h b/torchvision/csrc/roi_pool.h new file mode 100644 index 00000000000..f528ce6d7e0 --- /dev/null +++ b/torchvision/csrc/roi_pool.h @@ -0,0 +1,61 @@ +#pragma once + +#include "cpu/roi_pool_kernel.h" + +#ifdef WITH_CUDA +#include "cuda/roi_pool_kernel.h" +#endif +#ifdef WITH_HIP +#include "hip/roi_pool_kernel.h" +#endif + +// C++ Forward +std::tuple roi_pool( + const at::Tensor& input, + const at::Tensor& rois, + double spatial_scale, + int64_t pooled_height, + int64_t pooled_width); + +// Autocast Forward +#if defined(WITH_CUDA) || defined(WITH_HIP) +std::tuple roi_pool_autocast( + const at::Tensor& input, + const at::Tensor& rois, + double spatial_scale, + int64_t pooled_height, + int64_t pooled_width); +#endif + +// C++ Backward +at::Tensor _roi_pool_backward( + const at::Tensor& grad, + const at::Tensor& rois, + const at::Tensor& argmax, + double spatial_scale, + int64_t pooled_height, + int64_t pooled_width, + int64_t batch_size, + int64_t channels, + int64_t height, + int64_t width); + +// Autograd Forward and Backward +std::tuple roi_pool_autograd( + const at::Tensor& input, + const at::Tensor& rois, + double spatial_scale, + int64_t pooled_height, + int64_t pooled_width); + +at::Tensor roi_pool_backward_autograd( + const at::Tensor& grad, + const at::Tensor& rois, + const at::Tensor& argmax, + double spatial_scale, + int64_t pooled_height, + int64_t pooled_width, + int64_t batch_size, + int64_t channels, + int64_t height, + int64_t width); From d823bb8f484e07a5fc2e1cdf9b27e5e3f78cbab3 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Wed, 2 Dec 2020 01:10:28 +0000 Subject: [PATCH 6/6] Removing unnecessary repeated includes. --- torchvision/csrc/cpu/roi_pool_kernel.cpp | 5 +---- torchvision/csrc/cuda/roi_pool_kernel.cu | 3 +-- 2 files changed, 2 insertions(+), 6 deletions(-) diff --git a/torchvision/csrc/cpu/roi_pool_kernel.cpp b/torchvision/csrc/cpu/roi_pool_kernel.cpp index e083661bb14..389e9c90248 100644 --- a/torchvision/csrc/cpu/roi_pool_kernel.cpp +++ b/torchvision/csrc/cpu/roi_pool_kernel.cpp @@ -1,7 +1,4 @@ -#include -#include -#include -#include +#include #include "roi_pool_kernel.h" diff --git a/torchvision/csrc/cuda/roi_pool_kernel.cu b/torchvision/csrc/cuda/roi_pool_kernel.cu index 97dd93d2f42..c10dd0cf403 100644 --- a/torchvision/csrc/cuda/roi_pool_kernel.cu +++ b/torchvision/csrc/cuda/roi_pool_kernel.cu @@ -1,7 +1,6 @@ -#include -#include #include #include +#include #include #include "cuda_helpers.h"