@@ -146,7 +146,7 @@ def downstream_trials(self, sagemaker_session=None) -> list:
146146 """Retrieve all trial runs which that use this artifact.
147147
148148 Args:
149- sagemaker_session (obj): Sagemaker Sesssion to use. If not provided a default session
149+ sagemaker_session (obj): Sagemaker Session to use. If not provided a default session
150150 will be created.
151151
152152 Returns:
@@ -159,6 +159,57 @@ def downstream_trials(self, sagemaker_session=None) -> list:
159159 )
160160 trial_component_arns : list = list (map (lambda x : x .destination_arn , outgoing_associations ))
161161
162+ return self ._get_trial_from_trial_component (trial_component_arns )
163+
164+ def downstream_trials_v2 (
165+ self , direction : LineageQueryDirectionEnum = LineageQueryDirectionEnum .DESCENDANTS
166+ ) -> list :
167+ """Retrieve all downstream trial runs which that use this artifact by using lineage query.
168+
169+ Args:
170+ direction (LineageQueryDirectionEnum, optional): The query direction.
171+
172+ Returns:
173+ [Trial]: A list of SageMaker `Trial` objects.
174+ """
175+ query_filter = LineageFilter (entities = [LineageEntityEnum .TRIAL_COMPONENT ])
176+ query_result = LineageQuery (self .sagemaker_session ).query (
177+ start_arns = [self .artifact_arn ],
178+ query_filter = query_filter ,
179+ direction = direction ,
180+ include_edges = False ,
181+ )
182+ trial_component_arns : list = list (map (lambda x : x .arn , query_result .vertices ))
183+
184+ return self ._get_trial_from_trial_component (trial_component_arns )
185+
186+ def get_upstream_trials (
187+ self , direction : LineageQueryDirectionEnum = LineageQueryDirectionEnum .ASCENDANTS
188+ ) -> List [str ]:
189+ """Retrieve all upstream trial runs which that use this artifact by using lineage query.
190+
191+ Args:
192+ direction (LineageQueryDirectionEnum, optional): The query direction.
193+
194+ Returns:
195+ [Trial]: A list of SageMaker `Trial` objects.
196+ """
197+ query_filter = LineageFilter (entities = [LineageEntityEnum .TRIAL_COMPONENT ])
198+ query_result = LineageQuery (self .sagemaker_session ).query (
199+ start_arns = [self .artifact_arn ],
200+ query_filter = query_filter ,
201+ direction = direction ,
202+ include_edges = False ,
203+ )
204+ trial_component_arns : list = list (map (lambda x : x .arn , query_result .vertices ))
205+ return self ._get_trial_from_trial_component (trial_component_arns )
206+
207+ def _get_trial_from_trial_component (self , trial_component_arns : list ):
208+ """Retrieve all upstream trial runs which that use the trial component arns.
209+
210+ Returns:
211+ [Trial]: A list of SageMaker `Trial` objects.
212+ """
162213 if not trial_component_arns :
163214 # no outgoing associations for this artifact
164215 return []
@@ -170,7 +221,7 @@ def downstream_trials(self, sagemaker_session=None) -> list:
170221 num_search_batches = math .ceil (len (trial_component_arns ) % max_search_by_arn )
171222 trial_components : list = []
172223
173- sagemaker_session = sagemaker_session or _utils .default_session ()
224+ sagemaker_session = self . sagemaker_session or _utils .default_session ()
174225 sagemaker_client = sagemaker_session .sagemaker_client
175226
176227 for i in range (num_search_batches ):
@@ -335,6 +386,17 @@ def list(
335386 sagemaker_session = sagemaker_session ,
336387 )
337388
389+ def get_artifacts (self , s3_uri : str ):
390+ """Return a list of artifact that use provided s3 uri.
391+
392+ Args:
393+ s3_uri (str): A S3 URI.
394+
395+ Returns:
396+ A list of ``Artifacts``
397+ """
398+ return self .sagemaker_session .sagemaker_client .list_artifacts (SourceUri = s3_uri )
399+
338400
339401class ModelArtifact (Artifact ):
340402 """A SageMaker lineage artifact representing a model.
@@ -349,7 +411,7 @@ def endpoints(self) -> list:
349411 """Get association summaries for endpoints deployed with this model.
350412
351413 Returns:
352- [AssociationSummary]: A list of associations repesenting the endpoints using the model.
414+ [AssociationSummary]: A list of associations representing the endpoints using the model.
353415 """
354416 endpoint_development_actions : Iterator = Association .list (
355417 source_arn = self .artifact_arn ,
@@ -522,3 +584,75 @@ def endpoint_contexts(
522584 for vertex in query_result .vertices :
523585 endpoint_contexts .append (vertex .to_lineage_object ())
524586 return endpoint_contexts
587+
588+ def get_upstream_datasets (
589+ self , direction : LineageQueryDirectionEnum = LineageQueryDirectionEnum .ASCENDANTS
590+ ) -> List [Artifact ]:
591+ """Get upstream artifacts representing dataset from the dataset's lineage.
592+
593+ Args:
594+ direction (LineageQueryDirectionEnum, optional): The query direction.
595+
596+ Returns:
597+ list of Artifacts: Artifacts representing an dataset.
598+ """
599+ query_filter = LineageFilter (
600+ entities = [LineageEntityEnum .ARTIFACT ], sources = [LineageSourceEnum .DATASET ]
601+ )
602+ query_result = LineageQuery (self .sagemaker_session ).query (
603+ start_arns = [self .artifact_arn ],
604+ query_filter = query_filter ,
605+ direction = direction ,
606+ include_edges = False ,
607+ )
608+ return [vertex .to_lineage_object () for vertex in query_result .vertices ]
609+
610+ def get_downstream_datasets (
611+ self , direction : LineageQueryDirectionEnum = LineageQueryDirectionEnum .DESCENDANTS
612+ ) -> List [Artifact ]:
613+ """Get downstream artifacts representing dataset from the dataset's lineage.
614+
615+ Args:
616+ direction (LineageQueryDirectionEnum, optional): The query direction.
617+
618+ Returns:
619+ list of Artifacts: Artifacts representing an dataset.
620+ """
621+ query_filter = LineageFilter (
622+ entities = [LineageEntityEnum .ARTIFACT ], sources = [LineageSourceEnum .DATASET ]
623+ )
624+ query_result = LineageQuery (self .sagemaker_session ).query (
625+ start_arns = [self .artifact_arn ],
626+ query_filter = query_filter ,
627+ direction = direction ,
628+ include_edges = False ,
629+ )
630+ return [vertex .to_lineage_object () for vertex in query_result .vertices ]
631+
632+
633+ class ImageArtifact (Artifact ):
634+ """A SageMaker lineage artifact representing an image.
635+
636+ Common model specific lineage traversals to discover how the image is connected
637+ to other entities.
638+ """
639+
640+ def get_dataset (self , direction : LineageQueryDirectionEnum ) -> List [Artifact ]:
641+ """Get artifacts representing dataset from the image artifact's lineage.
642+
643+ Args:
644+ direction (LineageQueryDirectionEnum): The query direction.
645+
646+ Returns:
647+ list of Artifacts: Artifacts representing an dataset.
648+ """
649+ query_filter = LineageFilter (
650+ entities = [LineageEntityEnum .ARTIFACT ], sources = [LineageSourceEnum .DATASET ]
651+ )
652+ query_result = LineageQuery (self .sagemaker_session ).query (
653+ start_arns = [self .artifact_arn ],
654+ query_filter = query_filter ,
655+ direction = direction ,
656+ include_edges = False ,
657+ )
658+ return [vertex .to_lineage_object () for vertex in query_result .vertices ]
0 commit comments