diff --git a/src/sagemaker/amazon/amazon_estimator.py b/src/sagemaker/amazon/amazon_estimator.py index 6a6177fc81..b4fc9d1f6d 100644 --- a/src/sagemaker/amazon/amazon_estimator.py +++ b/src/sagemaker/amazon/amazon_estimator.py @@ -20,7 +20,7 @@ from six.moves.urllib.parse import urlparse -from sagemaker import image_uris +from sagemaker import image_uris, s3_utils from sagemaker.amazon import validation from sagemaker.amazon.hyperparameter import Hyperparameter as hp # noqa from sagemaker.amazon.common import write_numpy_to_dense_tensor @@ -93,8 +93,15 @@ def __init__( enable_network_isolation=enable_network_isolation, **kwargs ) - data_location = data_location or "s3://{}/sagemaker-record-sets/".format( - self.sagemaker_session.default_bucket() + + data_location = data_location or ( + s3_utils.s3_path_join( + "s3://", + self.sagemaker_session.default_bucket(), + self.sagemaker_session.default_bucket_prefix, + "sagemaker-record-sets", + with_end_slash=True, + ) ) self._data_location = data_location diff --git a/src/sagemaker/automl/automl.py b/src/sagemaker/automl/automl.py index e2189a2083..334c1d5c88 100644 --- a/src/sagemaker/automl/automl.py +++ b/src/sagemaker/automl/automl.py @@ -17,7 +17,7 @@ from typing import Optional, List, Dict from six import string_types -from sagemaker import Model, PipelineModel +from sagemaker import Model, PipelineModel, s3 from sagemaker.automl.candidate_estimator import CandidateEstimator from sagemaker.config import ( AUTO_ML_ROLE_ARN_PATH, @@ -676,7 +676,12 @@ def _prepare_for_auto_ml_job(self, job_name=None): self.current_job_name = name_from_base(base_name, max_length=32) if self.output_path is None: - self.output_path = "s3://{}/".format(self.sagemaker_session.default_bucket()) + self.output_path = s3.s3_path_join( + "s3://", + self.sagemaker_session.default_bucket(), + self.sagemaker_session.default_bucket_prefix, + with_end_slash=True, + ) @classmethod def _get_supported_inference_keys(cls, container, default=None): diff --git a/src/sagemaker/config/__init__.py b/src/sagemaker/config/__init__.py index a7c685fa08..ecd0d73cab 100644 --- a/src/sagemaker/config/__init__.py +++ b/src/sagemaker/config/__init__.py @@ -93,6 +93,8 @@ AUTO_ML_VOLUME_KMS_KEY_ID_PATH, AUTO_ML_INTER_CONTAINER_ENCRYPTION_PATH, ENDPOINT_CONFIG_DATA_CAPTURE_KMS_KEY_ID_PATH, + SESSION_DEFAULT_S3_BUCKET_PATH, + SESSION_DEFAULT_S3_OBJECT_KEY_PREFIX_PATH, MONITORING_SCHEDULE_CONFIG, MONITORING_JOB_DEFINITION, MONITORING_OUTPUT_CONFIG, @@ -131,4 +133,9 @@ EXECUTION_ROLE_ARN, ASYNC_INFERENCE_CONFIG, SCHEMA_VERSION, + PYTHON_SDK, + MODULES, + DEFAULT_S3_BUCKET, + DEFAULT_S3_OBJECT_KEY_PREFIX, + SESSION, ) diff --git a/src/sagemaker/config/config_schema.py b/src/sagemaker/config/config_schema.py index 033742603a..5d36c1b076 100644 --- a/src/sagemaker/config/config_schema.py +++ b/src/sagemaker/config/config_schema.py @@ -89,6 +89,9 @@ OBJECT = "object" ADDITIONAL_PROPERTIES = "additionalProperties" ENABLE_INTER_CONTAINER_TRAFFIC_ENCRYPTION = "EnableInterContainerTrafficEncryption" +SESSION = "Session" +DEFAULT_S3_BUCKET = "DefaultS3Bucket" +DEFAULT_S3_OBJECT_KEY_PREFIX = "DefaultS3ObjectKeyPrefix" def _simple_path(*args: str): @@ -96,6 +99,7 @@ def _simple_path(*args: str): return ".".join(args) +# Paths for reference elsewhere in the code. COMPILATION_JOB_VPC_CONFIG_PATH = _simple_path(SAGEMAKER, COMPILATION_JOB, VPC_CONFIG) COMPILATION_JOB_KMS_KEY_ID_PATH = _simple_path( SAGEMAKER, COMPILATION_JOB, OUTPUT_CONFIG, KMS_KEY_ID @@ -231,7 +235,6 @@ def _simple_path(*args: str): MODEL_PACKAGE_VALIDATION_PROFILES_PATH = _simple_path( SAGEMAKER, MODEL_PACKAGE, VALIDATION_SPECIFICATION, VALIDATION_PROFILES ) - REMOTE_FUNCTION_DEPENDENCIES = _simple_path( SAGEMAKER, PYTHON_SDK, MODULES, REMOTE_FUNCTION, DEPENDENCIES ) @@ -274,9 +277,6 @@ def _simple_path(*args: str): REMOTE_FUNCTION_ENABLE_INTER_CONTAINER_TRAFFIC_ENCRYPTION = _simple_path( SAGEMAKER, PYTHON_SDK, MODULES, REMOTE_FUNCTION, ENABLE_INTER_CONTAINER_TRAFFIC_ENCRYPTION ) - -# Paths for reference elsewhere in the SDK. -# Names include the schema version since the paths could change with other schema versions MONITORING_SCHEDULE_INTER_CONTAINER_ENCRYPTION_PATH = _simple_path( SAGEMAKER, MONITORING_SCHEDULE, @@ -298,6 +298,13 @@ def _simple_path(*args: str): TRAINING_JOB_INTER_CONTAINER_ENCRYPTION_PATH = _simple_path( SAGEMAKER, TRAINING_JOB, ENABLE_INTER_CONTAINER_TRAFFIC_ENCRYPTION ) +SESSION_DEFAULT_S3_BUCKET_PATH = _simple_path( + SAGEMAKER, PYTHON_SDK, MODULES, SESSION, DEFAULT_S3_BUCKET +) +SESSION_DEFAULT_S3_OBJECT_KEY_PREFIX_PATH = _simple_path( + SAGEMAKER, PYTHON_SDK, MODULES, SESSION, DEFAULT_S3_OBJECT_KEY_PREFIX +) + SAGEMAKER_PYTHON_SDK_CONFIG_SCHEMA = { "$schema": "https://json-schema.org/draft/2020-12/schema", @@ -447,6 +454,15 @@ def _simple_path(*args: str): "s3Uri": {TYPE: "string", "pattern": "^(https|s3)://([^/]+)/?(.*)$", "maxLength": 1024}, # Regex is taken from https://docs.aws.amazon.com/sagemaker/latest/APIReference/API_AlgorithmSpecification.html#sagemaker-Type-AlgorithmSpecification-ContainerEntrypoint "preExecutionCommand": {TYPE: "string", "pattern": r".*"}, + # Regex based on https://docs.aws.amazon.com/sagemaker/latest/APIReference/API_PipelineDefinitionS3Location.html + # except with an additional ^ and $ for the beginning and the end to closer align to + # https://docs.aws.amazon.com/AmazonS3/latest/userguide/bucketnamingrules.html + "s3Bucket": { + TYPE: "string", + "pattern": r"^[a-z0-9][\.\-a-z0-9]{1,61}[a-z0-9]$", + "minLength": 3, + "maxLength": 63, + }, }, PROPERTIES: { SCHEMA_VERSION: { @@ -477,6 +493,29 @@ def _simple_path(*args: str): TYPE: OBJECT, ADDITIONAL_PROPERTIES: False, PROPERTIES: { + SESSION: { + TYPE: OBJECT, + ADDITIONAL_PROPERTIES: False, + PROPERTIES: { + DEFAULT_S3_BUCKET: { + "description": "sets `default_bucket` of Session", + "$ref": "#/definitions/s3Bucket", + }, + DEFAULT_S3_OBJECT_KEY_PREFIX: { + "description": ( + "sets `default_bucket_prefix` of Session" + ), + TYPE: "string", + # S3 guidelines: + # https://docs.aws.amazon.com/AmazonS3/latest/userguide/object-keys.html + # Note that the PythonSDK at the time of writing + # tends to collapse multiple "/" in a row to one "/" + # (even though S3 allows multiple "/" in a row) + "minLength": 1, + "maxLength": 1024, + }, + }, + }, REMOTE_FUNCTION: { TYPE: OBJECT, ADDITIONAL_PROPERTIES: False, @@ -504,9 +543,9 @@ def _simple_path(*args: str): VOLUME_KMS_KEY_ID: {"$ref": "#/definitions/kmsKeyId"}, VPC_CONFIG: {"$ref": "#/definitions/vpcConfig"}, }, - } + }, }, - } + }, }, }, # Feature Group diff --git a/src/sagemaker/djl_inference/model.py b/src/sagemaker/djl_inference/model.py index 632681ffbb..a181ce45e9 100644 --- a/src/sagemaker/djl_inference/model.py +++ b/src/sagemaker/djl_inference/model.py @@ -27,6 +27,7 @@ from sagemaker.deserializers import JSONDeserializer, BaseDeserializer from sagemaker.djl_inference import defaults from sagemaker.model import FrameworkModel +from sagemaker.s3_utils import s3_path_join from sagemaker.serializers import JSONSerializer, BaseSerializer from sagemaker.session import Session from sagemaker.utils import _tmpdir, _create_or_update_code_dir @@ -502,12 +503,16 @@ def partition( self.key_prefix, self.name, self.image_uri ) if s3_output_uri is None: - bucket = self.bucket or self.sagemaker_session.default_bucket() - s3_output_uri = f"s3://{bucket}/{deploy_key_prefix}" + bucket, deploy_key_prefix = s3.determine_bucket_and_prefix( + bucket=self.bucket, + key_prefix=deploy_key_prefix, + sagemaker_session=self.sagemaker_session, + ) + s3_output_uri = s3_path_join("s3://", bucket, deploy_key_prefix) else: - s3_output_uri = f"{s3_output_uri}/{deploy_key_prefix}" + s3_output_uri = s3_path_join(s3_output_uri, deploy_key_prefix) - self.save_mp_checkpoint_path = f"{s3_output_uri}/aot-partitioned-checkpoints" + self.save_mp_checkpoint_path = s3_path_join(s3_output_uri, "aot-partitioned-checkpoints") container_def = self._upload_model_to_s3(upload_as_tar=False) estimator = _create_estimator( @@ -673,7 +678,11 @@ def _upload_model_to_s3(self, upload_as_tar: bool = True): deploy_key_prefix = fw_utils.model_code_key_prefix( self.key_prefix, self.name, self.image_uri ) - bucket = self.bucket or self.sagemaker_session.default_bucket() + bucket, deploy_key_prefix = s3.determine_bucket_and_prefix( + bucket=self.bucket, + key_prefix=deploy_key_prefix, + sagemaker_session=self.sagemaker_session, + ) if upload_as_tar: uploaded_code = fw_utils.tar_and_upload_dir( self.sagemaker_session.boto_session, @@ -686,10 +695,9 @@ def _upload_model_to_s3(self, upload_as_tar: bool = True): ) model_data_url = uploaded_code.s3_prefix else: - key_prefix = f"{deploy_key_prefix}/aot-model" model_data_url = S3Uploader.upload( tmp_code_dir, - "s3://%s/%s" % (bucket, key_prefix), + s3_path_join("s3://", bucket, deploy_key_prefix, "aot-model"), self.model_kms_key, self.sagemaker_session, ) diff --git a/src/sagemaker/estimator.py b/src/sagemaker/estimator.py index 4f91b73972..f01935e6ca 100644 --- a/src/sagemaker/estimator.py +++ b/src/sagemaker/estimator.py @@ -27,7 +27,7 @@ from six.moves.urllib.parse import urlparse import sagemaker -from sagemaker import git_utils, image_uris, vpc_utils +from sagemaker import git_utils, image_uris, vpc_utils, s3 from sagemaker.analytics import TrainingJobAnalytics from sagemaker.config import ( TRAINING_JOB_VOLUME_KMS_KEY_ID_PATH, @@ -672,6 +672,9 @@ def __init__( enable_network_isolation=self._enable_network_isolation, ) + # Internal flag + self._is_output_path_set_from_default_bucket_and_prefix = False + @abstractmethod def training_image_uri(self): """Return the Docker image to use for training. @@ -772,7 +775,13 @@ def _prepare_for_training(self, job_name=None): if self.sagemaker_session.local_mode and local_code: self.output_path = "" else: - self.output_path = "s3://{}/".format(self.sagemaker_session.default_bucket()) + self.output_path = s3.s3_path_join( + "s3://", + self.sagemaker_session.default_bucket(), + self.sagemaker_session.default_bucket_prefix, + with_end_slash=True, + ) + self._is_output_path_set_from_default_bucket_and_prefix = True if self.git_config: updated_paths = git_utils.git_clone_repo( @@ -847,7 +856,8 @@ def _stage_user_code_in_s3(self) -> UploadedCode: if is_pipeline_variable(self.output_path): if self.code_location is None: code_bucket = self.sagemaker_session.default_bucket() - code_s3_prefix = self._assign_s3_prefix() + key_prefix = self.sagemaker_session.default_bucket_prefix + code_s3_prefix = self._assign_s3_prefix(key_prefix) kms_key = None else: code_bucket, key_prefix = parse_s3_url(self.code_location) @@ -860,7 +870,8 @@ def _stage_user_code_in_s3(self) -> UploadedCode: if local_mode: if self.code_location is None: code_bucket = self.sagemaker_session.default_bucket() - code_s3_prefix = self._assign_s3_prefix() + key_prefix = self.sagemaker_session.default_bucket_prefix + code_s3_prefix = self._assign_s3_prefix(key_prefix) kms_key = None else: code_bucket, key_prefix = parse_s3_url(self.code_location) @@ -868,8 +879,21 @@ def _stage_user_code_in_s3(self) -> UploadedCode: kms_key = None else: if self.code_location is None: - code_bucket, _ = parse_s3_url(self.output_path) - code_s3_prefix = self._assign_s3_prefix() + code_bucket, possible_key_prefix = parse_s3_url(self.output_path) + + if self._is_output_path_set_from_default_bucket_and_prefix: + # Only include possible_key_prefix if the output_path was created from the + # Session's default bucket and prefix. In that scenario, possible_key_prefix + # will either be "" or Session.default_bucket_prefix. + # Note: We cannot do `if (code_bucket == session.default_bucket() and + # key_prefix == session.default_bucket_prefix)` instead because the user + # could have passed in equivalent values themselves to output_path. And + # including the prefix in that case could result in a potentially backwards + # incompatible behavior change for the end user. + code_s3_prefix = self._assign_s3_prefix(possible_key_prefix) + else: + code_s3_prefix = self._assign_s3_prefix() + kms_key = self.output_kms_key else: code_bucket, key_prefix = parse_s3_url(self.code_location) @@ -905,18 +929,13 @@ def _assign_s3_prefix(self, key_prefix=""): """ from sagemaker.workflow.utilities import _pipeline_config - code_s3_prefix = "/".join(filter(None, [key_prefix, self._current_job_name, "source"])) + code_s3_prefix = s3.s3_path_join(key_prefix, self._current_job_name, "source") if _pipeline_config and _pipeline_config.code_hash: - code_s3_prefix = "/".join( - filter( - None, - [ - key_prefix, - _pipeline_config.pipeline_name, - "code", - _pipeline_config.code_hash, - ], - ) + code_s3_prefix = s3.s3_path_join( + key_prefix, + _pipeline_config.pipeline_name, + "code", + _pipeline_config.code_hash, ) return code_s3_prefix @@ -1060,8 +1079,12 @@ def _set_source_s3_uri(self, rule): if "source_s3_uri" in (rule.rule_parameters or {}): parse_result = urlparse(rule.rule_parameters["source_s3_uri"]) if parse_result.scheme != "s3": - desired_s3_uri = os.path.join( - "s3://", self.sagemaker_session.default_bucket(), rule.name, str(uuid.uuid4()) + desired_s3_uri = s3.s3_path_join( + "s3://", + self.sagemaker_session.default_bucket(), + self.sagemaker_session.default_bucket_prefix, + rule.name, + str(uuid.uuid4()), ) s3_uri = S3Uploader.upload( local_path=rule.rule_parameters["source_s3_uri"], diff --git a/src/sagemaker/experiments/_helper.py b/src/sagemaker/experiments/_helper.py index 0c689b1125..c87bc66e42 100644 --- a/src/sagemaker/experiments/_helper.py +++ b/src/sagemaker/experiments/_helper.py @@ -19,6 +19,7 @@ import botocore +from sagemaker import s3 from sagemaker.experiments._utils import is_already_exist_error logger = logging.getLogger(__name__) @@ -75,8 +76,17 @@ def upload_artifact(self, file_path): raise ValueError( "{} does not exist or is not a file. Please supply a file path.".format(file_path) ) - if not self.artifact_bucket: - self.artifact_bucket = self.sagemaker_session.default_bucket() + + # If self.artifact_bucket is falsy, it will be set to sagemaker_session.default_bucket. + # In that case, and if sagemaker_session.default_bucket_prefix exists, self.artifact_prefix + # needs to be updated too (because not updating self.artifact_prefix would result in + # different behavior the 1st time this method is called vs the 2nd). + self.artifact_bucket, self.artifact_prefix = s3.determine_bucket_and_prefix( + bucket=self.artifact_bucket, + key_prefix=self.artifact_prefix, + sagemaker_session=self.sagemaker_session, + ) + artifact_name = os.path.basename(file_path) artifact_s3_key = "{}/{}/{}".format( self.artifact_prefix, self.trial_component_name, artifact_name @@ -96,8 +106,17 @@ def upload_object_artifact(self, artifact_name, artifact_object, file_extension= Returns: str: The s3 URI of the uploaded file and the version of the file. """ - if not self.artifact_bucket: - self.artifact_bucket = self.sagemaker_session.default_bucket() + + # If self.artifact_bucket is falsy, it will be set to sagemaker_session.default_bucket. + # In that case, and if sagemaker_session.default_bucket_prefix exists, self.artifact_prefix + # needs to be updated too (because not updating self.artifact_prefix would result in + # different behavior the 1st time this method is called vs the 2nd). + self.artifact_bucket, self.artifact_prefix = s3.determine_bucket_and_prefix( + bucket=self.artifact_bucket, + key_prefix=self.artifact_prefix, + sagemaker_session=self.sagemaker_session, + ) + if file_extension: artifact_name = ( artifact_name + ("" if file_extension.startswith(".") else ".") + file_extension diff --git a/src/sagemaker/fw_utils.py b/src/sagemaker/fw_utils.py index c23f9016d8..6243d91592 100644 --- a/src/sagemaker/fw_utils.py +++ b/src/sagemaker/fw_utils.py @@ -25,6 +25,7 @@ from packaging import version import sagemaker.image_uris +from sagemaker.s3_utils import s3_path_join from sagemaker.session_settings import SessionSettings import sagemaker.utils from sagemaker.workflow import is_pipeline_variable @@ -598,7 +599,7 @@ def model_code_key_prefix(code_location_key_prefix, model_name, image): name_from_image = f"/model_code/{int(time.time())}" if not is_pipeline_variable(image): name_from_image = sagemaker.utils.name_from_image(image) - return "/".join(filter(None, [code_location_key_prefix, model_name or name_from_image])) + return s3_path_join(code_location_key_prefix, model_name or name_from_image) def warn_if_parameter_server_with_multi_gpu(training_instance_type, distribution): diff --git a/src/sagemaker/lambda_helper.py b/src/sagemaker/lambda_helper.py index d4912d3e8a..844ff83ed0 100644 --- a/src/sagemaker/lambda_helper.py +++ b/src/sagemaker/lambda_helper.py @@ -17,6 +17,8 @@ import zipfile import time from botocore.exceptions import ClientError + +from sagemaker import s3 from sagemaker.session import Session @@ -118,12 +120,15 @@ def create(self): if self.script is not None: code = {"ZipFile": _zip_lambda_code(self.script)} else: - bucket = self.s3_bucket or self.session.default_bucket() + bucket, key_prefix = s3.determine_bucket_and_prefix( + bucket=self.s3_bucket, key_prefix=None, sagemaker_session=self.session + ) key = _upload_to_s3( s3_client=_get_s3_client(self.session), function_name=self.function_name, zipped_code_dir=self.zipped_code_dir, s3_bucket=bucket, + s3_key_prefix=key_prefix, ) code = {"S3Bucket": bucket, "S3Key": key} @@ -160,7 +165,10 @@ def update(self): ZipFile=_zip_lambda_code(self.script), ) else: - bucket = self.s3_bucket or self.session.default_bucket() + bucket, key_prefix = s3.determine_bucket_and_prefix( + bucket=self.s3_bucket, key_prefix=None, sagemaker_session=self.session + ) + # get function name to be used in S3 upload path if self.function_arn: versioned_function_name = self.function_arn.split("funtion:")[-1] @@ -179,6 +187,7 @@ def update(self): function_name=function_name_for_s3, zipped_code_dir=self.zipped_code_dir, s3_bucket=bucket, + s3_key_prefix=key_prefix, ), ) return response @@ -267,7 +276,7 @@ def _get_lambda_client(session): return lambda_client -def _upload_to_s3(s3_client, function_name, zipped_code_dir, s3_bucket): +def _upload_to_s3(s3_client, function_name, zipped_code_dir, s3_bucket, s3_key_prefix=None): """Upload the zipped code to S3 bucket provided in the Lambda instance. Lambda instance must have a path to the zipped code folder and a S3 bucket to upload @@ -276,7 +285,13 @@ def _upload_to_s3(s3_client, function_name, zipped_code_dir, s3_bucket): Returns: the S3 key where the code is uploaded. """ - key = "{}/{}/{}".format("lambda", function_name, "code") + + key = s3.s3_path_join( + s3_key_prefix, + "lambda", + function_name, + "code", + ) s3_client.upload_file(zipped_code_dir, s3_bucket, key) return key diff --git a/src/sagemaker/local/local_session.py b/src/sagemaker/local/local_session.py index c94f695e0d..3dbb0e0464 100644 --- a/src/sagemaker/local/local_session.py +++ b/src/sagemaker/local/local_session.py @@ -21,7 +21,12 @@ import boto3 from botocore.exceptions import ClientError -from sagemaker.config import load_sagemaker_config, validate_sagemaker_config +from sagemaker.config import ( + load_sagemaker_config, + validate_sagemaker_config, + SESSION_DEFAULT_S3_BUCKET_PATH, + SESSION_DEFAULT_S3_OBJECT_KEY_PREFIX_PATH, +) from sagemaker.local.image import _SageMakerContainer from sagemaker.local.utils import get_docker_host from sagemaker.local.entities import ( @@ -34,7 +39,7 @@ _LocalPipeline, ) from sagemaker.session import Session -from sagemaker.utils import get_config_value, _module_import_error +from sagemaker.utils import get_config_value, _module_import_error, resolve_value_from_config logger = logging.getLogger(__name__) @@ -606,6 +611,7 @@ def __init__( s3_endpoint_url=None, disable_local_code=False, sagemaker_config: dict = None, + default_bucket_prefix=None, ): """Create a Local SageMaker Session. @@ -628,6 +634,10 @@ def __init__( this dictionary can be generated by calling :func:`~sagemaker.config.load_sagemaker_config` and then be provided to the Session. + default_bucket_prefix (str): The default prefix to use for S3 Object Keys. When + objects are saved to the Session's default_bucket, the Object Key used will + start with the default_bucket_prefix. If not provided here or within + sagemaker_config, no additional prefix will be added. """ self.s3_endpoint_url = s3_endpoint_url # We use this local variable to avoid disrupting the __init__->_initialize API of the @@ -639,6 +649,7 @@ def __init__( boto_session=boto_session, default_bucket=default_bucket, sagemaker_config=sagemaker_config, + default_bucket_prefix=default_bucket_prefix, ) if platform.system() == "Windows": @@ -700,15 +711,28 @@ def _initialize( # create a default S3 resource, but only if it needs to fetch from S3 self.sagemaker_config = load_sagemaker_config(s3_resource=self.s3_resource) - sagemaker_config_file = os.path.join(os.path.expanduser("~"), ".sagemaker", "config.yaml") - if os.path.exists(sagemaker_config_file): + # after sagemaker_config initialization, update self._default_bucket_name_override if needed + self._default_bucket_name_override = resolve_value_from_config( + direct_input=self._default_bucket_name_override, + config_path=SESSION_DEFAULT_S3_BUCKET_PATH, + sagemaker_session=self, + ) + # after sagemaker_config initialization, update self.default_bucket_prefix if needed + self.default_bucket_prefix = resolve_value_from_config( + direct_input=self.default_bucket_prefix, + config_path=SESSION_DEFAULT_S3_OBJECT_KEY_PREFIX_PATH, + sagemaker_session=self, + ) + + local_mode_config_file = os.path.join(os.path.expanduser("~"), ".sagemaker", "config.yaml") + if os.path.exists(local_mode_config_file): try: import yaml except ImportError as e: logger.error(_module_import_error("yaml", "Local mode", "local")) raise e - self.config = yaml.safe_load(open(sagemaker_config_file, "r")) + self.config = yaml.safe_load(open(local_mode_config_file, "r")) if self._disable_local_code and "local" in self.config: self.config["local"]["local_code"] = False diff --git a/src/sagemaker/model.py b/src/sagemaker/model.py index 31cec28761..da9611ea31 100644 --- a/src/sagemaker/model.py +++ b/src/sagemaker/model.py @@ -563,7 +563,11 @@ def _upload_code(self, key_prefix: str, repack: bool = False) -> None: artifact should be repackaged into a new S3 object. (default: False). """ local_code = utils.get_config_value("local.local_code", self.sagemaker_session.config) - bucket = self.bucket or self.sagemaker_session.default_bucket() + + bucket, key_prefix = s3.determine_bucket_and_prefix( + bucket=self.bucket, key_prefix=key_prefix, sagemaker_session=self.sagemaker_session + ) + if (self.sagemaker_session.local_mode and local_code) or self.entry_point is None: self.uploaded_code = None elif not repack: @@ -1333,14 +1337,22 @@ def _build_default_async_inference_config(self, async_inference_config): """Build default async inference config and return ``AsyncInferenceConfig``""" unique_folder = unique_name_from_base(self.name) if async_inference_config.output_path is None: - async_output_s3uri = "s3://{}/async-endpoint-outputs/{}".format( - self.sagemaker_session.default_bucket(), unique_folder + async_output_s3uri = s3.s3_path_join( + "s3://", + self.sagemaker_session.default_bucket(), + self.sagemaker_session.default_bucket_prefix, + "async-endpoint-outputs", + unique_folder, ) async_inference_config.output_path = async_output_s3uri if async_inference_config.failure_path is None: - async_failure_s3uri = "s3://{}/async-endpoint-failures/{}".format( - self.sagemaker_session.default_bucket(), unique_folder + async_failure_s3uri = s3.s3_path_join( + "s3://", + self.sagemaker_session.default_bucket(), + self.sagemaker_session.default_bucket_prefix, + "async-endpoint-failures", + unique_folder, ) async_inference_config.failure_path = async_failure_s3uri diff --git a/src/sagemaker/model_monitor/data_capture_config.py b/src/sagemaker/model_monitor/data_capture_config.py index aa11d41aad..28bca43950 100644 --- a/src/sagemaker/model_monitor/data_capture_config.py +++ b/src/sagemaker/model_monitor/data_capture_config.py @@ -73,6 +73,7 @@ def __init__( self.destination_s3_uri = s3.s3_path_join( "s3://", sagemaker_session.default_bucket(), + sagemaker_session.default_bucket_prefix, _MODEL_MONITOR_S3_PATH, _DATA_CAPTURE_S3_PATH, ) diff --git a/src/sagemaker/model_monitor/model_monitoring.py b/src/sagemaker/model_monitor/model_monitoring.py index e865b4815f..94c0e456e7 100644 --- a/src/sagemaker/model_monitor/model_monitoring.py +++ b/src/sagemaker/model_monitor/model_monitoring.py @@ -1287,6 +1287,7 @@ def _normalize_baseline_inputs(self, baseline_inputs=None): s3_uri = s3.s3_path_join( "s3://", self.sagemaker_session.default_bucket(), + self.sagemaker_session.default_bucket_prefix, self.latest_baselining_job_name, file_input.input_name, ) @@ -1312,6 +1313,7 @@ def _normalize_baseline_output(self, output_s3_uri=None): s3_uri = output_s3_uri or s3.s3_path_join( "s3://", self.sagemaker_session.default_bucket(), + self.sagemaker_session.default_bucket_prefix, _MODEL_MONITOR_S3_PATH, _BASELINING_S3_PATH, self.latest_baselining_job_name, @@ -1338,6 +1340,7 @@ def _normalize_processing_output(self, output=None): s3_uri = s3.s3_path_join( "s3://", self.sagemaker_session.default_bucket(), + self.sagemaker_session.default_bucket_prefix, self.latest_baselining_job_name, "output", ) @@ -1361,6 +1364,7 @@ def _normalize_monitoring_output(self, monitoring_schedule_name, output_s3_uri=N s3_uri = output_s3_uri or s3.s3_path_join( "s3://", self.sagemaker_session.default_bucket(), + self.sagemaker_session.default_bucket_prefix, _MODEL_MONITOR_S3_PATH, _MONITORING_S3_PATH, monitoring_schedule_name, @@ -1387,6 +1391,7 @@ def _normalize_monitoring_output_fields(self, output=None): output.destination = s3.s3_path_join( "s3://", self.sagemaker_session.default_bucket(), + self.sagemaker_session.default_bucket_prefix, self.monitoring_schedule_name, "output", ) @@ -1408,6 +1413,7 @@ def _s3_uri_from_local_path(self, path): s3_uri = s3.s3_path_join( "s3://", self.sagemaker_session.default_bucket(), + self.sagemaker_session.default_bucket_prefix, _MODEL_MONITOR_S3_PATH, _MONITORING_S3_PATH, self.monitoring_schedule_name, @@ -1496,6 +1502,7 @@ def _upload_and_convert_to_processing_input(self, source, destination, name): s3_uri = s3.s3_path_join( "s3://", self.sagemaker_session.default_bucket(), + self.sagemaker_session.default_bucket_prefix, _MODEL_MONITOR_S3_PATH, _BASELINING_S3_PATH, self.latest_baselining_job_name, diff --git a/src/sagemaker/model_monitor/monitoring_files.py b/src/sagemaker/model_monitor/monitoring_files.py index 2dfdbd42f6..90ec627087 100644 --- a/src/sagemaker/model_monitor/monitoring_files.py +++ b/src/sagemaker/model_monitor/monitoring_files.py @@ -158,7 +158,12 @@ def from_string( sagemaker_session = sagemaker_session or Session() file_name = file_name or "statistics.json" desired_s3_uri = s3.s3_path_join( - "s3://", sagemaker_session.default_bucket(), "monitoring", str(uuid.uuid4()), file_name + "s3://", + sagemaker_session.default_bucket(), + sagemaker_session.default_bucket_prefix, + "monitoring", + str(uuid.uuid4()), + file_name, ) s3_uri = s3.S3Uploader.upload_string_as_file_body( body=statistics_file_string, @@ -286,7 +291,12 @@ def from_string( sagemaker_session = sagemaker_session or Session() file_name = file_name or "constraints.json" desired_s3_uri = s3.s3_path_join( - "s3://", sagemaker_session.default_bucket(), "monitoring", str(uuid.uuid4()), file_name + "s3://", + sagemaker_session.default_bucket(), + sagemaker_session.default_bucket_prefix, + "monitoring", + str(uuid.uuid4()), + file_name, ) s3_uri = s3.S3Uploader.upload_string_as_file_body( body=constraints_file_string, @@ -441,7 +451,12 @@ def from_string( sagemaker_session = sagemaker_session or Session() file_name = file_name or "constraint_violations.json" desired_s3_uri = s3.s3_path_join( - "s3://", sagemaker_session.default_bucket(), "monitoring", str(uuid.uuid4()), file_name + "s3://", + sagemaker_session.default_bucket(), + sagemaker_session.default_bucket_prefix, + "monitoring", + str(uuid.uuid4()), + file_name, ) s3_uri = s3.S3Uploader.upload_string_as_file_body( body=constraint_violations_file_string, diff --git a/src/sagemaker/predictor_async.py b/src/sagemaker/predictor_async.py index 4c6324a541..1adfce4c7c 100644 --- a/src/sagemaker/predictor_async.py +++ b/src/sagemaker/predictor_async.py @@ -16,6 +16,8 @@ import time import uuid from botocore.exceptions import WaiterError + +from sagemaker import s3 from sagemaker.exceptions import PollingTimeoutError, AsyncInferenceModelError from sagemaker.async_inference import WaiterConfig, AsyncInferenceResponse from sagemaker.s3 import parse_s3_url @@ -166,17 +168,18 @@ def _upload_data_to_s3( my_uuid = str(uuid.uuid4()) timestamp = sagemaker_timestamp() bucket = self.sagemaker_session.default_bucket() - key = "async-endpoint-inputs/{}/{}-{}".format( + key = s3.s3_path_join( + self.sagemaker_session.default_bucket_prefix, + "async-endpoint-inputs", name_from_base(self.name, short=True), - timestamp, - my_uuid, + "{}-{}".format(timestamp, my_uuid), ) data = self.serializer.serialize(data) self.s3_client.put_object( Body=data, Bucket=bucket, Key=key, ContentType=self.serializer.CONTENT_TYPE ) - input_path = input_path or "s3://{}/{}".format(self.sagemaker_session.default_bucket(), key) + input_path = input_path or "s3://{}/{}".format(bucket, key) return input_path diff --git a/src/sagemaker/processing.py b/src/sagemaker/processing.py index 505772bcfe..0d7df9c517 100644 --- a/src/sagemaker/processing.py +++ b/src/sagemaker/processing.py @@ -402,6 +402,7 @@ def _normalize_inputs(self, inputs=None, kms_key=None): desired_s3_uri = s3.s3_path_join( "s3://", self.sagemaker_session.default_bucket(), + self.sagemaker_session.default_bucket_prefix, _pipeline_config.pipeline_name, _pipeline_config.step_name, "input", @@ -411,6 +412,7 @@ def _normalize_inputs(self, inputs=None, kms_key=None): desired_s3_uri = s3.s3_path_join( "s3://", self.sagemaker_session.default_bucket(), + self.sagemaker_session.default_bucket_prefix, self._current_job_name, "input", file_input.input_name, @@ -465,6 +467,12 @@ def _normalize_outputs(self, outputs=None): values=[ "s3:/", self.sagemaker_session.default_bucket(), + *( + # don't include default_bucket_prefix if it is None or "" + [self.sagemaker_session.default_bucket_prefix] + if self.sagemaker_session.default_bucket_prefix + else [] + ), _pipeline_config.pipeline_name, ExecutionVariables.PIPELINE_EXECUTION_ID, _pipeline_config.step_name, @@ -476,6 +484,7 @@ def _normalize_outputs(self, outputs=None): s3_uri = s3.s3_path_join( "s3://", self.sagemaker_session.default_bucket(), + self.sagemaker_session.default_bucket_prefix, self._current_job_name, "output", output.output_name, @@ -780,6 +789,7 @@ def _upload_code(self, code, kms_key=None): desired_s3_uri = s3.s3_path_join( "s3://", self.sagemaker_session.default_bucket(), + self.sagemaker_session.default_bucket_prefix, _pipeline_config.pipeline_name, self._CODE_CONTAINER_INPUT_NAME, _pipeline_config.code_hash, @@ -788,6 +798,7 @@ def _upload_code(self, code, kms_key=None): desired_s3_uri = s3.s3_path_join( "s3://", self.sagemaker_session.default_bucket(), + self.sagemaker_session.default_bucket_prefix, self._current_job_name, "input", self._CODE_CONTAINER_INPUT_NAME, @@ -1937,9 +1948,14 @@ def _create_and_upload_runproc(self, user_script, kms_key, entrypoint_s3_uri): if _pipeline_config and _pipeline_config.pipeline_name: runproc_file_str = self._generate_framework_script(user_script) runproc_file_hash = hash_object(runproc_file_str) - s3_uri = ( - f"s3://{self.sagemaker_session.default_bucket()}/{_pipeline_config.pipeline_name}/" - f"code/{runproc_file_hash}/runproc.sh" + s3_uri = s3.s3_path_join( + "s3://", + self.sagemaker_session.default_bucket(), + self.sagemaker_session.default_bucket_prefix, + _pipeline_config.pipeline_name, + "code", + runproc_file_hash, + "runproc.sh", ) s3_runproc_sh = S3Uploader.upload_string_as_file_body( runproc_file_str, diff --git a/src/sagemaker/remote_function/job.py b/src/sagemaker/remote_function/job.py index a96e6f7146..80863cb4aa 100644 --- a/src/sagemaker/remote_function/job.py +++ b/src/sagemaker/remote_function/job.py @@ -248,7 +248,11 @@ def __init__( self.s3_root_uri = resolve_value_from_config( direct_input=s3_root_uri, config_path=REMOTE_FUNCTION_S3_ROOT_URI, - default_value=os.path.join("s3://", self.sagemaker_session.default_bucket()), + default_value=s3_path_join( + "s3://", + self.sagemaker_session.default_bucket(), + self.sagemaker_session.default_bucket_prefix, + ), sagemaker_session=self.sagemaker_session, ) diff --git a/src/sagemaker/s3.py b/src/sagemaker/s3.py index 9817f83d37..509e97a35d 100644 --- a/src/sagemaker/s3.py +++ b/src/sagemaker/s3.py @@ -13,51 +13,20 @@ """This module contains Enums and helper methods related to S3.""" from __future__ import print_function, absolute_import -import pathlib import logging import io from typing import Union -from six.moves.urllib.parse import urlparse from sagemaker.session import Session -logger = logging.getLogger("sagemaker") - - -def parse_s3_url(url): - """Returns an (s3 bucket, key name/prefix) tuple from a url with an s3 scheme. - - Args: - url (str): - - Returns: - tuple: A tuple containing: - - - str: S3 bucket name - - str: S3 key - """ - parsed_url = urlparse(url) - if parsed_url.scheme != "s3": - raise ValueError("Expecting 's3' scheme, got: {} in {}.".format(parsed_url.scheme, url)) - return parsed_url.netloc, parsed_url.path.lstrip("/") +# These were defined inside s3.py initially. Kept here for backward compatibility +from sagemaker.s3_utils import ( # pylint: disable=unused-import # noqa: F401 + parse_s3_url, + s3_path_join, + determine_bucket_and_prefix, +) - -def s3_path_join(*args): - """Returns the arguments joined by a slash ("/"), similarly to ``os.path.join()`` (on Unix). - - If the first argument is "s3://", then that is preserved. - - Args: - *args: The strings to join with a slash. - - Returns: - str: The joined string. - """ - if args[0].startswith("s3://"): - path = str(pathlib.PurePosixPath(*args[1:])).lstrip("/") - return str(pathlib.PurePosixPath(args[0], path)).replace("s3:/", "s3://") - - return str(pathlib.PurePosixPath(*args)).lstrip("/") +logger = logging.getLogger("sagemaker") class S3Uploader(object): diff --git a/src/sagemaker/s3_utils.py b/src/sagemaker/s3_utils.py new file mode 100644 index 0000000000..e53cdbe02a --- /dev/null +++ b/src/sagemaker/s3_utils.py @@ -0,0 +1,162 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""This module contains helper functions related to S3. You may want to use `s3.py` instead. + +This has a subset of the functions available through s3.py. This module was initially created with +functions that were originally in `s3.py` so that those functions could be imported inside +`session.py` without circular dependencies. (`s3.py` imports Session as a dependency.) +""" +from __future__ import print_function, absolute_import + +import logging +from functools import reduce +from typing import Optional + +from six.moves.urllib.parse import urlparse + +logger = logging.getLogger("sagemaker") + + +def parse_s3_url(url): + """Returns an (s3 bucket, key name/prefix) tuple from a url with an s3 scheme. + + Args: + url (str): + + Returns: + tuple: A tuple containing: + + - str: S3 bucket name + - str: S3 key + """ + parsed_url = urlparse(url) + if parsed_url.scheme != "s3": + raise ValueError("Expecting 's3' scheme, got: {} in {}.".format(parsed_url.scheme, url)) + return parsed_url.netloc, parsed_url.path.lstrip("/") + + +def s3_path_join(*args, with_end_slash: bool = False): + """Returns the arguments joined by a slash ("/"), similar to ``os.path.join()`` (on Unix). + + Behavior of this function: + - If the first argument is "s3://", then that is preserved. + - The output by default will have no slashes at the beginning or end. There is one exception + (see `with_end_slash`). For example, `s3_path_join("/foo", "bar/")` will yield + `"foo/bar"` and `s3_path_join("foo", "bar", with_end_slash=True)` will yield `"foo/bar/"` + - Any repeat slashes will be removed in the output (except for "s3://" if provided at the + beginning). For example, `s3_path_join("s3://", "//foo/", "/bar///baz")` will yield + `"s3://foo/bar/baz"`. + - Empty or None arguments will be skipped. For example + `s3_path_join("foo", "", None, "bar")` will yield `"foo/bar"` + + Alternatives to this function that are NOT recommended for S3 paths: + - `os.path.join(...)` will have different behavior on Unix machines vs non-Unix machines + - `pathlib.PurePosixPath(...)` will apply potentially unintended simplification of single + dots (".") and root directories. (for example + `pathlib.PurePosixPath("foo", "/bar/./", "baz")` would yield `"/bar/baz"`) + - `"{}/{}/{}".format(...)` and similar may result in unintended repeat slashes + + Args: + *args: The strings to join with a slash. + with_end_slash (bool): (default: False) If true and if the path is not empty, appends a "/" + to the end of the path + + Returns: + str: The joined string, without a slash at the end unless with_end_slash is True. + """ + delimiter = "/" + + non_empty_args = list(filter(lambda item: item is not None and item != "", args)) + + merged_path = "" + for index, path in enumerate(non_empty_args): + if ( + index == 0 + or (merged_path and merged_path[-1] == delimiter) + or (path and path[0] == delimiter) + ): + # dont need to add an extra slash because either this is the beginning of the string, + # or one (or more) slash already exists + merged_path += path + else: + merged_path += delimiter + path + + if with_end_slash and merged_path and merged_path[-1] != delimiter: + merged_path += delimiter + + # At this point, merged_path may include slashes at the beginning and/or end. And some of the + # provided args may have had duplicate slashes inside or at the ends. + # For backwards compatibility reasons, these need to be filtered out (done below). In the + # future, if there is a desire to support multiple slashes for S3 paths throughout the SDK, + # one option is to create a new optional argument (or a new function) that only executes the + # logic above. + filtered_path = merged_path + + # remove duplicate slashes + if filtered_path: + + def duplicate_delimiter_remover(sequence, next_char): + if sequence[-1] == delimiter and next_char == delimiter: + return sequence + return sequence + next_char + + if filtered_path.startswith("s3://"): + filtered_path = reduce( + duplicate_delimiter_remover, filtered_path[5:], filtered_path[:5] + ) + else: + filtered_path = reduce(duplicate_delimiter_remover, filtered_path) + + # remove beginning slashes + filtered_path = filtered_path.lstrip(delimiter) + + # remove end slashes + if not with_end_slash and filtered_path != "s3://": + filtered_path = filtered_path.rstrip(delimiter) + + return filtered_path + + +def determine_bucket_and_prefix( + bucket: Optional[str] = None, key_prefix: Optional[str] = None, sagemaker_session=None +): + """Helper function that returns the correct S3 bucket and prefix to use depending on the inputs. + + Args: + bucket (Optional[str]): S3 Bucket to use (if it exists) + key_prefix (Optional[str]): S3 Object Key Prefix to use or append to (if it exists) + sagemaker_session (sagemaker.session.Session): Session to fetch a default bucket and + prefix from, if bucket doesn't exist. Expected to exist + + Returns: The correct S3 Bucket and S3 Object Key Prefix that should be used + """ + if bucket: + final_bucket = bucket + final_key_prefix = key_prefix + else: + final_bucket = sagemaker_session.default_bucket() + + # default_bucket_prefix (if it exists) should be appended if (and only if) 'bucket' does not + # exist and we are using the Session's default_bucket. + final_key_prefix = s3_path_join(sagemaker_session.default_bucket_prefix, key_prefix) + + # We should not append default_bucket_prefix even if the bucket exists but is equal to the + # default_bucket, because either: + # (1) the bucket was explicitly passed in by the user and just happens to be the same as the + # default_bucket (in which case we don't want to change the user's input), or + # (2) the default_bucket was fetched from Session earlier already (and the default prefix + # should have been fetched then as well), and then this function was + # called with it. If we appended the default prefix here, we would be appending it more than + # once in total. + + return final_bucket, final_key_prefix diff --git a/src/sagemaker/session.py b/src/sagemaker/session.py index b4d24d422f..c7fc6c1120 100644 --- a/src/sagemaker/session.py +++ b/src/sagemaker/session.py @@ -31,9 +31,9 @@ import six import sagemaker.logs -from sagemaker import vpc_utils +from sagemaker import vpc_utils, s3_utils from sagemaker._studio import _append_project_tags -from sagemaker.config import load_sagemaker_config, validate_sagemaker_config # noqa: F401 +from sagemaker.config import load_sagemaker_config, validate_sagemaker_config from sagemaker.config import ( KEY, TRAINING_JOB, @@ -92,6 +92,8 @@ FEATURE_GROUP_ROLE_ARN_PATH, FEATURE_GROUP_ONLINE_STORE_CONFIG_PATH, FEATURE_GROUP_OFFLINE_STORE_CONFIG_PATH, + SESSION_DEFAULT_S3_BUCKET_PATH, + SESSION_DEFAULT_S3_OBJECT_KEY_PREFIX_PATH, ) from sagemaker.deprecations import deprecated_class from sagemaker.inputs import ShuffleConfig, TrainingInput, BatchDataCaptureConfig @@ -158,6 +160,7 @@ def __init__( settings=SessionSettings(), sagemaker_metrics_client=None, sagemaker_config: dict = None, + default_bucket_prefix: str = None, ): """Initialize a SageMaker ``Session``. @@ -180,7 +183,8 @@ def __init__( default_bucket (str): The default Amazon S3 bucket to be used by this session. This will be created the next time an Amazon S3 bucket is needed (by calling :func:`default_bucket`). - If not provided, a default bucket will be created based on the following format: + If not provided, it will be fetched from the sagemaker_config. If not configured + there either, a default bucket will be created based on the following format: "sagemaker-{region}-{aws-account-id}". Example: "sagemaker-my-custom-bucket". settings (sagemaker.session_settings.SessionSettings): Optional. Set of optional @@ -199,9 +203,23 @@ def __init__( this dictionary can be generated by calling :func:`~sagemaker.config.load_sagemaker_config` and then be provided to the Session. + default_bucket_prefix (str): The default prefix to use for S3 Object Keys. (default: + None). If provided and where applicable, it will be used by the SDK to construct + default S3 URIs, in the format: + `s3://{default_bucket}/{default_bucket_prefix}/` + This parameter can also be specified via `{sagemaker_config}` instead of here. If + not provided here or within `{sagemaker_config}`, default S3 URIs will have the + format: `s3://{default_bucket}/` """ + + # sagemaker_config is validated and initialized inside :func:`_initialize`, + # so if default_bucket is None and the sagemaker_config has a default S3 bucket configured, + # _default_bucket_name_override will be set again inside :func:`_initialize`. self._default_bucket = None self._default_bucket_name_override = default_bucket + # this may also be set again inside :func:`_initialize` if it is None + self.default_bucket_prefix = default_bucket_prefix + self.s3_resource = None self.s3_client = None self.resource_groups_client = None @@ -280,6 +298,19 @@ def _initialize( # create a default S3 resource, but only if it needs to fetch from S3 self.sagemaker_config = load_sagemaker_config(s3_resource=self.s3_resource) + # after sagemaker_config initialization, update self._default_bucket_name_override if needed + self._default_bucket_name_override = resolve_value_from_config( + direct_input=self._default_bucket_name_override, + config_path=SESSION_DEFAULT_S3_BUCKET_PATH, + sagemaker_session=self, + ) + # after sagemaker_config initialization, update self.default_bucket_prefix if needed + self.default_bucket_prefix = resolve_value_from_config( + direct_input=self.default_bucket_prefix, + config_path=SESSION_DEFAULT_S3_OBJECT_KEY_PREFIX_PATH, + sagemaker_session=self, + ) + @property def boto_region_name(self): """Placeholder docstring""" @@ -313,6 +344,10 @@ def upload_data(self, path, bucket=None, key_prefix="data", extra_args=None): If a directory is specified in the path argument, the URI format is ``s3://{bucket name}/{key_prefix}``. """ + bucket, key_prefix = s3_utils.determine_bucket_and_prefix( + bucket=bucket, key_prefix=key_prefix, sagemaker_session=self + ) + # Generate a tuple for each file that we want to upload of the form (local_path, s3_key). files = [] key_suffix = None @@ -331,7 +366,6 @@ def upload_data(self, path, bucket=None, key_prefix="data", extra_args=None): files.append((path, s3_key)) key_suffix = name - bucket = bucket or self.default_bucket() if self.s3_resource is None: s3 = self.boto_session.resource("s3", region_name=self.boto_region_name) else: @@ -484,7 +518,8 @@ def default_bucket(self): This function will create the s3 bucket if it does not exist. Returns: - str: The name of the default bucket, which is of the form: + str: The name of the default bucket. If the name was not explicitly specified through + the Session or sagemaker_config, the bucket will take the form: ``sagemaker-{region}-{AWS account ID}``. """ diff --git a/src/sagemaker/spark/processing.py b/src/sagemaker/spark/processing.py index c35d97a588..efe1b916d3 100644 --- a/src/sagemaker/spark/processing.py +++ b/src/sagemaker/spark/processing.py @@ -34,7 +34,7 @@ from typing import Union, List, Dict, Optional -from sagemaker import image_uris +from sagemaker import image_uris, s3 from sagemaker.local.image import _ecr_login_if_needed, _pull_image from sagemaker.processing import ProcessingInput, ProcessingOutput, ScriptProcessor from sagemaker.s3 import S3Uploader @@ -429,7 +429,11 @@ def _stage_configuration(self, configuration): else: s3_prefix_uri = self.configuration_location else: - s3_prefix_uri = f"s3://{self.sagemaker_session.default_bucket()}" + s3_prefix_uri = s3.s3_path_join( + "s3://", + self.sagemaker_session.default_bucket(), + self.sagemaker_session.default_bucket_prefix, + ) serialized_configuration = BytesIO(json.dumps(configuration).encode("utf-8")) @@ -524,7 +528,11 @@ def _stage_submit_deps(self, submit_deps, input_channel_name): else: s3_prefix_uri = self.dependency_location else: - s3_prefix_uri = f"s3://{self.sagemaker_session.default_bucket()}" + s3_prefix_uri = s3.s3_path_join( + "s3://", + self.sagemaker_session.default_bucket(), + self.sagemaker_session.default_bucket_prefix, + ) if _pipeline_config and _pipeline_config.code_hash: input_channel_s3_uri = ( diff --git a/src/sagemaker/tensorflow/estimator.py b/src/sagemaker/tensorflow/estimator.py index 914a56dd3f..18f220094b 100644 --- a/src/sagemaker/tensorflow/estimator.py +++ b/src/sagemaker/tensorflow/estimator.py @@ -426,8 +426,13 @@ def _default_s3_path(self, directory, mpi=False): return "/opt/ml/model" if self._current_job_name: if is_pipeline_variable(self.output_path): - output_path = "s3://{}".format(self.sagemaker_session.default_bucket()) - return s3.s3_path_join(output_path, self._current_job_name, directory) + return s3.s3_path_join( + "s3://", + self.sagemaker_session.default_bucket(), + self.sagemaker_session.default_bucket_prefix, + self._current_job_name, + directory, + ) return s3.s3_path_join(self.output_path, self._current_job_name, directory) return None diff --git a/src/sagemaker/tensorflow/model.py b/src/sagemaker/tensorflow/model.py index 1f162cef60..d893cf0762 100644 --- a/src/sagemaker/tensorflow/model.py +++ b/src/sagemaker/tensorflow/model.py @@ -387,8 +387,14 @@ def prepare_container_def( instance_type, accelerator_type, serverless_inference_config=serverless_inference_config ) env = self._get_container_env() - key_prefix = sagemaker.fw_utils.model_code_key_prefix(self.key_prefix, self.name, image_uri) - bucket = self.bucket or self.sagemaker_session.default_bucket() + + bucket, key_prefix = s3.determine_bucket_and_prefix( + bucket=self.bucket, + key_prefix=sagemaker.fw_utils.model_code_key_prefix( + self.key_prefix, self.name, image_uri + ), + sagemaker_session=self.sagemaker_session, + ) if self.entry_point and not is_pipeline_variable(self.model_data): model_data = s3.s3_path_join("s3://", bucket, key_prefix, "model.tar.gz") diff --git a/src/sagemaker/transformer.py b/src/sagemaker/transformer.py index 34685e786b..3e890bd37b 100644 --- a/src/sagemaker/transformer.py +++ b/src/sagemaker/transformer.py @@ -19,6 +19,8 @@ import time from botocore import exceptions + +from sagemaker import s3 from sagemaker.config import ( TRANSFORM_OUTPUT_KMS_KEY_ID_PATH, TRANSFORM_JOB_KMS_KEY_ID_PATH, @@ -260,14 +262,23 @@ def transform( values=[ "s3:/", self.sagemaker_session.default_bucket(), + *( + # don't include default_bucket_prefix if it is None or "" + [self.sagemaker_session.default_bucket_prefix] + if self.sagemaker_session.default_bucket_prefix + else [] + ), _pipeline_config.pipeline_name, ExecutionVariables.PIPELINE_EXECUTION_ID, _pipeline_config.step_name, ], ) else: - self.output_path = "s3://{}/{}".format( - self.sagemaker_session.default_bucket(), self._current_job_name + self.output_path = s3.s3_path_join( + "s3://", + self.sagemaker_session.default_bucket(), + self.sagemaker_session.default_bucket_prefix, + self._current_job_name, ) self._reset_output_path = True diff --git a/src/sagemaker/workflow/airflow.py b/src/sagemaker/workflow/airflow.py index a3565ba9c1..cb4951d6e4 100644 --- a/src/sagemaker/workflow/airflow.py +++ b/src/sagemaker/workflow/airflow.py @@ -37,12 +37,17 @@ def prepare_framework(estimator, s3_operations): """ if estimator.code_location is not None: bucket, key = s3.parse_s3_url(estimator.code_location) - key = os.path.join(key, estimator._current_job_name, "source", "sourcedir.tar.gz") + key = s3.s3_path_join(key, estimator._current_job_name, "source", "sourcedir.tar.gz") elif estimator.uploaded_code is not None: bucket, key = s3.parse_s3_url(estimator.uploaded_code.s3_prefix) else: - bucket = estimator.sagemaker_session._default_bucket - key = os.path.join(estimator._current_job_name, "source", "sourcedir.tar.gz") + bucket = estimator.sagemaker_session.default_bucket + key = s3.s3_path_join( + estimator.sagemaker_session.default_bucket_prefix, + estimator._current_job_name, + "source", + "sourcedir.tar.gz", + ) script = os.path.basename(estimator.entry_point) @@ -159,8 +164,12 @@ def training_base_config(estimator, inputs=None, job_name=None, mini_batch_size= estimator._current_job_name = utils.name_from_base(base_name) if estimator.output_path is None: - default_bucket = estimator.sagemaker_session.default_bucket() - estimator.output_path = "s3://{}/".format(default_bucket) + estimator.output_path = s3.s3_path_join( + "s3://", + estimator.sagemaker_session.default_bucket(), + estimator.sagemaker_session.default_bucket_prefix, + with_end_slash=True, + ) if isinstance(estimator, sagemaker.estimator.Framework): prepare_framework(estimator, s3_operations) @@ -543,7 +552,12 @@ def prepare_framework_container_def(model, instance_type, s3_operations): base_name = utils.base_name_from_image(deploy_image) model.name = model.name or utils.name_from_base(base_name) - bucket = model.bucket or model.sagemaker_session._default_bucket + bucket, key_prefix = s3.determine_bucket_and_prefix( + bucket=model.bucket, + key_prefix=None, + sagemaker_session=model.sagemaker_session, + ) + if model.entry_point is not None: script = os.path.basename(model.entry_point) key = "{}/source/sourcedir.tar.gz".format(model.name) @@ -552,7 +566,7 @@ def prepare_framework_container_def(model, instance_type, s3_operations): code_dir = model.source_dir model.uploaded_code = fw_utils.UploadedCode(s3_prefix=code_dir, script_name=script) else: - code_dir = "s3://{}/{}".format(bucket, key) + code_dir = s3.s3_path_join("s3://", bucket, key_prefix, key) model.uploaded_code = fw_utils.UploadedCode(s3_prefix=code_dir, script_name=script) s3_operations["S3Upload"] = [ {"Path": model.source_dir or script, "Bucket": bucket, "Key": key, "Tar": True} @@ -757,8 +771,11 @@ def transform_config( ) if transformer.output_path is None: - transformer.output_path = "s3://{}/{}".format( - transformer.sagemaker_session.default_bucket(), transformer._current_job_name + transformer.output_path = s3.s3_path_join( + "s3://", + transformer.sagemaker_session.default_bucket(), + transformer.sagemaker_session.default_bucket_prefix, + transformer._current_job_name, ) job_config = sagemaker.transformer._TransformJob._load_config( diff --git a/src/sagemaker/workflow/clarify_check_step.py b/src/sagemaker/workflow/clarify_check_step.py index 32793de977..ceb13fa3d4 100644 --- a/src/sagemaker/workflow/clarify_check_step.py +++ b/src/sagemaker/workflow/clarify_check_step.py @@ -456,6 +456,7 @@ def _get_s3_base_uri_for_monitoring_analysis_config(self) -> str: return s3.s3_path_join( "s3://", self._model_monitor.sagemaker_session.default_bucket(), + self._model_monitor.sagemaker_session.default_bucket_prefix, _MODEL_MONITOR_S3_PATH, monitoring_cfg_base_name, ) diff --git a/src/sagemaker/workflow/pipeline.py b/src/sagemaker/workflow/pipeline.py index 73c115f84a..6869478dee 100644 --- a/src/sagemaker/workflow/pipeline.py +++ b/src/sagemaker/workflow/pipeline.py @@ -176,17 +176,19 @@ def _create_args( if len(pipeline_definition.encode("utf-8")) < 1024 * 100: kwargs["PipelineDefinition"] = pipeline_definition else: - desired_s3_uri = s3.s3_path_join( - "s3://", self.sagemaker_session.default_bucket(), self.name + bucket, object_key = s3.determine_bucket_and_prefix( + bucket=None, key_prefix=self.name, sagemaker_session=self.sagemaker_session ) + + desired_s3_uri = s3.s3_path_join("s3://", bucket, object_key) s3.S3Uploader.upload_string_as_file_body( body=pipeline_definition, desired_s3_uri=desired_s3_uri, sagemaker_session=self.sagemaker_session, ) kwargs["PipelineDefinitionS3Location"] = { - "Bucket": self.sagemaker_session.default_bucket(), - "ObjectKey": self.name, + "Bucket": bucket, + "ObjectKey": object_key, } update_args( diff --git a/src/sagemaker/workflow/pipeline_context.py b/src/sagemaker/workflow/pipeline_context.py index 2ca1e11484..e12ee04f18 100644 --- a/src/sagemaker/workflow/pipeline_context.py +++ b/src/sagemaker/workflow/pipeline_context.py @@ -113,6 +113,7 @@ def __init__( default_bucket=None, settings=SessionSettings(), sagemaker_config: dict = None, + default_bucket_prefix: str = None, ): """Initialize a ``PipelineSession``. @@ -142,6 +143,10 @@ def __init__( this dictionary can be generated by calling :func:`~sagemaker.config.load_sagemaker_config` and then be provided to the Session. + default_bucket_prefix (str): The default prefix to use for S3 Object Keys. When + objects are saved to the Session's default_bucket, the Object Key used will + start with the default_bucket_prefix. If not provided here or within + sagemaker_config, no additional prefix will be added. """ super().__init__( boto_session=boto_session, @@ -149,6 +154,7 @@ def __init__( default_bucket=default_bucket, settings=settings, sagemaker_config=sagemaker_config, + default_bucket_prefix=default_bucket_prefix, ) self._context = None @@ -197,7 +203,12 @@ class LocalPipelineSession(LocalSession, PipelineSession): """ def __init__( - self, boto_session=None, default_bucket=None, s3_endpoint_url=None, disable_local_code=False + self, + boto_session=None, + default_bucket=None, + s3_endpoint_url=None, + disable_local_code=False, + default_bucket_prefix=None, ): """Initialize a ``LocalPipelineSession``. @@ -216,6 +227,10 @@ def __init__( disable_local_code (bool): Set to True to override the default AWS configuration chain to disable the `local.local_code` setting, which may not be supported for some SDK features (default: False). + default_bucket_prefix (str): The default prefix to use for S3 Object Keys. When + objects are saved to the Session's default_bucket, the Object Key used will + start with the default_bucket_prefix. If not provided here or within + sagemaker_config, no additional prefix will be added. """ super().__init__( @@ -223,6 +238,7 @@ def __init__( default_bucket=default_bucket, s3_endpoint_url=s3_endpoint_url, disable_local_code=disable_local_code, + default_bucket_prefix=default_bucket_prefix, ) diff --git a/src/sagemaker/workflow/quality_check_step.py b/src/sagemaker/workflow/quality_check_step.py index 3a6c3ba627..44405b4150 100644 --- a/src/sagemaker/workflow/quality_check_step.py +++ b/src/sagemaker/workflow/quality_check_step.py @@ -338,6 +338,7 @@ def _generate_baseline_output(self): s3_uri = self.quality_check_config.output_s3_uri or s3.s3_path_join( "s3://", self._model_monitor.sagemaker_session.default_bucket(), + self._model_monitor.sagemaker_session.default_bucket_prefix, _MODEL_MONITOR_S3_PATH, _BASELINING_S3_PATH, self._model_monitor.latest_baselining_job_name, diff --git a/tests/conftest.py b/tests/conftest.py index 7db96893fe..f9429279a3 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -28,6 +28,7 @@ DEFAULT_REGION = "us-west-2" CUSTOM_BUCKET_NAME_PREFIX = "sagemaker-custom-bucket" +CUSTOM_S3_OBJECT_KEY_PREFIX = "session-default-prefix" NO_M4_REGIONS = [ "eu-west-3", @@ -164,6 +165,7 @@ def sagemaker_session( sagemaker_runtime_client=runtime_client, sagemaker_metrics_client=metrics_client, sagemaker_config={}, + default_bucket_prefix=CUSTOM_S3_OBJECT_KEY_PREFIX, ) diff --git a/tests/data/config/config.yaml b/tests/data/config/config.yaml index fc052f2ddd..9764311d6b 100644 --- a/tests/data/config/config.yaml +++ b/tests/data/config/config.yaml @@ -1,5 +1,31 @@ SchemaVersion: '1.0' SageMaker: + PythonSDK: + Modules: + Session: + DefaultS3Bucket: 'sagemaker-python-sdk-test-bucket' + DefaultS3ObjectKeyPrefix: 'test-prefix' + RemoteFunction: + Dependencies: "./requirements.txt" + EnvironmentVariables: + "var1": "value1" + "var2": "value2" + ImageUri: "123456789012.dkr.ecr.us-west-2.amazonaws.com/myimage:latest" + IncludeLocalWorkDir: true + InstanceType: "ml.m5.xlarge" + JobCondaEnvironment: "some_conda_env" + RoleArn: "arn:aws:iam::555555555555:role/IMRole" + S3KmsKeyId: "kmskeyid1" + S3RootUri: "s3://my-bucket/key" + Tags: + - Key: "tag1" + Value: "tagValue1" + VolumeKmsKeyId: "kmskeyid2" + VpcConfig: + SecurityGroupIds: + - 'sg123' + Subnets: + - 'subnet-1234' FeatureGroup: OnlineStoreConfig: SecurityConfig: @@ -117,27 +143,4 @@ SageMaker: EdgePackagingJob: OutputConfig: KmsKeyId: 'kmskeyid1' - RoleArn: 'arn:aws:iam::555555555555:role/IMRole' - PythonSDK: - Modules: - RemoteFunction: - Dependencies: "./requirements.txt" - EnvironmentVariables: - "var1": "value1" - "var2": "value2" - ImageUri: "123456789012.dkr.ecr.us-west-2.amazonaws.com/myimage:latest" - IncludeLocalWorkDir: true - InstanceType: "ml.m5.xlarge" - JobCondaEnvironment: "some_conda_env" - RoleArn: "arn:aws:iam::555555555555:role/IMRole" - S3KmsKeyId: "kmskeyid1" - S3RootUri: "s3://my-bucket/key" - Tags: - - Key: "tag1" - Value: "tagValue1" - VolumeKmsKeyId: "kmskeyid2" - VpcConfig: - SecurityGroupIds: - - 'sg123' - Subnets: - - 'subnet-1234' + RoleArn: 'arn:aws:iam::555555555555:role/IMRole' \ No newline at end of file diff --git a/tests/integ/sagemaker/experiments/test_run.py b/tests/integ/sagemaker/experiments/test_run.py index 96fc632ad7..ada64a4dd5 100644 --- a/tests/integ/sagemaker/experiments/test_run.py +++ b/tests/integ/sagemaker/experiments/test_run.py @@ -19,6 +19,7 @@ import pytest +from tests.conftest import CUSTOM_S3_OBJECT_KEY_PREFIX from tests.integ.sagemaker.experiments.conftest import TAGS from sagemaker.experiments._api_types import _TrialComponentStatusType from sagemaker.experiments._utils import is_run_trial_component @@ -754,7 +755,7 @@ def _check_run_from_local_end_result(sagemaker_session, tc, is_complete_log=True if not is_complete_log: return - s3_prefix = f"s3://{sagemaker_session.default_bucket()}/{_DEFAULT_ARTIFACT_PREFIX}" + s3_prefix = f"s3://{sagemaker_session.default_bucket()}/{CUSTOM_S3_OBJECT_KEY_PREFIX}/{_DEFAULT_ARTIFACT_PREFIX}" assert s3_prefix in tc.output_artifacts[file_artifact_name].value assert "text/plain" == tc.output_artifacts[file_artifact_name].media_type assert "s3://Output" == tc.output_artifacts[artifact_name].value diff --git a/tests/integ/test_async_inference.py b/tests/integ/test_async_inference.py index 0f7b0c61ff..f60cd59923 100644 --- a/tests/integ/test_async_inference.py +++ b/tests/integ/test_async_inference.py @@ -64,7 +64,10 @@ def test_async_walkthrough(sagemaker_session, cpu_instance_type, training_set): "s3://" + sagemaker_session.default_bucket() ) assert result_no_wait_with_data.failure_path.startswith( - "s3://" + sagemaker_session.default_bucket() + "/async-endpoint-failures/" + "s3://" + + sagemaker_session.default_bucket() + + f"/{sagemaker_session.default_bucket_prefix}" + + "/async-endpoint-failures/" ) time.sleep(5) result_no_wait_with_data = result_no_wait_with_data.get_result() @@ -101,7 +104,10 @@ def test_async_walkthrough(sagemaker_session, cpu_instance_type, training_set): assert isinstance(result_not_wait, AsyncInferenceResponse) assert result_not_wait.output_path.startswith("s3://" + sagemaker_session.default_bucket()) assert result_not_wait.failure_path.startswith( - "s3://" + sagemaker_session.default_bucket() + "/async-endpoint-failures/" + "s3://" + + sagemaker_session.default_bucket() + + f"/{sagemaker_session.default_bucket_prefix}" + + "/async-endpoint-failures/" ) time.sleep(5) result_not_wait = result_not_wait.get_result() diff --git a/tests/integ/test_clarify_model_monitor.py b/tests/integ/test_clarify_model_monitor.py index 6011dbe271..d278fdef83 100644 --- a/tests/integ/test_clarify_model_monitor.py +++ b/tests/integ/test_clarify_model_monitor.py @@ -553,6 +553,7 @@ def _upload(s3_uri_base, input_file_name, target_time, file_name): capture_s3_uri_base = os.path.join( "s3://", sagemaker_session.default_bucket(), + sagemaker_session.default_bucket_prefix, "model-monitor", "data-capture", endpoint_name, diff --git a/tests/integ/test_local_mode.py b/tests/integ/test_local_mode.py index 0d07ed68c4..6e1e3fd444 100644 --- a/tests/integ/test_local_mode.py +++ b/tests/integ/test_local_mode.py @@ -23,6 +23,8 @@ import stopit import tests.integ.lock as lock +from sagemaker.config import SESSION_DEFAULT_S3_BUCKET_PATH +from sagemaker.utils import resolve_value_from_config from tests.integ import DATA_DIR from mock import Mock, ANY @@ -70,6 +72,13 @@ def _initialize(self, boto_session, sagemaker_client, sagemaker_runtime_client, self.sagemaker_config = kwargs.get("sagemaker_config", None) + # after sagemaker_config initialization, update self._default_bucket_name_override if needed + self._default_bucket_name_override = resolve_value_from_config( + direct_input=self._default_bucket_name_override, + config_path=SESSION_DEFAULT_S3_BUCKET_PATH, + sagemaker_session=self, + ) + class LocalPipelineNoS3Session(LocalPipelineSession): """ @@ -91,6 +100,13 @@ def _initialize(self, boto_session, sagemaker_client, sagemaker_runtime_client, self.sagemaker_config = kwargs.get("sagemaker_config", None) + # after sagemaker_config initialization, update self._default_bucket_name_override if needed + self._default_bucket_name_override = resolve_value_from_config( + direct_input=self._default_bucket_name_override, + config_path=SESSION_DEFAULT_S3_BUCKET_PATH, + sagemaker_session=self, + ) + @pytest.fixture(scope="module") def sagemaker_local_session_no_local_code(boto_session): diff --git a/tests/integ/test_model_quality_monitor.py b/tests/integ/test_model_quality_monitor.py index 2e4bc9539e..1fafa96cfb 100644 --- a/tests/integ/test_model_quality_monitor.py +++ b/tests/integ/test_model_quality_monitor.py @@ -458,6 +458,7 @@ def _upload(s3_uri_base, input_file_name, target_time, file_name): capture_s3_uri_base = os.path.join( "s3://", sagemaker_session.default_bucket(), + sagemaker_session.default_bucket_prefix, "model-monitor", "data-capture", endpoint_name, diff --git a/tests/integ/test_sagemaker_config.py b/tests/integ/test_sagemaker_config.py index c202886d70..c2dbec0a8c 100644 --- a/tests/integ/test_sagemaker_config.py +++ b/tests/integ/test_sagemaker_config.py @@ -96,6 +96,16 @@ def sagemaker_session_with_dynamically_generated_sagemaker_config( config_as_dict = { "SchemaVersion": "1.0", "SageMaker": { + "PythonSDK": { + "Modules": { + "Session": { + "DefaultS3ObjectKeyPrefix": S3_KEY_PREFIX, + # S3Bucket is omitted for now, because the tests support one S3 bucket at + # the moment and it would be hard to validate injection of this parameter + # if we use the same bucket that the rest of the tests are. + }, + }, + }, "EndpointConfig": { "AsyncInferenceConfig": {"OutputConfig": {"KmsKeyId": kms_key_arn}}, "DataCaptureConfig": {"KmsKeyId": kms_key_arn}, @@ -216,11 +226,11 @@ def test_sagemaker_config_cross_context_injection( xgboost_data_path = os.path.join(DATA_DIR, "xgboost_model") sparkml_model_data = sagemaker_session.upload_data( path=os.path.join(sparkml_data_path, "mleap_model.tar.gz"), - key_prefix=S3_KEY_PREFIX + "/sparkml/model", + key_prefix="sparkml/model", ) xgb_model_data = sagemaker_session.upload_data( path=os.path.join(xgboost_data_path, "xgb_model.tar.gz"), - key_prefix=S3_KEY_PREFIX + "/xgboost/model", + key_prefix="xgboost/model", ) with timeout_and_delete_endpoint_by_name(name, sagemaker_session): @@ -260,7 +270,8 @@ def test_sagemaker_config_cross_context_injection( sparkml_model.enable_network_isolation(), xgb_model.enable_network_isolation(), pipeline_model.enable_network_isolation, # This is not a function in PipelineModel - ] == [role_arn, role_arn, role_arn, True, True, True] + sagemaker_session.default_bucket_prefix, + ] == [role_arn, role_arn, role_arn, True, True, True, S3_KEY_PREFIX] # First mutating API call where sagemaker_config values should be injected in predictor = pipeline_model.deploy( @@ -269,7 +280,6 @@ def test_sagemaker_config_cross_context_injection( endpoint_name=name, data_capture_config=DataCaptureConfig( True, - destination_s3_uri=data_capture_s3_uri, sagemaker_session=sagemaker_session, ), tags=test_tags, diff --git a/tests/integ/test_transformer.py b/tests/integ/test_transformer.py index 1de333b987..c1fa2f15d4 100644 --- a/tests/integ/test_transformer.py +++ b/tests/integ/test_transformer.py @@ -508,13 +508,13 @@ def test_single_transformer_multiple_jobs( with timeout_and_delete_model_with_transformer( transformer, sagemaker_session, minutes=TRANSFORM_DEFAULT_TIMEOUT_MINUTES ): - assert transformer.output_path == "s3://{}/{}".format( - sagemaker_session.default_bucket(), job_name + assert transformer.output_path == "s3://{}/{}/{}".format( + sagemaker_session.default_bucket(), sagemaker_session.default_bucket_prefix, job_name ) job_name = unique_name_from_base("test-mxnet-transform") transformer.transform(mxnet_transform_input, content_type="text/csv", job_name=job_name) - assert transformer.output_path == "s3://{}/{}".format( - sagemaker_session.default_bucket(), job_name + assert transformer.output_path == "s3://{}/{}/{}".format( + sagemaker_session.default_bucket(), sagemaker_session.default_bucket_prefix, job_name ) diff --git a/tests/unit/__init__.py b/tests/unit/__init__.py index df465fd31f..8e1bae3703 100644 --- a/tests/unit/__init__.py +++ b/tests/unit/__init__.py @@ -14,6 +14,8 @@ import os +from mock.mock import Mock + from sagemaker.config import ( SAGEMAKER, MONITORING_SCHEDULE, @@ -70,11 +72,31 @@ MODEL, ASYNC_INFERENCE_CONFIG, SCHEMA_VERSION, + PYTHON_SDK, + MODULES, + DEFAULT_S3_BUCKET, + DEFAULT_S3_OBJECT_KEY_PREFIX, + SESSION, ) DATA_DIR = os.path.join(os.path.dirname(__file__), "..", "data") PY_VERSION = "py3" +DEFAULT_S3_BUCKET_NAME = "sagemaker-config-session-s3-bucket" +DEFAULT_S3_OBJECT_KEY_PREFIX_NAME = "test-prefix" +SAGEMAKER_CONFIG_SESSION = { + SCHEMA_VERSION: "1.0", + SAGEMAKER: { + PYTHON_SDK: { + MODULES: { + SESSION: { + DEFAULT_S3_BUCKET: "sagemaker-config-session-s3-bucket", + DEFAULT_S3_OBJECT_KEY_PREFIX: "test-prefix", + }, + }, + }, + }, +} SAGEMAKER_CONFIG_MONITORING_SCHEDULE = { SCHEMA_VERSION: "1.0", @@ -277,3 +299,65 @@ }, }, } + + +def _test_default_bucket_and_prefix_combinations( + function_with_user_input=None, + function_without_user_input=None, + expected__without_user_input__with_default_bucket_and_default_prefix=None, + expected__without_user_input__with_default_bucket_only=None, + expected__with_user_input__with_default_bucket_and_prefix=None, + expected__with_user_input__with_default_bucket_only=None, + session_with_bucket_and_prefix=Mock( + name="sagemaker_session", + sagemaker_config={}, + default_bucket=Mock(name="default_bucket", return_value=DEFAULT_S3_BUCKET_NAME), + default_bucket_prefix=DEFAULT_S3_OBJECT_KEY_PREFIX_NAME, + config=None, + settings=None, + ), + session_with_bucket_and_no_prefix=Mock( + name="sagemaker_session", + sagemaker_config={}, + default_bucket_prefix=None, + default_bucket=Mock(name="default_bucket", return_value=DEFAULT_S3_BUCKET_NAME), + config=None, + settings=None, + ), +): + """ + Helper to test the different possible scenarios of how S3 params will be generated. + + Possible scenarios: + 1. User provided their own input, so (in most cases) there is no need to use default params + 2. User did not provide input. Session has a default_bucket_prefix set + 2. User did not provide input. Session does NOT have a default_bucket_prefix set + """ + + actual_values = [] + expected_values = [] + + # With Default Bucket and Default Prefix + if expected__without_user_input__with_default_bucket_and_default_prefix: + actual_values.append(function_without_user_input(session_with_bucket_and_prefix)) + expected_values.append(expected__without_user_input__with_default_bucket_and_default_prefix) + + # With Default Bucket and no Default Prefix + if expected__without_user_input__with_default_bucket_only: + actual_values.append(function_without_user_input(session_with_bucket_and_no_prefix)) + expected_values.append(expected__without_user_input__with_default_bucket_only) + + # With user input & With Default Bucket and Default Prefix + if expected__with_user_input__with_default_bucket_and_prefix: + actual_values.append(function_with_user_input(session_with_bucket_and_prefix)) + expected_values.append(expected__with_user_input__with_default_bucket_and_prefix) + + # With user input & With Default Bucket and no Default Prefix + if expected__with_user_input__with_default_bucket_only: + actual_values.append(function_with_user_input(session_with_bucket_and_no_prefix)) + expected_values.append(expected__with_user_input__with_default_bucket_only) + + # It is better to put assert statements in the caller function rather than within here. + # (If we put Asserts inside of this function, the info logged is not very debuggable. It just + # says that the Assert failed, and doesn't show the difference.) + return actual_values, expected_values diff --git a/tests/unit/sagemaker/automl/test_auto_ml.py b/tests/unit/sagemaker/automl/test_auto_ml.py index ddf89d6b35..ab3aa1ed70 100644 --- a/tests/unit/sagemaker/automl/test_auto_ml.py +++ b/tests/unit/sagemaker/automl/test_auto_ml.py @@ -21,7 +21,13 @@ from sagemaker.predictor import Predictor from sagemaker.session_settings import SessionSettings from sagemaker.workflow.functions import Join -from tests.unit import SAGEMAKER_CONFIG_AUTO_ML, SAGEMAKER_CONFIG_TRAINING_JOB +from tests.unit import ( + SAGEMAKER_CONFIG_AUTO_ML, + SAGEMAKER_CONFIG_TRAINING_JOB, + _test_default_bucket_and_prefix_combinations, + DEFAULT_S3_BUCKET_NAME, + DEFAULT_S3_OBJECT_KEY_PREFIX_NAME, +) MODEL_DATA = "s3://bucket/model.tar.gz" MODEL_IMAGE = "mi" @@ -261,6 +267,7 @@ def sagemaker_session(): config=None, local_mode=False, settings=SessionSettings(), + default_bucket_prefix=None, ) sms.default_bucket = Mock(name="default_bucket", return_value=BUCKET_NAME) sms.upload_data = Mock(name="upload_data", return_value=DEFAULT_S3_INPUT_DATA) @@ -984,3 +991,41 @@ def test_attach(sagemaker_session): assert aml.mode == "ENSEMBLING" assert aml.auto_generate_endpoint_name is False assert aml.endpoint_name == "EndpointName" + + +@patch("sagemaker.automl.automl.AutoMLJob.start_new") +def test_output_path_default_bucket_and_prefix_combinations(start_new): + def with_user_input(sess): + auto_ml = AutoML( + role=ROLE, + target_attribute_name=TARGET_ATTRIBUTE_NAME, + sagemaker_session=sess, + output_path="s3://test", + ) + inputs = DEFAULT_S3_INPUT_DATA + auto_ml.fit(inputs, job_name=JOB_NAME, wait=False, logs=True) + start_new.assert_called() # just to make sure this is patched with a mock + return auto_ml.output_path + + def without_user_input(sess): + auto_ml = AutoML( + role=ROLE, + target_attribute_name=TARGET_ATTRIBUTE_NAME, + sagemaker_session=sess, + ) + inputs = DEFAULT_S3_INPUT_DATA + auto_ml.fit(inputs, job_name=JOB_NAME, wait=False, logs=True) + start_new.assert_called() # just to make sure this is patched with a mock + return auto_ml.output_path + + actual, expected = _test_default_bucket_and_prefix_combinations( + function_with_user_input=with_user_input, + function_without_user_input=without_user_input, + expected__without_user_input__with_default_bucket_and_default_prefix=( + f"s3://{DEFAULT_S3_BUCKET_NAME}/{DEFAULT_S3_OBJECT_KEY_PREFIX_NAME}/" + ), + expected__without_user_input__with_default_bucket_only=f"s3://{DEFAULT_S3_BUCKET_NAME}/", + expected__with_user_input__with_default_bucket_and_prefix="s3://test", + expected__with_user_input__with_default_bucket_only="s3://test", + ) + assert actual == expected diff --git a/tests/unit/sagemaker/config/conftest.py b/tests/unit/sagemaker/config/conftest.py index 3e1feb4adc..473c437882 100644 --- a/tests/unit/sagemaker/config/conftest.py +++ b/tests/unit/sagemaker/config/conftest.py @@ -37,6 +37,14 @@ def valid_tags(): return [{"Key": "tag1", "Value": "tagValue1"}] +@pytest.fixture() +def valid_session_config(): + return { + "DefaultS3Bucket": "sagemaker-python-sdk-test-bucket", + "DefaultS3ObjectKeyPrefix": "test-prefix", + } + + @pytest.fixture() def valid_feature_group_config(valid_iam_role_arn): security_storage_config = {"KmsKeyId": "kmskeyid1"} @@ -172,25 +180,24 @@ def valid_monitoring_schedule_config(valid_iam_role_arn, valid_vpc_config): @pytest.fixture() def valid_remote_function_config(valid_iam_role_arn, valid_tags, valid_vpc_config): return { - "RemoteFunction": { - "Dependencies": "./requirements.txt", - "EnvironmentVariables": {"var1": "value1", "var2": "value2"}, - "ImageUri": "123456789012.dkr.ecr.us-west-2.amazonaws.com/myimage:latest", - "IncludeLocalWorkDir": True, - "InstanceType": "ml.m5.xlarge", - "JobCondaEnvironment": "some_conda_env", - "RoleArn": valid_iam_role_arn, - "S3KmsKeyId": "kmskeyid1", - "S3RootUri": "s3://my-bucket/key", - "Tags": valid_tags, - "VolumeKmsKeyId": "kmskeyid2", - "VpcConfig": valid_vpc_config, - } + "Dependencies": "./requirements.txt", + "EnvironmentVariables": {"var1": "value1", "var2": "value2"}, + "ImageUri": "123456789012.dkr.ecr.us-west-2.amazonaws.com/myimage:latest", + "IncludeLocalWorkDir": True, + "InstanceType": "ml.m5.xlarge", + "JobCondaEnvironment": "some_conda_env", + "RoleArn": valid_iam_role_arn, + "S3KmsKeyId": "kmskeyid1", + "S3RootUri": "s3://my-bucket/key", + "Tags": valid_tags, + "VolumeKmsKeyId": "kmskeyid2", + "VpcConfig": valid_vpc_config, } @pytest.fixture() def valid_config_with_all_the_scopes( + valid_session_config, valid_feature_group_config, valid_monitoring_schedule_config, valid_endpointconfig_config, @@ -206,6 +213,12 @@ def valid_config_with_all_the_scopes( valid_remote_function_config, ): return { + "PythonSDK": { + "Modules": { + "RemoteFunction": valid_remote_function_config, + "Session": valid_session_config, + } + }, "FeatureGroup": valid_feature_group_config, "MonitoringSchedule": valid_monitoring_schedule_config, "EndpointConfig": valid_endpointconfig_config, @@ -218,7 +231,6 @@ def valid_config_with_all_the_scopes( "ProcessingJob": valid_processing_job_config, "TrainingJob": valid_training_job_config, "EdgePackagingJob": valid_edge_packaging_config, - "PythonSDK": {"Modules": valid_remote_function_config}, } diff --git a/tests/unit/sagemaker/config/test_config_schema.py b/tests/unit/sagemaker/config/test_config_schema.py index 2efe47b8a2..e5e1e54b71 100644 --- a/tests/unit/sagemaker/config/test_config_schema.py +++ b/tests/unit/sagemaker/config/test_config_schema.py @@ -97,7 +97,8 @@ def test_valid_monitoring_schedule_schema( def test_valid_remote_function_schema(base_config_with_schema, valid_remote_function_config): _validate_config( - base_config_with_schema, {"PythonSDK": {"Modules": valid_remote_function_config}} + base_config_with_schema, + {"PythonSDK": {"Modules": {"RemoteFunction": valid_remote_function_config}}}, ) @@ -199,3 +200,75 @@ def test_invalid_s3uri_schema(base_config_with_schema): config["SageMaker"] = {"PythonSDK": {"Modules": {"RemoteFunction": {"S3RootUri": "bad_regex"}}}} with pytest.raises(exceptions.ValidationError): validate(config, SAGEMAKER_PYTHON_SDK_CONFIG_SCHEMA) + + +@pytest.mark.parametrize( + "bucket_name", + [ + "docexamplebucket1", + "log-delivery-march-2020", + "my-hosted-content", + "docexamplewebsite.com", + "www.docexamplewebsite.com", + "my.example.s3.bucket", + ], +) +def test_session_s3_bucket_schema(base_config_with_schema, bucket_name): + config = {"PythonSDK": {"Modules": {"Session": {"DefaultS3Bucket": bucket_name}}}} + _validate_config(base_config_with_schema, config) + + +@pytest.mark.parametrize( + "invalid_bucket_name", + [ + "ab", + "this-is-sixty-four-characters-total-which-is-one-above-the-limit", + "UPPERCASE-LETTERS", + "special_characters", + "special-characters@", + ".dot-at-the-beginning", + "-dash-at-the-beginning", + "dot-at-the-end.", + "dash-at-the-end-", + ], +) +def test_invalid_session_s3_bucket_schema(base_config_with_schema, invalid_bucket_name): + with pytest.raises(exceptions.ValidationError): + test_session_s3_bucket_schema(base_config_with_schema, invalid_bucket_name) + + +@pytest.mark.parametrize( + "prefix_name", + [ + "S3suggested/chars/0123/abc/ABC/!/-/_/./*/'/(/)", + "/slash/at/the/beginning", + "multiple/slashes//////in///the///middle/", + "Other/chars/&/$/@/=/;/:/+ /,/?", + "a", + # samples from https://docs.aws.amazon.com/AmazonS3/latest/userguide/object-keys.html + "Development/Projects.xls", + "Finance/statement1.pdf", + "Private/taxdocument.pdf", + "s3-dg.pdf", + "4my-organization", + "my.great_photos-2014/jan/myvacation.jpg", + "videos/2014/birthday/video1.wmv", + ], +) +def test_session_s3_object_key_prefix_schema(base_config_with_schema, prefix_name): + config = {"PythonSDK": {"Modules": {"Session": {"DefaultS3ObjectKeyPrefix": prefix_name}}}} + _validate_config(base_config_with_schema, config) + + +@pytest.mark.parametrize( + "invalid_prefix_name", + [ + "", + "too_many_chars_above_1024_" + ("a" * 1000), + 1000, + True, + ], +) +def test_invalid_session_s3_object_key_prefix_schema(base_config_with_schema, invalid_prefix_name): + with pytest.raises(exceptions.ValidationError): + test_session_s3_object_key_prefix_schema(base_config_with_schema, invalid_prefix_name) diff --git a/tests/unit/sagemaker/experiments/test_helper.py b/tests/unit/sagemaker/experiments/test_helper.py index a11f67389b..7fb49d4feb 100644 --- a/tests/unit/sagemaker/experiments/test_helper.py +++ b/tests/unit/sagemaker/experiments/test_helper.py @@ -26,6 +26,11 @@ ) from src.sagemaker.experiments._utils import resolve_artifact_name from src.sagemaker.session import Session +from tests.unit import ( + _test_default_bucket_and_prefix_combinations, + DEFAULT_S3_OBJECT_KEY_PREFIX_NAME, + DEFAULT_S3_BUCKET_NAME, +) @pytest.fixture @@ -193,3 +198,119 @@ def test_artifact_uploader_upload_object_artifact(tempdir, artifact_uploader): expected_uri = "s3://{}/{}".format(artifact_uploader.artifact_bucket, expected_key) assert expected_uri == s3_uri + + +def test_upload_artifact__default_bucket_and_prefix_combinations(tempdir): + path = os.path.join(tempdir, "exists") + with open(path, "a") as f: + f.write("boo") + + def with_user_input(sess): + artifact_uploader = _ArtifactUploader( + trial_component_name="trial_component_name", + artifact_bucket="artifact_bucket", + artifact_prefix="artifact_prefix", + sagemaker_session=sess, + ) + artifact_uploader._s3_client.head_object.return_value = {"ETag": "etag_value"} + s3_uri, etag = artifact_uploader.upload_artifact(path) + s3_uri_2, etag_2 = artifact_uploader.upload_artifact(path) + return s3_uri, s3_uri_2 + + def without_user_input(sess): + artifact_uploader = _ArtifactUploader( + trial_component_name="trial_component_name", + sagemaker_session=sess, + ) + artifact_uploader._s3_client.head_object.return_value = {"ETag": "etag_value"} + s3_uri, etag = artifact_uploader.upload_artifact(path) + s3_uri_2, etag_2 = artifact_uploader.upload_artifact(path) + return s3_uri, s3_uri_2 + + actual, expected = _test_default_bucket_and_prefix_combinations( + function_with_user_input=with_user_input, + function_without_user_input=without_user_input, + expected__without_user_input__with_default_bucket_and_default_prefix=( + f"s3://{DEFAULT_S3_BUCKET_NAME}/{DEFAULT_S3_OBJECT_KEY_PREFIX_NAME}/" + + "trial-component-artifacts/trial_component_name/exists", + f"s3://{DEFAULT_S3_BUCKET_NAME}/{DEFAULT_S3_OBJECT_KEY_PREFIX_NAME}/" + + "trial-component-artifacts/trial_component_name/exists", + ), + expected__without_user_input__with_default_bucket_only=( + f"s3://{DEFAULT_S3_BUCKET_NAME}/trial-component-artifacts/trial_component_name/exists", + f"s3://{DEFAULT_S3_BUCKET_NAME}/trial-component-artifacts/trial_component_name/exists", + ), + expected__with_user_input__with_default_bucket_and_prefix=( + "s3://artifact_bucket/artifact_prefix/trial_component_name/exists", + "s3://artifact_bucket/artifact_prefix/trial_component_name/exists", + ), + expected__with_user_input__with_default_bucket_only=( + "s3://artifact_bucket/artifact_prefix/trial_component_name/exists", + "s3://artifact_bucket/artifact_prefix/trial_component_name/exists", + ), + ) + assert actual == expected + + +def test_upload_object_artifact__default_bucket_and_prefix_combinations(tempdir): + path = os.path.join(tempdir, "exists") + with open(path, "a") as f: + f.write("boo") + + artifact_name = "my-artifact" + artifact_object = {"key": "value"} + file_extension = ".csv" + + def with_user_input(sess): + artifact_uploader = _ArtifactUploader( + trial_component_name="trial_component_name", + artifact_bucket="artifact_bucket", + artifact_prefix="artifact_prefix", + sagemaker_session=sess, + ) + artifact_uploader._s3_client.head_object.return_value = {"ETag": "etag_value"} + s3_uri, etag = artifact_uploader.upload_object_artifact( + artifact_name, artifact_object, file_extension + ) + s3_uri_2, etag_2 = artifact_uploader.upload_object_artifact( + artifact_name, artifact_object, file_extension + ) + return s3_uri, s3_uri_2 + + def without_user_input(sess): + artifact_uploader = _ArtifactUploader( + trial_component_name="trial_component_name", + sagemaker_session=sess, + ) + artifact_uploader._s3_client.head_object.return_value = {"ETag": "etag_value"} + s3_uri, etag = artifact_uploader.upload_object_artifact( + artifact_name, artifact_object, file_extension + ) + s3_uri_2, etag_2 = artifact_uploader.upload_object_artifact( + artifact_name, artifact_object, file_extension + ) + return s3_uri, s3_uri_2 + + actual, expected = _test_default_bucket_and_prefix_combinations( + function_with_user_input=with_user_input, + function_without_user_input=without_user_input, + expected__without_user_input__with_default_bucket_and_default_prefix=( + f"s3://{DEFAULT_S3_BUCKET_NAME}/{DEFAULT_S3_OBJECT_KEY_PREFIX_NAME}/" + + "trial-component-artifacts/trial_component_name/my-artifact.csv", + f"s3://{DEFAULT_S3_BUCKET_NAME}/{DEFAULT_S3_OBJECT_KEY_PREFIX_NAME}/" + + "trial-component-artifacts/trial_component_name/my-artifact.csv", + ), + expected__without_user_input__with_default_bucket_only=( + f"s3://{DEFAULT_S3_BUCKET_NAME}/trial-component-artifacts/trial_component_name/my-artifact.csv", + f"s3://{DEFAULT_S3_BUCKET_NAME}/trial-component-artifacts/trial_component_name/my-artifact.csv", + ), + expected__with_user_input__with_default_bucket_and_prefix=( + "s3://artifact_bucket/artifact_prefix/trial_component_name/my-artifact.csv", + "s3://artifact_bucket/artifact_prefix/trial_component_name/my-artifact.csv", + ), + expected__with_user_input__with_default_bucket_only=( + "s3://artifact_bucket/artifact_prefix/trial_component_name/my-artifact.csv", + "s3://artifact_bucket/artifact_prefix/trial_component_name/my-artifact.csv", + ), + ) + assert actual == expected diff --git a/tests/unit/sagemaker/huggingface/test_estimator.py b/tests/unit/sagemaker/huggingface/test_estimator.py index 666a142543..90aa81a841 100644 --- a/tests/unit/sagemaker/huggingface/test_estimator.py +++ b/tests/unit/sagemaker/huggingface/test_estimator.py @@ -66,6 +66,7 @@ def fixture_sagemaker_session(): s3_resource=None, s3_client=None, settings=SessionSettings(), + default_bucket_prefix=None, ) describe = {"ModelArtifacts": {"S3ModelArtifacts": "s3://m/m.tar.gz"}} diff --git a/tests/unit/sagemaker/huggingface/test_processing.py b/tests/unit/sagemaker/huggingface/test_processing.py index f5ea031143..491f4ab5df 100644 --- a/tests/unit/sagemaker/huggingface/test_processing.py +++ b/tests/unit/sagemaker/huggingface/test_processing.py @@ -44,6 +44,7 @@ def sagemaker_session(): config=None, local_mode=False, settings=SessionSettings(), + default_bucket_prefix=None, ) session_mock.default_bucket = Mock(name="default_bucket", return_value=BUCKET_NAME) diff --git a/tests/unit/sagemaker/inference_recommender/test_inference_recommender_mixin.py b/tests/unit/sagemaker/inference_recommender/test_inference_recommender_mixin.py index fc68abc072..41bf2829fd 100644 --- a/tests/unit/sagemaker/inference_recommender/test_inference_recommender_mixin.py +++ b/tests/unit/sagemaker/inference_recommender/test_inference_recommender_mixin.py @@ -180,7 +180,10 @@ @pytest.fixture() def sagemaker_session(): - session = MagicMock(boto_region_name=REGION) + session = MagicMock( + boto_region_name=REGION, + default_bucket_prefix=None, + ) session.create_inference_recommendations_job.return_value = IR_JOB_NAME session.wait_for_inference_recommendations_job.return_value = IR_SAMPLE_INFERENCE_RESPONSE diff --git a/tests/unit/sagemaker/local/test_local_session.py b/tests/unit/sagemaker/local/test_local_session.py index 728c7e0c06..dc8cf393df 100644 --- a/tests/unit/sagemaker/local/test_local_session.py +++ b/tests/unit/sagemaker/local/test_local_session.py @@ -17,7 +17,7 @@ import os from botocore.exceptions import ClientError from mock import Mock, patch -from tests.unit import DATA_DIR +from tests.unit import DATA_DIR, SAGEMAKER_CONFIG_SESSION import sagemaker from sagemaker.workflow.parameters import ParameterString @@ -956,3 +956,78 @@ def test_start_undefined_pipeline(): with pytest.raises(ClientError) as e: LocalSession().sagemaker_client.start_pipeline_execution("UndefinedPipeline") assert "Pipeline UndefinedPipeline does not exist" in str(e.value) + + +def test_default_bucket_with_sagemaker_config(boto_session, client): + # common kwargs for Session objects + session_kwargs = { + "boto_session": boto_session, + } + + # Case 1: Use bucket from sagemaker_config + session_with_config_bucket = LocalSession( + default_bucket=None, + sagemaker_config=SAGEMAKER_CONFIG_SESSION, + **session_kwargs, + ) + assert ( + session_with_config_bucket.default_bucket() + == SAGEMAKER_CONFIG_SESSION["SageMaker"]["PythonSDK"]["Modules"]["Session"][ + "DefaultS3Bucket" + ] + ) + + # Case 2: Use bucket from user input to Session (even if sagemaker_config has a bucket) + session_with_user_bucket = LocalSession( + default_bucket="default-bucket", + sagemaker_config=SAGEMAKER_CONFIG_SESSION, + **session_kwargs, + ) + assert session_with_user_bucket.default_bucket() == "default-bucket" + + # Case 3: Use default bucket of SDK + session_with_sdk_bucket = LocalSession( + default_bucket=None, + sagemaker_config=None, + **session_kwargs, + ) + session_with_sdk_bucket.boto_session.client.return_value = Mock( + get_caller_identity=Mock(return_value={"Account": "111111111"}) + ) + assert session_with_sdk_bucket.default_bucket() == "sagemaker-us-west-2-111111111" + + +def test_default_bucket_prefix_with_sagemaker_config(boto_session, client): + # common kwargs for Session objects + session_kwargs = { + "boto_session": boto_session, + } + + # Case 1: Use prefix from sagemaker_config + session_with_config_prefix = LocalSession( + default_bucket_prefix=None, + sagemaker_config=SAGEMAKER_CONFIG_SESSION, + **session_kwargs, + ) + assert ( + session_with_config_prefix.default_bucket_prefix + == SAGEMAKER_CONFIG_SESSION["SageMaker"]["PythonSDK"]["Modules"]["Session"][ + "DefaultS3ObjectKeyPrefix" + ] + ) + + # Case 2: Use prefix from user input to Session (even if sagemaker_config has a prefix) + session_with_user_prefix = LocalSession( + default_bucket_prefix="default-prefix", + sagemaker_config=SAGEMAKER_CONFIG_SESSION, + **session_kwargs, + ) + assert session_with_user_prefix.default_bucket_prefix == "default-prefix" + + # Case 3: Neither the user input or config has the prefix + session_with_no_prefix = LocalSession( + default_bucket_prefix=None, + sagemaker_config=None, + **session_kwargs, + ) + assert session_with_no_prefix.default_bucket_prefix is None diff --git a/tests/unit/sagemaker/model/test_deploy.py b/tests/unit/sagemaker/model/test_deploy.py index 3d0689fc73..e0ea24bacd 100644 --- a/tests/unit/sagemaker/model/test_deploy.py +++ b/tests/unit/sagemaker/model/test_deploy.py @@ -67,11 +67,15 @@ SHAP_BASELINE = '1,2,3,"good product"' CSV_MIME_TYPE = "text/csv" +BUCKET_NAME = "mybucket" @pytest.fixture def sagemaker_session(): - session = Mock() + session = Mock( + default_bucket_prefix=None, + ) + session.default_bucket = Mock(name="default_bucket", return_value=BUCKET_NAME) # For tests which doesn't verify config file injection, operate with empty config session.sagemaker_config = {} return session diff --git a/tests/unit/sagemaker/model/test_edge.py b/tests/unit/sagemaker/model/test_edge.py index 9e2c10c586..dc2a2da42f 100644 --- a/tests/unit/sagemaker/model/test_edge.py +++ b/tests/unit/sagemaker/model/test_edge.py @@ -30,7 +30,10 @@ @pytest.fixture def sagemaker_session(): - session = Mock(boto_region_name=REGION) + session = Mock( + boto_region_name=REGION, + default_bucket_prefix=None, + ) # For tests which doesn't verify config file injection, operate with empty config session.sagemaker_config = {} return session diff --git a/tests/unit/sagemaker/model/test_framework_model.py b/tests/unit/sagemaker/model/test_framework_model.py index 613ebefd64..f1c30083e8 100644 --- a/tests/unit/sagemaker/model/test_framework_model.py +++ b/tests/unit/sagemaker/model/test_framework_model.py @@ -92,6 +92,7 @@ def sagemaker_session(): s3_client=None, s3_resource=None, settings=SessionSettings(), + default_bucket_prefix=None, ) sms.default_bucket = Mock(name="default_bucket", return_value=BUCKET_NAME) # For tests which doesn't verify config file injection, operate with empty config diff --git a/tests/unit/sagemaker/model/test_model.py b/tests/unit/sagemaker/model/test_model.py index daa9d46763..f241221c00 100644 --- a/tests/unit/sagemaker/model/test_model.py +++ b/tests/unit/sagemaker/model/test_model.py @@ -17,6 +17,7 @@ from mock import Mock, patch import sagemaker +from sagemaker.async_inference import AsyncInferenceConfig from sagemaker.model import FrameworkModel, Model from sagemaker.huggingface.model import HuggingFaceModel from sagemaker.jumpstart.constants import JUMPSTART_BUCKET_NAME_SET, JUMPSTART_RESOURCE_BASE_NAME @@ -27,7 +28,11 @@ from sagemaker.tensorflow.model import TensorFlowModel from sagemaker.xgboost.model import XGBoostModel from sagemaker.workflow.properties import Properties - +from tests.unit import ( + _test_default_bucket_and_prefix_combinations, + DEFAULT_S3_BUCKET_NAME, + DEFAULT_S3_OBJECT_KEY_PREFIX_NAME, +) MODEL_DATA = "s3://bucket/model.tar.gz" MODEL_IMAGE = "mi" @@ -107,6 +112,7 @@ def sagemaker_session(): local_mode=False, s3_client=None, s3_resource=None, + default_bucket_prefix=None, ) sms.default_bucket = Mock(name="default_bucket", return_value=BUCKET_NAME) # For tests which doesn't verify config file injection, operate with empty config @@ -805,3 +811,113 @@ def test_model_local_download_dir(repack_model, sagemaker_session): repack_model.call_args_list[0][1]["sagemaker_session"].settings.local_download_dir == local_download_dir ) + + +@patch("sagemaker.model.fw_utils.tar_and_upload_dir") +def test__upload_code__default_bucket_and_prefix_combinations( + tar_and_upload_dir, +): + def with_user_input(sess): + model = Model( + entry_point=ENTRY_POINT_INFERENCE, + role=ROLE, + sagemaker_session=sess, + image_uri=IMAGE_URI, + model_data=MODEL_DATA, + code_location="s3://test-bucket/test-prefix/test-prefix-2", + ) + model._upload_code("upload-prefix/upload-prefix-2", repack=False) + kwargs = tar_and_upload_dir.call_args.kwargs + return kwargs["bucket"], kwargs["s3_key_prefix"] + + def without_user_input(sess): + model = Model( + entry_point=ENTRY_POINT_INFERENCE, + role=ROLE, + sagemaker_session=sess, + image_uri=IMAGE_URI, + model_data=MODEL_DATA, + ) + model._upload_code("upload-prefix/upload-prefix-2", repack=False) + kwargs = tar_and_upload_dir.call_args.kwargs + return kwargs["bucket"], kwargs["s3_key_prefix"] + + actual, expected = _test_default_bucket_and_prefix_combinations( + function_with_user_input=with_user_input, + function_without_user_input=without_user_input, + expected__without_user_input__with_default_bucket_and_default_prefix=( + DEFAULT_S3_BUCKET_NAME, + f"{DEFAULT_S3_OBJECT_KEY_PREFIX_NAME}/upload-prefix/upload-prefix-2", + ), + expected__without_user_input__with_default_bucket_only=( + DEFAULT_S3_BUCKET_NAME, + "upload-prefix/upload-prefix-2", + ), + expected__with_user_input__with_default_bucket_and_prefix=( + "test-bucket", + "upload-prefix/upload-prefix-2", + ), + expected__with_user_input__with_default_bucket_only=( + "test-bucket", + "upload-prefix/upload-prefix-2", + ), + ) + assert actual == expected + + +@patch("sagemaker.model.unique_name_from_base") +def test__build_default_async_inference_config__default_bucket_and_prefix_combinations( + unique_name_from_base, +): + unique_name_from_base.return_value = "unique-name" + + def with_user_input(sess): + model = Model( + entry_point=ENTRY_POINT_INFERENCE, + role=ROLE, + sagemaker_session=sess, + image_uri=IMAGE_URI, + model_data=MODEL_DATA, + code_location="s3://test-bucket/test-prefix/test-prefix-2", + ) + async_config = AsyncInferenceConfig( + output_path="s3://output-bucket/output-prefix/output-prefix-2", + failure_path="s3://failure-bucket/failure-prefix/failure-prefix-2", + ) + model._build_default_async_inference_config(async_config) + return async_config.output_path, async_config.failure_path + + def without_user_input(sess): + model = Model( + entry_point=ENTRY_POINT_INFERENCE, + role=ROLE, + sagemaker_session=sess, + image_uri=IMAGE_URI, + model_data=MODEL_DATA, + code_location="s3://test-bucket/test-prefix/test-prefix-2", + ) + async_config = AsyncInferenceConfig() + model._build_default_async_inference_config(async_config) + return async_config.output_path, async_config.failure_path + + actual, expected = _test_default_bucket_and_prefix_combinations( + function_with_user_input=with_user_input, + function_without_user_input=without_user_input, + expected__without_user_input__with_default_bucket_and_default_prefix=( + f"s3://{DEFAULT_S3_BUCKET_NAME}/{DEFAULT_S3_OBJECT_KEY_PREFIX_NAME}/async-endpoint-outputs/unique-name", + f"s3://{DEFAULT_S3_BUCKET_NAME}/{DEFAULT_S3_OBJECT_KEY_PREFIX_NAME}/async-endpoint-failures/unique-name", + ), + expected__without_user_input__with_default_bucket_only=( + f"s3://{DEFAULT_S3_BUCKET_NAME}/async-endpoint-outputs/unique-name", + f"s3://{DEFAULT_S3_BUCKET_NAME}/async-endpoint-failures/unique-name", + ), + expected__with_user_input__with_default_bucket_and_prefix=( + "s3://output-bucket/output-prefix/output-prefix-2", + "s3://failure-bucket/failure-prefix/failure-prefix-2", + ), + expected__with_user_input__with_default_bucket_only=( + "s3://output-bucket/output-prefix/output-prefix-2", + "s3://failure-bucket/failure-prefix/failure-prefix-2", + ), + ) + assert actual == expected diff --git a/tests/unit/sagemaker/model/test_model_package.py b/tests/unit/sagemaker/model/test_model_package.py index a87a2b74f4..fb45866481 100644 --- a/tests/unit/sagemaker/model/test_model_package.py +++ b/tests/unit/sagemaker/model/test_model_package.py @@ -55,7 +55,9 @@ @pytest.fixture def sagemaker_session(): - session = Mock() + session = Mock( + default_bucket_prefix=None, + ) session.sagemaker_client.describe_model_package = Mock( return_value=DESCRIBE_MODEL_PACKAGE_RESPONSE ) diff --git a/tests/unit/sagemaker/model/test_neo.py b/tests/unit/sagemaker/model/test_neo.py index 5aa1468fd6..3cc73e0ee5 100644 --- a/tests/unit/sagemaker/model/test_neo.py +++ b/tests/unit/sagemaker/model/test_neo.py @@ -34,7 +34,10 @@ @pytest.fixture def sagemaker_session(): - session = Mock(boto_region_name=REGION) + session = Mock( + boto_region_name=REGION, + default_bucket_prefix=None, + ) # For tests which doesn't verify config file injection, operate with empty config session.sagemaker_config = {} return session diff --git a/tests/unit/sagemaker/monitor/test_clarify_model_monitor.py b/tests/unit/sagemaker/monitor/test_clarify_model_monitor.py index 59d3c9b727..7f3a122da0 100644 --- a/tests/unit/sagemaker/monitor/test_clarify_model_monitor.py +++ b/tests/unit/sagemaker/monitor/test_clarify_model_monitor.py @@ -408,6 +408,7 @@ def sagemaker_session(sagemaker_client): boto_region_name="us-west-2", config=None, local_mode=False, + default_bucket_prefix=None, ) session_mock.default_bucket = Mock(name="default_bucket", return_value="mybucket") session_mock.upload_data = Mock( diff --git a/tests/unit/sagemaker/monitor/test_data_capture_config.py b/tests/unit/sagemaker/monitor/test_data_capture_config.py index 474c63f09a..f25b8782e0 100644 --- a/tests/unit/sagemaker/monitor/test_data_capture_config.py +++ b/tests/unit/sagemaker/monitor/test_data_capture_config.py @@ -54,7 +54,7 @@ def test_init_when_non_defaults_provided(): def test_init_when_optionals_not_provided(): - sagemaker_session = Mock() + sagemaker_session = Mock(default_bucket_prefix=None) sagemaker_session.default_bucket.return_value = DEFAULT_BUCKET_NAME sagemaker_session.sagemaker_config = {} diff --git a/tests/unit/sagemaker/monitor/test_model_monitoring.py b/tests/unit/sagemaker/monitor/test_model_monitoring.py index 0a00a1b7cc..d9bc072d7d 100644 --- a/tests/unit/sagemaker/monitor/test_model_monitoring.py +++ b/tests/unit/sagemaker/monitor/test_model_monitoring.py @@ -434,6 +434,7 @@ def sagemaker_session(): boto_region_name=REGION, config=None, local_mode=False, + default_bucket_prefix=None, ) session_mock.default_bucket = Mock(name="default_bucket", return_value=BUCKET_NAME) session_mock.upload_data = Mock( diff --git a/tests/unit/sagemaker/remote_function/test_job.py b/tests/unit/sagemaker/remote_function/test_job.py index 686862bcc7..bd030963c2 100644 --- a/tests/unit/sagemaker/remote_function/test_job.py +++ b/tests/unit/sagemaker/remote_function/test_job.py @@ -104,6 +104,7 @@ def mock_session(): session.boto_region_name = TEST_REGION session.sagemaker_config = None session._append_sagemaker_config_tags.return_value = [] + session.default_bucket_prefix = None return session @@ -266,7 +267,7 @@ def test_start( assert job.job_name.startswith("job-function") - assert mock_stored_function.called_once_with( + mock_stored_function.assert_called_once_with( sagemaker_session=session(), s3_base_uri=f"{S3_URI}/{job.job_name}", hmac_key=HMAC_KEY, @@ -386,11 +387,11 @@ def test_start_with_complete_job_settings( assert job.job_name.startswith("job-function") - assert mock_stored_function.called_once_with( + mock_stored_function.assert_called_once_with( sagemaker_session=session(), s3_base_uri=f"{S3_URI}/{job.job_name}", hmac_key=HMAC_KEY, - s3_kms_key=None, + s3_kms_key=KMS_KEY_ARN, ) local_dependencies_path = mock_runtime_manager().snapshot() diff --git a/tests/unit/sagemaker/spark/test_processing.py b/tests/unit/sagemaker/spark/test_processing.py index 16583d33ae..21371d4594 100644 --- a/tests/unit/sagemaker/spark/test_processing.py +++ b/tests/unit/sagemaker/spark/test_processing.py @@ -59,6 +59,7 @@ def sagemaker_session(): config=None, local_mode=False, settings=SessionSettings(), + default_bucket_prefix=None, ) session_mock.default_bucket = Mock(name="default_bucket", return_value=BUCKET_NAME) session_mock.sagemaker_config = {} diff --git a/tests/unit/sagemaker/tensorflow/test_estimator.py b/tests/unit/sagemaker/tensorflow/test_estimator.py index e384f21f92..6654a04202 100644 --- a/tests/unit/sagemaker/tensorflow/test_estimator.py +++ b/tests/unit/sagemaker/tensorflow/test_estimator.py @@ -74,6 +74,7 @@ def sagemaker_session(): s3_resource=None, s3_client=None, settings=SessionSettings(), + default_bucket_prefix=None, ) session.default_bucket = Mock(name="default_bucket", return_value=BUCKET_NAME) session.expand_role = Mock(name="expand_role", return_value=ROLE) diff --git a/tests/unit/sagemaker/tensorflow/test_estimator_attach.py b/tests/unit/sagemaker/tensorflow/test_estimator_attach.py index 1a8762ea5d..b8ec3af69d 100644 --- a/tests/unit/sagemaker/tensorflow/test_estimator_attach.py +++ b/tests/unit/sagemaker/tensorflow/test_estimator_attach.py @@ -36,6 +36,7 @@ def sagemaker_session(): local_mode=False, s3_resource=None, s3_client=None, + default_bucket_prefix=None, ) session.default_bucket = Mock(name="default_bucket", return_value=BUCKET_NAME) session.expand_role = Mock(name="expand_role", return_value=ROLE) diff --git a/tests/unit/sagemaker/tensorflow/test_estimator_init.py b/tests/unit/sagemaker/tensorflow/test_estimator_init.py index 9f4ee47034..3ea09d5b10 100644 --- a/tests/unit/sagemaker/tensorflow/test_estimator_init.py +++ b/tests/unit/sagemaker/tensorflow/test_estimator_init.py @@ -25,7 +25,11 @@ @pytest.fixture() def sagemaker_session(): - session_mock = Mock(name="sagemaker_session", boto_region_name=REGION) + session_mock = Mock( + name="sagemaker_session", + boto_region_name=REGION, + default_bucket_prefix=None, + ) session_mock.sagemaker_config = {} return session_mock diff --git a/tests/unit/sagemaker/tensorflow/test_tfs.py b/tests/unit/sagemaker/tensorflow/test_tfs.py index 2c8b9f3ff4..c9a50d161b 100644 --- a/tests/unit/sagemaker/tensorflow/test_tfs.py +++ b/tests/unit/sagemaker/tensorflow/test_tfs.py @@ -60,6 +60,7 @@ def sagemaker_session(): local_mode=False, s3_resource=None, s3_client=None, + default_bucket_prefix=None, ) session.default_bucket = Mock(name="default_bucket", return_value="my_bucket") session.expand_role = Mock(name="expand_role", return_value=ROLE) diff --git a/tests/unit/sagemaker/training_compiler/test_huggingface_pytorch_compiler.py b/tests/unit/sagemaker/training_compiler/test_huggingface_pytorch_compiler.py index cc8e2af0d2..96f6998af6 100644 --- a/tests/unit/sagemaker/training_compiler/test_huggingface_pytorch_compiler.py +++ b/tests/unit/sagemaker/training_compiler/test_huggingface_pytorch_compiler.py @@ -74,6 +74,7 @@ def fixture_sagemaker_session(): s3_resource=None, s3_client=None, settings=SessionSettings(), + default_bucket_prefix=None, ) describe = {"ModelArtifacts": {"S3ModelArtifacts": "s3://m/m.tar.gz"}} diff --git a/tests/unit/sagemaker/training_compiler/test_huggingface_tensorflow_compiler.py b/tests/unit/sagemaker/training_compiler/test_huggingface_tensorflow_compiler.py index 852d7ee372..a650379dfd 100644 --- a/tests/unit/sagemaker/training_compiler/test_huggingface_tensorflow_compiler.py +++ b/tests/unit/sagemaker/training_compiler/test_huggingface_tensorflow_compiler.py @@ -72,6 +72,7 @@ def fixture_sagemaker_session(): s3_resource=None, s3_client=None, settings=SessionSettings(), + default_bucket_prefix=None, ) describe = {"ModelArtifacts": {"S3ModelArtifacts": "s3://m/m.tar.gz"}} diff --git a/tests/unit/sagemaker/training_compiler/test_pytorch_compiler.py b/tests/unit/sagemaker/training_compiler/test_pytorch_compiler.py index 4d46fba62c..9a7ba698f3 100644 --- a/tests/unit/sagemaker/training_compiler/test_pytorch_compiler.py +++ b/tests/unit/sagemaker/training_compiler/test_pytorch_compiler.py @@ -73,6 +73,7 @@ def fixture_sagemaker_session(): s3_resource=None, s3_client=None, settings=SessionSettings(), + default_bucket_prefix=None, ) describe = {"ModelArtifacts": {"S3ModelArtifacts": "s3://m/m.tar.gz"}} diff --git a/tests/unit/sagemaker/training_compiler/test_tensorflow_compiler.py b/tests/unit/sagemaker/training_compiler/test_tensorflow_compiler.py index a94229654e..67530bc288 100644 --- a/tests/unit/sagemaker/training_compiler/test_tensorflow_compiler.py +++ b/tests/unit/sagemaker/training_compiler/test_tensorflow_compiler.py @@ -80,6 +80,7 @@ def fixture_sagemaker_session(): s3_resource=None, s3_client=None, settings=SessionSettings(), + default_bucket_prefix=None, ) describe = {"ModelArtifacts": {"S3ModelArtifacts": "s3://m/m.tar.gz"}} diff --git a/tests/unit/sagemaker/workflow/test_airflow.py b/tests/unit/sagemaker/workflow/test_airflow.py index 32afc6d7b5..742320cfb8 100644 --- a/tests/unit/sagemaker/workflow/test_airflow.py +++ b/tests/unit/sagemaker/workflow/test_airflow.py @@ -38,6 +38,7 @@ def sagemaker_session(): local_mode=False, s3_resource=None, s3_client=None, + default_bucket_prefix=None, ) session.default_bucket = Mock(name="default_bucket", return_value=BUCKET_NAME) session._default_bucket = BUCKET_NAME diff --git a/tests/unit/sagemaker/workflow/test_pipeline_session.py b/tests/unit/sagemaker/workflow/test_pipeline_session.py index a9a6fb41c5..b6c17033ed 100644 --- a/tests/unit/sagemaker/workflow/test_pipeline_session.py +++ b/tests/unit/sagemaker/workflow/test_pipeline_session.py @@ -28,32 +28,13 @@ from sagemaker.workflow.functions import Join, JsonGet from tests.unit.sagemaker.workflow.helpers import CustomStep -from botocore.config import Config - -from tests.unit import DATA_DIR +from tests.unit import DATA_DIR, SAGEMAKER_CONFIG_SESSION _REGION = "us-west-2" _ROLE = "DummyRole" _BUCKET = "my-bucket" -def test_pipeline_session_init(sagemaker_client_config, boto_session): - sagemaker_client_config.setdefault("config", Config(retries=dict(max_attempts=10))) - sagemaker_client = ( - boto_session.client("sagemaker", **sagemaker_client_config) - if sagemaker_client_config - else None - ) - - sess = PipelineSession( - boto_session=boto_session, - sagemaker_client=sagemaker_client, - ) - assert sess.sagemaker_client is not None - assert sess.default_bucket is not None - assert sess.context is None - - @pytest.fixture def client_mock(): """Mock client. @@ -94,6 +75,16 @@ def pipeline_session_mock(boto_session_mock, client_mock): ) +def test_pipeline_session_init(boto_session_mock, client_mock): + sess = PipelineSession( + boto_session=boto_session_mock, + sagemaker_client=client_mock, + ) + assert sess.sagemaker_client is not None + assert sess.default_bucket is not None + assert sess.context is None + + def test_pipeline_session_context_for_model_step(pipeline_session_mock): model = Model( name="MyModel", @@ -325,3 +316,80 @@ def test_pipeline_session_context_for_model_step_without_model_package_group_nam "inference_inferences and transform_instances " "must be provided if model_package_group_name is not present." == str(error) ) + + +def test_default_bucket_with_sagemaker_config(boto_session_mock, client_mock): + # common kwargs for Session objects + session_kwargs = { + "boto_session": boto_session_mock, + "sagemaker_client": client_mock, + } + + # Case 1: Use bucket from sagemaker_config + session_with_config_bucket = PipelineSession( + default_bucket=None, + sagemaker_config=SAGEMAKER_CONFIG_SESSION, + **session_kwargs, + ) + assert ( + session_with_config_bucket.default_bucket() + == SAGEMAKER_CONFIG_SESSION["SageMaker"]["PythonSDK"]["Modules"]["Session"][ + "DefaultS3Bucket" + ] + ) + + # Case 2: Use bucket from user input to Session (even if sagemaker_config has a bucket) + session_with_user_bucket = PipelineSession( + default_bucket="default-bucket", + sagemaker_config=SAGEMAKER_CONFIG_SESSION, + **session_kwargs, + ) + assert session_with_user_bucket.default_bucket() == "default-bucket" + + # Case 3: Use default bucket of SDK + session_with_sdk_bucket = PipelineSession( + default_bucket=None, + sagemaker_config=None, + **session_kwargs, + ) + session_with_sdk_bucket.boto_session.client.return_value = Mock( + get_caller_identity=Mock(return_value={"Account": "111111111"}) + ) + assert session_with_sdk_bucket.default_bucket() == "sagemaker-us-west-2-111111111" + + +def test_default_bucket_prefix_with_sagemaker_config(boto_session_mock, client_mock): + # common kwargs for Session objects + session_kwargs = { + "boto_session": boto_session_mock, + "sagemaker_client": client_mock, + } + + # Case 1: Use prefix from sagemaker_config + session_with_config_prefix = PipelineSession( + default_bucket_prefix=None, + sagemaker_config=SAGEMAKER_CONFIG_SESSION, + **session_kwargs, + ) + assert ( + session_with_config_prefix.default_bucket_prefix + == SAGEMAKER_CONFIG_SESSION["SageMaker"]["PythonSDK"]["Modules"]["Session"][ + "DefaultS3ObjectKeyPrefix" + ] + ) + + # Case 2: Use prefix from user input to Session (even if sagemaker_config has a prefix) + session_with_user_prefix = PipelineSession( + default_bucket_prefix="default-prefix", + sagemaker_config=SAGEMAKER_CONFIG_SESSION, + **session_kwargs, + ) + assert session_with_user_prefix.default_bucket_prefix == "default-prefix" + + # Case 3: Neither the user input or config has the prefix + session_with_no_prefix = PipelineSession( + default_bucket_prefix=None, + sagemaker_config=None, + **session_kwargs, + ) + assert session_with_no_prefix.default_bucket_prefix is None diff --git a/tests/unit/sagemaker/wrangler/test_processing.py b/tests/unit/sagemaker/wrangler/test_processing.py index 01cad17b76..37e13aff6e 100644 --- a/tests/unit/sagemaker/wrangler/test_processing.py +++ b/tests/unit/sagemaker/wrangler/test_processing.py @@ -38,6 +38,7 @@ def sagemaker_session(): config=None, local_mode=False, settings=SessionSettings(), + default_bucket_prefix=None, ) session_mock.expand_role.return_value = ROLE diff --git a/tests/unit/test_amazon_estimator.py b/tests/unit/test_amazon_estimator.py index f571f4cbc2..8b00b68dd9 100644 --- a/tests/unit/test_amazon_estimator.py +++ b/tests/unit/test_amazon_estimator.py @@ -24,11 +24,17 @@ FileSystemRecordSet, ) from sagemaker.session_settings import SessionSettings +from tests.unit import ( + DEFAULT_S3_OBJECT_KEY_PREFIX_NAME, + DEFAULT_S3_BUCKET_NAME, + _test_default_bucket_and_prefix_combinations, +) COMMON_ARGS = {"role": "myrole", "instance_count": 1, "instance_type": "ml.c4.xlarge"} REGION = "us-west-2" BUCKET_NAME = "Some-Bucket" +DEFAULT_PREFIX_NAME = "Some-Prefix" TIMESTAMP = "2017-11-06-14:14:15.671" @@ -42,6 +48,7 @@ def sagemaker_session(): config=None, local_mode=False, settings=SessionSettings(), + default_bucket_prefix=None, ) sms.boto_region_name = REGION sms.default_bucket = Mock(name="default_bucket", return_value=BUCKET_NAME) @@ -143,6 +150,29 @@ def test_data_location_does_not_call_default_bucket(sagemaker_session): assert not sagemaker_session.default_bucket.called +def test_data_location_default_bucket_and_prefix_combinations(): + actual, expected = _test_default_bucket_and_prefix_combinations( + function_with_user_input=( + lambda sess: PCA( + num_components=2, + sagemaker_session=sess, + data_location="s3://test", + **COMMON_ARGS, + ).data_location + ), + function_without_user_input=( + lambda sess: PCA(num_components=2, sagemaker_session=sess, **COMMON_ARGS).data_location + ), + expected__without_user_input__with_default_bucket_and_default_prefix=( + f"s3://{DEFAULT_S3_BUCKET_NAME}/{DEFAULT_S3_OBJECT_KEY_PREFIX_NAME}/sagemaker-record-sets/" + ), + expected__without_user_input__with_default_bucket_only=f"s3://{DEFAULT_S3_BUCKET_NAME}/sagemaker-record-sets/", + expected__with_user_input__with_default_bucket_and_prefix="s3://test", + expected__with_user_input__with_default_bucket_only="s3://test", + ) + assert actual == expected + + def test_prepare_for_training(sagemaker_session): pca = PCA(num_components=55, sagemaker_session=sagemaker_session, **COMMON_ARGS) diff --git a/tests/unit/test_analytics.py b/tests/unit/test_analytics.py index 8cb90dbf46..2f29b29bf7 100644 --- a/tests/unit/test_analytics.py +++ b/tests/unit/test_analytics.py @@ -49,6 +49,7 @@ def create_sagemaker_session( config=None, local_mode=False, settings=SessionSettings(), + default_bucket_prefix=None, ) sms.default_bucket = Mock(name="default_bucket", return_value=BUCKET_NAME) sms.sagemaker_client.describe_hyper_parameter_tuning_job = Mock( diff --git a/tests/unit/test_chainer.py b/tests/unit/test_chainer.py index fdf29fef2e..8ae318cb83 100644 --- a/tests/unit/test_chainer.py +++ b/tests/unit/test_chainer.py @@ -64,6 +64,7 @@ def sagemaker_session(): s3_resource=None, s3_client=None, settings=SessionSettings(), + default_bucket_prefix=None, ) describe = {"ModelArtifacts": {"S3ModelArtifacts": "s3://m/m.tar.gz"}} diff --git a/tests/unit/test_clarify.py b/tests/unit/test_clarify.py index 4fe1ecb5f2..c2be0e3db2 100644 --- a/tests/unit/test_clarify.py +++ b/tests/unit/test_clarify.py @@ -785,6 +785,7 @@ def sagemaker_session(): boto_region_name="us-west-2", config=None, local_mode=False, + default_bucket_prefix=None, ) session_mock.default_bucket = Mock(name="default_bucket", return_value="mybucket") session_mock.upload_data = Mock( diff --git a/tests/unit/test_collection.py b/tests/unit/test_collection.py index f62c141958..cbfaf9cc15 100644 --- a/tests/unit/test_collection.py +++ b/tests/unit/test_collection.py @@ -59,6 +59,7 @@ def sagemaker_session(): boto_region_name=REGION, config=None, local_mode=False, + default_bucket_prefix=None, ) session_mock.create_group = Mock( diff --git a/tests/unit/test_djl_inference.py b/tests/unit/test_djl_inference.py index d2f12cc7a1..787a61768c 100644 --- a/tests/unit/test_djl_inference.py +++ b/tests/unit/test_djl_inference.py @@ -29,7 +29,13 @@ DeepSpeedModel, ) from sagemaker.djl_inference.model import DJLServingEngineEntryPointDefaults +from sagemaker.s3_utils import s3_path_join from sagemaker.session_settings import SessionSettings +from tests.unit import ( + _test_default_bucket_and_prefix_combinations, + DEFAULT_S3_BUCKET_NAME, + DEFAULT_S3_OBJECT_KEY_PREFIX_NAME, +) VALID_UNCOMPRESSED_MODEL_DATA = "s3://mybucket/model" INVALID_UNCOMPRESSED_MODEL_DATA = "s3://mybucket/model.tar.gz" @@ -58,8 +64,9 @@ def sagemaker_session(): settings=SessionSettings(), create_model=Mock(name="create_model"), endpoint_from_production_variants=Mock(name="endpoint_from_production_variants"), + default_bucket_prefix=None, ) - session.default_bucket = Mock(name="default_bucket", return_valie=BUCKET) + session.default_bucket = Mock(name="default_bucket", return_value=BUCKET) # For tests which doesn't verify config file injection, operate with empty config session.sagemaker_config = {} @@ -592,4 +599,219 @@ def test_partition( IMAGE_URI, model_data_url="s3prefix", env=expected_env ) - assert model.model_id == f"{s3_output_uri}/s3prefix/aot-partitioned-checkpoints" + assert model.model_id == f"{s3_output_uri}s3prefix/aot-partitioned-checkpoints" + + +@patch("sagemaker.djl_inference.model.fw_utils.model_code_key_prefix") +@patch("sagemaker.djl_inference.model._get_model_config_properties_from_s3") +@patch("sagemaker.djl_inference.model.fw_utils.tar_and_upload_dir") +def test__upload_model_to_s3__with_upload_as_tar__default_bucket_and_prefix_combinations( + tar_and_upload_dir, + _get_model_config_properties_from_s3, + model_code_key_prefix, +): + # Skip appending of timestamps that this normally does + model_code_key_prefix.side_effect = lambda a, b, c: s3_path_join(a, b, c) + + def with_user_input(sess): + model = DJLModel( + VALID_UNCOMPRESSED_MODEL_DATA, + ROLE, + sagemaker_session=sess, + number_of_partitions=4, + data_type="fp16", + container_log_level=logging.DEBUG, + env=ENV, + code_location="s3://test-bucket/test-prefix/test-prefix-2", + image_uri="image_uri", + ) + model._upload_model_to_s3(upload_as_tar=True) + args = tar_and_upload_dir.call_args.args + return "s3://%s/%s" % (args[1], args[2]) + + def without_user_input(sess): + model = DJLModel( + VALID_UNCOMPRESSED_MODEL_DATA, + ROLE, + sagemaker_session=sess, + number_of_partitions=4, + data_type="fp16", + container_log_level=logging.DEBUG, + env=ENV, + image_uri="image_uri", + ) + model._upload_model_to_s3(upload_as_tar=True) + args = tar_and_upload_dir.call_args.args + return "s3://%s/%s" % (args[1], args[2]) + + actual, expected = _test_default_bucket_and_prefix_combinations( + function_with_user_input=with_user_input, + function_without_user_input=without_user_input, + expected__without_user_input__with_default_bucket_and_default_prefix=( + f"s3://{DEFAULT_S3_BUCKET_NAME}/{DEFAULT_S3_OBJECT_KEY_PREFIX_NAME}/image_uri" + ), + expected__without_user_input__with_default_bucket_only=( + f"s3://{DEFAULT_S3_BUCKET_NAME}/image_uri" + ), + expected__with_user_input__with_default_bucket_and_prefix=( + "s3://test-bucket/test-prefix/test-prefix-2/image_uri" + ), + expected__with_user_input__with_default_bucket_only=( + "s3://test-bucket/test-prefix/test-prefix-2/image_uri" + ), + ) + assert actual == expected + + +@patch("sagemaker.djl_inference.model.fw_utils.model_code_key_prefix") +@patch("sagemaker.djl_inference.model._get_model_config_properties_from_s3") +@patch("sagemaker.djl_inference.model.S3Uploader.upload") +def test__upload_model_to_s3__without_upload_as_tar__default_bucket_and_prefix_combinations( + upload, + _get_model_config_properties_from_s3, + model_code_key_prefix, +): + """This test is similar to test__upload_model_to_s3__with_upload_as_tar__default_bucket_and_prefix_combinations + + except upload_as_tar is False and S3Uploader.upload is checked + """ + + # Skip appending of timestamps that this normally does + model_code_key_prefix.side_effect = lambda a, b, c: s3_path_join(a, b, c) + + def with_user_input(sess): + model = DJLModel( + VALID_UNCOMPRESSED_MODEL_DATA, + ROLE, + sagemaker_session=sess, + number_of_partitions=4, + data_type="fp16", + container_log_level=logging.DEBUG, + env=ENV, + code_location="s3://test-bucket/test-prefix/test-prefix-2", + image_uri="image_uri", + ) + model._upload_model_to_s3(upload_as_tar=False) + args = upload.call_args.args + return args[1] + + def without_user_input(sess): + model = DJLModel( + VALID_UNCOMPRESSED_MODEL_DATA, + ROLE, + sagemaker_session=sess, + number_of_partitions=4, + data_type="fp16", + container_log_level=logging.DEBUG, + env=ENV, + image_uri="image_uri", + ) + model._upload_model_to_s3(upload_as_tar=False) + args = upload.call_args.args + return args[1] + + actual, expected = _test_default_bucket_and_prefix_combinations( + function_with_user_input=with_user_input, + function_without_user_input=without_user_input, + expected__without_user_input__with_default_bucket_and_default_prefix=( + f"s3://{DEFAULT_S3_BUCKET_NAME}/{DEFAULT_S3_OBJECT_KEY_PREFIX_NAME}/image_uri/aot-model" + ), + expected__without_user_input__with_default_bucket_only=( + f"s3://{DEFAULT_S3_BUCKET_NAME}/image_uri/aot-model" + ), + expected__with_user_input__with_default_bucket_and_prefix=( + "s3://test-bucket/test-prefix/test-prefix-2/image_uri/aot-model" + ), + expected__with_user_input__with_default_bucket_only=( + "s3://test-bucket/test-prefix/test-prefix-2/image_uri/aot-model" + ), + ) + assert actual == expected + + +@pytest.mark.parametrize( + ( + "code_location," + "expected__without_user_input__with_default_bucket_and_default_prefix, " + "expected__without_user_input__with_default_bucket_only, " + "expected__with_user_input__with_default_bucket_and_prefix, " + "expected__with_user_input__with_default_bucket_only" + ), + [ + ( + "s3://code-test-bucket/code-test-prefix/code-test-prefix-2", + "s3://code-test-bucket/code-test-prefix/code-test-prefix-2/image_uri", + "s3://code-test-bucket/code-test-prefix/code-test-prefix-2/image_uri", + "s3://test-bucket/test-prefix/test-prefix-2/code-test-prefix/code-test-prefix-2/image_uri", + "s3://test-bucket/test-prefix/test-prefix-2/code-test-prefix/code-test-prefix-2/image_uri", + ), + ( + None, + f"s3://{DEFAULT_S3_BUCKET_NAME}/{DEFAULT_S3_OBJECT_KEY_PREFIX_NAME}/image_uri", + f"s3://{DEFAULT_S3_BUCKET_NAME}/image_uri", + "s3://test-bucket/test-prefix/test-prefix-2/image_uri", + "s3://test-bucket/test-prefix/test-prefix-2/image_uri", + ), + ], +) +@patch("sagemaker.djl_inference.model.fw_utils.model_code_key_prefix") +@patch("sagemaker.djl_inference.model._get_model_config_properties_from_s3") +@patch("sagemaker.djl_inference.model.fw_utils.tar_and_upload_dir") +@patch("sagemaker.djl_inference.model._create_estimator") +def test_partition_default_bucket_and_prefix_combinations( + _create_estimator, + tar_and_upload_dir, + _get_model_config_properties_from_s3, + model_code_key_prefix, + code_location, + expected__without_user_input__with_default_bucket_and_default_prefix, + expected__without_user_input__with_default_bucket_only, + expected__with_user_input__with_default_bucket_and_prefix, + expected__with_user_input__with_default_bucket_only, +): + # Skip appending of timestamps that this normally does + model_code_key_prefix.side_effect = lambda a, b, c: s3_path_join(a, b, c) + + def with_user_input(sess): + model = DeepSpeedModel( + VALID_UNCOMPRESSED_MODEL_DATA, + ROLE, + sagemaker_session=sess, + data_type="fp16", + container_log_level=logging.DEBUG, + env=ENV, + code_location=code_location, + image_uri="image_uri", + ) + model.partition(GPU_INSTANCE, s3_output_uri="s3://test-bucket/test-prefix/test-prefix-2") + kwargs = _create_estimator.call_args.kwargs + return kwargs["s3_output_uri"] + + def without_user_input(sess): + model = DeepSpeedModel( + VALID_UNCOMPRESSED_MODEL_DATA, + ROLE, + sagemaker_session=sess, + data_type="fp16", + container_log_level=logging.DEBUG, + env=ENV, + code_location=code_location, + image_uri="image_uri", + ) + model.partition(GPU_INSTANCE) + kwargs = _create_estimator.call_args.kwargs + return kwargs["s3_output_uri"] + + actual, expected = _test_default_bucket_and_prefix_combinations( + function_with_user_input=with_user_input, + function_without_user_input=without_user_input, + expected__without_user_input__with_default_bucket_and_default_prefix=( + expected__without_user_input__with_default_bucket_and_default_prefix + ), + expected__without_user_input__with_default_bucket_only=expected__without_user_input__with_default_bucket_only, + expected__with_user_input__with_default_bucket_and_prefix=( + expected__with_user_input__with_default_bucket_and_prefix + ), + expected__with_user_input__with_default_bucket_only=expected__with_user_input__with_default_bucket_only, + ) + assert actual == expected diff --git a/tests/unit/test_estimator.py b/tests/unit/test_estimator.py index fa51ef6497..cd5eef371f 100644 --- a/tests/unit/test_estimator.py +++ b/tests/unit/test_estimator.py @@ -55,10 +55,16 @@ from sagemaker.tensorflow.estimator import TensorFlow from sagemaker.predictor_async import AsyncPredictor from sagemaker.transformer import Transformer +from sagemaker.workflow.execution_variables import ExecutionVariable from sagemaker.workflow.parameters import ParameterString, ParameterBoolean from sagemaker.workflow.pipeline_context import PipelineSession, _PipelineConfig from sagemaker.xgboost.estimator import XGBoost -from tests.unit import SAGEMAKER_CONFIG_TRAINING_JOB +from tests.unit import ( + SAGEMAKER_CONFIG_TRAINING_JOB, + _test_default_bucket_and_prefix_combinations, + DEFAULT_S3_BUCKET_NAME, + DEFAULT_S3_OBJECT_KEY_PREFIX_NAME, +) MODEL_DATA = "s3://bucket/model.tar.gz" MODEL_IMAGE = "mi" @@ -228,6 +234,7 @@ def sagemaker_session(): s3_client=None, s3_resource=None, settings=SessionSettings(), + default_bucket_prefix=None, ) sms.default_bucket = Mock(name="default_bucket", return_value=BUCKET_NAME) sms.sagemaker_client.describe_training_job = Mock( @@ -4399,7 +4406,6 @@ def test_script_mode_estimator_tags_jumpstart_models_with_no_estimator_js_tags( ] -@patch("time.time", return_value=TIME) @patch("sagemaker.estimator.tar_and_upload_dir") @patch("sagemaker.model.Model._upload_code") @patch("sagemaker.utils.repack_model") @@ -4539,7 +4545,6 @@ def test_script_mode_estimator_uses_jumpstart_base_name_with_js_models( ) -@patch("time.time", return_value=TIME) @patch("sagemaker.estimator.tar_and_upload_dir") @patch("sagemaker.model.Model._upload_code") @patch("sagemaker.utils.repack_model") @@ -4777,3 +4782,215 @@ def test_estimator_local_download_dir( patched_tar_and_upload_dir.call_args_list[0][1]["settings"].local_download_dir == local_download_dir ) + + +@pytest.mark.parametrize( + "input_key_prefix, input_current_job_name, input_pipeline_config, output_code_s3_prefix", + [ + ( + "my/prefix", + "job-name", + MOCKED_PIPELINE_CONFIG, + "my/prefix/test-pipeline/code/code-hash-0123456789", + ), + ("my/prefix", "job-name", None, "my/prefix/job-name/source"), + ("", "job-name", MOCKED_PIPELINE_CONFIG, "test-pipeline/code/code-hash-0123456789"), + ("", "job-name", None, "job-name/source"), + (None, "job-name", MOCKED_PIPELINE_CONFIG, "test-pipeline/code/code-hash-0123456789"), + (None, "job-name", None, "job-name/source"), + (None, None, MOCKED_PIPELINE_CONFIG, "test-pipeline/code/code-hash-0123456789"), + (None, None, None, "source"), + ], +) +def test_assign_s3_prefix( + sagemaker_session, + input_key_prefix, + input_current_job_name, + input_pipeline_config, + output_code_s3_prefix, +): + + with patch("sagemaker.workflow.utilities._pipeline_config", input_pipeline_config): + framework = DummyFramework( + "my_script.py", + role="DummyRole", + sagemaker_session=sagemaker_session, + ) + framework._current_job_name = input_current_job_name + assert framework._assign_s3_prefix(input_key_prefix) == output_code_s3_prefix + + +@patch("sagemaker.estimator._TrainingJob.start_new") +@patch("sagemaker.estimator.tar_and_upload_dir") +def test_output_path_default_bucket_and_prefix_combinations(start_new, tar_and_upload_dir): + def with_user_input(sess): + framework = DummyFramework( + "my_script.py", + role="DummyRole", + sagemaker_session=sess, + output_path="s3://test", + ) + framework.fit(None, job_name=JOB_NAME, wait=False, logs=True) + start_new.assert_called() # just to make sure this is patched with a mock + tar_and_upload_dir.assert_called() # just to make sure this is patched with a mock + return framework.output_path + + def without_user_input(sess): + framework = DummyFramework( + "my_script.py", + role="DummyRole", + sagemaker_session=sess, + ) + framework.fit(None, job_name=JOB_NAME, wait=False, logs=True) + start_new.assert_called() # just to make sure this is patched with a mock + tar_and_upload_dir.assert_called() # just to make sure this is patched with a mock + return framework.output_path + + actual, expected = _test_default_bucket_and_prefix_combinations( + function_with_user_input=with_user_input, + function_without_user_input=without_user_input, + expected__without_user_input__with_default_bucket_and_default_prefix=( + f"s3://{DEFAULT_S3_BUCKET_NAME}/{DEFAULT_S3_OBJECT_KEY_PREFIX_NAME}/" + ), + expected__without_user_input__with_default_bucket_only=f"s3://{DEFAULT_S3_BUCKET_NAME}/", + expected__with_user_input__with_default_bucket_and_prefix="s3://test", + expected__with_user_input__with_default_bucket_only="s3://test", + ) + assert actual == expected + + +@patch("sagemaker.estimator.tar_and_upload_dir") +@pytest.mark.parametrize( + ( + "output_path, code_location," + "expected__without_user_input__with_default_bucket_and_default_prefix, " + "expected__without_user_input__with_default_bucket_only, " + "expected__with_user_input__with_default_bucket_and_prefix, " + "expected__with_user_input__with_default_bucket_only" + ), + [ + # Group of not-None output_bucket + ( + "s3://output-bucket/output-prefix/output-prefix2", + "s3://code-bucket/code-prefix/code-prefix2", + (DEFAULT_S3_BUCKET_NAME, f"{DEFAULT_S3_OBJECT_KEY_PREFIX_NAME}/{JOB_NAME}/source"), + (DEFAULT_S3_BUCKET_NAME, f"{JOB_NAME}/source"), + ("code-bucket", f"code-prefix/code-prefix2/{JOB_NAME}/source"), + ("code-bucket", f"code-prefix/code-prefix2/{JOB_NAME}/source"), + ), + ( + "s3://output-bucket/output-prefix/output-prefix2", + None, + (DEFAULT_S3_BUCKET_NAME, f"{DEFAULT_S3_OBJECT_KEY_PREFIX_NAME}/{JOB_NAME}/source"), + (DEFAULT_S3_BUCKET_NAME, f"{JOB_NAME}/source"), + ("output-bucket", f"{JOB_NAME}/source"), + ("output-bucket", f"{JOB_NAME}/source"), + ), + # Group of None output_bucket + ( + None, + "s3://code-bucket/code-prefix/code-prefix2", + (DEFAULT_S3_BUCKET_NAME, f"{DEFAULT_S3_OBJECT_KEY_PREFIX_NAME}/{JOB_NAME}/source"), + (DEFAULT_S3_BUCKET_NAME, f"{JOB_NAME}/source"), + ("code-bucket", f"code-prefix/code-prefix2/{JOB_NAME}/source"), + ("code-bucket", f"code-prefix/code-prefix2/{JOB_NAME}/source"), + ), + ( + None, + None, + (DEFAULT_S3_BUCKET_NAME, f"{DEFAULT_S3_OBJECT_KEY_PREFIX_NAME}/{JOB_NAME}/source"), + (DEFAULT_S3_BUCKET_NAME, f"{JOB_NAME}/source"), + (DEFAULT_S3_BUCKET_NAME, f"{DEFAULT_S3_OBJECT_KEY_PREFIX_NAME}/{JOB_NAME}/source"), + (DEFAULT_S3_BUCKET_NAME, f"{JOB_NAME}/source"), + ), + # Group of PipelineVariable output_bucket + ( + ExecutionVariable("output_path"), + "s3://code-bucket/code-prefix/code-prefix2", + (DEFAULT_S3_BUCKET_NAME, f"{DEFAULT_S3_OBJECT_KEY_PREFIX_NAME}/{JOB_NAME}/source"), + (DEFAULT_S3_BUCKET_NAME, f"{JOB_NAME}/source"), + ("code-bucket", f"code-prefix/code-prefix2/{JOB_NAME}/source"), + ("code-bucket", f"code-prefix/code-prefix2/{JOB_NAME}/source"), + ), + ( + ExecutionVariable("output_path"), + None, + (DEFAULT_S3_BUCKET_NAME, f"{DEFAULT_S3_OBJECT_KEY_PREFIX_NAME}/{JOB_NAME}/source"), + (DEFAULT_S3_BUCKET_NAME, f"{JOB_NAME}/source"), + (DEFAULT_S3_BUCKET_NAME, f"{DEFAULT_S3_OBJECT_KEY_PREFIX_NAME}/{JOB_NAME}/source"), + (DEFAULT_S3_BUCKET_NAME, f"{JOB_NAME}/source"), + ), + # Group of file output_bucket + ( + "file://output-bucket/output-prefix/output-prefix2", + "s3://code-bucket/code-prefix/code-prefix2", + (DEFAULT_S3_BUCKET_NAME, f"{DEFAULT_S3_OBJECT_KEY_PREFIX_NAME}/{JOB_NAME}/source"), + (DEFAULT_S3_BUCKET_NAME, f"{JOB_NAME}/source"), + ("code-bucket", f"code-prefix/code-prefix2/{JOB_NAME}/source"), + ("code-bucket", f"code-prefix/code-prefix2/{JOB_NAME}/source"), + ), + ( + "file://output-bucket/output-prefix/output-prefix2", + None, + (DEFAULT_S3_BUCKET_NAME, f"{DEFAULT_S3_OBJECT_KEY_PREFIX_NAME}/{JOB_NAME}/source"), + (DEFAULT_S3_BUCKET_NAME, f"{JOB_NAME}/source"), + (DEFAULT_S3_BUCKET_NAME, f"{DEFAULT_S3_OBJECT_KEY_PREFIX_NAME}/{JOB_NAME}/source"), + (DEFAULT_S3_BUCKET_NAME, f"{JOB_NAME}/source"), + ), + ], +) +def test_stage_user_code_in_s3_default_bucket_and_prefix_combinations( + tar_and_upload_dir, + output_path, + code_location, + expected__without_user_input__with_default_bucket_and_default_prefix, + expected__without_user_input__with_default_bucket_only, + expected__with_user_input__with_default_bucket_and_prefix, + expected__with_user_input__with_default_bucket_only, +): + def with_user_input(sess): + framework = DummyFramework( + "my_script.py", + role="DummyRole", + sagemaker_session=sess, + ) + + if output_path is not None: + framework.output_path = output_path + if code_location is not None: + framework.code_location = code_location + + # this method calls _stage_user_code_in_s3() + framework._prepare_for_training(job_name=JOB_NAME) + kwargs = tar_and_upload_dir.call_args.kwargs + return kwargs["bucket"], kwargs["s3_key_prefix"] + + def without_user_input(sess): + framework = DummyFramework( + "my_script.py", + role="DummyRole", + sagemaker_session=sess, + ) + + # this method calls _stage_user_code_in_s3() + framework._prepare_for_training(job_name=JOB_NAME) + kwargs = tar_and_upload_dir.call_args.kwargs + return kwargs["bucket"], kwargs["s3_key_prefix"] + + actual, expected = _test_default_bucket_and_prefix_combinations( + function_with_user_input=with_user_input, + function_without_user_input=without_user_input, + expected__without_user_input__with_default_bucket_and_default_prefix=( + expected__without_user_input__with_default_bucket_and_default_prefix + ), + expected__without_user_input__with_default_bucket_only=( + expected__without_user_input__with_default_bucket_only + ), + expected__with_user_input__with_default_bucket_and_prefix=( + expected__with_user_input__with_default_bucket_and_prefix + ), + expected__with_user_input__with_default_bucket_only=( + expected__with_user_input__with_default_bucket_only + ), + ) + assert actual == expected diff --git a/tests/unit/test_fm.py b/tests/unit/test_fm.py index 61f8079396..ebac0dfbb9 100644 --- a/tests/unit/test_fm.py +++ b/tests/unit/test_fm.py @@ -60,6 +60,7 @@ def sagemaker_session(): s3_client=False, s3_resource=False, settings=SessionSettings(), + default_bucket_prefix=None, ) sms.boto_region_name = REGION sms.default_bucket = Mock(name="default_bucket", return_value=BUCKET_NAME) diff --git a/tests/unit/test_fw_utils.py b/tests/unit/test_fw_utils.py index 6e311a8fdd..b0c699a7f9 100644 --- a/tests/unit/test_fw_utils.py +++ b/tests/unit/test_fw_utils.py @@ -48,6 +48,7 @@ def sagemaker_session(): s3_client=None, s3_resource=None, settings=SessionSettings(), + default_bucket_prefix=None, ) session_mock.default_bucket = Mock(name="default_bucket", return_value="my-bucket") session_mock.expand_role = Mock(name="expand_role", return_value="my-role") diff --git a/tests/unit/test_ipinsights.py b/tests/unit/test_ipinsights.py index 5fe3882ed3..da4b8a9477 100644 --- a/tests/unit/test_ipinsights.py +++ b/tests/unit/test_ipinsights.py @@ -57,6 +57,7 @@ def sagemaker_session(): config=None, local_mode=False, settings=SessionSettings(), + default_bucket_prefix=None, ) sms.boto_region_name = REGION sms.default_bucket = Mock(name="default_bucket", return_value=BUCKET_NAME) diff --git a/tests/unit/test_job.py b/tests/unit/test_job.py index c151bc8174..c378901b84 100644 --- a/tests/unit/test_job.py +++ b/tests/unit/test_job.py @@ -77,7 +77,11 @@ def estimator(sagemaker_session): def sagemaker_session(): boto_mock = Mock(name="boto_session") mock_session = Mock( - name="sagemaker_session", boto_session=boto_mock, s3_client=None, s3_resource=None + name="sagemaker_session", + boto_session=boto_mock, + s3_client=None, + s3_resource=None, + default_bucket_prefix=None, ) mock_session.expand_role = Mock(name="expand_role", return_value=ROLE) # For tests which doesn't verify config file injection, operate with empty config diff --git a/tests/unit/test_kmeans.py b/tests/unit/test_kmeans.py index e966f4024c..3d91726478 100644 --- a/tests/unit/test_kmeans.py +++ b/tests/unit/test_kmeans.py @@ -54,6 +54,7 @@ def sagemaker_session(): s3_client=None, s3_resource=None, settings=SessionSettings(), + default_bucket_prefix=None, ) sms.boto_region_name = REGION sms.default_bucket = Mock(name="default_bucket", return_value=BUCKET_NAME) diff --git a/tests/unit/test_knn.py b/tests/unit/test_knn.py index 704d7a665f..0480d1891c 100644 --- a/tests/unit/test_knn.py +++ b/tests/unit/test_knn.py @@ -60,6 +60,7 @@ def sagemaker_session(): s3_client=None, s3_resource=None, settings=SessionSettings(), + default_bucket_prefix=None, ) sms.boto_region_name = REGION sms.default_bucket = Mock(name="default_bucket", return_value=BUCKET_NAME) diff --git a/tests/unit/test_lambda_helper.py b/tests/unit/test_lambda_helper.py index d6ad11beb7..48337c2205 100644 --- a/tests/unit/test_lambda_helper.py +++ b/tests/unit/test_lambda_helper.py @@ -42,6 +42,7 @@ def sagemaker_session(): config=None, local_mode=False, # default_bucket=S3_BUCKET, + default_bucket_prefix=None, ) return session_mock diff --git a/tests/unit/test_lda.py b/tests/unit/test_lda.py index f0adbc4d0a..f39df24d75 100644 --- a/tests/unit/test_lda.py +++ b/tests/unit/test_lda.py @@ -49,6 +49,7 @@ def sagemaker_session(): s3_client=None, s3_resource=None, settings=SessionSettings(), + default_bucket_prefix=None, ) sms.boto_region_name = REGION sms.default_bucket = Mock(name="default_bucket", return_value=BUCKET_NAME) diff --git a/tests/unit/test_linear_learner.py b/tests/unit/test_linear_learner.py index 1d0d5d08dc..3e45d76784 100644 --- a/tests/unit/test_linear_learner.py +++ b/tests/unit/test_linear_learner.py @@ -55,6 +55,7 @@ def sagemaker_session(): s3_client=None, s3_resource=None, settings=SessionSettings(), + default_bucket_prefix=None, ) sms.boto_region_name = REGION sms.default_bucket = Mock(name="default_bucket", return_value=BUCKET_NAME) diff --git a/tests/unit/test_multidatamodel.py b/tests/unit/test_multidatamodel.py index 45216f0a62..e354c0240e 100644 --- a/tests/unit/test_multidatamodel.py +++ b/tests/unit/test_multidatamodel.py @@ -70,6 +70,7 @@ def sagemaker_session(): local_mode=False, s3_resource=None, s3_client=None, + default_bucket_prefix=None, ) session.sagemaker_client.describe_endpoint = Mock(return_value=ENDPOINT_DESC) session.sagemaker_client.describe_endpoint_config = Mock(return_value=ENDPOINT_CONFIG_DESC) diff --git a/tests/unit/test_mxnet.py b/tests/unit/test_mxnet.py index d5dd02ba05..0fc9e822fa 100644 --- a/tests/unit/test_mxnet.py +++ b/tests/unit/test_mxnet.py @@ -85,6 +85,7 @@ def sagemaker_session(): s3_resource=None, s3_client=None, settings=SessionSettings(), + default_bucket_prefix=None, ) describe = {"ModelArtifacts": {"S3ModelArtifacts": "s3://m/m.tar.gz"}} diff --git a/tests/unit/test_ntm.py b/tests/unit/test_ntm.py index 6b0339bf74..8db3688f84 100644 --- a/tests/unit/test_ntm.py +++ b/tests/unit/test_ntm.py @@ -54,6 +54,7 @@ def sagemaker_session(): s3_client=None, s3_resource=None, settings=SessionSettings(), + default_bucket_prefix=None, ) sms.boto_region_name = REGION sms.default_bucket = Mock(name="default_bucket", return_value=BUCKET_NAME) diff --git a/tests/unit/test_object2vec.py b/tests/unit/test_object2vec.py index 9a5eac7931..d26faa4bb2 100644 --- a/tests/unit/test_object2vec.py +++ b/tests/unit/test_object2vec.py @@ -62,6 +62,7 @@ def sagemaker_session(): s3_client=None, s3_resource=None, settings=SessionSettings(), + default_bucket_prefix=None, ) sms.boto_region_name = REGION sms.default_bucket = Mock(name="default_bucket", return_value=BUCKET_NAME) diff --git a/tests/unit/test_pca.py b/tests/unit/test_pca.py index 4a00f1ea0d..1f9460293d 100644 --- a/tests/unit/test_pca.py +++ b/tests/unit/test_pca.py @@ -54,6 +54,7 @@ def sagemaker_session(): s3_client=None, s3_resource=None, settings=SessionSettings(), + default_bucket_prefix=None, ) sms.boto_region_name = REGION sms.default_bucket = Mock(name="default_bucket", return_value=BUCKET_NAME) diff --git a/tests/unit/test_pipeline_model.py b/tests/unit/test_pipeline_model.py index 57a909a73d..b546d4e9e8 100644 --- a/tests/unit/test_pipeline_model.py +++ b/tests/unit/test_pipeline_model.py @@ -72,6 +72,7 @@ def sagemaker_session(): s3_client=None, s3_resource=None, settings=SessionSettings(), + default_bucket_prefix=None, ) sms.default_bucket = Mock(name="default_bucket", return_value=BUCKET_NAME) # For tests which doesn't verify config file injection, operate with empty config diff --git a/tests/unit/test_predictor.py b/tests/unit/test_predictor.py index a9a7312f34..2a3616575d 100644 --- a/tests/unit/test_predictor.py +++ b/tests/unit/test_predictor.py @@ -39,7 +39,10 @@ def empty_sagemaker_session(): - ims = Mock(name="sagemaker_session") + ims = Mock( + name="sagemaker_session", + default_bucket_prefix=None, + ) ims.default_bucket = Mock(name="default_bucket", return_value=BUCKET_NAME) ims.sagemaker_runtime_client = Mock(name="sagemaker_runtime") ims.sagemaker_client.describe_endpoint = Mock(return_value=ENDPOINT_DESC) @@ -147,7 +150,10 @@ def test_multi_model_predict_call(): def json_sagemaker_session(): - ims = Mock(name="sagemaker_session") + ims = Mock( + name="sagemaker_session", + default_bucket_prefix=None, + ) ims.default_bucket = Mock(name="default_bucket", return_value=BUCKET_NAME) ims.sagemaker_runtime_client = Mock(name="sagemaker_runtime") ims.sagemaker_client.describe_endpoint = Mock(return_value=ENDPOINT_DESC) @@ -188,7 +194,10 @@ def test_predict_call_with_json(): def ret_csv_sagemaker_session(): - ims = Mock(name="sagemaker_session") + ims = Mock( + name="sagemaker_session", + default_bucket_prefix=None, + ) ims.default_bucket = Mock(name="default_bucket", return_value=BUCKET_NAME) ims.sagemaker_runtime_client = Mock(name="sagemaker_runtime") ims.sagemaker_client.describe_endpoint = Mock(return_value=ENDPOINT_DESC) diff --git a/tests/unit/test_predictor_async.py b/tests/unit/test_predictor_async.py index 6754506680..1af21a36ff 100644 --- a/tests/unit/test_predictor_async.py +++ b/tests/unit/test_predictor_async.py @@ -43,7 +43,10 @@ def empty_sagemaker_session(): - ims = Mock(name="sagemaker_session") + ims = Mock( + name="sagemaker_session", + default_bucket_prefix=None, + ) ims.default_bucket = Mock(name="default_bucket", return_value=BUCKET_NAME) ims.sagemaker_runtime_client = Mock(name="sagemaker_runtime") ims.sagemaker_client.describe_endpoint = Mock(return_value=ENDPOINT_DESC) diff --git a/tests/unit/test_processing.py b/tests/unit/test_processing.py index d6265fa0ac..55bae583fd 100644 --- a/tests/unit/test_processing.py +++ b/tests/unit/test_processing.py @@ -74,6 +74,7 @@ def sagemaker_session(): config=None, local_mode=False, settings=SessionSettings(), + default_bucket_prefix=None, ) session_mock.default_bucket = Mock(name="default_bucket", return_value=BUCKET_NAME) @@ -99,6 +100,7 @@ def pipeline_session(): config=None, local_mode=False, settings=SessionSettings(), + default_bucket_prefix=None, ) session_mock.default_bucket = Mock(name="default_bucket", return_value=BUCKET_NAME) diff --git a/tests/unit/test_pytorch.py b/tests/unit/test_pytorch.py index e0c49ea328..d450238854 100644 --- a/tests/unit/test_pytorch.py +++ b/tests/unit/test_pytorch.py @@ -73,6 +73,7 @@ def fixture_sagemaker_session(): s3_resource=None, s3_client=None, settings=SessionSettings(), + default_bucket_prefix=None, ) describe = {"ModelArtifacts": {"S3ModelArtifacts": "s3://m/m.tar.gz"}} diff --git a/tests/unit/test_randomcutforest.py b/tests/unit/test_randomcutforest.py index daa6f8cacc..d9884f9cde 100644 --- a/tests/unit/test_randomcutforest.py +++ b/tests/unit/test_randomcutforest.py @@ -54,6 +54,7 @@ def sagemaker_session(): config=None, local_mode=False, settings=SessionSettings(), + default_bucket_prefix=None, ) sms.boto_region_name = REGION sms.default_bucket = Mock(name="default_bucket", return_value=BUCKET_NAME) diff --git a/tests/unit/test_rl.py b/tests/unit/test_rl.py index 2d74d3919e..06e6387dd1 100644 --- a/tests/unit/test_rl.py +++ b/tests/unit/test_rl.py @@ -66,6 +66,7 @@ def fixture_sagemaker_session(): s3_resource=None, s3_client=None, settings=SessionSettings(), + default_bucket_prefix=None, ) describe = {"ModelArtifacts": {"S3ModelArtifacts": "s3://m/m.tar.gz"}} diff --git a/tests/unit/test_s3.py b/tests/unit/test_s3.py index 296ae21306..8ec8973b9a 100644 --- a/tests/unit/test_s3.py +++ b/tests/unit/test_s3.py @@ -24,6 +24,13 @@ SOURCE_NAME = "source" KMS_KEY = "kmskey" +SESSION_MOCK_WITH_PREFIX = Mock( + default_bucket=Mock(return_value="session_bucket"), default_bucket_prefix="session_prefix" +) +SESSION_MOCK_WITHOUT_PREFIX = Mock( + default_bucket=Mock(return_value="session_bucket"), default_bucket_prefix=None +) + @pytest.fixture() def sagemaker_session(): @@ -34,6 +41,7 @@ def sagemaker_session(): boto_region_name=REGION, config=None, local_mode=False, + default_bucket_prefix=None, ) session_mock.upload_data = Mock(name="upload_data", return_value="s3_uri_to_uploaded_data") session_mock.download_data = Mock(name="download_data") @@ -100,10 +108,20 @@ def test_download_with_kms_key(sagemaker_session): ) -def test_parse_s3_url(): - bucket, key_prefix = s3.parse_s3_url("s3://bucket/code_location") - assert "bucket" == bucket - assert "code_location" == key_prefix +@pytest.mark.parametrize( + "input_url, expected_bucket, expected_prefix", + [ + ("s3://bucket/code_location", "bucket", "code_location"), + ("s3://bucket/code_location/sub_location", "bucket", "code_location/sub_location"), + ("s3://bucket/code_location/sub_location/", "bucket", "code_location/sub_location/"), + ("s3://bucket/", "bucket", ""), + ("s3://bucket", "bucket", ""), + ], +) +def test_parse_s3_url(input_url, expected_bucket, expected_prefix): + bucket, key_prefix = s3.parse_s3_url(input_url) + assert bucket == expected_bucket + assert key_prefix == expected_prefix def test_parse_s3_url_fail(): @@ -112,15 +130,147 @@ def test_parse_s3_url_fail(): assert "Expecting 's3' scheme" in str(error) -def test_path_join(): - test_cases = ( +@pytest.mark.parametrize( + "expected_output, input_args", + [ + # simple cases + ("foo", ["foo"]), ("foo/bar", ("foo", "bar")), ("foo/bar", ("foo/", "bar")), ("foo/bar", ("/foo/", "bar")), + # ---------------- + # cases with s3:// ("s3://foo/bar", ("s3://", "foo", "bar")), ("s3://foo/bar", ("s3://", "/foo", "bar")), ("s3://foo/bar", ("s3://foo", "bar")), + ("s3://foo/bar/baz", ("s3://", "foo/bar/", "baz/")), + ("s3:", ["s3:"]), + ("s3:", ["s3:/"]), + ("s3://", ["s3://"]), + ("s3://", (["s3:////"])), + ("s3:", ("/", "s3://")), + ("s3://", ("s3://", "/")), + ("s/3/:", ("s", "3", ":", "/", "/")), + # ---------------- + # cases with empty or None + ("", []), + ("s3://foo/bar", ("s3://", "", "foo", "", "bar", "")), + ("s3://foo/bar", ("s3://", None, "foo", None, "bar", None)), + ("foo", (None, "foo")), + ("", ("", "", "")), + ("", ("")), + ("", ([None])), + ("", (None, None, None)), + # ---------------- + # cases with trailing slash + ("", ["/"]), + ("", ["/////"]), + ("foo", ["foo/"]), + ("foo", ["foo/////"]), + ("foo/bar", ("foo", "bar/")), + ("foo/bar", ("foo/", "bar/")), + ("foo/bar", ("/foo/", "bar/")), + # ---------------- + # cases with leading slashes + # (os.path.join and pathlib.PurePosixPath discard anything before the last leading slash) + ("foo/bar", ("/foo", "bar/")), + ("foo/bar", ("/////foo/", "bar/")), + ("foo", ("/", "foo")), + ("s3://foo/bar/baz", ("s3://", "foo", "/bar", "baz")), + ("s3://foo/bar/baz", ("s3://", "foo", "/bar", "/baz")), + # ---------------- + # cases with multiple slashes (note: multiple slashes are allowed by S3) + # (pathlib.PurePosixPath collapses multiple slashes to one) + ("s3://foo/bar/baz", ("s3://", "foo////bar/////", "baz/")), + ("s3://foo/bar/baz", ("s3://", "foo////bar/", "/////baz/")), + # ---------------- + # cases with a dot + # (pathlib.PurePosixPath collapses some single dots) + ("f.oo/bar", ("f.oo", "bar")), + ("foo/.bar", ("foo", ".bar")), + ("foo/.bar", ("foo", "/.bar")), + ("foo./bar", ("foo.", "bar")), + ("foo/./bar", ("foo/.", "bar")), + ("foo/./bar", ("foo/./", "bar")), + ("foo/./bar", ["foo/./bar"]), + ( + "s3://foo/..././bar/..../.././baz", + ("s3://", "foo//..././bar/", "..../.././/baz/"), + ), + # ---------------- + # cases with 2 dots + ("f..oo/bar", ("f..oo", "bar")), + ("foo/..bar", ("foo", "..bar")), + ("foo/..bar", ("foo", "/..bar")), + ("foo../bar", ("foo..", "bar")), + ("foo/../bar", ("foo/..", "bar")), + ("foo/../bar", ("foo/../", "bar")), + ("foo/../bar", ["foo/../bar"]), + ], +) +def test_path_join(expected_output, input_args): + assert s3.s3_path_join(*input_args) == expected_output + + +@pytest.mark.parametrize( + "expected_output, input_args", + [ + ("foo/", ["foo"]), + ("foo/", ["foo///"]), + ("foo/bar/", ("foo", "bar")), + ("foo/bar/", ("foo/", "bar")), + ("foo/bar/", ("/foo/", "bar")), + ("s3://foo/bar/", ("s3://", "foo", "bar")), + ("s3://foo/bar/", ("s3://", "/foo", "bar")), + ("s3://foo/bar/", ("s3://foo", "bar")), + ("s3://foo/bar/baz/", ("s3://", "foo/bar/", "baz/")), + ("s3://foo/bar/", ("s3://", "", "foo", "", "bar", "")), + ("s3://foo/bar/", ("s3://", None, "foo", None, "bar", None)), + ("foo/", (None, "foo")), + ("", ("", "", "")), + ("", ("")), + ("", ("/")), + ("", ("///")), + ("", ([None])), + ("", (None, None, None)), + ("s3:/", ["s3:"]), + ("s3:/", ["s3:/"]), + ("s3://", ["s3://"]), + ("s3://", (["s3:////"])), + ("s3:/", ("/", "s3://")), + ("s3://", ("s3://", "/")), + ("s/3/:/", ("s", "3", ":", "/", "/")), + ], +) +def test_s3_path_join_with_end_slash(expected_output, input_args): + assert s3.s3_path_join(*input_args, with_end_slash=True) == expected_output + + +@pytest.mark.parametrize( + "input_bucket, input_prefix, input_session, expected_bucket, expected_prefix", + [ + ("input-bucket", None, None, "input-bucket", None), + ("input-bucket", "input-prefix", None, "input-bucket", "input-prefix"), + ("input-bucket", None, SESSION_MOCK_WITH_PREFIX, "input-bucket", None), + ("input-bucket", "input-prefix", SESSION_MOCK_WITH_PREFIX, "input-bucket", "input-prefix"), + (None, None, SESSION_MOCK_WITH_PREFIX, "session_bucket", "session_prefix"), + (None, None, SESSION_MOCK_WITHOUT_PREFIX, "session_bucket", ""), + ( + None, + "input-prefix", + SESSION_MOCK_WITH_PREFIX, + "session_bucket", + "session_prefix/input-prefix", + ), + (None, "input-prefix", SESSION_MOCK_WITHOUT_PREFIX, "session_bucket", "input-prefix"), + ], +) +def test_determine_bucket_and_prefix( + input_bucket, input_prefix, input_session, expected_bucket, expected_prefix +): + + actual_bucket, actual_prefix = s3.determine_bucket_and_prefix( + bucket=input_bucket, key_prefix=input_prefix, sagemaker_session=input_session ) - for expected, args in test_cases: - assert expected == s3.s3_path_join(*args) + assert (actual_bucket == expected_bucket) and (actual_prefix == expected_prefix) diff --git a/tests/unit/test_session.py b/tests/unit/test_session.py index 6932088020..2a56e5f8df 100644 --- a/tests/unit/test_session.py +++ b/tests/unit/test_session.py @@ -50,6 +50,10 @@ SAGEMAKER_CONFIG_TRAINING_JOB, SAGEMAKER_CONFIG_TRANSFORM_JOB, SAGEMAKER_CONFIG_MODEL, + SAGEMAKER_CONFIG_SESSION, + _test_default_bucket_and_prefix_combinations, + DEFAULT_S3_OBJECT_KEY_PREFIX_NAME, + DEFAULT_S3_BUCKET_NAME, ) STATIC_HPs = {"feature_dim": "784"} @@ -374,6 +378,87 @@ def test_create_process_with_sagemaker_config_injection(sagemaker_session): sagemaker_session.sagemaker_client.create_processing_job.assert_called_with(**expected_request) +def test_default_bucket_with_sagemaker_config(boto_session, client): + # common kwargs for Session objects + session_kwargs = { + "boto_session": boto_session, + "sagemaker_client": client, + "sagemaker_runtime_client": client, + "sagemaker_metrics_client": client, + } + + # Case 1: Use bucket from sagemaker_config + session_with_config_bucket = Session( + default_bucket=None, + sagemaker_config=SAGEMAKER_CONFIG_SESSION, + **session_kwargs, + ) + assert ( + session_with_config_bucket.default_bucket() + == SAGEMAKER_CONFIG_SESSION["SageMaker"]["PythonSDK"]["Modules"]["Session"][ + "DefaultS3Bucket" + ] + ) + + # Case 2: Use bucket from user input to Session (even if sagemaker_config has a bucket) + session_with_user_bucket = Session( + default_bucket="default-bucket", + sagemaker_config=SAGEMAKER_CONFIG_SESSION, + **session_kwargs, + ) + assert session_with_user_bucket.default_bucket() == "default-bucket" + + # Case 3: Use default bucket of SDK + session_with_sdk_bucket = Session( + default_bucket=None, + sagemaker_config=None, + **session_kwargs, + ) + session_with_sdk_bucket.boto_session.client.return_value = Mock( + get_caller_identity=Mock(return_value={"Account": "111111111"}) + ) + assert session_with_sdk_bucket.default_bucket() == "sagemaker-us-west-2-111111111" + + +def test_default_bucket_prefix_with_sagemaker_config(boto_session, client): + # common kwargs for Session objects + session_kwargs = { + "boto_session": boto_session, + "sagemaker_client": client, + "sagemaker_runtime_client": client, + "sagemaker_metrics_client": client, + } + + # Case 1: Use prefix from sagemaker_config + session_with_config_prefix = Session( + default_bucket_prefix=None, + sagemaker_config=SAGEMAKER_CONFIG_SESSION, + **session_kwargs, + ) + assert ( + session_with_config_prefix.default_bucket_prefix + == SAGEMAKER_CONFIG_SESSION["SageMaker"]["PythonSDK"]["Modules"]["Session"][ + "DefaultS3ObjectKeyPrefix" + ] + ) + + # Case 2: Use prefix from user input to Session (even if sagemaker_config has a prefix) + session_with_user_prefix = Session( + default_bucket_prefix="default-prefix", + sagemaker_config=SAGEMAKER_CONFIG_SESSION, + **session_kwargs, + ) + assert session_with_user_prefix.default_bucket_prefix == "default-prefix" + + # Case 3: Neither the user input or config has the prefix + session_with_no_prefix = Session( + default_bucket_prefix=None, + sagemaker_config=None, + **session_kwargs, + ) + assert session_with_no_prefix.default_bucket_prefix is None + + def mock_exists(filepath_to_mock, exists_result): unmocked_exists = os.path.exists @@ -2159,7 +2244,6 @@ def test_logs_for_transform_job_full_lifecycle(time, cw, sagemaker_session_full_ def test_create_model_with_sagemaker_config_injection(sagemaker_session): - sagemaker_session.sagemaker_config = SAGEMAKER_CONFIG_MODEL sagemaker_session.expand_role = Mock( @@ -3859,7 +3943,6 @@ def feature_group_dummy_definitions(): def test_feature_group_create_with_sagemaker_config_injection( sagemaker_session, feature_group_dummy_definitions ): - sagemaker_session.sagemaker_config = SAGEMAKER_CONFIG_FEATURE_GROUP sagemaker_session.create_feature_group( @@ -4723,3 +4806,108 @@ def sort(tags): {"Key": "tagkey5", "Value": "000"}, ] ) + + +@pytest.mark.parametrize( + ( + "file_path, user_input_params, " + "expected__without_user_input__with_default_bucket_and_default_prefix, " + "expected__without_user_input__with_default_bucket_only, " + "expected__with_user_input__with_default_bucket_and_prefix, " + "expected__with_user_input__with_default_bucket_only" + ), + [ + # Group with just bucket as user input + ( + "some/local/path/model.gz", + {"bucket": "input-bucket"}, + f"s3://{DEFAULT_S3_BUCKET_NAME}/{DEFAULT_S3_OBJECT_KEY_PREFIX_NAME}/data/model.gz", + f"s3://{DEFAULT_S3_BUCKET_NAME}/data/model.gz", + "s3://input-bucket/data/model.gz", + "s3://input-bucket/data/model.gz", + ), + ( + "some/local/path/dir", + {"bucket": "input-bucket"}, + f"s3://{DEFAULT_S3_BUCKET_NAME}/{DEFAULT_S3_OBJECT_KEY_PREFIX_NAME}/data/dir", + f"s3://{DEFAULT_S3_BUCKET_NAME}/data/dir", + "s3://input-bucket/data/dir", + "s3://input-bucket/data/dir", + ), + # Group with both bucket and prefix as user input + ( + "some/local/path/model.gz", + {"bucket": "input-bucket", "key_prefix": "input-prefix"}, + f"s3://{DEFAULT_S3_BUCKET_NAME}/{DEFAULT_S3_OBJECT_KEY_PREFIX_NAME}/data/model.gz", + f"s3://{DEFAULT_S3_BUCKET_NAME}/data/model.gz", + "s3://input-bucket/input-prefix/model.gz", + "s3://input-bucket/input-prefix/model.gz", + ), + ( + "some/local/path/dir", + {"bucket": "input-bucket", "key_prefix": "input-prefix"}, + f"s3://{DEFAULT_S3_BUCKET_NAME}/{DEFAULT_S3_OBJECT_KEY_PREFIX_NAME}/data/dir", + f"s3://{DEFAULT_S3_BUCKET_NAME}/data/dir", + "s3://input-bucket/input-prefix/dir", + "s3://input-bucket/input-prefix/dir", + ), + # Group with just prefix as user input + ( + "some/local/path/model.gz", + {"key_prefix": "input-prefix"}, + f"s3://{DEFAULT_S3_BUCKET_NAME}/{DEFAULT_S3_OBJECT_KEY_PREFIX_NAME}/data/model.gz", + f"s3://{DEFAULT_S3_BUCKET_NAME}/data/model.gz", + f"s3://{DEFAULT_S3_BUCKET_NAME}/{DEFAULT_S3_OBJECT_KEY_PREFIX_NAME}/input-prefix/model.gz", + f"s3://{DEFAULT_S3_BUCKET_NAME}/input-prefix/model.gz", + ), + ( + "some/local/path/dir", + {"key_prefix": "input-prefix"}, + f"s3://{DEFAULT_S3_BUCKET_NAME}/{DEFAULT_S3_OBJECT_KEY_PREFIX_NAME}/data/dir", + f"s3://{DEFAULT_S3_BUCKET_NAME}/data/dir", + f"s3://{DEFAULT_S3_BUCKET_NAME}/{DEFAULT_S3_OBJECT_KEY_PREFIX_NAME}/input-prefix/dir", + f"s3://{DEFAULT_S3_BUCKET_NAME}/input-prefix/dir", + ), + ( + "some/local/path/dir", + {"key_prefix": "input-prefix/longer/path/"}, + f"s3://{DEFAULT_S3_BUCKET_NAME}/{DEFAULT_S3_OBJECT_KEY_PREFIX_NAME}/data/dir", + f"s3://{DEFAULT_S3_BUCKET_NAME}/data/dir", + f"s3://{DEFAULT_S3_BUCKET_NAME}/{DEFAULT_S3_OBJECT_KEY_PREFIX_NAME}/input-prefix/longer/path/dir", + f"s3://{DEFAULT_S3_BUCKET_NAME}/input-prefix/longer/path/dir", + ), + ], +) +def test_upload_data_default_bucket_and_prefix_combinations( + sagemaker_session, + file_path, + user_input_params, + expected__without_user_input__with_default_bucket_and_default_prefix, + expected__without_user_input__with_default_bucket_only, + expected__with_user_input__with_default_bucket_and_prefix, + expected__with_user_input__with_default_bucket_only, +): + sagemaker_session.s3_resource = Mock() + sagemaker_session._default_bucket = DEFAULT_S3_BUCKET_NAME + + session_with_bucket_and_prefix = copy.deepcopy(sagemaker_session) + session_with_bucket_and_prefix.default_bucket_prefix = DEFAULT_S3_OBJECT_KEY_PREFIX_NAME + + session_with_bucket_and_no_prefix = copy.deepcopy(sagemaker_session) + session_with_bucket_and_no_prefix.default_bucket_prefix = None + + actual, expected = _test_default_bucket_and_prefix_combinations( + session_with_bucket_and_prefix=session_with_bucket_and_prefix, + session_with_bucket_and_no_prefix=session_with_bucket_and_no_prefix, + function_with_user_input=(lambda sess: sess.upload_data(file_path, **user_input_params)), + function_without_user_input=(lambda sess: sess.upload_data(file_path)), + expected__without_user_input__with_default_bucket_and_default_prefix=( + expected__without_user_input__with_default_bucket_and_default_prefix + ), + expected__without_user_input__with_default_bucket_only=expected__without_user_input__with_default_bucket_only, + expected__with_user_input__with_default_bucket_and_prefix=( + expected__with_user_input__with_default_bucket_and_prefix + ), + expected__with_user_input__with_default_bucket_only=expected__with_user_input__with_default_bucket_only, + ) + assert actual == expected diff --git a/tests/unit/test_sklearn.py b/tests/unit/test_sklearn.py index 6f0ce35319..9745c4ea26 100644 --- a/tests/unit/test_sklearn.py +++ b/tests/unit/test_sklearn.py @@ -68,6 +68,7 @@ def sagemaker_session(): s3_resource=None, s3_client=None, settings=SessionSettings(), + default_bucket_prefix=None, ) describe = {"ModelArtifacts": {"S3ModelArtifacts": "s3://m/m.tar.gz"}} diff --git a/tests/unit/test_sparkml_serving.py b/tests/unit/test_sparkml_serving.py index 1e178dfc57..280703d321 100644 --- a/tests/unit/test_sparkml_serving.py +++ b/tests/unit/test_sparkml_serving.py @@ -42,6 +42,7 @@ def sagemaker_session(): config=None, local_mode=False, settings=SessionSettings(), + default_bucket_prefix=None, ) sms.boto_region_name = REGION sms.sagemaker_client.describe_endpoint = Mock(return_value=ENDPOINT_DESC) diff --git a/tests/unit/test_transformer.py b/tests/unit/test_transformer.py index 3611ebd4a2..4136f0aec2 100644 --- a/tests/unit/test_transformer.py +++ b/tests/unit/test_transformer.py @@ -65,7 +65,12 @@ def mock_create_tar_file(): @pytest.fixture() def sagemaker_session(): boto_mock = Mock(name="boto_session") - session = Mock(name="sagemaker_session", boto_session=boto_mock, local_mode=False) + session = Mock( + name="sagemaker_session", + boto_session=boto_mock, + local_mode=False, + default_bucket_prefix=None, + ) # For tests which doesn't verify config file injection, operate with empty config session.sagemaker_config = {} return session diff --git a/tests/unit/test_tuner.py b/tests/unit/test_tuner.py index 8bd5127dd6..d8aa891a7e 100644 --- a/tests/unit/test_tuner.py +++ b/tests/unit/test_tuner.py @@ -58,6 +58,7 @@ def sagemaker_session(): s3_client=None, s3_resource=None, settings=SessionSettings(), + default_bucket_prefix=None, ) sms.boto_region_name = REGION sms.default_bucket = Mock(name="default_bucket", return_value=BUCKET_NAME) diff --git a/tests/unit/test_xgboost.py b/tests/unit/test_xgboost.py index 4476b3b5ff..e54da2d862 100644 --- a/tests/unit/test_xgboost.py +++ b/tests/unit/test_xgboost.py @@ -71,6 +71,7 @@ def sagemaker_session(): s3_resource=None, s3_client=None, settings=SessionSettings(), + default_bucket_prefix=None, ) describe = {"ModelArtifacts": {"S3ModelArtifacts": "s3://m/m.tar.gz"}} diff --git a/tests/unit/tuner_test_utils.py b/tests/unit/tuner_test_utils.py index b9fa4f2ff3..c7b1abcbb2 100644 --- a/tests/unit/tuner_test_utils.py +++ b/tests/unit/tuner_test_utils.py @@ -74,6 +74,9 @@ SAGEMAKER_SESSION = Mock() # For tests which doesn't verify config file injection, operate with empty config SAGEMAKER_SESSION.sagemaker_config = {} +SAGEMAKER_SESSION.default_bucket = Mock(return_value=BUCKET_NAME) +SAGEMAKER_SESSION.default_bucket_prefix = None + ESTIMATOR = Estimator( IMAGE_NAME,