2626from sagemaker import image_uris , s3
2727from sagemaker .session import Session
2828from sagemaker .utils import name_from_base
29- from sagemaker .clarify import SageMakerClarifyProcessor , ModelPredictedLabelConfig
29+ from sagemaker .clarify import SageMakerClarifyProcessor
3030
3131_LOGGER = logging .getLogger (__name__ )
3232
@@ -833,10 +833,9 @@ def suggest_baseline(
833833 specific explainability method. Currently, only SHAP is supported.
834834 model_config (:class:`~sagemaker.clarify.ModelConfig`): Config of the model and its
835835 endpoint to be created.
836- model_scores (int or str or :class:`~sagemaker.clarify.ModelPredictedLabelConfig`):
837- Index or JSONPath to locate the predicted scores in the model output. This is not
838- required if the model output is a single score. Alternatively, it can be an instance
839- of ModelPredictedLabelConfig to provide more parameters like label_headers.
836+ model_scores (int or str): Index or JSONPath location in the model output for the
837+ predicted scores to be explained. This is not required if the model output is
838+ a single score.
840839 wait (bool): Whether the call should wait until the job completes (default: False).
841840 logs (bool): Whether to show the logs produced by the job.
842841 Only meaningful when wait is True (default: False).
@@ -866,24 +865,14 @@ def suggest_baseline(
866865 headers = copy .deepcopy (data_config .headers )
867866 if headers and data_config .label in headers :
868867 headers .remove (data_config .label )
869- if model_scores is None :
870- inference_attribute = None
871- label_headers = None
872- elif isinstance (model_scores , ModelPredictedLabelConfig ):
873- inference_attribute = str (model_scores .label )
874- label_headers = model_scores .label_headers
875- else :
876- inference_attribute = str (model_scores )
877- label_headers = None
878868 self .latest_baselining_job_config = ClarifyBaseliningConfig (
879869 analysis_config = ExplainabilityAnalysisConfig (
880870 explainability_config = explainability_config ,
881871 model_config = model_config ,
882872 headers = headers ,
883- label_headers = label_headers ,
884873 ),
885874 features_attribute = data_config .features ,
886- inference_attribute = inference_attribute ,
875+ inference_attribute = model_scores if model_scores is None else str ( model_scores ) ,
887876 )
888877 self .latest_baselining_job_name = baselining_job_name
889878 self .latest_baselining_job = ClarifyBaseliningJob (
@@ -1177,7 +1166,7 @@ def attach(cls, monitor_schedule_name, sagemaker_session=None):
11771166class ExplainabilityAnalysisConfig :
11781167 """Analysis configuration for ModelExplainabilityMonitor."""
11791168
1180- def __init__ (self , explainability_config , model_config , headers = None , label_headers = None ):
1169+ def __init__ (self , explainability_config , model_config , headers = None ):
11811170 """Creates an analysis config dictionary.
11821171
11831172 Args:
@@ -1186,19 +1175,13 @@ def __init__(self, explainability_config, model_config, headers=None, label_head
11861175 model_config (sagemaker.clarify.ModelConfig): Config object related to bias
11871176 configurations.
11881177 headers (list[str]): A list of feature names (without label) of model/endpint input.
1189- label_headers (list[str]): List of headers, each for a predicted score in model output.
1190- It is used to beautify the analysis report by replacing placeholders like "label0".
1191-
11921178 """
1193- predictor_config = model_config .get_predictor_config ()
11941179 self .analysis_config = {
11951180 "methods" : explainability_config .get_explainability_config (),
1196- "predictor" : predictor_config ,
1181+ "predictor" : model_config . get_predictor_config () ,
11971182 }
11981183 if headers is not None :
11991184 self .analysis_config ["headers" ] = headers
1200- if label_headers is not None :
1201- predictor_config ["label_headers" ] = label_headers
12021185
12031186 def _to_dict (self ):
12041187 """Generates a request dictionary using the parameters provided to the class."""
0 commit comments