Skip to content

Commit 97bbaff

Browse files
bdhirshfacebook-github-bot
authored andcommitted
Updating all call-sites of the legacy dispatcher registration API in fbcode to the new API. (#48178)
Summary: Pull Request resolved: pytorch/pytorch#48178 I migrated all call sites that used the legacy dispatcher registration API (RegisterOperators()) to use the new API (TORCH_LIBRARY...). I found all call-sites by running `fbgs RegisterOperators()`. This includes several places, including other OSS code (nestedtensor, torchtext, torchvision). A few things to call out: For simple ops that only had one registered kernel without a dispatch key, I replaced them with: ``` TORCH_LIBRARY_FRAGMENT(ns, m) { m.def("opName", fn_name); } ``` For ops that registered to a specific dispatch key / had multiple kernels registered, I registered the common kernel (math/cpu) directly inside a `TORCH_LIBRARY_FRAGMENT` block, and registered any additional kernels from other files (e.g. cuda) in a separate `TORCH_LIBRARY_IMPL` block. ``` // cpu file TORCH_LIBRARY_FRAGMENT(ns, m) { m.def("opName(schema_inputs) -> schema_outputs"); m.impl("opName", torch::dispatch(c10::DispatchKey::CPU, TORCH_FN(cpu_kernel))); } // cuda file TORCH_LIBRARY_IMPL(ns, CUDA, m) { m.impl("opName", torch::dispatch(c10::DispatchKey::CUDA, TORCH_FN(cuda_kernel))); } ``` Special cases: I found a few ops that used a (legacy) `CPUTensorId`/`CUDATensorId` dispatch key. Updated those to use CPU/CUDA- this seems safe because the keys are aliased to one another in `DispatchKey.h` There were a handful of ops that registered a functor (function class) to the legacy API. As far as I could tell we don't allow this case in the new API, mainly because you can accomplish the same thing more cleanly with lambdas. Rather than delete the class I wrote a wrapper function on top of the class, which I passed to the new API. There were a handful of ops that were registered only to a CUDA dispatch key. I put them inside a TORCH_LIBRARY_FRAGMENT block, and used a `def()` and `impl()` call like in case two above. Test Plan: Imported from OSS Reviewed By: ezyang Differential Revision: D25056090 Pulled By: bdhirsh fbshipit-source-id: 8f868b45f545e5da2f21924046e786850eba70d9
1 parent ff96c17 commit 97bbaff

File tree

2 files changed

+17
-20
lines changed

2 files changed

+17
-20
lines changed

torchvision/csrc/cpu/video_reader/VideoReader.cpp

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -669,12 +669,9 @@ torch::List<torch::Tensor> probeVideoFromFile(std::string videoPath) {
669669

670670
} // namespace video_reader
671671

672-
static auto registry = torch::RegisterOperators()
673-
.op("video_reader::read_video_from_memory",
674-
&video_reader::readVideoFromMemory)
675-
.op("video_reader::read_video_from_file",
676-
&video_reader::readVideoFromFile)
677-
.op("video_reader::probe_video_from_memory",
678-
&video_reader::probeVideoFromMemory)
679-
.op("video_reader::probe_video_from_file",
680-
&video_reader::probeVideoFromFile);
672+
TORCH_LIBRARY_FRAGMENT(video_reader, m) {
673+
m.def("read_video_from_memory", video_reader::readVideoFromMemory);
674+
m.def("read_video_from_file", video_reader::readVideoFromFile);
675+
m.def("probe_video_from_memory", video_reader::probeVideoFromMemory);
676+
m.def("probe_video_from_file", video_reader::probeVideoFromFile);
677+
}

torchvision/csrc/vision.cpp

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -39,14 +39,14 @@ int64_t _cuda_version() {
3939
#endif
4040
}
4141

42-
static auto registry =
43-
torch::RegisterOperators()
44-
.op("torchvision::nms", &nms)
45-
.op("torchvision::roi_align(Tensor input, Tensor rois, float spatial_scale, int pooled_height, int pooled_width, int sampling_ratio, bool aligned) -> Tensor",
46-
&roi_align)
47-
.op("torchvision::roi_pool", &roi_pool)
48-
.op("torchvision::_new_empty_tensor_op", &new_empty_tensor)
49-
.op("torchvision::ps_roi_align", &ps_roi_align)
50-
.op("torchvision::ps_roi_pool", &ps_roi_pool)
51-
.op("torchvision::deform_conv2d", &deform_conv2d)
52-
.op("torchvision::_cuda_version", &_cuda_version);
42+
TORCH_LIBRARY_FRAGMENT(torchvision, m) {
43+
m.def("nms", nms);
44+
m.def("roi_align(Tensor input, Tensor rois, float spatial_scale, int pooled_height, int pooled_width, int sampling_ratio, bool aligned) -> Tensor",
45+
&roi_align);
46+
m.def("roi_pool", roi_pool);
47+
m.def("_new_empty_tensor_op", new_empty_tensor);
48+
m.def("ps_roi_align", ps_roi_align);
49+
m.def("ps_roi_pool", ps_roi_pool);
50+
m.def("deform_conv2d", deform_conv2d);
51+
m.def("_cuda_version", _cuda_version);
52+
}

0 commit comments

Comments
 (0)