@@ -209,7 +209,7 @@ def register(
209209 model_package_arn = model_package .get ("ModelPackageArn" ),
210210 )
211211
212- def _init_sagemaker_session_if_does_not_exist (self , instance_type ):
212+ def _init_sagemaker_session_if_does_not_exist (self , instance_type = None ):
213213 """Set ``self.sagemaker_session`` to ``LocalSession`` or ``Session`` if it's not already.
214214
215215 The type of session object is determined by the instance type.
@@ -688,8 +688,8 @@ def compile(
688688
689689 def deploy (
690690 self ,
691- initial_instance_count ,
692- instance_type ,
691+ initial_instance_count = None ,
692+ instance_type = None ,
693693 serializer = None ,
694694 deserializer = None ,
695695 accelerator_type = None ,
@@ -698,6 +698,7 @@ def deploy(
698698 kms_key = None ,
699699 wait = True ,
700700 data_capture_config = None ,
701+ serverless_inference_config = None ,
701702 ** kwargs ,
702703 ):
703704 """Deploy this ``Model`` to an ``Endpoint`` and optionally return a ``Predictor``.
@@ -715,9 +716,13 @@ def deploy(
715716
716717 Args:
717718 initial_instance_count (int): The initial number of instances to run
718- in the ``Endpoint`` created from this ``Model``.
719+ in the ``Endpoint`` created from this ``Model``. If not using
720+ serverless inference, then it need to be a number larger or equals
721+ to 1 (default: None)
719722 instance_type (str): The EC2 instance type to deploy this Model to.
720- For example, 'ml.p2.xlarge', or 'local' for local mode.
723+ For example, 'ml.p2.xlarge', or 'local' for local mode. If not using
724+ serverless inference, then it is required to deploy a model.
725+ (default: None)
721726 serializer (:class:`~sagemaker.serializers.BaseSerializer`): A
722727 serializer object, used to encode data for an inference endpoint
723728 (default: None). If ``serializer`` is not None, then
@@ -746,7 +751,14 @@ def deploy(
746751 data_capture_config (sagemaker.model_monitor.DataCaptureConfig): Specifies
747752 configuration related to Endpoint data capture for use with
748753 Amazon SageMaker Model Monitoring. Default: None.
749-
754+ serverless_inference_config (sagemaker.serverless.ServerlessInferenceConfig):
755+ Specifies configuration related to serverless endpoint. Use this configuration
756+ when trying to create serverless endpoint and make serverless inference. If
757+ empty config object passed through, we will use default config to deploy
758+ serverless endpoint (default: None)
759+ Raises:
760+ ValueError: If no role is specified or if serverless inference config is not
761+ specified and instance type and instance count are also not specified
750762 Returns:
751763 callable[string, sagemaker.session.Session] or None: Invocation of
752764 ``self.predictor_cls`` on the created endpoint name, if ``self.predictor_cls``
@@ -757,28 +769,43 @@ def deploy(
757769
758770 if self .role is None :
759771 raise ValueError ("Role can not be null for deploying a model" )
772+ is_serverless = bool (serverless_inference_config )
773+ if not is_serverless and not (instance_type and initial_instance_count ):
774+ raise ValueError (
775+ "Must specify instance type and instance count unless using serverless inference"
776+ )
760777
761- if instance_type .startswith ("ml.inf" ) and not self ._is_compiled_model :
778+ if instance_type and instance_type .startswith ("ml.inf" ) and not self ._is_compiled_model :
762779 LOGGER .warning (
763780 "Your model is not compiled. Please compile your model before using Inferentia."
764781 )
765782
766- compiled_model_suffix = "-" . join ( instance_type . split ( "." )[: - 1 ])
767- if self . _is_compiled_model :
783+ if self . _is_compiled_model and not is_serverless :
784+ compiled_model_suffix = "-" . join ( instance_type . split ( "." )[: - 1 ])
768785 self ._ensure_base_name_if_needed (self .image_uri )
769786 if self ._base_name is not None :
770787 self ._base_name = "-" .join ((self ._base_name , compiled_model_suffix ))
771788
772789 self ._create_sagemaker_model (instance_type , accelerator_type , tags )
790+
791+ serverless_inference_config_dict = (
792+ serverless_inference_config ._to_request_dict () if is_serverless else None
793+ )
773794 production_variant = sagemaker .production_variant (
774- self .name , instance_type , initial_instance_count , accelerator_type = accelerator_type
795+ self .name ,
796+ instance_type ,
797+ initial_instance_count ,
798+ accelerator_type = accelerator_type ,
799+ serverless_inference_config = serverless_inference_config_dict ,
775800 )
776801 if endpoint_name :
777802 self .endpoint_name = endpoint_name
778803 else :
779804 base_endpoint_name = self ._base_name or utils .base_from_name (self .name )
780- if self ._is_compiled_model and not base_endpoint_name .endswith (compiled_model_suffix ):
781- base_endpoint_name = "-" .join ((base_endpoint_name , compiled_model_suffix ))
805+ if self ._is_compiled_model and not is_serverless :
806+ compiled_model_suffix = "-" .join (instance_type .split ("." )[:- 1 ])
807+ if not base_endpoint_name .endswith (compiled_model_suffix ):
808+ base_endpoint_name = "-" .join ((base_endpoint_name , compiled_model_suffix ))
782809 self .endpoint_name = utils .name_from_base (base_endpoint_name )
783810
784811 data_capture_config_dict = None
0 commit comments