4444 ENDPOINT_CONFIG_ASYNC_KMS_KEY_ID_PATH ,
4545 load_sagemaker_config ,
4646)
47+ from sagemaker .jumpstart .enums import JumpStartModelType
4748from sagemaker .model_card import (
4849 ModelCard ,
4950 ModelPackageModelCard ,
@@ -448,6 +449,7 @@ def register(
448449 skip_model_validation : Optional [Union [str , PipelineVariable ]] = None ,
449450 source_uri : Optional [Union [str , PipelineVariable ]] = None ,
450451 model_card : Optional [Union [ModelPackageModelCard , ModelCard ]] = None ,
452+ accept_eula : Optional [bool ] = None ,
451453 ):
452454 """Creates a model package for creating SageMaker models or listing on Marketplace.
453455
@@ -515,23 +517,22 @@ def register(
515517
516518 if image_uri is not None :
517519 self .image_uri = image_uri
518-
519- if model_package_group_name is None and model_package_name is None :
520- # If model package group and model package name is not set
521- # then register to auto-generated model package group
522- model_package_group_name = utils .base_name_from_image (
523- self .image_uri , default_base_name = ModelPackage .__name__
524- )
525-
526- if model_package_group_name is not None :
527- container_def = self .prepare_container_def ()
528- container_def = update_container_with_inference_params (
529- framework = framework ,
530- framework_version = framework_version ,
531- nearest_model_name = nearest_model_name ,
532- data_input_configuration = data_input_configuration ,
533- container_def = container_def ,
534- )
520+ if self .model_type is not JumpStartModelType .PROPRIETARY :
521+ if model_package_group_name is None and model_package_name is None :
522+ # If model package group and model package name is not set
523+ # then register to auto-generated model package group
524+ model_package_group_name = utils .base_name_from_image (
525+ self .image_uri , default_base_name = ModelPackage .__name__
526+ )
527+ if model_package_group_name is not None :
528+ container_def = self .prepare_container_def (accept_eula = accept_eula )
529+ container_def = update_container_with_inference_params (
530+ framework = framework ,
531+ framework_version = framework_version ,
532+ nearest_model_name = nearest_model_name ,
533+ data_input_configuration = data_input_configuration ,
534+ container_def = container_def ,
535+ )
535536 else :
536537 container_def = {
537538 "Image" : self .image_uri ,
@@ -546,6 +547,10 @@ def register(
546547 if self .model_data is not None :
547548 container_def ["ModelDataUrl" ] = self .model_data
548549
550+ if self .model_type is JumpStartModelType .PROPRIETARY :
551+ source_uri = self .model_package_arn
552+ model_package_group_name = self .model_id
553+
549554 model_pkg_args = sagemaker .get_model_package_args (
550555 self .content_types ,
551556 self .response_types ,
0 commit comments