Skip to content

Commit 329f9d0

Browse files
committed
changed to prepare_batch_for_transfer
1 parent b73c708 commit 329f9d0

File tree

5 files changed

+7
-6
lines changed

5 files changed

+7
-6
lines changed

pytorch_lightning/core/decorators.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,8 +57,8 @@ def auto_transfer_args(self, *args, **kwargs):
5757
if not isinstance(self, LightningModule):
5858
return fn(self, *args, **kwargs)
5959

60-
args = self.transfer_batch_to_device(args, self.device)
61-
kwargs = self.transfer_batch_to_device(kwargs, self.device)
60+
args = self.prepare_batch_for_transfer(args, self.device)
61+
kwargs = self.prepare_batch_for_transfer(kwargs, self.device)
6262
return fn(self, *args, **kwargs)
6363

6464
return auto_transfer_args

pytorch_lightning/core/memory.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -219,7 +219,7 @@ def _forward_example_input(self) -> None:
219219
trainer = self._model.trainer
220220

221221
input_ = model.example_input_array
222-
input_ = model.transfer_batch_to_device(input_, model.device)
222+
input_ = model.prepare_batch_for_transfer(input_, model.device)
223223

224224
if trainer is not None and trainer.amp_backend == AMPType.NATIVE and not trainer.use_tpu:
225225
model.forward = torch.cuda.amp.autocast()(model.forward)

pytorch_lightning/loggers/tensorboard.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -197,7 +197,7 @@ def log_graph(self, model: LightningModule, input_array=None):
197197
input_array = model.example_input_array
198198

199199
if input_array is not None:
200-
input_array = model.transfer_batch_to_device(input_array, model.device)
200+
input_array = model.prepare_batch_for_transfer(input_array, model.device)
201201
self.experiment.add_graph(model, input_array)
202202
else:
203203
rank_zero_warn('Could not log computational graph since the'

pytorch_lightning/loggers/test_tube.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,7 @@ def log_graph(self, model: LightningModule, input_array=None):
146146
if input_array is not None:
147147
self.experiment.add_graph(
148148
model,
149-
model.transfer_batch_to_device(
149+
model.prepare_batch_for_transfer(
150150
model.example_input_array, model.device)
151151
)
152152
else:

pytorch_lightning/trainer/connectors/data_connector.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,8 @@ def attach_datamodule(self, model, datamodule, stage):
117117
if is_overridden('test_dataloader', datamodule):
118118
model.test_dataloader = datamodule.test_dataloader
119119

120-
# Override transfer_batch_to_device if dataset-specific to_device logic has been defined in datamodule
120+
# TODO check this before merge
121+
# Override prepare_batch_for_transfer if dataset-specific to_device logic has been defined in datamodule
121122
model.prepare_batch_for_transfer = datamodule.prepare_batch_for_transfer
122123

123124
# TODO remove this after all the changes are done

0 commit comments

Comments
 (0)