Skip to content
This repository was archived by the owner on Nov 15, 2022. It is now read-only.

Commit 5e493dc

Browse files
bdhirshfacebook-github-bot
authored andcommitted
Updating all call-sites of the legacy dispatcher registration API in fbcode to the new API. (#48178)
Summary: Pull Request resolved: pytorch/pytorch#48178 I migrated all call sites that used the legacy dispatcher registration API (RegisterOperators()) to use the new API (TORCH_LIBRARY...). I found all call-sites by running `fbgs RegisterOperators()`. This includes several places, including other OSS code (nestedtensor, torchtext, torchvision). A few things to call out: For simple ops that only had one registered kernel without a dispatch key, I replaced them with: ``` TORCH_LIBRARY_FRAGMENT(ns, m) { m.def("opName", fn_name); } ``` For ops that registered to a specific dispatch key / had multiple kernels registered, I registered the common kernel (math/cpu) directly inside a `TORCH_LIBRARY_FRAGMENT` block, and registered any additional kernels from other files (e.g. cuda) in a separate `TORCH_LIBRARY_IMPL` block. ``` // cpu file TORCH_LIBRARY_FRAGMENT(ns, m) { m.def("opName(schema_inputs) -> schema_outputs"); m.impl("opName", torch::dispatch(c10::DispatchKey::CPU, TORCH_FN(cpu_kernel))); } // cuda file TORCH_LIBRARY_IMPL(ns, CUDA, m) { m.impl("opName", torch::dispatch(c10::DispatchKey::CUDA, TORCH_FN(cuda_kernel))); } ``` Special cases: I found a few ops that used a (legacy) `CPUTensorId`/`CUDATensorId` dispatch key. Updated those to use CPU/CUDA- this seems safe because the keys are aliased to one another in `DispatchKey.h` There were a handful of ops that registered a functor (function class) to the legacy API. As far as I could tell we don't allow this case in the new API, mainly because you can accomplish the same thing more cleanly with lambdas. Rather than delete the class I wrote a wrapper function on top of the class, which I passed to the new API. There were a handful of ops that were registered only to a CUDA dispatch key. I put them inside a TORCH_LIBRARY_FRAGMENT block, and used a `def()` and `impl()` call like in case two above. Test Plan: Imported from OSS Reviewed By: ezyang Differential Revision: D25056090 Pulled By: bdhirsh fbshipit-source-id: 8f868b45f545e5da2f21924046e786850eba70d9
1 parent 87793a3 commit 5e493dc

File tree

3 files changed

+24
-22
lines changed

3 files changed

+24
-22
lines changed

nestedtensor/csrc/mha.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -62,8 +62,9 @@ at::Tensor min_mha(
6262
return attn_output;
6363
}
6464

65-
static auto registry =
66-
torch::RegisterOperators().op("nestedtensor::min_mha", &min_mha);
65+
TORCH_LIBRARY_FRAGMENT(nestedtensor, m) {
66+
m.def("min_mha", min_mha);
67+
}
6768

6869
} // namespace nested_tensor
6970
} // namespace torch

nestedtensor/csrc/py_init.cpp

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -145,36 +145,34 @@ inline std::vector<std::string> split_str(
145145
result.push_back(s);
146146
return result;
147147
}
148-
149-
static auto registry =
150-
torch::RegisterOperators()
151-
.op("nestedtensor::is_nested_tensor_impl",
152-
[](Tensor tensor) { return is_nested_tensor_impl(tensor); })
153-
.op("nestedtensor::nested_dim",
148+
TORCH_LIBRARY_FRAGMENT(nestedtensor, m) {
149+
m.def("is_nested_tensor_impl",
150+
[](Tensor tensor) { return is_nested_tensor_impl(tensor); });
151+
m.def("nested_dim",
154152
[](Tensor tensor) {
155153
return get_nested_tensor_impl(tensor)->nested_dim();
156-
})
157-
.op("nestedtensor::stack",
154+
});
155+
m.def("stack",
158156
[](std::vector<Tensor> tensors, int64_t dim) {
159157
return at::stack(TensorList(tensors), dim);
160-
})
161-
.op("nestedtensor::cat",
158+
});
159+
m.def("cat",
162160
[](std::vector<Tensor> tensors, int64_t dim) {
163161
return at::cat(TensorList(tensors), dim);
164-
})
165-
.op("nestedtensor::to_nested_tensor",
162+
});
163+
m.def("to_nested_tensor",
166164
[](Tensor tensor, c10::optional<int64_t> dim) {
167165
return get_nested_tensor_impl(tensor)->to_nested_tensor(dim);
168-
})
169-
.op("nestedtensor::sizes",
166+
});
167+
m.def("sizes",
170168
[](Tensor tensor) {
171169
return get_nested_tensor_impl(tensor)->opt_sizes();
172-
})
173-
.op("nestedtensor::len",
170+
});
171+
m.def("len",
174172
[](Tensor self) {
175173
return (int64_t)(get_nested_tensor_structure(self).degree());
176-
})
177-
.op("nestedtensor::str", [](Tensor tensor) {
174+
});
175+
m.def("str", [](Tensor tensor) {
178176
auto node = get_nested_tensor_structure(tensor);
179177
return NestedNode___str__(
180178
node,
@@ -203,6 +201,8 @@ static auto registry =
203201
return result;
204202
});
205203
});
204+
}
205+
206206
} // namespace
207207
} // namespace nested_tensor
208208
} // namespace torch

nestedtensor/csrc/totensor.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -109,10 +109,11 @@ Tensor NestedTensor_to_tensor(Tensor tensor, c10::optional<int64_t> dim_) {
109109
// return wrap_tensor_node(TensorNode(std::move(result)));
110110
}
111111

112-
static auto registry = torch::RegisterOperators().op(
113-
"nestedtensor::to_tensor",
112+
TORCH_LIBRARY_FRAGMENT(nestedtensor, m) {
113+
m.def("to_tensor",
114114
[](Tensor tensor, c10::optional<int64_t> dim) {
115115
return NestedTensor_to_tensor(tensor, dim);
116116
});
117+
}
117118

118119
} // namespace at

0 commit comments

Comments
 (0)