209209
210210if is_accelerate_available ():
211211 from accelerate import Accelerator , skip_first_batches
212- from accelerate import __version__ as accelerate_version
213212 from accelerate .state import AcceleratorState
214213 from accelerate .utils import (
215214 DataLoaderConfiguration ,
@@ -4967,7 +4966,18 @@ def create_accelerator_and_postprocess(self):
49674966 # this would have been updated above, no need for it anymore
49684967 accelerator_config .pop ("gradient_accumulation_kwargs" )
49694968
4970- args = {"deepspeed_plugin" : self .args .deepspeed_plugin , "dataloader_config" : dataloader_config }
4969+ fsdp_plugin = None
4970+ if self .args .fsdp_plugin_args is not None :
4971+ from accelerate .utils import FullyShardedDataParallelPlugin
4972+
4973+ fsdp_plugin = FullyShardedDataParallelPlugin (** self .args .fsdp_plugin_args )
4974+
4975+ args = {
4976+ "mixed_precision" : self .args .mixed_precision ,
4977+ "dataloader_config" : dataloader_config ,
4978+ "fsdp_plugin" : fsdp_plugin ,
4979+ "deepspeed_plugin" : self .args .deepspeed_plugin ,
4980+ }
49714981
49724982 # We defer compatibility checks to accelerator
49734983 if self .args .parallelism_config is not None :
@@ -4981,14 +4991,23 @@ def create_accelerator_and_postprocess(self):
49814991 if getattr (self .model , "tp_size" , None ) is not None and self .model .tp_size > 1 :
49824992 self .is_tp_enabled = True
49834993 if self .args .parallelism_config is not None :
4984- if version . parse ( accelerate_version ) > version . parse ("1.10.1" ):
4994+ if is_accelerate_available ("1.10.1" ):
49854995 if self .args .parallelism_config is not None :
49864996 from accelerate import ParallelismConfig
49874997
49884998 args ["parallelism_config" ] = ParallelismConfig (tp_size = self .model .tp_size )
49894999 else :
49905000 raise ValueError ("Requires accelerate>1.10.1 to use Tensor Parallelism." )
49915001
5002+ if is_accelerate_available ("1.2.0" ):
5003+ # it we don't have the correct version, we will rely on env var instead that were set in TrainingArguments
5004+ from accelerate .utils import TorchDynamoPlugin
5005+
5006+ dynamo_plugin = TorchDynamoPlugin (
5007+ backend = self .args .torch_compile_backend , mode = self .args .torch_compile_mode
5008+ )
5009+ args ["dynamo_plugin" ] = dynamo_plugin
5010+
49925011 # create accelerator object
49935012 self .accelerator = Accelerator (** args )
49945013 # some Trainer classes need to use `gather` instead of `gather_for_metrics`, thus we store a flag
0 commit comments