diff --git a/cpp/autograd/CMakeLists.txt b/cpp/autograd/CMakeLists.txt new file mode 100644 index 0000000000..332c51ab09 --- /dev/null +++ b/cpp/autograd/CMakeLists.txt @@ -0,0 +1,21 @@ +cmake_minimum_required(VERSION 2.8) + +project(autograd) +set(CMAKE_CXX_STANDARD 14) + +find_package(Torch REQUIRED) + +add_executable(${PROJECT_NAME} "autograd.cpp") +target_link_libraries(${PROJECT_NAME} "${TORCH_LIBRARIES}") + +# The following code block is suggested to be used on Windows. +# According to https://github.com/pytorch/pytorch/issues/25457, +# the DLLs need to be copied to avoid memory errors. +if (MSVC) + file(GLOB TORCH_DLLS "${TORCH_INSTALL_PREFIX}/lib/*.dll") + add_custom_command(TARGET ${PROJECT_NAME} + POST_BUILD + COMMAND ${CMAKE_COMMAND} -E copy_if_different + ${TORCH_DLLS} + $) +endif (MSVC) diff --git a/cpp/autograd/README.md b/cpp/autograd/README.md new file mode 100644 index 0000000000..6061d442a1 --- /dev/null +++ b/cpp/autograd/README.md @@ -0,0 +1,78 @@ +# C++ autograd example + +`autograd.cpp` contains several examples of doing autograd in PyTorch C++ frontend. + +To build the code, run the following commands from your terminal: + +```shell +$ cd autograd +$ mkdir build +$ cd build +$ cmake -DCMAKE_PREFIX_PATH=/path/to/libtorch .. +$ make +``` + +where `/path/to/libtorch` should be the path to the unzipped *LibTorch* +distribution, which you can get from the [PyTorch +homepage](https://pytorch.org/get-started/locally/). + +Execute the compiled binary to run: + +```shell +$ ./autograd +====== Running: "Basic autograd operations" ====== + 1 1 + 1 1 +[ CPUFloatType{2,2} ] + 3 3 + 3 3 +[ CPUFloatType{2,2} ] +AddBackward1 + 27 27 + 27 27 +[ CPUFloatType{2,2} ] +MulBackward1 +27 +[ CPUFloatType{} ] +MeanBackward0 +false +true +SumBackward0 + 4.5000 4.5000 + 4.5000 4.5000 +[ CPUFloatType{2,2} ] + 813.6625 + 1015.0142 + -664.8849 +[ CPUFloatType{3} ] +MulBackward1 + 204.8000 + 2048.0000 + 0.2048 +[ CPUFloatType{3} ] +true +true +false +true +false +true + +====== Running "Computing higher-order gradients in C++" ====== + 0.0025 0.0946 0.1474 0.1387 + 0.0238 -0.0018 0.0259 0.0094 + 0.0513 -0.0549 -0.0604 0.0210 +[ CPUFloatType{3,4} ] + +====== Running "Using custom autograd function in C++" ====== +-3.5513 3.7160 3.6477 +-3.5513 3.7160 3.6477 +[ CPUFloatType{2,3} ] + 0.3095 1.4035 -0.0349 + 0.3095 1.4035 -0.0349 + 0.3095 1.4035 -0.0349 + 0.3095 1.4035 -0.0349 +[ CPUFloatType{4,3} ] + 5.5000 + 5.5000 +[ CPUFloatType{2} ] +``` diff --git a/cpp/autograd/autograd.cpp b/cpp/autograd/autograd.cpp new file mode 100644 index 0000000000..2ef3c4dfcc --- /dev/null +++ b/cpp/autograd/autograd.cpp @@ -0,0 +1,191 @@ +#include +#include + +using namespace torch::autograd; + +void basic_autograd_operations_example() { + std::cout << "====== Running: \"Basic autograd operations\" ======" << std::endl; + + // Create a tensor and set ``torch::requires_grad()`` to track computation with it + auto x = torch::ones({2, 2}, torch::requires_grad()); + std::cout << x << std::endl; + + // Do a tensor operation: + auto y = x + 2; + std::cout << y << std::endl; + + // ``y`` was created as a result of an operation, so it has a ``grad_fn``. + std::cout << y.grad_fn()->name() << std::endl; + + // Do more operations on ``y`` + auto z = y * y * 3; + auto out = z.mean(); + + std::cout << z << std::endl; + std::cout << z.grad_fn()->name() << std::endl; + std::cout << out << std::endl; + std::cout << out.grad_fn()->name() << std::endl; + + // ``.requires_grad_( ... )`` changes an existing tensor's ``requires_grad`` flag in-place. + auto a = torch::randn({2, 2}); + a = ((a * 3) / (a - 1)); + std::cout << a.requires_grad() << std::endl; + + a.requires_grad_(true); + std::cout << a.requires_grad() << std::endl; + + auto b = (a * a).sum(); + std::cout << b.grad_fn()->name() << std::endl; + + // Let's backprop now. Because ``out`` contains a single scalar, ``out.backward()`` + // is equivalent to ``out.backward(torch::tensor(1.))``. + out.backward(); + + // Print gradients d(out)/dx + std::cout << x.grad() << std::endl; + + // Now let's take a look at an example of vector-Jacobian product: + x = torch::randn(3, torch::requires_grad()); + + y = x * 2; + while (y.norm().item() < 1000) { + y = y * 2; + } + + std::cout << y << std::endl; + std::cout << y.grad_fn()->name() << std::endl; + + // If we want the vector-Jacobian product, pass the vector to ``backward`` as argument: + auto v = torch::tensor({0.1, 1.0, 0.0001}, torch::kFloat); + y.backward(v); + + std::cout << x.grad() << std::endl; + + // You can also stop autograd from tracking history on tensors that require gradients + // either by putting ``torch::NoGradGuard`` in a code block + std::cout << x.requires_grad() << std::endl; + std::cout << x.pow(2).requires_grad() << std::endl; + + { + torch::NoGradGuard no_grad; + std::cout << x.pow(2).requires_grad() << std::endl; + } + + // Or by using ``.detach()`` to get a new tensor with the same content but that does + // not require gradients: + std::cout << x.requires_grad() << std::endl; + y = x.detach(); + std::cout << y.requires_grad() << std::endl; + std::cout << x.eq(y).all().item() << std::endl; +} + +void compute_higher_order_gradients_example() { + std::cout << "====== Running \"Computing higher-order gradients in C++\" ======" << std::endl; + + // One of the applications of higher-order gradients is calculating gradient penalty. + // Let's see an example of it using ``torch::autograd::grad``: + + auto model = torch::nn::Linear(4, 3); + + auto input = torch::randn({3, 4}).requires_grad_(true); + auto output = model(input); + + // Calculate loss + auto target = torch::randn({3, 3}); + auto loss = torch::nn::MSELoss()(output, target); + + // Use norm of gradients as penalty + auto grad_output = torch::ones_like(output); + auto gradient = torch::autograd::grad({output}, {input}, /*grad_outputs=*/{grad_output}, /*create_graph=*/true)[0]; + auto gradient_penalty = torch::pow((gradient.norm(2, /*dim=*/1) - 1), 2).mean(); + + // Add gradient penalty to loss + auto combined_loss = loss + gradient_penalty; + combined_loss.backward(); + + std::cout << input.grad() << std::endl; +} + +// Inherit from Function +class LinearFunction : public Function { + public: + // Note that both forward and backward are static functions + + // bias is an optional argument + static torch::Tensor forward( + AutogradContext *ctx, torch::Tensor input, torch::Tensor weight, torch::Tensor bias = torch::Tensor()) { + ctx->save_for_backward({input, weight, bias}); + auto output = input.mm(weight.t()); + if (bias.defined()) { + output += bias.unsqueeze(0).expand_as(output); + } + return output; + } + + static tensor_list backward(AutogradContext *ctx, tensor_list grad_outputs) { + auto saved = ctx->get_saved_variables(); + auto input = saved[0]; + auto weight = saved[1]; + auto bias = saved[2]; + + auto grad_output = grad_outputs[0]; + auto grad_input = grad_output.mm(weight); + auto grad_weight = grad_output.t().mm(input); + auto grad_bias = torch::Tensor(); + if (bias.defined()) { + grad_bias = grad_output.sum(0); + } + + return {grad_input, grad_weight, grad_bias}; + } +}; + +class MulConstant : public Function { + public: + static torch::Tensor forward(AutogradContext *ctx, torch::Tensor tensor, double constant) { + // ctx is a context object that can be used to stash information + // for backward computation + ctx->saved_data["constant"] = constant; + return tensor * constant; + } + + static tensor_list backward(AutogradContext *ctx, tensor_list grad_outputs) { + // We return as many input gradients as there were arguments. + // Gradients of non-tensor arguments to forward must be `torch::Tensor()`. + return {grad_outputs[0] * ctx->saved_data["constant"].toDouble(), torch::Tensor()}; + } +}; + +void custom_autograd_function_example() { + std::cout << "====== Running \"Using custom autograd function in C++\" ======" << std::endl; + { + auto x = torch::randn({2, 3}).requires_grad_(); + auto weight = torch::randn({4, 3}).requires_grad_(); + auto y = LinearFunction::apply(x, weight); + y.sum().backward(); + + std::cout << x.grad() << std::endl; + std::cout << weight.grad() << std::endl; + } + { + auto x = torch::randn({2}).requires_grad_(); + auto y = MulConstant::apply(x, 5.5); + y.sum().backward(); + + std::cout << x.grad() << std::endl; + } +} + +int main() { + std::cout << std::boolalpha; + + basic_autograd_operations_example(); + + std::cout << "\n"; + + compute_higher_order_gradients_example(); + + std::cout << "\n"; + + custom_autograd_function_example(); +}