Skip to content

Commit 3ee8389

Browse files
committed
edits for black
1 parent 910d2f6 commit 3ee8389

File tree

1 file changed

+17
-9
lines changed

1 file changed

+17
-9
lines changed

test/unit/test_pytorch_xla.py

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@ def num_gpus(instance_type):
6161
def is_trcomp_env():
6262
try:
6363
import torch_xla.distributed.xla_spawn # pylint: disable=unused-import # noqa: F401
64+
6465
return True
6566
except ModuleNotFoundError:
6667
return False
@@ -69,6 +70,7 @@ def is_trcomp_env():
6970
def is_oss_pt_xla_env():
7071
try:
7172
import torch_xla # pylint: disable=unused-import # noqa: F401
73+
7274
return not is_trcomp_env()
7375
except ModuleNotFoundError:
7476
return False
@@ -182,7 +184,10 @@ def test_create_command_with_shell_script(
182184
runner._create_command()
183185
assert "Please use a python script" in str(err)
184186

185-
@pytest.mark.skipif(not is_trcomp_env(), reason="Processor compatibility check follows environment compatibility check")
187+
@pytest.mark.skipif(
188+
not is_trcomp_env(),
189+
reason="Processor compatibility check follows environment compatibility check",
190+
)
186191
def test_check_compatibility_with_gpu(
187192
self, cluster, cluster_size, master, instance_type, num_gpus, *patches
188193
):
@@ -208,7 +213,9 @@ def test_check_compatibility_with_gpu(
208213
)
209214
runner._check_compatibility()
210215

211-
@pytest.mark.skipif(not is_oss_pt_xla_env(), reason="This test expects an OSS PT-XLA environment")
216+
@pytest.mark.skipif(
217+
not is_oss_pt_xla_env(), reason="This test expects an OSS PT-XLA environment"
218+
)
212219
def test_check_compatibility_with_oss_pt_xla(
213220
self, cluster, cluster_size, master, instance_type, num_gpus, *patches
214221
):
@@ -234,7 +241,7 @@ def test_check_compatibility_with_oss_pt_xla(
234241
)
235242
with pytest.raises(ModuleNotFoundError) as err:
236243
runner._check_compatibility()
237-
assert 'Unable to find SageMaker integration code' in err
244+
assert "Unable to find SageMaker integration code" in err
238245

239246
@pytest.mark.skipif(is_trcomp_env() or is_oss_pt_xla_env())
240247
def test_check_compatibility_with_pt(
@@ -262,14 +269,15 @@ def test_check_compatibility_with_pt(
262269
)
263270
with pytest.raises(ModuleNotFoundError) as err:
264271
runner._check_compatibility()
265-
assert 'requires PT-XLA to be available' in err
272+
assert "requires PT-XLA to be available" in err
266273

267274

268-
@pytest.mark.skipif(not is_trcomp_env(), reason="Processor compatibility check follows environment compatibility check")
275+
@pytest.mark.skipif(
276+
not is_trcomp_env(),
277+
reason="Processor compatibility check follows environment compatibility check",
278+
)
269279
@pytest.mark.parametrize("cluster_size", [1, 4])
270-
def test_check_compatibility_with_cpu(
271-
cluster, cluster_size, master, *patches
272-
):
280+
def test_check_compatibility_with_cpu(cluster, cluster_size, master, *patches):
273281
for rank, current_host in enumerate(cluster):
274282
print(f"Testing as host {rank+1}/{cluster_size}")
275283
runner = PyTorchXLARunner(
@@ -279,7 +287,7 @@ def test_check_compatibility_with_cpu(
279287
"SM_TRAINING_ENV": json.dumps(
280288
{
281289
"additional_framework_parameters": {
282-
"sagemaker_instance_type": 'ml.c5.4xlarge'
290+
"sagemaker_instance_type": "ml.c5.4xlarge"
283291
}
284292
}
285293
),

0 commit comments

Comments
 (0)