Skip to content

Commit 0615686

Browse files
author
Will Feng
committed
[WIP]
1 parent 234bcff commit 0615686

File tree

3 files changed

+292
-0
lines changed

3 files changed

+292
-0
lines changed

cpp/autograd/CMakeLists.txt

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
cmake_minimum_required(VERSION 2.8)
2+
3+
project(autograd)
4+
set(CMAKE_CXX_STANDARD 14)
5+
6+
find_package(Torch REQUIRED)
7+
8+
add_executable(${PROJECT_NAME} "autograd.cpp")
9+
target_link_libraries(${PROJECT_NAME} "${TORCH_LIBRARIES}")
10+
11+
# The following code block is suggested to be used on Windows.
12+
# According to https://github.com/pytorch/pytorch/issues/25457,
13+
# the DLLs need to be copied to avoid memory errors.
14+
if (MSVC)
15+
file(GLOB TORCH_DLLS "${TORCH_INSTALL_PREFIX}/lib/*.dll")
16+
add_custom_command(TARGET ${PROJECT_NAME}
17+
POST_BUILD
18+
COMMAND ${CMAKE_COMMAND} -E copy_if_different
19+
${TORCH_DLLS}
20+
$<TARGET_FILE_DIR:${PROJECT_NAME}>)
21+
endif (MSVC)

cpp/autograd/README.md

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
# C++ autograd example
2+
3+
`autograd.cpp` contains several examples of doing autograd in PyTorch C++ frontend.
4+
5+
To build the code, run the following commands from your terminal:
6+
7+
```shell
8+
$ cd autograd
9+
$ mkdir build
10+
$ cd build
11+
$ cmake -DCMAKE_PREFIX_PATH=/path/to/libtorch ..
12+
$ make
13+
```
14+
15+
where `/path/to/libtorch` should be the path to the unzipped *LibTorch*
16+
distribution, which you can get from the [PyTorch
17+
homepage](https://pytorch.org/get-started/locally/).
18+
19+
Execute the compiled binary to run:
20+
21+
```shell
22+
$ ./autograd
23+
====== Running: "Basic autograd operations" ======
24+
1 1
25+
1 1
26+
[ CPUFloatType{2,2} ]
27+
3 3
28+
3 3
29+
[ CPUFloatType{2,2} ]
30+
AddBackward1
31+
27 27
32+
27 27
33+
[ CPUFloatType{2,2} ]
34+
MulBackward1
35+
27
36+
[ CPUFloatType{} ]
37+
MeanBackward0
38+
false
39+
true
40+
SumBackward0
41+
4.5000 4.5000
42+
4.5000 4.5000
43+
[ CPUFloatType{2,2} ]
44+
-731.0470
45+
963.0721
46+
1236.4192
47+
[ CPUFloatType{3} ]
48+
MulBackward1
49+
102.4000
50+
1024.0000
51+
0.1024
52+
[ CPUFloatType{3} ]
53+
true
54+
true
55+
false
56+
true
57+
false
58+
true
59+
60+
====== Running "Computing higher-order gradients in C++" ======
61+
-0.0384 0.1510 -0.0288 0.0872
62+
-0.0105 -0.0936 -0.0553 -0.0222
63+
0.0589 -0.0848 -0.0730 0.0070
64+
[ CPUFloatType{3,4} ]
65+
66+
====== Running "Using custom autograd function in C++" ======
67+
-0.6962 -1.7728 1.4167
68+
-0.6962 -1.7728 1.4167
69+
[ CPUFloatType{2,3} ]
70+
1.5162 1.6421 1.2691
71+
1.5162 1.6421 1.2691
72+
1.5162 1.6421 1.2691
73+
1.5162 1.6421 1.2691
74+
[ CPUFloatType{4,3} ]
75+
5.5000
76+
5.5000
77+
[ CPUFloatType{2} ]
78+
```

cpp/autograd/autograd.cpp

Lines changed: 193 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,193 @@
1+
#include <torch/torch.h>
2+
#include <iostream>
3+
4+
using namespace torch::autograd;
5+
6+
void basic_autograd_operations_example() {
7+
std::cout << "====== Running: \"Basic autograd operations\" ======" << std::endl;
8+
9+
// Create a tensor and set ``torch::requires_grad()`` to track computation with it
10+
auto x = torch::ones({2, 2}, torch::requires_grad());
11+
std::cout << x << std::endl;
12+
13+
// Do a tensor operation:
14+
auto y = x + 2;
15+
std::cout << y << std::endl;
16+
17+
// ``y`` was created as a result of an operation, so it has a ``grad_fn``.
18+
std::cout << y.grad_fn()->name() << std::endl;
19+
20+
// Do more operations on ``y``
21+
auto z = y * y * 3;
22+
auto out = z.mean();
23+
24+
std::cout << z << std::endl;
25+
std::cout << z.grad_fn()->name() << std::endl;
26+
std::cout << out << std::endl;
27+
std::cout << out.grad_fn()->name() << std::endl;
28+
29+
// ``.requires_grad_( ... )`` changes an existing Tensor's ``requires_grad`` flag in-place.
30+
auto a = torch::randn({2, 2});
31+
a = ((a * 3) / (a - 1));
32+
std::cout << a.requires_grad() << std::endl;
33+
34+
a.requires_grad_(true);
35+
std::cout << a.requires_grad() << std::endl;
36+
37+
auto b = (a * a).sum();
38+
std::cout << b.grad_fn()->name() << std::endl;
39+
40+
// Let's backprop now. Because ``out`` contains a single scalar, ``out.backward()``
41+
// is equivalent to ``out.backward(torch::tensor(1.))``.
42+
out.backward();
43+
44+
// Print gradients d(out)/dx
45+
std::cout << x.grad() << std::endl;
46+
47+
// Now let's take a look at an example of vector-Jacobian product:
48+
x = torch::randn(3, torch::requires_grad());
49+
50+
y = x * 2;
51+
while (y.norm().item<double>() < 1000) {
52+
y = y * 2;
53+
}
54+
55+
std::cout << y << std::endl;
56+
std::cout << y.grad_fn()->name() << std::endl;
57+
58+
// If we want the vector-Jacobian product, pass the vector to ``backward`` as argument:
59+
auto v = torch::tensor({0.1, 1.0, 0.0001}, torch::kFloat);
60+
y.backward(v);
61+
62+
std::cout << x.grad() << std::endl;
63+
64+
// You can also stop autograd from tracking history on Tensors with ``.requires_grad() == true``
65+
// either by putting ``torch::NoGradGuard`` in a code block
66+
std::cout << x.requires_grad() << std::endl;
67+
std::cout << x.pow(2).requires_grad() << std::endl;
68+
69+
{
70+
torch::NoGradGuard no_grad;
71+
std::cout << x.pow(2).requires_grad() << std::endl;
72+
}
73+
74+
// Or by using ``.detach()`` to get a new Tensor with the same content but that does
75+
// not require gradients:
76+
std::cout << x.requires_grad() << std::endl;
77+
y = x.detach();
78+
std::cout << y.requires_grad() << std::endl;
79+
std::cout << x.eq(y).all().item<bool>() << std::endl;
80+
}
81+
82+
void compute_higher_order_gradients_example() {
83+
std::cout << "====== Running \"Computing higher-order gradients in C++\" ======" << std::endl;
84+
85+
// One of the applications of higher-order gradients is calculating gradient penalty.
86+
// Let's see an example of it using ``torch::autograd::grad``:
87+
88+
auto model = torch::nn::Linear(4, 3);
89+
90+
auto input = torch::randn({3, 4}).requires_grad_(true);
91+
auto output = model(input);
92+
93+
// Calculate loss
94+
auto target = torch::randn({3, 3});
95+
auto loss = torch::nn::MSELoss()(output, target);
96+
97+
// Use norm of gradients as penalty
98+
auto grad_output = torch::ones_like(output);
99+
auto gradient = torch::autograd::grad({output}, {input}, /*grad_outputs=*/{grad_output},
100+
/*retain_graph=*/true, /*create_graph=*/true,
101+
/*allow_unused=*/true)[0];
102+
gradient = gradient.view({-1, 1});
103+
auto gradient_penalty = torch::pow((gradient.norm(2, /*dim=*/1) - 1), 2).mean();
104+
105+
// Add gradient penalty to loss
106+
auto combined_loss = loss + gradient_penalty;
107+
combined_loss.backward();
108+
109+
std::cout << input.grad() << std::endl;
110+
}
111+
112+
// Inherit from Function
113+
class LinearFunction : public Function<LinearFunction> {
114+
public:
115+
// Note that both forward and backward are static functions
116+
117+
// bias is an optional argument
118+
static Variable forward(AutogradContext *ctx, Variable input, Variable weight, Variable bias = Variable()) {
119+
ctx->save_for_backward({input, weight, bias});
120+
auto output = input.mm(weight.t());
121+
if (bias.defined()) {
122+
output += bias.unsqueeze(0).expand_as(output);
123+
}
124+
return output;
125+
}
126+
127+
static variable_list backward(AutogradContext *ctx, variable_list grad_outputs) {
128+
auto saved = ctx->get_saved_variables();
129+
auto input = saved[0];
130+
auto weight = saved[1];
131+
auto bias = saved[2];
132+
133+
auto grad_output = grad_outputs[0];
134+
auto grad_input = grad_output.mm(weight);
135+
auto grad_weight = grad_output.t().mm(input);
136+
auto grad_bias = Variable();
137+
if (bias.defined()) {
138+
grad_bias = grad_output.sum(0);
139+
}
140+
141+
return {grad_input, grad_weight, grad_bias};
142+
}
143+
};
144+
145+
class MulConstant : public Function<MulConstant> {
146+
public:
147+
static Variable forward(AutogradContext *ctx, Variable variable, double constant) {
148+
// ctx is a context object that can be used to stash information
149+
// for backward computation
150+
ctx->saved_data["constant"] = constant;
151+
return variable * constant;
152+
}
153+
154+
static variable_list backward(AutogradContext *ctx, variable_list grad_outputs) {
155+
// We return as many input gradients as there were arguments.
156+
// Gradients of non-Tensor arguments to forward must be `Variable()`.
157+
return {grad_outputs[0] * ctx->saved_data["constant"].toDouble(), Variable()};
158+
}
159+
};
160+
161+
void custom_autograd_function_example() {
162+
std::cout << "====== Running \"Using custom autograd function in C++\" ======" << std::endl;
163+
{
164+
auto x = torch::randn({2, 3}).requires_grad_();
165+
auto weight = torch::randn({4, 3}).requires_grad_();
166+
auto y = LinearFunction::apply(x, weight);
167+
y.sum().backward();
168+
169+
std::cout << x.grad() << std::endl;
170+
std::cout << weight.grad() << std::endl;
171+
}
172+
{
173+
auto x = torch::randn({2}).requires_grad_();
174+
auto y = MulConstant::apply(x, 5.5);
175+
y.sum().backward();
176+
177+
std::cout << x.grad() << std::endl;
178+
}
179+
}
180+
181+
int main() {
182+
std::cout << std::boolalpha;
183+
184+
basic_autograd_operations_example();
185+
186+
std::cout << "\n";
187+
188+
compute_higher_order_gradients_example();
189+
190+
std::cout << "\n";
191+
192+
custom_autograd_function_example();
193+
}

0 commit comments

Comments
 (0)