@@ -88,6 +88,7 @@ def __init__(
8888 allgather_bucket_size : int = 2e8 ,
8989 reduce_bucket_size : int = 2e8 ,
9090 zero_allow_untested_optimizer : bool = True ,
91+ logging_batch_size_per_gpu : Union [str , int ] = "auto" ,
9192 config : Optional [Union [Path , str , dict ]] = None ,
9293 logging_level : int = logging .WARN ,
9394 num_nodes : int = 1 ,
@@ -148,6 +149,13 @@ def __init__(
148149 zero_allow_untested_optimizer: Allow untested optimizers to be used with ZeRO. Currently only Adam is a
149150 DeepSpeed supported optimizer when using ZeRO (default: True)
150151
152+ logging_batch_size_per_gpu: Config used in DeepSpeed to calculate verbose timing for logging
153+ on a per sample per second basis (only displayed if logging=logging.INFO).
154+ If set to "auto", the plugin tries to infer this from
155+ the train DataLoader's BatchSampler, else defaults to 1.
156+ To obtain accurate logs when using datasets that do not support batch samplers,
157+ set this to the actual per gpu batch size (trainer.batch_size).
158+
151159 config: Pass in a deepspeed formatted config dict,
152160 or path to a deepspeed config: https://www.deepspeed.ai/docs/config-json.
153161 All defaults will be ignored if a config is passed in. (Default: ``None``)
@@ -182,6 +190,7 @@ def __init__(
182190 when using ZeRO Stage 3. This allows a single weight file to contain the entire model,
183191 rather than individual sharded weight files.
184192 Disable to save sharded states individually. (Default: True)
193+
185194 """
186195 if not _DEEPSPEED_AVAILABLE :
187196 raise MisconfigurationException (
@@ -197,6 +206,7 @@ def __init__(
197206 self .config = self ._create_default_config (
198207 zero_optimization ,
199208 zero_allow_untested_optimizer ,
209+ logging_batch_size_per_gpu ,
200210 partition_activations = partition_activations ,
201211 cpu_checkpointing = cpu_checkpointing ,
202212 contiguous_memory_optimization = contiguous_memory_optimization ,
@@ -409,14 +419,22 @@ def _format_batch_size_and_grad_accum_config(self):
409419 " as this will be set via accumulate_grad_batches=x argument passed via the Lightning Trainer."
410420 )
411421 if "train_micro_batch_size_per_gpu" not in self .config :
412- # train_micro_batch_size_per_gpu is used for throughput logging purposes
413- # by default we use the batch size of the loader which may be incorrect if a batch sampler is passed
414- batch_size = self .lightning_module .train_dataloader ().batch_sampler .batch_size
422+ batch_size = self ._auto_select_batch_size ()
415423 self .config ["train_micro_batch_size_per_gpu" ] = batch_size
416424 self .config ["gradient_accumulation_steps" ] = self .lightning_module .trainer .accumulate_grad_batches
417425 if "gradient_clipping" not in self .config :
418426 self .config ["gradient_clipping" ] = self .lightning_module .trainer .gradient_clip_val
419427
428+ def _auto_select_batch_size (self ):
429+ # train_micro_batch_size_per_gpu is used for throughput logging purposes
430+ # by default we try to use the batch size of the loader
431+ batch_size = 1
432+ if hasattr (self .lightning_module , 'train_dataloader' ):
433+ train_dataloader = self .lightning_module .train_dataloader ()
434+ if hasattr (train_dataloader , 'batch_sampler' ):
435+ batch_size = train_dataloader .batch_sampler .batch_size
436+ return batch_size
437+
420438 def _format_precision_config (self ):
421439 amp_type = self .lightning_module .trainer .accelerator_connector .amp_type
422440 amp_level = self .lightning_module .trainer .accelerator_connector .amp_level
@@ -446,6 +464,7 @@ def _create_default_config(
446464 self ,
447465 zero_optimization : bool ,
448466 zero_allow_untested_optimizer : bool ,
467+ logging_batch_size_per_gpu : Union [str , int ],
449468 partition_activations : bool ,
450469 cpu_checkpointing : bool ,
451470 contiguous_memory_optimization : bool ,
@@ -466,6 +485,8 @@ def _create_default_config(
466485 "zero_optimization" : zero_kwargs ,
467486 ** cfg
468487 }
488+ if logging_batch_size_per_gpu != 'auto' :
489+ cfg = {"train_micro_batch_size_per_gpu" : logging_batch_size_per_gpu , ** cfg }
469490 return cfg
470491
471492 def _filepath_to_dir (self , filepath : str ) -> str :
0 commit comments