diff --git a/tests/integ/test_pytorch_train.py b/tests/integ/test_pytorch_train.py index c1164e13e7..4066d5001b 100644 --- a/tests/integ/test_pytorch_train.py +++ b/tests/integ/test_pytorch_train.py @@ -12,18 +12,23 @@ # language governing permissions and limitations under the License. from __future__ import absolute_import -import os - import numpy +import os import pytest -from tests.integ import DATA_DIR, PYTHON_VERSION, TRAINING_DEFAULT_TIMEOUT_MINUTES -from tests.integ.timeout import timeout, timeout_and_delete_endpoint_by_name - +from sagemaker.pytorch.defaults import LATEST_PY2_VERSION from sagemaker.pytorch.estimator import PyTorch from sagemaker.pytorch.model import PyTorchModel -from sagemaker.pytorch.defaults import LATEST_PY2_VERSION from sagemaker.utils import sagemaker_timestamp +from tests.integ import ( + test_region, + DATA_DIR, + PYTHON_VERSION, + TRAINING_DEFAULT_TIMEOUT_MINUTES, + EI_SUPPORTED_REGIONS, +) +from tests.integ.timeout import timeout, timeout_and_delete_endpoint_by_name + MNIST_DIR = os.path.join(DATA_DIR, "pytorch_mnist") MNIST_SCRIPT = os.path.join(MNIST_DIR, "mnist.py") @@ -120,6 +125,9 @@ def test_deploy_model(pytorch_training_job, sagemaker_session, cpu_instance_type @pytest.mark.skipif(PYTHON_VERSION == "py2", reason="PyTorch EIA does not support Python 2.") +@pytest.mark.skipif( + test_region() not in EI_SUPPORTED_REGIONS, reason="EI isn't supported in that specific region." +) def test_deploy_model_with_accelerator(sagemaker_session, cpu_instance_type): endpoint_name = "test-pytorch-deploy-eia-{}".format(sagemaker_timestamp()) model_data = sagemaker_session.upload_data(path=EIA_MODEL) @@ -134,7 +142,7 @@ def test_deploy_model_with_accelerator(sagemaker_session, cpu_instance_type): predictor = pytorch.deploy( initial_instance_count=1, instance_type=cpu_instance_type, - accelerator_type="ml.eia2.medium", + accelerator_type="ml.eia1.medium", endpoint_name=endpoint_name, )