Skip to content

Commit a5fdab9

Browse files
author
Jessica Lin
authored
Merge branch 'master' into master
2 parents 39ff9d8 + b9f3b2e commit a5fdab9

File tree

13 files changed

+752
-12
lines changed

13 files changed

+752
-12
lines changed

cpp/autograd/CMakeLists.txt

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
cmake_minimum_required(VERSION 2.8)
2+
3+
project(autograd)
4+
set(CMAKE_CXX_STANDARD 14)
5+
6+
find_package(Torch REQUIRED)
7+
8+
add_executable(${PROJECT_NAME} "autograd.cpp")
9+
target_link_libraries(${PROJECT_NAME} "${TORCH_LIBRARIES}")
10+
11+
# The following code block is suggested to be used on Windows.
12+
# According to https://github.com/pytorch/pytorch/issues/25457,
13+
# the DLLs need to be copied to avoid memory errors.
14+
if (MSVC)
15+
file(GLOB TORCH_DLLS "${TORCH_INSTALL_PREFIX}/lib/*.dll")
16+
add_custom_command(TARGET ${PROJECT_NAME}
17+
POST_BUILD
18+
COMMAND ${CMAKE_COMMAND} -E copy_if_different
19+
${TORCH_DLLS}
20+
$<TARGET_FILE_DIR:${PROJECT_NAME}>)
21+
endif (MSVC)

cpp/autograd/README.md

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
# C++ autograd example
2+
3+
`autograd.cpp` contains several examples of doing autograd in PyTorch C++ frontend.
4+
5+
To build the code, run the following commands from your terminal:
6+
7+
```shell
8+
$ cd autograd
9+
$ mkdir build
10+
$ cd build
11+
$ cmake -DCMAKE_PREFIX_PATH=/path/to/libtorch ..
12+
$ make
13+
```
14+
15+
where `/path/to/libtorch` should be the path to the unzipped *LibTorch*
16+
distribution, which you can get from the [PyTorch
17+
homepage](https://pytorch.org/get-started/locally/).
18+
19+
Execute the compiled binary to run:
20+
21+
```shell
22+
$ ./autograd
23+
====== Running: "Basic autograd operations" ======
24+
1 1
25+
1 1
26+
[ CPUFloatType{2,2} ]
27+
3 3
28+
3 3
29+
[ CPUFloatType{2,2} ]
30+
AddBackward1
31+
27 27
32+
27 27
33+
[ CPUFloatType{2,2} ]
34+
MulBackward1
35+
27
36+
[ CPUFloatType{} ]
37+
MeanBackward0
38+
false
39+
true
40+
SumBackward0
41+
4.5000 4.5000
42+
4.5000 4.5000
43+
[ CPUFloatType{2,2} ]
44+
813.6625
45+
1015.0142
46+
-664.8849
47+
[ CPUFloatType{3} ]
48+
MulBackward1
49+
204.8000
50+
2048.0000
51+
0.2048
52+
[ CPUFloatType{3} ]
53+
true
54+
true
55+
false
56+
true
57+
false
58+
true
59+
60+
====== Running "Computing higher-order gradients in C++" ======
61+
0.0025 0.0946 0.1474 0.1387
62+
0.0238 -0.0018 0.0259 0.0094
63+
0.0513 -0.0549 -0.0604 0.0210
64+
[ CPUFloatType{3,4} ]
65+
66+
====== Running "Using custom autograd function in C++" ======
67+
-3.5513 3.7160 3.6477
68+
-3.5513 3.7160 3.6477
69+
[ CPUFloatType{2,3} ]
70+
0.3095 1.4035 -0.0349
71+
0.3095 1.4035 -0.0349
72+
0.3095 1.4035 -0.0349
73+
0.3095 1.4035 -0.0349
74+
[ CPUFloatType{4,3} ]
75+
5.5000
76+
5.5000
77+
[ CPUFloatType{2} ]
78+
```

cpp/autograd/autograd.cpp

Lines changed: 191 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,191 @@
1+
#include <torch/torch.h>
2+
#include <iostream>
3+
4+
using namespace torch::autograd;
5+
6+
void basic_autograd_operations_example() {
7+
std::cout << "====== Running: \"Basic autograd operations\" ======" << std::endl;
8+
9+
// Create a tensor and set ``torch::requires_grad()`` to track computation with it
10+
auto x = torch::ones({2, 2}, torch::requires_grad());
11+
std::cout << x << std::endl;
12+
13+
// Do a tensor operation:
14+
auto y = x + 2;
15+
std::cout << y << std::endl;
16+
17+
// ``y`` was created as a result of an operation, so it has a ``grad_fn``.
18+
std::cout << y.grad_fn()->name() << std::endl;
19+
20+
// Do more operations on ``y``
21+
auto z = y * y * 3;
22+
auto out = z.mean();
23+
24+
std::cout << z << std::endl;
25+
std::cout << z.grad_fn()->name() << std::endl;
26+
std::cout << out << std::endl;
27+
std::cout << out.grad_fn()->name() << std::endl;
28+
29+
// ``.requires_grad_( ... )`` changes an existing tensor's ``requires_grad`` flag in-place.
30+
auto a = torch::randn({2, 2});
31+
a = ((a * 3) / (a - 1));
32+
std::cout << a.requires_grad() << std::endl;
33+
34+
a.requires_grad_(true);
35+
std::cout << a.requires_grad() << std::endl;
36+
37+
auto b = (a * a).sum();
38+
std::cout << b.grad_fn()->name() << std::endl;
39+
40+
// Let's backprop now. Because ``out`` contains a single scalar, ``out.backward()``
41+
// is equivalent to ``out.backward(torch::tensor(1.))``.
42+
out.backward();
43+
44+
// Print gradients d(out)/dx
45+
std::cout << x.grad() << std::endl;
46+
47+
// Now let's take a look at an example of vector-Jacobian product:
48+
x = torch::randn(3, torch::requires_grad());
49+
50+
y = x * 2;
51+
while (y.norm().item<double>() < 1000) {
52+
y = y * 2;
53+
}
54+
55+
std::cout << y << std::endl;
56+
std::cout << y.grad_fn()->name() << std::endl;
57+
58+
// If we want the vector-Jacobian product, pass the vector to ``backward`` as argument:
59+
auto v = torch::tensor({0.1, 1.0, 0.0001}, torch::kFloat);
60+
y.backward(v);
61+
62+
std::cout << x.grad() << std::endl;
63+
64+
// You can also stop autograd from tracking history on tensors that require gradients
65+
// either by putting ``torch::NoGradGuard`` in a code block
66+
std::cout << x.requires_grad() << std::endl;
67+
std::cout << x.pow(2).requires_grad() << std::endl;
68+
69+
{
70+
torch::NoGradGuard no_grad;
71+
std::cout << x.pow(2).requires_grad() << std::endl;
72+
}
73+
74+
// Or by using ``.detach()`` to get a new tensor with the same content but that does
75+
// not require gradients:
76+
std::cout << x.requires_grad() << std::endl;
77+
y = x.detach();
78+
std::cout << y.requires_grad() << std::endl;
79+
std::cout << x.eq(y).all().item<bool>() << std::endl;
80+
}
81+
82+
void compute_higher_order_gradients_example() {
83+
std::cout << "====== Running \"Computing higher-order gradients in C++\" ======" << std::endl;
84+
85+
// One of the applications of higher-order gradients is calculating gradient penalty.
86+
// Let's see an example of it using ``torch::autograd::grad``:
87+
88+
auto model = torch::nn::Linear(4, 3);
89+
90+
auto input = torch::randn({3, 4}).requires_grad_(true);
91+
auto output = model(input);
92+
93+
// Calculate loss
94+
auto target = torch::randn({3, 3});
95+
auto loss = torch::nn::MSELoss()(output, target);
96+
97+
// Use norm of gradients as penalty
98+
auto grad_output = torch::ones_like(output);
99+
auto gradient = torch::autograd::grad({output}, {input}, /*grad_outputs=*/{grad_output}, /*create_graph=*/true)[0];
100+
auto gradient_penalty = torch::pow((gradient.norm(2, /*dim=*/1) - 1), 2).mean();
101+
102+
// Add gradient penalty to loss
103+
auto combined_loss = loss + gradient_penalty;
104+
combined_loss.backward();
105+
106+
std::cout << input.grad() << std::endl;
107+
}
108+
109+
// Inherit from Function
110+
class LinearFunction : public Function<LinearFunction> {
111+
public:
112+
// Note that both forward and backward are static functions
113+
114+
// bias is an optional argument
115+
static torch::Tensor forward(
116+
AutogradContext *ctx, torch::Tensor input, torch::Tensor weight, torch::Tensor bias = torch::Tensor()) {
117+
ctx->save_for_backward({input, weight, bias});
118+
auto output = input.mm(weight.t());
119+
if (bias.defined()) {
120+
output += bias.unsqueeze(0).expand_as(output);
121+
}
122+
return output;
123+
}
124+
125+
static tensor_list backward(AutogradContext *ctx, tensor_list grad_outputs) {
126+
auto saved = ctx->get_saved_variables();
127+
auto input = saved[0];
128+
auto weight = saved[1];
129+
auto bias = saved[2];
130+
131+
auto grad_output = grad_outputs[0];
132+
auto grad_input = grad_output.mm(weight);
133+
auto grad_weight = grad_output.t().mm(input);
134+
auto grad_bias = torch::Tensor();
135+
if (bias.defined()) {
136+
grad_bias = grad_output.sum(0);
137+
}
138+
139+
return {grad_input, grad_weight, grad_bias};
140+
}
141+
};
142+
143+
class MulConstant : public Function<MulConstant> {
144+
public:
145+
static torch::Tensor forward(AutogradContext *ctx, torch::Tensor tensor, double constant) {
146+
// ctx is a context object that can be used to stash information
147+
// for backward computation
148+
ctx->saved_data["constant"] = constant;
149+
return tensor * constant;
150+
}
151+
152+
static tensor_list backward(AutogradContext *ctx, tensor_list grad_outputs) {
153+
// We return as many input gradients as there were arguments.
154+
// Gradients of non-tensor arguments to forward must be `torch::Tensor()`.
155+
return {grad_outputs[0] * ctx->saved_data["constant"].toDouble(), torch::Tensor()};
156+
}
157+
};
158+
159+
void custom_autograd_function_example() {
160+
std::cout << "====== Running \"Using custom autograd function in C++\" ======" << std::endl;
161+
{
162+
auto x = torch::randn({2, 3}).requires_grad_();
163+
auto weight = torch::randn({4, 3}).requires_grad_();
164+
auto y = LinearFunction::apply(x, weight);
165+
y.sum().backward();
166+
167+
std::cout << x.grad() << std::endl;
168+
std::cout << weight.grad() << std::endl;
169+
}
170+
{
171+
auto x = torch::randn({2}).requires_grad_();
172+
auto y = MulConstant::apply(x, 5.5);
173+
y.sum().backward();
174+
175+
std::cout << x.grad() << std::endl;
176+
}
177+
}
178+
179+
int main() {
180+
std::cout << std::boolalpha;
181+
182+
basic_autograd_operations_example();
183+
184+
std::cout << "\n";
185+
186+
compute_higher_order_gradients_example();
187+
188+
std::cout << "\n";
189+
190+
custom_autograd_function_example();
191+
}

cpp/dcgan/dcgan.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -121,9 +121,9 @@ int main(int argc, const char* argv[]) {
121121
torch::data::DataLoaderOptions().batch_size(kBatchSize).workers(2));
122122

123123
torch::optim::Adam generator_optimizer(
124-
generator->parameters(), torch::optim::AdamOptions(2e-4).beta1(0.5));
124+
generator->parameters(), torch::optim::AdamOptions(2e-4).betas(std::make_tuple (0.5, 0.5)));
125125
torch::optim::Adam discriminator_optimizer(
126-
discriminator->parameters(), torch::optim::AdamOptions(2e-4).beta1(0.5));
126+
discriminator->parameters(), torch::optim::AdamOptions(2e-4).betas(std::make_tuple (0.5, 0.5)));
127127

128128
if (kRestoreFromCheckpoint) {
129129
torch::load(generator, "generator-checkpoint.pt");

distributed/ddp/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -189,4 +189,4 @@ that in turn produces the following output
189189
```
190190
191191
# Conclusions
192-
As the author of a distributed data parallel application, your code needs to be aware of two types of resources: compute nodes and the GPUs within each node. The process of setting up bookkeeping to track how the set of GPUs is mapped to the processes of your application can be tedious and error-prone. We hope that by structuring your application as shown in this example and using the launcher, the mechanics of setting up distributed training can be significantly simplified.
192+
As the author of a distributed data parallel application, your code needs to be aware of two types of resources: compute nodes and the GPUs within each node. The process of setting up bookkeeping to track how the set of GPUs is mapped to the processes of your application can be tedious and error-prone. We hope that by structuring your application as shown in this example and using the launcher, the mechanics of setting up distributed training can be significantly simplified.

0 commit comments

Comments
 (0)