From bcaaffbacc41342429ccd002aece151b0fd4e227 Mon Sep 17 00:00:00 2001 From: xinyi Date: Thu, 22 Feb 2024 17:20:25 +0800 Subject: [PATCH 1/4] support torch 2.1 feat(dmodule): support parallelized dtensor init feat(dtensor): support for query random op feat(dtensor): support deferred init on device --- src/cc/torchdistx/deferred_init.cc | 111 +++++++++++++++++++++- src/cc/torchdistx/deferred_init.h | 5 +- src/python/torchdistx/_C.pyi | 5 + src/python/torchdistx/_C/deferred_init.cc | 47 ++++++++- 4 files changed, 165 insertions(+), 3 deletions(-) diff --git a/src/cc/torchdistx/deferred_init.cc b/src/cc/torchdistx/deferred_init.cc index 961b638..0a29d9f 100644 --- a/src/cc/torchdistx/deferred_init.cc +++ b/src/cc/torchdistx/deferred_init.cc @@ -173,6 +173,7 @@ class Op { } void materialize(); + void materializeWithShape(c10::IntArrayRef shape, const c10::optional device); std::size_t num_outputs() const noexcept { return num_outputs_; @@ -220,7 +221,6 @@ Op Op::fromOperatorHandle(const OperatorHandle& handle, Stack s) { }; const FunctionSchema& shm = handle.schema(); - return Op{shm.name(), std::move(fn), shm.arguments().size(), shm.returns().size(), std::move(s)}; } @@ -271,6 +271,44 @@ void Op::materialize() { materialized_ = true; } +void Op::materializeWithShape(c10::IntArrayRef shape, const c10::optional device) { + if (materialized_) { + return; + } + + { + ThreadLocalStateGuard state_guard{*tls_}; + + auto replace_first_shape = [&](c10::IntArrayRef sp){ + IValue local_shape(sp); + stack_[0] = local_shape; + }; + + std::vector op_white_list{"aten::randn", "aten::rand", "aten::empty", "aten::ones", "aten::zeros", "aten::full" }; + + if (std::find(op_white_list.begin(),op_white_list.end(), name()) != op_white_list.end()){ + // if the op is operator + replace_first_shape(shape); + } + + if(device.has_value()){ // set target device + for (size_t i = 0 ; i < stack_.size(); i++){ + if(stack_[i].isDevice()){ + stack_[i] = IValue(device.value()); + } + } + } + + fn_(stack_); + } + + fn_ = nullptr; + + tls_ = nullopt; + + materialized_ = true; +} + const Tensor& Op::getOutput(std::size_t idx) const noexcept { const Tensor* opt_out = nullptr; @@ -343,6 +381,8 @@ class OpNode { // Materializes the operation held by this node along with all the operations // in its recorded call stack. void materialize(); + // with changed shape + void materializeWithShape(c10::IntArrayRef shape, c10::optional device); private: void buildCallStack(); @@ -527,6 +567,30 @@ void OpNode::materialize() { call_stack_.clear(); } +void OpNode::materializeWithShape(c10::IntArrayRef shape, const c10::optional device) { + // Do not try to shortcut this function by checking if the node is already + // materialized. A later in-place operation can still change the output of + // this node. + + buildCallStack(); + + for (OpNode* node : call_stack_) { + if (node->op_.materialized()) { + continue; + } + + node->materializeArguments(); + + node->op_.materializeWithShape(shape, device); + + // Make sure that we deallocate parts of the operation graph that are not + // needed anymore. + node->detachDependencies(); + } + + call_stack_.clear(); +} + void OpNode::buildCallStack() { OpNode* last_node = getLastInPlaceOpNode(); @@ -728,6 +792,24 @@ Tensor materialize(const Tensor& fake) { return out; } +Tensor materialize_with_shape(const Tensor& fake, c10::IntArrayRef shape, const c10::optional device) { + TensorRecord& record = getTensorRecord(fake); + + const OpOutputDescriptor& output_desc = record.output_descriptor(); + + output_desc.node()->materializeWithShape(shape, device); + + Tensor out = output_desc.node()->op().getOutput(output_desc.output_index()); + + // Unfortunately there is no way for us to track calls to `requires_grad_()`, + // so instead we explicitly set `requires_grad` after materialization. + if (fake.is_leaf() && fake.requires_grad()) { + out.set_requires_grad(true); + } + + return out; +} + // The catch-all handler for the `DeferredInit` dispatch key. class DeferredInitHandler { public: @@ -1032,6 +1114,12 @@ class ProxyVariableHooks : public VariableHooksInterface { inner_->requires_grad_(self, value); } + void basic_autograd_not_implemented_fallback(const c10::OperatorHandle& op, + c10::DispatchKeySet dispatch_keys, + torch::jit::Stack* stack) const override { + inner_->basic_autograd_not_implemented_fallback(op, dispatch_keys, stack); + } + VariableHooksInterface* inner() noexcept { return inner_; } @@ -1164,6 +1252,7 @@ bool canMaterialize(const Tensor& tensor) noexcept { return isFake(tensor) && unsafeAsFake(tensor).hasData(DispatchKey::DeferredInit); } + Tensor materializeTensor(const Tensor& tensor) { if (canMaterialize(tensor)) { return detail::materialize(tensor); @@ -1172,4 +1261,24 @@ Tensor materializeTensor(const Tensor& tensor) { } } +Tensor materializeTensorWithLocalShape(const at::Tensor& tensor, c10::IntArrayRef shape, const c10::optional device){ + if (canMaterialize(tensor)) { + return detail::materialize_with_shape(tensor, shape, device); + } else { + return tensor; + } +} + +bool isGenByRandomOp(const Tensor& tensor) noexcept{ + if (canMaterialize(tensor)) { + detail::TensorRecord& record = detail::getTensorRecord(tensor); + const detail::OpOutputDescriptor& output_desc = record.output_descriptor(); + auto name = output_desc.node()->op().name(); + std::vector op_white_list{"aten::randn", "aten::rand"}; + return std::find(op_white_list.begin(),op_white_list.end(), name) != op_white_list.end(); + }else{ + return false; + } +} + } // namespace torchdistx diff --git a/src/cc/torchdistx/deferred_init.h b/src/cc/torchdistx/deferred_init.h index afda4cb..4c654cf 100644 --- a/src/cc/torchdistx/deferred_init.h +++ b/src/cc/torchdistx/deferred_init.h @@ -8,6 +8,8 @@ #include #include +#include +#include #include "macros.h" @@ -27,9 +29,10 @@ TDX_API void leaveDeferredInit() noexcept; // Indicates whether `tensor` has been constructed in a deferred-init context. TDX_API bool canMaterialize(const at::Tensor& tensor) noexcept; - +TDX_API bool isGenByRandomOp(const at::Tensor& tensor) noexcept; // Materializes `tensor`. TDX_API at::Tensor materializeTensor(const at::Tensor& tensor); +TDX_API at::Tensor materializeTensorWithLocalShape(const at::Tensor& tensor, c10::IntArrayRef shape, const c10::optional device = {}); // Temporarily disables deferred-init. class TDX_API NoDeferredInit { diff --git a/src/python/torchdistx/_C.pyi b/src/python/torchdistx/_C.pyi index 82c8421..9e4462a 100644 --- a/src/python/torchdistx/_C.pyi +++ b/src/python/torchdistx/_C.pyi @@ -5,12 +5,17 @@ # LICENSE file in the root directory of this source tree. import torch +from torch.types import _int, SymInt, _device +from collections import Sequence +from typing import Union, Optional def enter_deferred_init() -> None: ... def leave_deferred_init() -> None: ... def enter_fake_mode(fake_mode: bool) -> None: ... def leave_fake_mode() -> None: ... def is_fake(tensor: torch.Tensor) -> bool: ... +def is_gen_by_random_op(tensor: torch.Tensor) -> bool: ... def can_materialize(tensor: torch.Tensor) -> bool: ... def materialize_tensor(tensor: torch.Tensor) -> torch.Tensor: ... +def materialize_tensor_with_local_shape(tensor: torch.Tensor, shape: Sequence[Union[_int, SymInt]], device: Optional[Union[_device, str, None]] = None) -> torch.Tensor: ... def meta_like(fake: torch.Tensor) -> torch.Tensor: ... diff --git a/src/python/torchdistx/_C/deferred_init.cc b/src/python/torchdistx/_C/deferred_init.cc index 7eb73b8..d810544 100644 --- a/src/python/torchdistx/_C/deferred_init.cc +++ b/src/python/torchdistx/_C/deferred_init.cc @@ -8,6 +8,7 @@ #include #include +#include #include #include #include @@ -94,14 +95,58 @@ py::object materializeVariable(const py::object& var) { return makeVariable(Py_TYPE(naked_var), std::move(materialized_data)); } + +// Materializing a tensor in Python requires an extra step. We need to ensure +// that the materialized tensor has the same Python class (e.g. `Variable` or +// `Parameter`) as the original tensor. +// and with dtensor case we need to change the parallized tensor shape +py::object materializeVariableWithLocalShape(const py::object& var, const py::object &shape, const c10::optional device) { + PyObject* naked_var = var.ptr(); + auto c_shape = shape.cast>(); + + if (!THPVariable_Check(naked_var)) { + throw TypeError{"`var` has to be a `Variable`, but got `%s`.", Py_TYPE(naked_var)->tp_name}; + } + + const Tensor& data = THPVariable_Unpack(naked_var); + + auto materialize = [=](const Tensor& tensor, c10::IntArrayRef sp) { + py::gil_scoped_release guard{}; + + return materializeTensorWithLocalShape(tensor, sp, device); + }; + + Tensor materialized_data = materialize(data, at::IntArrayRef(c_shape)); + + // Check if we have really materialized `data`. Materializing a regular tensor + // is a no-op, so we can simply return. + if (materialized_data.is_same(data)) { + return var; + } + + // We might have already materialized `data`. Make sure that we preserve its + // identity on the Python side and avoid creating a new Python tensor. + c10::optional opt_materialized_var = + materialized_data.unsafeGetTensorImpl()->pyobj_slot()->check_pyobj(getPyInterpreter()); + if (opt_materialized_var.has_value()) { + return py::reinterpret_borrow(*opt_materialized_var); + } + + // Otherwise ensure that our materialized tensor has the same Python class as + // the original tensor. + return makeVariable(Py_TYPE(naked_var), std::move(materialized_data)); +} + + } // namespace void initDeferredInitFunctions(py::module& m) { m.def("enter_deferred_init", enterDeferredInit); m.def("leave_deferred_init", leaveDeferredInit); - m.def("can_materialize", canMaterialize); + m.def("is_gen_by_random_op", isGenByRandomOp); m.def("materialize_tensor", materializeVariable); + m.def("materialize_tensor_with_local_shape", materializeVariableWithLocalShape); } } // namespace torchdistx::python From 86190b66084aec92592527aba75fd4ac15a91615 Mon Sep 17 00:00:00 2001 From: xinyi Date: Thu, 28 Mar 2024 08:33:32 +0800 Subject: [PATCH 2/4] fix(deferinit): support torch2.2 deferinit with defer_device --- src/python/torchdistx/_C/fake.cc | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/src/python/torchdistx/_C/fake.cc b/src/python/torchdistx/_C/fake.cc index 70e0e3c..538a3b5 100644 --- a/src/python/torchdistx/_C/fake.cc +++ b/src/python/torchdistx/_C/fake.cc @@ -8,7 +8,12 @@ #include #include +#include +#if TORCH_VERSION_MAJOR >= 2 && TORCH_VERSION_MINOR>=3 +#include +#else #include +#endif #include #include @@ -22,7 +27,12 @@ void pyEnterFakeMode(bool fake_cuda) { // subsystem which would fail and prevent us from instantiating CUDA devices. if (fake_cuda) { if (!at::hasCUDA()) { + +#if TORCH_VERSION_MAJOR >= 2 && TORCH_VERSION_MINOR>=3 + torch::utils::set_requires_device_init(at::kCUDA, false); +#else torch::utils::set_requires_cuda_init(false); +#endif } } } @@ -31,7 +41,11 @@ void pyLeaveFakeMode() { leaveFakeMode(); if (!isFakeModeActive() && !at::hasCUDA()) { +#if TORCH_VERSION_MAJOR >= 2 && TORCH_VERSION_MINOR>=3 + torch::utils::set_requires_device_init(at::kCUDA, true); +#else torch::utils::set_requires_cuda_init(true); +#endif } } From 67ed2e9f2bba6e3b1c2f667ff6d0a4293c340fd1 Mon Sep 17 00:00:00 2001 From: xinyi Date: Thu, 25 Apr 2024 20:41:04 +0800 Subject: [PATCH 3/4] fix(dtensor): fix uniform op is random op fix generate op is random op when deferred init is set --- src/cc/torchdistx/deferred_init.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/cc/torchdistx/deferred_init.cc b/src/cc/torchdistx/deferred_init.cc index 0a29d9f..311b70e 100644 --- a/src/cc/torchdistx/deferred_init.cc +++ b/src/cc/torchdistx/deferred_init.cc @@ -284,7 +284,7 @@ void Op::materializeWithShape(c10::IntArrayRef shape, const c10::optional op_white_list{"aten::randn", "aten::rand", "aten::empty", "aten::ones", "aten::zeros", "aten::full" }; + std::vector op_white_list{"aten::randn", "aten::rand", "aten::empty", "aten::ones", "aten::zeros", "aten::full"}; if (std::find(op_white_list.begin(),op_white_list.end(), name()) != op_white_list.end()){ // if the op is operator @@ -1274,7 +1274,7 @@ bool isGenByRandomOp(const Tensor& tensor) noexcept{ detail::TensorRecord& record = detail::getTensorRecord(tensor); const detail::OpOutputDescriptor& output_desc = record.output_descriptor(); auto name = output_desc.node()->op().name(); - std::vector op_white_list{"aten::randn", "aten::rand"}; + std::vector op_white_list{"aten::randn", "aten::rand", "aten::uniform_"}; return std::find(op_white_list.begin(),op_white_list.end(), name) != op_white_list.end(); }else{ return false; From ea8c9adfaa90486bd94936bd623aa484a9cb16e3 Mon Sep 17 00:00:00 2001 From: xinyi Date: Thu, 20 Jun 2024 15:19:33 +0800 Subject: [PATCH 4/4] support material at device --- setup.py | 2 +- src/cc/torchdistx/deferred_init.cc | 77 +++++++++++++++++++++++ src/cc/torchdistx/deferred_init.h | 1 + src/python/torchdistx/_C.pyi | 1 + src/python/torchdistx/_C/deferred_init.cc | 41 ++++++++++++ 5 files changed, 121 insertions(+), 1 deletion(-) diff --git a/setup.py b/setup.py index a4102b5..62b05ac 100644 --- a/setup.py +++ b/setup.py @@ -173,7 +173,7 @@ def main() -> None: zip_safe=False, # Since PyTorch does not offer ABI compatibility we have to make sure # that we use the same version that was used at build time. - install_requires=[f"torch=={torch.__version__}"], + # install_requires=[f"torch=={torch.__version__}"], classifiers=[ "Development Status :: 3 - Alpha", "Intended Audience :: Developers", diff --git a/src/cc/torchdistx/deferred_init.cc b/src/cc/torchdistx/deferred_init.cc index 311b70e..1cc2a2c 100644 --- a/src/cc/torchdistx/deferred_init.cc +++ b/src/cc/torchdistx/deferred_init.cc @@ -174,6 +174,7 @@ class Op { void materialize(); void materializeWithShape(c10::IntArrayRef shape, const c10::optional device); + void materializeWithDevice(const c10::optional device); std::size_t num_outputs() const noexcept { return num_outputs_; @@ -309,6 +310,32 @@ void Op::materializeWithShape(c10::IntArrayRef shape, const c10::optional device) { + if (materialized_) { + return; + } + + { + ThreadLocalStateGuard state_guard{*tls_}; + + if(device.has_value()){ // set target device + for (size_t i = 0 ; i < stack_.size(); i++){ + if(stack_[i].isDevice()){ + stack_[i] = IValue(device.value()); + } + } + } + + fn_(stack_); + } + + fn_ = nullptr; + + tls_ = nullopt; + + materialized_ = true; +} + const Tensor& Op::getOutput(std::size_t idx) const noexcept { const Tensor* opt_out = nullptr; @@ -383,6 +410,7 @@ class OpNode { void materialize(); // with changed shape void materializeWithShape(c10::IntArrayRef shape, c10::optional device); + void materializeWithDevice(c10::optional device); private: void buildCallStack(); @@ -567,6 +595,30 @@ void OpNode::materialize() { call_stack_.clear(); } +void OpNode::materializeWithDevice(const c10::optional device) { + // Do not try to shortcut this function by checking if the node is already + // materialized. A later in-place operation can still change the output of + // this node. + + buildCallStack(); + + for (OpNode* node : call_stack_) { + if (node->op_.materialized()) { + continue; + } + + node->materializeArguments(); + + node->op_.materializeWithDevice(device); + + // Make sure that we deallocate parts of the operation graph that are not + // needed anymore. + node->detachDependencies(); + } + + call_stack_.clear(); +} + void OpNode::materializeWithShape(c10::IntArrayRef shape, const c10::optional device) { // Do not try to shortcut this function by checking if the node is already // materialized. A later in-place operation can still change the output of @@ -792,6 +844,24 @@ Tensor materialize(const Tensor& fake) { return out; } +Tensor materialize_with_device(const Tensor& fake, const c10::optional device) { + TensorRecord& record = getTensorRecord(fake); + + const OpOutputDescriptor& output_desc = record.output_descriptor(); + + output_desc.node()->materializeWithDevice(device); + + Tensor out = output_desc.node()->op().getOutput(output_desc.output_index()); + + // Unfortunately there is no way for us to track calls to `requires_grad_()`, + // so instead we explicitly set `requires_grad` after materialization. + if (fake.is_leaf() && fake.requires_grad()) { + out.set_requires_grad(true); + } + + return out; +} + Tensor materialize_with_shape(const Tensor& fake, c10::IntArrayRef shape, const c10::optional device) { TensorRecord& record = getTensorRecord(fake); @@ -1261,6 +1331,13 @@ Tensor materializeTensor(const Tensor& tensor) { } } +Tensor materializeTensorWithDevice(const at::Tensor& tensor, const c10::optional device){ + if (canMaterialize(tensor)) { + return detail::materialize_with_device(tensor, device); + } else { + return tensor; + } +} Tensor materializeTensorWithLocalShape(const at::Tensor& tensor, c10::IntArrayRef shape, const c10::optional device){ if (canMaterialize(tensor)) { return detail::materialize_with_shape(tensor, shape, device); diff --git a/src/cc/torchdistx/deferred_init.h b/src/cc/torchdistx/deferred_init.h index 4c654cf..680a626 100644 --- a/src/cc/torchdistx/deferred_init.h +++ b/src/cc/torchdistx/deferred_init.h @@ -32,6 +32,7 @@ TDX_API bool canMaterialize(const at::Tensor& tensor) noexcept; TDX_API bool isGenByRandomOp(const at::Tensor& tensor) noexcept; // Materializes `tensor`. TDX_API at::Tensor materializeTensor(const at::Tensor& tensor); +TDX_API at::Tensor materializeTensorWithDevice(const at::Tensor& tensor, const c10::optional device = {}); TDX_API at::Tensor materializeTensorWithLocalShape(const at::Tensor& tensor, c10::IntArrayRef shape, const c10::optional device = {}); // Temporarily disables deferred-init. diff --git a/src/python/torchdistx/_C.pyi b/src/python/torchdistx/_C.pyi index 9e4462a..8138a55 100644 --- a/src/python/torchdistx/_C.pyi +++ b/src/python/torchdistx/_C.pyi @@ -17,5 +17,6 @@ def is_fake(tensor: torch.Tensor) -> bool: ... def is_gen_by_random_op(tensor: torch.Tensor) -> bool: ... def can_materialize(tensor: torch.Tensor) -> bool: ... def materialize_tensor(tensor: torch.Tensor) -> torch.Tensor: ... +def materialize_tensor_with_device(tensor: torch.Tensor, device: Optional[Union[_device, str, None]] = None) -> torch.Tensor: ... def materialize_tensor_with_local_shape(tensor: torch.Tensor, shape: Sequence[Union[_int, SymInt]], device: Optional[Union[_device, str, None]] = None) -> torch.Tensor: ... def meta_like(fake: torch.Tensor) -> torch.Tensor: ... diff --git a/src/python/torchdistx/_C/deferred_init.cc b/src/python/torchdistx/_C/deferred_init.cc index d810544..8013325 100644 --- a/src/python/torchdistx/_C/deferred_init.cc +++ b/src/python/torchdistx/_C/deferred_init.cc @@ -138,6 +138,46 @@ py::object materializeVariableWithLocalShape(const py::object& var, const py::ob } +// Materializing a tensor in Python requires an extra step. We need to ensure +// that the materialized tensor has the same Python class (e.g. `Variable` or +// `Parameter`) as the original tensor. +// and with dtensor case we need to change the parallized tensor shape +py::object materializeVariableWithDevice(const py::object& var, const c10::optional device) { + PyObject* naked_var = var.ptr(); + + if (!THPVariable_Check(naked_var)) { + throw TypeError{"`var` has to be a `Variable`, but got `%s`.", Py_TYPE(naked_var)->tp_name}; + } + + const Tensor& data = THPVariable_Unpack(naked_var); + + auto materialize = [=](const Tensor& tensor) { + py::gil_scoped_release guard{}; + return materializeTensorWithDevice(tensor, device); + }; + + Tensor materialized_data = materialize(data); + + // Check if we have really materialized `data`. Materializing a regular tensor + // is a no-op, so we can simply return. + if (materialized_data.is_same(data)) { + return var; + } + + // We might have already materialized `data`. Make sure that we preserve its + // identity on the Python side and avoid creating a new Python tensor. + c10::optional opt_materialized_var = + materialized_data.unsafeGetTensorImpl()->pyobj_slot()->check_pyobj(getPyInterpreter()); + if (opt_materialized_var.has_value()) { + return py::reinterpret_borrow(*opt_materialized_var); + } + + // Otherwise ensure that our materialized tensor has the same Python class as + // the original tensor. + return makeVariable(Py_TYPE(naked_var), std::move(materialized_data)); +} + + } // namespace void initDeferredInitFunctions(py::module& m) { @@ -146,6 +186,7 @@ void initDeferredInitFunctions(py::module& m) { m.def("can_materialize", canMaterialize); m.def("is_gen_by_random_op", isGenByRandomOp); m.def("materialize_tensor", materializeVariable); + m.def("materialize_tensor_with_device", materializeVariableWithDevice); m.def("materialize_tensor_with_local_shape", materializeVariableWithLocalShape); }