|
25 | 25 |
|
26 | 26 | import tempfile |
27 | 27 | from abc import ABC, abstractmethod |
| 28 | +from typing import List, Union |
| 29 | + |
28 | 30 | from sagemaker import image_uris, s3, utils |
29 | 31 | from sagemaker.processing import ProcessingInput, ProcessingOutput, Processor |
30 | 32 |
|
@@ -63,7 +65,6 @@ def __init__( |
63 | 65 | label (str): Target attribute of the model required by bias metrics. |
64 | 66 | Specified as column name or index for CSV dataset or as JSONPath for JSONLines. |
65 | 67 | *Required parameter* except for when the input dataset does not contain the label. |
66 | | - Cannot be used at the same time as ``predicted_label``. |
67 | 68 | features (str): JSONPath for locating the feature columns for bias metrics if the |
68 | 69 | dataset format is JSONLines. |
69 | 70 | dataset_type (str): Format of the dataset. Valid values are ``"text/csv"`` for CSV, |
@@ -103,7 +104,7 @@ def __init__( |
103 | 104 | predicted_label (str or int): Predicted label of the target attribute of the model |
104 | 105 | required for running bias analysis. Specified as column name or index for CSV data. |
105 | 106 | Clarify uses the predicted labels directly instead of making model inference API |
106 | | - calls. Cannot be used at the same time as ``label``. |
| 107 | + calls. |
107 | 108 | excluded_columns (list[int] or list[str]): A list of names or indices of the columns |
108 | 109 | which are to be excluded from making model inference API calls. |
109 | 110 |
|
@@ -922,6 +923,7 @@ def __init__( |
922 | 923 | version (str): Clarify version to use. |
923 | 924 | """ # noqa E501 # pylint: disable=c0301 |
924 | 925 | container_uri = image_uris.retrieve("clarify", sagemaker_session.boto_region_name, version) |
| 926 | + self._last_analysis_config = None |
925 | 927 | self.job_name_prefix = job_name_prefix |
926 | 928 | super(SageMakerClarifyProcessor, self).__init__( |
927 | 929 | role, |
@@ -983,10 +985,10 @@ def _run( |
983 | 985 | the Trial Component will be unassociated. |
984 | 986 | * ``'TrialComponentDisplayName'`` is used for display in Amazon SageMaker Studio. |
985 | 987 | """ |
986 | | - analysis_config["methods"]["report"] = { |
987 | | - "name": "report", |
988 | | - "title": "Analysis Report", |
989 | | - } |
| 988 | + # for debugging: to access locally, i.e. without a need to look for it in an S3 bucket |
| 989 | + self._last_analysis_config = analysis_config |
| 990 | + logger.info("Analysis Config: %s", analysis_config) |
| 991 | + |
990 | 992 | with tempfile.TemporaryDirectory() as tmpdirname: |
991 | 993 | analysis_config_file = os.path.join(tmpdirname, "analysis_config.json") |
992 | 994 | with open(analysis_config_file, "w") as f: |
@@ -1083,14 +1085,13 @@ def run_pre_training_bias( |
1083 | 1085 | the Trial Component will be unassociated. |
1084 | 1086 | * ``'TrialComponentDisplayName'`` is used for display in Amazon SageMaker Studio. |
1085 | 1087 | """ # noqa E501 # pylint: disable=c0301 |
1086 | | - analysis_config = data_config.get_config() |
1087 | | - analysis_config.update(data_bias_config.get_config()) |
1088 | | - analysis_config["methods"] = {"pre_training_bias": {"methods": methods}} |
1089 | | - if job_name is None: |
1090 | | - if self.job_name_prefix: |
1091 | | - job_name = utils.name_from_base(self.job_name_prefix) |
1092 | | - else: |
1093 | | - job_name = utils.name_from_base("Clarify-Pretraining-Bias") |
| 1088 | + analysis_config = _AnalysisConfigGenerator.bias_pre_training( |
| 1089 | + data_config, data_bias_config, methods |
| 1090 | + ) |
| 1091 | + # when name is either not provided (is None) or an empty string ("") |
| 1092 | + job_name = job_name or utils.name_from_base( |
| 1093 | + self.job_name_prefix or "Clarify-Pretraining-Bias" |
| 1094 | + ) |
1094 | 1095 | return self._run( |
1095 | 1096 | data_config, |
1096 | 1097 | analysis_config, |
@@ -1165,21 +1166,13 @@ def run_post_training_bias( |
1165 | 1166 | the Trial Component will be unassociated. |
1166 | 1167 | * ``'TrialComponentDisplayName'`` is used for display in Amazon SageMaker Studio. |
1167 | 1168 | """ # noqa E501 # pylint: disable=c0301 |
1168 | | - analysis_config = data_config.get_config() |
1169 | | - analysis_config.update(data_bias_config.get_config()) |
1170 | | - ( |
1171 | | - probability_threshold, |
1172 | | - predictor_config, |
1173 | | - ) = model_predicted_label_config.get_predictor_config() |
1174 | | - predictor_config.update(model_config.get_predictor_config()) |
1175 | | - analysis_config["methods"] = {"post_training_bias": {"methods": methods}} |
1176 | | - analysis_config["predictor"] = predictor_config |
1177 | | - _set(probability_threshold, "probability_threshold", analysis_config) |
1178 | | - if job_name is None: |
1179 | | - if self.job_name_prefix: |
1180 | | - job_name = utils.name_from_base(self.job_name_prefix) |
1181 | | - else: |
1182 | | - job_name = utils.name_from_base("Clarify-Posttraining-Bias") |
| 1169 | + analysis_config = _AnalysisConfigGenerator.bias_post_training( |
| 1170 | + data_config, data_bias_config, model_predicted_label_config, methods, model_config |
| 1171 | + ) |
| 1172 | + # when name is either not provided (is None) or an empty string ("") |
| 1173 | + job_name = job_name or utils.name_from_base( |
| 1174 | + self.job_name_prefix or "Clarify-Posttraining-Bias" |
| 1175 | + ) |
1183 | 1176 | return self._run( |
1184 | 1177 | data_config, |
1185 | 1178 | analysis_config, |
@@ -1264,28 +1257,16 @@ def run_bias( |
1264 | 1257 | the Trial Component will be unassociated. |
1265 | 1258 | * ``'TrialComponentDisplayName'`` is used for display in Amazon SageMaker Studio. |
1266 | 1259 | """ # noqa E501 # pylint: disable=c0301 |
1267 | | - analysis_config = data_config.get_config() |
1268 | | - analysis_config.update(bias_config.get_config()) |
1269 | | - analysis_config["predictor"] = model_config.get_predictor_config() |
1270 | | - if model_predicted_label_config: |
1271 | | - ( |
1272 | | - probability_threshold, |
1273 | | - predictor_config, |
1274 | | - ) = model_predicted_label_config.get_predictor_config() |
1275 | | - if predictor_config: |
1276 | | - analysis_config["predictor"].update(predictor_config) |
1277 | | - if probability_threshold is not None: |
1278 | | - analysis_config["probability_threshold"] = probability_threshold |
1279 | | - |
1280 | | - analysis_config["methods"] = { |
1281 | | - "pre_training_bias": {"methods": pre_training_methods}, |
1282 | | - "post_training_bias": {"methods": post_training_methods}, |
1283 | | - } |
1284 | | - if job_name is None: |
1285 | | - if self.job_name_prefix: |
1286 | | - job_name = utils.name_from_base(self.job_name_prefix) |
1287 | | - else: |
1288 | | - job_name = utils.name_from_base("Clarify-Bias") |
| 1260 | + analysis_config = _AnalysisConfigGenerator.bias( |
| 1261 | + data_config, |
| 1262 | + bias_config, |
| 1263 | + model_config, |
| 1264 | + model_predicted_label_config, |
| 1265 | + pre_training_methods, |
| 1266 | + post_training_methods, |
| 1267 | + ) |
| 1268 | + # when name is either not provided (is None) or an empty string ("") |
| 1269 | + job_name = job_name or utils.name_from_base(self.job_name_prefix or "Clarify-Bias") |
1289 | 1270 | return self._run( |
1290 | 1271 | data_config, |
1291 | 1272 | analysis_config, |
@@ -1370,6 +1351,36 @@ def run_explainability( |
1370 | 1351 | the Trial Component will be unassociated. |
1371 | 1352 | * ``'TrialComponentDisplayName'`` is used for display in Amazon SageMaker Studio. |
1372 | 1353 | """ # noqa E501 # pylint: disable=c0301 |
| 1354 | + analysis_config = _AnalysisConfigGenerator.explainability( |
| 1355 | + data_config, model_config, model_scores, explainability_config |
| 1356 | + ) |
| 1357 | + # when name is either not provided (is None) or an empty string ("") |
| 1358 | + job_name = job_name or utils.name_from_base( |
| 1359 | + self.job_name_prefix or "Clarify-Explainability" |
| 1360 | + ) |
| 1361 | + return self._run( |
| 1362 | + data_config, |
| 1363 | + analysis_config, |
| 1364 | + wait, |
| 1365 | + logs, |
| 1366 | + job_name, |
| 1367 | + kms_key, |
| 1368 | + experiment_config, |
| 1369 | + ) |
| 1370 | + |
| 1371 | + |
| 1372 | +class _AnalysisConfigGenerator: |
| 1373 | + """Creates analysis_config objects for different type of runs.""" |
| 1374 | + |
| 1375 | + @classmethod |
| 1376 | + def explainability( |
| 1377 | + cls, |
| 1378 | + data_config: DataConfig, |
| 1379 | + model_config: ModelConfig, |
| 1380 | + model_scores: ModelPredictedLabelConfig, |
| 1381 | + explainability_config: ExplainabilityConfig, |
| 1382 | + ): |
| 1383 | + """Generates a config for Explainability""" |
1373 | 1384 | analysis_config = data_config.get_config() |
1374 | 1385 | predictor_config = model_config.get_predictor_config() |
1375 | 1386 | if isinstance(model_scores, ModelPredictedLabelConfig): |
@@ -1406,20 +1417,84 @@ def run_explainability( |
1406 | 1417 | explainability_methods = explainability_config.get_explainability_config() |
1407 | 1418 | analysis_config["methods"] = explainability_methods |
1408 | 1419 | analysis_config["predictor"] = predictor_config |
1409 | | - if job_name is None: |
1410 | | - if self.job_name_prefix: |
1411 | | - job_name = utils.name_from_base(self.job_name_prefix) |
1412 | | - else: |
1413 | | - job_name = utils.name_from_base("Clarify-Explainability") |
1414 | | - return self._run( |
1415 | | - data_config, |
1416 | | - analysis_config, |
1417 | | - wait, |
1418 | | - logs, |
1419 | | - job_name, |
1420 | | - kms_key, |
1421 | | - experiment_config, |
1422 | | - ) |
| 1420 | + return cls._common(analysis_config) |
| 1421 | + |
| 1422 | + @classmethod |
| 1423 | + def bias_pre_training( |
| 1424 | + cls, data_config: DataConfig, bias_config: BiasConfig, methods: Union[str, List[str]] |
| 1425 | + ): |
| 1426 | + """Generates a config for Bias Pre Training""" |
| 1427 | + analysis_config = { |
| 1428 | + **data_config.get_config(), |
| 1429 | + **bias_config.get_config(), |
| 1430 | + "methods": {"pre_training_bias": {"methods": methods}}, |
| 1431 | + } |
| 1432 | + return cls._common(analysis_config) |
| 1433 | + |
| 1434 | + @classmethod |
| 1435 | + def bias_post_training( |
| 1436 | + cls, |
| 1437 | + data_config: DataConfig, |
| 1438 | + bias_config: BiasConfig, |
| 1439 | + model_predicted_label_config: ModelPredictedLabelConfig, |
| 1440 | + methods: Union[str, List[str]], |
| 1441 | + model_config: ModelConfig, |
| 1442 | + ): |
| 1443 | + """Generates a config for Bias Post Training""" |
| 1444 | + analysis_config = { |
| 1445 | + **data_config.get_config(), |
| 1446 | + **bias_config.get_config(), |
| 1447 | + "predictor": {**model_config.get_predictor_config()}, |
| 1448 | + "methods": {"post_training_bias": {"methods": methods}}, |
| 1449 | + } |
| 1450 | + if model_predicted_label_config: |
| 1451 | + ( |
| 1452 | + probability_threshold, |
| 1453 | + predictor_config, |
| 1454 | + ) = model_predicted_label_config.get_predictor_config() |
| 1455 | + if predictor_config: |
| 1456 | + analysis_config["predictor"].update(predictor_config) |
| 1457 | + _set(probability_threshold, "probability_threshold", analysis_config) |
| 1458 | + return cls._common(analysis_config) |
| 1459 | + |
| 1460 | + @classmethod |
| 1461 | + def bias( |
| 1462 | + cls, |
| 1463 | + data_config: DataConfig, |
| 1464 | + bias_config: BiasConfig, |
| 1465 | + model_config: ModelConfig, |
| 1466 | + model_predicted_label_config: ModelPredictedLabelConfig, |
| 1467 | + pre_training_methods: Union[str, List[str]] = "all", |
| 1468 | + post_training_methods: Union[str, List[str]] = "all", |
| 1469 | + ): |
| 1470 | + """Generates a config for Bias""" |
| 1471 | + analysis_config = { |
| 1472 | + **data_config.get_config(), |
| 1473 | + **bias_config.get_config(), |
| 1474 | + "predictor": model_config.get_predictor_config(), |
| 1475 | + "methods": { |
| 1476 | + "pre_training_bias": {"methods": pre_training_methods}, |
| 1477 | + "post_training_bias": {"methods": post_training_methods}, |
| 1478 | + }, |
| 1479 | + } |
| 1480 | + if model_predicted_label_config: |
| 1481 | + ( |
| 1482 | + probability_threshold, |
| 1483 | + predictor_config, |
| 1484 | + ) = model_predicted_label_config.get_predictor_config() |
| 1485 | + if predictor_config: |
| 1486 | + analysis_config["predictor"].update(predictor_config) |
| 1487 | + _set(probability_threshold, "probability_threshold", analysis_config) |
| 1488 | + return cls._common(analysis_config) |
| 1489 | + |
| 1490 | + @staticmethod |
| 1491 | + def _common(analysis_config): |
| 1492 | + """Extends analysis config with common values""" |
| 1493 | + analysis_config["methods"]["report"] = { |
| 1494 | + "name": "report", |
| 1495 | + "title": "Analysis Report", |
| 1496 | + } |
| 1497 | + return analysis_config |
1423 | 1498 |
|
1424 | 1499 |
|
1425 | 1500 | def _upload_analysis_config(analysis_config_file, s3_output_path, sagemaker_session, kms_key): |
|
0 commit comments