Skip to content

Commit 7f9acd5

Browse files
committed
Create a new distribution mechanism for PT-XLA
1 parent 6de6135 commit 7f9acd5

File tree

3 files changed

+139
-1
lines changed

3 files changed

+139
-1
lines changed

src/sagemaker_training/params.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,9 @@
3838
MULTI_WORKER_MIRRORED_STRATEGY_ENABLED = (
3939
"sagemaker_multi_worker_mirrored_strategy_enabled"
4040
) # type: str
41+
PTORCH_XLA_MULTI_WORKER_ENABLED = (
42+
"sagemaker_pytorch_xla_multi_worker_enabled"
43+
) # type: str
4144
REGION_NAME_PARAM = "sagemaker_region" # type: str
4245
REGION_NAME_ENV = REGION_NAME_PARAM.upper() # type: str
4346
DEFAULT_INVOCATIONS_ACCEPT_ENV = "SAGEMAKER_DEFAULT_INVOCATIONS_ACCEPT" # type: str
Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,122 @@
1+
# Copyright 2018-2022 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the 'License'). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the 'license' file accompanying this file. This file is
10+
# distributed on an 'AS IS' BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
"""This module contains functionality related to distributed training using
14+
PT-XLA (PyTorch - Accelerated Linear Algebra)."""
15+
from __future__ import absolute_import
16+
17+
import os
18+
19+
from sagemaker_training import (
20+
logging_config,
21+
_entry_point_type,
22+
environment,
23+
errors
24+
)
25+
26+
27+
logger = logging_config.get_logger()
28+
29+
30+
class PyTorchXLARunner(process.ProcessRunner):
31+
"""Responsible for preparing PT-XLA distributed training.
32+
"""
33+
34+
MESH_SERVICE_PORT = 53957
35+
WORKER_PORT = 43857
36+
37+
def __init__(
38+
self,
39+
user_entry_point,
40+
args,
41+
env_vars,
42+
processes_per_host,
43+
master_hostname,
44+
current_host,
45+
hosts,
46+
num_gpus,
47+
):
48+
"""Initialize a PyTorchXLARunner, which is responsible for preparing distributed
49+
training with PT-XLA.
50+
51+
Args:
52+
user_entry_point (str): The name of the user entry point.
53+
args ([str]): A list of arguments to include when executing the entry point.
54+
env_vars (dict(str,str)): A dictionary of environment variables.
55+
master_hostname (str): The master hostname.
56+
current_host (str): The current hostname.
57+
hosts ([str]): A list of hosts.
58+
num_gpus (int): The number of GPUs available per host.
59+
"""
60+
61+
super(PyTorchXLARunner, self).__init__(user_entry_point, args, env_vars, processes_per_host)
62+
63+
self._master_hostname = master_hostname
64+
self._current_host = current_host
65+
self._hosts = hosts
66+
self._num_gpus = num_gpus
67+
68+
self._num_hosts = len(self._hosts)
69+
self._rank = len(self._hosts.index(self._current_host))
70+
71+
def _setup(self): # type: () -> None
72+
logger.info("Starting distributed training through PT-XLA Runtime.")
73+
self.__check_compatibility()
74+
75+
os.environ["XRT_HOST_ORDINAL"] = str(self._rank)
76+
os.environ["XRT_SHARD_WORLD_SIZE"] = str(self._num_hosts)
77+
address = 'localservice:{};{}:'+str(self.WORKER_PORT)
78+
os.environ["XRT_WORKERS"] = "|".join([address.format(i, host) for i, host in enumerate(self._hosts)])
79+
os.environ["GPU_NUM_DEVICES"] = str(self._num_gpus)
80+
if self._num_hosts > 1:
81+
os.environ["XRT_MESH_SERVICE_ADDRESS"] = f"{self._master_hostname}:{self.MESH_SERVICE_PORT}"
82+
83+
logger.info("Completed environment setup for distributed training through PT-XLA Runtime.")
84+
85+
def _create_command(self):
86+
entrypoint_type = _entry_point_type.get(environment.code_dir, self._user_entry_point)
87+
88+
if entrypoint_type is _entry_point_type.PYTHON_PACKAGE:
89+
raise errors.ClientError("Distributed Training through PT-XLA is not supported for Python packages. Please use a python script as the entry-point")
90+
elif entrypoint_type is _entry_point_type.PYTHON_PROGRAM:
91+
return self.__pytorch_xla_command() + [self._user_entry_point] + self._args
92+
else:
93+
raise errors.ClientError("Distributed Training through PT-XLA is only supported for Python scripts. Please use a python script as the entry-point")
94+
95+
def __pytorch_xla_command(self): # pylint: disable=no-self-use
96+
return [self._python_command(), '-m', 'torch_xla.distributed.xla_spawn', '--num_gpus', str(self._num_gpus)]
97+
98+
def __check_compatibility(self):
99+
try:
100+
import torch_xla
101+
except ModuleNotFoundError:
102+
raise ModuleNotFoundError(
103+
"Unable to find PT-XLA in the execution environment. "
104+
"This distribution mechanism requires PT-XLA to be available in the execution environment. "
105+
"SageMaker Training Compiler provides ready-to-use containers with PT-XLA. "
106+
"Please refer to https://github.com/aws/deep-learning-containers/blob/master/available_images.md "
107+
)
108+
109+
try:
110+
import torch_xla.distributed.xla_spawn
111+
except ModuleNotFoundError:
112+
raise ModuleNotFoundError(
113+
"Unable to find SageMaker integration code in PT-XLA. "
114+
"AWS SageMaker adds custom code on top of open source PT-XLA to provide platform specific "
115+
"optimizations. These SageMaker specific binaries are shipped as part of our "
116+
"Deep Learning Containers. Please refer to "
117+
"https://github.com/aws/deep-learning-containers/blob/master/available_images.md"
118+
)
119+
120+
if not ( self._num_gpus > 1 ):
121+
raise ValueError("Distributed training through PT-XLA is only supported for GPUs.")
122+

src/sagemaker_training/runner.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
import enum
1919

20-
from sagemaker_training import environment, mpi, params, process, smdataparallel
20+
from sagemaker_training import environment, mpi, params, process, smdataparallel, pytorch_xla
2121

2222

2323
class RunnerType(enum.Enum):
@@ -26,11 +26,13 @@ class RunnerType(enum.Enum):
2626
MPI = "MPI"
2727
Process = "Process"
2828
SMDataParallel = "SMDataParallel"
29+
PyTorchXLA = "PyTorchXLA"
2930

3031

3132
ProcessRunnerType = RunnerType.Process
3233
MPIRunnerType = RunnerType.MPI
3334
SMDataParallelRunnerType = RunnerType.SMDataParallel
35+
PyTorchXLARunnerType = RunnerType.PyTorchXLA
3436

3537

3638
def get(identifier, user_entry_point=None, args=None, env_vars=None, extra_opts=None):
@@ -103,6 +105,17 @@ def _get_by_runner_type(
103105
return mpi.WorkerRunner(
104106
user_entry_point, args, env_vars, processes_per_host, env.master_hostname
105107
)
108+
elif identifier is RunnerType.PyTorchXLA:
109+
return pytorch_xla.PyTorchXLARunner(
110+
user_entry_point,
111+
args,
112+
env_vars,
113+
processes_per_host,
114+
env.master_hostname,
115+
env.current_host,
116+
env.distribution_hosts,
117+
env.num_gpus,
118+
)
106119
elif identifier is RunnerType.Process:
107120
return process.ProcessRunner(user_entry_point, args, env_vars, processes_per_host)
108121
else:

0 commit comments

Comments
 (0)