Skip to content

Commit 0f28e24

Browse files
committed
make it private
1 parent caf1784 commit 0f28e24

File tree

6 files changed

+17
-11
lines changed

6 files changed

+17
-11
lines changed

pytorch_lightning/accelerators/accelerator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ def train_or_test(self):
6969
def batch_to_device(self, batch: Any, device: torch.device):
7070
model = self.trainer.get_model()
7171
if model is not None:
72-
return model.prepare_batch_for_transfer(batch, device)
72+
return model._prepare_batch_for_transfer(batch, device)
7373
return move_data_to_device(batch, device)
7474

7575
def training_step_end(self, output):

pytorch_lightning/core/hooks.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -560,15 +560,15 @@ def transfer_batch_to_device(self, batch, device):
560560
return move_data_to_device(batch, device)
561561

562562
def on_before_batch_transfer(self, batch):
563+
"""
564+
Called before batch is transfered to the device
565+
"""
563566
return batch
564567

565568
def on_after_batch_transfer(self, batch):
566-
return batch
567-
568-
def prepare_batch_for_transfer(self, batch: Any, device: Optional[torch.device] = None):
569-
batch = self.on_before_batch_transfer(batch)
570-
batch = self.transfer_batch_to_device(batch, device)
571-
batch = self.on_after_batch_transfer(batch)
569+
"""
570+
Called after batch is transfered to the device
571+
"""
572572
return batch
573573

574574

pytorch_lightning/core/lightning.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -171,6 +171,12 @@ def automatic_optimization(self) -> bool:
171171
"""
172172
return True
173173

174+
def _prepare_batch_for_transfer(self, batch: Any, device: Optional[torch.device] = None):
175+
batch = self.on_before_batch_transfer(batch)
176+
batch = self.transfer_batch_to_device(batch, device)
177+
batch = self.on_after_batch_transfer(batch)
178+
return batch
179+
174180
def print(self, *args, **kwargs) -> None:
175181
r"""
176182
Prints only from process 0. Use this in any distributed mode to log only once.
@@ -1787,7 +1793,7 @@ def to_torchscript(
17871793
if example_inputs is None:
17881794
example_inputs = self.example_input_array
17891795
# automatically send example inputs to the right device and use trace
1790-
example_inputs = self.prepare_batch_for_transfer(example_inputs)
1796+
example_inputs = self._prepare_batch_for_transfer(example_inputs)
17911797
torchscript_module = torch.jit.trace(func=self.eval(), example_inputs=example_inputs, **kwargs)
17921798
else:
17931799
raise ValueError(f"The 'method' parameter only supports 'script' or 'trace', but value given was:"

pytorch_lightning/core/memory.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -219,7 +219,7 @@ def _forward_example_input(self) -> None:
219219
trainer = self._model.trainer
220220

221221
input_ = model.example_input_array
222-
input_ = model.prepare_batch_for_transfer(input_, model.device)
222+
input_ = model._prepare_batch_for_transfer(input_, model.device)
223223

224224
if trainer is not None and trainer.amp_backend == AMPType.NATIVE and not trainer.use_tpu:
225225
model.forward = torch.cuda.amp.autocast()(model.forward)

pytorch_lightning/loggers/tensorboard.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -198,7 +198,7 @@ def log_graph(self, model: LightningModule, input_array=None):
198198
input_array = model.example_input_array
199199

200200
if input_array is not None:
201-
input_array = model.prepare_batch_for_transfer(input_array)
201+
input_array = model._prepare_batch_for_transfer(input_array)
202202
self.experiment.add_graph(model, input_array)
203203
else:
204204
rank_zero_warn('Could not log computational graph since the'

pytorch_lightning/loggers/test_tube.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,7 @@ def log_graph(self, model: LightningModule, input_array=None):
146146
input_array = model.example_input_array
147147

148148
if input_array is not None:
149-
input_array = self.prepare_batch_for_transfer(input_array)
149+
input_array = self._prepare_batch_for_transfer(input_array)
150150
self.experiment.add_graph(model, input_array)
151151
else:
152152
rank_zero_warn('Could not log computational graph since the'

0 commit comments

Comments
 (0)