Skip to content

Commit 107dc11

Browse files
committed
Update on "Add base forward grad logic"
RFC: pytorch/rfcs#11 This PR add the basic logic to handle forward grad as dual Tensors. It contains the following: - Mechanism to save dual state on a Tensor and clear it up when the dual level ends - C++ and python user facing API - Updated view system that is able to track both forward and backward views The current PR has the following limitations: - Extensive tests are in the next PR in the stack as formulas are needed to write full tests. - Only the manual formulas have been audited and no other formula is actually implemented here (they are in the next PR in the stack) - Only level 0 is allowed for now. This was discussed and agreed that it is not needed for the first version of this PR. - We can save one ViewInfo creation when both the forward and backward views have the same base. This can be done by adding a boolean flag to the DifferentiableViewMeta and extra logic in the `as_view` method. This is left out to keep this PR concise. - We can skip tracking forward views if the base has a forward grad. This can be done by adding extra logic in the `as_view` method. This is left out to keep this PR concise. Reading guide: - Updated view handling in [gen_variable_type.py](https://github.com/pytorch/pytorch/pull/49097/files#diff-f6553cec68caeaea36f6c8b14ff76a6d39dfd774e0ea9ef2f76e8d81fd9af5df), [VariableTypeUtils.h](https://github.com/pytorch/pytorch/pull/49097/files#diff-ec71cfa45954dece1236c661d170e6341879c5be637f4abf52e826d61b40695a), [variable.cpp](https://github.com/pytorch/pytorch/pull/49097/files#diff-60e3bfe444e89efc7149f25b38e472710525984789934ab83f1bd5671b8ff285) (skip code below "[Forward Grad View]" for now), [variable.h](https://github.com/pytorch/pytorch/pull/49097/files#diff-1604bcd0e4350ed99ec45e437cee7ac9ebe337392c9ea16a236247aeeb35b02bR266-R542) and [custom_function.cpp](https://github.com/pytorch/pytorch/pull/49097/files#diff-dd85f452082b5bb6612bbc12adb496f8827defa228509f7b493de1d517522d5d). This introduces the new ViewInfo to hold view informations shared for forward and backward. It also updates the differentiable view meta to use this. And it updates the as_view function to handle both forward and backward view. - New forward grad class that handle storing gradients and tracking at each level [forward_grad.h](https://github.com/pytorch/pytorch/pull/49097/files#diff-c6c5b9ab2d7e5dde4102495faa1b6bbbfc23aa3e47deb7359c0bfe1eb004c0cb), [forward_grad.cpp](https://github.com/pytorch/pytorch/pull/49097/files#diff-de2ab54ade7312701850d71a119a4f4ee4b9fc5a9c42a467cdd4e73c033531dd) and [build_variables.bzl](https://github.com/pytorch/pytorch/pull/49097/files#diff-dfdfa2efb17beddfd9094524f95351fd197db6c8857e96b436fb599870359325). EDIT: These files also contain the new flag to globally disable forward AD that allows us to reduce performance issues while this is in development. - Lowest level API and binding between Tensor and AutogradMeta in [TensorBody.h](https://github.com/pytorch/pytorch/pull/49097/files#diff-7554853205392fa743357bf845ecc350a974ec049383248c12daaf2f4de04911), [TensorImpl.cpp](https://github.com/pytorch/pytorch/pull/49097/files#diff-052bd9150ef8e09289ddf644b5a6830ede49207201cd41728f6d7cc6d9cead94), [TensorImpl.h](https://github.com/pytorch/pytorch/pull/49097/files#diff-a15aae4cf23da44970db7cece62ff981265575c798c62f7b52d87c8809dfe2e1) and the rest of [variable.cpp](https://github.com/pytorch/pytorch/pull/49097/files#diff-60e3bfe444e89efc7149f25b38e472710525984789934ab83f1bd5671b8ff285R557-R677) - API to access the forward primal that needs to be a differentiable function (and so in native_functions.yaml) [native_functions.yaml](https://github.com/pytorch/pytorch/pull/49097/files#diff-2f3dbd85efb9b5172f2264eedd3be47dd765e6ab7cc8bf3ade5e62c28ae35991) [NamedRegistrations.cpp](https://github.com/pytorch/pytorch/pull/49097/files#diff-69bd3bea510c9b64e1633fa18c3ea63d4b8348dbad3a78ad9de844ab3e43dc1d), [VariableMethodsStub.cpp](https://github.com/pytorch/pytorch/pull/49097/files#diff-23f5fcb737a2b289811fe0f4b65aef775e7c824b2e629ecd343df51405cd434f), [derivatives.yaml](https://github.com/pytorch/pytorch/pull/49097/files#diff-e4c2f99a2404e98c3586e07425da73008f36b1bada790648a7297af141d37f8c), [gen_python_functions.py](https://github.com/pytorch/pytorch/pull/49097/files#diff-e4c2f99a2404e98c3586e07425da73008f36b1bada790648a7297af141d37f8c), [gen_trace_type.py](https://github.com/pytorch/pytorch/pull/49097/files#diff-54e0b976027bf8debefb959ff360b89ae93466970c843365b1b3a03806d868ce), [TraceTypeManual.cpp](https://github.com/pytorch/pytorch/pull/49097/files#diff-f34636741ad4a23d018e0c289bc750c3bad887b45660e1d6eaf440d234a78fbf) and [part of VariableTypeManual.cpp](https://github.com/pytorch/pytorch/pull/49097/files#diff-6e19a1bce8cbdba8714b6e2c794a76bc0864b64a49cfa757cb0b5afdc937d1a4R198-R243) - c++ API [autograd.h](https://github.com/pytorch/pytorch/pull/49097/files#diff-349028fbe8291a965a7a263c323b208fe071c35c66179ee997ef84fa81aa4b1e), [autograd.cpp](https://github.com/pytorch/pytorch/pull/49097/files#diff-a3fe908d67dfec16a1fcde300de68b0701bf68b88db7451f29f2bee255cf30c9) - python binding [init.cpp](https://github.com/pytorch/pytorch/pull/49097/files#diff-c58a67c85191c22c9b3bb439117d8053edfd9dea839fa010cf967d404c3c630d) - python API [forward_ad.py](https://github.com/pytorch/pytorch/pull/49097/files#diff-a4efad4ba18fffdfb264c21e5475997a24a743089a899f8ec1a5ff962c6738d9), [autograd/__init__.py](https://github.com/pytorch/pytorch/pull/49097/files#diff-743abcafd32ad0e69f39ac5a91df4197b7e1921c135cacee7ef6dc829a8a7af8) - c++ and python printing [Formatting.cpp](https://github.com/pytorch/pytorch/pull/49097/files#diff-881dba501e71662e2e4818b4b016f739b344c8aed2f5edc6b871eda47a2aced0), [_tensor_str.py](https://github.com/pytorch/pytorch/pull/49097/files#diff-a7911f8d5e73adbff914d99fd7818ace2a7030b6a3748abe06ec6fc6e3df9cc3) - Utility for formulas and updated manual functions to respect new view system as well as forward grad [FunctionsManual.h](https://github.com/pytorch/pytorch/pull/49097/files#diff-6378bb6dc81a64dab676d61731341fa5d1088418f32a1473a33a0ccfc2357dc1), [FunctionsManual.cpp](https://github.com/pytorch/pytorch/pull/49097/files#diff-4adbd88239afcd60e8198aab65d4f5e43b62314e34b80551e997a1ea503adea5) [rest of VariableTypeManual.cpp](https://github.com/pytorch/pytorch/pull/49097/files#diff-6e19a1bce8cbdba8714b6e2c794a76bc0864b64a49cfa757cb0b5afdc937d1a4R264-R433) - Ensure SavedVariable save forward grad properly [saved_variable.h](https://github.com/pytorch/pytorch/pull/49097/files#diff-c1b8039d776241abe177d5aa99b79dd9489a9b3e529da8ab24c2e386c1238ae2), [saved_variable.cpp](https://github.com/pytorch/pytorch/pull/49097/files#diff-cc9fba479b5beae06b2eea2e390d17796e0341c5b037a20b5bcaccbb0c341030) [ghstack-poisoned]
1 parent 2c84322 commit 107dc11

File tree

4 files changed

+118
-4
lines changed

4 files changed

+118
-4
lines changed

test/test_autograd.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
IS_WINDOWS, IS_MACOS, CudaMemoryLeakCheck)
3636
from torch.autograd import Variable, Function, detect_anomaly, kineto_available
3737
from torch.autograd.function import InplaceFunction
38+
import torch.autograd.forward_ad as fwAD
3839
from torch.testing import randn_like
3940
from torch.testing._internal.common_methods_invocations import (method_tests,
4041
create_input, unpack_variables,
@@ -6187,6 +6188,65 @@ def foo(a):
61876188
self.assertEqual(hvp, torch.mm(hes, v.unsqueeze(1)).squeeze(1))
61886189
self.assertEqual(vhp, torch.mm(v.unsqueeze(0), hes).squeeze(0))
61896190

6191+
class TestAutogradForwardMode(TestCase):
6192+
def test_forward_level_cleanup(self):
6193+
import weakref
6194+
6195+
def get_tensor_and_weak_ref():
6196+
# Helper function to get a Tensor and a weak ref that tells us
6197+
# if the c++ version of this Tensor is still alive or not.
6198+
#
6199+
# Create the following reference chain to do so:
6200+
# - python Tensor t
6201+
# - c++ Tensor corresponding by t
6202+
# - c++ Node corresponding to t.grad_fn
6203+
# - python dict of metadata from this Node
6204+
# - an object in this dict that we can take a weakref of
6205+
6206+
6207+
# Create a new Tensor and Node
6208+
t = torch.rand(2, requires_grad=True).clone()
6209+
# Create the metadata dict
6210+
meta_dict = t.grad_fn.metadata
6211+
# Create the object in the dict
6212+
class Foo(object):
6213+
pass
6214+
my_obj = Foo()
6215+
meta_dict[0] = my_obj
6216+
6217+
# After exiting this function, the python Tensor t is the only
6218+
# thing keeping ref alive
6219+
ref = weakref.ref(my_obj)
6220+
return t, ref
6221+
6222+
# Sanity check that the helper function works as expected
6223+
t, t_ref = get_tensor_and_weak_ref()
6224+
self.assertIsNotNone(t_ref())
6225+
6226+
del t
6227+
self.assertIsNone(t_ref())
6228+
6229+
# Main test code
6230+
foo = torch.rand(2)
6231+
6232+
with fwAD.dual_level():
6233+
tangent, tangent_ref = get_tensor_and_weak_ref()
6234+
self.assertIsNotNone(tangent_ref())
6235+
6236+
dual = fwAD.make_dual(foo, tangent)
6237+
self.assertIsNotNone(tangent_ref())
6238+
6239+
# Make sure that the tangent we provided has been re-used as is
6240+
self.assertTrue(fwAD.unpack_dual(dual)[1] is tangent)
6241+
6242+
# Make sure that dual is keeping the tangent alive
6243+
del tangent
6244+
self.assertIsNotNone(tangent_ref())
6245+
6246+
# Make sure that the dual level does not keep the c++
6247+
# version of the tangent alive
6248+
del dual
6249+
self.assertIsNone(tangent_ref())
61906250

61916251
# Generic device type autograd tests.
61926252
class TestAutogradDeviceType(TestCase):

torch/csrc/autograd/forward_grad.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,15 @@ std::shared_ptr<ForwardADLevel> ForwardADLevel::get_by_idx(uint64_t idx) {
4444
return all_forward_levels_[idx];
4545
}
4646

47+
std::shared_ptr<ForwardADLevel> ForwardADLevel::try_get_by_idx(uint64_t idx) {
48+
std::lock_guard<std::mutex> lock(all_forward_levels_mutex_);
49+
if (idx < all_forward_levels_.size()) {
50+
return all_forward_levels_[idx];
51+
} else {
52+
return nullptr;
53+
}
54+
}
55+
4756
ForwardADLevel::~ForwardADLevel() {
4857
std::lock_guard<std::mutex> lock(mutex_);
4958
auto it = grads_.begin();

torch/csrc/autograd/forward_grad.h

Lines changed: 37 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ namespace torch { namespace autograd {
77

88
struct ForwardGrad;
99

10+
1011
// This file contains two classes that are used to store forward AD gradients and
1112
// ensure that they are scoped properly.
1213
// Because forward AD runs concurently with the evaluation of the function, we need
@@ -30,13 +31,21 @@ struct ForwardGrad;
3031
// On the other hand, the level, when it is released, will reset all the gradients for this
3132
// level on all the ForwardGrad.
3233

34+
35+
// Data structures in this file are optimized for this maximum number of levels.
36+
// The number of levels corresponds to the degree of the gradient being
37+
// computed using forward AD and we don't expect more than second order gradients
38+
// to be common.
39+
#define EXPECTED_MAX_LEVEL 2
40+
3341
struct TORCH_API ForwardADLevel {
3442
ForwardADLevel(uint64_t idx): idx_(idx) {}
3543
~ForwardADLevel();
3644

3745
static uint64_t get_next_idx();
3846
static void release_idx(uint64_t idx);
3947
static std::shared_ptr<ForwardADLevel> get_by_idx(uint64_t idx);
48+
static std::shared_ptr<ForwardADLevel> try_get_by_idx(uint64_t idx);
4049

4150
void erase(const std::shared_ptr<ForwardGrad>& grad) {
4251
std::lock_guard<std::mutex> lock(mutex_);
@@ -58,9 +67,33 @@ struct TORCH_API ForwardADLevel {
5867
struct TORCH_API ForwardGrad : std::enable_shared_from_this<ForwardGrad> {
5968

6069
ForwardGrad() {}
61-
~ForwardGrad() {
62-
for (auto& c: content_) {
63-
ForwardADLevel::get_by_idx(c.first)->erase(shared_from_this());
70+
71+
// This function must only be called when AutogradMeta is being destructed
72+
// as it ensures that:
73+
// - The only (potential) other references to this ForwardGrad are the
74+
// different level it is registered to
75+
// - No other thread will try to call `set_value` or `value` ever from now on
76+
// - Any of the ForwardADLevel that this ForwardGrad is registered with migh
77+
// call `reset` at any point during this function
78+
void clear() {
79+
c10::SmallVector<uint64_t, EXPECTED_MAX_LEVEL> levels_idx;
80+
81+
{
82+
std::lock_guard<std::mutex> lock(mutex_);
83+
for (auto& c: content_) {
84+
levels_idx.push_back(c.first);
85+
}
86+
}
87+
88+
for (auto l_idx: levels_idx) {
89+
// Use "try" version here as another thread might have deleted this
90+
// level before we got here
91+
// This is an owning reference as we want to keep the level alive
92+
// until we successfully unregister ourselves
93+
auto level = ForwardADLevel::try_get_by_idx(l_idx);
94+
if (level) {
95+
level->erase(shared_from_this());
96+
}
6497
}
6598
}
6699

@@ -95,6 +128,7 @@ struct TORCH_API ForwardGrad : std::enable_shared_from_this<ForwardGrad> {
95128

96129

97130
private:
131+
// TODO(albanD): replace this with a SmallVector
98132
std::unordered_map<uint64_t, at::Tensor> content_;
99133
mutable std::mutex mutex_;
100134

torch/csrc/autograd/variable.h

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -194,7 +194,8 @@ struct TORCH_API AutogradMeta : public c10::AutogradMetaInterface {
194194
std::shared_ptr<Node> grad_fn_;
195195
std::weak_ptr<Node> grad_accumulator_;
196196

197-
// This field is lazily initialized
197+
// This field is lazily initialized and is used to store all the
198+
// forward AD gradients associated with this Tensor
198199
// Any transition from not_initialized to initialized
199200
// must be protected by mutex_
200201
std::shared_ptr<ForwardGrad> fw_grad_;
@@ -266,6 +267,16 @@ struct TORCH_API AutogradMeta : public c10::AutogradMetaInterface {
266267
!grad_fn_ || !requires_grad_,
267268
"requires_grad should be false if grad_fn is set");
268269
}
270+
271+
~AutogradMeta() {
272+
// If AutogradMeta is being destroyed, it means that no other thread can hold a reference to its
273+
// corresponding Tensor. It implies that no other thread can be using this object and so there is
274+
// no need to lock mutex_ here.
275+
if (fw_grad_) {
276+
fw_grad_->clear();
277+
}
278+
279+
}
269280
};
270281

271282
struct TORCH_API ViewInfo {

0 commit comments

Comments
 (0)