1919first change the runtime to “GPU” or higher. Once you do, we need to
2020install ``torch`` if it isn’t already available.
2121
22- """
22+ ::
23+
24+ pip install torch
2325
24- pip install torch
26+ """
2527
2628
2729######################################################################
3537# 5. Save on a CPU, load on a GPU
3638# 6. Saving and loading ``DataParallel`` models
3739#
38- # **1) Import necessary libraries for loading our data**
40+ # 1. Import necessary libraries for loading our data
3941# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
4042#
4143# For this recipe, we will use ``torch`` and its subsidiaries ``torch.nn``
4850
4951
5052######################################################################
51- # **2) Define and intialize the neural network**
53+ # 2. Define and intialize the neural network
5254# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
5355#
5456# For sake of example, we will create a neural network for training
@@ -79,7 +81,7 @@ def forward(self, x):
7981
8082
8183######################################################################
82- # **3) Save on GPU, Load on CPU**
84+ # 3. Save on GPU, Load on CPU
8385# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
8486#
8587# When loading a model on a CPU that was trained with a GPU, pass
@@ -103,7 +105,7 @@ def forward(self, x):
103105# In this case, the storages underlying the tensors are dynamically
104106# remapped to the CPU device using the ``map_location`` argument.
105107#
106- # **4) Save on GPU, Load on GPU**
108+ # 4. Save on GPU, Load on GPU
107109# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
108110#
109111# When loading a model on a GPU that was trained and saved on GPU, simply
@@ -130,7 +132,7 @@ def forward(self, x):
130132# remember to manually overwrite tensors:
131133# ``my_tensor = my_tensor.to(torch.device('cuda'))``.
132134#
133- # **5) Save on CPU, Load on GPU**
135+ # 5. Save on CPU, Load on GPU
134136# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
135137#
136138# When loading a model on a GPU that was trained and saved on CPU, set the
@@ -157,7 +159,7 @@ def forward(self, x):
157159
158160
159161######################################################################
160- # **6) Saving ``torch.nn.DataParallel`` Models**
162+ # 6. Saving ``torch.nn.DataParallel`` Models
161163# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
162164#
163165# ``torch.nn.DataParallel`` is a model wrapper that enables parallel GPU
@@ -185,4 +187,4 @@ def forward(self, x):
185187#
186188# - TBD
187189# - TBD
188- #
190+ #
0 commit comments