-
Notifications
You must be signed in to change notification settings - Fork 3.6k
Closed
Labels
bugSomething isn't workingSomething isn't workinghelp wantedOpen to be worked onOpen to be worked on
Description
🐛 Bug
I am trying to do finetuning on a pre-trained model which is saved as TorchScript. Unfortunately, it looks like Lightning's LayerSummary does not support scripted modules:
To Reproduce
Run
import torch
from torch import nn
import pytorch_lightning as pl
class Module(pl.LightningModule):
def __init__(self):
super().__init__()
self.linear = torch.jit.script(nn.Linear(5, 1)) # Notice the scripting!
def configure_optimizers(self):
return torch.optim.Adam(self.linear.parameters())
def training_step(self, batch, batch_idx):
return self.linear(batch).sum()
if __name__ == "__main__":
m = Module()
datasets = [torch.rand([5]) for __ in range(100)]
train_loader = torch.utils.data.DataLoader(datasets, batch_size=8)
trainer = pl.Trainer(
num_sanity_val_steps=0,
max_epochs=1,
)
trainer.fit(m, train_loader)fails with
Traceback (most recent call last):
File "scratch.py", line 29, in <module>
trainer.fit(m, train_loader)
File "/home/alandu/miniconda3/envs/ctrldev/lib/python3.6/site-packages/pytorch_lightning/trainer/trainer.py", line 513, in fit
self.dispatch()
File "/home/alandu/miniconda3/envs/ctrldev/lib/python3.6/site-packages/pytorch_lightning/trainer/trainer.py", line 553, in dispatch
self.accelerator.start_training(self)
File "/home/alandu/miniconda3/envs/ctrldev/lib/python3.6/site-packages/pytorch_lightning/accelerators/accelerator.py", line 74, in start_training
self.training_type_plugin.start_training(trainer)
File "/home/alandu/miniconda3/envs/ctrldev/lib/python3.6/site-packages/pytorch_lightning/plugins/training_type/training_type_plugin.py", line 111, in start_training
self._results = trainer.run_train()
File "/home/alandu/miniconda3/envs/ctrldev/lib/python3.6/site-packages/pytorch_lightning/trainer/trainer.py", line 609, in run_train
self._pre_training_routine()
File "/home/alandu/miniconda3/envs/ctrldev/lib/python3.6/site-packages/pytorch_lightning/trainer/trainer.py", line 595, in _pre_training_routine
ref_model.summarize(mode=self.weights_summary)
File "/home/alandu/miniconda3/envs/ctrldev/lib/python3.6/site-packages/pytorch_lightning/core/lightning.py", line 1456, in summarize
model_summary = ModelSummary(self, mode=mode)
File "/home/alandu/miniconda3/envs/ctrldev/lib/python3.6/site-packages/pytorch_lightning/core/memory.py", line 184, in __init__
self._layer_summary = self.summarize()
File "/home/alandu/miniconda3/envs/ctrldev/lib/python3.6/site-packages/pytorch_lightning/core/memory.py", line 236, in summarize
summary = OrderedDict((name, LayerSummary(module)) for name, module in self.named_modules)
File "/home/alandu/miniconda3/envs/ctrldev/lib/python3.6/site-packages/pytorch_lightning/core/memory.py", line 236, in <genexpr>
summary = OrderedDict((name, LayerSummary(module)) for name, module in self.named_modules)
File "/home/alandu/miniconda3/envs/ctrldev/lib/python3.6/site-packages/pytorch_lightning/core/memory.py", line 67, in __init__
self._hook_handle = self._register_hook()
File "/home/alandu/miniconda3/envs/ctrldev/lib/python3.6/site-packages/pytorch_lightning/core/memory.py", line 91, in _register_hook
return self._module.register_forward_hook(hook)
File "/home/alandu/miniconda3/envs/ctrldev/lib/python3.6/site-packages/torch/jit/_script.py", line 723, in fail
raise RuntimeError(name + " is not supported on ScriptModules")
RuntimeError: register_forward_hook is not supported on ScriptModules
Exception ignored in: <bound method LayerSummary.__del__ of <pytorch_lightning.core.memory.LayerSummary object at 0x7f5815a7a9e8>>
Traceback (most recent call last):
File "/home/alandu/miniconda3/envs/ctrldev/lib/python3.6/site-packages/pytorch_lightning/core/memory.py", line 72, in __del__
self.detach_hook()
File "/home/alandu/miniconda3/envs/ctrldev/lib/python3.6/site-packages/pytorch_lightning/core/memory.py", line 98, in detach_hook
if self._hook_handle is not None:
AttributeError: 'LayerSummary' object has no attribute '_hook_handle'
Expected behavior
This should work as if I had just done nn.Linear.
For now, I can work around this by setting weight_summary=None.
Environment
* CUDA:
- GPU:
- available: False
- version: None
* Packages:
- numpy: 1.19.5
- pyTorch_debug: False
- pyTorch_version: 1.8.0
- pytorch-lightning: 1.2.2
- tqdm: 4.59.0
* System:
- OS: Linux
- architecture:
- 64bit
- ELF
- processor: x86_64
- python: 3.6.10
- version: #1 SMP Fri Feb 26 16:21:30 UTC 2021
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't workinghelp wantedOpen to be worked onOpen to be worked on