diff --git a/README.rst b/README.rst index f1eb78c3fd..8575672754 100644 --- a/README.rst +++ b/README.rst @@ -222,7 +222,7 @@ PyTorch SageMaker Estimators With PyTorch SageMaker Estimators, you can train and host PyTorch models on Amazon SageMaker. -Supported versions of PyTorch: ``0.4.0``, ``1.0.0``, ``1.1.0``. +Supported versions of PyTorch: ``0.4.0``, ``1.0.0``, ``1.1.0``, ``1.2.0``. We recommend that you use the latest supported version, because that's where we focus most of our development efforts. diff --git a/src/sagemaker/fw_utils.py b/src/sagemaker/fw_utils.py index 99c49192d4..5e6338def3 100644 --- a/src/sagemaker/fw_utils.py +++ b/src/sagemaker/fw_utils.py @@ -67,6 +67,8 @@ "tensorflow-serving-eia": "tensorflow-inference-eia", "mxnet": "mxnet-training", "mxnet-serving": "mxnet-inference", + "pytorch": "pytorch-training", + "pytorch-serving": "pytorch-inference", "mxnet-serving-eia": "mxnet-inference-eia", } @@ -76,6 +78,8 @@ "tensorflow-serving-eia": [1, 14, 0], "mxnet": [1, 4, 1], "mxnet-serving": [1, 4, 1], + "pytorch": [1, 2, 0], + "pytorch-serving": [1, 2, 0], "mxnet-serving-eia": [1, 4, 1], } @@ -119,10 +123,15 @@ def _using_merged_images(region, framework, py_version, framework_version): is_gov_region = region in VALID_ACCOUNTS_BY_REGION is_py3 = py_version == "py3" or py_version is None is_merged_versions = _is_merged_versions(framework, framework_version) + return ( ((not is_gov_region) or region in ASIMOV_VALID_ACCOUNTS_BY_REGION) and is_merged_versions - and (is_py3 or _is_tf_14_or_later(framework, framework_version)) + and ( + is_py3 + or _is_tf_14_or_later(framework, framework_version) + or _is_pt_12_or_later(framework, framework_version) + ) ) @@ -140,6 +149,18 @@ def _is_tf_14_or_later(framework, framework_version): ) +def _is_pt_12_or_later(framework, framework_version): + """ + Args: + framework: Name of the frameowork + framework_version: framework version + """ + asimov_lowest_pt = [1, 12, 0] + version = [int(s) for s in framework_version.split(".")] + is_pytorch = framework in ("pytorch", "pytorch-serving") + return is_pytorch and version >= asimov_lowest_pt[0 : len(version)] + + def _registry_id(region, framework, py_version, account, framework_version): """ Args: diff --git a/src/sagemaker/pytorch/README.rst b/src/sagemaker/pytorch/README.rst index cb8dc780f3..8806422f68 100644 --- a/src/sagemaker/pytorch/README.rst +++ b/src/sagemaker/pytorch/README.rst @@ -4,7 +4,7 @@ SageMaker PyTorch Estimators and Models With PyTorch Estimators and Models, you can train and host PyTorch models on Amazon SageMaker. -Supported versions of PyTorch: ``0.4.0``, ``1.0.0``, ``1.1.0``. +Supported versions of PyTorch: ``0.4.0``, ``1.0.0``, ``1.1.0``, ``1.2.0``. We recommend that you use the latest supported version, because that's where we focus most of our development efforts. diff --git a/src/sagemaker/pytorch/estimator.py b/src/sagemaker/pytorch/estimator.py index 668c2749c0..452eb9c2f1 100644 --- a/src/sagemaker/pytorch/estimator.py +++ b/src/sagemaker/pytorch/estimator.py @@ -34,7 +34,7 @@ class PyTorch(Framework): __framework_name__ = "pytorch" - LATEST_VERSION = "1.1" + LATEST_VERSION = "1.2.0" """The latest version of PyTorch included in the SageMaker pre-built Docker images.""" def __init__( diff --git a/src/sagemaker/pytorch/model.py b/src/sagemaker/pytorch/model.py index edea603501..8de0fa0576 100644 --- a/src/sagemaker/pytorch/model.py +++ b/src/sagemaker/pytorch/model.py @@ -14,6 +14,7 @@ from __future__ import absolute_import import logging +import pkg_resources import sagemaker from sagemaker.fw_utils import create_image_uri, model_code_key_prefix, python_deprecation_warning @@ -53,6 +54,7 @@ class PyTorchModel(FrameworkModel): """ __framework_name__ = "pytorch" + _LOWEST_MMS_VERSION = "1.2" def __init__( self, @@ -122,19 +124,28 @@ def prepare_container_def(self, instance_type, accelerator_type=None): dict[str, str]: A container definition object usable with the CreateModel API. """ + lowest_mms_version = pkg_resources.parse_version(self._LOWEST_MMS_VERSION) + framework_version = pkg_resources.parse_version(self.framework_version) + is_mms_version = framework_version >= lowest_mms_version + deploy_image = self.image if not deploy_image: region_name = self.sagemaker_session.boto_session.region_name + + framework_name = self.__framework_name__ + if is_mms_version: + framework_name += "-serving" + deploy_image = create_image_uri( region_name, - self.__framework_name__, + framework_name, instance_type, self.framework_version, self.py_version, accelerator_type=accelerator_type, ) deploy_key_prefix = model_code_key_prefix(self.key_prefix, self.name, deploy_image) - self._upload_code(deploy_key_prefix) + self._upload_code(deploy_key_prefix, repack=is_mms_version) deploy_env = dict(self.env) deploy_env.update(self._framework_env_vars()) diff --git a/tests/integ/test_pytorch_train.py b/tests/integ/test_pytorch_train.py index 8f430c3665..1337317d2d 100644 --- a/tests/integ/test_pytorch_train.py +++ b/tests/integ/test_pytorch_train.py @@ -54,6 +54,31 @@ def test_sync_fit_deploy(pytorch_training_job, sagemaker_session, cpu_instance_t assert output.shape == (batch_size, 10) +@pytest.mark.local_mode +def test_fit_deploy(sagemaker_local_session, pytorch_full_version): + pytorch = PyTorch( + entry_point=MNIST_SCRIPT, + role="SageMakerRole", + framework_version=pytorch_full_version, + py_version="py3", + train_instance_count=1, + train_instance_type="local", + sagemaker_session=sagemaker_local_session, + ) + + pytorch.fit({"training": "file://" + os.path.join(MNIST_DIR, "training")}) + + predictor = pytorch.deploy(1, "local") + try: + batch_size = 100 + data = numpy.random.rand(batch_size, 1, 28, 28).astype(numpy.float32) + output = predictor.predict(data) + + assert output.shape == (batch_size, 10) + finally: + predictor.delete_endpoint() + + def test_deploy_model(pytorch_training_job, sagemaker_session, cpu_instance_type): endpoint_name = "test-pytorch-deploy-model-{}".format(sagemaker_timestamp()) diff --git a/tests/unit/test_fw_utils.py b/tests/unit/test_fw_utils.py index d285e23d94..db4704a40b 100644 --- a/tests/unit/test_fw_utils.py +++ b/tests/unit/test_fw_utils.py @@ -313,6 +313,32 @@ def test_create_image_uri_merged_gov_regions(): ) +def test_create_image_uri_merged_pytorch(): + + image_uri = fw_utils.create_image_uri("us-west-2", "pytorch", "ml.p3.2xlarge", "1.12", "py2") + assert image_uri == "763104351884.dkr.ecr.us-west-2.amazonaws.com/pytorch-training:1.12-gpu-py2" + + image_uri = fw_utils.create_image_uri("us-west-2", "pytorch", "ml.p3.2xlarge", "1.11", "py2") + assert ( + image_uri == "520713654638.dkr.ecr.us-west-2.amazonaws.com/sagemaker-pytorch:1.11-gpu-py2" + ) + + image_uri = fw_utils.create_image_uri( + "us-west-2", "pytorch-serving", "ml.c4.2xlarge", "1.12", "py2" + ) + assert ( + image_uri == "763104351884.dkr.ecr.us-west-2.amazonaws.com/pytorch-inference:1.12-cpu-py2" + ) + + image_uri = fw_utils.create_image_uri( + "us-west-2", "pytorch-serving", "ml.c4.2xlarge", "1.11", "py2" + ) + assert ( + image_uri + == "520713654638.dkr.ecr.us-west-2.amazonaws.com/sagemaker-pytorch-serving:1.11-cpu-py2" + ) + + def test_create_image_uri_accelerator_tf(): image_uri = fw_utils.create_image_uri( MOCK_REGION, "tensorflow", "ml.p3.2xlarge", "1.0", "py3", accelerator_type="ml.eia1.medium" diff --git a/tests/unit/test_pytorch.py b/tests/unit/test_pytorch.py index 14079b78e4..5248a48ac9 100644 --- a/tests/unit/test_pytorch.py +++ b/tests/unit/test_pytorch.py @@ -17,8 +17,7 @@ import os import pytest import sys -from mock import MagicMock, Mock -from mock import patch +from mock import ANY, MagicMock, Mock, patch from sagemaker.pytorch import defaults from sagemaker.pytorch import PyTorch @@ -296,6 +295,42 @@ def test_model(sagemaker_session): assert isinstance(predictor, PyTorchPredictor) +@patch("sagemaker.utils.create_tar_file", MagicMock()) +@patch("sagemaker.utils.repack_model") +def test_mms_model(repack_model, sagemaker_session): + PyTorchModel( + MODEL_DATA, + role=ROLE, + entry_point=SCRIPT_PATH, + sagemaker_session=sagemaker_session, + framework_version="1.2", + ).deploy(1, GPU) + + repack_model.assert_called_with( + dependencies=[], + inference_script=SCRIPT_PATH, + kms_key=None, + model_uri="s3://some/data.tar.gz", + repacked_model_uri=ANY, + sagemaker_session=sagemaker_session, + source_directory=None, + ) + + +@patch("sagemaker.utils.create_tar_file", MagicMock()) +@patch("sagemaker.utils.repack_model") +def test_non_mms_model(repack_model, sagemaker_session): + PyTorchModel( + MODEL_DATA, + role=ROLE, + entry_point=SCRIPT_PATH, + sagemaker_session=sagemaker_session, + framework_version="1.1", + ).deploy(1, GPU) + + repack_model.assert_not_called() + + @patch("sagemaker.fw_utils.tar_and_upload_dir", MagicMock()) def test_model_image_accelerator(sagemaker_session): model = PyTorchModel(