From 00406cfe29cf21dc76ba79071eb0bc99cbde8ef2 Mon Sep 17 00:00:00 2001 From: "Malte S. Kurz" Date: Fri, 3 Dec 2021 14:13:01 +0100 Subject: [PATCH 1/2] rename abstract methods --- doc/oop.svg | 1129 +++++++++++--------------- doubleml/double_ml.py | 28 +- doubleml/double_ml_iivm.py | 6 +- doubleml/double_ml_irm.py | 6 +- doubleml/double_ml_model_template.py | 167 ++++ doubleml/double_ml_pliv.py | 30 +- doubleml/double_ml_plr.py | 6 +- 7 files changed, 697 insertions(+), 675 deletions(-) create mode 100644 doubleml/double_ml_model_template.py diff --git a/doc/oop.svg b/doc/oop.svg index c8dde319..c63dc840 100644 --- a/doc/oop.svg +++ b/doc/oop.svg @@ -1,10 +1,10 @@ - + - - - + + + @@ -43,26 +43,26 @@ - - - - - - - - - - - - - - - - - + + + + + + + + + + + + + + + + + - + @@ -73,28 +73,28 @@ - + - - - + + + - - + + - - - - - - - - - - + + + + + + + + + + @@ -106,8 +106,8 @@ - - + + @@ -119,8 +119,8 @@ - - + + @@ -129,7 +129,7 @@ - + @@ -138,8 +138,8 @@ - - + + @@ -165,8 +165,8 @@ - - + + @@ -200,96 +200,67 @@ - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + - - + + - - - - - - - - - - - - - + + + + + + + + + + + + + @@ -301,8 +272,8 @@ - - + + @@ -320,7 +291,7 @@ - + @@ -329,97 +300,68 @@ - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + - - + + - - - - - - - - - - - - - - + + + + + + + + + + + + + + @@ -431,8 +373,8 @@ - - + + @@ -450,7 +392,7 @@ - + @@ -459,96 +401,67 @@ - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + - - + + - - - - - - - - - - - - - + + + + + + + + + + + + + @@ -560,8 +473,8 @@ - - + + @@ -579,7 +492,7 @@ - + @@ -588,97 +501,68 @@ - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + - - + + - - - - - - - - - - - - - - + + + + + + + + + + + + + + @@ -690,8 +574,8 @@ - - + + @@ -709,7 +593,7 @@ - + @@ -718,76 +602,47 @@ - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + @@ -808,14 +663,14 @@ - - - - - - - - + + + + + + + + @@ -851,9 +706,9 @@ - - - + + + @@ -874,14 +729,14 @@ - - - - - - - - + + + + + + + + @@ -917,14 +772,14 @@ - - - + + + - + @@ -933,14 +788,14 @@ - - - - - - - - + + + + + + + + @@ -1064,7 +919,7 @@ - + @@ -1122,14 +977,14 @@ - - - - - - - - + + + + + + + + @@ -1147,7 +1002,7 @@ - + @@ -1169,7 +1024,7 @@ - + @@ -1197,7 +1052,7 @@ - + @@ -1261,7 +1116,7 @@ - + @@ -1339,7 +1194,7 @@ - + @@ -1365,7 +1220,7 @@ - + @@ -1540,7 +1395,7 @@ - + @@ -1549,14 +1404,14 @@ - - - - - - - - + + + + + + + + @@ -1680,7 +1535,7 @@ - + @@ -1738,14 +1593,14 @@ - - - - - - - - + + + + + + + + @@ -1763,7 +1618,7 @@ - + @@ -1785,7 +1640,7 @@ - + @@ -1813,7 +1668,7 @@ - + @@ -1877,7 +1732,7 @@ - + @@ -1955,7 +1810,7 @@ - + @@ -1981,7 +1836,7 @@ - + @@ -2170,55 +2025,55 @@ - - - - - - - - - - - + + + + + + + + + + + - - - - - - - - - - - - + + + + + + + + + + + + - - - - - - - - - - - + + + + + + + + + + + - - - - - - - - - - - - + + + + + + + + + + + + @@ -2239,7 +2094,7 @@ - + @@ -2313,7 +2168,7 @@ - + @@ -2375,55 +2230,55 @@ - - - - - - - - - - - + + + + + + + + + + + - - - - - - - - - - - - + + + + + + + + + + + + - - - - - - - - - - - + + + + + + + + + + + - - - - - - - - - - - - + + + + + + + + + + + + @@ -2444,7 +2299,7 @@ - + @@ -2518,7 +2373,7 @@ - + diff --git a/doubleml/double_ml.py b/doubleml/double_ml.py index 77117ab7..0d200c2c 100644 --- a/doubleml/double_ml.py +++ b/doubleml/double_ml.py @@ -474,7 +474,7 @@ def fit(self, n_jobs_cv=None, keep_scores=True, store_predictions=False): # ml estimation of nuisance models and computation of score elements self._psi_a[:, self._i_rep, self._i_treat], self._psi_b[:, self._i_rep, self._i_treat], preds =\ - self._ml_nuisance_and_score_elements(self.__smpls, n_jobs_cv) + self._nuisance_est(self.__smpls, n_jobs_cv) if store_predictions: self._store_predictions(preds) @@ -799,11 +799,11 @@ def tune(self, self._i_rep = i_rep # tune hyperparameters - res = self._ml_nuisance_tuning(self.__smpls, - param_grids, scoring_methods, - n_folds_tune, - n_jobs_cv, - search_mode, n_iter_randomized_search) + res = self._nuisance_tuning(self.__smpls, + param_grids, scoring_methods, + n_folds_tune, + n_jobs_cv, + search_mode, n_iter_randomized_search) tuning_res[i_rep][i_d] = res nuisance_params.append(res['params']) @@ -816,11 +816,11 @@ def tune(self, else: smpls = [(np.arange(self._dml_data.n_obs), np.arange(self._dml_data.n_obs))] # tune hyperparameters - res = self._ml_nuisance_tuning(smpls, - param_grids, scoring_methods, - n_folds_tune, - n_jobs_cv, - search_mode, n_iter_randomized_search) + res = self._nuisance_tuning(smpls, + param_grids, scoring_methods, + n_folds_tune, + n_jobs_cv, + search_mode, n_iter_randomized_search) tuning_res[i_d] = res if set_as_params: @@ -887,12 +887,12 @@ def _initialize_ml_nuisance_params(self): pass @abstractmethod - def _ml_nuisance_and_score_elements(self, smpls, n_jobs_cv): + def _nuisance_est(self, smpls, n_jobs_cv): pass @abstractmethod - def _ml_nuisance_tuning(self, smpls, param_grids, scoring_methods, n_folds_tune, n_jobs_cv, - search_mode, n_iter_randomized_search): + def _nuisance_tuning(self, smpls, param_grids, scoring_methods, n_folds_tune, n_jobs_cv, + search_mode, n_iter_randomized_search): pass @staticmethod diff --git a/doubleml/double_ml_iivm.py b/doubleml/double_ml_iivm.py index 47c49500..8731ad65 100644 --- a/doubleml/double_ml_iivm.py +++ b/doubleml/double_ml_iivm.py @@ -210,7 +210,7 @@ def _check_data(self, obj_dml_data): raise ValueError(err_msg) return - def _ml_nuisance_and_score_elements(self, smpls, n_jobs_cv): + def _nuisance_est(self, smpls, n_jobs_cv): x, y = check_X_y(self._dml_data.x, self._dml_data.y, force_all_finite=False) x, z = check_X_y(x, np.ravel(self._dml_data.z), @@ -283,8 +283,8 @@ def _score_elements(self, y, z, d, g_hat0, g_hat1, m_hat, r_hat0, r_hat1, smpls) return psi_a, psi_b - def _ml_nuisance_tuning(self, smpls, param_grids, scoring_methods, n_folds_tune, n_jobs_cv, - search_mode, n_iter_randomized_search): + def _nuisance_tuning(self, smpls, param_grids, scoring_methods, n_folds_tune, n_jobs_cv, + search_mode, n_iter_randomized_search): x, y = check_X_y(self._dml_data.x, self._dml_data.y, force_all_finite=False) x, z = check_X_y(x, np.ravel(self._dml_data.z), diff --git a/doubleml/double_ml_irm.py b/doubleml/double_ml_irm.py index 0e5807de..1d210b1a 100644 --- a/doubleml/double_ml_irm.py +++ b/doubleml/double_ml_irm.py @@ -163,7 +163,7 @@ def _check_data(self, obj_dml_data): 'needs to be specified as treatment variable.') return - def _ml_nuisance_and_score_elements(self, smpls, n_jobs_cv): + def _nuisance_est(self, smpls, n_jobs_cv): x, y = check_X_y(self._dml_data.x, self._dml_data.y, force_all_finite=False) x, d = check_X_y(x, self._dml_data.d, @@ -230,8 +230,8 @@ def _score_elements(self, y, d, g_hat0, g_hat1, m_hat, smpls): return psi_a, psi_b - def _ml_nuisance_tuning(self, smpls, param_grids, scoring_methods, n_folds_tune, n_jobs_cv, - search_mode, n_iter_randomized_search): + def _nuisance_tuning(self, smpls, param_grids, scoring_methods, n_folds_tune, n_jobs_cv, + search_mode, n_iter_randomized_search): x, y = check_X_y(self._dml_data.x, self._dml_data.y, force_all_finite=False) x, d = check_X_y(x, self._dml_data.d, diff --git a/doubleml/double_ml_model_template.py b/doubleml/double_ml_model_template.py new file mode 100644 index 00000000..c5e7d5cc --- /dev/null +++ b/doubleml/double_ml_model_template.py @@ -0,0 +1,167 @@ +from sklearn.utils import check_X_y + +from .double_ml import DoubleML +from ._utils import _dml_cv_predict, _dml_tune, _check_finite_predictions + + +class DoubleMLNewModel(DoubleML): # TODO change DoubleMLNewModel to your model name + """Double machine learning for ??? TODO add your model description + + Parameters + ---------- + obj_dml_data : :class:`DoubleMLData` object + The :class:`DoubleMLData` object providing the data and specifying the variables for the causal model. + + # TODO add a description for each nuisance function (ml_g is a regression example; ml_m a classification example) + ml_g : estimator implementing ``fit()`` and ``predict()`` + A machine learner implementing ``fit()`` and ``predict()`` methods (e.g. + :py:class:`sklearn.ensemble.RandomForestRegressor`) for the nuisance function :math:`g_0(X) = E[Y|X]`. + + ml_m : classifier implementing ``fit()`` and ``predict()`` + A machine learner implementing ``fit()`` and ``predict()`` methods (e.g. + :py:class:`sklearn.ensemble.RandomForestClassifier`) for the nuisance function :math:`m_0(X) = E[D|X]`. + + n_folds : int + Number of folds. + Default is ``5``. + + n_rep : int + Number of repetitons for the sample splitting. + Default is ``1``. + + # TODO give a name for your orthogonal score function + score : str or callable + A str (``'my_orthogonal_score'``) specifying the score function. + Default is ``'my_orthogonal_score'``. + + dml_procedure : str + A str (``'dml1'`` or ``'dml2'``) specifying the double machine learning algorithm. + Default is ``'dml2'``. + + draw_sample_splitting : bool + Indicates whether the sample splitting should be drawn during initialization of the object. + Default is ``True``. + + apply_cross_fitting : bool + Indicates whether cross-fitting should be applied. + Default is ``True``. + + Examples + -------- + # TODO add an example + + Notes + ----- + # TODO add an description of the model + """ + def __init__(self, + obj_dml_data, + ml_g, # TODO add a entry for each nuisance function + ml_m, # TODO add a entry for each nuisance function + n_folds=5, + n_rep=1, + score='my_orthogonal_score', # TODO give a name for your orthogonal score function + dml_procedure='dml2', + draw_sample_splitting=True, + apply_cross_fitting=True): + super().__init__(obj_dml_data, + n_folds, + n_rep, + score, + dml_procedure, + draw_sample_splitting, + apply_cross_fitting) + + self._check_data(self._dml_data) + self._check_score(self.score) + _ = self._check_learner(ml_g, 'ml_g', regressor=True, classifier=False) # TODO may needs adaption + _ = self._check_learner(ml_g, 'ml_m', regressor=False, classifier=True) # TODO may needs adaption + self._learner = {'ml_g': ml_g, 'ml_m': ml_m} # TODO may needs adaption + self._predict_method = {'ml_g': 'predict', 'ml_m': 'predict_proba'} # TODO may needs adaption + + self._initialize_ml_nuisance_params() + + def _initialize_ml_nuisance_params(self): + self._params = {learner: {key: [None] * self.n_rep for key in self._dml_data.d_cols} for learner in + ['ml_g', 'ml_m']} # TODO may needs adaption + + def _check_score(self, score): + if isinstance(score, str): + valid_score = ['my_orthogonal_score'] # TODO give a name for your orthogonal score function + if score not in valid_score: + raise ValueError('Invalid score ' + score + '. ' + + 'Valid score ' + ' or '.join(valid_score) + '.') + else: + if not callable(score): + raise TypeError('score should be either a string or a callable. ' + '%r was passed.' % score) + return + + def _check_data(self, obj_dml_data): + # TODO model specific data requirements can be checked here + return + + def _nuisance_est(self, smpls, n_jobs_cv): + # TODO data checks may need adaptions + x, y = check_X_y(self._dml_data.x, self._dml_data.y, + force_all_finite=False) + x, d = check_X_y(x, self._dml_data.d, + force_all_finite=False) + + # TODO add a entry for each nuisance function + # nuisance g + g_hat = _dml_cv_predict(self._learner['ml_g'], x, y, smpls=smpls, n_jobs=n_jobs_cv, + est_params=self._get_params('ml_g'), method=self._predict_method['ml_g']) + _check_finite_predictions(g_hat, self._learner['ml_g'], 'ml_g', smpls) + + # TODO add a entry for each nuisance function + # nuisance m + m_hat = _dml_cv_predict(self._learner['ml_m'], x, d, smpls=smpls, n_jobs=n_jobs_cv, + est_params=self._get_params('ml_m'), method=self._predict_method['ml_m']) + _check_finite_predictions(m_hat, self._learner['ml_m'], 'ml_m', smpls) + + psi_a, psi_b = self._score_elements(y, d, g_hat, m_hat, smpls) # TODO may needs adaption + preds = {'ml_g': g_hat, + 'ml_m': m_hat} + + return psi_a, psi_b, preds + + def _score_elements(self, y, d, g_hat, m_hat, smpls): # TODO may needs adaption + # TODO here the score elements psi_a and psi_b should be computed + # return psi_a, psi_b + pass + + def _nuisance_tuning(self, smpls, param_grids, scoring_methods, n_folds_tune, n_jobs_cv, + search_mode, n_iter_randomized_search): + # TODO data checks may need adaptions + x, y = check_X_y(self._dml_data.x, self._dml_data.y, + force_all_finite=False) + x, d = check_X_y(x, self._dml_data.d, + force_all_finite=False) + + if scoring_methods is None: + scoring_methods = {'ml_g': None, + 'ml_m': None} # TODO may needs adaption + + train_inds = [train_index for (train_index, _) in smpls] + # TODO add a entry for each nuisance function + g_tune_res = _dml_tune(y, x, train_inds, + self._learner['ml_g'], param_grids['ml_g'], scoring_methods['ml_g'], + n_folds_tune, n_jobs_cv, search_mode, n_iter_randomized_search) + m_tune_res = _dml_tune(d, x, train_inds, + self._learner['ml_m'], param_grids['ml_m'], scoring_methods['ml_m'], + n_folds_tune, n_jobs_cv, search_mode, n_iter_randomized_search) + + g_best_params = [xx.best_params_ for xx in g_tune_res] + m_best_params = [xx.best_params_ for xx in m_tune_res] + + params = {'ml_g': g_best_params, + 'ml_m': m_best_params} # TODO may needs adaption + + tune_res = {'g_tune': g_tune_res, + 'm_tune': m_tune_res} # TODO may needs adaption + + res = {'params': params, + 'tune_res': tune_res} + + return res diff --git a/doubleml/double_ml_pliv.py b/doubleml/double_ml_pliv.py index a59e2dbb..77a76678 100644 --- a/doubleml/double_ml_pliv.py +++ b/doubleml/double_ml_pliv.py @@ -255,33 +255,33 @@ def _check_data(self, obj_dml_data): 'use DoubleMLPLR instead of DoubleMLPLIV.') return - def _ml_nuisance_and_score_elements(self, smpls, n_jobs_cv): + def _nuisance_est(self, smpls, n_jobs_cv): if self.partialX & (not self.partialZ): - psi_a, psi_b, preds = self._ml_nuisance_and_score_elements_partial_x(smpls, n_jobs_cv) + psi_a, psi_b, preds = self._nuisance_est_partial_x(smpls, n_jobs_cv) elif (not self.partialX) & self.partialZ: - psi_a, psi_b, preds = self._ml_nuisance_and_score_elements_partial_z(smpls, n_jobs_cv) + psi_a, psi_b, preds = self._nuisance_est_partial_z(smpls, n_jobs_cv) else: assert (self.partialX & self.partialZ) - psi_a, psi_b, preds = self._ml_nuisance_and_score_elements_partial_xz(smpls, n_jobs_cv) + psi_a, psi_b, preds = self._nuisance_est_partial_xz(smpls, n_jobs_cv) return psi_a, psi_b, preds - def _ml_nuisance_tuning(self, smpls, param_grids, scoring_methods, n_folds_tune, n_jobs_cv, - search_mode, n_iter_randomized_search): + def _nuisance_tuning(self, smpls, param_grids, scoring_methods, n_folds_tune, n_jobs_cv, + search_mode, n_iter_randomized_search): if self.partialX & (not self.partialZ): - res = self._ml_nuisance_tuning_partial_x(smpls, param_grids, scoring_methods, n_folds_tune, n_jobs_cv, + res = self._nuisance_tuning_partial_x(smpls, param_grids, scoring_methods, n_folds_tune, n_jobs_cv, search_mode, n_iter_randomized_search) elif (not self.partialX) & self.partialZ: - res = self._ml_nuisance_tuning_partial_z(smpls, param_grids, scoring_methods, n_folds_tune, n_jobs_cv, + res = self._nuisance_tuning_partial_z(smpls, param_grids, scoring_methods, n_folds_tune, n_jobs_cv, search_mode, n_iter_randomized_search) else: assert (self.partialX & self.partialZ) - res = self._ml_nuisance_tuning_partial_xz(smpls, param_grids, scoring_methods, n_folds_tune, n_jobs_cv, + res = self._nuisance_tuning_partial_xz(smpls, param_grids, scoring_methods, n_folds_tune, n_jobs_cv, search_mode, n_iter_randomized_search) return res - def _ml_nuisance_and_score_elements_partial_x(self, smpls, n_jobs_cv): + def _nuisance_est_partial_x(self, smpls, n_jobs_cv): x, y = check_X_y(self._dml_data.x, self._dml_data.y, force_all_finite=False) x, d = check_X_y(x, self._dml_data.d, @@ -357,7 +357,7 @@ def _score_elements(self, y, z, d, g_hat, m_hat, r_hat, smpls): return psi_a, psi_b - def _ml_nuisance_and_score_elements_partial_z(self, smpls, n_jobs_cv): + def _nuisance_est_partial_z(self, smpls, n_jobs_cv): y = self._dml_data.y xz, d = check_X_y(np.hstack((self._dml_data.x, self._dml_data.z)), self._dml_data.d, @@ -380,7 +380,7 @@ def _ml_nuisance_and_score_elements_partial_z(self, smpls, n_jobs_cv): return psi_a, psi_b, preds - def _ml_nuisance_and_score_elements_partial_xz(self, smpls, n_jobs_cv): + def _nuisance_est_partial_xz(self, smpls, n_jobs_cv): x, y = check_X_y(self._dml_data.x, self._dml_data.y, force_all_finite=False) xz, d = check_X_y(np.hstack((self._dml_data.x, self._dml_data.z)), @@ -423,7 +423,7 @@ def _ml_nuisance_and_score_elements_partial_xz(self, smpls, n_jobs_cv): return psi_a, psi_b, preds - def _ml_nuisance_tuning_partial_x(self, smpls, param_grids, scoring_methods, n_folds_tune, n_jobs_cv, + def _nuisance_tuning_partial_x(self, smpls, param_grids, scoring_methods, n_folds_tune, n_jobs_cv, search_mode, n_iter_randomized_search): x, y = check_X_y(self._dml_data.x, self._dml_data.y, force_all_finite=False) @@ -486,7 +486,7 @@ def _ml_nuisance_tuning_partial_x(self, smpls, param_grids, scoring_methods, n_f return res - def _ml_nuisance_tuning_partial_z(self, smpls, param_grids, scoring_methods, n_folds_tune, n_jobs_cv, + def _nuisance_tuning_partial_z(self, smpls, param_grids, scoring_methods, n_folds_tune, n_jobs_cv, search_mode, n_iter_randomized_search): xz, d = check_X_y(np.hstack((self._dml_data.x, self._dml_data.z)), self._dml_data.d, @@ -511,7 +511,7 @@ def _ml_nuisance_tuning_partial_z(self, smpls, param_grids, scoring_methods, n_f return res - def _ml_nuisance_tuning_partial_xz(self, smpls, param_grids, scoring_methods, n_folds_tune, n_jobs_cv, + def _nuisance_tuning_partial_xz(self, smpls, param_grids, scoring_methods, n_folds_tune, n_jobs_cv, search_mode, n_iter_randomized_search): x, y = check_X_y(self._dml_data.x, self._dml_data.y, force_all_finite=False) diff --git a/doubleml/double_ml_plr.py b/doubleml/double_ml_plr.py index 5f510227..c5d4432c 100644 --- a/doubleml/double_ml_plr.py +++ b/doubleml/double_ml_plr.py @@ -138,7 +138,7 @@ def _check_data(self, obj_dml_data): 'To fit a partially linear IV regression model use DoubleMLPLIV instead of DoubleMLPLR.') return - def _ml_nuisance_and_score_elements(self, smpls, n_jobs_cv): + def _nuisance_est(self, smpls, n_jobs_cv): x, y = check_X_y(self._dml_data.x, self._dml_data.y, force_all_finite=False) x, d = check_X_y(x, self._dml_data.d, @@ -188,8 +188,8 @@ def _score_elements(self, y, d, g_hat, m_hat, smpls): return psi_a, psi_b - def _ml_nuisance_tuning(self, smpls, param_grids, scoring_methods, n_folds_tune, n_jobs_cv, - search_mode, n_iter_randomized_search): + def _nuisance_tuning(self, smpls, param_grids, scoring_methods, n_folds_tune, n_jobs_cv, + search_mode, n_iter_randomized_search): x, y = check_X_y(self._dml_data.x, self._dml_data.y, force_all_finite=False) x, d = check_X_y(x, self._dml_data.d, From f7f2a72b008a18401e60c7e67a199b794bdd92a0 Mon Sep 17 00:00:00 2001 From: "Malte S. Kurz" Date: Fri, 3 Dec 2021 14:14:43 +0100 Subject: [PATCH 2/2] remove accidentally commited template file --- doubleml/double_ml_model_template.py | 167 --------------------------- 1 file changed, 167 deletions(-) delete mode 100644 doubleml/double_ml_model_template.py diff --git a/doubleml/double_ml_model_template.py b/doubleml/double_ml_model_template.py deleted file mode 100644 index c5e7d5cc..00000000 --- a/doubleml/double_ml_model_template.py +++ /dev/null @@ -1,167 +0,0 @@ -from sklearn.utils import check_X_y - -from .double_ml import DoubleML -from ._utils import _dml_cv_predict, _dml_tune, _check_finite_predictions - - -class DoubleMLNewModel(DoubleML): # TODO change DoubleMLNewModel to your model name - """Double machine learning for ??? TODO add your model description - - Parameters - ---------- - obj_dml_data : :class:`DoubleMLData` object - The :class:`DoubleMLData` object providing the data and specifying the variables for the causal model. - - # TODO add a description for each nuisance function (ml_g is a regression example; ml_m a classification example) - ml_g : estimator implementing ``fit()`` and ``predict()`` - A machine learner implementing ``fit()`` and ``predict()`` methods (e.g. - :py:class:`sklearn.ensemble.RandomForestRegressor`) for the nuisance function :math:`g_0(X) = E[Y|X]`. - - ml_m : classifier implementing ``fit()`` and ``predict()`` - A machine learner implementing ``fit()`` and ``predict()`` methods (e.g. - :py:class:`sklearn.ensemble.RandomForestClassifier`) for the nuisance function :math:`m_0(X) = E[D|X]`. - - n_folds : int - Number of folds. - Default is ``5``. - - n_rep : int - Number of repetitons for the sample splitting. - Default is ``1``. - - # TODO give a name for your orthogonal score function - score : str or callable - A str (``'my_orthogonal_score'``) specifying the score function. - Default is ``'my_orthogonal_score'``. - - dml_procedure : str - A str (``'dml1'`` or ``'dml2'``) specifying the double machine learning algorithm. - Default is ``'dml2'``. - - draw_sample_splitting : bool - Indicates whether the sample splitting should be drawn during initialization of the object. - Default is ``True``. - - apply_cross_fitting : bool - Indicates whether cross-fitting should be applied. - Default is ``True``. - - Examples - -------- - # TODO add an example - - Notes - ----- - # TODO add an description of the model - """ - def __init__(self, - obj_dml_data, - ml_g, # TODO add a entry for each nuisance function - ml_m, # TODO add a entry for each nuisance function - n_folds=5, - n_rep=1, - score='my_orthogonal_score', # TODO give a name for your orthogonal score function - dml_procedure='dml2', - draw_sample_splitting=True, - apply_cross_fitting=True): - super().__init__(obj_dml_data, - n_folds, - n_rep, - score, - dml_procedure, - draw_sample_splitting, - apply_cross_fitting) - - self._check_data(self._dml_data) - self._check_score(self.score) - _ = self._check_learner(ml_g, 'ml_g', regressor=True, classifier=False) # TODO may needs adaption - _ = self._check_learner(ml_g, 'ml_m', regressor=False, classifier=True) # TODO may needs adaption - self._learner = {'ml_g': ml_g, 'ml_m': ml_m} # TODO may needs adaption - self._predict_method = {'ml_g': 'predict', 'ml_m': 'predict_proba'} # TODO may needs adaption - - self._initialize_ml_nuisance_params() - - def _initialize_ml_nuisance_params(self): - self._params = {learner: {key: [None] * self.n_rep for key in self._dml_data.d_cols} for learner in - ['ml_g', 'ml_m']} # TODO may needs adaption - - def _check_score(self, score): - if isinstance(score, str): - valid_score = ['my_orthogonal_score'] # TODO give a name for your orthogonal score function - if score not in valid_score: - raise ValueError('Invalid score ' + score + '. ' + - 'Valid score ' + ' or '.join(valid_score) + '.') - else: - if not callable(score): - raise TypeError('score should be either a string or a callable. ' - '%r was passed.' % score) - return - - def _check_data(self, obj_dml_data): - # TODO model specific data requirements can be checked here - return - - def _nuisance_est(self, smpls, n_jobs_cv): - # TODO data checks may need adaptions - x, y = check_X_y(self._dml_data.x, self._dml_data.y, - force_all_finite=False) - x, d = check_X_y(x, self._dml_data.d, - force_all_finite=False) - - # TODO add a entry for each nuisance function - # nuisance g - g_hat = _dml_cv_predict(self._learner['ml_g'], x, y, smpls=smpls, n_jobs=n_jobs_cv, - est_params=self._get_params('ml_g'), method=self._predict_method['ml_g']) - _check_finite_predictions(g_hat, self._learner['ml_g'], 'ml_g', smpls) - - # TODO add a entry for each nuisance function - # nuisance m - m_hat = _dml_cv_predict(self._learner['ml_m'], x, d, smpls=smpls, n_jobs=n_jobs_cv, - est_params=self._get_params('ml_m'), method=self._predict_method['ml_m']) - _check_finite_predictions(m_hat, self._learner['ml_m'], 'ml_m', smpls) - - psi_a, psi_b = self._score_elements(y, d, g_hat, m_hat, smpls) # TODO may needs adaption - preds = {'ml_g': g_hat, - 'ml_m': m_hat} - - return psi_a, psi_b, preds - - def _score_elements(self, y, d, g_hat, m_hat, smpls): # TODO may needs adaption - # TODO here the score elements psi_a and psi_b should be computed - # return psi_a, psi_b - pass - - def _nuisance_tuning(self, smpls, param_grids, scoring_methods, n_folds_tune, n_jobs_cv, - search_mode, n_iter_randomized_search): - # TODO data checks may need adaptions - x, y = check_X_y(self._dml_data.x, self._dml_data.y, - force_all_finite=False) - x, d = check_X_y(x, self._dml_data.d, - force_all_finite=False) - - if scoring_methods is None: - scoring_methods = {'ml_g': None, - 'ml_m': None} # TODO may needs adaption - - train_inds = [train_index for (train_index, _) in smpls] - # TODO add a entry for each nuisance function - g_tune_res = _dml_tune(y, x, train_inds, - self._learner['ml_g'], param_grids['ml_g'], scoring_methods['ml_g'], - n_folds_tune, n_jobs_cv, search_mode, n_iter_randomized_search) - m_tune_res = _dml_tune(d, x, train_inds, - self._learner['ml_m'], param_grids['ml_m'], scoring_methods['ml_m'], - n_folds_tune, n_jobs_cv, search_mode, n_iter_randomized_search) - - g_best_params = [xx.best_params_ for xx in g_tune_res] - m_best_params = [xx.best_params_ for xx in m_tune_res] - - params = {'ml_g': g_best_params, - 'ml_m': m_best_params} # TODO may needs adaption - - tune_res = {'g_tune': g_tune_res, - 'm_tune': m_tune_res} # TODO may needs adaption - - res = {'params': params, - 'tune_res': tune_res} - - return res