From 7c76ba0a0e20eb5be48fd073a6157e20d8dfbd64 Mon Sep 17 00:00:00 2001 From: Leonard Lausen Date: Mon, 3 May 2021 21:29:08 +0000 Subject: [PATCH 01/10] deepspeed add train_micro_batch_size_per_gpu argument --- pytorch_lightning/plugins/training_type/deepspeed.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/pytorch_lightning/plugins/training_type/deepspeed.py b/pytorch_lightning/plugins/training_type/deepspeed.py index 54974739c1746..4524685aa8e61 100644 --- a/pytorch_lightning/plugins/training_type/deepspeed.py +++ b/pytorch_lightning/plugins/training_type/deepspeed.py @@ -88,6 +88,7 @@ def __init__( allgather_bucket_size: int = 2e8, reduce_bucket_size: int = 2e8, zero_allow_untested_optimizer: bool = True, + train_micro_batch_size_per_gpu: Optional[int] = None, config: Optional[Union[Path, str, dict]] = None, logging_level: int = logging.WARN, num_nodes: int = 1, @@ -148,6 +149,8 @@ def __init__( zero_allow_untested_optimizer: Allow untested optimizers to be used with ZeRO. Currently only Adam is a DeepSpeed supported optimizer when using ZeRO (default: True) + train_micro_batch_size_per_gpu + config: Pass in a deepspeed formatted config dict, or path to a deepspeed config: https://www.deepspeed.ai/docs/config-json. All defaults will be ignored if a config is passed in. (Default: ``None``) @@ -197,6 +200,7 @@ def __init__( self.config = self._create_default_config( zero_optimization, zero_allow_untested_optimizer, + train_micro_batch_size_per_gpu, partition_activations=partition_activations, cpu_checkpointing=cpu_checkpointing, contiguous_memory_optimization=contiguous_memory_optimization, @@ -446,6 +450,7 @@ def _create_default_config( self, zero_optimization: bool, zero_allow_untested_optimizer: bool, + train_micro_batch_size_per_gpu: Optional[int], partition_activations: bool, cpu_checkpointing: bool, contiguous_memory_optimization: bool, @@ -466,6 +471,9 @@ def _create_default_config( "zero_optimization": zero_kwargs, **cfg } + if train_micro_batch_size_per_gpu is not None: + cfg = {"train_micro_batch_size_per_gpu": train_micro_batch_size_per_gpu, + **cfg} return cfg def _filepath_to_dir(self, filepath: str) -> str: From 0507f6e970c254429a7d7b726a9d712b4ff78dfa Mon Sep 17 00:00:00 2001 From: Leonard Lausen Date: Tue, 4 May 2021 23:19:20 +0000 Subject: [PATCH 02/10] Update naming and doc --- .../plugins/training_type/deepspeed.py | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/pytorch_lightning/plugins/training_type/deepspeed.py b/pytorch_lightning/plugins/training_type/deepspeed.py index 4524685aa8e61..c84ab87e526d9 100644 --- a/pytorch_lightning/plugins/training_type/deepspeed.py +++ b/pytorch_lightning/plugins/training_type/deepspeed.py @@ -88,7 +88,7 @@ def __init__( allgather_bucket_size: int = 2e8, reduce_bucket_size: int = 2e8, zero_allow_untested_optimizer: bool = True, - train_micro_batch_size_per_gpu: Optional[int] = None, + logging_batch_size_per_gpu: Optional[int] = 1, config: Optional[Union[Path, str, dict]] = None, logging_level: int = logging.WARN, num_nodes: int = 1, @@ -149,7 +149,10 @@ def __init__( zero_allow_untested_optimizer: Allow untested optimizers to be used with ZeRO. Currently only Adam is a DeepSpeed supported optimizer when using ZeRO (default: True) - train_micro_batch_size_per_gpu + logging_batch_size_per_gpu: Config used in DeepSpeed to calculate verbose timing for logging + on a per sample per second basis (only displayed if logging=logging.INFO). + To obtain accurate logs, set this to the actual per gpu batch size (trainer.batch_size). + If set to None, the logging_batch_size_per_gpu is inferred from the train DataLoader's BatchSampler config: Pass in a deepspeed formatted config dict, or path to a deepspeed config: https://www.deepspeed.ai/docs/config-json. @@ -185,6 +188,7 @@ def __init__( when using ZeRO Stage 3. This allows a single weight file to contain the entire model, rather than individual sharded weight files. Disable to save sharded states individually. (Default: True) + """ if not _DEEPSPEED_AVAILABLE: raise MisconfigurationException( @@ -200,7 +204,7 @@ def __init__( self.config = self._create_default_config( zero_optimization, zero_allow_untested_optimizer, - train_micro_batch_size_per_gpu, + logging_batch_size_per_gpu, partition_activations=partition_activations, cpu_checkpointing=cpu_checkpointing, contiguous_memory_optimization=contiguous_memory_optimization, @@ -450,7 +454,7 @@ def _create_default_config( self, zero_optimization: bool, zero_allow_untested_optimizer: bool, - train_micro_batch_size_per_gpu: Optional[int], + logging_batch_size_per_gpu: Optional[int], partition_activations: bool, cpu_checkpointing: bool, contiguous_memory_optimization: bool, @@ -471,8 +475,8 @@ def _create_default_config( "zero_optimization": zero_kwargs, **cfg } - if train_micro_batch_size_per_gpu is not None: - cfg = {"train_micro_batch_size_per_gpu": train_micro_batch_size_per_gpu, + if logging_batch_size_per_gpu is not None: + cfg = {"train_micro_batch_size_per_gpu": logging_batch_size_per_gpu, **cfg} return cfg From b51f330e261db7215be3b2588845c487872fff1d Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Thu, 6 May 2021 11:20:06 +0100 Subject: [PATCH 03/10] Modify to use auto naming convention, add test --- .../plugins/training_type/deepspeed.py | 29 ++++++++++++------- tests/plugins/test_deepspeed_plugin.py | 27 +++++++++++++++++ 2 files changed, 46 insertions(+), 10 deletions(-) diff --git a/pytorch_lightning/plugins/training_type/deepspeed.py b/pytorch_lightning/plugins/training_type/deepspeed.py index c84ab87e526d9..afad2e50936ad 100644 --- a/pytorch_lightning/plugins/training_type/deepspeed.py +++ b/pytorch_lightning/plugins/training_type/deepspeed.py @@ -88,7 +88,7 @@ def __init__( allgather_bucket_size: int = 2e8, reduce_bucket_size: int = 2e8, zero_allow_untested_optimizer: bool = True, - logging_batch_size_per_gpu: Optional[int] = 1, + logging_batch_size_per_gpu: Union[str, int] = "auto", config: Optional[Union[Path, str, dict]] = None, logging_level: int = logging.WARN, num_nodes: int = 1, @@ -151,8 +151,10 @@ def __init__( logging_batch_size_per_gpu: Config used in DeepSpeed to calculate verbose timing for logging on a per sample per second basis (only displayed if logging=logging.INFO). - To obtain accurate logs, set this to the actual per gpu batch size (trainer.batch_size). - If set to None, the logging_batch_size_per_gpu is inferred from the train DataLoader's BatchSampler + If set to "auto", the plugin tries to infer this from + the train DataLoader's BatchSampler, else defaults to 1. + To obtain accurate logs when using datasets that do not support batch samplers, + set this to the actual per gpu batch size (trainer.batch_size). config: Pass in a deepspeed formatted config dict, or path to a deepspeed config: https://www.deepspeed.ai/docs/config-json. @@ -417,14 +419,22 @@ def _format_batch_size_and_grad_accum_config(self): " as this will be set via accumulate_grad_batches=x argument passed via the Lightning Trainer." ) if "train_micro_batch_size_per_gpu" not in self.config: - # train_micro_batch_size_per_gpu is used for throughput logging purposes - # by default we use the batch size of the loader which may be incorrect if a batch sampler is passed - batch_size = self.lightning_module.train_dataloader().batch_sampler.batch_size + batch_size = self._auto_select_batch_size() self.config["train_micro_batch_size_per_gpu"] = batch_size self.config["gradient_accumulation_steps"] = self.lightning_module.trainer.accumulate_grad_batches if "gradient_clipping" not in self.config: self.config["gradient_clipping"] = self.lightning_module.trainer.gradient_clip_val + def _auto_select_batch_size(self): + # train_micro_batch_size_per_gpu is used for throughput logging purposes + # by default we try to use the batch size of the loader + batch_size = 1 + if hasattr(self.lightning_module, 'train_dataloader'): + train_dataloader = self.lightning_module.train_dataloader() + if hasattr(train_dataloader, 'batch_sampler'): + batch_size = train_dataloader.batch_sampler.batch_size + return batch_size + def _format_precision_config(self): amp_type = self.lightning_module.trainer.accelerator_connector.amp_type amp_level = self.lightning_module.trainer.accelerator_connector.amp_level @@ -454,7 +464,7 @@ def _create_default_config( self, zero_optimization: bool, zero_allow_untested_optimizer: bool, - logging_batch_size_per_gpu: Optional[int], + logging_batch_size_per_gpu: Union[str, int], partition_activations: bool, cpu_checkpointing: bool, contiguous_memory_optimization: bool, @@ -475,9 +485,8 @@ def _create_default_config( "zero_optimization": zero_kwargs, **cfg } - if logging_batch_size_per_gpu is not None: - cfg = {"train_micro_batch_size_per_gpu": logging_batch_size_per_gpu, - **cfg} + if logging_batch_size_per_gpu is not 'auto': + cfg = {"train_micro_batch_size_per_gpu": logging_batch_size_per_gpu, **cfg} return cfg def _filepath_to_dir(self, filepath: str) -> str: diff --git a/tests/plugins/test_deepspeed_plugin.py b/tests/plugins/test_deepspeed_plugin.py index c768a9aabf8fb..28e8d444c5b43 100644 --- a/tests/plugins/test_deepspeed_plugin.py +++ b/tests/plugins/test_deepspeed_plugin.py @@ -234,6 +234,33 @@ def backward(self, loss: Tensor, optimizer: Optimizer, optimizer_idx: int, *args trainer.fit(model) +@RunIf(min_gpus=1, deepspeed=True) +@pytest.mark.parametrize('value', ["auto", 10]) +def test_deepspeed_auto_batch_size_config_select(tmpdir, value): + """Test to ensure that the batch size is correctly set as expected for deepspeed logging purposes.""" + model = BoringModel() + + class AssertCallback(Callback): + + def on_train_start(self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule') -> None: + assert isinstance(trainer.accelerator.training_type_plugin, DeepSpeedPlugin) + config = trainer.accelerator.training_type_plugin.config + expected_value = pl_module.train_dataloader().batch_size if value is "auto" else value + assert config['train_micro_batch_size_per_gpu'] == expected_value + raise SystemExit + + ck = AssertCallback() + trainer = Trainer( + default_root_dir=tmpdir, + fast_dev_run=True, + callbacks=ck, + gpus=1, + plugins=DeepSpeedPlugin(logging_batch_size_per_gpu=value, zero_optimization=False), + ) + with pytest.raises(SystemExit): + trainer.fit(model) + + @RunIf(min_gpus=1, deepspeed=True, special=True) def test_deepspeed_run_configure_optimizers(tmpdir): """ From 937efd5c866b396c886de9fda61261dbc0c2d7c2 Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Thu, 6 May 2021 11:28:28 +0100 Subject: [PATCH 04/10] Add iterable tests --- tests/plugins/test_deepspeed_plugin.py | 22 +++++++++++++++++----- 1 file changed, 17 insertions(+), 5 deletions(-) diff --git a/tests/plugins/test_deepspeed_plugin.py b/tests/plugins/test_deepspeed_plugin.py index 28e8d444c5b43..3cf3acc65b43b 100644 --- a/tests/plugins/test_deepspeed_plugin.py +++ b/tests/plugins/test_deepspeed_plugin.py @@ -7,6 +7,7 @@ import torch.nn.functional as F from torch import nn, Tensor from torch.optim import Optimizer +from torch.utils.data import DataLoader from pytorch_lightning import LightningModule, seed_everything, Trainer from pytorch_lightning.callbacks import Callback, ModelCheckpoint @@ -14,7 +15,7 @@ from pytorch_lightning.plugins import DeepSpeedPlugin, DeepSpeedPrecisionPlugin from pytorch_lightning.plugins.training_type.deepspeed import LightningDeepSpeedModule from pytorch_lightning.utilities.exceptions import MisconfigurationException -from tests.helpers.boring_model import BoringModel +from tests.helpers.boring_model import BoringModel, RandomDataset, RandomIterableDataset from tests.helpers.datamodules import ClassifDataModule from tests.helpers.runif import RunIf @@ -235,21 +236,32 @@ def backward(self, loss: Tensor, optimizer: Optimizer, optimizer_idx: int, *args @RunIf(min_gpus=1, deepspeed=True) -@pytest.mark.parametrize('value', ["auto", 10]) -def test_deepspeed_auto_batch_size_config_select(tmpdir, value): +@pytest.mark.parametrize(['dataset_cls', 'value'], [(RandomDataset, "auto"), (RandomDataset, 10), + (RandomIterableDataset, "auto"), (RandomIterableDataset, 10)]) +def test_deepspeed_auto_batch_size_config_select(tmpdir, dataset_cls, value): """Test to ensure that the batch size is correctly set as expected for deepspeed logging purposes.""" - model = BoringModel() + + class TestModel(BoringModel): + + def train_dataloader(self): + return DataLoader(dataset_cls(32, 64)) class AssertCallback(Callback): def on_train_start(self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule') -> None: assert isinstance(trainer.accelerator.training_type_plugin, DeepSpeedPlugin) config = trainer.accelerator.training_type_plugin.config - expected_value = pl_module.train_dataloader().batch_size if value is "auto" else value + + # int value overrides auto mode + expected_value = value if isinstance(value, int) else 1 + if dataset_cls is RandomDataset: + expected_value = pl_module.train_dataloader().batch_size if value is "auto" else value + assert config['train_micro_batch_size_per_gpu'] == expected_value raise SystemExit ck = AssertCallback() + model = TestModel() trainer = Trainer( default_root_dir=tmpdir, fast_dev_run=True, From 99e9b123486a06b3b0b25c5c496702945bf38cda Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Thu, 6 May 2021 14:53:29 +0100 Subject: [PATCH 05/10] Fix tests, attempt by mocking --- pytorch_lightning/plugins/training_type/deepspeed.py | 2 +- tests/plugins/test_deepspeed_plugin.py | 8 +++++--- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/pytorch_lightning/plugins/training_type/deepspeed.py b/pytorch_lightning/plugins/training_type/deepspeed.py index afad2e50936ad..fe3f51fa99390 100644 --- a/pytorch_lightning/plugins/training_type/deepspeed.py +++ b/pytorch_lightning/plugins/training_type/deepspeed.py @@ -485,7 +485,7 @@ def _create_default_config( "zero_optimization": zero_kwargs, **cfg } - if logging_batch_size_per_gpu is not 'auto': + if logging_batch_size_per_gpu != 'auto': cfg = {"train_micro_batch_size_per_gpu": logging_batch_size_per_gpu, **cfg} return cfg diff --git a/tests/plugins/test_deepspeed_plugin.py b/tests/plugins/test_deepspeed_plugin.py index 3cf3acc65b43b..d720efbda9823 100644 --- a/tests/plugins/test_deepspeed_plugin.py +++ b/tests/plugins/test_deepspeed_plugin.py @@ -2,6 +2,7 @@ import os from typing import Any, Dict +import mock import pytest import torch import torch.nn.functional as F @@ -238,7 +239,8 @@ def backward(self, loss: Tensor, optimizer: Optimizer, optimizer_idx: int, *args @RunIf(min_gpus=1, deepspeed=True) @pytest.mark.parametrize(['dataset_cls', 'value'], [(RandomDataset, "auto"), (RandomDataset, 10), (RandomIterableDataset, "auto"), (RandomIterableDataset, 10)]) -def test_deepspeed_auto_batch_size_config_select(tmpdir, dataset_cls, value): +@mock.patch('pytorch_lightning.plugins.DDPPlugin.setup_distributed', autospec=True) +def test_deepspeed_auto_batch_size_config_select(setup_distributed_mock, tmpdir, dataset_cls, value): """Test to ensure that the batch size is correctly set as expected for deepspeed logging purposes.""" class TestModel(BoringModel): @@ -248,13 +250,13 @@ def train_dataloader(self): class AssertCallback(Callback): - def on_train_start(self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule') -> None: + def on_train_start(self, trainer, pl_module) -> None: assert isinstance(trainer.accelerator.training_type_plugin, DeepSpeedPlugin) config = trainer.accelerator.training_type_plugin.config # int value overrides auto mode expected_value = value if isinstance(value, int) else 1 - if dataset_cls is RandomDataset: + if dataset_cls == RandomDataset: expected_value = pl_module.train_dataloader().batch_size if value is "auto" else value assert config['train_micro_batch_size_per_gpu'] == expected_value From 5cf6de7d45f04d301f1c1985405d7f1ec062f3bb Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Thu, 6 May 2021 15:17:01 +0100 Subject: [PATCH 06/10] Import correct package --- tests/plugins/test_deepspeed_plugin.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/plugins/test_deepspeed_plugin.py b/tests/plugins/test_deepspeed_plugin.py index d720efbda9823..6084c4b304083 100644 --- a/tests/plugins/test_deepspeed_plugin.py +++ b/tests/plugins/test_deepspeed_plugin.py @@ -1,8 +1,8 @@ import json import os from typing import Any, Dict +from unittest import mock -import mock import pytest import torch import torch.nn.functional as F From bb9c86d3fc9190a5357a3aa90ab61283147d9499 Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Thu, 6 May 2021 15:31:25 +0100 Subject: [PATCH 07/10] Fix comparison --- tests/plugins/test_deepspeed_plugin.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/plugins/test_deepspeed_plugin.py b/tests/plugins/test_deepspeed_plugin.py index 6084c4b304083..7d8f91d300348 100644 --- a/tests/plugins/test_deepspeed_plugin.py +++ b/tests/plugins/test_deepspeed_plugin.py @@ -257,7 +257,7 @@ def on_train_start(self, trainer, pl_module) -> None: # int value overrides auto mode expected_value = value if isinstance(value, int) else 1 if dataset_cls == RandomDataset: - expected_value = pl_module.train_dataloader().batch_size if value is "auto" else value + expected_value = pl_module.train_dataloader().batch_size if value == "auto" else value assert config['train_micro_batch_size_per_gpu'] == expected_value raise SystemExit From c31c8e02fc699fb0377e640a48169d447e33002a Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Thu, 6 May 2021 19:01:19 +0100 Subject: [PATCH 08/10] Set as special test --- tests/plugins/test_deepspeed_plugin.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/tests/plugins/test_deepspeed_plugin.py b/tests/plugins/test_deepspeed_plugin.py index 7d8f91d300348..e3aa045afa162 100644 --- a/tests/plugins/test_deepspeed_plugin.py +++ b/tests/plugins/test_deepspeed_plugin.py @@ -236,11 +236,10 @@ def backward(self, loss: Tensor, optimizer: Optimizer, optimizer_idx: int, *args trainer.fit(model) -@RunIf(min_gpus=1, deepspeed=True) +@RunIf(min_gpus=1, deepspeed=True, special=True) @pytest.mark.parametrize(['dataset_cls', 'value'], [(RandomDataset, "auto"), (RandomDataset, 10), (RandomIterableDataset, "auto"), (RandomIterableDataset, 10)]) -@mock.patch('pytorch_lightning.plugins.DDPPlugin.setup_distributed', autospec=True) -def test_deepspeed_auto_batch_size_config_select(setup_distributed_mock, tmpdir, dataset_cls, value): +def test_deepspeed_auto_batch_size_config_select(tmpdir, dataset_cls, value): """Test to ensure that the batch size is correctly set as expected for deepspeed logging purposes.""" class TestModel(BoringModel): From 68850dbb8703617a3002d22541a0ae11f5cde92d Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Thu, 6 May 2021 19:06:45 +0100 Subject: [PATCH 09/10] Remove import --- tests/plugins/test_deepspeed_plugin.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/plugins/test_deepspeed_plugin.py b/tests/plugins/test_deepspeed_plugin.py index e3aa045afa162..056c28ffa2309 100644 --- a/tests/plugins/test_deepspeed_plugin.py +++ b/tests/plugins/test_deepspeed_plugin.py @@ -1,7 +1,6 @@ import json import os from typing import Any, Dict -from unittest import mock import pytest import torch From e91c4be490054353b73fd2cea3698a4fe7354c40 Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Fri, 7 May 2021 10:09:20 +0100 Subject: [PATCH 10/10] Add Changelog --- CHANGELOG.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 193ca633f2fe6..af142fdba3414 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -22,6 +22,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Fixed +- Fixed DeepSpeed with IterableDatasets ([#7362](https://github.com/PyTorchLightning/pytorch-lightning/pull/7362)) + ## [1.3.0] - 2021-05-06