Skip to content

Commit b4b21ef

Browse files
author
Jessica Lin
authored
Merge pull request #939 from CamiWilliams/basics-recipe-multiplesaveload
Saving and loading multiple models in one file recipe
2 parents de0b952 + 767b9a8 commit b4b21ef

File tree

1 file changed

+162
-0
lines changed

1 file changed

+162
-0
lines changed
Lines changed: 162 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,162 @@
1+
"""
2+
Saving and loading multiple models in one file using PyTorch
3+
============================================================
4+
Saving and loading multiple models can be helpful for reusing models
5+
that you have previously trained.
6+
7+
Introduction
8+
------------
9+
When saving a model comprised of multiple ``torch.nn.Modules``, such as
10+
a GAN, a sequence-to-sequence model, or an ensemble of models, you must
11+
save a dictionary of each model’s state_dict and corresponding
12+
optimizer. You can also save any other items that may aid you in
13+
resuming training by simply appending them to the dictionary.
14+
To load the models, first initialize the models and optimizers, then
15+
load the dictionary locally using ``torch.load()``. From here, you can
16+
easily access the saved items by simply querying the dictionary as you
17+
would expect.
18+
In this recipe, we will demonstrate how to save multiple models to one
19+
file using PyTorch.
20+
21+
Setup
22+
-----
23+
Before we begin, we need to install ``torch`` if it isn’t already
24+
available.
25+
26+
::
27+
28+
pip install torch
29+
30+
"""
31+
32+
33+
34+
######################################################################
35+
# Steps
36+
# -----
37+
#
38+
# 1. Import all necessary libraries for loading our data
39+
# 2. Define and intialize the neural network
40+
# 3. Initialize the optimizer
41+
# 4. Save multiple models
42+
# 5. Load multiple models
43+
#
44+
# 1. Import necessary libraries for loading our data
45+
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
46+
#
47+
# For this recipe, we will use ``torch`` and its subsidiaries ``torch.nn``
48+
# and ``torch.optim``.
49+
#
50+
51+
import torch
52+
import torch.nn as nn
53+
import torch.optim as optim
54+
55+
56+
######################################################################
57+
# 2. Define and intialize the neural network
58+
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
59+
#
60+
# For sake of example, we will create a neural network for training
61+
# images. To learn more see the Defining a Neural Network recipe. Build
62+
# two variables for the models to eventually save.
63+
#
64+
65+
class Net(nn.Module):
66+
def __init__(self):
67+
super(Net, self).__init__()
68+
self.conv1 = nn.Conv2d(3, 6, 5)
69+
self.pool = nn.MaxPool2d(2, 2)
70+
self.conv2 = nn.Conv2d(6, 16, 5)
71+
self.fc1 = nn.Linear(16 * 5 * 5, 120)
72+
self.fc2 = nn.Linear(120, 84)
73+
self.fc3 = nn.Linear(84, 10)
74+
75+
def forward(self, x):
76+
x = self.pool(F.relu(self.conv1(x)))
77+
x = self.pool(F.relu(self.conv2(x)))
78+
x = x.view(-1, 16 * 5 * 5)
79+
x = F.relu(self.fc1(x))
80+
x = F.relu(self.fc2(x))
81+
x = self.fc3(x)
82+
return x
83+
84+
netA = Net()
85+
netB = Net()
86+
87+
88+
######################################################################
89+
# 3. Initialize the optimizer
90+
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
91+
#
92+
# We will use SGD with momentum to build an optimizer for each model we
93+
# created.
94+
#
95+
96+
optimizerA = optim.SGD(netA.parameters(), lr=0.001, momentum=0.9)
97+
optimizerB = optim.SGD(netB.parameters(), lr=0.001, momentum=0.9)
98+
99+
100+
######################################################################
101+
# 4. Save multiple models
102+
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~
103+
#
104+
# Collect all relevant information and build your dictionary.
105+
#
106+
107+
# Specify a path to save to
108+
PATH = "model.pt"
109+
110+
torch.save({
111+
'modelA_state_dict': netA.state_dict(),
112+
'modelB_state_dict': netB.state_dict(),
113+
'optimizerA_state_dict': optimizerA.state_dict(),
114+
'optimizerB_state_dict': optimizerB.state_dict(),
115+
}, PATH)
116+
117+
118+
######################################################################
119+
# 4. Load multiple models
120+
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~
121+
#
122+
# Remember to first initialize the models and optimizers, then load the
123+
# dictionary locally.
124+
#
125+
126+
modelA = Net()
127+
modelB = Net()
128+
optimModelA = optim.SGD(modelA.parameters(), lr=0.001, momentum=0.9)
129+
optimModelB = optim.SGD(modelB.parameters(), lr=0.001, momentum=0.9)
130+
131+
checkpoint = torch.load(PATH)
132+
modelA.load_state_dict(checkpoint['modelA_state_dict'])
133+
modelB.load_state_dict(checkpoint['modelB_state_dict'])
134+
optimizerA.load_state_dict(checkpoint['optimizerA_state_dict'])
135+
optimizerB.load_state_dict(checkpoint['optimizerB_state_dict'])
136+
137+
modelA.eval()
138+
modelB.eval()
139+
# - or -
140+
modelA.train()
141+
modelB.train()
142+
143+
144+
######################################################################
145+
# You must call ``model.eval()`` to set dropout and batch normalization
146+
# layers to evaluation mode before running inference. Failing to do this
147+
# will yield inconsistent inference results.
148+
#
149+
# If you wish to resuming training, call ``model.train()`` to ensure these
150+
# layers are in training mode.
151+
#
152+
# Congratulations! You have successfully saved and loaded multiple models
153+
# in PyTorch.
154+
#
155+
# Learn More
156+
# ----------
157+
#
158+
# Take a look at these other recipes to continue your learning:
159+
#
160+
# - TBD
161+
# - TBD
162+
#

0 commit comments

Comments
 (0)