Skip to content

Resolve training type plugin when passed with Accelerator #10775

@kaushikb11

Description

@kaushikb11

🐛 Bug

To Reproduce

trainer = Trainer(accelerator=GPUAccelerator(precision_plugin=PrecisionPlugin(), training_type_plugin=DDPPlugin()), gpus=4)

At the moment, when training type plugin is passed with Accelerators, attributes such as parallel_devices, cluster_environment and sync_batchnorm are not set to the training plugin and leads to errors.

Expected behavior

trainer = Trainer(accelerator=GPUAccelerator(precision_plugin=PrecisionPlugin(), training_type_plugin=DDPPlugin()), gpus=4)

# should be equivalent to

training_type_plugin.parallel_devices == [torch.device("cuda", i) for i in self.parallel_device_ids]
training_type_plugin.cluster_environment == LightningEnvironment()
training_type_plugin.sync_batchnorm == False

Metadata

Metadata

Assignees

Labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions