diff --git a/src/sagemaker/fw_utils.py b/src/sagemaker/fw_utils.py index 3e55b15047..e495732501 100644 --- a/src/sagemaker/fw_utils.py +++ b/src/sagemaker/fw_utils.py @@ -60,7 +60,7 @@ ) SM_DATAPARALLEL_SUPPORTED_FRAMEWORK_VERSIONS = { "tensorflow": ["2.3.0", "2.3.1"], - "pytorch": ["1.6.0"], + "pytorch": ["1.6.0", "1.7.1"], } SMDISTRIBUTED_SUPPORTED_STRATEGIES = ["dataparallel", "modelparallel"] diff --git a/src/sagemaker/image_uri_config/pytorch.json b/src/sagemaker/image_uri_config/pytorch.json index d3a9e9faca..b9eef62678 100644 --- a/src/sagemaker/image_uri_config/pytorch.json +++ b/src/sagemaker/image_uri_config/pytorch.json @@ -54,7 +54,8 @@ "1.3": "1.3.1", "1.4": "1.4.0", "1.5": "1.5.0", - "1.6": "1.6.0" + "1.6": "1.6.0", + "1.7": "1.7.1" }, "versions": { "0.4.0": { @@ -318,6 +319,39 @@ "us-west-2": "763104351884" }, "repository": "pytorch-inference" + }, + "1.7.1": { + "py_versions": [ + "py3", + "py36" + ], + "registries": { + "af-south-1": "626614931356", + "ap-east-1": "871362719292", + "ap-northeast-1": "763104351884", + "ap-northeast-2": "763104351884", + "ap-south-1": "763104351884", + "ap-southeast-1": "763104351884", + "ap-southeast-2": "763104351884", + "ca-central-1": "763104351884", + "cn-north-1": "727897471807", + "cn-northwest-1": "727897471807", + "eu-central-1": "763104351884", + "eu-north-1": "763104351884", + "eu-west-1": "763104351884", + "eu-west-2": "763104351884", + "eu-west-3": "763104351884", + "eu-south-1": "692866216735", + "me-south-1": "217643126080", + "sa-east-1": "763104351884", + "us-east-1": "763104351884", + "us-east-2": "763104351884", + "us-gov-west-1": "442386744353", + "us-iso-east-1": "886529160074", + "us-west-1": "763104351884", + "us-west-2": "763104351884" + }, + "repository": "pytorch-inference" } } }, @@ -334,7 +368,8 @@ "1.3": "1.3.1", "1.4": "1.4.0", "1.5": "1.5.0", - "1.6": "1.6.0" + "1.6": "1.6.0", + "1.7": "1.7.1" }, "versions": { "0.4.0": { @@ -599,6 +634,39 @@ "us-west-2": "763104351884" }, "repository": "pytorch-training" + }, + "1.7.1": { + "py_versions": [ + "py3", + "py36" + ], + "registries": { + "af-south-1": "626614931356", + "ap-east-1": "871362719292", + "ap-northeast-1": "763104351884", + "ap-northeast-2": "763104351884", + "ap-south-1": "763104351884", + "ap-southeast-1": "763104351884", + "ap-southeast-2": "763104351884", + "ca-central-1": "763104351884", + "cn-north-1": "727897471807", + "cn-northwest-1": "727897471807", + "eu-central-1": "763104351884", + "eu-north-1": "763104351884", + "eu-west-1": "763104351884", + "eu-west-2": "763104351884", + "eu-west-3": "763104351884", + "eu-south-1": "692866216735", + "me-south-1": "217643126080", + "sa-east-1": "763104351884", + "us-east-1": "763104351884", + "us-east-2": "763104351884", + "us-gov-west-1": "442386744353", + "us-iso-east-1": "886529160074", + "us-west-1": "763104351884", + "us-west-2": "763104351884" + }, + "repository": "pytorch-training" } } } diff --git a/tests/conftest.py b/tests/conftest.py index 6e9794fe07..e652e9ec14 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -173,6 +173,8 @@ def mxnet_eia_latest_py_version(): def pytorch_training_py_version(pytorch_training_version, request): if Version(pytorch_training_version) < Version("1.5.0"): return request.param + elif Version(pytorch_training_version) == Version("1.7.1"): + return "py36" else: return "py3" @@ -181,6 +183,8 @@ def pytorch_training_py_version(pytorch_training_version, request): def pytorch_inference_py_version(pytorch_inference_version, request): if Version(pytorch_inference_version) < Version("1.4.0"): return request.param + elif Version(pytorch_inference_version) == Version("1.7.1"): + return "py36" else: return "py3" diff --git a/tests/data/smdistributed_dataparallel/mnist_pt.py b/tests/data/smdistributed_dataparallel/mnist_pt.py index 224a6a3882..1614b55d3d 100644 --- a/tests/data/smdistributed_dataparallel/mnist_pt.py +++ b/tests/data/smdistributed_dataparallel/mnist_pt.py @@ -13,6 +13,7 @@ from __future__ import print_function import argparse +import os import time import torch import torch.nn as nn @@ -150,8 +151,8 @@ def main(): parser.add_argument( "--data-path", type=str, - default="/tmp/data", - help="Path for downloading " "the MNIST dataset", + default=os.environ["SM_CHANNEL_TRAINING"], + help="Path for downloading the MNIST dataset", ) args = parser.parse_args() @@ -186,7 +187,7 @@ def main(): train_dataset = datasets.MNIST( data_path, train=True, - download=True, + download=False, # True sets a dependency on an external site for our tests. transform=transforms.Compose( [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))] ), diff --git a/tests/integ/test_smdataparallel_pt.py b/tests/integ/test_smdataparallel_pt.py index b7dbdd5c32..3dfc7b387b 100644 --- a/tests/integ/test_smdataparallel_pt.py +++ b/tests/integ/test_smdataparallel_pt.py @@ -21,7 +21,7 @@ from sagemaker.pytorch import PyTorch from tests.integ import timeout - +from tests.integ.test_pytorch import _upload_training_data smdataparallel_dir = os.path.join( os.path.dirname(__file__), "..", "data", "smdistributed_dataparallel" @@ -51,4 +51,4 @@ def test_smdataparallel_pt_mnist( ) with timeout.timeout(minutes=integ.TRAINING_DEFAULT_TIMEOUT_MINUTES): - estimator.fit(job_name=job_name) + estimator.fit({"training": _upload_training_data(estimator)}, job_name=job_name) diff --git a/tests/unit/test_fw_utils.py b/tests/unit/test_fw_utils.py index 60cb117f06..653e078fd7 100644 --- a/tests/unit/test_fw_utils.py +++ b/tests/unit/test_fw_utils.py @@ -632,6 +632,7 @@ def test_validate_smdataparallel_args_not_raises(): (None, None, None, None, smdataparallel_disabled), ("ml.p3.16xlarge", "tensorflow", "2.3.1", "py3", smdataparallel_enabled), ("ml.p3.16xlarge", "pytorch", "1.6.0", "py3", smdataparallel_enabled), + ("ml.p3.16xlarge", "pytorch", "1.7.1", "py3", smdataparallel_enabled), ] for instance_type, framework_name, framework_version, py_version, distribution in good_args: fw_utils._validate_smdataparallel_args(