Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions CODE_OF_CONDUCT.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
## Code of Conduct
This project has adopted the [Amazon Open Source Code of Conduct](https://aws.github.io/code-of-conduct).
For more information see the [Code of Conduct FAQ](https://aws.github.io/code-of-conduct-faq) or contact
This project has adopted the [Amazon Open Source Code of Conduct](https://aws.github.io/code-of-conduct).
For more information see the [Code of Conduct FAQ](https://aws.github.io/code-of-conduct-faq) or contact
[email protected] with any additional questions or comments.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def read_version():
"werkzeug>=0.15.5",
"paramiko>=2.4.2",
"psutil>=5.6.7",
"protobuf>=3.19,<3.20",
"protobuf>=3.9.2,<3.20",
"scipy>=1.2.2",
]

Expand Down
246 changes: 225 additions & 21 deletions src/sagemaker_training/environment.py

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions src/sagemaker_training/params.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,3 +63,4 @@
"sagemaker_distributed_dataparallel_custom_mpi_options"
) # type: str
SM_HP_MP_PARAMETERS = "SM_HP_MP_PARAMETERS"
DISTRIBUTION_INSTANCE_GROUPS = "sagemaker_distribution_instance_groups" # type: list
4 changes: 2 additions & 2 deletions src/sagemaker_training/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def _get_by_runner_type(
env_vars,
processes_per_host,
env.master_hostname,
env.hosts,
env.distribution_hosts,
custom_mpi_options,
env.network_interface_name,
)
Expand All @@ -94,7 +94,7 @@ def _get_by_runner_type(
env_vars,
processes_per_host,
env.master_hostname,
env.hosts,
env.distribution_hosts,
custom_mpi_options,
env.network_interface_name,
num_processes=num_processes,
Expand Down
3 changes: 3 additions & 0 deletions src/sagemaker_training/smdataparallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,9 @@ def _get_instance_type(self):
instance_type = sm_training_env.get("additional_framework_parameters").get(
"sagemaker_instance_type"
)
if not instance_type:
# Heterogeneous mode
instance_type = sm_training_env.get("current_instance_type", None)
logger.info("instance type: %s" % instance_type)
return instance_type

Expand Down
6 changes: 5 additions & 1 deletion src/sagemaker_training/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,11 @@ def train():
logging_config.configure_logger(env.log_level)

mpi_enabled = env.additional_framework_parameters.get(params.MPI_ENABLED)
runner_type = runner.RunnerType.MPI if mpi_enabled else runner.RunnerType.Process
runner_type = (
runner.RunnerType.MPI
if mpi_enabled and (env.current_instance_group in env.distribution_instance_groups)
else runner.RunnerType.Process
)

entry_point.run(
env.module_dir,
Expand Down
19 changes: 14 additions & 5 deletions test/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,16 +77,17 @@ def prepare(
current_host="algo-1",
hosts=None,
network_interface_name="ethwe",
current_instance_group="homogeneousCluster",
local=False,
):
# type: (UserModule, dict, list, str, list, str, bool) -> None
# type: (UserModule, dict, list, str, list, str, str, bool) -> None
hosts = hosts or ["algo-1"]

if not local:
user_module.upload()

create_hyperparameters_config(hyperparameters, user_module.url)
create_resource_config(current_host, hosts, network_interface_name)
create_resource_config(current_host, hosts, current_instance_group, network_interface_name)
create_input_data_config(channels)


Expand All @@ -98,21 +99,29 @@ def hyperparameters(**kwargs): # type: (...) -> dict


def create_resource_config(
current_host="algo-1", hosts=None, network_interface_name="ethwe"
): # type: (str, list, str) -> None
current_host="algo-1",
hosts=None,
current_instance_group="homogeneousCluster",
network_interface_name="ethwe",
): # type: (str, list, str, str) -> None

if network_interface_name:
write_json(
dict(
current_host=current_host,
hosts=hosts or ["algo-1"],
current_instance_group=current_instance_group,
network_interface_name=network_interface_name,
),
environment.resource_config_file_dir,
)
else:
write_json(
dict(current_host=current_host, hosts=hosts or ["algo-1"]),
dict(
current_host=current_host,
current_instance_group=current_instance_group,
hosts=hosts or ["algo-1"],
),
environment.resource_config_file_dir,
)

Expand Down
24 changes: 23 additions & 1 deletion test/unit/test_environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,18 @@

builtins_open = "__builtin__.open" if six.PY2 else "builtins.open"

RESOURCE_CONFIG = dict(current_host="algo-1", hosts=["algo-1", "algo-2", "algo-3"])
RESOURCE_CONFIG = dict(
current_host="algo-1",
hosts=["algo-1", "algo-2", "algo-3"],
current_group_name="train1",
current_instance_type="ml.p3.16xlarge",
instance_groups=[
dict(
instance_group_name="train1", instance_type="ml.p3.16xlarge", hosts=["algo-1", "algo-2"]
),
dict(instance_group_name="train2", instance_type="ml.p3.8xlarge", hosts=["algo-3"]),
],
)

INPUT_DATA_CONFIG = {
"train": {
Expand Down Expand Up @@ -184,6 +195,9 @@ def test_training_env(training_env):
assert training_env.network_interface_name == "eth0"
assert training_env.job_name == "training-job-42"
assert training_env.additional_framework_parameters == {"sagemaker_parameter_server_num": 2}
assert training_env.current_instance_group == "train1"
assert training_env.current_instance_type == "ml.p3.16xlarge"
assert training_env.instance_groups == ["train1", "train2"]


def test_env_mapping_properties(training_env):
Expand Down Expand Up @@ -213,6 +227,14 @@ def test_env_mapping_properties(training_env):
"is_master",
"master_hostname",
"is_modelparallel_enabled",
"instance_groups",
"instance_groups_dict",
"current_instance_type",
"current_instance_group",
"current_instance_group_hosts",
"distribution_hosts",
"distribution_instance_groups",
"is_hetero",
}


Expand Down
116 changes: 116 additions & 0 deletions test/unit/test_smdataparallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,122 @@ def test_smdataparallel_run_single_node_python(
path_exists.assert_called_with("/usr/sbin/sshd")


@patch("asyncio.gather", new_callable=AsyncMock)
@patch("os.path.exists")
@patch("sagemaker_training.process.python_executable", return_value="usr/bin/python3")
@patch("paramiko.SSHClient", new_callable=MockSSHClient)
@patch("paramiko.AutoAddPolicy")
@patch("asyncio.create_subprocess_shell")
@patch("sagemaker_training.environment.Environment")
def test_hc_smdataparallel_run_single_node_python(
training_env,
async_shell,
policy,
ssh_client,
python_executable,
path_exists,
async_gather,
event_loop,
):
with patch.dict(os.environ, clear=True):
hosts = ["algo-1"]
master_hostname = hosts[0]
num_hosts = len(hosts)
num_processes_per_host = 8
num_processes = num_processes_per_host * num_hosts
host_list = hosts
network_interface_name = "ethw3"
smdataparallel_flag = "SMDATAPARALLEL_USE_SINGLENODE=1"

smdataparallel_runner = smdataparallel.SMDataParallelRunner(
user_entry_point="train.py",
args=["-v", "--lr", "35"],
env_vars={
"SM_TRAINING_ENV": '{"additional_framework_parameters":{"sagemaker_distributed_dataparallel_enabled":"true"},\
"current_instance_type": "ml.p4d.24xlarge"}'
},
processes_per_host=num_processes_per_host,
master_hostname=master_hostname,
hosts=hosts,
custom_mpi_options="--verbose",
network_interface_name=network_interface_name,
)

_, _, process = smdataparallel_runner.run(wait=False)
cmd = [
"mpirun",
"--host",
",".join(host_list),
"-np",
str(num_processes),
"--allow-run-as-root",
"--tag-output",
"--oversubscribe",
"-mca",
"btl_tcp_if_include",
network_interface_name,
"-mca",
"oob_tcp_if_include",
network_interface_name,
"-mca",
"plm_rsh_no_tree_spawn",
"1",
"-mca",
"pml",
"ob1",
"-mca",
"btl",
"^openib",
"-mca",
"orte_abort_on_non_zero_status",
"1",
"-mca",
"btl_vader_single_copy_mechanism",
"none",
"-mca",
"plm_rsh_num_concurrent",
str(num_hosts),
"-x",
"NCCL_SOCKET_IFNAME=%s" % network_interface_name,
"-x",
"NCCL_DEBUG=INFO",
"-x",
"LD_LIBRARY_PATH",
"-x",
"PATH",
"-x",
smdataparallel_flag,
"-x",
"FI_PROVIDER=efa",
"-x",
"RDMAV_FORK_SAFE=1",
"-x",
"LD_PRELOAD=%s" % inspect.getfile(gethostname),
"--verbose",
"-x",
"FI_EFA_USE_DEVICE_RDMA=1",
"smddprun",
"usr/bin/python3",
"-m",
"mpi4py",
"train.py",
"-v",
"--lr",
"35",
]
async_shell.assert_called_with(
" ".join(cmd),
cwd=environment.code_dir,
env=ANY,
stdout=asyncio.subprocess.PIPE,
stderr=None,
)
async_shell.assert_called_once()
async_gather.assert_called_once()
assert process == async_shell.return_value
path_exists.assert_called_with("/usr/sbin/sshd")


@patch("sagemaker_training.logging_config.log_script_invocation")
def test_connection(log):
with pytest.raises(Exception):
Expand Down
2 changes: 2 additions & 0 deletions test/unit/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@ def sagemaker_s3_output(self):

class ScriptEnvironment(Environment):
framework_module = None
current_instance_group = "Test1"
distribution_instance_groups = ["Test1"]

def sagemaker_s3_output(self):
return "s3://bucket"
Expand Down