diff --git a/src/sagemaker/workflow/lambda_step.py b/src/sagemaker/workflow/lambda_step.py index 968dd8dc0f..0446a0b46c 100644 --- a/src/sagemaker/workflow/lambda_step.py +++ b/src/sagemaker/workflow/lambda_step.py @@ -154,7 +154,6 @@ def _get_function_arn(self): Method creates a lambda function and returns it's arn. If the lambda is already present, it will build it's arn and return that. """ - account_id = self.lambda_func.session.account_id() region = self.lambda_func.session.boto_region_name if region.lower() == "cn-north-1" or region.lower() == "cn-northwest-1": partition = "aws-cn" @@ -163,6 +162,7 @@ def _get_function_arn(self): if self.lambda_func.function_arn is None: try: + account_id = self.lambda_func.session.account_id() response = self.lambda_func.create() return response["FunctionArn"] except ValueError as error: diff --git a/tests/unit/sagemaker/workflow/test_lambda_step.py b/tests/unit/sagemaker/workflow/test_lambda_step.py index 4149f210da..0566e39318 100644 --- a/tests/unit/sagemaker/workflow/test_lambda_step.py +++ b/tests/unit/sagemaker/workflow/test_lambda_step.py @@ -16,7 +16,7 @@ import pytest -from mock import Mock +from mock import Mock, MagicMock from sagemaker.workflow.parameters import ParameterInteger, ParameterString from sagemaker.workflow.pipeline import Pipeline @@ -27,12 +27,13 @@ @pytest.fixture() def sagemaker_session(): boto_mock = Mock(name="boto_session", region_name="us-west-2") - session_mock = Mock( + session_mock = MagicMock( name="sagemaker_session", boto_session=boto_mock, boto_region_name="us-west-2", config=None, local_mode=False, + account_id=Mock(), ) return session_mock @@ -173,3 +174,36 @@ def test_lambda_step_no_inputs_outputs(sagemaker_session): "OutputParameters": [], "Arguments": {}, } + + +def test_lambda_step_with_function_arn(sagemaker_session): + lambda_step = LambdaStep( + name="MyLambdaStep", + depends_on=["TestStep"], + lambda_func=Lambda( + function_arn="arn:aws:lambda:us-west-2:123456789012:function:sagemaker_test_lambda", + session=sagemaker_session, + ), + inputs={}, + outputs=[], + ) + lambda_step._get_function_arn() + sagemaker_session.account_id.assert_not_called() + + +def test_lambda_step_without_function_arn(sagemaker_session): + lambda_step = LambdaStep( + name="MyLambdaStep", + depends_on=["TestStep"], + lambda_func=Lambda( + function_name="name", + execution_role_arn="arn:aws:lambda:us-west-2:123456789012:execution_role", + zipped_code_dir="", + handler="", + session=sagemaker_session, + ), + inputs={}, + outputs=[], + ) + lambda_step._get_function_arn() + sagemaker_session.account_id.assert_called_once()