1414
1515from __future__ import absolute_import
1616
17- from functools import lru_cache
1817from typing import Dict , List , Optional , Any , Union
1918import pandas as pd
2019from botocore .exceptions import ClientError
4847 get_jumpstart_configs ,
4948 get_metrics_from_deployment_configs ,
5049 add_instance_rate_stats_to_benchmark_metrics ,
50+ deployment_config_response_data ,
51+ _deployment_config_lru_cache ,
5152)
5253from sagemaker .jumpstart .constants import JUMPSTART_LOGGER
5354from sagemaker .jumpstart .enums import JumpStartModelType
@@ -449,10 +450,12 @@ def deployment_config(self) -> Optional[Dict[str, Any]]:
449450 Returns:
450451 Optional[Dict[str, Any]]: Deployment config.
451452 """
452- deployment_config = self ._retrieve_selected_deployment_config (
453- self .config_name , self .instance_type
454- )
455- return deployment_config .to_json () if deployment_config is not None else None
453+ if self .config_name is None :
454+ return None
455+ for config in self .list_deployment_configs ():
456+ if config .get ("DeploymentConfigName" ) == self .config_name :
457+ return config
458+ return None
456459
457460 @property
458461 def benchmark_metrics (self ) -> pd .DataFrame :
@@ -461,29 +464,24 @@ def benchmark_metrics(self) -> pd.DataFrame:
461464 Returns:
462465 Benchmark Metrics: Pandas DataFrame object.
463466 """
464- benchmark_metrics_data = self ._get_deployment_configs_benchmarks_data (
465- self .config_name , self .instance_type
466- )
467- keys = list (benchmark_metrics_data .keys ())
468- df = pd .DataFrame (benchmark_metrics_data ).sort_values (by = [keys [0 ], keys [1 ]])
469- return df
467+ df = pd .DataFrame (self ._get_deployment_configs_benchmarks_data ())
468+ default_mask = df .apply (lambda row : any ("Default" in str (val ) for val in row ), axis = 1 )
469+ sorted_df = pd .concat ([df [default_mask ], df [~ default_mask ]])
470+ return sorted_df
470471
471- def display_benchmark_metrics (self ) -> None :
472+ def display_benchmark_metrics (self , * args , ** kwargs ) -> None :
472473 """Display deployment configs benchmark metrics."""
473- print (self .benchmark_metrics .to_markdown (index = False ))
474+ print (self .benchmark_metrics .to_markdown (index = False ), * args , ** kwargs )
474475
475476 def list_deployment_configs (self ) -> List [Dict [str , Any ]]:
476477 """List deployment configs for ``This`` model.
477478
478479 Returns:
479480 List[Dict[str, Any]]: A list of deployment configs.
480481 """
481- return [
482- deployment_config .to_json ()
483- for deployment_config in self ._get_deployment_configs (
484- self .config_name , self .instance_type
485- )
486- ]
482+ return deployment_config_response_data (
483+ self ._get_deployment_configs (self .config_name , self .instance_type )
484+ )
487485
488486 def _create_sagemaker_model (
489487 self ,
@@ -873,71 +871,46 @@ def register_deploy_wrapper(*args, **kwargs):
873871
874872 return model_package
875873
876- @lru_cache
877- def _get_deployment_configs_benchmarks_data (
878- self , config_name : str , instance_type : str
879- ) -> Dict [str , Any ]:
874+ @_deployment_config_lru_cache
875+ def _get_deployment_configs_benchmarks_data (self ) -> Dict [str , Any ]:
880876 """Deployment configs benchmark metrics.
881877
882- Args:
883- config_name (str): Name of selected deployment config.
884- instance_type (str): The selected Instance type.
885878 Returns:
886879 Dict[str, List[str]]: Deployment config benchmark data.
887880 """
888881 return get_metrics_from_deployment_configs (
889- self ._get_deployment_configs (config_name , instance_type )
882+ self ._get_deployment_configs (None , None ),
890883 )
891884
892- @lru_cache
893- def _retrieve_selected_deployment_config (
894- self , config_name : str , instance_type : str
895- ) -> Optional [DeploymentConfigMetadata ]:
896- """Retrieve the deployment config to apply to `This` model.
897-
898- Args:
899- config_name (str): The name of the deployment config to retrieve.
900- instance_type (str): The instance type of the deployment config to retrieve.
901- Returns:
902- Optional[Dict[str, Any]]: The retrieved deployment config.
903- """
904- if config_name is None :
905- return None
906-
907- for deployment_config in self ._get_deployment_configs (config_name , instance_type ):
908- if deployment_config .deployment_config_name == config_name :
909- return deployment_config
910- return None
911-
912- @lru_cache
885+ @_deployment_config_lru_cache
913886 def _get_deployment_configs (
914- self , selected_config_name : str , selected_instance_type : str
887+ self , selected_config_name : Optional [ str ] , selected_instance_type : Optional [ str ]
915888 ) -> List [DeploymentConfigMetadata ]:
916889 """Retrieve deployment configs metadata.
917890
918891 Args:
919- selected_config_name (str): The name of the selected deployment config.
920- selected_instance_type (str): The selected instance type.
892+ selected_config_name (Optional[ str] ): The name of the selected deployment config.
893+ selected_instance_type (Optional[ str] ): The selected instance type.
921894 """
922895 deployment_configs = []
923- if self ._metadata_configs is None :
896+ if not self ._metadata_configs :
924897 return deployment_configs
925898
926899 err = None
927900 for config_name , metadata_config in self ._metadata_configs .items ():
928- if err is None or "is not authorized to perform: pricing:GetProducts" not in err :
929- err , metadata_config .benchmark_metrics = (
930- add_instance_rate_stats_to_benchmark_metrics (
931- self .region , metadata_config .benchmark_metrics
932- )
933- )
934-
935901 resolved_config = metadata_config .resolved_config
936902 if selected_config_name == config_name :
937903 instance_type_to_use = selected_instance_type
938904 else :
939905 instance_type_to_use = resolved_config .get ("default_inference_instance_type" )
940906
907+ if metadata_config .benchmark_metrics :
908+ err , metadata_config .benchmark_metrics = (
909+ add_instance_rate_stats_to_benchmark_metrics (
910+ self .region , metadata_config .benchmark_metrics
911+ )
912+ )
913+
941914 init_kwargs = get_init_kwargs (
942915 model_id = self .model_id ,
943916 instance_type = instance_type_to_use ,
@@ -957,9 +930,9 @@ def _get_deployment_configs(
957930 )
958931 deployment_configs .append (deployment_config_metadata )
959932
960- if err is not None and "is not authorized to perform: pricing:GetProducts" in err :
933+ if err and err [ "Code" ] == "AccessDeniedException" :
961934 error_message = "Instance rate metrics will be omitted. Reason: %s"
962- JUMPSTART_LOGGER .warning (error_message , err )
935+ JUMPSTART_LOGGER .warning (error_message , err [ "Message" ] )
963936
964937 return deployment_configs
965938
0 commit comments