Skip to content

Commit 91b5a68

Browse files
authored
[Trainer] remove env vars (#41697)
* remove env var * style * fix value * update * fix * style * fix * maybe this time * rm tests * fix
1 parent d4562bb commit 91b5a68

File tree

3 files changed

+172
-191
lines changed

3 files changed

+172
-191
lines changed

src/transformers/trainer.py

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -209,7 +209,6 @@
209209

210210
if is_accelerate_available():
211211
from accelerate import Accelerator, skip_first_batches
212-
from accelerate import __version__ as accelerate_version
213212
from accelerate.state import AcceleratorState
214213
from accelerate.utils import (
215214
DataLoaderConfiguration,
@@ -4967,7 +4966,18 @@ def create_accelerator_and_postprocess(self):
49674966
# this would have been updated above, no need for it anymore
49684967
accelerator_config.pop("gradient_accumulation_kwargs")
49694968

4970-
args = {"deepspeed_plugin": self.args.deepspeed_plugin, "dataloader_config": dataloader_config}
4969+
fsdp_plugin = None
4970+
if self.args.fsdp_plugin_args is not None:
4971+
from accelerate.utils import FullyShardedDataParallelPlugin
4972+
4973+
fsdp_plugin = FullyShardedDataParallelPlugin(**self.args.fsdp_plugin_args)
4974+
4975+
args = {
4976+
"mixed_precision": self.args.mixed_precision,
4977+
"dataloader_config": dataloader_config,
4978+
"fsdp_plugin": fsdp_plugin,
4979+
"deepspeed_plugin": self.args.deepspeed_plugin,
4980+
}
49714981

49724982
# We defer compatibility checks to accelerator
49734983
if self.args.parallelism_config is not None:
@@ -4981,14 +4991,23 @@ def create_accelerator_and_postprocess(self):
49814991
if getattr(self.model, "tp_size", None) is not None and self.model.tp_size > 1:
49824992
self.is_tp_enabled = True
49834993
if self.args.parallelism_config is not None:
4984-
if version.parse(accelerate_version) > version.parse("1.10.1"):
4994+
if is_accelerate_available("1.10.1"):
49854995
if self.args.parallelism_config is not None:
49864996
from accelerate import ParallelismConfig
49874997

49884998
args["parallelism_config"] = ParallelismConfig(tp_size=self.model.tp_size)
49894999
else:
49905000
raise ValueError("Requires accelerate>1.10.1 to use Tensor Parallelism.")
49915001

5002+
if is_accelerate_available("1.2.0"):
5003+
# it we don't have the correct version, we will rely on env var instead that were set in TrainingArguments
5004+
from accelerate.utils import TorchDynamoPlugin
5005+
5006+
dynamo_plugin = TorchDynamoPlugin(
5007+
backend=self.args.torch_compile_backend, mode=self.args.torch_compile_mode
5008+
)
5009+
args["dynamo_plugin"] = dynamo_plugin
5010+
49925011
# create accelerator object
49935012
self.accelerator = Accelerator(**args)
49945013
# some Trainer classes need to use `gather` instead of `gather_for_metrics`, thus we store a flag

0 commit comments

Comments
 (0)