@@ -99,12 +99,14 @@ def test_setup(self, cluster, cluster_size, master, instance_type, num_gpus, *pa
9999 def test_create_command_with_py_script (
100100 self , cluster , cluster_size , master , instance_type , num_gpus , * patches
101101 ):
102+ training_args = ["-v" , "--lr" , "35" ]
103+ training_script = "train.py"
102104 for current_host in cluster :
103105 rank = cluster .index (current_host )
104106 print (f"Testing as host { rank + 1 } /{ cluster_size } " )
105107 runner = PyTorchXLARunner (
106- user_entry_point = "train.py" ,
107- args = [ "-v" , "--lr" , "35" ] ,
108+ user_entry_point = training_script ,
109+ args = training_args ,
108110 env_vars = {
109111 "SM_TRAINING_ENV" : json .dumps (
110112 {
@@ -120,8 +122,17 @@ def test_create_command_with_py_script(
120122 hosts = cluster ,
121123 num_gpus = num_gpus ,
122124 )
123- expected_command = []
124- assert expected_command == runner ._create_command ()
125+ received_command = runner ._create_command ()
126+ expected_command = [
127+ "python" ,
128+ "-m" ,
129+ "torch_xla.distributed.xla_spawn" ,
130+ "--num_gpus" ,
131+ str (num_gpus ),
132+ training_script ,
133+ ] + training_args
134+ assert received_command [0 ].split ("/" )[- 1 ] == expected_command [0 ]
135+ assert received_command [1 :] == expected_command [1 :]
125136
126137 def test_create_command_with_shell_script (
127138 self , cluster , cluster_size , master , instance_type , num_gpus , * patches
0 commit comments