From 042de3171d580cb92e6a1a7d379767fade32fdf5 Mon Sep 17 00:00:00 2001 From: Tom Bousso Date: Mon, 20 May 2024 18:00:09 +0000 Subject: [PATCH 1/8] rewrite pytorchddp to smdistributed --- src/sagemaker/pytorch/estimator.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/src/sagemaker/pytorch/estimator.py b/src/sagemaker/pytorch/estimator.py index a4e24d1ff0..1b4e656440 100644 --- a/src/sagemaker/pytorch/estimator.py +++ b/src/sagemaker/pytorch/estimator.py @@ -276,6 +276,16 @@ def __init__( kwargs["entry_point"] = entry_point if distribution is not None: + # rewrite pytorchddp to smdistributed + if "pytorchddp" in distribution: + if "smdistributed" in distribution: + raise ValueError( + "Cannot use both pytorchddp and smdistributed " + "distribution options together.", + distribution + ) + distribution = {"smdistributed": {"dataparallel": distribution["pytorchddp"]}} + distribution = validate_distribution( distribution, self.instance_groups, From 1f5894cd30e0e8a1c053442c34a1574a931ace0a Mon Sep 17 00:00:00 2001 From: Tom Bousso Date: Mon, 20 May 2024 18:00:31 +0000 Subject: [PATCH 2/8] remove instance type check --- src/sagemaker/fw_utils.py | 13 ------------- 1 file changed, 13 deletions(-) diff --git a/src/sagemaker/fw_utils.py b/src/sagemaker/fw_utils.py index 33018becdd..6f6ab3705b 100644 --- a/src/sagemaker/fw_utils.py +++ b/src/sagemaker/fw_utils.py @@ -795,7 +795,6 @@ def _validate_smdataparallel_args( Raises: ValueError: if - (`instance_type` is not in SM_DATAPARALLEL_SUPPORTED_INSTANCE_TYPES or `py_version` is not python3 or `framework_version` is not in SM_DATAPARALLEL_SUPPORTED_FRAMEWORK_VERSION """ @@ -806,18 +805,6 @@ def _validate_smdataparallel_args( if not smdataparallel_enabled: return - is_instance_type_supported = instance_type in SM_DATAPARALLEL_SUPPORTED_INSTANCE_TYPES - - err_msg = "" - - if not is_instance_type_supported: - # instance_type is required - err_msg += ( - f"Provided instance_type {instance_type} is not supported by smdataparallel.\n" - "Please specify one of the supported instance types:" - f"{SM_DATAPARALLEL_SUPPORTED_INSTANCE_TYPES}\n" - ) - if not image_uri: # ignore framework_version & py_version if image_uri is set # in case image_uri is not set, then both are mandatory From 22a33983cbaf8b96f41ae9af12ba22f85f488ecf Mon Sep 17 00:00:00 2001 From: Tom Bousso Date: Tue, 21 May 2024 12:20:59 -0700 Subject: [PATCH 3/8] Update estimator.py --- src/sagemaker/pytorch/estimator.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/sagemaker/pytorch/estimator.py b/src/sagemaker/pytorch/estimator.py index 1b4e656440..eb574af285 100644 --- a/src/sagemaker/pytorch/estimator.py +++ b/src/sagemaker/pytorch/estimator.py @@ -284,7 +284,11 @@ def __init__( "distribution options together.", distribution ) - distribution = {"smdistributed": {"dataparallel": distribution["pytorchddp"]}} + + # convert pytorchddp distribution into smdistributed distribution + distribution = distribution.copy() + distribution["smdistributed"] = {"dataparallel" : distribution["pytorchddp"]} + del distribution["pytorchddp"] distribution = validate_distribution( distribution, From 006f04e67160646e32ea386d9a78a1635ebc75d3 Mon Sep 17 00:00:00 2001 From: Tom Bousso Date: Tue, 21 May 2024 20:09:39 +0000 Subject: [PATCH 4/8] remove validate_pytorch_distribution --- src/sagemaker/fw_utils.py | 87 --------------------------------------- 1 file changed, 87 deletions(-) diff --git a/src/sagemaker/fw_utils.py b/src/sagemaker/fw_utils.py index 6f6ab3705b..9f588a08df 100644 --- a/src/sagemaker/fw_utils.py +++ b/src/sagemaker/fw_utils.py @@ -145,22 +145,6 @@ ], } -PYTORCHDDP_SUPPORTED_FRAMEWORK_VERSIONS = [ - "1.10", - "1.10.0", - "1.10.2", - "1.11", - "1.11.0", - "1.12", - "1.12.0", - "1.12.1", - "1.13.1", - "2.0.0", - "2.0.1", - "2.1.0", - "2.2.0", -] - TORCH_DISTRIBUTED_GPU_SUPPORTED_FRAMEWORK_VERSIONS = [ "1.13.1", "2.0.0", @@ -915,13 +899,6 @@ def validate_distribution( ) if framework_name and framework_name == "pytorch": # We need to validate only for PyTorch framework - validate_pytorch_distribution( - distribution=validated_distribution, - framework_name=framework_name, - framework_version=framework_version, - py_version=py_version, - image_uri=image_uri, - ) validate_torch_distributed_distribution( instance_type=instance_type, distribution=validated_distribution, @@ -955,13 +932,6 @@ def validate_distribution( ) if framework_name and framework_name == "pytorch": # We need to validate only for PyTorch framework - validate_pytorch_distribution( - distribution=validated_distribution, - framework_name=framework_name, - framework_version=framework_version, - py_version=py_version, - image_uri=image_uri, - ) validate_torch_distributed_distribution( instance_type=instance_type, distribution=validated_distribution, @@ -1010,63 +980,6 @@ def validate_distribution_for_instance_type(instance_type, distribution): raise ValueError(err_msg) -def validate_pytorch_distribution( - distribution, framework_name, framework_version, py_version, image_uri -): - """Check if pytorch distribution strategy is correctly invoked by the user. - - Args: - distribution (dict): A dictionary with information to enable distributed training. - (Defaults to None if distributed training is not enabled.) For example: - - .. code:: python - - { - "pytorchddp": { - "enabled": True - } - } - framework_name (str): A string representing the name of framework selected. - framework_version (str): A string representing the framework version selected. - py_version (str): A string representing the python version selected. - image_uri (str): A string representing a Docker image URI. - - Raises: - ValueError: if - `py_version` is not python3 or - `framework_version` is not in PYTORCHDDP_SUPPORTED_FRAMEWORK_VERSIONS - """ - if framework_name and framework_name != "pytorch": - # We need to validate only for PyTorch framework - return - - pytorch_ddp_enabled = False - if "pytorchddp" in distribution: - pytorch_ddp_enabled = distribution.get("pytorchddp").get("enabled", False) - if not pytorch_ddp_enabled: - # Distribution strategy other than pytorchddp is selected - return - - err_msg = "" - if not image_uri: - # ignore framework_version and py_version if image_uri is set - # in case image_uri is not set, then both are mandatory - if framework_version not in PYTORCHDDP_SUPPORTED_FRAMEWORK_VERSIONS: - err_msg += ( - f"Provided framework_version {framework_version} is not supported by" - " pytorchddp.\n" - "Please specify one of the supported framework versions:" - f" {PYTORCHDDP_SUPPORTED_FRAMEWORK_VERSIONS} \n" - ) - if "py3" not in py_version: - err_msg += ( - f"Provided py_version {py_version} is not supported by pytorchddp.\n" - "Please specify py_version>=py3" - ) - if err_msg: - raise ValueError(err_msg) - - def validate_torch_distributed_distribution( instance_type, distribution, From b3a8c4e86d421ee00cc877066fdaeab1390a57d0 Mon Sep 17 00:00:00 2001 From: Tom Bousso Date: Tue, 21 May 2024 22:58:11 +0000 Subject: [PATCH 5/8] fix --- src/sagemaker/fw_utils.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/sagemaker/fw_utils.py b/src/sagemaker/fw_utils.py index 9f588a08df..81be406656 100644 --- a/src/sagemaker/fw_utils.py +++ b/src/sagemaker/fw_utils.py @@ -789,6 +789,8 @@ def _validate_smdataparallel_args( if not smdataparallel_enabled: return + err_msg = "" + if not image_uri: # ignore framework_version & py_version if image_uri is set # in case image_uri is not set, then both are mandatory From 16e72d5ce1a6749ed7dcd893f73f8f817fc0425c Mon Sep 17 00:00:00 2001 From: Tom Bousso Date: Tue, 21 May 2024 22:59:09 +0000 Subject: [PATCH 6/8] fix unit tests --- tests/unit/test_fw_utils.py | 78 +------------------------------------ tests/unit/test_pytorch.py | 5 ++- 2 files changed, 5 insertions(+), 78 deletions(-) diff --git a/tests/unit/test_fw_utils.py b/tests/unit/test_fw_utils.py index e955d68227..6575263778 100644 --- a/tests/unit/test_fw_utils.py +++ b/tests/unit/test_fw_utils.py @@ -853,18 +853,12 @@ def test_validate_smdataparallel_args_raises(): smdataparallel_enabled = {"smdistributed": {"dataparallel": {"enabled": True}}} # Cases {PT|TF2} - # 1. None instance type - # 2. incorrect instance type - # 3. incorrect python version - # 4. incorrect framework version + # 1. incorrect python version + # 2. incorrect framework version bad_args = [ - (None, "tensorflow", "2.3.1", "py3", smdataparallel_enabled), - ("ml.p3.2xlarge", "tensorflow", "2.3.1", "py3", smdataparallel_enabled), ("ml.p3dn.24xlarge", "tensorflow", "2.3.1", "py2", smdataparallel_enabled), ("ml.p3.16xlarge", "tensorflow", "1.3.1", "py3", smdataparallel_enabled), - (None, "pytorch", "1.6.0", "py3", smdataparallel_enabled), - ("ml.p3.2xlarge", "pytorch", "1.6.0", "py3", smdataparallel_enabled), ("ml.p3dn.24xlarge", "pytorch", "1.6.0", "py2", smdataparallel_enabled), ("ml.p3.16xlarge", "pytorch", "1.5.0", "py3", smdataparallel_enabled), ] @@ -966,74 +960,6 @@ def test_validate_smdataparallel_args_not_raises(): ) -def test_validate_pytorchddp_not_raises(): - # Case 1: Framework is not PyTorch - fw_utils.validate_pytorch_distribution( - distribution=None, - framework_name="tensorflow", - framework_version="2.9.1", - py_version="py3", - image_uri="custom-container", - ) - # Case 2: Framework is PyTorch, but distribution is not PyTorchDDP - pytorchddp_disabled = {"pytorchddp": {"enabled": False}} - fw_utils.validate_pytorch_distribution( - distribution=pytorchddp_disabled, - framework_name="pytorch", - framework_version="1.10", - py_version="py3", - image_uri="custom-container", - ) - # Case 3: Framework is PyTorch, Distribution is PyTorchDDP enabled, supported framework and py versions - pytorchddp_enabled = {"pytorchddp": {"enabled": True}} - pytorchddp_supported_fw_versions = [ - "1.10", - "1.10.0", - "1.10.2", - "1.11", - "1.11.0", - "1.12", - "1.12.0", - "1.12.1", - "1.13.1", - "2.0.0", - "2.0.1", - "2.1.0", - "2.2.0", - ] - for framework_version in pytorchddp_supported_fw_versions: - fw_utils.validate_pytorch_distribution( - distribution=pytorchddp_enabled, - framework_name="pytorch", - framework_version=framework_version, - py_version="py3", - image_uri="custom-container", - ) - - -def test_validate_pytorchddp_raises(): - pytorchddp_enabled = {"pytorchddp": {"enabled": True}} - # Case 1: Unsupported framework version - with pytest.raises(ValueError): - fw_utils.validate_pytorch_distribution( - distribution=pytorchddp_enabled, - framework_name="pytorch", - framework_version="1.8", - py_version="py3", - image_uri=None, - ) - - # Case 2: Unsupported Py version - with pytest.raises(ValueError): - fw_utils.validate_pytorch_distribution( - distribution=pytorchddp_enabled, - framework_name="pytorch", - framework_version="1.10", - py_version="py2", - image_uri=None, - ) - - def test_validate_torch_distributed_not_raises(): # Case 1: Framework is PyTorch, but torch_distributed is not enabled torch_distributed_disabled = {"torch_distributed": {"enabled": False}} diff --git a/tests/unit/test_pytorch.py b/tests/unit/test_pytorch.py index 5ada026ef8..618d0d7ea8 100644 --- a/tests/unit/test_pytorch.py +++ b/tests/unit/test_pytorch.py @@ -801,14 +801,15 @@ def test_pytorch_ddp_distribution_configuration( distribution=pytorch.distribution ) expected_torch_ddp = { - "sagemaker_pytorch_ddp_enabled": True, + "sagemaker_distributed_dataparallel_enabled": True, + "sagemaker_distributed_dataparallel_custom_mpi_options": "", "sagemaker_instance_type": test_instance_type, } assert actual_pytorch_ddp == expected_torch_ddp def test_pytorch_ddp_distribution_configuration_unsupported(sagemaker_session): - unsupported_framework_version = "1.9.1" + unsupported_framework_version = "1.5.0" unsupported_py_version = "py2" with pytest.raises(ValueError) as error: _pytorch_estimator( From 722a25e54bba9681d1452fce950a6f6cb85a5123 Mon Sep 17 00:00:00 2001 From: Tom Bousso Date: Wed, 22 May 2024 19:10:40 +0000 Subject: [PATCH 7/8] fix formatting --- src/sagemaker/pytorch/estimator.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/sagemaker/pytorch/estimator.py b/src/sagemaker/pytorch/estimator.py index eb574af285..412926279c 100644 --- a/src/sagemaker/pytorch/estimator.py +++ b/src/sagemaker/pytorch/estimator.py @@ -282,12 +282,12 @@ def __init__( raise ValueError( "Cannot use both pytorchddp and smdistributed " "distribution options together.", - distribution + distribution, ) # convert pytorchddp distribution into smdistributed distribution distribution = distribution.copy() - distribution["smdistributed"] = {"dataparallel" : distribution["pytorchddp"]} + distribution["smdistributed"] = {"dataparallel": distribution["pytorchddp"]} del distribution["pytorchddp"] distribution = validate_distribution( From 2e043f9c9d547ff55d532b4a2534cc4feeb92bf2 Mon Sep 17 00:00:00 2001 From: Tom Bousso Date: Wed, 22 May 2024 19:20:35 +0000 Subject: [PATCH 8/8] check instance type not None --- src/sagemaker/fw_utils.py | 3 +++ tests/unit/test_fw_utils.py | 7 +++++-- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/src/sagemaker/fw_utils.py b/src/sagemaker/fw_utils.py index 81be406656..be3658365a 100644 --- a/src/sagemaker/fw_utils.py +++ b/src/sagemaker/fw_utils.py @@ -791,6 +791,9 @@ def _validate_smdataparallel_args( err_msg = "" + if not instance_type: + err_msg += "Please specify an instance_type for smdataparallel.\n" + if not image_uri: # ignore framework_version & py_version if image_uri is set # in case image_uri is not set, then both are mandatory diff --git a/tests/unit/test_fw_utils.py b/tests/unit/test_fw_utils.py index 6575263778..97d4e6ec2a 100644 --- a/tests/unit/test_fw_utils.py +++ b/tests/unit/test_fw_utils.py @@ -853,12 +853,15 @@ def test_validate_smdataparallel_args_raises(): smdataparallel_enabled = {"smdistributed": {"dataparallel": {"enabled": True}}} # Cases {PT|TF2} - # 1. incorrect python version - # 2. incorrect framework version + # 1. None instance type + # 2. incorrect python version + # 3. incorrect framework version bad_args = [ + (None, "tensorflow", "2.3.1", "py3", smdataparallel_enabled), ("ml.p3dn.24xlarge", "tensorflow", "2.3.1", "py2", smdataparallel_enabled), ("ml.p3.16xlarge", "tensorflow", "1.3.1", "py3", smdataparallel_enabled), + (None, "pytorch", "1.6.0", "py3", smdataparallel_enabled), ("ml.p3dn.24xlarge", "pytorch", "1.6.0", "py2", smdataparallel_enabled), ("ml.p3.16xlarge", "pytorch", "1.5.0", "py3", smdataparallel_enabled), ]