@@ -1156,6 +1156,149 @@ def from_json(self, json_obj: Dict[str, Any]) -> None:
11561156 )
11571157 self .model_subscription_link = json_obj .get ("model_subscription_link" )
11581158
1159+ def from_describe_hub_content_response (self , response : DescribeHubContentResponse ) -> None :
1160+ """Sets fields in object based on values in HubContentDocument
1161+
1162+ Args:
1163+ hub_content_doc (Dict[str, any]): parsed HubContentDocument returned
1164+ from SageMaker:DescribeHubContent
1165+ """
1166+ self .model_id : str = response .hub_content_name
1167+ self .version : str = response .hub_content_version
1168+ hub_content_document : HubModelDocument = response .hub_content_document
1169+ self .url : str = hub_content_document .url
1170+ self .min_sdk_version : str = hub_content_document .min_sdk_version
1171+ self .training_supported : bool = hub_content_document .training_supported
1172+ self .incremental_training_supported : bool = bool (
1173+ hub_content_document ["IncrementalTrainingSupported" ]
1174+ )
1175+ self .hosting_ecr_uri : Optional [str ] = hub_content_document .hosting_ecr_uri
1176+ self ._non_serializable_slots .append ("hosting_ecr_specs" )
1177+
1178+ hosting_artifact_bucket , hosting_artifact_key = parse_s3_url (
1179+ hub_content_document .hosting_artifact_uri
1180+ )
1181+ self .hosting_artifact_key : str = hosting_artifact_key
1182+ hosting_script_bucket , hosting_script_key = parse_s3_url (
1183+ hub_content_document .hosting_script_uri
1184+ )
1185+ self .hosting_script_key : str = hosting_script_key
1186+ self .inference_environment_variables = hub_content_document .inference_environment_variables
1187+ self .inference_vulnerable : bool = False
1188+ self .inference_dependencies : List [str ] = hub_content_document .inference_dependencies
1189+ self .inference_vulnerabilities : List [str ] = []
1190+ self .training_vulnerable : bool = False
1191+ self .training_dependencies : List [str ] = hub_content_document .training_dependencies
1192+ self .training_vulnerabilities : List [str ] = []
1193+ self .deprecated : bool = False
1194+ self .deprecated_message : Optional [str ] = None
1195+ self .deprecate_warn_message : Optional [str ] = None
1196+ self .usage_info_message : Optional [str ] = None
1197+ self .default_inference_instance_type : Optional [
1198+ str
1199+ ] = hub_content_document .default_inference_instance_type
1200+ self .default_training_instance_type : Optional [
1201+ str
1202+ ] = hub_content_document .default_training_instance_type
1203+ self .supported_inference_instance_types : Optional [
1204+ List [str ]
1205+ ] = hub_content_document .supported_inference_instance_types
1206+ self .supported_training_instance_types : Optional [
1207+ List [str ]
1208+ ] = hub_content_document .supported_training_instance_types
1209+ self .dynamic_container_deployment_supported : Optional [
1210+ bool
1211+ ] = hub_content_document .dynamic_container_deployment_supported
1212+ self .hosting_resource_requirements : Optional [
1213+ Dict [str , int ]
1214+ ] = hub_content_document .hosting_resource_requirements
1215+ self .metrics : Optional [List [Dict [str , str ]]] = hub_content_document .training_metrics
1216+ self .training_prepacked_script_key : Optional [str ] = None
1217+ if hub_content_document .training_prepacked_script_uri is not None :
1218+ training_prepacked_script_bucket , training_prepacked_script_key = parse_s3_url (
1219+ hub_content_document .training_prepacked_script_uri
1220+ )
1221+ self .training_prepacked_script_key = training_prepacked_script_key
1222+
1223+ self .hosting_prepacked_artifact_key : Optional [str ] = None
1224+ if hub_content_document .hosting_prepacked_artifact_uri is not None :
1225+ hosting_prepacked_artifact_bucket , hosting_prepacked_artifact_key = parse_s3_url (
1226+ hub_content_document .hosting_prepacked_artifact_uri
1227+ )
1228+ self .hosting_prepacked_artifact_key = hosting_prepacked_artifact_key
1229+
1230+ self .fit_kwargs = get_model_spec_kwargs_from_hub_content_document (
1231+ ModelSpecKwargType .FIT , hub_content_document
1232+ )
1233+ self .model_kwargs = get_model_spec_kwargs_from_hub_content_document (
1234+ ModelSpecKwargType .MODEL , hub_content_document
1235+ )
1236+ self .deploy_kwargs = get_model_spec_kwargs_from_hub_content_document (
1237+ ModelSpecKwargType .DEPLOY , hub_content_document
1238+ )
1239+ self .estimator_kwargs = get_model_spec_kwargs_from_hub_content_document (
1240+ ModelSpecKwargType .ESTIMATOR , hub_content_document
1241+ )
1242+
1243+ self .predictor_specs : Optional [
1244+ JumpStartPredictorSpecs
1245+ ] = hub_content_document .sage_maker_sdk_predictor_specifications
1246+ self .default_payloads : Optional [
1247+ Dict [str , JumpStartSerializablePayload ]
1248+ ] = hub_content_document .default_payloads
1249+ self .gated_bucket = hub_content_document .gated_bucket
1250+ self .inference_volume_size : Optional [int ] = hub_content_document .inference_volume_size
1251+ self .inference_enable_network_isolation : bool = (
1252+ hub_content_document .inference_enable_network_isolation
1253+ )
1254+ self .resource_name_base : Optional [str ] = hub_content_document .resource_name_base
1255+
1256+ self .hosting_eula_key : Optional [str ] = None
1257+ if hub_content_document .hosting_eula_uri is not None :
1258+ hosting_eula_bucket , hosting_eula_key = parse_s3_url (
1259+ hub_content_document .hosting_eula_uri
1260+ )
1261+ self .hosting_eula_key = hosting_eula_key
1262+
1263+ self .hosting_model_package_arns : Optional [Dict ] = None # TODO: Missing from shcema?
1264+ self .hosting_use_script_uri : bool = hub_content_document .hosting_use_script_uri
1265+
1266+ self .hosting_instance_type_variants : Optional [JumpStartInstanceTypeVariants ] = (
1267+ JumpStartInstanceTypeVariants (hub_content_document .hosting_instance_type_variants )
1268+ if hub_content_document .hosting_instance_type_variants
1269+ else None
1270+ )
1271+
1272+ if self .training_supported :
1273+ self .training_ecr_uri : Optional [str ] = hub_content_document .training_ecr_uri
1274+ self ._non_serializable_slots .append ("training_ecr_specs" )
1275+ training_artifact_bucket , training_artifact_key = parse_s3_url (
1276+ hub_content_document .training_artifact_uri
1277+ )
1278+ self .training_artifact_key : str = training_artifact_key
1279+ training_script_bucket , training_script_key = parse_s3_url (
1280+ hub_content_document .training_script_uri
1281+ )
1282+ self .training_script_key : str = training_script_key
1283+
1284+ self .hyperparameters : List [
1285+ JumpStartHyperparameter
1286+ ] = hub_content_document .hyperparameters
1287+ self .training_volume_size : Optional [int ] = hub_content_document .training_volume_size
1288+ self .training_enable_network_isolation : bool = (
1289+ hub_content_document .training_enable_network_isolation
1290+ )
1291+ self .training_model_package_artifact_uris : Optional [
1292+ Dict
1293+ ] = hub_content_document .training_model_package_artifact_uri
1294+ self .training_instance_type_variants : Optional [
1295+ JumpStartInstanceTypeVariants
1296+ ] = JumpStartInstanceTypeVariants (
1297+ hub_content_document .training_instance_type_variants
1298+ if hub_content_document .training_instance_type_variants
1299+ else None
1300+ )
1301+
11591302 def supports_prepacked_inference (self ) -> bool :
11601303 """Returns True if the model has a prepacked inference artifact."""
11611304 return getattr (self , "hosting_prepacked_artifact_key" , None ) is not None
0 commit comments