Skip to content

Commit 04c4a56

Browse files
author
EC2 Default User
committed
unit testing for tgi
1 parent 27b1eb8 commit 04c4a56

File tree

2 files changed

+166
-1
lines changed

2 files changed

+166
-1
lines changed

src/sagemaker/serve/model_server/tgi/prepare.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,6 @@ def _copy_jumpstart_artifacts(model_data: str, js_id: str, code_dir: Path) -> bo
4040
resources.extractall(path=code_dir)
4141
else:
4242
logger.info("Copying uncompressed JumpStart artifacts...")
43-
raise Exception(s3_downloader)
4443
s3_downloader.download(model_data, code_dir)
4544
elif isinstance(model_data, dict): # if dict assume that it is uncompressed
4645
logger.info("Copying uncompressed JumpStart artifacts...")
Lines changed: 166 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,166 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
from __future__ import absolute_import
14+
15+
from unittest import TestCase
16+
from unittest.mock import Mock, PropertyMock, patch
17+
18+
from sagemaker.serve.model_server.tgi.prepare import prepare_tgi_js_resources
19+
20+
MOCK_MODEL_PATH = "/path/to/mock/model/dir"
21+
MOCK_CODE_DIR = "/path/to/mock/model/dir/code"
22+
MOCK_JUMPSTART_ID = "mock_llm_js_id"
23+
MOCK_TMP_DIR = "tmp123456"
24+
MOCK_COMPRESSED_MODEL_DATA_STR = "s3://jumpstart-cache-prod-us-west-2/huggingface-infer/prepack/v1.0.1/infer-prepack-huggingface-llm-falcon-7b-bf16.tar.gz"
25+
MOCK_UNCOMPRESSED_MODEL_DATA_STR = "s3://jumpstart-cache-prod-us-west-2/huggingface-llm/huggingface-llm-falcon-7b-bf16/artifacts/inference-prepack/v1.0.1/"
26+
MOCK_UNCOMPRESSED_MODEL_DATA_STR_FOR_DICT = "s3://jumpstart-cache-prod-us-west-2/huggingface-llm/huggingface-llm-falcon-7b-bf16/artifacts/inference-prepack/v1.0.1/dict/"
27+
MOCK_UNCOMPRESSED_MODEL_DATA_DICT = {
28+
"S3DataSource": {
29+
"S3Uri": MOCK_UNCOMPRESSED_MODEL_DATA_STR_FOR_DICT,
30+
"S3DataType": "S3Prefix",
31+
"CompressionType": "None",
32+
}
33+
}
34+
35+
36+
class TgiPrepareTests(TestCase):
37+
@patch("sagemaker.serve.model_server.tgi.prepare._check_disk_space")
38+
@patch("sagemaker.serve.model_server.tgi.prepare._check_docker_disk_usage")
39+
@patch("sagemaker.serve.model_server.tgi.prepare.Path")
40+
@patch("sagemaker.serve.model_server.tgi.prepare.S3Downloader")
41+
def test_prepare_tgi_js_resources_for_jumpstart_uncompressed_str(
42+
self, mock_s3downloader, mock_path, mock_disk_usage, mock_disk_space
43+
):
44+
# mock actions
45+
(
46+
mock_model_path,
47+
mock_code_dir,
48+
mocked_s3_downloader_obj,
49+
) = self.populate_standard_resource_mocks(mock_path, mock_s3downloader)
50+
51+
# invoke prepare
52+
prepare_tgi_js_resources(
53+
model_path=MOCK_MODEL_PATH,
54+
js_id=MOCK_JUMPSTART_ID,
55+
model_data=MOCK_UNCOMPRESSED_MODEL_DATA_STR,
56+
)
57+
58+
# validate call chain
59+
self.validate_standard_resource_mocks(
60+
mock_model_path, mock_code_dir, mock_disk_space, mock_disk_usage
61+
)
62+
mocked_s3_downloader_obj.download.assert_called_once_with(
63+
MOCK_UNCOMPRESSED_MODEL_DATA_STR, mock_code_dir
64+
)
65+
66+
@patch("sagemaker.serve.model_server.tgi.prepare._check_disk_space")
67+
@patch("sagemaker.serve.model_server.tgi.prepare._check_docker_disk_usage")
68+
@patch("sagemaker.serve.model_server.tgi.prepare.Path")
69+
@patch("sagemaker.serve.model_server.tgi.prepare.S3Downloader")
70+
def test_prepare_tgi_js_resources_for_jumpstart_uncompresssed_dict(
71+
self, mock_s3downloader, mock_path, mock_disk_usage, mock_disk_space
72+
):
73+
# mock actions
74+
(
75+
mock_model_path,
76+
mock_code_dir,
77+
mocked_s3_downloader_obj,
78+
) = self.populate_standard_resource_mocks(mock_path, mock_s3downloader)
79+
80+
# invoke prepare
81+
prepare_tgi_js_resources(
82+
model_path=MOCK_MODEL_PATH,
83+
js_id=MOCK_JUMPSTART_ID,
84+
model_data=MOCK_UNCOMPRESSED_MODEL_DATA_DICT,
85+
)
86+
87+
# validate call chain
88+
self.validate_standard_resource_mocks(
89+
mock_model_path, mock_code_dir, mock_disk_space, mock_disk_usage
90+
)
91+
mocked_s3_downloader_obj.download.assert_called_once_with(
92+
MOCK_UNCOMPRESSED_MODEL_DATA_STR_FOR_DICT, mock_code_dir
93+
)
94+
95+
@patch("sagemaker.serve.model_server.tgi.prepare._check_disk_space")
96+
@patch("sagemaker.serve.model_server.tgi.prepare._check_docker_disk_usage")
97+
@patch("sagemaker.serve.model_server.tgi.prepare.Path")
98+
@patch("sagemaker.serve.model_server.tgi.prepare.tarfile")
99+
@patch("sagemaker.serve.model_server.tgi.prepare.S3Downloader")
100+
@patch("sagemaker.serve.model_server.tgi.prepare._tmpdir")
101+
def test_prepare_tgi_js_resources_for_jumpstart_compressed_str(
102+
self,
103+
mock_tmpdir,
104+
mock_s3downloader,
105+
mock_tarfile,
106+
mock_path,
107+
mock_disk_usage,
108+
mock_disk_space,
109+
):
110+
# mock actions
111+
mock_model_path = Mock()
112+
mock_model_path.exists.return_value = False
113+
mock_code_dir = Mock()
114+
mock_model_path.joinpath.return_value = mock_code_dir
115+
116+
mock_tmp_js_dir = Mock()
117+
mock_tmp_sourcedir = Mock()
118+
mock_tmp_js_dir.joinpath.return_value = mock_tmp_sourcedir
119+
mock_path.side_effect = [mock_model_path, mock_tmp_js_dir]
120+
121+
mocked_s3_downloader_obj = Mock()
122+
mock_s3downloader.return_value = mocked_s3_downloader_obj
123+
124+
mock_tmpdir_obj = Mock()
125+
type(mock_tmpdir_obj).__enter__ = PropertyMock(return_value=Mock())
126+
type(mock_tmpdir_obj).__exit__ = PropertyMock(return_value=Mock())
127+
mock_tmpdir.return_value = mock_tmpdir_obj
128+
129+
mock_resources = Mock()
130+
mock_tarfile.open.return_value = mock_resources
131+
132+
# invoke prepare
133+
prepare_tgi_js_resources(
134+
model_path=MOCK_MODEL_PATH,
135+
js_id=MOCK_JUMPSTART_ID,
136+
model_data=MOCK_COMPRESSED_MODEL_DATA_STR,
137+
)
138+
139+
# validate call chain
140+
self.validate_standard_resource_mocks(
141+
mock_model_path, mock_code_dir, mock_disk_space, mock_disk_usage
142+
)
143+
mocked_s3_downloader_obj.download.assert_called_once_with(
144+
MOCK_COMPRESSED_MODEL_DATA_STR, mock_tmpdir_obj
145+
)
146+
147+
def populate_standard_resource_mocks(self, mock_path, mock_s3downloader):
148+
mock_model_path = Mock()
149+
mock_model_path.exists.return_value = False
150+
mock_code_dir = Mock()
151+
mock_model_path.joinpath.return_value = mock_code_dir
152+
mock_path.return_value = mock_model_path
153+
154+
mocked_s3_downloader_obj = Mock()
155+
mock_s3downloader.return_value = mocked_s3_downloader_obj
156+
157+
return mock_model_path, mock_code_dir, mocked_s3_downloader_obj
158+
159+
def validate_standard_resource_mocks(
160+
self, mock_model_path, mock_code_dir, mock_disk_space, mock_disk_usage
161+
):
162+
mock_model_path.mkdir.assert_called_once_with(parents=True)
163+
mock_model_path.joinpath.assert_called_once_with("code")
164+
mock_code_dir.mkdir.assert_called_once_with(exist_ok=True, parents=True)
165+
mock_disk_space.assert_called_once_with(mock_model_path)
166+
mock_disk_usage.assert_called_once()

0 commit comments

Comments
 (0)