Skip to content

Commit ea5cfd2

Browse files
awaelchliBordacarmocca
authored
move batch to device before sending it to hooks (#7378)
* update train step * test * x * limits * val * typeo * x * x * step * min gpus * run all loops * x * limit test * profiler * clean up accelerator code * move files * rename * move tests * changelog * reorder callbacks and model hooks * add test description * replace unneccessary method * fix chlog * adjust batch_to_device for DP Plugin * update tests for dataloader idx * unused imports * hook change * switch None * clear memory * change to None * None * None * memory savings * remove redundant todo * hack * cheat * Revert "cheat" This reverts commit a8433bd. * Revert "hack" This reverts commit 43a6d1e. * update new epoch loop * remove from old loop code * update chlog * update hook test * changelog * teardown * integrate changes in new eval loop * fix hook calls * add prediction step * bad merge * Revert "bad merge" This reverts commit 4880808. * fix train batch hook test * rm -rf _notebooks * update chlog * release memory * fix type * notebooks mess * debug * Revert "debug" This reverts commit eec4ee2. * teardown * fix teardown bug * debug * x * debug * Revert "debug" This reverts commit a6e6101. Revert "debug" This reverts commit 5ddeaec. debug debug Revert "debug" This reverts commit 605be74. Revert "Revert "debug"" This reverts commit a7612d5. debug x x x s tol x tol * Fix changelog Co-authored-by: Jirka Borovec <[email protected]> Co-authored-by: Carlos Mocholí <[email protected]>
1 parent 8193bae commit ea5cfd2

File tree

11 files changed

+133
-50
lines changed

11 files changed

+133
-50
lines changed

CHANGELOG.md

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -345,6 +345,14 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
345345
- Fixed a bug where using `precision=64` would cause buffers with complex dtype to be cast to real ([#8208](https://github.com/PyTorchLightning/pytorch-lightning/pull/8208))
346346

347347

348+
- Fixed a bug where `truncated_bptt_steps` would throw an AttributeError when the target RNN has multiple hidden states ([#8145](https://github.com/PyTorchLightning/pytorch-lightning/pull/8145))
349+
350+
351+
- Fixed moving batch to device before sending it to the `on_*_batch_start`/`on_*_batch_end` callbacks and model hooks ([#7378](https://github.com/PyTorchLightning/pytorch-lightning/pull/7378))
352+
353+
354+
- Fixed passing a custom `DDPPlugin` when choosing `accelerator="ddp_cpu"` for the accelerator ([#6208](https://github.com/PyTorchLightning/pytorch-lightning/pull/6208))
355+
348356

349357
## [1.3.8] - 2021-07-01
350358

@@ -361,13 +369,6 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
361369
- Fixed SWA to also work with `IterableDataset` ([#8172](https://github.com/PyTorchLightning/pytorch-lightning/pull/8172))
362370

363371

364-
365-
- Fixed a bug where `truncated_bptt_steps` would throw an AttributeError when the target RNN has multiple hidden states ([#8145](https://github.com/PyTorchLightning/pytorch-lightning/pull/8145))
366-
367-
368-
- Fixed passing a custom `DDPPlugin` when choosing `accelerator="ddp_cpu"` for the accelerator ([#6208](https://github.com/PyTorchLightning/pytorch-lightning/pull/6208))
369-
370-
371372
## [1.3.7] - 2021-06-22
372373

373374
### Fixed
@@ -377,6 +378,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
377378
- Fixed setting a `DistributedSampler` when using a distributed plugin in a custom accelerator ([#7814](https://github.com/PyTorchLightning/pytorch-lightning/pull/7814))
378379
- Improved `PyTorchProfiler` chrome traces names ([#8009](https://github.com/PyTorchLightning/pytorch-lightning/pull/8009))
379380
- Fixed moving the best score to device in `EarlyStopping` callback for TPU devices ([#7959](https://github.com/PyTorchLightning/pytorch-lightning/pull/7959))
381+
- Fixes access to `callback_metrics` in ddp_spawn ([#7916](https://github.com/PyTorchLightning/pytorch-lightning/pull/7916))
380382

381383

382384
## [1.3.6] - 2021-06-15
@@ -387,7 +389,6 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
387389
- Fixed `DataModule.prepare_data` could only be called on the global rank 0 process ([#7945](https://github.com/PyTorchLightning/pytorch-lightning/pull/7945))
388390
- Fixed setting `worker_init_fn` to seed dataloaders correctly when using DDP ([#7942](https://github.com/PyTorchLightning/pytorch-lightning/pull/7942))
389391
- Fixed `BaseFinetuning` callback to properly handle parent modules w/ parameters ([#7931](https://github.com/PyTorchLightning/pytorch-lightning/pull/7931))
390-
- Fixes access to `callback_metrics` in ddp_spawn ([#7916](https://github.com/PyTorchLightning/pytorch-lightning/pull/7916))
391392

392393

393394
## [1.3.5] - 2021-06-08

benchmarks/test_basic_parity.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,8 +45,8 @@ def assert_parity_absolute(pl_values, pt_values, norm_by: float = 1, max_diff: f
4545
@pytest.mark.parametrize(
4646
'cls_model,max_diff_speed,max_diff_memory',
4747
[
48-
(ParityModuleRNN, 0.05, 0.0),
49-
(ParityModuleMNIST, 0.25, 0.0), # todo: lower this thr
48+
(ParityModuleRNN, 0.05, 0.001),
49+
(ParityModuleMNIST, 0.25, 0.001), # todo: lower this thr
5050
]
5151
)
5252
@pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires GPU machine")

pytorch_lightning/accelerators/accelerator.py

Lines changed: 3 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from torch.utils.data import DataLoader
2323

2424
import pytorch_lightning as pl
25+
from pytorch_lightning.plugins import DataParallelPlugin
2526
from pytorch_lightning.plugins.precision import ApexMixedPrecisionPlugin, NativeMixedPrecisionPlugin, PrecisionPlugin
2627
from pytorch_lightning.plugins.training_type import TrainingTypePlugin
2728
from pytorch_lightning.trainer.states import TrainerFn
@@ -173,8 +174,8 @@ def batch_to_device(
173174
dataloader_idx: The index of the dataloader to which the batch belongs.
174175
"""
175176
model = self.lightning_module
176-
177-
if model is not None:
177+
if model is not None and not isinstance(self.training_type_plugin, DataParallelPlugin):
178+
# no need to transfer batch to device in DP mode
178179
return model._apply_batch_transfer_handler(batch, device, dataloader_idx)
179180

180181
return move_data_to_device(batch, device)
@@ -195,8 +196,6 @@ def training_step(
195196
- hiddens(:class:`~torch.Tensor`): Passed in if
196197
:paramref:`~pytorch_lightning.core.lightning.LightningModule.truncated_bptt_steps` > 0.
197198
"""
198-
step_kwargs = self.to_device(step_kwargs)
199-
200199
with self.precision_plugin.train_step_context(), self.training_type_plugin.train_step_context():
201200
return self.training_type_plugin.training_step(*step_kwargs.values())
202201

@@ -215,8 +214,6 @@ def validation_step(self, step_kwargs: Dict[str, Union[Any, int]]) -> Optional[S
215214
- dataloader_idx (int): The index of the dataloader that produced this batch
216215
(only if multiple val dataloaders used)
217216
"""
218-
step_kwargs = self.to_device(step_kwargs)
219-
220217
with self.precision_plugin.val_step_context(), self.training_type_plugin.val_step_context():
221218
return self.training_type_plugin.validation_step(*step_kwargs.values())
222219

@@ -232,8 +229,6 @@ def test_step(self, step_kwargs: Dict[str, Union[Any, int]]) -> Optional[STEP_OU
232229
- dataloader_idx (int): The index of the dataloader that produced this batch
233230
(only if multiple test dataloaders used).
234231
"""
235-
step_kwargs = self.to_device(step_kwargs)
236-
237232
with self.precision_plugin.test_step_context(), self.training_type_plugin.test_step_context():
238233
return self.training_type_plugin.test_step(*step_kwargs.values())
239234

@@ -249,8 +244,6 @@ def predict_step(self, step_kwargs: Dict[str, Union[Any, int]]) -> STEP_OUTPUT:
249244
- dataloader_idx (int): The index of the dataloader that produced this batch
250245
(only if multiple predict dataloaders used).
251246
"""
252-
step_kwargs = self.to_device(step_kwargs)
253-
254247
with self.precision_plugin.predict_step_context(), self.training_type_plugin.predict_step_context():
255248
return self.training_type_plugin.predict_step(*step_kwargs.values())
256249

@@ -371,13 +364,6 @@ def setup_precision_plugin(self) -> None:
371364
self.optimizers = optimizers
372365
self.schedulers = schedulers
373366

374-
def to_device(self, step_kwargs: Dict[str, Union[Any, int]]) -> Dict[str, Union[Any, int]]:
375-
"""Pushes the batch to the root device"""
376-
step_kwargs['batch'] = self.batch_to_device(
377-
step_kwargs['batch'], self.root_device, dataloader_idx=step_kwargs.get('dataloader_idx', None)
378-
)
379-
return step_kwargs
380-
381367
@property
382368
def amp_backend(self) -> Optional[LightningEnum]:
383369
if isinstance(self.precision_plugin, ApexMixedPrecisionPlugin):

pytorch_lightning/accelerators/gpu.py

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,11 @@
1313
# limitations under the License.
1414
import logging
1515
import os
16-
from typing import Any, Dict, Union
1716

1817
import torch
1918

2019
import pytorch_lightning as pl
2120
from pytorch_lightning.accelerators.accelerator import Accelerator
22-
from pytorch_lightning.plugins import DataParallelPlugin
2321
from pytorch_lightning.utilities.exceptions import MisconfigurationException
2422

2523
_log = logging.getLogger(__name__)
@@ -51,11 +49,3 @@ def set_nvidia_flags(local_rank: int) -> None:
5149
all_gpu_ids = ",".join([str(x) for x in range(torch.cuda.device_count())])
5250
devices = os.getenv("CUDA_VISIBLE_DEVICES", all_gpu_ids)
5351
_log.info(f"LOCAL_RANK: {local_rank} - CUDA_VISIBLE_DEVICES: [{devices}]")
54-
55-
def to_device(self, step_kwargs: Dict[str, Union[Any, int]]) -> Dict[str, Union[Any, int]]:
56-
# no need to transfer batch to device in DP mode
57-
# TODO: Add support to allow batch transfer to device in Lightning for DP mode.
58-
if not isinstance(self.training_type_plugin, DataParallelPlugin):
59-
step_kwargs = super().to_device(step_kwargs)
60-
61-
return step_kwargs

pytorch_lightning/loops/batch/training_batch_loop.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,10 @@ def advance(self, batch, batch_idx, dataloader_idx):
139139
if result:
140140
self.batch_outputs[0].append(result.training_step_output)
141141

142+
def teardown(self) -> None:
143+
# release memory
144+
self._remaining_splits = None
145+
142146
def num_active_optimizers(self, batch_idx: Optional[int] = None) -> int:
143147
"""Gets the number of active optimizers based on their frequency"""
144148
return len(self.get_active_optimizers(batch_idx))

pytorch_lightning/loops/epoch/evaluation_epoch_loop.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,9 @@ def advance(
100100
if batch is None:
101101
raise StopIteration
102102

103+
with self.trainer.profiler.profile("evaluation_batch_to_device"):
104+
batch = self.trainer.accelerator.batch_to_device(batch, dataloader_idx=dataloader_idx)
105+
103106
# hook
104107
self.on_evaluation_batch_start(batch, batch_idx, dataloader_idx)
105108

pytorch_lightning/loops/epoch/prediction_epoch_loop.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,9 @@ def advance(
8383
if batch is None:
8484
raise StopIteration
8585

86+
with self.trainer.profiler.profile("predict_batch_to_device"):
87+
batch = self.trainer.accelerator.batch_to_device(batch, dataloader_idx=dataloader_idx)
88+
8689
with self.trainer.profiler.profile("predict_step"):
8790
self._predict_step(batch, batch_idx, dataloader_idx)
8891

pytorch_lightning/loops/epoch/training_epoch_loop.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,9 @@ def advance(self, dataloader_iter: Iterator, **kwargs: Any) -> None:
104104
# ------------------------------------
105105
# TRAINING_STEP + TRAINING_STEP_END
106106
# ------------------------------------
107+
with self.trainer.profiler.profile("training_batch_to_device"):
108+
batch = self.trainer.accelerator.batch_to_device(batch, dataloader_idx=self._dataloader_idx)
109+
107110
with self.trainer.profiler.profile("run_training_batch"):
108111
batch_output = self.batch_loop.run(batch, self.iteration_count, self._dataloader_idx)
109112
self.batches_seen += 1

tests/models/test_hooks.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -287,13 +287,13 @@ def _train_batch(trainer, model, batches, current_epoch=0):
287287
out = []
288288
for i in range(batches):
289289
out.extend([
290+
dict(name='on_before_batch_transfer', args=(ANY, 0)),
291+
dict(name='transfer_batch_to_device', args=(ANY, torch.device('cpu'), 0)),
292+
dict(name='on_after_batch_transfer', args=(ANY, 0)),
290293
# TODO: `on_batch_{start,end}`
291294
dict(name='Callback.on_batch_start', args=(trainer, model)),
292295
dict(name='Callback.on_train_batch_start', args=(trainer, model, ANY, i, 0)),
293296
dict(name='on_train_batch_start', args=(ANY, i, 0)),
294-
dict(name='on_before_batch_transfer', args=(ANY, None)),
295-
dict(name='transfer_batch_to_device', args=(ANY, torch.device('cpu'), None)),
296-
dict(name='on_after_batch_transfer', args=(ANY, None)),
297297
dict(name='forward', args=(ANY, )),
298298
dict(name='training_step', args=(ANY, i)),
299299
dict(name='training_step_end', args=(dict(loss=ANY), )),
@@ -338,12 +338,12 @@ def _eval_batch(fn, trainer, model, batches, key):
338338
outputs = {key: ANY}
339339
for i in range(batches):
340340
out.extend([
341+
dict(name='on_before_batch_transfer', args=(ANY, 0)),
342+
dict(name='transfer_batch_to_device', args=(ANY, torch.device('cpu'), 0)),
343+
dict(name='on_after_batch_transfer', args=(ANY, 0)),
341344
# TODO: `{,Callback}.on_batch_{start,end}`
342345
dict(name=f'Callback.on_{fn}_batch_start', args=(trainer, model, ANY, i, 0)),
343346
dict(name=f'on_{fn}_batch_start', args=(ANY, i, 0)),
344-
dict(name='on_before_batch_transfer', args=(ANY, None)),
345-
dict(name='transfer_batch_to_device', args=(ANY, torch.device('cpu'), None)),
346-
dict(name='on_after_batch_transfer', args=(ANY, None)),
347347
dict(name='forward', args=(ANY, )),
348348
dict(name=f'{fn}_step', args=(ANY, i)),
349349
dict(name=f'{fn}_step_end', args=(outputs, )),
@@ -358,11 +358,11 @@ def _predict_batch(trainer, model, batches):
358358
for i in range(batches):
359359
out.extend([
360360
# TODO: `{,Callback}.on_batch_{start,end}`
361+
dict(name='on_before_batch_transfer', args=(ANY, 0)),
362+
dict(name='transfer_batch_to_device', args=(ANY, torch.device('cpu'), 0)),
363+
dict(name='on_after_batch_transfer', args=(ANY, 0)),
361364
dict(name='Callback.on_predict_batch_start', args=(trainer, model, ANY, i, 0)),
362365
dict(name='on_predict_batch_start', args=(ANY, i, 0)),
363-
dict(name='on_before_batch_transfer', args=(ANY, None)),
364-
dict(name='transfer_batch_to_device', args=(ANY, torch.device('cpu'), None)),
365-
dict(name='on_after_batch_transfer', args=(ANY, None)),
366366
dict(name='forward', args=(ANY, )),
367367
dict(name='predict_step', args=(ANY, i)),
368368
# TODO: `predict_step_end`
@@ -777,9 +777,9 @@ def call(hook, fn, *args, **kwargs):
777777
dm = HookedDataModule(called)
778778
trainer.fit(model, datamodule=dm)
779779
batch_transfer = [
780-
dict(name='on_before_batch_transfer', args=(ANY, None)),
781-
dict(name='transfer_batch_to_device', args=(ANY, torch.device('cpu'), None)),
782-
dict(name='on_after_batch_transfer', args=(ANY, None)),
780+
dict(name='on_before_batch_transfer', args=(ANY, 0)),
781+
dict(name='transfer_batch_to_device', args=(ANY, torch.device('cpu'), 0)),
782+
dict(name='on_after_batch_transfer', args=(ANY, 0)),
783783
]
784784
expected = [
785785
dict(name='prepare_data'),

tests/trainer/loops/test_all.py

Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
# Copyright The PyTorch Lightning team.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
from pytorch_lightning import Callback, Trainer
15+
from tests.helpers import BoringModel
16+
from tests.helpers.runif import RunIf
17+
18+
19+
class BatchHookObserverCallback(Callback):
20+
21+
def on_train_batch_start(self, trainer, pl_module, batch, *args):
22+
assert batch.device == pl_module.device
23+
24+
def on_train_batch_end(self, trainer, pl_module, outputs, batch, *args):
25+
assert batch.device == pl_module.device
26+
27+
def on_validation_batch_start(self, trainer, pl_module, batch, *args):
28+
assert batch.device == pl_module.device
29+
30+
def on_validation_batch_end(self, trainer, pl_module, outputs, batch, *args):
31+
assert batch.device == pl_module.device
32+
33+
def on_test_batch_start(self, trainer, pl_module, batch, *args):
34+
assert batch.device == pl_module.device
35+
36+
def on_test_batch_end(self, trainer, pl_module, outputs, batch, *args):
37+
assert batch.device == pl_module.device
38+
39+
def on_predict_batch_start(self, trainer, pl_module, batch, *args):
40+
assert batch.device == pl_module.device
41+
42+
def on_predict_batch_end(self, trainer, pl_module, outputs, batch, *args):
43+
assert batch.device == pl_module.device
44+
45+
46+
class BatchHookObserverModel(BoringModel):
47+
48+
def on_train_batch_start(self, batch, *args):
49+
assert batch.device == self.device
50+
51+
def on_train_batch_end(self, outputs, batch, *args):
52+
assert batch.device == self.device
53+
54+
def on_validation_batch_start(self, batch, *args):
55+
assert batch.device == self.device
56+
57+
def on_validation_batch_end(self, outputs, batch, *args):
58+
assert batch.device == self.device
59+
60+
def on_test_batch_start(self, batch, *args):
61+
assert batch.device == self.device
62+
63+
def on_test_batch_end(self, outputs, batch, *args):
64+
assert batch.device == self.device
65+
66+
def on_predict_batch_start(self, batch, *args):
67+
assert batch.device == self.device
68+
69+
def on_predict_batch_end(self, outputs, batch, *args):
70+
assert batch.device == self.device
71+
72+
73+
@RunIf(min_gpus=1)
74+
def test_callback_batch_on_device(tmpdir):
75+
""" Test that the batch object sent to the on_*_batch_start/end hooks is on the right device."""
76+
77+
batch_callback = BatchHookObserverCallback()
78+
79+
model = BatchHookObserverModel()
80+
trainer = Trainer(
81+
default_root_dir=tmpdir,
82+
max_steps=1,
83+
limit_train_batches=1,
84+
limit_val_batches=1,
85+
limit_test_batches=1,
86+
limit_predict_batches=1,
87+
gpus=1,
88+
callbacks=[batch_callback],
89+
)
90+
trainer.fit(model)
91+
trainer.validate(model)
92+
trainer.test(model)
93+
trainer.predict(model)

0 commit comments

Comments
 (0)