diff --git a/src/sagemaker/lineage/query.py b/src/sagemaker/lineage/query.py index 033e838137..78cfc700e6 100644 --- a/src/sagemaker/lineage/query.py +++ b/src/sagemaker/lineage/query.py @@ -12,9 +12,11 @@ # language governing permissions and limitations under the License. """This module contains code to query SageMaker lineage.""" from __future__ import absolute_import + from datetime import datetime from enum import Enum from typing import Optional, Union, List, Dict + from sagemaker.lineage._utils import get_resource_name_from_arn @@ -65,6 +67,27 @@ def __init__( self.destination_arn = destination_arn self.association_type = association_type + def __hash__(self): + """Define hash function for ``Edge``.""" + return hash( + ( + "source_arn", + self.source_arn, + "destination_arn", + self.destination_arn, + "association_type", + self.association_type, + ) + ) + + def __eq__(self, other): + """Define equal function for ``Edge``.""" + return ( + self.association_type == other.association_type + and self.source_arn == other.source_arn + and self.destination_arn == other.destination_arn + ) + class Vertex: """A vertex for a lineage graph.""" @@ -82,6 +105,27 @@ def __init__( self.lineage_source = lineage_source self._session = sagemaker_session + def __hash__(self): + """Define hash function for ``Vertex``.""" + return hash( + ( + "arn", + self.arn, + "lineage_entity", + self.lineage_entity, + "lineage_source", + self.lineage_source, + ) + ) + + def __eq__(self, other): + """Define equal function for ``Vertex``.""" + return ( + self.arn == other.arn + and self.lineage_entity == other.lineage_entity + and self.lineage_source == other.lineage_source + ) + def to_lineage_object(self): """Convert the ``Vertex`` object to its corresponding Artifact, Action, Context object.""" from sagemaker.lineage.artifact import Artifact, ModelArtifact @@ -210,6 +254,18 @@ def _convert_api_response(self, response) -> LineageQueryResult: converted.edges = [self._get_edge(edge) for edge in response["Edges"]] converted.vertices = [self._get_vertex(vertex) for vertex in response["Vertices"]] + edge_set = set() + for edge in converted.edges: + if edge in edge_set: + converted.edges.remove(edge) + edge_set.add(edge) + + vertex_set = set() + for vertex in converted.vertices: + if vertex in vertex_set: + converted.vertices.remove(vertex) + vertex_set.add(vertex) + return converted def _collapse_cross_account_artifacts(self, query_response): diff --git a/tests/unit/sagemaker/lineage/test_query.py b/tests/unit/sagemaker/lineage/test_query.py index 595e7e1d0f..50bb14e6b1 100644 --- a/tests/unit/sagemaker/lineage/test_query.py +++ b/tests/unit/sagemaker/lineage/test_query.py @@ -32,6 +32,38 @@ def test_lineage_query(sagemaker_session): start_arns=["arn:aws:sagemaker:us-west-2:0123456789012:context/mycontext"] ) + assert len(response.edges) == 1 + assert response.edges[0].source_arn == "arn1" + assert response.edges[0].destination_arn == "arn2" + assert response.edges[0].association_type == "Produced" + assert len(response.vertices) == 2 + + assert response.vertices[0].arn == "arn1" + assert response.vertices[0].lineage_source == "Endpoint" + assert response.vertices[0].lineage_entity == "Artifact" + assert response.vertices[1].arn == "arn2" + assert response.vertices[1].lineage_source == "Model" + assert response.vertices[1].lineage_entity == "Context" + + +def test_lineage_query_duplication(sagemaker_session): + lineage_query = LineageQuery(sagemaker_session) + sagemaker_session.sagemaker_client.query_lineage.return_value = { + "Vertices": [ + {"Arn": "arn1", "Type": "Endpoint", "LineageType": "Artifact"}, + {"Arn": "arn1", "Type": "Endpoint", "LineageType": "Artifact"}, + {"Arn": "arn2", "Type": "Model", "LineageType": "Context"}, + ], + "Edges": [ + {"SourceArn": "arn1", "DestinationArn": "arn2", "AssociationType": "Produced"}, + {"SourceArn": "arn1", "DestinationArn": "arn2", "AssociationType": "Produced"}, + ], + } + + response = lineage_query.query( + start_arns=["arn:aws:sagemaker:us-west-2:0123456789012:context/mycontext"] + ) + assert len(response.edges) == 1 assert response.edges[0].source_arn == "arn1" assert response.edges[0].destination_arn == "arn2"