Skip to content

Commit 2e5223d

Browse files
author
Ruban Hussain
committed
feature: SDK Defaults Config - Support for Session default_bucket_prefix & ability to set an SDK-wide S3 Key Prefix through Session
1 parent a46deac commit 2e5223d

31 files changed

+468
-140
lines changed

src/sagemaker/amazon/amazon_estimator.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020

2121
from six.moves.urllib.parse import urlparse
2222

23-
from sagemaker import image_uris
23+
from sagemaker import image_uris, s3_utils
2424
from sagemaker.amazon import validation
2525
from sagemaker.amazon.hyperparameter import Hyperparameter as hp # noqa
2626
from sagemaker.amazon.common import write_numpy_to_dense_tensor
@@ -93,8 +93,15 @@ def __init__(
9393
enable_network_isolation=enable_network_isolation,
9494
**kwargs
9595
)
96-
data_location = data_location or "s3://{}/sagemaker-record-sets/".format(
97-
self.sagemaker_session.default_bucket()
96+
97+
data_location = data_location or (
98+
s3_utils.s3_path_join(
99+
"s3://",
100+
self.sagemaker_session.default_bucket(),
101+
self.sagemaker_session.default_bucket_prefix,
102+
"sagemaker-record-sets",
103+
)
104+
+ "/"
98105
)
99106
self._data_location = data_location
100107

src/sagemaker/automl/automl.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from typing import Optional, List, Dict
1818
from six import string_types
1919

20-
from sagemaker import Model, PipelineModel
20+
from sagemaker import Model, PipelineModel, s3
2121
from sagemaker.automl.candidate_estimator import CandidateEstimator
2222
from sagemaker.config import (
2323
AUTO_ML_ROLE_ARN_PATH,
@@ -663,7 +663,14 @@ def _prepare_for_auto_ml_job(self, job_name=None):
663663
self.current_job_name = name_from_base(base_name, max_length=32)
664664

665665
if self.output_path is None:
666-
self.output_path = "s3://{}/".format(self.sagemaker_session.default_bucket())
666+
self.output_path = (
667+
s3.s3_path_join(
668+
"s3://",
669+
self.sagemaker_session.default_bucket(),
670+
self.sagemaker_session.default_bucket_prefix,
671+
)
672+
+ "/"
673+
)
667674

668675
@classmethod
669676
def _get_supported_inference_keys(cls, container, default=None):

src/sagemaker/config/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,7 @@
9494
AUTO_ML_INTER_CONTAINER_ENCRYPTION_PATH,
9595
ENDPOINT_CONFIG_DATA_CAPTURE_KMS_KEY_ID_PATH,
9696
SESSION_S3_BUCKET_PATH,
97+
SESSION_S3_OBJECT_KEY_PREFIX_PATH,
9798
MONITORING_SCHEDULE_CONFIG,
9899
MONITORING_JOB_DEFINITION,
99100
MONITORING_OUTPUT_CONFIG,

src/sagemaker/config/config_schema.py

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,7 @@
9191
ENABLE_INTER_CONTAINER_TRAFFIC_ENCRYPTION = "EnableInterContainerTrafficEncryption"
9292
SESSION = "Session"
9393
S3_BUCKET = "S3Bucket"
94+
S3_OBJECT_KEY_PREFIX = "S3ObjectKeyPrefix"
9495

9596

9697
def _simple_path(*args: str):
@@ -298,6 +299,10 @@ def _simple_path(*args: str):
298299
SAGEMAKER, TRAINING_JOB, ENABLE_INTER_CONTAINER_TRAFFIC_ENCRYPTION
299300
)
300301
SESSION_S3_BUCKET_PATH = _simple_path(SAGEMAKER, PYTHON_SDK, MODULES, SESSION, S3_BUCKET)
302+
SESSION_S3_OBJECT_KEY_PREFIX_PATH = _simple_path(
303+
SAGEMAKER, PYTHON_SDK, MODULES, SESSION, S3_OBJECT_KEY_PREFIX
304+
)
305+
301306

302307
SAGEMAKER_PYTHON_SDK_CONFIG_SCHEMA = {
303308
"$schema": "https://json-schema.org/draft/2020-12/schema",
@@ -447,7 +452,6 @@ def _simple_path(*args: str):
447452
"s3Uri": {TYPE: "string", "pattern": "^(https|s3)://([^/]+)/?(.*)$", "maxLength": 1024},
448453
# Regex is taken from https://docs.aws.amazon.com/sagemaker/latest/APIReference/API_AlgorithmSpecification.html#sagemaker-Type-AlgorithmSpecification-ContainerEntrypoint
449454
"preExecutionCommand": {TYPE: "string", "pattern": r".*"},
450-
451455
# Regex based on https://docs.aws.amazon.com/sagemaker/latest/APIReference/API_PipelineDefinitionS3Location.html
452456
# except with an additional ^ and $ for the beginning and the end to closer align to
453457
# https://docs.aws.amazon.com/AmazonS3/latest/userguide/bucketnamingrules.html
@@ -492,9 +496,26 @@ def _simple_path(*args: str):
492496
ADDITIONAL_PROPERTIES: False,
493497
PROPERTIES: {
494498
S3_BUCKET: {
495-
"description": "Used as `default_bucket` of Session",
499+
"description": "sets `default_bucket` of Session",
496500
"$ref": "#/definitions/s3Bucket",
497501
},
502+
S3_OBJECT_KEY_PREFIX: {
503+
"description": (
504+
"sets `default_s3_object_key_prefix` of Session"
505+
),
506+
TYPE: "string",
507+
# Regex based on
508+
# https://docs.aws.amazon.com/AmazonS3/latest/userguide/object-keys.html
509+
# For now, the regex only allows the "safe characters"
510+
# specified by S3. If needed, the regex can be loosened
511+
# (but not tightened) in the future without
512+
# introducing backward incompatibility.
513+
"pattern": (
514+
r"^[a-zA-Z0-9!_.*'()-]+(/[a-zA-Z0-9!_.*'()-]+)*$"
515+
),
516+
"minLength": 0,
517+
"maxLength": 1024,
518+
},
498519
},
499520
},
500521
REMOTE_FUNCTION: {

src/sagemaker/djl_inference/model.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
from typing import Optional, Union, Dict, Any
2424

2525
import sagemaker
26-
from sagemaker import s3, Predictor, image_uris, fw_utils
26+
from sagemaker import s3, Predictor, image_uris, fw_utils, s3_utils
2727
from sagemaker.deserializers import JSONDeserializer, BaseDeserializer
2828
from sagemaker.djl_inference import defaults
2929
from sagemaker.model import FrameworkModel
@@ -527,7 +527,9 @@ def prepare_container_def(
527527
deploy_key_prefix = fw_utils.model_code_key_prefix(
528528
self.key_prefix, self.name, self.image_uri
529529
)
530-
bucket = self.bucket or self.sagemaker_session.default_bucket()
530+
bucket, deploy_key_prefix = s3_utils.calculate_bucket_and_prefix(
531+
self.bucket, deploy_key_prefix, self.sagemaker_session
532+
)
531533
uploaded_code = fw_utils.tar_and_upload_dir(
532534
self.sagemaker_session.boto_session,
533535
bucket,

src/sagemaker/estimator.py

Lines changed: 38 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
from six.moves.urllib.parse import urlparse
2828

2929
import sagemaker
30-
from sagemaker import git_utils, image_uris, vpc_utils
30+
from sagemaker import git_utils, image_uris, vpc_utils, s3
3131
from sagemaker.analytics import TrainingJobAnalytics
3232
from sagemaker.config import (
3333
TRAINING_JOB_VOLUME_KMS_KEY_ID_PATH,
@@ -672,6 +672,9 @@ def __init__(
672672
enable_network_isolation=self._enable_network_isolation,
673673
)
674674

675+
# Internal flag
676+
self._is_output_path_set_from_default_bucket_and_prefix = False
677+
675678
@abstractmethod
676679
def training_image_uri(self):
677680
"""Return the Docker image to use for training.
@@ -772,7 +775,12 @@ def _prepare_for_training(self, job_name=None):
772775
if self.sagemaker_session.local_mode and local_code:
773776
self.output_path = ""
774777
else:
775-
self.output_path = "s3://{}/".format(self.sagemaker_session.default_bucket())
778+
self.output_path = s3.s3_path_join(
779+
"s3://",
780+
self.sagemaker_session.default_bucket(),
781+
self.sagemaker_session.default_bucket_prefix,
782+
)
783+
self._is_output_path_set_from_default_bucket_and_prefix = True
776784

777785
if self.git_config:
778786
updated_paths = git_utils.git_clone_repo(
@@ -847,7 +855,8 @@ def _stage_user_code_in_s3(self) -> str:
847855
if is_pipeline_variable(self.output_path):
848856
if self.code_location is None:
849857
code_bucket = self.sagemaker_session.default_bucket()
850-
code_s3_prefix = self._assign_s3_prefix()
858+
key_prefix = self.sagemaker_session.default_bucket_prefix
859+
code_s3_prefix = self._assign_s3_prefix(key_prefix)
851860
kms_key = None
852861
else:
853862
code_bucket, key_prefix = parse_s3_url(self.code_location)
@@ -860,16 +869,33 @@ def _stage_user_code_in_s3(self) -> str:
860869
if local_mode:
861870
if self.code_location is None:
862871
code_bucket = self.sagemaker_session.default_bucket()
863-
code_s3_prefix = self._assign_s3_prefix()
872+
key_prefix = self.sagemaker_session.default_bucket_prefix
873+
code_s3_prefix = self._assign_s3_prefix(key_prefix)
864874
kms_key = None
865875
else:
866876
code_bucket, key_prefix = parse_s3_url(self.code_location)
867877
code_s3_prefix = self._assign_s3_prefix(key_prefix)
868878
kms_key = None
869879
else:
870880
if self.code_location is None:
871-
code_bucket, _ = parse_s3_url(self.output_path)
872-
code_s3_prefix = self._assign_s3_prefix()
881+
# TODO: if output_path was set from Session, include prefix. If it was set by
882+
# the user, do not set the prefix (which would change behavior of existing
883+
# notebooks)
884+
code_bucket, possible_key_prefix = parse_s3_url(self.output_path)
885+
886+
if self._is_output_path_set_from_default_bucket_and_prefix:
887+
# Only include possible_key_prefix if the output_path was created from the
888+
# Session's default bucket and prefix. In that scenario, possible_key_prefix
889+
# will either be "" or Session.default_bucket_prefix.
890+
# Note: We cannot do `if (code_bucket == session.default_bucket() and
891+
# key_prefix == session.default_bucket_prefix)` instead because the user
892+
# could have passed in equivalent values themselves to output_path. And
893+
# including the prefix in that case could result in a potentially backwards
894+
# incompatible behavior change for the end user.
895+
code_s3_prefix = self._assign_s3_prefix(possible_key_prefix)
896+
else:
897+
code_s3_prefix = self._assign_s3_prefix()
898+
873899
kms_key = self.output_kms_key
874900
else:
875901
code_bucket, key_prefix = parse_s3_url(self.code_location)
@@ -1060,8 +1086,12 @@ def _set_source_s3_uri(self, rule):
10601086
if "source_s3_uri" in (rule.rule_parameters or {}):
10611087
parse_result = urlparse(rule.rule_parameters["source_s3_uri"])
10621088
if parse_result.scheme != "s3":
1063-
desired_s3_uri = os.path.join(
1064-
"s3://", self.sagemaker_session.default_bucket(), rule.name, str(uuid.uuid4())
1089+
desired_s3_uri = s3.s3_path_join(
1090+
"s3://",
1091+
self.sagemaker_session.default_bucket(),
1092+
self.sagemaker_session.default_bucket_prefix,
1093+
rule.name,
1094+
str(uuid.uuid4()),
10651095
)
10661096
s3_uri = S3Uploader.upload(
10671097
local_path=rule.rule_parameters["source_s3_uri"],

src/sagemaker/experiments/_helper.py

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919

2020
import botocore
2121

22+
from sagemaker import s3_utils
2223
from sagemaker.experiments._utils import is_already_exist_error
2324

2425
logger = logging.getLogger(__name__)
@@ -75,8 +76,15 @@ def upload_artifact(self, file_path):
7576
raise ValueError(
7677
"{} does not exist or is not a file. Please supply a file path.".format(file_path)
7778
)
78-
if not self.artifact_bucket:
79-
self.artifact_bucket = self.sagemaker_session.default_bucket()
79+
80+
# If self.artifact_bucket is falsy, it will be set to sagemaker_session.default_bucket.
81+
# In that case, and if sagemaker_session.default_bucket_prefix exists, self.artifact_prefix
82+
# needs to be updated too (because not updating self.artifact_prefix would result in
83+
# different behavior the 1st time this method is called vs the 2nd).
84+
self.artifact_bucket, self.artifact_prefix = s3_utils.calculate_bucket_and_prefix(
85+
self.artifact_bucket, self.artifact_prefix, self.sagemaker_session
86+
)
87+
8088
artifact_name = os.path.basename(file_path)
8189
artifact_s3_key = "{}/{}/{}".format(
8290
self.artifact_prefix, self.trial_component_name, artifact_name
@@ -96,8 +104,15 @@ def upload_object_artifact(self, artifact_name, artifact_object, file_extension=
96104
Returns:
97105
str: The s3 URI of the uploaded file and the version of the file.
98106
"""
99-
if not self.artifact_bucket:
100-
self.artifact_bucket = self.sagemaker_session.default_bucket()
107+
108+
# If self.artifact_bucket is falsy, it will be set to sagemaker_session.default_bucket.
109+
# In that case, and if sagemaker_session.default_bucket_prefix exists, self.artifact_prefix
110+
# needs to be updated too (because not updating self.artifact_prefix would result in
111+
# different behavior the 1st time this method is called vs the 2nd).
112+
self.artifact_bucket, self.artifact_prefix = s3_utils.calculate_bucket_and_prefix(
113+
self.artifact_bucket, self.artifact_prefix, self.sagemaker_session
114+
)
115+
101116
if file_extension:
102117
artifact_name = (
103118
artifact_name + ("" if file_extension.startswith(".") else ".") + file_extension

src/sagemaker/lambda_helper.py

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717
import zipfile
1818
import time
1919
from botocore.exceptions import ClientError
20+
21+
from sagemaker import s3, s3_utils
2022
from sagemaker.session import Session
2123

2224

@@ -109,12 +111,15 @@ def create(self):
109111
if self.script is not None:
110112
code = {"ZipFile": _zip_lambda_code(self.script)}
111113
else:
112-
bucket = self.s3_bucket or self.session.default_bucket()
114+
bucket, key_prefix = s3_utils.calculate_bucket_and_prefix(
115+
self.s3_bucket, None, self.session
116+
)
113117
key = _upload_to_s3(
114118
s3_client=_get_s3_client(self.session),
115119
function_name=self.function_name,
116120
zipped_code_dir=self.zipped_code_dir,
117121
s3_bucket=bucket,
122+
s3_key_prefix=key_prefix,
118123
)
119124
code = {"S3Bucket": bucket, "S3Key": key}
120125

@@ -148,7 +153,10 @@ def update(self):
148153
ZipFile=_zip_lambda_code(self.script),
149154
)
150155
else:
151-
bucket = self.s3_bucket or self.session.default_bucket()
156+
bucket, key_prefix = s3_utils.calculate_bucket_and_prefix(
157+
self.s3_bucket, None, self.session
158+
)
159+
152160
# get function name to be used in S3 upload path
153161
if self.function_arn:
154162
versioned_function_name = self.function_arn.split("funtion:")[-1]
@@ -167,6 +175,7 @@ def update(self):
167175
function_name=function_name_for_s3,
168176
zipped_code_dir=self.zipped_code_dir,
169177
s3_bucket=bucket,
178+
s3_key_prefix=key_prefix,
170179
),
171180
)
172181
return response
@@ -255,7 +264,7 @@ def _get_lambda_client(session):
255264
return lambda_client
256265

257266

258-
def _upload_to_s3(s3_client, function_name, zipped_code_dir, s3_bucket):
267+
def _upload_to_s3(s3_client, function_name, zipped_code_dir, s3_bucket, s3_key_prefix=None):
259268
"""Upload the zipped code to S3 bucket provided in the Lambda instance.
260269
261270
Lambda instance must have a path to the zipped code folder and a S3 bucket to upload
@@ -264,7 +273,13 @@ def _upload_to_s3(s3_client, function_name, zipped_code_dir, s3_bucket):
264273
265274
Returns: the S3 key where the code is uploaded.
266275
"""
267-
key = "{}/{}/{}".format("lambda", function_name, "code")
276+
277+
key = s3.s3_path_join(
278+
s3_key_prefix,
279+
"lambda",
280+
function_name,
281+
"code",
282+
)
268283
s3_client.upload_file(zipped_code_dir, s3_bucket, key)
269284
return key
270285

src/sagemaker/local/local_session.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
load_sagemaker_config,
2626
validate_sagemaker_config,
2727
SESSION_S3_BUCKET_PATH,
28+
SESSION_S3_OBJECT_KEY_PREFIX_PATH,
2829
)
2930
from sagemaker.local.image import _SageMakerContainer
3031
from sagemaker.local.utils import get_docker_host
@@ -610,6 +611,7 @@ def __init__(
610611
s3_endpoint_url=None,
611612
disable_local_code=False,
612613
sagemaker_config: dict = None,
614+
default_bucket_prefix=None,
613615
):
614616
"""Create a Local SageMaker Session.
615617
@@ -632,6 +634,10 @@ def __init__(
632634
this dictionary can be generated by calling
633635
:func:`~sagemaker.config.load_sagemaker_config` and then be provided to the
634636
Session.
637+
default_bucket_prefix (str): The default prefix to use for S3 Object Keys. When
638+
objects are saved to the Session's default_bucket, the Object Key used will
639+
start with the default_bucket_prefix. If not provided here or within
640+
sagemaker_config, no additional prefix will be added.
635641
"""
636642
self.s3_endpoint_url = s3_endpoint_url
637643
# We use this local variable to avoid disrupting the __init__->_initialize API of the
@@ -643,6 +649,7 @@ def __init__(
643649
boto_session=boto_session,
644650
default_bucket=default_bucket,
645651
sagemaker_config=sagemaker_config,
652+
default_bucket_prefix=default_bucket_prefix,
646653
)
647654

648655
if platform.system() == "Windows":
@@ -710,6 +717,12 @@ def _initialize(
710717
config_path=SESSION_S3_BUCKET_PATH,
711718
sagemaker_session=self,
712719
)
720+
# after sagemaker_config initialization, update self.default_bucket_prefix if needed
721+
self.default_bucket_prefix = resolve_value_from_config(
722+
direct_input=self.default_bucket_prefix,
723+
config_path=SESSION_S3_OBJECT_KEY_PREFIX_PATH,
724+
sagemaker_session=self,
725+
)
713726

714727
local_mode_config_file = os.path.join(os.path.expanduser("~"), ".sagemaker", "config.yaml")
715728
if os.path.exists(local_mode_config_file):

0 commit comments

Comments
 (0)