Skip to content

Commit e9e5776

Browse files
author
Jessica Lin
authored
Merge pull request #934 from CamiWilliams/basics-recipe-zerogradients
Zeroing out gradients recipe
2 parents c2d76d2 + 7b3d5c0 commit e9e5776

File tree

1 file changed

+100
-57
lines changed

1 file changed

+100
-57
lines changed
Lines changed: 100 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -1,48 +1,66 @@
1-
# -*- coding: utf-8 -*-
2-
"""Zeroing out gradients in PyTorch.ipynb
3-
4-
Automatically generated by Colaboratory.
5-
6-
Original file is located at
7-
https://colab.research.google.com/drive/1K2m6BkzNRB2rAN1sELFOP3yRZquaUAAg
8-
1+
"""
92
Zeroing out gradients in PyTorch
10-
=======================
11-
It is beneficial to zero out gradients when building a neural network. This is because by default, gradients are accumulated in buffers (i.e, not overwritten) whenever ``.backward()`` is called.
3+
================================
4+
It is beneficial to zero out gradients when building a neural network.
5+
This is because by default, gradients are accumulated in buffers (i.e,
6+
not overwritten) whenever ``.backward()`` is called.
127
138
Introduction
14-
---
15-
When training your neural network, models are able to increase their accuracy through gradient decent. In short, gradient descent is the process of minimizing our loss (or error) by tweaking the weights and biases in our model.
9+
------------
10+
When training your neural network, models are able to increase their
11+
accuracy through gradient decent. In short, gradient descent is the
12+
process of minimizing our loss (or error) by tweaking the weights and
13+
biases in our model.
14+
15+
``torch.Tensor`` is the central class of PyTorch. When you create a
16+
tensor, if you set its attribute ``.requires_grad`` as ``True``, the
17+
package tracks all operations on it. This happens on subsequent backward
18+
passes. The gradient for this tensor will be accumulated into ``.grad``
19+
attribute. The accumulation (or sum) of all the gradients is calculated
20+
when .backward() is called on the loss tensor.
21+
22+
There are cases where it may be necessary to zero-out the gradients of a
23+
tensor. For example: when you start your training loop, you should zero
24+
out the gradients so that you can perform this tracking correctly.
25+
In this recipe, we will learn how to zero out gradients using the
26+
PyTorch library. We will demonstrate how to do this by training a neural
27+
network on the ``CIFAR10`` dataset built into PyTorch.
1628
17-
``torch.Tensor`` is the central class of PyTorch. When you create a tensor,
18-
if you set its attribute ``.requires_grad`` as ``True``, the package tracks all operations on it. This happens on subsequent backward passes. The gradient for this tensor will be accumulated into ``.grad`` attribute. The accumulation (or sum) of all the gradients is calculated when .backward() is called on the loss tensor.
29+
Setup
30+
-----
31+
Since we will be training data in this recipe, if you are in a runable
32+
notebook, it is best to switch the runtime to GPU or TPU.
33+
Before we begin, we need to install ``torch`` and ``torchvision`` if
34+
they aren’t already available.
1935
20-
There are cases where it may be necessary to zero-out the gradients of a tensor. For example: when you start your training loop, you should zero out the gradients so that you can perform this tracking correctly.
36+
::
2137
22-
In this recipe, we will learn how to zero out gradients using the PyTorch library. We will demonstrate how to do this by training a neural network on the ``CIFAR10`` dataset built into PyTorch.
38+
pip install torchvision
2339
24-
Setup
25-
---
26-
Since we will be training data in this recipe, if you are in a runable notebook, it is best to switch the runtime to GPU or TPU.
2740
28-
Before we begin, we need to install ``torch`` and ``torchvision`` if they aren't already available.
2941
"""
3042

31-
pip install torchvision
32-
33-
"""Steps
34-
-----------------
35-
Steps 1 through 4 set up our data and neural network for training. The process of zeroing out the gradients happens in step 5. If you already have your data and neural network built, skip to 5.
3643

37-
1. Import all necessary libraries for loading our data
38-
2. Load and normalize the dataset
39-
3. Build the neural network
40-
4. Define the loss function
41-
5. Zero the gradients while training the network
42-
43-
### **1) Import necessary libraries for loading our data**
44-
For this recipe, we will just be using ``torch`` and ``torchvision`` to access the dataset.
45-
"""
44+
######################################################################
45+
# Steps
46+
# -----
47+
#
48+
# Steps 1 through 4 set up our data and neural network for training. The
49+
# process of zeroing out the gradients happens in step 5. If you already
50+
# have your data and neural network built, skip to 5.
51+
#
52+
# 1. Import all necessary libraries for loading our data
53+
# 2. Load and normalize the dataset
54+
# 3. Build the neural network
55+
# 4. Define the loss function
56+
# 5. Zero the gradients while training the network
57+
#
58+
# 1. Import necessary libraries for loading our data
59+
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
60+
#
61+
# For this recipe, we will just be using ``torch`` and ``torchvision`` to
62+
# access the dataset.
63+
#
4664

4765
import torch
4866

@@ -54,9 +72,14 @@
5472
import torchvision
5573
import torchvision.transforms as transforms
5674

57-
"""### **2) Load and normalize the dataset**
58-
PyTorch features various built-in datasets (see the Loading Data recipe for more information).
59-
"""
75+
76+
######################################################################
77+
# 2. Load and normalize the dataset
78+
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
79+
#
80+
# PyTorch features various built-in datasets (see the Loading Data recipe
81+
# for more information).
82+
#
6083

6184
transform = transforms.Compose(
6285
[transforms.ToTensor(),
@@ -75,9 +98,14 @@
7598
classes = ('plane', 'car', 'bird', 'cat',
7699
'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
77100

78-
"""### **3) Build the neural network**
79-
We will use a convolutional neural network. To learn more see the Defining a Neural Network recipe.
80-
"""
101+
102+
######################################################################
103+
# 3. Build the neural network
104+
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
105+
#
106+
# We will use a convolutional neural network. To learn more see the
107+
# Defining a Neural Network recipe.
108+
#
81109

82110
class Net(nn.Module):
83111
def __init__(self):
@@ -98,19 +126,30 @@ def forward(self, x):
98126
x = self.fc3(x)
99127
return x
100128

101-
"""### **4) Define a Loss function and optimizer**
102-
Let’s use a Classification Cross-Entropy loss and SGD with momentum.
103-
"""
129+
130+
######################################################################
131+
# 4. Define a Loss function and optimizer
132+
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
133+
#
134+
# Let’s use a Classification Cross-Entropy loss and SGD with momentum.
135+
#
104136

105137
net = Net()
106138
criterion = nn.CrossEntropyLoss()
107139
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
108140

109-
"""### **5) Zero the gradients while training the network**
110-
This is when things start to get interesting. We simply have to loop over our data iterator, and feed the inputs to the network and optimize.
111141

112-
Notice that for each entity of data, we zero out the gradients. This is to ensure that we aren't tracking any unnecessary information when we train our neural network.
113-
"""
142+
######################################################################
143+
# 5. Zero the gradients while training the network
144+
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
145+
#
146+
# This is when things start to get interesting. We simply have to loop
147+
# over our data iterator, and feed the inputs to the network and optimize.
148+
#
149+
# Notice that for each entity of data, we zero out the gradients. This is
150+
# to ensure that we aren’t tracking any unnecessary information when we
151+
# train our neural network.
152+
#
114153

115154
for epoch in range(2): # loop over the dataset multiple times
116155

@@ -137,14 +176,18 @@ def forward(self, x):
137176

138177
print('Finished Training')
139178

140-
"""You can also use ``model.zero_grad()``. This is the same as using ``optimizer.zero_grad()`` as long as all your model parameters are in that optimizer. Use your best judgement to decide which one to use.
141-
142-
Congratulations! You have successfully zeroed out gradients PyTorch.
143-
144-
Learn More
145-
----------------------------
146-
Take a look at these other recipes to continue your learning:
147179

148-
* TBD
149-
* TBD
150-
"""
180+
######################################################################
181+
# You can also use ``model.zero_grad()``. This is the same as using
182+
# ``optimizer.zero_grad()`` as long as all your model parameters are in
183+
# that optimizer. Use your best judgement to decide which one to use.
184+
#
185+
# Congratulations! You have successfully zeroed out gradients PyTorch.
186+
#
187+
# Learn More
188+
# ----------
189+
#
190+
# Take a look at these other recipes to continue your learning:
191+
#
192+
# - TBD
193+
# - TBD

0 commit comments

Comments
 (0)