}}\preformatted{DoubleMLPLIV$new(
data,
- ml_g,
+ ml_l,
ml_m,
ml_r,
+ ml_g = NULL,
partialX = TRUE,
partialZ = FALSE,
n_folds = 5,
@@ -145,7 +146,7 @@ Creates a new instance of this R6 class.
The \code{DoubleMLData} object providing the data and specifying the variables
of the causal model.}
-\item{\code{ml_g}}{(\code{\link[mlr3:LearnerRegr]{LearnerRegr}},
+\item{\code{ml_l}}{(\code{\link[mlr3:LearnerRegr]{LearnerRegr}},
\code{\link[mlr3:Learner]{Learner}}, \code{character(1)}) \cr
A learner of the class \code{\link[mlr3:LearnerRegr]{LearnerRegr}}, which is
available from \href{https://mlr3.mlr-org.com/index.html}{mlr3} or its
@@ -156,7 +157,7 @@ Alternatively, a \code{\link[mlr3:Learner]{Learner}} object with public field
\code{\link[mlr3pipelines:mlr_learners_graph]{GraphLearner}}. The learner can possibly
be passed with specified parameters, for example
\code{lrn("regr.cv_glmnet", s = "lambda.min")}. \cr
-\code{ml_g} refers to the nuisance function \eqn{g_0(X) = E[Y|X]}.}
+\code{ml_l} refers to the nuisance function \eqn{l_0(X) = E[Y|X]}.}
\item{\code{ml_m}}{(\code{\link[mlr3:LearnerRegr]{LearnerRegr}},
\code{\link[mlr3:Learner]{Learner}}, \code{character(1)}) \cr
@@ -184,6 +185,21 @@ be passed with specified parameters, for example
\code{lrn("regr.cv_glmnet", s = "lambda.min")}. \cr
\code{ml_r} refers to the nuisance function \eqn{r_0(X) = E[D|X]}.}
+\item{\code{ml_g}}{(\code{\link[mlr3:LearnerRegr]{LearnerRegr}},
+\code{\link[mlr3:Learner]{Learner}}, \code{character(1)}) \cr
+A learner of the class \code{\link[mlr3:LearnerRegr]{LearnerRegr}}, which is
+available from \href{https://mlr3.mlr-org.com/index.html}{mlr3} or its
+extension packages \href{https://mlr3learners.mlr-org.com/}{mlr3learners} or
+\href{https://mlr3extralearners.mlr-org.com/}{mlr3extralearners}.
+Alternatively, a \code{\link[mlr3:Learner]{Learner}} object with public field
+\code{task_type = "regr"} can be passed, for example of class
+\code{\link[mlr3pipelines:mlr_learners_graph]{GraphLearner}}. The learner can possibly
+be passed with specified parameters, for example
+\code{lrn("regr.cv_glmnet", s = "lambda.min")}. \cr
+\code{ml_g} refers to the nuisance function \eqn{g_0(X) = E[Y - D\theta_0|X]}.
+Note: The learner \code{ml_g} is only required for the score \code{'IV-type'}.
+Optionally, it can be specified and estimated for callable scores.}
+
\item{\code{partialX}}{(\code{logical(1)}) \cr
Indicates whether covariates \eqn{X} should be partialled out.
Default is \code{TRUE}.}
@@ -199,10 +215,10 @@ Number of folds. Default is \code{5}.}
Number of repetitions for the sample splitting. Default is \code{1}.}
\item{\code{score}}{(\code{character(1)}, \verb{function()}) \cr
-A \code{character(1)} (\code{"partialling out"} is the only choice) or a
-\verb{function()} specifying the score function.
+A \code{character(1)} (\code{"partialling out"} or \code{"IV-type"}) or a \verb{function()}
+specifying the score function.
If a \verb{function()} is provided, it must be of the form
-\verb{function(y, z, d, g_hat, m_hat, r_hat, smpls)} and
+\verb{function(y, z, d, l_hat, m_hat, r_hat, g_hat, smpls)} and
the returned output must be a named \code{list()} with elements
\code{psi_a} and \code{psi_b}. Default is \code{"partialling out"}.}
@@ -221,6 +237,130 @@ Indicates whether cross-fitting should be applied. Default is \code{TRUE}.}
}
}
\if{html}{\out{
}}
+\if{latex}{\out{\hypertarget{method-set_ml_nuisance_params}{}}}
+\subsection{Method \code{set_ml_nuisance_params()}}{
+Set hyperparameters for the nuisance models of DoubleML models.
+
+Note that in the current implementation, either all parameters have to
+be set globally or all parameters have to be provided fold-specific.
+\subsection{Usage}{
+\if{html}{\out{
}}\preformatted{DoubleMLPLIV$set_ml_nuisance_params(
+ learner = NULL,
+ treat_var = NULL,
+ params,
+ set_fold_specific = FALSE
+)}\if{html}{\out{
}}
+\describe{
+\item{\code{learner}}{(\code{character(1)}) \cr
+The nuisance model/learner (see method \code{params_names}).}
+
+\item{\code{treat_var}}{(\code{character(1)}) \cr
+The treatment varaible (hyperparameters can be set treatment-variable
+specific).}
+
+\item{\code{params}}{(named \code{list()}) \cr
+A named \code{list()} with estimator parameters. Parameters are used for all
+folds by default. Alternatively, parameters can be passed in a
+fold-specific way if option \code{fold_specific}is \code{TRUE}. In this case, the
+outer list needs to be of length \code{n_rep} and the inner list of length
+\code{n_folds}.}
+
+\item{\code{set_fold_specific}}{(\code{logical(1)}) \cr
+Indicates if the parameters passed in \code{params} should be passed in
+fold-specific way. Default is \code{FALSE}. If \code{TRUE}, the outer list needs
+to be of length \code{n_rep} and the inner list of length \code{n_folds}.
+Note that in the current implementation, either all parameters have to
+be set globally or all parameters have to be provided fold-specific.}
+}
+\if{html}{\out{
}}
+\if{latex}{\out{\hypertarget{method-tune}{}}}
+\subsection{Method \code{tune()}}{
+Hyperparameter-tuning for DoubleML models.
+
+The hyperparameter-tuning is performed using the tuning methods provided
+in the \href{https://mlr3tuning.mlr-org.com/}{mlr3tuning} package. For more
+information on tuning in \href{https://mlr3.mlr-org.com/}{mlr3}, we refer to
+the section on parameter tuning in the
+\href{https://mlr3book.mlr-org.com/optimization.html#tuning}{mlr3 book}.
+\subsection{Usage}{
+\if{html}{\out{
}}\preformatted{DoubleMLPLIV$tune(
+ param_set,
+ tune_settings = list(n_folds_tune = 5, rsmp_tune = mlr3::rsmp("cv", folds = 5),
+ measure = NULL, terminator = mlr3tuning::trm("evals", n_evals = 20), algorithm =
+ mlr3tuning::tnr("grid_search"), resolution = 5),
+ tune_on_folds = FALSE
+)}\if{html}{\out{
}}
+\describe{
+\item{\code{param_set}}{(named \code{list()}) \cr
+A named \code{list} with a parameter grid for each nuisance model/learner
+(see method \code{learner_names()}). The parameter grid must be an object of
+class \link[paradox:ParamSet]{ParamSet}.}
+
+\item{\code{tune_settings}}{(named \code{list()}) \cr
+A named \code{list()} with arguments passed to the hyperparameter-tuning with
+\href{https://mlr3tuning.mlr-org.com/}{mlr3tuning} to set up
+\link[mlr3tuning:TuningInstanceSingleCrit]{TuningInstance} objects.
+\code{tune_settings} has entries
+\itemize{
+\item \code{terminator} (\link[bbotk:Terminator]{Terminator}) \cr
+A \link[bbotk:Terminator]{Terminator} object. Specification of \code{terminator}
+is required to perform tuning.
+\item \code{algorithm} (\link[mlr3tuning:Tuner]{Tuner} or \code{character(1)}) \cr
+A \link[mlr3tuning:Tuner]{Tuner} object (recommended) or key passed to the
+respective dictionary to specify the tuning algorithm used in
+\link[mlr3tuning:tnr]{tnr()}. \code{algorithm} is passed as an argument to
+\link[mlr3tuning:tnr]{tnr()}. If \code{algorithm} is not specified by the users,
+default is set to \code{"grid_search"}. If set to \code{"grid_search"}, then
+additional argument \code{"resolution"} is required.
+\item \code{rsmp_tune} (\link[mlr3:Resampling]{Resampling} or \code{character(1)})\cr
+A \link[mlr3:Resampling]{Resampling} object (recommended) or option passed
+to \link[mlr3:mlr_sugar]{rsmp()} to initialize a
+\link[mlr3:Resampling]{Resampling} for parameter tuning in \code{mlr3}.
+If not specified by the user, default is set to \code{"cv"}
+(cross-validation).
+\item \code{n_folds_tune} (\code{integer(1)}, optional) \cr
+If \code{rsmp_tune = "cv"}, number of folds used for cross-validation.
+If not specified by the user, default is set to \code{5}.
+\item \code{measure} (\code{NULL}, named \code{list()}, optional) \cr
+Named list containing the measures used for parameter tuning. Entries in
+list must either be \link[mlr3:Measure]{Measure} objects or keys to be
+passed to passed to \link[mlr3:mlr_sugar]{msr()}. The names of the entries must
+match the learner names (see method \code{learner_names()}). If set to \code{NULL},
+default measures are used, i.e., \code{"regr.mse"} for continuous outcome
+variables and \code{"classif.ce"} for binary outcomes.
+\item \code{resolution} (\code{character(1)}) \cr The key passed to the respective
+dictionary to specify the tuning algorithm used in
+\link[mlr3tuning:tnr]{tnr()}. \code{resolution} is passed as an argument to
+\link[mlr3tuning:tnr]{tnr()}.
+}}
+
+\item{\code{tune_on_folds}}{(\code{logical(1)}) \cr
+Indicates whether the tuning should be done fold-specific or globally.
+Default is \code{FALSE}.}
+}
+\if{html}{\out{
}}
\if{latex}{\out{\hypertarget{method-clone}{}}}
\subsection{Method \code{clone()}}{
diff --git a/man/DoubleMLPLR.Rd b/man/DoubleMLPLR.Rd
index 1fa63ff1..8aa46212 100644
--- a/man/DoubleMLPLR.Rd
+++ b/man/DoubleMLPLR.Rd
@@ -43,13 +43,13 @@ library(mlr3learners)
library(mlr3tuning)
library(data.table)
set.seed(2)
-ml_g = lrn("regr.rpart")
-ml_m = ml_g$clone()
+ml_l = lrn("regr.rpart")
+ml_m = ml_l$clone()
obj_dml_data = make_plr_CCDDHNR2018(alpha = 0.5)
-dml_plr_obj = DoubleMLPLR$new(obj_dml_data, ml_g, ml_m)
+dml_plr_obj = DoubleMLPLR$new(obj_dml_data, ml_l, ml_m)
param_grid = list(
- "ml_g" = paradox::ParamSet$new(list(
+ "ml_l" = paradox::ParamSet$new(list(
paradox::ParamDbl$new("cp", lower = 0.01, upper = 0.02),
paradox::ParamInt$new("minsplit", lower = 1, upper = 2))),
"ml_m" = paradox::ParamSet$new(list(
@@ -80,6 +80,8 @@ Other DoubleML:
\subsection{Public methods}{
\itemize{
\item \href{#method-new}{\code{DoubleMLPLR$new()}}
+\item \href{#method-set_ml_nuisance_params}{\code{DoubleMLPLR$set_ml_nuisance_params()}}
+\item \href{#method-tune}{\code{DoubleMLPLR$tune()}}
\item \href{#method-clone}{\code{DoubleMLPLR$clone()}}
}
}
@@ -94,11 +96,9 @@ Other DoubleML:
\item \out{
}
}
\out{}
}
@@ -110,8 +110,9 @@ Creates a new instance of this R6 class.
\subsection{Usage}{
\if{html}{\out{
}}\preformatted{DoubleMLPLR$new(
data,
- ml_g,
+ ml_l,
ml_m,
+ ml_g = NULL,
n_folds = 5,
n_rep = 1,
score = "partialling out",
@@ -128,7 +129,7 @@ Creates a new instance of this R6 class.
The \code{DoubleMLData} object providing the data and specifying the
variables of the causal model.}
-\item{\code{ml_g}}{(\code{\link[mlr3:LearnerRegr]{LearnerRegr}},
+\item{\code{ml_l}}{(\code{\link[mlr3:LearnerRegr]{LearnerRegr}},
\code{\link[mlr3:Learner]{Learner}}, \code{character(1)}) \cr
A learner of the class \code{\link[mlr3:LearnerRegr]{LearnerRegr}}, which is
available from \href{https://mlr3.mlr-org.com/index.html}{mlr3} or its
@@ -139,7 +140,7 @@ Alternatively, a \code{\link[mlr3:Learner]{Learner}} object with public field
\code{\link[mlr3pipelines:mlr_learners_graph]{GraphLearner}}. The learner can possibly
be passed with specified parameters, for example
\code{lrn("regr.cv_glmnet", s = "lambda.min")}. \cr
-\code{ml_g} refers to the nuisance function \eqn{g_0(X) = E[Y|X]}.}
+\code{ml_l} refers to the nuisance function \eqn{l_0(X) = E[Y|X]}.}
\item{\code{ml_m}}{(\code{\link[mlr3:LearnerRegr]{LearnerRegr}},
\code{\link[mlr3:LearnerClassif]{LearnerClassif}}, \code{\link[mlr3:Learner]{Learner}},
@@ -157,6 +158,21 @@ respectively, for example of class
\code{\link[mlr3pipelines:mlr_learners_graph]{GraphLearner}}. \cr
\code{ml_m} refers to the nuisance function \eqn{m_0(X) = E[D|X]}.}
+\item{\code{ml_g}}{(\code{\link[mlr3:LearnerRegr]{LearnerRegr}},
+\code{\link[mlr3:Learner]{Learner}}, \code{character(1)}) \cr
+A learner of the class \code{\link[mlr3:LearnerRegr]{LearnerRegr}}, which is
+available from \href{https://mlr3.mlr-org.com/index.html}{mlr3} or its
+extension packages \href{https://mlr3learners.mlr-org.com/}{mlr3learners} or
+\href{https://mlr3extralearners.mlr-org.com/}{mlr3extralearners}.
+Alternatively, a \code{\link[mlr3:Learner]{Learner}} object with public field
+\code{task_type = "regr"} can be passed, for example of class
+\code{\link[mlr3pipelines:mlr_learners_graph]{GraphLearner}}. The learner can possibly
+be passed with specified parameters, for example
+\code{lrn("regr.cv_glmnet", s = "lambda.min")}. \cr
+\code{ml_g} refers to the nuisance function \eqn{g_0(X) = E[Y - D\theta_0|X]}.
+Note: The learner \code{ml_g} is only required for the score \code{'IV-type'}.
+Optionally, it can be specified and estimated for callable scores.}
+
\item{\code{n_folds}}{(\code{integer(1)})\cr
Number of folds. Default is \code{5}.}
@@ -164,10 +180,10 @@ Number of folds. Default is \code{5}.}
Number of repetitions for the sample splitting. Default is \code{1}.}
\item{\code{score}}{(\code{character(1)}, \verb{function()}) \cr
-A \code{character(1)} (\code{"partialling out"} or \code{IV-type}) or a \verb{function()}
+A \code{character(1)} (\code{"partialling out"} or \code{"IV-type"}) or a \verb{function()}
specifying the score function.
If a \verb{function()} is provided, it must be of the form
-\verb{function(y, d, g_hat, m_hat, smpls)} and
+\verb{function(y, d, l_hat, m_hat, g_hat, smpls)} and
the returned output must be a named \code{list()} with elements \code{psi_a} and
\code{psi_b}. Default is \code{"partialling out"}.}
@@ -186,6 +202,130 @@ Indicates whether cross-fitting should be applied. Default is \code{TRUE}.}
}
}
\if{html}{\out{
}}
+\if{html}{\out{
}}
+\if{latex}{\out{\hypertarget{method-set_ml_nuisance_params}{}}}
+\subsection{Method \code{set_ml_nuisance_params()}}{
+Set hyperparameters for the nuisance models of DoubleML models.
+
+Note that in the current implementation, either all parameters have to
+be set globally or all parameters have to be provided fold-specific.
+\subsection{Usage}{
+\if{html}{\out{
}}\preformatted{DoubleMLPLR$set_ml_nuisance_params(
+ learner = NULL,
+ treat_var = NULL,
+ params,
+ set_fold_specific = FALSE
+)}\if{html}{\out{
}}
+}
+
+\subsection{Arguments}{
+\if{html}{\out{
}}
+\describe{
+\item{\code{learner}}{(\code{character(1)}) \cr
+The nuisance model/learner (see method \code{params_names}).}
+
+\item{\code{treat_var}}{(\code{character(1)}) \cr
+The treatment varaible (hyperparameters can be set treatment-variable
+specific).}
+
+\item{\code{params}}{(named \code{list()}) \cr
+A named \code{list()} with estimator parameters. Parameters are used for all
+folds by default. Alternatively, parameters can be passed in a
+fold-specific way if option \code{fold_specific}is \code{TRUE}. In this case, the
+outer list needs to be of length \code{n_rep} and the inner list of length
+\code{n_folds}.}
+
+\item{\code{set_fold_specific}}{(\code{logical(1)}) \cr
+Indicates if the parameters passed in \code{params} should be passed in
+fold-specific way. Default is \code{FALSE}. If \code{TRUE}, the outer list needs
+to be of length \code{n_rep} and the inner list of length \code{n_folds}.
+Note that in the current implementation, either all parameters have to
+be set globally or all parameters have to be provided fold-specific.}
+}
+\if{html}{\out{
}}
+}
+\subsection{Returns}{
+self
+}
+}
+\if{html}{\out{
}}
+\if{html}{\out{
}}
+\if{latex}{\out{\hypertarget{method-tune}{}}}
+\subsection{Method \code{tune()}}{
+Hyperparameter-tuning for DoubleML models.
+
+The hyperparameter-tuning is performed using the tuning methods provided
+in the \href{https://mlr3tuning.mlr-org.com/}{mlr3tuning} package. For more
+information on tuning in \href{https://mlr3.mlr-org.com/}{mlr3}, we refer to
+the section on parameter tuning in the
+\href{https://mlr3book.mlr-org.com/optimization.html#tuning}{mlr3 book}.
+\subsection{Usage}{
+\if{html}{\out{
}}\preformatted{DoubleMLPLR$tune(
+ param_set,
+ tune_settings = list(n_folds_tune = 5, rsmp_tune = mlr3::rsmp("cv", folds = 5),
+ measure = NULL, terminator = mlr3tuning::trm("evals", n_evals = 20), algorithm =
+ mlr3tuning::tnr("grid_search"), resolution = 5),
+ tune_on_folds = FALSE
+)}\if{html}{\out{
}}
+}
+
+\subsection{Arguments}{
+\if{html}{\out{
}}
+\describe{
+\item{\code{param_set}}{(named \code{list()}) \cr
+A named \code{list} with a parameter grid for each nuisance model/learner
+(see method \code{learner_names()}). The parameter grid must be an object of
+class \link[paradox:ParamSet]{ParamSet}.}
+
+\item{\code{tune_settings}}{(named \code{list()}) \cr
+A named \code{list()} with arguments passed to the hyperparameter-tuning with
+\href{https://mlr3tuning.mlr-org.com/}{mlr3tuning} to set up
+\link[mlr3tuning:TuningInstanceSingleCrit]{TuningInstance} objects.
+\code{tune_settings} has entries
+\itemize{
+\item \code{terminator} (\link[bbotk:Terminator]{Terminator}) \cr
+A \link[bbotk:Terminator]{Terminator} object. Specification of \code{terminator}
+is required to perform tuning.
+\item \code{algorithm} (\link[mlr3tuning:Tuner]{Tuner} or \code{character(1)}) \cr
+A \link[mlr3tuning:Tuner]{Tuner} object (recommended) or key passed to the
+respective dictionary to specify the tuning algorithm used in
+\link[mlr3tuning:tnr]{tnr()}. \code{algorithm} is passed as an argument to
+\link[mlr3tuning:tnr]{tnr()}. If \code{algorithm} is not specified by the users,
+default is set to \code{"grid_search"}. If set to \code{"grid_search"}, then
+additional argument \code{"resolution"} is required.
+\item \code{rsmp_tune} (\link[mlr3:Resampling]{Resampling} or \code{character(1)})\cr
+A \link[mlr3:Resampling]{Resampling} object (recommended) or option passed
+to \link[mlr3:mlr_sugar]{rsmp()} to initialize a
+\link[mlr3:Resampling]{Resampling} for parameter tuning in \code{mlr3}.
+If not specified by the user, default is set to \code{"cv"}
+(cross-validation).
+\item \code{n_folds_tune} (\code{integer(1)}, optional) \cr
+If \code{rsmp_tune = "cv"}, number of folds used for cross-validation.
+If not specified by the user, default is set to \code{5}.
+\item \code{measure} (\code{NULL}, named \code{list()}, optional) \cr
+Named list containing the measures used for parameter tuning. Entries in
+list must either be \link[mlr3:Measure]{Measure} objects or keys to be
+passed to passed to \link[mlr3:mlr_sugar]{msr()}. The names of the entries must
+match the learner names (see method \code{learner_names()}). If set to \code{NULL},
+default measures are used, i.e., \code{"regr.mse"} for continuous outcome
+variables and \code{"classif.ce"} for binary outcomes.
+\item \code{resolution} (\code{character(1)}) \cr The key passed to the respective
+dictionary to specify the tuning algorithm used in
+\link[mlr3tuning:tnr]{tnr()}. \code{resolution} is passed as an argument to
+\link[mlr3tuning:tnr]{tnr()}.
+}}
+
+\item{\code{tune_on_folds}}{(\code{logical(1)}) \cr
+Indicates whether the tuning should be done fold-specific or globally.
+Default is \code{FALSE}.}
+}
+\if{html}{\out{
}}
+}
+\subsection{Returns}{
+self
+}
+}
+\if{html}{\out{
}}
\if{html}{\out{
}}
\if{latex}{\out{\hypertarget{method-clone}{}}}
\subsection{Method \code{clone()}}{
diff --git a/tests/testthat/helper-05-ml-learner.R b/tests/testthat/helper-05-ml-learner.R
index 53a64531..390e6a27 100644
--- a/tests/testthat/helper-05-ml-learner.R
+++ b/tests/testthat/helper-05-ml-learner.R
@@ -2,36 +2,42 @@ get_default_mlmethod_plr = function(learner, default = FALSE) {
if (default == FALSE) {
if (learner == "regr.lm") {
mlmethod = list(
+ mlmethod_l = learner,
mlmethod_m = learner,
mlmethod_g = learner
)
params = list(
- params_g = list(),
- params_m = list()
+ params_l = list(),
+ params_m = list(),
+ params_g = list()
)
}
else if (learner == "regr.ranger") {
mlmethod = list(
+ mlmethod_l = learner,
mlmethod_m = learner,
mlmethod_g = learner
)
params = list(
- params_g = list(num.trees = 100),
- params_m = list(num.trees = 120)
+ params_l = list(num.trees = 60),
+ params_m = list(num.trees = 120),
+ params_g = list(num.trees = 100)
)
}
else if (learner == "regr.rpart") {
mlmethod = list(
+ mlmethod_l = learner,
mlmethod_m = learner,
mlmethod_g = learner
)
params = list(
- params_g = list(cp = 0.01, minsplit = 20),
- params_m = list(cp = 0.01, minsplit = 20)
+ params_l = list(cp = 0.013, minsplit = 18),
+ params_m = list(cp = 0.01, minsplit = 20),
+ params_g = list(cp = 0.005, minsplit = 10)
)
}
@@ -48,11 +54,16 @@ get_default_mlmethod_plr = function(learner, default = FALSE) {
# }
else if (learner == "regr.cv_glmnet") {
mlmethod = list(
+ mlmethod_l = learner,
mlmethod_m = learner,
mlmethod_g = learner
)
params = list(
+ params_l = list(
+ s = "lambda.min",
+ family = "gaussian"
+ ),
params_m = list(
s = "lambda.min",
family = "gaussian"
@@ -65,12 +76,14 @@ get_default_mlmethod_plr = function(learner, default = FALSE) {
else if (default == TRUE) {
mlmethod = list(
+ mlmethod_l = learner,
mlmethod_m = learner,
mlmethod_g = learner
)
params = list(
- params_g = list(),
- params_m = list())
+ params_l = list(),
+ params_m = list(),
+ params_g = list())
}
if (learner == "graph_learner") {
@@ -80,109 +93,129 @@ get_default_mlmethod_plr = function(learner, default = FALSE) {
lambda = 0.01,
family = "gaussian")
mlmethod = list(
+ mlmethod_l = "graph_learner",
mlmethod_m = "graph_learner",
mlmethod_g = "graph_learner")
params = list(
params_g = list(),
params_m = list())
- ml_g = mlr3::as_learner(pipe_learner)
+ ml_l = mlr3::as_learner(pipe_learner)
ml_m = mlr3::as_learner(pipe_learner)
+ ml_g = mlr3::as_learner(pipe_learner)
} else {
- ml_g = mlr3::lrn(mlmethod$mlmethod_g)
- ml_g$param_set$values = params$params_g
+ ml_l = mlr3::lrn(mlmethod$mlmethod_l)
+ ml_l$param_set$values = params$params_l
ml_m = mlr3::lrn(mlmethod$mlmethod_m)
ml_m$param_set$values = params$params_m
+ ml_g = mlr3::lrn(mlmethod$mlmethod_g)
+ ml_g$param_set$values = params$params_g
}
return(list(
mlmethod = mlmethod, params = params,
- ml_g = ml_g, ml_m = ml_m
+ ml_l = ml_l, ml_m = ml_m, ml_g = ml_g
))
}
get_default_mlmethod_pliv = function(learner) {
if (learner == "regr.lm") {
mlmethod = list(
+ mlmethod_l = learner,
mlmethod_m = learner,
- mlmethod_g = learner,
- mlmethod_r = learner
+ mlmethod_r = learner,
+ mlmethod_g = learner
)
params = list(
- params_g = list(),
+ params_l = list(),
params_m = list(),
- params_r = list()
+ params_r = list(),
+ params_g = list()
)
}
else if (learner == "regr.ranger") {
mlmethod = list(
+ mlmethod_l = learner,
mlmethod_m = learner,
- mlmethod_g = learner,
- mlmethod_r = learner
+ mlmethod_r = learner,
+ mlmethod_g = learner
)
params = list(
- params_g = list(num.trees = 100),
+ params_l = list(num.trees = 100),
params_m = list(num.trees = 120),
- params_r = list(num.trees = 100)
+ params_r = list(num.trees = 100),
+ params_g = list(num.trees = 100)
)
}
else if (learner == "regr.rpart") {
mlmethod = list(
+ mlmethod_l = learner,
mlmethod_m = learner,
- mlmethod_g = learner,
- mlmethod_r = learner
+ mlmethod_r = learner,
+ mlmethod_g = learner
)
params = list(
- params_g = list(cp = 0.01, minsplit = 20),
+ params_l = list(cp = 0.01, minsplit = 20),
params_m = list(cp = 0.01, minsplit = 20),
- params_r = list(cp = 0.01, minsplit = 20)
+ params_r = list(cp = 0.01, minsplit = 20),
+ params_g = list(cp = 0.01, minsplit = 20)
)
}
else if (learner == "regr.cv_glmnet") {
mlmethod = list(
+ mlmethod_l = learner,
mlmethod_m = learner,
- mlmethod_g = learner,
- mlmethod_r = learner
+ mlmethod_r = learner,
+ mlmethod_g = learner
)
params = list(
- params_m = list(
+ params_l = list(
s = "lambda.min",
family = "gaussian"
),
- params_g = list(
+ params_m = list(
s = "lambda.min",
family = "gaussian"
),
params_r = list(
s = "lambda.min",
family = "gaussian"
+ ),
+ params_g = list(
+ s = "lambda.min",
+ family = "gaussian"
)
)
} else if (learner == "regr.glmnet") {
mlmethod = list(
+ mlmethod_l = learner,
mlmethod_m = learner,
- mlmethod_g = learner,
- mlmethod_r = learner
+ mlmethod_r = learner,
+ mlmethod_g = learner
)
params = list(
- params_m = list(
+ params_l = list(
lambda = 0.01,
family = "gaussian"
),
- params_g = list(
+ params_m = list(
lambda = 0.01,
family = "gaussian"
),
params_r = list(
lambda = 0.01,
family = "gaussian"
+ ),
+ params_g = list(
+ lambda = 0.01,
+ family = "gaussian"
)
)
@@ -195,27 +228,33 @@ get_default_mlmethod_pliv = function(learner) {
lambda = 0.01,
family = "gaussian")
mlmethod = list(
+ mlmethod_l = "graph_learner",
mlmethod_m = "graph_learner",
- mlmethod_g = "graph_learner",
- mlmethod_r = "graph_learner")
+ mlmethod_r = "graph_learner",
+ mlmethod_g = "graph_learner")
params = list(
- params_g = list(),
- params_m = list())
- ml_g = mlr3::as_learner(pipe_learner)
+ params_l = list(),
+ params_m = list(),
+ params_r = list(),
+ params_g = list())
+ ml_l = mlr3::as_learner(pipe_learner)
ml_m = mlr3::as_learner(pipe_learner)
ml_r = mlr3::as_learner(pipe_learner)
+ ml_g = mlr3::as_learner(pipe_learner)
} else {
- ml_g = mlr3::lrn(mlmethod$mlmethod_g)
- ml_g$param_set$values = params$params_g
+ ml_l = mlr3::lrn(mlmethod$mlmethod_l)
+ ml_l$param_set$values = params$params_l
ml_m = mlr3::lrn(mlmethod$mlmethod_m)
ml_m$param_set$values = params$params_m
ml_r = mlr3::lrn(mlmethod$mlmethod_r)
ml_r$param_set$values = params$params_r
+ ml_g = mlr3::lrn(mlmethod$mlmethod_g)
+ ml_g$param_set$values = params$params_g
}
return(list(
mlmethod = mlmethod, params = params,
- ml_g = ml_g, ml_m = ml_m, ml_r = ml_r
+ ml_l = ml_l, ml_m = ml_m, ml_r = ml_r, ml_g = ml_g
))
}
diff --git a/tests/testthat/helper-08-dml_plr.R b/tests/testthat/helper-08-dml_plr.R
index 6c1c94d6..305c4016 100644
--- a/tests/testthat/helper-08-dml_plr.R
+++ b/tests/testthat/helper-08-dml_plr.R
@@ -1,9 +1,9 @@
# Double Machine Learning for Partially Linear Regression.
dml_plr = function(data, y, d,
- n_folds, ml_g, ml_m,
+ n_folds, ml_l, ml_m, ml_g,
dml_procedure, score,
n_rep = 1, smpls = NULL,
- params_g = NULL, params_m = NULL) {
+ params_l = NULL, params_m = NULL, params_g = NULL) {
if (is.null(smpls)) {
smpls = lapply(1:n_rep, function(x) sample_splitting(n_folds, data))
@@ -21,10 +21,10 @@ dml_plr = function(data, y, d,
res_single_split = fit_plr_single_split(
data, y, d,
- n_folds, ml_g, ml_m,
+ n_folds, ml_l, ml_m, ml_g,
dml_procedure, score,
this_smpl,
- params_g, params_m)
+ params_l, params_m, params_g)
all_preds[[i_rep]] = res_single_split$all_preds
all_thetas[i_rep] = res_single_split$theta
@@ -53,10 +53,10 @@ dml_plr = function(data, y, d,
dml_plr_multitreat = function(data, y, d,
- n_folds, ml_g, ml_m,
+ n_folds, ml_l, ml_m, ml_g,
dml_procedure, score,
n_rep = 1, smpls = NULL,
- params_g = NULL, params_m = NULL) {
+ params_l = NULL, params_m = NULL, params_g = NULL) {
if (is.null(smpls)) {
smpls = lapply(1:n_rep, function(x) sample_splitting(n_folds, data))
@@ -71,6 +71,11 @@ dml_plr_multitreat = function(data, y, d,
all_preds_this_rep = list()
for (i_d in seq(n_d)) {
+ if (!is.null(params_l)) {
+ this_params_l = params_l[[i_d]]
+ } else {
+ this_params_l = NULL
+ }
if (!is.null(params_g)) {
this_params_g = params_g[[i_d]]
} else {
@@ -83,10 +88,10 @@ dml_plr_multitreat = function(data, y, d,
}
res_single_split = fit_plr_single_split(
data, y, d[i_d],
- n_folds, ml_g, ml_m,
+ n_folds, ml_l, ml_m, ml_g,
dml_procedure, score,
this_smpl,
- this_params_g, this_params_m)
+ this_params_l, this_params_m, this_params_g)
all_preds_this_rep[[i_d]] = res_single_split$all_preds
thetas_this_rep[i_d] = res_single_split$theta
@@ -125,27 +130,28 @@ dml_plr_multitreat = function(data, y, d,
fit_plr_single_split = function(data, y, d,
- n_folds, ml_g, ml_m,
+ n_folds, ml_l, ml_m, ml_g,
dml_procedure, score, smpl,
- params_g, params_m) {
+ params_l, params_m, params_g) {
train_ids = smpl$train_ids
test_ids = smpl$test_ids
+ fit_g = (score == "IV-type") | is.function(score)
all_preds = fit_nuisance_plr(
data, y, d,
- ml_g, ml_m,
- smpl,
- params_g, params_m)
+ ml_l, ml_m, ml_g,
+ n_folds, smpl, fit_g,
+ params_l, params_m, params_g)
residuals = compute_plr_residuals(
data, y, d, n_folds, smpl,
all_preds)
- u_hat = residuals$u_hat
- v_hat = residuals$v_hat
+ y_minus_l_hat = residuals$y_minus_l_hat
+ d_minus_m_hat = residuals$d_minus_m_hat
+ y_minus_g_hat = residuals$y_minus_g_hat
D = data[, d]
Y = data[, y]
- v_hatd = v_hat * D
# DML 1
if (dml_procedure == "dml1") {
@@ -154,30 +160,37 @@ fit_plr_single_split = function(data, y, d,
test_index = test_ids[[i]]
orth_est = orth_plr_dml(
- u_hat = u_hat[test_index], v_hat = v_hat[test_index],
- v_hatd = v_hatd[test_index],
+ y_minus_l_hat = y_minus_l_hat[test_index],
+ d_minus_m_hat = d_minus_m_hat[test_index],
+ y_minus_g_hat = y_minus_g_hat[test_index],
+ d = D[test_index],
score = score)
thetas[i] = orth_est$theta
}
theta = mean(thetas, na.rm = TRUE)
if (length(train_ids) == 1) {
D = D[test_index]
- u_hat = u_hat[test_index]
- v_hat = v_hat[test_index]
- v_hatd = v_hatd[test_index]
+ y_minus_l_hat = y_minus_l_hat[test_index]
+ d_minus_m_hat = d_minus_m_hat[test_index]
+ y_minus_g_hat = y_minus_g_hat[test_index]
}
}
if (dml_procedure == "dml2") {
orth_est = orth_plr_dml(
- u_hat = u_hat, v_hat = v_hat,
- v_hatd = v_hatd, score = score)
+ y_minus_l_hat = y_minus_l_hat,
+ d_minus_m_hat = d_minus_m_hat,
+ y_minus_g_hat = y_minus_g_hat,
+ d = D, score = score)
theta = orth_est$theta
}
se = sqrt(var_plr(
- theta = theta, d = D, u_hat = u_hat, v_hat = v_hat,
- v_hatd = v_hatd, score = score))
+ theta = theta, d = D,
+ y_minus_l_hat = y_minus_l_hat,
+ d_minus_m_hat = d_minus_m_hat,
+ y_minus_g_hat = y_minus_g_hat,
+ score = score))
res = list(
theta = theta, se = se,
@@ -188,27 +201,29 @@ fit_plr_single_split = function(data, y, d,
fit_nuisance_plr = function(data, y, d,
- ml_g, ml_m,
- smpls,
- params_g, params_m) {
+ ml_l, ml_m, ml_g,
+ n_folds, smpls, fit_g,
+ params_l, params_m, params_g) {
train_ids = smpls$train_ids
test_ids = smpls$test_ids
- # nuisance g
- g_indx = names(data) != d
- data_g = data[, g_indx, drop = FALSE]
- task_g = mlr3::TaskRegr$new(id = paste0("nuis_g_", d), backend = data_g, target = y)
+ # nuisance l
+ l_indx = names(data) != d
+ data_l = data[, l_indx, drop = FALSE]
+ task_l = mlr3::TaskRegr$new(
+ id = paste0("nuis_l_", d),
+ backend = data_l, target = y)
- resampling_g = mlr3::rsmp("custom")
- resampling_g$instantiate(task_g, train_ids, test_ids)
+ resampling_l = mlr3::rsmp("custom")
+ resampling_l$instantiate(task_l, train_ids, test_ids)
- if (!is.null(params_g)) {
- ml_g$param_set$values = params_g
+ if (!is.null(params_l)) {
+ ml_l$param_set$values = params_l
}
- r_g = mlr3::resample(task_g, ml_g, resampling_g, store_models = TRUE)
- g_hat_list = lapply(r_g$predictions(), function(x) x$response)
+ r_l = mlr3::resample(task_l, ml_l, resampling_l, store_models = TRUE)
+ l_hat_list = lapply(r_l$predictions(), function(x) x$response)
# nuisance m
if (!is.null(params_m)) {
@@ -239,7 +254,45 @@ fit_nuisance_plr = function(data, y, d,
m_hat_list = lapply(r_m$predictions(), function(x) as.data.table(x)$prob.1)
}
+ if (fit_g) {
+ # nuisance g
+ residuals = compute_plr_residuals(
+ data, y, d, n_folds,
+ smpls, list(
+ l_hat_list = l_hat_list,
+ g_hat_list = NULL,
+ m_hat_list = m_hat_list))
+ y_minus_l_hat = residuals$y_minus_l_hat
+ d_minus_m_hat = residuals$d_minus_m_hat
+ psi_a = -d_minus_m_hat * d_minus_m_hat
+ psi_b = d_minus_m_hat * y_minus_l_hat
+ theta_initial = -mean(psi_b, na.rm = TRUE) / mean(psi_a, na.rm = TRUE)
+
+ D = data[, d]
+ Y = data[, y]
+ g_indx = names(data) != y & names(data) != d
+ y_minus_theta_d = Y - theta_initial * D
+ data_g = cbind(data[, g_indx, drop = FALSE], y_minus_theta_d)
+
+ task_g = mlr3::TaskRegr$new(
+ id = paste0("nuis_g_", d), backend = data_g,
+ target = "y_minus_theta_d")
+
+ resampling_g = mlr3::rsmp("custom")
+ resampling_g$instantiate(task_g, train_ids, test_ids)
+
+ if (!is.null(params_g)) {
+ ml_g$param_set$values = params_g
+ }
+
+ r_g = mlr3::resample(task_g, ml_g, resampling_g, store_models = TRUE)
+ g_hat_list = lapply(r_g$predictions(), function(x) x$response)
+ } else {
+ g_hat_list = NULL
+ }
+
all_preds = list(
+ l_hat_list = l_hat_list,
m_hat_list = m_hat_list,
g_hat_list = g_hat_list)
@@ -250,42 +303,50 @@ compute_plr_residuals = function(data, y, d, n_folds, smpls, all_preds) {
test_ids = smpls$test_ids
- g_hat_list = all_preds$g_hat_list
+ l_hat_list = all_preds$l_hat_list
m_hat_list = all_preds$m_hat_list
+ g_hat_list = all_preds$g_hat_list
n = nrow(data)
D = data[, d]
Y = data[, y]
- v_hat = u_hat = w_hat = rep(NA_real_, n)
+ y_minus_l_hat = d_minus_m_hat = y_minus_g_hat = rep(NA_real_, n)
for (i in 1:n_folds) {
test_index = test_ids[[i]]
- g_hat = g_hat_list[[i]]
+ l_hat = l_hat_list[[i]]
m_hat = m_hat_list[[i]]
- u_hat[test_index] = Y[test_index] - g_hat
- v_hat[test_index] = D[test_index] - m_hat
+ y_minus_l_hat[test_index] = Y[test_index] - l_hat
+ d_minus_m_hat[test_index] = D[test_index] - m_hat
+
+ if (!is.null(g_hat_list)) {
+ g_hat = g_hat_list[[i]]
+ y_minus_g_hat[test_index] = Y[test_index] - g_hat
+ }
}
- residuals = list(u_hat = u_hat, v_hat = v_hat)
+ residuals = list(
+ y_minus_l_hat = y_minus_l_hat,
+ d_minus_m_hat = d_minus_m_hat,
+ y_minus_g_hat = y_minus_g_hat)
return(residuals)
}
# Orthogonalized Estimation of Coefficient in PLR
-orth_plr_dml = function(u_hat, v_hat, v_hatd, score) {
+orth_plr_dml = function(y_minus_l_hat, d_minus_m_hat, y_minus_g_hat, d, score) {
theta = NA_real_
if (score == "partialling out") {
- res_fit = stats::lm(u_hat ~ 0 + v_hat)
+ res_fit = stats::lm(y_minus_l_hat ~ 0 + d_minus_m_hat)
theta = stats::coef(res_fit)
}
else if (score == "IV-type") {
- theta = mean(v_hat * u_hat) / mean(v_hatd)
- # se = 1/(mean(u_hat)^2) * mean((v_hat - theta*u_hat)*u_hat)^2
+ theta = mean(d_minus_m_hat * y_minus_g_hat) / mean(d_minus_m_hat * d)
}
else {
@@ -298,15 +359,15 @@ orth_plr_dml = function(u_hat, v_hat, v_hatd, score) {
# Variance estimation for DML estimator in the partially linear regression model
-var_plr = function(theta, d, u_hat, v_hat, v_hatd, score) {
+var_plr = function(theta, d, y_minus_l_hat, d_minus_m_hat, y_minus_g_hat, score) {
n = length(d)
if (score == "partialling out") {
- var = 1 / n * 1 / (mean(v_hat^2))^2 *
- mean(((u_hat - v_hat * theta) * v_hat)^2)
+ var = 1 / n * 1 / (mean(d_minus_m_hat^2))^2 *
+ mean(((y_minus_l_hat - d_minus_m_hat * theta) * d_minus_m_hat)^2)
}
else if (score == "IV-type") {
- var = 1 / n * 1 / mean(v_hatd)^2 *
- mean(((u_hat - d * theta) * v_hat)^2)
+ var = 1 / n * 1 / mean(d_minus_m_hat * d)^2 *
+ mean(((y_minus_g_hat - d * theta) * d_minus_m_hat)^2)
}
return(c(var))
}
@@ -375,18 +436,19 @@ boot_plr_single_split = function(theta, se, data, y, d,
residuals = compute_plr_residuals(
data, y, d, n_folds,
smpl, all_preds)
- u_hat = residuals$u_hat
- v_hat = residuals$v_hat
+ y_minus_l_hat = residuals$y_minus_l_hat
+ d_minus_m_hat = residuals$d_minus_m_hat
+ y_minus_g_hat = residuals$y_minus_g_hat
+
D = data[, d]
- v_hatd = v_hat * D
if (score == "partialling out") {
- psi = (u_hat - v_hat * theta) * v_hat
- psi_a = -v_hat * v_hat
+ psi = (y_minus_l_hat - d_minus_m_hat * theta) * d_minus_m_hat
+ psi_a = -d_minus_m_hat * d_minus_m_hat
}
else if (score == "IV-type") {
- psi = (u_hat - D * theta) * v_hat
- psi_a = -v_hatd
+ psi = (y_minus_g_hat - D * theta) * d_minus_m_hat
+ psi_a = -d_minus_m_hat * D
}
res = functional_bootstrap(
diff --git a/tests/testthat/helper-09-dml_pliv.R b/tests/testthat/helper-09-dml_pliv.R
index afe10667..2c2f8279 100644
--- a/tests/testthat/helper-09-dml_pliv.R
+++ b/tests/testthat/helper-09-dml_pliv.R
@@ -1,10 +1,10 @@
# Double Machine Learning for Partially Linear Instrumental Variable Regression.
dml_pliv = function(data, y, d, z,
n_folds,
- ml_g, ml_m, ml_r,
+ ml_l, ml_m, ml_r, ml_g,
params, dml_procedure, score,
n_rep = 1, smpls = NULL,
- params_g = NULL, params_m = NULL, params_r = NULL) {
+ params_l = NULL, params_m = NULL, params_r = NULL, params_g = NULL) {
if (is.null(smpls)) {
smpls = lapply(1:n_rep, function(x) sample_splitting(n_folds, data))
@@ -18,18 +18,21 @@ dml_pliv = function(data, y, d, z,
train_ids = this_smpl$train_ids
test_ids = this_smpl$test_ids
+ fit_g = (score == "IV-type") | is.function(score)
all_preds[[i_rep]] = fit_nuisance_pliv(
data, y, d, z,
- ml_g, ml_m, ml_r,
- this_smpl,
- params_g, params_m, params_r)
+ ml_l, ml_m, ml_r, ml_g,
+ n_folds, this_smpl, fit_g,
+ params_l, params_m, params_r, params_g)
residuals = compute_pliv_residuals(
data, y, d, z, n_folds, this_smpl,
all_preds[[i_rep]])
- u_hat = residuals$u_hat
- v_hat = residuals$v_hat
- w_hat = residuals$w_hat
+ y_minus_l_hat = residuals$y_minus_l_hat
+ z_minus_m_hat = residuals$z_minus_m_hat
+ d_minus_r_hat = residuals$d_minus_r_hat
+ y_minus_g_hat = residuals$y_minus_g_hat
+ D = data[, d]
# DML 1
if (dml_procedure == "dml1") {
@@ -37,29 +40,35 @@ dml_pliv = function(data, y, d, z,
for (i in 1:n_folds) {
test_index = test_ids[[i]]
orth_est = orth_pliv_dml(
- u_hat = u_hat[test_index],
- v_hat = v_hat[test_index],
- w_hat = w_hat[test_index],
+ y_minus_l_hat = y_minus_l_hat[test_index],
+ z_minus_m_hat = z_minus_m_hat[test_index],
+ d_minus_r_hat = d_minus_r_hat[test_index],
+ y_minus_g_hat = y_minus_g_hat[test_index],
+ D = D[test_index],
score = score)
thetas[i] = orth_est$theta
}
all_thetas[i_rep] = mean(thetas, na.rm = TRUE)
if (length(train_ids) == 1) {
- u_hat = u_hat[test_index]
- v_hat = v_hat[test_index]
- w_hat = w_hat[test_index]
+ y_minus_l_hat = y_minus_l_hat[test_index]
+ z_minus_m_hat = z_minus_m_hat[test_index]
+ d_minus_r_hat = d_minus_r_hat[test_index]
+ y_minus_g_hat = y_minus_g_hat[test_index]
}
}
if (dml_procedure == "dml2") {
orth_est = orth_pliv_dml(
- u_hat = u_hat, v_hat = v_hat, w_hat = w_hat,
- score = score)
+ y_minus_l_hat = y_minus_l_hat, z_minus_m_hat = z_minus_m_hat,
+ d_minus_r_hat = d_minus_r_hat, y_minus_g_hat = y_minus_g_hat,
+ D = D, score = score)
all_thetas[i_rep] = orth_est$theta
}
all_ses[i_rep] = sqrt(var_pliv(
- theta = all_thetas[i_rep], u_hat = u_hat, v_hat = v_hat,
- w_hat = w_hat, score = score))
+ D = D, theta = all_thetas[i_rep],
+ y_minus_l_hat = y_minus_l_hat, z_minus_m_hat = z_minus_m_hat,
+ d_minus_r_hat = d_minus_r_hat, y_minus_g_hat = y_minus_g_hat,
+ score = score))
}
theta = stats::median(all_thetas)
@@ -83,27 +92,27 @@ dml_pliv = function(data, y, d, z,
}
fit_nuisance_pliv = function(data, y, d, z,
- ml_g, ml_m, ml_r,
- smpls,
- params_g, params_m, params_r) {
+ ml_l, ml_m, ml_r, ml_g,
+ n_folds, smpls, fit_g,
+ params_l, params_m, params_r, params_g) {
train_ids = smpls$train_ids
test_ids = smpls$test_ids
- # nuisance g: E[Y|X]
- g_indx = names(data) != d & names(data) != z
- data_g = data[, g_indx, drop = FALSE]
- task_g = mlr3::TaskRegr$new(id = paste0("nuis_g_", d), backend = data_g, target = y)
+ # nuisance l: E[Y|X]
+ l_indx = names(data) != d & names(data) != z
+ data_l = data[, l_indx, drop = FALSE]
+ task_l = mlr3::TaskRegr$new(id = paste0("nuis_l_", d), backend = data_l, target = y)
- resampling_g = mlr3::rsmp("custom")
- resampling_g$instantiate(task_g, train_ids, test_ids)
+ resampling_l = mlr3::rsmp("custom")
+ resampling_l$instantiate(task_l, train_ids, test_ids)
- if (!is.null(params_g)) {
- ml_g$param_set$values = params_g
+ if (!is.null(params_l)) {
+ ml_l$param_set$values = params_l
}
- r_g = mlr3::resample(task_g, ml_g, resampling_g, store_models = TRUE)
- g_hat_list = lapply(r_g$predictions(), function(x) x$response)
+ r_l = mlr3::resample(task_l, ml_l, resampling_l, store_models = TRUE)
+ l_hat_list = lapply(r_l$predictions(), function(x) x$response)
# nuisance m: E[Z|X]
m_indx = names(data) != y & names(data) != d
@@ -133,10 +142,50 @@ fit_nuisance_pliv = function(data, y, d, z,
r_r = mlr3::resample(task_r, ml_r, resampling_r, store_models = TRUE)
r_hat_list = lapply(r_r$predictions(), function(x) x$response)
+ if (fit_g) {
+ # nuisance g
+ residuals = compute_pliv_residuals(
+ data, y, d, z, n_folds,
+ smpls, list(
+ l_hat_list = l_hat_list,
+ m_hat_list = m_hat_list,
+ r_hat_list = r_hat_list,
+ g_hat_list = NULL))
+ y_minus_l_hat = residuals$y_minus_l_hat
+ z_minus_m_hat = residuals$z_minus_m_hat
+ d_minus_r_hat = residuals$d_minus_r_hat
+ psi_a = -d_minus_r_hat * z_minus_m_hat
+ psi_b = z_minus_m_hat * y_minus_l_hat
+ theta_initial = -mean(psi_b, na.rm = TRUE) / mean(psi_a, na.rm = TRUE)
+
+ D = data[, d]
+ Y = data[, y]
+ g_indx = names(data) != y & names(data) != d & names(data) != z
+ y_minus_theta_d = Y - theta_initial * D
+ data_g = cbind(data[, g_indx, drop = FALSE], y_minus_theta_d)
+
+ task_g = mlr3::TaskRegr$new(
+ id = paste0("nuis_g_", d), backend = data_g,
+ target = "y_minus_theta_d")
+
+ resampling_g = mlr3::rsmp("custom")
+ resampling_g$instantiate(task_g, train_ids, test_ids)
+
+ if (!is.null(params_g)) {
+ ml_g$param_set$values = params_g
+ }
+
+ r_g = mlr3::resample(task_g, ml_g, resampling_g, store_models = TRUE)
+ g_hat_list = lapply(r_g$predictions(), function(x) x$response)
+ } else {
+ g_hat_list = NULL
+ }
+
all_preds = list(
+ l_hat_list = l_hat_list,
m_hat_list = m_hat_list,
- g_hat_list = g_hat_list,
- r_hat_list = r_hat_list)
+ r_hat_list = r_hat_list,
+ g_hat_list = g_hat_list)
return(all_preds)
}
@@ -145,37 +194,50 @@ compute_pliv_residuals = function(data, y, d, z, n_folds, smpls, all_preds) {
test_ids = smpls$test_ids
+ l_hat_list = all_preds$l_hat_list
m_hat_list = all_preds$m_hat_list
- g_hat_list = all_preds$g_hat_list
r_hat_list = all_preds$r_hat_list
+ g_hat_list = all_preds$g_hat_list
n = nrow(data)
D = data[, d]
Y = data[, y]
Z = data[, z]
- v_hat = u_hat = w_hat = rep(NA_real_, n)
+ y_minus_l_hat = z_minus_m_hat = d_minus_r_hat = y_minus_g_hat = rep(NA_real_, n)
for (i in 1:n_folds) {
test_index = test_ids[[i]]
+ l_hat = l_hat_list[[i]]
m_hat = m_hat_list[[i]]
- g_hat = g_hat_list[[i]]
r_hat = r_hat_list[[i]]
- v_hat[test_index] = D[test_index] - r_hat
- u_hat[test_index] = Y[test_index] - g_hat
- w_hat[test_index] = Z[test_index] - m_hat
+ y_minus_l_hat[test_index] = Y[test_index] - l_hat
+ z_minus_m_hat[test_index] = Z[test_index] - m_hat
+ d_minus_r_hat[test_index] = D[test_index] - r_hat
+
+ if (!is.null(g_hat_list)) {
+ g_hat = g_hat_list[[i]]
+ y_minus_g_hat[test_index] = Y[test_index] - g_hat
+ }
}
- residuals = list(u_hat = u_hat, v_hat = v_hat, w_hat = w_hat)
+ residuals = list(
+ y_minus_l_hat = y_minus_l_hat,
+ z_minus_m_hat = z_minus_m_hat,
+ d_minus_r_hat = d_minus_r_hat,
+ y_minus_g_hat = y_minus_g_hat)
return(residuals)
}
# Orthogonalized Estimation of Coefficient in PLR
-orth_pliv_dml = function(u_hat, v_hat, w_hat, score) {
+orth_pliv_dml = function(y_minus_l_hat, z_minus_m_hat,
+ d_minus_r_hat, y_minus_g_hat, D, score) {
if (score == "partialling out") {
- theta = mean(u_hat * w_hat) / mean(v_hat * w_hat)
+ theta = mean(y_minus_l_hat * z_minus_m_hat) / mean(d_minus_r_hat * z_minus_m_hat)
+ } else if (score == "IV-type") {
+ theta = mean(y_minus_g_hat * z_minus_m_hat) / mean(D * z_minus_m_hat)
} else {
stop("Inference framework for orthogonal estimation unknown")
}
@@ -184,10 +246,14 @@ orth_pliv_dml = function(u_hat, v_hat, w_hat, score) {
}
# Variance estimation for DML estimator in the partially linear regression model
-var_pliv = function(theta, u_hat, v_hat, w_hat, score) {
+var_pliv = function(theta, D, y_minus_l_hat, z_minus_m_hat,
+ d_minus_r_hat, y_minus_g_hat, score) {
if (score == "partialling out") {
- var = mean(1 / length(u_hat) * 1 / (mean(v_hat * w_hat))^2 *
- mean(((u_hat - v_hat * theta) * w_hat)^2))
+ var = mean(1 / length(y_minus_l_hat) * 1 / (mean(d_minus_r_hat * z_minus_m_hat))^2 *
+ mean(((y_minus_l_hat - d_minus_r_hat * theta) * z_minus_m_hat)^2))
+ } else if (score == "IV-type") {
+ var = mean(1 / length(y_minus_l_hat) * 1 / (mean(D * z_minus_m_hat))^2 *
+ mean(((y_minus_g_hat - D * theta) * z_minus_m_hat)^2))
} else {
stop("Inference framework for variance estimation unknown")
}
@@ -196,18 +262,25 @@ var_pliv = function(theta, u_hat, v_hat, w_hat, score) {
# Bootstrap Implementation for Partially Linear Regression Model
bootstrap_pliv = function(theta, se, data, y, d, z, n_folds, smpls,
- all_preds, bootstrap, n_rep_boot,
+ all_preds, bootstrap, n_rep_boot, score,
n_rep = 1) {
for (i_rep in 1:n_rep) {
residuals = compute_pliv_residuals(
data, y, d, z, n_folds,
smpls[[i_rep]], all_preds[[i_rep]])
- u_hat = residuals$u_hat
- v_hat = residuals$v_hat
- w_hat = residuals$w_hat
-
- psi = (u_hat - v_hat * theta[i_rep]) * w_hat
- psi_a = -v_hat * w_hat
+ y_minus_l_hat = residuals$y_minus_l_hat
+ d_minus_r_hat = residuals$d_minus_r_hat
+ z_minus_m_hat = residuals$z_minus_m_hat
+ y_minus_g_hat = residuals$y_minus_g_hat
+
+ if (score == "partialling out") {
+ psi = (y_minus_l_hat - d_minus_r_hat * theta[i_rep]) * z_minus_m_hat
+ psi_a = -d_minus_r_hat * z_minus_m_hat
+ } else if (score == "IV-type") {
+ D = data[, d]
+ psi = (y_minus_g_hat - D * theta[i_rep]) * z_minus_m_hat
+ psi_a = -D * z_minus_m_hat
+ }
n = length(psi)
weights = draw_bootstrap_weights(bootstrap, n_rep_boot, n)
diff --git a/tests/testthat/helper-13-dml_pliv_partial_x.R b/tests/testthat/helper-13-dml_pliv_partial_x.R
index daefdfd8..c4fb9f5d 100644
--- a/tests/testthat/helper-13-dml_pliv_partial_x.R
+++ b/tests/testthat/helper-13-dml_pliv_partial_x.R
@@ -1,9 +1,9 @@
dml_pliv_partial_x = function(data, y, d, z,
n_folds,
- ml_g, ml_m, ml_r,
+ ml_l, ml_m, ml_r,
params, dml_procedure, score,
n_rep = 1, smpls = NULL,
- params_g = NULL, params_m = NULL, params_r = NULL) {
+ params_l = NULL, params_m = NULL, params_r = NULL) {
stopifnot(length(z) > 1)
if (is.null(smpls)) {
@@ -18,9 +18,9 @@ dml_pliv_partial_x = function(data, y, d, z,
all_preds[[i_rep]] = fit_nuisance_pliv_partial_x(
data, y, d, z,
- ml_g, ml_m, ml_r,
+ ml_l, ml_m, ml_r,
this_smpl,
- params_g, params_m, params_r)
+ params_l, params_m, params_r)
residuals = compute_pliv_partial_x_residuals(
data, y, d, z, n_folds,
@@ -77,27 +77,27 @@ dml_pliv_partial_x = function(data, y, d, z,
}
fit_nuisance_pliv_partial_x = function(data, y, d, z,
- ml_g, ml_m, ml_r,
+ ml_l, ml_m, ml_r,
smpls,
- params_g, params_m, params_r) {
+ params_l, params_m, params_r) {
train_ids = smpls$train_ids
test_ids = smpls$test_ids
- # nuisance g: E[Y|X]
- g_indx = names(data) != d & (names(data) %in% z == FALSE)
- data_g = data[, g_indx, drop = FALSE]
- task_g = mlr3::TaskRegr$new(id = paste0("nuis_g_", d), backend = data_g, target = y)
+ # nuisance l: E[Y|X]
+ l_indx = names(data) != d & (names(data) %in% z == FALSE)
+ data_l = data[, l_indx, drop = FALSE]
+ task_l = mlr3::TaskRegr$new(id = paste0("nuis_l_", d), backend = data_l, target = y)
- resampling_g = mlr3::rsmp("custom")
- resampling_g$instantiate(task_g, train_ids, test_ids)
+ resampling_l = mlr3::rsmp("custom")
+ resampling_l$instantiate(task_l, train_ids, test_ids)
- if (!is.null(params_g)) {
- ml_g$param_set$values = params_g
+ if (!is.null(params_l)) {
+ ml_l$param_set$values = params_l
}
- r_g = mlr3::resample(task_g, ml_g, resampling_g, store_models = TRUE)
- g_hat_list = lapply(r_g$predictions(), function(x) x$response)
+ r_l = mlr3::resample(task_l, ml_l, resampling_l, store_models = TRUE)
+ l_hat_list = lapply(r_l$predictions(), function(x) x$response)
# nuisance m: E[Z|X]
n_z = length(z)
@@ -150,7 +150,7 @@ fit_nuisance_pliv_partial_x = function(data, y, d, z,
Z - m_hat_array)
all_preds = list(
- g_hat_list = g_hat_list,
+ l_hat_list = l_hat_list,
r_hat_list = r_hat_list,
r_hat_tilde = r_hat_tilde)
@@ -162,7 +162,7 @@ compute_pliv_partial_x_residuals = function(data, y, d, z, n_folds, smpls,
test_ids = smpls$test_ids
- g_hat_list = all_preds$g_hat_list
+ l_hat_list = all_preds$l_hat_list
r_hat_list = all_preds$r_hat_list
r_hat_tilde = all_preds$r_hat_tilde
@@ -175,10 +175,10 @@ compute_pliv_partial_x_residuals = function(data, y, d, z, n_folds, smpls,
for (i in 1:n_folds) {
test_index = test_ids[[i]]
- g_hat = g_hat_list[[i]]
+ l_hat = l_hat_list[[i]]
r_hat = r_hat_list[[i]]
- u_hat[test_index] = Y[test_index] - g_hat
+ u_hat[test_index] = Y[test_index] - l_hat
w_hat[test_index] = D[test_index] - r_hat
}
residuals = list(u_hat = u_hat, w_hat = w_hat, r_hat_tilde = r_hat_tilde)
diff --git a/tests/testthat/helper-15-dml_pliv_partial_xz.R b/tests/testthat/helper-15-dml_pliv_partial_xz.R
index a6dc29a9..58450cea 100644
--- a/tests/testthat/helper-15-dml_pliv_partial_xz.R
+++ b/tests/testthat/helper-15-dml_pliv_partial_xz.R
@@ -1,9 +1,9 @@
dml_pliv_partial_xz = function(data, y, d, z,
n_folds,
- ml_g, ml_m, ml_r,
+ ml_l, ml_m, ml_r,
params, dml_procedure, score,
n_rep = 1, smpls = NULL,
- params_g = NULL, params_m = NULL, params_r = NULL) {
+ params_l = NULL, params_m = NULL, params_r = NULL) {
if (is.null(smpls)) {
smpls = lapply(1:n_rep, function(x) sample_splitting(n_folds, data))
@@ -17,9 +17,9 @@ dml_pliv_partial_xz = function(data, y, d, z,
all_preds[[i_rep]] = fit_nuisance_pliv_partial_xz(
data, y, d, z,
- ml_g, ml_m, ml_r,
+ ml_l, ml_m, ml_r,
this_smpl,
- params_g, params_m, params_r)
+ params_l, params_m, params_r)
residuals = compute_pliv_partial_xz_residuals(
data, y, d, z, n_folds,
@@ -81,27 +81,27 @@ dml_pliv_partial_xz = function(data, y, d, z,
}
fit_nuisance_pliv_partial_xz = function(data, y, d, z,
- ml_g, ml_m, ml_r,
+ ml_l, ml_m, ml_r,
smpls,
- params_g, params_m, params_r) {
+ params_l, params_m, params_r) {
train_ids = smpls$train_ids
test_ids = smpls$test_ids
- # nuisance g: E[Y|X]
- g_indx = names(data) != d & (names(data) %in% z == FALSE)
- data_g = data[, g_indx, drop = FALSE]
- task_g = mlr3::TaskRegr$new(id = paste0("nuis_g_", d), backend = data_g, target = y)
+ # nuisance l: E[Y|X]
+ l_indx = names(data) != d & (names(data) %in% z == FALSE)
+ data_l = data[, l_indx, drop = FALSE]
+ task_l = mlr3::TaskRegr$new(id = paste0("nuis_l_", d), backend = data_l, target = y)
- resampling_g = mlr3::rsmp("custom")
- resampling_g$instantiate(task_g, train_ids, test_ids)
+ resampling_l = mlr3::rsmp("custom")
+ resampling_l$instantiate(task_l, train_ids, test_ids)
- if (!is.null(params_g)) {
- ml_g$param_set$values = params_g
+ if (!is.null(params_l)) {
+ ml_l$param_set$values = params_l
}
- r_g = mlr3::resample(task_g, ml_g, resampling_g, store_models = TRUE)
- g_hat_list = lapply(r_g$predictions(), function(x) x$response)
+ r_l = mlr3::resample(task_l, ml_l, resampling_l, store_models = TRUE)
+ l_hat_list = lapply(r_l$predictions(), function(x) x$response)
# nuisance m: E[D|XZ]
m_indx = (names(data) != y)
@@ -141,7 +141,7 @@ fit_nuisance_pliv_partial_xz = function(data, y, d, z,
}
all_preds = list(
- g_hat_list = g_hat_list,
+ l_hat_list = l_hat_list,
m_hat_list = m_hat_list,
r_hat_list = r_hat_list)
@@ -153,7 +153,7 @@ compute_pliv_partial_xz_residuals = function(data, y, d, z, n_folds, smpls,
test_ids = smpls$test_ids
- g_hat_list = all_preds$g_hat_list
+ l_hat_list = all_preds$l_hat_list
m_hat_list = all_preds$m_hat_list
r_hat_list = all_preds$r_hat_list
@@ -166,11 +166,11 @@ compute_pliv_partial_xz_residuals = function(data, y, d, z, n_folds, smpls,
for (i in 1:n_folds) {
test_index = test_ids[[i]]
- g_hat = g_hat_list[[i]]
+ l_hat = l_hat_list[[i]]
m_hat = m_hat_list[[i]]
r_hat = r_hat_list[[i]]
- u_hat[test_index] = Y[test_index] - g_hat
+ u_hat[test_index] = Y[test_index] - l_hat
v_hat[test_index] = m_hat - r_hat
w_hat[test_index] = D[test_index] - r_hat
}
diff --git a/tests/testthat/print_outputs/dml_pliv.txt b/tests/testthat/print_outputs/dml_pliv.txt
index e0ec4b2a..59c434fa 100644
--- a/tests/testthat/print_outputs/dml_pliv.txt
+++ b/tests/testthat/print_outputs/dml_pliv.txt
@@ -15,7 +15,7 @@ Score function: partialling out
DML algorithm: dml2
------------------ Machine learner ------------------
-ml_g: regr.rpart
+ml_l: regr.rpart
ml_m: regr.rpart
ml_r: regr.rpart
diff --git a/tests/testthat/print_outputs/dml_plr.txt b/tests/testthat/print_outputs/dml_plr.txt
index b1f0def8..5e982686 100644
--- a/tests/testthat/print_outputs/dml_plr.txt
+++ b/tests/testthat/print_outputs/dml_plr.txt
@@ -14,7 +14,7 @@ Score function: partialling out
DML algorithm: dml2
------------------ Machine learner ------------------
-ml_g: regr.rpart
+ml_l: regr.rpart
ml_m: regr.rpart
------------------ Resampling ------------------
diff --git a/tests/testthat/test-double_ml_data.R b/tests/testthat/test-double_ml_data.R
index 4db052c3..cd295473 100644
--- a/tests/testthat/test-double_ml_data.R
+++ b/tests/testthat/test-double_ml_data.R
@@ -386,7 +386,7 @@ test_that("Unit tests for invalid data", {
"DoubleMLPLIV instead of DoubleMLPLR.")
expect_error(DoubleMLPLR$new(
data = data_pliv$dml_data,
- ml_g = mlr3::lrn("regr.rpart"),
+ ml_l = mlr3::lrn("regr.rpart"),
ml_m = mlr3::lrn("regr.rpart")),
regexp = msg)
@@ -398,7 +398,7 @@ test_that("Unit tests for invalid data", {
"variable\\(s\\) use DoubleMLPLR instead of DoubleMLPLIV.")
expect_error(DoubleMLPLIV$new(
data = data_plr$dml_data,
- ml_g = mlr3::lrn("regr.rpart"),
+ ml_l = mlr3::lrn("regr.rpart"),
ml_m = mlr3::lrn("regr.rpart"),
ml_r = mlr3::lrn("regr.rpart")),
regexp = msg)
diff --git a/tests/testthat/test-double_ml_pliv.R b/tests/testthat/test-double_ml_pliv.R
index 42e48a14..19cb2f31 100644
--- a/tests/testthat/test-double_ml_pliv.R
+++ b/tests/testthat/test-double_ml_pliv.R
@@ -15,7 +15,7 @@ if (on_cran) {
test_cases = expand.grid(
learner = c("regr.lm", "regr.glmnet", "graph_learner"),
dml_procedure = c("dml1", "dml2"),
- score = "partialling out",
+ score = c("partialling out", "IV-type"),
stringsAsFactors = FALSE)
}
test_cases[".test_name"] = apply(test_cases, 1, paste, collapse = "_")
@@ -29,9 +29,10 @@ patrick::with_parameters_test_that("Unit tests for PLIV:",
pliv_hat = dml_pliv(data_pliv$df,
y = "y", d = "d", z = "z",
n_folds = 5,
- ml_g = learner_pars$ml_g$clone(),
+ ml_l = learner_pars$ml_l$clone(),
ml_m = learner_pars$ml_m$clone(),
ml_r = learner_pars$ml_r$clone(),
+ ml_g = learner_pars$ml_g$clone(),
dml_procedure = dml_procedure, score = score)
theta = pliv_hat$coef
se = pliv_hat$se
@@ -41,19 +42,32 @@ patrick::with_parameters_test_that("Unit tests for PLIV:",
y = "y", d = "d", z = "z",
n_folds = 5, smpls = pliv_hat$smpls,
all_preds = pliv_hat$all_preds,
- bootstrap = "normal", n_rep_boot = n_rep_boot)$boot_coef
+ bootstrap = "normal", n_rep_boot = n_rep_boot,
+ score = score)$boot_coef
set.seed(3141)
- double_mlpliv_obj = DoubleMLPLIV$new(
- data = data_pliv$dml_data,
- n_folds = 5,
- ml_g = learner_pars$ml_g$clone(),
- ml_m = learner_pars$ml_m$clone(),
- ml_r = learner_pars$ml_r$clone(),
- dml_procedure = dml_procedure,
- score = score)
+ if (score == "partialling out") {
+ double_mlpliv_obj = DoubleMLPLIV$new(
+ data = data_pliv$dml_data,
+ n_folds = 5,
+ ml_l = learner_pars$ml_l$clone(),
+ ml_m = learner_pars$ml_m$clone(),
+ ml_r = learner_pars$ml_r$clone(),
+ dml_procedure = dml_procedure,
+ score = score)
+ } else {
+ double_mlpliv_obj = DoubleMLPLIV$new(
+ data = data_pliv$dml_data,
+ n_folds = 5,
+ ml_l = learner_pars$ml_l$clone(),
+ ml_m = learner_pars$ml_m$clone(),
+ ml_r = learner_pars$ml_r$clone(),
+ ml_g = learner_pars$ml_g$clone(),
+ dml_procedure = dml_procedure,
+ score = score)
+ }
- double_mlpliv_obj$fit()
+ double_mlpliv_obj$fit(store_predictions = T)
theta_obj = double_mlpliv_obj$coef
se_obj = double_mlpliv_obj$se
diff --git a/tests/testthat/test-double_ml_pliv_exception_handling.R b/tests/testthat/test-double_ml_pliv_exception_handling.R
new file mode 100644
index 00000000..b852d643
--- /dev/null
+++ b/tests/testthat/test-double_ml_pliv_exception_handling.R
@@ -0,0 +1,79 @@
+context("Unit tests for exception handling and deprecation warnings of PLIV")
+
+library("mlr3learners")
+
+logger = lgr::get_logger("bbotk")
+logger$set_threshold("warn")
+lgr::get_logger("mlr3")$set_threshold("warn")
+
+test_that("Unit tests for deprecation warnings of PLIV", {
+ set.seed(3141)
+ dml_data_pliv = make_pliv_CHS2015(n_obs = 51, dim_z = 1)
+ ml_l = lrn("regr.ranger")
+ ml_g = lrn("regr.ranger")
+ ml_m = lrn("regr.ranger")
+ ml_r = lrn("regr.ranger")
+ msg = paste0("The argument ml_g was renamed to ml_l.")
+ expect_warning(DoubleMLPLIV$new(dml_data_pliv,
+ ml_g = ml_g, ml_m = ml_m, ml_r = ml_r),
+ regexp = msg)
+
+ msg = paste(
+ "For score = 'IV-type', learners",
+ "ml_l, ml_m, ml_r and ml_g need to be specified.")
+ expect_error(DoubleMLPLIV$new(dml_data_pliv,
+ ml_l = ml_l, ml_m = ml_m, ml_r = ml_r,
+ score = "IV-type"),
+ regexp = msg)
+
+ dml_obj = DoubleMLPLIV$new(dml_data_pliv,
+ ml_l = ml_g, ml_m = ml_m, ml_r = ml_r)
+
+ msg = paste0("Learner ml_g was renamed to ml_l.")
+ expect_warning(dml_obj$set_ml_nuisance_params(
+ "ml_g", "d", list("num.trees" = 10)),
+ regexp = msg)
+
+ par_grids = list(
+ "ml_g" = paradox::ParamSet$new(list(
+ paradox::ParamInt$new("num.trees", lower = 9, upper = 10))),
+ "ml_m" = paradox::ParamSet$new(list(
+ paradox::ParamInt$new("num.trees", lower = 10, upper = 11))),
+ "ml_r" = paradox::ParamSet$new(list(
+ paradox::ParamInt$new("num.trees", lower = 10, upper = 11))))
+
+ msg = paste0("Learner ml_g was renamed to ml_l.")
+ expect_warning(dml_obj$tune(par_grids),
+ regexp = msg)
+
+ tune_settings = list(
+ n_folds_tune = 5,
+ rsmp_tune = mlr3::rsmp("cv", folds = 5),
+ measure = list(ml_g = "regr.mse", ml_m = "regr.mae"),
+ terminator = mlr3tuning::trm("evals", n_evals = 20),
+ algorithm = mlr3tuning::tnr("grid_search"),
+ resolution = 5)
+ expect_warning(dml_obj$tune(par_grids, tune_settings = tune_settings),
+ regexp = msg)
+}
+)
+
+test_that("Unit tests of exception handling for DoubleMLPLIV", {
+ set.seed(3141)
+ dml_data_pliv = make_pliv_CHS2015(n_obs = 51, dim_z = 1)
+ ml_l = lrn("regr.ranger")
+ ml_m = lrn("regr.ranger")
+ ml_r = lrn("regr.ranger")
+ ml_g = lrn("regr.ranger")
+
+
+ msg = paste0(
+ "A learner ml_g has been provided for ",
+ "score = 'partialling out' but will be ignored.")
+ expect_warning(DoubleMLPLIV$new(dml_data_pliv,
+ ml_l = ml_l, ml_m = ml_m, ml_r = ml_r,
+ ml_g = ml_g,
+ score = "partialling out"),
+ regexp = msg)
+}
+)
diff --git a/tests/testthat/test-double_ml_pliv_multi_z_parameter_passing.R b/tests/testthat/test-double_ml_pliv_multi_z_parameter_passing.R
index 5065cd71..e6cf0fa4 100644
--- a/tests/testthat/test-double_ml_pliv_multi_z_parameter_passing.R
+++ b/tests/testthat/test-double_ml_pliv_multi_z_parameter_passing.R
@@ -34,10 +34,10 @@ patrick::with_parameters_test_that("Unit tests for parameter passing of PLIV.par
pliv_hat = dml_pliv_partial_x(df,
y = "y", d = "d", z = c("z", "z2"),
n_folds = n_folds, n_rep = n_rep,
- ml_g = mlr3::lrn(learner_pars$mlmethod$mlmethod_g),
+ ml_l = mlr3::lrn(learner_pars$mlmethod$mlmethod_l),
ml_m = mlr3::lrn(learner_pars$mlmethod$mlmethod_m),
ml_r = mlr3::lrn(learner_pars$mlmethod$mlmethod_r),
- params_g = learner_pars$params$params_g,
+ params_l = learner_pars$params$params_l,
params_m = learner_pars$params$params_m,
params_r = learner_pars$params$params_r,
dml_procedure = dml_procedure, score = score)
@@ -62,7 +62,7 @@ patrick::with_parameters_test_that("Unit tests for parameter passing of PLIV.par
dml_pliv_obj = DoubleMLPLIV.partialX(
data = dml_data,
n_folds = n_folds, n_rep = n_rep,
- ml_g = mlr3::lrn(learner_pars$mlmethod$mlmethod_g),
+ ml_l = mlr3::lrn(learner_pars$mlmethod$mlmethod_l),
ml_m = mlr3::lrn(learner_pars$mlmethod$mlmethod_m),
ml_r = mlr3::lrn(learner_pars$mlmethod$mlmethod_r),
dml_procedure = dml_procedure,
@@ -70,8 +70,8 @@ patrick::with_parameters_test_that("Unit tests for parameter passing of PLIV.par
dml_pliv_obj$set_ml_nuisance_params(
treat_var = "d",
- learner = "ml_g",
- params = learner_pars$params$params_g)
+ learner = "ml_l",
+ params = learner_pars$params$params_l)
dml_pliv_obj$set_ml_nuisance_params(
learner = "ml_m_z",
treat_var = "d",
@@ -118,7 +118,7 @@ patrick::with_parameters_test_that("Unit tests for parameter passing of PLIV.par
set.seed(3141)
dml_pliv_obj = DoubleMLPLIV.partialX(dml_data,
n_folds = n_folds, n_rep = n_rep,
- ml_g = mlr3::lrn(learner_pars$mlmethod$mlmethod_g),
+ ml_l = mlr3::lrn(learner_pars$mlmethod$mlmethod_l),
ml_m = mlr3::lrn(learner_pars$mlmethod$mlmethod_m),
ml_r = mlr3::lrn(learner_pars$mlmethod$mlmethod_r),
dml_procedure = dml_procedure,
@@ -126,8 +126,8 @@ patrick::with_parameters_test_that("Unit tests for parameter passing of PLIV.par
dml_pliv_obj$set_ml_nuisance_params(
treat_var = "d",
- learner = "ml_g",
- params = learner_pars$params$params_g)
+ learner = "ml_l",
+ params = learner_pars$params$params_l)
dml_pliv_obj$set_ml_nuisance_params(
learner = "ml_m_z",
treat_var = "d",
@@ -145,14 +145,14 @@ patrick::with_parameters_test_that("Unit tests for parameter passing of PLIV.par
theta = dml_pliv_obj$coef
se = dml_pliv_obj$se
- params_g_fold_wise = rep(list(rep(list(learner_pars$params$params_g), n_folds)), n_rep)
+ params_l_fold_wise = rep(list(rep(list(learner_pars$params$params_l), n_folds)), n_rep)
params_m_fold_wise = rep(list(rep(list(learner_pars$params$params_m), n_folds)), n_rep)
params_r_fold_wise = rep(list(rep(list(learner_pars$params$params_r), n_folds)), n_rep)
set.seed(3141)
dml_pliv_obj_fold_wise = DoubleMLPLIV.partialX(dml_data,
n_folds = n_folds, n_rep = n_rep,
- ml_g = mlr3::lrn(learner_pars$mlmethod$mlmethod_g),
+ ml_l = mlr3::lrn(learner_pars$mlmethod$mlmethod_l),
ml_m = mlr3::lrn(learner_pars$mlmethod$mlmethod_m),
ml_r = mlr3::lrn(learner_pars$mlmethod$mlmethod_r),
dml_procedure = dml_procedure,
@@ -160,8 +160,8 @@ patrick::with_parameters_test_that("Unit tests for parameter passing of PLIV.par
dml_pliv_obj_fold_wise$set_ml_nuisance_params(
treat_var = "d",
- learner = "ml_g",
- params = params_g_fold_wise,
+ learner = "ml_l",
+ params = params_l_fold_wise,
set_fold_specific = TRUE)
dml_pliv_obj_fold_wise$set_ml_nuisance_params(
treat_var = "d",
@@ -193,7 +193,7 @@ patrick::with_parameters_test_that("Unit tests for parameter passing of PLIV.par
n_folds = 2
n_rep = 3
- params_g = list(cp = 0.01, minsplit = 20) # this are defaults
+ params_l = list(cp = 0.01, minsplit = 20) # this are defaults
params_m = list(cp = 0.01, minsplit = 20) # this are defaults
params_r = list(cp = 0.01, minsplit = 20) # this are defaults
@@ -206,7 +206,7 @@ patrick::with_parameters_test_that("Unit tests for parameter passing of PLIV.par
set.seed(3141)
dml_pliv_default = DoubleMLPLIV.partialX(dml_data,
n_folds = n_folds, n_rep = n_rep,
- ml_g = lrn("regr.rpart"),
+ ml_l = lrn("regr.rpart"),
ml_m = lrn("regr.rpart"),
ml_r = lrn("regr.rpart"),
dml_procedure = dml_procedure,
@@ -219,7 +219,7 @@ patrick::with_parameters_test_that("Unit tests for parameter passing of PLIV.par
set.seed(3141)
dml_pliv_obj = DoubleMLPLIV.partialX(dml_data,
n_folds = n_folds, n_rep = n_rep,
- ml_g = lrn("regr.rpart"),
+ ml_l = lrn("regr.rpart"),
ml_m = lrn("regr.rpart"),
ml_r = lrn("regr.rpart"),
dml_procedure = dml_procedure,
@@ -227,8 +227,8 @@ patrick::with_parameters_test_that("Unit tests for parameter passing of PLIV.par
dml_pliv_obj$set_ml_nuisance_params(
treat_var = "d",
- learner = "ml_g",
- params = params_g)
+ learner = "ml_l",
+ params = params_l)
dml_pliv_obj$set_ml_nuisance_params(
learner = "ml_m_z",
treat_var = "d",
diff --git a/tests/testthat/test-double_ml_pliv_one_way_cluster.R b/tests/testthat/test-double_ml_pliv_one_way_cluster.R
index 73c589c5..4f6f70e4 100644
--- a/tests/testthat/test-double_ml_pliv_one_way_cluster.R
+++ b/tests/testthat/test-double_ml_pliv_one_way_cluster.R
@@ -13,7 +13,7 @@ if (on_cran) {
test_cases = expand.grid(
learner = c("regr.lm", "regr.glmnet"),
dml_procedure = c("dml1", "dml2"),
- score = "partialling out",
+ score = c("partialling out", "IV-type"),
stringsAsFactors = FALSE)
}
test_cases[".test_name"] = apply(test_cases, 1, paste, collapse = "_")
@@ -35,15 +35,20 @@ patrick::with_parameters_test_that("Unit tests for PLIV with one-way clustering:
n_folds = 2
n_rep = 2
-
set.seed(3141)
+ if (score == "IV-type") {
+ ml_g = learner_pars$ml_g$clone()
+ } else {
+ ml_g = NULL
+ }
double_mlpliv_obj = DoubleMLPLIV$new(
data = data_one_way,
n_folds = n_folds,
n_rep = n_rep,
- ml_g = learner_pars$ml_g$clone(),
+ ml_l = learner_pars$ml_l$clone(),
ml_m = learner_pars$ml_m$clone(),
ml_r = learner_pars$ml_r$clone(),
+ ml_g = ml_g,
dml_procedure = dml_procedure,
score = score)
@@ -53,6 +58,11 @@ patrick::with_parameters_test_that("Unit tests for PLIV with one-way clustering:
se_obj = double_mlpliv_obj$se
set.seed(3141)
+ if (score == "IV-type") {
+ ml_g = learner_pars$ml_g$clone()
+ } else {
+ ml_g = NULL
+ }
df = as.data.frame(data_one_way$data)
cluster_var = df$cluster_var_i
# need to drop variables as x is not explicitly set
@@ -60,9 +70,10 @@ patrick::with_parameters_test_that("Unit tests for PLIV with one-way clustering:
pliv_hat = dml_pliv(df,
y = "Y", d = "D", z = "Z",
n_folds = n_folds,
- ml_g = learner_pars$ml_g$clone(),
+ ml_l = learner_pars$ml_l$clone(),
ml_m = learner_pars$ml_m$clone(),
ml_r = learner_pars$ml_r$clone(),
+ ml_g = ml_g,
dml_procedure = dml_procedure, score = score,
smpls = double_mlpliv_obj$smpls,
n_rep = n_rep)
@@ -74,15 +85,19 @@ patrick::with_parameters_test_that("Unit tests for PLIV with one-way clustering:
residuals = compute_pliv_residuals(df,
y = "Y", d = "D", z = "Z",
n_folds = n_folds,
- this_smpl,
- pliv_hat$all_preds[[i_rep]])
- u_hat = residuals$u_hat
- v_hat = residuals$v_hat
- w_hat = residuals$w_hat
+ smpls = this_smpl,
+ all_preds = pliv_hat$all_preds[[i_rep]])
+ y_minus_l_hat = residuals$y_minus_l_hat
+ d_minus_r_hat = residuals$d_minus_r_hat
+ z_minus_m_hat = residuals$z_minus_m_hat
+ y_minus_g_hat = residuals$y_minus_g_hat
+ D = df[, "D"]
- psi_a = -w_hat * v_hat
+ if (score == "partialling out") psi_a = -z_minus_m_hat * d_minus_r_hat
+ if (score == "IV-type") psi_a = -D * z_minus_m_hat
if (dml_procedure == "dml2") {
- psi_b = w_hat * u_hat
+ if (score == "partialling out") psi_b = z_minus_m_hat * y_minus_l_hat
+ if (score == "IV-type") psi_b = z_minus_m_hat * y_minus_g_hat
theta = est_one_way_cluster_dml2(
psi_a, psi_b,
cluster_var,
@@ -90,7 +105,8 @@ patrick::with_parameters_test_that("Unit tests for PLIV with one-way clustering:
} else {
theta = pliv_hat$thetas[i_rep]
}
- psi = (u_hat - v_hat * theta) * w_hat
+ if (score == "partialling out") psi = (y_minus_l_hat - d_minus_r_hat * theta) * z_minus_m_hat
+ if (score == "IV-type") psi = (y_minus_g_hat - D * theta) * z_minus_m_hat
var = var_one_way_cluster(
psi, psi_a,
cluster_var,
diff --git a/tests/testthat/test-double_ml_pliv_parameter_passing.R b/tests/testthat/test-double_ml_pliv_parameter_passing.R
index 1105a4b1..bb6afd75 100644
--- a/tests/testthat/test-double_ml_pliv_parameter_passing.R
+++ b/tests/testthat/test-double_ml_pliv_parameter_passing.R
@@ -13,7 +13,7 @@ if (on_cran) {
test_cases = expand.grid(
learner = "regr.rpart",
dml_procedure = c("dml1", "dml2"),
- score = "partialling out",
+ score = c("partialling out", "IV-type"),
stringsAsFactors = FALSE)
}
@@ -35,15 +35,22 @@ patrick::with_parameters_test_that("Unit tests for parameter passing of PLIV (oo
learner_pars = get_default_mlmethod_pliv(learner)
set.seed(3141)
+ if (score == "IV-type") {
+ ml_g = mlr3::lrn(learner_pars$mlmethod$mlmethod_g)
+ } else {
+ ml_g = NULL
+ }
pliv_hat = dml_pliv(data_pliv$df,
y = "y", d = "d", z = "z",
n_folds = n_folds, n_rep = n_rep,
- ml_g = mlr3::lrn(learner_pars$mlmethod$mlmethod_g),
+ ml_l = mlr3::lrn(learner_pars$mlmethod$mlmethod_l),
ml_m = mlr3::lrn(learner_pars$mlmethod$mlmethod_m),
ml_r = mlr3::lrn(learner_pars$mlmethod$mlmethod_r),
- params_g = learner_pars$params$params_g,
+ ml_g = ml_g,
+ params_l = learner_pars$params$params_l,
params_m = learner_pars$params$params_m,
params_r = learner_pars$params$params_r,
+ params_g = learner_pars$params$params_g,
dml_procedure = dml_procedure, score = score)
theta = pliv_hat$coef
se = pliv_hat$se
@@ -54,22 +61,29 @@ patrick::with_parameters_test_that("Unit tests for parameter passing of PLIV (oo
n_folds = n_folds, n_rep = n_rep,
smpls = pliv_hat$smpls,
all_preds = pliv_hat$all_preds,
- bootstrap = "normal", n_rep_boot = n_rep_boot)$boot_coef
+ bootstrap = "normal", n_rep_boot = n_rep_boot,
+ score = score)$boot_coef
set.seed(3141)
+ if (score == "IV-type") {
+ ml_g = mlr3::lrn(learner_pars$mlmethod$mlmethod_g)
+ } else {
+ ml_g = NULL
+ }
dml_pliv_obj = DoubleMLPLIV$new(
data = data_pliv$dml_data,
n_folds = n_folds, n_rep = n_rep,
- ml_g = mlr3::lrn(learner_pars$mlmethod$mlmethod_g),
+ ml_l = mlr3::lrn(learner_pars$mlmethod$mlmethod_l),
ml_m = mlr3::lrn(learner_pars$mlmethod$mlmethod_m),
ml_r = mlr3::lrn(learner_pars$mlmethod$mlmethod_r),
+ ml_g = ml_g,
dml_procedure = dml_procedure,
score = score)
dml_pliv_obj$set_ml_nuisance_params(
treat_var = "d",
- learner = "ml_g",
- params = learner_pars$params$params_g)
+ learner = "ml_l",
+ params = learner_pars$params$params_l)
dml_pliv_obj$set_ml_nuisance_params(
treat_var = "d",
learner = "ml_m",
@@ -78,6 +92,13 @@ patrick::with_parameters_test_that("Unit tests for parameter passing of PLIV (oo
treat_var = "d",
learner = "ml_r",
params = learner_pars$params$params_r)
+ if (score == "IV-type") {
+ dml_pliv_obj$set_ml_nuisance_params(
+ treat_var = "d",
+ learner = "ml_g",
+ params = learner_pars$params$params_g)
+ }
+
dml_pliv_obj$fit()
@@ -109,35 +130,48 @@ patrick::with_parameters_test_that("Unit tests for parameter passing of PLIV (no
test_ids = list(my_sampling$test_set(1))
smpls = list(list(train_ids = train_ids, test_ids = test_ids))
+ if (score == "IV-type") {
+ ml_g = mlr3::lrn(learner_pars$mlmethod$mlmethod_g)
+ } else {
+ ml_g = NULL
+ }
pliv_hat = dml_pliv(data_pliv$df,
y = "y", d = "d", z = "z",
n_folds = 1,
- ml_g = mlr3::lrn(learner_pars$mlmethod$mlmethod_g),
+ ml_l = mlr3::lrn(learner_pars$mlmethod$mlmethod_l),
ml_m = mlr3::lrn(learner_pars$mlmethod$mlmethod_m),
ml_r = mlr3::lrn(learner_pars$mlmethod$mlmethod_r),
- params_g = learner_pars$params$params_g,
+ ml_g = ml_g,
+ params_l = learner_pars$params$params_l,
params_m = learner_pars$params$params_m,
params_r = learner_pars$params$params_r,
+ params_g = learner_pars$params$params_g,
dml_procedure = dml_procedure, score = score,
smpls = smpls)
theta = pliv_hat$coef
se = pliv_hat$se
set.seed(3141)
+ if (score == "IV-type") {
+ ml_g = mlr3::lrn(learner_pars$mlmethod$mlmethod_g)
+ } else {
+ ml_g = NULL
+ }
dml_pliv_nocf = DoubleMLPLIV$new(
data = data_pliv$dml_data,
n_folds = n_folds,
- ml_g = mlr3::lrn(learner_pars$mlmethod$mlmethod_g),
+ ml_l = mlr3::lrn(learner_pars$mlmethod$mlmethod_l),
ml_m = mlr3::lrn(learner_pars$mlmethod$mlmethod_m),
ml_r = mlr3::lrn(learner_pars$mlmethod$mlmethod_r),
+ ml_g = ml_g,
dml_procedure = dml_procedure,
score = score,
apply_cross_fitting = FALSE)
dml_pliv_nocf$set_ml_nuisance_params(
treat_var = "d",
- learner = "ml_g",
- params = learner_pars$params$params_g)
+ learner = "ml_l",
+ params = learner_pars$params$params_l)
dml_pliv_nocf$set_ml_nuisance_params(
treat_var = "d",
learner = "ml_m",
@@ -146,6 +180,12 @@ patrick::with_parameters_test_that("Unit tests for parameter passing of PLIV (no
treat_var = "d",
learner = "ml_r",
params = learner_pars$params$params_r)
+ if (score == "IV-type") {
+ dml_pliv_nocf$set_ml_nuisance_params(
+ treat_var = "d",
+ learner = "ml_g",
+ params = learner_pars$params$params_g)
+ }
dml_pliv_nocf$fit()
theta_obj = dml_pliv_nocf$coef
@@ -164,18 +204,24 @@ patrick::with_parameters_test_that("Unit tests for parameter passing of PLIV (fo
learner_pars = get_default_mlmethod_pliv(learner)
set.seed(3141)
+ if (score == "IV-type") {
+ ml_g = mlr3::lrn(learner_pars$mlmethod$mlmethod_g)
+ } else {
+ ml_g = NULL
+ }
dml_pliv_obj = DoubleMLPLIV$new(data_pliv$dml_data,
n_folds = n_folds, n_rep = n_rep,
- ml_g = mlr3::lrn(learner_pars$mlmethod$mlmethod_g),
+ ml_l = mlr3::lrn(learner_pars$mlmethod$mlmethod_l),
ml_m = mlr3::lrn(learner_pars$mlmethod$mlmethod_m),
ml_r = mlr3::lrn(learner_pars$mlmethod$mlmethod_r),
+ ml_g = ml_g,
dml_procedure = dml_procedure,
score = score)
dml_pliv_obj$set_ml_nuisance_params(
treat_var = "d",
- learner = "ml_g",
- params = learner_pars$params$params_g)
+ learner = "ml_l",
+ params = learner_pars$params$params_l)
dml_pliv_obj$set_ml_nuisance_params(
treat_var = "d",
learner = "ml_m",
@@ -184,28 +230,41 @@ patrick::with_parameters_test_that("Unit tests for parameter passing of PLIV (fo
treat_var = "d",
learner = "ml_r",
params = learner_pars$params$params_r)
+ if (score == "IV-type") {
+ dml_pliv_obj$set_ml_nuisance_params(
+ treat_var = "d",
+ learner = "ml_g",
+ params = learner_pars$params$params_g)
+ }
dml_pliv_obj$fit()
theta = dml_pliv_obj$coef
se = dml_pliv_obj$se
- params_g_fold_wise = rep(list(rep(list(learner_pars$params$params_g), n_folds)), n_rep)
+ params_l_fold_wise = rep(list(rep(list(learner_pars$params$params_l), n_folds)), n_rep)
params_m_fold_wise = rep(list(rep(list(learner_pars$params$params_m), n_folds)), n_rep)
params_r_fold_wise = rep(list(rep(list(learner_pars$params$params_r), n_folds)), n_rep)
+ params_g_fold_wise = rep(list(rep(list(learner_pars$params$params_g), n_folds)), n_rep)
set.seed(3141)
+ if (score == "IV-type") {
+ ml_g = mlr3::lrn(learner_pars$mlmethod$mlmethod_g)
+ } else {
+ ml_g = NULL
+ }
dml_pliv_obj_fold_wise = DoubleMLPLIV$new(data_pliv$dml_data,
n_folds = n_folds, n_rep = n_rep,
- ml_g = mlr3::lrn(learner_pars$mlmethod$mlmethod_g),
+ ml_l = mlr3::lrn(learner_pars$mlmethod$mlmethod_l),
ml_m = mlr3::lrn(learner_pars$mlmethod$mlmethod_m),
ml_r = mlr3::lrn(learner_pars$mlmethod$mlmethod_r),
+ ml_g = ml_g,
dml_procedure = dml_procedure,
score = score)
dml_pliv_obj_fold_wise$set_ml_nuisance_params(
treat_var = "d",
- learner = "ml_g",
- params = params_g_fold_wise,
+ learner = "ml_l",
+ params = params_l_fold_wise,
set_fold_specific = TRUE)
dml_pliv_obj_fold_wise$set_ml_nuisance_params(
treat_var = "d",
@@ -217,6 +276,13 @@ patrick::with_parameters_test_that("Unit tests for parameter passing of PLIV (fo
learner = "ml_r",
params = params_r_fold_wise,
set_fold_specific = TRUE)
+ if (score == "IV-type") {
+ dml_pliv_obj_fold_wise$set_ml_nuisance_params(
+ treat_var = "d",
+ learner = "ml_g",
+ params = params_g_fold_wise,
+ set_fold_specific = TRUE)
+ }
dml_pliv_obj_fold_wise$fit()
theta_fold_wise = dml_pliv_obj_fold_wise$coef
@@ -232,16 +298,23 @@ patrick::with_parameters_test_that("Unit tests for parameter passing of PLIV (de
n_folds = 2
n_rep = 3
- params_g = list(cp = 0.01, minsplit = 20) # this are defaults
+ params_l = list(cp = 0.01, minsplit = 20) # this are defaults
params_m = list(cp = 0.01, minsplit = 20) # this are defaults
params_r = list(cp = 0.01, minsplit = 20) # this are defaults
+ params_g = list(cp = 0.01, minsplit = 20) # this are defaults
set.seed(3141)
+ if (score == "IV-type") {
+ ml_g = lrn("regr.rpart")
+ } else {
+ ml_g = NULL
+ }
dml_pliv_default = DoubleMLPLIV$new(data_pliv$dml_data,
n_folds = n_folds, n_rep = n_rep,
- ml_g = lrn("regr.rpart"),
+ ml_l = lrn("regr.rpart"),
ml_m = lrn("regr.rpart"),
ml_r = lrn("regr.rpart"),
+ ml_g = ml_g,
dml_procedure = dml_procedure,
score = score)
@@ -250,19 +323,25 @@ patrick::with_parameters_test_that("Unit tests for parameter passing of PLIV (de
se_default = dml_pliv_default$se
set.seed(3141)
+ if (score == "IV-type") {
+ ml_g = lrn("regr.rpart")
+ } else {
+ ml_g = NULL
+ }
dml_pliv_obj = DoubleMLPLIV$new(
data = data_pliv$dml_data,
n_folds = n_folds, n_rep = n_rep,
- ml_g = lrn("regr.rpart"),
+ ml_l = lrn("regr.rpart"),
ml_m = lrn("regr.rpart"),
ml_r = lrn("regr.rpart"),
+ ml_g = ml_g,
dml_procedure = dml_procedure,
score = score)
dml_pliv_obj$set_ml_nuisance_params(
treat_var = "d",
- learner = "ml_g",
- params = params_g)
+ learner = "ml_l",
+ params = params_l)
dml_pliv_obj$set_ml_nuisance_params(
treat_var = "d",
learner = "ml_m",
@@ -271,6 +350,12 @@ patrick::with_parameters_test_that("Unit tests for parameter passing of PLIV (de
treat_var = "d",
learner = "ml_r",
params = params_r)
+ if (score == "IV-type") {
+ dml_pliv_obj$set_ml_nuisance_params(
+ treat_var = "d",
+ learner = "ml_g",
+ params = params_g)
+ }
dml_pliv_obj$fit()
theta = dml_pliv_obj$coef
diff --git a/tests/testthat/test-double_ml_pliv_partial_functional_initializer.R b/tests/testthat/test-double_ml_pliv_partial_functional_initializer.R
index 253903e6..36902402 100644
--- a/tests/testthat/test-double_ml_pliv_partial_functional_initializer.R
+++ b/tests/testthat/test-double_ml_pliv_partial_functional_initializer.R
@@ -31,7 +31,7 @@ patrick::with_parameters_test_that("Unit tests for PLIV (partialX functional ini
set.seed(3141)
double_mlpliv_obj = DoubleMLPLIV$new(data_ml,
n_folds = 5,
- ml_g = learner_pars$ml_g$clone(),
+ ml_l = learner_pars$ml_l$clone(),
ml_m = learner_pars$ml_m$clone(),
ml_r = learner_pars$ml_r$clone(),
dml_procedure = dml_procedure,
@@ -45,7 +45,7 @@ patrick::with_parameters_test_that("Unit tests for PLIV (partialX functional ini
set.seed(3141)
double_mlpliv_partX = DoubleMLPLIV.partialX(data_ml,
n_folds = 5,
- ml_g = learner_pars$ml_g$clone(),
+ ml_l = learner_pars$ml_l$clone(),
ml_m = learner_pars$ml_m$clone(),
ml_r = learner_pars$ml_r$clone(),
dml_procedure = dml_procedure,
@@ -72,7 +72,7 @@ patrick::with_parameters_test_that("Unit tests for PLIV (partialZ functional ini
set.seed(3141)
double_mlpliv_partZ = DoubleMLPLIV$new(data_ml,
n_folds = 5,
- ml_g = NULL,
+ ml_l = NULL,
ml_m = NULL,
ml_r = learner_pars$ml_r$clone(),
dml_procedure = dml_procedure,
@@ -111,7 +111,7 @@ patrick::with_parameters_test_that("Unit tests for PLIV (partialXZ functional in
set.seed(3141)
double_mlpliv_partXZ = DoubleMLPLIV$new(data_ml,
n_folds = 5,
- ml_g = learner_pars$ml_g$clone(),
+ ml_l = learner_pars$ml_l$clone(),
ml_m = learner_pars$ml_m$clone(),
ml_r = learner_pars$ml_r$clone(),
dml_procedure = dml_procedure,
@@ -125,7 +125,7 @@ patrick::with_parameters_test_that("Unit tests for PLIV (partialXZ functional in
set.seed(3141)
double_mlpliv_partXZ_fun = DoubleMLPLIV.partialXZ(data_ml,
n_folds = 5,
- ml_g = learner_pars$ml_g$clone(),
+ ml_l = learner_pars$ml_l$clone(),
ml_m = learner_pars$ml_m$clone(),
ml_r = learner_pars$ml_r$clone(),
dml_procedure = dml_procedure,
diff --git a/tests/testthat/test-double_ml_pliv_partial_functional_initializer_IVtype.R b/tests/testthat/test-double_ml_pliv_partial_functional_initializer_IVtype.R
new file mode 100644
index 00000000..35a277a1
--- /dev/null
+++ b/tests/testthat/test-double_ml_pliv_partial_functional_initializer_IVtype.R
@@ -0,0 +1,63 @@
+context("Unit tests for PLIV, partialling out X, Z, XZ")
+
+lgr::get_logger("mlr3")$set_threshold("warn")
+
+on_cran = !identical(Sys.getenv("NOT_CRAN"), "true")
+if (on_cran) {
+ test_cases = expand.grid(
+ learner = "regr.lm",
+ dml_procedure = "dml2",
+ score = "IV-type",
+ stringsAsFactors = FALSE)
+} else {
+ test_cases = expand.grid(
+ learner = c("regr.lm", "regr.cv_glmnet"),
+ dml_procedure = c("dml1", "dml2"),
+ score = "IV-type",
+ stringsAsFactors = FALSE)
+}
+test_cases[".test_name"] = apply(test_cases, 1, paste, collapse = "_")
+
+patrick::with_parameters_test_that("Unit tests for PLIV (partialX functional initialization):",
+ .cases = test_cases, {
+ learner_pars = get_default_mlmethod_pliv(learner)
+ df = data_pliv$df
+ Xnames = names(df)[names(df) %in% c("y", "d", "z", "z2") == FALSE]
+ data_ml = double_ml_data_from_data_frame(df,
+ y_col = "y",
+ d_cols = "d", x_cols = Xnames, z_cols = "z")
+
+ # Partial out X (default PLIV)
+ set.seed(3141)
+ double_mlpliv_obj = DoubleMLPLIV$new(data_ml,
+ n_folds = 5,
+ ml_l = learner_pars$ml_l$clone(),
+ ml_m = learner_pars$ml_m$clone(),
+ ml_r = learner_pars$ml_r$clone(),
+ ml_g = learner_pars$ml_g$clone(),
+ dml_procedure = dml_procedure,
+ score = score)
+
+ double_mlpliv_obj$fit()
+ theta_obj = double_mlpliv_obj$coef
+ se_obj = double_mlpliv_obj$se
+
+ # Partial out X
+ set.seed(3141)
+ double_mlpliv_partX = DoubleMLPLIV.partialX(data_ml,
+ n_folds = 5,
+ ml_l = learner_pars$ml_l$clone(),
+ ml_m = learner_pars$ml_m$clone(),
+ ml_r = learner_pars$ml_r$clone(),
+ ml_g = learner_pars$ml_g$clone(),
+ dml_procedure = dml_procedure,
+ score = score)
+
+ double_mlpliv_partX$fit()
+ theta_partX = double_mlpliv_partX$coef
+ se_partX = double_mlpliv_partX$se
+
+ expect_equal(theta_partX, theta_obj, tolerance = 1e-8)
+ expect_equal(se_partX, se_obj, tolerance = 1e-8)
+ }
+)
diff --git a/tests/testthat/test-double_ml_pliv_partial_x.R b/tests/testthat/test-double_ml_pliv_partial_x.R
index 00c1df72..bbfc6be8 100644
--- a/tests/testthat/test-double_ml_pliv_partial_x.R
+++ b/tests/testthat/test-double_ml_pliv_partial_x.R
@@ -21,7 +21,7 @@ patrick::with_parameters_test_that("Unit tests for PLIV.partialX:",
pliv_hat = dml_pliv_partial_x(data_pliv_partialX$df,
y = "y", d = "d", z = paste0("Z", 1:dim_z),
n_folds = 5,
- ml_g = learner_pars$ml_g$clone(),
+ ml_l = learner_pars$ml_l$clone(),
ml_m = learner_pars$ml_m$clone(),
ml_r = learner_pars$ml_r$clone(),
dml_procedure = dml_procedure, score = score)
@@ -38,7 +38,7 @@ patrick::with_parameters_test_that("Unit tests for PLIV.partialX:",
set.seed(3141)
double_mlpliv_obj = DoubleMLPLIV.partialX(data_pliv_partialX$dml_data,
- ml_g = learner_pars$ml_g$clone(),
+ ml_l = learner_pars$ml_l$clone(),
ml_m = learner_pars$ml_m$clone(),
ml_r = learner_pars$ml_r$clone(),
n_folds = 5,
@@ -67,7 +67,7 @@ test_that("Unit tests for PLIV.partialX invalid score", {
"partialX=TRUE and partialZ=FALSE with several instruments.")
double_mlplr_obj <- DoubleMLPLIV.partialX(
data_pliv_partialX$dml_data,
- ml_g = mlr3::lrn("regr.rpart"),
+ ml_l = mlr3::lrn("regr.rpart"),
ml_m = mlr3::lrn("regr.rpart"),
ml_r = mlr3::lrn("regr.rpart"),
score = function(x) {
diff --git a/tests/testthat/test-double_ml_pliv_partial_xz.R b/tests/testthat/test-double_ml_pliv_partial_xz.R
index 2812c36e..befef206 100644
--- a/tests/testthat/test-double_ml_pliv_partial_xz.R
+++ b/tests/testthat/test-double_ml_pliv_partial_xz.R
@@ -23,7 +23,7 @@ patrick::with_parameters_test_that("Unit tests for PLIV.partialXZ:",
pliv_hat = dml_pliv_partial_xz(data_pliv_partialXZ$df,
y = "y", d = "d", z = paste0("Z", 1:dim_z),
n_folds = 5,
- ml_g = learner_pars$ml_g$clone(),
+ ml_l = learner_pars$ml_l$clone(),
ml_m = learner_pars$ml_m$clone(),
ml_r = learner_pars$ml_r$clone(),
dml_procedure = dml_procedure, score = score)
@@ -39,7 +39,7 @@ patrick::with_parameters_test_that("Unit tests for PLIV.partialXZ:",
set.seed(3141)
double_mlpliv_obj = DoubleMLPLIV.partialXZ(data_pliv_partialXZ$dml_data,
- ml_g = learner_pars$ml_g$clone(),
+ ml_l = learner_pars$ml_l$clone(),
ml_m = learner_pars$ml_m$clone(),
ml_r = learner_pars$ml_r$clone(),
n_folds = 5,
@@ -67,7 +67,7 @@ test_that("Unit tests for PLIV.partialXZ invalid score", {
"partialX=TRUE and partialZ=TRUE.")
double_mlplr_obj <- DoubleMLPLIV.partialXZ(
data_pliv_partialXZ$dml_data,
- ml_g = mlr3::lrn("regr.rpart"),
+ ml_l = mlr3::lrn("regr.rpart"),
ml_m = mlr3::lrn("regr.rpart"),
ml_r = mlr3::lrn("regr.rpart"),
score = function(x) {
diff --git a/tests/testthat/test-double_ml_pliv_partial_xz_parameter_passing.R b/tests/testthat/test-double_ml_pliv_partial_xz_parameter_passing.R
index 77798f29..a07ffef6 100644
--- a/tests/testthat/test-double_ml_pliv_partial_xz_parameter_passing.R
+++ b/tests/testthat/test-double_ml_pliv_partial_xz_parameter_passing.R
@@ -32,10 +32,10 @@ patrick::with_parameters_test_that("Unit tests for parameter passing of PLIV.par
pliv_hat = dml_pliv_partial_xz(df,
y = "y", d = "d", z = c("z", "z2"),
n_folds = n_folds, n_rep = n_rep,
- ml_g = mlr3::lrn(learner_pars$mlmethod$mlmethod_g),
+ ml_l = mlr3::lrn(learner_pars$mlmethod$mlmethod_l),
ml_m = mlr3::lrn(learner_pars$mlmethod$mlmethod_m),
ml_r = mlr3::lrn(learner_pars$mlmethod$mlmethod_r),
- params_g = learner_pars$params$params_g,
+ params_l = learner_pars$params$params_l,
params_m = learner_pars$params$params_m,
params_r = learner_pars$params$params_r,
dml_procedure = dml_procedure, score = score)
@@ -59,7 +59,7 @@ patrick::with_parameters_test_that("Unit tests for parameter passing of PLIV.par
dml_pliv_obj = DoubleMLPLIV.partialXZ(
data = dml_data,
n_folds = n_folds, n_rep = n_rep,
- ml_g = mlr3::lrn(learner_pars$mlmethod$mlmethod_g),
+ ml_l = mlr3::lrn(learner_pars$mlmethod$mlmethod_l),
ml_m = mlr3::lrn(learner_pars$mlmethod$mlmethod_m),
ml_r = mlr3::lrn(learner_pars$mlmethod$mlmethod_r),
dml_procedure = dml_procedure,
@@ -67,8 +67,8 @@ patrick::with_parameters_test_that("Unit tests for parameter passing of PLIV.par
dml_pliv_obj$set_ml_nuisance_params(
treat_var = "d",
- learner = "ml_g",
- params = learner_pars$params$params_g)
+ learner = "ml_l",
+ params = learner_pars$params$params_l)
dml_pliv_obj$set_ml_nuisance_params(
treat_var = "d",
learner = "ml_m",
@@ -112,10 +112,10 @@ patrick::with_parameters_test_that("Unit tests for parameter passing of PLIV.par
pliv_hat = dml_pliv_partial_xz(df,
y = "y", d = "d", z = c("z", "z2"),
n_folds = 1,
- ml_g = mlr3::lrn(learner_pars$mlmethod$mlmethod_g),
+ ml_l = mlr3::lrn(learner_pars$mlmethod$mlmethod_l),
ml_m = mlr3::lrn(learner_pars$mlmethod$mlmethod_m),
ml_r = mlr3::lrn(learner_pars$mlmethod$mlmethod_r),
- params_g = learner_pars$params$params_g,
+ params_l = learner_pars$params$params_l,
params_m = learner_pars$params$params_m,
params_r = learner_pars$params$params_r,
dml_procedure = dml_procedure, score = score,
@@ -132,7 +132,7 @@ patrick::with_parameters_test_that("Unit tests for parameter passing of PLIV.par
dml_pliv_nocf = DoubleMLPLIV.partialXZ(
data = dml_data,
n_folds = n_folds,
- ml_g = mlr3::lrn(learner_pars$mlmethod$mlmethod_g),
+ ml_l = mlr3::lrn(learner_pars$mlmethod$mlmethod_l),
ml_m = mlr3::lrn(learner_pars$mlmethod$mlmethod_m),
ml_r = mlr3::lrn(learner_pars$mlmethod$mlmethod_r),
dml_procedure = dml_procedure,
@@ -141,8 +141,8 @@ patrick::with_parameters_test_that("Unit tests for parameter passing of PLIV.par
dml_pliv_nocf$set_ml_nuisance_params(
treat_var = "d",
- learner = "ml_g",
- params = learner_pars$params$params_g)
+ learner = "ml_l",
+ params = learner_pars$params$params_l)
dml_pliv_nocf$set_ml_nuisance_params(
treat_var = "d",
learner = "ml_m",
@@ -177,7 +177,7 @@ patrick::with_parameters_test_that("Unit tests for parameter passing of PLIV.par
set.seed(3141)
dml_pliv_obj = DoubleMLPLIV.partialXZ(dml_data,
n_folds = n_folds, n_rep = n_rep,
- ml_g = mlr3::lrn(learner_pars$mlmethod$mlmethod_g),
+ ml_l = mlr3::lrn(learner_pars$mlmethod$mlmethod_l),
ml_m = mlr3::lrn(learner_pars$mlmethod$mlmethod_m),
ml_r = mlr3::lrn(learner_pars$mlmethod$mlmethod_r),
dml_procedure = dml_procedure,
@@ -185,8 +185,8 @@ patrick::with_parameters_test_that("Unit tests for parameter passing of PLIV.par
dml_pliv_obj$set_ml_nuisance_params(
treat_var = "d",
- learner = "ml_g",
- params = learner_pars$params$params_g)
+ learner = "ml_l",
+ params = learner_pars$params$params_l)
dml_pliv_obj$set_ml_nuisance_params(
treat_var = "d",
learner = "ml_m",
@@ -200,14 +200,14 @@ patrick::with_parameters_test_that("Unit tests for parameter passing of PLIV.par
theta = dml_pliv_obj$coef
se = dml_pliv_obj$se
- params_g_fold_wise = rep(list(rep(list(learner_pars$params$params_g), n_folds)), n_rep)
+ params_l_fold_wise = rep(list(rep(list(learner_pars$params$params_l), n_folds)), n_rep)
params_m_fold_wise = rep(list(rep(list(learner_pars$params$params_m), n_folds)), n_rep)
params_r_fold_wise = rep(list(rep(list(learner_pars$params$params_r), n_folds)), n_rep)
set.seed(3141)
dml_pliv_obj_fold_wise = DoubleMLPLIV.partialXZ(dml_data,
n_folds = n_folds, n_rep = n_rep,
- ml_g = mlr3::lrn(learner_pars$mlmethod$mlmethod_g),
+ ml_l = mlr3::lrn(learner_pars$mlmethod$mlmethod_l),
ml_m = mlr3::lrn(learner_pars$mlmethod$mlmethod_m),
ml_r = mlr3::lrn(learner_pars$mlmethod$mlmethod_r),
dml_procedure = dml_procedure,
@@ -215,8 +215,8 @@ patrick::with_parameters_test_that("Unit tests for parameter passing of PLIV.par
dml_pliv_obj_fold_wise$set_ml_nuisance_params(
treat_var = "d",
- learner = "ml_g",
- params = params_g_fold_wise,
+ learner = "ml_l",
+ params = params_l_fold_wise,
set_fold_specific = TRUE)
dml_pliv_obj_fold_wise$set_ml_nuisance_params(
treat_var = "d",
@@ -243,7 +243,7 @@ patrick::with_parameters_test_that("Unit tests for parameter passing of PLIV.par
n_folds = 2
n_rep = 3
- params_g = list(cp = 0.01, minsplit = 20) # this are defaults
+ params_l = list(cp = 0.01, minsplit = 20) # this are defaults
params_m = list(cp = 0.01, minsplit = 20) # this are defaults
params_r = list(cp = 0.01, minsplit = 20) # this are defaults
@@ -256,7 +256,7 @@ patrick::with_parameters_test_that("Unit tests for parameter passing of PLIV.par
set.seed(3141)
dml_pliv_default = DoubleMLPLIV.partialXZ(dml_data,
n_folds = n_folds, n_rep = n_rep,
- ml_g = lrn("regr.rpart"),
+ ml_l = lrn("regr.rpart"),
ml_m = lrn("regr.rpart"),
ml_r = lrn("regr.rpart"),
dml_procedure = dml_procedure,
@@ -269,7 +269,7 @@ patrick::with_parameters_test_that("Unit tests for parameter passing of PLIV.par
set.seed(3141)
dml_pliv_obj = DoubleMLPLIV.partialXZ(dml_data,
n_folds = n_folds, n_rep = n_rep,
- ml_g = lrn("regr.rpart"),
+ ml_l = lrn("regr.rpart"),
ml_m = lrn("regr.rpart"),
ml_r = lrn("regr.rpart"),
dml_procedure = dml_procedure,
@@ -277,8 +277,8 @@ patrick::with_parameters_test_that("Unit tests for parameter passing of PLIV.par
dml_pliv_obj$set_ml_nuisance_params(
treat_var = "d",
- learner = "ml_g",
- params = params_g)
+ learner = "ml_l",
+ params = params_l)
dml_pliv_obj$set_ml_nuisance_params(
treat_var = "d",
learner = "ml_m",
diff --git a/tests/testthat/test-double_ml_pliv_tuning.R b/tests/testthat/test-double_ml_pliv_tuning.R
index 0687d250..7cee8b40 100644
--- a/tests/testthat/test-double_ml_pliv_tuning.R
+++ b/tests/testthat/test-double_ml_pliv_tuning.R
@@ -15,41 +15,141 @@ tune_settings = list(
n_rep_tune = 1,
rsmp_tune = "cv",
measure = list(
- "ml_g" = "regr.mse",
+ "ml_l" = "regr.mse",
"ml_r" = "regr.mse",
"ml_m" = "regr.mse"),
terminator = mlr3tuning::trm("evals", n_evals = 2),
algorithm = "grid_search",
- tuning_instance_g = NULL,
+ tuning_instance_l = NULL,
tuning_instance_m = NULL,
tuner = "grid_search",
resolution = 5)
on_cran = !identical(Sys.getenv("NOT_CRAN"), "true")
if (on_cran) {
- test_cases = expand.grid(
+ test_cases_one_z = expand.grid(
dml_procedure = "dml2",
score = "partialling out",
n_rep = c(1),
tune_on_folds = FALSE,
- z_indx = c(1),
stringsAsFactors = FALSE)
} else {
- test_cases = expand.grid(
+ test_cases_one_z = expand.grid(
dml_procedure = c("dml1", "dml2"),
- score = "partialling out",
+ score = c("partialling out", "IV-type"),
n_rep = c(1, 3),
tune_on_folds = c(FALSE, TRUE),
- z_indx = c(1, 2),
stringsAsFactors = FALSE)
}
-test_cases[".test_name"] = apply(test_cases, 1, paste, collapse = "_")
+test_cases_one_z[".test_name"] = apply(test_cases_one_z, 1, paste, collapse = "_")
# skip('Skip tests for tuning')
patrick::with_parameters_test_that("Unit tests for tuning of PLIV",
- .cases = test_cases, {
+ .cases = test_cases_one_z, {
+
+ # TBD: Functional Test Case
+
+ set.seed(3141)
+ n_folds = 2
+ n_rep_boot = 498
+
+ z_cols = "z"
+ set.seed(3141)
+ df = data_pliv$df
+ Xnames = names(df)[names(df) %in% c("y", "d", "z", "z2") == FALSE]
+ data_ml = double_ml_data_from_data_frame(df,
+ y_col = "y",
+ d_cols = "d", x_cols = Xnames, z_cols = z_cols)
+
+ if (score == "IV-type") {
+ ml_g = learner
+ } else {
+ ml_g = NULL
+ }
+ double_mlpliv_obj_tuned = DoubleMLPLIV$new(data_ml,
+ n_folds = n_folds,
+ ml_l = learner,
+ ml_m = learner,
+ ml_r = learner,
+ ml_g = ml_g,
+ dml_procedure = dml_procedure,
+ score = score,
+ n_rep = n_rep)
+
+ param_grid = list(
+ "ml_l" = paradox::ParamSet$new(list(
+ paradox::ParamDbl$new("cp", lower = 0.01, upper = 0.02),
+ paradox::ParamInt$new("minsplit", lower = 1, upper = 2))),
+ "ml_m" = paradox::ParamSet$new(list(
+ paradox::ParamDbl$new("cp", lower = 0.01, upper = 0.02),
+ paradox::ParamInt$new("minsplit", lower = 1, upper = 2))),
+ "ml_r" = paradox::ParamSet$new(list(
+ paradox::ParamDbl$new("cp", lower = 0.01, upper = 0.02),
+ paradox::ParamInt$new("minsplit", lower = 1, upper = 2))))
+ if (score == "IV-type") {
+ param_grid[["ml_g"]] = paradox::ParamSet$new(list(
+ paradox::ParamDbl$new("cp", lower = 0.01, upper = 0.02),
+ paradox::ParamInt$new("minsplit", lower = 1, upper = 2)))
+ tune_settings[["measure"]][["ml_g"]] = "regr.mse"
+ }
+
+ double_mlpliv_obj_tuned$tune(param_set = param_grid, tune_settings = tune_settings, tune_on_folds = tune_on_folds)
+ double_mlpliv_obj_tuned$fit()
+
+ theta_obj_tuned = double_mlpliv_obj_tuned$coef
+ se_obj_tuned = double_mlpliv_obj_tuned$se
+
+ # bootstrap
+ # double_mlplr_obj_exact$bootstrap(method = 'normal', n_rep = n_rep_boot)
+ # boot_theta_obj_exact = double_mlplr_obj_exact$boot_coef
+
+ expect_is(theta_obj_tuned, "numeric")
+ expect_is(se_obj_tuned, "numeric")
+
+ # if (data_ml$n_instr() == 1) {
+ # double_mlpliv_obj_tuned_Z = DoubleMLPLIV.partialZ(data_ml,
+ # n_folds = n_folds,
+ # ml_r = learner,
+ # dml_procedure = dml_procedure,
+ # score = score,
+ # n_rep = n_rep)
+ #
+ # double_mlpliv_obj_tuned_Z$tune(param_set = param_grid, tune_on_folds = tune_on_folds)
+ # double_mlpliv_obj_tuned_Z$fit()
+ #
+ # theta_obj_tuned_Z = double_mlpliv_obj_tuned_Z$coef
+ # se_obj_tuned_Z = double_mlpliv_obj_tuned_Z$se
+ #
+ # expect_is(theta_obj_tuned_Z, "numeric")
+ # expect_is(se_obj_tuned_Z, "numeric")
+ # }
+ #
+ }
+)
+
+on_cran = !identical(Sys.getenv("NOT_CRAN"), "true")
+if (on_cran) {
+ test_cases_multiple_z = expand.grid(
+ dml_procedure = "dml2",
+ score = "partialling out",
+ n_rep = c(1),
+ tune_on_folds = FALSE,
+ stringsAsFactors = FALSE)
+} else {
+ test_cases_multiple_z = expand.grid(
+ dml_procedure = c("dml1", "dml2"),
+ score = "partialling out",
+ n_rep = c(1, 3),
+ tune_on_folds = c(FALSE, TRUE),
+ stringsAsFactors = FALSE)
+}
+
+test_cases_multiple_z[".test_name"] = apply(test_cases_multiple_z, 1, paste, collapse = "_")
+
+patrick::with_parameters_test_that("Unit tests for tuning of PLIV (multiple Z)",
+ .cases = test_cases_multiple_z, {
# TBD: Functional Test Case
@@ -57,17 +157,7 @@ patrick::with_parameters_test_that("Unit tests for tuning of PLIV",
n_folds = 2
n_rep_boot = 498
- # set.seed(3141)
- # pliv_hat = dml_plriv(data_pliv, y = "y", d = "d", z = 'z',
- # n_folds = n_folds, mlmethod = learner_list,
- # params = learner_pars$params,
- # dml_procedure = dml_procedure, score = score,
- # bootstrap = "normal", n_rep_boot = n_rep_boot)
- # theta = coef(pliv_hat)
- # se = pliv_hat$se
-
- z_vars = list("z", c("z", "z2"))
- z_cols = z_vars[[z_indx]]
+ z_cols = c("z", "z2")
set.seed(3141)
df = data_pliv$df
Xnames = names(df)[names(df) %in% c("y", "d", "z", "z2") == FALSE]
@@ -77,7 +167,7 @@ patrick::with_parameters_test_that("Unit tests for tuning of PLIV",
double_mlpliv_obj_tuned = DoubleMLPLIV$new(data_ml,
n_folds = n_folds,
- ml_g = learner,
+ ml_l = learner,
ml_m = learner,
ml_r = learner,
dml_procedure = dml_procedure,
@@ -85,7 +175,7 @@ patrick::with_parameters_test_that("Unit tests for tuning of PLIV",
n_rep = n_rep)
param_grid = list(
- "ml_g" = paradox::ParamSet$new(list(
+ "ml_l" = paradox::ParamSet$new(list(
paradox::ParamDbl$new("cp", lower = 0.01, upper = 0.02),
paradox::ParamInt$new("minsplit", lower = 1, upper = 2))),
"ml_m" = paradox::ParamSet$new(list(
@@ -137,7 +227,7 @@ patrick::with_parameters_test_that("Unit tests for tuning of PLIV",
param_grid_r = list("ml_r" = param_grid[["ml_r"]])
tune_settings_r = tune_settings
- tune_settings_r$measure$ml_g = tune_settings_r$measure$ml_m = NULL
+ tune_settings_r$measure$ml_l = tune_settings_r$measure$ml_m = NULL
double_mlpliv_obj_tuned_Z$tune(
param_set = param_grid_r, tune_on_folds = tune_on_folds,
tune_settings = tune_settings_r)
@@ -152,7 +242,7 @@ patrick::with_parameters_test_that("Unit tests for tuning of PLIV",
set.seed(3141)
double_mlpliv_obj_tuned_XZ = DoubleMLPLIV.partialXZ(data_ml,
n_folds = n_folds,
- ml_g = learner,
+ ml_l = learner,
ml_m = learner,
ml_r = learner,
dml_procedure = dml_procedure,
diff --git a/tests/testthat/test-double_ml_pliv_two_way_cluster.R b/tests/testthat/test-double_ml_pliv_two_way_cluster.R
index 5353f372..4512a1d0 100644
--- a/tests/testthat/test-double_ml_pliv_two_way_cluster.R
+++ b/tests/testthat/test-double_ml_pliv_two_way_cluster.R
@@ -15,7 +15,7 @@ if (on_cran) {
test_cases = expand.grid(
learner = c("regr.lm", "regr.glmnet"),
dml_procedure = c("dml1", "dml2"),
- score = "partialling out",
+ score = c("partialling out", "IV-type"),
stringsAsFactors = FALSE)
}
test_cases[".test_name"] = apply(test_cases, 1, paste, collapse = "_")
@@ -31,12 +31,18 @@ patrick::with_parameters_test_that("Unit tests for PLIV with two-way clustering:
learner_pars = get_default_mlmethod_pliv(learner)
set.seed(3141)
+ if (score == "IV-type") {
+ ml_g = learner_pars$ml_g$clone()
+ } else {
+ ml_g = NULL
+ }
double_mlpliv_obj = DoubleMLPLIV$new(
data = data_two_way,
n_folds = 2,
- ml_g = learner_pars$ml_g$clone(),
+ ml_l = learner_pars$ml_l$clone(),
ml_m = learner_pars$ml_m$clone(),
ml_r = learner_pars$ml_r$clone(),
+ ml_g = ml_g,
dml_procedure = dml_procedure,
score = score)
@@ -51,12 +57,18 @@ patrick::with_parameters_test_that("Unit tests for PLIV with two-way clustering:
cluster_var2 = df$cluster_var_j
# need to drop variables as x is not explicitly set
df = df[, !(names(df) %in% c("cluster_var_i", "cluster_var_j"))]
+ if (score == "IV-type") {
+ ml_g = learner_pars$ml_g$clone()
+ } else {
+ ml_g = NULL
+ }
pliv_hat = dml_pliv(df,
y = "Y", d = "D", z = "Z",
n_folds = 4,
- ml_g = learner_pars$ml_g$clone(),
+ ml_l = learner_pars$ml_l$clone(),
ml_m = learner_pars$ml_m$clone(),
ml_r = learner_pars$ml_r$clone(),
+ ml_g = ml_g,
dml_procedure = dml_procedure, score = score,
smpls = double_mlpliv_obj$smpls)
@@ -65,15 +77,19 @@ patrick::with_parameters_test_that("Unit tests for PLIV with two-way clustering:
residuals = compute_pliv_residuals(df,
y = "Y", d = "D", z = "Z",
n_folds = 4,
- this_smpl,
- pliv_hat$all_preds[[1]])
- u_hat = residuals$u_hat
- v_hat = residuals$v_hat
- w_hat = residuals$w_hat
+ smpls = this_smpl,
+ all_preds = pliv_hat$all_preds[[1]])
+ y_minus_l_hat = residuals$y_minus_l_hat
+ d_minus_r_hat = residuals$d_minus_r_hat
+ z_minus_m_hat = residuals$z_minus_m_hat
+ y_minus_g_hat = residuals$y_minus_g_hat
+ D = df[, "D"]
- psi_a = -w_hat * v_hat
+ if (score == "partialling out") psi_a = -z_minus_m_hat * d_minus_r_hat
+ if (score == "IV-type") psi_a = -D * z_minus_m_hat
if (dml_procedure == "dml2") {
- psi_b = w_hat * u_hat
+ if (score == "partialling out") psi_b = z_minus_m_hat * y_minus_l_hat
+ if (score == "IV-type") psi_b = z_minus_m_hat * y_minus_g_hat
theta = est_two_way_cluster_dml2(
psi_a, psi_b,
cluster_var1,
@@ -82,7 +98,8 @@ patrick::with_parameters_test_that("Unit tests for PLIV with two-way clustering:
} else {
theta = pliv_hat$coef
}
- psi = (u_hat - v_hat * theta) * w_hat
+ if (score == "partialling out") psi = (y_minus_l_hat - d_minus_r_hat * theta) * z_minus_m_hat
+ if (score == "IV-type") psi = (y_minus_g_hat - D * theta) * z_minus_m_hat
var = var_two_way_cluster(
psi, psi_a,
cluster_var1,
diff --git a/tests/testthat/test-double_ml_pliv_user_score.R b/tests/testthat/test-double_ml_pliv_user_score.R
index b9d53063..0bd000b1 100644
--- a/tests/testthat/test-double_ml_pliv_user_score.R
+++ b/tests/testthat/test-double_ml_pliv_user_score.R
@@ -4,8 +4,8 @@ library("mlr3learners")
lgr::get_logger("mlr3")$set_threshold("warn")
-score_fct = function(y, z, d, g_hat, m_hat, r_hat, smpls) {
- u_hat = y - g_hat
+score_fct_po = function(y, z, d, l_hat, m_hat, r_hat, g_hat, smpls) {
+ u_hat = y - l_hat
w_hat = d - r_hat
v_hat = z - m_hat
psi_a = -w_hat * v_hat
@@ -15,16 +15,27 @@ score_fct = function(y, z, d, g_hat, m_hat, r_hat, smpls) {
psi_b = psi_b)
}
+score_fct_iv = function(y, z, d, l_hat, m_hat, r_hat, g_hat, smpls) {
+ v_hat = z - m_hat
+ psi_a = -d * v_hat
+ psi_b = v_hat * (y - g_hat)
+ psis = list(
+ psi_a = psi_a,
+ psi_b = psi_b)
+}
+
on_cran = !identical(Sys.getenv("NOT_CRAN"), "true")
if (on_cran) {
test_cases = expand.grid(
learner = "regr.lm",
dml_procedure = "dml2",
+ score = "partialling out",
stringsAsFactors = FALSE)
} else {
test_cases = expand.grid(
learner = c("regr.lm", "regr.glmnet"),
dml_procedure = c("dml1", "dml2"),
+ score = c("partialling out", "IV-type"),
stringsAsFactors = FALSE)
}
test_cases[".test_name"] = apply(test_cases, 1, paste, collapse = "_")
@@ -33,15 +44,24 @@ patrick::with_parameters_test_that("Unit tests for PLIV, callable score:",
.cases = test_cases, {
n_rep_boot = 498
+ if (score == "partialling out") {
+ score_fct = score_fct_po
+ ml_g = NULL
+ } else if (score == "IV-type") {
+ score_fct = score_fct_iv
+ ml_g = lrn(learner)
+ }
+
set.seed(3141)
double_mlpliv_obj = DoubleMLPLIV$new(
data = data_pliv$dml_data,
n_folds = 5,
- ml_g = lrn(learner),
+ ml_l = lrn(learner),
ml_m = lrn(learner),
ml_r = lrn(learner),
+ ml_g = ml_g,
dml_procedure = dml_procedure,
- score = "partialling out")
+ score = score)
double_mlpliv_obj$fit()
theta_obj = double_mlpliv_obj$coef
@@ -54,9 +74,10 @@ patrick::with_parameters_test_that("Unit tests for PLIV, callable score:",
double_mlpliv_obj_score = DoubleMLPLIV$new(
data = data_pliv$dml_data,
n_folds = 5,
- ml_g = lrn(learner),
+ ml_l = lrn(learner),
ml_m = lrn(learner),
ml_r = lrn(learner),
+ ml_g = ml_g,
dml_procedure = dml_procedure,
score = score_fct)
diff --git a/tests/testthat/test-double_ml_plr.R b/tests/testthat/test-double_ml_plr.R
index 9869b0d1..469ac338 100644
--- a/tests/testthat/test-double_ml_plr.R
+++ b/tests/testthat/test-double_ml_plr.R
@@ -30,7 +30,9 @@ patrick::with_parameters_test_that("Unit tests for PLR:",
plr_hat = dml_plr(data_plr$df,
y = "y", d = "d",
n_folds = n_folds,
- ml_g = learner_pars$ml_g$clone(), ml_m = learner_pars$ml_m$clone(),
+ ml_l = learner_pars$ml_l$clone(),
+ ml_m = learner_pars$ml_m$clone(),
+ ml_g = learner_pars$ml_g$clone(),
dml_procedure = dml_procedure, score = score)
theta = plr_hat$coef
se = plr_hat$se
@@ -47,13 +49,24 @@ patrick::with_parameters_test_that("Unit tests for PLR:",
score = score)$boot_coef
set.seed(3141)
- double_mlplr_obj = DoubleMLPLR$new(
- data = data_plr$dml_data,
- ml_g = learner_pars$ml_g$clone(),
- ml_m = learner_pars$ml_m$clone(),
- dml_procedure = dml_procedure,
- n_folds = n_folds,
- score = score)
+ if (score == "partialling out") {
+ double_mlplr_obj = DoubleMLPLR$new(
+ data = data_plr$dml_data,
+ ml_l = learner_pars$ml_g$clone(),
+ ml_m = learner_pars$ml_m$clone(),
+ dml_procedure = dml_procedure,
+ n_folds = n_folds,
+ score = score)
+ } else {
+ double_mlplr_obj = DoubleMLPLR$new(
+ data = data_plr$dml_data,
+ ml_l = learner_pars$ml_l$clone(),
+ ml_m = learner_pars$ml_m$clone(),
+ ml_g = learner_pars$ml_g$clone(),
+ dml_procedure = dml_procedure,
+ n_folds = n_folds,
+ score = score)
+ }
double_mlplr_obj$fit()
theta_obj = double_mlplr_obj$coef
diff --git a/tests/testthat/test-double_ml_plr_classifier.R b/tests/testthat/test-double_ml_plr_classifier.R
index b453a846..5b29dad1 100644
--- a/tests/testthat/test-double_ml_plr_classifier.R
+++ b/tests/testthat/test-double_ml_plr_classifier.R
@@ -7,15 +7,17 @@ lgr::get_logger("mlr3")$set_threshold("warn")
on_cran = !identical(Sys.getenv("NOT_CRAN"), "true")
if (on_cran) {
test_cases = expand.grid(
- g_learner = c("regr.rpart", "classif.rpart"),
+ l_learner = c("regr.rpart", "classif.rpart"),
m_learner = "classif.rpart",
+ g_learner = "regr.rpart",
dml_procedure = "dml2",
score = "partialling out",
stringsAsFactors = FALSE)
} else {
test_cases = expand.grid(
- g_learner = "regr.cv_glmnet",
+ l_learner = c("regr.rpart", "classif.rpart"),
m_learner = "classif.cv_glmnet",
+ g_learner = "regr.cv_glmnet",
dml_procedure = c("dml1", "dml2"),
score = c("IV-type", "partialling out"),
stringsAsFactors = FALSE)
@@ -27,15 +29,24 @@ patrick::with_parameters_test_that("Unit tests for PLR with classifier for ml_m:
n_rep_boot = 498
n_folds = 3
- if (g_learner == "regr.cv_glmnet") {
- ml_g = mlr3::lrn(g_learner)
- ml_m = mlr3::lrn(m_learner)
+ ml_l = mlr3::lrn(l_learner)
+ ml_m = mlr3::lrn(m_learner)
+ ml_g = mlr3::lrn(g_learner)
+
+ if (ml_l$task_type == "regr") {
set.seed(3141)
+ if (score == "IV-type") {
+ ml_g = ml_g$clone()
+ } else {
+ ml_g = NULL
+ }
plr_hat = dml_plr(data_irm$df,
y = "y", d = "d",
n_folds = n_folds,
- ml_g = ml_g$clone(), ml_m = ml_m$clone(),
+ ml_l = ml_l$clone(),
+ ml_m = ml_m$clone(),
+ ml_g = ml_g,
dml_procedure = dml_procedure, score = score)
theta = plr_hat$coef
se = plr_hat$se
@@ -52,10 +63,16 @@ patrick::with_parameters_test_that("Unit tests for PLR with classifier for ml_m:
pval = plr_hat$pval
set.seed(3141)
+ if (score == "IV-type") {
+ ml_g = ml_g$clone()
+ } else {
+ ml_g = NULL
+ }
double_mlplr_obj = DoubleMLPLR$new(
data = data_irm$dml_data,
- ml_g = ml_g$clone(),
+ ml_l = ml_l$clone(),
ml_m = ml_m$clone(),
+ ml_g = ml_g,
dml_procedure = dml_procedure,
n_folds = n_folds,
score = score)
@@ -76,12 +93,18 @@ patrick::with_parameters_test_that("Unit tests for PLR with classifier for ml_m:
expect_equal(pval, pval_obj, tolerance = 1e-8)
# expect_equal(ci, ci_obj, tolerance = 1e-8)
- } else if (g_learner == "classif.cv_glmnet") {
- msg = "Invalid learner provided for ml_g: must be of class 'LearnerRegr'"
+ } else if (ml_l$task_type == "classif") {
+ msg = "Invalid learner provided for ml_l: 'learner\\$task_type' must be 'regr'"
+ if (score == "IV-type") {
+ ml_g = ml_g$clone()
+ } else {
+ ml_g = NULL
+ }
expect_error(DoubleMLPLR$new(
data = data_irm$dml_data,
- ml_g = lrn(g_learner),
- ml_m = lrn(m_learner),
+ ml_l = ml_l$clone(),
+ ml_m = ml_m$clone(),
+ ml_g = ml_g,
dml_procedure = dml_procedure,
n_folds = n_folds,
score = score),
@@ -99,7 +122,7 @@ test_that("Unit tests for exception handling of PLR with classifier for ml_m:",
dml_data = double_ml_data_from_data_frame(df, y_col = "y", d_cols = "d")
double_mlplr_obj = DoubleMLPLR$new(
data = dml_data,
- ml_g = mlr3::lrn("regr.rpart"),
+ ml_l = mlr3::lrn("regr.rpart"),
ml_m = mlr3::lrn("classif.rpart"))
msg = paste(
"Assertion on 'levels\\(data\\[\\[target\\]\\])' failed: .* set \\{'0','1'\\}")
@@ -112,7 +135,7 @@ test_that("Unit tests for exception handling of PLR with classifier for ml_m:",
dml_data = double_ml_data_from_data_frame(df, y_col = "y", d_cols = "d")
double_mlplr_obj = DoubleMLPLR$new(
data = dml_data,
- ml_g = mlr3::lrn("regr.rpart"),
+ ml_l = mlr3::lrn("regr.rpart"),
ml_m = mlr3::lrn("classif.rpart"))
msg = paste(
"Assertion on 'levels\\(data\\[\\[target\\]\\])' failed: .* set \\{'0','1'\\}")
diff --git a/tests/testthat/test-double_ml_plr_exception_handling.R b/tests/testthat/test-double_ml_plr_exception_handling.R
index 2b6ed203..5de1cd25 100644
--- a/tests/testthat/test-double_ml_plr_exception_handling.R
+++ b/tests/testthat/test-double_ml_plr_exception_handling.R
@@ -2,7 +2,10 @@ context("Unit tests for exception handling if fit() or bootstrap() was not run y
library("mlr3learners")
+logger = lgr::get_logger("bbotk")
+logger$set_threshold("warn")
lgr::get_logger("mlr3")$set_threshold("warn")
+
on_cran = !identical(Sys.getenv("NOT_CRAN"), "true")
if (on_cran) {
test_cases = expand.grid(
@@ -48,10 +51,16 @@ patrick::with_parameters_test_that("Unit tests for exception handling of PLR:",
} else {
msg = "Assertion on 'i' failed: Element 1 is not <= 1."
}
+ if (score == "IV-type") {
+ ml_g = learner_pars$mlmethod$mlmethod_g
+ } else {
+ ml_g = NULL
+ }
expect_error(DoubleMLPLR$new(
data = data_ml,
- ml_g = learner_pars$mlmethod$mlmethod_g,
+ ml_l = learner_pars$mlmethod$mlmethod_l,
ml_m = mlr3::lrn(learner_pars$mlmethod$mlmethod_m),
+ ml_g = ml_g,
dml_procedure = dml_procedure,
n_folds = n_folds,
n_rep = n_rep,
@@ -59,10 +68,16 @@ patrick::with_parameters_test_that("Unit tests for exception handling of PLR:",
apply_cross_fitting = apply_cross_fitting),
regexp = msg)
} else {
+ if (score == "IV-type") {
+ ml_g = learner_pars$mlmethod$mlmethod_g
+ } else {
+ ml_g = NULL
+ }
double_mlplr_obj = DoubleMLPLR$new(
data = data_ml,
- ml_g = learner_pars$mlmethod$mlmethod_g,
+ ml_l = learner_pars$mlmethod$mlmethod_l,
ml_m = mlr3::lrn(learner_pars$mlmethod$mlmethod_m),
+ ml_g = ml_g,
dml_procedure = dml_procedure,
n_folds = n_folds,
n_rep = n_rep,
@@ -76,11 +91,20 @@ patrick::with_parameters_test_that("Unit tests for exception handling of PLR:",
treat_var = "d",
params = learner_pars$params$params_m)
- # set params for nuisance part g
+ # set params for nuisance part l
double_mlplr_obj$set_ml_nuisance_params(
- learner = "ml_g",
+ learner = "ml_l",
treat_var = "d",
- params = learner_pars$params$params_g)
+ params = learner_pars$params$params_l)
+
+ if (score == "IV-type") {
+ # set params for nuisance part g
+ double_mlplr_obj$set_ml_nuisance_params(
+ learner = "ml_g",
+ treat_var = "d",
+ params = learner_pars$params$params_g)
+ }
+
}
# currently, no warning or message printed
@@ -100,6 +124,66 @@ patrick::with_parameters_test_that("Unit tests for exception handling of PLR:",
msg = "Multiplier bootstrap has not yet been performed. First call bootstrap\\(\\) and then try confint\\(\\) again."
expect_error(double_mlplr_obj$confint(joint = TRUE, level = 0.95),
regexp = msg)
+
+ set.seed(3141)
+ dml_data = make_plr_CCDDHNR2018(n_obs = 101)
+ ml_l = lrn("regr.ranger")
+ ml_m = ml_l$clone()
+ ml_g = ml_l$clone()
+
+ if (score == "partialling out") {
+ msg = paste0(
+ "A learner ml_g has been provided for ",
+ "score = 'partialling out' but will be ignored.")
+ expect_warning(DoubleMLPLR$new(dml_data,
+ ml_l = ml_l, ml_m = ml_m, ml_g = ml_g,
+ score = score),
+ regexp = msg)
+ }
}
}
)
+
+test_that("Unit tests for deprecation warnings of PLR", {
+ set.seed(3141)
+ dml_data = make_plr_CCDDHNR2018(n_obs = 101)
+ ml_l = lrn("regr.ranger")
+ ml_m = ml_l$clone()
+ ml_g = ml_l$clone()
+ msg = paste0("The argument ml_g was renamed to ml_l.")
+ expect_warning(DoubleMLPLR$new(dml_data, ml_g = ml_g, ml_m = ml_m),
+ regexp = msg)
+
+ msg = "learners ml_l and ml_g should be specified"
+ expect_warning(DoubleMLPLR$new(dml_data, ml_l, ml_m,
+ score = "IV-type"),
+ regexp = msg)
+
+ dml_obj = DoubleMLPLR$new(dml_data, ml_l = ml_l, ml_m = ml_m)
+
+ msg = paste0("Learner ml_g was renamed to ml_l.")
+ expect_warning(dml_obj$set_ml_nuisance_params(
+ "ml_g", "d", list("num.trees" = 10)),
+ regexp = msg)
+
+ par_grids = list(
+ "ml_g" = paradox::ParamSet$new(list(
+ paradox::ParamInt$new("num.trees", lower = 9, upper = 10))),
+ "ml_m" = paradox::ParamSet$new(list(
+ paradox::ParamInt$new("num.trees", lower = 10, upper = 11))))
+
+ msg = paste0("Learner ml_g was renamed to ml_l.")
+ expect_warning(dml_obj$tune(par_grids),
+ regexp = msg)
+
+ tune_settings = list(
+ n_folds_tune = 5,
+ rsmp_tune = mlr3::rsmp("cv", folds = 5),
+ measure = list(ml_g = "regr.mse", ml_m = "regr.mae"),
+ terminator = mlr3tuning::trm("evals", n_evals = 20),
+ algorithm = mlr3tuning::tnr("grid_search"),
+ resolution = 5)
+ expect_warning(dml_obj$tune(par_grids, tune_settings = tune_settings),
+ regexp = msg)
+}
+)
diff --git a/tests/testthat/test-double_ml_plr_export_preds.R b/tests/testthat/test-double_ml_plr_export_preds.R
index 555606c4..3929ecab 100644
--- a/tests/testthat/test-double_ml_plr_export_preds.R
+++ b/tests/testthat/test-double_ml_plr_export_preds.R
@@ -7,17 +7,19 @@ lgr::get_logger("mlr3")$set_threshold("warn")
on_cran = !identical(Sys.getenv("NOT_CRAN"), "true")
if (on_cran) {
test_cases = expand.grid(
- g_learner = "regr.rpart",
+ l_learner = "regr.rpart",
m_learner = "regr.rpart",
+ g_learner = "regr.rpart",
dml_procedure = "dml2",
score = "partialling out",
stringsAsFactors = FALSE)
} else {
test_cases = expand.grid(
- g_learner = c("regr.rpart", "regr.lm"),
+ l_learner = c("regr.rpart", "regr.lm"),
m_learner = c("regr.rpart", "regr.lm"),
+ g_learner = c("regr.rpart", "regr.lm"),
dml_procedure = "dml2",
- score = "partialling out",
+ score = c("partialling out", "IV-type"),
stringsAsFactors = FALSE)
}
test_cases[".test_name"] = apply(test_cases, 1, paste, collapse = "_")
@@ -30,10 +32,16 @@ patrick::with_parameters_test_that("Unit tests for for the export of predictions
df = data_plr$df
dml_data = data_plr$dml_data
+ if (score == "IV-type") {
+ ml_g = lrn(g_learner)
+ } else {
+ ml_g = NULL
+ }
double_mlplr_obj = DoubleMLPLR$new(
data = dml_data,
- ml_g = lrn(g_learner),
+ ml_l = lrn(l_learner),
ml_m = lrn(m_learner),
+ ml_g = ml_g,
dml_procedure = dml_procedure,
n_folds = n_folds,
score = score)
@@ -44,14 +52,14 @@ patrick::with_parameters_test_that("Unit tests for for the export of predictions
Xnames = names(df)[names(df) %in% c("y", "d", "z") == FALSE]
indx = (names(df) %in% c(Xnames, "y"))
data = df[, indx]
- task = mlr3::TaskRegr$new(id = "ml_g", backend = data, target = "y")
+ task = mlr3::TaskRegr$new(id = "ml_l", backend = data, target = "y")
resampling_smpls = rsmp("custom")$instantiate(
task,
double_mlplr_obj$smpls[[1]]$train_ids,
double_mlplr_obj$smpls[[1]]$test_ids)
- resampling_pred = resample(task, lrn(g_learner), resampling_smpls)
- preds_g = as.data.table(resampling_pred$prediction())
- data.table::setorder(preds_g, "row_ids")
+ resampling_pred = resample(task, lrn(l_learner), resampling_smpls)
+ preds_l = as.data.table(resampling_pred$prediction())
+ data.table::setorder(preds_l, "row_ids")
Xnames = names(df)[names(df) %in% c("y", "d", "z") == FALSE]
indx = (names(df) %in% c(Xnames, "d"))
@@ -65,8 +73,36 @@ patrick::with_parameters_test_that("Unit tests for for the export of predictions
preds_m = as.data.table(resampling_pred$prediction())
data.table::setorder(preds_m, "row_ids")
- expect_equal(as.vector(double_mlplr_obj$predictions$ml_g),
- as.vector(preds_g$response),
+ if (score == "IV-type") {
+ d = df[["d"]]
+ y = df[["y"]]
+ psi_a = -(d - preds_m$response) * (d - preds_m$response)
+ psi_b = (d - preds_m$response) * (y - preds_l$response)
+ theta_initial = -mean(psi_b, na.rm = TRUE) / mean(psi_a, na.rm = TRUE)
+
+ data_aux = cbind(df, "y_minus_theta_d" = y - theta_initial * d)
+ Xnames = names(data_aux)[names(data_aux) %in%
+ c("y", "d", "z", "y_minus_theta_d") == FALSE]
+ indx = (names(data_aux) %in% c(Xnames, "y_minus_theta_d"))
+ data = data_aux[, indx]
+ task = mlr3::TaskRegr$new(
+ id = "ml_g", backend = data,
+ target = "y_minus_theta_d")
+ resampling_smpls = rsmp("custom")$instantiate(
+ task,
+ double_mlplr_obj$smpls[[1]]$train_ids,
+ double_mlplr_obj$smpls[[1]]$test_ids)
+ resampling_pred = resample(task, lrn(g_learner), resampling_smpls)
+ preds_g = as.data.table(resampling_pred$prediction())
+ data.table::setorder(preds_g, "row_ids")
+
+ expect_equal(as.vector(double_mlplr_obj$predictions$ml_g),
+ as.vector(preds_g$response),
+ tolerance = 1e-8)
+ }
+
+ expect_equal(as.vector(double_mlplr_obj$predictions$ml_l),
+ as.vector(preds_l$response),
tolerance = 1e-8)
expect_equal(as.vector(double_mlplr_obj$predictions$ml_m),
diff --git a/tests/testthat/test-double_ml_plr_loaded_mlr3learner.R b/tests/testthat/test-double_ml_plr_loaded_mlr3learner.R
index 2c93b153..eff13312 100644
--- a/tests/testthat/test-double_ml_plr_loaded_mlr3learner.R
+++ b/tests/testthat/test-double_ml_plr_loaded_mlr3learner.R
@@ -29,10 +29,16 @@ patrick::with_parameters_test_that("Unit tests for PLR:",
params = list("cp" = 0.01, "minsplit" = 20)
set.seed(123)
+ if (score == "IV-type") {
+ ml_g = learner_name
+ } else {
+ ml_g = NULL
+ }
double_mlplr = DoubleMLPLR$new(
data = data_plr$dml_data,
- ml_g = learner_name,
+ ml_l = learner_name,
ml_m = learner_name,
+ ml_g = ml_g,
dml_procedure = dml_procedure,
n_folds = n_folds,
score = score)
@@ -43,12 +49,20 @@ patrick::with_parameters_test_that("Unit tests for PLR:",
treat_var = "d",
params = params)
- # set params for nuisance part g
+ # set params for nuisance part l
double_mlplr$set_ml_nuisance_params(
- learner = "ml_g",
+ learner = "ml_l",
treat_var = "d",
params = params)
+ if (score == "IV-type") {
+ # set params for nuisance part g
+ double_mlplr$set_ml_nuisance_params(
+ learner = "ml_g",
+ treat_var = "d",
+ params = params)
+ }
+
double_mlplr$fit()
theta = double_mlplr$coef
se = double_mlplr$se
@@ -60,10 +74,16 @@ patrick::with_parameters_test_that("Unit tests for PLR:",
set.seed(123)
loaded_learner = mlr3::lrn("regr.rpart", "cp" = 0.01, "minsplit" = 20)
+ if (score == "IV-type") {
+ ml_g = loaded_learner
+ } else {
+ ml_g = NULL
+ }
double_mlplr_loaded = DoubleMLPLR$new(
data = data_plr$dml_data,
- ml_g = loaded_learner,
+ ml_l = loaded_learner,
ml_m = loaded_learner,
+ ml_g = ml_g,
dml_procedure = dml_procedure,
n_folds = n_folds,
score = score)
@@ -79,10 +99,16 @@ patrick::with_parameters_test_that("Unit tests for PLR:",
set.seed(123)
semiloaded_learner = mlr3::lrn("regr.rpart")
+ if (score == "IV-type") {
+ ml_g = semiloaded_learner
+ } else {
+ ml_g = NULL
+ }
double_mlplr_semiloaded = DoubleMLPLR$new(
data = data_plr$dml_data,
- ml_g = semiloaded_learner,
+ ml_l = semiloaded_learner,
ml_m = semiloaded_learner,
+ ml_g = ml_g,
dml_procedure = dml_procedure,
n_folds = n_folds,
score = score)
@@ -92,12 +118,20 @@ patrick::with_parameters_test_that("Unit tests for PLR:",
treat_var = "d",
params = params)
- # set params for nuisance part g
+ # set params for nuisance part l
double_mlplr_semiloaded$set_ml_nuisance_params(
- learner = "ml_g",
+ learner = "ml_l",
treat_var = "d",
params = params)
+ if (score == "IV-type") {
+ # set params for nuisance part g
+ double_mlplr_semiloaded$set_ml_nuisance_params(
+ learner = "ml_g",
+ treat_var = "d",
+ params = params)
+ }
+
double_mlplr_semiloaded$fit()
theta_semiloaded = double_mlplr_semiloaded$coef
se_semiloaded = double_mlplr_semiloaded$se
diff --git a/tests/testthat/test-double_ml_plr_multitreat.R b/tests/testthat/test-double_ml_plr_multitreat.R
index b5b96447..0a31d76f 100644
--- a/tests/testthat/test-double_ml_plr_multitreat.R
+++ b/tests/testthat/test-double_ml_plr_multitreat.R
@@ -29,11 +29,17 @@ patrick::with_parameters_test_that("Unit tests for PLR:",
n_folds = 5
set.seed(3141)
+ if (score == "IV-type") {
+ ml_g = learner_pars$ml_g$clone()
+ } else {
+ ml_g = NULL
+ }
plr_hat = dml_plr_multitreat(data_plr_multi,
y = "y", d = c("d1", "d2", "d3"),
n_folds = n_folds,
- ml_g = learner_pars$ml_g$clone(),
+ ml_l = learner_pars$ml_l$clone(),
ml_m = learner_pars$ml_m$clone(),
+ ml_g = ml_g,
dml_procedure = dml_procedure, score = score)
theta = plr_hat$coef
se = plr_hat$se
@@ -56,9 +62,15 @@ patrick::with_parameters_test_that("Unit tests for PLR:",
y_col = "y",
d_cols = c("d1", "d2", "d3"), x_cols = Xnames)
+ if (score == "IV-type") {
+ ml_g = learner_pars$ml_g$clone()
+ } else {
+ ml_g = NULL
+ }
double_mlplr_obj = DoubleMLPLR$new(data_ml,
- ml_g = learner_pars$ml_g$clone(),
+ ml_l = learner_pars$ml_l$clone(),
ml_m = learner_pars$ml_m$clone(),
+ ml_g = ml_g,
dml_procedure = dml_procedure,
n_folds = n_folds,
score = score)
diff --git a/tests/testthat/test-double_ml_plr_nocrossfit.R b/tests/testthat/test-double_ml_plr_nocrossfit.R
index c46393de..9279ca9b 100644
--- a/tests/testthat/test-double_ml_plr_nocrossfit.R
+++ b/tests/testthat/test-double_ml_plr_nocrossfit.R
@@ -43,11 +43,17 @@ patrick::with_parameters_test_that("Unit tests for PLR:",
train_ids = list(seq(nrow(df))),
test_ids = list(seq(nrow(df)))))
}
+ if (score == "IV-type") {
+ ml_g = learner_pars$ml_g$clone()
+ } else {
+ ml_g = NULL
+ }
plr_hat = dml_plr(df,
y = "y", d = "d",
n_folds = 1,
- ml_g = learner_pars$ml_g$clone(),
+ ml_l = learner_pars$ml_l$clone(),
ml_m = learner_pars$ml_m$clone(),
+ ml_g = ml_g,
dml_procedure = dml_procedure, score = score,
smpls = smpls)
theta = plr_hat$coef
@@ -56,10 +62,16 @@ patrick::with_parameters_test_that("Unit tests for PLR:",
pval = plr_hat$pval
set.seed(3141)
+ if (score == "IV-type") {
+ ml_g = learner_pars$ml_g$clone()
+ } else {
+ ml_g = NULL
+ }
double_mlplr_obj = DoubleMLPLR$new(
data = data_plr$dml_data,
- ml_g = learner_pars$ml_g$clone(),
+ ml_l = learner_pars$ml_l$clone(),
ml_m = learner_pars$ml_m$clone(),
+ ml_g = ml_g,
dml_procedure = dml_procedure,
n_folds = n_folds,
score = score,
@@ -74,10 +86,16 @@ patrick::with_parameters_test_that("Unit tests for PLR:",
if (n_folds == 2) {
+ if (score == "IV-type") {
+ ml_g = learner_pars$ml_g$clone()
+ } else {
+ ml_g = NULL
+ }
dml_plr_obj_external = DoubleMLPLR$new(
data = data_plr$dml_data,
- ml_g = learner_pars$ml_g$clone(),
+ ml_l = learner_pars$ml_l$clone(),
ml_m = learner_pars$ml_m$clone(),
+ ml_g = ml_g,
dml_procedure = dml_procedure,
n_folds = n_folds,
score = score,
diff --git a/tests/testthat/test-double_ml_plr_nonorth.R b/tests/testthat/test-double_ml_plr_nonorth.R
index 7d6da79d..cdada942 100644
--- a/tests/testthat/test-double_ml_plr_nonorth.R
+++ b/tests/testthat/test-double_ml_plr_nonorth.R
@@ -4,7 +4,21 @@ library("mlr3learners")
lgr::get_logger("mlr3")$set_threshold("warn")
-non_orth_score = function(y, d, g_hat, m_hat, smpls) {
+non_orth_score_w_g = function(y, d, l_hat, m_hat, g_hat, smpls) {
+ u_hat = y - g_hat
+ psi_a = -1 * d * d
+ psi_b = d * u_hat
+ psis = list(psi_a = psi_a, psi_b = psi_b)
+ return(psis)
+}
+
+non_orth_score_w_l = function(y, d, l_hat, m_hat, g_hat, smpls) {
+
+ p_a = -(d - m_hat) * (d - m_hat)
+ p_b = (d - m_hat) * (y - l_hat)
+ theta_initial = -mean(p_b, na.rm = TRUE) / mean(p_a, na.rm = TRUE)
+ g_hat = l_hat - theta_initial * m_hat
+
u_hat = y - g_hat
psi_a = -1 * d * d
psi_b = d * u_hat
@@ -17,7 +31,7 @@ if (on_cran) {
test_cases = expand.grid(
learner = "regr.lm",
dml_procedure = "dml1",
- score = c(non_orth_score),
+ which_score = c("non_orth_score_w_g"),
n_folds = c(3),
n_rep = c(2),
stringsAsFactors = FALSE)
@@ -25,7 +39,9 @@ if (on_cran) {
test_cases = expand.grid(
learner = c("regr.lm", "regr.cv_glmnet"),
dml_procedure = c("dml1", "dml2"),
- score = c(non_orth_score),
+ which_score = c(
+ "non_orth_score_w_g",
+ "non_orth_score_w_l"),
n_folds = c(2, 3),
n_rep = c(1, 2),
stringsAsFactors = FALSE)
@@ -35,12 +51,22 @@ test_cases[".test_name"] = apply(test_cases, 1, paste, collapse = "_")
patrick::with_parameters_test_that("Unit tests for PLR:",
.cases = test_cases, {
learner_pars = get_default_mlmethod_plr(learner)
+
+ if (which_score == "non_orth_score_w_g") {
+ score = non_orth_score_w_g
+ ml_g = learner_pars$ml_g$clone()
+ } else if (which_score == "non_orth_score_w_l") {
+ score = non_orth_score_w_l
+ ml_g = NULL
+ }
+
n_rep_boot = 498
set.seed(3141)
double_mlplr_obj = DoubleMLPLR$new(
data = data_plr$dml_data,
- ml_g = learner_pars$ml_g$clone(),
+ ml_l = learner_pars$ml_l$clone(),
ml_m = learner_pars$ml_m$clone(),
+ ml_g = ml_g,
dml_procedure = dml_procedure,
n_folds = n_folds,
score = score)
@@ -63,8 +89,9 @@ patrick::with_parameters_test_that("Unit tests for PLR:",
if (n_folds == 2 & n_rep == 1) {
double_mlplr_nocf = DoubleMLPLR$new(
data = data_plr$dml_data,
- ml_g = learner_pars$ml_g$clone(),
+ ml_l = learner_pars$ml_l$clone(),
ml_m = learner_pars$ml_m$clone(),
+ ml_g = ml_g,
dml_procedure = dml_procedure,
n_folds = n_folds,
score = score,
diff --git a/tests/testthat/test-double_ml_plr_p_adjust.R b/tests/testthat/test-double_ml_plr_p_adjust.R
index e98211dc..ae021215 100644
--- a/tests/testthat/test-double_ml_plr_p_adjust.R
+++ b/tests/testthat/test-double_ml_plr_p_adjust.R
@@ -53,9 +53,15 @@ patrick::with_parameters_test_that("Unit tests for PLR:",
x_cols = colnames(X)[(k + 1):p],
y_col = "y",
d_cols = colnames(X)[1:k])
+ if (score == "IV-type") {
+ ml_g = learner_pars$ml_g$clone()
+ } else {
+ ml_g = NULL
+ }
double_mlplr_obj = DoubleMLPLR$new(data_ml,
- ml_g = learner_pars$ml_g$clone(),
+ ml_l = learner_pars$ml_l$clone(),
ml_m = learner_pars$ml_m$clone(),
+ ml_g = ml_g,
dml_procedure = dml_procedure,
n_folds = n_folds,
score = score,
diff --git a/tests/testthat/test-double_ml_plr_parameter_passing.R b/tests/testthat/test-double_ml_plr_parameter_passing.R
index b4fe9641..4c127050 100644
--- a/tests/testthat/test-double_ml_plr_parameter_passing.R
+++ b/tests/testthat/test-double_ml_plr_parameter_passing.R
@@ -35,17 +35,25 @@ patrick::with_parameters_test_that("Unit tests for parameter passing of PLR (oop
n_rep = 3
learner_pars = get_default_mlmethod_plr(learner)
- params_g = rep(list(learner_pars$params$params_g), 2)
+ params_l = rep(list(learner_pars$params$params_l), 2)
params_m = rep(list(learner_pars$params$params_m), 2)
+ params_g = rep(list(learner_pars$params$params_g), 2)
set.seed(3141)
+ if (score == "IV-type") {
+ ml_g = mlr3::lrn(learner_pars$mlmethod$mlmethod_g)
+ } else {
+ ml_g = NULL
+ }
plr_hat = dml_plr_multitreat(data_plr_multi,
y = "y", d = c("d1", "d2"),
n_folds = n_folds, n_rep = n_rep,
- ml_g = mlr3::lrn(learner_pars$mlmethod$mlmethod_g),
+ ml_l = mlr3::lrn(learner_pars$mlmethod$mlmethod_l),
ml_m = mlr3::lrn(learner_pars$mlmethod$mlmethod_m),
- params_g = params_g,
+ ml_g = ml_g,
+ params_l = params_l,
params_m = params_m,
+ params_g = params_g,
dml_procedure = dml_procedure, score = score)
theta = plr_hat$coef
se = plr_hat$se
@@ -65,26 +73,40 @@ patrick::with_parameters_test_that("Unit tests for parameter passing of PLR (oop
d_cols = c("d1", "d2"), x_cols = Xnames)
set.seed(3141)
+ if (score == "IV-type") {
+ ml_g = mlr3::lrn(learner_pars$mlmethod$mlmethod_g)
+ } else {
+ ml_g = NULL
+ }
double_mlplr_obj = DoubleMLPLR$new(data_ml,
n_folds = n_folds,
- ml_g = mlr3::lrn(learner_pars$mlmethod$mlmethod_g),
+ ml_l = mlr3::lrn(learner_pars$mlmethod$mlmethod_l),
ml_m = mlr3::lrn(learner_pars$mlmethod$mlmethod_m),
+ ml_g = ml_g,
dml_procedure = dml_procedure,
score = score,
n_rep = n_rep)
double_mlplr_obj$set_ml_nuisance_params(
- treat_var = "d1", learner = "ml_g",
- params = learner_pars$params$params_g)
+ treat_var = "d1", learner = "ml_l",
+ params = learner_pars$params$params_l)
double_mlplr_obj$set_ml_nuisance_params(
- treat_var = "d2", learner = "ml_g",
- params = learner_pars$params$params_g)
+ treat_var = "d2", learner = "ml_l",
+ params = learner_pars$params$params_l)
double_mlplr_obj$set_ml_nuisance_params(
treat_var = "d1", learner = "ml_m",
params = learner_pars$params$params_m)
double_mlplr_obj$set_ml_nuisance_params(
treat_var = "d2", learner = "ml_m",
params = learner_pars$params$params_m)
+ if (score == "IV-type") {
+ double_mlplr_obj$set_ml_nuisance_params(
+ treat_var = "d1", learner = "ml_g",
+ params = learner_pars$params$params_g)
+ double_mlplr_obj$set_ml_nuisance_params(
+ treat_var = "d2", learner = "ml_g",
+ params = learner_pars$params$params_g)
+ }
double_mlplr_obj$fit()
@@ -108,8 +130,9 @@ patrick::with_parameters_test_that("Unit tests for parameter passing of PLR (no
n_folds = 2
learner_pars = get_default_mlmethod_plr(learner)
- params_g = rep(list(learner_pars$params$params_g), 2)
+ params_l = rep(list(learner_pars$params$params_l), 2)
params_m = rep(list(learner_pars$params$params_m), 2)
+ params_g = rep(list(learner_pars$params$params_g), 2)
# Passing for non-cross-fitting case
set.seed(3141)
@@ -119,13 +142,20 @@ patrick::with_parameters_test_that("Unit tests for parameter passing of PLR (no
test_ids = list(my_sampling$test_set(1))
smpls = list(list(train_ids = train_ids, test_ids = test_ids))
+ if (score == "IV-type") {
+ ml_g = mlr3::lrn(learner_pars$mlmethod$mlmethod_g)
+ } else {
+ ml_g = NULL
+ }
plr_hat = dml_plr_multitreat(data_plr_multi,
y = "y", d = c("d1", "d2"),
n_folds = 1,
- ml_g = mlr3::lrn(learner_pars$mlmethod$mlmethod_g),
+ ml_l = mlr3::lrn(learner_pars$mlmethod$mlmethod_l),
ml_m = mlr3::lrn(learner_pars$mlmethod$mlmethod_m),
- params_g = params_g,
+ ml_g = ml_g,
+ params_l = params_l,
params_m = params_m,
+ params_g = params_g,
dml_procedure = dml_procedure, score = score,
smpls = smpls)
theta = plr_hat$coef
@@ -137,26 +167,40 @@ patrick::with_parameters_test_that("Unit tests for parameter passing of PLR (no
d_cols = c("d1", "d2"), x_cols = Xnames)
set.seed(3141)
+ if (score == "IV-type") {
+ ml_g = mlr3::lrn(learner_pars$mlmethod$mlmethod_g)
+ } else {
+ ml_g = NULL
+ }
double_mlplr_obj_nocf = DoubleMLPLR$new(data_ml,
n_folds = n_folds,
- ml_g = mlr3::lrn(learner_pars$mlmethod$mlmethod_g),
+ ml_l = mlr3::lrn(learner_pars$mlmethod$mlmethod_l),
ml_m = mlr3::lrn(learner_pars$mlmethod$mlmethod_m),
+ ml_g = ml_g,
dml_procedure = dml_procedure,
score = score,
apply_cross_fitting = FALSE)
double_mlplr_obj_nocf$set_ml_nuisance_params(
- treat_var = "d1", learner = "ml_g",
- params = learner_pars$params$params_g)
+ treat_var = "d1", learner = "ml_l",
+ params = learner_pars$params$params_l)
double_mlplr_obj_nocf$set_ml_nuisance_params(
- treat_var = "d2", learner = "ml_g",
- params = learner_pars$params$params_g)
+ treat_var = "d2", learner = "ml_l",
+ params = learner_pars$params$params_l)
double_mlplr_obj_nocf$set_ml_nuisance_params(
treat_var = "d1", learner = "ml_m",
params = learner_pars$params$params_m)
double_mlplr_obj_nocf$set_ml_nuisance_params(
treat_var = "d2", learner = "ml_m",
params = learner_pars$params$params_m)
+ if (score == "IV-type") {
+ double_mlplr_obj_nocf$set_ml_nuisance_params(
+ treat_var = "d1", learner = "ml_g",
+ params = learner_pars$params$params_g)
+ double_mlplr_obj_nocf$set_ml_nuisance_params(
+ treat_var = "d2", learner = "ml_g",
+ params = learner_pars$params$params_g)
+ }
double_mlplr_obj_nocf$fit()
@@ -182,50 +226,71 @@ patrick::with_parameters_test_that("Unit tests for parameter passing of PLR (fol
d_cols = c("d1", "d2"), x_cols = Xnames)
set.seed(3141)
+ if (score == "IV-type") {
+ ml_g = mlr3::lrn(learner_pars$mlmethod$mlmethod_g)
+ } else {
+ ml_g = NULL
+ }
double_mlplr_obj = DoubleMLPLR$new(data_ml,
n_folds = n_folds,
- ml_g = mlr3::lrn(learner_pars$mlmethod$mlmethod_g),
+ ml_l = mlr3::lrn(learner_pars$mlmethod$mlmethod_l),
ml_m = mlr3::lrn(learner_pars$mlmethod$mlmethod_m),
+ ml_g = ml_g,
dml_procedure = dml_procedure,
score = score,
n_rep = n_rep)
double_mlplr_obj$set_ml_nuisance_params(
- treat_var = "d1", learner = "ml_g",
- params = learner_pars$params$params_g)
+ treat_var = "d1", learner = "ml_l",
+ params = learner_pars$params$params_l)
double_mlplr_obj$set_ml_nuisance_params(
- treat_var = "d2", learner = "ml_g",
- params = learner_pars$params$params_g)
+ treat_var = "d2", learner = "ml_l",
+ params = learner_pars$params$params_l)
double_mlplr_obj$set_ml_nuisance_params(
treat_var = "d1", learner = "ml_m",
params = learner_pars$params$params_m)
double_mlplr_obj$set_ml_nuisance_params(
treat_var = "d2", learner = "ml_m",
params = learner_pars$params$params_m)
+ if (score == "IV-type") {
+ double_mlplr_obj$set_ml_nuisance_params(
+ treat_var = "d1", learner = "ml_g",
+ params = learner_pars$params$params_g)
+ double_mlplr_obj$set_ml_nuisance_params(
+ treat_var = "d2", learner = "ml_g",
+ params = learner_pars$params$params_g)
+ }
double_mlplr_obj$fit()
theta = double_mlplr_obj$coef
se = double_mlplr_obj$se
- params_g_fold_wise = rep(list(rep(list(learner_pars$params$params_g), n_folds)), n_rep)
+ params_l_fold_wise = rep(list(rep(list(learner_pars$params$params_l), n_folds)), n_rep)
params_m_fold_wise = rep(list(rep(list(learner_pars$params$params_m), n_folds)), n_rep)
+ params_g_fold_wise = rep(list(rep(list(learner_pars$params$params_g), n_folds)), n_rep)
set.seed(3141)
+ if (score == "IV-type") {
+ ml_g = mlr3::lrn(learner_pars$mlmethod$mlmethod_g)
+ } else {
+ ml_g = NULL
+ }
dml_plr_fold_wise = DoubleMLPLR$new(data_ml,
n_folds = n_folds,
- ml_g = mlr3::lrn(learner_pars$mlmethod$mlmethod_g),
+ ml_l = mlr3::lrn(learner_pars$mlmethod$mlmethod_l),
ml_m = mlr3::lrn(learner_pars$mlmethod$mlmethod_m),
+ ml_g = ml_g,
dml_procedure = dml_procedure,
score = score,
n_rep = n_rep)
dml_plr_fold_wise$set_ml_nuisance_params(
- treat_var = "d1", learner = "ml_g",
- params = params_g_fold_wise,
+ treat_var = "d1", learner = "ml_l",
+ params = params_l_fold_wise,
set_fold_specific = TRUE)
dml_plr_fold_wise$set_ml_nuisance_params(
- treat_var = "d2", learner = "ml_g",
- params = params_g_fold_wise,
+ treat_var = "d2", learner = "ml_l",
+ params = params_l_fold_wise,
set_fold_specific = TRUE)
dml_plr_fold_wise$set_ml_nuisance_params(
treat_var = "d1", learner = "ml_m",
@@ -235,6 +300,16 @@ patrick::with_parameters_test_that("Unit tests for parameter passing of PLR (fol
treat_var = "d2", learner = "ml_m",
params = params_m_fold_wise,
set_fold_specific = TRUE)
+ if (score == "IV-type") {
+ dml_plr_fold_wise$set_ml_nuisance_params(
+ treat_var = "d1", learner = "ml_g",
+ params = params_g_fold_wise,
+ set_fold_specific = TRUE)
+ dml_plr_fold_wise$set_ml_nuisance_params(
+ treat_var = "d2", learner = "ml_g",
+ params = params_g_fold_wise,
+ set_fold_specific = TRUE)
+ }
dml_plr_fold_wise$fit()
theta_fold_wise = dml_plr_fold_wise$coef
@@ -251,8 +326,9 @@ patrick::with_parameters_test_that("Unit tests for parameter passing of PLR (def
n_folds = 2
n_rep = 3
- params_g = list(cp = 0.01, minsplit = 20) # this are defaults
+ params_l = list(cp = 0.01, minsplit = 20) # this are defaults
params_m = list(cp = 0.01, minsplit = 20) # this are defaults
+ params_g = list(cp = 0.01, minsplit = 20) # this are defaults
Xnames = names(data_plr_multi)[names(data_plr_multi) %in% c("y", "d1", "d2", "z") == FALSE]
data_ml = double_ml_data_from_data_frame(data_plr_multi,
@@ -260,10 +336,16 @@ patrick::with_parameters_test_that("Unit tests for parameter passing of PLR (def
d_cols = c("d1", "d2"), x_cols = Xnames)
set.seed(3141)
+ if (score == "IV-type") {
+ ml_g = lrn("regr.rpart")
+ } else {
+ ml_g = NULL
+ }
dml_plr_default = DoubleMLPLR$new(data_ml,
n_folds = n_folds,
- ml_g = lrn("regr.rpart"),
+ ml_l = lrn("regr.rpart"),
ml_m = lrn("regr.rpart"),
+ ml_g = ml_g,
dml_procedure = dml_procedure,
score = score,
n_rep = n_rep)
@@ -273,25 +355,39 @@ patrick::with_parameters_test_that("Unit tests for parameter passing of PLR (def
se_default = dml_plr_default$se
set.seed(3141)
+ if (score == "IV-type") {
+ ml_g = lrn("regr.rpart")
+ } else {
+ ml_g = NULL
+ }
double_mlplr_obj = DoubleMLPLR$new(data_ml,
n_folds = n_folds,
- ml_g = lrn("regr.rpart"),
+ ml_l = lrn("regr.rpart"),
ml_m = lrn("regr.rpart"),
+ ml_g = ml_g,
dml_procedure = dml_procedure,
score = score,
n_rep = n_rep)
double_mlplr_obj$set_ml_nuisance_params(
- treat_var = "d1", learner = "ml_g",
- params = params_g)
+ treat_var = "d1", learner = "ml_l",
+ params = params_l)
double_mlplr_obj$set_ml_nuisance_params(
- treat_var = "d2", learner = "ml_g",
- params = params_g)
+ treat_var = "d2", learner = "ml_l",
+ params = params_l)
double_mlplr_obj$set_ml_nuisance_params(
treat_var = "d1", learner = "ml_m",
params = params_m)
double_mlplr_obj$set_ml_nuisance_params(
treat_var = "d2", learner = "ml_m",
params = params_m)
+ if (score == "IV-type") {
+ double_mlplr_obj$set_ml_nuisance_params(
+ treat_var = "d1", learner = "ml_g",
+ params = params_g)
+ double_mlplr_obj$set_ml_nuisance_params(
+ treat_var = "d2", learner = "ml_g",
+ params = params_g)
+ }
double_mlplr_obj$fit()
theta = double_mlplr_obj$coef
diff --git a/tests/testthat/test-double_ml_plr_rep_cross_fit.R b/tests/testthat/test-double_ml_plr_rep_cross_fit.R
index da73d7f4..9e763a52 100644
--- a/tests/testthat/test-double_ml_plr_rep_cross_fit.R
+++ b/tests/testthat/test-double_ml_plr_rep_cross_fit.R
@@ -30,11 +30,17 @@ patrick::with_parameters_test_that("Unit tests for PLR:",
set.seed(3141)
n_folds = 5
+ if (score == "IV-type") {
+ ml_g = learner_pars$ml_g$clone()
+ } else {
+ ml_g = NULL
+ }
plr_hat = dml_plr(data_plr$df,
y = "y", d = "d",
n_folds = n_folds, n_rep = n_rep,
- ml_g = learner_pars$ml_g$clone(),
+ ml_l = learner_pars$ml_l$clone(),
ml_m = learner_pars$ml_m$clone(),
+ ml_g = ml_g,
dml_procedure = dml_procedure, score = score)
theta = plr_hat$coef
se = plr_hat$se
@@ -52,10 +58,16 @@ patrick::with_parameters_test_that("Unit tests for PLR:",
score = score)$boot_coef
set.seed(3141)
+ if (score == "IV-type") {
+ ml_g = learner_pars$ml_g$clone()
+ } else {
+ ml_g = NULL
+ }
double_mlplr_obj = DoubleMLPLR$new(
data = data_plr$dml_data,
- ml_g = learner_pars$ml_g$clone(),
+ ml_l = learner_pars$ml_l$clone(),
ml_m = learner_pars$ml_m$clone(),
+ ml_g = ml_g,
dml_procedure = dml_procedure,
n_folds = n_folds,
score = score,
diff --git a/tests/testthat/test-double_ml_plr_set_samples.R b/tests/testthat/test-double_ml_plr_set_samples.R
index 094a85ca..d8bebafc 100644
--- a/tests/testthat/test-double_ml_plr_set_samples.R
+++ b/tests/testthat/test-double_ml_plr_set_samples.R
@@ -35,9 +35,15 @@ patrick::with_parameters_test_that("PLR with external sample provision:",
y_col = "y",
d_cols = "d", x_cols = Xnames)
+ if (score == "IV-type") {
+ ml_g = learner_pars$ml_g$clone()
+ } else {
+ ml_g = NULL
+ }
double_mlplr_obj = DoubleMLPLR$new(data_ml,
- ml_g = learner_pars$ml_g$clone(),
+ ml_l = learner_pars$ml_l$clone(),
ml_m = learner_pars$ml_m$clone(),
+ ml_g = ml_g,
dml_procedure = dml_procedure,
n_folds = n_folds,
score = score,
@@ -52,9 +58,15 @@ patrick::with_parameters_test_that("PLR with external sample provision:",
# External sample provision
SAMPLES = double_mlplr_obj$smpls
+ if (score == "IV-type") {
+ ml_g = learner_pars$ml_g$clone()
+ } else {
+ ml_g = NULL
+ }
double_mlplr_obj_external = DoubleMLPLR$new(data_ml,
- ml_g = learner_pars$ml_g$clone(),
+ ml_l = learner_pars$ml_l$clone(),
ml_m = learner_pars$ml_m$clone(),
+ ml_g = ml_g,
dml_procedure = dml_procedure,
score = score,
draw_sample_splitting = FALSE)
diff --git a/tests/testthat/test-double_ml_plr_tuning.R b/tests/testthat/test-double_ml_plr_tuning.R
index 60907783..2a97ec9c 100644
--- a/tests/testthat/test-double_ml_plr_tuning.R
+++ b/tests/testthat/test-double_ml_plr_tuning.R
@@ -56,10 +56,16 @@ patrick::with_parameters_test_that("Unit tests for tuning of PLR:",
y_col = "y",
d_cols = c("d1", "d2"), x_cols = Xnames)
}
+ if (score == "IV-type") {
+ ml_g = learner
+ } else {
+ ml_g = NULL
+ }
double_mlplr_obj_tuned = DoubleMLPLR$new(data_ml,
n_folds = n_folds,
- ml_g = learner,
+ ml_l = learner,
ml_m = m_learner,
+ ml_g = ml_g,
dml_procedure = dml_procedure,
score = score,
n_rep = n_rep)
@@ -73,12 +79,18 @@ patrick::with_parameters_test_that("Unit tests for tuning of PLR:",
resolution = 5)
param_grid = list(
- "ml_g" = paradox::ParamSet$new(list(
- paradox::ParamDbl$new("cp", lower = 0.01, upper = 0.02),
+ "ml_l" = paradox::ParamSet$new(list(
+ paradox::ParamDbl$new("cp", lower = 0.02, upper = 0.03),
paradox::ParamInt$new("minsplit", lower = 1, upper = 2))),
"ml_m" = paradox::ParamSet$new(list(
- paradox::ParamDbl$new("cp", lower = 0.01, upper = 0.02),
- paradox::ParamInt$new("minsplit", lower = 1, upper = 2))))
+ paradox::ParamDbl$new("cp", lower = 0.03, upper = 0.04),
+ paradox::ParamInt$new("minsplit", lower = 2, upper = 3))))
+
+ if (score == "IV-type") {
+ param_grid[["ml_g"]] = paradox::ParamSet$new(list(
+ paradox::ParamDbl$new("cp", lower = 0.015, upper = 0.025),
+ paradox::ParamInt$new("minsplit", lower = 3, upper = 4)))
+ }
double_mlplr_obj_tuned$tune(param_set = param_grid, tune_on_folds = tune_on_folds, tune_settings = tune_sets)
diff --git a/tests/testthat/test-double_ml_plr_user_score.R b/tests/testthat/test-double_ml_plr_user_score.R
index 187442c5..586eaaa8 100644
--- a/tests/testthat/test-double_ml_plr_user_score.R
+++ b/tests/testthat/test-double_ml_plr_user_score.R
@@ -4,9 +4,9 @@ library("mlr3learners")
lgr::get_logger("mlr3")$set_threshold("warn")
-score_fct = function(y, d, g_hat, m_hat, smpls) {
+score_fct = function(y, d, l_hat, m_hat, g_hat, smpls) {
v_hat = d - m_hat
- u_hat = y - g_hat
+ u_hat = y - l_hat
v_hatd = v_hat * d
psi_a = -v_hat * v_hat
psi_b = v_hat * u_hat
@@ -39,7 +39,7 @@ patrick::with_parameters_test_that("Unit tests for PLR, callable score:",
double_mlplr_obj = DoubleMLPLR$new(
data = data_plr$dml_data,
- ml_g = lrn(learner),
+ ml_l = lrn(learner),
ml_m = lrn(learner),
dml_procedure = dml_procedure,
n_folds = n_folds,
@@ -55,7 +55,7 @@ patrick::with_parameters_test_that("Unit tests for PLR, callable score:",
set.seed(3141)
double_mlplr_obj_score = DoubleMLPLR$new(
data = data_plr$dml_data,
- ml_g = lrn(learner),
+ ml_l = lrn(learner),
ml_m = lrn(learner),
dml_procedure = dml_procedure,
n_folds = n_folds,
diff --git a/tests/testthat/test-double_ml_print.R b/tests/testthat/test-double_ml_print.R
index 44af8081..60e7452b 100644
--- a/tests/testthat/test-double_ml_print.R
+++ b/tests/testthat/test-double_ml_print.R
@@ -6,8 +6,8 @@ set.seed(3141)
dml_data = make_plr_CCDDHNR2018(n_obs = 100)
dml_cluster_data = make_pliv_multiway_cluster_CKMS2021(N = 10, M = 10, dim_X = 5)
-ml_g = ml_m = ml_r = "regr.rpart"
-dml_plr = DoubleMLPLR$new(dml_data, ml_g, ml_m, n_folds = 2)
+ml_l = ml_g = ml_m = ml_r = "regr.rpart"
+dml_plr = DoubleMLPLR$new(dml_data, ml_l, ml_m, n_folds = 2)
dml_pliv = DoubleMLPLIV$new(dml_cluster_data, ml_g, ml_m, ml_r, n_folds = 2)
dml_plr$fit()
dml_pliv$fit()
diff --git a/tests/testthat/test-double_ml_set_sample_splitting.R b/tests/testthat/test-double_ml_set_sample_splitting.R
index e17b9885..41c8db3c 100644
--- a/tests/testthat/test-double_ml_set_sample_splitting.R
+++ b/tests/testthat/test-double_ml_set_sample_splitting.R
@@ -2,9 +2,9 @@ context("Unit tests for the method set_sample_splitting of class DoubleML")
set.seed(3141)
dml_data = make_plr_CCDDHNR2018(n_obs = 10)
-ml_g = lrn("regr.ranger")
-ml_m = ml_g$clone()
-dml_plr = DoubleMLPLR$new(dml_data, ml_g, ml_m, n_folds = 7, n_rep = 8)
+ml_l = lrn("regr.ranger")
+ml_m = ml_l$clone()
+dml_plr = DoubleMLPLR$new(dml_data, ml_l, ml_m, n_folds = 7, n_rep = 8)
test_that("Unit tests for the method set_sample_splitting of class DoubleML", {
@@ -165,37 +165,37 @@ assert_resampling_pars = function(dml_obj0, dml_obj1) {
test_that("Unit tests for the method set_sample_splitting of class DoubleML (draw vs set)", {
set.seed(3141)
- dml_plr_set = DoubleMLPLR$new(dml_data, ml_g, ml_m, n_folds = 7, n_rep = 8)
+ dml_plr_set = DoubleMLPLR$new(dml_data, ml_l, ml_m, n_folds = 7, n_rep = 8)
- dml_plr_drawn = DoubleMLPLR$new(dml_data, ml_g, ml_m,
+ dml_plr_drawn = DoubleMLPLR$new(dml_data, ml_l, ml_m,
n_folds = 1, n_rep = 1, apply_cross_fitting = FALSE)
dml_plr_set$set_sample_splitting(dml_plr_drawn$smpls)
assert_resampling_pars(dml_plr_drawn, dml_plr_set)
dml_plr_set$set_sample_splitting(dml_plr_drawn$smpls[[1]])
assert_resampling_pars(dml_plr_drawn, dml_plr_set)
- dml_plr_drawn = DoubleMLPLR$new(dml_data, ml_g, ml_m,
+ dml_plr_drawn = DoubleMLPLR$new(dml_data, ml_l, ml_m,
n_folds = 2, n_rep = 1, apply_cross_fitting = FALSE)
dml_plr_set$set_sample_splitting(dml_plr_drawn$smpls)
assert_resampling_pars(dml_plr_drawn, dml_plr_set)
dml_plr_set$set_sample_splitting(dml_plr_drawn$smpls[[1]])
assert_resampling_pars(dml_plr_drawn, dml_plr_set)
- dml_plr_drawn = DoubleMLPLR$new(dml_data, ml_g, ml_m,
+ dml_plr_drawn = DoubleMLPLR$new(dml_data, ml_l, ml_m,
n_folds = 2, n_rep = 1, apply_cross_fitting = TRUE)
dml_plr_set$set_sample_splitting(dml_plr_drawn$smpls)
assert_resampling_pars(dml_plr_drawn, dml_plr_set)
dml_plr_set$set_sample_splitting(dml_plr_drawn$smpls[[1]])
assert_resampling_pars(dml_plr_drawn, dml_plr_set)
- dml_plr_drawn = DoubleMLPLR$new(dml_data, ml_g, ml_m,
+ dml_plr_drawn = DoubleMLPLR$new(dml_data, ml_l, ml_m,
n_folds = 5, n_rep = 1, apply_cross_fitting = TRUE)
dml_plr_set$set_sample_splitting(dml_plr_drawn$smpls)
assert_resampling_pars(dml_plr_drawn, dml_plr_set)
dml_plr_set$set_sample_splitting(dml_plr_drawn$smpls[[1]])
assert_resampling_pars(dml_plr_drawn, dml_plr_set)
- dml_plr_drawn = DoubleMLPLR$new(dml_data, ml_g, ml_m,
+ dml_plr_drawn = DoubleMLPLR$new(dml_data, ml_l, ml_m,
n_folds = 5, n_rep = 3, apply_cross_fitting = TRUE)
dml_plr_set$set_sample_splitting(dml_plr_drawn$smpls)
assert_resampling_pars(dml_plr_drawn, dml_plr_set)