Skip to content

Commit e5997d5

Browse files
committed
Update on "Add base forward grad logic"
RFC: pytorch/rfcs#11 This PR add the basic logic to handle forward grad as dual Tensors. It contains the following: - Mechanism to save dual state on a Tensor and clear it up when the dual level ends - C++ and python user facing API - Updated view system that is able to track both forward and backward views The current PR has the following limitations: - Extensive tests are in the next PR in the stack as formulas are needed to write full tests. - Only the manual formulas have been audited and no other formula is actually implemented here (they are in the next PR in the stack) - Only level 0 is allowed for now. This was discussed and agreed that it is not needed for the first version of this PR. - We can save one ViewInfo creation when both the forward and backward views have the same base. This can be done by adding a boolean flag to the DifferentiableViewMeta and extra logic in the `as_view` method. This is left out to keep this PR concise. - We can skip tracking forward views if the base has a forward grad. This can be done by adding extra logic in the `as_view` method. This is left out to keep this PR concise. Reading guide: - Updated view handling in [gen_variable_type.py](https://github.com/pytorch/pytorch/pull/49097/files#diff-f6553cec68caeaea36f6c8b14ff76a6d39dfd774e0ea9ef2f76e8d81fd9af5df), [VariableTypeUtils.h](https://github.com/pytorch/pytorch/pull/49097/files#diff-ec71cfa45954dece1236c661d170e6341879c5be637f4abf52e826d61b40695a), [variable.cpp](https://github.com/pytorch/pytorch/pull/49097/files#diff-60e3bfe444e89efc7149f25b38e472710525984789934ab83f1bd5671b8ff285) (skip code below "[Forward Grad View]" for now), [variable.h](https://github.com/pytorch/pytorch/pull/49097/files#diff-1604bcd0e4350ed99ec45e437cee7ac9ebe337392c9ea16a236247aeeb35b02bR266-R542) and [custom_function.cpp](https://github.com/pytorch/pytorch/pull/49097/files#diff-dd85f452082b5bb6612bbc12adb496f8827defa228509f7b493de1d517522d5d). - New forward grad class that handle storing gradients and tracking at each level [forward_grad.h](https://github.com/pytorch/pytorch/pull/49097/files#diff-c6c5b9ab2d7e5dde4102495faa1b6bbbfc23aa3e47deb7359c0bfe1eb004c0cb), [forward_grad.cpp](https://github.com/pytorch/pytorch/pull/49097/files#diff-de2ab54ade7312701850d71a119a4f4ee4b9fc5a9c42a467cdd4e73c033531dd) and [build_variables.bzl](https://github.com/pytorch/pytorch/pull/49097/files#diff-dfdfa2efb17beddfd9094524f95351fd197db6c8857e96b436fb599870359325). - Lowest level API and binding between Tensor and AutogradMeta in [TensorBody.h](https://github.com/pytorch/pytorch/pull/49097/files#diff-7554853205392fa743357bf845ecc350a974ec049383248c12daaf2f4de04911), [TensorImpl.cpp](https://github.com/pytorch/pytorch/pull/49097/files#diff-052bd9150ef8e09289ddf644b5a6830ede49207201cd41728f6d7cc6d9cead94), [TensorImpl.h](https://github.com/pytorch/pytorch/pull/49097/files#diff-a15aae4cf23da44970db7cece62ff981265575c798c62f7b52d87c8809dfe2e1) and the rest of [variable.cpp](https://github.com/pytorch/pytorch/pull/49097/files#diff-60e3bfe444e89efc7149f25b38e472710525984789934ab83f1bd5671b8ff285R557-R677) - API to access the forward primal that needs to be a differentiable function (and so in native_functions.yaml) [native_functions.yaml](https://github.com/pytorch/pytorch/pull/49097/files#diff-2f3dbd85efb9b5172f2264eedd3be47dd765e6ab7cc8bf3ade5e62c28ae35991) [NamedRegistrations.cpp](https://github.com/pytorch/pytorch/pull/49097/files#diff-69bd3bea510c9b64e1633fa18c3ea63d4b8348dbad3a78ad9de844ab3e43dc1d), [VariableMethodsStub.cpp](https://github.com/pytorch/pytorch/pull/49097/files#diff-23f5fcb737a2b289811fe0f4b65aef775e7c824b2e629ecd343df51405cd434f), [derivatives.yaml](https://github.com/pytorch/pytorch/pull/49097/files#diff-e4c2f99a2404e98c3586e07425da73008f36b1bada790648a7297af141d37f8c), [gen_python_functions.py](https://github.com/pytorch/pytorch/pull/49097/files#diff-e4c2f99a2404e98c3586e07425da73008f36b1bada790648a7297af141d37f8c), [gen_trace_type.py](https://github.com/pytorch/pytorch/pull/49097/files#diff-54e0b976027bf8debefb959ff360b89ae93466970c843365b1b3a03806d868ce), [TraceTypeManual.cpp](https://github.com/pytorch/pytorch/pull/49097/files#diff-f34636741ad4a23d018e0c289bc750c3bad887b45660e1d6eaf440d234a78fbf) and [part of VariableTypeManual.cpp](https://github.com/pytorch/pytorch/pull/49097/files#diff-6e19a1bce8cbdba8714b6e2c794a76bc0864b64a49cfa757cb0b5afdc937d1a4R198-R243) - c++ API [autograd.h](https://github.com/pytorch/pytorch/pull/49097/files#diff-349028fbe8291a965a7a263c323b208fe071c35c66179ee997ef84fa81aa4b1e), [autograd.cpp](https://github.com/pytorch/pytorch/pull/49097/files#diff-a3fe908d67dfec16a1fcde300de68b0701bf68b88db7451f29f2bee255cf30c9) - python binding [init.cpp](https://github.com/pytorch/pytorch/pull/49097/files#diff-c58a67c85191c22c9b3bb439117d8053edfd9dea839fa010cf967d404c3c630d) - python API [forward_ad.py](https://github.com/pytorch/pytorch/pull/49097/files#diff-a4efad4ba18fffdfb264c21e5475997a24a743089a899f8ec1a5ff962c6738d9), [autograd/__init__.py](https://github.com/pytorch/pytorch/pull/49097/files#diff-743abcafd32ad0e69f39ac5a91df4197b7e1921c135cacee7ef6dc829a8a7af8) - c++ and python printing [Formatting.cpp](https://github.com/pytorch/pytorch/pull/49097/files#diff-881dba501e71662e2e4818b4b016f739b344c8aed2f5edc6b871eda47a2aced0), [_tensor_str.py](https://github.com/pytorch/pytorch/pull/49097/files#diff-a7911f8d5e73adbff914d99fd7818ace2a7030b6a3748abe06ec6fc6e3df9cc3) - Utility for formulas and updated manual functions to respect new view system as well as forward grad [FunctionsManual.h](https://github.com/pytorch/pytorch/pull/49097/files#diff-6378bb6dc81a64dab676d61731341fa5d1088418f32a1473a33a0ccfc2357dc1), [FunctionsManual.cpp](https://github.com/pytorch/pytorch/pull/49097/files#diff-4adbd88239afcd60e8198aab65d4f5e43b62314e34b80551e997a1ea503adea5) [rest of VariableTypeManual.cpp](https://github.com/pytorch/pytorch/pull/49097/files#diff-6e19a1bce8cbdba8714b6e2c794a76bc0864b64a49cfa757cb0b5afdc937d1a4R264-R433) - Ensure SavedVariable save forward grad properly [saved_variable.h](https://github.com/pytorch/pytorch/pull/49097/files#diff-c1b8039d776241abe177d5aa99b79dd9489a9b3e529da8ab24c2e386c1238ae2), [saved_variable.cpp](https://github.com/pytorch/pytorch/pull/49097/files#diff-cc9fba479b5beae06b2eea2e390d17796e0341c5b037a20b5bcaccbb0c341030) [ghstack-poisoned]
2 parents 5393431 + 8397a62 commit e5997d5

File tree

313 files changed

+16465
-4665
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

313 files changed

+16465
-4665
lines changed

.circleci/scripts/binary_ios_upload.sh

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,13 @@ touch version.txt
3434
echo $(date +%s) > version.txt
3535
zip -r ${ZIPFILE} install src version.txt LICENSE
3636
# upload to aws
37-
brew install awscli
37+
# Install conda then 'conda install' awscli
38+
curl --retry 3 -o ~/conda.sh https://repo.anaconda.com/miniconda/Miniconda3-latest-MacOSX-x86_64.sh
39+
chmod +x ~/conda.sh
40+
/bin/bash ~/conda.sh -b -p ~/anaconda
41+
export PATH="~/anaconda/bin:${PATH}"
42+
source ~/anaconda/bin/activate
43+
conda install -c conda-forge awscli --yes
3844
set +x
3945
export AWS_ACCESS_KEY_ID=${AWS_S3_ACCESS_KEY_FOR_PYTORCH_BINARY_UPLOAD}
4046
export AWS_SECRET_ACCESS_KEY=${AWS_S3_ACCESS_SECRET_FOR_PYTORCH_BINARY_UPLOAD}

.jenkins/pytorch/codegen-test.sh

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,6 @@ python -m tools.setup_helpers.generate_code \
3737
mkdir -p "$OUT"/pyi/torch/_C
3838
mkdir -p "$OUT"/pyi/torch/nn
3939
python -m tools.pyi.gen_pyi \
40-
--declarations-path "$OUT"/torch/share/ATen/Declarations.yaml \
4140
--native-functions-path aten/src/ATen/native/native_functions.yaml \
4241
--deprecated-functions-path tools/autograd/deprecated.yaml \
4342
--out "$OUT"/pyi

.jenkins/pytorch/multigpu-test.sh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,4 +21,5 @@ time python test/run_test.py --verbose -i distributed/test_jit_c10d
2121
time python test/run_test.py --verbose -i distributed/test_distributed_fork
2222
time python test/run_test.py --verbose -i distributed/test_c10d
2323
time python test/run_test.py --verbose -i distributed/test_c10d_spawn
24+
time python test/run_test.py --verbose -i distributed/rpc/test_tensorpipe_agent
2425
assert_git_not_dirty

BUILD.bazel

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -544,6 +544,7 @@ header_template_rule(
544544
substitutions = {
545545
"@AT_MKLDNN_ENABLED@": "1",
546546
"@AT_MKL_ENABLED@": "0",
547+
"@AT_FFTW_ENABLED@": "0",
547548
"@AT_NNPACK_ENABLED@": "0",
548549
"@CAFFE2_STATIC_LINK_CUDA_INT@": "0",
549550
"@USE_BLAS@": "1",

CMakeLists.txt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -684,8 +684,8 @@ if(CMAKE_SYSTEM_PROCESSOR MATCHES "aarch64")
684684
int main() {
685685
float a[] = {1.0, 1.0};
686686
float32x4x2_t v;
687-
v.val[0] = vcombine_f32 (vcreate_f32 (__AARCH64_UINT64_C (0)), vcreate_f32 (__AARCH64_UINT64_C (0)));
688-
v.val[1] = vcombine_f32 (vcreate_f32 (__AARCH64_UINT64_C (0)), vcreate_f32 (__AARCH64_UINT64_C (0)));
687+
v.val[0] = vcombine_f32 (vcreate_f32 (0UL), vcreate_f32 (0UL));
688+
v.val[1] = vcombine_f32 (vcreate_f32 (0UL), vcreate_f32 (0UL));
689689
vst1q_f32_x2(a, v);
690690
return 0;
691691
}" HAS_VST1)

aten/src/ATen/BatchingRegistrations.cpp

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -233,6 +233,32 @@ Tensor unsqueeze_batching_rule(const Tensor& self, int64_t dim) {
233233
return self_physical.newLogicalFromPhysical(result);
234234
}
235235

236+
Tensor& fill_inplace_scalar_batching_rule(Tensor& self, Scalar value) {
237+
auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self);
238+
self_physical.tensor().fill_(value);
239+
return self;
240+
}
241+
242+
Tensor& fill_inplace_tensor_batching_rule(Tensor& self, const Tensor& value) {
243+
auto value_batched = isBatchedTensor(value);
244+
245+
if (value_batched) {
246+
auto physical_args =
247+
BroadcastingVmapTransform::logicalToPhysical({self, value});
248+
physical_args[0].tensor().copy_(physical_args[1].tensor());
249+
} else {
250+
auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self);
251+
self_physical.tensor().fill_(value);
252+
}
253+
return self;
254+
}
255+
256+
Tensor& zero_inplace_batching_rule(Tensor &self) {
257+
auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self);
258+
self_physical.tensor().zero_();
259+
return self;
260+
}
261+
236262
Tensor squeeze_batching_rule(const Tensor& self) {
237263
auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self);
238264
auto physical_sizes = self_physical.tensor().sizes();
@@ -971,6 +997,11 @@ TORCH_LIBRARY_IMPL(aten, Batched, m) {
971997
m.impl("is_complex", native::is_complex);
972998
m.impl("conj", native::conj);
973999

1000+
// inplace operations
1001+
m.impl("fill_.Scalar", fill_inplace_scalar_batching_rule);
1002+
m.impl("fill_.Tensor", fill_inplace_tensor_batching_rule);
1003+
m.impl("zero_", zero_inplace_batching_rule);
1004+
9741005
// view operations
9751006
m.impl("as_strided", as_strided_batching_rule);
9761007
m.impl("chunk", chunk_batching_rule);

aten/src/ATen/Dispatch.h

Lines changed: 294 additions & 157 deletions
Large diffs are not rendered by default.

aten/src/ATen/LegacyTHFunctionsCPU.cpp

Lines changed: 0 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -832,53 +832,6 @@ std::tuple<Tensor,Tensor> _th_gels(const Tensor & self, const Tensor & A) {
832832
}
833833
return std::tuple<Tensor, Tensor>(res1, res2);
834834
}
835-
std::tuple<Tensor &,Tensor &> _th_eig_out(Tensor & res1, Tensor & res2, const Tensor & self, bool eigenvectors) {
836-
// DeviceGuard omitted
837-
auto dispatch_scalar_type = infer_scalar_type(self);
838-
839-
switch (dispatch_scalar_type) {
840-
case ScalarType::Double: {
841-
auto res1_ = checked_dense_tensor_unwrap(res1, "res1", 0, "_th_eig_out", false, DeviceType::CPU, dispatch_scalar_type);
842-
auto res2_ = checked_dense_tensor_unwrap(res2, "res2", 0, "_th_eig_out", false, DeviceType::CPU, dispatch_scalar_type);
843-
auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_eig_out", false, DeviceType::CPU, dispatch_scalar_type);
844-
THDoubleTensor_geev(res1_, res2_, self_, eigenvectors);
845-
break;
846-
}
847-
case ScalarType::Float: {
848-
auto res1_ = checked_dense_tensor_unwrap(res1, "res1", 0, "_th_eig_out", false, DeviceType::CPU, dispatch_scalar_type);
849-
auto res2_ = checked_dense_tensor_unwrap(res2, "res2", 0, "_th_eig_out", false, DeviceType::CPU, dispatch_scalar_type);
850-
auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_eig_out", false, DeviceType::CPU, dispatch_scalar_type);
851-
THFloatTensor_geev(res1_, res2_, self_, eigenvectors);
852-
break;
853-
}
854-
default:
855-
AT_ERROR("_th_eig_out not supported on CPUType for ", dispatch_scalar_type);
856-
}
857-
return std::tuple<Tensor &, Tensor &>(res1, res2);
858-
}
859-
std::tuple<Tensor,Tensor> _th_eig(const Tensor & self, bool eigenvectors) {
860-
// DeviceGuard omitted
861-
auto dispatch_scalar_type = infer_scalar_type(self);
862-
auto res1_ = c10::make_intrusive<TensorImpl, UndefinedTensorImpl>(c10::Storage(c10::Storage::use_byte_size_t(), 0, allocator(), true),DispatchKey::CPU, scalarTypeToTypeMeta(dispatch_scalar_type)).release();
863-
auto res1 = Tensor(c10::intrusive_ptr<TensorImpl, UndefinedTensorImpl>::reclaim(res1_));
864-
auto res2_ = c10::make_intrusive<TensorImpl, UndefinedTensorImpl>(c10::Storage(c10::Storage::use_byte_size_t(), 0, allocator(), true),DispatchKey::CPU, scalarTypeToTypeMeta(dispatch_scalar_type)).release();
865-
auto res2 = Tensor(c10::intrusive_ptr<TensorImpl, UndefinedTensorImpl>::reclaim(res2_));
866-
switch (dispatch_scalar_type) {
867-
case ScalarType::Double: {
868-
auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_eig", false, DeviceType::CPU, dispatch_scalar_type);
869-
THDoubleTensor_geev(res1_, res2_, self_, eigenvectors);
870-
break;
871-
}
872-
case ScalarType::Float: {
873-
auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_eig", false, DeviceType::CPU, dispatch_scalar_type);
874-
THFloatTensor_geev(res1_, res2_, self_, eigenvectors);
875-
break;
876-
}
877-
default:
878-
AT_ERROR("_th_eig not supported on CPUType for ", dispatch_scalar_type);
879-
}
880-
return std::tuple<Tensor, Tensor>(res1, res2);
881-
}
882835
Tensor & _th_potri_out(Tensor & output, const Tensor & self, bool upper) {
883836
// DeviceGuard omitted
884837
auto dispatch_scalar_type = infer_scalar_type(self);

aten/src/ATen/LegacyTHFunctionsCPU.h

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,8 +38,6 @@ Tensor & _th_histc_out(Tensor & result, const Tensor & self, int64_t bins, Scala
3838
Tensor _th_histc(const Tensor & self, int64_t bins, Scalar min, Scalar max);
3939
std::tuple<Tensor &,Tensor &> _th_gels_out(Tensor & res1, Tensor & res2, const Tensor & self, const Tensor & A);
4040
std::tuple<Tensor,Tensor> _th_gels(const Tensor & self, const Tensor & A);
41-
std::tuple<Tensor &,Tensor &> _th_eig_out(Tensor & res1, Tensor & res2, const Tensor & self, bool eigenvectors);
42-
std::tuple<Tensor,Tensor> _th_eig(const Tensor & self, bool eigenvectors);
4341
Tensor & _th_potri_out(Tensor & output, const Tensor & self, bool upper);
4442
Tensor _th_potri(const Tensor & self, bool upper);
4543
std::tuple<Tensor &,Tensor &> _th_geqrf_out(Tensor & res1, Tensor & res2, const Tensor & self);

aten/src/ATen/MemoryOverlap.cpp

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,4 +75,16 @@ void assert_no_partial_overlap(TensorImpl* a, TensorImpl* b) {
7575
"Please clone() the tensor before performing the operation.");
7676
}
7777

78+
void assert_no_overlap(const Tensor& a, const Tensor& b) {
79+
assert_no_overlap(a.unsafeGetTensorImpl(), b.unsafeGetTensorImpl());
80+
}
81+
82+
void assert_no_overlap(TensorImpl* a, TensorImpl* b) {
83+
const auto lap = get_overlap_status(a, b);
84+
TORCH_CHECK(lap != MemOverlapStatus::PARTIAL && lap != MemOverlapStatus::FULL,
85+
"unsupported operation: some elements of the input tensor and "
86+
"the written-to tensor refer to a single memory location. "
87+
"Please clone() the tensor before performing the operation.");
88+
}
89+
7890
}

0 commit comments

Comments
 (0)