|
14 | 14 |
|
15 | 15 | """Various hooks to be used in the Lightning code.""" |
16 | 16 |
|
17 | | -from typing import Any, Dict, List, Union |
| 17 | +from typing import Any, Dict, List, Optional, Union |
18 | 18 |
|
19 | 19 | import torch |
20 | 20 | from pytorch_lightning.utilities import AMPType, move_data_to_device, rank_zero_warn |
@@ -300,7 +300,7 @@ def on_after_backward(self): |
300 | 300 |
|
301 | 301 |
|
302 | 302 | class DataHooks: |
303 | | - """Hooks to be used with LightningDataModule.""" |
| 303 | + """Hooks to be used for data related stuff.""" |
304 | 304 | def prepare_data(self) -> None: |
305 | 305 | """ |
306 | 306 | Use this to download and prepare data. |
@@ -508,7 +508,7 @@ def val_dataloader(self): |
508 | 508 | will have an argument ``dataloader_idx`` which matches the order here. |
509 | 509 | """ |
510 | 510 |
|
511 | | - def transfer_batch_to_device(self, batch: Any, device: torch.device) -> Any: |
| 511 | + def transfer_batch_to_device(self, batch: Any, device: Optional[torch.device] = None) -> Any: |
512 | 512 | """ |
513 | 513 | Override this hook if your :class:`~torch.utils.data.DataLoader` returns tensors |
514 | 514 | wrapped in a custom data structure. |
@@ -556,8 +556,21 @@ def transfer_batch_to_device(self, batch, device) |
556 | 556 | - :func:`~pytorch_lightning.utilities.apply_func.move_data_to_device` |
557 | 557 | - :func:`~pytorch_lightning.utilities.apply_func.apply_to_collection` |
558 | 558 | """ |
| 559 | + device = device or self.device |
559 | 560 | return move_data_to_device(batch, device) |
560 | 561 |
|
| 562 | + def on_before_batch_transfer(self, batch): |
| 563 | + return batch |
| 564 | + |
| 565 | + def on_after_batch_transfer(self, batch): |
| 566 | + return batch |
| 567 | + |
| 568 | + def prepare_batch_for_transfer(self, batch: Any, device: Optional[torch.device] = None): |
| 569 | + batch = self.on_before_batch_transfer(batch) |
| 570 | + batch = self.transfer_batch_to_device(batch, device) |
| 571 | + batch = self.on_after_batch_transfer(batch) |
| 572 | + return batch |
| 573 | + |
561 | 574 |
|
562 | 575 | class CheckpointHooks: |
563 | 576 | """Hooks to be used with Checkpointing.""" |
|
0 commit comments