Skip to content

Commit 577a232

Browse files
author
Pranav Krishnan
committed
fixed shapconfig docstring, added shapconfig unit test
1 parent b12677d commit 577a232

File tree

2 files changed

+27
-7
lines changed

2 files changed

+27
-7
lines changed

src/sagemaker/clarify.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -318,20 +318,23 @@ def __init__(
318318
calculated automatically by using K-means or K-prototypes in the input dataset.
319319
num_samples (None or int): Number of samples to be used in the Kernel SHAP algorithm.
320320
This number determines the size of the generated synthetic dataset to compute the
321-
SHAP values.
321+
SHAP values. If not provided then Clarify job will choose a proper value according
322+
to the count of features.
322323
agg_method (None or str): Aggregation method for global SHAP values. Valid values are
323324
"mean_abs" (mean of absolute SHAP values for all instances),
324325
"median" (median of SHAP values for all instances) and
325326
"mean_sq" (mean of squared SHAP values for all instances).
327+
If not provided then Clarify job uses method "mean_abs"
326328
use_logit (bool): Indicator of whether the logit function is to be applied to the model
327329
predictions. Default is False. If "use_logit" is true then the SHAP values will
328330
have log-odds units.
329331
save_local_shap_values (bool): Indicator of whether to save the local SHAP values
330332
in the output location. Default is True.
331333
seed (int): seed value to get deterministic SHAP values. Default is None.
332334
num_clusters (None or int): If a baseline is not provided, Clarify automatically
333-
computes a baseline dataset via a clustering algorithm. num_clusters is a parameter
334-
for K-means/K-prototypes. Default is None.
335+
computes a baseline dataset via a clustering algorithm (K-means/K-prototypes).
336+
num_clusters is a parameter for this algorithm. num_clusters will be the resulting
337+
size of the baseline dataset. If not provided, Clarify job will use a default value.
335338
"""
336339
if agg_method is not None and agg_method not in ["mean_abs", "median", "mean_sq"]:
337340
raise ValueError(
@@ -342,17 +345,19 @@ def __init__(
342345
"Baseline and num_clusters cannot both be provided. Please specify one of the two."
343346
)
344347
self.shap_config = {
345-
"num_samples": num_samples,
346-
"agg_method": agg_method,
347348
"use_logit": use_logit,
348349
"save_local_shap_values": save_local_shap_values,
349350
}
350351
if baseline is not None:
351352
self.shap_config["baseline"] = baseline
352-
if num_clusters is not None:
353-
self.shap_config["num_clusters"] = num_clusters
353+
if num_samples is not None:
354+
self.shap_config["num_samples"] = num_samples
355+
if agg_method is not None:
356+
self.shap_config["agg_method"] = agg_method
354357
if seed is not None:
355358
self.shap_config["seed"] = seed
359+
if num_clusters is not None:
360+
self.shap_config["num_clusters"] = num_clusters
356361

357362
def get_explainability_config(self):
358363
"""Returns config."""

tests/unit/test_clarify.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -292,6 +292,15 @@ def test_shap_config_no_baseline():
292292
}
293293
assert expected_config == shap_config.get_explainability_config()
294294

295+
def test_shap_config_no_parameters():
296+
shap_config = SHAPConfig()
297+
expected_config = {
298+
"shap": {
299+
"use_logit": False,
300+
"save_local_shap_values": True,
301+
}
302+
}
303+
assert expected_config == shap_config.get_explainability_config()
295304

296305
def test_invalid_shap_config():
297306
with pytest.raises(ValueError) as error:
@@ -303,6 +312,12 @@ def test_invalid_shap_config():
303312
assert "Invalid agg_method invalid. Please choose mean_abs, median, or mean_sq." in str(
304313
error.value
305314
)
315+
with pytest.raises(ValueError) as error:
316+
SHAPConfig(baseline=[[1]], num_samples=1, agg_method="mean_abs", num_clusters=2)
317+
assert (
318+
"Baseline and num_clusters cannot both be provided. Please specify one of the two."
319+
in str(error.value)
320+
)
306321

307322

308323
@pytest.fixture(scope="module")

0 commit comments

Comments
 (0)