Skip to content

Commit 3b2e493

Browse files
author
Ruban Hussain
committed
feature: SDK Defaults Config - SDK-wide S3 Key Prefix - more unit tests
1 parent 5f20c73 commit 3b2e493

File tree

6 files changed

+613
-4
lines changed

6 files changed

+613
-4
lines changed

tests/unit/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -314,13 +314,15 @@ def _test_default_bucket_and_prefix_combinations(
314314
default_bucket=Mock(name="default_bucket", return_value=DEFAULT_S3_BUCKET_NAME),
315315
default_bucket_prefix=DEFAULT_S3_OBJECT_KEY_PREFIX_NAME,
316316
config=None,
317+
settings=None,
317318
),
318319
session_with_bucket_and_no_prefix=Mock(
319320
name="sagemaker_session",
320321
sagemaker_config={},
321322
default_bucket_prefix=None,
322323
default_bucket=Mock(name="default_bucket", return_value=DEFAULT_S3_BUCKET_NAME),
323324
config=None,
325+
settings=None,
324326
),
325327
):
326328
"""

tests/unit/sagemaker/experiments/test_helper.py

Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,11 @@
2626
)
2727
from src.sagemaker.experiments._utils import resolve_artifact_name
2828
from src.sagemaker.session import Session
29+
from tests.unit import (
30+
_test_default_bucket_and_prefix_combinations,
31+
DEFAULT_S3_OBJECT_KEY_PREFIX_NAME,
32+
DEFAULT_S3_BUCKET_NAME,
33+
)
2934

3035

3136
@pytest.fixture
@@ -193,3 +198,119 @@ def test_artifact_uploader_upload_object_artifact(tempdir, artifact_uploader):
193198

194199
expected_uri = "s3://{}/{}".format(artifact_uploader.artifact_bucket, expected_key)
195200
assert expected_uri == s3_uri
201+
202+
203+
def test_upload_artifact__default_bucket_and_prefix_combinations(tempdir):
204+
path = os.path.join(tempdir, "exists")
205+
with open(path, "a") as f:
206+
f.write("boo")
207+
208+
def with_user_input(sess):
209+
artifact_uploader = _ArtifactUploader(
210+
trial_component_name="trial_component_name",
211+
artifact_bucket="artifact_bucket",
212+
artifact_prefix="artifact_prefix",
213+
sagemaker_session=sess,
214+
)
215+
artifact_uploader._s3_client.head_object.return_value = {"ETag": "etag_value"}
216+
s3_uri, etag = artifact_uploader.upload_artifact(path)
217+
s3_uri_2, etag_2 = artifact_uploader.upload_artifact(path)
218+
return s3_uri, s3_uri_2
219+
220+
def without_user_input(sess):
221+
artifact_uploader = _ArtifactUploader(
222+
trial_component_name="trial_component_name",
223+
sagemaker_session=sess,
224+
)
225+
artifact_uploader._s3_client.head_object.return_value = {"ETag": "etag_value"}
226+
s3_uri, etag = artifact_uploader.upload_artifact(path)
227+
s3_uri_2, etag_2 = artifact_uploader.upload_artifact(path)
228+
return s3_uri, s3_uri_2
229+
230+
actual, expected = _test_default_bucket_and_prefix_combinations(
231+
function_with_user_input=with_user_input,
232+
function_without_user_input=without_user_input,
233+
expected__without_user_input__with_default_bucket_and_default_prefix=(
234+
f"s3://{DEFAULT_S3_BUCKET_NAME}/{DEFAULT_S3_OBJECT_KEY_PREFIX_NAME}/"
235+
+ "trial-component-artifacts/trial_component_name/exists",
236+
f"s3://{DEFAULT_S3_BUCKET_NAME}/{DEFAULT_S3_OBJECT_KEY_PREFIX_NAME}/"
237+
+ "trial-component-artifacts/trial_component_name/exists",
238+
),
239+
expected__without_user_input__with_default_bucket_only=(
240+
f"s3://{DEFAULT_S3_BUCKET_NAME}/trial-component-artifacts/trial_component_name/exists",
241+
f"s3://{DEFAULT_S3_BUCKET_NAME}/trial-component-artifacts/trial_component_name/exists",
242+
),
243+
expected__with_user_input__with_default_bucket_and_prefix=(
244+
"s3://artifact_bucket/artifact_prefix/trial_component_name/exists",
245+
"s3://artifact_bucket/artifact_prefix/trial_component_name/exists",
246+
),
247+
expected__with_user_input__with_default_bucket_only=(
248+
"s3://artifact_bucket/artifact_prefix/trial_component_name/exists",
249+
"s3://artifact_bucket/artifact_prefix/trial_component_name/exists",
250+
),
251+
)
252+
assert actual == expected
253+
254+
255+
def test_upload_object_artifact__default_bucket_and_prefix_combinations(tempdir):
256+
path = os.path.join(tempdir, "exists")
257+
with open(path, "a") as f:
258+
f.write("boo")
259+
260+
artifact_name = "my-artifact"
261+
artifact_object = {"key": "value"}
262+
file_extension = ".csv"
263+
264+
def with_user_input(sess):
265+
artifact_uploader = _ArtifactUploader(
266+
trial_component_name="trial_component_name",
267+
artifact_bucket="artifact_bucket",
268+
artifact_prefix="artifact_prefix",
269+
sagemaker_session=sess,
270+
)
271+
artifact_uploader._s3_client.head_object.return_value = {"ETag": "etag_value"}
272+
s3_uri, etag = artifact_uploader.upload_object_artifact(
273+
artifact_name, artifact_object, file_extension
274+
)
275+
s3_uri_2, etag_2 = artifact_uploader.upload_object_artifact(
276+
artifact_name, artifact_object, file_extension
277+
)
278+
return s3_uri, s3_uri_2
279+
280+
def without_user_input(sess):
281+
artifact_uploader = _ArtifactUploader(
282+
trial_component_name="trial_component_name",
283+
sagemaker_session=sess,
284+
)
285+
artifact_uploader._s3_client.head_object.return_value = {"ETag": "etag_value"}
286+
s3_uri, etag = artifact_uploader.upload_object_artifact(
287+
artifact_name, artifact_object, file_extension
288+
)
289+
s3_uri_2, etag_2 = artifact_uploader.upload_object_artifact(
290+
artifact_name, artifact_object, file_extension
291+
)
292+
return s3_uri, s3_uri_2
293+
294+
actual, expected = _test_default_bucket_and_prefix_combinations(
295+
function_with_user_input=with_user_input,
296+
function_without_user_input=without_user_input,
297+
expected__without_user_input__with_default_bucket_and_default_prefix=(
298+
f"s3://{DEFAULT_S3_BUCKET_NAME}/{DEFAULT_S3_OBJECT_KEY_PREFIX_NAME}/"
299+
+ "trial-component-artifacts/trial_component_name/my-artifact.csv",
300+
f"s3://{DEFAULT_S3_BUCKET_NAME}/{DEFAULT_S3_OBJECT_KEY_PREFIX_NAME}/"
301+
+ "trial-component-artifacts/trial_component_name/my-artifact.csv",
302+
),
303+
expected__without_user_input__with_default_bucket_only=(
304+
f"s3://{DEFAULT_S3_BUCKET_NAME}/trial-component-artifacts/trial_component_name/my-artifact.csv",
305+
f"s3://{DEFAULT_S3_BUCKET_NAME}/trial-component-artifacts/trial_component_name/my-artifact.csv",
306+
),
307+
expected__with_user_input__with_default_bucket_and_prefix=(
308+
"s3://artifact_bucket/artifact_prefix/trial_component_name/my-artifact.csv",
309+
"s3://artifact_bucket/artifact_prefix/trial_component_name/my-artifact.csv",
310+
),
311+
expected__with_user_input__with_default_bucket_only=(
312+
"s3://artifact_bucket/artifact_prefix/trial_component_name/my-artifact.csv",
313+
"s3://artifact_bucket/artifact_prefix/trial_component_name/my-artifact.csv",
314+
),
315+
)
316+
assert actual == expected

tests/unit/sagemaker/local/test_local_session.py

Lines changed: 76 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
import os
1818
from botocore.exceptions import ClientError
1919
from mock import Mock, patch
20-
from tests.unit import DATA_DIR
20+
from tests.unit import DATA_DIR, SAGEMAKER_CONFIG_SESSION
2121

2222
import sagemaker
2323
from sagemaker.workflow.parameters import ParameterString
@@ -956,3 +956,78 @@ def test_start_undefined_pipeline():
956956
with pytest.raises(ClientError) as e:
957957
LocalSession().sagemaker_client.start_pipeline_execution("UndefinedPipeline")
958958
assert "Pipeline UndefinedPipeline does not exist" in str(e.value)
959+
960+
961+
def test_default_bucket_with_sagemaker_config(boto_session, client):
962+
# common kwargs for Session objects
963+
session_kwargs = {
964+
"boto_session": boto_session,
965+
}
966+
967+
# Case 1: Use bucket from sagemaker_config
968+
session_with_config_bucket = LocalSession(
969+
default_bucket=None,
970+
sagemaker_config=SAGEMAKER_CONFIG_SESSION,
971+
**session_kwargs,
972+
)
973+
assert (
974+
session_with_config_bucket.default_bucket()
975+
== SAGEMAKER_CONFIG_SESSION["SageMaker"]["PythonSDK"]["Modules"]["Session"][
976+
"SessionDefaultS3Bucket"
977+
]
978+
)
979+
980+
# Case 2: Use bucket from user input to Session (even if sagemaker_config has a bucket)
981+
session_with_user_bucket = LocalSession(
982+
default_bucket="default-bucket",
983+
sagemaker_config=SAGEMAKER_CONFIG_SESSION,
984+
**session_kwargs,
985+
)
986+
assert session_with_user_bucket.default_bucket() == "default-bucket"
987+
988+
# Case 3: Use default bucket of SDK
989+
session_with_sdk_bucket = LocalSession(
990+
default_bucket=None,
991+
sagemaker_config=None,
992+
**session_kwargs,
993+
)
994+
session_with_sdk_bucket.boto_session.client.return_value = Mock(
995+
get_caller_identity=Mock(return_value={"Account": "111111111"})
996+
)
997+
assert session_with_sdk_bucket.default_bucket() == "sagemaker-us-west-2-111111111"
998+
999+
1000+
def test_default_bucket_prefix_with_sagemaker_config(boto_session, client):
1001+
# common kwargs for Session objects
1002+
session_kwargs = {
1003+
"boto_session": boto_session,
1004+
}
1005+
1006+
# Case 1: Use prefix from sagemaker_config
1007+
session_with_config_prefix = LocalSession(
1008+
default_bucket_prefix=None,
1009+
sagemaker_config=SAGEMAKER_CONFIG_SESSION,
1010+
**session_kwargs,
1011+
)
1012+
assert (
1013+
session_with_config_prefix.default_bucket_prefix
1014+
== SAGEMAKER_CONFIG_SESSION["SageMaker"]["PythonSDK"]["Modules"]["Session"][
1015+
"SessionDefaultS3ObjectKeyPrefix"
1016+
]
1017+
)
1018+
1019+
# Case 2: Use prefix from user input to Session (even if sagemaker_config has a prefix)
1020+
session_with_user_prefix = LocalSession(
1021+
default_bucket_prefix="default-prefix",
1022+
sagemaker_config=SAGEMAKER_CONFIG_SESSION,
1023+
**session_kwargs,
1024+
)
1025+
assert session_with_user_prefix.default_bucket_prefix == "default-prefix"
1026+
1027+
# Case 3: Neither the user input or config has the prefix
1028+
session_with_no_prefix = LocalSession(
1029+
default_bucket_prefix=None,
1030+
sagemaker_config=None,
1031+
**session_kwargs,
1032+
)
1033+
assert session_with_no_prefix.default_bucket_prefix is None

tests/unit/sagemaker/model/test_model.py

Lines changed: 116 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from mock import Mock, patch
1818

1919
import sagemaker
20+
from sagemaker.async_inference import AsyncInferenceConfig
2021
from sagemaker.model import FrameworkModel, Model
2122
from sagemaker.huggingface.model import HuggingFaceModel
2223
from sagemaker.jumpstart.constants import JUMPSTART_BUCKET_NAME_SET, JUMPSTART_RESOURCE_BASE_NAME
@@ -27,7 +28,11 @@
2728
from sagemaker.tensorflow.model import TensorFlowModel
2829
from sagemaker.xgboost.model import XGBoostModel
2930
from sagemaker.workflow.properties import Properties
30-
31+
from tests.unit import (
32+
_test_default_bucket_and_prefix_combinations,
33+
DEFAULT_S3_BUCKET_NAME,
34+
DEFAULT_S3_OBJECT_KEY_PREFIX_NAME,
35+
)
3136

3237
MODEL_DATA = "s3://bucket/model.tar.gz"
3338
MODEL_IMAGE = "mi"
@@ -806,3 +811,113 @@ def test_model_local_download_dir(repack_model, sagemaker_session):
806811
repack_model.call_args_list[0][1]["sagemaker_session"].settings.local_download_dir
807812
== local_download_dir
808813
)
814+
815+
816+
@patch("sagemaker.model.fw_utils.tar_and_upload_dir")
817+
def test__upload_code__default_bucket_and_prefix_combinations(
818+
tar_and_upload_dir,
819+
):
820+
def with_user_input(sess):
821+
model = Model(
822+
entry_point=ENTRY_POINT_INFERENCE,
823+
role=ROLE,
824+
sagemaker_session=sess,
825+
image_uri=IMAGE_URI,
826+
model_data=MODEL_DATA,
827+
code_location="s3://test-bucket/test-prefix/test-prefix-2",
828+
)
829+
model._upload_code("upload-prefix/upload-prefix-2", repack=False)
830+
kwargs = tar_and_upload_dir.call_args.kwargs
831+
return kwargs["bucket"], kwargs["s3_key_prefix"]
832+
833+
def without_user_input(sess):
834+
model = Model(
835+
entry_point=ENTRY_POINT_INFERENCE,
836+
role=ROLE,
837+
sagemaker_session=sess,
838+
image_uri=IMAGE_URI,
839+
model_data=MODEL_DATA,
840+
)
841+
model._upload_code("upload-prefix/upload-prefix-2", repack=False)
842+
kwargs = tar_and_upload_dir.call_args.kwargs
843+
return kwargs["bucket"], kwargs["s3_key_prefix"]
844+
845+
actual, expected = _test_default_bucket_and_prefix_combinations(
846+
function_with_user_input=with_user_input,
847+
function_without_user_input=without_user_input,
848+
expected__without_user_input__with_default_bucket_and_default_prefix=(
849+
DEFAULT_S3_BUCKET_NAME,
850+
f"{DEFAULT_S3_OBJECT_KEY_PREFIX_NAME}/upload-prefix/upload-prefix-2",
851+
),
852+
expected__without_user_input__with_default_bucket_only=(
853+
DEFAULT_S3_BUCKET_NAME,
854+
"upload-prefix/upload-prefix-2",
855+
),
856+
expected__with_user_input__with_default_bucket_and_prefix=(
857+
"test-bucket",
858+
"upload-prefix/upload-prefix-2",
859+
),
860+
expected__with_user_input__with_default_bucket_only=(
861+
"test-bucket",
862+
"upload-prefix/upload-prefix-2",
863+
),
864+
)
865+
assert actual == expected
866+
867+
868+
@patch("sagemaker.model.unique_name_from_base")
869+
def test__build_default_async_inference_config__default_bucket_and_prefix_combinations(
870+
unique_name_from_base,
871+
):
872+
unique_name_from_base.return_value = "unique-name"
873+
874+
def with_user_input(sess):
875+
model = Model(
876+
entry_point=ENTRY_POINT_INFERENCE,
877+
role=ROLE,
878+
sagemaker_session=sess,
879+
image_uri=IMAGE_URI,
880+
model_data=MODEL_DATA,
881+
code_location="s3://test-bucket/test-prefix/test-prefix-2",
882+
)
883+
async_config = AsyncInferenceConfig(
884+
output_path="s3://output-bucket/output-prefix/output-prefix-2",
885+
failure_path="s3://failure-bucket/failure-prefix/failure-prefix-2",
886+
)
887+
model._build_default_async_inference_config(async_config)
888+
return async_config.output_path, async_config.failure_path
889+
890+
def without_user_input(sess):
891+
model = Model(
892+
entry_point=ENTRY_POINT_INFERENCE,
893+
role=ROLE,
894+
sagemaker_session=sess,
895+
image_uri=IMAGE_URI,
896+
model_data=MODEL_DATA,
897+
code_location="s3://test-bucket/test-prefix/test-prefix-2",
898+
)
899+
async_config = AsyncInferenceConfig()
900+
model._build_default_async_inference_config(async_config)
901+
return async_config.output_path, async_config.failure_path
902+
903+
actual, expected = _test_default_bucket_and_prefix_combinations(
904+
function_with_user_input=with_user_input,
905+
function_without_user_input=without_user_input,
906+
expected__without_user_input__with_default_bucket_and_default_prefix=(
907+
f"s3://{DEFAULT_S3_BUCKET_NAME}/{DEFAULT_S3_OBJECT_KEY_PREFIX_NAME}/async-endpoint-outputs/unique-name",
908+
f"s3://{DEFAULT_S3_BUCKET_NAME}/{DEFAULT_S3_OBJECT_KEY_PREFIX_NAME}/async-endpoint-failures/unique-name",
909+
),
910+
expected__without_user_input__with_default_bucket_only=(
911+
f"s3://{DEFAULT_S3_BUCKET_NAME}/async-endpoint-outputs/unique-name",
912+
f"s3://{DEFAULT_S3_BUCKET_NAME}/async-endpoint-failures/unique-name",
913+
),
914+
expected__with_user_input__with_default_bucket_and_prefix=(
915+
"s3://output-bucket/output-prefix/output-prefix-2",
916+
"s3://failure-bucket/failure-prefix/failure-prefix-2",
917+
),
918+
expected__with_user_input__with_default_bucket_only=(
919+
"s3://output-bucket/output-prefix/output-prefix-2",
920+
"s3://failure-bucket/failure-prefix/failure-prefix-2",
921+
),
922+
)
923+
assert actual == expected

0 commit comments

Comments
 (0)