diff --git a/src/sagemaker/lineage/context.py b/src/sagemaker/lineage/context.py index 57c0064eb2..aef919e876 100644 --- a/src/sagemaker/lineage/context.py +++ b/src/sagemaker/lineage/context.py @@ -490,3 +490,15 @@ def pipeline_execution_arn( return tag["Value"] return None + + +class ModelPackageGroup(Context): + """An Amazon SageMaker model package group context, which is part of a SageMaker lineage.""" + + def pipeline_execution_arn(self) -> str: + """Get the ARN for the pipeline execution associated with this model package group (if any). + + Returns: + str: A pipeline execution ARN. + """ + return self.properties.get("PipelineExecutionArn") diff --git a/tests/integ/sagemaker/lineage/conftest.py b/tests/integ/sagemaker/lineage/conftest.py index 007174be84..672af41de9 100644 --- a/tests/integ/sagemaker/lineage/conftest.py +++ b/tests/integ/sagemaker/lineage/conftest.py @@ -667,6 +667,29 @@ def static_endpoint_context(sagemaker_session, static_pipeline_execution_arn): ) +@pytest.fixture +def static_model_package_group_context(sagemaker_session, static_pipeline_execution_arn): + + model_package_group_arn = get_model_package_group_arn_from_static_pipeline(sagemaker_session) + + contexts = sagemaker_session.sagemaker_client.list_contexts(SourceUri=model_package_group_arn)[ + "ContextSummaries" + ] + if len(contexts) != 1: + raise ( + Exception( + f"Got an unexpected number of Contexts for \ + model package group {STATIC_MODEL_PACKAGE_GROUP_NAME} from pipeline \ + execution {static_pipeline_execution_arn}. \ + Expected 1 but got {len(contexts)}" + ) + ) + + yield context.ModelPackageGroup.load( + contexts[0]["ContextName"], sagemaker_session=sagemaker_session + ) + + @pytest.fixture def static_model_artifact(sagemaker_session, static_pipeline_execution_arn): model_package_arn = get_model_package_arn_from_static_pipeline( @@ -745,6 +768,15 @@ def get_endpoint_arn_from_static_pipeline(sagemaker_session): raise e +def get_model_package_group_arn_from_static_pipeline(sagemaker_session): + static_model_package_group_arn = ( + sagemaker_session.sagemaker_client.describe_model_package_group( + ModelPackageGroupName=STATIC_MODEL_PACKAGE_GROUP_NAME + )["ModelPackageGroupArn"] + ) + return static_model_package_group_arn + + def get_model_package_arn_from_static_pipeline(pipeline_execution_arn, sagemaker_session): # get the model package ARN from the pipeline pipeline_execution_steps = sagemaker_session.sagemaker_client.list_pipeline_execution_steps( diff --git a/tests/integ/sagemaker/lineage/test_model_package_group_context.py b/tests/integ/sagemaker/lineage/test_model_package_group_context.py new file mode 100644 index 0000000000..8f6cd85e77 --- /dev/null +++ b/tests/integ/sagemaker/lineage/test_model_package_group_context.py @@ -0,0 +1,20 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""This module contains code to test SageMaker ``ModelPackageGroup``""" +from __future__ import absolute_import + + +def test_pipeline_execution_arn(static_model_package_group_context, static_pipeline_execution_arn): + pipeline_execution_arn = static_model_package_group_context.pipeline_execution_arn() + + assert pipeline_execution_arn == static_pipeline_execution_arn diff --git a/tests/unit/sagemaker/lineage/test_model_package_group_context.py b/tests/unit/sagemaker/lineage/test_model_package_group_context.py new file mode 100644 index 0000000000..8c14773df7 --- /dev/null +++ b/tests/unit/sagemaker/lineage/test_model_package_group_context.py @@ -0,0 +1,36 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""This module contains code to test SageMaker ``ModelPackageGroup``""" +from __future__ import absolute_import + +import unittest.mock +import pytest +from sagemaker.lineage import context + + +@pytest.fixture +def sagemaker_session(): + return unittest.mock.Mock() + + +def test_pipeline_execution_arn(sagemaker_session): + obj = context.ModelPackageGroup( + sagemaker_session, + context_name="foo", + description="test-description", + properties={"PipelineExecutionArn": "abcd", "k2": "v2"}, + properties_to_remove=["E"], + ) + actual_result = obj.pipeline_execution_arn() + expected_result = "abcd" + assert expected_result == actual_result