@@ -1368,68 +1368,70 @@ def run_explainability(
13681368 experiment_config ,
13691369 )
13701370
1371+ def run_bias_and_explainability (self ):
1372+ """
1373+ TODO:
1374+ - add doc string
1375+ - add logic
1376+ - add tests
1377+ """
1378+ raise NotImplementedError (
1379+ "Please choose a method of run_pre_training_bias, run_post_training_bias or run_explainability."
1380+ )
1381+
13711382
13721383class _AnalysisConfigGenerator :
13731384 """
13741385 Creates analysis_config objects for different type of runs.
13751386 """
13761387
13771388 @classmethod
1378- def explainability (
1389+ def bias_and_explainability (
13791390 cls ,
13801391 data_config : DataConfig ,
13811392 model_config : ModelConfig ,
1382- model_scores : ModelPredictedLabelConfig ,
1383- explainability_config : ExplainabilityConfig ,
1393+ model_predicted_label_config : ModelPredictedLabelConfig ,
1394+ explainability_config : Union [ExplainabilityConfig , List [ExplainabilityConfig ]],
1395+ bias_config : BiasConfig ,
1396+ pre_training_methods : Union [str , List [str ]] = "all" ,
1397+ post_training_methods : Union [str , List [str ]] = "all" ,
13841398 ):
1385- analysis_config = data_config .get_config ()
1386- predictor_config = model_config . get_predictor_config ()
1387- if isinstance ( model_scores , ModelPredictedLabelConfig ):
1388- (
1389- probability_threshold ,
1390- predicted_label_config ,
1391- ) = model_scores . get_predictor_config ( )
1392- _set ( probability_threshold , "probability_threshold" , analysis_config )
1393- predictor_config . update ( predicted_label_config )
1394- else :
1395- _set ( model_scores , "label" , predictor_config )
1399+ analysis_config = { ** data_config .get_config (), ** bias_config . get_config ()}
1400+ analysis_config = cls . _add_methods (
1401+ analysis_config ,
1402+ pre_training_methods = pre_training_methods ,
1403+ post_training_methods = post_training_methods ,
1404+ explainability_config = explainability_config ,
1405+ )
1406+ analysis_config = cls . _add_predictor (
1407+ analysis_config , model_config , model_predicted_label_config
1408+ )
1409+ return analysis_config
13961410
1397- explainability_methods = {}
1398- if isinstance (explainability_config , list ):
1399- if len (explainability_config ) == 0 :
1400- raise ValueError ("Please provide at least one explainability config." )
1401- for config in explainability_config :
1402- explain_config = config .get_explainability_config ()
1403- explainability_methods .update (explain_config )
1404- if not len (explainability_methods .keys ()) == len (explainability_config ):
1405- raise ValueError ("Duplicate explainability configs are provided" )
1406- if (
1407- "shap" not in explainability_methods
1408- and explainability_methods ["pdp" ].get ("features" , None ) is None
1409- ):
1410- raise ValueError ("PDP features must be provided when ShapConfig is not provided" )
1411- else :
1412- if (
1413- isinstance (explainability_config , PDPConfig )
1414- and explainability_config .get_explainability_config ()["pdp" ].get ("features" , None )
1415- is None
1416- ):
1417- raise ValueError ("PDP features must be provided when ShapConfig is not provided" )
1418- explainability_methods = explainability_config .get_explainability_config ()
1419- analysis_config ["methods" ] = explainability_methods
1420- analysis_config ["predictor" ] = predictor_config
1421- return cls ._common (analysis_config )
1411+ @classmethod
1412+ def explainability (
1413+ cls ,
1414+ data_config : DataConfig ,
1415+ model_config : ModelConfig ,
1416+ model_predicted_label_config : ModelPredictedLabelConfig ,
1417+ explainability_config : Union [ExplainabilityConfig , List [ExplainabilityConfig ]],
1418+ ):
1419+ analysis_config = data_config .analysis_config
1420+ analysis_config = cls ._add_predictor (
1421+ analysis_config , model_config , model_predicted_label_config
1422+ )
1423+ analysis_config = cls ._add_methods (
1424+ analysis_config , explainability_config = explainability_config
1425+ )
1426+ return analysis_config
14221427
14231428 @classmethod
14241429 def bias_pre_training (
14251430 cls , data_config : DataConfig , bias_config : BiasConfig , methods : Union [str , List [str ]]
14261431 ):
1427- analysis_config = {
1428- ** data_config .get_config (),
1429- ** bias_config .get_config (),
1430- "methods" : {"pre_training_bias" : {"methods" : methods }},
1431- }
1432- return cls ._common (analysis_config )
1432+ analysis_config = {** data_config .get_config (), ** bias_config .get_config ()}
1433+ analysis_config = cls ._add_methods (analysis_config , pre_training_methods = methods )
1434+ return analysis_config
14331435
14341436 @classmethod
14351437 def bias_post_training (
@@ -1440,21 +1442,12 @@ def bias_post_training(
14401442 methods : Union [str , List [str ]],
14411443 model_config : ModelConfig ,
14421444 ):
1443- analysis_config = {
1444- ** data_config .get_config (),
1445- ** bias_config .get_config (),
1446- "predictor" : {** model_config .get_predictor_config ()},
1447- "methods" : {"post_training_bias" : {"methods" : methods }},
1448- }
1449- if model_predicted_label_config :
1450- (
1451- probability_threshold ,
1452- predictor_config ,
1453- ) = model_predicted_label_config .get_predictor_config ()
1454- if predictor_config :
1455- analysis_config ["predictor" ].update (predictor_config )
1456- _set (probability_threshold , "probability_threshold" , analysis_config )
1457- return cls ._common (analysis_config )
1445+ analysis_config = {** data_config .get_config (), ** bias_config .get_config ()}
1446+ analysis_config = cls ._add_methods (analysis_config , post_training_methods = methods )
1447+ analysis_config = cls ._add_predictor (
1448+ analysis_config , model_config , model_predicted_label_config
1449+ )
1450+ return analysis_config
14581451
14591452 @classmethod
14601453 def bias (
@@ -1466,33 +1459,96 @@ def bias(
14661459 pre_training_methods : Union [str , List [str ]] = "all" ,
14671460 post_training_methods : Union [str , List [str ]] = "all" ,
14681461 ):
1469- analysis_config = {
1470- ** data_config .get_config (),
1471- ** bias_config .get_config (),
1472- "predictor" : model_config .get_predictor_config (),
1473- "methods" : {
1474- "pre_training_bias" : {"methods" : pre_training_methods },
1475- "post_training_bias" : {"methods" : post_training_methods },
1476- },
1477- }
1478- if model_predicted_label_config :
1462+ analysis_config = {** data_config .get_config (), ** bias_config .get_config ()}
1463+ analysis_config = cls ._add_methods (
1464+ analysis_config ,
1465+ pre_training_methods = pre_training_methods ,
1466+ post_training_methods = post_training_methods ,
1467+ )
1468+ analysis_config = cls ._add_predictor (
1469+ analysis_config , model_config , model_predicted_label_config
1470+ )
1471+ return analysis_config
1472+
1473+ @classmethod
1474+ def _add_predictor (cls , analysis_config , model_config , model_predicted_label_config ):
1475+ analysis_config = {** analysis_config }
1476+ analysis_config ["predictor" ] = model_config .get_predictor_config ()
1477+ if isinstance (model_predicted_label_config , ModelPredictedLabelConfig ):
14791478 (
14801479 probability_threshold ,
14811480 predictor_config ,
14821481 ) = model_predicted_label_config .get_predictor_config ()
14831482 if predictor_config :
14841483 analysis_config ["predictor" ].update (predictor_config )
14851484 _set (probability_threshold , "probability_threshold" , analysis_config )
1486- return cls ._common (analysis_config )
1485+ else :
1486+ _set (model_predicted_label_config , "label" , analysis_config ["predictor" ])
1487+ return analysis_config
14871488
1488- @staticmethod
1489- def _common (analysis_config ):
1490- analysis_config ["methods" ]["report" ] = {
1491- "name" : "report" ,
1492- "title" : "Analysis Report" ,
1493- }
1489+ @classmethod
1490+ def _add_methods (
1491+ cls ,
1492+ analysis_config ,
1493+ pre_training_methods = None ,
1494+ post_training_methods = None ,
1495+ explainability_config = None ,
1496+ report = True ,
1497+ ):
1498+ # validate
1499+ params = [pre_training_methods , post_training_methods , explainability_config ]
1500+ if all ([1 if p is None else 0 for p in params ]):
1501+ raise AttributeError (
1502+ "analysis_config must have at least one working method: "
1503+ "One of the `pre_training_methods`, `post_training_methods`, `explainability_config`."
1504+ )
1505+
1506+ # main logic
1507+ analysis_config = {** analysis_config }
1508+ if "methods" not in analysis_config :
1509+ analysis_config ["methods" ] = {}
1510+
1511+ if report :
1512+ analysis_config ["methods" ]["report" ] = {"name" : "report" , "title" : "Analysis Report" }
1513+
1514+ if pre_training_methods :
1515+ analysis_config ["methods" ]["pre_training_bias" ] = {"methods" : pre_training_methods }
1516+
1517+ if post_training_methods :
1518+ analysis_config ["methods" ]["post_training_bias" ] = {"methods" : post_training_methods }
1519+
1520+ if explainability_config is not None :
1521+ explainability_methods = cls ._merge_explainability_configs (explainability_config )
1522+ analysis_config ["methods" ] = {** analysis_config ["methods" ], ** explainability_methods }
14941523 return analysis_config
14951524
1525+ @classmethod
1526+ def _merge_explainability_configs (
1527+ cls , explainability_config : Union [ExplainabilityConfig , List [ExplainabilityConfig ]]
1528+ ):
1529+ if isinstance (explainability_config , list ):
1530+ explainability_methods = {}
1531+ if len (explainability_config ) == 0 :
1532+ raise ValueError ("Please provide at least one explainability config." )
1533+ for config in explainability_config :
1534+ explain_config = config .get_explainability_config ()
1535+ explainability_methods .update (explain_config )
1536+ if not len (explainability_methods ) == len (explainability_config ):
1537+ raise ValueError ("Duplicate explainability configs are provided" )
1538+ if (
1539+ "shap" not in explainability_methods
1540+ and "features" not in explainability_methods ["pdp" ]
1541+ ):
1542+ raise ValueError ("PDP features must be provided when ShapConfig is not provided" )
1543+ return explainability_methods
1544+ else :
1545+ if (
1546+ isinstance (explainability_config , PDPConfig )
1547+ and "features" not in explainability_config .get_explainability_config ()["pdp" ]
1548+ ):
1549+ raise ValueError ("PDP features must be provided when ShapConfig is not provided" )
1550+ return explainability_config .get_explainability_config ()
1551+
14961552
14971553def _upload_analysis_config (analysis_config_file , s3_output_path , sagemaker_session , kms_key ):
14981554 """Uploads the local ``analysis_config_file`` to the ``s3_output_path``.
0 commit comments