Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions pytorch_lightning/accelerators/tpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@ class TPUAccelerator(Accelerator):
def setup(self, trainer, model):
if isinstance(self.precision_plugin, MixedPrecisionPlugin):
raise MisconfigurationException(
"amp + tpu is not supported. "
"Only bfloats are supported on TPU. Consider using TPUHalfPrecisionPlugin"
"amp + tpu is not supported."
" Only bfloats are supported on TPU. Consider using TPUHalfPrecisionPlugin"
)

if not isinstance(self.training_type_plugin, (SingleTPUPlugin, TPUSpawnPlugin)):
Expand Down
1 change: 0 additions & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,6 @@ omit =
pytorch_lightning/utilities/distributed.py
pytorch_lightning/tuner/auto_gpu_select.py
# TODO: temporary, until accelerator refactor is finished
pytorch_lightning/accelerators/accelerator.py
pytorch_lightning/plugins/training_type/*.py
pytorch_lightning/plugins/precision/*.py
pytorch_lightning/plugins/base_plugin.py
Expand Down
14 changes: 13 additions & 1 deletion tests/accelerators/test_cpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,22 @@

from pytorch_lightning.accelerators import CPUAccelerator
from pytorch_lightning.plugins import SingleDevicePlugin
from pytorch_lightning.plugins.precision import MixedPrecisionPlugin
from pytorch_lightning.plugins.precision import MixedPrecisionPlugin, PrecisionPlugin
from pytorch_lightning.utilities.exceptions import MisconfigurationException


def test_invalid_root_device():
""" Test that CPU Accelerator has root device on CPU. """
trainer = Mock()
model = Mock()
accelerator = CPUAccelerator(
training_type_plugin=SingleDevicePlugin(torch.device("cuda", 1)),
precision_plugin=PrecisionPlugin()
)
with pytest.raises(MisconfigurationException, match="Device should be CPU"):
accelerator.setup(trainer=trainer, model=model)


def test_unsupported_precision_plugins():
""" Test error messages are raised for unsupported precision plugins with CPU. """
trainer = Mock()
Expand Down
58 changes: 58 additions & 0 deletions tests/accelerators/test_gpu.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
import logging
import os
from unittest import mock
from unittest.mock import Mock

import pytest
import torch

from pytorch_lightning.accelerators import GPUAccelerator
from pytorch_lightning.plugins import PrecisionPlugin, SingleDevicePlugin
from pytorch_lightning.utilities.exceptions import MisconfigurationException


@pytest.mark.skipif(not torch.cuda.is_available(), reason="requires GPU machine")
def test_invalid_root_device():
""" Test that GPU Accelerator has root device on GPU. """
accelerator = GPUAccelerator(
training_type_plugin=SingleDevicePlugin(torch.device("cpu")),
precision_plugin=PrecisionPlugin()
)
with pytest.raises(MisconfigurationException, match="Device should be GPU"):
accelerator.setup(trainer=Mock(), model=Mock())


@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="requires multi-GPU machine")
def test_root_device_set():
""" Test that GPU Accelerator sets the current device to the root device. """
accelerator = GPUAccelerator(
training_type_plugin=SingleDevicePlugin(torch.device("cuda", 1)),
precision_plugin=PrecisionPlugin()
)
accelerator.setup(trainer=Mock(), model=Mock())
assert torch.cuda.current_device() == 1


@pytest.mark.skipif(not torch.cuda.is_available(), reason="requires GPU machine")
@mock.patch.dict(os.environ, {"CUDA_DEVICE_ORDER": ""})
def test_cuda_environment_variables_set():
""" Test that GPU Accelerator sets NVIDIA environment variables. """
accelerator = GPUAccelerator(
training_type_plugin=SingleDevicePlugin(torch.device("cuda", 0)),
precision_plugin=PrecisionPlugin()
)
accelerator.setup(trainer=Mock(), model=Mock())
assert os.getenv("CUDA_DEVICE_ORDER") == "PCI_BUS_ID"


@pytest.mark.skipif(not torch.cuda.is_available(), reason="requires GPU machine")
@mock.patch.dict(os.environ, {"CUDA_VISIBLE_DEVICES": "1, 2", "LOCAL_RANK": "3"})
def test_cuda_visible_devices_logged(caplog):
""" Test that GPU Accelerator logs CUDA_VISIBLE_DEVICES env variable. """
accelerator = GPUAccelerator(
training_type_plugin=SingleDevicePlugin(torch.device("cuda", 0)),
precision_plugin=PrecisionPlugin()
)
with caplog.at_level(logging.INFO):
accelerator.setup(trainer=Mock(), model=Mock())
assert "LOCAL_RANK: 3 - CUDA_VISIBLE_DEVICES: [1, 2]" in caplog.text
32 changes: 32 additions & 0 deletions tests/accelerators/test_tpu.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
from unittest.mock import Mock

import pytest

from pytorch_lightning.accelerators import TPUAccelerator
from pytorch_lightning.plugins import SingleTPUPlugin, DDPPlugin, PrecisionPlugin
from pytorch_lightning.plugins.precision import MixedPrecisionPlugin
from pytorch_lightning.utilities.exceptions import MisconfigurationException


def test_unsupported_precision_plugins():
""" Test error messages are raised for unsupported precision plugins with TPU. """
trainer = Mock()
model = Mock()
accelerator = TPUAccelerator(
training_type_plugin=SingleTPUPlugin(device=Mock()),
precision_plugin=MixedPrecisionPlugin(),
)
with pytest.raises(MisconfigurationException, match=r"amp \+ tpu is not supported."):
accelerator.setup(trainer=trainer, model=model)


def test_unsupported_training_type_plugins():
""" Test error messages are raised for unsupported training type with TPU. """
trainer = Mock()
model = Mock()
accelerator = TPUAccelerator(
training_type_plugin=DDPPlugin(),
precision_plugin=PrecisionPlugin(),
)
with pytest.raises(MisconfigurationException, match="TPUs only support a single tpu core or tpu spawn training"):
accelerator.setup(trainer=trainer, model=model)