diff --git a/pytorch_lightning/trainer/connectors/optimizer_connector.py b/pytorch_lightning/trainer/connectors/optimizer_connector.py index 4d0b26763ecfe..e894d9df535d2 100644 --- a/pytorch_lightning/trainer/connectors/optimizer_connector.py +++ b/pytorch_lightning/trainer/connectors/optimizer_connector.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import List, Optional +from typing import Any, List, Optional from weakref import proxy import pytorch_lightning as pl @@ -47,7 +47,7 @@ def update_learning_rates( if opt_indices is None: opt_indices = [] - for scheduler_idx, lr_scheduler in enumerate(self.trainer.lr_schedulers): + for lr_scheduler in self.trainer.lr_schedulers: if isinstance(lr_scheduler["opt_idx"], int) and lr_scheduler["opt_idx"] not in opt_indices: continue @@ -59,11 +59,11 @@ def update_learning_rates( # Take step if call to update_learning_rates matches the interval key and # the current step modulo the schedulers frequency is zero if lr_scheduler["interval"] == interval and current_idx % lr_scheduler["frequency"] == 0: - # If instance of ReduceLROnPlateau, we need a monitor - monitor_key, monitor_val = None, None + monitor_val = None if lr_scheduler["reduce_on_plateau"]: + # If instance of ReduceLROnPlateau, we need a monitor monitor_key = lr_scheduler["monitor"] - monitor_val = self.trainer.callback_metrics.get(monitor_key) + monitor_val = self._get_monitor_value(monitor_key) if monitor_val is None: if lr_scheduler.get("strict", True): avail_metrics = list(self.trainer.callback_metrics) @@ -79,27 +79,17 @@ def update_learning_rates( RuntimeWarning, ) continue - # update LR - old_lr = lr_scheduler["scheduler"].optimizer.param_groups[0]["lr"] self.trainer.fit_loop.epoch_loop.scheduler_progress.increment_ready() + # update LR if lr_scheduler["reduce_on_plateau"]: lr_scheduler["scheduler"].step(monitor_val) else: lr_scheduler["scheduler"].step() - new_lr = lr_scheduler["scheduler"].optimizer.param_groups[0]["lr"] - self.trainer.fit_loop.epoch_loop.scheduler_progress.increment_completed() - if self.trainer.dev_debugger.enabled: - self.trainer.dev_debugger.track_lr_schedulers_update( - self.trainer.fit_loop.batch_idx, - interval, - scheduler_idx, - old_lr, - new_lr, - monitor_key=monitor_key, - monitor_val=monitor_val, - ) + def _get_monitor_value(self, key: str) -> Any: + # this is a separate method to aid in testing + return self.trainer.callback_metrics.get(key) diff --git a/pytorch_lightning/utilities/debugging.py b/pytorch_lightning/utilities/debugging.py index 3860ff0ac005f..978300d44f398 100644 --- a/pytorch_lightning/utilities/debugging.py +++ b/pytorch_lightning/utilities/debugging.py @@ -15,9 +15,8 @@ import os import time from functools import wraps -from typing import Any, Callable, Dict, List, Optional, Union +from typing import Any, Callable, Dict, List, Optional -import torch from torch.utils.data import DataLoader import pytorch_lightning as pl @@ -44,7 +43,6 @@ def __init__(self, trainer: "pl.Trainer") -> None: self.enabled = os.environ.get("PL_DEV_DEBUG", "0") == "1" self.trainer = trainer self.events: List[Dict[str, Any]] = [] - self.saved_lr_scheduler_updates: List[Dict[str, Union[int, float, str, torch.Tensor, None]]] = [] self.train_dataloader_calls: List[Dict[str, Any]] = [] self.val_dataloader_calls: List[Dict[str, Any]] = [] self.test_dataloader_calls: List[Dict[str, Any]] = [] @@ -100,26 +98,3 @@ def track_load_dataloader_call(self, name: str, dataloaders: List[DataLoader]) - self.val_dataloader_calls.append(values) elif "test" in name: self.test_dataloader_calls.append(values) - - @enabled_only - def track_lr_schedulers_update( - self, - batch_idx: int, - interval: int, - scheduler_idx: int, - old_lr: float, - new_lr: float, - monitor_key: Optional[str] = None, - monitor_val: Optional[torch.Tensor] = None, - ) -> None: - loss_dict = { - "batch_idx": batch_idx, - "interval": interval, - "scheduler_idx": scheduler_idx, - "epoch": self.trainer.current_epoch, - "monitor_key": monitor_key, - "monitor_val": monitor_val, - "old_lr": old_lr, - "new_lr": new_lr, - } - self.saved_lr_scheduler_updates.append(loss_dict) diff --git a/tests/checkpointing/test_model_checkpoint.py b/tests/checkpointing/test_model_checkpoint.py index 0c1a6fbd51268..ea2845ff8ce7e 100644 --- a/tests/checkpointing/test_model_checkpoint.py +++ b/tests/checkpointing/test_model_checkpoint.py @@ -63,7 +63,20 @@ def validation_epoch_end(self, outputs): self.log("val_acc", outs) -@mock.patch.dict(os.environ, {"PL_DEV_DEBUG": "1"}) +def mock_optimizer_connector(trainer): + # do not use `unittest.Mock` because we need to store the return value + calls = {} + old_get_monitor_value = trainer.optimizer_connector._get_monitor_value + + def mock(key): + value = old_get_monitor_value(key) + calls[trainer.current_epoch] = {key: value} + return value + + trainer.optimizer_connector._get_monitor_value = mock + return calls + + @pytest.mark.parametrize( "validation_step_none,val_dataloaders_none,monitor", [(False, False, "val_log"), (True, False, "train_log_epoch"), (False, True, "val_log")], @@ -137,13 +150,11 @@ def on_validation_epoch_end(self): max_epochs=max_epochs, progress_bar_refresh_rate=0, ) + calls = mock_optimizer_connector(trainer) trainer.fit(model) - assert trainer.state.finished, f"Training failed with {trainer.state}" ckpt_files = list(Path(tmpdir).glob("*.ckpt")) - lr_scheduler_debug = trainer.dev_debugger.saved_lr_scheduler_updates assert len(ckpt_files) == len(model.scores) == max_epochs - assert len(lr_scheduler_debug) == max_epochs for epoch in range(max_epochs): score = model.scores[epoch] @@ -169,12 +180,10 @@ def on_validation_epoch_end(self): # checkpoint is saved after updating lr_scheduler states assert actual_step_count == epoch + 2 # step_count starts at 1 assert actual_lr == lr * gamma ** (epoch + 1) - - assert lr_scheduler_debug[epoch]["monitor_val"] == (score if reduce_lr_on_plateau else None) - assert lr_scheduler_debug[epoch]["monitor_key"] == (monitor if reduce_lr_on_plateau else None) + else: + assert calls[epoch] == {monitor: score} -@mock.patch.dict(os.environ, {"PL_DEV_DEBUG": "1"}) @pytest.mark.parametrize( "val_check_interval,reduce_lr_on_plateau,epoch_aligned", [(0.25, True, True), (0.25, False, True), (0.42, False, False)], @@ -239,33 +248,23 @@ def configure_optimizers(self): progress_bar_refresh_rate=0, num_sanity_val_steps=0, ) + calls = mock_optimizer_connector(trainer) trainer.fit(model) - assert trainer.state.finished, f"Training failed with {trainer.state}" - ckpt_files = list(Path(tmpdir).glob("*.ckpt")) - lr_scheduler_debug = trainer.dev_debugger.saved_lr_scheduler_updates - - assert len(ckpt_files) == len(model.scores) == per_epoch_val_checks * max_epochs - assert len(lr_scheduler_debug) == max_epochs - - def _make_assertions(epoch, ix, version=""): + def _make_assertions(epoch, ix): global_ix = ix + per_epoch_val_checks * epoch - duplicated = bool(version) # checkpoint saved at the end of training epoch will have updated lr_scheduler states - epoch_end_checkpoint = duplicated - if epoch_aligned: - epoch_end_checkpoint = ix == (per_epoch_val_checks - 1) + epoch_end_checkpoint = epoch_aligned and ix == (per_epoch_val_checks - 1) score = model.scores[global_ix] expected_score = getattr(model, f"{monitor}s")[global_ix].mean().item() - expected_filename = f"{monitor}={score:.4f}-epoch={epoch}{version}.ckpt" + expected_filename = f"{monitor}={score:.4f}-epoch={epoch}.ckpt" assert math.isclose(score, expected_score, rel_tol=1e-4) chk = pl_load(os.path.join(checkpoint.dirpath, expected_filename)) assert chk["epoch"] == epoch + 1 - epoch_num = epoch + duplicated - expected_global_step = per_val_train_batches * (global_ix + 1) + (leftover_train_batches * epoch_num) + expected_global_step = per_val_train_batches * (global_ix + 1) + (leftover_train_batches * epoch) assert chk["global_step"] == expected_global_step mc_specific_data = chk["callbacks"][ @@ -284,12 +283,15 @@ def _make_assertions(epoch, ix, version=""): return score + ckpt_files = list(Path(tmpdir).glob("*.ckpt")) + assert len(ckpt_files) == len(model.scores) == per_epoch_val_checks * max_epochs + for epoch in range(max_epochs): for i in range(per_epoch_val_checks): score = _make_assertions(epoch, i) - assert lr_scheduler_debug[epoch]["monitor_val"] == (score if reduce_lr_on_plateau else None) - assert lr_scheduler_debug[epoch]["monitor_key"] == (monitor if reduce_lr_on_plateau else None) + if reduce_lr_on_plateau: + assert calls[epoch] == {monitor: score} @pytest.mark.parametrize("save_top_k", [-1, 0, 1, 2])