Skip to content

Commit fbf3eff

Browse files
committed
simplify model forward transfer
1 parent e951aae commit fbf3eff

File tree

1 file changed

+16
-26
lines changed

1 file changed

+16
-26
lines changed

pytorch_lightning/core/memory.py

Lines changed: 16 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,8 @@
99
import torch.nn as nn
1010

1111
import pytorch_lightning as pl
12-
12+
from pytorch_lightning.utilities import transfer_batch_to_device
13+
from pytorch_lightning.utilities.apply_func import apply_to_collection
1314

1415
UNKNOWN_SIZE = 'unknown'
1516

@@ -59,12 +60,14 @@ def _register_hook(self):
5960
on the first forward pass. The hook will remove itself from the module, meaning that
6061
recursive models will only record their input- and output shapes once.
6162
"""
63+
6264
def hook(module, inp, out):
6365
if len(inp) == 1:
6466
inp = inp[0]
6567
self._in_size = parse_batch_shape(inp)
6668
self._out_size = parse_batch_shape(out)
6769
self._hook_handle.remove() # hook detaches itself from module
70+
6871
return self._module.register_forward_hook(hook)
6972

7073
@property
@@ -176,40 +179,27 @@ def summarize(self) -> Dict[str, LayerSummary]:
176179

177180
def _forward_example_input(self) -> None:
178181
""" Run the example input through each layer to get input- and output sizes. """
182+
model = self._model
183+
trainer = self._model.trainer
179184

180185
input_ = self._model.example_input_array
186+
input_ = transfer_batch_to_device(input_, self._model.device)
181187

182-
# TODO: should rethink this to add support for GPU, TPU, AMP, ... and avoid code duplication
183-
# or should it always be done on cpu?
184-
if self._model.on_gpu:
185-
device = next(self._model.parameters()).device
186-
# test if input is a list or a tuple
187-
if isinstance(input_, (list, tuple)):
188-
input_ = [input_i.to(device) if torch.is_tensor(input_i) else input_i
189-
for input_i in input_]
190-
else:
191-
input_ = input_.to(device)
192-
193-
# if model.trainer.use_amp and self.use_native_amp:
194-
# model.forward = torch.cuda.amp.autocast()(model.forward)
188+
if trainer is not None and trainer.use_amp:
189+
if model.use_native_amp:
190+
model.forward = torch.cuda.amp.autocast()(model.forward)
195191

196-
if self._model.trainer is not None and self._model.trainer.use_amp:
197-
# test if it is not a list or a tuple
198-
if isinstance(input_, (list, tuple)):
199-
input_ = [input_i.half() if torch.is_tensor(input_i) else input_i
200-
for input_i in input_]
201-
else:
202-
input_ = input_.half()
192+
input_ = apply_to_collection(input_, torch.Tensor, lambda x: x.type(model.dtype))
203193

204-
mode = self._model.training
205-
self._model.eval()
194+
mode = model.training
195+
model.eval()
206196
with torch.no_grad():
207197
# let the model hooks collect the input- and output shapes
208198
if isinstance(input_, (list, tuple)):
209-
self._model(*input_)
199+
model(*input_)
210200
else:
211-
self._model(input_)
212-
self._model.train(mode) # restore mode of module
201+
model(input_)
202+
model.train(mode) # restore mode of module
213203

214204
def __str__(self):
215205
"""

0 commit comments

Comments
 (0)