|
24 | 24 | from pytorch_lightning.utilities.argparse import add_argparse_args, from_argparse_args, get_init_arguments_and_types |
25 | 25 |
|
26 | 26 |
|
27 | | -class _DataModuleWrapper(type): |
28 | | - |
29 | | - def __init__(self, *args: Any, **kwargs: Any) -> None: |
30 | | - super().__init__(*args, **kwargs) |
31 | | - self.__has_added_checks = False |
32 | | - |
33 | | - def __call__(cls, *args, **kwargs): |
34 | | - """A wrapper for LightningDataModule that: |
35 | | -
|
36 | | - 1. Runs user defined subclass's __init__ |
37 | | - 2. Assures prepare_data() runs on rank 0 |
38 | | - 3. Lets you check prepare_data and setup to see if they've been called |
39 | | - """ |
40 | | - if not cls.__has_added_checks: |
41 | | - cls.__has_added_checks = True |
42 | | - # Track prepare_data calls and make sure it runs on rank zero |
43 | | - cls.prepare_data = track_data_hook_calls(rank_zero_only(cls.prepare_data)) |
44 | | - # Track setup calls |
45 | | - cls.setup = track_data_hook_calls(cls.setup) |
46 | | - # Track teardown calls |
47 | | - cls.teardown = track_data_hook_calls(cls.teardown) |
48 | | - |
49 | | - # Get instance of LightningDataModule by mocking its __init__ via __call__ |
50 | | - obj = type.__call__(cls, *args, **kwargs) |
51 | | - |
52 | | - return obj |
53 | | - |
54 | | - |
55 | | -def track_data_hook_calls(fn): |
56 | | - """A decorator that checks if prepare_data/setup/teardown has been called. |
57 | | -
|
58 | | - - When ``dm.prepare_data()`` is called, ``dm.has_prepared_data`` gets set to True |
59 | | - - When ``dm.setup()``, ``dm.has_setup_{fit,validate,test}`` get set to True |
60 | | - - When ``dm.setup(stage)`` is called, where stage is any of ``{fit,validate,test,predict}``. |
61 | | - Its corresponding `dm_has_setup_{stage}` attribute gets set to True |
62 | | - - ``dm.teardown()`` and ``dm.teardown(stage)`` act exactly like ``dm.setup`` |
63 | | -
|
64 | | - Args: |
65 | | - fn (function): Function that will be tracked to see if it has been called. |
66 | | -
|
67 | | - Returns: |
68 | | - function: Decorated function that tracks its call status and saves it to private attrs in its obj instance. |
69 | | - """ |
70 | | - |
71 | | - @functools.wraps(fn) |
72 | | - def wrapped_fn(*args, **kwargs): |
73 | | - |
74 | | - # The object instance from which setup or prepare_data was called |
75 | | - obj = args[0] |
76 | | - name = fn.__name__ |
77 | | - |
78 | | - # If calling setup, we check the stage and assign stage-specific bool args |
79 | | - if name in ("setup", "teardown"): |
80 | | - |
81 | | - # Get stage either by grabbing from args or checking kwargs. |
82 | | - # If not provided, set call status of 'fit', 'validate', and 'test' to True. |
83 | | - # We do this so __attach_datamodule in trainer.py doesn't mistakenly call setup('test') on trainer.test() |
84 | | - stage = args[1] if len(args) > 1 else kwargs.get("stage", None) |
85 | | - |
86 | | - if stage is None: |
87 | | - for s in ("fit", "validate", "test"): |
88 | | - setattr(obj, f"_has_{name}_{s}", True) |
89 | | - else: |
90 | | - setattr(obj, f"_has_{name}_{stage}", True) |
91 | | - |
92 | | - elif name == "prepare_data": |
93 | | - obj._has_prepared_data = True |
94 | | - |
95 | | - return fn(*args, **kwargs) |
96 | | - |
97 | | - return wrapped_fn |
98 | | - |
99 | | - |
100 | | -class LightningDataModule(CheckpointHooks, DataHooks, metaclass=_DataModuleWrapper): |
| 27 | +class LightningDataModule(CheckpointHooks, DataHooks): |
101 | 28 | """ |
102 | 29 | A DataModule standardizes the training, val, test splits, data preparation and transforms. |
103 | 30 | The main advantage is consistent data splits, data preparation and transforms across models. |
@@ -398,3 +325,62 @@ def test_dataloader(): |
398 | 325 | if test_dataset is not None: |
399 | 326 | datamodule.test_dataloader = test_dataloader |
400 | 327 | return datamodule |
| 328 | + |
| 329 | + def __new__(cls, *args: Any, **kwargs: Any) -> 'LightningDataModule': |
| 330 | + obj = super().__new__(cls) |
| 331 | + # track `DataHooks` calls and run `prepare_data` only on rank zero |
| 332 | + obj.prepare_data = cls._track_data_hook_calls(obj, rank_zero_only(obj.prepare_data)) |
| 333 | + obj.setup = cls._track_data_hook_calls(obj, obj.setup) |
| 334 | + obj.teardown = cls._track_data_hook_calls(obj, obj.teardown) |
| 335 | + return obj |
| 336 | + |
| 337 | + @staticmethod |
| 338 | + def _track_data_hook_calls(obj: 'LightningDataModule', fn: callable) -> callable: |
| 339 | + """A decorator that checks if prepare_data/setup/teardown has been called. |
| 340 | +
|
| 341 | + - When ``dm.prepare_data()`` is called, ``dm.has_prepared_data`` gets set to True |
| 342 | + - When ``dm.setup()``, ``dm.has_setup_{fit,validate,test}`` get set to True |
| 343 | + - When ``dm.setup(stage)`` is called, where stage is any of ``{fit,validate,test,predict}``. |
| 344 | + Its corresponding `dm_has_setup_{stage}` attribute gets set to True |
| 345 | + - ``dm.teardown()`` and ``dm.teardown(stage)`` act exactly like ``dm.setup`` |
| 346 | +
|
| 347 | + Args: |
| 348 | + obj: Object whose function will be tracked |
| 349 | + fn: Function that will be tracked to see if it has been called. |
| 350 | +
|
| 351 | + Returns: |
| 352 | + Decorated function that tracks its call status and saves it to private attrs in its obj instance. |
| 353 | + """ |
| 354 | + |
| 355 | + @functools.wraps(fn) |
| 356 | + def wrapped_fn(*args: str, **kwargs: Optional[str]) -> Any: |
| 357 | + name = fn.__name__ |
| 358 | + |
| 359 | + # If calling setup, we check the stage and assign stage-specific bool args |
| 360 | + if name in ("setup", "teardown"): |
| 361 | + |
| 362 | + # Get stage either by grabbing from args or checking kwargs. |
| 363 | + # If not provided, set call status of 'fit', 'validate', and 'test' to True. |
| 364 | + # We do this so __attach_datamodule in trainer.py doesn't mistakenly call |
| 365 | + # setup('test') on trainer.test() |
| 366 | + stage = args[0] if len(args) else kwargs.get("stage", None) |
| 367 | + |
| 368 | + if stage is None: |
| 369 | + for s in ("fit", "validate", "test"): |
| 370 | + setattr(obj, f"_has_{name}_{s}", True) |
| 371 | + else: |
| 372 | + setattr(obj, f"_has_{name}_{stage}", True) |
| 373 | + |
| 374 | + elif name == "prepare_data": |
| 375 | + obj._has_prepared_data = True |
| 376 | + |
| 377 | + return fn(*args, **kwargs) |
| 378 | + |
| 379 | + return wrapped_fn |
| 380 | + |
| 381 | + def __getstate__(self) -> dict: |
| 382 | + # avoids _pickle.PicklingError: Can't pickle <...>: it's not the same object as <...> |
| 383 | + d = self.__dict__.copy() |
| 384 | + for fn in ("prepare_data", "setup", "teardown"): |
| 385 | + del d[fn] |
| 386 | + return d |
0 commit comments