From 5f6d8caed28690203a177e9b359b745b943c6302 Mon Sep 17 00:00:00 2001 From: Jay Goyani Date: Tue, 23 Jan 2024 09:10:34 -0800 Subject: [PATCH] fix: update get_execution_role_arn from metadata file if present --- src/sagemaker/session.py | 7 ++++++- tests/unit/test_session.py | 19 +++++++++++++++++++ 2 files changed, 25 insertions(+), 1 deletion(-) diff --git a/src/sagemaker/session.py b/src/sagemaker/session.py index 8f2753a7cf..b1342eb381 100644 --- a/src/sagemaker/session.py +++ b/src/sagemaker/session.py @@ -5412,6 +5412,7 @@ def get_caller_identity_arn(self): domain_id = metadata.get("DomainId") user_profile_name = metadata.get("UserProfileName") space_name = metadata.get("SpaceName") + execution_role_arn = metadata.get("ExecutionRoleArn") try: if domain_id is None: instance_desc = self.sagemaker_client.describe_notebook_instance( @@ -5419,7 +5420,11 @@ def get_caller_identity_arn(self): ) return instance_desc["RoleArn"] - # In Space app, find execution role from DefaultSpaceSettings on domain level + # find execution role from the metadata file if present + if execution_role_arn is not None: + return execution_role_arn + + # In Shared Space app, find execution role from DefaultSpaceSettings on domain level if space_name is not None: domain_desc = self.sagemaker_client.describe_domain(DomainId=domain_id) return domain_desc["DefaultSpaceSettings"]["ExecutionRole"] diff --git a/tests/unit/test_session.py b/tests/unit/test_session.py index 6ee2cc9af5..93828d882f 100644 --- a/tests/unit/test_session.py +++ b/tests/unit/test_session.py @@ -698,6 +698,25 @@ def test_fallback_to_domain_if_role_unavailable_in_user_settings(boto_session): sess.sagemaker_client.describe_domain.assert_called_once_with(DomainId="d-kbnw5yk6tg8j") +@patch( + "six.moves.builtins.open", + mock_open( + read_data='{"ResourceName": "SageMakerInstance", ' + '"DomainId": "d-kbnw5yk6tg8j", ' + '"ExecutionRoleArn": "arn:aws:iam::369233609183:role/service-role/SageMakerRole-20171129T072388", ' + '"SpaceName": "space_name"}' + ), +) +@patch("os.path.exists", side_effect=mock_exists(NOTEBOOK_METADATA_FILE, True)) +def test_get_caller_identity_arn_from_metadata_file_for_space(boto_session): + sess = Session(boto_session) + expected_role = "arn:aws:iam::369233609183:role/service-role/SageMakerRole-20171129T072388" + + actual = sess.get_caller_identity_arn() + + assert actual == expected_role + + @patch( "six.moves.builtins.open", mock_open(