Skip to content

DeepSpeedPlugin only works for Datasets with BatchSampler not IterableDataset #7345

@leezu

Description

@leezu

https://github.com/PyTorchLightning/pytorch-lightning/blob/2a740ebe775c585f70d5f6fede38d29a5e26ab2f/pytorch_lightning/plugins/training_type/deepspeed.py#L411-L415

uses BatchSampler to infer train_micro_batch_size_per_gpu. BatchSampler may not be present when using IterableDataset. One option to optionally enable users to specif the micro batch size when constructing the plugin:

modified   pytorch_lightning/plugins/training_type/deepspeed.py
@@ -88,6 +88,7 @@ class DeepSpeedPlugin(DDPPlugin):
         allgather_bucket_size: int = 2e8,
         reduce_bucket_size: int = 2e8,
         zero_allow_untested_optimizer: bool = True,
+        train_micro_batch_size_per_gpu: Optional[int] = None,
         config: Optional[Union[Path, str, dict]] = None,
         logging_level: int = logging.WARN,
         num_nodes: int = 1,
@@ -148,6 +149,8 @@ class DeepSpeedPlugin(DDPPlugin):
             zero_allow_untested_optimizer: Allow untested optimizers to be used with ZeRO. Currently only Adam is a
                 DeepSpeed supported optimizer when using ZeRO (default: True)
 
+            train_micro_batch_size_per_gpu
+
             config: Pass in a deepspeed formatted config dict,
                 or path to a deepspeed config: https://www.deepspeed.ai/docs/config-json.
                 All defaults will be ignored if a config is passed in. (Default: ``None``)
@@ -197,6 +200,7 @@ class DeepSpeedPlugin(DDPPlugin):
             self.config = self._create_default_config(
                 zero_optimization,
                 zero_allow_untested_optimizer,
+                train_micro_batch_size_per_gpu,
                 partition_activations=partition_activations,
                 cpu_checkpointing=cpu_checkpointing,
                 contiguous_memory_optimization=contiguous_memory_optimization,
@@ -446,6 +450,7 @@ class DeepSpeedPlugin(DDPPlugin):
         self,
         zero_optimization: bool,
         zero_allow_untested_optimizer: bool,
+        train_micro_batch_size_per_gpu: Optional[int],
         partition_activations: bool,
         cpu_checkpointing: bool,
         contiguous_memory_optimization: bool,
@@ -466,6 +471,9 @@ class DeepSpeedPlugin(DDPPlugin):
                 "zero_optimization": zero_kwargs,
                 **cfg
             }
+        if train_micro_batch_size_per_gpu is not None:
+            cfg = {"train_micro_batch_size_per_gpu": train_micro_batch_size_per_gpu,
+                   **cfg}
         return cfg
 
     def _filepath_to_dir(self, filepath: str) -> str:

@SeanNaren WDYT?

Metadata

Metadata

Assignees

Labels

bugSomething isn't workinghelp wantedOpen to be worked on

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions