From 74e508e701bb25a8a9d373b75227a048e82700da Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Mon, 21 Dec 2020 15:30:02 +0000 Subject: [PATCH 01/26] resolve bug --- pytorch_lightning/core/lightning.py | 13 +++++-- pytorch_lightning/utilities/distributed.py | 6 +++- tests/utilities/test_all_gather_grad.py | 42 ++++++++++++++++++++++ 3 files changed, 58 insertions(+), 3 deletions(-) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index edad1be868cad..24bf231ebe30a 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -24,11 +24,13 @@ from argparse import Namespace from pathlib import Path from typing import Any, Callable, Dict, List, Optional, Tuple, Union +from functools import partial import torch from torch import ScriptModule, Tensor from torch.nn import Module from torch.optim.optimizer import Optimizer +from pytorch_lightning.utilities.apply_func import apply_to_collection from pytorch_lightning import _logger as log from pytorch_lightning.core.grads import GradInformation @@ -41,6 +43,7 @@ from pytorch_lightning.utilities.device_dtype_mixin import DeviceDtypeModuleMixin from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.parsing import AttributeDict, collect_init_args, get_init_args +from pytorch_lightning.utilities.distributed import all_gather_ddp_if_available class LightningModule( @@ -363,7 +366,7 @@ def __auto_choose_log_on_epoch(self, on_epoch): return on_epoch - def all_gather(self, tensor: Union[torch.Tensor], group: Optional[Any] = None, sync_grads: bool = False): + def all_gather(self, data: Union[torch.Tensor, Dict, List, Tuple], group: Optional[Any] = None, sync_grads: bool = False): r""" Allows users to call ``self.all_gather()`` from the LightningModule, thus making the ```all_gather``` operation accelerator agnostic. @@ -379,8 +382,14 @@ def all_gather(self, tensor: Union[torch.Tensor], group: Optional[Any] = None, s Return: A tensor of shape (world_size, batch, ...) """ - return self.trainer.accelerator_backend.all_gather(tensor, group=group, sync_grads=sync_grads) + if self.trainer.accelerator_backend is not None: + all_gather = self.trainer.accelerator_backend.all_gather + else: + all_gather = all_gather_ddp_if_available + all_gather = partial(all_gather, group=group, sync_grads=sync_grads) + return apply_to_collection(data, torch.Tensor, all_gather) + def forward(self, *args, **kwargs): r""" Same as :meth:`torch.nn.Module.forward()`, however in Lightning you want this to define diff --git a/pytorch_lightning/utilities/distributed.py b/pytorch_lightning/utilities/distributed.py index 9724f05247c00..2d4e9ea939e6c 100644 --- a/pytorch_lightning/utilities/distributed.py +++ b/pytorch_lightning/utilities/distributed.py @@ -23,6 +23,7 @@ if torch.distributed.is_available(): from torch.distributed import ReduceOp from torch.distributed import group + WORLD = group.WORLD else: class ReduceOp: SUM = None @@ -203,10 +204,13 @@ def all_gather_ddp_if_available( Return: A tensor of shape (world_size, batch, ...) """ + if group is None: + group = torch.distributed.group.WORLD + if torch.distributed.is_available() and torch.distributed.is_initialized(): if sync_grads: return AllGatherGrad.apply(tensor, group) else: - with torch.no_grad: + with torch.no_grad(): return AllGatherGrad.apply(tensor, group) return tensor diff --git a/tests/utilities/test_all_gather_grad.py b/tests/utilities/test_all_gather_grad.py index faba88236afd0..60810f9524c24 100644 --- a/tests/utilities/test_all_gather_grad.py +++ b/tests/utilities/test_all_gather_grad.py @@ -3,7 +3,9 @@ import sys import torch +from pytorch_lightning import Trainer, seed_everything from pytorch_lightning.utilities import AllGatherGrad +from tests.base.boring_model import BoringModel def setup_ddp(rank, world_size): @@ -41,3 +43,43 @@ def _test_all_gather_ddp(rank, world_size): def test_all_gather_ddp(): world_size = 3 torch.multiprocessing.spawn(_test_all_gather_ddp, args=(world_size,), nprocs=world_size) + + +@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine") +@pytest.mark.skipif(not os.getenv("PL_RUNNING_SPECIAL_TESTS", '0') == '1', + reason="test should be run outside of pytest") +def test_all_gather_properly_works(tmpdir): + + class TestModel(BoringModel): + + training_epoch_end_called = False + + def training_epoch_end(self, outputs) -> None: + self.training_epoch_end_called = True + losses = torch.stack([x["loss"] for x in outputs]) + gathered_loss = self.all_gather( + {"losses": losses, + "losses_list": [losses, losses] + }) + assert gathered_loss["losses_list"][0].numel() == 2 * len(losses) + assert gathered_loss["losses"].numel() == 2 * len(losses) + + seed_everything(42) + + model = TestModel() + + limit_train_batches = 8 + trainer = Trainer( + default_root_dir=tmpdir, + limit_train_batches=limit_train_batches, + limit_val_batches=2, + max_epochs=1, + log_every_n_steps=1, + accumulate_grad_batches=2, + enable_pl_optimizer=True, + gpus=2, + accelerator="ddp", + ) + + trainer.fit(model) + assert model.training_epoch_end_called \ No newline at end of file From a7cdde872c7ebe39944851c10457472583270479 Mon Sep 17 00:00:00 2001 From: tchaton Date: Mon, 21 Dec 2020 16:33:12 +0100 Subject: [PATCH 02/26] add tests --- pytorch_lightning/core/lightning.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index 24bf231ebe30a..dadaa2adc29a6 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -22,15 +22,14 @@ import tempfile from abc import ABC from argparse import Namespace +from functools import partial from pathlib import Path from typing import Any, Callable, Dict, List, Optional, Tuple, Union -from functools import partial import torch from torch import ScriptModule, Tensor from torch.nn import Module from torch.optim.optimizer import Optimizer -from pytorch_lightning.utilities.apply_func import apply_to_collection from pytorch_lightning import _logger as log from pytorch_lightning.core.grads import GradInformation @@ -40,10 +39,11 @@ from pytorch_lightning.core.saving import ALLOWED_CONFIG_TYPES, PRIMITIVE_TYPES, ModelIO from pytorch_lightning.core.step_result import Result from pytorch_lightning.utilities import rank_zero_warn +from pytorch_lightning.utilities.apply_func import apply_to_collection from pytorch_lightning.utilities.device_dtype_mixin import DeviceDtypeModuleMixin +from pytorch_lightning.utilities.distributed import all_gather_ddp_if_available from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.parsing import AttributeDict, collect_init_args, get_init_args -from pytorch_lightning.utilities.distributed import all_gather_ddp_if_available class LightningModule( @@ -366,7 +366,10 @@ def __auto_choose_log_on_epoch(self, on_epoch): return on_epoch - def all_gather(self, data: Union[torch.Tensor, Dict, List, Tuple], group: Optional[Any] = None, sync_grads: bool = False): + def all_gather(self, + data: Union[torch.Tensor, Dict, List, Tuple], + group: Optional[Any] = None, + sync_grads: bool = False): r""" Allows users to call ``self.all_gather()`` from the LightningModule, thus making the ```all_gather``` operation accelerator agnostic. @@ -389,7 +392,7 @@ def all_gather(self, data: Union[torch.Tensor, Dict, List, Tuple], group: Option all_gather = partial(all_gather, group=group, sync_grads=sync_grads) return apply_to_collection(data, torch.Tensor, all_gather) - + def forward(self, *args, **kwargs): r""" Same as :meth:`torch.nn.Module.forward()`, however in Lightning you want this to define From 34f5d3403e29049bd3bd357189a167183549ce41 Mon Sep 17 00:00:00 2001 From: tchaton Date: Mon, 21 Dec 2020 16:34:32 +0100 Subject: [PATCH 03/26] add tests --- tests/special_tests.sh | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/special_tests.sh b/tests/special_tests.sh index f7cb581951783..e0ef6d47d1fb9 100644 --- a/tests/special_tests.sh +++ b/tests/special_tests.sh @@ -19,4 +19,5 @@ python ${DEFAULTS} tests/plugins/test_rpc_plugin.py::test_rpc_function_calls_ddp python ${DEFAULTS} tests/plugins/test_ddp_sequential_plugin.py::test_ddp_sequential_plugin_ddp_rpc_manual python ${DEFAULTS} tests/plugins/test_ddp_sequential_plugin.py::test_ddp_sequential_plugin_ddp_rpc_manual_amp python ${DEFAULTS} tests/plugins/test_ddp_sequential_plugin.py::test_ddp_sequential_plugin_ddp_rpc_automatic +python ${DEFAULTS} tests/utilities/test_all_gather_grad.py::test_all_gather_properly_works # python ${DEFAULTS} tests/plugins/test_ddp_sequential_plugin.py::test_ddp_sequential_plugin_ddp_rpc_with_wrong_balance From b850968155cf8a41e854331625820152a9e137c2 Mon Sep 17 00:00:00 2001 From: tchaton Date: Mon, 21 Dec 2020 16:46:51 +0100 Subject: [PATCH 04/26] resolve flake8 --- tests/utilities/test_all_gather_grad.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/tests/utilities/test_all_gather_grad.py b/tests/utilities/test_all_gather_grad.py index 60810f9524c24..96b23bc3f6092 100644 --- a/tests/utilities/test_all_gather_grad.py +++ b/tests/utilities/test_all_gather_grad.py @@ -1,6 +1,7 @@ import os -import pytest import sys + +import pytest import torch from pytorch_lightning import Trainer, seed_everything @@ -57,8 +58,8 @@ class TestModel(BoringModel): def training_epoch_end(self, outputs) -> None: self.training_epoch_end_called = True losses = torch.stack([x["loss"] for x in outputs]) - gathered_loss = self.all_gather( - {"losses": losses, + gathered_loss = self.all_gather({ + "losses": losses, "losses_list": [losses, losses] }) assert gathered_loss["losses_list"][0].numel() == 2 * len(losses) @@ -82,4 +83,4 @@ def training_epoch_end(self, outputs) -> None: ) trainer.fit(model) - assert model.training_epoch_end_called \ No newline at end of file + assert model.training_epoch_end_called From 00075665d51749a569f649d185fb25f3fcba2b19 Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Tue, 22 Dec 2020 08:47:35 +0000 Subject: [PATCH 05/26] update --- pytorch_lightning/core/lightning.py | 5 ++++- pytorch_lightning/utilities/distributed.py | 10 ++++++++-- tests/utilities/test_all_gather_grad.py | 1 + 3 files changed, 13 insertions(+), 3 deletions(-) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index 24bf231ebe30a..bb82c5fb216a3 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -366,7 +366,10 @@ def __auto_choose_log_on_epoch(self, on_epoch): return on_epoch - def all_gather(self, data: Union[torch.Tensor, Dict, List, Tuple], group: Optional[Any] = None, sync_grads: bool = False): + def all_gather(self, + data: Union[torch.Tensor, Dict, List, Tuple], + group: Optional[Any] = None, + sync_grads: bool = False): r""" Allows users to call ``self.all_gather()`` from the LightningModule, thus making the ```all_gather``` operation accelerator agnostic. diff --git a/pytorch_lightning/utilities/distributed.py b/pytorch_lightning/utilities/distributed.py index 2d4e9ea939e6c..db5f6574c70af 100644 --- a/pytorch_lightning/utilities/distributed.py +++ b/pytorch_lightning/utilities/distributed.py @@ -19,11 +19,12 @@ import torch from pytorch_lightning import _logger as log from typing import Union, Optional, Any +from pytorch_lightning.utilities.exceptions import MisconfigurationException + if torch.distributed.is_available(): from torch.distributed import ReduceOp from torch.distributed import group - WORLD = group.WORLD else: class ReduceOp: SUM = None @@ -205,7 +206,12 @@ def all_gather_ddp_if_available( A tensor of shape (world_size, batch, ...) """ if group is None: - group = torch.distributed.group.WORLD + group = globals()["group"].WORLD + if group is None: + raise MisConfigurationException( + "The provided group was None and `torch.distributed.group` isn't available. " + "Gathering tensor accross processes won't be possible. " + ) if torch.distributed.is_available() and torch.distributed.is_initialized(): if sync_grads: diff --git a/tests/utilities/test_all_gather_grad.py b/tests/utilities/test_all_gather_grad.py index 60810f9524c24..51cdf5b65a767 100644 --- a/tests/utilities/test_all_gather_grad.py +++ b/tests/utilities/test_all_gather_grad.py @@ -45,6 +45,7 @@ def test_all_gather_ddp(): torch.multiprocessing.spawn(_test_all_gather_ddp, args=(world_size,), nprocs=world_size) +@pytest.mark.skipif(sys.platform == "win32", reason="DDP not available on windows") @pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine") @pytest.mark.skipif(not os.getenv("PL_RUNNING_SPECIAL_TESTS", '0') == '1', reason="test should be run outside of pytest") From 9205af6100ef6e627134ff81c4132c67da671e6d Mon Sep 17 00:00:00 2001 From: tchaton Date: Tue, 22 Dec 2020 09:51:06 +0100 Subject: [PATCH 06/26] update --- pytorch_lightning/utilities/distributed.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/pytorch_lightning/utilities/distributed.py b/pytorch_lightning/utilities/distributed.py index db5f6574c70af..ce44f2747649a 100644 --- a/pytorch_lightning/utilities/distributed.py +++ b/pytorch_lightning/utilities/distributed.py @@ -15,16 +15,15 @@ import os import warnings from functools import wraps +from typing import Any, Optional, Union import torch + from pytorch_lightning import _logger as log -from typing import Union, Optional, Any from pytorch_lightning.utilities.exceptions import MisconfigurationException - if torch.distributed.is_available(): - from torch.distributed import ReduceOp - from torch.distributed import group + from torch.distributed import ReduceOp, group else: class ReduceOp: SUM = None @@ -208,7 +207,7 @@ def all_gather_ddp_if_available( if group is None: group = globals()["group"].WORLD if group is None: - raise MisConfigurationException( + raise MisconfigurationException( "The provided group was None and `torch.distributed.group` isn't available. " "Gathering tensor accross processes won't be possible. " ) From b7c5df912d36227937f05febd36a05b49f859e23 Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Mon, 28 Dec 2020 11:32:41 +0000 Subject: [PATCH 07/26] remove globals --- pytorch_lightning/utilities/distributed.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/utilities/distributed.py b/pytorch_lightning/utilities/distributed.py index db5f6574c70af..7e4427605537a 100644 --- a/pytorch_lightning/utilities/distributed.py +++ b/pytorch_lightning/utilities/distributed.py @@ -206,7 +206,7 @@ def all_gather_ddp_if_available( A tensor of shape (world_size, batch, ...) """ if group is None: - group = globals()["group"].WORLD + group = torch.distributed.group.WORLD if group is None: raise MisConfigurationException( "The provided group was None and `torch.distributed.group` isn't available. " From c42d4cfbda986b7f26dc9fdf33efacc884a0c6a7 Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Mon, 28 Dec 2020 11:34:03 +0000 Subject: [PATCH 08/26] typo --- pytorch_lightning/utilities/distributed.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/utilities/distributed.py b/pytorch_lightning/utilities/distributed.py index ab6127240623c..b43875904d983 100644 --- a/pytorch_lightning/utilities/distributed.py +++ b/pytorch_lightning/utilities/distributed.py @@ -209,7 +209,7 @@ def all_gather_ddp_if_available( if group is None: raise MisconfigurationException( "The provided group was None and `torch.distributed.group` isn't available. " - "Gathering tensor accross processes won't be possible. " + "Gathering tensor across processes won't be possible. " ) if torch.distributed.is_available() and torch.distributed.is_initialized(): From a109502ba989da62cbcb2f5d98d3947f811c8074 Mon Sep 17 00:00:00 2001 From: chaton Date: Mon, 28 Dec 2020 16:24:57 +0100 Subject: [PATCH 09/26] Update pytorch_lightning/utilities/distributed.py Co-authored-by: Jirka Borovec --- pytorch_lightning/utilities/distributed.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/utilities/distributed.py b/pytorch_lightning/utilities/distributed.py index b43875904d983..115861928bc91 100644 --- a/pytorch_lightning/utilities/distributed.py +++ b/pytorch_lightning/utilities/distributed.py @@ -208,8 +208,8 @@ def all_gather_ddp_if_available( group = torch.distributed.group.WORLD if group is None: raise MisconfigurationException( - "The provided group was None and `torch.distributed.group` isn't available. " - "Gathering tensor across processes won't be possible. " + "The provided group was None and `torch.distributed.group` isn't available." + " Gathering tensor across processes won't be possible." ) if torch.distributed.is_available() and torch.distributed.is_initialized(): From 03b53d801b6e76a603180dcd55ef5c067ef3c4b1 Mon Sep 17 00:00:00 2001 From: tchaton Date: Mon, 28 Dec 2020 17:06:52 +0100 Subject: [PATCH 10/26] update --- pytorch_lightning/core/lightning.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index dadaa2adc29a6..d9c3d9a296ff8 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -366,10 +366,12 @@ def __auto_choose_log_on_epoch(self, on_epoch): return on_epoch - def all_gather(self, - data: Union[torch.Tensor, Dict, List, Tuple], - group: Optional[Any] = None, - sync_grads: bool = False): + def all_gather( + self, + data: Union[torch.Tensor, Dict, List, Tuple], + group: Optional[Any] = None, + sync_grads: bool = False + ): r""" Allows users to call ``self.all_gather()`` from the LightningModule, thus making the ```all_gather``` operation accelerator agnostic. From ee299a63b23ae94bf087593839c44977f455a181 Mon Sep 17 00:00:00 2001 From: tchaton Date: Mon, 4 Jan 2021 11:22:00 +0100 Subject: [PATCH 11/26] update --- pytorch_lightning/core/lightning.py | 10 ++++++++-- pytorch_lightning/utilities/apply_func.py | 14 ++++++++++++++ 2 files changed, 22 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index d9c3d9a296ff8..75574b05bc4f6 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -39,7 +39,7 @@ from pytorch_lightning.core.saving import ALLOWED_CONFIG_TYPES, PRIMITIVE_TYPES, ModelIO from pytorch_lightning.core.step_result import Result from pytorch_lightning.utilities import rank_zero_warn -from pytorch_lightning.utilities.apply_func import apply_to_collection +from pytorch_lightning.utilities.apply_func import apply_to_collection, flatten_collection from pytorch_lightning.utilities.device_dtype_mixin import DeviceDtypeModuleMixin from pytorch_lightning.utilities.distributed import all_gather_ddp_if_available from pytorch_lightning.utilities.exceptions import MisconfigurationException @@ -392,8 +392,14 @@ def all_gather( else: all_gather = all_gather_ddp_if_available + def to_dtype_tensor(value, dtype=None): + return torch.tensor(value, dtype=dtype, device=self.device) + + data = apply_to_collection(data, float, partial(to_dtype_tensor, dtype=torch.float)) + data = apply_to_collection(data, int, partial(to_dtype_tensor, dtype=torch.int)) all_gather = partial(all_gather, group=group, sync_grads=sync_grads) - return apply_to_collection(data, torch.Tensor, all_gather) + data = apply_to_collection(data, torch.Tensor, all_gather) + return flatten_collection(data) def forward(self, *args, **kwargs): r""" diff --git a/pytorch_lightning/utilities/apply_func.py b/pytorch_lightning/utilities/apply_func.py index 775c22dbbfa0a..0f7ec1e58645a 100644 --- a/pytorch_lightning/utilities/apply_func.py +++ b/pytorch_lightning/utilities/apply_func.py @@ -61,6 +61,20 @@ def apply_to_collection(data: Any, dtype: Union[type, tuple], function: Callable return data +def flatten_collection(data: Any): + + if isinstance(data, dict): + for key, value in data.items(): + data[key] = flatten_collection(value) + + elif isinstance(data, (list, tuple)): + if all([torch.is_tensor(value) for value in data]) and len(data) > 0: + value = torch.cat(data, dim=0) + return [value] if isinstance(data, list) else (value, ) + + return data + + class TransferableDataType(ABC): """ A custom type for data that can be moved to a torch device via `.to(...)`. From 710c2effafe065402f17d92739b72f29e35fc106 Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Mon, 4 Jan 2021 10:46:26 +0000 Subject: [PATCH 12/26] add suport int, float --- pytorch_lightning/core/lightning.py | 5 ++--- pytorch_lightning/utilities/apply_func.py | 15 --------------- tests/utilities/test_all_gather_grad.py | 4 ++++ 3 files changed, 6 insertions(+), 18 deletions(-) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index 75574b05bc4f6..41b596825ccfe 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -370,7 +370,7 @@ def all_gather( self, data: Union[torch.Tensor, Dict, List, Tuple], group: Optional[Any] = None, - sync_grads: bool = False + sync_grads: bool = False, ): r""" Allows users to call ``self.all_gather()`` from the LightningModule, thus making @@ -398,8 +398,7 @@ def to_dtype_tensor(value, dtype=None): data = apply_to_collection(data, float, partial(to_dtype_tensor, dtype=torch.float)) data = apply_to_collection(data, int, partial(to_dtype_tensor, dtype=torch.int)) all_gather = partial(all_gather, group=group, sync_grads=sync_grads) - data = apply_to_collection(data, torch.Tensor, all_gather) - return flatten_collection(data) + return apply_to_collection(data, torch.Tensor, all_gather) def forward(self, *args, **kwargs): r""" diff --git a/pytorch_lightning/utilities/apply_func.py b/pytorch_lightning/utilities/apply_func.py index 0f7ec1e58645a..dd8fc079a85d9 100644 --- a/pytorch_lightning/utilities/apply_func.py +++ b/pytorch_lightning/utilities/apply_func.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - import importlib from abc import ABC from collections.abc import Mapping, Sequence @@ -61,20 +60,6 @@ def apply_to_collection(data: Any, dtype: Union[type, tuple], function: Callable return data -def flatten_collection(data: Any): - - if isinstance(data, dict): - for key, value in data.items(): - data[key] = flatten_collection(value) - - elif isinstance(data, (list, tuple)): - if all([torch.is_tensor(value) for value in data]) and len(data) > 0: - value = torch.cat(data, dim=0) - return [value] if isinstance(data, list) else (value, ) - - return data - - class TransferableDataType(ABC): """ A custom type for data that can be moved to a torch device via `.to(...)`. diff --git a/tests/utilities/test_all_gather_grad.py b/tests/utilities/test_all_gather_grad.py index 56af9431d464f..60123aea770c9 100644 --- a/tests/utilities/test_all_gather_grad.py +++ b/tests/utilities/test_all_gather_grad.py @@ -60,9 +60,13 @@ def training_epoch_end(self, outputs) -> None: self.training_epoch_end_called = True losses = torch.stack([x["loss"] for x in outputs]) gathered_loss = self.all_gather({ + "losses_float": [0., 1., 2.], + "losses_int": [0, 1, 2], "losses": losses, "losses_list": [losses, losses] }) + assert gathered_loss["losses_float"][0].dtype == torch.float + assert gathered_loss["losses_int"][0].dtype == torch.int assert gathered_loss["losses_list"][0].numel() == 2 * len(losses) assert gathered_loss["losses"].numel() == 2 * len(losses) From 053c0046fe3660c59265f8b43a42436fff6f85c6 Mon Sep 17 00:00:00 2001 From: tchaton Date: Mon, 4 Jan 2021 11:48:07 +0100 Subject: [PATCH 13/26] update --- pytorch_lightning/core/lightning.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index 41b596825ccfe..fac3daf4d3b7a 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -39,7 +39,7 @@ from pytorch_lightning.core.saving import ALLOWED_CONFIG_TYPES, PRIMITIVE_TYPES, ModelIO from pytorch_lightning.core.step_result import Result from pytorch_lightning.utilities import rank_zero_warn -from pytorch_lightning.utilities.apply_func import apply_to_collection, flatten_collection +from pytorch_lightning.utilities.apply_func import apply_to_collection from pytorch_lightning.utilities.device_dtype_mixin import DeviceDtypeModuleMixin from pytorch_lightning.utilities.distributed import all_gather_ddp_if_available from pytorch_lightning.utilities.exceptions import MisconfigurationException From 04924ae9ca0c5e551ef230509031e2672a0099e0 Mon Sep 17 00:00:00 2001 From: tchaton Date: Mon, 4 Jan 2021 18:59:28 +0100 Subject: [PATCH 14/26] resolve pep8 --- pytorch_lightning/utilities/apply_func.py | 1 - 1 file changed, 1 deletion(-) diff --git a/pytorch_lightning/utilities/apply_func.py b/pytorch_lightning/utilities/apply_func.py index b6eb449142ac8..6d02a2454fc02 100644 --- a/pytorch_lightning/utilities/apply_func.py +++ b/pytorch_lightning/utilities/apply_func.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import importlib from abc import ABC from collections.abc import Mapping, Sequence from copy import copy From cccce283ee698c4b35a4cb7c622ce9d900aebc18 Mon Sep 17 00:00:00 2001 From: chaton Date: Mon, 4 Jan 2021 21:45:18 +0100 Subject: [PATCH 15/26] Update pytorch_lightning/core/lightning.py MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Adrian Wälchli --- pytorch_lightning/core/lightning.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index 5fcaeacb0f30d..d9a6475805462 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -385,7 +385,7 @@ def all_gather( sync_grads: flag that allows users to synchronize gradients for all_gather op Return: - A tensor of shape (world_size, batch, ...) + A tensor of shape (world_size, batch, ...), or if the input was a collection the output will also be a collection with tensors of this shape. """ if self.trainer.accelerator_backend is not None: all_gather = self.trainer.accelerator_backend.all_gather From f764832b09813945fda622b715bd9e8c2c420543 Mon Sep 17 00:00:00 2001 From: chaton Date: Mon, 4 Jan 2021 21:45:32 +0100 Subject: [PATCH 16/26] Update tests/utilities/test_all_gather_grad.py MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Adrian Wälchli --- tests/utilities/test_all_gather_grad.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/utilities/test_all_gather_grad.py b/tests/utilities/test_all_gather_grad.py index 60123aea770c9..036c916aa361c 100644 --- a/tests/utilities/test_all_gather_grad.py +++ b/tests/utilities/test_all_gather_grad.py @@ -50,7 +50,7 @@ def test_all_gather_ddp(): @pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine") @pytest.mark.skipif(not os.getenv("PL_RUNNING_SPECIAL_TESTS", '0') == '1', reason="test should be run outside of pytest") -def test_all_gather_properly_works(tmpdir): +def test_all_gather_collection(tmpdir): class TestModel(BoringModel): From 2df387b49037a7412a1fa473b6d5cbd1eb89d245 Mon Sep 17 00:00:00 2001 From: tchaton Date: Tue, 5 Jan 2021 14:09:18 +0100 Subject: [PATCH 17/26] update doc --- pytorch_lightning/core/lightning.py | 6 ++- tests/trainer/test_dataloaders.py | 81 ++++++++++++++++++++++++++++- 2 files changed, 83 insertions(+), 4 deletions(-) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index d9a6475805462..39e7c382e773f 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -380,12 +380,14 @@ def all_gather( distributed processes Args: - tensor: tensor of shape (batch, ...) + tensor: int, float, tensor of shape (batch, ...), or a collection of + int, float, tensor of shape (batch, ...) group: the process group to gather results from. Defaults to all processes (world) sync_grads: flag that allows users to synchronize gradients for all_gather op Return: - A tensor of shape (world_size, batch, ...), or if the input was a collection the output will also be a collection with tensors of this shape. + A tensor of shape (world_size, batch, ...), or if the input was a collection + the output will also be a collection with tensors of this shape. """ if self.trainer.accelerator_backend is not None: all_gather = self.trainer.accelerator_backend.all_gather diff --git a/tests/trainer/test_dataloaders.py b/tests/trainer/test_dataloaders.py index 50c426c174349..9b42aa98c9dd0 100644 --- a/tests/trainer/test_dataloaders.py +++ b/tests/trainer/test_dataloaders.py @@ -33,7 +33,6 @@ def test_fit_train_loader_only(tmpdir): - model = EvalModelTemplate() train_dataloader = model.train_dataloader() @@ -52,7 +51,6 @@ def test_fit_train_loader_only(tmpdir): def test_fit_val_loader_only(tmpdir): - model = EvalModelTemplate() train_dataloader = model.train_dataloader() val_dataloader = model.val_dataloader() @@ -658,6 +656,62 @@ def test_warning_with_few_workers(mock, tmpdir, ckpt_path): trainer.test(**test_options) +@pytest.mark.skipif(platform.system() == 'Windows', reason='Does not apply to Windows platform.') +@pytest.mark.parametrize('ckpt_path', [None, 'best', 'specific']) +@patch('pytorch_lightning.trainer.data_loading.multiprocessing.cpu_count', return_value=4) +def test_warning_with_few_workers_multi_loader(mock, tmpdir, ckpt_path): + """ Test that error is raised if dataloader with only a few workers is used """ + + model = EvalModelTemplate() + model.training_step = model.training_step__multiple_dataloaders + model.validation_step = model.validation_step__multiple_dataloaders + model.validation_epoch_end = model.validation_epoch_end__multiple_dataloaders + model.test_step = model.test_step__multiple_dataloaders + model.test_epoch_end = model.test_epoch_end__multiple_dataloaders + + # logger file to get meta + train_dl = model.dataloader(train=True) + train_dl.num_workers = 0 + + val_dl = model.dataloader(train=False) + val_dl.num_workers = 0 + + train_dl = model.dataloader(train=False) + train_dl.num_workers = 0 + + train_multi_dl = {'a': train_dl, 'b': train_dl} + val_multi_dl = [val_dl, val_dl] + test_multi_dl = [train_dl, train_dl] + + fit_options = dict(train_dataloader=train_multi_dl, + val_dataloaders=val_multi_dl) + trainer = Trainer( + default_root_dir=tmpdir, + max_epochs=1, + limit_val_batches=0.1, + limit_train_batches=0.2, + ) + + # fit model + with pytest.warns( + UserWarning, match='The dataloader, train dataloader, does not have many workers which may be a bottleneck.' + ): + trainer.fit(model, **fit_options) + + with pytest.warns( + UserWarning, match='The dataloader, val dataloader 0, does not have many workers which may be a bottleneck.' + ): + trainer.fit(model, **fit_options) + + if ckpt_path == 'specific': + ckpt_path = trainer.checkpoint_callback.best_model_path + test_options = dict(test_dataloaders=test_multi_dl, ckpt_path=ckpt_path) + with pytest.warns( + UserWarning, match='The dataloader, test dataloader 0, does not have many workers which may be a bottleneck.' + ): + trainer.test(**test_options) + + @pytest.mark.xfail( LooseVersion(torch.__version__) < LooseVersion("1.4.0"), reason="IterableDataset with __len__ before 1.4 raises", @@ -857,6 +911,29 @@ def train_dataloader(self): assert 1 == result +@pytest.mark.parametrize(['multiple_trainloader_mode', 'num_training_batches'], [ + pytest.param("min_size", 5), + pytest.param("max_size_cycle", 10), +]) +def test_fit_multiple_train_loaders(tmpdir, multiple_trainloader_mode, num_training_batches): + """Integration test for multple train loaders""" + model = EvalModelTemplate() + + model.train_dataloader = model.train_dataloader__multiple_mapping + # todo: add also `train_dataloader__multiple_sequence` + model.training_step = model.training_step__multiple_dataloaders + + trainer = Trainer( + max_epochs=1, + default_root_dir=tmpdir, + multiple_trainloader_mode=multiple_trainloader_mode, + ) + + assert 1 == trainer.fit(model) + # verify the num_training_batches according to the multiple_trainloader_mode + assert num_training_batches == trainer.num_training_batches + + @pytest.mark.parametrize('check_interval', [1.0]) def test_val_dataloader_not_implemented_error(tmpdir, check_interval): """Test not_implemented_error data loader (e.g. IterableDataset)""" From 0ad2f3a3b6a78b60ef07ed0623be284f1dbddbbe Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Thu, 7 Jan 2021 11:30:02 +0000 Subject: [PATCH 18/26] add bool and np.ndarray --- pytorch_lightning/core/lightning.py | 9 +++---- pytorch_lightning/utilities/apply_func.py | 29 ++++++++++++++++++++++- tests/utilities/test_all_gather_grad.py | 5 ++++ 3 files changed, 36 insertions(+), 7 deletions(-) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index 41b596825ccfe..1d97f445839d1 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -20,6 +20,7 @@ import os import re import tempfile +import numpy as np from abc import ABC from argparse import Namespace from functools import partial @@ -39,7 +40,7 @@ from pytorch_lightning.core.saving import ALLOWED_CONFIG_TYPES, PRIMITIVE_TYPES, ModelIO from pytorch_lightning.core.step_result import Result from pytorch_lightning.utilities import rank_zero_warn -from pytorch_lightning.utilities.apply_func import apply_to_collection, flatten_collection +from pytorch_lightning.utilities.apply_func import apply_to_collection, convert_to_tensors from pytorch_lightning.utilities.device_dtype_mixin import DeviceDtypeModuleMixin from pytorch_lightning.utilities.distributed import all_gather_ddp_if_available from pytorch_lightning.utilities.exceptions import MisconfigurationException @@ -392,11 +393,7 @@ def all_gather( else: all_gather = all_gather_ddp_if_available - def to_dtype_tensor(value, dtype=None): - return torch.tensor(value, dtype=dtype, device=self.device) - - data = apply_to_collection(data, float, partial(to_dtype_tensor, dtype=torch.float)) - data = apply_to_collection(data, int, partial(to_dtype_tensor, dtype=torch.int)) + data = convert_to_tensors(data, device=self.device) all_gather = partial(all_gather, group=group, sync_grads=sync_grads) return apply_to_collection(data, torch.Tensor, all_gather) diff --git a/pytorch_lightning/utilities/apply_func.py b/pytorch_lightning/utilities/apply_func.py index dd8fc079a85d9..ad6a936a694a1 100644 --- a/pytorch_lightning/utilities/apply_func.py +++ b/pytorch_lightning/utilities/apply_func.py @@ -15,8 +15,10 @@ from abc import ABC from collections.abc import Mapping, Sequence from copy import copy +from functools import partial +from pytorch_lightning.utilities.exceptions import MisconfigurationException from typing import Any, Callable, Union - +import numpy as np import torch TORCHTEXT_AVAILABLE = importlib.util.find_spec("torchtext") is not None @@ -26,6 +28,15 @@ Batch = type(None) +CONVERSION_DTYPES = [ + (bool, torch.bool), + (int, torch.int), + (float, torch.float), + (np.ndarray, None), + +] + + def apply_to_collection(data: Any, dtype: Union[type, tuple], function: Callable, *args, **kwargs) -> Any: """ Recursively applies a function to all elements of a certain dtype. @@ -60,6 +71,22 @@ def apply_to_collection(data: Any, dtype: Union[type, tuple], function: Callable return data +def to_dtype_tensor(value, dtype:torch.dtype = None, device: torch.device = None): + if isinstance(value, np.ndarray): + return torch.from_numpy(value).to(device) + return torch.tensor(value, dtype=dtype, device=device) + + +def convert_to_tensors(data, device: torch.device = None): + if device is None: + raise MisconfigurationException( + "device (torch.device) should be provided." + ) + for src_dtype, dst_dtype in CONVERSION_DTYPES: + data = apply_to_collection(data, src_dtype, partial(to_dtype_tensor, dtype=dst_dtype, device=device)) + return data + + class TransferableDataType(ABC): """ A custom type for data that can be moved to a torch device via `.to(...)`. diff --git a/tests/utilities/test_all_gather_grad.py b/tests/utilities/test_all_gather_grad.py index 60123aea770c9..cc122c22dddba 100644 --- a/tests/utilities/test_all_gather_grad.py +++ b/tests/utilities/test_all_gather_grad.py @@ -3,6 +3,7 @@ import pytest import torch +import numpy as np from pytorch_lightning import Trainer, seed_everything from pytorch_lightning.utilities import AllGatherGrad @@ -60,11 +61,15 @@ def training_epoch_end(self, outputs) -> None: self.training_epoch_end_called = True losses = torch.stack([x["loss"] for x in outputs]) gathered_loss = self.all_gather({ + "losses_np_ndarray": np.array([1, 2, 3]), + "losses_bool": [True, False], "losses_float": [0., 1., 2.], "losses_int": [0, 1, 2], "losses": losses, "losses_list": [losses, losses] }) + assert gathered_loss["losses_np_ndarray"][0].dtype == torch.int64 + assert gathered_loss["losses_bool"][0].dtype == torch.bool assert gathered_loss["losses_float"][0].dtype == torch.float assert gathered_loss["losses_int"][0].dtype == torch.int assert gathered_loss["losses_list"][0].numel() == 2 * len(losses) From 00813ef681c9c063052212ab7e8645e0a02d0f7a Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Thu, 7 Jan 2021 11:32:45 +0000 Subject: [PATCH 19/26] resolve conflicts --- pytorch_lightning/core/lightning.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index e0a74325e1b3f..3d90dd52478aa 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -40,11 +40,7 @@ from pytorch_lightning.core.saving import ALLOWED_CONFIG_TYPES, PRIMITIVE_TYPES, ModelIO from pytorch_lightning.core.step_result import Result from pytorch_lightning.utilities import rank_zero_warn -<<<<<<< HEAD from pytorch_lightning.utilities.apply_func import apply_to_collection, convert_to_tensors -======= -from pytorch_lightning.utilities.apply_func import apply_to_collection ->>>>>>> 8a906d3f0167314ac2b5b5552724c4699122dd93 from pytorch_lightning.utilities.device_dtype_mixin import DeviceDtypeModuleMixin from pytorch_lightning.utilities.distributed import all_gather_ddp_if_available from pytorch_lightning.utilities.exceptions import MisconfigurationException From 832bb98773b6e43e156e97e308fcb1b2d2d8bad7 Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Thu, 7 Jan 2021 11:37:19 +0000 Subject: [PATCH 20/26] resolve conflicts --- pytorch_lightning/utilities/apply_func.py | 54 +++++++++++------------ 1 file changed, 26 insertions(+), 28 deletions(-) diff --git a/pytorch_lightning/utilities/apply_func.py b/pytorch_lightning/utilities/apply_func.py index ece718c3f37c0..912023763a652 100644 --- a/pytorch_lightning/utilities/apply_func.py +++ b/pytorch_lightning/utilities/apply_func.py @@ -11,13 +11,14 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + from abc import ABC +import numpy as np +from functools import partial from collections.abc import Mapping, Sequence from copy import copy -from functools import partial +from typing import Any, Callable, Union, Optional from pytorch_lightning.utilities.exceptions import MisconfigurationException -from typing import Any, Callable, Union -import numpy as np import torch from pytorch_lightning.utilities.imports import _TORCHTEXT_AVAILABLE @@ -36,10 +37,10 @@ ] -def apply_to_collection(data: Any, dtype: Union[type, tuple], function: Callable, *args, **kwargs) -> Any: +def apply_to_collection(data: Any, dtype: Union[type, tuple], function: Callable, *args, + wrong_dtype: Optional[Union[type, tuple]] = None, **kwargs) -> Any: """ Recursively applies a function to all elements of a certain dtype. - Args: data: the collection to apply the function to dtype: the given function will be applied to all elements of this dtype @@ -48,10 +49,8 @@ def apply_to_collection(data: Any, dtype: Union[type, tuple], function: Callable wrong_dtype: the given function won't be applied if this type is specified and the given collections is of the :attr:`wrong_type` even if it is of type :attr`dtype` **kwargs: keyword arguments (will be forwarded to calls of ``function``) - Returns: the resulting collection - """ elem_type = type(data) @@ -72,28 +71,10 @@ def apply_to_collection(data: Any, dtype: Union[type, tuple], function: Callable return data -def to_dtype_tensor(value, dtype:torch.dtype = None, device: torch.device = None): - if isinstance(value, np.ndarray): - return torch.from_numpy(value).to(device) - return torch.tensor(value, dtype=dtype, device=device) - - -def convert_to_tensors(data, device: torch.device = None): - if device is None: - raise MisconfigurationException( - "device (torch.device) should be provided." - ) - for src_dtype, dst_dtype in CONVERSION_DTYPES: - data = apply_to_collection(data, src_dtype, partial(to_dtype_tensor, dtype=dst_dtype, device=device)) - return data - - class TransferableDataType(ABC): """ A custom type for data that can be moved to a torch device via `.to(...)`. - Example: - >>> isinstance(dict, TransferableDataType) False >>> isinstance(torch.rand(2, 3), TransferableDataType) @@ -120,15 +101,12 @@ def move_data_to_device(batch: Any, device: torch.device): """ Transfers a collection of data to the given device. Any object that defines a method ``to(device)`` will be moved and all other objects in the collection will be left untouched. - Args: batch: A tensor or collection of tensors or anything that has a method `.to(...)`. See :func:`apply_to_collection` for a list of supported collection types. device: The device to which the data should be moved - Return: the same collection but with all contained tensors residing on the new device. - See Also: - :meth:`torch.Tensor.to` - :class:`torch.device` @@ -152,3 +130,23 @@ def batch_to(data): dtype = (TransferableDataType, Batch) if _TORCHTEXT_AVAILABLE else TransferableDataType return apply_to_collection(batch, dtype=dtype, function=batch_to) + + +def to_dtype_tensor(value, dtype:torch.dtype = None, device: torch.device = None): + if device is None: + raise MisconfigurationException( + "device (torch.device) should be provided." + ) + if isinstance(value, np.ndarray): + return torch.from_numpy(value).to(device) + return torch.tensor(value, dtype=dtype, device=device) + + +def convert_to_tensors(data, device: torch.device = None): + if device is None: + raise MisconfigurationException( + "device (torch.device) should be provided." + ) + for src_dtype, dst_dtype in CONVERSION_DTYPES: + data = apply_to_collection(data, src_dtype, partial(to_dtype_tensor, dtype=dst_dtype, device=device)) + return data \ No newline at end of file From a00a51ca4cfe3b5965fc1b0e3cc5015af6fb0bdb Mon Sep 17 00:00:00 2001 From: tchaton Date: Thu, 7 Jan 2021 13:08:42 +0100 Subject: [PATCH 21/26] resolve pep8 --- pytorch_lightning/core/lightning.py | 11 +++++------ pytorch_lightning/utilities/apply_func.py | 11 ++++++----- 2 files changed, 11 insertions(+), 11 deletions(-) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index 3d90dd52478aa..7018bd12f3cbd 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -14,17 +14,16 @@ """nn.Module with additional great features.""" +from abc import ABC +from argparse import Namespace import collections import copy +from functools import partial import inspect import os +from pathlib import Path import re import tempfile -import numpy as np -from abc import ABC -from argparse import Namespace -from functools import partial -from pathlib import Path from typing import Any, Callable, Dict, List, Optional, Tuple, Union import torch @@ -37,7 +36,7 @@ from pytorch_lightning.core.hooks import CheckpointHooks, DataHooks, ModelHooks from pytorch_lightning.core.memory import ModelSummary from pytorch_lightning.core.optimizer import LightningOptimizer -from pytorch_lightning.core.saving import ALLOWED_CONFIG_TYPES, PRIMITIVE_TYPES, ModelIO +from pytorch_lightning.core.saving import ALLOWED_CONFIG_TYPES, ModelIO, PRIMITIVE_TYPES from pytorch_lightning.core.step_result import Result from pytorch_lightning.utilities import rank_zero_warn from pytorch_lightning.utilities.apply_func import apply_to_collection, convert_to_tensors diff --git a/pytorch_lightning/utilities/apply_func.py b/pytorch_lightning/utilities/apply_func.py index 912023763a652..fedc8a6b87a27 100644 --- a/pytorch_lightning/utilities/apply_func.py +++ b/pytorch_lightning/utilities/apply_func.py @@ -13,14 +13,15 @@ # limitations under the License. from abc import ABC -import numpy as np -from functools import partial from collections.abc import Mapping, Sequence from copy import copy -from typing import Any, Callable, Union, Optional -from pytorch_lightning.utilities.exceptions import MisconfigurationException +from functools import partial +from typing import Any, Callable, Optional, Union + +import numpy as np import torch +from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.imports import _TORCHTEXT_AVAILABLE if _TORCHTEXT_AVAILABLE: @@ -149,4 +150,4 @@ def convert_to_tensors(data, device: torch.device = None): ) for src_dtype, dst_dtype in CONVERSION_DTYPES: data = apply_to_collection(data, src_dtype, partial(to_dtype_tensor, dtype=dst_dtype, device=device)) - return data \ No newline at end of file + return data From 96b3d98c9b34a9cec3a2011ad9143530365bef2b Mon Sep 17 00:00:00 2001 From: tchaton Date: Thu, 7 Jan 2021 13:13:07 +0100 Subject: [PATCH 22/26] add changelog --- CHANGELOG.md | 3 +++ pytorch_lightning/utilities/distributed.py | 4 +--- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 928144320394a..bde9727822115 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -30,6 +30,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added `R2Score` metric ([#5241](https://github.com/PyTorchLightning/pytorch-lightning/pull/5241)) +- Accelerator `all_gather` supports collection ([#5221](https://github.com/PyTorchLightning/pytorch-lightning/pull/5221)) + + ### Changed - `stat_scores` metric now calculates stat scores over all classes and gains new parameters, in line with the new `StatScores` metric ([#4839](https://github.com/PyTorchLightning/pytorch-lightning/pull/4839)) diff --git a/pytorch_lightning/utilities/distributed.py b/pytorch_lightning/utilities/distributed.py index 9289eccc40244..3460d85c64131 100644 --- a/pytorch_lightning/utilities/distributed.py +++ b/pytorch_lightning/utilities/distributed.py @@ -20,11 +20,9 @@ import torch from pytorch_lightning import _logger as log -from pytorch_lightning.utilities.exceptions import MisconfigurationException - if torch.distributed.is_available(): - from torch.distributed import ReduceOp, group + from torch.distributed import group, ReduceOp else: class ReduceOp: SUM = None From 5112f4b38f512a503de661c5e44fe06a85c75d63 Mon Sep 17 00:00:00 2001 From: chaton Date: Fri, 8 Jan 2021 09:48:23 +0100 Subject: [PATCH 23/26] Update pytorch_lightning/core/lightning.py MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Adrian Wälchli --- pytorch_lightning/core/lightning.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index 473e7e7bbb703..42ef431aaf096 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -381,8 +381,7 @@ def all_gather( distributed processes Args: - tensor: int, float, tensor of shape (batch, ...), or a collection of - int, float, tensor of shape (batch, ...) + tensor: int, float, tensor of shape (batch, ...), or a (possibly nested) collection thereof. group: the process group to gather results from. Defaults to all processes (world) sync_grads: flag that allows users to synchronize gradients for all_gather op From 6e9141efa86818a9f2c6ae5b81a0abcc9025af88 Mon Sep 17 00:00:00 2001 From: tchaton Date: Sat, 9 Jan 2021 10:58:04 +0100 Subject: [PATCH 24/26] update --- tests/special_tests.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/special_tests.sh b/tests/special_tests.sh index 67ed07a3162a3..675f05cf787ff 100644 --- a/tests/special_tests.sh +++ b/tests/special_tests.sh @@ -20,6 +20,6 @@ python ${DEFAULTS} tests/plugins/test_rpc_plugin.py::test_rpc_function_calls_ddp python ${DEFAULTS} tests/plugins/test_ddp_sequential_plugin.py::test_ddp_sequential_plugin_ddp_rpc_manual python ${DEFAULTS} tests/plugins/test_ddp_sequential_plugin.py::test_ddp_sequential_plugin_ddp_rpc_manual_amp python ${DEFAULTS} tests/plugins/test_ddp_sequential_plugin.py::test_ddp_sequential_plugin_ddp_rpc_automatic -python ${DEFAULTS} tests/utilities/test_all_gather_grad.py::test_all_gather_properly_works +python ${DEFAULTS} tests/utilities/test_all_gather_grad.py::test_all_gather_collection # python ${DEFAULTS} tests/plugins/test_ddp_sequential_plugin.py::test_ddp_sequential_plugin_ddp_rpc_with_wrong_balance python ${DEFAULTS} tests/trainer/logging_tests/test_train_loop_logging_1_0.py::test_logging_sync_dist_true_ddp From 46126a249f4317845c745bd75c1fc6e7265bc5fe Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Sat, 9 Jan 2021 11:54:39 +0000 Subject: [PATCH 25/26] resolve bug --- pytorch_lightning/core/lightning.py | 1 + pytorch_lightning/utilities/apply_func.py | 40 +++++++++++++---------- tests/utilities/test_all_gather_grad.py | 3 +- 3 files changed, 25 insertions(+), 19 deletions(-) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index 42ef431aaf096..c2ec67819912e 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -389,6 +389,7 @@ def all_gather( A tensor of shape (world_size, batch, ...), or if the input was a collection the output will also be a collection with tensors of this shape. """ + group = group if group is not None else torch.distributed.group.WORLD if self.trainer.accelerator_backend is not None: all_gather = self.trainer.accelerator_backend.all_gather else: diff --git a/pytorch_lightning/utilities/apply_func.py b/pytorch_lightning/utilities/apply_func.py index fedc8a6b87a27..6f1512dd5204f 100644 --- a/pytorch_lightning/utilities/apply_func.py +++ b/pytorch_lightning/utilities/apply_func.py @@ -30,11 +30,27 @@ Batch = type(None) +def to_dtype_tensor(value, dtype:torch.dtype = None, device: torch.device = None): + if device is None: + raise MisconfigurationException( + "device (torch.device) should be provided." + ) + return torch.tensor(value, dtype=dtype, device=device) + + +def from_numpy(value, device: torch.device = None): + if device is None: + raise MisconfigurationException( + "device (torch.device) should be provided." + ) + return torch.from_numpy(value).to(device) + CONVERSION_DTYPES = [ - (bool, torch.bool), - (int, torch.int), - (float, torch.float), - (np.ndarray, None), + # bool -> int as torch.bool: RuntimeError: Unsupported data type for NCCL process group + (bool, partial(to_dtype_tensor, dtype=torch.int)), + (int, partial(to_dtype_tensor, dtype=torch.int)), + (float, partial(to_dtype_tensor, dtype=torch.float)), + (np.ndarray, from_numpy), ] @@ -131,23 +147,11 @@ def batch_to(data): dtype = (TransferableDataType, Batch) if _TORCHTEXT_AVAILABLE else TransferableDataType return apply_to_collection(batch, dtype=dtype, function=batch_to) - - -def to_dtype_tensor(value, dtype:torch.dtype = None, device: torch.device = None): - if device is None: - raise MisconfigurationException( - "device (torch.device) should be provided." - ) - if isinstance(value, np.ndarray): - return torch.from_numpy(value).to(device) - return torch.tensor(value, dtype=dtype, device=device) - - def convert_to_tensors(data, device: torch.device = None): if device is None: raise MisconfigurationException( "device (torch.device) should be provided." ) - for src_dtype, dst_dtype in CONVERSION_DTYPES: - data = apply_to_collection(data, src_dtype, partial(to_dtype_tensor, dtype=dst_dtype, device=device)) + for src_dtype, conversion_func in CONVERSION_DTYPES: + data = apply_to_collection(data, src_dtype, partial(conversion_func, device=device)) return data diff --git a/tests/utilities/test_all_gather_grad.py b/tests/utilities/test_all_gather_grad.py index f38dae2e2e095..a03604fb4b8f2 100644 --- a/tests/utilities/test_all_gather_grad.py +++ b/tests/utilities/test_all_gather_grad.py @@ -69,7 +69,8 @@ def training_epoch_end(self, outputs) -> None: "losses_list": [losses, losses] }) assert gathered_loss["losses_np_ndarray"][0].dtype == torch.int64 - assert gathered_loss["losses_bool"][0].dtype == torch.bool + # torch.bool can't be all_gathered + assert gathered_loss["losses_bool"][0].dtype == torch.int32 assert gathered_loss["losses_float"][0].dtype == torch.float assert gathered_loss["losses_int"][0].dtype == torch.int assert gathered_loss["losses_list"][0].numel() == 2 * len(losses) From 0df1a9895e66078b13c02e81b6563dc6169a1856 Mon Sep 17 00:00:00 2001 From: tchaton Date: Sat, 9 Jan 2021 12:57:31 +0100 Subject: [PATCH 26/26] resolve flake8 --- pytorch_lightning/utilities/apply_func.py | 7 +++++-- tests/utilities/test_all_gather_grad.py | 6 +++--- 2 files changed, 8 insertions(+), 5 deletions(-) diff --git a/pytorch_lightning/utilities/apply_func.py b/pytorch_lightning/utilities/apply_func.py index 6f1512dd5204f..fd610ebcb0c8d 100644 --- a/pytorch_lightning/utilities/apply_func.py +++ b/pytorch_lightning/utilities/apply_func.py @@ -45,9 +45,10 @@ def from_numpy(value, device: torch.device = None): ) return torch.from_numpy(value).to(device) + CONVERSION_DTYPES = [ - # bool -> int as torch.bool: RuntimeError: Unsupported data type for NCCL process group - (bool, partial(to_dtype_tensor, dtype=torch.int)), + # bool -> uint8 as bool -> torch.bool triggers RuntimeError: Unsupported data type for NCCL process group + (bool, partial(to_dtype_tensor, dtype=torch.uint8)), (int, partial(to_dtype_tensor, dtype=torch.int)), (float, partial(to_dtype_tensor, dtype=torch.float)), (np.ndarray, from_numpy), @@ -147,6 +148,8 @@ def batch_to(data): dtype = (TransferableDataType, Batch) if _TORCHTEXT_AVAILABLE else TransferableDataType return apply_to_collection(batch, dtype=dtype, function=batch_to) + + def convert_to_tensors(data, device: torch.device = None): if device is None: raise MisconfigurationException( diff --git a/tests/utilities/test_all_gather_grad.py b/tests/utilities/test_all_gather_grad.py index a03604fb4b8f2..9d0dc5cbc9481 100644 --- a/tests/utilities/test_all_gather_grad.py +++ b/tests/utilities/test_all_gather_grad.py @@ -1,11 +1,11 @@ import os import sys +import numpy as np import pytest import torch -import numpy as np -from pytorch_lightning import Trainer, seed_everything +from pytorch_lightning import seed_everything, Trainer from pytorch_lightning.utilities import AllGatherGrad from tests.base.boring_model import BoringModel @@ -70,7 +70,7 @@ def training_epoch_end(self, outputs) -> None: }) assert gathered_loss["losses_np_ndarray"][0].dtype == torch.int64 # torch.bool can't be all_gathered - assert gathered_loss["losses_bool"][0].dtype == torch.int32 + assert gathered_loss["losses_bool"][0].dtype == torch.uint8 assert gathered_loss["losses_float"][0].dtype == torch.float assert gathered_loss["losses_int"][0].dtype == torch.int assert gathered_loss["losses_list"][0].numel() == 2 * len(losses)