Skip to content

Commit a46deac

Browse files
author
Ruban Hussain
committed
feature: SDK Defaults Config - Support for Session default_bucket
1 parent 9317e08 commit a46deac

File tree

10 files changed

+188
-15
lines changed

10 files changed

+188
-15
lines changed

src/sagemaker/config/__init__.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,7 @@
9393
AUTO_ML_VOLUME_KMS_KEY_ID_PATH,
9494
AUTO_ML_INTER_CONTAINER_ENCRYPTION_PATH,
9595
ENDPOINT_CONFIG_DATA_CAPTURE_KMS_KEY_ID_PATH,
96+
SESSION_S3_BUCKET_PATH,
9697
MONITORING_SCHEDULE_CONFIG,
9798
MONITORING_JOB_DEFINITION,
9899
MONITORING_OUTPUT_CONFIG,
@@ -131,4 +132,8 @@
131132
EXECUTION_ROLE_ARN,
132133
ASYNC_INFERENCE_CONFIG,
133134
SCHEMA_VERSION,
135+
PYTHON_SDK,
136+
MODULES,
137+
S3_BUCKET,
138+
SESSION,
134139
)

src/sagemaker/config/config_schema.py

Lines changed: 26 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -89,13 +89,16 @@
8989
OBJECT = "object"
9090
ADDITIONAL_PROPERTIES = "additionalProperties"
9191
ENABLE_INTER_CONTAINER_TRAFFIC_ENCRYPTION = "EnableInterContainerTrafficEncryption"
92+
SESSION = "Session"
93+
S3_BUCKET = "S3Bucket"
9294

9395

9496
def _simple_path(*args: str):
9597
"""Appends an arbitrary number of strings to use as path constants"""
9698
return ".".join(args)
9799

98100

101+
# Paths for reference elsewhere in the SDK.
99102
COMPILATION_JOB_VPC_CONFIG_PATH = _simple_path(SAGEMAKER, COMPILATION_JOB, VPC_CONFIG)
100103
COMPILATION_JOB_KMS_KEY_ID_PATH = _simple_path(
101104
SAGEMAKER, COMPILATION_JOB, OUTPUT_CONFIG, KMS_KEY_ID
@@ -231,7 +234,6 @@ def _simple_path(*args: str):
231234
MODEL_PACKAGE_VALIDATION_PROFILES_PATH = _simple_path(
232235
SAGEMAKER, MODEL_PACKAGE, VALIDATION_SPECIFICATION, VALIDATION_PROFILES
233236
)
234-
235237
REMOTE_FUNCTION_DEPENDENCIES = _simple_path(
236238
SAGEMAKER, PYTHON_SDK, MODULES, REMOTE_FUNCTION, DEPENDENCIES
237239
)
@@ -274,9 +276,6 @@ def _simple_path(*args: str):
274276
REMOTE_FUNCTION_ENABLE_INTER_CONTAINER_TRAFFIC_ENCRYPTION = _simple_path(
275277
SAGEMAKER, PYTHON_SDK, MODULES, REMOTE_FUNCTION, ENABLE_INTER_CONTAINER_TRAFFIC_ENCRYPTION
276278
)
277-
278-
# Paths for reference elsewhere in the SDK.
279-
# Names include the schema version since the paths could change with other schema versions
280279
MONITORING_SCHEDULE_INTER_CONTAINER_ENCRYPTION_PATH = _simple_path(
281280
SAGEMAKER,
282281
MONITORING_SCHEDULE,
@@ -298,6 +297,7 @@ def _simple_path(*args: str):
298297
TRAINING_JOB_INTER_CONTAINER_ENCRYPTION_PATH = _simple_path(
299298
SAGEMAKER, TRAINING_JOB, ENABLE_INTER_CONTAINER_TRAFFIC_ENCRYPTION
300299
)
300+
SESSION_S3_BUCKET_PATH = _simple_path(SAGEMAKER, PYTHON_SDK, MODULES, SESSION, S3_BUCKET)
301301

302302
SAGEMAKER_PYTHON_SDK_CONFIG_SCHEMA = {
303303
"$schema": "https://json-schema.org/draft/2020-12/schema",
@@ -447,6 +447,16 @@ def _simple_path(*args: str):
447447
"s3Uri": {TYPE: "string", "pattern": "^(https|s3)://([^/]+)/?(.*)$", "maxLength": 1024},
448448
# Regex is taken from https://docs.aws.amazon.com/sagemaker/latest/APIReference/API_AlgorithmSpecification.html#sagemaker-Type-AlgorithmSpecification-ContainerEntrypoint
449449
"preExecutionCommand": {TYPE: "string", "pattern": r".*"},
450+
451+
# Regex based on https://docs.aws.amazon.com/sagemaker/latest/APIReference/API_PipelineDefinitionS3Location.html
452+
# except with an additional ^ and $ for the beginning and the end to closer align to
453+
# https://docs.aws.amazon.com/AmazonS3/latest/userguide/bucketnamingrules.html
454+
"s3Bucket": {
455+
TYPE: "string",
456+
"pattern": r"^[a-z0-9][\.\-a-z0-9]{1,61}[a-z0-9]$",
457+
"minLength": 3,
458+
"maxLength": 63,
459+
},
450460
},
451461
PROPERTIES: {
452462
SCHEMA_VERSION: {
@@ -477,6 +487,16 @@ def _simple_path(*args: str):
477487
TYPE: OBJECT,
478488
ADDITIONAL_PROPERTIES: False,
479489
PROPERTIES: {
490+
SESSION: {
491+
TYPE: OBJECT,
492+
ADDITIONAL_PROPERTIES: False,
493+
PROPERTIES: {
494+
S3_BUCKET: {
495+
"description": "Used as `default_bucket` of Session",
496+
"$ref": "#/definitions/s3Bucket",
497+
},
498+
},
499+
},
480500
REMOTE_FUNCTION: {
481501
TYPE: OBJECT,
482502
ADDITIONAL_PROPERTIES: False,
@@ -504,9 +524,9 @@ def _simple_path(*args: str):
504524
VOLUME_KMS_KEY_ID: {"$ref": "#/definitions/kmsKeyId"},
505525
VPC_CONFIG: {"$ref": "#/definitions/vpcConfig"},
506526
},
507-
}
527+
},
508528
},
509-
}
529+
},
510530
},
511531
},
512532
# Feature Group

src/sagemaker/local/local_session.py

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,11 @@
2121
import boto3
2222
from botocore.exceptions import ClientError
2323

24-
from sagemaker.config import load_sagemaker_config, validate_sagemaker_config
24+
from sagemaker.config import (
25+
load_sagemaker_config,
26+
validate_sagemaker_config,
27+
SESSION_S3_BUCKET_PATH,
28+
)
2529
from sagemaker.local.image import _SageMakerContainer
2630
from sagemaker.local.utils import get_docker_host
2731
from sagemaker.local.entities import (
@@ -34,7 +38,7 @@
3438
_LocalPipeline,
3539
)
3640
from sagemaker.session import Session
37-
from sagemaker.utils import get_config_value, _module_import_error
41+
from sagemaker.utils import get_config_value, _module_import_error, resolve_value_from_config
3842

3943
logger = logging.getLogger(__name__)
4044

@@ -700,15 +704,22 @@ def _initialize(
700704
# create a default S3 resource, but only if it needs to fetch from S3
701705
self.sagemaker_config = load_sagemaker_config(s3_resource=self.s3_resource)
702706

703-
sagemaker_config_file = os.path.join(os.path.expanduser("~"), ".sagemaker", "config.yaml")
704-
if os.path.exists(sagemaker_config_file):
707+
# after sagemaker_config initialization, update self._default_bucket_name_override if needed
708+
self._default_bucket_name_override = resolve_value_from_config(
709+
direct_input=self._default_bucket_name_override,
710+
config_path=SESSION_S3_BUCKET_PATH,
711+
sagemaker_session=self,
712+
)
713+
714+
local_mode_config_file = os.path.join(os.path.expanduser("~"), ".sagemaker", "config.yaml")
715+
if os.path.exists(local_mode_config_file):
705716
try:
706717
import yaml
707718
except ImportError as e:
708719
logger.error(_module_import_error("yaml", "Local mode", "local"))
709720
raise e
710721

711-
self.config = yaml.safe_load(open(sagemaker_config_file, "r"))
722+
self.config = yaml.safe_load(open(local_mode_config_file, "r"))
712723
if self._disable_local_code and "local" in self.config:
713724
self.config["local"]["local_code"] = False
714725

src/sagemaker/session.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,7 @@
9292
FEATURE_GROUP_ROLE_ARN_PATH,
9393
FEATURE_GROUP_ONLINE_STORE_CONFIG_PATH,
9494
FEATURE_GROUP_OFFLINE_STORE_CONFIG_PATH,
95+
SESSION_S3_BUCKET_PATH,
9596
)
9697
from sagemaker.deprecations import deprecated_class
9798
from sagemaker.inputs import ShuffleConfig, TrainingInput, BatchDataCaptureConfig
@@ -180,7 +181,8 @@ def __init__(
180181
default_bucket (str): The default Amazon S3 bucket to be used by this session.
181182
This will be created the next time an Amazon S3 bucket is needed (by calling
182183
:func:`default_bucket`).
183-
If not provided, a default bucket will be created based on the following format:
184+
If not provided, it will be fetched from the sagemaker_config. If not configured
185+
there either, a default bucket will be created based on the following format:
184186
"sagemaker-{region}-{aws-account-id}".
185187
Example: "sagemaker-my-custom-bucket".
186188
settings (sagemaker.session_settings.SessionSettings): Optional. Set of optional
@@ -200,8 +202,13 @@ def __init__(
200202
:func:`~sagemaker.config.load_sagemaker_config` and then be provided to the
201203
Session.
202204
"""
205+
206+
# sagemaker_config is validated and initialized inside :func:`_initialize`,
207+
# so if default_bucket is None and the sagemaker_config has a default S3 bucket configured,
208+
# _default_bucket_name_override will be set again inside :func:`_initialize`.
203209
self._default_bucket = None
204210
self._default_bucket_name_override = default_bucket
211+
205212
self.s3_resource = None
206213
self.s3_client = None
207214
self.resource_groups_client = None
@@ -280,6 +287,13 @@ def _initialize(
280287
# create a default S3 resource, but only if it needs to fetch from S3
281288
self.sagemaker_config = load_sagemaker_config(s3_resource=self.s3_resource)
282289

290+
# after sagemaker_config initialization, update self._default_bucket_name_override if needed
291+
self._default_bucket_name_override = resolve_value_from_config(
292+
direct_input=self._default_bucket_name_override,
293+
config_path=SESSION_S3_BUCKET_PATH,
294+
sagemaker_session=self,
295+
)
296+
283297
@property
284298
def boto_region_name(self):
285299
"""Placeholder docstring"""
@@ -484,7 +498,8 @@ def default_bucket(self):
484498
This function will create the s3 bucket if it does not exist.
485499
486500
Returns:
487-
str: The name of the default bucket, which is of the form:
501+
str: The name of the default bucket. If the name was not explicitly specified through
502+
the Session or sagemaker_config, the bucket will take the form:
488503
``sagemaker-{region}-{AWS account ID}``.
489504
"""
490505

tests/data/config/config.yaml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,9 @@
11
SchemaVersion: '1.0'
22
SageMaker:
3+
PythonSDK:
4+
Modules:
5+
Session:
6+
S3Bucket: 'sagemaker-python-sdk-test-bucket'
37
FeatureGroup:
48
OnlineStoreConfig:
59
SecurityConfig:

tests/integ/test_local_mode.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@
2323
import stopit
2424

2525
import tests.integ.lock as lock
26+
from sagemaker.config import SESSION_S3_BUCKET_PATH
27+
from sagemaker.utils import resolve_value_from_config
2628
from tests.integ import DATA_DIR
2729
from mock import Mock, ANY
2830

@@ -70,6 +72,13 @@ def _initialize(self, boto_session, sagemaker_client, sagemaker_runtime_client,
7072

7173
self.sagemaker_config = kwargs.get("sagemaker_config", None)
7274

75+
# after sagemaker_config initialization, update self._default_bucket_name_override if needed
76+
self._default_bucket_name_override = resolve_value_from_config(
77+
direct_input=self._default_bucket_name_override,
78+
config_path=SESSION_S3_BUCKET_PATH,
79+
sagemaker_session=self,
80+
)
81+
7382

7483
class LocalPipelineNoS3Session(LocalPipelineSession):
7584
"""
@@ -91,6 +100,13 @@ def _initialize(self, boto_session, sagemaker_client, sagemaker_runtime_client,
91100

92101
self.sagemaker_config = kwargs.get("sagemaker_config", None)
93102

103+
# after sagemaker_config initialization, update self._default_bucket_name_override if needed
104+
self._default_bucket_name_override = resolve_value_from_config(
105+
direct_input=self._default_bucket_name_override,
106+
config_path=SESSION_S3_BUCKET_PATH,
107+
sagemaker_session=self,
108+
)
109+
94110

95111
@pytest.fixture(scope="module")
96112
def sagemaker_local_session_no_local_code(boto_session):

tests/unit/__init__.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,11 +70,27 @@
7070
MODEL,
7171
ASYNC_INFERENCE_CONFIG,
7272
SCHEMA_VERSION,
73+
PYTHON_SDK,
74+
MODULES,
75+
S3_BUCKET,
76+
SESSION,
7377
)
7478

7579
DATA_DIR = os.path.join(os.path.dirname(__file__), "..", "data")
7680
PY_VERSION = "py3"
7781

82+
SAGEMAKER_CONFIG_SESSION = {
83+
SCHEMA_VERSION: "1.0",
84+
SAGEMAKER: {
85+
PYTHON_SDK: {
86+
MODULES: {
87+
SESSION: {
88+
S3_BUCKET: "sagemaker-config-session-s3-bucket",
89+
},
90+
},
91+
},
92+
},
93+
}
7894

7995
SAGEMAKER_CONFIG_MONITORING_SCHEDULE = {
8096
SCHEMA_VERSION: "1.0",

tests/unit/sagemaker/config/conftest.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,13 @@ def valid_tags():
3737
return [{"Key": "tag1", "Value": "tagValue1"}]
3838

3939

40+
@pytest.fixture()
41+
def valid_session_config():
42+
return {
43+
"S3Bucket": "sagemaker-python-sdk-test-bucket",
44+
}
45+
46+
4047
@pytest.fixture()
4148
def valid_feature_group_config(valid_iam_role_arn):
4249
security_storage_config = {"KmsKeyId": "kmskeyid1"}
@@ -191,6 +198,7 @@ def valid_remote_function_config(valid_iam_role_arn, valid_tags, valid_vpc_confi
191198

192199
@pytest.fixture()
193200
def valid_config_with_all_the_scopes(
201+
valid_session_config,
194202
valid_feature_group_config,
195203
valid_monitoring_schedule_config,
196204
valid_endpointconfig_config,
@@ -206,6 +214,11 @@ def valid_config_with_all_the_scopes(
206214
valid_remote_function_config,
207215
):
208216
return {
217+
"PythonSDK": {
218+
"Modules": {
219+
"Session": valid_session_config,
220+
}
221+
},
209222
"FeatureGroup": valid_feature_group_config,
210223
"MonitoringSchedule": valid_monitoring_schedule_config,
211224
"EndpointConfig": valid_endpointconfig_config,

tests/unit/sagemaker/config/test_config_schema.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -199,3 +199,37 @@ def test_invalid_s3uri_schema(base_config_with_schema):
199199
config["SageMaker"] = {"PythonSDK": {"Modules": {"RemoteFunction": {"S3RootUri": "bad_regex"}}}}
200200
with pytest.raises(exceptions.ValidationError):
201201
validate(config, SAGEMAKER_PYTHON_SDK_CONFIG_SCHEMA)
202+
203+
@pytest.mark.parametrize(
204+
"bucket_name",
205+
[
206+
"docexamplebucket1",
207+
"log-delivery-march-2020",
208+
"my-hosted-content",
209+
"docexamplewebsite.com",
210+
"www.docexamplewebsite.com",
211+
"my.example.s3.bucket",
212+
],
213+
)
214+
def test_session_s3_bucket_schema(base_config_with_schema, bucket_name):
215+
config = {"PythonSDK": {"Modules": {"Session": {"S3Bucket": bucket_name}}}}
216+
_validate_config(base_config_with_schema, config)
217+
218+
219+
@pytest.mark.parametrize(
220+
"invalid_bucket_name",
221+
[
222+
"ab",
223+
"this-is-sixty-four-characters-total-which-is-one-above-the-limit",
224+
"UPPERCASE-LETTERS",
225+
"special_characters",
226+
"special-characters@",
227+
".dot-at-the-beginning",
228+
"-dash-at-the-beginning",
229+
"dot-at-the-end.",
230+
"dash-at-the-end-",
231+
],
232+
)
233+
def test_invalid_session_s3_bucket_schema(base_config_with_schema, invalid_bucket_name):
234+
with pytest.raises(exceptions.ValidationError):
235+
test_session_s3_bucket_schema(base_config_with_schema, invalid_bucket_name)

0 commit comments

Comments
 (0)