-
Notifications
You must be signed in to change notification settings - Fork 3.6k
Open
Labels
bugSomething isn't workingSomething isn't workingcheckpointingRelated to checkpointingRelated to checkpointingver: 2.0.x
Description
Bug description
When reloading from a checkpoint (using Trainer().fit(..., ckpt_path=X)) it seems that the current_epoch of the trainer is not updated accordingly at on_fit_start, but rather later on at on_train_epoch_end.
load_state_dict is called already for callbacks, so we know it is indeed a training that is loaded from a ckpt, not from scratch.
My workaround for this (I have some custom logic in a callback that only applies if we continue/fine-tune a training) is to look at trainer.ckpt_path not being None.
Is this the intended behavior or should trainer.current_epoch be also >0 ?
What version are you seeing the problem on?
v2.0
How to reproduce the bug
from pytorch_lightning import LightningModule, Callback, Trainer
from pytorch_lightning.callbacks.callback import Callback
from torch.utils.data import Dataset, DataLoader
import torch as tr
from torch import nn
from torch.nn import functional as F
class Reader(Dataset):
def __getitem__(self, ix):
return tr.randn(5), tr.randn(1)
def __len__(self):
return 5
class MyCallback(Callback):
def on_train_epoch_start(self, trainer, pl_module):
print(f"[on_train_epoch_start] epoch: {trainer.current_epoch}")
def on_fit_start(self, trainer, pl_module):
print(f"[on_fit_start] epoch: {trainer.current_epoch}")
def load_state_dict(self, state):
print("[load_state_dict]")
def state_dict(self):
print("[state_dict]")
return {1: 2}
class LitModel(LightningModule):
def __init__(self):
super().__init__()
self.l1 = nn.Linear(5, 1)
def forward(self, x):
return tr.relu(self.l1(x.view(x.size(0), -1)))
def training_step(self, batch, batch_idx):
x, y = batch
y_hat = self(x)
loss = F.cross_entropy(y_hat, y)
return loss
def configure_optimizers(self):
return tr.optim.Adam(self.parameters(), lr=0.02)
def main():
loader = DataLoader(Reader(), batch_size=5)
model = LitModel()
Trainer(max_epochs=2, callbacks=[MyCallback()], enable_progress_bar=False, enable_model_summary=False) \
.fit(model, train_dataloaders=loader)
print("----")
Trainer(max_epochs=5, callbacks=[MyCallback()], enable_progress_bar=False, enable_model_summary=False) \
.fit(model, train_dataloaders=loader, ckpt_path=model.trainer.checkpoint_callback.best_model_path)
if __name__ == "__main__":
main()Error messages and logs
python main.py 2>/dev/null
[on_fit_start] epoch: 0
[on_train_epoch_start] epoch: 0
[state_dict]
[on_train_epoch_start] epoch: 1
[state_dict]
----
[load_state_dict] <- we know that this is a loaded train since this method is called here
[on_fit_start] epoch: 0 <- here epoch is still 0, but later on it restarts from 2
[on_train_epoch_start] epoch: 2
[state_dict]
[on_train_epoch_start] epoch: 3
[state_dict]
[on_train_epoch_start] epoch: 4
[state_dict]
Environment
Current environment
- CUDA:
- GPU: None
- available: False
- version: 11.7
- Lightning:
- lightning-module-enhanced: 0.1
- lightning-utilities: 0.8.0
- pytorch-lightning: 2.0.0
- torch: 2.0.0
- torchinfo: 1.7.0
- torchmetrics: 0.11.4
- torchvision: 0.15.1
- Packages:
- absl-py: 1.2.0
- aiohttp: 3.8.1
- aiosignal: 1.2.0
- alembic: 1.10.3
- antlr4-python3-runtime: 4.9.3
- anyio: 3.6.2
- appdirs: 1.4.4
- argon2-cffi: 21.3.0
- argon2-cffi-bindings: 21.2.0
- arrow: 1.2.3
- art: 5.8
- astroid: 2.15.5
- asttokens: 2.2.1
- astunparse: 1.6.3
- async-generator: 1.10
- async-timeout: 4.0.2
- attrs: 21.4.0
- audioread: 2.1.9
- automat: 22.10.0
- backcall: 0.2.0
- beautifulsoup4: 4.11.1
- bleach: 6.0.0
- cachetools: 5.2.0
- certifi: 2022.12.7
- cffi: 1.15.1
- charset-normalizer: 2.1.0
- cloudpickle: 2.2.1
- cmake: 3.26.3
- colorama: 0.4.6
- coloredlogs: 15.0.1
- comm: 0.1.2
- constantly: 15.1.0
- cryptography: 39.0.2
- cssselect: 1.2.0
- cycler: 0.11.0
- databricks-cli: 0.17.6
- debugpy: 1.6.6
- decorator: 5.1.1
- decord: 0.6.0
- defusedxml: 0.7.1
- dill: 0.3.5.1
- docker: 6.0.1
- entrypoints: 0.4
- exceptiongroup: 1.1.0
- executing: 1.2.0
- fastcore: 1.5.28
- fastjsonschema: 2.16.3
- ffmpeg-python: 0.2.0
- filelock: 3.7.1
- flask: 2.2.3
- flatbuffers: 1.12
- flow-vis: 0.1
- fonttools: 4.34.4
- fqdn: 1.5.1
- frozenlist: 1.3.0
- fsspec: 2022.5.0
- gast: 0.4.0
- gdown: 4.6.0
- gitdb: 4.0.10
- gitpython: 3.1.31
- google-auth: 2.9.1
- google-auth-oauthlib: 0.4.6
- google-pasta: 0.2.0
- graphviz: 0.20.1
- grpcio: 1.47.0
- gunicorn: 20.1.0
- h11: 0.14.0
- h5py: 3.7.0
- huggingface-hub: 0.12.0
- humanfriendly: 10.0
- hyperlink: 21.0.0
- idna: 3.4
- imageio: 2.19.5
- imageio-ffmpeg: 0.4.7
- incremental: 22.10.0
- iniconfig: 1.1.1
- ipykernel: 6.21.2
- ipython: 8.10.0
- ipython-genutils: 0.2.0
- ipywidgets: 8.0.4
- isoduration: 20.11.0
- isort: 5.10.1
- itemadapter: 0.7.0
- itemloaders: 1.0.6
- itsdangerous: 2.1.2
- jedi: 0.18.2
- jinja2: 3.1.2
- jmespath: 1.0.1
- joblib: 1.1.0
- jsonpointer: 2.3
- jsonschema: 4.17.3
- jupyter: 1.0.0
- jupyter-client: 8.0.3
- jupyter-console: 6.6.1
- jupyter-core: 5.2.0
- jupyter-events: 0.6.3
- jupyter-server: 2.3.0
- jupyter-server-terminals: 0.4.4
- jupyterlab-pygments: 0.2.2
- jupyterlab-widgets: 3.0.5
- keras: 2.9.0
- keras-preprocessing: 1.1.2
- kiwisolver: 1.4.4
- lazy-loader: 0.2
- lazy-object-proxy: 1.7.1
- libclang: 14.0.6
- librosa: 0.9.2
- lightning-module-enhanced: 0.1
- lightning-utilities: 0.8.0
- lit: 16.0.5
- llvmlite: 0.38.1
- lovely-numpy: 0.2.8
- lovely-tensors: 0.1.14
- lxml: 4.9.2
- lycon: 0.2.0
- markdown: 3.4.1
- markupsafe: 2.1.2
- matplotlib: 3.5.2
- matplotlib-inline: 0.1.6
- mccabe: 0.7.0
- mistune: 2.0.5
- mlflow: 2.2.2
- mpmath: 1.2.1
- multidict: 6.0.2
- natsort: 8.2.0
- nbclassic: 0.5.2
- nbclient: 0.7.2
- nbconvert: 7.2.9
- nbformat: 5.7.3
- nest-asyncio: 1.5.6
- networkx: 2.8.5
- notebook: 6.5.2
- notebook-shim: 0.2.2
- numba: 0.55.2
- numpy: 1.24.2
- nvidia-cublas-cu11: 11.10.3.66
- nvidia-cuda-cupti-cu11: 11.7.101
- nvidia-cuda-nvrtc-cu11: 11.7.99
- nvidia-cuda-runtime-cu11: 11.7.99
- nvidia-cudnn-cu11: 8.5.0.96
- nvidia-cufft-cu11: 10.9.0.58
- nvidia-curand-cu11: 10.2.10.91
- nvidia-cusolver-cu11: 11.4.0.1
- nvidia-cusparse-cu11: 11.7.4.91
- nvidia-nccl-cu11: 2.14.3
- nvidia-nvtx-cu11: 11.7.91
- nwgraph: 0.2.1
- nwutils: 0.3.4
- omegaconf: 2.3.0
- onnx: 1.13.1
- onnxruntime: 1.14.0
- opencv-python: 4.6.0.66
- opt-einsum: 3.3.0
- outcome: 1.2.0
- overrides: 7.3.1
- packaging: 21.3
- pandarallel: 1.6.4
- pandas: 1.5.1
- pandocfilters: 1.5.0
- parsel: 1.7.0
- parso: 0.8.3
- pexpect: 4.8.0
- pickleshare: 0.7.5
- pillow: 9.4.0
- pims: 0.6.1
- pip: 23.1.2
- platformdirs: 2.5.2
- pluggy: 1.0.0
- plyvel: 1.5.0
- pooch: 1.6.0
- pool-resources: 0.1
- prometheus-client: 0.16.0
- prompt-toolkit: 3.0.37
- protego: 0.2.1
- protobuf: 3.20.1
- psutil: 5.9.5
- ptyprocess: 0.7.0
- pure-eval: 0.2.2
- py: 1.11.0
- pyarrow: 11.0.0
- pyasn1: 0.4.8
- pyasn1-modules: 0.2.8
- pycm: 3.8
- pycparser: 2.21
- pydeprecate: 0.3.2
- pydispatcher: 2.0.7
- pygments: 2.14.0
- pylint: 2.17.4
- pyopenssl: 23.0.0
- pyparsing: 3.0.9
- pyproj: 3.4.1
- pyrsistent: 0.19.3
- pysocks: 1.7.1
- pytest: 7.3.1
- python-dateutil: 2.8.2
- python-json-logger: 2.0.7
- pytorch-lightning: 2.0.0
- pytz: 2023.3
- pywavelets: 1.3.0
- pyyaml: 6.0
- pyzmq: 25.0.0
- qtconsole: 5.4.0
- qtpy: 2.3.0
- querystring-parser: 1.2.4
- queuelib: 1.6.2
- requests: 2.28.2
- requests-file: 1.5.1
- requests-oauthlib: 1.3.1
- resampy: 0.3.1
- rfc3339-validator: 0.1.4
- rfc3986-validator: 0.1.1
- rsa: 4.9
- scikit-image: 0.20.0
- scikit-learn: 1.1.1
- scipy: 1.8.1
- scrapy: 2.8.0
- selenium: 4.8.2
- send2trash: 1.8.0
- service-identity: 21.1.0
- setuptools: 65.6.3
- shap: 0.41.0
- six: 1.16.0
- slicer: 0.0.7
- slicerator: 1.1.0
- smmap: 5.0.0
- sniffio: 1.3.0
- sortedcontainers: 2.4.0
- soundfile: 0.10.3.post1
- soupsieve: 2.3.2.post1
- sqlalchemy: 2.0.9
- sqlparse: 0.4.3
- stack-data: 0.6.2
- sympy: 1.11.1
- tabulate: 0.9.0
- tensorboard-data-server: 0.6.1
- tensorboard-plugin-wit: 1.8.1
- tensorboardx: 2.5.1
- tensorflow: 2.9.0
- tensorflow-estimator: 2.9.0
- tensorflow-io-gcs-filesystem: 0.29.0
- termcolor: 2.1.1
- terminado: 0.17.1
- tf-estimator-nightly: 2.8.0.dev2021122109
- threadpoolctl: 3.1.0
- tifffile: 2022.5.4
- timm: 0.6.12
- tinycss2: 1.2.1
- tldextract: 3.4.0
- tomli: 2.0.1
- tomlkit: 0.11.1
- torch: 2.0.0
- torchinfo: 1.7.0
- torchmetrics: 0.11.4
- torchvision: 0.15.1
- tornado: 6.2
- tqdm: 4.65.0
- traitlets: 5.9.0
- transforms3d: 0.4.1
- trio: 0.22.0
- trio-websocket: 0.9.2
- triton: 2.0.0
- twisted: 22.10.0
- typing-extensions: 4.3.0
- typing-utils: 0.1.0
- uri-template: 1.2.0
- urllib3: 1.26.15
- w3lib: 2.1.1
- wcwidth: 0.2.6
- webcolors: 1.12
- webencodings: 0.5.1
- websocket-client: 1.5.1
- werkzeug: 2.2.3
- wheel: 0.37.1
- widgetsnbextension: 4.0.5
- wrapt: 1.14.1
- wsproto: 1.2.0
- yarl: 1.7.2
- yt-dlp: 2023.3.4
- zope.interface: 5.5.2
- System:
- OS: Linux
- architecture:
- 64bit
- ELF
- processor: x86_64
- python: 3.10.11
- release: 5.19.0-41-generic
- version: fixed clip grad warning #42~22.04.1-Ubuntu SMP PREEMPT_DYNAMIC Tue Apr 18 17:40:00 UTC 2
More info
No response
cc @lantiga
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't workingcheckpointingRelated to checkpointingRelated to checkpointingver: 2.0.x