5454from sagemaker .serve .validations .check_image_and_hardware_type import (
5555 validate_image_uri_and_hardware ,
5656)
57+ from sagemaker .workflow .entities import PipelineVariable
5758from sagemaker .huggingface .llm_utils import get_huggingface_model_metadata
5859
5960logger = logging .getLogger (__name__ )
@@ -81,7 +82,6 @@ class ModelBuilder(Triton, DJL, JumpStart, TGI, Transformers):
8182
8283 * ``Mode.SAGEMAKER_ENDPOINT``: Launch on a SageMaker endpoint
8384 * ``Mode.LOCAL_CONTAINER``: Launch locally with a container
84-
8585 shared_libs (List[str]): Any shared libraries you want to bring into
8686 the model packaging.
8787 dependencies (Optional[Dict[str, Any]): The dependencies of the model
@@ -122,6 +122,15 @@ class ModelBuilder(Triton, DJL, JumpStart, TGI, Transformers):
122122 ``invoke`` and ``load`` functions.
123123 image_uri (Optional[str]): The container image uri (which is derived from a
124124 SageMaker-based container).
125+ image_config (dict[str, str] or dict[str, PipelineVariable]): Specifies
126+ whether the image of model container is pulled from ECR, or private
127+ registry in your VPC. By default it is set to pull model container
128+ image from ECR. (default: None).
129+ vpc_config ( Optional[Dict[str, List[Union[str, PipelineVariable]]]]):
130+ The VpcConfig set on the model (default: None)
131+ * 'Subnets' (List[Union[str, PipelineVariable]]): List of subnet ids.
132+ * 'SecurityGroupIds' (List[Union[str, PipelineVariable]]]): List of security group
133+ ids.
125134 model_server (Optional[ModelServer]): The model server to which to deploy.
126135 You need to provide this argument when you specify an ``image_uri``
127136 in order for model builder to build the artifacts correctly (according
@@ -204,6 +213,23 @@ class ModelBuilder(Triton, DJL, JumpStart, TGI, Transformers):
204213 image_uri : Optional [str ] = field (
205214 default = None , metadata = {"help" : "Define the container image uri" }
206215 )
216+ image_config : Optional [Dict [str , Union [str , PipelineVariable ]]] = field (
217+ default = None ,
218+ metadata = {
219+ "help" : "Specifies whether the image of model container is pulled from ECR,"
220+ " or private registry in your VPC. By default it is set to pull model "
221+ "container image from ECR. (default: None)."
222+ },
223+ )
224+ vpc_config : Optional [Dict [str , List [Union [str , PipelineVariable ]]]] = field (
225+ default = None ,
226+ metadata = {
227+ "help" : "The VpcConfig set on the model (default: None)."
228+ "* 'Subnets' (List[Union[str, PipelineVariable]]): List of subnet ids."
229+ "* ''SecurityGroupIds'' (List[Union[str, PipelineVariable]]): List of"
230+ " security group ids."
231+ },
232+ )
207233 model_server : Optional [ModelServer ] = field (
208234 default = None , metadata = {"help" : "Define the model server to deploy to." }
209235 )
@@ -386,6 +412,8 @@ def _create_model(self):
386412 # TODO: we should create model as per the framework
387413 self .pysdk_model = Model (
388414 image_uri = self .image_uri ,
415+ image_config = self .image_config ,
416+ vpc_config = self .vpc_config ,
389417 model_data = self .s3_upload_path ,
390418 role = self .serve_settings .role_arn ,
391419 env = self .env_vars ,
@@ -543,15 +571,16 @@ def build(
543571 self ,
544572 mode : Type [Mode ] = None ,
545573 role_arn : str = None ,
546- sagemaker_session : str = None ,
574+ sagemaker_session : Optional [ Session ] = None ,
547575 ) -> Type [Model ]:
548576 """Create a deployable ``Model`` instance with ``ModelBuilder``.
549577
550578 Args:
551579 mode (Type[Mode], optional): The mode. Defaults to ``None``.
552580 role_arn (str, optional): The IAM role arn. Defaults to ``None``.
553- sagemaker_session (str, optional): The SageMaker session to use
554- for the execution. Defaults to ``None``.
581+ sagemaker_session (Optional[Session]): Session object which manages interactions
582+ with Amazon SageMaker APIs and any other AWS services needed. If not specified, the
583+ function creates one using the default AWS configuration chain.
555584
556585 Returns:
557586 Type[Model]: A deployable ``Model`` object.
@@ -562,10 +591,7 @@ def build(
562591 self .mode = mode
563592 if role_arn :
564593 self .role_arn = role_arn
565- if sagemaker_session :
566- self .sagemaker_session = sagemaker_session
567- elif not self .sagemaker_session :
568- self .sagemaker_session = Session ()
594+ self .sagemaker_session = sagemaker_session or Session ()
569595
570596 self .sagemaker_session .settings ._local_download_dir = self .model_path
571597
@@ -607,7 +633,7 @@ def save(
607633 self ,
608634 save_path : Optional [str ] = None ,
609635 s3_path : Optional [str ] = None ,
610- sagemaker_session : Optional [str ] = None ,
636+ sagemaker_session : Optional [Session ] = None ,
611637 role_arn : Optional [str ] = None ,
612638 ) -> Type [Model ]:
613639 """WARNING: This function is expremental and not intended for production use.
@@ -618,7 +644,7 @@ def save(
618644 save_path (Optional[str]): The path where you want to save resources.
619645 s3_path (Optional[str]): The path where you want to upload resources.
620646 """
621- self .sagemaker_session = sagemaker_session if sagemaker_session else Session ()
647+ self .sagemaker_session = sagemaker_session or Session ()
622648
623649 if role_arn :
624650 self .role_arn = role_arn
0 commit comments