Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion neural_compressor/torch/utils/environ.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,5 +66,5 @@ def get_device(device_name="auto"):
from neural_compressor.torch.utils.auto_accelerator import auto_detect_accelerator

runtime_accelerator = auto_detect_accelerator(device_name)
device = runtime_accelerator.name()
device = runtime_accelerator.current_device_name()
return device
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import pytest
import torch

from neural_compressor.torch.utils import get_device
from neural_compressor.torch.utils.auto_accelerator import accelerator_registry, auto_detect_accelerator


Expand Down Expand Up @@ -52,6 +53,16 @@ def test_cuda_accelerator(self, force_use_cuda):
assert accelerator.synchronize() is None
assert accelerator.empty_cache() is None

@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="Only one GPU is available")
def test_get_device(self):
accelerator = auto_detect_accelerator()
assert accelerator.set_device(1) is None
assert accelerator.current_device_name() == "cuda:1"
cur_device = get_device()
assert cur_device == "cuda:1"
tmp_tensor = torch.tensor([1, 2], device=cur_device)
assert "cuda:1" == str(tmp_tensor.device)


class TestAutoAccelerator:

Expand Down