From c909ac68fed49cbc5d0cd7a14104c88f37b6fb7e Mon Sep 17 00:00:00 2001 From: Karan Jariwala Date: Tue, 6 Apr 2021 09:34:48 -0700 Subject: [PATCH 1/3] feature: smdataparallel enable EFA RDMA flag --- src/sagemaker_training/smdataparallel.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/src/sagemaker_training/smdataparallel.py b/src/sagemaker_training/smdataparallel.py index 9f9e70bb..0998884b 100644 --- a/src/sagemaker_training/smdataparallel.py +++ b/src/sagemaker_training/smdataparallel.py @@ -161,12 +161,15 @@ def _get_mpirun_command( mpirun_command.extend(additional_options) + instance_type = self._get_instance_type() + # Use EFA's RDMA functionality for one-sided and two-sided transfer + if instance_type in ["ml.p3dn.24xlarge", "ml.p4d.24xlarge"]: + mpirun_command.extend(["-x", "FI_EFA_USE_DEVICE_RDMA=1"]) + if smdataparallel_server_addr and smdataparallel_server_port: # in case of multi-node [distributed] training, smdataparallel_server_addr, # smdataparallel_server_port and interconnect_bandwidth will need to be set - instance_type = self._get_instance_type() - mpirun_command.extend( [ "-x", From 0b90e09fa239f06b48dbffc69b40dd649aecced0 Mon Sep 17 00:00:00 2001 From: Karan Jariwala Date: Tue, 6 Apr 2021 10:13:10 -0700 Subject: [PATCH 2/3] added changes to unit test --- test/unit/test_smdataparallel.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/test/unit/test_smdataparallel.py b/test/unit/test_smdataparallel.py index 7c429633..a6afa6d2 100644 --- a/test/unit/test_smdataparallel.py +++ b/test/unit/test_smdataparallel.py @@ -159,7 +159,9 @@ def test_smdataparallel_run_single_node_python( smdataparallel_runner = smdataparallel.SMDataParallelRunner( user_entry_point="train.py", args=["-v", "--lr", "35"], - env_vars={}, + env_vars={ + "SM_TRAINING_ENV": '{"additional_framework_parameters":{"sagemaker_instance_type":"ml.p3dn.24xlarge"}}' + }, master_hostname=master_hostname, hosts=hosts, custom_mpi_options="--verbose", @@ -219,6 +221,8 @@ def test_smdataparallel_run_single_node_python( "-x", "LD_PRELOAD=%s" % inspect.getfile(gethostname), "--verbose", + "-x", + "FI_EFA_USE_DEVICE_RDMA=1", "smddprun", "usr/bin/python3", "-m", From 07cb410433a4ef4cc3a534548c7ae07c83b54198 Mon Sep 17 00:00:00 2001 From: Karan Jariwala Date: Tue, 6 Apr 2021 10:52:52 -0700 Subject: [PATCH 3/3] updated the flag to use only for ml.p4d.24xlarge instance --- src/sagemaker_training/smdataparallel.py | 2 +- test/unit/test_smdataparallel.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/sagemaker_training/smdataparallel.py b/src/sagemaker_training/smdataparallel.py index 0998884b..ece1270e 100644 --- a/src/sagemaker_training/smdataparallel.py +++ b/src/sagemaker_training/smdataparallel.py @@ -163,7 +163,7 @@ def _get_mpirun_command( instance_type = self._get_instance_type() # Use EFA's RDMA functionality for one-sided and two-sided transfer - if instance_type in ["ml.p3dn.24xlarge", "ml.p4d.24xlarge"]: + if instance_type in ["ml.p4d.24xlarge"]: mpirun_command.extend(["-x", "FI_EFA_USE_DEVICE_RDMA=1"]) if smdataparallel_server_addr and smdataparallel_server_port: diff --git a/test/unit/test_smdataparallel.py b/test/unit/test_smdataparallel.py index a6afa6d2..9917dacb 100644 --- a/test/unit/test_smdataparallel.py +++ b/test/unit/test_smdataparallel.py @@ -160,7 +160,7 @@ def test_smdataparallel_run_single_node_python( user_entry_point="train.py", args=["-v", "--lr", "35"], env_vars={ - "SM_TRAINING_ENV": '{"additional_framework_parameters":{"sagemaker_instance_type":"ml.p3dn.24xlarge"}}' + "SM_TRAINING_ENV": '{"additional_framework_parameters":{"sagemaker_instance_type":"ml.p4d.24xlarge"}}' }, master_hostname=master_hostname, hosts=hosts,