Skip to content

Commit bcb5556

Browse files
committed
Fixing syntax errors
1 parent 8f744fb commit bcb5556

File tree

1 file changed

+15
-4
lines changed

1 file changed

+15
-4
lines changed

test/unit/test_pytorch_xla.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -99,12 +99,14 @@ def test_setup(self, cluster, cluster_size, master, instance_type, num_gpus, *pa
9999
def test_create_command_with_py_script(
100100
self, cluster, cluster_size, master, instance_type, num_gpus, *patches
101101
):
102+
training_args = ["-v", "--lr", "35"]
103+
training_script = "train.py"
102104
for current_host in cluster:
103105
rank = cluster.index(current_host)
104106
print(f"Testing as host {rank+1}/{cluster_size}")
105107
runner = PyTorchXLARunner(
106-
user_entry_point="train.py",
107-
args=["-v", "--lr", "35"],
108+
user_entry_point=training_script,
109+
args=training_args,
108110
env_vars={
109111
"SM_TRAINING_ENV": json.dumps(
110112
{
@@ -120,8 +122,17 @@ def test_create_command_with_py_script(
120122
hosts=cluster,
121123
num_gpus=num_gpus,
122124
)
123-
expected_command = []
124-
assert expected_command == runner._create_command()
125+
received_command = runner._create_command()
126+
expected_command = [
127+
"python",
128+
"-m",
129+
"torch_xla.distributed.xla_spawn",
130+
"--num_gpus",
131+
str(num_gpus),
132+
training_script,
133+
] + training_args
134+
assert received_command[0].split("/")[-1] == expected_command[0]
135+
assert received_command[1:] == expected_command[1:]
125136

126137
def test_create_command_with_shell_script(
127138
self, cluster, cluster_size, master, instance_type, num_gpus, *patches

0 commit comments

Comments
 (0)