diff --git a/CMakeLists.txt b/CMakeLists.txt index e6f4adc..9431050 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -84,7 +84,7 @@ if(Python3_EXECUTABLE) ) endif() -find_package(Torch 1.13 REQUIRED PATHS ${torch_cmake_prefix_path}) +find_package(Torch 2.1.0 REQUIRED PATHS ${torch_cmake_prefix_path}) # ------------------------------------------------------------ # Targets diff --git a/src/cc/torchdistx/deferred_init.cc b/src/cc/torchdistx/deferred_init.cc index 961b638..5bf09d6 100644 --- a/src/cc/torchdistx/deferred_init.cc +++ b/src/cc/torchdistx/deferred_init.cc @@ -1032,6 +1032,12 @@ class ProxyVariableHooks : public VariableHooksInterface { inner_->requires_grad_(self, value); } + void basic_autograd_not_implemented_fallback(const c10::OperatorHandle& op, + c10::DispatchKeySet dispatch_keys, + torch::jit::Stack* stack) const override { + inner_->basic_autograd_not_implemented_fallback(op, dispatch_keys, stack); + } + VariableHooksInterface* inner() noexcept { return inner_; } diff --git a/src/python/torchdistx/_C/fake.cc b/src/python/torchdistx/_C/fake.cc index 70e0e3c..3b7c8d5 100644 --- a/src/python/torchdistx/_C/fake.cc +++ b/src/python/torchdistx/_C/fake.cc @@ -8,7 +8,7 @@ #include #include -#include +#include #include #include @@ -22,7 +22,7 @@ void pyEnterFakeMode(bool fake_cuda) { // subsystem which would fail and prevent us from instantiating CUDA devices. if (fake_cuda) { if (!at::hasCUDA()) { - torch::utils::set_requires_cuda_init(false); + torch::utils::set_requires_device_init(at::kCUDA, false); } } } @@ -31,7 +31,7 @@ void pyLeaveFakeMode() { leaveFakeMode(); if (!isFakeModeActive() && !at::hasCUDA()) { - torch::utils::set_requires_cuda_init(true); + torch::utils::set_requires_device_init(at::kCUDA,true); } }