@@ -116,7 +116,8 @@ def __init__(self,
116116 self ._initialize_ml_nuisance_params ()
117117
118118 def _initialize_ml_nuisance_params (self ):
119- self ._params = {learner : {key : [None ] * self .n_rep for key in self ._dml_data .d_cols } for learner in ['ml_g' , 'ml_m' ]}
119+ self ._params = {learner : {key : [None ] * self .n_rep for key in self ._dml_data .d_cols }
120+ for learner in ['ml_l' , 'ml_g' , 'ml_m' ]}
120121
121122 def _check_score (self , score ):
122123 if isinstance (score , str ):
@@ -144,10 +145,10 @@ def _nuisance_est(self, smpls, n_jobs_cv):
144145 x , d = check_X_y (x , self ._dml_data .d ,
145146 force_all_finite = False )
146147
147- # nuisance g
148- g_hat = _dml_cv_predict (self ._learner ['ml_g' ], x , y , smpls = smpls , n_jobs = n_jobs_cv ,
148+ # nuisance l
149+ l_hat = _dml_cv_predict (self ._learner ['ml_g' ], x , y , smpls = smpls , n_jobs = n_jobs_cv ,
149150 est_params = self ._get_params ('ml_g' ), method = self ._predict_method ['ml_g' ])
150- _check_finite_predictions (g_hat , self ._learner ['ml_g' ], 'ml_g' , smpls )
151+ _check_finite_predictions (l_hat , self ._learner ['ml_g' ], 'ml_g' , smpls )
151152
152153 # nuisance m
153154 m_hat = _dml_cv_predict (self ._learner ['ml_m' ], x , d , smpls = smpls , n_jobs = n_jobs_cv ,
@@ -163,28 +164,41 @@ def _nuisance_est(self, smpls, n_jobs_cv):
163164 'observed to be binary with values 0 and 1. Make sure that for classifiers '
164165 'probabilities and not labels are predicted.' )
165166
166- psi_a , psi_b = self ._score_elements (y , d , g_hat , m_hat , smpls )
167+ # an estimate of g is obtained for the IV-type score and callable scores
168+ g_hat = None
169+ if (isinstance (self .score , str ) & (self .score == 'IV-type' )) | callable (self .score ):
170+ # get an initial estimate for theta using the partialling out score
171+ psi_a = - np .multiply (d - m_hat , d - m_hat )
172+ psi_b = np .multiply (d - m_hat , y - l_hat )
173+ theta_initial = - np .mean (psi_b ) / np .mean (psi_a )
174+ # nuisance g
175+ g_hat = _dml_cv_predict (self ._learner ['ml_g' ], x , y - theta_initial * d , smpls = smpls , n_jobs = n_jobs_cv ,
176+ est_params = self ._get_params ('ml_l' ), method = self ._predict_method ['ml_g' ])
177+ _check_finite_predictions (g_hat , self ._learner ['ml_g' ], 'ml_g' , smpls )
178+
179+ psi_a , psi_b = self ._score_elements (y , d , l_hat , g_hat , m_hat , smpls )
167180 preds = {'ml_g' : g_hat ,
181+ 'ml_l' : l_hat ,
168182 'ml_m' : m_hat }
169183
170184 return psi_a , psi_b , preds
171185
172- def _score_elements (self , y , d , g_hat , m_hat , smpls ):
186+ def _score_elements (self , y , d , l_hat , g_hat , m_hat , smpls ):
173187 # compute residuals
174- u_hat = y - g_hat
188+ u_hat = y - l_hat
175189 v_hat = d - m_hat
176- v_hatd = np .multiply (v_hat , d )
177190
178191 if isinstance (self .score , str ):
179192 if self .score == 'IV-type' :
180- psi_a = - v_hatd
193+ psi_a = - np .multiply (v_hat , d )
194+ psi_b = np .multiply (v_hat , y - g_hat )
181195 else :
182196 assert self .score == 'partialling out'
183197 psi_a = - np .multiply (v_hat , v_hat )
184- psi_b = np .multiply (v_hat , u_hat )
198+ psi_b = np .multiply (v_hat , u_hat )
185199 else :
186200 assert callable (self .score )
187- psi_a , psi_b = self .score (y , d , g_hat , m_hat , smpls )
201+ psi_a , psi_b = self .score (y , d , l_hat , g_hat , m_hat , smpls )
188202
189203 return psi_a , psi_b
190204
@@ -200,21 +214,44 @@ def _nuisance_tuning(self, smpls, param_grids, scoring_methods, n_folds_tune, n_
200214 'ml_m' : None }
201215
202216 train_inds = [train_index for (train_index , _ ) in smpls ]
203- g_tune_res = _dml_tune (y , x , train_inds ,
217+ l_tune_res = _dml_tune (y , x , train_inds ,
204218 self ._learner ['ml_g' ], param_grids ['ml_g' ], scoring_methods ['ml_g' ],
205219 n_folds_tune , n_jobs_cv , search_mode , n_iter_randomized_search )
206220 m_tune_res = _dml_tune (d , x , train_inds ,
207221 self ._learner ['ml_m' ], param_grids ['ml_m' ], scoring_methods ['ml_m' ],
208222 n_folds_tune , n_jobs_cv , search_mode , n_iter_randomized_search )
209223
210- g_best_params = [xx .best_params_ for xx in g_tune_res ]
224+ l_best_params = [xx .best_params_ for xx in l_tune_res ]
211225 m_best_params = [xx .best_params_ for xx in m_tune_res ]
212226
213- params = {'ml_g' : g_best_params ,
214- 'ml_m' : m_best_params }
215-
216- tune_res = {'g_tune' : g_tune_res ,
217- 'm_tune' : m_tune_res }
227+ # an ML model for g is obtained for the IV-type score and callable scores
228+ if (isinstance (self .score , str ) & (self .score == 'IV-type' )) | callable (self .score ):
229+ # construct an initial theta estimate from the tuned models using the partialling out score
230+ l_hat = np .full_like (y , np .nan )
231+ m_hat = np .full_like (d , np .nan )
232+ for idx , (train_index , _ ) in enumerate (smpls ):
233+ l_hat [train_index ] = l_tune_res [idx ].predict (x [train_index , :])
234+ m_hat [train_index ] = m_tune_res [idx ].predict (x [train_index , :])
235+ psi_a = - np .multiply (d - m_hat , d - m_hat )
236+ psi_b = np .multiply (d - m_hat , y - l_hat )
237+ theta_initial = - np .mean (psi_b ) / np .mean (psi_a )
238+ g_tune_res = _dml_tune (y - theta_initial * d , x , train_inds ,
239+ self ._learner ['ml_g' ], param_grids ['ml_g' ], scoring_methods ['ml_g' ],
240+ n_folds_tune , n_jobs_cv , search_mode , n_iter_randomized_search )
241+
242+ g_best_params = [xx .best_params_ for xx in g_tune_res ]
243+ params = {'ml_l' : l_best_params ,
244+ 'ml_m' : m_best_params ,
245+ 'ml_g' : g_best_params }
246+ tune_res = {'l_tune' : l_tune_res ,
247+ 'm_tune' : m_tune_res ,
248+ 'g_tune' : g_tune_res }
249+ else :
250+ assert self .score == 'partialling out'
251+ params = {'ml_l' : l_best_params ,
252+ 'ml_m' : m_best_params }
253+ tune_res = {'g_tune' : l_tune_res ,
254+ 'm_tune' : m_tune_res }
218255
219256 res = {'params' : params ,
220257 'tune_res' : tune_res }
0 commit comments