@@ -1368,15 +1368,135 @@ def run_explainability(
13681368 experiment_config ,
13691369 )
13701370
1371- def run_bias_and_explainability (self ):
1372- """
1373- TODO:
1374- - add doc string
1375- - add logic
1376- - add tests
1377- """
1378- raise NotImplementedError (
1379- "Please choose a method of run_pre_training_bias, run_post_training_bias or run_explainability."
1371+ def run_bias_and_explainability (
1372+ self ,
1373+ data_config : DataConfig ,
1374+ model_config : ModelConfig ,
1375+ explainability_config : Union [ExplainabilityConfig , List [ExplainabilityConfig ]],
1376+ bias_config : BiasConfig ,
1377+ pre_training_methods : Union [str , List [str ]] = "all" ,
1378+ post_training_methods : Union [str , List [str ]] = "all" ,
1379+ model_predicted_label_config : ModelPredictedLabelConfig = None ,
1380+ wait = True ,
1381+ logs = True ,
1382+ job_name = None ,
1383+ kms_key = None ,
1384+ experiment_config = None ,
1385+ ):
1386+ """Runs a :class:`~sagemaker.processing.ProcessingJob` computing feature attributions.
1387+
1388+ For bias:
1389+ Computes metrics for both the pre-training and the post-training methods.
1390+ To calculate post-training methods, it spins up a model endpoint and runs inference over the
1391+ input examples in 's3_data_input_path' (from the :class:`~sagemaker.clarify.DataConfig`)
1392+ to obtain predicted labels.
1393+
1394+ For Explainability:
1395+ Spins up a model endpoint.
1396+
1397+ Currently, only SHAP and Partial Dependence Plots (PDP) are supported
1398+ as explainability methods.
1399+ You can request both methods or one at a time with the ``explainability_config`` parameter.
1400+
1401+ When SHAP is requested in the ``explainability_config``,
1402+ the SHAP algorithm calculates the feature importance for each input example
1403+ in the ``s3_data_input_path`` of the :class:`~sagemaker.clarify.DataConfig`,
1404+ by creating ``num_samples`` copies of the example with a subset of features
1405+ replaced with values from the ``baseline``.
1406+ It then runs model inference to see how the model's prediction changes with the replaced
1407+ features. If the model output returns multiple scores importance is computed for each score.
1408+ Across examples, feature importance is aggregated using ``agg_method``.
1409+
1410+ When PDP is requested in the ``explainability_config``,
1411+ the PDP algorithm calculates the dependence of the target response
1412+ on the input features and marginalizes over the values of all other input features.
1413+ The Partial Dependence Plots are included in the output
1414+ `report <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-feature-attribute-baselines-reports.html>`__
1415+ and the corresponding values are included in the analysis output.
1416+
1417+ Args:
1418+ data_config (:class:`~sagemaker.clarify.DataConfig`): Config of the input/output data.
1419+ model_config (:class:`~sagemaker.clarify.ModelConfig`): Config of the model and its
1420+ endpoint to be created.
1421+ explainability_config (:class:`~sagemaker.clarify.ExplainabilityConfig` or list):
1422+ Config of the specific explainability method or a list of
1423+ :class:`~sagemaker.clarify.ExplainabilityConfig` objects.
1424+ Currently, SHAP and PDP are the two methods supported.
1425+ You can request multiple methods at once by passing in a list of
1426+ `~sagemaker.clarify.ExplainabilityConfig`.
1427+ bias_config (:class:`~sagemaker.clarify.BiasConfig`): Config of sensitive groups.
1428+ pre_training_methods (str or list[str]): Selector of a subset of potential metrics:
1429+ ["`CI <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-bias-metric-class-imbalance.html>`_",
1430+ "`DPL <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-data-bias-metric-true-label-imbalance.html>`_",
1431+ "`KL <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-data-bias-metric-kl-divergence.html>`_",
1432+ "`JS <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-data-bias-metric-jensen-shannon-divergence.html>`_",
1433+ "`LP <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-data-bias-metric-lp-norm.html>`_",
1434+ "`TVD <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-data-bias-metric-total-variation-distance.html>`_",
1435+ "`KS <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-data-bias-metric-kolmogorov-smirnov.html>`_",
1436+ "`CDDL <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-data-bias-metric-cddl.html>`_"].
1437+ Defaults to str "all" to run all metrics if left unspecified.
1438+ post_training_methods (str or list[str]): Selector of a subset of potential metrics:
1439+ ["`DPPL <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-post-training-bias-metric-dppl.html>`_"
1440+ , "`DI <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-post-training-bias-metric-di.html>`_",
1441+ "`DCA <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-post-training-bias-metric-dca.html>`_",
1442+ "`DCR <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-post-training-bias-metric-dcr.html>`_",
1443+ "`RD <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-post-training-bias-metric-rd.html>`_",
1444+ "`DAR <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-post-training-bias-metric-dar.html>`_",
1445+ "`DRR <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-post-training-bias-metric-drr.html>`_",
1446+ "`AD <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-post-training-bias-metric-ad.html>`_",
1447+ "`CDDPL <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-post-training-bias-metric-cddpl.html>`_
1448+ ", "`TE <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-post-training-bias-metric-te.html>`_",
1449+ "`FT <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-post-training-bias-metric-ft.html>`_"].
1450+ Defaults to str "all" to run all metrics if left unspecified.
1451+ model_predicted_label_config (int or str or :class:`~sagemaker.clarify.ModelPredictedLabelConfig`):
1452+ Index or JSONPath to locate the predicted scores in the model output. This is not
1453+ required if the model output is a single score. Alternatively, it can be an instance
1454+ of :class:`~sagemaker.clarify.SageMakerClarifyProcessor`
1455+ to provide more parameters like ``label_headers``.
1456+ wait (bool): Whether the call should wait until the job completes (default: True).
1457+ logs (bool): Whether to show the logs produced by the job.
1458+ Only meaningful when ``wait`` is True (default: True).
1459+ job_name (str): Processing job name. When ``job_name`` is not specified,
1460+ if ``job_name_prefix`` in :class:`~sagemaker.clarify.SageMakerClarifyProcessor`
1461+ is specified, the job name will be composed of ``job_name_prefix`` and current
1462+ timestamp; otherwise use ``"Clarify-Explainability"`` as prefix.
1463+ kms_key (str): The ARN of the KMS key that is used to encrypt the
1464+ user code file (default: None).
1465+ experiment_config (dict[str, str]): Experiment management configuration.
1466+ Optionally, the dict can contain three keys:
1467+ ``'ExperimentName'``, ``'TrialName'``, and ``'TrialComponentDisplayName'``.
1468+
1469+ The behavior of setting these keys is as follows:
1470+
1471+ * If ``'ExperimentName'`` is supplied but ``'TrialName'`` is not, a Trial will be
1472+ automatically created and the job's Trial Component associated with the Trial.
1473+ * If ``'TrialName'`` is supplied and the Trial already exists,
1474+ the job's Trial Component will be associated with the Trial.
1475+ * If both ``'ExperimentName'`` and ``'TrialName'`` are not supplied,
1476+ the Trial Component will be unassociated.
1477+ * ``'TrialComponentDisplayName'`` is used for display in Amazon SageMaker Studio.
1478+ """ # noqa E501 # pylint: disable=c0301
1479+ analysis_config = _AnalysisConfigGenerator .bias_and_explainability (
1480+ data_config ,
1481+ model_config ,
1482+ model_predicted_label_config ,
1483+ explainability_config ,
1484+ bias_config ,
1485+ pre_training_methods ,
1486+ post_training_methods ,
1487+ )
1488+ # when name is either not provided (is None) or an empty string ("")
1489+ job_name = job_name or utils .name_from_base (
1490+ self .job_name_prefix or "Clarify-Bias-And-Explainability"
1491+ )
1492+ return self ._run (
1493+ data_config ,
1494+ analysis_config ,
1495+ wait ,
1496+ logs ,
1497+ job_name ,
1498+ kms_key ,
1499+ experiment_config ,
13801500 )
13811501
13821502
0 commit comments