-
Notifications
You must be signed in to change notification settings - Fork 3.6k
Closed
Labels
Description
Bug description
I'm trying to use the LearningRateFinder before every training epoch. I'm basing my LearningRateFinder on the snippet used in these release notes. After the first LearningRateFinder call, training epoch is run normally but after every subsequent LearningRateFinder call, "Trainer.fit stopped: max_steps=100 reached" info is shown and neither LearningRateFinder call or that training epoch is run.
How to reproduce the bug
I managed to reproduce the bug using this snippet. The script is run over 10 epochs and LearningRateFinder is run before the first 5 epochs. Notice how the first epoch runs normally, the next 4 are skipped and then the training continues normally.
import os
import torch
from pytorch_lightning import LightningModule, Trainer
from pytorch_lightning.callbacks import LearningRateFinder
from torch.utils.data import DataLoader, Dataset
class RandomDataset(Dataset):
def __init__(self, size, num_samples):
self.len = num_samples
self.data = torch.randn(num_samples, size)
def __getitem__(self, index):
return self.data[index]
def __len__(self):
return self.len
class BoringModel(LightningModule):
def __init__(self):
super().__init__()
self.layer = torch.nn.Linear(32, 2)
self.learning_rate = 0.1
def forward(self, x):
return self.layer(x)
def training_step(self, batch, batch_idx):
loss = self(batch).sum()
self.log("train_loss", loss)
return {"loss": loss}
def validation_step(self, batch, batch_idx):
loss = self(batch).sum()
self.log("valid_loss", loss)
def test_step(self, batch, batch_idx):
loss = self(batch).sum()
self.log("test_loss", loss)
def configure_optimizers(self):
return torch.optim.SGD(self.layer.parameters(), lr=self.learning_rate)
class FineTuneLearningRateFinder(LearningRateFinder):
def __init__(self, milestones, *args, **kwargs):
super().__init__(*args, **kwargs)
self.milestones = milestones
def on_fit_start(self, *args, **kwargs):
return
def on_train_epoch_start(self, trainer, pl_module):
if trainer.current_epoch in self.milestones or trainer.current_epoch == 0:
self.lr_find(trainer, pl_module)
def run():
num_samples = 10000
train_data = DataLoader(RandomDataset(32, num_samples), batch_size=2)
val_data = DataLoader(RandomDataset(32, num_samples), batch_size=2)
test_data = DataLoader(RandomDataset(32, num_samples), batch_size=2)
model = BoringModel()
trainer = Trainer(
default_root_dir=os.getcwd(),
limit_train_batches=100,
limit_val_batches=100,
limit_test_batches=100,
num_sanity_val_steps=2,
enable_model_summary=False,
max_epochs=10,
callbacks=[
FineTuneLearningRateFinder(
milestones=list(range(5)), early_stop_threshold=None, update_attr=True
)
],
)
trainer.fit(model, train_dataloaders=train_data, val_dataloaders=val_data)
trainer.test(model, dataloaders=test_data)
if __name__ == "__main__":
run()Error messages and logs
INFO:pytorch_lightning.utilities.rank_zero:GPU available: True (cuda), used: False
INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO:pytorch_lightning.utilities.rank_zero:IPU available: False, using: 0 IPUs
INFO:pytorch_lightning.utilities.rank_zero:HPU available: False, using: 0 HPUs
/usr/local/lib/python3.7/dist-packages/pytorch_lightning/trainer/setup.py:178: PossibleUserWarning: GPU available but not used. Set `accelerator` and `devices` using `Trainer(accelerator='gpu', devices=1)`.
category=PossibleUserWarning,
WARNING:pytorch_lightning.loggers.tensorboard:Missing logger folder: /content/lightning_logs
Epoch 9: 100%
200/200 [00:01<00:00, 106.58it/s, loss=-533, v_num=0]
Finding best initial lr: 100%
100/100 [00:00<00:00, 390.98it/s]
INFO:pytorch_lightning.utilities.rank_zero:`Trainer.fit` stopped: `max_steps=100` reached.
INFO:pytorch_lightning.tuner.lr_finder:Learning rate set to 0.8317637711026709
INFO:pytorch_lightning.utilities.rank_zero:Restoring states from the checkpoint path at /content/.lr_find_668bc567-49f5-4a03-8b0b-811a58609110.ckpt
INFO:pytorch_lightning.utilities.rank_zero:Restored all states from the checkpoint file at /content/.lr_find_668bc567-49f5-4a03-8b0b-811a58609110.ckpt
INFO:pytorch_lightning.utilities.rank_zero:`Trainer.fit` stopped: `max_steps=100` reached.
ERROR:pytorch_lightning.tuner.lr_finder:Failed to compute suggestion for learning rate because there are not enough points. Increase the loop iteration limits or the size of your dataset/dataloader.
INFO:pytorch_lightning.utilities.rank_zero:Restoring states from the checkpoint path at /content/.lr_find_a774da07-3d9f-4cbd-9207-b232cfe82ece.ckpt
INFO:pytorch_lightning.utilities.rank_zero:Restored all states from the checkpoint file at /content/.lr_find_a774da07-3d9f-4cbd-9207-b232cfe82ece.ckpt
INFO:pytorch_lightning.utilities.rank_zero:`Trainer.fit` stopped: `max_steps=100` reached.
ERROR:pytorch_lightning.tuner.lr_finder:Failed to compute suggestion for learning rate because there are not enough points. Increase the loop iteration limits or the size of your dataset/dataloader.
INFO:pytorch_lightning.utilities.rank_zero:Restoring states from the checkpoint path at /content/.lr_find_20dd38bb-6035-4ab1-b538-73a83d381164.ckpt
INFO:pytorch_lightning.utilities.rank_zero:Restored all states from the checkpoint file at /content/.lr_find_20dd38bb-6035-4ab1-b538-73a83d381164.ckpt
INFO:pytorch_lightning.utilities.rank_zero:`Trainer.fit` stopped: `max_steps=100` reached.
ERROR:pytorch_lightning.tuner.lr_finder:Failed to compute suggestion for learning rate because there are not enough points. Increase the loop iteration limits or the size of your dataset/dataloader.
INFO:pytorch_lightning.utilities.rank_zero:Restoring states from the checkpoint path at /content/.lr_find_f0d63158-d375-4554-a1bf-560ad31ce4f0.ckpt
INFO:pytorch_lightning.utilities.rank_zero:Restored all states from the checkpoint file at /content/.lr_find_f0d63158-d375-4554-a1bf-560ad31ce4f0.ckpt
INFO:pytorch_lightning.utilities.rank_zero:`Trainer.fit` stopped: `max_steps=100` reached.
ERROR:pytorch_lightning.tuner.lr_finder:Failed to compute suggestion for learning rate because there are not enough points. Increase the loop iteration limits or the size of your dataset/dataloader.
INFO:pytorch_lightning.utilities.rank_zero:Restoring states from the checkpoint path at /content/.lr_find_b03c5be5-0a79-41e9-a9f5-e8909b628461.ckpt
INFO:pytorch_lightning.utilities.rank_zero:Restored all states from the checkpoint file at /content/.lr_find_b03c5be5-0a79-41e9-a9f5-e8909b628461.ckpt
INFO:pytorch_lightning.utilities.rank_zero:`Trainer.fit` stopped: `max_epochs=10` reached.
Testing DataLoader 0: 100%
100/100 [00:00<00:00, 155.45it/s]
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
Test metric DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
test_loss -471.1915588378906
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
Environment
* CUDA:
- GPU:
- Tesla T4
- available: True
- version: 11.3
* Lightning:
- lightning-lite: 1.8.0.post1
- lightning-utilities: 0.3.0
- pytorch-lightning: 1.8.0.post1
- torch: 1.12.1+cu113
- torchaudio: 0.12.1+cu113
- torchmetrics: 0.10.2
- torchsummary: 1.5.1
- torchtext: 0.13.1
- torchvision: 0.13.1+cu113
* Packages:
- absl-py: 1.3.0
- aeppl: 0.0.33
- aesara: 2.7.9
- aiohttp: 3.8.3
- aiosignal: 1.2.0
- alabaster: 0.7.12
- albumentations: 1.2.1
- altair: 4.2.0
- appdirs: 1.4.4
- arviz: 0.12.1
- astor: 0.8.1
- astropy: 4.3.1
- astunparse: 1.6.3
- async-timeout: 4.0.2
- asynctest: 0.13.0
- atari-py: 0.2.9
- atomicwrites: 1.4.1
- attrs: 22.1.0
- audioread: 3.0.0
- autograd: 1.5
- babel: 2.11.0
- backcall: 0.2.0
- beautifulsoup4: 4.6.3
- bleach: 5.0.1
- blis: 0.7.9
- bokeh: 2.3.3
- branca: 0.5.0
- bs4: 0.0.1
- cachecontrol: 0.12.11
- cached-property: 1.5.2
- cachetools: 4.2.4
- catalogue: 2.0.8
- certifi: 2022.9.24
- cffi: 1.15.1
- cftime: 1.6.2
- chardet: 3.0.4
- charset-normalizer: 2.1.1
- click: 7.1.2
- clikit: 0.6.2
- cloudpickle: 1.5.0
- cmake: 3.22.6
- cmdstanpy: 1.0.8
- colorcet: 3.0.1
- colorlover: 0.3.0
- community: 1.0.0b1
- confection: 0.0.3
- cons: 0.4.5
- contextlib2: 0.5.5
- convertdate: 2.4.0
- crashtest: 0.3.1
- crcmod: 1.7
- cufflinks: 0.17.3
- cupy-cuda11x: 11.0.0
- cvxopt: 1.3.0
- cvxpy: 1.2.1
- cycler: 0.11.0
- cymem: 2.0.7
- cython: 0.29.32
- daft: 0.0.4
- dask: 2022.2.0
- datascience: 0.17.5
- debugpy: 1.0.0
- decorator: 4.4.2
- defusedxml: 0.7.1
- descartes: 1.1.0
- dill: 0.3.6
- distributed: 2022.2.0
- dlib: 19.24.0
- dm-tree: 0.1.7
- dnspython: 2.2.1
- docutils: 0.17.1
- dopamine-rl: 1.0.5
- earthengine-api: 0.1.330
- easydict: 1.10
- ecos: 2.0.10
- editdistance: 0.5.3
- en-core-web-sm: 3.4.1
- entrypoints: 0.4
- ephem: 4.1.3
- et-xmlfile: 1.1.0
- etils: 0.9.0
- etuples: 0.3.8
- fa2: 0.3.5
- fastai: 2.7.10
- fastcore: 1.5.27
- fastdownload: 0.0.7
- fastdtw: 0.3.4
- fastjsonschema: 2.16.2
- fastprogress: 1.0.3
- fastrlock: 0.8.1
- feather-format: 0.4.1
- filelock: 3.8.0
- fire: 0.4.0
- firebase-admin: 4.4.0
- fix-yahoo-finance: 0.0.22
- flask: 1.1.4
- flatbuffers: 1.12
- folium: 0.12.1.post1
- frozenlist: 1.3.1
- fsspec: 2022.10.0
- future: 0.16.0
- gast: 0.4.0
- gdal: 2.2.2
- gdown: 4.4.0
- gensim: 3.6.0
- geographiclib: 1.52
- geopy: 1.17.0
- gin-config: 0.5.0
- glob2: 0.7
- google: 2.0.3
- google-api-core: 1.31.6
- google-api-python-client: 1.12.11
- google-auth: 1.35.0
- google-auth-httplib2: 0.0.4
- google-auth-oauthlib: 0.4.6
- google-cloud-bigquery: 1.21.0
- google-cloud-bigquery-storage: 1.1.2
- google-cloud-core: 1.0.3
- google-cloud-datastore: 1.8.0
- google-cloud-firestore: 1.7.0
- google-cloud-language: 1.2.0
- google-cloud-storage: 1.18.1
- google-cloud-translate: 1.5.0
- google-colab: 1.0.0
- google-pasta: 0.2.0
- google-resumable-media: 0.4.1
- googleapis-common-protos: 1.56.4
- googledrivedownloader: 0.4
- graphviz: 0.10.1
- greenlet: 2.0.0.post0
- grpcio: 1.50.0
- gspread: 3.4.2
- gspread-dataframe: 3.0.8
- gym: 0.25.2
- gym-notices: 0.0.8
- h5py: 3.1.0
- heapdict: 1.0.1
- hijri-converter: 2.2.4
- holidays: 0.16
- holoviews: 1.14.9
- html5lib: 1.0.1
- httpimport: 0.5.18
- httplib2: 0.17.4
- httpstan: 4.6.1
- humanize: 0.5.1
- hyperopt: 0.1.2
- idna: 2.10
- imageio: 2.9.0
- imagesize: 1.4.1
- imbalanced-learn: 0.8.1
- imblearn: 0.0
- imgaug: 0.4.0
- importlib-metadata: 4.13.0
- importlib-resources: 5.10.0
- imutils: 0.5.4
- inflect: 2.1.0
- intel-openmp: 2022.2.0
- intervaltree: 2.1.0
- ipykernel: 5.3.4
- ipython: 7.9.0
- ipython-genutils: 0.2.0
- ipython-sql: 0.3.9
- ipywidgets: 7.7.1
- itsdangerous: 1.1.0
- jax: 0.3.23
- jaxlib: 0.3.22+cuda11.cudnn805
- jieba: 0.42.1
- jinja2: 2.11.3
- joblib: 1.2.0
- jpeg4py: 0.1.4
- jsonschema: 4.3.3
- jupyter-client: 6.1.12
- jupyter-console: 6.1.0
- jupyter-core: 4.11.2
- jupyterlab-widgets: 3.0.3
- kaggle: 1.5.12
- kapre: 0.3.7
- keras: 2.9.0
- keras-preprocessing: 1.1.2
- keras-vis: 0.4.1
- kiwisolver: 1.4.4
- korean-lunar-calendar: 0.3.1
- langcodes: 3.3.0
- libclang: 14.0.6
- librosa: 0.8.1
- lightgbm: 2.2.3
- lightning-lite: 1.8.0.post1
- lightning-utilities: 0.3.0
- llvmlite: 0.39.1
- lmdb: 0.99
- locket: 1.0.0
- logical-unification: 0.4.5
- lunarcalendar: 0.0.9
- lxml: 4.9.1
- markdown: 3.4.1
- markupsafe: 2.0.1
- marshmallow: 3.18.0
- matplotlib: 3.2.2
- matplotlib-venn: 0.11.7
- minikanren: 1.0.3
- missingno: 0.5.1
- mistune: 0.8.4
- mizani: 0.7.3
- mkl: 2019.0
- mlxtend: 0.14.0
- more-itertools: 9.0.0
- moviepy: 0.2.3.5
- mpmath: 1.2.1
- msgpack: 1.0.4
- multidict: 6.0.2
- multipledispatch: 0.6.0
- multitasking: 0.0.11
- murmurhash: 1.0.9
- music21: 5.5.0
- natsort: 5.5.0
- nbconvert: 5.6.1
- nbformat: 5.7.0
- netcdf4: 1.6.1
- networkx: 2.6.3
- nibabel: 3.0.2
- nltk: 3.7
- notebook: 5.7.16
- numba: 0.56.4
- numexpr: 2.8.4
- numpy: 1.21.6
- oauth2client: 4.1.3
- oauthlib: 3.2.2
- okgrade: 0.4.3
- opencv-contrib-python: 4.6.0.66
- opencv-python: 4.6.0.66
- opencv-python-headless: 4.6.0.66
- openpyxl: 3.0.10
- opt-einsum: 3.3.0
- osqp: 0.6.2.post0
- packaging: 21.3
- palettable: 3.3.0
- pandas: 1.3.5
- pandas-datareader: 0.9.0
- pandas-gbq: 0.13.3
- pandas-profiling: 1.4.1
- pandocfilters: 1.5.0
- panel: 0.12.1
- param: 1.12.2
- parso: 0.8.3
- partd: 1.3.0
- pastel: 0.2.1
- pathlib: 1.0.1
- pathy: 0.6.2
- patsy: 0.5.3
- pep517: 0.13.0
- pexpect: 4.8.0
- pickleshare: 0.7.5
- pillow: 7.1.2
- pip: 21.1.3
- pip-tools: 6.2.0
- plotly: 5.5.0
- plotnine: 0.8.0
- pluggy: 0.7.1
- pooch: 1.6.0
- portpicker: 1.3.9
- prefetch-generator: 1.0.1
- preshed: 3.0.8
- prettytable: 3.5.0
- progressbar2: 3.38.0
- prometheus-client: 0.15.0
- promise: 2.3
- prompt-toolkit: 2.0.10
- prophet: 1.1.1
- protobuf: 3.17.3
- psutil: 5.4.8
- psycopg2: 2.9.5
- ptyprocess: 0.7.0
- py: 1.11.0
- pyarrow: 6.0.1
- pyasn1: 0.4.8
- pyasn1-modules: 0.2.8
- pycocotools: 2.0.6
- pycparser: 2.21
- pyct: 0.4.8
- pydantic: 1.10.2
- pydata-google-auth: 1.4.0
- pydot: 1.3.0
- pydot-ng: 2.0.0
- pydotplus: 2.0.2
- pydrive: 1.3.1
- pyemd: 0.5.1
- pyerfa: 2.0.0.1
- pygments: 2.6.1
- pygobject: 3.26.1
- pylev: 1.4.0
- pymc: 4.1.4
- pymeeus: 0.5.11
- pymongo: 4.3.2
- pymystem3: 0.2.0
- pyopengl: 3.1.6
- pyparsing: 3.0.9
- pyrsistent: 0.19.2
- pysimdjson: 3.2.0
- pysndfile: 1.3.8
- pysocks: 1.7.1
- pystan: 3.3.0
- pytest: 3.6.4
- python-apt: 0.0.0
- python-dateutil: 2.8.2
- python-louvain: 0.16
- python-slugify: 6.1.2
- python-utils: 3.4.5
- pytorch-lightning: 1.8.0.post1
- pytz: 2022.6
- pyviz-comms: 2.2.1
- pywavelets: 1.3.0
- pyyaml: 6.0
- pyzmq: 23.2.1
- qdldl: 0.1.5.post2
- qudida: 0.0.4
- regex: 2022.6.2
- requests: 2.23.0
- requests-oauthlib: 1.3.1
- resampy: 0.4.2
- rpy2: 3.5.5
- rsa: 4.9
- scikit-image: 0.18.3
- scikit-learn: 1.0.2
- scipy: 1.7.3
- screen-resolution-extra: 0.0.0
- scs: 3.2.2
- seaborn: 0.11.2
- send2trash: 1.8.0
- setuptools: 57.4.0
- setuptools-git: 1.2
- shapely: 1.8.5.post1
- six: 1.15.0
- sklearn-pandas: 1.8.0
- smart-open: 5.2.1
- snowballstemmer: 2.2.0
- sortedcontainers: 2.4.0
- soundfile: 0.11.0
- spacy: 3.4.2
- spacy-legacy: 3.0.10
- spacy-loggers: 1.0.3
- sphinx: 1.8.6
- sphinxcontrib-serializinghtml: 1.1.5
- sphinxcontrib-websupport: 1.2.4
- sqlalchemy: 1.4.42
- sqlparse: 0.4.3
- srsly: 2.4.5
- statsmodels: 0.12.2
- sympy: 1.7.1
- tables: 3.7.0
- tabulate: 0.8.10
- tblib: 1.7.0
- tenacity: 8.1.0
- tensorboard: 2.9.1
- tensorboard-data-server: 0.6.1
- tensorboard-plugin-wit: 1.8.1
- tensorflow: 2.9.2
- tensorflow-datasets: 4.6.0
- tensorflow-estimator: 2.9.0
- tensorflow-gcs-config: 2.9.1
- tensorflow-hub: 0.12.0
- tensorflow-io-gcs-filesystem: 0.27.0
- tensorflow-metadata: 1.10.0
- tensorflow-probability: 0.17.0
- termcolor: 2.1.0
- terminado: 0.13.3
- testpath: 0.6.0
- text-unidecode: 1.3
- textblob: 0.15.3
- thinc: 8.1.5
- threadpoolctl: 3.1.0
- tifffile: 2021.11.2
- toml: 0.10.2
- tomli: 2.0.1
- toolz: 0.12.0
- torch: 1.12.1+cu113
- torchaudio: 0.12.1+cu113
- torchmetrics: 0.10.2
- torchsummary: 1.5.1
- torchtext: 0.13.1
- torchvision: 0.13.1+cu113
- tornado: 6.0.4
- tqdm: 4.64.1
- traitlets: 5.1.1
- tweepy: 3.10.0
- typeguard: 2.7.1
- typer: 0.4.2
- typing-extensions: 4.1.1
- tzlocal: 1.5.1
- uritemplate: 3.0.1
- urllib3: 1.24.3
- vega-datasets: 0.9.0
- wasabi: 0.10.1
- wcwidth: 0.2.5
- webargs: 8.2.0
- webencodings: 0.5.1
- werkzeug: 1.0.1
- wheel: 0.38.1
- widgetsnbextension: 3.6.1
- wordcloud: 1.8.2.2
- wrapt: 1.14.1
- xarray: 0.20.2
- xarray-einstats: 0.2.2
- xgboost: 0.90
- xkit: 0.0.0
- xlrd: 1.1.0
- xlwt: 1.3.0
- yarl: 1.8.1
- yellowbrick: 1.5
- zict: 2.2.0
- zipp: 3.10.0
* System:
- OS: Linux
- architecture:
- 64bit
-
- processor: x86_64
- python: 3.7.15
- version: #1 SMP Fri Aug 26 08:44:51 UTC 2022
More info
No response
cc @awaelchli