Skip to content

Commit ec6b0e3

Browse files
authored
Merge branch 'master' into fix/4237-auc-unstable-reorder
2 parents d1f2880 + 3abfec8 commit ec6b0e3

File tree

4 files changed

+179
-4
lines changed

4 files changed

+179
-4
lines changed

pytorch_lightning/trainer/connectors/checkpoint_connector.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,9 @@ class CheckpointConnector:
4949
def __init__(self, trainer):
5050
self.trainer = trainer
5151

52+
# used to validate checkpointing logic
53+
self.has_trained = False
54+
5255
def restore_weights(self, model: LightningModule):
5356
"""
5457
We attempt to restore weights in this order:
@@ -246,9 +249,19 @@ def dump_checkpoint(self, weights_only: bool = False) -> dict:
246249
Return:
247250
structured dictionary
248251
"""
252+
253+
current_epoch = self.trainer.current_epoch
254+
global_step = self.trainer.global_step
255+
has_reached_max_steps = self.trainer.max_steps and self.trainer.max_steps <= global_step
256+
257+
global_step += 1
258+
if self.has_trained:
259+
if not has_reached_max_steps:
260+
current_epoch += 1
261+
249262
checkpoint = {
250-
'epoch': self.trainer.current_epoch + 1,
251-
'global_step': self.trainer.global_step + 1,
263+
'epoch': current_epoch,
264+
'global_step': global_step,
252265
'pytorch-lightning_version': pytorch_lightning.__version__,
253266
}
254267

pytorch_lightning/trainer/trainer.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -460,6 +460,8 @@ def fit(
460460
def train(self):
461461
self.run_sanity_check(self.get_model())
462462

463+
self.checkpoint_connector.has_trained = False
464+
463465
# enable train mode
464466
model = self.get_model()
465467
model.train()

pytorch_lightning/trainer/training_loop.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -535,6 +535,7 @@ def run_training_epoch(self):
535535
dataloader_idx = 0
536536
should_check_val = False
537537
for batch_idx, (batch, is_last_batch) in train_dataloader:
538+
538539
self.trainer.batch_idx = batch_idx
539540

540541
# ------------------------------------
@@ -602,6 +603,8 @@ def run_training_epoch(self):
602603
# progress global step according to grads progress
603604
self.increment_accumulated_grad_global_step()
604605

606+
self.trainer.checkpoint_connector.has_trained = True
607+
605608
# log epoch metrics
606609
self.trainer.logger_connector.log_train_epoch_end_metrics(
607610
epoch_output, self.checkpoint_accumulator, self.early_stopping_accumulator, self.num_optimizers

tests/checkpointing/test_model_checkpoint.py

Lines changed: 159 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
import os
15+
import os.path as osp
16+
import pytorch_lightning as pl
1517
from distutils.version import LooseVersion
1618
from unittest.mock import MagicMock, Mock
1719

@@ -30,6 +32,7 @@
3032
from pytorch_lightning.callbacks import ModelCheckpoint
3133
from pytorch_lightning.loggers import TensorBoardLogger
3234
from tests.base import EvalModelTemplate, BoringModel
35+
from pytorch_lightning.utilities.cloud_io import load as pl_load
3336
from pytorch_lightning.utilities.exceptions import MisconfigurationException
3437

3538

@@ -472,7 +475,8 @@ def test_model_checkpoint_save_last_warning(tmpdir, caplog, max_epochs, should_v
472475
model.validation_step = None
473476
trainer = Trainer(
474477
default_root_dir=tmpdir,
475-
checkpoint_callback=ModelCheckpoint(monitor='early_stop_on', dirpath=tmpdir, save_top_k=0, save_last=save_last),
478+
checkpoint_callback=ModelCheckpoint(monitor='early_stop_on', filepath=tmpdir,
479+
save_top_k=0, save_last=save_last),
476480
max_epochs=max_epochs,
477481
)
478482
trainer.fit(model)
@@ -542,7 +546,7 @@ def validation_epoch_end(self, outputs):
542546
assert trainer.dev_debugger.checkpoint_callback_history[-1]['epoch'] == len(monitor) - 1
543547

544548

545-
def test_checkpoint_within_callbacks_list(tmpdir):
549+
def test_checkpoint_repeated_strategy(tmpdir):
546550
"""
547551
This test validates that the checkpoint can be called when provided to callacks list
548552
"""
@@ -572,6 +576,159 @@ def validation_step(self, batch, batch_idx):
572576
trainer.fit(model)
573577
assert os.listdir(tmpdir) == ['epoch=00.ckpt']
574578

579+
def get_last_checkpoint():
580+
ckpts = os.listdir(tmpdir)
581+
ckpts_map = {int(x.split("=")[1].split('.')[0]): osp.join(tmpdir, x) for x in ckpts if "epoch" in x}
582+
num_ckpts = len(ckpts_map) - 1
583+
return ckpts_map[num_ckpts]
584+
585+
for idx in range(1, 5):
586+
# load from checkpoint
587+
chk = get_last_checkpoint()
588+
model = BoringModel.load_from_checkpoint(chk)
589+
trainer = pl.Trainer(max_epochs=1,
590+
limit_train_batches=2,
591+
limit_val_batches=2,
592+
limit_test_batches=2,
593+
resume_from_checkpoint=chk)
594+
trainer.fit(model)
595+
trainer.test(model)
596+
597+
assert str(os.listdir(tmpdir)) == "['epoch=00.ckpt']"
598+
599+
600+
def test_checkpoint_repeated_strategy_tmpdir(tmpdir):
601+
"""
602+
This test validates that the checkpoint can be called when provided to callacks list
603+
"""
604+
605+
os.environ['PL_DEV_DEBUG'] = '1'
606+
607+
checkpoint_callback = ModelCheckpoint(monitor='val_loss', filepath=os.path.join(tmpdir, "{epoch:02d}"))
608+
609+
class ExtendedBoringModel(BoringModel):
610+
611+
def validation_step(self, batch, batch_idx):
612+
output = self.layer(batch)
613+
loss = self.loss(batch, output)
614+
return {"val_loss": loss}
615+
616+
model = ExtendedBoringModel()
617+
model.validation_step_end = None
618+
model.validation_epoch_end = None
619+
trainer = Trainer(
620+
default_root_dir=tmpdir,
621+
max_epochs=1,
622+
limit_train_batches=2,
623+
limit_val_batches=2,
624+
limit_test_batches=2,
625+
callbacks=[checkpoint_callback])
626+
627+
trainer.fit(model)
628+
assert sorted(os.listdir(tmpdir)) == sorted(['epoch=00.ckpt', 'lightning_logs'])
629+
path_to_lightning_logs = osp.join(tmpdir, 'lightning_logs')
630+
assert sorted(os.listdir(path_to_lightning_logs)) == sorted(['version_0'])
631+
632+
def get_last_checkpoint():
633+
ckpts = os.listdir(tmpdir)
634+
ckpts_map = {int(x.split("=")[1].split('.')[0]): osp.join(tmpdir, x) for x in ckpts if "epoch" in x}
635+
num_ckpts = len(ckpts_map) - 1
636+
return ckpts_map[num_ckpts]
637+
638+
for idx in range(1, 5):
639+
640+
# load from checkpoint
641+
chk = get_last_checkpoint()
642+
model = BoringModel.load_from_checkpoint(chk)
643+
trainer = pl.Trainer(default_root_dir=tmpdir,
644+
max_epochs=1,
645+
limit_train_batches=2,
646+
limit_val_batches=2,
647+
limit_test_batches=2,
648+
resume_from_checkpoint=chk)
649+
650+
trainer.fit(model)
651+
trainer.test(model)
652+
assert sorted(os.listdir(tmpdir)) == sorted(['epoch=00.ckpt', 'lightning_logs'])
653+
assert sorted(os.listdir(path_to_lightning_logs)) == sorted([f'version_{i}' for i in range(idx + 1)])
654+
655+
656+
def test_checkpoint_repeated_strategy_extended(tmpdir):
657+
"""
658+
This test validates checkpoint can be called several times without
659+
increasing internally its global step if nothing run.
660+
"""
661+
662+
os.environ['PL_DEV_DEBUG'] = '1'
663+
664+
class ExtendedBoringModel(BoringModel):
665+
666+
def validation_step(self, batch, batch_idx):
667+
output = self.layer(batch)
668+
loss = self.loss(batch, output)
669+
return {"val_loss": loss}
670+
671+
model = ExtendedBoringModel()
672+
model.validation_step_end = None
673+
model.validation_epoch_end = None
674+
trainer = pl.Trainer(default_root_dir=tmpdir,
675+
max_epochs=1,
676+
limit_train_batches=2,
677+
limit_val_batches=2,
678+
limit_test_batches=2,
679+
)
680+
681+
assert trainer.checkpoint_connector.has_trained is not True
682+
assert trainer.current_epoch == 0
683+
trainer.fit(model)
684+
assert trainer.checkpoint_connector.has_trained is True
685+
assert trainer.global_step == 2
686+
assert trainer.current_epoch == 0
687+
trainer.test(model)
688+
assert trainer.current_epoch == 0
689+
assert str(os.listdir(osp.join(tmpdir, 'lightning_logs'))) == "['version_0']"
690+
691+
def get_last_checkpoint():
692+
logs_dir = osp.join(tmpdir, 'lightning_logs')
693+
versions = os.listdir(logs_dir)
694+
versions.sort()
695+
696+
last_version = versions[-1]
697+
ckpt_dir = osp.join(logs_dir, last_version, "checkpoints")
698+
699+
ckpts = os.listdir(ckpt_dir)
700+
ckpts.sort()
701+
702+
return osp.join(ckpt_dir, ckpts[-1])
703+
704+
def assert_checkpoint_content():
705+
chk = pl_load(get_last_checkpoint())
706+
assert chk["epoch"] == 1
707+
assert chk["global_step"] == 2
708+
709+
assert_checkpoint_content()
710+
711+
for idx in range(1, 5):
712+
# load from checkpoint
713+
chk = get_last_checkpoint()
714+
assert_checkpoint_content()
715+
model = BoringModel.load_from_checkpoint(chk)
716+
trainer = pl.Trainer(default_root_dir=tmpdir,
717+
max_epochs=1,
718+
limit_train_batches=2,
719+
limit_val_batches=2,
720+
limit_test_batches=2,
721+
resume_from_checkpoint=chk)
722+
assert trainer.checkpoint_connector.has_trained is not True
723+
assert trainer.global_step == 0
724+
trainer.test(model)
725+
assert trainer.global_step == 2
726+
trainer.fit(model)
727+
assert trainer.global_step == 2
728+
assert trainer.checkpoint_connector.has_trained is not True
729+
lightning_logs_path = osp.join(tmpdir, 'lightning_logs')
730+
assert sorted(os.listdir(lightning_logs_path)) == [f"version_{i}" for i in range(idx + 1)]
731+
575732

576733
@pytest.mark.parametrize(
577734
'filepath, dirpath, filename',

0 commit comments

Comments
 (0)