@@ -61,6 +61,7 @@ def num_gpus(instance_type):
6161def is_trcomp_env ():
6262 try :
6363 import torch_xla .distributed .xla_spawn # pylint: disable=unused-import # noqa: F401
64+
6465 return True
6566 except ModuleNotFoundError :
6667 return False
@@ -69,6 +70,7 @@ def is_trcomp_env():
6970def is_oss_pt_xla_env ():
7071 try :
7172 import torch_xla # pylint: disable=unused-import # noqa: F401
73+
7274 return not is_trcomp_env ()
7375 except ModuleNotFoundError :
7476 return False
@@ -182,7 +184,10 @@ def test_create_command_with_shell_script(
182184 runner ._create_command ()
183185 assert "Please use a python script" in str (err )
184186
185- @pytest .mark .skipif (not is_trcomp_env (), reason = "Processor compatibility check follows environment compatibility check" )
187+ @pytest .mark .skipif (
188+ not is_trcomp_env (),
189+ reason = "Processor compatibility check follows environment compatibility check" ,
190+ )
186191 def test_check_compatibility_with_gpu (
187192 self , cluster , cluster_size , master , instance_type , num_gpus , * patches
188193 ):
@@ -208,7 +213,9 @@ def test_check_compatibility_with_gpu(
208213 )
209214 runner ._check_compatibility ()
210215
211- @pytest .mark .skipif (not is_oss_pt_xla_env (), reason = "This test expects an OSS PT-XLA environment" )
216+ @pytest .mark .skipif (
217+ not is_oss_pt_xla_env (), reason = "This test expects an OSS PT-XLA environment"
218+ )
212219 def test_check_compatibility_with_oss_pt_xla (
213220 self , cluster , cluster_size , master , instance_type , num_gpus , * patches
214221 ):
@@ -234,7 +241,7 @@ def test_check_compatibility_with_oss_pt_xla(
234241 )
235242 with pytest .raises (ModuleNotFoundError ) as err :
236243 runner ._check_compatibility ()
237- assert ' Unable to find SageMaker integration code' in err
244+ assert " Unable to find SageMaker integration code" in err
238245
239246 @pytest .mark .skipif (is_trcomp_env () or is_oss_pt_xla_env ())
240247 def test_check_compatibility_with_pt (
@@ -262,14 +269,15 @@ def test_check_compatibility_with_pt(
262269 )
263270 with pytest .raises (ModuleNotFoundError ) as err :
264271 runner ._check_compatibility ()
265- assert ' requires PT-XLA to be available' in err
272+ assert " requires PT-XLA to be available" in err
266273
267274
268- @pytest .mark .skipif (not is_trcomp_env (), reason = "Processor compatibility check follows environment compatibility check" )
275+ @pytest .mark .skipif (
276+ not is_trcomp_env (),
277+ reason = "Processor compatibility check follows environment compatibility check" ,
278+ )
269279@pytest .mark .parametrize ("cluster_size" , [1 , 4 ])
270- def test_check_compatibility_with_cpu (
271- cluster , cluster_size , master , * patches
272- ):
280+ def test_check_compatibility_with_cpu (cluster , cluster_size , master , * patches ):
273281 for rank , current_host in enumerate (cluster ):
274282 print (f"Testing as host { rank + 1 } /{ cluster_size } " )
275283 runner = PyTorchXLARunner (
@@ -279,7 +287,7 @@ def test_check_compatibility_with_cpu(
279287 "SM_TRAINING_ENV" : json .dumps (
280288 {
281289 "additional_framework_parameters" : {
282- "sagemaker_instance_type" : ' ml.c5.4xlarge'
290+ "sagemaker_instance_type" : " ml.c5.4xlarge"
283291 }
284292 }
285293 ),
0 commit comments