Skip to content

Commit 2c16f1d

Browse files
awaelchlipre-commit-ci[bot]rohitgr7carmocca
authored
remove dataloader patching on the LightningModule (#9764)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: rohitgr7 <[email protected]> Co-authored-by: Carlos Mocholí <[email protected]>
1 parent 6701526 commit 2c16f1d

File tree

19 files changed

+198
-125
lines changed

19 files changed

+198
-125
lines changed

CHANGELOG.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -513,6 +513,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
513513
- Removed `process_idx` from the `{DDPSpawnPlugin,TPUSpawnPlugin}.new_process` methods ([#10022](https://github.com/PyTorchLightning/pytorch-lightning/pull/10022))
514514

515515

516+
- Removed automatic patching of `{train,val,test,predict}_dataloader()` on the `LightningModule` ([#9764](https://github.com/PyTorchLightning/pytorch-lightning/pull/9764))
517+
518+
516519
### Fixed
517520

518521

@@ -594,6 +597,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
594597

595598

596599

600+
- Fixed undesired side effects being caused by `Trainer` patching dataloader methods on the `LightningModule` ([#9764](https://github.com/PyTorchLightning/pytorch-lightning/pull/9764))
601+
602+
597603
## [1.4.9] - 2021-09-30
598604

599605
- Fixed `lr_find` to generate same results on multiple calls ([#9704](https://github.com/PyTorchLightning/pytorch-lightning/pull/9704))

pl_examples/basic_examples/autoencoder.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -181,8 +181,8 @@ def cli_main():
181181
trainer_defaults={"callbacks": ImageSampler(), "max_epochs": 10},
182182
)
183183
cli.trainer.fit(cli.model, datamodule=cli.datamodule)
184-
cli.trainer.test(ckpt_path="best")
185-
predictions = cli.trainer.predict(ckpt_path="best")
184+
cli.trainer.test(ckpt_path="best", datamodule=cli.datamodule)
185+
predictions = cli.trainer.predict(ckpt_path="best", datamodule=cli.datamodule)
186186
print(predictions[0])
187187

188188

pl_examples/basic_examples/backbone_image_classifier.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -124,8 +124,8 @@ def predict_dataloader(self):
124124
def cli_main():
125125
cli = LightningCLI(LitClassifier, MyDataModule, seed_everything_default=1234, save_config_overwrite=True, run=False)
126126
cli.trainer.fit(cli.model, datamodule=cli.datamodule)
127-
cli.trainer.test(ckpt_path="best")
128-
predictions = cli.trainer.predict(ckpt_path="best")
127+
cli.trainer.test(ckpt_path="best", datamodule=cli.datamodule)
128+
predictions = cli.trainer.predict(ckpt_path="best", datamodule=cli.datamodule)
129129
print(predictions[0])
130130

131131

pl_examples/basic_examples/dali_image_classifier.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -194,7 +194,7 @@ def cli_main():
194194

195195
cli = LightningCLI(LitClassifier, MyDataModule, seed_everything_default=1234, save_config_overwrite=True, run=False)
196196
cli.trainer.fit(cli.model, datamodule=cli.datamodule)
197-
cli.trainer.test(ckpt_path="best")
197+
cli.trainer.test(ckpt_path="best", datamodule=cli.datamodule)
198198

199199

200200
if __name__ == "__main__":

pl_examples/basic_examples/simple_image_classifier.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ def cli_main():
7474
LitClassifier, MNISTDataModule, seed_everything_default=1234, save_config_overwrite=True, run=False
7575
)
7676
cli.trainer.fit(cli.model, datamodule=cli.datamodule)
77-
cli.trainer.test(ckpt_path="best")
77+
cli.trainer.test(ckpt_path="best", datamodule=cli.datamodule)
7878

7979

8080
if __name__ == "__main__":

pytorch_lightning/plugins/plugins_registry.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -127,11 +127,7 @@ def is_register_plugins_overridden(plugin: type) -> bool:
127127
else:
128128
return False
129129

130-
if hasattr(plugin_attr, "patch_loader_code"):
131-
is_overridden = plugin_attr.patch_loader_code != str(super_attr.__code__)
132-
else:
133-
is_overridden = plugin_attr.__code__ is not super_attr.__code__
134-
return is_overridden
130+
return plugin_attr.__code__ is not super_attr.__code__
135131

136132

137133
def call_training_type_register_plugins(root: Path, base_module: str) -> None:

pytorch_lightning/plugins/training_type/deepspeed.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -628,8 +628,9 @@ def _auto_select_batch_size(self):
628628
# train_micro_batch_size_per_gpu is used for throughput logging purposes
629629
# by default we try to use the batch size of the loader
630630
batch_size = 1
631-
if hasattr(self.lightning_module, "train_dataloader"):
632-
train_dataloader = self.lightning_module.train_dataloader()
631+
train_dl_source = self.lightning_module.trainer.data_connector._train_dataloader_source
632+
if train_dl_source.is_defined():
633+
train_dataloader = train_dl_source.dataloader()
633634
if hasattr(train_dataloader, "batch_sampler"):
634635
batch_size = train_dataloader.batch_sampler.batch_size
635636
return batch_size

pytorch_lightning/plugins/training_type/tpu_spawn.py

Lines changed: 12 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -20,14 +20,13 @@
2020

2121
import torch
2222
import torch.multiprocessing as mp
23-
from torch.nn import Module
2423
from torch.utils.data import DataLoader
2524

2625
import pytorch_lightning as pl
2726
from pytorch_lightning.overrides import LightningDistributedModule
2827
from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO
2928
from pytorch_lightning.plugins.training_type.ddp_spawn import DDPSpawnPlugin
30-
from pytorch_lightning.trainer.connectors.data_connector import _PatchDataLoader
29+
from pytorch_lightning.trainer.connectors.data_connector import DataConnector
3130
from pytorch_lightning.trainer.states import TrainerFn
3231
from pytorch_lightning.utilities import (
3332
_OMEGACONF_AVAILABLE,
@@ -96,19 +95,18 @@ def _validate_dataloader(dataloaders: Union[List[DataLoader], DataLoader]) -> No
9695
)
9796

9897
@staticmethod
99-
def _validate_patched_dataloaders(model: Module) -> None:
98+
def _validate_patched_dataloaders(model: "pl.LightningModule") -> None:
10099
"""Validate and fail fast if the dataloaders were passed directly to fit."""
101-
if hasattr(model, "train_dataloader") and isinstance(model.train_dataloader, _PatchDataLoader):
102-
TPUSpawnPlugin._validate_dataloader(model.train_dataloader.dataloader)
103-
104-
if hasattr(model, "val_dataloader") and isinstance(model.val_dataloader, _PatchDataLoader):
105-
TPUSpawnPlugin._validate_dataloader(model.val_dataloader.dataloader)
106-
107-
if hasattr(model, "test_dataloader") and isinstance(model.test_dataloader, _PatchDataLoader):
108-
TPUSpawnPlugin._validate_dataloader(model.test_dataloader.dataloader)
109-
110-
if hasattr(model, "predict_dataloader") and isinstance(model.predict_dataloader, _PatchDataLoader):
111-
TPUSpawnPlugin._validate_dataloader(model.predict_dataloader.dataloader)
100+
connector: DataConnector = model.trainer.data_connector
101+
sources = (
102+
connector._train_dataloader_source,
103+
connector._val_dataloader_source,
104+
connector._test_dataloader_source,
105+
connector._predict_dataloader_source,
106+
)
107+
for source in sources:
108+
if not source.is_module():
109+
TPUSpawnPlugin._validate_dataloader(source.instance)
112110

113111
def connect(self, model: "pl.LightningModule") -> None:
114112
TPUSpawnPlugin._validate_patched_dataloaders(model)

pytorch_lightning/trainer/configuration_validator.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ def verify_loop_configurations(trainer: "pl.Trainer", model: "pl.LightningModule
3838
elif trainer.state.fn == TrainerFn.TESTING:
3939
__verify_eval_loop_configuration(model, "test")
4040
elif trainer.state.fn == TrainerFn.PREDICTING:
41-
__verify_predict_loop_configuration(model)
41+
__verify_predict_loop_configuration(trainer, model)
4242
__verify_dp_batch_transfer_support(trainer, model)
4343
_check_add_get_queue(model)
4444
# TODO(@daniellepintz): Delete _check_progress_bar in v1.7
@@ -65,7 +65,7 @@ def __verify_train_loop_configuration(trainer: "pl.Trainer", model: "pl.Lightnin
6565
# -----------------------------------
6666
# verify model has a train dataloader
6767
# -----------------------------------
68-
has_train_dataloader = is_overridden("train_dataloader", model)
68+
has_train_dataloader = trainer.data_connector._train_dataloader_source.is_defined()
6969
if not has_train_dataloader:
7070
raise MisconfigurationException(
7171
"No `train_dataloader()` method defined. Lightning `Trainer` expects as minimum a"
@@ -175,8 +175,8 @@ def __verify_eval_loop_configuration(model: "pl.LightningModule", stage: str) ->
175175
)
176176

177177

178-
def __verify_predict_loop_configuration(model: "pl.LightningModule") -> None:
179-
has_predict_dataloader = is_overridden("predict_dataloader", model)
178+
def __verify_predict_loop_configuration(trainer: "pl.Trainer", model: "pl.LightningModule") -> None:
179+
has_predict_dataloader = trainer.data_connector._predict_dataloader_source.is_defined()
180180
if not has_predict_dataloader:
181181
raise MisconfigurationException("Dataloader not found for `Trainer.predict`")
182182
# ----------------------------------------------

pytorch_lightning/trainer/connectors/data_connector.py

Lines changed: 74 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,9 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
import os
15+
from dataclasses import dataclass
1516
from functools import partial
16-
from typing import Callable, Iterable, Optional, Union
17+
from typing import Iterable, Optional, Union
1718

1819
import pytorch_lightning as pl
1920
from pytorch_lightning.utilities import rank_zero_deprecation
@@ -47,6 +48,11 @@ def __init__(
4748
self.test_data_fetcher = test_data_fetcher
4849
self.sanity_check_data_fetcher: Optional[AbstractDataFetcher] = None
4950

51+
self._train_dataloader_source = _DataLoaderSource(None, "")
52+
self._val_dataloader_source = _DataLoaderSource(None, "")
53+
self._test_dataloader_source = _DataLoaderSource(None, "")
54+
self._predict_dataloader_source = _DataLoaderSource(None, "")
55+
5056
@property
5157
def evaluation_data_fetcher(self) -> Optional[AbstractDataFetcher]:
5258
if self.trainer.sanity_checking:
@@ -190,27 +196,23 @@ def attach_dataloaders(
190196
test_dataloaders: Optional[EVAL_DATALOADERS] = None,
191197
predict_dataloaders: Optional[EVAL_DATALOADERS] = None,
192198
) -> None:
193-
# when dataloader is passed via fit, patch the train_dataloader
194-
# functions to overwrite with these implementations
195-
if train_dataloaders is not None:
196-
self.trainer.train_dataloader = None
197-
train_dataloader = _PatchDataLoader(train_dataloaders, "train")
198-
train_dataloader.patch(model)
199-
200-
if val_dataloaders is not None:
201-
self.trainer.val_dataloaders = None
202-
val_dataloader = _PatchDataLoader(val_dataloaders, "val")
203-
val_dataloader.patch(model)
204-
205-
if test_dataloaders is not None:
206-
self.trainer.test_dataloaders = None
207-
test_dataloader = _PatchDataLoader(test_dataloaders, "test")
208-
test_dataloader.patch(model)
209-
210-
if predict_dataloaders is not None:
211-
self.trainer.predict_dataloaders = None
212-
predict_dataloader = _PatchDataLoader(predict_dataloaders, "predict")
213-
predict_dataloader.patch(model)
199+
self.trainer.train_dataloader = None
200+
self.trainer.val_dataloaders = None
201+
self.trainer.test_dataloaders = None
202+
self.trainer.predict_dataloaders = None
203+
204+
self._train_dataloader_source = _DataLoaderSource(
205+
train_dataloaders if train_dataloaders is not None else model, "train_dataloader"
206+
)
207+
self._val_dataloader_source = _DataLoaderSource(
208+
val_dataloaders if val_dataloaders is not None else model, "val_dataloader"
209+
)
210+
self._test_dataloader_source = _DataLoaderSource(
211+
test_dataloaders if test_dataloaders is not None else model, "test_dataloader"
212+
)
213+
self._predict_dataloader_source = _DataLoaderSource(
214+
predict_dataloaders if predict_dataloaders is not None else model, "predict_dataloader"
215+
)
214216

215217
def attach_datamodule(
216218
self, model: "pl.LightningModule", datamodule: Optional["pl.LightningDataModule"] = None
@@ -219,11 +221,10 @@ def attach_datamodule(
219221
if datamodule is None:
220222
return
221223

222-
# Override loader hooks
223-
dl_methods = ("train_dataloader", "val_dataloader", "test_dataloader", "predict_dataloader")
224-
for method in dl_methods:
225-
if is_overridden(method, datamodule):
226-
setattr(model, method, getattr(datamodule, method))
224+
self._train_dataloader_source = _DataLoaderSource(datamodule, "train_dataloader")
225+
self._val_dataloader_source = _DataLoaderSource(datamodule, "val_dataloader")
226+
self._test_dataloader_source = _DataLoaderSource(datamodule, "test_dataloader")
227+
self._predict_dataloader_source = _DataLoaderSource(datamodule, "predict_dataloader")
227228

228229
# Override data transfer hooks if dataset-specific to_device logic has been defined in datamodule
229230
batch_transfer_hooks = ("on_before_batch_transfer", "transfer_batch_to_device", "on_after_batch_transfer")
@@ -238,13 +239,6 @@ def attach_datamodule(
238239
if hasattr(datamodule, "data_pipeline"):
239240
model.data_pipeline = datamodule.data_pipeline
240241

241-
@staticmethod
242-
def detach_data(model: "pl.LightningModule") -> None:
243-
for stage in ("train", "val", "test", "predict"):
244-
loader = getattr(model, f"{stage}_dataloader", None)
245-
if isinstance(loader, _PatchDataLoader):
246-
loader.unpatch(model)
247-
248242
def teardown(self) -> None:
249243
if self.train_data_fetcher:
250244
self.train_data_fetcher.teardown()
@@ -260,32 +254,56 @@ def teardown(self) -> None:
260254
self.sanity_check_data_fetcher = None
261255

262256

263-
class _PatchDataLoader:
264-
r"""
265-
Callable object for patching dataloaders passed into trainer.fit().
266-
Use this class to override model.*_dataloader() and be pickle-compatible.
257+
@dataclass
258+
class _DataLoaderSource:
259+
"""Stores the information where the dataloaders come from.
260+
261+
The source can be
267262
268-
Args:
269-
dataloader: Dataloader object to return when called.
263+
1. from a ``*_datalaoder()`` method on the :class:`~pytorch_lightning.core.lightning.LightningModule`,
264+
2. from a ``*_datalaoder()`` method on the :class:`~pytorch_lightning.core.datamodule.LightningDataModule`,
265+
3. a direct instance of a :class:`~torch.utils.data.DataLoader` or supported collections thereof.
266+
267+
Arguments:
268+
instance: A LightningModule, LightningDataModule, or (a collection of) dataloader(s).
269+
name: A name for this dataloader source. If the instance is a module, the name corresponds to the hook
270+
that returns the desired dataloader(s).
270271
"""
271272

272-
def __init__(self, dataloader: Union[TRAIN_DATALOADERS, EVAL_DATALOADERS], stage: str) -> None:
273-
self.dataloader = dataloader
273+
instance: Optional[Union[TRAIN_DATALOADERS, EVAL_DATALOADERS, "pl.LightningModule", "pl.LightningDataModule"]]
274+
name: str
275+
276+
def dataloader(self) -> Union[TRAIN_DATALOADERS, EVAL_DATALOADERS]:
277+
"""Returns the dataloader from the source.
278+
279+
If the source is a module, the method with the corresponding :attr:`name` gets called.
280+
"""
281+
from pytorch_lightning import LightningDataModule, LightningModule # prevent cyclic import
282+
283+
if not self.name:
284+
return self.instance
285+
286+
if isinstance(self.instance, LightningModule):
287+
return self.instance.trainer.call_hook(self.name, pl_module=self.instance)
288+
289+
if isinstance(self.instance, LightningDataModule):
290+
method = getattr(self.instance, self.name)
291+
return method()
292+
293+
return self.instance
294+
295+
def is_defined(self) -> bool:
296+
"""Returns whether the source dataloader can be retrieved or not.
274297
275-
# cannot pickle __code__ so cannot verify if PatchDataloader
276-
# exists which shows dataloader methods have been overwritten.
277-
# so, we hack it by using the string representation
278-
self.patch_loader_code = str(self.__call__.__code__)
279-
self._old_loader: Optional[Callable] = None
280-
self.stage = stage
298+
If the source is a module it checks that the method with given :attr:`name` is overridden.
299+
"""
300+
return not self.is_module() or is_overridden(self.name, self.instance)
281301

282-
def __call__(self) -> Union[TRAIN_DATALOADERS, EVAL_DATALOADERS]:
283-
return self.dataloader
302+
def is_module(self) -> bool:
303+
"""Returns whether the the DataLoader source is a LightningModule or a LightningDataModule.
284304
285-
def patch(self, model: "pl.LightningModule") -> None:
286-
self._old_loader = getattr(model, self.stage + "_dataloader")
287-
setattr(model, self.stage + "_dataloader", self)
305+
It does not check whether ``*_dataloader`` methods are actually overridden.
306+
"""
307+
from pytorch_lightning import LightningDataModule, LightningModule # prevent cyclic import
288308

289-
def unpatch(self, model: "pl.LightningModule") -> None:
290-
setattr(model, self.stage + "_dataloader", self._old_loader)
291-
self._old_loader = None
309+
return isinstance(self.instance, (LightningModule, LightningDataModule))

0 commit comments

Comments
 (0)