Skip to content

Commit c9aa29b

Browse files
authored
feat: Integ tests for jumpstart model and estimator (#2865)
1 parent 63b0372 commit c9aa29b

File tree

12 files changed

+352
-129
lines changed

12 files changed

+352
-129
lines changed

src/sagemaker/model.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -406,7 +406,10 @@ def prepare_container_def(
406406
self._upload_code(deploy_key_prefix, repack=is_repack)
407407
deploy_env.update(self._script_mode_env_vars())
408408
return sagemaker.container_def(
409-
self.image_uri, self.model_data, deploy_env, image_config=self.image_config
409+
self.image_uri,
410+
self.repacked_model_data or self.model_data,
411+
deploy_env,
412+
image_config=self.image_config,
410413
)
411414

412415
def _upload_code(self, key_prefix: str, repack: bool = False) -> None:

tests/integ/sagemaker/jumpstart/retrieve_uri/conftest.py renamed to tests/integ/sagemaker/jumpstart/conftest.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,16 +16,16 @@
1616
import boto3
1717
import pytest
1818
from botocore.config import Config
19+
from tests.integ.sagemaker.jumpstart.constants import (
20+
ENV_VAR_JUMPSTART_SDK_TEST_SUITE_ID,
21+
JUMPSTART_TAG,
22+
)
1923

2024

21-
from tests.integ.sagemaker.jumpstart.retrieve_uri.utils import (
25+
from tests.integ.sagemaker.jumpstart.utils import (
2226
get_test_artifact_bucket,
2327
get_test_suite_id,
2428
)
25-
from tests.integ.sagemaker.jumpstart.retrieve_uri.constants import (
26-
ENV_VAR_JUMPSTART_SDK_TEST_SUITE_ID,
27-
JUMPSTART_TAG,
28-
)
2929

3030
from sagemaker.jumpstart.constants import JUMPSTART_DEFAULT_REGION_NAME
3131

tests/integ/sagemaker/jumpstart/retrieve_uri/inference.py

Lines changed: 2 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -12,25 +12,21 @@
1212
# language governing permissions and limitations under the License.
1313
from __future__ import absolute_import
1414

15-
import json
1615
import time
17-
from typing import Any, Dict, List
1816
import boto3
1917
import os
2018
from botocore.config import Config
21-
import pandas as pd
2219

2320
from sagemaker.jumpstart.constants import JUMPSTART_DEFAULT_REGION_NAME
24-
from tests.integ.sagemaker.jumpstart.retrieve_uri.utils import (
21+
from tests.integ.sagemaker.jumpstart.utils import (
2522
get_test_artifact_bucket,
2623
get_sm_session,
2724
)
2825

2926
from sagemaker.utils import repack_model
30-
from tests.integ.sagemaker.jumpstart.retrieve_uri.constants import (
27+
from tests.integ.sagemaker.jumpstart.constants import (
3128
ENV_VAR_JUMPSTART_SDK_TEST_SUITE_ID,
3229
JUMPSTART_TAG,
33-
ContentType,
3430
)
3531

3632

@@ -189,43 +185,3 @@ def create_endpoint(self) -> None:
189185
}
190186
],
191187
)
192-
193-
194-
class EndpointInvoker:
195-
def __init__(
196-
self,
197-
endpoint_name,
198-
region=JUMPSTART_DEFAULT_REGION_NAME,
199-
boto_config=Config(retries={"max_attempts": 10, "mode": "standard"}),
200-
) -> None:
201-
self.endpoint_name = endpoint_name
202-
self.region = region
203-
self.config = boto_config
204-
self.sagemaker_runtime_client = self.get_sagemaker_runtime_client()
205-
206-
def _invoke_endpoint(
207-
self,
208-
body: Any,
209-
content_type: ContentType,
210-
) -> Dict[str, Any]:
211-
response = self.sagemaker_runtime_client.invoke_endpoint(
212-
EndpointName=self.endpoint_name, ContentType=content_type.value, Body=body
213-
)
214-
return json.loads(response["Body"].read())
215-
216-
def invoke_tabular_endpoint(self, data: pd.DataFrame) -> Dict[str, Any]:
217-
return self._invoke_endpoint(
218-
body=data.to_csv(header=False, index=False).encode("utf-8"),
219-
content_type=ContentType.TEXT_CSV,
220-
)
221-
222-
def invoke_spc_endpoint(self, text: List[str]) -> Dict[str, Any]:
223-
return self._invoke_endpoint(
224-
body=json.dumps(text).encode("utf-8"),
225-
content_type=ContentType.LIST_TEXT,
226-
)
227-
228-
def get_sagemaker_runtime_client(self) -> boto3.client:
229-
return boto3.client(
230-
service_name="runtime.sagemaker", config=self.config, region_name=self.region
231-
)

tests/integ/sagemaker/jumpstart/retrieve_uri/test_inference.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,18 +14,18 @@
1414

1515

1616
from tests.integ.sagemaker.jumpstart.retrieve_uri.inference import (
17-
EndpointInvoker,
1817
InferenceJobLauncher,
1918
)
2019
from sagemaker import environment_variables, image_uris
2120
from sagemaker import script_uris
2221
from sagemaker import model_uris
2322

24-
from tests.integ.sagemaker.jumpstart.retrieve_uri.constants import InferenceTabularDataname
23+
from tests.integ.sagemaker.jumpstart.constants import InferenceTabularDataname
2524

26-
from tests.integ.sagemaker.jumpstart.retrieve_uri.utils import (
25+
from tests.integ.sagemaker.jumpstart.utils import (
2726
download_inference_assets,
2827
get_tabular_data,
28+
EndpointInvoker,
2929
)
3030

3131

tests/integ/sagemaker/jumpstart/retrieve_uri/test_transfer_learning.py

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -12,15 +12,15 @@
1212
# language governing permissions and limitations under the License.
1313
from __future__ import absolute_import
1414

15-
import pandas as pd
16-
1715

1816
from tests.integ.sagemaker.jumpstart.retrieve_uri.utils import (
1917
get_model_tarball_full_uri_from_base_uri,
18+
)
19+
from tests.integ.sagemaker.jumpstart.utils import (
2020
get_training_dataset_for_model_and_version,
21+
EndpointInvoker,
2122
)
2223
from tests.integ.sagemaker.jumpstart.retrieve_uri.inference import (
23-
EndpointInvoker,
2424
InferenceJobLauncher,
2525
)
2626
from tests.integ.sagemaker.jumpstart.retrieve_uri.training import TrainingJobLauncher
@@ -59,6 +59,8 @@ def test_jumpstart_transfer_learning_retrieve_functions(setup):
5959
model_id=model_id, model_version=model_version, include_container_hyperparameters=True
6060
)
6161

62+
default_hyperparameters["epochs"] = "1"
63+
6264
training_job = TrainingJobLauncher(
6365
image_uri=image_uri,
6466
script_uri=script_uri,
@@ -110,10 +112,5 @@ def test_jumpstart_transfer_learning_retrieve_functions(setup):
110112
)
111113

112114
response = endpoint_invoker.invoke_spc_endpoint(["hello", "world"])
113-
entail, no_entail = response[0][0], response[0][1]
114-
115-
assert entail is not None
116-
assert no_entail is not None
117115

118-
assert pd.isna(entail) is False
119-
assert pd.isna(no_entail) is False
116+
assert response is not None

tests/integ/sagemaker/jumpstart/retrieve_uri/training.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,14 +18,16 @@
1818
from botocore.config import Config
1919

2020
from sagemaker.jumpstart.constants import JUMPSTART_DEFAULT_REGION_NAME
21-
from tests.integ.sagemaker.jumpstart.retrieve_uri.utils import (
22-
get_full_hyperparameters,
21+
from tests.integ.sagemaker.jumpstart.utils import (
2322
get_test_artifact_bucket,
2423
get_sm_session,
2524
)
25+
from tests.integ.sagemaker.jumpstart.retrieve_uri.utils import (
26+
get_full_hyperparameters,
27+
)
2628
from sagemaker.jumpstart.utils import get_jumpstart_content_bucket
2729

28-
from tests.integ.sagemaker.jumpstart.retrieve_uri.constants import (
30+
from tests.integ.sagemaker.jumpstart.constants import (
2931
ENV_VAR_JUMPSTART_SDK_TEST_SUITE_ID,
3032
)
3133

tests/integ/sagemaker/jumpstart/retrieve_uri/utils.py

Lines changed: 0 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -12,26 +12,7 @@
1212
# language governing permissions and limitations under the License.
1313
from __future__ import absolute_import
1414

15-
import uuid
16-
from typing import Tuple
17-
import boto3
18-
import pandas as pd
19-
import os
20-
21-
from tests.integ.sagemaker.jumpstart.retrieve_uri.constants import (
22-
TEST_ASSETS_SPECS,
23-
TMP_DIRECTORY_PATH,
24-
TRAINING_DATASET_MODEL_DICT,
25-
)
26-
from sagemaker.jumpstart.constants import JUMPSTART_DEFAULT_REGION_NAME
27-
from sagemaker.jumpstart.utils import get_jumpstart_content_bucket
28-
2915
from sagemaker.s3 import parse_s3_url
30-
from sagemaker.session import Session
31-
32-
33-
def download_file(local_download_path, s3_bucket, s3_key, s3_client) -> None:
34-
s3_client.download_file(s3_bucket, s3_key, local_download_path)
3516

3617

3718
def get_model_tarball_full_uri_from_base_uri(base_uri: str, training_job_name: str) -> str:
@@ -56,46 +37,3 @@ def get_full_hyperparameters(
5637
"model-artifact-bucket": bucket,
5738
"model-artifact-key": key,
5839
}
59-
60-
61-
def get_training_dataset_for_model_and_version(model_id: str, version: str) -> dict:
62-
return TRAINING_DATASET_MODEL_DICT[(model_id, version)]
63-
64-
65-
def get_sm_session() -> Session:
66-
return Session(boto_session=boto3.Session(region_name=JUMPSTART_DEFAULT_REGION_NAME))
67-
68-
69-
def get_test_artifact_bucket() -> str:
70-
bucket_name = get_sm_session().default_bucket()
71-
return bucket_name
72-
73-
74-
def download_inference_assets():
75-
76-
if not os.path.exists(TMP_DIRECTORY_PATH):
77-
os.makedirs(TMP_DIRECTORY_PATH)
78-
79-
for asset, s3_key in TEST_ASSETS_SPECS.items():
80-
file_path = os.path.join(TMP_DIRECTORY_PATH, str(asset.value))
81-
if not os.path.exists(file_path):
82-
download_file(
83-
file_path,
84-
get_jumpstart_content_bucket(JUMPSTART_DEFAULT_REGION_NAME),
85-
s3_key,
86-
boto3.client("s3"),
87-
)
88-
89-
90-
def get_tabular_data(data_filename: str) -> Tuple[pd.DataFrame, pd.DataFrame]:
91-
92-
asset_file_path = os.path.join(TMP_DIRECTORY_PATH, data_filename)
93-
94-
test_data = pd.read_csv(asset_file_path, header=None)
95-
label, features = test_data.iloc[:, :1], test_data.iloc[:, 1:]
96-
97-
return label, features
98-
99-
100-
def get_test_suite_id() -> str:
101-
return str(uuid.uuid4())

tests/integ/sagemaker/jumpstart/script_mode_class/__init__.py

Whitespace-only changes.
Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
from __future__ import absolute_import
14+
import os
15+
16+
from sagemaker import image_uris, model_uris, script_uris
17+
from sagemaker.jumpstart.constants import INFERENCE_ENTRYPOINT_SCRIPT_NAME
18+
from sagemaker.model import Model
19+
from tests.integ.sagemaker.jumpstart.constants import (
20+
ENV_VAR_JUMPSTART_SDK_TEST_SUITE_ID,
21+
JUMPSTART_TAG,
22+
InferenceTabularDataname,
23+
)
24+
from tests.integ.sagemaker.jumpstart.utils import (
25+
EndpointInvoker,
26+
download_inference_assets,
27+
get_sm_session,
28+
get_tabular_data,
29+
)
30+
31+
32+
def test_jumpstart_inference_model_class(setup):
33+
34+
model_id, model_version = "catboost-classification-model", "1.0.0"
35+
instance_type, instance_count = "ml.m5.xlarge", 1
36+
37+
print("Starting inference...")
38+
39+
image_uri = image_uris.retrieve(
40+
region=None,
41+
framework=None,
42+
image_scope="inference",
43+
model_id=model_id,
44+
model_version=model_version,
45+
instance_type=instance_type,
46+
)
47+
48+
script_uri = script_uris.retrieve(
49+
model_id=model_id, model_version=model_version, script_scope="inference"
50+
)
51+
52+
model_uri = model_uris.retrieve(
53+
model_id=model_id, model_version=model_version, model_scope="inference"
54+
)
55+
56+
model = Model(
57+
image_uri=image_uri,
58+
model_data=model_uri,
59+
source_dir=script_uri,
60+
entry_point=INFERENCE_ENTRYPOINT_SCRIPT_NAME,
61+
role=get_sm_session().get_caller_identity_arn(),
62+
sagemaker_session=get_sm_session(),
63+
enable_network_isolation=True,
64+
)
65+
66+
model.deploy(
67+
initial_instance_count=instance_count,
68+
instance_type=instance_type,
69+
tags=[{"Key": JUMPSTART_TAG, "Value": os.environ[ENV_VAR_JUMPSTART_SDK_TEST_SUITE_ID]}],
70+
)
71+
72+
endpoint_invoker = EndpointInvoker(
73+
endpoint_name=model.endpoint_name,
74+
)
75+
76+
download_inference_assets()
77+
ground_truth_label, features = get_tabular_data(InferenceTabularDataname.MULTICLASS)
78+
79+
response = endpoint_invoker.invoke_tabular_endpoint(features)
80+
81+
assert response is not None

0 commit comments

Comments
 (0)