-
Notifications
You must be signed in to change notification settings - Fork 25.4k
Add base forward grad logic #49097
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add base forward grad logic #49097
Changes from all commits
2b72f72
c69e6da
e66b605
2cea4d9
f8e4761
949e783
5393431
e5997d5
248f0f0
289ab9b
7676dab
f2f69e7
8eb6fe0
2c84322
107dc11
4fb045e
28c25e2
4139298
9c8492f
8d22120
dd72a26
98b99a8
b87f778
004826a
bc8b23c
163b1eb
a55e021
de9d986
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
#include <ATen/ATen.h> | ||
|
||
namespace at { | ||
namespace native { | ||
|
||
/// This function can be used to create a dual Tensor that holds a tangent to compute forward mode gradients. | ||
/// Note that the dual Tensor's primal is a view of the given primal and the given tangent is used as-is. | ||
/// This function is backward differentiable. | ||
at::Tensor make_dual(const at::Tensor& primal, const at::Tensor& tangent, int64_t level) { | ||
TORCH_CHECK(!primal.fw_grad(level).defined(), "Making a dual Tensor based on a Tensor that " | ||
"already has a forward gradient at the same level ", level, " is not supported."); | ||
|
||
auto dual_tensor = primal.view(primal.sizes()); | ||
dual_tensor.set_fw_grad(tangent, level, /* is_inplace_op */ false); | ||
return dual_tensor; | ||
} | ||
|
||
/// This function can be used to unpack a given dual Tensor to get its primal and tangent. The returned primal | ||
/// is a view of the dual and the tangent is returned as is. | ||
/// This function is backward differentiable. | ||
std::tuple<at::Tensor, at::Tensor> unpack_dual(const at::Tensor& tensor, int64_t level) { | ||
return std::tuple<at::Tensor, at::Tensor>(tensor._fw_primal(level), tensor.fw_grad(level)); | ||
} | ||
|
||
} // namespace native | ||
|
||
} // namespace at |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -44,6 +44,17 @@ const at::Tensor& TensorImpl::grad() const { | |
return autograd_meta_->grad(); | ||
} | ||
|
||
const at::Tensor& TensorImpl::fw_grad(uint64_t level, const at::Tensor& self) const { | ||
// See TensorImpl::grad() above for explanation about the line below | ||
if (!autograd_meta_) return impl::GetAutogradMetaFactory()->undefined_tensor(); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @swolchok You don't happen to know a good way to avoid this goofiness? (I guess the big problem is we're returning |
||
return autograd_meta_->fw_grad(level, self); | ||
} | ||
|
||
void TensorImpl::set_fw_grad(const at::Tensor& new_grad, const at::Tensor& self, uint64_t level, bool is_inplace_op) { | ||
if (!autograd_meta_) autograd_meta_ = impl::GetAutogradMetaFactory()->make(); | ||
autograd_meta_->set_fw_grad(new_grad, self, level, is_inplace_op); | ||
} | ||
|
||
TensorImpl::TensorImpl( | ||
Storage&& storage, | ||
DispatchKeySet key_set, | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -136,6 +136,8 @@ struct C10_API AutogradMetaInterface { | |
virtual bool requires_grad() const = 0; | ||
virtual at::Tensor& mutable_grad() = 0; | ||
virtual const at::Tensor& grad() const = 0; | ||
virtual const at::Tensor& fw_grad(uint64_t level, const at::Tensor& self) const = 0; | ||
virtual void set_fw_grad(const at::Tensor& new_grad, const at::Tensor& self, uint64_t level, bool is_inplace_op) = 0; | ||
virtual ~AutogradMetaInterface(); | ||
}; | ||
|
||
|
@@ -598,6 +600,42 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target { | |
*/ | ||
const at::Tensor& grad() const; | ||
|
||
/** | ||
* Return the accumulated gradient of a tensor. This gradient is computed | ||
* using forward mode AD. | ||
* | ||
* This is an internal API that should never be used by end users. | ||
* | ||
* The API is as follows: | ||
* - "level" allows to specify the level of forward AD nesting for which the | ||
* gradient should be returned. Note that since levels are not fully | ||
* supported yet, this argument should be 0. See documentation for | ||
* torch::autograd::enter_dual_level for more details about forward AD nesting. | ||
* - "self" should represent the Tensor whose forward grad is accessed. It is | ||
* required when dealing with view. | ||
*/ | ||
const at::Tensor& fw_grad(uint64_t level, const at::Tensor& self) const; | ||
|
||
/** | ||
* Sets the forward gradient for this Tensor. | ||
* The given Tensor might not be used directly and its content will be copied. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. "We don't necessarily set the tensor in all cases; if self is a view, we may copy instead. This means that the forward gradient does not necessarily alias with the input tensor" There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. "if it has different strides, it will get restrided to match" |
||
* | ||
* This is an internal API that should never be used by end users. | ||
* | ||
* The API is as follows: | ||
* - "new_grad" is a Tensor containing the new value of the gradient that should | ||
* be set | ||
* - "self" should reprensent the Tensor whose forward grad is accessed. It is | ||
* required when dealing with view. | ||
* - "level" allows to specify the level of forward AD nesting for which the | ||
* gradient should be set. Note that since levels are not fully supported | ||
* yet, this argument should be 0. See documentation for torch::autograd::enter_dual_level | ||
* for more details about forward AD nesting. | ||
* - "is_inplace_op" is a boolean flag that tells if this gradient was generated | ||
* by an inplace operation or an out of place one. This allows better error checking. | ||
*/ | ||
albanD marked this conversation as resolved.
Show resolved
Hide resolved
|
||
void set_fw_grad(const at::Tensor& new_grad, const at::Tensor& self, uint64_t level, bool is_inplace_op); | ||
|
||
/** | ||
* Return a typed data pointer to the actual data which this tensor refers to. | ||
* This checks that the requested type (from the template parameter) matches | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -35,6 +35,7 @@ | |
IS_WINDOWS, IS_MACOS, CudaMemoryLeakCheck) | ||
from torch.autograd import Variable, Function, detect_anomaly, kineto_available | ||
from torch.autograd.function import InplaceFunction | ||
import torch.autograd.forward_ad as fwAD | ||
from torch.testing import randn_like | ||
from torch.testing._internal.common_methods_invocations import (method_tests, | ||
create_input, unpack_variables, | ||
|
@@ -5350,6 +5351,26 @@ def fn(a, dim0_size=5): | |
|
||
self.assertEqual(x.grad, y.grad) | ||
|
||
def test_view_with_multi_output(self): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This test is added in this PR because otherwise only test running with MKLDNN or XLA test this codepath. |
||
x = torch.randn(2, 2, 2, dtype=torch.double) | ||
|
||
x1 = torch.view_as_complex(x) | ||
# Taking an invalid view should always be allowed as long as it is not | ||
# modified inplace | ||
res = x1.unbind(0) | ||
|
||
with self.assertRaisesRegex(RuntimeError, "output of a function that returns multiple views"): | ||
res[0] += torch.rand(2, requires_grad=True) | ||
|
||
x.requires_grad_(True) | ||
x1 = torch.view_as_complex(x) | ||
# Taking an invalid view should always be allowed as long as it is not | ||
# modified inplace | ||
res = x1.unbind(0) | ||
|
||
with self.assertRaisesRegex(RuntimeError, "output of a function that returns multiple views"): | ||
res[0] += torch.rand(2, requires_grad=True) | ||
|
||
def as_identity(self): | ||
# view_as_real and view_as_complex behavior should be like an identity | ||
def func(z): | ||
|
@@ -6348,6 +6369,66 @@ def foo(a): | |
self.assertEqual(hvp, torch.mm(hes, v.unsqueeze(1)).squeeze(1)) | ||
self.assertEqual(vhp, torch.mm(v.unsqueeze(0), hes).squeeze(0)) | ||
|
||
class TestAutogradForwardMode(TestCase): | ||
def test_forward_level_cleanup(self): | ||
import weakref | ||
|
||
def get_tensor_and_weak_ref(): | ||
# Helper function to get a Tensor and a weak ref that tells us | ||
# if the c++ version of this Tensor is still alive or not. | ||
# | ||
# Create the following reference chain to do so: | ||
# - python Tensor t | ||
# - c++ Tensor corresponding by t | ||
# - c++ Node corresponding to t.grad_fn | ||
# - python dict of metadata from this Node | ||
# - an object in this dict that we can take a weakref of | ||
|
||
|
||
# Create a new Tensor and Node | ||
t = torch.rand(2, requires_grad=True).clone() | ||
# Create the metadata dict | ||
meta_dict = t.grad_fn.metadata | ||
# Create the object in the dict | ||
|
||
class Foo(object): | ||
pass | ||
my_obj = Foo() | ||
meta_dict[0] = my_obj | ||
|
||
# After exiting this function, the python Tensor t is the only | ||
# thing keeping ref alive | ||
ref = weakref.ref(my_obj) | ||
return t, ref | ||
|
||
# Sanity check that the helper function works as expected | ||
t, t_ref = get_tensor_and_weak_ref() | ||
self.assertIsNotNone(t_ref()) | ||
|
||
del t | ||
self.assertIsNone(t_ref()) | ||
|
||
# Main test code | ||
foo = torch.rand(2) | ||
|
||
with fwAD.dual_level(): | ||
tangent, tangent_ref = get_tensor_and_weak_ref() | ||
self.assertIsNotNone(tangent_ref()) | ||
|
||
dual = fwAD.make_dual(foo, tangent) | ||
self.assertIsNotNone(tangent_ref()) | ||
|
||
# Make sure that the tangent we provided has been re-used as is | ||
self.assertTrue(fwAD.unpack_dual(dual)[1] is tangent) | ||
|
||
# Make sure that dual is keeping the tangent alive | ||
del tangent | ||
self.assertIsNotNone(tangent_ref()) | ||
|
||
# Make sure that the dual level does not keep the c++ | ||
# version of the tangent alive | ||
del dual | ||
self.assertIsNone(tangent_ref()) | ||
|
||
# Generic device type autograd tests. | ||
class TestAutogradDeviceType(TestCase): | ||
|
Uh oh!
There was an error while loading. Please reload this page.