@@ -1369,15 +1369,139 @@ def run_explainability(
13691369 experiment_config ,
13701370 )
13711371
1372- def run_bias_and_explainability (self ):
1373- """
1374- TODO:
1375- - add doc string
1376- - add logic
1377- - add tests
1378- """
1379- raise NotImplementedError (
1380- "Please choose a method of run_pre_training_bias, run_post_training_bias or run_explainability."
1372+ def run_bias_and_explainability (
1373+ self ,
1374+ data_config : DataConfig ,
1375+ model_config : ModelConfig ,
1376+ explainability_config : Union [ExplainabilityConfig , List [ExplainabilityConfig ]],
1377+ bias_config : BiasConfig ,
1378+ pre_training_methods : Union [str , List [str ]] = "all" ,
1379+ post_training_methods : Union [str , List [str ]] = "all" ,
1380+ model_predicted_label_config : ModelPredictedLabelConfig = None ,
1381+ wait = True ,
1382+ logs = True ,
1383+ job_name = None ,
1384+ kms_key = None ,
1385+ experiment_config = None ,
1386+ ):
1387+ """Runs a :class:`~sagemaker.processing.ProcessingJob` computing feature attributions.
1388+
1389+ For bias:
1390+ Computes metrics for both the pre-training and the post-training methods.
1391+ To calculate post-training methods, it spins up a model endpoint and runs inference over the
1392+ input examples in 's3_data_input_path' (from the :class:`~sagemaker.clarify.DataConfig`)
1393+ to obtain predicted labels.
1394+
1395+ For Explainability:
1396+ Spins up a model endpoint.
1397+
1398+ Currently, only SHAP and Partial Dependence Plots (PDP) are supported
1399+ as explainability methods.
1400+ You can request both methods or one at a time with the ``explainability_config`` parameter.
1401+
1402+ When SHAP is requested in the ``explainability_config``,
1403+ the SHAP algorithm calculates the feature importance for each input example
1404+ in the ``s3_data_input_path`` of the :class:`~sagemaker.clarify.DataConfig`,
1405+ by creating ``num_samples`` copies of the example with a subset of features
1406+ replaced with values from the ``baseline``.
1407+ It then runs model inference to see how the model's prediction changes with the replaced
1408+ features. If the model output returns multiple scores importance is computed for each score.
1409+ Across examples, feature importance is aggregated using ``agg_method``.
1410+
1411+ When PDP is requested in the ``explainability_config``,
1412+ the PDP algorithm calculates the dependence of the target response
1413+ on the input features and marginalizes over the values of all other input features.
1414+ The Partial Dependence Plots are included in the output
1415+ `report <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-feature-attribute-baselines-reports.html>`__
1416+ and the corresponding values are included in the analysis output.
1417+
1418+ Args:
1419+ data_config (:class:`~sagemaker.clarify.DataConfig`): Config of the input/output data.
1420+ model_config (:class:`~sagemaker.clarify.ModelConfig`): Config of the model and its
1421+ endpoint to be created.
1422+ explainability_config (:class:`~sagemaker.clarify.ExplainabilityConfig` or list):
1423+ Config of the specific explainability method or a list of
1424+ :class:`~sagemaker.clarify.ExplainabilityConfig` objects.
1425+ Currently, SHAP and PDP are the two methods supported.
1426+ You can request multiple methods at once by passing in a list of
1427+ `~sagemaker.clarify.ExplainabilityConfig`.
1428+ bias_config (:class:`~sagemaker.clarify.BiasConfig`): Config of sensitive groups.
1429+ pre_training_methods (str or list[str]): Selector of a subset of potential metrics:
1430+ ["`CI <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-bias-metric-class-imbalance.html>`_",
1431+ "`DPL <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-data-bias-metric-true-label-imbalance.html>`_",
1432+ "`KL <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-data-bias-metric-kl-divergence.html>`_",
1433+ "`JS <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-data-bias-metric-jensen-shannon-divergence.html>`_",
1434+ "`LP <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-data-bias-metric-lp-norm.html>`_",
1435+ "`TVD <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-data-bias-metric-total-variation-distance.html>`_",
1436+ "`KS <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-data-bias-metric-kolmogorov-smirnov.html>`_",
1437+ "`CDDL <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-data-bias-metric-cddl.html>`_"].
1438+ Defaults to str "all" to run all metrics if left unspecified.
1439+ post_training_methods (str or list[str]): Selector of a subset of potential metrics:
1440+ ["`DPPL <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-post-training-bias-metric-dppl.html>`_"
1441+ , "`DI <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-post-training-bias-metric-di.html>`_",
1442+ "`DCA <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-post-training-bias-metric-dca.html>`_",
1443+ "`DCR <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-post-training-bias-metric-dcr.html>`_",
1444+ "`RD <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-post-training-bias-metric-rd.html>`_",
1445+ "`DAR <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-post-training-bias-metric-dar.html>`_",
1446+ "`DRR <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-post-training-bias-metric-drr.html>`_",
1447+ "`AD <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-post-training-bias-metric-ad.html>`_",
1448+ "`CDDPL <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-post-training-bias-metric-cddpl.html>`_
1449+ ", "`TE <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-post-training-bias-metric-te.html>`_",
1450+ "`FT <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-post-training-bias-metric-ft.html>`_"].
1451+ Defaults to str "all" to run all metrics if left unspecified.
1452+ model_predicted_label_config (
1453+ int or
1454+ str or
1455+ :class:`~sagemaker.clarify.ModelPredictedLabelConfig`
1456+ ):
1457+ Index or JSONPath to locate the predicted scores in the model output. This is not
1458+ required if the model output is a single score. Alternatively, it can be an instance
1459+ of :class:`~sagemaker.clarify.SageMakerClarifyProcessor`
1460+ to provide more parameters like ``label_headers``.
1461+ wait (bool): Whether the call should wait until the job completes (default: True).
1462+ logs (bool): Whether to show the logs produced by the job.
1463+ Only meaningful when ``wait`` is True (default: True).
1464+ job_name (str): Processing job name. When ``job_name`` is not specified,
1465+ if ``job_name_prefix`` in :class:`~sagemaker.clarify.SageMakerClarifyProcessor`
1466+ is specified, the job name will be composed of ``job_name_prefix`` and current
1467+ timestamp; otherwise use ``"Clarify-Explainability"`` as prefix.
1468+ kms_key (str): The ARN of the KMS key that is used to encrypt the
1469+ user code file (default: None).
1470+ experiment_config (dict[str, str]): Experiment management configuration.
1471+ Optionally, the dict can contain three keys:
1472+ ``'ExperimentName'``, ``'TrialName'``, and ``'TrialComponentDisplayName'``.
1473+
1474+ The behavior of setting these keys is as follows:
1475+
1476+ * If ``'ExperimentName'`` is supplied but ``'TrialName'`` is not, a Trial will be
1477+ automatically created and the job's Trial Component associated with the Trial.
1478+ * If ``'TrialName'`` is supplied and the Trial already exists,
1479+ the job's Trial Component will be associated with the Trial.
1480+ * If both ``'ExperimentName'`` and ``'TrialName'`` are not supplied,
1481+ the Trial Component will be unassociated.
1482+ * ``'TrialComponentDisplayName'`` is used for display in Amazon SageMaker Studio.
1483+ """ # noqa E501 # pylint: disable=c0301
1484+ analysis_config = _AnalysisConfigGenerator .bias_and_explainability (
1485+ data_config ,
1486+ model_config ,
1487+ model_predicted_label_config ,
1488+ explainability_config ,
1489+ bias_config ,
1490+ pre_training_methods ,
1491+ post_training_methods ,
1492+ )
1493+ # when name is either not provided (is None) or an empty string ("")
1494+ job_name = job_name or utils .name_from_base (
1495+ self .job_name_prefix or "Clarify-Bias-And-Explainability"
1496+ )
1497+ return self ._run (
1498+ data_config ,
1499+ analysis_config ,
1500+ wait ,
1501+ logs ,
1502+ job_name ,
1503+ kms_key ,
1504+ experiment_config ,
13811505 )
13821506
13831507
@@ -1395,6 +1519,7 @@ def bias_and_explainability(
13951519 pre_training_methods : Union [str , List [str ]] = "all" ,
13961520 post_training_methods : Union [str , List [str ]] = "all" ,
13971521 ):
1522+ """Generates a config for Bias and Explainability"""
13981523 analysis_config = {** data_config .get_config (), ** bias_config .get_config ()}
13991524 analysis_config = cls ._add_methods (
14001525 analysis_config ,
@@ -1475,6 +1600,7 @@ def bias(
14751600
14761601 @classmethod
14771602 def _add_predictor (cls , analysis_config , model_config , model_predicted_label_config ):
1603+ """Extends analysis config with predictor."""
14781604 analysis_config = {** analysis_config }
14791605 analysis_config ["predictor" ] = model_config .get_predictor_config ()
14801606 if isinstance (model_predicted_label_config , ModelPredictedLabelConfig ):
@@ -1498,12 +1624,14 @@ def _add_methods(
14981624 explainability_config = None ,
14991625 report = True ,
15001626 ):
1627+ """Extends analysis config with methods."""
15011628 # validate
15021629 params = [pre_training_methods , post_training_methods , explainability_config ]
15031630 if all ([1 if p is None else 0 for p in params ]):
15041631 raise AttributeError (
15051632 "analysis_config must have at least one working method: "
1506- "One of the `pre_training_methods`, `post_training_methods`, `explainability_config`."
1633+ "One of the "
1634+ "`pre_training_methods`, `post_training_methods`, `explainability_config`."
15071635 )
15081636
15091637 # main logic
@@ -1529,6 +1657,7 @@ def _add_methods(
15291657 def _merge_explainability_configs (
15301658 cls , explainability_config : Union [ExplainabilityConfig , List [ExplainabilityConfig ]]
15311659 ):
1660+ """Merges explainability configs, when more than one."""
15321661 if isinstance (explainability_config , list ):
15331662 explainability_methods = {}
15341663 if len (explainability_config ) == 0 :
0 commit comments