diff --git a/src/sagemaker/session.py b/src/sagemaker/session.py index 2ef34071fe..a5415bf18c 100644 --- a/src/sagemaker/session.py +++ b/src/sagemaker/session.py @@ -168,16 +168,23 @@ def upload_data(self, path, bucket=None, key_prefix='data'): s3_uri = '{}/{}'.format(s3_uri, key_suffix) return s3_uri - def default_bucket(self): + def default_bucket(self, sts_endpoint_url=None): """Return the name of the default bucket to use in relevant Amazon SageMaker interactions. + Args: + sts_endpoint_url (str): Optional. URL of STS endpoint to send requests to. If not specified, + the global endpoint will be used (sts.amazonaws.com). Should be in the format of + sts..amazonaws.com + Returns: str: The name of the default bucket, which is of the form: ``sagemaker-{region}-{AWS account ID}``. """ if self._default_bucket: return self._default_bucket - - account = self.boto_session.client('sts').get_caller_identity()['Account'] + if not sts_endpoint_url: + account = self.boto_session.client('sts').get_caller_identity()['Account'] + else: + account = self.boto_session.client('sts', endpoint_url=sts_endpoint_url).get_caller_identity()['Account'] region = self.boto_session.region_name default_bucket = 'sagemaker-{}-{}'.format(region, account) @@ -1084,12 +1091,21 @@ def expand_role(self, role): else: return self.boto_session.resource('iam').Role(role).arn - def get_caller_identity_arn(self): + def get_caller_identity_arn(self, sts_endpoint_url=None): """Returns the ARN user or role whose credentials are used to call the API. + + Args: + sts_endpoint_url (str): Optional. URL of STS endpoint to send requests to. If not specified, + the global endpoint will be used (sts.amazonaws.com). + Returns: (str): The ARN user or role """ - assumed_role = self.boto_session.client('sts').get_caller_identity()['Arn'] + + if not sts_endpoint_url: + assumed_role = self.boto_session.client('sts').get_caller_identity()['Arn'] + else: + assumed_role = self.boto_session.client('sts', endpoint_url=sts_endpoint_url).get_caller_identity()['Arn'] if 'AmazonSageMaker-ExecutionRole' in assumed_role: role = re.sub(r'^(.+)sts::(\d+):assumed-role/(.+?)/.*$', r'\1iam::\2:role/service-role/\3', assumed_role) @@ -1294,17 +1310,20 @@ def production_variant(model_name, instance_type, initial_instance_count=1, vari return production_variant_configuration -def get_execution_role(sagemaker_session=None): +def get_execution_role(sagemaker_session=None, sts_endpoint_url=None): """Return the role ARN whose credentials are used to call the API. Throws an exception if Args: sagemaker_session(Session): Current sagemaker session + sts_endpoint_url (str): Optional. URL of STS endpoint to send requests to. If not specified, + the global endpoint will be used (sts.amazonaws.com). + Returns: (str): The role ARN """ if not sagemaker_session: sagemaker_session = Session() - arn = sagemaker_session.get_caller_identity_arn() + arn = sagemaker_session.get_caller_identity_arn(sts_endpoint_url=sts_endpoint_url) if ':role/' in arn: return arn @@ -1383,6 +1402,7 @@ class ShuffleConfig(object): Used to configure channel shuffling using a seed. See SageMaker documentation for more detail: https://docs.aws.amazon.com/sagemaker/latest/dg/API_ShuffleConfig.html """ + def __init__(self, seed): """ Create a ShuffleConfig. diff --git a/tests/unit/test_session.py b/tests/unit/test_session.py index 4f34bde068..d8b4327393 100644 --- a/tests/unit/test_session.py +++ b/tests/unit/test_session.py @@ -53,6 +53,14 @@ def test_get_execution_role(): assert actual == 'arn:aws:iam::369233609183:role/SageMakerRole' +def test_get_execution_role_with_sts_endpoint(): + session = Mock() + session.get_caller_identity_arn.return_value = 'arn:aws:iam::369233609183:role/SageMakerRole' + + actual = get_execution_role(session, sts_endpoint_url='https://sts.us-west-2.amazonaws.com') + assert actual == 'arn:aws:iam::369233609183:role/SageMakerRole' + + def test_get_execution_role_works_with_service_role(): session = Mock() session.get_caller_identity_arn.return_value = \