From 8dbc25c24ae6ab738d8b072a8dd9d80353c8bcb7 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Tue, 8 Feb 2022 01:08:37 +0100 Subject: [PATCH 1/9] Integrate global step with progress tracking --- .../callbacks/device_stats_monitor.py | 4 +- .../callbacks/gpu_stats_monitor.py | 4 +- pytorch_lightning/callbacks/lr_monitor.py | 4 +- .../callbacks/model_checkpoint.py | 13 ++---- .../loops/dataloader/evaluation_loop.py | 4 +- .../loops/epoch/training_epoch_loop.py | 28 ++++++++----- pytorch_lightning/loops/fit_loop.py | 27 +++---------- .../connectors/checkpoint_connector.py | 14 ++++--- .../logger_connector/logger_connector.py | 12 ++---- pytorch_lightning/trainer/trainer.py | 2 +- pytorch_lightning/tuner/batch_size_scaling.py | 2 - pytorch_lightning/tuner/lr_finder.py | 2 - tests/callbacks/test_lr_monitor.py | 7 ++-- tests/callbacks/test_rich_progress_bar.py | 14 +++---- tests/callbacks/test_tqdm_progress_bar.py | 14 +++---- tests/checkpointing/test_model_checkpoint.py | 40 ++++++++++--------- .../checkpointing/test_trainer_checkpoint.py | 16 -------- tests/loggers/test_comet.py | 2 +- tests/loggers/test_mlflow.py | 4 +- tests/loggers/test_wandb.py | 4 +- tests/loops/test_loops.py | 7 ++-- tests/loops/test_training_loop.py | 2 +- tests/models/test_amp.py | 4 +- tests/models/test_restore.py | 13 +++--- tests/plugins/test_checkpoint_io_plugin.py | 4 +- tests/trainer/optimization/test_optimizers.py | 2 +- tests/trainer/test_trainer.py | 3 +- 27 files changed, 109 insertions(+), 143 deletions(-) diff --git a/pytorch_lightning/callbacks/device_stats_monitor.py b/pytorch_lightning/callbacks/device_stats_monitor.py index f9cb3cf623c1b..75f2650361734 100644 --- a/pytorch_lightning/callbacks/device_stats_monitor.py +++ b/pytorch_lightning/callbacks/device_stats_monitor.py @@ -66,7 +66,7 @@ def on_train_batch_start( for logger in trainer.loggers: separator = logger.group_separator prefixed_device_stats = _prefix_metric_keys(device_stats, "on_train_batch_start", separator) - logger.log_metrics(prefixed_device_stats, step=trainer.global_step) + logger.log_metrics(prefixed_device_stats, step=trainer.fit_loop.epoch_loop._batches_that_stepped) def on_train_batch_end( self, @@ -88,7 +88,7 @@ def on_train_batch_end( for logger in trainer.loggers: separator = logger.group_separator prefixed_device_stats = _prefix_metric_keys(device_stats, "on_train_batch_end", separator) - logger.log_metrics(prefixed_device_stats, step=trainer.global_step) + logger.log_metrics(prefixed_device_stats, step=trainer.fit_loop.epoch_loop._batches_that_stepped) def _prefix_metric_keys(metrics_dict: Dict[str, float], prefix: str, separator: str) -> Dict[str, float]: diff --git a/pytorch_lightning/callbacks/gpu_stats_monitor.py b/pytorch_lightning/callbacks/gpu_stats_monitor.py index 2c2949e53bf7e..6be1efbd0933a 100644 --- a/pytorch_lightning/callbacks/gpu_stats_monitor.py +++ b/pytorch_lightning/callbacks/gpu_stats_monitor.py @@ -162,7 +162,7 @@ def on_train_batch_start( logs["batch_time/inter_step (ms)"] = (time.time() - self._snap_inter_step_time) * 1000 for logger in trainer.loggers: - logger.log_metrics(logs, step=trainer.global_step) + logger.log_metrics(logs, step=trainer.fit_loop.epoch_loop._batches_that_stepped) @rank_zero_only def on_train_batch_end( @@ -187,7 +187,7 @@ def on_train_batch_end( logs["batch_time/intra_step (ms)"] = (time.time() - self._snap_intra_step_time) * 1000 for logger in trainer.loggers: - logger.log_metrics(logs, step=trainer.global_step) + logger.log_metrics(logs, step=trainer.fit_loop.epoch_loop._batches_that_stepped) @staticmethod def _get_gpu_ids(device_ids: List[int]) -> List[str]: diff --git a/pytorch_lightning/callbacks/lr_monitor.py b/pytorch_lightning/callbacks/lr_monitor.py index 00ff007af5e41..5f8f181b0728f 100644 --- a/pytorch_lightning/callbacks/lr_monitor.py +++ b/pytorch_lightning/callbacks/lr_monitor.py @@ -158,7 +158,7 @@ def on_train_batch_start(self, trainer: "pl.Trainer", *args: Any, **kwargs: Any) if latest_stat: for logger in trainer.loggers: - logger.log_metrics(latest_stat, step=trainer.global_step) + logger.log_metrics(latest_stat, step=trainer.fit_loop.epoch_loop._batches_that_stepped) def on_train_epoch_start(self, trainer: "pl.Trainer", *args: Any, **kwargs: Any) -> None: if self.logging_interval != "step": @@ -167,7 +167,7 @@ def on_train_epoch_start(self, trainer: "pl.Trainer", *args: Any, **kwargs: Any) if latest_stat: for logger in trainer.loggers: - logger.log_metrics(latest_stat, step=trainer.global_step) + logger.log_metrics(latest_stat, step=trainer.fit_loop.epoch_loop._batches_that_stepped) def _extract_stats(self, trainer: "pl.Trainer", interval: str) -> Dict[str, float]: latest_stat = {} diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index 1e231abab2661..dbf0f77bd4249 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -223,7 +223,7 @@ def __init__( self.save_weights_only = save_weights_only self.auto_insert_metric_name = auto_insert_metric_name self._save_on_train_epoch_end = save_on_train_epoch_end - self._last_global_step_saved = -1 + self._last_global_step_saved = 0 # no need to save when no steps were taken self._last_time_checked: Optional[float] = None self.current_score = None self.best_k_models = {} @@ -274,8 +274,7 @@ def on_train_batch_end( """Save checkpoint on train batch end if we meet the criteria for `every_n_train_steps`""" if self._should_skip_saving_checkpoint(trainer): return - step = trainer.global_step - skip_batch = self._every_n_train_steps < 1 or ((step + 1) % self._every_n_train_steps != 0) + skip_batch = self._every_n_train_steps < 1 or (trainer.global_step % self._every_n_train_steps != 0) train_time_interval = self._train_time_interval skip_time = True @@ -296,8 +295,6 @@ def on_train_batch_end( def on_train_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: """Save a checkpoint at the end of the training epoch.""" - # as we advance one step at end of training, we use `global_step - 1` to avoid saving duplicates - trainer.fit_loop.global_step -= 1 if ( not self._should_skip_saving_checkpoint(trainer) and self._save_on_train_epoch_end @@ -305,7 +302,6 @@ def on_train_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModu and (trainer.current_epoch + 1) % self._every_n_epochs == 0 ): self.save_checkpoint(trainer) - trainer.fit_loop.global_step += 1 def on_validation_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: """Save a checkpoint at the end of the validation stage.""" @@ -328,11 +324,8 @@ def on_train_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") - return if self.verbose: rank_zero_info("Saving latest checkpoint...") - # as we advance one step at end of training, we use `global_step - 1` to avoid saving duplicates - monitor_candidates = self._monitor_candidates(trainer, trainer.current_epoch, trainer.global_step - 1) - trainer.fit_loop.global_step -= 1 + monitor_candidates = self._monitor_candidates(trainer, trainer.current_epoch, trainer.global_step) self._save_last_checkpoint(trainer, monitor_candidates) - trainer.fit_loop.global_step += 1 def on_save_checkpoint( self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", checkpoint: Dict[str, Any] diff --git a/pytorch_lightning/loops/dataloader/evaluation_loop.py b/pytorch_lightning/loops/dataloader/evaluation_loop.py index cb0e79ae89448..5d7b0be1c8151 100644 --- a/pytorch_lightning/loops/dataloader/evaluation_loop.py +++ b/pytorch_lightning/loops/dataloader/evaluation_loop.py @@ -66,8 +66,6 @@ def num_dataloaders(self) -> int: # case where user does: # return dl1, dl2 dataloaders = self.dataloaders - if dataloaders is None: - return 0 length = len(dataloaders) if length > 0 and isinstance(dataloaders[0], (list, tuple)): length = len(dataloaders[0]) @@ -78,7 +76,7 @@ def dataloaders(self) -> Sequence[DataLoader]: """Returns the validation or test dataloaders.""" dataloaders = self.trainer.test_dataloaders if self.trainer.testing else self.trainer.val_dataloaders if dataloaders is None: - raise RuntimeError("Dataloaders should be available.") + return [] return dataloaders def connect(self, epoch_loop: EvaluationEpochLoop) -> None: # type: ignore[override] diff --git a/pytorch_lightning/loops/epoch/training_epoch_loop.py b/pytorch_lightning/loops/epoch/training_epoch_loop.py index 87216944893ef..bef86b1aecf7c 100644 --- a/pytorch_lightning/loops/epoch/training_epoch_loop.py +++ b/pytorch_lightning/loops/epoch/training_epoch_loop.py @@ -60,7 +60,6 @@ def __init__(self, min_steps: Optional[int] = None, max_steps: int = -1) -> None self.min_steps = min_steps self.max_steps = max_steps - self.global_step: int = 0 self.batch_progress = BatchProgress() self.scheduler_progress = SchedulerProgress() @@ -73,6 +72,7 @@ def __init__(self, min_steps: Optional[int] = None, max_steps: int = -1) -> None self._dataloader_iter: Optional[Iterator] = None # caches the loaded dataloader state until dataloader objects are available self._dataloader_state_dict: Dict[str, Any] = {} + self._batches_that_stepped: int = 0 @property def total_batch_idx(self) -> int: @@ -88,6 +88,13 @@ def batch_idx(self) -> int: # but before the next `ready` increase return self.batch_progress.current.ready - 1 + @property + def global_step(self) -> int: + lightning_module = self.trainer.lightning_module + if lightning_module is None or lightning_module.automatic_optimization: + return self.batch_loop.optimizer_loop.optim_progress.optimizer_steps + return self.batch_loop.manual_loop.optim_step_progress.total.completed + @property def _is_training_done(self) -> bool: max_steps_reached = _is_max_limit_reached(self.global_step, self.max_steps) @@ -246,17 +253,14 @@ def on_advance_end(self) -> None: self._run_validation() self.trainer.training = True - # ----------------------------------------- - # SAVE LOGGERS (ie: Tensorboard, etc...) - # ----------------------------------------- - self._save_loggers_on_train_batch_end() - # update plateau LR scheduler after metrics are logged self.update_lr_schedulers("step", update_plateau_schedulers=True) if not self._should_accumulate(): - # progress global step according to grads progress - self.global_step += 1 + # this is increased once per batch disregarding multiple optimizers or tbptt on purpose for loggers + self._batches_that_stepped += 1 + # this will save based on the `batches_that_stepped` value + self._save_loggers_on_train_batch_end() # if training finished, defer exit to the parent. this assumes there will be enough time in between # which might not be the case depending on what's in the `*_epoch_end` hooks @@ -502,10 +506,12 @@ def _should_check_val_fx(self, batch_idx: int, is_last_batch: bool) -> bool: def _save_loggers_on_train_batch_end(self) -> None: """Flushes loggers to disk.""" - # when loggers should save to disk - should_flush_logs = self.trainer.logger_connector.should_flush_logs # TODO: is_global_zero check should be moved to logger.save() implementation - if should_flush_logs and self.trainer.is_global_zero: + if not self.trainer.is_global_zero or self.trainer.logger is None: + return + # this assumes that `batches_that_stepped` was increased before + should_flush = self._batches_that_stepped % self.trainer.flush_logs_every_n_steps == 0 + if should_flush or self.trainer.should_stop: for logger in self.trainer.loggers: logger.save() diff --git a/pytorch_lightning/loops/fit_loop.py b/pytorch_lightning/loops/fit_loop.py index a942f3bf75a99..b8f754f0e7e6e 100644 --- a/pytorch_lightning/loops/fit_loop.py +++ b/pytorch_lightning/loops/fit_loop.py @@ -68,19 +68,6 @@ def __init__( self._outputs: _EPOCH_OUTPUTS_TYPE = [] self._data_fetcher: Optional[AbstractDataFetcher] = None - @property - def global_step(self) -> int: - """Returns the global step.""" - lightning_module = self.trainer.lightning_module - if lightning_module is None or lightning_module.automatic_optimization: - return self.epoch_loop.global_step - return self.epoch_loop.batch_loop.manual_loop.optim_step_progress.total.completed - - @global_step.setter - def global_step(self, value: int) -> None: - """Sets the global step (forwards to epoch_loop)""" - self.epoch_loop.global_step = value - @property def total_batch_idx(self) -> int: """Returns the current batch index (across epochs)""" @@ -171,7 +158,7 @@ def _results(self) -> _ResultCollection: def done(self) -> bool: """Evaluates when to leave the loop.""" # TODO(@awaelchli): Move track steps inside training loop and move part of these condition inside training loop - stop_steps = _is_max_limit_reached(self.global_step, self.max_steps) + stop_steps = _is_max_limit_reached(self.epoch_loop.global_step, self.max_steps) # `processed` is increased before `on_train_epoch_end`, the hook where checkpoints are typically saved. # we use it here because the checkpoint data won't have `completed` increased yet stop_epochs = _is_max_limit_reached(self.epoch_progress.current.processed, self.max_epochs) @@ -180,7 +167,7 @@ def done(self) -> bool: if self.trainer.should_stop: # early stopping met_min_epochs = self.epoch_progress.current.processed >= self.min_epochs if self.min_epochs else True - met_min_steps = self.global_step >= self.min_steps if self.min_steps else True + met_min_steps = self.epoch_loop.global_step >= self.min_steps if self.min_steps else True if met_min_epochs and met_min_steps: should_stop = True else: @@ -312,14 +299,12 @@ def on_advance_end(self) -> None: self.epoch_progress.increment_completed() - # the global step is manually decreased here due to backwards compatibility with existing loggers - # as they expect that the same step is used when logging epoch end metrics even when the batch loop has - # finished. this means the attribute does not exactly track the number of optimizer steps applied. - # TODO(@carmocca): deprecate and rename so users don't get confused - self.global_step -= 1 + # we manually decrease here because loggers expect that the same step is used when logging epoch-end metrics + # even when the batch loop has finished + self.epoch_loop._batches_that_stepped -= 1 # log epoch metrics self.trainer.logger_connector.update_train_epoch_metrics() - self.global_step += 1 + self.epoch_loop._batches_that_stepped += 1 # if fault tolerant is enabled and process has been notified, exit. self.trainer._exit_gracefully_on_signal() diff --git a/pytorch_lightning/trainer/connectors/checkpoint_connector.py b/pytorch_lightning/trainer/connectors/checkpoint_connector.py index 2a68f38bb4a37..e93f4172b5c6d 100644 --- a/pytorch_lightning/trainer/connectors/checkpoint_connector.py +++ b/pytorch_lightning/trainer/connectors/checkpoint_connector.py @@ -231,16 +231,20 @@ def restore_loops(self) -> None: if not self._loaded_checkpoint: return - self.trainer.fit_loop.global_step = self._loaded_checkpoint["global_step"] + fit_loop = self.trainer.fit_loop + # set the `global_step` value for old checkpoints without the progress tracking state. + # it will be overwritten by the loop's state if it was also saved + optimizer_loop = fit_loop.epoch_loop.batch_loop.optimizer_loop + optimizer_loop.optim_progress.optimizer.step.total.completed = self._loaded_checkpoint["global_step"] # set the `current_epoch` value for old checkpoints without the progress tracking state. # it will be overwritten by the loop's state if it was also saved - self.trainer.fit_loop.epoch_progress.current.completed = self._loaded_checkpoint["epoch"] + fit_loop.epoch_progress.current.completed = self._loaded_checkpoint["epoch"] assert self.trainer.state.fn is not None state_dict = self._loaded_checkpoint.get("loops") if state_dict is not None: if self.trainer.state.fn in (TrainerFn.FITTING, TrainerFn.TUNING): - self.trainer.fit_loop.load_state_dict(state_dict["fit_loop"]) + fit_loop.load_state_dict(state_dict["fit_loop"]) elif self.trainer.state.fn == TrainerFn.VALIDATING: self.trainer.validate_loop.load_state_dict(state_dict["validate_loop"]) elif self.trainer.state.fn == TrainerFn.TESTING: @@ -329,9 +333,9 @@ def dump_checkpoint(self, weights_only: bool = False) -> dict: model = self.trainer.lightning_module checkpoint = { - # the epoch is saved for compatibility but it's not relevant for restoration + # the epoch and global step are saved for compatibility but they are not relevant for restoration "epoch": self.trainer.current_epoch, - "global_step": self.trainer.global_step + model.automatic_optimization, + "global_step": self.trainer.global_step, "pytorch-lightning_version": pl.__version__, "state_dict": self._get_lightning_module_state_dict(), "loops": self._get_loops_state_dict(), diff --git a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py index 428713ff3347e..0e3a69bfc9d98 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py @@ -77,15 +77,11 @@ def on_trainer_init( ) break - @property - def should_flush_logs(self) -> bool: - should_flush = (self.trainer.global_step + 1) % self.trainer.flush_logs_every_n_steps == 0 - return should_flush or self.trainer.should_stop - @property def should_update_logs(self) -> bool: - should_log_every_n_steps = (self.trainer.global_step + 1) % self.trainer.log_every_n_steps == 0 - return should_log_every_n_steps or self.trainer.should_stop + # `+ 1` because it can be checked before a step is executed, for example, in `on_train_batch_start` + should_log = (self.trainer.fit_loop.epoch_loop._batches_that_stepped + 1) % self.trainer.log_every_n_steps == 0 + return should_log or self.trainer.should_stop def configure_logger(self, logger: Union[bool, LightningLoggerBase, Iterable[LightningLoggerBase]]) -> None: if isinstance(logger, bool): @@ -123,7 +119,7 @@ def log_metrics(self, metrics: _OUT_DICT, step: Optional[int] = None) -> None: if step is None: # added metrics for convenience scalar_metrics.setdefault("epoch", self.trainer.current_epoch) - step = self.trainer.global_step + step = self.trainer.fit_loop.epoch_loop._batches_that_stepped # log actual metrics for logger in self.trainer.loggers: diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 90c25569d6ac4..7a3c2a7c92d4f 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -2464,7 +2464,7 @@ def sanity_checking(self, val: bool) -> None: @property def global_step(self) -> int: - return self.fit_loop.global_step + return self.fit_loop.epoch_loop.global_step @property def current_epoch(self) -> int: diff --git a/pytorch_lightning/tuner/batch_size_scaling.py b/pytorch_lightning/tuner/batch_size_scaling.py index 3d5916e3f8bd9..6f4ac72bd7e8b 100644 --- a/pytorch_lightning/tuner/batch_size_scaling.py +++ b/pytorch_lightning/tuner/batch_size_scaling.py @@ -60,9 +60,7 @@ def scale_batch_size( # Save initial model, that is loaded after batch size is found ckpt_path = os.path.join(trainer.default_root_dir, f".scale_batch_size_{uuid.uuid4()}.ckpt") - trainer.fit_loop.global_step -= 1 trainer.save_checkpoint(ckpt_path) - trainer.fit_loop.global_step += 1 params = __scale_batch_dump_params(trainer) # Set to values that are required by the algorithm diff --git a/pytorch_lightning/tuner/lr_finder.py b/pytorch_lightning/tuner/lr_finder.py index d929bbe2f87c7..36b09c130056c 100644 --- a/pytorch_lightning/tuner/lr_finder.py +++ b/pytorch_lightning/tuner/lr_finder.py @@ -204,9 +204,7 @@ def lr_find( # Save initial model, that is loaded after learning rate is found ckpt_path = os.path.join(trainer.default_root_dir, f".lr_find_{uuid.uuid4()}.ckpt") - trainer.fit_loop.global_step -= 1 trainer.save_checkpoint(ckpt_path) - trainer.fit_loop.global_step += 1 params = __lr_finder_dump_params(trainer) # Set to values that are required by the algorithm diff --git a/tests/callbacks/test_lr_monitor.py b/tests/callbacks/test_lr_monitor.py index 82a4a5b99894a..391e74bb10221 100644 --- a/tests/callbacks/test_lr_monitor.py +++ b/tests/callbacks/test_lr_monitor.py @@ -217,7 +217,6 @@ def configure_optimizers(self): optimizer2 = optim.Adam(self.parameters(), lr=1e-2) lr_scheduler1 = optim.lr_scheduler.StepLR(optimizer1, 1, gamma=0.1) lr_scheduler2 = optim.lr_scheduler.StepLR(optimizer2, 1, gamma=0.1) - return [optimizer1, optimizer2], [lr_scheduler1, lr_scheduler2] model = CustomBoringModel() @@ -241,7 +240,8 @@ def configure_optimizers(self): assert list(lr_monitor.lrs) == ["lr-Adam", "lr-Adam-1"], "Names of learning rates not set correctly" if logging_interval == "step": - expected_number_logged = trainer.global_step // log_every_n_steps + # divide by 2 because we have 2 optimizers + expected_number_logged = trainer.global_step // 2 // log_every_n_steps if logging_interval == "epoch": expected_number_logged = trainer.max_epochs @@ -284,7 +284,8 @@ def configure_optimizers(self): assert list(lr_monitor.lrs) == ["lr-Adam", "lr-Adam-1"], "Names of learning rates not set correctly" if logging_interval == "step": - expected_number_logged = trainer.global_step // log_every_n_steps + # divide by 2 because we have 2 optimizers + expected_number_logged = trainer.global_step // 2 // log_every_n_steps if logging_interval == "epoch": expected_number_logged = trainer.max_epochs diff --git a/tests/callbacks/test_rich_progress_bar.py b/tests/callbacks/test_rich_progress_bar.py index cfe32dc495f8d..29ef3aa98f89b 100644 --- a/tests/callbacks/test_rich_progress_bar.py +++ b/tests/callbacks/test_rich_progress_bar.py @@ -368,15 +368,15 @@ def test_step(self, batch, batch_idx): trainer.fit(model) assert pbar.calls["fit"] == [ ("sanity_check", 0, 0, {"b": 0}), - ("train", 0, 0, {}), ("train", 0, 1, {}), - ("validate", 0, 1, {"b": 1}), # validation end + ("train", 0, 2, {}), + ("validate", 0, 2, {"b": 2}), # validation end # epoch end over, `on_epoch=True` metrics are computed - ("train", 0, 2, {"a": 1, "b": 1}), # training epoch end - ("train", 1, 2, {"a": 1, "b": 1}), - ("train", 1, 3, {"a": 1, "b": 1}), - ("validate", 1, 3, {"a": 1, "b": 3}), # validation end - ("train", 1, 4, {"a": 3, "b": 3}), # training epoch end + ("train", 0, 2, {"a": 1, "b": 2}), # training epoch end + ("train", 1, 3, {"a": 1, "b": 2}), + ("train", 1, 4, {"a": 1, "b": 2}), + ("validate", 1, 4, {"a": 1, "b": 4}), # validation end + ("train", 1, 4, {"a": 3, "b": 4}), # training epoch end ] trainer.validate(model, verbose=False) diff --git a/tests/callbacks/test_tqdm_progress_bar.py b/tests/callbacks/test_tqdm_progress_bar.py index 7897a1be798bb..3cfe54c992247 100644 --- a/tests/callbacks/test_tqdm_progress_bar.py +++ b/tests/callbacks/test_tqdm_progress_bar.py @@ -608,15 +608,15 @@ def test_step(self, batch, batch_idx): trainer.fit(model) assert pbar.calls["fit"] == [ ("sanity_check", 0, 0, {"b": 0}), - ("train", 0, 0, {}), ("train", 0, 1, {}), - ("validate", 0, 1, {"b": 1}), # validation end + ("train", 0, 2, {}), + ("validate", 0, 2, {"b": 2}), # validation end # epoch end over, `on_epoch=True` metrics are computed - ("train", 0, 2, {"a": 1, "b": 1}), # training epoch end - ("train", 1, 2, {"a": 1, "b": 1}), - ("train", 1, 3, {"a": 1, "b": 1}), - ("validate", 1, 3, {"a": 1, "b": 3}), # validation end - ("train", 1, 4, {"a": 3, "b": 3}), # training epoch end + ("train", 0, 2, {"a": 1, "b": 2}), # training epoch end + ("train", 1, 3, {"a": 1, "b": 2}), + ("train", 1, 4, {"a": 1, "b": 2}), + ("validate", 1, 4, {"a": 1, "b": 4}), # validation end + ("train", 1, 4, {"a": 3, "b": 4}), # training epoch end ] trainer.validate(model, verbose=False) diff --git a/tests/checkpointing/test_model_checkpoint.py b/tests/checkpointing/test_model_checkpoint.py index 2c65426534aff..4e4d8f03e42cd 100644 --- a/tests/checkpointing/test_model_checkpoint.py +++ b/tests/checkpointing/test_model_checkpoint.py @@ -469,7 +469,7 @@ def test_model_checkpoint_file_extension(tmpdir): trainer = Trainer(default_root_dir=tmpdir, callbacks=[model_checkpoint], max_steps=1, logger=False) trainer.fit(model) - expected = ["epoch=0-step=0.tpkc", "last.tpkc"] + expected = ["epoch=0-step=1.tpkc", "last.tpkc"] assert set(expected) == set(os.listdir(tmpdir)) @@ -490,12 +490,12 @@ def test_model_checkpoint_save_last(tmpdir): ) trainer.fit(model) last_filename = model_checkpoint._format_checkpoint_name( - ModelCheckpoint.CHECKPOINT_NAME_LAST, {"epoch": trainer.current_epoch} + ModelCheckpoint.CHECKPOINT_NAME_LAST, {"epoch": trainer.current_epoch - 1} ) last_filename = last_filename + ".ckpt" assert str(tmpdir / last_filename) == model_checkpoint.last_model_path assert set(os.listdir(tmpdir)) == set( - [f"epoch={i}-step={j}.ckpt" for i, j in zip(range(epochs), [9, 19, 29])] + [last_filename] + [f"epoch={i}-step={j}.ckpt" for i, j in zip(range(epochs), [10, 20, 30])] + [last_filename] ) ModelCheckpoint.CHECKPOINT_NAME_LAST = "last" @@ -583,14 +583,14 @@ def test_model_checkpoint_save_last_none_monitor(tmpdir, caplog): # these should not be set if monitor is None assert checkpoint_callback.monitor is None - assert checkpoint_callback.best_model_path == tmpdir / "epoch=1-step=19.ckpt" + assert checkpoint_callback.best_model_path == tmpdir / "epoch=1-step=20.ckpt" assert checkpoint_callback.last_model_path == tmpdir / "last.ckpt" assert checkpoint_callback.best_model_score is None assert checkpoint_callback.best_k_models == {} assert checkpoint_callback.kth_best_model_path == "" # check that the correct ckpts were created - expected = [f"epoch={i}-step={j}.ckpt" for i, j in zip(range(epochs), [9, 19])] + expected = [f"epoch={i}-step={j}.ckpt" for i, j in zip(range(epochs), [10, 20])] expected.append("last.ckpt") assert set(os.listdir(tmpdir)) == set(expected) @@ -642,7 +642,7 @@ def test_ckpt_every_n_train_steps(tmpdir): trainer.fit(model) expected = [ - f"step={i}.ckpt" for i in range(every_n_train_steps - 1, max_epochs * epoch_length, every_n_train_steps) + f"step={i}.ckpt" for i in range(every_n_train_steps, max_epochs * epoch_length + 1, every_n_train_steps) ] assert set(os.listdir(tmpdir)) == set(expected) @@ -766,14 +766,14 @@ def test_default_checkpoint_behavior(tmpdir): save_weights_only = trainer.checkpoint_callback.save_weights_only save_mock.assert_has_calls( [ - call(save_dir / "epoch=0-step=4.ckpt", save_weights_only), - call(save_dir / "epoch=1-step=9.ckpt", save_weights_only), - call(save_dir / "epoch=2-step=14.ckpt", save_weights_only), + call(save_dir / "epoch=0-step=5.ckpt", save_weights_only), + call(save_dir / "epoch=1-step=10.ckpt", save_weights_only), + call(save_dir / "epoch=2-step=15.ckpt", save_weights_only), ] ) ckpts = os.listdir(save_dir) assert len(ckpts) == 1 - assert ckpts[0] == "epoch=2-step=14.ckpt" + assert ckpts[0] == "epoch=2-step=15.ckpt" @pytest.mark.parametrize("max_epochs", [1, 2]) @@ -784,13 +784,17 @@ def test_model_checkpoint_save_last_warning( tmpdir, caplog, max_epochs: int, should_validate: bool, save_last: bool, verbose: bool ): """Tests 'Saving latest checkpoint...' log.""" - model = LogInTwoMethods() - if not should_validate: - model.validation_step = None - ckpt = ModelCheckpoint(monitor="early_stop_on", dirpath=tmpdir, save_top_k=0, save_last=save_last, verbose=verbose) + # set a high `every_n_epochs` to avoid saving in `on_train_epoch_end`. the message is only printed `on_train_end` + # but it would get skipped because it got already saved in `on_train_epoch_end` for the same global step + ckpt = ModelCheckpoint(dirpath=tmpdir, save_top_k=0, save_last=save_last, verbose=verbose, every_n_epochs=123) trainer = Trainer( - default_root_dir=tmpdir, callbacks=[ckpt], max_epochs=max_epochs, limit_train_batches=1, limit_val_batches=1 + default_root_dir=tmpdir, + callbacks=[ckpt], + max_epochs=max_epochs, + limit_train_batches=1, + limit_val_batches=int(should_validate), ) + model = BoringModel() with caplog.at_level(logging.INFO): trainer.fit(model) assert caplog.messages.count("Saving latest checkpoint...") == (verbose and save_last) @@ -821,9 +825,7 @@ def test_model_checkpoint_save_last_checkpoint_contents(tmpdir): ckpt_last_epoch = torch.load(path_last_epoch) ckpt_last = torch.load(path_last) - # `-1` because this checkpoint is saved `on_train_epoch_end` which is considered part of the epoch so the - # `current_epoch` count has not been increased yet - assert ckpt_last_epoch["epoch"] == ckpt_last["epoch"] - 1 + assert ckpt_last_epoch["epoch"] == ckpt_last["epoch"] assert ckpt_last_epoch["global_step"] == ckpt_last["global_step"] ckpt_id = ( @@ -1041,7 +1043,7 @@ def test_val_check_interval_checkpoint_files(tmpdir): ) trainer.fit(model) files = {p.basename for p in tmpdir.listdir()} - assert files == {f"epoch=0-step={s}.ckpt" for s in [1, 3, 5, 7, 9]} + assert files == {f"epoch=0-step={s}.ckpt" for s in [2, 4, 6, 8, 10]} def test_current_score(tmpdir): diff --git a/tests/checkpointing/test_trainer_checkpoint.py b/tests/checkpointing/test_trainer_checkpoint.py index 24268e3cfca84..5d129179c7c5d 100644 --- a/tests/checkpointing/test_trainer_checkpoint.py +++ b/tests/checkpointing/test_trainer_checkpoint.py @@ -71,19 +71,3 @@ def validation_step(self, batch, batch_idx): assert best_model_path.endswith(f"epoch=0{idx}.ckpt") else: assert f"epoch={idx + 1}" in best_model_path - - -def test_accumulated_gradient_batches_with_ckpt_path(tmpdir): - """This test validates that accumulated gradient is properly recomputed and reset on the trainer.""" - - ckpt = ModelCheckpoint(dirpath=tmpdir, save_last=True) - model = BoringModel() - trainer_kwargs = dict( - max_epochs=1, accumulate_grad_batches={0: 2}, callbacks=ckpt, limit_train_batches=1, limit_val_batches=0 - ) - trainer = Trainer(**trainer_kwargs) - trainer.fit(model) - - trainer_kwargs["max_epochs"] = 2 - trainer = Trainer(**trainer_kwargs) - trainer.fit(model, ckpt_path=ckpt.last_model_path) diff --git a/tests/loggers/test_comet.py b/tests/loggers/test_comet.py index 37758e904256a..e09b954a61a6a 100644 --- a/tests/loggers/test_comet.py +++ b/tests/loggers/test_comet.py @@ -156,7 +156,7 @@ def test_comet_logger_dirs_creation(comet, comet_experiment, tmpdir, monkeypatch trainer.fit(model) assert trainer.checkpoint_callback.dirpath == (tmpdir / "test" / "1" / "checkpoints") - assert set(os.listdir(trainer.checkpoint_callback.dirpath)) == {"epoch=0-step=2.ckpt"} + assert set(os.listdir(trainer.checkpoint_callback.dirpath)) == {"epoch=0-step=3.ckpt"} assert trainer.log_dir == logger.save_dir diff --git a/tests/loggers/test_mlflow.py b/tests/loggers/test_mlflow.py index 46c85f13e29e4..5ce5ceb75a0b1 100644 --- a/tests/loggers/test_mlflow.py +++ b/tests/loggers/test_mlflow.py @@ -136,7 +136,7 @@ def test_mlflow_log_dir(client, mlflow, tmpdir): assert trainer.log_dir == logger.save_dir trainer.fit(model) assert trainer.checkpoint_callback.dirpath == (tmpdir / "exp-id" / "run-id" / "checkpoints") - assert set(os.listdir(trainer.checkpoint_callback.dirpath)) == {"epoch=0-step=0.ckpt"} + assert set(os.listdir(trainer.checkpoint_callback.dirpath)) == {"epoch=0-step=1.ckpt"} assert trainer.log_dir == logger.save_dir @@ -177,7 +177,7 @@ def training_epoch_end(self, *args, **kwargs): assert "epoch" in os.listdir(tmpdir / exp_id / run_id / "metrics") assert set(os.listdir(tmpdir / exp_id / run_id / "params")) == model.hparams.keys() assert trainer.checkpoint_callback.dirpath == (tmpdir / exp_id / run_id / "checkpoints") - assert os.listdir(trainer.checkpoint_callback.dirpath) == [f"epoch=0-step={limit_batches - 1}.ckpt"] + assert os.listdir(trainer.checkpoint_callback.dirpath) == [f"epoch=0-step={limit_batches}.ckpt"] @mock.patch("pytorch_lightning.loggers.mlflow.mlflow") diff --git a/tests/loggers/test_wandb.py b/tests/loggers/test_wandb.py index 280303a3f7318..adb91aab6da32 100644 --- a/tests/loggers/test_wandb.py +++ b/tests/loggers/test_wandb.py @@ -156,7 +156,7 @@ def test_wandb_logger_dirs_creation(wandb, monkeypatch, tmpdir): trainer.fit(model) assert trainer.checkpoint_callback.dirpath == str(tmpdir / "project" / version / "checkpoints") - assert set(os.listdir(trainer.checkpoint_callback.dirpath)) == {"epoch=0-step=2.ckpt"} + assert set(os.listdir(trainer.checkpoint_callback.dirpath)) == {"epoch=0-step=3.ckpt"} assert trainer.log_dir == logger.save_dir @@ -212,7 +212,7 @@ def test_wandb_log_model(wandb, monkeypatch, tmpdir): type="model", metadata={ "score": None, - "original_filename": "epoch=1-step=5-v3.ckpt", + "original_filename": "epoch=1-step=6-v3.ckpt", "ModelCheckpoint": { "monitor": None, "mode": "min", diff --git a/tests/loops/test_loops.py b/tests/loops/test_loops.py index d578cecdab01e..b812644434230 100644 --- a/tests/loops/test_loops.py +++ b/tests/loops/test_loops.py @@ -747,7 +747,7 @@ def test_fit_loop_reset(tmpdir): trainer.fit(model) # reset state loaded from a checkpoint from mid-epoch - mid_epoch_ckpt = torch.load(str(tmpdir / "epoch=0-step=1.ckpt")) + mid_epoch_ckpt = torch.load(str(tmpdir / "epoch=0-step=2.ckpt")) fit_loop = trainer.fit_loop epoch_loop = fit_loop.epoch_loop optimizer_loop = epoch_loop.batch_loop.optimizer_loop @@ -780,7 +780,7 @@ def test_fit_loop_reset(tmpdir): assert optimizer_loop.optim_progress.optimizer_position == 1 # reset state loaded from a checkpoint from the end of an epoch - end_of_epoch_ckpt = torch.load(str(tmpdir / "epoch=0-step=3.ckpt")) + end_of_epoch_ckpt = torch.load(str(tmpdir / "epoch=0-step=4.ckpt")) fit_loop = trainer.fit_loop epoch_loop = fit_loop.epoch_loop fit_loop.restarting = False @@ -947,8 +947,7 @@ def val_dataloader(self): ) trainer.fit(model, ckpt_path=ckpt_path) - # TODO: -1 because there's a bug where global step is off by one on reload - assert trainer.global_step - 1 == expected_global_step + assert trainer.global_step == expected_global_step state_dict_after_restart = trainer.fit_loop.state_dict() diff --git a/tests/loops/test_training_loop.py b/tests/loops/test_training_loop.py index 8329389a93944..87afdbcd70903 100644 --- a/tests/loops/test_training_loop.py +++ b/tests/loops/test_training_loop.py @@ -124,7 +124,7 @@ def validation_step(self, *args): # even though we stopped mid epoch, the fit loop finished normally and the current epoch was increased assert trainer.current_epoch == 1 assert trainer.global_step == 5 - assert model.validation_called_at == (0, 4) + assert model.validation_called_at == (0, 5) def test_warning_valid_train_step_end(tmpdir): diff --git a/tests/models/test_amp.py b/tests/models/test_amp.py index 17135b98c16f5..917bb4d224194 100644 --- a/tests/models/test_amp.py +++ b/tests/models/test_amp.py @@ -199,7 +199,9 @@ def configure_optimizers(self): assert str(trainer.amp_backend) == "AMPType.APEX" trainer.fit(model) assert trainer.state.finished, f"Training failed with {trainer.state}" - assert bwd_mock.call_count == 10 + # `max_steps` is fulfilled in the third batch first optimizer, but we don't check the loop + # `done` condition until all optimizers have run, so the number of backwards is higher than `max_steps` + assert bwd_mock.call_count == 6 assert isinstance(trainer.lr_scheduler_configs[0].scheduler.optimizer, optim.Adam) assert isinstance(trainer.lr_scheduler_configs[1].scheduler.optimizer, optim.SGD) diff --git a/tests/models/test_restore.py b/tests/models/test_restore.py index 53838691a2efb..0c95c96f16cee 100644 --- a/tests/models/test_restore.py +++ b/tests/models/test_restore.py @@ -199,7 +199,9 @@ def on_train_start(self): if self.trainer.state.fn == TrainerFn.TUNING: self._test_on_val_test_predict_tune_start() else: - assert self.trainer.current_epoch == state_dict["epoch"] + # `-1` because this checkpoint is saved `on_train_epoch_end` which is considered part of the epoch so + # the `current_epoch` count has not been increased yet + assert self.trainer.current_epoch - 1 == state_dict["epoch"] assert self.trainer.global_step == state_dict["global_step"] assert self._check_model_state_dict() assert self._check_optimizers() @@ -241,8 +243,7 @@ def test_correct_step_and_epoch(tmpdir): ckpt = torch.load(ckpt_path) assert ckpt["epoch"] == first_max_epochs - # TODO(@carmocca): should not need `+1` - assert ckpt["global_step"] == first_max_epochs * train_batches + 1 + assert ckpt["global_step"] == first_max_epochs * train_batches max_epochs = first_max_epochs + 2 trainer = Trainer( @@ -255,13 +256,11 @@ def test_correct_step_and_epoch(tmpdir): class TestModel(BoringModel): def on_pretrain_routine_end(self) -> None: assert self.trainer.current_epoch == first_max_epochs - # TODO(@carmocca): should not need `+1` - assert self.trainer.global_step == first_max_epochs * train_batches + 1 + assert self.trainer.global_step == first_max_epochs * train_batches trainer.fit(TestModel(), ckpt_path=ckpt_path) assert trainer.current_epoch == max_epochs - # TODO(@carmocca): should not need `+1` - assert trainer.global_step == max_epochs * train_batches + 1 + assert trainer.global_step == max_epochs * train_batches def test_fit_twice(tmpdir): diff --git a/tests/plugins/test_checkpoint_io_plugin.py b/tests/plugins/test_checkpoint_io_plugin.py index 7a1352804ba3d..56aadad353b2e 100644 --- a/tests/plugins/test_checkpoint_io_plugin.py +++ b/tests/plugins/test_checkpoint_io_plugin.py @@ -52,7 +52,7 @@ def test_checkpoint_plugin_called(tmpdir): ) trainer.fit(model) - assert checkpoint_plugin.save_checkpoint.call_count == 5 + assert checkpoint_plugin.save_checkpoint.call_count == 4 assert checkpoint_plugin.remove_checkpoint.call_count == 1 trainer.test(model, ckpt_path=ck.last_model_path) @@ -71,7 +71,7 @@ def test_checkpoint_plugin_called(tmpdir): ) trainer.fit(model) - assert checkpoint_plugin.save_checkpoint.call_count == 5 + assert checkpoint_plugin.save_checkpoint.call_count == 4 assert checkpoint_plugin.remove_checkpoint.call_count == 1 trainer.test(model, ckpt_path=ck.last_model_path) diff --git a/tests/trainer/optimization/test_optimizers.py b/tests/trainer/optimization/test_optimizers.py index 99071ce3d8f8a..38c0a83cabb65 100644 --- a/tests/trainer/optimization/test_optimizers.py +++ b/tests/trainer/optimization/test_optimizers.py @@ -628,7 +628,7 @@ def configure_optimizers(self): def on_save_checkpoint(self, checkpoint): lr_scheduler_config = checkpoint["lr_schedulers"][0] # 2 batches ran. since the lr_scheduler_config interval is `step`, the step count should be 2 - assert self.trainer.global_step + 1 == batches # the global step hasn't been increased yet + assert self.trainer.global_step == batches compare_to = max_epochs if epoch_interval else batches assert lr_scheduler_config["_step_count"] - 1 == compare_to # step count starts at 1 assert lr_scheduler_config["_last_lr"] == [lr * gamma ** compare_to] diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index 064c02660aaeb..3cdb79eef49b3 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -332,7 +332,8 @@ def mock_save_function(filepath, *args): # emulate callback's calls during the training for i, loss in enumerate(losses, 1): - trainer.fit_loop.global_step = i + # sets `trainer.global_step` + trainer.fit_loop.epoch_loop.batch_loop.optimizer_loop.optim_progress.optimizer.step.total.completed = i trainer.callback_metrics.update({"checkpoint_on": torch.tensor(loss)}) checkpoint_callback.on_validation_end(trainer, trainer.lightning_module) trainer.fit_loop.epoch_progress.current.completed = i # sets `trainer.current_epoch` From bb87e052f442de963e665a1ffc62e345631becf7 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Tue, 8 Feb 2022 03:20:37 +0100 Subject: [PATCH 2/9] Remove `ModelCheckpoint.on_train_end` --- .../callbacks/model_checkpoint.py | 13 ---------- .../test_checkpoint_callback_frequency.py | 4 +-- tests/checkpointing/test_model_checkpoint.py | 25 ------------------- 3 files changed, 2 insertions(+), 40 deletions(-) diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index dbf0f77bd4249..6ce7e579a8279 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -314,19 +314,6 @@ def on_validation_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModul return self.save_checkpoint(trainer) - def on_train_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: - """Save a checkpoint when training stops. - - This will only save a checkpoint if `save_last` is also enabled as the monitor metrics logged during - training/validation steps or end of epochs are not guaranteed to be available at this stage. - """ - if self._should_skip_saving_checkpoint(trainer) or not self.save_last: - return - if self.verbose: - rank_zero_info("Saving latest checkpoint...") - monitor_candidates = self._monitor_candidates(trainer, trainer.current_epoch, trainer.global_step) - self._save_last_checkpoint(trainer, monitor_candidates) - def on_save_checkpoint( self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", checkpoint: Dict[str, Any] ) -> Dict[str, Any]: diff --git a/tests/checkpointing/test_checkpoint_callback_frequency.py b/tests/checkpointing/test_checkpoint_callback_frequency.py index 90665a6db476e..eeec11c6ecd14 100644 --- a/tests/checkpointing/test_checkpoint_callback_frequency.py +++ b/tests/checkpointing/test_checkpoint_callback_frequency.py @@ -81,8 +81,8 @@ def training_step(self, batch, batch_idx): trainer.fit(model) if save_last: - # last epochs are saved every step (so double the save calls) and once `on_train_end` - expected = expected * 2 + 1 + # last epochs are saved every step (so double the save calls) + expected = expected * 2 assert save_mock.call_count == expected diff --git a/tests/checkpointing/test_model_checkpoint.py b/tests/checkpointing/test_model_checkpoint.py index 4e4d8f03e42cd..331a5d3a5edeb 100644 --- a/tests/checkpointing/test_model_checkpoint.py +++ b/tests/checkpointing/test_model_checkpoint.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import logging import math import os import pickle @@ -776,30 +775,6 @@ def test_default_checkpoint_behavior(tmpdir): assert ckpts[0] == "epoch=2-step=15.ckpt" -@pytest.mark.parametrize("max_epochs", [1, 2]) -@pytest.mark.parametrize("should_validate", [True, False]) -@pytest.mark.parametrize("save_last", [True, False]) -@pytest.mark.parametrize("verbose", [True, False]) -def test_model_checkpoint_save_last_warning( - tmpdir, caplog, max_epochs: int, should_validate: bool, save_last: bool, verbose: bool -): - """Tests 'Saving latest checkpoint...' log.""" - # set a high `every_n_epochs` to avoid saving in `on_train_epoch_end`. the message is only printed `on_train_end` - # but it would get skipped because it got already saved in `on_train_epoch_end` for the same global step - ckpt = ModelCheckpoint(dirpath=tmpdir, save_top_k=0, save_last=save_last, verbose=verbose, every_n_epochs=123) - trainer = Trainer( - default_root_dir=tmpdir, - callbacks=[ckpt], - max_epochs=max_epochs, - limit_train_batches=1, - limit_val_batches=int(should_validate), - ) - model = BoringModel() - with caplog.at_level(logging.INFO): - trainer.fit(model) - assert caplog.messages.count("Saving latest checkpoint...") == (verbose and save_last) - - def test_model_checkpoint_save_last_checkpoint_contents(tmpdir): """Tests that the save_last checkpoint contains the latest information.""" seed_everything(100) From 8f8ea989e79d16526461cae8b48d67fbdac45cf4 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Wed, 23 Feb 2022 05:06:46 +0100 Subject: [PATCH 3/9] Docs --- docs/source/common/lightning_module.rst | 13 +++++++------ docs/source/common/trainer.rst | 19 +++++++++++++------ 2 files changed, 20 insertions(+), 12 deletions(-) diff --git a/docs/source/common/lightning_module.rst b/docs/source/common/lightning_module.rst index 18a0fbead4e7d..85db8784dcc2a 100644 --- a/docs/source/common/lightning_module.rst +++ b/docs/source/common/lightning_module.rst @@ -914,7 +914,7 @@ These are properties available in a LightningModule. current_epoch ~~~~~~~~~~~~~ -The current epoch +The number of epochs run. .. code-block:: python @@ -944,12 +944,13 @@ usually do not need to use this property, but it is useful to know how to access def training_step(self, batch, batch_idx): if self.global_rank == 0: # do something only once across all the nodes - self.log("global_step", self.trainer.global_step) + ... global_step ~~~~~~~~~~~ -The current step (does not reset each epoch) +The number of optimizer steps taken (does not reset each epoch). +This includes multiple optimizers and TBPTT steps (if enabled). .. code-block:: python @@ -1001,16 +1002,16 @@ The list of loggers currently being used by the Trainer. local_rank ~~~~~~~~~~~ -The ``global_rank`` is the index of the current process across all the devices for the current node. +The ``local_rank`` is the index of the current process across all the devices for the current node. You usually do not need to use this property, but it is useful to know how to access it if needed. For example, if using 10 machines (or nodes), the GPU at index 0 on each machine has local_rank = 0. .. code-block:: python def training_step(self, batch, batch_idx): - if self.global_rank == 0: + if self.local_rank == 0: # do something only once across each node - self.log("global_step", self.trainer.global_step) + ... precision ~~~~~~~~~ diff --git a/docs/source/common/trainer.rst b/docs/source/common/trainer.rst index ed5ec193db603..36dc2074129a8 100644 --- a/docs/source/common/trainer.rst +++ b/docs/source/common/trainer.rst @@ -1725,16 +1725,23 @@ The metrics available to callbacks. These are automatically set when you log via current_epoch ************* -The current epoch +The number of epochs run. .. code-block:: python - def training_step(self, batch, batch_idx): - current_epoch = self.trainer.current_epoch - if current_epoch > 100: - # do something - pass + if trainer.current_epoch >= 10: + ... + +global_step +*********** + +The number of optimizer steps taken (does not reset each epoch). +This includes multiple optimizers and TBPTT steps (if enabled). +.. code-block:: python + + if trainer.global_step >= 100: + ... logger ******* From d2e9ced44293ab17b55ac89d052dc0a009ba0c7b Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Wed, 23 Feb 2022 05:22:54 +0100 Subject: [PATCH 4/9] CHANGELOG --- CHANGELOG.md | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index a0dea9cd0f927..882841ace410a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -294,6 +294,15 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - The `trainer.current_epoch` value is now increased by 1 during and after `on_train_end` ([#8578](https://github.com/PyTorchLightning/pytorch-lightning/pull/8578)) +- The `trainer.global_step` value now accounts for multiple optimizers and TBPTT splits ([#11805](https://github.com/PyTorchLightning/pytorch-lightning/pull/11805)) + + +- The `trainer.global_step` value is now increased right after the `optimizer.step()` call which will impact users who access it during an intra-training validation hook ([#11805](https://github.com/PyTorchLightning/pytorch-lightning/pull/11805)) + + +- The `trainer.global_step` value is now increased by one when included as part of the checkpoint's filename ([#11805](https://github.com/PyTorchLightning/pytorch-lightning/pull/11805)) + + - Inherit from `ABC` for `Accelerator`: Users need to implement `auto_device_count` ([#11521](https://github.com/PyTorchLightning/pytorch-lightning/pull/11521)) From e8e620cb86e44f94bc4ac668b7255db9abbc24d0 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Wed, 23 Feb 2022 06:23:44 +0100 Subject: [PATCH 5/9] Fix `tests/strategies/test_deepspeed_strategy.py::test_deepspeed_multigpu_stage_2_accumulated_grad_batches[False]` --- pytorch_lightning/loops/optimization/optimizer_loop.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/loops/optimization/optimizer_loop.py b/pytorch_lightning/loops/optimization/optimizer_loop.py index f8d692d688035..bab025466789a 100644 --- a/pytorch_lightning/loops/optimization/optimizer_loop.py +++ b/pytorch_lightning/loops/optimization/optimizer_loop.py @@ -359,7 +359,11 @@ def _optimizer_step( else: optimizer = self.trainer.strategy._lightning_optimizers[opt_idx] - self.optim_progress.optimizer.step.increment_ready() + # if `strategy.handles_gradient_accumulation`, this method will be called to route into the strategy, but we + # need to check again if `should_accumulate` before increasing the counters + should_accumulate = self.trainer.fit_loop._should_accumulate() + if not should_accumulate: + self.optim_progress.optimizer.step.increment_ready() # model hook self.trainer._call_lightning_module_hook( @@ -374,7 +378,8 @@ def _optimizer_step( using_lbfgs=is_lbfgs, ) - self.optim_progress.optimizer.step.increment_completed() + if not should_accumulate: + self.optim_progress.optimizer.step.increment_completed() def _on_before_zero_grad(self, optimizer: torch.optim.Optimizer) -> None: """Calls the ``on_before_zero_grad`` hook. From 34fe14b14f4b965fe615534e3bdb000ace2b4bc1 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Thu, 24 Feb 2022 22:59:41 +0100 Subject: [PATCH 6/9] Adrian's review --- CHANGELOG.md | 2 +- pytorch_lightning/trainer/connectors/checkpoint_connector.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 882841ace410a..ffc460e94e88a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -300,7 +300,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - The `trainer.global_step` value is now increased right after the `optimizer.step()` call which will impact users who access it during an intra-training validation hook ([#11805](https://github.com/PyTorchLightning/pytorch-lightning/pull/11805)) -- The `trainer.global_step` value is now increased by one when included as part of the checkpoint's filename ([#11805](https://github.com/PyTorchLightning/pytorch-lightning/pull/11805)) +- The filename of checkpoints created with `ModelCheckpoint(filename='{step}')` is different compared to previous versions. A checkpoint saved after 1 step will be named `step=1.ckpt` instead of `step=0.ckpt` ([#11805](https://github.com/PyTorchLightning/pytorch-lightning/pull/11805)) - Inherit from `ABC` for `Accelerator`: Users need to implement `auto_device_count` ([#11521](https://github.com/PyTorchLightning/pytorch-lightning/pull/11521)) diff --git a/pytorch_lightning/trainer/connectors/checkpoint_connector.py b/pytorch_lightning/trainer/connectors/checkpoint_connector.py index e93f4172b5c6d..01cb5a37c91aa 100644 --- a/pytorch_lightning/trainer/connectors/checkpoint_connector.py +++ b/pytorch_lightning/trainer/connectors/checkpoint_connector.py @@ -232,11 +232,11 @@ def restore_loops(self) -> None: return fit_loop = self.trainer.fit_loop - # set the `global_step` value for old checkpoints without the progress tracking state. + # set the `global_step` value for checkpoints before v1.6 without the progress tracking state. # it will be overwritten by the loop's state if it was also saved optimizer_loop = fit_loop.epoch_loop.batch_loop.optimizer_loop optimizer_loop.optim_progress.optimizer.step.total.completed = self._loaded_checkpoint["global_step"] - # set the `current_epoch` value for old checkpoints without the progress tracking state. + # set the `current_epoch` value for checkpoints before v1.6 without the progress tracking state. # it will be overwritten by the loop's state if it was also saved fit_loop.epoch_progress.current.completed = self._loaded_checkpoint["epoch"] From 40ba7e683eab578b1610bf8c7ea52ae00b69e58e Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Mon, 28 Feb 2022 15:56:00 +0100 Subject: [PATCH 7/9] Fix test --- tests/checkpointing/test_model_checkpoint.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/checkpointing/test_model_checkpoint.py b/tests/checkpointing/test_model_checkpoint.py index 331a5d3a5edeb..a51923b75fb44 100644 --- a/tests/checkpointing/test_model_checkpoint.py +++ b/tests/checkpointing/test_model_checkpoint.py @@ -1253,4 +1253,4 @@ def test_last_global_step_saved(): trainer = MagicMock() trainer.callback_metrics = {"foo": 123} model_checkpoint.save_checkpoint(trainer) - assert model_checkpoint._last_global_step_saved == -1 + assert model_checkpoint._last_global_step_saved == 0 From 64c3c77a29fcd515cc112e4b648bf7504bcd98d4 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Wed, 2 Mar 2022 12:36:23 +0100 Subject: [PATCH 8/9] Docs --- docs/source/common/trainer.rst | 6 +++--- pytorch_lightning/trainer/trainer.py | 4 ++++ 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/docs/source/common/trainer.rst b/docs/source/common/trainer.rst index c2c14ac2e6335..0c04920872b29 100644 --- a/docs/source/common/trainer.rst +++ b/docs/source/common/trainer.rst @@ -934,7 +934,7 @@ max_steps | -Stop training after this number of steps +Stop training after this number of :ref:`global steps `. Training will stop if max_steps or max_epochs have reached (earliest). .. testcode:: @@ -959,7 +959,7 @@ min_steps | -Force training for at least these number of steps. +Force training for at least these number of :ref:`global steps `. Trainer will train model for at least min_steps or min_epochs (latest). .. testcode:: @@ -1829,4 +1829,4 @@ The metrics sent to the progress bar. estimated_stepping_batches ************************** -Check out :paramref:`~pytorch_lightning.trainer.trainer.Trainer.estimated_stepping_batches`. +Check out :meth:`~pytorch_lightning.trainer.trainer.Trainer.estimated_stepping_batches`. diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index c9d149bd9287d..a62f0d0abf998 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -2484,6 +2484,10 @@ def sanity_checking(self, val: bool) -> None: @property def global_step(self) -> int: + """The number of optimizer steps taken (does not reset each epoch). + + This includes multiple optimizers and TBPTT steps (if enabled). + """ return self.fit_loop.epoch_loop.global_step @property From 08168eb076676b99f040b1afed746d76315989c5 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Wed, 2 Mar 2022 12:39:45 +0100 Subject: [PATCH 9/9] Typo --- docs/source/common/trainer.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/common/trainer.rst b/docs/source/common/trainer.rst index 0c04920872b29..5687369283263 100644 --- a/docs/source/common/trainer.rst +++ b/docs/source/common/trainer.rst @@ -959,7 +959,7 @@ min_steps | -Force training for at least these number of :ref:`global steps `. +Force training for at least this number of :ref:`global steps `. Trainer will train model for at least min_steps or min_epochs (latest). .. testcode::