Skip to content

Commit c2d76d2

Browse files
author
Jessica Lin
authored
Merge pull request #949 from CamiWilliams/basics-recipe-acrossdevices
Saving and loading models across devices recipe
2 parents 0e020a0 + 961624d commit c2d76d2

File tree

1 file changed

+190
-0
lines changed

1 file changed

+190
-0
lines changed
Lines changed: 190 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,190 @@
1+
"""
2+
Saving and loading models across devices in PyTorch
3+
===================================================
4+
5+
There may be instances where you want to save and load your neural
6+
networks across different devices.
7+
8+
Introduction
9+
------------
10+
11+
Saving and loading models across devices is relatively straightforward
12+
using PyTorch. In this recipe, we will experiment with saving and
13+
loading models across CPUs and GPUs.
14+
15+
Setup
16+
-----
17+
18+
In order for every code block to run properly in this recipe, you must
19+
first change the runtime to “GPU” or higher. Once you do, we need to
20+
install ``torch`` if it isn’t already available.
21+
22+
::
23+
24+
pip install torch
25+
26+
"""
27+
28+
29+
######################################################################
30+
# Steps
31+
# -----
32+
#
33+
# 1. Import all necessary libraries for loading our data
34+
# 2. Define and intialize the neural network
35+
# 3. Save on a GPU, load on a CPU
36+
# 4. Save on a GPU, load on a GPU
37+
# 5. Save on a CPU, load on a GPU
38+
# 6. Saving and loading ``DataParallel`` models
39+
#
40+
# 1. Import necessary libraries for loading our data
41+
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
42+
#
43+
# For this recipe, we will use ``torch`` and its subsidiaries ``torch.nn``
44+
# and ``torch.optim``.
45+
#
46+
47+
import torch
48+
import torch.nn as nn
49+
import torch.optim as optim
50+
51+
52+
######################################################################
53+
# 2. Define and intialize the neural network
54+
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
55+
#
56+
# For sake of example, we will create a neural network for training
57+
# images. To learn more see the Defining a Neural Network recipe.
58+
#
59+
60+
class Net(nn.Module):
61+
def __init__(self):
62+
super(Net, self).__init__()
63+
self.conv1 = nn.Conv2d(3, 6, 5)
64+
self.pool = nn.MaxPool2d(2, 2)
65+
self.conv2 = nn.Conv2d(6, 16, 5)
66+
self.fc1 = nn.Linear(16 * 5 * 5, 120)
67+
self.fc2 = nn.Linear(120, 84)
68+
self.fc3 = nn.Linear(84, 10)
69+
70+
def forward(self, x):
71+
x = self.pool(F.relu(self.conv1(x)))
72+
x = self.pool(F.relu(self.conv2(x)))
73+
x = x.view(-1, 16 * 5 * 5)
74+
x = F.relu(self.fc1(x))
75+
x = F.relu(self.fc2(x))
76+
x = self.fc3(x)
77+
return x
78+
79+
net = Net()
80+
print(net)
81+
82+
83+
######################################################################
84+
# 3. Save on GPU, Load on CPU
85+
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
86+
#
87+
# When loading a model on a CPU that was trained with a GPU, pass
88+
# ``torch.device('cpu')`` to the ``map_location`` argument in the
89+
# ``torch.load()`` function.
90+
#
91+
92+
# Specify a path to save to
93+
PATH = "model.pt"
94+
95+
# Save
96+
torch.save(net.state_dict(), PATH)
97+
98+
# Load
99+
device = torch.device('cpu')
100+
model = Net()
101+
model.load_state_dict(torch.load(PATH, map_location=device))
102+
103+
104+
######################################################################
105+
# In this case, the storages underlying the tensors are dynamically
106+
# remapped to the CPU device using the ``map_location`` argument.
107+
#
108+
# 4. Save on GPU, Load on GPU
109+
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
110+
#
111+
# When loading a model on a GPU that was trained and saved on GPU, simply
112+
# convert the initialized model to a CUDA optimized model using
113+
# ``model.to(torch.device('cuda'))``.
114+
#
115+
# Be sure to use the ``.to(torch.device('cuda'))`` function on all model
116+
# inputs to prepare the data for the model.
117+
#
118+
119+
# Save
120+
torch.save(net.state_dict(), PATH)
121+
122+
# Load
123+
device = torch.device("cuda")
124+
model = Net()
125+
model.load_state_dict(torch.load(PATH))
126+
model.to(device)
127+
128+
129+
######################################################################
130+
# Note that calling ``my_tensor.to(device)`` returns a new copy of
131+
# ``my_tensor`` on GPU. It does NOT overwrite ``my_tensor``. Therefore,
132+
# remember to manually overwrite tensors:
133+
# ``my_tensor = my_tensor.to(torch.device('cuda'))``.
134+
#
135+
# 5. Save on CPU, Load on GPU
136+
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
137+
#
138+
# When loading a model on a GPU that was trained and saved on CPU, set the
139+
# ``map_location`` argument in the ``torch.load()`` function to
140+
# ``cuda:device_id``. This loads the model to a given GPU device.
141+
#
142+
# Be sure to call ``model.to(torch.device('cuda'))`` to convert the
143+
# model’s parameter tensors to CUDA tensors.
144+
#
145+
# Finally, also be sure to use the ``.to(torch.device('cuda'))`` function
146+
# on all model inputs to prepare the data for the CUDA optimized model.
147+
#
148+
149+
# Save
150+
torch.save(net.state_dict(), PATH)
151+
152+
# Load
153+
device = torch.device("cuda")
154+
model = Net()
155+
# Choose whatever GPU device number you want
156+
model.load_state_dict(torch.load(PATH, map_location="cuda:0"))
157+
# Make sure to call input = input.to(device) on any input tensors that you feed to the model
158+
model.to(device)
159+
160+
161+
######################################################################
162+
# 6. Saving ``torch.nn.DataParallel`` Models
163+
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
164+
#
165+
# ``torch.nn.DataParallel`` is a model wrapper that enables parallel GPU
166+
# utilization.
167+
#
168+
# To save a ``DataParallel`` model generically, save the
169+
# ``model.module.state_dict()``. This way, you have the flexibility to
170+
# load the model any way you want to any device you want.
171+
#
172+
173+
# Save
174+
torch.save(net.module.state_dict(), PATH)
175+
176+
# Load to whatever device you want
177+
178+
179+
######################################################################
180+
# Congratulations! You have successfully saved and loaded models across
181+
# devices in PyTorch.
182+
#
183+
# Learn More
184+
# ----------
185+
#
186+
# Take a look at these other recipes to continue your learning:
187+
#
188+
# - TBD
189+
# - TBD
190+
#

0 commit comments

Comments
 (0)