@@ -54,7 +54,7 @@ class LightningLite(ABC):
5454 - Multi-node support.
5555
5656 Args:
57- accelerator: The hardware to run on. Possible choices are: ``"cpu"``, ``"gpu"``, ``"tpu"``, ``"auto"``.
57+ accelerator: The hardware to run on. Possible choices are: ``"cpu"``, ``"cuda"``, ``" gpu"``, ``"tpu"``, ``"auto"``.
5858 strategy: Strategy for how to run across multiple devices. Possible choices are:
5959 ``"dp"``, ``"ddp"``, ``"ddp_spawn"``, ``"deepspeed"``, ``"ddp_sharded"``.
6060 devices: Number of devices to train on (``int``), which GPUs to train on (``list`` or ``str``), or ``"auto"``.
@@ -443,7 +443,7 @@ def _get_distributed_sampler(dataloader: DataLoader, **kwargs: Any) -> Distribut
443443 return DistributedSamplerWrapper (dataloader .sampler , ** kwargs )
444444
445445 def _check_accelerator_support (self , accelerator : Optional [Union [str , Accelerator ]]) -> None :
446- supported = [t .value .lower () for t in self ._supported_device_types ()] + ["auto" ]
446+ supported = [t .value .lower () for t in self ._supported_device_types ()] + ["gpu" , " auto" ]
447447 valid = accelerator is None or isinstance (accelerator , Accelerator ) or accelerator in supported
448448 if not valid :
449449 raise MisconfigurationException (
@@ -464,7 +464,7 @@ def _check_strategy_support(self, strategy: Optional[Union[str, Strategy]]) -> N
464464 def _supported_device_types () -> Sequence [_AcceleratorType ]:
465465 return (
466466 _AcceleratorType .CPU ,
467- _AcceleratorType .GPU ,
467+ _AcceleratorType .CUDA ,
468468 _AcceleratorType .TPU ,
469469 _AcceleratorType .MPS ,
470470 )
0 commit comments