1212# language governing permissions and limitations under the License.
1313"""This module contains code to query SageMaker lineage."""
1414from __future__ import absolute_import
15+
1516from datetime import datetime
1617from enum import Enum
1718from typing import Optional , Union , List , Dict
19+
1820from sagemaker .lineage ._utils import get_resource_name_from_arn
1921
2022
@@ -65,6 +67,27 @@ def __init__(
6567 self .destination_arn = destination_arn
6668 self .association_type = association_type
6769
70+ def __hash__ (self ):
71+ """Define hash function for ``Edge``."""
72+ return hash (
73+ (
74+ "source_arn" ,
75+ self .source_arn ,
76+ "destination_arn" ,
77+ self .destination_arn ,
78+ "association_type" ,
79+ self .association_type ,
80+ )
81+ )
82+
83+ def __eq__ (self , other ):
84+ """Define equal function for ``Edge``."""
85+ return (
86+ self .association_type == other .association_type
87+ and self .source_arn == other .source_arn
88+ and self .destination_arn == other .destination_arn
89+ )
90+
6891
6992class Vertex :
7093 """A vertex for a lineage graph."""
@@ -82,6 +105,27 @@ def __init__(
82105 self .lineage_source = lineage_source
83106 self ._session = sagemaker_session
84107
108+ def __hash__ (self ):
109+ """Define hash function for ``Vertex``."""
110+ return hash (
111+ (
112+ "arn" ,
113+ self .arn ,
114+ "lineage_entity" ,
115+ self .lineage_entity ,
116+ "lineage_source" ,
117+ self .lineage_source ,
118+ )
119+ )
120+
121+ def __eq__ (self , other ):
122+ """Define equal function for ``Vertex``."""
123+ return (
124+ self .arn == other .arn
125+ and self .lineage_entity == other .lineage_entity
126+ and self .lineage_source == other .lineage_source
127+ )
128+
85129 def to_lineage_object (self ):
86130 """Convert the ``Vertex`` object to its corresponding Artifact, Action, Context object."""
87131 from sagemaker .lineage .artifact import Artifact , ModelArtifact
@@ -210,6 +254,18 @@ def _convert_api_response(self, response) -> LineageQueryResult:
210254 converted .edges = [self ._get_edge (edge ) for edge in response ["Edges" ]]
211255 converted .vertices = [self ._get_vertex (vertex ) for vertex in response ["Vertices" ]]
212256
257+ edge_set = set ()
258+ for edge in converted .edges :
259+ if edge in edge_set :
260+ converted .edges .remove (edge )
261+ edge_set .add (edge )
262+
263+ vertex_set = set ()
264+ for vertex in converted .vertices :
265+ if vertex in vertex_set :
266+ converted .vertices .remove (vertex )
267+ vertex_set .add (vertex )
268+
213269 return converted
214270
215271 def _collapse_cross_account_artifacts (self , query_response ):
0 commit comments