Skip to content

If "Trainer.fit stopped:" is encountered during pytorch_lightning.tuner.lr_finder.lr_find the following train epoch is skipped #15603

@DominikSpiljak

Description

@DominikSpiljak

Bug description

I'm trying to use the LearningRateFinder before every training epoch. I'm basing my LearningRateFinder on the snippet used in these release notes. After the first LearningRateFinder call, training epoch is run normally but after every subsequent LearningRateFinder call, "Trainer.fit stopped: max_steps=100 reached" info is shown and neither LearningRateFinder call or that training epoch is run.

How to reproduce the bug

I managed to reproduce the bug using this snippet. The script is run over 10 epochs and LearningRateFinder is run before the first 5 epochs. Notice how the first epoch runs normally, the next 4 are skipped and then the training continues normally.

import os

import torch
from pytorch_lightning import LightningModule, Trainer
from pytorch_lightning.callbacks import LearningRateFinder
from torch.utils.data import DataLoader, Dataset


class RandomDataset(Dataset):
    def __init__(self, size, num_samples):
        self.len = num_samples
        self.data = torch.randn(num_samples, size)

    def __getitem__(self, index):
        return self.data[index]

    def __len__(self):
        return self.len


class BoringModel(LightningModule):
    def __init__(self):
        super().__init__()
        self.layer = torch.nn.Linear(32, 2)
        self.learning_rate = 0.1

    def forward(self, x):
        return self.layer(x)

    def training_step(self, batch, batch_idx):
        loss = self(batch).sum()
        self.log("train_loss", loss)
        return {"loss": loss}

    def validation_step(self, batch, batch_idx):
        loss = self(batch).sum()
        self.log("valid_loss", loss)

    def test_step(self, batch, batch_idx):
        loss = self(batch).sum()
        self.log("test_loss", loss)

    def configure_optimizers(self):
        return torch.optim.SGD(self.layer.parameters(), lr=self.learning_rate)


class FineTuneLearningRateFinder(LearningRateFinder):
    def __init__(self, milestones, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.milestones = milestones

    def on_fit_start(self, *args, **kwargs):
        return

    def on_train_epoch_start(self, trainer, pl_module):
        if trainer.current_epoch in self.milestones or trainer.current_epoch == 0:
            self.lr_find(trainer, pl_module)


def run():
    num_samples = 10000

    train_data = DataLoader(RandomDataset(32, num_samples), batch_size=2)
    val_data = DataLoader(RandomDataset(32, num_samples), batch_size=2)
    test_data = DataLoader(RandomDataset(32, num_samples), batch_size=2)

    model = BoringModel()
    trainer = Trainer(
        default_root_dir=os.getcwd(),
        limit_train_batches=100,
        limit_val_batches=100,
        limit_test_batches=100,
        num_sanity_val_steps=2,
        enable_model_summary=False,
        max_epochs=10,
        callbacks=[
            FineTuneLearningRateFinder(
                milestones=list(range(5)), early_stop_threshold=None, update_attr=True
            )
        ],
    )
    trainer.fit(model, train_dataloaders=train_data, val_dataloaders=val_data)
    trainer.test(model, dataloaders=test_data)


if __name__ == "__main__":
    run()

Error messages and logs

INFO:pytorch_lightning.utilities.rank_zero:GPU available: True (cuda), used: False
INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO:pytorch_lightning.utilities.rank_zero:IPU available: False, using: 0 IPUs
INFO:pytorch_lightning.utilities.rank_zero:HPU available: False, using: 0 HPUs
/usr/local/lib/python3.7/dist-packages/pytorch_lightning/trainer/setup.py:178: PossibleUserWarning: GPU available but not used. Set `accelerator` and `devices` using `Trainer(accelerator='gpu', devices=1)`.
  category=PossibleUserWarning,
WARNING:pytorch_lightning.loggers.tensorboard:Missing logger folder: /content/lightning_logs

Epoch 9: 100%
200/200 [00:01<00:00, 106.58it/s, loss=-533, v_num=0]
Finding best initial lr: 100%
100/100 [00:00<00:00, 390.98it/s]

INFO:pytorch_lightning.utilities.rank_zero:`Trainer.fit` stopped: `max_steps=100` reached.
INFO:pytorch_lightning.tuner.lr_finder:Learning rate set to 0.8317637711026709
INFO:pytorch_lightning.utilities.rank_zero:Restoring states from the checkpoint path at /content/.lr_find_668bc567-49f5-4a03-8b0b-811a58609110.ckpt
INFO:pytorch_lightning.utilities.rank_zero:Restored all states from the checkpoint file at /content/.lr_find_668bc567-49f5-4a03-8b0b-811a58609110.ckpt

INFO:pytorch_lightning.utilities.rank_zero:`Trainer.fit` stopped: `max_steps=100` reached.
ERROR:pytorch_lightning.tuner.lr_finder:Failed to compute suggestion for learning rate because there are not enough points. Increase the loop iteration limits or the size of your dataset/dataloader.
INFO:pytorch_lightning.utilities.rank_zero:Restoring states from the checkpoint path at /content/.lr_find_a774da07-3d9f-4cbd-9207-b232cfe82ece.ckpt
INFO:pytorch_lightning.utilities.rank_zero:Restored all states from the checkpoint file at /content/.lr_find_a774da07-3d9f-4cbd-9207-b232cfe82ece.ckpt
INFO:pytorch_lightning.utilities.rank_zero:`Trainer.fit` stopped: `max_steps=100` reached.
ERROR:pytorch_lightning.tuner.lr_finder:Failed to compute suggestion for learning rate because there are not enough points. Increase the loop iteration limits or the size of your dataset/dataloader.
INFO:pytorch_lightning.utilities.rank_zero:Restoring states from the checkpoint path at /content/.lr_find_20dd38bb-6035-4ab1-b538-73a83d381164.ckpt
INFO:pytorch_lightning.utilities.rank_zero:Restored all states from the checkpoint file at /content/.lr_find_20dd38bb-6035-4ab1-b538-73a83d381164.ckpt
INFO:pytorch_lightning.utilities.rank_zero:`Trainer.fit` stopped: `max_steps=100` reached.
ERROR:pytorch_lightning.tuner.lr_finder:Failed to compute suggestion for learning rate because there are not enough points. Increase the loop iteration limits or the size of your dataset/dataloader.
INFO:pytorch_lightning.utilities.rank_zero:Restoring states from the checkpoint path at /content/.lr_find_f0d63158-d375-4554-a1bf-560ad31ce4f0.ckpt
INFO:pytorch_lightning.utilities.rank_zero:Restored all states from the checkpoint file at /content/.lr_find_f0d63158-d375-4554-a1bf-560ad31ce4f0.ckpt
INFO:pytorch_lightning.utilities.rank_zero:`Trainer.fit` stopped: `max_steps=100` reached.
ERROR:pytorch_lightning.tuner.lr_finder:Failed to compute suggestion for learning rate because there are not enough points. Increase the loop iteration limits or the size of your dataset/dataloader.
INFO:pytorch_lightning.utilities.rank_zero:Restoring states from the checkpoint path at /content/.lr_find_b03c5be5-0a79-41e9-a9f5-e8909b628461.ckpt
INFO:pytorch_lightning.utilities.rank_zero:Restored all states from the checkpoint file at /content/.lr_find_b03c5be5-0a79-41e9-a9f5-e8909b628461.ckpt

INFO:pytorch_lightning.utilities.rank_zero:`Trainer.fit` stopped: `max_epochs=10` reached.

Testing DataLoader 0: 100%
100/100 [00:00<00:00, 155.45it/s]

────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       Test metric             DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
        test_loss           -471.1915588378906
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────

Environment

* CUDA:
	- GPU:
		- Tesla T4
	- available:         True
	- version:           11.3
* Lightning:
	- lightning-lite:    1.8.0.post1
	- lightning-utilities: 0.3.0
	- pytorch-lightning: 1.8.0.post1
	- torch:             1.12.1+cu113
	- torchaudio:        0.12.1+cu113
	- torchmetrics:      0.10.2
	- torchsummary:      1.5.1
	- torchtext:         0.13.1
	- torchvision:       0.13.1+cu113
* Packages:
	- absl-py:           1.3.0
	- aeppl:             0.0.33
	- aesara:            2.7.9
	- aiohttp:           3.8.3
	- aiosignal:         1.2.0
	- alabaster:         0.7.12
	- albumentations:    1.2.1
	- altair:            4.2.0
	- appdirs:           1.4.4
	- arviz:             0.12.1
	- astor:             0.8.1
	- astropy:           4.3.1
	- astunparse:        1.6.3
	- async-timeout:     4.0.2
	- asynctest:         0.13.0
	- atari-py:          0.2.9
	- atomicwrites:      1.4.1
	- attrs:             22.1.0
	- audioread:         3.0.0
	- autograd:          1.5
	- babel:             2.11.0
	- backcall:          0.2.0
	- beautifulsoup4:    4.6.3
	- bleach:            5.0.1
	- blis:              0.7.9
	- bokeh:             2.3.3
	- branca:            0.5.0
	- bs4:               0.0.1
	- cachecontrol:      0.12.11
	- cached-property:   1.5.2
	- cachetools:        4.2.4
	- catalogue:         2.0.8
	- certifi:           2022.9.24
	- cffi:              1.15.1
	- cftime:            1.6.2
	- chardet:           3.0.4
	- charset-normalizer: 2.1.1
	- click:             7.1.2
	- clikit:            0.6.2
	- cloudpickle:       1.5.0
	- cmake:             3.22.6
	- cmdstanpy:         1.0.8
	- colorcet:          3.0.1
	- colorlover:        0.3.0
	- community:         1.0.0b1
	- confection:        0.0.3
	- cons:              0.4.5
	- contextlib2:       0.5.5
	- convertdate:       2.4.0
	- crashtest:         0.3.1
	- crcmod:            1.7
	- cufflinks:         0.17.3
	- cupy-cuda11x:      11.0.0
	- cvxopt:            1.3.0
	- cvxpy:             1.2.1
	- cycler:            0.11.0
	- cymem:             2.0.7
	- cython:            0.29.32
	- daft:              0.0.4
	- dask:              2022.2.0
	- datascience:       0.17.5
	- debugpy:           1.0.0
	- decorator:         4.4.2
	- defusedxml:        0.7.1
	- descartes:         1.1.0
	- dill:              0.3.6
	- distributed:       2022.2.0
	- dlib:              19.24.0
	- dm-tree:           0.1.7
	- dnspython:         2.2.1
	- docutils:          0.17.1
	- dopamine-rl:       1.0.5
	- earthengine-api:   0.1.330
	- easydict:          1.10
	- ecos:              2.0.10
	- editdistance:      0.5.3
	- en-core-web-sm:    3.4.1
	- entrypoints:       0.4
	- ephem:             4.1.3
	- et-xmlfile:        1.1.0
	- etils:             0.9.0
	- etuples:           0.3.8
	- fa2:               0.3.5
	- fastai:            2.7.10
	- fastcore:          1.5.27
	- fastdownload:      0.0.7
	- fastdtw:           0.3.4
	- fastjsonschema:    2.16.2
	- fastprogress:      1.0.3
	- fastrlock:         0.8.1
	- feather-format:    0.4.1
	- filelock:          3.8.0
	- fire:              0.4.0
	- firebase-admin:    4.4.0
	- fix-yahoo-finance: 0.0.22
	- flask:             1.1.4
	- flatbuffers:       1.12
	- folium:            0.12.1.post1
	- frozenlist:        1.3.1
	- fsspec:            2022.10.0
	- future:            0.16.0
	- gast:              0.4.0
	- gdal:              2.2.2
	- gdown:             4.4.0
	- gensim:            3.6.0
	- geographiclib:     1.52
	- geopy:             1.17.0
	- gin-config:        0.5.0
	- glob2:             0.7
	- google:            2.0.3
	- google-api-core:   1.31.6
	- google-api-python-client: 1.12.11
	- google-auth:       1.35.0
	- google-auth-httplib2: 0.0.4
	- google-auth-oauthlib: 0.4.6
	- google-cloud-bigquery: 1.21.0
	- google-cloud-bigquery-storage: 1.1.2
	- google-cloud-core: 1.0.3
	- google-cloud-datastore: 1.8.0
	- google-cloud-firestore: 1.7.0
	- google-cloud-language: 1.2.0
	- google-cloud-storage: 1.18.1
	- google-cloud-translate: 1.5.0
	- google-colab:      1.0.0
	- google-pasta:      0.2.0
	- google-resumable-media: 0.4.1
	- googleapis-common-protos: 1.56.4
	- googledrivedownloader: 0.4
	- graphviz:          0.10.1
	- greenlet:          2.0.0.post0
	- grpcio:            1.50.0
	- gspread:           3.4.2
	- gspread-dataframe: 3.0.8
	- gym:               0.25.2
	- gym-notices:       0.0.8
	- h5py:              3.1.0
	- heapdict:          1.0.1
	- hijri-converter:   2.2.4
	- holidays:          0.16
	- holoviews:         1.14.9
	- html5lib:          1.0.1
	- httpimport:        0.5.18
	- httplib2:          0.17.4
	- httpstan:          4.6.1
	- humanize:          0.5.1
	- hyperopt:          0.1.2
	- idna:              2.10
	- imageio:           2.9.0
	- imagesize:         1.4.1
	- imbalanced-learn:  0.8.1
	- imblearn:          0.0
	- imgaug:            0.4.0
	- importlib-metadata: 4.13.0
	- importlib-resources: 5.10.0
	- imutils:           0.5.4
	- inflect:           2.1.0
	- intel-openmp:      2022.2.0
	- intervaltree:      2.1.0
	- ipykernel:         5.3.4
	- ipython:           7.9.0
	- ipython-genutils:  0.2.0
	- ipython-sql:       0.3.9
	- ipywidgets:        7.7.1
	- itsdangerous:      1.1.0
	- jax:               0.3.23
	- jaxlib:            0.3.22+cuda11.cudnn805
	- jieba:             0.42.1
	- jinja2:            2.11.3
	- joblib:            1.2.0
	- jpeg4py:           0.1.4
	- jsonschema:        4.3.3
	- jupyter-client:    6.1.12
	- jupyter-console:   6.1.0
	- jupyter-core:      4.11.2
	- jupyterlab-widgets: 3.0.3
	- kaggle:            1.5.12
	- kapre:             0.3.7
	- keras:             2.9.0
	- keras-preprocessing: 1.1.2
	- keras-vis:         0.4.1
	- kiwisolver:        1.4.4
	- korean-lunar-calendar: 0.3.1
	- langcodes:         3.3.0
	- libclang:          14.0.6
	- librosa:           0.8.1
	- lightgbm:          2.2.3
	- lightning-lite:    1.8.0.post1
	- lightning-utilities: 0.3.0
	- llvmlite:          0.39.1
	- lmdb:              0.99
	- locket:            1.0.0
	- logical-unification: 0.4.5
	- lunarcalendar:     0.0.9
	- lxml:              4.9.1
	- markdown:          3.4.1
	- markupsafe:        2.0.1
	- marshmallow:       3.18.0
	- matplotlib:        3.2.2
	- matplotlib-venn:   0.11.7
	- minikanren:        1.0.3
	- missingno:         0.5.1
	- mistune:           0.8.4
	- mizani:            0.7.3
	- mkl:               2019.0
	- mlxtend:           0.14.0
	- more-itertools:    9.0.0
	- moviepy:           0.2.3.5
	- mpmath:            1.2.1
	- msgpack:           1.0.4
	- multidict:         6.0.2
	- multipledispatch:  0.6.0
	- multitasking:      0.0.11
	- murmurhash:        1.0.9
	- music21:           5.5.0
	- natsort:           5.5.0
	- nbconvert:         5.6.1
	- nbformat:          5.7.0
	- netcdf4:           1.6.1
	- networkx:          2.6.3
	- nibabel:           3.0.2
	- nltk:              3.7
	- notebook:          5.7.16
	- numba:             0.56.4
	- numexpr:           2.8.4
	- numpy:             1.21.6
	- oauth2client:      4.1.3
	- oauthlib:          3.2.2
	- okgrade:           0.4.3
	- opencv-contrib-python: 4.6.0.66
	- opencv-python:     4.6.0.66
	- opencv-python-headless: 4.6.0.66
	- openpyxl:          3.0.10
	- opt-einsum:        3.3.0
	- osqp:              0.6.2.post0
	- packaging:         21.3
	- palettable:        3.3.0
	- pandas:            1.3.5
	- pandas-datareader: 0.9.0
	- pandas-gbq:        0.13.3
	- pandas-profiling:  1.4.1
	- pandocfilters:     1.5.0
	- panel:             0.12.1
	- param:             1.12.2
	- parso:             0.8.3
	- partd:             1.3.0
	- pastel:            0.2.1
	- pathlib:           1.0.1
	- pathy:             0.6.2
	- patsy:             0.5.3
	- pep517:            0.13.0
	- pexpect:           4.8.0
	- pickleshare:       0.7.5
	- pillow:            7.1.2
	- pip:               21.1.3
	- pip-tools:         6.2.0
	- plotly:            5.5.0
	- plotnine:          0.8.0
	- pluggy:            0.7.1
	- pooch:             1.6.0
	- portpicker:        1.3.9
	- prefetch-generator: 1.0.1
	- preshed:           3.0.8
	- prettytable:       3.5.0
	- progressbar2:      3.38.0
	- prometheus-client: 0.15.0
	- promise:           2.3
	- prompt-toolkit:    2.0.10
	- prophet:           1.1.1
	- protobuf:          3.17.3
	- psutil:            5.4.8
	- psycopg2:          2.9.5
	- ptyprocess:        0.7.0
	- py:                1.11.0
	- pyarrow:           6.0.1
	- pyasn1:            0.4.8
	- pyasn1-modules:    0.2.8
	- pycocotools:       2.0.6
	- pycparser:         2.21
	- pyct:              0.4.8
	- pydantic:          1.10.2
	- pydata-google-auth: 1.4.0
	- pydot:             1.3.0
	- pydot-ng:          2.0.0
	- pydotplus:         2.0.2
	- pydrive:           1.3.1
	- pyemd:             0.5.1
	- pyerfa:            2.0.0.1
	- pygments:          2.6.1
	- pygobject:         3.26.1
	- pylev:             1.4.0
	- pymc:              4.1.4
	- pymeeus:           0.5.11
	- pymongo:           4.3.2
	- pymystem3:         0.2.0
	- pyopengl:          3.1.6
	- pyparsing:         3.0.9
	- pyrsistent:        0.19.2
	- pysimdjson:        3.2.0
	- pysndfile:         1.3.8
	- pysocks:           1.7.1
	- pystan:            3.3.0
	- pytest:            3.6.4
	- python-apt:        0.0.0
	- python-dateutil:   2.8.2
	- python-louvain:    0.16
	- python-slugify:    6.1.2
	- python-utils:      3.4.5
	- pytorch-lightning: 1.8.0.post1
	- pytz:              2022.6
	- pyviz-comms:       2.2.1
	- pywavelets:        1.3.0
	- pyyaml:            6.0
	- pyzmq:             23.2.1
	- qdldl:             0.1.5.post2
	- qudida:            0.0.4
	- regex:             2022.6.2
	- requests:          2.23.0
	- requests-oauthlib: 1.3.1
	- resampy:           0.4.2
	- rpy2:              3.5.5
	- rsa:               4.9
	- scikit-image:      0.18.3
	- scikit-learn:      1.0.2
	- scipy:             1.7.3
	- screen-resolution-extra: 0.0.0
	- scs:               3.2.2
	- seaborn:           0.11.2
	- send2trash:        1.8.0
	- setuptools:        57.4.0
	- setuptools-git:    1.2
	- shapely:           1.8.5.post1
	- six:               1.15.0
	- sklearn-pandas:    1.8.0
	- smart-open:        5.2.1
	- snowballstemmer:   2.2.0
	- sortedcontainers:  2.4.0
	- soundfile:         0.11.0
	- spacy:             3.4.2
	- spacy-legacy:      3.0.10
	- spacy-loggers:     1.0.3
	- sphinx:            1.8.6
	- sphinxcontrib-serializinghtml: 1.1.5
	- sphinxcontrib-websupport: 1.2.4
	- sqlalchemy:        1.4.42
	- sqlparse:          0.4.3
	- srsly:             2.4.5
	- statsmodels:       0.12.2
	- sympy:             1.7.1
	- tables:            3.7.0
	- tabulate:          0.8.10
	- tblib:             1.7.0
	- tenacity:          8.1.0
	- tensorboard:       2.9.1
	- tensorboard-data-server: 0.6.1
	- tensorboard-plugin-wit: 1.8.1
	- tensorflow:        2.9.2
	- tensorflow-datasets: 4.6.0
	- tensorflow-estimator: 2.9.0
	- tensorflow-gcs-config: 2.9.1
	- tensorflow-hub:    0.12.0
	- tensorflow-io-gcs-filesystem: 0.27.0
	- tensorflow-metadata: 1.10.0
	- tensorflow-probability: 0.17.0
	- termcolor:         2.1.0
	- terminado:         0.13.3
	- testpath:          0.6.0
	- text-unidecode:    1.3
	- textblob:          0.15.3
	- thinc:             8.1.5
	- threadpoolctl:     3.1.0
	- tifffile:          2021.11.2
	- toml:              0.10.2
	- tomli:             2.0.1
	- toolz:             0.12.0
	- torch:             1.12.1+cu113
	- torchaudio:        0.12.1+cu113
	- torchmetrics:      0.10.2
	- torchsummary:      1.5.1
	- torchtext:         0.13.1
	- torchvision:       0.13.1+cu113
	- tornado:           6.0.4
	- tqdm:              4.64.1
	- traitlets:         5.1.1
	- tweepy:            3.10.0
	- typeguard:         2.7.1
	- typer:             0.4.2
	- typing-extensions: 4.1.1
	- tzlocal:           1.5.1
	- uritemplate:       3.0.1
	- urllib3:           1.24.3
	- vega-datasets:     0.9.0
	- wasabi:            0.10.1
	- wcwidth:           0.2.5
	- webargs:           8.2.0
	- webencodings:      0.5.1
	- werkzeug:          1.0.1
	- wheel:             0.38.1
	- widgetsnbextension: 3.6.1
	- wordcloud:         1.8.2.2
	- wrapt:             1.14.1
	- xarray:            0.20.2
	- xarray-einstats:   0.2.2
	- xgboost:           0.90
	- xkit:              0.0.0
	- xlrd:              1.1.0
	- xlwt:              1.3.0
	- yarl:              1.8.1
	- yellowbrick:       1.5
	- zict:              2.2.0
	- zipp:              3.10.0
* System:
	- OS:                Linux
	- architecture:
		- 64bit
		- 
	- processor:         x86_64
	- python:            3.7.15
	- version:           #1 SMP Fri Aug 26 08:44:51 UTC 2022

More info

No response

cc @awaelchli

Metadata

Metadata

Assignees

Labels

bugSomething isn't workingcallback

Type

No type

Projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions