2525
2626import tempfile
2727from abc import ABC , abstractmethod
28- from typing import List , Union
28+ from typing import List , Union , Dict
2929
3030from sagemaker import image_uris , s3 , utils
3131from sagemaker .processing import ProcessingInput , ProcessingOutput , Processor
@@ -173,7 +173,11 @@ def __init__(
173173 _set (joinsource , "joinsource_name_or_index" , self .analysis_config )
174174 _set (facet_dataset_uri , "facet_dataset_uri" , self .analysis_config )
175175 _set (facet_headers , "facet_headers" , self .analysis_config )
176- _set (predicted_label_dataset_uri , "predicted_label_dataset_uri" , self .analysis_config )
176+ _set (
177+ predicted_label_dataset_uri ,
178+ "predicted_label_dataset_uri" ,
179+ self .analysis_config ,
180+ )
177181 _set (predicted_label_headers , "predicted_label_headers" , self .analysis_config )
178182 _set (predicted_label , "predicted_label" , self .analysis_config )
179183 _set (excluded_columns , "excluded_columns" , self .analysis_config )
@@ -239,7 +243,8 @@ def __init__(
239243 assert len (facet_name ) > 0 , "Please provide at least one facet"
240244 if facet_values_or_threshold is None :
241245 facet_list = [
242- {"name_or_index" : single_facet_name } for single_facet_name in facet_name
246+ {"name_or_index" : single_facet_name }
247+ for single_facet_name in facet_name
243248 ]
244249 elif len (facet_values_or_threshold ) == len (facet_name ):
245250 facet_list = []
@@ -492,7 +497,10 @@ def __init__(self, features=None, grid_resolution=15, top_k_features=10):
492497 top_k_features (int): Sets the number of top SHAP attributes used to compute
493498 partial dependence plots.
494499 """ # noqa E501
495- self .pdp_config = {"grid_resolution" : grid_resolution , "top_k_features" : top_k_features }
500+ self .pdp_config = {
501+ "grid_resolution" : grid_resolution ,
502+ "top_k_features" : top_k_features ,
503+ }
496504 if features is not None :
497505 self .pdp_config ["features" ] = features
498506
@@ -825,9 +833,14 @@ def __init__(
825833 image_config (:class:`~sagemaker.clarify.ImageConfig`): Config for handling image
826834 features. Default is None.
827835 """ # noqa E501 # pylint: disable=c0301
828- if agg_method is not None and agg_method not in ["mean_abs" , "median" , "mean_sq" ]:
836+ if agg_method is not None and agg_method not in [
837+ "mean_abs" ,
838+ "median" ,
839+ "mean_sq" ,
840+ ]:
829841 raise ValueError (
830- f"Invalid agg_method { agg_method } ." f" Please choose mean_abs, median, or mean_sq."
842+ f"Invalid agg_method { agg_method } ."
843+ f" Please choose mean_abs, median, or mean_sq."
831844 )
832845 if num_clusters is not None and baseline is not None :
833846 raise ValueError (
@@ -923,7 +936,9 @@ def __init__(
923936 job_name_prefix (str): Processing job name prefix.
924937 version (str): Clarify version to use.
925938 """ # noqa E501 # pylint: disable=c0301
926- container_uri = image_uris .retrieve ("clarify" , sagemaker_session .boto_region_name , version )
939+ container_uri = image_uris .retrieve (
940+ "clarify" , sagemaker_session .boto_region_name , version
941+ )
927942 self ._last_analysis_config = None
928943 self .job_name_prefix = job_name_prefix
929944 super (SageMakerClarifyProcessor , self ).__init__ (
@@ -996,7 +1011,8 @@ def _run(
9961011 json .dump (analysis_config , f )
9971012 s3_analysis_config_file = _upload_analysis_config (
9981013 analysis_config_file ,
999- data_config .s3_analysis_config_output_path or data_config .s3_output_path ,
1014+ data_config .s3_analysis_config_output_path
1015+ or data_config .s3_output_path ,
10001016 self .sagemaker_session ,
10011017 kms_key ,
10021018 )
@@ -1168,7 +1184,11 @@ def run_post_training_bias(
11681184 * ``'TrialComponentDisplayName'`` is used for display in Amazon SageMaker Studio.
11691185 """ # noqa E501 # pylint: disable=c0301
11701186 analysis_config = _AnalysisConfigGenerator .bias_post_training (
1171- data_config , data_bias_config , model_predicted_label_config , methods , model_config
1187+ data_config ,
1188+ data_bias_config ,
1189+ model_predicted_label_config ,
1190+ methods ,
1191+ model_config ,
11721192 )
11731193 # when name is either not provided (is None) or an empty string ("")
11741194 job_name = job_name or utils .name_from_base (
@@ -1267,7 +1287,9 @@ def run_bias(
12671287 post_training_methods ,
12681288 )
12691289 # when name is either not provided (is None) or an empty string ("")
1270- job_name = job_name or utils .name_from_base (self .job_name_prefix or "Clarify-Bias" )
1290+ job_name = job_name or utils .name_from_base (
1291+ self .job_name_prefix or "Clarify-Bias"
1292+ )
12711293 return self ._run (
12721294 data_config ,
12731295 analysis_config ,
@@ -1450,8 +1472,8 @@ def run_bias_and_explainability(
14501472 "`FT <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-post-training-bias-metric-ft.html>`_"].
14511473 Defaults to str "all" to run all metrics if left unspecified.
14521474 model_predicted_label_config (
1453- int or
1454- str or
1475+ int or
1476+ str or
14551477 :class:`~sagemaker.clarify.ModelPredictedLabelConfig`
14561478 ):
14571479 Index or JSONPath to locate the predicted scores in the model output. This is not
@@ -1552,11 +1574,16 @@ def explainability(
15521574
15531575 @classmethod
15541576 def bias_pre_training (
1555- cls , data_config : DataConfig , bias_config : BiasConfig , methods : Union [str , List [str ]]
1577+ cls ,
1578+ data_config : DataConfig ,
1579+ bias_config : BiasConfig ,
1580+ methods : Union [str , List [str ]],
15561581 ):
15571582 """Generates a config for Bias Pre Training"""
15581583 analysis_config = {** data_config .get_config (), ** bias_config .get_config ()}
1559- analysis_config = cls ._add_methods (analysis_config , pre_training_methods = methods )
1584+ analysis_config = cls ._add_methods (
1585+ analysis_config , pre_training_methods = methods
1586+ )
15601587 return analysis_config
15611588
15621589 @classmethod
@@ -1570,7 +1597,9 @@ def bias_post_training(
15701597 ):
15711598 """Generates a config for Bias Post Training"""
15721599 analysis_config = {** data_config .get_config (), ** bias_config .get_config ()}
1573- analysis_config = cls ._add_methods (analysis_config , post_training_methods = methods )
1600+ analysis_config = cls ._add_methods (
1601+ analysis_config , post_training_methods = methods
1602+ )
15741603 analysis_config = cls ._add_predictor (
15751604 analysis_config , model_config , model_predicted_label_config
15761605 )
@@ -1599,7 +1628,12 @@ def bias(
15991628 return analysis_config
16001629
16011630 @classmethod
1602- def _add_predictor (cls , analysis_config , model_config , model_predicted_label_config ):
1631+ def _add_predictor (
1632+ cls ,
1633+ analysis_config : Dict ,
1634+ model_config : ModelConfig ,
1635+ model_predicted_label_config : ModelPredictedLabelConfig ,
1636+ ):
16031637 """Extends analysis config with predictor."""
16041638 analysis_config = {** analysis_config }
16051639 analysis_config ["predictor" ] = model_config .get_predictor_config ()
@@ -1618,10 +1652,12 @@ def _add_predictor(cls, analysis_config, model_config, model_predicted_label_con
16181652 @classmethod
16191653 def _add_methods (
16201654 cls ,
1621- analysis_config ,
1622- pre_training_methods = None ,
1623- post_training_methods = None ,
1624- explainability_config = None ,
1655+ analysis_config : Dict ,
1656+ pre_training_methods : Union [str , List [str ]] = "all" ,
1657+ post_training_methods : Union [str , List [str ]] = "all" ,
1658+ explainability_config : Union [
1659+ ExplainabilityConfig , List [ExplainabilityConfig ]
1660+ ] = None ,
16251661 report = True ,
16261662 ):
16271663 """Extends analysis config with methods."""
@@ -1640,22 +1676,35 @@ def _add_methods(
16401676 analysis_config ["methods" ] = {}
16411677
16421678 if report :
1643- analysis_config ["methods" ]["report" ] = {"name" : "report" , "title" : "Analysis Report" }
1679+ analysis_config ["methods" ]["report" ] = {
1680+ "name" : "report" ,
1681+ "title" : "Analysis Report" ,
1682+ }
16441683
16451684 if pre_training_methods :
1646- analysis_config ["methods" ]["pre_training_bias" ] = {"methods" : pre_training_methods }
1685+ analysis_config ["methods" ]["pre_training_bias" ] = {
1686+ "methods" : pre_training_methods
1687+ }
16471688
16481689 if post_training_methods :
1649- analysis_config ["methods" ]["post_training_bias" ] = {"methods" : post_training_methods }
1690+ analysis_config ["methods" ]["post_training_bias" ] = {
1691+ "methods" : post_training_methods
1692+ }
16501693
16511694 if explainability_config is not None :
1652- explainability_methods = cls ._merge_explainability_configs (explainability_config )
1653- analysis_config ["methods" ] = {** analysis_config ["methods" ], ** explainability_methods }
1695+ explainability_methods = cls ._merge_explainability_configs (
1696+ explainability_config
1697+ )
1698+ analysis_config ["methods" ] = {
1699+ ** analysis_config ["methods" ],
1700+ ** explainability_methods ,
1701+ }
16541702 return analysis_config
16551703
16561704 @classmethod
16571705 def _merge_explainability_configs (
1658- cls , explainability_config : Union [ExplainabilityConfig , List [ExplainabilityConfig ]]
1706+ cls ,
1707+ explainability_config : Union [ExplainabilityConfig , List [ExplainabilityConfig ]],
16591708 ):
16601709 """Merges explainability configs, when more than one."""
16611710 if isinstance (explainability_config , list ):
@@ -1671,17 +1720,24 @@ def _merge_explainability_configs(
16711720 "shap" not in explainability_methods
16721721 and "features" not in explainability_methods ["pdp" ]
16731722 ):
1674- raise ValueError ("PDP features must be provided when ShapConfig is not provided" )
1723+ raise ValueError (
1724+ "PDP features must be provided when ShapConfig is not provided"
1725+ )
16751726 return explainability_methods
16761727 if (
16771728 isinstance (explainability_config , PDPConfig )
1678- and "features" not in explainability_config .get_explainability_config ()["pdp" ]
1729+ and "features"
1730+ not in explainability_config .get_explainability_config ()["pdp" ]
16791731 ):
1680- raise ValueError ("PDP features must be provided when ShapConfig is not provided" )
1732+ raise ValueError (
1733+ "PDP features must be provided when ShapConfig is not provided"
1734+ )
16811735 return explainability_config .get_explainability_config ()
16821736
16831737
1684- def _upload_analysis_config (analysis_config_file , s3_output_path , sagemaker_session , kms_key ):
1738+ def _upload_analysis_config (
1739+ analysis_config_file , s3_output_path , sagemaker_session , kms_key
1740+ ):
16851741 """Uploads the local ``analysis_config_file`` to the ``s3_output_path``.
16861742
16871743 Args:
0 commit comments