Skip to content

Commit 18c25b5

Browse files
author
Pranav Krishnan
committed
fixed shapconfig method signature to make backwards compatible
1 parent 0b03bc2 commit 18c25b5

File tree

1 file changed

+14
-14
lines changed

1 file changed

+14
-14
lines changed

src/sagemaker/clarify.py

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -300,40 +300,40 @@ class SHAPConfig(ExplainabilityConfig):
300300

301301
def __init__(
302302
self,
303-
num_samples,
304-
agg_method,
305303
baseline=None,
306-
num_clusters=None,
304+
num_samples=None,
305+
agg_method=None,
307306
use_logit=False,
308307
save_local_shap_values=True,
309308
seed=None,
309+
num_clusters=None
310310
):
311311
"""Initializes config for SHAP.
312312
313313
Args:
314-
num_samples (int): Number of samples to be used in the Kernel SHAP algorithm.
315-
This number determines the size of the generated synthetic dataset to compute the
316-
SHAP values.
317-
agg_method (str): Aggregation method for global SHAP values. Valid values are
318-
"mean_abs" (mean of absolute SHAP values for all instances),
319-
"median" (median of SHAP values for all instances) and
320-
"mean_sq" (mean of squared SHAP values for all instances).
321314
baseline (None or str or list): None or S3 object Uri or A list of rows (at least one)
322315
to be used asthe baseline dataset in the Kernel SHAP algorithm. The format should
323316
be the same as the dataset format. Each row should contain only the feature
324317
columns/values and omit the label column/values. If None a baseline will be
325318
calculated automatically by using K-means or K-prototypes in the input dataset.
326-
num_clusters (None or int): If a baseline is not provided, Clarify automatically
327-
computes a baseline dataset via a clustering algorithm. num_clusters is a parameter
328-
for K-means/K-prototypes. Default is None.
319+
num_samples (None or int): Number of samples to be used in the Kernel SHAP algorithm.
320+
This number determines the size of the generated synthetic dataset to compute the
321+
SHAP values.
322+
agg_method (None or str): Aggregation method for global SHAP values. Valid values are
323+
"mean_abs" (mean of absolute SHAP values for all instances),
324+
"median" (median of SHAP values for all instances) and
325+
"mean_sq" (mean of squared SHAP values for all instances).
329326
use_logit (bool): Indicator of whether the logit function is to be applied to the model
330327
predictions. Default is False. If "use_logit" is true then the SHAP values will
331328
have log-odds units.
332329
save_local_shap_values (bool): Indicator of whether to save the local SHAP values
333330
in the output location. Default is True.
334331
seed (int): seed value to get deterministic SHAP values. Default is None.
332+
num_clusters (None or int): If a baseline is not provided, Clarify automatically
333+
computes a baseline dataset via a clustering algorithm. num_clusters is a parameter
334+
for K-means/K-prototypes. Default is None.
335335
"""
336-
if agg_method not in ["mean_abs", "median", "mean_sq"]:
336+
if agg_method is not None and agg_method not in ["mean_abs", "median", "mean_sq"]:
337337
raise ValueError(
338338
f"Invalid agg_method {agg_method}." f" Please choose mean_abs, median, or mean_sq."
339339
)

0 commit comments

Comments
 (0)