diff --git a/DESCRIPTION b/DESCRIPTION index 15e4c0c0..1e1b9806 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -29,7 +29,7 @@ Description: Flexible stochastic tree ensemble software. License: MIT + file LICENSE Encoding: UTF-8 Roxygen: list(markdown = TRUE) -RoxygenNote: 7.3.2 +RoxygenNote: 7.3.3 LinkingTo: cpp11, BH Suggests: diff --git a/NAMESPACE b/NAMESPACE index 2f4103c0..5303244d 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -10,6 +10,10 @@ export(calibrateInverseGammaErrorVariance) export(computeForestLeafIndices) export(computeForestLeafVariances) export(computeForestMaxLeafIndex) +export(compute_bart_posterior_interval) +export(compute_bcf_posterior_interval) +export(compute_contrast_bart_model) +export(compute_contrast_bcf_model) export(convertPreprocessorToJson) export(createBARTModelFromCombinedJson) export(createBARTModelFromCombinedJsonString) @@ -60,6 +64,8 @@ export(rootResetRandomEffectsModel) export(rootResetRandomEffectsTracker) export(sampleGlobalErrorVarianceOneIteration) export(sampleLeafVarianceOneIteration) +export(sample_bart_posterior_predictive) +export(sample_bcf_posterior_predictive) export(sample_without_replacement) export(saveBARTModelToJson) export(saveBARTModelToJsonFile) @@ -77,6 +83,7 @@ importFrom(stats,pnorm) importFrom(stats,predict) importFrom(stats,qgamma) importFrom(stats,qnorm) +importFrom(stats,quantile) importFrom(stats,resid) importFrom(stats,rnorm) importFrom(stats,runif) diff --git a/NEWS.md b/NEWS.md index fc00c05e..c794b856 100644 --- a/NEWS.md +++ b/NEWS.md @@ -2,13 +2,30 @@ ## New Features +* Support for multithreading in various elements of the GFR and MCMC algorithms ([#182](https://github.com/StochasticTree/stochtree/pull/182)) * Support for binary outcomes in BART and BCF with a probit link ([#164](https://github.com/StochasticTree/stochtree/pull/164)) +* Enable "restricted sweep" of tree algorithms over a handful of trees ([#173](https://github.com/StochasticTree/stochtree/pull/173)) +* Support for multivariate treatment in R ([#183](https://github.com/StochasticTree/stochtree/pull/183)) +* Enable modification of dataset variables (weights, etc...) via low-level interface ([#194](https://github.com/StochasticTree/stochtree/pull/194)) + +## Computational Improvements + +* Modified default random effects initialization ([#190](https://github.com/StochasticTree/stochtree/pull/190)) +* Avoid double prediction on training set ([#178](https://github.com/StochasticTree/stochtree/pull/178)) ## Bug Fixes * Fixed indexing bug in cleanup of grow-from-root (GFR) samples in BART and BCF models * Avoid using covariate preprocessor in `computeForestLeafIndices` function when a `ForestSamples` object is provided (rather than a `bartmodel` or `bcfmodel` object) +## Other Changes + +* Standardized naming conventions for out of sample data in prediction and posterior computation routines (we raise warnings when data are passed through `y`, `X`, `Z`, etc... arguments) + * Covariates / features are always referred to as "covariates" rather than "X" + * Treatment is referred to as "treatment" rather than "Z" + * Propensity scores are referred to as "propensity" rather than "pi_X" + * Outcomes are referred to as "outcome" rather than "Y" + # stochtree 0.1.1 * Fixed initialization bug in several R package code examples for random effects models diff --git a/R/bart.R b/R/bart.R index 5a7be03a..4f152ba2 100644 --- a/R/bart.R +++ b/R/bart.R @@ -45,12 +45,6 @@ #' - `num_chains` How many independent MCMC chains should be sampled. If `num_mcmc = 0`, this is ignored. If `num_gfr = 0`, then each chain is run from root for `num_mcmc * keep_every + num_burnin` iterations, with `num_mcmc` samples retained. If `num_gfr > 0`, each MCMC chain will be initialized from a separate GFR ensemble, with the requirement that `num_gfr >= num_chains`. Default: `1`. #' - `verbose` Whether or not to print progress during the sampling loops. Default: `FALSE`. #' - `probit_outcome_model` Whether or not the outcome should be modeled as explicitly binary via a probit link. If `TRUE`, `y` must only contain the values `0` and `1`. Default: `FALSE`. -#' - `rfx_working_parameter_prior_mean` Prior mean for the random effects "working parameter". Default: `NULL`. Must be a vector whose dimension matches the number of random effects bases, or a scalar value that will be expanded to a vector. -#' - `rfx_group_parameters_prior_mean` Prior mean for the random effects "group parameters." Default: `NULL`. Must be a vector whose dimension matches the number of random effects bases, or a scalar value that will be expanded to a vector. -#' - `rfx_working_parameter_prior_cov` Prior covariance matrix for the random effects "working parameter." Default: `NULL`. Must be a square matrix whose dimension matches the number of random effects bases, or a scalar value that will be expanded to a diagonal matrix. -#' - `rfx_group_parameter_prior_cov` Prior covariance matrix for the random effects "group parameters." Default: `NULL`. Must be a square matrix whose dimension matches the number of random effects bases, or a scalar value that will be expanded to a diagonal matrix. -#' - `rfx_variance_prior_shape` Shape parameter for the inverse gamma prior on the variance of the random effects "group parameter." Default: `1`. -#' - `rfx_variance_prior_scale` Scale parameter for the inverse gamma prior on the variance of the random effects "group parameter." Default: `1`. #' - `num_threads` Number of threads to use in the GFR and MCMC algorithms, as well as prediction. If OpenMP is not available on a user's setup, this will default to `1`, otherwise to the maximum number of available threads. #' #' @param mean_forest_params (Optional) A list of mean forest model parameters, each of which has a default value processed internally, so this argument list is optional. @@ -83,6 +77,16 @@ #' - `drop_vars` Vector of variable names or column indices denoting variables that should be excluded from the forest. Default: `NULL`. If both `drop_vars` and `keep_vars` are set, `drop_vars` will be ignored. #' - `num_features_subsample` How many features to subsample when growing each tree for the GFR algorithm. Defaults to the number of features in the training dataset. #' +#' @param random_effects_params (Optional) A list of random effects model parameters, each of which has a default value processed internally, so this argument list is optional. +#' +#' - `model_spec` Specification of the random effects model. Options are "custom" and "intercept_only". If "custom" is specified, then a user-provided basis must be passed through `rfx_basis_train`. If "intercept_only" is specified, a random effects basis of all ones will be dispatched internally at sampling and prediction time. If "intercept_plus_treatment" is specified, a random effects basis that combines an "intercept" basis of all ones with the treatment variable (`Z_train`) will be dispatched internally at sampling and prediction time. Default: "custom". If "intercept_only" is specified, `rfx_basis_train` and `rfx_basis_test` (if provided) will be ignored. +#' - `working_parameter_prior_mean` Prior mean for the random effects "working parameter". Default: `NULL`. Must be a vector whose dimension matches the number of random effects bases, or a scalar value that will be expanded to a vector. +#' - `group_parameters_prior_mean` Prior mean for the random effects "group parameters." Default: `NULL`. Must be a vector whose dimension matches the number of random effects bases, or a scalar value that will be expanded to a vector. +#' - `working_parameter_prior_cov` Prior covariance matrix for the random effects "working parameter." Default: `NULL`. Must be a square matrix whose dimension matches the number of random effects bases, or a scalar value that will be expanded to a diagonal matrix. +#' - `group_parameter_prior_cov` Prior covariance matrix for the random effects "group parameters." Default: `NULL`. Must be a square matrix whose dimension matches the number of random effects bases, or a scalar value that will be expanded to a diagonal matrix. +#' - `variance_prior_shape` Shape parameter for the inverse gamma prior on the variance of the random effects "group parameter." Default: `1`. +#' - `variance_prior_scale` Scale parameter for the inverse gamma prior on the variance of the random effects "group parameter." Default: `1`. +#' #' @return List of sampling outputs and a wrapper around the sampled forests (which can be used for in-memory prediction on new data, or serialized to JSON on disk). #' @export #' @@ -111,1685 +115,1719 @@ #' bart_model <- bart(X_train = X_train, y_train = y_train, X_test = X_test, #' num_gfr = 10, num_burnin = 0, num_mcmc = 10) bart <- function( - X_train, - y_train, - leaf_basis_train = NULL, - rfx_group_ids_train = NULL, - rfx_basis_train = NULL, - X_test = NULL, - leaf_basis_test = NULL, - rfx_group_ids_test = NULL, - rfx_basis_test = NULL, - num_gfr = 5, - num_burnin = 0, - num_mcmc = 100, - previous_model_json = NULL, - previous_model_warmstart_sample_num = NULL, - general_params = list(), - mean_forest_params = list(), - variance_forest_params = list() + X_train, + y_train, + leaf_basis_train = NULL, + rfx_group_ids_train = NULL, + rfx_basis_train = NULL, + X_test = NULL, + leaf_basis_test = NULL, + rfx_group_ids_test = NULL, + rfx_basis_test = NULL, + num_gfr = 5, + num_burnin = 0, + num_mcmc = 100, + previous_model_json = NULL, + previous_model_warmstart_sample_num = NULL, + general_params = list(), + mean_forest_params = list(), + variance_forest_params = list(), + random_effects_params = list() ) { - # Update general BART parameters - general_params_default <- list( - cutpoint_grid_size = 100, - standardize = TRUE, - sample_sigma2_global = TRUE, - sigma2_global_init = NULL, - sigma2_global_shape = 0, - sigma2_global_scale = 0, - variable_weights = NULL, - random_seed = -1, - keep_burnin = FALSE, - keep_gfr = FALSE, - keep_every = 1, - num_chains = 1, - verbose = FALSE, - probit_outcome_model = FALSE, - rfx_working_parameter_prior_mean = NULL, - rfx_group_parameter_prior_mean = NULL, - rfx_working_parameter_prior_cov = NULL, - rfx_group_parameter_prior_cov = NULL, - rfx_variance_prior_shape = 1, - rfx_variance_prior_scale = 1, - num_threads = -1 - ) - general_params_updated <- preprocessParams( - general_params_default, - general_params - ) - - # Update mean forest BART parameters - mean_forest_params_default <- list( - num_trees = 200, - alpha = 0.95, - beta = 2.0, - min_samples_leaf = 5, - max_depth = 10, - sample_sigma2_leaf = TRUE, - sigma2_leaf_init = NULL, - sigma2_leaf_shape = 3, - sigma2_leaf_scale = NULL, - keep_vars = NULL, - drop_vars = NULL, - num_features_subsample = NULL - ) - mean_forest_params_updated <- preprocessParams( - mean_forest_params_default, - mean_forest_params + # Update general BART parameters + general_params_default <- list( + cutpoint_grid_size = 100, + standardize = TRUE, + sample_sigma2_global = TRUE, + sigma2_global_init = NULL, + sigma2_global_shape = 0, + sigma2_global_scale = 0, + variable_weights = NULL, + random_seed = -1, + keep_burnin = FALSE, + keep_gfr = FALSE, + keep_every = 1, + num_chains = 1, + verbose = FALSE, + probit_outcome_model = FALSE, + num_threads = -1 + ) + general_params_updated <- preprocessParams( + general_params_default, + general_params + ) + + # Update mean forest BART parameters + mean_forest_params_default <- list( + num_trees = 200, + alpha = 0.95, + beta = 2.0, + min_samples_leaf = 5, + max_depth = 10, + sample_sigma2_leaf = TRUE, + sigma2_leaf_init = NULL, + sigma2_leaf_shape = 3, + sigma2_leaf_scale = NULL, + keep_vars = NULL, + drop_vars = NULL, + num_features_subsample = NULL + ) + mean_forest_params_updated <- preprocessParams( + mean_forest_params_default, + mean_forest_params + ) + + # Update variance forest BART parameters + variance_forest_params_default <- list( + num_trees = 0, + alpha = 0.95, + beta = 2.0, + min_samples_leaf = 5, + max_depth = 10, + leaf_prior_calibration_param = 1.5, + var_forest_leaf_init = NULL, + var_forest_prior_shape = NULL, + var_forest_prior_scale = NULL, + keep_vars = NULL, + drop_vars = NULL, + num_features_subsample = NULL + ) + variance_forest_params_updated <- preprocessParams( + variance_forest_params_default, + variance_forest_params + ) + + # Update rfx parameters + rfx_params_default <- list( + model_spec = "custom", + working_parameter_prior_mean = NULL, + group_parameter_prior_mean = NULL, + working_parameter_prior_cov = NULL, + group_parameter_prior_cov = NULL, + variance_prior_shape = 1, + variance_prior_scale = 1 + ) + rfx_params_updated <- preprocessParams( + rfx_params_default, + random_effects_params + ) + + ### Unpack all parameter values + # 1. General parameters + cutpoint_grid_size <- general_params_updated$cutpoint_grid_size + standardize <- general_params_updated$standardize + sample_sigma2_global <- general_params_updated$sample_sigma2_global + sigma2_init <- general_params_updated$sigma2_global_init + a_global <- general_params_updated$sigma2_global_shape + b_global <- general_params_updated$sigma2_global_scale + variable_weights <- general_params_updated$variable_weights + random_seed <- general_params_updated$random_seed + keep_burnin <- general_params_updated$keep_burnin + keep_gfr <- general_params_updated$keep_gfr + keep_every <- general_params_updated$keep_every + num_chains <- general_params_updated$num_chains + verbose <- general_params_updated$verbose + probit_outcome_model <- general_params_updated$probit_outcome_model + num_threads <- general_params_updated$num_threads + + # 2. Mean forest parameters + num_trees_mean <- mean_forest_params_updated$num_trees + alpha_mean <- mean_forest_params_updated$alpha + beta_mean <- mean_forest_params_updated$beta + min_samples_leaf_mean <- mean_forest_params_updated$min_samples_leaf + max_depth_mean <- mean_forest_params_updated$max_depth + sample_sigma2_leaf <- mean_forest_params_updated$sample_sigma2_leaf + sigma2_leaf_init <- mean_forest_params_updated$sigma2_leaf_init + a_leaf <- mean_forest_params_updated$sigma2_leaf_shape + b_leaf <- mean_forest_params_updated$sigma2_leaf_scale + keep_vars_mean <- mean_forest_params_updated$keep_vars + drop_vars_mean <- mean_forest_params_updated$drop_vars + num_features_subsample_mean <- mean_forest_params_updated$num_features_subsample + + # 3. Variance forest parameters + num_trees_variance <- variance_forest_params_updated$num_trees + alpha_variance <- variance_forest_params_updated$alpha + beta_variance <- variance_forest_params_updated$beta + min_samples_leaf_variance <- variance_forest_params_updated$min_samples_leaf + max_depth_variance <- variance_forest_params_updated$max_depth + a_0 <- variance_forest_params_updated$leaf_prior_calibration_param + variance_forest_init <- variance_forest_params_updated$var_forest_leaf_init + a_forest <- variance_forest_params_updated$var_forest_prior_shape + b_forest <- variance_forest_params_updated$var_forest_prior_scale + keep_vars_variance <- variance_forest_params_updated$keep_vars + drop_vars_variance <- variance_forest_params_updated$drop_vars + num_features_subsample_variance <- variance_forest_params_updated$num_features_subsample + + # 4. RFX parameters + rfx_model_spec <- rfx_params_updated$model_spec + rfx_working_parameter_prior_mean <- rfx_params_updated$working_parameter_prior_mean + rfx_group_parameter_prior_mean <- rfx_params_updated$group_parameter_prior_mean + rfx_working_parameter_prior_cov <- rfx_params_updated$working_parameter_prior_cov + rfx_group_parameter_prior_cov <- rfx_params_updated$group_parameter_prior_cov + rfx_variance_prior_shape <- rfx_params_updated$variance_prior_shape + rfx_variance_prior_scale <- rfx_params_updated$variance_prior_scale + + # Set a function-scoped RNG if user provided a random seed + custom_rng <- random_seed >= 0 + if (custom_rng) { + # Store original global environment RNG state + original_global_seed <- .Random.seed + # Set new seed and store associated RNG state + set.seed(random_seed) + function_scoped_seed <- .Random.seed + } + + # Check if there are enough GFR samples to seed num_chains samplers + if (num_gfr > 0) { + if (num_chains > num_gfr) { + stop( + "num_chains > num_gfr, meaning we do not have enough GFR samples to seed num_chains distinct MCMC chains" + ) + } + } + + # Override keep_gfr if there are no MCMC samples + if (num_mcmc == 0) { + keep_gfr <- TRUE + } + + # Check if previous model JSON is provided and parse it if so + has_prev_model <- !is.null(previous_model_json) + if (has_prev_model) { + previous_bart_model <- createBARTModelFromJsonString( + previous_model_json ) - - # Update variance forest BART parameters - variance_forest_params_default <- list( - num_trees = 0, - alpha = 0.95, - beta = 2.0, - min_samples_leaf = 5, - max_depth = 10, - leaf_prior_calibration_param = 1.5, - var_forest_leaf_init = NULL, - var_forest_prior_shape = NULL, - var_forest_prior_scale = NULL, - keep_vars = NULL, - drop_vars = NULL, - num_features_subsample = NULL - ) - variance_forest_params_updated <- preprocessParams( - variance_forest_params_default, - variance_forest_params - ) - - ### Unpack all parameter values - # 1. General parameters - cutpoint_grid_size <- general_params_updated$cutpoint_grid_size - standardize <- general_params_updated$standardize - sample_sigma2_global <- general_params_updated$sample_sigma2_global - sigma2_init <- general_params_updated$sigma2_global_init - a_global <- general_params_updated$sigma2_global_shape - b_global <- general_params_updated$sigma2_global_scale - variable_weights <- general_params_updated$variable_weights - random_seed <- general_params_updated$random_seed - keep_burnin <- general_params_updated$keep_burnin - keep_gfr <- general_params_updated$keep_gfr - keep_every <- general_params_updated$keep_every - num_chains <- general_params_updated$num_chains - verbose <- general_params_updated$verbose - probit_outcome_model <- general_params_updated$probit_outcome_model - rfx_working_parameter_prior_mean <- general_params_updated$rfx_working_parameter_prior_mean - rfx_group_parameter_prior_mean <- general_params_updated$rfx_group_parameter_prior_mean - rfx_working_parameter_prior_cov <- general_params_updated$rfx_working_parameter_prior_cov - rfx_group_parameter_prior_cov <- general_params_updated$rfx_group_parameter_prior_cov - rfx_variance_prior_shape <- general_params_updated$rfx_variance_prior_shape - rfx_variance_prior_scale <- general_params_updated$rfx_variance_prior_scale - num_threads <- general_params_updated$num_threads - - # 2. Mean forest parameters - num_trees_mean <- mean_forest_params_updated$num_trees - alpha_mean <- mean_forest_params_updated$alpha - beta_mean <- mean_forest_params_updated$beta - min_samples_leaf_mean <- mean_forest_params_updated$min_samples_leaf - max_depth_mean <- mean_forest_params_updated$max_depth - sample_sigma2_leaf <- mean_forest_params_updated$sample_sigma2_leaf - sigma2_leaf_init <- mean_forest_params_updated$sigma2_leaf_init - a_leaf <- mean_forest_params_updated$sigma2_leaf_shape - b_leaf <- mean_forest_params_updated$sigma2_leaf_scale - keep_vars_mean <- mean_forest_params_updated$keep_vars - drop_vars_mean <- mean_forest_params_updated$drop_vars - num_features_subsample_mean <- mean_forest_params_updated$num_features_subsample - - # 3. Variance forest parameters - num_trees_variance <- variance_forest_params_updated$num_trees - alpha_variance <- variance_forest_params_updated$alpha - beta_variance <- variance_forest_params_updated$beta - min_samples_leaf_variance <- variance_forest_params_updated$min_samples_leaf - max_depth_variance <- variance_forest_params_updated$max_depth - a_0 <- variance_forest_params_updated$leaf_prior_calibration_param - variance_forest_init <- variance_forest_params_updated$var_forest_leaf_init - a_forest <- variance_forest_params_updated$var_forest_prior_shape - b_forest <- variance_forest_params_updated$var_forest_prior_scale - keep_vars_variance <- variance_forest_params_updated$keep_vars - drop_vars_variance <- variance_forest_params_updated$drop_vars - num_features_subsample_variance <- variance_forest_params_updated$num_features_subsample - - # Set a function-scoped RNG if user provided a random seed - custom_rng <- random_seed >= 0 - if (custom_rng) { - # Store original global environment RNG state - original_global_seed <- .Random.seed - # Set new seed and store associated RNG state - set.seed(random_seed) - function_scoped_seed <- .Random.seed - } - - # Check if there are enough GFR samples to seed num_chains samplers - if (num_gfr > 0) { - if (num_chains > num_gfr) { - stop( - "num_chains > num_gfr, meaning we do not have enough GFR samples to seed num_chains distinct MCMC chains" - ) - } - } - - # Override keep_gfr if there are no MCMC samples - if (num_mcmc == 0) { - keep_gfr <- TRUE - } - - # Check if previous model JSON is provided and parse it if so - has_prev_model <- !is.null(previous_model_json) - if (has_prev_model) { - previous_bart_model <- createBARTModelFromJsonString( - previous_model_json - ) - previous_y_bar <- previous_bart_model$model_params$outcome_mean - previous_y_scale <- previous_bart_model$model_params$outcome_scale - if (previous_bart_model$model_params$include_mean_forest) { - previous_forest_samples_mean <- previous_bart_model$mean_forests - } else { - previous_forest_samples_mean <- NULL - } - if (previous_bart_model$model_params$include_variance_forest) { - previous_forest_samples_variance <- previous_bart_model$variance_forests - } else { - previous_forest_samples_variance <- NULL - } - if (previous_bart_model$model_params$sample_sigma2_global) { - previous_global_var_samples <- previous_bart_model$sigma2_global_samples / - (previous_y_scale * previous_y_scale) - } else { - previous_global_var_samples <- NULL - } - if (previous_bart_model$model_params$sample_sigma2_leaf) { - previous_leaf_var_samples <- previous_bart_model$sigma2_leaf_samples - } else { - previous_leaf_var_samples <- NULL - } - if (previous_bart_model$model_params$has_rfx) { - previous_rfx_samples <- previous_bart_model$rfx_samples - } else { - previous_rfx_samples <- NULL - } - previous_model_num_samples <- previous_bart_model$model_params$num_samples - if (previous_model_warmstart_sample_num > previous_model_num_samples) { - stop( - "`previous_model_warmstart_sample_num` exceeds the number of samples in `previous_model_json`" - ) - } - } else { - previous_y_bar <- NULL - previous_y_scale <- NULL - previous_global_var_samples <- NULL - previous_leaf_var_samples <- NULL - previous_rfx_samples <- NULL - previous_forest_samples_mean <- NULL - previous_forest_samples_variance <- NULL - previous_model_num_samples <- 0 - } - - # Determine whether conditional mean, variance, or both will be modeled - if (num_trees_variance > 0) { - include_variance_forest = TRUE + previous_y_bar <- previous_bart_model$model_params$outcome_mean + previous_y_scale <- previous_bart_model$model_params$outcome_scale + if (previous_bart_model$model_params$include_mean_forest) { + previous_forest_samples_mean <- previous_bart_model$mean_forests } else { - include_variance_forest = FALSE + previous_forest_samples_mean <- NULL } - if (num_trees_mean > 0) { - include_mean_forest = TRUE + if (previous_bart_model$model_params$include_variance_forest) { + previous_forest_samples_variance <- previous_bart_model$variance_forests } else { - include_mean_forest = FALSE + previous_forest_samples_variance <- NULL } - - # Set the variance forest priors if not set - if (include_variance_forest) { - if (is.null(a_forest)) { - a_forest <- num_trees_variance / (a_0^2) + 0.5 - } - if (is.null(b_forest)) b_forest <- num_trees_variance / (a_0^2) + if (previous_bart_model$model_params$sample_sigma2_global) { + previous_global_var_samples <- previous_bart_model$sigma2_global_samples / + (previous_y_scale * previous_y_scale) } else { - a_forest <- 1. - b_forest <- 1. + previous_global_var_samples <- NULL } - - # Override tau sampling if there is no mean forest - if (!include_mean_forest) { - sample_sigma2_leaf <- FALSE - } - - # Variable weight preprocessing (and initialization if necessary) - if (is.null(variable_weights)) { - variable_weights = rep(1 / ncol(X_train), ncol(X_train)) - } - if (any(variable_weights < 0)) { - stop("variable_weights cannot have any negative weights") - } - - # Check covariates are matrix or dataframe - if ((!is.data.frame(X_train)) && (!is.matrix(X_train))) { - stop("X_train must be a matrix or dataframe") - } - if (!is.null(X_test)) { - if ((!is.data.frame(X_test)) && (!is.matrix(X_test))) { - stop("X_test must be a matrix or dataframe") - } + if (previous_bart_model$model_params$sample_sigma2_leaf) { + previous_leaf_var_samples <- previous_bart_model$sigma2_leaf_samples + } else { + previous_leaf_var_samples <- NULL } - num_cov_orig <- ncol(X_train) - - # Standardize the keep variable lists to numeric indices - if (!is.null(keep_vars_mean)) { - if (is.character(keep_vars_mean)) { - if (!all(keep_vars_mean %in% names(X_train))) { - stop( - "keep_vars_mean includes some variable names that are not in X_train" - ) - } - variable_subset_mu <- unname(which( - names(X_train) %in% keep_vars_mean - )) - } else { - if (any(keep_vars_mean > ncol(X_train))) { - stop( - "keep_vars_mean includes some variable indices that exceed the number of columns in X_train" - ) - } - if (any(keep_vars_mean < 0)) { - stop("keep_vars_mean includes some negative variable indices") - } - variable_subset_mu <- keep_vars_mean - } - } else if ((is.null(keep_vars_mean)) && (!is.null(drop_vars_mean))) { - if (is.character(drop_vars_mean)) { - if (!all(drop_vars_mean %in% names(X_train))) { - stop( - "drop_vars_mean includes some variable names that are not in X_train" - ) - } - variable_subset_mean <- unname(which( - !(names(X_train) %in% drop_vars_mean) - )) - } else { - if (any(drop_vars_mean > ncol(X_train))) { - stop( - "drop_vars_mean includes some variable indices that exceed the number of columns in X_train" - ) - } - if (any(drop_vars_mean < 0)) { - stop("drop_vars_mean includes some negative variable indices") - } - variable_subset_mean <- (1:ncol(X_train))[ - !(1:ncol(X_train) %in% drop_vars_mean) - ] - } + if (previous_bart_model$model_params$has_rfx) { + previous_rfx_samples <- previous_bart_model$rfx_samples } else { - variable_subset_mean <- 1:ncol(X_train) - } - if (!is.null(keep_vars_variance)) { - if (is.character(keep_vars_variance)) { - if (!all(keep_vars_variance %in% names(X_train))) { - stop( - "keep_vars_variance includes some variable names that are not in X_train" - ) - } - variable_subset_variance <- unname(which( - names(X_train) %in% keep_vars_variance - )) - } else { - if (any(keep_vars_variance > ncol(X_train))) { - stop( - "keep_vars_variance includes some variable indices that exceed the number of columns in X_train" - ) - } - if (any(keep_vars_variance < 0)) { - stop( - "keep_vars_variance includes some negative variable indices" - ) - } - variable_subset_variance <- keep_vars_variance - } - } else if ( - (is.null(keep_vars_variance)) && (!is.null(drop_vars_variance)) - ) { - if (is.character(drop_vars_variance)) { - if (!all(drop_vars_variance %in% names(X_train))) { - stop( - "drop_vars_variance includes some variable names that are not in X_train" - ) - } - variable_subset_variance <- unname(which( - !(names(X_train) %in% drop_vars_variance) - )) - } else { - if (any(drop_vars_variance > ncol(X_train))) { - stop( - "drop_vars_variance includes some variable indices that exceed the number of columns in X_train" - ) - } - if (any(drop_vars_variance < 0)) { - stop( - "drop_vars_variance includes some negative variable indices" - ) - } - variable_subset_variance <- (1:ncol(X_train))[ - !(1:ncol(X_train) %in% drop_vars_variance) - ] - } + previous_rfx_samples <- NULL + } + previous_model_num_samples <- previous_bart_model$model_params$num_samples + if (previous_model_warmstart_sample_num > previous_model_num_samples) { + stop( + "`previous_model_warmstart_sample_num` exceeds the number of samples in `previous_model_json`" + ) + } + } else { + previous_y_bar <- NULL + previous_y_scale <- NULL + previous_global_var_samples <- NULL + previous_leaf_var_samples <- NULL + previous_rfx_samples <- NULL + previous_forest_samples_mean <- NULL + previous_forest_samples_variance <- NULL + previous_model_num_samples <- 0 + } + + # Determine whether conditional mean, variance, or both will be modeled + if (num_trees_variance > 0) { + include_variance_forest = TRUE + } else { + include_variance_forest = FALSE + } + if (num_trees_mean > 0) { + include_mean_forest = TRUE + } else { + include_mean_forest = FALSE + } + + # Set the variance forest priors if not set + if (include_variance_forest) { + if (is.null(a_forest)) { + a_forest <- num_trees_variance / (a_0^2) + 0.5 + } + if (is.null(b_forest)) b_forest <- num_trees_variance / (a_0^2) + } else { + a_forest <- 1. + b_forest <- 1. + } + + # Override tau sampling if there is no mean forest + if (!include_mean_forest) { + sample_sigma2_leaf <- FALSE + } + + # Variable weight preprocessing (and initialization if necessary) + if (is.null(variable_weights)) { + variable_weights = rep(1 / ncol(X_train), ncol(X_train)) + } + if (any(variable_weights < 0)) { + stop("variable_weights cannot have any negative weights") + } + + # Check covariates are matrix or dataframe + if ((!is.data.frame(X_train)) && (!is.matrix(X_train))) { + stop("X_train must be a matrix or dataframe") + } + if (!is.null(X_test)) { + if ((!is.data.frame(X_test)) && (!is.matrix(X_test))) { + stop("X_test must be a matrix or dataframe") + } + } + num_cov_orig <- ncol(X_train) + + # Standardize the keep variable lists to numeric indices + if (!is.null(keep_vars_mean)) { + if (is.character(keep_vars_mean)) { + if (!all(keep_vars_mean %in% names(X_train))) { + stop( + "keep_vars_mean includes some variable names that are not in X_train" + ) + } + variable_subset_mu <- unname(which( + names(X_train) %in% keep_vars_mean + )) } else { - variable_subset_variance <- 1:ncol(X_train) - } - - # Preprocess covariates - if ((!is.data.frame(X_train)) && (!is.matrix(X_train))) { - stop("X_train must be a matrix or dataframe") - } - if (!is.null(X_test)) { - if ((!is.data.frame(X_test)) && (!is.matrix(X_test))) { - stop("X_test must be a matrix or dataframe") - } - } - if (ncol(X_train) != length(variable_weights)) { - stop("length(variable_weights) must equal ncol(X_train)") - } - train_cov_preprocess_list <- preprocessTrainData(X_train) - X_train_metadata <- train_cov_preprocess_list$metadata - X_train <- train_cov_preprocess_list$data - original_var_indices <- X_train_metadata$original_var_indices - feature_types <- X_train_metadata$feature_types - if (!is.null(X_test)) { - X_test <- preprocessPredictionData(X_test, X_train_metadata) - } - - # Update variable weights - variable_weights_mean <- variable_weights_variance <- variable_weights - variable_weights_adj <- 1 / - sapply(original_var_indices, function(x) sum(original_var_indices == x)) - if (include_mean_forest) { - variable_weights_mean <- variable_weights_mean[original_var_indices] * - variable_weights_adj - variable_weights_mean[ - !(original_var_indices %in% variable_subset_mean) - ] <- 0 - } - if (include_variance_forest) { - variable_weights_variance <- variable_weights_variance[ - original_var_indices - ] * - variable_weights_adj - variable_weights_variance[ - !(original_var_indices %in% variable_subset_variance) - ] <- 0 - } - - # Set num_features_subsample to default, ncol(X_train), if not already set - if (is.null(num_features_subsample_mean)) { - num_features_subsample_mean <- ncol(X_train) - } - if (is.null(num_features_subsample_variance)) { - num_features_subsample_variance <- ncol(X_train) - } - - # Convert all input data to matrices if not already converted - if ((is.null(dim(leaf_basis_train))) && (!is.null(leaf_basis_train))) { - leaf_basis_train <- as.matrix(leaf_basis_train) - } - if ((is.null(dim(leaf_basis_test))) && (!is.null(leaf_basis_test))) { - leaf_basis_test <- as.matrix(leaf_basis_test) - } - if ((is.null(dim(rfx_basis_train))) && (!is.null(rfx_basis_train))) { - rfx_basis_train <- as.matrix(rfx_basis_train) - } - if ((is.null(dim(rfx_basis_test))) && (!is.null(rfx_basis_test))) { - rfx_basis_test <- as.matrix(rfx_basis_test) + if (any(keep_vars_mean > ncol(X_train))) { + stop( + "keep_vars_mean includes some variable indices that exceed the number of columns in X_train" + ) + } + if (any(keep_vars_mean < 0)) { + stop("keep_vars_mean includes some negative variable indices") + } + variable_subset_mu <- keep_vars_mean + } + } else if ((is.null(keep_vars_mean)) && (!is.null(drop_vars_mean))) { + if (is.character(drop_vars_mean)) { + if (!all(drop_vars_mean %in% names(X_train))) { + stop( + "drop_vars_mean includes some variable names that are not in X_train" + ) + } + variable_subset_mean <- unname(which( + !(names(X_train) %in% drop_vars_mean) + )) + } else { + if (any(drop_vars_mean > ncol(X_train))) { + stop( + "drop_vars_mean includes some variable indices that exceed the number of columns in X_train" + ) + } + if (any(drop_vars_mean < 0)) { + stop("drop_vars_mean includes some negative variable indices") + } + variable_subset_mean <- (1:ncol(X_train))[ + !(1:ncol(X_train) %in% drop_vars_mean) + ] + } + } else { + variable_subset_mean <- 1:ncol(X_train) + } + if (!is.null(keep_vars_variance)) { + if (is.character(keep_vars_variance)) { + if (!all(keep_vars_variance %in% names(X_train))) { + stop( + "keep_vars_variance includes some variable names that are not in X_train" + ) + } + variable_subset_variance <- unname(which( + names(X_train) %in% keep_vars_variance + )) + } else { + if (any(keep_vars_variance > ncol(X_train))) { + stop( + "keep_vars_variance includes some variable indices that exceed the number of columns in X_train" + ) + } + if (any(keep_vars_variance < 0)) { + stop( + "keep_vars_variance includes some negative variable indices" + ) + } + variable_subset_variance <- keep_vars_variance } - - # Recode group IDs to integer vector (if passed as, for example, a vector of county names, etc...) - has_rfx <- FALSE - has_rfx_test <- FALSE - if (!is.null(rfx_group_ids_train)) { - group_ids_factor <- factor(rfx_group_ids_train) - rfx_group_ids_train <- as.integer(group_ids_factor) - has_rfx <- TRUE - if (!is.null(rfx_group_ids_test)) { - group_ids_factor_test <- factor( - rfx_group_ids_test, - levels = levels(group_ids_factor) - ) - if (sum(is.na(group_ids_factor_test)) > 0) { - stop( - "All random effect group labels provided in rfx_group_ids_test must be present in rfx_group_ids_train" - ) - } - rfx_group_ids_test <- as.integer(group_ids_factor_test) - has_rfx_test <- TRUE - } + } else if ((is.null(keep_vars_variance)) && (!is.null(drop_vars_variance))) { + if (is.character(drop_vars_variance)) { + if (!all(drop_vars_variance %in% names(X_train))) { + stop( + "drop_vars_variance includes some variable names that are not in X_train" + ) + } + variable_subset_variance <- unname(which( + !(names(X_train) %in% drop_vars_variance) + )) + } else { + if (any(drop_vars_variance > ncol(X_train))) { + stop( + "drop_vars_variance includes some variable indices that exceed the number of columns in X_train" + ) + } + if (any(drop_vars_variance < 0)) { + stop( + "drop_vars_variance includes some negative variable indices" + ) + } + variable_subset_variance <- (1:ncol(X_train))[ + !(1:ncol(X_train) %in% drop_vars_variance) + ] + } + } else { + variable_subset_variance <- 1:ncol(X_train) + } + + # Preprocess covariates + if ((!is.data.frame(X_train)) && (!is.matrix(X_train))) { + stop("X_train must be a matrix or dataframe") + } + if (!is.null(X_test)) { + if ((!is.data.frame(X_test)) && (!is.matrix(X_test))) { + stop("X_test must be a matrix or dataframe") + } + } + if (ncol(X_train) != length(variable_weights)) { + stop("length(variable_weights) must equal ncol(X_train)") + } + train_cov_preprocess_list <- preprocessTrainData(X_train) + X_train_metadata <- train_cov_preprocess_list$metadata + X_train <- train_cov_preprocess_list$data + original_var_indices <- X_train_metadata$original_var_indices + feature_types <- X_train_metadata$feature_types + if (!is.null(X_test)) { + X_test <- preprocessPredictionData(X_test, X_train_metadata) + } + + # Update variable weights + variable_weights_mean <- variable_weights_variance <- variable_weights + variable_weights_adj <- 1 / + sapply(original_var_indices, function(x) sum(original_var_indices == x)) + if (include_mean_forest) { + variable_weights_mean <- variable_weights_mean[original_var_indices] * + variable_weights_adj + variable_weights_mean[ + !(original_var_indices %in% variable_subset_mean) + ] <- 0 + } + if (include_variance_forest) { + variable_weights_variance <- variable_weights_variance[ + original_var_indices + ] * + variable_weights_adj + variable_weights_variance[ + !(original_var_indices %in% variable_subset_variance) + ] <- 0 + } + + # Set num_features_subsample to default, ncol(X_train), if not already set + if (is.null(num_features_subsample_mean)) { + num_features_subsample_mean <- ncol(X_train) + } + if (is.null(num_features_subsample_variance)) { + num_features_subsample_variance <- ncol(X_train) + } + + # Convert all input data to matrices if not already converted + if ((is.null(dim(leaf_basis_train))) && (!is.null(leaf_basis_train))) { + leaf_basis_train <- as.matrix(leaf_basis_train) + } + if ((is.null(dim(leaf_basis_test))) && (!is.null(leaf_basis_test))) { + leaf_basis_test <- as.matrix(leaf_basis_test) + } + if ((is.null(dim(rfx_basis_train))) && (!is.null(rfx_basis_train))) { + rfx_basis_train <- as.matrix(rfx_basis_train) + } + if ((is.null(dim(rfx_basis_test))) && (!is.null(rfx_basis_test))) { + rfx_basis_test <- as.matrix(rfx_basis_test) + } + + # Recode group IDs to integer vector (if passed as, for example, a vector of county names, etc...) + has_rfx <- FALSE + has_rfx_test <- FALSE + if (!is.null(rfx_group_ids_train)) { + group_ids_factor <- factor(rfx_group_ids_train) + rfx_group_ids_train <- as.integer(group_ids_factor) + has_rfx <- TRUE + if (!is.null(rfx_group_ids_test)) { + group_ids_factor_test <- factor( + rfx_group_ids_test, + levels = levels(group_ids_factor) + ) + if (sum(is.na(group_ids_factor_test)) > 0) { + stop( + "All random effect group labels provided in rfx_group_ids_test must be present in rfx_group_ids_train" + ) + } + rfx_group_ids_test <- as.integer(group_ids_factor_test) + has_rfx_test <- TRUE + } + } + + # Data consistency checks + if ((!is.null(X_test)) && (ncol(X_test) != ncol(X_train))) { + stop("X_train and X_test must have the same number of columns") + } + if ( + (!is.null(leaf_basis_test)) && + (ncol(leaf_basis_test) != ncol(leaf_basis_train)) + ) { + stop( + "leaf_basis_train and leaf_basis_test must have the same number of columns" + ) + } + if ( + (!is.null(leaf_basis_train)) && + (nrow(leaf_basis_train) != nrow(X_train)) + ) { + stop("leaf_basis_train and X_train must have the same number of rows") + } + if ((!is.null(leaf_basis_test)) && (nrow(leaf_basis_test) != nrow(X_test))) { + stop("leaf_basis_test and X_test must have the same number of rows") + } + if (nrow(X_train) != length(y_train)) { + stop("X_train and y_train must have the same number of observations") + } + if ( + (!is.null(rfx_basis_test)) && + (ncol(rfx_basis_test) != ncol(rfx_basis_train)) + ) { + stop( + "rfx_basis_train and rfx_basis_test must have the same number of columns" + ) + } + if (!is.null(rfx_group_ids_train)) { + if (!is.null(rfx_group_ids_test)) { + if ((!is.null(rfx_basis_train)) && (is.null(rfx_basis_test))) { + stop( + "rfx_basis_train is provided but rfx_basis_test is not provided" + ) + } } + } - # Data consistency checks - if ((!is.null(X_test)) && (ncol(X_test) != ncol(X_train))) { - stop("X_train and X_test must have the same number of columns") - } - if ( - (!is.null(leaf_basis_test)) && - (ncol(leaf_basis_test) != ncol(leaf_basis_train)) - ) { + # Handle the rfx basis matrices + has_basis_rfx <- FALSE + num_basis_rfx <- 0 + if (has_rfx) { + if (rfx_model_spec == "custom") { + if (is.null(rfx_basis_train)) { stop( - "leaf_basis_train and leaf_basis_test must have the same number of columns" + "A user-provided basis (`rfx_basis_train`) must be provided when the random effects model spec is 'custom'" ) - } - if ( - (!is.null(leaf_basis_train)) && - (nrow(leaf_basis_train) != nrow(X_train)) - ) { - stop("leaf_basis_train and X_train must have the same number of rows") - } - if ( - (!is.null(leaf_basis_test)) && (nrow(leaf_basis_test) != nrow(X_test)) - ) { - stop("leaf_basis_test and X_test must have the same number of rows") - } - if (nrow(X_train) != length(y_train)) { - stop("X_train and y_train must have the same number of observations") - } - if ( - (!is.null(rfx_basis_test)) && - (ncol(rfx_basis_test) != ncol(rfx_basis_train)) - ) { + } + has_basis_rfx <- TRUE + num_basis_rfx <- ncol(rfx_basis_train) + } else if (rfx_model_spec == "intercept_only") { + rfx_basis_train <- matrix( + rep(1, nrow(X_train)), + nrow = nrow(X_train), + ncol = 1 + ) + has_basis_rfx <- TRUE + num_basis_rfx <- 1 + } + num_rfx_groups <- length(unique(rfx_group_ids_train)) + num_rfx_components <- ncol(rfx_basis_train) + if (num_rfx_groups == 1) { + warning( + "Only one group was provided for random effect sampling, so the random effects model is likely overkill" + ) + } + } + if (has_rfx_test) { + if (rfx_model_spec == "custom") { + if (is.null(rfx_basis_test)) { stop( - "rfx_basis_train and rfx_basis_test must have the same number of columns" + "A user-provided basis (`rfx_basis_test`) must be provided when the random effects model spec is 'custom'" ) + } + } else if (rfx_model_spec == "intercept_only") { + rfx_basis_test <- matrix( + rep(1, nrow(X_test)), + nrow = nrow(X_test), + ncol = 1 + ) + } + } + + # Convert y_train to numeric vector if not already converted + if (!is.null(dim(y_train))) { + y_train <- as.matrix(y_train) + } + + # Determine whether a basis vector is provided + has_basis = !is.null(leaf_basis_train) + + # Determine whether a test set is provided + has_test = !is.null(X_test) + + # Preliminary runtime checks for probit link + if (!include_mean_forest) { + probit_outcome_model <- FALSE + } + if (probit_outcome_model) { + if (!(length(unique(y_train)) == 2)) { + stop( + "You specified a probit outcome model, but supplied an outcome with more than 2 unique values" + ) + } + unique_outcomes <- sort(unique(y_train)) + if (!(all(unique_outcomes == c(0, 1)))) { + stop( + "You specified a probit outcome model, but supplied an outcome with 2 unique values other than 0 and 1" + ) } - if (!is.null(rfx_group_ids_train)) { - if (!is.null(rfx_group_ids_test)) { - if ((!is.null(rfx_basis_train)) && (is.null(rfx_basis_test))) { - stop( - "rfx_basis_train is provided but rfx_basis_test is not provided" - ) - } - } + if (include_variance_forest) { + stop("We do not support heteroskedasticity with a probit link") } - - # Fill in rfx basis as a vector of 1s (random intercept) if a basis not provided - has_basis_rfx <- FALSE - num_basis_rfx <- 0 - if (has_rfx) { - if (is.null(rfx_basis_train)) { - rfx_basis_train <- matrix( - rep(1, nrow(X_train)), - nrow = nrow(X_train), - ncol = 1 - ) + if (sample_sigma2_global) { + warning( + "Global error variance will not be sampled with a probit link as it is fixed at 1" + ) + sample_sigma2_global <- F + } + } + + # Handle standardization, prior calibration, and initialization of forest + # differently for binary and continuous outcomes + if (probit_outcome_model) { + # Compute a probit-scale offset and fix scale to 1 + y_bar_train <- qnorm(mean(y_train)) + y_std_train <- 1 + + # Set a pseudo outcome by subtracting mean(y_train) from y_train + resid_train <- y_train - mean(y_train) + + # Set initial values of root nodes to 0.0 (in probit scale) + init_val_mean <- 0.0 + + # Calibrate priors for sigma^2 and tau + # Set sigma2_init to 1, ignoring default provided + sigma2_init <- 1.0 + # Skip variance_forest_init, since variance forests are not supported with probit link + b_leaf <- 1 / (num_trees_mean) + if (has_basis) { + if (ncol(leaf_basis_train) > 1) { + if (is.null(sigma2_leaf_init)) { + sigma2_leaf_init <- diag( + 2 / (num_trees_mean), + ncol(leaf_basis_train) + ) + } + if (!is.matrix(sigma2_leaf_init)) { + current_leaf_scale <- as.matrix(diag( + sigma2_leaf_init, + ncol(leaf_basis_train) + )) } else { - has_basis_rfx <- TRUE - num_basis_rfx <- ncol(rfx_basis_train) + current_leaf_scale <- sigma2_leaf_init } - num_rfx_groups <- length(unique(rfx_group_ids_train)) - num_rfx_components <- ncol(rfx_basis_train) - if (num_rfx_groups == 1) { - warning( - "Only one group was provided for random effect sampling, so the 'redundant parameterization' is likely overkill" - ) + } else { + if (is.null(sigma2_leaf_init)) { + sigma2_leaf_init <- as.matrix(2 / (num_trees_mean)) } - } - if (has_rfx_test) { - if (is.null(rfx_basis_test)) { - if (has_basis_rfx) { - stop( - "Random effects basis provided for training set, must also be provided for the test set" - ) - } - rfx_basis_test <- matrix( - rep(1, nrow(X_test)), - nrow = nrow(X_test), - ncol = 1 - ) + if (!is.matrix(sigma2_leaf_init)) { + current_leaf_scale <- as.matrix(diag(sigma2_leaf_init, 1)) + } else { + current_leaf_scale <- sigma2_leaf_init } + } + } else { + if (is.null(sigma2_leaf_init)) { + sigma2_leaf_init <- as.matrix(2 / (num_trees_mean)) + } + if (!is.matrix(sigma2_leaf_init)) { + current_leaf_scale <- as.matrix(diag(sigma2_leaf_init, 1)) + } else { + current_leaf_scale <- sigma2_leaf_init + } + } + current_sigma2 <- sigma2_init + } else { + # Only standardize if user requested + if (standardize) { + y_bar_train <- mean(y_train) + y_std_train <- sd(y_train) + } else { + y_bar_train <- 0 + y_std_train <- 1 } - # Convert y_train to numeric vector if not already converted - if (!is.null(dim(y_train))) { - y_train <- as.matrix(y_train) - } - - # Determine whether a basis vector is provided - has_basis = !is.null(leaf_basis_train) + # Compute standardized outcome + resid_train <- (y_train - y_bar_train) / y_std_train - # Determine whether a test set is provided - has_test = !is.null(X_test) + # Compute initial value of root nodes in mean forest + init_val_mean <- mean(resid_train) - # Preliminary runtime checks for probit link - if (!include_mean_forest) { - probit_outcome_model <- FALSE + # Calibrate priors for sigma^2 and tau + if (is.null(sigma2_init)) { + sigma2_init <- 1.0 * var(resid_train) } - if (probit_outcome_model) { - if (!(length(unique(y_train)) == 2)) { - stop( - "You specified a probit outcome model, but supplied an outcome with more than 2 unique values" - ) - } - unique_outcomes <- sort(unique(y_train)) - if (!(all(unique_outcomes == c(0, 1)))) { - stop( - "You specified a probit outcome model, but supplied an outcome with 2 unique values other than 0 and 1" - ) - } - if (include_variance_forest) { - stop("We do not support heteroskedasticity with a probit link") - } - if (sample_sigma2_global) { - warning( - "Global error variance will not be sampled with a probit link as it is fixed at 1" - ) - sample_sigma2_global <- F - } + if (is.null(variance_forest_init)) { + variance_forest_init <- 1.0 * var(resid_train) } - - # Handle standardization, prior calibration, and initialization of forest - # differently for binary and continuous outcomes - if (probit_outcome_model) { - # Compute a probit-scale offset and fix scale to 1 - y_bar_train <- qnorm(mean(y_train)) - y_std_train <- 1 - - # Set a pseudo outcome by subtracting mean(y_train) from y_train - resid_train <- y_train - mean(y_train) - - # Set initial values of root nodes to 0.0 (in probit scale) - init_val_mean <- 0.0 - - # Calibrate priors for sigma^2 and tau - # Set sigma2_init to 1, ignoring default provided - sigma2_init <- 1.0 - # Skip variance_forest_init, since variance forests are not supported with probit link - b_leaf <- 1 / (num_trees_mean) - if (has_basis) { - if (ncol(leaf_basis_train) > 1) { - if (is.null(sigma2_leaf_init)) { - sigma2_leaf_init <- diag( - 2 / (num_trees_mean), - ncol(leaf_basis_train) - ) - } - if (!is.matrix(sigma2_leaf_init)) { - current_leaf_scale <- as.matrix(diag( - sigma2_leaf_init, - ncol(leaf_basis_train) - )) - } else { - current_leaf_scale <- sigma2_leaf_init - } - } else { - if (is.null(sigma2_leaf_init)) { - sigma2_leaf_init <- as.matrix(2 / (num_trees_mean)) - } - if (!is.matrix(sigma2_leaf_init)) { - current_leaf_scale <- as.matrix(diag(sigma2_leaf_init, 1)) - } else { - current_leaf_scale <- sigma2_leaf_init - } - } - } else { - if (is.null(sigma2_leaf_init)) { - sigma2_leaf_init <- as.matrix(2 / (num_trees_mean)) - } - if (!is.matrix(sigma2_leaf_init)) { - current_leaf_scale <- as.matrix(diag(sigma2_leaf_init, 1)) - } else { - current_leaf_scale <- sigma2_leaf_init - } + if (is.null(b_leaf)) { + b_leaf <- var(resid_train) / (2 * num_trees_mean) + } + if (has_basis) { + if (ncol(leaf_basis_train) > 1) { + if (is.null(sigma2_leaf_init)) { + sigma2_leaf_init <- diag( + 2 * var(resid_train) / (num_trees_mean), + ncol(leaf_basis_train) + ) } - current_sigma2 <- sigma2_init - } else { - # Only standardize if user requested - if (standardize) { - y_bar_train <- mean(y_train) - y_std_train <- sd(y_train) + if (!is.matrix(sigma2_leaf_init)) { + current_leaf_scale <- as.matrix(diag( + sigma2_leaf_init, + ncol(leaf_basis_train) + )) } else { - y_bar_train <- 0 - y_std_train <- 1 + current_leaf_scale <- sigma2_leaf_init } - - # Compute standardized outcome - resid_train <- (y_train - y_bar_train) / y_std_train - - # Compute initial value of root nodes in mean forest - init_val_mean <- mean(resid_train) - - # Calibrate priors for sigma^2 and tau - if (is.null(sigma2_init)) { - sigma2_init <- 1.0 * var(resid_train) - } - if (is.null(variance_forest_init)) { - variance_forest_init <- 1.0 * var(resid_train) - } - if (is.null(b_leaf)) { - b_leaf <- var(resid_train) / (2 * num_trees_mean) + } else { + if (is.null(sigma2_leaf_init)) { + sigma2_leaf_init <- as.matrix( + 2 * var(resid_train) / (num_trees_mean) + ) } - if (has_basis) { - if (ncol(leaf_basis_train) > 1) { - if (is.null(sigma2_leaf_init)) { - sigma2_leaf_init <- diag( - 2 * var(resid_train) / (num_trees_mean), - ncol(leaf_basis_train) - ) - } - if (!is.matrix(sigma2_leaf_init)) { - current_leaf_scale <- as.matrix(diag( - sigma2_leaf_init, - ncol(leaf_basis_train) - )) - } else { - current_leaf_scale <- sigma2_leaf_init - } - } else { - if (is.null(sigma2_leaf_init)) { - sigma2_leaf_init <- as.matrix( - 2 * var(resid_train) / (num_trees_mean) - ) - } - if (!is.matrix(sigma2_leaf_init)) { - current_leaf_scale <- as.matrix(diag(sigma2_leaf_init, 1)) - } else { - current_leaf_scale <- sigma2_leaf_init - } - } + if (!is.matrix(sigma2_leaf_init)) { + current_leaf_scale <- as.matrix(diag(sigma2_leaf_init, 1)) } else { - if (is.null(sigma2_leaf_init)) { - sigma2_leaf_init <- as.matrix( - 2 * var(resid_train) / (num_trees_mean) - ) - } - if (!is.matrix(sigma2_leaf_init)) { - current_leaf_scale <- as.matrix(diag(sigma2_leaf_init, 1)) - } else { - current_leaf_scale <- sigma2_leaf_init - } + current_leaf_scale <- sigma2_leaf_init } - current_sigma2 <- sigma2_init + } + } else { + if (is.null(sigma2_leaf_init)) { + sigma2_leaf_init <- as.matrix( + 2 * var(resid_train) / (num_trees_mean) + ) + } + if (!is.matrix(sigma2_leaf_init)) { + current_leaf_scale <- as.matrix(diag(sigma2_leaf_init, 1)) + } else { + current_leaf_scale <- sigma2_leaf_init + } + } + current_sigma2 <- sigma2_init + } + + # Determine leaf model type + if (!has_basis) { + leaf_model_mean_forest <- 0 + } else if (ncol(leaf_basis_train) == 1) { + leaf_model_mean_forest <- 1 + } else if (ncol(leaf_basis_train) > 1) { + leaf_model_mean_forest <- 2 + } else { + stop("leaf_basis_train passed must be a matrix with at least 1 column") + } + + # Set variance leaf model type (currently only one option) + leaf_model_variance_forest <- 3 + + # Unpack model type info + if (leaf_model_mean_forest == 0) { + leaf_dimension = 1 + is_leaf_constant = TRUE + leaf_regression = FALSE + } else if (leaf_model_mean_forest == 1) { + stopifnot(has_basis) + stopifnot(ncol(leaf_basis_train) == 1) + leaf_dimension = 1 + is_leaf_constant = FALSE + leaf_regression = TRUE + } else if (leaf_model_mean_forest == 2) { + stopifnot(has_basis) + stopifnot(ncol(leaf_basis_train) > 1) + leaf_dimension = ncol(leaf_basis_train) + is_leaf_constant = FALSE + leaf_regression = TRUE + if (sample_sigma2_leaf) { + warning( + "Sampling leaf scale not yet supported for multivariate leaf models, so the leaf scale parameter will not be sampled in this model." + ) + sample_sigma2_leaf <- FALSE } + } - # Determine leaf model type - if (!has_basis) { - leaf_model_mean_forest <- 0 - } else if (ncol(leaf_basis_train) == 1) { - leaf_model_mean_forest <- 1 - } else if (ncol(leaf_basis_train) > 1) { - leaf_model_mean_forest <- 2 + # Data + if (leaf_regression) { + forest_dataset_train <- createForestDataset(X_train, leaf_basis_train) + if (has_test) { + forest_dataset_test <- createForestDataset(X_test, leaf_basis_test) + } + requires_basis <- TRUE + } else { + forest_dataset_train <- createForestDataset(X_train) + if (has_test) { + forest_dataset_test <- createForestDataset(X_test) + } + requires_basis <- FALSE + } + outcome_train <- createOutcome(resid_train) + + # Random number generator (std::mt19937) + if (is.null(random_seed)) { + random_seed = sample(1:10000, 1, FALSE) + } + rng <- createCppRNG(random_seed) + + # Sampling data structures + feature_types <- as.integer(feature_types) + global_model_config <- createGlobalModelConfig( + global_error_variance = current_sigma2 + ) + if (include_mean_forest) { + forest_model_config_mean <- createForestModelConfig( + feature_types = feature_types, + num_trees = num_trees_mean, + num_features = ncol(X_train), + num_observations = nrow(X_train), + variable_weights = variable_weights_mean, + leaf_dimension = leaf_dimension, + alpha = alpha_mean, + beta = beta_mean, + min_samples_leaf = min_samples_leaf_mean, + max_depth = max_depth_mean, + leaf_model_type = leaf_model_mean_forest, + leaf_model_scale = current_leaf_scale, + cutpoint_grid_size = cutpoint_grid_size, + num_features_subsample = num_features_subsample_mean + ) + forest_model_mean <- createForestModel( + forest_dataset_train, + forest_model_config_mean, + global_model_config + ) + } + if (include_variance_forest) { + forest_model_config_variance <- createForestModelConfig( + feature_types = feature_types, + num_trees = num_trees_variance, + num_features = ncol(X_train), + num_observations = nrow(X_train), + variable_weights = variable_weights_variance, + leaf_dimension = 1, + alpha = alpha_variance, + beta = beta_variance, + min_samples_leaf = min_samples_leaf_variance, + max_depth = max_depth_variance, + leaf_model_type = leaf_model_variance_forest, + cutpoint_grid_size = cutpoint_grid_size, + num_features_subsample = num_features_subsample_variance + ) + forest_model_variance <- createForestModel( + forest_dataset_train, + forest_model_config_variance, + global_model_config + ) + } + + # Container of forest samples + if (include_mean_forest) { + forest_samples_mean <- createForestSamples( + num_trees_mean, + leaf_dimension, + is_leaf_constant, + FALSE + ) + active_forest_mean <- createForest( + num_trees_mean, + leaf_dimension, + is_leaf_constant, + FALSE + ) + } + if (include_variance_forest) { + forest_samples_variance <- createForestSamples( + num_trees_variance, + 1, + TRUE, + TRUE + ) + active_forest_variance <- createForest( + num_trees_variance, + 1, + TRUE, + TRUE + ) + } + + # Random effects initialization + if (has_rfx) { + # Prior parameters + if (is.null(rfx_working_parameter_prior_mean)) { + if (num_rfx_components == 1) { + alpha_init <- c(0) + } else if (num_rfx_components > 1) { + alpha_init <- rep(0, num_rfx_components) + } else { + stop("There must be at least 1 random effect component") + } } else { - stop("leaf_basis_train passed must be a matrix with at least 1 column") - } - - # Set variance leaf model type (currently only one option) - leaf_model_variance_forest <- 3 - - # Unpack model type info - if (leaf_model_mean_forest == 0) { - leaf_dimension = 1 - is_leaf_constant = TRUE - leaf_regression = FALSE - } else if (leaf_model_mean_forest == 1) { - stopifnot(has_basis) - stopifnot(ncol(leaf_basis_train) == 1) - leaf_dimension = 1 - is_leaf_constant = FALSE - leaf_regression = TRUE - } else if (leaf_model_mean_forest == 2) { - stopifnot(has_basis) - stopifnot(ncol(leaf_basis_train) > 1) - leaf_dimension = ncol(leaf_basis_train) - is_leaf_constant = FALSE - leaf_regression = TRUE - if (sample_sigma2_leaf) { - warning( - "Sampling leaf scale not yet supported for multivariate leaf models, so the leaf scale parameter will not be sampled in this model." - ) - sample_sigma2_leaf <- FALSE - } + alpha_init <- expand_dims_1d( + rfx_working_parameter_prior_mean, + num_rfx_components + ) + } + + if (is.null(rfx_group_parameter_prior_mean)) { + xi_init <- matrix( + rep(alpha_init, num_rfx_groups), + num_rfx_components, + num_rfx_groups + ) + } else { + xi_init <- expand_dims_2d( + rfx_group_parameter_prior_mean, + num_rfx_components, + num_rfx_groups + ) } - # Data - if (leaf_regression) { - forest_dataset_train <- createForestDataset(X_train, leaf_basis_train) - if (has_test) { - forest_dataset_test <- createForestDataset(X_test, leaf_basis_test) - } - requires_basis <- TRUE + if (is.null(rfx_working_parameter_prior_cov)) { + sigma_alpha_init <- diag(1, num_rfx_components, num_rfx_components) } else { - forest_dataset_train <- createForestDataset(X_train) - if (has_test) { - forest_dataset_test <- createForestDataset(X_test) - } - requires_basis <- FALSE + sigma_alpha_init <- expand_dims_2d_diag( + rfx_working_parameter_prior_cov, + num_rfx_components + ) } - outcome_train <- createOutcome(resid_train) - # Random number generator (std::mt19937) - if (is.null(random_seed)) { - random_seed = sample(1:10000, 1, FALSE) + if (is.null(rfx_group_parameter_prior_cov)) { + sigma_xi_init <- diag(1, num_rfx_components, num_rfx_components) + } else { + sigma_xi_init <- expand_dims_2d_diag( + rfx_group_parameter_prior_cov, + num_rfx_components + ) } - rng <- createCppRNG(random_seed) - # Sampling data structures - feature_types <- as.integer(feature_types) - global_model_config <- createGlobalModelConfig( - global_error_variance = current_sigma2 - ) - if (include_mean_forest) { - forest_model_config_mean <- createForestModelConfig( - feature_types = feature_types, - num_trees = num_trees_mean, - num_features = ncol(X_train), - num_observations = nrow(X_train), - variable_weights = variable_weights_mean, - leaf_dimension = leaf_dimension, - alpha = alpha_mean, - beta = beta_mean, - min_samples_leaf = min_samples_leaf_mean, - max_depth = max_depth_mean, - leaf_model_type = leaf_model_mean_forest, - leaf_model_scale = current_leaf_scale, - cutpoint_grid_size = cutpoint_grid_size, - num_features_subsample = num_features_subsample_mean - ) - forest_model_mean <- createForestModel( - forest_dataset_train, - forest_model_config_mean, - global_model_config - ) - } - if (include_variance_forest) { - forest_model_config_variance <- createForestModelConfig( - feature_types = feature_types, - num_trees = num_trees_variance, - num_features = ncol(X_train), - num_observations = nrow(X_train), - variable_weights = variable_weights_variance, - leaf_dimension = 1, - alpha = alpha_variance, - beta = beta_variance, - min_samples_leaf = min_samples_leaf_variance, - max_depth = max_depth_variance, - leaf_model_type = leaf_model_variance_forest, - cutpoint_grid_size = cutpoint_grid_size, - num_features_subsample = num_features_subsample_variance - ) - forest_model_variance <- createForestModel( - forest_dataset_train, - forest_model_config_variance, - global_model_config - ) - } + sigma_xi_shape <- rfx_variance_prior_shape + sigma_xi_scale <- rfx_variance_prior_scale - # Container of forest samples - if (include_mean_forest) { - forest_samples_mean <- createForestSamples( - num_trees_mean, - leaf_dimension, - is_leaf_constant, - FALSE - ) - active_forest_mean <- createForest( - num_trees_mean, - leaf_dimension, - is_leaf_constant, - FALSE - ) - } - if (include_variance_forest) { - forest_samples_variance <- createForestSamples( - num_trees_variance, - 1, - TRUE, - TRUE - ) - active_forest_variance <- createForest( - num_trees_variance, - 1, - TRUE, - TRUE - ) - } + # Random effects data structure and storage container + rfx_dataset_train <- createRandomEffectsDataset( + rfx_group_ids_train, + rfx_basis_train + ) + rfx_tracker_train <- createRandomEffectsTracker(rfx_group_ids_train) + rfx_model <- createRandomEffectsModel( + num_rfx_components, + num_rfx_groups + ) + rfx_model$set_working_parameter(alpha_init) + rfx_model$set_group_parameters(xi_init) + rfx_model$set_working_parameter_cov(sigma_alpha_init) + rfx_model$set_group_parameter_cov(sigma_xi_init) + rfx_model$set_variance_prior_shape(sigma_xi_shape) + rfx_model$set_variance_prior_scale(sigma_xi_scale) + rfx_samples <- createRandomEffectSamples( + num_rfx_components, + num_rfx_groups, + rfx_tracker_train + ) + } + + # Container of variance parameter samples + num_actual_mcmc_iter <- num_mcmc * keep_every + num_samples <- num_gfr + num_burnin + num_actual_mcmc_iter + # Delete GFR samples from these containers after the fact if desired + # num_retained_samples <- ifelse(keep_gfr, num_gfr, 0) + ifelse(keep_burnin, num_burnin, 0) + num_mcmc + num_retained_samples <- num_gfr + + ifelse(keep_burnin, num_burnin, 0) + + num_mcmc * num_chains + if (sample_sigma2_global) { + global_var_samples <- rep(NA, num_retained_samples) + } + if (sample_sigma2_leaf) { + leaf_scale_samples <- rep(NA, num_retained_samples) + } + if (include_mean_forest) { + mean_forest_pred_train <- matrix( + NA_real_, + nrow(X_train), + num_retained_samples + ) + } + if (include_variance_forest) { + variance_forest_pred_train <- matrix( + NA_real_, + nrow(X_train), + num_retained_samples + ) + } + sample_counter <- 0 - # Random effects initialization - if (has_rfx) { - # Prior parameters - if (is.null(rfx_working_parameter_prior_mean)) { - if (num_rfx_components == 1) { - alpha_init <- c(0) - } else if (num_rfx_components > 1) { - alpha_init <- rep(0, num_rfx_components) - } else { - stop("There must be at least 1 random effect component") - } - } else { - alpha_init <- expand_dims_1d( - rfx_working_parameter_prior_mean, - num_rfx_components - ) + # Initialize the leaves of each tree in the mean forest + if (include_mean_forest) { + if (requires_basis) { + init_values_mean_forest <- rep(0., ncol(leaf_basis_train)) + } else { + init_values_mean_forest <- 0. + } + active_forest_mean$prepare_for_sampler( + forest_dataset_train, + outcome_train, + forest_model_mean, + leaf_model_mean_forest, + init_values_mean_forest + ) + active_forest_mean$adjust_residual( + forest_dataset_train, + outcome_train, + forest_model_mean, + requires_basis, + FALSE + ) + } + + # Initialize the leaves of each tree in the variance forest + if (include_variance_forest) { + active_forest_variance$prepare_for_sampler( + forest_dataset_train, + outcome_train, + forest_model_variance, + leaf_model_variance_forest, + variance_forest_init + ) + } + + # Run GFR (warm start) if specified + if (num_gfr > 0) { + for (i in 1:num_gfr) { + # Keep all GFR samples at this stage -- remove from ForestSamples after MCMC + # keep_sample <- ifelse(keep_gfr, TRUE, FALSE) + keep_sample <- TRUE + if (keep_sample) { + sample_counter <- sample_counter + 1 + } + # Print progress + if (verbose) { + if ((i %% 10 == 0) || (i == num_gfr)) { + cat( + "Sampling", + i, + "out of", + num_gfr, + "XBART (grow-from-root) draws\n" + ) } - - if (is.null(rfx_group_parameter_prior_mean)) { - xi_init <- matrix( - rep(alpha_init, num_rfx_groups), - num_rfx_components, - num_rfx_groups - ) - } else { - xi_init <- expand_dims_2d( - rfx_group_parameter_prior_mean, - num_rfx_components, - num_rfx_groups + } + + if (include_mean_forest) { + if (probit_outcome_model) { + # Sample latent probit variable, z | - + outcome_pred <- active_forest_mean$predict( + forest_dataset_train + ) + if (has_rfx) { + rfx_pred <- rfx_model$predict( + rfx_dataset_train, + rfx_tracker_train ) + outcome_pred <- outcome_pred + rfx_pred + } + mu0 <- outcome_pred[y_train == 0] + mu1 <- outcome_pred[y_train == 1] + u0 <- runif(sum(y_train == 0), 0, pnorm(0 - mu0)) + u1 <- runif(sum(y_train == 1), pnorm(0 - mu1), 1) + resid_train[y_train == 0] <- mu0 + qnorm(u0) + resid_train[y_train == 1] <- mu1 + qnorm(u1) + + # Update outcome + outcome_train$update_data(resid_train - outcome_pred) } - if (is.null(rfx_working_parameter_prior_cov)) { - sigma_alpha_init <- diag(1, num_rfx_components, num_rfx_components) - } else { - sigma_alpha_init <- expand_dims_2d_diag( - rfx_working_parameter_prior_cov, - num_rfx_components - ) - } + # Sample mean forest + forest_model_mean$sample_one_iteration( + forest_dataset = forest_dataset_train, + residual = outcome_train, + forest_samples = forest_samples_mean, + active_forest = active_forest_mean, + rng = rng, + forest_model_config = forest_model_config_mean, + global_model_config = global_model_config, + num_threads = num_threads, + keep_forest = keep_sample, + gfr = TRUE + ) - if (is.null(rfx_group_parameter_prior_cov)) { - sigma_xi_init <- diag(1, num_rfx_components, num_rfx_components) - } else { - sigma_xi_init <- expand_dims_2d_diag( - rfx_group_parameter_prior_cov, - num_rfx_components - ) + # Cache train set predictions since they are already computed during sampling + if (keep_sample) { + mean_forest_pred_train[, + sample_counter + ] <- forest_model_mean$get_cached_forest_predictions() } - - sigma_xi_shape <- rfx_variance_prior_shape - sigma_xi_scale <- rfx_variance_prior_scale - - # Random effects data structure and storage container - rfx_dataset_train <- createRandomEffectsDataset( - rfx_group_ids_train, - rfx_basis_train - ) - rfx_tracker_train <- createRandomEffectsTracker(rfx_group_ids_train) - rfx_model <- createRandomEffectsModel( - num_rfx_components, - num_rfx_groups + } + if (include_variance_forest) { + forest_model_variance$sample_one_iteration( + forest_dataset = forest_dataset_train, + residual = outcome_train, + forest_samples = forest_samples_variance, + active_forest = active_forest_variance, + rng = rng, + forest_model_config = forest_model_config_variance, + global_model_config = global_model_config, + keep_forest = keep_sample, + gfr = TRUE ) - rfx_model$set_working_parameter(alpha_init) - rfx_model$set_group_parameters(xi_init) - rfx_model$set_working_parameter_cov(sigma_alpha_init) - rfx_model$set_group_parameter_cov(sigma_xi_init) - rfx_model$set_variance_prior_shape(sigma_xi_shape) - rfx_model$set_variance_prior_scale(sigma_xi_scale) - rfx_samples <- createRandomEffectSamples( - num_rfx_components, - num_rfx_groups, - rfx_tracker_train - ) - } - # Container of variance parameter samples - num_actual_mcmc_iter <- num_mcmc * keep_every - num_samples <- num_gfr + num_burnin + num_actual_mcmc_iter - # Delete GFR samples from these containers after the fact if desired - # num_retained_samples <- ifelse(keep_gfr, num_gfr, 0) + ifelse(keep_burnin, num_burnin, 0) + num_mcmc - num_retained_samples <- num_gfr + - ifelse(keep_burnin, num_burnin, 0) + - num_mcmc * num_chains - if (sample_sigma2_global) { - global_var_samples <- rep(NA, num_retained_samples) - } - if (sample_sigma2_leaf) { - leaf_scale_samples <- rep(NA, num_retained_samples) - } - if (include_mean_forest) { - mean_forest_pred_train <- matrix( - NA_real_, - nrow(X_train), - num_retained_samples + # Cache train set predictions since they are already computed during sampling + if (keep_sample) { + variance_forest_pred_train[, + sample_counter + ] <- forest_model_variance$get_cached_forest_predictions() + } + } + if (sample_sigma2_global) { + current_sigma2 <- sampleGlobalErrorVarianceOneIteration( + outcome_train, + forest_dataset_train, + rng, + a_global, + b_global ) - } - if (include_variance_forest) { - variance_forest_pred_train <- matrix( - NA_real_, - nrow(X_train), - num_retained_samples + if (keep_sample) { + global_var_samples[sample_counter] <- current_sigma2 + } + global_model_config$update_global_error_variance(current_sigma2) + } + if (sample_sigma2_leaf) { + leaf_scale_double <- sampleLeafVarianceOneIteration( + active_forest_mean, + rng, + a_leaf, + b_leaf ) - } - sample_counter <- 0 - - # Initialize the leaves of each tree in the mean forest - if (include_mean_forest) { - if (requires_basis) { - init_values_mean_forest <- rep(0., ncol(leaf_basis_train)) - } else { - init_values_mean_forest <- 0. + current_leaf_scale <- as.matrix(leaf_scale_double) + if (keep_sample) { + leaf_scale_samples[sample_counter] <- leaf_scale_double } - active_forest_mean$prepare_for_sampler( - forest_dataset_train, - outcome_train, - forest_model_mean, - leaf_model_mean_forest, - init_values_mean_forest + forest_model_config_mean$update_leaf_model_scale( + current_leaf_scale ) - active_forest_mean$adjust_residual( - forest_dataset_train, - outcome_train, - forest_model_mean, - requires_basis, - FALSE + } + if (has_rfx) { + rfx_model$sample_random_effect( + rfx_dataset_train, + outcome_train, + rfx_tracker_train, + rfx_samples, + keep_sample, + current_sigma2, + rng ) + } } + } - # Initialize the leaves of each tree in the variance forest - if (include_variance_forest) { - active_forest_variance$prepare_for_sampler( + # Run MCMC + if (num_burnin + num_mcmc > 0) { + for (chain_num in 1:num_chains) { + if (num_gfr > 0) { + # Reset state of active_forest and forest_model based on a previous GFR sample + forest_ind <- num_gfr - chain_num + if (include_mean_forest) { + resetActiveForest( + active_forest_mean, + forest_samples_mean, + forest_ind + ) + resetForestModel( + forest_model_mean, + active_forest_mean, forest_dataset_train, outcome_train, + TRUE + ) + if (sample_sigma2_leaf) { + leaf_scale_double <- leaf_scale_samples[forest_ind + 1] + current_leaf_scale <- as.matrix(leaf_scale_double) + forest_model_config_mean$update_leaf_model_scale( + current_leaf_scale + ) + } + } + if (include_variance_forest) { + resetActiveForest( + active_forest_variance, + forest_samples_variance, + forest_ind + ) + resetForestModel( forest_model_variance, - leaf_model_variance_forest, - variance_forest_init - ) - } - - # Run GFR (warm start) if specified - if (num_gfr > 0) { - for (i in 1:num_gfr) { - # Keep all GFR samples at this stage -- remove from ForestSamples after MCMC - # keep_sample <- ifelse(keep_gfr, TRUE, FALSE) - keep_sample <- TRUE - if (keep_sample) { - sample_counter <- sample_counter + 1 - } - # Print progress - if (verbose) { - if ((i %% 10 == 0) || (i == num_gfr)) { - cat( - "Sampling", - i, - "out of", - num_gfr, - "XBART (grow-from-root) draws\n" - ) - } - } - - if (include_mean_forest) { - if (probit_outcome_model) { - # Sample latent probit variable, z | - - forest_pred <- active_forest_mean$predict( - forest_dataset_train - ) - mu0 <- forest_pred[y_train == 0] - mu1 <- forest_pred[y_train == 1] - u0 <- runif(sum(y_train == 0), 0, pnorm(0 - mu0)) - u1 <- runif(sum(y_train == 1), pnorm(0 - mu1), 1) - resid_train[y_train == 0] <- mu0 + qnorm(u0) - resid_train[y_train == 1] <- mu1 + qnorm(u1) - - # Update outcome - outcome_train$update_data(resid_train - forest_pred) - } - - # Sample mean forest - forest_model_mean$sample_one_iteration( - forest_dataset = forest_dataset_train, - residual = outcome_train, - forest_samples = forest_samples_mean, - active_forest = active_forest_mean, - rng = rng, - forest_model_config = forest_model_config_mean, - global_model_config = global_model_config, - num_threads = num_threads, - keep_forest = keep_sample, - gfr = TRUE - ) - - # Cache train set predictions since they are already computed during sampling - if (keep_sample) { - mean_forest_pred_train[, - sample_counter - ] <- forest_model_mean$get_cached_forest_predictions() - } - } - if (include_variance_forest) { - forest_model_variance$sample_one_iteration( - forest_dataset = forest_dataset_train, - residual = outcome_train, - forest_samples = forest_samples_variance, - active_forest = active_forest_variance, - rng = rng, - forest_model_config = forest_model_config_variance, - global_model_config = global_model_config, - keep_forest = keep_sample, - gfr = TRUE - ) - - # Cache train set predictions since they are already computed during sampling - if (keep_sample) { - variance_forest_pred_train[, - sample_counter - ] <- forest_model_variance$get_cached_forest_predictions() - } - } - if (sample_sigma2_global) { - current_sigma2 <- sampleGlobalErrorVarianceOneIteration( - outcome_train, - forest_dataset_train, - rng, - a_global, - b_global - ) - if (keep_sample) { - global_var_samples[sample_counter] <- current_sigma2 - } - global_model_config$update_global_error_variance(current_sigma2) - } - if (sample_sigma2_leaf) { - leaf_scale_double <- sampleLeafVarianceOneIteration( - active_forest_mean, - rng, - a_leaf, - b_leaf - ) - current_leaf_scale <- as.matrix(leaf_scale_double) - if (keep_sample) { - leaf_scale_samples[sample_counter] <- leaf_scale_double - } - forest_model_config_mean$update_leaf_model_scale( - current_leaf_scale - ) - } - if (has_rfx) { - rfx_model$sample_random_effect( - rfx_dataset_train, - outcome_train, - rfx_tracker_train, - rfx_samples, - keep_sample, - current_sigma2, - rng - ) - } + active_forest_variance, + forest_dataset_train, + outcome_train, + FALSE + ) } - } - - # Run MCMC - if (num_burnin + num_mcmc > 0) { - for (chain_num in 1:num_chains) { - if (num_gfr > 0) { - # Reset state of active_forest and forest_model based on a previous GFR sample - forest_ind <- num_gfr - chain_num - if (include_mean_forest) { - resetActiveForest( - active_forest_mean, - forest_samples_mean, - forest_ind - ) - resetForestModel( - forest_model_mean, - active_forest_mean, - forest_dataset_train, - outcome_train, - TRUE - ) - if (sample_sigma2_leaf) { - leaf_scale_double <- leaf_scale_samples[forest_ind + 1] - current_leaf_scale <- as.matrix(leaf_scale_double) - forest_model_config_mean$update_leaf_model_scale( - current_leaf_scale - ) - } - } - if (include_variance_forest) { - resetActiveForest( - active_forest_variance, - forest_samples_variance, - forest_ind - ) - resetForestModel( - forest_model_variance, - active_forest_variance, - forest_dataset_train, - outcome_train, - FALSE - ) - } - if (has_rfx) { - resetRandomEffectsModel( - rfx_model, - rfx_samples, - forest_ind, - sigma_alpha_init - ) - resetRandomEffectsTracker( - rfx_tracker_train, - rfx_model, - rfx_dataset_train, - outcome_train, - rfx_samples - ) - } - if (sample_sigma2_global) { - current_sigma2 <- global_var_samples[forest_ind + 1] - global_model_config$update_global_error_variance( - current_sigma2 - ) - } - } else if (has_prev_model) { - if (include_mean_forest) { - resetActiveForest( - active_forest_mean, - previous_forest_samples_mean, - previous_model_warmstart_sample_num - 1 - ) - resetForestModel( - forest_model_mean, - active_forest_mean, - forest_dataset_train, - outcome_train, - TRUE - ) - if ( - sample_sigma2_leaf && - (!is.null(previous_leaf_var_samples)) - ) { - leaf_scale_double <- previous_leaf_var_samples[ - previous_model_warmstart_sample_num - ] - current_leaf_scale <- as.matrix(leaf_scale_double) - forest_model_config_mean$update_leaf_model_scale( - current_leaf_scale - ) - } - } - if (include_variance_forest) { - resetActiveForest( - active_forest_variance, - previous_forest_samples_variance, - previous_model_warmstart_sample_num - 1 - ) - resetForestModel( - forest_model_variance, - active_forest_variance, - forest_dataset_train, - outcome_train, - FALSE - ) - } - if (has_rfx) { - if (is.null(previous_rfx_samples)) { - warning( - "`previous_model_json` did not have any random effects samples, so the RFX sampler will be run from scratch while the forests and any other parameters are warm started" - ) - rootResetRandomEffectsModel( - rfx_model, - alpha_init, - xi_init, - sigma_alpha_init, - sigma_xi_init, - sigma_xi_shape, - sigma_xi_scale - ) - rootResetRandomEffectsTracker( - rfx_tracker_train, - rfx_model, - rfx_dataset_train, - outcome_train - ) - } else { - resetRandomEffectsModel( - rfx_model, - previous_rfx_samples, - previous_model_warmstart_sample_num - 1, - sigma_alpha_init - ) - resetRandomEffectsTracker( - rfx_tracker_train, - rfx_model, - rfx_dataset_train, - outcome_train, - rfx_samples - ) - } - } - if (sample_sigma2_global) { - if (!is.null(previous_global_var_samples)) { - current_sigma2 <- previous_global_var_samples[ - previous_model_warmstart_sample_num - ] - global_model_config$update_global_error_variance( - current_sigma2 - ) - } - } - } else { - if (include_mean_forest) { - resetActiveForest(active_forest_mean) - active_forest_mean$set_root_leaves( - init_values_mean_forest / num_trees_mean - ) - resetForestModel( - forest_model_mean, - active_forest_mean, - forest_dataset_train, - outcome_train, - TRUE - ) - if (sample_sigma2_leaf) { - current_leaf_scale <- as.matrix(sigma2_leaf_init) - forest_model_config_mean$update_leaf_model_scale( - current_leaf_scale - ) - } - } - if (include_variance_forest) { - resetActiveForest(active_forest_variance) - active_forest_variance$set_root_leaves( - log(variance_forest_init) / num_trees_variance - ) - resetForestModel( - forest_model_variance, - active_forest_variance, - forest_dataset_train, - outcome_train, - FALSE - ) - } - if (has_rfx) { - rootResetRandomEffectsModel( - rfx_model, - alpha_init, - xi_init, - sigma_alpha_init, - sigma_xi_init, - sigma_xi_shape, - sigma_xi_scale - ) - rootResetRandomEffectsTracker( - rfx_tracker_train, - rfx_model, - rfx_dataset_train, - outcome_train - ) - } - if (sample_sigma2_global) { - current_sigma2 <- sigma2_init - global_model_config$update_global_error_variance( - current_sigma2 - ) - } - } - for (i in (num_gfr + 1):num_samples) { - is_mcmc <- i > (num_gfr + num_burnin) - if (is_mcmc) { - mcmc_counter <- i - (num_gfr + num_burnin) - if (mcmc_counter %% keep_every == 0) { - keep_sample <- TRUE - } else { - keep_sample <- FALSE - } - } else { - if (keep_burnin) { - keep_sample <- TRUE - } else { - keep_sample <- FALSE - } - } - if (keep_sample) { - sample_counter <- sample_counter + 1 - } - # Print progress - if (verbose) { - if (num_burnin > 0) { - if ( - ((i - num_gfr) %% 100 == 0) || - ((i - num_gfr) == num_burnin) - ) { - cat( - "Sampling", - i - num_gfr, - "out of", - num_burnin, - "BART burn-in draws; Chain number ", - chain_num, - "\n" - ) - } - } - if (num_mcmc > 0) { - if ( - ((i - num_gfr - num_burnin) %% 100 == 0) || - (i == num_samples) - ) { - cat( - "Sampling", - i - num_burnin - num_gfr, - "out of", - num_mcmc, - "BART MCMC draws; Chain number ", - chain_num, - "\n" - ) - } - } - } - - if (include_mean_forest) { - if (probit_outcome_model) { - # Sample latent probit variable, z | - - forest_pred <- active_forest_mean$predict( - forest_dataset_train - ) - mu0 <- forest_pred[y_train == 0] - mu1 <- forest_pred[y_train == 1] - u0 <- runif(sum(y_train == 0), 0, pnorm(0 - mu0)) - u1 <- runif(sum(y_train == 1), pnorm(0 - mu1), 1) - resid_train[y_train == 0] <- mu0 + qnorm(u0) - resid_train[y_train == 1] <- mu1 + qnorm(u1) - - # Update outcome - outcome_train$update_data(resid_train - forest_pred) - } - - forest_model_mean$sample_one_iteration( - forest_dataset = forest_dataset_train, - residual = outcome_train, - forest_samples = forest_samples_mean, - active_forest = active_forest_mean, - rng = rng, - forest_model_config = forest_model_config_mean, - global_model_config = global_model_config, - keep_forest = keep_sample, - gfr = FALSE - ) - - # Cache train set predictions since they are already computed during sampling - if (keep_sample) { - mean_forest_pred_train[, - sample_counter - ] <- forest_model_mean$get_cached_forest_predictions() - } - } - if (include_variance_forest) { - forest_model_variance$sample_one_iteration( - forest_dataset = forest_dataset_train, - residual = outcome_train, - forest_samples = forest_samples_variance, - active_forest = active_forest_variance, - rng = rng, - forest_model_config = forest_model_config_variance, - global_model_config = global_model_config, - keep_forest = keep_sample, - gfr = FALSE - ) - - # Cache train set predictions since they are already computed during sampling - if (keep_sample) { - variance_forest_pred_train[, - sample_counter - ] <- forest_model_variance$get_cached_forest_predictions() - } - } - if (sample_sigma2_global) { - current_sigma2 <- sampleGlobalErrorVarianceOneIteration( - outcome_train, - forest_dataset_train, - rng, - a_global, - b_global - ) - if (keep_sample) { - global_var_samples[sample_counter] <- current_sigma2 - } - global_model_config$update_global_error_variance( - current_sigma2 - ) - } - if (sample_sigma2_leaf) { - leaf_scale_double <- sampleLeafVarianceOneIteration( - active_forest_mean, - rng, - a_leaf, - b_leaf - ) - current_leaf_scale <- as.matrix(leaf_scale_double) - if (keep_sample) { - leaf_scale_samples[sample_counter] <- leaf_scale_double - } - forest_model_config_mean$update_leaf_model_scale( - current_leaf_scale - ) - } - if (has_rfx) { - rfx_model$sample_random_effect( - rfx_dataset_train, - outcome_train, - rfx_tracker_train, - rfx_samples, - keep_sample, - current_sigma2, - rng - ) - } - } + if (has_rfx) { + resetRandomEffectsModel( + rfx_model, + rfx_samples, + forest_ind, + sigma_alpha_init + ) + resetRandomEffectsTracker( + rfx_tracker_train, + rfx_model, + rfx_dataset_train, + outcome_train, + rfx_samples + ) } - } - - # Remove GFR samples if they are not to be retained - if ((!keep_gfr) && (num_gfr > 0)) { - for (i in 1:num_gfr) { - if (include_mean_forest) { - forest_samples_mean$delete_sample(0) - } - if (include_variance_forest) { - forest_samples_variance$delete_sample(0) - } - if (has_rfx) { - rfx_samples$delete_sample(0) - } + if (sample_sigma2_global) { + current_sigma2 <- global_var_samples[forest_ind + 1] + global_model_config$update_global_error_variance( + current_sigma2 + ) } + } else if (has_prev_model) { if (include_mean_forest) { - mean_forest_pred_train <- mean_forest_pred_train[, - (num_gfr + 1):ncol(mean_forest_pred_train) + resetActiveForest( + active_forest_mean, + previous_forest_samples_mean, + previous_model_warmstart_sample_num - 1 + ) + resetForestModel( + forest_model_mean, + active_forest_mean, + forest_dataset_train, + outcome_train, + TRUE + ) + if ( + sample_sigma2_leaf && + (!is.null(previous_leaf_var_samples)) + ) { + leaf_scale_double <- previous_leaf_var_samples[ + previous_model_warmstart_sample_num ] + current_leaf_scale <- as.matrix(leaf_scale_double) + forest_model_config_mean$update_leaf_model_scale( + current_leaf_scale + ) + } } if (include_variance_forest) { - variance_forest_pred_train <- variance_forest_pred_train[, - (num_gfr + 1):ncol(variance_forest_pred_train) - ] + resetActiveForest( + active_forest_variance, + previous_forest_samples_variance, + previous_model_warmstart_sample_num - 1 + ) + resetForestModel( + forest_model_variance, + active_forest_variance, + forest_dataset_train, + outcome_train, + FALSE + ) } - if (sample_sigma2_global) { - global_var_samples <- global_var_samples[ - (num_gfr + 1):length(global_var_samples) - ] + if (has_rfx) { + if (is.null(previous_rfx_samples)) { + warning( + "`previous_model_json` did not have any random effects samples, so the RFX sampler will be run from scratch while the forests and any other parameters are warm started" + ) + rootResetRandomEffectsModel( + rfx_model, + alpha_init, + xi_init, + sigma_alpha_init, + sigma_xi_init, + sigma_xi_shape, + sigma_xi_scale + ) + rootResetRandomEffectsTracker( + rfx_tracker_train, + rfx_model, + rfx_dataset_train, + outcome_train + ) + } else { + resetRandomEffectsModel( + rfx_model, + previous_rfx_samples, + previous_model_warmstart_sample_num - 1, + sigma_alpha_init + ) + resetRandomEffectsTracker( + rfx_tracker_train, + rfx_model, + rfx_dataset_train, + outcome_train, + rfx_samples + ) + } } - if (sample_sigma2_leaf) { - leaf_scale_samples <- leaf_scale_samples[ - (num_gfr + 1):length(leaf_scale_samples) + if (sample_sigma2_global) { + if (!is.null(previous_global_var_samples)) { + current_sigma2 <- previous_global_var_samples[ + previous_model_warmstart_sample_num ] + global_model_config$update_global_error_variance( + current_sigma2 + ) + } } - num_retained_samples <- num_retained_samples - num_gfr - } - - # Mean forest predictions - if (include_mean_forest) { - # y_hat_train <- forest_samples_mean$predict(forest_dataset_train)*y_std_train + y_bar_train - y_hat_train <- mean_forest_pred_train * y_std_train + y_bar_train - if (has_test) { - y_hat_test <- forest_samples_mean$predict(forest_dataset_test) * - y_std_train + - y_bar_train - } - } - - # Variance forest predictions - if (include_variance_forest) { - # sigma2_x_hat_train <- forest_samples_variance$predict(forest_dataset_train) - sigma2_x_hat_train <- exp(variance_forest_pred_train) - if (has_test) { - sigma2_x_hat_test <- forest_samples_variance$predict( - forest_dataset_test + } else { + if (include_mean_forest) { + resetActiveForest(active_forest_mean) + active_forest_mean$set_root_leaves( + init_values_mean_forest / num_trees_mean + ) + resetForestModel( + forest_model_mean, + active_forest_mean, + forest_dataset_train, + outcome_train, + TRUE + ) + if (sample_sigma2_leaf) { + current_leaf_scale <- as.matrix(sigma2_leaf_init) + forest_model_config_mean$update_leaf_model_scale( + current_leaf_scale ) + } + } + if (include_variance_forest) { + resetActiveForest(active_forest_variance) + active_forest_variance$set_root_leaves( + log(variance_forest_init) / num_trees_variance + ) + resetForestModel( + forest_model_variance, + active_forest_variance, + forest_dataset_train, + outcome_train, + FALSE + ) + } + if (has_rfx) { + rootResetRandomEffectsModel( + rfx_model, + alpha_init, + xi_init, + sigma_alpha_init, + sigma_xi_init, + sigma_xi_shape, + sigma_xi_scale + ) + rootResetRandomEffectsTracker( + rfx_tracker_train, + rfx_model, + rfx_dataset_train, + outcome_train + ) } - } - - # Random effects predictions - if (has_rfx) { - rfx_preds_train <- rfx_samples$predict( - rfx_group_ids_train, - rfx_basis_train - ) * - y_std_train - y_hat_train <- y_hat_train + rfx_preds_train - } - if ((has_rfx_test) && (has_test)) { - rfx_preds_test <- rfx_samples$predict( - rfx_group_ids_test, - rfx_basis_test - ) * - y_std_train - y_hat_test <- y_hat_test + rfx_preds_test - } - - # Global error variance - if (sample_sigma2_global) { - sigma2_global_samples <- global_var_samples * (y_std_train^2) - } - - # Leaf parameter variance - if (sample_sigma2_leaf) { - tau_samples <- leaf_scale_samples - } - - # Rescale variance forest prediction by global sigma2 (sampled or constant) - if (include_variance_forest) { if (sample_sigma2_global) { - sigma2_x_hat_train <- sapply(1:num_retained_samples, function(i) { - sigma2_x_hat_train[, i] * sigma2_global_samples[i] - }) - if (has_test) { - sigma2_x_hat_test <- sapply( - 1:num_retained_samples, - function(i) { - sigma2_x_hat_test[, i] * sigma2_global_samples[i] - } - ) - } + current_sigma2 <- sigma2_init + global_model_config$update_global_error_variance( + current_sigma2 + ) + } + } + for (i in (num_gfr + 1):num_samples) { + is_mcmc <- i > (num_gfr + num_burnin) + if (is_mcmc) { + mcmc_counter <- i - (num_gfr + num_burnin) + if (mcmc_counter %% keep_every == 0) { + keep_sample <- TRUE + } else { + keep_sample <- FALSE + } } else { - sigma2_x_hat_train <- sigma2_x_hat_train * - sigma2_init * - y_std_train * - y_std_train - if (has_test) { - sigma2_x_hat_test <- sigma2_x_hat_test * - sigma2_init * - y_std_train * - y_std_train + if (keep_burnin) { + keep_sample <- TRUE + } else { + keep_sample <- FALSE + } + } + if (keep_sample) { + sample_counter <- sample_counter + 1 + } + # Print progress + if (verbose) { + if (num_burnin > 0) { + if ( + ((i - num_gfr) %% 100 == 0) || + ((i - num_gfr) == num_burnin) + ) { + cat( + "Sampling", + i - num_gfr, + "out of", + num_burnin, + "BART burn-in draws; Chain number ", + chain_num, + "\n" + ) + } + } + if (num_mcmc > 0) { + if ( + ((i - num_gfr - num_burnin) %% 100 == 0) || + (i == num_samples) + ) { + cat( + "Sampling", + i - num_burnin - num_gfr, + "out of", + num_mcmc, + "BART MCMC draws; Chain number ", + chain_num, + "\n" + ) } + } } - } - # Return results as a list - model_params <- list( - "sigma2_init" = sigma2_init, - "sigma2_leaf_init" = sigma2_leaf_init, - "a_global" = a_global, - "b_global" = b_global, - "a_leaf" = a_leaf, - "b_leaf" = b_leaf, - "a_forest" = a_forest, - "b_forest" = b_forest, - "outcome_mean" = y_bar_train, - "outcome_scale" = y_std_train, - "standardize" = standardize, - "leaf_dimension" = leaf_dimension, - "is_leaf_constant" = is_leaf_constant, - "leaf_regression" = leaf_regression, - "requires_basis" = requires_basis, - "num_covariates" = ncol(X_train), - "num_basis" = ifelse( - is.null(leaf_basis_train), - 0, - ncol(leaf_basis_train) - ), - "num_samples" = num_retained_samples, - "num_gfr" = num_gfr, - "num_burnin" = num_burnin, - "num_mcmc" = num_mcmc, - "keep_every" = keep_every, - "num_chains" = num_chains, - "has_basis" = !is.null(leaf_basis_train), - "has_rfx" = has_rfx, - "has_rfx_basis" = has_basis_rfx, - "num_rfx_basis" = num_basis_rfx, - "sample_sigma2_global" = sample_sigma2_global, - "sample_sigma2_leaf" = sample_sigma2_leaf, - "include_mean_forest" = include_mean_forest, - "include_variance_forest" = include_variance_forest, - "probit_outcome_model" = probit_outcome_model - ) - result <- list( - "model_params" = model_params, - "train_set_metadata" = X_train_metadata - ) + if (include_mean_forest) { + if (probit_outcome_model) { + # Sample latent probit variable, z | - + outcome_pred <- active_forest_mean$predict( + forest_dataset_train + ) + if (has_rfx) { + rfx_pred <- rfx_model$predict( + rfx_dataset_train, + rfx_tracker_train + ) + outcome_pred <- outcome_pred + rfx_pred + } + mu0 <- outcome_pred[y_train == 0] + mu1 <- outcome_pred[y_train == 1] + u0 <- runif(sum(y_train == 0), 0, pnorm(0 - mu0)) + u1 <- runif(sum(y_train == 1), pnorm(0 - mu1), 1) + resid_train[y_train == 0] <- mu0 + qnorm(u0) + resid_train[y_train == 1] <- mu1 + qnorm(u1) + + # Update outcome + outcome_train$update_data(resid_train - outcome_pred) + } + + forest_model_mean$sample_one_iteration( + forest_dataset = forest_dataset_train, + residual = outcome_train, + forest_samples = forest_samples_mean, + active_forest = active_forest_mean, + rng = rng, + forest_model_config = forest_model_config_mean, + global_model_config = global_model_config, + keep_forest = keep_sample, + gfr = FALSE + ) + + # Cache train set predictions since they are already computed during sampling + if (keep_sample) { + mean_forest_pred_train[, + sample_counter + ] <- forest_model_mean$get_cached_forest_predictions() + } + } + if (include_variance_forest) { + forest_model_variance$sample_one_iteration( + forest_dataset = forest_dataset_train, + residual = outcome_train, + forest_samples = forest_samples_variance, + active_forest = active_forest_variance, + rng = rng, + forest_model_config = forest_model_config_variance, + global_model_config = global_model_config, + keep_forest = keep_sample, + gfr = FALSE + ) + + # Cache train set predictions since they are already computed during sampling + if (keep_sample) { + variance_forest_pred_train[, + sample_counter + ] <- forest_model_variance$get_cached_forest_predictions() + } + } + if (sample_sigma2_global) { + current_sigma2 <- sampleGlobalErrorVarianceOneIteration( + outcome_train, + forest_dataset_train, + rng, + a_global, + b_global + ) + if (keep_sample) { + global_var_samples[sample_counter] <- current_sigma2 + } + global_model_config$update_global_error_variance( + current_sigma2 + ) + } + if (sample_sigma2_leaf) { + leaf_scale_double <- sampleLeafVarianceOneIteration( + active_forest_mean, + rng, + a_leaf, + b_leaf + ) + current_leaf_scale <- as.matrix(leaf_scale_double) + if (keep_sample) { + leaf_scale_samples[sample_counter] <- leaf_scale_double + } + forest_model_config_mean$update_leaf_model_scale( + current_leaf_scale + ) + } + if (has_rfx) { + rfx_model$sample_random_effect( + rfx_dataset_train, + outcome_train, + rfx_tracker_train, + rfx_samples, + keep_sample, + current_sigma2, + rng + ) + } + } + } + } + + # Remove GFR samples if they are not to be retained + if ((!keep_gfr) && (num_gfr > 0)) { + for (i in 1:num_gfr) { + if (include_mean_forest) { + forest_samples_mean$delete_sample(0) + } + if (include_variance_forest) { + forest_samples_variance$delete_sample(0) + } + if (has_rfx) { + rfx_samples$delete_sample(0) + } + } if (include_mean_forest) { - result[["mean_forests"]] = forest_samples_mean - result[["y_hat_train"]] = y_hat_train - if (has_test) result[["y_hat_test"]] = y_hat_test + mean_forest_pred_train <- mean_forest_pred_train[, + (num_gfr + 1):ncol(mean_forest_pred_train) + ] } if (include_variance_forest) { - result[["variance_forests"]] = forest_samples_variance - result[["sigma2_x_hat_train"]] = sigma2_x_hat_train - if (has_test) result[["sigma2_x_hat_test"]] = sigma2_x_hat_test + variance_forest_pred_train <- variance_forest_pred_train[, + (num_gfr + 1):ncol(variance_forest_pred_train) + ] } if (sample_sigma2_global) { - result[["sigma2_global_samples"]] = sigma2_global_samples + global_var_samples <- global_var_samples[ + (num_gfr + 1):length(global_var_samples) + ] } if (sample_sigma2_leaf) { - result[["sigma2_leaf_samples"]] = tau_samples - } - if (has_rfx) { - result[["rfx_samples"]] = rfx_samples - result[["rfx_preds_train"]] = rfx_preds_train - result[["rfx_unique_group_ids"]] = levels(group_ids_factor) - } - if ((has_rfx_test) && (has_test)) { - result[["rfx_preds_test"]] = rfx_preds_test + leaf_scale_samples <- leaf_scale_samples[ + (num_gfr + 1):length(leaf_scale_samples) + ] } - class(result) <- "bartmodel" + num_retained_samples <- num_retained_samples - num_gfr + } - # Clean up classes with external pointers to C++ data structures - if (include_mean_forest) { - rm(forest_model_mean) - } - if (include_variance_forest) { - rm(forest_model_variance) - } - rm(forest_dataset_train) + # Mean forest predictions + if (include_mean_forest) { + # y_hat_train <- forest_samples_mean$predict(forest_dataset_train)*y_std_train + y_bar_train + y_hat_train <- mean_forest_pred_train * y_std_train + y_bar_train if (has_test) { - rm(forest_dataset_test) - } - if (has_rfx) { - rm(rfx_dataset_train, rfx_tracker_train, rfx_model) - } - rm(outcome_train) - rm(rng) - - # Restore global RNG state if user provided a random seed - if (custom_rng) { - .Random.seed <- original_global_seed + y_hat_test <- forest_samples_mean$predict(forest_dataset_test) * + y_std_train + + y_bar_train } + } - return(result) + # Variance forest predictions + if (include_variance_forest) { + # sigma2_x_hat_train <- forest_samples_variance$predict(forest_dataset_train) + sigma2_x_hat_train <- exp(variance_forest_pred_train) + if (has_test) { + sigma2_x_hat_test <- forest_samples_variance$predict( + forest_dataset_test + ) + } + } + + # Random effects predictions + if (has_rfx) { + rfx_preds_train <- rfx_samples$predict( + rfx_group_ids_train, + rfx_basis_train + ) * + y_std_train + y_hat_train <- y_hat_train + rfx_preds_train + } + if ((has_rfx_test) && (has_test)) { + rfx_preds_test <- rfx_samples$predict( + rfx_group_ids_test, + rfx_basis_test + ) * + y_std_train + y_hat_test <- y_hat_test + rfx_preds_test + } + + # Global error variance + if (sample_sigma2_global) { + sigma2_global_samples <- global_var_samples * (y_std_train^2) + } + + # Leaf parameter variance + if (sample_sigma2_leaf) { + tau_samples <- leaf_scale_samples + } + + # Rescale variance forest prediction by global sigma2 (sampled or constant) + if (include_variance_forest) { + if (sample_sigma2_global) { + sigma2_x_hat_train <- sapply(1:num_retained_samples, function(i) { + sigma2_x_hat_train[, i] * sigma2_global_samples[i] + }) + if (has_test) { + sigma2_x_hat_test <- sapply( + 1:num_retained_samples, + function(i) { + sigma2_x_hat_test[, i] * sigma2_global_samples[i] + } + ) + } + } else { + sigma2_x_hat_train <- sigma2_x_hat_train * + sigma2_init * + y_std_train * + y_std_train + if (has_test) { + sigma2_x_hat_test <- sigma2_x_hat_test * + sigma2_init * + y_std_train * + y_std_train + } + } + } + + # Return results as a list + model_params <- list( + "sigma2_init" = sigma2_init, + "sigma2_leaf_init" = sigma2_leaf_init, + "a_global" = a_global, + "b_global" = b_global, + "a_leaf" = a_leaf, + "b_leaf" = b_leaf, + "a_forest" = a_forest, + "b_forest" = b_forest, + "outcome_mean" = y_bar_train, + "outcome_scale" = y_std_train, + "standardize" = standardize, + "leaf_dimension" = leaf_dimension, + "is_leaf_constant" = is_leaf_constant, + "leaf_regression" = leaf_regression, + "requires_basis" = requires_basis, + "num_covariates" = num_cov_orig, + "num_basis" = ifelse( + is.null(leaf_basis_train), + 0, + ncol(leaf_basis_train) + ), + "num_samples" = num_retained_samples, + "num_gfr" = num_gfr, + "num_burnin" = num_burnin, + "num_mcmc" = num_mcmc, + "keep_every" = keep_every, + "num_chains" = num_chains, + "has_basis" = !is.null(leaf_basis_train), + "has_rfx" = has_rfx, + "has_rfx_basis" = has_basis_rfx, + "num_rfx_basis" = num_basis_rfx, + "sample_sigma2_global" = sample_sigma2_global, + "sample_sigma2_leaf" = sample_sigma2_leaf, + "include_mean_forest" = include_mean_forest, + "include_variance_forest" = include_variance_forest, + "probit_outcome_model" = probit_outcome_model, + "rfx_model_spec" = rfx_model_spec + ) + result <- list( + "model_params" = model_params, + "train_set_metadata" = X_train_metadata + ) + if (include_mean_forest) { + result[["mean_forests"]] = forest_samples_mean + result[["y_hat_train"]] = y_hat_train + if (has_test) result[["y_hat_test"]] = y_hat_test + } + if (include_variance_forest) { + result[["variance_forests"]] = forest_samples_variance + result[["sigma2_x_hat_train"]] = sigma2_x_hat_train + if (has_test) result[["sigma2_x_hat_test"]] = sigma2_x_hat_test + } + if (sample_sigma2_global) { + result[["sigma2_global_samples"]] = sigma2_global_samples + } + if (sample_sigma2_leaf) { + result[["sigma2_leaf_samples"]] = tau_samples + } + if (has_rfx) { + result[["rfx_samples"]] = rfx_samples + result[["rfx_preds_train"]] = rfx_preds_train + result[["rfx_unique_group_ids"]] = levels(group_ids_factor) + } + if ((has_rfx_test) && (has_test)) { + result[["rfx_preds_test"]] = rfx_preds_test + } + class(result) <- "bartmodel" + + # Clean up classes with external pointers to C++ data structures + if (include_mean_forest) { + rm(forest_model_mean) + } + if (include_variance_forest) { + rm(forest_model_variance) + } + rm(forest_dataset_train) + if (has_test) { + rm(forest_dataset_test) + } + if (has_rfx) { + rm(rfx_dataset_train, rfx_tracker_train, rfx_model) + } + rm(outcome_train) + rm(rng) + + # Restore global RNG state if user provided a random seed + if (custom_rng) { + .Random.seed <- original_global_seed + } + + return(result) } #' Predict from a sampled BART model on new data #' #' @param object Object of type `bart` containing draws of a regression forest and associated sampling outputs. -#' @param X Covariates used to determine tree leaf predictions for each observation. Must be passed as a matrix or dataframe. +#' @param covariates Covariates used to determine tree leaf predictions for each observation. Must be passed as a matrix or dataframe. #' @param leaf_basis (Optional) Bases used for prediction (by e.g. dot product with leaf values). Default: `NULL`. #' @param rfx_group_ids (Optional) Test set group labels used for an additive random effects model. #' We do not currently support (but plan to in the near future), test set evaluation for group labels #' that were not in the training set. #' @param rfx_basis (Optional) Test set basis for "random-slope" regression in additive random effects model. +#' @param type (Optional) Type of prediction to return. Options are "mean", which averages the predictions from every draw of a BART model, and "posterior", which returns the entire matrix of posterior predictions. Default: "posterior". +#' @param terms (Optional) Which model terms to include in the prediction. This can be a single term or a list of model terms. Options include "y_hat", "mean_forest", "rfx", "variance_forest", or "all". If a model doesn't have mean forest, random effects, or variance forest predictions, but one of those terms is request, the request will simply be ignored. If none of the requested terms are present in a model, this function will return `NULL` along with a warning. Default: "all". +#' @param scale (Optional) Scale of mean function predictions. Options are "linear", which returns predictions on the original scale of the mean forest / RFX terms, and "probability", which transforms predictions into a probability of observing `y == 1`. "probability" is only valid for models fit with a probit outcome model. Default: "linear". #' @param ... (Optional) Other prediction parameters. #' -#' @return List of prediction matrices. If model does not have random effects, the list has one element -- the predictions from the forest. -#' If the model does have random effects, the list has three elements -- forest predictions, random effects predictions, and their sum (`y_hat`). +#' @return List of prediction matrices or single prediction matrix / vector, depending on the terms requested. #' @export #' #' @examples @@ -1817,162 +1855,320 @@ bart <- function( #' num_gfr = 10, num_burnin = 0, num_mcmc = 10) #' y_hat_test <- predict(bart_model, X_test)$y_hat predict.bartmodel <- function( - object, - X, - leaf_basis = NULL, - rfx_group_ids = NULL, - rfx_basis = NULL, - ... + object, + covariates, + leaf_basis = NULL, + rfx_group_ids = NULL, + rfx_basis = NULL, + type = "posterior", + terms = "all", + scale = "linear", + ... ) { - # Preprocess covariates - if ((!is.data.frame(X)) && (!is.matrix(X))) { - stop("X must be a matrix or dataframe") - } - train_set_metadata <- object$train_set_metadata - X <- preprocessPredictionData(X, train_set_metadata) - - # Convert all input data to matrices if not already converted - if ((is.null(dim(leaf_basis))) && (!is.null(leaf_basis))) { - leaf_basis <- as.matrix(leaf_basis) - } - if ((is.null(dim(rfx_basis))) && (!is.null(rfx_basis))) { - rfx_basis <- as.matrix(rfx_basis) - } - - # Data checks - if ((object$model_params$requires_basis) && (is.null(leaf_basis))) { - stop("Basis (leaf_basis) must be provided for this model") - } - if ((!is.null(leaf_basis)) && (nrow(X) != nrow(leaf_basis))) { - stop("X and leaf_basis must have the same number of rows") - } - if (object$model_params$num_covariates != ncol(X)) { - stop("X and leaf_basis must have the same number of rows") - } - if ((object$model_params$has_rfx) && (is.null(rfx_group_ids))) { + # Handle mean function scale + if (!is.character(scale)) { + stop("scale must be a string or character vector") + } + if (!(scale %in% c("linear", "probability"))) { + stop("scale must either be 'linear' or 'probability'") + } + is_probit <- object$model_params$probit_outcome_model + if ((scale == "probability") && (!is_probit)) { + stop( + "scale cannot be 'probability' for models not fit with a probit outcome model" + ) + } + probability_scale <- scale == "probability" + + # Handle prediction type + if (!is.character(type)) { + stop("type must be a string or character vector") + } + if (!(type %in% c("mean", "posterior"))) { + stop("type must either be 'mean' or 'posterior'") + } + predict_mean <- type == "mean" + + # Handle prediction terms + rfx_model_spec <- object$model_params$rfx_model_spec + rfx_intercept <- rfx_model_spec == "intercept_only" + if (!is.character(terms)) { + stop("type must be a string or character vector") + } + num_terms <- length(terms) + has_mean_forest <- object$model_params$include_mean_forest + has_variance_forest <- object$model_params$include_variance_forest + has_rfx <- object$model_params$has_rfx + has_y_hat <- has_mean_forest || has_rfx + predict_y_hat <- (((has_y_hat) && ("y_hat" %in% terms)) || + ((has_y_hat) && ("all" %in% terms))) + predict_mean_forest <- (((has_mean_forest) && ("mean_forest" %in% terms)) || + ((has_mean_forest) && ("all" %in% terms))) + predict_rfx <- (((has_rfx) && ("rfx" %in% terms)) || + ((has_rfx) && ("all" %in% terms))) + predict_variance_forest <- (((has_variance_forest) && + ("variance_forest" %in% terms)) || + ((has_variance_forest) && ("all" %in% terms))) + predict_count <- sum(c( + predict_y_hat, + predict_mean_forest, + predict_rfx, + predict_variance_forest + )) + if (predict_count == 0) { + warning(paste0( + "None of the requested model terms, ", + paste(terms, collapse = ", "), + ", were fit in this model" + )) + return(NULL) + } + predict_rfx_intermediate <- (predict_y_hat && has_rfx) + predict_mean_forest_intermediate <- (predict_y_hat && has_mean_forest) + + # Check that we have at least one term to predict on probability scale + if ( + probability_scale && + !predict_y_hat && + !predict_mean_forest && + !predict_rfx + ) { + stop( + "scale can only be 'probability' if at least one mean term is requested" + ) + } + + # Check that covariates are matrix or data frame + if ((!is.data.frame(covariates)) && (!is.matrix(covariates))) { + stop("covariates must be a matrix or dataframe") + } + + # Convert all input data to matrices if not already converted + if ((is.null(dim(leaf_basis))) && (!is.null(leaf_basis))) { + leaf_basis <- as.matrix(leaf_basis) + } + if ((is.null(dim(rfx_basis))) && (!is.null(rfx_basis))) { + if (predict_rfx) rfx_basis <- as.matrix(rfx_basis) + } + + # Data checks + if ((object$model_params$requires_basis) && (is.null(leaf_basis))) { + stop("Basis (leaf_basis) must be provided for this model") + } + if ((!is.null(leaf_basis)) && (nrow(covariates) != nrow(leaf_basis))) { + stop("covariates and leaf_basis must have the same number of rows") + } + if (object$model_params$num_covariates != ncol(covariates)) { + stop( + "covariates must contain the same number of columns as the BART model's training dataset" + ) + } + if ((predict_rfx) && (is.null(rfx_group_ids))) { + stop( + "Random effect group labels (rfx_group_ids) must be provided for this model" + ) + } + if ((predict_rfx) && (is.null(rfx_basis)) && (!rfx_intercept)) { + stop("Random effects basis (rfx_basis) must be provided for this model") + } + if ((object$model_params$num_rfx_basis > 0) && (!rfx_intercept)) { + if (ncol(rfx_basis) != object$model_params$num_rfx_basis) { + stop( + "Random effects basis has a different dimension than the basis used to train this model" + ) + } + } + + # Preprocess covariates + train_set_metadata <- object$train_set_metadata + covariates <- preprocessPredictionData(covariates, train_set_metadata) + + # Recode group IDs to integer vector (if passed as, for example, a vector of county names, etc...) + has_rfx <- FALSE + if (predict_rfx) { + if (!is.null(rfx_group_ids)) { + rfx_unique_group_ids <- object$rfx_unique_group_ids + group_ids_factor <- factor(rfx_group_ids, levels = rfx_unique_group_ids) + if (sum(is.na(group_ids_factor)) > 0) { stop( - "Random effect group labels (rfx_group_ids) must be provided for this model" + "All random effect group labels provided in rfx_group_ids must have been present in rfx_group_ids_train" ) + } + rfx_group_ids <- as.integer(group_ids_factor) + has_rfx <- TRUE } - if ((object$model_params$has_rfx_basis) && (is.null(rfx_basis))) { - stop("Random effects basis (rfx_basis) must be provided for this model") - } - if ( - (object$model_params$num_rfx_basis > 0) && - (ncol(rfx_basis) != object$model_params$num_rfx_basis) - ) { + } + + # Handle RFX model specification + if (has_rfx) { + if (object$model_params$rfx_model_spec == "custom") { + if (is.null(rfx_basis)) { stop( - "Random effects basis has a different dimension than the basis used to train this model" + "A user-provided basis (`rfx_basis`) must be provided when the model was sampled with a random effects model spec set to 'custom'" ) - } - - # Recode group IDs to integer vector (if passed as, for example, a vector of county names, etc...) - has_rfx <- FALSE - if (!is.null(rfx_group_ids)) { - rfx_unique_group_ids <- object$rfx_unique_group_ids - group_ids_factor <- factor(rfx_group_ids, levels = rfx_unique_group_ids) - if (sum(is.na(group_ids_factor)) > 0) { - stop( - "All random effect group labels provided in rfx_group_ids must be present in rfx_group_ids_train" - ) - } - rfx_group_ids <- as.integer(group_ids_factor) - has_rfx <- TRUE - } - - # Produce basis for the "intercept-only" random effects case - if ((object$model_params$has_rfx) && (is.null(rfx_basis))) { - rfx_basis <- matrix(rep(1, nrow(X)), ncol = 1) - } - - # Create prediction dataset - if (!is.null(leaf_basis)) { - prediction_dataset <- createForestDataset(X, leaf_basis) + } + } else if (object$model_params$rfx_model_spec == "intercept_only") { + # Only construct a basis if user-provided basis missing + if (is.null(rfx_basis)) { + rfx_basis <- matrix( + rep(1, nrow(covariates)), + nrow = nrow(covariates), + ncol = 1 + ) + } + } + } + + # Create prediction dataset + if (!is.null(leaf_basis)) { + prediction_dataset <- createForestDataset(covariates, leaf_basis) + } else { + prediction_dataset <- createForestDataset(covariates) + } + + # Compute variance forest predictions + if (predict_variance_forest) { + s_x_raw <- object$variance_forests$predict(prediction_dataset) + } + + # Scale variance forest predictions + num_samples <- object$model_params$num_samples + y_std <- object$model_params$outcome_scale + y_bar <- object$model_params$outcome_mean + sigma2_init <- object$model_params$sigma2_init + if (predict_variance_forest) { + if (object$model_params$sample_sigma2_global) { + sigma2_global_samples <- object$sigma2_global_samples + variance_forest_predictions <- sapply(1:num_samples, function(i) { + s_x_raw[, i] * sigma2_global_samples[i] + }) } else { - prediction_dataset <- createForestDataset(X) - } - - # Compute mean forest predictions - num_samples <- object$model_params$num_samples - y_std <- object$model_params$outcome_scale - y_bar <- object$model_params$outcome_mean - sigma2_init <- object$model_params$sigma2_init - if (object$model_params$include_mean_forest) { - mean_forest_predictions <- object$mean_forests$predict( - prediction_dataset - ) * - y_std + - y_bar - } - - # Compute variance forest predictions - if (object$model_params$include_variance_forest) { - s_x_raw <- object$variance_forests$predict(prediction_dataset) - } - - # Compute rfx predictions (if needed) - if (object$model_params$has_rfx) { - rfx_predictions <- object$rfx_samples$predict( - rfx_group_ids, - rfx_basis - ) * - y_std - } - - # Scale variance forest predictions - if (object$model_params$include_variance_forest) { - if (object$model_params$sample_sigma2_global) { - sigma2_global_samples <- object$sigma2_global_samples - variance_forest_predictions <- sapply(1:num_samples, function(i) { - s_x_raw[, i] * sigma2_global_samples[i] - }) - } else { - variance_forest_predictions <- s_x_raw * sigma2_init * y_std * y_std - } - } - - if ( - (object$model_params$include_mean_forest) && - (object$model_params$has_rfx) - ) { - y_hat <- mean_forest_predictions + rfx_predictions - } else if ( - (object$model_params$include_mean_forest) && - (!object$model_params$has_rfx) - ) { - y_hat <- mean_forest_predictions - } else if ( - (!object$model_params$include_mean_forest) && - (object$model_params$has_rfx) - ) { - y_hat <- rfx_predictions - } - + variance_forest_predictions <- s_x_raw * sigma2_init * y_std * y_std + } + if (predict_mean) { + variance_forest_predictions <- rowMeans(variance_forest_predictions) + } + } + + # Compute mean forest predictions + if (predict_mean_forest || predict_mean_forest_intermediate) { + mean_forest_predictions <- object$mean_forests$predict( + prediction_dataset + ) * + y_std + + y_bar + } + + # Compute rfx predictions (if needed) + if (predict_rfx || predict_rfx_intermediate) { + if (!is.null(rfx_basis)) { + rfx_predictions <- object$rfx_samples$predict( + rfx_group_ids, + rfx_basis + ) * + y_std + } else { + # Sanity check -- this branch should only occur if rfx_model_spec == "intercept_only" + if (!rfx_intercept) { + stop( + "rfx_basis must be provided for random effects models with random slopes" + ) + } + + # Extract the raw RFX samples and scale by train set outcome standard deviation + rfx_param_list <- object$rfx_samples$extract_parameter_samples() + rfx_beta_draws <- rfx_param_list$beta_samples * y_std + + # Construct a matrix with the appropriate group random effects arranged for each observation + rfx_predictions_raw <- array( + NA, + dim = c( + nrow(X), + ncol(rfx_basis), + object$model_params$num_samples + ) + ) + for (i in 1:nrow(X)) { + rfx_predictions_raw[i, , ] <- + rfx_beta_draws[, rfx_group_ids[i], ] + } + + # Intercept-only model, so the random effect prediction is simply the + # value of the respective group's intercept coefficient for each observation + rfx_predictions = rfx_predictions_raw[, 1, ] + } + } + + # Combine into y hat predictions + if (probability_scale) { + if (predict_y_hat && has_mean_forest && has_rfx) { + y_hat <- pnorm(mean_forest_predictions + rfx_predictions) + mean_forest_predictions <- pnorm(mean_forest_predictions) + rfx_predictions <- pnorm(rfx_predictions) + } else if (predict_y_hat && has_mean_forest) { + y_hat <- pnorm(mean_forest_predictions) + mean_forest_predictions <- pnorm(mean_forest_predictions) + } else if (predict_y_hat && has_rfx) { + y_hat <- pnorm(rfx_predictions) + rfx_predictions <- pnorm(rfx_predictions) + } + } else { + if (predict_y_hat && has_mean_forest && has_rfx) { + y_hat <- mean_forest_predictions + rfx_predictions + } else if (predict_y_hat && has_mean_forest) { + y_hat <- mean_forest_predictions + } else if (predict_y_hat && has_rfx) { + y_hat <- rfx_predictions + } + } + + # Collapse to posterior mean predictions if requested + if (predict_mean) { + if (predict_mean_forest) { + mean_forest_predictions <- rowMeans(mean_forest_predictions) + } + if (predict_rfx) { + rfx_predictions <- rowMeans(rfx_predictions) + } + if (predict_y_hat) { + y_hat <- rowMeans(y_hat) + } + } + + if (predict_count == 1) { + if (predict_y_hat) { + return(y_hat) + } else if (predict_mean_forest) { + return(mean_forest_predictions) + } else if (predict_rfx) { + return(rfx_predictions) + } else if (predict_variance_forest) { + return(variance_forest_predictions) + } + } else { result <- list() - if ( - (object$model_params$has_rfx) || - (object$model_params$include_mean_forest) - ) { - result[["y_hat"]] = y_hat + if (predict_y_hat) { + result[["y_hat"]] = y_hat } else { - result[["y_hat"]] <- NULL + result[["y_hat"]] <- NULL } - if (object$model_params$include_mean_forest) { - result[["mean_forest_predictions"]] = mean_forest_predictions + if (predict_mean_forest) { + result[["mean_forest_predictions"]] = mean_forest_predictions } else { - result[["mean_forest_predictions"]] <- NULL + result[["mean_forest_predictions"]] <- NULL } - if (object$model_params$has_rfx) { - result[["rfx_predictions"]] = rfx_predictions + if (predict_rfx) { + result[["rfx_predictions"]] = rfx_predictions } else { - result[["rfx_predictions"]] <- NULL + result[["rfx_predictions"]] <- NULL } - if (object$model_params$include_variance_forest) { - result[["variance_forest_predictions"]] = variance_forest_predictions + if (predict_variance_forest) { + result[["variance_forest_predictions"]] = variance_forest_predictions } else { - result[["variance_forest_predictions"]] <- NULL + result[["variance_forest_predictions"]] <- NULL } return(result) + } } #' Extract raw sample values for each of the random effect parameter terms. @@ -2024,26 +2220,26 @@ predict.bartmodel <- function( #' num_gfr = 10, num_burnin = 0, num_mcmc = 10) #' rfx_samples <- getRandomEffectSamples(bart_model) getRandomEffectSamples.bartmodel <- function(object, ...) { - result = list() + result = list() - if (!object$model_params$has_rfx) { - warning("This model has no RFX terms, returning an empty list") - return(result) - } + if (!object$model_params$has_rfx) { + warning("This model has no RFX terms, returning an empty list") + return(result) + } - # Extract the samples - result <- object$rfx_samples$extract_parameter_samples() + # Extract the samples + result <- object$rfx_samples$extract_parameter_samples() - # Scale by sd(y_train) - result$beta_samples <- result$beta_samples * - object$model_params$outcome_scale - result$xi_samples <- result$xi_samples * object$model_params$outcome_scale - result$alpha_samples <- result$alpha_samples * - object$model_params$outcome_scale - result$sigma_samples <- result$sigma_samples * - (object$model_params$outcome_scale^2) + # Scale by sd(y_train) + result$beta_samples <- result$beta_samples * + object$model_params$outcome_scale + result$xi_samples <- result$xi_samples * object$model_params$outcome_scale + result$alpha_samples <- result$alpha_samples * + object$model_params$outcome_scale + result$sigma_samples <- result$sigma_samples * + (object$model_params$outcome_scale^2) - return(result) + return(result) } #' Convert the persistent aspects of a BART model to (in-memory) JSON @@ -2078,132 +2274,136 @@ getRandomEffectSamples.bartmodel <- function(object, ...) { #' num_gfr = 10, num_burnin = 0, num_mcmc = 10) #' bart_json <- saveBARTModelToJson(bart_model) saveBARTModelToJson <- function(object) { - jsonobj <- createCppJson() - - if (!inherits(object, "bartmodel")) { - stop("`object` must be a BART model") - } - - if (is.null(object$model_params)) { - stop("This BCF model has not yet been sampled") - } - - # Add the forests - if (object$model_params$include_mean_forest) { - jsonobj$add_forest(object$mean_forests) - } - if (object$model_params$include_variance_forest) { - jsonobj$add_forest(object$variance_forests) - } - - # Add metadata - jsonobj$add_scalar( - "num_numeric_vars", - object$train_set_metadata$num_numeric_vars + jsonobj <- createCppJson() + + if (!inherits(object, "bartmodel")) { + stop("`object` must be a BART model") + } + + if (is.null(object$model_params)) { + stop("This BCF model has not yet been sampled") + } + + # Add the forests + if (object$model_params$include_mean_forest) { + jsonobj$add_forest(object$mean_forests) + } + if (object$model_params$include_variance_forest) { + jsonobj$add_forest(object$variance_forests) + } + + # Add metadata + jsonobj$add_scalar( + "num_numeric_vars", + object$train_set_metadata$num_numeric_vars + ) + jsonobj$add_scalar( + "num_ordered_cat_vars", + object$train_set_metadata$num_ordered_cat_vars + ) + jsonobj$add_scalar( + "num_unordered_cat_vars", + object$train_set_metadata$num_unordered_cat_vars + ) + if (object$train_set_metadata$num_numeric_vars > 0) { + jsonobj$add_string_vector( + "numeric_vars", + object$train_set_metadata$numeric_vars ) - jsonobj$add_scalar( - "num_ordered_cat_vars", - object$train_set_metadata$num_ordered_cat_vars + } + if (object$train_set_metadata$num_ordered_cat_vars > 0) { + jsonobj$add_string_vector( + "ordered_cat_vars", + object$train_set_metadata$ordered_cat_vars ) - jsonobj$add_scalar( - "num_unordered_cat_vars", - object$train_set_metadata$num_unordered_cat_vars + jsonobj$add_string_list( + "ordered_unique_levels", + object$train_set_metadata$ordered_unique_levels ) - if (object$train_set_metadata$num_numeric_vars > 0) { - jsonobj$add_string_vector( - "numeric_vars", - object$train_set_metadata$numeric_vars - ) - } - if (object$train_set_metadata$num_ordered_cat_vars > 0) { - jsonobj$add_string_vector( - "ordered_cat_vars", - object$train_set_metadata$ordered_cat_vars - ) - jsonobj$add_string_list( - "ordered_unique_levels", - object$train_set_metadata$ordered_unique_levels - ) - } - if (object$train_set_metadata$num_unordered_cat_vars > 0) { - jsonobj$add_string_vector( - "unordered_cat_vars", - object$train_set_metadata$unordered_cat_vars - ) - jsonobj$add_string_list( - "unordered_unique_levels", - object$train_set_metadata$unordered_unique_levels - ) - } - - # Add global parameters - jsonobj$add_scalar("outcome_scale", object$model_params$outcome_scale) - jsonobj$add_scalar("outcome_mean", object$model_params$outcome_mean) - jsonobj$add_boolean("standardize", object$model_params$standardize) - jsonobj$add_scalar("sigma2_init", object$model_params$sigma2_init) - jsonobj$add_boolean( - "sample_sigma2_global", - object$model_params$sample_sigma2_global + } + if (object$train_set_metadata$num_unordered_cat_vars > 0) { + jsonobj$add_string_vector( + "unordered_cat_vars", + object$train_set_metadata$unordered_cat_vars ) - jsonobj$add_boolean( - "sample_sigma2_leaf", - object$model_params$sample_sigma2_leaf + jsonobj$add_string_list( + "unordered_unique_levels", + object$train_set_metadata$unordered_unique_levels ) - jsonobj$add_boolean( - "include_mean_forest", - object$model_params$include_mean_forest + } + + # Add global parameters + jsonobj$add_scalar("outcome_scale", object$model_params$outcome_scale) + jsonobj$add_scalar("outcome_mean", object$model_params$outcome_mean) + jsonobj$add_boolean("standardize", object$model_params$standardize) + jsonobj$add_scalar("sigma2_init", object$model_params$sigma2_init) + jsonobj$add_boolean( + "sample_sigma2_global", + object$model_params$sample_sigma2_global + ) + jsonobj$add_boolean( + "sample_sigma2_leaf", + object$model_params$sample_sigma2_leaf + ) + jsonobj$add_boolean( + "include_mean_forest", + object$model_params$include_mean_forest + ) + jsonobj$add_boolean( + "include_variance_forest", + object$model_params$include_variance_forest + ) + jsonobj$add_boolean("has_rfx", object$model_params$has_rfx) + jsonobj$add_boolean("has_rfx_basis", object$model_params$has_rfx_basis) + jsonobj$add_scalar("num_rfx_basis", object$model_params$num_rfx_basis) + jsonobj$add_scalar("num_gfr", object$model_params$num_gfr) + jsonobj$add_scalar("num_burnin", object$model_params$num_burnin) + jsonobj$add_scalar("num_mcmc", object$model_params$num_mcmc) + jsonobj$add_scalar("num_samples", object$model_params$num_samples) + jsonobj$add_scalar("num_covariates", object$model_params$num_covariates) + jsonobj$add_scalar("num_basis", object$model_params$num_basis) + jsonobj$add_scalar("num_chains", object$model_params$num_chains) + jsonobj$add_scalar("keep_every", object$model_params$keep_every) + jsonobj$add_boolean("requires_basis", object$model_params$requires_basis) + jsonobj$add_boolean( + "probit_outcome_model", + object$model_params$probit_outcome_model + ) + jsonobj$add_string( + "rfx_model_spec", + object$model_params$rfx_model_spec + ) + if (object$model_params$sample_sigma2_global) { + jsonobj$add_vector( + "sigma2_global_samples", + object$sigma2_global_samples, + "parameters" ) - jsonobj$add_boolean( - "include_variance_forest", - object$model_params$include_variance_forest + } + if (object$model_params$sample_sigma2_leaf) { + jsonobj$add_vector( + "sigma2_leaf_samples", + object$sigma2_leaf_samples, + "parameters" ) - jsonobj$add_boolean("has_rfx", object$model_params$has_rfx) - jsonobj$add_boolean("has_rfx_basis", object$model_params$has_rfx_basis) - jsonobj$add_scalar("num_rfx_basis", object$model_params$num_rfx_basis) - jsonobj$add_scalar("num_gfr", object$model_params$num_gfr) - jsonobj$add_scalar("num_burnin", object$model_params$num_burnin) - jsonobj$add_scalar("num_mcmc", object$model_params$num_mcmc) - jsonobj$add_scalar("num_samples", object$model_params$num_samples) - jsonobj$add_scalar("num_covariates", object$model_params$num_covariates) - jsonobj$add_scalar("num_basis", object$model_params$num_basis) - jsonobj$add_scalar("num_chains", object$model_params$num_chains) - jsonobj$add_scalar("keep_every", object$model_params$keep_every) - jsonobj$add_boolean("requires_basis", object$model_params$requires_basis) - jsonobj$add_boolean( - "probit_outcome_model", - object$model_params$probit_outcome_model + } + + # Add random effects (if present) + if (object$model_params$has_rfx) { + jsonobj$add_random_effects(object$rfx_samples) + jsonobj$add_string_vector( + "rfx_unique_group_ids", + object$rfx_unique_group_ids ) - if (object$model_params$sample_sigma2_global) { - jsonobj$add_vector( - "sigma2_global_samples", - object$sigma2_global_samples, - "parameters" - ) - } - if (object$model_params$sample_sigma2_leaf) { - jsonobj$add_vector( - "sigma2_leaf_samples", - object$sigma2_leaf_samples, - "parameters" - ) - } - - # Add random effects (if present) - if (object$model_params$has_rfx) { - jsonobj$add_random_effects(object$rfx_samples) - jsonobj$add_string_vector( - "rfx_unique_group_ids", - object$rfx_unique_group_ids - ) - } + } - # Add covariate preprocessor metadata - preprocessor_metadata_string <- savePreprocessorToJsonString( - object$train_set_metadata - ) - jsonobj$add_string("preprocessor_metadata", preprocessor_metadata_string) + # Add covariate preprocessor metadata + preprocessor_metadata_string <- savePreprocessorToJsonString( + object$train_set_metadata + ) + jsonobj$add_string("preprocessor_metadata", preprocessor_metadata_string) - return(jsonobj) + return(jsonobj) } #' Convert the persistent aspects of a BART model to (in-memory) JSON and save to a file @@ -2241,11 +2441,11 @@ saveBARTModelToJson <- function(object) { #' saveBARTModelToJsonFile(bart_model, file.path(tmpjson)) #' unlink(tmpjson) saveBARTModelToJsonFile <- function(object, filename) { - # Convert to Json - jsonobj <- saveBARTModelToJson(object) + # Convert to Json + jsonobj <- saveBARTModelToJson(object) - # Save to file - jsonobj$save_file(filename) + # Save to file + jsonobj$save_file(filename) } #' Convert the persistent aspects of a BART model to (in-memory) JSON string @@ -2279,11 +2479,11 @@ saveBARTModelToJsonFile <- function(object, filename) { #' num_gfr = 10, num_burnin = 0, num_mcmc = 10) #' bart_json_string <- saveBARTModelToJsonString(bart_model) saveBARTModelToJsonString <- function(object) { - # Convert to Json - jsonobj <- saveBARTModelToJson(object) + # Convert to Json + jsonobj <- saveBARTModelToJson(object) - # Dump to string - return(jsonobj$return_json_string()) + # Dump to string + return(jsonobj$return_json_string()) } #' Convert an (in-memory) JSON representation of a BART model to a BART model object @@ -2320,138 +2520,141 @@ saveBARTModelToJsonString <- function(object) { #' bart_json <- saveBARTModelToJson(bart_model) #' bart_model_roundtrip <- createBARTModelFromJson(bart_json) createBARTModelFromJson <- function(json_object) { - # Initialize the BCF model - output <- list() - - # Unpack the forests - include_mean_forest <- json_object$get_boolean("include_mean_forest") - include_variance_forest <- json_object$get_boolean( - "include_variance_forest" + # Initialize the BCF model + output <- list() + + # Unpack the forests + include_mean_forest <- json_object$get_boolean("include_mean_forest") + include_variance_forest <- json_object$get_boolean( + "include_variance_forest" + ) + if (include_mean_forest) { + output[["mean_forests"]] <- loadForestContainerJson( + json_object, + "forest_0" ) - if (include_mean_forest) { - output[["mean_forests"]] <- loadForestContainerJson( - json_object, - "forest_0" - ) - if (include_variance_forest) { - output[["variance_forests"]] <- loadForestContainerJson( - json_object, - "forest_1" - ) - } - } else { - output[["variance_forests"]] <- loadForestContainerJson( - json_object, - "forest_0" - ) - } - - # Unpack metadata - train_set_metadata = list() - train_set_metadata[["num_numeric_vars"]] <- json_object$get_scalar( - "num_numeric_vars" - ) - train_set_metadata[["num_ordered_cat_vars"]] <- json_object$get_scalar( - "num_ordered_cat_vars" + if (include_variance_forest) { + output[["variance_forests"]] <- loadForestContainerJson( + json_object, + "forest_1" + ) + } + } else { + output[["variance_forests"]] <- loadForestContainerJson( + json_object, + "forest_0" ) - train_set_metadata[["num_unordered_cat_vars"]] <- json_object$get_scalar( - "num_unordered_cat_vars" + } + + # Unpack metadata + train_set_metadata = list() + train_set_metadata[["num_numeric_vars"]] <- json_object$get_scalar( + "num_numeric_vars" + ) + train_set_metadata[["num_ordered_cat_vars"]] <- json_object$get_scalar( + "num_ordered_cat_vars" + ) + train_set_metadata[["num_unordered_cat_vars"]] <- json_object$get_scalar( + "num_unordered_cat_vars" + ) + if (train_set_metadata[["num_numeric_vars"]] > 0) { + train_set_metadata[["numeric_vars"]] <- json_object$get_string_vector( + "numeric_vars" ) - if (train_set_metadata[["num_numeric_vars"]] > 0) { - train_set_metadata[["numeric_vars"]] <- json_object$get_string_vector( - "numeric_vars" - ) - } - if (train_set_metadata[["num_ordered_cat_vars"]] > 0) { - train_set_metadata[[ - "ordered_cat_vars" - ]] <- json_object$get_string_vector("ordered_cat_vars") - train_set_metadata[[ - "ordered_unique_levels" - ]] <- json_object$get_string_list( - "ordered_unique_levels", - train_set_metadata[["ordered_cat_vars"]] - ) - } - if (train_set_metadata[["num_unordered_cat_vars"]] > 0) { - train_set_metadata[[ - "unordered_cat_vars" - ]] <- json_object$get_string_vector("unordered_cat_vars") - train_set_metadata[[ - "unordered_unique_levels" - ]] <- json_object$get_string_list( - "unordered_unique_levels", - train_set_metadata[["unordered_cat_vars"]] - ) - } - output[["train_set_metadata"]] <- train_set_metadata - - # Unpack model params - model_params = list() - model_params[["outcome_scale"]] <- json_object$get_scalar("outcome_scale") - model_params[["outcome_mean"]] <- json_object$get_scalar("outcome_mean") - model_params[["standardize"]] <- json_object$get_boolean("standardize") - model_params[["sigma2_init"]] <- json_object$get_scalar("sigma2_init") - model_params[["sample_sigma2_global"]] <- json_object$get_boolean( - "sample_sigma2_global" + } + if (train_set_metadata[["num_ordered_cat_vars"]] > 0) { + train_set_metadata[[ + "ordered_cat_vars" + ]] <- json_object$get_string_vector("ordered_cat_vars") + train_set_metadata[[ + "ordered_unique_levels" + ]] <- json_object$get_string_list( + "ordered_unique_levels", + train_set_metadata[["ordered_cat_vars"]] ) - model_params[["sample_sigma2_leaf"]] <- json_object$get_boolean( - "sample_sigma2_leaf" + } + if (train_set_metadata[["num_unordered_cat_vars"]] > 0) { + train_set_metadata[[ + "unordered_cat_vars" + ]] <- json_object$get_string_vector("unordered_cat_vars") + train_set_metadata[[ + "unordered_unique_levels" + ]] <- json_object$get_string_list( + "unordered_unique_levels", + train_set_metadata[["unordered_cat_vars"]] ) - model_params[["include_mean_forest"]] <- include_mean_forest - model_params[["include_variance_forest"]] <- include_variance_forest - model_params[["has_rfx"]] <- json_object$get_boolean("has_rfx") - model_params[["has_rfx_basis"]] <- json_object$get_boolean("has_rfx_basis") - model_params[["num_rfx_basis"]] <- json_object$get_scalar("num_rfx_basis") - model_params[["num_gfr"]] <- json_object$get_scalar("num_gfr") - model_params[["num_burnin"]] <- json_object$get_scalar("num_burnin") - model_params[["num_mcmc"]] <- json_object$get_scalar("num_mcmc") - model_params[["num_samples"]] <- json_object$get_scalar("num_samples") - model_params[["num_covariates"]] <- json_object$get_scalar("num_covariates") - model_params[["num_basis"]] <- json_object$get_scalar("num_basis") - model_params[["num_chains"]] <- json_object$get_scalar("num_chains") - model_params[["keep_every"]] <- json_object$get_scalar("keep_every") - model_params[["requires_basis"]] <- json_object$get_boolean( - "requires_basis" + } + output[["train_set_metadata"]] <- train_set_metadata + + # Unpack model params + model_params = list() + model_params[["outcome_scale"]] <- json_object$get_scalar("outcome_scale") + model_params[["outcome_mean"]] <- json_object$get_scalar("outcome_mean") + model_params[["standardize"]] <- json_object$get_boolean("standardize") + model_params[["sigma2_init"]] <- json_object$get_scalar("sigma2_init") + model_params[["sample_sigma2_global"]] <- json_object$get_boolean( + "sample_sigma2_global" + ) + model_params[["sample_sigma2_leaf"]] <- json_object$get_boolean( + "sample_sigma2_leaf" + ) + model_params[["include_mean_forest"]] <- include_mean_forest + model_params[["include_variance_forest"]] <- include_variance_forest + model_params[["has_rfx"]] <- json_object$get_boolean("has_rfx") + model_params[["has_rfx_basis"]] <- json_object$get_boolean("has_rfx_basis") + model_params[["num_rfx_basis"]] <- json_object$get_scalar("num_rfx_basis") + model_params[["num_gfr"]] <- json_object$get_scalar("num_gfr") + model_params[["num_burnin"]] <- json_object$get_scalar("num_burnin") + model_params[["num_mcmc"]] <- json_object$get_scalar("num_mcmc") + model_params[["num_samples"]] <- json_object$get_scalar("num_samples") + model_params[["num_covariates"]] <- json_object$get_scalar("num_covariates") + model_params[["num_basis"]] <- json_object$get_scalar("num_basis") + model_params[["num_chains"]] <- json_object$get_scalar("num_chains") + model_params[["keep_every"]] <- json_object$get_scalar("keep_every") + model_params[["requires_basis"]] <- json_object$get_boolean( + "requires_basis" + ) + model_params[["probit_outcome_model"]] <- json_object$get_boolean( + "probit_outcome_model" + ) + model_params[["rfx_model_spec"]] <- json_object$get_string( + "rfx_model_spec" + ) + + output[["model_params"]] <- model_params + + # Unpack sampled parameters + if (model_params[["sample_sigma2_global"]]) { + output[["sigma2_global_samples"]] <- json_object$get_vector( + "sigma2_global_samples", + "parameters" ) - model_params[["probit_outcome_model"]] <- json_object$get_boolean( - "probit_outcome_model" + } + if (model_params[["sample_sigma2_leaf"]]) { + output[["sigma2_leaf_samples"]] <- json_object$get_vector( + "sigma2_leaf_samples", + "parameters" ) + } - output[["model_params"]] <- model_params - - # Unpack sampled parameters - if (model_params[["sample_sigma2_global"]]) { - output[["sigma2_global_samples"]] <- json_object$get_vector( - "sigma2_global_samples", - "parameters" - ) - } - if (model_params[["sample_sigma2_leaf"]]) { - output[["sigma2_leaf_samples"]] <- json_object$get_vector( - "sigma2_leaf_samples", - "parameters" - ) - } - - # Unpack random effects - if (model_params[["has_rfx"]]) { - output[["rfx_unique_group_ids"]] <- json_object$get_string_vector( - "rfx_unique_group_ids" - ) - output[["rfx_samples"]] <- loadRandomEffectSamplesJson(json_object, 0) - } - - # Unpack covariate preprocessor - preprocessor_metadata_string <- json_object$get_string( - "preprocessor_metadata" + # Unpack random effects + if (model_params[["has_rfx"]]) { + output[["rfx_unique_group_ids"]] <- json_object$get_string_vector( + "rfx_unique_group_ids" ) - output[["train_set_metadata"]] <- createPreprocessorFromJsonString( - preprocessor_metadata_string - ) - - class(output) <- "bartmodel" - return(output) + output[["rfx_samples"]] <- loadRandomEffectSamplesJson(json_object, 0) + } + + # Unpack covariate preprocessor + preprocessor_metadata_string <- json_object$get_string( + "preprocessor_metadata" + ) + output[["train_set_metadata"]] <- createPreprocessorFromJsonString( + preprocessor_metadata_string + ) + + class(output) <- "bartmodel" + return(output) } #' Convert a JSON file containing sample information on a trained BART model @@ -2490,13 +2693,13 @@ createBARTModelFromJson <- function(json_object) { #' bart_model_roundtrip <- createBARTModelFromJsonFile(file.path(tmpjson)) #' unlink(tmpjson) createBARTModelFromJsonFile <- function(json_filename) { - # Load a `CppJson` object from file - bart_json <- createCppJsonFile(json_filename) + # Load a `CppJson` object from file + bart_json <- createCppJsonFile(json_filename) - # Create and return the BART object - bart_object <- createBARTModelFromJson(bart_json) + # Create and return the BART object + bart_object <- createBARTModelFromJson(bart_json) - return(bart_object) + return(bart_object) } #' Convert a JSON string containing sample information on a trained BART model @@ -2534,13 +2737,13 @@ createBARTModelFromJsonFile <- function(json_filename) { #' bart_model_roundtrip <- createBARTModelFromJsonString(bart_json) #' y_hat_mean_roundtrip <- rowMeans(predict(bart_model_roundtrip, X_train)$y_hat) createBARTModelFromJsonString <- function(json_string) { - # Load a `CppJson` object from string - bart_json <- createCppJsonString(json_string) + # Load a `CppJson` object from string + bart_json <- createCppJsonString(json_string) - # Create and return the BART object - bart_object <- createBARTModelFromJson(bart_json) + # Create and return the BART object + bart_object <- createBARTModelFromJson(bart_json) - return(bart_object) + return(bart_object) } #' Convert a list of (in-memory) JSON representations of a BART model to a single combined BART model object @@ -2577,202 +2780,205 @@ createBARTModelFromJsonString <- function(json_string) { #' bart_json <- list(saveBARTModelToJson(bart_model)) #' bart_model_roundtrip <- createBARTModelFromCombinedJson(bart_json) createBARTModelFromCombinedJson <- function(json_object_list) { - # Initialize the BCF model - output <- list() - - # For scalar / preprocessing details which aren't sample-dependent, - # defer to the first json - json_object_default <- json_object_list[[1]] - - # Unpack the forests - include_mean_forest <- json_object_default$get_boolean( - "include_mean_forest" + # Initialize the BCF model + output <- list() + + # For scalar / preprocessing details which aren't sample-dependent, + # defer to the first json + json_object_default <- json_object_list[[1]] + + # Unpack the forests + include_mean_forest <- json_object_default$get_boolean( + "include_mean_forest" + ) + include_variance_forest <- json_object_default$get_boolean( + "include_variance_forest" + ) + if (include_mean_forest) { + output[["mean_forests"]] <- loadForestContainerCombinedJson( + json_object_list, + "forest_0" ) - include_variance_forest <- json_object_default$get_boolean( - "include_variance_forest" + if (include_variance_forest) { + output[["variance_forests"]] <- loadForestContainerCombinedJson( + json_object_list, + "forest_1" + ) + } + } else { + output[["variance_forests"]] <- loadForestContainerCombinedJson( + json_object_list, + "forest_0" ) - if (include_mean_forest) { - output[["mean_forests"]] <- loadForestContainerCombinedJson( - json_object_list, - "forest_0" - ) - if (include_variance_forest) { - output[["variance_forests"]] <- loadForestContainerCombinedJson( - json_object_list, - "forest_1" - ) - } - } else { - output[["variance_forests"]] <- loadForestContainerCombinedJson( - json_object_list, - "forest_0" - ) - } - - # Unpack metadata - train_set_metadata = list() - train_set_metadata[["num_numeric_vars"]] <- json_object_default$get_scalar( - "num_numeric_vars" + } + + # Unpack metadata + train_set_metadata = list() + train_set_metadata[["num_numeric_vars"]] <- json_object_default$get_scalar( + "num_numeric_vars" + ) + train_set_metadata[[ + "num_ordered_cat_vars" + ]] <- json_object_default$get_scalar("num_ordered_cat_vars") + train_set_metadata[[ + "num_unordered_cat_vars" + ]] <- json_object_default$get_scalar("num_unordered_cat_vars") + if (train_set_metadata[["num_numeric_vars"]] > 0) { + train_set_metadata[[ + "numeric_vars" + ]] <- json_object_default$get_string_vector("numeric_vars") + } + if (train_set_metadata[["num_ordered_cat_vars"]] > 0) { + train_set_metadata[[ + "ordered_cat_vars" + ]] <- json_object_default$get_string_vector("ordered_cat_vars") + train_set_metadata[[ + "ordered_unique_levels" + ]] <- json_object_default$get_string_list( + "ordered_unique_levels", + train_set_metadata[["ordered_cat_vars"]] ) + } + if (train_set_metadata[["num_unordered_cat_vars"]] > 0) { train_set_metadata[[ - "num_ordered_cat_vars" - ]] <- json_object_default$get_scalar("num_ordered_cat_vars") + "unordered_cat_vars" + ]] <- json_object_default$get_string_vector("unordered_cat_vars") train_set_metadata[[ - "num_unordered_cat_vars" - ]] <- json_object_default$get_scalar("num_unordered_cat_vars") - if (train_set_metadata[["num_numeric_vars"]] > 0) { - train_set_metadata[[ - "numeric_vars" - ]] <- json_object_default$get_string_vector("numeric_vars") - } - if (train_set_metadata[["num_ordered_cat_vars"]] > 0) { - train_set_metadata[[ - "ordered_cat_vars" - ]] <- json_object_default$get_string_vector("ordered_cat_vars") - train_set_metadata[[ - "ordered_unique_levels" - ]] <- json_object_default$get_string_list( - "ordered_unique_levels", - train_set_metadata[["ordered_cat_vars"]] + "unordered_unique_levels" + ]] <- json_object_default$get_string_list( + "unordered_unique_levels", + train_set_metadata[["unordered_cat_vars"]] + ) + } + output[["train_set_metadata"]] <- train_set_metadata + + # Unpack model params + model_params = list() + model_params[["outcome_scale"]] <- json_object_default$get_scalar( + "outcome_scale" + ) + model_params[["outcome_mean"]] <- json_object_default$get_scalar( + "outcome_mean" + ) + model_params[["standardize"]] <- json_object_default$get_boolean( + "standardize" + ) + model_params[["sigma2_init"]] <- json_object_default$get_scalar( + "sigma2_init" + ) + model_params[["sample_sigma2_global"]] <- json_object_default$get_boolean( + "sample_sigma2_global" + ) + model_params[["sample_sigma2_leaf"]] <- json_object_default$get_boolean( + "sample_sigma2_leaf" + ) + model_params[["include_mean_forest"]] <- include_mean_forest + model_params[["include_variance_forest"]] <- include_variance_forest + model_params[["has_rfx"]] <- json_object_default$get_boolean("has_rfx") + model_params[["has_rfx_basis"]] <- json_object_default$get_boolean( + "has_rfx_basis" + ) + model_params[["num_rfx_basis"]] <- json_object_default$get_scalar( + "num_rfx_basis" + ) + model_params[["num_covariates"]] <- json_object_default$get_scalar( + "num_covariates" + ) + model_params[["num_basis"]] <- json_object_default$get_scalar("num_basis") + model_params[["requires_basis"]] <- json_object_default$get_boolean( + "requires_basis" + ) + model_params[["probit_outcome_model"]] <- json_object_default$get_boolean( + "probit_outcome_model" + ) + model_params[["rfx_model_spec"]] <- json_object_default$get_string( + "rfx_model_spec" + ) + model_params[["num_chains"]] <- json_object_default$get_scalar("num_chains") + model_params[["keep_every"]] <- json_object_default$get_scalar("keep_every") + + # Combine values that are sample-specific + for (i in 1:length(json_object_list)) { + json_object <- json_object_list[[i]] + if (i == 1) { + model_params[["num_gfr"]] <- json_object$get_scalar("num_gfr") + model_params[["num_burnin"]] <- json_object$get_scalar("num_burnin") + model_params[["num_mcmc"]] <- json_object$get_scalar("num_mcmc") + model_params[["num_samples"]] <- json_object$get_scalar( + "num_samples" + ) + } else { + prev_json <- json_object_list[[i - 1]] + model_params[["num_gfr"]] <- model_params[["num_gfr"]] + + json_object$get_scalar("num_gfr") + model_params[["num_burnin"]] <- model_params[["num_burnin"]] + + json_object$get_scalar("num_burnin") + model_params[["num_mcmc"]] <- model_params[["num_mcmc"]] + + json_object$get_scalar("num_mcmc") + model_params[["num_samples"]] <- model_params[["num_samples"]] + + json_object$get_scalar("num_samples") + } + } + output[["model_params"]] <- model_params + + # Unpack sampled parameters + if (model_params[["sample_sigma2_global"]]) { + for (i in 1:length(json_object_list)) { + json_object <- json_object_list[[i]] + if (i == 1) { + output[["sigma2_global_samples"]] <- json_object$get_vector( + "sigma2_global_samples", + "parameters" ) - } - if (train_set_metadata[["num_unordered_cat_vars"]] > 0) { - train_set_metadata[[ - "unordered_cat_vars" - ]] <- json_object_default$get_string_vector("unordered_cat_vars") - train_set_metadata[[ - "unordered_unique_levels" - ]] <- json_object_default$get_string_list( - "unordered_unique_levels", - train_set_metadata[["unordered_cat_vars"]] + } else { + output[["sigma2_global_samples"]] <- c( + output[["sigma2_global_samples"]], + json_object$get_vector( + "sigma2_global_samples", + "parameters" + ) ) + } } - output[["train_set_metadata"]] <- train_set_metadata - - # Unpack model params - model_params = list() - model_params[["outcome_scale"]] <- json_object_default$get_scalar( - "outcome_scale" - ) - model_params[["outcome_mean"]] <- json_object_default$get_scalar( - "outcome_mean" - ) - model_params[["standardize"]] <- json_object_default$get_boolean( - "standardize" - ) - model_params[["sigma2_init"]] <- json_object_default$get_scalar( - "sigma2_init" - ) - model_params[["sample_sigma2_global"]] <- json_object_default$get_boolean( - "sample_sigma2_global" - ) - model_params[["sample_sigma2_leaf"]] <- json_object_default$get_boolean( - "sample_sigma2_leaf" - ) - model_params[["include_mean_forest"]] <- include_mean_forest - model_params[["include_variance_forest"]] <- include_variance_forest - model_params[["has_rfx"]] <- json_object_default$get_boolean("has_rfx") - model_params[["has_rfx_basis"]] <- json_object_default$get_boolean( - "has_rfx_basis" - ) - model_params[["num_rfx_basis"]] <- json_object_default$get_scalar( - "num_rfx_basis" - ) - model_params[["num_covariates"]] <- json_object_default$get_scalar( - "num_covariates" - ) - model_params[["num_basis"]] <- json_object_default$get_scalar("num_basis") - model_params[["requires_basis"]] <- json_object_default$get_boolean( - "requires_basis" - ) - model_params[["probit_outcome_model"]] <- json_object_default$get_boolean( - "probit_outcome_model" - ) - model_params[["num_chains"]] <- json_object_default$get_scalar("num_chains") - model_params[["keep_every"]] <- json_object_default$get_scalar("keep_every") - - # Combine values that are sample-specific + } + if (model_params[["sample_sigma2_leaf"]]) { for (i in 1:length(json_object_list)) { - json_object <- json_object_list[[i]] - if (i == 1) { - model_params[["num_gfr"]] <- json_object$get_scalar("num_gfr") - model_params[["num_burnin"]] <- json_object$get_scalar("num_burnin") - model_params[["num_mcmc"]] <- json_object$get_scalar("num_mcmc") - model_params[["num_samples"]] <- json_object$get_scalar( - "num_samples" - ) - } else { - prev_json <- json_object_list[[i - 1]] - model_params[["num_gfr"]] <- model_params[["num_gfr"]] + - json_object$get_scalar("num_gfr") - model_params[["num_burnin"]] <- model_params[["num_burnin"]] + - json_object$get_scalar("num_burnin") - model_params[["num_mcmc"]] <- model_params[["num_mcmc"]] + - json_object$get_scalar("num_mcmc") - model_params[["num_samples"]] <- model_params[["num_samples"]] + - json_object$get_scalar("num_samples") - } - } - output[["model_params"]] <- model_params - - # Unpack sampled parameters - if (model_params[["sample_sigma2_global"]]) { - for (i in 1:length(json_object_list)) { - json_object <- json_object_list[[i]] - if (i == 1) { - output[["sigma2_global_samples"]] <- json_object$get_vector( - "sigma2_global_samples", - "parameters" - ) - } else { - output[["sigma2_global_samples"]] <- c( - output[["sigma2_global_samples"]], - json_object$get_vector( - "sigma2_global_samples", - "parameters" - ) - ) - } - } - } - if (model_params[["sample_sigma2_leaf"]]) { - for (i in 1:length(json_object_list)) { - json_object <- json_object_list[[i]] - if (i == 1) { - output[["sigma2_leaf_samples"]] <- json_object$get_vector( - "sigma2_leaf_samples", - "parameters" - ) - } else { - output[["sigma2_leaf_samples"]] <- c( - output[["sigma2_leaf_samples"]], - json_object$get_vector("sigma2_leaf_samples", "parameters") - ) - } - } - } - - # Unpack random effects - if (model_params[["has_rfx"]]) { - output[[ - "rfx_unique_group_ids" - ]] <- json_object_default$get_string_vector("rfx_unique_group_ids") - output[["rfx_samples"]] <- loadRandomEffectSamplesCombinedJson( - json_object_list, - 0 + json_object <- json_object_list[[i]] + if (i == 1) { + output[["sigma2_leaf_samples"]] <- json_object$get_vector( + "sigma2_leaf_samples", + "parameters" ) - } - - # Unpack covariate preprocessor - preprocessor_metadata_string <- json_object$get_string( - "preprocessor_metadata" - ) - output[["train_set_metadata"]] <- createPreprocessorFromJsonString( - preprocessor_metadata_string + } else { + output[["sigma2_leaf_samples"]] <- c( + output[["sigma2_leaf_samples"]], + json_object$get_vector("sigma2_leaf_samples", "parameters") + ) + } + } + } + + # Unpack random effects + if (model_params[["has_rfx"]]) { + output[[ + "rfx_unique_group_ids" + ]] <- json_object_default$get_string_vector("rfx_unique_group_ids") + output[["rfx_samples"]] <- loadRandomEffectSamplesCombinedJson( + json_object_list, + 0 ) - - class(output) <- "bartmodel" - return(output) + } + + # Unpack covariate preprocessor + preprocessor_metadata_string <- json_object$get_string( + "preprocessor_metadata" + ) + output[["train_set_metadata"]] <- createPreprocessorFromJsonString( + preprocessor_metadata_string + ) + + class(output) <- "bartmodel" + return(output) } #' Convert a list of (in-memory) JSON strings that represent BART models to a single combined BART model object @@ -2809,207 +3015,210 @@ createBARTModelFromCombinedJson <- function(json_object_list) { #' bart_json_string_list <- list(saveBARTModelToJsonString(bart_model)) #' bart_model_roundtrip <- createBARTModelFromCombinedJsonString(bart_json_string_list) createBARTModelFromCombinedJsonString <- function(json_string_list) { - # Initialize the BCF model - output <- list() - - # Convert JSON strings - json_object_list <- list() - for (i in 1:length(json_string_list)) { - json_string <- json_string_list[[i]] - json_object_list[[i]] <- createCppJsonString(json_string) - } - - # For scalar / preprocessing details which aren't sample-dependent, - # defer to the first json - json_object_default <- json_object_list[[1]] - - # Unpack the forests - include_mean_forest <- json_object_default$get_boolean( - "include_mean_forest" + # Initialize the BCF model + output <- list() + + # Convert JSON strings + json_object_list <- list() + for (i in 1:length(json_string_list)) { + json_string <- json_string_list[[i]] + json_object_list[[i]] <- createCppJsonString(json_string) + } + + # For scalar / preprocessing details which aren't sample-dependent, + # defer to the first json + json_object_default <- json_object_list[[1]] + + # Unpack the forests + include_mean_forest <- json_object_default$get_boolean( + "include_mean_forest" + ) + include_variance_forest <- json_object_default$get_boolean( + "include_variance_forest" + ) + if (include_mean_forest) { + output[["mean_forests"]] <- loadForestContainerCombinedJson( + json_object_list, + "forest_0" ) - include_variance_forest <- json_object_default$get_boolean( - "include_variance_forest" + if (include_variance_forest) { + output[["variance_forests"]] <- loadForestContainerCombinedJson( + json_object_list, + "forest_1" + ) + } + } else { + output[["variance_forests"]] <- loadForestContainerCombinedJson( + json_object_list, + "forest_0" ) - if (include_mean_forest) { - output[["mean_forests"]] <- loadForestContainerCombinedJson( - json_object_list, - "forest_0" - ) - if (include_variance_forest) { - output[["variance_forests"]] <- loadForestContainerCombinedJson( - json_object_list, - "forest_1" - ) - } - } else { - output[["variance_forests"]] <- loadForestContainerCombinedJson( - json_object_list, - "forest_0" - ) - } - - # Unpack metadata - train_set_metadata = list() - train_set_metadata[["num_numeric_vars"]] <- json_object_default$get_scalar( - "num_numeric_vars" + } + + # Unpack metadata + train_set_metadata = list() + train_set_metadata[["num_numeric_vars"]] <- json_object_default$get_scalar( + "num_numeric_vars" + ) + train_set_metadata[[ + "num_ordered_cat_vars" + ]] <- json_object_default$get_scalar("num_ordered_cat_vars") + train_set_metadata[[ + "num_unordered_cat_vars" + ]] <- json_object_default$get_scalar("num_unordered_cat_vars") + if (train_set_metadata[["num_numeric_vars"]] > 0) { + train_set_metadata[[ + "numeric_vars" + ]] <- json_object_default$get_string_vector("numeric_vars") + } + if (train_set_metadata[["num_ordered_cat_vars"]] > 0) { + train_set_metadata[[ + "ordered_cat_vars" + ]] <- json_object_default$get_string_vector("ordered_cat_vars") + train_set_metadata[[ + "ordered_unique_levels" + ]] <- json_object_default$get_string_list( + "ordered_unique_levels", + train_set_metadata[["ordered_cat_vars"]] ) + } + if (train_set_metadata[["num_unordered_cat_vars"]] > 0) { train_set_metadata[[ - "num_ordered_cat_vars" - ]] <- json_object_default$get_scalar("num_ordered_cat_vars") + "unordered_cat_vars" + ]] <- json_object_default$get_string_vector("unordered_cat_vars") train_set_metadata[[ - "num_unordered_cat_vars" - ]] <- json_object_default$get_scalar("num_unordered_cat_vars") - if (train_set_metadata[["num_numeric_vars"]] > 0) { - train_set_metadata[[ - "numeric_vars" - ]] <- json_object_default$get_string_vector("numeric_vars") - } - if (train_set_metadata[["num_ordered_cat_vars"]] > 0) { - train_set_metadata[[ - "ordered_cat_vars" - ]] <- json_object_default$get_string_vector("ordered_cat_vars") - train_set_metadata[[ - "ordered_unique_levels" - ]] <- json_object_default$get_string_list( - "ordered_unique_levels", - train_set_metadata[["ordered_cat_vars"]] + "unordered_unique_levels" + ]] <- json_object_default$get_string_list( + "unordered_unique_levels", + train_set_metadata[["unordered_cat_vars"]] + ) + } + output[["train_set_metadata"]] <- train_set_metadata + + # Unpack model params + model_params = list() + model_params[["outcome_scale"]] <- json_object_default$get_scalar( + "outcome_scale" + ) + model_params[["outcome_mean"]] <- json_object_default$get_scalar( + "outcome_mean" + ) + model_params[["standardize"]] <- json_object_default$get_boolean( + "standardize" + ) + model_params[["sigma2_init"]] <- json_object_default$get_scalar( + "sigma2_init" + ) + model_params[["sample_sigma2_global"]] <- json_object_default$get_boolean( + "sample_sigma2_global" + ) + model_params[["sample_sigma2_leaf"]] <- json_object_default$get_boolean( + "sample_sigma2_leaf" + ) + model_params[["include_mean_forest"]] <- include_mean_forest + model_params[["include_variance_forest"]] <- include_variance_forest + model_params[["has_rfx"]] <- json_object_default$get_boolean("has_rfx") + model_params[["has_rfx_basis"]] <- json_object_default$get_boolean( + "has_rfx_basis" + ) + model_params[["num_rfx_basis"]] <- json_object_default$get_scalar( + "num_rfx_basis" + ) + model_params[["num_covariates"]] <- json_object_default$get_scalar( + "num_covariates" + ) + model_params[["num_basis"]] <- json_object_default$get_scalar("num_basis") + model_params[["num_chains"]] <- json_object_default$get_scalar("num_chains") + model_params[["keep_every"]] <- json_object_default$get_scalar("keep_every") + model_params[["requires_basis"]] <- json_object_default$get_boolean( + "requires_basis" + ) + model_params[["probit_outcome_model"]] <- json_object_default$get_boolean( + "probit_outcome_model" + ) + model_params[["rfx_model_spec"]] <- json_object_default$get_string( + "rfx_model_spec" + ) + + # Combine values that are sample-specific + for (i in 1:length(json_object_list)) { + json_object <- json_object_list[[i]] + if (i == 1) { + model_params[["num_gfr"]] <- json_object$get_scalar("num_gfr") + model_params[["num_burnin"]] <- json_object$get_scalar("num_burnin") + model_params[["num_mcmc"]] <- json_object$get_scalar("num_mcmc") + model_params[["num_samples"]] <- json_object$get_scalar( + "num_samples" + ) + } else { + prev_json <- json_object_list[[i - 1]] + model_params[["num_gfr"]] <- model_params[["num_gfr"]] + + json_object$get_scalar("num_gfr") + model_params[["num_burnin"]] <- model_params[["num_burnin"]] + + json_object$get_scalar("num_burnin") + model_params[["num_mcmc"]] <- model_params[["num_mcmc"]] + + json_object$get_scalar("num_mcmc") + model_params[["num_samples"]] <- model_params[["num_samples"]] + + json_object$get_scalar("num_samples") + } + } + output[["model_params"]] <- model_params + + # Unpack sampled parameters + if (model_params[["sample_sigma2_global"]]) { + for (i in 1:length(json_object_list)) { + json_object <- json_object_list[[i]] + if (i == 1) { + output[["sigma2_global_samples"]] <- json_object$get_vector( + "sigma2_global_samples", + "parameters" ) - } - if (train_set_metadata[["num_unordered_cat_vars"]] > 0) { - train_set_metadata[[ - "unordered_cat_vars" - ]] <- json_object_default$get_string_vector("unordered_cat_vars") - train_set_metadata[[ - "unordered_unique_levels" - ]] <- json_object_default$get_string_list( - "unordered_unique_levels", - train_set_metadata[["unordered_cat_vars"]] + } else { + output[["sigma2_global_samples"]] <- c( + output[["sigma2_global_samples"]], + json_object$get_vector( + "sigma2_global_samples", + "parameters" + ) ) + } } - output[["train_set_metadata"]] <- train_set_metadata - - # Unpack model params - model_params = list() - model_params[["outcome_scale"]] <- json_object_default$get_scalar( - "outcome_scale" - ) - model_params[["outcome_mean"]] <- json_object_default$get_scalar( - "outcome_mean" - ) - model_params[["standardize"]] <- json_object_default$get_boolean( - "standardize" - ) - model_params[["sigma2_init"]] <- json_object_default$get_scalar( - "sigma2_init" - ) - model_params[["sample_sigma2_global"]] <- json_object_default$get_boolean( - "sample_sigma2_global" - ) - model_params[["sample_sigma2_leaf"]] <- json_object_default$get_boolean( - "sample_sigma2_leaf" - ) - model_params[["include_mean_forest"]] <- include_mean_forest - model_params[["include_variance_forest"]] <- include_variance_forest - model_params[["has_rfx"]] <- json_object_default$get_boolean("has_rfx") - model_params[["has_rfx_basis"]] <- json_object_default$get_boolean( - "has_rfx_basis" - ) - model_params[["num_rfx_basis"]] <- json_object_default$get_scalar( - "num_rfx_basis" - ) - model_params[["num_covariates"]] <- json_object_default$get_scalar( - "num_covariates" - ) - model_params[["num_basis"]] <- json_object_default$get_scalar("num_basis") - model_params[["num_chains"]] <- json_object_default$get_scalar("num_chains") - model_params[["keep_every"]] <- json_object_default$get_scalar("keep_every") - model_params[["requires_basis"]] <- json_object_default$get_boolean( - "requires_basis" - ) - model_params[["probit_outcome_model"]] <- json_object_default$get_boolean( - "probit_outcome_model" - ) - - # Combine values that are sample-specific + } + if (model_params[["sample_sigma2_leaf"]]) { for (i in 1:length(json_object_list)) { - json_object <- json_object_list[[i]] - if (i == 1) { - model_params[["num_gfr"]] <- json_object$get_scalar("num_gfr") - model_params[["num_burnin"]] <- json_object$get_scalar("num_burnin") - model_params[["num_mcmc"]] <- json_object$get_scalar("num_mcmc") - model_params[["num_samples"]] <- json_object$get_scalar( - "num_samples" - ) - } else { - prev_json <- json_object_list[[i - 1]] - model_params[["num_gfr"]] <- model_params[["num_gfr"]] + - json_object$get_scalar("num_gfr") - model_params[["num_burnin"]] <- model_params[["num_burnin"]] + - json_object$get_scalar("num_burnin") - model_params[["num_mcmc"]] <- model_params[["num_mcmc"]] + - json_object$get_scalar("num_mcmc") - model_params[["num_samples"]] <- model_params[["num_samples"]] + - json_object$get_scalar("num_samples") - } - } - output[["model_params"]] <- model_params - - # Unpack sampled parameters - if (model_params[["sample_sigma2_global"]]) { - for (i in 1:length(json_object_list)) { - json_object <- json_object_list[[i]] - if (i == 1) { - output[["sigma2_global_samples"]] <- json_object$get_vector( - "sigma2_global_samples", - "parameters" - ) - } else { - output[["sigma2_global_samples"]] <- c( - output[["sigma2_global_samples"]], - json_object$get_vector( - "sigma2_global_samples", - "parameters" - ) - ) - } - } - } - if (model_params[["sample_sigma2_leaf"]]) { - for (i in 1:length(json_object_list)) { - json_object <- json_object_list[[i]] - if (i == 1) { - output[["sigma2_leaf_samples"]] <- json_object$get_vector( - "sigma2_leaf_samples", - "parameters" - ) - } else { - output[["sigma2_leaf_samples"]] <- c( - output[["sigma2_leaf_samples"]], - json_object$get_vector("sigma2_leaf_samples", "parameters") - ) - } - } - } - - # Unpack random effects - if (model_params[["has_rfx"]]) { - output[[ - "rfx_unique_group_ids" - ]] <- json_object_default$get_string_vector("rfx_unique_group_ids") - output[["rfx_samples"]] <- loadRandomEffectSamplesCombinedJson( - json_object_list, - 0 + json_object <- json_object_list[[i]] + if (i == 1) { + output[["sigma2_leaf_samples"]] <- json_object$get_vector( + "sigma2_leaf_samples", + "parameters" ) - } - - # Unpack covariate preprocessor - preprocessor_metadata_string <- json_object_default$get_string( - "preprocessor_metadata" - ) - output[["train_set_metadata"]] <- createPreprocessorFromJsonString( - preprocessor_metadata_string + } else { + output[["sigma2_leaf_samples"]] <- c( + output[["sigma2_leaf_samples"]], + json_object$get_vector("sigma2_leaf_samples", "parameters") + ) + } + } + } + + # Unpack random effects + if (model_params[["has_rfx"]]) { + output[[ + "rfx_unique_group_ids" + ]] <- json_object_default$get_string_vector("rfx_unique_group_ids") + output[["rfx_samples"]] <- loadRandomEffectSamplesCombinedJson( + json_object_list, + 0 ) - - class(output) <- "bartmodel" - return(output) + } + + # Unpack covariate preprocessor + preprocessor_metadata_string <- json_object_default$get_string( + "preprocessor_metadata" + ) + output[["train_set_metadata"]] <- createPreprocessorFromJsonString( + preprocessor_metadata_string + ) + + class(output) <- "bartmodel" + return(output) } diff --git a/R/bcf.R b/R/bcf.R index 2d1564fa..e5e18e65 100644 --- a/R/bcf.R +++ b/R/bcf.R @@ -47,12 +47,6 @@ #' - `num_chains` How many independent MCMC chains should be sampled. If `num_mcmc = 0`, this is ignored. If `num_gfr = 0`, then each chain is run from root for `num_mcmc * keep_every + num_burnin` iterations, with `num_mcmc` samples retained. If `num_gfr > 0`, each MCMC chain will be initialized from a separate GFR ensemble, with the requirement that `num_gfr >= num_chains`. Default: `1`. #' - `verbose` Whether or not to print progress during the sampling loops. Default: `FALSE`. #' - `probit_outcome_model` Whether or not the outcome should be modeled as explicitly binary via a probit link. If `TRUE`, `y` must only contain the values `0` and `1`. Default: `FALSE`. -#' - `rfx_working_parameter_prior_mean` Prior mean for the random effects "working parameter". Default: `NULL`. Must be a vector whose dimension matches the number of random effects bases, or a scalar value that will be expanded to a vector. -#' - `rfx_group_parameters_prior_mean` Prior mean for the random effects "group parameters." Default: `NULL`. Must be a vector whose dimension matches the number of random effects bases, or a scalar value that will be expanded to a vector. -#' - `rfx_working_parameter_prior_cov` Prior covariance matrix for the random effects "working parameter." Default: `NULL`. Must be a square matrix whose dimension matches the number of random effects bases, or a scalar value that will be expanded to a diagonal matrix. -#' - `rfx_group_parameter_prior_cov` Prior covariance matrix for the random effects "group parameters." Default: `NULL`. Must be a square matrix whose dimension matches the number of random effects bases, or a scalar value that will be expanded to a diagonal matrix. -#' - `rfx_variance_prior_shape` Shape parameter for the inverse gamma prior on the variance of the random effects "group parameter." Default: `1`. -#' - `rfx_variance_prior_scale` Scale parameter for the inverse gamma prior on the variance of the random effects "group parameter." Default: `1`. #' - `num_threads` Number of threads to use in the GFR and MCMC algorithms, as well as prediction. If OpenMP is not available on a user's setup, this will default to `1`, otherwise to the maximum number of available threads. #' #' @param prognostic_forest_params (Optional) A list of prognostic forest model parameters, each of which has a default value processed internally, so this argument list is optional. @@ -103,6 +97,16 @@ #' - `drop_vars` Vector of variable names or column indices denoting variables that should be excluded from the forest. Default: `NULL`. If both `drop_vars` and `keep_vars` are set, `drop_vars` will be ignored. #' - `num_features_subsample` How many features to subsample when growing each tree for the GFR algorithm. Defaults to the number of features in the training dataset. #' +#' @param random_effects_params (Optional) A list of random effects model parameters, each of which has a default value processed internally, so this argument list is optional. +#' +#' - `model_spec` Specification of the random effects model. Options are "custom", "intercept_only", and "intercept_plus_treatment". If "custom" is specified, then a user-provided basis must be passed through `rfx_basis_train`. If "intercept_only" is specified, a random effects basis of all ones will be dispatched internally at sampling and prediction time. If "intercept_plus_treatment" is specified, a random effects basis that combines an "intercept" basis of all ones with the treatment variable (`Z_train`) will be dispatched internally at sampling and prediction time. Default: "custom". If either "intercept_only" or "intercept_plus_treatment" is specified, `rfx_basis_train` and `rfx_basis_test` (if provided) will be ignored. +#' - `working_parameter_prior_mean` Prior mean for the random effects "working parameter". Default: `NULL`. Must be a vector whose dimension matches the number of random effects bases, or a scalar value that will be expanded to a vector. +#' - `group_parameters_prior_mean` Prior mean for the random effects "group parameters." Default: `NULL`. Must be a vector whose dimension matches the number of random effects bases, or a scalar value that will be expanded to a vector. +#' - `working_parameter_prior_cov` Prior covariance matrix for the random effects "working parameter." Default: `NULL`. Must be a square matrix whose dimension matches the number of random effects bases, or a scalar value that will be expanded to a diagonal matrix. +#' - `group_parameter_prior_cov` Prior covariance matrix for the random effects "group parameters." Default: `NULL`. Must be a square matrix whose dimension matches the number of random effects bases, or a scalar value that will be expanded to a diagonal matrix. +#' - `variance_prior_shape` Shape parameter for the inverse gamma prior on the variance of the random effects "group parameter." Default: `1`. +#' - `variance_prior_scale` Scale parameter for the inverse gamma prior on the variance of the random effects "group parameter." Default: `1`. +#' #' @return List of sampling outputs and a wrapper around the sampled forests (which can be used for in-memory prediction on new data, or serialized to JSON on disk). #' @export #' @@ -153,2414 +157,2479 @@ #' propensity_test = pi_test, num_gfr = 10, #' num_burnin = 0, num_mcmc = 10) bcf <- function( - X_train, - Z_train, - y_train, - propensity_train = NULL, - rfx_group_ids_train = NULL, - rfx_basis_train = NULL, - X_test = NULL, - Z_test = NULL, - propensity_test = NULL, - rfx_group_ids_test = NULL, - rfx_basis_test = NULL, - num_gfr = 5, - num_burnin = 0, - num_mcmc = 100, - previous_model_json = NULL, - previous_model_warmstart_sample_num = NULL, - general_params = list(), - prognostic_forest_params = list(), - treatment_effect_forest_params = list(), - variance_forest_params = list() + X_train, + Z_train, + y_train, + propensity_train = NULL, + rfx_group_ids_train = NULL, + rfx_basis_train = NULL, + X_test = NULL, + Z_test = NULL, + propensity_test = NULL, + rfx_group_ids_test = NULL, + rfx_basis_test = NULL, + num_gfr = 5, + num_burnin = 0, + num_mcmc = 100, + previous_model_json = NULL, + previous_model_warmstart_sample_num = NULL, + general_params = list(), + prognostic_forest_params = list(), + treatment_effect_forest_params = list(), + variance_forest_params = list(), + random_effects_params = list() ) { - # Update general BCF parameters - general_params_default <- list( - cutpoint_grid_size = 100, - standardize = TRUE, - sample_sigma2_global = TRUE, - sigma2_global_init = NULL, - sigma2_global_shape = 0, - sigma2_global_scale = 0, - variable_weights = NULL, - propensity_covariate = "mu", - adaptive_coding = TRUE, - control_coding_init = -0.5, - treated_coding_init = 0.5, - rfx_prior_var = NULL, - random_seed = -1, - keep_burnin = FALSE, - keep_gfr = FALSE, - keep_every = 1, - num_chains = 1, - verbose = FALSE, - probit_outcome_model = FALSE, - rfx_working_parameter_prior_mean = NULL, - rfx_group_parameter_prior_mean = NULL, - rfx_working_parameter_prior_cov = NULL, - rfx_group_parameter_prior_cov = NULL, - rfx_variance_prior_shape = 1, - rfx_variance_prior_scale = 1, - num_threads = -1 - ) - general_params_updated <- preprocessParams( - general_params_default, - general_params - ) - - # Update mu forest BCF parameters - prognostic_forest_params_default <- list( - num_trees = 250, - alpha = 0.95, - beta = 2.0, - min_samples_leaf = 5, - max_depth = 10, - sample_sigma2_leaf = TRUE, - sigma2_leaf_init = NULL, - sigma2_leaf_shape = 3, - sigma2_leaf_scale = NULL, - keep_vars = NULL, - drop_vars = NULL, - num_features_subsample = NULL - ) - prognostic_forest_params_updated <- preprocessParams( - prognostic_forest_params_default, - prognostic_forest_params - ) - - # Update tau forest BCF parameters - treatment_effect_forest_params_default <- list( - num_trees = 50, - alpha = 0.25, - beta = 3.0, - min_samples_leaf = 5, - max_depth = 5, - sample_sigma2_leaf = FALSE, - sigma2_leaf_init = NULL, - sigma2_leaf_shape = 3, - sigma2_leaf_scale = NULL, - keep_vars = NULL, - drop_vars = NULL, - delta_max = 0.9, - num_features_subsample = NULL - ) - treatment_effect_forest_params_updated <- preprocessParams( - treatment_effect_forest_params_default, - treatment_effect_forest_params - ) - - # Update variance forest BCF parameters - variance_forest_params_default <- list( - num_trees = 0, - alpha = 0.95, - beta = 2.0, - min_samples_leaf = 5, - max_depth = 10, - leaf_prior_calibration_param = 1.5, - variance_forest_init = NULL, - var_forest_prior_shape = NULL, - var_forest_prior_scale = NULL, - keep_vars = NULL, - drop_vars = NULL, - num_features_subsample = NULL - ) - variance_forest_params_updated <- preprocessParams( - variance_forest_params_default, - variance_forest_params + # Update general BCF parameters + general_params_default <- list( + cutpoint_grid_size = 100, + standardize = TRUE, + sample_sigma2_global = TRUE, + sigma2_global_init = NULL, + sigma2_global_shape = 0, + sigma2_global_scale = 0, + variable_weights = NULL, + propensity_covariate = "mu", + adaptive_coding = TRUE, + control_coding_init = -0.5, + treated_coding_init = 0.5, + rfx_prior_var = NULL, + random_seed = -1, + keep_burnin = FALSE, + keep_gfr = FALSE, + keep_every = 1, + num_chains = 1, + verbose = FALSE, + probit_outcome_model = FALSE, + num_threads = -1 + ) + general_params_updated <- preprocessParams( + general_params_default, + general_params + ) + + # Update mu forest BCF parameters + prognostic_forest_params_default <- list( + num_trees = 250, + alpha = 0.95, + beta = 2.0, + min_samples_leaf = 5, + max_depth = 10, + sample_sigma2_leaf = TRUE, + sigma2_leaf_init = NULL, + sigma2_leaf_shape = 3, + sigma2_leaf_scale = NULL, + keep_vars = NULL, + drop_vars = NULL, + num_features_subsample = NULL + ) + prognostic_forest_params_updated <- preprocessParams( + prognostic_forest_params_default, + prognostic_forest_params + ) + + # Update tau forest BCF parameters + treatment_effect_forest_params_default <- list( + num_trees = 50, + alpha = 0.25, + beta = 3.0, + min_samples_leaf = 5, + max_depth = 5, + sample_sigma2_leaf = FALSE, + sigma2_leaf_init = NULL, + sigma2_leaf_shape = 3, + sigma2_leaf_scale = NULL, + keep_vars = NULL, + drop_vars = NULL, + delta_max = 0.9, + num_features_subsample = NULL + ) + treatment_effect_forest_params_updated <- preprocessParams( + treatment_effect_forest_params_default, + treatment_effect_forest_params + ) + + # Update variance forest BCF parameters + variance_forest_params_default <- list( + num_trees = 0, + alpha = 0.95, + beta = 2.0, + min_samples_leaf = 5, + max_depth = 10, + leaf_prior_calibration_param = 1.5, + variance_forest_init = NULL, + var_forest_prior_shape = NULL, + var_forest_prior_scale = NULL, + keep_vars = NULL, + drop_vars = NULL, + num_features_subsample = NULL + ) + variance_forest_params_updated <- preprocessParams( + variance_forest_params_default, + variance_forest_params + ) + + # Update random effects parameters + rfx_params_default <- list( + model_spec = "custom", + working_parameter_prior_mean = NULL, + group_parameter_prior_mean = NULL, + working_parameter_prior_cov = NULL, + group_parameter_prior_cov = NULL, + variance_prior_shape = 1, + variance_prior_scale = 1 + ) + rfx_params_updated <- preprocessParams( + rfx_params_default, + random_effects_params + ) + + ### Unpack all parameter values + # 1. General parameters + cutpoint_grid_size <- general_params_updated$cutpoint_grid_size + standardize <- general_params_updated$standardize + sample_sigma2_global <- general_params_updated$sample_sigma2_global + sigma2_init <- general_params_updated$sigma2_global_init + a_global <- general_params_updated$sigma2_global_shape + b_global <- general_params_updated$sigma2_global_scale + variable_weights <- general_params_updated$variable_weights + propensity_covariate <- general_params_updated$propensity_covariate + adaptive_coding <- general_params_updated$adaptive_coding + b_0 <- general_params_updated$control_coding_init + b_1 <- general_params_updated$treated_coding_init + rfx_prior_var <- general_params_updated$rfx_prior_var + random_seed <- general_params_updated$random_seed + keep_burnin <- general_params_updated$keep_burnin + keep_gfr <- general_params_updated$keep_gfr + keep_every <- general_params_updated$keep_every + num_chains <- general_params_updated$num_chains + verbose <- general_params_updated$verbose + probit_outcome_model <- general_params_updated$probit_outcome_model + num_threads <- general_params_updated$num_threads + + # 2. Mu forest parameters + num_trees_mu <- prognostic_forest_params_updated$num_trees + alpha_mu <- prognostic_forest_params_updated$alpha + beta_mu <- prognostic_forest_params_updated$beta + min_samples_leaf_mu <- prognostic_forest_params_updated$min_samples_leaf + max_depth_mu <- prognostic_forest_params_updated$max_depth + sample_sigma2_leaf_mu <- prognostic_forest_params_updated$sample_sigma2_leaf + sigma2_leaf_mu <- prognostic_forest_params_updated$sigma2_leaf_init + a_leaf_mu <- prognostic_forest_params_updated$sigma2_leaf_shape + b_leaf_mu <- prognostic_forest_params_updated$sigma2_leaf_scale + keep_vars_mu <- prognostic_forest_params_updated$keep_vars + drop_vars_mu <- prognostic_forest_params_updated$drop_vars + num_features_subsample_mu <- prognostic_forest_params_updated$num_features_subsample + + # 3. Tau forest parameters + num_trees_tau <- treatment_effect_forest_params_updated$num_trees + alpha_tau <- treatment_effect_forest_params_updated$alpha + beta_tau <- treatment_effect_forest_params_updated$beta + min_samples_leaf_tau <- treatment_effect_forest_params_updated$min_samples_leaf + max_depth_tau <- treatment_effect_forest_params_updated$max_depth + sample_sigma2_leaf_tau <- treatment_effect_forest_params_updated$sample_sigma2_leaf + sigma2_leaf_tau <- treatment_effect_forest_params_updated$sigma2_leaf_init + a_leaf_tau <- treatment_effect_forest_params_updated$sigma2_leaf_shape + b_leaf_tau <- treatment_effect_forest_params_updated$sigma2_leaf_scale + keep_vars_tau <- treatment_effect_forest_params_updated$keep_vars + drop_vars_tau <- treatment_effect_forest_params_updated$drop_vars + delta_max <- treatment_effect_forest_params_updated$delta_max + num_features_subsample_tau <- treatment_effect_forest_params_updated$num_features_subsample + + # 4. Variance forest parameters + num_trees_variance <- variance_forest_params_updated$num_trees + alpha_variance <- variance_forest_params_updated$alpha + beta_variance <- variance_forest_params_updated$beta + min_samples_leaf_variance <- variance_forest_params_updated$min_samples_leaf + max_depth_variance <- variance_forest_params_updated$max_depth + a_0 <- variance_forest_params_updated$leaf_prior_calibration_param + variance_forest_init <- variance_forest_params_updated$init_root_val + a_forest <- variance_forest_params_updated$var_forest_prior_shape + b_forest <- variance_forest_params_updated$var_forest_prior_scale + keep_vars_variance <- variance_forest_params_updated$keep_vars + drop_vars_variance <- variance_forest_params_updated$drop_vars + num_features_subsample_variance <- variance_forest_params_updated$num_features_subsample + + # 5. Random effects parameters + rfx_model_spec <- rfx_params_updated$model_spec + rfx_working_parameter_prior_mean <- rfx_params_updated$working_parameter_prior_mean + rfx_group_parameter_prior_mean <- rfx_params_updated$group_parameter_prior_mean + rfx_working_parameter_prior_cov <- rfx_params_updated$working_parameter_prior_cov + rfx_group_parameter_prior_cov <- rfx_params_updated$group_parameter_prior_cov + rfx_variance_prior_shape <- rfx_params_updated$variance_prior_shape + rfx_variance_prior_scale <- rfx_params_updated$variance_prior_scale + + # Handle random effects specification + if (!is.character(rfx_model_spec)) { + stop("rfx_model_spec must be a string or character vector") + } + if ( + !(rfx_model_spec %in% + c("custom", "intercept_only", "intercept_plus_treatment")) + ) { + stop( + "rfx_model_spec must either be 'custom', 'intercept_only', or 'intercept_plus_treatment'" ) - - ### Unpack all parameter values - # 1. General parameters - cutpoint_grid_size <- general_params_updated$cutpoint_grid_size - standardize <- general_params_updated$standardize - sample_sigma2_global <- general_params_updated$sample_sigma2_global - sigma2_init <- general_params_updated$sigma2_global_init - a_global <- general_params_updated$sigma2_global_shape - b_global <- general_params_updated$sigma2_global_scale - variable_weights <- general_params_updated$variable_weights - propensity_covariate <- general_params_updated$propensity_covariate - adaptive_coding <- general_params_updated$adaptive_coding - b_0 <- general_params_updated$control_coding_init - b_1 <- general_params_updated$treated_coding_init - rfx_prior_var <- general_params_updated$rfx_prior_var - random_seed <- general_params_updated$random_seed - keep_burnin <- general_params_updated$keep_burnin - keep_gfr <- general_params_updated$keep_gfr - keep_every <- general_params_updated$keep_every - num_chains <- general_params_updated$num_chains - verbose <- general_params_updated$verbose - probit_outcome_model <- general_params_updated$probit_outcome_model - rfx_working_parameter_prior_mean <- general_params_updated$rfx_working_parameter_prior_mean - rfx_group_parameter_prior_mean <- general_params_updated$rfx_group_parameter_prior_mean - rfx_working_parameter_prior_cov <- general_params_updated$rfx_working_parameter_prior_cov - rfx_group_parameter_prior_cov <- general_params_updated$rfx_group_parameter_prior_cov - rfx_variance_prior_shape <- general_params_updated$rfx_variance_prior_shape - rfx_variance_prior_scale <- general_params_updated$rfx_variance_prior_scale - num_threads <- general_params_updated$num_threads - - # 2. Mu forest parameters - num_trees_mu <- prognostic_forest_params_updated$num_trees - alpha_mu <- prognostic_forest_params_updated$alpha - beta_mu <- prognostic_forest_params_updated$beta - min_samples_leaf_mu <- prognostic_forest_params_updated$min_samples_leaf - max_depth_mu <- prognostic_forest_params_updated$max_depth - sample_sigma2_leaf_mu <- prognostic_forest_params_updated$sample_sigma2_leaf - sigma2_leaf_mu <- prognostic_forest_params_updated$sigma2_leaf_init - a_leaf_mu <- prognostic_forest_params_updated$sigma2_leaf_shape - b_leaf_mu <- prognostic_forest_params_updated$sigma2_leaf_scale - keep_vars_mu <- prognostic_forest_params_updated$keep_vars - drop_vars_mu <- prognostic_forest_params_updated$drop_vars - num_features_subsample_mu <- prognostic_forest_params_updated$num_features_subsample - - # 3. Tau forest parameters - num_trees_tau <- treatment_effect_forest_params_updated$num_trees - alpha_tau <- treatment_effect_forest_params_updated$alpha - beta_tau <- treatment_effect_forest_params_updated$beta - min_samples_leaf_tau <- treatment_effect_forest_params_updated$min_samples_leaf - max_depth_tau <- treatment_effect_forest_params_updated$max_depth - sample_sigma2_leaf_tau <- treatment_effect_forest_params_updated$sample_sigma2_leaf - sigma2_leaf_tau <- treatment_effect_forest_params_updated$sigma2_leaf_init - a_leaf_tau <- treatment_effect_forest_params_updated$sigma2_leaf_shape - b_leaf_tau <- treatment_effect_forest_params_updated$sigma2_leaf_scale - keep_vars_tau <- treatment_effect_forest_params_updated$keep_vars - drop_vars_tau <- treatment_effect_forest_params_updated$drop_vars - delta_max <- treatment_effect_forest_params_updated$delta_max - num_features_subsample_tau <- treatment_effect_forest_params_updated$num_features_subsample - - # 4. Variance forest parameters - num_trees_variance <- variance_forest_params_updated$num_trees - alpha_variance <- variance_forest_params_updated$alpha - beta_variance <- variance_forest_params_updated$beta - min_samples_leaf_variance <- variance_forest_params_updated$min_samples_leaf - max_depth_variance <- variance_forest_params_updated$max_depth - a_0 <- variance_forest_params_updated$leaf_prior_calibration_param - variance_forest_init <- variance_forest_params_updated$init_root_val - a_forest <- variance_forest_params_updated$var_forest_prior_shape - b_forest <- variance_forest_params_updated$var_forest_prior_scale - keep_vars_variance <- variance_forest_params_updated$keep_vars - drop_vars_variance <- variance_forest_params_updated$drop_vars - num_features_subsample_variance <- variance_forest_params_updated$num_features_subsample - - # Set a function-scoped RNG if user provided a random seed - custom_rng <- random_seed >= 0 - if (custom_rng) { - # Store original global environment RNG state - original_global_seed <- .Random.seed - # Set new seed and store associated RNG state - set.seed(random_seed) - function_scoped_seed <- .Random.seed - } - - # Check if there are enough GFR samples to seed num_chains samplers - if (num_gfr > 0) { - if (num_chains > num_gfr) { - stop( - "num_chains > num_gfr, meaning we do not have enough GFR samples to seed num_chains distinct MCMC chains" - ) - } - } - - # Override keep_gfr if there are no MCMC samples - if (num_mcmc == 0) { - keep_gfr <- TRUE - } - - # Check if previous model JSON is provided and parse it if so - has_prev_model <- !is.null(previous_model_json) - if (has_prev_model) { - previous_bcf_model <- createBCFModelFromJsonString(previous_model_json) - previous_y_bar <- previous_bcf_model$model_params$outcome_mean - previous_y_scale <- previous_bcf_model$model_params$outcome_scale - previous_forest_samples_mu <- previous_bcf_model$forests_mu - previous_forest_samples_tau <- previous_bcf_model$forests_tau - if (previous_bcf_model$model_params$include_variance_forest) { - previous_forest_samples_variance <- previous_bcf_model$forests_variance - } else { - previous_forest_samples_variance <- NULL - } - if (previous_bcf_model$model_params$sample_sigma2_global) { - previous_global_var_samples <- previous_bcf_model$sigma2_global_samples / - (previous_y_scale * previous_y_scale) - } else { - previous_global_var_samples <- NULL - } - if (previous_bcf_model$model_params$sample_sigma2_leaf_mu) { - previous_leaf_var_mu_samples <- previous_bcf_model$sigma2_leaf_mu_samples - } else { - previous_leaf_var_mu_samples <- NULL - } - if (previous_bcf_model$model_params$sample_sigma2_leaf_tau) { - previous_leaf_var_tau_samples <- previous_bcf_model$sigma2_leaf_tau_samples - } else { - previous_leaf_var_tau_samples <- NULL - } - if (previous_bcf_model$model_params$has_rfx) { - previous_rfx_samples <- previous_bcf_model$rfx_samples - } else { - previous_rfx_samples <- NULL - } - if (previous_bcf_model$model_params$adaptive_coding) { - previous_b_1_samples <- previous_bcf_model$b_1_samples - previous_b_0_samples <- previous_bcf_model$b_0_samples - } else { - previous_b_1_samples <- NULL - previous_b_0_samples <- NULL - } - previous_model_num_samples <- previous_bcf_model$model_params$num_samples - if (previous_model_warmstart_sample_num > previous_model_num_samples) { - stop( - "`previous_model_warmstart_sample_num` exceeds the number of samples in `previous_model_json`" - ) - } + } + + # Set a function-scoped RNG if user provided a random seed + custom_rng <- random_seed >= 0 + if (custom_rng) { + # Store original global environment RNG state + original_global_seed <- .Random.seed + # Set new seed and store associated RNG state + set.seed(random_seed) + function_scoped_seed <- .Random.seed + } + + # Check if there are enough GFR samples to seed num_chains samplers + if (num_gfr > 0) { + if (num_chains > num_gfr) { + stop( + "num_chains > num_gfr, meaning we do not have enough GFR samples to seed num_chains distinct MCMC chains" + ) + } + } + + # Override keep_gfr if there are no MCMC samples + if (num_mcmc == 0) { + keep_gfr <- TRUE + } + + # Check if previous model JSON is provided and parse it if so + has_prev_model <- !is.null(previous_model_json) + if (has_prev_model) { + previous_bcf_model <- createBCFModelFromJsonString(previous_model_json) + previous_y_bar <- previous_bcf_model$model_params$outcome_mean + previous_y_scale <- previous_bcf_model$model_params$outcome_scale + previous_forest_samples_mu <- previous_bcf_model$forests_mu + previous_forest_samples_tau <- previous_bcf_model$forests_tau + if (previous_bcf_model$model_params$include_variance_forest) { + previous_forest_samples_variance <- previous_bcf_model$forests_variance } else { - previous_y_bar <- NULL - previous_y_scale <- NULL - previous_global_var_samples <- NULL - previous_leaf_var_mu_samples <- NULL - previous_leaf_var_tau_samples <- NULL - previous_rfx_samples <- NULL - previous_forest_samples_mu <- NULL - previous_forest_samples_tau <- NULL - previous_forest_samples_variance <- NULL - previous_b_1_samples <- NULL - previous_b_0_samples <- NULL + previous_forest_samples_variance <- NULL } - - # Determine whether conditional variance will be modeled - if (num_trees_variance > 0) { - include_variance_forest = TRUE + if (previous_bcf_model$model_params$sample_sigma2_global) { + previous_global_var_samples <- previous_bcf_model$sigma2_global_samples / + (previous_y_scale * previous_y_scale) } else { - include_variance_forest = FALSE + previous_global_var_samples <- NULL } - - # Set the variance forest priors if not set - if (include_variance_forest) { - if (is.null(a_forest)) { - a_forest <- num_trees_variance / (a_0^2) + 0.5 - } - if (is.null(b_forest)) b_forest <- num_trees_variance / (a_0^2) + if (previous_bcf_model$model_params$sample_sigma2_leaf_mu) { + previous_leaf_var_mu_samples <- previous_bcf_model$sigma2_leaf_mu_samples } else { - a_forest <- 1. - b_forest <- 1. - } - - # Variable weight preprocessing (and initialization if necessary) - if (is.null(variable_weights)) { - variable_weights = rep(1 / ncol(X_train), ncol(X_train)) - } - if (any(variable_weights < 0)) { - stop("variable_weights cannot have any negative weights") - } - - # Check covariates are matrix or dataframe - if ((!is.data.frame(X_train)) && (!is.matrix(X_train))) { - stop("X_train must be a matrix or dataframe") - } - if (!is.null(X_test)) { - if ((!is.data.frame(X_test)) && (!is.matrix(X_test))) { - stop("X_test must be a matrix or dataframe") - } - } - num_cov_orig <- ncol(X_train) - - # Check delta_max is valid - if ((delta_max <= 0) || (delta_max >= 1)) { - stop("delta_max must be > 0 and < 1") + previous_leaf_var_mu_samples <- NULL } - - # Standardize the keep variable lists to numeric indices - if (!is.null(keep_vars_mu)) { - if (is.character(keep_vars_mu)) { - if (!all(keep_vars_mu %in% names(X_train))) { - stop( - "keep_vars_mu includes some variable names that are not in X_train" - ) - } - variable_subset_mu <- unname(which( - names(X_train) %in% keep_vars_mu - )) - } else { - if (any(keep_vars_mu > ncol(X_train))) { - stop( - "keep_vars_mu includes some variable indices that exceed the number of columns in X_train" - ) - } - if (any(keep_vars_mu < 0)) { - stop("keep_vars_mu includes some negative variable indices") - } - variable_subset_mu <- keep_vars_mu - } - } else if ((is.null(keep_vars_mu)) && (!is.null(drop_vars_mu))) { - if (is.character(drop_vars_mu)) { - if (!all(drop_vars_mu %in% names(X_train))) { - stop( - "drop_vars_mu includes some variable names that are not in X_train" - ) - } - variable_subset_mu <- unname(which( - !(names(X_train) %in% drop_vars_mu) - )) - } else { - if (any(drop_vars_mu > ncol(X_train))) { - stop( - "drop_vars_mu includes some variable indices that exceed the number of columns in X_train" - ) - } - if (any(drop_vars_mu < 0)) { - stop("drop_vars_mu includes some negative variable indices") - } - variable_subset_mu <- (1:ncol(X_train))[ - !(1:ncol(X_train) %in% drop_vars_mu) - ] - } + if (previous_bcf_model$model_params$sample_sigma2_leaf_tau) { + previous_leaf_var_tau_samples <- previous_bcf_model$sigma2_leaf_tau_samples } else { - variable_subset_mu <- 1:ncol(X_train) + previous_leaf_var_tau_samples <- NULL } - if (!is.null(keep_vars_tau)) { - if (is.character(keep_vars_tau)) { - if (!all(keep_vars_tau %in% names(X_train))) { - stop( - "keep_vars_tau includes some variable names that are not in X_train" - ) - } - variable_subset_tau <- unname(which( - names(X_train) %in% keep_vars_tau - )) - } else { - if (any(keep_vars_tau > ncol(X_train))) { - stop( - "keep_vars_tau includes some variable indices that exceed the number of columns in X_train" - ) - } - if (any(keep_vars_tau < 0)) { - stop("keep_vars_tau includes some negative variable indices") - } - variable_subset_tau <- keep_vars_tau - } - } else if ((is.null(keep_vars_tau)) && (!is.null(drop_vars_tau))) { - if (is.character(drop_vars_tau)) { - if (!all(drop_vars_tau %in% names(X_train))) { - stop( - "drop_vars_tau includes some variable names that are not in X_train" - ) - } - variable_subset_tau <- unname(which( - !(names(X_train) %in% drop_vars_tau) - )) - } else { - if (any(drop_vars_tau > ncol(X_train))) { - stop( - "drop_vars_tau includes some variable indices that exceed the number of columns in X_train" - ) - } - if (any(drop_vars_tau < 0)) { - stop("drop_vars_tau includes some negative variable indices") - } - variable_subset_tau <- (1:ncol(X_train))[ - !(1:ncol(X_train) %in% drop_vars_tau) - ] - } + if (previous_bcf_model$model_params$has_rfx) { + previous_rfx_samples <- previous_bcf_model$rfx_samples } else { - variable_subset_tau <- 1:ncol(X_train) + previous_rfx_samples <- NULL } - if (!is.null(keep_vars_variance)) { - if (is.character(keep_vars_variance)) { - if (!all(keep_vars_variance %in% names(X_train))) { - stop( - "keep_vars_variance includes some variable names that are not in X_train" - ) - } - variable_subset_variance <- unname(which( - names(X_train) %in% keep_vars_variance - )) - } else { - if (any(keep_vars_variance > ncol(X_train))) { - stop( - "keep_vars_variance includes some variable indices that exceed the number of columns in X_train" - ) - } - if (any(keep_vars_variance < 0)) { - stop( - "keep_vars_variance includes some negative variable indices" - ) - } - variable_subset_variance <- keep_vars_variance - } - } else if ( - (is.null(keep_vars_variance)) && (!is.null(drop_vars_variance)) - ) { - if (is.character(drop_vars_variance)) { - if (!all(drop_vars_variance %in% names(X_train))) { - stop( - "drop_vars_variance includes some variable names that are not in X_train" - ) - } - variable_subset_variance <- unname(which( - !(names(X_train) %in% drop_vars_variance) - )) - } else { - if (any(drop_vars_variance > ncol(X_train))) { - stop( - "drop_vars_variance includes some variable indices that exceed the number of columns in X_train" - ) - } - if (any(drop_vars_variance < 0)) { - stop( - "drop_vars_variance includes some negative variable indices" - ) - } - variable_subset_variance <- (1:ncol(X_train))[ - !(1:ncol(X_train) %in% drop_vars_variance) - ] - } + if (previous_bcf_model$model_params$adaptive_coding) { + previous_b_1_samples <- previous_bcf_model$b_1_samples + previous_b_0_samples <- previous_bcf_model$b_0_samples } else { - variable_subset_variance <- 1:ncol(X_train) - } - - # Preprocess covariates - if (ncol(X_train) != length(variable_weights)) { - stop("length(variable_weights) must equal ncol(X_train)") - } - train_cov_preprocess_list <- preprocessTrainData(X_train) - X_train_metadata <- train_cov_preprocess_list$metadata - X_train_raw <- X_train - X_train <- train_cov_preprocess_list$data - original_var_indices <- X_train_metadata$original_var_indices - feature_types <- X_train_metadata$feature_types - X_test_raw <- X_test - if (!is.null(X_test)) { - X_test <- preprocessPredictionData(X_test, X_train_metadata) - } - - # Convert all input data to matrices if not already converted - Z_col <- ifelse(is.null(dim(Z_train)), 1, ncol(Z_train)) - Z_train <- matrix(as.numeric(Z_train), ncol = Z_col) - if ((is.null(dim(propensity_train))) && (!is.null(propensity_train))) { - propensity_train <- as.matrix(propensity_train) - } - if (!is.null(Z_test)) { - Z_test <- matrix(as.numeric(Z_test), ncol = Z_col) - } - if ((is.null(dim(propensity_test))) && (!is.null(propensity_test))) { - propensity_test <- as.matrix(propensity_test) - } - if ((is.null(dim(rfx_basis_train))) && (!is.null(rfx_basis_train))) { - rfx_basis_train <- as.matrix(rfx_basis_train) - } - if ((is.null(dim(rfx_basis_test))) && (!is.null(rfx_basis_test))) { - rfx_basis_test <- as.matrix(rfx_basis_test) - } - - # Recode group IDs to integer vector (if passed as, for example, a vector of county names, etc...) - has_rfx <- FALSE - has_rfx_test <- FALSE - if (!is.null(rfx_group_ids_train)) { - group_ids_factor <- factor(rfx_group_ids_train) - rfx_group_ids_train <- as.integer(group_ids_factor) - has_rfx <- TRUE - if (!is.null(rfx_group_ids_test)) { - group_ids_factor_test <- factor( - rfx_group_ids_test, - levels = levels(group_ids_factor) - ) - if (sum(is.na(group_ids_factor_test)) > 0) { - stop( - "All random effect group labels provided in rfx_group_ids_test must be present in rfx_group_ids_train" - ) - } - rfx_group_ids_test <- as.integer(group_ids_factor_test) - has_rfx_test <- TRUE - } - } - - # Check that outcome and treatment are numeric - if (!is.numeric(y_train)) { - stop("y_train must be numeric") - } - if (!is.numeric(Z_train)) { - stop("Z_train must be numeric") - } - if (!is.null(Z_test)) { - if (!is.numeric(Z_test)) stop("Z_test must be numeric") - } - - # Data consistency checks - if ((!is.null(X_test)) && (ncol(X_test) != ncol(X_train))) { - stop("X_train and X_test must have the same number of columns") - } - if ((!is.null(Z_test)) && (ncol(Z_test) != ncol(Z_train))) { - stop("Z_train and Z_test must have the same number of columns") - } - if ((!is.null(Z_train)) && (nrow(Z_train) != nrow(X_train))) { - stop("Z_train and X_train must have the same number of rows") - } - if ( - (!is.null(propensity_train)) && - (nrow(propensity_train) != nrow(X_train)) - ) { - stop("propensity_train and X_train must have the same number of rows") - } - if ((!is.null(Z_test)) && (nrow(Z_test) != nrow(X_test))) { - stop("Z_test and X_test must have the same number of rows") - } - if ( - (!is.null(propensity_test)) && (nrow(propensity_test) != nrow(X_test)) - ) { - stop("propensity_test and X_test must have the same number of rows") - } - if (nrow(X_train) != length(y_train)) { - stop("X_train and y_train must have the same number of observations") - } - if ( - (!is.null(rfx_basis_test)) && - (ncol(rfx_basis_test) != ncol(rfx_basis_train)) - ) { + previous_b_1_samples <- NULL + previous_b_0_samples <- NULL + } + previous_model_num_samples <- previous_bcf_model$model_params$num_samples + if (previous_model_warmstart_sample_num > previous_model_num_samples) { + stop( + "`previous_model_warmstart_sample_num` exceeds the number of samples in `previous_model_json`" + ) + } + } else { + previous_y_bar <- NULL + previous_y_scale <- NULL + previous_global_var_samples <- NULL + previous_leaf_var_mu_samples <- NULL + previous_leaf_var_tau_samples <- NULL + previous_rfx_samples <- NULL + previous_forest_samples_mu <- NULL + previous_forest_samples_tau <- NULL + previous_forest_samples_variance <- NULL + previous_b_1_samples <- NULL + previous_b_0_samples <- NULL + } + + # Determine whether conditional variance will be modeled + if (num_trees_variance > 0) { + include_variance_forest = TRUE + } else { + include_variance_forest = FALSE + } + + # Set the variance forest priors if not set + if (include_variance_forest) { + if (is.null(a_forest)) { + a_forest <- num_trees_variance / (a_0^2) + 0.5 + } + if (is.null(b_forest)) b_forest <- num_trees_variance / (a_0^2) + } else { + a_forest <- 1. + b_forest <- 1. + } + + # Variable weight preprocessing (and initialization if necessary) + if (is.null(variable_weights)) { + variable_weights = rep(1 / ncol(X_train), ncol(X_train)) + } + if (any(variable_weights < 0)) { + stop("variable_weights cannot have any negative weights") + } + + # Check covariates are matrix or dataframe + if ((!is.data.frame(X_train)) && (!is.matrix(X_train))) { + stop("X_train must be a matrix or dataframe") + } + if (!is.null(X_test)) { + if ((!is.data.frame(X_test)) && (!is.matrix(X_test))) { + stop("X_test must be a matrix or dataframe") + } + } + num_cov_orig <- ncol(X_train) + + # Check delta_max is valid + if ((delta_max <= 0) || (delta_max >= 1)) { + stop("delta_max must be > 0 and < 1") + } + + # Standardize the keep variable lists to numeric indices + if (!is.null(keep_vars_mu)) { + if (is.character(keep_vars_mu)) { + if (!all(keep_vars_mu %in% names(X_train))) { stop( - "rfx_basis_train and rfx_basis_test must have the same number of columns" + "keep_vars_mu includes some variable names that are not in X_train" ) + } + variable_subset_mu <- unname(which( + names(X_train) %in% keep_vars_mu + )) + } else { + if (any(keep_vars_mu > ncol(X_train))) { + stop( + "keep_vars_mu includes some variable indices that exceed the number of columns in X_train" + ) + } + if (any(keep_vars_mu < 0)) { + stop("keep_vars_mu includes some negative variable indices") + } + variable_subset_mu <- keep_vars_mu + } + } else if ((is.null(keep_vars_mu)) && (!is.null(drop_vars_mu))) { + if (is.character(drop_vars_mu)) { + if (!all(drop_vars_mu %in% names(X_train))) { + stop( + "drop_vars_mu includes some variable names that are not in X_train" + ) + } + variable_subset_mu <- unname(which( + !(names(X_train) %in% drop_vars_mu) + )) + } else { + if (any(drop_vars_mu > ncol(X_train))) { + stop( + "drop_vars_mu includes some variable indices that exceed the number of columns in X_train" + ) + } + if (any(drop_vars_mu < 0)) { + stop("drop_vars_mu includes some negative variable indices") + } + variable_subset_mu <- (1:ncol(X_train))[ + !(1:ncol(X_train) %in% drop_vars_mu) + ] + } + } else { + variable_subset_mu <- 1:ncol(X_train) + } + if (!is.null(keep_vars_tau)) { + if (is.character(keep_vars_tau)) { + if (!all(keep_vars_tau %in% names(X_train))) { + stop( + "keep_vars_tau includes some variable names that are not in X_train" + ) + } + variable_subset_tau <- unname(which( + names(X_train) %in% keep_vars_tau + )) + } else { + if (any(keep_vars_tau > ncol(X_train))) { + stop( + "keep_vars_tau includes some variable indices that exceed the number of columns in X_train" + ) + } + if (any(keep_vars_tau < 0)) { + stop("keep_vars_tau includes some negative variable indices") + } + variable_subset_tau <- keep_vars_tau + } + } else if ((is.null(keep_vars_tau)) && (!is.null(drop_vars_tau))) { + if (is.character(drop_vars_tau)) { + if (!all(drop_vars_tau %in% names(X_train))) { + stop( + "drop_vars_tau includes some variable names that are not in X_train" + ) + } + variable_subset_tau <- unname(which( + !(names(X_train) %in% drop_vars_tau) + )) + } else { + if (any(drop_vars_tau > ncol(X_train))) { + stop( + "drop_vars_tau includes some variable indices that exceed the number of columns in X_train" + ) + } + if (any(drop_vars_tau < 0)) { + stop("drop_vars_tau includes some negative variable indices") + } + variable_subset_tau <- (1:ncol(X_train))[ + !(1:ncol(X_train) %in% drop_vars_tau) + ] + } + } else { + variable_subset_tau <- 1:ncol(X_train) + } + if (!is.null(keep_vars_variance)) { + if (is.character(keep_vars_variance)) { + if (!all(keep_vars_variance %in% names(X_train))) { + stop( + "keep_vars_variance includes some variable names that are not in X_train" + ) + } + variable_subset_variance <- unname(which( + names(X_train) %in% keep_vars_variance + )) + } else { + if (any(keep_vars_variance > ncol(X_train))) { + stop( + "keep_vars_variance includes some variable indices that exceed the number of columns in X_train" + ) + } + if (any(keep_vars_variance < 0)) { + stop( + "keep_vars_variance includes some negative variable indices" + ) + } + variable_subset_variance <- keep_vars_variance } - if (!is.null(rfx_group_ids_train)) { - if (!is.null(rfx_group_ids_test)) { - if ((!is.null(rfx_basis_train)) && (is.null(rfx_basis_test))) { - stop( - "rfx_basis_train is provided but rfx_basis_test is not provided" - ) - } - } - } - - # # Stop if multivariate treatment is provided - # if (ncol(Z_train) > 1) stop("Multivariate treatments are not currently supported") - - # Handle multivariate treatment - has_multivariate_treatment <- ncol(Z_train) > 1 - if (has_multivariate_treatment) { - # Disable adaptive coding, internal propensity model, and - # leaf scale sampling if treatment is multivariate - if (adaptive_coding) { - warning( - "Adaptive coding is incompatible with multivariate treatment and will be ignored" - ) - adaptive_coding <- FALSE - } - if (is.null(propensity_train)) { - if (propensity_covariate != "none") { - warning( - "No propensities were provided for the multivariate treatment; an internal propensity model will not be fitted to the multivariate treatment and propensity_covariate will be set to 'none'" - ) - propensity_covariate <- "none" - } - } - if (sample_sigma2_leaf_tau) { - warning( - "Sampling leaf scale not yet supported for multivariate leaf models, so the leaf scale parameter will not be sampled for the treatment forest in this model." - ) - sample_sigma2_leaf_tau <- FALSE - } - } - - # Random effects covariance prior - if (has_rfx) { - if (is.null(rfx_prior_var)) { - rfx_prior_var <- rep(1, ncol(rfx_basis_train)) - } else { - if ((!is.integer(rfx_prior_var)) && (!is.numeric(rfx_prior_var))) { - stop("rfx_prior_var must be a numeric vector") - } - if (length(rfx_prior_var) != ncol(rfx_basis_train)) { - stop("length(rfx_prior_var) must equal ncol(rfx_basis_train)") - } - } - } - - # Update variable weights - variable_weights_adj <- 1 / - sapply(original_var_indices, function(x) sum(original_var_indices == x)) - variable_weights <- variable_weights[original_var_indices] * - variable_weights_adj - - # Create mu and tau (and variance) specific variable weights with weights zeroed out for excluded variables - variable_weights_variance <- variable_weights_tau <- variable_weights_mu <- variable_weights - variable_weights_mu[!(original_var_indices %in% variable_subset_mu)] <- 0 - variable_weights_tau[!(original_var_indices %in% variable_subset_tau)] <- 0 - if (include_variance_forest) { - variable_weights_variance[ - !(original_var_indices %in% variable_subset_variance) - ] <- 0 - } - - # Fill in rfx basis as a vector of 1s (random intercept) if a basis not provided - has_basis_rfx <- FALSE - num_basis_rfx <- 0 - if (has_rfx) { - if (is.null(rfx_basis_train)) { - rfx_basis_train <- matrix( - rep(1, nrow(X_train)), - nrow = nrow(X_train), - ncol = 1 - ) - } else { - has_basis_rfx <- TRUE - num_basis_rfx <- ncol(rfx_basis_train) - } - num_rfx_groups <- length(unique(rfx_group_ids_train)) - num_rfx_components <- ncol(rfx_basis_train) - if (num_rfx_groups == 1) { - warning( - "Only one group was provided for random effect sampling, so the 'redundant parameterization' is likely overkill" - ) - } - } - if (has_rfx_test) { - if (is.null(rfx_basis_test)) { - if (!is.null(rfx_basis_train)) { - stop( - "Random effects basis provided for training set, must also be provided for the test set" - ) - } - rfx_basis_test <- matrix( - rep(1, nrow(X_test)), - nrow = nrow(X_test), - ncol = 1 - ) - } - } - - # Check that number of samples are all nonnegative - stopifnot(num_gfr >= 0) - stopifnot(num_burnin >= 0) - stopifnot(num_mcmc >= 0) - - # Determine whether a test set is provided - has_test = !is.null(X_test) - - # Convert y_train to numeric vector if not already converted - if (!is.null(dim(y_train))) { - y_train <- as.matrix(y_train) + } else if ((is.null(keep_vars_variance)) && (!is.null(drop_vars_variance))) { + if (is.character(drop_vars_variance)) { + if (!all(drop_vars_variance %in% names(X_train))) { + stop( + "drop_vars_variance includes some variable names that are not in X_train" + ) + } + variable_subset_variance <- unname(which( + !(names(X_train) %in% drop_vars_variance) + )) + } else { + if (any(drop_vars_variance > ncol(X_train))) { + stop( + "drop_vars_variance includes some variable indices that exceed the number of columns in X_train" + ) + } + if (any(drop_vars_variance < 0)) { + stop( + "drop_vars_variance includes some negative variable indices" + ) + } + variable_subset_variance <- (1:ncol(X_train))[ + !(1:ncol(X_train) %in% drop_vars_variance) + ] + } + } else { + variable_subset_variance <- 1:ncol(X_train) + } + + # Preprocess covariates + if (ncol(X_train) != length(variable_weights)) { + stop("length(variable_weights) must equal ncol(X_train)") + } + train_cov_preprocess_list <- preprocessTrainData(X_train) + X_train_metadata <- train_cov_preprocess_list$metadata + X_train_raw <- X_train + X_train <- train_cov_preprocess_list$data + original_var_indices <- X_train_metadata$original_var_indices + feature_types <- X_train_metadata$feature_types + X_test_raw <- X_test + if (!is.null(X_test)) { + X_test <- preprocessPredictionData(X_test, X_train_metadata) + } + + # Convert all input data to matrices if not already converted + Z_col <- ifelse(is.null(dim(Z_train)), 1, ncol(Z_train)) + Z_train <- matrix(as.numeric(Z_train), ncol = Z_col) + if ((is.null(dim(propensity_train))) && (!is.null(propensity_train))) { + propensity_train <- as.matrix(propensity_train) + } + if (!is.null(Z_test)) { + Z_test <- matrix(as.numeric(Z_test), ncol = Z_col) + } + if ((is.null(dim(propensity_test))) && (!is.null(propensity_test))) { + propensity_test <- as.matrix(propensity_test) + } + if ((is.null(dim(rfx_basis_train))) && (!is.null(rfx_basis_train))) { + rfx_basis_train <- as.matrix(rfx_basis_train) + } + if ((is.null(dim(rfx_basis_test))) && (!is.null(rfx_basis_test))) { + rfx_basis_test <- as.matrix(rfx_basis_test) + } + + # Recode group IDs to integer vector (if passed as, for example, a vector of county names, etc...) + has_rfx <- FALSE + has_rfx_test <- FALSE + if (!is.null(rfx_group_ids_train)) { + group_ids_factor <- factor(rfx_group_ids_train) + rfx_group_ids_train <- as.integer(group_ids_factor) + has_rfx <- TRUE + if (!is.null(rfx_group_ids_test)) { + group_ids_factor_test <- factor( + rfx_group_ids_test, + levels = levels(group_ids_factor) + ) + if (sum(is.na(group_ids_factor_test)) > 0) { + stop( + "All random effect group labels provided in rfx_group_ids_test must be present in rfx_group_ids_train" + ) + } + rfx_group_ids_test <- as.integer(group_ids_factor_test) + has_rfx_test <- TRUE + } + } + + # Check that outcome and treatment are numeric + if (!is.numeric(y_train)) { + stop("y_train must be numeric") + } + if (!is.numeric(Z_train)) { + stop("Z_train must be numeric") + } + if (!is.null(Z_test)) { + if (!is.numeric(Z_test)) stop("Z_test must be numeric") + } + + # Data consistency checks + if ((!is.null(X_test)) && (ncol(X_test) != ncol(X_train))) { + stop("X_train and X_test must have the same number of columns") + } + if ((!is.null(Z_test)) && (ncol(Z_test) != ncol(Z_train))) { + stop("Z_train and Z_test must have the same number of columns") + } + if ((!is.null(Z_train)) && (nrow(Z_train) != nrow(X_train))) { + stop("Z_train and X_train must have the same number of rows") + } + if ( + (!is.null(propensity_train)) && + (nrow(propensity_train) != nrow(X_train)) + ) { + stop("propensity_train and X_train must have the same number of rows") + } + if ((!is.null(Z_test)) && (nrow(Z_test) != nrow(X_test))) { + stop("Z_test and X_test must have the same number of rows") + } + if ((!is.null(propensity_test)) && (nrow(propensity_test) != nrow(X_test))) { + stop("propensity_test and X_test must have the same number of rows") + } + if (nrow(X_train) != length(y_train)) { + stop("X_train and y_train must have the same number of observations") + } + if ( + (!is.null(rfx_basis_test)) && + (ncol(rfx_basis_test) != ncol(rfx_basis_train)) + ) { + stop( + "rfx_basis_train and rfx_basis_test must have the same number of columns" + ) + } + if (!is.null(rfx_group_ids_train)) { + if (!is.null(rfx_group_ids_test)) { + if ((!is.null(rfx_basis_train)) && (is.null(rfx_basis_test))) { + stop( + "rfx_basis_train is provided but rfx_basis_test is not provided" + ) + } } + } - # Check whether treatment is binary (specifically 0-1 binary) - binary_treatment <- length(unique(Z_train)) == 2 - if (binary_treatment) { - unique_treatments <- sort(unique(Z_train)) - if (!(all(unique_treatments == c(0, 1)))) binary_treatment <- FALSE - } + # # Stop if multivariate treatment is provided + # if (ncol(Z_train) > 1) stop("Multivariate treatments are not currently supported") - # Adaptive coding will be ignored for continuous / ordered categorical treatments - if ((!binary_treatment) && (adaptive_coding)) { - adaptive_coding <- FALSE + # Handle multivariate treatment + has_multivariate_treatment <- ncol(Z_train) > 1 + if (has_multivariate_treatment) { + # Disable adaptive coding, internal propensity model, and + # leaf scale sampling if treatment is multivariate + if (adaptive_coding) { + warning( + "Adaptive coding is incompatible with multivariate treatment and will be ignored" + ) + adaptive_coding <- FALSE + } + if (is.null(propensity_train)) { + if (propensity_covariate != "none") { + warning( + "No propensities were provided for the multivariate treatment; an internal propensity model will not be fitted to the multivariate treatment and propensity_covariate will be set to 'none'" + ) + propensity_covariate <- "none" + } } - - # Check if propensity_covariate is one of the required inputs - if (!(propensity_covariate %in% c("mu", "tau", "both", "none"))) { + if (sample_sigma2_leaf_tau) { + warning( + "Sampling leaf scale not yet supported for multivariate leaf models, so the leaf scale parameter will not be sampled for the treatment forest in this model." + ) + sample_sigma2_leaf_tau <- FALSE + } + } + + # Update variable weights + variable_weights_adj <- 1 / + sapply(original_var_indices, function(x) sum(original_var_indices == x)) + variable_weights <- variable_weights[original_var_indices] * + variable_weights_adj + + # Create mu and tau (and variance) specific variable weights with weights zeroed out for excluded variables + variable_weights_variance <- variable_weights_tau <- variable_weights_mu <- variable_weights + variable_weights_mu[!(original_var_indices %in% variable_subset_mu)] <- 0 + variable_weights_tau[!(original_var_indices %in% variable_subset_tau)] <- 0 + if (include_variance_forest) { + variable_weights_variance[ + !(original_var_indices %in% variable_subset_variance) + ] <- 0 + } + + # Handle the rfx basis matrices + has_basis_rfx <- FALSE + num_basis_rfx <- 0 + if (has_rfx) { + if (rfx_model_spec == "custom") { + if (is.null(rfx_basis_train)) { stop( - "propensity_covariate must equal one of 'none', 'mu', 'tau', or 'both'" + "A user-provided basis (`rfx_basis_train`) must be provided when the random effects model spec is 'custom'" ) - } - - # Estimate if pre-estimated propensity score is not provided - internal_propensity_model <- FALSE - if ((is.null(propensity_train)) && (propensity_covariate != "none")) { - internal_propensity_model <- TRUE - # Estimate using the last of several iterations of GFR BART - num_burnin <- 10 - num_total <- 50 - bart_model_propensity <- bart( - X_train = X_train, - y_train = as.numeric(Z_train), - X_test = X_test_raw, - num_gfr = num_total, - num_burnin = 0, - num_mcmc = 0 + } + has_basis_rfx <- TRUE + num_basis_rfx <- ncol(rfx_basis_train) + } else if (rfx_model_spec == "intercept_only") { + rfx_basis_train <- matrix( + rep(1, nrow(X_train)), + nrow = nrow(X_train), + ncol = 1 + ) + has_basis_rfx <- TRUE + num_basis_rfx <- 1 + } else if (rfx_model_spec == "intercept_plus_treatment") { + rfx_basis_train <- cbind( + rep(1, nrow(X_train)), + Z_train + ) + has_basis_rfx <- TRUE + num_basis_rfx <- 1 + ncol(Z_train) + } + num_rfx_groups <- length(unique(rfx_group_ids_train)) + num_rfx_components <- ncol(rfx_basis_train) + if (num_rfx_groups == 1) { + warning( + "Only one group was provided for random effect sampling, so the random effects model is likely overkill" + ) + } + } + if (has_rfx_test) { + if (rfx_model_spec == "custom") { + if (is.null(rfx_basis_test)) { + stop( + "A user-provided basis (`rfx_basis_test`) must be provided when the random effects model spec is 'custom'" ) - propensity_train <- rowMeans(bart_model_propensity$y_hat_train[, - (num_burnin + 1):num_total - ]) - if ((is.null(dim(propensity_train))) && (!is.null(propensity_train))) { - propensity_train <- as.matrix(propensity_train) - } - if (has_test) { - propensity_test <- rowMeans(bart_model_propensity$y_hat_test[, - (num_burnin + 1):num_total - ]) - if ( - (is.null(dim(propensity_test))) && (!is.null(propensity_test)) - ) { - propensity_test <- as.matrix(propensity_test) - } - } + } + } else if (rfx_model_spec == "intercept_only") { + rfx_basis_test <- matrix( + rep(1, nrow(X_test)), + nrow = nrow(X_test), + ncol = 1 + ) + } else if (rfx_model_spec == "intercept_plus_treatment") { + rfx_basis_test <- cbind( + rep(1, nrow(X_test)), + Z_test + ) + } + } + + # Random effects covariance prior + if (has_rfx) { + if (is.null(rfx_prior_var)) { + rfx_prior_var <- rep(1, ncol(rfx_basis_train)) + } else { + if ((!is.integer(rfx_prior_var)) && (!is.numeric(rfx_prior_var))) { + stop("rfx_prior_var must be a numeric vector") + } + if (length(rfx_prior_var) != ncol(rfx_basis_train)) { + stop("length(rfx_prior_var) must equal ncol(rfx_basis_train)") + } + } + } + + # Check that number of samples are all nonnegative + stopifnot(num_gfr >= 0) + stopifnot(num_burnin >= 0) + stopifnot(num_mcmc >= 0) + + # Determine whether a test set is provided + has_test = !is.null(X_test) + + # Convert y_train to numeric vector if not already converted + if (!is.null(dim(y_train))) { + y_train <- as.matrix(y_train) + } + + # Check whether treatment is binary (specifically 0-1 binary) + binary_treatment <- length(unique(Z_train)) == 2 + if (binary_treatment) { + unique_treatments <- sort(unique(Z_train)) + if (!(all(unique_treatments == c(0, 1)))) binary_treatment <- FALSE + } + + # Adaptive coding will be ignored for continuous / ordered categorical treatments + if ((!binary_treatment) && (adaptive_coding)) { + adaptive_coding <- FALSE + } + + # Check if propensity_covariate is one of the required inputs + if (!(propensity_covariate %in% c("mu", "tau", "both", "none"))) { + stop( + "propensity_covariate must equal one of 'none', 'mu', 'tau', or 'both'" + ) + } + + # Estimate if pre-estimated propensity score is not provided + internal_propensity_model <- FALSE + if ((is.null(propensity_train)) && (propensity_covariate != "none")) { + internal_propensity_model <- TRUE + # Estimate using the last of several iterations of GFR BART + num_burnin <- 10 + num_total <- 50 + bart_model_propensity <- bart( + X_train = X_train, + y_train = as.numeric(Z_train), + X_test = X_test_raw, + num_gfr = num_total, + num_burnin = 0, + num_mcmc = 0 + ) + propensity_train <- rowMeans(bart_model_propensity$y_hat_train[, + (num_burnin + 1):num_total + ]) + if ((is.null(dim(propensity_train))) && (!is.null(propensity_train))) { + propensity_train <- as.matrix(propensity_train) } - if (has_test) { - if (is.null(propensity_test)) { - stop( - "Propensity score must be provided for the test set if provided for the training set" - ) - } - } - - # Update feature_types and covariates - feature_types <- as.integer(feature_types) - if (propensity_covariate != "none") { - feature_types <- as.integer(c( - feature_types, - rep(0, ncol(propensity_train)) - )) - X_train <- cbind(X_train, propensity_train) - if (propensity_covariate == "mu") { - variable_weights_mu <- c( - variable_weights_mu, - rep(1. / num_cov_orig, ncol(propensity_train)) - ) - variable_weights_tau <- c( - variable_weights_tau, - rep(0, ncol(propensity_train)) - ) - if (include_variance_forest) { - variable_weights_variance <- c( - variable_weights_variance, - rep(0, ncol(propensity_train)) - ) - } - } else if (propensity_covariate == "tau") { - variable_weights_mu <- c( - variable_weights_mu, - rep(0, ncol(propensity_train)) - ) - variable_weights_tau <- c( - variable_weights_tau, - rep(1. / num_cov_orig, ncol(propensity_train)) - ) - if (include_variance_forest) { - variable_weights_variance <- c( - variable_weights_variance, - rep(0, ncol(propensity_train)) - ) - } - } else if (propensity_covariate == "both") { - variable_weights_mu <- c( - variable_weights_mu, - rep(1. / num_cov_orig, ncol(propensity_train)) - ) - variable_weights_tau <- c( - variable_weights_tau, - rep(1. / num_cov_orig, ncol(propensity_train)) - ) - if (include_variance_forest) { - variable_weights_variance <- c( - variable_weights_variance, - rep(0, ncol(propensity_train)) - ) - } - } - if (has_test) X_test <- cbind(X_test, propensity_test) + propensity_test <- rowMeans(bart_model_propensity$y_hat_test[, + (num_burnin + 1):num_total + ]) + if ((is.null(dim(propensity_test))) && (!is.null(propensity_test))) { + propensity_test <- as.matrix(propensity_test) + } + } + } + + if (has_test) { + if (is.null(propensity_test)) { + stop( + "Propensity score must be provided for the test set if provided for the training set" + ) + } + } + + # Update feature_types and covariates + feature_types <- as.integer(feature_types) + if (propensity_covariate != "none") { + feature_types <- as.integer(c( + feature_types, + rep(0, ncol(propensity_train)) + )) + X_train <- cbind(X_train, propensity_train) + if (propensity_covariate == "mu") { + variable_weights_mu <- c( + variable_weights_mu, + rep(1. / num_cov_orig, ncol(propensity_train)) + ) + variable_weights_tau <- c( + variable_weights_tau, + rep(0, ncol(propensity_train)) + ) + if (include_variance_forest) { + variable_weights_variance <- c( + variable_weights_variance, + rep(0, ncol(propensity_train)) + ) + } + } else if (propensity_covariate == "tau") { + variable_weights_mu <- c( + variable_weights_mu, + rep(0, ncol(propensity_train)) + ) + variable_weights_tau <- c( + variable_weights_tau, + rep(1. / num_cov_orig, ncol(propensity_train)) + ) + if (include_variance_forest) { + variable_weights_variance <- c( + variable_weights_variance, + rep(0, ncol(propensity_train)) + ) + } + } else if (propensity_covariate == "both") { + variable_weights_mu <- c( + variable_weights_mu, + rep(1. / num_cov_orig, ncol(propensity_train)) + ) + variable_weights_tau <- c( + variable_weights_tau, + rep(1. / num_cov_orig, ncol(propensity_train)) + ) + if (include_variance_forest) { + variable_weights_variance <- c( + variable_weights_variance, + rep(0, ncol(propensity_train)) + ) + } + } + if (has_test) X_test <- cbind(X_test, propensity_test) + } + + # Renormalize variable weights + variable_weights_mu <- variable_weights_mu / sum(variable_weights_mu) + variable_weights_tau <- variable_weights_tau / sum(variable_weights_tau) + if (include_variance_forest) { + variable_weights_variance <- variable_weights_variance / + sum(variable_weights_variance) + } + + # Set num_features_subsample to default, ncol(X_train), if not already set + if (is.null(num_features_subsample_mu)) { + num_features_subsample_mu <- ncol(X_train) + } + if (is.null(num_features_subsample_tau)) { + num_features_subsample_tau <- ncol(X_train) + } + if (is.null(num_features_subsample_variance)) { + num_features_subsample_variance <- ncol(X_train) + } + + # Preliminary runtime checks for probit link + if (probit_outcome_model) { + if (!(length(unique(y_train)) == 2)) { + stop( + "You specified a probit outcome model, but supplied an outcome with more than 2 unique values" + ) + } + unique_outcomes <- sort(unique(y_train)) + if (!(all(unique_outcomes == c(0, 1)))) { + stop( + "You specified a probit outcome model, but supplied an outcome with 2 unique values other than 0 and 1" + ) } - - # Renormalize variable weights - variable_weights_mu <- variable_weights_mu / sum(variable_weights_mu) - variable_weights_tau <- variable_weights_tau / sum(variable_weights_tau) if (include_variance_forest) { - variable_weights_variance <- variable_weights_variance / - sum(variable_weights_variance) - } - - # Set num_features_subsample to default, ncol(X_train), if not already set - if (is.null(num_features_subsample_mu)) { - num_features_subsample_mu <- ncol(X_train) - } - if (is.null(num_features_subsample_tau)) { - num_features_subsample_tau <- ncol(X_train) - } - if (is.null(num_features_subsample_variance)) { - num_features_subsample_variance <- ncol(X_train) + stop("We do not support heteroskedasticity with a probit link") } - - # Preliminary runtime checks for probit link - if (probit_outcome_model) { - if (!(length(unique(y_train)) == 2)) { - stop( - "You specified a probit outcome model, but supplied an outcome with more than 2 unique values" - ) - } - unique_outcomes <- sort(unique(y_train)) - if (!(all(unique_outcomes == c(0, 1)))) { - stop( - "You specified a probit outcome model, but supplied an outcome with 2 unique values other than 0 and 1" - ) - } - if (include_variance_forest) { - stop("We do not support heteroskedasticity with a probit link") - } - if (sample_sigma2_global) { - warning( - "Global error variance will not be sampled with a probit link as it is fixed at 1" - ) - sample_sigma2_global <- F - } + if (sample_sigma2_global) { + warning( + "Global error variance will not be sampled with a probit link as it is fixed at 1" + ) + sample_sigma2_global <- F } + } - # Handle standardization, prior calibration, and initialization of forest - # differently for binary and continuous outcomes - if (probit_outcome_model) { - # Compute a probit-scale offset and fix scale to 1 - y_bar_train <- qnorm(mean(y_train)) - y_std_train <- 1 - - # Set a pseudo outcome by subtracting mean(y_train) from y_train - resid_train <- y_train - mean(y_train) - - # Set initial value for the mu forest - init_mu <- 0.0 - - # Calibrate priors for global sigma^2 and sigma2_leaf_mu / sigma2_leaf_tau - # Set sigma2_init to 1, ignoring any defaults provided - sigma2_init <- 1.0 - # Skip variance_forest_init, since variance forests are not supported with probit link - if (is.null(b_leaf_mu)) { - b_leaf_mu <- 1 / num_trees_mu - } - if (is.null(b_leaf_tau)) { - b_leaf_tau <- 1 / (2 * num_trees_tau) - } - if (is.null(sigma2_leaf_mu)) { - sigma2_leaf_mu <- 2 / (num_trees_mu) - current_leaf_scale_mu <- as.matrix(sigma2_leaf_mu) - } else { - if (!is.matrix(sigma2_leaf_mu)) { - current_leaf_scale_mu <- as.matrix(sigma2_leaf_mu) - } else { - current_leaf_scale_mu <- sigma2_leaf_mu - } - } - if (is.null(sigma2_leaf_tau)) { - # Calibrate prior so that P(abs(tau(X)) < delta_max / dnorm(0)) = p - # Use p = 0.9 as an internal default rather than adding another - # user-facing "parameter" of the binary outcome BCF prior. - # Can be overriden by specifying `sigma2_leaf_init` in - # treatment_effect_forest_params. - p <- 0.6827 - q_quantile <- qnorm((p + 1) / 2) - sigma2_leaf_tau <- ((delta_max / (q_quantile * dnorm(0)))^2) / - num_trees_tau - current_leaf_scale_tau <- as.matrix(diag( - sigma2_leaf_tau, - ncol(Z_train) - )) - } else { - if (!is.matrix(sigma2_leaf_tau)) { - current_leaf_scale_tau <- as.matrix(diag( - sigma2_leaf_tau, - ncol(Z_train) - )) - } else { - if (ncol(sigma2_leaf_tau) != ncol(Z_train)) { - stop( - "sigma2_leaf_init for the tau forest must have the same number of columns / rows as columns in the Z_train matrix" - ) - } - if (nrow(sigma2_leaf_tau) != ncol(Z_train)) { - stop( - "sigma2_leaf_init for the tau forest must have the same number of columns / rows as columns in the Z_train matrix" - ) - } - current_leaf_scale_tau <- sigma2_leaf_tau - } - } - current_sigma2 <- sigma2_init + # Runtime checks for variance forest + if (include_variance_forest) { + if (sample_sigma2_global) { + warning( + "Global error variance will not be sampled with a heteroskedasticity" + ) + sample_sigma2_global <- F + } + } + + # Handle standardization, prior calibration, and initialization of forest + # differently for binary and continuous outcomes + if (probit_outcome_model) { + # Compute a probit-scale offset and fix scale to 1 + y_bar_train <- qnorm(mean(y_train)) + y_std_train <- 1 + + # Set a pseudo outcome by subtracting mean(y_train) from y_train + resid_train <- y_train - mean(y_train) + + # Set initial value for the mu forest + init_mu <- 0.0 + + # Calibrate priors for global sigma^2 and sigma2_leaf_mu / sigma2_leaf_tau + # Set sigma2_init to 1, ignoring any defaults provided + sigma2_init <- 1.0 + # Skip variance_forest_init, since variance forests are not supported with probit link + if (is.null(b_leaf_mu)) { + b_leaf_mu <- 1 / num_trees_mu + } + if (is.null(b_leaf_tau)) { + b_leaf_tau <- 1 / (2 * num_trees_tau) + } + if (is.null(sigma2_leaf_mu)) { + sigma2_leaf_mu <- 2 / (num_trees_mu) + current_leaf_scale_mu <- as.matrix(sigma2_leaf_mu) } else { - # Only standardize if user requested - if (standardize) { - y_bar_train <- mean(y_train) - y_std_train <- sd(y_train) - } else { - y_bar_train <- 0 - y_std_train <- 1 - } - - # Compute standardized outcome - resid_train <- (y_train - y_bar_train) / y_std_train - - # Set initial value for the mu forest - init_mu <- mean(resid_train) - - # Calibrate priors for global sigma^2 and sigma2_leaf_mu / sigma2_leaf_tau - if (is.null(sigma2_init)) { - sigma2_init <- 1.0 * var(resid_train) - } - if (is.null(variance_forest_init)) { - variance_forest_init <- 1.0 * var(resid_train) - } - if (is.null(b_leaf_mu)) { - b_leaf_mu <- var(resid_train) / (num_trees_mu) - } - if (is.null(b_leaf_tau)) { - b_leaf_tau <- var(resid_train) / (2 * num_trees_tau) - } - if (is.null(sigma2_leaf_mu)) { - sigma2_leaf_mu <- 2.0 * var(resid_train) / (num_trees_mu) - current_leaf_scale_mu <- as.matrix(sigma2_leaf_mu) - } else { - if (!is.matrix(sigma2_leaf_mu)) { - current_leaf_scale_mu <- as.matrix(sigma2_leaf_mu) - } else { - current_leaf_scale_mu <- sigma2_leaf_mu - } + if (!is.matrix(sigma2_leaf_mu)) { + current_leaf_scale_mu <- as.matrix(sigma2_leaf_mu) + } else { + current_leaf_scale_mu <- sigma2_leaf_mu + } + } + if (is.null(sigma2_leaf_tau)) { + # Calibrate prior so that P(abs(tau(X)) < delta_max / dnorm(0)) = p + # Use p = 0.9 as an internal default rather than adding another + # user-facing "parameter" of the binary outcome BCF prior. + # Can be overriden by specifying `sigma2_leaf_init` in + # treatment_effect_forest_params. + p <- 0.6827 + q_quantile <- qnorm((p + 1) / 2) + sigma2_leaf_tau <- ((delta_max / (q_quantile * dnorm(0)))^2) / + num_trees_tau + current_leaf_scale_tau <- as.matrix(diag( + sigma2_leaf_tau, + ncol(Z_train) + )) + } else { + if (!is.matrix(sigma2_leaf_tau)) { + current_leaf_scale_tau <- as.matrix(diag( + sigma2_leaf_tau, + ncol(Z_train) + )) + } else { + if (ncol(sigma2_leaf_tau) != ncol(Z_train)) { + stop( + "sigma2_leaf_init for the tau forest must have the same number of columns / rows as columns in the Z_train matrix" + ) } - if (is.null(sigma2_leaf_tau)) { - sigma2_leaf_tau <- var(resid_train) / (num_trees_tau) - current_leaf_scale_tau <- as.matrix(diag( - sigma2_leaf_tau, - ncol(Z_train) - )) - } else { - if (!is.matrix(sigma2_leaf_tau)) { - current_leaf_scale_tau <- as.matrix(diag( - sigma2_leaf_tau, - ncol(Z_train) - )) - } else { - if (ncol(sigma2_leaf_tau) != ncol(Z_train)) { - stop( - "sigma2_leaf_init for the tau forest must have the same number of columns / rows as columns in the Z_train matrix" - ) - } - if (nrow(sigma2_leaf_tau) != ncol(Z_train)) { - stop( - "sigma2_leaf_init for the tau forest must have the same number of columns / rows as columns in the Z_train matrix" - ) - } - current_leaf_scale_tau <- sigma2_leaf_tau - } + if (nrow(sigma2_leaf_tau) != ncol(Z_train)) { + stop( + "sigma2_leaf_init for the tau forest must have the same number of columns / rows as columns in the Z_train matrix" + ) } - current_sigma2 <- sigma2_init - } - - # Set mu and tau leaf models / dimensions - leaf_model_mu_forest <- 0 - leaf_dimension_mu_forest <- 1 - if (has_multivariate_treatment) { - leaf_model_tau_forest <- 2 - leaf_dimension_tau_forest <- ncol(Z_train) + current_leaf_scale_tau <- sigma2_leaf_tau + } + } + current_sigma2 <- sigma2_init + } else { + # Only standardize if user requested + if (standardize) { + y_bar_train <- mean(y_train) + y_std_train <- sd(y_train) } else { - leaf_model_tau_forest <- 1 - leaf_dimension_tau_forest <- 1 + y_bar_train <- 0 + y_std_train <- 1 } - # Set variance leaf model type (currently only one option) - leaf_model_variance_forest <- 3 - leaf_dimension_variance_forest <- 1 - - # Random effects prior parameters - if (has_rfx) { - # Prior parameters - if (is.null(rfx_working_parameter_prior_mean)) { - if (num_rfx_components == 1) { - alpha_init <- c(0) - } else if (num_rfx_components > 1) { - alpha_init <- rep(0, num_rfx_components) - } else { - stop("There must be at least 1 random effect component") - } - } else { - alpha_init <- expand_dims_1d( - rfx_working_parameter_prior_mean, - num_rfx_components - ) - } + # Compute standardized outcome + resid_train <- (y_train - y_bar_train) / y_std_train - if (is.null(rfx_group_parameter_prior_mean)) { - xi_init <- matrix( - rep(alpha_init, num_rfx_groups), - num_rfx_components, - num_rfx_groups - ) - } else { - xi_init <- expand_dims_2d( - rfx_group_parameter_prior_mean, - num_rfx_components, - num_rfx_groups - ) - } + # Set initial value for the mu forest + init_mu <- mean(resid_train) - if (is.null(rfx_working_parameter_prior_cov)) { - sigma_alpha_init <- diag(1, num_rfx_components, num_rfx_components) - } else { - sigma_alpha_init <- expand_dims_2d_diag( - rfx_working_parameter_prior_cov, - num_rfx_components - ) - } - - if (is.null(rfx_group_parameter_prior_cov)) { - sigma_xi_init <- diag(1, num_rfx_components, num_rfx_components) - } else { - sigma_xi_init <- expand_dims_2d_diag( - rfx_group_parameter_prior_cov, - num_rfx_components - ) - } - - sigma_xi_shape <- rfx_variance_prior_shape - sigma_xi_scale <- rfx_variance_prior_scale - } - - # Random effects data structure and storage container - if (has_rfx) { - rfx_dataset_train <- createRandomEffectsDataset( - rfx_group_ids_train, - rfx_basis_train - ) - rfx_tracker_train <- createRandomEffectsTracker(rfx_group_ids_train) - rfx_model <- createRandomEffectsModel( - num_rfx_components, - num_rfx_groups - ) - rfx_model$set_working_parameter(alpha_init) - rfx_model$set_group_parameters(xi_init) - rfx_model$set_working_parameter_cov(sigma_alpha_init) - rfx_model$set_group_parameter_cov(sigma_xi_init) - rfx_model$set_variance_prior_shape(sigma_xi_shape) - rfx_model$set_variance_prior_scale(sigma_xi_scale) - rfx_samples <- createRandomEffectSamples( - num_rfx_components, - num_rfx_groups, - rfx_tracker_train - ) - } - - # Container of variance parameter samples - num_actual_mcmc_iter <- num_mcmc * keep_every - num_samples <- num_gfr + num_burnin + num_actual_mcmc_iter - # Delete GFR samples from these containers after the fact if desired - # num_retained_samples <- ifelse(keep_gfr, num_gfr, 0) + ifelse(keep_burnin, num_burnin, 0) + num_mcmc - num_retained_samples <- num_gfr + - ifelse(keep_burnin, num_burnin, 0) + - num_mcmc * num_chains - if (sample_sigma2_global) { - global_var_samples <- rep(NA, num_retained_samples) + # Calibrate priors for global sigma^2 and sigma2_leaf_mu / sigma2_leaf_tau + if (is.null(sigma2_init)) { + sigma2_init <- 1.0 * var(resid_train) } - if (sample_sigma2_leaf_mu) { - leaf_scale_mu_samples <- rep(NA, num_retained_samples) - } - if (sample_sigma2_leaf_tau) { - leaf_scale_tau_samples <- rep(NA, num_retained_samples) + if (is.null(variance_forest_init)) { + variance_forest_init <- 1.0 * var(resid_train) } - muhat_train_raw <- matrix(NA_real_, nrow(X_train), num_retained_samples) - if (include_variance_forest) { - sigma2_x_train_raw <- matrix( - NA_real_, - nrow(X_train), - num_retained_samples - ) + if (is.null(b_leaf_mu)) { + b_leaf_mu <- var(resid_train) / (num_trees_mu) } - sample_counter <- 0 - - # Prepare adaptive coding structure - if ( - (!is.numeric(b_0)) || - (!is.numeric(b_1)) || - (length(b_0) > 1) || - (length(b_1) > 1) - ) { - stop("b_0 and b_1 must be single numeric values") + if (is.null(b_leaf_tau)) { + b_leaf_tau <- var(resid_train) / (2 * num_trees_tau) } - if (adaptive_coding) { - b_0_samples <- rep(NA, num_retained_samples) - b_1_samples <- rep(NA, num_retained_samples) - current_b_0 <- b_0 - current_b_1 <- b_1 - tau_basis_train <- (1 - Z_train) * current_b_0 + Z_train * current_b_1 - if (has_test) { - tau_basis_test <- (1 - Z_test) * current_b_0 + Z_test * current_b_1 + if (is.null(sigma2_leaf_mu)) { + sigma2_leaf_mu <- 2.0 * var(resid_train) / (num_trees_mu) + current_leaf_scale_mu <- as.matrix(sigma2_leaf_mu) + } else { + if (!is.matrix(sigma2_leaf_mu)) { + current_leaf_scale_mu <- as.matrix(sigma2_leaf_mu) + } else { + current_leaf_scale_mu <- sigma2_leaf_mu + } + } + if (is.null(sigma2_leaf_tau)) { + sigma2_leaf_tau <- var(resid_train) / (num_trees_tau) + current_leaf_scale_tau <- as.matrix(diag( + sigma2_leaf_tau, + ncol(Z_train) + )) + } else { + if (!is.matrix(sigma2_leaf_tau)) { + current_leaf_scale_tau <- as.matrix(diag( + sigma2_leaf_tau, + ncol(Z_train) + )) + } else { + if (ncol(sigma2_leaf_tau) != ncol(Z_train)) { + stop( + "sigma2_leaf_init for the tau forest must have the same number of columns / rows as columns in the Z_train matrix" + ) } + if (nrow(sigma2_leaf_tau) != ncol(Z_train)) { + stop( + "sigma2_leaf_init for the tau forest must have the same number of columns / rows as columns in the Z_train matrix" + ) + } + current_leaf_scale_tau <- sigma2_leaf_tau + } + } + current_sigma2 <- sigma2_init + } + + # Set mu and tau leaf models / dimensions + leaf_model_mu_forest <- 0 + leaf_dimension_mu_forest <- 1 + if (has_multivariate_treatment) { + leaf_model_tau_forest <- 2 + leaf_dimension_tau_forest <- ncol(Z_train) + } else { + leaf_model_tau_forest <- 1 + leaf_dimension_tau_forest <- 1 + } + + # Set variance leaf model type (currently only one option) + leaf_model_variance_forest <- 3 + leaf_dimension_variance_forest <- 1 + + # Random effects prior parameters + if (has_rfx) { + # Prior parameters + if (is.null(rfx_working_parameter_prior_mean)) { + if (num_rfx_components == 1) { + alpha_init <- c(0) + } else if (num_rfx_components > 1) { + alpha_init <- rep(0, num_rfx_components) + } else { + stop("There must be at least 1 random effect component") + } } else { - tau_basis_train <- Z_train - if (has_test) tau_basis_test <- Z_test - } - - # Data - forest_dataset_train <- createForestDataset(X_train, tau_basis_train) - if (has_test) { - forest_dataset_test <- createForestDataset(X_test, tau_basis_test) - } - outcome_train <- createOutcome(resid_train) - - # Random number generator (std::mt19937) - if (is.null(random_seed)) { - random_seed = sample(1:10000, 1, FALSE) - } - rng <- createCppRNG(random_seed) - - # Sampling data structures - global_model_config <- createGlobalModelConfig( - global_error_variance = current_sigma2 - ) - forest_model_config_mu <- createForestModelConfig( - feature_types = feature_types, - num_trees = num_trees_mu, - num_features = ncol(X_train), - num_observations = nrow(X_train), - variable_weights = variable_weights_mu, - leaf_dimension = leaf_dimension_mu_forest, - alpha = alpha_mu, - beta = beta_mu, - min_samples_leaf = min_samples_leaf_mu, - max_depth = max_depth_mu, - leaf_model_type = leaf_model_mu_forest, - leaf_model_scale = current_leaf_scale_mu, - cutpoint_grid_size = cutpoint_grid_size, - num_features_subsample = num_features_subsample_mu - ) - forest_model_config_tau <- createForestModelConfig( - feature_types = feature_types, - num_trees = num_trees_tau, - num_features = ncol(X_train), - num_observations = nrow(X_train), - variable_weights = variable_weights_tau, - leaf_dimension = leaf_dimension_tau_forest, - alpha = alpha_tau, - beta = beta_tau, - min_samples_leaf = min_samples_leaf_tau, - max_depth = max_depth_tau, - leaf_model_type = leaf_model_tau_forest, - leaf_model_scale = current_leaf_scale_tau, - cutpoint_grid_size = cutpoint_grid_size, - num_features_subsample = num_features_subsample_tau - ) - forest_model_mu <- createForestModel( - forest_dataset_train, - forest_model_config_mu, - global_model_config - ) - forest_model_tau <- createForestModel( - forest_dataset_train, - forest_model_config_tau, - global_model_config - ) - if (include_variance_forest) { - forest_model_config_variance <- createForestModelConfig( - feature_types = feature_types, - num_trees = num_trees_variance, - num_features = ncol(X_train), - num_observations = nrow(X_train), - variable_weights = variable_weights_variance, - leaf_dimension = leaf_dimension_variance_forest, - alpha = alpha_variance, - beta = beta_variance, - min_samples_leaf = min_samples_leaf_variance, - max_depth = max_depth_variance, - leaf_model_type = leaf_model_variance_forest, - cutpoint_grid_size = cutpoint_grid_size, - num_features_subsample = num_features_subsample_variance - ) - forest_model_variance <- createForestModel( - forest_dataset_train, - forest_model_config_variance, - global_model_config - ) - } - - # Container of forest samples - forest_samples_mu <- createForestSamples(num_trees_mu, 1, TRUE) - forest_samples_tau <- createForestSamples( - num_trees_tau, - ncol(Z_train), - FALSE - ) - active_forest_mu <- createForest(num_trees_mu, 1, TRUE) - active_forest_tau <- createForest(num_trees_tau, ncol(Z_train), FALSE) - if (include_variance_forest) { - forest_samples_variance <- createForestSamples( - num_trees_variance, - 1, - TRUE, - TRUE - ) - active_forest_variance <- createForest( - num_trees_variance, - 1, - TRUE, - TRUE - ) - } - - # Initialize the leaves of each tree in the prognostic forest - active_forest_mu$prepare_for_sampler( - forest_dataset_train, - outcome_train, - forest_model_mu, - leaf_model_mu_forest, - init_mu - ) - active_forest_mu$adjust_residual( - forest_dataset_train, - outcome_train, - forest_model_mu, - FALSE, - FALSE - ) - - # Initialize the leaves of each tree in the treatment effect forest - init_tau <- rep(0., ncol(Z_train)) - active_forest_tau$prepare_for_sampler( - forest_dataset_train, - outcome_train, - forest_model_tau, - leaf_model_tau_forest, - init_tau - ) - active_forest_tau$adjust_residual( - forest_dataset_train, - outcome_train, - forest_model_tau, - TRUE, - FALSE - ) - - # Initialize the leaves of each tree in the variance forest - if (include_variance_forest) { - active_forest_variance$prepare_for_sampler( - forest_dataset_train, - outcome_train, - forest_model_variance, - leaf_model_variance_forest, - variance_forest_init - ) + alpha_init <- expand_dims_1d( + rfx_working_parameter_prior_mean, + num_rfx_components + ) + } + + if (is.null(rfx_group_parameter_prior_mean)) { + xi_init <- matrix( + rep(alpha_init, num_rfx_groups), + num_rfx_components, + num_rfx_groups + ) + } else { + xi_init <- expand_dims_2d( + rfx_group_parameter_prior_mean, + num_rfx_components, + num_rfx_groups + ) } - # Run GFR (warm start) if specified - if (num_gfr > 0) { - for (i in 1:num_gfr) { - # Keep all GFR samples at this stage -- remove from ForestSamples after MCMC - # keep_sample <- ifelse(keep_gfr, TRUE, FALSE) - keep_sample <- TRUE - if (keep_sample) { - sample_counter <- sample_counter + 1 - } - # Print progress - if (verbose) { - if ((i %% 10 == 0) || (i == num_gfr)) { - cat( - "Sampling", - i, - "out of", - num_gfr, - "XBCF (grow-from-root) draws\n" - ) - } - } - - if (probit_outcome_model) { - # Sample latent probit variable, z | - - mu_forest_pred <- active_forest_mu$predict(forest_dataset_train) - tau_forest_pred <- active_forest_tau$predict( - forest_dataset_train - ) - forest_pred <- mu_forest_pred + tau_forest_pred - mu0 <- forest_pred[y_train == 0] - mu1 <- forest_pred[y_train == 1] - u0 <- runif(sum(y_train == 0), 0, pnorm(0 - mu0)) - u1 <- runif(sum(y_train == 1), pnorm(0 - mu1), 1) - resid_train[y_train == 0] <- mu0 + qnorm(u0) - resid_train[y_train == 1] <- mu1 + qnorm(u1) - - # Update outcome - outcome_train$update_data(resid_train - forest_pred) - } + if (is.null(rfx_working_parameter_prior_cov)) { + sigma_alpha_init <- diag(1, num_rfx_components, num_rfx_components) + } else { + sigma_alpha_init <- expand_dims_2d_diag( + rfx_working_parameter_prior_cov, + num_rfx_components + ) + } - # Sample the prognostic forest - forest_model_mu$sample_one_iteration( - forest_dataset = forest_dataset_train, - residual = outcome_train, - forest_samples = forest_samples_mu, - active_forest = active_forest_mu, - rng = rng, - forest_model_config = forest_model_config_mu, - global_model_config = global_model_config, - num_threads = num_threads, - keep_forest = keep_sample, - gfr = TRUE - ) + if (is.null(rfx_group_parameter_prior_cov)) { + sigma_xi_init <- diag(1, num_rfx_components, num_rfx_components) + } else { + sigma_xi_init <- expand_dims_2d_diag( + rfx_group_parameter_prior_cov, + num_rfx_components + ) + } - # Cache train set predictions since they are already computed during sampling - if (keep_sample) { - muhat_train_raw[, - sample_counter - ] <- forest_model_mu$get_cached_forest_predictions() - } + sigma_xi_shape <- rfx_variance_prior_shape + sigma_xi_scale <- rfx_variance_prior_scale + } - # Sample variance parameters (if requested) - if (sample_sigma2_global) { - current_sigma2 <- sampleGlobalErrorVarianceOneIteration( - outcome_train, - forest_dataset_train, - rng, - a_global, - b_global - ) - global_model_config$update_global_error_variance(current_sigma2) - } - if (sample_sigma2_leaf_mu) { - leaf_scale_mu_double <- sampleLeafVarianceOneIteration( - active_forest_mu, - rng, - a_leaf_mu, - b_leaf_mu - ) - current_leaf_scale_mu <- as.matrix(leaf_scale_mu_double) - if (keep_sample) { - leaf_scale_mu_samples[ - sample_counter - ] <- leaf_scale_mu_double - } - forest_model_config_mu$update_leaf_model_scale( - current_leaf_scale_mu - ) - } + # Random effects data structure and storage container + if (has_rfx) { + rfx_dataset_train <- createRandomEffectsDataset( + rfx_group_ids_train, + rfx_basis_train + ) + rfx_tracker_train <- createRandomEffectsTracker(rfx_group_ids_train) + rfx_model <- createRandomEffectsModel( + num_rfx_components, + num_rfx_groups + ) + rfx_model$set_working_parameter(alpha_init) + rfx_model$set_group_parameters(xi_init) + rfx_model$set_working_parameter_cov(sigma_alpha_init) + rfx_model$set_group_parameter_cov(sigma_xi_init) + rfx_model$set_variance_prior_shape(sigma_xi_shape) + rfx_model$set_variance_prior_scale(sigma_xi_scale) + rfx_samples <- createRandomEffectSamples( + num_rfx_components, + num_rfx_groups, + rfx_tracker_train + ) + } + + # Container of variance parameter samples + num_actual_mcmc_iter <- num_mcmc * keep_every + num_samples <- num_gfr + num_burnin + num_actual_mcmc_iter + # Delete GFR samples from these containers after the fact if desired + # num_retained_samples <- ifelse(keep_gfr, num_gfr, 0) + ifelse(keep_burnin, num_burnin, 0) + num_mcmc + num_retained_samples <- num_gfr + + ifelse(keep_burnin, num_burnin, 0) + + num_mcmc * num_chains + if (sample_sigma2_global) { + global_var_samples <- rep(NA, num_retained_samples) + } + if (sample_sigma2_leaf_mu) { + leaf_scale_mu_samples <- rep(NA, num_retained_samples) + } + if (sample_sigma2_leaf_tau) { + leaf_scale_tau_samples <- rep(NA, num_retained_samples) + } + muhat_train_raw <- matrix(NA_real_, nrow(X_train), num_retained_samples) + if (include_variance_forest) { + sigma2_x_train_raw <- matrix( + NA_real_, + nrow(X_train), + num_retained_samples + ) + } + sample_counter <- 0 + + # Prepare adaptive coding structure + if ( + (!is.numeric(b_0)) || + (!is.numeric(b_1)) || + (length(b_0) > 1) || + (length(b_1) > 1) + ) { + stop("b_0 and b_1 must be single numeric values") + } + if (adaptive_coding) { + b_0_samples <- rep(NA, num_retained_samples) + b_1_samples <- rep(NA, num_retained_samples) + current_b_0 <- b_0 + current_b_1 <- b_1 + tau_basis_train <- (1 - Z_train) * current_b_0 + Z_train * current_b_1 + if (has_test) { + tau_basis_test <- (1 - Z_test) * current_b_0 + Z_test * current_b_1 + } + } else { + tau_basis_train <- Z_train + if (has_test) tau_basis_test <- Z_test + } + + # Data + forest_dataset_train <- createForestDataset(X_train, tau_basis_train) + if (has_test) { + forest_dataset_test <- createForestDataset(X_test, tau_basis_test) + } + outcome_train <- createOutcome(resid_train) + + # Random number generator (std::mt19937) + if (is.null(random_seed)) { + random_seed = sample(1:10000, 1, FALSE) + } + rng <- createCppRNG(random_seed) + + # Sampling data structures + global_model_config <- createGlobalModelConfig( + global_error_variance = current_sigma2 + ) + forest_model_config_mu <- createForestModelConfig( + feature_types = feature_types, + num_trees = num_trees_mu, + num_features = ncol(X_train), + num_observations = nrow(X_train), + variable_weights = variable_weights_mu, + leaf_dimension = leaf_dimension_mu_forest, + alpha = alpha_mu, + beta = beta_mu, + min_samples_leaf = min_samples_leaf_mu, + max_depth = max_depth_mu, + leaf_model_type = leaf_model_mu_forest, + leaf_model_scale = current_leaf_scale_mu, + cutpoint_grid_size = cutpoint_grid_size, + num_features_subsample = num_features_subsample_mu + ) + forest_model_config_tau <- createForestModelConfig( + feature_types = feature_types, + num_trees = num_trees_tau, + num_features = ncol(X_train), + num_observations = nrow(X_train), + variable_weights = variable_weights_tau, + leaf_dimension = leaf_dimension_tau_forest, + alpha = alpha_tau, + beta = beta_tau, + min_samples_leaf = min_samples_leaf_tau, + max_depth = max_depth_tau, + leaf_model_type = leaf_model_tau_forest, + leaf_model_scale = current_leaf_scale_tau, + cutpoint_grid_size = cutpoint_grid_size, + num_features_subsample = num_features_subsample_tau + ) + forest_model_mu <- createForestModel( + forest_dataset_train, + forest_model_config_mu, + global_model_config + ) + forest_model_tau <- createForestModel( + forest_dataset_train, + forest_model_config_tau, + global_model_config + ) + if (include_variance_forest) { + forest_model_config_variance <- createForestModelConfig( + feature_types = feature_types, + num_trees = num_trees_variance, + num_features = ncol(X_train), + num_observations = nrow(X_train), + variable_weights = variable_weights_variance, + leaf_dimension = leaf_dimension_variance_forest, + alpha = alpha_variance, + beta = beta_variance, + min_samples_leaf = min_samples_leaf_variance, + max_depth = max_depth_variance, + leaf_model_type = leaf_model_variance_forest, + cutpoint_grid_size = cutpoint_grid_size, + num_features_subsample = num_features_subsample_variance + ) + forest_model_variance <- createForestModel( + forest_dataset_train, + forest_model_config_variance, + global_model_config + ) + } + + # Container of forest samples + forest_samples_mu <- createForestSamples(num_trees_mu, 1, TRUE) + forest_samples_tau <- createForestSamples( + num_trees_tau, + ncol(Z_train), + FALSE + ) + active_forest_mu <- createForest(num_trees_mu, 1, TRUE) + active_forest_tau <- createForest(num_trees_tau, ncol(Z_train), FALSE) + if (include_variance_forest) { + forest_samples_variance <- createForestSamples( + num_trees_variance, + 1, + TRUE, + TRUE + ) + active_forest_variance <- createForest( + num_trees_variance, + 1, + TRUE, + TRUE + ) + } + + # Initialize the leaves of each tree in the prognostic forest + active_forest_mu$prepare_for_sampler( + forest_dataset_train, + outcome_train, + forest_model_mu, + leaf_model_mu_forest, + init_mu + ) + active_forest_mu$adjust_residual( + forest_dataset_train, + outcome_train, + forest_model_mu, + FALSE, + FALSE + ) + + # Initialize the leaves of each tree in the treatment effect forest + init_tau <- rep(0., ncol(Z_train)) + active_forest_tau$prepare_for_sampler( + forest_dataset_train, + outcome_train, + forest_model_tau, + leaf_model_tau_forest, + init_tau + ) + active_forest_tau$adjust_residual( + forest_dataset_train, + outcome_train, + forest_model_tau, + TRUE, + FALSE + ) + + # Initialize the leaves of each tree in the variance forest + if (include_variance_forest) { + active_forest_variance$prepare_for_sampler( + forest_dataset_train, + outcome_train, + forest_model_variance, + leaf_model_variance_forest, + variance_forest_init + ) + } + + # Run GFR (warm start) if specified + if (num_gfr > 0) { + for (i in 1:num_gfr) { + # Keep all GFR samples at this stage -- remove from ForestSamples after MCMC + # keep_sample <- ifelse(keep_gfr, TRUE, FALSE) + keep_sample <- TRUE + if (keep_sample) { + sample_counter <- sample_counter + 1 + } + # Print progress + if (verbose) { + if ((i %% 10 == 0) || (i == num_gfr)) { + cat( + "Sampling", + i, + "out of", + num_gfr, + "XBCF (grow-from-root) draws\n" + ) + } + } - # Sample the treatment forest - forest_model_tau$sample_one_iteration( - forest_dataset = forest_dataset_train, - residual = outcome_train, - forest_samples = forest_samples_tau, - active_forest = active_forest_tau, - rng = rng, - forest_model_config = forest_model_config_tau, - global_model_config = global_model_config, - num_threads = num_threads, - keep_forest = keep_sample, - gfr = TRUE - ) + if (probit_outcome_model) { + # Sample latent probit variable, z | - + mu_forest_pred <- active_forest_mu$predict(forest_dataset_train) + tau_forest_pred <- active_forest_tau$predict( + forest_dataset_train + ) + outcome_pred <- mu_forest_pred + tau_forest_pred + if (has_rfx) { + rfx_pred <- rfx_model$predict( + rfx_dataset_train, + rfx_tracker_train + ) + outcome_pred <- outcome_pred + rfx_pred + } + mu0 <- outcome_pred[y_train == 0] + mu1 <- outcome_pred[y_train == 1] + u0 <- runif(sum(y_train == 0), 0, pnorm(0 - mu0)) + u1 <- runif(sum(y_train == 1), pnorm(0 - mu1), 1) + resid_train[y_train == 0] <- mu0 + qnorm(u0) + resid_train[y_train == 1] <- mu1 + qnorm(u1) + + # Update outcome + outcome_train$update_data(resid_train - outcome_pred) + } + + # Sample the prognostic forest + forest_model_mu$sample_one_iteration( + forest_dataset = forest_dataset_train, + residual = outcome_train, + forest_samples = forest_samples_mu, + active_forest = active_forest_mu, + rng = rng, + forest_model_config = forest_model_config_mu, + global_model_config = global_model_config, + num_threads = num_threads, + keep_forest = keep_sample, + gfr = TRUE + ) + + # Cache train set predictions since they are already computed during sampling + if (keep_sample) { + muhat_train_raw[, + sample_counter + ] <- forest_model_mu$get_cached_forest_predictions() + } + + # Sample variance parameters (if requested) + if (sample_sigma2_global) { + current_sigma2 <- sampleGlobalErrorVarianceOneIteration( + outcome_train, + forest_dataset_train, + rng, + a_global, + b_global + ) + global_model_config$update_global_error_variance(current_sigma2) + } + if (sample_sigma2_leaf_mu) { + leaf_scale_mu_double <- sampleLeafVarianceOneIteration( + active_forest_mu, + rng, + a_leaf_mu, + b_leaf_mu + ) + current_leaf_scale_mu <- as.matrix(leaf_scale_mu_double) + if (keep_sample) { + leaf_scale_mu_samples[ + sample_counter + ] <- leaf_scale_mu_double + } + forest_model_config_mu$update_leaf_model_scale( + current_leaf_scale_mu + ) + } + + # Sample the treatment forest + forest_model_tau$sample_one_iteration( + forest_dataset = forest_dataset_train, + residual = outcome_train, + forest_samples = forest_samples_tau, + active_forest = active_forest_tau, + rng = rng, + forest_model_config = forest_model_config_tau, + global_model_config = global_model_config, + num_threads = num_threads, + keep_forest = keep_sample, + gfr = TRUE + ) + + # Cannot cache train set predictions for tau because the cached predictions in the + # tracking data structures are pre-multiplied by the basis (treatment) + # ... + + # Sample coding parameters (if requested) + if (adaptive_coding) { + # Estimate mu(X) and tau(X) and compute y - mu(X) + mu_x_raw_train <- active_forest_mu$predict_raw( + forest_dataset_train + ) + tau_x_raw_train <- active_forest_tau$predict_raw( + forest_dataset_train + ) + partial_resid_mu_train <- resid_train - mu_x_raw_train + if (has_rfx) { + rfx_preds_train <- rfx_model$predict( + rfx_dataset_train, + rfx_tracker_train + ) + partial_resid_mu_train <- partial_resid_mu_train - + rfx_preds_train + } - # Cannot cache train set predictions for tau because the cached predictions in the - # tracking data structures are pre-multiplied by the basis (treatment) - # ... - - # Sample coding parameters (if requested) - if (adaptive_coding) { - # Estimate mu(X) and tau(X) and compute y - mu(X) - mu_x_raw_train <- active_forest_mu$predict_raw( - forest_dataset_train - ) - tau_x_raw_train <- active_forest_tau$predict_raw( - forest_dataset_train - ) - partial_resid_mu_train <- resid_train - mu_x_raw_train - if (has_rfx) { - rfx_preds_train <- rfx_model$predict( - rfx_dataset_train, - rfx_tracker_train - ) - partial_resid_mu_train <- partial_resid_mu_train - - rfx_preds_train - } - - # Compute sufficient statistics for regression of y - mu(X) on [tau(X)(1-Z), tau(X)Z] - s_tt0 <- sum(tau_x_raw_train * tau_x_raw_train * (Z_train == 0)) - s_tt1 <- sum(tau_x_raw_train * tau_x_raw_train * (Z_train == 1)) - s_ty0 <- sum( - tau_x_raw_train * partial_resid_mu_train * (Z_train == 0) - ) - s_ty1 <- sum( - tau_x_raw_train * partial_resid_mu_train * (Z_train == 1) - ) - - # Sample b0 (coefficient on tau(X)(1-Z)) and b1 (coefficient on tau(X)Z) - current_b_0 <- rnorm( - 1, - (s_ty0 / (s_tt0 + 2 * current_sigma2)), - sqrt(current_sigma2 / (s_tt0 + 2 * current_sigma2)) - ) - current_b_1 <- rnorm( - 1, - (s_ty1 / (s_tt1 + 2 * current_sigma2)), - sqrt(current_sigma2 / (s_tt1 + 2 * current_sigma2)) - ) - - # Update basis for the leaf regression - tau_basis_train <- (1 - Z_train) * - current_b_0 + - Z_train * current_b_1 - forest_dataset_train$update_basis(tau_basis_train) - if (keep_sample) { - b_0_samples[sample_counter] <- current_b_0 - b_1_samples[sample_counter] <- current_b_1 - } - if (has_test) { - tau_basis_test <- (1 - Z_test) * - current_b_0 + - Z_test * current_b_1 - forest_dataset_test$update_basis(tau_basis_test) - } - - # Update leaf predictions and residual - forest_model_tau$propagate_basis_update( - forest_dataset_train, - outcome_train, - active_forest_tau - ) - } + # Compute sufficient statistics for regression of y - mu(X) on [tau(X)(1-Z), tau(X)Z] + s_tt0 <- sum(tau_x_raw_train * tau_x_raw_train * (Z_train == 0)) + s_tt1 <- sum(tau_x_raw_train * tau_x_raw_train * (Z_train == 1)) + s_ty0 <- sum( + tau_x_raw_train * partial_resid_mu_train * (Z_train == 0) + ) + s_ty1 <- sum( + tau_x_raw_train * partial_resid_mu_train * (Z_train == 1) + ) - # Sample variance parameters (if requested) - if (include_variance_forest) { - forest_model_variance$sample_one_iteration( - forest_dataset = forest_dataset_train, - residual = outcome_train, - forest_samples = forest_samples_variance, - active_forest = active_forest_variance, - rng = rng, - forest_model_config = forest_model_config_variance, - global_model_config = global_model_config, - num_threads = num_threads, - keep_forest = keep_sample, - gfr = TRUE - ) - - # Cache train set predictions since they are already computed during sampling - if (keep_sample) { - sigma2_x_train_raw[, - sample_counter - ] <- forest_model_variance$get_cached_forest_predictions() - } - } - if (sample_sigma2_global) { - current_sigma2 <- sampleGlobalErrorVarianceOneIteration( - outcome_train, - forest_dataset_train, - rng, - a_global, - b_global - ) - if (keep_sample) { - global_var_samples[sample_counter] <- current_sigma2 - } - global_model_config$update_global_error_variance(current_sigma2) - } - if (sample_sigma2_leaf_tau) { - leaf_scale_tau_double <- sampleLeafVarianceOneIteration( - active_forest_tau, - rng, - a_leaf_tau, - b_leaf_tau - ) - current_leaf_scale_tau <- as.matrix(leaf_scale_tau_double) - if (keep_sample) { - leaf_scale_tau_samples[ - sample_counter - ] <- leaf_scale_tau_double - } - forest_model_config_mu$update_leaf_model_scale( - current_leaf_scale_mu - ) - } + # Sample b0 (coefficient on tau(X)(1-Z)) and b1 (coefficient on tau(X)Z) + current_b_0 <- rnorm( + 1, + (s_ty0 / (s_tt0 + 2 * current_sigma2)), + sqrt(current_sigma2 / (s_tt0 + 2 * current_sigma2)) + ) + current_b_1 <- rnorm( + 1, + (s_ty1 / (s_tt1 + 2 * current_sigma2)), + sqrt(current_sigma2 / (s_tt1 + 2 * current_sigma2)) + ) - # Sample random effects parameters (if requested) - if (has_rfx) { - rfx_model$sample_random_effect( - rfx_dataset_train, - outcome_train, - rfx_tracker_train, - rfx_samples, - keep_sample, - current_sigma2, - rng - ) - } + # Update basis for the leaf regression + tau_basis_train <- (1 - Z_train) * + current_b_0 + + Z_train * current_b_1 + forest_dataset_train$update_basis(tau_basis_train) + if (keep_sample) { + b_0_samples[sample_counter] <- current_b_0 + b_1_samples[sample_counter] <- current_b_1 } - } - - # Run MCMC - if (num_burnin + num_mcmc > 0) { - for (chain_num in 1:num_chains) { - if (num_gfr > 0) { - # Reset state of active_forest and forest_model based on a previous GFR sample - forest_ind <- num_gfr - chain_num - resetActiveForest( - active_forest_mu, - forest_samples_mu, - forest_ind - ) - resetForestModel( - forest_model_mu, - active_forest_mu, - forest_dataset_train, - outcome_train, - TRUE - ) - resetActiveForest( - active_forest_tau, - forest_samples_tau, - forest_ind - ) - resetForestModel( - forest_model_tau, - active_forest_tau, - forest_dataset_train, - outcome_train, - TRUE - ) - if (sample_sigma2_leaf_mu) { - leaf_scale_mu_double <- leaf_scale_mu_samples[ - forest_ind + 1 - ] - current_leaf_scale_mu <- as.matrix(leaf_scale_mu_double) - forest_model_config_mu$update_leaf_model_scale( - current_leaf_scale_mu - ) - } - if (sample_sigma2_leaf_tau) { - leaf_scale_tau_double <- leaf_scale_tau_samples[ - forest_ind + 1 - ] - current_leaf_scale_tau <- as.matrix(leaf_scale_tau_double) - forest_model_config_tau$update_leaf_model_scale( - current_leaf_scale_tau - ) - } - if (include_variance_forest) { - resetActiveForest( - active_forest_variance, - forest_samples_variance, - forest_ind - ) - resetForestModel( - forest_model_variance, - active_forest_variance, - forest_dataset_train, - outcome_train, - FALSE - ) - } - if (has_rfx) { - resetRandomEffectsModel( - rfx_model, - rfx_samples, - forest_ind, - sigma_alpha_init - ) - resetRandomEffectsTracker( - rfx_tracker_train, - rfx_model, - rfx_dataset_train, - outcome_train, - rfx_samples - ) - } - if (adaptive_coding) { - current_b_1 <- b_1_samples[forest_ind + 1] - current_b_0 <- b_0_samples[forest_ind + 1] - tau_basis_train <- (1 - Z_train) * - current_b_0 + - Z_train * current_b_1 - forest_dataset_train$update_basis(tau_basis_train) - if (has_test) { - tau_basis_test <- (1 - Z_test) * - current_b_0 + - Z_test * current_b_1 - forest_dataset_test$update_basis(tau_basis_test) - } - forest_model_tau$propagate_basis_update( - forest_dataset_train, - outcome_train, - active_forest_tau - ) - } - if (sample_sigma2_global) { - current_sigma2 <- global_var_samples[forest_ind + 1] - global_model_config$update_global_error_variance( - current_sigma2 - ) - } - } else if (has_prev_model) { - resetActiveForest( - active_forest_mu, - previous_forest_samples_mu, - previous_model_warmstart_sample_num - 1 - ) - resetForestModel( - forest_model_mu, - active_forest_mu, - forest_dataset_train, - outcome_train, - TRUE - ) - resetActiveForest( - active_forest_tau, - previous_forest_samples_tau, - previous_model_warmstart_sample_num - 1 - ) - resetForestModel( - forest_model_tau, - active_forest_tau, - forest_dataset_train, - outcome_train, - TRUE - ) - if (include_variance_forest) { - resetActiveForest( - active_forest_variance, - previous_forest_samples_variance, - previous_model_warmstart_sample_num - 1 - ) - resetForestModel( - forest_model_variance, - active_forest_variance, - forest_dataset_train, - outcome_train, - FALSE - ) - } - if ( - sample_sigma2_leaf_mu && - (!is.null(previous_leaf_var_mu_samples)) - ) { - leaf_scale_mu_double <- previous_leaf_var_mu_samples[ - previous_model_warmstart_sample_num - ] - current_leaf_scale_mu <- as.matrix(leaf_scale_mu_double) - forest_model_config_mu$update_leaf_model_scale( - current_leaf_scale_mu - ) - } - if ( - sample_sigma2_leaf_tau && - (!is.null(previous_leaf_var_tau_samples)) - ) { - leaf_scale_tau_double <- previous_leaf_var_tau_samples[ - previous_model_warmstart_sample_num - ] - current_leaf_scale_tau <- as.matrix(leaf_scale_tau_double) - forest_model_config_tau$update_leaf_model_scale( - current_leaf_scale_tau - ) - } - if (adaptive_coding) { - if (!is.null(previous_b_1_samples)) { - current_b_1 <- previous_b_1_samples[ - previous_model_warmstart_sample_num - ] - } - if (!is.null(previous_b_0_samples)) { - current_b_0 <- previous_b_0_samples[ - previous_model_warmstart_sample_num - ] - } - tau_basis_train <- (1 - Z_train) * - current_b_0 + - Z_train * current_b_1 - forest_dataset_train$update_basis(tau_basis_train) - if (has_test) { - tau_basis_test <- (1 - Z_test) * - current_b_0 + - Z_test * current_b_1 - forest_dataset_test$update_basis(tau_basis_test) - } - forest_model_tau$propagate_basis_update( - forest_dataset_train, - outcome_train, - active_forest_tau - ) - } - if (has_rfx) { - if (is.null(previous_rfx_samples)) { - warning( - "`previous_model_json` did not have any random effects samples, so the RFX sampler will be run from scratch while the forests and any other parameters are warm started" - ) - rootResetRandomEffectsModel( - rfx_model, - alpha_init, - xi_init, - sigma_alpha_init, - sigma_xi_init, - sigma_xi_shape, - sigma_xi_scale - ) - rootResetRandomEffectsTracker( - rfx_tracker_train, - rfx_model, - rfx_dataset_train, - outcome_train - ) - } else { - resetRandomEffectsModel( - rfx_model, - previous_rfx_samples, - previous_model_warmstart_sample_num - 1, - sigma_alpha_init - ) - resetRandomEffectsTracker( - rfx_tracker_train, - rfx_model, - rfx_dataset_train, - outcome_train, - rfx_samples - ) - } - } - if (sample_sigma2_global) { - if (!is.null(previous_global_var_samples)) { - current_sigma2 <- previous_global_var_samples[ - previous_model_warmstart_sample_num - ] - } - global_model_config$update_global_error_variance( - current_sigma2 - ) - } - } else { - resetActiveForest(active_forest_mu) - active_forest_mu$set_root_leaves(init_mu / num_trees_mu) - resetForestModel( - forest_model_mu, - active_forest_mu, - forest_dataset_train, - outcome_train, - TRUE - ) - resetActiveForest(active_forest_tau) - active_forest_tau$set_root_leaves(init_tau / num_trees_tau) - resetForestModel( - forest_model_tau, - active_forest_tau, - forest_dataset_train, - outcome_train, - TRUE - ) - if (sample_sigma2_leaf_mu) { - current_leaf_scale_mu <- as.matrix(sigma2_leaf_mu) - forest_model_config_mu$update_leaf_model_scale( - current_leaf_scale_mu - ) - } - if (sample_sigma2_leaf_tau) { - current_leaf_scale_tau <- as.matrix(sigma2_leaf_tau) - forest_model_config_tau$update_leaf_model_scale( - current_leaf_scale_tau - ) - } - if (include_variance_forest) { - resetActiveForest(active_forest_variance) - active_forest_variance$set_root_leaves( - log(variance_forest_init) / num_trees_variance - ) - resetForestModel( - forest_model_variance, - active_forest_variance, - forest_dataset_train, - outcome_train, - FALSE - ) - } - if (has_rfx) { - rootResetRandomEffectsModel( - rfx_model, - alpha_init, - xi_init, - sigma_alpha_init, - sigma_xi_init, - sigma_xi_shape, - sigma_xi_scale - ) - rootResetRandomEffectsTracker( - rfx_tracker_train, - rfx_model, - rfx_dataset_train, - outcome_train - ) - } - if (adaptive_coding) { - current_b_1 <- b_1 - current_b_0 <- b_0 - tau_basis_train <- (1 - Z_train) * - current_b_0 + - Z_train * current_b_1 - forest_dataset_train$update_basis(tau_basis_train) - if (has_test) { - tau_basis_test <- (1 - Z_test) * - current_b_0 + - Z_test * current_b_1 - forest_dataset_test$update_basis(tau_basis_test) - } - forest_model_tau$propagate_basis_update( - forest_dataset_train, - outcome_train, - active_forest_tau - ) - } - if (sample_sigma2_global) { - current_sigma2 <- sigma2_init - global_model_config$update_global_error_variance( - current_sigma2 - ) - } - } - for (i in (num_gfr + 1):num_samples) { - is_mcmc <- i > (num_gfr + num_burnin) - if (is_mcmc) { - mcmc_counter <- i - (num_gfr + num_burnin) - if (mcmc_counter %% keep_every == 0) { - keep_sample <- TRUE - } else { - keep_sample <- FALSE - } - } else { - if (keep_burnin) { - keep_sample <- TRUE - } else { - keep_sample <- FALSE - } - } - if (keep_sample) { - sample_counter <- sample_counter + 1 - } - # Print progress - if (verbose) { - if (num_burnin > 0) { - if ( - ((i - num_gfr) %% 100 == 0) || - ((i - num_gfr) == num_burnin) - ) { - cat( - "Sampling", - i - num_gfr, - "out of", - num_gfr, - "BCF burn-in draws\n" - ) - } - } - if (num_mcmc > 0) { - if ( - ((i - num_gfr - num_burnin) %% 100 == 0) || - (i == num_samples) - ) { - cat( - "Sampling", - i - num_burnin - num_gfr, - "out of", - num_mcmc, - "BCF MCMC draws\n" - ) - } - } - } - - if (probit_outcome_model) { - # Sample latent probit variable, z | - - mu_forest_pred <- active_forest_mu$predict( - forest_dataset_train - ) - tau_forest_pred <- active_forest_tau$predict( - forest_dataset_train - ) - forest_pred <- mu_forest_pred + tau_forest_pred - mu0 <- forest_pred[y_train == 0] - mu1 <- forest_pred[y_train == 1] - u0 <- runif(sum(y_train == 0), 0, pnorm(0 - mu0)) - u1 <- runif(sum(y_train == 1), pnorm(0 - mu1), 1) - resid_train[y_train == 0] <- mu0 + qnorm(u0) - resid_train[y_train == 1] <- mu1 + qnorm(u1) - - # Update outcome - outcome_train$update_data(resid_train - forest_pred) - } - - # Sample the prognostic forest - forest_model_mu$sample_one_iteration( - forest_dataset = forest_dataset_train, - residual = outcome_train, - forest_samples = forest_samples_mu, - active_forest = active_forest_mu, - rng = rng, - forest_model_config = forest_model_config_mu, - global_model_config = global_model_config, - num_threads = num_threads, - keep_forest = keep_sample, - gfr = FALSE - ) - - # Cache train set predictions since they are already computed during sampling - if (keep_sample) { - muhat_train_raw[, - sample_counter - ] <- forest_model_mu$get_cached_forest_predictions() - } - - # Sample variance parameters (if requested) - if (sample_sigma2_global) { - current_sigma2 <- sampleGlobalErrorVarianceOneIteration( - outcome_train, - forest_dataset_train, - rng, - a_global, - b_global - ) - global_model_config$update_global_error_variance( - current_sigma2 - ) - } - if (sample_sigma2_leaf_mu) { - leaf_scale_mu_double <- sampleLeafVarianceOneIteration( - active_forest_mu, - rng, - a_leaf_mu, - b_leaf_mu - ) - current_leaf_scale_mu <- as.matrix(leaf_scale_mu_double) - if (keep_sample) { - leaf_scale_mu_samples[ - sample_counter - ] <- leaf_scale_mu_double - } - forest_model_config_mu$update_leaf_model_scale( - current_leaf_scale_mu - ) - } - - # Sample the treatment forest - forest_model_tau$sample_one_iteration( - forest_dataset = forest_dataset_train, - residual = outcome_train, - forest_samples = forest_samples_tau, - active_forest = active_forest_tau, - rng = rng, - forest_model_config = forest_model_config_tau, - global_model_config = global_model_config, - num_threads = num_threads, - keep_forest = keep_sample, - gfr = FALSE - ) - - # Cannot cache train set predictions for tau because the cached predictions in the - # tracking data structures are pre-multiplied by the basis (treatment) - # ... - - # Sample coding parameters (if requested) - if (adaptive_coding) { - # Estimate mu(X) and tau(X) and compute y - mu(X) - mu_x_raw_train <- active_forest_mu$predict_raw( - forest_dataset_train - ) - tau_x_raw_train <- active_forest_tau$predict_raw( - forest_dataset_train - ) - partial_resid_mu_train <- resid_train - mu_x_raw_train - if (has_rfx) { - rfx_preds_train <- rfx_model$predict( - rfx_dataset_train, - rfx_tracker_train - ) - partial_resid_mu_train <- partial_resid_mu_train - - rfx_preds_train - } - - # Compute sufficient statistics for regression of y - mu(X) on [tau(X)(1-Z), tau(X)Z] - s_tt0 <- sum( - tau_x_raw_train * tau_x_raw_train * (Z_train == 0) - ) - s_tt1 <- sum( - tau_x_raw_train * tau_x_raw_train * (Z_train == 1) - ) - s_ty0 <- sum( - tau_x_raw_train * - partial_resid_mu_train * - (Z_train == 0) - ) - s_ty1 <- sum( - tau_x_raw_train * - partial_resid_mu_train * - (Z_train == 1) - ) - - # Sample b0 (coefficient on tau(X)(1-Z)) and b1 (coefficient on tau(X)Z) - current_b_0 <- rnorm( - 1, - (s_ty0 / (s_tt0 + 2 * current_sigma2)), - sqrt(current_sigma2 / (s_tt0 + 2 * current_sigma2)) - ) - current_b_1 <- rnorm( - 1, - (s_ty1 / (s_tt1 + 2 * current_sigma2)), - sqrt(current_sigma2 / (s_tt1 + 2 * current_sigma2)) - ) - - # Update basis for the leaf regression - tau_basis_train <- (1 - Z_train) * - current_b_0 + - Z_train * current_b_1 - forest_dataset_train$update_basis(tau_basis_train) - if (keep_sample) { - b_0_samples[sample_counter] <- current_b_0 - b_1_samples[sample_counter] <- current_b_1 - } - if (has_test) { - tau_basis_test <- (1 - Z_test) * - current_b_0 + - Z_test * current_b_1 - forest_dataset_test$update_basis(tau_basis_test) - } - - # Update leaf predictions and residual - forest_model_tau$propagate_basis_update( - forest_dataset_train, - outcome_train, - active_forest_tau - ) - } - - # Sample variance parameters (if requested) - if (include_variance_forest) { - forest_model_variance$sample_one_iteration( - forest_dataset = forest_dataset_train, - residual = outcome_train, - forest_samples = forest_samples_variance, - active_forest = active_forest_variance, - rng = rng, - forest_model_config = forest_model_config_variance, - global_model_config = global_model_config, - num_threads = num_threads, - keep_forest = keep_sample, - gfr = FALSE - ) - - # Cache train set predictions since they are already computed during sampling - if (keep_sample) { - sigma2_x_train_raw[, - sample_counter - ] <- forest_model_variance$get_cached_forest_predictions() - } - } - if (sample_sigma2_global) { - current_sigma2 <- sampleGlobalErrorVarianceOneIteration( - outcome_train, - forest_dataset_train, - rng, - a_global, - b_global - ) - if (keep_sample) { - global_var_samples[sample_counter] <- current_sigma2 - } - global_model_config$update_global_error_variance( - current_sigma2 - ) - } - if (sample_sigma2_leaf_tau) { - leaf_scale_tau_double <- sampleLeafVarianceOneIteration( - active_forest_tau, - rng, - a_leaf_tau, - b_leaf_tau - ) - current_leaf_scale_tau <- as.matrix(leaf_scale_tau_double) - if (keep_sample) { - leaf_scale_tau_samples[ - sample_counter - ] <- leaf_scale_tau_double - } - forest_model_config_tau$update_leaf_model_scale( - current_leaf_scale_tau - ) - } - - # Sample random effects parameters (if requested) - if (has_rfx) { - rfx_model$sample_random_effect( - rfx_dataset_train, - outcome_train, - rfx_tracker_train, - rfx_samples, - keep_sample, - current_sigma2, - rng - ) - } - } + if (has_test) { + tau_basis_test <- (1 - Z_test) * + current_b_0 + + Z_test * current_b_1 + forest_dataset_test$update_basis(tau_basis_test) } - } - # Remove GFR samples if they are not to be retained - if ((!keep_gfr) && (num_gfr > 0)) { - for (i in 1:num_gfr) { - forest_samples_mu$delete_sample(0) - forest_samples_tau$delete_sample(0) - if (include_variance_forest) { - forest_samples_variance$delete_sample(0) - } - if (has_rfx) { - rfx_samples$delete_sample(0) - } + # Update leaf predictions and residual + forest_model_tau$propagate_basis_update( + forest_dataset_train, + outcome_train, + active_forest_tau + ) + } + + # Sample variance parameters (if requested) + if (include_variance_forest) { + forest_model_variance$sample_one_iteration( + forest_dataset = forest_dataset_train, + residual = outcome_train, + forest_samples = forest_samples_variance, + active_forest = active_forest_variance, + rng = rng, + forest_model_config = forest_model_config_variance, + global_model_config = global_model_config, + num_threads = num_threads, + keep_forest = keep_sample, + gfr = TRUE + ) + + # Cache train set predictions since they are already computed during sampling + if (keep_sample) { + sigma2_x_train_raw[, + sample_counter + ] <- forest_model_variance$get_cached_forest_predictions() } - if (sample_sigma2_global) { - global_var_samples <- global_var_samples[ - (num_gfr + 1):length(global_var_samples) - ] + } + if (sample_sigma2_global) { + current_sigma2 <- sampleGlobalErrorVarianceOneIteration( + outcome_train, + forest_dataset_train, + rng, + a_global, + b_global + ) + if (keep_sample) { + global_var_samples[sample_counter] <- current_sigma2 + } + global_model_config$update_global_error_variance(current_sigma2) + } + if (sample_sigma2_leaf_tau) { + leaf_scale_tau_double <- sampleLeafVarianceOneIteration( + active_forest_tau, + rng, + a_leaf_tau, + b_leaf_tau + ) + current_leaf_scale_tau <- as.matrix(leaf_scale_tau_double) + if (keep_sample) { + leaf_scale_tau_samples[ + sample_counter + ] <- leaf_scale_tau_double } + forest_model_config_mu$update_leaf_model_scale( + current_leaf_scale_mu + ) + } + + # Sample random effects parameters (if requested) + if (has_rfx) { + rfx_model$sample_random_effect( + rfx_dataset_train, + outcome_train, + rfx_tracker_train, + rfx_samples, + keep_sample, + current_sigma2, + rng + ) + } + } + } + + # Run MCMC + if (num_burnin + num_mcmc > 0) { + for (chain_num in 1:num_chains) { + if (num_gfr > 0) { + # Reset state of active_forest and forest_model based on a previous GFR sample + forest_ind <- num_gfr - chain_num + resetActiveForest( + active_forest_mu, + forest_samples_mu, + forest_ind + ) + resetForestModel( + forest_model_mu, + active_forest_mu, + forest_dataset_train, + outcome_train, + TRUE + ) + resetActiveForest( + active_forest_tau, + forest_samples_tau, + forest_ind + ) + resetForestModel( + forest_model_tau, + active_forest_tau, + forest_dataset_train, + outcome_train, + TRUE + ) if (sample_sigma2_leaf_mu) { - leaf_scale_mu_samples <- leaf_scale_mu_samples[ - (num_gfr + 1):length(leaf_scale_mu_samples) - ] + leaf_scale_mu_double <- leaf_scale_mu_samples[ + forest_ind + 1 + ] + current_leaf_scale_mu <- as.matrix(leaf_scale_mu_double) + forest_model_config_mu$update_leaf_model_scale( + current_leaf_scale_mu + ) } if (sample_sigma2_leaf_tau) { - leaf_scale_tau_samples <- leaf_scale_tau_samples[ - (num_gfr + 1):length(leaf_scale_tau_samples) - ] + leaf_scale_tau_double <- leaf_scale_tau_samples[ + forest_ind + 1 + ] + current_leaf_scale_tau <- as.matrix(leaf_scale_tau_double) + forest_model_config_tau$update_leaf_model_scale( + current_leaf_scale_tau + ) + } + if (include_variance_forest) { + resetActiveForest( + active_forest_variance, + forest_samples_variance, + forest_ind + ) + resetForestModel( + forest_model_variance, + active_forest_variance, + forest_dataset_train, + outcome_train, + FALSE + ) + } + if (has_rfx) { + resetRandomEffectsModel( + rfx_model, + rfx_samples, + forest_ind, + sigma_alpha_init + ) + resetRandomEffectsTracker( + rfx_tracker_train, + rfx_model, + rfx_dataset_train, + outcome_train, + rfx_samples + ) } if (adaptive_coding) { - b_1_samples <- b_1_samples[(num_gfr + 1):length(b_1_samples)] - b_0_samples <- b_0_samples[(num_gfr + 1):length(b_0_samples)] + current_b_1 <- b_1_samples[forest_ind + 1] + current_b_0 <- b_0_samples[forest_ind + 1] + tau_basis_train <- (1 - Z_train) * + current_b_0 + + Z_train * current_b_1 + forest_dataset_train$update_basis(tau_basis_train) + if (has_test) { + tau_basis_test <- (1 - Z_test) * + current_b_0 + + Z_test * current_b_1 + forest_dataset_test$update_basis(tau_basis_test) + } + forest_model_tau$propagate_basis_update( + forest_dataset_train, + outcome_train, + active_forest_tau + ) + } + if (sample_sigma2_global) { + current_sigma2 <- global_var_samples[forest_ind + 1] + global_model_config$update_global_error_variance( + current_sigma2 + ) } - muhat_train_raw <- muhat_train_raw[, - (num_gfr + 1):ncol(muhat_train_raw) - ] + } else if (has_prev_model) { + resetActiveForest( + active_forest_mu, + previous_forest_samples_mu, + previous_model_warmstart_sample_num - 1 + ) + resetForestModel( + forest_model_mu, + active_forest_mu, + forest_dataset_train, + outcome_train, + TRUE + ) + resetActiveForest( + active_forest_tau, + previous_forest_samples_tau, + previous_model_warmstart_sample_num - 1 + ) + resetForestModel( + forest_model_tau, + active_forest_tau, + forest_dataset_train, + outcome_train, + TRUE + ) if (include_variance_forest) { - sigma2_x_train_raw <- sigma2_x_train_raw[, - (num_gfr + 1):ncol(sigma2_x_train_raw) + resetActiveForest( + active_forest_variance, + previous_forest_samples_variance, + previous_model_warmstart_sample_num - 1 + ) + resetForestModel( + forest_model_variance, + active_forest_variance, + forest_dataset_train, + outcome_train, + FALSE + ) + } + if ( + sample_sigma2_leaf_mu && + (!is.null(previous_leaf_var_mu_samples)) + ) { + leaf_scale_mu_double <- previous_leaf_var_mu_samples[ + previous_model_warmstart_sample_num + ] + current_leaf_scale_mu <- as.matrix(leaf_scale_mu_double) + forest_model_config_mu$update_leaf_model_scale( + current_leaf_scale_mu + ) + } + if ( + sample_sigma2_leaf_tau && + (!is.null(previous_leaf_var_tau_samples)) + ) { + leaf_scale_tau_double <- previous_leaf_var_tau_samples[ + previous_model_warmstart_sample_num + ] + current_leaf_scale_tau <- as.matrix(leaf_scale_tau_double) + forest_model_config_tau$update_leaf_model_scale( + current_leaf_scale_tau + ) + } + if (adaptive_coding) { + if (!is.null(previous_b_1_samples)) { + current_b_1 <- previous_b_1_samples[ + previous_model_warmstart_sample_num ] + } + if (!is.null(previous_b_0_samples)) { + current_b_0 <- previous_b_0_samples[ + previous_model_warmstart_sample_num + ] + } + tau_basis_train <- (1 - Z_train) * + current_b_0 + + Z_train * current_b_1 + forest_dataset_train$update_basis(tau_basis_train) + if (has_test) { + tau_basis_test <- (1 - Z_test) * + current_b_0 + + Z_test * current_b_1 + forest_dataset_test$update_basis(tau_basis_test) + } + forest_model_tau$propagate_basis_update( + forest_dataset_train, + outcome_train, + active_forest_tau + ) } - num_retained_samples <- num_retained_samples - num_gfr - } - - # Forest predictions - mu_hat_train <- muhat_train_raw * y_std_train + y_bar_train - if (adaptive_coding) { - tau_hat_train_raw <- forest_samples_tau$predict_raw( - forest_dataset_train + if (has_rfx) { + if (is.null(previous_rfx_samples)) { + warning( + "`previous_model_json` did not have any random effects samples, so the RFX sampler will be run from scratch while the forests and any other parameters are warm started" + ) + rootResetRandomEffectsModel( + rfx_model, + alpha_init, + xi_init, + sigma_alpha_init, + sigma_xi_init, + sigma_xi_shape, + sigma_xi_scale + ) + rootResetRandomEffectsTracker( + rfx_tracker_train, + rfx_model, + rfx_dataset_train, + outcome_train + ) + } else { + resetRandomEffectsModel( + rfx_model, + previous_rfx_samples, + previous_model_warmstart_sample_num - 1, + sigma_alpha_init + ) + resetRandomEffectsTracker( + rfx_tracker_train, + rfx_model, + rfx_dataset_train, + outcome_train, + rfx_samples + ) + } + } + if (sample_sigma2_global) { + if (!is.null(previous_global_var_samples)) { + current_sigma2 <- previous_global_var_samples[ + previous_model_warmstart_sample_num + ] + } + global_model_config$update_global_error_variance( + current_sigma2 + ) + } + } else { + resetActiveForest(active_forest_mu) + active_forest_mu$set_root_leaves(init_mu / num_trees_mu) + resetForestModel( + forest_model_mu, + active_forest_mu, + forest_dataset_train, + outcome_train, + TRUE ) - tau_hat_train <- t(t(tau_hat_train_raw) * (b_1_samples - b_0_samples)) * - y_std_train - } else { - tau_hat_train <- forest_samples_tau$predict_raw(forest_dataset_train) * - y_std_train - } - if (has_multivariate_treatment) { - tau_train_dim <- dim(tau_hat_train) - tau_num_obs <- tau_train_dim[1] - tau_num_samples <- tau_train_dim[3] - treatment_term_train <- matrix( - NA_real_, - nrow = tau_num_obs, - tau_num_samples + resetActiveForest(active_forest_tau) + active_forest_tau$set_root_leaves(init_tau / num_trees_tau) + resetForestModel( + forest_model_tau, + active_forest_tau, + forest_dataset_train, + outcome_train, + TRUE ) - for (i in 1:nrow(Z_train)) { - treatment_term_train[i, ] <- colSums( - tau_hat_train[i, , ] * Z_train[i, ] - ) + if (sample_sigma2_leaf_mu) { + current_leaf_scale_mu <- as.matrix(sigma2_leaf_mu) + forest_model_config_mu$update_leaf_model_scale( + current_leaf_scale_mu + ) + } + if (sample_sigma2_leaf_tau) { + current_leaf_scale_tau <- as.matrix(sigma2_leaf_tau) + forest_model_config_tau$update_leaf_model_scale( + current_leaf_scale_tau + ) + } + if (include_variance_forest) { + resetActiveForest(active_forest_variance) + active_forest_variance$set_root_leaves( + log(variance_forest_init) / num_trees_variance + ) + resetForestModel( + forest_model_variance, + active_forest_variance, + forest_dataset_train, + outcome_train, + FALSE + ) + } + if (has_rfx) { + rootResetRandomEffectsModel( + rfx_model, + alpha_init, + xi_init, + sigma_alpha_init, + sigma_xi_init, + sigma_xi_shape, + sigma_xi_scale + ) + rootResetRandomEffectsTracker( + rfx_tracker_train, + rfx_model, + rfx_dataset_train, + outcome_train + ) } - } else { - treatment_term_train <- tau_hat_train * as.numeric(Z_train) - } - y_hat_train <- mu_hat_train + treatment_term_train - if (has_test) { - mu_hat_test <- forest_samples_mu$predict(forest_dataset_test) * - y_std_train + - y_bar_train if (adaptive_coding) { - tau_hat_test_raw <- forest_samples_tau$predict_raw( - forest_dataset_test - ) - tau_hat_test <- t( - t(tau_hat_test_raw) * (b_1_samples - b_0_samples) - ) * - y_std_train + current_b_1 <- b_1 + current_b_0 <- b_0 + tau_basis_train <- (1 - Z_train) * + current_b_0 + + Z_train * current_b_1 + forest_dataset_train$update_basis(tau_basis_train) + if (has_test) { + tau_basis_test <- (1 - Z_test) * + current_b_0 + + Z_test * current_b_1 + forest_dataset_test$update_basis(tau_basis_test) + } + forest_model_tau$propagate_basis_update( + forest_dataset_train, + outcome_train, + active_forest_tau + ) + } + if (sample_sigma2_global) { + current_sigma2 <- sigma2_init + global_model_config$update_global_error_variance( + current_sigma2 + ) + } + } + for (i in (num_gfr + 1):num_samples) { + is_mcmc <- i > (num_gfr + num_burnin) + if (is_mcmc) { + mcmc_counter <- i - (num_gfr + num_burnin) + if (mcmc_counter %% keep_every == 0) { + keep_sample <- TRUE + } else { + keep_sample <- FALSE + } } else { - tau_hat_test <- forest_samples_tau$predict_raw( - forest_dataset_test - ) * - y_std_train + if (keep_burnin) { + keep_sample <- TRUE + } else { + keep_sample <- FALSE + } } - if (has_multivariate_treatment) { - tau_test_dim <- dim(tau_hat_test) - tau_num_obs <- tau_test_dim[1] - tau_num_samples <- tau_test_dim[3] - treatment_term_test <- matrix( - NA_real_, - nrow = tau_num_obs, - tau_num_samples - ) - for (i in 1:nrow(Z_test)) { - treatment_term_test[i, ] <- colSums( - tau_hat_test[i, , ] * Z_test[i, ] - ) + if (keep_sample) { + sample_counter <- sample_counter + 1 + } + # Print progress + if (verbose) { + if (num_burnin > 0) { + if ( + ((i - num_gfr) %% 100 == 0) || + ((i - num_gfr) == num_burnin) + ) { + cat( + "Sampling", + i - num_gfr, + "out of", + num_gfr, + "BCF burn-in draws\n" + ) } - } else { - treatment_term_test <- tau_hat_test * as.numeric(Z_test) + } + if (num_mcmc > 0) { + if ( + ((i - num_gfr - num_burnin) %% 100 == 0) || + (i == num_samples) + ) { + cat( + "Sampling", + i - num_burnin - num_gfr, + "out of", + num_mcmc, + "BCF MCMC draws\n" + ) + } + } } - y_hat_test <- mu_hat_test + treatment_term_test - } - if (include_variance_forest) { - sigma2_x_hat_train <- exp(sigma2_x_train_raw) - if (has_test) { - sigma2_x_hat_test <- forest_samples_variance$predict( - forest_dataset_test + + if (probit_outcome_model) { + # Sample latent probit variable, z | - + mu_forest_pred <- active_forest_mu$predict( + forest_dataset_train + ) + tau_forest_pred <- active_forest_tau$predict( + forest_dataset_train + ) + outcome_pred <- mu_forest_pred + tau_forest_pred + if (has_rfx) { + rfx_pred <- rfx_model$predict( + rfx_dataset_train, + rfx_tracker_train ) + outcome_pred <- outcome_pred + rfx_pred + } + mu0 <- outcome_pred[y_train == 0] + mu1 <- outcome_pred[y_train == 1] + u0 <- runif(sum(y_train == 0), 0, pnorm(0 - mu0)) + u1 <- runif(sum(y_train == 1), pnorm(0 - mu1), 1) + resid_train[y_train == 0] <- mu0 + qnorm(u0) + resid_train[y_train == 1] <- mu1 + qnorm(u1) + + # Update outcome + outcome_train$update_data(resid_train - outcome_pred) } - } - # Random effects predictions - if (has_rfx) { - rfx_preds_train <- rfx_samples$predict( - rfx_group_ids_train, - rfx_basis_train - ) * - y_std_train - y_hat_train <- y_hat_train + rfx_preds_train - } - if ((has_rfx_test) && (has_test)) { - rfx_preds_test <- rfx_samples$predict( - rfx_group_ids_test, - rfx_basis_test - ) * - y_std_train - y_hat_test <- y_hat_test + rfx_preds_test - } + # Sample the prognostic forest + forest_model_mu$sample_one_iteration( + forest_dataset = forest_dataset_train, + residual = outcome_train, + forest_samples = forest_samples_mu, + active_forest = active_forest_mu, + rng = rng, + forest_model_config = forest_model_config_mu, + global_model_config = global_model_config, + num_threads = num_threads, + keep_forest = keep_sample, + gfr = FALSE + ) - # Global error variance - if (sample_sigma2_global) { - sigma2_global_samples <- global_var_samples * (y_std_train^2) - } + # Cache train set predictions since they are already computed during sampling + if (keep_sample) { + muhat_train_raw[, + sample_counter + ] <- forest_model_mu$get_cached_forest_predictions() + } - # Leaf parameter variance for prognostic forest - if (sample_sigma2_leaf_mu) { - sigma2_leaf_mu_samples <- leaf_scale_mu_samples - } + # Sample variance parameters (if requested) + if (sample_sigma2_global) { + current_sigma2 <- sampleGlobalErrorVarianceOneIteration( + outcome_train, + forest_dataset_train, + rng, + a_global, + b_global + ) + global_model_config$update_global_error_variance( + current_sigma2 + ) + } + if (sample_sigma2_leaf_mu) { + leaf_scale_mu_double <- sampleLeafVarianceOneIteration( + active_forest_mu, + rng, + a_leaf_mu, + b_leaf_mu + ) + current_leaf_scale_mu <- as.matrix(leaf_scale_mu_double) + if (keep_sample) { + leaf_scale_mu_samples[ + sample_counter + ] <- leaf_scale_mu_double + } + forest_model_config_mu$update_leaf_model_scale( + current_leaf_scale_mu + ) + } - # Leaf parameter variance for treatment effect forest - if (sample_sigma2_leaf_tau) { - sigma2_leaf_tau_samples <- leaf_scale_tau_samples - } + # Sample the treatment forest + forest_model_tau$sample_one_iteration( + forest_dataset = forest_dataset_train, + residual = outcome_train, + forest_samples = forest_samples_tau, + active_forest = active_forest_tau, + rng = rng, + forest_model_config = forest_model_config_tau, + global_model_config = global_model_config, + num_threads = num_threads, + keep_forest = keep_sample, + gfr = FALSE + ) - # Rescale variance forest prediction by global sigma2 (sampled or constant) - if (include_variance_forest) { + # Cannot cache train set predictions for tau because the cached predictions in the + # tracking data structures are pre-multiplied by the basis (treatment) + # ... + + # Sample coding parameters (if requested) + if (adaptive_coding) { + # Estimate mu(X) and tau(X) and compute y - mu(X) + mu_x_raw_train <- active_forest_mu$predict_raw( + forest_dataset_train + ) + tau_x_raw_train <- active_forest_tau$predict_raw( + forest_dataset_train + ) + partial_resid_mu_train <- resid_train - mu_x_raw_train + if (has_rfx) { + rfx_preds_train <- rfx_model$predict( + rfx_dataset_train, + rfx_tracker_train + ) + partial_resid_mu_train <- partial_resid_mu_train - + rfx_preds_train + } + + # Compute sufficient statistics for regression of y - mu(X) on [tau(X)(1-Z), tau(X)Z] + s_tt0 <- sum( + tau_x_raw_train * tau_x_raw_train * (Z_train == 0) + ) + s_tt1 <- sum( + tau_x_raw_train * tau_x_raw_train * (Z_train == 1) + ) + s_ty0 <- sum( + tau_x_raw_train * + partial_resid_mu_train * + (Z_train == 0) + ) + s_ty1 <- sum( + tau_x_raw_train * + partial_resid_mu_train * + (Z_train == 1) + ) + + # Sample b0 (coefficient on tau(X)(1-Z)) and b1 (coefficient on tau(X)Z) + current_b_0 <- rnorm( + 1, + (s_ty0 / (s_tt0 + 2 * current_sigma2)), + sqrt(current_sigma2 / (s_tt0 + 2 * current_sigma2)) + ) + current_b_1 <- rnorm( + 1, + (s_ty1 / (s_tt1 + 2 * current_sigma2)), + sqrt(current_sigma2 / (s_tt1 + 2 * current_sigma2)) + ) + + # Update basis for the leaf regression + tau_basis_train <- (1 - Z_train) * + current_b_0 + + Z_train * current_b_1 + forest_dataset_train$update_basis(tau_basis_train) + if (keep_sample) { + b_0_samples[sample_counter] <- current_b_0 + b_1_samples[sample_counter] <- current_b_1 + } + if (has_test) { + tau_basis_test <- (1 - Z_test) * + current_b_0 + + Z_test * current_b_1 + forest_dataset_test$update_basis(tau_basis_test) + } + + # Update leaf predictions and residual + forest_model_tau$propagate_basis_update( + forest_dataset_train, + outcome_train, + active_forest_tau + ) + } + + # Sample variance parameters (if requested) + if (include_variance_forest) { + forest_model_variance$sample_one_iteration( + forest_dataset = forest_dataset_train, + residual = outcome_train, + forest_samples = forest_samples_variance, + active_forest = active_forest_variance, + rng = rng, + forest_model_config = forest_model_config_variance, + global_model_config = global_model_config, + num_threads = num_threads, + keep_forest = keep_sample, + gfr = FALSE + ) + + # Cache train set predictions since they are already computed during sampling + if (keep_sample) { + sigma2_x_train_raw[, + sample_counter + ] <- forest_model_variance$get_cached_forest_predictions() + } + } if (sample_sigma2_global) { - sigma2_x_hat_train <- sapply(1:num_retained_samples, function(i) { - sigma2_x_hat_train[, i] * sigma2_global_samples[i] - }) - if (has_test) { - sigma2_x_hat_test <- sapply( - 1:num_retained_samples, - function(i) { - sigma2_x_hat_test[, i] * sigma2_global_samples[i] - } - ) - } - } else { - sigma2_x_hat_train <- sigma2_x_hat_train * - sigma2_init * - y_std_train * - y_std_train - if (has_test) { - sigma2_x_hat_test <- sigma2_x_hat_test * - sigma2_init * - y_std_train * - y_std_train - } + current_sigma2 <- sampleGlobalErrorVarianceOneIteration( + outcome_train, + forest_dataset_train, + rng, + a_global, + b_global + ) + if (keep_sample) { + global_var_samples[sample_counter] <- current_sigma2 + } + global_model_config$update_global_error_variance( + current_sigma2 + ) + } + if (sample_sigma2_leaf_tau) { + leaf_scale_tau_double <- sampleLeafVarianceOneIteration( + active_forest_tau, + rng, + a_leaf_tau, + b_leaf_tau + ) + current_leaf_scale_tau <- as.matrix(leaf_scale_tau_double) + if (keep_sample) { + leaf_scale_tau_samples[ + sample_counter + ] <- leaf_scale_tau_double + } + forest_model_config_tau$update_leaf_model_scale( + current_leaf_scale_tau + ) } - } - # Return results as a list - if (include_variance_forest) { - num_variance_covariates <- sum(variable_weights_variance > 0) - } else { - num_variance_covariates <- 0 - } - model_params <- list( - "initial_sigma2" = sigma2_init, - "initial_sigma2_leaf_mu" = sigma2_leaf_mu, - "initial_sigma2_leaf_tau" = sigma2_leaf_tau, - "initial_b_0" = b_0, - "initial_b_1" = b_1, - "a_global" = a_global, - "b_global" = b_global, - "a_leaf_mu" = a_leaf_mu, - "b_leaf_mu" = b_leaf_mu, - "a_leaf_tau" = a_leaf_tau, - "b_leaf_tau" = b_leaf_tau, - "a_forest" = a_forest, - "b_forest" = b_forest, - "outcome_mean" = y_bar_train, - "outcome_scale" = y_std_train, - "standardize" = standardize, - "num_covariates" = num_cov_orig, - "num_prognostic_covariates" = sum(variable_weights_mu > 0), - "num_treatment_covariates" = sum(variable_weights_tau > 0), - "num_variance_covariates" = num_variance_covariates, - "treatment_dim" = ncol(Z_train), - "propensity_covariate" = propensity_covariate, - "binary_treatment" = binary_treatment, - "multivariate_treatment" = has_multivariate_treatment, - "adaptive_coding" = adaptive_coding, - "internal_propensity_model" = internal_propensity_model, - "num_samples" = num_retained_samples, - "num_gfr" = num_gfr, - "num_burnin" = num_burnin, - "num_mcmc" = num_mcmc, - "keep_every" = keep_every, - "num_chains" = num_chains, - "has_rfx" = has_rfx, - "has_rfx_basis" = has_basis_rfx, - "num_rfx_basis" = num_basis_rfx, - "include_variance_forest" = include_variance_forest, - "sample_sigma2_global" = sample_sigma2_global, - "sample_sigma2_leaf_mu" = sample_sigma2_leaf_mu, - "sample_sigma2_leaf_tau" = sample_sigma2_leaf_tau, - "probit_outcome_model" = probit_outcome_model - ) - result <- list( - "forests_mu" = forest_samples_mu, - "forests_tau" = forest_samples_tau, - "model_params" = model_params, - "mu_hat_train" = mu_hat_train, - "tau_hat_train" = tau_hat_train, - "y_hat_train" = y_hat_train, - "train_set_metadata" = X_train_metadata - ) - if (has_test) { - result[["mu_hat_test"]] = mu_hat_test - } - if (has_test) { - result[["tau_hat_test"]] = tau_hat_test - } - if (has_test) { - result[["y_hat_test"]] = y_hat_test - } - if (include_variance_forest) { - result[["forests_variance"]] = forest_samples_variance - result[["sigma2_x_hat_train"]] = sigma2_x_hat_train - if (has_test) result[["sigma2_x_hat_test"]] = sigma2_x_hat_test + # Sample random effects parameters (if requested) + if (has_rfx) { + rfx_model$sample_random_effect( + rfx_dataset_train, + outcome_train, + rfx_tracker_train, + rfx_samples, + keep_sample, + current_sigma2, + rng + ) + } + } + } + } + + # Remove GFR samples if they are not to be retained + if ((!keep_gfr) && (num_gfr > 0)) { + for (i in 1:num_gfr) { + forest_samples_mu$delete_sample(0) + forest_samples_tau$delete_sample(0) + if (include_variance_forest) { + forest_samples_variance$delete_sample(0) + } + if (has_rfx) { + rfx_samples$delete_sample(0) + } } if (sample_sigma2_global) { - result[["sigma2_global_samples"]] = sigma2_global_samples + global_var_samples <- global_var_samples[ + (num_gfr + 1):length(global_var_samples) + ] } if (sample_sigma2_leaf_mu) { - result[["sigma2_leaf_mu_samples"]] = sigma2_leaf_mu_samples + leaf_scale_mu_samples <- leaf_scale_mu_samples[ + (num_gfr + 1):length(leaf_scale_mu_samples) + ] } if (sample_sigma2_leaf_tau) { - result[["sigma2_leaf_tau_samples"]] = sigma2_leaf_tau_samples + leaf_scale_tau_samples <- leaf_scale_tau_samples[ + (num_gfr + 1):length(leaf_scale_tau_samples) + ] } if (adaptive_coding) { - result[["b_0_samples"]] = b_0_samples - result[["b_1_samples"]] = b_1_samples - } - if (has_rfx) { - result[["rfx_samples"]] = rfx_samples - result[["rfx_preds_train"]] = rfx_preds_train - result[["rfx_unique_group_ids"]] = levels(group_ids_factor) + b_1_samples <- b_1_samples[(num_gfr + 1):length(b_1_samples)] + b_0_samples <- b_0_samples[(num_gfr + 1):length(b_0_samples)] } - if ((has_rfx_test) && (has_test)) { - result[["rfx_preds_test"]] = rfx_preds_test - } - if (internal_propensity_model) { - result[["bart_propensity_model"]] = bart_model_propensity + muhat_train_raw <- muhat_train_raw[, + (num_gfr + 1):ncol(muhat_train_raw) + ] + if (include_variance_forest) { + sigma2_x_train_raw <- sigma2_x_train_raw[, + (num_gfr + 1):ncol(sigma2_x_train_raw) + ] + } + num_retained_samples <- num_retained_samples - num_gfr + } + + # Forest predictions + mu_hat_train <- muhat_train_raw * y_std_train + y_bar_train + if (adaptive_coding) { + tau_hat_train_raw <- forest_samples_tau$predict_raw( + forest_dataset_train + ) + tau_hat_train <- t(t(tau_hat_train_raw) * (b_1_samples - b_0_samples)) * + y_std_train + } else { + tau_hat_train <- forest_samples_tau$predict_raw(forest_dataset_train) * + y_std_train + } + if (has_multivariate_treatment) { + tau_train_dim <- dim(tau_hat_train) + tau_num_obs <- tau_train_dim[1] + tau_num_samples <- tau_train_dim[3] + treatment_term_train <- matrix( + NA_real_, + nrow = tau_num_obs, + tau_num_samples + ) + for (i in 1:nrow(Z_train)) { + treatment_term_train[i, ] <- colSums( + tau_hat_train[i, , ] * Z_train[i, ] + ) + } + } else { + treatment_term_train <- tau_hat_train * as.numeric(Z_train) + } + y_hat_train <- mu_hat_train + treatment_term_train + if (has_test) { + mu_hat_test <- forest_samples_mu$predict(forest_dataset_test) * + y_std_train + + y_bar_train + if (adaptive_coding) { + tau_hat_test_raw <- forest_samples_tau$predict_raw( + forest_dataset_test + ) + tau_hat_test <- t( + t(tau_hat_test_raw) * (b_1_samples - b_0_samples) + ) * + y_std_train + } else { + tau_hat_test <- forest_samples_tau$predict_raw( + forest_dataset_test + ) * + y_std_train } - class(result) <- "bcfmodel" - - # Restore global RNG state if user provided a random seed - if (custom_rng) { - .Random.seed <- original_global_seed + if (has_multivariate_treatment) { + tau_test_dim <- dim(tau_hat_test) + tau_num_obs <- tau_test_dim[1] + tau_num_samples <- tau_test_dim[3] + treatment_term_test <- matrix( + NA_real_, + nrow = tau_num_obs, + tau_num_samples + ) + for (i in 1:nrow(Z_test)) { + treatment_term_test[i, ] <- colSums( + tau_hat_test[i, , ] * Z_test[i, ] + ) + } + } else { + treatment_term_test <- tau_hat_test * as.numeric(Z_test) } - - return(result) + y_hat_test <- mu_hat_test + treatment_term_test + } + if (include_variance_forest) { + sigma2_x_hat_train <- exp(sigma2_x_train_raw) + if (has_test) { + sigma2_x_hat_test <- forest_samples_variance$predict( + forest_dataset_test + ) + } + } + + # Random effects predictions + if (has_rfx) { + rfx_preds_train <- rfx_samples$predict( + rfx_group_ids_train, + rfx_basis_train + ) * + y_std_train + y_hat_train <- y_hat_train + rfx_preds_train + } + if ((has_rfx_test) && (has_test)) { + rfx_preds_test <- rfx_samples$predict( + rfx_group_ids_test, + rfx_basis_test + ) * + y_std_train + y_hat_test <- y_hat_test + rfx_preds_test + } + + # Global error variance + if (sample_sigma2_global) { + sigma2_global_samples <- global_var_samples * (y_std_train^2) + } + + # Leaf parameter variance for prognostic forest + if (sample_sigma2_leaf_mu) { + sigma2_leaf_mu_samples <- leaf_scale_mu_samples + } + + # Leaf parameter variance for treatment effect forest + if (sample_sigma2_leaf_tau) { + sigma2_leaf_tau_samples <- leaf_scale_tau_samples + } + + # Rescale variance forest prediction by global sigma2 (sampled or constant) + if (include_variance_forest) { + if (sample_sigma2_global) { + sigma2_x_hat_train <- sapply(1:num_retained_samples, function(i) { + sigma2_x_hat_train[, i] * sigma2_global_samples[i] + }) + if (has_test) { + sigma2_x_hat_test <- sapply( + 1:num_retained_samples, + function(i) { + sigma2_x_hat_test[, i] * sigma2_global_samples[i] + } + ) + } + } else { + sigma2_x_hat_train <- sigma2_x_hat_train * + sigma2_init * + y_std_train * + y_std_train + if (has_test) { + sigma2_x_hat_test <- sigma2_x_hat_test * + sigma2_init * + y_std_train * + y_std_train + } + } + } + + # Return results as a list + if (include_variance_forest) { + num_variance_covariates <- sum(variable_weights_variance > 0) + } else { + num_variance_covariates <- 0 + } + model_params <- list( + "initial_sigma2" = sigma2_init, + "initial_sigma2_leaf_mu" = sigma2_leaf_mu, + "initial_sigma2_leaf_tau" = sigma2_leaf_tau, + "initial_b_0" = b_0, + "initial_b_1" = b_1, + "a_global" = a_global, + "b_global" = b_global, + "a_leaf_mu" = a_leaf_mu, + "b_leaf_mu" = b_leaf_mu, + "a_leaf_tau" = a_leaf_tau, + "b_leaf_tau" = b_leaf_tau, + "a_forest" = a_forest, + "b_forest" = b_forest, + "outcome_mean" = y_bar_train, + "outcome_scale" = y_std_train, + "standardize" = standardize, + "num_covariates" = num_cov_orig, + "num_prognostic_covariates" = sum(variable_weights_mu > 0), + "num_treatment_covariates" = sum(variable_weights_tau > 0), + "num_variance_covariates" = num_variance_covariates, + "treatment_dim" = ncol(Z_train), + "propensity_covariate" = propensity_covariate, + "binary_treatment" = binary_treatment, + "multivariate_treatment" = has_multivariate_treatment, + "adaptive_coding" = adaptive_coding, + "internal_propensity_model" = internal_propensity_model, + "num_samples" = num_retained_samples, + "num_gfr" = num_gfr, + "num_burnin" = num_burnin, + "num_mcmc" = num_mcmc, + "keep_every" = keep_every, + "num_chains" = num_chains, + "has_rfx" = has_rfx, + "has_rfx_basis" = has_basis_rfx, + "num_rfx_basis" = num_basis_rfx, + "include_variance_forest" = include_variance_forest, + "sample_sigma2_global" = sample_sigma2_global, + "sample_sigma2_leaf_mu" = sample_sigma2_leaf_mu, + "sample_sigma2_leaf_tau" = sample_sigma2_leaf_tau, + "probit_outcome_model" = probit_outcome_model, + "rfx_model_spec" = rfx_model_spec + ) + result <- list( + "forests_mu" = forest_samples_mu, + "forests_tau" = forest_samples_tau, + "model_params" = model_params, + "mu_hat_train" = mu_hat_train, + "tau_hat_train" = tau_hat_train, + "y_hat_train" = y_hat_train, + "train_set_metadata" = X_train_metadata + ) + if (has_test) { + result[["mu_hat_test"]] = mu_hat_test + } + if (has_test) { + result[["tau_hat_test"]] = tau_hat_test + } + if (has_test) { + result[["y_hat_test"]] = y_hat_test + } + if (include_variance_forest) { + result[["forests_variance"]] = forest_samples_variance + result[["sigma2_x_hat_train"]] = sigma2_x_hat_train + if (has_test) result[["sigma2_x_hat_test"]] = sigma2_x_hat_test + } + if (sample_sigma2_global) { + result[["sigma2_global_samples"]] = sigma2_global_samples + } + if (sample_sigma2_leaf_mu) { + result[["sigma2_leaf_mu_samples"]] = sigma2_leaf_mu_samples + } + if (sample_sigma2_leaf_tau) { + result[["sigma2_leaf_tau_samples"]] = sigma2_leaf_tau_samples + } + if (adaptive_coding) { + result[["b_0_samples"]] = b_0_samples + result[["b_1_samples"]] = b_1_samples + } + if (has_rfx) { + result[["rfx_samples"]] = rfx_samples + result[["rfx_preds_train"]] = rfx_preds_train + result[["rfx_unique_group_ids"]] = levels(group_ids_factor) + } + if ((has_rfx_test) && (has_test)) { + result[["rfx_preds_test"]] = rfx_preds_test + } + if (internal_propensity_model) { + result[["bart_propensity_model"]] = bart_model_propensity + } + class(result) <- "bcfmodel" + + # Restore global RNG state if user provided a random seed + if (custom_rng) { + .Random.seed <- original_global_seed + } + + return(result) } #' Predict from a sampled BCF model on new data @@ -2572,10 +2641,13 @@ bcf <- function( #' @param rfx_group_ids (Optional) Test set group labels used for an additive random effects model. #' We do not currently support (but plan to in the near future), test set evaluation for group labels #' that were not in the training set. -#' @param rfx_basis (Optional) Test set basis for "random-slope" regression in additive random effects model. +#' @param rfx_basis (Optional) Test set basis for "random-slope" regression in additive random effects model. If the model was sampled with a random effects `model_spec` of "intercept_only" or "intercept_plus_treatment", this is optional, but if it is provided, it will be used. +#' @param type (Optional) Type of prediction to return. Options are "mean", which averages the predictions from every draw of a BCF model, and "posterior", which returns the entire matrix of posterior predictions. Default: "posterior". +#' @param terms (Optional) Which model terms to include in the prediction. This can be a single term or a list of model terms. Options include "y_hat", "prognostic_function", "cate", "rfx", "variance_forest", or "all". If a model doesn't have random effects or variance forest predictions, but one of those terms is request, the request will simply be ignored. If none of the requested terms are present in a model, this function will return `NULL` along with a warning. Default: "all". +#' @param scale (Optional) Scale of mean function predictions. Options are "linear", which returns predictions on the original scale of the mean forest / RFX terms, and "probability", which transforms predictions into a probability of observing `y == 1`. "probability" is only valid for models fit with a probit outcome model. Default: "linear". #' @param ... (Optional) Other prediction parameters. #' -#' @return List of 3-5 `nrow(X)` by `object$num_samples` matrices: prognostic function estimates, treatment effect estimates, (optionally) random effects predictions, (optionally) variance forest predictions, and outcome predictions. +#' @return List of prediction matrices or single prediction matrix / vector, depending on the terms requested. #' @export #' #' @examples @@ -2625,172 +2697,411 @@ bcf <- function( #' num_burnin = 0, num_mcmc = 10) #' preds <- predict(bcf_model, X_test, Z_test, pi_test) predict.bcfmodel <- function( - object, - X, - Z, - propensity = NULL, - rfx_group_ids = NULL, - rfx_basis = NULL, - ... + object, + X, + Z, + propensity = NULL, + rfx_group_ids = NULL, + rfx_basis = NULL, + type = "posterior", + terms = "all", + scale = "linear", + ... ) { - # Preprocess covariates - if ((!is.data.frame(X)) && (!is.matrix(X))) { - stop("X must be a matrix or dataframe") - } - train_set_metadata <- object$train_set_metadata - X <- preprocessPredictionData(X, train_set_metadata) - - # Convert all input data to matrices if not already converted - if ((is.null(dim(Z))) && (!is.null(Z))) { - Z <- as.matrix(as.numeric(Z)) - } - if ((is.null(dim(propensity))) && (!is.null(propensity))) { - propensity <- as.matrix(propensity) - } - if ((is.null(dim(rfx_basis))) && (!is.null(rfx_basis))) { - rfx_basis <- as.matrix(rfx_basis) - } - - # Data checks - if ( - (object$model_params$propensity_covariate != "none") && - (is.null(propensity)) - ) { - if (!object$model_params$internal_propensity_model) { - stop("propensity must be provided for this model") - } - # Compute propensity score using the internal bart model - propensity <- rowMeans(predict(object$bart_propensity_model, X)$y_hat) - } - if (nrow(X) != nrow(Z)) { - stop("X and Z must have the same number of rows") - } - if (object$model_params$num_covariates != ncol(X)) { + # Handle mean function scale + if (!is.character(scale)) { + stop("scale must be a string or character vector") + } + if (!(scale %in% c("linear", "probability"))) { + stop("scale must either be 'linear' or 'probability'") + } + is_probit <- object$model_params$probit_outcome_model + if ((scale == "probability") && (!is_probit)) { + stop( + "scale cannot be 'probability' for models not fit with a probit outcome model" + ) + } + probability_scale <- scale == "probability" + + # Handle prediction type + if (!is.character(type)) { + stop("type must be a string or character vector") + } + if (!(type %in% c("mean", "posterior"))) { + stop("type must either be 'mean' or 'posterior") + } + predict_mean <- type == "mean" + + # Handle prediction terms + rfx_model_spec = object$model_params$rfx_model_spec + rfx_intercept_only <- rfx_model_spec == "intercept_only" + rfx_intercept_plus_treatment <- (rfx_model_spec == "intercept_plus_treatment") + rfx_intercept <- rfx_intercept_only || rfx_intercept_plus_treatment + if (!is.character(terms)) { + stop("type must be a string or character vector") + } + num_terms <- length(terms) + has_mu_forest <- T + has_tau_forest <- T + has_variance_forest <- object$model_params$include_variance_forest + has_rfx <- object$model_params$has_rfx + has_y_hat <- T + predict_y_hat <- (((has_y_hat) && ("y_hat" %in% terms)) || + ((has_y_hat) && ("all" %in% terms))) + predict_mu_forest <- (((has_mu_forest) && + ("prognostic_function" %in% terms)) || + ((has_mu_forest) && ("all" %in% terms))) + predict_tau_forest <- (((has_tau_forest) && ("cate" %in% terms)) || + ((has_tau_forest) && ("all" %in% terms))) + predict_rfx <- (((has_rfx) && ("rfx" %in% terms)) || + ((has_rfx) && ("all" %in% terms))) + predict_variance_forest <- (((has_variance_forest) && + ("variance_forest" %in% terms)) || + ((has_variance_forest) && ("all" %in% terms))) + predict_count <- sum(c( + predict_y_hat, + predict_mu_forest, + predict_tau_forest, + predict_rfx, + predict_variance_forest + )) + if (predict_count == 0) { + warning(paste0( + "None of the requested model terms, ", + paste(terms, collapse = ", "), + ", were fit in this model" + )) + return(NULL) + } + predict_rfx_intermediate <- (predict_y_hat && has_rfx) + predict_rfx_raw <- ((predict_mu_forest && has_rfx && rfx_intercept) || + (predict_tau_forest && has_rfx && rfx_intercept_plus_treatment)) + predict_mu_forest_intermediate <- (predict_y_hat && has_mu_forest) + predict_tau_forest_intermediate <- (predict_y_hat && has_tau_forest) + + # Make sure covariates are matrix or data frame + if ((!is.data.frame(X)) && (!is.matrix(X))) { + stop("X must be a matrix or dataframe") + } + + # Convert all input data to matrices if not already converted + if ((is.null(dim(Z))) && (!is.null(Z))) { + Z <- as.matrix(as.numeric(Z)) + } + if ((is.null(dim(propensity))) && (!is.null(propensity))) { + propensity <- as.matrix(propensity) + } + if ((is.null(dim(rfx_basis))) && (!is.null(rfx_basis))) { + rfx_basis <- as.matrix(rfx_basis) + } + + # Data checks + if ( + (object$model_params$propensity_covariate != "none") && + (is.null(propensity)) + ) { + if (!object$model_params$internal_propensity_model) { + stop("propensity must be provided for this model") + } + # Compute propensity score using the internal bart model + propensity <- rowMeans(predict(object$bart_propensity_model, X)$y_hat) + } + if (nrow(X) != nrow(Z)) { + stop("X and Z must have the same number of rows") + } + if (object$model_params$num_covariates != ncol(X)) { + stop( + "X and must have the same number of columns as the covariates used to train the model" + ) + } + if ((object$model_params$has_rfx) && (is.null(rfx_group_ids))) { + stop( + "Random effect group labels (rfx_group_ids) must be provided for this model" + ) + } + if ((object$model_params$has_rfx_basis) && (is.null(rfx_basis))) { + if (object$model_params$rfx_model_spec == "custom") { + stop("Random effects basis (rfx_basis) must be provided for this model") + } + } + if ((object$model_params$num_rfx_basis > 0) && (!is.null(rfx_basis))) { + if (ncol(rfx_basis) != object$model_params$num_rfx_basis) { + stop( + "Random effects basis has a different dimension than the basis used to train this model" + ) + } + } + + # Preprocess covariates + train_set_metadata <- object$train_set_metadata + X <- preprocessPredictionData(X, train_set_metadata) + + # Recode group IDs to integer vector (if passed as, for example, a vector of county names, etc...) + has_rfx <- FALSE + if (!is.null(rfx_group_ids)) { + rfx_unique_group_ids <- object$rfx_unique_group_ids + group_ids_factor <- factor(rfx_group_ids, levels = rfx_unique_group_ids) + if (sum(is.na(group_ids_factor)) > 0) { + stop( + "All random effect group labels provided in rfx_group_ids must have been present at sampling time" + ) + } + rfx_group_ids <- as.integer(group_ids_factor) + has_rfx <- TRUE + } + + # Handle RFX model specification + if (has_rfx) { + if (object$model_params$rfx_model_spec == "custom") { + if (is.null(rfx_basis)) { stop( - "X and must have the same number of columns as the covariates used to train the model" + "A user-provided basis (`rfx_basis`) must be provided when the model was sampled with a random effects model spec set to 'custom'" ) - } - if ((object$model_params$has_rfx) && (is.null(rfx_group_ids))) { - stop( - "Random effect group labels (rfx_group_ids) must be provided for this model" + } + } else if (object$model_params$rfx_model_spec == "intercept_only") { + # Only construct a basis if user-provided basis missing + if (is.null(rfx_basis)) { + rfx_basis <- matrix( + rep(1, nrow(X)), + nrow = nrow(X), + ncol = 1 ) - } - if ((object$model_params$has_rfx_basis) && (is.null(rfx_basis))) { - stop("Random effects basis (rfx_basis) must be provided for this model") - } - if ( - (object$model_params$num_rfx_basis > 0) && - (ncol(rfx_basis) != object$model_params$num_rfx_basis) + } + } else if ( + object$model_params$rfx_model_spec == "intercept_plus_treatment" ) { - stop( - "Random effects basis has a different dimension than the basis used to train this model" + # Only construct a basis if user-provided basis missing + if (is.null(rfx_basis)) { + rfx_basis <- cbind( + rep(1, nrow(X)), + Z ) + } + } + } + + # Add propensities to covariate set if necessary + X_combined <- X + if (object$model_params$propensity_covariate != "none") { + X_combined <- cbind(X, propensity) + } + + # Create prediction datasets + forest_dataset_pred <- createForestDataset(X_combined, Z) + + # Compute variance forest predictions + if (predict_variance_forest) { + s_x_raw <- object$forests_variance$predict(forest_dataset_pred) + } + + # Scale variance forest predictions + num_samples <- object$model_params$num_samples + y_std <- object$model_params$outcome_scale + y_bar <- object$model_params$outcome_mean + initial_sigma2 <- object$model_params$initial_sigma2 + if (predict_variance_forest) { + if (object$model_params$sample_sigma2_global) { + sigma2_global_samples <- object$sigma2_global_samples + variance_forest_predictions <- sapply(1:num_samples, function(i) { + s_x_raw[, i] * sigma2_global_samples[i] + }) + } else { + variance_forest_predictions <- s_x_raw * + initial_sigma2 * + y_std * + y_std } - - # Recode group IDs to integer vector (if passed as, for example, a vector of county names, etc...) - has_rfx <- FALSE - if (!is.null(rfx_group_ids)) { - rfx_unique_group_ids <- object$rfx_unique_group_ids - group_ids_factor <- factor(rfx_group_ids, levels = rfx_unique_group_ids) - if (sum(is.na(group_ids_factor)) > 0) { - stop( - "All random effect group labels provided in rfx_group_ids must be present in rfx_group_ids_train" - ) - } - rfx_group_ids <- as.integer(group_ids_factor) - has_rfx <- TRUE - } - - # Produce basis for the "intercept-only" random effects case - if ((object$model_params$has_rfx) && (is.null(rfx_basis))) { - rfx_basis <- matrix(rep(1, nrow(X)), ncol = 1) - } - - # Add propensities to covariate set if necessary - X_combined <- X - if (object$model_params$propensity_covariate != "none") { - X_combined <- cbind(X, propensity) + if (predict_mean) { + variance_forest_predictions <- rowMeans(variance_forest_predictions) } + } - # Create prediction datasets - forest_dataset_pred <- createForestDataset(X_combined, Z) + # Compute mu forest predictions + if (predict_mu_forest || predict_mu_forest_intermediate) { + mu_hat_forest <- object$forests_mu$predict(forest_dataset_pred) * + y_std + + y_bar + } - # Compute forest predictions - num_samples <- object$model_params$num_samples - y_std <- object$model_params$outcome_scale - y_bar <- object$model_params$outcome_mean - initial_sigma2 <- object$model_params$initial_sigma2 - mu_hat <- object$forests_mu$predict(forest_dataset_pred) * y_std + y_bar + # Compute CATE forest predictions + if (predict_tau_forest || predict_tau_forest_intermediate) { if (object$model_params$adaptive_coding) { - tau_hat_raw <- object$forests_tau$predict_raw(forest_dataset_pred) - tau_hat <- t( - t(tau_hat_raw) * (object$b_1_samples - object$b_0_samples) - ) * - y_std + tau_hat_raw <- object$forests_tau$predict_raw(forest_dataset_pred) + tau_hat_forest <- t( + t(tau_hat_raw) * (object$b_1_samples - object$b_0_samples) + ) * + y_std } else { - tau_hat <- object$forests_tau$predict_raw(forest_dataset_pred) * y_std + tau_hat_forest <- object$forests_tau$predict_raw(forest_dataset_pred) * + y_std } if (object$model_params$multivariate_treatment) { - tau_dim <- dim(tau_hat) - tau_num_obs <- tau_dim[1] - tau_num_samples <- tau_dim[3] - treatment_term <- matrix(NA_real_, nrow = tau_num_obs, tau_num_samples) - for (i in 1:nrow(Z)) { - treatment_term[i, ] <- colSums(tau_hat[i, , ] * Z[i, ]) - } + tau_dim <- dim(tau_hat_forest) + tau_num_obs <- tau_dim[1] + tau_num_samples <- tau_dim[3] + treatment_term <- matrix(NA_real_, nrow = tau_num_obs, tau_num_samples) + for (i in 1:nrow(Z)) { + treatment_term[i, ] <- colSums(tau_hat_forest[i, , ] * Z[i, ]) + } } else { - treatment_term <- tau_hat * as.numeric(Z) - } - if (object$model_params$include_variance_forest) { - s_x_raw <- object$forests_variance$predict(forest_dataset_pred) - } - - # Compute rfx predictions (if needed) - if (object$model_params$has_rfx) { - rfx_predictions <- object$rfx_samples$predict( - rfx_group_ids, - rfx_basis - ) * - y_std + treatment_term <- tau_hat_forest * as.numeric(Z) + } + } + + # Compute rfx predictions + if (predict_rfx || predict_rfx_intermediate) { + rfx_predictions <- object$rfx_samples$predict( + rfx_group_ids, + rfx_basis + ) * + y_std + } + + # Extract "raw" rfx coefficients for each rfx basis term if needed + if (predict_rfx_raw) { + # Extract the raw RFX samples and scale by train set outcome standard deviation + rfx_param_list <- object$rfx_samples$extract_parameter_samples() + rfx_beta_draws <- rfx_param_list$beta_samples * + object$model_params$outcome_scale + + # Construct a matrix with the appropriate group random effects arranged for each observation + rfx_predictions_raw <- array( + NA, + dim = c( + nrow(X), + ncol(rfx_basis), + object$model_params$num_samples + ) + ) + for (i in 1:nrow(X)) { + rfx_predictions_raw[i, , ] <- + rfx_beta_draws[, rfx_group_ids[i], ] } + } - # Compute overall "y_hat" predictions - y_hat <- mu_hat + treatment_term - if (object$model_params$has_rfx) { - y_hat <- y_hat + rfx_predictions + # Add raw RFX predictions to mu and tau if warranted by the RFX model spec + if (predict_mu_forest || predict_mu_forest_intermediate) { + if (rfx_intercept && predict_rfx_raw) { + mu_hat_final <- mu_hat_forest + rfx_predictions_raw[, 1, ] + } else { + mu_hat_final <- mu_hat_forest } - - # Scale variance forest predictions - if (object$model_params$include_variance_forest) { - if (object$model_params$sample_sigma2_global) { - sigma2_global_samples <- object$sigma2_global_samples - variance_forest_predictions <- sapply(1:num_samples, function(i) { - s_x_raw[, i] * sigma2_global_samples[i] - }) - } else { - variance_forest_predictions <- s_x_raw * - initial_sigma2 * - y_std * - y_std + } + if (predict_tau_forest || predict_tau_forest_intermediate) { + if (rfx_intercept_plus_treatment && predict_rfx_raw) { + tau_hat_final <- (tau_hat_forest + + rfx_predictions_raw[, 2:ncol(rfx_basis), ]) + } else { + tau_hat_final <- tau_hat_forest + } + } + + # Combine into y hat predictions + needs_mean_term_preds <- predict_y_hat || + predict_mu_forest || + predict_tau_forest || + predict_rfx + if (needs_mean_term_preds) { + if (probability_scale) { + if (has_rfx) { + if (predict_y_hat) { + y_hat <- pnorm(mu_hat_forest + treatment_term + rfx_predictions) } + if (predict_rfx) { + rfx_predictions <- pnorm(rfx_predictions) + } + } else { + if (predict_y_hat) { + y_hat <- pnorm(mu_hat_forest + treatment_term) + } + } + if (predict_mu_forest) { + mu_hat <- pnorm(mu_hat_final) + } + if (predict_tau_forest) { + tau_hat <- pnorm(tau_hat_final) + } + } else { + if (has_rfx) { + if (predict_y_hat) { + y_hat <- mu_hat_forest + treatment_term + rfx_predictions + } + } else { + if (predict_y_hat) { + y_hat <- mu_hat_forest + treatment_term + } + } + if (predict_mu_forest) { + mu_hat <- mu_hat_final + } + if (predict_tau_forest) { + tau_hat <- tau_hat_final + } + } + } + + # Collapse to posterior mean predictions if requested + if (predict_mean) { + if (predict_mu_forest) { + mu_hat <- rowMeans(mu_hat) + } + if (predict_tau_forest) { + if (object$model_params$multivariate_treatment) { + tau_hat <- apply(tau_hat, c(1, 2), mean) + } else { + tau_hat <- rowMeans(tau_hat) + } + } + if (predict_rfx) { + rfx_predictions <- rowMeans(rfx_predictions) + } + if (predict_y_hat) { + y_hat <- rowMeans(y_hat) + } + } + + # Return results + if (predict_count == 1) { + if (predict_y_hat) { + return(y_hat) + } else if (predict_mu_forest) { + return(mu_hat) + } else if (predict_tau_forest) { + return(tau_hat) + } else if (predict_rfx) { + return(rfx_predictions) + } else if (predict_variance_forest) { + return(variance_forest_predictions) + } + } else { + result <- list() + if (predict_y_hat) { + result[["y_hat"]] = y_hat + } else { + result[["y_hat"]] <- NULL + } + if (predict_mu_forest) { + result[["mu_hat"]] = mu_hat + } else { + result[["mu_hat"]] <- NULL } - - result <- list( - "mu_hat" = mu_hat, - "tau_hat" = tau_hat, - "y_hat" = y_hat - ) - if (object$model_params$has_rfx) { - result[["rfx_predictions"]] <- rfx_predictions + if (predict_tau_forest) { + result[["tau_hat"]] = tau_hat } else { - result[["rfx_predictions"]] <- NULL + result[["tau_hat"]] <- NULL } - if (object$model_params$include_variance_forest) { - result[["variance_forest_predictions"]] <- variance_forest_predictions + if (predict_rfx) { + result[["rfx_predictions"]] = rfx_predictions } else { - result[["variance_forest_predictions"]] <- NULL + result[["rfx_predictions"]] <- NULL } - return(result) + if (predict_variance_forest) { + result[["variance_forest_predictions"]] = variance_forest_predictions + } else { + result[["variance_forest_predictions"]] <- NULL + } + } + return(result) } #' Extract raw sample values for each of the random effect parameter terms. @@ -2869,26 +3180,26 @@ predict.bcfmodel <- function( #' treatment_effect_forest_params = tau_params) #' rfx_samples <- getRandomEffectSamples(bcf_model) getRandomEffectSamples.bcfmodel <- function(object, ...) { - result = list() + result = list() - if (!object$model_params$has_rfx) { - warning("This model has no RFX terms, returning an empty list") - return(result) - } + if (!object$model_params$has_rfx) { + warning("This model has no RFX terms, returning an empty list") + return(result) + } - # Extract the samples - result <- object$rfx_samples$extract_parameter_samples() + # Extract the samples + result <- object$rfx_samples$extract_parameter_samples() - # Scale by sd(y_train) - result$beta_samples <- result$beta_samples * - object$model_params$outcome_scale - result$xi_samples <- result$xi_samples * object$model_params$outcome_scale - result$alpha_samples <- result$alpha_samples * - object$model_params$outcome_scale - result$sigma_samples <- result$sigma_samples * - (object$model_params$outcome_scale^2) + # Scale by sd(y_train) + result$beta_samples <- result$beta_samples * + object$model_params$outcome_scale + result$xi_samples <- result$xi_samples * object$model_params$outcome_scale + result$alpha_samples <- result$alpha_samples * + object$model_params$outcome_scale + result$sigma_samples <- result$sigma_samples * + (object$model_params$outcome_scale^2) - return(result) + return(result) } #' Convert the persistent aspects of a BCF model to (in-memory) JSON @@ -2965,161 +3276,165 @@ getRandomEffectSamples.bcfmodel <- function(object, ...) { #' treatment_effect_forest_params = tau_params) #' bcf_json <- saveBCFModelToJson(bcf_model) saveBCFModelToJson <- function(object) { - jsonobj <- createCppJson() - - if (!inherits(object, "bcfmodel")) { - stop("`object` must be a BCF model") - } - - if (is.null(object$model_params)) { - stop("This BCF model has not yet been sampled") - } - - # Add the forests - jsonobj$add_forest(object$forests_mu) - jsonobj$add_forest(object$forests_tau) - if (object$model_params$include_variance_forest) { - jsonobj$add_forest(object$forests_variance) - } - - # Add metadata - jsonobj$add_scalar( - "num_numeric_vars", - object$train_set_metadata$num_numeric_vars - ) - jsonobj$add_scalar( - "num_ordered_cat_vars", - object$train_set_metadata$num_ordered_cat_vars + jsonobj <- createCppJson() + + if (!inherits(object, "bcfmodel")) { + stop("`object` must be a BCF model") + } + + if (is.null(object$model_params)) { + stop("This BCF model has not yet been sampled") + } + + # Add the forests + jsonobj$add_forest(object$forests_mu) + jsonobj$add_forest(object$forests_tau) + if (object$model_params$include_variance_forest) { + jsonobj$add_forest(object$forests_variance) + } + + # Add metadata + jsonobj$add_scalar( + "num_numeric_vars", + object$train_set_metadata$num_numeric_vars + ) + jsonobj$add_scalar( + "num_ordered_cat_vars", + object$train_set_metadata$num_ordered_cat_vars + ) + jsonobj$add_scalar( + "num_unordered_cat_vars", + object$train_set_metadata$num_unordered_cat_vars + ) + if (object$train_set_metadata$num_numeric_vars > 0) { + jsonobj$add_string_vector( + "numeric_vars", + object$train_set_metadata$numeric_vars ) - jsonobj$add_scalar( - "num_unordered_cat_vars", - object$train_set_metadata$num_unordered_cat_vars + } + if (object$train_set_metadata$num_ordered_cat_vars > 0) { + jsonobj$add_string_vector( + "ordered_cat_vars", + object$train_set_metadata$ordered_cat_vars ) - if (object$train_set_metadata$num_numeric_vars > 0) { - jsonobj$add_string_vector( - "numeric_vars", - object$train_set_metadata$numeric_vars - ) - } - if (object$train_set_metadata$num_ordered_cat_vars > 0) { - jsonobj$add_string_vector( - "ordered_cat_vars", - object$train_set_metadata$ordered_cat_vars - ) - jsonobj$add_string_list( - "ordered_unique_levels", - object$train_set_metadata$ordered_unique_levels - ) - } - if (object$train_set_metadata$num_unordered_cat_vars > 0) { - jsonobj$add_string_vector( - "unordered_cat_vars", - object$train_set_metadata$unordered_cat_vars - ) - jsonobj$add_string_list( - "unordered_unique_levels", - object$train_set_metadata$unordered_unique_levels - ) - } - - # Add global parameters - jsonobj$add_scalar("outcome_scale", object$model_params$outcome_scale) - jsonobj$add_scalar("outcome_mean", object$model_params$outcome_mean) - jsonobj$add_boolean("standardize", object$model_params$standardize) - jsonobj$add_scalar("initial_sigma2", object$model_params$initial_sigma2) - jsonobj$add_boolean( - "sample_sigma2_global", - object$model_params$sample_sigma2_global + jsonobj$add_string_list( + "ordered_unique_levels", + object$train_set_metadata$ordered_unique_levels ) - jsonobj$add_boolean( - "sample_sigma2_leaf_mu", - object$model_params$sample_sigma2_leaf_mu + } + if (object$train_set_metadata$num_unordered_cat_vars > 0) { + jsonobj$add_string_vector( + "unordered_cat_vars", + object$train_set_metadata$unordered_cat_vars ) - jsonobj$add_boolean( - "sample_sigma2_leaf_tau", - object$model_params$sample_sigma2_leaf_tau + jsonobj$add_string_list( + "unordered_unique_levels", + object$train_set_metadata$unordered_unique_levels ) - jsonobj$add_boolean( - "include_variance_forest", - object$model_params$include_variance_forest + } + + # Add global parameters + jsonobj$add_scalar("outcome_scale", object$model_params$outcome_scale) + jsonobj$add_scalar("outcome_mean", object$model_params$outcome_mean) + jsonobj$add_boolean("standardize", object$model_params$standardize) + jsonobj$add_scalar("initial_sigma2", object$model_params$initial_sigma2) + jsonobj$add_boolean( + "sample_sigma2_global", + object$model_params$sample_sigma2_global + ) + jsonobj$add_boolean( + "sample_sigma2_leaf_mu", + object$model_params$sample_sigma2_leaf_mu + ) + jsonobj$add_boolean( + "sample_sigma2_leaf_tau", + object$model_params$sample_sigma2_leaf_tau + ) + jsonobj$add_boolean( + "include_variance_forest", + object$model_params$include_variance_forest + ) + jsonobj$add_string( + "propensity_covariate", + object$model_params$propensity_covariate + ) + jsonobj$add_boolean("has_rfx", object$model_params$has_rfx) + jsonobj$add_boolean("has_rfx_basis", object$model_params$has_rfx_basis) + jsonobj$add_scalar("num_rfx_basis", object$model_params$num_rfx_basis) + jsonobj$add_boolean( + "multivariate_treatment", + object$model_params$multivariate_treatment + ) + jsonobj$add_boolean("adaptive_coding", object$model_params$adaptive_coding) + jsonobj$add_boolean( + "internal_propensity_model", + object$model_params$internal_propensity_model + ) + jsonobj$add_scalar("num_gfr", object$model_params$num_gfr) + jsonobj$add_scalar("num_burnin", object$model_params$num_burnin) + jsonobj$add_scalar("num_mcmc", object$model_params$num_mcmc) + jsonobj$add_scalar("num_samples", object$model_params$num_samples) + jsonobj$add_scalar("keep_every", object$model_params$keep_every) + jsonobj$add_scalar("num_chains", object$model_params$num_chains) + jsonobj$add_scalar("num_covariates", object$model_params$num_covariates) + jsonobj$add_boolean( + "probit_outcome_model", + object$model_params$probit_outcome_model + ) + if (object$model_params$sample_sigma2_global) { + jsonobj$add_vector( + "sigma2_global_samples", + object$sigma2_global_samples, + "parameters" ) - jsonobj$add_string( - "propensity_covariate", - object$model_params$propensity_covariate + } + if (object$model_params$sample_sigma2_leaf_mu) { + jsonobj$add_vector( + "sigma2_leaf_mu_samples", + object$sigma2_leaf_mu_samples, + "parameters" ) - jsonobj$add_boolean("has_rfx", object$model_params$has_rfx) - jsonobj$add_boolean("has_rfx_basis", object$model_params$has_rfx_basis) - jsonobj$add_scalar("num_rfx_basis", object$model_params$num_rfx_basis) - jsonobj$add_boolean( - "multivariate_treatment", - object$model_params$multivariate_treatment + } + if (object$model_params$sample_sigma2_leaf_tau) { + jsonobj$add_vector( + "sigma2_leaf_tau_samples", + object$sigma2_leaf_tau_samples, + "parameters" ) - jsonobj$add_boolean("adaptive_coding", object$model_params$adaptive_coding) - jsonobj$add_boolean( - "internal_propensity_model", - object$model_params$internal_propensity_model + } + if (object$model_params$adaptive_coding) { + jsonobj$add_vector("b_1_samples", object$b_1_samples, "parameters") + jsonobj$add_vector("b_0_samples", object$b_0_samples, "parameters") + } + + # Add random effects (if present) + if (object$model_params$has_rfx) { + jsonobj$add_random_effects(object$rfx_samples) + jsonobj$add_string_vector( + "rfx_unique_group_ids", + object$rfx_unique_group_ids ) - jsonobj$add_scalar("num_gfr", object$model_params$num_gfr) - jsonobj$add_scalar("num_burnin", object$model_params$num_burnin) - jsonobj$add_scalar("num_mcmc", object$model_params$num_mcmc) - jsonobj$add_scalar("num_samples", object$model_params$num_samples) - jsonobj$add_scalar("keep_every", object$model_params$keep_every) - jsonobj$add_scalar("num_chains", object$model_params$num_chains) - jsonobj$add_scalar("num_covariates", object$model_params$num_covariates) - jsonobj$add_boolean( - "probit_outcome_model", - object$model_params$probit_outcome_model + } + jsonobj$add_string( + "rfx_model_spec", + object$model_params$rfx_model_spec + ) + + # Add propensity model (if it exists) + if (object$model_params$internal_propensity_model) { + bart_propensity_string <- saveBARTModelToJsonString( + object$bart_propensity_model ) - if (object$model_params$sample_sigma2_global) { - jsonobj$add_vector( - "sigma2_global_samples", - object$sigma2_global_samples, - "parameters" - ) - } - if (object$model_params$sample_sigma2_leaf_mu) { - jsonobj$add_vector( - "sigma2_leaf_mu_samples", - object$sigma2_leaf_mu_samples, - "parameters" - ) - } - if (object$model_params$sample_sigma2_leaf_tau) { - jsonobj$add_vector( - "sigma2_leaf_tau_samples", - object$sigma2_leaf_tau_samples, - "parameters" - ) - } - if (object$model_params$adaptive_coding) { - jsonobj$add_vector("b_1_samples", object$b_1_samples, "parameters") - jsonobj$add_vector("b_0_samples", object$b_0_samples, "parameters") - } + jsonobj$add_string("bart_propensity_model", bart_propensity_string) + } - # Add random effects (if present) - if (object$model_params$has_rfx) { - jsonobj$add_random_effects(object$rfx_samples) - jsonobj$add_string_vector( - "rfx_unique_group_ids", - object$rfx_unique_group_ids - ) - } - - # Add propensity model (if it exists) - if (object$model_params$internal_propensity_model) { - bart_propensity_string <- saveBARTModelToJsonString( - object$bart_propensity_model - ) - jsonobj$add_string("bart_propensity_model", bart_propensity_string) - } - - # Add covariate preprocessor metadata - preprocessor_metadata_string <- savePreprocessorToJsonString( - object$train_set_metadata - ) - jsonobj$add_string("preprocessor_metadata", preprocessor_metadata_string) + # Add covariate preprocessor metadata + preprocessor_metadata_string <- savePreprocessorToJsonString( + object$train_set_metadata + ) + jsonobj$add_string("preprocessor_metadata", preprocessor_metadata_string) - return(jsonobj) + return(jsonobj) } #' Convert the persistent aspects of a BCF model to (in-memory) JSON and save to a file @@ -3199,11 +3514,11 @@ saveBCFModelToJson <- function(object) { #' saveBCFModelToJsonFile(bcf_model, file.path(tmpjson)) #' unlink(tmpjson) saveBCFModelToJsonFile <- function(object, filename) { - # Convert to Json - jsonobj <- saveBCFModelToJson(object) + # Convert to Json + jsonobj <- saveBCFModelToJson(object) - # Save to file - jsonobj$save_file(filename) + # Save to file + jsonobj$save_file(filename) } #' Convert the persistent aspects of a BCF model to (in-memory) JSON string @@ -3279,11 +3594,11 @@ saveBCFModelToJsonFile <- function(object, filename) { #' treatment_effect_forest_params = tau_params) #' saveBCFModelToJsonString(bcf_model) saveBCFModelToJsonString <- function(object) { - # Convert to Json - jsonobj <- saveBCFModelToJson(object) + # Convert to Json + jsonobj <- saveBCFModelToJson(object) - # Dump to string - return(jsonobj$return_json_string()) + # Dump to string + return(jsonobj$return_json_string()) } #' Convert an (in-memory) JSON representation of a BCF model to a BCF model object @@ -3362,161 +3677,164 @@ saveBCFModelToJsonString <- function(object) { #' bcf_json <- saveBCFModelToJson(bcf_model) #' bcf_model_roundtrip <- createBCFModelFromJson(bcf_json) createBCFModelFromJson <- function(json_object) { - # Initialize the BCF model - output <- list() - - # Unpack the forests - output[["forests_mu"]] <- loadForestContainerJson(json_object, "forest_0") - output[["forests_tau"]] <- loadForestContainerJson(json_object, "forest_1") - include_variance_forest <- json_object$get_boolean( - "include_variance_forest" - ) - if (include_variance_forest) { - output[["forests_variance"]] <- loadForestContainerJson( - json_object, - "forest_2" - ) - } - - # Unpack metadata - train_set_metadata = list() - train_set_metadata[["num_numeric_vars"]] <- json_object$get_scalar( - "num_numeric_vars" + # Initialize the BCF model + output <- list() + + # Unpack the forests + output[["forests_mu"]] <- loadForestContainerJson(json_object, "forest_0") + output[["forests_tau"]] <- loadForestContainerJson(json_object, "forest_1") + include_variance_forest <- json_object$get_boolean( + "include_variance_forest" + ) + if (include_variance_forest) { + output[["forests_variance"]] <- loadForestContainerJson( + json_object, + "forest_2" ) - train_set_metadata[["num_ordered_cat_vars"]] <- json_object$get_scalar( - "num_ordered_cat_vars" - ) - train_set_metadata[["num_unordered_cat_vars"]] <- json_object$get_scalar( - "num_unordered_cat_vars" - ) - if (train_set_metadata[["num_numeric_vars"]] > 0) { - train_set_metadata[["numeric_vars"]] <- json_object$get_string_vector( - "numeric_vars" - ) - } - if (train_set_metadata[["num_ordered_cat_vars"]] > 0) { - train_set_metadata[[ - "ordered_cat_vars" - ]] <- json_object$get_string_vector("ordered_cat_vars") - train_set_metadata[[ - "ordered_unique_levels" - ]] <- json_object$get_string_list( - "ordered_unique_levels", - train_set_metadata[["ordered_cat_vars"]] - ) - } - if (train_set_metadata[["num_unordered_cat_vars"]] > 0) { - train_set_metadata[[ - "unordered_cat_vars" - ]] <- json_object$get_string_vector("unordered_cat_vars") - train_set_metadata[[ - "unordered_unique_levels" - ]] <- json_object$get_string_list( - "unordered_unique_levels", - train_set_metadata[["unordered_cat_vars"]] - ) - } - output[["train_set_metadata"]] <- train_set_metadata - - # Unpack model params - model_params = list() - model_params[["outcome_scale"]] <- json_object$get_scalar("outcome_scale") - model_params[["outcome_mean"]] <- json_object$get_scalar("outcome_mean") - model_params[["standardize"]] <- json_object$get_boolean("standardize") - model_params[["initial_sigma2"]] <- json_object$get_scalar("initial_sigma2") - model_params[["sample_sigma2_global"]] <- json_object$get_boolean( - "sample_sigma2_global" + } + + # Unpack metadata + train_set_metadata = list() + train_set_metadata[["num_numeric_vars"]] <- json_object$get_scalar( + "num_numeric_vars" + ) + train_set_metadata[["num_ordered_cat_vars"]] <- json_object$get_scalar( + "num_ordered_cat_vars" + ) + train_set_metadata[["num_unordered_cat_vars"]] <- json_object$get_scalar( + "num_unordered_cat_vars" + ) + if (train_set_metadata[["num_numeric_vars"]] > 0) { + train_set_metadata[["numeric_vars"]] <- json_object$get_string_vector( + "numeric_vars" ) - model_params[["sample_sigma2_leaf_mu"]] <- json_object$get_boolean( - "sample_sigma2_leaf_mu" + } + if (train_set_metadata[["num_ordered_cat_vars"]] > 0) { + train_set_metadata[[ + "ordered_cat_vars" + ]] <- json_object$get_string_vector("ordered_cat_vars") + train_set_metadata[[ + "ordered_unique_levels" + ]] <- json_object$get_string_list( + "ordered_unique_levels", + train_set_metadata[["ordered_cat_vars"]] ) - model_params[["sample_sigma2_leaf_tau"]] <- json_object$get_boolean( - "sample_sigma2_leaf_tau" + } + if (train_set_metadata[["num_unordered_cat_vars"]] > 0) { + train_set_metadata[[ + "unordered_cat_vars" + ]] <- json_object$get_string_vector("unordered_cat_vars") + train_set_metadata[[ + "unordered_unique_levels" + ]] <- json_object$get_string_list( + "unordered_unique_levels", + train_set_metadata[["unordered_cat_vars"]] ) - model_params[["include_variance_forest"]] <- include_variance_forest - model_params[["propensity_covariate"]] <- json_object$get_string( - "propensity_covariate" + } + output[["train_set_metadata"]] <- train_set_metadata + + # Unpack model params + model_params = list() + model_params[["outcome_scale"]] <- json_object$get_scalar("outcome_scale") + model_params[["outcome_mean"]] <- json_object$get_scalar("outcome_mean") + model_params[["standardize"]] <- json_object$get_boolean("standardize") + model_params[["initial_sigma2"]] <- json_object$get_scalar("initial_sigma2") + model_params[["sample_sigma2_global"]] <- json_object$get_boolean( + "sample_sigma2_global" + ) + model_params[["sample_sigma2_leaf_mu"]] <- json_object$get_boolean( + "sample_sigma2_leaf_mu" + ) + model_params[["sample_sigma2_leaf_tau"]] <- json_object$get_boolean( + "sample_sigma2_leaf_tau" + ) + model_params[["include_variance_forest"]] <- include_variance_forest + model_params[["propensity_covariate"]] <- json_object$get_string( + "propensity_covariate" + ) + model_params[["has_rfx"]] <- json_object$get_boolean("has_rfx") + model_params[["has_rfx_basis"]] <- json_object$get_boolean("has_rfx_basis") + model_params[["num_rfx_basis"]] <- json_object$get_scalar("num_rfx_basis") + model_params[["adaptive_coding"]] <- json_object$get_boolean( + "adaptive_coding" + ) + model_params[["multivariate_treatment"]] <- json_object$get_boolean( + "multivariate_treatment" + ) + model_params[["internal_propensity_model"]] <- json_object$get_boolean( + "internal_propensity_model" + ) + model_params[["num_gfr"]] <- json_object$get_scalar("num_gfr") + model_params[["num_burnin"]] <- json_object$get_scalar("num_burnin") + model_params[["num_mcmc"]] <- json_object$get_scalar("num_mcmc") + model_params[["num_samples"]] <- json_object$get_scalar("num_samples") + model_params[["num_covariates"]] <- json_object$get_scalar("num_covariates") + model_params[["probit_outcome_model"]] <- json_object$get_boolean( + "probit_outcome_model" + ) + model_params[["rfx_model_spec"]] <- json_object$get_string( + "rfx_model_spec" + ) + output[["model_params"]] <- model_params + + # Unpack sampled parameters + if (model_params[["sample_sigma2_global"]]) { + output[["sigma2_global_samples"]] <- json_object$get_vector( + "sigma2_global_samples", + "parameters" ) - model_params[["has_rfx"]] <- json_object$get_boolean("has_rfx") - model_params[["has_rfx_basis"]] <- json_object$get_boolean("has_rfx_basis") - model_params[["num_rfx_basis"]] <- json_object$get_scalar("num_rfx_basis") - model_params[["adaptive_coding"]] <- json_object$get_boolean( - "adaptive_coding" + } + if (model_params[["sample_sigma2_leaf_mu"]]) { + output[["sigma2_leaf_mu_samples"]] <- json_object$get_vector( + "sigma2_leaf_mu_samples", + "parameters" ) - model_params[["multivariate_treatment"]] <- json_object$get_boolean( - "multivariate_treatment" + } + if (model_params[["sample_sigma2_leaf_tau"]]) { + output[["sigma2_leaf_tau_samples"]] <- json_object$get_vector( + "sigma2_leaf_tau_samples", + "parameters" ) - model_params[["internal_propensity_model"]] <- json_object$get_boolean( - "internal_propensity_model" + } + if (model_params[["adaptive_coding"]]) { + output[["b_1_samples"]] <- json_object$get_vector( + "b_1_samples", + "parameters" ) - model_params[["num_gfr"]] <- json_object$get_scalar("num_gfr") - model_params[["num_burnin"]] <- json_object$get_scalar("num_burnin") - model_params[["num_mcmc"]] <- json_object$get_scalar("num_mcmc") - model_params[["num_samples"]] <- json_object$get_scalar("num_samples") - model_params[["num_covariates"]] <- json_object$get_scalar("num_covariates") - model_params[["probit_outcome_model"]] <- json_object$get_boolean( - "probit_outcome_model" + output[["b_0_samples"]] <- json_object$get_vector( + "b_0_samples", + "parameters" ) - output[["model_params"]] <- model_params + } - # Unpack sampled parameters - if (model_params[["sample_sigma2_global"]]) { - output[["sigma2_global_samples"]] <- json_object$get_vector( - "sigma2_global_samples", - "parameters" - ) - } - if (model_params[["sample_sigma2_leaf_mu"]]) { - output[["sigma2_leaf_mu_samples"]] <- json_object$get_vector( - "sigma2_leaf_mu_samples", - "parameters" - ) - } - if (model_params[["sample_sigma2_leaf_tau"]]) { - output[["sigma2_leaf_tau_samples"]] <- json_object$get_vector( - "sigma2_leaf_tau_samples", - "parameters" - ) - } - if (model_params[["adaptive_coding"]]) { - output[["b_1_samples"]] <- json_object$get_vector( - "b_1_samples", - "parameters" - ) - output[["b_0_samples"]] <- json_object$get_vector( - "b_0_samples", - "parameters" - ) - } - - # Unpack random effects - if (model_params[["has_rfx"]]) { - output[["rfx_unique_group_ids"]] <- json_object$get_string_vector( - "rfx_unique_group_ids" - ) - output[["rfx_samples"]] <- loadRandomEffectSamplesJson(json_object, 0) - } - - # Unpack propensity model (if it exists) - if (model_params[["internal_propensity_model"]]) { - bart_propensity_string <- json_object$get_string( - "bart_propensity_model" - ) - output[["bart_propensity_model"]] <- createBARTModelFromJsonString( - bart_propensity_string - ) - } + # Unpack random effects + if (model_params[["has_rfx"]]) { + output[["rfx_unique_group_ids"]] <- json_object$get_string_vector( + "rfx_unique_group_ids" + ) + output[["rfx_samples"]] <- loadRandomEffectSamplesJson(json_object, 0) + } - # Unpack covariate preprocessor - preprocessor_metadata_string <- json_object$get_string( - "preprocessor_metadata" + # Unpack propensity model (if it exists) + if (model_params[["internal_propensity_model"]]) { + bart_propensity_string <- json_object$get_string( + "bart_propensity_model" ) - output[["train_set_metadata"]] <- createPreprocessorFromJsonString( - preprocessor_metadata_string + output[["bart_propensity_model"]] <- createBARTModelFromJsonString( + bart_propensity_string ) - - class(output) <- "bcfmodel" - return(output) + } + + # Unpack covariate preprocessor + preprocessor_metadata_string <- json_object$get_string( + "preprocessor_metadata" + ) + output[["train_set_metadata"]] <- createPreprocessorFromJsonString( + preprocessor_metadata_string + ) + + class(output) <- "bcfmodel" + return(output) } #' Convert a JSON file containing sample information on a trained BCF model @@ -3597,13 +3915,13 @@ createBCFModelFromJson <- function(json_object) { #' bcf_model_roundtrip <- createBCFModelFromJsonFile(file.path(tmpjson)) #' unlink(tmpjson) createBCFModelFromJsonFile <- function(json_filename) { - # Load a `CppJson` object from file - bcf_json <- createCppJsonFile(json_filename) + # Load a `CppJson` object from file + bcf_json <- createCppJsonFile(json_filename) - # Create and return the BCF object - bcf_object <- createBCFModelFromJson(bcf_json) + # Create and return the BCF object + bcf_object <- createBCFModelFromJson(bcf_json) - return(bcf_object) + return(bcf_object) } #' Convert a JSON string containing sample information on a trained BCF model @@ -3678,13 +3996,13 @@ createBCFModelFromJsonFile <- function(json_filename) { #' bcf_json <- saveBCFModelToJsonString(bcf_model) #' bcf_model_roundtrip <- createBCFModelFromJsonString(bcf_json) createBCFModelFromJsonString <- function(json_string) { - # Load a `CppJson` object from string - bcf_json <- createCppJsonString(json_string) + # Load a `CppJson` object from string + bcf_json <- createCppJsonString(json_string) - # Create and return the BCF object - bcf_object <- createBCFModelFromJson(bcf_json) + # Create and return the BCF object + bcf_object <- createBCFModelFromJson(bcf_json) - return(bcf_object) + return(bcf_object) } #' Convert a list of (in-memory) JSON strings that represent BCF models to a single combined BCF model object @@ -3759,271 +4077,274 @@ createBCFModelFromJsonString <- function(json_string) { #' bcf_json_list <- list(saveBCFModelToJson(bcf_model)) #' bcf_model_roundtrip <- createBCFModelFromCombinedJson(bcf_json_list) createBCFModelFromCombinedJson <- function(json_object_list) { - # Initialize the BCF model - output <- list() - - # For scalar / preprocessing details which aren't sample-dependent, - # defer to the first json - json_object_default <- json_object_list[[1]] - - # Unpack the forests - output[["forests_mu"]] <- loadForestContainerCombinedJson( - json_object_list, - "forest_0" - ) - output[["forests_tau"]] <- loadForestContainerCombinedJson( - json_object_list, - "forest_1" + # Initialize the BCF model + output <- list() + + # For scalar / preprocessing details which aren't sample-dependent, + # defer to the first json + json_object_default <- json_object_list[[1]] + + # Unpack the forests + output[["forests_mu"]] <- loadForestContainerCombinedJson( + json_object_list, + "forest_0" + ) + output[["forests_tau"]] <- loadForestContainerCombinedJson( + json_object_list, + "forest_1" + ) + include_variance_forest <- json_object_default$get_boolean( + "include_variance_forest" + ) + if (include_variance_forest) { + output[["forests_variance"]] <- loadForestContainerCombinedJson( + json_object_list, + "forest_2" ) - include_variance_forest <- json_object_default$get_boolean( - "include_variance_forest" - ) - if (include_variance_forest) { - output[["forests_variance"]] <- loadForestContainerCombinedJson( - json_object_list, - "forest_2" - ) - } - - # Unpack metadata - train_set_metadata = list() - train_set_metadata[["num_numeric_vars"]] <- json_object_default$get_scalar( - "num_numeric_vars" + } + + # Unpack metadata + train_set_metadata = list() + train_set_metadata[["num_numeric_vars"]] <- json_object_default$get_scalar( + "num_numeric_vars" + ) + train_set_metadata[[ + "num_ordered_cat_vars" + ]] <- json_object_default$get_scalar("num_ordered_cat_vars") + train_set_metadata[[ + "num_unordered_cat_vars" + ]] <- json_object_default$get_scalar("num_unordered_cat_vars") + if (train_set_metadata[["num_numeric_vars"]] > 0) { + train_set_metadata[[ + "numeric_vars" + ]] <- json_object_default$get_string_vector("numeric_vars") + } + if (train_set_metadata[["num_ordered_cat_vars"]] > 0) { + train_set_metadata[[ + "ordered_cat_vars" + ]] <- json_object_default$get_string_vector("ordered_cat_vars") + train_set_metadata[[ + "ordered_unique_levels" + ]] <- json_object_default$get_string_list( + "ordered_unique_levels", + train_set_metadata[["ordered_cat_vars"]] ) + } + if (train_set_metadata[["num_unordered_cat_vars"]] > 0) { train_set_metadata[[ - "num_ordered_cat_vars" - ]] <- json_object_default$get_scalar("num_ordered_cat_vars") + "unordered_cat_vars" + ]] <- json_object_default$get_string_vector("unordered_cat_vars") train_set_metadata[[ - "num_unordered_cat_vars" - ]] <- json_object_default$get_scalar("num_unordered_cat_vars") - if (train_set_metadata[["num_numeric_vars"]] > 0) { - train_set_metadata[[ - "numeric_vars" - ]] <- json_object_default$get_string_vector("numeric_vars") - } - if (train_set_metadata[["num_ordered_cat_vars"]] > 0) { - train_set_metadata[[ - "ordered_cat_vars" - ]] <- json_object_default$get_string_vector("ordered_cat_vars") - train_set_metadata[[ - "ordered_unique_levels" - ]] <- json_object_default$get_string_list( - "ordered_unique_levels", - train_set_metadata[["ordered_cat_vars"]] + "unordered_unique_levels" + ]] <- json_object_default$get_string_list( + "unordered_unique_levels", + train_set_metadata[["unordered_cat_vars"]] + ) + } + output[["train_set_metadata"]] <- train_set_metadata + + # Unpack model params + model_params = list() + model_params[["outcome_scale"]] <- json_object_default$get_scalar( + "outcome_scale" + ) + model_params[["outcome_mean"]] <- json_object_default$get_scalar( + "outcome_mean" + ) + model_params[["standardize"]] <- json_object_default$get_boolean( + "standardize" + ) + model_params[["initial_sigma2"]] <- json_object_default$get_scalar( + "initial_sigma2" + ) + model_params[["sample_sigma2_global"]] <- json_object_default$get_boolean( + "sample_sigma2_global" + ) + model_params[["sample_sigma2_leaf_mu"]] <- json_object_default$get_boolean( + "sample_sigma2_leaf_mu" + ) + model_params[["sample_sigma2_leaf_tau"]] <- json_object_default$get_boolean( + "sample_sigma2_leaf_tau" + ) + model_params[["include_variance_forest"]] <- include_variance_forest + model_params[["propensity_covariate"]] <- json_object_default$get_string( + "propensity_covariate" + ) + model_params[["has_rfx"]] <- json_object_default$get_boolean("has_rfx") + model_params[["has_rfx_basis"]] <- json_object_default$get_boolean( + "has_rfx_basis" + ) + model_params[["num_rfx_basis"]] <- json_object_default$get_scalar( + "num_rfx_basis" + ) + model_params[["num_covariates"]] <- json_object_default$get_scalar( + "num_covariates" + ) + model_params[["num_chains"]] <- json_object_default$get_scalar("num_chains") + model_params[["keep_every"]] <- json_object_default$get_scalar("keep_every") + model_params[["adaptive_coding"]] <- json_object_default$get_boolean( + "adaptive_coding" + ) + model_params[["multivariate_treatment"]] <- json_object_default$get_boolean( + "multivariate_treatment" + ) + model_params[[ + "internal_propensity_model" + ]] <- json_object_default$get_boolean("internal_propensity_model") + model_params[["probit_outcome_model"]] <- json_object_default$get_boolean( + "probit_outcome_model" + ) + model_params[["rfx_model_spec"]] <- json_object_default$get_string( + "rfx_model_spec" + ) + + # Combine values that are sample-specific + for (i in 1:length(json_object_list)) { + json_object <- json_object_list[[i]] + if (i == 1) { + model_params[["num_gfr"]] <- json_object$get_scalar("num_gfr") + model_params[["num_burnin"]] <- json_object$get_scalar("num_burnin") + model_params[["num_mcmc"]] <- json_object$get_scalar("num_mcmc") + model_params[["num_samples"]] <- json_object$get_scalar( + "num_samples" + ) + } else { + prev_json <- json_object_list[[i - 1]] + model_params[["num_gfr"]] <- model_params[["num_gfr"]] + + json_object$get_scalar("num_gfr") + model_params[["num_burnin"]] <- model_params[["num_burnin"]] + + json_object$get_scalar("num_burnin") + model_params[["num_mcmc"]] <- model_params[["num_mcmc"]] + + json_object$get_scalar("num_mcmc") + model_params[["num_samples"]] <- model_params[["num_samples"]] + + json_object$get_scalar("num_samples") + } + } + output[["model_params"]] <- model_params + + # Unpack sampled parameters + if (model_params[["sample_sigma2_global"]]) { + for (i in 1:length(json_object_list)) { + json_object <- json_object_list[[i]] + if (i == 1) { + output[["sigma2_global_samples"]] <- json_object$get_vector( + "sigma2_global_samples", + "parameters" ) - } - if (train_set_metadata[["num_unordered_cat_vars"]] > 0) { - train_set_metadata[[ - "unordered_cat_vars" - ]] <- json_object_default$get_string_vector("unordered_cat_vars") - train_set_metadata[[ - "unordered_unique_levels" - ]] <- json_object_default$get_string_list( - "unordered_unique_levels", - train_set_metadata[["unordered_cat_vars"]] + } else { + output[["sigma2_global_samples"]] <- c( + output[["sigma2_global_samples"]], + json_object$get_vector( + "sigma2_global_samples", + "parameters" + ) ) + } } - output[["train_set_metadata"]] <- train_set_metadata - - # Unpack model params - model_params = list() - model_params[["outcome_scale"]] <- json_object_default$get_scalar( - "outcome_scale" - ) - model_params[["outcome_mean"]] <- json_object_default$get_scalar( - "outcome_mean" - ) - model_params[["standardize"]] <- json_object_default$get_boolean( - "standardize" - ) - model_params[["initial_sigma2"]] <- json_object_default$get_scalar( - "initial_sigma2" - ) - model_params[["sample_sigma2_global"]] <- json_object_default$get_boolean( - "sample_sigma2_global" - ) - model_params[["sample_sigma2_leaf_mu"]] <- json_object_default$get_boolean( - "sample_sigma2_leaf_mu" - ) - model_params[["sample_sigma2_leaf_tau"]] <- json_object_default$get_boolean( - "sample_sigma2_leaf_tau" - ) - model_params[["include_variance_forest"]] <- include_variance_forest - model_params[["propensity_covariate"]] <- json_object_default$get_string( - "propensity_covariate" - ) - model_params[["has_rfx"]] <- json_object_default$get_boolean("has_rfx") - model_params[["has_rfx_basis"]] <- json_object_default$get_boolean( - "has_rfx_basis" - ) - model_params[["num_rfx_basis"]] <- json_object_default$get_scalar( - "num_rfx_basis" - ) - model_params[["num_covariates"]] <- json_object_default$get_scalar( - "num_covariates" - ) - model_params[["num_chains"]] <- json_object_default$get_scalar("num_chains") - model_params[["keep_every"]] <- json_object_default$get_scalar("keep_every") - model_params[["adaptive_coding"]] <- json_object_default$get_boolean( - "adaptive_coding" - ) - model_params[["multivariate_treatment"]] <- json_object_default$get_boolean( - "multivariate_treatment" - ) - model_params[[ - "internal_propensity_model" - ]] <- json_object_default$get_boolean("internal_propensity_model") - model_params[["probit_outcome_model"]] <- json_object_default$get_boolean( - "probit_outcome_model" - ) - - # Combine values that are sample-specific + } + if (model_params[["sample_sigma2_leaf_mu"]]) { for (i in 1:length(json_object_list)) { - json_object <- json_object_list[[i]] - if (i == 1) { - model_params[["num_gfr"]] <- json_object$get_scalar("num_gfr") - model_params[["num_burnin"]] <- json_object$get_scalar("num_burnin") - model_params[["num_mcmc"]] <- json_object$get_scalar("num_mcmc") - model_params[["num_samples"]] <- json_object$get_scalar( - "num_samples" - ) - } else { - prev_json <- json_object_list[[i - 1]] - model_params[["num_gfr"]] <- model_params[["num_gfr"]] + - json_object$get_scalar("num_gfr") - model_params[["num_burnin"]] <- model_params[["num_burnin"]] + - json_object$get_scalar("num_burnin") - model_params[["num_mcmc"]] <- model_params[["num_mcmc"]] + - json_object$get_scalar("num_mcmc") - model_params[["num_samples"]] <- model_params[["num_samples"]] + - json_object$get_scalar("num_samples") - } - } - output[["model_params"]] <- model_params - - # Unpack sampled parameters - if (model_params[["sample_sigma2_global"]]) { - for (i in 1:length(json_object_list)) { - json_object <- json_object_list[[i]] - if (i == 1) { - output[["sigma2_global_samples"]] <- json_object$get_vector( - "sigma2_global_samples", - "parameters" - ) - } else { - output[["sigma2_global_samples"]] <- c( - output[["sigma2_global_samples"]], - json_object$get_vector( - "sigma2_global_samples", - "parameters" - ) - ) - } - } - } - if (model_params[["sample_sigma2_leaf_mu"]]) { - for (i in 1:length(json_object_list)) { - json_object <- json_object_list[[i]] - if (i == 1) { - output[["sigma2_leaf_mu_samples"]] <- json_object$get_vector( - "sigma2_leaf_mu_samples", - "parameters" - ) - } else { - output[["sigma2_leaf_mu_samples"]] <- c( - output[["sigma2_leaf_mu_samples"]], - json_object$get_vector( - "sigma2_leaf_mu_samples", - "parameters" - ) - ) - } - } - } - if (model_params[["sample_sigma2_leaf_tau"]]) { - for (i in 1:length(json_object_list)) { - json_object <- json_object_list[[i]] - if (i == 1) { - output[["sigma2_leaf_tau_samples"]] <- json_object$get_vector( - "sigma2_leaf_tau_samples", - "parameters" - ) - } else { - output[["sigma2_leaf_tau_samples"]] <- c( - output[["sigma2_leaf_tau_samples"]], - json_object$get_vector( - "sigma2_leaf_tau_samples", - "parameters" - ) - ) - } - } - } - if (model_params[["sample_sigma2_leaf_tau"]]) { - for (i in 1:length(json_object_list)) { - json_object <- json_object_list[[i]] - if (i == 1) { - output[["sigma2_leaf_tau_samples"]] <- json_object$get_vector( - "sigma2_leaf_tau_samples", - "parameters" - ) - } else { - output[["sigma2_leaf_tau_samples"]] <- c( - output[["sigma2_leaf_tau_samples"]], - json_object$get_vector( - "sigma2_leaf_tau_samples", - "parameters" - ) - ) - } - } + json_object <- json_object_list[[i]] + if (i == 1) { + output[["sigma2_leaf_mu_samples"]] <- json_object$get_vector( + "sigma2_leaf_mu_samples", + "parameters" + ) + } else { + output[["sigma2_leaf_mu_samples"]] <- c( + output[["sigma2_leaf_mu_samples"]], + json_object$get_vector( + "sigma2_leaf_mu_samples", + "parameters" + ) + ) + } } - if (model_params[["adaptive_coding"]]) { - for (i in 1:length(json_object_list)) { - json_object <- json_object_list[[i]] - if (i == 1) { - output[["b_1_samples"]] <- json_object$get_vector( - "b_1_samples", - "parameters" - ) - output[["b_0_samples"]] <- json_object$get_vector( - "b_0_samples", - "parameters" - ) - } else { - output[["b_1_samples"]] <- c( - output[["b_1_samples"]], - json_object$get_vector("b_1_samples", "parameters") - ) - output[["b_0_samples"]] <- c( - output[["b_0_samples"]], - json_object$get_vector("b_0_samples", "parameters") - ) - } - } + } + if (model_params[["sample_sigma2_leaf_tau"]]) { + for (i in 1:length(json_object_list)) { + json_object <- json_object_list[[i]] + if (i == 1) { + output[["sigma2_leaf_tau_samples"]] <- json_object$get_vector( + "sigma2_leaf_tau_samples", + "parameters" + ) + } else { + output[["sigma2_leaf_tau_samples"]] <- c( + output[["sigma2_leaf_tau_samples"]], + json_object$get_vector( + "sigma2_leaf_tau_samples", + "parameters" + ) + ) + } } - - # Unpack random effects - if (model_params[["has_rfx"]]) { - output[[ - "rfx_unique_group_ids" - ]] <- json_object_default$get_string_vector("rfx_unique_group_ids") - output[["rfx_samples"]] <- loadRandomEffectSamplesCombinedJson( - json_object_list, - 0 + } + if (model_params[["sample_sigma2_leaf_tau"]]) { + for (i in 1:length(json_object_list)) { + json_object <- json_object_list[[i]] + if (i == 1) { + output[["sigma2_leaf_tau_samples"]] <- json_object$get_vector( + "sigma2_leaf_tau_samples", + "parameters" + ) + } else { + output[["sigma2_leaf_tau_samples"]] <- c( + output[["sigma2_leaf_tau_samples"]], + json_object$get_vector( + "sigma2_leaf_tau_samples", + "parameters" + ) ) + } } - - # Unpack covariate preprocessor - preprocessor_metadata_string <- json_object_default$get_string( - "preprocessor_metadata" - ) - output[["train_set_metadata"]] <- createPreprocessorFromJsonString( - preprocessor_metadata_string + } + if (model_params[["adaptive_coding"]]) { + for (i in 1:length(json_object_list)) { + json_object <- json_object_list[[i]] + if (i == 1) { + output[["b_1_samples"]] <- json_object$get_vector( + "b_1_samples", + "parameters" + ) + output[["b_0_samples"]] <- json_object$get_vector( + "b_0_samples", + "parameters" + ) + } else { + output[["b_1_samples"]] <- c( + output[["b_1_samples"]], + json_object$get_vector("b_1_samples", "parameters") + ) + output[["b_0_samples"]] <- c( + output[["b_0_samples"]], + json_object$get_vector("b_0_samples", "parameters") + ) + } + } + } + + # Unpack random effects + if (model_params[["has_rfx"]]) { + output[[ + "rfx_unique_group_ids" + ]] <- json_object_default$get_string_vector("rfx_unique_group_ids") + output[["rfx_samples"]] <- loadRandomEffectSamplesCombinedJson( + json_object_list, + 0 ) - - class(output) <- "bcfmodel" - return(output) + } + + # Unpack covariate preprocessor + preprocessor_metadata_string <- json_object_default$get_string( + "preprocessor_metadata" + ) + output[["train_set_metadata"]] <- createPreprocessorFromJsonString( + preprocessor_metadata_string + ) + + class(output) <- "bcfmodel" + return(output) } #' Convert a list of (in-memory) JSON strings that represent BCF models to a single combined BCF model object @@ -4098,284 +4419,287 @@ createBCFModelFromCombinedJson <- function(json_object_list) { #' bcf_json_string_list <- list(saveBCFModelToJsonString(bcf_model)) #' bcf_model_roundtrip <- createBCFModelFromCombinedJsonString(bcf_json_string_list) createBCFModelFromCombinedJsonString <- function(json_string_list) { - # Initialize the BCF model - output <- list() - - # Convert JSON strings - json_object_list <- list() - for (i in 1:length(json_string_list)) { - json_string <- json_string_list[[i]] - json_object_list[[i]] <- createCppJsonString(json_string) - # Add runtime check for separately serialized propensity models - # We don't support merging BCF models with independent propensity models - # this way at the moment - if (json_object_list[[i]]$get_boolean("internal_propensity_model")) { - stop( - "Combining separate BCF models with cached internal propensity models is currently unsupported. To make this work, please first train a propensity model and then pass the propensities as data to the separate BCF models before sampling." - ) - } - } - - # For scalar / preprocessing details which aren't sample-dependent, - # defer to the first json - json_object_default <- json_object_list[[1]] - - # Unpack the forests - output[["forests_mu"]] <- loadForestContainerCombinedJson( - json_object_list, - "forest_0" - ) - output[["forests_tau"]] <- loadForestContainerCombinedJson( - json_object_list, - "forest_1" + # Initialize the BCF model + output <- list() + + # Convert JSON strings + json_object_list <- list() + for (i in 1:length(json_string_list)) { + json_string <- json_string_list[[i]] + json_object_list[[i]] <- createCppJsonString(json_string) + # Add runtime check for separately serialized propensity models + # We don't support merging BCF models with independent propensity models + # this way at the moment + if (json_object_list[[i]]$get_boolean("internal_propensity_model")) { + stop( + "Combining separate BCF models with cached internal propensity models is currently unsupported. To make this work, please first train a propensity model and then pass the propensities as data to the separate BCF models before sampling." + ) + } + } + + # For scalar / preprocessing details which aren't sample-dependent, + # defer to the first json + json_object_default <- json_object_list[[1]] + + # Unpack the forests + output[["forests_mu"]] <- loadForestContainerCombinedJson( + json_object_list, + "forest_0" + ) + output[["forests_tau"]] <- loadForestContainerCombinedJson( + json_object_list, + "forest_1" + ) + include_variance_forest <- json_object_default$get_boolean( + "include_variance_forest" + ) + if (include_variance_forest) { + output[["forests_variance"]] <- loadForestContainerCombinedJson( + json_object_list, + "forest_2" ) - include_variance_forest <- json_object_default$get_boolean( - "include_variance_forest" - ) - if (include_variance_forest) { - output[["forests_variance"]] <- loadForestContainerCombinedJson( - json_object_list, - "forest_2" - ) - } - - # Unpack metadata - train_set_metadata = list() - train_set_metadata[["num_numeric_vars"]] <- json_object_default$get_scalar( - "num_numeric_vars" + } + + # Unpack metadata + train_set_metadata = list() + train_set_metadata[["num_numeric_vars"]] <- json_object_default$get_scalar( + "num_numeric_vars" + ) + train_set_metadata[[ + "num_ordered_cat_vars" + ]] <- json_object_default$get_scalar("num_ordered_cat_vars") + train_set_metadata[[ + "num_unordered_cat_vars" + ]] <- json_object_default$get_scalar("num_unordered_cat_vars") + if (train_set_metadata[["num_numeric_vars"]] > 0) { + train_set_metadata[[ + "numeric_vars" + ]] <- json_object_default$get_string_vector("numeric_vars") + } + if (train_set_metadata[["num_ordered_cat_vars"]] > 0) { + train_set_metadata[[ + "ordered_cat_vars" + ]] <- json_object_default$get_string_vector("ordered_cat_vars") + train_set_metadata[[ + "ordered_unique_levels" + ]] <- json_object_default$get_string_list( + "ordered_unique_levels", + train_set_metadata[["ordered_cat_vars"]] ) + } + if (train_set_metadata[["num_unordered_cat_vars"]] > 0) { train_set_metadata[[ - "num_ordered_cat_vars" - ]] <- json_object_default$get_scalar("num_ordered_cat_vars") + "unordered_cat_vars" + ]] <- json_object_default$get_string_vector("unordered_cat_vars") train_set_metadata[[ - "num_unordered_cat_vars" - ]] <- json_object_default$get_scalar("num_unordered_cat_vars") - if (train_set_metadata[["num_numeric_vars"]] > 0) { - train_set_metadata[[ - "numeric_vars" - ]] <- json_object_default$get_string_vector("numeric_vars") - } - if (train_set_metadata[["num_ordered_cat_vars"]] > 0) { - train_set_metadata[[ - "ordered_cat_vars" - ]] <- json_object_default$get_string_vector("ordered_cat_vars") - train_set_metadata[[ - "ordered_unique_levels" - ]] <- json_object_default$get_string_list( - "ordered_unique_levels", - train_set_metadata[["ordered_cat_vars"]] + "unordered_unique_levels" + ]] <- json_object_default$get_string_list( + "unordered_unique_levels", + train_set_metadata[["unordered_cat_vars"]] + ) + } + output[["train_set_metadata"]] <- train_set_metadata + + # Unpack model params + model_params = list() + model_params[["outcome_scale"]] <- json_object_default$get_scalar( + "outcome_scale" + ) + model_params[["outcome_mean"]] <- json_object_default$get_scalar( + "outcome_mean" + ) + model_params[["standardize"]] <- json_object_default$get_boolean( + "standardize" + ) + model_params[["initial_sigma2"]] <- json_object_default$get_scalar( + "initial_sigma2" + ) + model_params[["sample_sigma2_global"]] <- json_object_default$get_boolean( + "sample_sigma2_global" + ) + model_params[["sample_sigma2_leaf_mu"]] <- json_object_default$get_boolean( + "sample_sigma2_leaf_mu" + ) + model_params[["sample_sigma2_leaf_tau"]] <- json_object_default$get_boolean( + "sample_sigma2_leaf_tau" + ) + model_params[["include_variance_forest"]] <- include_variance_forest + model_params[["propensity_covariate"]] <- json_object_default$get_string( + "propensity_covariate" + ) + model_params[["has_rfx"]] <- json_object_default$get_boolean("has_rfx") + model_params[["has_rfx_basis"]] <- json_object_default$get_boolean( + "has_rfx_basis" + ) + model_params[["num_rfx_basis"]] <- json_object_default$get_scalar( + "num_rfx_basis" + ) + model_params[["num_covariates"]] <- json_object_default$get_scalar( + "num_covariates" + ) + model_params[["num_chains"]] <- json_object_default$get_scalar("num_chains") + model_params[["keep_every"]] <- json_object_default$get_scalar("keep_every") + model_params[["multivariate_treatment"]] <- json_object_default$get_boolean( + "multivariate_treatment" + ) + model_params[["adaptive_coding"]] <- json_object_default$get_boolean( + "adaptive_coding" + ) + model_params[[ + "internal_propensity_model" + ]] <- json_object_default$get_boolean("internal_propensity_model") + model_params[["probit_outcome_model"]] <- json_object_default$get_boolean( + "probit_outcome_model" + ) + model_params[["rfx_model_spec"]] <- json_object_default$get_string( + "rfx_model_spec" + ) + + # Combine values that are sample-specific + for (i in 1:length(json_object_list)) { + json_object <- json_object_list[[i]] + if (i == 1) { + model_params[["num_gfr"]] <- json_object$get_scalar("num_gfr") + model_params[["num_burnin"]] <- json_object$get_scalar("num_burnin") + model_params[["num_mcmc"]] <- json_object$get_scalar("num_mcmc") + model_params[["num_samples"]] <- json_object$get_scalar( + "num_samples" + ) + } else { + prev_json <- json_object_list[[i - 1]] + model_params[["num_gfr"]] <- model_params[["num_gfr"]] + + json_object$get_scalar("num_gfr") + model_params[["num_burnin"]] <- model_params[["num_burnin"]] + + json_object$get_scalar("num_burnin") + model_params[["num_mcmc"]] <- model_params[["num_mcmc"]] + + json_object$get_scalar("num_mcmc") + model_params[["num_samples"]] <- model_params[["num_samples"]] + + json_object$get_scalar("num_samples") + } + } + output[["model_params"]] <- model_params + + # Unpack sampled parameters + if (model_params[["sample_sigma2_global"]]) { + for (i in 1:length(json_object_list)) { + json_object <- json_object_list[[i]] + if (i == 1) { + output[["sigma2_global_samples"]] <- json_object$get_vector( + "sigma2_global_samples", + "parameters" ) - } - if (train_set_metadata[["num_unordered_cat_vars"]] > 0) { - train_set_metadata[[ - "unordered_cat_vars" - ]] <- json_object_default$get_string_vector("unordered_cat_vars") - train_set_metadata[[ - "unordered_unique_levels" - ]] <- json_object_default$get_string_list( - "unordered_unique_levels", - train_set_metadata[["unordered_cat_vars"]] + } else { + output[["sigma2_global_samples"]] <- c( + output[["sigma2_global_samples"]], + json_object$get_vector( + "sigma2_global_samples", + "parameters" + ) ) + } } - output[["train_set_metadata"]] <- train_set_metadata - - # Unpack model params - model_params = list() - model_params[["outcome_scale"]] <- json_object_default$get_scalar( - "outcome_scale" - ) - model_params[["outcome_mean"]] <- json_object_default$get_scalar( - "outcome_mean" - ) - model_params[["standardize"]] <- json_object_default$get_boolean( - "standardize" - ) - model_params[["initial_sigma2"]] <- json_object_default$get_scalar( - "initial_sigma2" - ) - model_params[["sample_sigma2_global"]] <- json_object_default$get_boolean( - "sample_sigma2_global" - ) - model_params[["sample_sigma2_leaf_mu"]] <- json_object_default$get_boolean( - "sample_sigma2_leaf_mu" - ) - model_params[["sample_sigma2_leaf_tau"]] <- json_object_default$get_boolean( - "sample_sigma2_leaf_tau" - ) - model_params[["include_variance_forest"]] <- include_variance_forest - model_params[["propensity_covariate"]] <- json_object_default$get_string( - "propensity_covariate" - ) - model_params[["has_rfx"]] <- json_object_default$get_boolean("has_rfx") - model_params[["has_rfx_basis"]] <- json_object_default$get_boolean( - "has_rfx_basis" - ) - model_params[["num_rfx_basis"]] <- json_object_default$get_scalar( - "num_rfx_basis" - ) - model_params[["num_covariates"]] <- json_object_default$get_scalar( - "num_covariates" - ) - model_params[["num_chains"]] <- json_object_default$get_scalar("num_chains") - model_params[["keep_every"]] <- json_object_default$get_scalar("keep_every") - model_params[["multivariate_treatment"]] <- json_object_default$get_boolean( - "multivariate_treatment" - ) - model_params[["adaptive_coding"]] <- json_object_default$get_boolean( - "adaptive_coding" - ) - model_params[[ - "internal_propensity_model" - ]] <- json_object_default$get_boolean("internal_propensity_model") - model_params[["probit_outcome_model"]] <- json_object_default$get_boolean( - "probit_outcome_model" - ) - - # Combine values that are sample-specific + } + if (model_params[["sample_sigma2_leaf_mu"]]) { for (i in 1:length(json_object_list)) { - json_object <- json_object_list[[i]] - if (i == 1) { - model_params[["num_gfr"]] <- json_object$get_scalar("num_gfr") - model_params[["num_burnin"]] <- json_object$get_scalar("num_burnin") - model_params[["num_mcmc"]] <- json_object$get_scalar("num_mcmc") - model_params[["num_samples"]] <- json_object$get_scalar( - "num_samples" - ) - } else { - prev_json <- json_object_list[[i - 1]] - model_params[["num_gfr"]] <- model_params[["num_gfr"]] + - json_object$get_scalar("num_gfr") - model_params[["num_burnin"]] <- model_params[["num_burnin"]] + - json_object$get_scalar("num_burnin") - model_params[["num_mcmc"]] <- model_params[["num_mcmc"]] + - json_object$get_scalar("num_mcmc") - model_params[["num_samples"]] <- model_params[["num_samples"]] + - json_object$get_scalar("num_samples") - } - } - output[["model_params"]] <- model_params - - # Unpack sampled parameters - if (model_params[["sample_sigma2_global"]]) { - for (i in 1:length(json_object_list)) { - json_object <- json_object_list[[i]] - if (i == 1) { - output[["sigma2_global_samples"]] <- json_object$get_vector( - "sigma2_global_samples", - "parameters" - ) - } else { - output[["sigma2_global_samples"]] <- c( - output[["sigma2_global_samples"]], - json_object$get_vector( - "sigma2_global_samples", - "parameters" - ) - ) - } - } - } - if (model_params[["sample_sigma2_leaf_mu"]]) { - for (i in 1:length(json_object_list)) { - json_object <- json_object_list[[i]] - if (i == 1) { - output[["sigma2_leaf_mu_samples"]] <- json_object$get_vector( - "sigma2_leaf_mu_samples", - "parameters" - ) - } else { - output[["sigma2_leaf_mu_samples"]] <- c( - output[["sigma2_leaf_mu_samples"]], - json_object$get_vector( - "sigma2_leaf_mu_samples", - "parameters" - ) - ) - } - } - } - if (model_params[["sample_sigma2_leaf_tau"]]) { - for (i in 1:length(json_object_list)) { - json_object <- json_object_list[[i]] - if (i == 1) { - output[["sigma2_leaf_tau_samples"]] <- json_object$get_vector( - "sigma2_leaf_tau_samples", - "parameters" - ) - } else { - output[["sigma2_leaf_tau_samples"]] <- c( - output[["sigma2_leaf_tau_samples"]], - json_object$get_vector( - "sigma2_leaf_tau_samples", - "parameters" - ) - ) - } - } - } - if (model_params[["sample_sigma2_leaf_tau"]]) { - for (i in 1:length(json_object_list)) { - json_object <- json_object_list[[i]] - if (i == 1) { - output[["sigma2_leaf_tau_samples"]] <- json_object$get_vector( - "sigma2_leaf_tau_samples", - "parameters" - ) - } else { - output[["sigma2_leaf_tau_samples"]] <- c( - output[["sigma2_leaf_tau_samples"]], - json_object$get_vector( - "sigma2_leaf_tau_samples", - "parameters" - ) - ) - } - } + json_object <- json_object_list[[i]] + if (i == 1) { + output[["sigma2_leaf_mu_samples"]] <- json_object$get_vector( + "sigma2_leaf_mu_samples", + "parameters" + ) + } else { + output[["sigma2_leaf_mu_samples"]] <- c( + output[["sigma2_leaf_mu_samples"]], + json_object$get_vector( + "sigma2_leaf_mu_samples", + "parameters" + ) + ) + } } - if (model_params[["adaptive_coding"]]) { - for (i in 1:length(json_object_list)) { - json_object <- json_object_list[[i]] - if (i == 1) { - output[["b_1_samples"]] <- json_object$get_vector( - "b_1_samples", - "parameters" - ) - output[["b_0_samples"]] <- json_object$get_vector( - "b_0_samples", - "parameters" - ) - } else { - output[["b_1_samples"]] <- c( - output[["b_1_samples"]], - json_object$get_vector("b_1_samples", "parameters") - ) - output[["b_0_samples"]] <- c( - output[["b_0_samples"]], - json_object$get_vector("b_0_samples", "parameters") - ) - } - } + } + if (model_params[["sample_sigma2_leaf_tau"]]) { + for (i in 1:length(json_object_list)) { + json_object <- json_object_list[[i]] + if (i == 1) { + output[["sigma2_leaf_tau_samples"]] <- json_object$get_vector( + "sigma2_leaf_tau_samples", + "parameters" + ) + } else { + output[["sigma2_leaf_tau_samples"]] <- c( + output[["sigma2_leaf_tau_samples"]], + json_object$get_vector( + "sigma2_leaf_tau_samples", + "parameters" + ) + ) + } } - - # Unpack random effects - if (model_params[["has_rfx"]]) { - output[[ - "rfx_unique_group_ids" - ]] <- json_object_default$get_string_vector("rfx_unique_group_ids") - output[["rfx_samples"]] <- loadRandomEffectSamplesCombinedJson( - json_object_list, - 0 + } + if (model_params[["sample_sigma2_leaf_tau"]]) { + for (i in 1:length(json_object_list)) { + json_object <- json_object_list[[i]] + if (i == 1) { + output[["sigma2_leaf_tau_samples"]] <- json_object$get_vector( + "sigma2_leaf_tau_samples", + "parameters" + ) + } else { + output[["sigma2_leaf_tau_samples"]] <- c( + output[["sigma2_leaf_tau_samples"]], + json_object$get_vector( + "sigma2_leaf_tau_samples", + "parameters" + ) ) + } } - - # Unpack covariate preprocessor - preprocessor_metadata_string <- json_object_default$get_string( - "preprocessor_metadata" - ) - output[["train_set_metadata"]] <- createPreprocessorFromJsonString( - preprocessor_metadata_string + } + if (model_params[["adaptive_coding"]]) { + for (i in 1:length(json_object_list)) { + json_object <- json_object_list[[i]] + if (i == 1) { + output[["b_1_samples"]] <- json_object$get_vector( + "b_1_samples", + "parameters" + ) + output[["b_0_samples"]] <- json_object$get_vector( + "b_0_samples", + "parameters" + ) + } else { + output[["b_1_samples"]] <- c( + output[["b_1_samples"]], + json_object$get_vector("b_1_samples", "parameters") + ) + output[["b_0_samples"]] <- c( + output[["b_0_samples"]], + json_object$get_vector("b_0_samples", "parameters") + ) + } + } + } + + # Unpack random effects + if (model_params[["has_rfx"]]) { + output[[ + "rfx_unique_group_ids" + ]] <- json_object_default$get_string_vector("rfx_unique_group_ids") + output[["rfx_samples"]] <- loadRandomEffectSamplesCombinedJson( + json_object_list, + 0 ) - - class(output) <- "bcfmodel" - return(output) + } + + # Unpack covariate preprocessor + preprocessor_metadata_string <- json_object_default$get_string( + "preprocessor_metadata" + ) + output[["train_set_metadata"]] <- createPreprocessorFromJsonString( + preprocessor_metadata_string + ) + + class(output) <- "bcfmodel" + return(output) } diff --git a/R/calibration.R b/R/calibration.R index cbcd293e..a82f84f4 100644 --- a/R/calibration.R +++ b/R/calibration.R @@ -22,25 +22,25 @@ #' sigma2hat <- mean(resid(lm(y~X))^2) #' mean(var(y)/rgamma(100000, nu, rate = nu*lambda) < sigma2hat) calibrateInverseGammaErrorVariance <- function( - y, - X, - W = NULL, - nu = 3, - quant = 0.9, - standardize = TRUE + y, + X, + W = NULL, + nu = 3, + quant = 0.9, + standardize = TRUE ) { - # Compute regression basis - if (!is.null(W)) { - basis <- cbind(X, W) - } else { - basis <- X - } - # Standardize outcome if requested - if (standardize) { - y <- (y - mean(y)) / sd(y) - } - # Compute the "regression-based" overestimate of sigma^2 - sigma2hat <- mean(resid(lm(y ~ basis))^2) - # Calibrate lambda based on the implied quantile of sigma2hat - return((sigma2hat * qgamma(1 - quant, nu)) / nu) + # Compute regression basis + if (!is.null(W)) { + basis <- cbind(X, W) + } else { + basis <- X + } + # Standardize outcome if requested + if (standardize) { + y <- (y - mean(y)) / sd(y) + } + # Compute the "regression-based" overestimate of sigma^2 + sigma2hat <- mean(resid(lm(y ~ basis))^2) + # Calibrate lambda based on the implied quantile of sigma2hat + return((sigma2hat * qgamma(1 - quant, nu)) / nu) } diff --git a/R/config.R b/R/config.R index bf4d51bc..8110a9f5 100644 --- a/R/config.R +++ b/R/config.R @@ -10,423 +10,423 @@ #' forest model they wish to run. ForestModelConfig <- R6::R6Class( - classname = "ForestModelConfig", - cloneable = FALSE, - public = list( - #' @field feature_types Vector of integer-coded feature types (integers where 0 = numeric, 1 = ordered categorical, 2 = unordered categorical) - feature_types = NULL, - - #' @field sweep_update_indices Vector of trees to update in a sweep - sweep_update_indices = NULL, - - #' @field num_trees Number of trees in the forest being sampled - num_trees = NULL, - - #' @field num_features Number of features in training dataset - num_features = NULL, - - #' @field num_observations Number of observations in training dataset - num_observations = NULL, - - #' @field leaf_dimension Dimension of the leaf model - leaf_dimension = NULL, - - #' @field alpha Root node split probability in tree prior - alpha = NULL, - - #' @field beta Depth prior penalty in tree prior - beta = NULL, - - #' @field min_samples_leaf Minimum number of samples in a tree leaf - min_samples_leaf = NULL, - - #' @field max_depth Maximum depth of any tree in the ensemble in the model. Setting to `-1` does not enforce any depth limits on trees. - max_depth = NULL, - - #' @field leaf_model_type Integer specifying the leaf model type (0 = constant leaf, 1 = univariate leaf regression, 2 = multivariate leaf regression) - leaf_model_type = NULL, - - #' @field leaf_model_scale Scale parameter used in Gaussian leaf models - leaf_model_scale = NULL, - - #' @field variable_weights Vector specifying sampling probability for all p covariates in ForestDataset - variable_weights = NULL, - - #' @field variance_forest_shape Shape parameter for IG leaf models (applicable when `leaf_model_type = 3`) - variance_forest_shape = NULL, - - #' @field variance_forest_scale Scale parameter for IG leaf models (applicable when `leaf_model_type = 3`) - variance_forest_scale = NULL, - - #' @field cutpoint_grid_size Number of unique cutpoints to consider - cutpoint_grid_size = NULL, - - #' @field num_features_subsample Number of features to subsample for the GFR algorithm - num_features_subsample = NULL, - - #' Create a new ForestModelConfig object. - #' - #' @param feature_types Vector of integer-coded feature types (where 0 = numeric, 1 = ordered categorical, 2 = unordered categorical) - #' @param sweep_update_indices Vector of (0-indexed) indices of trees to update in a sweep - #' @param num_trees Number of trees in the forest being sampled - #' @param num_features Number of features in training dataset - #' @param num_observations Number of observations in training dataset - #' @param variable_weights Vector specifying sampling probability for all p covariates in ForestDataset - #' @param leaf_dimension Dimension of the leaf model (default: `1`) - #' @param alpha Root node split probability in tree prior (default: `0.95`) - #' @param beta Depth prior penalty in tree prior (default: `2.0`) - #' @param min_samples_leaf Minimum number of samples in a tree leaf (default: `5`) - #' @param max_depth Maximum depth of any tree in the ensemble in the model. Setting to `-1` does not enforce any depth limits on trees. Default: `-1`. - #' @param leaf_model_type Integer specifying the leaf model type (0 = constant leaf, 1 = univariate leaf regression, 2 = multivariate leaf regression). Default: `0`. - #' @param leaf_model_scale Scale parameter used in Gaussian leaf models (can either be a scalar or a q x q matrix, where q is the dimensionality of the basis and is only >1 when `leaf_model_int = 2`). Calibrated internally as `1/num_trees`, propagated along diagonal if needed for multivariate leaf models. - #' @param variance_forest_shape Shape parameter for IG leaf models (applicable when `leaf_model_type = 3`). Default: `1`. - #' @param variance_forest_scale Scale parameter for IG leaf models (applicable when `leaf_model_type = 3`). Default: `1`. - #' @param cutpoint_grid_size Number of unique cutpoints to consider (default: `100`) - #' @param num_features_subsample Number of features to subsample for the GFR algorithm - #' - #' @return A new ForestModelConfig object. - initialize = function( - feature_types = NULL, - sweep_update_indices = NULL, - num_trees = NULL, - num_features = NULL, - num_observations = NULL, - variable_weights = NULL, - leaf_dimension = 1, - alpha = 0.95, - beta = 2.0, - min_samples_leaf = 5, - max_depth = -1, - leaf_model_type = 1, - leaf_model_scale = NULL, - variance_forest_shape = 1.0, - variance_forest_scale = 1.0, - cutpoint_grid_size = 100, - num_features_subsample = NULL - ) { - if (is.null(feature_types)) { - if (is.null(num_features)) { - stop( - "Neither of `num_features` nor `feature_types` (a vector from which `num_features` can be inferred) was provided. Please provide at least one of these inputs when creating a ForestModelConfig object." - ) - } - warning( - "`feature_types` not provided, will be assumed to be numeric" - ) - feature_types <- rep(0, num_features) - } else { - if (is.null(num_features)) { - num_features <- length(feature_types) - } - } - if (is.null(variable_weights)) { - warning( - "`variable_weights` not provided, will be assumed to be equal-weighted" - ) - variable_weights <- rep(1 / num_features, num_features) - } - if (is.null(num_trees)) { - stop("num_trees must be provided") - } - if (!is.null(sweep_update_indices)) { - stopifnot(min(sweep_update_indices) >= 0) - stopifnot(max(sweep_update_indices) < num_trees) - } - if (is.null(num_observations)) { - stop("num_observations must be provided") - } - if (num_features != length(feature_types)) { - stop("`feature_types` must have `num_features` total elements") - } - if (num_features != length(variable_weights)) { - stop( - "`variable_weights` must have `num_features` total elements" - ) - } - self$feature_types <- feature_types - self$sweep_update_indices <- sweep_update_indices - self$variable_weights <- variable_weights - self$num_trees <- num_trees - self$num_features <- num_features - self$num_observations <- num_observations - self$leaf_dimension <- leaf_dimension - self$alpha <- alpha - self$beta <- beta - self$min_samples_leaf <- min_samples_leaf - self$max_depth <- max_depth - self$variance_forest_shape <- variance_forest_shape - self$variance_forest_scale <- variance_forest_scale - self$cutpoint_grid_size <- cutpoint_grid_size - if (is.null(num_features_subsample)) { - num_features_subsample <- num_features - } - if (num_features_subsample > num_features) { - stop( - "`num_features_subsample` cannot be larger than `num_features`" - ) - } - if (num_features_subsample <= 0) { - stop("`num_features_subsample` must be at least 1") - } - self$num_features_subsample <- num_features_subsample - - if (!(as.integer(leaf_model_type) == leaf_model_type)) { - stop("`leaf_model_type` must be an integer between 0 and 3") - if ((leaf_model_type < 0) | (leaf_model_type > 3)) { - stop("`leaf_model_type` must be an integer between 0 and 3") - } - } - self$leaf_model_type <- leaf_model_type - - if (is.null(leaf_model_scale)) { - self$leaf_model_scale <- diag(1 / num_trees, leaf_dimension) - } else if (is.matrix(leaf_model_scale)) { - if (ncol(leaf_model_scale) != nrow(leaf_model_scale)) { - stop("`leaf_model_scale` must be a square matrix") - } - if (ncol(leaf_model_scale) != leaf_dimension) { - stop( - "`leaf_model_scale` must have `leaf_dimension` rows and columns" - ) - } - self$leaf_model_scale <- leaf_model_scale - } else { - if (leaf_model_scale <= 0) { - stop( - "`leaf_model_scale` must be positive, if provided as scalar" - ) - } - self$leaf_model_scale <- diag(leaf_model_scale, leaf_dimension) - } - }, - - #' @description - #' Update feature types - #' @param feature_types Vector of integer-coded feature types (integers where 0 = numeric, 1 = ordered categorical, 2 = unordered categorical) - update_feature_types = function(feature_types) { - stopifnot(length(feature_types) == self$num_features) - self$feature_types <- feature_types - }, - - #' @description - #' Update sweep update indices - #' @param sweep_update_indices Vector of (0-indexed) indices of trees to update in a sweep - update_sweep_indices = function(sweep_update_indices) { - if (!is.null(sweep_update_indices)) { - stopifnot(min(sweep_update_indices) >= 0) - stopifnot(max(sweep_update_indices) < self$num_trees) - } - self$sweep_update_indices <- sweep_update_indices - }, - - #' @description - #' Update variable weights - #' @param variable_weights Vector specifying sampling probability for all p covariates in ForestDataset - update_variable_weights = function(variable_weights) { - stopifnot(length(variable_weights) == self$num_features) - self$variable_weights <- variable_weights - }, - - #' @description - #' Update root node split probability in tree prior - #' @param alpha Root node split probability in tree prior - update_alpha = function(alpha) { - self$alpha <- alpha - }, - - #' @description - #' Update depth prior penalty in tree prior - #' @param beta Depth prior penalty in tree prior - update_beta = function(beta) { - self$beta <- beta - }, - - #' @description - #' Update minimum number of samples per leaf node in the tree prior - #' @param min_samples_leaf Minimum number of samples in a tree leaf - update_min_samples_leaf = function(min_samples_leaf) { - self$min_samples_leaf <- min_samples_leaf - }, - - #' @description - #' Update max depth in the tree prior - #' @param max_depth Maximum depth of any tree in the ensemble in the model - update_max_depth = function(max_depth) { - self$max_depth <- max_depth - }, - - #' @description - #' Update scale parameter used in Gaussian leaf models - #' @param leaf_model_scale Scale parameter used in Gaussian leaf models - update_leaf_model_scale = function(leaf_model_scale) { - if (is.matrix(leaf_model_scale)) { - if (ncol(leaf_model_scale) != nrow(leaf_model_scale)) { - stop("`leaf_model_scale` must be a square matrix") - } - if (ncol(leaf_model_scale) != self$leaf_dimension) { - stop( - "`leaf_model_scale` must have `leaf_dimension` rows and columns" - ) - } - self$leaf_model_scale <- leaf_model_scale - } else { - if (leaf_model_scale <= 0) { - stop( - "`leaf_model_scale` must be positive, if provided as scalar" - ) - } - self$leaf_model_scale <- diag(leaf_model_scale, self$leaf_dimension) - } - }, - - #' @description - #' Update shape parameter for IG leaf models - #' @param variance_forest_shape Shape parameter for IG leaf models - update_variance_forest_shape = function(variance_forest_shape) { - self$variance_forest_shape <- variance_forest_shape - }, - - #' @description - #' Update scale parameter for IG leaf models - #' @param variance_forest_scale Scale parameter for IG leaf models - update_variance_forest_scale = function(variance_forest_scale) { - self$variance_forest_scale <- variance_forest_scale - }, - - #' @description - #' Update number of unique cutpoints to consider - #' @param cutpoint_grid_size Number of unique cutpoints to consider - update_cutpoint_grid_size = function(cutpoint_grid_size) { - self$cutpoint_grid_size <- cutpoint_grid_size - }, - - #' @description - #' Update number of features to subsample for the GFR algorithm - #' @param num_features_subsample Number of features to subsample for the GFR algorithm - update_num_features_subsample = function(num_features_subsample) { - if (num_features_subsample > self$num_features) { - stop( - "`num_features_subsample` cannot be larger than `num_features`" - ) - } - if (num_features_subsample <= 0) { - stop("`num_features_subsample` must at least 1") - } - self$num_features_subsample <- num_features_subsample - }, - - #' @description - #' Query feature types for this ForestModelConfig object - #' @returns Vector of integer-coded feature types (integers where 0 = numeric, 1 = ordered categorical, 2 = unordered categorical) - get_feature_types = function() { - return(self$feature_types) - }, - - #' @description - #' Query sweep update indices for this ForestModelConfig object - #' @returns Vector of (0-indexed) indices of trees to update in a sweep - get_sweep_indices = function() { - return(self$sweep_update_indices) - }, - - #' @description - #' Query variable weights for this ForestModelConfig object - #' @returns Vector specifying sampling probability for all p covariates in ForestDataset - get_variable_weights = function() { - return(self$variable_weights) - }, - - #' @description - #' Query number of trees - #' @returns Number of trees in a forest - get_num_trees = function() { - return(self$num_trees) - }, - - #' @description - #' Query number of features - #' @returns Number of features in a forest model training set - get_num_features = function() { - return(self$num_features) - }, - - #' @description - #' Query number of observations - #' @returns Number of observations in a forest model training set - get_num_observations = function() { - return(self$num_observations) - }, - - #' @description - #' Query root node split probability in tree prior for this ForestModelConfig object - #' @returns Root node split probability in tree prior - get_alpha = function() { - return(self$alpha) - }, - - #' @description - #' Query depth prior penalty in tree prior for this ForestModelConfig object - #' @returns Depth prior penalty in tree prior - get_beta = function() { - return(self$beta) - }, - - #' @description - #' Query root node split probability in tree prior for this ForestModelConfig object - #' @returns Minimum number of samples in a tree leaf - get_min_samples_leaf = function() { - return(self$min_samples_leaf) - }, - - #' @description - #' Query root node split probability in tree prior for this ForestModelConfig object - #' @returns Maximum depth of any tree in the ensemble in the model - get_max_depth = function() { - return(self$max_depth) - }, - - #' @description - #' Query (integer-coded) type of leaf model - #' @returns Integer coded leaf model type - get_leaf_model_type = function() { - return(self$leaf_model_type) - }, - - #' @description - #' Query scale parameter used in Gaussian leaf models for this ForestModelConfig object - #' @returns Scale parameter used in Gaussian leaf models - get_leaf_model_scale = function() { - return(self$leaf_model_scale) - }, - - #' @description - #' Query shape parameter for IG leaf models for this ForestModelConfig object - #' @returns Shape parameter for IG leaf models - get_variance_forest_shape = function() { - return(self$variance_forest_shape) - }, - - #' @description - #' Query scale parameter for IG leaf models for this ForestModelConfig object - #' @returns Scale parameter for IG leaf models - get_variance_forest_scale = function() { - return(self$variance_forest_scale) - }, - - #' @description - #' Query number of unique cutpoints to consider for this ForestModelConfig object - #' @returns Number of unique cutpoints to consider - get_cutpoint_grid_size = function() { - return(self$cutpoint_grid_size) - }, - - #' @description - #' Query number of features to subsample for the GFR algorithm - #' @returns Number of features to subsample for the GFR algorithm - get_num_features_subsample = function() { - return(self$num_features_subsample) + classname = "ForestModelConfig", + cloneable = FALSE, + public = list( + #' @field feature_types Vector of integer-coded feature types (integers where 0 = numeric, 1 = ordered categorical, 2 = unordered categorical) + feature_types = NULL, + + #' @field sweep_update_indices Vector of trees to update in a sweep + sweep_update_indices = NULL, + + #' @field num_trees Number of trees in the forest being sampled + num_trees = NULL, + + #' @field num_features Number of features in training dataset + num_features = NULL, + + #' @field num_observations Number of observations in training dataset + num_observations = NULL, + + #' @field leaf_dimension Dimension of the leaf model + leaf_dimension = NULL, + + #' @field alpha Root node split probability in tree prior + alpha = NULL, + + #' @field beta Depth prior penalty in tree prior + beta = NULL, + + #' @field min_samples_leaf Minimum number of samples in a tree leaf + min_samples_leaf = NULL, + + #' @field max_depth Maximum depth of any tree in the ensemble in the model. Setting to `-1` does not enforce any depth limits on trees. + max_depth = NULL, + + #' @field leaf_model_type Integer specifying the leaf model type (0 = constant leaf, 1 = univariate leaf regression, 2 = multivariate leaf regression) + leaf_model_type = NULL, + + #' @field leaf_model_scale Scale parameter used in Gaussian leaf models + leaf_model_scale = NULL, + + #' @field variable_weights Vector specifying sampling probability for all p covariates in ForestDataset + variable_weights = NULL, + + #' @field variance_forest_shape Shape parameter for IG leaf models (applicable when `leaf_model_type = 3`) + variance_forest_shape = NULL, + + #' @field variance_forest_scale Scale parameter for IG leaf models (applicable when `leaf_model_type = 3`) + variance_forest_scale = NULL, + + #' @field cutpoint_grid_size Number of unique cutpoints to consider + cutpoint_grid_size = NULL, + + #' @field num_features_subsample Number of features to subsample for the GFR algorithm + num_features_subsample = NULL, + + #' Create a new ForestModelConfig object. + #' + #' @param feature_types Vector of integer-coded feature types (where 0 = numeric, 1 = ordered categorical, 2 = unordered categorical) + #' @param sweep_update_indices Vector of (0-indexed) indices of trees to update in a sweep + #' @param num_trees Number of trees in the forest being sampled + #' @param num_features Number of features in training dataset + #' @param num_observations Number of observations in training dataset + #' @param variable_weights Vector specifying sampling probability for all p covariates in ForestDataset + #' @param leaf_dimension Dimension of the leaf model (default: `1`) + #' @param alpha Root node split probability in tree prior (default: `0.95`) + #' @param beta Depth prior penalty in tree prior (default: `2.0`) + #' @param min_samples_leaf Minimum number of samples in a tree leaf (default: `5`) + #' @param max_depth Maximum depth of any tree in the ensemble in the model. Setting to `-1` does not enforce any depth limits on trees. Default: `-1`. + #' @param leaf_model_type Integer specifying the leaf model type (0 = constant leaf, 1 = univariate leaf regression, 2 = multivariate leaf regression). Default: `0`. + #' @param leaf_model_scale Scale parameter used in Gaussian leaf models (can either be a scalar or a q x q matrix, where q is the dimensionality of the basis and is only >1 when `leaf_model_int = 2`). Calibrated internally as `1/num_trees`, propagated along diagonal if needed for multivariate leaf models. + #' @param variance_forest_shape Shape parameter for IG leaf models (applicable when `leaf_model_type = 3`). Default: `1`. + #' @param variance_forest_scale Scale parameter for IG leaf models (applicable when `leaf_model_type = 3`). Default: `1`. + #' @param cutpoint_grid_size Number of unique cutpoints to consider (default: `100`) + #' @param num_features_subsample Number of features to subsample for the GFR algorithm + #' + #' @return A new ForestModelConfig object. + initialize = function( + feature_types = NULL, + sweep_update_indices = NULL, + num_trees = NULL, + num_features = NULL, + num_observations = NULL, + variable_weights = NULL, + leaf_dimension = 1, + alpha = 0.95, + beta = 2.0, + min_samples_leaf = 5, + max_depth = -1, + leaf_model_type = 1, + leaf_model_scale = NULL, + variance_forest_shape = 1.0, + variance_forest_scale = 1.0, + cutpoint_grid_size = 100, + num_features_subsample = NULL + ) { + if (is.null(feature_types)) { + if (is.null(num_features)) { + stop( + "Neither of `num_features` nor `feature_types` (a vector from which `num_features` can be inferred) was provided. Please provide at least one of these inputs when creating a ForestModelConfig object." + ) + } + warning( + "`feature_types` not provided, will be assumed to be numeric" + ) + feature_types <- rep(0, num_features) + } else { + if (is.null(num_features)) { + num_features <- length(feature_types) + } + } + if (is.null(variable_weights)) { + warning( + "`variable_weights` not provided, will be assumed to be equal-weighted" + ) + variable_weights <- rep(1 / num_features, num_features) + } + if (is.null(num_trees)) { + stop("num_trees must be provided") + } + if (!is.null(sweep_update_indices)) { + stopifnot(min(sweep_update_indices) >= 0) + stopifnot(max(sweep_update_indices) < num_trees) + } + if (is.null(num_observations)) { + stop("num_observations must be provided") + } + if (num_features != length(feature_types)) { + stop("`feature_types` must have `num_features` total elements") + } + if (num_features != length(variable_weights)) { + stop( + "`variable_weights` must have `num_features` total elements" + ) + } + self$feature_types <- feature_types + self$sweep_update_indices <- sweep_update_indices + self$variable_weights <- variable_weights + self$num_trees <- num_trees + self$num_features <- num_features + self$num_observations <- num_observations + self$leaf_dimension <- leaf_dimension + self$alpha <- alpha + self$beta <- beta + self$min_samples_leaf <- min_samples_leaf + self$max_depth <- max_depth + self$variance_forest_shape <- variance_forest_shape + self$variance_forest_scale <- variance_forest_scale + self$cutpoint_grid_size <- cutpoint_grid_size + if (is.null(num_features_subsample)) { + num_features_subsample <- num_features + } + if (num_features_subsample > num_features) { + stop( + "`num_features_subsample` cannot be larger than `num_features`" + ) + } + if (num_features_subsample <= 0) { + stop("`num_features_subsample` must be at least 1") + } + self$num_features_subsample <- num_features_subsample + + if (!(as.integer(leaf_model_type) == leaf_model_type)) { + stop("`leaf_model_type` must be an integer between 0 and 3") + if ((leaf_model_type < 0) | (leaf_model_type > 3)) { + stop("`leaf_model_type` must be an integer between 0 and 3") + } + } + self$leaf_model_type <- leaf_model_type + + if (is.null(leaf_model_scale)) { + self$leaf_model_scale <- diag(1 / num_trees, leaf_dimension) + } else if (is.matrix(leaf_model_scale)) { + if (ncol(leaf_model_scale) != nrow(leaf_model_scale)) { + stop("`leaf_model_scale` must be a square matrix") + } + if (ncol(leaf_model_scale) != leaf_dimension) { + stop( + "`leaf_model_scale` must have `leaf_dimension` rows and columns" + ) + } + self$leaf_model_scale <- leaf_model_scale + } else { + if (leaf_model_scale <= 0) { + stop( + "`leaf_model_scale` must be positive, if provided as scalar" + ) + } + self$leaf_model_scale <- diag(leaf_model_scale, leaf_dimension) + } + }, + + #' @description + #' Update feature types + #' @param feature_types Vector of integer-coded feature types (integers where 0 = numeric, 1 = ordered categorical, 2 = unordered categorical) + update_feature_types = function(feature_types) { + stopifnot(length(feature_types) == self$num_features) + self$feature_types <- feature_types + }, + + #' @description + #' Update sweep update indices + #' @param sweep_update_indices Vector of (0-indexed) indices of trees to update in a sweep + update_sweep_indices = function(sweep_update_indices) { + if (!is.null(sweep_update_indices)) { + stopifnot(min(sweep_update_indices) >= 0) + stopifnot(max(sweep_update_indices) < self$num_trees) + } + self$sweep_update_indices <- sweep_update_indices + }, + + #' @description + #' Update variable weights + #' @param variable_weights Vector specifying sampling probability for all p covariates in ForestDataset + update_variable_weights = function(variable_weights) { + stopifnot(length(variable_weights) == self$num_features) + self$variable_weights <- variable_weights + }, + + #' @description + #' Update root node split probability in tree prior + #' @param alpha Root node split probability in tree prior + update_alpha = function(alpha) { + self$alpha <- alpha + }, + + #' @description + #' Update depth prior penalty in tree prior + #' @param beta Depth prior penalty in tree prior + update_beta = function(beta) { + self$beta <- beta + }, + + #' @description + #' Update minimum number of samples per leaf node in the tree prior + #' @param min_samples_leaf Minimum number of samples in a tree leaf + update_min_samples_leaf = function(min_samples_leaf) { + self$min_samples_leaf <- min_samples_leaf + }, + + #' @description + #' Update max depth in the tree prior + #' @param max_depth Maximum depth of any tree in the ensemble in the model + update_max_depth = function(max_depth) { + self$max_depth <- max_depth + }, + + #' @description + #' Update scale parameter used in Gaussian leaf models + #' @param leaf_model_scale Scale parameter used in Gaussian leaf models + update_leaf_model_scale = function(leaf_model_scale) { + if (is.matrix(leaf_model_scale)) { + if (ncol(leaf_model_scale) != nrow(leaf_model_scale)) { + stop("`leaf_model_scale` must be a square matrix") + } + if (ncol(leaf_model_scale) != self$leaf_dimension) { + stop( + "`leaf_model_scale` must have `leaf_dimension` rows and columns" + ) } - ) + self$leaf_model_scale <- leaf_model_scale + } else { + if (leaf_model_scale <= 0) { + stop( + "`leaf_model_scale` must be positive, if provided as scalar" + ) + } + self$leaf_model_scale <- diag(leaf_model_scale, self$leaf_dimension) + } + }, + + #' @description + #' Update shape parameter for IG leaf models + #' @param variance_forest_shape Shape parameter for IG leaf models + update_variance_forest_shape = function(variance_forest_shape) { + self$variance_forest_shape <- variance_forest_shape + }, + + #' @description + #' Update scale parameter for IG leaf models + #' @param variance_forest_scale Scale parameter for IG leaf models + update_variance_forest_scale = function(variance_forest_scale) { + self$variance_forest_scale <- variance_forest_scale + }, + + #' @description + #' Update number of unique cutpoints to consider + #' @param cutpoint_grid_size Number of unique cutpoints to consider + update_cutpoint_grid_size = function(cutpoint_grid_size) { + self$cutpoint_grid_size <- cutpoint_grid_size + }, + + #' @description + #' Update number of features to subsample for the GFR algorithm + #' @param num_features_subsample Number of features to subsample for the GFR algorithm + update_num_features_subsample = function(num_features_subsample) { + if (num_features_subsample > self$num_features) { + stop( + "`num_features_subsample` cannot be larger than `num_features`" + ) + } + if (num_features_subsample <= 0) { + stop("`num_features_subsample` must at least 1") + } + self$num_features_subsample <- num_features_subsample + }, + + #' @description + #' Query feature types for this ForestModelConfig object + #' @returns Vector of integer-coded feature types (integers where 0 = numeric, 1 = ordered categorical, 2 = unordered categorical) + get_feature_types = function() { + return(self$feature_types) + }, + + #' @description + #' Query sweep update indices for this ForestModelConfig object + #' @returns Vector of (0-indexed) indices of trees to update in a sweep + get_sweep_indices = function() { + return(self$sweep_update_indices) + }, + + #' @description + #' Query variable weights for this ForestModelConfig object + #' @returns Vector specifying sampling probability for all p covariates in ForestDataset + get_variable_weights = function() { + return(self$variable_weights) + }, + + #' @description + #' Query number of trees + #' @returns Number of trees in a forest + get_num_trees = function() { + return(self$num_trees) + }, + + #' @description + #' Query number of features + #' @returns Number of features in a forest model training set + get_num_features = function() { + return(self$num_features) + }, + + #' @description + #' Query number of observations + #' @returns Number of observations in a forest model training set + get_num_observations = function() { + return(self$num_observations) + }, + + #' @description + #' Query root node split probability in tree prior for this ForestModelConfig object + #' @returns Root node split probability in tree prior + get_alpha = function() { + return(self$alpha) + }, + + #' @description + #' Query depth prior penalty in tree prior for this ForestModelConfig object + #' @returns Depth prior penalty in tree prior + get_beta = function() { + return(self$beta) + }, + + #' @description + #' Query root node split probability in tree prior for this ForestModelConfig object + #' @returns Minimum number of samples in a tree leaf + get_min_samples_leaf = function() { + return(self$min_samples_leaf) + }, + + #' @description + #' Query root node split probability in tree prior for this ForestModelConfig object + #' @returns Maximum depth of any tree in the ensemble in the model + get_max_depth = function() { + return(self$max_depth) + }, + + #' @description + #' Query (integer-coded) type of leaf model + #' @returns Integer coded leaf model type + get_leaf_model_type = function() { + return(self$leaf_model_type) + }, + + #' @description + #' Query scale parameter used in Gaussian leaf models for this ForestModelConfig object + #' @returns Scale parameter used in Gaussian leaf models + get_leaf_model_scale = function() { + return(self$leaf_model_scale) + }, + + #' @description + #' Query shape parameter for IG leaf models for this ForestModelConfig object + #' @returns Shape parameter for IG leaf models + get_variance_forest_shape = function() { + return(self$variance_forest_shape) + }, + + #' @description + #' Query scale parameter for IG leaf models for this ForestModelConfig object + #' @returns Scale parameter for IG leaf models + get_variance_forest_scale = function() { + return(self$variance_forest_scale) + }, + + #' @description + #' Query number of unique cutpoints to consider for this ForestModelConfig object + #' @returns Number of unique cutpoints to consider + get_cutpoint_grid_size = function() { + return(self$cutpoint_grid_size) + }, + + #' @description + #' Query number of features to subsample for the GFR algorithm + #' @returns Number of features to subsample for the GFR algorithm + get_num_features_subsample = function() { + return(self$num_features_subsample) + } + ) ) #' Object used to get / set global parameters and other global model @@ -441,35 +441,35 @@ ForestModelConfig <- R6::R6Class( #' of a model they wish to run. GlobalModelConfig <- R6::R6Class( - classname = "GlobalModelConfig", - cloneable = FALSE, - public = list( - #' @field global_error_variance Global error variance parameter - global_error_variance = NULL, - - #' Create a new GlobalModelConfig object. - #' - #' @param global_error_variance Global error variance parameter (default: `1.0`) - #' - #' @return A new GlobalModelConfig object. - initialize = function(global_error_variance = 1.0) { - self$global_error_variance <- global_error_variance - }, - - #' @description - #' Update global error variance parameter - #' @param global_error_variance Global error variance parameter - update_global_error_variance = function(global_error_variance) { - self$global_error_variance <- global_error_variance - }, - - #' @description - #' Query global error variance parameter for this GlobalModelConfig object - #' @returns Global error variance parameter - get_global_error_variance = function() { - return(self$global_error_variance) - } - ) + classname = "GlobalModelConfig", + cloneable = FALSE, + public = list( + #' @field global_error_variance Global error variance parameter + global_error_variance = NULL, + + #' Create a new GlobalModelConfig object. + #' + #' @param global_error_variance Global error variance parameter (default: `1.0`) + #' + #' @return A new GlobalModelConfig object. + initialize = function(global_error_variance = 1.0) { + self$global_error_variance <- global_error_variance + }, + + #' @description + #' Update global error variance parameter + #' @param global_error_variance Global error variance parameter + update_global_error_variance = function(global_error_variance) { + self$global_error_variance <- global_error_variance + }, + + #' @description + #' Query global error variance parameter for this GlobalModelConfig object + #' @returns Global error variance parameter + get_global_error_variance = function() { + return(self$global_error_variance) + } + ) ) #' Create a forest model config object @@ -497,45 +497,45 @@ GlobalModelConfig <- R6::R6Class( #' @examples #' config <- createForestModelConfig(num_trees = 10, num_features = 5, num_observations = 100) createForestModelConfig <- function( - feature_types = NULL, - sweep_update_indices = NULL, - num_trees = NULL, - num_features = NULL, - num_observations = NULL, - variable_weights = NULL, - leaf_dimension = 1, - alpha = 0.95, - beta = 2.0, - min_samples_leaf = 5, - max_depth = -1, - leaf_model_type = 1, - leaf_model_scale = NULL, - variance_forest_shape = 1.0, - variance_forest_scale = 1.0, - cutpoint_grid_size = 100, - num_features_subsample = NULL + feature_types = NULL, + sweep_update_indices = NULL, + num_trees = NULL, + num_features = NULL, + num_observations = NULL, + variable_weights = NULL, + leaf_dimension = 1, + alpha = 0.95, + beta = 2.0, + min_samples_leaf = 5, + max_depth = -1, + leaf_model_type = 1, + leaf_model_scale = NULL, + variance_forest_shape = 1.0, + variance_forest_scale = 1.0, + cutpoint_grid_size = 100, + num_features_subsample = NULL ) { - return(invisible( - (ForestModelConfig$new( - feature_types, - sweep_update_indices, - num_trees, - num_features, - num_observations, - variable_weights, - leaf_dimension, - alpha, - beta, - min_samples_leaf, - max_depth, - leaf_model_type, - leaf_model_scale, - variance_forest_shape, - variance_forest_scale, - cutpoint_grid_size, - num_features_subsample - )) + return(invisible( + (ForestModelConfig$new( + feature_types, + sweep_update_indices, + num_trees, + num_features, + num_observations, + variable_weights, + leaf_dimension, + alpha, + beta, + min_samples_leaf, + max_depth, + leaf_model_type, + leaf_model_scale, + variance_forest_shape, + variance_forest_scale, + cutpoint_grid_size, + num_features_subsample )) + )) } #' Create a global model config object @@ -547,5 +547,5 @@ createForestModelConfig <- function( #' @examples #' config <- createGlobalModelConfig(global_error_variance = 100) createGlobalModelConfig <- function(global_error_variance = 1.0) { - return(invisible((GlobalModelConfig$new(global_error_variance)))) + return(invisible((GlobalModelConfig$new(global_error_variance)))) } diff --git a/R/data.R b/R/data.R index 13cd714f..5fcdeb4d 100644 --- a/R/data.R +++ b/R/data.R @@ -6,110 +6,110 @@ #' weights are optional. ForestDataset <- R6::R6Class( - classname = "ForestDataset", - cloneable = FALSE, - public = list( - #' @field data_ptr External pointer to a C++ ForestDataset class - data_ptr = NULL, - - #' @description - #' Create a new ForestDataset object. - #' @param covariates Matrix of covariates - #' @param basis (Optional) Matrix of bases used to define a leaf regression - #' @param variance_weights (Optional) Vector of observation-specific variance weights - #' @return A new `ForestDataset` object. - initialize = function( - covariates, - basis = NULL, - variance_weights = NULL - ) { - self$data_ptr <- create_forest_dataset_cpp() - forest_dataset_add_covariates_cpp(self$data_ptr, covariates) - if (!is.null(basis)) { - forest_dataset_add_basis_cpp(self$data_ptr, basis) - } - if (!is.null(variance_weights)) { - forest_dataset_add_weights_cpp(self$data_ptr, variance_weights) - } - }, - - #' @description - #' Update basis matrix in a dataset - #' @param basis Updated matrix of bases used to define a leaf regression - update_basis = function(basis) { - stopifnot(self$has_basis()) - forest_dataset_update_basis_cpp(self$data_ptr, basis) - }, - - #' @description - #' Update variance_weights in a dataset - #' @param variance_weights Updated vector of variance weights used to define individual variance / case weights - #' @param exponentiate Whether or not input vector should be exponentiated before being written to the Dataset's variance weights. Default: F. - update_variance_weights = function(variance_weights, exponentiate = F) { - stopifnot(self$has_variance_weights()) - forest_dataset_update_var_weights_cpp( - self$data_ptr, - variance_weights, - exponentiate - ) - }, - - #' @description - #' Return number of observations in a `ForestDataset` object - #' @return Observation count - num_observations = function() { - return(dataset_num_rows_cpp(self$data_ptr)) - }, - - #' @description - #' Return number of covariates in a `ForestDataset` object - #' @return Covariate count - num_covariates = function() { - return(dataset_num_covariates_cpp(self$data_ptr)) - }, - - #' @description - #' Return number of bases in a `ForestDataset` object - #' @return Basis count - num_basis = function() { - return(dataset_num_basis_cpp(self$data_ptr)) - }, - - #' @description - #' Return covariates as an R matrix - #' @return Covariate data - get_covariates = function() { - return(forest_dataset_get_covariates_cpp(self$data_ptr)) - }, - - #' @description - #' Return bases as an R matrix - #' @return Basis data - get_basis = function() { - return(forest_dataset_get_basis_cpp(self$data_ptr)) - }, - - #' @description - #' Return variance weights as an R vector - #' @return Variance weight data - get_variance_weights = function() { - return(forest_dataset_get_variance_weights_cpp(self$data_ptr)) - }, - - #' @description - #' Whether or not a dataset has a basis matrix - #' @return True if basis matrix is loaded, false otherwise - has_basis = function() { - return(dataset_has_basis_cpp(self$data_ptr)) - }, - - #' @description - #' Whether or not a dataset has variance weights - #' @return True if variance weights are loaded, false otherwise - has_variance_weights = function() { - return(dataset_has_variance_weights_cpp(self$data_ptr)) - } - ) + classname = "ForestDataset", + cloneable = FALSE, + public = list( + #' @field data_ptr External pointer to a C++ ForestDataset class + data_ptr = NULL, + + #' @description + #' Create a new ForestDataset object. + #' @param covariates Matrix of covariates + #' @param basis (Optional) Matrix of bases used to define a leaf regression + #' @param variance_weights (Optional) Vector of observation-specific variance weights + #' @return A new `ForestDataset` object. + initialize = function( + covariates, + basis = NULL, + variance_weights = NULL + ) { + self$data_ptr <- create_forest_dataset_cpp() + forest_dataset_add_covariates_cpp(self$data_ptr, covariates) + if (!is.null(basis)) { + forest_dataset_add_basis_cpp(self$data_ptr, basis) + } + if (!is.null(variance_weights)) { + forest_dataset_add_weights_cpp(self$data_ptr, variance_weights) + } + }, + + #' @description + #' Update basis matrix in a dataset + #' @param basis Updated matrix of bases used to define a leaf regression + update_basis = function(basis) { + stopifnot(self$has_basis()) + forest_dataset_update_basis_cpp(self$data_ptr, basis) + }, + + #' @description + #' Update variance_weights in a dataset + #' @param variance_weights Updated vector of variance weights used to define individual variance / case weights + #' @param exponentiate Whether or not input vector should be exponentiated before being written to the Dataset's variance weights. Default: F. + update_variance_weights = function(variance_weights, exponentiate = F) { + stopifnot(self$has_variance_weights()) + forest_dataset_update_var_weights_cpp( + self$data_ptr, + variance_weights, + exponentiate + ) + }, + + #' @description + #' Return number of observations in a `ForestDataset` object + #' @return Observation count + num_observations = function() { + return(dataset_num_rows_cpp(self$data_ptr)) + }, + + #' @description + #' Return number of covariates in a `ForestDataset` object + #' @return Covariate count + num_covariates = function() { + return(dataset_num_covariates_cpp(self$data_ptr)) + }, + + #' @description + #' Return number of bases in a `ForestDataset` object + #' @return Basis count + num_basis = function() { + return(dataset_num_basis_cpp(self$data_ptr)) + }, + + #' @description + #' Return covariates as an R matrix + #' @return Covariate data + get_covariates = function() { + return(forest_dataset_get_covariates_cpp(self$data_ptr)) + }, + + #' @description + #' Return bases as an R matrix + #' @return Basis data + get_basis = function() { + return(forest_dataset_get_basis_cpp(self$data_ptr)) + }, + + #' @description + #' Return variance weights as an R vector + #' @return Variance weight data + get_variance_weights = function() { + return(forest_dataset_get_variance_weights_cpp(self$data_ptr)) + }, + + #' @description + #' Whether or not a dataset has a basis matrix + #' @return True if basis matrix is loaded, false otherwise + has_basis = function() { + return(dataset_has_basis_cpp(self$data_ptr)) + }, + + #' @description + #' Whether or not a dataset has variance weights + #' @return True if variance weights are loaded, false otherwise + has_variance_weights = function() { + return(dataset_has_variance_weights_cpp(self$data_ptr)) + } + ) ) #' Outcome / partial residual used to sample an additive model. @@ -123,90 +123,90 @@ ForestDataset <- R6::R6Class( #' (trees, group random effects, etc...). Outcome <- R6::R6Class( - classname = "Outcome", - cloneable = FALSE, - public = list( - #' @field data_ptr External pointer to a C++ Outcome class - data_ptr = NULL, - - #' @description - #' Create a new Outcome object. - #' @param outcome Vector of outcome values - #' @return A new `Outcome` object. - initialize = function(outcome) { - self$data_ptr <- create_column_vector_cpp(outcome) - }, - - #' @description - #' Extract raw data in R from the underlying C++ object - #' @return R vector containing (copy of) the values in `Outcome` object - get_data = function() { - return(get_residual_cpp(self$data_ptr)) - }, - - #' @description - #' Update the current state of the outcome (i.e. partial residual) data by adding the values of `update_vector` - #' @param update_vector Vector to be added to outcome - #' @return None - add_vector = function(update_vector) { - if (!is.numeric(update_vector)) { - stop("update_vector must be a numeric vector or 2d matrix") - } else { - dim_vec <- dim(update_vector) - if (!is.null(dim_vec)) { - if (length(dim_vec) > 2) { - stop( - "if update_vector is provided as a matrix, it must be 2d" - ) - } - update_vector <- as.numeric(update_vector) - } - } - add_to_column_vector_cpp(self$data_ptr, update_vector) - }, - - #' @description - #' Update the current state of the outcome (i.e. partial residual) data by subtracting the values of `update_vector` - #' @param update_vector Vector to be subtracted from outcome - #' @return None - subtract_vector = function(update_vector) { - if (!is.numeric(update_vector)) { - stop("update_vector must be a numeric vector or 2d matrix") - } else { - dim_vec <- dim(update_vector) - if (!is.null(dim_vec)) { - if (length(dim_vec) > 2) { - stop( - "if update_vector is provided as a matrix, it must be 2d" - ) - } - update_vector <- as.numeric(update_vector) - } - } - subtract_from_column_vector_cpp(self$data_ptr, update_vector) - }, - - #' @description - #' Update the current state of the outcome (i.e. partial residual) data by replacing each element with the elements of `new_vector` - #' @param new_vector Vector from which to overwrite the current data - #' @return None - update_data = function(new_vector) { - if (!is.numeric(new_vector)) { - stop("update_vector must be a numeric vector or 2d matrix") - } else { - dim_vec <- dim(new_vector) - if (!is.null(dim_vec)) { - if (length(dim_vec) > 2) { - stop( - "if update_vector is provided as a matrix, it must be 2d" - ) - } - new_vector <- as.numeric(new_vector) - } - } - overwrite_column_vector_cpp(self$data_ptr, new_vector) + classname = "Outcome", + cloneable = FALSE, + public = list( + #' @field data_ptr External pointer to a C++ Outcome class + data_ptr = NULL, + + #' @description + #' Create a new Outcome object. + #' @param outcome Vector of outcome values + #' @return A new `Outcome` object. + initialize = function(outcome) { + self$data_ptr <- create_column_vector_cpp(outcome) + }, + + #' @description + #' Extract raw data in R from the underlying C++ object + #' @return R vector containing (copy of) the values in `Outcome` object + get_data = function() { + return(get_residual_cpp(self$data_ptr)) + }, + + #' @description + #' Update the current state of the outcome (i.e. partial residual) data by adding the values of `update_vector` + #' @param update_vector Vector to be added to outcome + #' @return None + add_vector = function(update_vector) { + if (!is.numeric(update_vector)) { + stop("update_vector must be a numeric vector or 2d matrix") + } else { + dim_vec <- dim(update_vector) + if (!is.null(dim_vec)) { + if (length(dim_vec) > 2) { + stop( + "if update_vector is provided as a matrix, it must be 2d" + ) + } + update_vector <- as.numeric(update_vector) + } + } + add_to_column_vector_cpp(self$data_ptr, update_vector) + }, + + #' @description + #' Update the current state of the outcome (i.e. partial residual) data by subtracting the values of `update_vector` + #' @param update_vector Vector to be subtracted from outcome + #' @return None + subtract_vector = function(update_vector) { + if (!is.numeric(update_vector)) { + stop("update_vector must be a numeric vector or 2d matrix") + } else { + dim_vec <- dim(update_vector) + if (!is.null(dim_vec)) { + if (length(dim_vec) > 2) { + stop( + "if update_vector is provided as a matrix, it must be 2d" + ) + } + update_vector <- as.numeric(update_vector) } - ) + } + subtract_from_column_vector_cpp(self$data_ptr, update_vector) + }, + + #' @description + #' Update the current state of the outcome (i.e. partial residual) data by replacing each element with the elements of `new_vector` + #' @param new_vector Vector from which to overwrite the current data + #' @return None + update_data = function(new_vector) { + if (!is.numeric(new_vector)) { + stop("update_vector must be a numeric vector or 2d matrix") + } else { + dim_vec <- dim(new_vector) + if (!is.null(dim_vec)) { + if (length(dim_vec) > 2) { + stop( + "if update_vector is provided as a matrix, it must be 2d" + ) + } + new_vector <- as.numeric(new_vector) + } + } + overwrite_column_vector_cpp(self$data_ptr, new_vector) + } + ) ) #' Dataset used to sample a random effects model @@ -216,104 +216,104 @@ Outcome <- R6::R6Class( #' bases, and variance weights. Variance weights are optional. RandomEffectsDataset <- R6::R6Class( - classname = "RandomEffectsDataset", - cloneable = FALSE, - public = list( - #' @field data_ptr External pointer to a C++ RandomEffectsDataset class - data_ptr = NULL, - - #' @description - #' Create a new RandomEffectsDataset object. - #' @param group_labels Vector of group labels - #' @param basis Matrix of bases used to define the random effects regression (for an intercept-only model, pass an array of ones) - #' @param variance_weights (Optional) Vector of observation-specific variance weights - #' @return A new `RandomEffectsDataset` object. - initialize = function(group_labels, basis, variance_weights = NULL) { - self$data_ptr <- create_rfx_dataset_cpp() - rfx_dataset_add_group_labels_cpp(self$data_ptr, group_labels) - rfx_dataset_add_basis_cpp(self$data_ptr, basis) - if (!is.null(variance_weights)) { - rfx_dataset_add_weights_cpp(self$data_ptr, variance_weights) - } - }, - - #' @description - #' Update basis matrix in a dataset - #' @param basis Updated matrix of bases used to define random slopes / intercepts - update_basis = function(basis) { - stopifnot(self$has_basis()) - rfx_dataset_update_basis_cpp(self$data_ptr, basis) - }, - - #' @description - #' Update variance_weights in a dataset - #' @param variance_weights Updated vector of variance weights used to define individual variance / case weights - #' @param exponentiate Whether or not input vector should be exponentiated before being written to the RandomEffectsDataset's variance weights. Default: F. - update_variance_weights = function(variance_weights, exponentiate = F) { - stopifnot(self$has_variance_weights()) - rfx_dataset_update_var_weights_cpp( - self$data_ptr, - variance_weights, - exponentiate - ) - }, - - #' @description - #' Return number of observations in a `RandomEffectsDataset` object - #' @return Observation count - num_observations = function() { - return(rfx_dataset_num_rows_cpp(self$data_ptr)) - }, - - #' @description - #' Return dimension of the basis matrix in a `RandomEffectsDataset` object - #' @return Basis vector count - num_basis = function() { - return(rfx_dataset_num_basis_cpp(self$data_ptr)) - }, - - #' @description - #' Return group labels as an R vector - #' @return Group label data - get_group_labels = function() { - return(rfx_dataset_get_group_labels_cpp(self$data_ptr)) - }, - - #' @description - #' Return bases as an R matrix - #' @return Basis data - get_basis = function() { - return(rfx_dataset_get_basis_cpp(self$data_ptr)) - }, - - #' @description - #' Return variance weights as an R vector - #' @return Variance weight data - get_variance_weights = function() { - return(rfx_dataset_get_variance_weights_cpp(self$data_ptr)) - }, - - #' @description - #' Whether or not a dataset has group label indices - #' @return True if group label vector is loaded, false otherwise - has_group_labels = function() { - return(rfx_dataset_has_group_labels_cpp(self$data_ptr)) - }, - - #' @description - #' Whether or not a dataset has a basis matrix - #' @return True if basis matrix is loaded, false otherwise - has_basis = function() { - return(rfx_dataset_has_basis_cpp(self$data_ptr)) - }, - - #' @description - #' Whether or not a dataset has variance weights - #' @return True if variance weights are loaded, false otherwise - has_variance_weights = function() { - return(rfx_dataset_has_variance_weights_cpp(self$data_ptr)) - } - ) + classname = "RandomEffectsDataset", + cloneable = FALSE, + public = list( + #' @field data_ptr External pointer to a C++ RandomEffectsDataset class + data_ptr = NULL, + + #' @description + #' Create a new RandomEffectsDataset object. + #' @param group_labels Vector of group labels + #' @param basis Matrix of bases used to define the random effects regression (for an intercept-only model, pass an array of ones) + #' @param variance_weights (Optional) Vector of observation-specific variance weights + #' @return A new `RandomEffectsDataset` object. + initialize = function(group_labels, basis, variance_weights = NULL) { + self$data_ptr <- create_rfx_dataset_cpp() + rfx_dataset_add_group_labels_cpp(self$data_ptr, group_labels) + rfx_dataset_add_basis_cpp(self$data_ptr, basis) + if (!is.null(variance_weights)) { + rfx_dataset_add_weights_cpp(self$data_ptr, variance_weights) + } + }, + + #' @description + #' Update basis matrix in a dataset + #' @param basis Updated matrix of bases used to define random slopes / intercepts + update_basis = function(basis) { + stopifnot(self$has_basis()) + rfx_dataset_update_basis_cpp(self$data_ptr, basis) + }, + + #' @description + #' Update variance_weights in a dataset + #' @param variance_weights Updated vector of variance weights used to define individual variance / case weights + #' @param exponentiate Whether or not input vector should be exponentiated before being written to the RandomEffectsDataset's variance weights. Default: F. + update_variance_weights = function(variance_weights, exponentiate = F) { + stopifnot(self$has_variance_weights()) + rfx_dataset_update_var_weights_cpp( + self$data_ptr, + variance_weights, + exponentiate + ) + }, + + #' @description + #' Return number of observations in a `RandomEffectsDataset` object + #' @return Observation count + num_observations = function() { + return(rfx_dataset_num_rows_cpp(self$data_ptr)) + }, + + #' @description + #' Return dimension of the basis matrix in a `RandomEffectsDataset` object + #' @return Basis vector count + num_basis = function() { + return(rfx_dataset_num_basis_cpp(self$data_ptr)) + }, + + #' @description + #' Return group labels as an R vector + #' @return Group label data + get_group_labels = function() { + return(rfx_dataset_get_group_labels_cpp(self$data_ptr)) + }, + + #' @description + #' Return bases as an R matrix + #' @return Basis data + get_basis = function() { + return(rfx_dataset_get_basis_cpp(self$data_ptr)) + }, + + #' @description + #' Return variance weights as an R vector + #' @return Variance weight data + get_variance_weights = function() { + return(rfx_dataset_get_variance_weights_cpp(self$data_ptr)) + }, + + #' @description + #' Whether or not a dataset has group label indices + #' @return True if group label vector is loaded, false otherwise + has_group_labels = function() { + return(rfx_dataset_has_group_labels_cpp(self$data_ptr)) + }, + + #' @description + #' Whether or not a dataset has a basis matrix + #' @return True if basis matrix is loaded, false otherwise + has_basis = function() { + return(rfx_dataset_has_basis_cpp(self$data_ptr)) + }, + + #' @description + #' Whether or not a dataset has variance weights + #' @return True if variance weights are loaded, false otherwise + has_variance_weights = function() { + return(rfx_dataset_has_variance_weights_cpp(self$data_ptr)) + } + ) ) #' Create a forest dataset object @@ -333,11 +333,11 @@ RandomEffectsDataset <- R6::R6Class( #' forest_dataset <- createForestDataset(covariate_matrix, basis_matrix) #' forest_dataset <- createForestDataset(covariate_matrix, basis_matrix, weight_vector) createForestDataset <- function( - covariates, - basis = NULL, - variance_weights = NULL + covariates, + basis = NULL, + variance_weights = NULL ) { - return(invisible((ForestDataset$new(covariates, basis, variance_weights)))) + return(invisible((ForestDataset$new(covariates, basis, variance_weights)))) } #' Create an outcome object @@ -352,7 +352,7 @@ createForestDataset <- function( #' y <- -5 + 10*(X[,1] > 0.5) + rnorm(100) #' outcome <- createOutcome(y) createOutcome <- function(outcome) { - return(invisible((Outcome$new(outcome)))) + return(invisible((Outcome$new(outcome)))) } #' Create a random effects dataset object @@ -371,11 +371,11 @@ createOutcome <- function(outcome) { #' rfx_dataset <- createRandomEffectsDataset(rfx_group_ids, rfx_basis) #' rfx_dataset <- createRandomEffectsDataset(rfx_group_ids, rfx_basis, weight_vector) createRandomEffectsDataset <- function( - group_labels, - basis, - variance_weights = NULL + group_labels, + basis, + variance_weights = NULL ) { - return(invisible( - (RandomEffectsDataset$new(group_labels, basis, variance_weights)) - )) + return(invisible( + (RandomEffectsDataset$new(group_labels, basis, variance_weights)) + )) } diff --git a/R/forest.R b/R/forest.R index a554c2d5..ae136a97 100644 --- a/R/forest.R +++ b/R/forest.R @@ -4,897 +4,897 @@ #' Wrapper around a C++ container of tree ensembles ForestSamples <- R6::R6Class( - classname = "ForestSamples", - cloneable = FALSE, - public = list( - #' @field forest_container_ptr External pointer to a C++ ForestContainer class - forest_container_ptr = NULL, - - #' @description - #' Create a new ForestContainer object. - #' @param num_trees Number of trees - #' @param leaf_dimension Dimensionality of the outcome model - #' @param is_leaf_constant Whether leaf is constant - #' @param is_exponentiated Whether forest predictions should be exponentiated before being returned - #' @return A new `ForestContainer` object. - initialize = function( - num_trees, - leaf_dimension = 1, - is_leaf_constant = FALSE, - is_exponentiated = FALSE - ) { - self$forest_container_ptr <- forest_container_cpp( - num_trees, - leaf_dimension, - is_leaf_constant, - is_exponentiated - ) - }, - - #' @description - #' Collapse forests in this container by a pre-specified batch size. - #' For example, if we have a container of twenty 10-tree forests, and we - #' specify a `batch_size` of 5, then this method will yield four 50-tree - #' forests. "Excess" forests remaining after the size of a forest container - #' is divided by `batch_size` will be pruned from the beginning of the - #' container (i.e. earlier sampled forests will be deleted). This method - #' has no effect if `batch_size` is larger than the number of forests - #' in a container. - #' @param batch_size Number of forests to be collapsed into a single forest - collapse = function(batch_size) { - container_size <- self$num_samples() - if ((batch_size <= container_size) && (batch_size > 1)) { - reverse_container_inds <- seq(container_size, 1, -1) - num_clean_batches <- container_size %/% batch_size - batch_inds <- (reverse_container_inds - - (container_size - - (container_size %/% num_clean_batches) * - num_clean_batches) - - 1) %/% - batch_size - for (batch_ind in unique(batch_inds[batch_inds >= 0])) { - merge_forest_inds <- sort( - reverse_container_inds[batch_inds == batch_ind] - 1 - ) - num_merge_forests <- length(merge_forest_inds) - self$combine_forests(merge_forest_inds) - for (i in num_merge_forests:2) { - self$delete_sample(merge_forest_inds[i]) - } - forest_scale_factor <- 1.0 / num_merge_forests - self$multiply_forest( - merge_forest_inds[1], - forest_scale_factor - ) - } - if (min(batch_inds) < 0) { - delete_forest_inds <- sort( - reverse_container_inds[batch_inds < 0] - 1 - ) - for (i in length(delete_forest_inds):1) { - self$delete_sample(delete_forest_inds[i]) - } - } - } - }, - - #' @description - #' Merge specified forests into a single forest - #' @param forest_inds Indices of forests to be combined (0-indexed) - combine_forests = function(forest_inds) { - stopifnot(max(forest_inds) < self$num_samples()) - stopifnot(min(forest_inds) >= 0) - stopifnot(length(forest_inds) > 1) - stopifnot(all(as.integer(forest_inds) == forest_inds)) - forest_inds_sorted <- as.integer(sort(forest_inds)) - combine_forests_forest_container_cpp( - self$forest_container_ptr, - forest_inds_sorted - ) - }, - - #' @description - #' Add a constant value to every leaf of every tree of a given forest - #' @param forest_index Index of forest whose leaves will be modified (0-indexed) - #' @param constant_value Value to add to every leaf of every tree of the forest at `forest_index` - add_to_forest = function(forest_index, constant_value) { - stopifnot(forest_index < self$num_samples()) - stopifnot(forest_index >= 0) - add_to_forest_forest_container_cpp( - self$forest_container_ptr, - forest_index, - constant_value - ) - }, - - #' @description - #' Multiply every leaf of every tree of a given forest by constant value - #' @param forest_index Index of forest whose leaves will be modified (0-indexed) - #' @param constant_multiple Value to multiply through by every leaf of every tree of the forest at `forest_index` - multiply_forest = function(forest_index, constant_multiple) { - stopifnot(forest_index < self$num_samples()) - stopifnot(forest_index >= 0) - multiply_forest_forest_container_cpp( - self$forest_container_ptr, - forest_index, - constant_multiple - ) - }, - - #' @description - #' Create a new `ForestContainer` object from a json object - #' @param json_object Object of class `CppJson` - #' @param json_forest_label Label referring to a particular forest (i.e. "forest_0") in the overall json hierarchy - #' @return A new `ForestContainer` object. - load_from_json = function(json_object, json_forest_label) { - self$forest_container_ptr <- forest_container_from_json_cpp( - json_object$json_ptr, - json_forest_label - ) - }, - - #' @description - #' Append to a `ForestContainer` object from a json object - #' @param json_object Object of class `CppJson` - #' @param json_forest_label Label referring to a particular forest (i.e. "forest_0") in the overall json hierarchy - #' @return None - append_from_json = function(json_object, json_forest_label) { - forest_container_append_from_json_cpp( - self$forest_container_ptr, - json_object$json_ptr, - json_forest_label - ) - }, - - #' @description - #' Create a new `ForestContainer` object from a json object - #' @param json_string JSON string which parses into object of class `CppJson` - #' @param json_forest_label Label referring to a particular forest (i.e. "forest_0") in the overall json hierarchy - #' @return A new `ForestContainer` object. - load_from_json_string = function(json_string, json_forest_label) { - self$forest_container_ptr <- forest_container_from_json_string_cpp( - json_string, - json_forest_label - ) - }, - - #' @description - #' Append to a `ForestContainer` object from a json object - #' @param json_string JSON string which parses into object of class `CppJson` - #' @param json_forest_label Label referring to a particular forest (i.e. "forest_0") in the overall json hierarchy - #' @return None - append_from_json_string = function(json_string, json_forest_label) { - forest_container_append_from_json_string_cpp( - self$forest_container_ptr, - json_string, - json_forest_label - ) - }, - - #' @description - #' Predict every tree ensemble on every sample in `forest_dataset` - #' @param forest_dataset `ForestDataset` R class - #' @return matrix of predictions with as many rows as in forest_dataset - #' and as many columns as samples in the `ForestContainer` - predict = function(forest_dataset) { - stopifnot(!is.null(forest_dataset$data_ptr)) - return(predict_forest_cpp( - self$forest_container_ptr, - forest_dataset$data_ptr - )) - }, - - #' @description - #' Predict "raw" leaf values (without being multiplied by basis) for every tree ensemble on every sample in `forest_dataset` - #' @param forest_dataset `ForestDataset` R class - #' @return Array of predictions for each observation in `forest_dataset` and - #' each sample in the `ForestSamples` class with each prediction having the - #' dimensionality of the forests' leaf model. In the case of a constant leaf model - #' or univariate leaf regression, this array is two-dimensional (number of observations, - #' number of forest samples). In the case of a multivariate leaf regression, - #' this array is three-dimension (number of observations, leaf model dimension, - #' number of samples). - predict_raw = function(forest_dataset) { - stopifnot(!is.null(forest_dataset$data_ptr)) - # Unpack dimensions - output_dim <- leaf_dimension_forest_container_cpp( - self$forest_container_ptr - ) - num_samples <- num_samples_forest_container_cpp( - self$forest_container_ptr - ) - n <- dataset_num_rows_cpp(forest_dataset$data_ptr) - - # Predict leaf values from forest - predictions <- predict_forest_raw_cpp( - self$forest_container_ptr, - forest_dataset$data_ptr - ) - if (output_dim > 1) { - dim(predictions) <- c(n, output_dim, num_samples) - } else { - dim(predictions) <- c(n, num_samples) - } - - return(predictions) - }, - - #' @description - #' Predict "raw" leaf values (without being multiplied by basis) for a specific forest on every sample in `forest_dataset` - #' @param forest_dataset `ForestDataset` R class - #' @param forest_num Index of the forest sample within the container - #' @return matrix of predictions with as many rows as in forest_dataset - #' and as many columns as dimensions in the leaves of trees in `ForestContainer` - predict_raw_single_forest = function(forest_dataset, forest_num) { - stopifnot(!is.null(forest_dataset$data_ptr)) - # Unpack dimensions - output_dim <- leaf_dimension_forest_container_cpp( - self$forest_container_ptr - ) - n <- dataset_num_rows_cpp(forest_dataset$data_ptr) - - # Predict leaf values from forest - output <- predict_forest_raw_single_forest_cpp( - self$forest_container_ptr, - forest_dataset$data_ptr, - forest_num - ) - return(output) - }, - - #' @description - #' Predict "raw" leaf values (without being multiplied by basis) for a specific tree in a specific forest on every observation in `forest_dataset` - #' @param forest_dataset `ForestDataset` R class - #' @param forest_num Index of the forest sample within the container - #' @param tree_num Index of the tree to be queried - #' @return matrix of predictions with as many rows as in `forest_dataset` - #' and as many columns as dimensions in the leaves of trees in `ForestContainer` - predict_raw_single_tree = function( - forest_dataset, - forest_num, - tree_num - ) { - stopifnot(!is.null(forest_dataset$data_ptr)) - - # Predict leaf values from forest - output <- predict_forest_raw_single_tree_cpp( - self$forest_container_ptr, - forest_dataset$data_ptr, - forest_num, - tree_num - ) - return(output) - }, - - #' @description - #' Set a constant predicted value for every tree in the ensemble. - #' Stops program if any tree is more than a root node. - #' @param forest_num Index of the forest sample within the container. - #' @param leaf_value Constant leaf value(s) to be fixed for each tree in the ensemble indexed by `forest_num`. Can be either a single number or a vector, depending on the forest's leaf dimension. - set_root_leaves = function(forest_num, leaf_value) { - stopifnot(!is.null(self$forest_container_ptr)) - stopifnot( - num_samples_forest_container_cpp(self$forest_container_ptr) == 0 - ) - - # Set leaf values - if (length(leaf_value) == 1) { - stopifnot( - leaf_dimension_forest_container_cpp( - self$forest_container_ptr - ) == - 1 - ) - set_leaf_value_forest_container_cpp( - self$forest_container_ptr, - leaf_value - ) - } else if (length(leaf_value) > 1) { - stopifnot( - leaf_dimension_forest_container_cpp( - self$forest_container_ptr - ) == - length(leaf_value) - ) - set_leaf_vector_forest_container_cpp( - self$forest_container_ptr, - leaf_value - ) - } else { - stop( - "leaf_value must be a numeric value or vector of length >= 1" - ) - } - }, - - #' @description - #' Set a constant predicted value for every tree in the ensemble. - #' Stops program if any tree is more than a root node. - #' @param dataset `ForestDataset` Dataset class (covariates, basis, etc...) - #' @param outcome `Outcome` Outcome class (residual / partial residual) - #' @param forest_model `ForestModel` object storing tracking structures used in training / sampling - #' @param leaf_model_int Integer value encoding the leaf model type (0 = constant gaussian, 1 = univariate gaussian, 2 = multivariate gaussian, 3 = log linear variance). - #' @param leaf_value Constant leaf value(s) to be fixed for each tree in the ensemble indexed by `forest_num`. Can be either a single number or a vector, depending on the forest's leaf dimension. - prepare_for_sampler = function( - dataset, - outcome, - forest_model, - leaf_model_int, - leaf_value - ) { - stopifnot(!is.null(dataset$data_ptr)) - stopifnot(!is.null(outcome$data_ptr)) - stopifnot(!is.null(forest_model$tracker_ptr)) - stopifnot(!is.null(self$forest_container_ptr)) - stopifnot( - num_samples_forest_container_cpp(self$forest_container_ptr) == 0 - ) - - # Initialize the model - initialize_forest_model_cpp( - dataset$data_ptr, - outcome$data_ptr, - self$forest_container_ptr, - forest_model$tracker_ptr, - leaf_value, - leaf_model_int - ) - }, - - #' @description - #' Adjusts residual based on the predictions of a forest - #' - #' This is typically run just once at the beginning of a forest sampling algorithm. - #' After trees are initialized with constant root node predictions, their root predictions are subtracted out of the residual. - #' @param dataset `ForestDataset` object storing the covariates and bases for a given forest - #' @param outcome `Outcome` object storing the residuals to be updated based on forest predictions - #' @param forest_model `ForestModel` object storing tracking structures used in training / sampling - #' @param requires_basis Whether or not a forest requires a basis for prediction - #' @param forest_num Index of forest used to update residuals - #' @param add Whether forest predictions should be added to or subtracted from residuals - adjust_residual = function( - dataset, - outcome, - forest_model, - requires_basis, - forest_num, - add - ) { - stopifnot(!is.null(dataset$data_ptr)) - stopifnot(!is.null(outcome$data_ptr)) - stopifnot(!is.null(forest_model$tracker_ptr)) - stopifnot(!is.null(self$forest_container_ptr)) - - adjust_residual_forest_container_cpp( - dataset$data_ptr, - outcome$data_ptr, - self$forest_container_ptr, - forest_model$tracker_ptr, - requires_basis, - forest_num, - add - ) - }, - - #' @description - #' Store the trees and metadata of `ForestDataset` class in a json file - #' @param json_filename Name of output json file (must end in ".json") - save_json = function(json_filename) { - invisible(json_save_forest_container_cpp( - self$forest_container_ptr, - json_filename - )) - }, - - #' @description - #' Load trees and metadata for an ensemble from a json file. Note that - #' any trees and metadata already present in `ForestDataset` class will - #' be overwritten. - #' @param json_filename Name of model input json file (must end in ".json") - load_json = function(json_filename) { - invisible(json_load_forest_container_cpp( - self$forest_container_ptr, - json_filename - )) - }, - - #' @description - #' Return number of samples in a `ForestContainer` object - #' @return Sample count - num_samples = function() { - return(num_samples_forest_container_cpp(self$forest_container_ptr)) - }, - - #' @description - #' Return number of trees in each ensemble of a `ForestContainer` object - #' @return Tree count - num_trees = function() { - return(num_trees_forest_container_cpp(self$forest_container_ptr)) - }, - - #' @description - #' Return output dimension of trees in a `ForestContainer` object - #' @return Leaf node parameter size - leaf_dimension = function() { - return(leaf_dimension_forest_container_cpp( - self$forest_container_ptr - )) - }, - - #' @description - #' Return constant leaf status of trees in a `ForestContainer` object - #' @return `TRUE` if leaves are constant, `FALSE` otherwise - is_constant_leaf = function() { - return(is_constant_leaf_forest_container_cpp( - self$forest_container_ptr - )) - }, - - #' @description - #' Return exponentiation status of trees in a `ForestContainer` object - #' @return `TRUE` if leaf predictions must be exponentiated, `FALSE` otherwise - is_exponentiated = function() { - return(is_exponentiated_forest_container_cpp( - self$forest_container_ptr - )) - }, - - #' @description - #' Add a new all-root ensemble to the container, with all of the leaves - #' set to the value / vector provided - #' @param leaf_value Value (or vector of values) to initialize root nodes in tree - add_forest_with_constant_leaves = function(leaf_value) { - if (length(leaf_value) > 1) { - add_sample_vector_forest_container_cpp( - self$forest_container_ptr, - leaf_value - ) - } else { - add_sample_value_forest_container_cpp( - self$forest_container_ptr, - leaf_value - ) - } - }, - - #' @description - #' Add a numeric (i.e. `X[,i] <= c`) split to a given tree in the ensemble - #' @param forest_num Index of the forest which contains the tree to be split - #' @param tree_num Index of the tree to be split - #' @param leaf_num Leaf to be split - #' @param feature_num Feature that defines the new split - #' @param split_threshold Value that defines the cutoff of the new split - #' @param left_leaf_value Value (or vector of values) to assign to the newly created left node - #' @param right_leaf_value Value (or vector of values) to assign to the newly created right node - add_numeric_split_tree = function( + classname = "ForestSamples", + cloneable = FALSE, + public = list( + #' @field forest_container_ptr External pointer to a C++ ForestContainer class + forest_container_ptr = NULL, + + #' @description + #' Create a new ForestContainer object. + #' @param num_trees Number of trees + #' @param leaf_dimension Dimensionality of the outcome model + #' @param is_leaf_constant Whether leaf is constant + #' @param is_exponentiated Whether forest predictions should be exponentiated before being returned + #' @return A new `ForestContainer` object. + initialize = function( + num_trees, + leaf_dimension = 1, + is_leaf_constant = FALSE, + is_exponentiated = FALSE + ) { + self$forest_container_ptr <- forest_container_cpp( + num_trees, + leaf_dimension, + is_leaf_constant, + is_exponentiated + ) + }, + + #' @description + #' Collapse forests in this container by a pre-specified batch size. + #' For example, if we have a container of twenty 10-tree forests, and we + #' specify a `batch_size` of 5, then this method will yield four 50-tree + #' forests. "Excess" forests remaining after the size of a forest container + #' is divided by `batch_size` will be pruned from the beginning of the + #' container (i.e. earlier sampled forests will be deleted). This method + #' has no effect if `batch_size` is larger than the number of forests + #' in a container. + #' @param batch_size Number of forests to be collapsed into a single forest + collapse = function(batch_size) { + container_size <- self$num_samples() + if ((batch_size <= container_size) && (batch_size > 1)) { + reverse_container_inds <- seq(container_size, 1, -1) + num_clean_batches <- container_size %/% batch_size + batch_inds <- (reverse_container_inds - + (container_size - + (container_size %/% num_clean_batches) * + num_clean_batches) - + 1) %/% + batch_size + for (batch_ind in unique(batch_inds[batch_inds >= 0])) { + merge_forest_inds <- sort( + reverse_container_inds[batch_inds == batch_ind] - 1 + ) + num_merge_forests <- length(merge_forest_inds) + self$combine_forests(merge_forest_inds) + for (i in num_merge_forests:2) { + self$delete_sample(merge_forest_inds[i]) + } + forest_scale_factor <- 1.0 / num_merge_forests + self$multiply_forest( + merge_forest_inds[1], + forest_scale_factor + ) + } + if (min(batch_inds) < 0) { + delete_forest_inds <- sort( + reverse_container_inds[batch_inds < 0] - 1 + ) + for (i in length(delete_forest_inds):1) { + self$delete_sample(delete_forest_inds[i]) + } + } + } + }, + + #' @description + #' Merge specified forests into a single forest + #' @param forest_inds Indices of forests to be combined (0-indexed) + combine_forests = function(forest_inds) { + stopifnot(max(forest_inds) < self$num_samples()) + stopifnot(min(forest_inds) >= 0) + stopifnot(length(forest_inds) > 1) + stopifnot(all(as.integer(forest_inds) == forest_inds)) + forest_inds_sorted <- as.integer(sort(forest_inds)) + combine_forests_forest_container_cpp( + self$forest_container_ptr, + forest_inds_sorted + ) + }, + + #' @description + #' Add a constant value to every leaf of every tree of a given forest + #' @param forest_index Index of forest whose leaves will be modified (0-indexed) + #' @param constant_value Value to add to every leaf of every tree of the forest at `forest_index` + add_to_forest = function(forest_index, constant_value) { + stopifnot(forest_index < self$num_samples()) + stopifnot(forest_index >= 0) + add_to_forest_forest_container_cpp( + self$forest_container_ptr, + forest_index, + constant_value + ) + }, + + #' @description + #' Multiply every leaf of every tree of a given forest by constant value + #' @param forest_index Index of forest whose leaves will be modified (0-indexed) + #' @param constant_multiple Value to multiply through by every leaf of every tree of the forest at `forest_index` + multiply_forest = function(forest_index, constant_multiple) { + stopifnot(forest_index < self$num_samples()) + stopifnot(forest_index >= 0) + multiply_forest_forest_container_cpp( + self$forest_container_ptr, + forest_index, + constant_multiple + ) + }, + + #' @description + #' Create a new `ForestContainer` object from a json object + #' @param json_object Object of class `CppJson` + #' @param json_forest_label Label referring to a particular forest (i.e. "forest_0") in the overall json hierarchy + #' @return A new `ForestContainer` object. + load_from_json = function(json_object, json_forest_label) { + self$forest_container_ptr <- forest_container_from_json_cpp( + json_object$json_ptr, + json_forest_label + ) + }, + + #' @description + #' Append to a `ForestContainer` object from a json object + #' @param json_object Object of class `CppJson` + #' @param json_forest_label Label referring to a particular forest (i.e. "forest_0") in the overall json hierarchy + #' @return None + append_from_json = function(json_object, json_forest_label) { + forest_container_append_from_json_cpp( + self$forest_container_ptr, + json_object$json_ptr, + json_forest_label + ) + }, + + #' @description + #' Create a new `ForestContainer` object from a json object + #' @param json_string JSON string which parses into object of class `CppJson` + #' @param json_forest_label Label referring to a particular forest (i.e. "forest_0") in the overall json hierarchy + #' @return A new `ForestContainer` object. + load_from_json_string = function(json_string, json_forest_label) { + self$forest_container_ptr <- forest_container_from_json_string_cpp( + json_string, + json_forest_label + ) + }, + + #' @description + #' Append to a `ForestContainer` object from a json object + #' @param json_string JSON string which parses into object of class `CppJson` + #' @param json_forest_label Label referring to a particular forest (i.e. "forest_0") in the overall json hierarchy + #' @return None + append_from_json_string = function(json_string, json_forest_label) { + forest_container_append_from_json_string_cpp( + self$forest_container_ptr, + json_string, + json_forest_label + ) + }, + + #' @description + #' Predict every tree ensemble on every sample in `forest_dataset` + #' @param forest_dataset `ForestDataset` R class + #' @return matrix of predictions with as many rows as in forest_dataset + #' and as many columns as samples in the `ForestContainer` + predict = function(forest_dataset) { + stopifnot(!is.null(forest_dataset$data_ptr)) + return(predict_forest_cpp( + self$forest_container_ptr, + forest_dataset$data_ptr + )) + }, + + #' @description + #' Predict "raw" leaf values (without being multiplied by basis) for every tree ensemble on every sample in `forest_dataset` + #' @param forest_dataset `ForestDataset` R class + #' @return Array of predictions for each observation in `forest_dataset` and + #' each sample in the `ForestSamples` class with each prediction having the + #' dimensionality of the forests' leaf model. In the case of a constant leaf model + #' or univariate leaf regression, this array is two-dimensional (number of observations, + #' number of forest samples). In the case of a multivariate leaf regression, + #' this array is three-dimension (number of observations, leaf model dimension, + #' number of samples). + predict_raw = function(forest_dataset) { + stopifnot(!is.null(forest_dataset$data_ptr)) + # Unpack dimensions + output_dim <- leaf_dimension_forest_container_cpp( + self$forest_container_ptr + ) + num_samples <- num_samples_forest_container_cpp( + self$forest_container_ptr + ) + n <- dataset_num_rows_cpp(forest_dataset$data_ptr) + + # Predict leaf values from forest + predictions <- predict_forest_raw_cpp( + self$forest_container_ptr, + forest_dataset$data_ptr + ) + if (output_dim > 1) { + dim(predictions) <- c(n, output_dim, num_samples) + } else { + dim(predictions) <- c(n, num_samples) + } + + return(predictions) + }, + + #' @description + #' Predict "raw" leaf values (without being multiplied by basis) for a specific forest on every sample in `forest_dataset` + #' @param forest_dataset `ForestDataset` R class + #' @param forest_num Index of the forest sample within the container + #' @return matrix of predictions with as many rows as in forest_dataset + #' and as many columns as dimensions in the leaves of trees in `ForestContainer` + predict_raw_single_forest = function(forest_dataset, forest_num) { + stopifnot(!is.null(forest_dataset$data_ptr)) + # Unpack dimensions + output_dim <- leaf_dimension_forest_container_cpp( + self$forest_container_ptr + ) + n <- dataset_num_rows_cpp(forest_dataset$data_ptr) + + # Predict leaf values from forest + output <- predict_forest_raw_single_forest_cpp( + self$forest_container_ptr, + forest_dataset$data_ptr, + forest_num + ) + return(output) + }, + + #' @description + #' Predict "raw" leaf values (without being multiplied by basis) for a specific tree in a specific forest on every observation in `forest_dataset` + #' @param forest_dataset `ForestDataset` R class + #' @param forest_num Index of the forest sample within the container + #' @param tree_num Index of the tree to be queried + #' @return matrix of predictions with as many rows as in `forest_dataset` + #' and as many columns as dimensions in the leaves of trees in `ForestContainer` + predict_raw_single_tree = function( + forest_dataset, + forest_num, + tree_num + ) { + stopifnot(!is.null(forest_dataset$data_ptr)) + + # Predict leaf values from forest + output <- predict_forest_raw_single_tree_cpp( + self$forest_container_ptr, + forest_dataset$data_ptr, + forest_num, + tree_num + ) + return(output) + }, + + #' @description + #' Set a constant predicted value for every tree in the ensemble. + #' Stops program if any tree is more than a root node. + #' @param forest_num Index of the forest sample within the container. + #' @param leaf_value Constant leaf value(s) to be fixed for each tree in the ensemble indexed by `forest_num`. Can be either a single number or a vector, depending on the forest's leaf dimension. + set_root_leaves = function(forest_num, leaf_value) { + stopifnot(!is.null(self$forest_container_ptr)) + stopifnot( + num_samples_forest_container_cpp(self$forest_container_ptr) == 0 + ) + + # Set leaf values + if (length(leaf_value) == 1) { + stopifnot( + leaf_dimension_forest_container_cpp( + self$forest_container_ptr + ) == + 1 + ) + set_leaf_value_forest_container_cpp( + self$forest_container_ptr, + leaf_value + ) + } else if (length(leaf_value) > 1) { + stopifnot( + leaf_dimension_forest_container_cpp( + self$forest_container_ptr + ) == + length(leaf_value) + ) + set_leaf_vector_forest_container_cpp( + self$forest_container_ptr, + leaf_value + ) + } else { + stop( + "leaf_value must be a numeric value or vector of length >= 1" + ) + } + }, + + #' @description + #' Set a constant predicted value for every tree in the ensemble. + #' Stops program if any tree is more than a root node. + #' @param dataset `ForestDataset` Dataset class (covariates, basis, etc...) + #' @param outcome `Outcome` Outcome class (residual / partial residual) + #' @param forest_model `ForestModel` object storing tracking structures used in training / sampling + #' @param leaf_model_int Integer value encoding the leaf model type (0 = constant gaussian, 1 = univariate gaussian, 2 = multivariate gaussian, 3 = log linear variance). + #' @param leaf_value Constant leaf value(s) to be fixed for each tree in the ensemble indexed by `forest_num`. Can be either a single number or a vector, depending on the forest's leaf dimension. + prepare_for_sampler = function( + dataset, + outcome, + forest_model, + leaf_model_int, + leaf_value + ) { + stopifnot(!is.null(dataset$data_ptr)) + stopifnot(!is.null(outcome$data_ptr)) + stopifnot(!is.null(forest_model$tracker_ptr)) + stopifnot(!is.null(self$forest_container_ptr)) + stopifnot( + num_samples_forest_container_cpp(self$forest_container_ptr) == 0 + ) + + # Initialize the model + initialize_forest_model_cpp( + dataset$data_ptr, + outcome$data_ptr, + self$forest_container_ptr, + forest_model$tracker_ptr, + leaf_value, + leaf_model_int + ) + }, + + #' @description + #' Adjusts residual based on the predictions of a forest + #' + #' This is typically run just once at the beginning of a forest sampling algorithm. + #' After trees are initialized with constant root node predictions, their root predictions are subtracted out of the residual. + #' @param dataset `ForestDataset` object storing the covariates and bases for a given forest + #' @param outcome `Outcome` object storing the residuals to be updated based on forest predictions + #' @param forest_model `ForestModel` object storing tracking structures used in training / sampling + #' @param requires_basis Whether or not a forest requires a basis for prediction + #' @param forest_num Index of forest used to update residuals + #' @param add Whether forest predictions should be added to or subtracted from residuals + adjust_residual = function( + dataset, + outcome, + forest_model, + requires_basis, + forest_num, + add + ) { + stopifnot(!is.null(dataset$data_ptr)) + stopifnot(!is.null(outcome$data_ptr)) + stopifnot(!is.null(forest_model$tracker_ptr)) + stopifnot(!is.null(self$forest_container_ptr)) + + adjust_residual_forest_container_cpp( + dataset$data_ptr, + outcome$data_ptr, + self$forest_container_ptr, + forest_model$tracker_ptr, + requires_basis, + forest_num, + add + ) + }, + + #' @description + #' Store the trees and metadata of `ForestDataset` class in a json file + #' @param json_filename Name of output json file (must end in ".json") + save_json = function(json_filename) { + invisible(json_save_forest_container_cpp( + self$forest_container_ptr, + json_filename + )) + }, + + #' @description + #' Load trees and metadata for an ensemble from a json file. Note that + #' any trees and metadata already present in `ForestDataset` class will + #' be overwritten. + #' @param json_filename Name of model input json file (must end in ".json") + load_json = function(json_filename) { + invisible(json_load_forest_container_cpp( + self$forest_container_ptr, + json_filename + )) + }, + + #' @description + #' Return number of samples in a `ForestContainer` object + #' @return Sample count + num_samples = function() { + return(num_samples_forest_container_cpp(self$forest_container_ptr)) + }, + + #' @description + #' Return number of trees in each ensemble of a `ForestContainer` object + #' @return Tree count + num_trees = function() { + return(num_trees_forest_container_cpp(self$forest_container_ptr)) + }, + + #' @description + #' Return output dimension of trees in a `ForestContainer` object + #' @return Leaf node parameter size + leaf_dimension = function() { + return(leaf_dimension_forest_container_cpp( + self$forest_container_ptr + )) + }, + + #' @description + #' Return constant leaf status of trees in a `ForestContainer` object + #' @return `TRUE` if leaves are constant, `FALSE` otherwise + is_constant_leaf = function() { + return(is_constant_leaf_forest_container_cpp( + self$forest_container_ptr + )) + }, + + #' @description + #' Return exponentiation status of trees in a `ForestContainer` object + #' @return `TRUE` if leaf predictions must be exponentiated, `FALSE` otherwise + is_exponentiated = function() { + return(is_exponentiated_forest_container_cpp( + self$forest_container_ptr + )) + }, + + #' @description + #' Add a new all-root ensemble to the container, with all of the leaves + #' set to the value / vector provided + #' @param leaf_value Value (or vector of values) to initialize root nodes in tree + add_forest_with_constant_leaves = function(leaf_value) { + if (length(leaf_value) > 1) { + add_sample_vector_forest_container_cpp( + self$forest_container_ptr, + leaf_value + ) + } else { + add_sample_value_forest_container_cpp( + self$forest_container_ptr, + leaf_value + ) + } + }, + + #' @description + #' Add a numeric (i.e. `X[,i] <= c`) split to a given tree in the ensemble + #' @param forest_num Index of the forest which contains the tree to be split + #' @param tree_num Index of the tree to be split + #' @param leaf_num Leaf to be split + #' @param feature_num Feature that defines the new split + #' @param split_threshold Value that defines the cutoff of the new split + #' @param left_leaf_value Value (or vector of values) to assign to the newly created left node + #' @param right_leaf_value Value (or vector of values) to assign to the newly created right node + add_numeric_split_tree = function( + forest_num, + tree_num, + leaf_num, + feature_num, + split_threshold, + left_leaf_value, + right_leaf_value + ) { + if (length(left_leaf_value) > 1) { + add_numeric_split_tree_vector_forest_container_cpp( + self$forest_container_ptr, + forest_num, + tree_num, + leaf_num, + feature_num, + split_threshold, + left_leaf_value, + right_leaf_value + ) + } else { + add_numeric_split_tree_value_forest_container_cpp( + self$forest_container_ptr, + forest_num, + tree_num, + leaf_num, + feature_num, + split_threshold, + left_leaf_value, + right_leaf_value + ) + } + }, + + #' @description + #' Retrieve a vector of indices of leaf nodes for a given tree in a given forest + #' @param forest_num Index of the forest which contains tree `tree_num` + #' @param tree_num Index of the tree for which leaf indices will be retrieved + get_tree_leaves = function(forest_num, tree_num) { + return(get_tree_leaves_forest_container_cpp( + self$forest_container_ptr, + forest_num, + tree_num + )) + }, + + #' @description + #' Retrieve a vector of split counts for every training set variable in a given tree in a given forest + #' @param forest_num Index of the forest which contains tree `tree_num` + #' @param tree_num Index of the tree for which split counts will be retrieved + #' @param num_features Total number of features in the training set + get_tree_split_counts = function(forest_num, tree_num, num_features) { + return(get_tree_split_counts_forest_container_cpp( + self$forest_container_ptr, + forest_num, + tree_num, + num_features + )) + }, + + #' @description + #' Retrieve a vector of split counts for every training set variable in a given forest + #' @param forest_num Index of the forest for which split counts will be retrieved + #' @param num_features Total number of features in the training set + get_forest_split_counts = function(forest_num, num_features) { + return(get_forest_split_counts_forest_container_cpp( + self$forest_container_ptr, + forest_num, + num_features + )) + }, + + #' @description + #' Retrieve a vector of split counts for every training set variable in a given forest, aggregated across ensembles and trees + #' @param num_features Total number of features in the training set + get_aggregate_split_counts = function(num_features) { + return(get_overall_split_counts_forest_container_cpp( + self$forest_container_ptr, + num_features + )) + }, + + #' @description + #' Retrieve a vector of split counts for every training set variable in a given forest, reported separately for each ensemble and tree + #' @param num_features Total number of features in the training set + get_granular_split_counts = function(num_features) { + n_samples <- self$num_samples() + n_trees <- self$num_trees() + output <- get_granular_split_count_array_forest_container_cpp( + self$forest_container_ptr, + num_features + ) + dim(output) <- c(n_samples, n_trees, num_features) + return(output) + }, + + #' @description + #' Maximum depth of a specific tree in a specific ensemble in a `ForestSamples` object + #' @param ensemble_num Ensemble number + #' @param tree_num Tree index within ensemble `ensemble_num` + #' @return Maximum leaf depth + ensemble_tree_max_depth = function(ensemble_num, tree_num) { + return(ensemble_tree_max_depth_forest_container_cpp( + self$forest_container_ptr, + ensemble_num, + tree_num + )) + }, + + #' @description + #' Average the maximum depth of each tree in a given ensemble in a `ForestSamples` object + #' @param ensemble_num Ensemble number + #' @return Average maximum depth + average_ensemble_max_depth = function(ensemble_num) { + return(ensemble_average_max_depth_forest_container_cpp( + self$forest_container_ptr, + ensemble_num + )) + }, + + #' @description + #' Average the maximum depth of each tree in each ensemble in a `ForestContainer` object + #' @return Average maximum depth + average_max_depth = function() { + return(average_max_depth_forest_container_cpp( + self$forest_container_ptr + )) + }, + + #' @description + #' Number of leaves in a given ensemble in a `ForestSamples` object + #' @param forest_num Index of the ensemble to be queried + #' @return Count of leaves in the ensemble stored at `forest_num` + num_forest_leaves = function(forest_num) { + return(num_leaves_ensemble_forest_container_cpp( + self$forest_container_ptr, + forest_num + )) + }, + + #' @description + #' Sum of squared (raw) leaf values in a given ensemble in a `ForestSamples` object + #' @param forest_num Index of the ensemble to be queried + #' @return Average maximum depth + sum_leaves_squared = function(forest_num) { + return(sum_leaves_squared_ensemble_forest_container_cpp( + self$forest_container_ptr, + forest_num + )) + }, + + #' @description + #' Whether or not a given node of a given tree in a given forest in the `ForestSamples` is a leaf + #' @param forest_num Index of the forest to be queried + #' @param tree_num Index of the tree to be queried + #' @param node_id Index of the node to be queried + #' @return `TRUE` if node is a leaf, `FALSE` otherwise + is_leaf_node = function(forest_num, tree_num, node_id) { + return(is_leaf_node_forest_container_cpp( + self$forest_container_ptr, + forest_num, + tree_num, + node_id + )) + }, + + #' @description + #' Whether or not a given node of a given tree in a given forest in the `ForestSamples` is a numeric split node + #' @param forest_num Index of the forest to be queried + #' @param tree_num Index of the tree to be queried + #' @param node_id Index of the node to be queried + #' @return `TRUE` if node is a numeric split node, `FALSE` otherwise + is_numeric_split_node = function(forest_num, tree_num, node_id) { + return(is_numeric_split_node_forest_container_cpp( + self$forest_container_ptr, + forest_num, + tree_num, + node_id + )) + }, + + #' @description + #' Whether or not a given node of a given tree in a given forest in the `ForestSamples` is a categorical split node + #' @param forest_num Index of the forest to be queried + #' @param tree_num Index of the tree to be queried + #' @param node_id Index of the node to be queried + #' @return `TRUE` if node is a categorical split node, `FALSE` otherwise + is_categorical_split_node = function(forest_num, tree_num, node_id) { + return(is_categorical_split_node_forest_container_cpp( + self$forest_container_ptr, + forest_num, + tree_num, + node_id + )) + }, + + #' @description + #' Parent node of given node of a given tree in a given forest in a `ForestSamples` object + #' @param forest_num Index of the forest to be queried + #' @param tree_num Index of the tree to be queried + #' @param node_id Index of the node to be queried + #' @return Integer ID of the parent node + parent_node = function(forest_num, tree_num, node_id) { + return(parent_node_forest_container_cpp( + self$forest_container_ptr, + forest_num, + tree_num, + node_id + )) + }, + + #' @description + #' Left child node of given node of a given tree in a given forest in a `ForestSamples` object + #' @param forest_num Index of the forest to be queried + #' @param tree_num Index of the tree to be queried + #' @param node_id Index of the node to be queried + #' @return Integer ID of the left child node + left_child_node = function(forest_num, tree_num, node_id) { + return(left_child_node_forest_container_cpp( + self$forest_container_ptr, + forest_num, + tree_num, + node_id + )) + }, + + #' @description + #' Right child node of given node of a given tree in a given forest in a `ForestSamples` object + #' @param forest_num Index of the forest to be queried + #' @param tree_num Index of the tree to be queried + #' @param node_id Index of the node to be queried + #' @return Integer ID of the right child node + right_child_node = function(forest_num, tree_num, node_id) { + return(right_child_node_forest_container_cpp( + self$forest_container_ptr, + forest_num, + tree_num, + node_id + )) + }, + + #' @description + #' Depth of given node of a given tree in a given forest in a `ForestSamples` object, with 0 depth for the root node. + #' @param forest_num Index of the forest to be queried + #' @param tree_num Index of the tree to be queried + #' @param node_id Index of the node to be queried + #' @return Integer valued depth of the node + node_depth = function(forest_num, tree_num, node_id) { + return(node_depth_forest_container_cpp( + self$forest_container_ptr, + forest_num, + tree_num, + node_id + )) + }, + + #' @description + #' Split index of given node of a given tree in a given forest in a `ForestSamples` object. Returns `-1` is node is a leaf. + #' @param forest_num Index of the forest to be queried + #' @param tree_num Index of the tree to be queried + #' @param node_id Index of the node to be queried + #' @return Integer valued depth of the node + node_split_index = function(forest_num, tree_num, node_id) { + if (self$is_leaf_node(forest_num, tree_num, node_id)) { + return(-1) + } else { + return(split_index_forest_container_cpp( + self$forest_container_ptr, + forest_num, + tree_num, + node_id + )) + } + }, + + #' @description + #' Threshold that defines a numeric split for a given node of a given tree in a given forest in a `ForestSamples` object. + #' Returns `Inf` if the node is a leaf or a categorical split node. + #' @param forest_num Index of the forest to be queried + #' @param tree_num Index of the tree to be queried + #' @param node_id Index of the node to be queried + #' @return Threshold defining a split for the node + node_split_threshold = function(forest_num, tree_num, node_id) { + if ( + self$is_leaf_node(forest_num, tree_num, node_id) || + self$is_categorical_split_node( forest_num, tree_num, - leaf_num, - feature_num, - split_threshold, - left_leaf_value, - right_leaf_value - ) { - if (length(left_leaf_value) > 1) { - add_numeric_split_tree_vector_forest_container_cpp( - self$forest_container_ptr, - forest_num, - tree_num, - leaf_num, - feature_num, - split_threshold, - left_leaf_value, - right_leaf_value - ) - } else { - add_numeric_split_tree_value_forest_container_cpp( - self$forest_container_ptr, - forest_num, - tree_num, - leaf_num, - feature_num, - split_threshold, - left_leaf_value, - right_leaf_value - ) - } - }, - - #' @description - #' Retrieve a vector of indices of leaf nodes for a given tree in a given forest - #' @param forest_num Index of the forest which contains tree `tree_num` - #' @param tree_num Index of the tree for which leaf indices will be retrieved - get_tree_leaves = function(forest_num, tree_num) { - return(get_tree_leaves_forest_container_cpp( - self$forest_container_ptr, - forest_num, - tree_num - )) - }, - - #' @description - #' Retrieve a vector of split counts for every training set variable in a given tree in a given forest - #' @param forest_num Index of the forest which contains tree `tree_num` - #' @param tree_num Index of the tree for which split counts will be retrieved - #' @param num_features Total number of features in the training set - get_tree_split_counts = function(forest_num, tree_num, num_features) { - return(get_tree_split_counts_forest_container_cpp( - self$forest_container_ptr, - forest_num, - tree_num, - num_features - )) - }, - - #' @description - #' Retrieve a vector of split counts for every training set variable in a given forest - #' @param forest_num Index of the forest for which split counts will be retrieved - #' @param num_features Total number of features in the training set - get_forest_split_counts = function(forest_num, num_features) { - return(get_forest_split_counts_forest_container_cpp( - self$forest_container_ptr, - forest_num, - num_features - )) - }, - - #' @description - #' Retrieve a vector of split counts for every training set variable in a given forest, aggregated across ensembles and trees - #' @param num_features Total number of features in the training set - get_aggregate_split_counts = function(num_features) { - return(get_overall_split_counts_forest_container_cpp( - self$forest_container_ptr, - num_features - )) - }, - - #' @description - #' Retrieve a vector of split counts for every training set variable in a given forest, reported separately for each ensemble and tree - #' @param num_features Total number of features in the training set - get_granular_split_counts = function(num_features) { - n_samples <- self$num_samples() - n_trees <- self$num_trees() - output <- get_granular_split_count_array_forest_container_cpp( - self$forest_container_ptr, - num_features - ) - dim(output) <- c(n_samples, n_trees, num_features) - return(output) - }, - - #' @description - #' Maximum depth of a specific tree in a specific ensemble in a `ForestSamples` object - #' @param ensemble_num Ensemble number - #' @param tree_num Tree index within ensemble `ensemble_num` - #' @return Maximum leaf depth - ensemble_tree_max_depth = function(ensemble_num, tree_num) { - return(ensemble_tree_max_depth_forest_container_cpp( - self$forest_container_ptr, - ensemble_num, - tree_num - )) - }, - - #' @description - #' Average the maximum depth of each tree in a given ensemble in a `ForestSamples` object - #' @param ensemble_num Ensemble number - #' @return Average maximum depth - average_ensemble_max_depth = function(ensemble_num) { - return(ensemble_average_max_depth_forest_container_cpp( - self$forest_container_ptr, - ensemble_num - )) - }, - - #' @description - #' Average the maximum depth of each tree in each ensemble in a `ForestContainer` object - #' @return Average maximum depth - average_max_depth = function() { - return(average_max_depth_forest_container_cpp( - self$forest_container_ptr - )) - }, - - #' @description - #' Number of leaves in a given ensemble in a `ForestSamples` object - #' @param forest_num Index of the ensemble to be queried - #' @return Count of leaves in the ensemble stored at `forest_num` - num_forest_leaves = function(forest_num) { - return(num_leaves_ensemble_forest_container_cpp( - self$forest_container_ptr, - forest_num - )) - }, - - #' @description - #' Sum of squared (raw) leaf values in a given ensemble in a `ForestSamples` object - #' @param forest_num Index of the ensemble to be queried - #' @return Average maximum depth - sum_leaves_squared = function(forest_num) { - return(sum_leaves_squared_ensemble_forest_container_cpp( - self$forest_container_ptr, - forest_num - )) - }, - - #' @description - #' Whether or not a given node of a given tree in a given forest in the `ForestSamples` is a leaf - #' @param forest_num Index of the forest to be queried - #' @param tree_num Index of the tree to be queried - #' @param node_id Index of the node to be queried - #' @return `TRUE` if node is a leaf, `FALSE` otherwise - is_leaf_node = function(forest_num, tree_num, node_id) { - return(is_leaf_node_forest_container_cpp( - self$forest_container_ptr, - forest_num, - tree_num, - node_id - )) - }, - - #' @description - #' Whether or not a given node of a given tree in a given forest in the `ForestSamples` is a numeric split node - #' @param forest_num Index of the forest to be queried - #' @param tree_num Index of the tree to be queried - #' @param node_id Index of the node to be queried - #' @return `TRUE` if node is a numeric split node, `FALSE` otherwise - is_numeric_split_node = function(forest_num, tree_num, node_id) { - return(is_numeric_split_node_forest_container_cpp( - self$forest_container_ptr, - forest_num, - tree_num, - node_id - )) - }, - - #' @description - #' Whether or not a given node of a given tree in a given forest in the `ForestSamples` is a categorical split node - #' @param forest_num Index of the forest to be queried - #' @param tree_num Index of the tree to be queried - #' @param node_id Index of the node to be queried - #' @return `TRUE` if node is a categorical split node, `FALSE` otherwise - is_categorical_split_node = function(forest_num, tree_num, node_id) { - return(is_categorical_split_node_forest_container_cpp( - self$forest_container_ptr, - forest_num, - tree_num, - node_id - )) - }, - - #' @description - #' Parent node of given node of a given tree in a given forest in a `ForestSamples` object - #' @param forest_num Index of the forest to be queried - #' @param tree_num Index of the tree to be queried - #' @param node_id Index of the node to be queried - #' @return Integer ID of the parent node - parent_node = function(forest_num, tree_num, node_id) { - return(parent_node_forest_container_cpp( - self$forest_container_ptr, - forest_num, - tree_num, - node_id - )) - }, - - #' @description - #' Left child node of given node of a given tree in a given forest in a `ForestSamples` object - #' @param forest_num Index of the forest to be queried - #' @param tree_num Index of the tree to be queried - #' @param node_id Index of the node to be queried - #' @return Integer ID of the left child node - left_child_node = function(forest_num, tree_num, node_id) { - return(left_child_node_forest_container_cpp( - self$forest_container_ptr, - forest_num, - tree_num, - node_id - )) - }, - - #' @description - #' Right child node of given node of a given tree in a given forest in a `ForestSamples` object - #' @param forest_num Index of the forest to be queried - #' @param tree_num Index of the tree to be queried - #' @param node_id Index of the node to be queried - #' @return Integer ID of the right child node - right_child_node = function(forest_num, tree_num, node_id) { - return(right_child_node_forest_container_cpp( - self$forest_container_ptr, - forest_num, - tree_num, - node_id - )) - }, - - #' @description - #' Depth of given node of a given tree in a given forest in a `ForestSamples` object, with 0 depth for the root node. - #' @param forest_num Index of the forest to be queried - #' @param tree_num Index of the tree to be queried - #' @param node_id Index of the node to be queried - #' @return Integer valued depth of the node - node_depth = function(forest_num, tree_num, node_id) { - return(node_depth_forest_container_cpp( - self$forest_container_ptr, - forest_num, - tree_num, - node_id - )) - }, - - #' @description - #' Split index of given node of a given tree in a given forest in a `ForestSamples` object. Returns `-1` is node is a leaf. - #' @param forest_num Index of the forest to be queried - #' @param tree_num Index of the tree to be queried - #' @param node_id Index of the node to be queried - #' @return Integer valued depth of the node - node_split_index = function(forest_num, tree_num, node_id) { - if (self$is_leaf_node(forest_num, tree_num, node_id)) { - return(-1) - } else { - return(split_index_forest_container_cpp( - self$forest_container_ptr, - forest_num, - tree_num, - node_id - )) - } - }, - - #' @description - #' Threshold that defines a numeric split for a given node of a given tree in a given forest in a `ForestSamples` object. - #' Returns `Inf` if the node is a leaf or a categorical split node. - #' @param forest_num Index of the forest to be queried - #' @param tree_num Index of the tree to be queried - #' @param node_id Index of the node to be queried - #' @return Threshold defining a split for the node - node_split_threshold = function(forest_num, tree_num, node_id) { - if ( - self$is_leaf_node(forest_num, tree_num, node_id) || - self$is_categorical_split_node( - forest_num, - tree_num, - node_id - ) - ) { - return(Inf) - } else { - return(split_theshold_forest_container_cpp( - self$forest_container_ptr, - forest_num, - tree_num, - node_id - )) - } - }, - - #' @description - #' Array of category indices that define a categorical split for a given node of a given tree in a given forest in a `ForestSamples` object. - #' Returns `c(Inf)` if the node is a leaf or a numeric split node. - #' @param forest_num Index of the forest to be queried - #' @param tree_num Index of the tree to be queried - #' @param node_id Index of the node to be queried - #' @return Categories defining a split for the node - node_split_categories = function(forest_num, tree_num, node_id) { - if ( - self$is_leaf_node(forest_num, tree_num, node_id) || - self$is_numeric_split_node(forest_num, tree_num, node_id) - ) { - return(c(Inf)) - } else { - return(split_categories_forest_container_cpp( - self$forest_container_ptr, - forest_num, - tree_num, - node_id - )) - } - }, - - #' @description - #' Leaf node value(s) for a given node of a given tree in a given forest in a `ForestSamples` object. - #' Values are stale if the node is a split node. - #' @param forest_num Index of the forest to be queried - #' @param tree_num Index of the tree to be queried - #' @param node_id Index of the node to be queried - #' @return Vector (often univariate) of leaf values - node_leaf_values = function(forest_num, tree_num, node_id) { - return(leaf_values_forest_container_cpp( - self$forest_container_ptr, - forest_num, - tree_num, - node_id - )) - }, - - #' @description - #' Number of nodes in a given tree in a given forest in a `ForestSamples` object. - #' @param forest_num Index of the forest to be queried - #' @param tree_num Index of the tree to be queried - #' @return Count of total tree nodes - num_nodes = function(forest_num, tree_num) { - return(num_nodes_forest_container_cpp( - self$forest_container_ptr, - forest_num, - tree_num - )) - }, - - #' @description - #' Number of leaves in a given tree in a given forest in a `ForestSamples` object. - #' @param forest_num Index of the forest to be queried - #' @param tree_num Index of the tree to be queried - #' @return Count of total tree leaves - num_leaves = function(forest_num, tree_num) { - return(num_leaves_forest_container_cpp( - self$forest_container_ptr, - forest_num, - tree_num - )) - }, - - #' @description - #' Number of leaf parents (split nodes with two leaves as children) in a given tree in a given forest in a `ForestSamples` object. - #' @param forest_num Index of the forest to be queried - #' @param tree_num Index of the tree to be queried - #' @return Count of total tree leaf parents - num_leaf_parents = function(forest_num, tree_num) { - return(num_leaf_parents_forest_container_cpp( - self$forest_container_ptr, - forest_num, - tree_num - )) - }, - - #' @description - #' Number of split nodes in a given tree in a given forest in a `ForestSamples` object. - #' @param forest_num Index of the forest to be queried - #' @param tree_num Index of the tree to be queried - #' @return Count of total tree split nodes - num_split_nodes = function(forest_num, tree_num) { - return(num_split_nodes_forest_container_cpp( - self$forest_container_ptr, - forest_num, - tree_num - )) - }, - - #' @description - #' Array of node indices in a given tree in a given forest in a `ForestSamples` object. - #' @param forest_num Index of the forest to be queried - #' @param tree_num Index of the tree to be queried - #' @return Indices of tree nodes - nodes = function(forest_num, tree_num) { - return(nodes_forest_container_cpp( - self$forest_container_ptr, - forest_num, - tree_num - )) - }, - - #' @description - #' Array of leaf indices in a given tree in a given forest in a `ForestSamples` object. - #' @param forest_num Index of the forest to be queried - #' @param tree_num Index of the tree to be queried - #' @return Indices of leaf nodes - leaves = function(forest_num, tree_num) { - return(leaves_forest_container_cpp( - self$forest_container_ptr, - forest_num, - tree_num - )) - }, - - #' @description - #' Modify the ``ForestSamples`` object by removing the forest sample indexed by `forest_num - #' @param forest_num Index of the forest to be removed - delete_sample = function(forest_num) { - return(remove_sample_forest_container_cpp( - self$forest_container_ptr, - forest_num - )) - } - ) + node_id + ) + ) { + return(Inf) + } else { + return(split_theshold_forest_container_cpp( + self$forest_container_ptr, + forest_num, + tree_num, + node_id + )) + } + }, + + #' @description + #' Array of category indices that define a categorical split for a given node of a given tree in a given forest in a `ForestSamples` object. + #' Returns `c(Inf)` if the node is a leaf or a numeric split node. + #' @param forest_num Index of the forest to be queried + #' @param tree_num Index of the tree to be queried + #' @param node_id Index of the node to be queried + #' @return Categories defining a split for the node + node_split_categories = function(forest_num, tree_num, node_id) { + if ( + self$is_leaf_node(forest_num, tree_num, node_id) || + self$is_numeric_split_node(forest_num, tree_num, node_id) + ) { + return(c(Inf)) + } else { + return(split_categories_forest_container_cpp( + self$forest_container_ptr, + forest_num, + tree_num, + node_id + )) + } + }, + + #' @description + #' Leaf node value(s) for a given node of a given tree in a given forest in a `ForestSamples` object. + #' Values are stale if the node is a split node. + #' @param forest_num Index of the forest to be queried + #' @param tree_num Index of the tree to be queried + #' @param node_id Index of the node to be queried + #' @return Vector (often univariate) of leaf values + node_leaf_values = function(forest_num, tree_num, node_id) { + return(leaf_values_forest_container_cpp( + self$forest_container_ptr, + forest_num, + tree_num, + node_id + )) + }, + + #' @description + #' Number of nodes in a given tree in a given forest in a `ForestSamples` object. + #' @param forest_num Index of the forest to be queried + #' @param tree_num Index of the tree to be queried + #' @return Count of total tree nodes + num_nodes = function(forest_num, tree_num) { + return(num_nodes_forest_container_cpp( + self$forest_container_ptr, + forest_num, + tree_num + )) + }, + + #' @description + #' Number of leaves in a given tree in a given forest in a `ForestSamples` object. + #' @param forest_num Index of the forest to be queried + #' @param tree_num Index of the tree to be queried + #' @return Count of total tree leaves + num_leaves = function(forest_num, tree_num) { + return(num_leaves_forest_container_cpp( + self$forest_container_ptr, + forest_num, + tree_num + )) + }, + + #' @description + #' Number of leaf parents (split nodes with two leaves as children) in a given tree in a given forest in a `ForestSamples` object. + #' @param forest_num Index of the forest to be queried + #' @param tree_num Index of the tree to be queried + #' @return Count of total tree leaf parents + num_leaf_parents = function(forest_num, tree_num) { + return(num_leaf_parents_forest_container_cpp( + self$forest_container_ptr, + forest_num, + tree_num + )) + }, + + #' @description + #' Number of split nodes in a given tree in a given forest in a `ForestSamples` object. + #' @param forest_num Index of the forest to be queried + #' @param tree_num Index of the tree to be queried + #' @return Count of total tree split nodes + num_split_nodes = function(forest_num, tree_num) { + return(num_split_nodes_forest_container_cpp( + self$forest_container_ptr, + forest_num, + tree_num + )) + }, + + #' @description + #' Array of node indices in a given tree in a given forest in a `ForestSamples` object. + #' @param forest_num Index of the forest to be queried + #' @param tree_num Index of the tree to be queried + #' @return Indices of tree nodes + nodes = function(forest_num, tree_num) { + return(nodes_forest_container_cpp( + self$forest_container_ptr, + forest_num, + tree_num + )) + }, + + #' @description + #' Array of leaf indices in a given tree in a given forest in a `ForestSamples` object. + #' @param forest_num Index of the forest to be queried + #' @param tree_num Index of the tree to be queried + #' @return Indices of leaf nodes + leaves = function(forest_num, tree_num) { + return(leaves_forest_container_cpp( + self$forest_container_ptr, + forest_num, + tree_num + )) + }, + + #' @description + #' Modify the ``ForestSamples`` object by removing the forest sample indexed by `forest_num + #' @param forest_num Index of the forest to be removed + delete_sample = function(forest_num) { + return(remove_sample_forest_container_cpp( + self$forest_container_ptr, + forest_num + )) + } + ) ) #' Class that stores a single ensemble of decision trees (often treated as the "active forest") @@ -903,330 +903,330 @@ ForestSamples <- R6::R6Class( #' Wrapper around a C++ tree ensemble Forest <- R6::R6Class( - classname = "Forest", - cloneable = FALSE, - public = list( - #' @field forest_ptr External pointer to a C++ TreeEnsemble class - forest_ptr = NULL, - - #' @field internal_forest_is_empty Whether the forest has not yet been "initialized" such that its `predict` function can be called. - internal_forest_is_empty = TRUE, - - #' @description - #' Create a new Forest object. - #' @param num_trees Number of trees in the forest - #' @param leaf_dimension Dimensionality of the outcome model - #' @param is_leaf_constant Whether leaf is constant - #' @param is_exponentiated Whether forest predictions should be exponentiated before being returned - #' @return A new `Forest` object. - initialize = function( - num_trees, - leaf_dimension = 1, - is_leaf_constant = FALSE, - is_exponentiated = FALSE - ) { - self$forest_ptr <- active_forest_cpp( - num_trees, - leaf_dimension, - is_leaf_constant, - is_exponentiated - ) - self$internal_forest_is_empty <- TRUE - }, - - #' @description - #' Create a larger forest by merging the trees of this forest with those of another forest - #' @param forest Forest to be merged into this forest - merge_forest = function(forest) { - stopifnot(self$leaf_dimension() == forest$leaf_dimension()) - stopifnot(self$is_constant_leaf() == forest$is_constant_leaf()) - stopifnot(self$is_exponentiated() == forest$is_exponentiated()) - forest_merge_cpp(self$forest_ptr, forest$forest_ptr) - }, - - #' @description - #' Add a constant value to every leaf of every tree in an ensemble. If leaves are multi-dimensional, `constant_value` will be added to every dimension of the leaves. - #' @param constant_value Value that will be added to every leaf of every tree - add_constant = function(constant_value) { - forest_add_constant_cpp(self$forest_ptr, constant_value) - }, - - #' @description - #' Multiply every leaf of every tree by a constant value. If leaves are multi-dimensional, `constant_multiple` will be multiplied through every dimension of the leaves. - #' @param constant_multiple Value that will be multiplied by every leaf of every tree - multiply_constant = function(constant_multiple) { - forest_multiply_constant_cpp(self$forest_ptr, constant_multiple) - }, - - #' @description - #' Predict forest on every sample in `forest_dataset` - #' @param forest_dataset `ForestDataset` R class - #' @return vector of predictions with as many rows as in `forest_dataset` - predict = function(forest_dataset) { - stopifnot(!is.null(forest_dataset$data_ptr)) - stopifnot(!is.null(self$forest_ptr)) - return(predict_active_forest_cpp( - self$forest_ptr, - forest_dataset$data_ptr - )) - }, - - #' @description - #' Predict "raw" leaf values (without being multiplied by basis) for every sample in `forest_dataset` - #' @param forest_dataset `ForestDataset` R class - #' @return Array of predictions for each observation in `forest_dataset` and - #' each sample in the `ForestSamples` class with each prediction having the - #' dimensionality of the forests' leaf model. In the case of a constant leaf model - #' or univariate leaf regression, this array is a vector (length is the number of - #' observations). In the case of a multivariate leaf regression, - #' this array is a matrix (number of observations by leaf model dimension, - #' number of samples). - predict_raw = function(forest_dataset) { - stopifnot(!is.null(forest_dataset$data_ptr)) - # Unpack dimensions - output_dim <- leaf_dimension_active_forest_cpp(self$forest_ptr) - n <- dataset_num_rows_cpp(forest_dataset$data_ptr) - - # Predict leaf values from forest - predictions <- predict_raw_active_forest_cpp( - self$forest_ptr, - forest_dataset$data_ptr - ) - if (output_dim > 1) { - dim(predictions) <- c(n, output_dim) - } - - return(predictions) - }, - - #' @description - #' Set a constant predicted value for every tree in the ensemble. - #' Stops program if any tree is more than a root node. - #' @param leaf_value Constant leaf value(s) to be fixed for each tree in the ensemble indexed by `forest_num`. Can be either a single number or a vector, depending on the forest's leaf dimension. - set_root_leaves = function(leaf_value) { - stopifnot(!is.null(self$forest_ptr)) - stopifnot(self$internal_forest_is_empty) - - # Set leaf values - if (length(leaf_value) == 1) { - stopifnot( - leaf_dimension_active_forest_cpp(self$forest_ptr) == 1 - ) - set_leaf_value_active_forest_cpp(self$forest_ptr, leaf_value) - } else if (length(leaf_value) > 1) { - stopifnot( - leaf_dimension_active_forest_cpp(self$forest_ptr) == - length(leaf_value) - ) - set_leaf_vector_active_forest_cpp(self$forest_ptr, leaf_value) - } else { - stop( - "leaf_value must be a numeric value or vector of length >= 1" - ) - } - - self$internal_forest_is_empty = FALSE - }, - - #' @description - #' Set a constant predicted value for every tree in the ensemble. - #' Stops program if any tree is more than a root node. - #' @param dataset `ForestDataset` Dataset class (covariates, basis, etc...) - #' @param outcome `Outcome` Outcome class (residual / partial residual) - #' @param forest_model `ForestModel` object storing tracking structures used in training / sampling - #' @param leaf_model_int Integer value encoding the leaf model type (0 = constant gaussian, 1 = univariate gaussian, 2 = multivariate gaussian, 3 = log linear variance). - #' @param leaf_value Constant leaf value(s) to be fixed for each tree in the ensemble indexed by `forest_num`. Can be either a single number or a vector, depending on the forest's leaf dimension. - prepare_for_sampler = function( - dataset, - outcome, - forest_model, - leaf_model_int, - leaf_value - ) { - stopifnot(!is.null(dataset$data_ptr)) - stopifnot(!is.null(outcome$data_ptr)) - stopifnot(!is.null(forest_model$tracker_ptr)) - stopifnot(!is.null(self$forest_ptr)) - stopifnot(self$internal_forest_is_empty) - - # Initialize the model - initialize_forest_model_active_forest_cpp( - dataset$data_ptr, - outcome$data_ptr, - self$forest_ptr, - forest_model$tracker_ptr, - leaf_value, - leaf_model_int - ) - - self$internal_forest_is_empty = FALSE - }, - - #' @description - #' Adjusts residual based on the predictions of a forest - #' - #' This is typically run just once at the beginning of a forest sampling algorithm. - #' After trees are initialized with constant root node predictions, their root predictions are subtracted out of the residual. - #' @param dataset `ForestDataset` object storing the covariates and bases for a given forest - #' @param outcome `Outcome` object storing the residuals to be updated based on forest predictions - #' @param forest_model `ForestModel` object storing tracking structures used in training / sampling - #' @param requires_basis Whether or not a forest requires a basis for prediction - #' @param add Whether forest predictions should be added to or subtracted from residuals - adjust_residual = function( - dataset, - outcome, - forest_model, - requires_basis, - add - ) { - stopifnot(!is.null(dataset$data_ptr)) - stopifnot(!is.null(outcome$data_ptr)) - stopifnot(!is.null(forest_model$tracker_ptr)) - stopifnot(!is.null(self$forest_ptr)) - - adjust_residual_active_forest_cpp( - dataset$data_ptr, - outcome$data_ptr, - self$forest_ptr, - forest_model$tracker_ptr, - requires_basis, - add - ) - }, - - #' @description - #' Return number of trees in each ensemble of a `Forest` object - #' @return Tree count - num_trees = function() { - return(num_trees_active_forest_cpp(self$forest_ptr)) - }, - - #' @description - #' Return output dimension of trees in a `Forest` object - #' @return Leaf node parameter size - leaf_dimension = function() { - return(leaf_dimension_active_forest_cpp(self$forest_ptr)) - }, - - #' @description - #' Return constant leaf status of trees in a `Forest` object - #' @return `TRUE` if leaves are constant, `FALSE` otherwise - is_constant_leaf = function() { - return(is_leaf_constant_active_forest_cpp(self$forest_ptr)) - }, - - #' @description - #' Return exponentiation status of trees in a `Forest` object - #' @return `TRUE` if leaf predictions must be exponentiated, `FALSE` otherwise - is_exponentiated = function() { - return(is_exponentiated_active_forest_cpp(self$forest_ptr)) - }, - - #' @description - #' Add a numeric (i.e. `X[,i] <= c`) split to a given tree in the ensemble - #' @param tree_num Index of the tree to be split - #' @param leaf_num Leaf to be split - #' @param feature_num Feature that defines the new split - #' @param split_threshold Value that defines the cutoff of the new split - #' @param left_leaf_value Value (or vector of values) to assign to the newly created left node - #' @param right_leaf_value Value (or vector of values) to assign to the newly created right node - add_numeric_split_tree = function( - tree_num, - leaf_num, - feature_num, - split_threshold, - left_leaf_value, - right_leaf_value - ) { - if (length(left_leaf_value) > 1) { - add_numeric_split_tree_vector_active_forest_cpp( - self$forest_ptr, - tree_num, - leaf_num, - feature_num, - split_threshold, - left_leaf_value, - right_leaf_value - ) - } else { - add_numeric_split_tree_value_active_forest_cpp( - self$forest_ptr, - tree_num, - leaf_num, - feature_num, - split_threshold, - left_leaf_value, - right_leaf_value - ) - } - }, - - #' @description - #' Retrieve a vector of indices of leaf nodes for a given tree in a given forest - #' @param tree_num Index of the tree for which leaf indices will be retrieved - get_tree_leaves = function(tree_num) { - return(get_tree_leaves_active_forest_cpp(self$forest_ptr, tree_num)) - }, - - #' @description - #' Retrieve a vector of split counts for every training set variable in a given tree in the forest - #' @param tree_num Index of the tree for which split counts will be retrieved - #' @param num_features Total number of features in the training set - get_tree_split_counts = function(tree_num, num_features) { - return(get_tree_split_counts_active_forest_cpp( - self$forest_ptr, - tree_num, - num_features - )) - }, - - #' @description - #' Retrieve a vector of split counts for every training set variable in the forest - #' @param num_features Total number of features in the training set - get_forest_split_counts = function(num_features) { - return(get_overall_split_counts_active_forest_cpp( - self$forest_ptr, - num_features - )) - }, - - #' @description - #' Maximum depth of a specific tree in the forest - #' @param tree_num Tree index within forest - #' @return Maximum leaf depth - tree_max_depth = function(tree_num) { - return(ensemble_tree_max_depth_active_forest_cpp( - self$forest_ptr, - tree_num - )) - }, - - #' @description - #' Average the maximum depth of each tree in the forest - #' @return Average maximum depth - average_max_depth = function() { - return(ensemble_average_max_depth_active_forest_cpp( - self$forest_ptr - )) - }, - - #' @description - #' When a forest object is created, it is "empty" in the sense that none - #' of its component trees have leaves with values. There are two ways to - #' "initialize" a Forest object. First, the `set_root_leaves()` method - #' simply initializes every tree in the forest to a single node carrying - #' the same (user-specified) leaf value. Second, the `prepare_for_sampler()` - #' method initializes every tree in the forest to a single node with the - #' same value and also propagates this information through to a ForestModel - #' object, which must be synchronized with a Forest during a forest - #' sampler loop. - #' @return `TRUE` if a Forest has not yet been initialized with a constant - #' root value, `FALSE` otherwise if the forest has already been - #' initialized / grown. - is_empty = function() { - return(self$internal_forest_is_empty) - } - ) + classname = "Forest", + cloneable = FALSE, + public = list( + #' @field forest_ptr External pointer to a C++ TreeEnsemble class + forest_ptr = NULL, + + #' @field internal_forest_is_empty Whether the forest has not yet been "initialized" such that its `predict` function can be called. + internal_forest_is_empty = TRUE, + + #' @description + #' Create a new Forest object. + #' @param num_trees Number of trees in the forest + #' @param leaf_dimension Dimensionality of the outcome model + #' @param is_leaf_constant Whether leaf is constant + #' @param is_exponentiated Whether forest predictions should be exponentiated before being returned + #' @return A new `Forest` object. + initialize = function( + num_trees, + leaf_dimension = 1, + is_leaf_constant = FALSE, + is_exponentiated = FALSE + ) { + self$forest_ptr <- active_forest_cpp( + num_trees, + leaf_dimension, + is_leaf_constant, + is_exponentiated + ) + self$internal_forest_is_empty <- TRUE + }, + + #' @description + #' Create a larger forest by merging the trees of this forest with those of another forest + #' @param forest Forest to be merged into this forest + merge_forest = function(forest) { + stopifnot(self$leaf_dimension() == forest$leaf_dimension()) + stopifnot(self$is_constant_leaf() == forest$is_constant_leaf()) + stopifnot(self$is_exponentiated() == forest$is_exponentiated()) + forest_merge_cpp(self$forest_ptr, forest$forest_ptr) + }, + + #' @description + #' Add a constant value to every leaf of every tree in an ensemble. If leaves are multi-dimensional, `constant_value` will be added to every dimension of the leaves. + #' @param constant_value Value that will be added to every leaf of every tree + add_constant = function(constant_value) { + forest_add_constant_cpp(self$forest_ptr, constant_value) + }, + + #' @description + #' Multiply every leaf of every tree by a constant value. If leaves are multi-dimensional, `constant_multiple` will be multiplied through every dimension of the leaves. + #' @param constant_multiple Value that will be multiplied by every leaf of every tree + multiply_constant = function(constant_multiple) { + forest_multiply_constant_cpp(self$forest_ptr, constant_multiple) + }, + + #' @description + #' Predict forest on every sample in `forest_dataset` + #' @param forest_dataset `ForestDataset` R class + #' @return vector of predictions with as many rows as in `forest_dataset` + predict = function(forest_dataset) { + stopifnot(!is.null(forest_dataset$data_ptr)) + stopifnot(!is.null(self$forest_ptr)) + return(predict_active_forest_cpp( + self$forest_ptr, + forest_dataset$data_ptr + )) + }, + + #' @description + #' Predict "raw" leaf values (without being multiplied by basis) for every sample in `forest_dataset` + #' @param forest_dataset `ForestDataset` R class + #' @return Array of predictions for each observation in `forest_dataset` and + #' each sample in the `ForestSamples` class with each prediction having the + #' dimensionality of the forests' leaf model. In the case of a constant leaf model + #' or univariate leaf regression, this array is a vector (length is the number of + #' observations). In the case of a multivariate leaf regression, + #' this array is a matrix (number of observations by leaf model dimension, + #' number of samples). + predict_raw = function(forest_dataset) { + stopifnot(!is.null(forest_dataset$data_ptr)) + # Unpack dimensions + output_dim <- leaf_dimension_active_forest_cpp(self$forest_ptr) + n <- dataset_num_rows_cpp(forest_dataset$data_ptr) + + # Predict leaf values from forest + predictions <- predict_raw_active_forest_cpp( + self$forest_ptr, + forest_dataset$data_ptr + ) + if (output_dim > 1) { + dim(predictions) <- c(n, output_dim) + } + + return(predictions) + }, + + #' @description + #' Set a constant predicted value for every tree in the ensemble. + #' Stops program if any tree is more than a root node. + #' @param leaf_value Constant leaf value(s) to be fixed for each tree in the ensemble indexed by `forest_num`. Can be either a single number or a vector, depending on the forest's leaf dimension. + set_root_leaves = function(leaf_value) { + stopifnot(!is.null(self$forest_ptr)) + stopifnot(self$internal_forest_is_empty) + + # Set leaf values + if (length(leaf_value) == 1) { + stopifnot( + leaf_dimension_active_forest_cpp(self$forest_ptr) == 1 + ) + set_leaf_value_active_forest_cpp(self$forest_ptr, leaf_value) + } else if (length(leaf_value) > 1) { + stopifnot( + leaf_dimension_active_forest_cpp(self$forest_ptr) == + length(leaf_value) + ) + set_leaf_vector_active_forest_cpp(self$forest_ptr, leaf_value) + } else { + stop( + "leaf_value must be a numeric value or vector of length >= 1" + ) + } + + self$internal_forest_is_empty = FALSE + }, + + #' @description + #' Set a constant predicted value for every tree in the ensemble. + #' Stops program if any tree is more than a root node. + #' @param dataset `ForestDataset` Dataset class (covariates, basis, etc...) + #' @param outcome `Outcome` Outcome class (residual / partial residual) + #' @param forest_model `ForestModel` object storing tracking structures used in training / sampling + #' @param leaf_model_int Integer value encoding the leaf model type (0 = constant gaussian, 1 = univariate gaussian, 2 = multivariate gaussian, 3 = log linear variance). + #' @param leaf_value Constant leaf value(s) to be fixed for each tree in the ensemble indexed by `forest_num`. Can be either a single number or a vector, depending on the forest's leaf dimension. + prepare_for_sampler = function( + dataset, + outcome, + forest_model, + leaf_model_int, + leaf_value + ) { + stopifnot(!is.null(dataset$data_ptr)) + stopifnot(!is.null(outcome$data_ptr)) + stopifnot(!is.null(forest_model$tracker_ptr)) + stopifnot(!is.null(self$forest_ptr)) + stopifnot(self$internal_forest_is_empty) + + # Initialize the model + initialize_forest_model_active_forest_cpp( + dataset$data_ptr, + outcome$data_ptr, + self$forest_ptr, + forest_model$tracker_ptr, + leaf_value, + leaf_model_int + ) + + self$internal_forest_is_empty = FALSE + }, + + #' @description + #' Adjusts residual based on the predictions of a forest + #' + #' This is typically run just once at the beginning of a forest sampling algorithm. + #' After trees are initialized with constant root node predictions, their root predictions are subtracted out of the residual. + #' @param dataset `ForestDataset` object storing the covariates and bases for a given forest + #' @param outcome `Outcome` object storing the residuals to be updated based on forest predictions + #' @param forest_model `ForestModel` object storing tracking structures used in training / sampling + #' @param requires_basis Whether or not a forest requires a basis for prediction + #' @param add Whether forest predictions should be added to or subtracted from residuals + adjust_residual = function( + dataset, + outcome, + forest_model, + requires_basis, + add + ) { + stopifnot(!is.null(dataset$data_ptr)) + stopifnot(!is.null(outcome$data_ptr)) + stopifnot(!is.null(forest_model$tracker_ptr)) + stopifnot(!is.null(self$forest_ptr)) + + adjust_residual_active_forest_cpp( + dataset$data_ptr, + outcome$data_ptr, + self$forest_ptr, + forest_model$tracker_ptr, + requires_basis, + add + ) + }, + + #' @description + #' Return number of trees in each ensemble of a `Forest` object + #' @return Tree count + num_trees = function() { + return(num_trees_active_forest_cpp(self$forest_ptr)) + }, + + #' @description + #' Return output dimension of trees in a `Forest` object + #' @return Leaf node parameter size + leaf_dimension = function() { + return(leaf_dimension_active_forest_cpp(self$forest_ptr)) + }, + + #' @description + #' Return constant leaf status of trees in a `Forest` object + #' @return `TRUE` if leaves are constant, `FALSE` otherwise + is_constant_leaf = function() { + return(is_leaf_constant_active_forest_cpp(self$forest_ptr)) + }, + + #' @description + #' Return exponentiation status of trees in a `Forest` object + #' @return `TRUE` if leaf predictions must be exponentiated, `FALSE` otherwise + is_exponentiated = function() { + return(is_exponentiated_active_forest_cpp(self$forest_ptr)) + }, + + #' @description + #' Add a numeric (i.e. `X[,i] <= c`) split to a given tree in the ensemble + #' @param tree_num Index of the tree to be split + #' @param leaf_num Leaf to be split + #' @param feature_num Feature that defines the new split + #' @param split_threshold Value that defines the cutoff of the new split + #' @param left_leaf_value Value (or vector of values) to assign to the newly created left node + #' @param right_leaf_value Value (or vector of values) to assign to the newly created right node + add_numeric_split_tree = function( + tree_num, + leaf_num, + feature_num, + split_threshold, + left_leaf_value, + right_leaf_value + ) { + if (length(left_leaf_value) > 1) { + add_numeric_split_tree_vector_active_forest_cpp( + self$forest_ptr, + tree_num, + leaf_num, + feature_num, + split_threshold, + left_leaf_value, + right_leaf_value + ) + } else { + add_numeric_split_tree_value_active_forest_cpp( + self$forest_ptr, + tree_num, + leaf_num, + feature_num, + split_threshold, + left_leaf_value, + right_leaf_value + ) + } + }, + + #' @description + #' Retrieve a vector of indices of leaf nodes for a given tree in a given forest + #' @param tree_num Index of the tree for which leaf indices will be retrieved + get_tree_leaves = function(tree_num) { + return(get_tree_leaves_active_forest_cpp(self$forest_ptr, tree_num)) + }, + + #' @description + #' Retrieve a vector of split counts for every training set variable in a given tree in the forest + #' @param tree_num Index of the tree for which split counts will be retrieved + #' @param num_features Total number of features in the training set + get_tree_split_counts = function(tree_num, num_features) { + return(get_tree_split_counts_active_forest_cpp( + self$forest_ptr, + tree_num, + num_features + )) + }, + + #' @description + #' Retrieve a vector of split counts for every training set variable in the forest + #' @param num_features Total number of features in the training set + get_forest_split_counts = function(num_features) { + return(get_overall_split_counts_active_forest_cpp( + self$forest_ptr, + num_features + )) + }, + + #' @description + #' Maximum depth of a specific tree in the forest + #' @param tree_num Tree index within forest + #' @return Maximum leaf depth + tree_max_depth = function(tree_num) { + return(ensemble_tree_max_depth_active_forest_cpp( + self$forest_ptr, + tree_num + )) + }, + + #' @description + #' Average the maximum depth of each tree in the forest + #' @return Average maximum depth + average_max_depth = function() { + return(ensemble_average_max_depth_active_forest_cpp( + self$forest_ptr + )) + }, + + #' @description + #' When a forest object is created, it is "empty" in the sense that none + #' of its component trees have leaves with values. There are two ways to + #' "initialize" a Forest object. First, the `set_root_leaves()` method + #' simply initializes every tree in the forest to a single node carrying + #' the same (user-specified) leaf value. Second, the `prepare_for_sampler()` + #' method initializes every tree in the forest to a single node with the + #' same value and also propagates this information through to a ForestModel + #' object, which must be synchronized with a Forest during a forest + #' sampler loop. + #' @return `TRUE` if a Forest has not yet been initialized with a constant + #' root value, `FALSE` otherwise if the forest has already been + #' initialized / grown. + is_empty = function() { + return(self$internal_forest_is_empty) + } + ) ) #' Create a container of forest samples @@ -1246,19 +1246,19 @@ Forest <- R6::R6Class( #' is_exponentiated <- FALSE #' forest_samples <- createForestSamples(num_trees, leaf_dimension, is_leaf_constant, is_exponentiated) createForestSamples <- function( - num_trees, - leaf_dimension = 1, - is_leaf_constant = FALSE, - is_exponentiated = FALSE + num_trees, + leaf_dimension = 1, + is_leaf_constant = FALSE, + is_exponentiated = FALSE ) { - return(invisible( - (ForestSamples$new( - num_trees, - leaf_dimension, - is_leaf_constant, - is_exponentiated - )) + return(invisible( + (ForestSamples$new( + num_trees, + leaf_dimension, + is_leaf_constant, + is_exponentiated )) + )) } #' Create a forest @@ -1278,19 +1278,19 @@ createForestSamples <- function( #' is_exponentiated <- FALSE #' forest <- createForest(num_trees, leaf_dimension, is_leaf_constant, is_exponentiated) createForest <- function( - num_trees, - leaf_dimension = 1, - is_leaf_constant = FALSE, - is_exponentiated = FALSE + num_trees, + leaf_dimension = 1, + is_leaf_constant = FALSE, + is_exponentiated = FALSE ) { - return(invisible( - (Forest$new( - num_trees, - leaf_dimension, - is_leaf_constant, - is_exponentiated - )) + return(invisible( + (Forest$new( + num_trees, + leaf_dimension, + is_leaf_constant, + is_exponentiated )) + )) } #' Reset an active forest, either from a specific forest in a `ForestContainer` @@ -1316,25 +1316,25 @@ createForest <- function( #' resetActiveForest(active_forest, forest_samples, 0) #' resetActiveForest(active_forest) resetActiveForest <- function( - active_forest, - forest_samples = NULL, - forest_num = NULL + active_forest, + forest_samples = NULL, + forest_num = NULL ) { - if (is.null(forest_samples)) { - root_reset_active_forest_cpp(active_forest$forest_ptr) - active_forest$internal_forest_is_empty = TRUE - } else { - if (is.null(forest_num)) { - stop( - "`forest_num` must be specified if `forest_samples` is provided" - ) - } - reset_active_forest_cpp( - active_forest$forest_ptr, - forest_samples$forest_container_ptr, - forest_num - ) + if (is.null(forest_samples)) { + root_reset_active_forest_cpp(active_forest$forest_ptr) + active_forest$internal_forest_is_empty = TRUE + } else { + if (is.null(forest_num)) { + stop( + "`forest_num` must be specified if `forest_samples` is provided" + ) } + reset_active_forest_cpp( + active_forest$forest_ptr, + forest_samples$forest_container_ptr, + forest_num + ) + } } #' Re-initialize a forest model (tracking data structures) from a specific forest in a `ForestContainer` @@ -1394,17 +1394,17 @@ resetActiveForest <- function( #' resetActiveForest(active_forest, forest_samples, 0) #' resetForestModel(forest_model, active_forest, forest_dataset, outcome, TRUE) resetForestModel <- function( - forest_model, - forest, - dataset, - residual, - is_mean_model + forest_model, + forest, + dataset, + residual, + is_mean_model ) { - reset_forest_model_cpp( - forest_model$tracker_ptr, - forest$forest_ptr, - dataset$data_ptr, - residual$data_ptr, - is_mean_model - ) + reset_forest_model_cpp( + forest_model$tracker_ptr, + forest$forest_ptr, + dataset$data_ptr, + residual$data_ptr, + is_mean_model + ) } diff --git a/R/kernel.R b/R/kernel.R index d7e9661e..2b643b98 100644 --- a/R/kernel.R +++ b/R/kernel.R @@ -48,113 +48,113 @@ #' computeForestLeafIndices(bart_model, X, "mean", 0) #' computeForestLeafIndices(bart_model, X, "mean", c(1,3,9)) computeForestLeafIndices <- function( - model_object, - covariates, - forest_type = NULL, - propensity = NULL, - forest_inds = NULL + model_object, + covariates, + forest_type = NULL, + propensity = NULL, + forest_inds = NULL ) { - # Extract relevant forest container - stopifnot(any(c( - inherits(model_object, "bartmodel"), - inherits(model_object, "bcfmodel"), - inherits(model_object, "ForestSamples") - ))) - model_type <- ifelse( - inherits(model_object, "bartmodel"), - "bart", - ifelse(inherits(model_object, "bcfmodel"), "bcf", "forest_samples") - ) - if (model_type == "bart") { - stopifnot(forest_type %in% c("mean", "variance")) - if (forest_type == "mean") { - if (!model_object$model_params$include_mean_forest) { - stop("Mean forest was not sampled in the bart model provided") - } - forest_container <- model_object$mean_forests - } else if (forest_type == "variance") { - if (!model_object$model_params$include_variance_forest) { - stop( - "Variance forest was not sampled in the bart model provided" - ) - } - forest_container <- model_object$variance_forests - } - } else if (model_type == "bcf") { - stopifnot(forest_type %in% c("prognostic", "treatment", "variance")) - if (forest_type == "prognostic") { - forest_container <- model_object$forests_mu - } else if (forest_type == "treatment") { - forest_container <- model_object$forests_tau - } else if (forest_type == "variance") { - if (!model_object$model_params$include_variance_forest) { - stop( - "Variance forest was not sampled in the bcf model provided" - ) - } - forest_container <- model_object$variance_forests - } - } else { - forest_container <- model_object - } - - # Preprocess covariates - if ((!is.data.frame(covariates)) && (!is.matrix(covariates))) { - stop("covariates must be a matrix or dataframe") + # Extract relevant forest container + stopifnot(any(c( + inherits(model_object, "bartmodel"), + inherits(model_object, "bcfmodel"), + inherits(model_object, "ForestSamples") + ))) + model_type <- ifelse( + inherits(model_object, "bartmodel"), + "bart", + ifelse(inherits(model_object, "bcfmodel"), "bcf", "forest_samples") + ) + if (model_type == "bart") { + stopifnot(forest_type %in% c("mean", "variance")) + if (forest_type == "mean") { + if (!model_object$model_params$include_mean_forest) { + stop("Mean forest was not sampled in the bart model provided") + } + forest_container <- model_object$mean_forests + } else if (forest_type == "variance") { + if (!model_object$model_params$include_variance_forest) { + stop( + "Variance forest was not sampled in the bart model provided" + ) + } + forest_container <- model_object$variance_forests } - if (model_type %in% c("bart", "bcf")) { - train_set_metadata <- model_object$train_set_metadata - covariates_processed <- preprocessPredictionData( - covariates, - train_set_metadata + } else if (model_type == "bcf") { + stopifnot(forest_type %in% c("prognostic", "treatment", "variance")) + if (forest_type == "prognostic") { + forest_container <- model_object$forests_mu + } else if (forest_type == "treatment") { + forest_container <- model_object$forests_tau + } else if (forest_type == "variance") { + if (!model_object$model_params$include_variance_forest) { + stop( + "Variance forest was not sampled in the bcf model provided" ) - } else { - if (!is.matrix(covariates)) { - stop( - "covariates must be a matrix since no covariate preprocessor is stored in a `ForestSamples` object provided as `model_object`" - ) - } - covariates_processed <- covariates + } + forest_container <- model_object$variance_forests } + } else { + forest_container <- model_object + } - # Handle BCF propensity covariate - if (model_type == "bcf") { - # Add propensities to covariate set if necessary - if (model_object$model_params$propensity_covariate != "none") { - if (is.null(propensity)) { - if (!model_object$model_params$internal_propensity_model) { - stop("propensity must be provided for this model") - } - # Compute propensity score using the internal bart model - propensity <- rowMeans( - predict( - model_object$bart_propensity_model, - covariates - )$y_hat - ) - } - covariates_processed <- cbind(covariates_processed, propensity) - } + # Preprocess covariates + if ((!is.data.frame(covariates)) && (!is.matrix(covariates))) { + stop("covariates must be a matrix or dataframe") + } + if (model_type %in% c("bart", "bcf")) { + train_set_metadata <- model_object$train_set_metadata + covariates_processed <- preprocessPredictionData( + covariates, + train_set_metadata + ) + } else { + if (!is.matrix(covariates)) { + stop( + "covariates must be a matrix since no covariate preprocessor is stored in a `ForestSamples` object provided as `model_object`" + ) } + covariates_processed <- covariates + } - # Preprocess forest indices - num_forests <- forest_container$num_samples() - if (is.null(forest_inds)) { - forest_inds <- as.integer(1:num_forests - 1) - } else { - stopifnot(all(forest_inds <= num_forests - 1)) - stopifnot(all(forest_inds >= 0)) - forest_inds <- as.integer(forest_inds) + # Handle BCF propensity covariate + if (model_type == "bcf") { + # Add propensities to covariate set if necessary + if (model_object$model_params$propensity_covariate != "none") { + if (is.null(propensity)) { + if (!model_object$model_params$internal_propensity_model) { + stop("propensity must be provided for this model") + } + # Compute propensity score using the internal bart model + propensity <- rowMeans( + predict( + model_object$bart_propensity_model, + covariates + )$y_hat + ) + } + covariates_processed <- cbind(covariates_processed, propensity) } + } - # Compute leaf indices - leaf_ind_matrix <- compute_leaf_indices_cpp( - forest_container$forest_container_ptr, - covariates_processed, - forest_inds - ) + # Preprocess forest indices + num_forests <- forest_container$num_samples() + if (is.null(forest_inds)) { + forest_inds <- as.integer(1:num_forests - 1) + } else { + stopifnot(all(forest_inds <= num_forests - 1)) + stopifnot(all(forest_inds >= 0)) + forest_inds <- as.integer(forest_inds) + } - return(leaf_ind_matrix) + # Compute leaf indices + leaf_ind_matrix <- compute_leaf_indices_cpp( + forest_container$forest_container_ptr, + covariates_processed, + forest_inds + ) + + return(leaf_ind_matrix) } #' Compute vector of forest leaf scale parameters @@ -193,80 +193,80 @@ computeForestLeafIndices <- function( #' computeForestLeafVariances(bart_model, "mean", 0) #' computeForestLeafVariances(bart_model, "mean", c(1,3,5)) computeForestLeafVariances <- function( - model_object, - forest_type, - forest_inds = NULL + model_object, + forest_type, + forest_inds = NULL ) { - # Extract relevant forest container - stopifnot(any(c( - inherits(model_object, "bartmodel"), - inherits(model_object, "bcfmodel") - ))) - model_type <- ifelse(inherits(model_object, "bartmodel"), "bart", "bcf") - if (model_type == "bart") { - stopifnot(forest_type %in% c("mean", "variance")) - if (forest_type == "mean") { - if (!model_object$model_params$include_mean_forest) { - stop("Mean forest was not sampled in the bart model provided") - } - if (!model_object$model_params$sample_sigma2_leaf) { - stop( - "Leaf scale parameter was not sampled for the mean forest in the bart model provided" - ) - } - leaf_scale_vector <- model_object$sigma2_leaf_samples - } else if (forest_type == "variance") { - if (!model_object$model_params$include_variance_forest) { - stop( - "Variance forest was not sampled in the bart model provided" - ) - } - stop( - "Leaf scale parameter was not sampled for the variance forest in the bart model provided" - ) - } - } else { - stopifnot(forest_type %in% c("prognostic", "treatment", "variance")) - if (forest_type == "prognostic") { - if (!model_object$model_params$sample_sigma2_leaf_mu) { - stop( - "Leaf scale parameter was not sampled for the prognostic forest in the bcf model provided" - ) - } - leaf_scale_vector <- model_object$sigma2_leaf_mu_samples - } else if (forest_type == "treatment") { - if (!model_object$model_params$sample_sigma2_leaf_tau) { - stop( - "Leaf scale parameter was not sampled for the treatment effect forest in the bcf model provided" - ) - } - leaf_scale_vector <- model_object$sigma2_leaf_tau_samples - } else if (forest_type == "variance") { - if (!model_object$model_params$include_variance_forest) { - stop( - "Variance forest was not sampled in the bcf model provided" - ) - } - stop( - "Leaf scale parameter was not sampled for the variance forest in the bcf model provided" - ) - } + # Extract relevant forest container + stopifnot(any(c( + inherits(model_object, "bartmodel"), + inherits(model_object, "bcfmodel") + ))) + model_type <- ifelse(inherits(model_object, "bartmodel"), "bart", "bcf") + if (model_type == "bart") { + stopifnot(forest_type %in% c("mean", "variance")) + if (forest_type == "mean") { + if (!model_object$model_params$include_mean_forest) { + stop("Mean forest was not sampled in the bart model provided") + } + if (!model_object$model_params$sample_sigma2_leaf) { + stop( + "Leaf scale parameter was not sampled for the mean forest in the bart model provided" + ) + } + leaf_scale_vector <- model_object$sigma2_leaf_samples + } else if (forest_type == "variance") { + if (!model_object$model_params$include_variance_forest) { + stop( + "Variance forest was not sampled in the bart model provided" + ) + } + stop( + "Leaf scale parameter was not sampled for the variance forest in the bart model provided" + ) } - - # Preprocess forest indices - num_forests <- model_object$model_params$num_samples - if (is.null(forest_inds)) { - forest_inds <- as.integer(1:num_forests) - } else { - stopifnot(all(forest_inds <= num_forests - 1)) - stopifnot(all(forest_inds >= 0)) - forest_inds <- as.integer(forest_inds + 1) + } else { + stopifnot(forest_type %in% c("prognostic", "treatment", "variance")) + if (forest_type == "prognostic") { + if (!model_object$model_params$sample_sigma2_leaf_mu) { + stop( + "Leaf scale parameter was not sampled for the prognostic forest in the bcf model provided" + ) + } + leaf_scale_vector <- model_object$sigma2_leaf_mu_samples + } else if (forest_type == "treatment") { + if (!model_object$model_params$sample_sigma2_leaf_tau) { + stop( + "Leaf scale parameter was not sampled for the treatment effect forest in the bcf model provided" + ) + } + leaf_scale_vector <- model_object$sigma2_leaf_tau_samples + } else if (forest_type == "variance") { + if (!model_object$model_params$include_variance_forest) { + stop( + "Variance forest was not sampled in the bcf model provided" + ) + } + stop( + "Leaf scale parameter was not sampled for the variance forest in the bcf model provided" + ) } + } + + # Preprocess forest indices + num_forests <- model_object$model_params$num_samples + if (is.null(forest_inds)) { + forest_inds <- as.integer(1:num_forests) + } else { + stopifnot(all(forest_inds <= num_forests - 1)) + stopifnot(all(forest_inds >= 0)) + forest_inds <- as.integer(forest_inds + 1) + } - # Gather leaf scale parameters - leaf_scale_params <- leaf_scale_vector[forest_inds] + # Gather leaf scale parameters + leaf_scale_params <- leaf_scale_vector[forest_inds] - return(leaf_scale_params) + return(leaf_scale_params) } #' Compute and return the largest possible leaf index computable by `computeForestLeafIndices` for the forests in a designated forest sample container. @@ -304,72 +304,72 @@ computeForestLeafVariances <- function( #' computeForestMaxLeafIndex(bart_model, "mean", 0) #' computeForestMaxLeafIndex(bart_model, "mean", c(1,3,9)) computeForestMaxLeafIndex <- function( - model_object, - forest_type = NULL, - forest_inds = NULL + model_object, + forest_type = NULL, + forest_inds = NULL ) { - # Extract relevant forest container - stopifnot(any(c( - inherits(model_object, "bartmodel"), - inherits(model_object, "bcfmodel"), - inherits(model_object, "ForestSamples") - ))) - model_type <- ifelse( - inherits(model_object, "bartmodel"), - "bart", - ifelse(inherits(model_object, "bcfmodel"), "bcf", "forest_samples") - ) - if (model_type == "bart") { - stopifnot(forest_type %in% c("mean", "variance")) - if (forest_type == "mean") { - if (!model_object$model_params$include_mean_forest) { - stop("Mean forest was not sampled in the bart model provided") - } - forest_container <- model_object$mean_forests - } else if (forest_type == "variance") { - if (!model_object$model_params$include_variance_forest) { - stop( - "Variance forest was not sampled in the bart model provided" - ) - } - forest_container <- model_object$variance_forests - } - } else if (model_type == "bcf") { - stopifnot(forest_type %in% c("prognostic", "treatment", "variance")) - if (forest_type == "prognostic") { - forest_container <- model_object$forests_mu - } else if (forest_type == "treatment") { - forest_container <- model_object$forests_tau - } else if (forest_type == "variance") { - if (!model_object$model_params$include_variance_forest) { - stop( - "Variance forest was not sampled in the bcf model provided" - ) - } - forest_container <- model_object$variance_forests - } - } else { - forest_container <- model_object - } - - # Preprocess forest indices - num_forests <- forest_container$num_samples() - if (is.null(forest_inds)) { - forest_inds <- as.integer(1:num_forests - 1) - } else { - stopifnot(all(forest_inds <= num_forests - 1)) - stopifnot(all(forest_inds >= 0)) - forest_inds <- as.integer(forest_inds) + # Extract relevant forest container + stopifnot(any(c( + inherits(model_object, "bartmodel"), + inherits(model_object, "bcfmodel"), + inherits(model_object, "ForestSamples") + ))) + model_type <- ifelse( + inherits(model_object, "bartmodel"), + "bart", + ifelse(inherits(model_object, "bcfmodel"), "bcf", "forest_samples") + ) + if (model_type == "bart") { + stopifnot(forest_type %in% c("mean", "variance")) + if (forest_type == "mean") { + if (!model_object$model_params$include_mean_forest) { + stop("Mean forest was not sampled in the bart model provided") + } + forest_container <- model_object$mean_forests + } else if (forest_type == "variance") { + if (!model_object$model_params$include_variance_forest) { + stop( + "Variance forest was not sampled in the bart model provided" + ) + } + forest_container <- model_object$variance_forests } - - # Compute leaf indices - output <- rep(NA, length(forest_inds)) - for (i in 1:length(forest_inds)) { - output[i] <- forest_container_get_max_leaf_index_cpp( - forest_container$forest_container_ptr, - forest_inds[i] + } else if (model_type == "bcf") { + stopifnot(forest_type %in% c("prognostic", "treatment", "variance")) + if (forest_type == "prognostic") { + forest_container <- model_object$forests_mu + } else if (forest_type == "treatment") { + forest_container <- model_object$forests_tau + } else if (forest_type == "variance") { + if (!model_object$model_params$include_variance_forest) { + stop( + "Variance forest was not sampled in the bcf model provided" ) + } + forest_container <- model_object$variance_forests } + } else { + forest_container <- model_object + } + + # Preprocess forest indices + num_forests <- forest_container$num_samples() + if (is.null(forest_inds)) { + forest_inds <- as.integer(1:num_forests - 1) + } else { + stopifnot(all(forest_inds <= num_forests - 1)) + stopifnot(all(forest_inds >= 0)) + forest_inds <- as.integer(forest_inds) + } + + # Compute leaf indices + output <- rep(NA, length(forest_inds)) + for (i in 1:length(forest_inds)) { + output[i] <- forest_container_get_max_leaf_index_cpp( + forest_container$forest_container_ptr, + forest_inds[i] + ) + } - return(output) + return(output) } diff --git a/R/model.R b/R/model.R index 38df5970..5549880d 100644 --- a/R/model.R +++ b/R/model.R @@ -6,20 +6,20 @@ #' the C++ random number generator is initialized using `std::random_device`. CppRNG <- R6::R6Class( - classname = "CppRNG", - cloneable = FALSE, - public = list( - #' @field rng_ptr External pointer to a C++ std::mt19937 class - rng_ptr = NULL, + classname = "CppRNG", + cloneable = FALSE, + public = list( + #' @field rng_ptr External pointer to a C++ std::mt19937 class + rng_ptr = NULL, - #' @description - #' Create a new CppRNG object. - #' @param random_seed (Optional) random seed for sampling - #' @return A new `CppRNG` object. - initialize = function(random_seed = -1) { - self$rng_ptr <- rng_cpp(random_seed) - } - ) + #' @description + #' Create a new CppRNG object. + #' @param random_seed (Optional) random seed for sampling + #' @return A new `CppRNG` object. + initialize = function(random_seed = -1) { + self$rng_ptr <- rng_cpp(random_seed) + } + ) ) #' Class that defines and samples a forest model @@ -30,292 +30,291 @@ CppRNG <- R6::R6Class( #' (using either MCMC or the grow-from-root algorithm). ForestModel <- R6::R6Class( - classname = "ForestModel", - cloneable = FALSE, - public = list( - #' @field tracker_ptr External pointer to a C++ ForestTracker class - tracker_ptr = NULL, + classname = "ForestModel", + cloneable = FALSE, + public = list( + #' @field tracker_ptr External pointer to a C++ ForestTracker class + tracker_ptr = NULL, - #' @field tree_prior_ptr External pointer to a C++ TreePrior class - tree_prior_ptr = NULL, + #' @field tree_prior_ptr External pointer to a C++ TreePrior class + tree_prior_ptr = NULL, - #' @description - #' Create a new ForestModel object. - #' @param forest_dataset `ForestDataset` object, used to initialize forest sampling data structures - #' @param feature_types Feature types (integers where 0 = numeric, 1 = ordered categorical, 2 = unordered categorical) - #' @param num_trees Number of trees in the forest being sampled - #' @param n Number of observations in `forest_dataset` - #' @param alpha Root node split probability in tree prior - #' @param beta Depth prior penalty in tree prior - #' @param min_samples_leaf Minimum number of samples in a tree leaf - #' @param max_depth Maximum depth that any tree can reach - #' @return A new `ForestModel` object. - initialize = function( - forest_dataset, - feature_types, - num_trees, - n, - alpha, - beta, - min_samples_leaf, - max_depth = -1 - ) { - stopifnot(!is.null(forest_dataset$data_ptr)) - self$tracker_ptr <- forest_tracker_cpp( - forest_dataset$data_ptr, - feature_types, - num_trees, - n - ) - self$tree_prior_ptr <- tree_prior_cpp( - alpha, - beta, - min_samples_leaf, - max_depth - ) - }, + #' @description + #' Create a new ForestModel object. + #' @param forest_dataset `ForestDataset` object, used to initialize forest sampling data structures + #' @param feature_types Feature types (integers where 0 = numeric, 1 = ordered categorical, 2 = unordered categorical) + #' @param num_trees Number of trees in the forest being sampled + #' @param n Number of observations in `forest_dataset` + #' @param alpha Root node split probability in tree prior + #' @param beta Depth prior penalty in tree prior + #' @param min_samples_leaf Minimum number of samples in a tree leaf + #' @param max_depth Maximum depth that any tree can reach + #' @return A new `ForestModel` object. + initialize = function( + forest_dataset, + feature_types, + num_trees, + n, + alpha, + beta, + min_samples_leaf, + max_depth = -1 + ) { + stopifnot(!is.null(forest_dataset$data_ptr)) + self$tracker_ptr <- forest_tracker_cpp( + forest_dataset$data_ptr, + feature_types, + num_trees, + n + ) + self$tree_prior_ptr <- tree_prior_cpp( + alpha, + beta, + min_samples_leaf, + max_depth + ) + }, - #' @description - #' Run a single iteration of the forest sampling algorithm (MCMC or GFR) - #' @param forest_dataset Dataset used to sample the forest - #' @param residual Outcome used to sample the forest - #' @param forest_samples Container of forest samples - #' @param active_forest "Active" forest updated by the sampler in each iteration - #' @param rng Wrapper around C++ random number generator - #' @param forest_model_config ForestModelConfig object containing forest model parameters and settings - #' @param global_model_config GlobalModelConfig object containing global model parameters and settings - #' @param num_threads Number of threads to use in the GFR and MCMC algorithms, as well as prediction. If OpenMP is not available on a user's system, this will default to `1`, otherwise to the maximum number of available threads. - #' @param keep_forest (Optional) Whether the updated forest sample should be saved to `forest_samples`. Default: `TRUE`. - #' @param gfr (Optional) Whether or not the forest should be sampled using the "grow-from-root" (GFR) algorithm. Default: `TRUE`. - sample_one_iteration = function( - forest_dataset, - residual, - forest_samples, - active_forest, - rng, - forest_model_config, - global_model_config, - num_threads = -1, - keep_forest = TRUE, - gfr = TRUE - ) { - if (active_forest$is_empty()) { - stop( - "`active_forest` has not yet been initialized, which is necessary to run the sampler. Please set constant values for `active_forest`'s leaves using either the `set_root_leaves` or `prepare_for_sampler` methods." - ) - } + #' @description + #' Run a single iteration of the forest sampling algorithm (MCMC or GFR) + #' @param forest_dataset Dataset used to sample the forest + #' @param residual Outcome used to sample the forest + #' @param forest_samples Container of forest samples + #' @param active_forest "Active" forest updated by the sampler in each iteration + #' @param rng Wrapper around C++ random number generator + #' @param forest_model_config ForestModelConfig object containing forest model parameters and settings + #' @param global_model_config GlobalModelConfig object containing global model parameters and settings + #' @param num_threads Number of threads to use in the GFR and MCMC algorithms, as well as prediction. If OpenMP is not available on a user's system, this will default to `1`, otherwise to the maximum number of available threads. + #' @param keep_forest (Optional) Whether the updated forest sample should be saved to `forest_samples`. Default: `TRUE`. + #' @param gfr (Optional) Whether or not the forest should be sampled using the "grow-from-root" (GFR) algorithm. Default: `TRUE`. + sample_one_iteration = function( + forest_dataset, + residual, + forest_samples, + active_forest, + rng, + forest_model_config, + global_model_config, + num_threads = -1, + keep_forest = TRUE, + gfr = TRUE + ) { + if (active_forest$is_empty()) { + stop( + "`active_forest` has not yet been initialized, which is necessary to run the sampler. Please set constant values for `active_forest`'s leaves using either the `set_root_leaves` or `prepare_for_sampler` methods." + ) + } - # Unpack parameters from model config object - feature_types <- forest_model_config$feature_types - sweep_update_indices <- forest_model_config$sweep_update_indices - leaf_model_int <- forest_model_config$leaf_model_type - leaf_model_scale <- forest_model_config$leaf_model_scale - variable_weights <- forest_model_config$variable_weights - a_forest <- forest_model_config$variance_forest_shape - b_forest <- forest_model_config$variance_forest_scale - global_scale <- global_model_config$global_error_variance - cutpoint_grid_size <- forest_model_config$cutpoint_grid_size - num_features_subsample <- forest_model_config$num_features_subsample + # Unpack parameters from model config object + feature_types <- forest_model_config$feature_types + sweep_update_indices <- forest_model_config$sweep_update_indices + leaf_model_int <- forest_model_config$leaf_model_type + leaf_model_scale <- forest_model_config$leaf_model_scale + variable_weights <- forest_model_config$variable_weights + a_forest <- forest_model_config$variance_forest_shape + b_forest <- forest_model_config$variance_forest_scale + global_scale <- global_model_config$global_error_variance + cutpoint_grid_size <- forest_model_config$cutpoint_grid_size + num_features_subsample <- forest_model_config$num_features_subsample - # Default to empty integer vector if sweep_update_indices is NULL - if (is.null(sweep_update_indices)) { - # sweep_update_indices <- integer(0) - sweep_update_indices <- 0:(forest_model_config$num_trees - 1) - } + # Default to empty integer vector if sweep_update_indices is NULL + if (is.null(sweep_update_indices)) { + # sweep_update_indices <- integer(0) + sweep_update_indices <- 0:(forest_model_config$num_trees - 1) + } - # Detect changes to tree prior - if ( - forest_model_config$alpha != - get_alpha_tree_prior_cpp(self$tree_prior_ptr) - ) { - update_alpha_tree_prior_cpp( - self$tree_prior_ptr, - forest_model_config$alpha - ) - } - if ( - forest_model_config$beta != - get_beta_tree_prior_cpp(self$tree_prior_ptr) - ) { - update_beta_tree_prior_cpp( - self$tree_prior_ptr, - forest_model_config$beta - ) - } - if ( - forest_model_config$min_samples_leaf != - get_min_samples_leaf_tree_prior_cpp(self$tree_prior_ptr) - ) { - update_min_samples_leaf_tree_prior_cpp( - self$tree_prior_ptr, - forest_model_config$min_samples_leaf - ) - } - if ( - forest_model_config$max_depth != - get_max_depth_tree_prior_cpp(self$tree_prior_ptr) - ) { - update_max_depth_tree_prior_cpp( - self$tree_prior_ptr, - forest_model_config$max_depth - ) - } + # Detect changes to tree prior + if ( + forest_model_config$alpha != + get_alpha_tree_prior_cpp(self$tree_prior_ptr) + ) { + update_alpha_tree_prior_cpp( + self$tree_prior_ptr, + forest_model_config$alpha + ) + } + if ( + forest_model_config$beta != get_beta_tree_prior_cpp(self$tree_prior_ptr) + ) { + update_beta_tree_prior_cpp( + self$tree_prior_ptr, + forest_model_config$beta + ) + } + if ( + forest_model_config$min_samples_leaf != + get_min_samples_leaf_tree_prior_cpp(self$tree_prior_ptr) + ) { + update_min_samples_leaf_tree_prior_cpp( + self$tree_prior_ptr, + forest_model_config$min_samples_leaf + ) + } + if ( + forest_model_config$max_depth != + get_max_depth_tree_prior_cpp(self$tree_prior_ptr) + ) { + update_max_depth_tree_prior_cpp( + self$tree_prior_ptr, + forest_model_config$max_depth + ) + } - # Run the sampler - if (gfr) { - sample_gfr_one_iteration_cpp( - forest_dataset$data_ptr, - residual$data_ptr, - forest_samples$forest_container_ptr, - active_forest$forest_ptr, - self$tracker_ptr, - self$tree_prior_ptr, - rng$rng_ptr, - sweep_update_indices, - feature_types, - cutpoint_grid_size, - leaf_model_scale, - variable_weights, - a_forest, - b_forest, - global_scale, - leaf_model_int, - keep_forest, - num_features_subsample, - num_threads - ) - } else { - sample_mcmc_one_iteration_cpp( - forest_dataset$data_ptr, - residual$data_ptr, - forest_samples$forest_container_ptr, - active_forest$forest_ptr, - self$tracker_ptr, - self$tree_prior_ptr, - rng$rng_ptr, - sweep_update_indices, - feature_types, - cutpoint_grid_size, - leaf_model_scale, - variable_weights, - a_forest, - b_forest, - global_scale, - leaf_model_int, - keep_forest, - num_threads - ) - } - }, + # Run the sampler + if (gfr) { + sample_gfr_one_iteration_cpp( + forest_dataset$data_ptr, + residual$data_ptr, + forest_samples$forest_container_ptr, + active_forest$forest_ptr, + self$tracker_ptr, + self$tree_prior_ptr, + rng$rng_ptr, + sweep_update_indices, + feature_types, + cutpoint_grid_size, + leaf_model_scale, + variable_weights, + a_forest, + b_forest, + global_scale, + leaf_model_int, + keep_forest, + num_features_subsample, + num_threads + ) + } else { + sample_mcmc_one_iteration_cpp( + forest_dataset$data_ptr, + residual$data_ptr, + forest_samples$forest_container_ptr, + active_forest$forest_ptr, + self$tracker_ptr, + self$tree_prior_ptr, + rng$rng_ptr, + sweep_update_indices, + feature_types, + cutpoint_grid_size, + leaf_model_scale, + variable_weights, + a_forest, + b_forest, + global_scale, + leaf_model_int, + keep_forest, + num_threads + ) + } + }, - #' @description - #' Extract an internally-cached prediction of a forest on the training dataset in a sampler. - #' @return Vector with as many elements as observations in the training dataset - get_cached_forest_predictions = function() { - get_cached_forest_predictions_cpp(self$tracker_ptr) - }, + #' @description + #' Extract an internally-cached prediction of a forest on the training dataset in a sampler. + #' @return Vector with as many elements as observations in the training dataset + get_cached_forest_predictions = function() { + get_cached_forest_predictions_cpp(self$tracker_ptr) + }, - #' @description - #' Propagates basis update through to the (full/partial) residual by iteratively - #' (a) adding back in the previous prediction of each tree, (b) recomputing predictions - #' for each tree (caching on the C++ side), (c) subtracting the new predictions from the residual. - #' - #' This is useful in cases where a basis (for e.g. leaf regression) is updated outside - #' of a tree sampler (as with e.g. adaptive coding for binary treatment BCF). - #' Once a basis has been updated, the overall "function" represented by a tree model has - #' changed and this should be reflected through to the residual before the next sampling loop is run. - #' @param dataset `ForestDataset` object storing the covariates and bases for a given forest - #' @param outcome `Outcome` object storing the residuals to be updated based on forest predictions - #' @param active_forest "Active" forest updated by the sampler in each iteration - propagate_basis_update = function(dataset, outcome, active_forest) { - stopifnot(!is.null(dataset$data_ptr)) - stopifnot(!is.null(outcome$data_ptr)) - stopifnot(!is.null(self$tracker_ptr)) - stopifnot(!is.null(active_forest$forest_ptr)) + #' @description + #' Propagates basis update through to the (full/partial) residual by iteratively + #' (a) adding back in the previous prediction of each tree, (b) recomputing predictions + #' for each tree (caching on the C++ side), (c) subtracting the new predictions from the residual. + #' + #' This is useful in cases where a basis (for e.g. leaf regression) is updated outside + #' of a tree sampler (as with e.g. adaptive coding for binary treatment BCF). + #' Once a basis has been updated, the overall "function" represented by a tree model has + #' changed and this should be reflected through to the residual before the next sampling loop is run. + #' @param dataset `ForestDataset` object storing the covariates and bases for a given forest + #' @param outcome `Outcome` object storing the residuals to be updated based on forest predictions + #' @param active_forest "Active" forest updated by the sampler in each iteration + propagate_basis_update = function(dataset, outcome, active_forest) { + stopifnot(!is.null(dataset$data_ptr)) + stopifnot(!is.null(outcome$data_ptr)) + stopifnot(!is.null(self$tracker_ptr)) + stopifnot(!is.null(active_forest$forest_ptr)) - propagate_basis_update_active_forest_cpp( - dataset$data_ptr, - outcome$data_ptr, - active_forest$forest_ptr, - self$tracker_ptr - ) - }, + propagate_basis_update_active_forest_cpp( + dataset$data_ptr, + outcome$data_ptr, + active_forest$forest_ptr, + self$tracker_ptr + ) + }, - #' @description - #' Update the current state of the outcome (i.e. partial residual) data by subtracting the current predictions of each tree. - #' This function is run after the `Outcome` class's `update_data` method, which overwrites the partial residual with an entirely new stream of outcome data. - #' @param residual Outcome used to sample the forest - #' @return None - propagate_residual_update = function(residual) { - propagate_trees_column_vector_cpp( - self$tracker_ptr, - residual$data_ptr - ) - }, + #' @description + #' Update the current state of the outcome (i.e. partial residual) data by subtracting the current predictions of each tree. + #' This function is run after the `Outcome` class's `update_data` method, which overwrites the partial residual with an entirely new stream of outcome data. + #' @param residual Outcome used to sample the forest + #' @return None + propagate_residual_update = function(residual) { + propagate_trees_column_vector_cpp( + self$tracker_ptr, + residual$data_ptr + ) + }, - #' @description - #' Update alpha in the tree prior - #' @param alpha New value of alpha to be used - #' @return None - update_alpha = function(alpha) { - update_alpha_tree_prior_cpp(self$tree_prior_ptr, alpha) - }, + #' @description + #' Update alpha in the tree prior + #' @param alpha New value of alpha to be used + #' @return None + update_alpha = function(alpha) { + update_alpha_tree_prior_cpp(self$tree_prior_ptr, alpha) + }, - #' @description - #' Update beta in the tree prior - #' @param beta New value of beta to be used - #' @return None - update_beta = function(beta) { - update_beta_tree_prior_cpp(self$tree_prior_ptr, beta) - }, + #' @description + #' Update beta in the tree prior + #' @param beta New value of beta to be used + #' @return None + update_beta = function(beta) { + update_beta_tree_prior_cpp(self$tree_prior_ptr, beta) + }, - #' @description - #' Update min_samples_leaf in the tree prior - #' @param min_samples_leaf New value of min_samples_leaf to be used - #' @return None - update_min_samples_leaf = function(min_samples_leaf) { - update_min_samples_leaf_tree_prior_cpp( - self$tree_prior_ptr, - min_samples_leaf - ) - }, + #' @description + #' Update min_samples_leaf in the tree prior + #' @param min_samples_leaf New value of min_samples_leaf to be used + #' @return None + update_min_samples_leaf = function(min_samples_leaf) { + update_min_samples_leaf_tree_prior_cpp( + self$tree_prior_ptr, + min_samples_leaf + ) + }, - #' @description - #' Update max_depth in the tree prior - #' @param max_depth New value of max_depth to be used - #' @return None - update_max_depth = function(max_depth) { - update_max_depth_tree_prior_cpp(self$tree_prior_ptr, max_depth) - }, + #' @description + #' Update max_depth in the tree prior + #' @param max_depth New value of max_depth to be used + #' @return None + update_max_depth = function(max_depth) { + update_max_depth_tree_prior_cpp(self$tree_prior_ptr, max_depth) + }, - #' @description - #' Update alpha in the tree prior - #' @return Value of alpha in the tree prior - get_alpha = function() { - get_alpha_tree_prior_cpp(self$tree_prior_ptr) - }, + #' @description + #' Update alpha in the tree prior + #' @return Value of alpha in the tree prior + get_alpha = function() { + get_alpha_tree_prior_cpp(self$tree_prior_ptr) + }, - #' @description - #' Update beta in the tree prior - #' @return Value of beta in the tree prior - get_beta = function() { - get_beta_tree_prior_cpp(self$tree_prior_ptr) - }, + #' @description + #' Update beta in the tree prior + #' @return Value of beta in the tree prior + get_beta = function() { + get_beta_tree_prior_cpp(self$tree_prior_ptr) + }, - #' @description - #' Query min_samples_leaf in the tree prior - #' @return Value of min_samples_leaf in the tree prior - get_min_samples_leaf = function() { - get_min_samples_leaf_tree_prior_cpp(self$tree_prior_ptr) - }, + #' @description + #' Query min_samples_leaf in the tree prior + #' @return Value of min_samples_leaf in the tree prior + get_min_samples_leaf = function() { + get_min_samples_leaf_tree_prior_cpp(self$tree_prior_ptr) + }, - #' @description - #' Query max_depth in the tree prior - #' @return Value of max_depth in the tree prior - get_max_depth = function() { - get_max_depth_tree_prior_cpp(self$tree_prior_ptr) - } - ) + #' @description + #' Query max_depth in the tree prior + #' @return Value of max_depth in the tree prior + get_max_depth = function() { + get_max_depth_tree_prior_cpp(self$tree_prior_ptr) + } + ) ) #' Create an R class that wraps a C++ random number generator @@ -329,7 +328,7 @@ ForestModel <- R6::R6Class( #' rng <- createCppRNG(1234) #' rng <- createCppRNG() createCppRNG <- function(random_seed = -1) { - return(invisible((CppRNG$new(random_seed)))) + return(invisible((CppRNG$new(random_seed)))) } #' Create a forest model object @@ -360,22 +359,22 @@ createCppRNG <- function(random_seed = -1) { #' global_model_config <- createGlobalModelConfig(global_error_variance=1.0) #' forest_model <- createForestModel(forest_dataset, forest_model_config, global_model_config) createForestModel <- function( - forest_dataset, - forest_model_config, - global_model_config + forest_dataset, + forest_model_config, + global_model_config ) { - return(invisible( - (ForestModel$new( - forest_dataset, - forest_model_config$feature_types, - forest_model_config$num_trees, - forest_model_config$num_observations, - forest_model_config$alpha, - forest_model_config$beta, - forest_model_config$min_samples_leaf, - forest_model_config$max_depth - )) + return(invisible( + (ForestModel$new( + forest_dataset, + forest_model_config$feature_types, + forest_model_config$num_trees, + forest_model_config$num_observations, + forest_model_config$alpha, + forest_model_config$beta, + forest_model_config$min_samples_leaf, + forest_model_config$max_depth )) + )) } @@ -394,13 +393,13 @@ createForestModel <- function( #' num_samples <- 5 #' sample_without_replacement(a, p, num_samples) sample_without_replacement <- function( + population_vector, + sampling_probabilities, + sample_size +) { + return(sample_without_replacement_integer_cpp( population_vector, sampling_probabilities, sample_size -) { - return(sample_without_replacement_integer_cpp( - population_vector, - sampling_probabilities, - sample_size - )) + )) } diff --git a/R/posterior_transformation.R b/R/posterior_transformation.R new file mode 100644 index 00000000..b401cffd --- /dev/null +++ b/R/posterior_transformation.R @@ -0,0 +1,1294 @@ +#' Compute a contrast using a BCF model by making two sets of outcome predictions and taking their difference. +#' For simple BCF models with binary treatment, this will yield the same prediction as requesting `terms = "cate"` +#' in the `predict.bcfmodel` function. For more general models, such as models with continuous / multivariate treatments or +#' an additive random effects term with a coefficient on the treatment, this function provides the flexibility to compute a +#' any contrast of interest by specifying covariates, treatment, and random effects bases and IDs for both sides of a two term +#' contrast. For simplicity, we refer to the subtrahend of the contrast as the "control" or `Y0` term and the minuend of the +#' contrast as the `Y1` term, though the requested contrast need not match the "control vs treatment" terminology of a classic +#' two-arm experiment. We mirror the function calls and terminology of the `predict.bcfmodel` function, labeling each prediction +#' data term with a `1` to denote its contribution to the treatment prediction of a contrast and `0` to denote inclusion in the +#' control prediction. +#' +#' @param object Object of type `bcfmodel` containing draws of a Bayesian causal forest model and associated sampling outputs. +#' @param X_0 Covariates used for prediction in the "control" case. Must be a matrix or dataframe. +#' @param X_1 Covariates used for prediction in the "treatment" case. Must be a matrix or dataframe. +#' @param Z_0 Treatments used for prediction in the "control" case. Must be a matrix or vector. +#' @param Z_1 Treatments used for prediction in the "treatment" case. Must be a matrix or vector. +#' @param propensity_0 (Optional) Propensities used for prediction in the "control" case. Must be a matrix or vector. +#' @param propensity_1 (Optional) Propensities used for prediction in the "treatment" case. Must be a matrix or vector. +#' @param rfx_group_ids_0 (Optional) Test set group labels used for prediction from an additive random effects +#' model in the "control" case. We do not currently support (but plan to in the near future), test set evaluation +#' for group labels that were not in the training set. Must be a vector. +#' @param rfx_group_ids_1 (Optional) Test set group labels used for prediction from an additive random effects +#' model in the "treatment" case. We do not currently support (but plan to in the near future), test set evaluation +#' for group labels that were not in the training set. Must be a vector. +#' @param rfx_basis_0 (Optional) Test set basis for used for prediction from an additive random effects model in the "control" case. Must be a matrix or vector. +#' @param rfx_basis_1 (Optional) Test set basis for used for prediction from an additive random effects model in the "treatment" case. Must be a matrix or vector. +#' @param type (Optional) Aggregation level of the contrast. Options are "mean", which averages the contrast evaluations over every draw of a BCF model, and "posterior", which returns the entire matrix of posterior contrast estimates. Default: "posterior". +#' @param scale (Optional) Scale of the contrast. Options are "linear", which returns a contrast on the original scale of the mean forest / RFX terms, and "probability", which transforms each contrast term into a probability of observing `y == 1` before taking their difference. "probability" is only valid for models fit with a probit outcome model. Default: "linear". +#' +#' @return List of prediction matrices or single prediction matrix / vector, depending on the terms requested. +#' @export +#' +#' @examples +#' n <- 500 +#' p <- 5 +#' X <- matrix(runif(n*p), ncol = p) +#' mu_x <- ( +#' ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + +#' ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + +#' ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + +#' ((0.75 <= X[,1]) & (1 > X[,1])) * (7.5) +#' ) +#' pi_x <- ( +#' ((0 <= X[,1]) & (0.25 > X[,1])) * (0.2) + +#' ((0.25 <= X[,1]) & (0.5 > X[,1])) * (0.4) + +#' ((0.5 <= X[,1]) & (0.75 > X[,1])) * (0.6) + +#' ((0.75 <= X[,1]) & (1 > X[,1])) * (0.8) +#' ) +#' tau_x <- ( +#' ((0 <= X[,2]) & (0.25 > X[,2])) * (0.5) + +#' ((0.25 <= X[,2]) & (0.5 > X[,2])) * (1.0) + +#' ((0.5 <= X[,2]) & (0.75 > X[,2])) * (1.5) + +#' ((0.75 <= X[,2]) & (1 > X[,2])) * (2.0) +#' ) +#' Z <- rbinom(n, 1, pi_x) +#' noise_sd <- 1 +#' y <- mu_x + tau_x*Z + rnorm(n, 0, noise_sd) +#' test_set_pct <- 0.2 +#' n_test <- round(test_set_pct*n) +#' n_train <- n - n_test +#' test_inds <- sort(sample(1:n, n_test, replace = FALSE)) +#' train_inds <- (1:n)[!((1:n) %in% test_inds)] +#' X_test <- X[test_inds,] +#' X_train <- X[train_inds,] +#' pi_test <- pi_x[test_inds] +#' pi_train <- pi_x[train_inds] +#' Z_test <- Z[test_inds] +#' Z_train <- Z[train_inds] +#' y_test <- y[test_inds] +#' y_train <- y[train_inds] +#' mu_test <- mu_x[test_inds] +#' mu_train <- mu_x[train_inds] +#' tau_test <- tau_x[test_inds] +#' tau_train <- tau_x[train_inds] +#' bcf_model <- bcf(X_train = X_train, Z_train = Z_train, y_train = y_train, +#' propensity_train = pi_train, num_gfr = 10, +#' num_burnin = 0, num_mcmc = 10) +#' tau_hat_test <- compute_contrast_bcf_model( +#' bcf_model, X_0=X_test, X_1=X_test, Z_0=rep(0, n_test), Z_1=rep(1, n_test), +#' propensity_0 = pi_test, propensity_1 = pi_test +#' ) +compute_contrast_bcf_model <- function( + object, + X_0, + X_1, + Z_0, + Z_1, + propensity_0 = NULL, + propensity_1 = NULL, + rfx_group_ids_0 = NULL, + rfx_group_ids_1 = NULL, + rfx_basis_0 = NULL, + rfx_basis_1 = NULL, + type = "posterior", + scale = "linear" +) { + # Handle mean function scale + if (!is.character(scale)) { + stop("scale must be a string or character vector") + } + if (!(scale %in% c("linear", "probability"))) { + stop("scale must either be 'linear' or 'probability'") + } + is_probit <- object$model_params$probit_outcome_model + if ((scale == "probability") && (!is_probit)) { + stop( + "scale cannot be 'probability' for models not fit with a probit outcome model" + ) + } + probability_scale <- scale == "probability" + + # Handle prediction type + if (!is.character(type)) { + stop("type must be a string or character vector") + } + if (!(type %in% c("mean", "posterior"))) { + stop("type must either be 'mean' or 'posterior") + } + predict_mean <- type == "mean" + + # Make sure covariates are matrix or data frame + if ((!is.data.frame(X_0)) && (!is.matrix(X_0))) { + stop("X_0 must be a matrix or dataframe") + } + if ((!is.data.frame(X_1)) && (!is.matrix(X_1))) { + stop("X_1 must be a matrix or dataframe") + } + + # Convert all input data to matrices if not already converted + if ((is.null(dim(Z_0))) && (!is.null(Z_0))) { + Z_0 <- as.matrix(as.numeric(Z_0)) + } + if ((is.null(dim(Z_1))) && (!is.null(Z_1))) { + Z_1 <- as.matrix(as.numeric(Z_1)) + } + if ((is.null(dim(propensity_0))) && (!is.null(propensity_0))) { + propensity_0 <- as.matrix(propensity_0) + } + if ((is.null(dim(propensity_1))) && (!is.null(propensity_1))) { + propensity_1 <- as.matrix(propensity_1) + } + if ((is.null(dim(rfx_basis_0))) && (!is.null(rfx_basis_0))) { + rfx_basis_0 <- as.matrix(rfx_basis_0) + } + if ((is.null(dim(rfx_basis_1))) && (!is.null(rfx_basis_1))) { + rfx_basis_1 <- as.matrix(rfx_basis_1) + } + + # Data checks + if ( + (object$model_params$propensity_covariate != "none") && + ((is.null(propensity_0)) || + (is.null(propensity_1))) + ) { + if (!object$model_params$internal_propensity_model) { + stop("propensity_0 and propensity_1 must be provided for this model") + } + } + if (nrow(X_0) != nrow(Z_0)) { + stop("X_0 and Z_0 must have the same number of rows") + } + if (nrow(X_1) != nrow(Z_1)) { + stop("X_1 and Z_1 must have the same number of rows") + } + if (object$model_params$num_covariates != ncol(X_0)) { + stop( + "X_0 and must have the same number of columns as the covariates used to train the model" + ) + } + if (object$model_params$num_covariates != ncol(X_1)) { + stop( + "X_1 and must have the same number of columns as the covariates used to train the model" + ) + } + if ((object$model_params$has_rfx) && (is.null(rfx_group_ids_0))) { + stop( + "Random effect group labels (rfx_group_ids_0) must be provided for this model" + ) + } + if ((object$model_params$has_rfx) && (is.null(rfx_group_ids_1))) { + stop( + "Random effect group labels (rfx_group_ids_1) must be provided for this model" + ) + } + if ((object$model_params$has_rfx_basis) && (is.null(rfx_basis_0))) { + stop("Random effects basis (rfx_basis_0) must be provided for this model") + } + if ((object$model_params$has_rfx_basis) && (is.null(rfx_basis_1))) { + stop("Random effects basis (rfx_basis_1) must be provided for this model") + } + if ( + (object$model_params$num_rfx_basis > 0) && + (ncol(rfx_basis_0) != object$model_params$num_rfx_basis) + ) { + stop( + "Random effects basis has a different dimension than the basis used to train this model" + ) + } + if ( + (object$model_params$num_rfx_basis > 0) && + (ncol(rfx_basis_1) != object$model_params$num_rfx_basis) + ) { + stop( + "Random effects basis has a different dimension than the basis used to train this model" + ) + } + + # Predict for the control arm + control_preds <- predict( + object = object, + X = X_0, + Z = Z_0, + propensity = propensity_0, + rfx_group_ids = rfx_group_ids_0, + rfx_basis = rfx_basis_0, + type = "posterior", + term = "y_hat", + scale = "linear" + ) + + # Predict for the treatment arm + treatment_preds <- predict( + object = object, + X = X_1, + Z = Z_1, + propensity = propensity_1, + rfx_group_ids = rfx_group_ids_1, + rfx_basis = rfx_basis_1, + type = "posterior", + term = "y_hat", + scale = "linear" + ) + + # Transform to probability scale if requested + if (probability_scale) { + treatment_preds <- pnorm(treatment_preds) + control_preds <- pnorm(control_preds) + } + + # Compute and return contrast + if (predict_mean) { + return(rowMeans(treatment_preds - control_preds)) + } else { + return(treatment_preds - control_preds) + } +} + +#' Compute a contrast using a BART model by making two sets of outcome predictions and taking their difference. +#' This function provides the flexibility to compute any contrast of interest by specifying covariates, leaf basis, and random effects +#' bases / IDs for both sides of a two term contrast. For simplicity, we refer to the subtrahend of the contrast as the "control" or +#' `Y0` term and the minuend of the contrast as the `Y1` term, though the requested contrast need not match the "control vs treatment" +#' terminology of a classic two-treatment causal inference problem. We mirror the function calls and terminology of the `predict.bartmodel` +#' function, labeling each prediction data term with a `1` to denote its contribution to the treatment prediction of a contrast and +#' `0` to denote inclusion in the control prediction. +#' +#' Only valid when there is either a mean forest or a random effects term in the BART model. +#' +#' @param object Object of type `bart` containing draws of a regression forest and associated sampling outputs. +#' @param covariates_0 Covariates used for prediction in the "control" case. Must be a matrix or dataframe. +#' @param covariates_1 Covariates used for prediction in the "treatment" case. Must be a matrix or dataframe. +#' @param leaf_basis_0 (Optional) Bases used for prediction in the "control" case (by e.g. dot product with leaf values). Default: `NULL`. +#' @param leaf_basis_1 (Optional) Bases used for prediction in the "treatment" case (by e.g. dot product with leaf values). Default: `NULL`. +#' @param rfx_group_ids_0 (Optional) Test set group labels used for prediction from an additive random effects +#' model in the "control" case. We do not currently support (but plan to in the near future), test set evaluation +#' for group labels that were not in the training set. Must be a vector. +#' @param rfx_group_ids_1 (Optional) Test set group labels used for prediction from an additive random effects +#' model in the "treatment" case. We do not currently support (but plan to in the near future), test set evaluation +#' for group labels that were not in the training set. Must be a vector. +#' @param rfx_basis_0 (Optional) Test set basis for used for prediction from an additive random effects model in the "control" case. Must be a matrix or vector. +#' @param rfx_basis_1 (Optional) Test set basis for used for prediction from an additive random effects model in the "treatment" case. Must be a matrix or vector. +#' @param type (Optional) Aggregation level of the contrast. Options are "mean", which averages the contrast evaluations over every draw of a BART model, and "posterior", which returns the entire matrix of posterior contrast estimates. Default: "posterior". +#' @param scale (Optional) Scale of the contrast. Options are "linear", which returns a contrast on the original scale of the mean forest / RFX terms, and "probability", which transforms each contrast term into a probability of observing `y == 1` before taking their difference. "probability" is only valid for models fit with a probit outcome model. Default: "linear". +#' +#' @return Contrast matrix or vector, depending on whether type = "mean" or "posterior". +#' @export +#' +#' @examples +#' n <- 100 +#' p <- 5 +#' X <- matrix(runif(n*p), ncol = p) +#' W <- matrix(runif(n*1), ncol = 1) +#' f_XW <- ( +#' ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5*W[,1]) + +#' ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5*W[,1]) + +#' ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5*W[,1]) + +#' ((0.75 <= X[,1]) & (1 > X[,1])) * (7.5*W[,1]) +#' ) +#' noise_sd <- 1 +#' y <- f_XW + rnorm(n, 0, noise_sd) +#' test_set_pct <- 0.2 +#' n_test <- round(test_set_pct*n) +#' n_train <- n - n_test +#' test_inds <- sort(sample(1:n, n_test, replace = FALSE)) +#' train_inds <- (1:n)[!((1:n) %in% test_inds)] +#' X_test <- X[test_inds,] +#' X_train <- X[train_inds,] +#' W_test <- W[test_inds,] +#' W_train <- W[train_inds,] +#' y_test <- y[test_inds] +#' y_train <- y[train_inds] +#' bart_model <- bart(X_train = X_train, leaf_basis_train = W_train, y_train = y_train, +#' num_gfr = 10, num_burnin = 0, num_mcmc = 10) +#' contrast_test <- compute_contrast_bart_model( +#' bart_model, +#' covariates_0 = X_test, +#' covariates_1 = X_test, +#' leaf_basis_0 = matrix(0, nrow = n_test, ncol = 1), +#' leaf_basis_1 = matrix(1, nrow = n_test, ncol = 1), +#' type = "posterior", +#' scale = "linear" +#' ) +compute_contrast_bart_model <- function( + object, + covariates_0, + covariates_1, + leaf_basis_0 = NULL, + leaf_basis_1 = NULL, + rfx_group_ids_0 = NULL, + rfx_group_ids_1 = NULL, + rfx_basis_0 = NULL, + rfx_basis_1 = NULL, + type = "posterior", + scale = "linear" +) { + # Handle mean function scale + if (!is.character(scale)) { + stop("scale must be a string or character vector") + } + if (!(scale %in% c("linear", "probability"))) { + stop("scale must either be 'linear' or 'probability'") + } + is_probit <- object$model_params$probit_outcome_model + if ((scale == "probability") && (!is_probit)) { + stop( + "scale cannot be 'probability' for models not fit with a probit outcome model" + ) + } + probability_scale <- scale == "probability" + + # Handle prediction type + if (!is.character(type)) { + stop("type must be a string or character vector") + } + if (!(type %in% c("mean", "posterior"))) { + stop("type must either be 'mean' or 'posterior'") + } + predict_mean <- type == "mean" + + # Handle prediction terms + has_mean_forest <- object$model_params$include_mean_forest + has_rfx <- object$model_params$has_rfx + if ((!has_mean_forest) && (!has_rfx)) { + stop( + "Model must have either or both of mean forest or random effects terms to compute the requested contrast." + ) + } + + # Check that covariates are matrix or data frame + if ((!is.data.frame(covariates_0)) && (!is.matrix(covariates_0))) { + stop("covariates_0 must be a matrix or dataframe") + } + if ((!is.data.frame(covariates_1)) && (!is.matrix(covariates_1))) { + stop("covariates_1 must be a matrix or dataframe") + } + + # Convert all input data to matrices if not already converted + if ((is.null(dim(leaf_basis_0))) && (!is.null(leaf_basis_0))) { + leaf_basis_0 <- as.matrix(leaf_basis_0) + } + if ((is.null(dim(leaf_basis_1))) && (!is.null(leaf_basis_1))) { + leaf_basis_1 <- as.matrix(leaf_basis_1) + } + if ((is.null(dim(rfx_basis_0))) && (!is.null(rfx_basis_0))) { + rfx_basis_0 <- as.matrix(rfx_basis_0) + } + if ((is.null(dim(rfx_basis_1))) && (!is.null(rfx_basis_1))) { + rfx_basis_1 <- as.matrix(rfx_basis_1) + } + + # Data checks + if ( + (object$model_params$requires_basis) && + (is.null(leaf_basis_0) || is.null(leaf_basis_1)) + ) { + stop("leaf_basis_0 and leaf_basis_1 must be provided for this model") + } + if ((!is.null(leaf_basis_0)) && (nrow(covariates_0) != nrow(leaf_basis_0))) { + stop("covariates_0 and leaf_basis_0 must have the same number of rows") + } + if ((!is.null(leaf_basis_1)) && (nrow(covariates_1) != nrow(leaf_basis_1))) { + stop("covariates_1 and leaf_basis_1 must have the same number of rows") + } + if (object$model_params$num_covariates != ncol(covariates_0)) { + stop( + "covariates_0 must contain the same number of columns as the BART model's training dataset" + ) + } + if (object$model_params$num_covariates != ncol(covariates_1)) { + stop( + "covariates_1 must contain the same number of columns as the BART model's training dataset" + ) + } + if ((has_rfx) && (is.null(rfx_group_ids_0) || is.null(rfx_group_ids_1))) { + stop( + "rfx_group_ids_0 and rfx_group_ids_1 must be provided for this model" + ) + } + if ((has_rfx) && (is.null(rfx_basis_0) || is.null(rfx_basis_1))) { + stop( + "rfx_basis_0 and rfx_basis_1 must be provided for this model" + ) + } + if ( + (object$model_params$num_rfx_basis > 0) && + ((ncol(rfx_basis_0) != object$model_params$num_rfx_basis) || + (ncol(rfx_basis_1) != object$model_params$num_rfx_basis)) + ) { + stop( + "rfx_basis_0 and / or rfx_basis_1 have a different dimension than the basis used to train this model" + ) + } + + # Predict for the control arm + control_preds <- predict( + object = object, + covariates = covariates_0, + leaf_basis = leaf_basis_0, + rfx_group_ids = rfx_group_ids_0, + rfx_basis = rfx_basis_0, + type = "posterior", + term = "y_hat", + scale = "linear" + ) + + # Predict for the treatment arm + treatment_preds <- predict( + object = object, + covariates = covariates_1, + leaf_basis = leaf_basis_1, + rfx_group_ids = rfx_group_ids_1, + rfx_basis = rfx_basis_1, + type = "posterior", + term = "y_hat", + scale = "linear" + ) + + # Transform to probability scale if requested + if (probability_scale) { + treatment_preds <- pnorm(treatment_preds) + control_preds <- pnorm(control_preds) + } + + # Compute and return contrast + if (predict_mean) { + return(rowMeans(treatment_preds - control_preds)) + } else { + return(treatment_preds - control_preds) + } +} + +#' Sample from the posterior predictive distribution for outcomes modeled by BCF +#' +#' @param model_object A fitted BCF model object of class `bcfmodel`. +#' @param covariates A matrix or data frame of covariates. +#' @param treatment A vector or matrix of treatment assignments. +#' @param propensity (Optional) A vector or matrix of propensity scores. Required if the underlying model depends on user-provided propensities. +#' @param rfx_group_ids (Optional) A vector of group IDs for random effects model. Required if the BCF model includes random effects. +#' @param rfx_basis (Optional) A matrix of bases for random effects model. Required if the BCF model includes random effects. +#' @param num_draws_per_sample (Optional) The number of samples to draw from the likelihood for each draw of the posterior. Defaults to a heuristic based on the number of samples in a BCF model (i.e. if the BCF model has >1000 draws, we use 1 draw from the likelihood per sample, otherwise we upsample to ensure at least 1000 posterior predictive draws). +#' +#' @returns Array of posterior predictive samples with dimensions (num_observations, num_posterior_samples, num_draws_per_sample) if num_draws_per_sample > 1, otherwise (num_observations, num_posterior_samples). +#' +#' @export +#' @examples +#' n <- 100 +#' p <- 5 +#' X <- matrix(rnorm(n * p), nrow = n, ncol = p) +#' pi_X <- pnorm(X[,1] / 2) +#' Z <- rbinom(n, 1, pi_X) +#' y <- 2 * X[,2] + 0.5 * X[,2] * Z + rnorm(n) +#' bcf_model <- bcf(X_train = X, Z_train = Z, y_train = y, propensity_train = pi_X) +#' ppd_samples <- sample_bcf_posterior_predictive( +#' model_object = bcf_model, covariates = X, +#' treatment = Z, propensity = pi_X +#' ) +sample_bcf_posterior_predictive <- function( + model_object, + covariates = NULL, + treatment = NULL, + propensity = NULL, + rfx_group_ids = NULL, + rfx_basis = NULL, + num_draws_per_sample = NULL +) { + # Check the provided model object + check_model_is_valid(model_object) + + # Determine whether the outcome is continuous (Gaussian) or binary (probit-link) + is_probit <- model_object$model_params$probit_outcome_model + + # Check that all the necessary inputs were provided for interval computation + needs_covariates <- TRUE + if (needs_covariates) { + if (is.null(covariates)) { + stop( + "'covariates' must be provided in order to compute the requested intervals" + ) + } + if (!is.matrix(covariates) && !is.data.frame(covariates)) { + stop("'covariates' must be a matrix or data frame") + } + } + needs_treatment <- needs_covariates + if (needs_treatment) { + if (is.null(treatment)) { + stop( + "'treatment' must be provided in order to compute the requested intervals" + ) + } + if (!is.matrix(treatment) && !is.numeric(treatment)) { + stop("'treatment' must be a numeric vector or matrix") + } + if (is.matrix(treatment)) { + if (nrow(treatment) != nrow(covariates)) { + stop("'treatment' must have the same number of rows as 'covariates'") + } + } else { + if (length(treatment) != nrow(covariates)) { + stop( + "'treatment' must have the same number of elements as 'covariates'" + ) + } + } + } + uses_propensity <- model_object$model_params$propensity_covariate != "none" + internal_propensity_model <- model_object$model_params$internal_propensity_model + needs_propensity <- (needs_covariates && + uses_propensity && + (!internal_propensity_model)) + if (needs_propensity) { + if (is.null(propensity)) { + stop( + "'propensity' must be provided in order to compute the requested intervals" + ) + } + if (!is.matrix(propensity) && !is.numeric(propensity)) { + stop("'propensity' must be a numeric vector or matrix") + } + if (is.matrix(propensity)) { + if (nrow(propensity) != nrow(covariates)) { + stop("'propensity' must have the same number of rows as 'covariates'") + } + } else { + if (length(propensity) != nrow(covariates)) { + stop( + "'propensity' must have the same number of elements as 'covariates'" + ) + } + } + } + needs_rfx_data <- model_object$model_params$has_rfx + if (needs_rfx_data) { + if (is.null(rfx_group_ids)) { + stop( + "'rfx_group_ids' must be provided in order to compute the requested intervals" + ) + } + if (length(rfx_group_ids) != nrow(covariates)) { + stop( + "'rfx_group_ids' must have the same length as the number of rows in 'covariates'" + ) + } + if (is.null(rfx_basis)) { + stop( + "'rfx_basis' must be provided in order to compute the requested intervals" + ) + } + if (!is.matrix(rfx_basis)) { + stop("'rfx_basis' must be a matrix") + } + if (nrow(rfx_basis) != nrow(covariates)) { + stop("'rfx_basis' must have the same number of rows as 'covariates'") + } + } + + # Compute posterior samples + bcf_preds <- predict( + model_object, + X = covariates, + Z = treatment, + propensity = propensity, + rfx_group_ids = rfx_group_ids, + rfx_basis = rfx_basis, + type = "posterior", + terms = c("all"), + scale = "linear" + ) + + # Compute outcome mean and variance for every posterior draw + has_rfx <- model_object$model_params$has_rfx + has_variance_forest <- model_object$model_params$include_variance_forest + samples_global_variance <- model_object$model_params$sample_sigma2_global + num_posterior_draws <- model_object$model_params$num_samples + num_observations <- nrow(covariates) + ppd_mean <- bcf_preds$y_hat + if (has_variance_forest) { + ppd_variance <- bcf_preds$variance_forest_predictions + } else { + if (samples_global_variance) { + ppd_variance <- matrix( + rep( + model_object$sigma2_global_samples, + each = num_observations + ), + nrow = num_observations + ) + } else { + ppd_variance <- model_object$model_params$initial_sigma2 + } + } + + # Sample from the posterior predictive distribution + if (is.null(num_draws_per_sample)) { + ppd_draw_multiplier <- posterior_predictive_heuristic_multiplier( + num_posterior_draws, + num_observations + ) + } else { + ppd_draw_multiplier <- num_draws_per_sample + } + num_ppd_draws <- ppd_draw_multiplier * num_posterior_draws * num_observations + ppd_vector <- rnorm(num_ppd_draws, ppd_mean, sqrt(ppd_variance)) + + # Reshape data + if (ppd_draw_multiplier > 1) { + ppd_array <- array( + ppd_vector, + dim = c(num_observations, num_posterior_draws, ppd_draw_multiplier) + ) + } else { + ppd_array <- array( + ppd_vector, + dim = c(num_observations, num_posterior_draws) + ) + } + + # Binarize outcomes for probit models + if (is_probit) { + ppd_array <- (ppd_array > 0.0) * 1 + } + + return(ppd_array) +} + +#' Sample from the posterior predictive distribution for outcomes modeled by BART +#' +#' @param model_object A fitted BART model object of class `bartmodel`. +#' @param covariates A matrix or data frame of covariates. Required if the BART model depends on covariates (e.g., contains a mean or variance forest). +#' @param basis A matrix of bases for mean forest models with regression defined in the leaves. Required for "leaf regression" models. +#' @param rfx_group_ids A vector of group IDs for random effects model. Required if the BART model includes random effects. +#' @param rfx_basis A matrix of bases for random effects model. Required if the BART model includes random effects. +#' @param num_draws_per_sample The number of posterior predictive samples to draw for each posterior sample. Defaults to a heuristic based on the number of samples in a BART model (i.e. if the BART model has >1000 draws, we use 1 draw from the likelihood per sample, otherwise we upsample to ensure intervals are based on at least 1000 posterior predictive draws). +#' +#' @returns Array of posterior predictive samples with dimensions (num_observations, num_posterior_samples, num_draws_per_sample) if num_draws_per_sample > 1, otherwise (num_observations, num_posterior_samples). +#' +#' @export +#' @examples +#' n <- 100 +#' p <- 5 +#' X <- matrix(rnorm(n * p), nrow = n, ncol = p) +#' y <- 2 * X[,1] + rnorm(n) +#' bart_model <- bart(y_train = y, X_train = X) +#' ppd_samples <- sample_bart_posterior_predictive( +#' model_object = bart_model, covariates = X +#' ) +sample_bart_posterior_predictive <- function( + model_object, + covariates = NULL, + basis = NULL, + rfx_group_ids = NULL, + rfx_basis = NULL, + num_draws_per_sample = NULL +) { + # Check the provided model object + check_model_is_valid(model_object) + + # Determine whether the outcome is continuous (Gaussian) or binary (probit-link) + is_probit <- model_object$model_params$probit_outcome_model + + # Check that all the necessary inputs were provided for interval computation + needs_covariates <- model_object$model_params$include_mean_forest + if (needs_covariates) { + if (is.null(covariates)) { + stop( + "'covariates' must be provided in order to compute the requested intervals" + ) + } + if (!is.matrix(covariates) && !is.data.frame(covariates)) { + stop("'covariates' must be a matrix or data frame") + } + } + needs_basis <- needs_covariates && model_object$model_params$has_basis + if (needs_basis) { + if (is.null(basis)) { + stop( + "'basis' must be provided in order to compute the requested intervals" + ) + } + if (!is.matrix(basis)) { + stop("'basis' must be a matrix") + } + if (is.matrix(basis)) { + if (nrow(basis) != nrow(covariates)) { + stop("'basis' must have the same number of rows as 'covariates'") + } + } else { + if (length(basis) != nrow(covariates)) { + stop("'basis' must have the same number of elements as 'covariates'") + } + } + } + needs_rfx_data <- model_object$model_params$has_rfx + if (needs_rfx_data) { + if (is.null(rfx_group_ids)) { + stop( + "'rfx_group_ids' must be provided in order to compute the requested intervals" + ) + } + if (length(rfx_group_ids) != nrow(covariates)) { + stop( + "'rfx_group_ids' must have the same length as the number of rows in 'covariates'" + ) + } + if (is.null(rfx_basis)) { + stop( + "'rfx_basis' must be provided in order to compute the requested intervals" + ) + } + if (!is.matrix(rfx_basis)) { + stop("'rfx_basis' must be a matrix") + } + if (nrow(rfx_basis) != nrow(covariates)) { + stop("'rfx_basis' must have the same number of rows as 'covariates'") + } + } + + # Compute posterior samples + bart_preds <- predict( + model_object, + covariates = covariates, + leaf_basis = basis, + rfx_group_ids = rfx_group_ids, + rfx_basis = rfx_basis, + type = "posterior", + terms = c("all"), + scale = "linear" + ) + + # Compute outcome mean and variance for every posterior draw + has_mean_term <- (model_object$model_params$include_mean_forest || + model_object$model_params$has_rfx) + has_variance_forest <- model_object$model_params$include_variance_forest + samples_global_variance <- model_object$model_params$sample_sigma2_global + num_posterior_draws <- model_object$model_params$num_samples + num_observations <- nrow(covariates) + if (has_mean_term) { + ppd_mean <- bart_preds$y_hat + } else { + ppd_mean <- 0 + } + if (has_variance_forest) { + ppd_variance <- bart_preds$variance_forest_predictions + } else { + if (samples_global_variance) { + ppd_variance <- matrix( + rep( + model_object$sigma2_global_samples, + each = num_observations + ), + nrow = num_observations + ) + } else { + ppd_variance <- model_object$model_params$sigma2_init + } + } + + # Sample from the posterior predictive distribution + if (is.null(num_draws_per_sample)) { + ppd_draw_multiplier <- posterior_predictive_heuristic_multiplier( + num_posterior_draws, + num_observations + ) + } else { + ppd_draw_multiplier <- num_draws_per_sample + } + num_ppd_draws <- ppd_draw_multiplier * num_posterior_draws * num_observations + ppd_vector <- rnorm(num_ppd_draws, ppd_mean, sqrt(ppd_variance)) + + # Reshape data + if (ppd_draw_multiplier > 1) { + ppd_array <- array( + ppd_vector, + dim = c(num_observations, num_posterior_draws, ppd_draw_multiplier) + ) + } else { + ppd_array <- array( + ppd_vector, + dim = c(num_observations, num_posterior_draws) + ) + } + + # Binarize outcomes for probit models + if (is_probit) { + ppd_array <- (ppd_array > 0.0) * 1 + } + + return(ppd_array) +} + +posterior_predictive_heuristic_multiplier <- function( + num_samples, + num_observations +) { + if (num_samples >= 1000) { + return(1) + } else { + return(ceiling(1000 / num_samples)) + } +} + +#' Compute posterior credible intervals for BCF model terms +#' +#' This function computes posterior credible intervals for specified terms from a fitted BCF model. It supports intervals for prognostic forests, CATE forests, variance forests, random effects, and overall mean outcome predictions. +#' @param model_object A fitted BCF model object of class `bcfmodel`. +#' @param terms A character string specifying the model term(s) for which to compute intervals. Options for BCF models are `"prognostic_function"`, `"cate"`, `"variance_forest"`, `"rfx"`, or `"y_hat"`. +#' @param level A numeric value between 0 and 1 specifying the credible interval level (default is 0.95 for a 95% credible interval). +#' @param scale (Optional) Scale of mean function predictions. Options are "linear", which returns predictions on the original scale of the mean forest / RFX terms, and "probability", which transforms predictions into a probability of observing `y == 1`. "probability" is only valid for models fit with a probit outcome model. Default: "linear". +#' @param covariates (Optional) A matrix or data frame of covariates at which to compute the intervals. Required if the requested term depends on covariates (e.g., prognostic forest, CATE forest, variance forest, or overall predictions). +#' @param treatment (Optional) A vector or matrix of treatment assignments. Required if the requested term is `"y_hat"` (overall predictions). +#' @param propensity (Optional) A vector or matrix of propensity scores. Required if the underlying model depends on user-provided propensities. +#' @param rfx_group_ids An optional vector of group IDs for random effects. Required if the requested term includes random effects. +#' @param rfx_basis An optional matrix of basis function evaluations for random effects. Required if the requested term includes random effects. +#' +#' @returns A list containing the lower and upper bounds of the credible interval for the specified term. If multiple terms are requested, a named list with intervals for each term is returned. +#' +#' @export +#' @examples +#' n <- 100 +#' p <- 5 +#' X <- matrix(rnorm(n * p), nrow = n, ncol = p) +#' pi_X <- pnorm(0.5 * X[,1]) +#' Z <- rbinom(n, 1, pi_X) +#' mu_X <- X[,1] +#' tau_X <- 0.25 * X[,2] +#' y <- mu_X + tau_X * Z + rnorm(n) +#' bcf_model <- bcf(X_train = X, Z_train = Z, y_train = y, +#' propensity_train = pi_X) +#' intervals <- compute_bcf_posterior_interval( +#' model_object = bcf_model, +#' terms = c("prognostic_function", "cate"), +#' covariates = X, +#' treatment = Z, +#' propensity = pi_X, +#' level = 0.90 +#' ) +compute_bcf_posterior_interval <- function( + model_object, + terms, + level = 0.95, + scale = "linear", + covariates = NULL, + treatment = NULL, + propensity = NULL, + rfx_group_ids = NULL, + rfx_basis = NULL +) { + # Check the provided model object and requested term + check_model_is_valid(model_object) + for (term in terms) { + check_model_has_term(model_object, term) + } + + # Handle mean function scale + if (!is.character(scale)) { + stop("scale must be a string or character vector") + } + if (!(scale %in% c("linear", "probability"))) { + stop("scale must either be 'linear' or 'probability'") + } + is_probit <- model_object$model_params$probit_outcome_model + if ((scale == "probability") && (!is_probit)) { + stop( + "scale cannot be 'probability' for models not fit with a probit outcome model" + ) + } + + # Check that all the necessary inputs were provided for interval computation + needs_covariates_intermediate <- ((("y_hat" %in% terms) || + ("all" %in% terms))) + needs_covariates <- (("prognostic_function" %in% terms) || + ("cate" %in% terms) || + ("variance_forest" %in% terms) || + (needs_covariates_intermediate)) + if (needs_covariates) { + if (is.null(covariates)) { + stop( + "'covariates' must be provided in order to compute the requested intervals" + ) + } + if (!is.matrix(covariates) && !is.data.frame(covariates)) { + stop("'covariates' must be a matrix or data frame") + } + } + needs_treatment <- needs_covariates + if (needs_treatment) { + if (is.null(treatment)) { + stop( + "'treatment' must be provided in order to compute the requested intervals" + ) + } + if (!is.matrix(treatment) && !is.numeric(treatment)) { + stop("'treatment' must be a numeric vector or matrix") + } + if (is.matrix(treatment)) { + if (nrow(treatment) != nrow(covariates)) { + stop("'treatment' must have the same number of rows as 'covariates'") + } + } else { + if (length(treatment) != nrow(covariates)) { + stop( + "'treatment' must have the same number of elements as 'covariates'" + ) + } + } + } + uses_propensity <- model_object$model_params$propensity_covariate != "none" + internal_propensity_model <- model_object$model_params$internal_propensity_model + needs_propensity <- (needs_covariates && + uses_propensity && + (!internal_propensity_model)) + if (needs_propensity) { + if (is.null(propensity)) { + stop( + "'propensity' must be provided in order to compute the requested intervals" + ) + } + if (!is.matrix(propensity) && !is.numeric(propensity)) { + stop("'propensity' must be a numeric vector or matrix") + } + if (is.matrix(propensity)) { + if (nrow(propensity) != nrow(covariates)) { + stop("'propensity' must have the same number of rows as 'covariates'") + } + } else { + if (length(propensity) != nrow(covariates)) { + stop( + "'propensity' must have the same number of elements as 'covariates'" + ) + } + } + } + needs_rfx_data_intermediate <- ((("y_hat" %in% terms) || + ("all" %in% terms)) && + model_object$model_params$has_rfx) + needs_rfx_data <- (("rfx" %in% terms) || + (needs_rfx_data_intermediate)) + if (needs_rfx_data) { + if (is.null(rfx_group_ids)) { + stop( + "'rfx_group_ids' must be provided in order to compute the requested intervals" + ) + } + if (length(rfx_group_ids) != nrow(covariates)) { + stop( + "'rfx_group_ids' must have the same length as the number of rows in 'covariates'" + ) + } + if (is.null(rfx_basis)) { + stop( + "'rfx_basis' must be provided in order to compute the requested intervals" + ) + } + if (!is.matrix(rfx_basis)) { + stop("'rfx_basis' must be a matrix") + } + if (nrow(rfx_basis) != nrow(covariates)) { + stop("'rfx_basis' must have the same number of rows as 'covariates'") + } + } + + # Compute posterior matrices for the requested model terms + predictions <- predict( + model_object, + X = covariates, + Z = treatment, + propensity = propensity, + rfx_group_ids = rfx_group_ids, + rfx_basis = rfx_basis, + type = "posterior", + terms = terms, + scale = scale + ) + has_multiple_terms <- ifelse(is.list(predictions), TRUE, FALSE) + + # Compute the interval + if (has_multiple_terms) { + result <- list() + for (term_name in names(predictions)) { + result[[term_name]] <- summarize_interval( + predictions[[term_name]], + sample_dim = 2, + level = level + ) + } + return(result) + } else { + return(summarize_interval( + predictions, + sample_dim = 2, + level = level + )) + } +} + +#' Compute posterior credible intervals for specified terms from a fitted BART model. It supports intervals for mean functions, variance functions, random effects, and overall predictions. +#' @param model_object A fitted BART or BCF model object of class `bartmodel`. +#' @param terms A character string specifying the model term(s) for which to compute intervals. Options for BART models are `"mean_forest"`, `"variance_forest"`, `"rfx"`, or `"y_hat"`. +#' @param level A numeric value between 0 and 1 specifying the credible interval level (default is 0.95 for a 95% credible interval). +#' @param scale (Optional) Scale of mean function predictions. Options are "linear", which returns predictions on the original scale of the mean forest / RFX terms, and "probability", which transforms predictions into a probability of observing `y == 1`. "probability" is only valid for models fit with a probit outcome model. Default: "linear". +#' @param covariates A matrix or data frame of covariates at which to compute the intervals. Required if the requested term depends on covariates (e.g., mean forest, variance forest, or overall predictions). +#' @param basis An optional matrix of basis function evaluations for mean forest models with regression defined in the leaves. Required for "leaf regression" models. +#' @param rfx_group_ids An optional vector of group IDs for random effects. Required if the requested term includes random effects. +#' @param rfx_basis An optional matrix of basis function evaluations for random effects. Required if the requested term includes random effects. +#' +#' @returns A list containing the lower and upper bounds of the credible interval for the specified term. If multiple terms are requested, a named list with intervals for each term is returned. +#' +#' @export +#' @examples +#' n <- 100 +#' p <- 5 +#' X <- matrix(rnorm(n * p), nrow = n, ncol = p) +#' y <- 2 * X[,1] + rnorm(n) +#' bart_model <- bart(y_train = y, X_train = X) +#' intervals <- compute_bart_posterior_interval( +#' model_object = bart_model, +#' terms = c("mean_forest", "y_hat"), +#' covariates = X, +#' level = 0.90 +#' ) +#' @export +compute_bart_posterior_interval <- function( + model_object, + terms, + level = 0.95, + scale = "linear", + covariates = NULL, + basis = NULL, + rfx_group_ids = NULL, + rfx_basis = NULL +) { + # Check the provided model object and requested term + check_model_is_valid(model_object) + for (term in terms) { + check_model_has_term(model_object, term) + } + + # Handle mean function scale + if (!is.character(scale)) { + stop("scale must be a string or character vector") + } + if (!(scale %in% c("linear", "probability"))) { + stop("scale must either be 'linear' or 'probability'") + } + is_probit <- model_object$model_params$probit_outcome_model + if ((scale == "probability") && (!is_probit)) { + stop( + "scale cannot be 'probability' for models not fit with a probit outcome model" + ) + } + + # Check that all the necessary inputs were provided for interval computation + needs_covariates_intermediate <- ((("y_hat" %in% terms) || + ("all" %in% terms)) && + model_object$model_params$include_mean_forest) + needs_covariates <- (("mean_forest" %in% terms) || + ("variance_forest" %in% terms) || + (needs_covariates_intermediate)) + if (needs_covariates) { + if (is.null(covariates)) { + stop( + "'covariates' must be provided in order to compute the requested intervals" + ) + } + if (!is.matrix(covariates) && !is.data.frame(covariates)) { + stop("'covariates' must be a matrix or data frame") + } + } + needs_basis <- needs_covariates && model_object$model_params$has_basis + if (needs_basis) { + if (is.null(basis)) { + stop( + "'basis' must be provided in order to compute the requested intervals" + ) + } + if (!is.matrix(basis)) { + stop("'basis' must be a matrix") + } + if (is.matrix(basis)) { + if (nrow(basis) != nrow(covariates)) { + stop("'basis' must have the same number of rows as 'covariates'") + } + } else { + if (length(basis) != nrow(covariates)) { + stop("'basis' must have the same number of elements as 'covariates'") + } + } + } + needs_rfx_data_intermediate <- ((("y_hat" %in% terms) || + ("all" %in% terms)) && + model_object$model_params$has_rfx) + needs_rfx_data <- (("rfx" %in% terms) || + (needs_rfx_data_intermediate)) + if (needs_rfx_data) { + if (is.null(rfx_group_ids)) { + stop( + "'rfx_group_ids' must be provided in order to compute the requested intervals" + ) + } + if (length(rfx_group_ids) != nrow(covariates)) { + stop( + "'rfx_group_ids' must have the same length as the number of rows in 'covariates'" + ) + } + if (is.null(rfx_basis)) { + stop( + "'rfx_basis' must be provided in order to compute the requested intervals" + ) + } + if (!is.matrix(rfx_basis)) { + stop("'rfx_basis' must be a matrix") + } + if (nrow(rfx_basis) != nrow(covariates)) { + stop("'rfx_basis' must have the same number of rows as 'covariates'") + } + } + + # Compute posterior matrices for the requested model terms + predictions <- predict( + model_object, + covariates = covariates, + leaf_basis = basis, + rfx_group_ids = rfx_group_ids, + rfx_basis = rfx_basis, + type = "posterior", + terms = terms, + scale = scale + ) + has_multiple_terms <- ifelse(is.list(predictions), TRUE, FALSE) + + # Compute the interval + if (has_multiple_terms) { + result <- list() + for (term_name in names(predictions)) { + result[[term_name]] <- summarize_interval( + predictions[[term_name]], + sample_dim = 2, + level = level + ) + } + return(result) + } else { + return(summarize_interval( + predictions, + sample_dim = 2, + level = level + )) + } +} + +summarize_interval <- function(array, sample_dim = 2, level = 0.95) { + # Check that the array is numeric and at least 2 dimensional + stopifnot(is.numeric(array) && length(dim(array)) >= 2) + + # Compute lower and upper quantiles based on the requested interval + quantile_lb <- (1 - level) / 2 + quantile_ub <- 1 - quantile_lb + + # Determine the dimensions over which interval is computed + apply_dim <- setdiff(1:length(dim(array)), sample_dim) + + # Calculate the interval + result_lb <- apply(array, apply_dim, function(x) { + quantile(x, probs = quantile_lb, names = FALSE) + }) + result_ub <- apply(array, apply_dim, function(x) { + quantile(x, probs = quantile_ub, names = FALSE) + }) + + return(list(lower = result_lb, upper = result_ub)) +} + +check_model_is_valid <- function(model_object) { + if ( + (!inherits(model_object, "bartmodel")) && + (!inherits(model_object, "bcfmodel")) + ) { + stop("'model_object' must be a bartmodel or bcfmodel") + } +} + +check_model_has_term <- function(model_object, term) { + # Parse inputs + if (!is.character(term) || length(term) != 1) { + stop("'term' must be a single character string") + } + if ( + (!inherits(model_object, "bartmodel")) && + (!inherits(model_object, "bcfmodel")) + ) { + stop("'model_object' must be a bartmodel or bcfmodel") + } + model_type <- ifelse(inherits(model_object, "bartmodel"), "bart", "bcf") + + # Check if the term was fitted as part of the provided model + if (model_type == "bart") { + validate_bart_term(term) + return(bart_model_has_term(model_object, term)) + } else { + validate_bcf_term(term) + return(bcf_model_has_term(model_object, term)) + } +} + +bart_model_has_term <- function(model_object, term) { + if (term == "mean_forest") { + return(model_object$model_params$include_mean_forest) + } else if (term == "variance_forest") { + return(model_object$model_params$include_variance_forest) + } else if (term == "rfx") { + return(model_object$model_params$has_rfx) + } else if (term == "y_hat") { + return( + model_object$model_params$include_mean_forest || + model_object$model_params$has_rfx + ) + } else if (term == "all") { + return(TRUE) + } else { + return(FALSE) + } +} + +bcf_model_has_term <- function(model_object, term) { + if (term == "prognostic_function") { + return(TRUE) + } else if (term == "cate") { + return(TRUE) + } else if (term == "variance_forest") { + return(model_object$model_params$include_variance_forest) + } else if (term == "rfx") { + return(model_object$model_params$has_rfx) + } else if (term == "y_hat") { + return(TRUE) + } else if (term == "all") { + return(TRUE) + } else { + return(FALSE) + } +} + +validate_bart_term <- function(term) { + model_terms <- c("mean_forest", "variance_forest", "rfx", "y_hat", "all") + if (!(term %in% model_terms)) { + stop( + "'term' must be one of 'mean_forest', 'variance_forest', 'rfx', 'y_hat', or 'all' for bartmodel objects" + ) + } +} + +validate_bcf_term <- function(term) { + model_terms <- c( + "prognostic_function", + "cate", + "variance_forest", + "rfx", + "y_hat", + "all" + ) + if (!(term %in% model_terms)) { + stop( + "'term' must be one of 'prognostic_function', 'cate', 'variance_forest', 'rfx', 'y_hat', or 'all' for bcfmodel objects" + ) + } +} diff --git a/R/random_effects.R b/R/random_effects.R index b91c2678..868897ed 100644 --- a/R/random_effects.R +++ b/R/random_effects.R @@ -9,234 +9,234 @@ #' needed for prediction / serialization RandomEffectSamples <- R6::R6Class( - classname = "RandomEffectSamples", - cloneable = FALSE, - public = list( - #' @field rfx_container_ptr External pointer to a C++ StochTree::RandomEffectsContainer class - rfx_container_ptr = NULL, - - #' @field label_mapper_ptr External pointer to a C++ StochTree::LabelMapper class - label_mapper_ptr = NULL, - - #' @field training_group_ids Unique vector of group IDs that were in the training dataset - training_group_ids = NULL, - - #' @description - #' Create a new RandomEffectSamples object. - #' @return A new `RandomEffectSamples` object. - initialize = function() {}, - - #' @description - #' Construct RandomEffectSamples object from other "in-session" R objects - #' @param num_components Number of "components" or bases defining the random effects regression - #' @param num_groups Number of random effects groups - #' @param random_effects_tracker Object of type `RandomEffectsTracker` - #' @return None - load_in_session = function( - num_components, - num_groups, - random_effects_tracker - ) { - # Initialize - self$rfx_container_ptr <- rfx_container_cpp( - num_components, - num_groups - ) - self$label_mapper_ptr <- rfx_label_mapper_cpp( - random_effects_tracker$rfx_tracker_ptr - ) - self$training_group_ids <- rfx_tracker_get_unique_group_ids_cpp( - random_effects_tracker$rfx_tracker_ptr - ) - }, - - #' @description - #' Construct RandomEffectSamples object from a json object - #' @param json_object Object of class `CppJson` - #' @param json_rfx_container_label Label referring to a particular rfx sample container (i.e. "random_effect_container_0") in the overall json hierarchy - #' @param json_rfx_mapper_label Label referring to a particular rfx label mapper (i.e. "random_effect_label_mapper_0") in the overall json hierarchy - #' @param json_rfx_groupids_label Label referring to a particular set of rfx group IDs (i.e. "random_effect_groupids_0") in the overall json hierarchy - #' @return A new `RandomEffectSamples` object. - load_from_json = function( - json_object, - json_rfx_container_label, - json_rfx_mapper_label, - json_rfx_groupids_label - ) { - self$rfx_container_ptr <- rfx_container_from_json_cpp( - json_object$json_ptr, - json_rfx_container_label - ) - self$label_mapper_ptr <- rfx_label_mapper_from_json_cpp( - json_object$json_ptr, - json_rfx_mapper_label - ) - self$training_group_ids <- rfx_group_ids_from_json_cpp( - json_object$json_ptr, - json_rfx_groupids_label - ) - }, - - #' @description - #' Append random effect draws to `RandomEffectSamples` object from a json object - #' @param json_object Object of class `CppJson` - #' @param json_rfx_container_label Label referring to a particular rfx sample container (i.e. "random_effect_container_0") in the overall json hierarchy - #' @param json_rfx_mapper_label Label referring to a particular rfx label mapper (i.e. "random_effect_label_mapper_0") in the overall json hierarchy - #' @param json_rfx_groupids_label Label referring to a particular set of rfx group IDs (i.e. "random_effect_groupids_0") in the overall json hierarchy - #' @return None - append_from_json = function( - json_object, - json_rfx_container_label, - json_rfx_mapper_label, - json_rfx_groupids_label - ) { - rfx_container_append_from_json_cpp( - self$rfx_container_ptr, - json_object$json_ptr, - json_rfx_container_label - ) - }, - - #' @description - #' Construct RandomEffectSamples object from a json object - #' @param json_string JSON string which parses into object of class `CppJson` - #' @param json_rfx_container_label Label referring to a particular rfx sample container (i.e. "random_effect_container_0") in the overall json hierarchy - #' @param json_rfx_mapper_label Label referring to a particular rfx label mapper (i.e. "random_effect_label_mapper_0") in the overall json hierarchy - #' @param json_rfx_groupids_label Label referring to a particular set of rfx group IDs (i.e. "random_effect_groupids_0") in the overall json hierarchy - #' @return A new `RandomEffectSamples` object. - load_from_json_string = function( - json_string, - json_rfx_container_label, - json_rfx_mapper_label, - json_rfx_groupids_label - ) { - self$rfx_container_ptr <- rfx_container_from_json_string_cpp( - json_string, - json_rfx_container_label - ) - self$label_mapper_ptr <- rfx_label_mapper_from_json_string_cpp( - json_string, - json_rfx_mapper_label - ) - self$training_group_ids <- rfx_group_ids_from_json_string_cpp( - json_string, - json_rfx_groupids_label - ) - }, - - #' @description - #' Append random effect draws to `RandomEffectSamples` object from a json object - #' @param json_string JSON string which parses into object of class `CppJson` - #' @param json_rfx_container_label Label referring to a particular rfx sample container (i.e. "random_effect_container_0") in the overall json hierarchy - #' @param json_rfx_mapper_label Label referring to a particular rfx label mapper (i.e. "random_effect_label_mapper_0") in the overall json hierarchy - #' @param json_rfx_groupids_label Label referring to a particular set of rfx group IDs (i.e. "random_effect_groupids_0") in the overall json hierarchy - #' @return None - append_from_json_string = function( - json_string, - json_rfx_container_label, - json_rfx_mapper_label, - json_rfx_groupids_label - ) { - # Append RFX objects - rfx_container_append_from_json_string_cpp( - self$rfx_container_ptr, - json_string, - json_rfx_container_label - ) - }, - - #' @description - #' Predict random effects for each observation implied by `rfx_group_ids` and `rfx_basis`. - #' If a random effects model is "intercept-only" the `rfx_basis` will be a vector of ones of size `length(rfx_group_ids)`. - #' @param rfx_group_ids Indices of random effects groups in a prediction set - #' @param rfx_basis (Optional) Basis used for random effects prediction - #' @return Matrix with as many rows as observations provided and as many columns as samples drawn of the model. - predict = function(rfx_group_ids, rfx_basis = NULL) { - num_obs = length(rfx_group_ids) - if (is.null(rfx_basis)) { - rfx_basis <- matrix(rep(1, num_obs), ncol = 1) - } - num_samples = rfx_container_num_samples_cpp(self$rfx_container_ptr) - num_components = rfx_container_num_components_cpp( - self$rfx_container_ptr - ) - num_groups = rfx_container_num_groups_cpp(self$rfx_container_ptr) - rfx_group_ids_int <- as.integer(rfx_group_ids) - stopifnot(sum(abs(rfx_group_ids_int - rfx_group_ids)) < 1e-6) - stopifnot(sum(!(rfx_group_ids %in% self$training_group_ids)) == 0) - stopifnot(ncol(rfx_basis) == num_components) - rfx_dataset <- createRandomEffectsDataset( - rfx_group_ids_int, - rfx_basis - ) - output <- rfx_container_predict_cpp( - self$rfx_container_ptr, - rfx_dataset$data_ptr, - self$label_mapper_ptr - ) - dim(output) <- c(num_obs, num_samples) - return(output) - }, - - #' @description - #' Extract the random effects parameters sampled. With the "redundant parameterization" - #' of Gelman et al (2008), this includes four parameters: alpha (the "working parameter" - #' shared across every group), xi (the "group parameter" sampled separately for each group), - #' beta (the product of alpha and xi, which corresponds to the overall group-level random effects), - #' and sigma (group-independent prior variance for each component of xi). - #' @return List of arrays. The alpha array has dimension (`num_components`, `num_samples`) and is simply a vector if `num_components = 1`. - #' The xi and beta arrays have dimension (`num_components`, `num_groups`, `num_samples`) and is simply a matrix if `num_components = 1`. - #' The sigma array has dimension (`num_components`, `num_samples`) and is simply a vector if `num_components = 1`. - extract_parameter_samples = function() { - num_samples = rfx_container_num_samples_cpp(self$rfx_container_ptr) - num_components = rfx_container_num_components_cpp( - self$rfx_container_ptr - ) - num_groups = rfx_container_num_groups_cpp(self$rfx_container_ptr) - beta_samples <- rfx_container_get_beta_cpp(self$rfx_container_ptr) - xi_samples <- rfx_container_get_xi_cpp(self$rfx_container_ptr) - alpha_samples <- rfx_container_get_alpha_cpp(self$rfx_container_ptr) - sigma_samples <- rfx_container_get_sigma_cpp(self$rfx_container_ptr) - if (num_components == 1) { - dim(beta_samples) <- c(num_groups, num_samples) - dim(xi_samples) <- c(num_groups, num_samples) - } else if (num_components > 1) { - dim(beta_samples) <- c(num_components, num_groups, num_samples) - dim(xi_samples) <- c(num_components, num_groups, num_samples) - dim(alpha_samples) <- c(num_components, num_samples) - dim(sigma_samples) <- c(num_components, num_samples) - } else { - stop( - "Invalid random effects sample container, num_components is less than 1" - ) - } - - output = list( - "beta_samples" = beta_samples, - "xi_samples" = xi_samples, - "alpha_samples" = alpha_samples, - "sigma_samples" = sigma_samples - ) - return(output) - }, - - #' @description - #' Modify the `RandomEffectsSamples` object by removing the parameter samples index by `sample_num`. - #' @param sample_num Index of the RFX sample to be removed - delete_sample = function(sample_num) { - rfx_container_delete_sample_cpp(self$rfx_container_ptr, sample_num) - }, - - #' @description - #' Convert the mapping of group IDs to random effect components indices from C++ to R native format - #' @return List mapping group ID to random effect components. - extract_label_mapping = function() { - keys_and_vals <- rfx_label_mapper_to_list_cpp(self$label_mapper_ptr) - result <- as.list(keys_and_vals[[2]] + 1) - setNames(result, keys_and_vals[[1]]) - return(result) - } - ) + classname = "RandomEffectSamples", + cloneable = FALSE, + public = list( + #' @field rfx_container_ptr External pointer to a C++ StochTree::RandomEffectsContainer class + rfx_container_ptr = NULL, + + #' @field label_mapper_ptr External pointer to a C++ StochTree::LabelMapper class + label_mapper_ptr = NULL, + + #' @field training_group_ids Unique vector of group IDs that were in the training dataset + training_group_ids = NULL, + + #' @description + #' Create a new RandomEffectSamples object. + #' @return A new `RandomEffectSamples` object. + initialize = function() {}, + + #' @description + #' Construct RandomEffectSamples object from other "in-session" R objects + #' @param num_components Number of "components" or bases defining the random effects regression + #' @param num_groups Number of random effects groups + #' @param random_effects_tracker Object of type `RandomEffectsTracker` + #' @return None + load_in_session = function( + num_components, + num_groups, + random_effects_tracker + ) { + # Initialize + self$rfx_container_ptr <- rfx_container_cpp( + num_components, + num_groups + ) + self$label_mapper_ptr <- rfx_label_mapper_cpp( + random_effects_tracker$rfx_tracker_ptr + ) + self$training_group_ids <- rfx_tracker_get_unique_group_ids_cpp( + random_effects_tracker$rfx_tracker_ptr + ) + }, + + #' @description + #' Construct RandomEffectSamples object from a json object + #' @param json_object Object of class `CppJson` + #' @param json_rfx_container_label Label referring to a particular rfx sample container (i.e. "random_effect_container_0") in the overall json hierarchy + #' @param json_rfx_mapper_label Label referring to a particular rfx label mapper (i.e. "random_effect_label_mapper_0") in the overall json hierarchy + #' @param json_rfx_groupids_label Label referring to a particular set of rfx group IDs (i.e. "random_effect_groupids_0") in the overall json hierarchy + #' @return A new `RandomEffectSamples` object. + load_from_json = function( + json_object, + json_rfx_container_label, + json_rfx_mapper_label, + json_rfx_groupids_label + ) { + self$rfx_container_ptr <- rfx_container_from_json_cpp( + json_object$json_ptr, + json_rfx_container_label + ) + self$label_mapper_ptr <- rfx_label_mapper_from_json_cpp( + json_object$json_ptr, + json_rfx_mapper_label + ) + self$training_group_ids <- rfx_group_ids_from_json_cpp( + json_object$json_ptr, + json_rfx_groupids_label + ) + }, + + #' @description + #' Append random effect draws to `RandomEffectSamples` object from a json object + #' @param json_object Object of class `CppJson` + #' @param json_rfx_container_label Label referring to a particular rfx sample container (i.e. "random_effect_container_0") in the overall json hierarchy + #' @param json_rfx_mapper_label Label referring to a particular rfx label mapper (i.e. "random_effect_label_mapper_0") in the overall json hierarchy + #' @param json_rfx_groupids_label Label referring to a particular set of rfx group IDs (i.e. "random_effect_groupids_0") in the overall json hierarchy + #' @return None + append_from_json = function( + json_object, + json_rfx_container_label, + json_rfx_mapper_label, + json_rfx_groupids_label + ) { + rfx_container_append_from_json_cpp( + self$rfx_container_ptr, + json_object$json_ptr, + json_rfx_container_label + ) + }, + + #' @description + #' Construct RandomEffectSamples object from a json object + #' @param json_string JSON string which parses into object of class `CppJson` + #' @param json_rfx_container_label Label referring to a particular rfx sample container (i.e. "random_effect_container_0") in the overall json hierarchy + #' @param json_rfx_mapper_label Label referring to a particular rfx label mapper (i.e. "random_effect_label_mapper_0") in the overall json hierarchy + #' @param json_rfx_groupids_label Label referring to a particular set of rfx group IDs (i.e. "random_effect_groupids_0") in the overall json hierarchy + #' @return A new `RandomEffectSamples` object. + load_from_json_string = function( + json_string, + json_rfx_container_label, + json_rfx_mapper_label, + json_rfx_groupids_label + ) { + self$rfx_container_ptr <- rfx_container_from_json_string_cpp( + json_string, + json_rfx_container_label + ) + self$label_mapper_ptr <- rfx_label_mapper_from_json_string_cpp( + json_string, + json_rfx_mapper_label + ) + self$training_group_ids <- rfx_group_ids_from_json_string_cpp( + json_string, + json_rfx_groupids_label + ) + }, + + #' @description + #' Append random effect draws to `RandomEffectSamples` object from a json object + #' @param json_string JSON string which parses into object of class `CppJson` + #' @param json_rfx_container_label Label referring to a particular rfx sample container (i.e. "random_effect_container_0") in the overall json hierarchy + #' @param json_rfx_mapper_label Label referring to a particular rfx label mapper (i.e. "random_effect_label_mapper_0") in the overall json hierarchy + #' @param json_rfx_groupids_label Label referring to a particular set of rfx group IDs (i.e. "random_effect_groupids_0") in the overall json hierarchy + #' @return None + append_from_json_string = function( + json_string, + json_rfx_container_label, + json_rfx_mapper_label, + json_rfx_groupids_label + ) { + # Append RFX objects + rfx_container_append_from_json_string_cpp( + self$rfx_container_ptr, + json_string, + json_rfx_container_label + ) + }, + + #' @description + #' Predict random effects for each observation implied by `rfx_group_ids` and `rfx_basis`. + #' If a random effects model is "intercept-only" the `rfx_basis` will be a vector of ones of size `length(rfx_group_ids)`. + #' @param rfx_group_ids Indices of random effects groups in a prediction set + #' @param rfx_basis (Optional) Basis used for random effects prediction + #' @return Matrix with as many rows as observations provided and as many columns as samples drawn of the model. + predict = function(rfx_group_ids, rfx_basis = NULL) { + num_obs = length(rfx_group_ids) + if (is.null(rfx_basis)) { + rfx_basis <- matrix(rep(1, num_obs), ncol = 1) + } + num_samples = rfx_container_num_samples_cpp(self$rfx_container_ptr) + num_components = rfx_container_num_components_cpp( + self$rfx_container_ptr + ) + num_groups = rfx_container_num_groups_cpp(self$rfx_container_ptr) + rfx_group_ids_int <- as.integer(rfx_group_ids) + stopifnot(sum(abs(rfx_group_ids_int - rfx_group_ids)) < 1e-6) + stopifnot(sum(!(rfx_group_ids %in% self$training_group_ids)) == 0) + stopifnot(ncol(rfx_basis) == num_components) + rfx_dataset <- createRandomEffectsDataset( + rfx_group_ids_int, + rfx_basis + ) + output <- rfx_container_predict_cpp( + self$rfx_container_ptr, + rfx_dataset$data_ptr, + self$label_mapper_ptr + ) + dim(output) <- c(num_obs, num_samples) + return(output) + }, + + #' @description + #' Extract the random effects parameters sampled. With the "redundant parameterization" + #' of Gelman et al (2008), this includes four parameters: alpha (the "working parameter" + #' shared across every group), xi (the "group parameter" sampled separately for each group), + #' beta (the product of alpha and xi, which corresponds to the overall group-level random effects), + #' and sigma (group-independent prior variance for each component of xi). + #' @return List of arrays. The alpha array has dimension (`num_components`, `num_samples`) and is simply a vector if `num_components = 1`. + #' The xi and beta arrays have dimension (`num_components`, `num_groups`, `num_samples`) and are simply matrices if `num_components = 1`. + #' The sigma array has dimension (`num_components`, `num_samples`) and is simply a vector if `num_components = 1`. + extract_parameter_samples = function() { + num_samples = rfx_container_num_samples_cpp(self$rfx_container_ptr) + num_components = rfx_container_num_components_cpp( + self$rfx_container_ptr + ) + num_groups = rfx_container_num_groups_cpp(self$rfx_container_ptr) + beta_samples <- rfx_container_get_beta_cpp(self$rfx_container_ptr) + xi_samples <- rfx_container_get_xi_cpp(self$rfx_container_ptr) + alpha_samples <- rfx_container_get_alpha_cpp(self$rfx_container_ptr) + sigma_samples <- rfx_container_get_sigma_cpp(self$rfx_container_ptr) + if (num_components == 1) { + dim(beta_samples) <- c(num_groups, num_samples) + dim(xi_samples) <- c(num_groups, num_samples) + } else if (num_components > 1) { + dim(beta_samples) <- c(num_components, num_groups, num_samples) + dim(xi_samples) <- c(num_components, num_groups, num_samples) + dim(alpha_samples) <- c(num_components, num_samples) + dim(sigma_samples) <- c(num_components, num_samples) + } else { + stop( + "Invalid random effects sample container, num_components is less than 1" + ) + } + + output = list( + "beta_samples" = beta_samples, + "xi_samples" = xi_samples, + "alpha_samples" = alpha_samples, + "sigma_samples" = sigma_samples + ) + return(output) + }, + + #' @description + #' Modify the `RandomEffectsSamples` object by removing the parameter samples index by `sample_num`. + #' @param sample_num Index of the RFX sample to be removed + delete_sample = function(sample_num) { + rfx_container_delete_sample_cpp(self$rfx_container_ptr, sample_num) + }, + + #' @description + #' Convert the mapping of group IDs to random effect components indices from C++ to R native format + #' @return List mapping group ID to random effect components. + extract_label_mapping = function() { + keys_and_vals <- rfx_label_mapper_to_list_cpp(self$label_mapper_ptr) + result <- as.list(keys_and_vals[[2]] + 1) + setNames(result, keys_and_vals[[1]]) + return(result) + } + ) ) #' Class that defines a "tracker" for random effects models, most notably @@ -249,21 +249,21 @@ RandomEffectSamples <- R6::R6Class( #' group, and predictions for each observation. RandomEffectsTracker <- R6::R6Class( - classname = "RandomEffectsTracker", - cloneable = FALSE, - public = list( - #' @field rfx_tracker_ptr External pointer to a C++ StochTree::RandomEffectsTracker class - rfx_tracker_ptr = NULL, - - #' @description - #' Create a new RandomEffectsTracker object. - #' @param rfx_group_indices Integer indices indicating groups used to define random effects - #' @return A new `RandomEffectsTracker` object. - initialize = function(rfx_group_indices) { - # Initialize - self$rfx_tracker_ptr <- rfx_tracker_cpp(rfx_group_indices) - } - ) + classname = "RandomEffectsTracker", + cloneable = FALSE, + public = list( + #' @field rfx_tracker_ptr External pointer to a C++ StochTree::RandomEffectsTracker class + rfx_tracker_ptr = NULL, + + #' @description + #' Create a new RandomEffectsTracker object. + #' @param rfx_group_indices Integer indices indicating groups used to define random effects + #' @return A new `RandomEffectsTracker` object. + initialize = function(rfx_group_indices) { + # Initialize + self$rfx_tracker_ptr <- rfx_tracker_cpp(rfx_group_indices) + } + ) ) #' The core "model" class for sampling random effects. @@ -273,158 +273,158 @@ RandomEffectsTracker <- R6::R6Class( #' sampling from the conditional posterior of each parameter. RandomEffectsModel <- R6::R6Class( - classname = "RandomEffectsModel", - cloneable = FALSE, - public = list( - #' @field rfx_model_ptr External pointer to a C++ StochTree::RandomEffectsModel class - rfx_model_ptr = NULL, - - #' @field num_groups Number of groups in the random effects model - num_groups = NULL, - - #' @field num_components Number of components (i.e. dimension of basis) in the random effects model - num_components = NULL, - - #' @description - #' Create a new RandomEffectsModel object. - #' @param num_components Number of "components" or bases defining the random effects regression - #' @param num_groups Number of random effects groups - #' @return A new `RandomEffectsModel` object. - initialize = function(num_components, num_groups) { - # Initialize - self$rfx_model_ptr <- rfx_model_cpp(num_components, num_groups) - self$num_components <- num_components - self$num_groups <- num_groups - }, - - #' @description - #' Sample from random effects model. - #' @param rfx_dataset Object of type `RandomEffectsDataset` - #' @param residual Object of type `Outcome` - #' @param rfx_tracker Object of type `RandomEffectsTracker` - #' @param rfx_samples Object of type `RandomEffectSamples` - #' @param keep_sample Whether sample should be retained in `rfx_samples`. If `FALSE`, the state of `rfx_tracker` will be updated, but the parameter values will not be added to the sample container. Samples are commonly discarded due to burn-in or thinning. - #' @param global_variance Scalar global variance parameter - #' @param rng Object of type `CppRNG` - #' @return None - sample_random_effect = function( - rfx_dataset, - residual, - rfx_tracker, - rfx_samples, - keep_sample, - global_variance, - rng - ) { - rfx_model_sample_random_effects_cpp( - self$rfx_model_ptr, - rfx_dataset$data_ptr, - residual$data_ptr, - rfx_tracker$rfx_tracker_ptr, - rfx_samples$rfx_container_ptr, - keep_sample, - global_variance, - rng$rng_ptr - ) - }, - - #' @description - #' Predict from (a single sample of a) random effects model. - #' @param rfx_dataset Object of type `RandomEffectsDataset` - #' @param rfx_tracker Object of type `RandomEffectsTracker` - #' @return Vector of predictions with size matching number of observations in rfx_dataset - predict = function(rfx_dataset, rfx_tracker) { - pred <- rfx_model_predict_cpp( - self$rfx_model_ptr, - rfx_dataset$data_ptr, - rfx_tracker$rfx_tracker_ptr - ) - return(pred) - }, - - #' @description - #' Set value for the "working parameter." This is typically - #' used for initialization, but could also be used to interrupt - #' or override the sampler. - #' @param value Parameter input - #' @return None - set_working_parameter = function(value) { - stopifnot(is.double(value)) - stopifnot(!is.matrix(value)) - stopifnot(length(value) == self$num_components) - rfx_model_set_working_parameter_cpp(self$rfx_model_ptr, value) - }, - - #' @description - #' Set value for the "group parameters." This is typically - #' used for initialization, but could also be used to interrupt - #' or override the sampler. - #' @param value Parameter input - #' @return None - set_group_parameters = function(value) { - stopifnot(is.double(value)) - stopifnot(is.matrix(value)) - stopifnot(nrow(value) == self$num_components) - stopifnot(ncol(value) == self$num_groups) - rfx_model_set_group_parameters_cpp(self$rfx_model_ptr, value) - }, - - #' @description - #' Set value for the working parameter covariance. This is typically - #' used for initialization, but could also be used to interrupt - #' or override the sampler. - #' @param value Parameter input - #' @return None - set_working_parameter_cov = function(value) { - stopifnot(is.double(value)) - stopifnot(is.matrix(value)) - stopifnot(nrow(value) == self$num_components) - stopifnot(ncol(value) == self$num_components) - rfx_model_set_working_parameter_covariance_cpp( - self$rfx_model_ptr, - value - ) - }, - - #' @description - #' Set value for the group parameter covariance. This is typically - #' used for initialization, but could also be used to interrupt - #' or override the sampler. - #' @param value Parameter input - #' @return None - set_group_parameter_cov = function(value) { - stopifnot(is.double(value)) - stopifnot(is.matrix(value)) - stopifnot(nrow(value) == self$num_components) - stopifnot(ncol(value) == self$num_components) - rfx_model_set_group_parameter_covariance_cpp( - self$rfx_model_ptr, - value - ) - }, - - #' @description - #' Set shape parameter for the group parameter variance prior. - #' @param value Parameter input - #' @return None - set_variance_prior_shape = function(value) { - stopifnot(is.double(value)) - stopifnot(!is.matrix(value)) - stopifnot(length(value) == 1) - rfx_model_set_variance_prior_shape_cpp(self$rfx_model_ptr, value) - }, - - #' @description - #' Set shape parameter for the group parameter variance prior. - #' @param value Parameter input - #' @return None - set_variance_prior_scale = function(value) { - stopifnot(is.double(value)) - stopifnot(!is.matrix(value)) - stopifnot(length(value) == 1) - rfx_model_set_variance_prior_scale_cpp(self$rfx_model_ptr, value) - } - ) + classname = "RandomEffectsModel", + cloneable = FALSE, + public = list( + #' @field rfx_model_ptr External pointer to a C++ StochTree::RandomEffectsModel class + rfx_model_ptr = NULL, + + #' @field num_groups Number of groups in the random effects model + num_groups = NULL, + + #' @field num_components Number of components (i.e. dimension of basis) in the random effects model + num_components = NULL, + + #' @description + #' Create a new RandomEffectsModel object. + #' @param num_components Number of "components" or bases defining the random effects regression + #' @param num_groups Number of random effects groups + #' @return A new `RandomEffectsModel` object. + initialize = function(num_components, num_groups) { + # Initialize + self$rfx_model_ptr <- rfx_model_cpp(num_components, num_groups) + self$num_components <- num_components + self$num_groups <- num_groups + }, + + #' @description + #' Sample from random effects model. + #' @param rfx_dataset Object of type `RandomEffectsDataset` + #' @param residual Object of type `Outcome` + #' @param rfx_tracker Object of type `RandomEffectsTracker` + #' @param rfx_samples Object of type `RandomEffectSamples` + #' @param keep_sample Whether sample should be retained in `rfx_samples`. If `FALSE`, the state of `rfx_tracker` will be updated, but the parameter values will not be added to the sample container. Samples are commonly discarded due to burn-in or thinning. + #' @param global_variance Scalar global variance parameter + #' @param rng Object of type `CppRNG` + #' @return None + sample_random_effect = function( + rfx_dataset, + residual, + rfx_tracker, + rfx_samples, + keep_sample, + global_variance, + rng + ) { + rfx_model_sample_random_effects_cpp( + self$rfx_model_ptr, + rfx_dataset$data_ptr, + residual$data_ptr, + rfx_tracker$rfx_tracker_ptr, + rfx_samples$rfx_container_ptr, + keep_sample, + global_variance, + rng$rng_ptr + ) + }, + + #' @description + #' Predict from (a single sample of a) random effects model. + #' @param rfx_dataset Object of type `RandomEffectsDataset` + #' @param rfx_tracker Object of type `RandomEffectsTracker` + #' @return Vector of predictions with size matching number of observations in rfx_dataset + predict = function(rfx_dataset, rfx_tracker) { + pred <- rfx_model_predict_cpp( + self$rfx_model_ptr, + rfx_dataset$data_ptr, + rfx_tracker$rfx_tracker_ptr + ) + return(pred) + }, + + #' @description + #' Set value for the "working parameter." This is typically + #' used for initialization, but could also be used to interrupt + #' or override the sampler. + #' @param value Parameter input + #' @return None + set_working_parameter = function(value) { + stopifnot(is.double(value)) + stopifnot(!is.matrix(value)) + stopifnot(length(value) == self$num_components) + rfx_model_set_working_parameter_cpp(self$rfx_model_ptr, value) + }, + + #' @description + #' Set value for the "group parameters." This is typically + #' used for initialization, but could also be used to interrupt + #' or override the sampler. + #' @param value Parameter input + #' @return None + set_group_parameters = function(value) { + stopifnot(is.double(value)) + stopifnot(is.matrix(value)) + stopifnot(nrow(value) == self$num_components) + stopifnot(ncol(value) == self$num_groups) + rfx_model_set_group_parameters_cpp(self$rfx_model_ptr, value) + }, + + #' @description + #' Set value for the working parameter covariance. This is typically + #' used for initialization, but could also be used to interrupt + #' or override the sampler. + #' @param value Parameter input + #' @return None + set_working_parameter_cov = function(value) { + stopifnot(is.double(value)) + stopifnot(is.matrix(value)) + stopifnot(nrow(value) == self$num_components) + stopifnot(ncol(value) == self$num_components) + rfx_model_set_working_parameter_covariance_cpp( + self$rfx_model_ptr, + value + ) + }, + + #' @description + #' Set value for the group parameter covariance. This is typically + #' used for initialization, but could also be used to interrupt + #' or override the sampler. + #' @param value Parameter input + #' @return None + set_group_parameter_cov = function(value) { + stopifnot(is.double(value)) + stopifnot(is.matrix(value)) + stopifnot(nrow(value) == self$num_components) + stopifnot(ncol(value) == self$num_components) + rfx_model_set_group_parameter_covariance_cpp( + self$rfx_model_ptr, + value + ) + }, + + #' @description + #' Set shape parameter for the group parameter variance prior. + #' @param value Parameter input + #' @return None + set_variance_prior_shape = function(value) { + stopifnot(is.double(value)) + stopifnot(!is.matrix(value)) + stopifnot(length(value) == 1) + rfx_model_set_variance_prior_shape_cpp(self$rfx_model_ptr, value) + }, + + #' @description + #' Set shape parameter for the group parameter variance prior. + #' @param value Parameter input + #' @return None + set_variance_prior_scale = function(value) { + stopifnot(is.double(value)) + stopifnot(!is.matrix(value)) + stopifnot(length(value) == 1) + rfx_model_set_variance_prior_scale_cpp(self$rfx_model_ptr, value) + } + ) ) #' Create a `RandomEffectSamples` object @@ -444,13 +444,13 @@ RandomEffectsModel <- R6::R6Class( #' rfx_tracker <- createRandomEffectsTracker(rfx_group_ids) #' rfx_samples <- createRandomEffectSamples(num_components, num_groups, rfx_tracker) createRandomEffectSamples <- function( - num_components, - num_groups, - random_effects_tracker + num_components, + num_groups, + random_effects_tracker ) { - invisible(output <- RandomEffectSamples$new()) - output$load_in_session(num_components, num_groups, random_effects_tracker) - return(output) + invisible(output <- RandomEffectSamples$new()) + output$load_in_session(num_components, num_groups, random_effects_tracker) + return(output) } #' Create a `RandomEffectsTracker` object @@ -467,7 +467,7 @@ createRandomEffectSamples <- function( #' num_components <- ncol(rfx_basis) #' rfx_tracker <- createRandomEffectsTracker(rfx_group_ids) createRandomEffectsTracker <- function(rfx_group_indices) { - return(invisible((RandomEffectsTracker$new(rfx_group_indices)))) + return(invisible((RandomEffectsTracker$new(rfx_group_indices)))) } #' Create a `RandomEffectsModel` object @@ -485,7 +485,7 @@ createRandomEffectsTracker <- function(rfx_group_indices) { #' num_components <- ncol(rfx_basis) #' rfx_model <- createRandomEffectsModel(num_components, num_groups) createRandomEffectsModel <- function(num_components, num_groups) { - return(invisible((RandomEffectsModel$new(num_components, num_groups)))) + return(invisible((RandomEffectsModel$new(num_components, num_groups)))) } #' Reset a `RandomEffectsModel` object based on the parameters indexed by `sample_num` in a `RandomEffectsSamples` object @@ -531,23 +531,23 @@ createRandomEffectsModel <- function(num_components, num_groups) { #' } #' resetRandomEffectsModel(rfx_model, rfx_samples, 0, 1.0) resetRandomEffectsModel <- function( - rfx_model, - rfx_samples, - sample_num, - sigma_alpha_init + rfx_model, + rfx_samples, + sample_num, + sigma_alpha_init ) { - if (!is.matrix(sigma_alpha_init)) { - if (!is.double(sigma_alpha_init)) { - stop("`sigma_alpha_init` must be a numeric scalar or matrix") - } - sigma_alpha_init <- as.matrix(sigma_alpha_init) + if (!is.matrix(sigma_alpha_init)) { + if (!is.double(sigma_alpha_init)) { + stop("`sigma_alpha_init` must be a numeric scalar or matrix") } - reset_rfx_model_cpp( - rfx_model$rfx_model_ptr, - rfx_samples$rfx_container_ptr, - sample_num - ) - rfx_model$set_working_parameter_cov(sigma_alpha_init) + sigma_alpha_init <- as.matrix(sigma_alpha_init) + } + reset_rfx_model_cpp( + rfx_model$rfx_model_ptr, + rfx_samples$rfx_container_ptr, + sample_num + ) + rfx_model$set_working_parameter_cov(sigma_alpha_init) } #' Reset a `RandomEffectsTracker` object based on the parameters indexed by `sample_num` in a `RandomEffectsSamples` object @@ -595,18 +595,18 @@ resetRandomEffectsModel <- function( #' resetRandomEffectsModel(rfx_model, rfx_samples, 0, 1.0) #' resetRandomEffectsTracker(rfx_tracker, rfx_model, rfx_dataset, outcome, rfx_samples) resetRandomEffectsTracker <- function( - rfx_tracker, - rfx_model, - rfx_dataset, - residual, - rfx_samples + rfx_tracker, + rfx_model, + rfx_dataset, + residual, + rfx_samples ) { - reset_rfx_tracker_cpp( - rfx_tracker$rfx_tracker_ptr, - rfx_dataset$data_ptr, - residual$data_ptr, - rfx_model$rfx_model_ptr - ) + reset_rfx_tracker_cpp( + rfx_tracker$rfx_tracker_ptr, + rfx_dataset$data_ptr, + residual$data_ptr, + rfx_model$rfx_model_ptr + ) } #' Reset a `RandomEffectsModel` object to its "default" state @@ -656,20 +656,20 @@ resetRandomEffectsTracker <- function( #' rootResetRandomEffectsModel(rfx_model, alpha_init, xi_init, sigma_alpha_init, #' sigma_xi_init, sigma_xi_shape, sigma_xi_scale) rootResetRandomEffectsModel <- function( - rfx_model, - alpha_init, - xi_init, - sigma_alpha_init, - sigma_xi_init, - sigma_xi_shape, - sigma_xi_scale + rfx_model, + alpha_init, + xi_init, + sigma_alpha_init, + sigma_xi_init, + sigma_xi_shape, + sigma_xi_scale ) { - rfx_model$set_working_parameter(alpha_init) - rfx_model$set_group_parameters(xi_init) - rfx_model$set_working_parameter_cov(sigma_alpha_init) - rfx_model$set_group_parameter_cov(sigma_xi_init) - rfx_model$set_variance_prior_shape(sigma_xi_shape) - rfx_model$set_variance_prior_scale(sigma_xi_scale) + rfx_model$set_working_parameter(alpha_init) + rfx_model$set_group_parameters(xi_init) + rfx_model$set_working_parameter_cov(sigma_alpha_init) + rfx_model$set_group_parameter_cov(sigma_xi_init) + rfx_model$set_variance_prior_shape(sigma_xi_shape) + rfx_model$set_variance_prior_scale(sigma_xi_scale) } #' Reset a `RandomEffectsTracker` object to its "default" state @@ -717,15 +717,15 @@ rootResetRandomEffectsModel <- function( #' sigma_xi_init, sigma_xi_shape, sigma_xi_scale) #' rootResetRandomEffectsTracker(rfx_tracker, rfx_model, rfx_dataset, outcome) rootResetRandomEffectsTracker <- function( - rfx_tracker, - rfx_model, - rfx_dataset, - residual + rfx_tracker, + rfx_model, + rfx_dataset, + residual ) { - root_reset_rfx_tracker_cpp( - rfx_tracker$rfx_tracker_ptr, - rfx_dataset$data_ptr, - residual$data_ptr, - rfx_model$rfx_model_ptr - ) + root_reset_rfx_tracker_cpp( + rfx_tracker$rfx_tracker_ptr, + rfx_dataset$data_ptr, + residual$data_ptr, + rfx_model$rfx_model_ptr + ) } diff --git a/R/serialization.R b/R/serialization.R index c52e9fec..1b579bd3 100644 --- a/R/serialization.R +++ b/R/serialization.R @@ -4,521 +4,521 @@ #' Wrapper around a C++ container of tree ensembles CppJson <- R6::R6Class( - classname = "CppJson", - cloneable = FALSE, - public = list( - #' @field json_ptr External pointer to a C++ nlohmann::json object - json_ptr = NULL, - - #' @field num_forests Number of forests in the nlohmann::json object - num_forests = NULL, - - #' @field forest_labels Names of forest objects in the overall nlohmann::json object - forest_labels = NULL, - - #' @field num_rfx Number of random effects terms in the nlohman::json object - num_rfx = NULL, - - #' @field rfx_container_labels Names of rfx container objects in the overall nlohmann::json object - rfx_container_labels = NULL, - - #' @field rfx_mapper_labels Names of rfx label mapper objects in the overall nlohmann::json object - rfx_mapper_labels = NULL, - - #' @field rfx_groupid_labels Names of rfx group id objects in the overall nlohmann::json object - rfx_groupid_labels = NULL, - - #' @description - #' Create a new CppJson object. - #' @return A new `CppJson` object. - initialize = function() { - self$json_ptr <- init_json_cpp() - self$num_forests <- 0 - self$forest_labels <- c() - self$num_rfx <- 0 - self$rfx_container_labels <- c() - self$rfx_mapper_labels <- c() - self$rfx_groupid_labels <- c() - }, - - #' @description - #' Convert a forest container to json and add to the current `CppJson` object - #' @param forest_samples `ForestSamples` R class - #' @return None - add_forest = function(forest_samples) { - forest_label <- json_add_forest_cpp( - self$json_ptr, - forest_samples$forest_container_ptr - ) - self$num_forests <- self$num_forests + 1 - self$forest_labels <- c(self$forest_labels, forest_label) - }, - - #' @description - #' Convert a random effects container to json and add to the current `CppJson` object - #' @param rfx_samples `RandomEffectSamples` R class - #' @return None - add_random_effects = function(rfx_samples) { - rfx_container_label <- json_add_rfx_container_cpp( - self$json_ptr, - rfx_samples$rfx_container_ptr - ) - self$rfx_container_labels <- c( - self$rfx_container_labels, - rfx_container_label - ) - rfx_mapper_label <- json_add_rfx_label_mapper_cpp( - self$json_ptr, - rfx_samples$label_mapper_ptr - ) - self$rfx_mapper_labels <- c( - self$rfx_mapper_labels, - rfx_mapper_label - ) - rfx_groupid_label <- json_add_rfx_groupids_cpp( - self$json_ptr, - rfx_samples$training_group_ids - ) - self$rfx_groupid_labels <- c( - self$rfx_groupid_labels, - rfx_groupid_label - ) - json_increment_rfx_count_cpp(self$json_ptr) - self$num_rfx <- self$num_rfx + 1 - }, - - #' @description - #' Add a scalar to the json object under the name "field_name" (with optional subfolder "subfolder_name") - #' @param field_name The name of the field to be added to json - #' @param field_value Numeric value of the field to be added to json - #' @param subfolder_name (Optional) Name of the subfolder / hierarchy under which to place the value - #' @return None - add_scalar = function(field_name, field_value, subfolder_name = NULL) { - if (is.null(subfolder_name)) { - json_add_double_cpp(self$json_ptr, field_name, field_value) - } else { - json_add_double_subfolder_cpp( - self$json_ptr, - subfolder_name, - field_name, - field_value - ) - } - }, - - #' @description - #' Add a scalar to the json object under the name "field_name" (with optional subfolder "subfolder_name") - #' @param field_name The name of the field to be added to json - #' @param field_value Integer value of the field to be added to json - #' @param subfolder_name (Optional) Name of the subfolder / hierarchy under which to place the value - #' @return None - add_integer = function(field_name, field_value, subfolder_name = NULL) { - if (is.null(subfolder_name)) { - json_add_integer_cpp(self$json_ptr, field_name, field_value) - } else { - json_add_integer_subfolder_cpp( - self$json_ptr, - subfolder_name, - field_name, - field_value - ) - } - }, - - #' @description - #' Add a boolean value to the json object under the name "field_name" (with optional subfolder "subfolder_name") - #' @param field_name The name of the field to be added to json - #' @param field_value Numeric value of the field to be added to json - #' @param subfolder_name (Optional) Name of the subfolder / hierarchy under which to place the value - #' @return None - add_boolean = function(field_name, field_value, subfolder_name = NULL) { - if (is.null(subfolder_name)) { - json_add_bool_cpp(self$json_ptr, field_name, field_value) - } else { - json_add_bool_subfolder_cpp( - self$json_ptr, - subfolder_name, - field_name, - field_value - ) - } - }, - - #' @description - #' Add a string value to the json object under the name "field_name" (with optional subfolder "subfolder_name") - #' @param field_name The name of the field to be added to json - #' @param field_value Numeric value of the field to be added to json - #' @param subfolder_name (Optional) Name of the subfolder / hierarchy under which to place the value - #' @return None - add_string = function(field_name, field_value, subfolder_name = NULL) { - if (is.null(subfolder_name)) { - json_add_string_cpp(self$json_ptr, field_name, field_value) - } else { - json_add_string_subfolder_cpp( - self$json_ptr, - subfolder_name, - field_name, - field_value - ) - } - }, - - #' @description - #' Add a vector to the json object under the name "field_name" (with optional subfolder "subfolder_name") - #' @param field_name The name of the field to be added to json - #' @param field_vector Vector to be stored in json - #' @param subfolder_name (Optional) Name of the subfolder / hierarchy under which to place the value - #' @return None - add_vector = function(field_name, field_vector, subfolder_name = NULL) { - field_vector <- as.numeric(field_vector) - if (is.null(subfolder_name)) { - json_add_vector_cpp(self$json_ptr, field_name, field_vector) - } else { - json_add_vector_subfolder_cpp( - self$json_ptr, - subfolder_name, - field_name, - field_vector - ) - } - }, - - #' @description - #' Add an integer vector to the json object under the name "field_name" (with optional subfolder "subfolder_name") - #' @param field_name The name of the field to be added to json - #' @param field_vector Vector to be stored in json - #' @param subfolder_name (Optional) Name of the subfolder / hierarchy under which to place the value - #' @return None - add_integer_vector = function( - field_name, - field_vector, - subfolder_name = NULL - ) { - field_vector <- as.numeric(field_vector) - if (is.null(subfolder_name)) { - json_add_integer_vector_cpp( - self$json_ptr, - field_name, - field_vector - ) - } else { - json_add_integer_vector_subfolder_cpp( - self$json_ptr, - subfolder_name, - field_name, - field_vector - ) - } - }, - - #' @description - #' Add an array to the json object under the name "field_name" (with optional subfolder "subfolder_name") - #' @param field_name The name of the field to be added to json - #' @param field_vector Character vector to be stored in json - #' @param subfolder_name (Optional) Name of the subfolder / hierarchy under which to place the value - #' @return None - add_string_vector = function( - field_name, - field_vector, - subfolder_name = NULL - ) { - if (is.null(subfolder_name)) { - json_add_string_vector_cpp( - self$json_ptr, - field_name, - field_vector - ) - } else { - json_add_string_vector_subfolder_cpp( - self$json_ptr, - subfolder_name, - field_name, - field_vector - ) - } - }, - - #' @description - #' Add a list of vectors (as an object map of arrays) to the json object under the name "field_name" - #' @param field_name The name of the field to be added to json - #' @param field_list List to be stored in json - #' @return None - add_list = function(field_name, field_list) { - stopifnot(sum(!sapply(field_list, is.vector)) == 0) - list_names <- names(field_list) - for (i in 1:length(field_list)) { - vec_name <- list_names[i] - vec <- field_list[[i]] - json_add_vector_subfolder_cpp( - self$json_ptr, - field_name, - vec_name, - vec - ) - } - }, - - #' @description - #' Add a list of vectors (as an object map of arrays) to the json object under the name "field_name" - #' @param field_name The name of the field to be added to json - #' @param field_list List to be stored in json - #' @return None - add_string_list = function(field_name, field_list) { - stopifnot(sum(!sapply(field_list, is.vector)) == 0) - list_names <- names(field_list) - for (i in 1:length(field_list)) { - vec_name <- list_names[i] - vec <- field_list[[i]] - json_add_string_vector_subfolder_cpp( - self$json_ptr, - field_name, - vec_name, - vec - ) - } - }, - - #' @description - #' Retrieve a scalar value from the json object under the name "field_name" (with optional subfolder "subfolder_name") - #' @param field_name The name of the field to be accessed from json - #' @param subfolder_name (Optional) Name of the subfolder / hierarchy under which the field is stored - #' @return None - get_scalar = function(field_name, subfolder_name = NULL) { - if (is.null(subfolder_name)) { - stopifnot(json_contains_field_cpp(self$json_ptr, field_name)) - result <- json_extract_double_cpp(self$json_ptr, field_name) - } else { - stopifnot(json_contains_field_subfolder_cpp( - self$json_ptr, - subfolder_name, - field_name - )) - result <- json_extract_double_subfolder_cpp( - self$json_ptr, - subfolder_name, - field_name - ) - } - return(result) - }, - - #' @description - #' Retrieve a integer value from the json object under the name "field_name" (with optional subfolder "subfolder_name") - #' @param field_name The name of the field to be accessed from json - #' @param subfolder_name (Optional) Name of the subfolder / hierarchy under which the field is stored - #' @return None - get_integer = function(field_name, subfolder_name = NULL) { - if (is.null(subfolder_name)) { - stopifnot(json_contains_field_cpp(self$json_ptr, field_name)) - result <- json_extract_integer_cpp(self$json_ptr, field_name) - } else { - stopifnot(json_contains_field_subfolder_cpp( - self$json_ptr, - subfolder_name, - field_name - )) - result <- json_extract_integer_subfolder_cpp( - self$json_ptr, - subfolder_name, - field_name - ) - } - return(result) - }, - - #' @description - #' Retrieve a boolean value from the json object under the name "field_name" (with optional subfolder "subfolder_name") - #' @param field_name The name of the field to be accessed from json - #' @param subfolder_name (Optional) Name of the subfolder / hierarchy under which the field is stored - #' @return None - get_boolean = function(field_name, subfolder_name = NULL) { - if (is.null(subfolder_name)) { - stopifnot(json_contains_field_cpp(self$json_ptr, field_name)) - result <- json_extract_bool_cpp(self$json_ptr, field_name) - } else { - stopifnot(json_contains_field_subfolder_cpp( - self$json_ptr, - subfolder_name, - field_name - )) - result <- json_extract_bool_subfolder_cpp( - self$json_ptr, - subfolder_name, - field_name - ) - } - return(result) - }, - - #' @description - #' Retrieve a string value from the json object under the name "field_name" (with optional subfolder "subfolder_name") - #' @param field_name The name of the field to be accessed from json - #' @param subfolder_name (Optional) Name of the subfolder / hierarchy under which the field is stored - #' @return None - get_string = function(field_name, subfolder_name = NULL) { - if (is.null(subfolder_name)) { - stopifnot(json_contains_field_cpp(self$json_ptr, field_name)) - result <- json_extract_string_cpp(self$json_ptr, field_name) - } else { - stopifnot(json_contains_field_subfolder_cpp( - self$json_ptr, - subfolder_name, - field_name - )) - result <- json_extract_string_subfolder_cpp( - self$json_ptr, - subfolder_name, - field_name - ) - } - return(result) - }, - - #' @description - #' Retrieve a vector from the json object under the name "field_name" (with optional subfolder "subfolder_name") - #' @param field_name The name of the field to be accessed from json - #' @param subfolder_name (Optional) Name of the subfolder / hierarchy under which the field is stored - #' @return None - get_vector = function(field_name, subfolder_name = NULL) { - if (is.null(subfolder_name)) { - stopifnot(json_contains_field_cpp(self$json_ptr, field_name)) - result <- json_extract_vector_cpp(self$json_ptr, field_name) - } else { - stopifnot(json_contains_field_subfolder_cpp( - self$json_ptr, - subfolder_name, - field_name - )) - result <- json_extract_vector_subfolder_cpp( - self$json_ptr, - subfolder_name, - field_name - ) - } - return(result) - }, - - #' @description - #' Retrieve an integer vector from the json object under the name "field_name" (with optional subfolder "subfolder_name") - #' @param field_name The name of the field to be accessed from json - #' @param subfolder_name (Optional) Name of the subfolder / hierarchy under which the field is stored - #' @return None - get_integer_vector = function(field_name, subfolder_name = NULL) { - if (is.null(subfolder_name)) { - stopifnot(json_contains_field_cpp(self$json_ptr, field_name)) - result <- json_extract_integer_vector_cpp( - self$json_ptr, - field_name - ) - } else { - stopifnot(json_contains_field_subfolder_cpp( - self$json_ptr, - subfolder_name, - field_name - )) - result <- json_extract_integer_vector_subfolder_cpp( - self$json_ptr, - subfolder_name, - field_name - ) - } - return(result) - }, - - #' @description - #' Retrieve a character vector from the json object under the name "field_name" (with optional subfolder "subfolder_name") - #' @param field_name The name of the field to be accessed from json - #' @param subfolder_name (Optional) Name of the subfolder / hierarchy under which the field is stored - #' @return None - get_string_vector = function(field_name, subfolder_name = NULL) { - if (is.null(subfolder_name)) { - stopifnot(json_contains_field_cpp(self$json_ptr, field_name)) - result <- json_extract_string_vector_cpp( - self$json_ptr, - field_name - ) - } else { - stopifnot(json_contains_field_subfolder_cpp( - self$json_ptr, - subfolder_name, - field_name - )) - result <- json_extract_string_vector_subfolder_cpp( - self$json_ptr, - subfolder_name, - field_name - ) - } - return(result) - }, - - #' @description - #' Reconstruct a list of numeric vectors from the json object stored under "field_name" - #' @param field_name The name of the field to be added to json - #' @param key_names Vector of names of list elements (each of which is a vector) - #' @return None - get_numeric_list = function(field_name, key_names) { - output <- list() - for (i in 1:length(key_names)) { - vec_name <- key_names[i] - output[[vec_name]] <- json_extract_vector_subfolder_cpp( - self$json_ptr, - field_name, - vec_name - ) - } - return(output) - }, - - #' @description - #' Reconstruct a list of string vectors from the json object stored under "field_name" - #' @param field_name The name of the field to be added to json - #' @param key_names Vector of names of list elements (each of which is a vector) - #' @return None - get_string_list = function(field_name, key_names) { - output <- list() - for (i in 1:length(key_names)) { - vec_name <- key_names[i] - output[[vec_name]] <- json_extract_string_vector_subfolder_cpp( - self$json_ptr, - field_name, - vec_name - ) - } - return(output) - }, - - #' @description - #' Convert a JSON object to in-memory string - #' @return JSON string - return_json_string = function() { - return(get_json_string_cpp(self$json_ptr)) - }, - - #' @description - #' Save a json object to file - #' @param filename String of filepath, must end in ".json" - #' @return None - save_file = function(filename) { - json_save_file_cpp(self$json_ptr, filename) - }, - - #' @description - #' Load a json object from file - #' @param filename String of filepath, must end in ".json" - #' @return None - load_from_file = function(filename) { - json_load_file_cpp(self$json_ptr, filename) - }, - - #' @description - #' Load a json object from string - #' @param json_string JSON string dump - #' @return None - load_from_string = function(json_string) { - json_load_string_cpp(self$json_ptr, json_string) - } - ) + classname = "CppJson", + cloneable = FALSE, + public = list( + #' @field json_ptr External pointer to a C++ nlohmann::json object + json_ptr = NULL, + + #' @field num_forests Number of forests in the nlohmann::json object + num_forests = NULL, + + #' @field forest_labels Names of forest objects in the overall nlohmann::json object + forest_labels = NULL, + + #' @field num_rfx Number of random effects terms in the nlohman::json object + num_rfx = NULL, + + #' @field rfx_container_labels Names of rfx container objects in the overall nlohmann::json object + rfx_container_labels = NULL, + + #' @field rfx_mapper_labels Names of rfx label mapper objects in the overall nlohmann::json object + rfx_mapper_labels = NULL, + + #' @field rfx_groupid_labels Names of rfx group id objects in the overall nlohmann::json object + rfx_groupid_labels = NULL, + + #' @description + #' Create a new CppJson object. + #' @return A new `CppJson` object. + initialize = function() { + self$json_ptr <- init_json_cpp() + self$num_forests <- 0 + self$forest_labels <- c() + self$num_rfx <- 0 + self$rfx_container_labels <- c() + self$rfx_mapper_labels <- c() + self$rfx_groupid_labels <- c() + }, + + #' @description + #' Convert a forest container to json and add to the current `CppJson` object + #' @param forest_samples `ForestSamples` R class + #' @return None + add_forest = function(forest_samples) { + forest_label <- json_add_forest_cpp( + self$json_ptr, + forest_samples$forest_container_ptr + ) + self$num_forests <- self$num_forests + 1 + self$forest_labels <- c(self$forest_labels, forest_label) + }, + + #' @description + #' Convert a random effects container to json and add to the current `CppJson` object + #' @param rfx_samples `RandomEffectSamples` R class + #' @return None + add_random_effects = function(rfx_samples) { + rfx_container_label <- json_add_rfx_container_cpp( + self$json_ptr, + rfx_samples$rfx_container_ptr + ) + self$rfx_container_labels <- c( + self$rfx_container_labels, + rfx_container_label + ) + rfx_mapper_label <- json_add_rfx_label_mapper_cpp( + self$json_ptr, + rfx_samples$label_mapper_ptr + ) + self$rfx_mapper_labels <- c( + self$rfx_mapper_labels, + rfx_mapper_label + ) + rfx_groupid_label <- json_add_rfx_groupids_cpp( + self$json_ptr, + rfx_samples$training_group_ids + ) + self$rfx_groupid_labels <- c( + self$rfx_groupid_labels, + rfx_groupid_label + ) + json_increment_rfx_count_cpp(self$json_ptr) + self$num_rfx <- self$num_rfx + 1 + }, + + #' @description + #' Add a scalar to the json object under the name "field_name" (with optional subfolder "subfolder_name") + #' @param field_name The name of the field to be added to json + #' @param field_value Numeric value of the field to be added to json + #' @param subfolder_name (Optional) Name of the subfolder / hierarchy under which to place the value + #' @return None + add_scalar = function(field_name, field_value, subfolder_name = NULL) { + if (is.null(subfolder_name)) { + json_add_double_cpp(self$json_ptr, field_name, field_value) + } else { + json_add_double_subfolder_cpp( + self$json_ptr, + subfolder_name, + field_name, + field_value + ) + } + }, + + #' @description + #' Add a scalar to the json object under the name "field_name" (with optional subfolder "subfolder_name") + #' @param field_name The name of the field to be added to json + #' @param field_value Integer value of the field to be added to json + #' @param subfolder_name (Optional) Name of the subfolder / hierarchy under which to place the value + #' @return None + add_integer = function(field_name, field_value, subfolder_name = NULL) { + if (is.null(subfolder_name)) { + json_add_integer_cpp(self$json_ptr, field_name, field_value) + } else { + json_add_integer_subfolder_cpp( + self$json_ptr, + subfolder_name, + field_name, + field_value + ) + } + }, + + #' @description + #' Add a boolean value to the json object under the name "field_name" (with optional subfolder "subfolder_name") + #' @param field_name The name of the field to be added to json + #' @param field_value Numeric value of the field to be added to json + #' @param subfolder_name (Optional) Name of the subfolder / hierarchy under which to place the value + #' @return None + add_boolean = function(field_name, field_value, subfolder_name = NULL) { + if (is.null(subfolder_name)) { + json_add_bool_cpp(self$json_ptr, field_name, field_value) + } else { + json_add_bool_subfolder_cpp( + self$json_ptr, + subfolder_name, + field_name, + field_value + ) + } + }, + + #' @description + #' Add a string value to the json object under the name "field_name" (with optional subfolder "subfolder_name") + #' @param field_name The name of the field to be added to json + #' @param field_value Numeric value of the field to be added to json + #' @param subfolder_name (Optional) Name of the subfolder / hierarchy under which to place the value + #' @return None + add_string = function(field_name, field_value, subfolder_name = NULL) { + if (is.null(subfolder_name)) { + json_add_string_cpp(self$json_ptr, field_name, field_value) + } else { + json_add_string_subfolder_cpp( + self$json_ptr, + subfolder_name, + field_name, + field_value + ) + } + }, + + #' @description + #' Add a vector to the json object under the name "field_name" (with optional subfolder "subfolder_name") + #' @param field_name The name of the field to be added to json + #' @param field_vector Vector to be stored in json + #' @param subfolder_name (Optional) Name of the subfolder / hierarchy under which to place the value + #' @return None + add_vector = function(field_name, field_vector, subfolder_name = NULL) { + field_vector <- as.numeric(field_vector) + if (is.null(subfolder_name)) { + json_add_vector_cpp(self$json_ptr, field_name, field_vector) + } else { + json_add_vector_subfolder_cpp( + self$json_ptr, + subfolder_name, + field_name, + field_vector + ) + } + }, + + #' @description + #' Add an integer vector to the json object under the name "field_name" (with optional subfolder "subfolder_name") + #' @param field_name The name of the field to be added to json + #' @param field_vector Vector to be stored in json + #' @param subfolder_name (Optional) Name of the subfolder / hierarchy under which to place the value + #' @return None + add_integer_vector = function( + field_name, + field_vector, + subfolder_name = NULL + ) { + field_vector <- as.numeric(field_vector) + if (is.null(subfolder_name)) { + json_add_integer_vector_cpp( + self$json_ptr, + field_name, + field_vector + ) + } else { + json_add_integer_vector_subfolder_cpp( + self$json_ptr, + subfolder_name, + field_name, + field_vector + ) + } + }, + + #' @description + #' Add an array to the json object under the name "field_name" (with optional subfolder "subfolder_name") + #' @param field_name The name of the field to be added to json + #' @param field_vector Character vector to be stored in json + #' @param subfolder_name (Optional) Name of the subfolder / hierarchy under which to place the value + #' @return None + add_string_vector = function( + field_name, + field_vector, + subfolder_name = NULL + ) { + if (is.null(subfolder_name)) { + json_add_string_vector_cpp( + self$json_ptr, + field_name, + field_vector + ) + } else { + json_add_string_vector_subfolder_cpp( + self$json_ptr, + subfolder_name, + field_name, + field_vector + ) + } + }, + + #' @description + #' Add a list of vectors (as an object map of arrays) to the json object under the name "field_name" + #' @param field_name The name of the field to be added to json + #' @param field_list List to be stored in json + #' @return None + add_list = function(field_name, field_list) { + stopifnot(sum(!sapply(field_list, is.vector)) == 0) + list_names <- names(field_list) + for (i in 1:length(field_list)) { + vec_name <- list_names[i] + vec <- field_list[[i]] + json_add_vector_subfolder_cpp( + self$json_ptr, + field_name, + vec_name, + vec + ) + } + }, + + #' @description + #' Add a list of vectors (as an object map of arrays) to the json object under the name "field_name" + #' @param field_name The name of the field to be added to json + #' @param field_list List to be stored in json + #' @return None + add_string_list = function(field_name, field_list) { + stopifnot(sum(!sapply(field_list, is.vector)) == 0) + list_names <- names(field_list) + for (i in 1:length(field_list)) { + vec_name <- list_names[i] + vec <- field_list[[i]] + json_add_string_vector_subfolder_cpp( + self$json_ptr, + field_name, + vec_name, + vec + ) + } + }, + + #' @description + #' Retrieve a scalar value from the json object under the name "field_name" (with optional subfolder "subfolder_name") + #' @param field_name The name of the field to be accessed from json + #' @param subfolder_name (Optional) Name of the subfolder / hierarchy under which the field is stored + #' @return None + get_scalar = function(field_name, subfolder_name = NULL) { + if (is.null(subfolder_name)) { + stopifnot(json_contains_field_cpp(self$json_ptr, field_name)) + result <- json_extract_double_cpp(self$json_ptr, field_name) + } else { + stopifnot(json_contains_field_subfolder_cpp( + self$json_ptr, + subfolder_name, + field_name + )) + result <- json_extract_double_subfolder_cpp( + self$json_ptr, + subfolder_name, + field_name + ) + } + return(result) + }, + + #' @description + #' Retrieve a integer value from the json object under the name "field_name" (with optional subfolder "subfolder_name") + #' @param field_name The name of the field to be accessed from json + #' @param subfolder_name (Optional) Name of the subfolder / hierarchy under which the field is stored + #' @return None + get_integer = function(field_name, subfolder_name = NULL) { + if (is.null(subfolder_name)) { + stopifnot(json_contains_field_cpp(self$json_ptr, field_name)) + result <- json_extract_integer_cpp(self$json_ptr, field_name) + } else { + stopifnot(json_contains_field_subfolder_cpp( + self$json_ptr, + subfolder_name, + field_name + )) + result <- json_extract_integer_subfolder_cpp( + self$json_ptr, + subfolder_name, + field_name + ) + } + return(result) + }, + + #' @description + #' Retrieve a boolean value from the json object under the name "field_name" (with optional subfolder "subfolder_name") + #' @param field_name The name of the field to be accessed from json + #' @param subfolder_name (Optional) Name of the subfolder / hierarchy under which the field is stored + #' @return None + get_boolean = function(field_name, subfolder_name = NULL) { + if (is.null(subfolder_name)) { + stopifnot(json_contains_field_cpp(self$json_ptr, field_name)) + result <- json_extract_bool_cpp(self$json_ptr, field_name) + } else { + stopifnot(json_contains_field_subfolder_cpp( + self$json_ptr, + subfolder_name, + field_name + )) + result <- json_extract_bool_subfolder_cpp( + self$json_ptr, + subfolder_name, + field_name + ) + } + return(result) + }, + + #' @description + #' Retrieve a string value from the json object under the name "field_name" (with optional subfolder "subfolder_name") + #' @param field_name The name of the field to be accessed from json + #' @param subfolder_name (Optional) Name of the subfolder / hierarchy under which the field is stored + #' @return None + get_string = function(field_name, subfolder_name = NULL) { + if (is.null(subfolder_name)) { + stopifnot(json_contains_field_cpp(self$json_ptr, field_name)) + result <- json_extract_string_cpp(self$json_ptr, field_name) + } else { + stopifnot(json_contains_field_subfolder_cpp( + self$json_ptr, + subfolder_name, + field_name + )) + result <- json_extract_string_subfolder_cpp( + self$json_ptr, + subfolder_name, + field_name + ) + } + return(result) + }, + + #' @description + #' Retrieve a vector from the json object under the name "field_name" (with optional subfolder "subfolder_name") + #' @param field_name The name of the field to be accessed from json + #' @param subfolder_name (Optional) Name of the subfolder / hierarchy under which the field is stored + #' @return None + get_vector = function(field_name, subfolder_name = NULL) { + if (is.null(subfolder_name)) { + stopifnot(json_contains_field_cpp(self$json_ptr, field_name)) + result <- json_extract_vector_cpp(self$json_ptr, field_name) + } else { + stopifnot(json_contains_field_subfolder_cpp( + self$json_ptr, + subfolder_name, + field_name + )) + result <- json_extract_vector_subfolder_cpp( + self$json_ptr, + subfolder_name, + field_name + ) + } + return(result) + }, + + #' @description + #' Retrieve an integer vector from the json object under the name "field_name" (with optional subfolder "subfolder_name") + #' @param field_name The name of the field to be accessed from json + #' @param subfolder_name (Optional) Name of the subfolder / hierarchy under which the field is stored + #' @return None + get_integer_vector = function(field_name, subfolder_name = NULL) { + if (is.null(subfolder_name)) { + stopifnot(json_contains_field_cpp(self$json_ptr, field_name)) + result <- json_extract_integer_vector_cpp( + self$json_ptr, + field_name + ) + } else { + stopifnot(json_contains_field_subfolder_cpp( + self$json_ptr, + subfolder_name, + field_name + )) + result <- json_extract_integer_vector_subfolder_cpp( + self$json_ptr, + subfolder_name, + field_name + ) + } + return(result) + }, + + #' @description + #' Retrieve a character vector from the json object under the name "field_name" (with optional subfolder "subfolder_name") + #' @param field_name The name of the field to be accessed from json + #' @param subfolder_name (Optional) Name of the subfolder / hierarchy under which the field is stored + #' @return None + get_string_vector = function(field_name, subfolder_name = NULL) { + if (is.null(subfolder_name)) { + stopifnot(json_contains_field_cpp(self$json_ptr, field_name)) + result <- json_extract_string_vector_cpp( + self$json_ptr, + field_name + ) + } else { + stopifnot(json_contains_field_subfolder_cpp( + self$json_ptr, + subfolder_name, + field_name + )) + result <- json_extract_string_vector_subfolder_cpp( + self$json_ptr, + subfolder_name, + field_name + ) + } + return(result) + }, + + #' @description + #' Reconstruct a list of numeric vectors from the json object stored under "field_name" + #' @param field_name The name of the field to be added to json + #' @param key_names Vector of names of list elements (each of which is a vector) + #' @return None + get_numeric_list = function(field_name, key_names) { + output <- list() + for (i in 1:length(key_names)) { + vec_name <- key_names[i] + output[[vec_name]] <- json_extract_vector_subfolder_cpp( + self$json_ptr, + field_name, + vec_name + ) + } + return(output) + }, + + #' @description + #' Reconstruct a list of string vectors from the json object stored under "field_name" + #' @param field_name The name of the field to be added to json + #' @param key_names Vector of names of list elements (each of which is a vector) + #' @return None + get_string_list = function(field_name, key_names) { + output <- list() + for (i in 1:length(key_names)) { + vec_name <- key_names[i] + output[[vec_name]] <- json_extract_string_vector_subfolder_cpp( + self$json_ptr, + field_name, + vec_name + ) + } + return(output) + }, + + #' @description + #' Convert a JSON object to in-memory string + #' @return JSON string + return_json_string = function() { + return(get_json_string_cpp(self$json_ptr)) + }, + + #' @description + #' Save a json object to file + #' @param filename String of filepath, must end in ".json" + #' @return None + save_file = function(filename) { + json_save_file_cpp(self$json_ptr, filename) + }, + + #' @description + #' Load a json object from file + #' @param filename String of filepath, must end in ".json" + #' @return None + load_from_file = function(filename) { + json_load_file_cpp(self$json_ptr, filename) + }, + + #' @description + #' Load a json object from string + #' @param json_string JSON string dump + #' @return None + load_from_string = function(json_string) { + json_load_string_cpp(self$json_ptr, json_string) + } + ) ) #' Load a container of forest samples from json @@ -536,9 +536,9 @@ CppJson <- R6::R6Class( #' bart_json <- saveBARTModelToJson(bart_model) #' mean_forest <- loadForestContainerJson(bart_json, "forest_0") loadForestContainerJson <- function(json_object, json_forest_label) { - invisible(output <- ForestSamples$new(0, 1, T)) - output$load_from_json(json_object, json_forest_label) - return(output) + invisible(output <- ForestSamples$new(0, 1, T)) + output$load_from_json(json_object, json_forest_label) + return(output) } #' Combine multiple JSON model objects containing forests (with the same hierarchy / schema) into a single forest_container @@ -556,19 +556,19 @@ loadForestContainerJson <- function(json_object, json_forest_label) { #' bart_json <- list(saveBARTModelToJson(bart_model)) #' mean_forest <- loadForestContainerCombinedJson(bart_json, "forest_0") loadForestContainerCombinedJson <- function( - json_object_list, - json_forest_label + json_object_list, + json_forest_label ) { - invisible(output <- ForestSamples$new(0, 1, T)) - for (i in 1:length(json_object_list)) { - json_object <- json_object_list[[i]] - if (i == 1) { - output$load_from_json(json_object, json_forest_label) - } else { - output$append_from_json(json_object, json_forest_label) - } + invisible(output <- ForestSamples$new(0, 1, T)) + for (i in 1:length(json_object_list)) { + json_object <- json_object_list[[i]] + if (i == 1) { + output$load_from_json(json_object, json_forest_label) + } else { + output$append_from_json(json_object, json_forest_label) } - return(output) + } + return(output) } #' Combine multiple JSON strings representing model objects containing forests (with the same hierarchy / schema) into a single forest_container @@ -586,19 +586,19 @@ loadForestContainerCombinedJson <- function( #' bart_json_string <- list(saveBARTModelToJsonString(bart_model)) #' mean_forest <- loadForestContainerCombinedJsonString(bart_json_string, "forest_0") loadForestContainerCombinedJsonString <- function( - json_string_list, - json_forest_label + json_string_list, + json_forest_label ) { - invisible(output <- ForestSamples$new(0, 1, T)) - for (i in 1:length(json_string_list)) { - json_string <- json_string_list[[i]] - if (i == 1) { - output$load_from_json_string(json_string, json_forest_label) - } else { - output$append_from_json_string(json_string, json_forest_label) - } + invisible(output <- ForestSamples$new(0, 1, T)) + for (i in 1:length(json_string_list)) { + json_string <- json_string_list[[i]] + if (i == 1) { + output$load_from_json_string(json_string, json_forest_label) + } else { + output$append_from_json_string(json_string, json_forest_label) } - return(output) + } + return(output) } #' Load a container of random effect samples from json @@ -621,17 +621,17 @@ loadForestContainerCombinedJsonString <- function( #' bart_json <- saveBARTModelToJson(bart_model) #' rfx_samples <- loadRandomEffectSamplesJson(bart_json, 0) loadRandomEffectSamplesJson <- function(json_object, json_rfx_num) { - json_rfx_container_label <- paste0("random_effect_container_", json_rfx_num) - json_rfx_mapper_label <- paste0("random_effect_label_mapper_", json_rfx_num) - json_rfx_groupids_label <- paste0("random_effect_groupids_", json_rfx_num) - invisible(output <- RandomEffectSamples$new()) - output$load_from_json( - json_object, - json_rfx_container_label, - json_rfx_mapper_label, - json_rfx_groupids_label - ) - return(output) + json_rfx_container_label <- paste0("random_effect_container_", json_rfx_num) + json_rfx_mapper_label <- paste0("random_effect_label_mapper_", json_rfx_num) + json_rfx_groupids_label <- paste0("random_effect_groupids_", json_rfx_num) + invisible(output <- RandomEffectSamples$new()) + output$load_from_json( + json_object, + json_rfx_container_label, + json_rfx_mapper_label, + json_rfx_groupids_label + ) + return(output) } #' Combine multiple JSON model objects containing random effects (with the same hierarchy / schema) into a single container @@ -654,32 +654,32 @@ loadRandomEffectSamplesJson <- function(json_object, json_rfx_num) { #' bart_json <- list(saveBARTModelToJson(bart_model)) #' rfx_samples <- loadRandomEffectSamplesCombinedJson(bart_json, 0) loadRandomEffectSamplesCombinedJson <- function( - json_object_list, - json_rfx_num + json_object_list, + json_rfx_num ) { - json_rfx_container_label <- paste0("random_effect_container_", json_rfx_num) - json_rfx_mapper_label <- paste0("random_effect_label_mapper_", json_rfx_num) - json_rfx_groupids_label <- paste0("random_effect_groupids_", json_rfx_num) - invisible(output <- RandomEffectSamples$new()) - for (i in 1:length(json_object_list)) { - json_object <- json_object_list[[i]] - if (i == 1) { - output$load_from_json( - json_object, - json_rfx_container_label, - json_rfx_mapper_label, - json_rfx_groupids_label - ) - } else { - output$append_from_json( - json_object, - json_rfx_container_label, - json_rfx_mapper_label, - json_rfx_groupids_label - ) - } + json_rfx_container_label <- paste0("random_effect_container_", json_rfx_num) + json_rfx_mapper_label <- paste0("random_effect_label_mapper_", json_rfx_num) + json_rfx_groupids_label <- paste0("random_effect_groupids_", json_rfx_num) + invisible(output <- RandomEffectSamples$new()) + for (i in 1:length(json_object_list)) { + json_object <- json_object_list[[i]] + if (i == 1) { + output$load_from_json( + json_object, + json_rfx_container_label, + json_rfx_mapper_label, + json_rfx_groupids_label + ) + } else { + output$append_from_json( + json_object, + json_rfx_container_label, + json_rfx_mapper_label, + json_rfx_groupids_label + ) } - return(output) + } + return(output) } #' Combine multiple JSON strings representing model objects containing random effects (with the same hierarchy / schema) into a single container @@ -702,32 +702,32 @@ loadRandomEffectSamplesCombinedJson <- function( #' bart_json_string <- list(saveBARTModelToJsonString(bart_model)) #' rfx_samples <- loadRandomEffectSamplesCombinedJsonString(bart_json_string, 0) loadRandomEffectSamplesCombinedJsonString <- function( - json_string_list, - json_rfx_num + json_string_list, + json_rfx_num ) { - json_rfx_container_label <- paste0("random_effect_container_", json_rfx_num) - json_rfx_mapper_label <- paste0("random_effect_label_mapper_", json_rfx_num) - json_rfx_groupids_label <- paste0("random_effect_groupids_", json_rfx_num) - invisible(output <- RandomEffectSamples$new()) - for (i in 1:length(json_string_list)) { - json_string <- json_string_list[[i]] - if (i == 1) { - output$load_from_json_string( - json_string, - json_rfx_container_label, - json_rfx_mapper_label, - json_rfx_groupids_label - ) - } else { - output$append_from_json_string( - json_string, - json_rfx_container_label, - json_rfx_mapper_label, - json_rfx_groupids_label - ) - } + json_rfx_container_label <- paste0("random_effect_container_", json_rfx_num) + json_rfx_mapper_label <- paste0("random_effect_label_mapper_", json_rfx_num) + json_rfx_groupids_label <- paste0("random_effect_groupids_", json_rfx_num) + invisible(output <- RandomEffectSamples$new()) + for (i in 1:length(json_string_list)) { + json_string <- json_string_list[[i]] + if (i == 1) { + output$load_from_json_string( + json_string, + json_rfx_container_label, + json_rfx_mapper_label, + json_rfx_groupids_label + ) + } else { + output$append_from_json_string( + json_string, + json_rfx_container_label, + json_rfx_mapper_label, + json_rfx_groupids_label + ) } - return(output) + } + return(output) } #' Load a vector from json @@ -745,16 +745,16 @@ loadRandomEffectSamplesCombinedJsonString <- function( #' example_json$add_vector("myvec", example_vec) #' roundtrip_vec <- loadVectorJson(example_json, "myvec") loadVectorJson <- function( - json_object, - json_vector_label, - subfolder_name = NULL + json_object, + json_vector_label, + subfolder_name = NULL ) { - if (is.null(subfolder_name)) { - output <- json_object$get_vector(json_vector_label) - } else { - output <- json_object$get_vector(json_vector_label, subfolder_name) - } - return(output) + if (is.null(subfolder_name)) { + output <- json_object$get_vector(json_vector_label) + } else { + output <- json_object$get_vector(json_vector_label, subfolder_name) + } + return(output) } #' Load a scalar from json @@ -772,16 +772,16 @@ loadVectorJson <- function( #' example_json$add_scalar("myscalar", example_scalar) #' roundtrip_scalar <- loadScalarJson(example_json, "myscalar") loadScalarJson <- function( - json_object, - json_scalar_label, - subfolder_name = NULL + json_object, + json_scalar_label, + subfolder_name = NULL ) { - if (is.null(subfolder_name)) { - output <- json_object$get_scalar(json_scalar_label) - } else { - output <- json_object$get_scalar(json_scalar_label, subfolder_name) - } - return(output) + if (is.null(subfolder_name)) { + output <- json_object$get_scalar(json_scalar_label) + } else { + output <- json_object$get_scalar(json_scalar_label, subfolder_name) + } + return(output) } #' Create a new (empty) C++ Json object @@ -794,7 +794,7 @@ loadScalarJson <- function( #' example_json <- createCppJson() #' example_json$add_vector("myvec", example_vec) createCppJson <- function() { - return(invisible((CppJson$new()))) + return(invisible((CppJson$new()))) } #' Create a C++ Json object from a Json file @@ -812,9 +812,9 @@ createCppJson <- function() { #' example_json_roundtrip <- createCppJsonFile(file.path(tmpjson)) #' unlink(tmpjson) createCppJsonFile <- function(json_filename) { - invisible((output <- CppJson$new())) - output$load_from_file(json_filename) - return(output) + invisible((output <- CppJson$new())) + output$load_from_file(json_filename) + return(output) } #' Create a C++ Json object from a Json string @@ -830,7 +830,7 @@ createCppJsonFile <- function(json_filename) { #' example_json_string <- example_json$return_json_string() #' example_json_roundtrip <- createCppJsonString(example_json_string) createCppJsonString <- function(json_string) { - invisible((output <- CppJson$new())) - output$load_from_string(json_string) - return(output) + invisible((output <- CppJson$new())) + output$load_from_string(json_string) + return(output) } diff --git a/R/stochtree-package.R b/R/stochtree-package.R index 97eeded1..41fcf99c 100644 --- a/R/stochtree-package.R +++ b/R/stochtree-package.R @@ -6,6 +6,7 @@ #' @importFrom stats predict #' @importFrom stats qgamma #' @importFrom stats qnorm +#' @importFrom stats quantile #' @importFrom stats pnorm #' @importFrom stats resid #' @importFrom stats rnorm diff --git a/R/utils.R b/R/utils.R index ad9752bc..c169cc1c 100644 --- a/R/utils.R +++ b/R/utils.R @@ -6,18 +6,18 @@ #' #' @return Parameter list with defaults overriden by values supplied in `user_params` preprocessParams <- function(default_params, user_params = NULL) { - # Override defaults from general_params - if (!is.null(user_params)) { - for (key in names(user_params)) { - if (key %in% names(default_params)) { - val <- user_params[[key]] - if (!is.null(val)) default_params[[key]] <- val - } - } + # Override defaults from general_params + if (!is.null(user_params)) { + for (key in names(user_params)) { + if (key %in% names(default_params)) { + val <- user_params[[key]] + if (!is.null(val)) default_params[[key]] <- val + } } + } - # Return result - return(default_params) + # Return result + return(default_params) } #' Preprocess covariates. DataFrames will be preprocessed based on their column @@ -35,19 +35,19 @@ preprocessParams <- function(default_params, user_params = NULL) { #' preprocess_list <- preprocessTrainData(cov_mat) #' X <- preprocess_list$X preprocessTrainData <- function(input_data) { - # Input checks - if ((!is.matrix(input_data)) && (!is.data.frame(input_data))) { - stop("Covariates provided must be a dataframe or matrix") - } - - # Routing the correct preprocessing function - if (is.matrix(input_data)) { - output <- preprocessTrainMatrix(input_data) - } else { - output <- preprocessTrainDataFrame(input_data) - } - - return(output) + # Input checks + if ((!is.matrix(input_data)) && (!is.data.frame(input_data))) { + stop("Covariates provided must be a dataframe or matrix") + } + + # Routing the correct preprocessing function + if (is.matrix(input_data)) { + output <- preprocessTrainMatrix(input_data) + } else { + output <- preprocessTrainDataFrame(input_data) + } + + return(output) } #' Preprocess covariates. DataFrames will be preprocessed based on their column @@ -66,19 +66,19 @@ preprocessTrainData <- function(input_data) { #' num_numeric_vars = 3, numeric_vars = c("x1", "x2", "x3")) #' X_preprocessed <- preprocessPredictionData(cov_df, metadata) preprocessPredictionData <- function(input_data, metadata) { - # Input checks - if ((!is.matrix(input_data)) && (!is.data.frame(input_data))) { - stop("Covariates provided must be a dataframe or matrix") - } - - # Routing the correct preprocessing function - if (is.matrix(input_data)) { - X <- preprocessPredictionMatrix(input_data, metadata) - } else { - X <- preprocessPredictionDataFrame(input_data, metadata) - } - - return(X) + # Input checks + if ((!is.matrix(input_data)) && (!is.data.frame(input_data))) { + stop("Covariates provided must be a dataframe or matrix") + } + + # Routing the correct preprocessing function + if (is.matrix(input_data)) { + X <- preprocessPredictionMatrix(input_data, metadata) + } else { + X <- preprocessPredictionDataFrame(input_data, metadata) + } + + return(X) } #' Preprocess a matrix of covariate values, assuming all columns are numeric. @@ -96,38 +96,38 @@ preprocessPredictionData <- function(input_data, metadata) { #' preprocess_list <- preprocessTrainMatrix(cov_mat) #' X <- preprocess_list$X preprocessTrainMatrix <- function(input_matrix) { - # Input checks - if (!is.matrix(input_matrix)) { - stop("covariates provided must be a matrix") - } - - # Unpack metadata (assuming all variables are numeric) - names(input_matrix) <- paste0("x", 1:ncol(input_matrix)) - df_vars <- names(input_matrix) - num_ordered_cat_vars <- 0 - num_unordered_cat_vars <- 0 - num_numeric_vars <- ncol(input_matrix) - numeric_vars <- names(input_matrix) - feature_types <- rep(0, ncol(input_matrix)) - - # Unpack data - X <- input_matrix - - # Aggregate results into a list - metadata <- list( - feature_types = feature_types, - num_ordered_cat_vars = num_ordered_cat_vars, - num_unordered_cat_vars = num_unordered_cat_vars, - num_numeric_vars = num_numeric_vars, - numeric_vars = numeric_vars, - original_var_indices = 1:num_numeric_vars - ) - output <- list( - data = X, - metadata = metadata - ) - - return(output) + # Input checks + if (!is.matrix(input_matrix)) { + stop("covariates provided must be a matrix") + } + + # Unpack metadata (assuming all variables are numeric) + names(input_matrix) <- paste0("x", 1:ncol(input_matrix)) + df_vars <- names(input_matrix) + num_ordered_cat_vars <- 0 + num_unordered_cat_vars <- 0 + num_numeric_vars <- ncol(input_matrix) + numeric_vars <- names(input_matrix) + feature_types <- rep(0, ncol(input_matrix)) + + # Unpack data + X <- input_matrix + + # Aggregate results into a list + metadata <- list( + feature_types = feature_types, + num_ordered_cat_vars = num_ordered_cat_vars, + num_unordered_cat_vars = num_unordered_cat_vars, + num_numeric_vars = num_numeric_vars, + numeric_vars = numeric_vars, + original_var_indices = 1:num_numeric_vars + ) + output <- list( + data = X, + metadata = metadata + ) + + return(output) } #' Preprocess a matrix of covariate values, assuming all columns are numeric. @@ -145,17 +145,17 @@ preprocessTrainMatrix <- function(input_matrix) { #' num_numeric_vars = 3, numeric_vars = c("x1", "x2", "x3")) #' X_preprocessed <- preprocessPredictionMatrix(cov_mat, metadata) preprocessPredictionMatrix <- function(input_matrix, metadata) { - # Input checks - if (!is.matrix(input_matrix)) { - stop("covariates provided must be a matrix") - } - if (!(ncol(input_matrix) == metadata$num_numeric_vars)) { - stop( - "Prediction set covariates have inconsistent dimension from train set covariates" - ) - } + # Input checks + if (!is.matrix(input_matrix)) { + stop("covariates provided must be a matrix") + } + if (!(ncol(input_matrix) == metadata$num_numeric_vars)) { + stop( + "Prediction set covariates have inconsistent dimension from train set covariates" + ) + } - return(input_matrix) + return(input_matrix) } #' Preprocess a dataframe of covariate values, converting categorical variables @@ -170,126 +170,126 @@ preprocessPredictionMatrix <- function(input_matrix, metadata) { #' of variable, unique categories associated with categorical variables, and the #' vector of feature types needed for calls to BART and BCF. preprocessTrainDataFrame <- function(input_df) { - # Input checks / details - if (!is.data.frame(input_df)) { - stop("covariates provided must be a data frame") - } - df_vars <- names(input_df) - - # Detect ordered and unordered categorical variables - - # First, ordered categorical: users must have explicitly - # converted this to a factor with ordered = TRUE - factor_mask <- sapply(input_df, is.factor) - ordered_mask <- sapply(input_df, is.ordered) - ordered_cat_matches <- factor_mask & ordered_mask - ordered_cat_vars <- df_vars[ordered_cat_matches] - ordered_cat_var_inds <- unname(which(ordered_cat_matches)) - num_ordered_cat_vars <- length(ordered_cat_vars) - if (num_ordered_cat_vars > 0) { - ordered_cat_df <- input_df[, ordered_cat_vars, drop = FALSE] - } - - # Next, unordered categorical: we will convert character - # columns but not integer columns (users must explicitly - # convert these to factor) - character_mask <- sapply(input_df, is.character) - unordered_cat_matches <- (factor_mask & (!ordered_mask)) | character_mask - unordered_cat_vars <- df_vars[unordered_cat_matches] - unordered_cat_var_inds <- unname(which(unordered_cat_matches)) - num_unordered_cat_vars <- length(unordered_cat_vars) - if (num_unordered_cat_vars > 0) { - unordered_cat_df <- input_df[, unordered_cat_vars, drop = FALSE] - } - - # Numeric variables - numeric_matches <- (!ordered_cat_matches) & (!unordered_cat_matches) - numeric_vars <- df_vars[numeric_matches] - numeric_var_inds <- unname(which(numeric_matches)) - num_numeric_vars <- length(numeric_vars) - if (num_numeric_vars > 0) { - numeric_df <- input_df[, numeric_vars, drop = FALSE] - } - - # Empty outputs - X <- double(0) - unordered_unique_levels <- list() - ordered_unique_levels <- list() - feature_types <- integer(0) - original_var_indices <- integer(0) - - # First, extract the numeric covariates - if (num_numeric_vars > 0) { - Xnum <- double(0) - for (i in 1:ncol(numeric_df)) { - stopifnot(is.numeric(numeric_df[, i])) - Xnum <- cbind(Xnum, numeric_df[, i]) - } - X <- cbind(X, unname(Xnum)) - feature_types <- c(feature_types, rep(0, ncol(Xnum))) - original_var_indices <- c(original_var_indices, numeric_var_inds) - } - - # Next, run some simple preprocessing on the ordered categorical covariates - if (num_ordered_cat_vars > 0) { - Xordcat <- double(0) - for (i in 1:ncol(ordered_cat_df)) { - var_name <- names(ordered_cat_df)[i] - preprocess_list <- orderedCatInitializeAndPreprocess(ordered_cat_df[, - i - ]) - ordered_unique_levels[[var_name]] <- preprocess_list$unique_levels - Xordcat <- cbind(Xordcat, preprocess_list$x_preprocessed) - } - X <- cbind(X, unname(Xordcat)) - feature_types <- c(feature_types, rep(1, ncol(Xordcat))) - original_var_indices <- c(original_var_indices, ordered_cat_var_inds) + # Input checks / details + if (!is.data.frame(input_df)) { + stop("covariates provided must be a data frame") + } + df_vars <- names(input_df) + + # Detect ordered and unordered categorical variables + + # First, ordered categorical: users must have explicitly + # converted this to a factor with ordered = TRUE + factor_mask <- sapply(input_df, is.factor) + ordered_mask <- sapply(input_df, is.ordered) + ordered_cat_matches <- factor_mask & ordered_mask + ordered_cat_vars <- df_vars[ordered_cat_matches] + ordered_cat_var_inds <- unname(which(ordered_cat_matches)) + num_ordered_cat_vars <- length(ordered_cat_vars) + if (num_ordered_cat_vars > 0) { + ordered_cat_df <- input_df[, ordered_cat_vars, drop = FALSE] + } + + # Next, unordered categorical: we will convert character + # columns but not integer columns (users must explicitly + # convert these to factor) + character_mask <- sapply(input_df, is.character) + unordered_cat_matches <- (factor_mask & (!ordered_mask)) | character_mask + unordered_cat_vars <- df_vars[unordered_cat_matches] + unordered_cat_var_inds <- unname(which(unordered_cat_matches)) + num_unordered_cat_vars <- length(unordered_cat_vars) + if (num_unordered_cat_vars > 0) { + unordered_cat_df <- input_df[, unordered_cat_vars, drop = FALSE] + } + + # Numeric variables + numeric_matches <- (!ordered_cat_matches) & (!unordered_cat_matches) + numeric_vars <- df_vars[numeric_matches] + numeric_var_inds <- unname(which(numeric_matches)) + num_numeric_vars <- length(numeric_vars) + if (num_numeric_vars > 0) { + numeric_df <- input_df[, numeric_vars, drop = FALSE] + } + + # Empty outputs + X <- double(0) + unordered_unique_levels <- list() + ordered_unique_levels <- list() + feature_types <- integer(0) + original_var_indices <- integer(0) + + # First, extract the numeric covariates + if (num_numeric_vars > 0) { + Xnum <- double(0) + for (i in 1:ncol(numeric_df)) { + stopifnot(is.numeric(numeric_df[, i])) + Xnum <- cbind(Xnum, numeric_df[, i]) } - - # Finally, one-hot encode the unordered categorical covariates - if (num_unordered_cat_vars > 0) { - one_hot_mats <- list() - for (i in 1:ncol(unordered_cat_df)) { - var_name <- names(unordered_cat_df)[i] - encode_list <- oneHotInitializeAndEncode(unordered_cat_df[, i]) - unordered_unique_levels[[var_name]] <- encode_list$unique_levels - one_hot_mats[[var_name]] <- encode_list$Xtilde - one_hot_var <- rep( - unordered_cat_var_inds[i], - ncol(encode_list$Xtilde) - ) - original_var_indices <- c(original_var_indices, one_hot_var) - } - Xcat <- do.call(cbind, one_hot_mats) - X <- cbind(X, unname(Xcat)) - feature_types <- c(feature_types, rep(1, ncol(Xcat))) - } - - # Aggregate results into a list - metadata <- list( - feature_types = feature_types, - num_ordered_cat_vars = num_ordered_cat_vars, - num_unordered_cat_vars = num_unordered_cat_vars, - num_numeric_vars = num_numeric_vars, - original_var_indices = original_var_indices - ) - if (num_ordered_cat_vars > 0) { - metadata[["ordered_cat_vars"]] = ordered_cat_vars - metadata[["ordered_unique_levels"]] = ordered_unique_levels + X <- cbind(X, unname(Xnum)) + feature_types <- c(feature_types, rep(0, ncol(Xnum))) + original_var_indices <- c(original_var_indices, numeric_var_inds) + } + + # Next, run some simple preprocessing on the ordered categorical covariates + if (num_ordered_cat_vars > 0) { + Xordcat <- double(0) + for (i in 1:ncol(ordered_cat_df)) { + var_name <- names(ordered_cat_df)[i] + preprocess_list <- orderedCatInitializeAndPreprocess(ordered_cat_df[, + i + ]) + ordered_unique_levels[[var_name]] <- preprocess_list$unique_levels + Xordcat <- cbind(Xordcat, preprocess_list$x_preprocessed) } - if (num_unordered_cat_vars > 0) { - metadata[["unordered_cat_vars"]] = unordered_cat_vars - metadata[["unordered_unique_levels"]] = unordered_unique_levels + X <- cbind(X, unname(Xordcat)) + feature_types <- c(feature_types, rep(1, ncol(Xordcat))) + original_var_indices <- c(original_var_indices, ordered_cat_var_inds) + } + + # Finally, one-hot encode the unordered categorical covariates + if (num_unordered_cat_vars > 0) { + one_hot_mats <- list() + for (i in 1:ncol(unordered_cat_df)) { + var_name <- names(unordered_cat_df)[i] + encode_list <- oneHotInitializeAndEncode(unordered_cat_df[, i]) + unordered_unique_levels[[var_name]] <- encode_list$unique_levels + one_hot_mats[[var_name]] <- encode_list$Xtilde + one_hot_var <- rep( + unordered_cat_var_inds[i], + ncol(encode_list$Xtilde) + ) + original_var_indices <- c(original_var_indices, one_hot_var) } - if (num_numeric_vars > 0) { - metadata[["numeric_vars"]] = numeric_vars - } - output <- list( - data = X, - metadata = metadata - ) - - return(output) + Xcat <- do.call(cbind, one_hot_mats) + X <- cbind(X, unname(Xcat)) + feature_types <- c(feature_types, rep(1, ncol(Xcat))) + } + + # Aggregate results into a list + metadata <- list( + feature_types = feature_types, + num_ordered_cat_vars = num_ordered_cat_vars, + num_unordered_cat_vars = num_unordered_cat_vars, + num_numeric_vars = num_numeric_vars, + original_var_indices = original_var_indices + ) + if (num_ordered_cat_vars > 0) { + metadata[["ordered_cat_vars"]] = ordered_cat_vars + metadata[["ordered_unique_levels"]] = ordered_unique_levels + } + if (num_unordered_cat_vars > 0) { + metadata[["unordered_cat_vars"]] = unordered_cat_vars + metadata[["unordered_unique_levels"]] = unordered_unique_levels + } + if (num_numeric_vars > 0) { + metadata[["numeric_vars"]] = numeric_vars + } + output <- list( + data = X, + metadata = metadata + ) + + return(output) } #' Preprocess a dataframe of covariate values, converting categorical variables @@ -309,70 +309,70 @@ preprocessTrainDataFrame <- function(input_df) { #' num_numeric_vars = 3, numeric_vars = c("x1", "x2", "x3")) #' X_preprocessed <- preprocessPredictionDataFrame(cov_df, metadata) preprocessPredictionDataFrame <- function(input_df, metadata) { - if (!is.data.frame(input_df)) { - stop("covariates provided must be a data frame") - } - df_vars <- names(input_df) - num_ordered_cat_vars <- metadata$num_ordered_cat_vars - num_unordered_cat_vars <- metadata$num_unordered_cat_vars - num_numeric_vars <- metadata$num_numeric_vars - - if (num_ordered_cat_vars > 0) { - ordered_cat_vars <- metadata$ordered_cat_vars - ordered_cat_df <- input_df[, ordered_cat_vars, drop = FALSE] + if (!is.data.frame(input_df)) { + stop("covariates provided must be a data frame") + } + df_vars <- names(input_df) + num_ordered_cat_vars <- metadata$num_ordered_cat_vars + num_unordered_cat_vars <- metadata$num_unordered_cat_vars + num_numeric_vars <- metadata$num_numeric_vars + + if (num_ordered_cat_vars > 0) { + ordered_cat_vars <- metadata$ordered_cat_vars + ordered_cat_df <- input_df[, ordered_cat_vars, drop = FALSE] + } + if (num_unordered_cat_vars > 0) { + unordered_cat_vars <- metadata$unordered_cat_vars + unordered_cat_df <- input_df[, unordered_cat_vars, drop = FALSE] + } + if (num_numeric_vars > 0) { + numeric_vars <- metadata$numeric_vars + numeric_df <- input_df[, numeric_vars, drop = FALSE] + } + + # Empty outputs + X <- double(0) + + # First, extract the numeric covariates + if (num_numeric_vars > 0) { + Xnum <- double(0) + for (i in 1:ncol(numeric_df)) { + stopifnot(is.numeric(numeric_df[, i])) + Xnum <- cbind(Xnum, numeric_df[, i]) } - if (num_unordered_cat_vars > 0) { - unordered_cat_vars <- metadata$unordered_cat_vars - unordered_cat_df <- input_df[, unordered_cat_vars, drop = FALSE] + X <- cbind(X, unname(Xnum)) + } + + # Next, run some simple preprocessing on the ordered categorical covariates + if (num_ordered_cat_vars > 0) { + Xordcat <- double(0) + for (i in 1:ncol(ordered_cat_df)) { + var_name <- names(ordered_cat_df)[i] + x_preprocessed <- orderedCatPreprocess( + ordered_cat_df[, i], + metadata$ordered_unique_levels[[var_name]] + ) + Xordcat <- cbind(Xordcat, x_preprocessed) } - if (num_numeric_vars > 0) { - numeric_vars <- metadata$numeric_vars - numeric_df <- input_df[, numeric_vars, drop = FALSE] + X <- cbind(X, unname(Xordcat)) + } + + # Finally, one-hot encode the unordered categorical covariates + if (num_unordered_cat_vars > 0) { + one_hot_mats <- list() + for (i in 1:ncol(unordered_cat_df)) { + var_name <- names(unordered_cat_df)[i] + Xtilde <- oneHotEncode( + unordered_cat_df[, i], + metadata$unordered_unique_levels[[var_name]] + ) + one_hot_mats[[var_name]] <- Xtilde } + Xcat <- do.call(cbind, one_hot_mats) + X <- cbind(X, unname(Xcat)) + } - # Empty outputs - X <- double(0) - - # First, extract the numeric covariates - if (num_numeric_vars > 0) { - Xnum <- double(0) - for (i in 1:ncol(numeric_df)) { - stopifnot(is.numeric(numeric_df[, i])) - Xnum <- cbind(Xnum, numeric_df[, i]) - } - X <- cbind(X, unname(Xnum)) - } - - # Next, run some simple preprocessing on the ordered categorical covariates - if (num_ordered_cat_vars > 0) { - Xordcat <- double(0) - for (i in 1:ncol(ordered_cat_df)) { - var_name <- names(ordered_cat_df)[i] - x_preprocessed <- orderedCatPreprocess( - ordered_cat_df[, i], - metadata$ordered_unique_levels[[var_name]] - ) - Xordcat <- cbind(Xordcat, x_preprocessed) - } - X <- cbind(X, unname(Xordcat)) - } - - # Finally, one-hot encode the unordered categorical covariates - if (num_unordered_cat_vars > 0) { - one_hot_mats <- list() - for (i in 1:ncol(unordered_cat_df)) { - var_name <- names(unordered_cat_df)[i] - Xtilde <- oneHotEncode( - unordered_cat_df[, i], - metadata$unordered_unique_levels[[var_name]] - ) - one_hot_mats[[var_name]] <- Xtilde - } - Xcat <- do.call(cbind, one_hot_mats) - X <- cbind(X, unname(Xcat)) - } - - return(X) + return(X) } #' Convert the persistent aspects of a covariate preprocessor to (in-memory) C++ JSON object @@ -388,59 +388,59 @@ preprocessPredictionDataFrame <- function(input_df, metadata) { #' preprocess_list <- preprocessTrainData(cov_mat) #' preprocessor_json <- convertPreprocessorToJson(preprocess_list$metadata) convertPreprocessorToJson <- function(object) { - jsonobj <- createCppJson() - if (is.null(object$feature_types)) { - stop("This covariate preprocessor has not yet been fit") - } - - # Add internal scalars - jsonobj$add_integer("num_numeric_vars", object$num_numeric_vars) - jsonobj$add_integer("num_ordered_cat_vars", object$num_ordered_cat_vars) - jsonobj$add_integer("num_unordered_cat_vars", object$num_unordered_cat_vars) - - # Add internal vectors - jsonobj$add_vector("feature_types", object$feature_types) - jsonobj$add_vector("original_var_indices", object$original_var_indices) - if (object$num_numeric_vars > 0) { - jsonobj$add_string_vector("numeric_vars", object$numeric_vars) + jsonobj <- createCppJson() + if (is.null(object$feature_types)) { + stop("This covariate preprocessor has not yet been fit") + } + + # Add internal scalars + jsonobj$add_integer("num_numeric_vars", object$num_numeric_vars) + jsonobj$add_integer("num_ordered_cat_vars", object$num_ordered_cat_vars) + jsonobj$add_integer("num_unordered_cat_vars", object$num_unordered_cat_vars) + + # Add internal vectors + jsonobj$add_vector("feature_types", object$feature_types) + jsonobj$add_vector("original_var_indices", object$original_var_indices) + if (object$num_numeric_vars > 0) { + jsonobj$add_string_vector("numeric_vars", object$numeric_vars) + } + if (object$num_ordered_cat_vars > 0) { + jsonobj$add_string_vector("ordered_cat_vars", object$ordered_cat_vars) + for (i in 1:object$num_ordered_cat_vars) { + var_key <- names(object$ordered_unique_levels)[i] + jsonobj$add_string( + paste0("key_", i), + var_key, + "ordered_unique_level_keys" + ) + jsonobj$add_string_vector( + var_key, + object$ordered_unique_levels[[i]], + "ordered_unique_levels" + ) } - if (object$num_ordered_cat_vars > 0) { - jsonobj$add_string_vector("ordered_cat_vars", object$ordered_cat_vars) - for (i in 1:object$num_ordered_cat_vars) { - var_key <- names(object$ordered_unique_levels)[i] - jsonobj$add_string( - paste0("key_", i), - var_key, - "ordered_unique_level_keys" - ) - jsonobj$add_string_vector( - var_key, - object$ordered_unique_levels[[i]], - "ordered_unique_levels" - ) - } - } - if (object$num_unordered_cat_vars > 0) { - jsonobj$add_string_vector( - "unordered_cat_vars", - object$unordered_cat_vars - ) - for (i in 1:object$num_unordered_cat_vars) { - var_key <- names(object$unordered_unique_levels)[i] - jsonobj$add_string( - paste0("key_", i), - var_key, - "unordered_unique_level_keys" - ) - jsonobj$add_string_vector( - var_key, - object$unordered_unique_levels[[i]], - "unordered_unique_levels" - ) - } + } + if (object$num_unordered_cat_vars > 0) { + jsonobj$add_string_vector( + "unordered_cat_vars", + object$unordered_cat_vars + ) + for (i in 1:object$num_unordered_cat_vars) { + var_key <- names(object$unordered_unique_levels)[i] + jsonobj$add_string( + paste0("key_", i), + var_key, + "unordered_unique_level_keys" + ) + jsonobj$add_string_vector( + var_key, + object$unordered_unique_levels[[i]], + "unordered_unique_levels" + ) } + } - return(jsonobj) + return(jsonobj) } #' Convert the persistent aspects of a covariate preprocessor to (in-memory) JSON string @@ -456,11 +456,11 @@ convertPreprocessorToJson <- function(object) { #' preprocess_list <- preprocessTrainData(cov_mat) #' preprocessor_json_string <- savePreprocessorToJsonString(preprocess_list$metadata) savePreprocessorToJsonString <- function(object) { - # Convert to Json - jsonobj <- convertPreprocessorToJson(object) + # Convert to Json + jsonobj <- convertPreprocessorToJson(object) - # Dump to string - return(jsonobj$return_json_string()) + # Dump to string + return(jsonobj$return_json_string()) } #' Reload a covariate preprocessor object from a JSON string containing a serialized preprocessor @@ -476,66 +476,66 @@ savePreprocessorToJsonString <- function(object) { #' preprocessor_json <- convertPreprocessorToJson(preprocess_list$metadata) #' preprocessor_roundtrip <- createPreprocessorFromJson(preprocessor_json) createPreprocessorFromJson <- function(json_object) { - # Initialize the metadata list - metadata <- list() - - # Unpack internal scalars - metadata[["num_numeric_vars"]] <- json_object$get_integer( - "num_numeric_vars" + # Initialize the metadata list + metadata <- list() + + # Unpack internal scalars + metadata[["num_numeric_vars"]] <- json_object$get_integer( + "num_numeric_vars" + ) + metadata[["num_ordered_cat_vars"]] <- json_object$get_integer( + "num_ordered_cat_vars" + ) + metadata[["num_unordered_cat_vars"]] <- json_object$get_integer( + "num_unordered_cat_vars" + ) + + # Unpack internal vectors + metadata[["feature_types"]] <- json_object$get_vector("feature_types") + metadata[["original_var_indices"]] <- json_object$get_vector( + "original_var_indices" + ) + if (metadata$num_numeric_vars > 0) { + metadata[["numeric_vars"]] <- json_object$get_string_vector( + "numeric_vars" ) - metadata[["num_ordered_cat_vars"]] <- json_object$get_integer( - "num_ordered_cat_vars" + } + if (metadata$num_ordered_cat_vars > 0) { + metadata[["ordered_cat_vars"]] <- json_object$get_string_vector( + "ordered_cat_vars" ) - metadata[["num_unordered_cat_vars"]] <- json_object$get_integer( - "num_unordered_cat_vars" - ) - - # Unpack internal vectors - metadata[["feature_types"]] <- json_object$get_vector("feature_types") - metadata[["original_var_indices"]] <- json_object$get_vector( - "original_var_indices" - ) - if (metadata$num_numeric_vars > 0) { - metadata[["numeric_vars"]] <- json_object$get_string_vector( - "numeric_vars" - ) - } - if (metadata$num_ordered_cat_vars > 0) { - metadata[["ordered_cat_vars"]] <- json_object$get_string_vector( - "ordered_cat_vars" - ) - ordered_unique_levels <- list() - for (i in 1:metadata$num_ordered_cat_vars) { - var_key <- json_object$get_string( - paste0("key_", i), - "ordered_unique_level_keys" - ) - ordered_unique_levels[[var_key]] <- json_object$get_string_vector( - var_key, - "ordered_unique_levels" - ) - } - metadata[["ordered_unique_levels"]] <- ordered_unique_levels + ordered_unique_levels <- list() + for (i in 1:metadata$num_ordered_cat_vars) { + var_key <- json_object$get_string( + paste0("key_", i), + "ordered_unique_level_keys" + ) + ordered_unique_levels[[var_key]] <- json_object$get_string_vector( + var_key, + "ordered_unique_levels" + ) } - if (metadata$num_unordered_cat_vars > 0) { - metadata[["unordered_cat_vars"]] <- json_object$get_string_vector( - "unordered_cat_vars" - ) - unordered_unique_levels <- list() - for (i in 1:metadata$num_unordered_cat_vars) { - var_key <- json_object$get_string( - paste0("key_", i), - "unordered_unique_level_keys" - ) - unordered_unique_levels[[var_key]] <- json_object$get_string_vector( - var_key, - "unordered_unique_levels" - ) - } - metadata[["unordered_unique_levels"]] <- unordered_unique_levels + metadata[["ordered_unique_levels"]] <- ordered_unique_levels + } + if (metadata$num_unordered_cat_vars > 0) { + metadata[["unordered_cat_vars"]] <- json_object$get_string_vector( + "unordered_cat_vars" + ) + unordered_unique_levels <- list() + for (i in 1:metadata$num_unordered_cat_vars) { + var_key <- json_object$get_string( + paste0("key_", i), + "unordered_unique_level_keys" + ) + unordered_unique_levels[[var_key]] <- json_object$get_string_vector( + var_key, + "unordered_unique_levels" + ) } + metadata[["unordered_unique_levels"]] <- unordered_unique_levels + } - return(metadata) + return(metadata) } #' Reload a covariate preprocessor object from a JSON string containing a serialized preprocessor @@ -551,13 +551,13 @@ createPreprocessorFromJson <- function(json_object) { #' preprocessor_json_string <- savePreprocessorToJsonString(preprocess_list$metadata) #' preprocessor_roundtrip <- createPreprocessorFromJsonString(preprocessor_json_string) createPreprocessorFromJsonString <- function(json_string) { - # Load a `CppJson` object from string - preprocessor_json <- createCppJsonString(json_string) + # Load a `CppJson` object from string + preprocessor_json <- createCppJsonString(json_string) - # Create and return the BCF object - preprocessor_object <- createPreprocessorFromJson(preprocessor_json) + # Create and return the BCF object + preprocessor_object <- createPreprocessorFromJson(preprocessor_json) - return(preprocessor_object) + return(preprocessor_object) } #' Preprocess a dataframe of covariate values, converting categorical variables @@ -579,129 +579,129 @@ createPreprocessorFromJsonString <- function(json_string) { #' preprocess_list <- createForestCovariates(cov_df) #' X <- preprocess_list$X createForestCovariates <- function( - input_data, - ordered_cat_vars = NULL, - unordered_cat_vars = NULL + input_data, + ordered_cat_vars = NULL, + unordered_cat_vars = NULL ) { - if (is.matrix(input_data)) { - input_df <- as.data.frame(input_data) - names(input_df) <- paste0("x", 1:ncol(input_data)) - if (!is.null(ordered_cat_vars)) { - if (is.numeric(ordered_cat_vars)) { - ordered_cat_vars <- paste0("x", as.integer(ordered_cat_vars)) - } - } - if (!is.null(unordered_cat_vars)) { - if (is.numeric(unordered_cat_vars)) { - unordered_cat_vars <- paste0( - "x", - as.integer(unordered_cat_vars) - ) - } - } - } else if (is.data.frame(input_data)) { - input_df <- input_data - } else { - stop("input_data must be either a matrix or a data frame") - } - df_vars <- names(input_df) - if (is.null(ordered_cat_vars)) { - ordered_cat_matches <- rep(FALSE, length(df_vars)) - } else { - ordered_cat_matches <- df_vars %in% ordered_cat_vars - } - if (is.null(unordered_cat_vars)) { - unordered_cat_matches <- rep(FALSE, length(df_vars)) - } else { - unordered_cat_matches <- df_vars %in% unordered_cat_vars - } - numeric_matches <- ((!ordered_cat_matches) & (!unordered_cat_matches)) - ordered_cat_vars <- df_vars[ordered_cat_matches] - unordered_cat_vars <- df_vars[unordered_cat_matches] - numeric_vars <- df_vars[numeric_matches] - num_ordered_cat_vars <- length(ordered_cat_vars) - num_unordered_cat_vars <- length(unordered_cat_vars) - num_numeric_vars <- length(numeric_vars) - if (num_ordered_cat_vars > 0) { - ordered_cat_df <- input_df[, ordered_cat_vars, drop = FALSE] - } - if (num_unordered_cat_vars > 0) { - unordered_cat_df <- input_df[, unordered_cat_vars, drop = FALSE] - } - if (num_numeric_vars > 0) { - numeric_df <- input_df[, numeric_vars, drop = FALSE] - } - - # Empty outputs - X <- double(0) - unordered_unique_levels <- list() - ordered_unique_levels <- list() - feature_types <- integer(0) - - # First, extract the numeric covariates - if (num_numeric_vars > 0) { - Xnum <- double(0) - for (i in 1:ncol(numeric_df)) { - stopifnot(is.numeric(numeric_df[, i])) - Xnum <- cbind(Xnum, numeric_df[, i]) - } - X <- cbind(X, unname(Xnum)) - feature_types <- c(feature_types, rep(0, ncol(Xnum))) - } - - # Next, run some simple preprocessing on the ordered categorical covariates - if (num_ordered_cat_vars > 0) { - Xordcat <- double(0) - for (i in 1:ncol(ordered_cat_df)) { - var_name <- names(ordered_cat_df)[i] - preprocess_list <- orderedCatInitializeAndPreprocess(ordered_cat_df[, - i - ]) - ordered_unique_levels[[var_name]] <- preprocess_list$unique_levels - Xordcat <- cbind(Xordcat, preprocess_list$x_preprocessed) - } - X <- cbind(X, unname(Xordcat)) - feature_types <- c(feature_types, rep(1, ncol(Xordcat))) + if (is.matrix(input_data)) { + input_df <- as.data.frame(input_data) + names(input_df) <- paste0("x", 1:ncol(input_data)) + if (!is.null(ordered_cat_vars)) { + if (is.numeric(ordered_cat_vars)) { + ordered_cat_vars <- paste0("x", as.integer(ordered_cat_vars)) + } } - - # Finally, one-hot encode the unordered categorical covariates - if (num_unordered_cat_vars > 0) { - one_hot_mats <- list() - for (i in 1:ncol(unordered_cat_df)) { - var_name <- names(unordered_cat_df)[i] - encode_list <- oneHotInitializeAndEncode(unordered_cat_df[, i]) - unordered_unique_levels[[var_name]] <- encode_list$unique_levels - one_hot_mats[[var_name]] <- encode_list$Xtilde - } - Xcat <- do.call(cbind, one_hot_mats) - X <- cbind(X, unname(Xcat)) - feature_types <- c(feature_types, rep(1, ncol(Xcat))) + if (!is.null(unordered_cat_vars)) { + if (is.numeric(unordered_cat_vars)) { + unordered_cat_vars <- paste0( + "x", + as.integer(unordered_cat_vars) + ) + } } - - # Aggregate results into a list - metadata <- list( - feature_types = feature_types, - num_ordered_cat_vars = num_ordered_cat_vars, - num_unordered_cat_vars = num_unordered_cat_vars, - num_numeric_vars = num_numeric_vars - ) - if (num_ordered_cat_vars > 0) { - metadata[["ordered_cat_vars"]] = ordered_cat_vars - metadata[["ordered_unique_levels"]] = ordered_unique_levels + } else if (is.data.frame(input_data)) { + input_df <- input_data + } else { + stop("input_data must be either a matrix or a data frame") + } + df_vars <- names(input_df) + if (is.null(ordered_cat_vars)) { + ordered_cat_matches <- rep(FALSE, length(df_vars)) + } else { + ordered_cat_matches <- df_vars %in% ordered_cat_vars + } + if (is.null(unordered_cat_vars)) { + unordered_cat_matches <- rep(FALSE, length(df_vars)) + } else { + unordered_cat_matches <- df_vars %in% unordered_cat_vars + } + numeric_matches <- ((!ordered_cat_matches) & (!unordered_cat_matches)) + ordered_cat_vars <- df_vars[ordered_cat_matches] + unordered_cat_vars <- df_vars[unordered_cat_matches] + numeric_vars <- df_vars[numeric_matches] + num_ordered_cat_vars <- length(ordered_cat_vars) + num_unordered_cat_vars <- length(unordered_cat_vars) + num_numeric_vars <- length(numeric_vars) + if (num_ordered_cat_vars > 0) { + ordered_cat_df <- input_df[, ordered_cat_vars, drop = FALSE] + } + if (num_unordered_cat_vars > 0) { + unordered_cat_df <- input_df[, unordered_cat_vars, drop = FALSE] + } + if (num_numeric_vars > 0) { + numeric_df <- input_df[, numeric_vars, drop = FALSE] + } + + # Empty outputs + X <- double(0) + unordered_unique_levels <- list() + ordered_unique_levels <- list() + feature_types <- integer(0) + + # First, extract the numeric covariates + if (num_numeric_vars > 0) { + Xnum <- double(0) + for (i in 1:ncol(numeric_df)) { + stopifnot(is.numeric(numeric_df[, i])) + Xnum <- cbind(Xnum, numeric_df[, i]) } - if (num_unordered_cat_vars > 0) { - metadata[["unordered_cat_vars"]] = unordered_cat_vars - metadata[["unordered_unique_levels"]] = unordered_unique_levels + X <- cbind(X, unname(Xnum)) + feature_types <- c(feature_types, rep(0, ncol(Xnum))) + } + + # Next, run some simple preprocessing on the ordered categorical covariates + if (num_ordered_cat_vars > 0) { + Xordcat <- double(0) + for (i in 1:ncol(ordered_cat_df)) { + var_name <- names(ordered_cat_df)[i] + preprocess_list <- orderedCatInitializeAndPreprocess(ordered_cat_df[, + i + ]) + ordered_unique_levels[[var_name]] <- preprocess_list$unique_levels + Xordcat <- cbind(Xordcat, preprocess_list$x_preprocessed) } - if (num_numeric_vars > 0) { - metadata[["numeric_vars"]] = numeric_vars + X <- cbind(X, unname(Xordcat)) + feature_types <- c(feature_types, rep(1, ncol(Xordcat))) + } + + # Finally, one-hot encode the unordered categorical covariates + if (num_unordered_cat_vars > 0) { + one_hot_mats <- list() + for (i in 1:ncol(unordered_cat_df)) { + var_name <- names(unordered_cat_df)[i] + encode_list <- oneHotInitializeAndEncode(unordered_cat_df[, i]) + unordered_unique_levels[[var_name]] <- encode_list$unique_levels + one_hot_mats[[var_name]] <- encode_list$Xtilde } - output <- list( - data = X, - metadata = metadata - ) - - return(output) + Xcat <- do.call(cbind, one_hot_mats) + X <- cbind(X, unname(Xcat)) + feature_types <- c(feature_types, rep(1, ncol(Xcat))) + } + + # Aggregate results into a list + metadata <- list( + feature_types = feature_types, + num_ordered_cat_vars = num_ordered_cat_vars, + num_unordered_cat_vars = num_unordered_cat_vars, + num_numeric_vars = num_numeric_vars + ) + if (num_ordered_cat_vars > 0) { + metadata[["ordered_cat_vars"]] = ordered_cat_vars + metadata[["ordered_unique_levels"]] = ordered_unique_levels + } + if (num_unordered_cat_vars > 0) { + metadata[["unordered_cat_vars"]] = unordered_cat_vars + metadata[["unordered_unique_levels"]] = unordered_unique_levels + } + if (num_numeric_vars > 0) { + metadata[["numeric_vars"]] = numeric_vars + } + output <- list( + data = X, + metadata = metadata + ) + + return(output) } #' Preprocess a dataframe of covariate values, converting categorical variables @@ -722,75 +722,75 @@ createForestCovariates <- function( #' num_numeric_vars = 3, numeric_vars = c("x1", "x2", "x3")) #' X_preprocessed <- createForestCovariatesFromMetadata(cov_df, metadata) createForestCovariatesFromMetadata <- function(input_data, metadata) { - if (is.matrix(input_data)) { - input_df <- as.data.frame(input_data) - names(input_df) <- paste0("x", 1:ncol(input_data)) - } else if (is.data.frame(input_data)) { - input_df <- input_data - } else { - stop("input_data must be either a matrix or a data frame") - } - df_vars <- names(input_df) - num_ordered_cat_vars <- metadata$num_ordered_cat_vars - num_unordered_cat_vars <- metadata$num_unordered_cat_vars - num_numeric_vars <- metadata$num_numeric_vars - - if (num_ordered_cat_vars > 0) { - ordered_cat_vars <- metadata$ordered_cat_vars - ordered_cat_df <- input_df[, ordered_cat_vars, drop = FALSE] + if (is.matrix(input_data)) { + input_df <- as.data.frame(input_data) + names(input_df) <- paste0("x", 1:ncol(input_data)) + } else if (is.data.frame(input_data)) { + input_df <- input_data + } else { + stop("input_data must be either a matrix or a data frame") + } + df_vars <- names(input_df) + num_ordered_cat_vars <- metadata$num_ordered_cat_vars + num_unordered_cat_vars <- metadata$num_unordered_cat_vars + num_numeric_vars <- metadata$num_numeric_vars + + if (num_ordered_cat_vars > 0) { + ordered_cat_vars <- metadata$ordered_cat_vars + ordered_cat_df <- input_df[, ordered_cat_vars, drop = FALSE] + } + if (num_unordered_cat_vars > 0) { + unordered_cat_vars <- metadata$unordered_cat_vars + unordered_cat_df <- input_df[, unordered_cat_vars, drop = FALSE] + } + if (num_numeric_vars > 0) { + numeric_vars <- metadata$numeric_vars + numeric_df <- input_df[, numeric_vars, drop = FALSE] + } + + # Empty outputs + X <- double(0) + + # First, extract the numeric covariates + if (num_numeric_vars > 0) { + Xnum <- double(0) + for (i in 1:ncol(numeric_df)) { + stopifnot(is.numeric(numeric_df[, i])) + Xnum <- cbind(Xnum, numeric_df[, i]) } - if (num_unordered_cat_vars > 0) { - unordered_cat_vars <- metadata$unordered_cat_vars - unordered_cat_df <- input_df[, unordered_cat_vars, drop = FALSE] + X <- cbind(X, unname(Xnum)) + } + + # Next, run some simple preprocessing on the ordered categorical covariates + if (num_ordered_cat_vars > 0) { + Xordcat <- double(0) + for (i in 1:ncol(ordered_cat_df)) { + var_name <- names(ordered_cat_df)[i] + x_preprocessed <- orderedCatPreprocess( + ordered_cat_df[, i], + metadata$ordered_unique_levels[[var_name]] + ) + Xordcat <- cbind(Xordcat, x_preprocessed) } - if (num_numeric_vars > 0) { - numeric_vars <- metadata$numeric_vars - numeric_df <- input_df[, numeric_vars, drop = FALSE] + X <- cbind(X, unname(Xordcat)) + } + + # Finally, one-hot encode the unordered categorical covariates + if (num_unordered_cat_vars > 0) { + one_hot_mats <- list() + for (i in 1:ncol(unordered_cat_df)) { + var_name <- names(unordered_cat_df)[i] + Xtilde <- oneHotEncode( + unordered_cat_df[, i], + metadata$unordered_unique_levels[[var_name]] + ) + one_hot_mats[[var_name]] <- Xtilde } + Xcat <- do.call(cbind, one_hot_mats) + X <- cbind(X, unname(Xcat)) + } - # Empty outputs - X <- double(0) - - # First, extract the numeric covariates - if (num_numeric_vars > 0) { - Xnum <- double(0) - for (i in 1:ncol(numeric_df)) { - stopifnot(is.numeric(numeric_df[, i])) - Xnum <- cbind(Xnum, numeric_df[, i]) - } - X <- cbind(X, unname(Xnum)) - } - - # Next, run some simple preprocessing on the ordered categorical covariates - if (num_ordered_cat_vars > 0) { - Xordcat <- double(0) - for (i in 1:ncol(ordered_cat_df)) { - var_name <- names(ordered_cat_df)[i] - x_preprocessed <- orderedCatPreprocess( - ordered_cat_df[, i], - metadata$ordered_unique_levels[[var_name]] - ) - Xordcat <- cbind(Xordcat, x_preprocessed) - } - X <- cbind(X, unname(Xordcat)) - } - - # Finally, one-hot encode the unordered categorical covariates - if (num_unordered_cat_vars > 0) { - one_hot_mats <- list() - for (i in 1:ncol(unordered_cat_df)) { - var_name <- names(unordered_cat_df)[i] - Xtilde <- oneHotEncode( - unordered_cat_df[, i], - metadata$unordered_unique_levels[[var_name]] - ) - one_hot_mats[[var_name]] <- Xtilde - } - Xcat <- do.call(cbind, one_hot_mats) - X <- cbind(X, unname(Xcat)) - } - - return(X) + return(X) } #' Convert a vector of unordered categorical data (either numeric or character @@ -813,15 +813,15 @@ createForestCovariatesFromMetadata <- function(input_data, metadata) { #' x <- c("a","c","b","c","d","a","c","a","b","d") #' x_onehot <- oneHotInitializeAndEncode(x) oneHotInitializeAndEncode <- function(x_input) { - stopifnot((is.null(dim(x_input)) && length(x_input) > 0)) - if (is.factor(x_input) && is.ordered(x_input)) { - warning("One-hot encoding an ordered categorical variable") - } - x_factor <- factor(x_input) - unique_levels <- levels(x_factor) - Xtilde <- cbind(unname(model.matrix(~ 0 + x_factor)), 0) - output <- list(Xtilde = Xtilde, unique_levels = unique_levels) - return(output) + stopifnot((is.null(dim(x_input)) && length(x_input) > 0)) + if (is.factor(x_input) && is.ordered(x_input)) { + warning("One-hot encoding an ordered categorical variable") + } + x_factor <- factor(x_input) + unique_levels <- levels(x_factor) + Xtilde <- cbind(unname(model.matrix(~ 0 + x_factor)), 0) + output <- list(Xtilde = Xtilde, unique_levels = unique_levels) + return(output) } #' Convert a vector of unordered categorical data (either numeric or character @@ -848,34 +848,34 @@ oneHotInitializeAndEncode <- function(x_input) { #' x_test <- sample(1:9, 10, TRUE) #' x_onehot <- oneHotEncode(x_test, levels(factor(x))) oneHotEncode <- function(x_input, unique_levels) { - stopifnot((is.null(dim(x_input)) && length(x_input) > 0)) - stopifnot((is.null(dim(unique_levels)) && length(unique_levels) > 0)) - num_unique_levels <- length(unique_levels) - in_sample <- x_input %in% unique_levels - out_of_sample <- !(x_input %in% unique_levels) - has_out_of_sample <- sum(out_of_sample) > 0 - if (has_out_of_sample) { - x_factor_insample <- factor(x_input[in_sample], levels = unique_levels) - Xtilde <- matrix( - 0, - nrow = length(x_input), - ncol = num_unique_levels + 1 - ) - Xtilde_insample <- cbind( - unname(model.matrix(~ 0 + x_factor_insample)), - 0 - ) - Xtilde_out_of_sample <- cbind( - matrix(0, nrow = sum(out_of_sample), ncol = num_unique_levels), - 1 - ) - Xtilde[in_sample, ] <- Xtilde_insample - Xtilde[out_of_sample, ] <- Xtilde_out_of_sample - } else { - x_factor <- factor(x_input, levels = unique_levels) - Xtilde <- cbind(unname(model.matrix(~ 0 + x_factor)), 0) - } - return(Xtilde) + stopifnot((is.null(dim(x_input)) && length(x_input) > 0)) + stopifnot((is.null(dim(unique_levels)) && length(unique_levels) > 0)) + num_unique_levels <- length(unique_levels) + in_sample <- x_input %in% unique_levels + out_of_sample <- !(x_input %in% unique_levels) + has_out_of_sample <- sum(out_of_sample) > 0 + if (has_out_of_sample) { + x_factor_insample <- factor(x_input[in_sample], levels = unique_levels) + Xtilde <- matrix( + 0, + nrow = length(x_input), + ncol = num_unique_levels + 1 + ) + Xtilde_insample <- cbind( + unname(model.matrix(~ 0 + x_factor_insample)), + 0 + ) + Xtilde_out_of_sample <- cbind( + matrix(0, nrow = sum(out_of_sample), ncol = num_unique_levels), + 1 + ) + Xtilde[in_sample, ] <- Xtilde_insample + Xtilde[out_of_sample, ] <- Xtilde_out_of_sample + } else { + x_factor <- factor(x_input, levels = unique_levels) + Xtilde <- cbind(unname(model.matrix(~ 0 + x_factor)), 0) + } + return(Xtilde) } #' Run some simple preprocessing of ordered categorical variables, converting @@ -897,17 +897,17 @@ oneHotEncode <- function(x_input, unique_levels) { #' preprocess_list <- orderedCatInitializeAndPreprocess(x) #' x_preprocessed <- preprocess_list$x_preprocessed orderedCatInitializeAndPreprocess <- function(x_input) { - stopifnot((is.null(dim(x_input)) && length(x_input) > 0)) - already_ordered_factor <- (is.factor(x_input)) && (is.ordered(x_input)) - if (already_ordered_factor) { - x_preprocessed <- as.integer(x_input) - unique_levels <- levels(x_input) - } else { - x_factor <- factor(x_input, ordered = TRUE) - x_preprocessed <- as.integer(x_factor) - unique_levels <- levels(x_factor) - } - return(list(x_preprocessed = x_preprocessed, unique_levels = unique_levels)) + stopifnot((is.null(dim(x_input)) && length(x_input) > 0)) + already_ordered_factor <- (is.factor(x_input)) && (is.ordered(x_input)) + if (already_ordered_factor) { + x_preprocessed <- as.integer(x_input) + unique_levels <- levels(x_input) + } else { + x_factor <- factor(x_input, ordered = TRUE) + x_preprocessed <- as.integer(x_factor) + unique_levels <- levels(x_factor) + } + return(list(x_preprocessed = x_preprocessed, unique_levels = unique_levels)) } #' Run some simple preprocessing of ordered categorical variables, converting @@ -933,56 +933,56 @@ orderedCatInitializeAndPreprocess <- function(x_input) { #' "4. Agree", "3. Neither agree nor disagree", "5. Strongly agree", "4. Agree") #' x_processed <- orderedCatPreprocess(x, x_levels) orderedCatPreprocess <- function(x_input, unique_levels, var_name = NULL) { - stopifnot((is.null(dim(x_input)) && length(x_input) > 0)) - stopifnot((is.null(dim(unique_levels)) && length(unique_levels) > 0)) - already_ordered_factor <- (is.factor(x_input)) && (is.ordered(x_input)) - if (already_ordered_factor) { - # Run time checks - levels_not_in_reflist <- !(levels(x_input) %in% unique_levels) - if (sum(levels_not_in_reflist) > 0) { - if (!is.null(var_name)) { - warning_message <- paste0( - "Variable ", - var_name, - " includes ordered categorical levels not included in the original training set" - ) - } else { - warning_message <- paste0( - "Variable includes ordered categorical levels not included in the original training set" - ) - } - warning(warning_message) - } - # Preprocessing - x_string <- as.character(x_input) - x_factor <- factor(x_string, unique_levels, ordered = TRUE) - x_preprocessed <- as.integer(x_factor) - x_preprocessed[is.na(x_preprocessed)] <- length(unique_levels) + 1 - } else { - x_factor <- factor(x_input, ordered = TRUE) - # Run time checks - levels_not_in_reflist <- !(levels(x_factor) %in% unique_levels) - if (sum(levels_not_in_reflist) > 0) { - if (!is.null(var_name)) { - warning_message <- paste0( - "Variable ", - var_name, - " includes ordered categorical levels not included in the original training set" - ) - } else { - warning_message <- paste0( - "Variable includes ordered categorical levels not included in the original training set" - ) - } - warning(warning_message) - } - # Preprocessing - x_string <- as.character(x_input) - x_factor <- factor(x_string, unique_levels, ordered = TRUE) - x_preprocessed <- as.integer(x_factor) - x_preprocessed[is.na(x_preprocessed)] <- length(unique_levels) + 1 + stopifnot((is.null(dim(x_input)) && length(x_input) > 0)) + stopifnot((is.null(dim(unique_levels)) && length(unique_levels) > 0)) + already_ordered_factor <- (is.factor(x_input)) && (is.ordered(x_input)) + if (already_ordered_factor) { + # Run time checks + levels_not_in_reflist <- !(levels(x_input) %in% unique_levels) + if (sum(levels_not_in_reflist) > 0) { + if (!is.null(var_name)) { + warning_message <- paste0( + "Variable ", + var_name, + " includes ordered categorical levels not included in the original training set" + ) + } else { + warning_message <- paste0( + "Variable includes ordered categorical levels not included in the original training set" + ) + } + warning(warning_message) } - return(x_preprocessed) + # Preprocessing + x_string <- as.character(x_input) + x_factor <- factor(x_string, unique_levels, ordered = TRUE) + x_preprocessed <- as.integer(x_factor) + x_preprocessed[is.na(x_preprocessed)] <- length(unique_levels) + 1 + } else { + x_factor <- factor(x_input, ordered = TRUE) + # Run time checks + levels_not_in_reflist <- !(levels(x_factor) %in% unique_levels) + if (sum(levels_not_in_reflist) > 0) { + if (!is.null(var_name)) { + warning_message <- paste0( + "Variable ", + var_name, + " includes ordered categorical levels not included in the original training set" + ) + } else { + warning_message <- paste0( + "Variable includes ordered categorical levels not included in the original training set" + ) + } + warning(warning_message) + } + # Preprocessing + x_string <- as.character(x_input) + x_factor <- factor(x_string, unique_levels, ordered = TRUE) + x_preprocessed <- as.integer(x_factor) + x_preprocessed[is.na(x_preprocessed)] <- length(unique_levels) + 1 + } + return(x_preprocessed) } #' Convert scalar input to vector of dimension `output_size`, @@ -993,19 +993,19 @@ orderedCatPreprocess <- function(x_input, unique_levels, var_name = NULL) { #' @return A vector of length `output_size` #' @export expand_dims_1d <- function(input, output_size) { - if (length(input) == 1) { - output <- rep(input, output_size) - } else if (is.numeric(input)) { - if (length(input) != output_size) { - stop("`input` must be a 1D numpy array with `output_size` elements") - } - output <- input - } else { - stop( - "`input` must be either a 1D numpy array or a scalar that can be repeated `output_size` times" - ) + if (length(input) == 1) { + output <- rep(input, output_size) + } else if (is.numeric(input)) { + if (length(input) != output_size) { + stop("`input` must be a 1D numpy array with `output_size` elements") } - return(output) + output <- input + } else { + stop( + "`input` must be either a 1D numpy array or a scalar that can be repeated `output_size` times" + ) + } + return(output) } #' Ensures that input is propagated appropriately to a matrix of dimension `output_rows` x `output_cols`. @@ -1022,41 +1022,41 @@ expand_dims_1d <- function(input, output_size) { #' @return A matrix of dimension `output_rows` x `output_cols` #' @export expand_dims_2d <- function(input, output_rows, output_cols) { - if (length(input) == 1) { - output <- matrix( - rep(input, output_rows * output_cols), - ncol = output_cols - ) - } else if (is.numeric(input)) { - if (length(input) == output_cols) { - output <- matrix( - rep(input, output_rows), - nrow = output_rows, - byrow = T - ) - } else if (length(input) == output_rows) { - output <- matrix( - rep(input, output_cols), - ncol = output_cols, - byrow = F - ) - } else { - stop( - "If `input` is a vector, it must either contain `output_rows` or `output_cols` elements" - ) - } - } else if (is.matrix(input)) { - if (nrow(input) != output_rows) { - stop("`input` must be a matrix with `output_rows` rows") - } - if (ncol(input) != output_cols) { - stop("`input` must be a matrix with `output_cols` columns") - } - output <- input + if (length(input) == 1) { + output <- matrix( + rep(input, output_rows * output_cols), + ncol = output_cols + ) + } else if (is.numeric(input)) { + if (length(input) == output_cols) { + output <- matrix( + rep(input, output_rows), + nrow = output_rows, + byrow = T + ) + } else if (length(input) == output_rows) { + output <- matrix( + rep(input, output_cols), + ncol = output_cols, + byrow = F + ) } else { - stop("`input` must be either a matrix, vector or a scalar") + stop( + "If `input` is a vector, it must either contain `output_rows` or `output_cols` elements" + ) + } + } else if (is.matrix(input)) { + if (nrow(input) != output_rows) { + stop("`input` must be a matrix with `output_rows` rows") } - return(output) + if (ncol(input) != output_cols) { + stop("`input` must be a matrix with `output_cols` columns") + } + output <- input + } else { + stop("`input` must be either a matrix, vector or a scalar") + } + return(output) } #' Convert scalar input to square matrix of dimension `output_size` x `output_size` with `input` along the diagonal, @@ -1067,20 +1067,20 @@ expand_dims_2d <- function(input, output_rows, output_cols) { #' @return A square matrix of dimension `output_size` x `output_size` #' @export expand_dims_2d_diag <- function(input, output_size) { - if (length(input) == 1) { - output <- as.matrix(diag(input, output_size)) - } else if (is.matrix(input)) { - if (nrow(input) != ncol(input)) { - stop("`input` must be a square matrix") - } - if (nrow(input) != output_size) { - stop( - "`input` must be a square matrix with `output_size` rows and columns" - ) - } - output <- input - } else { - stop("`input` must be either a square matrix or a scalar") + if (length(input) == 1) { + output <- as.matrix(diag(input, output_size)) + } else if (is.matrix(input)) { + if (nrow(input) != ncol(input)) { + stop("`input` must be a square matrix") + } + if (nrow(input) != output_size) { + stop( + "`input` must be a square matrix with `output_size` rows and columns" + ) } - return(output) + output <- input + } else { + stop("`input` must be either a square matrix or a scalar") + } + return(output) } diff --git a/R/variance.R b/R/variance.R index b0ad722e..11c6c326 100644 --- a/R/variance.R +++ b/R/variance.R @@ -19,19 +19,19 @@ #' b <- 1.0 #' sigma2 <- sampleGlobalErrorVarianceOneIteration(outcome, forest_dataset, rng, a, b) sampleGlobalErrorVarianceOneIteration <- function( - residual, - dataset, - rng, + residual, + dataset, + rng, + a, + b +) { + return(sample_sigma2_one_iteration_cpp( + residual$data_ptr, + dataset$data_ptr, + rng$rng_ptr, a, b -) { - return(sample_sigma2_one_iteration_cpp( - residual$data_ptr, - dataset$data_ptr, - rng$rng_ptr, - a, - b - )) + )) } #' Sample one iteration of the leaf parameter variance model (only for univariate basis and constant leaf!) @@ -54,5 +54,5 @@ sampleGlobalErrorVarianceOneIteration <- function( #' b <- 1.0 #' tau <- sampleLeafVarianceOneIteration(active_forest, rng, a, b) sampleLeafVarianceOneIteration <- function(forest, rng, a, b) { - return(sample_tau_one_iteration_cpp(forest$forest_ptr, rng$rng_ptr, a, b)) + return(sample_tau_one_iteration_cpp(forest$forest_ptr, rng$rng_ptr, a, b)) } diff --git a/demo/debug/bart_contrast_debug.py b/demo/debug/bart_contrast_debug.py new file mode 100644 index 00000000..15ce5705 --- /dev/null +++ b/demo/debug/bart_contrast_debug.py @@ -0,0 +1,181 @@ +# Demo of contrast computation function for BART + +# Load libraries +from stochtree import BARTModel +from sklearn.model_selection import train_test_split +import numpy as np + +# Generate data +n = 500 +p = 5 +rng = np.random.default_rng(1234) +X = rng.uniform(low=0.0, high=1.0, size=(n, p)) +W = rng.normal(loc=0.0, scale=1.0, size=(n, 1)) +f_XW = np.where( + ((0 <= X[:, 0]) & (X[:, 0] < 0.25)), + -7.5 * W[:, 0], + np.where( + ((0.25 <= X[:, 0]) & (X[:, 0] < 0.5)), + -2.5 * W[:, 0], + np.where( + ((0.5 <= X[:, 0]) & (X[:, 0] < 0.75)), + 2.5 * W[:, 0], + 7.5 * W[:, 0], + ), + ), +) +E_Y = f_XW +snr = 2 +y = E_Y + rng.normal(loc=0.0, scale=1.0, size=(n,)) * (np.std(E_Y) / snr) + +# Train-test split +test_set_pct = 0.2 +train_inds, test_inds = train_test_split( + np.arange(n), test_size=test_set_pct, random_state=1234 +) +X_train = X[train_inds, :] +X_test = X[test_inds, :] +W_train = W[train_inds, :] +W_test = W[test_inds, :] +y_train = y[train_inds] +y_test = y[test_inds] +n_test = len(test_inds) +n_train = len(train_inds) + +# Fit BART model +bart_model = BARTModel() +bart_model.sample( + X_train=X_train, + leaf_basis_train=W_train, + y_train=y_train, + num_gfr=10, + num_burnin=0, + num_mcmc=1000, +) + +# Compute contrast posterior +contrast_posterior_test = bart_model.compute_contrast( + covariates_0=X_test, + covariates_1=X_test, + basis_0=np.zeros((n_test, 1)), + basis_1=np.ones((n_test, 1)), + type="posterior", + scale="linear", +) + +# Compute the same quantity via two predict calls +y_hat_posterior_test_0 = bart_model.predict( + covariates=X_test, + basis=np.zeros((n_test, 1)), + type="posterior", + terms="y_hat", + scale="linear", +) +y_hat_posterior_test_1 = bart_model.predict( + covariates=X_test, + basis=np.ones((n_test, 1)), + type="posterior", + terms="y_hat", + scale="linear", +) +contrast_posterior_test_comparison = y_hat_posterior_test_1 - y_hat_posterior_test_0 + +# Compare results +contrast_diff = contrast_posterior_test_comparison - contrast_posterior_test +np.allclose(contrast_diff, 0, atol=0.001) + +# Generate data for a BART model with random effects +X = rng.uniform(low=0.0, high=1.0, size=(n, p)) +W = rng.normal(loc=0.0, scale=1.0, size=(n, 1)) +f_XW = np.where( + ((0 <= X[:, 0]) & (X[:, 0] < 0.25)), + -7.5 * W[:, 0], + np.where( + ((0.25 <= X[:, 0]) & (X[:, 0] < 0.5)), + -2.5 * W[:, 0], + np.where( + ((0.5 <= X[:, 0]) & (X[:, 0] < 0.75)), + 2.5 * W[:, 0], + 7.5 * W[:, 0], + ), + ), +) +num_rfx_groups = 3 +group_labels = rng.choice(num_rfx_groups, size=n) +basis = np.empty((n, 2)) +basis[:, 0] = 1.0 +basis[:, 1] = rng.uniform(0, 1, (n,)) +rfx_coefs = np.array([[-2, -2], [0, 0], [2, 2]]) +rfx_term = np.sum(rfx_coefs[group_labels, :] * basis, axis=1) +E_Y = f_XW + rfx_term +snr = 2 +y = E_Y + rng.normal(loc=0.0, scale=1.0, size=(n,)) * (np.std(E_Y) / snr) + +# Train-test split +train_inds, test_inds = train_test_split( + np.arange(n), test_size=test_set_pct, random_state=1234 +) +X_train = X[train_inds, :] +X_test = X[test_inds, :] +W_train = W[train_inds, :] +W_test = W[test_inds, :] +y_train = y[train_inds] +y_test = y[test_inds] +group_ids_train = group_labels[train_inds] +group_ids_test = group_labels[test_inds] +rfx_basis_train = basis[train_inds, :] +rfx_basis_test = basis[test_inds, :] +n_test = len(test_inds) +n_train = len(train_inds) + +# Fit BART model +bart_model = BARTModel() +bart_model.sample( + X_train=X_train, + leaf_basis_train=W_train, + y_train=y_train, + rfx_group_ids_train=group_ids_train, + rfx_basis_train=rfx_basis_train, + num_gfr=10, + num_burnin=0, + num_mcmc=1000, +) + +# Compute contrast posterior +contrast_posterior_test = bart_model.compute_contrast( + covariates_0=X_test, + covariates_1=X_test, + basis_0=np.zeros((n_test, 1)), + basis_1=np.ones((n_test, 1)), + rfx_group_ids_0=group_ids_test, + rfx_group_ids_1=group_ids_test, + rfx_basis_0=rfx_basis_test, + rfx_basis_1=rfx_basis_test, + type="posterior", + scale="linear", +) + +# Compute the same quantity via two predict calls +y_hat_posterior_test_0 = bart_model.predict( + covariates=X_test, + basis=np.zeros((n_test, 1)), + rfx_group_ids=group_ids_test, + rfx_basis=rfx_basis_test, + type="posterior", + terms="y_hat", + scale="linear", +) +y_hat_posterior_test_1 = bart_model.predict( + covariates=X_test, + basis=np.ones((n_test, 1)), + rfx_group_ids=group_ids_test, + rfx_basis=rfx_basis_test, + type="posterior", + terms="y_hat", + scale="linear", +) +contrast_posterior_test_comparison = y_hat_posterior_test_1 - y_hat_posterior_test_0 + +# Compare results +contrast_diff = contrast_posterior_test_comparison - contrast_posterior_test +np.allclose(contrast_diff, 0, atol=0.001) diff --git a/demo/debug/bart_predict_debug.py b/demo/debug/bart_predict_debug.py new file mode 100644 index 00000000..d66b1110 --- /dev/null +++ b/demo/debug/bart_predict_debug.py @@ -0,0 +1,96 @@ +# Demo of updated predict method for BART + +# Load library +from stochtree import BARTModel +import numpy as np +from sklearn.model_selection import train_test_split +import matplotlib.pyplot as plt + +# Generate data +rng = np.random.default_rng() +n = 500 +p = 5 +X = rng.uniform(low=0.0, high=1.0, size=(n, p)) +f_X = np.where( + ((0 <= X[:, 0]) & (X[:, 0] <= 0.25)), + -7.5, + np.where( + ((0.25 <= X[:, 0]) & (X[:, 0] <= 0.5)), + -2.5, + np.where(((0.5 <= X[:, 0]) & (X[:, 0] <= 0.75)), 2.5, 7.5), + ), +) +noise_sd = 1.0 +y = f_X + rng.normal(loc=0.0, scale=1.0, size=(n,)) + +# Train-test split +sample_inds = np.arange(n) +test_set_pct = 0.2 +train_inds, test_inds = train_test_split(sample_inds, test_size=test_set_pct) +X_train = X[train_inds, :] +X_test = X[test_inds, :] +y_train = y[train_inds] +y_test = y[test_inds] +f_X_train = f_X[train_inds] +f_X_test = f_X[test_inds] + +# Fit simple BART model +bart_model = BARTModel() +bart_model.sample( + X_train=X_train, + y_train=y_train, + X_test=X_test, + num_gfr=10, + num_burnin=0, + num_mcmc=1000, +) + +# # Check several predict approaches +bart_preds = bart_model.predict(covariates=X_test) +y_hat_posterior_test = bart_model.predict(covariates=X_test)["y_hat"] +y_hat_mean_test = bart_model.predict(covariates=X_test, type="mean", terms=["y_hat"]) +y_hat_test = bart_model.predict( + covariates=X_test, type="mean", terms=["rfx", "variance"] +) + +# Plot predicted versus actual +plt.scatter(y_hat_mean_test, y_test, color="black") +plt.axline((0, 0), slope=1, color="red", linestyle=(0, (3, 3))) +plt.xlabel("Predicted") +plt.ylabel("Actual") +plt.title("Y hat") +plt.show() + +# Compute posterior interval +intervals = bart_model.compute_posterior_interval( + terms="all", scale="linear", level=0.95, covariates=X_test +) + +# Check coverage +mean_coverage = np.mean( + (intervals["y_hat"]["lower"] <= f_X_test) + & (f_X_test <= intervals["y_hat"]["upper"]) +) +print(f"Coverage of 95% posterior interval for f(X): {mean_coverage:.3f}") + +# Sample from the posterior predictive distribution +bart_ppd_samples = bart_model.sample_posterior_predictive( + covariates=X_test, num_draws_per_sample=10 +) + +# Plot PPD mean vs actual +ppd_mean = np.mean(bart_ppd_samples, axis=(0, 2)) +plt.clf() +plt.scatter(ppd_mean, y_test, color="blue") +plt.axline((0, 0), slope=1, color="red", linestyle=(0, (3, 3))) +plt.xlabel("Predicted") +plt.ylabel("Actual") +plt.title("Posterior Predictive Mean Comparison") +plt.show() + +# Check coverage of posterior predictive distribution +ppd_intervals = np.percentile(bart_ppd_samples, [2.5, 97.5], axis=(0, 2)) +ppd_coverage = np.mean( + (ppd_intervals[0, :] <= y_test) & (y_test <= ppd_intervals[1, :]) +) +print(f"Coverage of 95% posterior predictive interval for Y: {ppd_coverage:.3f}") diff --git a/demo/debug/bcf_contrast_debug.py b/demo/debug/bcf_contrast_debug.py new file mode 100644 index 00000000..006780d7 --- /dev/null +++ b/demo/debug/bcf_contrast_debug.py @@ -0,0 +1,284 @@ +# Demo of contrast computation function for BCF + +# Load libraries +from stochtree import BCFModel +from sklearn.model_selection import train_test_split +from scipy.stats import norm +import numpy as np + +# Generate data +n = 500 +p = 5 +rng = np.random.default_rng(1234) +X = rng.uniform(low=0.0, high=1.0, size=(n, p)) +mu_x = X[:, 0] +tau_x = 0.25 * X[:, 1] +pi_x = norm.cdf(0.5 * X[:, 0]) +Z = rng.binomial(n=1, p=pi_x, size=(n,)) +E_XZ = mu_x + Z * tau_x +snr = 2 +y = E_XZ + rng.normal(loc=0.0, scale=1.0, size=(n,)) * (np.std(E_XZ) / snr) + +# Train-test split +test_set_pct = 0.2 +train_inds, test_inds = train_test_split( + np.arange(n), test_size=test_set_pct, random_state=1234 +) +X_train = X[train_inds, :] +X_test = X[test_inds, :] +Z_train = Z[train_inds] +Z_test = Z[test_inds] +y_train = y[train_inds] +y_test = y[test_inds] +n_test = len(test_inds) +n_train = len(train_inds) + +# Fit BCF model +bcf_model = BCFModel() +bcf_model.sample( + X_train=X_train, + Z_train=Z_train, + y_train=y_train, + num_gfr=10, + num_burnin=0, + num_mcmc=1000, +) + +# Compute contrast posterior +contrast_posterior_test = bcf_model.compute_contrast( + X_0=X_test, + X_1=X_test, + Z_0=np.zeros((n_test, 1)), + Z_1=np.ones((n_test, 1)), + type="posterior", + scale="linear", +) + +# Compute the same quantity via two predict calls +y_hat_posterior_test_0 = bcf_model.predict( + X=X_test, Z=np.zeros((n_test, 1)), type="posterior", terms="y_hat", scale="linear" +) +y_hat_posterior_test_1 = bcf_model.predict( + X=X_test, Z=np.ones((n_test, 1)), type="posterior", terms="y_hat", scale="linear" +) +contrast_posterior_test_comparison = y_hat_posterior_test_1 - y_hat_posterior_test_0 + +# Compare results +contrast_diff = contrast_posterior_test_comparison - contrast_posterior_test +np.allclose(contrast_diff, 0, atol=0.001) + +# Generate data for a BCF model with random effects +X = rng.uniform(low=0.0, high=1.0, size=(n, p)) +mu_x = X[:, 0] +tau_x = 0.25 * X[:, 1] +pi_x = norm.cdf(0.5 * X[:, 0]) +Z = rng.binomial(n=1, p=pi_x, size=(n,)) +num_rfx_groups = 3 +group_labels = rng.choice(num_rfx_groups, size=n) +basis = np.empty((n, 2)) +basis[:, 0] = 1.0 +basis[:, 1] = rng.uniform(0, 1, (n,)) +rfx_coefs = np.array([[-2, -2], [0, 0], [2, 2]]) +rfx_term = np.sum(rfx_coefs[group_labels, :] * basis, axis=1) +E_XZ = mu_x + Z * tau_x + rfx_term +snr = 2 +y = E_XZ + rng.normal(loc=0.0, scale=1.0, size=(n,)) * (np.std(E_XZ) / snr) + +# Train-test split +train_inds, test_inds = train_test_split( + np.arange(n), test_size=test_set_pct, random_state=1234 +) +X_train = X[train_inds, :] +X_test = X[test_inds, :] +Z_train = Z[train_inds] +Z_test = Z[test_inds] +y_train = y[train_inds] +y_test = y[test_inds] +group_ids_train = group_labels[train_inds] +group_ids_test = group_labels[test_inds] +rfx_basis_train = basis[train_inds, :] +rfx_basis_test = basis[test_inds, :] +n_test = len(test_inds) +n_train = len(train_inds) + +# Fit BCF model +bcf_model = BCFModel() +bcf_model.sample( + X_train=X_train, + Z_train=Z_train, + y_train=y_train, + rfx_group_ids_train=group_ids_train, + rfx_basis_train=rfx_basis_train, + num_gfr=10, + num_burnin=0, + num_mcmc=1000, +) + +# Compute contrast posterior +contrast_posterior_test = bcf_model.compute_contrast( + X_0=X_test, + X_1=X_test, + Z_0=np.zeros((n_test, 1)), + Z_1=np.ones((n_test, 1)), + rfx_group_ids_0=group_ids_test, + rfx_group_ids_1=group_ids_test, + rfx_basis_0=rfx_basis_test, + rfx_basis_1=rfx_basis_test, + type="posterior", + scale="linear", +) + +# Compute the same quantity via two predict calls +y_hat_posterior_test_0 = bcf_model.predict( + X=X_test, + Z=np.zeros((n_test, 1)), + rfx_group_ids=group_ids_test, + rfx_basis=rfx_basis_test, + type="posterior", + terms="y_hat", + scale="linear", +) +y_hat_posterior_test_1 = bcf_model.predict( + X=X_test, + Z=np.ones((n_test, 1)), + rfx_group_ids=group_ids_test, + rfx_basis=rfx_basis_test, + type="posterior", + terms="y_hat", + scale="linear", +) +contrast_posterior_test_comparison = y_hat_posterior_test_1 - y_hat_posterior_test_0 + +# Compare results +contrast_diff = contrast_posterior_test_comparison - contrast_posterior_test +np.allclose(contrast_diff, 0, atol=0.001) + +# Now repeat the same process but via random effects model spec +bcf_model = BCFModel() +bcf_model.sample( + X_train=X_train, + Z_train=Z_train, + y_train=y_train, + rfx_group_ids_train=group_ids_train, + num_gfr=10, + num_burnin=0, + num_mcmc=1000, + random_effects_params={"model_spec": "intercept_plus_treatment"}, +) + +# Compute CATE posterior +tau_hat_posterior_test = bcf_model.compute_contrast( + X_0=X_test, + X_1=X_test, + Z_0=np.zeros((n_test, 1)), + Z_1=np.ones((n_test, 1)), + rfx_group_ids_0=group_ids_test, + rfx_group_ids_1=group_ids_test, + rfx_basis_0=np.concatenate((np.ones((n_test, 1)), np.zeros((n_test, 1))), axis=1), + rfx_basis_1=np.ones((n_test, 2)), + type="posterior", + scale="linear", +) + +# Compute the same quantity via predict +tau_hat_posterior_test_comparison = bcf_model.predict( + X=X_test, + Z=Z_test, + rfx_group_ids=group_ids_test, + type="posterior", + terms="cate", + scale="linear", +) + +# Compare results +contrast_diff = tau_hat_posterior_test_comparison - tau_hat_posterior_test +np.allclose(contrast_diff, 0, atol=0.001) + +# Generate data for a probit BCF model with random effects +X = rng.uniform(low=0.0, high=1.0, size=(n, p)) +mu_x = X[:, 0] +tau_x = 0.25 * X[:, 1] +pi_x = norm.cdf(0.5 * X[:, 0]) +Z = rng.binomial(n=1, p=pi_x, size=(n,)) +num_rfx_groups = 3 +group_labels = rng.choice(num_rfx_groups, size=n) +basis = np.empty((n, 2)) +basis[:, 0] = 1.0 +basis[:, 1] = rng.uniform(0, 1, (n,)) +rfx_coefs = np.array([[-2, -2], [0, 0], [2, 2]]) +rfx_term = np.sum(rfx_coefs[group_labels, :] * basis, axis=1) +E_XZ = mu_x + Z * tau_x + rfx_term +W = E_XZ + rng.normal(loc=0.0, scale=1.0, size=(n,)) +y = (W > 0) * 1.0 + +# Train-test split +train_inds, test_inds = train_test_split( + np.arange(n), test_size=test_set_pct, random_state=1234 +) +X_train = X[train_inds, :] +X_test = X[test_inds, :] +Z_train = Z[train_inds] +Z_test = Z[test_inds] +W_train = W[train_inds] +W_test = W[test_inds] +y_train = y[train_inds] +y_test = y[test_inds] +group_ids_train = group_labels[train_inds] +group_ids_test = group_labels[test_inds] +rfx_basis_train = basis[train_inds, :] +rfx_basis_test = basis[test_inds, :] +n_test = len(test_inds) +n_train = len(train_inds) + +# Fit BCF model +bcf_model = BCFModel() +bcf_model.sample( + X_train=X_train, + Z_train=Z_train, + y_train=y_train, + rfx_group_ids_train=group_ids_train, + rfx_basis_train=rfx_basis_train, + num_gfr=10, + num_burnin=0, + num_mcmc=1000, + general_params={"probit_outcome_model": True}, +) + +# Compute contrast posterior +contrast_posterior_test = bcf_model.compute_contrast( + X_0=X_test, + X_1=X_test, + Z_0=np.zeros((n_test, 1)), + Z_1=np.ones((n_test, 1)), + rfx_group_ids_0=group_ids_test, + rfx_group_ids_1=group_ids_test, + rfx_basis_0=rfx_basis_test, + rfx_basis_1=rfx_basis_test, + type="posterior", + scale="probability", +) + +# Compute the same quantity via two predict calls +y_hat_posterior_test_0 = bcf_model.predict( + X=X_test, + Z=np.zeros((n_test, 1)), + rfx_group_ids=group_ids_test, + rfx_basis=rfx_basis_test, + type="posterior", + terms="y_hat", + scale="probability", +) +y_hat_posterior_test_1 = bcf_model.predict( + X=X_test, + Z=np.ones((n_test, 1)), + rfx_group_ids=group_ids_test, + rfx_basis=rfx_basis_test, + type="posterior", + terms="y_hat", + scale="probability", +) +contrast_posterior_test_comparison = y_hat_posterior_test_1 - y_hat_posterior_test_0 + +# Compare results +contrast_diff = contrast_posterior_test_comparison - contrast_posterior_test +np.allclose(contrast_diff, 0, atol=0.001) diff --git a/demo/debug/bcf_predict_debug.py b/demo/debug/bcf_predict_debug.py new file mode 100644 index 00000000..2257684a --- /dev/null +++ b/demo/debug/bcf_predict_debug.py @@ -0,0 +1,139 @@ +# Demo of updated predict method for BART + +# Load library +from stochtree import BCFModel +import numpy as np +from sklearn.model_selection import train_test_split +from scipy.stats import norm +import matplotlib.pyplot as plt + +# Generate data +rng = np.random.default_rng() +n = 1000 +p = 5 +X = rng.normal(loc=0.0, scale=1.0, size=(n, p)) +mu_X = X[:, 0] +tau_X = 0.25 * X[:, 1] +pi_X = norm.cdf(0.5 * X[:, 1]) +Z = rng.binomial(n=1, p=pi_X, size=(n,)) +E_XZ = mu_X + tau_X * Z +snr = 2.0 +noise_sd = np.std(E_XZ) / snr +y = E_XZ + rng.normal(loc=0.0, scale=noise_sd, size=(n,)) + +# Train-test split +sample_inds = np.arange(n) +test_set_pct = 0.2 +train_inds, test_inds = train_test_split(sample_inds, test_size=test_set_pct) +X_train = X[train_inds, :] +X_test = X[test_inds, :] +Z_train = Z[train_inds] +Z_test = Z[test_inds] +pi_train = pi_X[train_inds] +pi_test = pi_X[test_inds] +tau_train = tau_X[train_inds] +tau_test = tau_X[test_inds] +mu_train = mu_X[train_inds] +mu_test = mu_X[test_inds] +y_train = y[train_inds] +y_test = y[test_inds] +E_XZ_train = E_XZ[train_inds] +E_XZ_test = E_XZ[test_inds] + +# Fit simple BCF model +bcf_model = BCFModel() +bcf_model.sample( + X_train=X_train, + Z_train=Z_train, + pi_train=pi_train, + y_train=y_train, + num_gfr=10, + num_burnin=0, + num_mcmc=1000, +) + +# Check several predict approaches +bcf_preds = bcf_model.predict(X=X_test, Z=Z_test, propensity=pi_test) +y_hat_posterior_test = bcf_model.predict(X=X_test, Z=Z_test, propensity=pi_test)[ + "y_hat" +] +y_hat_mean_test = bcf_model.predict( + X=X_test, Z=Z_test, propensity=pi_test, type="mean", terms=["y_hat"] +) +tau_hat_mean_test = bcf_model.predict( + X=X_test, Z=Z_test, propensity=pi_test, type="mean", terms=["cate"] +) +# Check that this raises a warning +y_hat_test = bcf_model.predict( + X=X_test, Z=Z_test, propensity=pi_test, type="mean", terms=["rfx", "variance"] +) + +# Plot predicted versus actual +plt.scatter(y_hat_mean_test, y_test, color="black") +plt.axline((0, 0), slope=1, color="red", linestyle=(0, (3, 3))) +plt.xlabel("Predicted") +plt.ylabel("Actual") +plt.title("Y hat") +plt.show() + +# Plot predicted versus actual +plt.clf() +plt.scatter(tau_hat_mean_test, tau_test, color="black") +plt.axline((0, 0), slope=1, color="red", linestyle=(0, (3, 3))) +plt.xlabel("Predicted") +plt.ylabel("Actual") +plt.title("CATE function") +plt.show() + +# Compute posterior interval +intervals = bcf_model.compute_posterior_interval( + terms="all", + scale="linear", + level=0.95, + covariates=X_test, + treatment=Z_test, + propensity=pi_test, +) + +# Check coverage of E[Y | X, Z] +mean_coverage = np.mean( + (intervals["y_hat"]["lower"] <= E_XZ_test) + & (E_XZ_test <= intervals["y_hat"]["upper"]) +) +print(f"Coverage of 95% posterior interval for E[Y|X,Z]: {mean_coverage:.3f}") + +# Check coverage of tau(X) +tau_coverage = np.mean( + (intervals["tau_hat"]["lower"] <= tau_test) + & (tau_test <= intervals["tau_hat"]["upper"]) +) +print(f"Coverage of 95% posterior interval for tau(X): {tau_coverage:.3f}") + +# Check coverage of mu(X) +mu_coverage = np.mean( + (intervals["mu_hat"]["lower"] <= mu_test) + & (mu_test <= intervals["mu_hat"]["upper"]) +) +print(f"Coverage of 95% posterior interval for mu(X): {mu_coverage:.3f}") + +# Sample from the posterior predictive distribution +bcf_ppd_samples = bcf_model.sample_posterior_predictive( + covariates=X_test, treatment=Z_test, propensity=pi_test, num_draws_per_sample=10 +) + +# Plot PPD mean vs actual +ppd_mean = np.mean(bcf_ppd_samples, axis=(0, 2)) +plt.clf() +plt.scatter(ppd_mean, y_test, color="blue") +plt.axline((0, 0), slope=1, color="red", linestyle=(0, (3, 3))) +plt.xlabel("Predicted") +plt.ylabel("Actual") +plt.title("Posterior Predictive Mean Comparison") +plt.show() + +# Check coverage of posterior predictive distribution +ppd_intervals = np.percentile(bcf_ppd_samples, [2.5, 97.5], axis=(0, 2)) +ppd_coverage = np.mean( + (ppd_intervals[0, :] <= y_test) & (y_test <= ppd_intervals[1, :]) +) +print(f"Coverage of 95% posterior predictive interval for Y: {ppd_coverage:.3f}") diff --git a/demo/debug/probit_bart_rfx_debug.py b/demo/debug/probit_bart_rfx_debug.py new file mode 100644 index 00000000..ae2e8c10 --- /dev/null +++ b/demo/debug/probit_bart_rfx_debug.py @@ -0,0 +1,124 @@ +# Debuggin probit BCF with random for BART + +# Load libraries +from stochtree import BARTModel +from sklearn.model_selection import train_test_split +import numpy as np +import matplotlib.pyplot as plt + +# Generate data for a probit BART model with random effects +n = 500 +p = 5 +rng = np.random.default_rng(1234) +X = rng.uniform(low=0.0, high=1.0, size=(n, p)) +W = rng.normal(loc=0.0, scale=1.0, size=(n, 1)) +f_XW = np.where( + ((0 <= X[:, 0]) & (X[:, 0] < 0.25)), + -7.5 * W[:, 0], + np.where( + ((0.25 <= X[:, 0]) & (X[:, 0] < 0.5)), + -2.5 * W[:, 0], + np.where( + ((0.5 <= X[:, 0]) & (X[:, 0] < 0.75)), + 2.5 * W[:, 0], + 7.5 * W[:, 0], + ), + ), +) +num_rfx_groups = 3 +group_labels = rng.choice(num_rfx_groups, size=n) +basis = np.empty((n, 2)) +basis[:, 0] = 1.0 +basis[:, 1] = rng.uniform(0, 1, (n,)) +rfx_coefs = np.array([[-2, -2], [0, 0], [2, 2]]) +rfx_term = np.sum(rfx_coefs[group_labels, :] * basis, axis=1) +E_Y = f_XW + rfx_term +Z = E_Y + rng.normal(loc=0.0, scale=1.0, size=(n,)) +y = (Z > 0) * 1.0 + +# Train-test split +test_set_pct = 0.2 +train_inds, test_inds = train_test_split( + np.arange(n), test_size=test_set_pct, random_state=1234 +) +X_train = X[train_inds, :] +X_test = X[test_inds, :] +W_train = W[train_inds, :] +W_test = W[test_inds, :] +Z_train = Z[train_inds] +Z_test = Z[test_inds] +y_train = y[train_inds] +y_test = y[test_inds] +group_ids_train = group_labels[train_inds] +group_ids_test = group_labels[test_inds] +rfx_basis_train = basis[train_inds, :] +rfx_basis_test = basis[test_inds, :] +n_test = len(test_inds) +n_train = len(train_inds) + +# Fit BART model +bart_model = BARTModel() +bart_model.sample( + X_train=X_train, + leaf_basis_train=W_train, + y_train=y_train, + rfx_group_ids_train=group_ids_train, + rfx_basis_train=rfx_basis_train, + num_gfr=10, + num_burnin=0, + num_mcmc=1000, + general_params={"probit_outcome_model": True}, +) + +# Compute contrast posterior +contrast_posterior_test = bart_model.compute_contrast( + covariates_0=X_test, + covariates_1=X_test, + basis_0=np.zeros((n_test, 1)), + basis_1=np.ones((n_test, 1)), + rfx_group_ids_0=group_ids_test, + rfx_group_ids_1=group_ids_test, + rfx_basis_0=rfx_basis_test, + rfx_basis_1=rfx_basis_test, + type="posterior", + scale="linear", +) + +# Compute the same quantity via two predict calls +y_hat_posterior_test_0 = bart_model.predict( + covariates=X_test, + basis=np.zeros((n_test, 1)), + rfx_group_ids=group_ids_test, + rfx_basis=rfx_basis_test, + type="posterior", + terms="y_hat", + scale="linear", +) +y_hat_posterior_test_1 = bart_model.predict( + covariates=X_test, + basis=np.ones((n_test, 1)), + rfx_group_ids=group_ids_test, + rfx_basis=rfx_basis_test, + type="posterior", + terms="y_hat", + scale="linear", +) +contrast_posterior_test_comparison = y_hat_posterior_test_1 - y_hat_posterior_test_0 + +# Compare results +contrast_diff = contrast_posterior_test_comparison - contrast_posterior_test +np.allclose(contrast_diff, 0, atol=0.001) + +# Plot predicted versus actual outcome +Z_hat_test = bart_model.predict( + covariates=X_test, + basis=W_test, + rfx_group_ids=group_ids_test, + rfx_basis=rfx_basis_test, + type="mean", + terms="y_hat", + scale="linear", +) +plt.scatter(Z_hat_test, Z_test, alpha=0.5) +plt.axline((0, 0), slope=1, color="red", linestyle="--") +plt.show() diff --git a/demo/debug/probit_bcf_rfx_debug.py b/demo/debug/probit_bcf_rfx_debug.py new file mode 100644 index 00000000..9c2dbdfb --- /dev/null +++ b/demo/debug/probit_bcf_rfx_debug.py @@ -0,0 +1,115 @@ +# Debugging probit BCF with random effects + +# Load libraries +from stochtree import BCFModel +from scipy.stats import norm +from sklearn.model_selection import train_test_split +import numpy as np +import matplotlib.pyplot as plt + +# Generate data for a probit BCF model with random effects +n = 1000 +p = 5 +rng = np.random.default_rng(1234) +X = rng.uniform(low=0.0, high=1.0, size=(n, p)) +mu_x = X[:, 0] +tau_x = 0.25 * X[:, 1] +pi_x = norm.cdf(0.5 * X[:, 0]) +Z = rng.binomial(n=1, p=pi_x, size=(n,)) +num_rfx_groups = 3 +group_labels = rng.choice(num_rfx_groups, size=n) +basis = np.empty((n, 2)) +basis[:, 0] = 1.0 +basis[:, 1] = rng.uniform(0, 1, (n,)) +rfx_coefs = np.array([[-1, 1], [0, 1], [1, 1]]) +rfx_term = np.sum(rfx_coefs[group_labels, :] * basis, axis=1) +E_XZ = mu_x + Z * tau_x + rfx_term +W = E_XZ + rng.normal(loc=0.0, scale=1.0, size=(n,)) +y = (W > 0) * 1.0 + +# Train-test split +test_set_pct = 0.2 +train_inds, test_inds = train_test_split( + np.arange(n), test_size=test_set_pct, random_state=1234 +) +X_train = X[train_inds, :] +X_test = X[test_inds, :] +Z_train = Z[train_inds] +Z_test = Z[test_inds] +W_train = W[train_inds] +W_test = W[test_inds] +y_train = y[train_inds] +y_test = y[test_inds] +group_ids_train = group_labels[train_inds] +group_ids_test = group_labels[test_inds] +rfx_basis_train = basis[train_inds, :] +rfx_basis_test = basis[test_inds, :] +n_test = len(test_inds) +n_train = len(train_inds) + +# Fit BCF model +bcf_model = BCFModel() +bcf_model.sample( + X_train=X_train, + Z_train=Z_train, + y_train=y_train, + rfx_group_ids_train=group_ids_train, + rfx_basis_train=rfx_basis_train, + num_gfr=10, + num_burnin=0, + num_mcmc=1000, + general_params={"probit_outcome_model": True}, +) + +# Compute contrast posterior +contrast_posterior_test = bcf_model.compute_contrast( + X_0=X_test, + X_1=X_test, + Z_0=np.zeros((n_test, 1)), + Z_1=np.ones((n_test, 1)), + rfx_group_ids_0=group_ids_test, + rfx_group_ids_1=group_ids_test, + rfx_basis_0=rfx_basis_test, + rfx_basis_1=rfx_basis_test, + type="posterior", + scale="probability", +) + +# Compute the same quantity via two predict calls +y_hat_posterior_test_0 = bcf_model.predict( + X=X_test, + Z=np.zeros((n_test, 1)), + rfx_group_ids=group_ids_test, + rfx_basis=rfx_basis_test, + type="posterior", + terms="y_hat", + scale="probability", +) +y_hat_posterior_test_1 = bcf_model.predict( + X=X_test, + Z=np.ones((n_test, 1)), + rfx_group_ids=group_ids_test, + rfx_basis=rfx_basis_test, + type="posterior", + terms="y_hat", + scale="probability", +) +contrast_posterior_test_comparison = y_hat_posterior_test_1 - y_hat_posterior_test_0 + +# Compare results +contrast_diff = contrast_posterior_test_comparison - contrast_posterior_test +np.allclose(contrast_diff, 0, atol=0.001) + +# Plot predicted versus actual outcome +W_hat_test = bcf_model.predict( + X=X_test, + Z=Z_test, + rfx_group_ids=group_ids_test, + rfx_basis=rfx_basis_test, + type="mean", + terms="y_hat", + scale="linear", +) +plt.scatter(W_hat_test, W_test, alpha=0.5) +plt.axline((0, 0), slope=1, color="red", linestyle="--") +plt.show() diff --git a/man/RandomEffectSamples.Rd b/man/RandomEffectSamples.Rd index ae5e9ac0..beb85a8b 100644 --- a/man/RandomEffectSamples.Rd +++ b/man/RandomEffectSamples.Rd @@ -240,7 +240,7 @@ and sigma (group-independent prior variance for each component of xi). \subsection{Returns}{ List of arrays. The alpha array has dimension (\code{num_components}, \code{num_samples}) and is simply a vector if \code{num_components = 1}. -The xi and beta arrays have dimension (\code{num_components}, \code{num_groups}, \code{num_samples}) and is simply a matrix if \code{num_components = 1}. +The xi and beta arrays have dimension (\code{num_components}, \code{num_groups}, \code{num_samples}) and are simply matrices if \code{num_components = 1}. The sigma array has dimension (\code{num_components}, \code{num_samples}) and is simply a vector if \code{num_components = 1}. } } diff --git a/man/bart.Rd b/man/bart.Rd index 66a9b9ad..c76ec963 100644 --- a/man/bart.Rd +++ b/man/bart.Rd @@ -21,7 +21,8 @@ bart( previous_model_warmstart_sample_num = NULL, general_params = list(), mean_forest_params = list(), - variance_forest_params = list() + variance_forest_params = list(), + random_effects_params = list() ) } \arguments{ @@ -84,12 +85,6 @@ that were not in the training set.} \item \code{num_chains} How many independent MCMC chains should be sampled. If \code{num_mcmc = 0}, this is ignored. If \code{num_gfr = 0}, then each chain is run from root for \code{num_mcmc * keep_every + num_burnin} iterations, with \code{num_mcmc} samples retained. If \code{num_gfr > 0}, each MCMC chain will be initialized from a separate GFR ensemble, with the requirement that \code{num_gfr >= num_chains}. Default: \code{1}. \item \code{verbose} Whether or not to print progress during the sampling loops. Default: \code{FALSE}. \item \code{probit_outcome_model} Whether or not the outcome should be modeled as explicitly binary via a probit link. If \code{TRUE}, \code{y} must only contain the values \code{0} and \code{1}. Default: \code{FALSE}. -\item \code{rfx_working_parameter_prior_mean} Prior mean for the random effects "working parameter". Default: \code{NULL}. Must be a vector whose dimension matches the number of random effects bases, or a scalar value that will be expanded to a vector. -\item \code{rfx_group_parameters_prior_mean} Prior mean for the random effects "group parameters." Default: \code{NULL}. Must be a vector whose dimension matches the number of random effects bases, or a scalar value that will be expanded to a vector. -\item \code{rfx_working_parameter_prior_cov} Prior covariance matrix for the random effects "working parameter." Default: \code{NULL}. Must be a square matrix whose dimension matches the number of random effects bases, or a scalar value that will be expanded to a diagonal matrix. -\item \code{rfx_group_parameter_prior_cov} Prior covariance matrix for the random effects "group parameters." Default: \code{NULL}. Must be a square matrix whose dimension matches the number of random effects bases, or a scalar value that will be expanded to a diagonal matrix. -\item \code{rfx_variance_prior_shape} Shape parameter for the inverse gamma prior on the variance of the random effects "group parameter." Default: \code{1}. -\item \code{rfx_variance_prior_scale} Scale parameter for the inverse gamma prior on the variance of the random effects "group parameter." Default: \code{1}. \item \code{num_threads} Number of threads to use in the GFR and MCMC algorithms, as well as prediction. If OpenMP is not available on a user's setup, this will default to \code{1}, otherwise to the maximum number of available threads. }} @@ -124,6 +119,17 @@ that were not in the training set.} \item \code{drop_vars} Vector of variable names or column indices denoting variables that should be excluded from the forest. Default: \code{NULL}. If both \code{drop_vars} and \code{keep_vars} are set, \code{drop_vars} will be ignored. \item \code{num_features_subsample} How many features to subsample when growing each tree for the GFR algorithm. Defaults to the number of features in the training dataset. }} + +\item{random_effects_params}{(Optional) A list of random effects model parameters, each of which has a default value processed internally, so this argument list is optional. +\itemize{ +\item \code{model_spec} Specification of the random effects model. Options are "custom" and "intercept_only". If "custom" is specified, then a user-provided basis must be passed through \code{rfx_basis_train}. If "intercept_only" is specified, a random effects basis of all ones will be dispatched internally at sampling and prediction time. If "intercept_plus_treatment" is specified, a random effects basis that combines an "intercept" basis of all ones with the treatment variable (\code{Z_train}) will be dispatched internally at sampling and prediction time. Default: "custom". If "intercept_only" is specified, \code{rfx_basis_train} and \code{rfx_basis_test} (if provided) will be ignored. +\item \code{working_parameter_prior_mean} Prior mean for the random effects "working parameter". Default: \code{NULL}. Must be a vector whose dimension matches the number of random effects bases, or a scalar value that will be expanded to a vector. +\item \code{group_parameters_prior_mean} Prior mean for the random effects "group parameters." Default: \code{NULL}. Must be a vector whose dimension matches the number of random effects bases, or a scalar value that will be expanded to a vector. +\item \code{working_parameter_prior_cov} Prior covariance matrix for the random effects "working parameter." Default: \code{NULL}. Must be a square matrix whose dimension matches the number of random effects bases, or a scalar value that will be expanded to a diagonal matrix. +\item \code{group_parameter_prior_cov} Prior covariance matrix for the random effects "group parameters." Default: \code{NULL}. Must be a square matrix whose dimension matches the number of random effects bases, or a scalar value that will be expanded to a diagonal matrix. +\item \code{variance_prior_shape} Shape parameter for the inverse gamma prior on the variance of the random effects "group parameter." Default: \code{1}. +\item \code{variance_prior_scale} Scale parameter for the inverse gamma prior on the variance of the random effects "group parameter." Default: \code{1}. +}} } \value{ List of sampling outputs and a wrapper around the sampled forests (which can be used for in-memory prediction on new data, or serialized to JSON on disk). @@ -136,9 +142,9 @@ n <- 100 p <- 5 X <- matrix(runif(n*p), ncol = p) f_XW <- ( - ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + - ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + - ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + + ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + + ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + + ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + ((0.75 <= X[,1]) & (1 > X[,1])) * (7.5) ) noise_sd <- 1 @@ -153,6 +159,6 @@ X_train <- X[train_inds,] y_test <- y[test_inds] y_train <- y[train_inds] -bart_model <- bart(X_train = X_train, y_train = y_train, X_test = X_test, +bart_model <- bart(X_train = X_train, y_train = y_train, X_test = X_test, num_gfr = 10, num_burnin = 0, num_mcmc = 10) } diff --git a/man/bcf.Rd b/man/bcf.Rd index 01e5fab8..55e5e181 100644 --- a/man/bcf.Rd +++ b/man/bcf.Rd @@ -24,7 +24,8 @@ bcf( general_params = list(), prognostic_forest_params = list(), treatment_effect_forest_params = list(), - variance_forest_params = list() + variance_forest_params = list(), + random_effects_params = list() ) } \arguments{ @@ -91,12 +92,6 @@ that were not in the training set.} \item \code{num_chains} How many independent MCMC chains should be sampled. If \code{num_mcmc = 0}, this is ignored. If \code{num_gfr = 0}, then each chain is run from root for \code{num_mcmc * keep_every + num_burnin} iterations, with \code{num_mcmc} samples retained. If \code{num_gfr > 0}, each MCMC chain will be initialized from a separate GFR ensemble, with the requirement that \code{num_gfr >= num_chains}. Default: \code{1}. \item \code{verbose} Whether or not to print progress during the sampling loops. Default: \code{FALSE}. \item \code{probit_outcome_model} Whether or not the outcome should be modeled as explicitly binary via a probit link. If \code{TRUE}, \code{y} must only contain the values \code{0} and \code{1}. Default: \code{FALSE}. -\item \code{rfx_working_parameter_prior_mean} Prior mean for the random effects "working parameter". Default: \code{NULL}. Must be a vector whose dimension matches the number of random effects bases, or a scalar value that will be expanded to a vector. -\item \code{rfx_group_parameters_prior_mean} Prior mean for the random effects "group parameters." Default: \code{NULL}. Must be a vector whose dimension matches the number of random effects bases, or a scalar value that will be expanded to a vector. -\item \code{rfx_working_parameter_prior_cov} Prior covariance matrix for the random effects "working parameter." Default: \code{NULL}. Must be a square matrix whose dimension matches the number of random effects bases, or a scalar value that will be expanded to a diagonal matrix. -\item \code{rfx_group_parameter_prior_cov} Prior covariance matrix for the random effects "group parameters." Default: \code{NULL}. Must be a square matrix whose dimension matches the number of random effects bases, or a scalar value that will be expanded to a diagonal matrix. -\item \code{rfx_variance_prior_shape} Shape parameter for the inverse gamma prior on the variance of the random effects "group parameter." Default: \code{1}. -\item \code{rfx_variance_prior_scale} Scale parameter for the inverse gamma prior on the variance of the random effects "group parameter." Default: \code{1}. \item \code{num_threads} Number of threads to use in the GFR and MCMC algorithms, as well as prediction. If OpenMP is not available on a user's setup, this will default to \code{1}, otherwise to the maximum number of available threads. }} @@ -150,6 +145,17 @@ that were not in the training set.} \item \code{drop_vars} Vector of variable names or column indices denoting variables that should be excluded from the forest. Default: \code{NULL}. If both \code{drop_vars} and \code{keep_vars} are set, \code{drop_vars} will be ignored. \item \code{num_features_subsample} How many features to subsample when growing each tree for the GFR algorithm. Defaults to the number of features in the training dataset. }} + +\item{random_effects_params}{(Optional) A list of random effects model parameters, each of which has a default value processed internally, so this argument list is optional. +\itemize{ +\item \code{model_spec} Specification of the random effects model. Options are "custom", "intercept_only", and "intercept_plus_treatment". If "custom" is specified, then a user-provided basis must be passed through \code{rfx_basis_train}. If "intercept_only" is specified, a random effects basis of all ones will be dispatched internally at sampling and prediction time. If "intercept_plus_treatment" is specified, a random effects basis that combines an "intercept" basis of all ones with the treatment variable (\code{Z_train}) will be dispatched internally at sampling and prediction time. Default: "custom". If either "intercept_only" or "intercept_plus_treatment" is specified, \code{rfx_basis_train} and \code{rfx_basis_test} (if provided) will be ignored. +\item \code{working_parameter_prior_mean} Prior mean for the random effects "working parameter". Default: \code{NULL}. Must be a vector whose dimension matches the number of random effects bases, or a scalar value that will be expanded to a vector. +\item \code{group_parameters_prior_mean} Prior mean for the random effects "group parameters." Default: \code{NULL}. Must be a vector whose dimension matches the number of random effects bases, or a scalar value that will be expanded to a vector. +\item \code{working_parameter_prior_cov} Prior covariance matrix for the random effects "working parameter." Default: \code{NULL}. Must be a square matrix whose dimension matches the number of random effects bases, or a scalar value that will be expanded to a diagonal matrix. +\item \code{group_parameter_prior_cov} Prior covariance matrix for the random effects "group parameters." Default: \code{NULL}. Must be a square matrix whose dimension matches the number of random effects bases, or a scalar value that will be expanded to a diagonal matrix. +\item \code{variance_prior_shape} Shape parameter for the inverse gamma prior on the variance of the random effects "group parameter." Default: \code{1}. +\item \code{variance_prior_scale} Scale parameter for the inverse gamma prior on the variance of the random effects "group parameter." Default: \code{1}. +}} } \value{ List of sampling outputs and a wrapper around the sampled forests (which can be used for in-memory prediction on new data, or serialized to JSON on disk). @@ -162,21 +168,21 @@ n <- 500 p <- 5 X <- matrix(runif(n*p), ncol = p) mu_x <- ( - ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + - ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + - ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + + ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + + ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + + ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + ((0.75 <= X[,1]) & (1 > X[,1])) * (7.5) ) pi_x <- ( - ((0 <= X[,1]) & (0.25 > X[,1])) * (0.2) + - ((0.25 <= X[,1]) & (0.5 > X[,1])) * (0.4) + - ((0.5 <= X[,1]) & (0.75 > X[,1])) * (0.6) + + ((0 <= X[,1]) & (0.25 > X[,1])) * (0.2) + + ((0.25 <= X[,1]) & (0.5 > X[,1])) * (0.4) + + ((0.5 <= X[,1]) & (0.75 > X[,1])) * (0.6) + ((0.75 <= X[,1]) & (1 > X[,1])) * (0.8) ) tau_x <- ( - ((0 <= X[,2]) & (0.25 > X[,2])) * (0.5) + - ((0.25 <= X[,2]) & (0.5 > X[,2])) * (1.0) + - ((0.5 <= X[,2]) & (0.75 > X[,2])) * (1.5) + + ((0 <= X[,2]) & (0.25 > X[,2])) * (0.5) + + ((0.25 <= X[,2]) & (0.5 > X[,2])) * (1.0) + + ((0.5 <= X[,2]) & (0.75 > X[,2])) * (1.5) + ((0.75 <= X[,2]) & (1 > X[,2])) * (2.0) ) Z <- rbinom(n, 1, pi_x) @@ -199,8 +205,8 @@ mu_test <- mu_x[test_inds] mu_train <- mu_x[train_inds] tau_test <- tau_x[test_inds] tau_train <- tau_x[train_inds] -bcf_model <- bcf(X_train = X_train, Z_train = Z_train, y_train = y_train, - propensity_train = pi_train, X_test = X_test, Z_test = Z_test, - propensity_test = pi_test, num_gfr = 10, +bcf_model <- bcf(X_train = X_train, Z_train = Z_train, y_train = y_train, + propensity_train = pi_train, X_test = X_test, Z_test = Z_test, + propensity_test = pi_test, num_gfr = 10, num_burnin = 0, num_mcmc = 10) } diff --git a/man/compute_bart_posterior_interval.Rd b/man/compute_bart_posterior_interval.Rd new file mode 100644 index 00000000..59a0a895 --- /dev/null +++ b/man/compute_bart_posterior_interval.Rd @@ -0,0 +1,53 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/posterior_transformation.R +\name{compute_bart_posterior_interval} +\alias{compute_bart_posterior_interval} +\title{Compute posterior credible intervals for specified terms from a fitted BART model. It supports intervals for mean functions, variance functions, random effects, and overall predictions.} +\usage{ +compute_bart_posterior_interval( + model_object, + terms, + level = 0.95, + scale = "linear", + covariates = NULL, + basis = NULL, + rfx_group_ids = NULL, + rfx_basis = NULL +) +} +\arguments{ +\item{model_object}{A fitted BART or BCF model object of class \code{bartmodel}.} + +\item{terms}{A character string specifying the model term(s) for which to compute intervals. Options for BART models are \code{"mean_forest"}, \code{"variance_forest"}, \code{"rfx"}, or \code{"y_hat"}.} + +\item{level}{A numeric value between 0 and 1 specifying the credible interval level (default is 0.95 for a 95\% credible interval).} + +\item{scale}{(Optional) Scale of mean function predictions. Options are "linear", which returns predictions on the original scale of the mean forest / RFX terms, and "probability", which transforms predictions into a probability of observing \code{y == 1}. "probability" is only valid for models fit with a probit outcome model. Default: "linear".} + +\item{covariates}{A matrix or data frame of covariates at which to compute the intervals. Required if the requested term depends on covariates (e.g., mean forest, variance forest, or overall predictions).} + +\item{basis}{An optional matrix of basis function evaluations for mean forest models with regression defined in the leaves. Required for "leaf regression" models.} + +\item{rfx_group_ids}{An optional vector of group IDs for random effects. Required if the requested term includes random effects.} + +\item{rfx_basis}{An optional matrix of basis function evaluations for random effects. Required if the requested term includes random effects.} +} +\value{ +A list containing the lower and upper bounds of the credible interval for the specified term. If multiple terms are requested, a named list with intervals for each term is returned. +} +\description{ +Compute posterior credible intervals for specified terms from a fitted BART model. It supports intervals for mean functions, variance functions, random effects, and overall predictions. +} +\examples{ +n <- 100 +p <- 5 +X <- matrix(rnorm(n * p), nrow = n, ncol = p) +y <- 2 * X[,1] + rnorm(n) +bart_model <- bart(y_train = y, X_train = X) +intervals <- compute_bart_posterior_interval( + model_object = bart_model, + terms = c("mean_forest", "y_hat"), + covariates = X, + level = 0.90 +) +} diff --git a/man/compute_bcf_posterior_interval.Rd b/man/compute_bcf_posterior_interval.Rd new file mode 100644 index 00000000..880226e6 --- /dev/null +++ b/man/compute_bcf_posterior_interval.Rd @@ -0,0 +1,63 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/posterior_transformation.R +\name{compute_bcf_posterior_interval} +\alias{compute_bcf_posterior_interval} +\title{Compute posterior credible intervals for BCF model terms} +\usage{ +compute_bcf_posterior_interval( + model_object, + terms, + level = 0.95, + scale = "linear", + covariates = NULL, + treatment = NULL, + propensity = NULL, + rfx_group_ids = NULL, + rfx_basis = NULL +) +} +\arguments{ +\item{model_object}{A fitted BCF model object of class \code{bcfmodel}.} + +\item{terms}{A character string specifying the model term(s) for which to compute intervals. Options for BCF models are \code{"prognostic_function"}, \code{"cate"}, \code{"variance_forest"}, \code{"rfx"}, or \code{"y_hat"}.} + +\item{level}{A numeric value between 0 and 1 specifying the credible interval level (default is 0.95 for a 95\% credible interval).} + +\item{scale}{(Optional) Scale of mean function predictions. Options are "linear", which returns predictions on the original scale of the mean forest / RFX terms, and "probability", which transforms predictions into a probability of observing \code{y == 1}. "probability" is only valid for models fit with a probit outcome model. Default: "linear".} + +\item{covariates}{(Optional) A matrix or data frame of covariates at which to compute the intervals. Required if the requested term depends on covariates (e.g., prognostic forest, CATE forest, variance forest, or overall predictions).} + +\item{treatment}{(Optional) A vector or matrix of treatment assignments. Required if the requested term is \code{"y_hat"} (overall predictions).} + +\item{propensity}{(Optional) A vector or matrix of propensity scores. Required if the underlying model depends on user-provided propensities.} + +\item{rfx_group_ids}{An optional vector of group IDs for random effects. Required if the requested term includes random effects.} + +\item{rfx_basis}{An optional matrix of basis function evaluations for random effects. Required if the requested term includes random effects.} +} +\value{ +A list containing the lower and upper bounds of the credible interval for the specified term. If multiple terms are requested, a named list with intervals for each term is returned. +} +\description{ +This function computes posterior credible intervals for specified terms from a fitted BCF model. It supports intervals for prognostic forests, CATE forests, variance forests, random effects, and overall mean outcome predictions. +} +\examples{ +n <- 100 +p <- 5 +X <- matrix(rnorm(n * p), nrow = n, ncol = p) +pi_X <- pnorm(0.5 * X[,1]) +Z <- rbinom(n, 1, pi_X) +mu_X <- X[,1] +tau_X <- 0.25 * X[,2] +y <- mu_X + tau_X * Z + rnorm(n) +bcf_model <- bcf(X_train = X, Z_train = Z, y_train = y, + propensity_train = pi_X) +intervals <- compute_bcf_posterior_interval( + model_object = bcf_model, + terms = c("prognostic_function", "cate"), + covariates = X, + treatment = Z, + propensity = pi_X, + level = 0.90 +) +} diff --git a/man/compute_contrast_bart_model.Rd b/man/compute_contrast_bart_model.Rd new file mode 100644 index 00000000..8a0c3096 --- /dev/null +++ b/man/compute_contrast_bart_model.Rd @@ -0,0 +1,95 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/posterior_transformation.R +\name{compute_contrast_bart_model} +\alias{compute_contrast_bart_model} +\title{Compute a contrast using a BART model by making two sets of outcome predictions and taking their difference. +This function provides the flexibility to compute any contrast of interest by specifying covariates, leaf basis, and random effects +bases / IDs for both sides of a two term contrast. For simplicity, we refer to the subtrahend of the contrast as the "control" or +\code{Y0} term and the minuend of the contrast as the \code{Y1} term, though the requested contrast need not match the "control vs treatment" +terminology of a classic two-treatment causal inference problem. We mirror the function calls and terminology of the \code{predict.bartmodel} +function, labeling each prediction data term with a \code{1} to denote its contribution to the treatment prediction of a contrast and +\code{0} to denote inclusion in the control prediction.} +\usage{ +compute_contrast_bart_model( + object, + covariates_0, + covariates_1, + leaf_basis_0 = NULL, + leaf_basis_1 = NULL, + rfx_group_ids_0 = NULL, + rfx_group_ids_1 = NULL, + rfx_basis_0 = NULL, + rfx_basis_1 = NULL, + type = "posterior", + scale = "linear" +) +} +\arguments{ +\item{object}{Object of type \code{bart} containing draws of a regression forest and associated sampling outputs.} + +\item{covariates_0}{Covariates used for prediction in the "control" case. Must be a matrix or dataframe.} + +\item{covariates_1}{Covariates used for prediction in the "treatment" case. Must be a matrix or dataframe.} + +\item{leaf_basis_0}{(Optional) Bases used for prediction in the "control" case (by e.g. dot product with leaf values). Default: \code{NULL}.} + +\item{leaf_basis_1}{(Optional) Bases used for prediction in the "treatment" case (by e.g. dot product with leaf values). Default: \code{NULL}.} + +\item{rfx_group_ids_0}{(Optional) Test set group labels used for prediction from an additive random effects +model in the "control" case. We do not currently support (but plan to in the near future), test set evaluation +for group labels that were not in the training set. Must be a vector.} + +\item{rfx_group_ids_1}{(Optional) Test set group labels used for prediction from an additive random effects +model in the "treatment" case. We do not currently support (but plan to in the near future), test set evaluation +for group labels that were not in the training set. Must be a vector.} + +\item{rfx_basis_0}{(Optional) Test set basis for used for prediction from an additive random effects model in the "control" case. Must be a matrix or vector.} + +\item{rfx_basis_1}{(Optional) Test set basis for used for prediction from an additive random effects model in the "treatment" case. Must be a matrix or vector.} + +\item{type}{(Optional) Aggregation level of the contrast. Options are "mean", which averages the contrast evaluations over every draw of a BART model, and "posterior", which returns the entire matrix of posterior contrast estimates. Default: "posterior".} + +\item{scale}{(Optional) Scale of the contrast. Options are "linear", which returns a contrast on the original scale of the mean forest / RFX terms, and "probability", which transforms each contrast term into a probability of observing \code{y == 1} before taking their difference. "probability" is only valid for models fit with a probit outcome model. Default: "linear".} +} +\value{ +Contrast matrix or vector, depending on whether type = "mean" or "posterior". +} +\description{ +Only valid when there is either a mean forest or a random effects term in the BART model. +} +\examples{ +n <- 100 +p <- 5 +X <- matrix(runif(n*p), ncol = p) +W <- matrix(runif(n*1), ncol = 1) +f_XW <- ( + ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5*W[,1]) + + ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5*W[,1]) + + ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5*W[,1]) + + ((0.75 <= X[,1]) & (1 > X[,1])) * (7.5*W[,1]) +) +noise_sd <- 1 +y <- f_XW + rnorm(n, 0, noise_sd) +test_set_pct <- 0.2 +n_test <- round(test_set_pct*n) +n_train <- n - n_test +test_inds <- sort(sample(1:n, n_test, replace = FALSE)) +train_inds <- (1:n)[!((1:n) \%in\% test_inds)] +X_test <- X[test_inds,] +X_train <- X[train_inds,] +W_test <- W[test_inds,] +W_train <- W[train_inds,] +y_test <- y[test_inds] +y_train <- y[train_inds] +bart_model <- bart(X_train = X_train, leaf_basis_train = W_train, y_train = y_train, + num_gfr = 10, num_burnin = 0, num_mcmc = 10) +contrast_test <- compute_contrast_bart_model( + bart_model, + covariates_0 = X_test, + covariates_1 = X_test, + leaf_basis_0 = matrix(0, nrow = n_test, ncol = 1), + leaf_basis_1 = matrix(1, nrow = n_test, ncol = 1), + type = "posterior", + scale = "linear" +) +} diff --git a/man/compute_contrast_bcf_model.Rd b/man/compute_contrast_bcf_model.Rd new file mode 100644 index 00000000..d28e77b0 --- /dev/null +++ b/man/compute_contrast_bcf_model.Rd @@ -0,0 +1,127 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/posterior_transformation.R +\name{compute_contrast_bcf_model} +\alias{compute_contrast_bcf_model} +\title{Compute a contrast using a BCF model by making two sets of outcome predictions and taking their difference. +For simple BCF models with binary treatment, this will yield the same prediction as requesting \code{terms = "cate"} +in the \code{predict.bcfmodel} function. For more general models, such as models with continuous / multivariate treatments or +an additive random effects term with a coefficient on the treatment, this function provides the flexibility to compute a +any contrast of interest by specifying covariates, treatment, and random effects bases and IDs for both sides of a two term +contrast. For simplicity, we refer to the subtrahend of the contrast as the "control" or \code{Y0} term and the minuend of the +contrast as the \code{Y1} term, though the requested contrast need not match the "control vs treatment" terminology of a classic +two-arm experiment. We mirror the function calls and terminology of the \code{predict.bcfmodel} function, labeling each prediction +data term with a \code{1} to denote its contribution to the treatment prediction of a contrast and \code{0} to denote inclusion in the +control prediction.} +\usage{ +compute_contrast_bcf_model( + object, + X_0, + X_1, + Z_0, + Z_1, + propensity_0 = NULL, + propensity_1 = NULL, + rfx_group_ids_0 = NULL, + rfx_group_ids_1 = NULL, + rfx_basis_0 = NULL, + rfx_basis_1 = NULL, + type = "posterior", + scale = "linear" +) +} +\arguments{ +\item{object}{Object of type \code{bcfmodel} containing draws of a Bayesian causal forest model and associated sampling outputs.} + +\item{X_0}{Covariates used for prediction in the "control" case. Must be a matrix or dataframe.} + +\item{X_1}{Covariates used for prediction in the "treatment" case. Must be a matrix or dataframe.} + +\item{Z_0}{Treatments used for prediction in the "control" case. Must be a matrix or vector.} + +\item{Z_1}{Treatments used for prediction in the "treatment" case. Must be a matrix or vector.} + +\item{propensity_0}{(Optional) Propensities used for prediction in the "control" case. Must be a matrix or vector.} + +\item{propensity_1}{(Optional) Propensities used for prediction in the "treatment" case. Must be a matrix or vector.} + +\item{rfx_group_ids_0}{(Optional) Test set group labels used for prediction from an additive random effects +model in the "control" case. We do not currently support (but plan to in the near future), test set evaluation +for group labels that were not in the training set. Must be a vector.} + +\item{rfx_group_ids_1}{(Optional) Test set group labels used for prediction from an additive random effects +model in the "treatment" case. We do not currently support (but plan to in the near future), test set evaluation +for group labels that were not in the training set. Must be a vector.} + +\item{rfx_basis_0}{(Optional) Test set basis for used for prediction from an additive random effects model in the "control" case. Must be a matrix or vector.} + +\item{rfx_basis_1}{(Optional) Test set basis for used for prediction from an additive random effects model in the "treatment" case. Must be a matrix or vector.} + +\item{type}{(Optional) Aggregation level of the contrast. Options are "mean", which averages the contrast evaluations over every draw of a BCF model, and "posterior", which returns the entire matrix of posterior contrast estimates. Default: "posterior".} + +\item{scale}{(Optional) Scale of the contrast. Options are "linear", which returns a contrast on the original scale of the mean forest / RFX terms, and "probability", which transforms each contrast term into a probability of observing \code{y == 1} before taking their difference. "probability" is only valid for models fit with a probit outcome model. Default: "linear".} +} +\value{ +List of prediction matrices or single prediction matrix / vector, depending on the terms requested. +} +\description{ +Compute a contrast using a BCF model by making two sets of outcome predictions and taking their difference. +For simple BCF models with binary treatment, this will yield the same prediction as requesting \code{terms = "cate"} +in the \code{predict.bcfmodel} function. For more general models, such as models with continuous / multivariate treatments or +an additive random effects term with a coefficient on the treatment, this function provides the flexibility to compute a +any contrast of interest by specifying covariates, treatment, and random effects bases and IDs for both sides of a two term +contrast. For simplicity, we refer to the subtrahend of the contrast as the "control" or \code{Y0} term and the minuend of the +contrast as the \code{Y1} term, though the requested contrast need not match the "control vs treatment" terminology of a classic +two-arm experiment. We mirror the function calls and terminology of the \code{predict.bcfmodel} function, labeling each prediction +data term with a \code{1} to denote its contribution to the treatment prediction of a contrast and \code{0} to denote inclusion in the +control prediction. +} +\examples{ +n <- 500 +p <- 5 +X <- matrix(runif(n*p), ncol = p) +mu_x <- ( + ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + + ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + + ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + + ((0.75 <= X[,1]) & (1 > X[,1])) * (7.5) +) +pi_x <- ( + ((0 <= X[,1]) & (0.25 > X[,1])) * (0.2) + + ((0.25 <= X[,1]) & (0.5 > X[,1])) * (0.4) + + ((0.5 <= X[,1]) & (0.75 > X[,1])) * (0.6) + + ((0.75 <= X[,1]) & (1 > X[,1])) * (0.8) +) +tau_x <- ( + ((0 <= X[,2]) & (0.25 > X[,2])) * (0.5) + + ((0.25 <= X[,2]) & (0.5 > X[,2])) * (1.0) + + ((0.5 <= X[,2]) & (0.75 > X[,2])) * (1.5) + + ((0.75 <= X[,2]) & (1 > X[,2])) * (2.0) +) +Z <- rbinom(n, 1, pi_x) +noise_sd <- 1 +y <- mu_x + tau_x*Z + rnorm(n, 0, noise_sd) +test_set_pct <- 0.2 +n_test <- round(test_set_pct*n) +n_train <- n - n_test +test_inds <- sort(sample(1:n, n_test, replace = FALSE)) +train_inds <- (1:n)[!((1:n) \%in\% test_inds)] +X_test <- X[test_inds,] +X_train <- X[train_inds,] +pi_test <- pi_x[test_inds] +pi_train <- pi_x[train_inds] +Z_test <- Z[test_inds] +Z_train <- Z[train_inds] +y_test <- y[test_inds] +y_train <- y[train_inds] +mu_test <- mu_x[test_inds] +mu_train <- mu_x[train_inds] +tau_test <- tau_x[test_inds] +tau_train <- tau_x[train_inds] +bcf_model <- bcf(X_train = X_train, Z_train = Z_train, y_train = y_train, + propensity_train = pi_train, num_gfr = 10, + num_burnin = 0, num_mcmc = 10) +tau_hat_test <- compute_contrast_bcf_model( + bcf_model, X_0=X_test, X_1=X_test, Z_0=rep(0, n_test), Z_1=rep(1, n_test), + propensity_0 = pi_test, propensity_1 = pi_test +) +} diff --git a/man/createBARTModelFromCombinedJson.Rd b/man/createBARTModelFromCombinedJson.Rd index 35d185c3..83d61d0d 100644 --- a/man/createBARTModelFromCombinedJson.Rd +++ b/man/createBARTModelFromCombinedJson.Rd @@ -22,9 +22,9 @@ n <- 100 p <- 5 X <- matrix(runif(n*p), ncol = p) f_XW <- ( - ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + - ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + - ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + + ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + + ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + + ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + ((0.75 <= X[,1]) & (1 > X[,1])) * (7.5) ) noise_sd <- 1 @@ -38,7 +38,7 @@ X_test <- X[test_inds,] X_train <- X[train_inds,] y_test <- y[test_inds] y_train <- y[train_inds] -bart_model <- bart(X_train = X_train, y_train = y_train, +bart_model <- bart(X_train = X_train, y_train = y_train, num_gfr = 10, num_burnin = 0, num_mcmc = 10) bart_json <- list(saveBARTModelToJson(bart_model)) bart_model_roundtrip <- createBARTModelFromCombinedJson(bart_json) diff --git a/man/createBARTModelFromCombinedJsonString.Rd b/man/createBARTModelFromCombinedJsonString.Rd index a8470dee..7a17484a 100644 --- a/man/createBARTModelFromCombinedJsonString.Rd +++ b/man/createBARTModelFromCombinedJsonString.Rd @@ -22,9 +22,9 @@ n <- 100 p <- 5 X <- matrix(runif(n*p), ncol = p) f_XW <- ( - ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + - ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + - ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + + ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + + ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + + ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + ((0.75 <= X[,1]) & (1 > X[,1])) * (7.5) ) noise_sd <- 1 @@ -38,7 +38,7 @@ X_test <- X[test_inds,] X_train <- X[train_inds,] y_test <- y[test_inds] y_train <- y[train_inds] -bart_model <- bart(X_train = X_train, y_train = y_train, +bart_model <- bart(X_train = X_train, y_train = y_train, num_gfr = 10, num_burnin = 0, num_mcmc = 10) bart_json_string_list <- list(saveBARTModelToJsonString(bart_model)) bart_model_roundtrip <- createBARTModelFromCombinedJsonString(bart_json_string_list) diff --git a/man/createBARTModelFromJson.Rd b/man/createBARTModelFromJson.Rd index 57686122..68a02f0e 100644 --- a/man/createBARTModelFromJson.Rd +++ b/man/createBARTModelFromJson.Rd @@ -22,9 +22,9 @@ n <- 100 p <- 5 X <- matrix(runif(n*p), ncol = p) f_XW <- ( - ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + - ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + - ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + + ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + + ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + + ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + ((0.75 <= X[,1]) & (1 > X[,1])) * (7.5) ) noise_sd <- 1 @@ -38,7 +38,7 @@ X_test <- X[test_inds,] X_train <- X[train_inds,] y_test <- y[test_inds] y_train <- y[train_inds] -bart_model <- bart(X_train = X_train, y_train = y_train, +bart_model <- bart(X_train = X_train, y_train = y_train, num_gfr = 10, num_burnin = 0, num_mcmc = 10) bart_json <- saveBARTModelToJson(bart_model) bart_model_roundtrip <- createBARTModelFromJson(bart_json) diff --git a/man/createBARTModelFromJsonFile.Rd b/man/createBARTModelFromJsonFile.Rd index f714a94a..7608d8d2 100644 --- a/man/createBARTModelFromJsonFile.Rd +++ b/man/createBARTModelFromJsonFile.Rd @@ -22,9 +22,9 @@ n <- 100 p <- 5 X <- matrix(runif(n*p), ncol = p) f_XW <- ( - ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + - ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + - ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + + ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + + ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + + ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + ((0.75 <= X[,1]) & (1 > X[,1])) * (7.5) ) noise_sd <- 1 @@ -38,7 +38,7 @@ X_test <- X[test_inds,] X_train <- X[train_inds,] y_test <- y[test_inds] y_train <- y[train_inds] -bart_model <- bart(X_train = X_train, y_train = y_train, +bart_model <- bart(X_train = X_train, y_train = y_train, num_gfr = 10, num_burnin = 0, num_mcmc = 10) tmpjson <- tempfile(fileext = ".json") saveBARTModelToJsonFile(bart_model, file.path(tmpjson)) diff --git a/man/createBARTModelFromJsonString.Rd b/man/createBARTModelFromJsonString.Rd index 67068fd0..0748d97a 100644 --- a/man/createBARTModelFromJsonString.Rd +++ b/man/createBARTModelFromJsonString.Rd @@ -22,9 +22,9 @@ n <- 100 p <- 5 X <- matrix(runif(n*p), ncol = p) f_XW <- ( - ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + - ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + - ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + + ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + + ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + + ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + ((0.75 <= X[,1]) & (1 > X[,1])) * (7.5) ) noise_sd <- 1 @@ -38,7 +38,7 @@ X_test <- X[test_inds,] X_train <- X[train_inds,] y_test <- y[test_inds] y_train <- y[train_inds] -bart_model <- bart(X_train = X_train, y_train = y_train, +bart_model <- bart(X_train = X_train, y_train = y_train, num_gfr = 10, num_burnin = 0, num_mcmc = 10) bart_json <- saveBARTModelToJsonString(bart_model) bart_model_roundtrip <- createBARTModelFromJsonString(bart_json) diff --git a/man/createBCFModelFromCombinedJson.Rd b/man/createBCFModelFromCombinedJson.Rd index 6f29569e..24c82e4f 100644 --- a/man/createBCFModelFromCombinedJson.Rd +++ b/man/createBCFModelFromCombinedJson.Rd @@ -22,21 +22,21 @@ n <- 500 p <- 5 X <- matrix(runif(n*p), ncol = p) mu_x <- ( - ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + - ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + - ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + + ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + + ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + + ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + ((0.75 <= X[,1]) & (1 > X[,1])) * (7.5) ) pi_x <- ( - ((0 <= X[,1]) & (0.25 > X[,1])) * (0.2) + - ((0.25 <= X[,1]) & (0.5 > X[,1])) * (0.4) + - ((0.5 <= X[,1]) & (0.75 > X[,1])) * (0.6) + + ((0 <= X[,1]) & (0.25 > X[,1])) * (0.2) + + ((0.25 <= X[,1]) & (0.5 > X[,1])) * (0.4) + + ((0.5 <= X[,1]) & (0.75 > X[,1])) * (0.6) + ((0.75 <= X[,1]) & (1 > X[,1])) * (0.8) ) tau_x <- ( - ((0 <= X[,2]) & (0.25 > X[,2])) * (0.5) + - ((0.25 <= X[,2]) & (0.5 > X[,2])) * (1.0) + - ((0.5 <= X[,2]) & (0.75 > X[,2])) * (1.5) + + ((0 <= X[,2]) & (0.25 > X[,2])) * (0.5) + + ((0.25 <= X[,2]) & (0.5 > X[,2])) * (1.0) + + ((0.5 <= X[,2]) & (0.75 > X[,2])) * (1.5) + ((0.75 <= X[,2]) & (1 > X[,2])) * (2.0) ) Z <- rbinom(n, 1, pi_x) @@ -70,13 +70,13 @@ rfx_basis_test <- rfx_basis[test_inds,] rfx_basis_train <- rfx_basis[train_inds,] rfx_term_test <- rfx_term[test_inds] rfx_term_train <- rfx_term[train_inds] -bcf_model <- bcf(X_train = X_train, Z_train = Z_train, y_train = y_train, - propensity_train = pi_train, - rfx_group_ids_train = rfx_group_ids_train, - rfx_basis_train = rfx_basis_train, X_test = X_test, - Z_test = Z_test, propensity_test = pi_test, +bcf_model <- bcf(X_train = X_train, Z_train = Z_train, y_train = y_train, + propensity_train = pi_train, + rfx_group_ids_train = rfx_group_ids_train, + rfx_basis_train = rfx_basis_train, X_test = X_test, + Z_test = Z_test, propensity_test = pi_test, rfx_group_ids_test = rfx_group_ids_test, - rfx_basis_test = rfx_basis_test, + rfx_basis_test = rfx_basis_test, num_gfr = 10, num_burnin = 0, num_mcmc = 10) bcf_json_list <- list(saveBCFModelToJson(bcf_model)) bcf_model_roundtrip <- createBCFModelFromCombinedJson(bcf_json_list) diff --git a/man/createBCFModelFromCombinedJsonString.Rd b/man/createBCFModelFromCombinedJsonString.Rd index bd7e63f2..e0522f75 100644 --- a/man/createBCFModelFromCombinedJsonString.Rd +++ b/man/createBCFModelFromCombinedJsonString.Rd @@ -22,21 +22,21 @@ n <- 500 p <- 5 X <- matrix(runif(n*p), ncol = p) mu_x <- ( - ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + - ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + - ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + + ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + + ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + + ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + ((0.75 <= X[,1]) & (1 > X[,1])) * (7.5) ) pi_x <- ( - ((0 <= X[,1]) & (0.25 > X[,1])) * (0.2) + - ((0.25 <= X[,1]) & (0.5 > X[,1])) * (0.4) + - ((0.5 <= X[,1]) & (0.75 > X[,1])) * (0.6) + + ((0 <= X[,1]) & (0.25 > X[,1])) * (0.2) + + ((0.25 <= X[,1]) & (0.5 > X[,1])) * (0.4) + + ((0.5 <= X[,1]) & (0.75 > X[,1])) * (0.6) + ((0.75 <= X[,1]) & (1 > X[,1])) * (0.8) ) tau_x <- ( - ((0 <= X[,2]) & (0.25 > X[,2])) * (0.5) + - ((0.25 <= X[,2]) & (0.5 > X[,2])) * (1.0) + - ((0.5 <= X[,2]) & (0.75 > X[,2])) * (1.5) + + ((0 <= X[,2]) & (0.25 > X[,2])) * (0.5) + + ((0.25 <= X[,2]) & (0.5 > X[,2])) * (1.0) + + ((0.5 <= X[,2]) & (0.75 > X[,2])) * (1.5) + ((0.75 <= X[,2]) & (1 > X[,2])) * (2.0) ) Z <- rbinom(n, 1, pi_x) @@ -70,13 +70,13 @@ rfx_basis_test <- rfx_basis[test_inds,] rfx_basis_train <- rfx_basis[train_inds,] rfx_term_test <- rfx_term[test_inds] rfx_term_train <- rfx_term[train_inds] -bcf_model <- bcf(X_train = X_train, Z_train = Z_train, y_train = y_train, - propensity_train = pi_train, - rfx_group_ids_train = rfx_group_ids_train, - rfx_basis_train = rfx_basis_train, X_test = X_test, - Z_test = Z_test, propensity_test = pi_test, +bcf_model <- bcf(X_train = X_train, Z_train = Z_train, y_train = y_train, + propensity_train = pi_train, + rfx_group_ids_train = rfx_group_ids_train, + rfx_basis_train = rfx_basis_train, X_test = X_test, + Z_test = Z_test, propensity_test = pi_test, rfx_group_ids_test = rfx_group_ids_test, - rfx_basis_test = rfx_basis_test, + rfx_basis_test = rfx_basis_test, num_gfr = 10, num_burnin = 0, num_mcmc = 10) bcf_json_string_list <- list(saveBCFModelToJsonString(bcf_model)) bcf_model_roundtrip <- createBCFModelFromCombinedJsonString(bcf_json_string_list) diff --git a/man/createBCFModelFromJson.Rd b/man/createBCFModelFromJson.Rd index a579b140..35cff7ce 100644 --- a/man/createBCFModelFromJson.Rd +++ b/man/createBCFModelFromJson.Rd @@ -22,21 +22,21 @@ n <- 500 p <- 5 X <- matrix(runif(n*p), ncol = p) mu_x <- ( - ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + - ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + - ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + + ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + + ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + + ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + ((0.75 <= X[,1]) & (1 > X[,1])) * (7.5) ) pi_x <- ( - ((0 <= X[,1]) & (0.25 > X[,1])) * (0.2) + - ((0.25 <= X[,1]) & (0.5 > X[,1])) * (0.4) + - ((0.5 <= X[,1]) & (0.75 > X[,1])) * (0.6) + + ((0 <= X[,1]) & (0.25 > X[,1])) * (0.2) + + ((0.25 <= X[,1]) & (0.5 > X[,1])) * (0.4) + + ((0.5 <= X[,1]) & (0.75 > X[,1])) * (0.6) + ((0.75 <= X[,1]) & (1 > X[,1])) * (0.8) ) tau_x <- ( - ((0 <= X[,2]) & (0.25 > X[,2])) * (0.5) + - ((0.25 <= X[,2]) & (0.5 > X[,2])) * (1.0) + - ((0.5 <= X[,2]) & (0.75 > X[,2])) * (1.5) + + ((0 <= X[,2]) & (0.25 > X[,2])) * (0.5) + + ((0.25 <= X[,2]) & (0.5 > X[,2])) * (1.0) + + ((0.5 <= X[,2]) & (0.75 > X[,2])) * (1.5) + ((0.75 <= X[,2]) & (1 > X[,2])) * (2.0) ) Z <- rbinom(n, 1, pi_x) @@ -72,15 +72,15 @@ rfx_term_test <- rfx_term[test_inds] rfx_term_train <- rfx_term[train_inds] mu_params <- list(sample_sigma2_leaf = TRUE) tau_params <- list(sample_sigma2_leaf = FALSE) -bcf_model <- bcf(X_train = X_train, Z_train = Z_train, y_train = y_train, - propensity_train = pi_train, - rfx_group_ids_train = rfx_group_ids_train, - rfx_basis_train = rfx_basis_train, X_test = X_test, - Z_test = Z_test, propensity_test = pi_test, +bcf_model <- bcf(X_train = X_train, Z_train = Z_train, y_train = y_train, + propensity_train = pi_train, + rfx_group_ids_train = rfx_group_ids_train, + rfx_basis_train = rfx_basis_train, X_test = X_test, + Z_test = Z_test, propensity_test = pi_test, rfx_group_ids_test = rfx_group_ids_test, - rfx_basis_test = rfx_basis_test, - num_gfr = 10, num_burnin = 0, num_mcmc = 10, - prognostic_forest_params = mu_params, + rfx_basis_test = rfx_basis_test, + num_gfr = 10, num_burnin = 0, num_mcmc = 10, + prognostic_forest_params = mu_params, treatment_effect_forest_params = tau_params) bcf_json <- saveBCFModelToJson(bcf_model) bcf_model_roundtrip <- createBCFModelFromJson(bcf_json) diff --git a/man/createBCFModelFromJsonFile.Rd b/man/createBCFModelFromJsonFile.Rd index 2661d4de..a2496797 100644 --- a/man/createBCFModelFromJsonFile.Rd +++ b/man/createBCFModelFromJsonFile.Rd @@ -22,21 +22,21 @@ n <- 500 p <- 5 X <- matrix(runif(n*p), ncol = p) mu_x <- ( - ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + - ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + - ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + + ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + + ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + + ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + ((0.75 <= X[,1]) & (1 > X[,1])) * (7.5) ) pi_x <- ( - ((0 <= X[,1]) & (0.25 > X[,1])) * (0.2) + - ((0.25 <= X[,1]) & (0.5 > X[,1])) * (0.4) + - ((0.5 <= X[,1]) & (0.75 > X[,1])) * (0.6) + + ((0 <= X[,1]) & (0.25 > X[,1])) * (0.2) + + ((0.25 <= X[,1]) & (0.5 > X[,1])) * (0.4) + + ((0.5 <= X[,1]) & (0.75 > X[,1])) * (0.6) + ((0.75 <= X[,1]) & (1 > X[,1])) * (0.8) ) tau_x <- ( - ((0 <= X[,2]) & (0.25 > X[,2])) * (0.5) + - ((0.25 <= X[,2]) & (0.5 > X[,2])) * (1.0) + - ((0.5 <= X[,2]) & (0.75 > X[,2])) * (1.5) + + ((0 <= X[,2]) & (0.25 > X[,2])) * (0.5) + + ((0.25 <= X[,2]) & (0.5 > X[,2])) * (1.0) + + ((0.5 <= X[,2]) & (0.75 > X[,2])) * (1.5) + ((0.75 <= X[,2]) & (1 > X[,2])) * (2.0) ) Z <- rbinom(n, 1, pi_x) @@ -72,15 +72,15 @@ rfx_term_test <- rfx_term[test_inds] rfx_term_train <- rfx_term[train_inds] mu_params <- list(sample_sigma2_leaf = TRUE) tau_params <- list(sample_sigma2_leaf = FALSE) -bcf_model <- bcf(X_train = X_train, Z_train = Z_train, y_train = y_train, - propensity_train = pi_train, - rfx_group_ids_train = rfx_group_ids_train, - rfx_basis_train = rfx_basis_train, X_test = X_test, - Z_test = Z_test, propensity_test = pi_test, +bcf_model <- bcf(X_train = X_train, Z_train = Z_train, y_train = y_train, + propensity_train = pi_train, + rfx_group_ids_train = rfx_group_ids_train, + rfx_basis_train = rfx_basis_train, X_test = X_test, + Z_test = Z_test, propensity_test = pi_test, rfx_group_ids_test = rfx_group_ids_test, - rfx_basis_test = rfx_basis_test, - num_gfr = 10, num_burnin = 0, num_mcmc = 10, - prognostic_forest_params = mu_params, + rfx_basis_test = rfx_basis_test, + num_gfr = 10, num_burnin = 0, num_mcmc = 10, + prognostic_forest_params = mu_params, treatment_effect_forest_params = tau_params) tmpjson <- tempfile(fileext = ".json") saveBCFModelToJsonFile(bcf_model, file.path(tmpjson)) diff --git a/man/createBCFModelFromJsonString.Rd b/man/createBCFModelFromJsonString.Rd index 5f34724c..cc944f85 100644 --- a/man/createBCFModelFromJsonString.Rd +++ b/man/createBCFModelFromJsonString.Rd @@ -22,21 +22,21 @@ n <- 500 p <- 5 X <- matrix(runif(n*p), ncol = p) mu_x <- ( - ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + - ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + - ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + + ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + + ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + + ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + ((0.75 <= X[,1]) & (1 > X[,1])) * (7.5) ) pi_x <- ( - ((0 <= X[,1]) & (0.25 > X[,1])) * (0.2) + - ((0.25 <= X[,1]) & (0.5 > X[,1])) * (0.4) + - ((0.5 <= X[,1]) & (0.75 > X[,1])) * (0.6) + + ((0 <= X[,1]) & (0.25 > X[,1])) * (0.2) + + ((0.25 <= X[,1]) & (0.5 > X[,1])) * (0.4) + + ((0.5 <= X[,1]) & (0.75 > X[,1])) * (0.6) + ((0.75 <= X[,1]) & (1 > X[,1])) * (0.8) ) tau_x <- ( - ((0 <= X[,2]) & (0.25 > X[,2])) * (0.5) + - ((0.25 <= X[,2]) & (0.5 > X[,2])) * (1.0) + - ((0.5 <= X[,2]) & (0.75 > X[,2])) * (1.5) + + ((0 <= X[,2]) & (0.25 > X[,2])) * (0.5) + + ((0.25 <= X[,2]) & (0.5 > X[,2])) * (1.0) + + ((0.5 <= X[,2]) & (0.75 > X[,2])) * (1.5) + ((0.75 <= X[,2]) & (1 > X[,2])) * (2.0) ) Z <- rbinom(n, 1, pi_x) @@ -70,13 +70,13 @@ rfx_basis_test <- rfx_basis[test_inds,] rfx_basis_train <- rfx_basis[train_inds,] rfx_term_test <- rfx_term[test_inds] rfx_term_train <- rfx_term[train_inds] -bcf_model <- bcf(X_train = X_train, Z_train = Z_train, y_train = y_train, - propensity_train = pi_train, - rfx_group_ids_train = rfx_group_ids_train, - rfx_basis_train = rfx_basis_train, X_test = X_test, - Z_test = Z_test, propensity_test = pi_test, +bcf_model <- bcf(X_train = X_train, Z_train = Z_train, y_train = y_train, + propensity_train = pi_train, + rfx_group_ids_train = rfx_group_ids_train, + rfx_basis_train = rfx_basis_train, X_test = X_test, + Z_test = Z_test, propensity_test = pi_test, rfx_group_ids_test = rfx_group_ids_test, - rfx_basis_test = rfx_basis_test, + rfx_basis_test = rfx_basis_test, num_gfr = 10, num_burnin = 0, num_mcmc = 10) bcf_json <- saveBCFModelToJsonString(bcf_model) bcf_model_roundtrip <- createBCFModelFromJsonString(bcf_json) diff --git a/man/createForestModel.Rd b/man/createForestModel.Rd index d9000925..d7a1adae 100644 --- a/man/createForestModel.Rd +++ b/man/createForestModel.Rd @@ -30,10 +30,10 @@ max_depth <- 10 feature_types <- as.integer(rep(0, p)) X <- matrix(runif(n*p), ncol = p) forest_dataset <- createForestDataset(X) -forest_model_config <- createForestModelConfig(feature_types=feature_types, - num_trees=num_trees, num_features=p, - num_observations=n, alpha=alpha, beta=beta, - min_samples_leaf=min_samples_leaf, +forest_model_config <- createForestModelConfig(feature_types=feature_types, + num_trees=num_trees, num_features=p, + num_observations=n, alpha=alpha, beta=beta, + min_samples_leaf=min_samples_leaf, max_depth=max_depth, leaf_model_type=1) global_model_config <- createGlobalModelConfig(global_error_variance=1.0) forest_model <- createForestModel(forest_dataset, forest_model_config, global_model_config) diff --git a/man/getRandomEffectSamples.bartmodel.Rd b/man/getRandomEffectSamples.bartmodel.Rd index 0da1eb98..149586a8 100644 --- a/man/getRandomEffectSamples.bartmodel.Rd +++ b/man/getRandomEffectSamples.bartmodel.Rd @@ -24,9 +24,9 @@ n <- 100 p <- 5 X <- matrix(runif(n*p), ncol = p) f_XW <- ( - ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + - ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + - ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + + ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + + ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + + ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + ((0.75 <= X[,1]) & (1 > X[,1])) * (7.5) ) snr <- 3 @@ -51,11 +51,11 @@ rfx_basis_test <- rfx_basis[test_inds,] rfx_basis_train <- rfx_basis[train_inds,] rfx_term_test <- rfx_term[test_inds] rfx_term_train <- rfx_term[train_inds] -bart_model <- bart(X_train = X_train, y_train = y_train, X_test = X_test, - rfx_group_ids_train = rfx_group_ids_train, - rfx_group_ids_test = rfx_group_ids_test, - rfx_basis_train = rfx_basis_train, - rfx_basis_test = rfx_basis_test, +bart_model <- bart(X_train = X_train, y_train = y_train, X_test = X_test, + rfx_group_ids_train = rfx_group_ids_train, + rfx_group_ids_test = rfx_group_ids_test, + rfx_basis_train = rfx_basis_train, + rfx_basis_test = rfx_basis_test, num_gfr = 10, num_burnin = 0, num_mcmc = 10) rfx_samples <- getRandomEffectSamples(bart_model) } diff --git a/man/getRandomEffectSamples.bcfmodel.Rd b/man/getRandomEffectSamples.bcfmodel.Rd index 6769de62..08a8eae4 100644 --- a/man/getRandomEffectSamples.bcfmodel.Rd +++ b/man/getRandomEffectSamples.bcfmodel.Rd @@ -24,21 +24,21 @@ n <- 500 p <- 5 X <- matrix(runif(n*p), ncol = p) mu_x <- ( - ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + - ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + - ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + + ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + + ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + + ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + ((0.75 <= X[,1]) & (1 > X[,1])) * (7.5) ) pi_x <- ( - ((0 <= X[,1]) & (0.25 > X[,1])) * (0.2) + - ((0.25 <= X[,1]) & (0.5 > X[,1])) * (0.4) + - ((0.5 <= X[,1]) & (0.75 > X[,1])) * (0.6) + + ((0 <= X[,1]) & (0.25 > X[,1])) * (0.2) + + ((0.25 <= X[,1]) & (0.5 > X[,1])) * (0.4) + + ((0.5 <= X[,1]) & (0.75 > X[,1])) * (0.6) + ((0.75 <= X[,1]) & (1 > X[,1])) * (0.8) ) tau_x <- ( - ((0 <= X[,2]) & (0.25 > X[,2])) * (0.5) + - ((0.25 <= X[,2]) & (0.5 > X[,2])) * (1.0) + - ((0.5 <= X[,2]) & (0.75 > X[,2])) * (1.5) + + ((0 <= X[,2]) & (0.25 > X[,2])) * (0.5) + + ((0.25 <= X[,2]) & (0.5 > X[,2])) * (1.0) + + ((0.5 <= X[,2]) & (0.75 > X[,2])) * (1.5) + ((0.75 <= X[,2]) & (1 > X[,2])) * (2.0) ) Z <- rbinom(n, 1, pi_x) @@ -74,15 +74,15 @@ rfx_term_test <- rfx_term[test_inds] rfx_term_train <- rfx_term[train_inds] mu_params <- list(sample_sigma2_leaf = TRUE) tau_params <- list(sample_sigma2_leaf = FALSE) -bcf_model <- bcf(X_train = X_train, Z_train = Z_train, y_train = y_train, - propensity_train = pi_train, - rfx_group_ids_train = rfx_group_ids_train, - rfx_basis_train = rfx_basis_train, X_test = X_test, - Z_test = Z_test, propensity_test = pi_test, +bcf_model <- bcf(X_train = X_train, Z_train = Z_train, y_train = y_train, + propensity_train = pi_train, + rfx_group_ids_train = rfx_group_ids_train, + rfx_basis_train = rfx_basis_train, X_test = X_test, + Z_test = Z_test, propensity_test = pi_test, rfx_group_ids_test = rfx_group_ids_test, - rfx_basis_test = rfx_basis_test, - num_gfr = 10, num_burnin = 0, num_mcmc = 10, - prognostic_forest_params = mu_params, + rfx_basis_test = rfx_basis_test, + num_gfr = 10, num_burnin = 0, num_mcmc = 10, + prognostic_forest_params = mu_params, treatment_effect_forest_params = tau_params) rfx_samples <- getRandomEffectSamples(bcf_model) } diff --git a/man/predict.bartmodel.Rd b/man/predict.bartmodel.Rd index 2afccbf6..0cb82678 100644 --- a/man/predict.bartmodel.Rd +++ b/man/predict.bartmodel.Rd @@ -6,17 +6,20 @@ \usage{ \method{predict}{bartmodel}( object, - X, + covariates, leaf_basis = NULL, rfx_group_ids = NULL, rfx_basis = NULL, + type = "posterior", + terms = "all", + scale = "linear", ... ) } \arguments{ \item{object}{Object of type \code{bart} containing draws of a regression forest and associated sampling outputs.} -\item{X}{Covariates used to determine tree leaf predictions for each observation. Must be passed as a matrix or dataframe.} +\item{covariates}{Covariates used to determine tree leaf predictions for each observation. Must be passed as a matrix or dataframe.} \item{leaf_basis}{(Optional) Bases used for prediction (by e.g. dot product with leaf values). Default: \code{NULL}.} @@ -26,11 +29,16 @@ that were not in the training set.} \item{rfx_basis}{(Optional) Test set basis for "random-slope" regression in additive random effects model.} +\item{type}{(Optional) Type of prediction to return. Options are "mean", which averages the predictions from every draw of a BART model, and "posterior", which returns the entire matrix of posterior predictions. Default: "posterior".} + +\item{terms}{(Optional) Which model terms to include in the prediction. This can be a single term or a list of model terms. Options include "y_hat", "mean_forest", "rfx", "variance_forest", or "all". If a model doesn't have mean forest, random effects, or variance forest predictions, but one of those terms is request, the request will simply be ignored. If none of the requested terms are present in a model, this function will return \code{NULL} along with a warning. Default: "all".} + +\item{scale}{(Optional) Scale of mean function predictions. Options are "linear", which returns predictions on the original scale of the mean forest / RFX terms, and "probability", which transforms predictions into a probability of observing \code{y == 1}. "probability" is only valid for models fit with a probit outcome model. Default: "linear".} + \item{...}{(Optional) Other prediction parameters.} } \value{ -List of prediction matrices. If model does not have random effects, the list has one element -- the predictions from the forest. -If the model does have random effects, the list has three elements -- forest predictions, random effects predictions, and their sum (\code{y_hat}). +List of prediction matrices or single prediction matrix / vector, depending on the terms requested. } \description{ Predict from a sampled BART model on new data @@ -40,9 +48,9 @@ n <- 100 p <- 5 X <- matrix(runif(n*p), ncol = p) f_XW <- ( - ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + - ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + - ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + + ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + + ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + + ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + ((0.75 <= X[,1]) & (1 > X[,1])) * (7.5) ) noise_sd <- 1 @@ -56,7 +64,7 @@ X_test <- X[test_inds,] X_train <- X[train_inds,] y_test <- y[test_inds] y_train <- y[train_inds] -bart_model <- bart(X_train = X_train, y_train = y_train, +bart_model <- bart(X_train = X_train, y_train = y_train, num_gfr = 10, num_burnin = 0, num_mcmc = 10) y_hat_test <- predict(bart_model, X_test)$y_hat } diff --git a/man/predict.bcfmodel.Rd b/man/predict.bcfmodel.Rd index ff315808..bda63aa5 100644 --- a/man/predict.bcfmodel.Rd +++ b/man/predict.bcfmodel.Rd @@ -11,6 +11,9 @@ propensity = NULL, rfx_group_ids = NULL, rfx_basis = NULL, + type = "posterior", + terms = "all", + scale = "linear", ... ) } @@ -27,12 +30,18 @@ We do not currently support (but plan to in the near future), test set evaluation for group labels that were not in the training set.} -\item{rfx_basis}{(Optional) Test set basis for "random-slope" regression in additive random effects model.} +\item{rfx_basis}{(Optional) Test set basis for "random-slope" regression in additive random effects model. If the model was sampled with a random effects \code{model_spec} of "intercept_only" or "intercept_plus_treatment", this is optional, but if it is provided, it will be used.} + +\item{type}{(Optional) Type of prediction to return. Options are "mean", which averages the predictions from every draw of a BCF model, and "posterior", which returns the entire matrix of posterior predictions. Default: "posterior".} + +\item{terms}{(Optional) Which model terms to include in the prediction. This can be a single term or a list of model terms. Options include "y_hat", "prognostic_function", "cate", "rfx", "variance_forest", or "all". If a model doesn't have random effects or variance forest predictions, but one of those terms is request, the request will simply be ignored. If none of the requested terms are present in a model, this function will return \code{NULL} along with a warning. Default: "all".} + +\item{scale}{(Optional) Scale of mean function predictions. Options are "linear", which returns predictions on the original scale of the mean forest / RFX terms, and "probability", which transforms predictions into a probability of observing \code{y == 1}. "probability" is only valid for models fit with a probit outcome model. Default: "linear".} \item{...}{(Optional) Other prediction parameters.} } \value{ -List of 3-5 \code{nrow(X)} by \code{object$num_samples} matrices: prognostic function estimates, treatment effect estimates, (optionally) random effects predictions, (optionally) variance forest predictions, and outcome predictions. +List of prediction matrices or single prediction matrix / vector, depending on the terms requested. } \description{ Predict from a sampled BCF model on new data @@ -42,21 +51,21 @@ n <- 500 p <- 5 X <- matrix(runif(n*p), ncol = p) mu_x <- ( - ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + - ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + - ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + + ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + + ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + + ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + ((0.75 <= X[,1]) & (1 > X[,1])) * (7.5) ) pi_x <- ( - ((0 <= X[,1]) & (0.25 > X[,1])) * (0.2) + - ((0.25 <= X[,1]) & (0.5 > X[,1])) * (0.4) + - ((0.5 <= X[,1]) & (0.75 > X[,1])) * (0.6) + + ((0 <= X[,1]) & (0.25 > X[,1])) * (0.2) + + ((0.25 <= X[,1]) & (0.5 > X[,1])) * (0.4) + + ((0.5 <= X[,1]) & (0.75 > X[,1])) * (0.6) + ((0.75 <= X[,1]) & (1 > X[,1])) * (0.8) ) tau_x <- ( - ((0 <= X[,2]) & (0.25 > X[,2])) * (0.5) + - ((0.25 <= X[,2]) & (0.5 > X[,2])) * (1.0) + - ((0.5 <= X[,2]) & (0.75 > X[,2])) * (1.5) + + ((0 <= X[,2]) & (0.25 > X[,2])) * (0.5) + + ((0.25 <= X[,2]) & (0.5 > X[,2])) * (1.0) + + ((0.5 <= X[,2]) & (0.75 > X[,2])) * (1.5) + ((0.75 <= X[,2]) & (1 > X[,2])) * (2.0) ) Z <- rbinom(n, 1, pi_x) @@ -79,8 +88,8 @@ mu_test <- mu_x[test_inds] mu_train <- mu_x[train_inds] tau_test <- tau_x[test_inds] tau_train <- tau_x[train_inds] -bcf_model <- bcf(X_train = X_train, Z_train = Z_train, y_train = y_train, - propensity_train = pi_train, num_gfr = 10, +bcf_model <- bcf(X_train = X_train, Z_train = Z_train, y_train = y_train, + propensity_train = pi_train, num_gfr = 10, num_burnin = 0, num_mcmc = 10) preds <- predict(bcf_model, X_test, Z_test, pi_test) } diff --git a/man/preprocessPredictionData.Rd b/man/preprocessPredictionData.Rd index f881fda8..a6382e69 100644 --- a/man/preprocessPredictionData.Rd +++ b/man/preprocessPredictionData.Rd @@ -22,7 +22,7 @@ types. Matrices will be passed through assuming all columns are numeric. } \examples{ cov_df <- data.frame(x1 = 1:5, x2 = 5:1, x3 = 6:10) -metadata <- list(num_ordered_cat_vars = 0, num_unordered_cat_vars = 0, +metadata <- list(num_ordered_cat_vars = 0, num_unordered_cat_vars = 0, num_numeric_vars = 3, numeric_vars = c("x1", "x2", "x3")) X_preprocessed <- preprocessPredictionData(cov_df, metadata) } diff --git a/man/resetForestModel.Rd b/man/resetForestModel.Rd index f0fec6ca..b02158d4 100644 --- a/man/resetForestModel.Rd +++ b/man/resetForestModel.Rd @@ -48,23 +48,23 @@ y <- -5 + 10*(X[,1] > 0.5) + rnorm(n) outcome <- createOutcome(y) rng <- createCppRNG(1234) global_model_config <- createGlobalModelConfig(global_error_variance=sigma2) -forest_model_config <- createForestModelConfig(feature_types=feature_types, - num_trees=num_trees, num_observations=n, - num_features=p, alpha=alpha, beta=beta, - min_samples_leaf=min_samples_leaf, - max_depth=max_depth, - variable_weights=variable_weights, - cutpoint_grid_size=cutpoint_grid_size, - leaf_model_type=leaf_model, +forest_model_config <- createForestModelConfig(feature_types=feature_types, + num_trees=num_trees, num_observations=n, + num_features=p, alpha=alpha, beta=beta, + min_samples_leaf=min_samples_leaf, + max_depth=max_depth, + variable_weights=variable_weights, + cutpoint_grid_size=cutpoint_grid_size, + leaf_model_type=leaf_model, leaf_model_scale=leaf_scale) forest_model <- createForestModel(forest_dataset, forest_model_config, global_model_config) active_forest <- createForest(num_trees, leaf_dimension, is_leaf_constant, is_exponentiated) -forest_samples <- createForestSamples(num_trees, leaf_dimension, +forest_samples <- createForestSamples(num_trees, leaf_dimension, is_leaf_constant, is_exponentiated) active_forest$prepare_for_sampler(forest_dataset, outcome, forest_model, 0, 0.) forest_model$sample_one_iteration( - forest_dataset, outcome, forest_samples, active_forest, - rng, forest_model_config, global_model_config, + forest_dataset, outcome, forest_samples, active_forest, + rng, forest_model_config, global_model_config, keep_forest = TRUE, gfr = FALSE ) resetActiveForest(active_forest, forest_samples, 0) diff --git a/man/resetRandomEffectsModel.Rd b/man/resetRandomEffectsModel.Rd index fec99b77..b032ccc2 100644 --- a/man/resetRandomEffectsModel.Rd +++ b/man/resetRandomEffectsModel.Rd @@ -49,8 +49,8 @@ rfx_model$set_group_parameter_cov(sigma_xi_init) rfx_model$set_variance_prior_shape(sigma_xi_shape) rfx_model$set_variance_prior_scale(sigma_xi_scale) for (i in 1:3) { - rfx_model$sample_random_effect(rfx_dataset=rfx_dataset, residual=outcome, - rfx_tracker=rfx_tracker, rfx_samples=rfx_samples, + rfx_model$sample_random_effect(rfx_dataset=rfx_dataset, residual=outcome, + rfx_tracker=rfx_tracker, rfx_samples=rfx_samples, keep_sample=TRUE, global_variance=1.0, rng=rng) } resetRandomEffectsModel(rfx_model, rfx_samples, 0, 1.0) diff --git a/man/resetRandomEffectsTracker.Rd b/man/resetRandomEffectsTracker.Rd index 5249ca96..c57af16a 100644 --- a/man/resetRandomEffectsTracker.Rd +++ b/man/resetRandomEffectsTracker.Rd @@ -57,8 +57,8 @@ rfx_model$set_group_parameter_cov(sigma_xi_init) rfx_model$set_variance_prior_shape(sigma_xi_shape) rfx_model$set_variance_prior_scale(sigma_xi_scale) for (i in 1:3) { - rfx_model$sample_random_effect(rfx_dataset=rfx_dataset, residual=outcome, - rfx_tracker=rfx_tracker, rfx_samples=rfx_samples, + rfx_model$sample_random_effect(rfx_dataset=rfx_dataset, residual=outcome, + rfx_tracker=rfx_tracker, rfx_samples=rfx_samples, keep_sample=TRUE, global_variance=1.0, rng=rng) } resetRandomEffectsModel(rfx_model, rfx_samples, 0, 1.0) diff --git a/man/rootResetRandomEffectsModel.Rd b/man/rootResetRandomEffectsModel.Rd index c58a09e9..4c3cc2f7 100644 --- a/man/rootResetRandomEffectsModel.Rd +++ b/man/rootResetRandomEffectsModel.Rd @@ -63,8 +63,8 @@ rfx_model$set_group_parameter_cov(sigma_xi_init) rfx_model$set_variance_prior_shape(sigma_xi_shape) rfx_model$set_variance_prior_scale(sigma_xi_scale) for (i in 1:3) { - rfx_model$sample_random_effect(rfx_dataset=rfx_dataset, residual=outcome, - rfx_tracker=rfx_tracker, rfx_samples=rfx_samples, + rfx_model$sample_random_effect(rfx_dataset=rfx_dataset, residual=outcome, + rfx_tracker=rfx_tracker, rfx_samples=rfx_samples, keep_sample=TRUE, global_variance=1.0, rng=rng) } rootResetRandomEffectsModel(rfx_model, alpha_init, xi_init, sigma_alpha_init, diff --git a/man/rootResetRandomEffectsTracker.Rd b/man/rootResetRandomEffectsTracker.Rd index 8de2c514..6f2dc843 100644 --- a/man/rootResetRandomEffectsTracker.Rd +++ b/man/rootResetRandomEffectsTracker.Rd @@ -49,8 +49,8 @@ rfx_model$set_group_parameter_cov(sigma_xi_init) rfx_model$set_variance_prior_shape(sigma_xi_shape) rfx_model$set_variance_prior_scale(sigma_xi_scale) for (i in 1:3) { - rfx_model$sample_random_effect(rfx_dataset=rfx_dataset, residual=outcome, - rfx_tracker=rfx_tracker, rfx_samples=rfx_samples, + rfx_model$sample_random_effect(rfx_dataset=rfx_dataset, residual=outcome, + rfx_tracker=rfx_tracker, rfx_samples=rfx_samples, keep_sample=TRUE, global_variance=1.0, rng=rng) } rootResetRandomEffectsModel(rfx_model, alpha_init, xi_init, sigma_alpha_init, diff --git a/man/sample_bart_posterior_predictive.Rd b/man/sample_bart_posterior_predictive.Rd new file mode 100644 index 00000000..5bce8442 --- /dev/null +++ b/man/sample_bart_posterior_predictive.Rd @@ -0,0 +1,44 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/posterior_transformation.R +\name{sample_bart_posterior_predictive} +\alias{sample_bart_posterior_predictive} +\title{Sample from the posterior predictive distribution for outcomes modeled by BART} +\usage{ +sample_bart_posterior_predictive( + model_object, + covariates = NULL, + basis = NULL, + rfx_group_ids = NULL, + rfx_basis = NULL, + num_draws_per_sample = NULL +) +} +\arguments{ +\item{model_object}{A fitted BART model object of class \code{bartmodel}.} + +\item{covariates}{A matrix or data frame of covariates. Required if the BART model depends on covariates (e.g., contains a mean or variance forest).} + +\item{basis}{A matrix of bases for mean forest models with regression defined in the leaves. Required for "leaf regression" models.} + +\item{rfx_group_ids}{A vector of group IDs for random effects model. Required if the BART model includes random effects.} + +\item{rfx_basis}{A matrix of bases for random effects model. Required if the BART model includes random effects.} + +\item{num_draws_per_sample}{The number of posterior predictive samples to draw for each posterior sample. Defaults to a heuristic based on the number of samples in a BART model (i.e. if the BART model has >1000 draws, we use 1 draw from the likelihood per sample, otherwise we upsample to ensure intervals are based on at least 1000 posterior predictive draws).} +} +\value{ +Array of posterior predictive samples with dimensions (num_observations, num_posterior_samples, num_draws_per_sample) if num_draws_per_sample > 1, otherwise (num_observations, num_posterior_samples). +} +\description{ +Sample from the posterior predictive distribution for outcomes modeled by BART +} +\examples{ +n <- 100 +p <- 5 +X <- matrix(rnorm(n * p), nrow = n, ncol = p) +y <- 2 * X[,1] + rnorm(n) +bart_model <- bart(y_train = y, X_train = X) +ppd_samples <- sample_bart_posterior_predictive( + model_object = bart_model, covariates = X +) +} diff --git a/man/sample_bcf_posterior_predictive.Rd b/man/sample_bcf_posterior_predictive.Rd new file mode 100644 index 00000000..0c77d7c1 --- /dev/null +++ b/man/sample_bcf_posterior_predictive.Rd @@ -0,0 +1,50 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/posterior_transformation.R +\name{sample_bcf_posterior_predictive} +\alias{sample_bcf_posterior_predictive} +\title{Sample from the posterior predictive distribution for outcomes modeled by BCF} +\usage{ +sample_bcf_posterior_predictive( + model_object, + covariates = NULL, + treatment = NULL, + propensity = NULL, + rfx_group_ids = NULL, + rfx_basis = NULL, + num_draws_per_sample = NULL +) +} +\arguments{ +\item{model_object}{A fitted BCF model object of class \code{bcfmodel}.} + +\item{covariates}{A matrix or data frame of covariates.} + +\item{treatment}{A vector or matrix of treatment assignments.} + +\item{propensity}{(Optional) A vector or matrix of propensity scores. Required if the underlying model depends on user-provided propensities.} + +\item{rfx_group_ids}{(Optional) A vector of group IDs for random effects model. Required if the BCF model includes random effects.} + +\item{rfx_basis}{(Optional) A matrix of bases for random effects model. Required if the BCF model includes random effects.} + +\item{num_draws_per_sample}{(Optional) The number of samples to draw from the likelihood for each draw of the posterior. Defaults to a heuristic based on the number of samples in a BCF model (i.e. if the BCF model has >1000 draws, we use 1 draw from the likelihood per sample, otherwise we upsample to ensure at least 1000 posterior predictive draws).} +} +\value{ +Array of posterior predictive samples with dimensions (num_observations, num_posterior_samples, num_draws_per_sample) if num_draws_per_sample > 1, otherwise (num_observations, num_posterior_samples). +} +\description{ +Sample from the posterior predictive distribution for outcomes modeled by BCF +} +\examples{ +n <- 100 +p <- 5 +X <- matrix(rnorm(n * p), nrow = n, ncol = p) +pi_X <- pnorm(X[,1] / 2) +Z <- rbinom(n, 1, pi_X) +y <- 2 * X[,2] + 0.5 * X[,2] * Z + rnorm(n) +bcf_model <- bcf(X_train = X, Z_train = Z, y_train = y, propensity_train = pi_X) +ppd_samples <- sample_bcf_posterior_predictive( + model_object = bcf_model, covariates = X, + treatment = Z, propensity = pi_X +) +} diff --git a/man/saveBARTModelToJson.Rd b/man/saveBARTModelToJson.Rd index a617532e..054af24e 100644 --- a/man/saveBARTModelToJson.Rd +++ b/man/saveBARTModelToJson.Rd @@ -20,9 +20,9 @@ n <- 100 p <- 5 X <- matrix(runif(n*p), ncol = p) f_XW <- ( - ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + - ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + - ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + + ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + + ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + + ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + ((0.75 <= X[,1]) & (1 > X[,1])) * (7.5) ) noise_sd <- 1 @@ -36,7 +36,7 @@ X_test <- X[test_inds,] X_train <- X[train_inds,] y_test <- y[test_inds] y_train <- y[train_inds] -bart_model <- bart(X_train = X_train, y_train = y_train, +bart_model <- bart(X_train = X_train, y_train = y_train, num_gfr = 10, num_burnin = 0, num_mcmc = 10) bart_json <- saveBARTModelToJson(bart_model) } diff --git a/man/saveBARTModelToJsonFile.Rd b/man/saveBARTModelToJsonFile.Rd index 46a3110e..62ef6ad7 100644 --- a/man/saveBARTModelToJsonFile.Rd +++ b/man/saveBARTModelToJsonFile.Rd @@ -22,9 +22,9 @@ n <- 100 p <- 5 X <- matrix(runif(n*p), ncol = p) f_XW <- ( - ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + - ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + - ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + + ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + + ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + + ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + ((0.75 <= X[,1]) & (1 > X[,1])) * (7.5) ) noise_sd <- 1 @@ -38,7 +38,7 @@ X_test <- X[test_inds,] X_train <- X[train_inds,] y_test <- y[test_inds] y_train <- y[train_inds] -bart_model <- bart(X_train = X_train, y_train = y_train, +bart_model <- bart(X_train = X_train, y_train = y_train, num_gfr = 10, num_burnin = 0, num_mcmc = 10) tmpjson <- tempfile(fileext = ".json") saveBARTModelToJsonFile(bart_model, file.path(tmpjson)) diff --git a/man/saveBARTModelToJsonString.Rd b/man/saveBARTModelToJsonString.Rd index c83f9e5d..10927c20 100644 --- a/man/saveBARTModelToJsonString.Rd +++ b/man/saveBARTModelToJsonString.Rd @@ -20,9 +20,9 @@ n <- 100 p <- 5 X <- matrix(runif(n*p), ncol = p) f_XW <- ( - ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + - ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + - ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + + ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + + ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + + ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + ((0.75 <= X[,1]) & (1 > X[,1])) * (7.5) ) noise_sd <- 1 @@ -36,7 +36,7 @@ X_test <- X[test_inds,] X_train <- X[train_inds,] y_test <- y[test_inds] y_train <- y[train_inds] -bart_model <- bart(X_train = X_train, y_train = y_train, +bart_model <- bart(X_train = X_train, y_train = y_train, num_gfr = 10, num_burnin = 0, num_mcmc = 10) bart_json_string <- saveBARTModelToJsonString(bart_model) } diff --git a/man/saveBCFModelToJson.Rd b/man/saveBCFModelToJson.Rd index ae2c286d..2c04d76c 100644 --- a/man/saveBCFModelToJson.Rd +++ b/man/saveBCFModelToJson.Rd @@ -20,21 +20,21 @@ n <- 500 p <- 5 X <- matrix(runif(n*p), ncol = p) mu_x <- ( - ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + - ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + - ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + + ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + + ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + + ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + ((0.75 <= X[,1]) & (1 > X[,1])) * (7.5) ) pi_x <- ( - ((0 <= X[,1]) & (0.25 > X[,1])) * (0.2) + - ((0.25 <= X[,1]) & (0.5 > X[,1])) * (0.4) + - ((0.5 <= X[,1]) & (0.75 > X[,1])) * (0.6) + + ((0 <= X[,1]) & (0.25 > X[,1])) * (0.2) + + ((0.25 <= X[,1]) & (0.5 > X[,1])) * (0.4) + + ((0.5 <= X[,1]) & (0.75 > X[,1])) * (0.6) + ((0.75 <= X[,1]) & (1 > X[,1])) * (0.8) ) tau_x <- ( - ((0 <= X[,2]) & (0.25 > X[,2])) * (0.5) + - ((0.25 <= X[,2]) & (0.5 > X[,2])) * (1.0) + - ((0.5 <= X[,2]) & (0.75 > X[,2])) * (1.5) + + ((0 <= X[,2]) & (0.25 > X[,2])) * (0.5) + + ((0.25 <= X[,2]) & (0.5 > X[,2])) * (1.0) + + ((0.5 <= X[,2]) & (0.75 > X[,2])) * (1.5) + ((0.75 <= X[,2]) & (1 > X[,2])) * (2.0) ) Z <- rbinom(n, 1, pi_x) @@ -70,15 +70,15 @@ rfx_term_test <- rfx_term[test_inds] rfx_term_train <- rfx_term[train_inds] mu_params <- list(sample_sigma2_leaf = TRUE) tau_params <- list(sample_sigma2_leaf = FALSE) -bcf_model <- bcf(X_train = X_train, Z_train = Z_train, y_train = y_train, - propensity_train = pi_train, - rfx_group_ids_train = rfx_group_ids_train, - rfx_basis_train = rfx_basis_train, X_test = X_test, - Z_test = Z_test, propensity_test = pi_test, +bcf_model <- bcf(X_train = X_train, Z_train = Z_train, y_train = y_train, + propensity_train = pi_train, + rfx_group_ids_train = rfx_group_ids_train, + rfx_basis_train = rfx_basis_train, X_test = X_test, + Z_test = Z_test, propensity_test = pi_test, rfx_group_ids_test = rfx_group_ids_test, - rfx_basis_test = rfx_basis_test, - num_gfr = 10, num_burnin = 0, num_mcmc = 10, - prognostic_forest_params = mu_params, + rfx_basis_test = rfx_basis_test, + num_gfr = 10, num_burnin = 0, num_mcmc = 10, + prognostic_forest_params = mu_params, treatment_effect_forest_params = tau_params) bcf_json <- saveBCFModelToJson(bcf_model) } diff --git a/man/saveBCFModelToJsonFile.Rd b/man/saveBCFModelToJsonFile.Rd index e6a9f0aa..584bbbba 100644 --- a/man/saveBCFModelToJsonFile.Rd +++ b/man/saveBCFModelToJsonFile.Rd @@ -22,21 +22,21 @@ n <- 500 p <- 5 X <- matrix(runif(n*p), ncol = p) mu_x <- ( - ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + - ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + - ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + + ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + + ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + + ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + ((0.75 <= X[,1]) & (1 > X[,1])) * (7.5) ) pi_x <- ( - ((0 <= X[,1]) & (0.25 > X[,1])) * (0.2) + - ((0.25 <= X[,1]) & (0.5 > X[,1])) * (0.4) + - ((0.5 <= X[,1]) & (0.75 > X[,1])) * (0.6) + + ((0 <= X[,1]) & (0.25 > X[,1])) * (0.2) + + ((0.25 <= X[,1]) & (0.5 > X[,1])) * (0.4) + + ((0.5 <= X[,1]) & (0.75 > X[,1])) * (0.6) + ((0.75 <= X[,1]) & (1 > X[,1])) * (0.8) ) tau_x <- ( - ((0 <= X[,2]) & (0.25 > X[,2])) * (0.5) + - ((0.25 <= X[,2]) & (0.5 > X[,2])) * (1.0) + - ((0.5 <= X[,2]) & (0.75 > X[,2])) * (1.5) + + ((0 <= X[,2]) & (0.25 > X[,2])) * (0.5) + + ((0.25 <= X[,2]) & (0.5 > X[,2])) * (1.0) + + ((0.5 <= X[,2]) & (0.75 > X[,2])) * (1.5) + ((0.75 <= X[,2]) & (1 > X[,2])) * (2.0) ) Z <- rbinom(n, 1, pi_x) @@ -72,15 +72,15 @@ rfx_term_test <- rfx_term[test_inds] rfx_term_train <- rfx_term[train_inds] mu_params <- list(sample_sigma2_leaf = TRUE) tau_params <- list(sample_sigma2_leaf = FALSE) -bcf_model <- bcf(X_train = X_train, Z_train = Z_train, y_train = y_train, - propensity_train = pi_train, - rfx_group_ids_train = rfx_group_ids_train, - rfx_basis_train = rfx_basis_train, X_test = X_test, - Z_test = Z_test, propensity_test = pi_test, +bcf_model <- bcf(X_train = X_train, Z_train = Z_train, y_train = y_train, + propensity_train = pi_train, + rfx_group_ids_train = rfx_group_ids_train, + rfx_basis_train = rfx_basis_train, X_test = X_test, + Z_test = Z_test, propensity_test = pi_test, rfx_group_ids_test = rfx_group_ids_test, - rfx_basis_test = rfx_basis_test, - num_gfr = 10, num_burnin = 0, num_mcmc = 10, - prognostic_forest_params = mu_params, + rfx_basis_test = rfx_basis_test, + num_gfr = 10, num_burnin = 0, num_mcmc = 10, + prognostic_forest_params = mu_params, treatment_effect_forest_params = tau_params) tmpjson <- tempfile(fileext = ".json") saveBCFModelToJsonFile(bcf_model, file.path(tmpjson)) diff --git a/man/saveBCFModelToJsonString.Rd b/man/saveBCFModelToJsonString.Rd index 4328e525..2182bbe3 100644 --- a/man/saveBCFModelToJsonString.Rd +++ b/man/saveBCFModelToJsonString.Rd @@ -20,21 +20,21 @@ n <- 500 p <- 5 X <- matrix(runif(n*p), ncol = p) mu_x <- ( - ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + - ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + - ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + + ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + + ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + + ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + ((0.75 <= X[,1]) & (1 > X[,1])) * (7.5) ) pi_x <- ( - ((0 <= X[,1]) & (0.25 > X[,1])) * (0.2) + - ((0.25 <= X[,1]) & (0.5 > X[,1])) * (0.4) + - ((0.5 <= X[,1]) & (0.75 > X[,1])) * (0.6) + + ((0 <= X[,1]) & (0.25 > X[,1])) * (0.2) + + ((0.25 <= X[,1]) & (0.5 > X[,1])) * (0.4) + + ((0.5 <= X[,1]) & (0.75 > X[,1])) * (0.6) + ((0.75 <= X[,1]) & (1 > X[,1])) * (0.8) ) tau_x <- ( - ((0 <= X[,2]) & (0.25 > X[,2])) * (0.5) + - ((0.25 <= X[,2]) & (0.5 > X[,2])) * (1.0) + - ((0.5 <= X[,2]) & (0.75 > X[,2])) * (1.5) + + ((0 <= X[,2]) & (0.25 > X[,2])) * (0.5) + + ((0.25 <= X[,2]) & (0.5 > X[,2])) * (1.0) + + ((0.5 <= X[,2]) & (0.75 > X[,2])) * (1.5) + ((0.75 <= X[,2]) & (1 > X[,2])) * (2.0) ) Z <- rbinom(n, 1, pi_x) @@ -70,15 +70,15 @@ rfx_term_test <- rfx_term[test_inds] rfx_term_train <- rfx_term[train_inds] mu_params <- list(sample_sigma2_leaf = TRUE) tau_params <- list(sample_sigma2_leaf = FALSE) -bcf_model <- bcf(X_train = X_train, Z_train = Z_train, y_train = y_train, - propensity_train = pi_train, - rfx_group_ids_train = rfx_group_ids_train, - rfx_basis_train = rfx_basis_train, X_test = X_test, - Z_test = Z_test, propensity_test = pi_test, +bcf_model <- bcf(X_train = X_train, Z_train = Z_train, y_train = y_train, + propensity_train = pi_train, + rfx_group_ids_train = rfx_group_ids_train, + rfx_basis_train = rfx_basis_train, X_test = X_test, + Z_test = Z_test, propensity_test = pi_test, rfx_group_ids_test = rfx_group_ids_test, - rfx_basis_test = rfx_basis_test, - num_gfr = 10, num_burnin = 0, num_mcmc = 10, - prognostic_forest_params = mu_params, + rfx_basis_test = rfx_basis_test, + num_gfr = 10, num_burnin = 0, num_mcmc = 10, + prognostic_forest_params = mu_params, treatment_effect_forest_params = tau_params) saveBCFModelToJsonString(bcf_model) } diff --git a/src/py_stochtree.cpp b/src/py_stochtree.cpp index 66621d52..f5c30d7d 100644 --- a/src/py_stochtree.cpp +++ b/src/py_stochtree.cpp @@ -1418,6 +1418,64 @@ class RandomEffectsContainerCpp { int NumGroups() { return rfx_container_->NumGroups(); } + py::array_t GetBeta() { + int num_samples = rfx_container_->NumSamples(); + int num_components = rfx_container_->NumComponents(); + int num_groups = rfx_container_->NumGroups(); + std::vector beta_raw = rfx_container_->GetBeta(); + auto result = py::array_t(py::detail::any_container({num_components, num_groups, num_samples})); + auto accessor = result.mutable_unchecked<3>(); + for (int i = 0; i < num_components; i++) { + for (int j = 0; j < num_groups; j++) { + for (int k = 0; k < num_samples; k++) { + accessor(i,j,k) = beta_raw[k*num_groups*num_components + j*num_components + i]; + } + } + } + return result; + } + py::array_t GetXi() { + int num_samples = rfx_container_->NumSamples(); + int num_components = rfx_container_->NumComponents(); + int num_groups = rfx_container_->NumGroups(); + std::vector xi_raw = rfx_container_->GetXi(); + auto result = py::array_t(py::detail::any_container({num_components, num_groups, num_samples})); + auto accessor = result.mutable_unchecked<3>(); + for (int i = 0; i < num_components; i++) { + for (int j = 0; j < num_groups; j++) { + for (int k = 0; k < num_samples; k++) { + accessor(i,j,k) = xi_raw[k*num_groups*num_components + j*num_components + i]; + } + } + } + return result; + } + py::array_t GetAlpha() { + int num_samples = rfx_container_->NumSamples(); + int num_components = rfx_container_->NumComponents(); + std::vector alpha_raw = rfx_container_->GetAlpha(); + auto result = py::array_t(py::detail::any_container({num_components, num_samples})); + auto accessor = result.mutable_unchecked<2>(); + for (int i = 0; i < num_components; i++) { + for (int j = 0; j < num_samples; j++) { + accessor(i,j) = alpha_raw[j*num_components + i]; + } + } + return result; + } + py::array_t GetSigma() { + int num_samples = rfx_container_->NumSamples(); + int num_components = rfx_container_->NumComponents(); + std::vector sigma_raw = rfx_container_->GetSigma(); + auto result = py::array_t(py::detail::any_container({num_components, num_samples})); + auto accessor = result.mutable_unchecked<2>(); + for (int i = 0; i < num_components; i++) { + for (int j = 0; j < num_samples; j++) { + accessor(i,j) = sigma_raw[j*num_components + i]; + } + } + return result; + } void DeleteSample(int sample_num) { rfx_container_->DeleteSample(sample_num); } @@ -2294,6 +2352,10 @@ PYBIND11_MODULE(stochtree_cpp, m) { .def("NumSamples", &RandomEffectsContainerCpp::NumSamples) .def("NumComponents", &RandomEffectsContainerCpp::NumComponents) .def("NumGroups", &RandomEffectsContainerCpp::NumGroups) + .def("GetBeta", &RandomEffectsContainerCpp::GetBeta) + .def("GetXi", &RandomEffectsContainerCpp::GetXi) + .def("GetAlpha", &RandomEffectsContainerCpp::GetAlpha) + .def("GetSigma", &RandomEffectsContainerCpp::GetSigma) .def("DeleteSample", &RandomEffectsContainerCpp::DeleteSample) .def("Predict", &RandomEffectsContainerCpp::Predict) .def("SaveToJsonFile", &RandomEffectsContainerCpp::SaveToJsonFile) diff --git a/stochtree/bart.py b/stochtree/bart.py index 1f65ea17..3f81c531 100644 --- a/stochtree/bart.py +++ b/stochtree/bart.py @@ -1,7 +1,3 @@ -""" -Bayesian Additive Regression Trees (BART) module -""" - import warnings from math import log from numbers import Integral @@ -23,7 +19,14 @@ ) from .sampler import RNG, ForestSampler, GlobalVarianceModel, LeafVarianceModel from .serialization import JSONSerializer -from .utils import NotSampledError, _expand_dims_1d, _expand_dims_2d, _expand_dims_2d_diag +from .utils import ( + NotSampledError, + _expand_dims_1d, + _expand_dims_2d, + _expand_dims_2d_diag, + _posterior_predictive_heuristic_multiplier, + _summarize_interval, +) class BARTModel: @@ -82,6 +85,7 @@ def sample( general_params: Optional[Dict[str, Any]] = None, mean_forest_params: Optional[Dict[str, Any]] = None, variance_forest_params: Optional[Dict[str, Any]] = None, + random_effects_params: Optional[Dict[str, Any]] = None, previous_model_json: Optional[str] = None, previous_model_warmstart_sample_num: Optional[int] = None, ) -> None: @@ -132,12 +136,6 @@ def sample( * `keep_every` (`int`): How many iterations of the burned-in MCMC sampler should be run before forests and parameters are retained. Defaults to `1`. Setting `keep_every = k` for some `k > 1` will "thin" the MCMC samples by retaining every `k`-th sample, rather than simply every sample. This can reduce the autocorrelation of the MCMC samples. * `num_chains` (`int`): How many independent MCMC chains should be sampled. If `num_mcmc = 0`, this is ignored. If `num_gfr = 0`, then each chain is run from root for `num_mcmc * keep_every + num_burnin` iterations, with `num_mcmc` samples retained. If `num_gfr > 0`, each MCMC chain will be initialized from a separate GFR ensemble, with the requirement that `num_gfr >= num_chains`. Defaults to `1`. * `probit_outcome_model` (`bool`): Whether or not the outcome should be modeled as explicitly binary via a probit link. If `True`, `y` must only contain the values `0` and `1`. Default: `False`. - * `rfx_working_parameter_prior_mean`: Prior mean for the random effects "working parameter". Default: `None`. Must be a 1D numpy array whose dimension matches the number of random effects bases, or a scalar value that will be expanded to a vector. - * `rfx_group_parameter_prior_mean`: Prior mean for the random effects "group parameters." Default: `None`. Must be a 1D numpy array whose dimension matches the number of random effects bases, or a scalar value that will be expanded to a vector. - * `rfx_working_parameter_prior_cov`: Prior covariance matrix for the random effects "working parameter." Default: `None`. Must be a square numpy matrix whose dimension matches the number of random effects bases, or a scalar value that will be expanded to a diagonal matrix. - * `rfx_group_parameter_prior_cov`: Prior covariance matrix for the random effects "group parameters." Default: `None`. Must be a square numpy matrix whose dimension matches the number of random effects bases, or a scalar value that will be expanded to a diagonal matrix. - * `rfx_variance_prior_shape`: Shape parameter for the inverse gamma prior on the variance of the random effects "group parameter." Default: `1`. - * `rfx_variance_prior_scale`: Scale parameter for the inverse gamma prior on the variance of the random effects "group parameter." Default: `1`. * `num_threads`: Number of threads to use in the GFR and MCMC algorithms, as well as prediction. If OpenMP is not available on a user's setup, this will default to `1`, otherwise to the maximum number of available threads. mean_forest_params : dict, optional @@ -157,7 +155,7 @@ def sample( * `num_features_subsample` (`int`): How many features to subsample when growing each tree for the GFR algorithm. Defaults to the number of features in the training dataset. variance_forest_params : dict, optional - Dictionary of variance forest model parameters, each of which has a default value processed internally, so this argument is optional. + Dictionary of variance forest model parameters, each of which has a default value processed internally, so this argument is optional. * `num_trees` (`int`): Number of trees in the conditional variance model. Defaults to `0`. Variance is only modeled using a tree / forest if `num_trees > 0`. * `alpha` (`float`): Prior probability of splitting for a tree of depth 0 in the conditional variance model. Tree split prior combines `alpha` and `beta` via `alpha*(1+node_depth)^-beta`. Defaults to `0.95`. @@ -172,6 +170,17 @@ def sample( * `drop_vars` (`list` or `np.array`): Vector of variable names or column indices denoting variables that should be excluded from the variance forest. Defaults to `None`. If both `drop_vars` and `keep_vars` are set, `drop_vars` will be ignored. * `num_features_subsample` (`int`): How many features to subsample when growing each tree for the GFR algorithm. Defaults to the number of features in the training dataset. + random_effects_params : dict, optional + Dictionary of random effects parameters, each of which has a default value processed internally, so this argument is optional. + + * `model_spec`: Specification of the random effects model. Options are "custom" and "intercept_only". If "custom" is specified, then a user-provided basis must be passed through `rfx_basis_train`. If "intercept_only" is specified, a random effects basis of all ones will be dispatched internally at sampling and prediction time. If "intercept_plus_treatment" is specified, a random effects basis that combines an "intercept" basis of all ones with the treatment variable (`Z_train`) will be dispatched internally at sampling and prediction time. Default: "custom". If "intercept_only" is specified, `rfx_basis_train` and `rfx_basis_test` (if provided) will be ignored. + * `working_parameter_prior_mean`: Prior mean for the random effects "working parameter". Default: `None`. Must be a 1D numpy array whose dimension matches the number of random effects bases, or a scalar value that will be expanded to a vector. + * `group_parameter_prior_mean`: Prior mean for the random effects "group parameters." Default: `None`. Must be a 1D numpy array whose dimension matches the number of random effects bases, or a scalar value that will be expanded to a vector. + * `working_parameter_prior_cov`: Prior covariance matrix for the random effects "working parameter." Default: `None`. Must be a square numpy matrix whose dimension matches the number of random effects bases, or a scalar value that will be expanded to a diagonal matrix. + * `group_parameter_prior_cov`: Prior covariance matrix for the random effects "group parameters." Default: `None`. Must be a square numpy matrix whose dimension matches the number of random effects bases, or a scalar value that will be expanded to a diagonal matrix. + * `variance_prior_shape`: Shape parameter for the inverse gamma prior on the variance of the random effects "group parameter." Default: `1`. + * `variance_prior_scale`: Scale parameter for the inverse gamma prior on the variance of the random effects "group parameter." Default: `1`. + previous_model_json : str, optional JSON string containing a previous BART model. This can be used to "continue" a sampler interactively after inspecting the samples or to run parallel chains "warm-started" from existing forest samples. Defaults to `None`. previous_model_warmstart_sample_num : int, optional @@ -197,12 +206,6 @@ def sample( "keep_every": 1, "num_chains": 1, "probit_outcome_model": False, - "rfx_working_parameter_prior_mean": None, - "rfx_group_parameter_prior_mean": None, - "rfx_working_parameter_prior_cov": None, - "rfx_group_parameter_prior_cov": None, - "rfx_variance_prior_shape": 1.0, - "rfx_variance_prior_scale": 1.0, "num_threads": -1, } general_params_updated = _preprocess_params( @@ -247,6 +250,20 @@ def sample( variance_forest_params_default, variance_forest_params ) + # Update random effects parameters + rfx_params_default = { + "model_spec": "custom", + "working_parameter_prior_mean": None, + "group_parameter_prior_mean": None, + "working_parameter_prior_cov": None, + "group_parameter_prior_cov": None, + "variance_prior_shape": 1.0, + "variance_prior_scale": 1.0, + } + rfx_params_updated = _preprocess_params( + rfx_params_default, random_effects_params + ) + ### Unpack all parameter values # 1. General parameters cutpoint_grid_size = general_params_updated["cutpoint_grid_size"] @@ -262,12 +279,6 @@ def sample( keep_every = general_params_updated["keep_every"] num_chains = general_params_updated["num_chains"] self.probit_outcome_model = general_params_updated["probit_outcome_model"] - rfx_working_parameter_prior_mean = general_params_updated["rfx_working_parameter_prior_mean"] - rfx_group_parameter_prior_mean = general_params_updated["rfx_group_parameter_prior_mean"] - rfx_working_parameter_prior_cov = general_params_updated["rfx_working_parameter_prior_cov"] - rfx_group_parameter_prior_cov = general_params_updated["rfx_group_parameter_prior_cov"] - rfx_variance_prior_shape = general_params_updated["rfx_variance_prior_shape"] - rfx_variance_prior_scale = general_params_updated["rfx_variance_prior_scale"] num_threads = general_params_updated["num_threads"] # 2. Mean forest parameters @@ -282,7 +293,9 @@ def sample( b_leaf = mean_forest_params_updated["sigma2_leaf_scale"] keep_vars_mean = mean_forest_params_updated["keep_vars"] drop_vars_mean = mean_forest_params_updated["drop_vars"] - num_features_subsample_mean = mean_forest_params_updated["num_features_subsample"] + num_features_subsample_mean = mean_forest_params_updated[ + "num_features_subsample" + ] # 3. Variance forest parameters num_trees_variance = variance_forest_params_updated["num_trees"] @@ -298,7 +311,30 @@ def sample( b_forest = variance_forest_params_updated["var_forest_prior_scale"] keep_vars_variance = variance_forest_params_updated["keep_vars"] drop_vars_variance = variance_forest_params_updated["drop_vars"] - num_features_subsample_variance = variance_forest_params_updated["num_features_subsample"] + num_features_subsample_variance = variance_forest_params_updated[ + "num_features_subsample" + ] + + # 4. Random effects parameters + self.rfx_model_spec = rfx_params_updated["model_spec"] + rfx_working_parameter_prior_mean = rfx_params_updated[ + "working_parameter_prior_mean" + ] + rfx_group_parameter_prior_mean = rfx_params_updated[ + "group_parameter_prior_mean" + ] + rfx_working_parameter_prior_cov = rfx_params_updated[ + "working_parameter_prior_cov" + ] + rfx_group_parameter_prior_cov = rfx_params_updated["group_parameter_prior_cov"] + rfx_variance_prior_shape = rfx_params_updated["variance_prior_shape"] + rfx_variance_prior_scale = rfx_params_updated["variance_prior_scale"] + + # Check random effects specification + if not isinstance(self.rfx_model_spec, str): + raise ValueError("rfx_model_spec must be a string") + if self.rfx_model_spec not in ["custom", "intercept_only"]: + raise ValueError("type must either be 'custom' or 'intercept_only'") # Override keep_gfr if there are no MCMC samples if num_mcmc == 0: @@ -955,60 +991,74 @@ def sample( "All random effect group labels provided in rfx_group_ids_test must be present in rfx_group_ids_train" ) - # Fill in rfx basis as a vector of 1s (random intercept) if a basis not provided - has_basis_rfx = False + # Handle the rfx basis matrices + self.has_rfx_basis = False + self.num_rfx_basis = 0 if self.has_rfx: - if rfx_basis_train is None: - rfx_basis_train = np.ones((rfx_group_ids_train.shape[0], 1)) - else: - has_basis_rfx = True + if self.rfx_model_spec == "custom": + if rfx_basis_train is None: + raise ValueError( + "rfx_basis_train must be provided when rfx_model_spec = 'custom'" + ) + elif self.rfx_model_spec == "intercept_only": + if rfx_basis_train is None: + rfx_basis_train = np.ones((rfx_group_ids_train.shape[0], 1)) + self.has_rfx_basis = True + self.num_rfx_basis = rfx_basis_train.shape[1] num_rfx_groups = np.unique(rfx_group_ids_train).shape[0] num_rfx_components = rfx_basis_train.shape[1] - # TODO warn if num_rfx_groups is 1 + if num_rfx_groups == 1: + warnings.warn( + "Only one group was provided for random effect sampling, so the random effects model is likely overkill" + ) if has_rfx_test: - if rfx_basis_test is None: - if has_basis_rfx: + if self.rfx_model_spec == "custom": + if rfx_basis_test is None: raise ValueError( - "Random effects basis provided for training set, must also be provided for the test set" + "rfx_basis_test must be provided when rfx_model_spec = 'custom' and a test set is provided" ) - rfx_basis_test = np.ones((rfx_group_ids_test.shape[0], 1)) - + elif self.rfx_model_spec == "intercept_only": + if rfx_basis_test is None: + rfx_basis_test = np.ones((rfx_group_ids_test.shape[0], 1)) # Set up random effects structures if self.has_rfx: # Prior parameters if rfx_working_parameter_prior_mean is None: if num_rfx_components == 1: - alpha_init = np.array([1]) + alpha_init = np.array([0.0], dtype=float) elif num_rfx_components > 1: - alpha_init = np.concatenate( - ( - np.ones(1, dtype=float), - np.zeros(num_rfx_components - 1, dtype=float), - ) - ) + alpha_init = np.zeros(num_rfx_components, dtype=float) else: raise ValueError("There must be at least 1 random effect component") else: - alpha_init = _expand_dims_1d(rfx_working_parameter_prior_mean, num_rfx_components) - + alpha_init = _expand_dims_1d( + rfx_working_parameter_prior_mean, num_rfx_components + ) + if rfx_group_parameter_prior_mean is None: xi_init = np.tile(np.expand_dims(alpha_init, 1), (1, num_rfx_groups)) else: - xi_init = _expand_dims_2d(rfx_group_parameter_prior_mean, num_rfx_components, num_rfx_groups) - + xi_init = _expand_dims_2d( + rfx_group_parameter_prior_mean, num_rfx_components, num_rfx_groups + ) + if rfx_working_parameter_prior_cov is None: sigma_alpha_init = np.identity(num_rfx_components) else: - sigma_alpha_init = _expand_dims_2d_diag(rfx_working_parameter_prior_cov, num_rfx_components) - + sigma_alpha_init = _expand_dims_2d_diag( + rfx_working_parameter_prior_cov, num_rfx_components + ) + if rfx_group_parameter_prior_cov is None: sigma_xi_init = np.identity(num_rfx_components) else: - sigma_xi_init = _expand_dims_2d_diag(rfx_group_parameter_prior_cov, num_rfx_components) - + sigma_xi_init = _expand_dims_2d_diag( + rfx_group_parameter_prior_cov, num_rfx_components + ) + sigma_xi_shape = rfx_variance_prior_shape sigma_xi_scale = rfx_variance_prior_scale - + # Random effects sampling data structures rfx_dataset_train = RandomEffectsDataset() rfx_dataset_train.add_group_labels(rfx_group_ids_train) @@ -1046,9 +1096,13 @@ def sample( if sample_sigma2_leaf: self.leaf_scale_samples = np.empty(self.num_samples, dtype=np.float64) if self.include_mean_forest: - yhat_train_raw = np.empty((self.n_train, self.num_samples), dtype=np.float64) + yhat_train_raw = np.empty( + (self.n_train, self.num_samples), dtype=np.float64 + ) if self.include_variance_forest: - sigma2_x_train_raw = np.empty((self.n_train, self.num_samples), dtype=np.float64) + sigma2_x_train_raw = np.empty( + (self.n_train, self.num_samples), dtype=np.float64 + ) sample_counter = -1 # Forest Dataset (covariates and optional basis) @@ -1104,8 +1158,8 @@ def sample( max_depth=max_depth_mean, leaf_model_type=leaf_model_mean_forest, leaf_model_scale=current_leaf_scale, - cutpoint_grid_size=cutpoint_grid_size, - num_features_subsample=num_features_subsample_mean + cutpoint_grid_size=cutpoint_grid_size, + num_features_subsample=num_features_subsample_mean, ) forest_sampler_mean = ForestSampler( forest_dataset_train, @@ -1128,7 +1182,7 @@ def sample( cutpoint_grid_size=cutpoint_grid_size, variance_forest_shape=a_forest, variance_forest_scale=b_forest, - num_features_subsample=num_features_subsample_variance + num_features_subsample=num_features_subsample_variance, ) forest_sampler_variance = ForestSampler( forest_dataset_train, @@ -1196,9 +1250,12 @@ def sample( if self.include_mean_forest: if self.probit_outcome_model: # Sample latent probit variable z | - - forest_pred = active_forest_mean.predict(forest_dataset_train) - mu0 = forest_pred[y_train[:, 0] == 0] - mu1 = forest_pred[y_train[:, 0] == 1] + outcome_pred = active_forest_mean.predict(forest_dataset_train) + if self.has_rfx: + rfx_pred = rfx_model.predict(rfx_dataset_train, rfx_tracker) + outcome_pred = outcome_pred + rfx_pred + mu0 = outcome_pred[y_train[:, 0] == 0] + mu1 = outcome_pred[y_train[:, 0] == 1] n0 = np.sum(y_train[:, 0] == 0) n1 = np.sum(y_train[:, 0] == 1) u0 = self.rng.uniform( @@ -1215,7 +1272,7 @@ def sample( resid_train[y_train[:, 0] == 1, 0] = mu1 + norm.ppf(u1) # Update outcome - new_outcome = np.squeeze(resid_train) - forest_pred + new_outcome = np.squeeze(resid_train) - outcome_pred residual_train.update_data(new_outcome) # Sample the mean forest @@ -1234,7 +1291,9 @@ def sample( # Cache train set predictions since they are already computed during sampling if keep_sample: - yhat_train_raw[:,sample_counter] = forest_sampler_mean.get_cached_forest_predictions() + yhat_train_raw[:, sample_counter] = ( + forest_sampler_mean.get_cached_forest_predictions() + ) # Sample the variance forest if self.include_variance_forest: @@ -1253,7 +1312,9 @@ def sample( # Cache train set predictions since they are already computed during sampling if keep_sample: - sigma2_x_train_raw[:,sample_counter] = forest_sampler_variance.get_cached_forest_predictions() + sigma2_x_train_raw[:, sample_counter] = ( + forest_sampler_variance.get_cached_forest_predictions() + ) # Sample variance parameters (if requested) if self.sample_sigma2_global: @@ -1396,11 +1457,16 @@ def sample( if self.include_mean_forest: if self.probit_outcome_model: # Sample latent probit variable z | - - forest_pred = active_forest_mean.predict( + outcome_pred = active_forest_mean.predict( forest_dataset_train ) - mu0 = forest_pred[y_train[:, 0] == 0] - mu1 = forest_pred[y_train[:, 0] == 1] + if self.has_rfx: + rfx_pred = rfx_model.predict( + rfx_dataset_train, rfx_tracker + ) + outcome_pred = outcome_pred + rfx_pred + mu0 = outcome_pred[y_train[:, 0] == 0] + mu1 = outcome_pred[y_train[:, 0] == 1] n0 = np.sum(y_train[:, 0] == 0) n1 = np.sum(y_train[:, 0] == 1) u0 = self.rng.uniform( @@ -1417,7 +1483,7 @@ def sample( resid_train[y_train[:, 0] == 1, 0] = mu1 + norm.ppf(u1) # Update outcome - new_outcome = np.squeeze(resid_train) - forest_pred + new_outcome = np.squeeze(resid_train) - outcome_pred residual_train.update_data(new_outcome) # Sample the mean forest @@ -1435,7 +1501,9 @@ def sample( ) if keep_sample: - yhat_train_raw[:,sample_counter] = forest_sampler_mean.get_cached_forest_predictions() + yhat_train_raw[:, sample_counter] = ( + forest_sampler_mean.get_cached_forest_predictions() + ) # Sample the variance forest if self.include_variance_forest: @@ -1453,7 +1521,9 @@ def sample( ) if keep_sample: - sigma2_x_train_raw[:,sample_counter] = forest_sampler_variance.get_cached_forest_predictions() + sigma2_x_train_raw[:, sample_counter] = ( + forest_sampler_variance.get_cached_forest_predictions() + ) # Sample variance parameters (if requested) if self.sample_sigma2_global: @@ -1504,9 +1574,9 @@ def sample( if self.sample_sigma2_leaf: self.leaf_scale_samples = self.leaf_scale_samples[num_gfr:] if self.include_mean_forest: - yhat_train_raw = yhat_train_raw[:,num_gfr:] + yhat_train_raw = yhat_train_raw[:, num_gfr:] if self.include_variance_forest: - sigma2_x_train_raw = sigma2_x_train_raw[:,num_gfr:] + sigma2_x_train_raw = sigma2_x_train_raw[:, num_gfr:] self.num_samples -= num_gfr # Store predictions @@ -1553,7 +1623,10 @@ def sample( ) else: self.sigma2_x_train = ( - np.exp(sigma2_x_train_raw) * self.sigma2_init * self.y_std * self.y_std + np.exp(sigma2_x_train_raw) + * self.sigma2_init + * self.y_std + * self.y_std ) if self.has_test: sigma2_x_test_raw = ( @@ -1578,6 +1651,9 @@ def predict( basis: np.array = None, rfx_group_ids: np.array = None, rfx_basis: np.array = None, + type: str = "posterior", + terms: Union[list[str], str] = "all", + scale: str = "linear", ) -> Union[np.array, tuple]: """Return predictions from every forest sampled (either / both of mean and variance). Return type is either a single array of predictions, if a BART model only includes a @@ -1593,14 +1669,80 @@ def predict( Optional group labels used for an additive random effects model. rfx_basis : np.array, optional Optional basis for "random-slope" regression in an additive random effects model. + type : str, optional + Type of prediction to return. Options are "mean", which averages the predictions from every draw of a BART model, and "posterior", which returns the entire matrix of posterior predictions. Default: "posterior". + terms : str, optional + Which model terms to include in the prediction. This can be a single term or a list of model terms. Options include "y_hat", "mean_forest", "rfx", "variance_forest", or "all". If a model doesn't have mean forest, random effects, or variance forest predictions, but one of those terms is request, the request will simply be ignored. If none of the requested terms are present in a model, this function will return `NULL` along with a warning. Default: "all". + scale : str, optional + Scale of mean function predictions. Options are "linear", which returns predictions on the original scale of the mean forest / RFX terms, and "probability", which transforms predictions into a probability of observing `y == 1`. "probability" is only valid for models fit with a probit outcome model. Default: "linear". Returns ------- - mu_x : np.array, optional - Mean forest and / or random effects predictions. - sigma2_x : np.array, optional - Variance forest predictions. + Dict of numpy arrays for each prediction term, or a simple numpy array if a single term is requested. """ + # Handle mean function scale + if not isinstance(scale, str): + raise ValueError("scale must be a string") + if scale not in ["linear", "probability"]: + raise ValueError("scale must either be 'linear' or 'probability'") + is_probit = self.probit_outcome_model + if (scale == "probability") and (not is_probit): + raise ValueError( + "scale cannot be 'probability' for models not fit with a probit outcome model" + ) + probability_scale = scale == "probability" + + # Handle prediction type + if not isinstance(type, str): + raise ValueError("type must be a string") + if type not in ["mean", "posterior"]: + raise ValueError("type must either be 'mean' or 'posterior'") + predict_mean = type == "mean" + + # Handle prediction terms + rfx_model_spec = self.rfx_model_spec + rfx_intercept = rfx_model_spec == "intercept_only" + if not isinstance(terms, str) and not isinstance(terms, list): + raise ValueError("type must be a string or list of strings") + num_terms = 1 if isinstance(terms, str) else len(terms) + has_mean_forest = self.include_mean_forest + has_variance_forest = self.include_variance_forest + has_rfx = self.has_rfx + has_y_hat = has_mean_forest or has_rfx + predict_y_hat = (has_y_hat and ("y_hat" in terms)) or ( + has_y_hat and ("all" in terms) + ) + predict_mean_forest = (has_mean_forest and ("mean_forest" in terms)) or ( + has_mean_forest and ("all" in terms) + ) + predict_rfx = (has_rfx and ("rfx" in terms)) or (has_rfx and ("all" in terms)) + predict_variance_forest = ( + has_variance_forest and ("variance_forest" in terms) + ) or (has_variance_forest and ("all" in terms)) + predict_count = ( + predict_y_hat + predict_mean_forest + predict_rfx + predict_variance_forest + ) + if predict_count == 0: + term_list = ", ".join(terms) + warnings.warn( + f"None of the requested model terms, {term_list}, were fit in this model" + ) + return None + predict_rfx_intermediate = predict_y_hat and has_rfx + predict_mean_forest_intermediate = predict_y_hat and has_mean_forest + + # Check that we have at least one term to predict on probability scale + if ( + probability_scale + and not predict_y_hat + and not predict_mean_forest + and not predict_rfx + ): + raise ValueError( + "scale can only be 'probability' if at least one mean term is requested" + ) + + # Check the model is valid if not self.is_sampled(): msg = ( "This BARTModel instance is not fitted yet. Call 'fit' with " @@ -1656,68 +1798,225 @@ def predict( if basis is not None: pred_dataset.add_basis(basis) - # Forest predictions - if self.include_mean_forest: - mean_pred_raw = self.forest_container_mean.forest_container_cpp.Predict( - pred_dataset.dataset_cpp - ) - mean_pred = mean_pred_raw * self.y_std + self.y_bar - - if self.has_rfx: - rfx_preds = ( - self.rfx_container.predict(rfx_group_ids, rfx_basis) * self.y_std - ) - if self.include_mean_forest: - mean_pred = mean_pred + rfx_preds - else: - mean_pred = rfx_preds + self.y_bar - - if self.include_variance_forest: + # Variance forest predictions + if predict_variance_forest: variance_pred_raw = ( self.forest_container_variance.forest_container_cpp.Predict( pred_dataset.dataset_cpp ) ) if self.sample_sigma2_global: - variance_pred = np.empty_like(variance_pred_raw) + variance_forest_predictions = np.empty_like(variance_pred_raw) for i in range(self.num_samples): - variance_pred[:, i] = ( + variance_forest_predictions[:, i] = ( variance_pred_raw[:, i] * self.global_var_samples[i] ) else: - variance_pred = ( + variance_forest_predictions = ( variance_pred_raw * self.sigma2_init * self.y_std * self.y_std ) + if predict_mean: + variance_forest_predictions = np.mean( + variance_forest_predictions, axis=1 + ) + + # Forest predictions + if predict_mean_forest or predict_mean_forest_intermediate: + mean_pred_raw = self.forest_container_mean.forest_container_cpp.Predict( + pred_dataset.dataset_cpp + ) + mean_forest_predictions = mean_pred_raw * self.y_std + self.y_bar - has_mean_predictions = self.include_mean_forest or self.has_rfx - if has_mean_predictions and self.include_variance_forest: - return {"y_hat": mean_pred, "variance_forest_predictions": variance_pred} - elif has_mean_predictions and not self.include_variance_forest: - return {"y_hat": mean_pred, "variance_forest_predictions": None} - elif not has_mean_predictions and self.include_variance_forest: - return {"y_hat": None, "variance_forest_predictions": variance_pred} + # Random effects data checks + if has_rfx: + if rfx_group_ids is None: + raise ValueError( + "rfx_group_ids must be provided if rfx_basis is provided" + ) + if rfx_basis is not None: + if rfx_basis.ndim == 1: + rfx_basis = np.expand_dims(rfx_basis, 1) + if rfx_basis.shape[0] != covariates.shape[0]: + raise ValueError("X and rfx_basis must have the same number of rows") + if rfx_basis.shape[1] != self.num_rfx_basis: + raise ValueError( + "rfx_basis must have the same number of columns as the random effects basis used to sample this model" + ) - def predict_mean( + # Random effects predictions + if predict_rfx or predict_rfx_intermediate: + if rfx_basis is not None: + rfx_predictions = ( + self.rfx_container.predict(rfx_group_ids, rfx_basis) * self.y_std + ) + else: + # Sanity check -- this branch should only occur if rfx_model_spec == "intercept_only" + if not rfx_intercept: + raise ValueError( + "rfx_basis must be provided for random effects models with random slopes" + ) + + # Extract the raw RFX samples and scale by train set outcome standard deviation + rfx_samples_raw = self.rfx_container.extract_parameter_samples() + rfx_beta_draws = rfx_samples_raw["beta_samples"] * self.y_std + + # Construct an array with the appropriate group random effects arranged for each observation + n_train = covariates.shape[0] + if rfx_beta_draws.ndim != 2: + raise ValueError( + "BART models fit with random intercept models should only yield 2 dimensional random effect sample matrices" + ) + else: + rfx_predictions_raw = np.empty( + shape=(n_train, 1, rfx_beta_draws.shape[1]) + ) + for i in range(n_train): + rfx_predictions_raw[i, 0, :] = rfx_beta_draws[ + rfx_group_ids[i], : + ] + rfx_predictions = np.squeeze(rfx_predictions_raw[:, 0, :]) + + # Combine into y hat predictions + if probability_scale: + if predict_y_hat and has_mean_forest and has_rfx: + y_hat = norm.cdf(mean_forest_predictions + rfx_predictions) + mean_forest_predictions = norm.cdf(mean_forest_predictions) + rfx_predictions = norm.cdf(rfx_predictions) + elif predict_y_hat and has_mean_forest: + y_hat = norm.cdf(mean_forest_predictions) + mean_forest_predictions = norm.cdf(mean_forest_predictions) + elif predict_y_hat and has_rfx: + y_hat = norm.cdf(rfx_predictions) + rfx_predictions = norm.cdf(rfx_predictions) + else: + if predict_y_hat and has_mean_forest and has_rfx: + y_hat = mean_forest_predictions + rfx_predictions + elif predict_y_hat and has_mean_forest: + y_hat = mean_forest_predictions + elif predict_y_hat and has_rfx: + y_hat = rfx_predictions + + # Collapse to posterior mean predictions if requested + if predict_mean: + if predict_mean_forest: + mean_forest_predictions = np.mean(mean_forest_predictions, axis=1) + if predict_rfx: + rfx_predictions = np.mean(rfx_predictions, axis=1) + if predict_y_hat: + y_hat = np.mean(y_hat, axis=1) + + if predict_count == 1: + if predict_y_hat: + return y_hat + elif predict_mean_forest: + return mean_forest_predictions + elif predict_rfx: + return rfx_predictions + elif predict_variance_forest: + return variance_forest_predictions + else: + result = dict() + if predict_y_hat: + result["y_hat"] = y_hat + else: + result["y_hat"] = None + if predict_mean_forest: + result["mean_forest_predictions"] = mean_forest_predictions + else: + result["mean_forest_predictions"] = None + if predict_rfx: + result["rfx_predictions"] = rfx_predictions + else: + result["rfx_predictions"] = None + if predict_variance_forest: + result["variance_forest_predictions"] = variance_forest_predictions + else: + result["variance_forest_predictions"] = None + return result + + def compute_contrast( self, - covariates: np.array, - basis: np.array = None, - rfx_group_ids: np.array = None, - rfx_basis: np.array = None, - ) -> np.array: - """Predict expected conditional outcome from a BART model. + covariates_0: Union[np.array, pd.DataFrame], + covariates_1: Union[np.array, pd.DataFrame], + basis_0: np.array = None, + basis_1: np.array = None, + rfx_group_ids_0: np.array = None, + rfx_group_ids_1: np.array = None, + rfx_basis_0: np.array = None, + rfx_basis_1: np.array = None, + type: str = "posterior", + scale: str = "linear", + ) -> Union[np.array, tuple]: + """Compute a contrast using a BART model by making two sets of outcome predictions and taking their + difference. This function provides the flexibility to compute any contrast of interest by specifying + covariates, leaf basis, and random effects bases / IDs for both sides of a two term contrast. + For simplicity, we refer to the subtrahend of the contrast as the "control" or `Y0` term and the minuend + of the contrast as the `Y1` term, though the requested contrast need not match the "control vs treatment" + terminology of a classic two-treatment causal inference problem. We mirror the function calls and + terminology of the `predict.bartmodel` function, labeling each prediction data term with a `1` to denote + its contribution to the treatment prediction of a contrast and `0` to denote inclusion in the control prediction. Parameters ---------- - covariates : np.array - Test set covariates. - basis : np.array, optional - Optional test set basis vector, must be provided if the model was trained with a leaf regression basis. + covariates_0 : np.array or pd.DataFrame + Covariates used for prediction in the "control" case. Must be a numpy array or dataframe. + covariates_1 : np.array or pd.DataFrame + Covariates used for prediction in the "treatment" case. Must be a numpy array or dataframe. + basis_0 : np.array, optional + Bases used for prediction in the "control" case (by e.g. dot product with leaf values). + basis_1 : np.array, optional + Bases used for prediction in the "treatment" case (by e.g. dot product with leaf values). + rfx_group_ids_0 : np.array, optional + Test set group labels used for prediction from an additive random effects model in the "control" case. + We do not currently support (but plan to in the near future), test set evaluation for group labels that + were not in the training set. Must be a numpy array. + rfx_group_ids_1 : np.array, optional + Test set group labels used for prediction from an additive random effects model in the "treatment" case. + We do not currently support (but plan to in the near future), test set evaluation for group labels that + were not in the training set. Must be a numpy array. + rfx_basis_0 : np.array, optional + Test set basis for used for prediction from an additive random effects model in the "control" case. + rfx_basis_1 : np.array, optional + Test set basis for used for prediction from an additive random effects model in the "treatment" case. + type : str, optional + Aggregation level of the contrast. Options are "mean", which averages the contrast evaluations over every draw of a BART model, and "posterior", which returns the entire matrix of posterior contrast estimates. Default: "posterior". + scale : str, optional + Scale of the contrast. Options are "linear", which returns predictions on the original scale of the mean forest / RFX terms, and "probability", which transforms predictions into a probability of observing `y == 1`. "probability" is only valid for models fit with a probit outcome model. Default: "linear". Returns ------- - np.array - Mean forest predictions. + Array, either 1d or 2d depending on whether type = "mean" or "posterior". """ + # Handle mean function scale + if not isinstance(scale, str): + raise ValueError("scale must be a string") + if scale not in ["linear", "probability"]: + raise ValueError("scale must either be 'linear' or 'probability'") + is_probit = self.probit_outcome_model + if (scale == "probability") and (not is_probit): + raise ValueError( + "scale cannot be 'probability' for models not fit with a probit outcome model" + ) + probability_scale = scale == "probability" + + # Handle prediction type + if not isinstance(type, str): + raise ValueError("type must be a string") + if type not in ["mean", "posterior"]: + raise ValueError("type must either be 'mean' or 'posterior'") + predict_mean = type == "mean" + + # Handle prediction terms + has_mean_forest = self.include_mean_forest + has_rfx = self.has_rfx + + # Check that we have at least one term to predict on probability scale + if not has_mean_forest and not has_rfx: + raise ValueError( + "Contrast cannot be computed as the model does not have a mean forest or random effects term" + ) + + # Check the model is valid if not self.is_sampled(): msg = ( "This BARTModel instance is not fitted yet. Call 'fit' with " @@ -1725,160 +2024,335 @@ def predict_mean( ) raise NotSampledError(msg) - has_mean_predictions = self.include_mean_forest or self.has_rfx - if not has_mean_predictions: - msg = ( - "This BARTModel instance was not sampled with a mean forest or random effects. " - "Call 'fit' with appropriate arguments before using this model." - ) - raise NotSampledError(msg) - # Data checks - if not isinstance(covariates, pd.DataFrame) and not isinstance( - covariates, np.ndarray + if not isinstance(covariates_0, pd.DataFrame) and not isinstance( + covariates_0, np.ndarray ): - raise ValueError("covariates must be a pandas dataframe or numpy array") - if basis is not None: - if not isinstance(basis, np.ndarray): - raise ValueError("basis must be a numpy array") - if basis.shape[0] != covariates.shape[0]: + raise ValueError("covariates_0 must be a pandas dataframe or numpy array") + if not isinstance(covariates_1, pd.DataFrame) and not isinstance( + covariates_1, np.ndarray + ): + raise ValueError("covariates_1 must be a pandas dataframe or numpy array") + if basis_0 is not None: + if not isinstance(basis_0, np.ndarray): + raise ValueError("basis_0 must be a numpy array") + if basis_0.shape[0] != covariates_0.shape[0]: raise ValueError( - "covariates and basis must have the same number of rows" + "covariates_0 and basis_0 must have the same number of rows" ) - - # Convert everything to standard shape (2-dimensional) - if isinstance(covariates, np.ndarray): - if covariates.ndim == 1: - covariates = np.expand_dims(covariates, 1) - if basis is not None: - if basis.ndim == 1: - basis = np.expand_dims(basis, 1) - - # Covariate preprocessing - if not self._covariate_preprocessor._check_is_fitted(): - if not isinstance(covariates, np.ndarray): + if basis_1 is not None: + if not isinstance(basis_1, np.ndarray): + raise ValueError("basis_1 must be a numpy array") + if basis_1.shape[0] != covariates_1.shape[0]: raise ValueError( - "Prediction cannot proceed on a pandas dataframe, since the BART model was not fit with a covariate preprocessor. Please refit your model by passing covariate data as a Pandas dataframe." + "covariates_1 and basis_1 must have the same number of rows" ) - else: - warnings.warn( - "This BART model has not run any covariate preprocessing routines. We will attempt to predict on the raw covariate values, but this will trigger an error with non-numeric columns. Please refit your model by passing non-numeric covariate data a a Pandas dataframe.", - RuntimeWarning, - ) - if not np.issubdtype( - covariates.dtype, np.floating - ) and not np.issubdtype(covariates.dtype, np.integer): - raise ValueError( - "Prediction cannot proceed on a non-numeric numpy array, since the BART model was not fit with a covariate preprocessor. Please refit your model by passing non-numeric covariate data as a Pandas dataframe." - ) - covariates_processed = covariates - else: - covariates_processed = self._covariate_preprocessor.transform(covariates) - # Dataset construction - pred_dataset = Dataset() - pred_dataset.add_covariates(covariates_processed) - if basis is not None: - pred_dataset.add_basis(basis) + # Predict for the control arm + control_preds = self.predict( + covariates=covariates_0, + basis=basis_0, + rfx_group_ids=rfx_group_ids_0, + rfx_basis=rfx_basis_0, + type="posterior", + terms="y_hat", + scale="linear", + ) - # Mean forest predictions - if self.include_mean_forest: - mean_pred_raw = self.forest_container_mean.forest_container_cpp.Predict( - pred_dataset.dataset_cpp - ) - mean_pred = mean_pred_raw * self.y_std + self.y_bar + # Predict for the treatment arm + treatment_preds = self.predict( + covariates=covariates_1, + basis=basis_1, + rfx_group_ids=rfx_group_ids_1, + rfx_basis=rfx_basis_1, + type="posterior", + terms="y_hat", + scale="linear", + ) - # RFX predictions - if self.has_rfx: - rfx_preds = ( - self.rfx_container.predict(rfx_group_ids, rfx_basis) * self.y_std - ) - if self.include_mean_forest: - mean_pred = mean_pred + rfx_preds - else: - mean_pred = rfx_preds + self.y_bar + # Transform to probability scale if requested + if probability_scale: + treatment_preds = norm.cdf(treatment_preds) + control_preds = norm.cdf(control_preds) - return mean_pred + # Compute and return contrast + if predict_mean: + return np.mean(treatment_preds - control_preds, axis=1) + else: + return treatment_preds - control_preds - def predict_variance(self, covariates: np.array) -> np.array: - """Predict expected conditional variance from a BART model. + def compute_posterior_interval( + self, + terms: Union[list[str], str] = "all", + scale: str = "linear", + level: float = 0.95, + covariates: np.array = None, + basis: np.array = None, + rfx_group_ids: np.array = None, + rfx_basis: np.array = None, + ) -> dict: + """ + Compute posterior credible intervals for specified terms from a fitted BART model. It supports intervals for mean functions, variance functions, random effects, and overall predictions. Parameters ---------- - covariates : np.array - Test set covariates. + terms : str, optional + Character string specifying the model term(s) for which to compute intervals. Options for BART models are `"mean_forest"`, `"variance_forest"`, `"rfx"`, `"y_hat"`, or `"all"`. Defaults to `"all"`. + scale : str, optional + Scale of mean function predictions. Options are "linear", which returns predictions on the original scale of the mean forest / RFX terms, and "probability", which transforms predictions into a probability of observing `y == 1`. "probability" is only valid for models fit with a probit outcome model. Defaults to `"linear"`. + level : float, optional + A numeric value between 0 and 1 specifying the credible interval level. Defaults to 0.95 for a 95% credible interval. + covariates : np.array, optional + Optional array or data frame of covariates at which to compute the intervals. Required if the requested term depends on covariates (e.g., mean forest, variance forest, or overall predictions). + basis : np.array, optional + Optional array of basis function evaluations for mean forest models with regression defined in the leaves. Required for "leaf regression" models. + rfx_group_ids : np.array, optional + Optional vector of group IDs for random effects. Required if the requested term includes random effects. + rfx_basis : np.array, optional + Optional matrix of basis function evaluations for random effects. Required if the requested term includes random effects. Returns ------- - np.array - Variance forest predictions. + dict + A dict containing the lower and upper bounds of the credible interval for the specified term. If multiple terms are requested, a dict with intervals for each term is returned. """ + # Check the provided model object and requested terms if not self.is_sampled(): - msg = ( - "This BARTModel instance is not fitted yet. Call 'fit' with " - "appropriate arguments before using this model." - ) - raise NotSampledError(msg) + raise ValueError("Model has not yet been sampled") + for term in terms: + if not self.has_term(term): + warnings.warn( + f"Term {term} was not sampled in this model and its intervals will not be returned." + ) - if not self.include_variance_forest: - msg = ( - "This BARTModel instance was not sampled with a variance forest. " - "Call 'fit' with appropriate arguments before using this model." + # Handle mean function scale + if not isinstance(scale, str): + raise ValueError("scale must be a string") + if scale not in ["linear", "probability"]: + raise ValueError("scale must either be 'linear' or 'probability'") + is_probit = self.probit_outcome_model + if (scale == "probability") and (not is_probit): + raise ValueError( + "scale cannot be 'probability' for models not fit with a probit outcome model" ) - raise NotSampledError(msg) - - # Data checks - if not isinstance(covariates, pd.DataFrame) and not isinstance( - covariates, np.ndarray - ): - raise ValueError("covariates must be a pandas dataframe or numpy array") - - # Convert everything to standard shape (2-dimensional) - if isinstance(covariates, np.ndarray): - if covariates.ndim == 1: - covariates = np.expand_dims(covariates, 1) - # Covariate preprocessing - if not self._covariate_preprocessor._check_is_fitted(): - if not isinstance(covariates, np.ndarray): + # Check that all the necessary inputs were provided for interval computation + needs_covariates_intermediate = ( + ("y_hat" in terms) or ("all" in terms) + ) and self.include_mean_forest + needs_covariates = ( + ("mean_forest" in terms) + or ("variance_forest" in terms) + or needs_covariates_intermediate + ) + if needs_covariates: + if covariates is None: raise ValueError( - "Prediction cannot proceed on a pandas dataframe, since the BART model was not fit with a covariate preprocessor. Please refit your model by passing covariate data as a Pandas dataframe." + "'covariates' must be provided in order to compute the requested intervals" ) - else: - warnings.warn( - "This BART model has not run any covariate preprocessing routines. We will attempt to predict on the raw covariate values, but this will trigger an error with non-numeric columns. Please refit your model by passing non-numeric covariate data a a Pandas dataframe.", - RuntimeWarning, + if not isinstance(covariates, np.ndarray) and not isinstance( + covariates, pd.DataFrame + ): + raise ValueError("'covariates' must be a matrix or data frame") + needs_basis = needs_covariates and self.has_basis + if needs_basis: + if basis is None: + raise ValueError( + "'basis' must be provided in order to compute the requested intervals" ) - if not np.issubdtype( - covariates.dtype, np.floating - ) and not np.issubdtype(covariates.dtype, np.integer): - raise ValueError( - "Prediction cannot proceed on a non-numeric numpy array, since the BART model was not fit with a covariate preprocessor. Please refit your model by passing non-numeric covariate data as a Pandas dataframe." + if not isinstance(basis, np.ndarray): + raise ValueError("'basis' must be a numpy array") + if basis.shape[0] != covariates.shape[0]: + raise ValueError( + "'basis' must have the same number of rows as 'covariates'" + ) + needs_rfx_data_intermediate = ( + ("y_hat" in terms) or ("all" in terms) + ) and self.has_rfx + needs_rfx_data = ("rfx" in terms) or needs_rfx_data_intermediate + if needs_rfx_data: + if rfx_group_ids is None: + raise ValueError( + "'rfx_group_ids' must be provided in order to compute the requested intervals" + ) + if not isinstance(rfx_group_ids, np.ndarray): + raise ValueError("'rfx_group_ids' must be a numpy array") + if rfx_group_ids.shape[0] != covariates.shape[0]: + raise ValueError( + "'rfx_group_ids' must have the same length as the number of rows in 'covariates'" + ) + if rfx_basis is None: + raise ValueError( + "'rfx_basis' must be provided in order to compute the requested intervals" + ) + if not isinstance(rfx_basis, np.ndarray): + raise ValueError("'rfx_basis' must be a numpy array") + if rfx_basis.shape[0] != covariates.shape[0]: + raise ValueError( + "'rfx_basis' must have the same number of rows as 'covariates'" + ) + + # Compute posterior matrices for the requested model terms + predictions = self.predict( + covariates=covariates, + basis=basis, + rfx_group_ids=rfx_group_ids, + rfx_basis=rfx_basis, + type="posterior", + terms=terms, + scale=scale, + ) + has_multiple_terms = True if isinstance(predictions, dict) else False + + # Compute posterior intervals + if has_multiple_terms: + result = dict() + for term in predictions.keys(): + if predictions[term] is not None: + result[term] = _summarize_interval( + predictions[term], 1, level=level ) - covariates_processed = covariates + return result else: - covariates_processed = self._covariate_preprocessor.transform(covariates) + return _summarize_interval(predictions, 1, level=level) - # Dataset construction - pred_dataset = Dataset() - pred_dataset.add_covariates(covariates_processed) + def sample_posterior_predictive( + self, + covariates: np.array = None, + basis: np.array = None, + rfx_group_ids: np.array = None, + rfx_basis: np.array = None, + num_draws_per_sample: int = None, + ) -> np.array: + """ + Sample from the posterior predictive distribution for outcomes modeled by BART - # Variance forest predictions - variance_pred_raw = self.forest_container_variance.forest_container_cpp.Predict( - pred_dataset.dataset_cpp - ) - if self.sample_sigma2_global: - variance_pred = np.empty_like(variance_pred_raw) - for i in range(self.num_samples): - variance_pred[:, i] = ( - variance_pred_raw[:, i] * self.global_var_samples[i] + Parameters + ---------- + covariates : np.array, optional + An array or data frame of covariates at which to compute the intervals. Required if the BART model depends on covariates (e.g., contains a mean or variance forest). + basis : np.array, optional + An array of basis function evaluations for mean forest models with regression defined in the leaves. Required for "leaf regression" models. + rfx_group_ids : np.array, optional + An array of group IDs for random effects. Required if the BART model includes random effects. + rfx_basis : np.array, optional + An array of basis function evaluations for random effects. Required if the BART model includes random effects. + num_draws_per_sample : int, optional + The number of posterior predictive samples to draw for each posterior sample. Defaults to a heuristic based on the number of samples in a BART model (i.e. if the BART model has >1000 draws, we use 1 draw from the likelihood per sample, otherwise we upsample to ensure intervals are based on at least 1000 posterior predictive draws). + + Returns + ------- + np.array + A matrix of posterior predictive samples. If `num_draws = 1`. + """ + # Check the provided model object + if not self.is_sampled(): + raise ValueError("Model has not yet been sampled") + + # Determine whether the outcome is continuous (Gaussian) or binary (probit-link) + is_probit = self.probit_outcome_model + + # Check that all the necessary inputs were provided for interval computation + needs_covariates = self.include_mean_forest + if needs_covariates: + if covariates is None: + raise ValueError( + "'covariates' must be provided in order to compute the requested intervals" ) + if not isinstance(covariates, np.ndarray) and not isinstance( + covariates, pd.DataFrame + ): + raise ValueError("'covariates' must be a matrix or data frame") + needs_basis = needs_covariates and self.has_basis + if needs_basis: + if basis is None: + raise ValueError( + "'basis' must be provided in order to compute the requested intervals" + ) + if not isinstance(basis, np.ndarray): + raise ValueError("'basis' must be a numpy array") + if basis.shape[0] != covariates.shape[0]: + raise ValueError( + "'basis' must have the same number of rows as 'covariates'" + ) + needs_rfx_data = self.has_rfx + if needs_rfx_data: + if rfx_group_ids is None: + raise ValueError( + "'rfx_group_ids' must be provided in order to compute the requested intervals" + ) + if not isinstance(rfx_group_ids, np.ndarray): + raise ValueError("'rfx_group_ids' must be a numpy array") + if rfx_group_ids.shape[0] != covariates.shape[0]: + raise ValueError( + "'rfx_group_ids' must have the same length as the number of rows in 'covariates'" + ) + if rfx_basis is None: + raise ValueError( + "'rfx_basis' must be provided in order to compute the requested intervals" + ) + if not isinstance(rfx_basis, np.ndarray): + raise ValueError("'rfx_basis' must be a numpy array") + if rfx_basis.shape[0] != covariates.shape[0]: + raise ValueError( + "'rfx_basis' must have the same number of rows as 'covariates'" + ) + + # Compute posterior predictive samples + bart_preds = self.predict( + covariates=covariates, + basis=basis, + rfx_group_ids=rfx_group_ids, + rfx_basis=rfx_basis, + type="posterior", + terms="all", + ) + + # Compute outcome mean and variance for posterior predictive distribution + has_mean_term = self.include_mean_forest or self.has_rfx + has_variance_forest = self.include_variance_forest + samples_global_variance = self.sample_sigma2_global + num_posterior_draws = self.num_samples + num_observations = covariates.shape[0] + if has_mean_term: + ppd_mean = bart_preds["y_hat"] else: - variance_pred = ( - variance_pred_raw * self.sigma2_init * self.y_std * self.y_std + ppd_mean = 0.0 + if has_variance_forest: + ppd_variance = bart_preds["variance_forest_predictions"] + else: + if samples_global_variance: + ppd_variance = np.tile(self.global_var_samples, (num_observations, 1)) + else: + ppd_variance = self.sigma2_init + + # Sample from the posterior predictive distribution + if num_draws_per_sample is None: + ppd_draw_multiplier = _posterior_predictive_heuristic_multiplier( + num_posterior_draws, num_observations + ) + else: + ppd_draw_multiplier = num_draws_per_sample + if ppd_draw_multiplier > 1: + ppd_mean = np.tile(ppd_mean, (ppd_draw_multiplier, 1, 1)) + ppd_variance = np.tile(ppd_variance, (ppd_draw_multiplier, 1, 1)) + ppd_array = np.random.normal( + loc=ppd_mean, + scale=np.sqrt(ppd_variance), + size=(ppd_draw_multiplier, num_observations, num_posterior_draws), + ) + else: + ppd_array = np.random.normal( + loc=ppd_mean, + scale=np.sqrt(ppd_variance), + size=(num_observations, num_posterior_draws), ) - return variance_pred + # Binarize outcome for probit models + if is_probit: + ppd_array = (ppd_array > 0.0) * 1 + + return ppd_array def to_json(self) -> str: """ @@ -1920,6 +2394,8 @@ def to_json(self) -> str: bart_json.add_boolean("include_mean_forest", self.include_mean_forest) bart_json.add_boolean("include_variance_forest", self.include_variance_forest) bart_json.add_boolean("has_rfx", self.has_rfx) + bart_json.add_boolean("has_rfx_basis", self.has_rfx_basis) + bart_json.add_scalar("num_rfx_basis", self.num_rfx_basis) bart_json.add_integer("num_gfr", self.num_gfr) bart_json.add_integer("num_burnin", self.num_burnin) bart_json.add_integer("num_mcmc", self.num_mcmc) @@ -1927,6 +2403,7 @@ def to_json(self) -> str: bart_json.add_integer("num_basis", self.num_basis) bart_json.add_boolean("requires_basis", self.has_basis) bart_json.add_boolean("probit_outcome_model", self.probit_outcome_model) + bart_json.add_string("rfx_model_spec", self.rfx_model_spec) # Add parameter samples if self.sample_sigma2_global: @@ -1961,6 +2438,8 @@ def from_json(self, json_string: str) -> None: self.include_mean_forest = bart_json.get_boolean("include_mean_forest") self.include_variance_forest = bart_json.get_boolean("include_variance_forest") self.has_rfx = bart_json.get_boolean("has_rfx") + self.has_rfx_basis = bart_json.get_boolean("has_rfx_basis") + self.num_rfx_basis = bart_json.get_scalar("num_rfx_basis") if self.include_mean_forest: # TODO: don't just make this a placeholder that we overwrite self.forest_container_mean = ForestContainer(0, 0, False, False) @@ -1999,6 +2478,7 @@ def from_json(self, json_string: str) -> None: self.num_basis = bart_json.get_integer("num_basis") self.has_basis = bart_json.get_boolean("requires_basis") self.probit_outcome_model = bart_json.get_boolean("probit_outcome_model") + self.rfx_model_spec = bart_json.get_string("rfx_model_spec") # Unpack parameter samples if self.sample_sigma2_global: @@ -2084,6 +2564,8 @@ def from_json_string_list(self, json_string_list: list[str]) -> None: # Unpack random effects self.has_rfx = json_object_default.get_boolean("has_rfx") + self.has_rfx_basis = json_object_default.get_boolean("has_rfx_basis") + self.num_rfx_basis = json_object_default.get_scalar("num_rfx_basis") if self.has_rfx: self.rfx_container = RandomEffectsContainer() for i in range(len(json_object_list)): @@ -2109,6 +2591,7 @@ def from_json_string_list(self, json_string_list: list[str]) -> None: self.probit_outcome_model = json_object_default.get_boolean( "probit_outcome_model" ) + self.rfx_model_spec = json_object_default.get_string("rfx_model_spec") # Unpack number of samples for i in range(len(json_object_list)): @@ -2165,3 +2648,30 @@ def is_sampled(self) -> bool: `True` if a BART model has been sampled, `False` otherwise """ return self.sampled + + def has_term(self, term: str) -> bool: + """ + Whether or not a model includes a term. + + Parameters + ---------- + term : str + Character string specifying the model term to check for. Options for BART models are `"mean_forest"`, `"variance_forest"`, `"rfx"`, `"y_hat"`, or `"all"`. + + Returns + ------- + bool + `True` if the model includes the specified term, `False` otherwise + """ + if term == "mean_forest": + return self.include_mean_forest + elif term == "variance_forest": + return self.include_variance_forest + elif term == "rfx": + return self.has_rfx + elif term == "y_hat": + return self.include_mean_forest or self.has_rfx + elif term == "all": + return True + else: + return False diff --git a/stochtree/bcf.py b/stochtree/bcf.py index bfe9cc34..be4410cc 100644 --- a/stochtree/bcf.py +++ b/stochtree/bcf.py @@ -1,7 +1,3 @@ -""" -Bayesian Causal Forests (BCF) module -""" - import warnings from typing import Any, Dict, Optional, Union @@ -23,7 +19,14 @@ ) from .sampler import RNG, ForestSampler, GlobalVarianceModel, LeafVarianceModel from .serialization import JSONSerializer -from .utils import NotSampledError, _expand_dims_1d, _expand_dims_2d, _expand_dims_2d_diag +from .utils import ( + NotSampledError, + _expand_dims_1d, + _expand_dims_2d, + _expand_dims_2d_diag, + _posterior_predictive_heuristic_multiplier, + _summarize_interval, +) class BCFModel: @@ -94,6 +97,7 @@ def sample( prognostic_forest_params: Optional[Dict[str, Any]] = None, treatment_effect_forest_params: Optional[Dict[str, Any]] = None, variance_forest_params: Optional[Dict[str, Any]] = None, + random_effects_params: Optional[Dict[str, Any]] = None, ) -> None: """Runs a BCF sampler on provided training set. Outcome predictions and estimates of the prognostic and treatment effect functions will be cached for the training set and (if provided) the test set. @@ -152,12 +156,6 @@ def sample( * `keep_every` (`int`): How many iterations of the burned-in MCMC sampler should be run before forests and parameters are retained. Defaults to `1`. Setting `keep_every = k` for some `k > 1` will "thin" the MCMC samples by retaining every `k`-th sample, rather than simply every sample. This can reduce the autocorrelation of the MCMC samples. * `num_chains` (`int`): How many independent MCMC chains should be sampled. If `num_mcmc = 0`, this is ignored. If `num_gfr = 0`, then each chain is run from root for `num_mcmc * keep_every + num_burnin` iterations, with `num_mcmc` samples retained. If `num_gfr > 0`, each MCMC chain will be initialized from a separate GFR ensemble, with the requirement that `num_gfr >= num_chains`. Defaults to `1`. * `probit_outcome_model` (`bool`): Whether or not the outcome should be modeled as explicitly binary via a probit link. If `True`, `y` must only contain the values `0` and `1`. Default: `False`. - * `rfx_working_parameter_prior_mean`: Prior mean for the random effects "working parameter". Default: `None`. Must be a 1D numpy array whose dimension matches the number of random effects bases, or a scalar value that will be expanded to a vector. - * `rfx_group_parameter_prior_mean`: Prior mean for the random effects "group parameters." Default: `None`. Must be a 1D numpy array whose dimension matches the number of random effects bases, or a scalar value that will be expanded to a vector. - * `rfx_working_parameter_prior_cov`: Prior covariance matrix for the random effects "working parameter." Default: `None`. Must be a square numpy matrix whose dimension matches the number of random effects bases, or a scalar value that will be expanded to a diagonal matrix. - * `rfx_group_parameter_prior_cov`: Prior covariance matrix for the random effects "group parameters." Default: `None`. Must be a square numpy matrix whose dimension matches the number of random effects bases, or a scalar value that will be expanded to a diagonal matrix. - * `rfx_variance_prior_shape`: Shape parameter for the inverse gamma prior on the variance of the random effects "group parameter." Default: `1`. - * `rfx_variance_prior_scale`: Scale parameter for the inverse gamma prior on the variance of the random effects "group parameter." Default: `1`. * `num_threads`: Number of threads to use in the GFR and MCMC algorithms, as well as prediction. If OpenMP is not available on a user's setup, this will default to `1`, otherwise to the maximum number of available threads. prognostic_forest_params : dict, optional @@ -210,6 +208,17 @@ def sample( * `drop_vars` (`list` or `np.array`): Vector of variable names or column indices denoting variables that should be excluded from the variance forest. Defaults to `None`. If both `drop_vars` and `keep_vars` are set, `drop_vars` will be ignored. * `num_features_subsample` (`int`): How many features to subsample when growing each tree for the GFR algorithm. Defaults to the number of features in the training dataset. + random_effects_params : dict, optional + Dictionary of random effects parameters, each of which has a default value processed internally, so this argument is optional. + + * `model_spec`: Specification of the random effects model. Options are "custom", "intercept_only", and "intercept_plus_treatment". If "custom" is specified, then a user-provided basis must be passed through `rfx_basis_train`. If "intercept_only" is specified, a random effects basis of all ones will be dispatched internally at sampling and prediction time. If "intercept_plus_treatment" is specified, a random effects basis that combines an "intercept" basis of all ones with the treatment variable (`Z_train`) will be dispatched internally at sampling and prediction time. Default: "custom". If either "intercept_only" or "intercept_plus_treatment" is specified, `rfx_basis_train` and `rfx_basis_test` (if provided) will be ignored. + * `working_parameter_prior_mean`: Prior mean for the random effects "working parameter". Default: `None`. Must be a 1D numpy array whose dimension matches the number of random effects bases, or a scalar value that will be expanded to a vector. + * `group_parameter_prior_mean`: Prior mean for the random effects "group parameters." Default: `None`. Must be a 1D numpy array whose dimension matches the number of random effects bases, or a scalar value that will be expanded to a vector. + * `working_parameter_prior_cov`: Prior covariance matrix for the random effects "working parameter." Default: `None`. Must be a square numpy matrix whose dimension matches the number of random effects bases, or a scalar value that will be expanded to a diagonal matrix. + * `group_parameter_prior_cov`: Prior covariance matrix for the random effects "group parameters." Default: `None`. Must be a square numpy matrix whose dimension matches the number of random effects bases, or a scalar value that will be expanded to a diagonal matrix. + * `variance_prior_shape`: Shape parameter for the inverse gamma prior on the variance of the random effects "group parameter." Default: `1`. + * `variance_prior_scale`: Scale parameter for the inverse gamma prior on the variance of the random effects "group parameter." Default: `1`. + Returns ------- self : BCFModel @@ -234,12 +243,6 @@ def sample( "keep_every": 1, "num_chains": 1, "probit_outcome_model": False, - "rfx_working_parameter_prior_mean": None, - "rfx_group_parameter_prior_mean": None, - "rfx_working_parameter_prior_cov": None, - "rfx_group_parameter_prior_cov": None, - "rfx_variance_prior_shape": 1.0, - "rfx_variance_prior_scale": 1.0, "num_threads": -1, } general_params_updated = _preprocess_params( @@ -304,6 +307,20 @@ def sample( variance_forest_params_default, variance_forest_params ) + # Update random effects parameters + rfx_params_default = { + "model_spec": "custom", + "working_parameter_prior_mean": None, + "group_parameter_prior_mean": None, + "working_parameter_prior_cov": None, + "group_parameter_prior_cov": None, + "variance_prior_shape": 1.0, + "variance_prior_scale": 1.0, + } + rfx_params_updated = _preprocess_params( + rfx_params_default, random_effects_params + ) + ### Unpack all parameter values # 1. General parameters cutpoint_grid_size = general_params_updated["cutpoint_grid_size"] @@ -323,12 +340,6 @@ def sample( keep_every = general_params_updated["keep_every"] num_chains = general_params_updated["num_chains"] self.probit_outcome_model = general_params_updated["probit_outcome_model"] - rfx_working_parameter_prior_mean = general_params_updated["rfx_working_parameter_prior_mean"] - rfx_group_parameter_prior_mean = general_params_updated["rfx_group_parameter_prior_mean"] - rfx_working_parameter_prior_cov = general_params_updated["rfx_working_parameter_prior_cov"] - rfx_group_parameter_prior_cov = general_params_updated["rfx_group_parameter_prior_cov"] - rfx_variance_prior_shape = general_params_updated["rfx_variance_prior_shape"] - rfx_variance_prior_scale = general_params_updated["rfx_variance_prior_scale"] num_threads = general_params_updated["num_threads"] # 2. Mu forest parameters @@ -343,7 +354,9 @@ def sample( b_leaf_mu = prognostic_forest_params_updated["sigma2_leaf_scale"] keep_vars_mu = prognostic_forest_params_updated["keep_vars"] drop_vars_mu = prognostic_forest_params_updated["drop_vars"] - num_features_subsample_mu = prognostic_forest_params_updated["num_features_subsample"] + num_features_subsample_mu = prognostic_forest_params_updated[ + "num_features_subsample" + ] # 3. Tau forest parameters num_trees_tau = treatment_effect_forest_params_updated["num_trees"] @@ -362,7 +375,9 @@ def sample( delta_max = treatment_effect_forest_params_updated["delta_max"] keep_vars_tau = treatment_effect_forest_params_updated["keep_vars"] drop_vars_tau = treatment_effect_forest_params_updated["drop_vars"] - num_features_subsample_tau = treatment_effect_forest_params_updated["num_features_subsample"] + num_features_subsample_tau = treatment_effect_forest_params_updated[ + "num_features_subsample" + ] # 4. Variance forest parameters num_trees_variance = variance_forest_params_updated["num_trees"] @@ -378,7 +393,36 @@ def sample( b_forest = variance_forest_params_updated["var_forest_prior_scale"] keep_vars_variance = variance_forest_params_updated["keep_vars"] drop_vars_variance = variance_forest_params_updated["drop_vars"] - num_features_subsample_variance = variance_forest_params_updated["num_features_subsample"] + num_features_subsample_variance = variance_forest_params_updated[ + "num_features_subsample" + ] + + # 5. Random effects parameters + self.rfx_model_spec = rfx_params_updated["model_spec"] + rfx_working_parameter_prior_mean = rfx_params_updated[ + "working_parameter_prior_mean" + ] + rfx_group_parameter_prior_mean = rfx_params_updated[ + "group_parameter_prior_mean" + ] + rfx_working_parameter_prior_cov = rfx_params_updated[ + "working_parameter_prior_cov" + ] + rfx_group_parameter_prior_cov = rfx_params_updated["group_parameter_prior_cov"] + rfx_variance_prior_shape = rfx_params_updated["variance_prior_shape"] + rfx_variance_prior_scale = rfx_params_updated["variance_prior_scale"] + + # Check random effects specification + if not isinstance(self.rfx_model_spec, str): + raise ValueError("rfx_model_spec must be a string") + if self.rfx_model_spec not in [ + "custom", + "intercept_only", + "intercept_plus_treatment", + ]: + raise ValueError( + "type must either be 'custom', 'intercept_only', 'intercept_plus_treatment'" + ) # Override keep_gfr if there are no MCMC samples if num_mcmc == 0: @@ -1193,34 +1237,24 @@ def sample( current_sigma2 = sigma2_init self.sigma2_init = sigma2_init # Skip variance_forest_init, since variance forests are not supported with probit link - b_leaf_mu = ( - 1.0 / num_trees_mu - if b_leaf_mu is None - else b_leaf_mu - ) - b_leaf_tau = ( - 1.0 / (2 * num_trees_tau) - if b_leaf_tau is None - else b_leaf_tau - ) + b_leaf_mu = 1.0 / num_trees_mu if b_leaf_mu is None else b_leaf_mu + b_leaf_tau = 1.0 / (2 * num_trees_tau) if b_leaf_tau is None else b_leaf_tau sigma2_leaf_mu = ( - 1 / num_trees_mu - if sigma2_leaf_mu is None - else sigma2_leaf_mu + 1 / num_trees_mu if sigma2_leaf_mu is None else sigma2_leaf_mu ) if isinstance(sigma2_leaf_mu, float): current_leaf_scale_mu = np.array([[sigma2_leaf_mu]]) else: raise ValueError("sigma2_leaf_mu must be a scalar") # Calibrate prior so that P(abs(tau(X)) < delta_max / dnorm(0)) = p - # Use p = 0.9 as an internal default rather than adding another - # user-facing "parameter" of the binary outcome BCF prior. - # Can be overriden by specifying `sigma2_leaf_init` in + # Use p = 0.9 as an internal default rather than adding another + # user-facing "parameter" of the binary outcome BCF prior. + # Can be overriden by specifying `sigma2_leaf_init` in # treatment_effect_forest_params. p = 0.6827 q_quantile = norm.ppf((p + 1) / 2.0) sigma2_leaf_tau = ( - ((delta_max / (q_quantile*norm.pdf(0)))**2) / num_trees_tau + ((delta_max / (q_quantile * norm.pdf(0))) ** 2) / num_trees_tau if sigma2_leaf_tau is None else sigma2_leaf_tau ) @@ -1231,7 +1265,9 @@ def sample( ) if isinstance(sigma2_leaf_tau, float): if Z_train.shape[1] > 1: - current_leaf_scale_tau = np.zeros((Z_train.shape[1], Z_train.shape[1]), dtype=float) + current_leaf_scale_tau = np.zeros( + (Z_train.shape[1], Z_train.shape[1]), dtype=float + ) np.fill_diagonal(current_leaf_scale_tau, sigma2_leaf_tau) else: current_leaf_scale_tau = np.array([[sigma2_leaf_tau]]) @@ -1304,7 +1340,9 @@ def sample( ) if isinstance(sigma2_leaf_tau, float): if Z_train.shape[1] > 1: - current_leaf_scale_tau = np.zeros((Z_train.shape[1], Z_train.shape[1]), dtype=float) + current_leaf_scale_tau = np.zeros( + (Z_train.shape[1], Z_train.shape[1]), dtype=float + ) np.fill_diagonal(current_leaf_scale_tau, sigma2_leaf_tau) else: current_leaf_scale_tau = np.array([[sigma2_leaf_tau]]) @@ -1333,7 +1371,7 @@ def sample( if not a_forest: a_forest = 1.0 if not b_forest: - b_forest = 1.0 + b_forest = 1.0 # Runtime checks on RFX group ids self.has_rfx = False @@ -1347,57 +1385,82 @@ def sample( "All random effect group labels provided in rfx_group_ids_test must be present in rfx_group_ids_train" ) - # Fill in rfx basis as a vector of 1s (random intercept) if a basis not provided - has_basis_rfx = False + # Handle the rfx basis matrices + self.has_rfx_basis = False + self.num_rfx_basis = 0 if self.has_rfx: - if rfx_basis_train is None: - rfx_basis_train = np.ones((rfx_group_ids_train.shape[0], 1)) - else: - has_basis_rfx = True + if self.rfx_model_spec == "custom": + if rfx_basis_train is None: + raise ValueError( + "rfx_basis_train must be provided when rfx_model_spec = 'custom'" + ) + elif self.rfx_model_spec == "intercept_only": + if rfx_basis_train is None: + rfx_basis_train = np.ones((rfx_group_ids_train.shape[0], 1)) + elif self.rfx_model_spec == "intercept_plus_treatment": + if rfx_basis_train is None: + rfx_basis_train = np.concatenate( + (np.ones((rfx_group_ids_train.shape[0], 1)), Z_train), axis=1 + ) + self.has_rfx_basis = True + self.num_rfx_basis = rfx_basis_train.shape[1] num_rfx_groups = np.unique(rfx_group_ids_train).shape[0] num_rfx_components = rfx_basis_train.shape[1] - # TODO warn if num_rfx_groups is 1 + if num_rfx_groups == 1: + warnings.warn( + "Only one group was provided for random effect sampling, so the random effects model is likely overkill" + ) if has_rfx_test: - if rfx_basis_test is None: - if has_basis_rfx: + if self.rfx_model_spec == "custom": + if rfx_basis_test is None: raise ValueError( - "Random effects basis provided for training set, must also be provided for the test set" + "rfx_basis_test must be provided when rfx_model_spec = 'custom' and a test set is provided" + ) + elif self.rfx_model_spec == "intercept_only": + if rfx_basis_test is None: + rfx_basis_test = np.ones((rfx_group_ids_test.shape[0], 1)) + elif self.rfx_model_spec == "intercept_plus_treatment": + if rfx_basis_test is None: + rfx_basis_test = np.concatenate( + (np.ones((rfx_group_ids_test.shape[0], 1)), Z_test), axis=1 ) - rfx_basis_test = np.ones((rfx_group_ids_test.shape[0], 1)) # Set up random effects structures if self.has_rfx: # Prior parameters if rfx_working_parameter_prior_mean is None: if num_rfx_components == 1: - alpha_init = np.array([1]) + alpha_init = np.array([0.0], dtype=float) elif num_rfx_components > 1: - alpha_init = np.concatenate( - ( - np.ones(1, dtype=float), - np.zeros(num_rfx_components - 1, dtype=float), - ) - ) + alpha_init = np.zeros(num_rfx_components, dtype=float) else: raise ValueError("There must be at least 1 random effect component") else: - alpha_init = _expand_dims_1d(rfx_working_parameter_prior_mean, num_rfx_components) - + alpha_init = _expand_dims_1d( + rfx_working_parameter_prior_mean, num_rfx_components + ) + if rfx_group_parameter_prior_mean is None: xi_init = np.tile(np.expand_dims(alpha_init, 1), (1, num_rfx_groups)) else: - xi_init = _expand_dims_2d(rfx_group_parameter_prior_mean, num_rfx_components, num_rfx_groups) - + xi_init = _expand_dims_2d( + rfx_group_parameter_prior_mean, num_rfx_components, num_rfx_groups + ) + if rfx_working_parameter_prior_cov is None: sigma_alpha_init = np.identity(num_rfx_components) else: - sigma_alpha_init = _expand_dims_2d_diag(rfx_working_parameter_prior_cov, num_rfx_components) - + sigma_alpha_init = _expand_dims_2d_diag( + rfx_working_parameter_prior_cov, num_rfx_components + ) + if rfx_group_parameter_prior_cov is None: sigma_xi_init = np.identity(num_rfx_components) else: - sigma_xi_init = _expand_dims_2d_diag(rfx_group_parameter_prior_cov, num_rfx_components) - + sigma_xi_init = _expand_dims_2d_diag( + rfx_group_parameter_prior_cov, num_rfx_components + ) + sigma_xi_shape = rfx_variance_prior_shape sigma_xi_scale = rfx_variance_prior_scale @@ -1526,7 +1589,9 @@ def sample( muhat_train_raw = np.empty((self.n_train, self.num_samples), dtype=np.float64) tauhat_train_raw = np.empty((self.n_train, self.num_samples), dtype=np.float64) if self.include_variance_forest: - sigma2_x_train_raw = np.empty((self.n_train, self.num_samples), dtype=np.float64) + sigma2_x_train_raw = np.empty( + (self.n_train, self.num_samples), dtype=np.float64 + ) sample_counter = -1 # Prepare adaptive coding structure @@ -1699,14 +1764,17 @@ def sample( keep_sample = True if keep_sample: sample_counter += 1 - + if self.probit_outcome_model: # Sample latent probit variable z | - forest_pred_mu = active_forest_mu.predict(forest_dataset_train) forest_pred_tau = active_forest_tau.predict(forest_dataset_train) - forest_pred = forest_pred_mu + forest_pred_tau - mu0 = forest_pred[y_train[:, 0] == 0] - mu1 = forest_pred[y_train[:, 0] == 1] + outcome_pred = forest_pred_mu + forest_pred_tau + if self.has_rfx: + rfx_pred = rfx_model.predict(rfx_dataset_train, rfx_tracker) + outcome_pred = outcome_pred + rfx_pred + mu0 = outcome_pred[y_train[:, 0] == 0] + mu1 = outcome_pred[y_train[:, 0] == 1] n0 = np.sum(y_train[:, 0] == 0) n1 = np.sum(y_train[:, 0] == 1) u0 = self.rng.uniform( @@ -1723,9 +1791,9 @@ def sample( resid_train[y_train[:, 0] == 1, 0] = mu1 + norm.ppf(u1) # Update outcome - new_outcome = np.squeeze(resid_train) - forest_pred + new_outcome = np.squeeze(resid_train) - outcome_pred residual_train.update_data(new_outcome) - + # Sample the prognostic forest forest_sampler_mu.sample_one_iteration( self.forest_container_mu, @@ -1742,7 +1810,9 @@ def sample( # Cache train set predictions since they are already computed during sampling if keep_sample: - muhat_train_raw[:,sample_counter] = forest_sampler_mu.get_cached_forest_predictions() + muhat_train_raw[:, sample_counter] = ( + forest_sampler_mu.get_cached_forest_predictions() + ) # Sample variance parameters (if requested) if self.sample_sigma2_global: @@ -1778,7 +1848,7 @@ def sample( num_threads, ) - # Cannot cache train set predictions for tau because the cached predictions in the + # Cannot cache train set predictions for tau because the cached predictions in the # tracking data structures are pre-multiplied by the basis (treatment) # ... @@ -1788,14 +1858,19 @@ def sample( tau_x = np.squeeze( active_forest_tau.predict_raw(forest_dataset_train) ) + partial_resid_train = np.squeeze(resid_train - mu_x) + if self.has_rfx: + rfx_pred = np.squeeze( + rfx_model.predict(rfx_dataset_train, rfx_tracker) + ) + partial_resid_train = partial_resid_train - rfx_pred s_tt0 = np.sum(tau_x * tau_x * (np.squeeze(Z_train) == 0)) s_tt1 = np.sum(tau_x * tau_x * (np.squeeze(Z_train) == 1)) - partial_resid_mu = np.squeeze(resid_train - mu_x) s_ty0 = np.sum( - tau_x * partial_resid_mu * (np.squeeze(Z_train) == 0) + tau_x * partial_resid_train * (np.squeeze(Z_train) == 0) ) s_ty1 = np.sum( - tau_x * partial_resid_mu * (np.squeeze(Z_train) == 1) + tau_x * partial_resid_train * (np.squeeze(Z_train) == 1) ) current_b_0 = self.rng.normal( loc=(s_ty0 / (s_tt0 + 2 * current_sigma2)), @@ -1842,7 +1917,9 @@ def sample( # Cache train set predictions since they are already computed during sampling if keep_sample: - sigma2_x_train_raw[:,sample_counter] = forest_sampler_variance.get_cached_forest_predictions() + sigma2_x_train_raw[:, sample_counter] = ( + forest_sampler_variance.get_cached_forest_predictions() + ) # Sample variance parameters (if requested) if self.sample_sigma2_global: @@ -1895,14 +1972,17 @@ def sample( keep_sample = False if keep_sample: sample_counter += 1 - + if self.probit_outcome_model: # Sample latent probit variable z | - forest_pred_mu = active_forest_mu.predict(forest_dataset_train) forest_pred_tau = active_forest_tau.predict(forest_dataset_train) - forest_pred = forest_pred_mu + forest_pred_tau - mu0 = forest_pred[y_train[:, 0] == 0] - mu1 = forest_pred[y_train[:, 0] == 1] + outcome_pred = forest_pred_mu + forest_pred_tau + if self.has_rfx: + rfx_pred = rfx_model.predict(rfx_dataset_train, rfx_tracker) + outcome_pred = outcome_pred + rfx_pred + mu0 = outcome_pred[y_train[:, 0] == 0] + mu1 = outcome_pred[y_train[:, 0] == 1] n0 = np.sum(y_train[:, 0] == 0) n1 = np.sum(y_train[:, 0] == 1) u0 = self.rng.uniform( @@ -1919,9 +1999,9 @@ def sample( resid_train[y_train[:, 0] == 1, 0] = mu1 + norm.ppf(u1) # Update outcome - new_outcome = np.squeeze(resid_train) - forest_pred + new_outcome = np.squeeze(resid_train) - outcome_pred residual_train.update_data(new_outcome) - + # Sample the prognostic forest forest_sampler_mu.sample_one_iteration( self.forest_container_mu, @@ -1938,7 +2018,9 @@ def sample( # Cache train set predictions since they are already computed during sampling if keep_sample: - muhat_train_raw[:,sample_counter] = forest_sampler_mu.get_cached_forest_predictions() + muhat_train_raw[:, sample_counter] = ( + forest_sampler_mu.get_cached_forest_predictions() + ) # Sample variance parameters (if requested) if self.sample_sigma2_global: @@ -1974,7 +2056,7 @@ def sample( num_threads, ) - # Cannot cache train set predictions for tau because the cached predictions in the + # Cannot cache train set predictions for tau because the cached predictions in the # tracking data structures are pre-multiplied by the basis (treatment) # ... @@ -1984,14 +2066,19 @@ def sample( tau_x = np.squeeze( active_forest_tau.predict_raw(forest_dataset_train) ) + partial_resid_train = np.squeeze(resid_train - mu_x) + if self.has_rfx: + rfx_pred = np.squeeze( + rfx_model.predict(rfx_dataset_train, rfx_tracker) + ) + partial_resid_train = partial_resid_train - rfx_pred s_tt0 = np.sum(tau_x * tau_x * (np.squeeze(Z_train) == 0)) s_tt1 = np.sum(tau_x * tau_x * (np.squeeze(Z_train) == 1)) - partial_resid_mu = np.squeeze(resid_train - mu_x) s_ty0 = np.sum( - tau_x * partial_resid_mu * (np.squeeze(Z_train) == 0) + tau_x * partial_resid_train * (np.squeeze(Z_train) == 0) ) s_ty1 = np.sum( - tau_x * partial_resid_mu * (np.squeeze(Z_train) == 1) + tau_x * partial_resid_train * (np.squeeze(Z_train) == 1) ) current_b_0 = self.rng.normal( loc=(s_ty0 / (s_tt0 + 2 * current_sigma2)), @@ -2038,7 +2125,9 @@ def sample( # Cache train set predictions since they are already computed during sampling if keep_sample: - sigma2_x_train_raw[:,sample_counter] = forest_sampler_variance.get_cached_forest_predictions() + sigma2_x_train_raw[:, sample_counter] = ( + forest_sampler_variance.get_cached_forest_predictions() + ) # Sample variance parameters (if requested) if self.sample_sigma2_global: @@ -2095,9 +2184,9 @@ def sample( self.leaf_scale_mu_samples = self.leaf_scale_mu_samples[num_gfr:] if self.sample_sigma2_leaf_tau: self.leaf_scale_tau_samples = self.leaf_scale_tau_samples[num_gfr:] - muhat_train_raw = muhat_train_raw[:,num_gfr:] + muhat_train_raw = muhat_train_raw[:, num_gfr:] if self.include_variance_forest: - sigma2_x_train_raw = sigma2_x_train_raw[:,num_gfr:] + sigma2_x_train_raw = sigma2_x_train_raw[:, num_gfr:] self.num_samples -= num_gfr # Store predictions @@ -2179,7 +2268,10 @@ def sample( ) else: self.sigma2_x_train = ( - np.exp(sigma2_x_train_raw) * self.sigma2_init * self.y_std * self.y_std + np.exp(sigma2_x_train_raw) + * self.sigma2_init + * self.y_std + * self.y_std ) if self.has_test: sigma2_x_test_raw = ( @@ -2198,10 +2290,21 @@ def sample( sigma2_x_test_raw * self.sigma2_init * self.y_std * self.y_std ) - def predict_tau( - self, X: np.array, Z: np.array, propensity: np.array = None - ) -> np.array: - """Predict CATE function for every provided observation. + def predict( + self, + X: np.array, + Z: np.array, + propensity: np.array = None, + rfx_group_ids: np.array = None, + rfx_basis: np.array = None, + type: str = "posterior", + terms: Union[list[str], str] = "all", + scale: str = "linear", + ) -> Union[dict[str, np.array], np.array]: + """Predict outcome model components (CATE function and prognostic function) as well as overall outcome for every provided observation. + Predicted outcomes are computed as `yhat = mu_x + Z*tau_x` where mu_x is a sample of the prognostic function and tau_x is a sample of the treatment effect (CATE) function. + When random effects are present, they are either included in yhat additively if `rfx_model_spec == "custom"`. They are included in mu_x if `rfx_model_spec == "intercept_only"` or + partially included in mu_x and partially included in tau_x `rfx_model_spec == "intercept_plus_treatment"`. Parameters ---------- @@ -2209,14 +2312,88 @@ def predict_tau( Test set covariates. Z : np.array Test set treatment indicators. - propensity : np.array, optional + propensity : `np.array`, optional Optional test set propensities. Must be provided if propensities were provided when the model was sampled. + rfx_group_ids : np.array, optional + Optional group labels used for an additive random effects model. + rfx_basis : np.array, optional + Optional basis for "random-slope" regression in an additive random effects model. Not necessary if `rfx_model_spec` is "intercept_only" or "intercept_plus_treatment", but if rfx_basis is provided, it will supercede the basis implied by `rfx_model_spec`. + type : str, optional + Type of prediction to return. Options are "mean", which averages the predictions from every draw of a BART model, and "posterior", which returns the entire matrix of posterior predictions. Default: "posterior". + terms : str, optional + Which model terms to include in the prediction. This can be a single term or a list of model terms. Options include "y_hat", "prognostic_function", "cate", "rfx", "variance_forest", or "all". If a model doesn't have mean forest, random effects, or variance forest predictions, but one of those terms is request, the request will simply be ignored. If none of the requested terms are present in a model, this function will return `NULL` along with a warning. Default: "all". + scale : str, optional + Scale on which to return predictions. Options are "linear" (the default), which returns predictions on the original outcome scale, and "probit", which returns predictions on the probit (latent) scale. Only applicable for models fit with `probit_outcome_model=True`. Returns ------- - np.array - Array with as many rows as in `X` and as many columns as retained samples of the algorithm. + Dict of numpy arrays for each prediction term, or a simple numpy array if a single term is requested. """ + # Handle mean function scale + if not isinstance(scale, str): + raise ValueError("scale must be a string") + if scale not in ["linear", "probability"]: + raise ValueError("scale must either be 'linear' or 'probability'") + is_probit = self.probit_outcome_model + if (scale == "probability") and (not is_probit): + raise ValueError( + "scale cannot be 'probability' for models not fit with a probit outcome model" + ) + probability_scale = scale == "probability" + + # Handle prediction type + if not isinstance(type, str): + raise ValueError("type must be a string") + if type not in ["mean", "posterior"]: + raise ValueError("type must either be 'mean' or 'posterior'") + predict_mean = type == "mean" + + # Handle prediction terms + rfx_model_spec = self.rfx_model_spec + rfx_intercept_only = rfx_model_spec == "intercept_only" + rfx_intercept_plus_treatment = rfx_model_spec == "intercept_plus_treatment" + rfx_intercept = rfx_intercept_only or rfx_intercept_plus_treatment + if not isinstance(terms, str) and not isinstance(terms, list): + raise ValueError("type must be a string or list of strings") + num_terms = 1 if isinstance(terms, str) else len(terms) + has_mu_forest = True + has_tau_forest = True + has_variance_forest = self.include_variance_forest + has_rfx = self.has_rfx + has_y_hat = has_mu_forest or has_tau_forest or has_rfx + predict_y_hat = (has_y_hat and ("y_hat" in terms)) or ( + has_y_hat and ("all" in terms) + ) + predict_mu_forest = (has_mu_forest and ("prognostic_function" in terms)) or ( + has_mu_forest and ("all" in terms) + ) + predict_tau_forest = (has_tau_forest and ("cate" in terms)) or ( + has_tau_forest and ("all" in terms) + ) + predict_rfx = (has_rfx and ("rfx" in terms)) or (has_rfx and ("all" in terms)) + predict_variance_forest = ( + has_variance_forest and ("variance_forest" in terms) + ) or (has_variance_forest and ("all" in terms)) + predict_count = ( + predict_y_hat + + predict_mu_forest + + predict_tau_forest + + predict_rfx + + predict_variance_forest + ) + if predict_count == 0: + term_list = ", ".join(terms) + warnings.warn( + f"None of the requested model terms, {term_list}, were fit in this model" + ) + return None + predict_rfx_intermediate = predict_y_hat and has_rfx + predict_rfx_raw = (predict_mu_forest and has_rfx and rfx_intercept) or ( + predict_tau_forest and has_rfx and rfx_intercept_plus_treatment + ) + predict_mu_forest_intermediate = predict_y_hat and has_mu_forest + predict_tau_forest_intermediate = predict_y_hat and has_tau_forest + if not self.is_sampled(): msg = ( "This BCFModel instance is not fitted yet. Call 'fit' with " @@ -2260,7 +2437,7 @@ def predict_tau( covariates_processed = X else: covariates_processed = self._covariate_preprocessor.transform(X) - + # Handle propensities if propensity is not None: if propensity.shape[0] != X.shape[0]: @@ -2272,9 +2449,11 @@ def predict_tau( "Propensity scores not provided, but no propensity model was trained during sampling" ) else: - internal_propensity_preds = self.bart_propensity_model.predict(covariates_processed) + internal_propensity_preds = self.bart_propensity_model.predict( + covariates_processed + ) propensity = np.mean( - internal_propensity_preds['y_hat'], axis=1, keepdims=True + internal_propensity_preds["y_hat"], axis=1, keepdims=True ) # Update covariates to include propensities if requested @@ -2288,166 +2467,265 @@ def predict_tau( forest_dataset_test.add_covariates(X_combined) forest_dataset_test.add_basis(Z) - # Estimate treatment effect - tau_raw = self.forest_container_tau.forest_container_cpp.PredictRaw( - forest_dataset_test.dataset_cpp - ) - tau_raw = tau_raw - if self.adaptive_coding: - adaptive_coding_weights = np.expand_dims( - self.b1_samples - self.b0_samples, axis=(0, 2) + # Compute predictions from the variance forest (if included) + if predict_variance_forest: + sigma2_x_raw = self.forest_container_variance.forest_container_cpp.Predict( + forest_dataset_test.dataset_cpp ) - tau_raw = tau_raw * adaptive_coding_weights - tau_x = np.squeeze(tau_raw * self.y_std) - - # Return result matrix - return tau_x - - def predict_variance( - self, covariates: np.array, propensity: np.array = None - ) -> np.array: - """Predict expected conditional variance from a BART model. - - Parameters - ---------- - covariates : np.array - Test set covariates. - propensity : np.array, optional - Test set propensity scores. Optional (not currently used in variance forests). + if self.sample_sigma2_global: + sigma2_x = np.empty_like(sigma2_x_raw) + for i in range(self.num_samples): + sigma2_x[:, i] = sigma2_x_raw[:, i] * self.global_var_samples[i] + else: + sigma2_x = sigma2_x_raw * self.sigma2_init * self.y_std * self.y_std + if predict_mean: + sigma2_x = np.mean(sigma2_x, axis=1) - Returns - ------- - np.array - Array of predictions corresponding to the variance forest. Each array will contain as many rows as in `covariates` and as many columns as retained samples of the algorithm. - """ - if not self.is_sampled(): - msg = ( - "This BARTModel instance is not fitted yet. Call 'fit' with " - "appropriate arguments before using this model." + # Prognostic forest predictions + if predict_mu_forest or predict_mu_forest_intermediate: + mu_raw = self.forest_container_mu.forest_container_cpp.Predict( + forest_dataset_test.dataset_cpp ) - raise NotSampledError(msg) + mu_x_forest = mu_raw * self.y_std + self.y_bar - if not self.include_variance_forest: - msg = ( - "This BARTModel instance was not sampled with a variance forest. " - "Call 'fit' with appropriate arguments before using this model." + # Treatment effect forest predictions + if predict_tau_forest or predict_tau_forest_intermediate: + tau_raw = self.forest_container_tau.forest_container_cpp.PredictRaw( + forest_dataset_test.dataset_cpp ) - raise NotSampledError(msg) - - # Convert everything to standard shape (2-dimensional) - if covariates.ndim == 1: - covariates = np.expand_dims(covariates, 1) - if propensity is not None: - if propensity.ndim == 1: - propensity = np.expand_dims(propensity, 1) - - # Covariate preprocessing - if not self._covariate_preprocessor._check_is_fitted(): - if not isinstance(covariates, np.ndarray): - raise ValueError( - "Prediction cannot proceed on a pandas dataframe, since the BCF model was not fit with a covariate preprocessor. Please refit your model by passing covariate data as a Pandas dataframe." + if self.adaptive_coding: + adaptive_coding_weights = np.expand_dims( + self.b1_samples - self.b0_samples, axis=(0, 2) ) + tau_raw = tau_raw * adaptive_coding_weights + tau_x_forest = np.squeeze(tau_raw * self.y_std) + if Z.shape[1] > 1: + treatment_term = np.multiply( + np.atleast_3d(Z).swapaxes(1, 2), tau_x_forest + ).sum(axis=2) else: - warnings.warn( - "This BCF model has not run any covariate preprocessing routines. We will attempt to predict on the raw covariate values, but this will trigger an error with non-numeric columns. Please refit your model by passing non-numeric covariate data a a Pandas dataframe.", - RuntimeWarning, + treatment_term = Z * np.squeeze(tau_x_forest) + + # Random effects data checks + if has_rfx: + if rfx_group_ids is None: + raise ValueError( + "rfx_group_ids must be provided if rfx_basis is provided" ) - if not np.issubdtype( - covariates.dtype, np.floating - ) and not np.issubdtype(covariates.dtype, np.integer): + if rfx_basis is not None: + if rfx_basis.ndim == 1: + rfx_basis = np.expand_dims(rfx_basis, 1) + if rfx_basis.shape[0] != X.shape[0]: raise ValueError( - "Prediction cannot proceed on a non-numeric numpy array, since the BCF model was not fit with a covariate preprocessor. Please refit your model by passing non-numeric covariate data as a Pandas dataframe." + "X and rfx_basis must have the same number of rows" ) - covariates_processed = covariates - else: - covariates_processed = self._covariate_preprocessor.transform(covariates) - - # Handle propensities - if propensity is not None: - if propensity.shape[0] != covariates.shape[0]: - raise ValueError("X and propensity must have the same number of rows") - else: - if self.propensity_covariate != "none": - if not self.internal_propensity_model: + if rfx_basis.shape[1] != self.num_rfx_basis: raise ValueError( - "Propensity scores not provided, but no propensity model was trained during sampling" - ) - else: - internal_propensity_preds = self.bart_propensity_model.predict(covariates_processed) - propensity = np.mean( - internal_propensity_preds['y_hat'], axis=1, keepdims=True + "rfx_basis must have the same number of columns as the random effects basis used to sample this model" ) - # Update covariates to include propensities if requested - if self.propensity_covariate == "none": - X_combined = covariates_processed - else: - if propensity is not None: - X_combined = np.c_[covariates_processed, propensity] + # Random effects predictions + if predict_rfx or predict_rfx_intermediate: + rfx_preds = ( + self.rfx_container.predict(rfx_group_ids, rfx_basis) * self.y_std + ) + + # Extract "raw" rfx predictions for each rfx basis term if needed + if predict_rfx_raw: + # Extract the raw RFX samples and scale by train set outcome standard deviation + rfx_samples_raw = self.rfx_container.extract_parameter_samples() + rfx_beta_draws = rfx_samples_raw["beta_samples"] * self.y_std + + # Construct an array with the appropriate group random effects arranged for each observation + if rfx_beta_draws.ndim == 3: + rfx_predictions_raw = np.empty( + shape=(X.shape[0], rfx_beta_draws.shape[0], rfx_beta_draws.shape[2]) + ) + for i in range(X.shape[0]): + rfx_predictions_raw[i, :, :] = rfx_beta_draws[ + :, rfx_group_ids[i], : + ] + elif rfx_beta_draws.ndim == 2: + rfx_predictions_raw = np.empty( + shape=(X.shape[0], 1, rfx_beta_draws.shape[1]) + ) + for i in range(X.shape[0]): + rfx_predictions_raw[i, 0, :] = rfx_beta_draws[rfx_group_ids[i], :] else: - # Dummy propensities if not provided but also not needed - propensity = np.ones(covariates_processed.shape[0]) - propensity = np.expand_dims(propensity, 1) - X_combined = np.c_[covariates_processed, propensity] + raise ValueError( + "Unexpected number of dimensions in extracted random effects samples" + ) - # Forest dataset - pred_dataset = Dataset() - pred_dataset.add_covariates(X_combined) + # Add raw RFX predictions to mu and tau if warranted by the RFX model spec + if predict_mu_forest or predict_mu_forest_intermediate: + if rfx_intercept and predict_rfx_raw: + mu_x = mu_x_forest + np.squeeze(rfx_predictions_raw[:, 0, :]) + else: + mu_x = mu_x_forest + if predict_tau_forest or predict_tau_forest_intermediate: + if rfx_intercept_plus_treatment and predict_rfx_raw: + tau_x = tau_x_forest + np.squeeze(rfx_predictions_raw[:, 1:, :]) + else: + tau_x = tau_x_forest - variance_pred_raw = self.forest_container_variance.forest_container_cpp.Predict( - pred_dataset.dataset_cpp + # Combine into y hat predictions + needs_mean_term_preds = ( + predict_y_hat or predict_mu_forest or predict_tau_forest or predict_rfx ) - if self.sample_sigma2_global: - variance_pred = np.empty_like(variance_pred_raw) - for i in range(self.num_samples): - variance_pred[:, i] = ( - variance_pred_raw[:, i] * self.global_var_samples[i] - ) + if needs_mean_term_preds: + if probability_scale: + if has_rfx: + if predict_y_hat: + y_hat = norm.cdf(mu_x_forest + treatment_term + rfx_preds) + if predict_rfx: + rfx_preds = norm.cdf(rfx_preds) + else: + if predict_y_hat: + y_hat = norm.cdf(mu_x_forest + treatment_term) + if predict_mu_forest: + mu_x = norm.cdf(mu_x) + if predict_tau_forest: + tau_x = norm.cdf(tau_x) + else: + if has_rfx: + if predict_y_hat: + y_hat = mu_x_forest + treatment_term + rfx_preds + else: + if predict_y_hat: + y_hat = mu_x_forest + treatment_term + if predict_mu_forest: + mu_x = mu_x + if predict_tau_forest: + tau_x = tau_x + + # Collapse to posterior mean predictions if requested + if predict_mean: + if predict_mu_forest: + mu_x = np.mean(mu_x, axis=1) + if predict_tau_forest: + if Z.shape[1] > 1: + tau_x = np.mean(tau_x, axis=2) + else: + tau_x = np.mean(tau_x, axis=1) + if predict_rfx: + rfx_preds = np.mean(rfx_preds, axis=1) + if predict_y_hat: + y_hat = np.mean(y_hat, axis=1) + + if predict_count == 1: + if predict_y_hat: + return y_hat + elif predict_mu_forest: + return mu_x + elif predict_tau_forest: + return tau_x + elif predict_rfx: + return rfx_preds + elif predict_variance_forest: + return sigma2_x else: - variance_pred = ( - variance_pred_raw * self.sigma2_init * self.y_std * self.y_std - ) - - return variance_pred + result = dict() + if predict_y_hat: + result["y_hat"] = y_hat + else: + result["y_hat"] = None + if predict_mu_forest: + result["mu_hat"] = mu_x + else: + result["mu_hat"] = None + if predict_tau_forest: + result["tau_hat"] = tau_x + else: + result["tau_hat"] = None + if predict_rfx: + result["rfx_predictions"] = rfx_preds + else: + result["rfx_predictions"] = None + if predict_variance_forest: + result["variance_forest_predictions"] = sigma2_x + else: + result["variance_forest_predictions"] = None + return result - def predict( + def compute_contrast( self, - X: np.array, - Z: np.array, - propensity: np.array = None, - rfx_group_ids: np.array = None, - rfx_basis: np.array = None, + X_0: np.array, + X_1: np.array, + Z_0: np.array, + Z_1: np.array, + propensity_0: np.array = None, + propensity_1: np.array = None, + rfx_group_ids_0: np.array = None, + rfx_group_ids_1: np.array = None, + rfx_basis_0: np.array = None, + rfx_basis_1: np.array = None, + type: str = "posterior", + scale: str = "linear", ) -> dict: - """Predict outcome model components (CATE function and prognostic function) as well as overall outcome for every provided observation. - Predicted outcomes are computed as `yhat = mu_x + Z*tau_x` where mu_x is a sample of the prognostic function and tau_x is a sample of the treatment effect (CATE) function. + """Compute a contrast using a BCF model by making two sets of outcome predictions and taking their + difference. This function provides the flexibility to compute any contrast of interest by specifying + covariates, leaf basis, and random effects bases / IDs for both sides of a two term contrast. + For simplicity, we refer to the subtrahend of the contrast as the "control" or `Y0` term and the minuend + of the contrast as the `Y1` term, though the requested contrast need not match the "control vs treatment" + terminology of a classic two-treatment causal inference problem. We mirror the function calls and + terminology of the `predict.bartmodel` function, labeling each prediction data term with a `1` to denote + its contribution to the treatment prediction of a contrast and `0` to denote inclusion in the control prediction. Parameters ---------- - X : np.array or pd.DataFrame - Test set covariates. - Z : np.array - Test set treatment indicators. - propensity : `np.array`, optional - Optional test set propensities. Must be provided if propensities were provided when the model was sampled. - rfx_group_ids : np.array, optional - Optional group labels used for an additive random effects model. - rfx_basis : np.array, optional - Optional basis for "random-slope" regression in an additive random effects model. + X_0 : np.array or pd.DataFrame + Covariates used for prediction in the "control" case. Must be a numpy array or dataframe. + X_1 : np.array or pd.DataFrame + Covariates used for prediction in the "treatment" case. Must be a numpy array or dataframe. + Z_0 : np.array + Treatments used for prediction in the "control" case. Must be a numpy array or vector. + Z_1 : np.array + Treatments used for prediction in the "treatment" case. Must be a numpy array or vector. + propensity_0 : `np.array`, optional + Propensities used for prediction in the "control" case. Must be a numpy array or vector. + propensity_1 : `np.array`, optional + Propensities used for prediction in the "treatment" case. Must be a numpy array or vector. + rfx_group_ids_0 : np.array, optional + Test set group labels used for prediction from an additive random effects model in the "control" case. + We do not currently support (but plan to in the near future), test set evaluation for group labels that + were not in the training set. Must be a numpy array. + rfx_group_ids_1 : np.array, optional + Test set group labels used for prediction from an additive random effects model in the "control" case. + We do not currently support (but plan to in the near future), test set evaluation for group labels that + were not in the training set. Must be a numpy array. + rfx_basis_0 : np.array, optional + Test set basis for used for prediction from an additive random effects model in the "control" case. Must be a numpy array. + rfx_basis_1 : np.array, optional + Test set basis for used for prediction from an additive random effects model in the "treatment" case. Must be a numpy array. + type : str, optional + Aggregation level of the contrast. Options are "mean", which averages the contrast evaluations over every draw of a BCF model, and "posterior", which returns the entire matrix of posterior contrast estimates. Default: "posterior". + scale : str, optional + Scale of the contrast. Options are "linear", which returns a contrast on the original scale of the mean forest / RFX terms, and "probability", which transforms each contrast term into a probability of observing `y == 1` before taking their difference. "probability" is only valid for models fit with a probit outcome model. Default: "linear". Returns ------- - tau_x : np.array - Conditional average treatment effect (CATE) samples for every observation provided. - mu_x : np.array - Prognostic effect samples for every observation provided. - rfx : np.array, optional - Random effect samples for every observation provided, if the model includes a random effects term. - yhat_x : np.array - Outcome prediction samples for every observation provided. - sigma2_x : np.array, optional - Variance forest samples for every observation provided. Only returned if the - model includes a heteroskedasticity forest. + Array, either 1d or 2d depending on whether type = "mean" or "posterior". """ + # Handle mean function scale + if not isinstance(scale, str): + raise ValueError("scale must be a string") + if scale not in ["linear", "probability"]: + raise ValueError("scale must either be 'linear' or 'probability'") + is_probit = self.probit_outcome_model + if (scale == "probability") and (not is_probit): + raise ValueError( + "scale cannot be 'probability' for models not fit with a probit outcome model" + ) + probability_scale = scale == "probability" + + # Handle prediction type + if not isinstance(type, str): + raise ValueError("type must be a string") + if type not in ["mean", "posterior"]: + raise ValueError("type must either be 'mean' or 'posterior'") + predict_mean = type == "mean" + + # Check the model is valid if not self.is_sampled(): msg = ( "This BCFModel instance is not fitted yet. Call 'fit' with " @@ -2455,119 +2733,355 @@ def predict( ) raise NotSampledError(msg) - # Convert everything to standard shape (2-dimensional) - if X.ndim == 1: - X = np.expand_dims(X, 1) - if Z.ndim == 1: - Z = np.expand_dims(Z, 1) + # Data checks + if Z_0.shape[0] != X_0.shape[0]: + raise ValueError("X_0 and Z_0 must have the same number of rows") + if Z_1.shape[0] != X_1.shape[0]: + raise ValueError("X_1 and Z_1 must have the same number of rows") + + # Predict for the control arm + control_preds = self.predict( + X=X_0, + Z=Z_0, + propensity=propensity_0, + rfx_group_ids=rfx_group_ids_0, + rfx_basis=rfx_basis_0, + type="posterior", + terms="y_hat", + scale="linear", + ) + + # Predict for the treatment arm + treatment_preds = self.predict( + X=X_1, + Z=Z_1, + propensity=propensity_1, + rfx_group_ids=rfx_group_ids_1, + rfx_basis=rfx_basis_1, + type="posterior", + terms="y_hat", + scale="linear", + ) + + # Transform to probability scale if requested + if probability_scale: + treatment_preds = norm.cdf(treatment_preds) + control_preds = norm.cdf(control_preds) + + # Compute and return contrast + if predict_mean: + return np.mean(treatment_preds - control_preds, axis=1) else: - if Z.ndim != 2: - raise ValueError("treatment must have 1 or 2 dimensions") - if propensity is not None: - if propensity.ndim == 1: - propensity = np.expand_dims(propensity, 1) + return treatment_preds - control_preds - # Data checks - if Z.shape[0] != X.shape[0]: - raise ValueError("X and Z must have the same number of rows") + def compute_posterior_interval( + self, + terms: Union[list[str], str] = "all", + scale: str = "linear", + level: float = 0.95, + covariates: np.array = None, + treatment: np.array = None, + propensity: np.array = None, + rfx_group_ids: np.array = None, + rfx_basis: np.array = None, + ) -> dict: + """ + Compute posterior credible intervals for specified terms from a fitted BART model. It supports intervals for mean functions, variance functions, random effects, and overall predictions. - # Covariate preprocessing - if not self._covariate_preprocessor._check_is_fitted(): - if not isinstance(X, np.ndarray): + Parameters + ---------- + terms : str, optional + Character string specifying the model term(s) for which to compute intervals. Options for BCF models are `"prognostic_function"`, `"cate"`, `"variance_forest"`, `"rfx"`, or `"y_hat"`. Defaults to `"all"`. + scale : str, optional + Scale of mean function predictions. Options are "linear", which returns predictions on the original scale of the mean forest / RFX terms, and "probability", which transforms predictions into a probability of observing `y == 1`. "probability" is only valid for models fit with a probit outcome model. Defaults to `"linear"`. + level : float, optional + A numeric value between 0 and 1 specifying the credible interval level. Defaults to 0.95 for a 95% credible interval. + covariates : np.array, optional + Optional array or data frame of covariates at which to compute the intervals. Required if the requested term depends on covariates (e.g., prognostic forest, treatment effect forest, variance forest, or overall predictions). + treatment : np.array, optional + Optional array of treatment assignments. Required if the requested term is `"y_hat"` (overall predictions). + propensity : np.array, optional + Optional array of propensity scores. Required if the underlying model depends on user-provided propensities. + rfx_group_ids : np.array, optional + Optional vector of group IDs for random effects. Required if the requested term includes random effects. + rfx_basis : np.array, optional + Optional matrix of basis function evaluations for random effects. Required if the requested term includes random effects. + + Returns + ------- + dict + A dict containing the lower and upper bounds of the credible interval for the specified term. If multiple terms are requested, a dict with intervals for each term is returned. + """ + # Check the provided model object and requested term + if not self.is_sampled(): + raise ValueError("Model has not yet been sampled") + for term in terms: + if not self.has_term(term): + warnings.warn( + f"Term {term} was not sampled in this model and its intervals will not be returned." + ) + + # Handle mean function scale + if not isinstance(scale, str): + raise ValueError("scale must be a string") + if scale not in ["linear", "probability"]: + raise ValueError("scale must either be 'linear' or 'probability'") + is_probit = self.probit_outcome_model + if (scale == "probability") and (not is_probit): + raise ValueError( + "scale cannot be 'probability' for models not fit with a probit outcome model" + ) + + # Check that all the necessary inputs were provided for interval computation + needs_covariates_intermediate = ("y_hat" in terms) or ("all" in terms) + needs_covariates = ( + ("prognostic_function" in terms) + or ("cate" in terms) + or ("variance_forest" in terms) + or needs_covariates_intermediate + ) + if needs_covariates: + if covariates is None: raise ValueError( - "Prediction cannot proceed on a pandas dataframe, since the BCF model was not fit with a covariate preprocessor. Please refit your model by passing covariate data as a Pandas dataframe." + "'covariates' must be provided in order to compute the requested intervals" ) - else: - warnings.warn( - "This BCF model has not run any covariate preprocessing routines. We will attempt to predict on the raw covariate values, but this will trigger an error with non-numeric columns. Please refit your model by passing non-numeric covariate data a a Pandas dataframe.", - RuntimeWarning, + if not isinstance(covariates, np.ndarray) and not isinstance( + covariates, pd.DataFrame + ): + raise ValueError("'covariates' must be a matrix or data frame") + needs_treatment = needs_covariates + if needs_treatment: + if treatment is None: + raise ValueError( + "'treatment' must be provided in order to compute the requested intervals" ) - if not np.issubdtype(X.dtype, np.floating) and not np.issubdtype( - X.dtype, np.integer - ): - raise ValueError( - "Prediction cannot proceed on a non-numeric numpy array, since the BCF model was not fit with a covariate preprocessor. Please refit your model by passing non-numeric covariate data as a Pandas dataframe." + if not isinstance(treatment, np.ndarray): + raise ValueError("'treatment' must be a numpy array") + if treatment.shape[0] != covariates.shape[0]: + raise ValueError( + "'treatment' must have the same number of rows as 'covariates'" + ) + uses_propensity = self.propensity_covariate != "none" + internal_propensity_model = self.internal_propensity_model + needs_propensity = ( + needs_covariates and uses_propensity and not internal_propensity_model + ) + if needs_propensity: + if propensity is None: + raise ValueError( + "'propensity' must be provided in order to compute the requested intervals" + ) + if not isinstance(propensity, np.ndarray): + raise ValueError("'propensity' must be a numpy array") + if propensity.shape[0] != covariates.shape[0]: + raise ValueError( + "'propensity' must have the same number of rows as 'covariates'" + ) + needs_rfx_data_intermediate = ( + ("y_hat" in terms) or ("all" in terms) + ) and self.has_rfx + needs_rfx_data = ("rfx" in terms) or needs_rfx_data_intermediate + if needs_rfx_data: + if rfx_group_ids is None: + raise ValueError( + "'rfx_group_ids' must be provided in order to compute the requested intervals" + ) + if not isinstance(rfx_group_ids, np.ndarray): + raise ValueError("'rfx_group_ids' must be a numpy array") + if rfx_group_ids.shape[0] != covariates.shape[0]: + raise ValueError( + "'rfx_group_ids' must have the same length as the number of rows in 'covariates'" + ) + if rfx_basis is None: + raise ValueError( + "'rfx_basis' must be provided in order to compute the requested intervals" + ) + if not isinstance(rfx_basis, np.ndarray): + raise ValueError("'rfx_basis' must be a numpy array") + if rfx_basis.shape[0] != covariates.shape[0]: + raise ValueError( + "'rfx_basis' must have the same number of rows as 'covariates'" + ) + + # Compute posterior matrices for the requested model terms + predictions = self.predict( + X=covariates, + Z=treatment, + propensity=propensity, + rfx_group_ids=rfx_group_ids, + rfx_basis=rfx_basis, + type="posterior", + terms=terms, + scale=scale, + ) + has_multiple_terms = True if isinstance(predictions, dict) else False + + # Compute posterior intervals + if has_multiple_terms: + result = dict() + for term in predictions.keys(): + if predictions[term] is not None: + result[term] = _summarize_interval( + predictions[term], 1, level=level ) - covariates_processed = X + return result else: - covariates_processed = self._covariate_preprocessor.transform(X) - - # Handle propensities - if propensity is not None: - if propensity.shape[0] != X.shape[0]: - raise ValueError("X and propensity must have the same number of rows") - else: - if self.propensity_covariate != "none": - if not self.internal_propensity_model: - raise ValueError( - "Propensity scores not provided, but no propensity model was trained during sampling" - ) - else: - internal_propensity_preds = self.bart_propensity_model.predict(covariates_processed) - propensity = np.mean( - internal_propensity_preds['y_hat'], axis=1, keepdims=True - ) + return _summarize_interval(predictions, 1, level=level) - # Update covariates to include propensities if requested - if self.propensity_covariate == "none": - X_combined = covariates_processed - else: - X_combined = np.c_[covariates_processed, propensity] + def sample_posterior_predictive( + self, + covariates: np.array, + treatment: np.array, + propensity: np.array = None, + rfx_group_ids: np.array = None, + rfx_basis: np.array = None, + num_draws_per_sample: int = None, + ) -> np.array: + """ + Sample from the posterior predictive distribution for outcomes modeled by BART - # Forest dataset - forest_dataset_test = Dataset() - forest_dataset_test.add_covariates(X_combined) - forest_dataset_test.add_basis(Z) + Parameters + ---------- + covariates : np.array + An array or data frame of covariates. + treatment : np.array + An array of treatment assignments. + propensity : np.array, optional + Optional array of propensity scores. Required if the underlying model depends on user-provided propensities. + rfx_group_ids : np.array, optional + Optional vector of group IDs for random effects. Required if the requested term includes random effects. + rfx_basis : np.array, optional + Optional matrix of basis function evaluations for random effects. Required if the requested term includes random effects. + num_draws_per_sample : int, optional + The number of posterior predictive samples to draw for each posterior sample. Defaults to a heuristic based on the number of samples in a BCF model (i.e. if the BCF model has >1000 draws, we use 1 draw from the likelihood per sample, otherwise we upsample to ensure intervals are based on at least 1000 posterior predictive draws). - # Compute predicted outcome and decomposed outcome model terms - mu_raw = self.forest_container_mu.forest_container_cpp.Predict( - forest_dataset_test.dataset_cpp + Returns + ------- + np.array + A matrix of posterior predictive samples. If `num_draws = 1`. + """ + # Check the provided model object + if not self.is_sampled(): + raise ValueError("Model has not yet been sampled") + + # Determine whether the outcome is continuous (Gaussian) or binary (probit-link) + is_probit = self.probit_outcome_model + + # Check that all the necessary inputs were provided for interval computation + needs_covariates = True + if needs_covariates: + if covariates is None: + raise ValueError( + "'covariates' must be provided in order to compute the requested intervals" + ) + if not isinstance(covariates, np.ndarray) and not isinstance( + covariates, pd.DataFrame + ): + raise ValueError("'covariates' must be a matrix or data frame") + needs_treatment = needs_covariates + if needs_treatment: + if treatment is None: + raise ValueError( + "'treatment' must be provided in order to compute the requested intervals" + ) + if not isinstance(treatment, np.ndarray): + raise ValueError("'treatment' must be a numpy array") + if treatment.shape[0] != covariates.shape[0]: + raise ValueError( + "'treatment' must have the same number of rows as 'covariates'" + ) + uses_propensity = self.propensity_covariate != "none" + internal_propensity_model = self.internal_propensity_model + needs_propensity = ( + needs_covariates and uses_propensity and not internal_propensity_model ) - mu_x = mu_raw * self.y_std + self.y_bar - tau_raw = self.forest_container_tau.forest_container_cpp.PredictRaw( - forest_dataset_test.dataset_cpp + if needs_propensity: + if propensity is None: + raise ValueError( + "'propensity' must be provided in order to compute the requested intervals" + ) + if not isinstance(propensity, np.ndarray): + raise ValueError("'propensity' must be a numpy array") + if propensity.shape[0] != covariates.shape[0]: + raise ValueError( + "'propensity' must have the same number of rows as 'covariates'" + ) + needs_rfx_data = self.has_rfx + if needs_rfx_data: + if rfx_group_ids is None: + raise ValueError( + "'rfx_group_ids' must be provided in order to compute the requested intervals" + ) + if not isinstance(rfx_group_ids, np.ndarray): + raise ValueError("'rfx_group_ids' must be a numpy array") + if rfx_group_ids.shape[0] != covariates.shape[0]: + raise ValueError( + "'rfx_group_ids' must have the same length as the number of rows in 'covariates'" + ) + if rfx_basis is None: + raise ValueError( + "'rfx_basis' must be provided in order to compute the requested intervals" + ) + if not isinstance(rfx_basis, np.ndarray): + raise ValueError("'rfx_basis' must be a numpy array") + if rfx_basis.shape[0] != covariates.shape[0]: + raise ValueError( + "'rfx_basis' must have the same number of rows as 'covariates'" + ) + + # Compute posterior predictive samples + bcf_preds = self.predict( + X=covariates, + Z=treatment, + propensity=propensity, + rfx_group_ids=rfx_group_ids, + rfx_basis=rfx_basis, + type="posterior", + terms="all", + scale="linear", ) - if self.adaptive_coding: - adaptive_coding_weights = np.expand_dims( - self.b1_samples - self.b0_samples, axis=(0, 2) + + # Compute outcome mean and variance for posterior predictive distribution + has_variance_forest = self.include_variance_forest + samples_global_variance = self.sample_sigma2_global + num_posterior_draws = self.num_samples + num_observations = covariates.shape[0] + ppd_mean = bcf_preds["y_hat"] + if has_variance_forest: + ppd_variance = bcf_preds["variance_forest_predictions"] + else: + if samples_global_variance: + ppd_variance = np.tile(self.global_var_samples, (num_observations, 1)) + else: + ppd_variance = self.sigma2_init + + # Sample from the posterior predictive distribution + if num_draws_per_sample is None: + ppd_draw_multiplier = _posterior_predictive_heuristic_multiplier( + num_posterior_draws, num_observations ) - tau_raw = tau_raw * adaptive_coding_weights - tau_x = np.squeeze(tau_raw * self.y_std) - if Z.shape[1] > 1: - treatment_term = np.multiply(np.atleast_3d(Z).swapaxes(1, 2), tau_x).sum( - axis=2 + else: + ppd_draw_multiplier = num_draws_per_sample + if ppd_draw_multiplier > 1: + ppd_mean = np.tile(ppd_mean, (ppd_draw_multiplier, 1, 1)) + ppd_variance = np.tile(ppd_variance, (ppd_draw_multiplier, 1, 1)) + ppd_array = np.random.normal( + loc=ppd_mean, + scale=np.sqrt(ppd_variance), + size=(ppd_draw_multiplier, num_observations, num_posterior_draws), ) else: - treatment_term = Z * np.squeeze(tau_x) - yhat_x = mu_x + treatment_term - - if self.has_rfx: - rfx_preds = ( - self.rfx_container.predict(rfx_group_ids, rfx_basis) * self.y_std + ppd_array = np.random.normal( + loc=ppd_mean, + scale=np.sqrt(ppd_variance), + size=(num_observations, num_posterior_draws), ) - yhat_x = yhat_x + rfx_preds - # Compute predictions from the variance forest (if included) - if self.include_variance_forest: - sigma2_x_raw = self.forest_container_variance.forest_container_cpp.Predict( - forest_dataset_test.dataset_cpp - ) - if self.sample_sigma2_global: - sigma2_x = np.empty_like(sigma2_x_raw) - for i in range(self.num_samples): - sigma2_x[:, i] = sigma2_x_raw[:, i] * self.global_var_samples[i] - else: - sigma2_x = sigma2_x_raw * self.sigma2_init * self.y_std * self.y_std + # Binarize outcome for probit models + if is_probit: + ppd_array = (ppd_array > 0.0) * 1 - # Return result matrices as a tuple - if self.has_rfx and self.include_variance_forest: - return {"mu_hat": mu_x, "tau_hat": tau_x, "y_hat": yhat_x, "rfx_predictions": rfx_preds, "variance_forest_predictions": sigma2_x} - elif not self.has_rfx and self.include_variance_forest: - return {"mu_hat": mu_x, "tau_hat": tau_x, "y_hat": yhat_x, "rfx_predictions": None, "variance_forest_predictions": sigma2_x} - elif self.has_rfx and not self.include_variance_forest: - return {"mu_hat": mu_x, "tau_hat": tau_x, "y_hat": yhat_x, "rfx_predictions": rfx_preds, "variance_forest_predictions": None} - else: - return {"mu_hat": mu_x, "tau_hat": tau_x, "y_hat": yhat_x, "rfx_predictions": None, "variance_forest_predictions": None} + return ppd_array def to_json(self) -> str: """ @@ -2609,6 +3123,9 @@ def to_json(self) -> str: bcf_json.add_boolean("sample_sigma2_leaf_tau", self.sample_sigma2_leaf_tau) bcf_json.add_boolean("include_variance_forest", self.include_variance_forest) bcf_json.add_boolean("has_rfx", self.has_rfx) + bcf_json.add_boolean("has_rfx_basis", self.has_rfx_basis) + bcf_json.add_scalar("num_rfx_basis", self.num_rfx_basis) + bcf_json.add_boolean("multivariate_treatment", self.multivariate_treatment) bcf_json.add_scalar("num_gfr", self.num_gfr) bcf_json.add_scalar("num_burnin", self.num_burnin) bcf_json.add_scalar("num_mcmc", self.num_mcmc) @@ -2618,9 +3135,8 @@ def to_json(self) -> str: bcf_json.add_boolean( "internal_propensity_model", self.internal_propensity_model ) - bcf_json.add_boolean( - "probit_outcome_model", self.probit_outcome_model - ) + bcf_json.add_boolean("probit_outcome_model", self.probit_outcome_model) + bcf_json.add_string("rfx_model_spec", self.rfx_model_spec) # Add parameter samples if self.sample_sigma2_global: @@ -2666,6 +3182,9 @@ def from_json(self, json_string: str) -> None: # Unpack forests self.include_variance_forest = bcf_json.get_boolean("include_variance_forest") self.has_rfx = bcf_json.get_boolean("has_rfx") + self.has_rfx_basis = bcf_json.get_boolean("has_rfx_basis") + self.num_rfx_basis = bcf_json.get_scalar("num_rfx_basis") + self.multivariate_treatment = bcf_json.get_boolean("multivariate_treatment") # TODO: don't just make this a placeholder that we overwrite self.forest_container_mu = ForestContainer(0, 0, False, False) self.forest_container_mu.forest_container_cpp.LoadFromJson( @@ -2705,9 +3224,8 @@ def from_json(self, json_string: str) -> None: self.internal_propensity_model = bcf_json.get_boolean( "internal_propensity_model" ) - self.probit_outcome_model = bcf_json.get_boolean( - "probit_outcome_model" - ) + self.probit_outcome_model = bcf_json.get_boolean("probit_outcome_model") + self.rfx_model_spec = bcf_json.get_string("rfx_model_spec") # Unpack parameter samples if self.sample_sigma2_global: @@ -2801,6 +3319,11 @@ def from_json_string_list(self, json_string_list: list[str]) -> None: # Unpack random effects self.has_rfx = json_object_default.get_boolean("has_rfx") + self.has_rfx_basis = json_object_default.get_boolean("has_rfx_basis") + self.num_rfx_basis = json_object_default.get_scalar("num_rfx_basis") + self.multivariate_treatment = json_object_default.get_boolean( + "multivariate_treatment" + ) if self.has_rfx: self.rfx_container = RandomEffectsContainer() for i in range(len(json_object_list)): @@ -2833,7 +3356,11 @@ def from_json_string_list(self, json_string_list: list[str]) -> None: self.internal_propensity_model = json_object_default.get_boolean( "internal_propensity_model" ) - + self.probit_outcome_model = json_object_default.get_boolean( + "probit_outcome_model" + ) + self.rfx_model_spec = json_object_default.get_string("rfx_model_spec") + # Unpack number of samples for i in range(len(json_object_list)): if i == 0: @@ -2873,9 +3400,9 @@ def from_json_string_list(self, json_string_list: list[str]) -> None: if self.sample_sigma2_leaf_tau: for i in range(len(json_object_list)): if i == 0: - self.sample_sigma2_leaf_tau = json_object_list[i].get_numeric_vector( - "sigma2_leaf_tau_samples", "parameters" - ) + self.sample_sigma2_leaf_tau = json_object_list[ + i + ].get_numeric_vector("sigma2_leaf_tau_samples", "parameters") else: sample_sigma2_leaf_tau = json_object_list[i].get_numeric_vector( "sigma2_leaf_tau_samples", "parameters" @@ -2911,3 +3438,32 @@ def is_sampled(self) -> bool: `True` if a BCF model has been sampled, `False` otherwise """ return self.sampled + + def has_term(self, term: str) -> bool: + """ + Whether or not a model includes a term. + + Parameters + ---------- + term : str + Character string specifying the model term to check for. Options for BCF models are `"prognostic_function"`, `"cate"`, `"variance_forest"`, `"rfx"`, `"y_hat"`, or `"all"`. + + Returns + ------- + bool + `True` if the model includes the specified term, `False` otherwise + """ + if term == "prognostic_function": + return True + if term == "cate": + return True + elif term == "variance_forest": + return self.include_variance_forest + elif term == "rfx": + return self.has_rfx + elif term == "y_hat": + return True + elif term == "all": + return True + else: + return False diff --git a/stochtree/config.py b/stochtree/config.py index 72cae512..84e851c3 100644 --- a/stochtree/config.py +++ b/stochtree/config.py @@ -119,7 +119,9 @@ def __init__( raise ValueError("`leaf_dimension` must be an integer greater than 0") if leaf_model_scale is None: diag_value = 1.0 / num_trees - leaf_model_scale_array = np.zeros((leaf_dimension, leaf_dimension), dtype=float) + leaf_model_scale_array = np.zeros( + (leaf_dimension, leaf_dimension), dtype=float + ) np.fill_diagonal(leaf_model_scale_array, diag_value) else: if isinstance(leaf_model_scale, np.ndarray): @@ -432,7 +434,7 @@ def get_feature_types(self) -> np.ndarray: """ return self.feature_types - def get_sweep_update_indices(self) -> Union[np.ndarray,None]: + def get_sweep_update_indices(self) -> Union[np.ndarray, None]: """ Query vector of (0-indexed) indices of trees to update in a sweep diff --git a/stochtree/data.py b/stochtree/data.py index 4e40a282..da3ca735 100644 --- a/stochtree/data.py +++ b/stochtree/data.py @@ -59,7 +59,9 @@ def update_basis(self, basis: np.array): Numpy array of basis vectors. """ if not self.has_basis(): - raise ValueError("This dataset does not have a basis to update. Please use `add_basis` to create and initialize the values in the Dataset's basis matrix.") + raise ValueError( + "This dataset does not have a basis to update. Please use `add_basis` to create and initialize the values in the Dataset's basis matrix." + ) if not isinstance(basis, np.ndarray): raise ValueError("basis must be a numpy array.") if np.ndim(basis) == 1: @@ -71,9 +73,13 @@ def update_basis(self, basis: np.array): n, p = basis_.shape basis_rowmajor = np.ascontiguousarray(basis_) if self.num_basis() != p: - raise ValueError(f"The number of columns in the new basis ({p}) must match the number of columns in the existing basis ({self.num_basis()}).") + raise ValueError( + f"The number of columns in the new basis ({p}) must match the number of columns in the existing basis ({self.num_basis()})." + ) if self.num_observations() != n: - raise ValueError(f"The number of rows in the new basis ({n}) must match the number of rows in the existing basis ({self.num_observations()}).") + raise ValueError( + f"The number of rows in the new basis ({n}) must match the number of rows in the existing basis ({self.num_observations()})." + ) self.dataset_cpp.UpdateBasis(basis_rowmajor, n, p, True) def add_variance_weights(self, variance_weights: np.array): @@ -91,12 +97,14 @@ def add_variance_weights(self, variance_weights: np.array): n = variance_weights_.size if variance_weights_.ndim != 1: raise ValueError("variance_weights must be a 1-dimensional numpy array.") - + self.dataset_cpp.AddVarianceWeights(variance_weights_, n) - - def update_variance_weights(self, variance_weights: np.array, exponentiate: bool = False): + + def update_variance_weights( + self, variance_weights: np.array, exponentiate: bool = False + ): """ - Update variance weights in a dataset. Allows users to build an ensemble that depends on + Update variance weights in a dataset. Allows users to build an ensemble that depends on variance weights that are updated throughout the sampler. Parameters @@ -107,7 +115,9 @@ def update_variance_weights(self, variance_weights: np.array, exponentiate: bool Whether to exponentiate the variance weights before storing them in the dataset. """ if not self.has_variance_weights(): - raise ValueError("This dataset does not have variance weights to update. Please use `add_variance_weights` to create and initialize the values in the Dataset's variance weight vector.") + raise ValueError( + "This dataset does not have variance weights to update. Please use `add_variance_weights` to create and initialize the values in the Dataset's variance weight vector." + ) if not isinstance(variance_weights, np.ndarray): raise ValueError("variance_weights must be a numpy array.") variance_weights_ = np.squeeze(variance_weights) @@ -115,7 +125,9 @@ def update_variance_weights(self, variance_weights: np.array, exponentiate: bool if variance_weights_.ndim != 1: raise ValueError("variance_weights must be a 1-dimensional numpy array.") if self.num_observations() != n: - raise ValueError(f"The number of rows in the new variance_weights vector ({n}) must match the number of rows in the existing vector ({self.num_observations()}).") + raise ValueError( + f"The number of rows in the new variance_weights vector ({n}) must match the number of rows in the existing vector ({self.num_observations()})." + ) self.dataset_cpp.UpdateVarianceWeights(variance_weights_, n, exponentiate) def num_observations(self) -> int: @@ -150,7 +162,7 @@ def num_basis(self) -> int: Dimension of the basis vector in the dataset, returning 0 if the dataset does not have a basis """ return self.dataset_cpp.NumBasis() - + def get_covariates(self) -> np.array: """ Return the covariates in a Dataset as a numpy array @@ -161,7 +173,7 @@ def get_covariates(self) -> np.array: Covariate data """ return self.dataset_cpp.GetCovariates() - + def get_basis(self) -> np.array: """ Return the bases in a Dataset as a numpy array @@ -172,7 +184,7 @@ def get_basis(self) -> np.array: Basis data """ return self.dataset_cpp.GetBasis() - + def get_variance_weights(self) -> np.array: """ Return the variance weights in a Dataset as a numpy array diff --git a/stochtree/forest.py b/stochtree/forest.py index d672ae5c..3cfde245 100644 --- a/stochtree/forest.py +++ b/stochtree/forest.py @@ -1,7 +1,3 @@ -""" -Python classes wrapping C++ forest container object -""" - from typing import Union import numpy as np @@ -161,13 +157,13 @@ def set_root_leaves( def collapse(self, batch_size: int) -> None: """ - Collapse forests in this container by a pre-specified batch size. - For example, if we have a container of twenty 10-tree forests, and we - specify a `batch_size` of 5, then this method will yield four 50-tree - forests. "Excess" forests remaining after the size of a forest container - is divided by `batch_size` will be pruned from the beginning of the - container (i.e. earlier sampled forests will be deleted). This method - has no effect if `batch_size` is larger than the number of forests + Collapse forests in this container by a pre-specified batch size. + For example, if we have a container of twenty 10-tree forests, and we + specify a `batch_size` of 5, then this method will yield four 50-tree + forests. "Excess" forests remaining after the size of a forest container + is divided by `batch_size` will be pruned from the beginning of the + container (i.e. earlier sampled forests will be deleted). This method + has no effect if `batch_size` is larger than the number of forests in a container. Parameters @@ -177,12 +173,23 @@ def collapse(self, batch_size: int) -> None: """ container_size = self.num_samples() if batch_size <= container_size and batch_size > 1: - reverse_container_inds = np.linspace(start=container_size, stop=1, num=container_size, dtype=int) + reverse_container_inds = np.linspace( + start=container_size, stop=1, num=container_size, dtype=int + ) num_clean_batches = container_size // batch_size - batch_inds = (reverse_container_inds - (container_size - ((container_size // num_clean_batches) * num_clean_batches)) - 1) // batch_size + batch_inds = ( + reverse_container_inds + - ( + container_size + - ((container_size // num_clean_batches) * num_clean_batches) + ) + - 1 + ) // batch_size batch_inds = batch_inds.astype(int) for batch_ind in np.flip(np.unique(batch_inds[batch_inds >= 0])): - merge_forest_inds = np.sort(reverse_container_inds[batch_inds == batch_ind] - 1) + merge_forest_inds = np.sort( + reverse_container_inds[batch_inds == batch_ind] - 1 + ) num_merge_forests = len(merge_forest_inds) self.combine_forests(merge_forest_inds) for i in range(num_merge_forests - 1, 0, -1): @@ -194,10 +201,8 @@ def collapse(self, batch_size: int) -> None: num_delete_forests = len(delete_forest_inds) for i in range(num_delete_forests - 1, -1, -1): self.delete_sample(delete_forest_inds[i]) - - def combine_forests( - self, forest_inds: np.array - ) -> None: + + def combine_forests(self, forest_inds: np.array) -> None: """ Collapse specified forests into a single forest @@ -214,42 +219,56 @@ def combine_forests( forest_inds_sorted = forest_inds_sorted.astype(int) self.forest_container_cpp.CombineForests(forest_inds_sorted) - def add_to_forest( - self, forest_index: int, constant_value : float - ) -> None: + def add_to_forest(self, forest_index: int, constant_value: float) -> None: """ Add a constant value to every leaf of every tree of a given forest Parameters ---------- - forest_index : int + forest_index : int Index of forest whose leaves will be modified (0-indexed) - constant_value : float + constant_value : float Value to add to every leaf of every tree of the forest at `forest_index` """ - if not isinstance(forest_index, int) and not isinstance(constant_value, (int, float)): - raise ValueError("forest_index must be an integer and constant_multiple must be a float or int") - if not forest_index >= 0 or not forest_index < self.forest_container_cpp.NumSamples(): - raise ValueError("forest_index must be >= 0 and less than the total number of samples in a forest container") + if not isinstance(forest_index, int) and not isinstance( + constant_value, (int, float) + ): + raise ValueError( + "forest_index must be an integer and constant_multiple must be a float or int" + ) + if ( + not forest_index >= 0 + or not forest_index < self.forest_container_cpp.NumSamples() + ): + raise ValueError( + "forest_index must be >= 0 and less than the total number of samples in a forest container" + ) self.forest_container_cpp.AddToForest(forest_index, constant_value) - def multiply_forest( - self, forest_index: int, constant_multiple : float - ) -> None: + def multiply_forest(self, forest_index: int, constant_multiple: float) -> None: """ Multiply every leaf of every tree of a given forest by constant value Parameters ---------- - forest_index : int + forest_index : int Index of forest whose leaves will be modified (0-indexed) - constant_multiple : float + constant_multiple : float Value to multiply through by every leaf of every tree of the forest at `forest_index` """ - if not isinstance(forest_index, int) and not isinstance(constant_multiple, (int, float)): - raise ValueError("forest_index must be an integer and constant_multiple must be a float or int") - if not forest_index >= 0 or not forest_index < self.forest_container_cpp.NumSamples(): - raise ValueError("forest_index must be >= 0 and less than the total number of samples in a forest container") + if not isinstance(forest_index, int) and not isinstance( + constant_multiple, (int, float) + ): + raise ValueError( + "forest_index must be an integer and constant_multiple must be a float or int" + ) + if ( + not forest_index >= 0 + or not forest_index < self.forest_container_cpp.NumSamples() + ): + raise ValueError( + "forest_index must be >= 0 and less than the total number of samples in a forest container" + ) self.forest_container_cpp.MultiplyForest(forest_index, constant_multiple) def save_to_json_file(self, json_filename: str) -> None: @@ -1021,7 +1040,7 @@ def set_root_leaves(self, leaf_value: Union[float, np.array]) -> None: else: self.forest_cpp.SetRootValue(leaf_value) self.internal_forest_is_empty = False - + def merge_forest(self, other_forest): """ Create a larger forest by merging the trees of this forest with those of another forest @@ -1034,13 +1053,19 @@ def merge_forest(self, other_forest): if not isinstance(other_forest, Forest): raise ValueError("other_forest must be an instance of the Forest class") if self.leaf_constant != other_forest.leaf_constant: - raise ValueError("Forests must have matching leaf dimensions in order to be merged") + raise ValueError( + "Forests must have matching leaf dimensions in order to be merged" + ) if self.output_dimension != other_forest.output_dimension: - raise ValueError("Forests must have matching leaf dimensions in order to be merged") + raise ValueError( + "Forests must have matching leaf dimensions in order to be merged" + ) if self.is_exponentiated != other_forest.is_exponentiated: - raise ValueError("Forests must have matching leaf dimensions in order to be merged") + raise ValueError( + "Forests must have matching leaf dimensions in order to be merged" + ) self.forest_cpp.MergeForest(other_forest.forest_cpp) - + def add_constant(self, constant_value): """ Add a constant value to every leaf of every tree in an ensemble. If leaves are multi-dimensional, `constant_value` will be added to every dimension of the leaves. @@ -1051,7 +1076,7 @@ def add_constant(self, constant_value): Value that will be added to every leaf of every tree """ self.forest_cpp.AddConstant(constant_value) - + def multiply_constant(self, constant_multiple): """ Multiply every leaf of every tree by a constant value. If leaves are multi-dimensional, `constant_multiple` will be multiplied through every dimension of the leaves. diff --git a/stochtree/kernel.py b/stochtree/kernel.py index 86d9d8a5..58986f7f 100644 --- a/stochtree/kernel.py +++ b/stochtree/kernel.py @@ -2,20 +2,29 @@ import pandas as pd import numpy as np -from stochtree_cpp import cppComputeForestContainerLeafIndices, cppComputeForestMaxLeafIndex +from stochtree_cpp import ( + cppComputeForestContainerLeafIndices, + cppComputeForestMaxLeafIndex, +) from .bart import BARTModel from .bcf import BCFModel from .forest import ForestContainer -def compute_forest_leaf_indices(model_object: Union[BARTModel, BCFModel, ForestContainer], covariates: Union[np.array, pd.DataFrame], forest_type: str = None, propensity: np.array = None, forest_inds: Union[int, np.ndarray] = None): +def compute_forest_leaf_indices( + model_object: Union[BARTModel, BCFModel, ForestContainer], + covariates: Union[np.array, pd.DataFrame], + forest_type: str = None, + propensity: np.array = None, + forest_inds: Union[int, np.ndarray] = None, +): """ Compute and return a vector representation of a forest's leaf predictions for every observation in a dataset. - The vector has a "row-major" format that can be easily re-represented as as a CSR sparse matrix: elements are organized so that the first `n` elements - correspond to leaf predictions for all `n` observations in a dataset for the first tree in an ensemble, the next `n` elements correspond to predictions for - the second tree and so on. The "data" for each element corresponds to a uniquely mapped column index that corresponds to a single leaf of a single tree (i.e. + The vector has a "row-major" format that can be easily re-represented as as a CSR sparse matrix: elements are organized so that the first `n` elements + correspond to leaf predictions for all `n` observations in a dataset for the first tree in an ensemble, the next `n` elements correspond to predictions for + the second tree and so on. The "data" for each element corresponds to a uniquely mapped column index that corresponds to a single leaf of a single tree (i.e. if tree 1 has 3 leaves, its column indices range from 0 to 2, and then tree 2's leaf indices begin at 3, etc...). Parameters @@ -36,38 +45,52 @@ def compute_forest_leaf_indices(model_object: Union[BARTModel, BCFModel, ForestC * `'variance'`: Extracts leaf indices for the variance forest * **ForestContainer** * `NULL`: It is not necessary to disambiguate when this function is called directly on a `ForestSamples` object. This is the default value of this - + propensity : `np.array`, optional Optional test set propensities. Must be provided if propensities were provided when the model was sampled. forest_inds : int or np.ndarray - Indices of the forest sample(s) for which to compute leaf indices. If not provided, this function will return leaf indices for every sample of a forest. + Indices of the forest sample(s) for which to compute leaf indices. If not provided, this function will return leaf indices for every sample of a forest. This function uses 0-indexing, so the first forest sample corresponds to `forest_num = 0`, and so on. - + Returns ------- - Numpy array with dimensions `num_obs` by `num_trees`, where `num_obs` is the number of rows in `covaritates` and `num_trees` is the number of trees in the relevant forest of `model_object`. + Numpy array with dimensions `num_obs` by `num_trees`, where `num_obs` is the number of rows in `covaritates` and `num_trees` is the number of trees in the relevant forest of `model_object`. """ # Extract relevant forest container - if not isinstance(model_object, BARTModel) and not isinstance(model_object, BCFModel) and not isinstance(model_object, ForestContainer): - raise ValueError("model_object must be one of BARTModel, BCFModel, or ForestContainer") + if ( + not isinstance(model_object, BARTModel) + and not isinstance(model_object, BCFModel) + and not isinstance(model_object, ForestContainer) + ): + raise ValueError( + "model_object must be one of BARTModel, BCFModel, or ForestContainer" + ) if isinstance(model_object, BARTModel): model_type = "bart" if forest_type is None: - raise ValueError("forest_type must be specified for a BARTModel model_type (either set to 'mean' or 'variance')") + raise ValueError( + "forest_type must be specified for a BARTModel model_type (either set to 'mean' or 'variance')" + ) elif isinstance(model_object, BCFModel): model_type = "bcf" if forest_type is None: - raise ValueError("forest_type must be specified for a BCFModel model_type (either set to 'prognostic', 'treatment' or 'variance')") + raise ValueError( + "forest_type must be specified for a BCFModel model_type (either set to 'prognostic', 'treatment' or 'variance')" + ) else: model_type = "forest" if model_type == "bart": if forest_type == "mean": if not model_object.include_mean_forest: - raise ValueError("Mean forest was not sampled for model_object, but requested by forest_type") + raise ValueError( + "Mean forest was not sampled for model_object, but requested by forest_type" + ) forest_container = model_object.forest_container_mean else: if not model_object.include_variance_forest: - raise ValueError("Variance forest was not sampled for model_object, but requested by forest_type") + raise ValueError( + "Variance forest was not sampled for model_object, but requested by forest_type" + ) forest_container = model_object.forest_container_variance elif model_type == "bcf": if forest_type == "prognostic": @@ -76,17 +99,23 @@ def compute_forest_leaf_indices(model_object: Union[BARTModel, BCFModel, ForestC forest_container = model_object.forest_container_tau else: if not model_object.include_variance_forest: - raise ValueError("Variance forest was not sampled for model_object, but requested by forest_type") + raise ValueError( + "Variance forest was not sampled for model_object, but requested by forest_type" + ) forest_container = model_object.forest_container_variance else: forest_container = model_object - - if not isinstance(covariates, pd.DataFrame) and not isinstance(covariates, np.ndarray): + + if not isinstance(covariates, pd.DataFrame) and not isinstance( + covariates, np.ndarray + ): raise ValueError("covariates must be a matrix or dataframe") - + # Preprocess covariates if model_type == "bart" or model_type == "bcf": - covariates_processed = model_object._covariate_preprocessor.transform(covariates) + covariates_processed = model_object._covariate_preprocessor.transform( + covariates + ) else: covariates_processed = covariates covariates_processed = np.asfortranarray(covariates_processed) @@ -100,17 +129,21 @@ def compute_forest_leaf_indices(model_object: Union[BARTModel, BCFModel, ForestC "Propensity scores not provided, but no propensity model was trained during sampling" ) propensity = np.mean( - model_object.bart_propensity_model.predict(covariates), axis=1, keepdims=True + model_object.bart_propensity_model.predict(covariates), + axis=1, + keepdims=True, ) covariates_processed = np.c_[covariates_processed, propensity] - + # Preprocess forest indices num_forests = forest_container.num_samples() if forest_inds is None: forest_inds = np.arange(num_forests) elif isinstance(forest_inds, int): if not forest_inds >= 0 or not forest_inds < num_forests: - raise ValueError("The index in forest_inds must be >= 0 and < the total number of samples in a forest container") + raise ValueError( + "The index in forest_inds must be >= 0 and < the total number of samples in a forest container" + ) forest_inds = np.array([forest_inds]) elif isinstance(forest_inds, np.ndarray): if forest_inds.size > 1: @@ -118,13 +151,22 @@ def compute_forest_leaf_indices(model_object: Union[BARTModel, BCFModel, ForestC if forest_inds.ndim > 1: raise ValueError("forest_inds must be a one-dimensional numpy array") if not np.all(forest_inds >= 0) or not np.all(forest_inds < num_forests): - raise ValueError("The indices in forest_inds must be >= 0 and < the total number of samples in a forest container") + raise ValueError( + "The indices in forest_inds must be >= 0 and < the total number of samples in a forest container" + ) else: raise ValueError("forest_inds must be a one-dimensional numpy array") - - return cppComputeForestContainerLeafIndices(forest_container.forest_container_cpp, covariates_processed, forest_inds) -def compute_forest_max_leaf_index(model_object: Union[BARTModel, BCFModel, ForestContainer], forest_type: str = None, forest_inds: Union[int, np.ndarray] = None): + return cppComputeForestContainerLeafIndices( + forest_container.forest_container_cpp, covariates_processed, forest_inds + ) + + +def compute_forest_max_leaf_index( + model_object: Union[BARTModel, BCFModel, ForestContainer], + forest_type: str = None, + forest_inds: Union[int, np.ndarray] = None, +): """ Compute and return the largest possible leaf index computable by `compute_forest_leaf_indices` for the forests in a designated forest sample container. @@ -144,36 +186,50 @@ def compute_forest_max_leaf_index(model_object: Union[BARTModel, BCFModel, Fores * `'variance'`: Extracts leaf indices for the variance forest * **ForestContainer** * `NULL`: It is not necessary to disambiguate when this function is called directly on a `ForestSamples` object. This is the default value of this - + forest_inds : int or np.ndarray - Indices of the forest sample(s) for which to compute max leaf indices. If not provided, this function will return max leaf indices for every sample of a forest. + Indices of the forest sample(s) for which to compute max leaf indices. If not provided, this function will return max leaf indices for every sample of a forest. This function uses 0-indexing, so the first forest sample corresponds to `forest_num = 0`, and so on. - + Returns ------- Numpy array containing the largest possible leaf index computable by `compute_forest_leaf_indices` for the forests in a designated forest sample container. """ # Extract relevant forest container - if not isinstance(model_object, BARTModel) and not isinstance(model_object, BCFModel) and not isinstance(model_object, ForestContainer): - raise ValueError("model_object must be one of BARTModel, BCFModel, or ForestContainer") + if ( + not isinstance(model_object, BARTModel) + and not isinstance(model_object, BCFModel) + and not isinstance(model_object, ForestContainer) + ): + raise ValueError( + "model_object must be one of BARTModel, BCFModel, or ForestContainer" + ) if isinstance(model_object, BARTModel): model_type = "bart" if forest_type is None: - raise ValueError("forest_type must be specified for a BARTModel model_type (either set to 'mean' or 'variance')") + raise ValueError( + "forest_type must be specified for a BARTModel model_type (either set to 'mean' or 'variance')" + ) elif isinstance(model_object, BCFModel): model_type = "bcf" if forest_type is None: - raise ValueError("forest_type must be specified for a BCFModel model_type (either set to 'prognostic', 'treatment' or 'variance')") + raise ValueError( + "forest_type must be specified for a BCFModel model_type (either set to 'prognostic', 'treatment' or 'variance')" + ) else: model_type = "forest" if model_type == "bart": if forest_type == "mean": if not model_object.include_mean_forest: - raise ValueError("Mean forest was not sampled for model_object, but requested by forest_type") + raise ValueError( + "Mean forest was not sampled for model_object, but requested by forest_type" + ) forest_container = model_object.forest_container_mean else: if not model_object.include_variance_forest: - raise ValueError("Variance forest was not sampled for model_object, but requested by forest_type") + raise ValueError( + "Variance forest was not sampled for model_object, but requested by forest_type" + ) forest_container = model_object.forest_container_variance elif model_type == "bcf": if forest_type == "prognostic": @@ -182,18 +238,22 @@ def compute_forest_max_leaf_index(model_object: Union[BARTModel, BCFModel, Fores forest_container = model_object.forest_container_tau else: if not model_object.include_variance_forest: - raise ValueError("Variance forest was not sampled for model_object, but requested by forest_type") + raise ValueError( + "Variance forest was not sampled for model_object, but requested by forest_type" + ) forest_container = model_object.forest_container_variance else: forest_container = model_object - + # Preprocess forest indices num_forests = forest_container.num_samples() if forest_inds is None: forest_inds = np.arange(num_forests) elif isinstance(forest_inds, int): if not forest_inds >= 0 or not forest_inds < num_forests: - raise ValueError("The index in forest_inds must be >= 0 and < the total number of samples in a forest container") + raise ValueError( + "The index in forest_inds must be >= 0 and < the total number of samples in a forest container" + ) forest_inds = np.array([forest_inds]) elif isinstance(forest_inds, np.ndarray): if forest_inds.size > 1: @@ -201,16 +261,19 @@ def compute_forest_max_leaf_index(model_object: Union[BARTModel, BCFModel, Fores if forest_inds.ndim > 1: raise ValueError("forest_inds must be a one-dimensional numpy array") if not np.all(forest_inds >= 0) or not np.all(forest_inds < num_forests): - raise ValueError("The indices in forest_inds must be >= 0 and < the total number of samples in a forest container") + raise ValueError( + "The indices in forest_inds must be >= 0 and < the total number of samples in a forest container" + ) else: raise ValueError("forest_inds must be a one-dimensional numpy array") - + # Compute max index output_size = len(forest_inds) output = np.empty(output_size) for i in np.arange(output_size): - output[i] = cppComputeForestMaxLeafIndex(forest_container.forest_container_cpp, forest_inds[i]) - - # Return result - return output + output[i] = cppComputeForestMaxLeafIndex( + forest_container.forest_container_cpp, forest_inds[i] + ) + # Return result + return output diff --git a/stochtree/random_effects.py b/stochtree/random_effects.py index 6c4093d4..f3e1a187 100644 --- a/stochtree/random_effects.py +++ b/stochtree/random_effects.py @@ -110,10 +110,12 @@ def add_variance_weights(self, variance_weights: np.array): ) n = variance_weights_.shape[0] self.rfx_dataset_cpp.AddVarianceWeights(variance_weights_, n) - - def update_variance_weights(self, variance_weights: np.array, exponentiate: bool = False): + + def update_variance_weights( + self, variance_weights: np.array, exponentiate: bool = False + ): """ - Update variance weights in a dataset. Allows users to build an ensemble that depends on + Update variance weights in a dataset. Allows users to build an ensemble that depends on variance weights that are updated throughout the sampler. Parameters @@ -124,7 +126,9 @@ def update_variance_weights(self, variance_weights: np.array, exponentiate: bool Whether to exponentiate the variance weights before storing them in the dataset. """ if not self.has_variance_weights(): - raise ValueError("This dataset does not have variance weights to update. Please use `add_variance_weights` to create and initialize the values in the Dataset's variance weight vector.") + raise ValueError( + "This dataset does not have variance weights to update. Please use `add_variance_weights` to create and initialize the values in the Dataset's variance weight vector." + ) if not isinstance(variance_weights, np.ndarray): raise ValueError("variance_weights must be a numpy array.") variance_weights_ = np.squeeze(variance_weights) @@ -134,9 +138,11 @@ def update_variance_weights(self, variance_weights: np.array, exponentiate: bool ) n = variance_weights_.shape[0] if self.num_observations() != n: - raise ValueError(f"The number of rows in the new variance_weights vector ({n}) must match the number of rows in the existing vector ({self.num_observations()}).") + raise ValueError( + f"The number of rows in the new variance_weights vector ({n}) must match the number of rows in the existing vector ({self.num_observations()})." + ) self.rfx_dataset_cpp.UpdateVarianceWeights(variance_weights, n, exponentiate) - + def get_group_labels(self) -> np.array: """ Return the group labels in a RandomEffectsDataset as a numpy array @@ -147,7 +153,7 @@ def get_group_labels(self) -> np.array: One-dimensional numpy array of group labels. """ return self.rfx_dataset_cpp.GetGroupLabels() - + def get_basis(self) -> np.array: """ Return the bases in a RandomEffectsDataset as a numpy array @@ -158,7 +164,7 @@ def get_basis(self) -> np.array: Two-dimensional numpy array of basis vectors. """ return self.rfx_dataset_cpp.GetBasis() - + def get_variance_weights(self) -> np.array: """ Return the variance weights in a RandomEffectsDataset as a numpy array @@ -361,6 +367,36 @@ def predict(self, group_labels: np.array, basis: np.array) -> np.ndarray: return self.rfx_container_cpp.Predict( rfx_dataset.rfx_dataset_cpp, self.rfx_label_mapper_cpp ) + + def extract_parameter_samples(self) -> dict[str, np.ndarray]: + """ + Extract the random effects parameters sampled. With the "redundant parameterization" of Gelman et al (2008), + this includes four parameters: alpha (the "working parameter" shared across every group), xi + (the "group parameter" sampled separately for each group), beta (the product of alpha and xi, + which corresponds to the overall group-level random effects), and sigma (group-independent prior + variance for each component of xi). + + Returns + ------- + dict[str, np.ndarray] + dict of arrays. The alpha array has dimension (`num_components`, `num_samples`) and is simply a vector if `num_components = 1`. + The xi and beta arrays have dimension (`num_components`, `num_groups`, `num_samples`) and are simply matrices if `num_components = 1`. + The sigma array has dimension (`num_components`, `num_samples`) and is simply a vector if `num_components = 1`. + """ + # num_samples = self.rfx_container_cpp.NumSamples() + # num_components = self.rfx_container_cpp.NumComponents() + # num_groups = self.rfx_container_cpp.NumGroups() + beta_samples = np.squeeze(self.rfx_container_cpp.GetBeta()) + xi_samples = np.squeeze(self.rfx_container_cpp.GetXi()) + alpha_samples = np.squeeze(self.rfx_container_cpp.GetAlpha()) + sigma_samples = np.squeeze(self.rfx_container_cpp.GetSigma()) + output = { + "beta_samples": beta_samples, + "xi_samples": xi_samples, + "alpha_samples": alpha_samples, + "sigma_samples": sigma_samples + } + return output class RandomEffectsModel: @@ -420,6 +456,28 @@ def sample( rng.rng_cpp, ) + def predict( + self, rfx_dataset: RandomEffectsDataset, rfx_tracker: RandomEffectsTracker + ) -> np.ndarray: + """ + Predict random effects for each observation in `rfx_dataset` + + Parameters + ---------- + rfx_dataset: RandomEffectsDataset + Object of type `RandomEffectsDataset` + rfx_tracker: RandomEffectsTracker + Object of type `RandomEffectsTracker` + + Returns + ------- + np.ndarray + Numpy array with as many rows as observations in `rfx_dataset` and as many columns as samples in the container + """ + return self.rfx_model_cpp.Predict( + rfx_dataset.rfx_dataset_cpp, rfx_tracker.rfx_tracker_cpp + ) + def set_working_parameter(self, working_parameter: np.ndarray) -> None: """ Set values for the "working parameter." This is typically used for initialization, diff --git a/stochtree/sampler.py b/stochtree/sampler.py index cbab9ce6..cc5ad54a 100644 --- a/stochtree/sampler.py +++ b/stochtree/sampler.py @@ -151,7 +151,7 @@ def sample_one_iteration( ) if self.forest_sampler_cpp.GetMaxDepth() != forest_config.get_max_depth(): self.forest_sampler_cpp.SetMaxDepth(forest_config.get_max_depth()) - + # Unpack sweep update indices (initializing empty numpy array if None) sweep_update_indices = forest_config.get_sweep_update_indices() if sweep_update_indices is None: @@ -176,7 +176,7 @@ def sample_one_iteration( forest_config.get_num_features_subsample(), keep_forest, gfr, - num_threads, + num_threads, ) def prepare_for_sampler( @@ -270,7 +270,7 @@ def propagate_basis_update( self.forest_sampler_cpp.PropagateBasisUpdate( dataset.dataset_cpp, residual.residual_cpp, forest.forest_cpp ) - + def get_cached_forest_predictions(self) -> np.array: """ Extract an internally-cached prediction of a forest on the training dataset in a sampler. diff --git a/stochtree/serialization.py b/stochtree/serialization.py index 844ee54d..64f0d842 100644 --- a/stochtree/serialization.py +++ b/stochtree/serialization.py @@ -62,7 +62,9 @@ def add_random_effects(self, rfx_container: RandomEffectsContainer) -> None: Samples of a random effects model """ _ = self.json_cpp.AddRandomEffectsContainer(rfx_container.rfx_container_cpp) - _ = self.json_cpp.AddRandomEffectsLabelMapper(rfx_container.rfx_label_mapper_cpp) + _ = self.json_cpp.AddRandomEffectsLabelMapper( + rfx_container.rfx_label_mapper_cpp + ) _ = self.json_cpp.AddRandomEffectsGroupIDs(rfx_container.rfx_group_ids) self.json_cpp.IncrementRandomEffectsCount() self.num_rfx += 1 diff --git a/stochtree/utils.py b/stochtree/utils.py index da25cb9c..214beb2f 100644 --- a/stochtree/utils.py +++ b/stochtree/utils.py @@ -1,4 +1,5 @@ -from typing import Union, Optional +from typing import Union +import math import numpy as np @@ -190,17 +191,17 @@ def _check_matrix_square(input: np.ndarray) -> bool: def _expand_dims_1d(input: Union[int, float, np.array], output_size: int) -> np.array: """ - Convert scalar input to 1D numpy array of dimension `output_size`, + Convert scalar input to 1D numpy array of dimension `output_size`, or check that input array is equivalent to a 1D array of dimension `output_size`. Single element numpy arrays (i.e. `np.array([2.5])`) are treated as scalars. - + Parameters ---------- input : int, float, np.array Input to be converted to a 1D array (or passed through as-is) output_size : int Intended size of the output vector - + Returns ------- np.array @@ -209,30 +210,38 @@ def _expand_dims_1d(input: Union[int, float, np.array], output_size: int) -> np. if isinstance(input, np.ndarray): input = np.squeeze(input) if input.ndim > 1: - raise ValueError("`input` must be convertible to a 1D numpy array or scalar") + raise ValueError( + "`input` must be convertible to a 1D numpy array or scalar" + ) if input.ndim == 0: output = np.repeat(input, output_size) else: if input.shape[0] != output_size: - raise ValueError("`input` must be a 1D numpy array with `output_size` elements") + raise ValueError( + "`input` must be a 1D numpy array with `output_size` elements" + ) output = input elif isinstance(input, (int, float)): output = np.repeat(input, output_size) else: - raise ValueError("`input` must be either a 1D numpy array or a scalar that can be repeated `output_size` times") + raise ValueError( + "`input` must be either a 1D numpy array or a scalar that can be repeated `output_size` times" + ) return output -def _expand_dims_2d(input: Union[int, float, np.array], output_rows: int, output_cols: int) -> np.array: +def _expand_dims_2d( + input: Union[int, float, np.array], output_rows: int, output_cols: int +) -> np.array: """ - Ensures that input is propagated appropriately to a 2D numpy array of dimension `output_rows` x `output_cols`. + Ensures that input is propagated appropriately to a 2D numpy array of dimension `output_rows` x `output_cols`. Handles the following cases: 1. `input` is a scalar: output is simply a (`output_rows`, `output_cols`) array with `input` repeated for each element 2. `input` is a 1D array of length `output_rows`: output is a (`output_rows`, `output_cols`) array with `input` broadcast across each of `output_cols` columns 3. `input` is a 1D array of length `output_cols`: output is a (`output_rows`, `output_cols`) array with `input` broadcast across each of `output_rows` rows 4. `input` is a 2D array of dimension (`output_rows`, `output_cols`): input is passed through as-is All other cases raise a `ValueError`. Single element numpy arrays (i.e. `np.array([2.5])`) are treated as scalars. - + Parameters ---------- input : int, float, np.array @@ -241,7 +250,7 @@ def _expand_dims_2d(input: Union[int, float, np.array], output_rows: int, output Intended number of rows in the output array output_cols : int Intended number of columns in the output array - + Returns ------- np.array @@ -253,9 +262,13 @@ def _expand_dims_2d(input: Union[int, float, np.array], output_rows: int, output raise ValueError("`input` must be a 1D or 2D numpy array") elif input.ndim == 2: if input.shape[0] != output_rows: - raise ValueError("If `input` is passed as a 2D numpy array, it must contain `output_rows` rows") + raise ValueError( + "If `input` is passed as a 2D numpy array, it must contain `output_rows` rows" + ) if input.shape[1] != output_cols: - raise ValueError("If `input` is passed as a 2D numpy array, it must contain `output_cols` columns") + raise ValueError( + "If `input` is passed as a 2D numpy array, it must contain `output_cols` columns" + ) output = input elif input.ndim == 1: if input.shape[0] == output_cols: @@ -263,7 +276,9 @@ def _expand_dims_2d(input: Union[int, float, np.array], output_rows: int, output elif input.shape[0] == output_rows: output = np.tile(input, (output_cols, 1)).T else: - raise ValueError("If `input` is a 1D numpy array, it must either contain `output_rows` or `output_cols` elements") + raise ValueError( + "If `input` is a 1D numpy array, it must either contain `output_rows` or `output_cols` elements" + ) elif input.ndim == 0: output = np.tile(input, (output_rows, output_cols)) elif isinstance(input, (int, float)): @@ -273,19 +288,21 @@ def _expand_dims_2d(input: Union[int, float, np.array], output_rows: int, output return output -def _expand_dims_2d_diag(input: Union[int, float, np.array], output_size: int) -> np.array: +def _expand_dims_2d_diag( + input: Union[int, float, np.array], output_size: int +) -> np.array: """ - Convert scalar input to 2D square numpy array of dimension `output_size` x `output_size` with `input` along the diagonal, + Convert scalar input to 2D square numpy array of dimension `output_size` x `output_size` with `input` along the diagonal, or check that input array is equivalent to a 2D square array of dimension `output_size` x `output_size`. Single element numpy arrays (i.e. `np.array([2.5])`) are treated as scalars. - + Parameters ---------- input : int, float, np.array Input to be converted to a 2D square array (or passed through as-is) output_size : int Intended row and column dimension of the square output matrix - + Returns ------- np.array @@ -294,23 +311,57 @@ def _expand_dims_2d_diag(input: Union[int, float, np.array], output_size: int) - if isinstance(input, np.ndarray): input = np.squeeze(input) if (input.ndim != 2) and (input.ndim != 0): - raise ValueError("`input` must be convertible to a 2D numpy array or scalar") - if input.ndim == 0: - output = np.zeros( - (output_size, output_size), dtype=float + raise ValueError( + "`input` must be convertible to a 2D numpy array or scalar" ) + if input.ndim == 0: + output = np.zeros((output_size, output_size), dtype=float) np.fill_diagonal(output, input) else: if input.shape[0] != input.shape[1]: raise ValueError("`input` must be a 2D square numpy array") if input.shape[0] != output_size: - raise ValueError("`input` must be a 2D square numpy array with exactly `output_size` rows and columns") + raise ValueError( + "`input` must be a 2D square numpy array with exactly `output_size` rows and columns" + ) output = input elif isinstance(input, (int, float)): - output = np.zeros( - (output_size, output_size), dtype=float - ) + output = np.zeros((output_size, output_size), dtype=float) np.fill_diagonal(output, input) else: - raise ValueError("`input` must be either a 2D square numpy array or a scalar that can be propagated along the diagonal of a square matrix") + raise ValueError( + "`input` must be either a 2D square numpy array or a scalar that can be propagated along the diagonal of a square matrix" + ) return output + +def _posterior_predictive_heuristic_multiplier(num_samples: int, num_observations: int) -> int: + if num_samples >= 1000: + return 1 + else: + return math.ceil(1000 / num_samples) + +def _summarize_interval(array: np.ndarray, sample_dim: int = 2, level: float = 0.95) -> dict: + # Check that the array is numeric and at least 2 dimensional + if not isinstance(array, np.ndarray): + raise ValueError("`array` must be a numpy array") + if not _check_array_numeric(array): + raise ValueError("`array` must be a numeric numpy array") + if not len(array.shape) >= 2: + raise ValueError("`array` must be at least a 2-dimensional numpy array") + if not _check_is_int(sample_dim) or (sample_dim < 0) or (sample_dim >= len(array.shape)): + raise ValueError( + "`sample_dim` must be an integer between 0 and the number of dimensions of `array` - 1" + ) + if not isinstance(level, float) or (level <= 0) or (level >= 1): + raise ValueError("`level` must be a float between 0 and 1") + + # Compute lower and upper quantiles based on the requested interval + quantile_lb = (1 - level) / 2 + quantile_ub = 1 - quantile_lb + + # Calculate the interval + result_lb = np.quantile(array, q=quantile_lb, axis=sample_dim) + result_ub = np.quantile(array, q=quantile_ub, axis=sample_dim) + + # Return results as a dictionary + return {"lower": result_lb, "upper": result_ub} diff --git a/test/R/testthat/test-bart.R b/test/R/testthat/test-bart.R index 748ff96c..23013ec2 100644 --- a/test/R/testthat/test-bart.R +++ b/test/R/testthat/test-bart.R @@ -1,433 +1,594 @@ test_that("MCMC BART", { - skip_on_cran() - - # Generate simulated data - n <- 100 - p <- 5 - X <- matrix(runif(n*p), ncol = p) - f_XW <- ( - ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + - ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + - ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + - ((0.75 <= X[,1]) & (1 > X[,1])) * (7.5) - ) - noise_sd <- 1 - y <- f_XW + rnorm(n, 0, noise_sd) - test_set_pct <- 0.2 - n_test <- round(test_set_pct*n) - n_train <- n - n_test - test_inds <- sort(sample(1:n, n_test, replace = FALSE)) - train_inds <- (1:n)[!((1:n) %in% test_inds)] - X_test <- X[test_inds,] - X_train <- X[train_inds,] - y_test <- y[test_inds] - y_train <- y[train_inds] - - # 1 chain, no thinning - general_param_list <- list(num_chains = 1, keep_every = 1) - expect_no_error( - bart_model <- bart(X_train = X_train, y_train = y_train, X_test = X_test, - num_gfr = 0, num_burnin = 10, num_mcmc = 10, - general_params = general_param_list) - ) - - # 3 chains, no thinning - general_param_list <- list(num_chains = 3, keep_every = 1) - expect_no_error( - bart_model <- bart(X_train = X_train, y_train = y_train, X_test = X_test, - num_gfr = 0, num_burnin = 10, num_mcmc = 10, - general_params = general_param_list) + skip_on_cran() + + # Generate simulated data + n <- 100 + p <- 5 + X <- matrix(runif(n * p), ncol = p) + # fmt: skip + f_XW <- (((0 <= X[, 1]) & (0.25 > X[, 1])) * (-7.5) + + ((0.25 <= X[, 1]) & (0.5 > X[, 1])) * (-2.5) + + ((0.5 <= X[, 1]) & (0.75 > X[, 1])) * (2.5) + + ((0.75 <= X[, 1]) & (1 > X[, 1])) * (7.5)) + noise_sd <- 1 + y <- f_XW + rnorm(n, 0, noise_sd) + test_set_pct <- 0.2 + n_test <- round(test_set_pct * n) + n_train <- n - n_test + test_inds <- sort(sample(1:n, n_test, replace = FALSE)) + train_inds <- (1:n)[!((1:n) %in% test_inds)] + X_test <- X[test_inds, ] + X_train <- X[train_inds, ] + y_test <- y[test_inds] + y_train <- y[train_inds] + + # 1 chain, no thinning + general_param_list <- list(num_chains = 1, keep_every = 1) + expect_no_error( + bart_model <- bart( + X_train = X_train, + y_train = y_train, + X_test = X_test, + num_gfr = 0, + num_burnin = 10, + num_mcmc = 10, + general_params = general_param_list ) - - # 1 chain, thinning - general_param_list <- list(num_chains = 1, keep_every = 5) - expect_no_error( - bart_model <- bart(X_train = X_train, y_train = y_train, X_test = X_test, - num_gfr = 0, num_burnin = 10, num_mcmc = 10, - general_params = general_param_list) + ) + + # 3 chains, no thinning + general_param_list <- list(num_chains = 3, keep_every = 1) + expect_no_error( + bart_model <- bart( + X_train = X_train, + y_train = y_train, + X_test = X_test, + num_gfr = 0, + num_burnin = 10, + num_mcmc = 10, + general_params = general_param_list ) - - # 3 chains, thinning - general_param_list <- list(num_chains = 3, keep_every = 5) - expect_no_error( - bart_model <- bart(X_train = X_train, y_train = y_train, X_test = X_test, - num_gfr = 0, num_burnin = 10, num_mcmc = 10, - general_params = general_param_list) + ) + + # 1 chain, thinning + general_param_list <- list(num_chains = 1, keep_every = 5) + expect_no_error( + bart_model <- bart( + X_train = X_train, + y_train = y_train, + X_test = X_test, + num_gfr = 0, + num_burnin = 10, + num_mcmc = 10, + general_params = general_param_list ) - - # Generate simulated data with a leaf basis - n <- 100 - p <- 5 - p_w <- 2 - X <- matrix(runif(n*p), ncol = p) - W <- matrix(runif(n*p_w), ncol = p_w) - f_XW <- ( - ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5*W[,1]) + - ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5*W[,1]) + - ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5*W[,1]) + - ((0.75 <= X[,1]) & (1 > X[,1])) * (7.5*W[,1]) + ) + + # 3 chains, thinning + general_param_list <- list(num_chains = 3, keep_every = 5) + expect_no_error( + bart_model <- bart( + X_train = X_train, + y_train = y_train, + X_test = X_test, + num_gfr = 0, + num_burnin = 10, + num_mcmc = 10, + general_params = general_param_list ) - noise_sd <- 1 - y <- f_XW + rnorm(n, 0, noise_sd) - test_set_pct <- 0.2 - n_test <- round(test_set_pct*n) - n_train <- n - n_test - test_inds <- sort(sample(1:n, n_test, replace = FALSE)) - train_inds <- (1:n)[!((1:n) %in% test_inds)] - X_test <- X[test_inds,] - X_train <- X[train_inds,] - W_test <- W[test_inds,] - W_train <- W[train_inds,] - y_test <- y[test_inds] - y_train <- y[train_inds] - - # 3 chains, thinning, leaf regression - general_param_list <- list(num_chains = 3, keep_every = 5) - mean_forest_param_list <- list(sample_sigma2_leaf = FALSE) - expect_no_error( - bart_model <- bart(X_train = X_train, y_train = y_train, X_test = X_test, - leaf_basis_train = W_train, leaf_basis_test = W_test, - num_gfr = 0, num_burnin = 10, num_mcmc = 10, - general_params = general_param_list, - mean_forest_params = mean_forest_param_list) + ) + + # Generate simulated data with a leaf basis + n <- 100 + p <- 5 + p_w <- 2 + X <- matrix(runif(n * p), ncol = p) + W <- matrix(runif(n * p_w), ncol = p_w) + # fmt: skip + f_XW <- (((0 <= X[, 1]) & (0.25 > X[, 1])) * (-7.5 * W[, 1]) + + ((0.25 <= X[, 1]) & (0.5 > X[, 1])) * (-2.5 * W[, 1]) + + ((0.5 <= X[, 1]) & (0.75 > X[, 1])) * (2.5 * W[, 1]) + + ((0.75 <= X[, 1]) & (1 > X[, 1])) * (7.5 * W[, 1])) + noise_sd <- 1 + y <- f_XW + rnorm(n, 0, noise_sd) + test_set_pct <- 0.2 + n_test <- round(test_set_pct * n) + n_train <- n - n_test + test_inds <- sort(sample(1:n, n_test, replace = FALSE)) + train_inds <- (1:n)[!((1:n) %in% test_inds)] + X_test <- X[test_inds, ] + X_train <- X[train_inds, ] + W_test <- W[test_inds, ] + W_train <- W[train_inds, ] + y_test <- y[test_inds] + y_train <- y[train_inds] + + # 3 chains, thinning, leaf regression + general_param_list <- list(num_chains = 3, keep_every = 5) + mean_forest_param_list <- list(sample_sigma2_leaf = FALSE) + expect_no_error( + bart_model <- bart( + X_train = X_train, + y_train = y_train, + X_test = X_test, + leaf_basis_train = W_train, + leaf_basis_test = W_test, + num_gfr = 0, + num_burnin = 10, + num_mcmc = 10, + general_params = general_param_list, + mean_forest_params = mean_forest_param_list ) - - # 3 chains, thinning, leaf regression with a scalar leaf scale - general_param_list <- list(num_chains = 3, keep_every = 5) - mean_forest_param_list <- list(sample_sigma2_leaf = FALSE, sigma2_leaf_init = 0.5) - expect_no_error( - bart_model <- bart(X_train = X_train, y_train = y_train, X_test = X_test, - leaf_basis_train = W_train, leaf_basis_test = W_test, - num_gfr = 0, num_burnin = 10, num_mcmc = 10, - general_params = general_param_list, - mean_forest_params = mean_forest_param_list) + ) + + # 3 chains, thinning, leaf regression with a scalar leaf scale + general_param_list <- list(num_chains = 3, keep_every = 5) + mean_forest_param_list <- list( + sample_sigma2_leaf = FALSE, + sigma2_leaf_init = 0.5 + ) + expect_no_error( + bart_model <- bart( + X_train = X_train, + y_train = y_train, + X_test = X_test, + leaf_basis_train = W_train, + leaf_basis_test = W_test, + num_gfr = 0, + num_burnin = 10, + num_mcmc = 10, + general_params = general_param_list, + mean_forest_params = mean_forest_param_list ) - - # 3 chains, thinning, leaf regression with a scalar leaf scale, random leaf scale - general_param_list <- list(num_chains = 3, keep_every = 5) - mean_forest_param_list <- list(sample_sigma2_leaf = T, sigma2_leaf_init = 0.5) - expect_warning( - bart_model <- bart(X_train = X_train, y_train = y_train, X_test = X_test, - leaf_basis_train = W_train, leaf_basis_test = W_test, - num_gfr = 0, num_burnin = 10, num_mcmc = 10, - general_params = general_param_list, - mean_forest_params = mean_forest_param_list) + ) + + # 3 chains, thinning, leaf regression with a scalar leaf scale, random leaf scale + general_param_list <- list(num_chains = 3, keep_every = 5) + mean_forest_param_list <- list(sample_sigma2_leaf = T, sigma2_leaf_init = 0.5) + expect_warning( + bart_model <- bart( + X_train = X_train, + y_train = y_train, + X_test = X_test, + leaf_basis_train = W_train, + leaf_basis_test = W_test, + num_gfr = 0, + num_burnin = 10, + num_mcmc = 10, + general_params = general_param_list, + mean_forest_params = mean_forest_param_list ) + ) }) test_that("GFR BART", { - skip_on_cran() - - # Generate simulated data - n <- 100 - p <- 5 - X <- matrix(runif(n*p), ncol = p) - f_XW <- ( - ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + - ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + - ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + - ((0.75 <= X[,1]) & (1 > X[,1])) * (7.5) - ) - noise_sd <- 1 - y <- f_XW + rnorm(n, 0, noise_sd) - test_set_pct <- 0.2 - n_test <- round(test_set_pct*n) - n_train <- n - n_test - test_inds <- sort(sample(1:n, n_test, replace = FALSE)) - train_inds <- (1:n)[!((1:n) %in% test_inds)] - X_test <- X[test_inds,] - X_train <- X[train_inds,] - y_test <- y[test_inds] - y_train <- y[train_inds] - - # 1 chain, no thinning - general_param_list <- list(num_chains = 1, keep_every = 1) - expect_no_error( - bart_model <- bart(X_train = X_train, y_train = y_train, X_test = X_test, - num_gfr = 10, num_burnin = 10, num_mcmc = 10, - general_params = general_param_list) + skip_on_cran() + + # Generate simulated data + n <- 100 + p <- 5 + X <- matrix(runif(n * p), ncol = p) + # fmt: skip + f_XW <- (((0 <= X[, 1]) & (0.25 > X[, 1])) * (-7.5) + + ((0.25 <= X[, 1]) & (0.5 > X[, 1])) * (-2.5) + + ((0.5 <= X[, 1]) & (0.75 > X[, 1])) * (2.5) + + ((0.75 <= X[, 1]) & (1 > X[, 1])) * (7.5)) + noise_sd <- 1 + y <- f_XW + rnorm(n, 0, noise_sd) + test_set_pct <- 0.2 + n_test <- round(test_set_pct * n) + n_train <- n - n_test + test_inds <- sort(sample(1:n, n_test, replace = FALSE)) + train_inds <- (1:n)[!((1:n) %in% test_inds)] + X_test <- X[test_inds, ] + X_train <- X[train_inds, ] + y_test <- y[test_inds] + y_train <- y[train_inds] + + # 1 chain, no thinning + general_param_list <- list(num_chains = 1, keep_every = 1) + expect_no_error( + bart_model <- bart( + X_train = X_train, + y_train = y_train, + X_test = X_test, + num_gfr = 10, + num_burnin = 10, + num_mcmc = 10, + general_params = general_param_list ) - - # 3 chains, no thinning - general_param_list <- list(num_chains = 3, keep_every = 1) - expect_no_error( - bart_model <- bart(X_train = X_train, y_train = y_train, X_test = X_test, - num_gfr = 10, num_burnin = 10, num_mcmc = 10, - general_params = general_param_list) + ) + + # 3 chains, no thinning + general_param_list <- list(num_chains = 3, keep_every = 1) + expect_no_error( + bart_model <- bart( + X_train = X_train, + y_train = y_train, + X_test = X_test, + num_gfr = 10, + num_burnin = 10, + num_mcmc = 10, + general_params = general_param_list ) - - # 1 chain, thinning - general_param_list <- list(num_chains = 1, keep_every = 5) - expect_no_error( - bart_model <- bart(X_train = X_train, y_train = y_train, X_test = X_test, - num_gfr = 10, num_burnin = 10, num_mcmc = 10, - general_params = general_param_list) + ) + + # 1 chain, thinning + general_param_list <- list(num_chains = 1, keep_every = 5) + expect_no_error( + bart_model <- bart( + X_train = X_train, + y_train = y_train, + X_test = X_test, + num_gfr = 10, + num_burnin = 10, + num_mcmc = 10, + general_params = general_param_list ) - - # 3 chains, thinning - general_param_list <- list(num_chains = 3, keep_every = 5) - expect_no_error( - bart_model <- bart(X_train = X_train, y_train = y_train, X_test = X_test, - num_gfr = 10, num_burnin = 10, num_mcmc = 10, - general_params = general_param_list) + ) + + # 3 chains, thinning + general_param_list <- list(num_chains = 3, keep_every = 5) + expect_no_error( + bart_model <- bart( + X_train = X_train, + y_train = y_train, + X_test = X_test, + num_gfr = 10, + num_burnin = 10, + num_mcmc = 10, + general_params = general_param_list ) - - # Check for error when more chains than GFR forests - general_param_list <- list(num_chains = 11, keep_every = 1) - expect_error( - bart_model <- bart(X_train = X_train, y_train = y_train, X_test = X_test, - num_gfr = 10, num_burnin = 10, num_mcmc = 10, - general_params = general_param_list) + ) + + # Check for error when more chains than GFR forests + general_param_list <- list(num_chains = 11, keep_every = 1) + expect_error( + bart_model <- bart( + X_train = X_train, + y_train = y_train, + X_test = X_test, + num_gfr = 10, + num_burnin = 10, + num_mcmc = 10, + general_params = general_param_list ) - - # Check for error when more chains than GFR forests - general_param_list <- list(num_chains = 11, keep_every = 5) - expect_error( - bart_model <- bart(X_train = X_train, y_train = y_train, X_test = X_test, - num_gfr = 10, num_burnin = 10, num_mcmc = 10, - general_params = general_param_list) + ) + + # Check for error when more chains than GFR forests + general_param_list <- list(num_chains = 11, keep_every = 5) + expect_error( + bart_model <- bart( + X_train = X_train, + y_train = y_train, + X_test = X_test, + num_gfr = 10, + num_burnin = 10, + num_mcmc = 10, + general_params = general_param_list ) + ) }) test_that("Warmstart BART", { - skip_on_cran() - - # Generate simulated data - n <- 100 - p <- 5 - X <- matrix(runif(n*p), ncol = p) - f_XW <- ( - ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + - ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + - ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + - ((0.75 <= X[,1]) & (1 > X[,1])) * (7.5) - ) - noise_sd <- 1 - y <- f_XW + rnorm(n, 0, noise_sd) - test_set_pct <- 0.2 - n_test <- round(test_set_pct*n) - n_train <- n - n_test - test_inds <- sort(sample(1:n, n_test, replace = FALSE)) - train_inds <- (1:n)[!((1:n) %in% test_inds)] - X_test <- X[test_inds,] - X_train <- X[train_inds,] - y_test <- y[test_inds] - y_train <- y[train_inds] - - # Run a BART model with only GFR - general_param_list <- list(num_chains = 1, keep_every = 1) - bart_model <- bart(X_train = X_train, y_train = y_train, X_test = X_test, - num_gfr = 10, num_burnin = 0, num_mcmc = 0, - general_params = general_param_list) - - # Save to JSON string - bart_model_json_string <- saveBARTModelToJsonString(bart_model) - - # Run a new BART chain from the existing (X)BART model - general_param_list <- list(num_chains = 3, keep_every = 5) - expect_no_error( - bart_model <- bart(X_train = X_train, y_train = y_train, X_test = X_test, - num_gfr = 0, num_burnin = 10, num_mcmc = 10, - previous_model_json = bart_model_json_string, - previous_model_warmstart_sample_num = 1, - general_params = general_param_list) - - ) - - # Generate simulated data with random effects - n <- 100 - p <- 5 - X <- matrix(runif(n*p), ncol = p) - f_XW <- ( - ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + - ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + - ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + - ((0.75 <= X[,1]) & (1 > X[,1])) * (7.5) + skip_on_cran() + + # Generate simulated data + n <- 100 + p <- 5 + X <- matrix(runif(n * p), ncol = p) + # fmt: skip + f_XW <- (((0 <= X[, 1]) & (0.25 > X[, 1])) * (-7.5) + + ((0.25 <= X[, 1]) & (0.5 > X[, 1])) * (-2.5) + + ((0.5 <= X[, 1]) & (0.75 > X[, 1])) * (2.5) + + ((0.75 <= X[, 1]) & (1 > X[, 1])) * (7.5)) + noise_sd <- 1 + y <- f_XW + rnorm(n, 0, noise_sd) + test_set_pct <- 0.2 + n_test <- round(test_set_pct * n) + n_train <- n - n_test + test_inds <- sort(sample(1:n, n_test, replace = FALSE)) + train_inds <- (1:n)[!((1:n) %in% test_inds)] + X_test <- X[test_inds, ] + X_train <- X[train_inds, ] + y_test <- y[test_inds] + y_train <- y[train_inds] + + # Run a BART model with only GFR + general_param_list <- list(num_chains = 1, keep_every = 1) + bart_model <- bart( + X_train = X_train, + y_train = y_train, + X_test = X_test, + num_gfr = 10, + num_burnin = 0, + num_mcmc = 0, + general_params = general_param_list + ) + + # Save to JSON string + bart_model_json_string <- saveBARTModelToJsonString(bart_model) + + # Run a new BART chain from the existing (X)BART model + general_param_list <- list(num_chains = 3, keep_every = 5) + expect_no_error( + bart_model <- bart( + X_train = X_train, + y_train = y_train, + X_test = X_test, + num_gfr = 0, + num_burnin = 10, + num_mcmc = 10, + previous_model_json = bart_model_json_string, + previous_model_warmstart_sample_num = 1, + general_params = general_param_list ) - rfx_group_ids <- sample(1:2, size = n, replace = TRUE) - rfx_basis <- rep(1, n) - rfx_coefs <- c(-5, 5) - rfx_term <- rfx_coefs[rfx_group_ids] * rfx_basis - noise_sd <- 1 - y <- f_XW + rfx_term + rnorm(n, 0, noise_sd) - test_set_pct <- 0.2 - n_test <- round(test_set_pct*n) - n_train <- n - n_test - test_inds <- sort(sample(1:n, n_test, replace = FALSE)) - train_inds <- (1:n)[!((1:n) %in% test_inds)] - X_test <- X[test_inds,] - X_train <- X[train_inds,] - rfx_group_ids_test <- rfx_group_ids[test_inds] - rfx_group_ids_train <- rfx_group_ids[train_inds] - rfx_basis_test <- rfx_basis[test_inds] - rfx_basis_train <- rfx_basis[train_inds] - y_test <- y[test_inds] - y_train <- y[train_inds] - - # Run a BART model with only GFR - general_param_list <- list(num_chains = 1, keep_every = 1) - bart_model <- bart(X_train = X_train, y_train = y_train, X_test = X_test, - rfx_group_ids_train = rfx_group_ids_train, - rfx_group_ids_test = rfx_group_ids_test, - rfx_basis_train = rfx_basis_train, - rfx_basis_test = rfx_basis_test, - num_gfr = 10, num_burnin = 0, num_mcmc = 0, - general_params = general_param_list) - - # Save to JSON string - bart_model_json_string <- saveBARTModelToJsonString(bart_model) - - # Run a new BART chain from the existing (X)BART model - general_param_list <- list(num_chains = 4, keep_every = 5) - expect_no_error( - bart_model <- bart(X_train = X_train, y_train = y_train, X_test = X_test, - rfx_group_ids_train = rfx_group_ids_train, - rfx_group_ids_test = rfx_group_ids_test, - rfx_basis_train = rfx_basis_train, - rfx_basis_test = rfx_basis_test, - num_gfr = 0, num_burnin = 10, num_mcmc = 10, - previous_model_json = bart_model_json_string, - previous_model_warmstart_sample_num = 1, - general_params = general_param_list) + ) + + # Generate simulated data with random effects + n <- 100 + p <- 5 + X <- matrix(runif(n * p), ncol = p) + # fmt: skip + f_XW <- (((0 <= X[, 1]) & (0.25 > X[, 1])) * (-7.5) + + ((0.25 <= X[, 1]) & (0.5 > X[, 1])) * (-2.5) + + ((0.5 <= X[, 1]) & (0.75 > X[, 1])) * (2.5) + + ((0.75 <= X[, 1]) & (1 > X[, 1])) * (7.5)) + rfx_group_ids <- sample(1:2, size = n, replace = TRUE) + rfx_basis <- rep(1, n) + rfx_coefs <- c(-5, 5) + rfx_term <- rfx_coefs[rfx_group_ids] * rfx_basis + noise_sd <- 1 + y <- f_XW + rfx_term + rnorm(n, 0, noise_sd) + test_set_pct <- 0.2 + n_test <- round(test_set_pct * n) + n_train <- n - n_test + test_inds <- sort(sample(1:n, n_test, replace = FALSE)) + train_inds <- (1:n)[!((1:n) %in% test_inds)] + X_test <- X[test_inds, ] + X_train <- X[train_inds, ] + rfx_group_ids_test <- rfx_group_ids[test_inds] + rfx_group_ids_train <- rfx_group_ids[train_inds] + rfx_basis_test <- rfx_basis[test_inds] + rfx_basis_train <- rfx_basis[train_inds] + y_test <- y[test_inds] + y_train <- y[train_inds] + + # Run a BART model with only GFR + general_param_list <- list(num_chains = 1, keep_every = 1) + bart_model <- bart( + X_train = X_train, + y_train = y_train, + X_test = X_test, + rfx_group_ids_train = rfx_group_ids_train, + rfx_group_ids_test = rfx_group_ids_test, + rfx_basis_train = rfx_basis_train, + rfx_basis_test = rfx_basis_test, + num_gfr = 10, + num_burnin = 0, + num_mcmc = 0, + general_params = general_param_list + ) + + # Save to JSON string + bart_model_json_string <- saveBARTModelToJsonString(bart_model) + + # Run a new BART chain from the existing (X)BART model + general_param_list <- list(num_chains = 4, keep_every = 5) + expect_no_error( + bart_model <- bart( + X_train = X_train, + y_train = y_train, + X_test = X_test, + rfx_group_ids_train = rfx_group_ids_train, + rfx_group_ids_test = rfx_group_ids_test, + rfx_basis_train = rfx_basis_train, + rfx_basis_test = rfx_basis_test, + num_gfr = 0, + num_burnin = 10, + num_mcmc = 10, + previous_model_json = bart_model_json_string, + previous_model_warmstart_sample_num = 1, + general_params = general_param_list ) + ) }) test_that("BART Predictions", { - skip_on_cran() - - # Generate simulated data - n <- 100 - p <- 5 - X <- matrix(runif(n*p), ncol = p) - f_XW <- ( - ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + - ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + - ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + - ((0.75 <= X[,1]) & (1 > X[,1])) * (7.5) - ) - noise_sd <- 1 - y <- f_XW + rnorm(n, 0, noise_sd) - test_set_pct <- 0.2 - n_test <- round(test_set_pct*n) - n_train <- n - n_test - test_inds <- sort(sample(1:n, n_test, replace = FALSE)) - train_inds <- (1:n)[!((1:n) %in% test_inds)] - X_test <- X[test_inds,] - X_train <- X[train_inds,] - y_test <- y[test_inds] - y_train <- y[train_inds] - - # Run a BART model with only GFR - general_params <- list(num_chains = 1) - variance_forest_params <- list(num_trees = 50) - bart_model <- bart(X_train = X_train, y_train = y_train, X_test = X_test, - num_gfr = 10, num_burnin = 0, num_mcmc = 10, - general_params = general_params, - variance_forest_params = variance_forest_params) - - # Check that cached predictions agree with results of predict() function - train_preds <- predict(bart_model, X = X_train) - train_preds_mean_cached <- bart_model$y_hat_train - train_preds_mean_recomputed <- train_preds$mean_forest_predictions - train_preds_variance_cached <- bart_model$sigma2_x_hat_train - train_preds_variance_recomputed <- train_preds$variance_forest_predictions - - # Assertion - expect_equal(train_preds_mean_cached, train_preds_mean_recomputed) - expect_equal(train_preds_variance_cached, train_preds_variance_recomputed) + skip_on_cran() + + # Generate simulated data + n <- 100 + p <- 5 + X <- matrix(runif(n * p), ncol = p) + # fmt: skip + f_XW <- (((0 <= X[, 1]) & (0.25 > X[, 1])) * (-7.5) + + ((0.25 <= X[, 1]) & (0.5 > X[, 1])) * (-2.5) + + ((0.5 <= X[, 1]) & (0.75 > X[, 1])) * (2.5) + + ((0.75 <= X[, 1]) & (1 > X[, 1])) * (7.5)) + noise_sd <- 1 + y <- f_XW + rnorm(n, 0, noise_sd) + test_set_pct <- 0.2 + n_test <- round(test_set_pct * n) + n_train <- n - n_test + test_inds <- sort(sample(1:n, n_test, replace = FALSE)) + train_inds <- (1:n)[!((1:n) %in% test_inds)] + X_test <- X[test_inds, ] + X_train <- X[train_inds, ] + y_test <- y[test_inds] + y_train <- y[train_inds] + + # Run a BART model with only GFR + general_params <- list(num_chains = 1) + variance_forest_params <- list(num_trees = 50) + bart_model <- bart( + X_train = X_train, + y_train = y_train, + X_test = X_test, + num_gfr = 10, + num_burnin = 0, + num_mcmc = 10, + general_params = general_params, + variance_forest_params = variance_forest_params + ) + + # Check that cached predictions agree with results of predict() function + train_preds <- predict(bart_model, covariates = X_train) + train_preds_mean_cached <- bart_model$y_hat_train + train_preds_mean_recomputed <- train_preds$mean_forest_predictions + train_preds_variance_cached <- bart_model$sigma2_x_hat_train + train_preds_variance_recomputed <- train_preds$variance_forest_predictions + + # Assertion + expect_equal(train_preds_mean_cached, train_preds_mean_recomputed) + expect_equal(train_preds_variance_cached, train_preds_variance_recomputed) }) test_that("Random Effects BART", { - skip_on_cran() - - # Generate simulated data - n <- 100 - p <- 5 - p_w <- 2 - X <- matrix(runif(n*p), ncol = p) - W <- matrix(runif(n*p_w), ncol = p_w) - f_XW <- ( - ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5*W[,1]) + - ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5*W[,1]) + - ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5*W[,1]) + - ((0.75 <= X[,1]) & (1 > X[,1])) * (7.5*W[,1]) + skip_on_cran() + + # Generate simulated data + n <- 100 + p <- 5 + p_w <- 2 + X <- matrix(runif(n * p), ncol = p) + W <- matrix(runif(n * p_w), ncol = p_w) + # fmt: skip + f_XW <- (((0 <= X[, 1]) & (0.25 > X[, 1])) * (-7.5 * W[, 1]) + + ((0.25 <= X[, 1]) & (0.5 > X[, 1])) * (-2.5 * W[, 1]) + + ((0.5 <= X[, 1]) & (0.75 > X[, 1])) * (2.5 * W[, 1]) + + ((0.75 <= X[, 1]) & (1 > X[, 1])) * (7.5 * W[, 1])) + rfx_group_ids <- sample(1:2, size = n, replace = TRUE) + rfx_basis <- cbind(rep(1, n), runif(n)) + num_rfx_components <- ncol(rfx_basis) + num_rfx_groups <- length(unique(rfx_group_ids)) + rfx_coefs <- matrix(c(-5, 5, 1, -1), ncol = 2, byrow = T) + rfx_term <- rowSums(rfx_coefs[rfx_group_ids, ] * rfx_basis) + noise_sd <- 1 + y <- f_XW + rfx_term + rnorm(n, 0, noise_sd) + test_set_pct <- 0.2 + n_test <- round(test_set_pct * n) + n_train <- n - n_test + test_inds <- sort(sample(1:n, n_test, replace = FALSE)) + train_inds <- (1:n)[!((1:n) %in% test_inds)] + X_test <- X[test_inds, ] + X_train <- X[train_inds, ] + W_test <- W[test_inds, ] + W_train <- W[train_inds, ] + rfx_group_ids_test <- rfx_group_ids[test_inds] + rfx_group_ids_train <- rfx_group_ids[train_inds] + rfx_basis_test <- rfx_basis[test_inds, ] + rfx_basis_train <- rfx_basis[train_inds, ] + y_test <- y[test_inds] + y_train <- y[train_inds] + + # Specify no rfx parameters + general_param_list <- list() + mean_forest_param_list <- list(sample_sigma2_leaf = FALSE) + expect_no_error( + bart_model <- bart( + X_train = X_train, + y_train = y_train, + X_test = X_test, + leaf_basis_train = W_train, + leaf_basis_test = W_test, + rfx_group_ids_train = rfx_group_ids_train, + rfx_group_ids_test = rfx_group_ids_test, + rfx_basis_train = rfx_basis_train, + rfx_basis_test = rfx_basis_test, + num_gfr = 0, + num_burnin = 10, + num_mcmc = 10, + general_params = general_param_list, + mean_forest_params = mean_forest_param_list ) - rfx_group_ids <- sample(1:2, size = n, replace = TRUE) - rfx_basis <- cbind(rep(1, n), runif(n)) - num_rfx_components <- ncol(rfx_basis) - num_rfx_groups <- length(unique(rfx_group_ids)) - rfx_coefs <- matrix(c(-5, 5, 1, -1), ncol = 2, byrow = T) - rfx_term <- rowSums(rfx_coefs[rfx_group_ids,] * rfx_basis) - noise_sd <- 1 - y <- f_XW + rfx_term + rnorm(n, 0, noise_sd) - test_set_pct <- 0.2 - n_test <- round(test_set_pct*n) - n_train <- n - n_test - test_inds <- sort(sample(1:n, n_test, replace = FALSE)) - train_inds <- (1:n)[!((1:n) %in% test_inds)] - X_test <- X[test_inds,] - X_train <- X[train_inds,] - W_test <- W[test_inds,] - W_train <- W[train_inds,] - rfx_group_ids_test <- rfx_group_ids[test_inds] - rfx_group_ids_train <- rfx_group_ids[train_inds] - rfx_basis_test <- rfx_basis[test_inds,] - rfx_basis_train <- rfx_basis[train_inds,] - y_test <- y[test_inds] - y_train <- y[train_inds] - - # Specify no rfx parameters - general_param_list <- list() - mean_forest_param_list <- list(sample_sigma2_leaf = FALSE) - expect_no_error( - bart_model <- bart(X_train = X_train, y_train = y_train, X_test = X_test, - leaf_basis_train = W_train, leaf_basis_test = W_test, - rfx_group_ids_train = rfx_group_ids_train, - rfx_group_ids_test = rfx_group_ids_test, - rfx_basis_train = rfx_basis_train, - rfx_basis_test = rfx_basis_test, - num_gfr = 0, num_burnin = 10, num_mcmc = 10, - general_params = general_param_list, - mean_forest_params = mean_forest_param_list) + ) + + # Specify all rfx parameters as scalars + rfx_param_list <- list( + working_parameter_prior_mean = 1., + group_parameter_prior_mean = 1., + working_parameter_prior_cov = 1., + group_parameter_prior_cov = 1., + variance_prior_shape = 1, + variance_prior_scale = 1 + ) + mean_forest_param_list <- list(sample_sigma2_leaf = FALSE) + expect_no_error( + bart_model <- bart( + X_train = X_train, + y_train = y_train, + X_test = X_test, + leaf_basis_train = W_train, + leaf_basis_test = W_test, + rfx_group_ids_train = rfx_group_ids_train, + rfx_group_ids_test = rfx_group_ids_test, + rfx_basis_train = rfx_basis_train, + rfx_basis_test = rfx_basis_test, + num_gfr = 0, + num_burnin = 10, + num_mcmc = 10, + mean_forest_params = mean_forest_param_list, + random_effects_params = rfx_param_list ) - - # Specify all rfx parameters as scalars - general_param_list <- list(rfx_working_parameter_prior_mean = 1., - rfx_group_parameter_prior_mean = 1., - rfx_working_parameter_prior_cov = 1., - rfx_group_parameter_prior_cov = 1., - rfx_variance_prior_shape = 1, - rfx_variance_prior_scale = 1) - mean_forest_param_list <- list(sample_sigma2_leaf = FALSE) - expect_no_error( - bart_model <- bart(X_train = X_train, y_train = y_train, X_test = X_test, - leaf_basis_train = W_train, leaf_basis_test = W_test, - rfx_group_ids_train = rfx_group_ids_train, - rfx_group_ids_test = rfx_group_ids_test, - rfx_basis_train = rfx_basis_train, - rfx_basis_test = rfx_basis_test, - num_gfr = 0, num_burnin = 10, num_mcmc = 10, - general_params = general_param_list, - mean_forest_params = mean_forest_param_list) + ) + + # Specify all relevant rfx parameters as vectors + rfx_param_list <- list( + working_parameter_prior_mean = c(1., 1.), + group_parameter_prior_mean = c(1., 1.), + working_parameter_prior_cov = diag(1., 2), + group_parameter_prior_cov = diag(1., 2), + variance_prior_shape = 1, + variance_prior_scale = 1 + ) + mean_forest_param_list <- list(sample_sigma2_leaf = FALSE) + expect_no_error( + bart_model <- bart( + X_train = X_train, + y_train = y_train, + X_test = X_test, + leaf_basis_train = W_train, + leaf_basis_test = W_test, + rfx_group_ids_train = rfx_group_ids_train, + rfx_group_ids_test = rfx_group_ids_test, + rfx_basis_train = rfx_basis_train, + rfx_basis_test = rfx_basis_test, + num_gfr = 0, + num_burnin = 10, + num_mcmc = 10, + mean_forest_params = mean_forest_param_list, + random_effects_params = rfx_param_list ) - - # Specify all relevant rfx parameters as vectors - general_param_list <- list(rfx_working_parameter_prior_mean = c(1.,1.), - rfx_group_parameter_prior_mean = c(1.,1.), - rfx_working_parameter_prior_cov = diag(1.,2), - rfx_group_parameter_prior_cov = diag(1.,2), - rfx_variance_prior_shape = 1, - rfx_variance_prior_scale = 1) - mean_forest_param_list <- list(sample_sigma2_leaf = FALSE) - expect_no_error( - bart_model <- bart(X_train = X_train, y_train = y_train, X_test = X_test, - leaf_basis_train = W_train, leaf_basis_test = W_test, - rfx_group_ids_train = rfx_group_ids_train, - rfx_group_ids_test = rfx_group_ids_test, - rfx_basis_train = rfx_basis_train, - rfx_basis_test = rfx_basis_test, - num_gfr = 0, num_burnin = 10, num_mcmc = 10, - general_params = general_param_list, - mean_forest_params = mean_forest_param_list) + ) + + # Specify simpler intercept-only RFX model + rfx_param_list <- list( + model_spec = "intercept_only" + ) + mean_forest_param_list <- list(sample_sigma2_leaf = FALSE) + expect_no_error({ + bart_model <- bart( + X_train = X_train, + y_train = y_train, + X_test = X_test, + leaf_basis_train = W_train, + leaf_basis_test = W_test, + rfx_group_ids_train = rfx_group_ids_train, + rfx_group_ids_test = rfx_group_ids_test, + num_gfr = 0, + num_burnin = 10, + num_mcmc = 10, + mean_forest_params = mean_forest_param_list, + random_effects_params = rfx_param_list ) -}) \ No newline at end of file + preds <- predict( + bart_model, + covariates = X_test, + leaf_basis = W_test, + rfx_group_ids = rfx_group_ids_test, + type = "posterior", + terms = "rfx" + ) + }) +}) diff --git a/test/R/testthat/test-bcf.R b/test/R/testthat/test-bcf.R index 531320f6..221c333f 100644 --- a/test/R/testthat/test-bcf.R +++ b/test/R/testthat/test-bcf.R @@ -1,607 +1,785 @@ test_that("MCMC BCF", { - skip_on_cran() - - # Generate simulated data - n <- 100 - p <- 5 - X <- matrix(runif(n*p), ncol = p) - mu_X <- ( - ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + - ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + - ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + - ((0.75 <= X[,1]) & (1 > X[,1])) * (7.5) - ) - pi_X <- ( - ((0 <= X[,1]) & (0.25 > X[,1])) * (0.2) + - ((0.25 <= X[,1]) & (0.5 > X[,1])) * (0.4) + - ((0.5 <= X[,1]) & (0.75 > X[,1])) * (0.6) + - ((0.75 <= X[,1]) & (1 > X[,1])) * (0.8) - ) - tau_X <- ( - ((0 <= X[,2]) & (0.25 > X[,2])) * (0.5) + - ((0.25 <= X[,2]) & (0.5 > X[,2])) * (1.0) + - ((0.5 <= X[,2]) & (0.75 > X[,2])) * (1.5) + - ((0.75 <= X[,2]) & (1 > X[,2])) * (2.0) - ) - Z <- rbinom(n, 1, pi_X) - noise_sd <- 1 - y <- mu_X + tau_X*Z + rnorm(n, 0, noise_sd) - test_set_pct <- 0.2 - n_test <- round(test_set_pct*n) - n_train <- n - n_test - test_inds <- sort(sample(1:n, n_test, replace = FALSE)) - train_inds <- (1:n)[!((1:n) %in% test_inds)] - X_test <- X[test_inds,] - X_train <- X[train_inds,] - Z_test <- Z[test_inds] - Z_train <- Z[train_inds] - pi_test <- pi_X[test_inds] - pi_train <- pi_X[train_inds] - mu_test <- mu_X[test_inds] - mu_train <- mu_X[train_inds] - tau_test <- tau_X[test_inds] - tau_train <- tau_X[train_inds] - y_test <- y[test_inds] - y_train <- y[train_inds] - - # 1 chain, no thinning - general_param_list <- list(num_chains = 1, keep_every = 1) - expect_no_error( - bcf_model <- bcf(X_train = X_train, y_train = y_train, Z_train = Z_train, - propensity_train = pi_train, X_test = X_test, Z_test = Z_test, - propensity_test = pi_test, num_gfr = 0, num_burnin = 10, - num_mcmc = 10, general_params = general_param_list) - ) - - # 1 chain, no thinning, matrix leaf scale parameter provided - general_param_list <- list(num_chains = 1, keep_every = 1) - mu_forest_param_list <- list(sigma2_leaf_init = as.matrix(0.5)) - tau_forest_param_list <- list(sigma2_leaf_init = as.matrix(0.5)) - expect_no_error( - bcf_model <- bcf(X_train = X_train, y_train = y_train, Z_train = Z_train, - propensity_train = pi_train, X_test = X_test, Z_test = Z_test, - propensity_test = pi_test, num_gfr = 0, num_burnin = 10, - num_mcmc = 10, general_params = general_param_list, - prognostic_forest_params = mu_forest_param_list, - treatment_effect_forest_params = tau_forest_param_list) - ) - - # 1 chain, no thinning, scalar leaf scale parameter provided - general_param_list <- list(num_chains = 1, keep_every = 1) - mu_forest_param_list <- list(sigma2_leaf_init = 0.5) - tau_forest_param_list <- list(sigma2_leaf_init = 0.5) - expect_no_error( - bcf_model <- bcf(X_train = X_train, y_train = y_train, Z_train = Z_train, - propensity_train = pi_train, X_test = X_test, Z_test = Z_test, - propensity_test = pi_test, num_gfr = 0, num_burnin = 10, - num_mcmc = 10, general_params = general_param_list, - prognostic_forest_params = mu_forest_param_list, - treatment_effect_forest_params = tau_forest_param_list) - ) - - # 3 chains, no thinning - general_param_list <- list(num_chains = 3, keep_every = 1) - expect_no_error( - bcf_model <- bcf(X_train = X_train, y_train = y_train, Z_train = Z_train, - propensity_train = pi_train, X_test = X_test, Z_test = Z_test, - propensity_test = pi_test, num_gfr = 0, num_burnin = 10, - num_mcmc = 10, general_params = general_param_list) - ) - - # 1 chain, thinning - general_param_list <- list(num_chains = 1, keep_every = 5) - expect_no_error( - bcf_model <- bcf(X_train = X_train, y_train = y_train, Z_train = Z_train, - propensity_train = pi_train, X_test = X_test, Z_test = Z_test, - propensity_test = pi_test, num_gfr = 0, num_burnin = 10, - num_mcmc = 10, general_params = general_param_list) - ) - - # 3 chains, thinning - general_param_list <- list(num_chains = 3, keep_every = 5) - expect_no_error( - bcf_model <- bcf(X_train = X_train, y_train = y_train, Z_train = Z_train, - propensity_train = pi_train, X_test = X_test, Z_test = Z_test, - propensity_test = pi_test, num_gfr = 0, num_burnin = 10, - num_mcmc = 10, general_params = general_param_list) - ) + skip_on_cran() + + # Generate simulated data + n <- 100 + p <- 5 + X <- matrix(runif(n * p), ncol = p) + mu_X <- (((0 <= X[, 1]) & (0.25 > X[, 1])) * + (-7.5) + + ((0.25 <= X[, 1]) & (0.5 > X[, 1])) * (-2.5) + + ((0.5 <= X[, 1]) & (0.75 > X[, 1])) * (2.5) + + ((0.75 <= X[, 1]) & (1 > X[, 1])) * (7.5)) + pi_X <- (((0 <= X[, 1]) & (0.25 > X[, 1])) * + (0.2) + + ((0.25 <= X[, 1]) & (0.5 > X[, 1])) * (0.4) + + ((0.5 <= X[, 1]) & (0.75 > X[, 1])) * (0.6) + + ((0.75 <= X[, 1]) & (1 > X[, 1])) * (0.8)) + tau_X <- (((0 <= X[, 2]) & (0.25 > X[, 2])) * + (0.5) + + ((0.25 <= X[, 2]) & (0.5 > X[, 2])) * (1.0) + + ((0.5 <= X[, 2]) & (0.75 > X[, 2])) * (1.5) + + ((0.75 <= X[, 2]) & (1 > X[, 2])) * (2.0)) + Z <- rbinom(n, 1, pi_X) + noise_sd <- 1 + y <- mu_X + tau_X * Z + rnorm(n, 0, noise_sd) + test_set_pct <- 0.2 + n_test <- round(test_set_pct * n) + n_train <- n - n_test + test_inds <- sort(sample(1:n, n_test, replace = FALSE)) + train_inds <- (1:n)[!((1:n) %in% test_inds)] + X_test <- X[test_inds, ] + X_train <- X[train_inds, ] + Z_test <- Z[test_inds] + Z_train <- Z[train_inds] + pi_test <- pi_X[test_inds] + pi_train <- pi_X[train_inds] + mu_test <- mu_X[test_inds] + mu_train <- mu_X[train_inds] + tau_test <- tau_X[test_inds] + tau_train <- tau_X[train_inds] + y_test <- y[test_inds] + y_train <- y[train_inds] + + # 1 chain, no thinning + general_param_list <- list(num_chains = 1, keep_every = 1) + expect_no_error( + bcf_model <- bcf( + X_train = X_train, + y_train = y_train, + Z_train = Z_train, + propensity_train = pi_train, + X_test = X_test, + Z_test = Z_test, + propensity_test = pi_test, + num_gfr = 0, + num_burnin = 10, + num_mcmc = 10, + general_params = general_param_list + ) + ) + + # 1 chain, no thinning, matrix leaf scale parameter provided + general_param_list <- list(num_chains = 1, keep_every = 1) + mu_forest_param_list <- list(sigma2_leaf_init = as.matrix(0.5)) + tau_forest_param_list <- list(sigma2_leaf_init = as.matrix(0.5)) + expect_no_error( + bcf_model <- bcf( + X_train = X_train, + y_train = y_train, + Z_train = Z_train, + propensity_train = pi_train, + X_test = X_test, + Z_test = Z_test, + propensity_test = pi_test, + num_gfr = 0, + num_burnin = 10, + num_mcmc = 10, + general_params = general_param_list, + prognostic_forest_params = mu_forest_param_list, + treatment_effect_forest_params = tau_forest_param_list + ) + ) + + # 1 chain, no thinning, scalar leaf scale parameter provided + general_param_list <- list(num_chains = 1, keep_every = 1) + mu_forest_param_list <- list(sigma2_leaf_init = 0.5) + tau_forest_param_list <- list(sigma2_leaf_init = 0.5) + expect_no_error( + bcf_model <- bcf( + X_train = X_train, + y_train = y_train, + Z_train = Z_train, + propensity_train = pi_train, + X_test = X_test, + Z_test = Z_test, + propensity_test = pi_test, + num_gfr = 0, + num_burnin = 10, + num_mcmc = 10, + general_params = general_param_list, + prognostic_forest_params = mu_forest_param_list, + treatment_effect_forest_params = tau_forest_param_list + ) + ) + + # 3 chains, no thinning + general_param_list <- list(num_chains = 3, keep_every = 1) + expect_no_error( + bcf_model <- bcf( + X_train = X_train, + y_train = y_train, + Z_train = Z_train, + propensity_train = pi_train, + X_test = X_test, + Z_test = Z_test, + propensity_test = pi_test, + num_gfr = 0, + num_burnin = 10, + num_mcmc = 10, + general_params = general_param_list + ) + ) + + # 1 chain, thinning + general_param_list <- list(num_chains = 1, keep_every = 5) + expect_no_error( + bcf_model <- bcf( + X_train = X_train, + y_train = y_train, + Z_train = Z_train, + propensity_train = pi_train, + X_test = X_test, + Z_test = Z_test, + propensity_test = pi_test, + num_gfr = 0, + num_burnin = 10, + num_mcmc = 10, + general_params = general_param_list + ) + ) + + # 3 chains, thinning + general_param_list <- list(num_chains = 3, keep_every = 5) + expect_no_error( + bcf_model <- bcf( + X_train = X_train, + y_train = y_train, + Z_train = Z_train, + propensity_train = pi_train, + X_test = X_test, + Z_test = Z_test, + propensity_test = pi_test, + num_gfr = 0, + num_burnin = 10, + num_mcmc = 10, + general_params = general_param_list + ) + ) }) test_that("GFR BCF", { - skip_on_cran() - - # Generate simulated data - n <- 100 - p <- 5 - X <- matrix(runif(n*p), ncol = p) - mu_X <- ( - ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + - ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + - ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + - ((0.75 <= X[,1]) & (1 > X[,1])) * (7.5) - ) - pi_X <- ( - ((0 <= X[,1]) & (0.25 > X[,1])) * (0.2) + - ((0.25 <= X[,1]) & (0.5 > X[,1])) * (0.4) + - ((0.5 <= X[,1]) & (0.75 > X[,1])) * (0.6) + - ((0.75 <= X[,1]) & (1 > X[,1])) * (0.8) - ) - tau_X <- ( - ((0 <= X[,2]) & (0.25 > X[,2])) * (0.5) + - ((0.25 <= X[,2]) & (0.5 > X[,2])) * (1.0) + - ((0.5 <= X[,2]) & (0.75 > X[,2])) * (1.5) + - ((0.75 <= X[,2]) & (1 > X[,2])) * (2.0) - ) - Z <- rbinom(n, 1, pi_X) - noise_sd <- 1 - y <- mu_X + tau_X*Z + rnorm(n, 0, noise_sd) - test_set_pct <- 0.2 - n_test <- round(test_set_pct*n) - n_train <- n - n_test - test_inds <- sort(sample(1:n, n_test, replace = FALSE)) - train_inds <- (1:n)[!((1:n) %in% test_inds)] - X_test <- X[test_inds,] - X_train <- X[train_inds,] - Z_test <- Z[test_inds] - Z_train <- Z[train_inds] - pi_test <- pi_X[test_inds] - pi_train <- pi_X[train_inds] - mu_test <- mu_X[test_inds] - mu_train <- mu_X[train_inds] - tau_test <- tau_X[test_inds] - tau_train <- tau_X[train_inds] - y_test <- y[test_inds] - y_train <- y[train_inds] - - # 1 chain, no thinning - general_param_list <- list(num_chains = 1, keep_every = 1) - expect_no_error( - bcf_model <- bcf(X_train = X_train, y_train = y_train, Z_train = Z_train, - propensity_train = pi_train, X_test = X_test, Z_test = Z_test, - propensity_test = pi_test, num_gfr = 10, num_burnin = 10, - num_mcmc = 10, general_params = general_param_list) - ) - - # 3 chains, no thinning - general_param_list <- list(num_chains = 3, keep_every = 1) - expect_no_error( - bcf_model <- bcf(X_train = X_train, y_train = y_train, Z_train = Z_train, - propensity_train = pi_train, X_test = X_test, Z_test = Z_test, - propensity_test = pi_test, num_gfr = 10, num_burnin = 10, - num_mcmc = 10, general_params = general_param_list) - ) - - # 1 chain, thinning - general_param_list <- list(num_chains = 1, keep_every = 5) - expect_no_error( - bcf_model <- bcf(X_train = X_train, y_train = y_train, Z_train = Z_train, - propensity_train = pi_train, X_test = X_test, Z_test = Z_test, - propensity_test = pi_test, num_gfr = 10, num_burnin = 10, - num_mcmc = 10, general_params = general_param_list) - ) - - # 3 chains, thinning - general_param_list <- list(num_chains = 3, keep_every = 5) - expect_no_error( - bcf_model <- bcf(X_train = X_train, y_train = y_train, Z_train = Z_train, - propensity_train = pi_train, X_test = X_test, Z_test = Z_test, - propensity_test = pi_test, num_gfr = 10, num_burnin = 10, - num_mcmc = 10, general_params = general_param_list) - ) - - # Check for error when more chains than GFR forests - general_param_list <- list(num_chains = 11, keep_every = 1) - expect_error( - bcf_model <- bcf(X_train = X_train, y_train = y_train, Z_train = Z_train, - propensity_train = pi_train, X_test = X_test, Z_test = Z_test, - propensity_test = pi_test, num_gfr = 10, num_burnin = 10, - num_mcmc = 10, general_params = general_param_list) - ) - - # Check for error when more chains than GFR forests - general_param_list <- list(num_chains = 11, keep_every = 5) - expect_error( - bcf_model <- bcf(X_train = X_train, y_train = y_train, Z_train = Z_train, - propensity_train = pi_train, X_test = X_test, Z_test = Z_test, - propensity_test = pi_test, num_gfr = 10, num_burnin = 10, - num_mcmc = 10, general_params = general_param_list) - ) + skip_on_cran() + + # Generate simulated data + n <- 100 + p <- 5 + X <- matrix(runif(n * p), ncol = p) + mu_X <- (((0 <= X[, 1]) & (0.25 > X[, 1])) * + (-7.5) + + ((0.25 <= X[, 1]) & (0.5 > X[, 1])) * (-2.5) + + ((0.5 <= X[, 1]) & (0.75 > X[, 1])) * (2.5) + + ((0.75 <= X[, 1]) & (1 > X[, 1])) * (7.5)) + pi_X <- (((0 <= X[, 1]) & (0.25 > X[, 1])) * + (0.2) + + ((0.25 <= X[, 1]) & (0.5 > X[, 1])) * (0.4) + + ((0.5 <= X[, 1]) & (0.75 > X[, 1])) * (0.6) + + ((0.75 <= X[, 1]) & (1 > X[, 1])) * (0.8)) + tau_X <- (((0 <= X[, 2]) & (0.25 > X[, 2])) * + (0.5) + + ((0.25 <= X[, 2]) & (0.5 > X[, 2])) * (1.0) + + ((0.5 <= X[, 2]) & (0.75 > X[, 2])) * (1.5) + + ((0.75 <= X[, 2]) & (1 > X[, 2])) * (2.0)) + Z <- rbinom(n, 1, pi_X) + noise_sd <- 1 + y <- mu_X + tau_X * Z + rnorm(n, 0, noise_sd) + test_set_pct <- 0.2 + n_test <- round(test_set_pct * n) + n_train <- n - n_test + test_inds <- sort(sample(1:n, n_test, replace = FALSE)) + train_inds <- (1:n)[!((1:n) %in% test_inds)] + X_test <- X[test_inds, ] + X_train <- X[train_inds, ] + Z_test <- Z[test_inds] + Z_train <- Z[train_inds] + pi_test <- pi_X[test_inds] + pi_train <- pi_X[train_inds] + mu_test <- mu_X[test_inds] + mu_train <- mu_X[train_inds] + tau_test <- tau_X[test_inds] + tau_train <- tau_X[train_inds] + y_test <- y[test_inds] + y_train <- y[train_inds] + + # 1 chain, no thinning + general_param_list <- list(num_chains = 1, keep_every = 1) + expect_no_error( + bcf_model <- bcf( + X_train = X_train, + y_train = y_train, + Z_train = Z_train, + propensity_train = pi_train, + X_test = X_test, + Z_test = Z_test, + propensity_test = pi_test, + num_gfr = 10, + num_burnin = 10, + num_mcmc = 10, + general_params = general_param_list + ) + ) + + # 3 chains, no thinning + general_param_list <- list(num_chains = 3, keep_every = 1) + expect_no_error( + bcf_model <- bcf( + X_train = X_train, + y_train = y_train, + Z_train = Z_train, + propensity_train = pi_train, + X_test = X_test, + Z_test = Z_test, + propensity_test = pi_test, + num_gfr = 10, + num_burnin = 10, + num_mcmc = 10, + general_params = general_param_list + ) + ) + + # 1 chain, thinning + general_param_list <- list(num_chains = 1, keep_every = 5) + expect_no_error( + bcf_model <- bcf( + X_train = X_train, + y_train = y_train, + Z_train = Z_train, + propensity_train = pi_train, + X_test = X_test, + Z_test = Z_test, + propensity_test = pi_test, + num_gfr = 10, + num_burnin = 10, + num_mcmc = 10, + general_params = general_param_list + ) + ) + + # 3 chains, thinning + general_param_list <- list(num_chains = 3, keep_every = 5) + expect_no_error( + bcf_model <- bcf( + X_train = X_train, + y_train = y_train, + Z_train = Z_train, + propensity_train = pi_train, + X_test = X_test, + Z_test = Z_test, + propensity_test = pi_test, + num_gfr = 10, + num_burnin = 10, + num_mcmc = 10, + general_params = general_param_list + ) + ) + + # Check for error when more chains than GFR forests + general_param_list <- list(num_chains = 11, keep_every = 1) + expect_error( + bcf_model <- bcf( + X_train = X_train, + y_train = y_train, + Z_train = Z_train, + propensity_train = pi_train, + X_test = X_test, + Z_test = Z_test, + propensity_test = pi_test, + num_gfr = 10, + num_burnin = 10, + num_mcmc = 10, + general_params = general_param_list + ) + ) + + # Check for error when more chains than GFR forests + general_param_list <- list(num_chains = 11, keep_every = 5) + expect_error( + bcf_model <- bcf( + X_train = X_train, + y_train = y_train, + Z_train = Z_train, + propensity_train = pi_train, + X_test = X_test, + Z_test = Z_test, + propensity_test = pi_test, + num_gfr = 10, + num_burnin = 10, + num_mcmc = 10, + general_params = general_param_list + ) + ) }) test_that("Warmstart BCF", { - skip_on_cran() - - # Generate simulated data - n <- 100 - p <- 5 - X <- matrix(runif(n*p), ncol = p) - mu_X <- ( - ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + - ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + - ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + - ((0.75 <= X[,1]) & (1 > X[,1])) * (7.5) - ) - pi_X <- ( - ((0 <= X[,1]) & (0.25 > X[,1])) * (0.2) + - ((0.25 <= X[,1]) & (0.5 > X[,1])) * (0.4) + - ((0.5 <= X[,1]) & (0.75 > X[,1])) * (0.6) + - ((0.75 <= X[,1]) & (1 > X[,1])) * (0.8) - ) - tau_X <- ( - ((0 <= X[,2]) & (0.25 > X[,2])) * (0.5) + - ((0.25 <= X[,2]) & (0.5 > X[,2])) * (1.0) + - ((0.5 <= X[,2]) & (0.75 > X[,2])) * (1.5) + - ((0.75 <= X[,2]) & (1 > X[,2])) * (2.0) - ) - Z <- rbinom(n, 1, pi_X) - noise_sd <- 1 - y <- mu_X + tau_X*Z + rnorm(n, 0, noise_sd) - test_set_pct <- 0.2 - n_test <- round(test_set_pct*n) - n_train <- n - n_test - test_inds <- sort(sample(1:n, n_test, replace = FALSE)) - train_inds <- (1:n)[!((1:n) %in% test_inds)] - X_test <- X[test_inds,] - X_train <- X[train_inds,] - Z_test <- Z[test_inds] - Z_train <- Z[train_inds] - pi_test <- pi_X[test_inds] - pi_train <- pi_X[train_inds] - mu_test <- mu_X[test_inds] - mu_train <- mu_X[train_inds] - tau_test <- tau_X[test_inds] - tau_train <- tau_X[train_inds] - y_test <- y[test_inds] - y_train <- y[train_inds] - - # Run a BCF model with only GFR - general_param_list <- list(num_chains = 1, keep_every = 1) - bcf_model <- bcf(X_train = X_train, y_train = y_train, Z_train = Z_train, - propensity_train = pi_train, X_test = X_test, Z_test = Z_test, - propensity_test = pi_test, num_gfr = 10, num_burnin = 0, - num_mcmc = 0, general_params = general_param_list) - - # Save to JSON string - bcf_model_json_string <- saveBCFModelToJsonString(bcf_model) - - # Run a new BCF chain from the existing (X)BCF model - general_param_list <- list(num_chains = 3, keep_every = 5) - expect_no_error( - bcf_model <- bcf(X_train = X_train, y_train = y_train, Z_train = Z_train, - propensity_train = pi_train, X_test = X_test, Z_test = Z_test, - propensity_test = pi_test, num_gfr = 0, num_burnin = 10, - num_mcmc = 10, previous_model_json = bcf_model_json_string, - previous_model_warmstart_sample_num = 1, - general_params = general_param_list) - ) - - # Generate simulated data with random effects - n <- 100 - p <- 5 - X <- matrix(runif(n*p), ncol = p) - mu_X <- ( - ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + - ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + - ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + - ((0.75 <= X[,1]) & (1 > X[,1])) * (7.5) - ) - pi_X <- ( - ((0 <= X[,1]) & (0.25 > X[,1])) * (0.2) + - ((0.25 <= X[,1]) & (0.5 > X[,1])) * (0.4) + - ((0.5 <= X[,1]) & (0.75 > X[,1])) * (0.6) + - ((0.75 <= X[,1]) & (1 > X[,1])) * (0.8) - ) - tau_X <- ( - ((0 <= X[,2]) & (0.25 > X[,2])) * (0.5) + - ((0.25 <= X[,2]) & (0.5 > X[,2])) * (1.0) + - ((0.5 <= X[,2]) & (0.75 > X[,2])) * (1.5) + - ((0.75 <= X[,2]) & (1 > X[,2])) * (2.0) - ) - Z <- rbinom(n, 1, pi_X) - rfx_group_ids <- sample(1:2, size = n, replace = TRUE) - rfx_basis <- rep(1, n) - rfx_coefs <- c(-5, 5) - rfx_term <- rfx_coefs[rfx_group_ids] * rfx_basis - noise_sd <- 1 - y <- mu_X + tau_X*Z + rfx_term + rnorm(n, 0, noise_sd) - test_set_pct <- 0.2 - n_test <- round(test_set_pct*n) - n_train <- n - n_test - test_inds <- sort(sample(1:n, n_test, replace = FALSE)) - train_inds <- (1:n)[!((1:n) %in% test_inds)] - X_test <- X[test_inds,] - X_train <- X[train_inds,] - Z_test <- Z[test_inds] - Z_train <- Z[train_inds] - pi_test <- pi_X[test_inds] - pi_train <- pi_X[train_inds] - mu_test <- mu_X[test_inds] - mu_train <- mu_X[train_inds] - tau_test <- tau_X[test_inds] - tau_train <- tau_X[train_inds] - rfx_group_ids_test <- rfx_group_ids[test_inds] - rfx_group_ids_train <- rfx_group_ids[train_inds] - rfx_basis_test <- rfx_basis[test_inds] - rfx_basis_train <- rfx_basis[train_inds] - y_test <- y[test_inds] - y_train <- y[train_inds] - - # Run a BCF model with only GFR - general_param_list <- list(num_chains = 1, keep_every = 1) - bcf_model <- bcf(X_train = X_train, y_train = y_train, Z_train = Z_train, - propensity_train = pi_train, X_test = X_test, Z_test = Z_test, - rfx_group_ids_train = rfx_group_ids_train, - rfx_group_ids_test = rfx_group_ids_test, - rfx_basis_train = rfx_basis_train, - rfx_basis_test = rfx_basis_test, - propensity_test = pi_test, num_gfr = 10, num_burnin = 0, - num_mcmc = 0, general_params = general_param_list) - - # Save to JSON string - bcf_model_json_string <- saveBCFModelToJsonString(bcf_model) - - # Run a new BCF chain from the existing (X)BCF model - general_param_list <- list(num_chains = 3, keep_every = 5) - expect_no_error( - bcf_model <- bcf(X_train = X_train, y_train = y_train, Z_train = Z_train, - propensity_train = pi_train, X_test = X_test, Z_test = Z_test, - rfx_group_ids_train = rfx_group_ids_train, - rfx_group_ids_test = rfx_group_ids_test, - rfx_basis_train = rfx_basis_train, - rfx_basis_test = rfx_basis_test, - propensity_test = pi_test, num_gfr = 0, num_burnin = 10, - num_mcmc = 10, previous_model_json = bcf_model_json_string, - previous_model_warmstart_sample_num = 1, - general_params = general_param_list) - ) + skip_on_cran() + + # Generate simulated data + n <- 100 + p <- 5 + X <- matrix(runif(n * p), ncol = p) + # fmt: skip + mu_X <- (((0 <= X[, 1]) & (0.25 > X[, 1])) * (-7.5) + + ((0.25 <= X[, 1]) & (0.5 > X[, 1])) * (-2.5) + + ((0.5 <= X[, 1]) & (0.75 > X[, 1])) * (2.5) + + ((0.75 <= X[, 1]) & (1 > X[, 1])) * (7.5)) + # fmt: skip + pi_X <- (((0 <= X[, 1]) & (0.25 > X[, 1])) * (0.2) + + ((0.25 <= X[, 1]) & (0.5 > X[, 1])) * (0.4) + + ((0.5 <= X[, 1]) & (0.75 > X[, 1])) * (0.6) + + ((0.75 <= X[, 1]) & (1 > X[, 1])) * (0.8)) + # fmt: skip + tau_X <- (((0 <= X[, 2]) & (0.25 > X[, 2])) * (0.5) + + ((0.25 <= X[, 2]) & (0.5 > X[, 2])) * (1.0) + + ((0.5 <= X[, 2]) & (0.75 > X[, 2])) * (1.5) + + ((0.75 <= X[, 2]) & (1 > X[, 2])) * (2.0)) + Z <- rbinom(n, 1, pi_X) + noise_sd <- 1 + y <- mu_X + tau_X * Z + rnorm(n, 0, noise_sd) + test_set_pct <- 0.2 + n_test <- round(test_set_pct * n) + n_train <- n - n_test + test_inds <- sort(sample(1:n, n_test, replace = FALSE)) + train_inds <- (1:n)[!((1:n) %in% test_inds)] + X_test <- X[test_inds, ] + X_train <- X[train_inds, ] + Z_test <- Z[test_inds] + Z_train <- Z[train_inds] + pi_test <- pi_X[test_inds] + pi_train <- pi_X[train_inds] + mu_test <- mu_X[test_inds] + mu_train <- mu_X[train_inds] + tau_test <- tau_X[test_inds] + tau_train <- tau_X[train_inds] + y_test <- y[test_inds] + y_train <- y[train_inds] + + # Run a BCF model with only GFR + general_param_list <- list(num_chains = 1, keep_every = 1) + bcf_model <- bcf( + X_train = X_train, + y_train = y_train, + Z_train = Z_train, + propensity_train = pi_train, + X_test = X_test, + Z_test = Z_test, + propensity_test = pi_test, + num_gfr = 10, + num_burnin = 0, + num_mcmc = 0, + general_params = general_param_list + ) + + # Save to JSON string + bcf_model_json_string <- saveBCFModelToJsonString(bcf_model) + + # Run a new BCF chain from the existing (X)BCF model + general_param_list <- list(num_chains = 3, keep_every = 5) + expect_no_error( + bcf_model <- bcf( + X_train = X_train, + y_train = y_train, + Z_train = Z_train, + propensity_train = pi_train, + X_test = X_test, + Z_test = Z_test, + propensity_test = pi_test, + num_gfr = 0, + num_burnin = 10, + num_mcmc = 10, + previous_model_json = bcf_model_json_string, + previous_model_warmstart_sample_num = 1, + general_params = general_param_list + ) + ) + + # Generate simulated data with random effects + n <- 100 + p <- 5 + X <- matrix(runif(n * p), ncol = p) + # fmt: skip + mu_X <- (((0 <= X[, 1]) & (0.25 > X[, 1])) * (-7.5) + + ((0.25 <= X[, 1]) & (0.5 > X[, 1])) * (-2.5) + + ((0.5 <= X[, 1]) & (0.75 > X[, 1])) * (2.5) + + ((0.75 <= X[, 1]) & (1 > X[, 1])) * (7.5)) + # fmt: skip + pi_X <- (((0 <= X[, 1]) & (0.25 > X[, 1])) * (0.2) + + ((0.25 <= X[, 1]) & (0.5 > X[, 1])) * (0.4) + + ((0.5 <= X[, 1]) & (0.75 > X[, 1])) * (0.6) + + ((0.75 <= X[, 1]) & (1 > X[, 1])) * (0.8)) + # fmt: skip + tau_X <- (((0 <= X[, 2]) & (0.25 > X[, 2])) * (0.5) + + ((0.25 <= X[, 2]) & (0.5 > X[, 2])) * (1.0) + + ((0.5 <= X[, 2]) & (0.75 > X[, 2])) * (1.5) + + ((0.75 <= X[, 2]) & (1 > X[, 2])) * (2.0)) + Z <- rbinom(n, 1, pi_X) + rfx_group_ids <- sample(1:2, size = n, replace = TRUE) + rfx_basis <- rep(1, n) + rfx_coefs <- c(-5, 5) + rfx_term <- rfx_coefs[rfx_group_ids] * rfx_basis + noise_sd <- 1 + y <- mu_X + tau_X * Z + rfx_term + rnorm(n, 0, noise_sd) + test_set_pct <- 0.2 + n_test <- round(test_set_pct * n) + n_train <- n - n_test + test_inds <- sort(sample(1:n, n_test, replace = FALSE)) + train_inds <- (1:n)[!((1:n) %in% test_inds)] + X_test <- X[test_inds, ] + X_train <- X[train_inds, ] + Z_test <- Z[test_inds] + Z_train <- Z[train_inds] + pi_test <- pi_X[test_inds] + pi_train <- pi_X[train_inds] + mu_test <- mu_X[test_inds] + mu_train <- mu_X[train_inds] + tau_test <- tau_X[test_inds] + tau_train <- tau_X[train_inds] + rfx_group_ids_test <- rfx_group_ids[test_inds] + rfx_group_ids_train <- rfx_group_ids[train_inds] + rfx_basis_test <- rfx_basis[test_inds] + rfx_basis_train <- rfx_basis[train_inds] + y_test <- y[test_inds] + y_train <- y[train_inds] + + # Run a BCF model with only GFR + general_param_list <- list(num_chains = 1, keep_every = 1) + bcf_model <- bcf( + X_train = X_train, + y_train = y_train, + Z_train = Z_train, + propensity_train = pi_train, + X_test = X_test, + Z_test = Z_test, + rfx_group_ids_train = rfx_group_ids_train, + rfx_group_ids_test = rfx_group_ids_test, + rfx_basis_train = rfx_basis_train, + rfx_basis_test = rfx_basis_test, + propensity_test = pi_test, + num_gfr = 10, + num_burnin = 0, + num_mcmc = 0, + general_params = general_param_list + ) + + # Save to JSON string + bcf_model_json_string <- saveBCFModelToJsonString(bcf_model) + + # Run a new BCF chain from the existing (X)BCF model + general_param_list <- list(num_chains = 3, keep_every = 5) + expect_no_error( + bcf_model <- bcf( + X_train = X_train, + y_train = y_train, + Z_train = Z_train, + propensity_train = pi_train, + X_test = X_test, + Z_test = Z_test, + rfx_group_ids_train = rfx_group_ids_train, + rfx_group_ids_test = rfx_group_ids_test, + rfx_basis_train = rfx_basis_train, + rfx_basis_test = rfx_basis_test, + propensity_test = pi_test, + num_gfr = 0, + num_burnin = 10, + num_mcmc = 10, + previous_model_json = bcf_model_json_string, + previous_model_warmstart_sample_num = 1, + general_params = general_param_list + ) + ) }) test_that("Multivariate Treatment MCMC BCF", { - skip_on_cran() - - # Generate simulated data - n <- 100 - p <- 5 - X <- matrix(runif(n*p), ncol = p) - mu_X <- ( - ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + - ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + - ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + - ((0.75 <= X[,1]) & (1 > X[,1])) * (7.5) - ) - pi_X_1 <- ( - ((0 <= X[,1]) & (0.25 > X[,1])) * (0.2) + - ((0.25 <= X[,1]) & (0.5 > X[,1])) * (0.4) + - ((0.5 <= X[,1]) & (0.75 > X[,1])) * (0.6) + - ((0.75 <= X[,1]) & (1 > X[,1])) * (0.8) - ) - pi_X_2 <- ( - ((0 <= X[,2]) & (0.25 > X[,2])) * (0.8) + - ((0.25 <= X[,2]) & (0.5 > X[,2])) * (0.4) + - ((0.5 <= X[,2]) & (0.75 > X[,2])) * (0.6) + - ((0.75 <= X[,2]) & (1 > X[,2])) * (0.2) - ) - pi_X <- cbind(pi_X_1, pi_X_2) - tau_X_1 <- ( - ((0 <= X[,2]) & (0.25 > X[,2])) * (0.5) + - ((0.25 <= X[,2]) & (0.5 > X[,2])) * (1.0) + - ((0.5 <= X[,2]) & (0.75 > X[,2])) * (1.5) + - ((0.75 <= X[,2]) & (1 > X[,2])) * (2.0) - ) - tau_X_2 <- ( - ((0 <= X[,3]) & (0.25 > X[,3])) * (-0.5) + - ((0.25 <= X[,3]) & (0.5 > X[,3])) * (-1.5) + - ((0.5 <= X[,3]) & (0.75 > X[,3])) * (-1.0) + - ((0.75 <= X[,3]) & (1 > X[,3])) * (0.0) - ) - tau_X <- cbind(tau_X_1, tau_X_2) - Z_1 <- as.numeric(rbinom(n, 1, pi_X_1)) - Z_2 <- as.numeric(rbinom(n, 1, pi_X_2)) - Z <- cbind(Z_1, Z_2) - noise_sd <- 1 - y <- mu_X + rowSums(tau_X*Z) + rnorm(n, 0, noise_sd) - test_set_pct <- 0.2 - n_test <- round(test_set_pct*n) - n_train <- n - n_test - test_inds <- sort(sample(1:n, n_test, replace = FALSE)) - train_inds <- (1:n)[!((1:n) %in% test_inds)] - X_test <- X[test_inds,] - X_train <- X[train_inds,] - Z_test <- Z[test_inds,] - Z_train <- Z[train_inds,] - pi_test <- pi_X[test_inds,] - pi_train <- pi_X[train_inds,] - mu_test <- mu_X[test_inds] - mu_train <- mu_X[train_inds] - tau_test <- tau_X[test_inds,] - tau_train <- tau_X[train_inds,] - y_test <- y[test_inds] - y_train <- y[train_inds] - - # 1 chain, no thinning - general_param_list <- list(num_chains = 1, keep_every = 1, adaptive_coding = F) - expect_no_error({ - bcf_model <- bcf(X_train = X_train, y_train = y_train, Z_train = Z_train, - propensity_train = pi_train, X_test = X_test, Z_test = Z_test, - propensity_test = pi_test, num_gfr = 0, num_burnin = 10, - num_mcmc = 10, general_params = general_param_list) - predict(bcf_model, X = X_test, Z = Z_test, propensity = pi_test) - }) + skip_on_cran() + + # Generate simulated data + n <- 100 + p <- 5 + X <- matrix(runif(n * p), ncol = p) + # fmt: skip + mu_X <- (((0 <= X[, 1]) & (0.25 > X[, 1])) * (-7.5) + + ((0.25 <= X[, 1]) & (0.5 > X[, 1])) * (-2.5) + + ((0.5 <= X[, 1]) & (0.75 > X[, 1])) * (2.5) + + ((0.75 <= X[, 1]) & (1 > X[, 1])) * (7.5)) + # fmt: skip + pi_X_1 <- (((0 <= X[, 1]) & (0.25 > X[, 1])) * (0.2) + + ((0.25 <= X[, 1]) & (0.5 > X[, 1])) * (0.4) + + ((0.5 <= X[, 1]) & (0.75 > X[, 1])) * (0.6) + + ((0.75 <= X[, 1]) & (1 > X[, 1])) * (0.8)) + # fmt: skip + pi_X_2 <- (((0 <= X[, 2]) & (0.25 > X[, 2])) * (0.8) + + ((0.25 <= X[, 2]) & (0.5 > X[, 2])) * (0.4) + + ((0.5 <= X[, 2]) & (0.75 > X[, 2])) * (0.6) + + ((0.75 <= X[, 2]) & (1 > X[, 2])) * (0.2)) + pi_X <- cbind(pi_X_1, pi_X_2) + # fmt: skip + tau_X_1 <- (((0 <= X[, 2]) & (0.25 > X[, 2])) * (0.5) + + ((0.25 <= X[, 2]) & (0.5 > X[, 2])) * (1.0) + + ((0.5 <= X[, 2]) & (0.75 > X[, 2])) * (1.5) + + ((0.75 <= X[, 2]) & (1 > X[, 2])) * (2.0)) + # fmt: skip + tau_X_2 <- (((0 <= X[, 3]) & (0.25 > X[, 3])) * (-0.5) + + ((0.25 <= X[, 3]) & (0.5 > X[, 3])) * (-1.5) + + ((0.5 <= X[, 3]) & (0.75 > X[, 3])) * (-1.0) + + ((0.75 <= X[, 3]) & (1 > X[, 3])) * (0.0)) + tau_X <- cbind(tau_X_1, tau_X_2) + Z_1 <- as.numeric(rbinom(n, 1, pi_X_1)) + Z_2 <- as.numeric(rbinom(n, 1, pi_X_2)) + Z <- cbind(Z_1, Z_2) + noise_sd <- 1 + y <- mu_X + rowSums(tau_X * Z) + rnorm(n, 0, noise_sd) + test_set_pct <- 0.2 + n_test <- round(test_set_pct * n) + n_train <- n - n_test + test_inds <- sort(sample(1:n, n_test, replace = FALSE)) + train_inds <- (1:n)[!((1:n) %in% test_inds)] + X_test <- X[test_inds, ] + X_train <- X[train_inds, ] + Z_test <- Z[test_inds, ] + Z_train <- Z[train_inds, ] + pi_test <- pi_X[test_inds, ] + pi_train <- pi_X[train_inds, ] + mu_test <- mu_X[test_inds] + mu_train <- mu_X[train_inds] + tau_test <- tau_X[test_inds, ] + tau_train <- tau_X[train_inds, ] + y_test <- y[test_inds] + y_train <- y[train_inds] + + # 1 chain, no thinning + general_param_list <- list( + num_chains = 1, + keep_every = 1, + adaptive_coding = F + ) + expect_no_error({ + bcf_model <- bcf( + X_train = X_train, + y_train = y_train, + Z_train = Z_train, + propensity_train = pi_train, + X_test = X_test, + Z_test = Z_test, + propensity_test = pi_test, + num_gfr = 0, + num_burnin = 10, + num_mcmc = 10, + general_params = general_param_list + ) + predict(bcf_model, X = X_test, Z = Z_test, propensity = pi_test) + }) }) test_that("BCF Predictions", { - skip_on_cran() - - # Generate simulated data - n <- 100 - p <- 5 - X <- matrix(runif(n*p), ncol = p) - mu_X <- ( - ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + - ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + - ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + - ((0.75 <= X[,1]) & (1 > X[,1])) * (7.5) - ) - pi_X <- ( - ((0 <= X[,1]) & (0.25 > X[,1])) * (0.2) + - ((0.25 <= X[,1]) & (0.5 > X[,1])) * (0.4) + - ((0.5 <= X[,1]) & (0.75 > X[,1])) * (0.6) + - ((0.75 <= X[,1]) & (1 > X[,1])) * (0.8) - ) - tau_X <- ( - ((0 <= X[,2]) & (0.25 > X[,2])) * (0.5) + - ((0.25 <= X[,2]) & (0.5 > X[,2])) * (1.0) + - ((0.5 <= X[,2]) & (0.75 > X[,2])) * (1.5) + - ((0.75 <= X[,2]) & (1 > X[,2])) * (2.0) - ) - Z <- rbinom(n, 1, pi_X) - noise_sd <- 1 - y <- mu_X + tau_X*Z + rnorm(n, 0, noise_sd) - test_set_pct <- 0.2 - n_test <- round(test_set_pct*n) - n_train <- n - n_test - test_inds <- sort(sample(1:n, n_test, replace = FALSE)) - train_inds <- (1:n)[!((1:n) %in% test_inds)] - X_test <- X[test_inds,] - X_train <- X[train_inds,] - Z_test <- Z[test_inds] - Z_train <- Z[train_inds] - pi_test <- pi_X[test_inds] - pi_train <- pi_X[train_inds] - mu_test <- mu_X[test_inds] - mu_train <- mu_X[train_inds] - tau_test <- tau_X[test_inds] - tau_train <- tau_X[train_inds] - y_test <- y[test_inds] - y_train <- y[train_inds] - - # Run a BCF model with only GFR - general_params <- list(num_chains = 1, keep_every = 1) - variance_forest_params <- list(num_trees = 50) - bcf_model <- bcf(X_train = X_train, y_train = y_train, Z_train = Z_train, - propensity_train = pi_train, X_test = X_test, Z_test = Z_test, - propensity_test = pi_test, num_gfr = 10, num_burnin = 0, - num_mcmc = 10, general_params = general_params, - variance_forest_params = variance_forest_params) - - # Check that cached predictions agree with results of predict() function - train_preds <- predict(bcf_model, X = X_train, Z = Z_train, propensity = pi_train) - train_preds_mean_cached <- bcf_model$y_hat_train - train_preds_mean_recomputed <- train_preds$y_hat - train_preds_variance_cached <- bcf_model$sigma2_x_hat_train - train_preds_variance_recomputed <- train_preds$variance_forest_predictions - - # Assertion - expect_equal(train_preds_mean_cached, train_preds_mean_recomputed) - expect_equal(train_preds_variance_cached, train_preds_variance_recomputed) + skip_on_cran() + + # Generate simulated data + n <- 100 + p <- 5 + X <- matrix(runif(n * p), ncol = p) + # fmt: skip + mu_X <- (((0 <= X[, 1]) & (0.25 > X[, 1])) * (-7.5) + + ((0.25 <= X[, 1]) & (0.5 > X[, 1])) * (-2.5) + + ((0.5 <= X[, 1]) & (0.75 > X[, 1])) * (2.5) + + ((0.75 <= X[, 1]) & (1 > X[, 1])) * (7.5)) + # fmt: skip + pi_X <- (((0 <= X[, 1]) & (0.25 > X[, 1])) * (0.2) + + ((0.25 <= X[, 1]) & (0.5 > X[, 1])) * (0.4) + + ((0.5 <= X[, 1]) & (0.75 > X[, 1])) * (0.6) + + ((0.75 <= X[, 1]) & (1 > X[, 1])) * (0.8)) + # fmt: skip + tau_X <- (((0 <= X[, 2]) & (0.25 > X[, 2])) * (0.5) + + ((0.25 <= X[, 2]) & (0.5 > X[, 2])) * (1.0) + + ((0.5 <= X[, 2]) & (0.75 > X[, 2])) * (1.5) + + ((0.75 <= X[, 2]) & (1 > X[, 2])) * (2.0)) + Z <- rbinom(n, 1, pi_X) + noise_sd <- 1 + y <- mu_X + tau_X * Z + rnorm(n, 0, noise_sd) + test_set_pct <- 0.2 + n_test <- round(test_set_pct * n) + n_train <- n - n_test + test_inds <- sort(sample(1:n, n_test, replace = FALSE)) + train_inds <- (1:n)[!((1:n) %in% test_inds)] + X_test <- X[test_inds, ] + X_train <- X[train_inds, ] + Z_test <- Z[test_inds] + Z_train <- Z[train_inds] + pi_test <- pi_X[test_inds] + pi_train <- pi_X[train_inds] + mu_test <- mu_X[test_inds] + mu_train <- mu_X[train_inds] + tau_test <- tau_X[test_inds] + tau_train <- tau_X[train_inds] + y_test <- y[test_inds] + y_train <- y[train_inds] + + # Run a BCF model with only GFR + general_params <- list(num_chains = 1, keep_every = 1) + variance_forest_params <- list(num_trees = 50) + expect_warning( + bcf_model <- bcf( + X_train = X_train, + y_train = y_train, + Z_train = Z_train, + propensity_train = pi_train, + X_test = X_test, + Z_test = Z_test, + propensity_test = pi_test, + num_gfr = 10, + num_burnin = 0, + num_mcmc = 10, + general_params = general_params, + variance_forest_params = variance_forest_params + ) + ) + + # Check that cached predictions agree with results of predict() function + train_preds <- predict( + bcf_model, + X = X_train, + Z = Z_train, + propensity = pi_train + ) + train_preds_mean_cached <- bcf_model$y_hat_train + train_preds_mean_recomputed <- train_preds$y_hat + train_preds_variance_cached <- bcf_model$sigma2_x_hat_train + train_preds_variance_recomputed <- train_preds$variance_forest_predictions + + # Assertion + expect_equal(train_preds_mean_cached, train_preds_mean_recomputed) + expect_equal(train_preds_variance_cached, train_preds_variance_recomputed) }) test_that("Random Effects BCF", { - skip_on_cran() - - # Generate simulated data - n <- 100 - p <- 5 - X <- matrix(runif(n*p), ncol = p) - mu_X <- ( - ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + - ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + - ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + - ((0.75 <= X[,1]) & (1 > X[,1])) * (7.5) - ) - pi_X <- ( - ((0 <= X[,1]) & (0.25 > X[,1])) * (0.2) + - ((0.25 <= X[,1]) & (0.5 > X[,1])) * (0.4) + - ((0.5 <= X[,1]) & (0.75 > X[,1])) * (0.6) + - ((0.75 <= X[,1]) & (1 > X[,1])) * (0.8) - ) - tau_X <- ( - ((0 <= X[,2]) & (0.25 > X[,2])) * (0.5) + - ((0.25 <= X[,2]) & (0.5 > X[,2])) * (1.0) + - ((0.5 <= X[,2]) & (0.75 > X[,2])) * (1.5) + - ((0.75 <= X[,2]) & (1 > X[,2])) * (2.0) - ) - Z <- rbinom(n, 1, pi_X) - rfx_group_ids <- sample(1:2, size = n, replace = TRUE) - rfx_basis <- cbind(rep(1, n), runif(n)) - num_rfx_components <- ncol(rfx_basis) - num_rfx_groups <- length(unique(rfx_group_ids)) - rfx_coefs <- matrix(c(-5, 5, 1, -1), ncol = 2, byrow = T) - rfx_term <- rowSums(rfx_coefs[rfx_group_ids,] * rfx_basis) - noise_sd <- 1 - y <- mu_X + tau_X*Z + rfx_term + rnorm(n, 0, noise_sd) - - test_set_pct <- 0.2 - n_test <- round(test_set_pct*n) - n_train <- n - n_test - test_inds <- sort(sample(1:n, n_test, replace = FALSE)) - train_inds <- (1:n)[!((1:n) %in% test_inds)] - X_test <- X[test_inds,] - X_train <- X[train_inds,] - Z_test <- Z[test_inds] - Z_train <- Z[train_inds] - pi_test <- pi_X[test_inds] - pi_train <- pi_X[train_inds] - mu_test <- mu_X[test_inds] - mu_train <- mu_X[train_inds] - tau_test <- tau_X[test_inds] - tau_train <- tau_X[train_inds] - rfx_group_ids_test <- rfx_group_ids[test_inds] - rfx_group_ids_train <- rfx_group_ids[train_inds] - rfx_basis_test <- rfx_basis[test_inds,] - rfx_basis_train <- rfx_basis[train_inds,] - y_test <- y[test_inds] - y_train <- y[train_inds] - - # Specify no rfx parameters - general_param_list <- list() - expect_no_error( - bcf_model <- bcf(X_train = X_train, y_train = y_train, - Z_train = Z_train, propensity_train = pi_train, - rfx_group_ids_train = rfx_group_ids_train, - rfx_basis_train = rfx_basis_train, X_test = X_test, - Z_test = Z_test, propensity_test = pi_test, - rfx_group_ids_test = rfx_group_ids_test, - rfx_basis_test = rfx_basis_test, - num_gfr = 10, num_burnin = 0, num_mcmc = 10, - general_params = general_param_list) - ) + skip_on_cran() - # Specify all rfx parameters as scalars - general_param_list <- list(rfx_working_parameter_prior_mean = 1., - rfx_group_parameter_prior_mean = 1., - rfx_working_parameter_prior_cov = 1., - rfx_group_parameter_prior_cov = 1., - rfx_variance_prior_shape = 1, - rfx_variance_prior_scale = 1) - expect_no_error( - bcf_model <- bcf(X_train = X_train, y_train = y_train, - Z_train = Z_train, propensity_train = pi_train, - rfx_group_ids_train = rfx_group_ids_train, - rfx_basis_train = rfx_basis_train, X_test = X_test, - Z_test = Z_test, propensity_test = pi_test, - rfx_group_ids_test = rfx_group_ids_test, - rfx_basis_test = rfx_basis_test, - num_gfr = 10, num_burnin = 0, num_mcmc = 10, - general_params = general_param_list) - ) - - # Specify all relevant rfx parameters as vectors - general_param_list <- list(rfx_working_parameter_prior_mean = c(1.,1.), - rfx_group_parameter_prior_mean = c(1.,1.), - rfx_working_parameter_prior_cov = diag(1.,2), - rfx_group_parameter_prior_cov = diag(1.,2), - rfx_variance_prior_shape = 1, - rfx_variance_prior_scale = 1) - expect_no_error( - bcf_model <- bcf(X_train = X_train, y_train = y_train, - Z_train = Z_train, propensity_train = pi_train, - rfx_group_ids_train = rfx_group_ids_train, - rfx_basis_train = rfx_basis_train, X_test = X_test, - Z_test = Z_test, propensity_test = pi_test, - rfx_group_ids_test = rfx_group_ids_test, - rfx_basis_test = rfx_basis_test, - num_gfr = 10, num_burnin = 0, num_mcmc = 10, - general_params = general_param_list) - ) + # Generate simulated data + n <- 100 + p <- 5 + X <- matrix(runif(n * p), ncol = p) + # fmt: skip + mu_X <- (((0 <= X[, 1]) & (0.25 > X[, 1])) * (-7.5) + + ((0.25 <= X[, 1]) & (0.5 > X[, 1])) * (-2.5) + + ((0.5 <= X[, 1]) & (0.75 > X[, 1])) * (2.5) + + ((0.75 <= X[, 1]) & (1 > X[, 1])) * (7.5)) + # fmt: skip + pi_X <- (((0 <= X[, 1]) & (0.25 > X[, 1])) * (0.2) + + ((0.25 <= X[, 1]) & (0.5 > X[, 1])) * (0.4) + + ((0.5 <= X[, 1]) & (0.75 > X[, 1])) * (0.6) + + ((0.75 <= X[, 1]) & (1 > X[, 1])) * (0.8)) + # fmt: skip + tau_X <- (((0 <= X[, 2]) & (0.25 > X[, 2])) * (0.5) + + ((0.25 <= X[, 2]) & (0.5 > X[, 2])) * (1.0) + + ((0.5 <= X[, 2]) & (0.75 > X[, 2])) * (1.5) + + ((0.75 <= X[, 2]) & (1 > X[, 2])) * (2.0)) + Z <- rbinom(n, 1, pi_X) + rfx_group_ids <- sample(1:2, size = n, replace = TRUE) + rfx_basis <- cbind(rep(1, n), runif(n)) + num_rfx_components <- ncol(rfx_basis) + num_rfx_groups <- length(unique(rfx_group_ids)) + rfx_coefs <- matrix(c(-5, 5, 1, -1), ncol = 2, byrow = T) + rfx_term <- rowSums(rfx_coefs[rfx_group_ids, ] * rfx_basis) + noise_sd <- 1 + y <- mu_X + tau_X * Z + rfx_term + rnorm(n, 0, noise_sd) + + test_set_pct <- 0.2 + n_test <- round(test_set_pct * n) + n_train <- n - n_test + test_inds <- sort(sample(1:n, n_test, replace = FALSE)) + train_inds <- (1:n)[!((1:n) %in% test_inds)] + X_test <- X[test_inds, ] + X_train <- X[train_inds, ] + Z_test <- Z[test_inds] + Z_train <- Z[train_inds] + pi_test <- pi_X[test_inds] + pi_train <- pi_X[train_inds] + mu_test <- mu_X[test_inds] + mu_train <- mu_X[train_inds] + tau_test <- tau_X[test_inds] + tau_train <- tau_X[train_inds] + rfx_group_ids_test <- rfx_group_ids[test_inds] + rfx_group_ids_train <- rfx_group_ids[train_inds] + rfx_basis_test <- rfx_basis[test_inds, ] + rfx_basis_train <- rfx_basis[train_inds, ] + y_test <- y[test_inds] + y_train <- y[train_inds] + + # Specify no rfx parameters + general_param_list <- list() + expect_no_error( + bcf_model <- bcf( + X_train = X_train, + y_train = y_train, + Z_train = Z_train, + propensity_train = pi_train, + rfx_group_ids_train = rfx_group_ids_train, + rfx_basis_train = rfx_basis_train, + X_test = X_test, + Z_test = Z_test, + propensity_test = pi_test, + rfx_group_ids_test = rfx_group_ids_test, + rfx_basis_test = rfx_basis_test, + num_gfr = 10, + num_burnin = 0, + num_mcmc = 10, + general_params = general_param_list + ) + ) + + # Specify all rfx parameters as scalars + rfx_param_list <- list( + working_parameter_prior_mean = 1., + group_parameter_prior_mean = 1., + working_parameter_prior_cov = 1., + group_parameter_prior_cov = 1., + variance_prior_shape = 1, + variance_prior_scale = 1 + ) + expect_no_error( + bcf_model <- bcf( + X_train = X_train, + y_train = y_train, + Z_train = Z_train, + propensity_train = pi_train, + rfx_group_ids_train = rfx_group_ids_train, + rfx_basis_train = rfx_basis_train, + X_test = X_test, + Z_test = Z_test, + propensity_test = pi_test, + rfx_group_ids_test = rfx_group_ids_test, + rfx_basis_test = rfx_basis_test, + num_gfr = 10, + num_burnin = 0, + num_mcmc = 10, + random_effects_params = rfx_param_list + ) + ) + + # Specify all relevant rfx parameters as vectors + rfx_param_list <- list( + working_parameter_prior_mean = c(1., 1.), + group_parameter_prior_mean = c(1., 1.), + working_parameter_prior_cov = diag(1., 2), + group_parameter_prior_cov = diag(1., 2), + variance_prior_shape = 1, + variance_prior_scale = 1 + ) + expect_no_error( + bcf_model <- bcf( + X_train = X_train, + y_train = y_train, + Z_train = Z_train, + propensity_train = pi_train, + rfx_group_ids_train = rfx_group_ids_train, + rfx_basis_train = rfx_basis_train, + X_test = X_test, + Z_test = Z_test, + propensity_test = pi_test, + rfx_group_ids_test = rfx_group_ids_test, + rfx_basis_test = rfx_basis_test, + num_gfr = 10, + num_burnin = 0, + num_mcmc = 10, + random_effects_params = rfx_param_list + ) + ) }) diff --git a/test/R/testthat/test-categorical.R b/test/R/testthat/test-categorical.R index 92d7459f..19a7468e 100644 --- a/test/R/testthat/test-categorical.R +++ b/test/R/testthat/test-categorical.R @@ -1,7 +1,8 @@ test_that("In-sample one-hot encoding works for unordered categorical variables", { - x1 <- c(3,2,1,4,3,2,3,2,4,2) - x1_onehot <- oneHotInitializeAndEncode(x1) - x1_expected <- matrix( + x1 <- c(3, 2, 1, 4, 3, 2, 3, 2, 4, 2) + x1_onehot <- oneHotInitializeAndEncode(x1) + # fmt: skip + x1_expected <- matrix( c(0,0,1,0,0, 0,1,0,0,0, 1,0,0,0,0, @@ -13,11 +14,12 @@ test_that("In-sample one-hot encoding works for unordered categorical variables" 0,0,0,1,0, 0,1,0,0,0), byrow = TRUE, ncol = 5) - x1_levels_expected <- c("1","2","3","4") - - x2 <- c("a","c","b","c","d","a","c","a","b","d") - x2_onehot <- oneHotInitializeAndEncode(x2) - x2_expected <- matrix( + x1_levels_expected <- c("1", "2", "3", "4") + + x2 <- c("a", "c", "b", "c", "d", "a", "c", "a", "b", "d") + x2_onehot <- oneHotInitializeAndEncode(x2) + # fmt: skip + x2_expected <- matrix( c(1,0,0,0,0, 0,0,1,0,0, 0,1,0,0,0, @@ -29,11 +31,12 @@ test_that("In-sample one-hot encoding works for unordered categorical variables" 0,1,0,0,0, 0,0,0,1,0), byrow = TRUE, ncol = 5) - x2_levels_expected <- c("a","b","c","d") - - x3 <- c(3.2,2.4,1.5,4.6,3.2,2.4,3.2,2.4,4.6,2.4) - x3_onehot <- oneHotInitializeAndEncode(x3) - x3_expected <- matrix( + x2_levels_expected <- c("a", "b", "c", "d") + + x3 <- c(3.2, 2.4, 1.5, 4.6, 3.2, 2.4, 3.2, 2.4, 4.6, 2.4) + x3_onehot <- oneHotInitializeAndEncode(x3) + # fmt: skip + x3_expected <- matrix( c(0,0,1,0,0, 0,1,0,0,0, 1,0,0,0,0, @@ -45,111 +48,159 @@ test_that("In-sample one-hot encoding works for unordered categorical variables" 0,0,0,1,0, 0,1,0,0,0), byrow = TRUE, ncol = 5) - x3_levels_expected <- c("1.5","2.4","3.2","4.6") - - expect_equal(x1_onehot$Xtilde, x1_expected) - expect_equal(x2_onehot$Xtilde, x2_expected) - expect_equal(x3_onehot$Xtilde, x3_expected) - expect_equal(x1_onehot$unique_levels, x1_levels_expected) - expect_equal(x2_onehot$unique_levels, x2_levels_expected) - expect_equal(x3_onehot$unique_levels, x3_levels_expected) + x3_levels_expected <- c("1.5", "2.4", "3.2", "4.6") + + expect_equal(x1_onehot$Xtilde, x1_expected) + expect_equal(x2_onehot$Xtilde, x2_expected) + expect_equal(x3_onehot$Xtilde, x3_expected) + expect_equal(x1_onehot$unique_levels, x1_levels_expected) + expect_equal(x2_onehot$unique_levels, x2_levels_expected) + expect_equal(x3_onehot$unique_levels, x3_levels_expected) }) test_that("Out-of-sample one-hot encoding works for unordered categorical variables", { - x1 <- c(3,2,1,4,3,2,3,2,4,2) - x1_test <- c(1,2,4,3,5) - x1_test_onehot <- oneHotEncode(x1_test, levels(factor(x1))) - x1_test_expected <- matrix( + x1 <- c(3, 2, 1, 4, 3, 2, 3, 2, 4, 2) + x1_test <- c(1, 2, 4, 3, 5) + x1_test_onehot <- oneHotEncode(x1_test, levels(factor(x1))) + # fmt: skip + x1_test_expected <- matrix( c(1,0,0,0,0, 0,1,0,0,0, 0,0,0,1,0, 0,0,1,0,0, 0,0,0,0,1), byrow = TRUE, ncol = 5) - - x2 <- c("a","c","b","c","d","a","c","a","b","d") - x2_test <- c("a","c","g","b","f") - x2_test_onehot <- oneHotEncode(x2_test, levels(factor(x2))) - x2_test_expected <- matrix( + + x2 <- c("a", "c", "b", "c", "d", "a", "c", "a", "b", "d") + x2_test <- c("a", "c", "g", "b", "f") + x2_test_onehot <- oneHotEncode(x2_test, levels(factor(x2))) + # fmt: skip + x2_test_expected <- matrix( c(1,0,0,0,0, 0,0,1,0,0, 0,0,0,0,1, 0,1,0,0,0, 0,0,0,0,1), byrow = TRUE, ncol = 5) - - x3 <- c(3.2,2.4,1.5,4.6,3.2,2.4,3.2,2.4,4.6,2.4) - x3_test <- c(10.3,-0.5,4.6,3.2,1.8) - x3_test_onehot <- oneHotEncode(x3_test, levels(factor(x3))) - x3_test_expected <- matrix( + + x3 <- c(3.2, 2.4, 1.5, 4.6, 3.2, 2.4, 3.2, 2.4, 4.6, 2.4) + x3_test <- c(10.3, -0.5, 4.6, 3.2, 1.8) + x3_test_onehot <- oneHotEncode(x3_test, levels(factor(x3))) + # fmt: skip + x3_test_expected <- matrix( c(0,0,0,0,1, 0,0,0,0,1, 0,0,0,1,0, 0,0,1,0,0, 0,0,0,0,1), byrow = TRUE, ncol = 5) - - expect_equal(x1_test_onehot, x1_test_expected) - expect_equal(x2_test_onehot, x2_test_expected) - expect_equal(x3_test_onehot, x3_test_expected) + + expect_equal(x1_test_onehot, x1_test_expected) + expect_equal(x2_test_onehot, x2_test_expected) + expect_equal(x3_test_onehot, x3_test_expected) }) test_that("In-sample preprocessing for ordered categorical variables", { - string_var_response_levels <- c("1. Strongly disagree", "2. Disagree", "3. Neither agree nor disagree", "4. Agree", "5. Strongly agree") - - x1 <- c("1. Strongly disagree", "3. Neither agree nor disagree", "2. Disagree", "4. Agree", "3. Neither agree nor disagree", "5. Strongly agree", "4. Agree") - x1_preprocessing <- orderedCatInitializeAndPreprocess(x1) - x1_vector_expected <- c(1,3,2,4,3,5,4) - x1_levels_expected <- string_var_response_levels - - x2 <- factor(x1, levels = string_var_response_levels, ordered = TRUE) - x2_preprocessing <- orderedCatInitializeAndPreprocess(x2) - x2_vector_expected <- c(1,3,2,4,3,5,4) - x2_levels_expected <- string_var_response_levels - - string_var_levels_reordered <- c("5. Strongly agree", "4. Agree", "3. Neither agree nor disagree", "2. Disagree", "1. Strongly disagree") - x3 <- factor(x1, levels = string_var_levels_reordered, ordered = TRUE) - x3_preprocessing <- orderedCatInitializeAndPreprocess(x3) - x3_vector_expected <- c(5,3,4,2,3,1,2) - x3_levels_expected <- string_var_levels_reordered - - x4 <- c(3,2,4,6,5,2,3,1,3,4,6) - x4_preprocessing <- orderedCatInitializeAndPreprocess(x4) - x4_vector_expected <- c(3,2,4,6,5,2,3,1,3,4,6) - x4_levels_expected <- c("1","2","3","4","5","6") - - expect_equal(x1_preprocessing$x_preprocessed, x1_vector_expected) - expect_equal(x2_preprocessing$x_preprocessed, x2_vector_expected) - expect_equal(x3_preprocessing$x_preprocessed, x3_vector_expected) - expect_equal(x4_preprocessing$x_preprocessed, x4_vector_expected) - expect_equal(x1_preprocessing$unique_levels, x1_levels_expected) - expect_equal(x2_preprocessing$unique_levels, x2_levels_expected) - expect_equal(x3_preprocessing$unique_levels, x3_levels_expected) - expect_equal(x4_preprocessing$unique_levels, x4_levels_expected) + string_var_response_levels <- c( + "1. Strongly disagree", + "2. Disagree", + "3. Neither agree nor disagree", + "4. Agree", + "5. Strongly agree" + ) + + x1 <- c( + "1. Strongly disagree", + "3. Neither agree nor disagree", + "2. Disagree", + "4. Agree", + "3. Neither agree nor disagree", + "5. Strongly agree", + "4. Agree" + ) + x1_preprocessing <- orderedCatInitializeAndPreprocess(x1) + x1_vector_expected <- c(1, 3, 2, 4, 3, 5, 4) + x1_levels_expected <- string_var_response_levels + + x2 <- factor(x1, levels = string_var_response_levels, ordered = TRUE) + x2_preprocessing <- orderedCatInitializeAndPreprocess(x2) + x2_vector_expected <- c(1, 3, 2, 4, 3, 5, 4) + x2_levels_expected <- string_var_response_levels + + string_var_levels_reordered <- c( + "5. Strongly agree", + "4. Agree", + "3. Neither agree nor disagree", + "2. Disagree", + "1. Strongly disagree" + ) + x3 <- factor(x1, levels = string_var_levels_reordered, ordered = TRUE) + x3_preprocessing <- orderedCatInitializeAndPreprocess(x3) + x3_vector_expected <- c(5, 3, 4, 2, 3, 1, 2) + x3_levels_expected <- string_var_levels_reordered + + x4 <- c(3, 2, 4, 6, 5, 2, 3, 1, 3, 4, 6) + x4_preprocessing <- orderedCatInitializeAndPreprocess(x4) + x4_vector_expected <- c(3, 2, 4, 6, 5, 2, 3, 1, 3, 4, 6) + x4_levels_expected <- c("1", "2", "3", "4", "5", "6") + + expect_equal(x1_preprocessing$x_preprocessed, x1_vector_expected) + expect_equal(x2_preprocessing$x_preprocessed, x2_vector_expected) + expect_equal(x3_preprocessing$x_preprocessed, x3_vector_expected) + expect_equal(x4_preprocessing$x_preprocessed, x4_vector_expected) + expect_equal(x1_preprocessing$unique_levels, x1_levels_expected) + expect_equal(x2_preprocessing$unique_levels, x2_levels_expected) + expect_equal(x3_preprocessing$unique_levels, x3_levels_expected) + expect_equal(x4_preprocessing$unique_levels, x4_levels_expected) }) test_that("Out-of-sample preprocessing for ordered categorical variables", { - string_var_response_levels <- c("1. Strongly disagree", "2. Disagree", "3. Neither agree nor disagree", "4. Agree", "5. Strongly agree") - - x1 <- c("1. Strongly disagree", "3. Neither agree nor disagree", "2. Disagree", "4. Agree", "3. Neither agree nor disagree", "5. Strongly agree", "4. Agree") - x1_preprocessing <- orderedCatPreprocess(x1, string_var_response_levels) - x1_vector_expected <- c(1,3,2,4,3,5,4) - - x2 <- factor(x1, levels = string_var_response_levels, ordered = TRUE) - x2_preprocessing <- orderedCatPreprocess(x2, string_var_response_levels) - x2_vector_expected <- c(1,3,2,4,3,5,4) - - x3 <- c("1. Strongly disagree", "6. Other", "7. Also other", "4. Agree", "3. Neither agree nor disagree", "5. Strongly agree", "4. Agree") - expected_warning_message <- "Variable includes ordered categorical levels not included in the original training set" - expect_warning(x3_preprocessing <- orderedCatPreprocess(x3, string_var_response_levels), expected_warning_message) - x3_vector_expected <- c(1,6,6,4,3,5,4) - - x4 <- c(3,2,4,6,5,2,3,1,3,4,6) - x4_preprocessing <- orderedCatPreprocess(x4, c("1","2","3","4","5","6")) - x4_vector_expected <- c(3,2,4,6,5,2,3,1,3,4,6) - - expect_equal(x1_preprocessing, x1_vector_expected) - expect_equal(x2_preprocessing, x2_vector_expected) - expect_equal(x3_preprocessing, x3_vector_expected) - expect_equal(x4_preprocessing, x4_vector_expected) + string_var_response_levels <- c( + "1. Strongly disagree", + "2. Disagree", + "3. Neither agree nor disagree", + "4. Agree", + "5. Strongly agree" + ) + + x1 <- c( + "1. Strongly disagree", + "3. Neither agree nor disagree", + "2. Disagree", + "4. Agree", + "3. Neither agree nor disagree", + "5. Strongly agree", + "4. Agree" + ) + x1_preprocessing <- orderedCatPreprocess(x1, string_var_response_levels) + x1_vector_expected <- c(1, 3, 2, 4, 3, 5, 4) + + x2 <- factor(x1, levels = string_var_response_levels, ordered = TRUE) + x2_preprocessing <- orderedCatPreprocess(x2, string_var_response_levels) + x2_vector_expected <- c(1, 3, 2, 4, 3, 5, 4) + + x3 <- c( + "1. Strongly disagree", + "6. Other", + "7. Also other", + "4. Agree", + "3. Neither agree nor disagree", + "5. Strongly agree", + "4. Agree" + ) + expected_warning_message <- "Variable includes ordered categorical levels not included in the original training set" + expect_warning( + x3_preprocessing <- orderedCatPreprocess(x3, string_var_response_levels), + expected_warning_message + ) + x3_vector_expected <- c(1, 6, 6, 4, 3, 5, 4) + + x4 <- c(3, 2, 4, 6, 5, 2, 3, 1, 3, 4, 6) + x4_preprocessing <- orderedCatPreprocess(x4, c("1", "2", "3", "4", "5", "6")) + x4_vector_expected <- c(3, 2, 4, 6, 5, 2, 3, 1, 3, 4, 6) + + expect_equal(x1_preprocessing, x1_vector_expected) + expect_equal(x2_preprocessing, x2_vector_expected) + expect_equal(x3_preprocessing, x3_vector_expected) + expect_equal(x4_preprocessing, x4_vector_expected) }) diff --git a/test/R/testthat/test-data-preprocessing.R b/test/R/testthat/test-data-preprocessing.R index 61752aa7..ced792a0 100644 --- a/test/R/testthat/test-data-preprocessing.R +++ b/test/R/testthat/test-data-preprocessing.R @@ -13,7 +13,7 @@ # expect_equal(preprocess_list$metadata$num_unordered_cat_vars, 0) # expect_equal(preprocess_list$metadata$numeric_vars, c("x1","x2","x3")) # }) -# +# # test_that("Preprocessing of all-unordered-categorical covariate dataset works", { # cov_df <- data.frame(x1 = 1:5, x2 = 5:1, x3 = 6:10) # cov_mat <- matrix(c( @@ -30,13 +30,13 @@ # expect_equal(preprocess_list$metadata$num_ordered_cat_vars, 0) # expect_equal(preprocess_list$metadata$num_unordered_cat_vars, 3) # expect_equal(preprocess_list$metadata$unordered_cat_vars, c("x1","x2","x3")) -# expect_equal(preprocess_list$metadata$unordered_unique_levels, -# list(x1=c("1","2","3","4","5"), -# x2=c("1","2","3","4","5"), +# expect_equal(preprocess_list$metadata$unordered_unique_levels, +# list(x1=c("1","2","3","4","5"), +# x2=c("1","2","3","4","5"), # x3=c("6","7","8","9","10")) # ) # }) -# +# # test_that("Preprocessing of all-ordered-categorical covariate dataset works", { # cov_df <- data.frame(x1 = 1:5, x2 = 5:1, x3 = 6:10) # cov_mat <- matrix(c( @@ -51,13 +51,13 @@ # expect_equal(preprocess_list$metadata$num_ordered_cat_vars, 3) # expect_equal(preprocess_list$metadata$num_unordered_cat_vars, 0) # expect_equal(preprocess_list$metadata$ordered_cat_vars, c("x1","x2","x3")) -# expect_equal(preprocess_list$metadata$ordered_unique_levels, -# list(x1=c("1","2","3","4","5"), -# x2=c("1","2","3","4","5"), +# expect_equal(preprocess_list$metadata$ordered_unique_levels, +# list(x1=c("1","2","3","4","5"), +# x2=c("1","2","3","4","5"), # x3=c("6","7","8","9","10")) # ) # }) -# +# # test_that("Preprocessing of mixed-covariate dataset works", { # cov_df <- data.frame(x1 = 1:5, x2 = 5:1, x3 = 6:10) # cov_mat <- matrix(c( @@ -78,7 +78,7 @@ # expect_equal(preprocess_list$metadata$ordered_unique_levels, list(x2=c("1","2","3","4","5"))) # expect_equal(preprocess_list$metadata$unordered_unique_levels, list(x3=c("6","7","8","9","10"))) # }) -# +# # test_that("Preprocessing of mixed-covariate matrix works", { # cov_input <- matrix(c(1:5,5:1,6:10),ncol=3,byrow=FALSE) # cov_mat <- matrix(c( @@ -98,7 +98,7 @@ # expect_equal(preprocess_list$metadata$unordered_cat_vars, c("x3")) # expect_equal(preprocess_list$metadata$ordered_unique_levels, list(x2=c("1","2","3","4","5"))) # expect_equal(preprocess_list$metadata$unordered_unique_levels, list(x3=c("6","7","8","9","10"))) -# +# # alt_preprocess_list <- createForestCovariates(cov_input, ordered_cat_vars = "x2", unordered_cat_vars = "x3") # expect_equal(alt_preprocess_list$data, cov_mat) # expect_equal(alt_preprocess_list$metadata$feature_types, c(0, rep(1,7))) @@ -110,16 +110,16 @@ # expect_equal(alt_preprocess_list$metadata$ordered_unique_levels, list(x2=c("1","2","3","4","5"))) # expect_equal(alt_preprocess_list$metadata$unordered_unique_levels, list(x3=c("6","7","8","9","10"))) # }) -# +# # test_that("Preprocessing of out-of-sample mixed-covariate dataset works", { # metadata <- list( -# num_numeric_vars = 1, -# num_ordered_cat_vars = 1, -# num_unordered_cat_vars = 1, -# numeric_vars = c("x1"), -# ordered_cat_vars = c("x2"), -# unordered_cat_vars = c("x3"), -# ordered_unique_levels = list(x2=c("1","2","3","4","5")), +# num_numeric_vars = 1, +# num_ordered_cat_vars = 1, +# num_unordered_cat_vars = 1, +# numeric_vars = c("x1"), +# ordered_cat_vars = c("x2"), +# unordered_cat_vars = c("x3"), +# ordered_unique_levels = list(x2=c("1","2","3","4","5")), # unordered_unique_levels = list(x3=c("6","7","8","9","10")) # ) # cov_df <- data.frame(x1 = c(1:5,1), x2 = c(5:1,5), x3 = 6:11) @@ -134,7 +134,7 @@ # X_preprocessed <- createForestCovariatesFromMetadata(cov_df, metadata) # expect_equal(X_preprocessed, cov_mat) # }) -# +# # test_that("Preprocessing of all-numeric covariate dataset works", { # cov_df <- data.frame(x1 = 1:5, x2 = 5:1, x3 = 6:10) # cov_mat <- matrix(c( @@ -151,7 +151,7 @@ # expect_equal(preprocess_list$metadata$original_var_indices, 1:3) # expect_equal(preprocess_list$metadata$numeric_vars, c("x1","x2","x3")) # }) -# +# # test_that("Preprocessing of all-unordered-categorical covariate dataset works", { # cov_df <- data.frame(x1 = 1:5, x2 = 5:1, x3 = 6:10) # cov_df$x1 <- factor(cov_df$x1) @@ -173,13 +173,13 @@ # expect_equal(preprocess_list$metadata$unordered_cat_vars, c("x1","x2","x3")) # expected_var_indices <- c(rep(1,6),rep(2,6),rep(3,6)) # expect_equal(preprocess_list$metadata$original_var_indices, expected_var_indices) -# expect_equal(preprocess_list$metadata$unordered_unique_levels, -# list(x1=c("1","2","3","4","5"), -# x2=c("1","2","3","4","5"), +# expect_equal(preprocess_list$metadata$unordered_unique_levels, +# list(x1=c("1","2","3","4","5"), +# x2=c("1","2","3","4","5"), # x3=c("6","7","8","9","10")) # ) # }) -# +# # test_that("Preprocessing of all-ordered-categorical covariate dataset works", { # cov_df <- data.frame(x1 = 1:5, x2 = 5:1, x3 = 6:10) # cov_df$x1 <- factor(cov_df$x1, ordered = TRUE) @@ -198,13 +198,13 @@ # expect_equal(preprocess_list$metadata$num_unordered_cat_vars, 0) # expect_equal(preprocess_list$metadata$ordered_cat_vars, c("x1","x2","x3")) # expect_equal(preprocess_list$metadata$original_var_indices, 1:3) -# expect_equal(preprocess_list$metadata$ordered_unique_levels, -# list(x1=c("1","2","3","4","5"), -# x2=c("1","2","3","4","5"), +# expect_equal(preprocess_list$metadata$ordered_unique_levels, +# list(x1=c("1","2","3","4","5"), +# x2=c("1","2","3","4","5"), # x3=c("6","7","8","9","10")) # ) # }) -# +# # test_that("Preprocessing of mixed-covariate dataset works", { # cov_df <- data.frame(x1 = 1:5, x2 = 5:1, x3 = 6:10) # cov_df$x2 <- factor(cov_df$x2, ordered = TRUE) @@ -229,17 +229,17 @@ # expect_equal(preprocess_list$metadata$ordered_unique_levels, list(x2=c("1","2","3","4","5"))) # expect_equal(preprocess_list$metadata$unordered_unique_levels, list(x3=c("6","7","8","9","10"))) # }) -# +# # test_that("Preprocessing of out-of-sample mixed-covariate dataset works", { # metadata <- list( -# num_numeric_vars = 1, -# num_ordered_cat_vars = 1, -# num_unordered_cat_vars = 1, +# num_numeric_vars = 1, +# num_ordered_cat_vars = 1, +# num_unordered_cat_vars = 1, # original_var_indices = c(1, 2, 3, 3, 3, 3, 3, 3), -# numeric_vars = c("x1"), -# ordered_cat_vars = c("x2"), -# unordered_cat_vars = c("x3"), -# ordered_unique_levels = list(x2=c("1","2","3","4","5")), +# numeric_vars = c("x1"), +# ordered_cat_vars = c("x2"), +# unordered_cat_vars = c("x3"), +# ordered_unique_levels = list(x2=c("1","2","3","4","5")), # unordered_unique_levels = list(x3=c("6","7","8","9","10")) # ) # cov_df <- data.frame(x1 = c(1:5,1), x2 = c(5:1,5), x3 = 6:11) diff --git a/test/R/testthat/test-dataset.R b/test/R/testthat/test-dataset.R index 80f8d8e0..d9ecadf9 100644 --- a/test/R/testthat/test-dataset.R +++ b/test/R/testthat/test-dataset.R @@ -1,66 +1,73 @@ test_that("ForestDataset can be constructed and updated", { - # Generate data - n <- 20 - num_covariates <- 10 - num_basis <- 5 - covariates <- matrix(runif(n * num_covariates), ncol = num_covariates) - basis <- matrix(runif(n * num_basis), ncol = num_basis) - variance_weights <- runif(n) - - # Copy data to a ForestDataset object - forest_dataset <- createForestDataset(covariates, basis, variance_weights) - - # Run first round of expectations - expect_equal(forest_dataset$num_observations(), n) - expect_equal(forest_dataset$num_covariates(), num_covariates) - expect_equal(forest_dataset$num_basis(), num_basis) - expect_equal(forest_dataset$has_variance_weights(), T) - - # Update data - new_basis <- matrix(runif(n * num_basis), ncol = num_basis) - new_variance_weights <- runif(n) - expect_no_error( - forest_dataset$update_basis(new_basis) - ) - expect_no_error( - forest_dataset$update_variance_weights(new_variance_weights) - ) - - # Check that we recover the correct data through get_covariates, get_basis, and get_variance_weights - expect_equal(covariates, forest_dataset$get_covariates()) - expect_equal(new_basis, forest_dataset$get_basis()) - expect_equal(new_variance_weights, forest_dataset$get_variance_weights()) + # Generate data + n <- 20 + num_covariates <- 10 + num_basis <- 5 + covariates <- matrix(runif(n * num_covariates), ncol = num_covariates) + basis <- matrix(runif(n * num_basis), ncol = num_basis) + variance_weights <- runif(n) + + # Copy data to a ForestDataset object + forest_dataset <- createForestDataset(covariates, basis, variance_weights) + + # Run first round of expectations + expect_equal(forest_dataset$num_observations(), n) + expect_equal(forest_dataset$num_covariates(), num_covariates) + expect_equal(forest_dataset$num_basis(), num_basis) + expect_equal(forest_dataset$has_variance_weights(), T) + + # Update data + new_basis <- matrix(runif(n * num_basis), ncol = num_basis) + new_variance_weights <- runif(n) + expect_no_error( + forest_dataset$update_basis(new_basis) + ) + expect_no_error( + forest_dataset$update_variance_weights(new_variance_weights) + ) + + # Check that we recover the correct data through get_covariates, get_basis, and get_variance_weights + expect_equal(covariates, forest_dataset$get_covariates()) + expect_equal(new_basis, forest_dataset$get_basis()) + expect_equal(new_variance_weights, forest_dataset$get_variance_weights()) }) test_that("RandomEffectsDataset can be constructed and updated", { - # Generate data - n <- 20 - num_groups <- 4 - num_basis <- 5 - group_ids <- sample(as.integer(1:num_groups), size = n, replace = T) - rfx_basis <- cbind(1, matrix(runif(n*(num_basis-1)), ncol = (num_basis-1))) - variance_weights <- runif(n) - - # Copy data to a RandomEffectsDataset object - rfx_dataset <- createRandomEffectsDataset(group_ids, rfx_basis, variance_weights) - - # Run first round of expectations - expect_equal(rfx_dataset$num_observations(), n) - expect_equal(rfx_dataset$num_basis(), num_basis) - expect_equal(rfx_dataset$has_variance_weights(), T) - - # Update data - new_rfx_basis <- matrix(runif(n * num_basis), ncol = num_basis) - new_variance_weights <- runif(n) - expect_no_error( - rfx_dataset$update_basis(new_rfx_basis) - ) - expect_no_error( - rfx_dataset$update_variance_weights(new_variance_weights) - ) - - # Check that we recover the correct data through get_group_labels, get_basis, and get_variance_weights - expect_equal(group_ids, rfx_dataset$get_group_labels()) - expect_equal(new_rfx_basis, rfx_dataset$get_basis()) - expect_equal(new_variance_weights, rfx_dataset$get_variance_weights()) + # Generate data + n <- 20 + num_groups <- 4 + num_basis <- 5 + group_ids <- sample(as.integer(1:num_groups), size = n, replace = T) + rfx_basis <- cbind( + 1, + matrix(runif(n * (num_basis - 1)), ncol = (num_basis - 1)) + ) + variance_weights <- runif(n) + + # Copy data to a RandomEffectsDataset object + rfx_dataset <- createRandomEffectsDataset( + group_ids, + rfx_basis, + variance_weights + ) + + # Run first round of expectations + expect_equal(rfx_dataset$num_observations(), n) + expect_equal(rfx_dataset$num_basis(), num_basis) + expect_equal(rfx_dataset$has_variance_weights(), T) + + # Update data + new_rfx_basis <- matrix(runif(n * num_basis), ncol = num_basis) + new_variance_weights <- runif(n) + expect_no_error( + rfx_dataset$update_basis(new_rfx_basis) + ) + expect_no_error( + rfx_dataset$update_variance_weights(new_variance_weights) + ) + + # Check that we recover the correct data through get_group_labels, get_basis, and get_variance_weights + expect_equal(group_ids, rfx_dataset$get_group_labels()) + expect_equal(new_rfx_basis, rfx_dataset$get_basis()) + expect_equal(new_variance_weights, rfx_dataset$get_variance_weights()) }) diff --git a/test/R/testthat/test-forest-container.R b/test/R/testthat/test-forest-container.R index 9b399a96..adbfa423 100644 --- a/test/R/testthat/test-forest-container.R +++ b/test/R/testthat/test-forest-container.R @@ -1,251 +1,297 @@ test_that("Univariate constant forest container", { - # Create dataset and forest container - num_trees <- 10 - X = matrix(c(1.5, 8.7, 1.2, + # Create dataset and forest container + num_trees <- 10 + # fmt: skip + X = matrix(c(1.5, 8.7, 1.2, 2.7, 3.4, 5.4, 3.6, 1.2, 9.3, 4.4, 5.4, 10.4, 5.3, 9.3, 3.6, 6.1, 10.4, 4.4), byrow = TRUE, nrow = 6) - n <- nrow(X) - p <- ncol(X) - forest_dataset = createForestDataset(X) - forest_samples <- createForestSamples(num_trees, 1, TRUE) - - # Initialize a forest with constant root predictions - forest_samples$add_forest_with_constant_leaves(0.) - - # Check that regular and "raw" predictions are the same (since the leaf is constant) - pred <- forest_samples$predict(forest_dataset) - pred_raw <- forest_samples$predict_raw(forest_dataset) - - # Assertion - expect_equal(pred, pred_raw) - - # Split the root of the first tree in the ensemble at X[,1] > 4.0 - forest_samples$add_numeric_split_tree(0, 0, 0, 0, 4.0, -5., 5.) - - # Check that regular and "raw" predictions are the same (since the leaf is constant) - pred <- forest_samples$predict(forest_dataset) - pred_raw <- forest_samples$predict_raw(forest_dataset) - - # Assertion - expect_equal(pred, pred_raw) - - # Split the left leaf of the first tree in the ensemble at X[,2] > 4.0 - forest_samples$add_numeric_split_tree(0, 0, 1, 1, 4.0, -7.5, -2.5) - - # Check that regular and "raw" predictions are the same (since the leaf is constant) - pred <- forest_samples$predict(forest_dataset) - pred_raw <- forest_samples$predict_raw(forest_dataset) - - # Assertion - expect_equal(pred, pred_raw) - - # Multiply first forest by 2.0 - forest_samples$multiply_forest(0, 2.0) - - # Check that predictions are all double - pred_orig <- pred - pred_expected <- pred * 2.0 - pred <- forest_samples$predict(forest_dataset) - - # Assertion - expect_equal(pred, pred_expected) - - # Add 1.0 to every tree in first forest - forest_samples$add_to_forest(0, 1.0) - - # Check that predictions are += num_trees - pred_expected <- pred + num_trees - pred <- forest_samples$predict(forest_dataset) - - # Assertion - expect_equal(pred, pred_expected) - - # Initialize a new forest with constant root predictions - forest_samples$add_forest_with_constant_leaves(0.) - - # Split the second forest as the first forest was split - forest_samples$add_numeric_split_tree(1, 0, 0, 0, 4.0, -5., 5.) - forest_samples$add_numeric_split_tree(1, 0, 1, 1, 4.0, -7.5, -2.5) - - # Check that predictions are as expected - pred <- forest_samples$predict(forest_dataset) - pred_expected_new <- cbind(pred_expected, pred_orig) - - # Assertion - expect_equal(pred, pred_expected_new) - - # Combine second forest with the first forest - forest_samples$combine_forests(c(0,1)) - - # Check that predictions are as expected - pred <- forest_samples$predict(forest_dataset) - pred_expected_new <- cbind(pred_expected + pred_orig, pred_orig) - - # Assertion - expect_equal(pred, pred_expected_new) - - # Divide first forest predictions by 2 - forest_samples$multiply_forest(0, 0.5) - - # Check that predictions are as expected - pred <- forest_samples$predict(forest_dataset) - pred_expected_new <- cbind((pred_expected + pred_orig)/2.0, pred_orig) - - # Assertion - expect_equal(pred, pred_expected_new) - - # Delete second forest - forest_samples$delete_sample(1) - - # Check that predictions are as expected - pred <- forest_samples$predict(forest_dataset) - pred_expected_new <- (pred_expected + pred_orig)/2.0 - - # Assertion - expect_equal(pred, pred_expected_new) + n <- nrow(X) + p <- ncol(X) + forest_dataset = createForestDataset(X) + forest_samples <- createForestSamples(num_trees, 1, TRUE) + + # Initialize a forest with constant root predictions + forest_samples$add_forest_with_constant_leaves(0.) + + # Check that regular and "raw" predictions are the same (since the leaf is constant) + pred <- forest_samples$predict(forest_dataset) + pred_raw <- forest_samples$predict_raw(forest_dataset) + + # Assertion + expect_equal(pred, pred_raw) + + # Split the root of the first tree in the ensemble at X[,1] > 4.0 + forest_samples$add_numeric_split_tree(0, 0, 0, 0, 4.0, -5., 5.) + + # Check that regular and "raw" predictions are the same (since the leaf is constant) + pred <- forest_samples$predict(forest_dataset) + pred_raw <- forest_samples$predict_raw(forest_dataset) + + # Assertion + expect_equal(pred, pred_raw) + + # Split the left leaf of the first tree in the ensemble at X[,2] > 4.0 + forest_samples$add_numeric_split_tree(0, 0, 1, 1, 4.0, -7.5, -2.5) + + # Check that regular and "raw" predictions are the same (since the leaf is constant) + pred <- forest_samples$predict(forest_dataset) + pred_raw <- forest_samples$predict_raw(forest_dataset) + + # Assertion + expect_equal(pred, pred_raw) + + # Multiply first forest by 2.0 + forest_samples$multiply_forest(0, 2.0) + + # Check that predictions are all double + pred_orig <- pred + pred_expected <- pred * 2.0 + pred <- forest_samples$predict(forest_dataset) + + # Assertion + expect_equal(pred, pred_expected) + + # Add 1.0 to every tree in first forest + forest_samples$add_to_forest(0, 1.0) + + # Check that predictions are += num_trees + pred_expected <- pred + num_trees + pred <- forest_samples$predict(forest_dataset) + + # Assertion + expect_equal(pred, pred_expected) + + # Initialize a new forest with constant root predictions + forest_samples$add_forest_with_constant_leaves(0.) + + # Split the second forest as the first forest was split + forest_samples$add_numeric_split_tree(1, 0, 0, 0, 4.0, -5., 5.) + forest_samples$add_numeric_split_tree(1, 0, 1, 1, 4.0, -7.5, -2.5) + + # Check that predictions are as expected + pred <- forest_samples$predict(forest_dataset) + pred_expected_new <- cbind(pred_expected, pred_orig) + + # Assertion + expect_equal(pred, pred_expected_new) + + # Combine second forest with the first forest + forest_samples$combine_forests(c(0, 1)) + + # Check that predictions are as expected + pred <- forest_samples$predict(forest_dataset) + pred_expected_new <- cbind(pred_expected + pred_orig, pred_orig) + + # Assertion + expect_equal(pred, pred_expected_new) + + # Divide first forest predictions by 2 + forest_samples$multiply_forest(0, 0.5) + + # Check that predictions are as expected + pred <- forest_samples$predict(forest_dataset) + pred_expected_new <- cbind((pred_expected + pred_orig) / 2.0, pred_orig) + + # Assertion + expect_equal(pred, pred_expected_new) + + # Delete second forest + forest_samples$delete_sample(1) + + # Check that predictions are as expected + pred <- forest_samples$predict(forest_dataset) + pred_expected_new <- (pred_expected + pred_orig) / 2.0 + + # Assertion + expect_equal(pred, pred_expected_new) }) test_that("Collapse forests", { - skip_on_cran() - - # Generate simulated data - n <- 100 - p <- 5 - X <- matrix(runif(n*p), ncol = p) - f_XW <- ( + skip_on_cran() + + # Generate simulated data + n <- 100 + p <- 5 + X <- matrix(runif(n * p), ncol = p) + # fmt: skip + f_XW <- ( ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + ((0.75 <= X[,1]) & (1 > X[,1])) * (7.5) ) - noise_sd <- 1 - y <- f_XW + rnorm(n, 0, noise_sd) - test_set_pct <- 0.2 - n_test <- round(test_set_pct*n) - n_train <- n - n_test - test_inds <- sort(sample(1:n, n_test, replace = FALSE)) - train_inds <- (1:n)[!((1:n) %in% test_inds)] - X_test <- X[test_inds,] - X_train <- X[train_inds,] - y_test <- y[test_inds] - y_train <- y[train_inds] - - # Create forest dataset - forest_dataset_test <- createForestDataset(covariates = X_test) - - # Run BART for 50 iterations - num_mcmc <- 50 - general_param_list <- list(num_chains = 1, keep_every = 1) - bart_model <- bart(X_train = X_train, y_train = y_train, X_test = X_test, - num_gfr = 0, num_burnin = 0, num_mcmc = num_mcmc, - general_params = general_param_list) - - # Extract the mean forest container - mean_forest_container <- bart_model$mean_forests - - # Predict from the original container - pred_orig <- mean_forest_container$predict(forest_dataset_test) - - # Collapse the container in batches of 5 - batch_size <- 5 - mean_forest_container$collapse(batch_size) - - # Predict from the modified container - pred_new <- mean_forest_container$predict(forest_dataset_test) - - # Check that corresponding (sums of) predictions match - batch_inds <- (seq(1,num_mcmc,1) - (num_mcmc - (num_mcmc %/% (num_mcmc %/% batch_size)) * (num_mcmc %/% batch_size)) - 1) %/% batch_size + 1 - pred_orig_collapsed <- matrix(NA, nrow = nrow(pred_orig), ncol = max(batch_inds)) - for (i in 1:max(batch_inds)) { - pred_orig_collapsed[,i] <- rowSums(pred_orig[,batch_inds == i]) / sum(batch_inds == i) - } - - # Assertion - expect_equal(pred_orig_collapsed, pred_new) - - # Now run BART for 52 iterations - num_mcmc <- 52 - general_param_list <- list(num_chains = 1, keep_every = 1) - bart_model <- bart(X_train = X_train, y_train = y_train, X_test = X_test, - num_gfr = 0, num_burnin = 0, num_mcmc = num_mcmc, - general_params = general_param_list) - - # Extract the mean forest container - mean_forest_container <- bart_model$mean_forests - - # Predict from the original container - pred_orig <- mean_forest_container$predict(forest_dataset_test) - - # Collapse the container in batches of 5 - batch_size <- 5 - mean_forest_container$collapse(batch_size) - - # Predict from the modified container - pred_new <- mean_forest_container$predict(forest_dataset_test) - - # Check that corresponding (sums of) predictions match - batch_inds <- (seq(1,num_mcmc,1) - (num_mcmc - (num_mcmc %/% (num_mcmc %/% batch_size)) * (num_mcmc %/% batch_size)) - 1) %/% batch_size + 1 - pred_orig_collapsed <- matrix(NA, nrow = nrow(pred_orig), ncol = max(batch_inds)) - for (i in 1:max(batch_inds)) { - pred_orig_collapsed[,i] <- rowSums(pred_orig[,batch_inds == i]) / sum(batch_inds == i) - } - - # Assertion - expect_equal(pred_orig_collapsed, pred_new) - - # Now run BART for 5 iterations - num_mcmc <- 5 - general_param_list <- list(num_chains = 1, keep_every = 1) - bart_model <- bart(X_train = X_train, y_train = y_train, X_test = X_test, - num_gfr = 0, num_burnin = 0, num_mcmc = num_mcmc, - general_params = general_param_list) - - # Extract the mean forest container - mean_forest_container <- bart_model$mean_forests - - # Predict from the original container - pred_orig <- mean_forest_container$predict(forest_dataset_test) - - # Collapse the container in batches of 5 - batch_size <- 5 - mean_forest_container$collapse(batch_size) - - # Predict from the modified container - pred_new <- mean_forest_container$predict(forest_dataset_test) - - # Check that corresponding (sums of) predictions match - pred_orig_collapsed <- as.matrix(rowSums(pred_orig) / batch_size) - - # Assertion - expect_equal(pred_orig_collapsed, pred_new) - - # Now run BART for 4 iterations - num_mcmc <- 4 - general_param_list <- list(num_chains = 1, keep_every = 1) - bart_model <- bart(X_train = X_train, y_train = y_train, X_test = X_test, - num_gfr = 0, num_burnin = 0, num_mcmc = num_mcmc, - general_params = general_param_list) - - # Extract the mean forest container - mean_forest_container <- bart_model$mean_forests - - # Predict from the original container - pred_orig <- mean_forest_container$predict(forest_dataset_test) - - # Collapse the container in batches of 5 - batch_size <- 5 - mean_forest_container$collapse(batch_size) - - # Predict from the modified container - pred_new <- mean_forest_container$predict(forest_dataset_test) - - # Check that corresponding (sums of) predictions match - pred_orig_collapsed <- pred_orig - - # Assertion - expect_equal(pred_orig_collapsed, pred_new) + noise_sd <- 1 + y <- f_XW + rnorm(n, 0, noise_sd) + test_set_pct <- 0.2 + n_test <- round(test_set_pct * n) + n_train <- n - n_test + test_inds <- sort(sample(1:n, n_test, replace = FALSE)) + train_inds <- (1:n)[!((1:n) %in% test_inds)] + X_test <- X[test_inds, ] + X_train <- X[train_inds, ] + y_test <- y[test_inds] + y_train <- y[train_inds] + + # Create forest dataset + forest_dataset_test <- createForestDataset(covariates = X_test) + + # Run BART for 50 iterations + num_mcmc <- 50 + general_param_list <- list(num_chains = 1, keep_every = 1) + bart_model <- bart( + X_train = X_train, + y_train = y_train, + X_test = X_test, + num_gfr = 0, + num_burnin = 0, + num_mcmc = num_mcmc, + general_params = general_param_list + ) + + # Extract the mean forest container + mean_forest_container <- bart_model$mean_forests + + # Predict from the original container + pred_orig <- mean_forest_container$predict(forest_dataset_test) + + # Collapse the container in batches of 5 + batch_size <- 5 + mean_forest_container$collapse(batch_size) + + # Predict from the modified container + pred_new <- mean_forest_container$predict(forest_dataset_test) + + # Check that corresponding (sums of) predictions match + batch_inds <- (seq(1, num_mcmc, 1) - + (num_mcmc - + (num_mcmc %/% (num_mcmc %/% batch_size)) * (num_mcmc %/% batch_size)) - + 1) %/% + batch_size + + 1 + pred_orig_collapsed <- matrix( + NA, + nrow = nrow(pred_orig), + ncol = max(batch_inds) + ) + for (i in 1:max(batch_inds)) { + pred_orig_collapsed[, i] <- rowSums(pred_orig[, batch_inds == i]) / + sum(batch_inds == i) + } + + # Assertion + expect_equal(pred_orig_collapsed, pred_new) + + # Now run BART for 52 iterations + num_mcmc <- 52 + general_param_list <- list(num_chains = 1, keep_every = 1) + bart_model <- bart( + X_train = X_train, + y_train = y_train, + X_test = X_test, + num_gfr = 0, + num_burnin = 0, + num_mcmc = num_mcmc, + general_params = general_param_list + ) + + # Extract the mean forest container + mean_forest_container <- bart_model$mean_forests + + # Predict from the original container + pred_orig <- mean_forest_container$predict(forest_dataset_test) + + # Collapse the container in batches of 5 + batch_size <- 5 + mean_forest_container$collapse(batch_size) + + # Predict from the modified container + pred_new <- mean_forest_container$predict(forest_dataset_test) + + # Check that corresponding (sums of) predictions match + batch_inds <- (seq(1, num_mcmc, 1) - + (num_mcmc - + (num_mcmc %/% (num_mcmc %/% batch_size)) * (num_mcmc %/% batch_size)) - + 1) %/% + batch_size + + 1 + pred_orig_collapsed <- matrix( + NA, + nrow = nrow(pred_orig), + ncol = max(batch_inds) + ) + for (i in 1:max(batch_inds)) { + pred_orig_collapsed[, i] <- rowSums(pred_orig[, batch_inds == i]) / + sum(batch_inds == i) + } + + # Assertion + expect_equal(pred_orig_collapsed, pred_new) + + # Now run BART for 5 iterations + num_mcmc <- 5 + general_param_list <- list(num_chains = 1, keep_every = 1) + bart_model <- bart( + X_train = X_train, + y_train = y_train, + X_test = X_test, + num_gfr = 0, + num_burnin = 0, + num_mcmc = num_mcmc, + general_params = general_param_list + ) + + # Extract the mean forest container + mean_forest_container <- bart_model$mean_forests + + # Predict from the original container + pred_orig <- mean_forest_container$predict(forest_dataset_test) + + # Collapse the container in batches of 5 + batch_size <- 5 + mean_forest_container$collapse(batch_size) + + # Predict from the modified container + pred_new <- mean_forest_container$predict(forest_dataset_test) + + # Check that corresponding (sums of) predictions match + pred_orig_collapsed <- as.matrix(rowSums(pred_orig) / batch_size) + + # Assertion + expect_equal(pred_orig_collapsed, pred_new) + + # Now run BART for 4 iterations + num_mcmc <- 4 + general_param_list <- list(num_chains = 1, keep_every = 1) + bart_model <- bart( + X_train = X_train, + y_train = y_train, + X_test = X_test, + num_gfr = 0, + num_burnin = 0, + num_mcmc = num_mcmc, + general_params = general_param_list + ) + + # Extract the mean forest container + mean_forest_container <- bart_model$mean_forests + + # Predict from the original container + pred_orig <- mean_forest_container$predict(forest_dataset_test) + + # Collapse the container in batches of 5 + batch_size <- 5 + mean_forest_container$collapse(batch_size) + + # Predict from the modified container + pred_new <- mean_forest_container$predict(forest_dataset_test) + + # Check that corresponding (sums of) predictions match + pred_orig_collapsed <- pred_orig + + # Assertion + expect_equal(pred_orig_collapsed, pred_new) }) diff --git a/test/R/testthat/test-forest.R b/test/R/testthat/test-forest.R index 3c9409dd..063c64b6 100644 --- a/test/R/testthat/test-forest.R +++ b/test/R/testthat/test-forest.R @@ -1,162 +1,164 @@ test_that("Univariate forest construction", { - # Create dataset and forest container - num_trees <- 10 - X = matrix(c(1.5, 8.7, 1.2, + # Create dataset and forest container + num_trees <- 10 + # fmt: skip + X = matrix(c(1.5, 8.7, 1.2, 2.7, 3.4, 5.4, 3.6, 1.2, 9.3, 4.4, 5.4, 10.4, 5.3, 9.3, 3.6, 6.1, 10.4, 4.4), byrow = TRUE, nrow = 6) - n <- nrow(X) - p <- ncol(X) - forest_dataset = createForestDataset(X) - forest <- createForest(num_trees, 1, TRUE) - - # Initialize forest with 0.0 root predictions - forest$set_root_leaves(0.) - - # Check that regular and "raw" predictions are the same (since the leaf is constant) - pred <- forest$predict(forest_dataset) - pred_raw <- forest$predict_raw(forest_dataset) - - # Assertion - expect_equal(pred, pred_raw) - - # Split the root of the first tree in the ensemble at X[,1] > 4.0 - forest$add_numeric_split_tree(0, 0, 0, 4.0, -5., 5.) - - # Check that predictions are the same (since the leaf is constant) - pred <- forest$predict(forest_dataset) - pred_raw <- forest$predict_raw(forest_dataset) - - # Assertion - expect_equal(pred, pred_raw) - - # Split the left leaf of the first tree in the ensemble at X[,2] > 4.0 - forest$add_numeric_split_tree(0, 1, 1, 4.0, -7.5, -2.5) - - # Check that regular and "raw" predictions are the same (since the leaf is constant) - pred <- forest$predict(forest_dataset) - pred_raw <- forest$predict_raw(forest_dataset) - - # Assertion - expect_equal(pred, pred_raw) - - # Check the split count for the first tree in the ensemble - split_counts <- forest$get_tree_split_counts(0,p) - split_counts_expected <- c(1,1,0) - - # Assertion - expect_equal(split_counts, split_counts_expected) + n <- nrow(X) + p <- ncol(X) + forest_dataset = createForestDataset(X) + forest <- createForest(num_trees, 1, TRUE) + + # Initialize forest with 0.0 root predictions + forest$set_root_leaves(0.) + + # Check that regular and "raw" predictions are the same (since the leaf is constant) + pred <- forest$predict(forest_dataset) + pred_raw <- forest$predict_raw(forest_dataset) + + # Assertion + expect_equal(pred, pred_raw) + + # Split the root of the first tree in the ensemble at X[,1] > 4.0 + forest$add_numeric_split_tree(0, 0, 0, 4.0, -5., 5.) + + # Check that predictions are the same (since the leaf is constant) + pred <- forest$predict(forest_dataset) + pred_raw <- forest$predict_raw(forest_dataset) + + # Assertion + expect_equal(pred, pred_raw) + + # Split the left leaf of the first tree in the ensemble at X[,2] > 4.0 + forest$add_numeric_split_tree(0, 1, 1, 4.0, -7.5, -2.5) + + # Check that regular and "raw" predictions are the same (since the leaf is constant) + pred <- forest$predict(forest_dataset) + pred_raw <- forest$predict_raw(forest_dataset) + + # Assertion + expect_equal(pred, pred_raw) + + # Check the split count for the first tree in the ensemble + split_counts <- forest$get_tree_split_counts(0, p) + split_counts_expected <- c(1, 1, 0) + + # Assertion + expect_equal(split_counts, split_counts_expected) }) test_that("Univariate forest construction and low-level merge / arithmetic ops", { - # Create dataset and forest container - num_trees <- 10 - X = matrix(c(1.5, 8.7, 1.2, + # Create dataset and forest container + num_trees <- 10 + # fmt: skip + X = matrix(c(1.5, 8.7, 1.2, 2.7, 3.4, 5.4, 3.6, 1.2, 9.3, 4.4, 5.4, 10.4, 5.3, 9.3, 3.6, 6.1, 10.4, 4.4), byrow = TRUE, nrow = 6) - n <- nrow(X) - p <- ncol(X) - forest_dataset = createForestDataset(X) - forest1 <- createForest(num_trees, 1, TRUE) - forest2 <- createForest(num_trees, 1, TRUE) - - # Initialize forests with 0.0 root predictions - forest1$set_root_leaves(0.) - forest2$set_root_leaves(0.) - - # Check that predictions are as expected - pred1 <- forest1$predict(forest_dataset) - pred2 <- forest2$predict(forest_dataset) - pred_exp1 <- c(0,0,0,0,0,0) - pred_exp2 <- c(0,0,0,0,0,0) - - # Assertion - expect_equal(pred1, pred_exp1) - expect_equal(pred2, pred_exp2) - - # Split the root of the first tree of the first forest in the ensemble at X[,1] > 4.0 - forest1$add_numeric_split_tree(0, 0, 0, 4.0, -5., 5.) - - # Split the root of the first tree of the second forest in the ensemble at X[,1] > 3.0 - forest2$add_numeric_split_tree(0, 0, 0, 3.0, -1., 1.) - - # Check that predictions are as expected - pred1 <- forest1$predict(forest_dataset) - pred2 <- forest2$predict(forest_dataset) - pred_exp1 <- c(-5,-5,-5,5,5,5) - pred_exp2 <- c(-1,-1,1,1,1,1) - - # Assertion - expect_equal(pred1, pred_exp1) - expect_equal(pred2, pred_exp2) - - # Split the left leaf of the first tree of the first forest in the ensemble at X[,2] > 4.0 - forest1$add_numeric_split_tree(0, 1, 1, 4.0, -7.5, -2.5) - - # Split the left leaf of the first tree of the first forest in the ensemble at X[,2] > 4.0 - forest2$add_numeric_split_tree(0, 1, 1, 4.0, -1.5, -0.5) - - # Check that predictions are as expected - pred1 <- forest1$predict(forest_dataset) - pred2 <- forest2$predict(forest_dataset) - pred_exp1 <- c(-2.5,-7.5,-7.5,5,5,5) - pred_exp2 <- c(-0.5,-1.5,1,1,1,1) - - # Assertion - expect_equal(pred1, pred_exp1) - expect_equal(pred2, pred_exp2) - - # Merge forests - forest1$merge_forest(forest2) - - # Check that predictions are as expected - pred <- forest1$predict(forest_dataset) - pred_exp <- c(-3.0,-9.0,-6.5,6.0,6.0,6.0) - - # Assertion - expect_equal(pred, pred_exp) - - # Add constant to every value of the combined forest - forest1$add_constant(0.5) - - # Check that predictions are as expected - pred <- forest1$predict(forest_dataset) - pred_exp <- c(7.0,1.0,3.5,16.0,16.0,16.0) - - # Assertion - expect_equal(pred, pred_exp) - - # Check that "old" forest is still intact - pred <- forest2$predict(forest_dataset) - pred_exp <- c(-0.5,-1.5,1,1,1,1) - - # Assertion - expect_equal(pred, pred_exp) - - # Subtract constant back off of every value of the combined forest - forest1$add_constant(-0.5) - - # Check that predictions are as expected - pred <- forest1$predict(forest_dataset) - pred_exp <- c(-3.0,-9.0,-6.5,6.0,6.0,6.0) - - # Assertion - expect_equal(pred, pred_exp) - - # Multiply every value of the combined forest by a constant - forest1$multiply_constant(2.0) - - # Check that predictions are as expected - pred <- forest1$predict(forest_dataset) - pred_exp <- c(-6.0,-18.0,-13.0,12.0,12.0,12.0) - - # Assertion - expect_equal(pred, pred_exp) + n <- nrow(X) + p <- ncol(X) + forest_dataset = createForestDataset(X) + forest1 <- createForest(num_trees, 1, TRUE) + forest2 <- createForest(num_trees, 1, TRUE) + + # Initialize forests with 0.0 root predictions + forest1$set_root_leaves(0.) + forest2$set_root_leaves(0.) + + # Check that predictions are as expected + pred1 <- forest1$predict(forest_dataset) + pred2 <- forest2$predict(forest_dataset) + pred_exp1 <- c(0, 0, 0, 0, 0, 0) + pred_exp2 <- c(0, 0, 0, 0, 0, 0) + + # Assertion + expect_equal(pred1, pred_exp1) + expect_equal(pred2, pred_exp2) + + # Split the root of the first tree of the first forest in the ensemble at X[,1] > 4.0 + forest1$add_numeric_split_tree(0, 0, 0, 4.0, -5., 5.) + + # Split the root of the first tree of the second forest in the ensemble at X[,1] > 3.0 + forest2$add_numeric_split_tree(0, 0, 0, 3.0, -1., 1.) + + # Check that predictions are as expected + pred1 <- forest1$predict(forest_dataset) + pred2 <- forest2$predict(forest_dataset) + pred_exp1 <- c(-5, -5, -5, 5, 5, 5) + pred_exp2 <- c(-1, -1, 1, 1, 1, 1) + + # Assertion + expect_equal(pred1, pred_exp1) + expect_equal(pred2, pred_exp2) + + # Split the left leaf of the first tree of the first forest in the ensemble at X[,2] > 4.0 + forest1$add_numeric_split_tree(0, 1, 1, 4.0, -7.5, -2.5) + + # Split the left leaf of the first tree of the first forest in the ensemble at X[,2] > 4.0 + forest2$add_numeric_split_tree(0, 1, 1, 4.0, -1.5, -0.5) + + # Check that predictions are as expected + pred1 <- forest1$predict(forest_dataset) + pred2 <- forest2$predict(forest_dataset) + pred_exp1 <- c(-2.5, -7.5, -7.5, 5, 5, 5) + pred_exp2 <- c(-0.5, -1.5, 1, 1, 1, 1) + + # Assertion + expect_equal(pred1, pred_exp1) + expect_equal(pred2, pred_exp2) + + # Merge forests + forest1$merge_forest(forest2) + + # Check that predictions are as expected + pred <- forest1$predict(forest_dataset) + pred_exp <- c(-3.0, -9.0, -6.5, 6.0, 6.0, 6.0) + + # Assertion + expect_equal(pred, pred_exp) + + # Add constant to every value of the combined forest + forest1$add_constant(0.5) + + # Check that predictions are as expected + pred <- forest1$predict(forest_dataset) + pred_exp <- c(7.0, 1.0, 3.5, 16.0, 16.0, 16.0) + + # Assertion + expect_equal(pred, pred_exp) + + # Check that "old" forest is still intact + pred <- forest2$predict(forest_dataset) + pred_exp <- c(-0.5, -1.5, 1, 1, 1, 1) + + # Assertion + expect_equal(pred, pred_exp) + + # Subtract constant back off of every value of the combined forest + forest1$add_constant(-0.5) + + # Check that predictions are as expected + pred <- forest1$predict(forest_dataset) + pred_exp <- c(-3.0, -9.0, -6.5, 6.0, 6.0, 6.0) + + # Assertion + expect_equal(pred, pred_exp) + + # Multiply every value of the combined forest by a constant + forest1$multiply_constant(2.0) + + # Check that predictions are as expected + pred <- forest1$predict(forest_dataset) + pred_exp <- c(-6.0, -18.0, -13.0, 12.0, 12.0, 12.0) + + # Assertion + expect_equal(pred, pred_exp) }) diff --git a/test/R/testthat/test-predict.R b/test/R/testthat/test-predict.R index 88628f6d..bdd9d66b 100644 --- a/test/R/testthat/test-predict.R +++ b/test/R/testthat/test-predict.R @@ -1,174 +1,445 @@ test_that("Prediction from trees with constant leaf", { - # Create dataset and forest container - num_trees <- 10 - X = matrix(c(1.5, 8.7, 1.2, + # Create dataset and forest container + num_trees <- 10 + # fmt: skip + X = matrix(c(1.5, 8.7, 1.2, 2.7, 3.4, 5.4, 3.6, 1.2, 9.3, 4.4, 5.4, 10.4, 5.3, 9.3, 3.6, 6.1, 10.4, 4.4), byrow = TRUE, nrow = 6) - n <- nrow(X) - p <- ncol(X) - forest_dataset = createForestDataset(X) - forest_samples <- createForestSamples(num_trees, 1, TRUE) - - # Initialize a forest with constant root predictions - forest_samples$add_forest_with_constant_leaves(0.) - - # Check that regular and "raw" predictions are the same (since the leaf is constant) - pred <- forest_samples$predict(forest_dataset) - pred_raw <- forest_samples$predict_raw(forest_dataset) - - # Assertion - expect_equal(pred, pred_raw) - - # Split the root of the first tree in the ensemble at X[,1] > 4.0 - forest_samples$add_numeric_split_tree(0, 0, 0, 0, 4.0, -5., 5.) - - # Check that regular and "raw" predictions are the same (since the leaf is constant) - pred <- forest_samples$predict(forest_dataset) - pred_raw <- forest_samples$predict_raw(forest_dataset) - - # Assertion - expect_equal(pred, pred_raw) - - # Split the left leaf of the first tree in the ensemble at X[,2] > 4.0 - forest_samples$add_numeric_split_tree(0, 0, 1, 1, 4.0, -7.5, -2.5) - - # Check that regular and "raw" predictions are the same (since the leaf is constant) - pred <- forest_samples$predict(forest_dataset) - pred_raw <- forest_samples$predict_raw(forest_dataset) - - # Assertion - expect_equal(pred, pred_raw) - - # Check the split count for the first tree in the ensemble - split_counts <- forest_samples$get_tree_split_counts(0,0,p) - split_counts_expected <- c(1,1,0) - - # Assertion - expect_equal(split_counts, split_counts_expected) + n <- nrow(X) + p <- ncol(X) + forest_dataset = createForestDataset(X) + forest_samples <- createForestSamples(num_trees, 1, TRUE) + + # Initialize a forest with constant root predictions + forest_samples$add_forest_with_constant_leaves(0.) + + # Check that regular and "raw" predictions are the same (since the leaf is constant) + pred <- forest_samples$predict(forest_dataset) + pred_raw <- forest_samples$predict_raw(forest_dataset) + + # Assertion + expect_equal(pred, pred_raw) + + # Split the root of the first tree in the ensemble at X[,1] > 4.0 + forest_samples$add_numeric_split_tree(0, 0, 0, 0, 4.0, -5., 5.) + + # Check that regular and "raw" predictions are the same (since the leaf is constant) + pred <- forest_samples$predict(forest_dataset) + pred_raw <- forest_samples$predict_raw(forest_dataset) + + # Assertion + expect_equal(pred, pred_raw) + + # Split the left leaf of the first tree in the ensemble at X[,2] > 4.0 + forest_samples$add_numeric_split_tree(0, 0, 1, 1, 4.0, -7.5, -2.5) + + # Check that regular and "raw" predictions are the same (since the leaf is constant) + pred <- forest_samples$predict(forest_dataset) + pred_raw <- forest_samples$predict_raw(forest_dataset) + + # Assertion + expect_equal(pred, pred_raw) + + # Check the split count for the first tree in the ensemble + split_counts <- forest_samples$get_tree_split_counts(0, 0, p) + split_counts_expected <- c(1, 1, 0) + + # Assertion + expect_equal(split_counts, split_counts_expected) }) test_that("Prediction from trees with univariate leaf basis", { - # Create dataset and forest container - num_trees <- 10 - X = matrix(c(1.5, 8.7, 1.2, + # Create dataset and forest container + num_trees <- 10 + # fmt: skip + X = matrix(c(1.5, 8.7, 1.2, 2.7, 3.4, 5.4, 3.6, 1.2, 9.3, 4.4, 5.4, 10.4, 5.3, 9.3, 3.6, 6.1, 10.4, 4.4), byrow = TRUE, nrow = 6) - W = as.matrix(c(-1,-1,-1,1,1,1)) - n <- nrow(X) - p <- ncol(X) - forest_dataset = createForestDataset(X,W) - forest_samples <- createForestSamples(num_trees, 1, FALSE) - - # Initialize a forest with constant root predictions - forest_samples$add_forest_with_constant_leaves(0.) - - # Check that regular and "raw" predictions are the same (since the leaf is constant) - pred <- forest_samples$predict(forest_dataset) - pred_raw <- forest_samples$predict_raw(forest_dataset) - - # Assertion - expect_equal(pred, pred_raw) - - # Split the root of the first tree in the ensemble at X[,1] > 4.0 - forest_samples$add_numeric_split_tree(0, 0, 0, 0, 4.0, -5., 5.) - - # Check that regular and "raw" predictions are the same (since the leaf is constant) - pred <- forest_samples$predict(forest_dataset) - pred_raw <- forest_samples$predict_raw(forest_dataset) - pred_manual <- pred_raw*W - - # Assertion - expect_equal(pred, pred_manual) - - # Split the left leaf of the first tree in the ensemble at X[,2] > 4.0 - forest_samples$add_numeric_split_tree(0, 0, 1, 1, 4.0, -7.5, -2.5) - - # Check that regular and "raw" predictions are the same (since the leaf is constant) - pred <- forest_samples$predict(forest_dataset) - pred_raw <- forest_samples$predict_raw(forest_dataset) - pred_manual <- pred_raw*W - - # Assertion - expect_equal(pred, pred_manual) - - # Check the split count for the first tree in the ensemble - split_counts <- forest_samples$get_tree_split_counts(0,0,p) - split_counts_expected <- c(1,1,0) - - # Assertion - expect_equal(split_counts, split_counts_expected) + W = as.matrix(c(-1, -1, -1, 1, 1, 1)) + n <- nrow(X) + p <- ncol(X) + forest_dataset = createForestDataset(X, W) + forest_samples <- createForestSamples(num_trees, 1, FALSE) + + # Initialize a forest with constant root predictions + forest_samples$add_forest_with_constant_leaves(0.) + + # Check that regular and "raw" predictions are the same (since the leaf is constant) + pred <- forest_samples$predict(forest_dataset) + pred_raw <- forest_samples$predict_raw(forest_dataset) + + # Assertion + expect_equal(pred, pred_raw) + + # Split the root of the first tree in the ensemble at X[,1] > 4.0 + forest_samples$add_numeric_split_tree(0, 0, 0, 0, 4.0, -5., 5.) + + # Check that regular and "raw" predictions are the same (since the leaf is constant) + pred <- forest_samples$predict(forest_dataset) + pred_raw <- forest_samples$predict_raw(forest_dataset) + pred_manual <- pred_raw * W + + # Assertion + expect_equal(pred, pred_manual) + + # Split the left leaf of the first tree in the ensemble at X[,2] > 4.0 + forest_samples$add_numeric_split_tree(0, 0, 1, 1, 4.0, -7.5, -2.5) + + # Check that regular and "raw" predictions are the same (since the leaf is constant) + pred <- forest_samples$predict(forest_dataset) + pred_raw <- forest_samples$predict_raw(forest_dataset) + pred_manual <- pred_raw * W + + # Assertion + expect_equal(pred, pred_manual) + + # Check the split count for the first tree in the ensemble + split_counts <- forest_samples$get_tree_split_counts(0, 0, p) + split_counts_expected <- c(1, 1, 0) + + # Assertion + expect_equal(split_counts, split_counts_expected) }) test_that("Prediction from trees with multivariate leaf basis", { - # Create dataset and forest container - num_trees <- 10 - output_dim <- 2 - num_samples <- 0 - X = matrix(c(1.5, 8.7, 1.2, + # Create dataset and forest container + num_trees <- 10 + output_dim <- 2 + num_samples <- 0 + # fmt: skip + X = matrix(c(1.5, 8.7, 1.2, 2.7, 3.4, 5.4, 3.6, 1.2, 9.3, 4.4, 5.4, 10.4, 5.3, 9.3, 3.6, 6.1, 10.4, 4.4), byrow = TRUE, nrow = 6) - n <- nrow(X) - p <- ncol(X) - W = matrix(c(1,1,1,1,1,1,-1,-1,-1,1,1,1), byrow=FALSE, nrow=6) - forest_dataset = createForestDataset(X,W) - forest_samples <- createForestSamples(num_trees, output_dim, FALSE) - - # Initialize a forest with constant root predictions - forest_samples$add_forest_with_constant_leaves(c(1.,1.)) - num_samples <- num_samples + 1 - - # Check that regular and "raw" predictions are the same (since the leaf is constant) - pred <- forest_samples$predict(forest_dataset) - pred_raw <- forest_samples$predict_raw(forest_dataset) - pred_intermediate <- as.numeric(pred_raw) * as.numeric(W) - dim(pred_intermediate) <- c(n, output_dim, num_samples) - pred_manual <- apply(pred_intermediate, 3, function(x) rowSums(x)) - - # Assertion - expect_equal(pred, pred_manual) - - # Split the root of the first tree in the ensemble at X[,1] > 4.0 - forest_samples$add_numeric_split_tree(0, 0, 0, 0, 4.0, c(-5.,-1.), c(5.,1.)) - - # Check that regular and "raw" predictions are the same (since the leaf is constant) - pred <- forest_samples$predict(forest_dataset) - pred_raw <- forest_samples$predict_raw(forest_dataset) - pred_intermediate <- as.numeric(pred_raw) * as.numeric(W) - dim(pred_intermediate) <- c(n, output_dim, num_samples) - pred_manual <- apply(pred_intermediate, 3, function(x) rowSums(x)) - - # Assertion - expect_equal(pred, pred_manual) - - # Split the left leaf of the first tree in the ensemble at X[,2] > 4.0 - forest_samples$add_numeric_split_tree(0, 0, 1, 1, 4.0, c(-7.5,2.5), c(-2.5,7.5)) - - # Check that regular and "raw" predictions are the same (since the leaf is constant) - pred <- forest_samples$predict(forest_dataset) - pred_raw <- forest_samples$predict_raw(forest_dataset) - pred_intermediate <- as.numeric(pred_raw) * as.numeric(W) - dim(pred_intermediate) <- c(n, output_dim, num_samples) - pred_manual <- apply(pred_intermediate, 3, function(x) rowSums(x)) - - # Assertion - expect_equal(pred, pred_manual) - - # Check the split count for the first tree in the ensemble - split_counts <- forest_samples$get_tree_split_counts(0,0,3) - split_counts_expected <- c(1,1,0) - - # Assertion - expect_equal(split_counts, split_counts_expected) + n <- nrow(X) + p <- ncol(X) + W = matrix(c(1, 1, 1, 1, 1, 1, -1, -1, -1, 1, 1, 1), byrow = FALSE, nrow = 6) + forest_dataset = createForestDataset(X, W) + forest_samples <- createForestSamples(num_trees, output_dim, FALSE) + + # Initialize a forest with constant root predictions + forest_samples$add_forest_with_constant_leaves(c(1., 1.)) + num_samples <- num_samples + 1 + + # Check that regular and "raw" predictions are the same (since the leaf is constant) + pred <- forest_samples$predict(forest_dataset) + pred_raw <- forest_samples$predict_raw(forest_dataset) + pred_intermediate <- as.numeric(pred_raw) * as.numeric(W) + dim(pred_intermediate) <- c(n, output_dim, num_samples) + pred_manual <- apply(pred_intermediate, 3, function(x) rowSums(x)) + + # Assertion + expect_equal(pred, pred_manual) + + # Split the root of the first tree in the ensemble at X[,1] > 4.0 + forest_samples$add_numeric_split_tree(0, 0, 0, 0, 4.0, c(-5., -1.), c(5., 1.)) + + # Check that regular and "raw" predictions are the same (since the leaf is constant) + pred <- forest_samples$predict(forest_dataset) + pred_raw <- forest_samples$predict_raw(forest_dataset) + pred_intermediate <- as.numeric(pred_raw) * as.numeric(W) + dim(pred_intermediate) <- c(n, output_dim, num_samples) + pred_manual <- apply(pred_intermediate, 3, function(x) rowSums(x)) + + # Assertion + expect_equal(pred, pred_manual) + + # Split the left leaf of the first tree in the ensemble at X[,2] > 4.0 + forest_samples$add_numeric_split_tree( + 0, + 0, + 1, + 1, + 4.0, + c(-7.5, 2.5), + c(-2.5, 7.5) + ) + + # Check that regular and "raw" predictions are the same (since the leaf is constant) + pred <- forest_samples$predict(forest_dataset) + pred_raw <- forest_samples$predict_raw(forest_dataset) + pred_intermediate <- as.numeric(pred_raw) * as.numeric(W) + dim(pred_intermediate) <- c(n, output_dim, num_samples) + pred_manual <- apply(pred_intermediate, 3, function(x) rowSums(x)) + + # Assertion + expect_equal(pred, pred_manual) + + # Check the split count for the first tree in the ensemble + split_counts <- forest_samples$get_tree_split_counts(0, 0, 3) + split_counts_expected <- c(1, 1, 0) + + # Assertion + expect_equal(split_counts, split_counts_expected) +}) + +test_that("BART predictions with pre-summarization", { + # Generate data and test-train split + n <- 100 + p <- 5 + X <- matrix(runif(n * p), ncol = p) + # fmt: skip + f_XW <- (((0 <= X[, 1]) & (0.25 > X[, 1])) * (-7.5) + + ((0.25 <= X[, 1]) & (0.5 > X[, 1])) * (-2.5) + + ((0.5 <= X[, 1]) & (0.75 > X[, 1])) * (2.5) + + ((0.75 <= X[, 1]) & (1 > X[, 1])) * (7.5)) + noise_sd <- 1 + y <- f_XW + rnorm(n, 0, noise_sd) + test_set_pct <- 0.2 + n_test <- round(test_set_pct * n) + n_train <- n - n_test + test_inds <- sort(sample(1:n, n_test, replace = FALSE)) + train_inds <- (1:n)[!((1:n) %in% test_inds)] + X_test <- X[test_inds, ] + X_train <- X[train_inds, ] + y_test <- y[test_inds] + y_train <- y[train_inds] + + # Fit a "classic" BART model + bart_model <- bart( + X_train = X_train, + y_train = y_train, + num_gfr = 10, + num_burnin = 0, + num_mcmc = 10 + ) + + # Check that the default predict method returns a list + pred <- predict(bart_model, X_test) + y_hat_posterior_test <- pred$y_hat + expect_equal(dim(y_hat_posterior_test), c(20, 10)) + + # Check that the pre-aggregated predictions match with those computed by rowMeans + pred_mean <- predict(bart_model, X_test, type = "mean") + y_hat_mean_test <- pred_mean$y_hat + expect_equal(y_hat_mean_test, rowMeans(y_hat_posterior_test)) + + # Check that we warn and return a NULL when requesting terms that weren't fit + expect_warning({ + pred_mean <- predict( + bart_model, + X_test, + type = "mean", + terms = c("rfx", "variance_forest") + ) + }) + expect_equal(NULL, pred_mean) + + # Fit a heteroskedastic BART model + var_params <- list(num_trees = 20) + het_bart_model <- bart( + X_train = X_train, + y_train = y_train, + num_gfr = 10, + num_burnin = 0, + num_mcmc = 10, + variance_forest_params = var_params + ) + + # Check that the default predict method returns a list + pred <- predict(het_bart_model, X_test) + y_hat_posterior_test <- pred$y_hat + sigma2_hat_posterior_test <- pred$variance_forest_predictions + + # Assertion + expect_equal(dim(y_hat_posterior_test), c(20, 10)) + expect_equal(dim(sigma2_hat_posterior_test), c(20, 10)) + + # Check that the pre-aggregated predictions match with those computed by rowMeans + pred_mean <- predict(het_bart_model, X_test, type = "mean") + y_hat_mean_test <- pred_mean$y_hat + sigma2_hat_mean_test <- pred_mean$variance_forest_predictions + + # Assertion + expect_equal(y_hat_mean_test, rowMeans(y_hat_posterior_test)) + expect_equal(sigma2_hat_mean_test, rowMeans(sigma2_hat_posterior_test)) + + # Check that the "single-term" pre-aggregated predictions + # match those computed by pre-aggregated predictions returned in a list + y_hat_mean_test_single_term <- predict( + het_bart_model, + X_test, + type = "mean", + terms = "y_hat" + ) + sigma2_hat_mean_test_single_term <- predict( + het_bart_model, + X_test, + type = "mean", + terms = "variance_forest" + ) + + # Assertion + expect_equal(y_hat_mean_test, y_hat_mean_test_single_term) + expect_equal(sigma2_hat_mean_test, sigma2_hat_mean_test_single_term) +}) + +test_that("BCF predictions with pre-summarization", { + # Generate data and test-train split + n <- 100 + g <- function(x) { + ifelse(x[, 5] == 1, 2, ifelse(x[, 5] == 2, -1, -4)) + } + x1 <- rnorm(n) + x2 <- rnorm(n) + x3 <- rnorm(n) + x4 <- as.numeric(rbinom(n, 1, 0.5)) + x5 <- as.numeric(sample(1:3, n, replace = TRUE)) + X <- cbind(x1, x2, x3, x4, x5) + p <- ncol(X) + mu_x <- 1 + g(X) + X[, 1] * X[, 3] + tau_x <- 1 + 2 * X[, 2] * X[, 4] + pi_x <- 0.8 * + pnorm((3 * mu_x / sd(mu_x)) - 0.5 * X[, 1]) + + 0.05 + + runif(n) / 10 + Z <- rbinom(n, 1, pi_x) + E_XZ <- mu_x + Z * tau_x + snr <- 2 + y <- E_XZ + rnorm(n, 0, 1) * (sd(E_XZ) / snr) + X <- as.data.frame(X) + X$x4 <- factor(X$x4, ordered = TRUE) + X$x5 <- factor(X$x5, ordered = TRUE) + test_set_pct <- 0.2 + n_test <- round(test_set_pct * n) + n_train <- n - n_test + test_inds <- sort(sample(1:n, n_test, replace = FALSE)) + train_inds <- (1:n)[!((1:n) %in% test_inds)] + X_test <- X[test_inds, ] + X_train <- X[train_inds, ] + pi_test <- pi_x[test_inds] + pi_train <- pi_x[train_inds] + Z_test <- Z[test_inds] + Z_train <- Z[train_inds] + y_test <- y[test_inds] + y_train <- y[train_inds] + + # Fit a "classic" BCF model + bcf_model <- bcf( + X_train = X_train, + Z_train = Z_train, + y_train = y_train, + propensity_train = pi_train, + X_test = X_test, + Z_test = Z_test, + propensity_test = pi_test, + num_gfr = 10, + num_burnin = 0, + num_mcmc = 10 + ) + + # Check that the default predict method returns a list + pred <- predict( + bcf_model, + X = X_test, + Z = Z_test, + propensity = pi_test + ) + y_hat_posterior_test <- pred$y_hat + expect_equal(dim(y_hat_posterior_test), c(20, 10)) + + # Check that the pre-aggregated predictions match with those computed by rowMeans + pred_mean <- predict( + bcf_model, + X = X_test, + Z = Z_test, + propensity = pi_test, + type = "mean" + ) + y_hat_mean_test <- pred_mean$y_hat + expect_equal(y_hat_mean_test, rowMeans(y_hat_posterior_test)) + + # Check that we warn and return a NULL when requesting terms that weren't fit + expect_warning({ + pred_mean <- predict( + bcf_model, + X = X_test, + Z = Z_test, + propensity = pi_test, + type = "mean", + terms = c("rfx", "variance_forest") + ) + }) + expect_equal(NULL, pred_mean) + + # Fit a heteroskedastic BCF model + var_params <- list(num_trees = 20) + expect_warning( + het_bcf_model <- bcf( + X_train = X_train, + Z_train = Z_train, + y_train = y_train, + propensity_train = pi_train, + X_test = X_test, + Z_test = Z_test, + propensity_test = pi_test, + num_gfr = 10, + num_burnin = 0, + num_mcmc = 10, + variance_forest_params = var_params + ) + ) + + # Check that the default predict method returns a list + pred <- predict( + het_bcf_model, + X = X_test, + Z = Z_test, + propensity = pi_test + ) + y_hat_posterior_test <- pred$y_hat + sigma2_hat_posterior_test <- pred$variance_forest_predictions + + # Assertion + expect_equal(dim(y_hat_posterior_test), c(20, 10)) + expect_equal(dim(sigma2_hat_posterior_test), c(20, 10)) + + # Check that the pre-aggregated predictions match with those computed by rowMeans + pred_mean <- predict( + het_bcf_model, + X = X_test, + Z = Z_test, + propensity = pi_test, + type = "mean" + ) + y_hat_mean_test <- pred_mean$y_hat + sigma2_hat_mean_test <- pred_mean$variance_forest_predictions + + # Assertion + expect_equal(y_hat_mean_test, rowMeans(y_hat_posterior_test)) + expect_equal(sigma2_hat_mean_test, rowMeans(sigma2_hat_posterior_test)) + + # Check that the "single-term" pre-aggregated predictions + # match those computed by pre-aggregated predictions returned in a list + y_hat_mean_test_single_term <- predict( + het_bcf_model, + X = X_test, + Z = Z_test, + propensity = pi_test, + type = "mean", + terms = "y_hat" + ) + sigma2_hat_mean_test_single_term <- predict( + het_bcf_model, + X = X_test, + Z = Z_test, + propensity = pi_test, + type = "mean", + terms = "variance_forest" + ) + + # Assertion + expect_equal(y_hat_mean_test, y_hat_mean_test_single_term) + expect_equal(sigma2_hat_mean_test, sigma2_hat_mean_test_single_term) }) diff --git a/test/R/testthat/test-residual.R b/test/R/testthat/test-residual.R index 4f663465..64bbb30b 100644 --- a/test/R/testthat/test-residual.R +++ b/test/R/testthat/test-residual.R @@ -1,85 +1,120 @@ test_that("Residual updates correctly propagated after forest sampling step", { - # Setup - # Create dataset - X = matrix(c(1.5, 8.7, 1.2, + # Setup + # Create dataset + # fmt: skip + X = matrix(c(1.5, 8.7, 1.2, 2.7, 3.4, 5.4, 3.6, 1.2, 9.3, 4.4, 5.4, 10.4, 5.3, 9.3, 3.6, 6.1, 10.4, 4.4), byrow = TRUE, nrow = 6) - W = matrix(c(1, 1, 1, 1, 1, 1), nrow = 6) - n = nrow(X) - p = ncol(X) - y = as.matrix(ifelse(X[,1]>4,-5,5) + rnorm(n,0,1)) - y_bar = mean(y) - y_std = sd(y) - resid = (y-y_bar)/y_std - forest_dataset = createForestDataset(X, W) - residual = createOutcome(resid) - variable_weights = rep(1.0/p, p) - feature_types = as.integer(rep(0, p)) - - # Forest parameters - num_trees = 50 - alpha = 0.95 - beta = 2.0 - min_samples_leaf = 1 - current_sigma2 = 1. - current_leaf_scale = as.matrix(1./num_trees,nrow=1,ncol=1) - cutpoint_grid_size = 100 - max_depth = 10 - a_forest = 0 - b_forest = 0 - - # RNG - cpp_rng = createCppRNG(-1) - - # Create forest sampler and forest container - global_model_config = createGlobalModelConfig(global_error_variance=current_sigma2) - forest_model_config = createForestModelConfig(feature_types=feature_types, num_trees=num_trees, - num_observations=n, alpha=alpha, beta=beta, - min_samples_leaf=min_samples_leaf, max_depth=max_depth, - leaf_model_type=0, leaf_model_scale=current_leaf_scale, - variable_weights=variable_weights, variance_forest_shape=a_forest, - variance_forest_scale=b_forest, cutpoint_grid_size=cutpoint_grid_size) - forest_model = createForestModel(forest_dataset, forest_model_config, global_model_config) - forest_samples = createForestSamples(num_trees, 1, FALSE) - active_forest = createForest(num_trees, 1, FALSE) - - # Initialize the leaves of each tree in the prognostic forest - active_forest$prepare_for_sampler(forest_dataset, residual, forest_model, 0, mean(resid)) - active_forest$adjust_residual(forest_dataset, residual, forest_model, FALSE, FALSE) - - # Run the forest sampling algorithm for a single iteration - forest_model$sample_one_iteration( - forest_dataset, residual, forest_samples, active_forest, - cpp_rng, forest_model_config, global_model_config, keep_forest = TRUE, gfr = TRUE - ) + W = matrix(c(1, 1, 1, 1, 1, 1), nrow = 6) + n = nrow(X) + p = ncol(X) + y = as.matrix(ifelse(X[, 1] > 4, -5, 5) + rnorm(n, 0, 1)) + y_bar = mean(y) + y_std = sd(y) + resid = (y - y_bar) / y_std + forest_dataset = createForestDataset(X, W) + residual = createOutcome(resid) + variable_weights = rep(1.0 / p, p) + feature_types = as.integer(rep(0, p)) - # Get the current residual after running the sampler - initial_resid = residual$get_data() - - # Get initial prediction from the tree ensemble - initial_yhat = as.numeric(forest_samples$predict(forest_dataset)) - - # Update the basis vector - scalar = 2.0 - W_update = W*scalar - forest_dataset$update_basis(W_update) - - # Update residual to reflect adjusted basis - forest_model$propagate_basis_update(forest_dataset, residual, active_forest) - - # Get updated prediction from the tree ensemble - updated_yhat = as.numeric(forest_samples$predict(forest_dataset)) - - # Compute the expected residual - expected_resid = initial_resid + initial_yhat - updated_yhat - - # Get the current residual after running the sampler - updated_resid = residual$get_data() - - # Assertion - expect_equal(expected_resid, updated_resid) + # Forest parameters + num_trees = 50 + alpha = 0.95 + beta = 2.0 + min_samples_leaf = 1 + current_sigma2 = 1. + current_leaf_scale = as.matrix(1. / num_trees, nrow = 1, ncol = 1) + cutpoint_grid_size = 100 + max_depth = 10 + a_forest = 0 + b_forest = 0 + + # RNG + cpp_rng = createCppRNG(-1) + + # Create forest sampler and forest container + global_model_config = createGlobalModelConfig( + global_error_variance = current_sigma2 + ) + forest_model_config = createForestModelConfig( + feature_types = feature_types, + num_trees = num_trees, + num_observations = n, + alpha = alpha, + beta = beta, + min_samples_leaf = min_samples_leaf, + max_depth = max_depth, + leaf_model_type = 0, + leaf_model_scale = current_leaf_scale, + variable_weights = variable_weights, + variance_forest_shape = a_forest, + variance_forest_scale = b_forest, + cutpoint_grid_size = cutpoint_grid_size + ) + forest_model = createForestModel( + forest_dataset, + forest_model_config, + global_model_config + ) + forest_samples = createForestSamples(num_trees, 1, FALSE) + active_forest = createForest(num_trees, 1, FALSE) + + # Initialize the leaves of each tree in the prognostic forest + active_forest$prepare_for_sampler( + forest_dataset, + residual, + forest_model, + 0, + mean(resid) + ) + active_forest$adjust_residual( + forest_dataset, + residual, + forest_model, + FALSE, + FALSE + ) + + # Run the forest sampling algorithm for a single iteration + forest_model$sample_one_iteration( + forest_dataset, + residual, + forest_samples, + active_forest, + cpp_rng, + forest_model_config, + global_model_config, + keep_forest = TRUE, + gfr = TRUE + ) + + # Get the current residual after running the sampler + initial_resid = residual$get_data() + + # Get initial prediction from the tree ensemble + initial_yhat = as.numeric(forest_samples$predict(forest_dataset)) + + # Update the basis vector + scalar = 2.0 + W_update = W * scalar + forest_dataset$update_basis(W_update) + + # Update residual to reflect adjusted basis + forest_model$propagate_basis_update(forest_dataset, residual, active_forest) + + # Get updated prediction from the tree ensemble + updated_yhat = as.numeric(forest_samples$predict(forest_dataset)) + + # Compute the expected residual + expected_resid = initial_resid + initial_yhat - updated_yhat + + # Get the current residual after running the sampler + updated_resid = residual$get_data() + + # Assertion + expect_equal(expected_resid, updated_resid) }) diff --git a/test/R/testthat/test-serialization.R b/test/R/testthat/test-serialization.R index 0d78957f..fe50af5f 100644 --- a/test/R/testthat/test-serialization.R +++ b/test/R/testthat/test-serialization.R @@ -1,162 +1,181 @@ test_that("BART Serialization", { - skip_on_cran() - - # Generate simulated data - n <- 100 - p <- 5 - X <- matrix(runif(n*p), ncol = p) - f_XW <- ( + skip_on_cran() + + # Generate simulated data + n <- 100 + p <- 5 + X <- matrix(runif(n * p), ncol = p) + # fmt: skip + f_XW <- ( ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + ((0.75 <= X[,1]) & (1 > X[,1])) * (7.5) ) - noise_sd <- 1 - y <- f_XW + rnorm(n, 0, noise_sd) - test_set_pct <- 0.2 - n_test <- round(test_set_pct*n) - n_train <- n - n_test - test_inds <- sort(sample(1:n, n_test, replace = FALSE)) - train_inds <- (1:n)[!((1:n) %in% test_inds)] - X_test <- X[test_inds,] - X_train <- X[train_inds,] - y_test <- y[test_inds] - y_train <- y[train_inds] - - # Sample a BART model - general_param_list <- list(num_chains = 1, keep_every = 1) - bart_model <- bart(X_train = X_train, y_train = y_train, - num_gfr = 0, num_burnin = 10, num_mcmc = 10, - general_params = general_param_list) - y_hat_orig <- rowMeans(predict(bart_model, X_test)$y_hat) - - # Save to JSON - bart_json_string <- saveBARTModelToJsonString(bart_model) - - # Reload as a BART model - bart_model_roundtrip <- createBARTModelFromJsonString(bart_json_string) - - # Predict from the roundtrip BART model - y_hat_reloaded <- rowMeans(predict(bart_model_roundtrip, X_test)$y_hat) - - # Assertion - expect_equal(y_hat_orig, y_hat_reloaded) + noise_sd <- 1 + y <- f_XW + rnorm(n, 0, noise_sd) + test_set_pct <- 0.2 + n_test <- round(test_set_pct * n) + n_train <- n - n_test + test_inds <- sort(sample(1:n, n_test, replace = FALSE)) + train_inds <- (1:n)[!((1:n) %in% test_inds)] + X_test <- X[test_inds, ] + X_train <- X[train_inds, ] + y_test <- y[test_inds] + y_train <- y[train_inds] + + # Sample a BART model + general_param_list <- list(num_chains = 1, keep_every = 1) + bart_model <- bart( + X_train = X_train, + y_train = y_train, + num_gfr = 0, + num_burnin = 10, + num_mcmc = 10, + general_params = general_param_list + ) + y_hat_orig <- rowMeans(predict(bart_model, X_test)$y_hat) + + # Save to JSON + bart_json_string <- saveBARTModelToJsonString(bart_model) + + # Reload as a BART model + bart_model_roundtrip <- createBARTModelFromJsonString(bart_json_string) + + # Predict from the roundtrip BART model + y_hat_reloaded <- rowMeans(predict(bart_model_roundtrip, X_test)$y_hat) + + # Assertion + expect_equal(y_hat_orig, y_hat_reloaded) }) test_that("BCF Serialization", { - skip_on_cran() - - n <- 500 - x1 <- runif(n) - x2 <- runif(n) - x3 <- runif(n) - x4 <- runif(n) - x5 <- runif(n) - X <- cbind(x1,x2,x3,x4,x5) - p <- ncol(X) - pi_x <- 0.25 + 0.5*X[,1] - mu_x <- pi_x * 5 - tau_x <- X[,2] * 2 - Z <- rbinom(n,1,pi_x) - E_XZ <- mu_x + Z*tau_x - y <- E_XZ + rnorm(n, 0, 1) - test_set_pct <- 0.2 - n_test <- round(test_set_pct*n) - n_train <- n - n_test - test_inds <- sort(sample(1:n, n_test, replace = FALSE)) - train_inds <- (1:n)[!((1:n) %in% test_inds)] - X_test <- X[test_inds,] - X_train <- X[train_inds,] - pi_test <- pi_x[test_inds] - pi_train <- pi_x[train_inds] - Z_test <- Z[test_inds] - Z_train <- Z[train_inds] - y_test <- y[test_inds] - y_train <- y[train_inds] - mu_test <- mu_x[test_inds] - mu_train <- mu_x[train_inds] - tau_test <- tau_x[test_inds] - tau_train <- tau_x[train_inds] - - # Sample a BCF model - bcf_model <- bcf(X_train = X_train, Z_train = Z_train, y_train = y_train, - propensity_train = pi_train, num_gfr = 100, num_burnin = 0, num_mcmc = 100) - bcf_preds_orig <- predict(bcf_model, X_test, Z_test, pi_test) - mu_hat_orig <- rowMeans(bcf_preds_orig[["mu_hat"]]) - tau_hat_orig <- rowMeans(bcf_preds_orig[["tau_hat"]]) - y_hat_orig <- rowMeans(bcf_preds_orig[["y_hat"]]) - - # Save to JSON - bcf_json_string <- saveBCFModelToJsonString(bcf_model) - - # Reload as a BCF model - bcf_model_roundtrip <- createBCFModelFromJsonString(bcf_json_string) - - # Predict from the roundtrip BCF model - bcf_preds_reloaded <- predict(bcf_model_roundtrip, X_test, Z_test, pi_test) - mu_hat_reloaded <- rowMeans(bcf_preds_reloaded[["mu_hat"]]) - tau_hat_reloaded <- rowMeans(bcf_preds_reloaded[["tau_hat"]]) - y_hat_reloaded <- rowMeans(bcf_preds_reloaded[["y_hat"]]) - - # Assertion - expect_equal(y_hat_orig, y_hat_reloaded) + skip_on_cran() + + n <- 500 + x1 <- runif(n) + x2 <- runif(n) + x3 <- runif(n) + x4 <- runif(n) + x5 <- runif(n) + X <- cbind(x1, x2, x3, x4, x5) + p <- ncol(X) + pi_x <- 0.25 + 0.5 * X[, 1] + mu_x <- pi_x * 5 + tau_x <- X[, 2] * 2 + Z <- rbinom(n, 1, pi_x) + E_XZ <- mu_x + Z * tau_x + y <- E_XZ + rnorm(n, 0, 1) + test_set_pct <- 0.2 + n_test <- round(test_set_pct * n) + n_train <- n - n_test + test_inds <- sort(sample(1:n, n_test, replace = FALSE)) + train_inds <- (1:n)[!((1:n) %in% test_inds)] + X_test <- X[test_inds, ] + X_train <- X[train_inds, ] + pi_test <- pi_x[test_inds] + pi_train <- pi_x[train_inds] + Z_test <- Z[test_inds] + Z_train <- Z[train_inds] + y_test <- y[test_inds] + y_train <- y[train_inds] + mu_test <- mu_x[test_inds] + mu_train <- mu_x[train_inds] + tau_test <- tau_x[test_inds] + tau_train <- tau_x[train_inds] + + # Sample a BCF model + bcf_model <- bcf( + X_train = X_train, + Z_train = Z_train, + y_train = y_train, + propensity_train = pi_train, + num_gfr = 100, + num_burnin = 0, + num_mcmc = 100 + ) + bcf_preds_orig <- predict(bcf_model, X_test, Z_test, pi_test) + mu_hat_orig <- rowMeans(bcf_preds_orig[["mu_hat"]]) + tau_hat_orig <- rowMeans(bcf_preds_orig[["tau_hat"]]) + y_hat_orig <- rowMeans(bcf_preds_orig[["y_hat"]]) + + # Save to JSON + bcf_json_string <- saveBCFModelToJsonString(bcf_model) + + # Reload as a BCF model + bcf_model_roundtrip <- createBCFModelFromJsonString(bcf_json_string) + + # Predict from the roundtrip BCF model + bcf_preds_reloaded <- predict(bcf_model_roundtrip, X_test, Z_test, pi_test) + mu_hat_reloaded <- rowMeans(bcf_preds_reloaded[["mu_hat"]]) + tau_hat_reloaded <- rowMeans(bcf_preds_reloaded[["tau_hat"]]) + y_hat_reloaded <- rowMeans(bcf_preds_reloaded[["y_hat"]]) + + # Assertion + expect_equal(y_hat_orig, y_hat_reloaded) }) test_that("BCF Serialization (no propensity)", { - skip_on_cran() - - n <- 500 - x1 <- runif(n) - x2 <- runif(n) - x3 <- runif(n) - x4 <- runif(n) - x5 <- runif(n) - X <- cbind(x1,x2,x3,x4,x5) - p <- ncol(X) - pi_x <- 0.25 + 0.5*X[,1] - mu_x <- pi_x * 5 - tau_x <- X[,2] * 2 - Z <- rbinom(n,1,pi_x) - E_XZ <- mu_x + Z*tau_x - y <- E_XZ + rnorm(n, 0, 1) - test_set_pct <- 0.2 - n_test <- round(test_set_pct*n) - n_train <- n - n_test - test_inds <- sort(sample(1:n, n_test, replace = FALSE)) - train_inds <- (1:n)[!((1:n) %in% test_inds)] - X_test <- X[test_inds,] - X_train <- X[train_inds,] - pi_test <- pi_x[test_inds] - pi_train <- pi_x[train_inds] - Z_test <- Z[test_inds] - Z_train <- Z[train_inds] - y_test <- y[test_inds] - y_train <- y[train_inds] - mu_test <- mu_x[test_inds] - mu_train <- mu_x[train_inds] - tau_test <- tau_x[test_inds] - tau_train <- tau_x[train_inds] - - # Sample a BCF model - bcf_model <- bcf(X_train = X_train, Z_train = Z_train, y_train = y_train, - num_gfr = 100, num_burnin = 0, num_mcmc = 100) - bcf_preds_orig <- predict(bcf_model, X_test, Z_test) - mu_hat_orig <- rowMeans(bcf_preds_orig[["mu_hat"]]) - tau_hat_orig <- rowMeans(bcf_preds_orig[["tau_hat"]]) - y_hat_orig <- rowMeans(bcf_preds_orig[["y_hat"]]) - - # Save to JSON - bcf_json_string <- saveBCFModelToJsonString(bcf_model) - - # Reload as a BCF model - bcf_model_roundtrip <- createBCFModelFromJsonString(bcf_json_string) - - # Predict from the roundtrip BCF model - bcf_preds_reloaded <- predict(bcf_model_roundtrip, X_test, Z_test) - mu_hat_reloaded <- rowMeans(bcf_preds_reloaded[["mu_hat"]]) - tau_hat_reloaded <- rowMeans(bcf_preds_reloaded[["tau_hat"]]) - y_hat_reloaded <- rowMeans(bcf_preds_reloaded[["y_hat"]]) - - # Assertion - expect_equal(y_hat_orig, y_hat_reloaded) -}) \ No newline at end of file + skip_on_cran() + + n <- 500 + x1 <- runif(n) + x2 <- runif(n) + x3 <- runif(n) + x4 <- runif(n) + x5 <- runif(n) + X <- cbind(x1, x2, x3, x4, x5) + p <- ncol(X) + pi_x <- 0.25 + 0.5 * X[, 1] + mu_x <- pi_x * 5 + tau_x <- X[, 2] * 2 + Z <- rbinom(n, 1, pi_x) + E_XZ <- mu_x + Z * tau_x + y <- E_XZ + rnorm(n, 0, 1) + test_set_pct <- 0.2 + n_test <- round(test_set_pct * n) + n_train <- n - n_test + test_inds <- sort(sample(1:n, n_test, replace = FALSE)) + train_inds <- (1:n)[!((1:n) %in% test_inds)] + X_test <- X[test_inds, ] + X_train <- X[train_inds, ] + pi_test <- pi_x[test_inds] + pi_train <- pi_x[train_inds] + Z_test <- Z[test_inds] + Z_train <- Z[train_inds] + y_test <- y[test_inds] + y_train <- y[train_inds] + mu_test <- mu_x[test_inds] + mu_train <- mu_x[train_inds] + tau_test <- tau_x[test_inds] + tau_train <- tau_x[train_inds] + + # Sample a BCF model + bcf_model <- bcf( + X_train = X_train, + Z_train = Z_train, + y_train = y_train, + num_gfr = 100, + num_burnin = 0, + num_mcmc = 100 + ) + bcf_preds_orig <- predict(bcf_model, X_test, Z_test) + mu_hat_orig <- rowMeans(bcf_preds_orig[["mu_hat"]]) + tau_hat_orig <- rowMeans(bcf_preds_orig[["tau_hat"]]) + y_hat_orig <- rowMeans(bcf_preds_orig[["y_hat"]]) + + # Save to JSON + bcf_json_string <- saveBCFModelToJsonString(bcf_model) + + # Reload as a BCF model + bcf_model_roundtrip <- createBCFModelFromJsonString(bcf_json_string) + + # Predict from the roundtrip BCF model + bcf_preds_reloaded <- predict(bcf_model_roundtrip, X_test, Z_test) + mu_hat_reloaded <- rowMeans(bcf_preds_reloaded[["mu_hat"]]) + tau_hat_reloaded <- rowMeans(bcf_preds_reloaded[["tau_hat"]]) + y_hat_reloaded <- rowMeans(bcf_preds_reloaded[["y_hat"]]) + + # Assertion + expect_equal(y_hat_orig, y_hat_reloaded) +}) diff --git a/test/R/testthat/test-utils.R b/test/R/testthat/test-utils.R index 60991ad7..72434368 100644 --- a/test/R/testthat/test-utils.R +++ b/test/R/testthat/test-utils.R @@ -1,70 +1,86 @@ test_that("Array conversion", { - skip_on_cran() - - # Test data - scalar_1 <- 1.5 - scalar_2 <- -2.5 - scalar_3 <- 4 - array_1d_1 <- c(1.6, 3.4, 7.6, 8.7) - array_1d_2 <- c(2.5, 3.1, 5.6) - array_1d_3 <- c(2.5) - array_2d_1 <- matrix( - c(2.5,1.2,4.3,7.4,1.7,2.9,3.6,9.1,7.2,4.5,6.7,1.4), - nrow = 3, ncol = 4, byrow = T - ) - array_2d_2 <- matrix( - c(2.5,1.2,4.3,7.4,1.7,2.9,3.6,9.1), - nrow = 2, ncol = 4, byrow = T - ) - array_square_1 <- matrix( - c(2.5,1.2,1.7,2.9), - nrow = 2, ncol = 2, byrow = T - ) - array_square_2 <- matrix( - c(2.5,0.0,0.0,2.9), - nrow = 2, ncol = 2, byrow = T - ) - array_square_3 <- matrix( - c(2.5,0.0,0.0,0.0,2.9,0.0,0.0,0.0,5.6), - nrow = 3, ncol = 3, byrow = T - ) - - # Error cases - expect_error(expand_dims_1d(array_1d_1, 5)) - expect_error(expand_dims_1d(array_1d_2, 4)) - expect_error(expand_dims_2d(array_2d_1, 2, 4)) - expect_error(expand_dims_2d(array_2d_2, 3, 4)) - expect_error(expand_dims_2d_diag(array_square_1, 4)) - expect_error(expand_dims_2d_diag(array_square_2, 3)) - expect_error(expand_dims_2d_diag(array_square_3, 2)) - - # Working cases - expect_equal(c(scalar_1,scalar_1,scalar_1), expand_dims_1d(scalar_1, 3)) - expect_equal(c(scalar_2,scalar_2,scalar_2,scalar_2), expand_dims_1d(scalar_2, 4)) - expect_equal(c(scalar_3,scalar_3), expand_dims_1d(scalar_3, 2)) - expect_equal(c(array_1d_3,array_1d_3,array_1d_3), expand_dims_1d(array_1d_3, 3)) - - output_exp <- matrix(rep(scalar_1, 6), nrow = 2, byrow = T) - expect_equal(output_exp, expand_dims_2d(scalar_1, 2, 3)) - output_exp <- matrix(rep(scalar_2, 8), nrow = 2, byrow = T) - expect_equal(output_exp, expand_dims_2d(scalar_2, 2, 4)) - output_exp <- matrix(rep(scalar_3, 6), nrow = 3, byrow = T) - expect_equal(output_exp, expand_dims_2d(scalar_3, 3, 2)) - output_exp <- matrix(rep(array_1d_3, 6), nrow = 3, byrow = T) - expect_equal(output_exp, expand_dims_2d(array_1d_3, 3, 2)) - output_exp <- unname(rbind(array_1d_1, array_1d_1)) - expect_equal(output_exp, expand_dims_2d(array_1d_1, 2, 4)) - output_exp <- unname(rbind(array_1d_2, array_1d_2, array_1d_2)) - expect_equal(output_exp, expand_dims_2d(array_1d_2, 3, 3)) - output_exp <- unname(cbind(array_1d_2, array_1d_2, array_1d_2, array_1d_2)) - expect_equal(output_exp, expand_dims_2d(array_1d_2, 3, 4)) - output_exp <- unname(cbind(array_1d_3, array_1d_3, array_1d_3, array_1d_3)) - expect_equal(output_exp, expand_dims_2d(array_1d_3, 1, 4)) - output_exp <- unname(rbind(array_1d_3, array_1d_3, array_1d_3, array_1d_3)) - expect_equal(output_exp, expand_dims_2d(array_1d_3, 4, 1)) - - expect_equal(diag(scalar_1, 3), expand_dims_2d_diag(scalar_1, 3)) - expect_equal(diag(scalar_2, 2), expand_dims_2d_diag(scalar_2, 2)) - expect_equal(diag(scalar_3, 4), expand_dims_2d_diag(scalar_3, 4)) - expect_equal(diag(array_1d_3, 2), expand_dims_2d_diag(array_1d_3, 2)) -}) \ No newline at end of file + skip_on_cran() + + # Test data + scalar_1 <- 1.5 + scalar_2 <- -2.5 + scalar_3 <- 4 + array_1d_1 <- c(1.6, 3.4, 7.6, 8.7) + array_1d_2 <- c(2.5, 3.1, 5.6) + array_1d_3 <- c(2.5) + array_2d_1 <- matrix( + c(2.5, 1.2, 4.3, 7.4, 1.7, 2.9, 3.6, 9.1, 7.2, 4.5, 6.7, 1.4), + nrow = 3, + ncol = 4, + byrow = T + ) + array_2d_2 <- matrix( + c(2.5, 1.2, 4.3, 7.4, 1.7, 2.9, 3.6, 9.1), + nrow = 2, + ncol = 4, + byrow = T + ) + array_square_1 <- matrix( + c(2.5, 1.2, 1.7, 2.9), + nrow = 2, + ncol = 2, + byrow = T + ) + array_square_2 <- matrix( + c(2.5, 0.0, 0.0, 2.9), + nrow = 2, + ncol = 2, + byrow = T + ) + array_square_3 <- matrix( + c(2.5, 0.0, 0.0, 0.0, 2.9, 0.0, 0.0, 0.0, 5.6), + nrow = 3, + ncol = 3, + byrow = T + ) + + # Error cases + expect_error(expand_dims_1d(array_1d_1, 5)) + expect_error(expand_dims_1d(array_1d_2, 4)) + expect_error(expand_dims_2d(array_2d_1, 2, 4)) + expect_error(expand_dims_2d(array_2d_2, 3, 4)) + expect_error(expand_dims_2d_diag(array_square_1, 4)) + expect_error(expand_dims_2d_diag(array_square_2, 3)) + expect_error(expand_dims_2d_diag(array_square_3, 2)) + + # Working cases + expect_equal(c(scalar_1, scalar_1, scalar_1), expand_dims_1d(scalar_1, 3)) + expect_equal( + c(scalar_2, scalar_2, scalar_2, scalar_2), + expand_dims_1d(scalar_2, 4) + ) + expect_equal(c(scalar_3, scalar_3), expand_dims_1d(scalar_3, 2)) + expect_equal( + c(array_1d_3, array_1d_3, array_1d_3), + expand_dims_1d(array_1d_3, 3) + ) + + output_exp <- matrix(rep(scalar_1, 6), nrow = 2, byrow = T) + expect_equal(output_exp, expand_dims_2d(scalar_1, 2, 3)) + output_exp <- matrix(rep(scalar_2, 8), nrow = 2, byrow = T) + expect_equal(output_exp, expand_dims_2d(scalar_2, 2, 4)) + output_exp <- matrix(rep(scalar_3, 6), nrow = 3, byrow = T) + expect_equal(output_exp, expand_dims_2d(scalar_3, 3, 2)) + output_exp <- matrix(rep(array_1d_3, 6), nrow = 3, byrow = T) + expect_equal(output_exp, expand_dims_2d(array_1d_3, 3, 2)) + output_exp <- unname(rbind(array_1d_1, array_1d_1)) + expect_equal(output_exp, expand_dims_2d(array_1d_1, 2, 4)) + output_exp <- unname(rbind(array_1d_2, array_1d_2, array_1d_2)) + expect_equal(output_exp, expand_dims_2d(array_1d_2, 3, 3)) + output_exp <- unname(cbind(array_1d_2, array_1d_2, array_1d_2, array_1d_2)) + expect_equal(output_exp, expand_dims_2d(array_1d_2, 3, 4)) + output_exp <- unname(cbind(array_1d_3, array_1d_3, array_1d_3, array_1d_3)) + expect_equal(output_exp, expand_dims_2d(array_1d_3, 1, 4)) + output_exp <- unname(rbind(array_1d_3, array_1d_3, array_1d_3, array_1d_3)) + expect_equal(output_exp, expand_dims_2d(array_1d_3, 4, 1)) + + expect_equal(diag(scalar_1, 3), expand_dims_2d_diag(scalar_1, 3)) + expect_equal(diag(scalar_2, 2), expand_dims_2d_diag(scalar_2, 2)) + expect_equal(diag(scalar_3, 4), expand_dims_2d_diag(scalar_3, 4)) + expect_equal(diag(array_1d_3, 2), expand_dims_2d_diag(array_1d_3, 2)) +}) diff --git a/test/python/test_bart.py b/test/python/test_bart.py index 4b22ab7b..3243b86a 100644 --- a/test/python/test_bart.py +++ b/test/python/test_bart.py @@ -84,7 +84,7 @@ def outcome_mean(X): # Assertions bart_preds_combined = bart_model_3.predict(covariates=X_train) - y_hat_train_combined = bart_preds_combined['y_hat'] + y_hat_train_combined = bart_preds_combined["y_hat"] assert y_hat_train_combined.shape == (n_train, num_mcmc * 2) np.testing.assert_allclose( y_hat_train_combined[:, 0:num_mcmc], bart_model.y_hat_train @@ -192,7 +192,7 @@ def outcome_mean(X, W): bart_preds_combined = bart_model_3.predict( covariates=X_train, basis=basis_train ) - y_hat_train_combined = bart_preds_combined['y_hat'] + y_hat_train_combined = bart_preds_combined["y_hat"] assert y_hat_train_combined.shape == (n_train, num_mcmc * 2) np.testing.assert_allclose( y_hat_train_combined[:, 0:num_mcmc], bart_model.y_hat_train @@ -300,7 +300,7 @@ def outcome_mean(X, W): bart_preds_combined = bart_model_3.predict( covariates=X_train, basis=basis_train ) - y_hat_train_combined = bart_preds_combined['y_hat'] + y_hat_train_combined = bart_preds_combined["y_hat"] assert y_hat_train_combined.shape == (n_train, num_mcmc * 2) np.testing.assert_allclose( y_hat_train_combined[:, 0:num_mcmc], bart_model.y_hat_train @@ -411,7 +411,10 @@ def conditional_stddev(X): # Assertions bart_preds_combined = bart_model_3.predict(covariates=X_train) - y_hat_train_combined, sigma2_x_train_combined = bart_preds_combined['y_hat'], bart_preds_combined['variance_forest_predictions'] + y_hat_train_combined, sigma2_x_train_combined = ( + bart_preds_combined["y_hat"], + bart_preds_combined["variance_forest_predictions"], + ) assert y_hat_train_combined.shape == (n_train, num_mcmc * 2) assert sigma2_x_train_combined.shape == (n_train, num_mcmc * 2) np.testing.assert_allclose( @@ -424,7 +427,8 @@ def conditional_stddev(X): sigma2_x_train_combined[:, 0:num_mcmc], bart_model.sigma2_x_train ) np.testing.assert_allclose( - sigma2_x_train_combined[:, num_mcmc : (2 * num_mcmc)], bart_model_2.sigma2_x_train + sigma2_x_train_combined[:, num_mcmc : (2 * num_mcmc)], + bart_model_2.sigma2_x_train, ) np.testing.assert_allclose( bart_model_3.global_var_samples[0:num_mcmc], bart_model.global_var_samples @@ -543,7 +547,7 @@ def conditional_stddev(X): bart_preds_combined = bart_model_3.predict( covariates=X_train, basis=basis_train ) - y_hat_train_combined = bart_preds_combined['y_hat'] + y_hat_train_combined = bart_preds_combined["y_hat"] assert y_hat_train_combined.shape == (n_train, num_mcmc * 2) np.testing.assert_allclose( y_hat_train_combined[:, 0:num_mcmc], bart_model.y_hat_train @@ -668,7 +672,7 @@ def conditional_stddev(X): bart_preds_combined = bart_model_3.predict( covariates=X_train, basis=basis_train ) - y_hat_train_combined = bart_preds_combined['y_hat'] + y_hat_train_combined = bart_preds_combined["y_hat"] assert y_hat_train_combined.shape == (n_train, num_mcmc * 2) np.testing.assert_allclose( y_hat_train_combined[:, 0:num_mcmc], bart_model.y_hat_train @@ -825,7 +829,7 @@ def rfx_term(group_labels, basis): rfx_group_ids=group_labels_train, rfx_basis=rfx_basis_train, ) - y_hat_train_combined = bart_preds_combined['y_hat'] + y_hat_train_combined = bart_preds_combined["y_hat"] assert y_hat_train_combined.shape == (n_train, num_mcmc * 2) np.testing.assert_allclose( y_hat_train_combined[:, 0:num_mcmc], bart_model.y_hat_train @@ -999,7 +1003,7 @@ def conditional_stddev(X): rfx_group_ids=group_labels_train, rfx_basis=rfx_basis_train, ) - y_hat_train_combined = bart_preds_combined['y_hat'] + y_hat_train_combined = bart_preds_combined["y_hat"] assert y_hat_train_combined.shape == (n_train, num_mcmc * 2) np.testing.assert_allclose( y_hat_train_combined[:, 0:num_mcmc], bart_model.y_hat_train @@ -1059,7 +1063,7 @@ def outcome_mean(X, W): # Define the group rfx function def rfx_term(group_labels, basis): return np.where( - group_labels == 0, -5 + 1. * basis[:,1], 5 - 1. * basis[:,1] + group_labels == 0, -5 + 1.0 * basis[:, 1], 5 - 1.0 * basis[:, 1] ) # Define the conditional standard deviation function @@ -1122,13 +1126,14 @@ def conditional_stddev(X): ) # Specify scalar rfx parameters - general_params = { - "rfx_working_parameter_prior_mean": 1., - "rfx_group_parameter_prior_mean": 1., - "rfx_working_parameter_prior_cov": 1., - "rfx_group_parameter_prior_cov": 1., - "rfx_variance_prior_shape": 1, - "rfx_variance_prior_scale": 1 + rfx_params = { + "model_spec": "custom", + "working_parameter_prior_mean": 1.0, + "group_parameter_prior_mean": 1.0, + "working_parameter_prior_cov": 1.0, + "group_parameter_prior_cov": 1.0, + "variance_prior_shape": 1, + "variance_prior_scale": 1, } bart_model_2 = BARTModel() bart_model_2.sample( @@ -1144,17 +1149,18 @@ def conditional_stddev(X): num_gfr=num_gfr, num_burnin=num_burnin, num_mcmc=num_mcmc, - general_params=general_params, + random_effects_params=rfx_params, ) # Specify all relevant rfx parameters as vectors - general_params = { - "rfx_working_parameter_prior_mean": np.repeat(1., num_rfx_basis), - "rfx_group_parameter_prior_mean": np.repeat(1., num_rfx_basis), - "rfx_working_parameter_prior_cov": np.identity(num_rfx_basis), - "rfx_group_parameter_prior_cov": np.identity(num_rfx_basis), - "rfx_variance_prior_shape": 1, - "rfx_variance_prior_scale": 1 + rfx_params = { + "model_spec": "custom", + "working_parameter_prior_mean": np.repeat(1.0, num_rfx_basis), + "group_parameter_prior_mean": np.repeat(1.0, num_rfx_basis), + "working_parameter_prior_cov": np.identity(num_rfx_basis), + "group_parameter_prior_cov": np.identity(num_rfx_basis), + "variance_prior_shape": 1, + "variance_prior_scale": 1, } bart_model_3 = BARTModel() bart_model_3.sample( @@ -1170,5 +1176,30 @@ def conditional_stddev(X): num_gfr=num_gfr, num_burnin=num_burnin, num_mcmc=num_mcmc, - general_params=general_params, + random_effects_params=rfx_params, + ) + + # Fit a simpler intercept-only RFX model + rfx_params = {"model_spec": "intercept_only"} + bart_model_4 = BARTModel() + bart_model_4.sample( + X_train=X_train, + y_train=y_train, + leaf_basis_train=basis_train, + rfx_group_ids_train=group_labels_train, + X_test=X_test, + leaf_basis_test=basis_test, + rfx_group_ids_test=group_labels_test, + num_gfr=num_gfr, + num_burnin=num_burnin, + num_mcmc=num_mcmc, + random_effects_params=rfx_params, + ) + preds = bart_model_4.predict( + covariates=X_test, + basis=basis_test, + rfx_group_ids=group_labels_test, + type="posterior", + terms="rfx", ) + assert preds.shape == (n_test, num_mcmc) diff --git a/test/python/test_bcf.py b/test/python/test_bcf.py index 2e2f7fbf..dac1ea25 100644 --- a/test/python/test_bcf.py +++ b/test/python/test_bcf.py @@ -71,13 +71,19 @@ def test_binary_bcf(self): # Check overall prediction method bcf_preds = bcf_model.predict(X_test, Z_test, pi_test) - tau_hat, mu_hat, y_hat = bcf_preds['tau_hat'], bcf_preds['mu_hat'], bcf_preds['y_hat'] + tau_hat, mu_hat, y_hat = ( + bcf_preds["tau_hat"], + bcf_preds["mu_hat"], + bcf_preds["y_hat"], + ) assert tau_hat.shape == (n_test, num_mcmc) assert mu_hat.shape == (n_test, num_mcmc) assert y_hat.shape == (n_test, num_mcmc) - # Check treatment effect prediction method - tau_hat = bcf_model.predict_tau(X_test, Z_test, pi_test) + # Check that we can predict just treatment effects + tau_hat = bcf_model.predict( + X=X_test, Z=Z_test, propensity=pi_test, terms="cate" + ) assert tau_hat.shape == (n_test, num_mcmc) # Run BCF without test set and with propensity score @@ -101,12 +107,18 @@ def test_binary_bcf(self): # Check overall prediction method bcf_preds = bcf_model.predict(X_test, Z_test, pi_test) - tau_hat, mu_hat, y_hat = bcf_preds['tau_hat'], bcf_preds['mu_hat'], bcf_preds['y_hat'] + tau_hat, mu_hat, y_hat = ( + bcf_preds["tau_hat"], + bcf_preds["mu_hat"], + bcf_preds["y_hat"], + ) assert tau_hat.shape == (n_test, num_mcmc) assert mu_hat.shape == (n_test, num_mcmc) assert y_hat.shape == (n_test, num_mcmc) # Check treatment effect prediction method - tau_hat = bcf_model.predict_tau(X_test, Z_test, pi_test) + tau_hat = bcf_model.predict( + X=X_test, Z=Z_test, propensity=pi_test, terms="cate" + ) assert tau_hat.shape == (n_test, num_mcmc) # Run BCF with test set and without propensity score @@ -136,13 +148,17 @@ def test_binary_bcf(self): # Check overall prediction method bcf_preds = bcf_model.predict(X_test, Z_test) - tau_hat, mu_hat, y_hat = bcf_preds['tau_hat'], bcf_preds['mu_hat'], bcf_preds['y_hat'] + tau_hat, mu_hat, y_hat = ( + bcf_preds["tau_hat"], + bcf_preds["mu_hat"], + bcf_preds["y_hat"], + ) assert tau_hat.shape == (n_test, num_mcmc) assert mu_hat.shape == (n_test, num_mcmc) assert y_hat.shape == (n_test, num_mcmc) # Check treatment effect prediction method - tau_hat = bcf_model.predict_tau(X_test, Z_test) + tau_hat = bcf_model.predict(X=X_test, Z=Z_test, terms="cate") assert tau_hat.shape == (n_test, num_mcmc) # Run BCF without test set and without propensity score @@ -166,13 +182,17 @@ def test_binary_bcf(self): # Check overall prediction method bcf_preds = bcf_model.predict(X_test, Z_test) - tau_hat, mu_hat, y_hat = bcf_preds['tau_hat'], bcf_preds['mu_hat'], bcf_preds['y_hat'] + tau_hat, mu_hat, y_hat = ( + bcf_preds["tau_hat"], + bcf_preds["mu_hat"], + bcf_preds["y_hat"], + ) assert tau_hat.shape == (n_test, num_mcmc) assert mu_hat.shape == (n_test, num_mcmc) assert y_hat.shape == (n_test, num_mcmc) # Check treatment effect prediction method - tau_hat = bcf_model.predict_tau(X_test, Z_test) + tau_hat = bcf_model.predict(X=X_test, Z=Z_test, terms="cate") def test_continuous_univariate_bcf(self): # RNG @@ -239,13 +259,19 @@ def test_continuous_univariate_bcf(self): # Check overall prediction method bcf_preds = bcf_model.predict(X_test, Z_test, pi_test) - tau_hat, mu_hat, y_hat = bcf_preds['tau_hat'], bcf_preds['mu_hat'], bcf_preds['y_hat'] + tau_hat, mu_hat, y_hat = ( + bcf_preds["tau_hat"], + bcf_preds["mu_hat"], + bcf_preds["y_hat"], + ) assert tau_hat.shape == (n_test, num_mcmc) assert mu_hat.shape == (n_test, num_mcmc) assert y_hat.shape == (n_test, num_mcmc) # Check treatment effect prediction method - tau_hat = bcf_model.predict_tau(X_test, Z_test, pi_test) + tau_hat = bcf_model.predict( + X=X_test, Z=Z_test, propensity=pi_test, terms="cate" + ) assert tau_hat.shape == (n_test, num_mcmc) # Run second BCF model with test set and propensity score @@ -275,13 +301,19 @@ def test_continuous_univariate_bcf(self): # Check overall prediction method bcf_preds_2 = bcf_model_2.predict(X_test, Z_test, pi_test) - tau_hat_2, mu_hat_2, y_hat_2 = bcf_preds_2['tau_hat'], bcf_preds_2['mu_hat'], bcf_preds_2['y_hat'] + tau_hat_2, mu_hat_2, y_hat_2 = ( + bcf_preds_2["tau_hat"], + bcf_preds_2["mu_hat"], + bcf_preds_2["y_hat"], + ) assert tau_hat_2.shape == (n_test, num_mcmc) assert mu_hat_2.shape == (n_test, num_mcmc) assert y_hat_2.shape == (n_test, num_mcmc) # Check treatment effect prediction method - tau_hat_2 = bcf_model_2.predict_tau(X_test, Z_test, pi_test) + tau_hat_2 = bcf_model_2.predict( + X=X_test, Z=Z_test, propensity=pi_test, terms="cate" + ) assert tau_hat_2.shape == (n_test, num_mcmc) # Combine into a single model @@ -291,7 +323,11 @@ def test_continuous_univariate_bcf(self): # Assertions bcf_preds_3 = bcf_model_3.predict(X_test, Z_test, pi_test) - tau_hat_3, mu_hat_3, y_hat_3 = bcf_preds_3['tau_hat'], bcf_preds_3['mu_hat'], bcf_preds_3['y_hat'] + tau_hat_3, mu_hat_3, y_hat_3 = ( + bcf_preds_3["tau_hat"], + bcf_preds_3["mu_hat"], + bcf_preds_3["y_hat"], + ) assert tau_hat_3.shape == (n_train, num_mcmc * 2) assert mu_hat_3.shape == (n_train, num_mcmc * 2) assert y_hat_3.shape == (n_train, num_mcmc * 2) @@ -330,13 +366,19 @@ def test_continuous_univariate_bcf(self): # Check overall prediction method bcf_preds = bcf_model.predict(X_test, Z_test, pi_test) - tau_hat, mu_hat, y_hat = bcf_preds['tau_hat'], bcf_preds['mu_hat'], bcf_preds['y_hat'] + tau_hat, mu_hat, y_hat = ( + bcf_preds["tau_hat"], + bcf_preds["mu_hat"], + bcf_preds["y_hat"], + ) assert tau_hat.shape == (n_test, num_mcmc) assert mu_hat.shape == (n_test, num_mcmc) assert y_hat.shape == (n_test, num_mcmc) # Check treatment effect prediction method - tau_hat = bcf_model.predict_tau(X_test, Z_test, pi_test) + tau_hat = bcf_model.predict( + X=X_test, Z=Z_test, propensity=pi_test, terms="cate" + ) assert tau_hat.shape == (n_test, num_mcmc) # Run BCF with test set and without propensity score @@ -366,13 +408,17 @@ def test_continuous_univariate_bcf(self): # Check overall prediction method bcf_preds = bcf_model.predict(X_test, Z_test, pi_test) - tau_hat, mu_hat, y_hat = bcf_preds['tau_hat'], bcf_preds['mu_hat'], bcf_preds['y_hat'] + tau_hat, mu_hat, y_hat = ( + bcf_preds["tau_hat"], + bcf_preds["mu_hat"], + bcf_preds["y_hat"], + ) assert tau_hat.shape == (n_test, num_mcmc) assert mu_hat.shape == (n_test, num_mcmc) assert y_hat.shape == (n_test, num_mcmc) # Check treatment effect prediction method - tau_hat = bcf_model.predict_tau(X_test, Z_test) + tau_hat = bcf_model.predict(X=X_test, Z=Z_test, terms="cate") assert tau_hat.shape == (n_test, num_mcmc) # Run BCF without test set and without propensity score @@ -396,13 +442,17 @@ def test_continuous_univariate_bcf(self): # Check overall prediction method bcf_preds = bcf_model.predict(X_test, Z_test) - tau_hat, mu_hat, y_hat = bcf_preds['tau_hat'], bcf_preds['mu_hat'], bcf_preds['y_hat'] + tau_hat, mu_hat, y_hat = ( + bcf_preds["tau_hat"], + bcf_preds["mu_hat"], + bcf_preds["y_hat"], + ) assert tau_hat.shape == (n_test, num_mcmc) assert mu_hat.shape == (n_test, num_mcmc) assert y_hat.shape == (n_test, num_mcmc) # Check treatment effect prediction method - tau_hat = bcf_model.predict_tau(X_test, Z_test) + tau_hat = bcf_model.predict(X=X_test, Z=Z_test, terms="cate") # Run second BCF model with test set and propensity score bcf_model_2 = BCFModel() @@ -424,13 +474,17 @@ def test_continuous_univariate_bcf(self): # Check overall prediction method bcf_preds_2 = bcf_model_2.predict(X_test, Z_test) - tau_hat_2, mu_hat_2, y_hat_2 = bcf_preds_2['tau_hat'], bcf_preds_2['mu_hat'], bcf_preds_2['y_hat'] + tau_hat_2, mu_hat_2, y_hat_2 = ( + bcf_preds_2["tau_hat"], + bcf_preds_2["mu_hat"], + bcf_preds_2["y_hat"], + ) assert tau_hat_2.shape == (n_test, num_mcmc) assert mu_hat_2.shape == (n_test, num_mcmc) assert y_hat_2.shape == (n_test, num_mcmc) # Check treatment effect prediction method - tau_hat_2 = bcf_model_2.predict_tau(X_test, Z_test) + tau_hat_2 = bcf_model_2.predict(X=X_test, Z=Z_test, terms="cate") assert tau_hat_2.shape == (n_test, num_mcmc) # Combine into a single model @@ -440,7 +494,11 @@ def test_continuous_univariate_bcf(self): # Assertions bcf_preds_3 = bcf_model_3.predict(X_test, Z_test) - tau_hat_3, mu_hat_3, y_hat_3 = bcf_preds_3['tau_hat'], bcf_preds_3['mu_hat'], bcf_preds_3['y_hat'] + tau_hat_3, mu_hat_3, y_hat_3 = ( + bcf_preds_3["tau_hat"], + bcf_preds_3["mu_hat"], + bcf_preds_3["y_hat"], + ) assert tau_hat_3.shape == (n_train, num_mcmc * 2) assert mu_hat_3.shape == (n_train, num_mcmc * 2) assert y_hat_3.shape == (n_train, num_mcmc * 2) @@ -522,13 +580,19 @@ def test_multivariate_bcf(self): # Check overall prediction method bcf_preds = bcf_model.predict(X_test, Z_test, pi_test) - tau_hat, mu_hat, y_hat = bcf_preds['tau_hat'], bcf_preds['mu_hat'], bcf_preds['y_hat'] + tau_hat, mu_hat, y_hat = ( + bcf_preds["tau_hat"], + bcf_preds["mu_hat"], + bcf_preds["y_hat"], + ) assert tau_hat.shape == (n_test, num_mcmc, treatment_dim) assert mu_hat.shape == (n_test, num_mcmc) assert y_hat.shape == (n_test, num_mcmc) # Check treatment effect prediction method - tau_hat = bcf_model.predict_tau(X_test, Z_test, pi_test) + tau_hat = bcf_model.predict( + X=X_test, Z=Z_test, propensity=pi_test, terms="cate" + ) assert tau_hat.shape == (n_test, num_mcmc, treatment_dim) # Run BCF without test set and with propensity score @@ -552,13 +616,19 @@ def test_multivariate_bcf(self): # Check overall prediction method bcf_preds = bcf_model.predict(X_test, Z_test, pi_test) - tau_hat, mu_hat, y_hat = bcf_preds['tau_hat'], bcf_preds['mu_hat'], bcf_preds['y_hat'] + tau_hat, mu_hat, y_hat = ( + bcf_preds["tau_hat"], + bcf_preds["mu_hat"], + bcf_preds["y_hat"], + ) assert tau_hat.shape == (n_test, num_mcmc, treatment_dim) assert mu_hat.shape == (n_test, num_mcmc) assert y_hat.shape == (n_test, num_mcmc) # Check treatment effect prediction method - tau_hat = bcf_model.predict_tau(X_test, Z_test, pi_test) + tau_hat = bcf_model.predict( + X=X_test, Z=Z_test, propensity=pi_test, terms="cate" + ) assert tau_hat.shape == (n_test, num_mcmc, treatment_dim) # Run BCF with test set and without propensity score @@ -658,14 +728,21 @@ def test_binary_bcf_heteroskedastic(self): # Check overall prediction method bcf_preds = bcf_model.predict(X_test, Z_test, pi_test) - tau_hat, mu_hat, y_hat, sigma2_x_hat = bcf_preds['tau_hat'], bcf_preds['mu_hat'], bcf_preds['y_hat'], bcf_preds['variance_forest_predictions'] + tau_hat, mu_hat, y_hat, sigma2_x_hat = ( + bcf_preds["tau_hat"], + bcf_preds["mu_hat"], + bcf_preds["y_hat"], + bcf_preds["variance_forest_predictions"], + ) assert tau_hat.shape == (n_test, num_mcmc) assert mu_hat.shape == (n_test, num_mcmc) assert y_hat.shape == (n_test, num_mcmc) assert sigma2_x_hat.shape == (n_test, num_mcmc) # Check treatment effect prediction method - tau_hat = bcf_model.predict_tau(X_test, Z_test, pi_test) + tau_hat = bcf_model.predict( + X=X_test, Z=Z_test, propensity=pi_test, terms="cate" + ) assert tau_hat.shape == (n_test, num_mcmc) # Run BCF without test set and with propensity score @@ -690,32 +767,28 @@ def test_binary_bcf_heteroskedastic(self): # Check overall prediction method bcf_preds = bcf_model.predict(X_test, Z_test, pi_test) - assert bcf_preds['tau_hat'].shape == (n_test, num_mcmc) - assert bcf_preds['mu_hat'].shape == (n_test, num_mcmc) - assert bcf_preds['y_hat'].shape == (n_test, num_mcmc) - assert bcf_preds['variance_forest_predictions'].shape == (n_test, num_mcmc) + assert bcf_preds["tau_hat"].shape == (n_test, num_mcmc) + assert bcf_preds["mu_hat"].shape == (n_test, num_mcmc) + assert bcf_preds["y_hat"].shape == (n_test, num_mcmc) + assert bcf_preds["variance_forest_predictions"].shape == (n_test, num_mcmc) # Check predictions match bcf_preds = bcf_model.predict(X_train, Z_train, pi_train) - assert bcf_preds['tau_hat'].shape == (n_train, num_mcmc) - assert bcf_preds['mu_hat'].shape == (n_train, num_mcmc) - assert bcf_preds['y_hat'].shape == (n_train, num_mcmc) - assert bcf_preds['variance_forest_predictions'].shape == (n_train, num_mcmc) - np.testing.assert_allclose( - bcf_preds['y_hat'], bcf_model.y_hat_train - ) - np.testing.assert_allclose( - bcf_preds['mu_hat'], bcf_model.mu_hat_train - ) - np.testing.assert_allclose( - bcf_preds['tau_hat'], bcf_model.tau_hat_train - ) + assert bcf_preds["tau_hat"].shape == (n_train, num_mcmc) + assert bcf_preds["mu_hat"].shape == (n_train, num_mcmc) + assert bcf_preds["y_hat"].shape == (n_train, num_mcmc) + assert bcf_preds["variance_forest_predictions"].shape == (n_train, num_mcmc) + np.testing.assert_allclose(bcf_preds["y_hat"], bcf_model.y_hat_train) + np.testing.assert_allclose(bcf_preds["mu_hat"], bcf_model.mu_hat_train) + np.testing.assert_allclose(bcf_preds["tau_hat"], bcf_model.tau_hat_train) np.testing.assert_allclose( - bcf_preds['variance_forest_predictions'], bcf_model.sigma2_x_train + bcf_preds["variance_forest_predictions"], bcf_model.sigma2_x_train ) # Check treatment effect prediction method - tau_hat = bcf_model.predict_tau(X_test, Z_test, pi_test) + tau_hat = bcf_model.predict( + X=X_test, Z=Z_test, propensity=pi_test, terms="cate" + ) assert tau_hat.shape == (n_test, num_mcmc) # Run BCF with test set and without propensity score @@ -746,13 +819,13 @@ def test_binary_bcf_heteroskedastic(self): # Check overall prediction method bcf_preds = bcf_model.predict(X_test, Z_test) - assert bcf_preds['tau_hat'].shape == (n_test, num_mcmc) - assert bcf_preds['mu_hat'].shape == (n_test, num_mcmc) - assert bcf_preds['y_hat'].shape == (n_test, num_mcmc) - assert bcf_preds['variance_forest_predictions'].shape == (n_test, num_mcmc) + assert bcf_preds["tau_hat"].shape == (n_test, num_mcmc) + assert bcf_preds["mu_hat"].shape == (n_test, num_mcmc) + assert bcf_preds["y_hat"].shape == (n_test, num_mcmc) + assert bcf_preds["variance_forest_predictions"].shape == (n_test, num_mcmc) # Check treatment effect prediction method - tau_hat = bcf_model.predict_tau(X_test, Z_test) + tau_hat = bcf_model.predict(X=X_test, Z=Z_test, terms="cate") assert tau_hat.shape == (n_test, num_mcmc) # Run BCF without test set and without propensity score @@ -776,13 +849,13 @@ def test_binary_bcf_heteroskedastic(self): # Check overall prediction method bcf_preds = bcf_model.predict(X_test, Z_test) - assert bcf_preds['tau_hat'].shape == (n_test, num_mcmc) - assert bcf_preds['mu_hat'].shape == (n_test, num_mcmc) - assert bcf_preds['y_hat'].shape == (n_test, num_mcmc) + assert bcf_preds["tau_hat"].shape == (n_test, num_mcmc) + assert bcf_preds["mu_hat"].shape == (n_test, num_mcmc) + assert bcf_preds["y_hat"].shape == (n_test, num_mcmc) # Check treatment effect prediction method - tau_hat = bcf_model.predict_tau(X_test, Z_test) - + tau_hat = bcf_model.predict(X=X_test, Z=Z_test, terms="cate") + def test_bcf_rfx_parameters(self): # RNG random_seed = 101 @@ -811,9 +884,9 @@ def test_bcf_rfx_parameters(self): # Define the group rfx function def rfx_term(group_labels, basis): return np.where( - group_labels == 0, -5 + 1. * basis[:,1], 5 - 1. * basis[:,1] + group_labels == 0, -5 + 1.0 * basis[:, 1], 5 - 1.0 * basis[:, 1] ) - + # Generate outcome epsilon = rng.normal(0, 1, n) y = mu_X + tau_X * Z + rfx_term(group_labels, rfx_basis) + epsilon @@ -856,17 +929,17 @@ def rfx_term(group_labels, basis): num_gfr=num_gfr, num_burnin=num_burnin, num_mcmc=num_mcmc, - general_params=general_params + general_params=general_params, ) # Specify scalar rfx parameters - general_params = { - "rfx_working_parameter_prior_mean": 1., - "rfx_group_parameter_prior_mean": 1., - "rfx_working_parameter_prior_cov": 1., - "rfx_group_parameter_prior_cov": 1., - "rfx_variance_prior_shape": 1, - "rfx_variance_prior_scale": 1 + rfx_params = { + "working_parameter_prior_mean": 1.0, + "group_parameter_prior_mean": 1.0, + "working_parameter_prior_cov": 1.0, + "group_parameter_prior_cov": 1.0, + "variance_prior_shape": 1, + "variance_prior_scale": 1, } bcf_model_2 = BCFModel() bcf_model_2.sample( @@ -884,17 +957,17 @@ def rfx_term(group_labels, basis): num_gfr=num_gfr, num_burnin=num_burnin, num_mcmc=num_mcmc, - general_params=general_params + random_effects_params=rfx_params, ) # Specify all relevant rfx parameters as vectors - general_params = { - "rfx_working_parameter_prior_mean": np.repeat(1., num_rfx_basis), - "rfx_group_parameter_prior_mean": np.repeat(1., num_rfx_basis), - "rfx_working_parameter_prior_cov": np.identity(num_rfx_basis), - "rfx_group_parameter_prior_cov": np.identity(num_rfx_basis), - "rfx_variance_prior_shape": 1, - "rfx_variance_prior_scale": 1 + rfx_params = { + "working_parameter_prior_mean": np.repeat(1.0, num_rfx_basis), + "group_parameter_prior_mean": np.repeat(1.0, num_rfx_basis), + "working_parameter_prior_cov": np.identity(num_rfx_basis), + "group_parameter_prior_cov": np.identity(num_rfx_basis), + "variance_prior_shape": 1, + "variance_prior_scale": 1, } bcf_model_3 = BCFModel() bcf_model_3.sample( @@ -912,5 +985,5 @@ def rfx_term(group_labels, basis): num_gfr=num_gfr, num_burnin=num_burnin, num_mcmc=num_mcmc, - general_params=general_params + random_effects_params=rfx_params, ) diff --git a/test/python/test_data.py b/test/python/test_data.py index 09c75154..cda8f58b 100644 --- a/test/python/test_data.py +++ b/test/python/test_data.py @@ -2,6 +2,7 @@ from stochtree import Dataset, RandomEffectsDataset + class TestDataset: def test_dataset_update(self): # Generate data @@ -12,7 +13,7 @@ def test_dataset_update(self): covariates = rng.uniform(0, 1, size=(n, num_covariates)) basis = rng.uniform(0, 1, size=(n, num_basis)) variance_weights = rng.uniform(0, 1, size=n) - + # Construct dataset forest_dataset = Dataset() forest_dataset.add_covariates(covariates) @@ -22,18 +23,21 @@ def test_dataset_update(self): assert forest_dataset.num_covariates() == num_covariates assert forest_dataset.num_basis() == num_basis assert forest_dataset.has_variance_weights() - + # Update dataset new_basis = rng.uniform(0, 1, size=(n, num_basis)) new_variance_weights = rng.uniform(0, 1, size=n) with np.testing.assert_no_warnings(): forest_dataset.update_basis(new_basis) forest_dataset.update_variance_weights(new_variance_weights) - + # Check that we recover the correct data through get_covariates, get_basis, and get_variance_weights np.testing.assert_array_equal(forest_dataset.get_covariates(), covariates) np.testing.assert_array_equal(forest_dataset.get_basis(), new_basis) - np.testing.assert_array_equal(forest_dataset.get_variance_weights(), new_variance_weights) + np.testing.assert_array_equal( + forest_dataset.get_variance_weights(), new_variance_weights + ) + class TestRFXDataset: def test_rfx_dataset_update(self): @@ -48,7 +52,7 @@ def test_rfx_dataset_update(self): if num_basis > 1: basis[:, 1:] = rng.uniform(-1, 1, (n, num_basis - 1)) variance_weights = rng.uniform(0, 1, size=n) - + # Construct dataset rfx_dataset = RandomEffectsDataset() rfx_dataset.add_group_labels(group_labels) @@ -57,16 +61,17 @@ def test_rfx_dataset_update(self): assert rfx_dataset.num_observations() == n assert rfx_dataset.num_basis() == num_basis assert rfx_dataset.has_variance_weights() - + # Update dataset new_basis = rng.uniform(0, 1, size=(n, num_basis)) new_variance_weights = rng.uniform(0, 1, size=n) with np.testing.assert_no_warnings(): rfx_dataset.update_basis(new_basis) rfx_dataset.update_variance_weights(new_variance_weights) - + # Check that we recover the correct data through get_group_labels, get_basis, and get_variance_weights np.testing.assert_array_equal(rfx_dataset.get_group_labels(), group_labels) np.testing.assert_array_equal(rfx_dataset.get_basis(), new_basis) - np.testing.assert_array_equal(rfx_dataset.get_variance_weights(), new_variance_weights) - + np.testing.assert_array_equal( + rfx_dataset.get_variance_weights(), new_variance_weights + ) diff --git a/test/python/test_forest.py b/test/python/test_forest.py index 267b09b4..9d5e6b48 100644 --- a/test/python/test_forest.py +++ b/test/python/test_forest.py @@ -6,14 +6,14 @@ class TestPredict: def test_constant_forest_construction(self): # Create dataset - X = np.array( - [[1.5, 8.7, 1.2], - [2.7, 3.4, 5.4], - [3.6, 1.2, 9.3], - [4.4, 5.4, 10.4], - [5.3, 9.3, 3.6], - [6.1, 10.4, 4.4]] - ) + X = np.array([ + [1.5, 8.7, 1.2], + [2.7, 3.4, 5.4], + [3.6, 1.2, 9.3], + [4.4, 5.4, 10.4], + [5.3, 9.3, 3.6], + [6.1, 10.4, 4.4], + ]) n, p = X.shape num_trees = 10 output_dim = 1 @@ -26,41 +26,41 @@ def test_constant_forest_construction(self): # Check that regular and "raw" predictions are the same (since the leaf is constant) pred = forest.predict(forest_dataset) - pred_exp = np.array([0.,0.,0.,0.,0.,0.]) + pred_exp = np.array([0.0, 0.0, 0.0, 0.0, 0.0, 0.0]) # Assertion np.testing.assert_almost_equal(pred, pred_exp) # Split the root of the first tree in the ensemble at X[,1] > 4.0 - forest.add_numeric_split(0, 0, 0, 4.0, -5., 5.) + forest.add_numeric_split(0, 0, 0, 4.0, -5.0, 5.0) # Check that regular and "raw" predictions are the same (since the leaf is constant) pred = forest.predict(forest_dataset) - pred_exp = np.array([-5.,-5.,-5.,5.,5.,5.]) - + pred_exp = np.array([-5.0, -5.0, -5.0, 5.0, 5.0, 5.0]) + # Assertion np.testing.assert_almost_equal(pred, pred_exp) # Split the left leaf of the first tree in the ensemble at X[,2] > 4.0 forest.add_numeric_split(0, 1, 1, 4.0, -7.5, -2.5) - + # Check that regular and "raw" predictions are the same (since the leaf is constant) pred = forest.predict(forest_dataset) - pred_exp = np.array([-2.5,-7.5,-7.5,5.,5.,5.]) - + pred_exp = np.array([-2.5, -7.5, -7.5, 5.0, 5.0, 5.0]) + # Assertion np.testing.assert_almost_equal(pred, pred_exp) - + def test_constant_forest_merge_arithmetic(self): # Create dataset - X = np.array( - [[1.5, 8.7, 1.2], - [2.7, 3.4, 5.4], - [3.6, 1.2, 9.3], - [4.4, 5.4, 10.4], - [5.3, 9.3, 3.6], - [6.1, 10.4, 4.4]] - ) + X = np.array([ + [1.5, 8.7, 1.2], + [2.7, 3.4, 5.4], + [3.6, 1.2, 9.3], + [4.4, 5.4, 10.4], + [5.3, 9.3, 3.6], + [6.1, 10.4, 4.4], + ]) n, p = X.shape num_trees = 10 output_dim = 1 @@ -78,25 +78,25 @@ def test_constant_forest_merge_arithmetic(self): # Check that predictions are as expected pred1 = forest1.predict(forest_dataset) pred2 = forest2.predict(forest_dataset) - pred_exp1 = np.array([0.,0.,0.,0.,0.,0.]) - pred_exp2 = np.array([0.,0.,0.,0.,0.,0.]) + pred_exp1 = np.array([0.0, 0.0, 0.0, 0.0, 0.0, 0.0]) + pred_exp2 = np.array([0.0, 0.0, 0.0, 0.0, 0.0, 0.0]) # Assertion np.testing.assert_almost_equal(pred1, pred_exp1) np.testing.assert_almost_equal(pred2, pred_exp2) # Split the root of the first tree in the first forest at X[,1] > 4.0 - forest1.add_numeric_split(0, 0, 0, 4.0, -5., 5.) - + forest1.add_numeric_split(0, 0, 0, 4.0, -5.0, 5.0) + # Split the root of the first tree in the second forest at X[,1] > 3.0 - forest2.add_numeric_split(0, 0, 0, 3.0, -1., 1.) + forest2.add_numeric_split(0, 0, 0, 3.0, -1.0, 1.0) # Check that predictions are as expected pred1 = forest1.predict(forest_dataset) pred2 = forest2.predict(forest_dataset) - pred_exp1 = np.array([-5.,-5.,-5.,5.,5.,5.]) - pred_exp2 = np.array([-1.,-1.,1.,1.,1.,1.]) - + pred_exp1 = np.array([-5.0, -5.0, -5.0, 5.0, 5.0, 5.0]) + pred_exp2 = np.array([-1.0, -1.0, 1.0, 1.0, 1.0, 1.0]) + # Assertion np.testing.assert_almost_equal(pred1, pred_exp1) np.testing.assert_almost_equal(pred2, pred_exp2) @@ -106,13 +106,13 @@ def test_constant_forest_merge_arithmetic(self): # Split the left leaf of the first tree in the first forest at X[,2] > 4.0 forest2.add_numeric_split(0, 1, 1, 4.0, -1.5, -0.5) - + # Check that predictions are as expected pred1 = forest1.predict(forest_dataset) pred2 = forest2.predict(forest_dataset) - pred_exp1 = np.array([-2.5,-7.5,-7.5,5.,5.,5.]) - pred_exp2 = np.array([-0.5,-1.5,1.,1.,1.,1.]) - + pred_exp1 = np.array([-2.5, -7.5, -7.5, 5.0, 5.0, 5.0]) + pred_exp2 = np.array([-0.5, -1.5, 1.0, 1.0, 1.0, 1.0]) + # Assertion np.testing.assert_almost_equal(pred1, pred_exp1) np.testing.assert_almost_equal(pred2, pred_exp2) @@ -122,7 +122,7 @@ def test_constant_forest_merge_arithmetic(self): # Check that predictions are as expected pred = forest1.predict(forest_dataset) - pred_exp = np.array([-3.0,-9.0,-6.5,6.0,6.0,6.0]) + pred_exp = np.array([-3.0, -9.0, -6.5, 6.0, 6.0, 6.0]) # Assertion np.testing.assert_almost_equal(pred, pred_exp) @@ -132,14 +132,14 @@ def test_constant_forest_merge_arithmetic(self): # Check that predictions are as expected pred = forest1.predict(forest_dataset) - pred_exp = np.array([7.0,1.0,3.5,16.0,16.0,16.0]) + pred_exp = np.array([7.0, 1.0, 3.5, 16.0, 16.0, 16.0]) # Assertion np.testing.assert_almost_equal(pred, pred_exp) # Check that "old" forest is still intact pred = forest2.predict(forest_dataset) - pred_exp = np.array([-0.5,-1.5,1.,1.,1.,1.]) + pred_exp = np.array([-0.5, -1.5, 1.0, 1.0, 1.0, 1.0]) # Assertion np.testing.assert_almost_equal(pred, pred_exp) @@ -149,7 +149,7 @@ def test_constant_forest_merge_arithmetic(self): # Check that predictions are as expected pred = forest1.predict(forest_dataset) - pred_exp = np.array([-3.0,-9.0,-6.5,6.0,6.0,6.0]) + pred_exp = np.array([-3.0, -9.0, -6.5, 6.0, 6.0, 6.0]) # Assertion np.testing.assert_almost_equal(pred, pred_exp) @@ -159,7 +159,7 @@ def test_constant_forest_merge_arithmetic(self): # Check that predictions are as expected pred = forest1.predict(forest_dataset) - pred_exp = np.array([-6.0,-18.0,-13.0,12.0,12.0,12.0]) + pred_exp = np.array([-6.0, -18.0, -13.0, 12.0, 12.0, 12.0]) # Assertion np.testing.assert_almost_equal(pred, pred_exp) diff --git a/test/python/test_forest_container.py b/test/python/test_forest_container.py index 9eda8569..4368f9b3 100644 --- a/test/python/test_forest_container.py +++ b/test/python/test_forest_container.py @@ -7,14 +7,14 @@ class TestPredict: def test_constant_leaf_forest_container(self): # Create dataset - X = np.array( - [[1.5, 8.7, 1.2], - [2.7, 3.4, 5.4], - [3.6, 1.2, 9.3], - [4.4, 5.4, 10.4], - [5.3, 9.3, 3.6], - [6.1, 10.4, 4.4]] - ) + X = np.array([ + [1.5, 8.7, 1.2], + [2.7, 3.4, 5.4], + [3.6, 1.2, 9.3], + [4.4, 5.4, 10.4], + [5.3, 9.3, 3.6], + [6.1, 10.4, 4.4], + ]) n, p = X.shape num_trees = 10 output_dim = 1 @@ -23,11 +23,11 @@ def test_constant_leaf_forest_container(self): forest_samples = ForestContainer(num_trees, output_dim, True, False) # Initialize a forest with constant root predictions - forest_samples.add_sample(0.) + forest_samples.add_sample(0.0) # Split the root of the first tree in the ensemble at X[,1] > 4.0 # and then split the left leaf of the first tree in the ensemble at X[,2] > 4.0 - forest_samples.add_numeric_split(0, 0, 0, 0, 4.0, -5., 5.) + forest_samples.add_numeric_split(0, 0, 0, 0, 4.0, -5.0, 5.0) forest_samples.add_numeric_split(0, 0, 1, 1, 4.0, -7.5, -2.5) # Store the predictions of the "original" forest before modifications @@ -35,48 +35,48 @@ def test_constant_leaf_forest_container(self): # Multiply first forest by 2.0 forest_samples.multiply_forest(0, 2.0) - + # Check that predictions are all double pred = forest_samples.predict(forest_dataset) pred_expected = pred_orig * 2.0 - + # Assertion np.testing.assert_almost_equal(pred, pred_expected) # Add 1.0 to every tree in first forest forest_samples.add_to_forest(0, 1.0) - + # Check that predictions are += num_trees pred_expected = pred + num_trees pred = forest_samples.predict(forest_dataset) - + # Assertion np.testing.assert_almost_equal(pred, pred_expected) # Initialize a new forest with constant root predictions - forest_samples.add_sample(0.) + forest_samples.add_sample(0.0) # Split the second forest as the first forest was split - forest_samples.add_numeric_split(1, 0, 0, 0, 4.0, -5., 5.) + forest_samples.add_numeric_split(1, 0, 0, 0, 4.0, -5.0, 5.0) forest_samples.add_numeric_split(1, 0, 1, 1, 4.0, -7.5, -2.5) - + # Check that predictions are as expected pred_expected_new = np.c_[pred_expected, pred_orig] pred = forest_samples.predict(forest_dataset) - + # Assertion np.testing.assert_almost_equal(pred, pred_expected_new) # Combine second forest with the first forest - forest_samples.combine_forests(np.array([0,1])) + forest_samples.combine_forests(np.array([0, 1])) # Check that predictions are as expected pred_expected_new = np.c_[pred_expected + pred_orig, pred_orig] pred = forest_samples.predict(forest_dataset) - + # Assertion np.testing.assert_almost_equal(pred, pred_expected_new) - + def test_collapse_forest_container(self): # RNG rng = np.random.default_rng() @@ -143,13 +143,22 @@ def outcome_mean(X): # Check that corresponding (sums of) predictions match container_inds = np.linspace(start=1, stop=num_mcmc, num=num_mcmc) - batch_inds = (container_inds - (num_mcmc - ((num_mcmc // (num_mcmc // batch_size)) * (num_mcmc // batch_size))) - 1) // batch_size + batch_inds = ( + container_inds + - ( + num_mcmc + - ((num_mcmc // (num_mcmc // batch_size)) * (num_mcmc // batch_size)) + ) + - 1 + ) // batch_size batch_inds = batch_inds.astype(int) num_batches = np.max(batch_inds) + 1 pred_orig_collapsed = np.empty((n_test, num_batches)) for i in range(num_batches): - pred_orig_collapsed[:,i] = np.sum(pred_orig[:,batch_inds == i], axis=1) / np.sum(batch_inds == i) - + pred_orig_collapsed[:, i] = np.sum( + pred_orig[:, batch_inds == i], axis=1 + ) / np.sum(batch_inds == i) + # Assertion np.testing.assert_almost_equal(pred_orig_collapsed, pred_new) @@ -180,13 +189,22 @@ def outcome_mean(X): # Check that corresponding (sums of) predictions match container_inds = np.linspace(start=1, stop=num_mcmc, num=num_mcmc) - batch_inds = (container_inds - (num_mcmc - ((num_mcmc // (num_mcmc // batch_size)) * (num_mcmc // batch_size))) - 1) // batch_size + batch_inds = ( + container_inds + - ( + num_mcmc + - ((num_mcmc // (num_mcmc // batch_size)) * (num_mcmc // batch_size)) + ) + - 1 + ) // batch_size batch_inds = batch_inds.astype(int) num_batches = np.max(batch_inds) + 1 pred_orig_collapsed = np.empty((n_test, num_batches)) for i in range(num_batches): - pred_orig_collapsed[:,i] = np.sum(pred_orig[:,batch_inds == i], axis=1) / np.sum(batch_inds == i) - + pred_orig_collapsed[:, i] = np.sum( + pred_orig[:, batch_inds == i], axis=1 + ) / np.sum(batch_inds == i) + # Assertion np.testing.assert_almost_equal(pred_orig_collapsed, pred_new) @@ -218,8 +236,8 @@ def outcome_mean(X): # Check that corresponding (sums of) predictions match num_batches = 1 pred_orig_collapsed = np.empty((n_test, num_batches)) - pred_orig_collapsed[:,0] = np.sum(pred_orig, axis=1) / batch_size - + pred_orig_collapsed[:, 0] = np.sum(pred_orig, axis=1) / batch_size + # Assertion np.testing.assert_almost_equal(pred_orig_collapsed, pred_new) @@ -250,6 +268,6 @@ def outcome_mean(X): # Check that corresponding (sums of) predictions match pred_orig_collapsed = pred_orig - + # Assertion np.testing.assert_almost_equal(pred_orig_collapsed, pred_new) diff --git a/test/python/test_json.py b/test/python/test_json.py index 1d0dc66d..48d4845b 100644 --- a/test/python/test_json.py +++ b/test/python/test_json.py @@ -14,7 +14,7 @@ JSONSerializer, Residual, ForestModelConfig, - GlobalModelConfig + GlobalModelConfig, ) @@ -41,17 +41,15 @@ def test_array(self): assert b == json_test.get_string_vector("b") def test_preprocessor(self): - df = pd.DataFrame( - { - "x1": [1.5, 2.7, 3.6, 4.4, 5.3, 6.1], - "x2": pd.Categorical( - ["a", "b", "c", "a", "b", "c"], - ordered=False, - categories=["c", "b", "a"], - ), - "x3": [1.2, 5.4, 9.3, 10.4, 3.6, 4.4], - } - ) + df = pd.DataFrame({ + "x1": [1.5, 2.7, 3.6, 4.4, 5.3, 6.1], + "x2": pd.Categorical( + ["a", "b", "c", "a", "b", "c"], + ordered=False, + categories=["c", "b", "a"], + ), + "x3": [1.2, 5.4, 9.3, 10.4, 3.6, 4.4], + }) cov_transformer = CovariatePreprocessor() df_transformed_orig = cov_transformer.fit_transform(df) cov_transformer_json = cov_transformer.to_json() @@ -60,27 +58,25 @@ def test_preprocessor(self): df_transformed_reloaded = cov_transformer_reloaded.transform(df) np.testing.assert_array_equal(df_transformed_orig, df_transformed_reloaded) - df_2 = pd.DataFrame( - { - "x1": [1.5, 2.7, 3.6, 4.4, 5.3, 6.1], - "x2": pd.Categorical( - ["a", "b", "c", "a", "b", "c"], - ordered=False, - categories=["c", "b", "a"], - ), - "x3": pd.Categorical( - ["a", "c", "d", "b", "d", "b"], - ordered=False, - categories=["c", "b", "a", "d"], - ), - "x4": pd.Categorical( - ["a", "b", "f", "f", "c", "a"], - ordered=True, - categories=["c", "b", "a", "f"], - ), - "x5": [1.2, 5.4, 9.3, 10.4, 3.6, 4.4], - } - ) + df_2 = pd.DataFrame({ + "x1": [1.5, 2.7, 3.6, 4.4, 5.3, 6.1], + "x2": pd.Categorical( + ["a", "b", "c", "a", "b", "c"], + ordered=False, + categories=["c", "b", "a"], + ), + "x3": pd.Categorical( + ["a", "c", "d", "b", "d", "b"], + ordered=False, + categories=["c", "b", "a", "d"], + ), + "x4": pd.Categorical( + ["a", "b", "f", "f", "c", "a"], + ordered=True, + categories=["c", "b", "a", "f"], + ), + "x5": [1.2, 5.4, 9.3, 10.4, 3.6, 4.4], + }) cov_transformer_2 = CovariatePreprocessor() df_transformed_orig_2 = cov_transformer_2.fit_transform(df_2) cov_transformer_json_2 = cov_transformer_2.to_json() @@ -89,9 +85,14 @@ def test_preprocessor(self): df_transformed_reloaded_2 = cov_transformer_reloaded_2.transform(df_2) np.testing.assert_array_equal(df_transformed_orig_2, df_transformed_reloaded_2) - np_3 = np.array( - [[1.5, 1.2], [2.7, 5.4], [3.6, 9.3], [4.4, 10.4], [5.3, 3.6], [6.1, 4.4]] - ) + np_3 = np.array([ + [1.5, 1.2], + [2.7, 5.4], + [3.6, 9.3], + [4.4, 10.4], + [5.3, 3.6], + [6.1, 4.4], + ]) cov_transformer_3 = CovariatePreprocessor() df_transformed_orig_3 = cov_transformer_3.fit_transform(np_3) cov_transformer_json_3 = cov_transformer_3.to_json() @@ -131,7 +132,7 @@ def outcome_mean(X): # Extract original predictions bart_preds = bart_model.predict(X) - forest_preds_y_mcmc_retrieved = bart_preds['y_hat'] + forest_preds_y_mcmc_retrieved = bart_preds["y_hat"] # Roundtrip to / from JSON json_test = JSONSerializer() @@ -213,7 +214,7 @@ def outcome_mean(X, W): residual = Residual(resid) # Forest samplers and temporary tracking data structures - leaf_model_type = 0 if p_W == 0 else 1 + 1*(p_W > 1) + leaf_model_type = 0 if p_W == 0 else 1 + 1 * (p_W > 1) forest_config = ForestModelConfig( num_trees=num_trees, num_features=p_X, @@ -231,9 +232,7 @@ def outcome_mean(X, W): global_config = GlobalModelConfig(global_error_variance=global_variance_init) forest_container = ForestContainer(num_trees, W.shape[1], False, False) active_forest = Forest(num_trees, W.shape[1], False, False) - forest_sampler = ForestSampler( - dataset, global_config, forest_config - ) + forest_sampler = ForestSampler(dataset, global_config, forest_config) cpp_rng = RNG(random_seed) global_var_model = GlobalVarianceModel() @@ -241,9 +240,10 @@ def outcome_mean(X, W): num_warmstart = 10 num_mcmc = 100 num_samples = num_warmstart + num_mcmc - global_var_samples = np.concatenate( - (np.array([global_variance_init]), np.repeat(0, num_samples)) - ) + global_var_samples = np.concatenate(( + np.array([global_variance_init]), + np.repeat(0, num_samples), + )) if p_W > 0: init_val = np.repeat(0.0, W.shape[1]) else: @@ -264,8 +264,8 @@ def outcome_mean(X, W): dataset, residual, cpp_rng, - global_config, - forest_config, + global_config, + forest_config, True, True, ) @@ -281,8 +281,8 @@ def outcome_mean(X, W): dataset, residual, cpp_rng, - global_config, - forest_config, + global_config, + forest_config, True, True, ) @@ -334,18 +334,20 @@ def outcome_mean(X, W): # Run BART bart_orig = BARTModel() - bart_orig.sample(X_train=X, y_train=y, leaf_basis_train=W, num_gfr=10, num_mcmc=10) + bart_orig.sample( + X_train=X, y_train=y, leaf_basis_train=W, num_gfr=10, num_mcmc=10 + ) # Extract predictions from the sampler bart_preds_orig = bart_orig.predict(X, W) - y_hat_orig = bart_preds_orig['y_hat'] + y_hat_orig = bart_preds_orig["y_hat"] # "Round-trip" the model to JSON string and back and check that the predictions agree bart_json_string = bart_orig.to_json() bart_reloaded = BARTModel() bart_reloaded.from_json(bart_json_string) bart_preds_reloaded = bart_reloaded.predict(X, W) - y_hat_reloaded = bart_preds_reloaded['y_hat'] + y_hat_reloaded = bart_preds_reloaded["y_hat"] np.testing.assert_almost_equal(y_hat_orig, y_hat_reloaded) def test_bart_rfx_string(self): @@ -407,19 +409,26 @@ def rfx_mean(group_labels, basis): # Run BART bart_orig = BARTModel() - bart_orig.sample(X_train=X, y_train=y, leaf_basis_train=W, rfx_group_ids_train=group_labels, - rfx_basis_train=basis, num_gfr=10, num_mcmc=10) + bart_orig.sample( + X_train=X, + y_train=y, + leaf_basis_train=W, + rfx_group_ids_train=group_labels, + rfx_basis_train=basis, + num_gfr=10, + num_mcmc=10, + ) # Extract predictions from the sampler bart_preds_orig = bart_orig.predict(X, W, group_labels, basis) - y_hat_orig = bart_preds_orig['y_hat'] + y_hat_orig = bart_preds_orig["y_hat"] # "Round-trip" the model to JSON string and back and check that the predictions agree bart_json_string = bart_orig.to_json() bart_reloaded = BARTModel() bart_reloaded.from_json(bart_json_string) bart_preds_reloaded = bart_reloaded.predict(X, W, group_labels, basis) - y_hat_reloaded = bart_preds_reloaded['y_hat'] + y_hat_reloaded = bart_preds_reloaded["y_hat"] np.testing.assert_almost_equal(y_hat_orig, y_hat_reloaded) def test_bcf_string(self): @@ -450,16 +459,22 @@ def test_bcf_string(self): # Extract predictions from the sampler bcf_preds_orig = bcf_orig.predict(X, Z, pi_X) - mu_hat_orig, tau_hat_orig, y_hat_orig = bcf_preds_orig['mu_hat'], bcf_preds_orig['tau_hat'], bcf_preds_orig['y_hat'] + mu_hat_orig, tau_hat_orig, y_hat_orig = ( + bcf_preds_orig["mu_hat"], + bcf_preds_orig["tau_hat"], + bcf_preds_orig["y_hat"], + ) # "Round-trip" the model to JSON string and back and check that the predictions agree bcf_json_string = bcf_orig.to_json() bcf_reloaded = BCFModel() bcf_reloaded.from_json(bcf_json_string) - bcf_preds_reloaded = bcf_reloaded.predict( - X, Z, pi_X + bcf_preds_reloaded = bcf_reloaded.predict(X, Z, pi_X) + mu_hat_reloaded, tau_hat_reloaded, y_hat_reloaded = ( + bcf_preds_reloaded["mu_hat"], + bcf_preds_reloaded["tau_hat"], + bcf_preds_reloaded["y_hat"], ) - mu_hat_reloaded, tau_hat_reloaded, y_hat_reloaded = bcf_preds_reloaded['mu_hat'], bcf_preds_reloaded['tau_hat'], bcf_preds_reloaded['y_hat'] np.testing.assert_almost_equal(y_hat_orig, y_hat_reloaded) np.testing.assert_almost_equal(tau_hat_orig, tau_hat_reloaded) np.testing.assert_almost_equal(mu_hat_orig, mu_hat_reloaded) @@ -511,21 +526,36 @@ def rfx_mean(group_labels, basis): # Run BCF bcf_orig = BCFModel() bcf_orig.sample( - X_train=X, Z_train=Z, y_train=y, pi_train=pi_X, rfx_group_ids_train=group_labels, rfx_basis_train=basis, num_gfr=10, num_mcmc=10 + X_train=X, + Z_train=Z, + y_train=y, + pi_train=pi_X, + rfx_group_ids_train=group_labels, + rfx_basis_train=basis, + num_gfr=10, + num_mcmc=10, ) # Extract predictions from the sampler bcf_preds_orig = bcf_orig.predict(X, Z, pi_X, group_labels, basis) - mu_hat_orig, tau_hat_orig, rfx_hat_orig, y_hat_orig = bcf_preds_orig['mu_hat'], bcf_preds_orig['tau_hat'], bcf_preds_orig['rfx_predictions'], bcf_preds_orig['y_hat'] + mu_hat_orig, tau_hat_orig, rfx_hat_orig, y_hat_orig = ( + bcf_preds_orig["mu_hat"], + bcf_preds_orig["tau_hat"], + bcf_preds_orig["rfx_predictions"], + bcf_preds_orig["y_hat"], + ) # "Round-trip" the model to JSON string and back and check that the predictions agree bcf_json_string = bcf_orig.to_json() bcf_reloaded = BCFModel() bcf_reloaded.from_json(bcf_json_string) - bcf_preds_reloaded = bcf_reloaded.predict( - X, Z, pi_X, group_labels, basis + bcf_preds_reloaded = bcf_reloaded.predict(X, Z, pi_X, group_labels, basis) + mu_hat_reloaded, tau_hat_reloaded, rfx_hat_reloaded, y_hat_reloaded = ( + bcf_preds_reloaded["mu_hat"], + bcf_preds_reloaded["tau_hat"], + bcf_preds_reloaded["rfx_predictions"], + bcf_preds_reloaded["y_hat"], ) - mu_hat_reloaded, tau_hat_reloaded, rfx_hat_reloaded, y_hat_reloaded = bcf_preds_reloaded['mu_hat'], bcf_preds_reloaded['tau_hat'], bcf_preds_reloaded['rfx_predictions'], bcf_preds_reloaded['y_hat'] np.testing.assert_almost_equal(y_hat_orig, y_hat_reloaded) np.testing.assert_almost_equal(tau_hat_orig, tau_hat_reloaded) np.testing.assert_almost_equal(mu_hat_orig, mu_hat_reloaded) @@ -557,16 +587,22 @@ def test_bcf_propensity_string(self): # Extract predictions from the sampler bcf_preds_orig = bcf_orig.predict(X, Z, pi_X) - mu_hat_orig, tau_hat_orig, y_hat_orig = bcf_preds_orig['mu_hat'], bcf_preds_orig['tau_hat'], bcf_preds_orig['y_hat'] + mu_hat_orig, tau_hat_orig, y_hat_orig = ( + bcf_preds_orig["mu_hat"], + bcf_preds_orig["tau_hat"], + bcf_preds_orig["y_hat"], + ) # "Round-trip" the model to JSON string and back and check that the predictions agree bcf_json_string = bcf_orig.to_json() bcf_reloaded = BCFModel() bcf_reloaded.from_json(bcf_json_string) - bcf_preds_reloaded = bcf_reloaded.predict( - X, Z, pi_X + bcf_preds_reloaded = bcf_reloaded.predict(X, Z, pi_X) + mu_hat_reloaded, tau_hat_reloaded, y_hat_reloaded = ( + bcf_preds_reloaded["mu_hat"], + bcf_preds_reloaded["tau_hat"], + bcf_preds_reloaded["y_hat"], ) - mu_hat_reloaded, tau_hat_reloaded, y_hat_reloaded = bcf_preds_reloaded['mu_hat'], bcf_preds_reloaded['tau_hat'], bcf_preds_reloaded['y_hat'] np.testing.assert_almost_equal(y_hat_orig, y_hat_reloaded) np.testing.assert_almost_equal(tau_hat_orig, tau_hat_reloaded) np.testing.assert_almost_equal(mu_hat_orig, mu_hat_reloaded) diff --git a/test/python/test_kernel.py b/test/python/test_kernel.py index 6d630874..6a6bff09 100644 --- a/test/python/test_kernel.py +++ b/test/python/test_kernel.py @@ -5,22 +5,22 @@ Dataset, Forest, ForestContainer, - compute_forest_leaf_indices, - compute_forest_max_leaf_index + compute_forest_leaf_indices, + compute_forest_max_leaf_index, ) class TestKernel: def test_forest(self): # Create dataset - X = np.array( - [[1.5, 8.7, 1.2], - [2.7, 3.4, 5.4], - [3.6, 1.2, 9.3], - [4.4, 5.4, 10.4], - [5.3, 9.3, 3.6], - [6.1, 10.4, 4.4]] - ) + X = np.array([ + [1.5, 8.7, 1.2], + [2.7, 3.4, 5.4], + [3.6, 1.2, 9.3], + [4.4, 5.4, 10.4], + [5.3, 9.3, 3.6], + [6.1, 10.4, 4.4], + ]) n, p = X.shape num_trees = 2 output_dim = 1 @@ -29,54 +29,54 @@ def test_forest(self): forest_samples = ForestContainer(num_trees, output_dim, True, False) # Initialize a forest with constant root predictions - forest_samples.add_sample(0.) + forest_samples.add_sample(0.0) # Split the root of the first tree in the ensemble at X[,1] > 4.0 - forest_samples.add_numeric_split(0, 0, 0, 0, 4.0, -5., 5.) + forest_samples.add_numeric_split(0, 0, 0, 0, 4.0, -5.0, 5.0) # Check that regular and "raw" predictions are the same (since the leaf is constant) computed = compute_forest_leaf_indices(forest_samples, X) max_leaf_index = compute_forest_max_leaf_index(forest_samples) expected = np.array([ - [0], - [0], - [0], - [1], - [1], + [0], + [0], + [0], [1], - [2], - [2], - [2], - [2], - [2], - [2] + [1], + [1], + [2], + [2], + [2], + [2], + [2], + [2], ]) - + # Assertion np.testing.assert_almost_equal(computed, expected) assert max_leaf_index == [2] # Split the left leaf of the first tree in the ensemble at X[,2] > 4.0 forest_samples.add_numeric_split(0, 0, 1, 1, 4.0, -7.5, -2.5) - + # Check that regular and "raw" predictions are the same (since the leaf is constant) computed = compute_forest_leaf_indices(forest_samples, X) max_leaf_index = compute_forest_max_leaf_index(forest_samples) expected = np.array([ - [2], - [1], - [1], - [0], - [0], + [2], + [1], + [1], [0], - [3], - [3], - [3], - [3], - [3], - [3] + [0], + [0], + [3], + [3], + [3], + [3], + [3], + [3], ]) - + # Assertion np.testing.assert_almost_equal(computed, expected) assert max_leaf_index == [3] diff --git a/test/python/test_predict.py b/test/python/test_predict.py index d180fd67..03f36cb2 100644 --- a/test/python/test_predict.py +++ b/test/python/test_predict.py @@ -1,19 +1,23 @@ +import pytest import numpy as np +import pandas as pd +from sklearn.model_selection import train_test_split +from scipy.stats import norm -from stochtree import Dataset, ForestContainer +from stochtree import Dataset, ForestContainer, BARTModel, BCFModel class TestPredict: def test_constant_leaf_prediction(self): # Create dataset - X = np.array( - [[1.5, 8.7, 1.2], - [2.7, 3.4, 5.4], - [3.6, 1.2, 9.3], - [4.4, 5.4, 10.4], - [5.3, 9.3, 3.6], - [6.1, 10.4, 4.4]] - ) + X = np.array([ + [1.5, 8.7, 1.2], + [2.7, 3.4, 5.4], + [3.6, 1.2, 9.3], + [4.4, 5.4, 10.4], + [5.3, 9.3, 3.6], + [6.1, 10.4, 4.4], + ]) n, p = X.shape num_trees = 10 output_dim = 1 @@ -22,7 +26,7 @@ def test_constant_leaf_prediction(self): forest_samples = ForestContainer(num_trees, output_dim, True, False) # Initialize a forest with constant root predictions - forest_samples.add_sample(0.) + forest_samples.add_sample(0.0) # Check that regular and "raw" predictions are the same (since the leaf is constant) pred = forest_samples.predict(forest_dataset) @@ -32,50 +36,43 @@ def test_constant_leaf_prediction(self): np.testing.assert_almost_equal(pred, pred_raw) # Split the root of the first tree in the ensemble at X[,1] > 4.0 - forest_samples.add_numeric_split(0, 0, 0, 0, 4.0, -5., 5.) + forest_samples.add_numeric_split(0, 0, 0, 0, 4.0, -5.0, 5.0) # Check that regular and "raw" predictions are the same (since the leaf is constant) pred = forest_samples.predict(forest_dataset) pred_raw = forest_samples.predict_raw(forest_dataset) - + # Assertion np.testing.assert_almost_equal(pred, pred_raw) # Split the left leaf of the first tree in the ensemble at X[,2] > 4.0 forest_samples.add_numeric_split(0, 0, 1, 1, 4.0, -7.5, -2.5) - + # Check that regular and "raw" predictions are the same (since the leaf is constant) pred = forest_samples.predict(forest_dataset) pred_raw = forest_samples.predict_raw(forest_dataset) - + # Assertion np.testing.assert_almost_equal(pred, pred_raw) - + # Check the split count for the first tree in the ensemble split_counts = forest_samples.get_tree_split_counts(0, 0, p) - split_counts_expected = np.array([1,1,0]) - + split_counts_expected = np.array([1, 1, 0]) + # Assertion np.testing.assert_almost_equal(split_counts, split_counts_expected) - + def test_univariate_regression_leaf_prediction(self): # Create dataset - X = np.array( - [[1.5, 8.7, 1.2], - [2.7, 3.4, 5.4], - [3.6, 1.2, 9.3], - [4.4, 5.4, 10.4], - [5.3, 9.3, 3.6], - [6.1, 10.4, 4.4]] - ) - W = np.array( - [[-1], - [-1], - [-1], - [1], - [1], - [1]] - ) + X = np.array([ + [1.5, 8.7, 1.2], + [2.7, 3.4, 5.4], + [3.6, 1.2, 9.3], + [4.4, 5.4, 10.4], + [5.3, 9.3, 3.6], + [6.1, 10.4, 4.4], + ]) + W = np.array([[-1], [-1], [-1], [1], [1], [1]]) n, p = X.shape num_trees = 10 output_dim = 1 @@ -85,7 +82,7 @@ def test_univariate_regression_leaf_prediction(self): forest_samples = ForestContainer(num_trees, output_dim, False, False) # Initialize a forest with constant root predictions - forest_samples.add_sample(0.) + forest_samples.add_sample(0.0) # Check that regular and "raw" predictions are the same (since the leaf is constant) pred = forest_samples.predict(forest_dataset) @@ -95,52 +92,45 @@ def test_univariate_regression_leaf_prediction(self): np.testing.assert_almost_equal(pred, pred_raw) # Split the root of the first tree in the ensemble at X[,1] > 4.0 - forest_samples.add_numeric_split(0, 0, 0, 0, 4.0, -5., 5.) + forest_samples.add_numeric_split(0, 0, 0, 0, 4.0, -5.0, 5.0) # Check that regular and "raw" predictions are the same (since the leaf is constant) pred = forest_samples.predict(forest_dataset) pred_raw = forest_samples.predict_raw(forest_dataset) - pred_manual = pred_raw*W - + pred_manual = pred_raw * W + # Assertion np.testing.assert_almost_equal(pred, pred_manual) # Split the left leaf of the first tree in the ensemble at X[,2] > 4.0 forest_samples.add_numeric_split(0, 0, 1, 1, 4.0, -7.5, -2.5) - + # Check that regular and "raw" predictions are the same (since the leaf is constant) pred = forest_samples.predict(forest_dataset) pred_raw = forest_samples.predict_raw(forest_dataset) - pred_manual = pred_raw*W - + pred_manual = pred_raw * W + # Assertion np.testing.assert_almost_equal(pred, pred_manual) - + # Check the split count for the first tree in the ensemble split_counts = forest_samples.get_tree_split_counts(0, 0, p) split_counts_expected = np.array([1, 1, 0]) - + # Assertion np.testing.assert_almost_equal(split_counts, split_counts_expected) - + def test_multivariate_regression_leaf_prediction(self): # Create dataset - X = np.array( - [[1.5, 8.7, 1.2], - [2.7, 3.4, 5.4], - [3.6, 1.2, 9.3], - [4.4, 5.4, 10.4], - [5.3, 9.3, 3.6], - [6.1, 10.4, 4.4]] - ) - W = np.array( - [[1,-1], - [1,-1], - [1,-1], - [1, 1], - [1, 1], - [1, 1]] - ) + X = np.array([ + [1.5, 8.7, 1.2], + [2.7, 3.4, 5.4], + [3.6, 1.2, 9.3], + [4.4, 5.4, 10.4], + [5.3, 9.3, 3.6], + [6.1, 10.4, 4.4], + ]) + W = np.array([[1, -1], [1, -1], [1, -1], [1, 1], [1, 1], [1, 1]]) n, p = X.shape num_trees = 10 output_dim = 2 @@ -151,45 +141,280 @@ def test_multivariate_regression_leaf_prediction(self): forest_samples = ForestContainer(num_trees, output_dim, False, False) # Initialize a forest with constant root predictions - forest_samples.add_sample(np.array([1.,1.])) + forest_samples.add_sample(np.array([1.0, 1.0])) num_samples += 1 # Check that regular and "raw" predictions are the same (since the leaf is constant) pred = forest_samples.predict(forest_dataset) pred_raw = forest_samples.predict_raw(forest_dataset) pred_intermediate = pred_raw * W - pred_manual = pred_intermediate.sum(axis=1, keepdims = True) + pred_manual = pred_intermediate.sum(axis=1, keepdims=True) # Assertion np.testing.assert_almost_equal(pred, pred_manual) # Split the root of the first tree in the ensemble at X[,1] > 4.0 - forest_samples.add_numeric_split(0, 0, 0, 0, 4.0, np.array([-5.,-1.]), np.array([5.,1.])) + forest_samples.add_numeric_split( + 0, 0, 0, 0, 4.0, np.array([-5.0, -1.0]), np.array([5.0, 1.0]) + ) # Check that regular and "raw" predictions are the same (since the leaf is constant) pred = forest_samples.predict(forest_dataset) pred_raw = forest_samples.predict_raw(forest_dataset) pred_intermediate = pred_raw * W - pred_manual = pred_intermediate.sum(axis=1, keepdims = True) - + pred_manual = pred_intermediate.sum(axis=1, keepdims=True) + # Assertion np.testing.assert_almost_equal(pred, pred_manual) # Split the left leaf of the first tree in the ensemble at X[,2] > 4.0 - forest_samples.add_numeric_split(0, 0, 1, 1, 4.0, np.array([-7.5,2.5]), np.array([-2.5,7.5])) - + forest_samples.add_numeric_split( + 0, 0, 1, 1, 4.0, np.array([-7.5, 2.5]), np.array([-2.5, 7.5]) + ) + # Check that regular and "raw" predictions are the same (since the leaf is constant) pred = forest_samples.predict(forest_dataset) pred_raw = forest_samples.predict_raw(forest_dataset) pred_intermediate = pred_raw * W - pred_manual = pred_intermediate.sum(axis=1, keepdims = True) - + pred_manual = pred_intermediate.sum(axis=1, keepdims=True) + # Assertion np.testing.assert_almost_equal(pred, pred_manual) - + # Check the split count for the first tree in the ensemble split_counts = forest_samples.get_tree_split_counts(0, 0, p) split_counts_expected = np.array([1, 1, 0]) - + # Assertion np.testing.assert_almost_equal(split_counts, split_counts_expected) + + def test_bart_prediction(self): + # Generate data and test/train split + rng = np.random.default_rng(1234) + n = 100 + p = 5 + X = rng.uniform(size=(n, p)) + f_XW = np.where( + (0 <= X[:, 0]) & (X[:, 0] < 0.25), + -7.5, + np.where( + (0.25 <= X[:, 0]) & (X[:, 0] < 0.5), + -2.5, + np.where((0.5 <= X[:, 0]) & (X[:, 0] < 0.75), 2.5, 7.5), + ), + ) + noise_sd = 1 + y = f_XW + rng.normal(0, noise_sd, size=n) + test_set_pct = 0.2 + train_inds, test_inds = train_test_split( + np.arange(n), test_size=test_set_pct, random_state=1234 + ) + X_train = X[train_inds, :] + X_test = X[test_inds, :] + y_train = y[train_inds] + y_test = y[test_inds] + + # Fit a "classic" BART model + bart_model = BARTModel() + bart_model.sample( + X_train=X_train, y_train=y_train, num_gfr=10, num_burnin=0, num_mcmc=10 + ) + + # Check that the default predict method returns a dictionary + pred = bart_model.predict(covariates=X_test) + y_hat_posterior_test = pred["y_hat"] + assert y_hat_posterior_test.shape == (20, 10) + + # Check that the pre-aggregated predictions match with those computed by np.mean + pred_mean = bart_model.predict(covariates=X_test, type="mean") + y_hat_mean_test = pred_mean["y_hat"] + np.testing.assert_almost_equal( + y_hat_mean_test, np.mean(y_hat_posterior_test, axis=1) + ) + + # Fit a heteroskedastic BART model + var_params = {"num_trees": 20} + het_bart_model = BARTModel() + het_bart_model.sample( + X_train=X_train, + y_train=y_train, + num_gfr=10, + num_burnin=0, + num_mcmc=10, + variance_forest_params=var_params, + ) + + # Check that the default predict method returns a dictionary + pred = het_bart_model.predict(covariates=X_test) + y_hat_posterior_test = pred["y_hat"] + sigma2_hat_posterior_test = pred["variance_forest_predictions"] + assert y_hat_posterior_test.shape == (20, 10) + assert sigma2_hat_posterior_test.shape == (20, 10) + + # Check that the pre-aggregated predictions match with those computed by np.mean + pred_mean = het_bart_model.predict(covariates=X_test, type="mean") + y_hat_mean_test = pred_mean["y_hat"] + sigma2_hat_mean_test = pred_mean["variance_forest_predictions"] + np.testing.assert_almost_equal( + y_hat_mean_test, np.mean(y_hat_posterior_test, axis=1) + ) + np.testing.assert_almost_equal( + sigma2_hat_mean_test, np.mean(sigma2_hat_posterior_test, axis=1) + ) + + # Check that the "single-term" pre-aggregated predictions + # match those computed by pre-aggregated predictions returned in a dictionary + y_hat_mean_test_single_term = het_bart_model.predict( + covariates=X_test, type="mean", terms="y_hat" + ) + sigma2_hat_mean_test_single_term = het_bart_model.predict( + covariates=X_test, type="mean", terms="variance_forest" + ) + np.testing.assert_almost_equal(y_hat_mean_test, y_hat_mean_test_single_term) + np.testing.assert_almost_equal( + sigma2_hat_mean_test, sigma2_hat_mean_test_single_term + ) + + def test_bcf_prediction(self): + # Generate data and test/train split + rng = np.random.default_rng(1234) + n = 100 + g = lambda x: np.where(x[:, 4] == 1, 2, np.where(x[:, 4] == 2, -1, -4)) + x1 = rng.normal(size=n) + x2 = rng.normal(size=n) + x3 = rng.normal(size=n) + x4 = rng.binomial(n=1, p=0.5, size=(n,)) + x5 = rng.choice(a=[0, 1, 2], size=(n,), replace=True) + x4_cat = pd.Categorical(x4, categories=[0, 1], ordered=True) + x5_cat = pd.Categorical(x4, categories=[0, 1, 2], ordered=True) + p = 5 + X = pd.DataFrame( + data={ + "x1": pd.Series(x1), + "x2": pd.Series(x2), + "x3": pd.Series(x3), + "x4": pd.Series(x4_cat), + "x5": pd.Series(x5_cat), + } + ) + + def g(x5): + return np.where(x5 == 0, 2.0, np.where(x5 == 1, -1.0, -4.0)) + + p = X.shape[1] + mu_x = 1.0 + g(x5) + x1 * x3 + tau_x = 1.0 + 2 * x2 * x4 + pi_x = ( + 0.8 * norm.cdf(3.0 * mu_x / np.squeeze(np.std(mu_x)) - 0.5 * x1) + + 0.05 + + rng.uniform(low=0.0, high=0.1, size=(n,)) + ) + Z = rng.binomial(n=1, p=pi_x, size=(n,)) + E_XZ = mu_x + tau_x * Z + snr = 2 + y = E_XZ + rng.normal(loc=0.0, scale=np.std(E_XZ) / snr, size=(n,)) + test_set_pct = 0.2 + train_inds, test_inds = train_test_split( + np.arange(n), test_size=test_set_pct, random_state=1234 + ) + X_train = X.iloc[train_inds, :] + X_test = X.iloc[test_inds, :] + Z_train = Z[train_inds] + Z_test = Z[test_inds] + pi_x_train = pi_x[train_inds] + pi_x_test = pi_x[test_inds] + y_train = y[train_inds] + y_test = y[test_inds] + + # Fit a "classic" BCF model + bcf_model = BCFModel() + bcf_model.sample( + X_train=X_train, + Z_train=Z_train, + y_train=y_train, + pi_train=pi_x_train, + X_test=X_test, + Z_test=Z_test, + pi_test=pi_x_test, + num_gfr=10, + num_burnin=0, + num_mcmc=10, + ) + + # Check that the default predict method returns a dictionary + pred = bcf_model.predict(X=X_test, Z=Z_test, propensity=pi_x_test) + y_hat_posterior_test = pred["y_hat"] + assert y_hat_posterior_test.shape == (20, 10) + + # Check that the pre-aggregated predictions match with those computed by np.mean + pred_mean = bcf_model.predict( + X=X_test, Z=Z_test, propensity=pi_x_test, type="mean" + ) + y_hat_mean_test = pred_mean["y_hat"] + np.testing.assert_almost_equal( + y_hat_mean_test, np.mean(y_hat_posterior_test, axis=1) + ) + + # Check that we warn and return None when requesting terms that weren't fit + with pytest.warns(UserWarning): + pred_mean = bcf_model.predict( + X=X_test, + Z=Z_test, + propensity=pi_x_test, + type="mean", + terms=["rfx", "variance_forest"], + ) + + # Fit a heteroskedastic BCF model + var_params = {"num_trees": 20} + het_bcf_model = BCFModel() + het_bcf_model.sample( + X_train=X_train, + Z_train=Z_train, + y_train=y_train, + pi_train=pi_x_train, + X_test=X_test, + Z_test=Z_test, + pi_test=pi_x_test, + num_gfr=10, + num_burnin=0, + num_mcmc=10, + variance_forest_params=var_params, + ) + + # Check that the default predict method returns a dictionary + pred = het_bcf_model.predict(X=X_test, Z=Z_test, propensity=pi_x_test) + y_hat_posterior_test = pred["y_hat"] + sigma2_hat_posterior_test = pred["variance_forest_predictions"] + assert y_hat_posterior_test.shape == (20, 10) + assert sigma2_hat_posterior_test.shape == (20, 10) + + # Check that the pre-aggregated predictions match with those computed by np.mean + pred_mean = het_bcf_model.predict( + X=X_test, Z=Z_test, propensity=pi_x_test, type="mean" + ) + y_hat_mean_test = pred_mean["y_hat"] + sigma2_hat_mean_test = pred_mean["variance_forest_predictions"] + np.testing.assert_almost_equal( + y_hat_mean_test, np.mean(y_hat_posterior_test, axis=1) + ) + np.testing.assert_almost_equal( + sigma2_hat_mean_test, np.mean(sigma2_hat_posterior_test, axis=1) + ) + + # Check that the "single-term" pre-aggregated predictions + # match those computed by pre-aggregated predictions returned in a dictionary + y_hat_mean_test_single_term = het_bcf_model.predict( + X=X_test, Z=Z_test, propensity=pi_x_test, type="mean", terms="y_hat" + ) + sigma2_hat_mean_test_single_term = het_bcf_model.predict( + X=X_test, + Z=Z_test, + propensity=pi_x_test, + type="mean", + terms="variance_forest", + ) + np.testing.assert_almost_equal(y_hat_mean_test, y_hat_mean_test_single_term) + np.testing.assert_almost_equal( + sigma2_hat_mean_test, sigma2_hat_mean_test_single_term + ) diff --git a/test/python/test_preprocessor.py b/test/python/test_preprocessor.py index f40ef204..748fa18a 100644 --- a/test/python/test_preprocessor.py +++ b/test/python/test_preprocessor.py @@ -7,16 +7,14 @@ class TestPreprocessor: def test_numpy(self): cov_transformer = CovariatePreprocessor() - np_1 = np.array( - [ - [1.5, 8.7, 1.2], - [2.7, 3.4, 5.4], - [3.6, 1.2, 9.3], - [4.4, 5.4, 10.4], - [5.3, 9.3, 3.6], - [6.1, 10.4, 4.4], - ] - ) + np_1 = np.array([ + [1.5, 8.7, 1.2], + [2.7, 3.4, 5.4], + [3.6, 1.2, 9.3], + [4.4, 5.4, 10.4], + [5.3, 9.3, 3.6], + [6.1, 10.4, 4.4], + ]) np_1_transformed = cov_transformer.fit_transform(np_1) np.testing.assert_array_equal(np_1, np_1_transformed) np.testing.assert_array_equal( @@ -24,23 +22,19 @@ def test_numpy(self): ) def test_pandas(self): - df_1 = pd.DataFrame( - { - "x1": [1.5, 2.7, 3.6, 4.4, 5.3, 6.1], - "x2": [8.7, 3.4, 1.2, 5.4, 9.3, 10.4], - "x3": [1.2, 5.4, 9.3, 10.4, 3.6, 4.4], - } - ) - np_1 = np.array( - [ - [1.5, 8.7, 1.2], - [2.7, 3.4, 5.4], - [3.6, 1.2, 9.3], - [4.4, 5.4, 10.4], - [5.3, 9.3, 3.6], - [6.1, 10.4, 4.4], - ] - ) + df_1 = pd.DataFrame({ + "x1": [1.5, 2.7, 3.6, 4.4, 5.3, 6.1], + "x2": [8.7, 3.4, 1.2, 5.4, 9.3, 10.4], + "x3": [1.2, 5.4, 9.3, 10.4, 3.6, 4.4], + }) + np_1 = np.array([ + [1.5, 8.7, 1.2], + [2.7, 3.4, 5.4], + [3.6, 1.2, 9.3], + [4.4, 5.4, 10.4], + [5.3, 9.3, 3.6], + [6.1, 10.4, 4.4], + ]) cov_transformer = CovariatePreprocessor() df_1_transformed = cov_transformer.fit_transform(df_1) np.testing.assert_array_equal(np_1, df_1_transformed) @@ -48,27 +42,23 @@ def test_pandas(self): cov_transformer._processed_feature_types, np.array([0, 0, 0]) ) - df_2 = pd.DataFrame( - { - "x1": [1.5, 2.7, 3.6, 4.4, 5.3, 6.1], - "x2": pd.Categorical( - ["a", "b", "c", "a", "b", "c"], - ordered=True, - categories=["c", "b", "a"], - ), - "x3": [1.2, 5.4, 9.3, 10.4, 3.6, 4.4], - } - ) - np_2 = np.array( - [ - [1.5, 2, 1.2], - [2.7, 1, 5.4], - [3.6, 0, 9.3], - [4.4, 2, 10.4], - [5.3, 1, 3.6], - [6.1, 0, 4.4], - ] - ) + df_2 = pd.DataFrame({ + "x1": [1.5, 2.7, 3.6, 4.4, 5.3, 6.1], + "x2": pd.Categorical( + ["a", "b", "c", "a", "b", "c"], + ordered=True, + categories=["c", "b", "a"], + ), + "x3": [1.2, 5.4, 9.3, 10.4, 3.6, 4.4], + }) + np_2 = np.array([ + [1.5, 2, 1.2], + [2.7, 1, 5.4], + [3.6, 0, 9.3], + [4.4, 2, 10.4], + [5.3, 1, 3.6], + [6.1, 0, 4.4], + ]) cov_transformer = CovariatePreprocessor() df_2_transformed = cov_transformer.fit_transform(df_2) np.testing.assert_array_equal(np_2, df_2_transformed) @@ -76,27 +66,23 @@ def test_pandas(self): cov_transformer._processed_feature_types, np.array([0, 1, 0]) ) - df_3 = pd.DataFrame( - { - "x1": [1.5, 2.7, 3.6, 4.4, 5.3, 6.1], - "x2": pd.Categorical( - ["a", "b", "c", "a", "b", "c"], - ordered=False, - categories=["c", "b", "a"], - ), - "x3": [1.2, 5.4, 9.3, 10.4, 3.6, 4.4], - } - ) - np_3 = np.array( - [ - [1.5, 0, 0, 1, 1.2], - [2.7, 0, 1, 0, 5.4], - [3.6, 1, 0, 0, 9.3], - [4.4, 0, 0, 1, 10.4], - [5.3, 0, 1, 0, 3.6], - [6.1, 1, 0, 0, 4.4], - ] - ) + df_3 = pd.DataFrame({ + "x1": [1.5, 2.7, 3.6, 4.4, 5.3, 6.1], + "x2": pd.Categorical( + ["a", "b", "c", "a", "b", "c"], + ordered=False, + categories=["c", "b", "a"], + ), + "x3": [1.2, 5.4, 9.3, 10.4, 3.6, 4.4], + }) + np_3 = np.array([ + [1.5, 0, 0, 1, 1.2], + [2.7, 0, 1, 0, 5.4], + [3.6, 1, 0, 0, 9.3], + [4.4, 0, 0, 1, 10.4], + [5.3, 0, 1, 0, 3.6], + [6.1, 1, 0, 0, 4.4], + ]) cov_transformer = CovariatePreprocessor() df_3_transformed = cov_transformer.fit_transform(df_3) np.testing.assert_array_equal(np_3, df_3_transformed) @@ -104,28 +90,24 @@ def test_pandas(self): cov_transformer._processed_feature_types, np.array([0, 1, 1, 1, 0]) ) - df_4 = pd.DataFrame( - { - "x1": [1.5, 2.7, 3.6, 4.4, 5.3, 6.1, 7.6], - "x2": pd.Categorical( - ["a", "b", "c", "a", "b", "c", "c"], - ordered=False, - categories=["c", "b", "a", "d"], - ), - "x3": [1.2, 5.4, 9.3, 10.4, 3.6, 4.4, 3.4], - } - ) - np_4 = np.array( - [ - [1.5, 0, 0, 1, 0, 1.2], - [2.7, 0, 1, 0, 0, 5.4], - [3.6, 1, 0, 0, 0, 9.3], - [4.4, 0, 0, 1, 0, 10.4], - [5.3, 0, 1, 0, 0, 3.6], - [6.1, 1, 0, 0, 0, 4.4], - [7.6, 1, 0, 0, 0, 3.4], - ] - ) + df_4 = pd.DataFrame({ + "x1": [1.5, 2.7, 3.6, 4.4, 5.3, 6.1, 7.6], + "x2": pd.Categorical( + ["a", "b", "c", "a", "b", "c", "c"], + ordered=False, + categories=["c", "b", "a", "d"], + ), + "x3": [1.2, 5.4, 9.3, 10.4, 3.6, 4.4, 3.4], + }) + np_4 = np.array([ + [1.5, 0, 0, 1, 0, 1.2], + [2.7, 0, 1, 0, 0, 5.4], + [3.6, 1, 0, 0, 0, 9.3], + [4.4, 0, 0, 1, 0, 10.4], + [5.3, 0, 1, 0, 0, 3.6], + [6.1, 1, 0, 0, 0, 4.4], + [7.6, 1, 0, 0, 0, 3.4], + ]) cov_transformer = CovariatePreprocessor() df_4_transformed = cov_transformer.fit_transform(df_4) np.testing.assert_array_equal(np_4, df_4_transformed) diff --git a/test/python/test_residual.py b/test/python/test_residual.py index a0dd1c09..9be4eaa9 100644 --- a/test/python/test_residual.py +++ b/test/python/test_residual.py @@ -15,16 +15,14 @@ class TestResidual: def test_basis_update(self): # Create dataset - X = np.array( - [ - [1.5, 8.7, 1.2], - [2.7, 3.4, 5.4], - [3.6, 1.2, 9.3], - [4.4, 5.4, 10.4], - [5.3, 9.3, 3.6], - [6.1, 10.4, 4.4], - ] - ) + X = np.array([ + [1.5, 8.7, 1.2], + [2.7, 3.4, 5.4], + [3.6, 1.2, 9.3], + [4.4, 5.4, 10.4], + [5.3, 9.3, 3.6], + [6.1, 10.4, 4.4], + ]) W = np.array([[1], [1], [1], [1], [1], [1]]) n = X.shape[0] p = X.shape[1] diff --git a/test/python/test_utils.py b/test/python/test_utils.py index 5394802f..9b6da787 100644 --- a/test/python/test_utils.py +++ b/test/python/test_utils.py @@ -9,9 +9,9 @@ _check_matrix_square, _standardize_array_to_list, _standardize_array_to_np, - _expand_dims_1d, - _expand_dims_2d, - _expand_dims_2d_diag + _expand_dims_1d, + _expand_dims_2d, + _expand_dims_2d_diag, ) @@ -90,9 +90,11 @@ def test_standardize(self): array_np3 = np.array([8.2, 4.5, 3.8]) array_np4 = np.array([[8.2, 4.5, 3.8]]) nonconforming_array_np1 = np.array([[8.2, 4.5, 3.8], [1.6, 3.4, 7.6]]) - nonconforming_array_np2 = np.array( - [[8.2, 4.5, 3.8], [1.6, 3.4, 7.6], [1.6, 3.4, 7.6]] - ) + nonconforming_array_np2 = np.array([ + [8.2, 4.5, 3.8], + [1.6, 3.4, 7.6], + [1.6, 3.4, 7.6], + ]) non_array_1 = 100000000 non_array_2 = "a" @@ -117,7 +119,7 @@ def test_standardize(self): _ = _standardize_array_to_np(non_array_2) _ = _standardize_array_to_np(nonconforming_array_np1) _ = _standardize_array_to_np(nonconforming_array_np2) - + def test_array_conversion(self): scalar_1 = 1.5 scalar_2 = -2.5 @@ -125,11 +127,15 @@ def test_array_conversion(self): array_1d_1 = np.array([1.6, 3.4, 7.6, 8.7]) array_1d_2 = np.array([2.5, 3.1, 5.6]) array_1d_3 = np.array([2.5]) - array_2d_1 = np.array([[2.5,1.2,4.3,7.4],[1.7,2.9,3.6,9.1],[7.2,4.5,6.7,1.4]]) - array_2d_2 = np.array([[2.5,1.2,4.3,7.4],[1.7,2.9,3.6,9.1]]) - array_square_1 = np.array([[2.5,1.2],[1.7,2.9]]) - array_square_2 = np.array([[2.5,0.0],[0.0,2.9]]) - array_square_3 = np.array([[2.5,0.0,0.0],[0.0,2.9,0.0],[0.0,0.0,5.6]]) + array_2d_1 = np.array([ + [2.5, 1.2, 4.3, 7.4], + [1.7, 2.9, 3.6, 9.1], + [7.2, 4.5, 6.7, 1.4], + ]) + array_2d_2 = np.array([[2.5, 1.2, 4.3, 7.4], [1.7, 2.9, 3.6, 9.1]]) + array_square_1 = np.array([[2.5, 1.2], [1.7, 2.9]]) + array_square_2 = np.array([[2.5, 0.0], [0.0, 2.9]]) + array_square_3 = np.array([[2.5, 0.0, 0.0], [0.0, 2.9, 0.0], [0.0, 0.0, 5.6]]) with pytest.raises(ValueError): _ = _expand_dims_1d(array_1d_1, 5) _ = _expand_dims_1d(array_1d_2, 4) @@ -139,23 +145,91 @@ def test_array_conversion(self): _ = _expand_dims_2d_diag(array_square_1, 4) _ = _expand_dims_2d_diag(array_square_2, 3) _ = _expand_dims_2d_diag(array_square_3, 2) - - np.testing.assert_array_equal(np.array([scalar_1,scalar_1,scalar_1]), _expand_dims_1d(scalar_1, 3)) - np.testing.assert_array_equal(np.array([scalar_2,scalar_2,scalar_2,scalar_2]), _expand_dims_1d(scalar_2, 4)) - np.testing.assert_array_equal(np.array([scalar_3,scalar_3]), _expand_dims_1d(scalar_3, 2)) - np.testing.assert_array_equal(np.array([array_1d_3[0],array_1d_3[0],array_1d_3[0]]), _expand_dims_1d(array_1d_3, 3)) - - np.testing.assert_array_equal(np.array([[scalar_1,scalar_1,scalar_1],[scalar_1,scalar_1,scalar_1]]), _expand_dims_2d(scalar_1, 2, 3)) - np.testing.assert_array_equal(np.array([[scalar_2,scalar_2,scalar_2,scalar_2],[scalar_2,scalar_2,scalar_2,scalar_2]]), _expand_dims_2d(scalar_2, 2, 4)) - np.testing.assert_array_equal(np.array([[scalar_3,scalar_3],[scalar_3,scalar_3],[scalar_3,scalar_3]]), _expand_dims_2d(scalar_3, 3, 2)) - np.testing.assert_array_equal(np.array([[array_1d_3[0],array_1d_3[0]],[array_1d_3[0],array_1d_3[0]],[array_1d_3[0],array_1d_3[0]]]), _expand_dims_2d(array_1d_3, 3, 2)) - np.testing.assert_array_equal(np.vstack((array_1d_1, array_1d_1)), _expand_dims_2d(array_1d_1, 2, 4)) - np.testing.assert_array_equal(np.vstack((array_1d_2, array_1d_2, array_1d_2)), _expand_dims_2d(array_1d_2, 3, 3)) - np.testing.assert_array_equal(np.column_stack((array_1d_2, array_1d_2, array_1d_2, array_1d_2)), _expand_dims_2d(array_1d_2, 3, 4)) - np.testing.assert_array_equal(np.column_stack((array_1d_3, array_1d_3, array_1d_3, array_1d_3)), _expand_dims_2d(array_1d_3, 1, 4)) - np.testing.assert_array_equal(np.vstack((array_1d_3, array_1d_3, array_1d_3, array_1d_3)), _expand_dims_2d(array_1d_3, 4, 1)) - - np.testing.assert_array_equal(np.array([[scalar_1,0.0,0.0],[0.0,scalar_1,0.0],[0.0,0.0,scalar_1]]), _expand_dims_2d_diag(scalar_1, 3)) - np.testing.assert_array_equal(np.array([[scalar_2,0.0],[0.0,scalar_2]]), _expand_dims_2d_diag(scalar_2, 2)) - np.testing.assert_array_equal(np.array([[scalar_3,0.0,0.0,0.0],[0.0,scalar_3,0.0,0.0],[0.0,0.0,scalar_3,0.0],[0.0,0.0,0.0,scalar_3]]), _expand_dims_2d_diag(scalar_3, 4)) - np.testing.assert_array_equal(np.array([[array_1d_3[0],0.0],[0.0,array_1d_3[0]]]), _expand_dims_2d_diag(array_1d_3, 2)) + + np.testing.assert_array_equal( + np.array([scalar_1, scalar_1, scalar_1]), _expand_dims_1d(scalar_1, 3) + ) + np.testing.assert_array_equal( + np.array([scalar_2, scalar_2, scalar_2, scalar_2]), + _expand_dims_1d(scalar_2, 4), + ) + np.testing.assert_array_equal( + np.array([scalar_3, scalar_3]), _expand_dims_1d(scalar_3, 2) + ) + np.testing.assert_array_equal( + np.array([array_1d_3[0], array_1d_3[0], array_1d_3[0]]), + _expand_dims_1d(array_1d_3, 3), + ) + + np.testing.assert_array_equal( + np.array([[scalar_1, scalar_1, scalar_1], [scalar_1, scalar_1, scalar_1]]), + _expand_dims_2d(scalar_1, 2, 3), + ) + np.testing.assert_array_equal( + np.array([ + [scalar_2, scalar_2, scalar_2, scalar_2], + [scalar_2, scalar_2, scalar_2, scalar_2], + ]), + _expand_dims_2d(scalar_2, 2, 4), + ) + np.testing.assert_array_equal( + np.array([ + [scalar_3, scalar_3], + [scalar_3, scalar_3], + [scalar_3, scalar_3], + ]), + _expand_dims_2d(scalar_3, 3, 2), + ) + np.testing.assert_array_equal( + np.array([ + [array_1d_3[0], array_1d_3[0]], + [array_1d_3[0], array_1d_3[0]], + [array_1d_3[0], array_1d_3[0]], + ]), + _expand_dims_2d(array_1d_3, 3, 2), + ) + np.testing.assert_array_equal( + np.vstack((array_1d_1, array_1d_1)), _expand_dims_2d(array_1d_1, 2, 4) + ) + np.testing.assert_array_equal( + np.vstack((array_1d_2, array_1d_2, array_1d_2)), + _expand_dims_2d(array_1d_2, 3, 3), + ) + np.testing.assert_array_equal( + np.column_stack((array_1d_2, array_1d_2, array_1d_2, array_1d_2)), + _expand_dims_2d(array_1d_2, 3, 4), + ) + np.testing.assert_array_equal( + np.column_stack((array_1d_3, array_1d_3, array_1d_3, array_1d_3)), + _expand_dims_2d(array_1d_3, 1, 4), + ) + np.testing.assert_array_equal( + np.vstack((array_1d_3, array_1d_3, array_1d_3, array_1d_3)), + _expand_dims_2d(array_1d_3, 4, 1), + ) + + np.testing.assert_array_equal( + np.array([ + [scalar_1, 0.0, 0.0], + [0.0, scalar_1, 0.0], + [0.0, 0.0, scalar_1], + ]), + _expand_dims_2d_diag(scalar_1, 3), + ) + np.testing.assert_array_equal( + np.array([[scalar_2, 0.0], [0.0, scalar_2]]), + _expand_dims_2d_diag(scalar_2, 2), + ) + np.testing.assert_array_equal( + np.array([ + [scalar_3, 0.0, 0.0, 0.0], + [0.0, scalar_3, 0.0, 0.0], + [0.0, 0.0, scalar_3, 0.0], + [0.0, 0.0, 0.0, scalar_3], + ]), + _expand_dims_2d_diag(scalar_3, 4), + ) + np.testing.assert_array_equal( + np.array([[array_1d_3[0], 0.0], [0.0, array_1d_3[0]]]), + _expand_dims_2d_diag(array_1d_3, 2), + ) diff --git a/tools/debug/bart_contrast_debug.R b/tools/debug/bart_contrast_debug.R new file mode 100644 index 00000000..647d12b0 --- /dev/null +++ b/tools/debug/bart_contrast_debug.R @@ -0,0 +1,171 @@ +# Demo of CATE computation function for BCF + +# Load library +library(stochtree) + +# Generate data +n <- 500 +p <- 5 +X <- matrix(rnorm(n * p), ncol = p) +W <- matrix(rnorm(n * 1), ncol = 1) +# fmt: skip +f_XW <- ( + ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5 * W[,1]) + + ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5 * W[,1]) + + ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5 * W[,1]) + + ((0.75 <= X[,1]) & (1 > X[,1])) * (7.5 * W[,1]) +) +E_Y <- f_XW +snr <- 2 +y <- E_Y + rnorm(n, 0, 1) * (sd(E_Y) / snr) + +# Train-test split +test_set_pct <- 0.2 +n_test <- round(test_set_pct * n) +n_train <- n - n_test +test_inds <- sort(sample(1:n, n_test, replace = FALSE)) +train_inds <- (1:n)[!((1:n) %in% test_inds)] +X_test <- X[test_inds, ] +X_train <- X[train_inds, ] +W_test <- W[test_inds] +W_train <- W[train_inds] +y_test <- y[test_inds] +y_train <- y[train_inds] + +# Fit BART model +bart_model <- bart( + X_train = X_train, + leaf_basis_train = W_train, + y_train = y_train, + num_gfr = 10, + num_burnin = 0, + num_mcmc = 1000 +) + +# Compute contrast posterior +contrast_posterior_test <- compute_contrast_bart_model( + bart_model, + covariates_0 = X_test, + covariates_1 = X_test, + leaf_basis_0 = matrix(0, nrow = n_test, ncol = 1), + leaf_basis_1 = matrix(1, nrow = n_test, ncol = 1), + type = "posterior", + scale = "linear" +) + +# Compute the same quantity via two predict calls +y_hat_posterior_test_0 <- predict( + bart_model, + covariates = X_test, + leaf_basis = matrix(0, nrow = n_test, ncol = 1), + type = "posterior", + term = "y_hat", + scale = "linear" +) +y_hat_posterior_test_1 <- predict( + bart_model, + covariates = X_test, + leaf_basis = matrix(1, nrow = n_test, ncol = 1), + type = "posterior", + term = "y_hat", + scale = "linear" +) +contrast_posterior_test_comparison <- (y_hat_posterior_test_1 - + y_hat_posterior_test_0) + +# Compare results +contrast_diff <- contrast_posterior_test_comparison - contrast_posterior_test +all( + abs(contrast_diff) < 0.001 +) + +# Generate data for a BCF model with random effects +X <- matrix(rnorm(n * p), ncol = p) +W <- matrix(rnorm(n * 1), ncol = 1) +# fmt: skip +f_XW <- ( + ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5 * W[,1]) + + ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5 * W[,1]) + + ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5 * W[,1]) + + ((0.75 <= X[,1]) & (1 > X[,1])) * (7.5 * W[,1]) +) +group_ids <- rep(c(1, 2), n %/% 2) +rfx_coefs <- matrix(c(-1, -1, 1, 1), nrow = 2, byrow = TRUE) +rfx_basis <- cbind(1, runif(n)) +rfx_term <- rowSums(rfx_coefs[group_ids, ] * rfx_basis) +E_Y <- f_XW + rfx_term +snr <- 2 +y <- E_Y + rnorm(n, 0, 1) * (sd(E_Y) / snr) + +# Train-test split +n_test <- round(test_set_pct * n) +n_train <- n - n_test +test_inds <- sort(sample(1:n, n_test, replace = FALSE)) +train_inds <- (1:n)[!((1:n) %in% test_inds)] +X_test <- X[test_inds, ] +X_train <- X[train_inds, ] +W_test <- W[test_inds] +W_train <- W[train_inds] +y_test <- y[test_inds] +y_train <- y[train_inds] +group_ids_test <- group_ids[test_inds] +group_ids_train <- group_ids[train_inds] +rfx_basis_test <- rfx_basis[test_inds, ] +rfx_basis_train <- rfx_basis[train_inds, ] + +# Fit BART model +bart_model <- bart( + X_train = X_train, + leaf_basis_train = W_train, + y_train = y_train, + rfx_group_ids_train = group_ids_train, + rfx_basis_train = rfx_basis_train, + num_gfr = 10, + num_burnin = 0, + num_mcmc = 1000 +) + +# Compute contrast posterior +contrast_posterior_test <- compute_contrast_bart_model( + bart_model, + covariates_0 = X_test, + covariates_1 = X_test, + leaf_basis_0 = matrix(0, nrow = n_test, ncol = 1), + leaf_basis_1 = matrix(1, nrow = n_test, ncol = 1), + rfx_group_ids_0 = group_ids_test, + rfx_group_ids_1 = group_ids_test, + rfx_basis_0 = rfx_basis_test, + rfx_basis_1 = rfx_basis_test, + type = "posterior", + scale = "linear" +) + +# Compute the same quantity via two predict calls +y_hat_posterior_test_0 <- predict( + bart_model, + covariates = X_test, + leaf_basis = matrix(0, nrow = n_test, ncol = 1), + rfx_group_ids = group_ids_test, + rfx_basis = rfx_basis_test, + type = "posterior", + term = "y_hat", + scale = "linear" +) +y_hat_posterior_test_1 <- predict( + bart_model, + covariates = X_test, + leaf_basis = matrix(1, nrow = n_test, ncol = 1), + rfx_group_ids = group_ids_test, + rfx_basis = rfx_basis_test, + type = "posterior", + term = "y_hat", + scale = "linear" +) +contrast_posterior_test_comparison <- (y_hat_posterior_test_1 - + y_hat_posterior_test_0) + +# Compare results +contrast_diff <- contrast_posterior_test_comparison - contrast_posterior_test +all( + abs(contrast_diff) < 0.001 +) diff --git a/tools/debug/bart_predict_debug.R b/tools/debug/bart_predict_debug.R new file mode 100644 index 00000000..89766a74 --- /dev/null +++ b/tools/debug/bart_predict_debug.R @@ -0,0 +1,224 @@ +# Demo of updated predict method for BART + +# Load library +library(stochtree) + +# Generate data +n <- 1000 +p <- 5 +X <- matrix(runif(n * p), ncol = p) +# fmt: skip +f_XW <- (((0 <= X[, 1]) & (0.25 > X[, 1])) * (-7.5) + + ((0.25 <= X[, 1]) & (0.5 > X[, 1])) * (-2.5) + + ((0.5 <= X[, 1]) & (0.75 > X[, 1])) * (2.5) + + ((0.75 <= X[, 1]) & (1 > X[, 1])) * (7.5)) +noise_sd <- 1 +y <- f_XW + rnorm(n, 0, noise_sd) + +# Train-test split +test_set_pct <- 0.2 +n_test <- round(test_set_pct * n) +n_train <- n - n_test +test_inds <- sort(sample(1:n, n_test, replace = FALSE)) +train_inds <- (1:n)[!((1:n) %in% test_inds)] +X_test <- X[test_inds, ] +X_train <- X[train_inds, ] +y_test <- y[test_inds] +y_train <- y[train_inds] +E_y_test <- f_XW[test_inds] +E_y_train <- f_XW[train_inds] + +# Fit simple BART model +bart_model <- bart( + X_train = X_train, + y_train = y_train, + num_gfr = 10, + num_burnin = 0, + num_mcmc = 1000 +) + +# Check several predict approaches +y_hat_posterior_test <- predict(bart_model, X_test)$y_hat +y_hat_mean_test <- predict( + bart_model, + X_test, + type = "mean", + terms = c("y_hat") +) +y_hat_test <- predict( + bart_model, + X_test, + type = "mean", + terms = c("rfx", "variance") +) + +y_hat_intervals <- compute_bart_posterior_interval( + model_object = bart_model, + transform = function(x) x, + terms = c("y_hat", "mean_forest"), + covariates = X_test, + level = 0.95 +) + +(coverage <- mean( + (y_hat_intervals$mean_forest_predictions$lower <= E_y_test) & + (y_hat_intervals$mean_forest_predictions$upper >= E_y_test) +)) + +pred_intervals <- sample_bart_posterior_predictive( + model_object = bart_model, + covariates = X_test, + level = 0.95 +) + +(coverage_pred <- mean( + (pred_intervals$lower <= y_test) & + (pred_intervals$upper >= y_test) +)) + +# Generate probit data +n <- 1000 +p <- 5 +X <- matrix(runif(n * p), ncol = p) +# fmt: skip +f_X <- (((0 <= X[, 1]) & (0.25 > X[, 1])) * (-2.5) + + ((0.25 <= X[, 1]) & (0.5 > X[, 1])) * (-1.25) + + ((0.5 <= X[, 1]) & (0.75 > X[, 1])) * (1.25) + + ((0.75 <= X[, 1]) & (1 > X[, 1])) * (2.5)) +noise_sd <- 1 +W <- f_X + rnorm(n, 0, noise_sd) +y <- as.numeric(W > 0) + +# Train-test split +test_set_pct <- 0.2 +n_test <- round(test_set_pct * n) +n_train <- n - n_test +test_inds <- sort(sample(1:n, n_test, replace = FALSE)) +train_inds <- (1:n)[!((1:n) %in% test_inds)] +X_test <- X[test_inds, ] +X_train <- X[train_inds, ] +W_test <- W[test_inds] +W_train <- W[train_inds] +y_test <- y[test_inds] +y_train <- y[train_inds] +E_y_test <- f_X[test_inds] +E_y_train <- f_X[train_inds] + +# Fit simple BART model +bart_model <- bart( + X_train = X_train, + y_train = y_train, + num_gfr = 10, + num_burnin = 0, + num_mcmc = 1000, + general_params = list(probit_outcome_model = TRUE) +) + +# Predict on latent scale +y_hat_post <- predict( + object = bart_model, + type = "posterior", + terms = c("y_hat"), + covariates = X_test, + scale = "linear" +) + +# Predict on probability scale +y_hat_post_prob <- predict( + object = bart_model, + type = "posterior", + terms = c("y_hat"), + covariates = X_test, + scale = "probability" +) + +# Compute intervals on latent scale +y_hat_intervals <- compute_bart_posterior_interval( + model_object = bart_model, + scale = "linear", + terms = c("y_hat"), + covariates = X_test, + level = 0.95 +) + +# Compute intervals on probability scale +y_hat_prob_intervals <- compute_bart_posterior_interval( + model_object = bart_model, + scale = "probability", + terms = c("y_hat"), + covariates = X_test, + level = 0.95 +) + +# Compute posterior means +y_hat_mean_latent <- rowMeans(y_hat_post) +y_hat_mean_prob <- rowMeans(y_hat_post_prob) + +# Plot on latent scale +sort_inds <- order(y_hat_mean_latent) +plot(y_hat_mean_latent[sort_inds]) +lines(y_hat_intervals$lower[sort_inds]) +lines(y_hat_intervals$upper[sort_inds]) + +# Plot on probability scale +sort_inds <- order(y_hat_mean_prob) +plot(y_hat_mean_prob[sort_inds]) +lines(y_hat_prob_intervals$lower[sort_inds]) +lines(y_hat_prob_intervals$upper[sort_inds]) + +# Draw from posterior predictive for covariates in the test set +ppd_samples <- sample_bart_posterior_predictive( + model_object = bart_model, + covariates = X_test, + num_draws = 10 +) + +# Compute histogram of PPD probabilities for both outcome classes +ppd_samples_prob <- apply(ppd_samples, 1, mean) +ppd_outcome_0 <- ppd_samples_prob[y_test == 0] +ppd_outcome_1 <- ppd_samples_prob[y_test == 1] +hist(ppd_outcome_0, breaks = 50, xlim = c(0, 1)) +hist(ppd_outcome_1, breaks = 50, xlim = c(0, 1)) + +# Compute posterior ROC +num_mcmc <- 1000 +num_thresholds <- 1000 +thresholds <- seq(0.001, 0.999, length.out = num_thresholds) +tpr_mean <- rep(NA, num_thresholds) +fpr_mean <- rep(NA, num_thresholds) +tpr_samples <- matrix(NA, num_thresholds, num_mcmc) +fpr_samples <- matrix(NA, num_thresholds, num_mcmc) +for (i in 1:num_thresholds) { + is_above_threshold_samples <- y_hat_post > qnorm(thresholds[i]) + is_above_threshold_mean <- y_hat_mean_latent > qnorm(thresholds[i]) + n_positive <- sum(y_test) + n_negative <- sum(y_test == 0) + y_above_threshold_mean <- y_test[is_above_threshold_mean] + for (j in 1:num_mcmc) { + y_above_threshold <- y_test[is_above_threshold_samples[, j]] + tpr_samples[i, j] <- sum(y_above_threshold) / n_positive + fpr_samples[i, j] <- sum(y_above_threshold == 0) / n_negative + } + # tpr_mean[i] <- sum(y_above_threshold_mean) / n_positive + # fpr_mean[i] <- sum(y_above_threshold_mean == 0) / n_negative + tpr_mean[i] <- mean(tpr_samples[i, ]) + fpr_mean[i] <- mean(fpr_samples[i, ]) +} + +for (i in 1:num_mcmc) { + if (i == 1) { + plot( + fpr_samples[, i], + tpr_samples[, i], + type = "line", + col = "blue", + lwd = 1, + lty = 1, + xlab = "False positive rate", + ylab = "True positive rate" + ) + } else { + lines(fpr_samples[, i], tpr_samples[, i], col = "blue", lwd = 1, lty = 1) + } +} +lines(fpr_mean, tpr_mean, col = "black", lwd = 3, lty = 3) diff --git a/tools/debug/bart_prior_draws.R b/tools/debug/bart_prior_draws.R new file mode 100644 index 00000000..ba3d1182 --- /dev/null +++ b/tools/debug/bart_prior_draws.R @@ -0,0 +1,169 @@ +library(stochtree) + +# Generate the data +n <- 500 +p_X <- 10 +p_W <- 1 +X <- matrix(runif(n * p_X), ncol = p_X) +W <- matrix(runif(n * p_W), ncol = p_W) +f_XW <- (((0 <= X[, 1]) & (0.25 > X[, 1])) * + (-3 * W[, 1]) + + ((0.25 <= X[, 1]) & (0.5 > X[, 1])) * (-1 * W[, 1]) + + ((0.5 <= X[, 1]) & (0.75 > X[, 1])) * (1 * W[, 1]) + + ((0.75 <= X[, 1]) & (1 > X[, 1])) * (3 * W[, 1])) +# y <- f_XW + rnorm(n, 0, 1) +y <- rep(0, n) +# wgt <- rep(0, n) + +# Standardize outcome +# y_bar <- mean(y) +# y_std <- sd(y) +# resid <- (y - y_bar) / y_std +resid <- y + +# Sampling parameters +alpha <- 0.99 +beta <- 1 +min_samples_leaf <- 1 +max_depth <- 10 +num_trees <- 100 +cutpoint_grid_size <- 100 +global_variance_init <- 1. +current_sigma2 <- global_variance_init +tau_init <- 1 / num_trees +leaf_prior_scale <- as.matrix(ifelse( + p_W >= 1, + diag(tau_init, p_W), + diag(tau_init, 1) +)) +nu <- 4 +lambda <- 10 +a_leaf <- 2. +b_leaf <- 0.5 +leaf_regression <- T +feature_types <- as.integer(rep(0, p_X)) # 0 = numeric +var_weights <- rep(1 / p_X, p_X) + +# Sampling data structures +# Data +# forest_dataset <- createForestDataset(X, W, wgt) +forest_dataset <- createForestDataset(X, W) +outcome_model_type <- 1 +leaf_dimension <- p_W +outcome <- createOutcome(resid) + +# Random number generator (std::mt19937) +rng <- createCppRNG() + +# Sampling data structures +forest_model_config <- createForestModelConfig( + feature_types = feature_types, + num_trees = num_trees, + num_features = p_X, + num_observations = n, + variable_weights = var_weights, + leaf_dimension = leaf_dimension, + alpha = alpha, + beta = beta, + min_samples_leaf = min_samples_leaf, + max_depth = max_depth, + leaf_model_type = outcome_model_type, + leaf_model_scale = leaf_prior_scale, + cutpoint_grid_size = cutpoint_grid_size +) +global_model_config <- createGlobalModelConfig( + global_error_variance = global_variance_init +) +forest_model <- createForestModel( + forest_dataset, + forest_model_config, + global_model_config +) + +# "Active forest" (which gets updated by the sample) and +# container of forest samples (which is written to when +# a sample is not discarded due to burn-in / thinning) +if (leaf_regression) { + forest_samples <- createForestSamples(num_trees, 1, F) + active_forest <- createForest(num_trees, 1, F) +} else { + forest_samples <- createForestSamples(num_trees, 1, T) + active_forest <- createForest(num_trees, 1, T) +} + +# Initialize the leaves of each tree in the forest +active_forest$prepare_for_sampler( + forest_dataset, + outcome, + forest_model, + outcome_model_type, + mean(resid) +) +active_forest$adjust_residual( + forest_dataset, + outcome, + forest_model, + ifelse(outcome_model_type == 1, T, F), + F +) + +# Prepare to run sampler +num_mcmc <- 2000 +global_var_samples <- rep(0, num_mcmc) + +# Run MCMC +for (i in 1:num_mcmc) { + # Sample forest + forest_model$sample_one_iteration( + forest_dataset, + outcome, + forest_samples, + active_forest, + rng, + forest_model_config, + global_model_config, + keep_forest = T, + gfr = F, + num_threads = 1 + ) + + # Sample global variance parameter + # current_sigma2 <- sampleGlobalErrorVarianceOneIteration( + # outcome, + # forest_dataset, + # rng, + # nu, + # lambda + # ) + current_sigma2 <- 1 / rgamma(1, shape = nu / 2, rate = lambda / 2) + global_var_samples[i] <- current_sigma2 + global_model_config$update_global_error_variance(current_sigma2) +} + +plot(global_var_samples, type = "l") +hist(global_var_samples, breaks = 30) +hist( + 1 / rgamma(num_mcmc, shape = nu / 2, rate = lambda / 2), + breaks = 30 +) +mean(global_var_samples) +mean(1 / rgamma(num_mcmc, shape = nu / 2, rate = lambda / 2)) +sd(global_var_samples) +sd(1 / rgamma(num_mcmc, shape = nu / 2, rate = lambda / 2)) + +# Extract forest predictions +forest_preds <- forest_samples$predict( + forest_dataset +) + +y_hat_prior <- rowMeans(forest_preds) +y_hat_prior_lb <- apply(forest_preds, 1, quantile, probs = 0.025) +y_hat_prior_ub <- apply(forest_preds, 1, quantile, probs = 0.975) +plot( + 1:length(y_hat_prior), + y_hat_prior, + type = "l", + ylim = range(c(y_hat_prior_lb, y_hat_prior_ub)) +) +lines(1:length(y_hat_prior), y_hat_prior_lb, col = "blue") +lines(1:length(y_hat_prior), y_hat_prior_ub, col = "blue") diff --git a/tools/debug/bcf_cate_debug.R b/tools/debug/bcf_cate_debug.R new file mode 100644 index 00000000..cf4f6b97 --- /dev/null +++ b/tools/debug/bcf_cate_debug.R @@ -0,0 +1,382 @@ +# Demo of CATE computation function for BCF + +# Load library +library(stochtree) + +# Generate data +n <- 500 +p <- 5 +X <- matrix(rnorm(n * p), ncol = p) +mu_x <- X[, 1] +tau_x <- 0.25 * X[, 2] +pi_x <- pnorm(0.5 * X[, 1]) +Z <- rbinom(n, 1, pi_x) +E_XZ <- mu_x + Z * tau_x +snr <- 2 +y <- E_XZ + rnorm(n, 0, 1) * (sd(E_XZ) / snr) + +# Train-test split +test_set_pct <- 0.2 +n_test <- round(test_set_pct * n) +n_train <- n - n_test +test_inds <- sort(sample(1:n, n_test, replace = FALSE)) +train_inds <- (1:n)[!((1:n) %in% test_inds)] +X_test <- X[test_inds, ] +X_train <- X[train_inds, ] +pi_test <- pi_x[test_inds] +pi_train <- pi_x[train_inds] +Z_test <- Z[test_inds] +Z_train <- Z[train_inds] +y_test <- y[test_inds] +y_train <- y[train_inds] +mu_test <- mu_x[test_inds] +mu_train <- mu_x[train_inds] +tau_test <- tau_x[test_inds] +tau_train <- tau_x[train_inds] + +# Fit BCF model +bcf_model <- bcf( + X_train = X_train, + Z_train = Z_train, + y_train = y_train, + propensity_train = pi_train, + X_test = X_test, + Z_test = Z_test, + propensity_test = pi_test, + num_gfr = 10, + num_burnin = 0, + num_mcmc = 1000 +) + +# Compute CATE posterior +tau_hat_posterior_test <- compute_contrast_bcf_model( + bcf_model, + X_0 = X_test, + X_1 = X_test, + Z_0 = rep(0, n_test), + Z_1 = rep(1, n_test), + propensity_0 = pi_test, + propensity_1 = pi_test, + type = "posterior", + scale = "linear" +) + +# Compute the same quantity via predict +tau_hat_posterior_test_comparison <- predict( + bcf_model, + X = X_test, + Z = Z_test, + propensity = pi_test, + type = "posterior", + terms = "cate", + scale = "linear" +) + +# Compare results +tau_diff <- tau_hat_posterior_test_comparison - tau_hat_posterior_test +all( + abs(tau_diff) < 0.001 +) + +# Generate data for a BCF model with random effects +X <- matrix(rnorm(n * p), ncol = p) +mu_x <- X[, 1] +tau_x <- 0.25 * X[, 2] +pi_x <- pnorm(0.5 * X[, 1]) +Z <- rbinom(n, 1, pi_x) +group_ids <- rep(c(1, 2), n %/% 2) +rfx_coefs <- matrix(c(-1, -1, 1, 1), nrow = 2, byrow = TRUE) +rfx_basis <- cbind(1, Z) +rfx_term <- rowSums(rfx_coefs[group_ids, ] * rfx_basis) +E_XZ <- mu_x + Z * tau_x + rfx_term +snr <- 2 +y <- E_XZ + rnorm(n, 0, 1) * (sd(E_XZ) / snr) + +# Train-test split +test_set_pct <- 0.2 +n_test <- round(test_set_pct * n) +n_train <- n - n_test +test_inds <- sort(sample(1:n, n_test, replace = FALSE)) +train_inds <- (1:n)[!((1:n) %in% test_inds)] +X_test <- X[test_inds, ] +X_train <- X[train_inds, ] +pi_test <- pi_x[test_inds] +pi_train <- pi_x[train_inds] +Z_test <- Z[test_inds] +Z_train <- Z[train_inds] +y_test <- y[test_inds] +y_train <- y[train_inds] +mu_test <- mu_x[test_inds] +mu_train <- mu_x[train_inds] +tau_test <- tau_x[test_inds] +tau_train <- tau_x[train_inds] +group_ids_test <- group_ids[test_inds] +group_ids_train <- group_ids[train_inds] +rfx_basis_test <- rfx_basis[test_inds, ] +rfx_basis_train <- rfx_basis[train_inds, ] + +# Fit BCF model +bcf_model <- bcf( + X_train = X_train, + Z_train = Z_train, + y_train = y_train, + propensity_train = pi_train, + rfx_group_ids_train = group_ids_train, + rfx_basis_train = rfx_basis_train, + X_test = X_test, + Z_test = Z_test, + propensity_test = pi_test, + rfx_group_ids_test = group_ids_test, + rfx_basis_test = rfx_basis_test, + num_gfr = 10, + num_burnin = 0, + num_mcmc = 1000 +) + +# Compute CATE posterior +tau_hat_posterior_test <- compute_contrast_bcf_model( + bcf_model, + X_0 = X_test, + X_1 = X_test, + Z_0 = rep(0, n_test), + Z_1 = rep(1, n_test), + propensity_0 = pi_test, + propensity_1 = pi_test, + rfx_group_ids_0 = group_ids_test, + rfx_group_ids_1 = group_ids_test, + rfx_basis_0 = cbind(1, rep(0, n_test)), + rfx_basis_1 = cbind(1, rep(1, n_test)), + type = "posterior", + scale = "linear" +) + +# Compute the same quantity via predict +tau_forest_posterior_test <- predict( + bcf_model, + X = X_test, + Z = Z_test, + propensity = pi_test, + rfx_group_ids = group_ids_test, + rfx_basis = rfx_basis_test, + type = "posterior", + terms = "cate", + scale = "linear" +) +rfx_term_posterior_treated <- predict( + bcf_model, + X = X_test, + Z = Z_test, + propensity = pi_test, + rfx_group_ids = group_ids_test, + rfx_basis = cbind(1, rep(1, n_test)), + type = "posterior", + terms = "rfx", + scale = "linear" +) +rfx_term_posterior_control <- predict( + bcf_model, + X = X_test, + Z = Z_test, + propensity = pi_test, + rfx_group_ids = group_ids_test, + rfx_basis = cbind(1, rep(0, n_test)), + type = "posterior", + terms = "rfx", + scale = "linear" +) +tau_hat_posterior_test_comparison <- (tau_forest_posterior_test + + (rfx_term_posterior_treated - rfx_term_posterior_control)) + +# Compare results +tau_diff <- tau_hat_posterior_test_comparison - tau_hat_posterior_test +all( + abs(tau_diff) < 0.001 +) + +# Now repeat the same process but via random effects model spec +bcf_model <- bcf( + X_train = X_train, + Z_train = Z_train, + y_train = y_train, + propensity_train = pi_train, + rfx_group_ids_train = group_ids_train, + X_test = X_test, + Z_test = Z_test, + propensity_test = pi_test, + rfx_group_ids_test = group_ids_test, + num_gfr = 10, + num_burnin = 0, + num_mcmc = 1000, + random_effects_params = list( + model_spec = "intercept_plus_treatment" + ) +) + +# Compute CATE posterior +tau_hat_posterior_test <- compute_contrast_bcf_model( + bcf_model, + X_0 = X_test, + X_1 = X_test, + Z_0 = rep(0, n_test), + Z_1 = rep(1, n_test), + propensity_0 = pi_test, + propensity_1 = pi_test, + rfx_group_ids_0 = group_ids_test, + rfx_group_ids_1 = group_ids_test, + rfx_basis_0 = cbind(1, rep(0, n_test)), + rfx_basis_1 = cbind(1, rep(1, n_test)), + type = "posterior", + scale = "linear" +) + +# Compute the same quantity directly via predict +tau_hat_posterior_test_comparison <- predict( + bcf_model, + X = X_test, + Z = Z_test, + propensity = pi_test, + rfx_group_ids = group_ids_test, + type = "posterior", + terms = "cate", + scale = "linear" +) + +# Compare results +tau_diff <- tau_hat_posterior_test_comparison - tau_hat_posterior_test +all( + abs(tau_diff) < 0.001 +) + +# Generate data for a probit BCF model with random effects +X <- matrix(rnorm(n * p), ncol = p) +mu_x <- X[, 1] +tau_x <- 0.25 * X[, 2] +pi_x <- pnorm(0.5 * X[, 1]) +Z <- rbinom(n, 1, pi_x) +group_ids <- rep(c(1, 2), n %/% 2) +rfx_coefs <- matrix(c(-1, -1, 1, 1), nrow = 2, byrow = TRUE) +rfx_basis <- cbind(1, Z) +rfx_term <- rowSums(rfx_coefs[group_ids, ] * rfx_basis) +E_XZ <- mu_x + Z * tau_x + rfx_term +# E_XZ <- mu_x + Z * tau_x + rfx_term +W <- E_XZ + rnorm(n, 0, 1) +y <- as.numeric(W > 0) + +# Train-test split +test_set_pct <- 0.2 +n_test <- round(test_set_pct * n) +n_train <- n - n_test +test_inds <- sort(sample(1:n, n_test, replace = FALSE)) +train_inds <- (1:n)[!((1:n) %in% test_inds)] +X_test <- X[test_inds, ] +X_train <- X[train_inds, ] +pi_test <- pi_x[test_inds] +pi_train <- pi_x[train_inds] +Z_test <- Z[test_inds] +Z_train <- Z[train_inds] +W_test <- W[test_inds] +W_train <- W[train_inds] +y_test <- y[test_inds] +y_train <- y[train_inds] +mu_test <- mu_x[test_inds] +mu_train <- mu_x[train_inds] +tau_test <- tau_x[test_inds] +tau_train <- tau_x[train_inds] +group_ids_test <- group_ids[test_inds] +group_ids_train <- group_ids[train_inds] +rfx_basis_test <- rfx_basis[test_inds, ] +rfx_basis_train <- rfx_basis[train_inds, ] + +# Fit BCF model +bcf_model <- bcf( + X_train = X_train, + Z_train = Z_train, + y_train = y_train, + propensity_train = pi_train, + rfx_group_ids_train = group_ids_train, + rfx_basis_train = rfx_basis_train, + X_test = X_test, + Z_test = Z_test, + propensity_test = pi_test, + rfx_group_ids_test = group_ids_test, + rfx_basis_test = rfx_basis_test, + num_gfr = 10, + num_burnin = 0, + num_mcmc = 1000, + general_params = list(probit_outcome_model = T) +) + +# Compute CATE posterior on probability scale +tau_hat_posterior_test <- compute_contrast_bcf_model( + bcf_model, + X_0 = X_test, + X_1 = X_test, + Z_0 = rep(0, n_test), + Z_1 = rep(1, n_test), + propensity_0 = pi_test, + propensity_1 = pi_test, + rfx_group_ids_0 = group_ids_test, + rfx_group_ids_1 = group_ids_test, + rfx_basis_0 = cbind(1, rep(0, n_test)), + rfx_basis_1 = cbind(1, rep(1, n_test)), + type = "posterior", + scale = "probability" +) + +# Compute the same quantity via predict +mu_forest_posterior_test <- predict( + bcf_model, + X = X_test, + Z = Z_test, + propensity = pi_test, + rfx_group_ids = group_ids_test, + rfx_basis = rfx_basis_test, + type = "posterior", + terms = "prognostic_function", + scale = "linear" +) +tau_forest_posterior_test <- predict( + bcf_model, + X = X_test, + Z = Z_test, + propensity = pi_test, + rfx_group_ids = group_ids_test, + rfx_basis = rfx_basis_test, + type = "posterior", + terms = "cate", + scale = "linear" +) +rfx_term_posterior_treated <- predict( + bcf_model, + X = X_test, + Z = Z_test, + propensity = pi_test, + rfx_group_ids = group_ids_test, + rfx_basis = cbind(1, rep(1, n_test)), + type = "posterior", + terms = "rfx", + scale = "linear" +) +rfx_term_posterior_control <- predict( + bcf_model, + X = X_test, + Z = Z_test, + propensity = pi_test, + rfx_group_ids = group_ids_test, + rfx_basis = cbind(1, rep(0, n_test)), + type = "posterior", + terms = "rfx", + scale = "linear" +) +w_hat_0 <- mu_forest_posterior_test + + rfx_term_posterior_control +w_hat_1 <- mu_forest_posterior_test + + tau_forest_posterior_test + + rfx_term_posterior_treated +tau_hat_posterior_test_comparison <- pnorm(w_hat_1) - pnorm(w_hat_0) + +# Compare results +tau_diff <- tau_hat_posterior_test_comparison - tau_hat_posterior_test +all( + abs(tau_diff) < 0.001 +) diff --git a/tools/debug/bcf_predict_debug.R b/tools/debug/bcf_predict_debug.R new file mode 100644 index 00000000..7a854707 --- /dev/null +++ b/tools/debug/bcf_predict_debug.R @@ -0,0 +1,229 @@ +# Demo of updated predict method for BCF + +# Load library +library(stochtree) + +# Generate data +n <- 500 +p <- 5 +X <- matrix(rnorm(n * p), ncol = p) +mu_x <- X[, 1] +tau_x <- 0.25 * X[, 2] +pi_x <- pnorm(0.5 * X[, 1]) +Z <- rbinom(n, 1, pi_x) +E_XZ <- mu_x + Z * tau_x +snr <- 2 +y <- E_XZ + rnorm(n, 0, 1) * (sd(E_XZ) / snr) + +# Train-test split +test_set_pct <- 0.2 +n_test <- round(test_set_pct * n) +n_train <- n - n_test +test_inds <- sort(sample(1:n, n_test, replace = FALSE)) +train_inds <- (1:n)[!((1:n) %in% test_inds)] +X_test <- X[test_inds, ] +X_train <- X[train_inds, ] +pi_test <- pi_x[test_inds] +pi_train <- pi_x[train_inds] +Z_test <- Z[test_inds] +Z_train <- Z[train_inds] +y_test <- y[test_inds] +y_train <- y[train_inds] +mu_test <- mu_x[test_inds] +mu_train <- mu_x[train_inds] +tau_test <- tau_x[test_inds] +tau_train <- tau_x[train_inds] + +# Fit a simple BCF model +bcf_model <- bcf( + X_train = X_train, + Z_train = Z_train, + y_train = y_train, + propensity_train = pi_train, + X_test = X_test, + Z_test = Z_test, + propensity_test = pi_test, + num_gfr = 10, + num_burnin = 0, + num_mcmc = 1000 +) + +# Check several predict approaches +y_hat_posterior_test <- predict( + bcf_model, + X = X_test, + Z = Z_test, + propensity = pi_test +)$y_hat +pred <- predict( + bcf_model, + X = X_test, + Z = Z_test, + propensity = pi_test, + type = "mean", + terms = c("all") +) +# Check that this throws a warning +y_hat_test <- predict( + bcf_model, + X = X_test, + Z = Z_test, + propensity = pi_test, + type = "mean", + terms = c("rfx", "variance") +) + +# Compute intervals around model terms (including E[y | X, Z]) +y_hat_intervals <- compute_bcf_posterior_interval( + model_object = bcf_model, + scale = "linear", + terms = c("all"), + covariates = X_test, + treatment = Z_test, + propensity = pi_test, + level = 0.95 +) + +# Estimate coverage of intervals on tau(X) +(tau_coverage <- mean( + (y_hat_intervals$tau_hat$upper >= tau_test) & + (y_hat_intervals$tau_hat$lower <= tau_test) +)) + +# Posterior predictive coverage and MSE checks +quantiles <- c(0.05, 0.95) +ppd_samples <- sample_bcf_posterior_predictive( + model_object = bcf_model, + covariates = X_test, + treatment = Z_test, + propensity = pi_test, + num_draws = 1 +) +yhat_ppd <- apply(ppd_samples, 1, mean) +yhat_interval_ppd <- apply(ppd_samples, 1, quantile, probs = quantiles) +mean((yhat_interval_ppd[1, ] <= y_test) & (yhat_interval_ppd[2, ] >= y_test)) +sqrt(mean((yhat_ppd - y_test)^2)) + +# Generate probit outcome data +n <- 1000 +p <- 5 +X <- matrix(rnorm(n * p), ncol = p) +mu_x <- X[, 1] +tau_x <- 0.25 * X[, 2] +pi_x <- pnorm(0.5 * X[, 1]) +Z <- rbinom(n, 1, pi_x) +E_XZ <- mu_x + Z * tau_x +W <- E_XZ + rnorm(n, 0, 1) +y <- (W > 0) * 1 + +# Train-test split +test_set_pct <- 0.2 +n_test <- round(test_set_pct * n) +n_train <- n - n_test +test_inds <- sort(sample(1:n, n_test, replace = FALSE)) +train_inds <- (1:n)[!((1:n) %in% test_inds)] +X_test <- X[test_inds, ] +X_train <- X[train_inds, ] +pi_test <- pi_x[test_inds] +pi_train <- pi_x[train_inds] +Z_test <- Z[test_inds] +Z_train <- Z[train_inds] +W_test <- W[test_inds] +W_train <- W[train_inds] +mu_test <- mu_x[test_inds] +mu_train <- mu_x[train_inds] +tau_test <- tau_x[test_inds] +tau_train <- tau_x[train_inds] +y_test <- y[test_inds] +y_train <- y[train_inds] + +# Fit a simple BCF model +bcf_model <- bcf( + X_train = X_train, + Z_train = Z_train, + y_train = y_train, + propensity_train = pi_train, + X_test = X_test, + Z_test = Z_test, + propensity_test = pi_test, + num_gfr = 10, + num_burnin = 0, + num_mcmc = 1000, + general_params = list(probit_outcome_model = TRUE) +) + +# Predict on latent scale +y_hat_post <- predict( + object = bcf_model, + type = "posterior", + terms = c("y_hat"), + X = X_test, + Z = Z_test, + propensity = pi_test, + scale = "linear" +) + +# Predict on probability scale +y_hat_post_prob <- predict( + object = bcf_model, + type = "posterior", + terms = c("y_hat"), + X = X_test, + Z = Z_test, + propensity = pi_test, + scale = "probability" +) + +# Compute intervals on latent scale +y_hat_intervals <- compute_bcf_posterior_interval( + model_object = bcf_model, + scale = "linear", + terms = c("y_hat"), + covariates = X_test, + treatment = Z_test, + propensity = pi_test, + level = 0.95 +) + +# Compute intervals on probability scale +y_hat_prob_intervals <- compute_bcf_posterior_interval( + model_object = bcf_model, + scale = "probability", + terms = c("y_hat"), + covariates = X_test, + treatment = Z_test, + propensity = pi_test, + level = 0.95 +) + +# Compute posterior means +y_hat_mean_latent <- rowMeans(y_hat_post) +y_hat_mean_prob <- rowMeans(y_hat_post_prob) + +# Plot on latent scale +sort_inds <- order(y_hat_mean_latent) +plot(y_hat_mean_latent[sort_inds], ylim = range(y_hat_intervals)) +lines(y_hat_intervals$lower[sort_inds]) +lines(y_hat_intervals$upper[sort_inds]) + +# Plot on probability scale +sort_inds <- order(y_hat_mean_prob) +plot(y_hat_mean_prob[sort_inds], ylim = range(y_hat_prob_intervals)) +lines(y_hat_prob_intervals$lower[sort_inds]) +lines(y_hat_prob_intervals$upper[sort_inds]) + +# Draw from posterior predictive for covariates / treatment values in the test set +ppd_samples <- sample_bcf_posterior_predictive( + model_object = bcf_model, + covariates = X_test, + treatment = Z_test, + propensity = pi_test, + num_draws = 10 +) + +# Compute histogram of PPD probabilities for both outcome classes +ppd_samples_prob <- apply(ppd_samples, 1, mean) +ppd_outcome_0 <- ppd_samples_prob[y_test == 0] +ppd_outcome_1 <- ppd_samples_prob[y_test == 1] +hist(ppd_outcome_0, breaks = 50, xlim = c(0, 1)) +hist(ppd_outcome_1, breaks = 50, xlim = c(0, 1))