Skip to content

Support passing Accelerator objects to the accelerator flag with devices=x #10592

@kaushikb11

Description

@kaushikb11

🐛 Bug

Expected behavior

As we are walking down the path to support #10410, it is important to support this behavior.

trainer = Trainer(accelerator=GPUAccelerator(CustomStrategy()), devices=4)

# should be equivalent to

trainer = Trainer(accelerator="gpu", devices=4, strategy=CustomStrategy())

Environment

  • PyTorch Lightning Version (e.g., 1.3.0):
  • PyTorch Version (e.g., 1.8)
  • Python version:
  • OS (e.g., Linux):
  • CUDA/cuDNN version:
  • GPU models and configuration:
  • How you installed PyTorch (conda, pip, source):
  • If compiling from source, the output of torch.__config__.show():
  • Any other relevant information:

Additional context

cc @tchaton @justusschock @kaushikb11 @awaelchli @Borda

Metadata

Metadata

Assignees

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions