2525
2626import tempfile
2727from abc import ABC , abstractmethod
28+ from typing import List , Union
29+
2830from sagemaker import image_uris , s3 , utils
2931from sagemaker .processing import ProcessingInput , ProcessingOutput , Processor
3032
@@ -971,7 +973,7 @@ def _run(
971973 def run_pre_training_bias (
972974 self ,
973975 data_config ,
974- bias_config ,
976+ data_bias_config ,
975977 methods = "all" ,
976978 wait = True ,
977979 logs = True ,
@@ -986,7 +988,7 @@ def run_pre_training_bias(
986988
987989 Args:
988990 data_config (:class:`~sagemaker.clarify.DataConfig`): Config of the input/output data.
989- bias_config (:class:`~sagemaker.clarify.BiasConfig`): Config of sensitive groups.
991+ data_bias_config (:class:`~sagemaker.clarify.BiasConfig`): Config of sensitive groups.
990992 methods (str or list[str]): Selects a subset of potential metrics:
991993 ["`CI <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-bias-metric-class-imbalance.html>`_",
992994 "`DPL <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-data-bias-metric-true-label-imbalance.html>`_",
@@ -1022,7 +1024,7 @@ def run_pre_training_bias(
10221024 """ # noqa E501 # pylint: disable=c0301
10231025 analysis_config = _AnalysisConfigGenerator .bias_pre_training (
10241026 data_config ,
1025- bias_config ,
1027+ data_bias_config ,
10261028 methods
10271029 )
10281030 # when name is either not provided (is None) or an empty string ("")
@@ -1040,7 +1042,7 @@ def run_pre_training_bias(
10401042 def run_post_training_bias (
10411043 self ,
10421044 data_config ,
1043- bias_config ,
1045+ data_bias_config ,
10441046 model_config ,
10451047 model_predicted_label_config ,
10461048 methods = "all" ,
@@ -1060,7 +1062,7 @@ def run_post_training_bias(
10601062
10611063 Args:
10621064 data_config (:class:`~sagemaker.clarify.DataConfig`): Config of the input/output data.
1063- bias_config (:class:`~sagemaker.clarify.BiasConfig`): Config of sensitive groups.
1065+ data_bias_config (:class:`~sagemaker.clarify.BiasConfig`): Config of sensitive groups.
10641066 model_config (:class:`~sagemaker.clarify.ModelConfig`): Config of the model and its
10651067 endpoint to be created.
10661068 model_predicted_label_config (:class:`~sagemaker.clarify.ModelPredictedLabelConfig`):
@@ -1103,7 +1105,7 @@ def run_post_training_bias(
11031105 """ # noqa E501 # pylint: disable=c0301
11041106 analysis_config = _AnalysisConfigGenerator .bias_post_training (
11051107 data_config ,
1106- bias_config ,
1108+ data_bias_config ,
11071109 model_predicted_label_config ,
11081110 methods ,
11091111 model_config
@@ -1314,10 +1316,10 @@ class _AnalysisConfigGenerator:
13141316 @classmethod
13151317 def explainability (
13161318 cls ,
1317- data_config ,
1318- model_config ,
1319- model_scores ,
1320- explainability_config
1319+ data_config : DataConfig ,
1320+ model_config : ModelConfig ,
1321+ model_scores : ModelPredictedLabelConfig ,
1322+ explainability_config : ExplainabilityConfig ,
13211323 ):
13221324 analysis_config = data_config .get_config ()
13231325 predictor_config = model_config .get_predictor_config ()
@@ -1358,7 +1360,7 @@ def explainability(
13581360 return cls ._common (analysis_config )
13591361
13601362 @classmethod
1361- def bias_pre_training (cls , data_config , bias_config , methods ):
1363+ def bias_pre_training (cls , data_config : DataConfig , bias_config : BiasConfig , methods : Union [ str , List [ str ]] ):
13621364 analysis_config = {
13631365 ** data_config .get_config (),
13641366 ** bias_config .get_config (),
@@ -1369,11 +1371,11 @@ def bias_pre_training(cls, data_config, bias_config, methods):
13691371 @classmethod
13701372 def bias_post_training (
13711373 cls ,
1372- data_config ,
1373- bias_config ,
1374- model_predicted_label_config ,
1375- methods ,
1376- model_config
1374+ data_config : DataConfig ,
1375+ bias_config : BiasConfig ,
1376+ model_predicted_label_config : ModelPredictedLabelConfig ,
1377+ methods : Union [ str , List [ str ]] ,
1378+ model_config : ModelConfig ,
13771379 ):
13781380 analysis_config = {
13791381 ** data_config .get_config (),
@@ -1391,12 +1393,12 @@ def bias_post_training(
13911393 @classmethod
13921394 def bias (
13931395 cls ,
1394- data_config ,
1395- bias_config ,
1396- model_config ,
1397- model_predicted_label_config ,
1398- pre_training_methods = "all" ,
1399- post_training_methods = "all" ,
1396+ data_config : DataConfig ,
1397+ bias_config : BiasConfig ,
1398+ model_config : ModelConfig ,
1399+ model_predicted_label_config : ModelPredictedLabelConfig ,
1400+ pre_training_methods : Union [ str , List [ str ]] = "all" ,
1401+ post_training_methods : Union [ str , List [ str ]] = "all" ,
14001402 ):
14011403 analysis_config = {
14021404 ** data_config .get_config (),
0 commit comments