Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
80 commits
Select commit Hold shift + click to select a range
a774766
update train step
awaelchli May 5, 2021
64af97d
test
awaelchli May 5, 2021
36d6a91
x
awaelchli May 5, 2021
cb6112e
limits
awaelchli May 5, 2021
16728d1
val
awaelchli May 5, 2021
de08636
typeo
awaelchli May 5, 2021
66934e4
x
awaelchli May 5, 2021
89202c9
x
awaelchli May 5, 2021
b5f8806
step
awaelchli May 5, 2021
0ff4f9e
min gpus
awaelchli May 6, 2021
0d3f47a
run all loops
awaelchli May 6, 2021
fb9ffed
x
awaelchli May 6, 2021
b7befc1
limit test
awaelchli May 6, 2021
fe065a5
profiler
awaelchli May 6, 2021
60d458e
clean up accelerator code
awaelchli May 6, 2021
642bcf6
move files
awaelchli May 6, 2021
2d1b7ca
rename
awaelchli May 6, 2021
ca0b7ba
Merge branch 'tests/group-loop-tests' into bugfix/batch-device
awaelchli May 6, 2021
50ac2a5
move tests
awaelchli May 6, 2021
1e38413
changelog
awaelchli May 6, 2021
b021641
reorder callbacks and model hooks
awaelchli May 6, 2021
cc79736
add test description
awaelchli May 6, 2021
2792376
Merge branch 'master' into bugfix/batch-device
awaelchli May 6, 2021
53e3e71
Merge branch 'master' into bugfix/batch-device
awaelchli May 7, 2021
917a06e
Merge branch 'master' into bugfix/batch-device
Borda May 11, 2021
5521304
Merge branch 'master' into bugfix/batch-device
awaelchli May 17, 2021
4ffc332
replace unneccessary method
awaelchli May 17, 2021
ad8d357
Merge branch 'master' into bugfix/batch-device
awaelchli May 24, 2021
41f83ef
fix chlog
awaelchli May 24, 2021
5c2b7b2
adjust batch_to_device for DP Plugin
awaelchli May 25, 2021
88bc2fc
update tests for dataloader idx
awaelchli May 25, 2021
9d3beba
unused imports
awaelchli May 25, 2021
4f80d5f
Merge branch 'master' into bugfix/batch-device
awaelchli Jun 1, 2021
dea8cb4
hook change
awaelchli Jun 1, 2021
bf99814
switch None
awaelchli Jun 3, 2021
7c4c38d
Merge branch 'master' into bugfix/batch-device
awaelchli Jun 3, 2021
b4a1348
clear memory
awaelchli Jun 3, 2021
023e619
change to None
awaelchli Jun 3, 2021
b71547e
None
awaelchli Jun 3, 2021
7eab3bb
None
awaelchli Jun 3, 2021
91f1387
memory savings
awaelchli Jun 3, 2021
01b7293
remove redundant todo
awaelchli Jun 6, 2021
43a6d1e
hack
awaelchli Jun 9, 2021
78438cd
Merge branch 'master'
awaelchli Jun 9, 2021
a8433bd
cheat
awaelchli Jun 9, 2021
d8cd2b1
Revert "cheat"
awaelchli Jun 9, 2021
9467689
Revert "hack"
awaelchli Jun 9, 2021
27ac26d
Merge branch 'master' into bugfix/batch-device
awaelchli Jun 15, 2021
5d08680
update new epoch loop
awaelchli Jun 15, 2021
032055c
remove from old loop code
awaelchli Jun 15, 2021
1b4f2af
update chlog
awaelchli Jun 15, 2021
792894c
Merge branch 'master' into bugfix/batch-device
awaelchli Jun 18, 2021
faaec6a
update hook test
awaelchli Jun 18, 2021
053f377
changelog
awaelchli Jun 18, 2021
14ea8a8
teardown
awaelchli Jun 18, 2021
d109174
Merge branch 'master' into bugfix/batch-device
awaelchli Jun 21, 2021
0733da2
integrate changes in new eval loop
awaelchli Jun 21, 2021
708cf0e
fix hook calls
awaelchli Jun 21, 2021
0fb1369
Merge branch 'master' into bugfix/batch-device
awaelchli Jun 27, 2021
e5de582
add prediction step
awaelchli Jun 27, 2021
4880808
bad merge
awaelchli Jun 27, 2021
3c3e87a
Revert "bad merge"
awaelchli Jun 27, 2021
da08269
fix train batch hook test
awaelchli Jun 27, 2021
ebe3ce3
rm -rf _notebooks
awaelchli Jun 27, 2021
2a0aedb
update chlog
awaelchli Jun 27, 2021
f4a2f8c
release memory
awaelchli Jun 27, 2021
8838b43
fix type
awaelchli Jun 27, 2021
6d5c61c
notebooks mess
awaelchli Jun 27, 2021
eec4ee2
debug
awaelchli Jun 27, 2021
968c967
Revert "debug"
awaelchli Jun 27, 2021
fc2d612
teardown
awaelchli Jun 28, 2021
856cd66
fix teardown bug
awaelchli Jul 1, 2021
a6e6101
debug
awaelchli Jul 1, 2021
cde1622
x
awaelchli Jul 1, 2021
5ddeaec
debug
awaelchli Jul 1, 2021
88ca10d
Merge branch 'master' into bugfix/batch-device
awaelchli Jul 1, 2021
c712b62
Revert "debug"
awaelchli Jul 1, 2021
3998873
Merge branch 'master' into bugfix/batch-device
awaelchli Jul 2, 2021
6b31973
Merge branch 'master' into bugfix/batch-device
carmocca Jul 2, 2021
542efbd
Fix changelog
carmocca Jul 2, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 9 additions & 8 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -345,6 +345,14 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed a bug where using `precision=64` would cause buffers with complex dtype to be cast to real ([#8208](https://github.com/PyTorchLightning/pytorch-lightning/pull/8208))


- Fixed a bug where `truncated_bptt_steps` would throw an AttributeError when the target RNN has multiple hidden states ([#8145](https://github.com/PyTorchLightning/pytorch-lightning/pull/8145))


- Fixed moving batch to device before sending it to the `on_*_batch_start`/`on_*_batch_end` callbacks and model hooks ([#7378](https://github.com/PyTorchLightning/pytorch-lightning/pull/7378))


- Fixed passing a custom `DDPPlugin` when choosing `accelerator="ddp_cpu"` for the accelerator ([#6208](https://github.com/PyTorchLightning/pytorch-lightning/pull/6208))


## [1.3.8] - 2021-07-01

Expand All @@ -361,13 +369,6 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed SWA to also work with `IterableDataset` ([#8172](https://github.com/PyTorchLightning/pytorch-lightning/pull/8172))



- Fixed a bug where `truncated_bptt_steps` would throw an AttributeError when the target RNN has multiple hidden states ([#8145](https://github.com/PyTorchLightning/pytorch-lightning/pull/8145))


- Fixed passing a custom `DDPPlugin` when choosing `accelerator="ddp_cpu"` for the accelerator ([#6208](https://github.com/PyTorchLightning/pytorch-lightning/pull/6208))


## [1.3.7] - 2021-06-22

### Fixed
Expand All @@ -377,6 +378,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed setting a `DistributedSampler` when using a distributed plugin in a custom accelerator ([#7814](https://github.com/PyTorchLightning/pytorch-lightning/pull/7814))
- Improved `PyTorchProfiler` chrome traces names ([#8009](https://github.com/PyTorchLightning/pytorch-lightning/pull/8009))
- Fixed moving the best score to device in `EarlyStopping` callback for TPU devices ([#7959](https://github.com/PyTorchLightning/pytorch-lightning/pull/7959))
- Fixes access to `callback_metrics` in ddp_spawn ([#7916](https://github.com/PyTorchLightning/pytorch-lightning/pull/7916))


## [1.3.6] - 2021-06-15
Expand All @@ -387,7 +389,6 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed `DataModule.prepare_data` could only be called on the global rank 0 process ([#7945](https://github.com/PyTorchLightning/pytorch-lightning/pull/7945))
- Fixed setting `worker_init_fn` to seed dataloaders correctly when using DDP ([#7942](https://github.com/PyTorchLightning/pytorch-lightning/pull/7942))
- Fixed `BaseFinetuning` callback to properly handle parent modules w/ parameters ([#7931](https://github.com/PyTorchLightning/pytorch-lightning/pull/7931))
- Fixes access to `callback_metrics` in ddp_spawn ([#7916](https://github.com/PyTorchLightning/pytorch-lightning/pull/7916))


## [1.3.5] - 2021-06-08
Expand Down
4 changes: 2 additions & 2 deletions benchmarks/test_basic_parity.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,8 @@ def assert_parity_absolute(pl_values, pt_values, norm_by: float = 1, max_diff: f
@pytest.mark.parametrize(
'cls_model,max_diff_speed,max_diff_memory',
[
(ParityModuleRNN, 0.05, 0.0),
(ParityModuleMNIST, 0.25, 0.0), # todo: lower this thr
(ParityModuleRNN, 0.05, 0.001),
(ParityModuleMNIST, 0.25, 0.001), # todo: lower this thr
]
)
@pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires GPU machine")
Expand Down
20 changes: 3 additions & 17 deletions pytorch_lightning/accelerators/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from torch.utils.data import DataLoader

import pytorch_lightning as pl
from pytorch_lightning.plugins import DataParallelPlugin
from pytorch_lightning.plugins.precision import ApexMixedPrecisionPlugin, NativeMixedPrecisionPlugin, PrecisionPlugin
from pytorch_lightning.plugins.training_type import TrainingTypePlugin
from pytorch_lightning.trainer.states import TrainerFn
Expand Down Expand Up @@ -173,8 +174,8 @@ def batch_to_device(
dataloader_idx: The index of the dataloader to which the batch belongs.
"""
model = self.lightning_module

if model is not None:
if model is not None and not isinstance(self.training_type_plugin, DataParallelPlugin):
# no need to transfer batch to device in DP mode
return model._apply_batch_transfer_handler(batch, device, dataloader_idx)

return move_data_to_device(batch, device)
Expand All @@ -195,8 +196,6 @@ def training_step(
- hiddens(:class:`~torch.Tensor`): Passed in if
:paramref:`~pytorch_lightning.core.lightning.LightningModule.truncated_bptt_steps` > 0.
"""
step_kwargs = self.to_device(step_kwargs)

with self.precision_plugin.train_step_context(), self.training_type_plugin.train_step_context():
return self.training_type_plugin.training_step(*step_kwargs.values())

Expand All @@ -215,8 +214,6 @@ def validation_step(self, step_kwargs: Dict[str, Union[Any, int]]) -> Optional[S
- dataloader_idx (int): The index of the dataloader that produced this batch
(only if multiple val dataloaders used)
"""
step_kwargs = self.to_device(step_kwargs)

with self.precision_plugin.val_step_context(), self.training_type_plugin.val_step_context():
return self.training_type_plugin.validation_step(*step_kwargs.values())

Expand All @@ -232,8 +229,6 @@ def test_step(self, step_kwargs: Dict[str, Union[Any, int]]) -> Optional[STEP_OU
- dataloader_idx (int): The index of the dataloader that produced this batch
(only if multiple test dataloaders used).
"""
step_kwargs = self.to_device(step_kwargs)

with self.precision_plugin.test_step_context(), self.training_type_plugin.test_step_context():
return self.training_type_plugin.test_step(*step_kwargs.values())

Expand All @@ -249,8 +244,6 @@ def predict_step(self, step_kwargs: Dict[str, Union[Any, int]]) -> STEP_OUTPUT:
- dataloader_idx (int): The index of the dataloader that produced this batch
(only if multiple predict dataloaders used).
"""
step_kwargs = self.to_device(step_kwargs)

with self.precision_plugin.predict_step_context(), self.training_type_plugin.predict_step_context():
return self.training_type_plugin.predict_step(*step_kwargs.values())

Expand Down Expand Up @@ -371,13 +364,6 @@ def setup_precision_plugin(self) -> None:
self.optimizers = optimizers
self.schedulers = schedulers

def to_device(self, step_kwargs: Dict[str, Union[Any, int]]) -> Dict[str, Union[Any, int]]:
"""Pushes the batch to the root device"""
step_kwargs['batch'] = self.batch_to_device(
step_kwargs['batch'], self.root_device, dataloader_idx=step_kwargs.get('dataloader_idx', None)
)
return step_kwargs

@property
def amp_backend(self) -> Optional[LightningEnum]:
if isinstance(self.precision_plugin, ApexMixedPrecisionPlugin):
Expand Down
10 changes: 0 additions & 10 deletions pytorch_lightning/accelerators/gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,11 @@
# limitations under the License.
import logging
import os
from typing import Any, Dict, Union

import torch

import pytorch_lightning as pl
from pytorch_lightning.accelerators.accelerator import Accelerator
from pytorch_lightning.plugins import DataParallelPlugin
from pytorch_lightning.utilities.exceptions import MisconfigurationException

_log = logging.getLogger(__name__)
Expand Down Expand Up @@ -51,11 +49,3 @@ def set_nvidia_flags(local_rank: int) -> None:
all_gpu_ids = ",".join([str(x) for x in range(torch.cuda.device_count())])
devices = os.getenv("CUDA_VISIBLE_DEVICES", all_gpu_ids)
_log.info(f"LOCAL_RANK: {local_rank} - CUDA_VISIBLE_DEVICES: [{devices}]")

def to_device(self, step_kwargs: Dict[str, Union[Any, int]]) -> Dict[str, Union[Any, int]]:
# no need to transfer batch to device in DP mode
# TODO: Add support to allow batch transfer to device in Lightning for DP mode.
if not isinstance(self.training_type_plugin, DataParallelPlugin):
step_kwargs = super().to_device(step_kwargs)

return step_kwargs
4 changes: 4 additions & 0 deletions pytorch_lightning/loops/batch/training_batch_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,10 @@ def advance(self, batch, batch_idx, dataloader_idx):
if result:
self.batch_outputs[0].append(result.training_step_output)

def teardown(self) -> None:
# release memory
self._remaining_splits = None

def num_active_optimizers(self, batch_idx: Optional[int] = None) -> int:
"""Gets the number of active optimizers based on their frequency"""
return len(self.get_active_optimizers(batch_idx))
Expand Down
3 changes: 3 additions & 0 deletions pytorch_lightning/loops/epoch/evaluation_epoch_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,9 @@ def advance(
if batch is None:
raise StopIteration

with self.trainer.profiler.profile("evaluation_batch_to_device"):
batch = self.trainer.accelerator.batch_to_device(batch, dataloader_idx=dataloader_idx)

# hook
self.on_evaluation_batch_start(batch, batch_idx, dataloader_idx)

Expand Down
3 changes: 3 additions & 0 deletions pytorch_lightning/loops/epoch/prediction_epoch_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,9 @@ def advance(
if batch is None:
raise StopIteration

with self.trainer.profiler.profile("predict_batch_to_device"):
batch = self.trainer.accelerator.batch_to_device(batch, dataloader_idx=dataloader_idx)

with self.trainer.profiler.profile("predict_step"):
self._predict_step(batch, batch_idx, dataloader_idx)

Expand Down
3 changes: 3 additions & 0 deletions pytorch_lightning/loops/epoch/training_epoch_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,9 @@ def advance(self, dataloader_iter: Iterator, **kwargs: Any) -> None:
# ------------------------------------
# TRAINING_STEP + TRAINING_STEP_END
# ------------------------------------
with self.trainer.profiler.profile("training_batch_to_device"):
batch = self.trainer.accelerator.batch_to_device(batch, dataloader_idx=self._dataloader_idx)

with self.trainer.profiler.profile("run_training_batch"):
batch_output = self.batch_loop.run(batch, self.iteration_count, self._dataloader_idx)
self.batches_seen += 1
Expand Down
24 changes: 12 additions & 12 deletions tests/models/test_hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,13 +287,13 @@ def _train_batch(trainer, model, batches, current_epoch=0):
out = []
for i in range(batches):
out.extend([
dict(name='on_before_batch_transfer', args=(ANY, 0)),
dict(name='transfer_batch_to_device', args=(ANY, torch.device('cpu'), 0)),
dict(name='on_after_batch_transfer', args=(ANY, 0)),
# TODO: `on_batch_{start,end}`
dict(name='Callback.on_batch_start', args=(trainer, model)),
dict(name='Callback.on_train_batch_start', args=(trainer, model, ANY, i, 0)),
dict(name='on_train_batch_start', args=(ANY, i, 0)),
dict(name='on_before_batch_transfer', args=(ANY, None)),
dict(name='transfer_batch_to_device', args=(ANY, torch.device('cpu'), None)),
dict(name='on_after_batch_transfer', args=(ANY, None)),
dict(name='forward', args=(ANY, )),
dict(name='training_step', args=(ANY, i)),
dict(name='training_step_end', args=(dict(loss=ANY), )),
Expand Down Expand Up @@ -338,12 +338,12 @@ def _eval_batch(fn, trainer, model, batches, key):
outputs = {key: ANY}
for i in range(batches):
out.extend([
dict(name='on_before_batch_transfer', args=(ANY, 0)),
dict(name='transfer_batch_to_device', args=(ANY, torch.device('cpu'), 0)),
dict(name='on_after_batch_transfer', args=(ANY, 0)),
# TODO: `{,Callback}.on_batch_{start,end}`
dict(name=f'Callback.on_{fn}_batch_start', args=(trainer, model, ANY, i, 0)),
dict(name=f'on_{fn}_batch_start', args=(ANY, i, 0)),
dict(name='on_before_batch_transfer', args=(ANY, None)),
dict(name='transfer_batch_to_device', args=(ANY, torch.device('cpu'), None)),
dict(name='on_after_batch_transfer', args=(ANY, None)),
dict(name='forward', args=(ANY, )),
dict(name=f'{fn}_step', args=(ANY, i)),
dict(name=f'{fn}_step_end', args=(outputs, )),
Expand All @@ -358,11 +358,11 @@ def _predict_batch(trainer, model, batches):
for i in range(batches):
out.extend([
# TODO: `{,Callback}.on_batch_{start,end}`
dict(name='on_before_batch_transfer', args=(ANY, 0)),
dict(name='transfer_batch_to_device', args=(ANY, torch.device('cpu'), 0)),
dict(name='on_after_batch_transfer', args=(ANY, 0)),
dict(name='Callback.on_predict_batch_start', args=(trainer, model, ANY, i, 0)),
dict(name='on_predict_batch_start', args=(ANY, i, 0)),
dict(name='on_before_batch_transfer', args=(ANY, None)),
dict(name='transfer_batch_to_device', args=(ANY, torch.device('cpu'), None)),
dict(name='on_after_batch_transfer', args=(ANY, None)),
dict(name='forward', args=(ANY, )),
dict(name='predict_step', args=(ANY, i)),
# TODO: `predict_step_end`
Expand Down Expand Up @@ -777,9 +777,9 @@ def call(hook, fn, *args, **kwargs):
dm = HookedDataModule(called)
trainer.fit(model, datamodule=dm)
batch_transfer = [
dict(name='on_before_batch_transfer', args=(ANY, None)),
dict(name='transfer_batch_to_device', args=(ANY, torch.device('cpu'), None)),
dict(name='on_after_batch_transfer', args=(ANY, None)),
dict(name='on_before_batch_transfer', args=(ANY, 0)),
dict(name='transfer_batch_to_device', args=(ANY, torch.device('cpu'), 0)),
dict(name='on_after_batch_transfer', args=(ANY, 0)),
]
expected = [
dict(name='prepare_data'),
Expand Down
93 changes: 93 additions & 0 deletions tests/trainer/loops/test_all.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from pytorch_lightning import Callback, Trainer
from tests.helpers import BoringModel
from tests.helpers.runif import RunIf


class BatchHookObserverCallback(Callback):

def on_train_batch_start(self, trainer, pl_module, batch, *args):
assert batch.device == pl_module.device

def on_train_batch_end(self, trainer, pl_module, outputs, batch, *args):
assert batch.device == pl_module.device

def on_validation_batch_start(self, trainer, pl_module, batch, *args):
assert batch.device == pl_module.device

def on_validation_batch_end(self, trainer, pl_module, outputs, batch, *args):
assert batch.device == pl_module.device

def on_test_batch_start(self, trainer, pl_module, batch, *args):
assert batch.device == pl_module.device

def on_test_batch_end(self, trainer, pl_module, outputs, batch, *args):
assert batch.device == pl_module.device

def on_predict_batch_start(self, trainer, pl_module, batch, *args):
assert batch.device == pl_module.device

def on_predict_batch_end(self, trainer, pl_module, outputs, batch, *args):
assert batch.device == pl_module.device


class BatchHookObserverModel(BoringModel):

def on_train_batch_start(self, batch, *args):
assert batch.device == self.device

def on_train_batch_end(self, outputs, batch, *args):
assert batch.device == self.device

def on_validation_batch_start(self, batch, *args):
assert batch.device == self.device

def on_validation_batch_end(self, outputs, batch, *args):
assert batch.device == self.device

def on_test_batch_start(self, batch, *args):
assert batch.device == self.device

def on_test_batch_end(self, outputs, batch, *args):
assert batch.device == self.device

def on_predict_batch_start(self, batch, *args):
assert batch.device == self.device

def on_predict_batch_end(self, outputs, batch, *args):
assert batch.device == self.device


@RunIf(min_gpus=1)
def test_callback_batch_on_device(tmpdir):
""" Test that the batch object sent to the on_*_batch_start/end hooks is on the right device."""

batch_callback = BatchHookObserverCallback()

model = BatchHookObserverModel()
trainer = Trainer(
default_root_dir=tmpdir,
max_steps=1,
limit_train_batches=1,
limit_val_batches=1,
limit_test_batches=1,
limit_predict_batches=1,
gpus=1,
callbacks=[batch_callback],
)
trainer.fit(model)
trainer.validate(model)
trainer.test(model)
trainer.predict(model)
2 changes: 1 addition & 1 deletion tests/trainer/test_dataloaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -1533,7 +1533,7 @@ def __init__(self):

def assert_dataloader_idx_hook(self, dataloader_idx):
if self.trainer.training:
assert dataloader_idx is None
assert dataloader_idx == 0
elif self.trainer.validating:
assert dataloader_idx == (0 if self.val_call_count <= 5 else 1)
elif self.trainer.testing:
Expand Down