From a24e3099d1645772fc606c7c9ffdf0b215f8f23f Mon Sep 17 00:00:00 2001 From: otaj Date: Tue, 27 Sep 2022 15:31:12 +0200 Subject: [PATCH 01/12] Find last checkpoints on restart --- src/pytorch_lightning/callbacks/model_checkpoint.py | 7 ++++++- .../trainer/connectors/checkpoint_connector.py | 7 ++++--- 2 files changed, 10 insertions(+), 4 deletions(-) diff --git a/src/pytorch_lightning/callbacks/model_checkpoint.py b/src/pytorch_lightning/callbacks/model_checkpoint.py index a789b95a4407d..c0cf270b63847 100644 --- a/src/pytorch_lightning/callbacks/model_checkpoint.py +++ b/src/pytorch_lightning/callbacks/model_checkpoint.py @@ -25,7 +25,7 @@ import warnings from copy import deepcopy from datetime import timedelta -from typing import Any, Dict, Optional +from typing import Any, Dict, List, Optional from weakref import proxy import numpy as np @@ -604,6 +604,11 @@ def __resolve_ckpt_dir(self, trainer: "pl.Trainer") -> None: self.dirpath = ckpt_path + def _find_last_checkpoints(self, trainer: "pl.Trainer") -> List[str]: + # find all checkpoints in the folder + self.__resolve_ckpt_dir(trainer) + return [str(p) for p in self._fs.ls(self.dirpath) if self.CHECKPOINT_NAME_LAST in str(p)] + def __warn_if_dir_not_empty(self, dirpath: _PATH) -> None: if self.save_top_k != 0 and self._fs.isdir(dirpath) and len(self._fs.ls(dirpath)) > 0: rank_zero_warn(f"Checkpoint directory {dirpath} exists and is not empty.") diff --git a/src/pytorch_lightning/trainer/connectors/checkpoint_connector.py b/src/pytorch_lightning/trainer/connectors/checkpoint_connector.py index 0f77927c29915..fd987986d6cf5 100644 --- a/src/pytorch_lightning/trainer/connectors/checkpoint_connector.py +++ b/src/pytorch_lightning/trainer/connectors/checkpoint_connector.py @@ -161,9 +161,10 @@ def _set_ckpt_path( ckpt_path = getattr(self.trainer.checkpoint_callback, "best_model_path", None) if ckpt_path == "last": - candidates = [getattr(ft, "ckpt_path", None) for ft in ft_checkpoints] + [ - getattr(cb, "last_model_path", None) for cb in self.trainer.checkpoint_callbacks - ] + candidates = [getattr(ft, "ckpt_path", None) for ft in ft_checkpoints] + for callback in self.trainer.checkpoint_callbacks: + if hasattr(callback, "_find_last_checkpoints"): + candidates += callback._find_last_checkpoints(self.trainer) candidates_fs = {path: get_filesystem(path) for path in candidates if path} candidates_ts = {path: fs.modified(path) for path, fs in candidates_fs.items() if fs.exists(path)} if not candidates_ts: From 490980347736079ccac9cf23c771d213c2ee8812 Mon Sep 17 00:00:00 2001 From: otaj Date: Tue, 27 Sep 2022 15:39:02 +0200 Subject: [PATCH 02/12] changelog --- src/pytorch_lightning/CHANGELOG.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/pytorch_lightning/CHANGELOG.md b/src/pytorch_lightning/CHANGELOG.md index 408efc314709f..637ca12fcf257 100644 --- a/src/pytorch_lightning/CHANGELOG.md +++ b/src/pytorch_lightning/CHANGELOG.md @@ -264,6 +264,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed an attribute error when running the tuner together with the `StochasticWeightAveraging` callback ([#14836](https://github.com/Lightning-AI/lightning/pull/14836)) +- Fixed reloading of the last checkpoint on run restart ([#14907](https://github.com/Lightning-AI/lightning/pull/14907)) + + ## [1.7.7] - 2022-09-22 ### Fixed From d9e0a70e8082b7a203cf9d5702c6a2c1206c39ed Mon Sep 17 00:00:00 2001 From: otaj Date: Tue, 27 Sep 2022 16:33:34 +0200 Subject: [PATCH 03/12] check if dir exists --- src/pytorch_lightning/callbacks/model_checkpoint.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/pytorch_lightning/callbacks/model_checkpoint.py b/src/pytorch_lightning/callbacks/model_checkpoint.py index c0cf270b63847..4b117145b7582 100644 --- a/src/pytorch_lightning/callbacks/model_checkpoint.py +++ b/src/pytorch_lightning/callbacks/model_checkpoint.py @@ -607,7 +607,9 @@ def __resolve_ckpt_dir(self, trainer: "pl.Trainer") -> None: def _find_last_checkpoints(self, trainer: "pl.Trainer") -> List[str]: # find all checkpoints in the folder self.__resolve_ckpt_dir(trainer) - return [str(p) for p in self._fs.ls(self.dirpath) if self.CHECKPOINT_NAME_LAST in str(p)] + if self._fs.exists(self.dirpath): + return [str(p) for p in self._fs.ls(self.dirpath) if self.CHECKPOINT_NAME_LAST in str(p)] + return [] def __warn_if_dir_not_empty(self, dirpath: _PATH) -> None: if self.save_top_k != 0 and self._fs.isdir(dirpath) and len(self._fs.ls(dirpath)) > 0: From e57f7aebef49d8c54355f25860eb001559e4b57c Mon Sep 17 00:00:00 2001 From: otaj Date: Tue, 27 Sep 2022 16:53:27 +0200 Subject: [PATCH 04/12] add broadcast flag to resolve dir --- src/pytorch_lightning/callbacks/model_checkpoint.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/src/pytorch_lightning/callbacks/model_checkpoint.py b/src/pytorch_lightning/callbacks/model_checkpoint.py index 4b117145b7582..ca9fbe7f18209 100644 --- a/src/pytorch_lightning/callbacks/model_checkpoint.py +++ b/src/pytorch_lightning/callbacks/model_checkpoint.py @@ -571,7 +571,7 @@ def format_checkpoint_name( ckpt_name = f"{filename}{self.FILE_EXTENSION}" return os.path.join(self.dirpath, ckpt_name) if self.dirpath else ckpt_name - def __resolve_ckpt_dir(self, trainer: "pl.Trainer") -> None: + def __resolve_ckpt_dir(self, trainer: "pl.Trainer", broadcast: bool = True) -> None: """Determines model checkpoint save directory at runtime. Reference attributes from the trainer's logger to determine where to save checkpoints. The path for saving weights is set in this priority: @@ -583,6 +583,8 @@ def __resolve_ckpt_dir(self, trainer: "pl.Trainer") -> None: """ if self.dirpath is not None: # short circuit if dirpath was passed to ModelCheckpoint + if broadcast: + self.dirpath = trainer.strategy.broadcast(self.dirpath) return if len(trainer.loggers) > 0: @@ -600,13 +602,14 @@ def __resolve_ckpt_dir(self, trainer: "pl.Trainer") -> None: # if no loggers, use default_root_dir ckpt_path = os.path.join(trainer.default_root_dir, "checkpoints") - ckpt_path = trainer.strategy.broadcast(ckpt_path) + if broadcast: + ckpt_path = trainer.strategy.broadcast(ckpt_path) self.dirpath = ckpt_path def _find_last_checkpoints(self, trainer: "pl.Trainer") -> List[str]: # find all checkpoints in the folder - self.__resolve_ckpt_dir(trainer) + self.__resolve_ckpt_dir(trainer, broadcast=False) if self._fs.exists(self.dirpath): return [str(p) for p in self._fs.ls(self.dirpath) if self.CHECKPOINT_NAME_LAST in str(p)] return [] From 25e0067687b364950f4a5902b34fa6640e4a714a Mon Sep 17 00:00:00 2001 From: otaj Date: Tue, 27 Sep 2022 17:33:58 +0200 Subject: [PATCH 05/12] fix path on windows --- src/pytorch_lightning/callbacks/model_checkpoint.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/src/pytorch_lightning/callbacks/model_checkpoint.py b/src/pytorch_lightning/callbacks/model_checkpoint.py index ca9fbe7f18209..1c6c5a388cbfc 100644 --- a/src/pytorch_lightning/callbacks/model_checkpoint.py +++ b/src/pytorch_lightning/callbacks/model_checkpoint.py @@ -611,7 +611,13 @@ def _find_last_checkpoints(self, trainer: "pl.Trainer") -> List[str]: # find all checkpoints in the folder self.__resolve_ckpt_dir(trainer, broadcast=False) if self._fs.exists(self.dirpath): - return [str(p) for p in self._fs.ls(self.dirpath) if self.CHECKPOINT_NAME_LAST in str(p)] + # fsspec returns a list of files joined with posixpath.separator, which breaks on windows + # That's why we are using `detail=True` and then casting the results to strings manually. + return [ + str(p["name"]) + for p in self._fs.ls(self.dirpath, detail=True) + if self.CHECKPOINT_NAME_LAST in str(p["name"]) + ] return [] def __warn_if_dir_not_empty(self, dirpath: _PATH) -> None: From 2ae4b644ae187e823072d00aa8d608b33408bc87 Mon Sep 17 00:00:00 2001 From: otaj Date: Tue, 27 Sep 2022 18:03:39 +0200 Subject: [PATCH 06/12] fix path on windows second attempt --- src/pytorch_lightning/callbacks/model_checkpoint.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/src/pytorch_lightning/callbacks/model_checkpoint.py b/src/pytorch_lightning/callbacks/model_checkpoint.py index 1c6c5a388cbfc..efd2e9cb1d01e 100644 --- a/src/pytorch_lightning/callbacks/model_checkpoint.py +++ b/src/pytorch_lightning/callbacks/model_checkpoint.py @@ -611,12 +611,8 @@ def _find_last_checkpoints(self, trainer: "pl.Trainer") -> List[str]: # find all checkpoints in the folder self.__resolve_ckpt_dir(trainer, broadcast=False) if self._fs.exists(self.dirpath): - # fsspec returns a list of files joined with posixpath.separator, which breaks on windows - # That's why we are using `detail=True` and then casting the results to strings manually. return [ - str(p["name"]) - for p in self._fs.ls(self.dirpath, detail=True) - if self.CHECKPOINT_NAME_LAST in str(p["name"]) + os.path.normpath(p) for p in self._fs.ls(self.dirpath, detail=False) if self.CHECKPOINT_NAME_LAST in p ] return [] From d3fb41fcf03b852df146e51a210b1d7b582f0014 Mon Sep 17 00:00:00 2001 From: otaj Date: Tue, 27 Sep 2022 18:21:41 +0200 Subject: [PATCH 07/12] less dangerous broadcast --- .../callbacks/model_checkpoint.py | 23 ++++++++----------- 1 file changed, 9 insertions(+), 14 deletions(-) diff --git a/src/pytorch_lightning/callbacks/model_checkpoint.py b/src/pytorch_lightning/callbacks/model_checkpoint.py index efd2e9cb1d01e..bcd048686ce65 100644 --- a/src/pytorch_lightning/callbacks/model_checkpoint.py +++ b/src/pytorch_lightning/callbacks/model_checkpoint.py @@ -255,7 +255,9 @@ def state_key(self) -> str: ) def setup(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage: str) -> None: - self.__resolve_ckpt_dir(trainer) + dirpath = self.__resolve_ckpt_dir(trainer) + dirpath = trainer.strategy.broadcast(dirpath) + self.dirpath = dirpath assert self.dirpath is not None if trainer.is_global_zero and stage == "fit": self.__warn_if_dir_not_empty(self.dirpath) @@ -571,7 +573,7 @@ def format_checkpoint_name( ckpt_name = f"{filename}{self.FILE_EXTENSION}" return os.path.join(self.dirpath, ckpt_name) if self.dirpath else ckpt_name - def __resolve_ckpt_dir(self, trainer: "pl.Trainer", broadcast: bool = True) -> None: + def __resolve_ckpt_dir(self, trainer: "pl.Trainer") -> str: """Determines model checkpoint save directory at runtime. Reference attributes from the trainer's logger to determine where to save checkpoints. The path for saving weights is set in this priority: @@ -583,9 +585,7 @@ def __resolve_ckpt_dir(self, trainer: "pl.Trainer", broadcast: bool = True) -> N """ if self.dirpath is not None: # short circuit if dirpath was passed to ModelCheckpoint - if broadcast: - self.dirpath = trainer.strategy.broadcast(self.dirpath) - return + return self.dirpath if len(trainer.loggers) > 0: if trainer.loggers[0].save_dir is not None: @@ -602,18 +602,13 @@ def __resolve_ckpt_dir(self, trainer: "pl.Trainer", broadcast: bool = True) -> N # if no loggers, use default_root_dir ckpt_path = os.path.join(trainer.default_root_dir, "checkpoints") - if broadcast: - ckpt_path = trainer.strategy.broadcast(ckpt_path) - - self.dirpath = ckpt_path + return ckpt_path def _find_last_checkpoints(self, trainer: "pl.Trainer") -> List[str]: # find all checkpoints in the folder - self.__resolve_ckpt_dir(trainer, broadcast=False) - if self._fs.exists(self.dirpath): - return [ - os.path.normpath(p) for p in self._fs.ls(self.dirpath, detail=False) if self.CHECKPOINT_NAME_LAST in p - ] + ckpt_path = self.__resolve_ckpt_dir(trainer) + if self._fs.exists(ckpt_path): + return [os.path.normpath(p) for p in self._fs.ls(ckpt_path, detail=False) if self.CHECKPOINT_NAME_LAST in p] return [] def __warn_if_dir_not_empty(self, dirpath: _PATH) -> None: From a3ac6cd3cdea7f9f2e75dc07753f03a3993e02e4 Mon Sep 17 00:00:00 2001 From: otaj Date: Tue, 27 Sep 2022 18:37:11 +0200 Subject: [PATCH 08/12] fix mypy --- src/pytorch_lightning/callbacks/model_checkpoint.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/pytorch_lightning/callbacks/model_checkpoint.py b/src/pytorch_lightning/callbacks/model_checkpoint.py index bcd048686ce65..81e9d1f2d757a 100644 --- a/src/pytorch_lightning/callbacks/model_checkpoint.py +++ b/src/pytorch_lightning/callbacks/model_checkpoint.py @@ -238,6 +238,7 @@ def __init__( self.last_model_path = "" self.kth_value: Tensor + self.dirpath: Optional[_PATH] self.__init_monitor_mode(mode) self.__init_ckpt_dir(dirpath, filename) self.__init_triggers(every_n_train_steps, every_n_epochs, train_time_interval) @@ -573,7 +574,7 @@ def format_checkpoint_name( ckpt_name = f"{filename}{self.FILE_EXTENSION}" return os.path.join(self.dirpath, ckpt_name) if self.dirpath else ckpt_name - def __resolve_ckpt_dir(self, trainer: "pl.Trainer") -> str: + def __resolve_ckpt_dir(self, trainer: "pl.Trainer") -> _PATH: """Determines model checkpoint save directory at runtime. Reference attributes from the trainer's logger to determine where to save checkpoints. The path for saving weights is set in this priority: From edf29762f3ea83a0e0c7f67c45505c70521a9523 Mon Sep 17 00:00:00 2001 From: otaj <6065855+otaj@users.noreply.github.com> Date: Tue, 27 Sep 2022 19:00:46 +0200 Subject: [PATCH 09/12] Update src/pytorch_lightning/callbacks/model_checkpoint.py MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Carlos MocholĂ­ --- src/pytorch_lightning/callbacks/model_checkpoint.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/pytorch_lightning/callbacks/model_checkpoint.py b/src/pytorch_lightning/callbacks/model_checkpoint.py index 81e9d1f2d757a..7e8aa0b599a2a 100644 --- a/src/pytorch_lightning/callbacks/model_checkpoint.py +++ b/src/pytorch_lightning/callbacks/model_checkpoint.py @@ -259,7 +259,6 @@ def setup(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage: s dirpath = self.__resolve_ckpt_dir(trainer) dirpath = trainer.strategy.broadcast(dirpath) self.dirpath = dirpath - assert self.dirpath is not None if trainer.is_global_zero and stage == "fit": self.__warn_if_dir_not_empty(self.dirpath) From 7207946799a2d5c17f3614aaea96cee18b01901f Mon Sep 17 00:00:00 2001 From: otaj Date: Tue, 27 Sep 2022 19:08:00 +0200 Subject: [PATCH 10/12] apply suggestions --- src/pytorch_lightning/callbacks/model_checkpoint.py | 8 ++++---- .../trainer/connectors/checkpoint_connector.py | 7 ++++--- 2 files changed, 8 insertions(+), 7 deletions(-) diff --git a/src/pytorch_lightning/callbacks/model_checkpoint.py b/src/pytorch_lightning/callbacks/model_checkpoint.py index 7e8aa0b599a2a..efc0b8f5d2990 100644 --- a/src/pytorch_lightning/callbacks/model_checkpoint.py +++ b/src/pytorch_lightning/callbacks/model_checkpoint.py @@ -25,7 +25,7 @@ import warnings from copy import deepcopy from datetime import timedelta -from typing import Any, Dict, List, Optional +from typing import Any, Dict, Optional, Set from weakref import proxy import numpy as np @@ -604,12 +604,12 @@ def __resolve_ckpt_dir(self, trainer: "pl.Trainer") -> _PATH: return ckpt_path - def _find_last_checkpoints(self, trainer: "pl.Trainer") -> List[str]: + def _find_last_checkpoints(self, trainer: "pl.Trainer") -> Set[str]: # find all checkpoints in the folder ckpt_path = self.__resolve_ckpt_dir(trainer) if self._fs.exists(ckpt_path): - return [os.path.normpath(p) for p in self._fs.ls(ckpt_path, detail=False) if self.CHECKPOINT_NAME_LAST in p] - return [] + return {os.path.normpath(p) for p in self._fs.ls(ckpt_path, detail=False) if self.CHECKPOINT_NAME_LAST in p} + return set() def __warn_if_dir_not_empty(self, dirpath: _PATH) -> None: if self.save_top_k != 0 and self._fs.isdir(dirpath) and len(self._fs.ls(dirpath)) > 0: diff --git a/src/pytorch_lightning/trainer/connectors/checkpoint_connector.py b/src/pytorch_lightning/trainer/connectors/checkpoint_connector.py index fd987986d6cf5..56c089ae24dd6 100644 --- a/src/pytorch_lightning/trainer/connectors/checkpoint_connector.py +++ b/src/pytorch_lightning/trainer/connectors/checkpoint_connector.py @@ -27,6 +27,7 @@ import pytorch_lightning as pl from lightning_lite.utilities.cloud_io import get_filesystem from lightning_lite.utilities.types import _PATH +from pytorch_lightning.callbacks import ModelCheckpoint from pytorch_lightning.plugins.precision import ApexMixedPrecisionPlugin, NativeMixedPrecisionPlugin from pytorch_lightning.trainer.states import TrainerFn from pytorch_lightning.utilities import _OMEGACONF_AVAILABLE @@ -161,10 +162,10 @@ def _set_ckpt_path( ckpt_path = getattr(self.trainer.checkpoint_callback, "best_model_path", None) if ckpt_path == "last": - candidates = [getattr(ft, "ckpt_path", None) for ft in ft_checkpoints] + candidates = {getattr(ft, "ckpt_path", None) for ft in ft_checkpoints} for callback in self.trainer.checkpoint_callbacks: - if hasattr(callback, "_find_last_checkpoints"): - candidates += callback._find_last_checkpoints(self.trainer) + if isinstance(callback, ModelCheckpoint): + candidates |= callback._find_last_checkpoints(self.trainer) candidates_fs = {path: get_filesystem(path) for path in candidates if path} candidates_ts = {path: fs.modified(path) for path, fs in candidates_fs.items() if fs.exists(path)} if not candidates_ts: From d3a3b35f3c4aea8543aeeb7bf3ca21371f07fce8 Mon Sep 17 00:00:00 2001 From: otaj Date: Fri, 30 Sep 2022 14:35:09 +0200 Subject: [PATCH 11/12] added test + small fix --- .../callbacks/model_checkpoint.py | 6 +++- tests/tests_pytorch/trainer/test_trainer.py | 34 +++++++++++++++++++ 2 files changed, 39 insertions(+), 1 deletion(-) diff --git a/src/pytorch_lightning/callbacks/model_checkpoint.py b/src/pytorch_lightning/callbacks/model_checkpoint.py index efc0b8f5d2990..62001d50b1c85 100644 --- a/src/pytorch_lightning/callbacks/model_checkpoint.py +++ b/src/pytorch_lightning/callbacks/model_checkpoint.py @@ -608,7 +608,11 @@ def _find_last_checkpoints(self, trainer: "pl.Trainer") -> Set[str]: # find all checkpoints in the folder ckpt_path = self.__resolve_ckpt_dir(trainer) if self._fs.exists(ckpt_path): - return {os.path.normpath(p) for p in self._fs.ls(ckpt_path, detail=False) if self.CHECKPOINT_NAME_LAST in p} + return { + os.path.normpath(p) + for p in self._fs.ls(ckpt_path, detail=False) + if self.CHECKPOINT_NAME_LAST in os.path.split(p)[1] + } return set() def __warn_if_dir_not_empty(self, dirpath: _PATH) -> None: diff --git a/tests/tests_pytorch/trainer/test_trainer.py b/tests/tests_pytorch/trainer/test_trainer.py index 4c14e06bfcbce..ea11668ead979 100644 --- a/tests/tests_pytorch/trainer/test_trainer.py +++ b/tests/tests_pytorch/trainer/test_trainer.py @@ -776,6 +776,40 @@ def test_checkpoint_path_input_last(tmpdir, ckpt_path, save_last, fn): assert trainer.ckpt_path == final_path +def test_checkpoint_find_last(tmpdir): + """Test that the last checkpoint is found correctly.""" + model = BoringModel() + mc = ModelCheckpoint(save_last=True) + trainer = Trainer( + max_epochs=1, + limit_train_batches=1, + limit_val_batches=1, + enable_model_summary=False, + enable_progress_bar=False, + logger=False, + default_root_dir=tmpdir, + callbacks=[mc], + ) + assert trainer.ckpt_path is None + trainer.fit(model) + + model = BoringModel() + mc = ModelCheckpoint() + trainer = Trainer( + max_epochs=1, + limit_train_batches=1, + limit_val_batches=1, + enable_model_summary=False, + enable_progress_bar=False, + logger=False, + default_root_dir=tmpdir, + callbacks=[mc], + ) + assert trainer.ckpt_path is None + trainer.fit(model, ckpt_path="last") + assert trainer.ckpt_path == str(tmpdir / "checkpoints" / "last.ckpt") + + @pytest.mark.parametrize("ckpt_path", (None, "best", "specific")) @pytest.mark.parametrize("save_top_k", (-1, 0, 1, 2)) @pytest.mark.parametrize("fn", ("validate", "test", "predict")) From ba3a48c360bf9f8d3da4f8591a74a362e3c29733 Mon Sep 17 00:00:00 2001 From: otaj Date: Fri, 30 Sep 2022 15:49:50 +0200 Subject: [PATCH 12/12] apply suggestions --- tests/tests_pytorch/trainer/test_trainer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/tests_pytorch/trainer/test_trainer.py b/tests/tests_pytorch/trainer/test_trainer.py index ea11668ead979..4738f71fe2a97 100644 --- a/tests/tests_pytorch/trainer/test_trainer.py +++ b/tests/tests_pytorch/trainer/test_trainer.py @@ -783,7 +783,7 @@ def test_checkpoint_find_last(tmpdir): trainer = Trainer( max_epochs=1, limit_train_batches=1, - limit_val_batches=1, + limit_val_batches=0, enable_model_summary=False, enable_progress_bar=False, logger=False, @@ -798,7 +798,7 @@ def test_checkpoint_find_last(tmpdir): trainer = Trainer( max_epochs=1, limit_train_batches=1, - limit_val_batches=1, + limit_val_batches=0, enable_model_summary=False, enable_progress_bar=False, logger=False,