Skip to content

Commit a8c2725

Browse files
rohitgr7awaelchli
andauthored
remove deprecated signature for transfer_batch_to_device (#10480)
Co-authored-by: Adrian Wälchli <[email protected]>
1 parent fabb364 commit a8c2725

File tree

5 files changed

+7
-23
lines changed

5 files changed

+7
-23
lines changed

CHANGELOG.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
109109
* ([#10403](https://github.com/PyTorchLightning/pytorch-lightning/pull/10403))
110110
* ([#10448](https://github.com/PyTorchLightning/pytorch-lightning/pull/10448))
111111

112+
- Removed deprecated signature for `transfer_batch_to_device` hook. The new argument `dataloader_idx` is now required ([#10480](https://github.com/PyTorchLightning/pytorch-lightning/pull/10480))
113+
112114

113115
- Removed deprecated `utilities.distributed.rank_zero_{warn/deprecation}` ([#10451](https://github.com/PyTorchLightning/pytorch-lightning/pull/10451))
114116

@@ -119,6 +121,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
119121
- Removed deprecated `Trainer.train_loop` property in favor of `Trainer.fit_loop` ([#10482](https://github.com/PyTorchLightning/pytorch-lightning/pull/10482))
120122

121123

124+
- Removed deprecated `Trainer.train_loop` property in favor of `Trainer.fit_loop` ([#10482](https://github.com/PyTorchLightning/pytorch-lightning/pull/10482))
125+
122126
### Fixed
123127

124128
- Fixed an issue where class or init-only variables of dataclasses were passed to the dataclass constructor in `utilities.apply_to_collection` ([#9702](https://github.com/PyTorchLightning/pytorch-lightning/issues/9702))

pytorch_lightning/core/hooks.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -693,7 +693,7 @@ def transfer_batch_to_device(self, batch, device, dataloader_idx):
693693
# skip device transfer for the first dataloader or anything you wish
694694
pass
695695
else:
696-
batch = super().transfer_batch_to_device(data, device)
696+
batch = super().transfer_batch_to_device(data, device, dataloader_idx)
697697
return batch
698698
699699
Raises:

pytorch_lightning/core/lightning.py

Lines changed: 1 addition & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -262,17 +262,7 @@ def _apply_batch_transfer_handler(
262262
) -> Any:
263263
device = device or self.device
264264
batch = self.on_before_batch_transfer(batch, dataloader_idx)
265-
266-
if is_param_in_hook_signature(self.transfer_batch_to_device, "dataloader_idx"):
267-
batch = self.transfer_batch_to_device(batch, device, dataloader_idx)
268-
else:
269-
warning_cache.deprecation(
270-
"`transfer_batch_to_device` hook signature has changed in v1.4."
271-
" `dataloader_idx` parameter has been added to it. Support for"
272-
" the old signature will be removed in v1.6"
273-
)
274-
batch = self.transfer_batch_to_device(batch, device)
275-
265+
batch = self.transfer_batch_to_device(batch, device, dataloader_idx)
276266
batch = self.on_after_batch_transfer(batch, dataloader_idx)
277267
return batch
278268

tests/accelerators/test_dp.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,7 @@ def test_dp_raise_exception_with_batch_transfer_hooks(tmpdir, monkeypatch):
143143
monkeypatch.setattr("torch.cuda.device_count", lambda: 2)
144144

145145
class CustomModel(BoringModel):
146-
def transfer_batch_to_device(self, batch, device):
146+
def transfer_batch_to_device(self, batch, device, dataloader_idx):
147147
batch = batch.to(device)
148148
return batch
149149

tests/deprecated_api/test_remove_1-6.py

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -21,16 +21,6 @@
2121
from tests.helpers import BoringModel
2222

2323

24-
def test_old_transfer_batch_to_device_hook(tmpdir):
25-
class OldModel(BoringModel):
26-
def transfer_batch_to_device(self, batch, device):
27-
return super().transfer_batch_to_device(batch, device, None)
28-
29-
trainer = Trainer(default_root_dir=tmpdir, limit_train_batches=1, limit_val_batches=0, max_epochs=1)
30-
with pytest.deprecated_call(match="old signature will be removed in v1.6"):
31-
trainer.fit(OldModel())
32-
33-
3424
def test_v1_6_0_reload_dataloaders_every_epoch(tmpdir):
3525
model = BoringModel()
3626

0 commit comments

Comments
 (0)