@@ -300,28 +300,32 @@ class SHAPConfig(ExplainabilityConfig):
300300
301301 def __init__ (
302302 self ,
303- baseline ,
304303 num_samples ,
305304 agg_method ,
305+ baseline = None ,
306+ num_clusters = None ,
306307 use_logit = False ,
307308 save_local_shap_values = True ,
308309 seed = None ,
309310 ):
310311 """Initializes config for SHAP.
311312
312313 Args:
313- baseline (None or str or list): None or S3 object Uri or A list of rows (at least one)
314- to be used asthe baseline dataset in the Kernel SHAP algorithm. The format should
315- be the same as the dataset format. Each row should contain only the feature
316- columns/values and omit the label column/values. If None a baseline will be
317- calculated automatically by using K-means or K-prototypes in the input dataset.
318314 num_samples (int): Number of samples to be used in the Kernel SHAP algorithm.
319315 This number determines the size of the generated synthetic dataset to compute the
320316 SHAP values.
321317 agg_method (str): Aggregation method for global SHAP values. Valid values are
322318 "mean_abs" (mean of absolute SHAP values for all instances),
323319 "median" (median of SHAP values for all instances) and
324320 "mean_sq" (mean of squared SHAP values for all instances).
321+ baseline (None or str or list): None or S3 object Uri or A list of rows (at least one)
322+ to be used asthe baseline dataset in the Kernel SHAP algorithm. The format should
323+ be the same as the dataset format. Each row should contain only the feature
324+ columns/values and omit the label column/values. If None a baseline will be
325+ calculated automatically by using K-means or K-prototypes in the input dataset.
326+ num_clusters (None or int): If a baseline is not provided, Clarify automatically computes a
327+ baseline dataset via a clustering algorithm. num_clusters is a parameter for K-means/K-prototypes.
328+ Default is None.
325329 use_logit (bool): Indicator of whether the logit function is to be applied to the model
326330 predictions. Default is False. If "use_logit" is true then the SHAP values will
327331 have log-odds units.
@@ -333,14 +337,20 @@ def __init__(
333337 raise ValueError (
334338 f"Invalid agg_method { agg_method } ." f" Please choose mean_abs, median, or mean_sq."
335339 )
336-
340+ if num_clusters is not None and baseline is not None :
341+ raise ValueError (
342+ "Baseline and num_clusters cannot both be provided. Please specify one of the two."
343+ )
337344 self .shap_config = {
338- "baseline" : baseline ,
339345 "num_samples" : num_samples ,
340346 "agg_method" : agg_method ,
341347 "use_logit" : use_logit ,
342348 "save_local_shap_values" : save_local_shap_values ,
343349 }
350+ if baseline is not None :
351+ self .shap_config ["baseline" ] = baseline
352+ if num_clusters is not None :
353+ self .shap_config ["num_clusters" ] = num_clusters
344354 if seed is not None :
345355 self .shap_config ["seed" ] = seed
346356
0 commit comments