Skip to content

Commit 0e020a0

Browse files
author
Jessica Lin
authored
Merge pull request #943 from CamiWilliams/basics-recipe-warmstart
Warmstarting a model using parameters from a different model recipe
2 parents b4b21ef + c793ed6 commit 0e020a0

File tree

1 file changed

+142
-0
lines changed

1 file changed

+142
-0
lines changed
Lines changed: 142 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,142 @@
1+
"""
2+
Warmstarting model using parameters from a different model in PyTorch
3+
=====================================================================
4+
Partially loading a model or loading a partial model are common
5+
scenarios when transfer learning or training a new complex model.
6+
Leveraging trained parameters, even if only a few are usable, will help
7+
to warmstart the training process and hopefully help your model converge
8+
much faster than training from scratch.
9+
10+
Introduction
11+
------------
12+
Whether you are loading from a partial ``state_dict``, which is missing
13+
some keys, or loading a ``state_dict`` with more keys than the model
14+
that you are loading into, you can set the strict argument to ``False``
15+
in the ``load_state_dict()`` function to ignore non-matching keys.
16+
In this recipe, we will experiment with warmstarting a model using
17+
parameters of a different model.
18+
19+
Setup
20+
-----
21+
Before we begin, we need to install ``torch`` if it isn’t already
22+
available.
23+
24+
::
25+
26+
pip install torch
27+
28+
"""
29+
30+
31+
32+
######################################################################
33+
# Steps
34+
# -----
35+
#
36+
# 1. Import all necessary libraries for loading our data
37+
# 2. Define and intialize the neural network A and B
38+
# 3. Save model A
39+
# 4. Load into model B
40+
#
41+
# 1. Import necessary libraries for loading our data
42+
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
43+
#
44+
# For this recipe, we will use ``torch`` and its subsidiaries ``torch.nn``
45+
# and ``torch.optim``.
46+
#
47+
48+
import torch
49+
import torch.nn as nn
50+
import torch.optim as optim
51+
52+
53+
######################################################################
54+
# 2. Define and intialize the neural network A and B
55+
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
56+
#
57+
# For sake of example, we will create a neural network for training
58+
# images. To learn more see the Defining a Neural Network recipe. We will
59+
# create two neural networks for sake of loading one parameter of type A
60+
# into type B.
61+
#
62+
63+
class NetA(nn.Module):
64+
def __init__(self):
65+
super(NetA, self).__init__()
66+
self.conv1 = nn.Conv2d(3, 6, 5)
67+
self.pool = nn.MaxPool2d(2, 2)
68+
self.conv2 = nn.Conv2d(6, 16, 5)
69+
self.fc1 = nn.Linear(16 * 5 * 5, 120)
70+
self.fc2 = nn.Linear(120, 84)
71+
self.fc3 = nn.Linear(84, 10)
72+
73+
def forward(self, x):
74+
x = self.pool(F.relu(self.conv1(x)))
75+
x = self.pool(F.relu(self.conv2(x)))
76+
x = x.view(-1, 16 * 5 * 5)
77+
x = F.relu(self.fc1(x))
78+
x = F.relu(self.fc2(x))
79+
x = self.fc3(x)
80+
return x
81+
82+
netA = NetA()
83+
84+
class NetB(nn.Module):
85+
def __init__(self):
86+
super(NetB, self).__init__()
87+
self.conv1 = nn.Conv2d(3, 6, 5)
88+
self.pool = nn.MaxPool2d(2, 2)
89+
self.conv2 = nn.Conv2d(6, 16, 5)
90+
self.fc1 = nn.Linear(16 * 5 * 5, 120)
91+
self.fc2 = nn.Linear(120, 84)
92+
self.fc3 = nn.Linear(84, 10)
93+
94+
def forward(self, x):
95+
x = self.pool(F.relu(self.conv1(x)))
96+
x = self.pool(F.relu(self.conv2(x)))
97+
x = x.view(-1, 16 * 5 * 5)
98+
x = F.relu(self.fc1(x))
99+
x = F.relu(self.fc2(x))
100+
x = self.fc3(x)
101+
return x
102+
103+
netB = NetB()
104+
105+
106+
######################################################################
107+
# 3. Save model A
108+
# ~~~~~~~~~~~~~~~~~~~
109+
#
110+
111+
# Specify a path to save to
112+
PATH = "model.pt"
113+
114+
torch.save(netA.state_dict(), PATH)
115+
116+
117+
######################################################################
118+
# 4. Load into model B
119+
# ~~~~~~~~~~~~~~~~~~~~~~~~
120+
#
121+
# If you want to load parameters from one layer to another, but some keys
122+
# do not match, simply change the name of the parameter keys in the
123+
# state_dict that you are loading to match the keys in the model that you
124+
# are loading into.
125+
#
126+
127+
netB.load_state_dict(torch.load(PATH), strict=False)
128+
129+
130+
######################################################################
131+
# You can see that all keys matched successfully!
132+
#
133+
# Congratulations! You have successfully warmstarted a model using
134+
# parameters from a different model in PyTorch.
135+
#
136+
# Learn More
137+
# ----------
138+
#
139+
# Take a look at these other recipes to continue your learning:
140+
#
141+
# - TBD
142+
# - TBD

0 commit comments

Comments
 (0)