@@ -379,13 +379,20 @@ def test_post_training_bias(
379379 )
380380
381381
382- def test_shap (clarify_processor , data_config , model_config , shap_config ):
382+ def _run_test_shap (
383+ clarify_processor ,
384+ data_config ,
385+ model_config ,
386+ shap_config ,
387+ model_scores ,
388+ expected_predictor_config ,
389+ ):
383390 with patch .object (SageMakerClarifyProcessor , "_run" , return_value = None ) as mock_method :
384391 clarify_processor .run_explainability (
385392 data_config ,
386393 model_config ,
387394 shap_config ,
388- model_scores = None ,
395+ model_scores = model_scores ,
389396 wait = True ,
390397 job_name = "test" ,
391398 experiment_config = {"ExperimentName" : "AnExperiment" },
@@ -414,11 +421,7 @@ def test_shap(clarify_processor, data_config, model_config, shap_config):
414421 "save_local_shap_values" : True ,
415422 }
416423 },
417- "predictor" : {
418- "model_name" : "xgboost-model" ,
419- "instance_type" : "ml.c5.xlarge" ,
420- "initial_instance_count" : 1 ,
421- },
424+ "predictor" : expected_predictor_config ,
422425 }
423426 mock_method .assert_called_once_with (
424427 data_config ,
@@ -429,3 +432,44 @@ def test_shap(clarify_processor, data_config, model_config, shap_config):
429432 None ,
430433 {"ExperimentName" : "AnExperiment" },
431434 )
435+
436+
437+ def test_shap (clarify_processor , data_config , model_config , shap_config ):
438+ model_scores = None
439+ expected_predictor_config = {
440+ "model_name" : "xgboost-model" ,
441+ "instance_type" : "ml.c5.xlarge" ,
442+ "initial_instance_count" : 1 ,
443+ }
444+ _run_test_shap (
445+ clarify_processor ,
446+ data_config ,
447+ model_config ,
448+ shap_config ,
449+ model_scores ,
450+ expected_predictor_config ,
451+ )
452+
453+
454+ def test_shap_with_predicted_label (clarify_processor , data_config , model_config , shap_config ):
455+ probability = "pr"
456+ label_headers = ["success" ]
457+ model_scores = ModelPredictedLabelConfig (
458+ probability = probability ,
459+ label_headers = label_headers ,
460+ )
461+ expected_predictor_config = {
462+ "model_name" : "xgboost-model" ,
463+ "instance_type" : "ml.c5.xlarge" ,
464+ "initial_instance_count" : 1 ,
465+ "probability" : probability ,
466+ "label_headers" : label_headers ,
467+ }
468+ _run_test_shap (
469+ clarify_processor ,
470+ data_config ,
471+ model_config ,
472+ shap_config ,
473+ model_scores ,
474+ expected_predictor_config ,
475+ )
0 commit comments