Skip to content

Commit fdcecb9

Browse files
authored
Merge branch 'master' into bugfix/dummy-logger
2 parents 5d6937a + a6c98c4 commit fdcecb9

File tree

13 files changed

+181
-79
lines changed

13 files changed

+181
-79
lines changed

CHANGELOG.md

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,9 +89,18 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
8989
- Do not print top-k verbose log with `ModelCheckpoint(monitor=None)` ([#6109](https://github.com/PyTorchLightning/pytorch-lightning/pull/6109))
9090

9191

92+
- Fixed `ModelCheckpoint(monitor=None, save_last=True)` not saving checkpoints ([#6136](https://github.com/PyTorchLightning/pytorch-lightning/pull/6136))
93+
94+
95+
- Fixed `ModelCheckpoint(save_top_k=0, save_last=True)` not saving the `last` checkpoint ([#6136](https://github.com/PyTorchLightning/pytorch-lightning/pull/6136))
96+
97+
9298
- Expose DeepSpeed loss parameters to allow users to fix loss instability ([#6115](https://github.com/PyTorchLightning/pytorch-lightning/pull/6115))
9399

94100

101+
- Fixed `AttributeError` when `logger=None` on TPU ([#6221](https://github.com/PyTorchLightning/pytorch-lightning/pull/6221))
102+
103+
95104
- Fixed `ModelPruning(make_pruning_permanent=True)` pruning buffers getting removed when saved during training ([#6073](https://github.com/PyTorchLightning/pytorch-lightning/pull/6073))
96105

97106

@@ -110,6 +119,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
110119
- Fixed PyTorch Profiler with `emit_nvtx` ([#6260](https://github.com/PyTorchLightning/pytorch-lightning/pull/6260))
111120

112121

122+
- Fixed `Trainer` not resetting `lightning_optimizers` when calling `Trainer.fit()` multiple times ([#6372](https://github.com/PyTorchLightning/pytorch-lightning/pull/6372))
123+
124+
113125
- Fixed `DummyLogger.log_hyperparams` raising a `TypeError` when running with `fast_dev_run=True` ([#6398](https://github.com/PyTorchLightning/pytorch-lightning/pull/6398))
114126

115127

docs/source/extensions/logging.rst

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -294,6 +294,25 @@ Some loggers also allow logging the hyperparams used in the experiment. For inst
294294
when using the TestTubeLogger or the TensorBoardLogger, all hyperparams will show
295295
in the `hparams tab <https://pytorch.org/docs/stable/tensorboard.html#torch.utils.tensorboard.writer.SummaryWriter.add_hparams>`_.
296296

297+
.. note::
298+
If you want to track a metric in the tensorboard hparams tab, log scalars to the key ``hp_metric``. If tracking multiple metrics, initialize ``TensorBoardLogger`` with ``default_hp_metric=False`` and call ``log_hyperparams`` only once with your metric keys and initial values. Subsequent updates can simply be logged to the metric keys. Refer to the following for examples on how to setup proper hyperparams metrics tracking within :doc:`LightningModule <../common/lightning_module>`.
299+
300+
.. code-block:: python
301+
302+
# Using default_hp_metric
303+
def validation_step(self, batch, batch_idx):
304+
self.log("hp_metric", some_scalar)
305+
306+
# Using custom or multiple metrics (default_hp_metric=False)
307+
def on_train_start(self):
308+
self.logger.log_hyperparams(self.hparams, {"hp/metric_1": 0, "hp/metric_2": 0})
309+
310+
def validation_step(self, batch, batch_idx):
311+
self.log("hp/metric_1", some_scalar_1)
312+
self.log("hp/metric_2", some_scalar_2)
313+
314+
In the example, using `hp/` as a prefix allows for the metrics to be grouped under "hp" in the tensorboard scalar tab where you can collapse them.
315+
297316
----------
298317

299318
*************

notebooks/06-mnist-tpu-training.ipynb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@
8080
"id": "AYGWh10lRaF1"
8181
},
8282
"source": [
83-
"! pip install cloud-tpu-client==0.10 https://storage.googleapis.com/tpu-pytorch/wheels/torch_xla-1.7-cp36-cp36m-linux_x86_64.whl"
83+
"! pip install cloud-tpu-client==0.10 https://storage.googleapis.com/tpu-pytorch/wheels/torch_xla-1.7-cp37-cp37m-linux_x86_64.whl"
8484
],
8585
"execution_count": null,
8686
"outputs": []

pytorch_lightning/callbacks/model_checkpoint.py

Lines changed: 81 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -189,7 +189,7 @@ def on_validation_end(self, trainer, pl_module):
189189
"""
190190
checkpoints can be saved at the end of the val loop
191191
"""
192-
self.save_checkpoint(trainer, pl_module)
192+
self.save_checkpoint(trainer)
193193

194194
def on_save_checkpoint(self, trainer, pl_module, checkpoint: Dict[str, Any]) -> Dict[str, Any]:
195195
return {
@@ -204,12 +204,18 @@ def on_load_checkpoint(self, callback_state: Dict[str, Any]):
204204
self.best_model_score = callback_state["best_model_score"]
205205
self.best_model_path = callback_state["best_model_path"]
206206

207-
def save_checkpoint(self, trainer, pl_module):
207+
def save_checkpoint(self, trainer, unused: Optional = None):
208208
"""
209209
Performs the main logic around saving a checkpoint.
210210
This method runs on all ranks, it is the responsibility of `self.save_function`
211211
to handle correct behaviour in distributed training, i.e., saving only on rank 0.
212212
"""
213+
if unused is not None:
214+
rank_zero_warn(
215+
"`ModelCheckpoint.save_checkpoint` signature has changed in v1.3. The `pl_module` parameter"
216+
" has been removed. Support for the old signature will be removed in v1.5", DeprecationWarning
217+
)
218+
213219
epoch = trainer.current_epoch
214220
global_step = trainer.global_step
215221

@@ -218,7 +224,6 @@ def save_checkpoint(self, trainer, pl_module):
218224
trainer.fast_dev_run # disable checkpointing with fast_dev_run
219225
or trainer.state != TrainerState.FITTING # don't save anything during non-fit
220226
or trainer.sanity_checking # don't save anything during sanity check
221-
or self.save_top_k == 0 # no models are saved
222227
or self.period < 1 # no models are saved
223228
or (epoch + 1) % self.period # skip epoch
224229
or self._last_global_step_saved == global_step # already saved at the last step
@@ -236,28 +241,33 @@ def save_checkpoint(self, trainer, pl_module):
236241

237242
# callback supports multiple simultaneous modes
238243
# here we call each mode sequentially
239-
# Mode 1: save all checkpoints OR only the top k
240-
if self.save_top_k:
241-
self._save_top_k_checkpoints(trainer, pl_module, monitor_candidates)
242-
243-
# Mode 2: save the last checkpoint
244+
# Mode 1: save the top k checkpoints
245+
self._save_top_k_checkpoint(trainer, monitor_candidates)
246+
# Mode 2: save monitor=None checkpoints
247+
self._save_none_monitor_checkpoint(trainer, monitor_candidates)
248+
# Mode 3: save last checkpoints
244249
self._save_last_checkpoint(trainer, monitor_candidates)
245250

246251
def __validate_init_configuration(self):
247252
if self.save_top_k is not None and self.save_top_k < -1:
248253
raise MisconfigurationException(f'Invalid value for save_top_k={self.save_top_k}. Must be None or >= -1')
249254
if self.monitor is None:
250255
# None: save last epoch, -1: save all epochs, 0: nothing is saved
251-
if self.save_top_k not in [None, -1, 0]:
256+
if self.save_top_k not in (None, -1, 0):
252257
raise MisconfigurationException(
253258
f'ModelCheckpoint(save_top_k={self.save_top_k}, monitor=None) is not a valid'
254259
' configuration. No quantity for top_k to track.'
255260
)
256261
if self.save_last:
257262
rank_zero_warn(
258-
'ModelCheckpoint(save_last=True, monitor=None) is a redundant configuration.'
263+
'ModelCheckpoint(save_last=True, save_top_k=None, monitor=None) is a redundant configuration.'
259264
' You can save the last checkpoint with ModelCheckpoint(save_top_k=None, monitor=None).'
260265
)
266+
if self.save_top_k == -1 and self.save_last:
267+
rank_zero_info(
268+
'ModelCheckpoint(save_last=True, save_top_k=-1, monitor=None)'
269+
' will duplicate the last checkpoint saved.'
270+
)
261271

262272
def __init_ckpt_dir(self, dirpath, filename, save_top_k):
263273

@@ -293,7 +303,16 @@ def _del_model(self, filepath: str):
293303
self._fs.rm(filepath)
294304
log.debug(f"Removed checkpoint: {filepath}")
295305

296-
def _save_model(self, filepath: str, trainer):
306+
def _save_model(self, trainer, filepath: str):
307+
if trainer.training_type_plugin.rpc_enabled:
308+
# RPCPlugin manages saving all model states
309+
# TODO: the rpc plugin should wrap trainer.save_checkpoint
310+
# instead of us having to do it here manually
311+
trainer.training_type_plugin.rpc_save_model(trainer, self._do_save, filepath)
312+
else:
313+
self._do_save(trainer, filepath)
314+
315+
def _do_save(self, trainer, filepath: str):
297316
# in debugging, track when we save checkpoints
298317
trainer.dev_debugger.track_checkpointing_history(filepath)
299318

@@ -307,7 +326,7 @@ def _save_model(self, filepath: str, trainer):
307326
else:
308327
raise ValueError(".save_function() not set")
309328

310-
def check_monitor_top_k(self, current) -> bool:
329+
def check_monitor_top_k(self, current: torch.Tensor) -> bool:
311330
if current is None:
312331
return False
313332

@@ -462,17 +481,17 @@ def _validate_monitor_key(self, trainer):
462481

463482
def _get_metric_interpolated_filepath_name(
464483
self,
465-
ckpt_name_metrics: Dict[str, Any],
484+
monitor_candidates: Dict[str, Any],
466485
epoch: int,
467486
step: int,
468487
trainer,
469488
del_filepath: Optional[str] = None,
470489
) -> str:
471-
filepath = self.format_checkpoint_name(epoch, step, ckpt_name_metrics)
490+
filepath = self.format_checkpoint_name(epoch, step, monitor_candidates)
472491

473492
version_cnt = self.STARTING_VERSION
474493
while self.file_exists(filepath, trainer) and filepath != del_filepath:
475-
filepath = self.format_checkpoint_name(epoch, step, ckpt_name_metrics, ver=version_cnt)
494+
filepath = self.format_checkpoint_name(epoch, step, monitor_candidates, ver=version_cnt)
476495
version_cnt += 1
477496

478497
return filepath
@@ -482,47 +501,32 @@ def _monitor_candidates(self, trainer):
482501
monitor_candidates.update(step=trainer.global_step, epoch=trainer.current_epoch)
483502
return monitor_candidates
484503

485-
def _save_last_checkpoint(self, trainer, ckpt_name_metrics):
486-
should_save_last = self.monitor is None or self.save_last
487-
if not should_save_last:
504+
def _save_last_checkpoint(self, trainer, monitor_candidates: Dict[str, Any]):
505+
if not self.save_last:
488506
return
489507

490-
# when user ALSO asked for the 'last.ckpt' change the name
491-
if self.save_last:
492-
last_filepath = self._format_checkpoint_name(
493-
self.CHECKPOINT_NAME_LAST,
494-
trainer.current_epoch,
495-
trainer.global_step,
496-
ckpt_name_metrics,
497-
)
498-
last_filepath = os.path.join(self.dirpath, f"{last_filepath}{self.FILE_EXTENSION}")
499-
else:
500-
last_filepath = self._get_metric_interpolated_filepath_name(
501-
ckpt_name_metrics,
502-
trainer.current_epoch,
503-
trainer.global_step,
504-
trainer,
505-
)
508+
filepath = self._format_checkpoint_name(
509+
self.CHECKPOINT_NAME_LAST,
510+
trainer.current_epoch,
511+
trainer.global_step,
512+
monitor_candidates,
513+
)
514+
filepath = os.path.join(self.dirpath, f"{filepath}{self.FILE_EXTENSION}")
506515

507-
if trainer.training_type_plugin.rpc_enabled:
508-
# RPCPlugin manages saving all model states
509-
trainer.training_type_plugin.rpc_save_model(self._save_model, last_filepath, trainer)
510-
else:
511-
self._save_model(last_filepath, trainer)
512-
if (
513-
self.last_model_path and self.last_model_path != last_filepath
514-
and (self.save_top_k != -1 or self.save_last) and trainer.is_global_zero
515-
):
516+
self._save_model(trainer, filepath)
517+
518+
if self.last_model_path and self.last_model_path != filepath and trainer.is_global_zero:
516519
self._del_model(self.last_model_path)
517-
self.last_model_path = last_filepath
518520

519-
if self.monitor is None:
520-
self.best_model_path = self.last_model_path
521+
self.last_model_path = filepath
522+
523+
def _save_top_k_checkpoint(self, trainer, monitor_candidates: Dict[str, Any]):
524+
if self.monitor is None or self.save_top_k == 0:
525+
return
521526

522-
def _save_top_k_checkpoints(self, trainer, pl_module, metrics):
523-
current = metrics.get(self.monitor)
524-
epoch = metrics.get("epoch")
525-
step = metrics.get("step")
527+
current = monitor_candidates.get(self.monitor)
528+
epoch = monitor_candidates.get("epoch")
529+
step = monitor_candidates.get("step")
526530

527531
# when `val_loss` is being logged and no ModelCheckpoint is being provided
528532
# `val_loss` will be selected for monitor and need to be reduced to
@@ -533,15 +537,37 @@ def _save_top_k_checkpoints(self, trainer, pl_module, metrics):
533537
current = trainer.training_type_plugin.reduce(current, reduce_op="mean")
534538

535539
if self.check_monitor_top_k(current):
536-
self._update_best_and_save(current, epoch, step, trainer, pl_module, metrics)
537-
elif self.monitor is not None and self.verbose:
540+
self._update_best_and_save(current, epoch, step, trainer, monitor_candidates)
541+
elif self.verbose:
538542
rank_zero_info(f"Epoch {epoch:d}, step {step:d}: {self.monitor} was not in top {self.save_top_k}")
539543

544+
def _save_none_monitor_checkpoint(self, trainer, monitor_candidates: Dict[str, Any]):
545+
if self.monitor is not None or self.save_top_k == 0:
546+
return
547+
548+
filepath = self._get_metric_interpolated_filepath_name(
549+
monitor_candidates,
550+
trainer.current_epoch,
551+
trainer.global_step,
552+
trainer,
553+
)
554+
self._save_model(trainer, filepath)
555+
556+
if (
557+
self.save_top_k is None
558+
and self.best_model_path
559+
and self.best_model_path != filepath
560+
and trainer.is_global_zero
561+
):
562+
self._del_model(self.best_model_path)
563+
564+
self.best_model_path = filepath
565+
540566
def _is_valid_monitor_key(self, metrics):
541567
return self.monitor in metrics or len(metrics) == 0
542568

543569
def _update_best_and_save(
544-
self, current: torch.Tensor, epoch: int, step: int, trainer, pl_module, ckpt_name_metrics
570+
self, current: torch.Tensor, epoch: int, step: int, trainer, monitor_candidates: Dict[str, Any]
545571
):
546572
k = len(self.best_k_models) + 1 if self.save_top_k == -1 else self.save_top_k
547573

@@ -554,7 +580,7 @@ def _update_best_and_save(
554580
if isinstance(current, torch.Tensor) and torch.isnan(current):
555581
current = torch.tensor(float('inf' if self.mode == "min" else '-inf'))
556582

557-
filepath = self._get_metric_interpolated_filepath_name(ckpt_name_metrics, epoch, step, trainer, del_filepath)
583+
filepath = self._get_metric_interpolated_filepath_name(monitor_candidates, epoch, step, trainer, del_filepath)
558584

559585
# save the current score
560586
self.current_score = current
@@ -575,7 +601,7 @@ def _update_best_and_save(
575601
f"Epoch {epoch:d}, global step {step:d}: {self.monitor} reached {current:0.5f}"
576602
f' (best {self.best_model_score:0.5f}), saving model to "{filepath}" as top {k}'
577603
)
578-
self._save_model(filepath, trainer)
604+
self._save_model(trainer, filepath)
579605

580606
if del_filepath is not None and filepath != del_filepath:
581607
self._del_model(del_filepath)

pytorch_lightning/loggers/tensorboard.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -46,10 +46,13 @@ class TensorBoardLogger(LightningLoggerBase):
4646
preinstalled.
4747
4848
Example:
49-
>>> from pytorch_lightning import Trainer
50-
>>> from pytorch_lightning.loggers import TensorBoardLogger
51-
>>> logger = TensorBoardLogger("tb_logs", name="my_model")
52-
>>> trainer = Trainer(logger=logger)
49+
50+
.. testcode::
51+
52+
from pytorch_lightning import Trainer
53+
from pytorch_lightning.loggers import TensorBoardLogger
54+
logger = TensorBoardLogger("tb_logs", name="my_model")
55+
trainer = Trainer(logger=logger)
5356
5457
Args:
5558
save_dir: Save directory

pytorch_lightning/plugins/training_type/rpc.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
# limitations under the License.
1414
import os
1515
from contextlib import suppress
16-
from typing import List, Optional
16+
from typing import List, Optional, Callable
1717

1818
import torch
1919

@@ -63,15 +63,15 @@ def init_rpc_connection(self, global_rank: int, world_size: int) -> None:
6363
rpc._set_rpc_timeout(self.rpc_timeout_sec)
6464
self._is_rpc_initialized = True
6565

66-
def rpc_save_model(self, save_model_fn, last_filepath, trainer) -> None:
66+
def rpc_save_model(self, trainer, save_model_fn: Callable, filepath: str) -> None:
6767
"""
6868
Override to save model to disk.
6969
This is required as the main process will be required to handle aggregating model states from RPC processes.
7070
7171
Args:
72-
save_model_fn: The saving function to save final model.
73-
last_filepath: The filepath to save the model to.
7472
trainer: The trainer object.
73+
save_model_fn: The saving function to save final model.
74+
filepath: The filepath to save the model to.
7575
"""
7676
raise NotImplementedError
7777

pytorch_lightning/plugins/training_type/rpc_sequential.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
# limitations under the License
1414
import logging
1515
import os
16-
from typing import List, Optional
16+
from typing import List, Optional, Callable
1717

1818
import torch
1919
import torch.distributed as torch_distrib
@@ -266,7 +266,7 @@ def configure_ddp(self):
266266
self._model.require_backward_grad_sync = False
267267

268268
@rank_zero_only
269-
def rpc_save_model(self, save_model_fn, last_filepath, trainer) -> None:
269+
def rpc_save_model(self, trainer, save_model_fn: Callable, filepath: str) -> None:
270270
model = self.lightning_module
271271
if not hasattr(model.sequential_module, "foreach_worker"):
272272
return
@@ -275,7 +275,7 @@ def rpc_save_model(self, save_model_fn, last_filepath, trainer) -> None:
275275
save_layers_on_all_rank_zero_workers, {"gpus_per_model": self.gpus_per_model}, include_self=True
276276
)
277277
model.sequential_module = load_sequential_from_saved_layers(self.gpus_per_model)
278-
save_model_fn(last_filepath, trainer)
278+
save_model_fn(trainer, filepath)
279279
model.sequential_module = current_layers
280280

281281
def worker_optimizer_step(self, model: LightningModule, opt_idx: int, *args, **kwargs) -> None:

pytorch_lightning/plugins/training_type/tpu_spawn.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -262,7 +262,7 @@ def __load_weights_on_main_process(self) -> None:
262262
self._model = model
263263

264264
def _close_logger(self, trainer) -> None:
265-
if hasattr(trainer, "logger"):
265+
if trainer.logger is not None:
266266
trainer.logger.finalize("success")
267267

268268
@property

0 commit comments

Comments
 (0)