3838SLEEP_TIME_SECONDS = 1
3939STATIC_PIPELINE_NAME = "SdkIntegTestStaticPipeline17"
4040STATIC_ENDPOINT_NAME = "SdkIntegTestStaticEndpoint17"
41+ STATIC_MODEL_PACKAGE_GROUP_NAME = "SdkIntegTestStaticPipeline17ModelPackageGroup"
4142
4243
4344@pytest .fixture
@@ -543,6 +544,29 @@ def static_endpoint_context(sagemaker_session, static_pipeline_execution_arn):
543544 )
544545
545546
547+ @pytest .fixture
548+ def static_model_package_group_context (sagemaker_session , static_pipeline_execution_arn ):
549+
550+ model_package_group_arn = get_model_package_group_arn_from_static_pipeline (sagemaker_session )
551+
552+ contexts = sagemaker_session .sagemaker_client .list_contexts (SourceUri = model_package_group_arn )[
553+ "ContextSummaries"
554+ ]
555+ if len (contexts ) != 1 :
556+ raise (
557+ Exception (
558+ f"Got an unexpected number of Contexts for \
559+ model package group { STATIC_MODEL_PACKAGE_GROUP_NAME } from pipeline \
560+ execution { static_pipeline_execution_arn } . \
561+ Expected 1 but got { len (contexts )} "
562+ )
563+ )
564+
565+ yield context .ModelPackageGroup .load (
566+ contexts [0 ]["ContextName" ], sagemaker_session = sagemaker_session
567+ )
568+
569+
546570@pytest .fixture
547571def static_model_artifact (sagemaker_session , static_pipeline_execution_arn ):
548572 model_package_arn = get_model_package_arn_from_static_pipeline (
@@ -590,6 +614,31 @@ def static_dataset_artifact(static_model_artifact, sagemaker_session):
590614 )
591615
592616
617+ @pytest .fixture
618+ def static_image_artifact (static_model_artifact , sagemaker_session ):
619+ dataset_associations = sagemaker_session .sagemaker_client .list_associations (
620+ DestinationArn = static_model_artifact .artifact_arn , SourceType = "Image"
621+ )
622+ if len (dataset_associations ["AssociationSummaries" ]) == 0 :
623+ # no directly associated dataset. work backwards from the model
624+ model_associations = sagemaker_session .sagemaker_client .list_associations (
625+ DestinationArn = static_model_artifact .artifact_arn , SourceType = "Model"
626+ )
627+ training_job_associations = sagemaker_session .sagemaker_client .list_associations (
628+ DestinationArn = model_associations ["AssociationSummaries" ][0 ]["SourceArn" ],
629+ SourceType = "SageMakerTrainingJob" ,
630+ )
631+ dataset_associations = sagemaker_session .sagemaker_client .list_associations (
632+ DestinationArn = training_job_associations ["AssociationSummaries" ][0 ]["SourceArn" ],
633+ SourceType = "Image" ,
634+ )
635+
636+ yield artifact .ImageArtifact .load (
637+ dataset_associations ["AssociationSummaries" ][0 ]["SourceArn" ],
638+ sagemaker_session = sagemaker_session ,
639+ )
640+
641+
593642def get_endpoint_arn_from_static_pipeline (sagemaker_session ):
594643 try :
595644 endpoint_arn = sagemaker_session .sagemaker_client .describe_endpoint (
@@ -604,6 +653,15 @@ def get_endpoint_arn_from_static_pipeline(sagemaker_session):
604653 raise e
605654
606655
656+ def get_model_package_group_arn_from_static_pipeline (sagemaker_session ):
657+ static_model_package_group_arn = (
658+ sagemaker_session .sagemaker_client .describe_model_package_group (
659+ ModelPackageGroupName = STATIC_MODEL_PACKAGE_GROUP_NAME
660+ )["ModelPackageGroupArn" ]
661+ )
662+ return static_model_package_group_arn
663+
664+
607665def get_model_package_arn_from_static_pipeline (pipeline_execution_arn , sagemaker_session ):
608666 # get the model package ARN from the pipeline
609667 pipeline_execution_steps = sagemaker_session .sagemaker_client .list_pipeline_execution_steps (
0 commit comments