|
36 | 36 | from pytorch_lightning.core.mixins import DeviceDtypeModuleMixin, HyperparametersMixin |
37 | 37 | from pytorch_lightning.core.optimizer import LightningOptimizer |
38 | 38 | from pytorch_lightning.core.saving import ModelIO |
| 39 | +from pytorch_lightning.trainer.connectors.data_connector import _DataHookSource |
39 | 40 | from pytorch_lightning.trainer.connectors.logger_connector.fx_validator import _FxValidator |
40 | 41 | from pytorch_lightning.utilities import ( |
41 | 42 | _IS_WINDOWS, |
@@ -263,13 +264,13 @@ def _apply_batch_transfer_handler( |
263 | 264 | self, batch: Any, device: Optional[torch.device] = None, dataloader_idx: int = 0 |
264 | 265 | ) -> Any: |
265 | 266 | device = device or self.device |
266 | | - batch = self.trainer._data_connector._datahook_source.get_hook("on_before_batch_transfer")( |
267 | | - batch, dataloader_idx |
| 267 | + datahook_source = ( |
| 268 | + _DataHookSource(self) if self.trainer is None else self.trainer._data_connector._datahook_source |
268 | 269 | ) |
269 | | - batch = self.trainer._data_connector._datahook_source.get_hook("transfer_batch_to_device")( |
270 | | - batch, device, dataloader_idx |
271 | | - ) |
272 | | - batch = self.trainer._data_connector._datahook_source.get_hook("on_after_batch_transfer")(batch, dataloader_idx) |
| 270 | + |
| 271 | + batch = datahook_source.get_hook("on_before_batch_transfer")(batch, dataloader_idx) |
| 272 | + batch = datahook_source.get_hook("transfer_batch_to_device")(batch, device, dataloader_idx) |
| 273 | + batch = datahook_source.get_hook("on_after_batch_transfer")(batch, dataloader_idx) |
273 | 274 | return batch |
274 | 275 |
|
275 | 276 | def print(self, *args, **kwargs) -> None: |
|
0 commit comments