Skip to content

Unable to resume from checkpoint when using apex #11488

@dtmoodie

Description

@dtmoodie

🐛 Bug

When trying to resume a model that was trained with apex, I cannot load the checkpoint.

To Reproduce

Train model with trainer.fit with the following params:

precision: 16
amp_level: O2
amp_backend: apex

Then attempt to continue training using the checkpoint and ckpt_path.
The error that I get is:

File "/home/dan/code/ml/yolo/train_lightning.py", line 299, in <module>
    trainer.fit(model, dataset, ckpt_path=manager.args.resume)
  File "/home/dan/.local/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 737, in fit
    self._call_and_handle_interrupt(
 File "/home/dan/.local/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 682, in _call_and_handle_interrupt
   return trainer_fn(*args, **kwargs)
 File "/home/dan/.local/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 772, in _fit_impl
   self._run(model, ckpt_path=ckpt_path)
 File "/home/dan/.local/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1189, in _run
   self.checkpoint_connector.restore_training_state()
 File "/home/dan/.local/lib/python3.8/site-packages/pytorch_lightning/trainer/connectors/checkpoint_connector.py", line 185, in restore_training_state
   self.trainer.precision_plugin.on_load_checkpoint(self._loaded_checkpoint)
 File "/home/dan/.local/lib/python3.8/site-packages/pytorch_lightning/plugins/precision/apex_amp.py", line 97, in on_load_checkpoint
   amp.load_state_dict(checkpoint["amp_scaling_state"])
 File "/opt/conda/lib/python3.8/site-packages/apex/amp/frontend.py", line 375, in load_state_dict
   if len(state_dict) != len(_amp_state.loss_scalers):

AttributeError: 'AmpState' object has no attribute 'loss_scalers'

Expected behavior

Training resumes as it would without apex.

Environment

Please copy and paste the output from our environment collection script:

  • CUDA:
    - GPU:
    - NVIDIA GeForce RTX 3090
    - available: True
    - version: 11.4
  • Packages:
    - numpy: 1.21.2
    - pyTorch_debug: False
    - pyTorch_version: 1.10.0a0+0aef44c
    - pytorch-lightning: 1.5.6
    - tqdm: 4.62.3
  • System:
    - OS: Linux
    - architecture:
    - 64bit
    - ELF
    - processor: x86_64
    - python: 3.8.12
    - version: Extend CI #44~20.04.2-Ubuntu SMP Tue Oct 26 18:07:44 UTC 2021
  • How you installed PyTorch (conda, pip, source): pip

cc @carmocca @justusschock @awaelchli @akihironitta @rohitgr7

Metadata

Metadata

Assignees

Labels

Type

No type

Projects

Milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions