diff --git a/R/double_ml.R b/R/double_ml.R index 0edba758..afdd4697 100644 --- a/R/double_ml.R +++ b/R/double_ml.R @@ -378,7 +378,7 @@ DoubleML = R6Class("DoubleML", } # ml estimation of nuisance models and computation of psi elements - res = private$ml_nuisance_and_score_elements(private$get__smpls()) + res = private$nuisance_est(private$get__smpls()) private$psi_a_[, private$i_rep, private$i_treat] = res$psi_a private$psi_b_[, private$i_rep, private$i_treat] = res$psi_b if (store_predictions) { @@ -814,7 +814,7 @@ DoubleML = R6Class("DoubleML", if (tune_on_folds) { for (i_rep in 1:self$n_rep) { private$i_rep = i_rep - param_tuning = private$ml_nuisance_tuning( + param_tuning = private$nuisance_tuning( private$get__smpls(), param_set, tune_settings, tune_on_folds) private$tuning_res_[[i_treat]][[i_rep]] = param_tuning @@ -833,7 +833,7 @@ DoubleML = R6Class("DoubleML", } } else { private$i_rep = 1 - param_tuning = private$ml_nuisance_tuning( + param_tuning = private$nuisance_tuning( private$get__smpls(), param_set, tune_settings, tune_on_folds) private$tuning_res_[[i_treat]] = param_tuning diff --git a/R/double_ml_iivm.R b/R/double_ml_iivm.R index c2507a94..067f2200 100644 --- a/R/double_ml_iivm.R +++ b/R/double_ml_iivm.R @@ -276,7 +276,7 @@ DoubleMLIIVM = R6Class("DoubleMLIIVM", "ml_r1" = nuisance) invisible(self) }, - ml_nuisance_and_score_elements = function(smpls, ...) { + nuisance_est = function(smpls, ...) { if (self$subgroups$always_takers == FALSE & self$subgroups$never_takers == FALSE) { @@ -396,7 +396,7 @@ DoubleMLIIVM = R6Class("DoubleMLIIVM", } return(psis) }, - ml_nuisance_tuning = function(smpls, param_set, tune_settings, + nuisance_tuning = function(smpls, param_set, tune_settings, tune_on_folds, ...) { if (!tune_on_folds) { diff --git a/R/double_ml_irm.R b/R/double_ml_irm.R index 6c0b1eb9..f78f7556 100644 --- a/R/double_ml_irm.R +++ b/R/double_ml_irm.R @@ -214,7 +214,7 @@ DoubleMLIRM = R6Class("DoubleMLIRM", "ml_m" = nuisance) invisible(self) }, - ml_nuisance_and_score_elements = function(smpls, ...) { + nuisance_est = function(smpls, ...) { cond_smpls = get_cond_samples( smpls, @@ -303,7 +303,7 @@ DoubleMLIRM = R6Class("DoubleMLIRM", } return(psis) }, - ml_nuisance_tuning = function(smpls, param_set, tune_settings, + nuisance_tuning = function(smpls, param_set, tune_settings, tune_on_folds, ...) { if (!tune_on_folds) { diff --git a/R/double_ml_pliv.R b/R/double_ml_pliv.R index 8ae34e3a..5616b20e 100644 --- a/R/double_ml_pliv.R +++ b/R/double_ml_pliv.R @@ -444,22 +444,22 @@ DoubleMLPLIV = R6Class("DoubleMLPLIV", names(private$params_) = param_names invisible(self) }, - ml_nuisance_and_score_elements = function(smpls, ...) { + nuisance_est = function(smpls, ...) { if (self$partialX & !self$partialZ) { - res = private$ml_nuisance_and_score_elements_partialX(smpls, ...) + res = private$nuisance_est_partialX(smpls, ...) } else if (!self$partialX & self$partialZ) { - res = private$ml_nuisance_and_score_elements_partialZ(smpls, ...) + res = private$nuisance_est_partialZ(smpls, ...) } else if (self$partialX & self$partialZ) { - res = private$ml_nuisance_and_score_elements_partialXZ(smpls, ...) + res = private$nuisance_est_partialXZ(smpls, ...) } return(res) }, - ml_nuisance_and_score_elements_partialX = function(smpls, ...) { + nuisance_est_partialX = function(smpls, ...) { l_hat = dml_cv_predict(self$learner$ml_l, c(self$data$x_cols, self$data$other_treat_cols), @@ -602,7 +602,7 @@ DoubleMLPLIV = R6Class("DoubleMLPLIV", } return(psis) }, - ml_nuisance_and_score_elements_partialXZ = function(smpls, ...) { + nuisance_est_partialXZ = function(smpls, ...) { l_hat = dml_cv_predict(self$learner$ml_l, c(self$data$x_cols, self$data$other_treat_cols), @@ -674,7 +674,7 @@ DoubleMLPLIV = R6Class("DoubleMLPLIV", return(res) }, - ml_nuisance_and_score_elements_partialZ = function(smpls, ...) { + nuisance_est_partialZ = function(smpls, ...) { # nuisance r @@ -712,22 +712,22 @@ DoubleMLPLIV = R6Class("DoubleMLPLIV", }, - ml_nuisance_tuning = function(smpls, param_set, tune_settings, + nuisance_tuning = function(smpls, param_set, tune_settings, tune_on_folds, ...) { if (self$partialX & !self$partialZ) { - res = private$ml_nuisance_tuning_partialX( + res = private$nuisance_tuning_partialX( smpls, param_set, tune_settings, tune_on_folds, ...) } else if (!self$partialX & self$partialZ) { - res = private$ml_nuisance_tuning_partialZ( + res = private$nuisance_tuning_partialZ( smpls, param_set, tune_settings, tune_on_folds, ...) } else if (self$partialX & self$partialZ) { - res = private$ml_nuisance_tuning_partialXZ( + res = private$nuisance_tuning_partialXZ( smpls, param_set, tune_settings, tune_on_folds, ...) @@ -736,7 +736,7 @@ DoubleMLPLIV = R6Class("DoubleMLPLIV", return(res) }, - ml_nuisance_tuning_partialX = function(smpls, param_set, + nuisance_tuning_partialX = function(smpls, param_set, tune_settings, tune_on_folds, ...) { if (!tune_on_folds) { @@ -892,7 +892,7 @@ DoubleMLPLIV = R6Class("DoubleMLPLIV", }, - ml_nuisance_tuning_partialXZ = function(smpls, param_set, + nuisance_tuning_partialXZ = function(smpls, param_set, tune_settings, tune_on_folds, ...) { if (!tune_on_folds) { @@ -968,7 +968,7 @@ DoubleMLPLIV = R6Class("DoubleMLPLIV", return(tuning_result) }, - ml_nuisance_tuning_partialZ = function(smpls, param_set, + nuisance_tuning_partialZ = function(smpls, param_set, tune_settings, tune_on_folds, ...) { if (!tune_on_folds) { data_tune_list = list(self$data$data_model) diff --git a/R/double_ml_plr.R b/R/double_ml_plr.R index 8fb6ca43..e418d595 100644 --- a/R/double_ml_plr.R +++ b/R/double_ml_plr.R @@ -379,7 +379,7 @@ DoubleMLPLR = R6Class("DoubleMLPLR", invisible(self) }, - ml_nuisance_and_score_elements = function(smpls, ...) { + nuisance_est = function(smpls, ...) { l_hat = dml_cv_predict(self$learner$ml_l, c(self$data$x_cols, self$data$other_treat_cols), @@ -459,7 +459,7 @@ DoubleMLPLR = R6Class("DoubleMLPLR", } return(psis) }, - ml_nuisance_tuning = function(smpls, param_set, tune_settings, + nuisance_tuning = function(smpls, param_set, tune_settings, tune_on_folds, ...) { if (!tune_on_folds) { diff --git a/man/figures/oop.svg b/man/figures/oop.svg index 4837b3eb..8d426d97 100644 --- a/man/figures/oop.svg +++ b/man/figures/oop.svg @@ -1,7 +1,8 @@ - + + @@ -40,7 +41,6 @@ - @@ -64,10 +64,10 @@ - - - - + + + + @@ -78,7 +78,7 @@ - + @@ -90,8 +90,8 @@ - - + + @@ -103,8 +103,8 @@ - - + + @@ -113,7 +113,7 @@ - + @@ -122,8 +122,8 @@ - - + + @@ -149,8 +149,8 @@ - - + + @@ -193,78 +193,49 @@ - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - + + + + + + + + + + - - - - - - - - - - + + + + + + + + + + + + + + + - - - - - - - - - - + + + + + + + + + + - - + + @@ -278,7 +249,7 @@ - + @@ -290,8 +261,8 @@ - - + + @@ -309,7 +280,7 @@ - + @@ -327,78 +298,49 @@ - - - - + + + + + + + + + + - - - - - - - - - - + + + + + + + + + + + + + + + - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - + + + + + + + + + + - - + + @@ -413,7 +355,7 @@ - + @@ -425,8 +367,8 @@ - - + + @@ -444,7 +386,7 @@ - + @@ -462,78 +404,49 @@ - - - - - - - - - - - - - - - + + + + + + + + + + - - - - - + + + + + + + + + + + + + + + - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - + + + + + + + + + + - - + + @@ -547,7 +460,7 @@ - + @@ -559,8 +472,8 @@ - - + + @@ -578,7 +491,7 @@ - + @@ -596,78 +509,49 @@ - - - - - - - - - - - - - - - - - - - - - - - - - - - - - + + + + + + + + + + - - - - - - - - - - - - - - + + + + + + + + + + + + + + + - - - - - - - - - - - - - - - - - - - - - + + + + + + + + + + - - + + @@ -682,7 +566,7 @@ - + @@ -694,8 +578,8 @@ - - + + @@ -713,7 +597,7 @@ - + @@ -731,72 +615,43 @@ - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - + + + + + + + + + + - - - - - - - - - - + + + + + + + + + + + + + + + - - - - - - - - - - + + + + + + + + + +