@@ -31,6 +31,7 @@ def retrieve_default(
3131 region : Optional [str ] = None ,
3232 model_id : Optional [str ] = None ,
3333 model_version : Optional [str ] = None ,
34+ hub_arn : Optional [str ] = None ,
3435 instance_type : Optional [str ] = None ,
3536 include_container_hyperparameters : bool = False ,
3637 tolerate_vulnerable_model : bool = False ,
@@ -46,6 +47,8 @@ def retrieve_default(
4647 retrieve the default hyperparameters. (Default: None).
4748 model_version (str): The version of the model for which to retrieve the
4849 default hyperparameters. (Default: None).
50+ hub_arn (str): The arn of the SageMaker Hub for which to retrieve
51+ model details from. (default: None).
4952 instance_type (str): An instance type to optionally supply in order to get hyperparameters
5053 specific for the instance type.
5154 include_container_hyperparameters (bool): ``True`` if the container hyperparameters
@@ -80,6 +83,7 @@ def retrieve_default(
8083 return artifacts ._retrieve_default_hyperparameters (
8184 model_id = model_id ,
8285 model_version = model_version ,
86+ hub_arn = hub_arn ,
8387 instance_type = instance_type ,
8488 region = region ,
8589 include_container_hyperparameters = include_container_hyperparameters ,
@@ -93,6 +97,7 @@ def validate(
9397 region : Optional [str ] = None ,
9498 model_id : Optional [str ] = None ,
9599 model_version : Optional [str ] = None ,
100+ hub_arn : Optional [str ] = None ,
96101 hyperparameters : Optional [dict ] = None ,
97102 validation_mode : HyperparameterValidationMode = HyperparameterValidationMode .VALIDATE_PROVIDED ,
98103 tolerate_vulnerable_model : bool = False ,
@@ -107,6 +112,8 @@ def validate(
107112 (Default: None).
108113 model_version (str): The version of the model for which to validate hyperparameters.
109114 (Default: None).
115+ hub_arn (str): The arn of the SageMaker Hub for which to retrieve
116+ model details from. (default: None).
110117 hyperparameters (dict): Hyperparameters to validate.
111118 (Default: None).
112119 validation_mode (HyperparameterValidationMode): Method of validation to use with
@@ -148,6 +155,7 @@ def validate(
148155 return validate_hyperparameters (
149156 model_id = model_id ,
150157 model_version = model_version ,
158+ hub_arn = hub_arn ,
151159 hyperparameters = hyperparameters ,
152160 validation_mode = validation_mode ,
153161 region = region ,
0 commit comments