@@ -735,6 +735,41 @@ def pdp_config():
735735 return PDPConfig (features = ["F1" , "F2" ], grid_resolution = 20 )
736736
737737
738+ def test_model_config_validations ():
739+ new_model_endpoint_definition = {
740+ "model_name" : "xgboost-model" ,
741+ "instance_type" : "ml.c5.xlarge" ,
742+ "instance_count" : 1 ,
743+ }
744+ existing_endpoint_definition = {"endpoint_name" : "existing_endpoint" }
745+
746+ with pytest .raises (AssertionError ):
747+ # should be one of them
748+ ModelConfig (
749+ ** new_model_endpoint_definition ,
750+ ** existing_endpoint_definition ,
751+ )
752+
753+ with pytest .raises (AssertionError ):
754+ # should be one of them
755+ ModelConfig (
756+ endpoint_name_prefix = "prefix" ,
757+ ** existing_endpoint_definition ,
758+ )
759+
760+ # success path for new model
761+ assert ModelConfig (** new_model_endpoint_definition ).predictor_config == {
762+ "initial_instance_count" : 1 ,
763+ "instance_type" : "ml.c5.xlarge" ,
764+ "model_name" : "xgboost-model" ,
765+ }
766+
767+ # success path for existing endpoint
768+ assert (
769+ ModelConfig (** existing_endpoint_definition ).predictor_config == existing_endpoint_definition
770+ )
771+
772+
738773@patch ("sagemaker.utils.name_from_base" , return_value = JOB_NAME )
739774def test_pre_training_bias (
740775 name_from_base ,
@@ -1396,6 +1431,47 @@ def test_analysis_config_generator_for_bias_explainability(
13961431 assert actual == expected
13971432
13981433
1434+ def test_analysis_config_generator_for_bias_explainability_with_existing_endpoint (
1435+ data_config , data_bias_config
1436+ ):
1437+ model_config = ModelConfig (endpoint_name = "existing_endpoint_name" )
1438+ model_predicted_label_config = ModelPredictedLabelConfig (
1439+ probability = "pr" ,
1440+ label_headers = ["success" ],
1441+ )
1442+ actual = _AnalysisConfigGenerator .bias_and_explainability (
1443+ data_config ,
1444+ model_config ,
1445+ model_predicted_label_config ,
1446+ [SHAPConfig (), PDPConfig ()],
1447+ data_bias_config ,
1448+ pre_training_methods = "all" ,
1449+ post_training_methods = "all" ,
1450+ )
1451+ expected = {
1452+ "dataset_type" : "text/csv" ,
1453+ "facet" : [{"name_or_index" : "F1" }],
1454+ "group_variable" : "F2" ,
1455+ "headers" : ["Label" , "F1" , "F2" , "F3" , "F4" ],
1456+ "joinsource_name_or_index" : "F4" ,
1457+ "label" : "Label" ,
1458+ "label_values_or_threshold" : [1 ],
1459+ "methods" : {
1460+ "pdp" : {"grid_resolution" : 15 , "top_k_features" : 10 },
1461+ "post_training_bias" : {"methods" : "all" },
1462+ "pre_training_bias" : {"methods" : "all" },
1463+ "report" : {"name" : "report" , "title" : "Analysis Report" },
1464+ "shap" : {"save_local_shap_values" : True , "use_logit" : False },
1465+ },
1466+ "predictor" : {
1467+ "label_headers" : ["success" ],
1468+ "endpoint_name" : "existing_endpoint_name" ,
1469+ "probability" : "pr" ,
1470+ },
1471+ }
1472+ assert actual == expected
1473+
1474+
13991475def test_analysis_config_generator_for_bias_pre_training (data_config , data_bias_config ):
14001476 actual = _AnalysisConfigGenerator .bias_pre_training (
14011477 data_config , data_bias_config , methods = "all"
0 commit comments