diff --git a/src/sagemaker/image_uri_config/pytorch-smp.json b/src/sagemaker/image_uri_config/pytorch-smp.json new file mode 100644 index 0000000000..96afc3cb1c --- /dev/null +++ b/src/sagemaker/image_uri_config/pytorch-smp.json @@ -0,0 +1,37 @@ +{ + "training": { + "processors": [ + "gpu" + ], + "version_aliases": { + "2.0": "2.0.1" + }, + "versions": { + "2.0.1": { + "py_versions": [ + "py310" + ], + "registries": { + "ap-northeast-1": "658645717510", + "ap-northeast-2": "658645717510", + "ap-northeast-3": "658645717510", + "ap-south-1": "658645717510", + "ap-southeast-1": "658645717510", + "ap-southeast-2": "658645717510", + "ca-central-1": "658645717510", + "eu-central-1": "658645717510", + "eu-north-1": "658645717510", + "eu-west-1": "658645717510", + "eu-west-2": "658645717510", + "eu-west-3": "658645717510", + "sa-east-1": "658645717510", + "us-east-1": "658645717510", + "us-east-2": "658645717510", + "us-west-1": "658645717510", + "us-west-2": "658645717510" + }, + "repository": "smdistributed-modelparallel" + } + } + } +} \ No newline at end of file diff --git a/src/sagemaker/image_uris.py b/src/sagemaker/image_uris.py index 267532cb1c..56e4bf346f 100644 --- a/src/sagemaker/image_uris.py +++ b/src/sagemaker/image_uris.py @@ -27,7 +27,10 @@ from sagemaker.jumpstart import artifacts from sagemaker.workflow import is_pipeline_variable from sagemaker.workflow.utilities import override_pipeline_parameter_var -from sagemaker.fw_utils import GRAVITON_ALLOWED_TARGET_INSTANCE_FAMILY, GRAVITON_ALLOWED_FRAMEWORKS +from sagemaker.fw_utils import ( + GRAVITON_ALLOWED_TARGET_INSTANCE_FAMILY, + GRAVITON_ALLOWED_FRAMEWORKS, +) logger = logging.getLogger(__name__) @@ -343,7 +346,8 @@ def _config_for_framework_and_scope(framework, image_scope, accelerator_type=Non if image_scope not in ("eia", "inference"): logger.warning( - "Elastic inference is for inference only. Ignoring image scope: %s.", image_scope + "Elastic inference is for inference only. Ignoring image scope: %s.", + image_scope, ) image_scope = "eia" @@ -660,6 +664,17 @@ def get_training_image_uri( container_version = None base_framework_version = None + # Check for smp library + if distribution is not None: + if "torch_distributed" in distribution and "smdistributed" in distribution: + if "modelparallel" in distribution["smdistributed"]: + if distribution["smdistributed"]["modelparallel"].get("enabled", True): + framework = "pytorch-smp" + if "p5" in instance_type: + container_version = "cu121" + else: + container_version = "cu118" + return retrieve( framework, region, diff --git a/tests/unit/sagemaker/image_uris/test_smp_v2.py b/tests/unit/sagemaker/image_uris/test_smp_v2.py new file mode 100644 index 0000000000..634c8a0f7f --- /dev/null +++ b/tests/unit/sagemaker/image_uris/test_smp_v2.py @@ -0,0 +1,53 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +import pytest +from sagemaker import image_uris +from tests.unit.sagemaker.image_uris import expected_uris + +CONTAINER_VERSIONS = {"ml.p4d.24xlarge": "cu118", "ml.p5d.24xlarge": "cu121"} + + +@pytest.mark.parametrize("load_config", ["pytorch-smp.json"], indirect=True) +def test_smp_v2(load_config): + VERSIONS = load_config["training"]["versions"] + PROCESSORS = load_config["training"]["processors"] + distribution = { + "torch_distributed": {"enabled": True}, + "smdistributed": {"modelparallel": {"enabled": True}}, + } + for processor in PROCESSORS: + for version in VERSIONS: + ACCOUNTS = load_config["training"]["versions"][version]["registries"] + PY_VERSIONS = load_config["training"]["versions"][version]["py_versions"] + for py_version in PY_VERSIONS: + for region in ACCOUNTS.keys(): + for instance_type in CONTAINER_VERSIONS.keys(): + uri = image_uris.get_training_image_uri( + region, + framework="pytorch", + framework_version=version, + py_version=py_version, + distribution=distribution, + instance_type=instance_type, + ) + expected = expected_uris.framework_uri( + repo="smdistributed-modelparallel", + fw_version=version, + py_version=f"{py_version}-{CONTAINER_VERSIONS[instance_type]}", + processor=processor, + region=region, + account=ACCOUNTS[region], + ) + assert expected == uri