1414from __future__ import print_function , absolute_import
1515
1616import abc
17- from typing import Any , Optional , Tuple , Union
17+ from typing import Any , Dict , Optional , Tuple , Union
18+ import logging
1819
20+ from sagemaker .enums import EndpointType
1921from sagemaker .deprecations import (
2022 deprecated_class ,
2123 deprecated_deserialize ,
5557from sagemaker .model_monitor .model_monitoring import DEFAULT_REPOSITORY_NAME
5658
5759from sagemaker .lineage .context import EndpointContext
60+ from sagemaker .compute_resource_requirements .resource_requirements import ResourceRequirements
61+
62+ LOGGER = logging .getLogger ("sagemaker" )
5863
5964
6065class PredictorBase (abc .ABC ):
@@ -92,6 +97,7 @@ def __init__(
9297 sagemaker_session = None ,
9398 serializer = IdentitySerializer (),
9499 deserializer = BytesDeserializer (),
100+ component_name = None ,
95101 ** kwargs ,
96102 ):
97103 """Initialize a ``Predictor``.
@@ -115,11 +121,14 @@ def __init__(
115121 deserializer (:class:`~sagemaker.deserializers.BaseDeserializer`): A
116122 deserializer object, used to decode data from an inference
117123 endpoint (default: :class:`~sagemaker.deserializers.BytesDeserializer`).
124+ component_name (str): Name of the Amazon SageMaker inference component
125+ corresponding the predictor.
118126 """
119127 removed_kwargs ("content_type" , kwargs )
120128 removed_kwargs ("accept" , kwargs )
121129 endpoint_name = renamed_kwargs ("endpoint" , "endpoint_name" , endpoint_name , kwargs )
122130 self .endpoint_name = endpoint_name
131+ self .component_name = component_name
123132 self .sagemaker_session = sagemaker_session or Session ()
124133 self .serializer = serializer
125134 self .deserializer = deserializer
@@ -137,6 +146,7 @@ def predict(
137146 target_variant = None ,
138147 inference_id = None ,
139148 custom_attributes = None ,
149+ component_name : Optional [str ] = None ,
140150 ):
141151 """Return the inference from the specified endpoint.
142152
@@ -169,22 +179,29 @@ def predict(
169179 value is returned. For example, if a custom attribute represents the trace ID, your
170180 model can prepend the custom attribute with Trace ID: in your post-processing
171181 function (Default: None).
182+ component_name (str): Optional. Name of the Amazon SageMaker inference component
183+ corresponding the predictor.
172184
173185 Returns:
174186 object: Inference for the given input. If a deserializer was specified when creating
175187 the Predictor, the result of the deserializer is
176188 returned. Otherwise the response returns the sequence of bytes
177189 as is.
178190 """
179-
191+ # [TODO]: clean up component_name in _create_request_args
180192 request_args = self ._create_request_args (
181- data ,
182- initial_args ,
183- target_model ,
184- target_variant ,
185- inference_id ,
186- custom_attributes ,
193+ data = data ,
194+ initial_args = initial_args ,
195+ target_model = target_model ,
196+ target_variant = target_variant ,
197+ inference_id = inference_id ,
198+ custom_attributes = custom_attributes ,
187199 )
200+
201+ inference_component_name = component_name or self ._get_component_name ()
202+ if inference_component_name :
203+ request_args ["InferenceComponentName" ] = inference_component_name
204+
188205 response = self .sagemaker_session .sagemaker_runtime_client .invoke_endpoint (** request_args )
189206 return self ._handle_response (response )
190207
@@ -260,6 +277,8 @@ def _create_request_args(
260277 if isinstance (data , JumpStartSerializablePayload ) and jumpstart_serialized_data
261278 else self .serializer .serialize (data )
262279 )
280+ if self ._get_component_name ():
281+ args ["InferenceComponentName" ] = self .component_name
263282
264283 args ["Body" ] = data
265284 return args
@@ -273,6 +292,8 @@ def update_endpoint(
273292 tags = None ,
274293 kms_key = None ,
275294 data_capture_config_dict = None ,
295+ max_instance_count = None ,
296+ min_instance_count = None ,
276297 wait = True ,
277298 ):
278299 """Update the existing endpoint with the provided attributes.
@@ -310,6 +331,8 @@ def update_endpoint(
310331 data_capture_config_dict (dict): The endpoint data capture configuration
311332 for use with Amazon SageMaker Model Monitoring. If not specified,
312333 the data capture configuration of the existing endpoint configuration is used.
334+ max_instance_count (int): The maximum instance count used for scaling instance.
335+ min_instance_count (int): The minimum instance count used for scaling instance.
313336
314337 Raises:
315338 ValueError: If there is not enough information to create a new ``ProductionVariant``:
@@ -348,23 +371,45 @@ def update_endpoint(
348371 else :
349372 self ._model_names = [model_name ]
350373
351- production_variant_config = production_variant (
352- model_name ,
353- instance_type ,
354- initial_instance_count = initial_instance_count ,
355- accelerator_type = accelerator_type ,
356- )
374+ managed_instance_scaling = {}
375+ if max_instance_count :
376+ managed_instance_scaling ["MaxInstanceCount" ] = max_instance_count
377+ if min_instance_count :
378+ managed_instance_scaling ["MinInstanceCount" ] = min_instance_count
379+
380+ if managed_instance_scaling and len (managed_instance_scaling ) > 0 :
381+ production_variant_config = production_variant (
382+ model_name ,
383+ instance_type ,
384+ initial_instance_count = initial_instance_count ,
385+ accelerator_type = accelerator_type ,
386+ managed_instance_scaling = managed_instance_scaling ,
387+ )
388+ else :
389+ production_variant_config = production_variant (
390+ model_name ,
391+ instance_type ,
392+ initial_instance_count = initial_instance_count ,
393+ accelerator_type = accelerator_type ,
394+ )
357395 production_variants = [production_variant_config ]
358396
359397 current_endpoint_config_name = self ._get_endpoint_config_name ()
360398 new_endpoint_config_name = name_from_base (current_endpoint_config_name )
399+
400+ if self ._get_component_name ():
401+ endpoint_type = EndpointType .GOLDFINCH
402+ else :
403+ endpoint_type = EndpointType .OTHERS
404+
361405 self .sagemaker_session .create_endpoint_config_from_existing (
362406 current_endpoint_config_name ,
363407 new_endpoint_config_name ,
364408 new_tags = tags ,
365409 new_kms_key = kms_key ,
366410 new_data_capture_config_dict = data_capture_config_dict ,
367411 new_production_variants = production_variants ,
412+ endpoint_type = endpoint_type ,
368413 )
369414 self .sagemaker_session .update_endpoint (
370415 self .endpoint_name , new_endpoint_config_name , wait = wait
@@ -393,10 +438,123 @@ def delete_endpoint(self, delete_endpoint_config=True):
393438
394439 self .sagemaker_session .delete_endpoint (self .endpoint_name )
395440
396- delete_predictor = delete_endpoint
441+ def delete_predictor (self ) -> None :
442+ """Delete the Amazon SageMaker inference component or endpoint backing this predictor.
443+
444+ Delete the corresponding inference component if the endpoint is Goldfinch.
445+ Otherwise delete the endpoint where this predictor is on.
446+ """
447+ # [TODO]: wait and describe inference component until not found to ensure
448+ # it gets deleted successfully. Throw appropriate exception/error type.
449+ if self .component_name :
450+ self .sagemaker_session .delete_inference_component (self .component_name )
451+ else :
452+ self .delete_endpoint ()
453+
454+ def update_predictor (
455+ self ,
456+ image_uri : Optional [str ] = None ,
457+ model_data : Optional [Union [str , dict ]] = None ,
458+ env : Optional [Dict [str , str ]] = None ,
459+ model_data_download_timeout : Optional [int ] = None ,
460+ container_startup_health_check_timeout : Optional [int ] = None ,
461+ resources : Optional [ResourceRequirements ] = None ,
462+ ) -> str :
463+ """Updates the predictor to deploy a new Model specification and apply new configurations.
464+
465+ This is done by updating the SageMaker InferenceComponent.
466+
467+ Args:
468+ image_uri (Optional[str]): A Docker image URI. (Default: None).
469+ model_data (Optional[Union[str, dict]]): Location
470+ of SageMaker model data. (Default: None).
471+ env (Optional[dict[str, str]]): Environment variables
472+ to run with ``image_uri`` when hosted in SageMaker. (Default: None).
473+ model_data_download_timeout (Optional[int]): The timeout value, in seconds, to download
474+ and extract model data from Amazon S3 to the individual inference instance
475+ associated with this production variant. (Default: None).
476+ container_startup_health_check_timeout (Optional[int]): The timeout value, in seconds,
477+ for your inference container to pass health check by SageMaker Hosting. For more
478+ information about health check see:
479+ https://docs.aws.amazon.com/sagemaker/latest/dg/your-algorithms-inference-code.html#your-algorithms-inference-algo-ping-requests
480+ (Default: None).
481+ resources (Optional[ResourceRequirements]): The compute resource requirements
482+ for a model to be deployed to an endpoint. Only EndpointType.Goldfinch supports
483+ this feature. (Default: None).
484+
485+ Returns:
486+ String: The updated Amazon SageMaker Inference Component name
487+ """
488+ if self .component_name is None :
489+ raise ValueError (
490+ "No existing Inference Component; "
491+ "Please ensure you deployed Inference Component first."
492+ )
493+ # [TODO]: Move to a module
494+ request = {
495+ "InferenceComponentName" : self .component_name ,
496+ "Specification" : {},
497+ }
498+
499+ if resources :
500+ request ["Specification" ][
501+ "ComputeResourceRequirements"
502+ ] = resources .get_compute_resource_requirements ()
503+
504+ if image_uri :
505+ request ["Specification" ]["Container" ]["Image" ] = image_uri
506+
507+ if env :
508+ request ["Specification" ]["Container" ]["Environment" ] = env
509+
510+ if model_data :
511+ request ["Specification" ]["Container" ]["ArtifactUrl" ] = model_data
512+
513+ if resources .copy_count :
514+ request ["RuntimeConfig" ] = {"CopyCount" : resources .copy_count }
515+
516+ if model_data_download_timeout :
517+ request ["Specification" ]["StartupParameters" ][
518+ "ModelDataDownloadTimeoutInSeconds"
519+ ] = model_data_download_timeout
520+
521+ if container_startup_health_check_timeout :
522+ request ["Specification" ]["StartupParameters" ][
523+ "ContainerStartupHealthCheckTimeoutInSeconds"
524+ ] = container_startup_health_check_timeout
525+
526+ empty_keys = []
527+ for key , value in request ["Specification" ].items ():
528+ if not value :
529+ empty_keys .append (key )
530+ for key in empty_keys :
531+ del request ["Specification" ][key ]
532+
533+ self .sagemaker_session .update_inference_component (** request )
534+ return self .component_name
535+
536+ # [TODO]: Check with doc writer for colocated vs collocated
537+ def list_colocated_models (self ):
538+ """List the deployed models co-located with this predictor.
539+
540+ Calls SageMaker:ListInferenceComponents on the endpoint associated with the predictor.
541+
542+ Returns:
543+ Dict[str, list]: A list of Amazon SageMaker Inference Component objects.
544+ """
545+
546+ inference_component_dict = self .sagemaker_session .list_inference_components (
547+ filters = {"EndpointNameEquals" : self .endpoint_name }
548+ )
549+
550+ if len (inference_component_dict ) == 0 :
551+ LOGGER .info ("No deployed models found for endpoint %s." , self .endpoint_name )
552+ return []
553+
554+ return inference_component_dict ["InferenceComponents" ]
397555
398556 def delete_model (self ):
399- """Deletes the Amazon SageMaker models backing this predictor."""
557+ """Delete the Amazon SageMaker model backing this predictor."""
400558 request_failed = False
401559 failed_models = []
402560 current_model_names = self ._get_model_names ()
@@ -594,9 +752,16 @@ def _get_model_names(self):
594752 EndpointConfigName = current_endpoint_config_name
595753 )
596754 production_variants = endpoint_config ["ProductionVariants" ]
597- self ._model_names = [d ["ModelName" ] for d in production_variants ]
755+ self ._model_names = []
756+ for d in production_variants :
757+ if "ModelName" in d :
758+ self ._model_names .append (d ["ModelName" ])
598759 return self ._model_names
599760
761+ def _get_component_name (self ) -> Optional [str ]:
762+ """Get the inference component name field if it exists in the Predictor object."""
763+ return getattr (self , "component_name" , None )
764+
600765 @property
601766 def content_type (self ):
602767 """The MIME type of the data sent to the inference endpoint."""
0 commit comments