Skip to content
This repository was archived by the owner on Nov 1, 2024. It is now read-only.
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ def main() -> None:
zip_safe=False,
# Since PyTorch does not offer ABI compatibility we have to make sure
# that we use the same version that was used at build time.
install_requires=[f"torch=={torch.__version__}"],
# install_requires=[f"torch=={torch.__version__}"],
classifiers=[
"Development Status :: 3 - Alpha",
"Intended Audience :: Developers",
Expand Down
188 changes: 187 additions & 1 deletion src/cc/torchdistx/deferred_init.cc
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,8 @@ class Op {
}

void materialize();
void materializeWithShape(c10::IntArrayRef shape, const c10::optional<c10::Device> device);
void materializeWithDevice(const c10::optional<c10::Device> device);

std::size_t num_outputs() const noexcept {
return num_outputs_;
Expand Down Expand Up @@ -220,7 +222,6 @@ Op Op::fromOperatorHandle(const OperatorHandle& handle, Stack s) {
};

const FunctionSchema& shm = handle.schema();

return Op{shm.name(), std::move(fn), shm.arguments().size(), shm.returns().size(), std::move(s)};
}

Expand Down Expand Up @@ -271,6 +272,70 @@ void Op::materialize() {
materialized_ = true;
}

void Op::materializeWithShape(c10::IntArrayRef shape, const c10::optional<c10::Device> device) {
if (materialized_) {
return;
}

{
ThreadLocalStateGuard state_guard{*tls_};

auto replace_first_shape = [&](c10::IntArrayRef sp){
IValue local_shape(sp);
stack_[0] = local_shape;
};

std::vector<std::string> op_white_list{"aten::randn", "aten::rand", "aten::empty", "aten::ones", "aten::zeros", "aten::full"};

if (std::find(op_white_list.begin(),op_white_list.end(), name()) != op_white_list.end()){
// if the op is operator
replace_first_shape(shape);
}

if(device.has_value()){ // set target device
for (size_t i = 0 ; i < stack_.size(); i++){
if(stack_[i].isDevice()){
stack_[i] = IValue(device.value());
}
}
}

fn_(stack_);
}

fn_ = nullptr;

tls_ = nullopt;

materialized_ = true;
}

void Op::materializeWithDevice(const c10::optional<c10::Device> device) {
if (materialized_) {
return;
}

{
ThreadLocalStateGuard state_guard{*tls_};

if(device.has_value()){ // set target device
for (size_t i = 0 ; i < stack_.size(); i++){
if(stack_[i].isDevice()){
stack_[i] = IValue(device.value());
}
}
}

fn_(stack_);
}

fn_ = nullptr;

tls_ = nullopt;

materialized_ = true;
}

const Tensor& Op::getOutput(std::size_t idx) const noexcept {
const Tensor* opt_out = nullptr;

Expand Down Expand Up @@ -343,6 +408,9 @@ class OpNode {
// Materializes the operation held by this node along with all the operations
// in its recorded call stack.
void materialize();
// with changed shape
void materializeWithShape(c10::IntArrayRef shape, c10::optional<c10::Device> device);
void materializeWithDevice(c10::optional<c10::Device> device);

private:
void buildCallStack();
Expand Down Expand Up @@ -527,6 +595,54 @@ void OpNode::materialize() {
call_stack_.clear();
}

void OpNode::materializeWithDevice(const c10::optional<c10::Device> device) {
// Do not try to shortcut this function by checking if the node is already
// materialized. A later in-place operation can still change the output of
// this node.

buildCallStack();

for (OpNode* node : call_stack_) {
if (node->op_.materialized()) {
continue;
}

node->materializeArguments();

node->op_.materializeWithDevice(device);

// Make sure that we deallocate parts of the operation graph that are not
// needed anymore.
node->detachDependencies();
}

call_stack_.clear();
}

void OpNode::materializeWithShape(c10::IntArrayRef shape, const c10::optional<c10::Device> device) {
// Do not try to shortcut this function by checking if the node is already
// materialized. A later in-place operation can still change the output of
// this node.

buildCallStack();

for (OpNode* node : call_stack_) {
if (node->op_.materialized()) {
continue;
}

node->materializeArguments();

node->op_.materializeWithShape(shape, device);

// Make sure that we deallocate parts of the operation graph that are not
// needed anymore.
node->detachDependencies();
}

call_stack_.clear();
}

void OpNode::buildCallStack() {
OpNode* last_node = getLastInPlaceOpNode();

Expand Down Expand Up @@ -728,6 +844,42 @@ Tensor materialize(const Tensor& fake) {
return out;
}

Tensor materialize_with_device(const Tensor& fake, const c10::optional<c10::Device> device) {
TensorRecord& record = getTensorRecord(fake);

const OpOutputDescriptor& output_desc = record.output_descriptor();

output_desc.node()->materializeWithDevice(device);

Tensor out = output_desc.node()->op().getOutput(output_desc.output_index());

// Unfortunately there is no way for us to track calls to `requires_grad_()`,
// so instead we explicitly set `requires_grad` after materialization.
if (fake.is_leaf() && fake.requires_grad()) {
out.set_requires_grad(true);
}

return out;
}

Tensor materialize_with_shape(const Tensor& fake, c10::IntArrayRef shape, const c10::optional<c10::Device> device) {
TensorRecord& record = getTensorRecord(fake);

const OpOutputDescriptor& output_desc = record.output_descriptor();

output_desc.node()->materializeWithShape(shape, device);

Tensor out = output_desc.node()->op().getOutput(output_desc.output_index());

// Unfortunately there is no way for us to track calls to `requires_grad_()`,
// so instead we explicitly set `requires_grad` after materialization.
if (fake.is_leaf() && fake.requires_grad()) {
out.set_requires_grad(true);
}

return out;
}

// The catch-all handler for the `DeferredInit` dispatch key.
class DeferredInitHandler {
public:
Expand Down Expand Up @@ -1032,6 +1184,12 @@ class ProxyVariableHooks : public VariableHooksInterface {
inner_->requires_grad_(self, value);
}

void basic_autograd_not_implemented_fallback(const c10::OperatorHandle& op,
c10::DispatchKeySet dispatch_keys,
torch::jit::Stack* stack) const override {
inner_->basic_autograd_not_implemented_fallback(op, dispatch_keys, stack);
}

VariableHooksInterface* inner() noexcept {
return inner_;
}
Expand Down Expand Up @@ -1164,6 +1322,7 @@ bool canMaterialize(const Tensor& tensor) noexcept {
return isFake(tensor) && unsafeAsFake(tensor).hasData(DispatchKey::DeferredInit);
}


Tensor materializeTensor(const Tensor& tensor) {
if (canMaterialize(tensor)) {
return detail::materialize(tensor);
Expand All @@ -1172,4 +1331,31 @@ Tensor materializeTensor(const Tensor& tensor) {
}
}

Tensor materializeTensorWithDevice(const at::Tensor& tensor, const c10::optional<c10::Device> device){
if (canMaterialize(tensor)) {
return detail::materialize_with_device(tensor, device);
} else {
return tensor;
}
}
Tensor materializeTensorWithLocalShape(const at::Tensor& tensor, c10::IntArrayRef shape, const c10::optional<c10::Device> device){
if (canMaterialize(tensor)) {
return detail::materialize_with_shape(tensor, shape, device);
} else {
return tensor;
}
}

bool isGenByRandomOp(const Tensor& tensor) noexcept{
if (canMaterialize(tensor)) {
detail::TensorRecord& record = detail::getTensorRecord(tensor);
const detail::OpOutputDescriptor& output_desc = record.output_descriptor();
auto name = output_desc.node()->op().name();
std::vector<std::string> op_white_list{"aten::randn", "aten::rand", "aten::uniform_"};
return std::find(op_white_list.begin(),op_white_list.end(), name) != op_white_list.end();
}else{
return false;
}
}

} // namespace torchdistx
6 changes: 5 additions & 1 deletion src/cc/torchdistx/deferred_init.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@

#include <c10/core/DispatchKey.h>
#include <c10/core/impl/LocalDispatchKeySet.h>
#include <c10/core/SymIntArrayRef.h>
#include <c10/core/Device.h>

#include "macros.h"

Expand All @@ -27,9 +29,11 @@ TDX_API void leaveDeferredInit() noexcept;

// Indicates whether `tensor` has been constructed in a deferred-init context.
TDX_API bool canMaterialize(const at::Tensor& tensor) noexcept;

TDX_API bool isGenByRandomOp(const at::Tensor& tensor) noexcept;
// Materializes `tensor`.
TDX_API at::Tensor materializeTensor(const at::Tensor& tensor);
TDX_API at::Tensor materializeTensorWithDevice(const at::Tensor& tensor, const c10::optional<c10::Device> device = {});
TDX_API at::Tensor materializeTensorWithLocalShape(const at::Tensor& tensor, c10::IntArrayRef shape, const c10::optional<c10::Device> device = {});

// Temporarily disables deferred-init.
class TDX_API NoDeferredInit {
Expand Down
6 changes: 6 additions & 0 deletions src/python/torchdistx/_C.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,18 @@
# LICENSE file in the root directory of this source tree.

import torch
from torch.types import _int, SymInt, _device
from collections import Sequence
from typing import Union, Optional

def enter_deferred_init() -> None: ...
def leave_deferred_init() -> None: ...
def enter_fake_mode(fake_mode: bool) -> None: ...
def leave_fake_mode() -> None: ...
def is_fake(tensor: torch.Tensor) -> bool: ...
def is_gen_by_random_op(tensor: torch.Tensor) -> bool: ...
def can_materialize(tensor: torch.Tensor) -> bool: ...
def materialize_tensor(tensor: torch.Tensor) -> torch.Tensor: ...
def materialize_tensor_with_device(tensor: torch.Tensor, device: Optional[Union[_device, str, None]] = None) -> torch.Tensor: ...
def materialize_tensor_with_local_shape(tensor: torch.Tensor, shape: Sequence[Union[_int, SymInt]], device: Optional[Union[_device, str, None]] = None) -> torch.Tensor: ...
def meta_like(fake: torch.Tensor) -> torch.Tensor: ...
Loading