@@ -691,10 +691,16 @@ def select_precision_plugin(self) -> PrecisionPlugin:
691691
692692 def create_training_type_plugin (self ) -> TrainingTypePlugin :
693693 if self .use_ddp2 :
694- plugin = DDP2Plugin (accelerator = self .accelerator , parallel_devices = self .parallel_devices , cluster_environment = self .cluster_environment ,)
694+ plugin = DDP2Plugin (
695+ accelerator = self .accelerator ,
696+ parallel_devices = self .parallel_devices ,
697+ cluster_environment = self .cluster_environment ,
698+ )
695699 elif self .use_ddp and self .use_deepspeed :
696700 plugin = DeepSpeedPlugin (
697- accelerator = self .accelerator , cluster_environment = self .select_cluster_environment (), parallel_devices = self .parallel_devices ,
701+ accelerator = self .accelerator ,
702+ cluster_environment = self .select_cluster_environment (),
703+ parallel_devices = self .parallel_devices ,
698704 )
699705 elif self .use_ddp :
700706 use_slurm_ddp = self .use_ddp and self ._is_slurm_managing_tasks ()
@@ -733,7 +739,9 @@ def create_training_type_plugin(self) -> TrainingTypePlugin:
733739 ddp_plugin_cls = DDPPlugin
734740
735741 plugin = ddp_plugin_cls (
736- accelerator = self .accelerator , parallel_devices = self .parallel_devices , cluster_environment = self .cluster_environment ,
742+ accelerator = self .accelerator ,
743+ parallel_devices = self .parallel_devices ,
744+ cluster_environment = self .cluster_environment ,
737745 )
738746 elif self .use_dp :
739747 plugin = DataParallelPlugin (accelerator = self .accelerator , parallel_devices = self .parallel_devices )
@@ -745,7 +753,10 @@ def create_training_type_plugin(self) -> TrainingTypePlugin:
745753 plugin = IPUPlugin (accelerator = self .accelerator , parallel_devices = self .parallel_devices )
746754 else :
747755 single_gpu_ordinal = device_parser .determine_root_gpu_device (self .parallel_device_ids )
748- plugin = SingleDevicePlugin (accelerator = self .accelerator , device = (torch .device (f"cuda:{ single_gpu_ordinal } " if self .use_gpu else "cpu" )),)
756+ plugin = SingleDevicePlugin (
757+ accelerator = self .accelerator ,
758+ device = (torch .device (f"cuda:{ single_gpu_ordinal } " if self .use_gpu else "cpu" )),
759+ )
749760 return plugin
750761
751762 def resolve_training_type_plugin (self , training_type : TrainingTypePlugin ) -> TrainingTypePlugin :
0 commit comments