|
10 | 10 | # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF |
11 | 11 | # ANY KIND, either express or implied. See the License for the specific |
12 | 12 | # language governing permissions and limitations under the License. |
13 | | -# from __future__ import absolute_import |
14 | | - |
15 | | -# import pytest |
16 | | -# import torch |
17 | | -# from PIL import Image |
18 | | -# import os |
19 | | - |
20 | | -# from sagemaker.serve.builder.model_builder import ModelBuilder, Mode |
21 | | -# from sagemaker.serve.builder.schema_builder import SchemaBuilder |
22 | | -# from sagemaker.serve.spec.inference_spec import InferenceSpec |
23 | | -# from torchvision.transforms import transforms |
24 | | -# from torchvision.models.squeezenet import squeezenet1_1 |
25 | | - |
26 | | -# from tests.integ.sagemaker.serve.constants import ( |
27 | | -# PYTORCH_SQUEEZENET_RESOURCE_DIR, |
28 | | -# SERVE_SAGEMAKER_ENDPOINT_TIMEOUT, |
29 | | -# NOT_RUNNING_ON_INF_EXP_DEV_PIPELINE, |
30 | | -# NOT_RUNNING_ON_PY310, |
31 | | -# ) |
32 | | -# from tests.integ.timeout import timeout |
33 | | -# from tests.integ.utils import cleanup_model_resources |
34 | | -# import logging |
35 | | - |
36 | | -# logger = logging.getLogger(__name__) |
37 | | - |
38 | | -# ROLE_NAME = "Admin" |
39 | | - |
40 | | -# GH_USER_NAME = os.getenv("GH_USER_NAME") |
41 | | -# GH_ACCESS_TOKEN = os.getenv("GH_ACCESS_TOKEN") |
42 | | - |
43 | | - |
44 | | -# @pytest.fixture |
45 | | -# def pt_dependencies(): |
46 | | -# return { |
47 | | -# "auto": True, |
48 | | -# "custom": [ |
49 | | -# "boto3==1.26.*", |
50 | | -# "botocore==1.29.*", |
51 | | -# "s3transfer==0.6.*", |
52 | | -# ( |
53 | | -# f"git+https://{GH_USER_NAME}:{GH_ACCESS_TOKEN}@github.com" |
54 | | -# "/aws/sagemaker-python-sdk-staging.git@inference-experience-dev" |
55 | | -# ), |
56 | | -# ], |
57 | | -# } |
58 | | - |
59 | | - |
60 | | -# @pytest.fixture |
61 | | -# def test_image(): |
62 | | -# return Image.open(str(os.path.join(PYTORCH_SQUEEZENET_RESOURCE_DIR, "zidane.jpeg"))) |
63 | | - |
64 | | - |
65 | | -# @pytest.fixture |
66 | | -# def squeezenet_inference_spec(): |
67 | | -# class MySqueezeNetModel(InferenceSpec): |
68 | | -# def __init__(self) -> None: |
69 | | -# super().__init__() |
70 | | -# self.transform = transforms.Compose( |
71 | | -# [ |
72 | | -# transforms.Resize(256), |
73 | | -# transforms.CenterCrop(224), |
74 | | -# transforms.ToTensor(), |
75 | | -# transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), |
76 | | -# ] |
77 | | -# ) |
78 | | - |
79 | | -# def invoke(self, input_object: object, model: object): |
80 | | -# # transform |
81 | | -# image_tensor = self.transform(input_object) |
82 | | -# input_batch = image_tensor.unsqueeze(0) |
83 | | -# # invoke |
84 | | -# with torch.no_grad(): |
85 | | -# output = model(input_batch) |
86 | | -# return output |
87 | | - |
88 | | -# def load(self, model_dir: str): |
89 | | -# model = squeezenet1_1() |
90 | | -# model.load_state_dict(torch.load(model_dir + "/model.pth")) |
91 | | -# model.eval() |
92 | | -# return model |
93 | | - |
94 | | -# return MySqueezeNetModel() |
95 | | - |
96 | | - |
97 | | -# @pytest.fixture |
98 | | -# def squeezenet_schema(): |
99 | | -# input_image = Image.open(os.path.join(PYTORCH_SQUEEZENET_RESOURCE_DIR, "zidane.jpeg")) |
100 | | -# output_tensor = torch.rand(3, 4) |
101 | | -# return SchemaBuilder(sample_input=input_image, sample_output=output_tensor) |
102 | | - |
103 | | - |
104 | | -# @pytest.fixture |
105 | | -# def model_builder_inference_spec_schema_builder( |
106 | | -# squeezenet_inference_spec, squeezenet_schema, pt_dependencies |
107 | | -# ): |
108 | | -# return ModelBuilder( |
109 | | -# model_path=PYTORCH_SQUEEZENET_RESOURCE_DIR, |
110 | | -# inference_spec=squeezenet_inference_spec, |
111 | | -# schema_builder=squeezenet_schema, |
112 | | -# dependencies=pt_dependencies, |
113 | | -# ) |
114 | | - |
115 | | - |
116 | | -# @pytest.fixture |
117 | | -# def model_builder(request): |
118 | | -# return request.getfixturevalue(request.param) |
| 13 | +from __future__ import absolute_import |
| 14 | + |
| 15 | +import pytest |
| 16 | +import torch |
| 17 | +from PIL import Image |
| 18 | +import os |
| 19 | +import io |
| 20 | +import numpy as np |
| 21 | + |
| 22 | +from sagemaker.serve.builder.model_builder import ModelBuilder, Mode |
| 23 | +from sagemaker.serve.builder.schema_builder import SchemaBuilder, CustomPayloadTranslator |
| 24 | +from sagemaker.serve.spec.inference_spec import InferenceSpec |
| 25 | +from torchvision.transforms import transforms |
| 26 | +from torchvision.models.squeezenet import squeezenet1_1 |
| 27 | + |
| 28 | +from tests.integ.sagemaker.serve.constants import ( |
| 29 | + PYTORCH_SQUEEZENET_RESOURCE_DIR, |
| 30 | + SERVE_SAGEMAKER_ENDPOINT_TIMEOUT, |
| 31 | + NOT_RUNNING_ON_PY310, |
| 32 | +) |
| 33 | +from tests.integ.timeout import timeout |
| 34 | +from tests.integ.utils import cleanup_model_resources |
| 35 | +import logging |
| 36 | + |
| 37 | +logger = logging.getLogger(__name__) |
| 38 | + |
| 39 | +ROLE_NAME = "SageMakerRole" |
| 40 | + |
| 41 | + |
| 42 | +@pytest.fixture |
| 43 | +def test_image(): |
| 44 | + return Image.open(str(os.path.join(PYTORCH_SQUEEZENET_RESOURCE_DIR, "zidane.jpeg"))) |
| 45 | + |
| 46 | + |
| 47 | +@pytest.fixture |
| 48 | +def squeezenet_inference_spec(): |
| 49 | + class MySqueezeNetModel(InferenceSpec): |
| 50 | + def __init__(self) -> None: |
| 51 | + super().__init__() |
| 52 | + |
| 53 | + def invoke(self, input_object: object, model: object): |
| 54 | + with torch.no_grad(): |
| 55 | + output = model(input_object) |
| 56 | + return output |
| 57 | + |
| 58 | + def load(self, model_dir: str): |
| 59 | + model = squeezenet1_1() |
| 60 | + model.load_state_dict(torch.load(model_dir + "/model.pth")) |
| 61 | + model.eval() |
| 62 | + return model |
| 63 | + |
| 64 | + return MySqueezeNetModel() |
| 65 | + |
| 66 | + |
| 67 | +@pytest.fixture |
| 68 | +def custom_request_translator(): |
| 69 | + # request translator |
| 70 | + class MyRequestTranslator(CustomPayloadTranslator): |
| 71 | + def __init__(self): |
| 72 | + super().__init__() |
| 73 | + # Define image transformation |
| 74 | + self.transform = transforms.Compose( |
| 75 | + [ |
| 76 | + transforms.Resize(256), |
| 77 | + transforms.CenterCrop(224), |
| 78 | + transforms.ToTensor(), |
| 79 | + transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), |
| 80 | + ] |
| 81 | + ) |
| 82 | + |
| 83 | + # This function converts the payload to bytes - happens on client side |
| 84 | + def serialize_payload_to_bytes(self, payload: object) -> bytes: |
| 85 | + # converts an image to bytes |
| 86 | + image_tensor = self.transform(payload) |
| 87 | + input_batch = image_tensor.unsqueeze(0) |
| 88 | + input_ndarray = input_batch.numpy() |
| 89 | + return self._convert_numpy_to_bytes(input_ndarray) |
| 90 | + |
| 91 | + # This function converts the bytes to payload - happens on server side |
| 92 | + def deserialize_payload_from_stream(self, stream) -> torch.Tensor: |
| 93 | + # convert payload back to torch.Tensor |
| 94 | + np_array = np.load(io.BytesIO(stream.read())) |
| 95 | + return torch.from_numpy(np_array) |
| 96 | + |
| 97 | + def _convert_numpy_to_bytes(self, np_array: np.ndarray) -> bytes: |
| 98 | + buffer = io.BytesIO() |
| 99 | + np.save(buffer, np_array) |
| 100 | + return buffer.getvalue() |
| 101 | + |
| 102 | + return MyRequestTranslator() |
| 103 | + |
| 104 | + |
| 105 | +@pytest.fixture |
| 106 | +def custom_response_translator(): |
| 107 | + # response translator |
| 108 | + class MyResponseTranslator(CustomPayloadTranslator): |
| 109 | + # This function converts the payload to bytes - happens on server side |
| 110 | + def serialize_payload_to_bytes(self, payload: torch.Tensor) -> bytes: |
| 111 | + return self._convert_numpy_to_bytes(payload.numpy()) |
| 112 | + |
| 113 | + # This function converts the bytes to payload - happens on client side |
| 114 | + def deserialize_payload_from_stream(self, stream) -> object: |
| 115 | + return torch.from_numpy(np.load(io.BytesIO(stream.read()))) |
| 116 | + |
| 117 | + def _convert_numpy_to_bytes(self, np_array: np.ndarray) -> bytes: |
| 118 | + buffer = io.BytesIO() |
| 119 | + np.save(buffer, np_array) |
| 120 | + return buffer.getvalue() |
| 121 | + |
| 122 | + return MyResponseTranslator() |
| 123 | + |
| 124 | + |
| 125 | +@pytest.fixture |
| 126 | +def squeezenet_schema(custom_request_translator, custom_response_translator): |
| 127 | + input_image = Image.open(os.path.join(PYTORCH_SQUEEZENET_RESOURCE_DIR, "zidane.jpeg")) |
| 128 | + output_tensor = torch.rand(3, 4) |
| 129 | + return SchemaBuilder( |
| 130 | + sample_input=input_image, |
| 131 | + sample_output=output_tensor, |
| 132 | + input_translator=custom_request_translator, |
| 133 | + output_translator=custom_response_translator, |
| 134 | + ) |
| 135 | + |
| 136 | +@pytest.fixture |
| 137 | +def model_builder_inference_spec_schema_builder(squeezenet_inference_spec, squeezenet_schema): |
| 138 | + return ModelBuilder( |
| 139 | + model_path=PYTORCH_SQUEEZENET_RESOURCE_DIR, |
| 140 | + inference_spec=squeezenet_inference_spec, |
| 141 | + schema_builder=squeezenet_schema, |
| 142 | + ) |
| 143 | + |
| 144 | + |
| 145 | +@pytest.fixture |
| 146 | +def model_builder(request): |
| 147 | + return request.getfixturevalue(request.param) |
119 | 148 |
|
120 | 149 |
|
121 | 150 | # @pytest.mark.skipif( |
|
149 | 178 | # ), f"{caught_ex} was thrown when running pytorch squeezenet local container test" |
150 | 179 |
|
151 | 180 |
|
152 | | -# @pytest.mark.skipif( |
153 | | -# NOT_RUNNING_ON_INF_EXP_DEV_PIPELINE or NOT_RUNNING_ON_PY310, |
154 | | -# reason="The goal of these test are to test the serving components of our feature", |
155 | | -# ) |
156 | | -# @pytest.mark.parametrize( |
157 | | -# "model_builder", ["model_builder_inference_spec_schema_builder"], indirect=True |
158 | | -# ) |
159 | | -# def test_happy_pytorch_sagemaker_endpoint( |
160 | | -# sagemaker_session, model_builder, cpu_instance_type, test_image |
161 | | -# ): |
162 | | -# logger.info("Running in SAGEMAKER_ENDPOINT mode...") |
163 | | -# caught_ex = None |
164 | | - |
165 | | -# iam_client = sagemaker_session.boto_session.client("iam") |
166 | | -# role_arn = iam_client.get_role(RoleName=ROLE_NAME)["Role"]["Arn"] |
167 | | - |
168 | | -# model = model_builder.build( |
169 | | -# mode=Mode.SAGEMAKER_ENDPOINT, role_arn=role_arn, sagemaker_session=sagemaker_session |
170 | | -# ) |
171 | | - |
172 | | -# with timeout(minutes=SERVE_SAGEMAKER_ENDPOINT_TIMEOUT): |
173 | | -# try: |
174 | | -# logger.info("Deploying and predicting in SAGEMAKER_ENDPOINT mode...") |
175 | | -# predictor = model.deploy(instance_type=cpu_instance_type, initial_instance_count=1) |
176 | | -# logger.info("Endpoint successfully deployed.") |
177 | | -# predictor.predict(test_image) |
178 | | -# except Exception as e: |
179 | | -# caught_ex = e |
180 | | -# finally: |
181 | | -# cleanup_model_resources( |
182 | | -# sagemaker_session=model_builder.sagemaker_session, |
183 | | -# model_name=model.name, |
184 | | -# endpoint_name=model.endpoint_name, |
185 | | -# ) |
186 | | -# if caught_ex: |
187 | | -# logger.exception(caught_ex) |
188 | | -# assert ( |
189 | | -# False |
190 | | -# ), f"{caught_ex} was thrown when running pytorch squeezenet sagemaker endpoint test" |
| 181 | +@pytest.mark.skipif( |
| 182 | + NOT_RUNNING_ON_PY310, # or NOT_RUNNING_ON_INF_EXP_DEV_PIPELINE, |
| 183 | + reason="The goal of these test are to test the serving components of our feature", |
| 184 | +) |
| 185 | +@pytest.mark.parametrize( |
| 186 | + "model_builder", ["model_builder_inference_spec_schema_builder"], indirect=True |
| 187 | +) |
| 188 | +def test_happy_pytorch_sagemaker_endpoint( |
| 189 | + sagemaker_session, model_builder, cpu_instance_type, test_image |
| 190 | +): |
| 191 | + logger.info("Running in SAGEMAKER_ENDPOINT mode...") |
| 192 | + caught_ex = None |
| 193 | + |
| 194 | + iam_client = sagemaker_session.boto_session.client("iam") |
| 195 | + role_arn = iam_client.get_role(RoleName=ROLE_NAME)["Role"]["Arn"] |
| 196 | + |
| 197 | + model = model_builder.build( |
| 198 | + mode=Mode.SAGEMAKER_ENDPOINT, role_arn=role_arn, sagemaker_session=sagemaker_session |
| 199 | + ) |
| 200 | + |
| 201 | + with timeout(minutes=SERVE_SAGEMAKER_ENDPOINT_TIMEOUT): |
| 202 | + try: |
| 203 | + logger.info("Deploying and predicting in SAGEMAKER_ENDPOINT mode...") |
| 204 | + predictor = model.deploy(instance_type=cpu_instance_type, initial_instance_count=1) |
| 205 | + logger.info("Endpoint successfully deployed.") |
| 206 | + predictor.predict(test_image) |
| 207 | + except Exception as e: |
| 208 | + caught_ex = e |
| 209 | + finally: |
| 210 | + cleanup_model_resources( |
| 211 | + sagemaker_session=model_builder.sagemaker_session, |
| 212 | + model_name=model.name, |
| 213 | + endpoint_name=model.endpoint_name, |
| 214 | + ) |
| 215 | + if caught_ex: |
| 216 | + logger.exception(caught_ex) |
| 217 | + assert ( |
| 218 | + False |
| 219 | + ), f"{caught_ex} was thrown when running pytorch squeezenet sagemaker endpoint test" |
191 | 220 |
|
192 | 221 |
|
193 | 222 | # @pytest.mark.skipif( |
|
0 commit comments