Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
2b72f72
Add base forward grad logic
albanD Dec 9, 2020
c69e6da
Update on "Add base forward grad logic"
albanD Dec 9, 2020
e66b605
Update on "Add base forward grad logic"
albanD Dec 9, 2020
2cea4d9
Update on "Add base forward grad logic"
albanD Dec 9, 2020
f8e4761
Update on "Add base forward grad logic"
albanD Dec 9, 2020
949e783
Update on "Add base forward grad logic"
albanD Dec 10, 2020
5393431
Update on "Add base forward grad logic"
albanD Dec 11, 2020
e5997d5
Update on "Add base forward grad logic"
albanD Dec 14, 2020
248f0f0
Update on "Add base forward grad logic"
albanD Dec 14, 2020
289ab9b
Update on "Add base forward grad logic"
albanD Dec 15, 2020
7676dab
Update on "Add base forward grad logic"
albanD Dec 15, 2020
f2f69e7
Update on "Add base forward grad logic"
albanD Dec 15, 2020
8eb6fe0
Update on "Add base forward grad logic"
albanD Dec 16, 2020
2c84322
Update on "Add base forward grad logic"
albanD Dec 16, 2020
107dc11
Update on "Add base forward grad logic"
albanD Dec 16, 2020
4fb045e
Update on "Add base forward grad logic"
albanD Dec 16, 2020
28c25e2
Update on "Add base forward grad logic"
albanD Dec 16, 2020
4139298
Update on "Add base forward grad logic"
albanD Dec 16, 2020
9c8492f
Update on "Add base forward grad logic"
albanD Dec 17, 2020
8d22120
Update on "Add base forward grad logic"
albanD Dec 17, 2020
dd72a26
Update on "Add base forward grad logic"
albanD Dec 17, 2020
98b99a8
Update on "Add base forward grad logic"
albanD Dec 17, 2020
b87f778
Update on "Add base forward grad logic"
albanD Dec 17, 2020
004826a
Update on "Add base forward grad logic"
albanD Dec 17, 2020
bc8b23c
Update on "Add base forward grad logic"
albanD Dec 17, 2020
163b1eb
Update on "Add base forward grad logic"
albanD Dec 17, 2020
a55e021
Update on "Add base forward grad logic"
albanD Dec 18, 2020
de9d986
Update on "Add base forward grad logic"
albanD Dec 18, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions aten/src/ATen/core/Formatting.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -292,6 +292,11 @@ std::ostream& print(std::ostream& stream, const Tensor & tensor_, int64_t linesi
stream << ", axis: " << tensor_.q_per_channel_axis();
}
}

auto& fw_grad = tensor.fw_grad(/* level */ 0);
if (fw_grad.defined()) {
stream << ", tangent:" << std::endl << fw_grad;
}
stream << " ]";
}
return stream;
Expand Down
1 change: 1 addition & 0 deletions aten/src/ATen/core/NamedRegistrations.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -509,4 +509,5 @@ TORCH_LIBRARY_IMPL(aten, Named, m) {
m.impl("_version", CppFunction::makeFallthrough());
m.impl("requires_grad_", CppFunction::makeFallthrough());
m.impl("retain_grad", CppFunction::makeFallthrough());
m.impl("_fw_primal", CppFunction::makeFallthrough());
}
27 changes: 27 additions & 0 deletions aten/src/ATen/native/AutogradComposite.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
#include <ATen/ATen.h>

namespace at {
namespace native {

/// This function can be used to create a dual Tensor that holds a tangent to compute forward mode gradients.
/// Note that the dual Tensor's primal is a view of the given primal and the given tangent is used as-is.
/// This function is backward differentiable.
at::Tensor make_dual(const at::Tensor& primal, const at::Tensor& tangent, int64_t level) {
TORCH_CHECK(!primal.fw_grad(level).defined(), "Making a dual Tensor based on a Tensor that "
"already has a forward gradient at the same level ", level, " is not supported.");

auto dual_tensor = primal.view(primal.sizes());
dual_tensor.set_fw_grad(tangent, level, /* is_inplace_op */ false);
return dual_tensor;
}

/// This function can be used to unpack a given dual Tensor to get its primal and tangent. The returned primal
/// is a view of the dual and the tangent is returned as is.
/// This function is backward differentiable.
std::tuple<at::Tensor, at::Tensor> unpack_dual(const at::Tensor& tensor, int64_t level) {
return std::tuple<at::Tensor, at::Tensor>(tensor._fw_primal(level), tensor.fw_grad(level));
}

} // namespace native

} // namespace at
4 changes: 4 additions & 0 deletions aten/src/ATen/native/VariableMethodStubs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,5 +40,9 @@ void retain_grad(Tensor& self) {
AT_ERROR("retain_grad is not implemented for Tensor");
}

Tensor _fw_primal(const Tensor& self, int64_t level) {
AT_ERROR("_fw_primal is not implemented for Tensor");
}

} // namespace native
} // namespace at
14 changes: 14 additions & 0 deletions aten/src/ATen/native/native_functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,20 @@
manual_kernel_registration: True
variants: method

- func: _fw_primal(Tensor(a) self, int level) -> Tensor(a)
use_c10_dispatcher: full
variants: method
dispatch:
DefaultBackend: _fw_primal

- func: make_dual(Tensor(a) primal, Tensor tangent, int level) -> Tensor(a)
use_c10_dispatcher: full
variants: function

- func: unpack_dual(Tensor(a) dual, int level) -> (Tensor(a) primal, Tensor tangent)
use_c10_dispatcher: full
variants: function

- func: rename_(Tensor(a!) self, Dimname[]? names) -> Tensor(a!)
use_c10_dispatcher: full
variants: method
Expand Down
17 changes: 17 additions & 0 deletions aten/src/ATen/templates/TensorBody.h
Original file line number Diff line number Diff line change
Expand Up @@ -599,6 +599,23 @@ class CAFFE2_API Tensor {
return impl_->grad();
}

// The Forward AD API functions below are low level and are not to be used by end
// users who should use the API provided in torch/csrc/autograd.h

/// This function returns the forward gradient for this Tensor at the given level.
const Tensor& fw_grad(uint64_t level) const {
return impl_->fw_grad(level, *this);
}

/// This function can be used to set the value of the forward grad.
/// Note that the given new_grad might not be used directly if it has different
/// metadata (size/stride/storage offset) compared to this Tensor. In that case,
/// new_grad content will be copied into a new Tensor
void set_fw_grad(const Tensor& new_grad, uint64_t level, bool is_inplace_op) {
impl_->set_fw_grad(new_grad, *this, level, is_inplace_op);
}


// STOP. Thinking of adding a method here, which only makes use
// of other ATen methods? Define it in native_functions.yaml.

Expand Down
11 changes: 11 additions & 0 deletions c10/core/TensorImpl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,17 @@ const at::Tensor& TensorImpl::grad() const {
return autograd_meta_->grad();
}

const at::Tensor& TensorImpl::fw_grad(uint64_t level, const at::Tensor& self) const {
// See TensorImpl::grad() above for explanation about the line below
if (!autograd_meta_) return impl::GetAutogradMetaFactory()->undefined_tensor();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@swolchok You don't happen to know a good way to avoid this goofiness? (I guess the big problem is we're returning const at::Tensor& const reference)

return autograd_meta_->fw_grad(level, self);
}

void TensorImpl::set_fw_grad(const at::Tensor& new_grad, const at::Tensor& self, uint64_t level, bool is_inplace_op) {
if (!autograd_meta_) autograd_meta_ = impl::GetAutogradMetaFactory()->make();
autograd_meta_->set_fw_grad(new_grad, self, level, is_inplace_op);
}

TensorImpl::TensorImpl(
Storage&& storage,
DispatchKeySet key_set,
Expand Down
38 changes: 38 additions & 0 deletions c10/core/TensorImpl.h
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,8 @@ struct C10_API AutogradMetaInterface {
virtual bool requires_grad() const = 0;
virtual at::Tensor& mutable_grad() = 0;
virtual const at::Tensor& grad() const = 0;
virtual const at::Tensor& fw_grad(uint64_t level, const at::Tensor& self) const = 0;
virtual void set_fw_grad(const at::Tensor& new_grad, const at::Tensor& self, uint64_t level, bool is_inplace_op) = 0;
virtual ~AutogradMetaInterface();
};

Expand Down Expand Up @@ -598,6 +600,42 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target {
*/
const at::Tensor& grad() const;

/**
* Return the accumulated gradient of a tensor. This gradient is computed
* using forward mode AD.
*
* This is an internal API that should never be used by end users.
*
* The API is as follows:
* - "level" allows to specify the level of forward AD nesting for which the
* gradient should be returned. Note that since levels are not fully
* supported yet, this argument should be 0. See documentation for
* torch::autograd::enter_dual_level for more details about forward AD nesting.
* - "self" should represent the Tensor whose forward grad is accessed. It is
* required when dealing with view.
*/
const at::Tensor& fw_grad(uint64_t level, const at::Tensor& self) const;

/**
* Sets the forward gradient for this Tensor.
* The given Tensor might not be used directly and its content will be copied.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"We don't necessarily set the tensor in all cases; if self is a view, we may copy instead. This means that the forward gradient does not necessarily alias with the input tensor"

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"if it has different strides, it will get restrided to match"

*
* This is an internal API that should never be used by end users.
*
* The API is as follows:
* - "new_grad" is a Tensor containing the new value of the gradient that should
* be set
* - "self" should reprensent the Tensor whose forward grad is accessed. It is
* required when dealing with view.
* - "level" allows to specify the level of forward AD nesting for which the
* gradient should be set. Note that since levels are not fully supported
* yet, this argument should be 0. See documentation for torch::autograd::enter_dual_level
* for more details about forward AD nesting.
* - "is_inplace_op" is a boolean flag that tells if this gradient was generated
* by an inplace operation or an out of place one. This allows better error checking.
*/
void set_fw_grad(const at::Tensor& new_grad, const at::Tensor& self, uint64_t level, bool is_inplace_op);

/**
* Return a typed data pointer to the actual data which this tensor refers to.
* This checks that the requested type (from the template parameter) matches
Expand Down
81 changes: 81 additions & 0 deletions test/test_autograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
IS_WINDOWS, IS_MACOS, CudaMemoryLeakCheck)
from torch.autograd import Variable, Function, detect_anomaly, kineto_available
from torch.autograd.function import InplaceFunction
import torch.autograd.forward_ad as fwAD
from torch.testing import randn_like
from torch.testing._internal.common_methods_invocations import (method_tests,
create_input, unpack_variables,
Expand Down Expand Up @@ -5350,6 +5351,26 @@ def fn(a, dim0_size=5):

self.assertEqual(x.grad, y.grad)

def test_view_with_multi_output(self):
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test is added in this PR because otherwise only test running with MKLDNN or XLA test this codepath.

x = torch.randn(2, 2, 2, dtype=torch.double)

x1 = torch.view_as_complex(x)
# Taking an invalid view should always be allowed as long as it is not
# modified inplace
res = x1.unbind(0)

with self.assertRaisesRegex(RuntimeError, "output of a function that returns multiple views"):
res[0] += torch.rand(2, requires_grad=True)

x.requires_grad_(True)
x1 = torch.view_as_complex(x)
# Taking an invalid view should always be allowed as long as it is not
# modified inplace
res = x1.unbind(0)

with self.assertRaisesRegex(RuntimeError, "output of a function that returns multiple views"):
res[0] += torch.rand(2, requires_grad=True)

def as_identity(self):
# view_as_real and view_as_complex behavior should be like an identity
def func(z):
Expand Down Expand Up @@ -6348,6 +6369,66 @@ def foo(a):
self.assertEqual(hvp, torch.mm(hes, v.unsqueeze(1)).squeeze(1))
self.assertEqual(vhp, torch.mm(v.unsqueeze(0), hes).squeeze(0))

class TestAutogradForwardMode(TestCase):
def test_forward_level_cleanup(self):
import weakref

def get_tensor_and_weak_ref():
# Helper function to get a Tensor and a weak ref that tells us
# if the c++ version of this Tensor is still alive or not.
#
# Create the following reference chain to do so:
# - python Tensor t
# - c++ Tensor corresponding by t
# - c++ Node corresponding to t.grad_fn
# - python dict of metadata from this Node
# - an object in this dict that we can take a weakref of


# Create a new Tensor and Node
t = torch.rand(2, requires_grad=True).clone()
# Create the metadata dict
meta_dict = t.grad_fn.metadata
# Create the object in the dict

class Foo(object):
pass
my_obj = Foo()
meta_dict[0] = my_obj

# After exiting this function, the python Tensor t is the only
# thing keeping ref alive
ref = weakref.ref(my_obj)
return t, ref

# Sanity check that the helper function works as expected
t, t_ref = get_tensor_and_weak_ref()
self.assertIsNotNone(t_ref())

del t
self.assertIsNone(t_ref())

# Main test code
foo = torch.rand(2)

with fwAD.dual_level():
tangent, tangent_ref = get_tensor_and_weak_ref()
self.assertIsNotNone(tangent_ref())

dual = fwAD.make_dual(foo, tangent)
self.assertIsNotNone(tangent_ref())

# Make sure that the tangent we provided has been re-used as is
self.assertTrue(fwAD.unpack_dual(dual)[1] is tangent)

# Make sure that dual is keeping the tangent alive
del tangent
self.assertIsNotNone(tangent_ref())

# Make sure that the dual level does not keep the c++
# version of the tangent alive
del dual
self.assertIsNone(tangent_ref())

# Generic device type autograd tests.
class TestAutogradDeviceType(TestCase):
Expand Down
7 changes: 5 additions & 2 deletions test/test_namedtuple_return_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
all_operators_with_namedtuple_return = {
'max', 'min', 'median', 'nanmedian', 'mode', 'kthvalue', 'svd', 'symeig', 'eig',
'qr', 'geqrf', 'solve', 'slogdet', 'sort', 'topk', 'lstsq',
'triangular_solve', 'cummax', 'cummin', 'linalg_eigh'
'triangular_solve', 'cummax', 'cummin', 'linalg_eigh', "unpack_dual"
}


Expand Down Expand Up @@ -65,6 +65,7 @@ def test_namedtuple_return(self):
op(operators=['triangular_solve'], input=(a,), names=('solution', 'cloned_coefficient'), hasout=True),
op(operators=['lstsq'], input=(a,), names=('solution', 'QR'), hasout=True),
op(operators=['linalg_eigh'], input=("L",), names=('eigenvalues', 'eigenvectors'), hasout=True),
op(operators=['unpack_dual'], input=(a, 0), names=('primal', 'tangent'), hasout=False),
]

for op in operators:
Expand All @@ -75,7 +76,9 @@ def test_namedtuple_return(self):
for i, name in enumerate(op.names):
self.assertIs(getattr(ret, name), ret[i])
else:
ret = getattr(a, f)(*op.input)
# Handle op that are not methods
func = getattr(a, f) if hasattr(a, f) else getattr(torch, f)
ret = func(*op.input)
for i, name in enumerate(op.names):
self.assertIs(getattr(ret, name), ret[i])
if op.hasout:
Expand Down
3 changes: 2 additions & 1 deletion tools/autograd/gen_python_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,8 @@
'nonzero(_(out|numpy))?',
'set_data',
'.*_overrideable', # overrideable functions for backend extension
'data', 'is_leaf', 'output_nr', '_version', 'requires_grad_', 'retain_grad', 'set_'
'data', 'is_leaf', 'output_nr', '_version', 'requires_grad_', 'retain_grad', 'set_',
'_fw_primal'
]

# These function signatures are not exposed to Python. Note that this signature
Expand Down
2 changes: 1 addition & 1 deletion tools/autograd/gen_trace_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
# For these ops we want to skip the codegen-ed registration to both Autograd and Tracer keys.
# You can find the manual registration in torch/csrc/autograd/VariableTypeManual.cpp
MANUAL_AUTOGRAD_AND_TRACER = set([
'resize_', 'resize_as_', 'detach', 'detach_', 'copy_',
'resize_', 'resize_as_', 'detach', 'detach_', 'copy_', '_fw_primal',
])

# Currently MANUAL_AUTOGRAD and MANUAL_TRACER share the same set of ops:
Expand Down
10 changes: 6 additions & 4 deletions tools/autograd/gen_variable_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -689,7 +689,7 @@ def wrap_output(return_values, var):

if len(differentiable_output_vars) == 0:
# no output is differentiable (.indices() for SparseTensors for example)
rhs_value = 'as_view({}, {}, /* is_differentiable */ false)'.format(view_info, var)
rhs_value = f'as_view({view_info}, {var}, /* is_bw_differentiable */ false, /* is_fw_differentiable */ false)'
elif len(differentiable_output_vars) == 1:
# Single differentiable output (Tensor or Tensor[])
return_info = differentiable_outputs[0]
Expand All @@ -704,13 +704,15 @@ def wrap_output(return_values, var):
creation_meta = "CreationMeta::MULTI_OUTPUT_SAFE"
else:
creation_meta = "CreationMeta::MULTI_OUTPUT_NODE"
call += ("as_view(/* base */ {}, /* output */ {}, /* is_differentiable */ true, "
"/* creation_meta */ {});\n").format(view_info, var, creation_meta)
call += ("as_view(/* base */ {}, /* output */ {}, /* is_bw_differentiable */ true, "
"/* is_fw_differentiable */ true, "
"/* creation_meta */ {});").format(view_info, var, creation_meta)
rhs_value = 'std::move({})'.format(var)
else:
call += emit_view_lambda()
creation_meta = "GradMode::is_enabled() ? CreationMeta::DEFAULT: CreationMeta::NO_GRAD_MODE"
rhs_value = ("as_view(/* base */ {}, /* output */ {}, /* is_differentiable */ true, "
rhs_value = ("as_view(/* base */ {}, /* output */ {}, /* is_bw_differentiable */ true, "
"/* is_fw_differentiable */ true, "
"/* view_func */ func, /* creation_meta */ {})").format(view_info, var, creation_meta)
else:
# This could be supported but we don't need it at the moment, so keeping things simple.
Expand Down
2 changes: 2 additions & 0 deletions tools/build_variables.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,8 @@ core_sources_common = [
"torch/csrc/autograd/profiler_legacy.cpp",
"torch/csrc/autograd/profiler_kineto.cpp",
"torch/csrc/autograd/profiler_utils.cpp",
"torch/csrc/autograd/autograd_meta.cpp",
"torch/csrc/autograd/forward_grad.cpp",
"torch/csrc/jit/frontend/edit_distance.cpp",
"torch/csrc/jit/frontend/string_to_type.cpp",
"torch/csrc/jit/mobile/type_parser.cpp",
Expand Down
6 changes: 6 additions & 0 deletions torch/_C/__init__.pyi.in
Original file line number Diff line number Diff line change
Expand Up @@ -522,6 +522,12 @@ def autocast_increment_nesting() -> _int: ...
def autocast_decrement_nesting() -> _int: ...
def set_anomaly_enabled(enabled: _bool) -> None: ...
def is_anomaly_enabled() -> _bool: ...
def _enter_dual_level() -> _int: ...
def _exit_dual_level(level: _int) -> None: ...
def _make_dual(tensor: Tensor, tangent: Tensor, level: _int) -> Tensor: ...
def _unpack_dual(tensor: Tensor, level: _int) -> Tensor: ...
def __set_forward_AD_enabled(enabled: _bool) -> None: ...
def __is_forward_AD_enabled() -> _bool: ...

# Defined in torch/csrc/jit/python/script_init.cpp
class LoggerBase(object):
Expand Down
Loading