File tree Expand file tree Collapse file tree 2 files changed +7
-13
lines changed Expand file tree Collapse file tree 2 files changed +7
-13
lines changed Original file line number Diff line number Diff line change 1111def test_get_torch_gpu_stats (tmpdir ):
1212 """Test GPU get_device_stats with Pytorch >= 1.8.0."""
1313 current_device = torch .device (f"cuda:{ torch .cuda .current_device ()} " )
14- GPUAccel = GPUAccelerator (
15- training_type_plugin = DataParallelPlugin (parallel_devices = [current_device ]), precision_plugin = PrecisionPlugin ()
16- )
17- gpu_stats = GPUAccel .get_device_stats (current_device )
14+ gpu_stats = GPUAccelerator ().get_device_stats (current_device )
1815 fields = ["allocated_bytes.all.freed" , "inactive_split.all.peak" , "reserved_bytes.large_pool.peak" ]
1916
2017 for f in fields :
@@ -26,10 +23,7 @@ def test_get_torch_gpu_stats(tmpdir):
2623def test_get_nvidia_gpu_stats (tmpdir ):
2724 """Test GPU get_device_stats with Pytorch < 1.8.0."""
2825 current_device = torch .device (f"cuda:{ torch .cuda .current_device ()} " )
29- GPUAccel = GPUAccelerator (
30- training_type_plugin = DataParallelPlugin (parallel_devices = [current_device ]), precision_plugin = PrecisionPlugin ()
31- )
32- gpu_stats = GPUAccel .get_device_stats (current_device )
26+ gpu_stats = GPUAccelerator ().get_device_stats (current_device )
3327 fields = ["utilization.gpu" , "memory.used" , "memory.free" , "utilization.memory" ]
3428
3529 for f in fields :
Original file line number Diff line number Diff line change 1313# limitations under the License
1414import collections
1515from copy import deepcopy
16- from unittest .mock import patch
16+ from unittest .mock import patch , Mock
1717
1818import pytest
1919import torch
@@ -288,13 +288,13 @@ def forward(self, x):
288288
289289
290290def test_tpu_invalid_raises ():
291- accelerator = TPUAccelerator (object ( ), TPUSpawnPlugin ())
291+ training_type_plugin = TPUSpawnPlugin ( accelerator = TPUAccelerator (), precision_plugin = Mock ())
292292 with pytest .raises (ValueError , match = "TPUAccelerator` can only be used with a `TPUPrecisionPlugin" ):
293- training_type_plugin .setup (object ())
293+ training_type_plugin .setup (Mock ())
294294
295- accelerator = TPUAccelerator (TPUPrecisionPlugin ( ), DDPPlugin ())
295+ training_type_plugin = DDPPlugin ( accelerator = TPUAccelerator (), precision_plugin = TPUPrecisionPlugin ())
296296 with pytest .raises (ValueError , match = "TPUAccelerator` can only be used with a `SingleTPUPlugin` or `TPUSpawnPlugi" ):
297- training_type_plugin .setup (object ())
297+ training_type_plugin .setup (Mock ())
298298
299299
300300def test_tpu_invalid_raises_set_precision_with_strategy ():
You can’t perform that action at this time.
0 commit comments