Skip to content

Commit 426c230

Browse files
author
Pranav Krishnan
committed
change: expose num_clusters parameter for clarify shap in shapconfig
1 parent a1f0aeb commit 426c230

File tree

2 files changed

+43
-8
lines changed

2 files changed

+43
-8
lines changed

src/sagemaker/clarify.py

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -300,28 +300,32 @@ class SHAPConfig(ExplainabilityConfig):
300300

301301
def __init__(
302302
self,
303-
baseline,
304303
num_samples,
305304
agg_method,
305+
baseline=None,
306+
num_clusters=None,
306307
use_logit=False,
307308
save_local_shap_values=True,
308309
seed=None,
309310
):
310311
"""Initializes config for SHAP.
311312
312313
Args:
313-
baseline (None or str or list): None or S3 object Uri or A list of rows (at least one)
314-
to be used asthe baseline dataset in the Kernel SHAP algorithm. The format should
315-
be the same as the dataset format. Each row should contain only the feature
316-
columns/values and omit the label column/values. If None a baseline will be
317-
calculated automatically by using K-means or K-prototypes in the input dataset.
318314
num_samples (int): Number of samples to be used in the Kernel SHAP algorithm.
319315
This number determines the size of the generated synthetic dataset to compute the
320316
SHAP values.
321317
agg_method (str): Aggregation method for global SHAP values. Valid values are
322318
"mean_abs" (mean of absolute SHAP values for all instances),
323319
"median" (median of SHAP values for all instances) and
324320
"mean_sq" (mean of squared SHAP values for all instances).
321+
baseline (None or str or list): None or S3 object Uri or A list of rows (at least one)
322+
to be used asthe baseline dataset in the Kernel SHAP algorithm. The format should
323+
be the same as the dataset format. Each row should contain only the feature
324+
columns/values and omit the label column/values. If None a baseline will be
325+
calculated automatically by using K-means or K-prototypes in the input dataset.
326+
num_clusters (None or int): If a baseline is not provided, Clarify automatically computes a
327+
baseline dataset via a clustering algorithm. num_clusters is a parameter for K-means/K-prototypes.
328+
Default is None.
325329
use_logit (bool): Indicator of whether the logit function is to be applied to the model
326330
predictions. Default is False. If "use_logit" is true then the SHAP values will
327331
have log-odds units.
@@ -333,14 +337,20 @@ def __init__(
333337
raise ValueError(
334338
f"Invalid agg_method {agg_method}." f" Please choose mean_abs, median, or mean_sq."
335339
)
336-
340+
if num_clusters is not None and baseline is not None:
341+
raise ValueError(
342+
"Baseline and num_clusters cannot both be provided. Please specify one of the two."
343+
)
337344
self.shap_config = {
338-
"baseline": baseline,
339345
"num_samples": num_samples,
340346
"agg_method": agg_method,
341347
"use_logit": use_logit,
342348
"save_local_shap_values": save_local_shap_values,
343349
}
350+
if baseline is not None:
351+
self.shap_config["baseline"] = baseline
352+
if num_clusters is not None:
353+
self.shap_config["num_clusters"] = num_clusters
344354
if seed is not None:
345355
self.shap_config["seed"] = seed
346356

tests/unit/test_clarify.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -268,6 +268,31 @@ def test_shap_config():
268268
assert expected_config == shap_config.get_explainability_config()
269269

270270

271+
def test_shap_config_no_baseline():
272+
num_samples = 100
273+
agg_method = "mean_sq"
274+
use_logit = True
275+
seed = 123
276+
shap_config = SHAPConfig(
277+
num_samples=num_samples,
278+
agg_method=agg_method,
279+
num_clusters=2,
280+
use_logit=use_logit,
281+
seed=seed,
282+
)
283+
expected_config = {
284+
"shap": {
285+
"num_samples": num_samples,
286+
"agg_method": agg_method,
287+
"num_clusters": 2,
288+
"use_logit": use_logit,
289+
"save_local_shap_values": True,
290+
"seed": seed,
291+
}
292+
}
293+
assert expected_config == shap_config.get_explainability_config()
294+
295+
271296
def test_invalid_shap_config():
272297
with pytest.raises(ValueError) as error:
273298
SHAPConfig(

0 commit comments

Comments
 (0)