Skip to content

Commit 7e10f6d

Browse files
authored
Save the loop progress state by default (#10784)
1 parent fa6d17c commit 7e10f6d

File tree

11 files changed

+36
-15
lines changed

11 files changed

+36
-15
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
3434
- Show a better error message when a custom `DataLoader` implementation is not well implemented and we need to reconstruct it ([#10719](https://github.com/PyTorchLightning/pytorch-lightning/issues/10719))
3535

3636

37+
- Save the `Loop`'s state by default in the checkpoint ([#10784](https://github.com/PyTorchLightning/pytorch-lightning/issues/10784))
38+
39+
3740
- Added `Loop.replace` to easily switch one loop for another ([#10324](https://github.com/PyTorchLightning/pytorch-lightning/issues/10324))
3841

3942

pytorch_lightning/callbacks/timer.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,13 @@ def on_test_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule")
142142
def on_test_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
143143
self._end_time[RunningStage.TESTING] = time.monotonic()
144144

145+
def on_fit_start(self, trainer: "pl.Trainer", *args: Any, **kwargs: Any) -> None:
146+
# this checks the time after the state is reloaded, regardless of the interval.
147+
# this is necessary in case we load a state whose timer is already depleted
148+
if self._duration is None:
149+
return
150+
self._check_time_remaining(trainer)
151+
145152
def on_train_batch_end(self, trainer: "pl.Trainer", *args: Any, **kwargs: Any) -> None:
146153
if self._interval != Interval.step or self._duration is None:
147154
return

pytorch_lightning/core/lightning.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -303,7 +303,7 @@ def log(
303303
add_dataloader_idx: bool = True,
304304
batch_size: Optional[int] = None,
305305
metric_attribute: Optional[str] = None,
306-
rank_zero_only: Optional[bool] = None,
306+
rank_zero_only: bool = False,
307307
) -> None:
308308
"""Log a key, value pair.
309309
@@ -441,7 +441,7 @@ def log_dict(
441441
sync_dist_group: Optional[Any] = None,
442442
add_dataloader_idx: bool = True,
443443
batch_size: Optional[int] = None,
444-
rank_zero_only: Optional[bool] = None,
444+
rank_zero_only: bool = False,
445445
) -> None:
446446
"""Log a dictionary of values at once.
447447

pytorch_lightning/loops/base.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
import pytorch_lightning as pl
2222
from pytorch_lightning.trainer.connectors.logger_connector.result import ResultCollection
2323
from pytorch_lightning.trainer.progress import BaseProgress
24+
from pytorch_lightning.utilities.enums import _FaultTolerantMode
2425
from pytorch_lightning.utilities.exceptions import MisconfigurationException
2526

2627
T = TypeVar("T") # the output type of `run`
@@ -273,9 +274,11 @@ def state_dict(self, destination: Optional[Dict] = None, prefix: str = "") -> Di
273274

274275
destination[prefix + "state_dict"] = self.on_save_checkpoint()
275276

277+
# do not get the mode from `self.trainer` because it might not have been attached yet
278+
ft_enabled = _FaultTolerantMode.detect_current_mode().is_enabled
276279
for k, v in self.__dict__.items():
277280
key = prefix + k
278-
if isinstance(v, BaseProgress):
281+
if ft_enabled and isinstance(v, BaseProgress):
279282
destination[key] = v.state_dict()
280283
elif isinstance(v, Loop):
281284
v.state_dict(destination, key + ".")
@@ -302,6 +305,10 @@ def load_state_dict(
302305
def _load_from_state_dict(self, state_dict: Dict, prefix: str, metrics: Optional[Dict[str, Metric]] = None) -> None:
303306
for k, v in self.__dict__.items():
304307
key = prefix + k
308+
if key not in state_dict:
309+
# no state for this object, maybe we are loading an old checkpoint
310+
continue
311+
305312
if isinstance(v, BaseProgress):
306313
v.load_state_dict(state_dict[key])
307314
elif (
@@ -330,4 +337,6 @@ def _load_from_state_dict(self, state_dict: Dict, prefix: str, metrics: Optional
330337
v.reset(metrics=False)
331338

332339
self.on_load_checkpoint(state_dict[prefix + "state_dict"])
333-
self.restarting = True
340+
341+
if _FaultTolerantMode.detect_current_mode().is_enabled:
342+
self.restarting = True

pytorch_lightning/trainer/connectors/checkpoint_connector.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -373,9 +373,8 @@ def dump_checkpoint(self, weights_only: bool = False) -> dict:
373373
"global_step": global_step,
374374
"pytorch-lightning_version": pl.__version__,
375375
"state_dict": self._get_lightning_module_state_dict(),
376+
"loops": self._get_loops_state_dict(),
376377
}
377-
if _fault_tolerant_training():
378-
checkpoint["loops"] = self._get_loops_state_dict()
379378

380379
if not weights_only:
381380
# dump callbacks

pytorch_lightning/trainer/connectors/logger_connector/result.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -616,8 +616,8 @@ def cpu(self) -> "ResultCollection":
616616

617617
def sync(self) -> None:
618618
for result_metric in self.result_metrics:
619-
if result_metric.is_tensor:
620-
result_metric.sync()
619+
if result_metric.is_tensor and not result_metric._is_synced:
620+
result_metric.sync(should_sync=not result_metric.meta.sync.rank_zero_only)
621621

622622
def unsync(self) -> None:
623623
for result_metric in self.result_metrics:

tests/callbacks/test_timer.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -168,15 +168,12 @@ def test_timer_resume_training(tmpdir):
168168
assert trainer.current_epoch < 99
169169
saved_global_step = trainer.global_step
170170

171-
# resume training (with depleted timer
171+
# resume training (with depleted timer)
172172
timer = Timer(duration=timedelta(milliseconds=200))
173-
trainer = Trainer(
174-
default_root_dir=tmpdir,
175-
callbacks=[timer, checkpoint_callback],
176-
)
173+
trainer = Trainer(default_root_dir=tmpdir, callbacks=timer)
177174
trainer.fit(model, ckpt_path=checkpoint_callback.best_model_path)
178175
assert timer._offset > 0
179-
assert trainer.global_step == saved_global_step + 1
176+
assert trainer.global_step == saved_global_step
180177

181178

182179
@RunIf(skip_windows=True)

tests/loops/test_loop_state_dict.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
import os
15+
from unittest import mock
1416
from unittest.mock import Mock
1517

1618
import pytest
@@ -37,6 +39,7 @@ def test_loops_state_dict():
3739
assert fit_loop.state_dict() == new_fit_loop.state_dict()
3840

3941

42+
@mock.patch.dict(os.environ, {"PL_FAULT_TOLERANT_TRAINING": "1"})
4043
def test_loops_state_dict_structure():
4144
trainer = Trainer()
4245
trainer.train_dataloader = Mock()

tests/loops/test_loops.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -213,6 +213,7 @@ def load_state_dict(self, state_dict: Dict) -> None:
213213
assert loop.outputs == list(range(10))
214214

215215

216+
@mock.patch.dict(os.environ, {"PL_FAULT_TOLERANT_TRAINING": "1"})
216217
def test_loop_hierarchy():
217218
@dataclass
218219
class SimpleProgress(BaseProgress):

tests/models/test_hooks.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -497,6 +497,7 @@ def training_step(self, batch, batch_idx):
497497
"optimizer_states": ANY,
498498
"pytorch-lightning_version": __version__,
499499
"state_dict": ANY,
500+
"loops": ANY,
500501
}
501502
if kwargs.get("amp_backend") == "native":
502503
saved_ckpt["native_amp_scaling_state"] = ANY
@@ -624,6 +625,7 @@ def test_trainer_model_hook_system_fit_no_val_and_resume(tmpdir):
624625
"optimizer_states": ANY,
625626
"pytorch-lightning_version": __version__,
626627
"state_dict": ANY,
628+
"loops": ANY,
627629
}
628630
# TODO: wrong saved epoch, should be 0
629631
saved_ckpt = {**loaded_ckpt, "global_step": steps_after_reload, "epoch": 2}

0 commit comments

Comments
 (0)