Skip to content

Commit e81a2dd

Browse files
authored
Fixed the auto device detection (#1674)
* update the device name to current device Signed-off-by: yiliu30 <[email protected]>
1 parent b0c2a82 commit e81a2dd

File tree

2 files changed

+12
-1
lines changed

2 files changed

+12
-1
lines changed

neural_compressor/torch/utils/environ.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,5 +66,5 @@ def get_device(device_name="auto"):
6666
from neural_compressor.torch.utils.auto_accelerator import auto_detect_accelerator
6767

6868
runtime_accelerator = auto_detect_accelerator(device_name)
69-
device = runtime_accelerator.name()
69+
device = runtime_accelerator.current_device_name()
7070
return device
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import pytest
44
import torch
55

6+
from neural_compressor.torch.utils import get_device
67
from neural_compressor.torch.utils.auto_accelerator import accelerator_registry, auto_detect_accelerator
78

89

@@ -52,6 +53,16 @@ def test_cuda_accelerator(self, force_use_cuda):
5253
assert accelerator.synchronize() is None
5354
assert accelerator.empty_cache() is None
5455

56+
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="Only one GPU is available")
57+
def test_get_device(self):
58+
accelerator = auto_detect_accelerator()
59+
assert accelerator.set_device(1) is None
60+
assert accelerator.current_device_name() == "cuda:1"
61+
cur_device = get_device()
62+
assert cur_device == "cuda:1"
63+
tmp_tensor = torch.tensor([1, 2], device=cur_device)
64+
assert "cuda:1" == str(tmp_tensor.device)
65+
5566

5667
class TestAutoAccelerator:
5768

0 commit comments

Comments
 (0)