diff --git a/src/sagemaker/session.py b/src/sagemaker/session.py index 9eab2222df..854c431e87 100644 --- a/src/sagemaker/session.py +++ b/src/sagemaker/session.py @@ -342,40 +342,58 @@ def default_bucket(self): ).get_caller_identity()["Account"] default_bucket = "sagemaker-{}-{}".format(region, account) - s3 = self.boto_session.resource("s3") - try: - # 'us-east-1' cannot be specified because it is the default region: - # https://github.com/boto/boto3/issues/125 - if region == "us-east-1": - s3.create_bucket(Bucket=default_bucket) - else: - s3.create_bucket( - Bucket=default_bucket, CreateBucketConfiguration={"LocationConstraint": region} - ) - - LOGGER.info("Created S3 bucket: %s", default_bucket) - except ClientError as e: - error_code = e.response["Error"]["Code"] - message = e.response["Error"]["Message"] - - if error_code == "BucketAlreadyOwnedByYou": - pass - elif ( - error_code == "OperationAborted" and "conflicting conditional operation" in message - ): - # If this bucket is already being concurrently created, we don't need to create it - # again. - pass - elif error_code == "TooManyBuckets": - # Succeed if the default bucket exists - s3.meta.client.head_bucket(Bucket=default_bucket) - else: - raise + self._create_s3_bucket_if_it_does_not_exist(bucket_name=default_bucket, region=region) self._default_bucket = default_bucket return self._default_bucket + def _create_s3_bucket_if_it_does_not_exist(self, bucket_name, region): + """Creates an S3 Bucket if it does not exist. + Also swallows a few common exceptions that indicate that the bucket already exists or + that it is being created. + + Args: + bucket_name (str): Name of the S3 bucket to be created. + region (str): The region in which to create the bucket. + + Raises: + botocore.exceptions.ClientError: If S3 throws an unexpected exception during bucket + creation. + If the exception is due to the bucket already existing or + already being created, no exception is raised. + + """ + bucket = self.boto_session.resource("s3", region_name=region).Bucket(name=bucket_name) + if bucket.creation_date is None: + try: + s3 = self.boto_session.resource("s3", region_name=region) + if region == "us-east-1": + # 'us-east-1' cannot be specified because it is the default region: + # https://github.com/boto/boto3/issues/125 + s3.create_bucket(Bucket=bucket_name) + else: + s3.create_bucket( + Bucket=bucket_name, CreateBucketConfiguration={"LocationConstraint": region} + ) + + LOGGER.info("Created S3 bucket: %s", bucket_name) + except ClientError as e: + error_code = e.response["Error"]["Code"] + message = e.response["Error"]["Message"] + + if error_code == "BucketAlreadyOwnedByYou": + pass + elif ( + error_code == "OperationAborted" + and "conflicting conditional operation" in message + ): + # If this bucket is already being concurrently created, we don't need to create + # it again. + pass + else: + raise + def train( # noqa: C901 self, input_mode, diff --git a/tests/integ/kms_utils.py b/tests/integ/kms_utils.py index f41123c73c..95b0edc859 100644 --- a/tests/integ/kms_utils.py +++ b/tests/integ/kms_utils.py @@ -15,8 +15,6 @@ import contextlib import json -from botocore import exceptions - from sagemaker import utils PRINCIPAL_TEMPLATE = ( @@ -158,7 +156,8 @@ def get_or_create_kms_key( @contextlib.contextmanager -def bucket_with_encryption(boto_session, sagemaker_role): +def bucket_with_encryption(sagemaker_session, sagemaker_role): + boto_session = sagemaker_session.boto_session region = boto_session.region_name sts_client = boto_session.client( "sts", region_name=region, endpoint_url=utils.sts_regional_endpoint(region) @@ -173,22 +172,10 @@ def bucket_with_encryption(boto_session, sagemaker_role): region = boto_session.region_name bucket_name = "sagemaker-{}-{}-with-kms".format(region, account) - s3 = boto_session.client("s3") - try: - # 'us-east-1' cannot be specified because it is the default region: - # https://github.com/boto/boto3/issues/125 - if region == "us-east-1": - s3.create_bucket(Bucket=bucket_name) - else: - s3.create_bucket( - Bucket=bucket_name, CreateBucketConfiguration={"LocationConstraint": region} - ) - - except exceptions.ClientError as e: - if e.response["Error"]["Code"] != "BucketAlreadyOwnedByYou": - raise - - s3.put_bucket_encryption( + sagemaker_session._create_s3_bucket_if_it_does_not_exist(bucket_name=bucket_name, region=region) + + s3_client = boto_session.client("s3", region_name=region) + s3_client.put_bucket_encryption( Bucket=bucket_name, ServerSideEncryptionConfiguration={ "Rules": [ @@ -202,7 +189,9 @@ def bucket_with_encryption(boto_session, sagemaker_role): }, ) - s3.put_bucket_policy(Bucket=bucket_name, Policy=KMS_BUCKET_POLICY % (bucket_name, bucket_name)) + s3_client.put_bucket_policy( + Bucket=bucket_name, Policy=KMS_BUCKET_POLICY % (bucket_name, bucket_name) + ) yield "s3://" + bucket_name, kms_key_arn diff --git a/tests/integ/test_tf_script_mode.py b/tests/integ/test_tf_script_mode.py index d5facb371e..792f09b5bb 100644 --- a/tests/integ/test_tf_script_mode.py +++ b/tests/integ/test_tf_script_mode.py @@ -83,8 +83,7 @@ def test_mnist_with_checkpoint_config(sagemaker_session, instance_type, tf_full_ def test_server_side_encryption(sagemaker_session, tf_full_version): - boto_session = sagemaker_session.boto_session - with kms_utils.bucket_with_encryption(boto_session, ROLE) as (bucket_with_kms, kms_key): + with kms_utils.bucket_with_encryption(sagemaker_session, ROLE) as (bucket_with_kms, kms_key): output_path = os.path.join( bucket_with_kms, "test-server-side-encryption", time.strftime("%y%m%d-%H%M") ) diff --git a/tests/unit/test_default_bucket.py b/tests/unit/test_default_bucket.py index 6b72a21e90..50ceda7c56 100644 --- a/tests/unit/test_default_bucket.py +++ b/tests/unit/test_default_bucket.py @@ -26,8 +26,9 @@ def sagemaker_session(): boto_mock = Mock(name="boto_session", region_name=REGION) boto_mock.client("sts").get_caller_identity.return_value = {"Account": ACCOUNT_ID} - ims = sagemaker.Session(boto_session=boto_mock) - return ims + sagemaker_session = sagemaker.Session(boto_session=boto_mock) + sagemaker_session.boto_session.resource("s3").Bucket().creation_date = None + return sagemaker_session def test_default_bucket_s3_create_call(sagemaker_session):