From 274802c6cffd2e7e7143e6e5d28e3f9a7f8ee871 Mon Sep 17 00:00:00 2001 From: yifeizhu Date: Mon, 10 Jan 2022 11:05:13 -0800 Subject: [PATCH] feature: Add support for SageMaker lineage queries in artifact, context and trial component --- src/sagemaker/lineage/action.py | 7 +- src/sagemaker/lineage/artifact.py | 133 ++++++++++++- src/sagemaker/lineage/context.py | 96 ++++++++- .../lineage/lineage_trial_component.py | 184 ++++++++++++++++++ src/sagemaker/lineage/query.py | 42 +++- tests/integ/sagemaker/lineage/conftest.py | 112 ++++++++++- .../integ/sagemaker/lineage/test_artifact.py | 19 ++ tests/integ/sagemaker/lineage/test_context.py | 11 ++ .../lineage/test_dataset_artifact.py | 16 ++ .../lineage/test_endpoint_context.py | 44 +++++ .../sagemaker/lineage/test_image_artifact.py | 26 +++ .../lineage/test_lineage_trial_component.py | 33 ++++ tests/unit/sagemaker/lineage/test_artifact.py | 140 +++++++++++++ tests/unit/sagemaker/lineage/test_context.py | 182 +++++++++++++++++ .../lineage/test_dataset_artifact.py | 86 ++++++++ .../sagemaker/lineage/test_image_artifact.py | 65 +++++++ .../lineage/test_lineage_trial_component.py | 153 +++++++++++++++ tests/unit/sagemaker/lineage/test_query.py | 80 +++++++- 18 files changed, 1400 insertions(+), 29 deletions(-) create mode 100644 src/sagemaker/lineage/lineage_trial_component.py create mode 100644 tests/integ/sagemaker/lineage/test_image_artifact.py create mode 100644 tests/integ/sagemaker/lineage/test_lineage_trial_component.py create mode 100644 tests/unit/sagemaker/lineage/test_image_artifact.py create mode 100644 tests/unit/sagemaker/lineage/test_lineage_trial_component.py diff --git a/src/sagemaker/lineage/action.py b/src/sagemaker/lineage/action.py index 1c8015a451..9046a3ccf2 100644 --- a/src/sagemaker/lineage/action.py +++ b/src/sagemaker/lineage/action.py @@ -16,12 +16,11 @@ from typing import Optional, Iterator, List from datetime import datetime -from sagemaker import Session +from sagemaker.session import Session from sagemaker.apiutils import _base_types from sagemaker.lineage import _api_types, _utils from sagemaker.lineage._api_types import ActionSource, ActionSummary from sagemaker.lineage.artifact import Artifact -from sagemaker.lineage.context import Context from sagemaker.lineage.query import ( LineageQuery, @@ -126,7 +125,7 @@ def delete(self, disassociate: bool = False): self._invoke_api(self._boto_delete_method, self._boto_delete_members) @classmethod - def load(cls, action_name: str, sagemaker_session: Session = None) -> "Action": + def load(cls, action_name: str, sagemaker_session=None) -> "Action": """Load an existing action and return an ``Action`` object representing it. Args: @@ -324,7 +323,7 @@ def model_package(self): def endpoints( self, direction: LineageQueryDirectionEnum = LineageQueryDirectionEnum.DESCENDANTS - ) -> List[Context]: + ) -> List: """Use a lineage query to retrieve downstream endpoint contexts that use this action. Args: diff --git a/src/sagemaker/lineage/artifact.py b/src/sagemaker/lineage/artifact.py index fc41808099..3921562beb 100644 --- a/src/sagemaker/lineage/artifact.py +++ b/src/sagemaker/lineage/artifact.py @@ -143,10 +143,10 @@ def load(cls, artifact_arn: str, sagemaker_session=None) -> "Artifact": return artifact def downstream_trials(self, sagemaker_session=None) -> list: - """Retrieve all trial runs which that use this artifact. + """Use the lineage API to retrieve all downstream trials that use this artifact. Args: - sagemaker_session (obj): Sagemaker Sesssion to use. If not provided a default session + sagemaker_session (obj): Sagemaker Session to use. If not provided a default session will be created. Returns: @@ -159,6 +159,54 @@ def downstream_trials(self, sagemaker_session=None) -> list: ) trial_component_arns: list = list(map(lambda x: x.destination_arn, outgoing_associations)) + return self._get_trial_from_trial_component(trial_component_arns) + + def downstream_trials_v2(self) -> list: + """Use a lineage query to retrieve all downstream trials that use this artifact. + + Returns: + [Trial]: A list of SageMaker `Trial` objects. + """ + return self._trials(direction=LineageQueryDirectionEnum.DESCENDANTS) + + def upstream_trials(self) -> List: + """Use the lineage query to retrieve all upstream trials that use this artifact. + + Returns: + [Trial]: A list of SageMaker `Trial` objects. + """ + return self._trials(direction=LineageQueryDirectionEnum.ASCENDANTS) + + def _trials( + self, direction: LineageQueryDirectionEnum = LineageQueryDirectionEnum.BOTH + ) -> List: + """Use the lineage query to retrieve all trials that use this artifact. + + Args: + direction (LineageQueryDirectionEnum, optional): The query direction. + + Returns: + [Trial]: A list of SageMaker `Trial` objects. + """ + query_filter = LineageFilter(entities=[LineageEntityEnum.TRIAL_COMPONENT]) + query_result = LineageQuery(self.sagemaker_session).query( + start_arns=[self.artifact_arn], + query_filter=query_filter, + direction=direction, + include_edges=False, + ) + trial_component_arns: list = list(map(lambda x: x.arn, query_result.vertices)) + return self._get_trial_from_trial_component(trial_component_arns) + + def _get_trial_from_trial_component(self, trial_component_arns: list) -> List: + """Retrieve all upstream trial runs which that use the trial component arns. + + Args: + trial_component_arns (list): list of trial component arns + + Returns: + [Trial]: A list of SageMaker `Trial` objects. + """ if not trial_component_arns: # no outgoing associations for this artifact return [] @@ -170,7 +218,7 @@ def downstream_trials(self, sagemaker_session=None) -> list: num_search_batches = math.ceil(len(trial_component_arns) % max_search_by_arn) trial_components: list = [] - sagemaker_session = sagemaker_session or _utils.default_session() + sagemaker_session = self.sagemaker_session or _utils.default_session() sagemaker_client = sagemaker_session.sagemaker_client for i in range(num_search_batches): @@ -335,6 +383,17 @@ def list( sagemaker_session=sagemaker_session, ) + def s3_uri_artifacts(self, s3_uri: str) -> dict: + """Retrieve a list of artifacts that use provided s3 uri. + + Args: + s3_uri (str): A S3 URI. + + Returns: + A list of ``Artifacts`` + """ + return self.sagemaker_session.sagemaker_client.list_artifacts(SourceUri=s3_uri) + class ModelArtifact(Artifact): """A SageMaker lineage artifact representing a model. @@ -349,7 +408,7 @@ def endpoints(self) -> list: """Get association summaries for endpoints deployed with this model. Returns: - [AssociationSummary]: A list of associations repesenting the endpoints using the model. + [AssociationSummary]: A list of associations representing the endpoints using the model. """ endpoint_development_actions: Iterator = Association.list( source_arn=self.artifact_arn, @@ -522,3 +581,69 @@ def endpoint_contexts( for vertex in query_result.vertices: endpoint_contexts.append(vertex.to_lineage_object()) return endpoint_contexts + + def upstream_datasets(self) -> List[Artifact]: + """Use the lineage query to retrieve upstream artifacts that use this dataset artifact. + + Returns: + list of Artifacts: Artifacts representing an dataset. + """ + return self._datasets(direction=LineageQueryDirectionEnum.ASCENDANTS) + + def downstream_datasets(self) -> List[Artifact]: + """Use the lineage query to retrieve downstream artifacts that use this dataset. + + Returns: + list of Artifacts: Artifacts representing an dataset. + """ + return self._datasets(direction=LineageQueryDirectionEnum.DESCENDANTS) + + def _datasets( + self, direction: LineageQueryDirectionEnum = LineageQueryDirectionEnum.BOTH + ) -> List[Artifact]: + """Use the lineage query to retrieve all artifacts that use this dataset. + + Args: + direction (LineageQueryDirectionEnum, optional): The query direction. + + Returns: + list of Artifacts: Artifacts representing an dataset. + """ + query_filter = LineageFilter( + entities=[LineageEntityEnum.ARTIFACT], sources=[LineageSourceEnum.DATASET] + ) + query_result = LineageQuery(self.sagemaker_session).query( + start_arns=[self.artifact_arn], + query_filter=query_filter, + direction=direction, + include_edges=False, + ) + return [vertex.to_lineage_object() for vertex in query_result.vertices] + + +class ImageArtifact(Artifact): + """A SageMaker lineage artifact representing an image. + + Common model specific lineage traversals to discover how the image is connected + to other entities. + """ + + def datasets(self, direction: LineageQueryDirectionEnum) -> List[Artifact]: + """Use the lineage query to retrieve datasets that use this image artifact. + + Args: + direction (LineageQueryDirectionEnum): The query direction. + + Returns: + list of Artifacts: Artifacts representing a dataset. + """ + query_filter = LineageFilter( + entities=[LineageEntityEnum.ARTIFACT], sources=[LineageSourceEnum.DATASET] + ) + query_result = LineageQuery(self.sagemaker_session).query( + start_arns=[self.artifact_arn], + query_filter=query_filter, + direction=direction, + include_edges=False, + ) + return [vertex.to_lineage_object() for vertex in query_result.vertices] diff --git a/src/sagemaker/lineage/context.py b/src/sagemaker/lineage/context.py index 469b9aeb1a..57c0064eb2 100644 --- a/src/sagemaker/lineage/context.py +++ b/src/sagemaker/lineage/context.py @@ -31,6 +31,8 @@ LineageQueryDirectionEnum, ) from sagemaker.lineage.artifact import Artifact +from sagemaker.lineage.action import Action +from sagemaker.lineage.lineage_trial_component import LineageTrialComponent class Context(_base_types.Record): @@ -256,12 +258,30 @@ def list( sagemaker_session=sagemaker_session, ) + def actions(self, direction: LineageQueryDirectionEnum) -> List[Action]: + """Use the lineage query to retrieve actions that use this context. + + Args: + direction (LineageQueryDirectionEnum): The query direction. + + Returns: + list of Actions: Actions. + """ + query_filter = LineageFilter(entities=[LineageEntityEnum.ACTION]) + query_result = LineageQuery(self.sagemaker_session).query( + start_arns=[self.context_arn], + query_filter=query_filter, + direction=direction, + include_edges=False, + ) + return [vertex.to_lineage_object() for vertex in query_result.vertices] + class EndpointContext(Context): """An Amazon SageMaker endpoint context, which is part of a SageMaker lineage.""" def models(self) -> List[association.Association]: - """Get all models deployed by all endpoint versions of the endpoint. + """Use Lineage API to get all models deployed by this endpoint. Returns: list of Associations: Associations that destination represents an endpoint's model. @@ -286,7 +306,7 @@ def models(self) -> List[association.Association]: def models_v2( self, direction: LineageQueryDirectionEnum = LineageQueryDirectionEnum.DESCENDANTS ) -> List[Artifact]: - """Get artifacts representing models from the context lineage by querying lineage data. + """Use the lineage query to retrieve downstream model artifacts that use this endpoint. Args: direction (LineageQueryDirectionEnum, optional): The query direction. @@ -335,7 +355,7 @@ def models_v2( def dataset_artifacts( self, direction: LineageQueryDirectionEnum = LineageQueryDirectionEnum.ASCENDANTS ) -> List[Artifact]: - """Get artifacts representing datasets from the endpoint's lineage. + """Use the lineage query to retrieve datasets that use this endpoint. Args: direction (LineageQueryDirectionEnum, optional): The query direction. @@ -360,6 +380,9 @@ def training_job_arns( ) -> List[str]: """Get ARNs for all training jobs that appear in the endpoint's lineage. + Args: + direction (LineageQueryDirectionEnum, optional): The query direction. + Returns: list of str: Training job ARNs. """ @@ -382,11 +405,78 @@ def training_job_arns( training_job_arns.append(trial_component["Source"]["SourceArn"]) return training_job_arns + def processing_jobs( + self, direction: LineageQueryDirectionEnum = LineageQueryDirectionEnum.ASCENDANTS + ) -> List[LineageTrialComponent]: + """Use the lineage query to retrieve processing jobs that use this endpoint. + + Args: + direction (LineageQueryDirectionEnum, optional): The query direction. + + Returns: + list of LineageTrialComponent: Lineage trial component that represent Processing jobs. + """ + query_filter = LineageFilter( + entities=[LineageEntityEnum.TRIAL_COMPONENT], sources=[LineageSourceEnum.PROCESSING_JOB] + ) + query_result = LineageQuery(self.sagemaker_session).query( + start_arns=[self.context_arn], + query_filter=query_filter, + direction=direction, + include_edges=False, + ) + return [vertex.to_lineage_object() for vertex in query_result.vertices] + + def transform_jobs( + self, direction: LineageQueryDirectionEnum = LineageQueryDirectionEnum.ASCENDANTS + ) -> List[LineageTrialComponent]: + """Use the lineage query to retrieve transform jobs that use this endpoint. + + Args: + direction (LineageQueryDirectionEnum, optional): The query direction. + + Returns: + list of LineageTrialComponent: Lineage trial component that represent Transform jobs. + """ + query_filter = LineageFilter( + entities=[LineageEntityEnum.TRIAL_COMPONENT], sources=[LineageSourceEnum.TRANSFORM_JOB] + ) + query_result = LineageQuery(self.sagemaker_session).query( + start_arns=[self.context_arn], + query_filter=query_filter, + direction=direction, + include_edges=False, + ) + return [vertex.to_lineage_object() for vertex in query_result.vertices] + + def trial_components( + self, direction: LineageQueryDirectionEnum = LineageQueryDirectionEnum.ASCENDANTS + ) -> List[LineageTrialComponent]: + """Use the lineage query to retrieve trial components that use this endpoint. + + Args: + direction (LineageQueryDirectionEnum, optional): The query direction. + + Returns: + list of LineageTrialComponent: Lineage trial component. + """ + query_filter = LineageFilter(entities=[LineageEntityEnum.TRIAL_COMPONENT]) + query_result = LineageQuery(self.sagemaker_session).query( + start_arns=[self.context_arn], + query_filter=query_filter, + direction=direction, + include_edges=False, + ) + return [vertex.to_lineage_object() for vertex in query_result.vertices] + def pipeline_execution_arn( self, direction: LineageQueryDirectionEnum = LineageQueryDirectionEnum.ASCENDANTS ) -> str: """Get the ARN for the pipeline execution associated with this endpoint (if any). + Args: + direction (LineageQueryDirectionEnum, optional): The query direction. + Returns: str: A pipeline execution ARN. """ diff --git a/src/sagemaker/lineage/lineage_trial_component.py b/src/sagemaker/lineage/lineage_trial_component.py new file mode 100644 index 0000000000..f8bc0e53b4 --- /dev/null +++ b/src/sagemaker/lineage/lineage_trial_component.py @@ -0,0 +1,184 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""This module contains code to create and manage SageMaker ``LineageTrialComponent``.""" +from __future__ import absolute_import + +import logging + +from typing import List + +from sagemaker.apiutils import _base_types +from sagemaker.lineage.query import ( + LineageQuery, + LineageFilter, + LineageSourceEnum, + LineageEntityEnum, + LineageQueryDirectionEnum, +) +from sagemaker.lineage.artifact import Artifact + + +LOGGER = logging.getLogger("sagemaker") + + +class LineageTrialComponent(_base_types.Record): + """An Amazon SageMaker, lineage trial component, which is part of a SageMaker lineage. + + A trial component is a stage in a trial. + Trial components are created automatically within the SageMaker runtime and also can be + created directly. To automatically associate trial components with a trial and experiment + supply an experiment config when creating a job. + For example: https://docs.aws.amazon.com/sagemaker/latest/dg/API_CreateTrainingJob.html + + Attributes: + trial_component_name (str): The name of the trial component. Generated by SageMaker from the + name of the source job with a suffix specific to the type of source job. + trial_component_arn (str): The ARN of the trial component. + display_name (str): The name of the trial component that will appear in UI, + such as SageMaker Studio. + source (obj): A TrialComponentSource object with a source_arn attribute. + status (str): Status of the source job. + start_time (datetime): When the source job started. + end_time (datetime): When the source job ended. + creation_time (datetime): When the source job was created. + created_by (obj): Contextual info on which account created the trial component. + last_modified_time (datetime): When the trial component was last modified. + last_modified_by (obj): Contextual info on which account last modified the trial component. + parameters (dict): Dictionary of parameters to the source job. + input_artifacts (dict): Dictionary of input artifacts. + output_artifacts (dict): Dictionary of output artifacts. + metrics (obj): Aggregated metrics for the job. + parameters_to_remove (list): The hyperparameters to remove from the component. + input_artifacts_to_remove (list): The input artifacts to remove from the component. + output_artifacts_to_remove (list): The output artifacts to remove from the component. + tags (List[dict[str, str]]): A list of tags to associate with the trial component. + """ + + trial_component_name = None + trial_component_arn = None + display_name = None + source = None + status = None + start_time = None + end_time = None + creation_time = None + created_by = None + last_modified_time = None + last_modified_by = None + parameters = None + input_artifacts = None + output_artifacts = None + metrics = None + parameters_to_remove = None + input_artifacts_to_remove = None + output_artifacts_to_remove = None + tags = None + + _boto_create_method: str = "create_trial_component" + _boto_load_method: str = "describe_trial_component" + _boto_update_method: str = "update_trial_component" + _boto_delete_method: str = "delete_trial_component" + + _boto_update_members = [ + "trial_component_name", + "display_name", + "status", + "start_time", + "end_time", + "parameters", + "input_artifacts", + "output_artifacts", + "parameters_to_remove", + "input_artifacts_to_remove", + "output_artifacts_to_remove", + ] + _boto_delete_members = ["trial_component_name"] + + @classmethod + def load(cls, trial_component_name: str, sagemaker_session=None) -> "LineageTrialComponent": + """Load an existing trial component and return an ``TrialComponent`` object representing it. + + Args: + trial_component_name (str): Name of the trial component + sagemaker_session (sagemaker.session.Session): Session object which + manages interactions with Amazon SageMaker APIs and any other + AWS services needed. If not specified, one is created using the + default AWS configuration chain. + Returns: + LineageTrialComponent: A SageMaker ``LineageTrialComponent`` object + """ + trial_component = cls._construct( + cls._boto_load_method, + trial_component_name=trial_component_name, + sagemaker_session=sagemaker_session, + ) + return trial_component + + def pipeline_execution_arn(self) -> str: + """Get the ARN for the pipeline execution associated with this trial component (if any). + + Returns: + str: A pipeline execution ARN. + """ + tags = self.sagemaker_session.sagemaker_client.list_tags( + ResourceArn=self.trial_component_arn + )["Tags"] + for tag in tags: + if tag["Key"] == "sagemaker:pipeline-execution-arn": + return tag["Value"] + return None + + def dataset_artifacts( + self, direction: LineageQueryDirectionEnum = LineageQueryDirectionEnum.ASCENDANTS + ) -> List[Artifact]: + """Use the lineage query to retrieve datasets that use this trial component. + + Args: + direction (LineageQueryDirectionEnum, optional): The query direction. + + Returns: + list of Artifacts: Artifacts representing a dataset. + """ + query_filter = LineageFilter( + entities=[LineageEntityEnum.ARTIFACT], sources=[LineageSourceEnum.DATASET] + ) + query_result = LineageQuery(self.sagemaker_session).query( + start_arns=[self.trial_component_arn], + query_filter=query_filter, + direction=direction, + include_edges=False, + ) + + return [vertex.to_lineage_object() for vertex in query_result.vertices] + + def models( + self, direction: LineageQueryDirectionEnum = LineageQueryDirectionEnum.DESCENDANTS + ) -> List[Artifact]: + """Use the lineage query to retrieve models that use this trial component. + + Args: + direction (LineageQueryDirectionEnum, optional): The query direction. + + Returns: + list of Artifacts: Artifacts representing a dataset. + """ + query_filter = LineageFilter( + entities=[LineageEntityEnum.ARTIFACT], sources=[LineageSourceEnum.MODEL] + ) + query_result = LineageQuery(self.sagemaker_session).query( + start_arns=[self.trial_component_arn], + query_filter=query_filter, + direction=direction, + include_edges=False, + ) + return [vertex.to_lineage_object() for vertex in query_result.vertices] diff --git a/src/sagemaker/lineage/query.py b/src/sagemaker/lineage/query.py index ecb48e3661..a54331c39a 100644 --- a/src/sagemaker/lineage/query.py +++ b/src/sagemaker/lineage/query.py @@ -23,6 +23,7 @@ class LineageEntityEnum(Enum): """Enum of lineage entities for use in a query filter.""" + TRIAL = "Trial" ACTION = "Action" ARTIFACT = "Artifact" CONTEXT = "Context" @@ -44,6 +45,8 @@ class LineageSourceEnum(Enum): TENSORBOARD = "TensorBoard" TRAINING_JOB = "TrainingJob" APPROVAL = "Approval" + PROCESSING_JOB = "ProcessingJob" + TRANSFORM_JOB = "TransformJob" class LineageQueryDirectionEnum(Enum): @@ -128,11 +131,15 @@ def __eq__(self, other): ) def to_lineage_object(self): - """Convert the ``Vertex`` object to its corresponding Artifact, Action, Context object.""" - from sagemaker.lineage.artifact import Artifact, ModelArtifact + """Convert the ``Vertex`` object to its corresponding lineage object. + + Returns: + A ``Vertex`` object to its corresponding ``Artifact``,``Action``, ``Context`` + or ``TrialComponent`` object. + """ from sagemaker.lineage.context import Context, EndpointContext - from sagemaker.lineage.artifact import DatasetArtifact from sagemaker.lineage.action import Action + from sagemaker.lineage.lineage_trial_component import LineageTrialComponent if self.lineage_entity == LineageEntityEnum.CONTEXT.value: resource_name = get_resource_name_from_arn(self.arn) @@ -143,17 +150,31 @@ def to_lineage_object(self): return Context.load(context_name=resource_name, sagemaker_session=self._session) if self.lineage_entity == LineageEntityEnum.ARTIFACT.value: - if self.lineage_source == LineageSourceEnum.MODEL.value: - return ModelArtifact.load(artifact_arn=self.arn, sagemaker_session=self._session) - if self.lineage_source == LineageSourceEnum.DATASET.value: - return DatasetArtifact.load(artifact_arn=self.arn, sagemaker_session=self._session) - return Artifact.load(artifact_arn=self.arn, sagemaker_session=self._session) + return self._artifact_to_lineage_object() if self.lineage_entity == LineageEntityEnum.ACTION.value: return Action.load(action_name=self.arn.split("/")[1], sagemaker_session=self._session) + if self.lineage_entity == LineageEntityEnum.TRIAL_COMPONENT.value: + trial_component_name = get_resource_name_from_arn(self.arn) + return LineageTrialComponent.load( + trial_component_name=trial_component_name, sagemaker_session=self._session + ) raise ValueError("Vertex cannot be converted to a lineage object.") + def _artifact_to_lineage_object(self): + """Convert the ``Vertex`` object to its corresponding ``Artifact``.""" + from sagemaker.lineage.artifact import Artifact, ModelArtifact, ImageArtifact + from sagemaker.lineage.artifact import DatasetArtifact + + if self.lineage_source == LineageSourceEnum.MODEL.value: + return ModelArtifact.load(artifact_arn=self.arn, sagemaker_session=self._session) + if self.lineage_source == LineageSourceEnum.DATASET.value: + return DatasetArtifact.load(artifact_arn=self.arn, sagemaker_session=self._session) + if self.lineage_source == LineageSourceEnum.IMAGE.value: + return ImageArtifact.load(artifact_arn=self.arn, sagemaker_session=self._session) + return Artifact.load(artifact_arn=self.arn, sagemaker_session=self._session) + class LineageQueryResult(object): """A wrapper around the results of a lineage query.""" @@ -242,9 +263,12 @@ def _get_edge(self, edge): def _get_vertex(self, vertex): """Convert lineage query API response to a Vertex.""" + vertex_type = None + if "Type" in vertex: + vertex_type = vertex["Type"] return Vertex( arn=vertex["Arn"], - lineage_source=vertex["Type"], + lineage_source=vertex_type, lineage_entity=vertex["LineageType"], sagemaker_session=self._session, ) diff --git a/tests/integ/sagemaker/lineage/conftest.py b/tests/integ/sagemaker/lineage/conftest.py index 863ab62183..007174be84 100644 --- a/tests/integ/sagemaker/lineage/conftest.py +++ b/tests/integ/sagemaker/lineage/conftest.py @@ -25,13 +25,6 @@ association, artifact, ) -from sagemaker.lineage.query import ( - LineageFilter, - LineageEntityEnum, - LineageSourceEnum, - LineageQuery, - LineageQueryDirectionEnum, -) from sagemaker.model import ModelPackage from tests.integ.test_workflow import test_end_to_end_pipeline_successful_execution from sagemaker.workflow.pipeline import _PipelineExecution @@ -39,6 +32,14 @@ from smexperiments import trial_component, trial, experiment from random import randint from botocore.exceptions import ClientError +from sagemaker.lineage.query import ( + LineageQuery, + LineageFilter, + LineageSourceEnum, + LineageEntityEnum, + LineageQueryDirectionEnum, +) +from sagemaker.lineage.lineage_trial_component import LineageTrialComponent from tests.integ.sagemaker.lineage.helpers import name, names @@ -46,6 +47,7 @@ SLEEP_TIME_TWO_SECONDS = 2 STATIC_PIPELINE_NAME = "SdkIntegTestStaticPipeline17" STATIC_ENDPOINT_NAME = "SdkIntegTestStaticEndpoint17" +STATIC_MODEL_PACKAGE_GROUP_NAME = "SdkIntegTestStaticPipeline17ModelPackageGroup" @pytest.fixture @@ -214,6 +216,24 @@ def trial_associated_artifact(artifact_obj, trial_obj, trial_component_obj, sage sagemaker_session=sagemaker_session, ) trial_obj.add_trial_component(trial_component_obj) + time.sleep(4) + yield artifact_obj + trial_obj.remove_trial_component(trial_component_obj) + assntn.delete() + + +@pytest.fixture +def upstream_trial_associated_artifact( + artifact_obj, trial_obj, trial_component_obj, sagemaker_session +): + assntn = association.Association.create( + source_arn=trial_component_obj.trial_component_arn, + destination_arn=artifact_obj.artifact_arn, + association_type="ContributedTo", + sagemaker_session=sagemaker_session, + ) + trial_obj.add_trial_component(trial_component_obj) + time.sleep(3) yield artifact_obj trial_obj.remove_trial_component(trial_component_obj) assntn.delete() @@ -557,6 +577,67 @@ def static_model_deployment_action(sagemaker_session, static_endpoint_context): yield model_approval_actions[0] +@pytest.fixture +def static_processing_job_trial_component( + sagemaker_session, static_endpoint_context +) -> LineageTrialComponent: + query_filter = LineageFilter( + entities=[LineageEntityEnum.TRIAL_COMPONENT], sources=[LineageSourceEnum.PROCESSING_JOB] + ) + + query_result = LineageQuery(sagemaker_session).query( + start_arns=[static_endpoint_context.context_arn], + query_filter=query_filter, + direction=LineageQueryDirectionEnum.ASCENDANTS, + include_edges=False, + ) + processing_jobs = [] + for vertex in query_result.vertices: + processing_jobs.append(vertex.to_lineage_object()) + + return processing_jobs[0] + + +@pytest.fixture +def static_training_job_trial_component( + sagemaker_session, static_endpoint_context +) -> LineageTrialComponent: + query_filter = LineageFilter( + entities=[LineageEntityEnum.TRIAL_COMPONENT], sources=[LineageSourceEnum.TRAINING_JOB] + ) + + query_result = LineageQuery(sagemaker_session).query( + start_arns=[static_endpoint_context.context_arn], + query_filter=query_filter, + direction=LineageQueryDirectionEnum.ASCENDANTS, + include_edges=False, + ) + training_jobs = [] + for vertex in query_result.vertices: + training_jobs.append(vertex.to_lineage_object()) + + return training_jobs[0] + + +@pytest.fixture +def static_transform_job_trial_component( + static_processing_job_trial_component, sagemaker_session, static_endpoint_context +) -> LineageTrialComponent: + query_filter = LineageFilter( + entities=[LineageEntityEnum.TRIAL_COMPONENT], sources=[LineageSourceEnum.TRANSFORM_JOB] + ) + query_result = LineageQuery(sagemaker_session).query( + start_arns=[static_processing_job_trial_component.trial_component_arn], + query_filter=query_filter, + direction=LineageQueryDirectionEnum.DESCENDANTS, + include_edges=False, + ) + transform_jobs = [] + for vertex in query_result.vertices: + transform_jobs.append(vertex.to_lineage_object()) + yield transform_jobs[0] + + @pytest.fixture def static_endpoint_context(sagemaker_session, static_pipeline_execution_arn): endpoint_arn = get_endpoint_arn_from_static_pipeline(sagemaker_session) @@ -633,6 +714,23 @@ def static_dataset_artifact(static_model_artifact, sagemaker_session): ) +@pytest.fixture +def static_image_artifact(static_model_artifact, sagemaker_session): + query_filter = LineageFilter( + entities=[LineageEntityEnum.ARTIFACT], sources=[LineageSourceEnum.IMAGE] + ) + query_result = LineageQuery(sagemaker_session).query( + start_arns=[static_model_artifact.artifact_arn], + query_filter=query_filter, + direction=LineageQueryDirectionEnum.ASCENDANTS, + include_edges=False, + ) + image_artifact = [] + for vertex in query_result.vertices: + image_artifact.append(vertex.to_lineage_object()) + return image_artifact[0] + + def get_endpoint_arn_from_static_pipeline(sagemaker_session): try: endpoint_arn = sagemaker_session.sagemaker_client.describe_endpoint( diff --git a/tests/integ/sagemaker/lineage/test_artifact.py b/tests/integ/sagemaker/lineage/test_artifact.py index 4a0c6398b2..7ecbd0ac15 100644 --- a/tests/integ/sagemaker/lineage/test_artifact.py +++ b/tests/integ/sagemaker/lineage/test_artifact.py @@ -102,6 +102,13 @@ def test_list_by_type(artifact_objs, sagemaker_session): assert artifact_names_listed[0] == expected_name +def test_get_artifact(static_dataset_artifact): + s3_uri = static_dataset_artifact.source.source_uri + expected_artifact = static_dataset_artifact.s3_uri_artifacts(s3_uri=s3_uri) + for ar in expected_artifact["ArtifactSummaries"]: + assert ar.get("Source")["SourceUri"] == s3_uri + + def test_downstream_trials(trial_associated_artifact, trial_obj, sagemaker_session): # allow trial components to index, 30 seconds max def validate(): @@ -120,6 +127,18 @@ def validate(): retry(validate, num_attempts=3) +def test_downstream_trials_v2(trial_associated_artifact, trial_obj, sagemaker_session): + trials = trial_associated_artifact.downstream_trials_v2() + assert len(trials) == 1 + assert trial_obj.trial_name in trials + + +def test_upstream_trials(upstream_trial_associated_artifact, trial_obj, sagemaker_session): + trials = upstream_trial_associated_artifact.upstream_trials() + assert len(trials) == 1 + assert trial_obj.trial_name in trials + + @pytest.mark.timeout(30) def test_tag(artifact_obj, sagemaker_session): tag = {"Key": "foo", "Value": "bar"} diff --git a/tests/integ/sagemaker/lineage/test_context.py b/tests/integ/sagemaker/lineage/test_context.py index 5b36cee746..bdc4cb34e3 100644 --- a/tests/integ/sagemaker/lineage/test_context.py +++ b/tests/integ/sagemaker/lineage/test_context.py @@ -20,6 +20,7 @@ import pytest from sagemaker.lineage import context +from sagemaker.lineage.query import LineageQueryDirectionEnum def test_create_delete(context_obj): @@ -32,6 +33,16 @@ def test_create_delete_with_association(context_obj_with_association): assert context_obj_with_association.context_arn +def test_action(static_endpoint_context, sagemaker_session): + actions_from_query = static_endpoint_context.actions( + direction=LineageQueryDirectionEnum.ASCENDANTS + ) + + assert len(actions_from_query) > 0 + for action in actions_from_query: + assert "action" in action.action_arn + + def test_save(context_obj, sagemaker_session): context_obj.description = "updated description" context_obj.properties = {"k3": "v3"} diff --git a/tests/integ/sagemaker/lineage/test_dataset_artifact.py b/tests/integ/sagemaker/lineage/test_dataset_artifact.py index be03a85e86..ee81b7e137 100644 --- a/tests/integ/sagemaker/lineage/test_dataset_artifact.py +++ b/tests/integ/sagemaker/lineage/test_dataset_artifact.py @@ -35,3 +35,19 @@ def test_endpoint_contexts( assert len(contexts_from_query) > 0 for context in contexts_from_query: assert context.context_type == "Endpoint" + + +def test_get_upstream_datasets(static_dataset_artifact, sagemaker_session): + artifacts_from_query = static_dataset_artifact.upstream_datasets() + assert len(artifacts_from_query) > 0 + for artifact in artifacts_from_query: + assert artifact.artifact_type == "DataSet" + assert "artifact" in artifact.artifact_arn + + +def test_get_down_datasets(static_dataset_artifact, sagemaker_session): + artifacts_from_query = static_dataset_artifact.downstream_datasets() + assert len(artifacts_from_query) > 0 + for artifact in artifacts_from_query: + assert artifact.artifact_type == "DataSet" + assert "artifact" in artifact.artifact_arn diff --git a/tests/integ/sagemaker/lineage/test_endpoint_context.py b/tests/integ/sagemaker/lineage/test_endpoint_context.py index 78a33e8ef9..2a797bd5cb 100644 --- a/tests/integ/sagemaker/lineage/test_endpoint_context.py +++ b/tests/integ/sagemaker/lineage/test_endpoint_context.py @@ -15,6 +15,7 @@ import time SLEEP_TIME_ONE_SECONDS = 1 +SLEEP_TIME_THREE_SECONDS = 3 def test_model(endpoint_context_associate_with_model, model_obj, endpoint_action_obj): @@ -59,3 +60,46 @@ def test_pipeline_execution_arn(static_endpoint_context, static_pipeline_executi pipeline_execution_arn = static_endpoint_context.pipeline_execution_arn() assert pipeline_execution_arn == static_pipeline_execution_arn + + +def test_transform_jobs( + sagemaker_session, static_transform_job_trial_component, static_endpoint_context +): + sagemaker_session.sagemaker_client.add_association( + SourceArn=static_transform_job_trial_component.trial_component_arn, + DestinationArn=static_endpoint_context.context_arn, + AssociationType="ContributedTo", + ) + time.sleep(SLEEP_TIME_THREE_SECONDS) + transform_jobs_from_query = static_endpoint_context.transform_jobs() + + assert len(transform_jobs_from_query) > 0 + for transform_job in transform_jobs_from_query: + assert "transform-job" in transform_job.trial_component_arn + assert "TransformJob" in transform_job.source.get("SourceType") + + sagemaker_session.sagemaker_client.delete_association( + SourceArn=static_transform_job_trial_component.trial_component_arn, + DestinationArn=static_endpoint_context.context_arn, + ) + + +def test_processing_jobs( + sagemaker_session, static_transform_job_trial_component, static_endpoint_context +): + processing_jobs_from_query = static_endpoint_context.processing_jobs() + assert len(processing_jobs_from_query) > 0 + for processing_job in processing_jobs_from_query: + assert "processing-job" in processing_job.trial_component_arn + assert "ProcessingJob" in processing_job.source.get("SourceType") + + +def test_trial_components( + sagemaker_session, static_transform_job_trial_component, static_endpoint_context +): + trial_components_from_query = static_endpoint_context.trial_components() + + assert len(trial_components_from_query) > 0 + for trial_component in trial_components_from_query: + assert "job" in trial_component.trial_component_arn + assert "Job" in trial_component.source.get("SourceType") diff --git a/tests/integ/sagemaker/lineage/test_image_artifact.py b/tests/integ/sagemaker/lineage/test_image_artifact.py new file mode 100644 index 0000000000..bd0f76445d --- /dev/null +++ b/tests/integ/sagemaker/lineage/test_image_artifact.py @@ -0,0 +1,26 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""This module contains code to test SageMaker ``ImageArtifact``""" +from __future__ import absolute_import + +from sagemaker.lineage.query import LineageQueryDirectionEnum + + +def test_dataset(static_image_artifact, sagemaker_session): + artifacts_from_query = static_image_artifact.datasets( + direction=LineageQueryDirectionEnum.DESCENDANTS + ) + assert len(artifacts_from_query) > 0 + for artifact in artifacts_from_query: + assert artifact.artifact_type == "DataSet" + assert "artifact" in artifact.artifact_arn diff --git a/tests/integ/sagemaker/lineage/test_lineage_trial_component.py b/tests/integ/sagemaker/lineage/test_lineage_trial_component.py new file mode 100644 index 0000000000..d8a8a5d9c8 --- /dev/null +++ b/tests/integ/sagemaker/lineage/test_lineage_trial_component.py @@ -0,0 +1,33 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""This module contains code to test SageMaker ``Trial Component``""" +from __future__ import absolute_import + + +def test_dataset_artifacts(static_training_job_trial_component): + artifacts_from_query = static_training_job_trial_component.dataset_artifacts() + assert len(artifacts_from_query) > 0 + for artifact in artifacts_from_query: + assert artifact.artifact_type == "DataSet" + + +def test_models(static_processing_job_trial_component): + artifacts_from_query = static_processing_job_trial_component.models() + assert len(artifacts_from_query) > 0 + for artifact in artifacts_from_query: + assert artifact.artifact_type == "Model" + + +def test_pipeline_execution_arn(static_training_job_trial_component, static_pipeline_execution_arn): + pipeline_execution_arn = static_training_job_trial_component.pipeline_execution_arn() + assert pipeline_execution_arn == static_pipeline_execution_arn diff --git a/tests/unit/sagemaker/lineage/test_artifact.py b/tests/unit/sagemaker/lineage/test_artifact.py index 72228ec964..218532c1b7 100644 --- a/tests/unit/sagemaker/lineage/test_artifact.py +++ b/tests/unit/sagemaker/lineage/test_artifact.py @@ -377,3 +377,143 @@ def test_downstream_trials(sagemaker_session): ), ] assert expected_calls == sagemaker_session.sagemaker_client.list_associations.mock_calls + + +def test_downstream_trials_v2(sagemaker_session): + sagemaker_session.sagemaker_client.query_lineage.return_value = { + "Vertices": [ + {"Arn": "B" + str(i), "Type": "DataSet", "LineageType": "Artifact"} for i in range(10) + ], + "Edges": [{"SourceArn": "arn1", "DestinationArn": "arn2", "AssociationType": "Produced"}], + } + sagemaker_session.sagemaker_client.search.return_value = { + "Results": [ + { + "TrialComponent": { + "TrialComponentName": "tc-1", + "TrialComponentArn": "arn::tc-1", + "DisplayName": "TC1", + "Parents": [{"TrialName": "test-trial-name"}], + } + } + ] + } + + obj = artifact.Artifact( + sagemaker_session=sagemaker_session, + artifact_arn="test-arn", + artifact_name="foo", + properties={"k1": "v1", "k2": "v2"}, + properties_to_remove=["r1"], + ) + + result = obj.downstream_trials_v2() + + expected_trials = ["test-trial-name"] + + assert expected_trials == result + + expected_calls = [ + unittest.mock.call( + Direction="Descendants", + Filters={"LineageTypes": ["TrialComponent"]}, + IncludeEdges=False, + MaxDepth=10, + StartArns=["test-arn"], + ), + ] + assert expected_calls == sagemaker_session.sagemaker_client.query_lineage.mock_calls + + +def test_upstream_trials(sagemaker_session): + sagemaker_session.sagemaker_client.query_lineage.return_value = { + "Vertices": [ + {"Arn": "B" + str(i), "Type": "DataSet", "LineageType": "Artifact"} for i in range(10) + ], + "Edges": [{"SourceArn": "arn1", "DestinationArn": "arn2", "AssociationType": "Produced"}], + } + sagemaker_session.sagemaker_client.search.return_value = { + "Results": [ + { + "TrialComponent": { + "TrialComponentName": "tc-1", + "TrialComponentArn": "arn::tc-1", + "DisplayName": "TC1", + "Parents": [{"TrialName": "test-trial-name"}], + } + } + ] + } + + obj = artifact.Artifact( + sagemaker_session=sagemaker_session, + artifact_arn="test-arn", + artifact_name="foo", + properties={"k1": "v1", "k2": "v2"}, + properties_to_remove=["r1"], + ) + + result = obj.upstream_trials() + + expected_trials = ["test-trial-name"] + + assert expected_trials == result + + expected_calls = [ + unittest.mock.call( + Direction="Ascendants", + Filters={"LineageTypes": ["TrialComponent"]}, + IncludeEdges=False, + MaxDepth=10, + StartArns=["test-arn"], + ), + ] + assert expected_calls == sagemaker_session.sagemaker_client.query_lineage.mock_calls + + +def test_s3_uri_artifacts(sagemaker_session): + obj = artifact.Artifact( + sagemaker_session=sagemaker_session, + artifact_arn="test-arn", + artifact_name="foo", + source_uri="s3://abced", + properties={"k1": "v1", "k2": "v2"}, + properties_to_remove=["r1"], + ) + sagemaker_session.sagemaker_client.list_artifacts.side_effect = [ + { + "ArtifactSummaries": [ + { + "ArtifactArn": "A", + "ArtifactName": "B", + "Source": { + "SourceUri": "D", + "source_types": [{"SourceIdType": "source_id_type", "Value": "value1"}], + }, + "ArtifactType": "test-type", + } + ], + "NextToken": "100", + }, + ] + result = obj.s3_uri_artifacts(s3_uri="s3://abced") + + expected_calls = [ + unittest.mock.call(SourceUri="s3://abced"), + ] + expected_result = { + "ArtifactSummaries": [ + { + "ArtifactArn": "A", + "ArtifactName": "B", + "Source": { + "SourceUri": "D", + "source_types": [{"SourceIdType": "source_id_type", "Value": "value1"}], + }, + "ArtifactType": "test-type", + } + ], + "NextToken": "100", + } + assert expected_calls == sagemaker_session.sagemaker_client.list_artifacts.mock_calls + assert result == expected_result diff --git a/tests/unit/sagemaker/lineage/test_context.py b/tests/unit/sagemaker/lineage/test_context.py index 5cf48dea67..d87120dde2 100644 --- a/tests/unit/sagemaker/lineage/test_context.py +++ b/tests/unit/sagemaker/lineage/test_context.py @@ -17,6 +17,9 @@ import pytest from sagemaker.lineage import context, _api_types +from sagemaker.lineage.action import Action +from sagemaker.lineage.lineage_trial_component import LineageTrialComponent +from sagemaker.lineage.query import LineageQueryDirectionEnum @pytest.fixture @@ -328,3 +331,182 @@ def test_create_delete_with_association(sagemaker_session): delete_with_association_expected_calls == sagemaker_session.sagemaker_client.delete_association.mock_calls ) + + +def test_actions(sagemaker_session): + context_arn = "arn:aws:sagemaker:us-west-2:123456789012:context/lineage-unit-3b05f017-0d87-4c37" + action_arn = "arn:aws:sagemaker:us-west-2:123456789012:action/lineage-unit-3b05f017-0d87-4c37" + obj = context.EndpointContext(sagemaker_session, context_name="foo", context_arn=context_arn) + + sagemaker_session.sagemaker_client.query_lineage.return_value = { + "Vertices": [ + {"Arn": action_arn, "Type": "Approval", "LineageType": "Action"}, + ], + "Edges": [{"SourceArn": "arn1", "DestinationArn": "arn2", "AssociationType": "Produced"}], + } + + sagemaker_session.sagemaker_client.describe_action.return_value = { + "ActionName": "MyAction", + "ActionArn": action_arn, + } + + action_list = obj.actions(direction=LineageQueryDirectionEnum.DESCENDANTS) + + expected_calls = [ + unittest.mock.call( + Direction="Descendants", + Filters={"LineageTypes": ["Action"]}, + IncludeEdges=False, + MaxDepth=10, + StartArns=[context_arn], + ), + ] + assert expected_calls == sagemaker_session.sagemaker_client.query_lineage.mock_calls + + expected_action_list = [ + Action( + action_arn=action_arn, + action_name="MyAction", + ) + ] + + assert expected_action_list[0].action_arn == action_list[0].action_arn + assert expected_action_list[0].action_name == action_list[0].action_name + + +def test_processing_jobs(sagemaker_session): + context_arn = "arn:aws:sagemaker:us-west-2:123456789012:context/lineage-unit-3b05f017-0d87-4c37" + processing_job_arn = ( + "arn:aws:sagemaker:us-west-2:123456789012:trial_component/lineage-unit-3b05f017-0d87-4c37" + ) + obj = context.EndpointContext(sagemaker_session, context_name="foo", context_arn=context_arn) + + sagemaker_session.sagemaker_client.query_lineage.return_value = { + "Vertices": [ + {"Arn": processing_job_arn, "Type": "ProcessingJob", "LineageType": "TrialComponent"}, + ], + "Edges": [{"SourceArn": "arn1", "DestinationArn": "arn2", "AssociationType": "Produced"}], + } + sagemaker_session.sagemaker_client.describe_trial_component.return_value = { + "TrialComponentName": "MyProcessingJob", + "TrialComponentArn": processing_job_arn, + } + + trial_component_list = obj.processing_jobs(direction=LineageQueryDirectionEnum.ASCENDANTS) + expected_calls = [ + unittest.mock.call( + Direction="Ascendants", + Filters={"Types": ["ProcessingJob"], "LineageTypes": ["TrialComponent"]}, + IncludeEdges=False, + MaxDepth=10, + StartArns=[context_arn], + ), + ] + assert expected_calls == sagemaker_session.sagemaker_client.query_lineage.mock_calls + expected_trial_component_list = [ + LineageTrialComponent( + trial_component_name="MyProcessingJob", + trial_component_arn=processing_job_arn, + ) + ] + + assert ( + expected_trial_component_list[0].trial_component_arn + == trial_component_list[0].trial_component_arn + ) + assert ( + expected_trial_component_list[0].trial_component_name + == trial_component_list[0].trial_component_name + ) + + +def test_transform_jobs(sagemaker_session): + context_arn = "arn:aws:sagemaker:us-west-2:123456789012:context/lineage-unit-3b05f017-0d87-4c37" + transform_job_arn = ( + "arn:aws:sagemaker:us-west-2:123456789012:trial_component/lineage-unit-3b05f017-0d87-4c37" + ) + obj = context.EndpointContext(sagemaker_session, context_name="foo", context_arn=context_arn) + + sagemaker_session.sagemaker_client.query_lineage.return_value = { + "Vertices": [ + {"Arn": transform_job_arn, "Type": "TransformJob", "LineageType": "TrialComponent"}, + ], + "Edges": [{"SourceArn": "arn1", "DestinationArn": "arn2", "AssociationType": "Produced"}], + } + sagemaker_session.sagemaker_client.describe_trial_component.return_value = { + "TrialComponentName": "MyTransformJob", + "TrialComponentArn": transform_job_arn, + } + + trial_component_list = obj.transform_jobs(direction=LineageQueryDirectionEnum.ASCENDANTS) + expected_calls = [ + unittest.mock.call( + Direction="Ascendants", + Filters={"Types": ["TransformJob"], "LineageTypes": ["TrialComponent"]}, + IncludeEdges=False, + MaxDepth=10, + StartArns=[context_arn], + ), + ] + assert expected_calls == sagemaker_session.sagemaker_client.query_lineage.mock_calls + expected_trial_component_list = [ + LineageTrialComponent( + trial_component_name="MyTransformJob", + trial_component_arn=transform_job_arn, + ) + ] + + assert ( + expected_trial_component_list[0].trial_component_arn + == trial_component_list[0].trial_component_arn + ) + assert ( + expected_trial_component_list[0].trial_component_name + == trial_component_list[0].trial_component_name + ) + + +def test_trial_components(sagemaker_session): + context_arn = "arn:aws:sagemaker:us-west-2:123456789012:context/lineage-unit-3b05f017-0d87-4c37" + trial_component_arn = ( + "arn:aws:sagemaker:us-west-2:123456789012:trial_component/lineage-unit-3b05f017-0d87-4c37" + ) + obj = context.EndpointContext(sagemaker_session, context_name="foo", context_arn=context_arn) + + sagemaker_session.sagemaker_client.query_lineage.return_value = { + "Vertices": [ + {"Arn": trial_component_arn, "Type": "TransformJob", "LineageType": "TrialComponent"}, + ], + "Edges": [{"SourceArn": "arn1", "DestinationArn": "arn2", "AssociationType": "Produced"}], + } + sagemaker_session.sagemaker_client.describe_trial_component.return_value = { + "TrialComponentName": "MyTransformJob", + "TrialComponentArn": trial_component_arn, + } + + trial_component_list = obj.trial_components(direction=LineageQueryDirectionEnum.ASCENDANTS) + expected_calls = [ + unittest.mock.call( + Direction="Ascendants", + Filters={"LineageTypes": ["TrialComponent"]}, + IncludeEdges=False, + MaxDepth=10, + StartArns=[context_arn], + ), + ] + assert expected_calls == sagemaker_session.sagemaker_client.query_lineage.mock_calls + expected_trial_component_list = [ + LineageTrialComponent( + trial_component_name="MyTransformJob", + trial_component_arn=trial_component_arn, + ) + ] + + assert ( + expected_trial_component_list[0].trial_component_arn + == trial_component_list[0].trial_component_arn + ) + assert ( + expected_trial_component_list[0].trial_component_name + == trial_component_list[0].trial_component_name + ) diff --git a/tests/unit/sagemaker/lineage/test_dataset_artifact.py b/tests/unit/sagemaker/lineage/test_dataset_artifact.py index 6db5a215f6..074efb488c 100644 --- a/tests/unit/sagemaker/lineage/test_dataset_artifact.py +++ b/tests/unit/sagemaker/lineage/test_dataset_artifact.py @@ -83,3 +83,89 @@ def test_trained_models(sagemaker_session): ) ] assert expected_model_list == model_list + + +def test_upstream_datasets(sagemaker_session): + artifact_arn = ( + "arn:aws:sagemaker:us-west-2:123456789012:artifact/lineage-unit-3b05f017-0d87-4c37" + ) + artifact_dataset_arn = "arn:aws:sagemaker:us-west-2:123456789012:artifact/datasets" + artifact_dataset_name = "myDataset" + + obj = artifact.DatasetArtifact( + sagemaker_session, artifact_name="foo", artifact_arn=artifact_arn + ) + + sagemaker_session.sagemaker_client.query_lineage.return_value = { + "Vertices": [ + {"Arn": artifact_dataset_arn, "Type": "DataSet", "LineageType": "Artifact"}, + ], + "Edges": [{"SourceArn": "arn1", "DestinationArn": "arn2", "AssociationType": "Produced"}], + } + sagemaker_session.sagemaker_client.describe_artifact.return_value = { + "ArtifactName": artifact_dataset_name, + "ArtifactArn": artifact_dataset_arn, + } + + dataset_list = obj.upstream_datasets() + expected_calls = [ + unittest.mock.call( + Direction="Ascendants", + Filters={"Types": ["DataSet"], "LineageTypes": ["Artifact"]}, + IncludeEdges=False, + MaxDepth=10, + StartArns=[artifact_arn], + ), + ] + assert expected_calls == sagemaker_session.sagemaker_client.query_lineage.mock_calls + expected_dataset_list = [ + artifact.DatasetArtifact( + artifact_name=artifact_dataset_name, + artifact_arn=artifact_dataset_arn, + ) + ] + assert expected_dataset_list[0].artifact_arn == dataset_list[0].artifact_arn + assert expected_dataset_list[0].artifact_name == dataset_list[0].artifact_name + + +def test_downstream_datasets(sagemaker_session): + artifact_arn = ( + "arn:aws:sagemaker:us-west-2:123456789012:artifact/lineage-unit-3b05f017-0d87-4c37" + ) + artifact_dataset_arn = "arn:aws:sagemaker:us-west-2:123456789012:artifact/datasets" + artifact_dataset_name = "myDataset" + + obj = artifact.DatasetArtifact( + sagemaker_session, artifact_name="foo", artifact_arn=artifact_arn + ) + + sagemaker_session.sagemaker_client.query_lineage.return_value = { + "Vertices": [ + {"Arn": artifact_dataset_arn, "Type": "DataSet", "LineageType": "Artifact"}, + ], + "Edges": [{"SourceArn": "arn1", "DestinationArn": "arn2", "AssociationType": "Produced"}], + } + sagemaker_session.sagemaker_client.describe_artifact.return_value = { + "ArtifactName": artifact_dataset_name, + "ArtifactArn": artifact_dataset_arn, + } + + dataset_list = obj.downstream_datasets() + expected_calls = [ + unittest.mock.call( + Direction="Descendants", + Filters={"Types": ["DataSet"], "LineageTypes": ["Artifact"]}, + IncludeEdges=False, + MaxDepth=10, + StartArns=[artifact_arn], + ), + ] + assert expected_calls == sagemaker_session.sagemaker_client.query_lineage.mock_calls + expected_dataset_list = [ + artifact.DatasetArtifact( + artifact_name=artifact_dataset_name, + artifact_arn=artifact_dataset_arn, + ) + ] + assert expected_dataset_list[0].artifact_arn == dataset_list[0].artifact_arn + assert expected_dataset_list[0].artifact_name == dataset_list[0].artifact_name diff --git a/tests/unit/sagemaker/lineage/test_image_artifact.py b/tests/unit/sagemaker/lineage/test_image_artifact.py new file mode 100644 index 0000000000..485d942db3 --- /dev/null +++ b/tests/unit/sagemaker/lineage/test_image_artifact.py @@ -0,0 +1,65 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +import unittest.mock + +import pytest +from sagemaker.lineage import artifact +from sagemaker.lineage.query import LineageQueryDirectionEnum + + +@pytest.fixture +def sagemaker_session(): + return unittest.mock.Mock() + + +def test_datasets(sagemaker_session): + artifact_arn = ( + "arn:aws:sagemaker:us-west-2:123456789012:artifact/lineage-unit-3b05f017-0d87-4c37" + ) + artifact_dataset_arn = "arn:aws:sagemaker:us-west-2:123456789012:artifact/datasets" + artifact_dataset_name = "myDataset" + + obj = artifact.ImageArtifact(sagemaker_session, artifact_name="foo", artifact_arn=artifact_arn) + + sagemaker_session.sagemaker_client.query_lineage.return_value = { + "Vertices": [ + {"Arn": artifact_dataset_arn, "Type": "DataSet", "LineageType": "Artifact"}, + ], + "Edges": [{"SourceArn": "arn1", "DestinationArn": "arn2", "AssociationType": "Produced"}], + } + sagemaker_session.sagemaker_client.describe_artifact.return_value = { + "ArtifactName": artifact_dataset_name, + "ArtifactArn": artifact_dataset_arn, + } + + dataset_list = obj.datasets(direction=LineageQueryDirectionEnum.DESCENDANTS) + expected_calls = [ + unittest.mock.call( + Direction="Descendants", + Filters={"Types": ["DataSet"], "LineageTypes": ["Artifact"]}, + IncludeEdges=False, + MaxDepth=10, + StartArns=[artifact_arn], + ), + ] + assert expected_calls == sagemaker_session.sagemaker_client.query_lineage.mock_calls + expected_dataset_list = [ + artifact.DatasetArtifact( + artifact_name=artifact_dataset_name, + artifact_arn=artifact_dataset_arn, + ) + ] + assert expected_dataset_list[0].artifact_arn == dataset_list[0].artifact_arn + assert expected_dataset_list[0].artifact_name == dataset_list[0].artifact_name diff --git a/tests/unit/sagemaker/lineage/test_lineage_trial_component.py b/tests/unit/sagemaker/lineage/test_lineage_trial_component.py new file mode 100644 index 0000000000..9b466832a1 --- /dev/null +++ b/tests/unit/sagemaker/lineage/test_lineage_trial_component.py @@ -0,0 +1,153 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +import unittest.mock + +import pytest +from sagemaker.lineage import artifact, lineage_trial_component + + +@pytest.fixture +def sagemaker_session(): + return unittest.mock.Mock() + + +def test_dataset_artifacts(sagemaker_session): + trial_component_arn = ( + "arn:aws:sagemaker:us-west-2:123456789012:trial_component/lineage-unit-3b05f017-0d87-4c37" + ) + artifact_dataset_arn = "arn:aws:sagemaker:us-west-2:123456789012:artifact/datasets" + artifact_dataset_name = "myDataset" + + obj = lineage_trial_component.LineageTrialComponent( + sagemaker_session, trial_component_name="foo", trial_component_arn=trial_component_arn + ) + + sagemaker_session.sagemaker_client.query_lineage.return_value = { + "Vertices": [ + {"Arn": artifact_dataset_arn, "Type": "DataSet", "LineageType": "Artifact"}, + ], + "Edges": [{"SourceArn": "arn1", "DestinationArn": "arn2", "AssociationType": "Produced"}], + } + sagemaker_session.sagemaker_client.describe_artifact.return_value = { + "ArtifactName": artifact_dataset_name, + "ArtifactArn": artifact_dataset_arn, + } + + dataset_list = obj.dataset_artifacts() + expected_calls = [ + unittest.mock.call( + Direction="Ascendants", + Filters={"Types": ["DataSet"], "LineageTypes": ["Artifact"]}, + IncludeEdges=False, + MaxDepth=10, + StartArns=[trial_component_arn], + ), + ] + assert expected_calls == sagemaker_session.sagemaker_client.query_lineage.mock_calls + expected_dataset_list = [ + artifact.DatasetArtifact( + artifact_name=artifact_dataset_name, + artifact_arn=artifact_dataset_arn, + ) + ] + assert expected_dataset_list[0].artifact_arn == dataset_list[0].artifact_arn + assert expected_dataset_list[0].artifact_name == dataset_list[0].artifact_name + + +def test_models(sagemaker_session): + trial_component_arn = ( + "arn:aws:sagemaker:us-west-2:123456789012:trial_component/lineage-unit-3b05f017-0d87-4c37" + ) + model_arn = "arn:aws:sagemaker:us-west-2:123456789012:context/models" + model_name = "myDataset" + + obj = lineage_trial_component.LineageTrialComponent( + sagemaker_session, trial_component_name="foo", trial_component_arn=trial_component_arn + ) + + sagemaker_session.sagemaker_client.query_lineage.return_value = { + "Vertices": [ + {"Arn": model_arn, "Type": "Model", "LineageType": "Artifact"}, + ], + "Edges": [{"SourceArn": "arn1", "DestinationArn": "arn2", "AssociationType": "Produced"}], + } + + sagemaker_session.sagemaker_client.describe_artifact.return_value = { + "ArtifactName": model_name, + "ArtifactArn": model_arn, + } + + model_list = obj.models() + expected_calls = [ + unittest.mock.call( + Direction="Descendants", + Filters={"Types": ["Model"], "LineageTypes": ["Artifact"]}, + IncludeEdges=False, + MaxDepth=10, + StartArns=[trial_component_arn], + ), + ] + assert expected_calls == sagemaker_session.sagemaker_client.query_lineage.mock_calls + expected_model_list = [ + artifact.DatasetArtifact( + artifact_name=model_name, + artifact_arn=model_arn, + ) + ] + assert expected_model_list[0].artifact_arn == model_list[0].artifact_arn + assert expected_model_list[0].artifact_name == model_list[0].artifact_name + + +def test_pipeline_execution_arn(sagemaker_session): + trial_component_arn = ( + "arn:aws:sagemaker:us-west-2:123456789012:trial_component/lineage-unit-3b05f017-0d87-4c37" + ) + obj = lineage_trial_component.LineageTrialComponent( + sagemaker_session, trial_component_name="foo", trial_component_arn=trial_component_arn + ) + + sagemaker_session.sagemaker_client.list_tags.return_value = { + "Tags": [ + {"Key": "sagemaker:pipeline-execution-arn", "Value": "tag1"}, + ], + } + expected_calls = [ + unittest.mock.call(ResourceArn=trial_component_arn), + ] + pipeline_execution_arn_result = obj.pipeline_execution_arn() + assert pipeline_execution_arn_result == "tag1" + assert expected_calls == sagemaker_session.sagemaker_client.list_tags.mock_calls + + +def test_no_pipeline_execution_arn(sagemaker_session): + trial_component_arn = ( + "arn:aws:sagemaker:us-west-2:123456789012:trial_component/lineage-unit-3b05f017-0d87-4c37" + ) + obj = lineage_trial_component.LineageTrialComponent( + sagemaker_session, trial_component_name="foo", trial_component_arn=trial_component_arn + ) + + sagemaker_session.sagemaker_client.list_tags.return_value = { + "Tags": [ + {"Key": "abcd", "Value": "efg"}, + ], + } + expected_calls = [ + unittest.mock.call(ResourceArn=trial_component_arn), + ] + pipeline_execution_arn_result = obj.pipeline_execution_arn() + expected_result = None + assert pipeline_execution_arn_result == expected_result + assert expected_calls == sagemaker_session.sagemaker_client.list_tags.mock_calls diff --git a/tests/unit/sagemaker/lineage/test_query.py b/tests/unit/sagemaker/lineage/test_query.py index 50bb14e6b1..ae76fd199c 100644 --- a/tests/unit/sagemaker/lineage/test_query.py +++ b/tests/unit/sagemaker/lineage/test_query.py @@ -11,9 +11,11 @@ # ANY KIND, either express or implied. See the License for the specific # language governing permissions and limitations under the License. from __future__ import absolute_import +import unittest.mock from sagemaker.lineage.artifact import DatasetArtifact, ModelArtifact, Artifact from sagemaker.lineage.context import EndpointContext, Context from sagemaker.lineage.action import Action +from sagemaker.lineage.lineage_trial_component import LineageTrialComponent from sagemaker.lineage.query import LineageEntityEnum, LineageSourceEnum, Vertex, LineageQuery import pytest @@ -286,6 +288,49 @@ def test_vertex_to_object_context(sagemaker_session): assert isinstance(context, Context) +def test_vertex_to_object_trial_component(sagemaker_session): + + tc_arn = "arn:aws:sagemaker:us-west-2:963951943925:trial-component/abaloneprocess-ixyt08z3ru-aws-processing-job" + vertex = Vertex( + arn=tc_arn, + lineage_entity=LineageEntityEnum.TRIAL_COMPONENT.value, + lineage_source=LineageSourceEnum.TRANSFORM_JOB.value, + sagemaker_session=sagemaker_session, + ) + + sagemaker_session.sagemaker_client.describe_trial_component.return_value = { + "TrialComponentName": "MyTrialComponent", + "TrialComponentArn": tc_arn, + "Source": { + "SourceUri": "arn:aws:sagemaker:us-west-2:0123456789012:model/my_trial_component", + "SourceType": "ARN", + "SourceId": "Thu Dec 17 17:16:24 UTC 2020", + }, + "TrialComponentType": "ModelDeployment", + "Properties": { + "PipelineExecutionArn": "arn:aws:sagemaker:us-west-2:0123456789012:\ + pipeline/mypipeline/execution/0irnteql64d0", + "PipelineStepName": "MyStep", + "Status": "Completed", + }, + "CreationTime": 1608225384.0, + "CreatedBy": {}, + "LastModifiedTime": 1608225384.0, + "LastModifiedBy": {}, + } + + trial_component = vertex.to_lineage_object() + + expected_calls = [ + unittest.mock.call(TrialComponentName="abaloneprocess-ixyt08z3ru-aws-processing-job"), + ] + assert expected_calls == sagemaker_session.sagemaker_client.describe_trial_component.mock_calls + + assert trial_component.trial_component_arn == tc_arn + assert trial_component.trial_component_name == "MyTrialComponent" + assert isinstance(trial_component, LineageTrialComponent) + + def test_vertex_to_object_model_artifact(sagemaker_session): vertex = Vertex( arn="arn:aws:sagemaker:us-west-2:0123456789012:artifact/e66eef7f19c05e75284089183491bd4f", @@ -317,6 +362,37 @@ def test_vertex_to_object_model_artifact(sagemaker_session): assert isinstance(artifact, ModelArtifact) +def test_vertex_to_object_artifact(sagemaker_session): + vertex = Vertex( + arn="arn:aws:sagemaker:us-west-2:0123456789012:artifact/e66eef7f19c05e75284089183491bd4f", + lineage_entity=LineageEntityEnum.ARTIFACT.value, + lineage_source=LineageSourceEnum.MODEL.value, + sagemaker_session=sagemaker_session, + ) + + sagemaker_session.sagemaker_client.describe_artifact.return_value = { + "ArtifactArn": "arn:aws:sagemaker:us-west-2:0123456789012:artifact/e66eef7f19c05e75284089183491bd4f", + "Source": { + "SourceUri": "arn:aws:sagemaker:us-west-2:0123456789012:model/mymodel", + "SourceTypes": [], + }, + "ArtifactType": None, + "Properties": {}, + "CreationTime": 1608224704.149, + "CreatedBy": {}, + "LastModifiedTime": 1608224704.149, + "LastModifiedBy": {}, + } + + artifact = vertex.to_lineage_object() + + assert ( + artifact.artifact_arn + == "arn:aws:sagemaker:us-west-2:0123456789012:artifact/e66eef7f19c05e75284089183491bd4f" + ) + assert isinstance(artifact, Artifact) + + def test_vertex_to_dataset_artifact(sagemaker_session): vertex = Vertex( arn="arn:aws:sagemaker:us-west-2:0123456789012:artifact/e66eef7f19c05e75284089183491bd4f", @@ -379,7 +455,7 @@ def test_vertex_to_model_artifact(sagemaker_session): assert isinstance(artifact, ModelArtifact) -def test_vertex_to_object_artifact(sagemaker_session): +def test_vertex_to_object_image_artifact(sagemaker_session): vertex = Vertex( arn="arn:aws:sagemaker:us-west-2:0123456789012:artifact/e66eef7f19c05e75284089183491bd4f", lineage_entity=LineageEntityEnum.ARTIFACT.value, @@ -441,7 +517,7 @@ def test_vertex_to_object_action(sagemaker_session): def test_vertex_to_object_unconvertable(sagemaker_session): vertex = Vertex( arn="arn:aws:sagemaker:us-west-2:0123456789012:artifact/e66eef7f19c05e75284089183491bd4f", - lineage_entity=LineageEntityEnum.TRIAL_COMPONENT.value, + lineage_entity=LineageEntityEnum.TRIAL.value, lineage_source=LineageSourceEnum.TENSORBOARD.value, sagemaker_session=sagemaker_session, )