|
29 | 29 | from pytorch_lightning.loggers.base import DummyLogger |
30 | 30 | from pytorch_lightning.utilities.exceptions import MisconfigurationException |
31 | 31 | from pytorch_lightning.utilities.parsing import lightning_hasattr, lightning_setattr |
| 32 | +from pytorch_lightning.utilities import rank_zero_warn |
| 33 | +from pytorch_lightning.utilities.cloud_io import get_filesystem |
32 | 34 |
|
33 | 35 | # check if ipywidgets is installed before importing tqdm.auto |
34 | 36 | # to ensure it won't fail and a progress bar is displayed |
|
41 | 43 | def _run_lr_finder_internally(trainer, model: LightningModule): |
42 | 44 | """ Call lr finder internally during Trainer.fit() """ |
43 | 45 | lr_finder = lr_find(trainer, model) |
| 46 | + |
| 47 | + if lr_finder is None: |
| 48 | + return |
| 49 | + |
44 | 50 | lr = lr_finder.suggestion() |
45 | 51 |
|
46 | 52 | # TODO: log lr.results to self.logger |
@@ -130,7 +136,11 @@ def lr_find( |
130 | 136 | trainer.fit(model) |
131 | 137 |
|
132 | 138 | """ |
133 | | - save_path = os.path.join(trainer.default_root_dir, 'lr_find_temp.ckpt') |
| 139 | + if trainer.fast_dev_run: |
| 140 | + rank_zero_warn('Skipping learning rate finder since `fast_dev_run=True`', UserWarning) |
| 141 | + return |
| 142 | + |
| 143 | + save_path = os.path.join(trainer.default_root_dir, 'lr_find_temp_model.ckpt') |
134 | 144 |
|
135 | 145 | __lr_finder_dump_params(trainer, model) |
136 | 146 |
|
@@ -181,8 +191,11 @@ def lr_find( |
181 | 191 | lr_finder._total_batch_idx = trainer.total_batch_idx # for debug purpose |
182 | 192 |
|
183 | 193 | # Reset model state |
184 | | - trainer.checkpoint_connector.restore(str(save_path), on_gpu=trainer.on_gpu) |
185 | | - os.remove(save_path) |
| 194 | + if trainer.is_global_zero: |
| 195 | + trainer.checkpoint_connector.restore(str(save_path), on_gpu=trainer.on_gpu) |
| 196 | + fs = get_filesystem(str(save_path)) |
| 197 | + if fs.exists(save_path): |
| 198 | + fs.rm(save_path) |
186 | 199 |
|
187 | 200 | # Finish by resetting variables so trainer is ready to fit model |
188 | 201 | __lr_finder_restore_params(trainer, model) |
|
0 commit comments