@@ -300,40 +300,40 @@ class SHAPConfig(ExplainabilityConfig):
300300
301301 def __init__ (
302302 self ,
303- num_samples ,
304- agg_method ,
305303 baseline = None ,
306- num_clusters = None ,
304+ num_samples = None ,
305+ agg_method = None ,
307306 use_logit = False ,
308307 save_local_shap_values = True ,
309308 seed = None ,
309+ num_clusters = None
310310 ):
311311 """Initializes config for SHAP.
312312
313313 Args:
314- num_samples (int): Number of samples to be used in the Kernel SHAP algorithm.
315- This number determines the size of the generated synthetic dataset to compute the
316- SHAP values.
317- agg_method (str): Aggregation method for global SHAP values. Valid values are
318- "mean_abs" (mean of absolute SHAP values for all instances),
319- "median" (median of SHAP values for all instances) and
320- "mean_sq" (mean of squared SHAP values for all instances).
321314 baseline (None or str or list): None or S3 object Uri or A list of rows (at least one)
322315 to be used asthe baseline dataset in the Kernel SHAP algorithm. The format should
323316 be the same as the dataset format. Each row should contain only the feature
324317 columns/values and omit the label column/values. If None a baseline will be
325318 calculated automatically by using K-means or K-prototypes in the input dataset.
326- num_clusters (None or int): If a baseline is not provided, Clarify automatically
327- computes a baseline dataset via a clustering algorithm. num_clusters is a parameter
328- for K-means/K-prototypes. Default is None.
319+ num_samples (None or int): Number of samples to be used in the Kernel SHAP algorithm.
320+ This number determines the size of the generated synthetic dataset to compute the
321+ SHAP values.
322+ agg_method (None or str): Aggregation method for global SHAP values. Valid values are
323+ "mean_abs" (mean of absolute SHAP values for all instances),
324+ "median" (median of SHAP values for all instances) and
325+ "mean_sq" (mean of squared SHAP values for all instances).
329326 use_logit (bool): Indicator of whether the logit function is to be applied to the model
330327 predictions. Default is False. If "use_logit" is true then the SHAP values will
331328 have log-odds units.
332329 save_local_shap_values (bool): Indicator of whether to save the local SHAP values
333330 in the output location. Default is True.
334331 seed (int): seed value to get deterministic SHAP values. Default is None.
332+ num_clusters (None or int): If a baseline is not provided, Clarify automatically
333+ computes a baseline dataset via a clustering algorithm. num_clusters is a parameter
334+ for K-means/K-prototypes. Default is None.
335335 """
336- if agg_method not in ["mean_abs" , "median" , "mean_sq" ]:
336+ if agg_method is not None and agg_method not in ["mean_abs" , "median" , "mean_sq" ]:
337337 raise ValueError (
338338 f"Invalid agg_method { agg_method } ." f" Please choose mean_abs, median, or mean_sq."
339339 )
0 commit comments