Skip to content

Commit f930f27

Browse files
authored
Merge branch 'master' into feat_wandb_resume
2 parents 792a94a + b4d926b commit f930f27

File tree

11 files changed

+120
-30
lines changed

11 files changed

+120
-30
lines changed

CHANGELOG.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,9 +27,13 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
2727
- Fixed `LightningOptimizer` exposes optimizer attributes ([#5095](https://github.com/PyTorchLightning/pytorch-lightning/pull/5095))
2828

2929

30+
- Fixed the saved filename in `ModelCheckpoint` when it already exists ([#4861](https://github.com/PyTorchLightning/pytorch-lightning/pull/4861))
31+
32+
3033
- Do not warn when the `name` key is used in the `lr_scheduler` dict ([#5057](https://github.com/PyTorchLightning/pytorch-lightning/pull/5057))
3134

3235

36+
3337
## [1.1.0] - 2020-12-09
3438

3539
### Added

benchmarks/test_parity.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,8 @@
44
import pytest
55
import torch
66

7+
from pytorch_lightning import seed_everything, Trainer
78
import tests.base.develop_utils as tutils
8-
from pytorch_lightning import Trainer, seed_everything
99
from tests.base.models import ParityModuleMNIST, ParityModuleRNN
1010

1111

benchmarks/test_sharded_parity.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import pytest
77
import torch
88

9-
from pytorch_lightning import Trainer, seed_everything
9+
from pytorch_lightning import seed_everything, Trainer
1010
from pytorch_lightning.plugins.ddp_plugin import DDPPlugin
1111
from pytorch_lightning.plugins.sharded_plugin import DDPShardedPlugin
1212
from pytorch_lightning.utilities import FAIRSCALE_AVAILABLE, NATIVE_AMP_AVAILABLE

pyproject.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ exclude = "(.eggs|.git|.hg|.mypy_cache|.nox|.tox|.venv|.svn|_build|buck-out|buil
1616

1717
[tool.isort]
1818
known_first_party = [
19-
"bencharmks",
19+
"benchmarks",
2020
"docs",
2121
"pl_examples",
2222
"pytorch_lightning",
@@ -52,3 +52,5 @@ skip_glob = [
5252
]
5353
profile = "black"
5454
line_length = 120
55+
force_sort_within_sections = "True"
56+
order_by_type = "False"

pytorch_lightning/callbacks/model_checkpoint.py

Lines changed: 28 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -240,17 +240,14 @@ def save_checkpoint(self, trainer, pl_module):
240240
# what can be monitored
241241
monitor_candidates = self._monitor_candidates(trainer)
242242

243-
# ie: path/val_loss=0.5.ckpt
244-
filepath = self._get_metric_interpolated_filepath_name(monitor_candidates, epoch, global_step)
245-
246243
# callback supports multiple simultaneous modes
247244
# here we call each mode sequentially
248245
# Mode 1: save all checkpoints OR only the top k
249246
if self.save_top_k:
250-
self._save_top_k_checkpoints(monitor_candidates, trainer, pl_module, filepath)
247+
self._save_top_k_checkpoints(trainer, pl_module, monitor_candidates)
251248

252249
# Mode 2: save the last checkpoint
253-
self._save_last_checkpoint(trainer, pl_module, monitor_candidates, filepath)
250+
self._save_last_checkpoint(trainer, pl_module, monitor_candidates)
254251

255252
def __validate_init_configuration(self):
256253
if self.save_top_k is not None and self.save_top_k < -1:
@@ -444,6 +441,7 @@ def format_checkpoint_name(
444441
)
445442
if ver is not None:
446443
filename = self.CHECKPOINT_JOIN_CHAR.join((filename, f"v{ver}"))
444+
447445
ckpt_name = f"{filename}{self.FILE_EXTENSION}"
448446
return os.path.join(self.dirpath, ckpt_name) if self.dirpath else ckpt_name
449447

@@ -515,13 +513,20 @@ def _validate_monitor_key(self, trainer):
515513
)
516514
raise MisconfigurationException(m)
517515

518-
def _get_metric_interpolated_filepath_name(self, ckpt_name_metrics: Dict[str, Any], epoch: int, step: int):
516+
def _get_metric_interpolated_filepath_name(
517+
self,
518+
ckpt_name_metrics: Dict[str, Any],
519+
epoch: int,
520+
step: int,
521+
del_filepath: Optional[str] = None
522+
) -> str:
519523
filepath = self.format_checkpoint_name(epoch, step, ckpt_name_metrics)
524+
520525
version_cnt = 0
521-
while self._fs.exists(filepath):
526+
while self._fs.exists(filepath) and filepath != del_filepath:
522527
filepath = self.format_checkpoint_name(epoch, step, ckpt_name_metrics, ver=version_cnt)
523-
# this epoch called before
524528
version_cnt += 1
529+
525530
return filepath
526531

527532
def _monitor_candidates(self, trainer):
@@ -531,13 +536,11 @@ def _monitor_candidates(self, trainer):
531536
ckpt_name_metrics.update({"step": trainer.global_step, "epoch": trainer.current_epoch})
532537
return ckpt_name_metrics
533538

534-
def _save_last_checkpoint(self, trainer, pl_module, ckpt_name_metrics, filepath):
539+
def _save_last_checkpoint(self, trainer, pl_module, ckpt_name_metrics):
535540
should_save_last = self.monitor is None or self.save_last
536541
if not should_save_last:
537542
return
538543

539-
last_filepath = filepath
540-
541544
# when user ALSO asked for the 'last.ckpt' change the name
542545
if self.save_last:
543546
last_filepath = self._format_checkpoint_name(
@@ -548,6 +551,10 @@ def _save_last_checkpoint(self, trainer, pl_module, ckpt_name_metrics, filepath)
548551
prefix=self.prefix
549552
)
550553
last_filepath = os.path.join(self.dirpath, f"{last_filepath}{self.FILE_EXTENSION}")
554+
else:
555+
last_filepath = self._get_metric_interpolated_filepath_name(
556+
ckpt_name_metrics, trainer.current_epoch, trainer.global_step
557+
)
551558

552559
accelerator_backend = trainer.accelerator_backend
553560

@@ -568,7 +575,7 @@ def _save_last_checkpoint(self, trainer, pl_module, ckpt_name_metrics, filepath)
568575
if self.monitor is None:
569576
self.best_model_path = self.last_model_path
570577

571-
def _save_top_k_checkpoints(self, metrics, trainer, pl_module, filepath):
578+
def _save_top_k_checkpoints(self, trainer, pl_module, metrics):
572579
current = metrics.get(self.monitor)
573580
epoch = metrics.get("epoch")
574581
step = metrics.get("step")
@@ -577,7 +584,7 @@ def _save_top_k_checkpoints(self, metrics, trainer, pl_module, filepath):
577584
current = torch.tensor(current, device=pl_module.device)
578585

579586
if self.check_monitor_top_k(current):
580-
self._update_best_and_save(filepath, current, epoch, step, trainer, pl_module)
587+
self._update_best_and_save(current, epoch, step, trainer, pl_module, metrics)
581588
elif self.verbose:
582589
rank_zero_info(
583590
f"Epoch {epoch:d}, step {step:d}: {self.monitor} was not in top {self.save_top_k}"
@@ -588,25 +595,26 @@ def _is_valid_monitor_key(self, metrics):
588595

589596
def _update_best_and_save(
590597
self,
591-
filepath: str,
592598
current: torch.Tensor,
593599
epoch: int,
594600
step: int,
595601
trainer,
596602
pl_module,
603+
ckpt_name_metrics
597604
):
598605
k = len(self.best_k_models) + 1 if self.save_top_k == -1 else self.save_top_k
599606

600-
del_list = []
607+
del_filepath = None
601608
if len(self.best_k_models) == k and k > 0:
602-
delpath = self.kth_best_model_path
603-
self.best_k_models.pop(self.kth_best_model_path)
604-
del_list.append(delpath)
609+
del_filepath = self.kth_best_model_path
610+
self.best_k_models.pop(del_filepath)
605611

606612
# do not save nan, replace with +/- inf
607613
if torch.isnan(current):
608614
current = torch.tensor(float('inf' if self.mode == "min" else '-inf'))
609615

616+
filepath = self._get_metric_interpolated_filepath_name(ckpt_name_metrics, epoch, step, del_filepath)
617+
610618
# save the current score
611619
self.current_score = current
612620
self.best_k_models[filepath] = current
@@ -630,9 +638,8 @@ def _update_best_and_save(
630638
)
631639
self._save_model(filepath, trainer, pl_module)
632640

633-
for cur_path in del_list:
634-
if cur_path != filepath:
635-
self._del_model(cur_path)
641+
if del_filepath is not None and filepath != del_filepath:
642+
self._del_model(del_filepath)
636643

637644
def to_yaml(self, filepath: Optional[Union[str, Path]] = None):
638645
"""

pytorch_lightning/setup_tools.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,12 @@
1414
# limitations under the License.
1515
import os
1616
import re
17-
import warnings
1817
from typing import Iterable, List
1918
from urllib.error import HTTPError, URLError
2019
from urllib.request import Request, urlopen
20+
import warnings
2121

22-
from pytorch_lightning import PROJECT_ROOT, __homepage__, __version__
22+
from pytorch_lightning import __homepage__, __version__, PROJECT_ROOT
2323

2424
_PATH_BADGES = os.path.join('.', 'docs', 'source', '_images', 'badges')
2525
# badge to download

pytorch_lightning/trainer/supporters.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ def __init__(self, window_length: int):
5050

5151
def reset(self) -> None:
5252
"""Empty the accumulator."""
53-
self = TensorRunningAccum(self.window_length)
53+
self.__init__(self.window_length)
5454

5555
def last(self):
5656
"""Get the last added element."""

tests/checkpointing/test_model_checkpoint.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -938,3 +938,42 @@ def __init__(self, hparams):
938938
else:
939939
# make sure it's not AttributeDict
940940
assert type(ckpt[model.CHECKPOINT_HYPER_PARAMS_KEY]) == hparams_type
941+
942+
943+
@pytest.mark.parametrize('max_epochs', [3, 4])
944+
@pytest.mark.parametrize(
945+
'save_top_k, expected',
946+
[
947+
(1, ['curr_epoch.ckpt']),
948+
(2, ['curr_epoch.ckpt', 'curr_epoch-v0.ckpt']),
949+
]
950+
)
951+
def test_model_checkpoint_file_already_exists(tmpdir, max_epochs, save_top_k, expected):
952+
"""
953+
Test that version is added to filename if required and it already exists in dirpath.
954+
"""
955+
model_checkpoint = ModelCheckpoint(
956+
dirpath=tmpdir,
957+
filename='curr_epoch',
958+
save_top_k=save_top_k,
959+
monitor='epoch',
960+
mode='max',
961+
)
962+
trainer = Trainer(
963+
default_root_dir=tmpdir,
964+
callbacks=[model_checkpoint],
965+
max_epochs=max_epochs,
966+
limit_train_batches=2,
967+
limit_val_batches=2,
968+
logger=None,
969+
weights_summary=None,
970+
progress_bar_refresh_rate=0,
971+
)
972+
973+
model = BoringModel()
974+
trainer.fit(model)
975+
ckpt_files = os.listdir(tmpdir)
976+
assert set(ckpt_files) == set(expected)
977+
978+
epochs_in_ckpt_files = [pl_load(os.path.join(tmpdir, f))['epoch'] - 1 for f in ckpt_files]
979+
assert sorted(epochs_in_ckpt_files) == list(range(max_epochs - save_top_k, max_epochs))

tests/conftest.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,10 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
import sys
16-
import threading
1715
from functools import partial, wraps
1816
from http.server import SimpleHTTPRequestHandler
17+
import sys
18+
import threading
1919

2020
import pytest
2121
import torch.multiprocessing as mp

tests/test_profiler.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,8 @@
1313
# limitations under the License.
1414

1515
import os
16-
import time
1716
from pathlib import Path
17+
import time
1818

1919
import numpy as np
2020
import pytest

0 commit comments

Comments
 (0)