|
| 1 | +#include <torch/torch.h> |
| 2 | +#include <iostream> |
| 3 | + |
| 4 | +using namespace torch::autograd; |
| 5 | + |
| 6 | +void basic_autograd_operations_example() { |
| 7 | + std::cout << "====== Running: \"Basic autograd operations\" ======" << std::endl; |
| 8 | + |
| 9 | + // Create a tensor and set ``torch::requires_grad()`` to track computation with it |
| 10 | + auto x = torch::ones({2, 2}, torch::requires_grad()); |
| 11 | + std::cout << x << std::endl; |
| 12 | + |
| 13 | + // Do a tensor operation: |
| 14 | + auto y = x + 2; |
| 15 | + std::cout << y << std::endl; |
| 16 | + |
| 17 | + // ``y`` was created as a result of an operation, so it has a ``grad_fn``. |
| 18 | + std::cout << y.grad_fn()->name() << std::endl; |
| 19 | + |
| 20 | + // Do more operations on ``y`` |
| 21 | + auto z = y * y * 3; |
| 22 | + auto out = z.mean(); |
| 23 | + |
| 24 | + std::cout << z << std::endl; |
| 25 | + std::cout << z.grad_fn()->name() << std::endl; |
| 26 | + std::cout << out << std::endl; |
| 27 | + std::cout << out.grad_fn()->name() << std::endl; |
| 28 | + |
| 29 | + // ``.requires_grad_( ... )`` changes an existing tensor's ``requires_grad`` flag in-place. |
| 30 | + auto a = torch::randn({2, 2}); |
| 31 | + a = ((a * 3) / (a - 1)); |
| 32 | + std::cout << a.requires_grad() << std::endl; |
| 33 | + |
| 34 | + a.requires_grad_(true); |
| 35 | + std::cout << a.requires_grad() << std::endl; |
| 36 | + |
| 37 | + auto b = (a * a).sum(); |
| 38 | + std::cout << b.grad_fn()->name() << std::endl; |
| 39 | + |
| 40 | + // Let's backprop now. Because ``out`` contains a single scalar, ``out.backward()`` |
| 41 | + // is equivalent to ``out.backward(torch::tensor(1.))``. |
| 42 | + out.backward(); |
| 43 | + |
| 44 | + // Print gradients d(out)/dx |
| 45 | + std::cout << x.grad() << std::endl; |
| 46 | + |
| 47 | + // Now let's take a look at an example of vector-Jacobian product: |
| 48 | + x = torch::randn(3, torch::requires_grad()); |
| 49 | + |
| 50 | + y = x * 2; |
| 51 | + while (y.norm().item<double>() < 1000) { |
| 52 | + y = y * 2; |
| 53 | + } |
| 54 | + |
| 55 | + std::cout << y << std::endl; |
| 56 | + std::cout << y.grad_fn()->name() << std::endl; |
| 57 | + |
| 58 | + // If we want the vector-Jacobian product, pass the vector to ``backward`` as argument: |
| 59 | + auto v = torch::tensor({0.1, 1.0, 0.0001}, torch::kFloat); |
| 60 | + y.backward(v); |
| 61 | + |
| 62 | + std::cout << x.grad() << std::endl; |
| 63 | + |
| 64 | + // You can also stop autograd from tracking history on tensors that require gradients |
| 65 | + // either by putting ``torch::NoGradGuard`` in a code block |
| 66 | + std::cout << x.requires_grad() << std::endl; |
| 67 | + std::cout << x.pow(2).requires_grad() << std::endl; |
| 68 | + |
| 69 | + { |
| 70 | + torch::NoGradGuard no_grad; |
| 71 | + std::cout << x.pow(2).requires_grad() << std::endl; |
| 72 | + } |
| 73 | + |
| 74 | + // Or by using ``.detach()`` to get a new tensor with the same content but that does |
| 75 | + // not require gradients: |
| 76 | + std::cout << x.requires_grad() << std::endl; |
| 77 | + y = x.detach(); |
| 78 | + std::cout << y.requires_grad() << std::endl; |
| 79 | + std::cout << x.eq(y).all().item<bool>() << std::endl; |
| 80 | +} |
| 81 | + |
| 82 | +void compute_higher_order_gradients_example() { |
| 83 | + std::cout << "====== Running \"Computing higher-order gradients in C++\" ======" << std::endl; |
| 84 | + |
| 85 | + // One of the applications of higher-order gradients is calculating gradient penalty. |
| 86 | + // Let's see an example of it using ``torch::autograd::grad``: |
| 87 | + |
| 88 | + auto model = torch::nn::Linear(4, 3); |
| 89 | + |
| 90 | + auto input = torch::randn({3, 4}).requires_grad_(true); |
| 91 | + auto output = model(input); |
| 92 | + |
| 93 | + // Calculate loss |
| 94 | + auto target = torch::randn({3, 3}); |
| 95 | + auto loss = torch::nn::MSELoss()(output, target); |
| 96 | + |
| 97 | + // Use norm of gradients as penalty |
| 98 | + auto grad_output = torch::ones_like(output); |
| 99 | + auto gradient = torch::autograd::grad({output}, {input}, /*grad_outputs=*/{grad_output}, /*create_graph=*/true)[0]; |
| 100 | + auto gradient_penalty = torch::pow((gradient.norm(2, /*dim=*/1) - 1), 2).mean(); |
| 101 | + |
| 102 | + // Add gradient penalty to loss |
| 103 | + auto combined_loss = loss + gradient_penalty; |
| 104 | + combined_loss.backward(); |
| 105 | + |
| 106 | + std::cout << input.grad() << std::endl; |
| 107 | +} |
| 108 | + |
| 109 | +// Inherit from Function |
| 110 | +class LinearFunction : public Function<LinearFunction> { |
| 111 | + public: |
| 112 | + // Note that both forward and backward are static functions |
| 113 | + |
| 114 | + // bias is an optional argument |
| 115 | + static torch::Tensor forward( |
| 116 | + AutogradContext *ctx, torch::Tensor input, torch::Tensor weight, torch::Tensor bias = torch::Tensor()) { |
| 117 | + ctx->save_for_backward({input, weight, bias}); |
| 118 | + auto output = input.mm(weight.t()); |
| 119 | + if (bias.defined()) { |
| 120 | + output += bias.unsqueeze(0).expand_as(output); |
| 121 | + } |
| 122 | + return output; |
| 123 | + } |
| 124 | + |
| 125 | + static tensor_list backward(AutogradContext *ctx, tensor_list grad_outputs) { |
| 126 | + auto saved = ctx->get_saved_variables(); |
| 127 | + auto input = saved[0]; |
| 128 | + auto weight = saved[1]; |
| 129 | + auto bias = saved[2]; |
| 130 | + |
| 131 | + auto grad_output = grad_outputs[0]; |
| 132 | + auto grad_input = grad_output.mm(weight); |
| 133 | + auto grad_weight = grad_output.t().mm(input); |
| 134 | + auto grad_bias = torch::Tensor(); |
| 135 | + if (bias.defined()) { |
| 136 | + grad_bias = grad_output.sum(0); |
| 137 | + } |
| 138 | + |
| 139 | + return {grad_input, grad_weight, grad_bias}; |
| 140 | + } |
| 141 | +}; |
| 142 | + |
| 143 | +class MulConstant : public Function<MulConstant> { |
| 144 | + public: |
| 145 | + static torch::Tensor forward(AutogradContext *ctx, torch::Tensor tensor, double constant) { |
| 146 | + // ctx is a context object that can be used to stash information |
| 147 | + // for backward computation |
| 148 | + ctx->saved_data["constant"] = constant; |
| 149 | + return tensor * constant; |
| 150 | + } |
| 151 | + |
| 152 | + static tensor_list backward(AutogradContext *ctx, tensor_list grad_outputs) { |
| 153 | + // We return as many input gradients as there were arguments. |
| 154 | + // Gradients of non-tensor arguments to forward must be `torch::Tensor()`. |
| 155 | + return {grad_outputs[0] * ctx->saved_data["constant"].toDouble(), torch::Tensor()}; |
| 156 | + } |
| 157 | +}; |
| 158 | + |
| 159 | +void custom_autograd_function_example() { |
| 160 | + std::cout << "====== Running \"Using custom autograd function in C++\" ======" << std::endl; |
| 161 | + { |
| 162 | + auto x = torch::randn({2, 3}).requires_grad_(); |
| 163 | + auto weight = torch::randn({4, 3}).requires_grad_(); |
| 164 | + auto y = LinearFunction::apply(x, weight); |
| 165 | + y.sum().backward(); |
| 166 | + |
| 167 | + std::cout << x.grad() << std::endl; |
| 168 | + std::cout << weight.grad() << std::endl; |
| 169 | + } |
| 170 | + { |
| 171 | + auto x = torch::randn({2}).requires_grad_(); |
| 172 | + auto y = MulConstant::apply(x, 5.5); |
| 173 | + y.sum().backward(); |
| 174 | + |
| 175 | + std::cout << x.grad() << std::endl; |
| 176 | + } |
| 177 | +} |
| 178 | + |
| 179 | +int main() { |
| 180 | + std::cout << std::boolalpha; |
| 181 | + |
| 182 | + basic_autograd_operations_example(); |
| 183 | + |
| 184 | + std::cout << "\n"; |
| 185 | + |
| 186 | + compute_higher_order_gradients_example(); |
| 187 | + |
| 188 | + std::cout << "\n"; |
| 189 | + |
| 190 | + custom_autograd_function_example(); |
| 191 | +} |
0 commit comments