From ee4281b2fda74528088d31c82fd67bce5a0defa1 Mon Sep 17 00:00:00 2001 From: Gray Date: Tue, 7 May 2019 22:16:31 -0700 Subject: [PATCH 01/10] feature: support for STS regional endpoints --- src/sagemaker/session.py | 17 +++++++++++------ tests/unit/test_session.py | 8 ++++++++ 2 files changed, 19 insertions(+), 6 deletions(-) diff --git a/src/sagemaker/session.py b/src/sagemaker/session.py index 2ef34071fe..ae683a1cf7 100644 --- a/src/sagemaker/session.py +++ b/src/sagemaker/session.py @@ -168,7 +168,7 @@ def upload_data(self, path, bucket=None, key_prefix='data'): s3_uri = '{}/{}'.format(s3_uri, key_suffix) return s3_uri - def default_bucket(self): + def default_bucket(self, sts_endpoint_url=None): """Return the name of the default bucket to use in relevant Amazon SageMaker interactions. Returns: @@ -177,7 +177,7 @@ def default_bucket(self): if self._default_bucket: return self._default_bucket - account = self.boto_session.client('sts').get_caller_identity()['Account'] + account = self.boto_session.client('sts', endpoint_url=sts_endpoint_url).get_caller_identity()['Account'] region = self.boto_session.region_name default_bucket = 'sagemaker-{}-{}'.format(region, account) @@ -1084,12 +1084,16 @@ def expand_role(self, role): else: return self.boto_session.resource('iam').Role(role).arn - def get_caller_identity_arn(self): + def get_caller_identity_arn(self, sts_endpoint_url=None): """Returns the ARN user or role whose credentials are used to call the API. Returns: (str): The ARN user or role """ - assumed_role = self.boto_session.client('sts').get_caller_identity()['Arn'] + + if not sts_endpoint_url: + assumed_role = self.boto_session.client('sts').get_caller_identity()['Arn'] + else: + assumed_role = self.boto_session.client('sts', endpoint_url=sts_endpoint_url).get_caller_identity()['Arn'] if 'AmazonSageMaker-ExecutionRole' in assumed_role: role = re.sub(r'^(.+)sts::(\d+):assumed-role/(.+?)/.*$', r'\1iam::\2:role/service-role/\3', assumed_role) @@ -1294,7 +1298,7 @@ def production_variant(model_name, instance_type, initial_instance_count=1, vari return production_variant_configuration -def get_execution_role(sagemaker_session=None): +def get_execution_role(sagemaker_session=None, sts_endpoint_url=None): """Return the role ARN whose credentials are used to call the API. Throws an exception if Args: @@ -1304,7 +1308,7 @@ def get_execution_role(sagemaker_session=None): """ if not sagemaker_session: sagemaker_session = Session() - arn = sagemaker_session.get_caller_identity_arn() + arn = sagemaker_session.get_caller_identity_arn(sts_endpoint_url=sts_endpoint_url) if ':role/' in arn: return arn @@ -1383,6 +1387,7 @@ class ShuffleConfig(object): Used to configure channel shuffling using a seed. See SageMaker documentation for more detail: https://docs.aws.amazon.com/sagemaker/latest/dg/API_ShuffleConfig.html """ + def __init__(self, seed): """ Create a ShuffleConfig. diff --git a/tests/unit/test_session.py b/tests/unit/test_session.py index 4f34bde068..dd932e73c1 100644 --- a/tests/unit/test_session.py +++ b/tests/unit/test_session.py @@ -53,6 +53,14 @@ def test_get_execution_role(): assert actual == 'arn:aws:iam::369233609183:role/SageMakerRole' +def test_get_execution_role_with_sts_endpoint(): + session = Mock() + session.get_caller_identity_arn.return_value = 'arn:aws:iam::369233609183:role/SageMakerRole' + + actual = get_execution_role(session, sts_endpoint_url='https://sts.us-west-2.amazonaws.com') + assert actual == 'arn:aws:iam::369233609183:role/SageMakerRole' + + def test_get_execution_role_works_with_service_role(): session = Mock() session.get_caller_identity_arn.return_value = \ From c9a6ba52bdbd62098258965a9a49b37e01ddc673 Mon Sep 17 00:00:00 2001 From: Gray Date: Tue, 7 May 2019 22:49:56 -0700 Subject: [PATCH 02/10] documentation: add STS endpoint documentation --- src/sagemaker/session.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/src/sagemaker/session.py b/src/sagemaker/session.py index ae683a1cf7..32045cd8f0 100644 --- a/src/sagemaker/session.py +++ b/src/sagemaker/session.py @@ -171,6 +171,10 @@ def upload_data(self, path, bucket=None, key_prefix='data'): def default_bucket(self, sts_endpoint_url=None): """Return the name of the default bucket to use in relevant Amazon SageMaker interactions. + Args: + sts_endpoint_url (str): Optional. URL of STS endpoint to send requests to. If not specified, + the global endpoint will be used (sts.amazonaws.com). + Returns: str: The name of the default bucket, which is of the form: ``sagemaker-{region}-{AWS account ID}``. """ @@ -1086,6 +1090,11 @@ def expand_role(self, role): def get_caller_identity_arn(self, sts_endpoint_url=None): """Returns the ARN user or role whose credentials are used to call the API. + + Args: + sts_endpoint_url (str): Optional. URL of STS endpoint to send requests to. If not specified, + the global endpoint will be used (sts.amazonaws.com). + Returns: (str): The ARN user or role """ @@ -1303,6 +1312,9 @@ def get_execution_role(sagemaker_session=None, sts_endpoint_url=None): Throws an exception if Args: sagemaker_session(Session): Current sagemaker session + sts_endpoint_url (str): Optional. URL of STS endpoint to send requests to. If not specified, + the global endpoint will be used (sts.amazonaws.com). + Returns: (str): The role ARN """ From a3805c39340fe10cf0dcddcf8c95f6a92d8a518a Mon Sep 17 00:00:00 2001 From: Gray Date: Wed, 8 May 2019 11:48:27 -0700 Subject: [PATCH 03/10] change: fixed style for build --- src/sagemaker/session.py | 8 ++++---- tests/unit/test_session.py | 2 +- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/src/sagemaker/session.py b/src/sagemaker/session.py index 32045cd8f0..cefac4ce89 100644 --- a/src/sagemaker/session.py +++ b/src/sagemaker/session.py @@ -172,7 +172,7 @@ def default_bucket(self, sts_endpoint_url=None): """Return the name of the default bucket to use in relevant Amazon SageMaker interactions. Args: - sts_endpoint_url (str): Optional. URL of STS endpoint to send requests to. If not specified, + sts_endpoint_url (str): Optional. URL of STS endpoint to send requests to. If not specified, the global endpoint will be used (sts.amazonaws.com). Returns: @@ -1092,9 +1092,9 @@ def get_caller_identity_arn(self, sts_endpoint_url=None): """Returns the ARN user or role whose credentials are used to call the API. Args: - sts_endpoint_url (str): Optional. URL of STS endpoint to send requests to. If not specified, + sts_endpoint_url (str): Optional. URL of STS endpoint to send requests to. If not specified, the global endpoint will be used (sts.amazonaws.com). - + Returns: (str): The ARN user or role """ @@ -1312,7 +1312,7 @@ def get_execution_role(sagemaker_session=None, sts_endpoint_url=None): Throws an exception if Args: sagemaker_session(Session): Current sagemaker session - sts_endpoint_url (str): Optional. URL of STS endpoint to send requests to. If not specified, + sts_endpoint_url (str): Optional. URL of STS endpoint to send requests to. If not specified, the global endpoint will be used (sts.amazonaws.com). Returns: diff --git a/tests/unit/test_session.py b/tests/unit/test_session.py index dd932e73c1..d8b4327393 100644 --- a/tests/unit/test_session.py +++ b/tests/unit/test_session.py @@ -60,7 +60,7 @@ def test_get_execution_role_with_sts_endpoint(): actual = get_execution_role(session, sts_endpoint_url='https://sts.us-west-2.amazonaws.com') assert actual == 'arn:aws:iam::369233609183:role/SageMakerRole' - + def test_get_execution_role_works_with_service_role(): session = Mock() session.get_caller_identity_arn.return_value = \ From e760c22937ff23befbe3438fb22f8aab8c180fa9 Mon Sep 17 00:00:00 2001 From: Gray Date: Wed, 8 May 2019 15:40:28 -0700 Subject: [PATCH 04/10] change: session docstring fix --- src/sagemaker/session.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/sagemaker/session.py b/src/sagemaker/session.py index cefac4ce89..1cd92c61e2 100644 --- a/src/sagemaker/session.py +++ b/src/sagemaker/session.py @@ -173,7 +173,7 @@ def default_bucket(self, sts_endpoint_url=None): Args: sts_endpoint_url (str): Optional. URL of STS endpoint to send requests to. If not specified, - the global endpoint will be used (sts.amazonaws.com). + the global endpoint will be used (sts.amazonaws.com). Returns: str: The name of the default bucket, which is of the form: ``sagemaker-{region}-{AWS account ID}``. @@ -1093,7 +1093,7 @@ def get_caller_identity_arn(self, sts_endpoint_url=None): Args: sts_endpoint_url (str): Optional. URL of STS endpoint to send requests to. If not specified, - the global endpoint will be used (sts.amazonaws.com). + the global endpoint will be used (sts.amazonaws.com). Returns: (str): The ARN user or role From 3a49ed46d3adcbd5e6fad2a6dffc8b1cf4b0738a Mon Sep 17 00:00:00 2001 From: Gray Date: Wed, 8 May 2019 15:45:26 -0700 Subject: [PATCH 05/10] change: session docstring fixes --- src/sagemaker/session.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/sagemaker/session.py b/src/sagemaker/session.py index 1cd92c61e2..3fc135c0f8 100644 --- a/src/sagemaker/session.py +++ b/src/sagemaker/session.py @@ -1312,8 +1312,8 @@ def get_execution_role(sagemaker_session=None, sts_endpoint_url=None): Throws an exception if Args: sagemaker_session(Session): Current sagemaker session - sts_endpoint_url (str): Optional. URL of STS endpoint to send requests to. If not specified, - the global endpoint will be used (sts.amazonaws.com). + sts_endpoint_url (str): Optional. URL of STS endpoint to send requests to. If not specified, + the global endpoint will be used (sts.amazonaws.com). Returns: (str): The role ARN From 8c0f99e96bd2de61849d0f88d4c1e188d758af5d Mon Sep 17 00:00:00 2001 From: Gray Date: Wed, 8 May 2019 15:49:26 -0700 Subject: [PATCH 06/10] change: docstring fixes --- src/sagemaker/session.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/sagemaker/session.py b/src/sagemaker/session.py index 3fc135c0f8..30ca727aa7 100644 --- a/src/sagemaker/session.py +++ b/src/sagemaker/session.py @@ -1312,7 +1312,7 @@ def get_execution_role(sagemaker_session=None, sts_endpoint_url=None): Throws an exception if Args: sagemaker_session(Session): Current sagemaker session - sts_endpoint_url (str): Optional. URL of STS endpoint to send requests to. If not specified, + sts_endpoint_url (str): Optional. URL of STS endpoint to send requests to. If not specified, the global endpoint will be used (sts.amazonaws.com). Returns: From 6b636bb5f6c7d9af2fb42ecdf767b77e45624077 Mon Sep 17 00:00:00 2001 From: Gray Date: Thu, 9 May 2019 11:14:03 -0700 Subject: [PATCH 07/10] fix: fixed endpoint url for default_bucket --- src/sagemaker/session.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/sagemaker/session.py b/src/sagemaker/session.py index 30ca727aa7..372d9f587f 100644 --- a/src/sagemaker/session.py +++ b/src/sagemaker/session.py @@ -180,8 +180,10 @@ def default_bucket(self, sts_endpoint_url=None): """ if self._default_bucket: return self._default_bucket - - account = self.boto_session.client('sts', endpoint_url=sts_endpoint_url).get_caller_identity()['Account'] + if not sts_endpoint_url: + account = self.boto_session.client('sts').get_caller_identity()['Account'] + else: + account = self.boto_session.client('sts', endpoint_url=sts_endpoint_url).get_caller_identity()['Account'] region = self.boto_session.region_name default_bucket = 'sagemaker-{}-{}'.format(region, account) From ec846e989b9256b7d2b6b76178b7622779e3b892 Mon Sep 17 00:00:00 2001 From: Gray Date: Mon, 13 May 2019 16:29:36 -0700 Subject: [PATCH 08/10] change: docstring fixes --- src/sagemaker/session.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/sagemaker/session.py b/src/sagemaker/session.py index 372d9f587f..145e087bb2 100644 --- a/src/sagemaker/session.py +++ b/src/sagemaker/session.py @@ -173,7 +173,7 @@ def default_bucket(self, sts_endpoint_url=None): Args: sts_endpoint_url (str): Optional. URL of STS endpoint to send requests to. If not specified, - the global endpoint will be used (sts.amazonaws.com). + the global endpoint will be used (sts.amazonaws.com). Should be in the format of sts..amazonaws.com Returns: str: The name of the default bucket, which is of the form: ``sagemaker-{region}-{AWS account ID}``. From 0978937df15dc68493c9cb118aaced8b0cd0e7f3 Mon Sep 17 00:00:00 2001 From: Gray Date: Mon, 13 May 2019 16:32:02 -0700 Subject: [PATCH 09/10] change: docstring linting --- src/sagemaker/session.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/sagemaker/session.py b/src/sagemaker/session.py index 145e087bb2..60be0a902e 100644 --- a/src/sagemaker/session.py +++ b/src/sagemaker/session.py @@ -173,7 +173,8 @@ def default_bucket(self, sts_endpoint_url=None): Args: sts_endpoint_url (str): Optional. URL of STS endpoint to send requests to. If not specified, - the global endpoint will be used (sts.amazonaws.com). Should be in the format of sts..amazonaws.com + the global endpoint will be used (sts.amazonaws.com). Should be in the format of + sts..amazonaws.com Returns: str: The name of the default bucket, which is of the form: ``sagemaker-{region}-{AWS account ID}``. From 963869e28a6d09d2f32f06172f34656a4c8521d9 Mon Sep 17 00:00:00 2001 From: Gray Date: Mon, 13 May 2019 20:47:58 -0700 Subject: [PATCH 10/10] change: docstring fix --- src/sagemaker/session.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/sagemaker/session.py b/src/sagemaker/session.py index 60be0a902e..a5415bf18c 100644 --- a/src/sagemaker/session.py +++ b/src/sagemaker/session.py @@ -173,7 +173,7 @@ def default_bucket(self, sts_endpoint_url=None): Args: sts_endpoint_url (str): Optional. URL of STS endpoint to send requests to. If not specified, - the global endpoint will be used (sts.amazonaws.com). Should be in the format of + the global endpoint will be used (sts.amazonaws.com). Should be in the format of sts..amazonaws.com Returns: