@@ -318,20 +318,23 @@ def __init__(
318318 calculated automatically by using K-means or K-prototypes in the input dataset.
319319 num_samples (None or int): Number of samples to be used in the Kernel SHAP algorithm.
320320 This number determines the size of the generated synthetic dataset to compute the
321- SHAP values.
321+ SHAP values. If not provided then Clarify job will choose a proper value according
322+ to the count of features.
322323 agg_method (None or str): Aggregation method for global SHAP values. Valid values are
323324 "mean_abs" (mean of absolute SHAP values for all instances),
324325 "median" (median of SHAP values for all instances) and
325326 "mean_sq" (mean of squared SHAP values for all instances).
327+ If not provided then Clarify job uses method "mean_abs"
326328 use_logit (bool): Indicator of whether the logit function is to be applied to the model
327329 predictions. Default is False. If "use_logit" is true then the SHAP values will
328330 have log-odds units.
329331 save_local_shap_values (bool): Indicator of whether to save the local SHAP values
330332 in the output location. Default is True.
331333 seed (int): seed value to get deterministic SHAP values. Default is None.
332334 num_clusters (None or int): If a baseline is not provided, Clarify automatically
333- computes a baseline dataset via a clustering algorithm. num_clusters is a parameter
334- for K-means/K-prototypes. Default is None.
335+ computes a baseline dataset via a clustering algorithm (K-means/K-prototypes).
336+ num_clusters is a parameter for this algorithm. num_clusters will be the resulting
337+ size of the baseline dataset. If not provided, Clarify job will use a default value.
335338 """
336339 if agg_method is not None and agg_method not in ["mean_abs" , "median" , "mean_sq" ]:
337340 raise ValueError (
@@ -342,17 +345,19 @@ def __init__(
342345 "Baseline and num_clusters cannot both be provided. Please specify one of the two."
343346 )
344347 self .shap_config = {
345- "num_samples" : num_samples ,
346- "agg_method" : agg_method ,
347348 "use_logit" : use_logit ,
348349 "save_local_shap_values" : save_local_shap_values ,
349350 }
350351 if baseline is not None :
351352 self .shap_config ["baseline" ] = baseline
352- if num_clusters is not None :
353- self .shap_config ["num_clusters" ] = num_clusters
353+ if num_samples is not None :
354+ self .shap_config ["num_samples" ] = num_samples
355+ if agg_method is not None :
356+ self .shap_config ["agg_method" ] = agg_method
354357 if seed is not None :
355358 self .shap_config ["seed" ] = seed
359+ if num_clusters is not None :
360+ self .shap_config ["num_clusters" ] = num_clusters
356361
357362 def get_explainability_config (self ):
358363 """Returns config."""
0 commit comments