Skip to content

Commit 811e8e1

Browse files
committed
look for trainer
1 parent 3add38b commit 811e8e1

File tree

1 file changed

+7
-6
lines changed

1 file changed

+7
-6
lines changed

pytorch_lightning/core/lightning.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
from pytorch_lightning.core.mixins import DeviceDtypeModuleMixin, HyperparametersMixin
3737
from pytorch_lightning.core.optimizer import LightningOptimizer
3838
from pytorch_lightning.core.saving import ModelIO
39+
from pytorch_lightning.trainer.connectors.data_connector import _DataHookSource
3940
from pytorch_lightning.trainer.connectors.logger_connector.fx_validator import _FxValidator
4041
from pytorch_lightning.utilities import (
4142
_IS_WINDOWS,
@@ -263,13 +264,13 @@ def _apply_batch_transfer_handler(
263264
self, batch: Any, device: Optional[torch.device] = None, dataloader_idx: int = 0
264265
) -> Any:
265266
device = device or self.device
266-
batch = self.trainer._data_connector._datahook_source.get_hook("on_before_batch_transfer")(
267-
batch, dataloader_idx
267+
datahook_source = (
268+
_DataHookSource(self) if self.trainer is None else self.trainer._data_connector._datahook_source
268269
)
269-
batch = self.trainer._data_connector._datahook_source.get_hook("transfer_batch_to_device")(
270-
batch, device, dataloader_idx
271-
)
272-
batch = self.trainer._data_connector._datahook_source.get_hook("on_after_batch_transfer")(batch, dataloader_idx)
270+
271+
batch = datahook_source.get_hook("on_before_batch_transfer")(batch, dataloader_idx)
272+
batch = datahook_source.get_hook("transfer_batch_to_device")(batch, device, dataloader_idx)
273+
batch = datahook_source.get_hook("on_after_batch_transfer")(batch, dataloader_idx)
273274
return batch
274275

275276
def print(self, *args, **kwargs) -> None:

0 commit comments

Comments
 (0)