Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
42 commits
Select commit Hold shift + click to select a range
021dae9
Add support for pluggable Accelerators
kaushikb11 Feb 21, 2022
7320c66
Add parse_devices method to Accelerators
kaushikb11 Feb 21, 2022
ac2f5c0
Refactor device parsing logic
kaushikb11 Feb 21, 2022
cfece86
Fix passing Accelerator instances
kaushikb11 Feb 21, 2022
c83717e
Fix devices auto
kaushikb11 Feb 21, 2022
e82feaf
Add tests
kaushikb11 Feb 21, 2022
24fa9a4
Update changelog
kaushikb11 Feb 21, 2022
ae08e2a
Update accelerators
kaushikb11 Feb 21, 2022
a6c8bc4
Update accelerator doc
kaushikb11 Feb 21, 2022
41de817
Merge branch 'master' into feat/pluggable_accelerators
kaushikb11 Feb 22, 2022
19f2e8d
Fix acc connector tests
kaushikb11 Feb 22, 2022
b373fa6
Fix parallel devices being passed to strategy
kaushikb11 Feb 23, 2022
af76a72
Fix gpu test
kaushikb11 Feb 23, 2022
cf4b8b3
Merge branch 'master' into feat/pluggable_accelerators
kaushikb11 Feb 23, 2022
0f4f387
Update tests
kaushikb11 Feb 23, 2022
0fc482b
Update tests
kaushikb11 Feb 24, 2022
10c3d4f
Update tests
kaushikb11 Feb 24, 2022
b97c88c
Update
kaushikb11 Feb 24, 2022
1c12097
Fix typing
kaushikb11 Feb 24, 2022
b75cf38
Update pytorch_lightning/accelerators/cpu.py
kaushikb11 Feb 25, 2022
d8ac58e
Update pytorch_lightning/accelerators/gpu.py
kaushikb11 Feb 25, 2022
8b21da3
Fix typing
kaushikb11 Feb 25, 2022
30ea678
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 25, 2022
6da57dd
Taking control over pre-commit
kaushikb11 Feb 25, 2022
9eb70eb
Update tests/accelerators/test_common.py
kaushikb11 Feb 25, 2022
b55af6b
Update tests/accelerators/test_accelerator_connector.py
kaushikb11 Feb 25, 2022
2b99d02
Address reviews
kaushikb11 Feb 25, 2022
258f587
Update pytorch_lightning/trainer/connectors/accelerator_connector.py
kaushikb11 Feb 25, 2022
438769f
Mypy
carmocca Feb 25, 2022
6e35841
Update pytorch_lightning/accelerators/cpu.py
kaushikb11 Feb 25, 2022
4bdcaff
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 25, 2022
1bdd7a8
Update tests
kaushikb11 Feb 25, 2022
85a4440
Merge branch 'feat/pluggable_accelerators' of https://github.com/PyTo…
kaushikb11 Feb 25, 2022
82fbc02
Fix typing
kaushikb11 Feb 25, 2022
ec845a5
fix typing
rohitgr7 Feb 25, 2022
a99e4ea
Fix typing
kaushikb11 Feb 25, 2022
9869db1
fix typing
rohitgr7 Feb 25, 2022
1fa5d1a
Simplify test
carmocca Feb 25, 2022
6e08017
Merge branch 'master' into feat/pluggable_accelerators
kaushikb11 Feb 28, 2022
e50d1d1
Fix tests
kaushikb11 Feb 28, 2022
5f1a1e4
Merge branch 'master' into feat/pluggable_accelerators
kaushikb11 Feb 28, 2022
190f7a8
Fix tests
kaushikb11 Feb 28, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added `estimated_stepping_batches` property to `Trainer` ([#11599](https://github.com/PyTorchLightning/pytorch-lightning/pull/11599))


- Added support for pluggable Accelerators ([#12030](https://github.com/PyTorchLightning/pytorch-lightning/pull/12030))


### Changed

- Make `benchmark` flag optional and set its value based on the deterministic flag ([#11944](https://github.com/PyTorchLightning/pytorch-lightning/pull/11944))
Expand Down
5 changes: 3 additions & 2 deletions docs/source/extensions/accelerator.rst
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ Each Accelerator gets two plugins upon initialization:
One to handle differences from the training routine and one to handle different precisions.

.. testcode::
:skipif: torch.cuda.device_count() < 2

from pytorch_lightning import Trainer
from pytorch_lightning.accelerators import GPUAccelerator
Expand All @@ -28,8 +29,8 @@ One to handle differences from the training routine and one to handle different

accelerator = GPUAccelerator()
precision_plugin = NativeMixedPrecisionPlugin(precision=16, device="cuda")
training_type_plugin = DDPStrategy(accelerator=accelerator, precision_plugin=precision_plugin)
trainer = Trainer(strategy=training_type_plugin)
training_strategy = DDPStrategy(accelerator=accelerator, precision_plugin=precision_plugin)
trainer = Trainer(strategy=training_strategy, devices=2)


We expose Accelerators and Plugins mainly for expert users who want to extend Lightning to work with new
Expand Down
10 changes: 10 additions & 0 deletions pytorch_lightning/accelerators/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,16 @@ def get_device_stats(self, device: Union[str, torch.device]) -> Dict[str, Any]:
"""
raise NotImplementedError

@staticmethod
@abstractmethod
def parse_devices(devices: Any) -> Any:
"""Accelerator device parsing logic."""

@staticmethod
@abstractmethod
def get_parallel_devices(devices: Any) -> Any:
"""Gets parallel devices for the Accelerator."""

@staticmethod
@abstractmethod
def auto_device_count() -> int:
Expand Down
22 changes: 18 additions & 4 deletions pytorch_lightning/accelerators/cpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,13 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations

from typing import Any
from typing import Any, Dict, List, Union

import torch

from pytorch_lightning.accelerators.accelerator import Accelerator
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.rank_zero import rank_zero_warn
from pytorch_lightning.utilities.types import _DEVICE


Expand All @@ -35,10 +34,25 @@ def setup_environment(self, root_device: torch.device) -> None:
if root_device.type != "cpu":
raise MisconfigurationException(f"Device should be CPU, got {root_device} instead.")

def get_device_stats(self, device: _DEVICE) -> dict[str, Any]:
def get_device_stats(self, device: _DEVICE) -> Dict[str, Any]:
"""CPU device stats aren't supported yet."""
return {}

@staticmethod
def parse_devices(devices: Union[int, str, List[int]]) -> Union[int, str, List[int]]:
"""Accelerator device parsing logic."""
return devices

@staticmethod
def get_parallel_devices(devices: Union[int, str, List[int]]) -> List[torch.device]:
"""Gets parallel devices for the Accelerator."""
if isinstance(devices, int):
return [torch.device("cpu")] * devices
rank_zero_warn(
f"The flag `devices` must be an int with `accelerator='cpu'`, got `devices={devices!r}` instead."
)
return []

@staticmethod
def auto_device_count() -> int:
"""Get the devices when set to auto."""
Expand Down
21 changes: 15 additions & 6 deletions pytorch_lightning/accelerators/gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,18 +11,17 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations

import logging
import os
import shutil
import subprocess
from typing import Any
from typing import Any, Dict, List, Optional, Union

import torch

import pytorch_lightning as pl
from pytorch_lightning.accelerators.accelerator import Accelerator
from pytorch_lightning.utilities import device_parser
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.imports import _TORCH_GREATER_EQUAL_1_8
from pytorch_lightning.utilities.types import _DEVICE
Expand All @@ -44,7 +43,7 @@ def setup_environment(self, root_device: torch.device) -> None:
raise MisconfigurationException(f"Device should be GPU, got {root_device} instead")
torch.cuda.set_device(root_device)

def setup(self, trainer: pl.Trainer) -> None:
def setup(self, trainer: "pl.Trainer") -> None:
# TODO refactor input from trainer to local_rank @four4fish
self.set_nvidia_flags(trainer.local_rank)
# clear cache before training
Expand All @@ -58,7 +57,7 @@ def set_nvidia_flags(local_rank: int) -> None:
devices = os.getenv("CUDA_VISIBLE_DEVICES", all_gpu_ids)
_log.info(f"LOCAL_RANK: {local_rank} - CUDA_VISIBLE_DEVICES: [{devices}]")

def get_device_stats(self, device: _DEVICE) -> dict[str, Any]:
def get_device_stats(self, device: _DEVICE) -> Dict[str, Any]:
"""Gets stats for the given GPU device.

Args:
Expand All @@ -75,6 +74,16 @@ def get_device_stats(self, device: _DEVICE) -> dict[str, Any]:
return torch.cuda.memory_stats(device)
return get_nvidia_gpu_stats(device)

@staticmethod
def parse_devices(devices: Union[int, str, List[int]]) -> Optional[List[int]]:
"""Accelerator device parsing logic."""
return device_parser.parse_gpu_ids(devices)

@staticmethod
def get_parallel_devices(devices: List[int]) -> List[torch.device]:
"""Gets parallel devices for the Accelerator."""
return [torch.device("cuda", i) for i in devices]

@staticmethod
def auto_device_count() -> int:
"""Get the devices when set to auto."""
Expand All @@ -85,7 +94,7 @@ def is_available() -> bool:
return torch.cuda.device_count() > 0


def get_nvidia_gpu_stats(device: _DEVICE) -> dict[str, float]:
def get_nvidia_gpu_stats(device: _DEVICE) -> Dict[str, float]:
"""Get GPU stats including memory, fan speed, and temperature from nvidia-smi.

Args:
Expand Down
12 changes: 11 additions & 1 deletion pytorch_lightning/accelerators/ipu.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Dict, Union
from typing import Any, Dict, List, Union

import torch

Expand All @@ -26,6 +26,16 @@ def get_device_stats(self, device: Union[str, torch.device]) -> Dict[str, Any]:
"""IPU device stats aren't supported yet."""
return {}

@staticmethod
def parse_devices(devices: int) -> int:
"""Accelerator device parsing logic."""
return devices

@staticmethod
def get_parallel_devices(devices: int) -> List[int]:
"""Gets parallel devices for the Accelerator."""
return list(range(devices))

@staticmethod
def auto_device_count() -> int:
"""Get the devices when set to auto."""
Expand Down
15 changes: 14 additions & 1 deletion pytorch_lightning/accelerators/tpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,12 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Dict, Union
from typing import Any, Dict, List, Optional, Union

import torch

from pytorch_lightning.accelerators.accelerator import Accelerator
from pytorch_lightning.utilities import device_parser
from pytorch_lightning.utilities.imports import _TPU_AVAILABLE, _XLA_AVAILABLE

if _XLA_AVAILABLE:
Expand Down Expand Up @@ -43,6 +44,18 @@ def get_device_stats(self, device: Union[str, torch.device]) -> Dict[str, Any]:
}
return device_stats

@staticmethod
def parse_devices(devices: Union[int, str, List[int]]) -> Optional[Union[int, List[int]]]:
"""Accelerator device parsing logic."""
return device_parser.parse_tpu_cores(devices)

@staticmethod
def get_parallel_devices(devices: Union[int, List[int]]) -> List[int]:
"""Gets parallel devices for the Accelerator."""
if isinstance(devices, int):
return list(range(devices))
return devices

@staticmethod
def auto_device_count() -> int:
"""Get the devices when set to auto."""
Expand Down
65 changes: 24 additions & 41 deletions pytorch_lightning/trainer/connectors/accelerator_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,7 @@ def __init__(
self._precision_flag: Optional[Union[int, str]] = None
self._precision_plugin_flag: Optional[PrecisionPlugin] = None
self._cluster_environment_flag: Optional[Union[ClusterEnvironment, str]] = None
self._parallel_devices: List[Union[int, torch.device]] = []
self.checkpoint_io: Optional[CheckpointIO] = None
self._amp_type_flag: Optional[LightningEnum] = None
self._amp_level_flag: Optional[str] = amp_level
Expand Down Expand Up @@ -361,6 +362,7 @@ def _check_config_and_set_final_flags(
self._accelerator_flag = "cpu"
if self._strategy_flag.parallel_devices[0].type == "cuda":
self._accelerator_flag = "gpu"
self._parallel_devices = self._strategy_flag.parallel_devices

amp_type = amp_type if isinstance(amp_type, str) else None
self._amp_type_flag = AMPType.from_str(amp_type)
Expand All @@ -387,7 +389,7 @@ def _check_device_config_and_set_final_flags(
devices, num_processes, gpus, ipus, tpu_cores
)

if self._devices_flag in ([], 0, "0", "0,"):
if self._devices_flag in ([], 0, "0"):
rank_zero_warn(f"You passed `devices={devices}`, switching to `cpu` accelerator")
self._accelerator_flag = "cpu"

Expand All @@ -408,10 +410,8 @@ def _map_deprecated_devices_specfic_info_to_accelerator_and_device_flag(
"""Sets the `devices_flag` and `accelerator_flag` based on num_processes, gpus, ipus, tpu_cores."""
self._gpus: Optional[Union[List[int], str, int]] = gpus
self._tpu_cores: Optional[Union[List[int], str, int]] = tpu_cores
gpus = device_parser.parse_gpu_ids(gpus)
tpu_cores = device_parser.parse_tpu_cores(tpu_cores)
deprecated_devices_specific_flag = num_processes or gpus or ipus or tpu_cores
if deprecated_devices_specific_flag and deprecated_devices_specific_flag not in (0, "0"):
if deprecated_devices_specific_flag and deprecated_devices_specific_flag not in ([], 0, "0"):
if devices:
# TODO: @awaelchli improve error message
rank_zero_warn(
Expand Down Expand Up @@ -456,51 +456,34 @@ def _choose_accelerator(self) -> str:

def _set_parallel_devices_and_init_accelerator(self) -> None:
# TODO add device availability check
self._parallel_devices: List[Union[int, torch.device]] = []

if isinstance(self._accelerator_flag, Accelerator):
self.accelerator: Accelerator = self._accelerator_flag
elif self._accelerator_flag == "tpu":
self.accelerator = TPUAccelerator()
self._set_devices_flag_if_auto_passed()
if isinstance(self._devices_flag, int):
self._parallel_devices = list(range(self._devices_flag))
else:
self._parallel_devices = self._devices_flag # type: ignore[assignment]

elif self._accelerator_flag == "ipu":
self.accelerator = IPUAccelerator()
self._set_devices_flag_if_auto_passed()
if isinstance(self._devices_flag, int):
self._parallel_devices = list(range(self._devices_flag))

elif self._accelerator_flag == "gpu":
self.accelerator = GPUAccelerator()
self._set_devices_flag_if_auto_passed()
if isinstance(self._devices_flag, int) or isinstance(self._devices_flag, str):
self._devices_flag = int(self._devices_flag)
self._parallel_devices = (
[torch.device("cuda", i) for i in device_parser.parse_gpu_ids(self._devices_flag)] # type: ignore
if self._devices_flag != 0
else []
else:
ACCELERATORS = {
"cpu": CPUAccelerator,
"gpu": GPUAccelerator,
"tpu": TPUAccelerator,
"ipu": IPUAccelerator,
}
assert self._accelerator_flag is not None
self._accelerator_flag = self._accelerator_flag.lower()
if self._accelerator_flag not in ACCELERATORS:
raise MisconfigurationException(
"When passing string value for the `accelerator` argument of `Trainer`,"
f" it can only be one of {list(ACCELERATORS)}."
)
else:
self._parallel_devices = [torch.device("cuda", i) for i in self._devices_flag] # type: ignore
accelerator_class = ACCELERATORS[self._accelerator_flag]
self.accelerator = accelerator_class() # type: ignore[abstract]

elif self._accelerator_flag == "cpu":
self.accelerator = CPUAccelerator()
self._set_devices_flag_if_auto_passed()
if isinstance(self._devices_flag, int):
self._parallel_devices = [torch.device("cpu")] * self._devices_flag
else:
rank_zero_warn(
"The flag `devices` must be an int with `accelerator='cpu'`,"
f" got `devices={self._devices_flag}` instead."
)
self._set_devices_flag_if_auto_passed()

self._gpus = self._devices_flag if not self._gpus else self._gpus
self._tpu_cores = self._devices_flag if not self._tpu_cores else self._tpu_cores

self._devices_flag = self.accelerator.parse_devices(self._devices_flag)
if not self._parallel_devices:
self._parallel_devices = self.accelerator.get_parallel_devices(self._devices_flag)

def _set_devices_flag_if_auto_passed(self) -> None:
if self._devices_flag == "auto" or not self._devices_flag:
self._devices_flag = self.accelerator.auto_device_count()
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/trainer/connectors/data_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from torch.utils.data.distributed import DistributedSampler

import pytorch_lightning as pl
from pytorch_lightning.accelerators import IPUAccelerator
from pytorch_lightning.accelerators.ipu import IPUAccelerator
from pytorch_lightning.overrides.distributed import UnrepeatedDistributedSampler
from pytorch_lightning.strategies import DDPSpawnStrategy
from pytorch_lightning.trainer.states import RunningStage, TrainerFn
Expand Down
26 changes: 20 additions & 6 deletions tests/accelerators/test_accelerator_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,6 +341,14 @@ def creates_processes_externally(self) -> bool:
@mock.patch("pytorch_lightning.strategies.DDPStrategy.setup_distributed", autospec=True)
def test_custom_accelerator(device_count_mock, setup_distributed_mock):
class Accel(Accelerator):
@staticmethod
def parse_devices(devices):
return devices

@staticmethod
def get_parallel_devices(devices):
return [torch.device("cpu")] * devices

@staticmethod
def auto_device_count() -> int:
return 1
Expand Down Expand Up @@ -413,10 +421,17 @@ def test_ipython_incompatible_backend_error(_, monkeypatch):
Trainer(strategy="dp")


@pytest.mark.parametrize("trainer_kwargs", [{}, dict(strategy="dp", accelerator="gpu"), dict(accelerator="tpu")])
def test_ipython_compatible_backend(trainer_kwargs, monkeypatch):
@mock.patch("torch.cuda.device_count", return_value=2)
def test_ipython_compatible_dp_strategy_gpu(_, monkeypatch):
monkeypatch.setattr(pytorch_lightning.utilities, "_IS_INTERACTIVE", True)
trainer = Trainer(strategy="dp", accelerator="gpu")
assert trainer.strategy.launcher is None or trainer.strategy.launcher.is_interactive_compatible


@mock.patch("pytorch_lightning.accelerators.tpu.TPUAccelerator.parse_devices", return_value=8)
def test_ipython_compatible_strategy_tpu(_, monkeypatch):
monkeypatch.setattr(pytorch_lightning.utilities, "_IS_INTERACTIVE", True)
trainer = Trainer(**trainer_kwargs)
trainer = Trainer(accelerator="tpu")
assert trainer.strategy.launcher is None or trainer.strategy.launcher.is_interactive_compatible


Expand Down Expand Up @@ -883,10 +898,9 @@ def test_strategy_choice_ddp_cpu_slurm(device_count_mock, setup_distributed_mock
assert trainer.strategy.local_rank == 0


def test_unsupported_tpu_choice(monkeypatch):
import pytorch_lightning.utilities.imports as imports
@mock.patch("pytorch_lightning.accelerators.tpu.TPUAccelerator.parse_devices", return_value=8)
def test_unsupported_tpu_choice(mock_devices):

monkeypatch.setattr(imports, "_XLA_AVAILABLE", True)
with pytest.raises(MisconfigurationException, match=r"accelerator='tpu', precision=64\)` is not implemented"):
Trainer(accelerator="tpu", precision=64)

Expand Down
Loading