|
9 | 9 | import torch.nn as nn |
10 | 10 |
|
11 | 11 | import pytorch_lightning as pl |
12 | | - |
| 12 | +from pytorch_lightning.utilities import transfer_batch_to_device |
| 13 | +from pytorch_lightning.utilities.apply_func import apply_to_collection |
13 | 14 |
|
14 | 15 | UNKNOWN_SIZE = 'unknown' |
15 | 16 |
|
@@ -59,12 +60,14 @@ def _register_hook(self): |
59 | 60 | on the first forward pass. The hook will remove itself from the module, meaning that |
60 | 61 | recursive models will only record their input- and output shapes once. |
61 | 62 | """ |
| 63 | + |
62 | 64 | def hook(module, inp, out): |
63 | 65 | if len(inp) == 1: |
64 | 66 | inp = inp[0] |
65 | 67 | self._in_size = parse_batch_shape(inp) |
66 | 68 | self._out_size = parse_batch_shape(out) |
67 | 69 | self._hook_handle.remove() # hook detaches itself from module |
| 70 | + |
68 | 71 | return self._module.register_forward_hook(hook) |
69 | 72 |
|
70 | 73 | @property |
@@ -176,40 +179,27 @@ def summarize(self) -> Dict[str, LayerSummary]: |
176 | 179 |
|
177 | 180 | def _forward_example_input(self) -> None: |
178 | 181 | """ Run the example input through each layer to get input- and output sizes. """ |
| 182 | + model = self._model |
| 183 | + trainer = self._model.trainer |
179 | 184 |
|
180 | 185 | input_ = self._model.example_input_array |
| 186 | + input_ = transfer_batch_to_device(input_, self._model.device) |
181 | 187 |
|
182 | | - # TODO: should rethink this to add support for GPU, TPU, AMP, ... and avoid code duplication |
183 | | - # or should it always be done on cpu? |
184 | | - if self._model.on_gpu: |
185 | | - device = next(self._model.parameters()).device |
186 | | - # test if input is a list or a tuple |
187 | | - if isinstance(input_, (list, tuple)): |
188 | | - input_ = [input_i.to(device) if torch.is_tensor(input_i) else input_i |
189 | | - for input_i in input_] |
190 | | - else: |
191 | | - input_ = input_.to(device) |
192 | | - |
193 | | - # if model.trainer.use_amp and self.use_native_amp: |
194 | | - # model.forward = torch.cuda.amp.autocast()(model.forward) |
| 188 | + if trainer is not None and trainer.use_amp: |
| 189 | + if model.use_native_amp: |
| 190 | + model.forward = torch.cuda.amp.autocast()(model.forward) |
195 | 191 |
|
196 | | - if self._model.trainer is not None and self._model.trainer.use_amp: |
197 | | - # test if it is not a list or a tuple |
198 | | - if isinstance(input_, (list, tuple)): |
199 | | - input_ = [input_i.half() if torch.is_tensor(input_i) else input_i |
200 | | - for input_i in input_] |
201 | | - else: |
202 | | - input_ = input_.half() |
| 192 | + input_ = apply_to_collection(input_, torch.Tensor, lambda x: x.type(model.dtype)) |
203 | 193 |
|
204 | | - mode = self._model.training |
205 | | - self._model.eval() |
| 194 | + mode = model.training |
| 195 | + model.eval() |
206 | 196 | with torch.no_grad(): |
207 | 197 | # let the model hooks collect the input- and output shapes |
208 | 198 | if isinstance(input_, (list, tuple)): |
209 | | - self._model(*input_) |
| 199 | + model(*input_) |
210 | 200 | else: |
211 | | - self._model(input_) |
212 | | - self._model.train(mode) # restore mode of module |
| 201 | + model(input_) |
| 202 | + model.train(mode) # restore mode of module |
213 | 203 |
|
214 | 204 | def __str__(self): |
215 | 205 | """ |
|
0 commit comments