Skip to content

Commit e4b401e

Browse files
committed
Merge branch 'feat-sync_step' of https://github.com/borisdayma/pytorch-lightning into feat-sync_step
2 parents ecdda57 + d1b74a8 commit e4b401e

28 files changed

+390
-170
lines changed

CHANGELOG.md

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
99

1010
### Added
1111

12+
- Added a check for optimizer attached to lr_scheduler ([#5338](https://github.com/PyTorchLightning/pytorch-lightning/pull/5338))
13+
14+
- Added `resume_from_checkpoint` accept non-existing file path ([#4402](https://github.com/PyTorchLightning/pytorch-lightning/pull/4402))
15+
1216

1317
### Changed
1418

@@ -21,6 +25,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
2125

2226
### Fixed
2327

28+
- Allowed `log_momentum` for adaptive optimizers in `LearningRateMonitor` ([#5333](https://github.com/PyTorchLightning/pytorch-lightning/pull/5333))
29+
30+
- Disabled checkpointing, earlystopping and logger with `fast_dev_run` ([#5277](https://github.com/PyTorchLightning/pytorch-lightning/pull/5277))
31+
2432

2533

2634
## [1.1.2] - 2020-12-23

docs/source/conf.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -294,10 +294,14 @@ def setup(app):
294294
# Ignoring Third-party packages
295295
# https://stackoverflow.com/questions/15889621/sphinx-how-to-exclude-imports-in-automodule
296296
def package_list_from_file(file):
297+
"""List up package name (not containing version and extras) from a package list file
298+
"""
297299
mocked_packages = []
298300
with open(file, 'r') as fp:
299301
for ln in fp.readlines():
300-
found = [ln.index(ch) for ch in list(',=<>#') if ch in ln]
302+
# Example: `tqdm>=4.41.0` => `tqdm`
303+
# `[` is for package with extras
304+
found = [ln.index(ch) for ch in list(',=<>#[') if ch in ln]
301305
pkg = ln[:min(found)] if found else ln
302306
if pkg.rstrip():
303307
mocked_packages.append(pkg.rstrip())

docs/source/debugging.rst

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,13 +28,18 @@ The point is to detect any bugs in the training/validation loop without having t
2828
argument of :class:`~pytorch_lightning.trainer.trainer.Trainer`)
2929

3030
.. testcode::
31-
31+
3232
# runs 1 train, val, test batch and program ends
3333
trainer = Trainer(fast_dev_run=True)
3434

3535
# runs 7 train, val, test batches and program ends
3636
trainer = Trainer(fast_dev_run=7)
3737

38+
.. note::
39+
40+
This argument will disable tuner, checkpoint callbacks, early stopping callbacks,
41+
loggers and logger callbacks like ``LearningRateLogger`` and runs for only 1 epoch.
42+
3843
----------------
3944

4045
Inspect gradient norms

docs/source/trainer.rst

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -666,9 +666,9 @@ Under the hood the pseudocode looks like this when running *fast_dev_run* with a
666666
.. note::
667667
668668
This argument is a bit different from ``limit_train/val/test_batches``. Setting this argument will
669-
disable tuner, logger callbacks like ``LearningRateLogger`` and runs for only 1 epoch. This must be
670-
used only for debugging purposes. ``limit_train/val/test_batches`` only limits the number of batches and won't
671-
disable anything.
669+
disable tuner, checkpoint callbacks, early stopping callbacks, loggers and logger callbacks like
670+
``LearningRateLogger`` and runs for only 1 epoch. This must be used only for debugging purposes.
671+
``limit_train/val/test_batches`` only limits the number of batches and won't disable anything.
672672
673673
flush_logs_every_n_steps
674674
^^^^^^^^^^^^^^^^^^^^^^^^
@@ -1328,7 +1328,8 @@ resume_from_checkpoint
13281328
13291329
|
13301330
1331-
To resume training from a specific checkpoint pass in the path here.
1331+
To resume training from a specific checkpoint pass in the path here. If resuming from a mid-epoch
1332+
checkpoint, training will start from the beginning of the next epoch.
13321333
13331334
.. testcode::
13341335

environment.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ dependencies:
3030
- future>=0.17.1
3131
- PyYAML>=5.1
3232
- tqdm>=4.41.0
33-
- fsspec>=0.8.0
33+
- fsspec[http]>=0.8.1
3434
#- tensorboard>=2.2.0 # not needed, already included in pytorch
3535

3636
# Optional

pytorch_lightning/callbacks/early_stopping.py

Lines changed: 9 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
from pytorch_lightning import _logger as log
2929
from pytorch_lightning.callbacks.base import Callback
3030
from pytorch_lightning.metrics.metric import Metric
31-
from pytorch_lightning.utilities import TPU_AVAILABLE, rank_zero_info, rank_zero_warn
31+
from pytorch_lightning.utilities import rank_zero_info, rank_zero_warn, TPU_AVAILABLE
3232

3333

3434
class EarlyStopping(Callback):
@@ -166,10 +166,10 @@ def on_validation_end(self, trainer, pl_module):
166166
self._run_early_stopping_check(trainer, pl_module)
167167

168168
def on_validation_epoch_end(self, trainer, pl_module):
169-
if trainer.running_sanity_check:
169+
if trainer.fast_dev_run or trainer.running_sanity_check:
170170
return
171171

172-
if self._validate_condition_metric(trainer.logger_connector.callback_metrics):
172+
if self._validate_condition_metric(trainer.callback_metrics):
173173
# turn off early stopping in on_train_epoch_end
174174
self.based_on_eval_results = True
175175

@@ -178,24 +178,19 @@ def on_train_epoch_end(self, trainer, pl_module, outputs):
178178
if self.based_on_eval_results:
179179
return
180180

181-
# early stopping can also work in the train loop when there is no val loop
182-
should_check_early_stop = False
183-
184-
# fallback to monitor key in result dict
185-
if trainer.logger_connector.callback_metrics.get(self.monitor, None) is not None:
186-
should_check_early_stop = True
187-
188-
if should_check_early_stop:
189-
self._run_early_stopping_check(trainer, pl_module)
181+
self._run_early_stopping_check(trainer, pl_module)
190182

191183
def _run_early_stopping_check(self, trainer, pl_module):
192184
"""
193185
Checks whether the early stopping condition is met
194186
and if so tells the trainer to stop the training.
195187
"""
196-
logs = trainer.logger_connector.callback_metrics
188+
logs = trainer.callback_metrics
197189

198-
if not self._validate_condition_metric(logs):
190+
if (
191+
trainer.fast_dev_run # disable early_stopping with fast_dev_run
192+
or not self._validate_condition_metric(logs) # short circuit if metric not present
193+
):
199194
return # short circuit if metric not present
200195

201196
current = logs.get(self.monitor)

pytorch_lightning/callbacks/gpu_stats_monitor.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
import shutil
2525
import subprocess
2626
import time
27-
from typing import List, Tuple, Dict
27+
from typing import Dict, List, Tuple
2828

2929
from pytorch_lightning.callbacks.base import Callback
3030
from pytorch_lightning.utilities import rank_zero_only
@@ -213,5 +213,4 @@ def _should_log(trainer) -> bool:
213213
or trainer.should_stop
214214
)
215215

216-
should_log = should_log and not trainer.fast_dev_run
217216
return should_log

pytorch_lightning/callbacks/lr_monitor.py

Lines changed: 40 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -33,11 +33,11 @@ class LearningRateMonitor(Callback):
3333
Automatically monitor and logs learning rate for learning rate schedulers during training.
3434
3535
Args:
36-
logging_interval: set to `epoch` or `step` to log `lr` of all optimizers
37-
at the same interval, set to `None` to log at individual interval
38-
according to the `interval` key of each scheduler. Defaults to ``None``.
36+
logging_interval: set to ``'epoch'`` or ``'step'`` to log ``lr`` of all optimizers
37+
at the same interval, set to ``None`` to log at individual interval
38+
according to the ``interval`` key of each scheduler. Defaults to ``None``.
3939
log_momentum: option to also log the momentum values of the optimizer, if the optimizer
40-
has the `momentum` attribute. Defaults to ``False``.
40+
has the ``momentum`` or ``betas`` attribute. Defaults to ``False``.
4141
4242
Example::
4343
@@ -47,17 +47,19 @@ class LearningRateMonitor(Callback):
4747
>>> trainer = Trainer(callbacks=[lr_monitor])
4848
4949
Logging names are automatically determined based on optimizer class name.
50-
In case of multiple optimizers of same type, they will be named `Adam`,
51-
`Adam-1` etc. If a optimizer has multiple parameter groups they will
52-
be named `Adam/pg1`, `Adam/pg2` etc. To control naming, pass in a
53-
`name` keyword in the construction of the learning rate schdulers
50+
In case of multiple optimizers of same type, they will be named ``Adam``,
51+
``Adam-1`` etc. If a optimizer has multiple parameter groups they will
52+
be named ``Adam/pg1``, ``Adam/pg2`` etc. To control naming, pass in a
53+
``name`` keyword in the construction of the learning rate schdulers
5454
5555
Example::
5656
5757
def configure_optimizer(self):
5858
optimizer = torch.optim.Adam(...)
59-
lr_scheduler = {'scheduler': torch.optim.lr_scheduler.LambdaLR(optimizer, ...)
60-
'name': 'my_logging_name'}
59+
lr_scheduler = {
60+
'scheduler': torch.optim.lr_scheduler.LambdaLR(optimizer, ...)
61+
'name': 'my_logging_name'
62+
}
6163
return [optimizer], [lr_scheduler]
6264
6365
"""
@@ -80,16 +82,28 @@ def on_train_start(self, trainer, *args, **kwargs):
8082
"""
8183
if not trainer.logger:
8284
raise MisconfigurationException(
83-
'Cannot use LearningRateMonitor callback with Trainer that has no logger.'
85+
'Cannot use `LearningRateMonitor` callback with `Trainer` that has no logger.'
8486
)
8587

8688
if not trainer.lr_schedulers:
8789
rank_zero_warn(
88-
'You are using LearningRateMonitor callback with models that'
90+
'You are using `LearningRateMonitor` callback with models that'
8991
' have no learning rate schedulers. Please see documentation'
9092
' for `configure_optimizers` method.', RuntimeWarning
9193
)
9294

95+
if self.log_momentum:
96+
def _check_no_key(key):
97+
return any(
98+
key not in sch['scheduler'].optimizer.defaults for sch in trainer.lr_schedulers
99+
)
100+
101+
if _check_no_key('momentum') and _check_no_key('betas'):
102+
rank_zero_warn(
103+
"You have set log_momentum=True, but some optimizers do not"
104+
" have momentum. This will log a value 0 for the momentum.", RuntimeWarning
105+
)
106+
93107
# Find names for schedulers
94108
names = self._find_names(trainer.lr_schedulers)
95109

@@ -105,35 +119,33 @@ def on_train_batch_start(self, trainer, *args, **kwargs):
105119
interval = 'step' if self.logging_interval is None else 'any'
106120
latest_stat = self._extract_stats(trainer, interval)
107121

108-
if trainer.logger is not None and latest_stat:
122+
if latest_stat:
109123
trainer.logger.log_metrics(latest_stat, step=trainer.global_step)
110124

111125
def on_train_epoch_start(self, trainer, *args, **kwargs):
112126
if self.logging_interval != 'step':
113127
interval = 'epoch' if self.logging_interval is None else 'any'
114128
latest_stat = self._extract_stats(trainer, interval)
115129

116-
if trainer.logger is not None and latest_stat:
130+
if latest_stat:
117131
trainer.logger.log_metrics(latest_stat, step=trainer.global_step)
118132

119133
def _extract_stats(self, trainer, interval: str) -> Dict[str, float]:
120134
latest_stat = {}
121135

122136
for name, scheduler in zip(self.lr_sch_names, trainer.lr_schedulers):
123137
if scheduler['interval'] == interval or interval == 'any':
124-
param_groups = scheduler['scheduler'].optimizer.param_groups
125-
if len(param_groups) != 1:
126-
for i, pg in enumerate(param_groups):
127-
lr = self._extract_lr(param_group=pg, name=f'{name}/pg{i + 1}')
128-
latest_stat.update(lr)
129-
momentum = self._extract_momentum(param_group=pg, name=f'{name}-momentum/pg{i + 1}')
130-
latest_stat.update(momentum)
131-
132-
else:
133-
pg = param_groups[0]
134-
lr = self._extract_lr(param_group=pg, name=name)
138+
opt = scheduler['scheduler'].optimizer
139+
param_groups = opt.param_groups
140+
use_betas = 'betas' in opt.defaults
141+
142+
for i, pg in enumerate(param_groups):
143+
suffix = f'/pg{i + 1}' if len(param_groups) > 1 else ''
144+
lr = self._extract_lr(param_group=pg, name=f'{name}{suffix}')
135145
latest_stat.update(lr)
136-
momentum = self._extract_momentum(param_group=pg, name=f'{name}-momentum')
146+
momentum = self._extract_momentum(
147+
param_group=pg, name=f'{name}-momentum{suffix}', use_betas=use_betas
148+
)
137149
latest_stat.update(momentum)
138150

139151
return latest_stat
@@ -143,11 +155,11 @@ def _extract_lr(self, param_group, name: str) -> Dict[str, float]:
143155
self.lrs[name].append(lr)
144156
return {name: lr}
145157

146-
def _extract_momentum(self, param_group, name: str) -> Dict[str, float]:
158+
def _extract_momentum(self, param_group, name: str, use_betas: bool) -> Dict[str, float]:
147159
if not self.log_momentum:
148160
return {}
149161

150-
momentum = param_group.get('momentum')
162+
momentum = param_group.get('betas')[0] if use_betas else param_group.get('momentum', 0)
151163
self.last_momentum_values[name] = momentum
152164
return {name: momentum}
153165

@@ -190,5 +202,4 @@ def _should_log(trainer) -> bool:
190202
or trainer.should_stop
191203
)
192204

193-
should_log = should_log and not trainer.fast_dev_run
194205
return should_log

pytorch_lightning/callbacks/model_checkpoint.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,11 +20,11 @@
2020
2121
"""
2222

23-
from copy import deepcopy
2423
import numbers
2524
import os
26-
from pathlib import Path
2725
import re
26+
from copy import deepcopy
27+
from pathlib import Path
2828
from typing import Any, Dict, Optional, Union
2929

3030
import numpy as np
@@ -224,7 +224,8 @@ def save_checkpoint(self, trainer, pl_module):
224224
global_step = trainer.global_step
225225

226226
if (
227-
self.save_top_k == 0 # no models are saved
227+
trainer.fast_dev_run # disable checkpointing with fast_dev_run
228+
or self.save_top_k == 0 # no models are saved
228229
or self.period < 1 # no models are saved
229230
or (epoch + 1) % self.period # skip epoch
230231
or trainer.running_sanity_check # don't save anything during sanity check
@@ -478,14 +479,14 @@ def __resolve_ckpt_dir(self, trainer, pl_module):
478479
version, name = trainer.accelerator_backend.broadcast((version, trainer.logger.name))
479480

480481
ckpt_path = os.path.join(
481-
save_dir, name, version, "checkpoints"
482+
save_dir, str(name), version, "checkpoints"
482483
)
483484
else:
484485
ckpt_path = os.path.join(trainer.weights_save_path, "checkpoints")
485486

486487
self.dirpath = ckpt_path
487488

488-
if trainer.is_global_zero:
489+
if not trainer.fast_dev_run and trainer.is_global_zero:
489490
self._fs.makedirs(self.dirpath, exist_ok=True)
490491

491492
def _add_backward_monitor_support(self, trainer):

pytorch_lightning/callbacks/progress.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@
2222
import importlib
2323
import sys
2424

25-
2625
# check if ipywidgets is installed before importing tqdm.auto
2726
# to ensure it won't fail and a progress bar is displayed
2827
if importlib.util.find_spec('ipywidgets') is not None:
@@ -323,7 +322,7 @@ def on_epoch_start(self, trainer, pl_module):
323322
super().on_epoch_start(trainer, pl_module)
324323
total_train_batches = self.total_train_batches
325324
total_val_batches = self.total_val_batches
326-
if total_train_batches != float('inf') and not trainer.fast_dev_run:
325+
if total_train_batches != float('inf'):
327326
# val can be checked multiple times per epoch
328327
val_checks_per_epoch = total_train_batches // trainer.val_check_batch
329328
total_val_batches = total_val_batches * val_checks_per_epoch

0 commit comments

Comments
 (0)