diff --git a/R/double_ml_iivm.R b/R/double_ml_iivm.R index c226aee9..c2507a94 100644 --- a/R/double_ml_iivm.R +++ b/R/double_ml_iivm.R @@ -366,8 +366,10 @@ DoubleMLIIVM = R6Class("DoubleMLIIVM", "ml_r1" = r1_hat) return(res) }, - score_elements = function(y, z, d, g0_hat, g1_hat, m_hat, r0_hat, - r1_hat, smpls) { + score_elements = function(y = y, z = z, d = d, + g0_hat = g0_hat, g1_hat = g1_hat, m_hat = m_hat, + r0_hat = r0_hat, r1_hat = r1_hat, + smpls = smpls) { u0_hat = y - g0_hat u1_hat = y - g1_hat diff --git a/R/double_ml_irm.R b/R/double_ml_irm.R index d7c347b5..6c0b1eb9 100644 --- a/R/double_ml_irm.R +++ b/R/double_ml_irm.R @@ -296,7 +296,10 @@ DoubleMLIRM = R6Class("DoubleMLIRM", } psis = list(psi_a = psi_a, psi_b = psi_b) } else if (is.function(self$score)) { - psis = self$score(y, d, g0_hat, g1_hat, m_hat, smpls) + psis = self$score( + y = y, d = d, + g0_hat = g0_hat, g1_hat = g1_hat, m_hat = m_hat, + smpls = smpls) } return(psis) }, diff --git a/R/double_ml_pliv.R b/R/double_ml_pliv.R index 90e1d04d..8ae34e3a 100644 --- a/R/double_ml_pliv.R +++ b/R/double_ml_pliv.R @@ -25,11 +25,11 @@ #' library(mlr3learners) #' library(data.table) #' set.seed(2) -#' ml_g = lrn("regr.ranger", num.trees = 100, mtry = 20, min.node.size = 2, max.depth = 5) -#' ml_m = ml_g$clone() -#' ml_r = ml_g$clone() +#' ml_l = lrn("regr.ranger", num.trees = 100, mtry = 20, min.node.size = 2, max.depth = 5) +#' ml_m = ml_l$clone() +#' ml_r = ml_l$clone() #' obj_dml_data = make_pliv_CHS2015(alpha = 1, n_obs = 500, dim_x = 20, dim_z = 1) -#' dml_pliv_obj = DoubleMLPLIV$new(obj_dml_data, ml_g, ml_m, ml_r) +#' dml_pliv_obj = DoubleMLPLIV$new(obj_dml_data, ml_l, ml_m, ml_r) #' dml_pliv_obj$fit() #' dml_pliv_obj$summary() #' } @@ -41,15 +41,15 @@ #' library(mlr3tuning) #' library(data.table) #' set.seed(2) -#' ml_g = lrn("regr.rpart") -#' ml_m = ml_g$clone() -#' ml_r = ml_g$clone() +#' ml_l = lrn("regr.rpart") +#' ml_m = ml_l$clone() +#' ml_r = ml_l$clone() #' obj_dml_data = make_pliv_CHS2015( #' alpha = 1, n_obs = 500, dim_x = 20, #' dim_z = 1) -#' dml_pliv_obj = DoubleMLPLIV$new(obj_dml_data, ml_g, ml_m, ml_r) +#' dml_pliv_obj = DoubleMLPLIV$new(obj_dml_data, ml_l, ml_m, ml_r) #' param_grid = list( -#' "ml_g" = paradox::ParamSet$new(list( +#' "ml_l" = paradox::ParamSet$new(list( #' paradox::ParamDbl$new("cp", lower = 0.01, upper = 0.02), #' paradox::ParamInt$new("minsplit", lower = 1, upper = 2))), #' "ml_m" = paradox::ParamSet$new(list( @@ -99,7 +99,7 @@ DoubleMLPLIV = R6Class("DoubleMLPLIV", #' The `DoubleMLData` object providing the data and specifying the variables #' of the causal model. #' - #' @param ml_g ([`LearnerRegr`][mlr3::LearnerRegr], + #' @param ml_l ([`LearnerRegr`][mlr3::LearnerRegr], #' [`Learner`][mlr3::Learner], `character(1)`) \cr #' A learner of the class [`LearnerRegr`][mlr3::LearnerRegr], which is #' available from [mlr3](https://mlr3.mlr-org.com/index.html) or its @@ -110,7 +110,7 @@ DoubleMLPLIV = R6Class("DoubleMLPLIV", #' [`GraphLearner`][mlr3pipelines::GraphLearner]. The learner can possibly #' be passed with specified parameters, for example #' `lrn("regr.cv_glmnet", s = "lambda.min")`. \cr - #' `ml_g` refers to the nuisance function \eqn{g_0(X) = E[Y|X]}. + #' `ml_l` refers to the nuisance function \eqn{l_0(X) = E[Y|X]}. #' #' @param ml_m ([`LearnerRegr`][mlr3::LearnerRegr], #' [`Learner`][mlr3::Learner], `character(1)`) \cr @@ -138,6 +138,21 @@ DoubleMLPLIV = R6Class("DoubleMLPLIV", #' `lrn("regr.cv_glmnet", s = "lambda.min")`. \cr #' `ml_r` refers to the nuisance function \eqn{r_0(X) = E[D|X]}. #' + #' @param ml_g ([`LearnerRegr`][mlr3::LearnerRegr], + #' [`Learner`][mlr3::Learner], `character(1)`) \cr + #' A learner of the class [`LearnerRegr`][mlr3::LearnerRegr], which is + #' available from [mlr3](https://mlr3.mlr-org.com/index.html) or its + #' extension packages [mlr3learners](https://mlr3learners.mlr-org.com/) or + #' [mlr3extralearners](https://mlr3extralearners.mlr-org.com/). + #' Alternatively, a [`Learner`][mlr3::Learner] object with public field + #' `task_type = "regr"` can be passed, for example of class + #' [`GraphLearner`][mlr3pipelines::GraphLearner]. The learner can possibly + #' be passed with specified parameters, for example + #' `lrn("regr.cv_glmnet", s = "lambda.min")`. \cr + #' `ml_g` refers to the nuisance function \eqn{g_0(X) = E[Y - D\theta_0|X]}. + #' Note: The learner `ml_g` is only required for the score `'IV-type'`. + #' Optionally, it can be specified and estimated for callable scores. + #' #' @param partialX (`logical(1)`) \cr #' Indicates whether covariates \eqn{X} should be partialled out. #' Default is `TRUE`. @@ -153,10 +168,10 @@ DoubleMLPLIV = R6Class("DoubleMLPLIV", #' Number of repetitions for the sample splitting. Default is `1`. #' #' @param score (`character(1)`, `function()`) \cr - #' A `character(1)` (`"partialling out"` is the only choice) or a - #' `function()` specifying the score function. + #' A `character(1)` (`"partialling out"` or `"IV-type"`) or a `function()` + #' specifying the score function. #' If a `function()` is provided, it must be of the form - #' `function(y, z, d, g_hat, m_hat, r_hat, smpls)` and + #' `function(y, z, d, l_hat, m_hat, r_hat, g_hat, smpls)` and #' the returned output must be a named `list()` with elements #' `psi_a` and `psi_b`. Default is `"partialling out"`. #' @@ -171,9 +186,10 @@ DoubleMLPLIV = R6Class("DoubleMLPLIV", #' @param apply_cross_fitting (`logical(1)`) \cr #' Indicates whether cross-fitting should be applied. Default is `TRUE`. initialize = function(data, - ml_g, + ml_l, ml_m, ml_r, + ml_g = NULL, partialX = TRUE, partialZ = FALSE, n_folds = 5, @@ -183,6 +199,19 @@ DoubleMLPLIV = R6Class("DoubleMLPLIV", draw_sample_splitting = TRUE, apply_cross_fitting = TRUE) { + if (missing(ml_l)) { + if (!missing(ml_g)) { + warning(paste0( + "The argument ml_g was renamed to ml_l. ", + "Please adapt the argument name accordingly. ", + "ml_g is redirected to ml_l.\n", + "The redirection will be removed in a future version."), + call. = FALSE) + ml_l = ml_g + ml_g = NULL + } + } + super$initialize_double_ml( data, n_folds, @@ -193,11 +222,11 @@ DoubleMLPLIV = R6Class("DoubleMLPLIV", apply_cross_fitting) private$check_data(self$data) - private$check_score(self$score) assert_logical(partialX, len = 1) assert_logical(partialZ, len = 1) private$partialX_ = partialX private$partialZ_ = partialZ + private$check_score(self$score) if (!self$partialX & self$partialZ) { ml_r = private$assert_learner(ml_r, "ml_r", @@ -205,7 +234,7 @@ DoubleMLPLIV = R6Class("DoubleMLPLIV", Classif = FALSE) private$learner_ = list("ml_r" = ml_r) } else { - ml_g = private$assert_learner(ml_g, "ml_g", + ml_l = private$assert_learner(ml_l, "ml_l", Regr = TRUE, Classif = FALSE) ml_m = private$assert_learner(ml_m, "ml_m", @@ -215,11 +244,186 @@ DoubleMLPLIV = R6Class("DoubleMLPLIV", Regr = TRUE, Classif = FALSE) private$learner_ = list( - "ml_g" = ml_g, + "ml_l" = ml_l, "ml_m" = ml_m, "ml_r" = ml_r) + + if (!is.null(ml_g)) { + assert( + check_character(ml_g, max.len = 1), + check_class(ml_g, "Learner")) + if ((is.character(self$score) && (self$score == "IV-type")) || + is.function(self$score)) { + ml_g = private$assert_learner(ml_g, "ml_g", + Regr = TRUE, Classif = FALSE) + private$learner_[["ml_g"]] = ml_g + } else if (is.character(self$score) && + (self$score == "partialling out")) { + warning(paste0( + "A learner ml_g has been provided for ", + "score = 'partialling out' but will be ignored. ", + "A learner ml_g is not required for estimation.")) + } + } else if (is.character(self$score) && (self$score == "IV-type")) { + stop(paste( + "For score = 'IV-type', learners", + "ml_l, ml_m, ml_r and ml_g need to be specified.")) + } } + private$initialize_ml_nuisance_params() + }, + # To be removed in version 0.6.0 + # + # Note: Ideally the following duplicate roxygen / docu parts should be taken + # from the base class DoubleML. However, this is an open issue in pkg + # roxygen2, see https://github.com/r-lib/roxygen2/issues/996 & + # https://github.com/r-lib/roxygen2/issues/1043 + # + #' @description + #' Set hyperparameters for the nuisance models of DoubleML models. + #' + #' Note that in the current implementation, either all parameters have to + #' be set globally or all parameters have to be provided fold-specific. + #' + #' @param learner (`character(1)`) \cr + #' The nuisance model/learner (see method `params_names`). + #' + #' @param treat_var (`character(1)`) \cr + #' The treatment varaible (hyperparameters can be set treatment-variable + #' specific). + #' + #' @param params (named `list()`) \cr + #' A named `list()` with estimator parameters. Parameters are used for all + #' folds by default. Alternatively, parameters can be passed in a + #' fold-specific way if option `fold_specific`is `TRUE`. In this case, the + #' outer list needs to be of length `n_rep` and the inner list of length + #' `n_folds`. + #' + #' @param set_fold_specific (`logical(1)`) \cr + #' Indicates if the parameters passed in `params` should be passed in + #' fold-specific way. Default is `FALSE`. If `TRUE`, the outer list needs + #' to be of length `n_rep` and the inner list of length `n_folds`. + #' Note that in the current implementation, either all parameters have to + #' be set globally or all parameters have to be provided fold-specific. + #' + #' @return self + set_ml_nuisance_params = function(learner = NULL, treat_var = NULL, params, + set_fold_specific = FALSE) { + assert_character(learner, len = 1) + if (is.character(self$score) && (self$score == "partialling out") && + (learner == "ml_g")) { + warning(paste0( + "Learner ml_g was renamed to ml_l. ", + "Please adapt the argument learner accordingly. ", + "The provided parameters are set for ml_l. ", + "The redirection will be removed in a future version."), + call. = FALSE) + learner = "ml_l" + } + super$set_ml_nuisance_params( + learner, treat_var, params, + set_fold_specific) + }, + # To be removed in version 0.6.0 + # + # Note: Ideally the following duplicate roxygen / docu parts should be taken + # from the base class DoubleML. However, this is an open issue in pkg + # roxygen2, see https://github.com/r-lib/roxygen2/issues/996 & + # https://github.com/r-lib/roxygen2/issues/1043 + # + #' @description + #' Hyperparameter-tuning for DoubleML models. + #' + #' The hyperparameter-tuning is performed using the tuning methods provided + #' in the [mlr3tuning](https://mlr3tuning.mlr-org.com/) package. For more + #' information on tuning in [mlr3](https://mlr3.mlr-org.com/), we refer to + #' the section on parameter tuning in the + #' [mlr3 book](https://mlr3book.mlr-org.com/optimization.html#tuning). + #' + #' @param param_set (named `list()`) \cr + #' A named `list` with a parameter grid for each nuisance model/learner + #' (see method `learner_names()`). The parameter grid must be an object of + #' class [ParamSet][paradox::ParamSet]. + #' + #' @param tune_settings (named `list()`) \cr + #' A named `list()` with arguments passed to the hyperparameter-tuning with + #' [mlr3tuning](https://mlr3tuning.mlr-org.com/) to set up + #' [TuningInstance][mlr3tuning::TuningInstanceSingleCrit] objects. + #' `tune_settings` has entries + #' * `terminator` ([Terminator][bbotk::Terminator]) \cr + #' A [Terminator][bbotk::Terminator] object. Specification of `terminator` + #' is required to perform tuning. + #' * `algorithm` ([Tuner][mlr3tuning::Tuner] or `character(1)`) \cr + #' A [Tuner][mlr3tuning::Tuner] object (recommended) or key passed to the + #' respective dictionary to specify the tuning algorithm used in + #' [tnr()][mlr3tuning::tnr()]. `algorithm` is passed as an argument to + #' [tnr()][mlr3tuning::tnr()]. If `algorithm` is not specified by the users, + #' default is set to `"grid_search"`. If set to `"grid_search"`, then + #' additional argument `"resolution"` is required. + #' * `rsmp_tune` ([Resampling][mlr3::Resampling] or `character(1)`)\cr + #' A [Resampling][mlr3::Resampling] object (recommended) or option passed + #' to [rsmp()][mlr3::mlr_sugar] to initialize a + #' [Resampling][mlr3::Resampling] for parameter tuning in `mlr3`. + #' If not specified by the user, default is set to `"cv"` + #' (cross-validation). + #' * `n_folds_tune` (`integer(1)`, optional) \cr + #' If `rsmp_tune = "cv"`, number of folds used for cross-validation. + #' If not specified by the user, default is set to `5`. + #' * `measure` (`NULL`, named `list()`, optional) \cr + #' Named list containing the measures used for parameter tuning. Entries in + #' list must either be [Measure][mlr3::Measure] objects or keys to be + #' passed to passed to [msr()][mlr3::msr()]. The names of the entries must + #' match the learner names (see method `learner_names()`). If set to `NULL`, + #' default measures are used, i.e., `"regr.mse"` for continuous outcome + #' variables and `"classif.ce"` for binary outcomes. + #' * `resolution` (`character(1)`) \cr The key passed to the respective + #' dictionary to specify the tuning algorithm used in + #' [tnr()][mlr3tuning::tnr()]. `resolution` is passed as an argument to + #' [tnr()][mlr3tuning::tnr()]. + #' + #' @param tune_on_folds (`logical(1)`) \cr + #' Indicates whether the tuning should be done fold-specific or globally. + #' Default is `FALSE`. + #' + #' @return self + tune = function(param_set, tune_settings = list( + n_folds_tune = 5, + rsmp_tune = mlr3::rsmp("cv", folds = 5), + measure = NULL, + terminator = mlr3tuning::trm("evals", n_evals = 20), + algorithm = mlr3tuning::tnr("grid_search"), + resolution = 5), + tune_on_folds = FALSE) { + + assert_list(param_set) + if (is.character(self$score) && (self$score == "partialling out")) { + if (exists("ml_g", where = param_set) && !exists("ml_l", where = param_set)) { + warning(paste0( + "Learner ml_g was renamed to ml_l. ", + "Please adapt the name in param_set accordingly. ", + "The provided param_set for ml_g is used for ml_l. ", + "The redirection will be removed in a future version."), + call. = FALSE) + names(param_set)[names(param_set) == "ml_g"] = "ml_l" + } + } + + assert_list(tune_settings) + if (test_names(names(tune_settings), must.include = "measure") && !is.null(tune_settings$measure)) { + assert_list(tune_settings$measure) + if (exists("ml_g", where = tune_settings$measure) && !exists("ml_l", where = tune_settings$measure)) { + warning(paste0( + "Learner ml_g was renamed to ml_l. ", + "Please adapt the name in tune_settings$measure accordingly. ", + "The provided tune_settings$measure for ml_g is used for ml_l. ", + "The redirection will be removed in a future version."), + call. = FALSE) + names(tune_settings$measure)[names(tune_settings$measure) == "ml_g"] = "ml_l" + } + } + + super$tune(param_set, tune_settings, tune_on_folds) } ), private = list( @@ -228,22 +432,16 @@ DoubleMLPLIV = R6Class("DoubleMLPLIV", n_nuisance = 3, i_instr = NULL, initialize_ml_nuisance_params = function() { - if (self$partialX & !self$partialZ) { - if (self$data$n_instr == 1) { - valid_learner = c("ml_g", "ml_m", "ml_r") - } else { - valid_learner = c("ml_g", "ml_r", paste0("ml_m_", self$data$z_cols)) - } - } else if (self$partialX & self$partialZ) { - valid_learner = c("ml_g", "ml_m", "ml_r") - } else if (!self$partialX & self$partialZ) { - valid_learner = c("ml_r") + if ((self$partialX && !self$partialZ) && (self$data$n_instr > 1)) { + param_names = c("ml_l", "ml_r", paste0("ml_m_", self$data$z_cols)) + } else { + param_names = names(private$learner_) } nuisance = vector("list", self$data$n_treat) names(nuisance) = self$data$d_cols - private$params_ = rep(list(nuisance), length(valid_learner)) - names(private$params_) = valid_learner + private$params_ = rep(list(nuisance), length(param_names)) + names(private$params_) = param_names invisible(self) }, ml_nuisance_and_score_elements = function(smpls, ...) { @@ -263,15 +461,15 @@ DoubleMLPLIV = R6Class("DoubleMLPLIV", ml_nuisance_and_score_elements_partialX = function(smpls, ...) { - g_hat = dml_cv_predict(self$learner$ml_g, + l_hat = dml_cv_predict(self$learner$ml_l, c(self$data$x_cols, self$data$other_treat_cols), self$data$y_col, self$data$data_model, - nuisance_id = "nuis_g", + nuisance_id = "nuis_l", smpls = smpls, - est_params = self$get_params("ml_g"), + est_params = self$get_params("ml_l"), return_train_preds = FALSE, - task_type = private$task_type$ml_g, + task_type = private$task_type$ml_l, fold_specific_params = private$fold_specific_params) r_hat = dml_cv_predict(self$learner$ml_r, @@ -320,29 +518,58 @@ DoubleMLPLIV = R6Class("DoubleMLPLIV", d = self$data$data_model[[self$data$treat_col]] y = self$data$data_model[[self$data$y_col]] - res = private$score_elements(y, z, d, g_hat, m_hat, r_hat, smpls) + g_hat = NULL + if (exists("ml_g", where = private$learner_)) { + # get an initial estimate for theta using the partialling out score + psi_a = -(d - r_hat) * (z - m_hat) + psi_b = (z - m_hat) * (y - l_hat) + theta_initial = -mean(psi_b, na.rm = TRUE) / mean(psi_a, na.rm = TRUE) + + data_aux = data.table(self$data$data_model, + "y_minus_theta_d" = y - theta_initial * d) + + g_hat = dml_cv_predict(self$learner$ml_g, + c(self$data$x_cols, self$data$other_treat_cols), + "y_minus_theta_d", + data_aux, + nuisance_id = "nuis_g", + smpls = smpls, + est_params = self$get_params("ml_g"), + return_train_preds = FALSE, + task_type = private$task_type$ml_g, + fold_specific_params = private$fold_specific_params) + } + + res = private$score_elements(y, z, d, l_hat, m_hat, r_hat, g_hat, smpls) res$preds = list( - "ml_g" = g_hat, + "ml_l" = l_hat, "ml_m" = m_hat, - "ml_r" = r_hat) + "ml_r" = r_hat, + "ml_g" = g_hat) return(res) }, - score_elements = function(y, z, d, g_hat, m_hat, r_hat, smpls) { - u_hat = y - g_hat + score_elements = function(y, z, d, l_hat, m_hat, r_hat, g_hat, smpls) { + u_hat = y - l_hat w_hat = d - r_hat v_hat = z - m_hat - if (self$data$n_instr == 1) { if (is.character(self$score)) { if (self$score == "partialling out") { psi_a = -w_hat * v_hat psi_b = v_hat * u_hat + } else if (self$score == "IV-type") { + psi_a = -d * v_hat + psi_b = v_hat * (y - g_hat) } psis = list( psi_a = psi_a, psi_b = psi_b) } else if (is.function(self$score)) { - psis = self$score(y, z, d, g_hat, m_hat, r_hat, smpls) + psis = self$score( + y = y, z = z, d = d, + l_hat = l_hat, m_hat = m_hat, + r_hat = r_hat, g_hat = g_hat, + smpls = smpls) } } else { stopifnot(self$apply_cross_fitting) @@ -377,15 +604,15 @@ DoubleMLPLIV = R6Class("DoubleMLPLIV", }, ml_nuisance_and_score_elements_partialXZ = function(smpls, ...) { - g_hat = dml_cv_predict(self$learner$ml_g, + l_hat = dml_cv_predict(self$learner$ml_l, c(self$data$x_cols, self$data$other_treat_cols), self$data$y_col, self$data$data_model, - nuisance_id = "nuis_g", + nuisance_id = "nuis_l", smpls = smpls, - est_params = self$get_params("ml_g"), + est_params = self$get_params("ml_l"), return_train_preds = FALSE, - task_type = private$task_type$ml_g, + task_type = private$task_type$ml_l, fold_specific_params = private$fold_specific_params) m_hat_list = dml_cv_predict(self$learner$ml_m, @@ -423,7 +650,7 @@ DoubleMLPLIV = R6Class("DoubleMLPLIV", d = self$data$data_model[[self$data$treat_col]] y = self$data$data_model[[self$data$y_col]] - u_hat = y - g_hat + u_hat = y - l_hat w_hat = d - m_hat_tilde if (is.character(self$score)) { @@ -441,7 +668,7 @@ DoubleMLPLIV = R6Class("DoubleMLPLIV", # res = self$score(y, d, g_hat, m_hat, m_hat_tilde) } res$preds = list( - "ml_g" = g_hat, + "ml_l" = l_hat, "ml_m" = m_hat, "ml_r" = m_hat_tilde) return(res) @@ -520,13 +747,13 @@ DoubleMLPLIV = R6Class("DoubleMLPLIV", function(x) extract_training_data(self$data$data_model, x)) } - tuning_result_g = dml_tune(self$learner$ml_g, + tuning_result_l = dml_tune(self$learner$ml_l, c(self$data$x_cols, self$data$other_treat_cols), self$data$y_col, data_tune_list, - nuisance_id = "nuis_g", - param_set$ml_g, tune_settings, - tune_settings$measure$ml_g, - private$task_type$ml_g) + nuisance_id = "nuis_l", + param_set$ml_l, tune_settings, + tune_settings$measure$ml_l, + private$task_type$ml_l) tuning_result_r = dml_tune(self$learner$ml_r, c(self$data$x_cols, self$data$other_treat_cols), @@ -545,20 +772,101 @@ DoubleMLPLIV = R6Class("DoubleMLPLIV", tune_settings$measure$ml_m, private$task_type$ml_m) - tuning_result = list( - "ml_g" = list(tuning_result_g, - params = tuning_result_g$params), - "ml_m" = list(tuning_result_m, - params = tuning_result_m$params), - "ml_r" = list(tuning_result_r, - params = tuning_result_r$params)) + if (exists("ml_g", where = private$learner_)) { + if (tune_on_folds) { + params_l = tuning_result_l$params + params_r = tuning_result_r$params + params_m = tuning_result_m$params + } else { + params_l = tuning_result_l$params[[1]] + params_r = tuning_result_r$params[[1]] + params_m = tuning_result_m$params[[1]] + } + l_hat = dml_cv_predict(self$learner$ml_l, + c(self$data$x_cols, self$data$other_treat_cols), + self$data$y_col, + self$data$data_model, + nuisance_id = "nuis_l", + smpls = smpls, + est_params = params_l, + return_train_preds = FALSE, + task_type = private$task_type$ml_l, + fold_specific_params = private$fold_specific_params) + + r_hat = dml_cv_predict(self$learner$ml_r, + c(self$data$x_cols, self$data$other_treat_cols), + self$data$treat_col, + self$data$data_model, + nuisance_id = "nuis_r", + smpls = smpls, + est_params = params_r, + return_train_preds = FALSE, + task_type = private$task_type$ml_r, + fold_specific_params = private$fold_specific_params) + + m_hat = dml_cv_predict(self$learner$ml_m, + c(self$data$x_cols, self$data$other_treat_cols), + self$data$treat_col, + self$data$data_model, + nuisance_id = "nuis_m", + smpls = smpls, + est_params = params_m, + return_train_preds = FALSE, + task_type = private$task_type$ml_m, + fold_specific_params = private$fold_specific_params) + + d = self$data$data_model[[self$data$treat_col]] + y = self$data$data_model[[self$data$y_col]] + z = self$data$data_model[[self$data$z_cols]] + + psi_a = -(d - r_hat) * (z - m_hat) + psi_b = (z - m_hat) * (y - l_hat) + theta_initial = -mean(psi_b, na.rm = TRUE) / mean(psi_a, na.rm = TRUE) + + data_aux = data.table(self$data$data_model, + "y_minus_theta_d" = y - theta_initial * d) + + if (!tune_on_folds) { + data_aux_tune_list = list(data_aux) + } else { + data_aux_tune_list = lapply(smpls$train_ids, function(x) { + extract_training_data(data_aux, x) + }) + } + + tuning_result_g = dml_tune(self$learner$ml_g, + c(self$data$x_cols, self$data$other_treat_cols), + "y_minus_theta_d", data_aux_tune_list, + nuisance_id = "nuis_g", + param_set$ml_g, tune_settings, + tune_settings$measure$ml_g, + private$task_type$ml_g) + tuning_result = list( + "ml_l" = list(tuning_result_l, + params = tuning_result_l$params), + "ml_m" = list(tuning_result_m, + params = tuning_result_m$params), + "ml_r" = list(tuning_result_r, + params = tuning_result_r$params), + "ml_g" = list(tuning_result_g, + params = tuning_result_g$params)) + } else { + tuning_result = list( + "ml_l" = list(tuning_result_l, + params = tuning_result_l$params), + "ml_m" = list(tuning_result_m, + params = tuning_result_m$params), + "ml_r" = list(tuning_result_r, + params = tuning_result_r$params)) + } + } else { tuning_result = vector("list", length = self$data$n_instr + 2) names(tuning_result) = c( - "ml_g", "ml_r", + "ml_l", "ml_r", paste0("ml_m_", self$data$z_cols)) - tuning_result[["ml_g"]] = list(tuning_result_g, - params = tuning_result_g$params) + tuning_result[["ml_l"]] = list(tuning_result_l, + params = tuning_result_l$params) tuning_result[["ml_r"]] = list(tuning_result_r, params = tuning_result_r$params) @@ -595,13 +903,13 @@ DoubleMLPLIV = R6Class("DoubleMLPLIV", function(x) extract_training_data(self$data$data_model, x)) } - tuning_result_g = dml_tune(self$learner$ml_g, + tuning_result_l = dml_tune(self$learner$ml_l, c(self$data$x_cols), self$data$y_col, data_tune_list, - nuisance_id = "nuis_g", - param_set$ml_g, tune_settings, - tune_settings$measure$ml_g, - private$task_type$ml_g) + nuisance_id = "nuis_l", + param_set$ml_l, tune_settings, + tune_settings$measure$ml_l, + private$task_type$ml_l) tuning_result_m = dml_tune(self$learner$ml_m, c(self$data$x_cols, self$data$z_cols), @@ -651,8 +959,8 @@ DoubleMLPLIV = R6Class("DoubleMLPLIV", private$task_type$ml_r) tuning_result = list( - "ml_g" = list(tuning_result_g, - params = tuning_result_g$params), + "ml_l" = list(tuning_result_l, + params = tuning_result_l$params), "ml_m" = list(tuning_result_m, params = tuning_result_m$params), "ml_r" = list(tuning_result_r, @@ -690,7 +998,11 @@ DoubleMLPLIV = R6Class("DoubleMLPLIV", check_character(score), check_class(score, "function")) if (is.character(score)) { - valid_score = c("partialling out") + if ((self$partialX && !self$partialZ) && (self$data$n_instr == 1)) { + valid_score = c("partialling out", "IV-type") + } else { + valid_score = c("partialling out") + } assertChoice(score, valid_score) } return() @@ -710,9 +1022,10 @@ DoubleMLPLIV = R6Class("DoubleMLPLIV", # Initializer for partialX DoubleMLPLIV.partialX = function(data, - ml_g, + ml_l, ml_m, ml_r, + ml_g = NULL, n_folds = 5, n_rep = 1, score = "partialling out", @@ -720,18 +1033,20 @@ DoubleMLPLIV.partialX = function(data, draw_sample_splitting = TRUE, apply_cross_fitting = TRUE) { - obj = DoubleMLPLIV$new(data, - ml_g, - ml_m, - ml_r, + obj = DoubleMLPLIV$new( + data = data, + ml_l = ml_l, + ml_m = ml_m, + ml_r = ml_r, + ml_g = ml_g, partialX = TRUE, partialZ = FALSE, - n_folds, - n_rep, - score, - dml_procedure, - draw_sample_splitting, - apply_cross_fitting) + n_folds = n_folds, + n_rep = n_rep, + score = score, + dml_procedure = dml_procedure, + draw_sample_splitting = draw_sample_splitting, + apply_cross_fitting = apply_cross_fitting) return(obj) } @@ -746,25 +1061,27 @@ DoubleMLPLIV.partialZ = function(data, draw_sample_splitting = TRUE, apply_cross_fitting = TRUE) { - obj = DoubleMLPLIV$new(data, - ml_g = NULL, + obj = DoubleMLPLIV$new( + data = data, + ml_l = NULL, ml_m = NULL, - ml_r, + ml_r = ml_r, + ml_g = NULL, partialX = FALSE, partialZ = TRUE, - n_folds, - n_rep, - score, - dml_procedure, - draw_sample_splitting, - apply_cross_fitting) + n_folds = n_folds, + n_rep = n_rep, + score = score, + dml_procedure = dml_procedure, + draw_sample_splitting = draw_sample_splitting, + apply_cross_fitting = apply_cross_fitting) return(obj) } # Initializer for partialXZ DoubleMLPLIV.partialXZ = function(data, - ml_g, + ml_l, ml_m, ml_r, n_folds = 5, @@ -774,18 +1091,20 @@ DoubleMLPLIV.partialXZ = function(data, draw_sample_splitting = TRUE, apply_cross_fitting = TRUE) { - obj = DoubleMLPLIV$new(data, - ml_g, - ml_m, - ml_r, + obj = DoubleMLPLIV$new( + data = data, + ml_l = ml_l, + ml_m = ml_m, + ml_r = ml_r, + ml_g = NULL, partialX = TRUE, partialZ = TRUE, - n_folds, - n_rep, - score, - dml_procedure, - draw_sample_splitting, - apply_cross_fitting) + n_folds = n_folds, + n_rep = n_rep, + score = score, + dml_procedure = dml_procedure, + draw_sample_splitting = draw_sample_splitting, + apply_cross_fitting = apply_cross_fitting) return(obj) } diff --git a/R/double_ml_plr.R b/R/double_ml_plr.R index 792c9300..8fb6ca43 100644 --- a/R/double_ml_plr.R +++ b/R/double_ml_plr.R @@ -42,13 +42,13 @@ #' library(mlr3tuning) #' library(data.table) #' set.seed(2) -#' ml_g = lrn("regr.rpart") -#' ml_m = ml_g$clone() +#' ml_l = lrn("regr.rpart") +#' ml_m = ml_l$clone() #' obj_dml_data = make_plr_CCDDHNR2018(alpha = 0.5) -#' dml_plr_obj = DoubleMLPLR$new(obj_dml_data, ml_g, ml_m) +#' dml_plr_obj = DoubleMLPLR$new(obj_dml_data, ml_l, ml_m) #' #' param_grid = list( -#' "ml_g" = paradox::ParamSet$new(list( +#' "ml_l" = paradox::ParamSet$new(list( #' paradox::ParamDbl$new("cp", lower = 0.01, upper = 0.02), #' paradox::ParamInt$new("minsplit", lower = 1, upper = 2))), #' "ml_m" = paradox::ParamSet$new(list( @@ -73,7 +73,7 @@ DoubleMLPLR = R6Class("DoubleMLPLR", #' The `DoubleMLData` object providing the data and specifying the #' variables of the causal model. #' - #' @param ml_g ([`LearnerRegr`][mlr3::LearnerRegr], + #' @param ml_l ([`LearnerRegr`][mlr3::LearnerRegr], #' [`Learner`][mlr3::Learner], `character(1)`) \cr #' A learner of the class [`LearnerRegr`][mlr3::LearnerRegr], which is #' available from [mlr3](https://mlr3.mlr-org.com/index.html) or its @@ -84,7 +84,7 @@ DoubleMLPLR = R6Class("DoubleMLPLR", #' [`GraphLearner`][mlr3pipelines::GraphLearner]. The learner can possibly #' be passed with specified parameters, for example #' `lrn("regr.cv_glmnet", s = "lambda.min")`. \cr - #' `ml_g` refers to the nuisance function \eqn{g_0(X) = E[Y|X]}. + #' `ml_l` refers to the nuisance function \eqn{l_0(X) = E[Y|X]}. #' #' @param ml_m ([`LearnerRegr`][mlr3::LearnerRegr], #' [`LearnerClassif`][mlr3::LearnerClassif], [`Learner`][mlr3::Learner], @@ -102,6 +102,21 @@ DoubleMLPLR = R6Class("DoubleMLPLR", #' [`GraphLearner`][mlr3pipelines::GraphLearner]. \cr #' `ml_m` refers to the nuisance function \eqn{m_0(X) = E[D|X]}. #' + #' @param ml_g ([`LearnerRegr`][mlr3::LearnerRegr], + #' [`Learner`][mlr3::Learner], `character(1)`) \cr + #' A learner of the class [`LearnerRegr`][mlr3::LearnerRegr], which is + #' available from [mlr3](https://mlr3.mlr-org.com/index.html) or its + #' extension packages [mlr3learners](https://mlr3learners.mlr-org.com/) or + #' [mlr3extralearners](https://mlr3extralearners.mlr-org.com/). + #' Alternatively, a [`Learner`][mlr3::Learner] object with public field + #' `task_type = "regr"` can be passed, for example of class + #' [`GraphLearner`][mlr3pipelines::GraphLearner]. The learner can possibly + #' be passed with specified parameters, for example + #' `lrn("regr.cv_glmnet", s = "lambda.min")`. \cr + #' `ml_g` refers to the nuisance function \eqn{g_0(X) = E[Y - D\theta_0|X]}. + #' Note: The learner `ml_g` is only required for the score `'IV-type'`. + #' Optionally, it can be specified and estimated for callable scores. + #' #' @param n_folds (`integer(1)`)\cr #' Number of folds. Default is `5`. #' @@ -109,10 +124,10 @@ DoubleMLPLR = R6Class("DoubleMLPLR", #' Number of repetitions for the sample splitting. Default is `1`. #' #' @param score (`character(1)`, `function()`) \cr - #' A `character(1)` (`"partialling out"` or `IV-type`) or a `function()` + #' A `character(1)` (`"partialling out"` or `"IV-type"`) or a `function()` #' specifying the score function. #' If a `function()` is provided, it must be of the form - #' `function(y, d, g_hat, m_hat, smpls)` and + #' `function(y, d, l_hat, m_hat, g_hat, smpls)` and #' the returned output must be a named `list()` with elements `psi_a` and #' `psi_b`. Default is `"partialling out"`. #' @@ -127,8 +142,9 @@ DoubleMLPLR = R6Class("DoubleMLPLR", #' @param apply_cross_fitting (`logical(1)`) \cr #' Indicates whether cross-fitting should be applied. Default is `TRUE`. initialize = function(data, - ml_g, + ml_l, ml_m, + ml_g = NULL, n_folds = 5, n_rep = 1, score = "partialling out", @@ -136,6 +152,19 @@ DoubleMLPLR = R6Class("DoubleMLPLR", draw_sample_splitting = TRUE, apply_cross_fitting = TRUE) { + if (missing(ml_l)) { + if (!missing(ml_g)) { + warning(paste0( + "The argument ml_g was renamed to ml_l. ", + "Please adapt the argument name accordingly. ", + "ml_g is redirected to ml_l.\n", + "The redirection will be removed in a future version."), + call. = FALSE) + ml_l = ml_g + ml_g = NULL + } + } + super$initialize_double_ml( data, n_folds, @@ -147,14 +176,193 @@ DoubleMLPLR = R6Class("DoubleMLPLR", private$check_data(self$data) private$check_score(self$score) - ml_g = private$assert_learner(ml_g, "ml_g", Regr = TRUE, Classif = FALSE) + ml_l = private$assert_learner(ml_l, "ml_l", Regr = TRUE, Classif = FALSE) ml_m = private$assert_learner(ml_m, "ml_m", Regr = TRUE, Classif = TRUE) private$learner_ = list( - "ml_g" = ml_g, + "ml_l" = ml_l, "ml_m" = ml_m) + + if (!is.null(ml_g)) { + assert( + check_character(ml_g, max.len = 1), + check_class(ml_g, "Learner")) + if ((is.character(self$score) && (self$score == "IV-type")) || + is.function(self$score)) { + ml_g = private$assert_learner(ml_g, "ml_g", + Regr = TRUE, Classif = FALSE) + private$learner_[["ml_g"]] = ml_g + } else if (is.character(self$score) && + (self$score == "partialling out")) { + warning(paste0( + "A learner ml_g has been provided for ", + "score = 'partialling out' but will be ignored. ", + "A learner ml_g is not required for estimation.")) + } + } else if (is.character(self$score) && (self$score == "IV-type")) { + warning(paste0( + "For score = 'IV-type', learners ml_l and ml_g ", + "should be specified. ", + "Set ml_g = ml_l$clone()."), + call. = FALSE) + ml_g = private$assert_learner(ml_l$clone(), "ml_g", + Regr = TRUE, Classif = FALSE) + private$learner_[["ml_g"]] = ml_g + } + private$initialize_ml_nuisance_params() + }, + # To be removed in version 0.6.0 + # + # Note: Ideally the following duplicate roxygen / docu parts should be taken + # from the base class DoubleML. However, this is an open issue in pkg + # roxygen2, see https://github.com/r-lib/roxygen2/issues/996 & + # https://github.com/r-lib/roxygen2/issues/1043 + # + #' @description + #' Set hyperparameters for the nuisance models of DoubleML models. + #' + #' Note that in the current implementation, either all parameters have to + #' be set globally or all parameters have to be provided fold-specific. + #' + #' @param learner (`character(1)`) \cr + #' The nuisance model/learner (see method `params_names`). + #' + #' @param treat_var (`character(1)`) \cr + #' The treatment varaible (hyperparameters can be set treatment-variable + #' specific). + #' + #' @param params (named `list()`) \cr + #' A named `list()` with estimator parameters. Parameters are used for all + #' folds by default. Alternatively, parameters can be passed in a + #' fold-specific way if option `fold_specific`is `TRUE`. In this case, the + #' outer list needs to be of length `n_rep` and the inner list of length + #' `n_folds`. + #' + #' @param set_fold_specific (`logical(1)`) \cr + #' Indicates if the parameters passed in `params` should be passed in + #' fold-specific way. Default is `FALSE`. If `TRUE`, the outer list needs + #' to be of length `n_rep` and the inner list of length `n_folds`. + #' Note that in the current implementation, either all parameters have to + #' be set globally or all parameters have to be provided fold-specific. + #' + #' @return self + set_ml_nuisance_params = function(learner = NULL, treat_var = NULL, params, + set_fold_specific = FALSE) { + assert_character(learner, len = 1) + if (is.character(self$score) && (self$score == "partialling out") && + (learner == "ml_g")) { + warning(paste0( + "Learner ml_g was renamed to ml_l. ", + "Please adapt the argument learner accordingly. ", + "The provided parameters are set for ml_l. ", + "The redirection will be removed in a future version."), + call. = FALSE) + learner = "ml_l" + } + super$set_ml_nuisance_params( + learner, treat_var, params, + set_fold_specific) + }, + # To be removed in version 0.6.0 + # + # Note: Ideally the following duplicate roxygen / docu parts should be taken + # from the base class DoubleML. However, this is an open issue in pkg + # roxygen2, see https://github.com/r-lib/roxygen2/issues/996 & + # https://github.com/r-lib/roxygen2/issues/1043 + # + #' @description + #' Hyperparameter-tuning for DoubleML models. + #' + #' The hyperparameter-tuning is performed using the tuning methods provided + #' in the [mlr3tuning](https://mlr3tuning.mlr-org.com/) package. For more + #' information on tuning in [mlr3](https://mlr3.mlr-org.com/), we refer to + #' the section on parameter tuning in the + #' [mlr3 book](https://mlr3book.mlr-org.com/optimization.html#tuning). + #' + #' @param param_set (named `list()`) \cr + #' A named `list` with a parameter grid for each nuisance model/learner + #' (see method `learner_names()`). The parameter grid must be an object of + #' class [ParamSet][paradox::ParamSet]. + #' + #' @param tune_settings (named `list()`) \cr + #' A named `list()` with arguments passed to the hyperparameter-tuning with + #' [mlr3tuning](https://mlr3tuning.mlr-org.com/) to set up + #' [TuningInstance][mlr3tuning::TuningInstanceSingleCrit] objects. + #' `tune_settings` has entries + #' * `terminator` ([Terminator][bbotk::Terminator]) \cr + #' A [Terminator][bbotk::Terminator] object. Specification of `terminator` + #' is required to perform tuning. + #' * `algorithm` ([Tuner][mlr3tuning::Tuner] or `character(1)`) \cr + #' A [Tuner][mlr3tuning::Tuner] object (recommended) or key passed to the + #' respective dictionary to specify the tuning algorithm used in + #' [tnr()][mlr3tuning::tnr()]. `algorithm` is passed as an argument to + #' [tnr()][mlr3tuning::tnr()]. If `algorithm` is not specified by the users, + #' default is set to `"grid_search"`. If set to `"grid_search"`, then + #' additional argument `"resolution"` is required. + #' * `rsmp_tune` ([Resampling][mlr3::Resampling] or `character(1)`)\cr + #' A [Resampling][mlr3::Resampling] object (recommended) or option passed + #' to [rsmp()][mlr3::mlr_sugar] to initialize a + #' [Resampling][mlr3::Resampling] for parameter tuning in `mlr3`. + #' If not specified by the user, default is set to `"cv"` + #' (cross-validation). + #' * `n_folds_tune` (`integer(1)`, optional) \cr + #' If `rsmp_tune = "cv"`, number of folds used for cross-validation. + #' If not specified by the user, default is set to `5`. + #' * `measure` (`NULL`, named `list()`, optional) \cr + #' Named list containing the measures used for parameter tuning. Entries in + #' list must either be [Measure][mlr3::Measure] objects or keys to be + #' passed to passed to [msr()][mlr3::msr()]. The names of the entries must + #' match the learner names (see method `learner_names()`). If set to `NULL`, + #' default measures are used, i.e., `"regr.mse"` for continuous outcome + #' variables and `"classif.ce"` for binary outcomes. + #' * `resolution` (`character(1)`) \cr The key passed to the respective + #' dictionary to specify the tuning algorithm used in + #' [tnr()][mlr3tuning::tnr()]. `resolution` is passed as an argument to + #' [tnr()][mlr3tuning::tnr()]. + #' + #' @param tune_on_folds (`logical(1)`) \cr + #' Indicates whether the tuning should be done fold-specific or globally. + #' Default is `FALSE`. + #' + #' @return self + tune = function(param_set, tune_settings = list( + n_folds_tune = 5, + rsmp_tune = mlr3::rsmp("cv", folds = 5), + measure = NULL, + terminator = mlr3tuning::trm("evals", n_evals = 20), + algorithm = mlr3tuning::tnr("grid_search"), + resolution = 5), + tune_on_folds = FALSE) { + assert_list(param_set) + if (is.character(self$score) && (self$score == "partialling out")) { + if (exists("ml_g", where = param_set) && !exists("ml_l", where = param_set)) { + warning(paste0( + "Learner ml_g was renamed to ml_l. ", + "Please adapt the name in param_set accordingly. ", + "The provided param_set for ml_g is used for ml_l. ", + "The redirection will be removed in a future version."), + call. = FALSE) + names(param_set)[names(param_set) == "ml_g"] = "ml_l" + } + } + + assert_list(tune_settings) + if (test_names(names(tune_settings), must.include = "measure") && !is.null(tune_settings$measure)) { + assert_list(tune_settings$measure) + if (exists("ml_g", where = tune_settings$measure) && !exists("ml_l", where = tune_settings$measure)) { + warning(paste0( + "Learner ml_g was renamed to ml_l. ", + "Please adapt the name in tune_settings$measure accordingly. ", + "The provided tune_settings$measure for ml_g is used for ml_l. ", + "The redirection will be removed in a future version."), + call. = FALSE) + names(tune_settings$measure)[names(tune_settings$measure) == "ml_g"] = "ml_l" + } + } + + super$tune(param_set, tune_settings, tune_on_folds) } ), private = list( @@ -163,22 +371,25 @@ DoubleMLPLR = R6Class("DoubleMLPLR", nuisance = vector("list", self$data$n_treat) names(nuisance) = self$data$d_cols private$params_ = list( - "ml_g" = nuisance, + "ml_l" = nuisance, "ml_m" = nuisance) + if (exists("ml_g", where = private$learner_)) { + private$params_[["ml_g"]] = nuisance + } invisible(self) }, ml_nuisance_and_score_elements = function(smpls, ...) { - g_hat = dml_cv_predict(self$learner$ml_g, + l_hat = dml_cv_predict(self$learner$ml_l, c(self$data$x_cols, self$data$other_treat_cols), self$data$y_col, self$data$data_model, - nuisance_id = "nuis_g", + nuisance_id = "nuis_l", smpls = smpls, - est_params = self$get_params("ml_g"), + est_params = self$get_params("ml_l"), return_train_preds = FALSE, - task_type = private$task_type$ml_g, + task_type = private$task_type$ml_l, fold_specific_params = private$fold_specific_params) m_hat = dml_cv_predict(self$learner$ml_m, @@ -195,34 +406,62 @@ DoubleMLPLR = R6Class("DoubleMLPLR", d = self$data$data_model[[self$data$treat_col]] y = self$data$data_model[[self$data$y_col]] - res = private$score_elements(y, d, g_hat, m_hat, smpls) + g_hat = NULL + if (exists("ml_g", where = private$learner_)) { + # get an initial estimate for theta using the partialling out score + psi_a = -(d - m_hat) * (d - m_hat) + psi_b = (d - m_hat) * (y - l_hat) + theta_initial = -mean(psi_b, na.rm = TRUE) / mean(psi_a, na.rm = TRUE) + + data_aux = data.table(self$data$data_model, + "y_minus_theta_d" = y - theta_initial * d) + + g_hat = dml_cv_predict(self$learner$ml_g, + c(self$data$x_cols, self$data$other_treat_cols), + "y_minus_theta_d", + data_aux, + nuisance_id = "nuis_g", + smpls = smpls, + est_params = self$get_params("ml_g"), + return_train_preds = FALSE, + task_type = private$task_type$ml_g, + fold_specific_params = private$fold_specific_params) + } + + res = private$score_elements(y, d, l_hat, m_hat, g_hat, smpls) res$preds = list( - "ml_g" = g_hat, - "ml_m" = m_hat) + "ml_l" = l_hat, + "ml_m" = m_hat, + "ml_g" = g_hat) return(res) }, - score_elements = function(y, d, g_hat, m_hat, smpls) { + score_elements = function(y, d, l_hat, m_hat, g_hat, smpls) { v_hat = d - m_hat - u_hat = y - g_hat + u_hat = y - l_hat v_hatd = v_hat * d if (is.character(self$score)) { if (self$score == "IV-type") { psi_a = -v_hatd + psi_b = v_hat * (y - g_hat) } else if (self$score == "partialling out") { psi_a = -v_hat * v_hat + psi_b = v_hat * u_hat } - psi_b = v_hat * u_hat psis = list( psi_a = psi_a, psi_b = psi_b) } else if (is.function(self$score)) { - psis = self$score(y, d, g_hat, m_hat, smpls) + psis = self$score( + y = y, d = d, + l_hat = l_hat, m_hat = m_hat, g_hat = g_hat, + smpls = smpls) } return(psis) }, ml_nuisance_tuning = function(smpls, param_set, tune_settings, tune_on_folds, ...) { + if (!tune_on_folds) { data_tune_list = list(self$data$data_model) } else { @@ -231,13 +470,13 @@ DoubleMLPLR = R6Class("DoubleMLPLR", }) } - tuning_result_g = dml_tune(self$learner$ml_g, + tuning_result_l = dml_tune(self$learner$ml_l, c(self$data$x_cols, self$data$other_treat_cols), self$data$y_col, data_tune_list, - nuisance_id = "nuis_g", - param_set$ml_g, tune_settings, - tune_settings$measure$ml_g, - private$task_type$ml_g) + nuisance_id = "nuis_l", + param_set$ml_l, tune_settings, + tune_settings$measure$ml_l, + private$task_type$ml_l) tuning_result_m = dml_tune(self$learner$ml_m, c(self$data$x_cols, self$data$other_treat_cols), @@ -247,9 +486,71 @@ DoubleMLPLR = R6Class("DoubleMLPLR", tune_settings$measure$ml_m, private$task_type$ml_m) - tuning_result = list( - "ml_g" = list(tuning_result_g, params = tuning_result_g$params), - "ml_m" = list(tuning_result_m, params = tuning_result_m$params)) + if (exists("ml_g", where = private$learner_)) { + if (tune_on_folds) { + params_l = tuning_result_l$params + params_m = tuning_result_m$params + } else { + params_l = tuning_result_l$params[[1]] + params_m = tuning_result_m$params[[1]] + } + l_hat = dml_cv_predict(self$learner$ml_l, + c(self$data$x_cols, self$data$other_treat_cols), + self$data$y_col, + self$data$data_model, + nuisance_id = "nuis_l", + smpls = smpls, + est_params = params_l, + return_train_preds = FALSE, + task_type = private$task_type$ml_l, + fold_specific_params = private$fold_specific_params) + + m_hat = dml_cv_predict(self$learner$ml_m, + c(self$data$x_cols, self$data$other_treat_cols), + self$data$treat_col, + self$data$data_model, + nuisance_id = "nuis_m", + smpls = smpls, + est_params = params_m, + return_train_preds = FALSE, + task_type = private$task_type$ml_m, + fold_specific_params = private$fold_specific_params) + + d = self$data$data_model[[self$data$treat_col]] + y = self$data$data_model[[self$data$y_col]] + + psi_a = -(d - m_hat) * (d - m_hat) + psi_b = (d - m_hat) * (y - l_hat) + theta_initial = -mean(psi_b, na.rm = TRUE) / mean(psi_a, na.rm = TRUE) + + data_aux = data.table(self$data$data_model, + "y_minus_theta_d" = y - theta_initial * d) + + if (!tune_on_folds) { + data_aux_tune_list = list(data_aux) + } else { + data_aux_tune_list = lapply(smpls$train_ids, function(x) { + extract_training_data(data_aux, x) + }) + } + + tuning_result_g = dml_tune(self$learner$ml_g, + c(self$data$x_cols, self$data$other_treat_cols), + "y_minus_theta_d", data_aux_tune_list, + nuisance_id = "nuis_g", + param_set$ml_g, tune_settings, + tune_settings$measure$ml_g, + private$task_type$ml_g) + tuning_result = list( + "ml_l" = list(tuning_result_l, params = tuning_result_l$params), + "ml_m" = list(tuning_result_m, params = tuning_result_m$params), + "ml_g" = list(tuning_result_g, params = tuning_result_g$params)) + } else { + tuning_result = list( + "ml_l" = list(tuning_result_l, params = tuning_result_l$params), + "ml_m" = list(tuning_result_m, params = tuning_result_m$params)) + } + return(tuning_result) }, check_score = function(score) { diff --git a/man/DoubleMLPLIV.Rd b/man/DoubleMLPLIV.Rd index ba3f31bc..faf6add0 100644 --- a/man/DoubleMLPLIV.Rd +++ b/man/DoubleMLPLIV.Rd @@ -25,11 +25,11 @@ library(mlr3) library(mlr3learners) library(data.table) set.seed(2) -ml_g = lrn("regr.ranger", num.trees = 100, mtry = 20, min.node.size = 2, max.depth = 5) -ml_m = ml_g$clone() -ml_r = ml_g$clone() +ml_l = lrn("regr.ranger", num.trees = 100, mtry = 20, min.node.size = 2, max.depth = 5) +ml_m = ml_l$clone() +ml_r = ml_l$clone() obj_dml_data = make_pliv_CHS2015(alpha = 1, n_obs = 500, dim_x = 20, dim_z = 1) -dml_pliv_obj = DoubleMLPLIV$new(obj_dml_data, ml_g, ml_m, ml_r) +dml_pliv_obj = DoubleMLPLIV$new(obj_dml_data, ml_l, ml_m, ml_r) dml_pliv_obj$fit() dml_pliv_obj$summary() } @@ -41,15 +41,15 @@ library(mlr3learners) library(mlr3tuning) library(data.table) set.seed(2) -ml_g = lrn("regr.rpart") -ml_m = ml_g$clone() -ml_r = ml_g$clone() +ml_l = lrn("regr.rpart") +ml_m = ml_l$clone() +ml_r = ml_l$clone() obj_dml_data = make_pliv_CHS2015( alpha = 1, n_obs = 500, dim_x = 20, dim_z = 1) -dml_pliv_obj = DoubleMLPLIV$new(obj_dml_data, ml_g, ml_m, ml_r) +dml_pliv_obj = DoubleMLPLIV$new(obj_dml_data, ml_l, ml_m, ml_r) param_grid = list( - "ml_g" = paradox::ParamSet$new(list( + "ml_l" = paradox::ParamSet$new(list( paradox::ParamDbl$new("cp", lower = 0.01, upper = 0.02), paradox::ParamInt$new("minsplit", lower = 1, upper = 2))), "ml_m" = paradox::ParamSet$new(list( @@ -94,6 +94,8 @@ Indicates whether instruments \eqn{Z} should be partialled out.} \subsection{Public methods}{ \itemize{ \item \href{#method-new}{\code{DoubleMLPLIV$new()}} +\item \href{#method-set_ml_nuisance_params}{\code{DoubleMLPLIV$set_ml_nuisance_params()}} +\item \href{#method-tune}{\code{DoubleMLPLIV$tune()}} \item \href{#method-clone}{\code{DoubleMLPLIV$clone()}} } } @@ -108,11 +110,9 @@ Indicates whether instruments \eqn{Z} should be partialled out.} \item \out{}\href{../../DoubleML/html/DoubleML.html#method-p_adjust}{\code{DoubleML::DoubleML$p_adjust()}}\out{} \item \out{}\href{../../DoubleML/html/DoubleML.html#method-params_names}{\code{DoubleML::DoubleML$params_names()}}\out{} \item \out{}\href{../../DoubleML/html/DoubleML.html#method-print}{\code{DoubleML::DoubleML$print()}}\out{} -\item \out{}\href{../../DoubleML/html/DoubleML.html#method-set_ml_nuisance_params}{\code{DoubleML::DoubleML$set_ml_nuisance_params()}}\out{} \item \out{}\href{../../DoubleML/html/DoubleML.html#method-set_sample_splitting}{\code{DoubleML::DoubleML$set_sample_splitting()}}\out{} \item \out{}\href{../../DoubleML/html/DoubleML.html#method-split_samples}{\code{DoubleML::DoubleML$split_samples()}}\out{} \item \out{}\href{../../DoubleML/html/DoubleML.html#method-summary}{\code{DoubleML::DoubleML$summary()}}\out{} -\item \out{}\href{../../DoubleML/html/DoubleML.html#method-tune}{\code{DoubleML::DoubleML$tune()}}\out{} } \out{} } @@ -124,9 +124,10 @@ Creates a new instance of this R6 class. \subsection{Usage}{ \if{html}{\out{
}}\preformatted{DoubleMLPLIV$new( data, - ml_g, + ml_l, ml_m, ml_r, + ml_g = NULL, partialX = TRUE, partialZ = FALSE, n_folds = 5, @@ -145,7 +146,7 @@ Creates a new instance of this R6 class. The \code{DoubleMLData} object providing the data and specifying the variables of the causal model.} -\item{\code{ml_g}}{(\code{\link[mlr3:LearnerRegr]{LearnerRegr}}, +\item{\code{ml_l}}{(\code{\link[mlr3:LearnerRegr]{LearnerRegr}}, \code{\link[mlr3:Learner]{Learner}}, \code{character(1)}) \cr A learner of the class \code{\link[mlr3:LearnerRegr]{LearnerRegr}}, which is available from \href{https://mlr3.mlr-org.com/index.html}{mlr3} or its @@ -156,7 +157,7 @@ Alternatively, a \code{\link[mlr3:Learner]{Learner}} object with public field \code{\link[mlr3pipelines:mlr_learners_graph]{GraphLearner}}. The learner can possibly be passed with specified parameters, for example \code{lrn("regr.cv_glmnet", s = "lambda.min")}. \cr -\code{ml_g} refers to the nuisance function \eqn{g_0(X) = E[Y|X]}.} +\code{ml_l} refers to the nuisance function \eqn{l_0(X) = E[Y|X]}.} \item{\code{ml_m}}{(\code{\link[mlr3:LearnerRegr]{LearnerRegr}}, \code{\link[mlr3:Learner]{Learner}}, \code{character(1)}) \cr @@ -184,6 +185,21 @@ be passed with specified parameters, for example \code{lrn("regr.cv_glmnet", s = "lambda.min")}. \cr \code{ml_r} refers to the nuisance function \eqn{r_0(X) = E[D|X]}.} +\item{\code{ml_g}}{(\code{\link[mlr3:LearnerRegr]{LearnerRegr}}, +\code{\link[mlr3:Learner]{Learner}}, \code{character(1)}) \cr +A learner of the class \code{\link[mlr3:LearnerRegr]{LearnerRegr}}, which is +available from \href{https://mlr3.mlr-org.com/index.html}{mlr3} or its +extension packages \href{https://mlr3learners.mlr-org.com/}{mlr3learners} or +\href{https://mlr3extralearners.mlr-org.com/}{mlr3extralearners}. +Alternatively, a \code{\link[mlr3:Learner]{Learner}} object with public field +\code{task_type = "regr"} can be passed, for example of class +\code{\link[mlr3pipelines:mlr_learners_graph]{GraphLearner}}. The learner can possibly +be passed with specified parameters, for example +\code{lrn("regr.cv_glmnet", s = "lambda.min")}. \cr +\code{ml_g} refers to the nuisance function \eqn{g_0(X) = E[Y - D\theta_0|X]}. +Note: The learner \code{ml_g} is only required for the score \code{'IV-type'}. +Optionally, it can be specified and estimated for callable scores.} + \item{\code{partialX}}{(\code{logical(1)}) \cr Indicates whether covariates \eqn{X} should be partialled out. Default is \code{TRUE}.} @@ -199,10 +215,10 @@ Number of folds. Default is \code{5}.} Number of repetitions for the sample splitting. Default is \code{1}.} \item{\code{score}}{(\code{character(1)}, \verb{function()}) \cr -A \code{character(1)} (\code{"partialling out"} is the only choice) or a -\verb{function()} specifying the score function. +A \code{character(1)} (\code{"partialling out"} or \code{"IV-type"}) or a \verb{function()} +specifying the score function. If a \verb{function()} is provided, it must be of the form -\verb{function(y, z, d, g_hat, m_hat, r_hat, smpls)} and +\verb{function(y, z, d, l_hat, m_hat, r_hat, g_hat, smpls)} and the returned output must be a named \code{list()} with elements \code{psi_a} and \code{psi_b}. Default is \code{"partialling out"}.} @@ -221,6 +237,130 @@ Indicates whether cross-fitting should be applied. Default is \code{TRUE}.} } } \if{html}{\out{
}} +\if{html}{\out{}} +\if{latex}{\out{\hypertarget{method-set_ml_nuisance_params}{}}} +\subsection{Method \code{set_ml_nuisance_params()}}{ +Set hyperparameters for the nuisance models of DoubleML models. + +Note that in the current implementation, either all parameters have to +be set globally or all parameters have to be provided fold-specific. +\subsection{Usage}{ +\if{html}{\out{
}}\preformatted{DoubleMLPLIV$set_ml_nuisance_params( + learner = NULL, + treat_var = NULL, + params, + set_fold_specific = FALSE +)}\if{html}{\out{
}} +} + +\subsection{Arguments}{ +\if{html}{\out{
}} +\describe{ +\item{\code{learner}}{(\code{character(1)}) \cr +The nuisance model/learner (see method \code{params_names}).} + +\item{\code{treat_var}}{(\code{character(1)}) \cr +The treatment varaible (hyperparameters can be set treatment-variable +specific).} + +\item{\code{params}}{(named \code{list()}) \cr +A named \code{list()} with estimator parameters. Parameters are used for all +folds by default. Alternatively, parameters can be passed in a +fold-specific way if option \code{fold_specific}is \code{TRUE}. In this case, the +outer list needs to be of length \code{n_rep} and the inner list of length +\code{n_folds}.} + +\item{\code{set_fold_specific}}{(\code{logical(1)}) \cr +Indicates if the parameters passed in \code{params} should be passed in +fold-specific way. Default is \code{FALSE}. If \code{TRUE}, the outer list needs +to be of length \code{n_rep} and the inner list of length \code{n_folds}. +Note that in the current implementation, either all parameters have to +be set globally or all parameters have to be provided fold-specific.} +} +\if{html}{\out{
}} +} +\subsection{Returns}{ +self +} +} +\if{html}{\out{
}} +\if{html}{\out{}} +\if{latex}{\out{\hypertarget{method-tune}{}}} +\subsection{Method \code{tune()}}{ +Hyperparameter-tuning for DoubleML models. + +The hyperparameter-tuning is performed using the tuning methods provided +in the \href{https://mlr3tuning.mlr-org.com/}{mlr3tuning} package. For more +information on tuning in \href{https://mlr3.mlr-org.com/}{mlr3}, we refer to +the section on parameter tuning in the +\href{https://mlr3book.mlr-org.com/optimization.html#tuning}{mlr3 book}. +\subsection{Usage}{ +\if{html}{\out{
}}\preformatted{DoubleMLPLIV$tune( + param_set, + tune_settings = list(n_folds_tune = 5, rsmp_tune = mlr3::rsmp("cv", folds = 5), + measure = NULL, terminator = mlr3tuning::trm("evals", n_evals = 20), algorithm = + mlr3tuning::tnr("grid_search"), resolution = 5), + tune_on_folds = FALSE +)}\if{html}{\out{
}} +} + +\subsection{Arguments}{ +\if{html}{\out{
}} +\describe{ +\item{\code{param_set}}{(named \code{list()}) \cr +A named \code{list} with a parameter grid for each nuisance model/learner +(see method \code{learner_names()}). The parameter grid must be an object of +class \link[paradox:ParamSet]{ParamSet}.} + +\item{\code{tune_settings}}{(named \code{list()}) \cr +A named \code{list()} with arguments passed to the hyperparameter-tuning with +\href{https://mlr3tuning.mlr-org.com/}{mlr3tuning} to set up +\link[mlr3tuning:TuningInstanceSingleCrit]{TuningInstance} objects. +\code{tune_settings} has entries +\itemize{ +\item \code{terminator} (\link[bbotk:Terminator]{Terminator}) \cr +A \link[bbotk:Terminator]{Terminator} object. Specification of \code{terminator} +is required to perform tuning. +\item \code{algorithm} (\link[mlr3tuning:Tuner]{Tuner} or \code{character(1)}) \cr +A \link[mlr3tuning:Tuner]{Tuner} object (recommended) or key passed to the +respective dictionary to specify the tuning algorithm used in +\link[mlr3tuning:tnr]{tnr()}. \code{algorithm} is passed as an argument to +\link[mlr3tuning:tnr]{tnr()}. If \code{algorithm} is not specified by the users, +default is set to \code{"grid_search"}. If set to \code{"grid_search"}, then +additional argument \code{"resolution"} is required. +\item \code{rsmp_tune} (\link[mlr3:Resampling]{Resampling} or \code{character(1)})\cr +A \link[mlr3:Resampling]{Resampling} object (recommended) or option passed +to \link[mlr3:mlr_sugar]{rsmp()} to initialize a +\link[mlr3:Resampling]{Resampling} for parameter tuning in \code{mlr3}. +If not specified by the user, default is set to \code{"cv"} +(cross-validation). +\item \code{n_folds_tune} (\code{integer(1)}, optional) \cr +If \code{rsmp_tune = "cv"}, number of folds used for cross-validation. +If not specified by the user, default is set to \code{5}. +\item \code{measure} (\code{NULL}, named \code{list()}, optional) \cr +Named list containing the measures used for parameter tuning. Entries in +list must either be \link[mlr3:Measure]{Measure} objects or keys to be +passed to passed to \link[mlr3:mlr_sugar]{msr()}. The names of the entries must +match the learner names (see method \code{learner_names()}). If set to \code{NULL}, +default measures are used, i.e., \code{"regr.mse"} for continuous outcome +variables and \code{"classif.ce"} for binary outcomes. +\item \code{resolution} (\code{character(1)}) \cr The key passed to the respective +dictionary to specify the tuning algorithm used in +\link[mlr3tuning:tnr]{tnr()}. \code{resolution} is passed as an argument to +\link[mlr3tuning:tnr]{tnr()}. +}} + +\item{\code{tune_on_folds}}{(\code{logical(1)}) \cr +Indicates whether the tuning should be done fold-specific or globally. +Default is \code{FALSE}.} +} +\if{html}{\out{
}} +} +\subsection{Returns}{ +self +} +} +\if{html}{\out{
}} \if{html}{\out{}} \if{latex}{\out{\hypertarget{method-clone}{}}} \subsection{Method \code{clone()}}{ diff --git a/man/DoubleMLPLR.Rd b/man/DoubleMLPLR.Rd index 1fa63ff1..8aa46212 100644 --- a/man/DoubleMLPLR.Rd +++ b/man/DoubleMLPLR.Rd @@ -43,13 +43,13 @@ library(mlr3learners) library(mlr3tuning) library(data.table) set.seed(2) -ml_g = lrn("regr.rpart") -ml_m = ml_g$clone() +ml_l = lrn("regr.rpart") +ml_m = ml_l$clone() obj_dml_data = make_plr_CCDDHNR2018(alpha = 0.5) -dml_plr_obj = DoubleMLPLR$new(obj_dml_data, ml_g, ml_m) +dml_plr_obj = DoubleMLPLR$new(obj_dml_data, ml_l, ml_m) param_grid = list( - "ml_g" = paradox::ParamSet$new(list( + "ml_l" = paradox::ParamSet$new(list( paradox::ParamDbl$new("cp", lower = 0.01, upper = 0.02), paradox::ParamInt$new("minsplit", lower = 1, upper = 2))), "ml_m" = paradox::ParamSet$new(list( @@ -80,6 +80,8 @@ Other DoubleML: \subsection{Public methods}{ \itemize{ \item \href{#method-new}{\code{DoubleMLPLR$new()}} +\item \href{#method-set_ml_nuisance_params}{\code{DoubleMLPLR$set_ml_nuisance_params()}} +\item \href{#method-tune}{\code{DoubleMLPLR$tune()}} \item \href{#method-clone}{\code{DoubleMLPLR$clone()}} } } @@ -94,11 +96,9 @@ Other DoubleML: \item \out{}\href{../../DoubleML/html/DoubleML.html#method-p_adjust}{\code{DoubleML::DoubleML$p_adjust()}}\out{} \item \out{}\href{../../DoubleML/html/DoubleML.html#method-params_names}{\code{DoubleML::DoubleML$params_names()}}\out{} \item \out{}\href{../../DoubleML/html/DoubleML.html#method-print}{\code{DoubleML::DoubleML$print()}}\out{} -\item \out{}\href{../../DoubleML/html/DoubleML.html#method-set_ml_nuisance_params}{\code{DoubleML::DoubleML$set_ml_nuisance_params()}}\out{} \item \out{}\href{../../DoubleML/html/DoubleML.html#method-set_sample_splitting}{\code{DoubleML::DoubleML$set_sample_splitting()}}\out{} \item \out{}\href{../../DoubleML/html/DoubleML.html#method-split_samples}{\code{DoubleML::DoubleML$split_samples()}}\out{} \item \out{}\href{../../DoubleML/html/DoubleML.html#method-summary}{\code{DoubleML::DoubleML$summary()}}\out{} -\item \out{}\href{../../DoubleML/html/DoubleML.html#method-tune}{\code{DoubleML::DoubleML$tune()}}\out{} } \out{} } @@ -110,8 +110,9 @@ Creates a new instance of this R6 class. \subsection{Usage}{ \if{html}{\out{
}}\preformatted{DoubleMLPLR$new( data, - ml_g, + ml_l, ml_m, + ml_g = NULL, n_folds = 5, n_rep = 1, score = "partialling out", @@ -128,7 +129,7 @@ Creates a new instance of this R6 class. The \code{DoubleMLData} object providing the data and specifying the variables of the causal model.} -\item{\code{ml_g}}{(\code{\link[mlr3:LearnerRegr]{LearnerRegr}}, +\item{\code{ml_l}}{(\code{\link[mlr3:LearnerRegr]{LearnerRegr}}, \code{\link[mlr3:Learner]{Learner}}, \code{character(1)}) \cr A learner of the class \code{\link[mlr3:LearnerRegr]{LearnerRegr}}, which is available from \href{https://mlr3.mlr-org.com/index.html}{mlr3} or its @@ -139,7 +140,7 @@ Alternatively, a \code{\link[mlr3:Learner]{Learner}} object with public field \code{\link[mlr3pipelines:mlr_learners_graph]{GraphLearner}}. The learner can possibly be passed with specified parameters, for example \code{lrn("regr.cv_glmnet", s = "lambda.min")}. \cr -\code{ml_g} refers to the nuisance function \eqn{g_0(X) = E[Y|X]}.} +\code{ml_l} refers to the nuisance function \eqn{l_0(X) = E[Y|X]}.} \item{\code{ml_m}}{(\code{\link[mlr3:LearnerRegr]{LearnerRegr}}, \code{\link[mlr3:LearnerClassif]{LearnerClassif}}, \code{\link[mlr3:Learner]{Learner}}, @@ -157,6 +158,21 @@ respectively, for example of class \code{\link[mlr3pipelines:mlr_learners_graph]{GraphLearner}}. \cr \code{ml_m} refers to the nuisance function \eqn{m_0(X) = E[D|X]}.} +\item{\code{ml_g}}{(\code{\link[mlr3:LearnerRegr]{LearnerRegr}}, +\code{\link[mlr3:Learner]{Learner}}, \code{character(1)}) \cr +A learner of the class \code{\link[mlr3:LearnerRegr]{LearnerRegr}}, which is +available from \href{https://mlr3.mlr-org.com/index.html}{mlr3} or its +extension packages \href{https://mlr3learners.mlr-org.com/}{mlr3learners} or +\href{https://mlr3extralearners.mlr-org.com/}{mlr3extralearners}. +Alternatively, a \code{\link[mlr3:Learner]{Learner}} object with public field +\code{task_type = "regr"} can be passed, for example of class +\code{\link[mlr3pipelines:mlr_learners_graph]{GraphLearner}}. The learner can possibly +be passed with specified parameters, for example +\code{lrn("regr.cv_glmnet", s = "lambda.min")}. \cr +\code{ml_g} refers to the nuisance function \eqn{g_0(X) = E[Y - D\theta_0|X]}. +Note: The learner \code{ml_g} is only required for the score \code{'IV-type'}. +Optionally, it can be specified and estimated for callable scores.} + \item{\code{n_folds}}{(\code{integer(1)})\cr Number of folds. Default is \code{5}.} @@ -164,10 +180,10 @@ Number of folds. Default is \code{5}.} Number of repetitions for the sample splitting. Default is \code{1}.} \item{\code{score}}{(\code{character(1)}, \verb{function()}) \cr -A \code{character(1)} (\code{"partialling out"} or \code{IV-type}) or a \verb{function()} +A \code{character(1)} (\code{"partialling out"} or \code{"IV-type"}) or a \verb{function()} specifying the score function. If a \verb{function()} is provided, it must be of the form -\verb{function(y, d, g_hat, m_hat, smpls)} and +\verb{function(y, d, l_hat, m_hat, g_hat, smpls)} and the returned output must be a named \code{list()} with elements \code{psi_a} and \code{psi_b}. Default is \code{"partialling out"}.} @@ -186,6 +202,130 @@ Indicates whether cross-fitting should be applied. Default is \code{TRUE}.} } } \if{html}{\out{
}} +\if{html}{\out{}} +\if{latex}{\out{\hypertarget{method-set_ml_nuisance_params}{}}} +\subsection{Method \code{set_ml_nuisance_params()}}{ +Set hyperparameters for the nuisance models of DoubleML models. + +Note that in the current implementation, either all parameters have to +be set globally or all parameters have to be provided fold-specific. +\subsection{Usage}{ +\if{html}{\out{
}}\preformatted{DoubleMLPLR$set_ml_nuisance_params( + learner = NULL, + treat_var = NULL, + params, + set_fold_specific = FALSE +)}\if{html}{\out{
}} +} + +\subsection{Arguments}{ +\if{html}{\out{
}} +\describe{ +\item{\code{learner}}{(\code{character(1)}) \cr +The nuisance model/learner (see method \code{params_names}).} + +\item{\code{treat_var}}{(\code{character(1)}) \cr +The treatment varaible (hyperparameters can be set treatment-variable +specific).} + +\item{\code{params}}{(named \code{list()}) \cr +A named \code{list()} with estimator parameters. Parameters are used for all +folds by default. Alternatively, parameters can be passed in a +fold-specific way if option \code{fold_specific}is \code{TRUE}. In this case, the +outer list needs to be of length \code{n_rep} and the inner list of length +\code{n_folds}.} + +\item{\code{set_fold_specific}}{(\code{logical(1)}) \cr +Indicates if the parameters passed in \code{params} should be passed in +fold-specific way. Default is \code{FALSE}. If \code{TRUE}, the outer list needs +to be of length \code{n_rep} and the inner list of length \code{n_folds}. +Note that in the current implementation, either all parameters have to +be set globally or all parameters have to be provided fold-specific.} +} +\if{html}{\out{
}} +} +\subsection{Returns}{ +self +} +} +\if{html}{\out{
}} +\if{html}{\out{}} +\if{latex}{\out{\hypertarget{method-tune}{}}} +\subsection{Method \code{tune()}}{ +Hyperparameter-tuning for DoubleML models. + +The hyperparameter-tuning is performed using the tuning methods provided +in the \href{https://mlr3tuning.mlr-org.com/}{mlr3tuning} package. For more +information on tuning in \href{https://mlr3.mlr-org.com/}{mlr3}, we refer to +the section on parameter tuning in the +\href{https://mlr3book.mlr-org.com/optimization.html#tuning}{mlr3 book}. +\subsection{Usage}{ +\if{html}{\out{
}}\preformatted{DoubleMLPLR$tune( + param_set, + tune_settings = list(n_folds_tune = 5, rsmp_tune = mlr3::rsmp("cv", folds = 5), + measure = NULL, terminator = mlr3tuning::trm("evals", n_evals = 20), algorithm = + mlr3tuning::tnr("grid_search"), resolution = 5), + tune_on_folds = FALSE +)}\if{html}{\out{
}} +} + +\subsection{Arguments}{ +\if{html}{\out{
}} +\describe{ +\item{\code{param_set}}{(named \code{list()}) \cr +A named \code{list} with a parameter grid for each nuisance model/learner +(see method \code{learner_names()}). The parameter grid must be an object of +class \link[paradox:ParamSet]{ParamSet}.} + +\item{\code{tune_settings}}{(named \code{list()}) \cr +A named \code{list()} with arguments passed to the hyperparameter-tuning with +\href{https://mlr3tuning.mlr-org.com/}{mlr3tuning} to set up +\link[mlr3tuning:TuningInstanceSingleCrit]{TuningInstance} objects. +\code{tune_settings} has entries +\itemize{ +\item \code{terminator} (\link[bbotk:Terminator]{Terminator}) \cr +A \link[bbotk:Terminator]{Terminator} object. Specification of \code{terminator} +is required to perform tuning. +\item \code{algorithm} (\link[mlr3tuning:Tuner]{Tuner} or \code{character(1)}) \cr +A \link[mlr3tuning:Tuner]{Tuner} object (recommended) or key passed to the +respective dictionary to specify the tuning algorithm used in +\link[mlr3tuning:tnr]{tnr()}. \code{algorithm} is passed as an argument to +\link[mlr3tuning:tnr]{tnr()}. If \code{algorithm} is not specified by the users, +default is set to \code{"grid_search"}. If set to \code{"grid_search"}, then +additional argument \code{"resolution"} is required. +\item \code{rsmp_tune} (\link[mlr3:Resampling]{Resampling} or \code{character(1)})\cr +A \link[mlr3:Resampling]{Resampling} object (recommended) or option passed +to \link[mlr3:mlr_sugar]{rsmp()} to initialize a +\link[mlr3:Resampling]{Resampling} for parameter tuning in \code{mlr3}. +If not specified by the user, default is set to \code{"cv"} +(cross-validation). +\item \code{n_folds_tune} (\code{integer(1)}, optional) \cr +If \code{rsmp_tune = "cv"}, number of folds used for cross-validation. +If not specified by the user, default is set to \code{5}. +\item \code{measure} (\code{NULL}, named \code{list()}, optional) \cr +Named list containing the measures used for parameter tuning. Entries in +list must either be \link[mlr3:Measure]{Measure} objects or keys to be +passed to passed to \link[mlr3:mlr_sugar]{msr()}. The names of the entries must +match the learner names (see method \code{learner_names()}). If set to \code{NULL}, +default measures are used, i.e., \code{"regr.mse"} for continuous outcome +variables and \code{"classif.ce"} for binary outcomes. +\item \code{resolution} (\code{character(1)}) \cr The key passed to the respective +dictionary to specify the tuning algorithm used in +\link[mlr3tuning:tnr]{tnr()}. \code{resolution} is passed as an argument to +\link[mlr3tuning:tnr]{tnr()}. +}} + +\item{\code{tune_on_folds}}{(\code{logical(1)}) \cr +Indicates whether the tuning should be done fold-specific or globally. +Default is \code{FALSE}.} +} +\if{html}{\out{
}} +} +\subsection{Returns}{ +self +} +} +\if{html}{\out{
}} \if{html}{\out{}} \if{latex}{\out{\hypertarget{method-clone}{}}} \subsection{Method \code{clone()}}{ diff --git a/tests/testthat/helper-05-ml-learner.R b/tests/testthat/helper-05-ml-learner.R index 53a64531..390e6a27 100644 --- a/tests/testthat/helper-05-ml-learner.R +++ b/tests/testthat/helper-05-ml-learner.R @@ -2,36 +2,42 @@ get_default_mlmethod_plr = function(learner, default = FALSE) { if (default == FALSE) { if (learner == "regr.lm") { mlmethod = list( + mlmethod_l = learner, mlmethod_m = learner, mlmethod_g = learner ) params = list( - params_g = list(), - params_m = list() + params_l = list(), + params_m = list(), + params_g = list() ) } else if (learner == "regr.ranger") { mlmethod = list( + mlmethod_l = learner, mlmethod_m = learner, mlmethod_g = learner ) params = list( - params_g = list(num.trees = 100), - params_m = list(num.trees = 120) + params_l = list(num.trees = 60), + params_m = list(num.trees = 120), + params_g = list(num.trees = 100) ) } else if (learner == "regr.rpart") { mlmethod = list( + mlmethod_l = learner, mlmethod_m = learner, mlmethod_g = learner ) params = list( - params_g = list(cp = 0.01, minsplit = 20), - params_m = list(cp = 0.01, minsplit = 20) + params_l = list(cp = 0.013, minsplit = 18), + params_m = list(cp = 0.01, minsplit = 20), + params_g = list(cp = 0.005, minsplit = 10) ) } @@ -48,11 +54,16 @@ get_default_mlmethod_plr = function(learner, default = FALSE) { # } else if (learner == "regr.cv_glmnet") { mlmethod = list( + mlmethod_l = learner, mlmethod_m = learner, mlmethod_g = learner ) params = list( + params_l = list( + s = "lambda.min", + family = "gaussian" + ), params_m = list( s = "lambda.min", family = "gaussian" @@ -65,12 +76,14 @@ get_default_mlmethod_plr = function(learner, default = FALSE) { else if (default == TRUE) { mlmethod = list( + mlmethod_l = learner, mlmethod_m = learner, mlmethod_g = learner ) params = list( - params_g = list(), - params_m = list()) + params_l = list(), + params_m = list(), + params_g = list()) } if (learner == "graph_learner") { @@ -80,109 +93,129 @@ get_default_mlmethod_plr = function(learner, default = FALSE) { lambda = 0.01, family = "gaussian") mlmethod = list( + mlmethod_l = "graph_learner", mlmethod_m = "graph_learner", mlmethod_g = "graph_learner") params = list( params_g = list(), params_m = list()) - ml_g = mlr3::as_learner(pipe_learner) + ml_l = mlr3::as_learner(pipe_learner) ml_m = mlr3::as_learner(pipe_learner) + ml_g = mlr3::as_learner(pipe_learner) } else { - ml_g = mlr3::lrn(mlmethod$mlmethod_g) - ml_g$param_set$values = params$params_g + ml_l = mlr3::lrn(mlmethod$mlmethod_l) + ml_l$param_set$values = params$params_l ml_m = mlr3::lrn(mlmethod$mlmethod_m) ml_m$param_set$values = params$params_m + ml_g = mlr3::lrn(mlmethod$mlmethod_g) + ml_g$param_set$values = params$params_g } return(list( mlmethod = mlmethod, params = params, - ml_g = ml_g, ml_m = ml_m + ml_l = ml_l, ml_m = ml_m, ml_g = ml_g )) } get_default_mlmethod_pliv = function(learner) { if (learner == "regr.lm") { mlmethod = list( + mlmethod_l = learner, mlmethod_m = learner, - mlmethod_g = learner, - mlmethod_r = learner + mlmethod_r = learner, + mlmethod_g = learner ) params = list( - params_g = list(), + params_l = list(), params_m = list(), - params_r = list() + params_r = list(), + params_g = list() ) } else if (learner == "regr.ranger") { mlmethod = list( + mlmethod_l = learner, mlmethod_m = learner, - mlmethod_g = learner, - mlmethod_r = learner + mlmethod_r = learner, + mlmethod_g = learner ) params = list( - params_g = list(num.trees = 100), + params_l = list(num.trees = 100), params_m = list(num.trees = 120), - params_r = list(num.trees = 100) + params_r = list(num.trees = 100), + params_g = list(num.trees = 100) ) } else if (learner == "regr.rpart") { mlmethod = list( + mlmethod_l = learner, mlmethod_m = learner, - mlmethod_g = learner, - mlmethod_r = learner + mlmethod_r = learner, + mlmethod_g = learner ) params = list( - params_g = list(cp = 0.01, minsplit = 20), + params_l = list(cp = 0.01, minsplit = 20), params_m = list(cp = 0.01, minsplit = 20), - params_r = list(cp = 0.01, minsplit = 20) + params_r = list(cp = 0.01, minsplit = 20), + params_g = list(cp = 0.01, minsplit = 20) ) } else if (learner == "regr.cv_glmnet") { mlmethod = list( + mlmethod_l = learner, mlmethod_m = learner, - mlmethod_g = learner, - mlmethod_r = learner + mlmethod_r = learner, + mlmethod_g = learner ) params = list( - params_m = list( + params_l = list( s = "lambda.min", family = "gaussian" ), - params_g = list( + params_m = list( s = "lambda.min", family = "gaussian" ), params_r = list( s = "lambda.min", family = "gaussian" + ), + params_g = list( + s = "lambda.min", + family = "gaussian" ) ) } else if (learner == "regr.glmnet") { mlmethod = list( + mlmethod_l = learner, mlmethod_m = learner, - mlmethod_g = learner, - mlmethod_r = learner + mlmethod_r = learner, + mlmethod_g = learner ) params = list( - params_m = list( + params_l = list( lambda = 0.01, family = "gaussian" ), - params_g = list( + params_m = list( lambda = 0.01, family = "gaussian" ), params_r = list( lambda = 0.01, family = "gaussian" + ), + params_g = list( + lambda = 0.01, + family = "gaussian" ) ) @@ -195,27 +228,33 @@ get_default_mlmethod_pliv = function(learner) { lambda = 0.01, family = "gaussian") mlmethod = list( + mlmethod_l = "graph_learner", mlmethod_m = "graph_learner", - mlmethod_g = "graph_learner", - mlmethod_r = "graph_learner") + mlmethod_r = "graph_learner", + mlmethod_g = "graph_learner") params = list( - params_g = list(), - params_m = list()) - ml_g = mlr3::as_learner(pipe_learner) + params_l = list(), + params_m = list(), + params_r = list(), + params_g = list()) + ml_l = mlr3::as_learner(pipe_learner) ml_m = mlr3::as_learner(pipe_learner) ml_r = mlr3::as_learner(pipe_learner) + ml_g = mlr3::as_learner(pipe_learner) } else { - ml_g = mlr3::lrn(mlmethod$mlmethod_g) - ml_g$param_set$values = params$params_g + ml_l = mlr3::lrn(mlmethod$mlmethod_l) + ml_l$param_set$values = params$params_l ml_m = mlr3::lrn(mlmethod$mlmethod_m) ml_m$param_set$values = params$params_m ml_r = mlr3::lrn(mlmethod$mlmethod_r) ml_r$param_set$values = params$params_r + ml_g = mlr3::lrn(mlmethod$mlmethod_g) + ml_g$param_set$values = params$params_g } return(list( mlmethod = mlmethod, params = params, - ml_g = ml_g, ml_m = ml_m, ml_r = ml_r + ml_l = ml_l, ml_m = ml_m, ml_r = ml_r, ml_g = ml_g )) } diff --git a/tests/testthat/helper-08-dml_plr.R b/tests/testthat/helper-08-dml_plr.R index 6c1c94d6..305c4016 100644 --- a/tests/testthat/helper-08-dml_plr.R +++ b/tests/testthat/helper-08-dml_plr.R @@ -1,9 +1,9 @@ # Double Machine Learning for Partially Linear Regression. dml_plr = function(data, y, d, - n_folds, ml_g, ml_m, + n_folds, ml_l, ml_m, ml_g, dml_procedure, score, n_rep = 1, smpls = NULL, - params_g = NULL, params_m = NULL) { + params_l = NULL, params_m = NULL, params_g = NULL) { if (is.null(smpls)) { smpls = lapply(1:n_rep, function(x) sample_splitting(n_folds, data)) @@ -21,10 +21,10 @@ dml_plr = function(data, y, d, res_single_split = fit_plr_single_split( data, y, d, - n_folds, ml_g, ml_m, + n_folds, ml_l, ml_m, ml_g, dml_procedure, score, this_smpl, - params_g, params_m) + params_l, params_m, params_g) all_preds[[i_rep]] = res_single_split$all_preds all_thetas[i_rep] = res_single_split$theta @@ -53,10 +53,10 @@ dml_plr = function(data, y, d, dml_plr_multitreat = function(data, y, d, - n_folds, ml_g, ml_m, + n_folds, ml_l, ml_m, ml_g, dml_procedure, score, n_rep = 1, smpls = NULL, - params_g = NULL, params_m = NULL) { + params_l = NULL, params_m = NULL, params_g = NULL) { if (is.null(smpls)) { smpls = lapply(1:n_rep, function(x) sample_splitting(n_folds, data)) @@ -71,6 +71,11 @@ dml_plr_multitreat = function(data, y, d, all_preds_this_rep = list() for (i_d in seq(n_d)) { + if (!is.null(params_l)) { + this_params_l = params_l[[i_d]] + } else { + this_params_l = NULL + } if (!is.null(params_g)) { this_params_g = params_g[[i_d]] } else { @@ -83,10 +88,10 @@ dml_plr_multitreat = function(data, y, d, } res_single_split = fit_plr_single_split( data, y, d[i_d], - n_folds, ml_g, ml_m, + n_folds, ml_l, ml_m, ml_g, dml_procedure, score, this_smpl, - this_params_g, this_params_m) + this_params_l, this_params_m, this_params_g) all_preds_this_rep[[i_d]] = res_single_split$all_preds thetas_this_rep[i_d] = res_single_split$theta @@ -125,27 +130,28 @@ dml_plr_multitreat = function(data, y, d, fit_plr_single_split = function(data, y, d, - n_folds, ml_g, ml_m, + n_folds, ml_l, ml_m, ml_g, dml_procedure, score, smpl, - params_g, params_m) { + params_l, params_m, params_g) { train_ids = smpl$train_ids test_ids = smpl$test_ids + fit_g = (score == "IV-type") | is.function(score) all_preds = fit_nuisance_plr( data, y, d, - ml_g, ml_m, - smpl, - params_g, params_m) + ml_l, ml_m, ml_g, + n_folds, smpl, fit_g, + params_l, params_m, params_g) residuals = compute_plr_residuals( data, y, d, n_folds, smpl, all_preds) - u_hat = residuals$u_hat - v_hat = residuals$v_hat + y_minus_l_hat = residuals$y_minus_l_hat + d_minus_m_hat = residuals$d_minus_m_hat + y_minus_g_hat = residuals$y_minus_g_hat D = data[, d] Y = data[, y] - v_hatd = v_hat * D # DML 1 if (dml_procedure == "dml1") { @@ -154,30 +160,37 @@ fit_plr_single_split = function(data, y, d, test_index = test_ids[[i]] orth_est = orth_plr_dml( - u_hat = u_hat[test_index], v_hat = v_hat[test_index], - v_hatd = v_hatd[test_index], + y_minus_l_hat = y_minus_l_hat[test_index], + d_minus_m_hat = d_minus_m_hat[test_index], + y_minus_g_hat = y_minus_g_hat[test_index], + d = D[test_index], score = score) thetas[i] = orth_est$theta } theta = mean(thetas, na.rm = TRUE) if (length(train_ids) == 1) { D = D[test_index] - u_hat = u_hat[test_index] - v_hat = v_hat[test_index] - v_hatd = v_hatd[test_index] + y_minus_l_hat = y_minus_l_hat[test_index] + d_minus_m_hat = d_minus_m_hat[test_index] + y_minus_g_hat = y_minus_g_hat[test_index] } } if (dml_procedure == "dml2") { orth_est = orth_plr_dml( - u_hat = u_hat, v_hat = v_hat, - v_hatd = v_hatd, score = score) + y_minus_l_hat = y_minus_l_hat, + d_minus_m_hat = d_minus_m_hat, + y_minus_g_hat = y_minus_g_hat, + d = D, score = score) theta = orth_est$theta } se = sqrt(var_plr( - theta = theta, d = D, u_hat = u_hat, v_hat = v_hat, - v_hatd = v_hatd, score = score)) + theta = theta, d = D, + y_minus_l_hat = y_minus_l_hat, + d_minus_m_hat = d_minus_m_hat, + y_minus_g_hat = y_minus_g_hat, + score = score)) res = list( theta = theta, se = se, @@ -188,27 +201,29 @@ fit_plr_single_split = function(data, y, d, fit_nuisance_plr = function(data, y, d, - ml_g, ml_m, - smpls, - params_g, params_m) { + ml_l, ml_m, ml_g, + n_folds, smpls, fit_g, + params_l, params_m, params_g) { train_ids = smpls$train_ids test_ids = smpls$test_ids - # nuisance g - g_indx = names(data) != d - data_g = data[, g_indx, drop = FALSE] - task_g = mlr3::TaskRegr$new(id = paste0("nuis_g_", d), backend = data_g, target = y) + # nuisance l + l_indx = names(data) != d + data_l = data[, l_indx, drop = FALSE] + task_l = mlr3::TaskRegr$new( + id = paste0("nuis_l_", d), + backend = data_l, target = y) - resampling_g = mlr3::rsmp("custom") - resampling_g$instantiate(task_g, train_ids, test_ids) + resampling_l = mlr3::rsmp("custom") + resampling_l$instantiate(task_l, train_ids, test_ids) - if (!is.null(params_g)) { - ml_g$param_set$values = params_g + if (!is.null(params_l)) { + ml_l$param_set$values = params_l } - r_g = mlr3::resample(task_g, ml_g, resampling_g, store_models = TRUE) - g_hat_list = lapply(r_g$predictions(), function(x) x$response) + r_l = mlr3::resample(task_l, ml_l, resampling_l, store_models = TRUE) + l_hat_list = lapply(r_l$predictions(), function(x) x$response) # nuisance m if (!is.null(params_m)) { @@ -239,7 +254,45 @@ fit_nuisance_plr = function(data, y, d, m_hat_list = lapply(r_m$predictions(), function(x) as.data.table(x)$prob.1) } + if (fit_g) { + # nuisance g + residuals = compute_plr_residuals( + data, y, d, n_folds, + smpls, list( + l_hat_list = l_hat_list, + g_hat_list = NULL, + m_hat_list = m_hat_list)) + y_minus_l_hat = residuals$y_minus_l_hat + d_minus_m_hat = residuals$d_minus_m_hat + psi_a = -d_minus_m_hat * d_minus_m_hat + psi_b = d_minus_m_hat * y_minus_l_hat + theta_initial = -mean(psi_b, na.rm = TRUE) / mean(psi_a, na.rm = TRUE) + + D = data[, d] + Y = data[, y] + g_indx = names(data) != y & names(data) != d + y_minus_theta_d = Y - theta_initial * D + data_g = cbind(data[, g_indx, drop = FALSE], y_minus_theta_d) + + task_g = mlr3::TaskRegr$new( + id = paste0("nuis_g_", d), backend = data_g, + target = "y_minus_theta_d") + + resampling_g = mlr3::rsmp("custom") + resampling_g$instantiate(task_g, train_ids, test_ids) + + if (!is.null(params_g)) { + ml_g$param_set$values = params_g + } + + r_g = mlr3::resample(task_g, ml_g, resampling_g, store_models = TRUE) + g_hat_list = lapply(r_g$predictions(), function(x) x$response) + } else { + g_hat_list = NULL + } + all_preds = list( + l_hat_list = l_hat_list, m_hat_list = m_hat_list, g_hat_list = g_hat_list) @@ -250,42 +303,50 @@ compute_plr_residuals = function(data, y, d, n_folds, smpls, all_preds) { test_ids = smpls$test_ids - g_hat_list = all_preds$g_hat_list + l_hat_list = all_preds$l_hat_list m_hat_list = all_preds$m_hat_list + g_hat_list = all_preds$g_hat_list n = nrow(data) D = data[, d] Y = data[, y] - v_hat = u_hat = w_hat = rep(NA_real_, n) + y_minus_l_hat = d_minus_m_hat = y_minus_g_hat = rep(NA_real_, n) for (i in 1:n_folds) { test_index = test_ids[[i]] - g_hat = g_hat_list[[i]] + l_hat = l_hat_list[[i]] m_hat = m_hat_list[[i]] - u_hat[test_index] = Y[test_index] - g_hat - v_hat[test_index] = D[test_index] - m_hat + y_minus_l_hat[test_index] = Y[test_index] - l_hat + d_minus_m_hat[test_index] = D[test_index] - m_hat + + if (!is.null(g_hat_list)) { + g_hat = g_hat_list[[i]] + y_minus_g_hat[test_index] = Y[test_index] - g_hat + } } - residuals = list(u_hat = u_hat, v_hat = v_hat) + residuals = list( + y_minus_l_hat = y_minus_l_hat, + d_minus_m_hat = d_minus_m_hat, + y_minus_g_hat = y_minus_g_hat) return(residuals) } # Orthogonalized Estimation of Coefficient in PLR -orth_plr_dml = function(u_hat, v_hat, v_hatd, score) { +orth_plr_dml = function(y_minus_l_hat, d_minus_m_hat, y_minus_g_hat, d, score) { theta = NA_real_ if (score == "partialling out") { - res_fit = stats::lm(u_hat ~ 0 + v_hat) + res_fit = stats::lm(y_minus_l_hat ~ 0 + d_minus_m_hat) theta = stats::coef(res_fit) } else if (score == "IV-type") { - theta = mean(v_hat * u_hat) / mean(v_hatd) - # se = 1/(mean(u_hat)^2) * mean((v_hat - theta*u_hat)*u_hat)^2 + theta = mean(d_minus_m_hat * y_minus_g_hat) / mean(d_minus_m_hat * d) } else { @@ -298,15 +359,15 @@ orth_plr_dml = function(u_hat, v_hat, v_hatd, score) { # Variance estimation for DML estimator in the partially linear regression model -var_plr = function(theta, d, u_hat, v_hat, v_hatd, score) { +var_plr = function(theta, d, y_minus_l_hat, d_minus_m_hat, y_minus_g_hat, score) { n = length(d) if (score == "partialling out") { - var = 1 / n * 1 / (mean(v_hat^2))^2 * - mean(((u_hat - v_hat * theta) * v_hat)^2) + var = 1 / n * 1 / (mean(d_minus_m_hat^2))^2 * + mean(((y_minus_l_hat - d_minus_m_hat * theta) * d_minus_m_hat)^2) } else if (score == "IV-type") { - var = 1 / n * 1 / mean(v_hatd)^2 * - mean(((u_hat - d * theta) * v_hat)^2) + var = 1 / n * 1 / mean(d_minus_m_hat * d)^2 * + mean(((y_minus_g_hat - d * theta) * d_minus_m_hat)^2) } return(c(var)) } @@ -375,18 +436,19 @@ boot_plr_single_split = function(theta, se, data, y, d, residuals = compute_plr_residuals( data, y, d, n_folds, smpl, all_preds) - u_hat = residuals$u_hat - v_hat = residuals$v_hat + y_minus_l_hat = residuals$y_minus_l_hat + d_minus_m_hat = residuals$d_minus_m_hat + y_minus_g_hat = residuals$y_minus_g_hat + D = data[, d] - v_hatd = v_hat * D if (score == "partialling out") { - psi = (u_hat - v_hat * theta) * v_hat - psi_a = -v_hat * v_hat + psi = (y_minus_l_hat - d_minus_m_hat * theta) * d_minus_m_hat + psi_a = -d_minus_m_hat * d_minus_m_hat } else if (score == "IV-type") { - psi = (u_hat - D * theta) * v_hat - psi_a = -v_hatd + psi = (y_minus_g_hat - D * theta) * d_minus_m_hat + psi_a = -d_minus_m_hat * D } res = functional_bootstrap( diff --git a/tests/testthat/helper-09-dml_pliv.R b/tests/testthat/helper-09-dml_pliv.R index afe10667..2c2f8279 100644 --- a/tests/testthat/helper-09-dml_pliv.R +++ b/tests/testthat/helper-09-dml_pliv.R @@ -1,10 +1,10 @@ # Double Machine Learning for Partially Linear Instrumental Variable Regression. dml_pliv = function(data, y, d, z, n_folds, - ml_g, ml_m, ml_r, + ml_l, ml_m, ml_r, ml_g, params, dml_procedure, score, n_rep = 1, smpls = NULL, - params_g = NULL, params_m = NULL, params_r = NULL) { + params_l = NULL, params_m = NULL, params_r = NULL, params_g = NULL) { if (is.null(smpls)) { smpls = lapply(1:n_rep, function(x) sample_splitting(n_folds, data)) @@ -18,18 +18,21 @@ dml_pliv = function(data, y, d, z, train_ids = this_smpl$train_ids test_ids = this_smpl$test_ids + fit_g = (score == "IV-type") | is.function(score) all_preds[[i_rep]] = fit_nuisance_pliv( data, y, d, z, - ml_g, ml_m, ml_r, - this_smpl, - params_g, params_m, params_r) + ml_l, ml_m, ml_r, ml_g, + n_folds, this_smpl, fit_g, + params_l, params_m, params_r, params_g) residuals = compute_pliv_residuals( data, y, d, z, n_folds, this_smpl, all_preds[[i_rep]]) - u_hat = residuals$u_hat - v_hat = residuals$v_hat - w_hat = residuals$w_hat + y_minus_l_hat = residuals$y_minus_l_hat + z_minus_m_hat = residuals$z_minus_m_hat + d_minus_r_hat = residuals$d_minus_r_hat + y_minus_g_hat = residuals$y_minus_g_hat + D = data[, d] # DML 1 if (dml_procedure == "dml1") { @@ -37,29 +40,35 @@ dml_pliv = function(data, y, d, z, for (i in 1:n_folds) { test_index = test_ids[[i]] orth_est = orth_pliv_dml( - u_hat = u_hat[test_index], - v_hat = v_hat[test_index], - w_hat = w_hat[test_index], + y_minus_l_hat = y_minus_l_hat[test_index], + z_minus_m_hat = z_minus_m_hat[test_index], + d_minus_r_hat = d_minus_r_hat[test_index], + y_minus_g_hat = y_minus_g_hat[test_index], + D = D[test_index], score = score) thetas[i] = orth_est$theta } all_thetas[i_rep] = mean(thetas, na.rm = TRUE) if (length(train_ids) == 1) { - u_hat = u_hat[test_index] - v_hat = v_hat[test_index] - w_hat = w_hat[test_index] + y_minus_l_hat = y_minus_l_hat[test_index] + z_minus_m_hat = z_minus_m_hat[test_index] + d_minus_r_hat = d_minus_r_hat[test_index] + y_minus_g_hat = y_minus_g_hat[test_index] } } if (dml_procedure == "dml2") { orth_est = orth_pliv_dml( - u_hat = u_hat, v_hat = v_hat, w_hat = w_hat, - score = score) + y_minus_l_hat = y_minus_l_hat, z_minus_m_hat = z_minus_m_hat, + d_minus_r_hat = d_minus_r_hat, y_minus_g_hat = y_minus_g_hat, + D = D, score = score) all_thetas[i_rep] = orth_est$theta } all_ses[i_rep] = sqrt(var_pliv( - theta = all_thetas[i_rep], u_hat = u_hat, v_hat = v_hat, - w_hat = w_hat, score = score)) + D = D, theta = all_thetas[i_rep], + y_minus_l_hat = y_minus_l_hat, z_minus_m_hat = z_minus_m_hat, + d_minus_r_hat = d_minus_r_hat, y_minus_g_hat = y_minus_g_hat, + score = score)) } theta = stats::median(all_thetas) @@ -83,27 +92,27 @@ dml_pliv = function(data, y, d, z, } fit_nuisance_pliv = function(data, y, d, z, - ml_g, ml_m, ml_r, - smpls, - params_g, params_m, params_r) { + ml_l, ml_m, ml_r, ml_g, + n_folds, smpls, fit_g, + params_l, params_m, params_r, params_g) { train_ids = smpls$train_ids test_ids = smpls$test_ids - # nuisance g: E[Y|X] - g_indx = names(data) != d & names(data) != z - data_g = data[, g_indx, drop = FALSE] - task_g = mlr3::TaskRegr$new(id = paste0("nuis_g_", d), backend = data_g, target = y) + # nuisance l: E[Y|X] + l_indx = names(data) != d & names(data) != z + data_l = data[, l_indx, drop = FALSE] + task_l = mlr3::TaskRegr$new(id = paste0("nuis_l_", d), backend = data_l, target = y) - resampling_g = mlr3::rsmp("custom") - resampling_g$instantiate(task_g, train_ids, test_ids) + resampling_l = mlr3::rsmp("custom") + resampling_l$instantiate(task_l, train_ids, test_ids) - if (!is.null(params_g)) { - ml_g$param_set$values = params_g + if (!is.null(params_l)) { + ml_l$param_set$values = params_l } - r_g = mlr3::resample(task_g, ml_g, resampling_g, store_models = TRUE) - g_hat_list = lapply(r_g$predictions(), function(x) x$response) + r_l = mlr3::resample(task_l, ml_l, resampling_l, store_models = TRUE) + l_hat_list = lapply(r_l$predictions(), function(x) x$response) # nuisance m: E[Z|X] m_indx = names(data) != y & names(data) != d @@ -133,10 +142,50 @@ fit_nuisance_pliv = function(data, y, d, z, r_r = mlr3::resample(task_r, ml_r, resampling_r, store_models = TRUE) r_hat_list = lapply(r_r$predictions(), function(x) x$response) + if (fit_g) { + # nuisance g + residuals = compute_pliv_residuals( + data, y, d, z, n_folds, + smpls, list( + l_hat_list = l_hat_list, + m_hat_list = m_hat_list, + r_hat_list = r_hat_list, + g_hat_list = NULL)) + y_minus_l_hat = residuals$y_minus_l_hat + z_minus_m_hat = residuals$z_minus_m_hat + d_minus_r_hat = residuals$d_minus_r_hat + psi_a = -d_minus_r_hat * z_minus_m_hat + psi_b = z_minus_m_hat * y_minus_l_hat + theta_initial = -mean(psi_b, na.rm = TRUE) / mean(psi_a, na.rm = TRUE) + + D = data[, d] + Y = data[, y] + g_indx = names(data) != y & names(data) != d & names(data) != z + y_minus_theta_d = Y - theta_initial * D + data_g = cbind(data[, g_indx, drop = FALSE], y_minus_theta_d) + + task_g = mlr3::TaskRegr$new( + id = paste0("nuis_g_", d), backend = data_g, + target = "y_minus_theta_d") + + resampling_g = mlr3::rsmp("custom") + resampling_g$instantiate(task_g, train_ids, test_ids) + + if (!is.null(params_g)) { + ml_g$param_set$values = params_g + } + + r_g = mlr3::resample(task_g, ml_g, resampling_g, store_models = TRUE) + g_hat_list = lapply(r_g$predictions(), function(x) x$response) + } else { + g_hat_list = NULL + } + all_preds = list( + l_hat_list = l_hat_list, m_hat_list = m_hat_list, - g_hat_list = g_hat_list, - r_hat_list = r_hat_list) + r_hat_list = r_hat_list, + g_hat_list = g_hat_list) return(all_preds) } @@ -145,37 +194,50 @@ compute_pliv_residuals = function(data, y, d, z, n_folds, smpls, all_preds) { test_ids = smpls$test_ids + l_hat_list = all_preds$l_hat_list m_hat_list = all_preds$m_hat_list - g_hat_list = all_preds$g_hat_list r_hat_list = all_preds$r_hat_list + g_hat_list = all_preds$g_hat_list n = nrow(data) D = data[, d] Y = data[, y] Z = data[, z] - v_hat = u_hat = w_hat = rep(NA_real_, n) + y_minus_l_hat = z_minus_m_hat = d_minus_r_hat = y_minus_g_hat = rep(NA_real_, n) for (i in 1:n_folds) { test_index = test_ids[[i]] + l_hat = l_hat_list[[i]] m_hat = m_hat_list[[i]] - g_hat = g_hat_list[[i]] r_hat = r_hat_list[[i]] - v_hat[test_index] = D[test_index] - r_hat - u_hat[test_index] = Y[test_index] - g_hat - w_hat[test_index] = Z[test_index] - m_hat + y_minus_l_hat[test_index] = Y[test_index] - l_hat + z_minus_m_hat[test_index] = Z[test_index] - m_hat + d_minus_r_hat[test_index] = D[test_index] - r_hat + + if (!is.null(g_hat_list)) { + g_hat = g_hat_list[[i]] + y_minus_g_hat[test_index] = Y[test_index] - g_hat + } } - residuals = list(u_hat = u_hat, v_hat = v_hat, w_hat = w_hat) + residuals = list( + y_minus_l_hat = y_minus_l_hat, + z_minus_m_hat = z_minus_m_hat, + d_minus_r_hat = d_minus_r_hat, + y_minus_g_hat = y_minus_g_hat) return(residuals) } # Orthogonalized Estimation of Coefficient in PLR -orth_pliv_dml = function(u_hat, v_hat, w_hat, score) { +orth_pliv_dml = function(y_minus_l_hat, z_minus_m_hat, + d_minus_r_hat, y_minus_g_hat, D, score) { if (score == "partialling out") { - theta = mean(u_hat * w_hat) / mean(v_hat * w_hat) + theta = mean(y_minus_l_hat * z_minus_m_hat) / mean(d_minus_r_hat * z_minus_m_hat) + } else if (score == "IV-type") { + theta = mean(y_minus_g_hat * z_minus_m_hat) / mean(D * z_minus_m_hat) } else { stop("Inference framework for orthogonal estimation unknown") } @@ -184,10 +246,14 @@ orth_pliv_dml = function(u_hat, v_hat, w_hat, score) { } # Variance estimation for DML estimator in the partially linear regression model -var_pliv = function(theta, u_hat, v_hat, w_hat, score) { +var_pliv = function(theta, D, y_minus_l_hat, z_minus_m_hat, + d_minus_r_hat, y_minus_g_hat, score) { if (score == "partialling out") { - var = mean(1 / length(u_hat) * 1 / (mean(v_hat * w_hat))^2 * - mean(((u_hat - v_hat * theta) * w_hat)^2)) + var = mean(1 / length(y_minus_l_hat) * 1 / (mean(d_minus_r_hat * z_minus_m_hat))^2 * + mean(((y_minus_l_hat - d_minus_r_hat * theta) * z_minus_m_hat)^2)) + } else if (score == "IV-type") { + var = mean(1 / length(y_minus_l_hat) * 1 / (mean(D * z_minus_m_hat))^2 * + mean(((y_minus_g_hat - D * theta) * z_minus_m_hat)^2)) } else { stop("Inference framework for variance estimation unknown") } @@ -196,18 +262,25 @@ var_pliv = function(theta, u_hat, v_hat, w_hat, score) { # Bootstrap Implementation for Partially Linear Regression Model bootstrap_pliv = function(theta, se, data, y, d, z, n_folds, smpls, - all_preds, bootstrap, n_rep_boot, + all_preds, bootstrap, n_rep_boot, score, n_rep = 1) { for (i_rep in 1:n_rep) { residuals = compute_pliv_residuals( data, y, d, z, n_folds, smpls[[i_rep]], all_preds[[i_rep]]) - u_hat = residuals$u_hat - v_hat = residuals$v_hat - w_hat = residuals$w_hat - - psi = (u_hat - v_hat * theta[i_rep]) * w_hat - psi_a = -v_hat * w_hat + y_minus_l_hat = residuals$y_minus_l_hat + d_minus_r_hat = residuals$d_minus_r_hat + z_minus_m_hat = residuals$z_minus_m_hat + y_minus_g_hat = residuals$y_minus_g_hat + + if (score == "partialling out") { + psi = (y_minus_l_hat - d_minus_r_hat * theta[i_rep]) * z_minus_m_hat + psi_a = -d_minus_r_hat * z_minus_m_hat + } else if (score == "IV-type") { + D = data[, d] + psi = (y_minus_g_hat - D * theta[i_rep]) * z_minus_m_hat + psi_a = -D * z_minus_m_hat + } n = length(psi) weights = draw_bootstrap_weights(bootstrap, n_rep_boot, n) diff --git a/tests/testthat/helper-13-dml_pliv_partial_x.R b/tests/testthat/helper-13-dml_pliv_partial_x.R index daefdfd8..c4fb9f5d 100644 --- a/tests/testthat/helper-13-dml_pliv_partial_x.R +++ b/tests/testthat/helper-13-dml_pliv_partial_x.R @@ -1,9 +1,9 @@ dml_pliv_partial_x = function(data, y, d, z, n_folds, - ml_g, ml_m, ml_r, + ml_l, ml_m, ml_r, params, dml_procedure, score, n_rep = 1, smpls = NULL, - params_g = NULL, params_m = NULL, params_r = NULL) { + params_l = NULL, params_m = NULL, params_r = NULL) { stopifnot(length(z) > 1) if (is.null(smpls)) { @@ -18,9 +18,9 @@ dml_pliv_partial_x = function(data, y, d, z, all_preds[[i_rep]] = fit_nuisance_pliv_partial_x( data, y, d, z, - ml_g, ml_m, ml_r, + ml_l, ml_m, ml_r, this_smpl, - params_g, params_m, params_r) + params_l, params_m, params_r) residuals = compute_pliv_partial_x_residuals( data, y, d, z, n_folds, @@ -77,27 +77,27 @@ dml_pliv_partial_x = function(data, y, d, z, } fit_nuisance_pliv_partial_x = function(data, y, d, z, - ml_g, ml_m, ml_r, + ml_l, ml_m, ml_r, smpls, - params_g, params_m, params_r) { + params_l, params_m, params_r) { train_ids = smpls$train_ids test_ids = smpls$test_ids - # nuisance g: E[Y|X] - g_indx = names(data) != d & (names(data) %in% z == FALSE) - data_g = data[, g_indx, drop = FALSE] - task_g = mlr3::TaskRegr$new(id = paste0("nuis_g_", d), backend = data_g, target = y) + # nuisance l: E[Y|X] + l_indx = names(data) != d & (names(data) %in% z == FALSE) + data_l = data[, l_indx, drop = FALSE] + task_l = mlr3::TaskRegr$new(id = paste0("nuis_l_", d), backend = data_l, target = y) - resampling_g = mlr3::rsmp("custom") - resampling_g$instantiate(task_g, train_ids, test_ids) + resampling_l = mlr3::rsmp("custom") + resampling_l$instantiate(task_l, train_ids, test_ids) - if (!is.null(params_g)) { - ml_g$param_set$values = params_g + if (!is.null(params_l)) { + ml_l$param_set$values = params_l } - r_g = mlr3::resample(task_g, ml_g, resampling_g, store_models = TRUE) - g_hat_list = lapply(r_g$predictions(), function(x) x$response) + r_l = mlr3::resample(task_l, ml_l, resampling_l, store_models = TRUE) + l_hat_list = lapply(r_l$predictions(), function(x) x$response) # nuisance m: E[Z|X] n_z = length(z) @@ -150,7 +150,7 @@ fit_nuisance_pliv_partial_x = function(data, y, d, z, Z - m_hat_array) all_preds = list( - g_hat_list = g_hat_list, + l_hat_list = l_hat_list, r_hat_list = r_hat_list, r_hat_tilde = r_hat_tilde) @@ -162,7 +162,7 @@ compute_pliv_partial_x_residuals = function(data, y, d, z, n_folds, smpls, test_ids = smpls$test_ids - g_hat_list = all_preds$g_hat_list + l_hat_list = all_preds$l_hat_list r_hat_list = all_preds$r_hat_list r_hat_tilde = all_preds$r_hat_tilde @@ -175,10 +175,10 @@ compute_pliv_partial_x_residuals = function(data, y, d, z, n_folds, smpls, for (i in 1:n_folds) { test_index = test_ids[[i]] - g_hat = g_hat_list[[i]] + l_hat = l_hat_list[[i]] r_hat = r_hat_list[[i]] - u_hat[test_index] = Y[test_index] - g_hat + u_hat[test_index] = Y[test_index] - l_hat w_hat[test_index] = D[test_index] - r_hat } residuals = list(u_hat = u_hat, w_hat = w_hat, r_hat_tilde = r_hat_tilde) diff --git a/tests/testthat/helper-15-dml_pliv_partial_xz.R b/tests/testthat/helper-15-dml_pliv_partial_xz.R index a6dc29a9..58450cea 100644 --- a/tests/testthat/helper-15-dml_pliv_partial_xz.R +++ b/tests/testthat/helper-15-dml_pliv_partial_xz.R @@ -1,9 +1,9 @@ dml_pliv_partial_xz = function(data, y, d, z, n_folds, - ml_g, ml_m, ml_r, + ml_l, ml_m, ml_r, params, dml_procedure, score, n_rep = 1, smpls = NULL, - params_g = NULL, params_m = NULL, params_r = NULL) { + params_l = NULL, params_m = NULL, params_r = NULL) { if (is.null(smpls)) { smpls = lapply(1:n_rep, function(x) sample_splitting(n_folds, data)) @@ -17,9 +17,9 @@ dml_pliv_partial_xz = function(data, y, d, z, all_preds[[i_rep]] = fit_nuisance_pliv_partial_xz( data, y, d, z, - ml_g, ml_m, ml_r, + ml_l, ml_m, ml_r, this_smpl, - params_g, params_m, params_r) + params_l, params_m, params_r) residuals = compute_pliv_partial_xz_residuals( data, y, d, z, n_folds, @@ -81,27 +81,27 @@ dml_pliv_partial_xz = function(data, y, d, z, } fit_nuisance_pliv_partial_xz = function(data, y, d, z, - ml_g, ml_m, ml_r, + ml_l, ml_m, ml_r, smpls, - params_g, params_m, params_r) { + params_l, params_m, params_r) { train_ids = smpls$train_ids test_ids = smpls$test_ids - # nuisance g: E[Y|X] - g_indx = names(data) != d & (names(data) %in% z == FALSE) - data_g = data[, g_indx, drop = FALSE] - task_g = mlr3::TaskRegr$new(id = paste0("nuis_g_", d), backend = data_g, target = y) + # nuisance l: E[Y|X] + l_indx = names(data) != d & (names(data) %in% z == FALSE) + data_l = data[, l_indx, drop = FALSE] + task_l = mlr3::TaskRegr$new(id = paste0("nuis_l_", d), backend = data_l, target = y) - resampling_g = mlr3::rsmp("custom") - resampling_g$instantiate(task_g, train_ids, test_ids) + resampling_l = mlr3::rsmp("custom") + resampling_l$instantiate(task_l, train_ids, test_ids) - if (!is.null(params_g)) { - ml_g$param_set$values = params_g + if (!is.null(params_l)) { + ml_l$param_set$values = params_l } - r_g = mlr3::resample(task_g, ml_g, resampling_g, store_models = TRUE) - g_hat_list = lapply(r_g$predictions(), function(x) x$response) + r_l = mlr3::resample(task_l, ml_l, resampling_l, store_models = TRUE) + l_hat_list = lapply(r_l$predictions(), function(x) x$response) # nuisance m: E[D|XZ] m_indx = (names(data) != y) @@ -141,7 +141,7 @@ fit_nuisance_pliv_partial_xz = function(data, y, d, z, } all_preds = list( - g_hat_list = g_hat_list, + l_hat_list = l_hat_list, m_hat_list = m_hat_list, r_hat_list = r_hat_list) @@ -153,7 +153,7 @@ compute_pliv_partial_xz_residuals = function(data, y, d, z, n_folds, smpls, test_ids = smpls$test_ids - g_hat_list = all_preds$g_hat_list + l_hat_list = all_preds$l_hat_list m_hat_list = all_preds$m_hat_list r_hat_list = all_preds$r_hat_list @@ -166,11 +166,11 @@ compute_pliv_partial_xz_residuals = function(data, y, d, z, n_folds, smpls, for (i in 1:n_folds) { test_index = test_ids[[i]] - g_hat = g_hat_list[[i]] + l_hat = l_hat_list[[i]] m_hat = m_hat_list[[i]] r_hat = r_hat_list[[i]] - u_hat[test_index] = Y[test_index] - g_hat + u_hat[test_index] = Y[test_index] - l_hat v_hat[test_index] = m_hat - r_hat w_hat[test_index] = D[test_index] - r_hat } diff --git a/tests/testthat/print_outputs/dml_pliv.txt b/tests/testthat/print_outputs/dml_pliv.txt index e0ec4b2a..59c434fa 100644 --- a/tests/testthat/print_outputs/dml_pliv.txt +++ b/tests/testthat/print_outputs/dml_pliv.txt @@ -15,7 +15,7 @@ Score function: partialling out DML algorithm: dml2 ------------------ Machine learner ------------------ -ml_g: regr.rpart +ml_l: regr.rpart ml_m: regr.rpart ml_r: regr.rpart diff --git a/tests/testthat/print_outputs/dml_plr.txt b/tests/testthat/print_outputs/dml_plr.txt index b1f0def8..5e982686 100644 --- a/tests/testthat/print_outputs/dml_plr.txt +++ b/tests/testthat/print_outputs/dml_plr.txt @@ -14,7 +14,7 @@ Score function: partialling out DML algorithm: dml2 ------------------ Machine learner ------------------ -ml_g: regr.rpart +ml_l: regr.rpart ml_m: regr.rpart ------------------ Resampling ------------------ diff --git a/tests/testthat/test-double_ml_data.R b/tests/testthat/test-double_ml_data.R index 4db052c3..cd295473 100644 --- a/tests/testthat/test-double_ml_data.R +++ b/tests/testthat/test-double_ml_data.R @@ -386,7 +386,7 @@ test_that("Unit tests for invalid data", { "DoubleMLPLIV instead of DoubleMLPLR.") expect_error(DoubleMLPLR$new( data = data_pliv$dml_data, - ml_g = mlr3::lrn("regr.rpart"), + ml_l = mlr3::lrn("regr.rpart"), ml_m = mlr3::lrn("regr.rpart")), regexp = msg) @@ -398,7 +398,7 @@ test_that("Unit tests for invalid data", { "variable\\(s\\) use DoubleMLPLR instead of DoubleMLPLIV.") expect_error(DoubleMLPLIV$new( data = data_plr$dml_data, - ml_g = mlr3::lrn("regr.rpart"), + ml_l = mlr3::lrn("regr.rpart"), ml_m = mlr3::lrn("regr.rpart"), ml_r = mlr3::lrn("regr.rpart")), regexp = msg) diff --git a/tests/testthat/test-double_ml_pliv.R b/tests/testthat/test-double_ml_pliv.R index 42e48a14..19cb2f31 100644 --- a/tests/testthat/test-double_ml_pliv.R +++ b/tests/testthat/test-double_ml_pliv.R @@ -15,7 +15,7 @@ if (on_cran) { test_cases = expand.grid( learner = c("regr.lm", "regr.glmnet", "graph_learner"), dml_procedure = c("dml1", "dml2"), - score = "partialling out", + score = c("partialling out", "IV-type"), stringsAsFactors = FALSE) } test_cases[".test_name"] = apply(test_cases, 1, paste, collapse = "_") @@ -29,9 +29,10 @@ patrick::with_parameters_test_that("Unit tests for PLIV:", pliv_hat = dml_pliv(data_pliv$df, y = "y", d = "d", z = "z", n_folds = 5, - ml_g = learner_pars$ml_g$clone(), + ml_l = learner_pars$ml_l$clone(), ml_m = learner_pars$ml_m$clone(), ml_r = learner_pars$ml_r$clone(), + ml_g = learner_pars$ml_g$clone(), dml_procedure = dml_procedure, score = score) theta = pliv_hat$coef se = pliv_hat$se @@ -41,19 +42,32 @@ patrick::with_parameters_test_that("Unit tests for PLIV:", y = "y", d = "d", z = "z", n_folds = 5, smpls = pliv_hat$smpls, all_preds = pliv_hat$all_preds, - bootstrap = "normal", n_rep_boot = n_rep_boot)$boot_coef + bootstrap = "normal", n_rep_boot = n_rep_boot, + score = score)$boot_coef set.seed(3141) - double_mlpliv_obj = DoubleMLPLIV$new( - data = data_pliv$dml_data, - n_folds = 5, - ml_g = learner_pars$ml_g$clone(), - ml_m = learner_pars$ml_m$clone(), - ml_r = learner_pars$ml_r$clone(), - dml_procedure = dml_procedure, - score = score) + if (score == "partialling out") { + double_mlpliv_obj = DoubleMLPLIV$new( + data = data_pliv$dml_data, + n_folds = 5, + ml_l = learner_pars$ml_l$clone(), + ml_m = learner_pars$ml_m$clone(), + ml_r = learner_pars$ml_r$clone(), + dml_procedure = dml_procedure, + score = score) + } else { + double_mlpliv_obj = DoubleMLPLIV$new( + data = data_pliv$dml_data, + n_folds = 5, + ml_l = learner_pars$ml_l$clone(), + ml_m = learner_pars$ml_m$clone(), + ml_r = learner_pars$ml_r$clone(), + ml_g = learner_pars$ml_g$clone(), + dml_procedure = dml_procedure, + score = score) + } - double_mlpliv_obj$fit() + double_mlpliv_obj$fit(store_predictions = T) theta_obj = double_mlpliv_obj$coef se_obj = double_mlpliv_obj$se diff --git a/tests/testthat/test-double_ml_pliv_exception_handling.R b/tests/testthat/test-double_ml_pliv_exception_handling.R new file mode 100644 index 00000000..b852d643 --- /dev/null +++ b/tests/testthat/test-double_ml_pliv_exception_handling.R @@ -0,0 +1,79 @@ +context("Unit tests for exception handling and deprecation warnings of PLIV") + +library("mlr3learners") + +logger = lgr::get_logger("bbotk") +logger$set_threshold("warn") +lgr::get_logger("mlr3")$set_threshold("warn") + +test_that("Unit tests for deprecation warnings of PLIV", { + set.seed(3141) + dml_data_pliv = make_pliv_CHS2015(n_obs = 51, dim_z = 1) + ml_l = lrn("regr.ranger") + ml_g = lrn("regr.ranger") + ml_m = lrn("regr.ranger") + ml_r = lrn("regr.ranger") + msg = paste0("The argument ml_g was renamed to ml_l.") + expect_warning(DoubleMLPLIV$new(dml_data_pliv, + ml_g = ml_g, ml_m = ml_m, ml_r = ml_r), + regexp = msg) + + msg = paste( + "For score = 'IV-type', learners", + "ml_l, ml_m, ml_r and ml_g need to be specified.") + expect_error(DoubleMLPLIV$new(dml_data_pliv, + ml_l = ml_l, ml_m = ml_m, ml_r = ml_r, + score = "IV-type"), + regexp = msg) + + dml_obj = DoubleMLPLIV$new(dml_data_pliv, + ml_l = ml_g, ml_m = ml_m, ml_r = ml_r) + + msg = paste0("Learner ml_g was renamed to ml_l.") + expect_warning(dml_obj$set_ml_nuisance_params( + "ml_g", "d", list("num.trees" = 10)), + regexp = msg) + + par_grids = list( + "ml_g" = paradox::ParamSet$new(list( + paradox::ParamInt$new("num.trees", lower = 9, upper = 10))), + "ml_m" = paradox::ParamSet$new(list( + paradox::ParamInt$new("num.trees", lower = 10, upper = 11))), + "ml_r" = paradox::ParamSet$new(list( + paradox::ParamInt$new("num.trees", lower = 10, upper = 11)))) + + msg = paste0("Learner ml_g was renamed to ml_l.") + expect_warning(dml_obj$tune(par_grids), + regexp = msg) + + tune_settings = list( + n_folds_tune = 5, + rsmp_tune = mlr3::rsmp("cv", folds = 5), + measure = list(ml_g = "regr.mse", ml_m = "regr.mae"), + terminator = mlr3tuning::trm("evals", n_evals = 20), + algorithm = mlr3tuning::tnr("grid_search"), + resolution = 5) + expect_warning(dml_obj$tune(par_grids, tune_settings = tune_settings), + regexp = msg) +} +) + +test_that("Unit tests of exception handling for DoubleMLPLIV", { + set.seed(3141) + dml_data_pliv = make_pliv_CHS2015(n_obs = 51, dim_z = 1) + ml_l = lrn("regr.ranger") + ml_m = lrn("regr.ranger") + ml_r = lrn("regr.ranger") + ml_g = lrn("regr.ranger") + + + msg = paste0( + "A learner ml_g has been provided for ", + "score = 'partialling out' but will be ignored.") + expect_warning(DoubleMLPLIV$new(dml_data_pliv, + ml_l = ml_l, ml_m = ml_m, ml_r = ml_r, + ml_g = ml_g, + score = "partialling out"), + regexp = msg) +} +) diff --git a/tests/testthat/test-double_ml_pliv_multi_z_parameter_passing.R b/tests/testthat/test-double_ml_pliv_multi_z_parameter_passing.R index 5065cd71..e6cf0fa4 100644 --- a/tests/testthat/test-double_ml_pliv_multi_z_parameter_passing.R +++ b/tests/testthat/test-double_ml_pliv_multi_z_parameter_passing.R @@ -34,10 +34,10 @@ patrick::with_parameters_test_that("Unit tests for parameter passing of PLIV.par pliv_hat = dml_pliv_partial_x(df, y = "y", d = "d", z = c("z", "z2"), n_folds = n_folds, n_rep = n_rep, - ml_g = mlr3::lrn(learner_pars$mlmethod$mlmethod_g), + ml_l = mlr3::lrn(learner_pars$mlmethod$mlmethod_l), ml_m = mlr3::lrn(learner_pars$mlmethod$mlmethod_m), ml_r = mlr3::lrn(learner_pars$mlmethod$mlmethod_r), - params_g = learner_pars$params$params_g, + params_l = learner_pars$params$params_l, params_m = learner_pars$params$params_m, params_r = learner_pars$params$params_r, dml_procedure = dml_procedure, score = score) @@ -62,7 +62,7 @@ patrick::with_parameters_test_that("Unit tests for parameter passing of PLIV.par dml_pliv_obj = DoubleMLPLIV.partialX( data = dml_data, n_folds = n_folds, n_rep = n_rep, - ml_g = mlr3::lrn(learner_pars$mlmethod$mlmethod_g), + ml_l = mlr3::lrn(learner_pars$mlmethod$mlmethod_l), ml_m = mlr3::lrn(learner_pars$mlmethod$mlmethod_m), ml_r = mlr3::lrn(learner_pars$mlmethod$mlmethod_r), dml_procedure = dml_procedure, @@ -70,8 +70,8 @@ patrick::with_parameters_test_that("Unit tests for parameter passing of PLIV.par dml_pliv_obj$set_ml_nuisance_params( treat_var = "d", - learner = "ml_g", - params = learner_pars$params$params_g) + learner = "ml_l", + params = learner_pars$params$params_l) dml_pliv_obj$set_ml_nuisance_params( learner = "ml_m_z", treat_var = "d", @@ -118,7 +118,7 @@ patrick::with_parameters_test_that("Unit tests for parameter passing of PLIV.par set.seed(3141) dml_pliv_obj = DoubleMLPLIV.partialX(dml_data, n_folds = n_folds, n_rep = n_rep, - ml_g = mlr3::lrn(learner_pars$mlmethod$mlmethod_g), + ml_l = mlr3::lrn(learner_pars$mlmethod$mlmethod_l), ml_m = mlr3::lrn(learner_pars$mlmethod$mlmethod_m), ml_r = mlr3::lrn(learner_pars$mlmethod$mlmethod_r), dml_procedure = dml_procedure, @@ -126,8 +126,8 @@ patrick::with_parameters_test_that("Unit tests for parameter passing of PLIV.par dml_pliv_obj$set_ml_nuisance_params( treat_var = "d", - learner = "ml_g", - params = learner_pars$params$params_g) + learner = "ml_l", + params = learner_pars$params$params_l) dml_pliv_obj$set_ml_nuisance_params( learner = "ml_m_z", treat_var = "d", @@ -145,14 +145,14 @@ patrick::with_parameters_test_that("Unit tests for parameter passing of PLIV.par theta = dml_pliv_obj$coef se = dml_pliv_obj$se - params_g_fold_wise = rep(list(rep(list(learner_pars$params$params_g), n_folds)), n_rep) + params_l_fold_wise = rep(list(rep(list(learner_pars$params$params_l), n_folds)), n_rep) params_m_fold_wise = rep(list(rep(list(learner_pars$params$params_m), n_folds)), n_rep) params_r_fold_wise = rep(list(rep(list(learner_pars$params$params_r), n_folds)), n_rep) set.seed(3141) dml_pliv_obj_fold_wise = DoubleMLPLIV.partialX(dml_data, n_folds = n_folds, n_rep = n_rep, - ml_g = mlr3::lrn(learner_pars$mlmethod$mlmethod_g), + ml_l = mlr3::lrn(learner_pars$mlmethod$mlmethod_l), ml_m = mlr3::lrn(learner_pars$mlmethod$mlmethod_m), ml_r = mlr3::lrn(learner_pars$mlmethod$mlmethod_r), dml_procedure = dml_procedure, @@ -160,8 +160,8 @@ patrick::with_parameters_test_that("Unit tests for parameter passing of PLIV.par dml_pliv_obj_fold_wise$set_ml_nuisance_params( treat_var = "d", - learner = "ml_g", - params = params_g_fold_wise, + learner = "ml_l", + params = params_l_fold_wise, set_fold_specific = TRUE) dml_pliv_obj_fold_wise$set_ml_nuisance_params( treat_var = "d", @@ -193,7 +193,7 @@ patrick::with_parameters_test_that("Unit tests for parameter passing of PLIV.par n_folds = 2 n_rep = 3 - params_g = list(cp = 0.01, minsplit = 20) # this are defaults + params_l = list(cp = 0.01, minsplit = 20) # this are defaults params_m = list(cp = 0.01, minsplit = 20) # this are defaults params_r = list(cp = 0.01, minsplit = 20) # this are defaults @@ -206,7 +206,7 @@ patrick::with_parameters_test_that("Unit tests for parameter passing of PLIV.par set.seed(3141) dml_pliv_default = DoubleMLPLIV.partialX(dml_data, n_folds = n_folds, n_rep = n_rep, - ml_g = lrn("regr.rpart"), + ml_l = lrn("regr.rpart"), ml_m = lrn("regr.rpart"), ml_r = lrn("regr.rpart"), dml_procedure = dml_procedure, @@ -219,7 +219,7 @@ patrick::with_parameters_test_that("Unit tests for parameter passing of PLIV.par set.seed(3141) dml_pliv_obj = DoubleMLPLIV.partialX(dml_data, n_folds = n_folds, n_rep = n_rep, - ml_g = lrn("regr.rpart"), + ml_l = lrn("regr.rpart"), ml_m = lrn("regr.rpart"), ml_r = lrn("regr.rpart"), dml_procedure = dml_procedure, @@ -227,8 +227,8 @@ patrick::with_parameters_test_that("Unit tests for parameter passing of PLIV.par dml_pliv_obj$set_ml_nuisance_params( treat_var = "d", - learner = "ml_g", - params = params_g) + learner = "ml_l", + params = params_l) dml_pliv_obj$set_ml_nuisance_params( learner = "ml_m_z", treat_var = "d", diff --git a/tests/testthat/test-double_ml_pliv_one_way_cluster.R b/tests/testthat/test-double_ml_pliv_one_way_cluster.R index 73c589c5..4f6f70e4 100644 --- a/tests/testthat/test-double_ml_pliv_one_way_cluster.R +++ b/tests/testthat/test-double_ml_pliv_one_way_cluster.R @@ -13,7 +13,7 @@ if (on_cran) { test_cases = expand.grid( learner = c("regr.lm", "regr.glmnet"), dml_procedure = c("dml1", "dml2"), - score = "partialling out", + score = c("partialling out", "IV-type"), stringsAsFactors = FALSE) } test_cases[".test_name"] = apply(test_cases, 1, paste, collapse = "_") @@ -35,15 +35,20 @@ patrick::with_parameters_test_that("Unit tests for PLIV with one-way clustering: n_folds = 2 n_rep = 2 - set.seed(3141) + if (score == "IV-type") { + ml_g = learner_pars$ml_g$clone() + } else { + ml_g = NULL + } double_mlpliv_obj = DoubleMLPLIV$new( data = data_one_way, n_folds = n_folds, n_rep = n_rep, - ml_g = learner_pars$ml_g$clone(), + ml_l = learner_pars$ml_l$clone(), ml_m = learner_pars$ml_m$clone(), ml_r = learner_pars$ml_r$clone(), + ml_g = ml_g, dml_procedure = dml_procedure, score = score) @@ -53,6 +58,11 @@ patrick::with_parameters_test_that("Unit tests for PLIV with one-way clustering: se_obj = double_mlpliv_obj$se set.seed(3141) + if (score == "IV-type") { + ml_g = learner_pars$ml_g$clone() + } else { + ml_g = NULL + } df = as.data.frame(data_one_way$data) cluster_var = df$cluster_var_i # need to drop variables as x is not explicitly set @@ -60,9 +70,10 @@ patrick::with_parameters_test_that("Unit tests for PLIV with one-way clustering: pliv_hat = dml_pliv(df, y = "Y", d = "D", z = "Z", n_folds = n_folds, - ml_g = learner_pars$ml_g$clone(), + ml_l = learner_pars$ml_l$clone(), ml_m = learner_pars$ml_m$clone(), ml_r = learner_pars$ml_r$clone(), + ml_g = ml_g, dml_procedure = dml_procedure, score = score, smpls = double_mlpliv_obj$smpls, n_rep = n_rep) @@ -74,15 +85,19 @@ patrick::with_parameters_test_that("Unit tests for PLIV with one-way clustering: residuals = compute_pliv_residuals(df, y = "Y", d = "D", z = "Z", n_folds = n_folds, - this_smpl, - pliv_hat$all_preds[[i_rep]]) - u_hat = residuals$u_hat - v_hat = residuals$v_hat - w_hat = residuals$w_hat + smpls = this_smpl, + all_preds = pliv_hat$all_preds[[i_rep]]) + y_minus_l_hat = residuals$y_minus_l_hat + d_minus_r_hat = residuals$d_minus_r_hat + z_minus_m_hat = residuals$z_minus_m_hat + y_minus_g_hat = residuals$y_minus_g_hat + D = df[, "D"] - psi_a = -w_hat * v_hat + if (score == "partialling out") psi_a = -z_minus_m_hat * d_minus_r_hat + if (score == "IV-type") psi_a = -D * z_minus_m_hat if (dml_procedure == "dml2") { - psi_b = w_hat * u_hat + if (score == "partialling out") psi_b = z_minus_m_hat * y_minus_l_hat + if (score == "IV-type") psi_b = z_minus_m_hat * y_minus_g_hat theta = est_one_way_cluster_dml2( psi_a, psi_b, cluster_var, @@ -90,7 +105,8 @@ patrick::with_parameters_test_that("Unit tests for PLIV with one-way clustering: } else { theta = pliv_hat$thetas[i_rep] } - psi = (u_hat - v_hat * theta) * w_hat + if (score == "partialling out") psi = (y_minus_l_hat - d_minus_r_hat * theta) * z_minus_m_hat + if (score == "IV-type") psi = (y_minus_g_hat - D * theta) * z_minus_m_hat var = var_one_way_cluster( psi, psi_a, cluster_var, diff --git a/tests/testthat/test-double_ml_pliv_parameter_passing.R b/tests/testthat/test-double_ml_pliv_parameter_passing.R index 1105a4b1..bb6afd75 100644 --- a/tests/testthat/test-double_ml_pliv_parameter_passing.R +++ b/tests/testthat/test-double_ml_pliv_parameter_passing.R @@ -13,7 +13,7 @@ if (on_cran) { test_cases = expand.grid( learner = "regr.rpart", dml_procedure = c("dml1", "dml2"), - score = "partialling out", + score = c("partialling out", "IV-type"), stringsAsFactors = FALSE) } @@ -35,15 +35,22 @@ patrick::with_parameters_test_that("Unit tests for parameter passing of PLIV (oo learner_pars = get_default_mlmethod_pliv(learner) set.seed(3141) + if (score == "IV-type") { + ml_g = mlr3::lrn(learner_pars$mlmethod$mlmethod_g) + } else { + ml_g = NULL + } pliv_hat = dml_pliv(data_pliv$df, y = "y", d = "d", z = "z", n_folds = n_folds, n_rep = n_rep, - ml_g = mlr3::lrn(learner_pars$mlmethod$mlmethod_g), + ml_l = mlr3::lrn(learner_pars$mlmethod$mlmethod_l), ml_m = mlr3::lrn(learner_pars$mlmethod$mlmethod_m), ml_r = mlr3::lrn(learner_pars$mlmethod$mlmethod_r), - params_g = learner_pars$params$params_g, + ml_g = ml_g, + params_l = learner_pars$params$params_l, params_m = learner_pars$params$params_m, params_r = learner_pars$params$params_r, + params_g = learner_pars$params$params_g, dml_procedure = dml_procedure, score = score) theta = pliv_hat$coef se = pliv_hat$se @@ -54,22 +61,29 @@ patrick::with_parameters_test_that("Unit tests for parameter passing of PLIV (oo n_folds = n_folds, n_rep = n_rep, smpls = pliv_hat$smpls, all_preds = pliv_hat$all_preds, - bootstrap = "normal", n_rep_boot = n_rep_boot)$boot_coef + bootstrap = "normal", n_rep_boot = n_rep_boot, + score = score)$boot_coef set.seed(3141) + if (score == "IV-type") { + ml_g = mlr3::lrn(learner_pars$mlmethod$mlmethod_g) + } else { + ml_g = NULL + } dml_pliv_obj = DoubleMLPLIV$new( data = data_pliv$dml_data, n_folds = n_folds, n_rep = n_rep, - ml_g = mlr3::lrn(learner_pars$mlmethod$mlmethod_g), + ml_l = mlr3::lrn(learner_pars$mlmethod$mlmethod_l), ml_m = mlr3::lrn(learner_pars$mlmethod$mlmethod_m), ml_r = mlr3::lrn(learner_pars$mlmethod$mlmethod_r), + ml_g = ml_g, dml_procedure = dml_procedure, score = score) dml_pliv_obj$set_ml_nuisance_params( treat_var = "d", - learner = "ml_g", - params = learner_pars$params$params_g) + learner = "ml_l", + params = learner_pars$params$params_l) dml_pliv_obj$set_ml_nuisance_params( treat_var = "d", learner = "ml_m", @@ -78,6 +92,13 @@ patrick::with_parameters_test_that("Unit tests for parameter passing of PLIV (oo treat_var = "d", learner = "ml_r", params = learner_pars$params$params_r) + if (score == "IV-type") { + dml_pliv_obj$set_ml_nuisance_params( + treat_var = "d", + learner = "ml_g", + params = learner_pars$params$params_g) + } + dml_pliv_obj$fit() @@ -109,35 +130,48 @@ patrick::with_parameters_test_that("Unit tests for parameter passing of PLIV (no test_ids = list(my_sampling$test_set(1)) smpls = list(list(train_ids = train_ids, test_ids = test_ids)) + if (score == "IV-type") { + ml_g = mlr3::lrn(learner_pars$mlmethod$mlmethod_g) + } else { + ml_g = NULL + } pliv_hat = dml_pliv(data_pliv$df, y = "y", d = "d", z = "z", n_folds = 1, - ml_g = mlr3::lrn(learner_pars$mlmethod$mlmethod_g), + ml_l = mlr3::lrn(learner_pars$mlmethod$mlmethod_l), ml_m = mlr3::lrn(learner_pars$mlmethod$mlmethod_m), ml_r = mlr3::lrn(learner_pars$mlmethod$mlmethod_r), - params_g = learner_pars$params$params_g, + ml_g = ml_g, + params_l = learner_pars$params$params_l, params_m = learner_pars$params$params_m, params_r = learner_pars$params$params_r, + params_g = learner_pars$params$params_g, dml_procedure = dml_procedure, score = score, smpls = smpls) theta = pliv_hat$coef se = pliv_hat$se set.seed(3141) + if (score == "IV-type") { + ml_g = mlr3::lrn(learner_pars$mlmethod$mlmethod_g) + } else { + ml_g = NULL + } dml_pliv_nocf = DoubleMLPLIV$new( data = data_pliv$dml_data, n_folds = n_folds, - ml_g = mlr3::lrn(learner_pars$mlmethod$mlmethod_g), + ml_l = mlr3::lrn(learner_pars$mlmethod$mlmethod_l), ml_m = mlr3::lrn(learner_pars$mlmethod$mlmethod_m), ml_r = mlr3::lrn(learner_pars$mlmethod$mlmethod_r), + ml_g = ml_g, dml_procedure = dml_procedure, score = score, apply_cross_fitting = FALSE) dml_pliv_nocf$set_ml_nuisance_params( treat_var = "d", - learner = "ml_g", - params = learner_pars$params$params_g) + learner = "ml_l", + params = learner_pars$params$params_l) dml_pliv_nocf$set_ml_nuisance_params( treat_var = "d", learner = "ml_m", @@ -146,6 +180,12 @@ patrick::with_parameters_test_that("Unit tests for parameter passing of PLIV (no treat_var = "d", learner = "ml_r", params = learner_pars$params$params_r) + if (score == "IV-type") { + dml_pliv_nocf$set_ml_nuisance_params( + treat_var = "d", + learner = "ml_g", + params = learner_pars$params$params_g) + } dml_pliv_nocf$fit() theta_obj = dml_pliv_nocf$coef @@ -164,18 +204,24 @@ patrick::with_parameters_test_that("Unit tests for parameter passing of PLIV (fo learner_pars = get_default_mlmethod_pliv(learner) set.seed(3141) + if (score == "IV-type") { + ml_g = mlr3::lrn(learner_pars$mlmethod$mlmethod_g) + } else { + ml_g = NULL + } dml_pliv_obj = DoubleMLPLIV$new(data_pliv$dml_data, n_folds = n_folds, n_rep = n_rep, - ml_g = mlr3::lrn(learner_pars$mlmethod$mlmethod_g), + ml_l = mlr3::lrn(learner_pars$mlmethod$mlmethod_l), ml_m = mlr3::lrn(learner_pars$mlmethod$mlmethod_m), ml_r = mlr3::lrn(learner_pars$mlmethod$mlmethod_r), + ml_g = ml_g, dml_procedure = dml_procedure, score = score) dml_pliv_obj$set_ml_nuisance_params( treat_var = "d", - learner = "ml_g", - params = learner_pars$params$params_g) + learner = "ml_l", + params = learner_pars$params$params_l) dml_pliv_obj$set_ml_nuisance_params( treat_var = "d", learner = "ml_m", @@ -184,28 +230,41 @@ patrick::with_parameters_test_that("Unit tests for parameter passing of PLIV (fo treat_var = "d", learner = "ml_r", params = learner_pars$params$params_r) + if (score == "IV-type") { + dml_pliv_obj$set_ml_nuisance_params( + treat_var = "d", + learner = "ml_g", + params = learner_pars$params$params_g) + } dml_pliv_obj$fit() theta = dml_pliv_obj$coef se = dml_pliv_obj$se - params_g_fold_wise = rep(list(rep(list(learner_pars$params$params_g), n_folds)), n_rep) + params_l_fold_wise = rep(list(rep(list(learner_pars$params$params_l), n_folds)), n_rep) params_m_fold_wise = rep(list(rep(list(learner_pars$params$params_m), n_folds)), n_rep) params_r_fold_wise = rep(list(rep(list(learner_pars$params$params_r), n_folds)), n_rep) + params_g_fold_wise = rep(list(rep(list(learner_pars$params$params_g), n_folds)), n_rep) set.seed(3141) + if (score == "IV-type") { + ml_g = mlr3::lrn(learner_pars$mlmethod$mlmethod_g) + } else { + ml_g = NULL + } dml_pliv_obj_fold_wise = DoubleMLPLIV$new(data_pliv$dml_data, n_folds = n_folds, n_rep = n_rep, - ml_g = mlr3::lrn(learner_pars$mlmethod$mlmethod_g), + ml_l = mlr3::lrn(learner_pars$mlmethod$mlmethod_l), ml_m = mlr3::lrn(learner_pars$mlmethod$mlmethod_m), ml_r = mlr3::lrn(learner_pars$mlmethod$mlmethod_r), + ml_g = ml_g, dml_procedure = dml_procedure, score = score) dml_pliv_obj_fold_wise$set_ml_nuisance_params( treat_var = "d", - learner = "ml_g", - params = params_g_fold_wise, + learner = "ml_l", + params = params_l_fold_wise, set_fold_specific = TRUE) dml_pliv_obj_fold_wise$set_ml_nuisance_params( treat_var = "d", @@ -217,6 +276,13 @@ patrick::with_parameters_test_that("Unit tests for parameter passing of PLIV (fo learner = "ml_r", params = params_r_fold_wise, set_fold_specific = TRUE) + if (score == "IV-type") { + dml_pliv_obj_fold_wise$set_ml_nuisance_params( + treat_var = "d", + learner = "ml_g", + params = params_g_fold_wise, + set_fold_specific = TRUE) + } dml_pliv_obj_fold_wise$fit() theta_fold_wise = dml_pliv_obj_fold_wise$coef @@ -232,16 +298,23 @@ patrick::with_parameters_test_that("Unit tests for parameter passing of PLIV (de n_folds = 2 n_rep = 3 - params_g = list(cp = 0.01, minsplit = 20) # this are defaults + params_l = list(cp = 0.01, minsplit = 20) # this are defaults params_m = list(cp = 0.01, minsplit = 20) # this are defaults params_r = list(cp = 0.01, minsplit = 20) # this are defaults + params_g = list(cp = 0.01, minsplit = 20) # this are defaults set.seed(3141) + if (score == "IV-type") { + ml_g = lrn("regr.rpart") + } else { + ml_g = NULL + } dml_pliv_default = DoubleMLPLIV$new(data_pliv$dml_data, n_folds = n_folds, n_rep = n_rep, - ml_g = lrn("regr.rpart"), + ml_l = lrn("regr.rpart"), ml_m = lrn("regr.rpart"), ml_r = lrn("regr.rpart"), + ml_g = ml_g, dml_procedure = dml_procedure, score = score) @@ -250,19 +323,25 @@ patrick::with_parameters_test_that("Unit tests for parameter passing of PLIV (de se_default = dml_pliv_default$se set.seed(3141) + if (score == "IV-type") { + ml_g = lrn("regr.rpart") + } else { + ml_g = NULL + } dml_pliv_obj = DoubleMLPLIV$new( data = data_pliv$dml_data, n_folds = n_folds, n_rep = n_rep, - ml_g = lrn("regr.rpart"), + ml_l = lrn("regr.rpart"), ml_m = lrn("regr.rpart"), ml_r = lrn("regr.rpart"), + ml_g = ml_g, dml_procedure = dml_procedure, score = score) dml_pliv_obj$set_ml_nuisance_params( treat_var = "d", - learner = "ml_g", - params = params_g) + learner = "ml_l", + params = params_l) dml_pliv_obj$set_ml_nuisance_params( treat_var = "d", learner = "ml_m", @@ -271,6 +350,12 @@ patrick::with_parameters_test_that("Unit tests for parameter passing of PLIV (de treat_var = "d", learner = "ml_r", params = params_r) + if (score == "IV-type") { + dml_pliv_obj$set_ml_nuisance_params( + treat_var = "d", + learner = "ml_g", + params = params_g) + } dml_pliv_obj$fit() theta = dml_pliv_obj$coef diff --git a/tests/testthat/test-double_ml_pliv_partial_functional_initializer.R b/tests/testthat/test-double_ml_pliv_partial_functional_initializer.R index 253903e6..36902402 100644 --- a/tests/testthat/test-double_ml_pliv_partial_functional_initializer.R +++ b/tests/testthat/test-double_ml_pliv_partial_functional_initializer.R @@ -31,7 +31,7 @@ patrick::with_parameters_test_that("Unit tests for PLIV (partialX functional ini set.seed(3141) double_mlpliv_obj = DoubleMLPLIV$new(data_ml, n_folds = 5, - ml_g = learner_pars$ml_g$clone(), + ml_l = learner_pars$ml_l$clone(), ml_m = learner_pars$ml_m$clone(), ml_r = learner_pars$ml_r$clone(), dml_procedure = dml_procedure, @@ -45,7 +45,7 @@ patrick::with_parameters_test_that("Unit tests for PLIV (partialX functional ini set.seed(3141) double_mlpliv_partX = DoubleMLPLIV.partialX(data_ml, n_folds = 5, - ml_g = learner_pars$ml_g$clone(), + ml_l = learner_pars$ml_l$clone(), ml_m = learner_pars$ml_m$clone(), ml_r = learner_pars$ml_r$clone(), dml_procedure = dml_procedure, @@ -72,7 +72,7 @@ patrick::with_parameters_test_that("Unit tests for PLIV (partialZ functional ini set.seed(3141) double_mlpliv_partZ = DoubleMLPLIV$new(data_ml, n_folds = 5, - ml_g = NULL, + ml_l = NULL, ml_m = NULL, ml_r = learner_pars$ml_r$clone(), dml_procedure = dml_procedure, @@ -111,7 +111,7 @@ patrick::with_parameters_test_that("Unit tests for PLIV (partialXZ functional in set.seed(3141) double_mlpliv_partXZ = DoubleMLPLIV$new(data_ml, n_folds = 5, - ml_g = learner_pars$ml_g$clone(), + ml_l = learner_pars$ml_l$clone(), ml_m = learner_pars$ml_m$clone(), ml_r = learner_pars$ml_r$clone(), dml_procedure = dml_procedure, @@ -125,7 +125,7 @@ patrick::with_parameters_test_that("Unit tests for PLIV (partialXZ functional in set.seed(3141) double_mlpliv_partXZ_fun = DoubleMLPLIV.partialXZ(data_ml, n_folds = 5, - ml_g = learner_pars$ml_g$clone(), + ml_l = learner_pars$ml_l$clone(), ml_m = learner_pars$ml_m$clone(), ml_r = learner_pars$ml_r$clone(), dml_procedure = dml_procedure, diff --git a/tests/testthat/test-double_ml_pliv_partial_functional_initializer_IVtype.R b/tests/testthat/test-double_ml_pliv_partial_functional_initializer_IVtype.R new file mode 100644 index 00000000..35a277a1 --- /dev/null +++ b/tests/testthat/test-double_ml_pliv_partial_functional_initializer_IVtype.R @@ -0,0 +1,63 @@ +context("Unit tests for PLIV, partialling out X, Z, XZ") + +lgr::get_logger("mlr3")$set_threshold("warn") + +on_cran = !identical(Sys.getenv("NOT_CRAN"), "true") +if (on_cran) { + test_cases = expand.grid( + learner = "regr.lm", + dml_procedure = "dml2", + score = "IV-type", + stringsAsFactors = FALSE) +} else { + test_cases = expand.grid( + learner = c("regr.lm", "regr.cv_glmnet"), + dml_procedure = c("dml1", "dml2"), + score = "IV-type", + stringsAsFactors = FALSE) +} +test_cases[".test_name"] = apply(test_cases, 1, paste, collapse = "_") + +patrick::with_parameters_test_that("Unit tests for PLIV (partialX functional initialization):", + .cases = test_cases, { + learner_pars = get_default_mlmethod_pliv(learner) + df = data_pliv$df + Xnames = names(df)[names(df) %in% c("y", "d", "z", "z2") == FALSE] + data_ml = double_ml_data_from_data_frame(df, + y_col = "y", + d_cols = "d", x_cols = Xnames, z_cols = "z") + + # Partial out X (default PLIV) + set.seed(3141) + double_mlpliv_obj = DoubleMLPLIV$new(data_ml, + n_folds = 5, + ml_l = learner_pars$ml_l$clone(), + ml_m = learner_pars$ml_m$clone(), + ml_r = learner_pars$ml_r$clone(), + ml_g = learner_pars$ml_g$clone(), + dml_procedure = dml_procedure, + score = score) + + double_mlpliv_obj$fit() + theta_obj = double_mlpliv_obj$coef + se_obj = double_mlpliv_obj$se + + # Partial out X + set.seed(3141) + double_mlpliv_partX = DoubleMLPLIV.partialX(data_ml, + n_folds = 5, + ml_l = learner_pars$ml_l$clone(), + ml_m = learner_pars$ml_m$clone(), + ml_r = learner_pars$ml_r$clone(), + ml_g = learner_pars$ml_g$clone(), + dml_procedure = dml_procedure, + score = score) + + double_mlpliv_partX$fit() + theta_partX = double_mlpliv_partX$coef + se_partX = double_mlpliv_partX$se + + expect_equal(theta_partX, theta_obj, tolerance = 1e-8) + expect_equal(se_partX, se_obj, tolerance = 1e-8) + } +) diff --git a/tests/testthat/test-double_ml_pliv_partial_x.R b/tests/testthat/test-double_ml_pliv_partial_x.R index 00c1df72..bbfc6be8 100644 --- a/tests/testthat/test-double_ml_pliv_partial_x.R +++ b/tests/testthat/test-double_ml_pliv_partial_x.R @@ -21,7 +21,7 @@ patrick::with_parameters_test_that("Unit tests for PLIV.partialX:", pliv_hat = dml_pliv_partial_x(data_pliv_partialX$df, y = "y", d = "d", z = paste0("Z", 1:dim_z), n_folds = 5, - ml_g = learner_pars$ml_g$clone(), + ml_l = learner_pars$ml_l$clone(), ml_m = learner_pars$ml_m$clone(), ml_r = learner_pars$ml_r$clone(), dml_procedure = dml_procedure, score = score) @@ -38,7 +38,7 @@ patrick::with_parameters_test_that("Unit tests for PLIV.partialX:", set.seed(3141) double_mlpliv_obj = DoubleMLPLIV.partialX(data_pliv_partialX$dml_data, - ml_g = learner_pars$ml_g$clone(), + ml_l = learner_pars$ml_l$clone(), ml_m = learner_pars$ml_m$clone(), ml_r = learner_pars$ml_r$clone(), n_folds = 5, @@ -67,7 +67,7 @@ test_that("Unit tests for PLIV.partialX invalid score", { "partialX=TRUE and partialZ=FALSE with several instruments.") double_mlplr_obj <- DoubleMLPLIV.partialX( data_pliv_partialX$dml_data, - ml_g = mlr3::lrn("regr.rpart"), + ml_l = mlr3::lrn("regr.rpart"), ml_m = mlr3::lrn("regr.rpart"), ml_r = mlr3::lrn("regr.rpart"), score = function(x) { diff --git a/tests/testthat/test-double_ml_pliv_partial_xz.R b/tests/testthat/test-double_ml_pliv_partial_xz.R index 2812c36e..befef206 100644 --- a/tests/testthat/test-double_ml_pliv_partial_xz.R +++ b/tests/testthat/test-double_ml_pliv_partial_xz.R @@ -23,7 +23,7 @@ patrick::with_parameters_test_that("Unit tests for PLIV.partialXZ:", pliv_hat = dml_pliv_partial_xz(data_pliv_partialXZ$df, y = "y", d = "d", z = paste0("Z", 1:dim_z), n_folds = 5, - ml_g = learner_pars$ml_g$clone(), + ml_l = learner_pars$ml_l$clone(), ml_m = learner_pars$ml_m$clone(), ml_r = learner_pars$ml_r$clone(), dml_procedure = dml_procedure, score = score) @@ -39,7 +39,7 @@ patrick::with_parameters_test_that("Unit tests for PLIV.partialXZ:", set.seed(3141) double_mlpliv_obj = DoubleMLPLIV.partialXZ(data_pliv_partialXZ$dml_data, - ml_g = learner_pars$ml_g$clone(), + ml_l = learner_pars$ml_l$clone(), ml_m = learner_pars$ml_m$clone(), ml_r = learner_pars$ml_r$clone(), n_folds = 5, @@ -67,7 +67,7 @@ test_that("Unit tests for PLIV.partialXZ invalid score", { "partialX=TRUE and partialZ=TRUE.") double_mlplr_obj <- DoubleMLPLIV.partialXZ( data_pliv_partialXZ$dml_data, - ml_g = mlr3::lrn("regr.rpart"), + ml_l = mlr3::lrn("regr.rpart"), ml_m = mlr3::lrn("regr.rpart"), ml_r = mlr3::lrn("regr.rpart"), score = function(x) { diff --git a/tests/testthat/test-double_ml_pliv_partial_xz_parameter_passing.R b/tests/testthat/test-double_ml_pliv_partial_xz_parameter_passing.R index 77798f29..a07ffef6 100644 --- a/tests/testthat/test-double_ml_pliv_partial_xz_parameter_passing.R +++ b/tests/testthat/test-double_ml_pliv_partial_xz_parameter_passing.R @@ -32,10 +32,10 @@ patrick::with_parameters_test_that("Unit tests for parameter passing of PLIV.par pliv_hat = dml_pliv_partial_xz(df, y = "y", d = "d", z = c("z", "z2"), n_folds = n_folds, n_rep = n_rep, - ml_g = mlr3::lrn(learner_pars$mlmethod$mlmethod_g), + ml_l = mlr3::lrn(learner_pars$mlmethod$mlmethod_l), ml_m = mlr3::lrn(learner_pars$mlmethod$mlmethod_m), ml_r = mlr3::lrn(learner_pars$mlmethod$mlmethod_r), - params_g = learner_pars$params$params_g, + params_l = learner_pars$params$params_l, params_m = learner_pars$params$params_m, params_r = learner_pars$params$params_r, dml_procedure = dml_procedure, score = score) @@ -59,7 +59,7 @@ patrick::with_parameters_test_that("Unit tests for parameter passing of PLIV.par dml_pliv_obj = DoubleMLPLIV.partialXZ( data = dml_data, n_folds = n_folds, n_rep = n_rep, - ml_g = mlr3::lrn(learner_pars$mlmethod$mlmethod_g), + ml_l = mlr3::lrn(learner_pars$mlmethod$mlmethod_l), ml_m = mlr3::lrn(learner_pars$mlmethod$mlmethod_m), ml_r = mlr3::lrn(learner_pars$mlmethod$mlmethod_r), dml_procedure = dml_procedure, @@ -67,8 +67,8 @@ patrick::with_parameters_test_that("Unit tests for parameter passing of PLIV.par dml_pliv_obj$set_ml_nuisance_params( treat_var = "d", - learner = "ml_g", - params = learner_pars$params$params_g) + learner = "ml_l", + params = learner_pars$params$params_l) dml_pliv_obj$set_ml_nuisance_params( treat_var = "d", learner = "ml_m", @@ -112,10 +112,10 @@ patrick::with_parameters_test_that("Unit tests for parameter passing of PLIV.par pliv_hat = dml_pliv_partial_xz(df, y = "y", d = "d", z = c("z", "z2"), n_folds = 1, - ml_g = mlr3::lrn(learner_pars$mlmethod$mlmethod_g), + ml_l = mlr3::lrn(learner_pars$mlmethod$mlmethod_l), ml_m = mlr3::lrn(learner_pars$mlmethod$mlmethod_m), ml_r = mlr3::lrn(learner_pars$mlmethod$mlmethod_r), - params_g = learner_pars$params$params_g, + params_l = learner_pars$params$params_l, params_m = learner_pars$params$params_m, params_r = learner_pars$params$params_r, dml_procedure = dml_procedure, score = score, @@ -132,7 +132,7 @@ patrick::with_parameters_test_that("Unit tests for parameter passing of PLIV.par dml_pliv_nocf = DoubleMLPLIV.partialXZ( data = dml_data, n_folds = n_folds, - ml_g = mlr3::lrn(learner_pars$mlmethod$mlmethod_g), + ml_l = mlr3::lrn(learner_pars$mlmethod$mlmethod_l), ml_m = mlr3::lrn(learner_pars$mlmethod$mlmethod_m), ml_r = mlr3::lrn(learner_pars$mlmethod$mlmethod_r), dml_procedure = dml_procedure, @@ -141,8 +141,8 @@ patrick::with_parameters_test_that("Unit tests for parameter passing of PLIV.par dml_pliv_nocf$set_ml_nuisance_params( treat_var = "d", - learner = "ml_g", - params = learner_pars$params$params_g) + learner = "ml_l", + params = learner_pars$params$params_l) dml_pliv_nocf$set_ml_nuisance_params( treat_var = "d", learner = "ml_m", @@ -177,7 +177,7 @@ patrick::with_parameters_test_that("Unit tests for parameter passing of PLIV.par set.seed(3141) dml_pliv_obj = DoubleMLPLIV.partialXZ(dml_data, n_folds = n_folds, n_rep = n_rep, - ml_g = mlr3::lrn(learner_pars$mlmethod$mlmethod_g), + ml_l = mlr3::lrn(learner_pars$mlmethod$mlmethod_l), ml_m = mlr3::lrn(learner_pars$mlmethod$mlmethod_m), ml_r = mlr3::lrn(learner_pars$mlmethod$mlmethod_r), dml_procedure = dml_procedure, @@ -185,8 +185,8 @@ patrick::with_parameters_test_that("Unit tests for parameter passing of PLIV.par dml_pliv_obj$set_ml_nuisance_params( treat_var = "d", - learner = "ml_g", - params = learner_pars$params$params_g) + learner = "ml_l", + params = learner_pars$params$params_l) dml_pliv_obj$set_ml_nuisance_params( treat_var = "d", learner = "ml_m", @@ -200,14 +200,14 @@ patrick::with_parameters_test_that("Unit tests for parameter passing of PLIV.par theta = dml_pliv_obj$coef se = dml_pliv_obj$se - params_g_fold_wise = rep(list(rep(list(learner_pars$params$params_g), n_folds)), n_rep) + params_l_fold_wise = rep(list(rep(list(learner_pars$params$params_l), n_folds)), n_rep) params_m_fold_wise = rep(list(rep(list(learner_pars$params$params_m), n_folds)), n_rep) params_r_fold_wise = rep(list(rep(list(learner_pars$params$params_r), n_folds)), n_rep) set.seed(3141) dml_pliv_obj_fold_wise = DoubleMLPLIV.partialXZ(dml_data, n_folds = n_folds, n_rep = n_rep, - ml_g = mlr3::lrn(learner_pars$mlmethod$mlmethod_g), + ml_l = mlr3::lrn(learner_pars$mlmethod$mlmethod_l), ml_m = mlr3::lrn(learner_pars$mlmethod$mlmethod_m), ml_r = mlr3::lrn(learner_pars$mlmethod$mlmethod_r), dml_procedure = dml_procedure, @@ -215,8 +215,8 @@ patrick::with_parameters_test_that("Unit tests for parameter passing of PLIV.par dml_pliv_obj_fold_wise$set_ml_nuisance_params( treat_var = "d", - learner = "ml_g", - params = params_g_fold_wise, + learner = "ml_l", + params = params_l_fold_wise, set_fold_specific = TRUE) dml_pliv_obj_fold_wise$set_ml_nuisance_params( treat_var = "d", @@ -243,7 +243,7 @@ patrick::with_parameters_test_that("Unit tests for parameter passing of PLIV.par n_folds = 2 n_rep = 3 - params_g = list(cp = 0.01, minsplit = 20) # this are defaults + params_l = list(cp = 0.01, minsplit = 20) # this are defaults params_m = list(cp = 0.01, minsplit = 20) # this are defaults params_r = list(cp = 0.01, minsplit = 20) # this are defaults @@ -256,7 +256,7 @@ patrick::with_parameters_test_that("Unit tests for parameter passing of PLIV.par set.seed(3141) dml_pliv_default = DoubleMLPLIV.partialXZ(dml_data, n_folds = n_folds, n_rep = n_rep, - ml_g = lrn("regr.rpart"), + ml_l = lrn("regr.rpart"), ml_m = lrn("regr.rpart"), ml_r = lrn("regr.rpart"), dml_procedure = dml_procedure, @@ -269,7 +269,7 @@ patrick::with_parameters_test_that("Unit tests for parameter passing of PLIV.par set.seed(3141) dml_pliv_obj = DoubleMLPLIV.partialXZ(dml_data, n_folds = n_folds, n_rep = n_rep, - ml_g = lrn("regr.rpart"), + ml_l = lrn("regr.rpart"), ml_m = lrn("regr.rpart"), ml_r = lrn("regr.rpart"), dml_procedure = dml_procedure, @@ -277,8 +277,8 @@ patrick::with_parameters_test_that("Unit tests for parameter passing of PLIV.par dml_pliv_obj$set_ml_nuisance_params( treat_var = "d", - learner = "ml_g", - params = params_g) + learner = "ml_l", + params = params_l) dml_pliv_obj$set_ml_nuisance_params( treat_var = "d", learner = "ml_m", diff --git a/tests/testthat/test-double_ml_pliv_tuning.R b/tests/testthat/test-double_ml_pliv_tuning.R index 0687d250..7cee8b40 100644 --- a/tests/testthat/test-double_ml_pliv_tuning.R +++ b/tests/testthat/test-double_ml_pliv_tuning.R @@ -15,41 +15,141 @@ tune_settings = list( n_rep_tune = 1, rsmp_tune = "cv", measure = list( - "ml_g" = "regr.mse", + "ml_l" = "regr.mse", "ml_r" = "regr.mse", "ml_m" = "regr.mse"), terminator = mlr3tuning::trm("evals", n_evals = 2), algorithm = "grid_search", - tuning_instance_g = NULL, + tuning_instance_l = NULL, tuning_instance_m = NULL, tuner = "grid_search", resolution = 5) on_cran = !identical(Sys.getenv("NOT_CRAN"), "true") if (on_cran) { - test_cases = expand.grid( + test_cases_one_z = expand.grid( dml_procedure = "dml2", score = "partialling out", n_rep = c(1), tune_on_folds = FALSE, - z_indx = c(1), stringsAsFactors = FALSE) } else { - test_cases = expand.grid( + test_cases_one_z = expand.grid( dml_procedure = c("dml1", "dml2"), - score = "partialling out", + score = c("partialling out", "IV-type"), n_rep = c(1, 3), tune_on_folds = c(FALSE, TRUE), - z_indx = c(1, 2), stringsAsFactors = FALSE) } -test_cases[".test_name"] = apply(test_cases, 1, paste, collapse = "_") +test_cases_one_z[".test_name"] = apply(test_cases_one_z, 1, paste, collapse = "_") # skip('Skip tests for tuning') patrick::with_parameters_test_that("Unit tests for tuning of PLIV", - .cases = test_cases, { + .cases = test_cases_one_z, { + + # TBD: Functional Test Case + + set.seed(3141) + n_folds = 2 + n_rep_boot = 498 + + z_cols = "z" + set.seed(3141) + df = data_pliv$df + Xnames = names(df)[names(df) %in% c("y", "d", "z", "z2") == FALSE] + data_ml = double_ml_data_from_data_frame(df, + y_col = "y", + d_cols = "d", x_cols = Xnames, z_cols = z_cols) + + if (score == "IV-type") { + ml_g = learner + } else { + ml_g = NULL + } + double_mlpliv_obj_tuned = DoubleMLPLIV$new(data_ml, + n_folds = n_folds, + ml_l = learner, + ml_m = learner, + ml_r = learner, + ml_g = ml_g, + dml_procedure = dml_procedure, + score = score, + n_rep = n_rep) + + param_grid = list( + "ml_l" = paradox::ParamSet$new(list( + paradox::ParamDbl$new("cp", lower = 0.01, upper = 0.02), + paradox::ParamInt$new("minsplit", lower = 1, upper = 2))), + "ml_m" = paradox::ParamSet$new(list( + paradox::ParamDbl$new("cp", lower = 0.01, upper = 0.02), + paradox::ParamInt$new("minsplit", lower = 1, upper = 2))), + "ml_r" = paradox::ParamSet$new(list( + paradox::ParamDbl$new("cp", lower = 0.01, upper = 0.02), + paradox::ParamInt$new("minsplit", lower = 1, upper = 2)))) + if (score == "IV-type") { + param_grid[["ml_g"]] = paradox::ParamSet$new(list( + paradox::ParamDbl$new("cp", lower = 0.01, upper = 0.02), + paradox::ParamInt$new("minsplit", lower = 1, upper = 2))) + tune_settings[["measure"]][["ml_g"]] = "regr.mse" + } + + double_mlpliv_obj_tuned$tune(param_set = param_grid, tune_settings = tune_settings, tune_on_folds = tune_on_folds) + double_mlpliv_obj_tuned$fit() + + theta_obj_tuned = double_mlpliv_obj_tuned$coef + se_obj_tuned = double_mlpliv_obj_tuned$se + + # bootstrap + # double_mlplr_obj_exact$bootstrap(method = 'normal', n_rep = n_rep_boot) + # boot_theta_obj_exact = double_mlplr_obj_exact$boot_coef + + expect_is(theta_obj_tuned, "numeric") + expect_is(se_obj_tuned, "numeric") + + # if (data_ml$n_instr() == 1) { + # double_mlpliv_obj_tuned_Z = DoubleMLPLIV.partialZ(data_ml, + # n_folds = n_folds, + # ml_r = learner, + # dml_procedure = dml_procedure, + # score = score, + # n_rep = n_rep) + # + # double_mlpliv_obj_tuned_Z$tune(param_set = param_grid, tune_on_folds = tune_on_folds) + # double_mlpliv_obj_tuned_Z$fit() + # + # theta_obj_tuned_Z = double_mlpliv_obj_tuned_Z$coef + # se_obj_tuned_Z = double_mlpliv_obj_tuned_Z$se + # + # expect_is(theta_obj_tuned_Z, "numeric") + # expect_is(se_obj_tuned_Z, "numeric") + # } + # + } +) + +on_cran = !identical(Sys.getenv("NOT_CRAN"), "true") +if (on_cran) { + test_cases_multiple_z = expand.grid( + dml_procedure = "dml2", + score = "partialling out", + n_rep = c(1), + tune_on_folds = FALSE, + stringsAsFactors = FALSE) +} else { + test_cases_multiple_z = expand.grid( + dml_procedure = c("dml1", "dml2"), + score = "partialling out", + n_rep = c(1, 3), + tune_on_folds = c(FALSE, TRUE), + stringsAsFactors = FALSE) +} + +test_cases_multiple_z[".test_name"] = apply(test_cases_multiple_z, 1, paste, collapse = "_") + +patrick::with_parameters_test_that("Unit tests for tuning of PLIV (multiple Z)", + .cases = test_cases_multiple_z, { # TBD: Functional Test Case @@ -57,17 +157,7 @@ patrick::with_parameters_test_that("Unit tests for tuning of PLIV", n_folds = 2 n_rep_boot = 498 - # set.seed(3141) - # pliv_hat = dml_plriv(data_pliv, y = "y", d = "d", z = 'z', - # n_folds = n_folds, mlmethod = learner_list, - # params = learner_pars$params, - # dml_procedure = dml_procedure, score = score, - # bootstrap = "normal", n_rep_boot = n_rep_boot) - # theta = coef(pliv_hat) - # se = pliv_hat$se - - z_vars = list("z", c("z", "z2")) - z_cols = z_vars[[z_indx]] + z_cols = c("z", "z2") set.seed(3141) df = data_pliv$df Xnames = names(df)[names(df) %in% c("y", "d", "z", "z2") == FALSE] @@ -77,7 +167,7 @@ patrick::with_parameters_test_that("Unit tests for tuning of PLIV", double_mlpliv_obj_tuned = DoubleMLPLIV$new(data_ml, n_folds = n_folds, - ml_g = learner, + ml_l = learner, ml_m = learner, ml_r = learner, dml_procedure = dml_procedure, @@ -85,7 +175,7 @@ patrick::with_parameters_test_that("Unit tests for tuning of PLIV", n_rep = n_rep) param_grid = list( - "ml_g" = paradox::ParamSet$new(list( + "ml_l" = paradox::ParamSet$new(list( paradox::ParamDbl$new("cp", lower = 0.01, upper = 0.02), paradox::ParamInt$new("minsplit", lower = 1, upper = 2))), "ml_m" = paradox::ParamSet$new(list( @@ -137,7 +227,7 @@ patrick::with_parameters_test_that("Unit tests for tuning of PLIV", param_grid_r = list("ml_r" = param_grid[["ml_r"]]) tune_settings_r = tune_settings - tune_settings_r$measure$ml_g = tune_settings_r$measure$ml_m = NULL + tune_settings_r$measure$ml_l = tune_settings_r$measure$ml_m = NULL double_mlpliv_obj_tuned_Z$tune( param_set = param_grid_r, tune_on_folds = tune_on_folds, tune_settings = tune_settings_r) @@ -152,7 +242,7 @@ patrick::with_parameters_test_that("Unit tests for tuning of PLIV", set.seed(3141) double_mlpliv_obj_tuned_XZ = DoubleMLPLIV.partialXZ(data_ml, n_folds = n_folds, - ml_g = learner, + ml_l = learner, ml_m = learner, ml_r = learner, dml_procedure = dml_procedure, diff --git a/tests/testthat/test-double_ml_pliv_two_way_cluster.R b/tests/testthat/test-double_ml_pliv_two_way_cluster.R index 5353f372..4512a1d0 100644 --- a/tests/testthat/test-double_ml_pliv_two_way_cluster.R +++ b/tests/testthat/test-double_ml_pliv_two_way_cluster.R @@ -15,7 +15,7 @@ if (on_cran) { test_cases = expand.grid( learner = c("regr.lm", "regr.glmnet"), dml_procedure = c("dml1", "dml2"), - score = "partialling out", + score = c("partialling out", "IV-type"), stringsAsFactors = FALSE) } test_cases[".test_name"] = apply(test_cases, 1, paste, collapse = "_") @@ -31,12 +31,18 @@ patrick::with_parameters_test_that("Unit tests for PLIV with two-way clustering: learner_pars = get_default_mlmethod_pliv(learner) set.seed(3141) + if (score == "IV-type") { + ml_g = learner_pars$ml_g$clone() + } else { + ml_g = NULL + } double_mlpliv_obj = DoubleMLPLIV$new( data = data_two_way, n_folds = 2, - ml_g = learner_pars$ml_g$clone(), + ml_l = learner_pars$ml_l$clone(), ml_m = learner_pars$ml_m$clone(), ml_r = learner_pars$ml_r$clone(), + ml_g = ml_g, dml_procedure = dml_procedure, score = score) @@ -51,12 +57,18 @@ patrick::with_parameters_test_that("Unit tests for PLIV with two-way clustering: cluster_var2 = df$cluster_var_j # need to drop variables as x is not explicitly set df = df[, !(names(df) %in% c("cluster_var_i", "cluster_var_j"))] + if (score == "IV-type") { + ml_g = learner_pars$ml_g$clone() + } else { + ml_g = NULL + } pliv_hat = dml_pliv(df, y = "Y", d = "D", z = "Z", n_folds = 4, - ml_g = learner_pars$ml_g$clone(), + ml_l = learner_pars$ml_l$clone(), ml_m = learner_pars$ml_m$clone(), ml_r = learner_pars$ml_r$clone(), + ml_g = ml_g, dml_procedure = dml_procedure, score = score, smpls = double_mlpliv_obj$smpls) @@ -65,15 +77,19 @@ patrick::with_parameters_test_that("Unit tests for PLIV with two-way clustering: residuals = compute_pliv_residuals(df, y = "Y", d = "D", z = "Z", n_folds = 4, - this_smpl, - pliv_hat$all_preds[[1]]) - u_hat = residuals$u_hat - v_hat = residuals$v_hat - w_hat = residuals$w_hat + smpls = this_smpl, + all_preds = pliv_hat$all_preds[[1]]) + y_minus_l_hat = residuals$y_minus_l_hat + d_minus_r_hat = residuals$d_minus_r_hat + z_minus_m_hat = residuals$z_minus_m_hat + y_minus_g_hat = residuals$y_minus_g_hat + D = df[, "D"] - psi_a = -w_hat * v_hat + if (score == "partialling out") psi_a = -z_minus_m_hat * d_minus_r_hat + if (score == "IV-type") psi_a = -D * z_minus_m_hat if (dml_procedure == "dml2") { - psi_b = w_hat * u_hat + if (score == "partialling out") psi_b = z_minus_m_hat * y_minus_l_hat + if (score == "IV-type") psi_b = z_minus_m_hat * y_minus_g_hat theta = est_two_way_cluster_dml2( psi_a, psi_b, cluster_var1, @@ -82,7 +98,8 @@ patrick::with_parameters_test_that("Unit tests for PLIV with two-way clustering: } else { theta = pliv_hat$coef } - psi = (u_hat - v_hat * theta) * w_hat + if (score == "partialling out") psi = (y_minus_l_hat - d_minus_r_hat * theta) * z_minus_m_hat + if (score == "IV-type") psi = (y_minus_g_hat - D * theta) * z_minus_m_hat var = var_two_way_cluster( psi, psi_a, cluster_var1, diff --git a/tests/testthat/test-double_ml_pliv_user_score.R b/tests/testthat/test-double_ml_pliv_user_score.R index b9d53063..0bd000b1 100644 --- a/tests/testthat/test-double_ml_pliv_user_score.R +++ b/tests/testthat/test-double_ml_pliv_user_score.R @@ -4,8 +4,8 @@ library("mlr3learners") lgr::get_logger("mlr3")$set_threshold("warn") -score_fct = function(y, z, d, g_hat, m_hat, r_hat, smpls) { - u_hat = y - g_hat +score_fct_po = function(y, z, d, l_hat, m_hat, r_hat, g_hat, smpls) { + u_hat = y - l_hat w_hat = d - r_hat v_hat = z - m_hat psi_a = -w_hat * v_hat @@ -15,16 +15,27 @@ score_fct = function(y, z, d, g_hat, m_hat, r_hat, smpls) { psi_b = psi_b) } +score_fct_iv = function(y, z, d, l_hat, m_hat, r_hat, g_hat, smpls) { + v_hat = z - m_hat + psi_a = -d * v_hat + psi_b = v_hat * (y - g_hat) + psis = list( + psi_a = psi_a, + psi_b = psi_b) +} + on_cran = !identical(Sys.getenv("NOT_CRAN"), "true") if (on_cran) { test_cases = expand.grid( learner = "regr.lm", dml_procedure = "dml2", + score = "partialling out", stringsAsFactors = FALSE) } else { test_cases = expand.grid( learner = c("regr.lm", "regr.glmnet"), dml_procedure = c("dml1", "dml2"), + score = c("partialling out", "IV-type"), stringsAsFactors = FALSE) } test_cases[".test_name"] = apply(test_cases, 1, paste, collapse = "_") @@ -33,15 +44,24 @@ patrick::with_parameters_test_that("Unit tests for PLIV, callable score:", .cases = test_cases, { n_rep_boot = 498 + if (score == "partialling out") { + score_fct = score_fct_po + ml_g = NULL + } else if (score == "IV-type") { + score_fct = score_fct_iv + ml_g = lrn(learner) + } + set.seed(3141) double_mlpliv_obj = DoubleMLPLIV$new( data = data_pliv$dml_data, n_folds = 5, - ml_g = lrn(learner), + ml_l = lrn(learner), ml_m = lrn(learner), ml_r = lrn(learner), + ml_g = ml_g, dml_procedure = dml_procedure, - score = "partialling out") + score = score) double_mlpliv_obj$fit() theta_obj = double_mlpliv_obj$coef @@ -54,9 +74,10 @@ patrick::with_parameters_test_that("Unit tests for PLIV, callable score:", double_mlpliv_obj_score = DoubleMLPLIV$new( data = data_pliv$dml_data, n_folds = 5, - ml_g = lrn(learner), + ml_l = lrn(learner), ml_m = lrn(learner), ml_r = lrn(learner), + ml_g = ml_g, dml_procedure = dml_procedure, score = score_fct) diff --git a/tests/testthat/test-double_ml_plr.R b/tests/testthat/test-double_ml_plr.R index 9869b0d1..469ac338 100644 --- a/tests/testthat/test-double_ml_plr.R +++ b/tests/testthat/test-double_ml_plr.R @@ -30,7 +30,9 @@ patrick::with_parameters_test_that("Unit tests for PLR:", plr_hat = dml_plr(data_plr$df, y = "y", d = "d", n_folds = n_folds, - ml_g = learner_pars$ml_g$clone(), ml_m = learner_pars$ml_m$clone(), + ml_l = learner_pars$ml_l$clone(), + ml_m = learner_pars$ml_m$clone(), + ml_g = learner_pars$ml_g$clone(), dml_procedure = dml_procedure, score = score) theta = plr_hat$coef se = plr_hat$se @@ -47,13 +49,24 @@ patrick::with_parameters_test_that("Unit tests for PLR:", score = score)$boot_coef set.seed(3141) - double_mlplr_obj = DoubleMLPLR$new( - data = data_plr$dml_data, - ml_g = learner_pars$ml_g$clone(), - ml_m = learner_pars$ml_m$clone(), - dml_procedure = dml_procedure, - n_folds = n_folds, - score = score) + if (score == "partialling out") { + double_mlplr_obj = DoubleMLPLR$new( + data = data_plr$dml_data, + ml_l = learner_pars$ml_g$clone(), + ml_m = learner_pars$ml_m$clone(), + dml_procedure = dml_procedure, + n_folds = n_folds, + score = score) + } else { + double_mlplr_obj = DoubleMLPLR$new( + data = data_plr$dml_data, + ml_l = learner_pars$ml_l$clone(), + ml_m = learner_pars$ml_m$clone(), + ml_g = learner_pars$ml_g$clone(), + dml_procedure = dml_procedure, + n_folds = n_folds, + score = score) + } double_mlplr_obj$fit() theta_obj = double_mlplr_obj$coef diff --git a/tests/testthat/test-double_ml_plr_classifier.R b/tests/testthat/test-double_ml_plr_classifier.R index b453a846..5b29dad1 100644 --- a/tests/testthat/test-double_ml_plr_classifier.R +++ b/tests/testthat/test-double_ml_plr_classifier.R @@ -7,15 +7,17 @@ lgr::get_logger("mlr3")$set_threshold("warn") on_cran = !identical(Sys.getenv("NOT_CRAN"), "true") if (on_cran) { test_cases = expand.grid( - g_learner = c("regr.rpart", "classif.rpart"), + l_learner = c("regr.rpart", "classif.rpart"), m_learner = "classif.rpart", + g_learner = "regr.rpart", dml_procedure = "dml2", score = "partialling out", stringsAsFactors = FALSE) } else { test_cases = expand.grid( - g_learner = "regr.cv_glmnet", + l_learner = c("regr.rpart", "classif.rpart"), m_learner = "classif.cv_glmnet", + g_learner = "regr.cv_glmnet", dml_procedure = c("dml1", "dml2"), score = c("IV-type", "partialling out"), stringsAsFactors = FALSE) @@ -27,15 +29,24 @@ patrick::with_parameters_test_that("Unit tests for PLR with classifier for ml_m: n_rep_boot = 498 n_folds = 3 - if (g_learner == "regr.cv_glmnet") { - ml_g = mlr3::lrn(g_learner) - ml_m = mlr3::lrn(m_learner) + ml_l = mlr3::lrn(l_learner) + ml_m = mlr3::lrn(m_learner) + ml_g = mlr3::lrn(g_learner) + + if (ml_l$task_type == "regr") { set.seed(3141) + if (score == "IV-type") { + ml_g = ml_g$clone() + } else { + ml_g = NULL + } plr_hat = dml_plr(data_irm$df, y = "y", d = "d", n_folds = n_folds, - ml_g = ml_g$clone(), ml_m = ml_m$clone(), + ml_l = ml_l$clone(), + ml_m = ml_m$clone(), + ml_g = ml_g, dml_procedure = dml_procedure, score = score) theta = plr_hat$coef se = plr_hat$se @@ -52,10 +63,16 @@ patrick::with_parameters_test_that("Unit tests for PLR with classifier for ml_m: pval = plr_hat$pval set.seed(3141) + if (score == "IV-type") { + ml_g = ml_g$clone() + } else { + ml_g = NULL + } double_mlplr_obj = DoubleMLPLR$new( data = data_irm$dml_data, - ml_g = ml_g$clone(), + ml_l = ml_l$clone(), ml_m = ml_m$clone(), + ml_g = ml_g, dml_procedure = dml_procedure, n_folds = n_folds, score = score) @@ -76,12 +93,18 @@ patrick::with_parameters_test_that("Unit tests for PLR with classifier for ml_m: expect_equal(pval, pval_obj, tolerance = 1e-8) # expect_equal(ci, ci_obj, tolerance = 1e-8) - } else if (g_learner == "classif.cv_glmnet") { - msg = "Invalid learner provided for ml_g: must be of class 'LearnerRegr'" + } else if (ml_l$task_type == "classif") { + msg = "Invalid learner provided for ml_l: 'learner\\$task_type' must be 'regr'" + if (score == "IV-type") { + ml_g = ml_g$clone() + } else { + ml_g = NULL + } expect_error(DoubleMLPLR$new( data = data_irm$dml_data, - ml_g = lrn(g_learner), - ml_m = lrn(m_learner), + ml_l = ml_l$clone(), + ml_m = ml_m$clone(), + ml_g = ml_g, dml_procedure = dml_procedure, n_folds = n_folds, score = score), @@ -99,7 +122,7 @@ test_that("Unit tests for exception handling of PLR with classifier for ml_m:", dml_data = double_ml_data_from_data_frame(df, y_col = "y", d_cols = "d") double_mlplr_obj = DoubleMLPLR$new( data = dml_data, - ml_g = mlr3::lrn("regr.rpart"), + ml_l = mlr3::lrn("regr.rpart"), ml_m = mlr3::lrn("classif.rpart")) msg = paste( "Assertion on 'levels\\(data\\[\\[target\\]\\])' failed: .* set \\{'0','1'\\}") @@ -112,7 +135,7 @@ test_that("Unit tests for exception handling of PLR with classifier for ml_m:", dml_data = double_ml_data_from_data_frame(df, y_col = "y", d_cols = "d") double_mlplr_obj = DoubleMLPLR$new( data = dml_data, - ml_g = mlr3::lrn("regr.rpart"), + ml_l = mlr3::lrn("regr.rpart"), ml_m = mlr3::lrn("classif.rpart")) msg = paste( "Assertion on 'levels\\(data\\[\\[target\\]\\])' failed: .* set \\{'0','1'\\}") diff --git a/tests/testthat/test-double_ml_plr_exception_handling.R b/tests/testthat/test-double_ml_plr_exception_handling.R index 2b6ed203..5de1cd25 100644 --- a/tests/testthat/test-double_ml_plr_exception_handling.R +++ b/tests/testthat/test-double_ml_plr_exception_handling.R @@ -2,7 +2,10 @@ context("Unit tests for exception handling if fit() or bootstrap() was not run y library("mlr3learners") +logger = lgr::get_logger("bbotk") +logger$set_threshold("warn") lgr::get_logger("mlr3")$set_threshold("warn") + on_cran = !identical(Sys.getenv("NOT_CRAN"), "true") if (on_cran) { test_cases = expand.grid( @@ -48,10 +51,16 @@ patrick::with_parameters_test_that("Unit tests for exception handling of PLR:", } else { msg = "Assertion on 'i' failed: Element 1 is not <= 1." } + if (score == "IV-type") { + ml_g = learner_pars$mlmethod$mlmethod_g + } else { + ml_g = NULL + } expect_error(DoubleMLPLR$new( data = data_ml, - ml_g = learner_pars$mlmethod$mlmethod_g, + ml_l = learner_pars$mlmethod$mlmethod_l, ml_m = mlr3::lrn(learner_pars$mlmethod$mlmethod_m), + ml_g = ml_g, dml_procedure = dml_procedure, n_folds = n_folds, n_rep = n_rep, @@ -59,10 +68,16 @@ patrick::with_parameters_test_that("Unit tests for exception handling of PLR:", apply_cross_fitting = apply_cross_fitting), regexp = msg) } else { + if (score == "IV-type") { + ml_g = learner_pars$mlmethod$mlmethod_g + } else { + ml_g = NULL + } double_mlplr_obj = DoubleMLPLR$new( data = data_ml, - ml_g = learner_pars$mlmethod$mlmethod_g, + ml_l = learner_pars$mlmethod$mlmethod_l, ml_m = mlr3::lrn(learner_pars$mlmethod$mlmethod_m), + ml_g = ml_g, dml_procedure = dml_procedure, n_folds = n_folds, n_rep = n_rep, @@ -76,11 +91,20 @@ patrick::with_parameters_test_that("Unit tests for exception handling of PLR:", treat_var = "d", params = learner_pars$params$params_m) - # set params for nuisance part g + # set params for nuisance part l double_mlplr_obj$set_ml_nuisance_params( - learner = "ml_g", + learner = "ml_l", treat_var = "d", - params = learner_pars$params$params_g) + params = learner_pars$params$params_l) + + if (score == "IV-type") { + # set params for nuisance part g + double_mlplr_obj$set_ml_nuisance_params( + learner = "ml_g", + treat_var = "d", + params = learner_pars$params$params_g) + } + } # currently, no warning or message printed @@ -100,6 +124,66 @@ patrick::with_parameters_test_that("Unit tests for exception handling of PLR:", msg = "Multiplier bootstrap has not yet been performed. First call bootstrap\\(\\) and then try confint\\(\\) again." expect_error(double_mlplr_obj$confint(joint = TRUE, level = 0.95), regexp = msg) + + set.seed(3141) + dml_data = make_plr_CCDDHNR2018(n_obs = 101) + ml_l = lrn("regr.ranger") + ml_m = ml_l$clone() + ml_g = ml_l$clone() + + if (score == "partialling out") { + msg = paste0( + "A learner ml_g has been provided for ", + "score = 'partialling out' but will be ignored.") + expect_warning(DoubleMLPLR$new(dml_data, + ml_l = ml_l, ml_m = ml_m, ml_g = ml_g, + score = score), + regexp = msg) + } } } ) + +test_that("Unit tests for deprecation warnings of PLR", { + set.seed(3141) + dml_data = make_plr_CCDDHNR2018(n_obs = 101) + ml_l = lrn("regr.ranger") + ml_m = ml_l$clone() + ml_g = ml_l$clone() + msg = paste0("The argument ml_g was renamed to ml_l.") + expect_warning(DoubleMLPLR$new(dml_data, ml_g = ml_g, ml_m = ml_m), + regexp = msg) + + msg = "learners ml_l and ml_g should be specified" + expect_warning(DoubleMLPLR$new(dml_data, ml_l, ml_m, + score = "IV-type"), + regexp = msg) + + dml_obj = DoubleMLPLR$new(dml_data, ml_l = ml_l, ml_m = ml_m) + + msg = paste0("Learner ml_g was renamed to ml_l.") + expect_warning(dml_obj$set_ml_nuisance_params( + "ml_g", "d", list("num.trees" = 10)), + regexp = msg) + + par_grids = list( + "ml_g" = paradox::ParamSet$new(list( + paradox::ParamInt$new("num.trees", lower = 9, upper = 10))), + "ml_m" = paradox::ParamSet$new(list( + paradox::ParamInt$new("num.trees", lower = 10, upper = 11)))) + + msg = paste0("Learner ml_g was renamed to ml_l.") + expect_warning(dml_obj$tune(par_grids), + regexp = msg) + + tune_settings = list( + n_folds_tune = 5, + rsmp_tune = mlr3::rsmp("cv", folds = 5), + measure = list(ml_g = "regr.mse", ml_m = "regr.mae"), + terminator = mlr3tuning::trm("evals", n_evals = 20), + algorithm = mlr3tuning::tnr("grid_search"), + resolution = 5) + expect_warning(dml_obj$tune(par_grids, tune_settings = tune_settings), + regexp = msg) +} +) diff --git a/tests/testthat/test-double_ml_plr_export_preds.R b/tests/testthat/test-double_ml_plr_export_preds.R index 555606c4..3929ecab 100644 --- a/tests/testthat/test-double_ml_plr_export_preds.R +++ b/tests/testthat/test-double_ml_plr_export_preds.R @@ -7,17 +7,19 @@ lgr::get_logger("mlr3")$set_threshold("warn") on_cran = !identical(Sys.getenv("NOT_CRAN"), "true") if (on_cran) { test_cases = expand.grid( - g_learner = "regr.rpart", + l_learner = "regr.rpart", m_learner = "regr.rpart", + g_learner = "regr.rpart", dml_procedure = "dml2", score = "partialling out", stringsAsFactors = FALSE) } else { test_cases = expand.grid( - g_learner = c("regr.rpart", "regr.lm"), + l_learner = c("regr.rpart", "regr.lm"), m_learner = c("regr.rpart", "regr.lm"), + g_learner = c("regr.rpart", "regr.lm"), dml_procedure = "dml2", - score = "partialling out", + score = c("partialling out", "IV-type"), stringsAsFactors = FALSE) } test_cases[".test_name"] = apply(test_cases, 1, paste, collapse = "_") @@ -30,10 +32,16 @@ patrick::with_parameters_test_that("Unit tests for for the export of predictions df = data_plr$df dml_data = data_plr$dml_data + if (score == "IV-type") { + ml_g = lrn(g_learner) + } else { + ml_g = NULL + } double_mlplr_obj = DoubleMLPLR$new( data = dml_data, - ml_g = lrn(g_learner), + ml_l = lrn(l_learner), ml_m = lrn(m_learner), + ml_g = ml_g, dml_procedure = dml_procedure, n_folds = n_folds, score = score) @@ -44,14 +52,14 @@ patrick::with_parameters_test_that("Unit tests for for the export of predictions Xnames = names(df)[names(df) %in% c("y", "d", "z") == FALSE] indx = (names(df) %in% c(Xnames, "y")) data = df[, indx] - task = mlr3::TaskRegr$new(id = "ml_g", backend = data, target = "y") + task = mlr3::TaskRegr$new(id = "ml_l", backend = data, target = "y") resampling_smpls = rsmp("custom")$instantiate( task, double_mlplr_obj$smpls[[1]]$train_ids, double_mlplr_obj$smpls[[1]]$test_ids) - resampling_pred = resample(task, lrn(g_learner), resampling_smpls) - preds_g = as.data.table(resampling_pred$prediction()) - data.table::setorder(preds_g, "row_ids") + resampling_pred = resample(task, lrn(l_learner), resampling_smpls) + preds_l = as.data.table(resampling_pred$prediction()) + data.table::setorder(preds_l, "row_ids") Xnames = names(df)[names(df) %in% c("y", "d", "z") == FALSE] indx = (names(df) %in% c(Xnames, "d")) @@ -65,8 +73,36 @@ patrick::with_parameters_test_that("Unit tests for for the export of predictions preds_m = as.data.table(resampling_pred$prediction()) data.table::setorder(preds_m, "row_ids") - expect_equal(as.vector(double_mlplr_obj$predictions$ml_g), - as.vector(preds_g$response), + if (score == "IV-type") { + d = df[["d"]] + y = df[["y"]] + psi_a = -(d - preds_m$response) * (d - preds_m$response) + psi_b = (d - preds_m$response) * (y - preds_l$response) + theta_initial = -mean(psi_b, na.rm = TRUE) / mean(psi_a, na.rm = TRUE) + + data_aux = cbind(df, "y_minus_theta_d" = y - theta_initial * d) + Xnames = names(data_aux)[names(data_aux) %in% + c("y", "d", "z", "y_minus_theta_d") == FALSE] + indx = (names(data_aux) %in% c(Xnames, "y_minus_theta_d")) + data = data_aux[, indx] + task = mlr3::TaskRegr$new( + id = "ml_g", backend = data, + target = "y_minus_theta_d") + resampling_smpls = rsmp("custom")$instantiate( + task, + double_mlplr_obj$smpls[[1]]$train_ids, + double_mlplr_obj$smpls[[1]]$test_ids) + resampling_pred = resample(task, lrn(g_learner), resampling_smpls) + preds_g = as.data.table(resampling_pred$prediction()) + data.table::setorder(preds_g, "row_ids") + + expect_equal(as.vector(double_mlplr_obj$predictions$ml_g), + as.vector(preds_g$response), + tolerance = 1e-8) + } + + expect_equal(as.vector(double_mlplr_obj$predictions$ml_l), + as.vector(preds_l$response), tolerance = 1e-8) expect_equal(as.vector(double_mlplr_obj$predictions$ml_m), diff --git a/tests/testthat/test-double_ml_plr_loaded_mlr3learner.R b/tests/testthat/test-double_ml_plr_loaded_mlr3learner.R index 2c93b153..eff13312 100644 --- a/tests/testthat/test-double_ml_plr_loaded_mlr3learner.R +++ b/tests/testthat/test-double_ml_plr_loaded_mlr3learner.R @@ -29,10 +29,16 @@ patrick::with_parameters_test_that("Unit tests for PLR:", params = list("cp" = 0.01, "minsplit" = 20) set.seed(123) + if (score == "IV-type") { + ml_g = learner_name + } else { + ml_g = NULL + } double_mlplr = DoubleMLPLR$new( data = data_plr$dml_data, - ml_g = learner_name, + ml_l = learner_name, ml_m = learner_name, + ml_g = ml_g, dml_procedure = dml_procedure, n_folds = n_folds, score = score) @@ -43,12 +49,20 @@ patrick::with_parameters_test_that("Unit tests for PLR:", treat_var = "d", params = params) - # set params for nuisance part g + # set params for nuisance part l double_mlplr$set_ml_nuisance_params( - learner = "ml_g", + learner = "ml_l", treat_var = "d", params = params) + if (score == "IV-type") { + # set params for nuisance part g + double_mlplr$set_ml_nuisance_params( + learner = "ml_g", + treat_var = "d", + params = params) + } + double_mlplr$fit() theta = double_mlplr$coef se = double_mlplr$se @@ -60,10 +74,16 @@ patrick::with_parameters_test_that("Unit tests for PLR:", set.seed(123) loaded_learner = mlr3::lrn("regr.rpart", "cp" = 0.01, "minsplit" = 20) + if (score == "IV-type") { + ml_g = loaded_learner + } else { + ml_g = NULL + } double_mlplr_loaded = DoubleMLPLR$new( data = data_plr$dml_data, - ml_g = loaded_learner, + ml_l = loaded_learner, ml_m = loaded_learner, + ml_g = ml_g, dml_procedure = dml_procedure, n_folds = n_folds, score = score) @@ -79,10 +99,16 @@ patrick::with_parameters_test_that("Unit tests for PLR:", set.seed(123) semiloaded_learner = mlr3::lrn("regr.rpart") + if (score == "IV-type") { + ml_g = semiloaded_learner + } else { + ml_g = NULL + } double_mlplr_semiloaded = DoubleMLPLR$new( data = data_plr$dml_data, - ml_g = semiloaded_learner, + ml_l = semiloaded_learner, ml_m = semiloaded_learner, + ml_g = ml_g, dml_procedure = dml_procedure, n_folds = n_folds, score = score) @@ -92,12 +118,20 @@ patrick::with_parameters_test_that("Unit tests for PLR:", treat_var = "d", params = params) - # set params for nuisance part g + # set params for nuisance part l double_mlplr_semiloaded$set_ml_nuisance_params( - learner = "ml_g", + learner = "ml_l", treat_var = "d", params = params) + if (score == "IV-type") { + # set params for nuisance part g + double_mlplr_semiloaded$set_ml_nuisance_params( + learner = "ml_g", + treat_var = "d", + params = params) + } + double_mlplr_semiloaded$fit() theta_semiloaded = double_mlplr_semiloaded$coef se_semiloaded = double_mlplr_semiloaded$se diff --git a/tests/testthat/test-double_ml_plr_multitreat.R b/tests/testthat/test-double_ml_plr_multitreat.R index b5b96447..0a31d76f 100644 --- a/tests/testthat/test-double_ml_plr_multitreat.R +++ b/tests/testthat/test-double_ml_plr_multitreat.R @@ -29,11 +29,17 @@ patrick::with_parameters_test_that("Unit tests for PLR:", n_folds = 5 set.seed(3141) + if (score == "IV-type") { + ml_g = learner_pars$ml_g$clone() + } else { + ml_g = NULL + } plr_hat = dml_plr_multitreat(data_plr_multi, y = "y", d = c("d1", "d2", "d3"), n_folds = n_folds, - ml_g = learner_pars$ml_g$clone(), + ml_l = learner_pars$ml_l$clone(), ml_m = learner_pars$ml_m$clone(), + ml_g = ml_g, dml_procedure = dml_procedure, score = score) theta = plr_hat$coef se = plr_hat$se @@ -56,9 +62,15 @@ patrick::with_parameters_test_that("Unit tests for PLR:", y_col = "y", d_cols = c("d1", "d2", "d3"), x_cols = Xnames) + if (score == "IV-type") { + ml_g = learner_pars$ml_g$clone() + } else { + ml_g = NULL + } double_mlplr_obj = DoubleMLPLR$new(data_ml, - ml_g = learner_pars$ml_g$clone(), + ml_l = learner_pars$ml_l$clone(), ml_m = learner_pars$ml_m$clone(), + ml_g = ml_g, dml_procedure = dml_procedure, n_folds = n_folds, score = score) diff --git a/tests/testthat/test-double_ml_plr_nocrossfit.R b/tests/testthat/test-double_ml_plr_nocrossfit.R index c46393de..9279ca9b 100644 --- a/tests/testthat/test-double_ml_plr_nocrossfit.R +++ b/tests/testthat/test-double_ml_plr_nocrossfit.R @@ -43,11 +43,17 @@ patrick::with_parameters_test_that("Unit tests for PLR:", train_ids = list(seq(nrow(df))), test_ids = list(seq(nrow(df))))) } + if (score == "IV-type") { + ml_g = learner_pars$ml_g$clone() + } else { + ml_g = NULL + } plr_hat = dml_plr(df, y = "y", d = "d", n_folds = 1, - ml_g = learner_pars$ml_g$clone(), + ml_l = learner_pars$ml_l$clone(), ml_m = learner_pars$ml_m$clone(), + ml_g = ml_g, dml_procedure = dml_procedure, score = score, smpls = smpls) theta = plr_hat$coef @@ -56,10 +62,16 @@ patrick::with_parameters_test_that("Unit tests for PLR:", pval = plr_hat$pval set.seed(3141) + if (score == "IV-type") { + ml_g = learner_pars$ml_g$clone() + } else { + ml_g = NULL + } double_mlplr_obj = DoubleMLPLR$new( data = data_plr$dml_data, - ml_g = learner_pars$ml_g$clone(), + ml_l = learner_pars$ml_l$clone(), ml_m = learner_pars$ml_m$clone(), + ml_g = ml_g, dml_procedure = dml_procedure, n_folds = n_folds, score = score, @@ -74,10 +86,16 @@ patrick::with_parameters_test_that("Unit tests for PLR:", if (n_folds == 2) { + if (score == "IV-type") { + ml_g = learner_pars$ml_g$clone() + } else { + ml_g = NULL + } dml_plr_obj_external = DoubleMLPLR$new( data = data_plr$dml_data, - ml_g = learner_pars$ml_g$clone(), + ml_l = learner_pars$ml_l$clone(), ml_m = learner_pars$ml_m$clone(), + ml_g = ml_g, dml_procedure = dml_procedure, n_folds = n_folds, score = score, diff --git a/tests/testthat/test-double_ml_plr_nonorth.R b/tests/testthat/test-double_ml_plr_nonorth.R index 7d6da79d..cdada942 100644 --- a/tests/testthat/test-double_ml_plr_nonorth.R +++ b/tests/testthat/test-double_ml_plr_nonorth.R @@ -4,7 +4,21 @@ library("mlr3learners") lgr::get_logger("mlr3")$set_threshold("warn") -non_orth_score = function(y, d, g_hat, m_hat, smpls) { +non_orth_score_w_g = function(y, d, l_hat, m_hat, g_hat, smpls) { + u_hat = y - g_hat + psi_a = -1 * d * d + psi_b = d * u_hat + psis = list(psi_a = psi_a, psi_b = psi_b) + return(psis) +} + +non_orth_score_w_l = function(y, d, l_hat, m_hat, g_hat, smpls) { + + p_a = -(d - m_hat) * (d - m_hat) + p_b = (d - m_hat) * (y - l_hat) + theta_initial = -mean(p_b, na.rm = TRUE) / mean(p_a, na.rm = TRUE) + g_hat = l_hat - theta_initial * m_hat + u_hat = y - g_hat psi_a = -1 * d * d psi_b = d * u_hat @@ -17,7 +31,7 @@ if (on_cran) { test_cases = expand.grid( learner = "regr.lm", dml_procedure = "dml1", - score = c(non_orth_score), + which_score = c("non_orth_score_w_g"), n_folds = c(3), n_rep = c(2), stringsAsFactors = FALSE) @@ -25,7 +39,9 @@ if (on_cran) { test_cases = expand.grid( learner = c("regr.lm", "regr.cv_glmnet"), dml_procedure = c("dml1", "dml2"), - score = c(non_orth_score), + which_score = c( + "non_orth_score_w_g", + "non_orth_score_w_l"), n_folds = c(2, 3), n_rep = c(1, 2), stringsAsFactors = FALSE) @@ -35,12 +51,22 @@ test_cases[".test_name"] = apply(test_cases, 1, paste, collapse = "_") patrick::with_parameters_test_that("Unit tests for PLR:", .cases = test_cases, { learner_pars = get_default_mlmethod_plr(learner) + + if (which_score == "non_orth_score_w_g") { + score = non_orth_score_w_g + ml_g = learner_pars$ml_g$clone() + } else if (which_score == "non_orth_score_w_l") { + score = non_orth_score_w_l + ml_g = NULL + } + n_rep_boot = 498 set.seed(3141) double_mlplr_obj = DoubleMLPLR$new( data = data_plr$dml_data, - ml_g = learner_pars$ml_g$clone(), + ml_l = learner_pars$ml_l$clone(), ml_m = learner_pars$ml_m$clone(), + ml_g = ml_g, dml_procedure = dml_procedure, n_folds = n_folds, score = score) @@ -63,8 +89,9 @@ patrick::with_parameters_test_that("Unit tests for PLR:", if (n_folds == 2 & n_rep == 1) { double_mlplr_nocf = DoubleMLPLR$new( data = data_plr$dml_data, - ml_g = learner_pars$ml_g$clone(), + ml_l = learner_pars$ml_l$clone(), ml_m = learner_pars$ml_m$clone(), + ml_g = ml_g, dml_procedure = dml_procedure, n_folds = n_folds, score = score, diff --git a/tests/testthat/test-double_ml_plr_p_adjust.R b/tests/testthat/test-double_ml_plr_p_adjust.R index e98211dc..ae021215 100644 --- a/tests/testthat/test-double_ml_plr_p_adjust.R +++ b/tests/testthat/test-double_ml_plr_p_adjust.R @@ -53,9 +53,15 @@ patrick::with_parameters_test_that("Unit tests for PLR:", x_cols = colnames(X)[(k + 1):p], y_col = "y", d_cols = colnames(X)[1:k]) + if (score == "IV-type") { + ml_g = learner_pars$ml_g$clone() + } else { + ml_g = NULL + } double_mlplr_obj = DoubleMLPLR$new(data_ml, - ml_g = learner_pars$ml_g$clone(), + ml_l = learner_pars$ml_l$clone(), ml_m = learner_pars$ml_m$clone(), + ml_g = ml_g, dml_procedure = dml_procedure, n_folds = n_folds, score = score, diff --git a/tests/testthat/test-double_ml_plr_parameter_passing.R b/tests/testthat/test-double_ml_plr_parameter_passing.R index b4fe9641..4c127050 100644 --- a/tests/testthat/test-double_ml_plr_parameter_passing.R +++ b/tests/testthat/test-double_ml_plr_parameter_passing.R @@ -35,17 +35,25 @@ patrick::with_parameters_test_that("Unit tests for parameter passing of PLR (oop n_rep = 3 learner_pars = get_default_mlmethod_plr(learner) - params_g = rep(list(learner_pars$params$params_g), 2) + params_l = rep(list(learner_pars$params$params_l), 2) params_m = rep(list(learner_pars$params$params_m), 2) + params_g = rep(list(learner_pars$params$params_g), 2) set.seed(3141) + if (score == "IV-type") { + ml_g = mlr3::lrn(learner_pars$mlmethod$mlmethod_g) + } else { + ml_g = NULL + } plr_hat = dml_plr_multitreat(data_plr_multi, y = "y", d = c("d1", "d2"), n_folds = n_folds, n_rep = n_rep, - ml_g = mlr3::lrn(learner_pars$mlmethod$mlmethod_g), + ml_l = mlr3::lrn(learner_pars$mlmethod$mlmethod_l), ml_m = mlr3::lrn(learner_pars$mlmethod$mlmethod_m), - params_g = params_g, + ml_g = ml_g, + params_l = params_l, params_m = params_m, + params_g = params_g, dml_procedure = dml_procedure, score = score) theta = plr_hat$coef se = plr_hat$se @@ -65,26 +73,40 @@ patrick::with_parameters_test_that("Unit tests for parameter passing of PLR (oop d_cols = c("d1", "d2"), x_cols = Xnames) set.seed(3141) + if (score == "IV-type") { + ml_g = mlr3::lrn(learner_pars$mlmethod$mlmethod_g) + } else { + ml_g = NULL + } double_mlplr_obj = DoubleMLPLR$new(data_ml, n_folds = n_folds, - ml_g = mlr3::lrn(learner_pars$mlmethod$mlmethod_g), + ml_l = mlr3::lrn(learner_pars$mlmethod$mlmethod_l), ml_m = mlr3::lrn(learner_pars$mlmethod$mlmethod_m), + ml_g = ml_g, dml_procedure = dml_procedure, score = score, n_rep = n_rep) double_mlplr_obj$set_ml_nuisance_params( - treat_var = "d1", learner = "ml_g", - params = learner_pars$params$params_g) + treat_var = "d1", learner = "ml_l", + params = learner_pars$params$params_l) double_mlplr_obj$set_ml_nuisance_params( - treat_var = "d2", learner = "ml_g", - params = learner_pars$params$params_g) + treat_var = "d2", learner = "ml_l", + params = learner_pars$params$params_l) double_mlplr_obj$set_ml_nuisance_params( treat_var = "d1", learner = "ml_m", params = learner_pars$params$params_m) double_mlplr_obj$set_ml_nuisance_params( treat_var = "d2", learner = "ml_m", params = learner_pars$params$params_m) + if (score == "IV-type") { + double_mlplr_obj$set_ml_nuisance_params( + treat_var = "d1", learner = "ml_g", + params = learner_pars$params$params_g) + double_mlplr_obj$set_ml_nuisance_params( + treat_var = "d2", learner = "ml_g", + params = learner_pars$params$params_g) + } double_mlplr_obj$fit() @@ -108,8 +130,9 @@ patrick::with_parameters_test_that("Unit tests for parameter passing of PLR (no n_folds = 2 learner_pars = get_default_mlmethod_plr(learner) - params_g = rep(list(learner_pars$params$params_g), 2) + params_l = rep(list(learner_pars$params$params_l), 2) params_m = rep(list(learner_pars$params$params_m), 2) + params_g = rep(list(learner_pars$params$params_g), 2) # Passing for non-cross-fitting case set.seed(3141) @@ -119,13 +142,20 @@ patrick::with_parameters_test_that("Unit tests for parameter passing of PLR (no test_ids = list(my_sampling$test_set(1)) smpls = list(list(train_ids = train_ids, test_ids = test_ids)) + if (score == "IV-type") { + ml_g = mlr3::lrn(learner_pars$mlmethod$mlmethod_g) + } else { + ml_g = NULL + } plr_hat = dml_plr_multitreat(data_plr_multi, y = "y", d = c("d1", "d2"), n_folds = 1, - ml_g = mlr3::lrn(learner_pars$mlmethod$mlmethod_g), + ml_l = mlr3::lrn(learner_pars$mlmethod$mlmethod_l), ml_m = mlr3::lrn(learner_pars$mlmethod$mlmethod_m), - params_g = params_g, + ml_g = ml_g, + params_l = params_l, params_m = params_m, + params_g = params_g, dml_procedure = dml_procedure, score = score, smpls = smpls) theta = plr_hat$coef @@ -137,26 +167,40 @@ patrick::with_parameters_test_that("Unit tests for parameter passing of PLR (no d_cols = c("d1", "d2"), x_cols = Xnames) set.seed(3141) + if (score == "IV-type") { + ml_g = mlr3::lrn(learner_pars$mlmethod$mlmethod_g) + } else { + ml_g = NULL + } double_mlplr_obj_nocf = DoubleMLPLR$new(data_ml, n_folds = n_folds, - ml_g = mlr3::lrn(learner_pars$mlmethod$mlmethod_g), + ml_l = mlr3::lrn(learner_pars$mlmethod$mlmethod_l), ml_m = mlr3::lrn(learner_pars$mlmethod$mlmethod_m), + ml_g = ml_g, dml_procedure = dml_procedure, score = score, apply_cross_fitting = FALSE) double_mlplr_obj_nocf$set_ml_nuisance_params( - treat_var = "d1", learner = "ml_g", - params = learner_pars$params$params_g) + treat_var = "d1", learner = "ml_l", + params = learner_pars$params$params_l) double_mlplr_obj_nocf$set_ml_nuisance_params( - treat_var = "d2", learner = "ml_g", - params = learner_pars$params$params_g) + treat_var = "d2", learner = "ml_l", + params = learner_pars$params$params_l) double_mlplr_obj_nocf$set_ml_nuisance_params( treat_var = "d1", learner = "ml_m", params = learner_pars$params$params_m) double_mlplr_obj_nocf$set_ml_nuisance_params( treat_var = "d2", learner = "ml_m", params = learner_pars$params$params_m) + if (score == "IV-type") { + double_mlplr_obj_nocf$set_ml_nuisance_params( + treat_var = "d1", learner = "ml_g", + params = learner_pars$params$params_g) + double_mlplr_obj_nocf$set_ml_nuisance_params( + treat_var = "d2", learner = "ml_g", + params = learner_pars$params$params_g) + } double_mlplr_obj_nocf$fit() @@ -182,50 +226,71 @@ patrick::with_parameters_test_that("Unit tests for parameter passing of PLR (fol d_cols = c("d1", "d2"), x_cols = Xnames) set.seed(3141) + if (score == "IV-type") { + ml_g = mlr3::lrn(learner_pars$mlmethod$mlmethod_g) + } else { + ml_g = NULL + } double_mlplr_obj = DoubleMLPLR$new(data_ml, n_folds = n_folds, - ml_g = mlr3::lrn(learner_pars$mlmethod$mlmethod_g), + ml_l = mlr3::lrn(learner_pars$mlmethod$mlmethod_l), ml_m = mlr3::lrn(learner_pars$mlmethod$mlmethod_m), + ml_g = ml_g, dml_procedure = dml_procedure, score = score, n_rep = n_rep) double_mlplr_obj$set_ml_nuisance_params( - treat_var = "d1", learner = "ml_g", - params = learner_pars$params$params_g) + treat_var = "d1", learner = "ml_l", + params = learner_pars$params$params_l) double_mlplr_obj$set_ml_nuisance_params( - treat_var = "d2", learner = "ml_g", - params = learner_pars$params$params_g) + treat_var = "d2", learner = "ml_l", + params = learner_pars$params$params_l) double_mlplr_obj$set_ml_nuisance_params( treat_var = "d1", learner = "ml_m", params = learner_pars$params$params_m) double_mlplr_obj$set_ml_nuisance_params( treat_var = "d2", learner = "ml_m", params = learner_pars$params$params_m) + if (score == "IV-type") { + double_mlplr_obj$set_ml_nuisance_params( + treat_var = "d1", learner = "ml_g", + params = learner_pars$params$params_g) + double_mlplr_obj$set_ml_nuisance_params( + treat_var = "d2", learner = "ml_g", + params = learner_pars$params$params_g) + } double_mlplr_obj$fit() theta = double_mlplr_obj$coef se = double_mlplr_obj$se - params_g_fold_wise = rep(list(rep(list(learner_pars$params$params_g), n_folds)), n_rep) + params_l_fold_wise = rep(list(rep(list(learner_pars$params$params_l), n_folds)), n_rep) params_m_fold_wise = rep(list(rep(list(learner_pars$params$params_m), n_folds)), n_rep) + params_g_fold_wise = rep(list(rep(list(learner_pars$params$params_g), n_folds)), n_rep) set.seed(3141) + if (score == "IV-type") { + ml_g = mlr3::lrn(learner_pars$mlmethod$mlmethod_g) + } else { + ml_g = NULL + } dml_plr_fold_wise = DoubleMLPLR$new(data_ml, n_folds = n_folds, - ml_g = mlr3::lrn(learner_pars$mlmethod$mlmethod_g), + ml_l = mlr3::lrn(learner_pars$mlmethod$mlmethod_l), ml_m = mlr3::lrn(learner_pars$mlmethod$mlmethod_m), + ml_g = ml_g, dml_procedure = dml_procedure, score = score, n_rep = n_rep) dml_plr_fold_wise$set_ml_nuisance_params( - treat_var = "d1", learner = "ml_g", - params = params_g_fold_wise, + treat_var = "d1", learner = "ml_l", + params = params_l_fold_wise, set_fold_specific = TRUE) dml_plr_fold_wise$set_ml_nuisance_params( - treat_var = "d2", learner = "ml_g", - params = params_g_fold_wise, + treat_var = "d2", learner = "ml_l", + params = params_l_fold_wise, set_fold_specific = TRUE) dml_plr_fold_wise$set_ml_nuisance_params( treat_var = "d1", learner = "ml_m", @@ -235,6 +300,16 @@ patrick::with_parameters_test_that("Unit tests for parameter passing of PLR (fol treat_var = "d2", learner = "ml_m", params = params_m_fold_wise, set_fold_specific = TRUE) + if (score == "IV-type") { + dml_plr_fold_wise$set_ml_nuisance_params( + treat_var = "d1", learner = "ml_g", + params = params_g_fold_wise, + set_fold_specific = TRUE) + dml_plr_fold_wise$set_ml_nuisance_params( + treat_var = "d2", learner = "ml_g", + params = params_g_fold_wise, + set_fold_specific = TRUE) + } dml_plr_fold_wise$fit() theta_fold_wise = dml_plr_fold_wise$coef @@ -251,8 +326,9 @@ patrick::with_parameters_test_that("Unit tests for parameter passing of PLR (def n_folds = 2 n_rep = 3 - params_g = list(cp = 0.01, minsplit = 20) # this are defaults + params_l = list(cp = 0.01, minsplit = 20) # this are defaults params_m = list(cp = 0.01, minsplit = 20) # this are defaults + params_g = list(cp = 0.01, minsplit = 20) # this are defaults Xnames = names(data_plr_multi)[names(data_plr_multi) %in% c("y", "d1", "d2", "z") == FALSE] data_ml = double_ml_data_from_data_frame(data_plr_multi, @@ -260,10 +336,16 @@ patrick::with_parameters_test_that("Unit tests for parameter passing of PLR (def d_cols = c("d1", "d2"), x_cols = Xnames) set.seed(3141) + if (score == "IV-type") { + ml_g = lrn("regr.rpart") + } else { + ml_g = NULL + } dml_plr_default = DoubleMLPLR$new(data_ml, n_folds = n_folds, - ml_g = lrn("regr.rpart"), + ml_l = lrn("regr.rpart"), ml_m = lrn("regr.rpart"), + ml_g = ml_g, dml_procedure = dml_procedure, score = score, n_rep = n_rep) @@ -273,25 +355,39 @@ patrick::with_parameters_test_that("Unit tests for parameter passing of PLR (def se_default = dml_plr_default$se set.seed(3141) + if (score == "IV-type") { + ml_g = lrn("regr.rpart") + } else { + ml_g = NULL + } double_mlplr_obj = DoubleMLPLR$new(data_ml, n_folds = n_folds, - ml_g = lrn("regr.rpart"), + ml_l = lrn("regr.rpart"), ml_m = lrn("regr.rpart"), + ml_g = ml_g, dml_procedure = dml_procedure, score = score, n_rep = n_rep) double_mlplr_obj$set_ml_nuisance_params( - treat_var = "d1", learner = "ml_g", - params = params_g) + treat_var = "d1", learner = "ml_l", + params = params_l) double_mlplr_obj$set_ml_nuisance_params( - treat_var = "d2", learner = "ml_g", - params = params_g) + treat_var = "d2", learner = "ml_l", + params = params_l) double_mlplr_obj$set_ml_nuisance_params( treat_var = "d1", learner = "ml_m", params = params_m) double_mlplr_obj$set_ml_nuisance_params( treat_var = "d2", learner = "ml_m", params = params_m) + if (score == "IV-type") { + double_mlplr_obj$set_ml_nuisance_params( + treat_var = "d1", learner = "ml_g", + params = params_g) + double_mlplr_obj$set_ml_nuisance_params( + treat_var = "d2", learner = "ml_g", + params = params_g) + } double_mlplr_obj$fit() theta = double_mlplr_obj$coef diff --git a/tests/testthat/test-double_ml_plr_rep_cross_fit.R b/tests/testthat/test-double_ml_plr_rep_cross_fit.R index da73d7f4..9e763a52 100644 --- a/tests/testthat/test-double_ml_plr_rep_cross_fit.R +++ b/tests/testthat/test-double_ml_plr_rep_cross_fit.R @@ -30,11 +30,17 @@ patrick::with_parameters_test_that("Unit tests for PLR:", set.seed(3141) n_folds = 5 + if (score == "IV-type") { + ml_g = learner_pars$ml_g$clone() + } else { + ml_g = NULL + } plr_hat = dml_plr(data_plr$df, y = "y", d = "d", n_folds = n_folds, n_rep = n_rep, - ml_g = learner_pars$ml_g$clone(), + ml_l = learner_pars$ml_l$clone(), ml_m = learner_pars$ml_m$clone(), + ml_g = ml_g, dml_procedure = dml_procedure, score = score) theta = plr_hat$coef se = plr_hat$se @@ -52,10 +58,16 @@ patrick::with_parameters_test_that("Unit tests for PLR:", score = score)$boot_coef set.seed(3141) + if (score == "IV-type") { + ml_g = learner_pars$ml_g$clone() + } else { + ml_g = NULL + } double_mlplr_obj = DoubleMLPLR$new( data = data_plr$dml_data, - ml_g = learner_pars$ml_g$clone(), + ml_l = learner_pars$ml_l$clone(), ml_m = learner_pars$ml_m$clone(), + ml_g = ml_g, dml_procedure = dml_procedure, n_folds = n_folds, score = score, diff --git a/tests/testthat/test-double_ml_plr_set_samples.R b/tests/testthat/test-double_ml_plr_set_samples.R index 094a85ca..d8bebafc 100644 --- a/tests/testthat/test-double_ml_plr_set_samples.R +++ b/tests/testthat/test-double_ml_plr_set_samples.R @@ -35,9 +35,15 @@ patrick::with_parameters_test_that("PLR with external sample provision:", y_col = "y", d_cols = "d", x_cols = Xnames) + if (score == "IV-type") { + ml_g = learner_pars$ml_g$clone() + } else { + ml_g = NULL + } double_mlplr_obj = DoubleMLPLR$new(data_ml, - ml_g = learner_pars$ml_g$clone(), + ml_l = learner_pars$ml_l$clone(), ml_m = learner_pars$ml_m$clone(), + ml_g = ml_g, dml_procedure = dml_procedure, n_folds = n_folds, score = score, @@ -52,9 +58,15 @@ patrick::with_parameters_test_that("PLR with external sample provision:", # External sample provision SAMPLES = double_mlplr_obj$smpls + if (score == "IV-type") { + ml_g = learner_pars$ml_g$clone() + } else { + ml_g = NULL + } double_mlplr_obj_external = DoubleMLPLR$new(data_ml, - ml_g = learner_pars$ml_g$clone(), + ml_l = learner_pars$ml_l$clone(), ml_m = learner_pars$ml_m$clone(), + ml_g = ml_g, dml_procedure = dml_procedure, score = score, draw_sample_splitting = FALSE) diff --git a/tests/testthat/test-double_ml_plr_tuning.R b/tests/testthat/test-double_ml_plr_tuning.R index 60907783..2a97ec9c 100644 --- a/tests/testthat/test-double_ml_plr_tuning.R +++ b/tests/testthat/test-double_ml_plr_tuning.R @@ -56,10 +56,16 @@ patrick::with_parameters_test_that("Unit tests for tuning of PLR:", y_col = "y", d_cols = c("d1", "d2"), x_cols = Xnames) } + if (score == "IV-type") { + ml_g = learner + } else { + ml_g = NULL + } double_mlplr_obj_tuned = DoubleMLPLR$new(data_ml, n_folds = n_folds, - ml_g = learner, + ml_l = learner, ml_m = m_learner, + ml_g = ml_g, dml_procedure = dml_procedure, score = score, n_rep = n_rep) @@ -73,12 +79,18 @@ patrick::with_parameters_test_that("Unit tests for tuning of PLR:", resolution = 5) param_grid = list( - "ml_g" = paradox::ParamSet$new(list( - paradox::ParamDbl$new("cp", lower = 0.01, upper = 0.02), + "ml_l" = paradox::ParamSet$new(list( + paradox::ParamDbl$new("cp", lower = 0.02, upper = 0.03), paradox::ParamInt$new("minsplit", lower = 1, upper = 2))), "ml_m" = paradox::ParamSet$new(list( - paradox::ParamDbl$new("cp", lower = 0.01, upper = 0.02), - paradox::ParamInt$new("minsplit", lower = 1, upper = 2)))) + paradox::ParamDbl$new("cp", lower = 0.03, upper = 0.04), + paradox::ParamInt$new("minsplit", lower = 2, upper = 3)))) + + if (score == "IV-type") { + param_grid[["ml_g"]] = paradox::ParamSet$new(list( + paradox::ParamDbl$new("cp", lower = 0.015, upper = 0.025), + paradox::ParamInt$new("minsplit", lower = 3, upper = 4))) + } double_mlplr_obj_tuned$tune(param_set = param_grid, tune_on_folds = tune_on_folds, tune_settings = tune_sets) diff --git a/tests/testthat/test-double_ml_plr_user_score.R b/tests/testthat/test-double_ml_plr_user_score.R index 187442c5..586eaaa8 100644 --- a/tests/testthat/test-double_ml_plr_user_score.R +++ b/tests/testthat/test-double_ml_plr_user_score.R @@ -4,9 +4,9 @@ library("mlr3learners") lgr::get_logger("mlr3")$set_threshold("warn") -score_fct = function(y, d, g_hat, m_hat, smpls) { +score_fct = function(y, d, l_hat, m_hat, g_hat, smpls) { v_hat = d - m_hat - u_hat = y - g_hat + u_hat = y - l_hat v_hatd = v_hat * d psi_a = -v_hat * v_hat psi_b = v_hat * u_hat @@ -39,7 +39,7 @@ patrick::with_parameters_test_that("Unit tests for PLR, callable score:", double_mlplr_obj = DoubleMLPLR$new( data = data_plr$dml_data, - ml_g = lrn(learner), + ml_l = lrn(learner), ml_m = lrn(learner), dml_procedure = dml_procedure, n_folds = n_folds, @@ -55,7 +55,7 @@ patrick::with_parameters_test_that("Unit tests for PLR, callable score:", set.seed(3141) double_mlplr_obj_score = DoubleMLPLR$new( data = data_plr$dml_data, - ml_g = lrn(learner), + ml_l = lrn(learner), ml_m = lrn(learner), dml_procedure = dml_procedure, n_folds = n_folds, diff --git a/tests/testthat/test-double_ml_print.R b/tests/testthat/test-double_ml_print.R index 44af8081..60e7452b 100644 --- a/tests/testthat/test-double_ml_print.R +++ b/tests/testthat/test-double_ml_print.R @@ -6,8 +6,8 @@ set.seed(3141) dml_data = make_plr_CCDDHNR2018(n_obs = 100) dml_cluster_data = make_pliv_multiway_cluster_CKMS2021(N = 10, M = 10, dim_X = 5) -ml_g = ml_m = ml_r = "regr.rpart" -dml_plr = DoubleMLPLR$new(dml_data, ml_g, ml_m, n_folds = 2) +ml_l = ml_g = ml_m = ml_r = "regr.rpart" +dml_plr = DoubleMLPLR$new(dml_data, ml_l, ml_m, n_folds = 2) dml_pliv = DoubleMLPLIV$new(dml_cluster_data, ml_g, ml_m, ml_r, n_folds = 2) dml_plr$fit() dml_pliv$fit() diff --git a/tests/testthat/test-double_ml_set_sample_splitting.R b/tests/testthat/test-double_ml_set_sample_splitting.R index e17b9885..41c8db3c 100644 --- a/tests/testthat/test-double_ml_set_sample_splitting.R +++ b/tests/testthat/test-double_ml_set_sample_splitting.R @@ -2,9 +2,9 @@ context("Unit tests for the method set_sample_splitting of class DoubleML") set.seed(3141) dml_data = make_plr_CCDDHNR2018(n_obs = 10) -ml_g = lrn("regr.ranger") -ml_m = ml_g$clone() -dml_plr = DoubleMLPLR$new(dml_data, ml_g, ml_m, n_folds = 7, n_rep = 8) +ml_l = lrn("regr.ranger") +ml_m = ml_l$clone() +dml_plr = DoubleMLPLR$new(dml_data, ml_l, ml_m, n_folds = 7, n_rep = 8) test_that("Unit tests for the method set_sample_splitting of class DoubleML", { @@ -165,37 +165,37 @@ assert_resampling_pars = function(dml_obj0, dml_obj1) { test_that("Unit tests for the method set_sample_splitting of class DoubleML (draw vs set)", { set.seed(3141) - dml_plr_set = DoubleMLPLR$new(dml_data, ml_g, ml_m, n_folds = 7, n_rep = 8) + dml_plr_set = DoubleMLPLR$new(dml_data, ml_l, ml_m, n_folds = 7, n_rep = 8) - dml_plr_drawn = DoubleMLPLR$new(dml_data, ml_g, ml_m, + dml_plr_drawn = DoubleMLPLR$new(dml_data, ml_l, ml_m, n_folds = 1, n_rep = 1, apply_cross_fitting = FALSE) dml_plr_set$set_sample_splitting(dml_plr_drawn$smpls) assert_resampling_pars(dml_plr_drawn, dml_plr_set) dml_plr_set$set_sample_splitting(dml_plr_drawn$smpls[[1]]) assert_resampling_pars(dml_plr_drawn, dml_plr_set) - dml_plr_drawn = DoubleMLPLR$new(dml_data, ml_g, ml_m, + dml_plr_drawn = DoubleMLPLR$new(dml_data, ml_l, ml_m, n_folds = 2, n_rep = 1, apply_cross_fitting = FALSE) dml_plr_set$set_sample_splitting(dml_plr_drawn$smpls) assert_resampling_pars(dml_plr_drawn, dml_plr_set) dml_plr_set$set_sample_splitting(dml_plr_drawn$smpls[[1]]) assert_resampling_pars(dml_plr_drawn, dml_plr_set) - dml_plr_drawn = DoubleMLPLR$new(dml_data, ml_g, ml_m, + dml_plr_drawn = DoubleMLPLR$new(dml_data, ml_l, ml_m, n_folds = 2, n_rep = 1, apply_cross_fitting = TRUE) dml_plr_set$set_sample_splitting(dml_plr_drawn$smpls) assert_resampling_pars(dml_plr_drawn, dml_plr_set) dml_plr_set$set_sample_splitting(dml_plr_drawn$smpls[[1]]) assert_resampling_pars(dml_plr_drawn, dml_plr_set) - dml_plr_drawn = DoubleMLPLR$new(dml_data, ml_g, ml_m, + dml_plr_drawn = DoubleMLPLR$new(dml_data, ml_l, ml_m, n_folds = 5, n_rep = 1, apply_cross_fitting = TRUE) dml_plr_set$set_sample_splitting(dml_plr_drawn$smpls) assert_resampling_pars(dml_plr_drawn, dml_plr_set) dml_plr_set$set_sample_splitting(dml_plr_drawn$smpls[[1]]) assert_resampling_pars(dml_plr_drawn, dml_plr_set) - dml_plr_drawn = DoubleMLPLR$new(dml_data, ml_g, ml_m, + dml_plr_drawn = DoubleMLPLR$new(dml_data, ml_l, ml_m, n_folds = 5, n_rep = 3, apply_cross_fitting = TRUE) dml_plr_set$set_sample_splitting(dml_plr_drawn$smpls) assert_resampling_pars(dml_plr_drawn, dml_plr_set)