Skip to content

Commit db2feef

Browse files
awaelchlifour4fish
authored andcommitted
update tests
1 parent 88a42ac commit db2feef

File tree

2 files changed

+7
-13
lines changed

2 files changed

+7
-13
lines changed

tests/accelerators/test_gpu.py

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,7 @@
1111
def test_get_torch_gpu_stats(tmpdir):
1212
"""Test GPU get_device_stats with Pytorch >= 1.8.0."""
1313
current_device = torch.device(f"cuda:{torch.cuda.current_device()}")
14-
GPUAccel = GPUAccelerator(
15-
training_type_plugin=DataParallelPlugin(parallel_devices=[current_device]), precision_plugin=PrecisionPlugin()
16-
)
17-
gpu_stats = GPUAccel.get_device_stats(current_device)
14+
gpu_stats = GPUAccelerator().get_device_stats(current_device)
1815
fields = ["allocated_bytes.all.freed", "inactive_split.all.peak", "reserved_bytes.large_pool.peak"]
1916

2017
for f in fields:
@@ -26,10 +23,7 @@ def test_get_torch_gpu_stats(tmpdir):
2623
def test_get_nvidia_gpu_stats(tmpdir):
2724
"""Test GPU get_device_stats with Pytorch < 1.8.0."""
2825
current_device = torch.device(f"cuda:{torch.cuda.current_device()}")
29-
GPUAccel = GPUAccelerator(
30-
training_type_plugin=DataParallelPlugin(parallel_devices=[current_device]), precision_plugin=PrecisionPlugin()
31-
)
32-
gpu_stats = GPUAccel.get_device_stats(current_device)
26+
gpu_stats = GPUAccelerator().get_device_stats(current_device)
3327
fields = ["utilization.gpu", "memory.used", "memory.free", "utilization.memory"]
3428

3529
for f in fields:

tests/accelerators/test_tpu.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
# limitations under the License
1414
import collections
1515
from copy import deepcopy
16-
from unittest.mock import patch
16+
from unittest.mock import patch, Mock
1717

1818
import pytest
1919
import torch
@@ -288,13 +288,13 @@ def forward(self, x):
288288

289289

290290
def test_tpu_invalid_raises():
291-
accelerator = TPUAccelerator(object(), TPUSpawnPlugin())
291+
training_type_plugin = TPUSpawnPlugin(accelerator=TPUAccelerator(), precision_plugin=Mock())
292292
with pytest.raises(ValueError, match="TPUAccelerator` can only be used with a `TPUPrecisionPlugin"):
293-
training_type_plugin.setup(object())
293+
training_type_plugin.setup(Mock())
294294

295-
accelerator = TPUAccelerator(TPUPrecisionPlugin(), DDPPlugin())
295+
training_type_plugin = DDPPlugin(accelerator=TPUAccelerator(), precision_plugin=TPUPrecisionPlugin())
296296
with pytest.raises(ValueError, match="TPUAccelerator` can only be used with a `SingleTPUPlugin` or `TPUSpawnPlugi"):
297-
training_type_plugin.setup(object())
297+
training_type_plugin.setup(Mock())
298298

299299

300300
def test_tpu_invalid_raises_set_precision_with_strategy():

0 commit comments

Comments
 (0)