Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 9 additions & 19 deletions pytorch_lightning/trainer/connectors/optimizer_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List, Optional
from typing import Any, List, Optional
from weakref import proxy

import pytorch_lightning as pl
Expand Down Expand Up @@ -47,7 +47,7 @@ def update_learning_rates(
if opt_indices is None:
opt_indices = []

for scheduler_idx, lr_scheduler in enumerate(self.trainer.lr_schedulers):
for lr_scheduler in self.trainer.lr_schedulers:
if isinstance(lr_scheduler["opt_idx"], int) and lr_scheduler["opt_idx"] not in opt_indices:
continue

Expand All @@ -59,11 +59,11 @@ def update_learning_rates(
# Take step if call to update_learning_rates matches the interval key and
# the current step modulo the schedulers frequency is zero
if lr_scheduler["interval"] == interval and current_idx % lr_scheduler["frequency"] == 0:
# If instance of ReduceLROnPlateau, we need a monitor
monitor_key, monitor_val = None, None
monitor_val = None
if lr_scheduler["reduce_on_plateau"]:
# If instance of ReduceLROnPlateau, we need a monitor
monitor_key = lr_scheduler["monitor"]
monitor_val = self.trainer.callback_metrics.get(monitor_key)
monitor_val = self._get_monitor_value(monitor_key)
if monitor_val is None:
if lr_scheduler.get("strict", True):
avail_metrics = list(self.trainer.callback_metrics)
Expand All @@ -79,27 +79,17 @@ def update_learning_rates(
RuntimeWarning,
)
continue
# update LR
old_lr = lr_scheduler["scheduler"].optimizer.param_groups[0]["lr"]

self.trainer.fit_loop.epoch_loop.scheduler_progress.increment_ready()

# update LR
if lr_scheduler["reduce_on_plateau"]:
lr_scheduler["scheduler"].step(monitor_val)
else:
lr_scheduler["scheduler"].step()

new_lr = lr_scheduler["scheduler"].optimizer.param_groups[0]["lr"]

self.trainer.fit_loop.epoch_loop.scheduler_progress.increment_completed()

if self.trainer.dev_debugger.enabled:
self.trainer.dev_debugger.track_lr_schedulers_update(
self.trainer.fit_loop.batch_idx,
interval,
scheduler_idx,
old_lr,
new_lr,
monitor_key=monitor_key,
monitor_val=monitor_val,
)
def _get_monitor_value(self, key: str) -> Any:
# this is a separate method to aid in testing
return self.trainer.callback_metrics.get(key)
27 changes: 1 addition & 26 deletions pytorch_lightning/utilities/debugging.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,8 @@
import os
import time
from functools import wraps
from typing import Any, Callable, Dict, List, Optional, Union
from typing import Any, Callable, Dict, List, Optional

import torch
from torch.utils.data import DataLoader

import pytorch_lightning as pl
Expand All @@ -44,7 +43,6 @@ def __init__(self, trainer: "pl.Trainer") -> None:
self.enabled = os.environ.get("PL_DEV_DEBUG", "0") == "1"
self.trainer = trainer
self.events: List[Dict[str, Any]] = []
self.saved_lr_scheduler_updates: List[Dict[str, Union[int, float, str, torch.Tensor, None]]] = []
self.train_dataloader_calls: List[Dict[str, Any]] = []
self.val_dataloader_calls: List[Dict[str, Any]] = []
self.test_dataloader_calls: List[Dict[str, Any]] = []
Expand Down Expand Up @@ -100,26 +98,3 @@ def track_load_dataloader_call(self, name: str, dataloaders: List[DataLoader]) -
self.val_dataloader_calls.append(values)
elif "test" in name:
self.test_dataloader_calls.append(values)

@enabled_only
def track_lr_schedulers_update(
self,
batch_idx: int,
interval: int,
scheduler_idx: int,
old_lr: float,
new_lr: float,
monitor_key: Optional[str] = None,
monitor_val: Optional[torch.Tensor] = None,
) -> None:
loss_dict = {
"batch_idx": batch_idx,
"interval": interval,
"scheduler_idx": scheduler_idx,
"epoch": self.trainer.current_epoch,
"monitor_key": monitor_key,
"monitor_val": monitor_val,
"old_lr": old_lr,
"new_lr": new_lr,
}
self.saved_lr_scheduler_updates.append(loss_dict)
52 changes: 27 additions & 25 deletions tests/checkpointing/test_model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,20 @@ def validation_epoch_end(self, outputs):
self.log("val_acc", outs)


@mock.patch.dict(os.environ, {"PL_DEV_DEBUG": "1"})
def mock_optimizer_connector(trainer):
# do not use `unittest.Mock` because we need to store the return value
calls = {}
old_get_monitor_value = trainer.optimizer_connector._get_monitor_value

def mock(key):
value = old_get_monitor_value(key)
calls[trainer.current_epoch] = {key: value}
return value

trainer.optimizer_connector._get_monitor_value = mock
return calls


@pytest.mark.parametrize(
"validation_step_none,val_dataloaders_none,monitor",
[(False, False, "val_log"), (True, False, "train_log_epoch"), (False, True, "val_log")],
Expand Down Expand Up @@ -137,13 +150,11 @@ def on_validation_epoch_end(self):
max_epochs=max_epochs,
progress_bar_refresh_rate=0,
)
calls = mock_optimizer_connector(trainer)
trainer.fit(model)
assert trainer.state.finished, f"Training failed with {trainer.state}"

ckpt_files = list(Path(tmpdir).glob("*.ckpt"))
lr_scheduler_debug = trainer.dev_debugger.saved_lr_scheduler_updates
assert len(ckpt_files) == len(model.scores) == max_epochs
assert len(lr_scheduler_debug) == max_epochs

for epoch in range(max_epochs):
score = model.scores[epoch]
Expand All @@ -169,12 +180,10 @@ def on_validation_epoch_end(self):
# checkpoint is saved after updating lr_scheduler states
assert actual_step_count == epoch + 2 # step_count starts at 1
assert actual_lr == lr * gamma ** (epoch + 1)

assert lr_scheduler_debug[epoch]["monitor_val"] == (score if reduce_lr_on_plateau else None)
assert lr_scheduler_debug[epoch]["monitor_key"] == (monitor if reduce_lr_on_plateau else None)
else:
assert calls[epoch] == {monitor: score}


@mock.patch.dict(os.environ, {"PL_DEV_DEBUG": "1"})
@pytest.mark.parametrize(
"val_check_interval,reduce_lr_on_plateau,epoch_aligned",
[(0.25, True, True), (0.25, False, True), (0.42, False, False)],
Expand Down Expand Up @@ -239,33 +248,23 @@ def configure_optimizers(self):
progress_bar_refresh_rate=0,
num_sanity_val_steps=0,
)
calls = mock_optimizer_connector(trainer)
trainer.fit(model)
assert trainer.state.finished, f"Training failed with {trainer.state}"

ckpt_files = list(Path(tmpdir).glob("*.ckpt"))
lr_scheduler_debug = trainer.dev_debugger.saved_lr_scheduler_updates

assert len(ckpt_files) == len(model.scores) == per_epoch_val_checks * max_epochs
assert len(lr_scheduler_debug) == max_epochs

def _make_assertions(epoch, ix, version=""):
def _make_assertions(epoch, ix):
global_ix = ix + per_epoch_val_checks * epoch
duplicated = bool(version)

# checkpoint saved at the end of training epoch will have updated lr_scheduler states
epoch_end_checkpoint = duplicated
if epoch_aligned:
epoch_end_checkpoint = ix == (per_epoch_val_checks - 1)
epoch_end_checkpoint = epoch_aligned and ix == (per_epoch_val_checks - 1)

score = model.scores[global_ix]
expected_score = getattr(model, f"{monitor}s")[global_ix].mean().item()
expected_filename = f"{monitor}={score:.4f}-epoch={epoch}{version}.ckpt"
expected_filename = f"{monitor}={score:.4f}-epoch={epoch}.ckpt"
assert math.isclose(score, expected_score, rel_tol=1e-4)

chk = pl_load(os.path.join(checkpoint.dirpath, expected_filename))
assert chk["epoch"] == epoch + 1
epoch_num = epoch + duplicated
expected_global_step = per_val_train_batches * (global_ix + 1) + (leftover_train_batches * epoch_num)
expected_global_step = per_val_train_batches * (global_ix + 1) + (leftover_train_batches * epoch)
assert chk["global_step"] == expected_global_step

mc_specific_data = chk["callbacks"][
Expand All @@ -284,12 +283,15 @@ def _make_assertions(epoch, ix, version=""):

return score

ckpt_files = list(Path(tmpdir).glob("*.ckpt"))
assert len(ckpt_files) == len(model.scores) == per_epoch_val_checks * max_epochs

for epoch in range(max_epochs):
for i in range(per_epoch_val_checks):
score = _make_assertions(epoch, i)

assert lr_scheduler_debug[epoch]["monitor_val"] == (score if reduce_lr_on_plateau else None)
assert lr_scheduler_debug[epoch]["monitor_key"] == (monitor if reduce_lr_on_plateau else None)
if reduce_lr_on_plateau:
assert calls[epoch] == {monitor: score}


@pytest.mark.parametrize("save_top_k", [-1, 0, 1, 2])
Expand Down