From 8bd973e98b8f54e63df4176f28179b3b09179f58 Mon Sep 17 00:00:00 2001 From: "Edward Z. Yang" Date: Mon, 29 Jun 2020 14:13:16 -0700 Subject: [PATCH 1/2] Switch torchvision registrations to new operator registration API. This is still registering everything as catchalls, so we're really just moving deck chairs around, but payoff is coming soon. Signed-off-by: Edward Z. Yang --- torchvision/csrc/vision.cpp | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/torchvision/csrc/vision.cpp b/torchvision/csrc/vision.cpp index 9debc3da9b6..74cb9c9f321 100644 --- a/torchvision/csrc/vision.cpp +++ b/torchvision/csrc/vision.cpp @@ -42,14 +42,14 @@ int64_t _cuda_version() { #endif } -static auto registry = - torch::RegisterOperators() - .op("torchvision::nms", &nms) - .op("torchvision::roi_align(Tensor input, Tensor rois, float spatial_scale, int pooled_height, int pooled_width, int sampling_ratio, bool aligned) -> Tensor", - &roi_align) - .op("torchvision::roi_pool", &roi_pool) - .op("torchvision::_new_empty_tensor_op", &new_empty_tensor) - .op("torchvision::ps_roi_align", &ps_roi_align) - .op("torchvision::ps_roi_pool", &ps_roi_pool) - .op("torchvision::deform_conv2d", &deform_conv2d) - .op("torchvision::_cuda_version", &_cuda_version); +TORCH_LIBRARY(torchvision, m) { + m.def("nms", &nms); + m.def("roi_align(Tensor input, Tensor rois, float spatial_scale, int pooled_height, int pooled_width, int sampling_ratio, bool aligned) -> Tensor", + &roi_align); + m.def("roi_pool", &roi_pool); + m.def("_new_empty_tensor_op", &new_empty_tensor); + m.def("ps_roi_align", &ps_roi_align); + m.def("ps_roi_pool", &ps_roi_pool); + m.def("deform_conv2d", &deform_conv2d); + m.def("_cuda_version", &_cuda_version); +} From 6155d0c626cfd8af5d2e6f660b43fac6b162df43 Mon Sep 17 00:00:00 2001 From: "Edward Z. Yang" Date: Mon, 29 Jun 2020 15:07:32 -0700 Subject: [PATCH 2/2] Port roi_align to actually use dispatcher Signed-off-by: Edward Z. Yang --- torchvision/csrc/ROIAlign.h | 139 ++++++++++++++++--------- torchvision/csrc/cpu/ROIAlign_cpu.cpp | 24 ++--- torchvision/csrc/cpu/vision_cpu.h | 24 ++--- torchvision/csrc/cuda/ROIAlign_cuda.cu | 24 ++--- torchvision/csrc/cuda/vision_cuda.h | 24 ++--- torchvision/csrc/vision.cpp | 24 ++++- 6 files changed, 161 insertions(+), 98 deletions(-) diff --git a/torchvision/csrc/ROIAlign.h b/torchvision/csrc/ROIAlign.h index 78dcb101dce..7a856f34d63 100644 --- a/torchvision/csrc/ROIAlign.h +++ b/torchvision/csrc/ROIAlign.h @@ -9,8 +9,9 @@ #include "hip/vision_cuda.h" #endif -// Interface for Python -at::Tensor ROIAlign_forward( +// TODO: put this stuff in torchvision namespace + +at::Tensor roi_align( const at::Tensor& input, // Input feature map. const at::Tensor& rois, // List of ROIs to pool over. const double spatial_scale, // The scale of the image features. ROIs will be @@ -21,21 +22,10 @@ at::Tensor ROIAlign_forward( const bool aligned) // The flag for pixel shift // along each axis. { - if (input.is_cuda()) { -#if defined(WITH_CUDA) || defined(WITH_HIP) - return ROIAlign_forward_cuda( - input, - rois, - spatial_scale, - pooled_height, - pooled_width, - sampling_ratio, - aligned); -#else - AT_ERROR("Not compiled with GPU support"); -#endif - } - return ROIAlign_forward_cpu( + static auto op = c10::Dispatcher::singleton() + .findSchemaOrThrow("torchvision::roi_align", "") + .typed(); + return op.call( input, rois, spatial_scale, @@ -45,37 +35,23 @@ at::Tensor ROIAlign_forward( aligned); } -at::Tensor ROIAlign_backward( +at::Tensor _roi_align_backward( const at::Tensor& grad, const at::Tensor& rois, - const float spatial_scale, - const int pooled_height, - const int pooled_width, - const int batch_size, - const int channels, - const int height, - const int width, - const int sampling_ratio, + const double spatial_scale, + const int64_t pooled_height, + const int64_t pooled_width, + const int64_t batch_size, + const int64_t channels, + const int64_t height, + const int64_t width, + const int64_t sampling_ratio, const bool aligned) { - if (grad.is_cuda()) { -#if defined(WITH_CUDA) || defined(WITH_HIP) - return ROIAlign_backward_cuda( - grad, - rois, - spatial_scale, - pooled_height, - pooled_width, - batch_size, - channels, - height, - width, - sampling_ratio, - aligned); -#else - AT_ERROR("Not compiled with GPU support"); -#endif - } - return ROIAlign_backward_cpu( + static auto op = + c10::Dispatcher::singleton() + .findSchemaOrThrow("torchvision::_roi_align_backward", "") + .typed(); + return op.call( grad, rois, spatial_scale, @@ -107,7 +83,8 @@ class ROIAlignFunction : public torch::autograd::Function { ctx->saved_data["aligned"] = aligned; ctx->saved_data["input_shape"] = input.sizes(); ctx->save_for_backward({rois}); - auto result = ROIAlign_forward( + at::AutoNonVariableTypeMode g; + auto result = roi_align( input, rois, spatial_scale, @@ -125,7 +102,7 @@ class ROIAlignFunction : public torch::autograd::Function { auto saved = ctx->get_saved_variables(); auto rois = saved[0]; auto input_shape = ctx->saved_data["input_shape"].toIntList(); - auto grad_in = ROIAlign_backward( + auto grad_in = _roi_align_backward( grad_output[0], rois, ctx->saved_data["spatial_scale"].toDouble(), @@ -147,7 +124,47 @@ class ROIAlignFunction : public torch::autograd::Function { } }; -at::Tensor roi_align( +// TODO: There should be an easier way to do this +class ROIAlignBackwardFunction + : public torch::autograd::Function { + public: + static torch::autograd::variable_list forward( + torch::autograd::AutogradContext* ctx, + torch::autograd::Variable grad, + torch::autograd::Variable rois, + const double spatial_scale, + const int64_t pooled_height, + const int64_t pooled_width, + const int64_t batch_size, + const int64_t channels, + const int64_t height, + const int64_t width, + const int64_t sampling_ratio, + const bool aligned) { + at::AutoNonVariableTypeMode g; + auto result = _roi_align_backward( + grad, + rois, + spatial_scale, + pooled_height, + pooled_width, + batch_size, + channels, + height, + width, + sampling_ratio, + aligned); + return {result}; + } + + static torch::autograd::variable_list backward( + torch::autograd::AutogradContext* ctx, + torch::autograd::variable_list grad_output) { + TORCH_CHECK(0, "double backwards on roi_align not supported"); + } +}; + +at::Tensor ROIAlign_autograd( const at::Tensor& input, const at::Tensor& rois, const double spatial_scale, @@ -164,3 +181,29 @@ at::Tensor roi_align( sampling_ratio, aligned)[0]; } + +at::Tensor ROIAlign_backward_autograd( + const at::Tensor& grad, + const at::Tensor& rois, + const double spatial_scale, + const int64_t pooled_height, + const int64_t pooled_width, + const int64_t batch_size, + const int64_t channels, + const int64_t height, + const int64_t width, + const int64_t sampling_ratio, + const bool aligned) { + return ROIAlignBackwardFunction::apply( + grad, + rois, + spatial_scale, + pooled_height, + pooled_width, + batch_size, + channels, + height, + width, + sampling_ratio, + aligned)[0]; +} diff --git a/torchvision/csrc/cpu/ROIAlign_cpu.cpp b/torchvision/csrc/cpu/ROIAlign_cpu.cpp index 325221df65b..75d3e7a90b4 100644 --- a/torchvision/csrc/cpu/ROIAlign_cpu.cpp +++ b/torchvision/csrc/cpu/ROIAlign_cpu.cpp @@ -381,10 +381,10 @@ void ROIAlignBackward( at::Tensor ROIAlign_forward_cpu( const at::Tensor& input, const at::Tensor& rois, - const float spatial_scale, - const int pooled_height, - const int pooled_width, - const int sampling_ratio, + const double spatial_scale, + const int64_t pooled_height, + const int64_t pooled_width, + const int64_t sampling_ratio, const bool aligned) { AT_ASSERTM(input.device().is_cpu(), "input must be a CPU tensor"); AT_ASSERTM(rois.device().is_cpu(), "rois must be a CPU tensor"); @@ -430,14 +430,14 @@ at::Tensor ROIAlign_forward_cpu( at::Tensor ROIAlign_backward_cpu( const at::Tensor& grad, const at::Tensor& rois, - const float spatial_scale, - const int pooled_height, - const int pooled_width, - const int batch_size, - const int channels, - const int height, - const int width, - const int sampling_ratio, + const double spatial_scale, + const int64_t pooled_height, + const int64_t pooled_width, + const int64_t batch_size, + const int64_t channels, + const int64_t height, + const int64_t width, + const int64_t sampling_ratio, const bool aligned) { AT_ASSERTM(grad.device().is_cpu(), "grad must be a CPU tensor"); AT_ASSERTM(rois.device().is_cpu(), "rois must be a CPU tensor"); diff --git a/torchvision/csrc/cpu/vision_cpu.h b/torchvision/csrc/cpu/vision_cpu.h index d81a51a59c4..64aa1ae2119 100644 --- a/torchvision/csrc/cpu/vision_cpu.h +++ b/torchvision/csrc/cpu/vision_cpu.h @@ -23,23 +23,23 @@ at::Tensor ROIPool_backward_cpu( at::Tensor ROIAlign_forward_cpu( const at::Tensor& input, const at::Tensor& rois, - const float spatial_scale, - const int pooled_height, - const int pooled_width, - const int sampling_ratio, + const double spatial_scale, + const int64_t pooled_height, + const int64_t pooled_width, + const int64_t sampling_ratio, const bool aligned); at::Tensor ROIAlign_backward_cpu( const at::Tensor& grad, const at::Tensor& rois, - const float spatial_scale, - const int pooled_height, - const int pooled_width, - const int batch_size, - const int channels, - const int height, - const int width, - const int sampling_ratio, + const double spatial_scale, + const int64_t pooled_height, + const int64_t pooled_width, + const int64_t batch_size, + const int64_t channels, + const int64_t height, + const int64_t width, + const int64_t sampling_ratio, const bool aligned); std::tuple PSROIPool_forward_cpu( diff --git a/torchvision/csrc/cuda/ROIAlign_cuda.cu b/torchvision/csrc/cuda/ROIAlign_cuda.cu index 298af06c708..8f8bcd10d48 100644 --- a/torchvision/csrc/cuda/ROIAlign_cuda.cu +++ b/torchvision/csrc/cuda/ROIAlign_cuda.cu @@ -307,10 +307,10 @@ __global__ void RoIAlignBackward( at::Tensor ROIAlign_forward_cuda( const at::Tensor& input, const at::Tensor& rois, - const float spatial_scale, - const int pooled_height, - const int pooled_width, - const int sampling_ratio, + const double spatial_scale, + const int64_t pooled_height, + const int64_t pooled_width, + const int64_t sampling_ratio, const bool aligned) { AT_ASSERTM(input.is_cuda(), "input must be a CUDA tensor"); AT_ASSERTM(rois.is_cuda(), "rois must be a CUDA tensor"); @@ -368,14 +368,14 @@ at::Tensor ROIAlign_forward_cuda( at::Tensor ROIAlign_backward_cuda( const at::Tensor& grad, const at::Tensor& rois, - const float spatial_scale, - const int pooled_height, - const int pooled_width, - const int batch_size, - const int channels, - const int height, - const int width, - const int sampling_ratio, + const double spatial_scale, + const int64_t pooled_height, + const int64_t pooled_width, + const int64_t batch_size, + const int64_t channels, + const int64_t height, + const int64_t width, + const int64_t sampling_ratio, const bool aligned) { AT_ASSERTM(grad.is_cuda(), "grad must be a CUDA tensor"); AT_ASSERTM(rois.is_cuda(), "rois must be a CUDA tensor"); diff --git a/torchvision/csrc/cuda/vision_cuda.h b/torchvision/csrc/cuda/vision_cuda.h index 5f0ff05246b..834ba51a4cf 100644 --- a/torchvision/csrc/cuda/vision_cuda.h +++ b/torchvision/csrc/cuda/vision_cuda.h @@ -9,23 +9,23 @@ at::Tensor ROIAlign_forward_cuda( const at::Tensor& input, const at::Tensor& rois, - const float spatial_scale, - const int pooled_height, - const int pooled_width, - const int sampling_ratio, + const double spatial_scale, + const int64_t pooled_height, + const int64_t pooled_width, + const int64_t sampling_ratio, const bool aligned); at::Tensor ROIAlign_backward_cuda( const at::Tensor& grad, const at::Tensor& rois, - const float spatial_scale, - const int pooled_height, - const int pooled_width, - const int batch_size, - const int channels, - const int height, - const int width, - const int sampling_ratio, + const double spatial_scale, + const int64_t pooled_height, + const int64_t pooled_width, + const int64_t batch_size, + const int64_t channels, + const int64_t height, + const int64_t width, + const int64_t sampling_ratio, const bool aligned); std::tuple ROIPool_forward_cuda( diff --git a/torchvision/csrc/vision.cpp b/torchvision/csrc/vision.cpp index 74cb9c9f321..7f56bdb51a1 100644 --- a/torchvision/csrc/vision.cpp +++ b/torchvision/csrc/vision.cpp @@ -44,8 +44,10 @@ int64_t _cuda_version() { TORCH_LIBRARY(torchvision, m) { m.def("nms", &nms); - m.def("roi_align(Tensor input, Tensor rois, float spatial_scale, int pooled_height, int pooled_width, int sampling_ratio, bool aligned) -> Tensor", - &roi_align); + m.def( + "roi_align(Tensor input, Tensor rois, float spatial_scale, int pooled_height, int pooled_width, int sampling_ratio, bool aligned) -> Tensor"); + m.def( + "_roi_align_backward(Tensor grad, Tensor rois, float spatial_scale, int pooled_height, int pooled_width, int batch_size, int channels, int height, int width, int sampling_ratio, bool aligned) -> Tensor"); m.def("roi_pool", &roi_pool); m.def("_new_empty_tensor_op", &new_empty_tensor); m.def("ps_roi_align", &ps_roi_align); @@ -53,3 +55,21 @@ TORCH_LIBRARY(torchvision, m) { m.def("deform_conv2d", &deform_conv2d); m.def("_cuda_version", &_cuda_version); } + +TORCH_LIBRARY_IMPL(torchvision, CPU, m) { + m.impl("roi_align", ROIAlign_forward_cpu); + m.impl("_roi_align_backward", ROIAlign_backward_cpu); +} + +// TODO: Place this in a hypothetical separate torchvision_cuda library +#if defined(WITH_CUDA) || defined(WITH_HIP) +TORCH_LIBRARY_IMPL(torchvision, CUDA, m) { + m.impl("roi_align", ROIAlign_forward_cuda); + m.impl("_roi_align_backward", ROIAlign_backward_cuda); +} +#endif + +TORCH_LIBRARY_IMPL(torchvision, Autograd, m) { + m.impl("roi_align", ROIAlign_autograd); + m.impl("_roi_align_backward", ROIAlign_backward_autograd); +}