Skip to content

Trainer's current_epoch is not updated at on_fit_start when loading from a checkpoint #17712

@Meehai

Description

@Meehai

Bug description

When reloading from a checkpoint (using Trainer().fit(..., ckpt_path=X)) it seems that the current_epoch of the trainer is not updated accordingly at on_fit_start, but rather later on at on_train_epoch_end.

load_state_dict is called already for callbacks, so we know it is indeed a training that is loaded from a ckpt, not from scratch.

My workaround for this (I have some custom logic in a callback that only applies if we continue/fine-tune a training) is to look at trainer.ckpt_path not being None.

Is this the intended behavior or should trainer.current_epoch be also >0 ?

What version are you seeing the problem on?

v2.0

How to reproduce the bug

from pytorch_lightning import LightningModule, Callback, Trainer
from pytorch_lightning.callbacks.callback import Callback
from torch.utils.data import Dataset, DataLoader
import torch as tr
from torch import nn
from torch.nn import functional as F

class Reader(Dataset):
    def __getitem__(self, ix):
        return tr.randn(5), tr.randn(1)

    def __len__(self):
        return 5

class MyCallback(Callback):
    def on_train_epoch_start(self, trainer, pl_module):
        print(f"[on_train_epoch_start] epoch: {trainer.current_epoch}")

    def on_fit_start(self, trainer, pl_module):
        print(f"[on_fit_start] epoch: {trainer.current_epoch}")

    def load_state_dict(self, state):
        print("[load_state_dict]")

    def state_dict(self):
        print("[state_dict]")
        return {1: 2}

class LitModel(LightningModule):
    def __init__(self):
        super().__init__()
        self.l1 = nn.Linear(5, 1)

    def forward(self, x):
        return tr.relu(self.l1(x.view(x.size(0), -1)))

    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = F.cross_entropy(y_hat, y)
        return loss

    def configure_optimizers(self):
        return tr.optim.Adam(self.parameters(), lr=0.02)

def main():
    loader = DataLoader(Reader(), batch_size=5)
    model = LitModel()
    Trainer(max_epochs=2, callbacks=[MyCallback()], enable_progress_bar=False, enable_model_summary=False) \
        .fit(model, train_dataloaders=loader)
    print("----")
    Trainer(max_epochs=5, callbacks=[MyCallback()], enable_progress_bar=False, enable_model_summary=False) \
        .fit(model, train_dataloaders=loader, ckpt_path=model.trainer.checkpoint_callback.best_model_path)

if __name__ == "__main__":
    main()

Error messages and logs

python main.py 2>/dev/null

[on_fit_start] epoch: 0
[on_train_epoch_start] epoch: 0
[state_dict]
[on_train_epoch_start] epoch: 1
[state_dict]
----
[load_state_dict] <- we know that this is a loaded train since this method is called here
[on_fit_start] epoch: 0 <- here epoch is still 0, but later on it restarts from 2
[on_train_epoch_start] epoch: 2
[state_dict]
[on_train_epoch_start] epoch: 3
[state_dict]
[on_train_epoch_start] epoch: 4
[state_dict]

Environment

Current environment
  • CUDA:
    • GPU: None
    • available: False
    • version: 11.7
  • Lightning:
    • lightning-module-enhanced: 0.1
    • lightning-utilities: 0.8.0
    • pytorch-lightning: 2.0.0
    • torch: 2.0.0
    • torchinfo: 1.7.0
    • torchmetrics: 0.11.4
    • torchvision: 0.15.1
  • Packages:
    • absl-py: 1.2.0
    • aiohttp: 3.8.1
    • aiosignal: 1.2.0
    • alembic: 1.10.3
    • antlr4-python3-runtime: 4.9.3
    • anyio: 3.6.2
    • appdirs: 1.4.4
    • argon2-cffi: 21.3.0
    • argon2-cffi-bindings: 21.2.0
    • arrow: 1.2.3
    • art: 5.8
    • astroid: 2.15.5
    • asttokens: 2.2.1
    • astunparse: 1.6.3
    • async-generator: 1.10
    • async-timeout: 4.0.2
    • attrs: 21.4.0
    • audioread: 2.1.9
    • automat: 22.10.0
    • backcall: 0.2.0
    • beautifulsoup4: 4.11.1
    • bleach: 6.0.0
    • cachetools: 5.2.0
    • certifi: 2022.12.7
    • cffi: 1.15.1
    • charset-normalizer: 2.1.0
    • cloudpickle: 2.2.1
    • cmake: 3.26.3
    • colorama: 0.4.6
    • coloredlogs: 15.0.1
    • comm: 0.1.2
    • constantly: 15.1.0
    • cryptography: 39.0.2
    • cssselect: 1.2.0
    • cycler: 0.11.0
    • databricks-cli: 0.17.6
    • debugpy: 1.6.6
    • decorator: 5.1.1
    • decord: 0.6.0
    • defusedxml: 0.7.1
    • dill: 0.3.5.1
    • docker: 6.0.1
    • entrypoints: 0.4
    • exceptiongroup: 1.1.0
    • executing: 1.2.0
    • fastcore: 1.5.28
    • fastjsonschema: 2.16.3
    • ffmpeg-python: 0.2.0
    • filelock: 3.7.1
    • flask: 2.2.3
    • flatbuffers: 1.12
    • flow-vis: 0.1
    • fonttools: 4.34.4
    • fqdn: 1.5.1
    • frozenlist: 1.3.0
    • fsspec: 2022.5.0
    • gast: 0.4.0
    • gdown: 4.6.0
    • gitdb: 4.0.10
    • gitpython: 3.1.31
    • google-auth: 2.9.1
    • google-auth-oauthlib: 0.4.6
    • google-pasta: 0.2.0
    • graphviz: 0.20.1
    • grpcio: 1.47.0
    • gunicorn: 20.1.0
    • h11: 0.14.0
    • h5py: 3.7.0
    • huggingface-hub: 0.12.0
    • humanfriendly: 10.0
    • hyperlink: 21.0.0
    • idna: 3.4
    • imageio: 2.19.5
    • imageio-ffmpeg: 0.4.7
    • incremental: 22.10.0
    • iniconfig: 1.1.1
    • ipykernel: 6.21.2
    • ipython: 8.10.0
    • ipython-genutils: 0.2.0
    • ipywidgets: 8.0.4
    • isoduration: 20.11.0
    • isort: 5.10.1
    • itemadapter: 0.7.0
    • itemloaders: 1.0.6
    • itsdangerous: 2.1.2
    • jedi: 0.18.2
    • jinja2: 3.1.2
    • jmespath: 1.0.1
    • joblib: 1.1.0
    • jsonpointer: 2.3
    • jsonschema: 4.17.3
    • jupyter: 1.0.0
    • jupyter-client: 8.0.3
    • jupyter-console: 6.6.1
    • jupyter-core: 5.2.0
    • jupyter-events: 0.6.3
    • jupyter-server: 2.3.0
    • jupyter-server-terminals: 0.4.4
    • jupyterlab-pygments: 0.2.2
    • jupyterlab-widgets: 3.0.5
    • keras: 2.9.0
    • keras-preprocessing: 1.1.2
    • kiwisolver: 1.4.4
    • lazy-loader: 0.2
    • lazy-object-proxy: 1.7.1
    • libclang: 14.0.6
    • librosa: 0.9.2
    • lightning-module-enhanced: 0.1
    • lightning-utilities: 0.8.0
    • lit: 16.0.5
    • llvmlite: 0.38.1
    • lovely-numpy: 0.2.8
    • lovely-tensors: 0.1.14
    • lxml: 4.9.2
    • lycon: 0.2.0
    • markdown: 3.4.1
    • markupsafe: 2.1.2
    • matplotlib: 3.5.2
    • matplotlib-inline: 0.1.6
    • mccabe: 0.7.0
    • mistune: 2.0.5
    • mlflow: 2.2.2
    • mpmath: 1.2.1
    • multidict: 6.0.2
    • natsort: 8.2.0
    • nbclassic: 0.5.2
    • nbclient: 0.7.2
    • nbconvert: 7.2.9
    • nbformat: 5.7.3
    • nest-asyncio: 1.5.6
    • networkx: 2.8.5
    • notebook: 6.5.2
    • notebook-shim: 0.2.2
    • numba: 0.55.2
    • numpy: 1.24.2
    • nvidia-cublas-cu11: 11.10.3.66
    • nvidia-cuda-cupti-cu11: 11.7.101
    • nvidia-cuda-nvrtc-cu11: 11.7.99
    • nvidia-cuda-runtime-cu11: 11.7.99
    • nvidia-cudnn-cu11: 8.5.0.96
    • nvidia-cufft-cu11: 10.9.0.58
    • nvidia-curand-cu11: 10.2.10.91
    • nvidia-cusolver-cu11: 11.4.0.1
    • nvidia-cusparse-cu11: 11.7.4.91
    • nvidia-nccl-cu11: 2.14.3
    • nvidia-nvtx-cu11: 11.7.91
    • nwgraph: 0.2.1
    • nwutils: 0.3.4
    • omegaconf: 2.3.0
    • onnx: 1.13.1
    • onnxruntime: 1.14.0
    • opencv-python: 4.6.0.66
    • opt-einsum: 3.3.0
    • outcome: 1.2.0
    • overrides: 7.3.1
    • packaging: 21.3
    • pandarallel: 1.6.4
    • pandas: 1.5.1
    • pandocfilters: 1.5.0
    • parsel: 1.7.0
    • parso: 0.8.3
    • pexpect: 4.8.0
    • pickleshare: 0.7.5
    • pillow: 9.4.0
    • pims: 0.6.1
    • pip: 23.1.2
    • platformdirs: 2.5.2
    • pluggy: 1.0.0
    • plyvel: 1.5.0
    • pooch: 1.6.0
    • pool-resources: 0.1
    • prometheus-client: 0.16.0
    • prompt-toolkit: 3.0.37
    • protego: 0.2.1
    • protobuf: 3.20.1
    • psutil: 5.9.5
    • ptyprocess: 0.7.0
    • pure-eval: 0.2.2
    • py: 1.11.0
    • pyarrow: 11.0.0
    • pyasn1: 0.4.8
    • pyasn1-modules: 0.2.8
    • pycm: 3.8
    • pycparser: 2.21
    • pydeprecate: 0.3.2
    • pydispatcher: 2.0.7
    • pygments: 2.14.0
    • pylint: 2.17.4
    • pyopenssl: 23.0.0
    • pyparsing: 3.0.9
    • pyproj: 3.4.1
    • pyrsistent: 0.19.3
    • pysocks: 1.7.1
    • pytest: 7.3.1
    • python-dateutil: 2.8.2
    • python-json-logger: 2.0.7
    • pytorch-lightning: 2.0.0
    • pytz: 2023.3
    • pywavelets: 1.3.0
    • pyyaml: 6.0
    • pyzmq: 25.0.0
    • qtconsole: 5.4.0
    • qtpy: 2.3.0
    • querystring-parser: 1.2.4
    • queuelib: 1.6.2
    • requests: 2.28.2
    • requests-file: 1.5.1
    • requests-oauthlib: 1.3.1
    • resampy: 0.3.1
    • rfc3339-validator: 0.1.4
    • rfc3986-validator: 0.1.1
    • rsa: 4.9
    • scikit-image: 0.20.0
    • scikit-learn: 1.1.1
    • scipy: 1.8.1
    • scrapy: 2.8.0
    • selenium: 4.8.2
    • send2trash: 1.8.0
    • service-identity: 21.1.0
    • setuptools: 65.6.3
    • shap: 0.41.0
    • six: 1.16.0
    • slicer: 0.0.7
    • slicerator: 1.1.0
    • smmap: 5.0.0
    • sniffio: 1.3.0
    • sortedcontainers: 2.4.0
    • soundfile: 0.10.3.post1
    • soupsieve: 2.3.2.post1
    • sqlalchemy: 2.0.9
    • sqlparse: 0.4.3
    • stack-data: 0.6.2
    • sympy: 1.11.1
    • tabulate: 0.9.0
    • tensorboard-data-server: 0.6.1
    • tensorboard-plugin-wit: 1.8.1
    • tensorboardx: 2.5.1
    • tensorflow: 2.9.0
    • tensorflow-estimator: 2.9.0
    • tensorflow-io-gcs-filesystem: 0.29.0
    • termcolor: 2.1.1
    • terminado: 0.17.1
    • tf-estimator-nightly: 2.8.0.dev2021122109
    • threadpoolctl: 3.1.0
    • tifffile: 2022.5.4
    • timm: 0.6.12
    • tinycss2: 1.2.1
    • tldextract: 3.4.0
    • tomli: 2.0.1
    • tomlkit: 0.11.1
    • torch: 2.0.0
    • torchinfo: 1.7.0
    • torchmetrics: 0.11.4
    • torchvision: 0.15.1
    • tornado: 6.2
    • tqdm: 4.65.0
    • traitlets: 5.9.0
    • transforms3d: 0.4.1
    • trio: 0.22.0
    • trio-websocket: 0.9.2
    • triton: 2.0.0
    • twisted: 22.10.0
    • typing-extensions: 4.3.0
    • typing-utils: 0.1.0
    • uri-template: 1.2.0
    • urllib3: 1.26.15
    • w3lib: 2.1.1
    • wcwidth: 0.2.6
    • webcolors: 1.12
    • webencodings: 0.5.1
    • websocket-client: 1.5.1
    • werkzeug: 2.2.3
    • wheel: 0.37.1
    • widgetsnbextension: 4.0.5
    • wrapt: 1.14.1
    • wsproto: 1.2.0
    • yarl: 1.7.2
    • yt-dlp: 2023.3.4
    • zope.interface: 5.5.2
  • System:
    • OS: Linux
    • architecture:
      • 64bit
      • ELF
    • processor: x86_64
    • python: 3.10.11
    • release: 5.19.0-41-generic
    • version: fixed clip grad warning #42~22.04.1-Ubuntu SMP PREEMPT_DYNAMIC Tue Apr 18 17:40:00 UTC 2

More info

No response

cc @lantiga

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions