Skip to content

Commit 4978670

Browse files
authored
Merge branch 'master' into mwfongAWS-SM-sdk
2 parents 34c9d1b + 4884d18 commit 4978670

File tree

5 files changed

+12
-6
lines changed

5 files changed

+12
-6
lines changed

src/sagemaker/djl_inference/model.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -135,15 +135,17 @@ def _read_existing_serving_properties(directory: str):
135135
return properties
136136

137137

138-
def _get_model_config_properties_from_s3(model_s3_uri: str):
138+
def _get_model_config_properties_from_s3(model_s3_uri: str, sagemaker_session: Session):
139139
"""Placeholder docstring"""
140140

141-
s3_files = s3.S3Downloader.list(model_s3_uri)
141+
s3_files = s3.S3Downloader.list(model_s3_uri, sagemaker_session=sagemaker_session)
142142
model_config = None
143143
for config in defaults.VALID_MODEL_CONFIG_FILES:
144144
config_file = os.path.join(model_s3_uri, config)
145145
if config_file in s3_files:
146-
model_config = json.loads(s3.S3Downloader.read_file(config_file))
146+
model_config = json.loads(
147+
s3.S3Downloader.read_file(config_file, sagemaker_session=sagemaker_session)
148+
)
147149
break
148150
if not model_config:
149151
raise ValueError(
@@ -198,7 +200,8 @@ def __new__(
198200
"containing folder"
199201
)
200202
if model_id.startswith("s3://"):
201-
model_config = _get_model_config_properties_from_s3(model_id)
203+
sagemaker_session = kwargs.get("sagemaker_session")
204+
model_config = _get_model_config_properties_from_s3(model_id, sagemaker_session)
202205
else:
203206
model_config = _get_model_config_properties_from_hf(model_id)
204207
if model_config.get("_class_name") == "StableDiffusionPipeline":

src/sagemaker/model.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -574,6 +574,7 @@ def _upload_code(self, key_prefix: str, repack: bool = False) -> None:
574574
script=self.entry_point,
575575
directory=self.source_dir,
576576
dependencies=self.dependencies,
577+
kms_key=self.model_kms_key,
577578
settings=self.sagemaker_session.settings,
578579
)
579580

tests/conftest.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@
5252
"me-south-1",
5353
"sa-east-1",
5454
"us-west-1",
55+
"ap-south-1", # no p3 availability
5556
]
5657

5758
NO_T2_REGIONS = ["eu-north-1", "ap-east-1", "me-south-1"]

tests/integ/test_inference_recommender.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -301,7 +301,6 @@ def test_default_right_size_and_deploy_registered_model_sklearn(
301301
predictor.delete_endpoint()
302302

303303

304-
@pytest.mark.skip(reason="This test is currently failing. Skipping until fixed")
305304
@pytest.mark.slow_test
306305
def test_default_right_size_and_deploy_unregistered_model_sklearn(
307306
default_right_sized_unregistered_model, sagemaker_session
@@ -346,7 +345,6 @@ def test_default_right_size_and_deploy_unregistered_base_model(
346345
predictor.delete_endpoint()
347346

348347

349-
@pytest.mark.skip(reason="This test is currently failing. Skipping until fixed")
350348
@pytest.mark.slow_test
351349
def test_advanced_right_size_and_deploy_unregistered_model_sklearn(
352350
advanced_right_sized_unregistered_model, sagemaker_session

tests/unit/test_djl_inference.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,9 @@ def test_create_model_automatic_engine_selection(mock_s3_list, mock_read_file, s
174174
sagemaker_session=sagemaker_session,
175175
number_of_partitions=2,
176176
)
177+
mock_s3_list.assert_any_call(
178+
VALID_UNCOMPRESSED_MODEL_DATA, sagemaker_session=sagemaker_session
179+
)
177180
if model_type == defaults.STABLE_DIFFUSION_MODEL_TYPE:
178181
assert ds_model.engine == DJLServingEngineEntryPointDefaults.STABLE_DIFFUSION
179182
else:

0 commit comments

Comments
 (0)