@@ -1304,6 +1304,9 @@ def run_explainability(
13041304
13051305
13061306class _AnalysisConfigGenerator :
1307+ """
1308+ Creates analysis_config objects for different type of runs.
1309+ """
13071310 @classmethod
13081311 def explainability (
13091312 cls ,
@@ -1334,15 +1337,15 @@ def explainability(
13341337 if not len (explainability_methods .keys ()) == len (explainability_config ):
13351338 raise ValueError ("Duplicate explainability configs are provided" )
13361339 if (
1337- "shap" not in explainability_methods
1338- and explainability_methods ["pdp" ].get ("features" , None ) is None
1340+ "shap" not in explainability_methods
1341+ and explainability_methods ["pdp" ].get ("features" , None ) is None
13391342 ):
13401343 raise ValueError ("PDP features must be provided when ShapConfig is not provided" )
13411344 else :
13421345 if (
1343- isinstance (explainability_config , PDPConfig )
1344- and explainability_config .get_explainability_config ()["pdp" ].get ("features" , None )
1345- is None
1346+ isinstance (explainability_config , PDPConfig )
1347+ and explainability_config .get_explainability_config ()["pdp" ].get ("features" , None )
1348+ is None
13461349 ):
13471350 raise ValueError ("PDP features must be provided when ShapConfig is not provided" )
13481351 explainability_methods = explainability_config .get_explainability_config ()
@@ -1352,9 +1355,11 @@ def explainability(
13521355
13531356 @classmethod
13541357 def bias_pre_training (cls , data_config , bias_config , methods ):
1355- analysis_config = data_config .get_config ()
1356- analysis_config .update (bias_config .get_config ())
1357- analysis_config ["methods" ] = {"pre_training_bias" : {"methods" : methods }}
1358+ analysis_config = {
1359+ ** data_config .get_config (),
1360+ ** bias_config .get_config (),
1361+ "methods" : {"pre_training_bias" : {"methods" : methods }}
1362+ }
13581363 return cls ._common (analysis_config )
13591364
13601365 @classmethod
@@ -1366,16 +1371,17 @@ def bias_post_training(
13661371 methods ,
13671372 model_config
13681373 ):
1369- analysis_config = data_config .get_config ()
1370- analysis_config .update (bias_config .get_config ())
1371- analysis_config ["methods" ] = {"post_training_bias" : {"methods" : methods }}
1372- (
1373- probability_threshold ,
1374- predictor_config ,
1375- ) = model_predicted_label_config .get_predictor_config ()
1376- predictor_config .update (model_config .get_predictor_config ())
1377- analysis_config ["predictor" ] = predictor_config
1378- _set (probability_threshold , "probability_threshold" , analysis_config )
1374+ analysis_config = {
1375+ ** data_config .get_config (),
1376+ ** bias_config .get_config (),
1377+ "predictor" : {** model_config .get_predictor_config ()},
1378+ "methods" : {"post_training_bias" : {"methods" : methods }},
1379+ }
1380+ if model_predicted_label_config :
1381+ probability_threshold , predictor_config = model_predicted_label_config .get_predictor_config ()
1382+ if predictor_config :
1383+ analysis_config ["predictor" ].update (predictor_config )
1384+ _set (probability_threshold , "probability_threshold" , analysis_config )
13791385 return cls ._common (analysis_config )
13801386
13811387 @classmethod
@@ -1388,23 +1394,20 @@ def bias(
13881394 pre_training_methods = "all" ,
13891395 post_training_methods = "all" ,
13901396 ):
1391- analysis_config = data_config .get_config ()
1392- analysis_config .update (bias_config .get_config ())
1393- analysis_config ["predictor" ] = model_config .get_predictor_config ()
1397+ analysis_config = {
1398+ ** data_config .get_config (),
1399+ ** bias_config .get_config (),
1400+ "predictor" : model_config .get_predictor_config (),
1401+ "methods" : {
1402+ "pre_training_bias" : {"methods" : pre_training_methods },
1403+ "post_training_bias" : {"methods" : post_training_methods },
1404+ }
1405+ }
13941406 if model_predicted_label_config :
1395- (
1396- probability_threshold ,
1397- predictor_config ,
1398- ) = model_predicted_label_config .get_predictor_config ()
1407+ probability_threshold , predictor_config = model_predicted_label_config .get_predictor_config ()
13991408 if predictor_config :
14001409 analysis_config ["predictor" ].update (predictor_config )
1401- if probability_threshold is not None :
1402- analysis_config ["probability_threshold" ] = probability_threshold
1403-
1404- analysis_config ["methods" ] = {
1405- "pre_training_bias" : {"methods" : pre_training_methods },
1406- "post_training_bias" : {"methods" : post_training_methods },
1407- }
1410+ _set (probability_threshold , "probability_threshold" , analysis_config )
14081411 return cls ._common (analysis_config )
14091412
14101413 @staticmethod
0 commit comments