Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion tests/integ/test_pytorch_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,10 @@ def test_deploy_model(pytorch_training_job, sagemaker_session, cpu_instance_type


@pytest.mark.skipif(PYTHON_VERSION == "py2", reason="PyTorch EIA does not support Python 2.")
@pytest.mark.skipif(
tests.integ.test_region() not in tests.integ.EI_SUPPORTED_REGIONS,
reason="EI isn't supported in that specific region.",
)
def test_deploy_model_with_accelerator(sagemaker_session, cpu_instance_type):
endpoint_name = "test-pytorch-deploy-eia-{}".format(sagemaker_timestamp())
model_data = sagemaker_session.upload_data(path=EIA_MODEL)
Expand All @@ -134,7 +138,7 @@ def test_deploy_model_with_accelerator(sagemaker_session, cpu_instance_type):
predictor = pytorch.deploy(
initial_instance_count=1,
instance_type=cpu_instance_type,
accelerator_type="ml.eia2.medium",
accelerator_type="ml.eia1.medium",
endpoint_name=endpoint_name,
)

Expand Down