Skip to content

Commit 353cf79

Browse files
authored
Merge branch 'master' into feat/gpu-validation
2 parents ca698e5 + 1203094 commit 353cf79

File tree

5 files changed

+48
-12
lines changed

5 files changed

+48
-12
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
9595
- Added `rank_zero` module to centralize utilities ([#11747](https://github.com/PyTorchLightning/pytorch-lightning/pull/11747))
9696

9797

98+
- Added a `_Stateful` support for `LightningDataModule` ([#11637](https://github.com/PyTorchLightning/pytorch-lightning/pull/11637))
99+
100+
98101
- Added checks to `GPUAccelerator` to assert CUDA availability at initialization ([#11797](https://github.com/PyTorchLightning/pytorch-lightning/pull/11797))
99102

100103

pytorch_lightning/core/datamodule.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
# limitations under the License.
1414
"""LightningDataModule for loading DataLoaders with ease."""
1515
from argparse import ArgumentParser, Namespace
16-
from typing import Any, List, Mapping, Optional, Sequence, Tuple, Union
16+
from typing import Any, Dict, List, Mapping, Optional, Sequence, Tuple, Union
1717

1818
from torch.utils.data import DataLoader, Dataset, IterableDataset
1919

@@ -246,3 +246,19 @@ def test_dataloader():
246246
if test_dataset is not None:
247247
datamodule.test_dataloader = test_dataloader
248248
return datamodule
249+
250+
def state_dict(self) -> Dict[str, Any]:
251+
"""Called when saving a checkpoint, implement to generate and save datamodule state.
252+
253+
Returns:
254+
A dictionary containing datamodule state.
255+
"""
256+
return {}
257+
258+
def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
259+
"""Called when loading a checkpoint, implement to reload datamodule state given datamodule state_dict.
260+
261+
Args:
262+
state_dict: the datamodule state returned by ``state_dict``.
263+
"""
264+
pass

pytorch_lightning/trainer/connectors/checkpoint_connector.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,8 @@ def restore_datamodule(self) -> None:
157157
datamodule = self.trainer.datamodule
158158
if datamodule is not None:
159159
datamodule.on_load_checkpoint(self._loaded_checkpoint)
160+
if datamodule.__class__.__qualname__ in self._loaded_checkpoint:
161+
datamodule.load_state_dict(self._loaded_checkpoint[datamodule.__class__.__qualname__])
160162

161163
def restore_model(self) -> None:
162164
"""Restores a model's weights from a PyTorch Lightning checkpoint.
@@ -324,7 +326,7 @@ def dump_checkpoint(self, weights_only: bool = False) -> dict:
324326
CHECKPOINT_HYPER_PARAMS_KEY:
325327
CHECKPOINT_HYPER_PARAMS_TYPE:
326328
something_cool_i_want_to_save: anything you define through model.on_save_checkpoint
327-
LightningDataModule.__class__.__name__: pl DataModule's state
329+
LightningDataModule.__class__.__qualname__: pl DataModule's state
328330
}
329331
"""
330332

@@ -378,10 +380,17 @@ def dump_checkpoint(self, weights_only: bool = False) -> dict:
378380
else:
379381
checkpoint[pl.LightningModule.CHECKPOINT_HYPER_PARAMS_KEY] = dict(model.hparams)
380382

381-
# give the model a chance to dump a few things
383+
# dump stateful datamodule
384+
datamodule = self.trainer.datamodule
385+
if datamodule is not None:
386+
datamodule_state_dict = datamodule.state_dict()
387+
if datamodule_state_dict:
388+
checkpoint[datamodule.__class__.__qualname__] = datamodule_state_dict
389+
390+
# on_save_checkpoint hooks
382391
model.on_save_checkpoint(checkpoint)
383-
if self.trainer.datamodule is not None:
384-
self.trainer.datamodule.on_save_checkpoint(checkpoint)
392+
if datamodule is not None:
393+
datamodule.on_save_checkpoint(checkpoint)
385394

386395
# TODO: remove this in v1.8.
387396
environment = self.trainer._accelerator_connector.cluster_environment

tests/core/test_datamodules.py

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -196,11 +196,18 @@ def validation_step(self, batch, batch_idx):
196196
return out
197197

198198
class CustomBoringDataModule(BoringDataModule):
199+
def state_dict(self) -> Dict[str, Any]:
200+
return {"my": "state_dict"}
201+
202+
def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
203+
self.my_state_dict = state_dict
204+
199205
def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
200-
checkpoint[self.__class__.__name__] = self.__class__.__name__
206+
checkpoint[self.__class__.__qualname__].update({"on_save": "update"})
201207

202208
def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
203-
self.checkpoint_state = checkpoint.get(self.__class__.__name__)
209+
self.checkpoint_state = checkpoint.get(self.__class__.__qualname__).copy()
210+
checkpoint[self.__class__.__qualname__].pop("on_save")
204211

205212
reset_seed()
206213
dm = CustomBoringDataModule()
@@ -220,14 +227,14 @@ def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
220227
assert trainer.state.finished, f"Training failed with {trainer.state}"
221228
checkpoint_path = list(trainer.checkpoint_callback.best_k_models.keys())[0]
222229
checkpoint = torch.load(checkpoint_path)
223-
assert dm.__class__.__name__ in checkpoint
224-
assert checkpoint[dm.__class__.__name__] == dm.__class__.__name__
230+
assert dm.__class__.__qualname__ in checkpoint
231+
assert checkpoint[dm.__class__.__qualname__] == {"my": "state_dict", "on_save": "update"}
225232

226233
for trainer_fn in TrainerFn:
227234
trainer.state.fn = trainer_fn
228-
with mock.patch.object(dm, "on_load_checkpoint") as dm_mock:
229-
trainer._restore_modules_and_callbacks(checkpoint_path)
230-
dm_mock.assert_called_once()
235+
trainer._restore_modules_and_callbacks(checkpoint_path)
236+
assert dm.checkpoint_state == {"my": "state_dict", "on_save": "update"}
237+
assert dm.my_state_dict == {"my": "state_dict"}
231238

232239

233240
def test_full_loop(tmpdir):

tests/models/test_hooks.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -865,6 +865,7 @@ def call(hook, fn, *args, **kwargs):
865865
dict(name="setup", kwargs=dict(stage="fit")),
866866
dict(name="val_dataloader"),
867867
dict(name="train_dataloader"),
868+
dict(name="state_dict"),
868869
dict(name="on_save_checkpoint", args=(ANY,)),
869870
dict(name="teardown", kwargs=dict(stage="fit")),
870871
]

0 commit comments

Comments
 (0)