From 7f0e5ec28ffc0831a72b6a9e0ab143443c7a2a66 Mon Sep 17 00:00:00 2001 From: Drew Herren Date: Mon, 6 Oct 2025 17:57:53 -0500 Subject: [PATCH 01/53] Initial overhaul of predict method --- R/bart.R | 4739 +++++++++--------- man/bart.Rd | 8 +- man/bcf.Rd | 24 +- man/createBARTModelFromCombinedJson.Rd | 8 +- man/createBARTModelFromCombinedJsonString.Rd | 8 +- man/createBARTModelFromJson.Rd | 8 +- man/createBARTModelFromJsonFile.Rd | 8 +- man/createBARTModelFromJsonString.Rd | 8 +- man/createBCFModelFromCombinedJson.Rd | 30 +- man/createBCFModelFromCombinedJsonString.Rd | 30 +- man/createBCFModelFromJson.Rd | 34 +- man/createBCFModelFromJsonFile.Rd | 34 +- man/createBCFModelFromJsonString.Rd | 30 +- man/createForestModel.Rd | 8 +- man/getRandomEffectSamples.bartmodel.Rd | 16 +- man/getRandomEffectSamples.bcfmodel.Rd | 34 +- man/predict.bartmodel.Rd | 14 +- man/predict.bcfmodel.Rd | 22 +- man/preprocessPredictionData.Rd | 2 +- man/resetForestModel.Rd | 22 +- man/resetRandomEffectsModel.Rd | 4 +- man/resetRandomEffectsTracker.Rd | 4 +- man/rootResetRandomEffectsModel.Rd | 4 +- man/rootResetRandomEffectsTracker.Rd | 4 +- man/saveBARTModelToJson.Rd | 8 +- man/saveBARTModelToJsonFile.Rd | 8 +- man/saveBARTModelToJsonString.Rd | 8 +- man/saveBCFModelToJson.Rd | 34 +- man/saveBCFModelToJsonFile.Rd | 34 +- man/saveBCFModelToJsonString.Rd | 34 +- test/R/testthat/test-predict.R | 84 + tools/debug/predict_debug.R | 31 + 32 files changed, 2760 insertions(+), 2584 deletions(-) create mode 100644 tools/debug/predict_debug.R diff --git a/R/bart.R b/R/bart.R index 83e2f828..829d016a 100644 --- a/R/bart.R +++ b/R/bart.R @@ -111,1655 +111,1651 @@ #' bart_model <- bart(X_train = X_train, y_train = y_train, X_test = X_test, #' num_gfr = 10, num_burnin = 0, num_mcmc = 10) bart <- function( - X_train, - y_train, - leaf_basis_train = NULL, - rfx_group_ids_train = NULL, - rfx_basis_train = NULL, - X_test = NULL, - leaf_basis_test = NULL, - rfx_group_ids_test = NULL, - rfx_basis_test = NULL, - num_gfr = 5, - num_burnin = 0, - num_mcmc = 100, - previous_model_json = NULL, - previous_model_warmstart_sample_num = NULL, - general_params = list(), - mean_forest_params = list(), - variance_forest_params = list() + X_train, + y_train, + leaf_basis_train = NULL, + rfx_group_ids_train = NULL, + rfx_basis_train = NULL, + X_test = NULL, + leaf_basis_test = NULL, + rfx_group_ids_test = NULL, + rfx_basis_test = NULL, + num_gfr = 5, + num_burnin = 0, + num_mcmc = 100, + previous_model_json = NULL, + previous_model_warmstart_sample_num = NULL, + general_params = list(), + mean_forest_params = list(), + variance_forest_params = list() ) { - # Update general BART parameters - general_params_default <- list( - cutpoint_grid_size = 100, - standardize = TRUE, - sample_sigma2_global = TRUE, - sigma2_global_init = NULL, - sigma2_global_shape = 0, - sigma2_global_scale = 0, - variable_weights = NULL, - random_seed = -1, - keep_burnin = FALSE, - keep_gfr = FALSE, - keep_every = 1, - num_chains = 1, - verbose = FALSE, - probit_outcome_model = FALSE, - rfx_working_parameter_prior_mean = NULL, - rfx_group_parameter_prior_mean = NULL, - rfx_working_parameter_prior_cov = NULL, - rfx_group_parameter_prior_cov = NULL, - rfx_variance_prior_shape = 1, - rfx_variance_prior_scale = 1, - num_threads = -1 + # Update general BART parameters + general_params_default <- list( + cutpoint_grid_size = 100, + standardize = TRUE, + sample_sigma2_global = TRUE, + sigma2_global_init = NULL, + sigma2_global_shape = 0, + sigma2_global_scale = 0, + variable_weights = NULL, + random_seed = -1, + keep_burnin = FALSE, + keep_gfr = FALSE, + keep_every = 1, + num_chains = 1, + verbose = FALSE, + probit_outcome_model = FALSE, + rfx_working_parameter_prior_mean = NULL, + rfx_group_parameter_prior_mean = NULL, + rfx_working_parameter_prior_cov = NULL, + rfx_group_parameter_prior_cov = NULL, + rfx_variance_prior_shape = 1, + rfx_variance_prior_scale = 1, + num_threads = -1 + ) + general_params_updated <- preprocessParams( + general_params_default, + general_params + ) + + # Update mean forest BART parameters + mean_forest_params_default <- list( + num_trees = 200, + alpha = 0.95, + beta = 2.0, + min_samples_leaf = 5, + max_depth = 10, + sample_sigma2_leaf = TRUE, + sigma2_leaf_init = NULL, + sigma2_leaf_shape = 3, + sigma2_leaf_scale = NULL, + keep_vars = NULL, + drop_vars = NULL, + num_features_subsample = NULL + ) + mean_forest_params_updated <- preprocessParams( + mean_forest_params_default, + mean_forest_params + ) + + # Update variance forest BART parameters + variance_forest_params_default <- list( + num_trees = 0, + alpha = 0.95, + beta = 2.0, + min_samples_leaf = 5, + max_depth = 10, + leaf_prior_calibration_param = 1.5, + var_forest_leaf_init = NULL, + var_forest_prior_shape = NULL, + var_forest_prior_scale = NULL, + keep_vars = NULL, + drop_vars = NULL, + num_features_subsample = NULL + ) + variance_forest_params_updated <- preprocessParams( + variance_forest_params_default, + variance_forest_params + ) + + ### Unpack all parameter values + # 1. General parameters + cutpoint_grid_size <- general_params_updated$cutpoint_grid_size + standardize <- general_params_updated$standardize + sample_sigma2_global <- general_params_updated$sample_sigma2_global + sigma2_init <- general_params_updated$sigma2_global_init + a_global <- general_params_updated$sigma2_global_shape + b_global <- general_params_updated$sigma2_global_scale + variable_weights <- general_params_updated$variable_weights + random_seed <- general_params_updated$random_seed + keep_burnin <- general_params_updated$keep_burnin + keep_gfr <- general_params_updated$keep_gfr + keep_every <- general_params_updated$keep_every + num_chains <- general_params_updated$num_chains + verbose <- general_params_updated$verbose + probit_outcome_model <- general_params_updated$probit_outcome_model + rfx_working_parameter_prior_mean <- general_params_updated$rfx_working_parameter_prior_mean + rfx_group_parameter_prior_mean <- general_params_updated$rfx_group_parameter_prior_mean + rfx_working_parameter_prior_cov <- general_params_updated$rfx_working_parameter_prior_cov + rfx_group_parameter_prior_cov <- general_params_updated$rfx_group_parameter_prior_cov + rfx_variance_prior_shape <- general_params_updated$rfx_variance_prior_shape + rfx_variance_prior_scale <- general_params_updated$rfx_variance_prior_scale + num_threads <- general_params_updated$num_threads + + # 2. Mean forest parameters + num_trees_mean <- mean_forest_params_updated$num_trees + alpha_mean <- mean_forest_params_updated$alpha + beta_mean <- mean_forest_params_updated$beta + min_samples_leaf_mean <- mean_forest_params_updated$min_samples_leaf + max_depth_mean <- mean_forest_params_updated$max_depth + sample_sigma2_leaf <- mean_forest_params_updated$sample_sigma2_leaf + sigma2_leaf_init <- mean_forest_params_updated$sigma2_leaf_init + a_leaf <- mean_forest_params_updated$sigma2_leaf_shape + b_leaf <- mean_forest_params_updated$sigma2_leaf_scale + keep_vars_mean <- mean_forest_params_updated$keep_vars + drop_vars_mean <- mean_forest_params_updated$drop_vars + num_features_subsample_mean <- mean_forest_params_updated$num_features_subsample + + # 3. Variance forest parameters + num_trees_variance <- variance_forest_params_updated$num_trees + alpha_variance <- variance_forest_params_updated$alpha + beta_variance <- variance_forest_params_updated$beta + min_samples_leaf_variance <- variance_forest_params_updated$min_samples_leaf + max_depth_variance <- variance_forest_params_updated$max_depth + a_0 <- variance_forest_params_updated$leaf_prior_calibration_param + variance_forest_init <- variance_forest_params_updated$var_forest_leaf_init + a_forest <- variance_forest_params_updated$var_forest_prior_shape + b_forest <- variance_forest_params_updated$var_forest_prior_scale + keep_vars_variance <- variance_forest_params_updated$keep_vars + drop_vars_variance <- variance_forest_params_updated$drop_vars + num_features_subsample_variance <- variance_forest_params_updated$num_features_subsample + + # Check if there are enough GFR samples to seed num_chains samplers + if (num_gfr > 0) { + if (num_chains > num_gfr) { + stop( + "num_chains > num_gfr, meaning we do not have enough GFR samples to seed num_chains distinct MCMC chains" + ) + } + } + + # Override keep_gfr if there are no MCMC samples + if (num_mcmc == 0) { + keep_gfr <- TRUE + } + + # Check if previous model JSON is provided and parse it if so + has_prev_model <- !is.null(previous_model_json) + if (has_prev_model) { + previous_bart_model <- createBARTModelFromJsonString( + previous_model_json ) - general_params_updated <- preprocessParams( - general_params_default, - general_params - ) - - # Update mean forest BART parameters - mean_forest_params_default <- list( - num_trees = 200, - alpha = 0.95, - beta = 2.0, - min_samples_leaf = 5, - max_depth = 10, - sample_sigma2_leaf = TRUE, - sigma2_leaf_init = NULL, - sigma2_leaf_shape = 3, - sigma2_leaf_scale = NULL, - keep_vars = NULL, - drop_vars = NULL, - num_features_subsample = NULL - ) - mean_forest_params_updated <- preprocessParams( - mean_forest_params_default, - mean_forest_params - ) - - # Update variance forest BART parameters - variance_forest_params_default <- list( - num_trees = 0, - alpha = 0.95, - beta = 2.0, - min_samples_leaf = 5, - max_depth = 10, - leaf_prior_calibration_param = 1.5, - var_forest_leaf_init = NULL, - var_forest_prior_shape = NULL, - var_forest_prior_scale = NULL, - keep_vars = NULL, - drop_vars = NULL, - num_features_subsample = NULL + previous_y_bar <- previous_bart_model$model_params$outcome_mean + previous_y_scale <- previous_bart_model$model_params$outcome_scale + if (previous_bart_model$model_params$include_mean_forest) { + previous_forest_samples_mean <- previous_bart_model$mean_forests + } else { + previous_forest_samples_mean <- NULL + } + if (previous_bart_model$model_params$include_variance_forest) { + previous_forest_samples_variance <- previous_bart_model$variance_forests + } else { + previous_forest_samples_variance <- NULL + } + if (previous_bart_model$model_params$sample_sigma2_global) { + previous_global_var_samples <- previous_bart_model$sigma2_global_samples / + (previous_y_scale * previous_y_scale) + } else { + previous_global_var_samples <- NULL + } + if (previous_bart_model$model_params$sample_sigma2_leaf) { + previous_leaf_var_samples <- previous_bart_model$sigma2_leaf_samples + } else { + previous_leaf_var_samples <- NULL + } + if (previous_bart_model$model_params$has_rfx) { + previous_rfx_samples <- previous_bart_model$rfx_samples + } else { + previous_rfx_samples <- NULL + } + previous_model_num_samples <- previous_bart_model$model_params$num_samples + if (previous_model_warmstart_sample_num > previous_model_num_samples) { + stop( + "`previous_model_warmstart_sample_num` exceeds the number of samples in `previous_model_json`" + ) + } + } else { + previous_y_bar <- NULL + previous_y_scale <- NULL + previous_global_var_samples <- NULL + previous_leaf_var_samples <- NULL + previous_rfx_samples <- NULL + previous_forest_samples_mean <- NULL + previous_forest_samples_variance <- NULL + previous_model_num_samples <- 0 + } + + # Determine whether conditional mean, variance, or both will be modeled + if (num_trees_variance > 0) { + include_variance_forest = TRUE + } else { + include_variance_forest = FALSE + } + if (num_trees_mean > 0) { + include_mean_forest = TRUE + } else { + include_mean_forest = FALSE + } + + # Set the variance forest priors if not set + if (include_variance_forest) { + if (is.null(a_forest)) { + a_forest <- num_trees_variance / (a_0^2) + 0.5 + } + if (is.null(b_forest)) b_forest <- num_trees_variance / (a_0^2) + } else { + a_forest <- 1. + b_forest <- 1. + } + + # Override tau sampling if there is no mean forest + if (!include_mean_forest) { + sample_sigma2_leaf <- FALSE + } + + # Variable weight preprocessing (and initialization if necessary) + if (is.null(variable_weights)) { + variable_weights = rep(1 / ncol(X_train), ncol(X_train)) + } + if (any(variable_weights < 0)) { + stop("variable_weights cannot have any negative weights") + } + + # Check covariates are matrix or dataframe + if ((!is.data.frame(X_train)) && (!is.matrix(X_train))) { + stop("X_train must be a matrix or dataframe") + } + if (!is.null(X_test)) { + if ((!is.data.frame(X_test)) && (!is.matrix(X_test))) { + stop("X_test must be a matrix or dataframe") + } + } + num_cov_orig <- ncol(X_train) + + # Standardize the keep variable lists to numeric indices + if (!is.null(keep_vars_mean)) { + if (is.character(keep_vars_mean)) { + if (!all(keep_vars_mean %in% names(X_train))) { + stop( + "keep_vars_mean includes some variable names that are not in X_train" + ) + } + variable_subset_mu <- unname(which( + names(X_train) %in% keep_vars_mean + )) + } else { + if (any(keep_vars_mean > ncol(X_train))) { + stop( + "keep_vars_mean includes some variable indices that exceed the number of columns in X_train" + ) + } + if (any(keep_vars_mean < 0)) { + stop("keep_vars_mean includes some negative variable indices") + } + variable_subset_mu <- keep_vars_mean + } + } else if ((is.null(keep_vars_mean)) && (!is.null(drop_vars_mean))) { + if (is.character(drop_vars_mean)) { + if (!all(drop_vars_mean %in% names(X_train))) { + stop( + "drop_vars_mean includes some variable names that are not in X_train" + ) + } + variable_subset_mean <- unname(which( + !(names(X_train) %in% drop_vars_mean) + )) + } else { + if (any(drop_vars_mean > ncol(X_train))) { + stop( + "drop_vars_mean includes some variable indices that exceed the number of columns in X_train" + ) + } + if (any(drop_vars_mean < 0)) { + stop("drop_vars_mean includes some negative variable indices") + } + variable_subset_mean <- (1:ncol(X_train))[ + !(1:ncol(X_train) %in% drop_vars_mean) + ] + } + } else { + variable_subset_mean <- 1:ncol(X_train) + } + if (!is.null(keep_vars_variance)) { + if (is.character(keep_vars_variance)) { + if (!all(keep_vars_variance %in% names(X_train))) { + stop( + "keep_vars_variance includes some variable names that are not in X_train" + ) + } + variable_subset_variance <- unname(which( + names(X_train) %in% keep_vars_variance + )) + } else { + if (any(keep_vars_variance > ncol(X_train))) { + stop( + "keep_vars_variance includes some variable indices that exceed the number of columns in X_train" + ) + } + if (any(keep_vars_variance < 0)) { + stop( + "keep_vars_variance includes some negative variable indices" + ) + } + variable_subset_variance <- keep_vars_variance + } + } else if ((is.null(keep_vars_variance)) && (!is.null(drop_vars_variance))) { + if (is.character(drop_vars_variance)) { + if (!all(drop_vars_variance %in% names(X_train))) { + stop( + "drop_vars_variance includes some variable names that are not in X_train" + ) + } + variable_subset_variance <- unname(which( + !(names(X_train) %in% drop_vars_variance) + )) + } else { + if (any(drop_vars_variance > ncol(X_train))) { + stop( + "drop_vars_variance includes some variable indices that exceed the number of columns in X_train" + ) + } + if (any(drop_vars_variance < 0)) { + stop( + "drop_vars_variance includes some negative variable indices" + ) + } + variable_subset_variance <- (1:ncol(X_train))[ + !(1:ncol(X_train) %in% drop_vars_variance) + ] + } + } else { + variable_subset_variance <- 1:ncol(X_train) + } + + # Preprocess covariates + if ((!is.data.frame(X_train)) && (!is.matrix(X_train))) { + stop("X_train must be a matrix or dataframe") + } + if (!is.null(X_test)) { + if ((!is.data.frame(X_test)) && (!is.matrix(X_test))) { + stop("X_test must be a matrix or dataframe") + } + } + if (ncol(X_train) != length(variable_weights)) { + stop("length(variable_weights) must equal ncol(X_train)") + } + train_cov_preprocess_list <- preprocessTrainData(X_train) + X_train_metadata <- train_cov_preprocess_list$metadata + X_train <- train_cov_preprocess_list$data + original_var_indices <- X_train_metadata$original_var_indices + feature_types <- X_train_metadata$feature_types + if (!is.null(X_test)) { + X_test <- preprocessPredictionData(X_test, X_train_metadata) + } + + # Update variable weights + variable_weights_mean <- variable_weights_variance <- variable_weights + variable_weights_adj <- 1 / + sapply(original_var_indices, function(x) sum(original_var_indices == x)) + if (include_mean_forest) { + variable_weights_mean <- variable_weights_mean[original_var_indices] * + variable_weights_adj + variable_weights_mean[ + !(original_var_indices %in% variable_subset_mean) + ] <- 0 + } + if (include_variance_forest) { + variable_weights_variance <- variable_weights_variance[ + original_var_indices + ] * + variable_weights_adj + variable_weights_variance[ + !(original_var_indices %in% variable_subset_variance) + ] <- 0 + } + + # Set num_features_subsample to default, ncol(X_train), if not already set + if (is.null(num_features_subsample_mean)) { + num_features_subsample_mean <- ncol(X_train) + } + if (is.null(num_features_subsample_variance)) { + num_features_subsample_variance <- ncol(X_train) + } + + # Convert all input data to matrices if not already converted + if ((is.null(dim(leaf_basis_train))) && (!is.null(leaf_basis_train))) { + leaf_basis_train <- as.matrix(leaf_basis_train) + } + if ((is.null(dim(leaf_basis_test))) && (!is.null(leaf_basis_test))) { + leaf_basis_test <- as.matrix(leaf_basis_test) + } + if ((is.null(dim(rfx_basis_train))) && (!is.null(rfx_basis_train))) { + rfx_basis_train <- as.matrix(rfx_basis_train) + } + if ((is.null(dim(rfx_basis_test))) && (!is.null(rfx_basis_test))) { + rfx_basis_test <- as.matrix(rfx_basis_test) + } + + # Recode group IDs to integer vector (if passed as, for example, a vector of county names, etc...) + has_rfx <- FALSE + has_rfx_test <- FALSE + if (!is.null(rfx_group_ids_train)) { + group_ids_factor <- factor(rfx_group_ids_train) + rfx_group_ids_train <- as.integer(group_ids_factor) + has_rfx <- TRUE + if (!is.null(rfx_group_ids_test)) { + group_ids_factor_test <- factor( + rfx_group_ids_test, + levels = levels(group_ids_factor) + ) + if (sum(is.na(group_ids_factor_test)) > 0) { + stop( + "All random effect group labels provided in rfx_group_ids_test must be present in rfx_group_ids_train" + ) + } + rfx_group_ids_test <- as.integer(group_ids_factor_test) + has_rfx_test <- TRUE + } + } + + # Data consistency checks + if ((!is.null(X_test)) && (ncol(X_test) != ncol(X_train))) { + stop("X_train and X_test must have the same number of columns") + } + if ( + (!is.null(leaf_basis_test)) && + (ncol(leaf_basis_test) != ncol(leaf_basis_train)) + ) { + stop( + "leaf_basis_train and leaf_basis_test must have the same number of columns" ) - variance_forest_params_updated <- preprocessParams( - variance_forest_params_default, - variance_forest_params + } + if ( + (!is.null(leaf_basis_train)) && + (nrow(leaf_basis_train) != nrow(X_train)) + ) { + stop("leaf_basis_train and X_train must have the same number of rows") + } + if ((!is.null(leaf_basis_test)) && (nrow(leaf_basis_test) != nrow(X_test))) { + stop("leaf_basis_test and X_test must have the same number of rows") + } + if (nrow(X_train) != length(y_train)) { + stop("X_train and y_train must have the same number of observations") + } + if ( + (!is.null(rfx_basis_test)) && + (ncol(rfx_basis_test) != ncol(rfx_basis_train)) + ) { + stop( + "rfx_basis_train and rfx_basis_test must have the same number of columns" ) - - ### Unpack all parameter values - # 1. General parameters - cutpoint_grid_size <- general_params_updated$cutpoint_grid_size - standardize <- general_params_updated$standardize - sample_sigma2_global <- general_params_updated$sample_sigma2_global - sigma2_init <- general_params_updated$sigma2_global_init - a_global <- general_params_updated$sigma2_global_shape - b_global <- general_params_updated$sigma2_global_scale - variable_weights <- general_params_updated$variable_weights - random_seed <- general_params_updated$random_seed - keep_burnin <- general_params_updated$keep_burnin - keep_gfr <- general_params_updated$keep_gfr - keep_every <- general_params_updated$keep_every - num_chains <- general_params_updated$num_chains - verbose <- general_params_updated$verbose - probit_outcome_model <- general_params_updated$probit_outcome_model - rfx_working_parameter_prior_mean <- general_params_updated$rfx_working_parameter_prior_mean - rfx_group_parameter_prior_mean <- general_params_updated$rfx_group_parameter_prior_mean - rfx_working_parameter_prior_cov <- general_params_updated$rfx_working_parameter_prior_cov - rfx_group_parameter_prior_cov <- general_params_updated$rfx_group_parameter_prior_cov - rfx_variance_prior_shape <- general_params_updated$rfx_variance_prior_shape - rfx_variance_prior_scale <- general_params_updated$rfx_variance_prior_scale - num_threads <- general_params_updated$num_threads - - # 2. Mean forest parameters - num_trees_mean <- mean_forest_params_updated$num_trees - alpha_mean <- mean_forest_params_updated$alpha - beta_mean <- mean_forest_params_updated$beta - min_samples_leaf_mean <- mean_forest_params_updated$min_samples_leaf - max_depth_mean <- mean_forest_params_updated$max_depth - sample_sigma2_leaf <- mean_forest_params_updated$sample_sigma2_leaf - sigma2_leaf_init <- mean_forest_params_updated$sigma2_leaf_init - a_leaf <- mean_forest_params_updated$sigma2_leaf_shape - b_leaf <- mean_forest_params_updated$sigma2_leaf_scale - keep_vars_mean <- mean_forest_params_updated$keep_vars - drop_vars_mean <- mean_forest_params_updated$drop_vars - num_features_subsample_mean <- mean_forest_params_updated$num_features_subsample - - # 3. Variance forest parameters - num_trees_variance <- variance_forest_params_updated$num_trees - alpha_variance <- variance_forest_params_updated$alpha - beta_variance <- variance_forest_params_updated$beta - min_samples_leaf_variance <- variance_forest_params_updated$min_samples_leaf - max_depth_variance <- variance_forest_params_updated$max_depth - a_0 <- variance_forest_params_updated$leaf_prior_calibration_param - variance_forest_init <- variance_forest_params_updated$var_forest_leaf_init - a_forest <- variance_forest_params_updated$var_forest_prior_shape - b_forest <- variance_forest_params_updated$var_forest_prior_scale - keep_vars_variance <- variance_forest_params_updated$keep_vars - drop_vars_variance <- variance_forest_params_updated$drop_vars - num_features_subsample_variance <- variance_forest_params_updated$num_features_subsample - - # Check if there are enough GFR samples to seed num_chains samplers - if (num_gfr > 0) { - if (num_chains > num_gfr) { - stop( - "num_chains > num_gfr, meaning we do not have enough GFR samples to seed num_chains distinct MCMC chains" - ) - } + } + if (!is.null(rfx_group_ids_train)) { + if (!is.null(rfx_group_ids_test)) { + if ((!is.null(rfx_basis_train)) && (is.null(rfx_basis_test))) { + stop( + "rfx_basis_train is provided but rfx_basis_test is not provided" + ) + } + } + } + + # Fill in rfx basis as a vector of 1s (random intercept) if a basis not provided + has_basis_rfx <- FALSE + num_basis_rfx <- 0 + if (has_rfx) { + if (is.null(rfx_basis_train)) { + rfx_basis_train <- matrix( + rep(1, nrow(X_train)), + nrow = nrow(X_train), + ncol = 1 + ) + } else { + has_basis_rfx <- TRUE + num_basis_rfx <- ncol(rfx_basis_train) + } + num_rfx_groups <- length(unique(rfx_group_ids_train)) + num_rfx_components <- ncol(rfx_basis_train) + if (num_rfx_groups == 1) { + warning( + "Only one group was provided for random effect sampling, so the 'redundant parameterization' is likely overkill" + ) + } + } + if (has_rfx_test) { + if (is.null(rfx_basis_test)) { + if (has_basis_rfx) { + stop( + "Random effects basis provided for training set, must also be provided for the test set" + ) + } + rfx_basis_test <- matrix( + rep(1, nrow(X_test)), + nrow = nrow(X_test), + ncol = 1 + ) + } + } + + # Convert y_train to numeric vector if not already converted + if (!is.null(dim(y_train))) { + y_train <- as.matrix(y_train) + } + + # Determine whether a basis vector is provided + has_basis = !is.null(leaf_basis_train) + + # Determine whether a test set is provided + has_test = !is.null(X_test) + + # Preliminary runtime checks for probit link + if (!include_mean_forest) { + probit_outcome_model <- FALSE + } + if (probit_outcome_model) { + if (!(length(unique(y_train)) == 2)) { + stop( + "You specified a probit outcome model, but supplied an outcome with more than 2 unique values" + ) + } + unique_outcomes <- sort(unique(y_train)) + if (!(all(unique_outcomes == c(0, 1)))) { + stop( + "You specified a probit outcome model, but supplied an outcome with 2 unique values other than 0 and 1" + ) } - - # Override keep_gfr if there are no MCMC samples - if (num_mcmc == 0) { - keep_gfr <- TRUE + if (include_variance_forest) { + stop("We do not support heteroskedasticity with a probit link") } - - # Check if previous model JSON is provided and parse it if so - has_prev_model <- !is.null(previous_model_json) - if (has_prev_model) { - previous_bart_model <- createBARTModelFromJsonString( - previous_model_json - ) - previous_y_bar <- previous_bart_model$model_params$outcome_mean - previous_y_scale <- previous_bart_model$model_params$outcome_scale - if (previous_bart_model$model_params$include_mean_forest) { - previous_forest_samples_mean <- previous_bart_model$mean_forests - } else { - previous_forest_samples_mean <- NULL - } - if (previous_bart_model$model_params$include_variance_forest) { - previous_forest_samples_variance <- previous_bart_model$variance_forests - } else { - previous_forest_samples_variance <- NULL + if (sample_sigma2_global) { + warning( + "Global error variance will not be sampled with a probit link as it is fixed at 1" + ) + sample_sigma2_global <- F + } + } + + # Handle standardization, prior calibration, and initialization of forest + # differently for binary and continuous outcomes + if (probit_outcome_model) { + # Compute a probit-scale offset and fix scale to 1 + y_bar_train <- qnorm(mean(y_train)) + y_std_train <- 1 + + # Set a pseudo outcome by subtracting mean(y_train) from y_train + resid_train <- y_train - mean(y_train) + + # Set initial values of root nodes to 0.0 (in probit scale) + init_val_mean <- 0.0 + + # Calibrate priors for sigma^2 and tau + # Set sigma2_init to 1, ignoring default provided + sigma2_init <- 1.0 + # Skip variance_forest_init, since variance forests are not supported with probit link + b_leaf <- 1 / (num_trees_mean) + if (has_basis) { + if (ncol(leaf_basis_train) > 1) { + if (is.null(sigma2_leaf_init)) { + sigma2_leaf_init <- diag( + 2 / (num_trees_mean), + ncol(leaf_basis_train) + ) } - if (previous_bart_model$model_params$sample_sigma2_global) { - previous_global_var_samples <- previous_bart_model$sigma2_global_samples / - (previous_y_scale * previous_y_scale) + if (!is.matrix(sigma2_leaf_init)) { + current_leaf_scale <- as.matrix(diag( + sigma2_leaf_init, + ncol(leaf_basis_train) + )) } else { - previous_global_var_samples <- NULL + current_leaf_scale <- sigma2_leaf_init } - if (previous_bart_model$model_params$sample_sigma2_leaf) { - previous_leaf_var_samples <- previous_bart_model$sigma2_leaf_samples - } else { - previous_leaf_var_samples <- NULL + } else { + if (is.null(sigma2_leaf_init)) { + sigma2_leaf_init <- as.matrix(2 / (num_trees_mean)) } - if (previous_bart_model$model_params$has_rfx) { - previous_rfx_samples <- previous_bart_model$rfx_samples + if (!is.matrix(sigma2_leaf_init)) { + current_leaf_scale <- as.matrix(diag(sigma2_leaf_init, 1)) } else { - previous_rfx_samples <- NULL - } - previous_model_num_samples <- previous_bart_model$model_params$num_samples - if (previous_model_warmstart_sample_num > previous_model_num_samples) { - stop( - "`previous_model_warmstart_sample_num` exceeds the number of samples in `previous_model_json`" - ) + current_leaf_scale <- sigma2_leaf_init } + } } else { - previous_y_bar <- NULL - previous_y_scale <- NULL - previous_global_var_samples <- NULL - previous_leaf_var_samples <- NULL - previous_rfx_samples <- NULL - previous_forest_samples_mean <- NULL - previous_forest_samples_variance <- NULL - previous_model_num_samples <- 0 - } - - # Determine whether conditional mean, variance, or both will be modeled - if (num_trees_variance > 0) { - include_variance_forest = TRUE - } else { - include_variance_forest = FALSE - } - if (num_trees_mean > 0) { - include_mean_forest = TRUE + if (is.null(sigma2_leaf_init)) { + sigma2_leaf_init <- as.matrix(2 / (num_trees_mean)) + } + if (!is.matrix(sigma2_leaf_init)) { + current_leaf_scale <- as.matrix(diag(sigma2_leaf_init, 1)) + } else { + current_leaf_scale <- sigma2_leaf_init + } + } + current_sigma2 <- sigma2_init + } else { + # Only standardize if user requested + if (standardize) { + y_bar_train <- mean(y_train) + y_std_train <- sd(y_train) } else { - include_mean_forest = FALSE + y_bar_train <- 0 + y_std_train <- 1 } - # Set the variance forest priors if not set - if (include_variance_forest) { - if (is.null(a_forest)) { - a_forest <- num_trees_variance / (a_0^2) + 0.5 - } - if (is.null(b_forest)) b_forest <- num_trees_variance / (a_0^2) - } else { - a_forest <- 1. - b_forest <- 1. - } + # Compute standardized outcome + resid_train <- (y_train - y_bar_train) / y_std_train - # Override tau sampling if there is no mean forest - if (!include_mean_forest) { - sample_sigma2_leaf <- FALSE - } + # Compute initial value of root nodes in mean forest + init_val_mean <- mean(resid_train) - # Variable weight preprocessing (and initialization if necessary) - if (is.null(variable_weights)) { - variable_weights = rep(1 / ncol(X_train), ncol(X_train)) + # Calibrate priors for sigma^2 and tau + if (is.null(sigma2_init)) { + sigma2_init <- 1.0 * var(resid_train) } - if (any(variable_weights < 0)) { - stop("variable_weights cannot have any negative weights") + if (is.null(variance_forest_init)) { + variance_forest_init <- 1.0 * var(resid_train) } - - # Check covariates are matrix or dataframe - if ((!is.data.frame(X_train)) && (!is.matrix(X_train))) { - stop("X_train must be a matrix or dataframe") + if (is.null(b_leaf)) { + b_leaf <- var(resid_train) / (2 * num_trees_mean) } - if (!is.null(X_test)) { - if ((!is.data.frame(X_test)) && (!is.matrix(X_test))) { - stop("X_test must be a matrix or dataframe") + if (has_basis) { + if (ncol(leaf_basis_train) > 1) { + if (is.null(sigma2_leaf_init)) { + sigma2_leaf_init <- diag( + 2 * var(resid_train) / (num_trees_mean), + ncol(leaf_basis_train) + ) } - } - num_cov_orig <- ncol(X_train) - - # Standardize the keep variable lists to numeric indices - if (!is.null(keep_vars_mean)) { - if (is.character(keep_vars_mean)) { - if (!all(keep_vars_mean %in% names(X_train))) { - stop( - "keep_vars_mean includes some variable names that are not in X_train" - ) - } - variable_subset_mu <- unname(which( - names(X_train) %in% keep_vars_mean - )) + if (!is.matrix(sigma2_leaf_init)) { + current_leaf_scale <- as.matrix(diag( + sigma2_leaf_init, + ncol(leaf_basis_train) + )) } else { - if (any(keep_vars_mean > ncol(X_train))) { - stop( - "keep_vars_mean includes some variable indices that exceed the number of columns in X_train" - ) - } - if (any(keep_vars_mean < 0)) { - stop("keep_vars_mean includes some negative variable indices") - } - variable_subset_mu <- keep_vars_mean + current_leaf_scale <- sigma2_leaf_init } - } else if ((is.null(keep_vars_mean)) && (!is.null(drop_vars_mean))) { - if (is.character(drop_vars_mean)) { - if (!all(drop_vars_mean %in% names(X_train))) { - stop( - "drop_vars_mean includes some variable names that are not in X_train" - ) - } - variable_subset_mean <- unname(which( - !(names(X_train) %in% drop_vars_mean) - )) - } else { - if (any(drop_vars_mean > ncol(X_train))) { - stop( - "drop_vars_mean includes some variable indices that exceed the number of columns in X_train" - ) - } - if (any(drop_vars_mean < 0)) { - stop("drop_vars_mean includes some negative variable indices") - } - variable_subset_mean <- (1:ncol(X_train))[ - !(1:ncol(X_train) %in% drop_vars_mean) - ] - } - } else { - variable_subset_mean <- 1:ncol(X_train) - } - if (!is.null(keep_vars_variance)) { - if (is.character(keep_vars_variance)) { - if (!all(keep_vars_variance %in% names(X_train))) { - stop( - "keep_vars_variance includes some variable names that are not in X_train" - ) - } - variable_subset_variance <- unname(which( - names(X_train) %in% keep_vars_variance - )) - } else { - if (any(keep_vars_variance > ncol(X_train))) { - stop( - "keep_vars_variance includes some variable indices that exceed the number of columns in X_train" - ) - } - if (any(keep_vars_variance < 0)) { - stop( - "keep_vars_variance includes some negative variable indices" - ) - } - variable_subset_variance <- keep_vars_variance + } else { + if (is.null(sigma2_leaf_init)) { + sigma2_leaf_init <- as.matrix( + 2 * var(resid_train) / (num_trees_mean) + ) } - } else if ( - (is.null(keep_vars_variance)) && (!is.null(drop_vars_variance)) - ) { - if (is.character(drop_vars_variance)) { - if (!all(drop_vars_variance %in% names(X_train))) { - stop( - "drop_vars_variance includes some variable names that are not in X_train" - ) - } - variable_subset_variance <- unname(which( - !(names(X_train) %in% drop_vars_variance) - )) + if (!is.matrix(sigma2_leaf_init)) { + current_leaf_scale <- as.matrix(diag(sigma2_leaf_init, 1)) } else { - if (any(drop_vars_variance > ncol(X_train))) { - stop( - "drop_vars_variance includes some variable indices that exceed the number of columns in X_train" - ) - } - if (any(drop_vars_variance < 0)) { - stop( - "drop_vars_variance includes some negative variable indices" - ) - } - variable_subset_variance <- (1:ncol(X_train))[ - !(1:ncol(X_train) %in% drop_vars_variance) - ] + current_leaf_scale <- sigma2_leaf_init } + } } else { - variable_subset_variance <- 1:ncol(X_train) + if (is.null(sigma2_leaf_init)) { + sigma2_leaf_init <- as.matrix( + 2 * var(resid_train) / (num_trees_mean) + ) + } + if (!is.matrix(sigma2_leaf_init)) { + current_leaf_scale <- as.matrix(diag(sigma2_leaf_init, 1)) + } else { + current_leaf_scale <- sigma2_leaf_init + } + } + current_sigma2 <- sigma2_init + } + + # Determine leaf model type + if (!has_basis) { + leaf_model_mean_forest <- 0 + } else if (ncol(leaf_basis_train) == 1) { + leaf_model_mean_forest <- 1 + } else if (ncol(leaf_basis_train) > 1) { + leaf_model_mean_forest <- 2 + } else { + stop("leaf_basis_train passed must be a matrix with at least 1 column") + } + + # Set variance leaf model type (currently only one option) + leaf_model_variance_forest <- 3 + + # Unpack model type info + if (leaf_model_mean_forest == 0) { + leaf_dimension = 1 + is_leaf_constant = TRUE + leaf_regression = FALSE + } else if (leaf_model_mean_forest == 1) { + stopifnot(has_basis) + stopifnot(ncol(leaf_basis_train) == 1) + leaf_dimension = 1 + is_leaf_constant = FALSE + leaf_regression = TRUE + } else if (leaf_model_mean_forest == 2) { + stopifnot(has_basis) + stopifnot(ncol(leaf_basis_train) > 1) + leaf_dimension = ncol(leaf_basis_train) + is_leaf_constant = FALSE + leaf_regression = TRUE + if (sample_sigma2_leaf) { + warning( + "Sampling leaf scale not yet supported for multivariate leaf models, so the leaf scale parameter will not be sampled in this model." + ) + sample_sigma2_leaf <- FALSE } + } - # Preprocess covariates - if ((!is.data.frame(X_train)) && (!is.matrix(X_train))) { - stop("X_train must be a matrix or dataframe") - } - if (!is.null(X_test)) { - if ((!is.data.frame(X_test)) && (!is.matrix(X_test))) { - stop("X_test must be a matrix or dataframe") - } - } - if (ncol(X_train) != length(variable_weights)) { - stop("length(variable_weights) must equal ncol(X_train)") + # Data + if (leaf_regression) { + forest_dataset_train <- createForestDataset(X_train, leaf_basis_train) + if (has_test) { + forest_dataset_test <- createForestDataset(X_test, leaf_basis_test) } - train_cov_preprocess_list <- preprocessTrainData(X_train) - X_train_metadata <- train_cov_preprocess_list$metadata - X_train <- train_cov_preprocess_list$data - original_var_indices <- X_train_metadata$original_var_indices - feature_types <- X_train_metadata$feature_types - if (!is.null(X_test)) { - X_test <- preprocessPredictionData(X_test, X_train_metadata) + requires_basis <- TRUE + } else { + forest_dataset_train <- createForestDataset(X_train) + if (has_test) { + forest_dataset_test <- createForestDataset(X_test) + } + requires_basis <- FALSE + } + outcome_train <- createOutcome(resid_train) + + # Random number generator (std::mt19937) + if (is.null(random_seed)) { + random_seed = sample(1:10000, 1, FALSE) + } + rng <- createCppRNG(random_seed) + + # Sampling data structures + feature_types <- as.integer(feature_types) + global_model_config <- createGlobalModelConfig( + global_error_variance = current_sigma2 + ) + if (include_mean_forest) { + forest_model_config_mean <- createForestModelConfig( + feature_types = feature_types, + num_trees = num_trees_mean, + num_features = ncol(X_train), + num_observations = nrow(X_train), + variable_weights = variable_weights_mean, + leaf_dimension = leaf_dimension, + alpha = alpha_mean, + beta = beta_mean, + min_samples_leaf = min_samples_leaf_mean, + max_depth = max_depth_mean, + leaf_model_type = leaf_model_mean_forest, + leaf_model_scale = current_leaf_scale, + cutpoint_grid_size = cutpoint_grid_size, + num_features_subsample = num_features_subsample_mean + ) + forest_model_mean <- createForestModel( + forest_dataset_train, + forest_model_config_mean, + global_model_config + ) + } + if (include_variance_forest) { + forest_model_config_variance <- createForestModelConfig( + feature_types = feature_types, + num_trees = num_trees_variance, + num_features = ncol(X_train), + num_observations = nrow(X_train), + variable_weights = variable_weights_variance, + leaf_dimension = 1, + alpha = alpha_variance, + beta = beta_variance, + min_samples_leaf = min_samples_leaf_variance, + max_depth = max_depth_variance, + leaf_model_type = leaf_model_variance_forest, + cutpoint_grid_size = cutpoint_grid_size, + num_features_subsample = num_features_subsample_variance + ) + forest_model_variance <- createForestModel( + forest_dataset_train, + forest_model_config_variance, + global_model_config + ) + } + + # Container of forest samples + if (include_mean_forest) { + forest_samples_mean <- createForestSamples( + num_trees_mean, + leaf_dimension, + is_leaf_constant, + FALSE + ) + active_forest_mean <- createForest( + num_trees_mean, + leaf_dimension, + is_leaf_constant, + FALSE + ) + } + if (include_variance_forest) { + forest_samples_variance <- createForestSamples( + num_trees_variance, + 1, + TRUE, + TRUE + ) + active_forest_variance <- createForest( + num_trees_variance, + 1, + TRUE, + TRUE + ) + } + + # Random effects initialization + if (has_rfx) { + # Prior parameters + if (is.null(rfx_working_parameter_prior_mean)) { + if (num_rfx_components == 1) { + alpha_init <- c(0) + } else if (num_rfx_components > 1) { + alpha_init <- rep(0, num_rfx_components) + } else { + stop("There must be at least 1 random effect component") + } + } else { + alpha_init <- expand_dims_1d( + rfx_working_parameter_prior_mean, + num_rfx_components + ) + } + + if (is.null(rfx_group_parameter_prior_mean)) { + xi_init <- matrix( + rep(alpha_init, num_rfx_groups), + num_rfx_components, + num_rfx_groups + ) + } else { + xi_init <- expand_dims_2d( + rfx_group_parameter_prior_mean, + num_rfx_components, + num_rfx_groups + ) } - # Update variable weights - variable_weights_mean <- variable_weights_variance <- variable_weights - variable_weights_adj <- 1 / - sapply(original_var_indices, function(x) sum(original_var_indices == x)) - if (include_mean_forest) { - variable_weights_mean <- variable_weights_mean[original_var_indices] * - variable_weights_adj - variable_weights_mean[ - !(original_var_indices %in% variable_subset_mean) - ] <- 0 - } - if (include_variance_forest) { - variable_weights_variance <- variable_weights_variance[ - original_var_indices - ] * - variable_weights_adj - variable_weights_variance[ - !(original_var_indices %in% variable_subset_variance) - ] <- 0 + if (is.null(rfx_working_parameter_prior_cov)) { + sigma_alpha_init <- diag(1, num_rfx_components, num_rfx_components) + } else { + sigma_alpha_init <- expand_dims_2d_diag( + rfx_working_parameter_prior_cov, + num_rfx_components + ) } - # Set num_features_subsample to default, ncol(X_train), if not already set - if (is.null(num_features_subsample_mean)) { - num_features_subsample_mean <- ncol(X_train) - } - if (is.null(num_features_subsample_variance)) { - num_features_subsample_variance <- ncol(X_train) + if (is.null(rfx_group_parameter_prior_cov)) { + sigma_xi_init <- diag(1, num_rfx_components, num_rfx_components) + } else { + sigma_xi_init <- expand_dims_2d_diag( + rfx_group_parameter_prior_cov, + num_rfx_components + ) } - # Convert all input data to matrices if not already converted - if ((is.null(dim(leaf_basis_train))) && (!is.null(leaf_basis_train))) { - leaf_basis_train <- as.matrix(leaf_basis_train) - } - if ((is.null(dim(leaf_basis_test))) && (!is.null(leaf_basis_test))) { - leaf_basis_test <- as.matrix(leaf_basis_test) - } - if ((is.null(dim(rfx_basis_train))) && (!is.null(rfx_basis_train))) { - rfx_basis_train <- as.matrix(rfx_basis_train) - } - if ((is.null(dim(rfx_basis_test))) && (!is.null(rfx_basis_test))) { - rfx_basis_test <- as.matrix(rfx_basis_test) - } + sigma_xi_shape <- rfx_variance_prior_shape + sigma_xi_scale <- rfx_variance_prior_scale - # Recode group IDs to integer vector (if passed as, for example, a vector of county names, etc...) - has_rfx <- FALSE - has_rfx_test <- FALSE - if (!is.null(rfx_group_ids_train)) { - group_ids_factor <- factor(rfx_group_ids_train) - rfx_group_ids_train <- as.integer(group_ids_factor) - has_rfx <- TRUE - if (!is.null(rfx_group_ids_test)) { - group_ids_factor_test <- factor( - rfx_group_ids_test, - levels = levels(group_ids_factor) - ) - if (sum(is.na(group_ids_factor_test)) > 0) { - stop( - "All random effect group labels provided in rfx_group_ids_test must be present in rfx_group_ids_train" - ) - } - rfx_group_ids_test <- as.integer(group_ids_factor_test) - has_rfx_test <- TRUE + # Random effects data structure and storage container + rfx_dataset_train <- createRandomEffectsDataset( + rfx_group_ids_train, + rfx_basis_train + ) + rfx_tracker_train <- createRandomEffectsTracker(rfx_group_ids_train) + rfx_model <- createRandomEffectsModel( + num_rfx_components, + num_rfx_groups + ) + rfx_model$set_working_parameter(alpha_init) + rfx_model$set_group_parameters(xi_init) + rfx_model$set_working_parameter_cov(sigma_alpha_init) + rfx_model$set_group_parameter_cov(sigma_xi_init) + rfx_model$set_variance_prior_shape(sigma_xi_shape) + rfx_model$set_variance_prior_scale(sigma_xi_scale) + rfx_samples <- createRandomEffectSamples( + num_rfx_components, + num_rfx_groups, + rfx_tracker_train + ) + } + + # Container of variance parameter samples + num_actual_mcmc_iter <- num_mcmc * keep_every + num_samples <- num_gfr + num_burnin + num_actual_mcmc_iter + # Delete GFR samples from these containers after the fact if desired + # num_retained_samples <- ifelse(keep_gfr, num_gfr, 0) + ifelse(keep_burnin, num_burnin, 0) + num_mcmc + num_retained_samples <- num_gfr + + ifelse(keep_burnin, num_burnin, 0) + + num_mcmc * num_chains + if (sample_sigma2_global) { + global_var_samples <- rep(NA, num_retained_samples) + } + if (sample_sigma2_leaf) { + leaf_scale_samples <- rep(NA, num_retained_samples) + } + if (include_mean_forest) { + mean_forest_pred_train <- matrix( + NA_real_, + nrow(X_train), + num_retained_samples + ) + } + if (include_variance_forest) { + variance_forest_pred_train <- matrix( + NA_real_, + nrow(X_train), + num_retained_samples + ) + } + sample_counter <- 0 + + # Initialize the leaves of each tree in the mean forest + if (include_mean_forest) { + if (requires_basis) { + init_values_mean_forest <- rep(0., ncol(leaf_basis_train)) + } else { + init_values_mean_forest <- 0. + } + active_forest_mean$prepare_for_sampler( + forest_dataset_train, + outcome_train, + forest_model_mean, + leaf_model_mean_forest, + init_values_mean_forest + ) + active_forest_mean$adjust_residual( + forest_dataset_train, + outcome_train, + forest_model_mean, + requires_basis, + FALSE + ) + } + + # Initialize the leaves of each tree in the variance forest + if (include_variance_forest) { + active_forest_variance$prepare_for_sampler( + forest_dataset_train, + outcome_train, + forest_model_variance, + leaf_model_variance_forest, + variance_forest_init + ) + } + + # Run GFR (warm start) if specified + if (num_gfr > 0) { + for (i in 1:num_gfr) { + # Keep all GFR samples at this stage -- remove from ForestSamples after MCMC + # keep_sample <- ifelse(keep_gfr, TRUE, FALSE) + keep_sample <- TRUE + if (keep_sample) { + sample_counter <- sample_counter + 1 + } + # Print progress + if (verbose) { + if ((i %% 10 == 0) || (i == num_gfr)) { + cat( + "Sampling", + i, + "out of", + num_gfr, + "XBART (grow-from-root) draws\n" + ) + } + } + + if (include_mean_forest) { + if (probit_outcome_model) { + # Sample latent probit variable, z | - + forest_pred <- active_forest_mean$predict( + forest_dataset_train + ) + mu0 <- forest_pred[y_train == 0] + mu1 <- forest_pred[y_train == 1] + u0 <- runif(sum(y_train == 0), 0, pnorm(0 - mu0)) + u1 <- runif(sum(y_train == 1), pnorm(0 - mu1), 1) + resid_train[y_train == 0] <- mu0 + qnorm(u0) + resid_train[y_train == 1] <- mu1 + qnorm(u1) + + # Update outcome + outcome_train$update_data(resid_train - forest_pred) } - } - # Data consistency checks - if ((!is.null(X_test)) && (ncol(X_test) != ncol(X_train))) { - stop("X_train and X_test must have the same number of columns") - } - if ( - (!is.null(leaf_basis_test)) && - (ncol(leaf_basis_test) != ncol(leaf_basis_train)) - ) { - stop( - "leaf_basis_train and leaf_basis_test must have the same number of columns" - ) - } - if ( - (!is.null(leaf_basis_train)) && - (nrow(leaf_basis_train) != nrow(X_train)) - ) { - stop("leaf_basis_train and X_train must have the same number of rows") - } - if ( - (!is.null(leaf_basis_test)) && (nrow(leaf_basis_test) != nrow(X_test)) - ) { - stop("leaf_basis_test and X_test must have the same number of rows") - } - if (nrow(X_train) != length(y_train)) { - stop("X_train and y_train must have the same number of observations") - } - if ( - (!is.null(rfx_basis_test)) && - (ncol(rfx_basis_test) != ncol(rfx_basis_train)) - ) { - stop( - "rfx_basis_train and rfx_basis_test must have the same number of columns" + # Sample mean forest + forest_model_mean$sample_one_iteration( + forest_dataset = forest_dataset_train, + residual = outcome_train, + forest_samples = forest_samples_mean, + active_forest = active_forest_mean, + rng = rng, + forest_model_config = forest_model_config_mean, + global_model_config = global_model_config, + num_threads = num_threads, + keep_forest = keep_sample, + gfr = TRUE ) - } - if (!is.null(rfx_group_ids_train)) { - if (!is.null(rfx_group_ids_test)) { - if ((!is.null(rfx_basis_train)) && (is.null(rfx_basis_test))) { - stop( - "rfx_basis_train is provided but rfx_basis_test is not provided" - ) - } + + # Cache train set predictions since they are already computed during sampling + if (keep_sample) { + mean_forest_pred_train[, + sample_counter + ] <- forest_model_mean$get_cached_forest_predictions() } - } + } + if (include_variance_forest) { + forest_model_variance$sample_one_iteration( + forest_dataset = forest_dataset_train, + residual = outcome_train, + forest_samples = forest_samples_variance, + active_forest = active_forest_variance, + rng = rng, + forest_model_config = forest_model_config_variance, + global_model_config = global_model_config, + keep_forest = keep_sample, + gfr = TRUE + ) - # Fill in rfx basis as a vector of 1s (random intercept) if a basis not provided - has_basis_rfx <- FALSE - num_basis_rfx <- 0 - if (has_rfx) { - if (is.null(rfx_basis_train)) { - rfx_basis_train <- matrix( - rep(1, nrow(X_train)), - nrow = nrow(X_train), - ncol = 1 - ) - } else { - has_basis_rfx <- TRUE - num_basis_rfx <- ncol(rfx_basis_train) + # Cache train set predictions since they are already computed during sampling + if (keep_sample) { + variance_forest_pred_train[, + sample_counter + ] <- forest_model_variance$get_cached_forest_predictions() } - num_rfx_groups <- length(unique(rfx_group_ids_train)) - num_rfx_components <- ncol(rfx_basis_train) - if (num_rfx_groups == 1) { - warning( - "Only one group was provided for random effect sampling, so the 'redundant parameterization' is likely overkill" - ) + } + if (sample_sigma2_global) { + current_sigma2 <- sampleGlobalErrorVarianceOneIteration( + outcome_train, + forest_dataset_train, + rng, + a_global, + b_global + ) + if (keep_sample) { + global_var_samples[sample_counter] <- current_sigma2 } - } - if (has_rfx_test) { - if (is.null(rfx_basis_test)) { - if (has_basis_rfx) { - stop( - "Random effects basis provided for training set, must also be provided for the test set" - ) - } - rfx_basis_test <- matrix( - rep(1, nrow(X_test)), - nrow = nrow(X_test), - ncol = 1 - ) + global_model_config$update_global_error_variance(current_sigma2) + } + if (sample_sigma2_leaf) { + leaf_scale_double <- sampleLeafVarianceOneIteration( + active_forest_mean, + rng, + a_leaf, + b_leaf + ) + current_leaf_scale <- as.matrix(leaf_scale_double) + if (keep_sample) { + leaf_scale_samples[sample_counter] <- leaf_scale_double } + forest_model_config_mean$update_leaf_model_scale( + current_leaf_scale + ) + } + if (has_rfx) { + rfx_model$sample_random_effect( + rfx_dataset_train, + outcome_train, + rfx_tracker_train, + rfx_samples, + keep_sample, + current_sigma2, + rng + ) + } } + } - # Convert y_train to numeric vector if not already converted - if (!is.null(dim(y_train))) { - y_train <- as.matrix(y_train) - } - - # Determine whether a basis vector is provided - has_basis = !is.null(leaf_basis_train) - - # Determine whether a test set is provided - has_test = !is.null(X_test) - - # Preliminary runtime checks for probit link - if (!include_mean_forest) { - probit_outcome_model <- FALSE - } - if (probit_outcome_model) { - if (!(length(unique(y_train)) == 2)) { - stop( - "You specified a probit outcome model, but supplied an outcome with more than 2 unique values" - ) - } - unique_outcomes <- sort(unique(y_train)) - if (!(all(unique_outcomes == c(0, 1)))) { - stop( - "You specified a probit outcome model, but supplied an outcome with 2 unique values other than 0 and 1" + # Run MCMC + if (num_burnin + num_mcmc > 0) { + for (chain_num in 1:num_chains) { + if (num_gfr > 0) { + # Reset state of active_forest and forest_model based on a previous GFR sample + forest_ind <- num_gfr - chain_num + if (include_mean_forest) { + resetActiveForest( + active_forest_mean, + forest_samples_mean, + forest_ind + ) + resetForestModel( + forest_model_mean, + active_forest_mean, + forest_dataset_train, + outcome_train, + TRUE + ) + if (sample_sigma2_leaf) { + leaf_scale_double <- leaf_scale_samples[forest_ind + 1] + current_leaf_scale <- as.matrix(leaf_scale_double) + forest_model_config_mean$update_leaf_model_scale( + current_leaf_scale ) + } } if (include_variance_forest) { - stop("We do not support heteroskedasticity with a probit link") - } - if (sample_sigma2_global) { - warning( - "Global error variance will not be sampled with a probit link as it is fixed at 1" - ) - sample_sigma2_global <- F - } - } - - # Handle standardization, prior calibration, and initialization of forest - # differently for binary and continuous outcomes - if (probit_outcome_model) { - # Compute a probit-scale offset and fix scale to 1 - y_bar_train <- qnorm(mean(y_train)) - y_std_train <- 1 - - # Set a pseudo outcome by subtracting mean(y_train) from y_train - resid_train <- y_train - mean(y_train) - - # Set initial values of root nodes to 0.0 (in probit scale) - init_val_mean <- 0.0 - - # Calibrate priors for sigma^2 and tau - # Set sigma2_init to 1, ignoring default provided - sigma2_init <- 1.0 - # Skip variance_forest_init, since variance forests are not supported with probit link - b_leaf <- 1 / (num_trees_mean) - if (has_basis) { - if (ncol(leaf_basis_train) > 1) { - if (is.null(sigma2_leaf_init)) { - sigma2_leaf_init <- diag( - 2 / (num_trees_mean), - ncol(leaf_basis_train) - ) - } - if (!is.matrix(sigma2_leaf_init)) { - current_leaf_scale <- as.matrix(diag( - sigma2_leaf_init, - ncol(leaf_basis_train) - )) - } else { - current_leaf_scale <- sigma2_leaf_init - } - } else { - if (is.null(sigma2_leaf_init)) { - sigma2_leaf_init <- as.matrix(2 / (num_trees_mean)) - } - if (!is.matrix(sigma2_leaf_init)) { - current_leaf_scale <- as.matrix(diag(sigma2_leaf_init, 1)) - } else { - current_leaf_scale <- sigma2_leaf_init - } - } - } else { - if (is.null(sigma2_leaf_init)) { - sigma2_leaf_init <- as.matrix(2 / (num_trees_mean)) - } - if (!is.matrix(sigma2_leaf_init)) { - current_leaf_scale <- as.matrix(diag(sigma2_leaf_init, 1)) - } else { - current_leaf_scale <- sigma2_leaf_init - } - } - current_sigma2 <- sigma2_init - } else { - # Only standardize if user requested - if (standardize) { - y_bar_train <- mean(y_train) - y_std_train <- sd(y_train) - } else { - y_bar_train <- 0 - y_std_train <- 1 - } - - # Compute standardized outcome - resid_train <- (y_train - y_bar_train) / y_std_train - - # Compute initial value of root nodes in mean forest - init_val_mean <- mean(resid_train) - - # Calibrate priors for sigma^2 and tau - if (is.null(sigma2_init)) { - sigma2_init <- 1.0 * var(resid_train) - } - if (is.null(variance_forest_init)) { - variance_forest_init <- 1.0 * var(resid_train) + resetActiveForest( + active_forest_variance, + forest_samples_variance, + forest_ind + ) + resetForestModel( + forest_model_variance, + active_forest_variance, + forest_dataset_train, + outcome_train, + FALSE + ) } - if (is.null(b_leaf)) { - b_leaf <- var(resid_train) / (2 * num_trees_mean) + if (has_rfx) { + resetRandomEffectsModel( + rfx_model, + rfx_samples, + forest_ind, + sigma_alpha_init + ) + resetRandomEffectsTracker( + rfx_tracker_train, + rfx_model, + rfx_dataset_train, + outcome_train, + rfx_samples + ) } - if (has_basis) { - if (ncol(leaf_basis_train) > 1) { - if (is.null(sigma2_leaf_init)) { - sigma2_leaf_init <- diag( - 2 * var(resid_train) / (num_trees_mean), - ncol(leaf_basis_train) - ) - } - if (!is.matrix(sigma2_leaf_init)) { - current_leaf_scale <- as.matrix(diag( - sigma2_leaf_init, - ncol(leaf_basis_train) - )) - } else { - current_leaf_scale <- sigma2_leaf_init - } - } else { - if (is.null(sigma2_leaf_init)) { - sigma2_leaf_init <- as.matrix( - 2 * var(resid_train) / (num_trees_mean) - ) - } - if (!is.matrix(sigma2_leaf_init)) { - current_leaf_scale <- as.matrix(diag(sigma2_leaf_init, 1)) - } else { - current_leaf_scale <- sigma2_leaf_init - } - } - } else { - if (is.null(sigma2_leaf_init)) { - sigma2_leaf_init <- as.matrix( - 2 * var(resid_train) / (num_trees_mean) - ) - } - if (!is.matrix(sigma2_leaf_init)) { - current_leaf_scale <- as.matrix(diag(sigma2_leaf_init, 1)) - } else { - current_leaf_scale <- sigma2_leaf_init - } + if (sample_sigma2_global) { + current_sigma2 <- global_var_samples[forest_ind + 1] + global_model_config$update_global_error_variance( + current_sigma2 + ) } - current_sigma2 <- sigma2_init - } - - # Determine leaf model type - if (!has_basis) { - leaf_model_mean_forest <- 0 - } else if (ncol(leaf_basis_train) == 1) { - leaf_model_mean_forest <- 1 - } else if (ncol(leaf_basis_train) > 1) { - leaf_model_mean_forest <- 2 - } else { - stop("leaf_basis_train passed must be a matrix with at least 1 column") - } - - # Set variance leaf model type (currently only one option) - leaf_model_variance_forest <- 3 - - # Unpack model type info - if (leaf_model_mean_forest == 0) { - leaf_dimension = 1 - is_leaf_constant = TRUE - leaf_regression = FALSE - } else if (leaf_model_mean_forest == 1) { - stopifnot(has_basis) - stopifnot(ncol(leaf_basis_train) == 1) - leaf_dimension = 1 - is_leaf_constant = FALSE - leaf_regression = TRUE - } else if (leaf_model_mean_forest == 2) { - stopifnot(has_basis) - stopifnot(ncol(leaf_basis_train) > 1) - leaf_dimension = ncol(leaf_basis_train) - is_leaf_constant = FALSE - leaf_regression = TRUE - if (sample_sigma2_leaf) { - warning( - "Sampling leaf scale not yet supported for multivariate leaf models, so the leaf scale parameter will not be sampled in this model." + } else if (has_prev_model) { + if (include_mean_forest) { + resetActiveForest( + active_forest_mean, + previous_forest_samples_mean, + previous_model_warmstart_sample_num - 1 + ) + resetForestModel( + forest_model_mean, + active_forest_mean, + forest_dataset_train, + outcome_train, + TRUE + ) + if ( + sample_sigma2_leaf && + (!is.null(previous_leaf_var_samples)) + ) { + leaf_scale_double <- previous_leaf_var_samples[ + previous_model_warmstart_sample_num + ] + current_leaf_scale <- as.matrix(leaf_scale_double) + forest_model_config_mean$update_leaf_model_scale( + current_leaf_scale ) - sample_sigma2_leaf <- FALSE - } - } - - # Data - if (leaf_regression) { - forest_dataset_train <- createForestDataset(X_train, leaf_basis_train) - if (has_test) { - forest_dataset_test <- createForestDataset(X_test, leaf_basis_test) + } } - requires_basis <- TRUE - } else { - forest_dataset_train <- createForestDataset(X_train) - if (has_test) { - forest_dataset_test <- createForestDataset(X_test) - } - requires_basis <- FALSE - } - outcome_train <- createOutcome(resid_train) - - # Random number generator (std::mt19937) - if (is.null(random_seed)) { - random_seed = sample(1:10000, 1, FALSE) - } - rng <- createCppRNG(random_seed) - - # Sampling data structures - feature_types <- as.integer(feature_types) - global_model_config <- createGlobalModelConfig( - global_error_variance = current_sigma2 - ) - if (include_mean_forest) { - forest_model_config_mean <- createForestModelConfig( - feature_types = feature_types, - num_trees = num_trees_mean, - num_features = ncol(X_train), - num_observations = nrow(X_train), - variable_weights = variable_weights_mean, - leaf_dimension = leaf_dimension, - alpha = alpha_mean, - beta = beta_mean, - min_samples_leaf = min_samples_leaf_mean, - max_depth = max_depth_mean, - leaf_model_type = leaf_model_mean_forest, - leaf_model_scale = current_leaf_scale, - cutpoint_grid_size = cutpoint_grid_size, - num_features_subsample = num_features_subsample_mean - ) - forest_model_mean <- createForestModel( - forest_dataset_train, - forest_model_config_mean, - global_model_config - ) - } - if (include_variance_forest) { - forest_model_config_variance <- createForestModelConfig( - feature_types = feature_types, - num_trees = num_trees_variance, - num_features = ncol(X_train), - num_observations = nrow(X_train), - variable_weights = variable_weights_variance, - leaf_dimension = 1, - alpha = alpha_variance, - beta = beta_variance, - min_samples_leaf = min_samples_leaf_variance, - max_depth = max_depth_variance, - leaf_model_type = leaf_model_variance_forest, - cutpoint_grid_size = cutpoint_grid_size, - num_features_subsample = num_features_subsample_variance - ) - forest_model_variance <- createForestModel( + if (include_variance_forest) { + resetActiveForest( + active_forest_variance, + previous_forest_samples_variance, + previous_model_warmstart_sample_num - 1 + ) + resetForestModel( + forest_model_variance, + active_forest_variance, forest_dataset_train, - forest_model_config_variance, - global_model_config - ) - } - - # Container of forest samples - if (include_mean_forest) { - forest_samples_mean <- createForestSamples( - num_trees_mean, - leaf_dimension, - is_leaf_constant, - FALSE - ) - active_forest_mean <- createForest( - num_trees_mean, - leaf_dimension, - is_leaf_constant, + outcome_train, FALSE - ) - } - if (include_variance_forest) { - forest_samples_variance <- createForestSamples( - num_trees_variance, - 1, - TRUE, - TRUE - ) - active_forest_variance <- createForest( - num_trees_variance, - 1, - TRUE, - TRUE - ) - } - - # Random effects initialization - if (has_rfx) { - # Prior parameters - if (is.null(rfx_working_parameter_prior_mean)) { - if (num_rfx_components == 1) { - alpha_init <- c(0) - } else if (num_rfx_components > 1) { - alpha_init <- rep(0, num_rfx_components) - } else { - stop("There must be at least 1 random effect component") - } - } else { - alpha_init <- expand_dims_1d( - rfx_working_parameter_prior_mean, - num_rfx_components - ) + ) } - - if (is.null(rfx_group_parameter_prior_mean)) { - xi_init <- matrix( - rep(alpha_init, num_rfx_groups), - num_rfx_components, - num_rfx_groups + if (has_rfx) { + if (is.null(previous_rfx_samples)) { + warning( + "`previous_model_json` did not have any random effects samples, so the RFX sampler will be run from scratch while the forests and any other parameters are warm started" ) - } else { - xi_init <- expand_dims_2d( - rfx_group_parameter_prior_mean, - num_rfx_components, - num_rfx_groups + rootResetRandomEffectsModel( + rfx_model, + alpha_init, + xi_init, + sigma_alpha_init, + sigma_xi_init, + sigma_xi_shape, + sigma_xi_scale ) - } - - if (is.null(rfx_working_parameter_prior_cov)) { - sigma_alpha_init <- diag(1, num_rfx_components, num_rfx_components) - } else { - sigma_alpha_init <- expand_dims_2d_diag( - rfx_working_parameter_prior_cov, - num_rfx_components + rootResetRandomEffectsTracker( + rfx_tracker_train, + rfx_model, + rfx_dataset_train, + outcome_train ) - } - - if (is.null(rfx_group_parameter_prior_cov)) { - sigma_xi_init <- diag(1, num_rfx_components, num_rfx_components) - } else { - sigma_xi_init <- expand_dims_2d_diag( - rfx_group_parameter_prior_cov, - num_rfx_components + } else { + resetRandomEffectsModel( + rfx_model, + previous_rfx_samples, + previous_model_warmstart_sample_num - 1, + sigma_alpha_init + ) + resetRandomEffectsTracker( + rfx_tracker_train, + rfx_model, + rfx_dataset_train, + outcome_train, + rfx_samples ) + } } - - sigma_xi_shape <- rfx_variance_prior_shape - sigma_xi_scale <- rfx_variance_prior_scale - - # Random effects data structure and storage container - rfx_dataset_train <- createRandomEffectsDataset( - rfx_group_ids_train, - rfx_basis_train - ) - rfx_tracker_train <- createRandomEffectsTracker(rfx_group_ids_train) - rfx_model <- createRandomEffectsModel( - num_rfx_components, - num_rfx_groups - ) - rfx_model$set_working_parameter(alpha_init) - rfx_model$set_group_parameters(xi_init) - rfx_model$set_working_parameter_cov(sigma_alpha_init) - rfx_model$set_group_parameter_cov(sigma_xi_init) - rfx_model$set_variance_prior_shape(sigma_xi_shape) - rfx_model$set_variance_prior_scale(sigma_xi_scale) - rfx_samples <- createRandomEffectSamples( - num_rfx_components, - num_rfx_groups, - rfx_tracker_train - ) - } - - # Container of variance parameter samples - num_actual_mcmc_iter <- num_mcmc * keep_every - num_samples <- num_gfr + num_burnin + num_actual_mcmc_iter - # Delete GFR samples from these containers after the fact if desired - # num_retained_samples <- ifelse(keep_gfr, num_gfr, 0) + ifelse(keep_burnin, num_burnin, 0) + num_mcmc - num_retained_samples <- num_gfr + - ifelse(keep_burnin, num_burnin, 0) + - num_mcmc * num_chains - if (sample_sigma2_global) { - global_var_samples <- rep(NA, num_retained_samples) - } - if (sample_sigma2_leaf) { - leaf_scale_samples <- rep(NA, num_retained_samples) - } - if (include_mean_forest) { - mean_forest_pred_train <- matrix( - NA_real_, - nrow(X_train), - num_retained_samples - ) - } - if (include_variance_forest) { - variance_forest_pred_train <- matrix( - NA_real_, - nrow(X_train), - num_retained_samples - ) - } - sample_counter <- 0 - - # Initialize the leaves of each tree in the mean forest - if (include_mean_forest) { - if (requires_basis) { - init_values_mean_forest <- rep(0., ncol(leaf_basis_train)) - } else { - init_values_mean_forest <- 0. + if (sample_sigma2_global) { + if (!is.null(previous_global_var_samples)) { + current_sigma2 <- previous_global_var_samples[ + previous_model_warmstart_sample_num + ] + global_model_config$update_global_error_variance( + current_sigma2 + ) + } } - active_forest_mean$prepare_for_sampler( - forest_dataset_train, - outcome_train, + } else { + if (include_mean_forest) { + resetActiveForest(active_forest_mean) + active_forest_mean$set_root_leaves( + init_values_mean_forest / num_trees_mean + ) + resetForestModel( forest_model_mean, - leaf_model_mean_forest, - init_values_mean_forest - ) - active_forest_mean$adjust_residual( + active_forest_mean, forest_dataset_train, outcome_train, - forest_model_mean, - requires_basis, - FALSE - ) - } - - # Initialize the leaves of each tree in the variance forest - if (include_variance_forest) { - active_forest_variance$prepare_for_sampler( + TRUE + ) + if (sample_sigma2_leaf) { + current_leaf_scale <- as.matrix(sigma2_leaf_init) + forest_model_config_mean$update_leaf_model_scale( + current_leaf_scale + ) + } + } + if (include_variance_forest) { + resetActiveForest(active_forest_variance) + active_forest_variance$set_root_leaves( + log(variance_forest_init) / num_trees_variance + ) + resetForestModel( + forest_model_variance, + active_forest_variance, forest_dataset_train, outcome_train, - forest_model_variance, - leaf_model_variance_forest, - variance_forest_init - ) - } - - # Run GFR (warm start) if specified - if (num_gfr > 0) { - for (i in 1:num_gfr) { - # Keep all GFR samples at this stage -- remove from ForestSamples after MCMC - # keep_sample <- ifelse(keep_gfr, TRUE, FALSE) + FALSE + ) + } + if (has_rfx) { + rootResetRandomEffectsModel( + rfx_model, + alpha_init, + xi_init, + sigma_alpha_init, + sigma_xi_init, + sigma_xi_shape, + sigma_xi_scale + ) + rootResetRandomEffectsTracker( + rfx_tracker_train, + rfx_model, + rfx_dataset_train, + outcome_train + ) + } + if (sample_sigma2_global) { + current_sigma2 <- sigma2_init + global_model_config$update_global_error_variance( + current_sigma2 + ) + } + } + for (i in (num_gfr + 1):num_samples) { + is_mcmc <- i > (num_gfr + num_burnin) + if (is_mcmc) { + mcmc_counter <- i - (num_gfr + num_burnin) + if (mcmc_counter %% keep_every == 0) { keep_sample <- TRUE - if (keep_sample) { - sample_counter <- sample_counter + 1 - } - # Print progress - if (verbose) { - if ((i %% 10 == 0) || (i == num_gfr)) { - cat( - "Sampling", - i, - "out of", - num_gfr, - "XBART (grow-from-root) draws\n" - ) - } - } - - if (include_mean_forest) { - if (probit_outcome_model) { - # Sample latent probit variable, z | - - forest_pred <- active_forest_mean$predict( - forest_dataset_train - ) - mu0 <- forest_pred[y_train == 0] - mu1 <- forest_pred[y_train == 1] - u0 <- runif(sum(y_train == 0), 0, pnorm(0 - mu0)) - u1 <- runif(sum(y_train == 1), pnorm(0 - mu1), 1) - resid_train[y_train == 0] <- mu0 + qnorm(u0) - resid_train[y_train == 1] <- mu1 + qnorm(u1) - - # Update outcome - outcome_train$update_data(resid_train - forest_pred) - } - - # Sample mean forest - forest_model_mean$sample_one_iteration( - forest_dataset = forest_dataset_train, - residual = outcome_train, - forest_samples = forest_samples_mean, - active_forest = active_forest_mean, - rng = rng, - forest_model_config = forest_model_config_mean, - global_model_config = global_model_config, - num_threads = num_threads, - keep_forest = keep_sample, - gfr = TRUE - ) - - # Cache train set predictions since they are already computed during sampling - if (keep_sample) { - mean_forest_pred_train[, - sample_counter - ] <- forest_model_mean$get_cached_forest_predictions() - } - } - if (include_variance_forest) { - forest_model_variance$sample_one_iteration( - forest_dataset = forest_dataset_train, - residual = outcome_train, - forest_samples = forest_samples_variance, - active_forest = active_forest_variance, - rng = rng, - forest_model_config = forest_model_config_variance, - global_model_config = global_model_config, - keep_forest = keep_sample, - gfr = TRUE - ) - - # Cache train set predictions since they are already computed during sampling - if (keep_sample) { - variance_forest_pred_train[, - sample_counter - ] <- forest_model_variance$get_cached_forest_predictions() - } - } - if (sample_sigma2_global) { - current_sigma2 <- sampleGlobalErrorVarianceOneIteration( - outcome_train, - forest_dataset_train, - rng, - a_global, - b_global - ) - if (keep_sample) { - global_var_samples[sample_counter] <- current_sigma2 - } - global_model_config$update_global_error_variance(current_sigma2) - } - if (sample_sigma2_leaf) { - leaf_scale_double <- sampleLeafVarianceOneIteration( - active_forest_mean, - rng, - a_leaf, - b_leaf - ) - current_leaf_scale <- as.matrix(leaf_scale_double) - if (keep_sample) { - leaf_scale_samples[sample_counter] <- leaf_scale_double - } - forest_model_config_mean$update_leaf_model_scale( - current_leaf_scale - ) - } - if (has_rfx) { - rfx_model$sample_random_effect( - rfx_dataset_train, - outcome_train, - rfx_tracker_train, - rfx_samples, - keep_sample, - current_sigma2, - rng - ) - } + } else { + keep_sample <- FALSE + } + } else { + if (keep_burnin) { + keep_sample <- TRUE + } else { + keep_sample <- FALSE + } } - } - - # Run MCMC - if (num_burnin + num_mcmc > 0) { - for (chain_num in 1:num_chains) { - if (num_gfr > 0) { - # Reset state of active_forest and forest_model based on a previous GFR sample - forest_ind <- num_gfr - chain_num - if (include_mean_forest) { - resetActiveForest( - active_forest_mean, - forest_samples_mean, - forest_ind - ) - resetForestModel( - forest_model_mean, - active_forest_mean, - forest_dataset_train, - outcome_train, - TRUE - ) - if (sample_sigma2_leaf) { - leaf_scale_double <- leaf_scale_samples[forest_ind + 1] - current_leaf_scale <- as.matrix(leaf_scale_double) - forest_model_config_mean$update_leaf_model_scale( - current_leaf_scale - ) - } - } - if (include_variance_forest) { - resetActiveForest( - active_forest_variance, - forest_samples_variance, - forest_ind - ) - resetForestModel( - forest_model_variance, - active_forest_variance, - forest_dataset_train, - outcome_train, - FALSE - ) - } - if (has_rfx) { - resetRandomEffectsModel( - rfx_model, - rfx_samples, - forest_ind, - sigma_alpha_init - ) - resetRandomEffectsTracker( - rfx_tracker_train, - rfx_model, - rfx_dataset_train, - outcome_train, - rfx_samples - ) - } - if (sample_sigma2_global) { - current_sigma2 <- global_var_samples[forest_ind + 1] - global_model_config$update_global_error_variance( - current_sigma2 - ) - } - } else if (has_prev_model) { - if (include_mean_forest) { - resetActiveForest( - active_forest_mean, - previous_forest_samples_mean, - previous_model_warmstart_sample_num - 1 - ) - resetForestModel( - forest_model_mean, - active_forest_mean, - forest_dataset_train, - outcome_train, - TRUE - ) - if ( - sample_sigma2_leaf && - (!is.null(previous_leaf_var_samples)) - ) { - leaf_scale_double <- previous_leaf_var_samples[ - previous_model_warmstart_sample_num - ] - current_leaf_scale <- as.matrix(leaf_scale_double) - forest_model_config_mean$update_leaf_model_scale( - current_leaf_scale - ) - } - } - if (include_variance_forest) { - resetActiveForest( - active_forest_variance, - previous_forest_samples_variance, - previous_model_warmstart_sample_num - 1 - ) - resetForestModel( - forest_model_variance, - active_forest_variance, - forest_dataset_train, - outcome_train, - FALSE - ) - } - if (has_rfx) { - if (is.null(previous_rfx_samples)) { - warning( - "`previous_model_json` did not have any random effects samples, so the RFX sampler will be run from scratch while the forests and any other parameters are warm started" - ) - rootResetRandomEffectsModel( - rfx_model, - alpha_init, - xi_init, - sigma_alpha_init, - sigma_xi_init, - sigma_xi_shape, - sigma_xi_scale - ) - rootResetRandomEffectsTracker( - rfx_tracker_train, - rfx_model, - rfx_dataset_train, - outcome_train - ) - } else { - resetRandomEffectsModel( - rfx_model, - previous_rfx_samples, - previous_model_warmstart_sample_num - 1, - sigma_alpha_init - ) - resetRandomEffectsTracker( - rfx_tracker_train, - rfx_model, - rfx_dataset_train, - outcome_train, - rfx_samples - ) - } - } - if (sample_sigma2_global) { - if (!is.null(previous_global_var_samples)) { - current_sigma2 <- previous_global_var_samples[ - previous_model_warmstart_sample_num - ] - global_model_config$update_global_error_variance( - current_sigma2 - ) - } - } - } else { - if (include_mean_forest) { - resetActiveForest(active_forest_mean) - active_forest_mean$set_root_leaves( - init_values_mean_forest / num_trees_mean - ) - resetForestModel( - forest_model_mean, - active_forest_mean, - forest_dataset_train, - outcome_train, - TRUE - ) - if (sample_sigma2_leaf) { - current_leaf_scale <- as.matrix(sigma2_leaf_init) - forest_model_config_mean$update_leaf_model_scale( - current_leaf_scale - ) - } - } - if (include_variance_forest) { - resetActiveForest(active_forest_variance) - active_forest_variance$set_root_leaves( - log(variance_forest_init) / num_trees_variance - ) - resetForestModel( - forest_model_variance, - active_forest_variance, - forest_dataset_train, - outcome_train, - FALSE - ) - } - if (has_rfx) { - rootResetRandomEffectsModel( - rfx_model, - alpha_init, - xi_init, - sigma_alpha_init, - sigma_xi_init, - sigma_xi_shape, - sigma_xi_scale - ) - rootResetRandomEffectsTracker( - rfx_tracker_train, - rfx_model, - rfx_dataset_train, - outcome_train - ) - } - if (sample_sigma2_global) { - current_sigma2 <- sigma2_init - global_model_config$update_global_error_variance( - current_sigma2 - ) - } - } - for (i in (num_gfr + 1):num_samples) { - is_mcmc <- i > (num_gfr + num_burnin) - if (is_mcmc) { - mcmc_counter <- i - (num_gfr + num_burnin) - if (mcmc_counter %% keep_every == 0) { - keep_sample <- TRUE - } else { - keep_sample <- FALSE - } - } else { - if (keep_burnin) { - keep_sample <- TRUE - } else { - keep_sample <- FALSE - } - } - if (keep_sample) { - sample_counter <- sample_counter + 1 - } - # Print progress - if (verbose) { - if (num_burnin > 0) { - if ( - ((i - num_gfr) %% 100 == 0) || - ((i - num_gfr) == num_burnin) - ) { - cat( - "Sampling", - i - num_gfr, - "out of", - num_burnin, - "BART burn-in draws; Chain number ", - chain_num, - "\n" - ) - } - } - if (num_mcmc > 0) { - if ( - ((i - num_gfr - num_burnin) %% 100 == 0) || - (i == num_samples) - ) { - cat( - "Sampling", - i - num_burnin - num_gfr, - "out of", - num_mcmc, - "BART MCMC draws; Chain number ", - chain_num, - "\n" - ) - } - } - } - - if (include_mean_forest) { - if (probit_outcome_model) { - # Sample latent probit variable, z | - - forest_pred <- active_forest_mean$predict( - forest_dataset_train - ) - mu0 <- forest_pred[y_train == 0] - mu1 <- forest_pred[y_train == 1] - u0 <- runif(sum(y_train == 0), 0, pnorm(0 - mu0)) - u1 <- runif(sum(y_train == 1), pnorm(0 - mu1), 1) - resid_train[y_train == 0] <- mu0 + qnorm(u0) - resid_train[y_train == 1] <- mu1 + qnorm(u1) - - # Update outcome - outcome_train$update_data(resid_train - forest_pred) - } - - forest_model_mean$sample_one_iteration( - forest_dataset = forest_dataset_train, - residual = outcome_train, - forest_samples = forest_samples_mean, - active_forest = active_forest_mean, - rng = rng, - forest_model_config = forest_model_config_mean, - global_model_config = global_model_config, - keep_forest = keep_sample, - gfr = FALSE - ) - - # Cache train set predictions since they are already computed during sampling - if (keep_sample) { - mean_forest_pred_train[, - sample_counter - ] <- forest_model_mean$get_cached_forest_predictions() - } - } - if (include_variance_forest) { - forest_model_variance$sample_one_iteration( - forest_dataset = forest_dataset_train, - residual = outcome_train, - forest_samples = forest_samples_variance, - active_forest = active_forest_variance, - rng = rng, - forest_model_config = forest_model_config_variance, - global_model_config = global_model_config, - keep_forest = keep_sample, - gfr = FALSE - ) - - # Cache train set predictions since they are already computed during sampling - if (keep_sample) { - variance_forest_pred_train[, - sample_counter - ] <- forest_model_variance$get_cached_forest_predictions() - } - } - if (sample_sigma2_global) { - current_sigma2 <- sampleGlobalErrorVarianceOneIteration( - outcome_train, - forest_dataset_train, - rng, - a_global, - b_global - ) - if (keep_sample) { - global_var_samples[sample_counter] <- current_sigma2 - } - global_model_config$update_global_error_variance( - current_sigma2 - ) - } - if (sample_sigma2_leaf) { - leaf_scale_double <- sampleLeafVarianceOneIteration( - active_forest_mean, - rng, - a_leaf, - b_leaf - ) - current_leaf_scale <- as.matrix(leaf_scale_double) - if (keep_sample) { - leaf_scale_samples[sample_counter] <- leaf_scale_double - } - forest_model_config_mean$update_leaf_model_scale( - current_leaf_scale - ) - } - if (has_rfx) { - rfx_model$sample_random_effect( - rfx_dataset_train, - outcome_train, - rfx_tracker_train, - rfx_samples, - keep_sample, - current_sigma2, - rng - ) - } - } + if (keep_sample) { + sample_counter <- sample_counter + 1 } - } - - # Remove GFR samples if they are not to be retained - if ((!keep_gfr) && (num_gfr > 0)) { - for (i in 1:num_gfr) { - if (include_mean_forest) { - forest_samples_mean$delete_sample(0) + # Print progress + if (verbose) { + if (num_burnin > 0) { + if ( + ((i - num_gfr) %% 100 == 0) || + ((i - num_gfr) == num_burnin) + ) { + cat( + "Sampling", + i - num_gfr, + "out of", + num_burnin, + "BART burn-in draws; Chain number ", + chain_num, + "\n" + ) } - if (include_variance_forest) { - forest_samples_variance$delete_sample(0) - } - if (has_rfx) { - rfx_samples$delete_sample(0) + } + if (num_mcmc > 0) { + if ( + ((i - num_gfr - num_burnin) %% 100 == 0) || + (i == num_samples) + ) { + cat( + "Sampling", + i - num_burnin - num_gfr, + "out of", + num_mcmc, + "BART MCMC draws; Chain number ", + chain_num, + "\n" + ) } + } } + if (include_mean_forest) { - mean_forest_pred_train <- mean_forest_pred_train[, - (num_gfr + 1):ncol(mean_forest_pred_train) - ] + if (probit_outcome_model) { + # Sample latent probit variable, z | - + forest_pred <- active_forest_mean$predict( + forest_dataset_train + ) + mu0 <- forest_pred[y_train == 0] + mu1 <- forest_pred[y_train == 1] + u0 <- runif(sum(y_train == 0), 0, pnorm(0 - mu0)) + u1 <- runif(sum(y_train == 1), pnorm(0 - mu1), 1) + resid_train[y_train == 0] <- mu0 + qnorm(u0) + resid_train[y_train == 1] <- mu1 + qnorm(u1) + + # Update outcome + outcome_train$update_data(resid_train - forest_pred) + } + + forest_model_mean$sample_one_iteration( + forest_dataset = forest_dataset_train, + residual = outcome_train, + forest_samples = forest_samples_mean, + active_forest = active_forest_mean, + rng = rng, + forest_model_config = forest_model_config_mean, + global_model_config = global_model_config, + keep_forest = keep_sample, + gfr = FALSE + ) + + # Cache train set predictions since they are already computed during sampling + if (keep_sample) { + mean_forest_pred_train[, + sample_counter + ] <- forest_model_mean$get_cached_forest_predictions() + } } if (include_variance_forest) { - variance_forest_pred_train <- variance_forest_pred_train[, - (num_gfr + 1):ncol(variance_forest_pred_train) - ] + forest_model_variance$sample_one_iteration( + forest_dataset = forest_dataset_train, + residual = outcome_train, + forest_samples = forest_samples_variance, + active_forest = active_forest_variance, + rng = rng, + forest_model_config = forest_model_config_variance, + global_model_config = global_model_config, + keep_forest = keep_sample, + gfr = FALSE + ) + + # Cache train set predictions since they are already computed during sampling + if (keep_sample) { + variance_forest_pred_train[, + sample_counter + ] <- forest_model_variance$get_cached_forest_predictions() + } } if (sample_sigma2_global) { - global_var_samples <- global_var_samples[ - (num_gfr + 1):length(global_var_samples) - ] + current_sigma2 <- sampleGlobalErrorVarianceOneIteration( + outcome_train, + forest_dataset_train, + rng, + a_global, + b_global + ) + if (keep_sample) { + global_var_samples[sample_counter] <- current_sigma2 + } + global_model_config$update_global_error_variance( + current_sigma2 + ) } if (sample_sigma2_leaf) { - leaf_scale_samples <- leaf_scale_samples[ - (num_gfr + 1):length(leaf_scale_samples) - ] - } - num_retained_samples <- num_retained_samples - num_gfr - } - - # Mean forest predictions - if (include_mean_forest) { - # y_hat_train <- forest_samples_mean$predict(forest_dataset_train)*y_std_train + y_bar_train - y_hat_train <- mean_forest_pred_train * y_std_train + y_bar_train - if (has_test) { - y_hat_test <- forest_samples_mean$predict(forest_dataset_test) * - y_std_train + - y_bar_train - } - } - - # Variance forest predictions - if (include_variance_forest) { - # sigma2_x_hat_train <- forest_samples_variance$predict(forest_dataset_train) - sigma2_x_hat_train <- exp(variance_forest_pred_train) - if (has_test) { - sigma2_x_hat_test <- forest_samples_variance$predict( - forest_dataset_test - ) + leaf_scale_double <- sampleLeafVarianceOneIteration( + active_forest_mean, + rng, + a_leaf, + b_leaf + ) + current_leaf_scale <- as.matrix(leaf_scale_double) + if (keep_sample) { + leaf_scale_samples[sample_counter] <- leaf_scale_double + } + forest_model_config_mean$update_leaf_model_scale( + current_leaf_scale + ) } - } - - # Random effects predictions - if (has_rfx) { - rfx_preds_train <- rfx_samples$predict( - rfx_group_ids_train, - rfx_basis_train - ) * - y_std_train - y_hat_train <- y_hat_train + rfx_preds_train - } - if ((has_rfx_test) && (has_test)) { - rfx_preds_test <- rfx_samples$predict( - rfx_group_ids_test, - rfx_basis_test - ) * - y_std_train - y_hat_test <- y_hat_test + rfx_preds_test - } - - # Global error variance - if (sample_sigma2_global) { - sigma2_global_samples <- global_var_samples * (y_std_train^2) - } - - # Leaf parameter variance - if (sample_sigma2_leaf) { - tau_samples <- leaf_scale_samples - } - - # Rescale variance forest prediction by global sigma2 (sampled or constant) - if (include_variance_forest) { - if (sample_sigma2_global) { - sigma2_x_hat_train <- sapply(1:num_retained_samples, function(i) { - sigma2_x_hat_train[, i] * sigma2_global_samples[i] - }) - if (has_test) { - sigma2_x_hat_test <- sapply( - 1:num_retained_samples, - function(i) { - sigma2_x_hat_test[, i] * sigma2_global_samples[i] - } - ) - } - } else { - sigma2_x_hat_train <- sigma2_x_hat_train * - sigma2_init * - y_std_train * - y_std_train - if (has_test) { - sigma2_x_hat_test <- sigma2_x_hat_test * - sigma2_init * - y_std_train * - y_std_train - } + if (has_rfx) { + rfx_model$sample_random_effect( + rfx_dataset_train, + outcome_train, + rfx_tracker_train, + rfx_samples, + keep_sample, + current_sigma2, + rng + ) } + } + } + } + + # Remove GFR samples if they are not to be retained + if ((!keep_gfr) && (num_gfr > 0)) { + for (i in 1:num_gfr) { + if (include_mean_forest) { + forest_samples_mean$delete_sample(0) + } + if (include_variance_forest) { + forest_samples_variance$delete_sample(0) + } + if (has_rfx) { + rfx_samples$delete_sample(0) + } } - - # Return results as a list - model_params <- list( - "sigma2_init" = sigma2_init, - "sigma2_leaf_init" = sigma2_leaf_init, - "a_global" = a_global, - "b_global" = b_global, - "a_leaf" = a_leaf, - "b_leaf" = b_leaf, - "a_forest" = a_forest, - "b_forest" = b_forest, - "outcome_mean" = y_bar_train, - "outcome_scale" = y_std_train, - "standardize" = standardize, - "leaf_dimension" = leaf_dimension, - "is_leaf_constant" = is_leaf_constant, - "leaf_regression" = leaf_regression, - "requires_basis" = requires_basis, - "num_covariates" = ncol(X_train), - "num_basis" = ifelse( - is.null(leaf_basis_train), - 0, - ncol(leaf_basis_train) - ), - "num_samples" = num_retained_samples, - "num_gfr" = num_gfr, - "num_burnin" = num_burnin, - "num_mcmc" = num_mcmc, - "keep_every" = keep_every, - "num_chains" = num_chains, - "has_basis" = !is.null(leaf_basis_train), - "has_rfx" = has_rfx, - "has_rfx_basis" = has_basis_rfx, - "num_rfx_basis" = num_basis_rfx, - "sample_sigma2_global" = sample_sigma2_global, - "sample_sigma2_leaf" = sample_sigma2_leaf, - "include_mean_forest" = include_mean_forest, - "include_variance_forest" = include_variance_forest, - "probit_outcome_model" = probit_outcome_model - ) - result <- list( - "model_params" = model_params, - "train_set_metadata" = X_train_metadata - ) if (include_mean_forest) { - result[["mean_forests"]] = forest_samples_mean - result[["y_hat_train"]] = y_hat_train - if (has_test) result[["y_hat_test"]] = y_hat_test + mean_forest_pred_train <- mean_forest_pred_train[, + (num_gfr + 1):ncol(mean_forest_pred_train) + ] } if (include_variance_forest) { - result[["variance_forests"]] = forest_samples_variance - result[["sigma2_x_hat_train"]] = sigma2_x_hat_train - if (has_test) result[["sigma2_x_hat_test"]] = sigma2_x_hat_test + variance_forest_pred_train <- variance_forest_pred_train[, + (num_gfr + 1):ncol(variance_forest_pred_train) + ] } if (sample_sigma2_global) { - result[["sigma2_global_samples"]] = sigma2_global_samples + global_var_samples <- global_var_samples[ + (num_gfr + 1):length(global_var_samples) + ] } if (sample_sigma2_leaf) { - result[["sigma2_leaf_samples"]] = tau_samples - } - if (has_rfx) { - result[["rfx_samples"]] = rfx_samples - result[["rfx_preds_train"]] = rfx_preds_train - result[["rfx_unique_group_ids"]] = levels(group_ids_factor) + leaf_scale_samples <- leaf_scale_samples[ + (num_gfr + 1):length(leaf_scale_samples) + ] } - if ((has_rfx_test) && (has_test)) { - result[["rfx_preds_test"]] = rfx_preds_test - } - class(result) <- "bartmodel" + num_retained_samples <- num_retained_samples - num_gfr + } - # Clean up classes with external pointers to C++ data structures - if (include_mean_forest) { - rm(forest_model_mean) - } - if (include_variance_forest) { - rm(forest_model_variance) - } - rm(forest_dataset_train) + # Mean forest predictions + if (include_mean_forest) { + # y_hat_train <- forest_samples_mean$predict(forest_dataset_train)*y_std_train + y_bar_train + y_hat_train <- mean_forest_pred_train * y_std_train + y_bar_train if (has_test) { - rm(forest_dataset_test) - } - if (has_rfx) { - rm(rfx_dataset_train, rfx_tracker_train, rfx_model) + y_hat_test <- forest_samples_mean$predict(forest_dataset_test) * + y_std_train + + y_bar_train } - rm(outcome_train) - rm(rng) + } - return(result) + # Variance forest predictions + if (include_variance_forest) { + # sigma2_x_hat_train <- forest_samples_variance$predict(forest_dataset_train) + sigma2_x_hat_train <- exp(variance_forest_pred_train) + if (has_test) { + sigma2_x_hat_test <- forest_samples_variance$predict( + forest_dataset_test + ) + } + } + + # Random effects predictions + if (has_rfx) { + rfx_preds_train <- rfx_samples$predict( + rfx_group_ids_train, + rfx_basis_train + ) * + y_std_train + y_hat_train <- y_hat_train + rfx_preds_train + } + if ((has_rfx_test) && (has_test)) { + rfx_preds_test <- rfx_samples$predict( + rfx_group_ids_test, + rfx_basis_test + ) * + y_std_train + y_hat_test <- y_hat_test + rfx_preds_test + } + + # Global error variance + if (sample_sigma2_global) { + sigma2_global_samples <- global_var_samples * (y_std_train^2) + } + + # Leaf parameter variance + if (sample_sigma2_leaf) { + tau_samples <- leaf_scale_samples + } + + # Rescale variance forest prediction by global sigma2 (sampled or constant) + if (include_variance_forest) { + if (sample_sigma2_global) { + sigma2_x_hat_train <- sapply(1:num_retained_samples, function(i) { + sigma2_x_hat_train[, i] * sigma2_global_samples[i] + }) + if (has_test) { + sigma2_x_hat_test <- sapply( + 1:num_retained_samples, + function(i) { + sigma2_x_hat_test[, i] * sigma2_global_samples[i] + } + ) + } + } else { + sigma2_x_hat_train <- sigma2_x_hat_train * + sigma2_init * + y_std_train * + y_std_train + if (has_test) { + sigma2_x_hat_test <- sigma2_x_hat_test * + sigma2_init * + y_std_train * + y_std_train + } + } + } + + # Return results as a list + model_params <- list( + "sigma2_init" = sigma2_init, + "sigma2_leaf_init" = sigma2_leaf_init, + "a_global" = a_global, + "b_global" = b_global, + "a_leaf" = a_leaf, + "b_leaf" = b_leaf, + "a_forest" = a_forest, + "b_forest" = b_forest, + "outcome_mean" = y_bar_train, + "outcome_scale" = y_std_train, + "standardize" = standardize, + "leaf_dimension" = leaf_dimension, + "is_leaf_constant" = is_leaf_constant, + "leaf_regression" = leaf_regression, + "requires_basis" = requires_basis, + "num_covariates" = ncol(X_train), + "num_basis" = ifelse( + is.null(leaf_basis_train), + 0, + ncol(leaf_basis_train) + ), + "num_samples" = num_retained_samples, + "num_gfr" = num_gfr, + "num_burnin" = num_burnin, + "num_mcmc" = num_mcmc, + "keep_every" = keep_every, + "num_chains" = num_chains, + "has_basis" = !is.null(leaf_basis_train), + "has_rfx" = has_rfx, + "has_rfx_basis" = has_basis_rfx, + "num_rfx_basis" = num_basis_rfx, + "sample_sigma2_global" = sample_sigma2_global, + "sample_sigma2_leaf" = sample_sigma2_leaf, + "include_mean_forest" = include_mean_forest, + "include_variance_forest" = include_variance_forest, + "probit_outcome_model" = probit_outcome_model + ) + result <- list( + "model_params" = model_params, + "train_set_metadata" = X_train_metadata + ) + if (include_mean_forest) { + result[["mean_forests"]] = forest_samples_mean + result[["y_hat_train"]] = y_hat_train + if (has_test) result[["y_hat_test"]] = y_hat_test + } + if (include_variance_forest) { + result[["variance_forests"]] = forest_samples_variance + result[["sigma2_x_hat_train"]] = sigma2_x_hat_train + if (has_test) result[["sigma2_x_hat_test"]] = sigma2_x_hat_test + } + if (sample_sigma2_global) { + result[["sigma2_global_samples"]] = sigma2_global_samples + } + if (sample_sigma2_leaf) { + result[["sigma2_leaf_samples"]] = tau_samples + } + if (has_rfx) { + result[["rfx_samples"]] = rfx_samples + result[["rfx_preds_train"]] = rfx_preds_train + result[["rfx_unique_group_ids"]] = levels(group_ids_factor) + } + if ((has_rfx_test) && (has_test)) { + result[["rfx_preds_test"]] = rfx_preds_test + } + class(result) <- "bartmodel" + + # Clean up classes with external pointers to C++ data structures + if (include_mean_forest) { + rm(forest_model_mean) + } + if (include_variance_forest) { + rm(forest_model_variance) + } + rm(forest_dataset_train) + if (has_test) { + rm(forest_dataset_test) + } + if (has_rfx) { + rm(rfx_dataset_train, rfx_tracker_train, rfx_model) + } + rm(outcome_train) + rm(rng) + + return(result) } #' Predict from a sampled BART model on new data @@ -1771,6 +1767,8 @@ bart <- function( #' We do not currently support (but plan to in the near future), test set evaluation for group labels #' that were not in the training set. #' @param rfx_basis (Optional) Test set basis for "random-slope" regression in additive random effects model. +#' @param type (Optional) Type of prediction to return. Options are "mean", which averages the predictions from every draw of a BART model, and "posterior", which returns the entire matrix of posterior predictions. Default: "posterior". +#' @param terms (Optional) Which model terms to include in the prediction. This can be a single term or a list of model terms. Options include "y_hat", "mean_forest", "rfx", "variance_forest", or "all". If a model doesn't have mean forest, random effects, or variance forest predictions, but one of those terms is request, the request will simply be ignored. If none of the requested terms are present in a model, this function will return `NULL`. Default: "all". #' @param ... (Optional) Other prediction parameters. #' #' @return List of prediction matrices. If model does not have random effects, the list has one element -- the predictions from the forest. @@ -1802,162 +1800,219 @@ bart <- function( #' num_gfr = 10, num_burnin = 0, num_mcmc = 10) #' y_hat_test <- predict(bart_model, X_test)$y_hat predict.bartmodel <- function( - object, - X, - leaf_basis = NULL, - rfx_group_ids = NULL, - rfx_basis = NULL, - ... + object, + X, + leaf_basis = NULL, + rfx_group_ids = NULL, + rfx_basis = NULL, + type = "posterior", + terms = "all", + ... ) { - # Preprocess covariates - if ((!is.data.frame(X)) && (!is.matrix(X))) { - stop("X must be a matrix or dataframe") - } - train_set_metadata <- object$train_set_metadata - X <- preprocessPredictionData(X, train_set_metadata) - - # Convert all input data to matrices if not already converted - if ((is.null(dim(leaf_basis))) && (!is.null(leaf_basis))) { - leaf_basis <- as.matrix(leaf_basis) - } - if ((is.null(dim(rfx_basis))) && (!is.null(rfx_basis))) { - rfx_basis <- as.matrix(rfx_basis) - } + # Handle prediction type + if (!is.character(type)) { + stop("type must be a string or character vector") + } + if (!(type %in% c("mean", "posterior"))) { + stop("type must either be 'mean' or 'posterior") + } + predict_mean <- type == "mean" + + # Handle prediction terms + if (!is.character(terms)) { + stop("type must be a string or character vector") + } + num_terms <- length(terms) + has_mean_forest <- object$model_params$include_mean_forest + has_variance_forest <- object$model_params$include_variance_forest + has_rfx <- object$model_params$has_rfx + has_y_hat <- has_mean_forest || has_rfx + predict_y_hat <- (((has_y_hat) && ("y_hat" %in% terms)) || + ((has_y_hat) && ("all" %in% terms))) + predict_mean_forest <- (((has_mean_forest) && ("mean_forest" %in% terms)) || + ((has_mean_forest) && ("all" %in% terms))) + predict_rfx <- (((has_rfx) && ("rfx" %in% terms)) || + ((has_rfx) && ("all" %in% terms))) + predict_variance_forest <- (((has_variance_forest) && + ("variance_forest" %in% terms)) || + ((has_variance_forest) && ("all" %in% terms))) + predict_count <- sum(c( + predict_y_hat, + predict_mean_forest, + predict_rfx, + predict_variance_forest + )) + if (predict_count == 0) { + warning(paste0( + "None of the requested model terms, ", + paste(terms, collapse = ", "), + ", were fit in this model" + )) + return(NULL) + } + predict_rfx_intermediate <- (predict_y_hat && has_rfx) + predict_mean_forest_intermediate <- (predict_y_hat && has_mean_forest) + + # Preprocess covariates + if ((!is.data.frame(X)) && (!is.matrix(X))) { + stop("X must be a matrix or dataframe") + } + train_set_metadata <- object$train_set_metadata + X <- preprocessPredictionData(X, train_set_metadata) + + # Convert all input data to matrices if not already converted + if ((is.null(dim(leaf_basis))) && (!is.null(leaf_basis))) { + leaf_basis <- as.matrix(leaf_basis) + } + if ((is.null(dim(rfx_basis))) && (!is.null(rfx_basis))) { + if (predict_rfx) rfx_basis <- as.matrix(rfx_basis) + } + + # Data checks + if ((object$model_params$requires_basis) && (is.null(leaf_basis))) { + stop("Basis (leaf_basis) must be provided for this model") + } + if ((!is.null(leaf_basis)) && (nrow(X) != nrow(leaf_basis))) { + stop("X and leaf_basis must have the same number of rows") + } + if (object$model_params$num_covariates != ncol(X)) { + stop("X and leaf_basis must have the same number of rows") + } + if ((predict_rfx) && (is.null(rfx_group_ids))) { + stop( + "Random effect group labels (rfx_group_ids) must be provided for this model" + ) + } + if ((predict_rfx) && (is.null(rfx_basis))) { + stop("Random effects basis (rfx_basis) must be provided for this model") + } + if ( + (object$model_params$num_rfx_basis > 0) && + (ncol(rfx_basis) != object$model_params$num_rfx_basis) + ) { + stop( + "Random effects basis has a different dimension than the basis used to train this model" + ) + } - # Data checks - if ((object$model_params$requires_basis) && (is.null(leaf_basis))) { - stop("Basis (leaf_basis) must be provided for this model") - } - if ((!is.null(leaf_basis)) && (nrow(X) != nrow(leaf_basis))) { - stop("X and leaf_basis must have the same number of rows") - } - if (object$model_params$num_covariates != ncol(X)) { - stop("X and leaf_basis must have the same number of rows") - } - if ((object$model_params$has_rfx) && (is.null(rfx_group_ids))) { - stop( - "Random effect group labels (rfx_group_ids) must be provided for this model" - ) - } - if ((object$model_params$has_rfx_basis) && (is.null(rfx_basis))) { - stop("Random effects basis (rfx_basis) must be provided for this model") - } - if ( - (object$model_params$num_rfx_basis > 0) && - (ncol(rfx_basis) != object$model_params$num_rfx_basis) - ) { + # Recode group IDs to integer vector (if passed as, for example, a vector of county names, etc...) + has_rfx <- FALSE + if (predict_rfx) { + if (!is.null(rfx_group_ids)) { + rfx_unique_group_ids <- object$rfx_unique_group_ids + group_ids_factor <- factor(rfx_group_ids, levels = rfx_unique_group_ids) + if (sum(is.na(group_ids_factor)) > 0) { stop( - "Random effects basis has a different dimension than the basis used to train this model" + "All random effect group labels provided in rfx_group_ids must have been present in rfx_group_ids_train" ) - } - - # Recode group IDs to integer vector (if passed as, for example, a vector of county names, etc...) - has_rfx <- FALSE - if (!is.null(rfx_group_ids)) { - rfx_unique_group_ids <- object$rfx_unique_group_ids - group_ids_factor <- factor(rfx_group_ids, levels = rfx_unique_group_ids) - if (sum(is.na(group_ids_factor)) > 0) { - stop( - "All random effect group labels provided in rfx_group_ids must be present in rfx_group_ids_train" - ) - } - rfx_group_ids <- as.integer(group_ids_factor) - has_rfx <- TRUE - } - - # Produce basis for the "intercept-only" random effects case - if ((object$model_params$has_rfx) && (is.null(rfx_basis))) { - rfx_basis <- matrix(rep(1, nrow(X)), ncol = 1) - } - - # Create prediction dataset - if (!is.null(leaf_basis)) { - prediction_dataset <- createForestDataset(X, leaf_basis) + } + rfx_group_ids <- as.integer(group_ids_factor) + has_rfx <- TRUE + } + } + + # Produce basis for the "intercept-only" random effects case + if ((predict_rfx) && (is.null(rfx_basis))) { + rfx_basis <- matrix(rep(1, nrow(X)), ncol = 1) + } + + # Create prediction dataset + if (!is.null(leaf_basis)) { + prediction_dataset <- createForestDataset(X, leaf_basis) + } else { + prediction_dataset <- createForestDataset(X) + } + + # Compute mean forest predictions + num_samples <- object$model_params$num_samples + y_std <- object$model_params$outcome_scale + y_bar <- object$model_params$outcome_mean + sigma2_init <- object$model_params$sigma2_init + if (predict_mean_forest || predict_mean_forest_intermediate) { + mean_forest_predictions <- object$mean_forests$predict( + prediction_dataset + ) * + y_std + + y_bar + if (predict_mean) { + mean_forest_predictions <- rowMeans(mean_forest_predictions) + } + } + + # Compute variance forest predictions + if (predict_variance_forest) { + s_x_raw <- object$variance_forests$predict(prediction_dataset) + } + + # Compute rfx predictions (if needed) + if (predict_rfx || predict_rfx_intermediate) { + rfx_predictions <- object$rfx_samples$predict( + rfx_group_ids, + rfx_basis + ) * + y_std + if (predict_mean) { + rfx_predictions <- rowMeans(rfx_predictions) + } + } + + # Scale variance forest predictions + if (predict_variance_forest) { + if (object$model_params$sample_sigma2_global) { + sigma2_global_samples <- object$sigma2_global_samples + variance_forest_predictions <- sapply(1:num_samples, function(i) { + s_x_raw[, i] * sigma2_global_samples[i] + }) } else { - prediction_dataset <- createForestDataset(X) - } - - # Compute mean forest predictions - num_samples <- object$model_params$num_samples - y_std <- object$model_params$outcome_scale - y_bar <- object$model_params$outcome_mean - sigma2_init <- object$model_params$sigma2_init - if (object$model_params$include_mean_forest) { - mean_forest_predictions <- object$mean_forests$predict( - prediction_dataset - ) * - y_std + - y_bar - } - - # Compute variance forest predictions - if (object$model_params$include_variance_forest) { - s_x_raw <- object$variance_forests$predict(prediction_dataset) - } - - # Compute rfx predictions (if needed) - if (object$model_params$has_rfx) { - rfx_predictions <- object$rfx_samples$predict( - rfx_group_ids, - rfx_basis - ) * - y_std - } - - # Scale variance forest predictions - if (object$model_params$include_variance_forest) { - if (object$model_params$sample_sigma2_global) { - sigma2_global_samples <- object$sigma2_global_samples - variance_forest_predictions <- sapply(1:num_samples, function(i) { - s_x_raw[, i] * sigma2_global_samples[i] - }) - } else { - variance_forest_predictions <- s_x_raw * sigma2_init * y_std * y_std - } - } - - if ( - (object$model_params$include_mean_forest) && - (object$model_params$has_rfx) - ) { - y_hat <- mean_forest_predictions + rfx_predictions - } else if ( - (object$model_params$include_mean_forest) && - (!object$model_params$has_rfx) - ) { - y_hat <- mean_forest_predictions - } else if ( - (!object$model_params$include_mean_forest) && - (object$model_params$has_rfx) - ) { - y_hat <- rfx_predictions - } - + variance_forest_predictions <- s_x_raw * sigma2_init * y_std * y_std + } + if (predict_mean) { + variance_forest_predictions <- rowMeans(variance_forest_predictions) + } + } + + if (predict_y_hat && has_mean_forest && has_rfx) { + y_hat <- mean_forest_predictions + rfx_predictions + } else if (predict_y_hat && has_mean_forest) { + y_hat <- mean_forest_predictions + } else if (predict_y_hat && has_rfx) { + y_hat <- rfx_predictions + } + + if (predict_count == 1) { + if (predict_y_hat) { + return(y_hat) + } else if (predict_mean_forest) { + return(mean_forest_predictions) + } else if (predict_rfx) { + return(rfx_predictions) + } else if (predict_variance_forest) { + return(variance_forest_predictions) + } + } else { result <- list() - if ( - (object$model_params$has_rfx) || - (object$model_params$include_mean_forest) - ) { - result[["y_hat"]] = y_hat + if (predict_y_hat) { + result[["y_hat"]] = y_hat } else { - result[["y_hat"]] <- NULL + result[["y_hat"]] <- NULL } - if (object$model_params$include_mean_forest) { - result[["mean_forest_predictions"]] = mean_forest_predictions + if (predict_mean_forest) { + result[["mean_forest_predictions"]] = mean_forest_predictions } else { - result[["mean_forest_predictions"]] <- NULL + result[["mean_forest_predictions"]] <- NULL } - if (object$model_params$has_rfx) { - result[["rfx_predictions"]] = rfx_predictions + if (predict_rfx) { + result[["rfx_predictions"]] = rfx_predictions } else { - result[["rfx_predictions"]] <- NULL + result[["rfx_predictions"]] <- NULL } - if (object$model_params$include_variance_forest) { - result[["variance_forest_predictions"]] = variance_forest_predictions + if (predict_variance_forest) { + result[["variance_forest_predictions"]] = variance_forest_predictions } else { - result[["variance_forest_predictions"]] <- NULL + result[["variance_forest_predictions"]] <- NULL } return(result) + } } #' Extract raw sample values for each of the random effect parameter terms. @@ -2009,26 +2064,26 @@ predict.bartmodel <- function( #' num_gfr = 10, num_burnin = 0, num_mcmc = 10) #' rfx_samples <- getRandomEffectSamples(bart_model) getRandomEffectSamples.bartmodel <- function(object, ...) { - result = list() + result = list() - if (!object$model_params$has_rfx) { - warning("This model has no RFX terms, returning an empty list") - return(result) - } + if (!object$model_params$has_rfx) { + warning("This model has no RFX terms, returning an empty list") + return(result) + } - # Extract the samples - result <- object$rfx_samples$extract_parameter_samples() + # Extract the samples + result <- object$rfx_samples$extract_parameter_samples() - # Scale by sd(y_train) - result$beta_samples <- result$beta_samples * - object$model_params$outcome_scale - result$xi_samples <- result$xi_samples * object$model_params$outcome_scale - result$alpha_samples <- result$alpha_samples * - object$model_params$outcome_scale - result$sigma_samples <- result$sigma_samples * - (object$model_params$outcome_scale^2) + # Scale by sd(y_train) + result$beta_samples <- result$beta_samples * + object$model_params$outcome_scale + result$xi_samples <- result$xi_samples * object$model_params$outcome_scale + result$alpha_samples <- result$alpha_samples * + object$model_params$outcome_scale + result$sigma_samples <- result$sigma_samples * + (object$model_params$outcome_scale^2) - return(result) + return(result) } #' Convert the persistent aspects of a BART model to (in-memory) JSON @@ -2063,132 +2118,132 @@ getRandomEffectSamples.bartmodel <- function(object, ...) { #' num_gfr = 10, num_burnin = 0, num_mcmc = 10) #' bart_json <- saveBARTModelToJson(bart_model) saveBARTModelToJson <- function(object) { - jsonobj <- createCppJson() - - if (!inherits(object, "bartmodel")) { - stop("`object` must be a BART model") - } - - if (is.null(object$model_params)) { - stop("This BCF model has not yet been sampled") - } - - # Add the forests - if (object$model_params$include_mean_forest) { - jsonobj$add_forest(object$mean_forests) - } - if (object$model_params$include_variance_forest) { - jsonobj$add_forest(object$variance_forests) - } - - # Add metadata - jsonobj$add_scalar( - "num_numeric_vars", - object$train_set_metadata$num_numeric_vars + jsonobj <- createCppJson() + + if (!inherits(object, "bartmodel")) { + stop("`object` must be a BART model") + } + + if (is.null(object$model_params)) { + stop("This BCF model has not yet been sampled") + } + + # Add the forests + if (object$model_params$include_mean_forest) { + jsonobj$add_forest(object$mean_forests) + } + if (object$model_params$include_variance_forest) { + jsonobj$add_forest(object$variance_forests) + } + + # Add metadata + jsonobj$add_scalar( + "num_numeric_vars", + object$train_set_metadata$num_numeric_vars + ) + jsonobj$add_scalar( + "num_ordered_cat_vars", + object$train_set_metadata$num_ordered_cat_vars + ) + jsonobj$add_scalar( + "num_unordered_cat_vars", + object$train_set_metadata$num_unordered_cat_vars + ) + if (object$train_set_metadata$num_numeric_vars > 0) { + jsonobj$add_string_vector( + "numeric_vars", + object$train_set_metadata$numeric_vars ) - jsonobj$add_scalar( - "num_ordered_cat_vars", - object$train_set_metadata$num_ordered_cat_vars + } + if (object$train_set_metadata$num_ordered_cat_vars > 0) { + jsonobj$add_string_vector( + "ordered_cat_vars", + object$train_set_metadata$ordered_cat_vars ) - jsonobj$add_scalar( - "num_unordered_cat_vars", - object$train_set_metadata$num_unordered_cat_vars + jsonobj$add_string_list( + "ordered_unique_levels", + object$train_set_metadata$ordered_unique_levels ) - if (object$train_set_metadata$num_numeric_vars > 0) { - jsonobj$add_string_vector( - "numeric_vars", - object$train_set_metadata$numeric_vars - ) - } - if (object$train_set_metadata$num_ordered_cat_vars > 0) { - jsonobj$add_string_vector( - "ordered_cat_vars", - object$train_set_metadata$ordered_cat_vars - ) - jsonobj$add_string_list( - "ordered_unique_levels", - object$train_set_metadata$ordered_unique_levels - ) - } - if (object$train_set_metadata$num_unordered_cat_vars > 0) { - jsonobj$add_string_vector( - "unordered_cat_vars", - object$train_set_metadata$unordered_cat_vars - ) - jsonobj$add_string_list( - "unordered_unique_levels", - object$train_set_metadata$unordered_unique_levels - ) - } - - # Add global parameters - jsonobj$add_scalar("outcome_scale", object$model_params$outcome_scale) - jsonobj$add_scalar("outcome_mean", object$model_params$outcome_mean) - jsonobj$add_boolean("standardize", object$model_params$standardize) - jsonobj$add_scalar("sigma2_init", object$model_params$sigma2_init) - jsonobj$add_boolean( - "sample_sigma2_global", - object$model_params$sample_sigma2_global + } + if (object$train_set_metadata$num_unordered_cat_vars > 0) { + jsonobj$add_string_vector( + "unordered_cat_vars", + object$train_set_metadata$unordered_cat_vars ) - jsonobj$add_boolean( - "sample_sigma2_leaf", - object$model_params$sample_sigma2_leaf + jsonobj$add_string_list( + "unordered_unique_levels", + object$train_set_metadata$unordered_unique_levels ) - jsonobj$add_boolean( - "include_mean_forest", - object$model_params$include_mean_forest + } + + # Add global parameters + jsonobj$add_scalar("outcome_scale", object$model_params$outcome_scale) + jsonobj$add_scalar("outcome_mean", object$model_params$outcome_mean) + jsonobj$add_boolean("standardize", object$model_params$standardize) + jsonobj$add_scalar("sigma2_init", object$model_params$sigma2_init) + jsonobj$add_boolean( + "sample_sigma2_global", + object$model_params$sample_sigma2_global + ) + jsonobj$add_boolean( + "sample_sigma2_leaf", + object$model_params$sample_sigma2_leaf + ) + jsonobj$add_boolean( + "include_mean_forest", + object$model_params$include_mean_forest + ) + jsonobj$add_boolean( + "include_variance_forest", + object$model_params$include_variance_forest + ) + jsonobj$add_boolean("has_rfx", object$model_params$has_rfx) + jsonobj$add_boolean("has_rfx_basis", object$model_params$has_rfx_basis) + jsonobj$add_scalar("num_rfx_basis", object$model_params$num_rfx_basis) + jsonobj$add_scalar("num_gfr", object$model_params$num_gfr) + jsonobj$add_scalar("num_burnin", object$model_params$num_burnin) + jsonobj$add_scalar("num_mcmc", object$model_params$num_mcmc) + jsonobj$add_scalar("num_samples", object$model_params$num_samples) + jsonobj$add_scalar("num_covariates", object$model_params$num_covariates) + jsonobj$add_scalar("num_basis", object$model_params$num_basis) + jsonobj$add_scalar("num_chains", object$model_params$num_chains) + jsonobj$add_scalar("keep_every", object$model_params$keep_every) + jsonobj$add_boolean("requires_basis", object$model_params$requires_basis) + jsonobj$add_boolean( + "probit_outcome_model", + object$model_params$probit_outcome_model + ) + if (object$model_params$sample_sigma2_global) { + jsonobj$add_vector( + "sigma2_global_samples", + object$sigma2_global_samples, + "parameters" ) - jsonobj$add_boolean( - "include_variance_forest", - object$model_params$include_variance_forest + } + if (object$model_params$sample_sigma2_leaf) { + jsonobj$add_vector( + "sigma2_leaf_samples", + object$sigma2_leaf_samples, + "parameters" ) - jsonobj$add_boolean("has_rfx", object$model_params$has_rfx) - jsonobj$add_boolean("has_rfx_basis", object$model_params$has_rfx_basis) - jsonobj$add_scalar("num_rfx_basis", object$model_params$num_rfx_basis) - jsonobj$add_scalar("num_gfr", object$model_params$num_gfr) - jsonobj$add_scalar("num_burnin", object$model_params$num_burnin) - jsonobj$add_scalar("num_mcmc", object$model_params$num_mcmc) - jsonobj$add_scalar("num_samples", object$model_params$num_samples) - jsonobj$add_scalar("num_covariates", object$model_params$num_covariates) - jsonobj$add_scalar("num_basis", object$model_params$num_basis) - jsonobj$add_scalar("num_chains", object$model_params$num_chains) - jsonobj$add_scalar("keep_every", object$model_params$keep_every) - jsonobj$add_boolean("requires_basis", object$model_params$requires_basis) - jsonobj$add_boolean( - "probit_outcome_model", - object$model_params$probit_outcome_model + } + + # Add random effects (if present) + if (object$model_params$has_rfx) { + jsonobj$add_random_effects(object$rfx_samples) + jsonobj$add_string_vector( + "rfx_unique_group_ids", + object$rfx_unique_group_ids ) - if (object$model_params$sample_sigma2_global) { - jsonobj$add_vector( - "sigma2_global_samples", - object$sigma2_global_samples, - "parameters" - ) - } - if (object$model_params$sample_sigma2_leaf) { - jsonobj$add_vector( - "sigma2_leaf_samples", - object$sigma2_leaf_samples, - "parameters" - ) - } - - # Add random effects (if present) - if (object$model_params$has_rfx) { - jsonobj$add_random_effects(object$rfx_samples) - jsonobj$add_string_vector( - "rfx_unique_group_ids", - object$rfx_unique_group_ids - ) - } + } - # Add covariate preprocessor metadata - preprocessor_metadata_string <- savePreprocessorToJsonString( - object$train_set_metadata - ) - jsonobj$add_string("preprocessor_metadata", preprocessor_metadata_string) + # Add covariate preprocessor metadata + preprocessor_metadata_string <- savePreprocessorToJsonString( + object$train_set_metadata + ) + jsonobj$add_string("preprocessor_metadata", preprocessor_metadata_string) - return(jsonobj) + return(jsonobj) } #' Convert the persistent aspects of a BART model to (in-memory) JSON and save to a file @@ -2226,11 +2281,11 @@ saveBARTModelToJson <- function(object) { #' saveBARTModelToJsonFile(bart_model, file.path(tmpjson)) #' unlink(tmpjson) saveBARTModelToJsonFile <- function(object, filename) { - # Convert to Json - jsonobj <- saveBARTModelToJson(object) + # Convert to Json + jsonobj <- saveBARTModelToJson(object) - # Save to file - jsonobj$save_file(filename) + # Save to file + jsonobj$save_file(filename) } #' Convert the persistent aspects of a BART model to (in-memory) JSON string @@ -2264,11 +2319,11 @@ saveBARTModelToJsonFile <- function(object, filename) { #' num_gfr = 10, num_burnin = 0, num_mcmc = 10) #' bart_json_string <- saveBARTModelToJsonString(bart_model) saveBARTModelToJsonString <- function(object) { - # Convert to Json - jsonobj <- saveBARTModelToJson(object) + # Convert to Json + jsonobj <- saveBARTModelToJson(object) - # Dump to string - return(jsonobj$return_json_string()) + # Dump to string + return(jsonobj$return_json_string()) } #' Convert an (in-memory) JSON representation of a BART model to a BART model object @@ -2305,138 +2360,138 @@ saveBARTModelToJsonString <- function(object) { #' bart_json <- saveBARTModelToJson(bart_model) #' bart_model_roundtrip <- createBARTModelFromJson(bart_json) createBARTModelFromJson <- function(json_object) { - # Initialize the BCF model - output <- list() - - # Unpack the forests - include_mean_forest <- json_object$get_boolean("include_mean_forest") - include_variance_forest <- json_object$get_boolean( - "include_variance_forest" + # Initialize the BCF model + output <- list() + + # Unpack the forests + include_mean_forest <- json_object$get_boolean("include_mean_forest") + include_variance_forest <- json_object$get_boolean( + "include_variance_forest" + ) + if (include_mean_forest) { + output[["mean_forests"]] <- loadForestContainerJson( + json_object, + "forest_0" ) - if (include_mean_forest) { - output[["mean_forests"]] <- loadForestContainerJson( - json_object, - "forest_0" - ) - if (include_variance_forest) { - output[["variance_forests"]] <- loadForestContainerJson( - json_object, - "forest_1" - ) - } - } else { - output[["variance_forests"]] <- loadForestContainerJson( - json_object, - "forest_0" - ) - } - - # Unpack metadata - train_set_metadata = list() - train_set_metadata[["num_numeric_vars"]] <- json_object$get_scalar( - "num_numeric_vars" - ) - train_set_metadata[["num_ordered_cat_vars"]] <- json_object$get_scalar( - "num_ordered_cat_vars" + if (include_variance_forest) { + output[["variance_forests"]] <- loadForestContainerJson( + json_object, + "forest_1" + ) + } + } else { + output[["variance_forests"]] <- loadForestContainerJson( + json_object, + "forest_0" ) - train_set_metadata[["num_unordered_cat_vars"]] <- json_object$get_scalar( - "num_unordered_cat_vars" + } + + # Unpack metadata + train_set_metadata = list() + train_set_metadata[["num_numeric_vars"]] <- json_object$get_scalar( + "num_numeric_vars" + ) + train_set_metadata[["num_ordered_cat_vars"]] <- json_object$get_scalar( + "num_ordered_cat_vars" + ) + train_set_metadata[["num_unordered_cat_vars"]] <- json_object$get_scalar( + "num_unordered_cat_vars" + ) + if (train_set_metadata[["num_numeric_vars"]] > 0) { + train_set_metadata[["numeric_vars"]] <- json_object$get_string_vector( + "numeric_vars" ) - if (train_set_metadata[["num_numeric_vars"]] > 0) { - train_set_metadata[["numeric_vars"]] <- json_object$get_string_vector( - "numeric_vars" - ) - } - if (train_set_metadata[["num_ordered_cat_vars"]] > 0) { - train_set_metadata[[ - "ordered_cat_vars" - ]] <- json_object$get_string_vector("ordered_cat_vars") - train_set_metadata[[ - "ordered_unique_levels" - ]] <- json_object$get_string_list( - "ordered_unique_levels", - train_set_metadata[["ordered_cat_vars"]] - ) - } - if (train_set_metadata[["num_unordered_cat_vars"]] > 0) { - train_set_metadata[[ - "unordered_cat_vars" - ]] <- json_object$get_string_vector("unordered_cat_vars") - train_set_metadata[[ - "unordered_unique_levels" - ]] <- json_object$get_string_list( - "unordered_unique_levels", - train_set_metadata[["unordered_cat_vars"]] - ) - } - output[["train_set_metadata"]] <- train_set_metadata - - # Unpack model params - model_params = list() - model_params[["outcome_scale"]] <- json_object$get_scalar("outcome_scale") - model_params[["outcome_mean"]] <- json_object$get_scalar("outcome_mean") - model_params[["standardize"]] <- json_object$get_boolean("standardize") - model_params[["sigma2_init"]] <- json_object$get_scalar("sigma2_init") - model_params[["sample_sigma2_global"]] <- json_object$get_boolean( - "sample_sigma2_global" + } + if (train_set_metadata[["num_ordered_cat_vars"]] > 0) { + train_set_metadata[[ + "ordered_cat_vars" + ]] <- json_object$get_string_vector("ordered_cat_vars") + train_set_metadata[[ + "ordered_unique_levels" + ]] <- json_object$get_string_list( + "ordered_unique_levels", + train_set_metadata[["ordered_cat_vars"]] ) - model_params[["sample_sigma2_leaf"]] <- json_object$get_boolean( - "sample_sigma2_leaf" + } + if (train_set_metadata[["num_unordered_cat_vars"]] > 0) { + train_set_metadata[[ + "unordered_cat_vars" + ]] <- json_object$get_string_vector("unordered_cat_vars") + train_set_metadata[[ + "unordered_unique_levels" + ]] <- json_object$get_string_list( + "unordered_unique_levels", + train_set_metadata[["unordered_cat_vars"]] ) - model_params[["include_mean_forest"]] <- include_mean_forest - model_params[["include_variance_forest"]] <- include_variance_forest - model_params[["has_rfx"]] <- json_object$get_boolean("has_rfx") - model_params[["has_rfx_basis"]] <- json_object$get_boolean("has_rfx_basis") - model_params[["num_rfx_basis"]] <- json_object$get_scalar("num_rfx_basis") - model_params[["num_gfr"]] <- json_object$get_scalar("num_gfr") - model_params[["num_burnin"]] <- json_object$get_scalar("num_burnin") - model_params[["num_mcmc"]] <- json_object$get_scalar("num_mcmc") - model_params[["num_samples"]] <- json_object$get_scalar("num_samples") - model_params[["num_covariates"]] <- json_object$get_scalar("num_covariates") - model_params[["num_basis"]] <- json_object$get_scalar("num_basis") - model_params[["num_chains"]] <- json_object$get_scalar("num_chains") - model_params[["keep_every"]] <- json_object$get_scalar("keep_every") - model_params[["requires_basis"]] <- json_object$get_boolean( - "requires_basis" + } + output[["train_set_metadata"]] <- train_set_metadata + + # Unpack model params + model_params = list() + model_params[["outcome_scale"]] <- json_object$get_scalar("outcome_scale") + model_params[["outcome_mean"]] <- json_object$get_scalar("outcome_mean") + model_params[["standardize"]] <- json_object$get_boolean("standardize") + model_params[["sigma2_init"]] <- json_object$get_scalar("sigma2_init") + model_params[["sample_sigma2_global"]] <- json_object$get_boolean( + "sample_sigma2_global" + ) + model_params[["sample_sigma2_leaf"]] <- json_object$get_boolean( + "sample_sigma2_leaf" + ) + model_params[["include_mean_forest"]] <- include_mean_forest + model_params[["include_variance_forest"]] <- include_variance_forest + model_params[["has_rfx"]] <- json_object$get_boolean("has_rfx") + model_params[["has_rfx_basis"]] <- json_object$get_boolean("has_rfx_basis") + model_params[["num_rfx_basis"]] <- json_object$get_scalar("num_rfx_basis") + model_params[["num_gfr"]] <- json_object$get_scalar("num_gfr") + model_params[["num_burnin"]] <- json_object$get_scalar("num_burnin") + model_params[["num_mcmc"]] <- json_object$get_scalar("num_mcmc") + model_params[["num_samples"]] <- json_object$get_scalar("num_samples") + model_params[["num_covariates"]] <- json_object$get_scalar("num_covariates") + model_params[["num_basis"]] <- json_object$get_scalar("num_basis") + model_params[["num_chains"]] <- json_object$get_scalar("num_chains") + model_params[["keep_every"]] <- json_object$get_scalar("keep_every") + model_params[["requires_basis"]] <- json_object$get_boolean( + "requires_basis" + ) + model_params[["probit_outcome_model"]] <- json_object$get_boolean( + "probit_outcome_model" + ) + + output[["model_params"]] <- model_params + + # Unpack sampled parameters + if (model_params[["sample_sigma2_global"]]) { + output[["sigma2_global_samples"]] <- json_object$get_vector( + "sigma2_global_samples", + "parameters" ) - model_params[["probit_outcome_model"]] <- json_object$get_boolean( - "probit_outcome_model" + } + if (model_params[["sample_sigma2_leaf"]]) { + output[["sigma2_leaf_samples"]] <- json_object$get_vector( + "sigma2_leaf_samples", + "parameters" ) + } - output[["model_params"]] <- model_params - - # Unpack sampled parameters - if (model_params[["sample_sigma2_global"]]) { - output[["sigma2_global_samples"]] <- json_object$get_vector( - "sigma2_global_samples", - "parameters" - ) - } - if (model_params[["sample_sigma2_leaf"]]) { - output[["sigma2_leaf_samples"]] <- json_object$get_vector( - "sigma2_leaf_samples", - "parameters" - ) - } - - # Unpack random effects - if (model_params[["has_rfx"]]) { - output[["rfx_unique_group_ids"]] <- json_object$get_string_vector( - "rfx_unique_group_ids" - ) - output[["rfx_samples"]] <- loadRandomEffectSamplesJson(json_object, 0) - } - - # Unpack covariate preprocessor - preprocessor_metadata_string <- json_object$get_string( - "preprocessor_metadata" + # Unpack random effects + if (model_params[["has_rfx"]]) { + output[["rfx_unique_group_ids"]] <- json_object$get_string_vector( + "rfx_unique_group_ids" ) - output[["train_set_metadata"]] <- createPreprocessorFromJsonString( - preprocessor_metadata_string - ) - - class(output) <- "bartmodel" - return(output) + output[["rfx_samples"]] <- loadRandomEffectSamplesJson(json_object, 0) + } + + # Unpack covariate preprocessor + preprocessor_metadata_string <- json_object$get_string( + "preprocessor_metadata" + ) + output[["train_set_metadata"]] <- createPreprocessorFromJsonString( + preprocessor_metadata_string + ) + + class(output) <- "bartmodel" + return(output) } #' Convert a JSON file containing sample information on a trained BART model @@ -2475,13 +2530,13 @@ createBARTModelFromJson <- function(json_object) { #' bart_model_roundtrip <- createBARTModelFromJsonFile(file.path(tmpjson)) #' unlink(tmpjson) createBARTModelFromJsonFile <- function(json_filename) { - # Load a `CppJson` object from file - bart_json <- createCppJsonFile(json_filename) + # Load a `CppJson` object from file + bart_json <- createCppJsonFile(json_filename) - # Create and return the BART object - bart_object <- createBARTModelFromJson(bart_json) + # Create and return the BART object + bart_object <- createBARTModelFromJson(bart_json) - return(bart_object) + return(bart_object) } #' Convert a JSON string containing sample information on a trained BART model @@ -2519,13 +2574,13 @@ createBARTModelFromJsonFile <- function(json_filename) { #' bart_model_roundtrip <- createBARTModelFromJsonString(bart_json) #' y_hat_mean_roundtrip <- rowMeans(predict(bart_model_roundtrip, X_train)$y_hat) createBARTModelFromJsonString <- function(json_string) { - # Load a `CppJson` object from string - bart_json <- createCppJsonString(json_string) + # Load a `CppJson` object from string + bart_json <- createCppJsonString(json_string) - # Create and return the BART object - bart_object <- createBARTModelFromJson(bart_json) + # Create and return the BART object + bart_object <- createBARTModelFromJson(bart_json) - return(bart_object) + return(bart_object) } #' Convert a list of (in-memory) JSON representations of a BART model to a single combined BART model object @@ -2562,202 +2617,202 @@ createBARTModelFromJsonString <- function(json_string) { #' bart_json <- list(saveBARTModelToJson(bart_model)) #' bart_model_roundtrip <- createBARTModelFromCombinedJson(bart_json) createBARTModelFromCombinedJson <- function(json_object_list) { - # Initialize the BCF model - output <- list() - - # For scalar / preprocessing details which aren't sample-dependent, - # defer to the first json - json_object_default <- json_object_list[[1]] - - # Unpack the forests - include_mean_forest <- json_object_default$get_boolean( - "include_mean_forest" + # Initialize the BCF model + output <- list() + + # For scalar / preprocessing details which aren't sample-dependent, + # defer to the first json + json_object_default <- json_object_list[[1]] + + # Unpack the forests + include_mean_forest <- json_object_default$get_boolean( + "include_mean_forest" + ) + include_variance_forest <- json_object_default$get_boolean( + "include_variance_forest" + ) + if (include_mean_forest) { + output[["mean_forests"]] <- loadForestContainerCombinedJson( + json_object_list, + "forest_0" ) - include_variance_forest <- json_object_default$get_boolean( - "include_variance_forest" + if (include_variance_forest) { + output[["variance_forests"]] <- loadForestContainerCombinedJson( + json_object_list, + "forest_1" + ) + } + } else { + output[["variance_forests"]] <- loadForestContainerCombinedJson( + json_object_list, + "forest_0" ) - if (include_mean_forest) { - output[["mean_forests"]] <- loadForestContainerCombinedJson( - json_object_list, - "forest_0" - ) - if (include_variance_forest) { - output[["variance_forests"]] <- loadForestContainerCombinedJson( - json_object_list, - "forest_1" - ) - } - } else { - output[["variance_forests"]] <- loadForestContainerCombinedJson( - json_object_list, - "forest_0" - ) - } - - # Unpack metadata - train_set_metadata = list() - train_set_metadata[["num_numeric_vars"]] <- json_object_default$get_scalar( - "num_numeric_vars" + } + + # Unpack metadata + train_set_metadata = list() + train_set_metadata[["num_numeric_vars"]] <- json_object_default$get_scalar( + "num_numeric_vars" + ) + train_set_metadata[[ + "num_ordered_cat_vars" + ]] <- json_object_default$get_scalar("num_ordered_cat_vars") + train_set_metadata[[ + "num_unordered_cat_vars" + ]] <- json_object_default$get_scalar("num_unordered_cat_vars") + if (train_set_metadata[["num_numeric_vars"]] > 0) { + train_set_metadata[[ + "numeric_vars" + ]] <- json_object_default$get_string_vector("numeric_vars") + } + if (train_set_metadata[["num_ordered_cat_vars"]] > 0) { + train_set_metadata[[ + "ordered_cat_vars" + ]] <- json_object_default$get_string_vector("ordered_cat_vars") + train_set_metadata[[ + "ordered_unique_levels" + ]] <- json_object_default$get_string_list( + "ordered_unique_levels", + train_set_metadata[["ordered_cat_vars"]] ) + } + if (train_set_metadata[["num_unordered_cat_vars"]] > 0) { train_set_metadata[[ - "num_ordered_cat_vars" - ]] <- json_object_default$get_scalar("num_ordered_cat_vars") + "unordered_cat_vars" + ]] <- json_object_default$get_string_vector("unordered_cat_vars") train_set_metadata[[ - "num_unordered_cat_vars" - ]] <- json_object_default$get_scalar("num_unordered_cat_vars") - if (train_set_metadata[["num_numeric_vars"]] > 0) { - train_set_metadata[[ - "numeric_vars" - ]] <- json_object_default$get_string_vector("numeric_vars") - } - if (train_set_metadata[["num_ordered_cat_vars"]] > 0) { - train_set_metadata[[ - "ordered_cat_vars" - ]] <- json_object_default$get_string_vector("ordered_cat_vars") - train_set_metadata[[ - "ordered_unique_levels" - ]] <- json_object_default$get_string_list( - "ordered_unique_levels", - train_set_metadata[["ordered_cat_vars"]] + "unordered_unique_levels" + ]] <- json_object_default$get_string_list( + "unordered_unique_levels", + train_set_metadata[["unordered_cat_vars"]] + ) + } + output[["train_set_metadata"]] <- train_set_metadata + + # Unpack model params + model_params = list() + model_params[["outcome_scale"]] <- json_object_default$get_scalar( + "outcome_scale" + ) + model_params[["outcome_mean"]] <- json_object_default$get_scalar( + "outcome_mean" + ) + model_params[["standardize"]] <- json_object_default$get_boolean( + "standardize" + ) + model_params[["sigma2_init"]] <- json_object_default$get_scalar( + "sigma2_init" + ) + model_params[["sample_sigma2_global"]] <- json_object_default$get_boolean( + "sample_sigma2_global" + ) + model_params[["sample_sigma2_leaf"]] <- json_object_default$get_boolean( + "sample_sigma2_leaf" + ) + model_params[["include_mean_forest"]] <- include_mean_forest + model_params[["include_variance_forest"]] <- include_variance_forest + model_params[["has_rfx"]] <- json_object_default$get_boolean("has_rfx") + model_params[["has_rfx_basis"]] <- json_object_default$get_boolean( + "has_rfx_basis" + ) + model_params[["num_rfx_basis"]] <- json_object_default$get_scalar( + "num_rfx_basis" + ) + model_params[["num_covariates"]] <- json_object_default$get_scalar( + "num_covariates" + ) + model_params[["num_basis"]] <- json_object_default$get_scalar("num_basis") + model_params[["requires_basis"]] <- json_object_default$get_boolean( + "requires_basis" + ) + model_params[["probit_outcome_model"]] <- json_object_default$get_boolean( + "probit_outcome_model" + ) + model_params[["num_chains"]] <- json_object_default$get_scalar("num_chains") + model_params[["keep_every"]] <- json_object_default$get_scalar("keep_every") + + # Combine values that are sample-specific + for (i in 1:length(json_object_list)) { + json_object <- json_object_list[[i]] + if (i == 1) { + model_params[["num_gfr"]] <- json_object$get_scalar("num_gfr") + model_params[["num_burnin"]] <- json_object$get_scalar("num_burnin") + model_params[["num_mcmc"]] <- json_object$get_scalar("num_mcmc") + model_params[["num_samples"]] <- json_object$get_scalar( + "num_samples" + ) + } else { + prev_json <- json_object_list[[i - 1]] + model_params[["num_gfr"]] <- model_params[["num_gfr"]] + + json_object$get_scalar("num_gfr") + model_params[["num_burnin"]] <- model_params[["num_burnin"]] + + json_object$get_scalar("num_burnin") + model_params[["num_mcmc"]] <- model_params[["num_mcmc"]] + + json_object$get_scalar("num_mcmc") + model_params[["num_samples"]] <- model_params[["num_samples"]] + + json_object$get_scalar("num_samples") + } + } + output[["model_params"]] <- model_params + + # Unpack sampled parameters + if (model_params[["sample_sigma2_global"]]) { + for (i in 1:length(json_object_list)) { + json_object <- json_object_list[[i]] + if (i == 1) { + output[["sigma2_global_samples"]] <- json_object$get_vector( + "sigma2_global_samples", + "parameters" ) - } - if (train_set_metadata[["num_unordered_cat_vars"]] > 0) { - train_set_metadata[[ - "unordered_cat_vars" - ]] <- json_object_default$get_string_vector("unordered_cat_vars") - train_set_metadata[[ - "unordered_unique_levels" - ]] <- json_object_default$get_string_list( - "unordered_unique_levels", - train_set_metadata[["unordered_cat_vars"]] + } else { + output[["sigma2_global_samples"]] <- c( + output[["sigma2_global_samples"]], + json_object$get_vector( + "sigma2_global_samples", + "parameters" + ) ) + } } - output[["train_set_metadata"]] <- train_set_metadata - - # Unpack model params - model_params = list() - model_params[["outcome_scale"]] <- json_object_default$get_scalar( - "outcome_scale" - ) - model_params[["outcome_mean"]] <- json_object_default$get_scalar( - "outcome_mean" - ) - model_params[["standardize"]] <- json_object_default$get_boolean( - "standardize" - ) - model_params[["sigma2_init"]] <- json_object_default$get_scalar( - "sigma2_init" - ) - model_params[["sample_sigma2_global"]] <- json_object_default$get_boolean( - "sample_sigma2_global" - ) - model_params[["sample_sigma2_leaf"]] <- json_object_default$get_boolean( - "sample_sigma2_leaf" - ) - model_params[["include_mean_forest"]] <- include_mean_forest - model_params[["include_variance_forest"]] <- include_variance_forest - model_params[["has_rfx"]] <- json_object_default$get_boolean("has_rfx") - model_params[["has_rfx_basis"]] <- json_object_default$get_boolean( - "has_rfx_basis" - ) - model_params[["num_rfx_basis"]] <- json_object_default$get_scalar( - "num_rfx_basis" - ) - model_params[["num_covariates"]] <- json_object_default$get_scalar( - "num_covariates" - ) - model_params[["num_basis"]] <- json_object_default$get_scalar("num_basis") - model_params[["requires_basis"]] <- json_object_default$get_boolean( - "requires_basis" - ) - model_params[["probit_outcome_model"]] <- json_object_default$get_boolean( - "probit_outcome_model" - ) - model_params[["num_chains"]] <- json_object_default$get_scalar("num_chains") - model_params[["keep_every"]] <- json_object_default$get_scalar("keep_every") - - # Combine values that are sample-specific + } + if (model_params[["sample_sigma2_leaf"]]) { for (i in 1:length(json_object_list)) { - json_object <- json_object_list[[i]] - if (i == 1) { - model_params[["num_gfr"]] <- json_object$get_scalar("num_gfr") - model_params[["num_burnin"]] <- json_object$get_scalar("num_burnin") - model_params[["num_mcmc"]] <- json_object$get_scalar("num_mcmc") - model_params[["num_samples"]] <- json_object$get_scalar( - "num_samples" - ) - } else { - prev_json <- json_object_list[[i - 1]] - model_params[["num_gfr"]] <- model_params[["num_gfr"]] + - json_object$get_scalar("num_gfr") - model_params[["num_burnin"]] <- model_params[["num_burnin"]] + - json_object$get_scalar("num_burnin") - model_params[["num_mcmc"]] <- model_params[["num_mcmc"]] + - json_object$get_scalar("num_mcmc") - model_params[["num_samples"]] <- model_params[["num_samples"]] + - json_object$get_scalar("num_samples") - } - } - output[["model_params"]] <- model_params - - # Unpack sampled parameters - if (model_params[["sample_sigma2_global"]]) { - for (i in 1:length(json_object_list)) { - json_object <- json_object_list[[i]] - if (i == 1) { - output[["sigma2_global_samples"]] <- json_object$get_vector( - "sigma2_global_samples", - "parameters" - ) - } else { - output[["sigma2_global_samples"]] <- c( - output[["sigma2_global_samples"]], - json_object$get_vector( - "sigma2_global_samples", - "parameters" - ) - ) - } - } - } - if (model_params[["sample_sigma2_leaf"]]) { - for (i in 1:length(json_object_list)) { - json_object <- json_object_list[[i]] - if (i == 1) { - output[["sigma2_leaf_samples"]] <- json_object$get_vector( - "sigma2_leaf_samples", - "parameters" - ) - } else { - output[["sigma2_leaf_samples"]] <- c( - output[["sigma2_leaf_samples"]], - json_object$get_vector("sigma2_leaf_samples", "parameters") - ) - } - } - } - - # Unpack random effects - if (model_params[["has_rfx"]]) { - output[[ - "rfx_unique_group_ids" - ]] <- json_object_default$get_string_vector("rfx_unique_group_ids") - output[["rfx_samples"]] <- loadRandomEffectSamplesCombinedJson( - json_object_list, - 0 + json_object <- json_object_list[[i]] + if (i == 1) { + output[["sigma2_leaf_samples"]] <- json_object$get_vector( + "sigma2_leaf_samples", + "parameters" ) - } - - # Unpack covariate preprocessor - preprocessor_metadata_string <- json_object$get_string( - "preprocessor_metadata" - ) - output[["train_set_metadata"]] <- createPreprocessorFromJsonString( - preprocessor_metadata_string + } else { + output[["sigma2_leaf_samples"]] <- c( + output[["sigma2_leaf_samples"]], + json_object$get_vector("sigma2_leaf_samples", "parameters") + ) + } + } + } + + # Unpack random effects + if (model_params[["has_rfx"]]) { + output[[ + "rfx_unique_group_ids" + ]] <- json_object_default$get_string_vector("rfx_unique_group_ids") + output[["rfx_samples"]] <- loadRandomEffectSamplesCombinedJson( + json_object_list, + 0 ) - - class(output) <- "bartmodel" - return(output) + } + + # Unpack covariate preprocessor + preprocessor_metadata_string <- json_object$get_string( + "preprocessor_metadata" + ) + output[["train_set_metadata"]] <- createPreprocessorFromJsonString( + preprocessor_metadata_string + ) + + class(output) <- "bartmodel" + return(output) } #' Convert a list of (in-memory) JSON strings that represent BART models to a single combined BART model object @@ -2794,207 +2849,207 @@ createBARTModelFromCombinedJson <- function(json_object_list) { #' bart_json_string_list <- list(saveBARTModelToJsonString(bart_model)) #' bart_model_roundtrip <- createBARTModelFromCombinedJsonString(bart_json_string_list) createBARTModelFromCombinedJsonString <- function(json_string_list) { - # Initialize the BCF model - output <- list() - - # Convert JSON strings - json_object_list <- list() - for (i in 1:length(json_string_list)) { - json_string <- json_string_list[[i]] - json_object_list[[i]] <- createCppJsonString(json_string) - } - - # For scalar / preprocessing details which aren't sample-dependent, - # defer to the first json - json_object_default <- json_object_list[[1]] - - # Unpack the forests - include_mean_forest <- json_object_default$get_boolean( - "include_mean_forest" + # Initialize the BCF model + output <- list() + + # Convert JSON strings + json_object_list <- list() + for (i in 1:length(json_string_list)) { + json_string <- json_string_list[[i]] + json_object_list[[i]] <- createCppJsonString(json_string) + } + + # For scalar / preprocessing details which aren't sample-dependent, + # defer to the first json + json_object_default <- json_object_list[[1]] + + # Unpack the forests + include_mean_forest <- json_object_default$get_boolean( + "include_mean_forest" + ) + include_variance_forest <- json_object_default$get_boolean( + "include_variance_forest" + ) + if (include_mean_forest) { + output[["mean_forests"]] <- loadForestContainerCombinedJson( + json_object_list, + "forest_0" ) - include_variance_forest <- json_object_default$get_boolean( - "include_variance_forest" + if (include_variance_forest) { + output[["variance_forests"]] <- loadForestContainerCombinedJson( + json_object_list, + "forest_1" + ) + } + } else { + output[["variance_forests"]] <- loadForestContainerCombinedJson( + json_object_list, + "forest_0" ) - if (include_mean_forest) { - output[["mean_forests"]] <- loadForestContainerCombinedJson( - json_object_list, - "forest_0" - ) - if (include_variance_forest) { - output[["variance_forests"]] <- loadForestContainerCombinedJson( - json_object_list, - "forest_1" - ) - } - } else { - output[["variance_forests"]] <- loadForestContainerCombinedJson( - json_object_list, - "forest_0" - ) - } - - # Unpack metadata - train_set_metadata = list() - train_set_metadata[["num_numeric_vars"]] <- json_object_default$get_scalar( - "num_numeric_vars" + } + + # Unpack metadata + train_set_metadata = list() + train_set_metadata[["num_numeric_vars"]] <- json_object_default$get_scalar( + "num_numeric_vars" + ) + train_set_metadata[[ + "num_ordered_cat_vars" + ]] <- json_object_default$get_scalar("num_ordered_cat_vars") + train_set_metadata[[ + "num_unordered_cat_vars" + ]] <- json_object_default$get_scalar("num_unordered_cat_vars") + if (train_set_metadata[["num_numeric_vars"]] > 0) { + train_set_metadata[[ + "numeric_vars" + ]] <- json_object_default$get_string_vector("numeric_vars") + } + if (train_set_metadata[["num_ordered_cat_vars"]] > 0) { + train_set_metadata[[ + "ordered_cat_vars" + ]] <- json_object_default$get_string_vector("ordered_cat_vars") + train_set_metadata[[ + "ordered_unique_levels" + ]] <- json_object_default$get_string_list( + "ordered_unique_levels", + train_set_metadata[["ordered_cat_vars"]] ) + } + if (train_set_metadata[["num_unordered_cat_vars"]] > 0) { train_set_metadata[[ - "num_ordered_cat_vars" - ]] <- json_object_default$get_scalar("num_ordered_cat_vars") + "unordered_cat_vars" + ]] <- json_object_default$get_string_vector("unordered_cat_vars") train_set_metadata[[ - "num_unordered_cat_vars" - ]] <- json_object_default$get_scalar("num_unordered_cat_vars") - if (train_set_metadata[["num_numeric_vars"]] > 0) { - train_set_metadata[[ - "numeric_vars" - ]] <- json_object_default$get_string_vector("numeric_vars") - } - if (train_set_metadata[["num_ordered_cat_vars"]] > 0) { - train_set_metadata[[ - "ordered_cat_vars" - ]] <- json_object_default$get_string_vector("ordered_cat_vars") - train_set_metadata[[ - "ordered_unique_levels" - ]] <- json_object_default$get_string_list( - "ordered_unique_levels", - train_set_metadata[["ordered_cat_vars"]] + "unordered_unique_levels" + ]] <- json_object_default$get_string_list( + "unordered_unique_levels", + train_set_metadata[["unordered_cat_vars"]] + ) + } + output[["train_set_metadata"]] <- train_set_metadata + + # Unpack model params + model_params = list() + model_params[["outcome_scale"]] <- json_object_default$get_scalar( + "outcome_scale" + ) + model_params[["outcome_mean"]] <- json_object_default$get_scalar( + "outcome_mean" + ) + model_params[["standardize"]] <- json_object_default$get_boolean( + "standardize" + ) + model_params[["sigma2_init"]] <- json_object_default$get_scalar( + "sigma2_init" + ) + model_params[["sample_sigma2_global"]] <- json_object_default$get_boolean( + "sample_sigma2_global" + ) + model_params[["sample_sigma2_leaf"]] <- json_object_default$get_boolean( + "sample_sigma2_leaf" + ) + model_params[["include_mean_forest"]] <- include_mean_forest + model_params[["include_variance_forest"]] <- include_variance_forest + model_params[["has_rfx"]] <- json_object_default$get_boolean("has_rfx") + model_params[["has_rfx_basis"]] <- json_object_default$get_boolean( + "has_rfx_basis" + ) + model_params[["num_rfx_basis"]] <- json_object_default$get_scalar( + "num_rfx_basis" + ) + model_params[["num_covariates"]] <- json_object_default$get_scalar( + "num_covariates" + ) + model_params[["num_basis"]] <- json_object_default$get_scalar("num_basis") + model_params[["num_chains"]] <- json_object_default$get_scalar("num_chains") + model_params[["keep_every"]] <- json_object_default$get_scalar("keep_every") + model_params[["requires_basis"]] <- json_object_default$get_boolean( + "requires_basis" + ) + model_params[["probit_outcome_model"]] <- json_object_default$get_boolean( + "probit_outcome_model" + ) + + # Combine values that are sample-specific + for (i in 1:length(json_object_list)) { + json_object <- json_object_list[[i]] + if (i == 1) { + model_params[["num_gfr"]] <- json_object$get_scalar("num_gfr") + model_params[["num_burnin"]] <- json_object$get_scalar("num_burnin") + model_params[["num_mcmc"]] <- json_object$get_scalar("num_mcmc") + model_params[["num_samples"]] <- json_object$get_scalar( + "num_samples" + ) + } else { + prev_json <- json_object_list[[i - 1]] + model_params[["num_gfr"]] <- model_params[["num_gfr"]] + + json_object$get_scalar("num_gfr") + model_params[["num_burnin"]] <- model_params[["num_burnin"]] + + json_object$get_scalar("num_burnin") + model_params[["num_mcmc"]] <- model_params[["num_mcmc"]] + + json_object$get_scalar("num_mcmc") + model_params[["num_samples"]] <- model_params[["num_samples"]] + + json_object$get_scalar("num_samples") + } + } + output[["model_params"]] <- model_params + + # Unpack sampled parameters + if (model_params[["sample_sigma2_global"]]) { + for (i in 1:length(json_object_list)) { + json_object <- json_object_list[[i]] + if (i == 1) { + output[["sigma2_global_samples"]] <- json_object$get_vector( + "sigma2_global_samples", + "parameters" ) - } - if (train_set_metadata[["num_unordered_cat_vars"]] > 0) { - train_set_metadata[[ - "unordered_cat_vars" - ]] <- json_object_default$get_string_vector("unordered_cat_vars") - train_set_metadata[[ - "unordered_unique_levels" - ]] <- json_object_default$get_string_list( - "unordered_unique_levels", - train_set_metadata[["unordered_cat_vars"]] + } else { + output[["sigma2_global_samples"]] <- c( + output[["sigma2_global_samples"]], + json_object$get_vector( + "sigma2_global_samples", + "parameters" + ) ) + } } - output[["train_set_metadata"]] <- train_set_metadata - - # Unpack model params - model_params = list() - model_params[["outcome_scale"]] <- json_object_default$get_scalar( - "outcome_scale" - ) - model_params[["outcome_mean"]] <- json_object_default$get_scalar( - "outcome_mean" - ) - model_params[["standardize"]] <- json_object_default$get_boolean( - "standardize" - ) - model_params[["sigma2_init"]] <- json_object_default$get_scalar( - "sigma2_init" - ) - model_params[["sample_sigma2_global"]] <- json_object_default$get_boolean( - "sample_sigma2_global" - ) - model_params[["sample_sigma2_leaf"]] <- json_object_default$get_boolean( - "sample_sigma2_leaf" - ) - model_params[["include_mean_forest"]] <- include_mean_forest - model_params[["include_variance_forest"]] <- include_variance_forest - model_params[["has_rfx"]] <- json_object_default$get_boolean("has_rfx") - model_params[["has_rfx_basis"]] <- json_object_default$get_boolean( - "has_rfx_basis" - ) - model_params[["num_rfx_basis"]] <- json_object_default$get_scalar( - "num_rfx_basis" - ) - model_params[["num_covariates"]] <- json_object_default$get_scalar( - "num_covariates" - ) - model_params[["num_basis"]] <- json_object_default$get_scalar("num_basis") - model_params[["num_chains"]] <- json_object_default$get_scalar("num_chains") - model_params[["keep_every"]] <- json_object_default$get_scalar("keep_every") - model_params[["requires_basis"]] <- json_object_default$get_boolean( - "requires_basis" - ) - model_params[["probit_outcome_model"]] <- json_object_default$get_boolean( - "probit_outcome_model" - ) - - # Combine values that are sample-specific + } + if (model_params[["sample_sigma2_leaf"]]) { for (i in 1:length(json_object_list)) { - json_object <- json_object_list[[i]] - if (i == 1) { - model_params[["num_gfr"]] <- json_object$get_scalar("num_gfr") - model_params[["num_burnin"]] <- json_object$get_scalar("num_burnin") - model_params[["num_mcmc"]] <- json_object$get_scalar("num_mcmc") - model_params[["num_samples"]] <- json_object$get_scalar( - "num_samples" - ) - } else { - prev_json <- json_object_list[[i - 1]] - model_params[["num_gfr"]] <- model_params[["num_gfr"]] + - json_object$get_scalar("num_gfr") - model_params[["num_burnin"]] <- model_params[["num_burnin"]] + - json_object$get_scalar("num_burnin") - model_params[["num_mcmc"]] <- model_params[["num_mcmc"]] + - json_object$get_scalar("num_mcmc") - model_params[["num_samples"]] <- model_params[["num_samples"]] + - json_object$get_scalar("num_samples") - } - } - output[["model_params"]] <- model_params - - # Unpack sampled parameters - if (model_params[["sample_sigma2_global"]]) { - for (i in 1:length(json_object_list)) { - json_object <- json_object_list[[i]] - if (i == 1) { - output[["sigma2_global_samples"]] <- json_object$get_vector( - "sigma2_global_samples", - "parameters" - ) - } else { - output[["sigma2_global_samples"]] <- c( - output[["sigma2_global_samples"]], - json_object$get_vector( - "sigma2_global_samples", - "parameters" - ) - ) - } - } - } - if (model_params[["sample_sigma2_leaf"]]) { - for (i in 1:length(json_object_list)) { - json_object <- json_object_list[[i]] - if (i == 1) { - output[["sigma2_leaf_samples"]] <- json_object$get_vector( - "sigma2_leaf_samples", - "parameters" - ) - } else { - output[["sigma2_leaf_samples"]] <- c( - output[["sigma2_leaf_samples"]], - json_object$get_vector("sigma2_leaf_samples", "parameters") - ) - } - } - } - - # Unpack random effects - if (model_params[["has_rfx"]]) { - output[[ - "rfx_unique_group_ids" - ]] <- json_object_default$get_string_vector("rfx_unique_group_ids") - output[["rfx_samples"]] <- loadRandomEffectSamplesCombinedJson( - json_object_list, - 0 + json_object <- json_object_list[[i]] + if (i == 1) { + output[["sigma2_leaf_samples"]] <- json_object$get_vector( + "sigma2_leaf_samples", + "parameters" ) - } - - # Unpack covariate preprocessor - preprocessor_metadata_string <- json_object_default$get_string( - "preprocessor_metadata" - ) - output[["train_set_metadata"]] <- createPreprocessorFromJsonString( - preprocessor_metadata_string + } else { + output[["sigma2_leaf_samples"]] <- c( + output[["sigma2_leaf_samples"]], + json_object$get_vector("sigma2_leaf_samples", "parameters") + ) + } + } + } + + # Unpack random effects + if (model_params[["has_rfx"]]) { + output[[ + "rfx_unique_group_ids" + ]] <- json_object_default$get_string_vector("rfx_unique_group_ids") + output[["rfx_samples"]] <- loadRandomEffectSamplesCombinedJson( + json_object_list, + 0 ) - - class(output) <- "bartmodel" - return(output) + } + + # Unpack covariate preprocessor + preprocessor_metadata_string <- json_object_default$get_string( + "preprocessor_metadata" + ) + output[["train_set_metadata"]] <- createPreprocessorFromJsonString( + preprocessor_metadata_string + ) + + class(output) <- "bartmodel" + return(output) } diff --git a/man/bart.Rd b/man/bart.Rd index 66a9b9ad..c11c619b 100644 --- a/man/bart.Rd +++ b/man/bart.Rd @@ -136,9 +136,9 @@ n <- 100 p <- 5 X <- matrix(runif(n*p), ncol = p) f_XW <- ( - ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + - ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + - ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + + ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + + ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + + ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + ((0.75 <= X[,1]) & (1 > X[,1])) * (7.5) ) noise_sd <- 1 @@ -153,6 +153,6 @@ X_train <- X[train_inds,] y_test <- y[test_inds] y_train <- y[train_inds] -bart_model <- bart(X_train = X_train, y_train = y_train, X_test = X_test, +bart_model <- bart(X_train = X_train, y_train = y_train, X_test = X_test, num_gfr = 10, num_burnin = 0, num_mcmc = 10) } diff --git a/man/bcf.Rd b/man/bcf.Rd index 01e5fab8..f7d42e93 100644 --- a/man/bcf.Rd +++ b/man/bcf.Rd @@ -162,21 +162,21 @@ n <- 500 p <- 5 X <- matrix(runif(n*p), ncol = p) mu_x <- ( - ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + - ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + - ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + + ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + + ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + + ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + ((0.75 <= X[,1]) & (1 > X[,1])) * (7.5) ) pi_x <- ( - ((0 <= X[,1]) & (0.25 > X[,1])) * (0.2) + - ((0.25 <= X[,1]) & (0.5 > X[,1])) * (0.4) + - ((0.5 <= X[,1]) & (0.75 > X[,1])) * (0.6) + + ((0 <= X[,1]) & (0.25 > X[,1])) * (0.2) + + ((0.25 <= X[,1]) & (0.5 > X[,1])) * (0.4) + + ((0.5 <= X[,1]) & (0.75 > X[,1])) * (0.6) + ((0.75 <= X[,1]) & (1 > X[,1])) * (0.8) ) tau_x <- ( - ((0 <= X[,2]) & (0.25 > X[,2])) * (0.5) + - ((0.25 <= X[,2]) & (0.5 > X[,2])) * (1.0) + - ((0.5 <= X[,2]) & (0.75 > X[,2])) * (1.5) + + ((0 <= X[,2]) & (0.25 > X[,2])) * (0.5) + + ((0.25 <= X[,2]) & (0.5 > X[,2])) * (1.0) + + ((0.5 <= X[,2]) & (0.75 > X[,2])) * (1.5) + ((0.75 <= X[,2]) & (1 > X[,2])) * (2.0) ) Z <- rbinom(n, 1, pi_x) @@ -199,8 +199,8 @@ mu_test <- mu_x[test_inds] mu_train <- mu_x[train_inds] tau_test <- tau_x[test_inds] tau_train <- tau_x[train_inds] -bcf_model <- bcf(X_train = X_train, Z_train = Z_train, y_train = y_train, - propensity_train = pi_train, X_test = X_test, Z_test = Z_test, - propensity_test = pi_test, num_gfr = 10, +bcf_model <- bcf(X_train = X_train, Z_train = Z_train, y_train = y_train, + propensity_train = pi_train, X_test = X_test, Z_test = Z_test, + propensity_test = pi_test, num_gfr = 10, num_burnin = 0, num_mcmc = 10) } diff --git a/man/createBARTModelFromCombinedJson.Rd b/man/createBARTModelFromCombinedJson.Rd index 35d185c3..83d61d0d 100644 --- a/man/createBARTModelFromCombinedJson.Rd +++ b/man/createBARTModelFromCombinedJson.Rd @@ -22,9 +22,9 @@ n <- 100 p <- 5 X <- matrix(runif(n*p), ncol = p) f_XW <- ( - ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + - ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + - ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + + ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + + ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + + ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + ((0.75 <= X[,1]) & (1 > X[,1])) * (7.5) ) noise_sd <- 1 @@ -38,7 +38,7 @@ X_test <- X[test_inds,] X_train <- X[train_inds,] y_test <- y[test_inds] y_train <- y[train_inds] -bart_model <- bart(X_train = X_train, y_train = y_train, +bart_model <- bart(X_train = X_train, y_train = y_train, num_gfr = 10, num_burnin = 0, num_mcmc = 10) bart_json <- list(saveBARTModelToJson(bart_model)) bart_model_roundtrip <- createBARTModelFromCombinedJson(bart_json) diff --git a/man/createBARTModelFromCombinedJsonString.Rd b/man/createBARTModelFromCombinedJsonString.Rd index a8470dee..7a17484a 100644 --- a/man/createBARTModelFromCombinedJsonString.Rd +++ b/man/createBARTModelFromCombinedJsonString.Rd @@ -22,9 +22,9 @@ n <- 100 p <- 5 X <- matrix(runif(n*p), ncol = p) f_XW <- ( - ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + - ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + - ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + + ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + + ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + + ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + ((0.75 <= X[,1]) & (1 > X[,1])) * (7.5) ) noise_sd <- 1 @@ -38,7 +38,7 @@ X_test <- X[test_inds,] X_train <- X[train_inds,] y_test <- y[test_inds] y_train <- y[train_inds] -bart_model <- bart(X_train = X_train, y_train = y_train, +bart_model <- bart(X_train = X_train, y_train = y_train, num_gfr = 10, num_burnin = 0, num_mcmc = 10) bart_json_string_list <- list(saveBARTModelToJsonString(bart_model)) bart_model_roundtrip <- createBARTModelFromCombinedJsonString(bart_json_string_list) diff --git a/man/createBARTModelFromJson.Rd b/man/createBARTModelFromJson.Rd index 57686122..68a02f0e 100644 --- a/man/createBARTModelFromJson.Rd +++ b/man/createBARTModelFromJson.Rd @@ -22,9 +22,9 @@ n <- 100 p <- 5 X <- matrix(runif(n*p), ncol = p) f_XW <- ( - ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + - ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + - ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + + ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + + ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + + ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + ((0.75 <= X[,1]) & (1 > X[,1])) * (7.5) ) noise_sd <- 1 @@ -38,7 +38,7 @@ X_test <- X[test_inds,] X_train <- X[train_inds,] y_test <- y[test_inds] y_train <- y[train_inds] -bart_model <- bart(X_train = X_train, y_train = y_train, +bart_model <- bart(X_train = X_train, y_train = y_train, num_gfr = 10, num_burnin = 0, num_mcmc = 10) bart_json <- saveBARTModelToJson(bart_model) bart_model_roundtrip <- createBARTModelFromJson(bart_json) diff --git a/man/createBARTModelFromJsonFile.Rd b/man/createBARTModelFromJsonFile.Rd index f714a94a..7608d8d2 100644 --- a/man/createBARTModelFromJsonFile.Rd +++ b/man/createBARTModelFromJsonFile.Rd @@ -22,9 +22,9 @@ n <- 100 p <- 5 X <- matrix(runif(n*p), ncol = p) f_XW <- ( - ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + - ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + - ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + + ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + + ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + + ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + ((0.75 <= X[,1]) & (1 > X[,1])) * (7.5) ) noise_sd <- 1 @@ -38,7 +38,7 @@ X_test <- X[test_inds,] X_train <- X[train_inds,] y_test <- y[test_inds] y_train <- y[train_inds] -bart_model <- bart(X_train = X_train, y_train = y_train, +bart_model <- bart(X_train = X_train, y_train = y_train, num_gfr = 10, num_burnin = 0, num_mcmc = 10) tmpjson <- tempfile(fileext = ".json") saveBARTModelToJsonFile(bart_model, file.path(tmpjson)) diff --git a/man/createBARTModelFromJsonString.Rd b/man/createBARTModelFromJsonString.Rd index 67068fd0..0748d97a 100644 --- a/man/createBARTModelFromJsonString.Rd +++ b/man/createBARTModelFromJsonString.Rd @@ -22,9 +22,9 @@ n <- 100 p <- 5 X <- matrix(runif(n*p), ncol = p) f_XW <- ( - ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + - ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + - ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + + ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + + ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + + ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + ((0.75 <= X[,1]) & (1 > X[,1])) * (7.5) ) noise_sd <- 1 @@ -38,7 +38,7 @@ X_test <- X[test_inds,] X_train <- X[train_inds,] y_test <- y[test_inds] y_train <- y[train_inds] -bart_model <- bart(X_train = X_train, y_train = y_train, +bart_model <- bart(X_train = X_train, y_train = y_train, num_gfr = 10, num_burnin = 0, num_mcmc = 10) bart_json <- saveBARTModelToJsonString(bart_model) bart_model_roundtrip <- createBARTModelFromJsonString(bart_json) diff --git a/man/createBCFModelFromCombinedJson.Rd b/man/createBCFModelFromCombinedJson.Rd index 6f29569e..24c82e4f 100644 --- a/man/createBCFModelFromCombinedJson.Rd +++ b/man/createBCFModelFromCombinedJson.Rd @@ -22,21 +22,21 @@ n <- 500 p <- 5 X <- matrix(runif(n*p), ncol = p) mu_x <- ( - ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + - ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + - ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + + ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + + ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + + ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + ((0.75 <= X[,1]) & (1 > X[,1])) * (7.5) ) pi_x <- ( - ((0 <= X[,1]) & (0.25 > X[,1])) * (0.2) + - ((0.25 <= X[,1]) & (0.5 > X[,1])) * (0.4) + - ((0.5 <= X[,1]) & (0.75 > X[,1])) * (0.6) + + ((0 <= X[,1]) & (0.25 > X[,1])) * (0.2) + + ((0.25 <= X[,1]) & (0.5 > X[,1])) * (0.4) + + ((0.5 <= X[,1]) & (0.75 > X[,1])) * (0.6) + ((0.75 <= X[,1]) & (1 > X[,1])) * (0.8) ) tau_x <- ( - ((0 <= X[,2]) & (0.25 > X[,2])) * (0.5) + - ((0.25 <= X[,2]) & (0.5 > X[,2])) * (1.0) + - ((0.5 <= X[,2]) & (0.75 > X[,2])) * (1.5) + + ((0 <= X[,2]) & (0.25 > X[,2])) * (0.5) + + ((0.25 <= X[,2]) & (0.5 > X[,2])) * (1.0) + + ((0.5 <= X[,2]) & (0.75 > X[,2])) * (1.5) + ((0.75 <= X[,2]) & (1 > X[,2])) * (2.0) ) Z <- rbinom(n, 1, pi_x) @@ -70,13 +70,13 @@ rfx_basis_test <- rfx_basis[test_inds,] rfx_basis_train <- rfx_basis[train_inds,] rfx_term_test <- rfx_term[test_inds] rfx_term_train <- rfx_term[train_inds] -bcf_model <- bcf(X_train = X_train, Z_train = Z_train, y_train = y_train, - propensity_train = pi_train, - rfx_group_ids_train = rfx_group_ids_train, - rfx_basis_train = rfx_basis_train, X_test = X_test, - Z_test = Z_test, propensity_test = pi_test, +bcf_model <- bcf(X_train = X_train, Z_train = Z_train, y_train = y_train, + propensity_train = pi_train, + rfx_group_ids_train = rfx_group_ids_train, + rfx_basis_train = rfx_basis_train, X_test = X_test, + Z_test = Z_test, propensity_test = pi_test, rfx_group_ids_test = rfx_group_ids_test, - rfx_basis_test = rfx_basis_test, + rfx_basis_test = rfx_basis_test, num_gfr = 10, num_burnin = 0, num_mcmc = 10) bcf_json_list <- list(saveBCFModelToJson(bcf_model)) bcf_model_roundtrip <- createBCFModelFromCombinedJson(bcf_json_list) diff --git a/man/createBCFModelFromCombinedJsonString.Rd b/man/createBCFModelFromCombinedJsonString.Rd index bd7e63f2..e0522f75 100644 --- a/man/createBCFModelFromCombinedJsonString.Rd +++ b/man/createBCFModelFromCombinedJsonString.Rd @@ -22,21 +22,21 @@ n <- 500 p <- 5 X <- matrix(runif(n*p), ncol = p) mu_x <- ( - ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + - ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + - ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + + ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + + ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + + ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + ((0.75 <= X[,1]) & (1 > X[,1])) * (7.5) ) pi_x <- ( - ((0 <= X[,1]) & (0.25 > X[,1])) * (0.2) + - ((0.25 <= X[,1]) & (0.5 > X[,1])) * (0.4) + - ((0.5 <= X[,1]) & (0.75 > X[,1])) * (0.6) + + ((0 <= X[,1]) & (0.25 > X[,1])) * (0.2) + + ((0.25 <= X[,1]) & (0.5 > X[,1])) * (0.4) + + ((0.5 <= X[,1]) & (0.75 > X[,1])) * (0.6) + ((0.75 <= X[,1]) & (1 > X[,1])) * (0.8) ) tau_x <- ( - ((0 <= X[,2]) & (0.25 > X[,2])) * (0.5) + - ((0.25 <= X[,2]) & (0.5 > X[,2])) * (1.0) + - ((0.5 <= X[,2]) & (0.75 > X[,2])) * (1.5) + + ((0 <= X[,2]) & (0.25 > X[,2])) * (0.5) + + ((0.25 <= X[,2]) & (0.5 > X[,2])) * (1.0) + + ((0.5 <= X[,2]) & (0.75 > X[,2])) * (1.5) + ((0.75 <= X[,2]) & (1 > X[,2])) * (2.0) ) Z <- rbinom(n, 1, pi_x) @@ -70,13 +70,13 @@ rfx_basis_test <- rfx_basis[test_inds,] rfx_basis_train <- rfx_basis[train_inds,] rfx_term_test <- rfx_term[test_inds] rfx_term_train <- rfx_term[train_inds] -bcf_model <- bcf(X_train = X_train, Z_train = Z_train, y_train = y_train, - propensity_train = pi_train, - rfx_group_ids_train = rfx_group_ids_train, - rfx_basis_train = rfx_basis_train, X_test = X_test, - Z_test = Z_test, propensity_test = pi_test, +bcf_model <- bcf(X_train = X_train, Z_train = Z_train, y_train = y_train, + propensity_train = pi_train, + rfx_group_ids_train = rfx_group_ids_train, + rfx_basis_train = rfx_basis_train, X_test = X_test, + Z_test = Z_test, propensity_test = pi_test, rfx_group_ids_test = rfx_group_ids_test, - rfx_basis_test = rfx_basis_test, + rfx_basis_test = rfx_basis_test, num_gfr = 10, num_burnin = 0, num_mcmc = 10) bcf_json_string_list <- list(saveBCFModelToJsonString(bcf_model)) bcf_model_roundtrip <- createBCFModelFromCombinedJsonString(bcf_json_string_list) diff --git a/man/createBCFModelFromJson.Rd b/man/createBCFModelFromJson.Rd index a579b140..35cff7ce 100644 --- a/man/createBCFModelFromJson.Rd +++ b/man/createBCFModelFromJson.Rd @@ -22,21 +22,21 @@ n <- 500 p <- 5 X <- matrix(runif(n*p), ncol = p) mu_x <- ( - ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + - ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + - ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + + ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + + ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + + ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + ((0.75 <= X[,1]) & (1 > X[,1])) * (7.5) ) pi_x <- ( - ((0 <= X[,1]) & (0.25 > X[,1])) * (0.2) + - ((0.25 <= X[,1]) & (0.5 > X[,1])) * (0.4) + - ((0.5 <= X[,1]) & (0.75 > X[,1])) * (0.6) + + ((0 <= X[,1]) & (0.25 > X[,1])) * (0.2) + + ((0.25 <= X[,1]) & (0.5 > X[,1])) * (0.4) + + ((0.5 <= X[,1]) & (0.75 > X[,1])) * (0.6) + ((0.75 <= X[,1]) & (1 > X[,1])) * (0.8) ) tau_x <- ( - ((0 <= X[,2]) & (0.25 > X[,2])) * (0.5) + - ((0.25 <= X[,2]) & (0.5 > X[,2])) * (1.0) + - ((0.5 <= X[,2]) & (0.75 > X[,2])) * (1.5) + + ((0 <= X[,2]) & (0.25 > X[,2])) * (0.5) + + ((0.25 <= X[,2]) & (0.5 > X[,2])) * (1.0) + + ((0.5 <= X[,2]) & (0.75 > X[,2])) * (1.5) + ((0.75 <= X[,2]) & (1 > X[,2])) * (2.0) ) Z <- rbinom(n, 1, pi_x) @@ -72,15 +72,15 @@ rfx_term_test <- rfx_term[test_inds] rfx_term_train <- rfx_term[train_inds] mu_params <- list(sample_sigma2_leaf = TRUE) tau_params <- list(sample_sigma2_leaf = FALSE) -bcf_model <- bcf(X_train = X_train, Z_train = Z_train, y_train = y_train, - propensity_train = pi_train, - rfx_group_ids_train = rfx_group_ids_train, - rfx_basis_train = rfx_basis_train, X_test = X_test, - Z_test = Z_test, propensity_test = pi_test, +bcf_model <- bcf(X_train = X_train, Z_train = Z_train, y_train = y_train, + propensity_train = pi_train, + rfx_group_ids_train = rfx_group_ids_train, + rfx_basis_train = rfx_basis_train, X_test = X_test, + Z_test = Z_test, propensity_test = pi_test, rfx_group_ids_test = rfx_group_ids_test, - rfx_basis_test = rfx_basis_test, - num_gfr = 10, num_burnin = 0, num_mcmc = 10, - prognostic_forest_params = mu_params, + rfx_basis_test = rfx_basis_test, + num_gfr = 10, num_burnin = 0, num_mcmc = 10, + prognostic_forest_params = mu_params, treatment_effect_forest_params = tau_params) bcf_json <- saveBCFModelToJson(bcf_model) bcf_model_roundtrip <- createBCFModelFromJson(bcf_json) diff --git a/man/createBCFModelFromJsonFile.Rd b/man/createBCFModelFromJsonFile.Rd index 2661d4de..a2496797 100644 --- a/man/createBCFModelFromJsonFile.Rd +++ b/man/createBCFModelFromJsonFile.Rd @@ -22,21 +22,21 @@ n <- 500 p <- 5 X <- matrix(runif(n*p), ncol = p) mu_x <- ( - ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + - ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + - ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + + ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + + ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + + ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + ((0.75 <= X[,1]) & (1 > X[,1])) * (7.5) ) pi_x <- ( - ((0 <= X[,1]) & (0.25 > X[,1])) * (0.2) + - ((0.25 <= X[,1]) & (0.5 > X[,1])) * (0.4) + - ((0.5 <= X[,1]) & (0.75 > X[,1])) * (0.6) + + ((0 <= X[,1]) & (0.25 > X[,1])) * (0.2) + + ((0.25 <= X[,1]) & (0.5 > X[,1])) * (0.4) + + ((0.5 <= X[,1]) & (0.75 > X[,1])) * (0.6) + ((0.75 <= X[,1]) & (1 > X[,1])) * (0.8) ) tau_x <- ( - ((0 <= X[,2]) & (0.25 > X[,2])) * (0.5) + - ((0.25 <= X[,2]) & (0.5 > X[,2])) * (1.0) + - ((0.5 <= X[,2]) & (0.75 > X[,2])) * (1.5) + + ((0 <= X[,2]) & (0.25 > X[,2])) * (0.5) + + ((0.25 <= X[,2]) & (0.5 > X[,2])) * (1.0) + + ((0.5 <= X[,2]) & (0.75 > X[,2])) * (1.5) + ((0.75 <= X[,2]) & (1 > X[,2])) * (2.0) ) Z <- rbinom(n, 1, pi_x) @@ -72,15 +72,15 @@ rfx_term_test <- rfx_term[test_inds] rfx_term_train <- rfx_term[train_inds] mu_params <- list(sample_sigma2_leaf = TRUE) tau_params <- list(sample_sigma2_leaf = FALSE) -bcf_model <- bcf(X_train = X_train, Z_train = Z_train, y_train = y_train, - propensity_train = pi_train, - rfx_group_ids_train = rfx_group_ids_train, - rfx_basis_train = rfx_basis_train, X_test = X_test, - Z_test = Z_test, propensity_test = pi_test, +bcf_model <- bcf(X_train = X_train, Z_train = Z_train, y_train = y_train, + propensity_train = pi_train, + rfx_group_ids_train = rfx_group_ids_train, + rfx_basis_train = rfx_basis_train, X_test = X_test, + Z_test = Z_test, propensity_test = pi_test, rfx_group_ids_test = rfx_group_ids_test, - rfx_basis_test = rfx_basis_test, - num_gfr = 10, num_burnin = 0, num_mcmc = 10, - prognostic_forest_params = mu_params, + rfx_basis_test = rfx_basis_test, + num_gfr = 10, num_burnin = 0, num_mcmc = 10, + prognostic_forest_params = mu_params, treatment_effect_forest_params = tau_params) tmpjson <- tempfile(fileext = ".json") saveBCFModelToJsonFile(bcf_model, file.path(tmpjson)) diff --git a/man/createBCFModelFromJsonString.Rd b/man/createBCFModelFromJsonString.Rd index 5f34724c..cc944f85 100644 --- a/man/createBCFModelFromJsonString.Rd +++ b/man/createBCFModelFromJsonString.Rd @@ -22,21 +22,21 @@ n <- 500 p <- 5 X <- matrix(runif(n*p), ncol = p) mu_x <- ( - ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + - ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + - ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + + ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + + ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + + ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + ((0.75 <= X[,1]) & (1 > X[,1])) * (7.5) ) pi_x <- ( - ((0 <= X[,1]) & (0.25 > X[,1])) * (0.2) + - ((0.25 <= X[,1]) & (0.5 > X[,1])) * (0.4) + - ((0.5 <= X[,1]) & (0.75 > X[,1])) * (0.6) + + ((0 <= X[,1]) & (0.25 > X[,1])) * (0.2) + + ((0.25 <= X[,1]) & (0.5 > X[,1])) * (0.4) + + ((0.5 <= X[,1]) & (0.75 > X[,1])) * (0.6) + ((0.75 <= X[,1]) & (1 > X[,1])) * (0.8) ) tau_x <- ( - ((0 <= X[,2]) & (0.25 > X[,2])) * (0.5) + - ((0.25 <= X[,2]) & (0.5 > X[,2])) * (1.0) + - ((0.5 <= X[,2]) & (0.75 > X[,2])) * (1.5) + + ((0 <= X[,2]) & (0.25 > X[,2])) * (0.5) + + ((0.25 <= X[,2]) & (0.5 > X[,2])) * (1.0) + + ((0.5 <= X[,2]) & (0.75 > X[,2])) * (1.5) + ((0.75 <= X[,2]) & (1 > X[,2])) * (2.0) ) Z <- rbinom(n, 1, pi_x) @@ -70,13 +70,13 @@ rfx_basis_test <- rfx_basis[test_inds,] rfx_basis_train <- rfx_basis[train_inds,] rfx_term_test <- rfx_term[test_inds] rfx_term_train <- rfx_term[train_inds] -bcf_model <- bcf(X_train = X_train, Z_train = Z_train, y_train = y_train, - propensity_train = pi_train, - rfx_group_ids_train = rfx_group_ids_train, - rfx_basis_train = rfx_basis_train, X_test = X_test, - Z_test = Z_test, propensity_test = pi_test, +bcf_model <- bcf(X_train = X_train, Z_train = Z_train, y_train = y_train, + propensity_train = pi_train, + rfx_group_ids_train = rfx_group_ids_train, + rfx_basis_train = rfx_basis_train, X_test = X_test, + Z_test = Z_test, propensity_test = pi_test, rfx_group_ids_test = rfx_group_ids_test, - rfx_basis_test = rfx_basis_test, + rfx_basis_test = rfx_basis_test, num_gfr = 10, num_burnin = 0, num_mcmc = 10) bcf_json <- saveBCFModelToJsonString(bcf_model) bcf_model_roundtrip <- createBCFModelFromJsonString(bcf_json) diff --git a/man/createForestModel.Rd b/man/createForestModel.Rd index d9000925..d7a1adae 100644 --- a/man/createForestModel.Rd +++ b/man/createForestModel.Rd @@ -30,10 +30,10 @@ max_depth <- 10 feature_types <- as.integer(rep(0, p)) X <- matrix(runif(n*p), ncol = p) forest_dataset <- createForestDataset(X) -forest_model_config <- createForestModelConfig(feature_types=feature_types, - num_trees=num_trees, num_features=p, - num_observations=n, alpha=alpha, beta=beta, - min_samples_leaf=min_samples_leaf, +forest_model_config <- createForestModelConfig(feature_types=feature_types, + num_trees=num_trees, num_features=p, + num_observations=n, alpha=alpha, beta=beta, + min_samples_leaf=min_samples_leaf, max_depth=max_depth, leaf_model_type=1) global_model_config <- createGlobalModelConfig(global_error_variance=1.0) forest_model <- createForestModel(forest_dataset, forest_model_config, global_model_config) diff --git a/man/getRandomEffectSamples.bartmodel.Rd b/man/getRandomEffectSamples.bartmodel.Rd index 0da1eb98..149586a8 100644 --- a/man/getRandomEffectSamples.bartmodel.Rd +++ b/man/getRandomEffectSamples.bartmodel.Rd @@ -24,9 +24,9 @@ n <- 100 p <- 5 X <- matrix(runif(n*p), ncol = p) f_XW <- ( - ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + - ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + - ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + + ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + + ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + + ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + ((0.75 <= X[,1]) & (1 > X[,1])) * (7.5) ) snr <- 3 @@ -51,11 +51,11 @@ rfx_basis_test <- rfx_basis[test_inds,] rfx_basis_train <- rfx_basis[train_inds,] rfx_term_test <- rfx_term[test_inds] rfx_term_train <- rfx_term[train_inds] -bart_model <- bart(X_train = X_train, y_train = y_train, X_test = X_test, - rfx_group_ids_train = rfx_group_ids_train, - rfx_group_ids_test = rfx_group_ids_test, - rfx_basis_train = rfx_basis_train, - rfx_basis_test = rfx_basis_test, +bart_model <- bart(X_train = X_train, y_train = y_train, X_test = X_test, + rfx_group_ids_train = rfx_group_ids_train, + rfx_group_ids_test = rfx_group_ids_test, + rfx_basis_train = rfx_basis_train, + rfx_basis_test = rfx_basis_test, num_gfr = 10, num_burnin = 0, num_mcmc = 10) rfx_samples <- getRandomEffectSamples(bart_model) } diff --git a/man/getRandomEffectSamples.bcfmodel.Rd b/man/getRandomEffectSamples.bcfmodel.Rd index 6769de62..08a8eae4 100644 --- a/man/getRandomEffectSamples.bcfmodel.Rd +++ b/man/getRandomEffectSamples.bcfmodel.Rd @@ -24,21 +24,21 @@ n <- 500 p <- 5 X <- matrix(runif(n*p), ncol = p) mu_x <- ( - ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + - ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + - ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + + ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + + ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + + ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + ((0.75 <= X[,1]) & (1 > X[,1])) * (7.5) ) pi_x <- ( - ((0 <= X[,1]) & (0.25 > X[,1])) * (0.2) + - ((0.25 <= X[,1]) & (0.5 > X[,1])) * (0.4) + - ((0.5 <= X[,1]) & (0.75 > X[,1])) * (0.6) + + ((0 <= X[,1]) & (0.25 > X[,1])) * (0.2) + + ((0.25 <= X[,1]) & (0.5 > X[,1])) * (0.4) + + ((0.5 <= X[,1]) & (0.75 > X[,1])) * (0.6) + ((0.75 <= X[,1]) & (1 > X[,1])) * (0.8) ) tau_x <- ( - ((0 <= X[,2]) & (0.25 > X[,2])) * (0.5) + - ((0.25 <= X[,2]) & (0.5 > X[,2])) * (1.0) + - ((0.5 <= X[,2]) & (0.75 > X[,2])) * (1.5) + + ((0 <= X[,2]) & (0.25 > X[,2])) * (0.5) + + ((0.25 <= X[,2]) & (0.5 > X[,2])) * (1.0) + + ((0.5 <= X[,2]) & (0.75 > X[,2])) * (1.5) + ((0.75 <= X[,2]) & (1 > X[,2])) * (2.0) ) Z <- rbinom(n, 1, pi_x) @@ -74,15 +74,15 @@ rfx_term_test <- rfx_term[test_inds] rfx_term_train <- rfx_term[train_inds] mu_params <- list(sample_sigma2_leaf = TRUE) tau_params <- list(sample_sigma2_leaf = FALSE) -bcf_model <- bcf(X_train = X_train, Z_train = Z_train, y_train = y_train, - propensity_train = pi_train, - rfx_group_ids_train = rfx_group_ids_train, - rfx_basis_train = rfx_basis_train, X_test = X_test, - Z_test = Z_test, propensity_test = pi_test, +bcf_model <- bcf(X_train = X_train, Z_train = Z_train, y_train = y_train, + propensity_train = pi_train, + rfx_group_ids_train = rfx_group_ids_train, + rfx_basis_train = rfx_basis_train, X_test = X_test, + Z_test = Z_test, propensity_test = pi_test, rfx_group_ids_test = rfx_group_ids_test, - rfx_basis_test = rfx_basis_test, - num_gfr = 10, num_burnin = 0, num_mcmc = 10, - prognostic_forest_params = mu_params, + rfx_basis_test = rfx_basis_test, + num_gfr = 10, num_burnin = 0, num_mcmc = 10, + prognostic_forest_params = mu_params, treatment_effect_forest_params = tau_params) rfx_samples <- getRandomEffectSamples(bcf_model) } diff --git a/man/predict.bartmodel.Rd b/man/predict.bartmodel.Rd index 2afccbf6..a3197f7b 100644 --- a/man/predict.bartmodel.Rd +++ b/man/predict.bartmodel.Rd @@ -10,6 +10,8 @@ leaf_basis = NULL, rfx_group_ids = NULL, rfx_basis = NULL, + type = "posterior", + terms = "all", ... ) } @@ -26,6 +28,10 @@ that were not in the training set.} \item{rfx_basis}{(Optional) Test set basis for "random-slope" regression in additive random effects model.} +\item{type}{(Optional) Type of prediction to return. Options are "mean", which averages the predictions from every draw of a BART model, and "posterior", which returns the entire matrix of posterior predictions. Default: "posterior".} + +\item{terms}{(Optional) Which model terms to include in the prediction. This can be a single term or a list of model terms. Options include "y_hat", "mean_forest", "rfx", "variance_forest", or "all". If a model doesn't have mean forest, random effects, or variance forest predictions, but one of those terms is request, the request will simply be ignored. If none of the requested terms are present in a model, this function will return \code{NULL}. Default: "all".} + \item{...}{(Optional) Other prediction parameters.} } \value{ @@ -40,9 +46,9 @@ n <- 100 p <- 5 X <- matrix(runif(n*p), ncol = p) f_XW <- ( - ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + - ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + - ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + + ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + + ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + + ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + ((0.75 <= X[,1]) & (1 > X[,1])) * (7.5) ) noise_sd <- 1 @@ -56,7 +62,7 @@ X_test <- X[test_inds,] X_train <- X[train_inds,] y_test <- y[test_inds] y_train <- y[train_inds] -bart_model <- bart(X_train = X_train, y_train = y_train, +bart_model <- bart(X_train = X_train, y_train = y_train, num_gfr = 10, num_burnin = 0, num_mcmc = 10) y_hat_test <- predict(bart_model, X_test)$y_hat } diff --git a/man/predict.bcfmodel.Rd b/man/predict.bcfmodel.Rd index ff315808..907e5308 100644 --- a/man/predict.bcfmodel.Rd +++ b/man/predict.bcfmodel.Rd @@ -42,21 +42,21 @@ n <- 500 p <- 5 X <- matrix(runif(n*p), ncol = p) mu_x <- ( - ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + - ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + - ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + + ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + + ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + + ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + ((0.75 <= X[,1]) & (1 > X[,1])) * (7.5) ) pi_x <- ( - ((0 <= X[,1]) & (0.25 > X[,1])) * (0.2) + - ((0.25 <= X[,1]) & (0.5 > X[,1])) * (0.4) + - ((0.5 <= X[,1]) & (0.75 > X[,1])) * (0.6) + + ((0 <= X[,1]) & (0.25 > X[,1])) * (0.2) + + ((0.25 <= X[,1]) & (0.5 > X[,1])) * (0.4) + + ((0.5 <= X[,1]) & (0.75 > X[,1])) * (0.6) + ((0.75 <= X[,1]) & (1 > X[,1])) * (0.8) ) tau_x <- ( - ((0 <= X[,2]) & (0.25 > X[,2])) * (0.5) + - ((0.25 <= X[,2]) & (0.5 > X[,2])) * (1.0) + - ((0.5 <= X[,2]) & (0.75 > X[,2])) * (1.5) + + ((0 <= X[,2]) & (0.25 > X[,2])) * (0.5) + + ((0.25 <= X[,2]) & (0.5 > X[,2])) * (1.0) + + ((0.5 <= X[,2]) & (0.75 > X[,2])) * (1.5) + ((0.75 <= X[,2]) & (1 > X[,2])) * (2.0) ) Z <- rbinom(n, 1, pi_x) @@ -79,8 +79,8 @@ mu_test <- mu_x[test_inds] mu_train <- mu_x[train_inds] tau_test <- tau_x[test_inds] tau_train <- tau_x[train_inds] -bcf_model <- bcf(X_train = X_train, Z_train = Z_train, y_train = y_train, - propensity_train = pi_train, num_gfr = 10, +bcf_model <- bcf(X_train = X_train, Z_train = Z_train, y_train = y_train, + propensity_train = pi_train, num_gfr = 10, num_burnin = 0, num_mcmc = 10) preds <- predict(bcf_model, X_test, Z_test, pi_test) } diff --git a/man/preprocessPredictionData.Rd b/man/preprocessPredictionData.Rd index f881fda8..a6382e69 100644 --- a/man/preprocessPredictionData.Rd +++ b/man/preprocessPredictionData.Rd @@ -22,7 +22,7 @@ types. Matrices will be passed through assuming all columns are numeric. } \examples{ cov_df <- data.frame(x1 = 1:5, x2 = 5:1, x3 = 6:10) -metadata <- list(num_ordered_cat_vars = 0, num_unordered_cat_vars = 0, +metadata <- list(num_ordered_cat_vars = 0, num_unordered_cat_vars = 0, num_numeric_vars = 3, numeric_vars = c("x1", "x2", "x3")) X_preprocessed <- preprocessPredictionData(cov_df, metadata) } diff --git a/man/resetForestModel.Rd b/man/resetForestModel.Rd index f0fec6ca..b02158d4 100644 --- a/man/resetForestModel.Rd +++ b/man/resetForestModel.Rd @@ -48,23 +48,23 @@ y <- -5 + 10*(X[,1] > 0.5) + rnorm(n) outcome <- createOutcome(y) rng <- createCppRNG(1234) global_model_config <- createGlobalModelConfig(global_error_variance=sigma2) -forest_model_config <- createForestModelConfig(feature_types=feature_types, - num_trees=num_trees, num_observations=n, - num_features=p, alpha=alpha, beta=beta, - min_samples_leaf=min_samples_leaf, - max_depth=max_depth, - variable_weights=variable_weights, - cutpoint_grid_size=cutpoint_grid_size, - leaf_model_type=leaf_model, +forest_model_config <- createForestModelConfig(feature_types=feature_types, + num_trees=num_trees, num_observations=n, + num_features=p, alpha=alpha, beta=beta, + min_samples_leaf=min_samples_leaf, + max_depth=max_depth, + variable_weights=variable_weights, + cutpoint_grid_size=cutpoint_grid_size, + leaf_model_type=leaf_model, leaf_model_scale=leaf_scale) forest_model <- createForestModel(forest_dataset, forest_model_config, global_model_config) active_forest <- createForest(num_trees, leaf_dimension, is_leaf_constant, is_exponentiated) -forest_samples <- createForestSamples(num_trees, leaf_dimension, +forest_samples <- createForestSamples(num_trees, leaf_dimension, is_leaf_constant, is_exponentiated) active_forest$prepare_for_sampler(forest_dataset, outcome, forest_model, 0, 0.) forest_model$sample_one_iteration( - forest_dataset, outcome, forest_samples, active_forest, - rng, forest_model_config, global_model_config, + forest_dataset, outcome, forest_samples, active_forest, + rng, forest_model_config, global_model_config, keep_forest = TRUE, gfr = FALSE ) resetActiveForest(active_forest, forest_samples, 0) diff --git a/man/resetRandomEffectsModel.Rd b/man/resetRandomEffectsModel.Rd index fec99b77..b032ccc2 100644 --- a/man/resetRandomEffectsModel.Rd +++ b/man/resetRandomEffectsModel.Rd @@ -49,8 +49,8 @@ rfx_model$set_group_parameter_cov(sigma_xi_init) rfx_model$set_variance_prior_shape(sigma_xi_shape) rfx_model$set_variance_prior_scale(sigma_xi_scale) for (i in 1:3) { - rfx_model$sample_random_effect(rfx_dataset=rfx_dataset, residual=outcome, - rfx_tracker=rfx_tracker, rfx_samples=rfx_samples, + rfx_model$sample_random_effect(rfx_dataset=rfx_dataset, residual=outcome, + rfx_tracker=rfx_tracker, rfx_samples=rfx_samples, keep_sample=TRUE, global_variance=1.0, rng=rng) } resetRandomEffectsModel(rfx_model, rfx_samples, 0, 1.0) diff --git a/man/resetRandomEffectsTracker.Rd b/man/resetRandomEffectsTracker.Rd index 5249ca96..c57af16a 100644 --- a/man/resetRandomEffectsTracker.Rd +++ b/man/resetRandomEffectsTracker.Rd @@ -57,8 +57,8 @@ rfx_model$set_group_parameter_cov(sigma_xi_init) rfx_model$set_variance_prior_shape(sigma_xi_shape) rfx_model$set_variance_prior_scale(sigma_xi_scale) for (i in 1:3) { - rfx_model$sample_random_effect(rfx_dataset=rfx_dataset, residual=outcome, - rfx_tracker=rfx_tracker, rfx_samples=rfx_samples, + rfx_model$sample_random_effect(rfx_dataset=rfx_dataset, residual=outcome, + rfx_tracker=rfx_tracker, rfx_samples=rfx_samples, keep_sample=TRUE, global_variance=1.0, rng=rng) } resetRandomEffectsModel(rfx_model, rfx_samples, 0, 1.0) diff --git a/man/rootResetRandomEffectsModel.Rd b/man/rootResetRandomEffectsModel.Rd index c58a09e9..4c3cc2f7 100644 --- a/man/rootResetRandomEffectsModel.Rd +++ b/man/rootResetRandomEffectsModel.Rd @@ -63,8 +63,8 @@ rfx_model$set_group_parameter_cov(sigma_xi_init) rfx_model$set_variance_prior_shape(sigma_xi_shape) rfx_model$set_variance_prior_scale(sigma_xi_scale) for (i in 1:3) { - rfx_model$sample_random_effect(rfx_dataset=rfx_dataset, residual=outcome, - rfx_tracker=rfx_tracker, rfx_samples=rfx_samples, + rfx_model$sample_random_effect(rfx_dataset=rfx_dataset, residual=outcome, + rfx_tracker=rfx_tracker, rfx_samples=rfx_samples, keep_sample=TRUE, global_variance=1.0, rng=rng) } rootResetRandomEffectsModel(rfx_model, alpha_init, xi_init, sigma_alpha_init, diff --git a/man/rootResetRandomEffectsTracker.Rd b/man/rootResetRandomEffectsTracker.Rd index 8de2c514..6f2dc843 100644 --- a/man/rootResetRandomEffectsTracker.Rd +++ b/man/rootResetRandomEffectsTracker.Rd @@ -49,8 +49,8 @@ rfx_model$set_group_parameter_cov(sigma_xi_init) rfx_model$set_variance_prior_shape(sigma_xi_shape) rfx_model$set_variance_prior_scale(sigma_xi_scale) for (i in 1:3) { - rfx_model$sample_random_effect(rfx_dataset=rfx_dataset, residual=outcome, - rfx_tracker=rfx_tracker, rfx_samples=rfx_samples, + rfx_model$sample_random_effect(rfx_dataset=rfx_dataset, residual=outcome, + rfx_tracker=rfx_tracker, rfx_samples=rfx_samples, keep_sample=TRUE, global_variance=1.0, rng=rng) } rootResetRandomEffectsModel(rfx_model, alpha_init, xi_init, sigma_alpha_init, diff --git a/man/saveBARTModelToJson.Rd b/man/saveBARTModelToJson.Rd index a617532e..054af24e 100644 --- a/man/saveBARTModelToJson.Rd +++ b/man/saveBARTModelToJson.Rd @@ -20,9 +20,9 @@ n <- 100 p <- 5 X <- matrix(runif(n*p), ncol = p) f_XW <- ( - ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + - ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + - ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + + ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + + ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + + ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + ((0.75 <= X[,1]) & (1 > X[,1])) * (7.5) ) noise_sd <- 1 @@ -36,7 +36,7 @@ X_test <- X[test_inds,] X_train <- X[train_inds,] y_test <- y[test_inds] y_train <- y[train_inds] -bart_model <- bart(X_train = X_train, y_train = y_train, +bart_model <- bart(X_train = X_train, y_train = y_train, num_gfr = 10, num_burnin = 0, num_mcmc = 10) bart_json <- saveBARTModelToJson(bart_model) } diff --git a/man/saveBARTModelToJsonFile.Rd b/man/saveBARTModelToJsonFile.Rd index 46a3110e..62ef6ad7 100644 --- a/man/saveBARTModelToJsonFile.Rd +++ b/man/saveBARTModelToJsonFile.Rd @@ -22,9 +22,9 @@ n <- 100 p <- 5 X <- matrix(runif(n*p), ncol = p) f_XW <- ( - ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + - ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + - ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + + ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + + ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + + ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + ((0.75 <= X[,1]) & (1 > X[,1])) * (7.5) ) noise_sd <- 1 @@ -38,7 +38,7 @@ X_test <- X[test_inds,] X_train <- X[train_inds,] y_test <- y[test_inds] y_train <- y[train_inds] -bart_model <- bart(X_train = X_train, y_train = y_train, +bart_model <- bart(X_train = X_train, y_train = y_train, num_gfr = 10, num_burnin = 0, num_mcmc = 10) tmpjson <- tempfile(fileext = ".json") saveBARTModelToJsonFile(bart_model, file.path(tmpjson)) diff --git a/man/saveBARTModelToJsonString.Rd b/man/saveBARTModelToJsonString.Rd index c83f9e5d..10927c20 100644 --- a/man/saveBARTModelToJsonString.Rd +++ b/man/saveBARTModelToJsonString.Rd @@ -20,9 +20,9 @@ n <- 100 p <- 5 X <- matrix(runif(n*p), ncol = p) f_XW <- ( - ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + - ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + - ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + + ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + + ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + + ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + ((0.75 <= X[,1]) & (1 > X[,1])) * (7.5) ) noise_sd <- 1 @@ -36,7 +36,7 @@ X_test <- X[test_inds,] X_train <- X[train_inds,] y_test <- y[test_inds] y_train <- y[train_inds] -bart_model <- bart(X_train = X_train, y_train = y_train, +bart_model <- bart(X_train = X_train, y_train = y_train, num_gfr = 10, num_burnin = 0, num_mcmc = 10) bart_json_string <- saveBARTModelToJsonString(bart_model) } diff --git a/man/saveBCFModelToJson.Rd b/man/saveBCFModelToJson.Rd index ae2c286d..2c04d76c 100644 --- a/man/saveBCFModelToJson.Rd +++ b/man/saveBCFModelToJson.Rd @@ -20,21 +20,21 @@ n <- 500 p <- 5 X <- matrix(runif(n*p), ncol = p) mu_x <- ( - ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + - ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + - ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + + ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + + ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + + ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + ((0.75 <= X[,1]) & (1 > X[,1])) * (7.5) ) pi_x <- ( - ((0 <= X[,1]) & (0.25 > X[,1])) * (0.2) + - ((0.25 <= X[,1]) & (0.5 > X[,1])) * (0.4) + - ((0.5 <= X[,1]) & (0.75 > X[,1])) * (0.6) + + ((0 <= X[,1]) & (0.25 > X[,1])) * (0.2) + + ((0.25 <= X[,1]) & (0.5 > X[,1])) * (0.4) + + ((0.5 <= X[,1]) & (0.75 > X[,1])) * (0.6) + ((0.75 <= X[,1]) & (1 > X[,1])) * (0.8) ) tau_x <- ( - ((0 <= X[,2]) & (0.25 > X[,2])) * (0.5) + - ((0.25 <= X[,2]) & (0.5 > X[,2])) * (1.0) + - ((0.5 <= X[,2]) & (0.75 > X[,2])) * (1.5) + + ((0 <= X[,2]) & (0.25 > X[,2])) * (0.5) + + ((0.25 <= X[,2]) & (0.5 > X[,2])) * (1.0) + + ((0.5 <= X[,2]) & (0.75 > X[,2])) * (1.5) + ((0.75 <= X[,2]) & (1 > X[,2])) * (2.0) ) Z <- rbinom(n, 1, pi_x) @@ -70,15 +70,15 @@ rfx_term_test <- rfx_term[test_inds] rfx_term_train <- rfx_term[train_inds] mu_params <- list(sample_sigma2_leaf = TRUE) tau_params <- list(sample_sigma2_leaf = FALSE) -bcf_model <- bcf(X_train = X_train, Z_train = Z_train, y_train = y_train, - propensity_train = pi_train, - rfx_group_ids_train = rfx_group_ids_train, - rfx_basis_train = rfx_basis_train, X_test = X_test, - Z_test = Z_test, propensity_test = pi_test, +bcf_model <- bcf(X_train = X_train, Z_train = Z_train, y_train = y_train, + propensity_train = pi_train, + rfx_group_ids_train = rfx_group_ids_train, + rfx_basis_train = rfx_basis_train, X_test = X_test, + Z_test = Z_test, propensity_test = pi_test, rfx_group_ids_test = rfx_group_ids_test, - rfx_basis_test = rfx_basis_test, - num_gfr = 10, num_burnin = 0, num_mcmc = 10, - prognostic_forest_params = mu_params, + rfx_basis_test = rfx_basis_test, + num_gfr = 10, num_burnin = 0, num_mcmc = 10, + prognostic_forest_params = mu_params, treatment_effect_forest_params = tau_params) bcf_json <- saveBCFModelToJson(bcf_model) } diff --git a/man/saveBCFModelToJsonFile.Rd b/man/saveBCFModelToJsonFile.Rd index e6a9f0aa..584bbbba 100644 --- a/man/saveBCFModelToJsonFile.Rd +++ b/man/saveBCFModelToJsonFile.Rd @@ -22,21 +22,21 @@ n <- 500 p <- 5 X <- matrix(runif(n*p), ncol = p) mu_x <- ( - ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + - ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + - ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + + ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + + ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + + ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + ((0.75 <= X[,1]) & (1 > X[,1])) * (7.5) ) pi_x <- ( - ((0 <= X[,1]) & (0.25 > X[,1])) * (0.2) + - ((0.25 <= X[,1]) & (0.5 > X[,1])) * (0.4) + - ((0.5 <= X[,1]) & (0.75 > X[,1])) * (0.6) + + ((0 <= X[,1]) & (0.25 > X[,1])) * (0.2) + + ((0.25 <= X[,1]) & (0.5 > X[,1])) * (0.4) + + ((0.5 <= X[,1]) & (0.75 > X[,1])) * (0.6) + ((0.75 <= X[,1]) & (1 > X[,1])) * (0.8) ) tau_x <- ( - ((0 <= X[,2]) & (0.25 > X[,2])) * (0.5) + - ((0.25 <= X[,2]) & (0.5 > X[,2])) * (1.0) + - ((0.5 <= X[,2]) & (0.75 > X[,2])) * (1.5) + + ((0 <= X[,2]) & (0.25 > X[,2])) * (0.5) + + ((0.25 <= X[,2]) & (0.5 > X[,2])) * (1.0) + + ((0.5 <= X[,2]) & (0.75 > X[,2])) * (1.5) + ((0.75 <= X[,2]) & (1 > X[,2])) * (2.0) ) Z <- rbinom(n, 1, pi_x) @@ -72,15 +72,15 @@ rfx_term_test <- rfx_term[test_inds] rfx_term_train <- rfx_term[train_inds] mu_params <- list(sample_sigma2_leaf = TRUE) tau_params <- list(sample_sigma2_leaf = FALSE) -bcf_model <- bcf(X_train = X_train, Z_train = Z_train, y_train = y_train, - propensity_train = pi_train, - rfx_group_ids_train = rfx_group_ids_train, - rfx_basis_train = rfx_basis_train, X_test = X_test, - Z_test = Z_test, propensity_test = pi_test, +bcf_model <- bcf(X_train = X_train, Z_train = Z_train, y_train = y_train, + propensity_train = pi_train, + rfx_group_ids_train = rfx_group_ids_train, + rfx_basis_train = rfx_basis_train, X_test = X_test, + Z_test = Z_test, propensity_test = pi_test, rfx_group_ids_test = rfx_group_ids_test, - rfx_basis_test = rfx_basis_test, - num_gfr = 10, num_burnin = 0, num_mcmc = 10, - prognostic_forest_params = mu_params, + rfx_basis_test = rfx_basis_test, + num_gfr = 10, num_burnin = 0, num_mcmc = 10, + prognostic_forest_params = mu_params, treatment_effect_forest_params = tau_params) tmpjson <- tempfile(fileext = ".json") saveBCFModelToJsonFile(bcf_model, file.path(tmpjson)) diff --git a/man/saveBCFModelToJsonString.Rd b/man/saveBCFModelToJsonString.Rd index 4328e525..2182bbe3 100644 --- a/man/saveBCFModelToJsonString.Rd +++ b/man/saveBCFModelToJsonString.Rd @@ -20,21 +20,21 @@ n <- 500 p <- 5 X <- matrix(runif(n*p), ncol = p) mu_x <- ( - ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + - ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + - ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + + ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + + ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + + ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + ((0.75 <= X[,1]) & (1 > X[,1])) * (7.5) ) pi_x <- ( - ((0 <= X[,1]) & (0.25 > X[,1])) * (0.2) + - ((0.25 <= X[,1]) & (0.5 > X[,1])) * (0.4) + - ((0.5 <= X[,1]) & (0.75 > X[,1])) * (0.6) + + ((0 <= X[,1]) & (0.25 > X[,1])) * (0.2) + + ((0.25 <= X[,1]) & (0.5 > X[,1])) * (0.4) + + ((0.5 <= X[,1]) & (0.75 > X[,1])) * (0.6) + ((0.75 <= X[,1]) & (1 > X[,1])) * (0.8) ) tau_x <- ( - ((0 <= X[,2]) & (0.25 > X[,2])) * (0.5) + - ((0.25 <= X[,2]) & (0.5 > X[,2])) * (1.0) + - ((0.5 <= X[,2]) & (0.75 > X[,2])) * (1.5) + + ((0 <= X[,2]) & (0.25 > X[,2])) * (0.5) + + ((0.25 <= X[,2]) & (0.5 > X[,2])) * (1.0) + + ((0.5 <= X[,2]) & (0.75 > X[,2])) * (1.5) + ((0.75 <= X[,2]) & (1 > X[,2])) * (2.0) ) Z <- rbinom(n, 1, pi_x) @@ -70,15 +70,15 @@ rfx_term_test <- rfx_term[test_inds] rfx_term_train <- rfx_term[train_inds] mu_params <- list(sample_sigma2_leaf = TRUE) tau_params <- list(sample_sigma2_leaf = FALSE) -bcf_model <- bcf(X_train = X_train, Z_train = Z_train, y_train = y_train, - propensity_train = pi_train, - rfx_group_ids_train = rfx_group_ids_train, - rfx_basis_train = rfx_basis_train, X_test = X_test, - Z_test = Z_test, propensity_test = pi_test, +bcf_model <- bcf(X_train = X_train, Z_train = Z_train, y_train = y_train, + propensity_train = pi_train, + rfx_group_ids_train = rfx_group_ids_train, + rfx_basis_train = rfx_basis_train, X_test = X_test, + Z_test = Z_test, propensity_test = pi_test, rfx_group_ids_test = rfx_group_ids_test, - rfx_basis_test = rfx_basis_test, - num_gfr = 10, num_burnin = 0, num_mcmc = 10, - prognostic_forest_params = mu_params, + rfx_basis_test = rfx_basis_test, + num_gfr = 10, num_burnin = 0, num_mcmc = 10, + prognostic_forest_params = mu_params, treatment_effect_forest_params = tau_params) saveBCFModelToJsonString(bcf_model) } diff --git a/test/R/testthat/test-predict.R b/test/R/testthat/test-predict.R index 88628f6d..92a12114 100644 --- a/test/R/testthat/test-predict.R +++ b/test/R/testthat/test-predict.R @@ -172,3 +172,87 @@ test_that("Prediction from trees with multivariate leaf basis", { # Assertion expect_equal(split_counts, split_counts_expected) }) + +test_that("Predictions with pre-summarization", { + # Generate data and test-train split + n <- 100 + p <- 5 + X <- matrix(runif(n * p), ncol = p) + f_XW <- (((0 <= X[, 1]) & (0.25 > X[, 1])) * + (-7.5) + + ((0.25 <= X[, 1]) & (0.5 > X[, 1])) * (-2.5) + + ((0.5 <= X[, 1]) & (0.75 > X[, 1])) * (2.5) + + ((0.75 <= X[, 1]) & (1 > X[, 1])) * (7.5)) + noise_sd <- 1 + y <- f_XW + rnorm(n, 0, noise_sd) + test_set_pct <- 0.2 + n_test <- round(test_set_pct * n) + n_train <- n - n_test + test_inds <- sort(sample(1:n, n_test, replace = FALSE)) + train_inds <- (1:n)[!((1:n) %in% test_inds)] + X_test <- X[test_inds, ] + X_train <- X[train_inds, ] + y_test <- y[test_inds] + y_train <- y[train_inds] + + # Fit a "classic" BART model + bart_model <- bart( + X_train = X_train, + y_train = y_train, + num_gfr = 10, + num_burnin = 0, + num_mcmc = 10 + ) + + # Check that the default predict method returns a list + pred <- predict(bart_model, X_test) + y_hat_posterior_test <- pred$y_hat + + # Assertion + expect_equal(dim(y_hat_posterior_test), c(20, 10)) + + # Check that the pre-aggregated predictions match with those computed by rowMeans + pred_mean <- predict(bart_model, X_test, type = "mean") + y_hat_mean_test <- pred_mean$y_hat + + # Assertion + expect_equal(y_hat_mean_test, rowMeans(y_hat_posterior_test)) + + # Fit a heteroskedastic BART model + var_params <- list(num_trees = 20) + het_bart_model <- bart( + X_train = X_train, + y_train = y_train, + num_gfr = 10, + num_burnin = 0, + num_mcmc = 10, + variance_forest_params = var_params + ) + + # Check that the default predict method returns a list + pred <- predict(het_bart_model, X_test) + y_hat_posterior_test <- pred$y_hat + sigma2_hat_posterior_test <- pred$variance_forest + + # Assertion + expect_equal(dim(y_hat_posterior_test), c(20, 10)) + expect_equal(dim(sigma2_hat_posterior_test), c(20, 10)) + + # Check that the pre-aggregated predictions match with those computed by rowMeans + pred_mean <- predict(het_bart_model, X_test, type = "mean") + y_hat_mean_test <- pred_mean$y_hat + sigma2_hat_mean_test <- pred_mean$variance_forest + + # Assertion + expect_equal(y_hat_mean_test, rowMeans(y_hat_posterior_test)) + expect_equal(sigma2_hat_mean_test, rowMeans(sigma2_hat_posterior_test)) + + # Check that the "single-term" pre-aggregated predictions + # match those computed by pre-aggregated predictions returned in a list + y_hat_mean_test_single_term <- predict(het_bart_model, X_test, type = "mean", terms = "y_hat") + sigma2_hat_mean_test_single_term <- predict(het_bart_model, X_test, type = "mean", terms = "variance_forest") + + # Assertion + expect_equal(y_hat_mean_test, y_hat_mean_test_single_term) + expect_equal(sigma2_hat_mean_test, sigma2_hat_mean_test_single_term) +}) \ No newline at end of file diff --git a/tools/debug/predict_debug.R b/tools/debug/predict_debug.R new file mode 100644 index 00000000..ec0ee75c --- /dev/null +++ b/tools/debug/predict_debug.R @@ -0,0 +1,31 @@ +# Demo of updated predict method +library(stochtree) +n <- 100 +p <- 5 +X <- matrix(runif(n * p), ncol = p) +f_XW <- (((0 <= X[, 1]) & (0.25 > X[, 1])) * + (-7.5) + + ((0.25 <= X[, 1]) & (0.5 > X[, 1])) * (-2.5) + + ((0.5 <= X[, 1]) & (0.75 > X[, 1])) * (2.5) + + ((0.75 <= X[, 1]) & (1 > X[, 1])) * (7.5)) +noise_sd <- 1 +y <- f_XW + rnorm(n, 0, noise_sd) +test_set_pct <- 0.2 +n_test <- round(test_set_pct * n) +n_train <- n - n_test +test_inds <- sort(sample(1:n, n_test, replace = FALSE)) +train_inds <- (1:n)[!((1:n) %in% test_inds)] +X_test <- X[test_inds, ] +X_train <- X[train_inds, ] +y_test <- y[test_inds] +y_train <- y[train_inds] +bart_model <- bart( + X_train = X_train, + y_train = y_train, + num_gfr = 10, + num_burnin = 0, + num_mcmc = 10 +) + +y_hat_posterior_test <- predict(bart_model, X_test)$y_hat +y_hat_test <- predict(bart_model, X_test, type = "mean", terms = c("rfx", "variance")) From 7a4f5adf9a2266fd0766296f29bf3a023b1cba3b Mon Sep 17 00:00:00 2001 From: Drew Herren Date: Mon, 6 Oct 2025 22:20:08 -0500 Subject: [PATCH 02/53] Updated tests and demo scripts --- test/R/testthat/test-predict.R | 480 +++++++++--------- .../{predict_debug.R => bart_predict_debug.R} | 28 +- 2 files changed, 278 insertions(+), 230 deletions(-) rename tools/debug/{predict_debug.R => bart_predict_debug.R} (63%) diff --git a/test/R/testthat/test-predict.R b/test/R/testthat/test-predict.R index 92a12114..a6f56b5b 100644 --- a/test/R/testthat/test-predict.R +++ b/test/R/testthat/test-predict.R @@ -1,258 +1,286 @@ test_that("Prediction from trees with constant leaf", { - # Create dataset and forest container - num_trees <- 10 - X = matrix(c(1.5, 8.7, 1.2, + # Create dataset and forest container + num_trees <- 10 + # fmt: skip + X = matrix(c(1.5, 8.7, 1.2, 2.7, 3.4, 5.4, 3.6, 1.2, 9.3, 4.4, 5.4, 10.4, 5.3, 9.3, 3.6, 6.1, 10.4, 4.4), byrow = TRUE, nrow = 6) - n <- nrow(X) - p <- ncol(X) - forest_dataset = createForestDataset(X) - forest_samples <- createForestSamples(num_trees, 1, TRUE) - - # Initialize a forest with constant root predictions - forest_samples$add_forest_with_constant_leaves(0.) - - # Check that regular and "raw" predictions are the same (since the leaf is constant) - pred <- forest_samples$predict(forest_dataset) - pred_raw <- forest_samples$predict_raw(forest_dataset) - - # Assertion - expect_equal(pred, pred_raw) - - # Split the root of the first tree in the ensemble at X[,1] > 4.0 - forest_samples$add_numeric_split_tree(0, 0, 0, 0, 4.0, -5., 5.) - - # Check that regular and "raw" predictions are the same (since the leaf is constant) - pred <- forest_samples$predict(forest_dataset) - pred_raw <- forest_samples$predict_raw(forest_dataset) - - # Assertion - expect_equal(pred, pred_raw) - - # Split the left leaf of the first tree in the ensemble at X[,2] > 4.0 - forest_samples$add_numeric_split_tree(0, 0, 1, 1, 4.0, -7.5, -2.5) - - # Check that regular and "raw" predictions are the same (since the leaf is constant) - pred <- forest_samples$predict(forest_dataset) - pred_raw <- forest_samples$predict_raw(forest_dataset) - - # Assertion - expect_equal(pred, pred_raw) - - # Check the split count for the first tree in the ensemble - split_counts <- forest_samples$get_tree_split_counts(0,0,p) - split_counts_expected <- c(1,1,0) - - # Assertion - expect_equal(split_counts, split_counts_expected) + n <- nrow(X) + p <- ncol(X) + forest_dataset = createForestDataset(X) + forest_samples <- createForestSamples(num_trees, 1, TRUE) + + # Initialize a forest with constant root predictions + forest_samples$add_forest_with_constant_leaves(0.) + + # Check that regular and "raw" predictions are the same (since the leaf is constant) + pred <- forest_samples$predict(forest_dataset) + pred_raw <- forest_samples$predict_raw(forest_dataset) + + # Assertion + expect_equal(pred, pred_raw) + + # Split the root of the first tree in the ensemble at X[,1] > 4.0 + forest_samples$add_numeric_split_tree(0, 0, 0, 0, 4.0, -5., 5.) + + # Check that regular and "raw" predictions are the same (since the leaf is constant) + pred <- forest_samples$predict(forest_dataset) + pred_raw <- forest_samples$predict_raw(forest_dataset) + + # Assertion + expect_equal(pred, pred_raw) + + # Split the left leaf of the first tree in the ensemble at X[,2] > 4.0 + forest_samples$add_numeric_split_tree(0, 0, 1, 1, 4.0, -7.5, -2.5) + + # Check that regular and "raw" predictions are the same (since the leaf is constant) + pred <- forest_samples$predict(forest_dataset) + pred_raw <- forest_samples$predict_raw(forest_dataset) + + # Assertion + expect_equal(pred, pred_raw) + + # Check the split count for the first tree in the ensemble + split_counts <- forest_samples$get_tree_split_counts(0, 0, p) + split_counts_expected <- c(1, 1, 0) + + # Assertion + expect_equal(split_counts, split_counts_expected) }) test_that("Prediction from trees with univariate leaf basis", { - # Create dataset and forest container - num_trees <- 10 - X = matrix(c(1.5, 8.7, 1.2, + # Create dataset and forest container + num_trees <- 10 + # fmt: skip + X = matrix(c(1.5, 8.7, 1.2, 2.7, 3.4, 5.4, 3.6, 1.2, 9.3, 4.4, 5.4, 10.4, 5.3, 9.3, 3.6, 6.1, 10.4, 4.4), byrow = TRUE, nrow = 6) - W = as.matrix(c(-1,-1,-1,1,1,1)) - n <- nrow(X) - p <- ncol(X) - forest_dataset = createForestDataset(X,W) - forest_samples <- createForestSamples(num_trees, 1, FALSE) - - # Initialize a forest with constant root predictions - forest_samples$add_forest_with_constant_leaves(0.) - - # Check that regular and "raw" predictions are the same (since the leaf is constant) - pred <- forest_samples$predict(forest_dataset) - pred_raw <- forest_samples$predict_raw(forest_dataset) - - # Assertion - expect_equal(pred, pred_raw) - - # Split the root of the first tree in the ensemble at X[,1] > 4.0 - forest_samples$add_numeric_split_tree(0, 0, 0, 0, 4.0, -5., 5.) - - # Check that regular and "raw" predictions are the same (since the leaf is constant) - pred <- forest_samples$predict(forest_dataset) - pred_raw <- forest_samples$predict_raw(forest_dataset) - pred_manual <- pred_raw*W - - # Assertion - expect_equal(pred, pred_manual) - - # Split the left leaf of the first tree in the ensemble at X[,2] > 4.0 - forest_samples$add_numeric_split_tree(0, 0, 1, 1, 4.0, -7.5, -2.5) - - # Check that regular and "raw" predictions are the same (since the leaf is constant) - pred <- forest_samples$predict(forest_dataset) - pred_raw <- forest_samples$predict_raw(forest_dataset) - pred_manual <- pred_raw*W - - # Assertion - expect_equal(pred, pred_manual) - - # Check the split count for the first tree in the ensemble - split_counts <- forest_samples$get_tree_split_counts(0,0,p) - split_counts_expected <- c(1,1,0) - - # Assertion - expect_equal(split_counts, split_counts_expected) + W = as.matrix(c(-1, -1, -1, 1, 1, 1)) + n <- nrow(X) + p <- ncol(X) + forest_dataset = createForestDataset(X, W) + forest_samples <- createForestSamples(num_trees, 1, FALSE) + + # Initialize a forest with constant root predictions + forest_samples$add_forest_with_constant_leaves(0.) + + # Check that regular and "raw" predictions are the same (since the leaf is constant) + pred <- forest_samples$predict(forest_dataset) + pred_raw <- forest_samples$predict_raw(forest_dataset) + + # Assertion + expect_equal(pred, pred_raw) + + # Split the root of the first tree in the ensemble at X[,1] > 4.0 + forest_samples$add_numeric_split_tree(0, 0, 0, 0, 4.0, -5., 5.) + + # Check that regular and "raw" predictions are the same (since the leaf is constant) + pred <- forest_samples$predict(forest_dataset) + pred_raw <- forest_samples$predict_raw(forest_dataset) + pred_manual <- pred_raw * W + + # Assertion + expect_equal(pred, pred_manual) + + # Split the left leaf of the first tree in the ensemble at X[,2] > 4.0 + forest_samples$add_numeric_split_tree(0, 0, 1, 1, 4.0, -7.5, -2.5) + + # Check that regular and "raw" predictions are the same (since the leaf is constant) + pred <- forest_samples$predict(forest_dataset) + pred_raw <- forest_samples$predict_raw(forest_dataset) + pred_manual <- pred_raw * W + + # Assertion + expect_equal(pred, pred_manual) + + # Check the split count for the first tree in the ensemble + split_counts <- forest_samples$get_tree_split_counts(0, 0, p) + split_counts_expected <- c(1, 1, 0) + + # Assertion + expect_equal(split_counts, split_counts_expected) }) test_that("Prediction from trees with multivariate leaf basis", { - # Create dataset and forest container - num_trees <- 10 - output_dim <- 2 - num_samples <- 0 - X = matrix(c(1.5, 8.7, 1.2, + # Create dataset and forest container + num_trees <- 10 + output_dim <- 2 + num_samples <- 0 + # fmt: skip + X = matrix(c(1.5, 8.7, 1.2, 2.7, 3.4, 5.4, 3.6, 1.2, 9.3, 4.4, 5.4, 10.4, 5.3, 9.3, 3.6, 6.1, 10.4, 4.4), byrow = TRUE, nrow = 6) - n <- nrow(X) - p <- ncol(X) - W = matrix(c(1,1,1,1,1,1,-1,-1,-1,1,1,1), byrow=FALSE, nrow=6) - forest_dataset = createForestDataset(X,W) - forest_samples <- createForestSamples(num_trees, output_dim, FALSE) - - # Initialize a forest with constant root predictions - forest_samples$add_forest_with_constant_leaves(c(1.,1.)) - num_samples <- num_samples + 1 - - # Check that regular and "raw" predictions are the same (since the leaf is constant) - pred <- forest_samples$predict(forest_dataset) - pred_raw <- forest_samples$predict_raw(forest_dataset) - pred_intermediate <- as.numeric(pred_raw) * as.numeric(W) - dim(pred_intermediate) <- c(n, output_dim, num_samples) - pred_manual <- apply(pred_intermediate, 3, function(x) rowSums(x)) - - # Assertion - expect_equal(pred, pred_manual) - - # Split the root of the first tree in the ensemble at X[,1] > 4.0 - forest_samples$add_numeric_split_tree(0, 0, 0, 0, 4.0, c(-5.,-1.), c(5.,1.)) - - # Check that regular and "raw" predictions are the same (since the leaf is constant) - pred <- forest_samples$predict(forest_dataset) - pred_raw <- forest_samples$predict_raw(forest_dataset) - pred_intermediate <- as.numeric(pred_raw) * as.numeric(W) - dim(pred_intermediate) <- c(n, output_dim, num_samples) - pred_manual <- apply(pred_intermediate, 3, function(x) rowSums(x)) - - # Assertion - expect_equal(pred, pred_manual) - - # Split the left leaf of the first tree in the ensemble at X[,2] > 4.0 - forest_samples$add_numeric_split_tree(0, 0, 1, 1, 4.0, c(-7.5,2.5), c(-2.5,7.5)) - - # Check that regular and "raw" predictions are the same (since the leaf is constant) - pred <- forest_samples$predict(forest_dataset) - pred_raw <- forest_samples$predict_raw(forest_dataset) - pred_intermediate <- as.numeric(pred_raw) * as.numeric(W) - dim(pred_intermediate) <- c(n, output_dim, num_samples) - pred_manual <- apply(pred_intermediate, 3, function(x) rowSums(x)) - - # Assertion - expect_equal(pred, pred_manual) - - # Check the split count for the first tree in the ensemble - split_counts <- forest_samples$get_tree_split_counts(0,0,3) - split_counts_expected <- c(1,1,0) - - # Assertion - expect_equal(split_counts, split_counts_expected) + n <- nrow(X) + p <- ncol(X) + W = matrix(c(1, 1, 1, 1, 1, 1, -1, -1, -1, 1, 1, 1), byrow = FALSE, nrow = 6) + forest_dataset = createForestDataset(X, W) + forest_samples <- createForestSamples(num_trees, output_dim, FALSE) + + # Initialize a forest with constant root predictions + forest_samples$add_forest_with_constant_leaves(c(1., 1.)) + num_samples <- num_samples + 1 + + # Check that regular and "raw" predictions are the same (since the leaf is constant) + pred <- forest_samples$predict(forest_dataset) + pred_raw <- forest_samples$predict_raw(forest_dataset) + pred_intermediate <- as.numeric(pred_raw) * as.numeric(W) + dim(pred_intermediate) <- c(n, output_dim, num_samples) + pred_manual <- apply(pred_intermediate, 3, function(x) rowSums(x)) + + # Assertion + expect_equal(pred, pred_manual) + + # Split the root of the first tree in the ensemble at X[,1] > 4.0 + forest_samples$add_numeric_split_tree(0, 0, 0, 0, 4.0, c(-5., -1.), c(5., 1.)) + + # Check that regular and "raw" predictions are the same (since the leaf is constant) + pred <- forest_samples$predict(forest_dataset) + pred_raw <- forest_samples$predict_raw(forest_dataset) + pred_intermediate <- as.numeric(pred_raw) * as.numeric(W) + dim(pred_intermediate) <- c(n, output_dim, num_samples) + pred_manual <- apply(pred_intermediate, 3, function(x) rowSums(x)) + + # Assertion + expect_equal(pred, pred_manual) + + # Split the left leaf of the first tree in the ensemble at X[,2] > 4.0 + forest_samples$add_numeric_split_tree( + 0, + 0, + 1, + 1, + 4.0, + c(-7.5, 2.5), + c(-2.5, 7.5) + ) + + # Check that regular and "raw" predictions are the same (since the leaf is constant) + pred <- forest_samples$predict(forest_dataset) + pred_raw <- forest_samples$predict_raw(forest_dataset) + pred_intermediate <- as.numeric(pred_raw) * as.numeric(W) + dim(pred_intermediate) <- c(n, output_dim, num_samples) + pred_manual <- apply(pred_intermediate, 3, function(x) rowSums(x)) + + # Assertion + expect_equal(pred, pred_manual) + + # Check the split count for the first tree in the ensemble + split_counts <- forest_samples$get_tree_split_counts(0, 0, 3) + split_counts_expected <- c(1, 1, 0) + + # Assertion + expect_equal(split_counts, split_counts_expected) }) test_that("Predictions with pre-summarization", { - # Generate data and test-train split - n <- 100 - p <- 5 - X <- matrix(runif(n * p), ncol = p) - f_XW <- (((0 <= X[, 1]) & (0.25 > X[, 1])) * - (-7.5) + + # Generate data and test-train split + n <- 100 + p <- 5 + X <- matrix(runif(n * p), ncol = p) + # fmt: skip + f_XW <- (((0 <= X[, 1]) & (0.25 > X[, 1])) * (-7.5) + ((0.25 <= X[, 1]) & (0.5 > X[, 1])) * (-2.5) + ((0.5 <= X[, 1]) & (0.75 > X[, 1])) * (2.5) + ((0.75 <= X[, 1]) & (1 > X[, 1])) * (7.5)) - noise_sd <- 1 - y <- f_XW + rnorm(n, 0, noise_sd) - test_set_pct <- 0.2 - n_test <- round(test_set_pct * n) - n_train <- n - n_test - test_inds <- sort(sample(1:n, n_test, replace = FALSE)) - train_inds <- (1:n)[!((1:n) %in% test_inds)] - X_test <- X[test_inds, ] - X_train <- X[train_inds, ] - y_test <- y[test_inds] - y_train <- y[train_inds] - - # Fit a "classic" BART model - bart_model <- bart( - X_train = X_train, - y_train = y_train, - num_gfr = 10, - num_burnin = 0, - num_mcmc = 10 - ) - - # Check that the default predict method returns a list - pred <- predict(bart_model, X_test) - y_hat_posterior_test <- pred$y_hat - - # Assertion - expect_equal(dim(y_hat_posterior_test), c(20, 10)) - - # Check that the pre-aggregated predictions match with those computed by rowMeans - pred_mean <- predict(bart_model, X_test, type = "mean") - y_hat_mean_test <- pred_mean$y_hat - - # Assertion - expect_equal(y_hat_mean_test, rowMeans(y_hat_posterior_test)) - - # Fit a heteroskedastic BART model - var_params <- list(num_trees = 20) - het_bart_model <- bart( - X_train = X_train, - y_train = y_train, - num_gfr = 10, - num_burnin = 0, - num_mcmc = 10, - variance_forest_params = var_params + noise_sd <- 1 + y <- f_XW + rnorm(n, 0, noise_sd) + test_set_pct <- 0.2 + n_test <- round(test_set_pct * n) + n_train <- n - n_test + test_inds <- sort(sample(1:n, n_test, replace = FALSE)) + train_inds <- (1:n)[!((1:n) %in% test_inds)] + X_test <- X[test_inds, ] + X_train <- X[train_inds, ] + y_test <- y[test_inds] + y_train <- y[train_inds] + + # Fit a "classic" BART model + bart_model <- bart( + X_train = X_train, + y_train = y_train, + num_gfr = 10, + num_burnin = 0, + num_mcmc = 10 + ) + + # Check that the default predict method returns a list + pred <- predict(bart_model, X_test) + y_hat_posterior_test <- pred$y_hat + expect_equal(dim(y_hat_posterior_test), c(20, 10)) + + # Check that the pre-aggregated predictions match with those computed by rowMeans + pred_mean <- predict(bart_model, X_test, type = "mean") + y_hat_mean_test <- pred_mean$y_hat + expect_equal(y_hat_mean_test, rowMeans(y_hat_posterior_test)) + + # Check that we warn and return a NULL when requesting terms that weren't fit + expect_warning({ + pred_mean <- predict( + bart_model, + X_test, + type = "mean", + terms = c("rfx", "variance_forest") ) - - # Check that the default predict method returns a list - pred <- predict(het_bart_model, X_test) - y_hat_posterior_test <- pred$y_hat - sigma2_hat_posterior_test <- pred$variance_forest - - # Assertion - expect_equal(dim(y_hat_posterior_test), c(20, 10)) - expect_equal(dim(sigma2_hat_posterior_test), c(20, 10)) - - # Check that the pre-aggregated predictions match with those computed by rowMeans - pred_mean <- predict(het_bart_model, X_test, type = "mean") - y_hat_mean_test <- pred_mean$y_hat - sigma2_hat_mean_test <- pred_mean$variance_forest - - # Assertion - expect_equal(y_hat_mean_test, rowMeans(y_hat_posterior_test)) - expect_equal(sigma2_hat_mean_test, rowMeans(sigma2_hat_posterior_test)) - - # Check that the "single-term" pre-aggregated predictions - # match those computed by pre-aggregated predictions returned in a list - y_hat_mean_test_single_term <- predict(het_bart_model, X_test, type = "mean", terms = "y_hat") - sigma2_hat_mean_test_single_term <- predict(het_bart_model, X_test, type = "mean", terms = "variance_forest") - - # Assertion - expect_equal(y_hat_mean_test, y_hat_mean_test_single_term) - expect_equal(sigma2_hat_mean_test, sigma2_hat_mean_test_single_term) -}) \ No newline at end of file + }) + expect_equal(NULL, pred_mean) + + # Fit a heteroskedastic BART model + var_params <- list(num_trees = 20) + het_bart_model <- bart( + X_train = X_train, + y_train = y_train, + num_gfr = 10, + num_burnin = 0, + num_mcmc = 10, + variance_forest_params = var_params + ) + + # Check that the default predict method returns a list + pred <- predict(het_bart_model, X_test) + y_hat_posterior_test <- pred$y_hat + sigma2_hat_posterior_test <- pred$variance_forest + + # Assertion + expect_equal(dim(y_hat_posterior_test), c(20, 10)) + expect_equal(dim(sigma2_hat_posterior_test), c(20, 10)) + + # Check that the pre-aggregated predictions match with those computed by rowMeans + pred_mean <- predict(het_bart_model, X_test, type = "mean") + y_hat_mean_test <- pred_mean$y_hat + sigma2_hat_mean_test <- pred_mean$variance_forest + + # Assertion + expect_equal(y_hat_mean_test, rowMeans(y_hat_posterior_test)) + expect_equal(sigma2_hat_mean_test, rowMeans(sigma2_hat_posterior_test)) + + # Check that the "single-term" pre-aggregated predictions + # match those computed by pre-aggregated predictions returned in a list + y_hat_mean_test_single_term <- predict( + het_bart_model, + X_test, + type = "mean", + terms = "y_hat" + ) + sigma2_hat_mean_test_single_term <- predict( + het_bart_model, + X_test, + type = "mean", + terms = "variance_forest" + ) + + # Assertion + expect_equal(y_hat_mean_test, y_hat_mean_test_single_term) + expect_equal(sigma2_hat_mean_test, sigma2_hat_mean_test_single_term) +}) diff --git a/tools/debug/predict_debug.R b/tools/debug/bart_predict_debug.R similarity index 63% rename from tools/debug/predict_debug.R rename to tools/debug/bart_predict_debug.R index ec0ee75c..17cb5aa4 100644 --- a/tools/debug/predict_debug.R +++ b/tools/debug/bart_predict_debug.R @@ -1,15 +1,21 @@ -# Demo of updated predict method +# Demo of updated predict method for BART + +# Load library library(stochtree) + +# Generate data n <- 100 p <- 5 X <- matrix(runif(n * p), ncol = p) -f_XW <- (((0 <= X[, 1]) & (0.25 > X[, 1])) * - (-7.5) + +# fmt: skip +f_XW <- (((0 <= X[, 1]) & (0.25 > X[, 1])) * (-7.5) + ((0.25 <= X[, 1]) & (0.5 > X[, 1])) * (-2.5) + ((0.5 <= X[, 1]) & (0.75 > X[, 1])) * (2.5) + ((0.75 <= X[, 1]) & (1 > X[, 1])) * (7.5)) noise_sd <- 1 y <- f_XW + rnorm(n, 0, noise_sd) + +# Train-test split test_set_pct <- 0.2 n_test <- round(test_set_pct * n) n_train <- n - n_test @@ -19,6 +25,8 @@ X_test <- X[test_inds, ] X_train <- X[train_inds, ] y_test <- y[test_inds] y_train <- y[train_inds] + +# Fit simple BART model bart_model <- bart( X_train = X_train, y_train = y_train, @@ -27,5 +35,17 @@ bart_model <- bart( num_mcmc = 10 ) +# Check several predict approaches y_hat_posterior_test <- predict(bart_model, X_test)$y_hat -y_hat_test <- predict(bart_model, X_test, type = "mean", terms = c("rfx", "variance")) +y_hat_mean_test <- predict( + bart_model, + X_test, + type = "mean", + terms = c("y_hat") +) +y_hat_test <- predict( + bart_model, + X_test, + type = "mean", + terms = c("rfx", "variance") +) From 1d1c147dfa7d67ef703b2ab6d26cb3c454ad5c22 Mon Sep 17 00:00:00 2001 From: Drew Herren Date: Mon, 6 Oct 2025 23:54:27 -0500 Subject: [PATCH 03/53] Updated predict function for BCF --- R/bart.R | 2 +- R/bcf.R | 186 +++++++++++++++++++++++++------- man/predict.bartmodel.Rd | 2 +- man/predict.bcfmodel.Rd | 6 ++ test/R/testthat/test-predict.R | 166 +++++++++++++++++++++++++++- tools/debug/bcf_predict_debug.R | 87 +++++++++++++++ 6 files changed, 404 insertions(+), 45 deletions(-) create mode 100644 tools/debug/bcf_predict_debug.R diff --git a/R/bart.R b/R/bart.R index 829d016a..c64ea3d8 100644 --- a/R/bart.R +++ b/R/bart.R @@ -1768,7 +1768,7 @@ bart <- function( #' that were not in the training set. #' @param rfx_basis (Optional) Test set basis for "random-slope" regression in additive random effects model. #' @param type (Optional) Type of prediction to return. Options are "mean", which averages the predictions from every draw of a BART model, and "posterior", which returns the entire matrix of posterior predictions. Default: "posterior". -#' @param terms (Optional) Which model terms to include in the prediction. This can be a single term or a list of model terms. Options include "y_hat", "mean_forest", "rfx", "variance_forest", or "all". If a model doesn't have mean forest, random effects, or variance forest predictions, but one of those terms is request, the request will simply be ignored. If none of the requested terms are present in a model, this function will return `NULL`. Default: "all". +#' @param terms (Optional) Which model terms to include in the prediction. This can be a single term or a list of model terms. Options include "y_hat", "mean_forest", "rfx", "variance_forest", or "all". If a model doesn't have mean forest, random effects, or variance forest predictions, but one of those terms is request, the request will simply be ignored. If none of the requested terms are present in a model, this function will return `NULL` along with a warning. Default: "all". #' @param ... (Optional) Other prediction parameters. #' #' @return List of prediction matrices. If model does not have random effects, the list has one element -- the predictions from the forest. diff --git a/R/bcf.R b/R/bcf.R index bd5fd5b1..6dfcdae3 100644 --- a/R/bcf.R +++ b/R/bcf.R @@ -2558,6 +2558,8 @@ bcf <- function( #' We do not currently support (but plan to in the near future), test set evaluation for group labels #' that were not in the training set. #' @param rfx_basis (Optional) Test set basis for "random-slope" regression in additive random effects model. +#' @param type (Optional) Type of prediction to return. Options are "mean", which averages the predictions from every draw of a BART model, and "posterior", which returns the entire matrix of posterior predictions. Default: "posterior". +#' @param terms (Optional) Which model terms to include in the prediction. This can be a single term or a list of model terms. Options include "y_hat", "prognostic_function", "cate", "rfx", "variance_forest", or "all". If a model doesn't have random effects or variance forest predictions, but one of those terms is request, the request will simply be ignored. If none of the requested terms are present in a model, this function will return `NULL` along with a warning. Default: "all". #' @param ... (Optional) Other prediction parameters. #' #' @return List of 3-5 `nrow(X)` by `object$num_samples` matrices: prognostic function estimates, treatment effect estimates, (optionally) random effects predictions, (optionally) variance forest predictions, and outcome predictions. @@ -2616,8 +2618,59 @@ predict.bcfmodel <- function( propensity = NULL, rfx_group_ids = NULL, rfx_basis = NULL, + type = "posterior", + terms = "all", ... ) { + # Handle prediction type + if (!is.character(type)) { + stop("type must be a string or character vector") + } + if (!(type %in% c("mean", "posterior"))) { + stop("type must either be 'mean' or 'posterior") + } + predict_mean <- type == "mean" + + # Handle prediction terms + if (!is.character(terms)) { + stop("type must be a string or character vector") + } + num_terms <- length(terms) + has_mu_forest <- T + has_tau_forest <- T + has_variance_forest <- object$model_params$include_variance_forest + has_rfx <- object$model_params$has_rfx + has_y_hat <- T + predict_y_hat <- (((has_y_hat) && ("y_hat" %in% terms)) || + ((has_y_hat) && ("all" %in% terms))) + predict_mu_forest <- (((has_mu_forest) && ("prognostic_function" %in% terms)) || + ((has_mu_forest) && ("all" %in% terms))) + predict_tau_forest <- (((has_tau_forest) && ("cate" %in% terms)) || + ((has_tau_forest) && ("all" %in% terms))) + predict_rfx <- (((has_rfx) && ("rfx" %in% terms)) || + ((has_rfx) && ("all" %in% terms))) + predict_variance_forest <- (((has_variance_forest) && + ("variance_forest" %in% terms)) || + ((has_variance_forest) && ("all" %in% terms))) + predict_count <- sum(c( + predict_y_hat, + predict_mu_forest, + predict_tau_forest, + predict_rfx, + predict_variance_forest + )) + if (predict_count == 0) { + warning(paste0( + "None of the requested model terms, ", + paste(terms, collapse = ", "), + ", were fit in this model" + )) + return(NULL) + } + predict_rfx_intermediate <- (predict_y_hat && has_rfx) + predict_mu_forest_intermediate <- (predict_y_hat && has_mu_forest) + predict_tau_forest_intermediate <- (predict_y_hat && has_tau_forest) + # Preprocess covariates if ((!is.data.frame(X)) && (!is.matrix(X))) { stop("X must be a matrix or dataframe") @@ -2699,53 +2752,77 @@ predict.bcfmodel <- function( # Create prediction datasets forest_dataset_pred <- createForestDataset(X_combined, Z) - # Compute forest predictions + # Compute mu forest predictions num_samples <- object$model_params$num_samples y_std <- object$model_params$outcome_scale y_bar <- object$model_params$outcome_mean initial_sigma2 <- object$model_params$initial_sigma2 - mu_hat <- object$forests_mu$predict(forest_dataset_pred) * y_std + y_bar - if (object$model_params$adaptive_coding) { - tau_hat_raw <- object$forests_tau$predict_raw(forest_dataset_pred) - tau_hat <- t( - t(tau_hat_raw) * (object$b_1_samples - object$b_0_samples) - ) * - y_std - } else { - tau_hat <- object$forests_tau$predict_raw(forest_dataset_pred) * y_std - } - if (object$model_params$multivariate_treatment) { - tau_dim <- dim(tau_hat) - tau_num_obs <- tau_dim[1] - tau_num_samples <- tau_dim[3] - treatment_term <- matrix(NA_real_, nrow = tau_num_obs, tau_num_samples) - for (i in 1:nrow(Z)) { - treatment_term[i, ] <- colSums(tau_hat[i, , ] * Z[i, ]) + if (predict_mu_forest || predict_mu_forest_intermediate) { + mu_hat <- object$forests_mu$predict(forest_dataset_pred) * y_std + y_bar + if (predict_mean) { + mu_hat <- rowMeans(mu_hat) } - } else { - treatment_term <- tau_hat * as.numeric(Z) } - if (object$model_params$include_variance_forest) { + + # Compute CATE forest predictions + if (predict_tau_forest || predict_tau_forest_intermediate) { + if (object$model_params$adaptive_coding) { + tau_hat_raw <- object$forests_tau$predict_raw(forest_dataset_pred) + tau_hat <- t( + t(tau_hat_raw) * (object$b_1_samples - object$b_0_samples) + ) * + y_std + } else { + tau_hat <- object$forests_tau$predict_raw(forest_dataset_pred) * y_std + } + if (object$model_params$multivariate_treatment) { + tau_dim <- dim(tau_hat) + tau_num_obs <- tau_dim[1] + tau_num_samples <- tau_dim[3] + treatment_term <- matrix(NA_real_, nrow = tau_num_obs, tau_num_samples) + for (i in 1:nrow(Z)) { + treatment_term[i, ] <- colSums(tau_hat[i, , ] * Z[i, ]) + } + if (predict_mean) { + tau_hat <- apply(tau_hat, c(1,2), mean) + treatment_term <- rowMeans(treatment_term) + } + } else { + treatment_term <- tau_hat * as.numeric(Z) + if (predict_mean) { + tau_hat <- rowMeans(tau_hat) + treatment_term <- rowMeans(treatment_term) + } + } + } + + # Compute variance forest predictions + if (predict_variance_forest) { s_x_raw <- object$forests_variance$predict(forest_dataset_pred) } - # Compute rfx predictions (if needed) - if (object$model_params$has_rfx) { + # Compute rfx predictions + if (predict_rfx) { rfx_predictions <- object$rfx_samples$predict( rfx_group_ids, rfx_basis ) * y_std + if (predict_mean) { + rfx_predictions <- rowMeans(rfx_predictions) + } } # Compute overall "y_hat" predictions - y_hat <- mu_hat + treatment_term - if (object$model_params$has_rfx) { - y_hat <- y_hat + rfx_predictions + if (predict_y_hat) { + y_hat <- mu_hat + treatment_term + if (has_rfx) { + y_hat <- y_hat + rfx_predictions + } } # Scale variance forest predictions - if (object$model_params$include_variance_forest) { + if (predict_variance_forest) { if (object$model_params$sample_sigma2_global) { sigma2_global_samples <- object$sigma2_global_samples variance_forest_predictions <- sapply(1:num_samples, function(i) { @@ -2757,22 +2834,51 @@ predict.bcfmodel <- function( y_std * y_std } + if (predict_mean) { + variance_forest_predictions <- rowMeans(variance_forest_predictions) + } } - result <- list( - "mu_hat" = mu_hat, - "tau_hat" = tau_hat, - "y_hat" = y_hat - ) - if (object$model_params$has_rfx) { - result[["rfx_predictions"]] <- rfx_predictions - } else { - result[["rfx_predictions"]] <- NULL - } - if (object$model_params$include_variance_forest) { - result[["variance_forest_predictions"]] <- variance_forest_predictions + # Return results + if (predict_count == 1) { + if (predict_y_hat) { + return(y_hat) + } else if (predict_mu_forest) { + return(mu_hat) + } else if (predict_tau_forest) { + return(tau_hat) + } else if (predict_rfx) { + return(rfx_predictions) + } else if (predict_variance_forest) { + return(variance_forest_predictions) + } } else { - result[["variance_forest_predictions"]] <- NULL + result <- list() + if (predict_y_hat) { + result[["y_hat"]] = y_hat + } else { + result[["y_hat"]] <- NULL + } + if (predict_mu_forest) { + result[["mu_hat"]] = mu_hat + } else { + result[["mu_hat"]] <- NULL + } + if (predict_tau_forest) { + result[["tau_hat"]] = tau_hat + } else { + result[["tau_hat"]] <- NULL + } + if (predict_rfx) { + result[["rfx_predictions"]] = rfx_predictions + } else { + result[["rfx_predictions"]] <- NULL + } + if (predict_variance_forest) { + result[["variance_forest_predictions"]] = variance_forest_predictions + } else { + result[["variance_forest_predictions"]] <- NULL + } } return(result) } diff --git a/man/predict.bartmodel.Rd b/man/predict.bartmodel.Rd index a3197f7b..99e9d878 100644 --- a/man/predict.bartmodel.Rd +++ b/man/predict.bartmodel.Rd @@ -30,7 +30,7 @@ that were not in the training set.} \item{type}{(Optional) Type of prediction to return. Options are "mean", which averages the predictions from every draw of a BART model, and "posterior", which returns the entire matrix of posterior predictions. Default: "posterior".} -\item{terms}{(Optional) Which model terms to include in the prediction. This can be a single term or a list of model terms. Options include "y_hat", "mean_forest", "rfx", "variance_forest", or "all". If a model doesn't have mean forest, random effects, or variance forest predictions, but one of those terms is request, the request will simply be ignored. If none of the requested terms are present in a model, this function will return \code{NULL}. Default: "all".} +\item{terms}{(Optional) Which model terms to include in the prediction. This can be a single term or a list of model terms. Options include "y_hat", "mean_forest", "rfx", "variance_forest", or "all". If a model doesn't have mean forest, random effects, or variance forest predictions, but one of those terms is request, the request will simply be ignored. If none of the requested terms are present in a model, this function will return \code{NULL} along with a warning. Default: "all".} \item{...}{(Optional) Other prediction parameters.} } diff --git a/man/predict.bcfmodel.Rd b/man/predict.bcfmodel.Rd index 907e5308..dd7f14a0 100644 --- a/man/predict.bcfmodel.Rd +++ b/man/predict.bcfmodel.Rd @@ -11,6 +11,8 @@ propensity = NULL, rfx_group_ids = NULL, rfx_basis = NULL, + type = "posterior", + terms = "all", ... ) } @@ -29,6 +31,10 @@ that were not in the training set.} \item{rfx_basis}{(Optional) Test set basis for "random-slope" regression in additive random effects model.} +\item{type}{(Optional) Type of prediction to return. Options are "mean", which averages the predictions from every draw of a BART model, and "posterior", which returns the entire matrix of posterior predictions. Default: "posterior".} + +\item{terms}{(Optional) Which model terms to include in the prediction. This can be a single term or a list of model terms. Options include "y_hat", "prognostic_function", "cate", "rfx", "variance_forest", or "all". If a model doesn't have random effects or variance forest predictions, but one of those terms is request, the request will simply be ignored. If none of the requested terms are present in a model, this function will return \code{NULL} along with a warning. Default: "all".} + \item{...}{(Optional) Other prediction parameters.} } \value{ diff --git a/test/R/testthat/test-predict.R b/test/R/testthat/test-predict.R index a6f56b5b..1f507ac7 100644 --- a/test/R/testthat/test-predict.R +++ b/test/R/testthat/test-predict.R @@ -184,7 +184,7 @@ test_that("Prediction from trees with multivariate leaf basis", { expect_equal(split_counts, split_counts_expected) }) -test_that("Predictions with pre-summarization", { +test_that("BART predictions with pre-summarization", { # Generate data and test-train split n <- 100 p <- 5 @@ -250,7 +250,7 @@ test_that("Predictions with pre-summarization", { # Check that the default predict method returns a list pred <- predict(het_bart_model, X_test) y_hat_posterior_test <- pred$y_hat - sigma2_hat_posterior_test <- pred$variance_forest + sigma2_hat_posterior_test <- pred$variance_forest_predictions # Assertion expect_equal(dim(y_hat_posterior_test), c(20, 10)) @@ -259,7 +259,7 @@ test_that("Predictions with pre-summarization", { # Check that the pre-aggregated predictions match with those computed by rowMeans pred_mean <- predict(het_bart_model, X_test, type = "mean") y_hat_mean_test <- pred_mean$y_hat - sigma2_hat_mean_test <- pred_mean$variance_forest + sigma2_hat_mean_test <- pred_mean$variance_forest_predictions # Assertion expect_equal(y_hat_mean_test, rowMeans(y_hat_posterior_test)) @@ -284,3 +284,163 @@ test_that("Predictions with pre-summarization", { expect_equal(y_hat_mean_test, y_hat_mean_test_single_term) expect_equal(sigma2_hat_mean_test, sigma2_hat_mean_test_single_term) }) + +test_that("BCF predictions with pre-summarization", { + # Generate data and test-train split + n <- 100 + g <- function(x) { + ifelse(x[, 5] == 1, 2, ifelse(x[, 5] == 2, -1, -4)) + } + x1 <- rnorm(n) + x2 <- rnorm(n) + x3 <- rnorm(n) + x4 <- as.numeric(rbinom(n, 1, 0.5)) + x5 <- as.numeric(sample(1:3, n, replace = TRUE)) + X <- cbind(x1, x2, x3, x4, x5) + p <- ncol(X) + mu_x <- 1 + g(X) + X[, 1] * X[, 3] + tau_x <- 1 + 2 * X[, 2] * X[, 4] + pi_x <- 0.8 * pnorm((3 * mu_x / sd(mu_x)) - 0.5 * X[, 1]) + 0.05 + runif(n) / 10 + Z <- rbinom(n, 1, pi_x) + E_XZ <- mu_x + Z * tau_x + snr <- 2 + y <- E_XZ + rnorm(n, 0, 1) * (sd(E_XZ) / snr) + X <- as.data.frame(X) + X$x4 <- factor(X$x4, ordered = TRUE) + X$x5 <- factor(X$x5, ordered = TRUE) + test_set_pct <- 0.2 + n_test <- round(test_set_pct * n) + n_train <- n - n_test + test_inds <- sort(sample(1:n, n_test, replace = FALSE)) + train_inds <- (1:n)[!((1:n) %in% test_inds)] + X_test <- X[test_inds, ] + X_train <- X[train_inds, ] + pi_test <- pi_x[test_inds] + pi_train <- pi_x[train_inds] + Z_test <- Z[test_inds] + Z_train <- Z[train_inds] + y_test <- y[test_inds] + y_train <- y[train_inds] + mu_test <- mu_x[test_inds] + mu_train <- mu_x[train_inds] + tau_test <- tau_x[test_inds] + tau_train <- tau_x[train_inds] + y_test <- y[test_inds] + y_train <- y[train_inds] + + # Fit a "classic" BCF model + bcf_model <- bcf( + X_train = X_train, + Z_train = Z_train, + y_train = y_train, + propensity_train = pi_train, + X_test = X_test, + Z_test = Z_test, + propensity_test = pi_test, + num_gfr = 10, + num_burnin = 0, + num_mcmc = 10 + ) + + # Check that the default predict method returns a list + pred <- predict( + bcf_model, + X = X_test, + Z = Z_test, + propensity = pi_test + ) + y_hat_posterior_test <- pred$y_hat + expect_equal(dim(y_hat_posterior_test), c(20, 10)) + + # Check that the pre-aggregated predictions match with those computed by rowMeans + pred_mean <- predict( + bcf_model, + X = X_test, + Z = Z_test, + propensity = pi_test, + type = "mean" + ) + y_hat_mean_test <- pred_mean$y_hat + expect_equal(y_hat_mean_test, rowMeans(y_hat_posterior_test)) + + # Check that we warn and return a NULL when requesting terms that weren't fit + expect_warning({ + pred_mean <- predict( + bcf_model, + X = X_test, + Z = Z_test, + propensity = pi_test, + type = "mean", + terms = c("rfx", "variance_forest") + ) + }) + expect_equal(NULL, pred_mean) + + # Fit a heteroskedastic BART model + var_params <- list(num_trees = 20) + het_bcf_model <- bcf( + X_train = X_train, + Z_train = Z_train, + y_train = y_train, + propensity_train = pi_train, + X_test = X_test, + Z_test = Z_test, + propensity_test = pi_test, + num_gfr = 10, + num_burnin = 0, + num_mcmc = 10, + variance_forest_params = var_params + ) + + # Check that the default predict method returns a list + pred <- predict( + het_bcf_model, + X = X_test, + Z = Z_test, + propensity = pi_test + ) + y_hat_posterior_test <- pred$y_hat + sigma2_hat_posterior_test <- pred$variance_forest_predictions + + # Assertion + expect_equal(dim(y_hat_posterior_test), c(20, 10)) + expect_equal(dim(sigma2_hat_posterior_test), c(20, 10)) + + # Check that the pre-aggregated predictions match with those computed by rowMeans + pred_mean <- predict( + het_bcf_model, + X = X_test, + Z = Z_test, + propensity = pi_test, + type = "mean" + ) + y_hat_mean_test <- pred_mean$y_hat + sigma2_hat_mean_test <- pred_mean$variance_forest_predictions + + # Assertion + expect_equal(y_hat_mean_test, rowMeans(y_hat_posterior_test)) + expect_equal(sigma2_hat_mean_test, rowMeans(sigma2_hat_posterior_test)) + + # Check that the "single-term" pre-aggregated predictions + # match those computed by pre-aggregated predictions returned in a list + y_hat_mean_test_single_term <- predict( + het_bcf_model, + X = X_test, + Z = Z_test, + propensity = pi_test, + type = "mean", + terms = "y_hat" + ) + sigma2_hat_mean_test_single_term <- predict( + het_bcf_model, + X = X_test, + Z = Z_test, + propensity = pi_test, + type = "mean", + terms = "variance_forest" + ) + + # Assertion + expect_equal(y_hat_mean_test, y_hat_mean_test_single_term) + expect_equal(sigma2_hat_mean_test, sigma2_hat_mean_test_single_term) +}) diff --git a/tools/debug/bcf_predict_debug.R b/tools/debug/bcf_predict_debug.R new file mode 100644 index 00000000..30d93852 --- /dev/null +++ b/tools/debug/bcf_predict_debug.R @@ -0,0 +1,87 @@ +# Demo of updated predict method for BCF + +# Load library +library(stochtree) + +# Generate data +n <- 100 +g <- function(x) { + ifelse(x[, 5] == 1, 2, ifelse(x[, 5] == 2, -1, -4)) +} +x1 <- rnorm(n) +x2 <- rnorm(n) +x3 <- rnorm(n) +x4 <- as.numeric(rbinom(n, 1, 0.5)) +x5 <- as.numeric(sample(1:3, n, replace = TRUE)) +X <- cbind(x1, x2, x3, x4, x5) +p <- ncol(X) +mu_x <- 1 + g(X) + X[, 1] * X[, 3] +tau_x <- 1 + 2 * X[, 2] * X[, 4] +pi_x <- 0.8 * pnorm((3 * mu_x / sd(mu_x)) - 0.5 * X[, 1]) + 0.05 + runif(n) / 10 +Z <- rbinom(n, 1, pi_x) +E_XZ <- mu_x + Z * tau_x +snr <- 2 +y <- E_XZ + rnorm(n, 0, 1) * (sd(E_XZ) / snr) +X <- as.data.frame(X) +X$x4 <- factor(X$x4, ordered = TRUE) +X$x5 <- factor(X$x5, ordered = TRUE) + +# Train-test split +test_set_pct <- 0.2 +n_test <- round(test_set_pct * n) +n_train <- n - n_test +test_inds <- sort(sample(1:n, n_test, replace = FALSE)) +train_inds <- (1:n)[!((1:n) %in% test_inds)] +X_test <- X[test_inds, ] +X_train <- X[train_inds, ] +pi_test <- pi_x[test_inds] +pi_train <- pi_x[train_inds] +Z_test <- Z[test_inds] +Z_train <- Z[train_inds] +y_test <- y[test_inds] +y_train <- y[train_inds] +mu_test <- mu_x[test_inds] +mu_train <- mu_x[train_inds] +tau_test <- tau_x[test_inds] +tau_train <- tau_x[train_inds] +y_test <- y[test_inds] +y_train <- y[train_inds] + +# Fit a simple BCF model +bcf_model <- bcf( + X_train = X_train, + Z_train = Z_train, + y_train = y_train, + propensity_train = pi_train, + X_test = X_test, + Z_test = Z_test, + propensity_test = pi_test, + num_gfr = 10, + num_burnin = 0, + num_mcmc = 10 +) + +# Check several predict approaches +y_hat_posterior_test <- predict( + bcf_model, + X = X_test, + Z = Z_test, + propensity = pi_test +)$y_hat +pred <- predict( + bcf_model, + X = X_test, + Z = Z_test, + propensity = pi_test, + type = "mean", + terms = c("all") +) +y_hat_test <- predict( + bcf_model, + X = X_test, + Z = Z_test, + propensity = pi_test, + type = "mean", + terms = c("rfx", "variance") +) + From f977c19836f1bb5b392af0b8f6b167c3691a6423 Mon Sep 17 00:00:00 2001 From: Drew Herren Date: Tue, 7 Oct 2025 23:44:09 -0500 Subject: [PATCH 04/53] Updated python BART predict method --- R/bart.R | 2 +- demo/debug/bart_predict_debug.py | 68 +++++++++++++++++++ stochtree/bart.py | 112 ++++++++++++++++++++++++------- 3 files changed, 157 insertions(+), 25 deletions(-) create mode 100644 demo/debug/bart_predict_debug.py diff --git a/R/bart.R b/R/bart.R index c64ea3d8..7fa7f92d 100644 --- a/R/bart.R +++ b/R/bart.R @@ -1814,7 +1814,7 @@ predict.bartmodel <- function( stop("type must be a string or character vector") } if (!(type %in% c("mean", "posterior"))) { - stop("type must either be 'mean' or 'posterior") + stop("type must either be 'mean' or 'posterior'") } predict_mean <- type == "mean" diff --git a/demo/debug/bart_predict_debug.py b/demo/debug/bart_predict_debug.py new file mode 100644 index 00000000..6b477053 --- /dev/null +++ b/demo/debug/bart_predict_debug.py @@ -0,0 +1,68 @@ +# Demo of updated predict method for BART + +# Load library +from stochtree import BARTModel +import numpy as np +from sklearn.model_selection import train_test_split +import matplotlib.pyplot as plt + +# Generate data +rng = np.random.default_rng() +n = 100 +p = 5 +X = rng.uniform(low=0.0, high=1.0, size=(n, p)) +f_X = np.where( + ((0 <= X[:, 0]) & (X[:, 0] <= 0.25)), + -7.5, + np.where( + ((0.25 <= X[:, 0]) & (X[:, 0] <= 0.5)), + -2.5, + np.where(((0.5 <= X[:, 0]) & (X[:, 0] <= 0.75)), 2.5, 7.5), + ), +) +noise_sd = 1.0 +y = f_X + rng.normal(loc=0.0, scale=1.0, size=(n,)) + +# Train-test split +sample_inds = np.arange(n) +test_set_pct = 0.2 +train_inds, test_inds = train_test_split(sample_inds, test_size=test_set_pct) +X_train = X[train_inds, :] +X_test = X[test_inds, :] +y_train = y[train_inds] +y_test = y[test_inds] +f_X_train = f_X[train_inds] +f_X_test = f_X[test_inds] + +# Fit simple BART model +bart_model = BARTModel() +bart_model.sample( + X_train=X_train, + y_train=y_train, + X_test=X_test, + num_gfr=10, + num_burnin=0, + num_mcmc=10, +) + +# # Check several predict approaches +bart_preds = bart_model.predict(covariates=X_test) +y_hat_posterior_test = bart_model.predict(covariates=X_test)['y_hat'] +y_hat_mean_test = bart_model.predict( + covariates=X_test, + type = "mean", + terms = ["y_hat"] +) +y_hat_test = bart_model.predict( + covariates=X_test, + type = "mean", + terms = ["rfx", "variance"] +) + +# Plot predicted versus actual +plt.scatter(y_hat_mean_test, y_test, color="black") +plt.axline((0, 0), slope=1, color="red", linestyle=(0, (3,3))) +plt.xlabel("Predicted") +plt.ylabel("Actual") +plt.title("Y hat") +plt.show() diff --git a/stochtree/bart.py b/stochtree/bart.py index 1f65ea17..3a5566a9 100644 --- a/stochtree/bart.py +++ b/stochtree/bart.py @@ -1577,7 +1577,9 @@ def predict( covariates: Union[np.array, pd.DataFrame], basis: np.array = None, rfx_group_ids: np.array = None, - rfx_basis: np.array = None, + rfx_basis: np.array = None, + type: str = "posterior", + terms: Union[list[str], str] = "all" ) -> Union[np.array, tuple]: """Return predictions from every forest sampled (either / both of mean and variance). Return type is either a single array of predictions, if a BART model only includes a @@ -1593,14 +1595,46 @@ def predict( Optional group labels used for an additive random effects model. rfx_basis : np.array, optional Optional basis for "random-slope" regression in an additive random effects model. + type : str, optional + Type of prediction to return. Options are "mean", which averages the predictions from every draw of a BART model, and "posterior", which returns the entire matrix of posterior predictions. Default: "posterior". + terms : str, optional + Which model terms to include in the prediction. This can be a single term or a list of model terms. Options include "y_hat", "mean_forest", "rfx", "variance_forest", or "all". If a model doesn't have mean forest, random effects, or variance forest predictions, but one of those terms is request, the request will simply be ignored. If none of the requested terms are present in a model, this function will return `NULL` along with a warning. Default: "all". Returns ------- - mu_x : np.array, optional - Mean forest and / or random effects predictions. - sigma2_x : np.array, optional - Variance forest predictions. + Dict of numpy arrays for each prediction term, or a simple numpy array if a single term is requested """ + # Handle prediction type + if not isinstance(type, str): + raise ValueError("type must be a string") + if not type in ["mean", "posterior"]: + raise ValueError("type must either be 'mean' or 'posterior'") + predict_mean = type == "mean" + + # Handle prediction terms + if not isinstance(terms, str) and not isinstance(terms, list): + raise ValueError("type must be a string or list of strings") + num_terms = 1 if isinstance(terms, str) else len(terms) + has_mean_forest = self.include_mean_forest + has_variance_forest = self.include_variance_forest + has_rfx = self.has_rfx + has_y_hat = has_mean_forest or has_rfx + predict_y_hat = ((has_y_hat and ("y_hat" in terms)) or + (has_y_hat and ("all" in terms))) + predict_mean_forest = ((has_mean_forest and ("mean_forest" in terms)) or + (has_mean_forest and ("all" in terms))) + predict_rfx = ((has_rfx and ("rfx" in terms)) or + (has_rfx and ("all" in terms))) + predict_variance_forest = ((has_variance_forest and ("variance_forest" in terms)) or + (has_variance_forest and ("all" in terms))) + predict_count = (predict_y_hat + predict_mean_forest + predict_rfx + predict_variance_forest) + if predict_count == 0: + term_list = ", ".join(terms) + warnings.warn(f"None of the requested model terms, {term_list}, were fit in this model") + return None + predict_rfx_intermediate = predict_y_hat and has_rfx + predict_mean_forest_intermediate = predict_y_hat and has_mean_forest + if not self.is_sampled(): msg = ( "This BARTModel instance is not fitted yet. Call 'fit' with " @@ -1657,22 +1691,22 @@ def predict( pred_dataset.add_basis(basis) # Forest predictions - if self.include_mean_forest: + if predict_mean_forest or predict_mean_forest_intermediate: mean_pred_raw = self.forest_container_mean.forest_container_cpp.Predict( pred_dataset.dataset_cpp ) - mean_pred = mean_pred_raw * self.y_std + self.y_bar + mean_forest_predictions = mean_pred_raw * self.y_std + self.y_bar + if predict_mean: + mean_forest_predictions = np.mean(mean_forest_predictions, axis = 1) - if self.has_rfx: - rfx_preds = ( + if predict_rfx or predict_rfx_intermediate: + rfx_predictions = ( self.rfx_container.predict(rfx_group_ids, rfx_basis) * self.y_std ) - if self.include_mean_forest: - mean_pred = mean_pred + rfx_preds - else: - mean_pred = rfx_preds + self.y_bar + if predict_mean: + rfx_predictions = np.mean(rfx_predictions, axis = 1) - if self.include_variance_forest: + if predict_variance_forest: variance_pred_raw = ( self.forest_container_variance.forest_container_cpp.Predict( pred_dataset.dataset_cpp @@ -1685,18 +1719,48 @@ def predict( variance_pred_raw[:, i] * self.global_var_samples[i] ) else: - variance_pred = ( + variance_forest_predictions = ( variance_pred_raw * self.sigma2_init * self.y_std * self.y_std ) - - has_mean_predictions = self.include_mean_forest or self.has_rfx - if has_mean_predictions and self.include_variance_forest: - return {"y_hat": mean_pred, "variance_forest_predictions": variance_pred} - elif has_mean_predictions and not self.include_variance_forest: - return {"y_hat": mean_pred, "variance_forest_predictions": None} - elif not has_mean_predictions and self.include_variance_forest: - return {"y_hat": None, "variance_forest_predictions": variance_pred} - + if predict_mean: + variance_forest_predictions = np.mean(variance_forest_predictions, axis = 1) + + if predict_y_hat and has_mean_forest and has_rfx: + y_hat = mean_forest_predictions + rfx_predictions + elif predict_y_hat and has_mean_forest: + y_hat = mean_forest_predictions + elif predict_y_hat and has_rfx: + y_hat = rfx_predictions + + if predict_count == 1: + if predict_y_hat: + return y_hat + elif predict_mean_forest: + return mean_forest_predictions + elif predict_rfx: + return rfx_predictions + elif predict_variance_forest: + return variance_forest_predictions + else: + result = dict() + if predict_y_hat: + result["y_hat"] = y_hat + else: + result["y_hat"] = None + if predict_mean_forest: + result["mean_forest_predictions"] = mean_forest_predictions + else: + result["mean_forest_predictions"] = None + if predict_rfx: + result["rfx_predictions"] = rfx_predictions + else: + result["rfx_predictions"] = None + if predict_variance_forest: + result["variance_forest_predictions"] = variance_forest_predictions + else: + result["variance_forest_predictions"] = None + return result + def predict_mean( self, covariates: np.array, From 11fbf7640b4f191cf17ccd8537830340f7c771e7 Mon Sep 17 00:00:00 2001 From: Drew Herren Date: Wed, 8 Oct 2025 08:22:23 -0500 Subject: [PATCH 05/53] Updated python BCF predict method --- R/bart.R | 3 +- R/bcf.R | 2 +- stochtree/bart.py | 8 +-- stochtree/bcf.py | 146 +++++++++++++++++++++++++++++++++------------- 4 files changed, 111 insertions(+), 48 deletions(-) diff --git a/R/bart.R b/R/bart.R index 7fa7f92d..d2488674 100644 --- a/R/bart.R +++ b/R/bart.R @@ -1771,8 +1771,7 @@ bart <- function( #' @param terms (Optional) Which model terms to include in the prediction. This can be a single term or a list of model terms. Options include "y_hat", "mean_forest", "rfx", "variance_forest", or "all". If a model doesn't have mean forest, random effects, or variance forest predictions, but one of those terms is request, the request will simply be ignored. If none of the requested terms are present in a model, this function will return `NULL` along with a warning. Default: "all". #' @param ... (Optional) Other prediction parameters. #' -#' @return List of prediction matrices. If model does not have random effects, the list has one element -- the predictions from the forest. -#' If the model does have random effects, the list has three elements -- forest predictions, random effects predictions, and their sum (`y_hat`). +#' @return List of prediction matrices or single prediction matrix / vector, depending on the terms requested. #' @export #' #' @examples diff --git a/R/bcf.R b/R/bcf.R index 6dfcdae3..28bc7b42 100644 --- a/R/bcf.R +++ b/R/bcf.R @@ -2562,7 +2562,7 @@ bcf <- function( #' @param terms (Optional) Which model terms to include in the prediction. This can be a single term or a list of model terms. Options include "y_hat", "prognostic_function", "cate", "rfx", "variance_forest", or "all". If a model doesn't have random effects or variance forest predictions, but one of those terms is request, the request will simply be ignored. If none of the requested terms are present in a model, this function will return `NULL` along with a warning. Default: "all". #' @param ... (Optional) Other prediction parameters. #' -#' @return List of 3-5 `nrow(X)` by `object$num_samples` matrices: prognostic function estimates, treatment effect estimates, (optionally) random effects predictions, (optionally) variance forest predictions, and outcome predictions. +#' @return List of prediction matrices or single prediction matrix / vector, depending on the terms requested. #' @export #' #' @examples diff --git a/stochtree/bart.py b/stochtree/bart.py index 3a5566a9..a1ffc317 100644 --- a/stochtree/bart.py +++ b/stochtree/bart.py @@ -1602,12 +1602,12 @@ def predict( Returns ------- - Dict of numpy arrays for each prediction term, or a simple numpy array if a single term is requested + Dict of numpy arrays for each prediction term, or a simple numpy array if a single term is requested. """ # Handle prediction type if not isinstance(type, str): raise ValueError("type must be a string") - if not type in ["mean", "posterior"]: + if type not in ["mean", "posterior"]: raise ValueError("type must either be 'mean' or 'posterior'") predict_mean = type == "mean" @@ -1713,9 +1713,9 @@ def predict( ) ) if self.sample_sigma2_global: - variance_pred = np.empty_like(variance_pred_raw) + variance_forest_predictions = np.empty_like(variance_pred_raw) for i in range(self.num_samples): - variance_pred[:, i] = ( + variance_forest_predictions[:, i] = ( variance_pred_raw[:, i] * self.global_var_samples[i] ) else: diff --git a/stochtree/bcf.py b/stochtree/bcf.py index bfe9cc34..a59d3239 100644 --- a/stochtree/bcf.py +++ b/stochtree/bcf.py @@ -2416,7 +2416,9 @@ def predict( Z: np.array, propensity: np.array = None, rfx_group_ids: np.array = None, - rfx_basis: np.array = None, + rfx_basis: np.array = None, + type: str = "posterior", + terms: Union[list[str], str] = "all" ) -> dict: """Predict outcome model components (CATE function and prognostic function) as well as overall outcome for every provided observation. Predicted outcomes are computed as `yhat = mu_x + Z*tau_x` where mu_x is a sample of the prognostic function and tau_x is a sample of the treatment effect (CATE) function. @@ -2433,21 +2435,51 @@ def predict( Optional group labels used for an additive random effects model. rfx_basis : np.array, optional Optional basis for "random-slope" regression in an additive random effects model. + + type : str, optional + Type of prediction to return. Options are "mean", which averages the predictions from every draw of a BART model, and "posterior", which returns the entire matrix of posterior predictions. Default: "posterior". + terms : str, optional + Which model terms to include in the prediction. This can be a single term or a list of model terms. Options include "y_hat", "prognostic_function", "cate", "rfx", "variance_forest", or "all". If a model doesn't have mean forest, random effects, or variance forest predictions, but one of those terms is request, the request will simply be ignored. If none of the requested terms are present in a model, this function will return `NULL` along with a warning. Default: "all". Returns ------- - tau_x : np.array - Conditional average treatment effect (CATE) samples for every observation provided. - mu_x : np.array - Prognostic effect samples for every observation provided. - rfx : np.array, optional - Random effect samples for every observation provided, if the model includes a random effects term. - yhat_x : np.array - Outcome prediction samples for every observation provided. - sigma2_x : np.array, optional - Variance forest samples for every observation provided. Only returned if the - model includes a heteroskedasticity forest. + Dict of numpy arrays for each prediction term, or a simple numpy array if a single term is requested. """ + # Handle prediction type + if not isinstance(type, str): + raise ValueError("type must be a string") + if type not in ["mean", "posterior"]: + raise ValueError("type must either be 'mean' or 'posterior'") + predict_mean = type == "mean" + + # Handle prediction terms + if not isinstance(terms, str) and not isinstance(terms, list): + raise ValueError("type must be a string or list of strings") + num_terms = 1 if isinstance(terms, str) else len(terms) + has_mu_forest = True + has_tau_forest = True + has_variance_forest = self.include_variance_forest + has_rfx = self.has_rfx + has_y_hat = has_mu_forest or has_tau_forest or has_rfx + predict_y_hat = ((has_y_hat and ("y_hat" in terms)) or + (has_y_hat and ("all" in terms))) + predict_mu_forest = ((has_mu_forest and ("prognostic_function" in terms)) or + (has_mu_forest and ("all" in terms))) + predict_tau_forest = ((has_tau_forest and ("cate" in terms)) or + (has_tau_forest and ("all" in terms))) + predict_rfx = ((has_rfx and ("rfx" in terms)) or + (has_rfx and ("all" in terms))) + predict_variance_forest = ((has_variance_forest and ("variance_forest" in terms)) or + (has_variance_forest and ("all" in terms))) + predict_count = (predict_y_hat + predict_mu_forest + predict_tau_forest + predict_rfx + predict_variance_forest) + if predict_count == 0: + term_list = ", ".join(terms) + warnings.warn(f"None of the requested model terms, {term_list}, were fit in this model") + return None + predict_rfx_intermediate = predict_y_hat and has_rfx + predict_mu_forest_intermediate = predict_y_hat and has_mu_forest + predict_tau_forest_intermediate = predict_y_hat and has_tau_forest + if not self.is_sampled(): msg = ( "This BCFModel instance is not fitted yet. Call 'fit' with " @@ -2520,35 +2552,42 @@ def predict( forest_dataset_test.add_basis(Z) # Compute predicted outcome and decomposed outcome model terms - mu_raw = self.forest_container_mu.forest_container_cpp.Predict( - forest_dataset_test.dataset_cpp - ) - mu_x = mu_raw * self.y_std + self.y_bar - tau_raw = self.forest_container_tau.forest_container_cpp.PredictRaw( - forest_dataset_test.dataset_cpp - ) - if self.adaptive_coding: - adaptive_coding_weights = np.expand_dims( - self.b1_samples - self.b0_samples, axis=(0, 2) + if predict_mu_forest or predict_mu_forest_intermediate: + mu_raw = self.forest_container_mu.forest_container_cpp.Predict( + forest_dataset_test.dataset_cpp ) - tau_raw = tau_raw * adaptive_coding_weights - tau_x = np.squeeze(tau_raw * self.y_std) - if Z.shape[1] > 1: - treatment_term = np.multiply(np.atleast_3d(Z).swapaxes(1, 2), tau_x).sum( - axis=2 + mu_x = mu_raw * self.y_std + self.y_bar + if predict_tau_forest or predict_tau_forest_intermediate: + tau_raw = self.forest_container_tau.forest_container_cpp.PredictRaw( + forest_dataset_test.dataset_cpp ) - else: - treatment_term = Z * np.squeeze(tau_x) - yhat_x = mu_x + treatment_term + if self.adaptive_coding: + adaptive_coding_weights = np.expand_dims( + self.b1_samples - self.b0_samples, axis=(0, 2) + ) + tau_raw = tau_raw * adaptive_coding_weights + tau_x = np.squeeze(tau_raw * self.y_std) + if Z.shape[1] > 1: + treatment_term = np.multiply(np.atleast_3d(Z).swapaxes(1, 2), tau_x).sum( + axis=2 + ) + else: + treatment_term = Z * np.squeeze(tau_x) - if self.has_rfx: + if predict_rfx or predict_rfx_intermediate: rfx_preds = ( self.rfx_container.predict(rfx_group_ids, rfx_basis) * self.y_std ) - yhat_x = yhat_x + rfx_preds + + if predict_y_hat and has_mu_forest and has_rfx: + y_hat = mu_x + treatment_term + rfx_preds + elif predict_y_hat and has_mu_forest: + y_hat = mu_x + treatment_term + elif predict_y_hat and has_rfx: + y_hat = rfx_preds # Compute predictions from the variance forest (if included) - if self.include_variance_forest: + if predict_variance_forest: sigma2_x_raw = self.forest_container_variance.forest_container_cpp.Predict( forest_dataset_test.dataset_cpp ) @@ -2559,15 +2598,40 @@ def predict( else: sigma2_x = sigma2_x_raw * self.sigma2_init * self.y_std * self.y_std - # Return result matrices as a tuple - if self.has_rfx and self.include_variance_forest: - return {"mu_hat": mu_x, "tau_hat": tau_x, "y_hat": yhat_x, "rfx_predictions": rfx_preds, "variance_forest_predictions": sigma2_x} - elif not self.has_rfx and self.include_variance_forest: - return {"mu_hat": mu_x, "tau_hat": tau_x, "y_hat": yhat_x, "rfx_predictions": None, "variance_forest_predictions": sigma2_x} - elif self.has_rfx and not self.include_variance_forest: - return {"mu_hat": mu_x, "tau_hat": tau_x, "y_hat": yhat_x, "rfx_predictions": rfx_preds, "variance_forest_predictions": None} + if predict_count == 1: + if predict_y_hat: + return y_hat + elif predict_mu_forest: + return mu_x + elif predict_tau_forest: + return tau_x + elif predict_rfx: + return rfx_preds + elif predict_variance_forest: + return sigma2_x else: - return {"mu_hat": mu_x, "tau_hat": tau_x, "y_hat": yhat_x, "rfx_predictions": None, "variance_forest_predictions": None} + result = dict() + if predict_y_hat: + result["y_hat"] = y_hat + else: + result["y_hat"] = None + if predict_mu_forest: + result["mu_hat"] = mu_x + else: + result["mu_hat"] = None + if predict_tau_forest: + result["tau_hat"] = tau_x + else: + result["tau_hat"] = None + if predict_rfx: + result["rfx_predictions"] = rfx_preds + else: + result["rfx_predictions"] = None + if predict_variance_forest: + result["variance_forest_predictions"] = sigma2_x + else: + result["variance_forest_predictions"] = None + return result def to_json(self) -> str: """ From 8e7ab6d90e6c2706b08c9e194e3f3866a91d7604 Mon Sep 17 00:00:00 2001 From: Drew Herren Date: Wed, 8 Oct 2025 20:07:02 -0500 Subject: [PATCH 06/53] Updated predict method for BCF --- stochtree/bcf.py | 12 ++ test/R/testthat/test-predict.R | 307 ++++++++++++++++----------------- test/python/test_predict.py | 199 ++++++++++++++++++++- 3 files changed, 362 insertions(+), 156 deletions(-) diff --git a/stochtree/bcf.py b/stochtree/bcf.py index a59d3239..7d384fa3 100644 --- a/stochtree/bcf.py +++ b/stochtree/bcf.py @@ -2557,6 +2557,8 @@ def predict( forest_dataset_test.dataset_cpp ) mu_x = mu_raw * self.y_std + self.y_bar + if predict_mean: + mu_x = np.mean(mu_x, axis=1) if predict_tau_forest or predict_tau_forest_intermediate: tau_raw = self.forest_container_tau.forest_container_cpp.PredictRaw( forest_dataset_test.dataset_cpp @@ -2571,13 +2573,21 @@ def predict( treatment_term = np.multiply(np.atleast_3d(Z).swapaxes(1, 2), tau_x).sum( axis=2 ) + if predict_mean: + treatment_term = np.mean(treatment_term, axis=1) + tau_x = np.mean(tau_x, axis=2) else: treatment_term = Z * np.squeeze(tau_x) + if predict_mean: + treatment_term = np.mean(treatment_term, axis=1) + tau_x = np.mean(tau_x, axis=1) if predict_rfx or predict_rfx_intermediate: rfx_preds = ( self.rfx_container.predict(rfx_group_ids, rfx_basis) * self.y_std ) + if predict_mean: + rfx_preds = np.mean(rfx_preds, axis=1) if predict_y_hat and has_mu_forest and has_rfx: y_hat = mu_x + treatment_term + rfx_preds @@ -2597,6 +2607,8 @@ def predict( sigma2_x[:, i] = sigma2_x_raw[:, i] * self.global_var_samples[i] else: sigma2_x = sigma2_x_raw * self.sigma2_init * self.y_std * self.y_std + if predict_mean: + sigma2_x = np.mean(sigma2_x, axis=1) if predict_count == 1: if predict_y_hat: diff --git a/test/R/testthat/test-predict.R b/test/R/testthat/test-predict.R index 1f507ac7..e64c85a6 100644 --- a/test/R/testthat/test-predict.R +++ b/test/R/testthat/test-predict.R @@ -286,161 +286,158 @@ test_that("BART predictions with pre-summarization", { }) test_that("BCF predictions with pre-summarization", { - # Generate data and test-train split - n <- 100 - g <- function(x) { - ifelse(x[, 5] == 1, 2, ifelse(x[, 5] == 2, -1, -4)) - } - x1 <- rnorm(n) - x2 <- rnorm(n) - x3 <- rnorm(n) - x4 <- as.numeric(rbinom(n, 1, 0.5)) - x5 <- as.numeric(sample(1:3, n, replace = TRUE)) - X <- cbind(x1, x2, x3, x4, x5) - p <- ncol(X) - mu_x <- 1 + g(X) + X[, 1] * X[, 3] - tau_x <- 1 + 2 * X[, 2] * X[, 4] - pi_x <- 0.8 * pnorm((3 * mu_x / sd(mu_x)) - 0.5 * X[, 1]) + 0.05 + runif(n) / 10 - Z <- rbinom(n, 1, pi_x) - E_XZ <- mu_x + Z * tau_x - snr <- 2 - y <- E_XZ + rnorm(n, 0, 1) * (sd(E_XZ) / snr) - X <- as.data.frame(X) - X$x4 <- factor(X$x4, ordered = TRUE) - X$x5 <- factor(X$x5, ordered = TRUE) - test_set_pct <- 0.2 - n_test <- round(test_set_pct * n) - n_train <- n - n_test - test_inds <- sort(sample(1:n, n_test, replace = FALSE)) - train_inds <- (1:n)[!((1:n) %in% test_inds)] - X_test <- X[test_inds, ] - X_train <- X[train_inds, ] - pi_test <- pi_x[test_inds] - pi_train <- pi_x[train_inds] - Z_test <- Z[test_inds] - Z_train <- Z[train_inds] - y_test <- y[test_inds] - y_train <- y[train_inds] - mu_test <- mu_x[test_inds] - mu_train <- mu_x[train_inds] - tau_test <- tau_x[test_inds] - tau_train <- tau_x[train_inds] - y_test <- y[test_inds] - y_train <- y[train_inds] - - # Fit a "classic" BCF model - bcf_model <- bcf( - X_train = X_train, - Z_train = Z_train, - y_train = y_train, - propensity_train = pi_train, - X_test = X_test, - Z_test = Z_test, - propensity_test = pi_test, - num_gfr = 10, - num_burnin = 0, - num_mcmc = 10 - ) - - # Check that the default predict method returns a list - pred <- predict( - bcf_model, - X = X_test, - Z = Z_test, - propensity = pi_test - ) - y_hat_posterior_test <- pred$y_hat - expect_equal(dim(y_hat_posterior_test), c(20, 10)) - - # Check that the pre-aggregated predictions match with those computed by rowMeans - pred_mean <- predict( - bcf_model, - X = X_test, - Z = Z_test, - propensity = pi_test, - type = "mean" - ) - y_hat_mean_test <- pred_mean$y_hat - expect_equal(y_hat_mean_test, rowMeans(y_hat_posterior_test)) - - # Check that we warn and return a NULL when requesting terms that weren't fit - expect_warning({ - pred_mean <- predict( - bcf_model, - X = X_test, - Z = Z_test, - propensity = pi_test, - type = "mean", - terms = c("rfx", "variance_forest") - ) - }) - expect_equal(NULL, pred_mean) - - # Fit a heteroskedastic BART model - var_params <- list(num_trees = 20) - het_bcf_model <- bcf( - X_train = X_train, - Z_train = Z_train, - y_train = y_train, - propensity_train = pi_train, - X_test = X_test, - Z_test = Z_test, - propensity_test = pi_test, - num_gfr = 10, - num_burnin = 0, - num_mcmc = 10, - variance_forest_params = var_params - ) - - # Check that the default predict method returns a list - pred <- predict( - het_bcf_model, - X = X_test, - Z = Z_test, - propensity = pi_test - ) - y_hat_posterior_test <- pred$y_hat - sigma2_hat_posterior_test <- pred$variance_forest_predictions - - # Assertion - expect_equal(dim(y_hat_posterior_test), c(20, 10)) - expect_equal(dim(sigma2_hat_posterior_test), c(20, 10)) - - # Check that the pre-aggregated predictions match with those computed by rowMeans + # Generate data and test-train split + n <- 100 + g <- function(x) { + ifelse(x[, 5] == 1, 2, ifelse(x[, 5] == 2, -1, -4)) + } + x1 <- rnorm(n) + x2 <- rnorm(n) + x3 <- rnorm(n) + x4 <- as.numeric(rbinom(n, 1, 0.5)) + x5 <- as.numeric(sample(1:3, n, replace = TRUE)) + X <- cbind(x1, x2, x3, x4, x5) + p <- ncol(X) + mu_x <- 1 + g(X) + X[, 1] * X[, 3] + tau_x <- 1 + 2 * X[, 2] * X[, 4] + pi_x <- 0.8 * + pnorm((3 * mu_x / sd(mu_x)) - 0.5 * X[, 1]) + + 0.05 + + runif(n) / 10 + Z <- rbinom(n, 1, pi_x) + E_XZ <- mu_x + Z * tau_x + snr <- 2 + y <- E_XZ + rnorm(n, 0, 1) * (sd(E_XZ) / snr) + X <- as.data.frame(X) + X$x4 <- factor(X$x4, ordered = TRUE) + X$x5 <- factor(X$x5, ordered = TRUE) + test_set_pct <- 0.2 + n_test <- round(test_set_pct * n) + n_train <- n - n_test + test_inds <- sort(sample(1:n, n_test, replace = FALSE)) + train_inds <- (1:n)[!((1:n) %in% test_inds)] + X_test <- X[test_inds, ] + X_train <- X[train_inds, ] + pi_test <- pi_x[test_inds] + pi_train <- pi_x[train_inds] + Z_test <- Z[test_inds] + Z_train <- Z[train_inds] + y_test <- y[test_inds] + y_train <- y[train_inds] + + # Fit a "classic" BCF model + bcf_model <- bcf( + X_train = X_train, + Z_train = Z_train, + y_train = y_train, + propensity_train = pi_train, + X_test = X_test, + Z_test = Z_test, + propensity_test = pi_test, + num_gfr = 10, + num_burnin = 0, + num_mcmc = 10 + ) + + # Check that the default predict method returns a list + pred <- predict( + bcf_model, + X = X_test, + Z = Z_test, + propensity = pi_test + ) + y_hat_posterior_test <- pred$y_hat + expect_equal(dim(y_hat_posterior_test), c(20, 10)) + + # Check that the pre-aggregated predictions match with those computed by rowMeans + pred_mean <- predict( + bcf_model, + X = X_test, + Z = Z_test, + propensity = pi_test, + type = "mean" + ) + y_hat_mean_test <- pred_mean$y_hat + expect_equal(y_hat_mean_test, rowMeans(y_hat_posterior_test)) + + # Check that we warn and return a NULL when requesting terms that weren't fit + expect_warning({ pred_mean <- predict( - het_bcf_model, - X = X_test, - Z = Z_test, - propensity = pi_test, - type = "mean" - ) - y_hat_mean_test <- pred_mean$y_hat - sigma2_hat_mean_test <- pred_mean$variance_forest_predictions - - # Assertion - expect_equal(y_hat_mean_test, rowMeans(y_hat_posterior_test)) - expect_equal(sigma2_hat_mean_test, rowMeans(sigma2_hat_posterior_test)) - - # Check that the "single-term" pre-aggregated predictions - # match those computed by pre-aggregated predictions returned in a list - y_hat_mean_test_single_term <- predict( - het_bcf_model, - X = X_test, - Z = Z_test, - propensity = pi_test, - type = "mean", - terms = "y_hat" - ) - sigma2_hat_mean_test_single_term <- predict( - het_bcf_model, - X = X_test, - Z = Z_test, - propensity = pi_test, - type = "mean", - terms = "variance_forest" + bcf_model, + X = X_test, + Z = Z_test, + propensity = pi_test, + type = "mean", + terms = c("rfx", "variance_forest") ) - - # Assertion - expect_equal(y_hat_mean_test, y_hat_mean_test_single_term) - expect_equal(sigma2_hat_mean_test, sigma2_hat_mean_test_single_term) + }) + expect_equal(NULL, pred_mean) + + # Fit a heteroskedastic BCF model + var_params <- list(num_trees = 20) + het_bcf_model <- bcf( + X_train = X_train, + Z_train = Z_train, + y_train = y_train, + propensity_train = pi_train, + X_test = X_test, + Z_test = Z_test, + propensity_test = pi_test, + num_gfr = 10, + num_burnin = 0, + num_mcmc = 10, + variance_forest_params = var_params + ) + + # Check that the default predict method returns a list + pred <- predict( + het_bcf_model, + X = X_test, + Z = Z_test, + propensity = pi_test + ) + y_hat_posterior_test <- pred$y_hat + sigma2_hat_posterior_test <- pred$variance_forest_predictions + + # Assertion + expect_equal(dim(y_hat_posterior_test), c(20, 10)) + expect_equal(dim(sigma2_hat_posterior_test), c(20, 10)) + + # Check that the pre-aggregated predictions match with those computed by rowMeans + pred_mean <- predict( + het_bcf_model, + X = X_test, + Z = Z_test, + propensity = pi_test, + type = "mean" + ) + y_hat_mean_test <- pred_mean$y_hat + sigma2_hat_mean_test <- pred_mean$variance_forest_predictions + + # Assertion + expect_equal(y_hat_mean_test, rowMeans(y_hat_posterior_test)) + expect_equal(sigma2_hat_mean_test, rowMeans(sigma2_hat_posterior_test)) + + # Check that the "single-term" pre-aggregated predictions + # match those computed by pre-aggregated predictions returned in a list + y_hat_mean_test_single_term <- predict( + het_bcf_model, + X = X_test, + Z = Z_test, + propensity = pi_test, + type = "mean", + terms = "y_hat" + ) + sigma2_hat_mean_test_single_term <- predict( + het_bcf_model, + X = X_test, + Z = Z_test, + propensity = pi_test, + type = "mean", + terms = "variance_forest" + ) + + # Assertion + expect_equal(y_hat_mean_test, y_hat_mean_test_single_term) + expect_equal(sigma2_hat_mean_test, sigma2_hat_mean_test_single_term) }) diff --git a/test/python/test_predict.py b/test/python/test_predict.py index d180fd67..0dda27d4 100644 --- a/test/python/test_predict.py +++ b/test/python/test_predict.py @@ -1,6 +1,10 @@ +import pytest import numpy as np +import pandas as pd +from sklearn.model_selection import train_test_split +from scipy.stats import norm -from stochtree import Dataset, ForestContainer +from stochtree import Dataset, ForestContainer, BARTModel, BCFModel class TestPredict: @@ -193,3 +197,196 @@ def test_multivariate_regression_leaf_prediction(self): # Assertion np.testing.assert_almost_equal(split_counts, split_counts_expected) + + def test_bart_prediction(self): + # Generate data and test/train split + rng = np.random.default_rng(1234) + n = 100 + p = 5 + X = rng.uniform(size=(n, p)) + f_XW = np.where((0 <= X[:, 0]) & (X[:, 0] < 0.25), -7.5, + np.where((0.25 <= X[:, 0]) & (X[:, 0] < 0.5), -2.5, + np.where((0.5 <= X[:, 0]) & (X[:, 0] < 0.75), 2.5, 7.5))) + noise_sd = 1 + y = f_XW + rng.normal(0, noise_sd, size=n) + test_set_pct = 0.2 + train_inds, test_inds = train_test_split(np.arange(n), test_size=test_set_pct, random_state=1234) + X_train = X[train_inds, :] + X_test = X[test_inds, :] + y_train = y[train_inds] + y_test = y[test_inds] + + # Fit a "classic" BART model + bart_model = BARTModel() + bart_model.sample(X_train = X_train, y_train = y_train, num_gfr=10, num_burnin=0, num_mcmc=10) + + # Check that the default predict method returns a dictionary + pred = bart_model.predict(covariates=X_test) + y_hat_posterior_test = pred['y_hat'] + assert y_hat_posterior_test.shape == (20, 10) + + # Check that the pre-aggregated predictions match with those computed by np.mean + pred_mean = bart_model.predict(covariates=X_test, type="mean") + y_hat_mean_test = pred_mean['y_hat'] + np.testing.assert_almost_equal(y_hat_mean_test, np.mean(y_hat_posterior_test, axis=1)) + + # Fit a heteroskedastic BART model + var_params = {'num_trees': 20} + het_bart_model = BARTModel() + het_bart_model.sample( + X_train = X_train, y_train = y_train, + num_gfr=10, num_burnin=0, num_mcmc=10, + variance_forest_params=var_params + ) + + # Check that the default predict method returns a dictionary + pred = het_bart_model.predict(covariates=X_test) + y_hat_posterior_test = pred['y_hat'] + sigma2_hat_posterior_test = pred['variance_forest_predictions'] + assert y_hat_posterior_test.shape == (20, 10) + assert sigma2_hat_posterior_test.shape == (20, 10) + + # Check that the pre-aggregated predictions match with those computed by np.mean + pred_mean = het_bart_model.predict(covariates=X_test, type="mean") + y_hat_mean_test = pred_mean['y_hat'] + sigma2_hat_mean_test = pred_mean['variance_forest_predictions'] + np.testing.assert_almost_equal(y_hat_mean_test, np.mean(y_hat_posterior_test, axis=1)) + np.testing.assert_almost_equal(sigma2_hat_mean_test, np.mean(sigma2_hat_posterior_test, axis=1)) + + # Check that the "single-term" pre-aggregated predictions + # match those computed by pre-aggregated predictions returned in a dictionary + y_hat_mean_test_single_term = het_bart_model.predict(covariates=X_test, type="mean", terms="y_hat") + sigma2_hat_mean_test_single_term = het_bart_model.predict(covariates=X_test, type="mean", terms="variance_forest") + np.testing.assert_almost_equal(y_hat_mean_test, y_hat_mean_test_single_term) + np.testing.assert_almost_equal(sigma2_hat_mean_test, sigma2_hat_mean_test_single_term) + + def test_bcf_prediction(self): + # Generate data and test/train split + rng = np.random.default_rng(1234) + + + # Convert the R code down below to Python + rng = np.random.default_rng(1234) + n = 100 + g = lambda x: np.where(x[:, 4] == 1, 2, np.where(x[:, 4] == 2, -1, -4)) + x1 = rng.normal(size=n) + x2 = rng.normal(size=n) + x3 = rng.normal(size=n) + x4 = rng.binomial(n=1,p=0.5,size=(n,)) + x5 = rng.choice(a=[0,1,2], size=(n,), replace=True) + x4_cat = pd.Categorical(x4, categories=[0,1], ordered=True) + x5_cat = pd.Categorical(x4, categories=[0,1,2], ordered=True) + p = 5 + X = pd.DataFrame(data={ + "x1": pd.Series(x1), + "x2": pd.Series(x2), + "x3": pd.Series(x3), + "x4": pd.Series(x4_cat), + "x5": pd.Series(x5_cat) + }) + def g(x5): + return np.where( + x5 == 0, 2.0, + np.where( + x5 == 1, -1.0, -4.0 + ) + ) + p = X.shape[1] + mu_x = 1.0 + g(x5) + x1*x3 + tau_x = 1.0 + 2*x2*x4 + pi_x = ( + 0.8 * norm.cdf(3.0 * mu_x / np.squeeze(np.std(mu_x)) - 0.5 * x1) + + 0.05 + rng.uniform(low=0., high=0.1, size=(n,)) + ) + Z = rng.binomial(n=1, p=pi_x, size=(n,)) + E_XZ = mu_x + tau_x*Z + snr = 2 + y = E_XZ + rng.normal(loc=0., scale=np.std(E_XZ) / snr, size=(n,)) + test_set_pct = 0.2 + train_inds, test_inds = train_test_split(np.arange(n), test_size=test_set_pct, random_state=1234) + X_train = X.iloc[train_inds,:] + X_test = X.iloc[test_inds,:] + Z_train = Z[train_inds] + Z_test = Z[test_inds] + pi_x_train = pi_x[train_inds] + pi_x_test = pi_x[test_inds] + y_train = y[train_inds] + y_test = y[test_inds] + + # Fit a "classic" BCF model + bcf_model = BCFModel() + bcf_model.sample( + X_train = X_train, + Z_train = Z_train, + y_train = y_train, + pi_train = pi_x_train, + X_test = X_test, + Z_test = Z_test, + pi_test = pi_x_test, + num_gfr = 10, + num_burnin = 0, + num_mcmc = 10 + ) + + # Check that the default predict method returns a list + pred = bcf_model.predict(X=X_test, Z=Z_test, propensity=pi_x_test) + y_hat_posterior_test = pred['y_hat'] + assert y_hat_posterior_test.shape == (20, 10) + + # Check that the pre-aggregated predictions match with those computed by np.mean + pred_mean = bcf_model.predict(X=X_test, Z=Z_test, propensity=pi_x_test, type="mean") + y_hat_mean_test = pred_mean['y_hat'] + np.testing.assert_almost_equal(y_hat_mean_test, np.mean(y_hat_posterior_test, axis=1)) + + # Check that we warn and return None when requesting terms that weren't fit + with pytest.warns(UserWarning): + pred_mean = bcf_model.predict( + X=X_test, Z=Z_test, propensity=pi_x_test, + type="mean", terms=["rfx", "variance_forest"] + ) + + # Fit a heteroskedastic BCF model + var_params = {'num_trees': 20} + het_bcf_model = BCFModel() + het_bcf_model.sample( + X_train = X_train, + Z_train = Z_train, + y_train = y_train, + pi_train = pi_x_train, + X_test = X_test, + Z_test = Z_test, + pi_test = pi_x_test, + num_gfr = 10, + num_burnin = 0, + num_mcmc = 10, + variance_forest_params=var_params + ) + + # Check that the default predict method returns a dictionary + pred = het_bcf_model.predict(X=X_test, Z=Z_test, propensity=pi_x_test) + y_hat_posterior_test = pred['y_hat'] + sigma2_hat_posterior_test = pred['variance_forest_predictions'] + assert y_hat_posterior_test.shape == (20, 10) + assert sigma2_hat_posterior_test.shape == (20, 10) + + # Check that the pre-aggregated predictions match with those computed by np.mean + pred_mean = het_bcf_model.predict(X=X_test, Z=Z_test, propensity=pi_x_test, type="mean") + y_hat_mean_test = pred_mean['y_hat'] + sigma2_hat_mean_test = pred_mean['variance_forest_predictions'] + np.testing.assert_almost_equal(y_hat_mean_test, np.mean(y_hat_posterior_test, axis=1)) + np.testing.assert_almost_equal(sigma2_hat_mean_test, np.mean(sigma2_hat_posterior_test, axis=1)) + + # Check that the "single-term" pre-aggregated predictions + # match those computed by pre-aggregated predictions returned in a dictionary + y_hat_mean_test_single_term = het_bcf_model.predict( + X=X_test, Z=Z_test, propensity=pi_x_test, + type="mean", terms="y_hat" + ) + sigma2_hat_mean_test_single_term = het_bcf_model.predict( + X=X_test, Z=Z_test, propensity=pi_x_test, + type="mean", terms="variance_forest" + ) + np.testing.assert_almost_equal(y_hat_mean_test, y_hat_mean_test_single_term) + np.testing.assert_almost_equal(sigma2_hat_mean_test, sigma2_hat_mean_test_single_term) + + From 763876eba5a45f0c918596179b847d8d3c06ef09 Mon Sep 17 00:00:00 2001 From: Drew Herren Date: Wed, 15 Oct 2025 15:55:09 -0500 Subject: [PATCH 07/53] Updated posterior transformation functions for R interface --- NAMESPACE | 4 + NEWS.md | 17 + R/bart.R | 117 +- R/bcf.R | 6771 ++++++++++++----------- R/posterior_transformation.R | 874 +++ man/compute_bart_posterior_interval.Rd | 53 + man/compute_bcf_posterior_interval.Rd | 63 + man/predict.bartmodel.Rd | 12 +- man/predict.bcfmodel.Rd | 5 +- man/sample_bart_posterior_predictive.Rd | 44 + man/sample_bcf_posterior_predictive.Rd | 47 + tools/debug/bart_predict_debug.R | 177 +- tools/debug/bart_prior_draws.R | 169 + tools/debug/bcf_predict_debug.R | 177 +- 14 files changed, 5093 insertions(+), 3437 deletions(-) create mode 100644 R/posterior_transformation.R create mode 100644 man/compute_bart_posterior_interval.Rd create mode 100644 man/compute_bcf_posterior_interval.Rd create mode 100644 man/sample_bart_posterior_predictive.Rd create mode 100644 man/sample_bcf_posterior_predictive.Rd create mode 100644 tools/debug/bart_prior_draws.R diff --git a/NAMESPACE b/NAMESPACE index 2f4103c0..ab6544bc 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -10,6 +10,8 @@ export(calibrateInverseGammaErrorVariance) export(computeForestLeafIndices) export(computeForestLeafVariances) export(computeForestMaxLeafIndex) +export(compute_bart_posterior_interval) +export(compute_bcf_posterior_interval) export(convertPreprocessorToJson) export(createBARTModelFromCombinedJson) export(createBARTModelFromCombinedJsonString) @@ -60,6 +62,8 @@ export(rootResetRandomEffectsModel) export(rootResetRandomEffectsTracker) export(sampleGlobalErrorVarianceOneIteration) export(sampleLeafVarianceOneIteration) +export(sample_bart_posterior_predictive) +export(sample_bcf_posterior_predictive) export(sample_without_replacement) export(saveBARTModelToJson) export(saveBARTModelToJsonFile) diff --git a/NEWS.md b/NEWS.md index fc00c05e..c794b856 100644 --- a/NEWS.md +++ b/NEWS.md @@ -2,13 +2,30 @@ ## New Features +* Support for multithreading in various elements of the GFR and MCMC algorithms ([#182](https://github.com/StochasticTree/stochtree/pull/182)) * Support for binary outcomes in BART and BCF with a probit link ([#164](https://github.com/StochasticTree/stochtree/pull/164)) +* Enable "restricted sweep" of tree algorithms over a handful of trees ([#173](https://github.com/StochasticTree/stochtree/pull/173)) +* Support for multivariate treatment in R ([#183](https://github.com/StochasticTree/stochtree/pull/183)) +* Enable modification of dataset variables (weights, etc...) via low-level interface ([#194](https://github.com/StochasticTree/stochtree/pull/194)) + +## Computational Improvements + +* Modified default random effects initialization ([#190](https://github.com/StochasticTree/stochtree/pull/190)) +* Avoid double prediction on training set ([#178](https://github.com/StochasticTree/stochtree/pull/178)) ## Bug Fixes * Fixed indexing bug in cleanup of grow-from-root (GFR) samples in BART and BCF models * Avoid using covariate preprocessor in `computeForestLeafIndices` function when a `ForestSamples` object is provided (rather than a `bartmodel` or `bcfmodel` object) +## Other Changes + +* Standardized naming conventions for out of sample data in prediction and posterior computation routines (we raise warnings when data are passed through `y`, `X`, `Z`, etc... arguments) + * Covariates / features are always referred to as "covariates" rather than "X" + * Treatment is referred to as "treatment" rather than "Z" + * Propensity scores are referred to as "propensity" rather than "pi_X" + * Outcomes are referred to as "outcome" rather than "Y" + # stochtree 0.1.1 * Fixed initialization bug in several R package code examples for random effects models diff --git a/R/bart.R b/R/bart.R index d2488674..30e0f8f3 100644 --- a/R/bart.R +++ b/R/bart.R @@ -1769,6 +1769,7 @@ bart <- function( #' @param rfx_basis (Optional) Test set basis for "random-slope" regression in additive random effects model. #' @param type (Optional) Type of prediction to return. Options are "mean", which averages the predictions from every draw of a BART model, and "posterior", which returns the entire matrix of posterior predictions. Default: "posterior". #' @param terms (Optional) Which model terms to include in the prediction. This can be a single term or a list of model terms. Options include "y_hat", "mean_forest", "rfx", "variance_forest", or "all". If a model doesn't have mean forest, random effects, or variance forest predictions, but one of those terms is request, the request will simply be ignored. If none of the requested terms are present in a model, this function will return `NULL` along with a warning. Default: "all". +#' @param scale (Optional) Scale of mean function predictions. Options are "linear", which returns predictions on the original scale of the mean forest / RFX terms, and "probability", which transforms predictions into a probability of observing `y == 1`. "probability" is only valid for models fit with a probit outcome model. Default: "linear". #' @param ... (Optional) Other prediction parameters. #' #' @return List of prediction matrices or single prediction matrix / vector, depending on the terms requested. @@ -1800,14 +1801,30 @@ bart <- function( #' y_hat_test <- predict(bart_model, X_test)$y_hat predict.bartmodel <- function( object, - X, + covariates, leaf_basis = NULL, rfx_group_ids = NULL, rfx_basis = NULL, type = "posterior", terms = "all", + scale = "linear", ... ) { + # Handle mean function scale + if (!is.character(scale)) { + stop("scale must be a string or character vector") + } + if (!(scale %in% c("linear", "probability"))) { + stop("scale must either be 'linear' or 'probability'") + } + is_probit <- object$model_params$probit_outcome_model + if ((scale == "probability") && (!is_probit)) { + stop( + "scale cannot be 'probability' for models not fit with a probit outcome model" + ) + } + probability_scale <- scale == "probability" + # Handle prediction type if (!is.character(type)) { stop("type must be a string or character vector") @@ -1852,12 +1869,24 @@ predict.bartmodel <- function( predict_rfx_intermediate <- (predict_y_hat && has_rfx) predict_mean_forest_intermediate <- (predict_y_hat && has_mean_forest) + # Check that we have at least one term to predict on probability scale + if ( + probability_scale && + !predict_y_hat && + !predict_mean_forest && + !predict_rfx + ) { + stop( + "scale can only be 'probability' if at least one mean term is requested" + ) + } + # Preprocess covariates - if ((!is.data.frame(X)) && (!is.matrix(X))) { - stop("X must be a matrix or dataframe") + if ((!is.data.frame(covariates)) && (!is.matrix(covariates))) { + stop("covariates must be a matrix or dataframe") } train_set_metadata <- object$train_set_metadata - X <- preprocessPredictionData(X, train_set_metadata) + X <- preprocessPredictionData(covariates, train_set_metadata) # Convert all input data to matrices if not already converted if ((is.null(dim(leaf_basis))) && (!is.null(leaf_basis))) { @@ -1922,25 +1951,37 @@ predict.bartmodel <- function( prediction_dataset <- createForestDataset(X) } + # Compute variance forest predictions + if (predict_variance_forest) { + s_x_raw <- object$variance_forests$predict(prediction_dataset) + } + + # Scale variance forest predictions + sigma2_init <- object$model_params$sigma2_init + if (predict_variance_forest) { + if (object$model_params$sample_sigma2_global) { + sigma2_global_samples <- object$sigma2_global_samples + variance_forest_predictions <- sapply(1:num_samples, function(i) { + s_x_raw[, i] * sigma2_global_samples[i] + }) + } else { + variance_forest_predictions <- s_x_raw * sigma2_init * y_std * y_std + } + if (predict_mean) { + variance_forest_predictions <- rowMeans(variance_forest_predictions) + } + } + # Compute mean forest predictions num_samples <- object$model_params$num_samples y_std <- object$model_params$outcome_scale y_bar <- object$model_params$outcome_mean - sigma2_init <- object$model_params$sigma2_init if (predict_mean_forest || predict_mean_forest_intermediate) { mean_forest_predictions <- object$mean_forests$predict( prediction_dataset ) * y_std + y_bar - if (predict_mean) { - mean_forest_predictions <- rowMeans(mean_forest_predictions) - } - } - - # Compute variance forest predictions - if (predict_variance_forest) { - s_x_raw <- object$variance_forests$predict(prediction_dataset) } # Compute rfx predictions (if needed) @@ -1950,32 +1991,42 @@ predict.bartmodel <- function( rfx_basis ) * y_std - if (predict_mean) { - rfx_predictions <- rowMeans(rfx_predictions) - } } - # Scale variance forest predictions - if (predict_variance_forest) { - if (object$model_params$sample_sigma2_global) { - sigma2_global_samples <- object$sigma2_global_samples - variance_forest_predictions <- sapply(1:num_samples, function(i) { - s_x_raw[, i] * sigma2_global_samples[i] - }) - } else { - variance_forest_predictions <- s_x_raw * sigma2_init * y_std * y_std + # Combine into y hat predictions + if (probability_scale) { + if (predict_y_hat && has_mean_forest && has_rfx) { + y_hat <- pnorm(mean_forest_predictions + rfx_predictions) + mean_forest_predictions <- pnorm(mean_forest_predictions) + rfx_predictions <- pnorm(rfx_predictions) + } else if (predict_y_hat && has_mean_forest) { + y_hat <- pnorm(mean_forest_predictions) + mean_forest_predictions <- pnorm(mean_forest_predictions) + } else if (predict_y_hat && has_rfx) { + y_hat <- pnorm(rfx_predictions) + rfx_predictions <- pnorm(rfx_predictions) } - if (predict_mean) { - variance_forest_predictions <- rowMeans(variance_forest_predictions) + } else { + if (predict_y_hat && has_mean_forest && has_rfx) { + y_hat <- mean_forest_predictions + rfx_predictions + } else if (predict_y_hat && has_mean_forest) { + y_hat <- mean_forest_predictions + } else if (predict_y_hat && has_rfx) { + y_hat <- rfx_predictions } } - if (predict_y_hat && has_mean_forest && has_rfx) { - y_hat <- mean_forest_predictions + rfx_predictions - } else if (predict_y_hat && has_mean_forest) { - y_hat <- mean_forest_predictions - } else if (predict_y_hat && has_rfx) { - y_hat <- rfx_predictions + # Collapse to posterior mean predictions if requested + if (predict_mean) { + if (predict_mean_forest) { + mean_forest_predictions <- rowMeans(mean_forest_predictions) + } + if (predict_rfx) { + rfx_predictions <- rowMeans(rfx_predictions) + } + if (predict_y_hat) { + y_hat <- rowMeans(y_hat) + } } if (predict_count == 1) { diff --git a/R/bcf.R b/R/bcf.R index 28bc7b42..b69e0ad2 100644 --- a/R/bcf.R +++ b/R/bcf.R @@ -153,2399 +153,2393 @@ #' propensity_test = pi_test, num_gfr = 10, #' num_burnin = 0, num_mcmc = 10) bcf <- function( - X_train, - Z_train, - y_train, - propensity_train = NULL, - rfx_group_ids_train = NULL, - rfx_basis_train = NULL, - X_test = NULL, - Z_test = NULL, - propensity_test = NULL, - rfx_group_ids_test = NULL, - rfx_basis_test = NULL, - num_gfr = 5, - num_burnin = 0, - num_mcmc = 100, - previous_model_json = NULL, - previous_model_warmstart_sample_num = NULL, - general_params = list(), - prognostic_forest_params = list(), - treatment_effect_forest_params = list(), - variance_forest_params = list() + X_train, + Z_train, + y_train, + propensity_train = NULL, + rfx_group_ids_train = NULL, + rfx_basis_train = NULL, + X_test = NULL, + Z_test = NULL, + propensity_test = NULL, + rfx_group_ids_test = NULL, + rfx_basis_test = NULL, + num_gfr = 5, + num_burnin = 0, + num_mcmc = 100, + previous_model_json = NULL, + previous_model_warmstart_sample_num = NULL, + general_params = list(), + prognostic_forest_params = list(), + treatment_effect_forest_params = list(), + variance_forest_params = list() ) { - # Update general BCF parameters - general_params_default <- list( - cutpoint_grid_size = 100, - standardize = TRUE, - sample_sigma2_global = TRUE, - sigma2_global_init = NULL, - sigma2_global_shape = 0, - sigma2_global_scale = 0, - variable_weights = NULL, - propensity_covariate = "mu", - adaptive_coding = TRUE, - control_coding_init = -0.5, - treated_coding_init = 0.5, - rfx_prior_var = NULL, - random_seed = -1, - keep_burnin = FALSE, - keep_gfr = FALSE, - keep_every = 1, - num_chains = 1, - verbose = FALSE, - probit_outcome_model = FALSE, - rfx_working_parameter_prior_mean = NULL, - rfx_group_parameter_prior_mean = NULL, - rfx_working_parameter_prior_cov = NULL, - rfx_group_parameter_prior_cov = NULL, - rfx_variance_prior_shape = 1, - rfx_variance_prior_scale = 1, - num_threads = -1 - ) - general_params_updated <- preprocessParams( - general_params_default, - general_params - ) - - # Update mu forest BCF parameters - prognostic_forest_params_default <- list( - num_trees = 250, - alpha = 0.95, - beta = 2.0, - min_samples_leaf = 5, - max_depth = 10, - sample_sigma2_leaf = TRUE, - sigma2_leaf_init = NULL, - sigma2_leaf_shape = 3, - sigma2_leaf_scale = NULL, - keep_vars = NULL, - drop_vars = NULL, - num_features_subsample = NULL - ) - prognostic_forest_params_updated <- preprocessParams( - prognostic_forest_params_default, - prognostic_forest_params - ) - - # Update tau forest BCF parameters - treatment_effect_forest_params_default <- list( - num_trees = 50, - alpha = 0.25, - beta = 3.0, - min_samples_leaf = 5, - max_depth = 5, - sample_sigma2_leaf = FALSE, - sigma2_leaf_init = NULL, - sigma2_leaf_shape = 3, - sigma2_leaf_scale = NULL, - keep_vars = NULL, - drop_vars = NULL, - delta_max = 0.9, - num_features_subsample = NULL - ) - treatment_effect_forest_params_updated <- preprocessParams( - treatment_effect_forest_params_default, - treatment_effect_forest_params - ) - - # Update variance forest BCF parameters - variance_forest_params_default <- list( - num_trees = 0, - alpha = 0.95, - beta = 2.0, - min_samples_leaf = 5, - max_depth = 10, - leaf_prior_calibration_param = 1.5, - variance_forest_init = NULL, - var_forest_prior_shape = NULL, - var_forest_prior_scale = NULL, - keep_vars = NULL, - drop_vars = NULL, - num_features_subsample = NULL - ) - variance_forest_params_updated <- preprocessParams( - variance_forest_params_default, - variance_forest_params - ) - - ### Unpack all parameter values - # 1. General parameters - cutpoint_grid_size <- general_params_updated$cutpoint_grid_size - standardize <- general_params_updated$standardize - sample_sigma2_global <- general_params_updated$sample_sigma2_global - sigma2_init <- general_params_updated$sigma2_global_init - a_global <- general_params_updated$sigma2_global_shape - b_global <- general_params_updated$sigma2_global_scale - variable_weights <- general_params_updated$variable_weights - propensity_covariate <- general_params_updated$propensity_covariate - adaptive_coding <- general_params_updated$adaptive_coding - b_0 <- general_params_updated$control_coding_init - b_1 <- general_params_updated$treated_coding_init - rfx_prior_var <- general_params_updated$rfx_prior_var - random_seed <- general_params_updated$random_seed - keep_burnin <- general_params_updated$keep_burnin - keep_gfr <- general_params_updated$keep_gfr - keep_every <- general_params_updated$keep_every - num_chains <- general_params_updated$num_chains - verbose <- general_params_updated$verbose - probit_outcome_model <- general_params_updated$probit_outcome_model - rfx_working_parameter_prior_mean <- general_params_updated$rfx_working_parameter_prior_mean - rfx_group_parameter_prior_mean <- general_params_updated$rfx_group_parameter_prior_mean - rfx_working_parameter_prior_cov <- general_params_updated$rfx_working_parameter_prior_cov - rfx_group_parameter_prior_cov <- general_params_updated$rfx_group_parameter_prior_cov - rfx_variance_prior_shape <- general_params_updated$rfx_variance_prior_shape - rfx_variance_prior_scale <- general_params_updated$rfx_variance_prior_scale - num_threads <- general_params_updated$num_threads - - # 2. Mu forest parameters - num_trees_mu <- prognostic_forest_params_updated$num_trees - alpha_mu <- prognostic_forest_params_updated$alpha - beta_mu <- prognostic_forest_params_updated$beta - min_samples_leaf_mu <- prognostic_forest_params_updated$min_samples_leaf - max_depth_mu <- prognostic_forest_params_updated$max_depth - sample_sigma2_leaf_mu <- prognostic_forest_params_updated$sample_sigma2_leaf - sigma2_leaf_mu <- prognostic_forest_params_updated$sigma2_leaf_init - a_leaf_mu <- prognostic_forest_params_updated$sigma2_leaf_shape - b_leaf_mu <- prognostic_forest_params_updated$sigma2_leaf_scale - keep_vars_mu <- prognostic_forest_params_updated$keep_vars - drop_vars_mu <- prognostic_forest_params_updated$drop_vars - num_features_subsample_mu <- prognostic_forest_params_updated$num_features_subsample - - # 3. Tau forest parameters - num_trees_tau <- treatment_effect_forest_params_updated$num_trees - alpha_tau <- treatment_effect_forest_params_updated$alpha - beta_tau <- treatment_effect_forest_params_updated$beta - min_samples_leaf_tau <- treatment_effect_forest_params_updated$min_samples_leaf - max_depth_tau <- treatment_effect_forest_params_updated$max_depth - sample_sigma2_leaf_tau <- treatment_effect_forest_params_updated$sample_sigma2_leaf - sigma2_leaf_tau <- treatment_effect_forest_params_updated$sigma2_leaf_init - a_leaf_tau <- treatment_effect_forest_params_updated$sigma2_leaf_shape - b_leaf_tau <- treatment_effect_forest_params_updated$sigma2_leaf_scale - keep_vars_tau <- treatment_effect_forest_params_updated$keep_vars - drop_vars_tau <- treatment_effect_forest_params_updated$drop_vars - delta_max <- treatment_effect_forest_params_updated$delta_max - num_features_subsample_tau <- treatment_effect_forest_params_updated$num_features_subsample - - # 4. Variance forest parameters - num_trees_variance <- variance_forest_params_updated$num_trees - alpha_variance <- variance_forest_params_updated$alpha - beta_variance <- variance_forest_params_updated$beta - min_samples_leaf_variance <- variance_forest_params_updated$min_samples_leaf - max_depth_variance <- variance_forest_params_updated$max_depth - a_0 <- variance_forest_params_updated$leaf_prior_calibration_param - variance_forest_init <- variance_forest_params_updated$init_root_val - a_forest <- variance_forest_params_updated$var_forest_prior_shape - b_forest <- variance_forest_params_updated$var_forest_prior_scale - keep_vars_variance <- variance_forest_params_updated$keep_vars - drop_vars_variance <- variance_forest_params_updated$drop_vars - num_features_subsample_variance <- variance_forest_params_updated$num_features_subsample - - # Check if there are enough GFR samples to seed num_chains samplers - if (num_gfr > 0) { - if (num_chains > num_gfr) { - stop( - "num_chains > num_gfr, meaning we do not have enough GFR samples to seed num_chains distinct MCMC chains" - ) - } - } - - # Override keep_gfr if there are no MCMC samples - if (num_mcmc == 0) { - keep_gfr <- TRUE - } - - # Check if previous model JSON is provided and parse it if so - has_prev_model <- !is.null(previous_model_json) - if (has_prev_model) { - previous_bcf_model <- createBCFModelFromJsonString(previous_model_json) - previous_y_bar <- previous_bcf_model$model_params$outcome_mean - previous_y_scale <- previous_bcf_model$model_params$outcome_scale - previous_forest_samples_mu <- previous_bcf_model$forests_mu - previous_forest_samples_tau <- previous_bcf_model$forests_tau - if (previous_bcf_model$model_params$include_variance_forest) { - previous_forest_samples_variance <- previous_bcf_model$forests_variance - } else { - previous_forest_samples_variance <- NULL - } - if (previous_bcf_model$model_params$sample_sigma2_global) { - previous_global_var_samples <- previous_bcf_model$sigma2_global_samples / - (previous_y_scale * previous_y_scale) - } else { - previous_global_var_samples <- NULL - } - if (previous_bcf_model$model_params$sample_sigma2_leaf_mu) { - previous_leaf_var_mu_samples <- previous_bcf_model$sigma2_leaf_mu_samples - } else { - previous_leaf_var_mu_samples <- NULL - } - if (previous_bcf_model$model_params$sample_sigma2_leaf_tau) { - previous_leaf_var_tau_samples <- previous_bcf_model$sigma2_leaf_tau_samples - } else { - previous_leaf_var_tau_samples <- NULL - } - if (previous_bcf_model$model_params$has_rfx) { - previous_rfx_samples <- previous_bcf_model$rfx_samples - } else { - previous_rfx_samples <- NULL - } - if (previous_bcf_model$model_params$adaptive_coding) { - previous_b_1_samples <- previous_bcf_model$b_1_samples - previous_b_0_samples <- previous_bcf_model$b_0_samples - } else { - previous_b_1_samples <- NULL - previous_b_0_samples <- NULL - } - previous_model_num_samples <- previous_bcf_model$model_params$num_samples - if (previous_model_warmstart_sample_num > previous_model_num_samples) { - stop( - "`previous_model_warmstart_sample_num` exceeds the number of samples in `previous_model_json`" - ) - } + # Update general BCF parameters + general_params_default <- list( + cutpoint_grid_size = 100, + standardize = TRUE, + sample_sigma2_global = TRUE, + sigma2_global_init = NULL, + sigma2_global_shape = 0, + sigma2_global_scale = 0, + variable_weights = NULL, + propensity_covariate = "mu", + adaptive_coding = TRUE, + control_coding_init = -0.5, + treated_coding_init = 0.5, + rfx_prior_var = NULL, + random_seed = -1, + keep_burnin = FALSE, + keep_gfr = FALSE, + keep_every = 1, + num_chains = 1, + verbose = FALSE, + probit_outcome_model = FALSE, + rfx_working_parameter_prior_mean = NULL, + rfx_group_parameter_prior_mean = NULL, + rfx_working_parameter_prior_cov = NULL, + rfx_group_parameter_prior_cov = NULL, + rfx_variance_prior_shape = 1, + rfx_variance_prior_scale = 1, + num_threads = -1 + ) + general_params_updated <- preprocessParams( + general_params_default, + general_params + ) + + # Update mu forest BCF parameters + prognostic_forest_params_default <- list( + num_trees = 250, + alpha = 0.95, + beta = 2.0, + min_samples_leaf = 5, + max_depth = 10, + sample_sigma2_leaf = TRUE, + sigma2_leaf_init = NULL, + sigma2_leaf_shape = 3, + sigma2_leaf_scale = NULL, + keep_vars = NULL, + drop_vars = NULL, + num_features_subsample = NULL + ) + prognostic_forest_params_updated <- preprocessParams( + prognostic_forest_params_default, + prognostic_forest_params + ) + + # Update tau forest BCF parameters + treatment_effect_forest_params_default <- list( + num_trees = 50, + alpha = 0.25, + beta = 3.0, + min_samples_leaf = 5, + max_depth = 5, + sample_sigma2_leaf = FALSE, + sigma2_leaf_init = NULL, + sigma2_leaf_shape = 3, + sigma2_leaf_scale = NULL, + keep_vars = NULL, + drop_vars = NULL, + delta_max = 0.9, + num_features_subsample = NULL + ) + treatment_effect_forest_params_updated <- preprocessParams( + treatment_effect_forest_params_default, + treatment_effect_forest_params + ) + + # Update variance forest BCF parameters + variance_forest_params_default <- list( + num_trees = 0, + alpha = 0.95, + beta = 2.0, + min_samples_leaf = 5, + max_depth = 10, + leaf_prior_calibration_param = 1.5, + variance_forest_init = NULL, + var_forest_prior_shape = NULL, + var_forest_prior_scale = NULL, + keep_vars = NULL, + drop_vars = NULL, + num_features_subsample = NULL + ) + variance_forest_params_updated <- preprocessParams( + variance_forest_params_default, + variance_forest_params + ) + + ### Unpack all parameter values + # 1. General parameters + cutpoint_grid_size <- general_params_updated$cutpoint_grid_size + standardize <- general_params_updated$standardize + sample_sigma2_global <- general_params_updated$sample_sigma2_global + sigma2_init <- general_params_updated$sigma2_global_init + a_global <- general_params_updated$sigma2_global_shape + b_global <- general_params_updated$sigma2_global_scale + variable_weights <- general_params_updated$variable_weights + propensity_covariate <- general_params_updated$propensity_covariate + adaptive_coding <- general_params_updated$adaptive_coding + b_0 <- general_params_updated$control_coding_init + b_1 <- general_params_updated$treated_coding_init + rfx_prior_var <- general_params_updated$rfx_prior_var + random_seed <- general_params_updated$random_seed + keep_burnin <- general_params_updated$keep_burnin + keep_gfr <- general_params_updated$keep_gfr + keep_every <- general_params_updated$keep_every + num_chains <- general_params_updated$num_chains + verbose <- general_params_updated$verbose + probit_outcome_model <- general_params_updated$probit_outcome_model + rfx_working_parameter_prior_mean <- general_params_updated$rfx_working_parameter_prior_mean + rfx_group_parameter_prior_mean <- general_params_updated$rfx_group_parameter_prior_mean + rfx_working_parameter_prior_cov <- general_params_updated$rfx_working_parameter_prior_cov + rfx_group_parameter_prior_cov <- general_params_updated$rfx_group_parameter_prior_cov + rfx_variance_prior_shape <- general_params_updated$rfx_variance_prior_shape + rfx_variance_prior_scale <- general_params_updated$rfx_variance_prior_scale + num_threads <- general_params_updated$num_threads + + # 2. Mu forest parameters + num_trees_mu <- prognostic_forest_params_updated$num_trees + alpha_mu <- prognostic_forest_params_updated$alpha + beta_mu <- prognostic_forest_params_updated$beta + min_samples_leaf_mu <- prognostic_forest_params_updated$min_samples_leaf + max_depth_mu <- prognostic_forest_params_updated$max_depth + sample_sigma2_leaf_mu <- prognostic_forest_params_updated$sample_sigma2_leaf + sigma2_leaf_mu <- prognostic_forest_params_updated$sigma2_leaf_init + a_leaf_mu <- prognostic_forest_params_updated$sigma2_leaf_shape + b_leaf_mu <- prognostic_forest_params_updated$sigma2_leaf_scale + keep_vars_mu <- prognostic_forest_params_updated$keep_vars + drop_vars_mu <- prognostic_forest_params_updated$drop_vars + num_features_subsample_mu <- prognostic_forest_params_updated$num_features_subsample + + # 3. Tau forest parameters + num_trees_tau <- treatment_effect_forest_params_updated$num_trees + alpha_tau <- treatment_effect_forest_params_updated$alpha + beta_tau <- treatment_effect_forest_params_updated$beta + min_samples_leaf_tau <- treatment_effect_forest_params_updated$min_samples_leaf + max_depth_tau <- treatment_effect_forest_params_updated$max_depth + sample_sigma2_leaf_tau <- treatment_effect_forest_params_updated$sample_sigma2_leaf + sigma2_leaf_tau <- treatment_effect_forest_params_updated$sigma2_leaf_init + a_leaf_tau <- treatment_effect_forest_params_updated$sigma2_leaf_shape + b_leaf_tau <- treatment_effect_forest_params_updated$sigma2_leaf_scale + keep_vars_tau <- treatment_effect_forest_params_updated$keep_vars + drop_vars_tau <- treatment_effect_forest_params_updated$drop_vars + delta_max <- treatment_effect_forest_params_updated$delta_max + num_features_subsample_tau <- treatment_effect_forest_params_updated$num_features_subsample + + # 4. Variance forest parameters + num_trees_variance <- variance_forest_params_updated$num_trees + alpha_variance <- variance_forest_params_updated$alpha + beta_variance <- variance_forest_params_updated$beta + min_samples_leaf_variance <- variance_forest_params_updated$min_samples_leaf + max_depth_variance <- variance_forest_params_updated$max_depth + a_0 <- variance_forest_params_updated$leaf_prior_calibration_param + variance_forest_init <- variance_forest_params_updated$init_root_val + a_forest <- variance_forest_params_updated$var_forest_prior_shape + b_forest <- variance_forest_params_updated$var_forest_prior_scale + keep_vars_variance <- variance_forest_params_updated$keep_vars + drop_vars_variance <- variance_forest_params_updated$drop_vars + num_features_subsample_variance <- variance_forest_params_updated$num_features_subsample + + # Check if there are enough GFR samples to seed num_chains samplers + if (num_gfr > 0) { + if (num_chains > num_gfr) { + stop( + "num_chains > num_gfr, meaning we do not have enough GFR samples to seed num_chains distinct MCMC chains" + ) + } + } + + # Override keep_gfr if there are no MCMC samples + if (num_mcmc == 0) { + keep_gfr <- TRUE + } + + # Check if previous model JSON is provided and parse it if so + has_prev_model <- !is.null(previous_model_json) + if (has_prev_model) { + previous_bcf_model <- createBCFModelFromJsonString(previous_model_json) + previous_y_bar <- previous_bcf_model$model_params$outcome_mean + previous_y_scale <- previous_bcf_model$model_params$outcome_scale + previous_forest_samples_mu <- previous_bcf_model$forests_mu + previous_forest_samples_tau <- previous_bcf_model$forests_tau + if (previous_bcf_model$model_params$include_variance_forest) { + previous_forest_samples_variance <- previous_bcf_model$forests_variance } else { - previous_y_bar <- NULL - previous_y_scale <- NULL - previous_global_var_samples <- NULL - previous_leaf_var_mu_samples <- NULL - previous_leaf_var_tau_samples <- NULL - previous_rfx_samples <- NULL - previous_forest_samples_mu <- NULL - previous_forest_samples_tau <- NULL - previous_forest_samples_variance <- NULL - previous_b_1_samples <- NULL - previous_b_0_samples <- NULL + previous_forest_samples_variance <- NULL } - - # Determine whether conditional variance will be modeled - if (num_trees_variance > 0) { - include_variance_forest = TRUE + if (previous_bcf_model$model_params$sample_sigma2_global) { + previous_global_var_samples <- previous_bcf_model$sigma2_global_samples / + (previous_y_scale * previous_y_scale) } else { - include_variance_forest = FALSE + previous_global_var_samples <- NULL } - - # Set the variance forest priors if not set - if (include_variance_forest) { - if (is.null(a_forest)) { - a_forest <- num_trees_variance / (a_0^2) + 0.5 - } - if (is.null(b_forest)) b_forest <- num_trees_variance / (a_0^2) + if (previous_bcf_model$model_params$sample_sigma2_leaf_mu) { + previous_leaf_var_mu_samples <- previous_bcf_model$sigma2_leaf_mu_samples } else { - a_forest <- 1. - b_forest <- 1. - } - - # Variable weight preprocessing (and initialization if necessary) - if (is.null(variable_weights)) { - variable_weights = rep(1 / ncol(X_train), ncol(X_train)) - } - if (any(variable_weights < 0)) { - stop("variable_weights cannot have any negative weights") - } - - # Check covariates are matrix or dataframe - if ((!is.data.frame(X_train)) && (!is.matrix(X_train))) { - stop("X_train must be a matrix or dataframe") + previous_leaf_var_mu_samples <- NULL } - if (!is.null(X_test)) { - if ((!is.data.frame(X_test)) && (!is.matrix(X_test))) { - stop("X_test must be a matrix or dataframe") - } - } - num_cov_orig <- ncol(X_train) - - # Check delta_max is valid - if ((delta_max <= 0) || (delta_max >= 1)) { - stop("delta_max must be > 0 and < 1") - } - - # Standardize the keep variable lists to numeric indices - if (!is.null(keep_vars_mu)) { - if (is.character(keep_vars_mu)) { - if (!all(keep_vars_mu %in% names(X_train))) { - stop( - "keep_vars_mu includes some variable names that are not in X_train" - ) - } - variable_subset_mu <- unname(which( - names(X_train) %in% keep_vars_mu - )) - } else { - if (any(keep_vars_mu > ncol(X_train))) { - stop( - "keep_vars_mu includes some variable indices that exceed the number of columns in X_train" - ) - } - if (any(keep_vars_mu < 0)) { - stop("keep_vars_mu includes some negative variable indices") - } - variable_subset_mu <- keep_vars_mu - } - } else if ((is.null(keep_vars_mu)) && (!is.null(drop_vars_mu))) { - if (is.character(drop_vars_mu)) { - if (!all(drop_vars_mu %in% names(X_train))) { - stop( - "drop_vars_mu includes some variable names that are not in X_train" - ) - } - variable_subset_mu <- unname(which( - !(names(X_train) %in% drop_vars_mu) - )) - } else { - if (any(drop_vars_mu > ncol(X_train))) { - stop( - "drop_vars_mu includes some variable indices that exceed the number of columns in X_train" - ) - } - if (any(drop_vars_mu < 0)) { - stop("drop_vars_mu includes some negative variable indices") - } - variable_subset_mu <- (1:ncol(X_train))[ - !(1:ncol(X_train) %in% drop_vars_mu) - ] - } + if (previous_bcf_model$model_params$sample_sigma2_leaf_tau) { + previous_leaf_var_tau_samples <- previous_bcf_model$sigma2_leaf_tau_samples } else { - variable_subset_mu <- 1:ncol(X_train) + previous_leaf_var_tau_samples <- NULL } - if (!is.null(keep_vars_tau)) { - if (is.character(keep_vars_tau)) { - if (!all(keep_vars_tau %in% names(X_train))) { - stop( - "keep_vars_tau includes some variable names that are not in X_train" - ) - } - variable_subset_tau <- unname(which( - names(X_train) %in% keep_vars_tau - )) - } else { - if (any(keep_vars_tau > ncol(X_train))) { - stop( - "keep_vars_tau includes some variable indices that exceed the number of columns in X_train" - ) - } - if (any(keep_vars_tau < 0)) { - stop("keep_vars_tau includes some negative variable indices") - } - variable_subset_tau <- keep_vars_tau - } - } else if ((is.null(keep_vars_tau)) && (!is.null(drop_vars_tau))) { - if (is.character(drop_vars_tau)) { - if (!all(drop_vars_tau %in% names(X_train))) { - stop( - "drop_vars_tau includes some variable names that are not in X_train" - ) - } - variable_subset_tau <- unname(which( - !(names(X_train) %in% drop_vars_tau) - )) - } else { - if (any(drop_vars_tau > ncol(X_train))) { - stop( - "drop_vars_tau includes some variable indices that exceed the number of columns in X_train" - ) - } - if (any(drop_vars_tau < 0)) { - stop("drop_vars_tau includes some negative variable indices") - } - variable_subset_tau <- (1:ncol(X_train))[ - !(1:ncol(X_train) %in% drop_vars_tau) - ] - } + if (previous_bcf_model$model_params$has_rfx) { + previous_rfx_samples <- previous_bcf_model$rfx_samples } else { - variable_subset_tau <- 1:ncol(X_train) + previous_rfx_samples <- NULL } - if (!is.null(keep_vars_variance)) { - if (is.character(keep_vars_variance)) { - if (!all(keep_vars_variance %in% names(X_train))) { - stop( - "keep_vars_variance includes some variable names that are not in X_train" - ) - } - variable_subset_variance <- unname(which( - names(X_train) %in% keep_vars_variance - )) - } else { - if (any(keep_vars_variance > ncol(X_train))) { - stop( - "keep_vars_variance includes some variable indices that exceed the number of columns in X_train" - ) - } - if (any(keep_vars_variance < 0)) { - stop( - "keep_vars_variance includes some negative variable indices" - ) - } - variable_subset_variance <- keep_vars_variance - } - } else if ( - (is.null(keep_vars_variance)) && (!is.null(drop_vars_variance)) - ) { - if (is.character(drop_vars_variance)) { - if (!all(drop_vars_variance %in% names(X_train))) { - stop( - "drop_vars_variance includes some variable names that are not in X_train" - ) - } - variable_subset_variance <- unname(which( - !(names(X_train) %in% drop_vars_variance) - )) - } else { - if (any(drop_vars_variance > ncol(X_train))) { - stop( - "drop_vars_variance includes some variable indices that exceed the number of columns in X_train" - ) - } - if (any(drop_vars_variance < 0)) { - stop( - "drop_vars_variance includes some negative variable indices" - ) - } - variable_subset_variance <- (1:ncol(X_train))[ - !(1:ncol(X_train) %in% drop_vars_variance) - ] - } + if (previous_bcf_model$model_params$adaptive_coding) { + previous_b_1_samples <- previous_bcf_model$b_1_samples + previous_b_0_samples <- previous_bcf_model$b_0_samples } else { - variable_subset_variance <- 1:ncol(X_train) - } - - # Preprocess covariates - if (ncol(X_train) != length(variable_weights)) { - stop("length(variable_weights) must equal ncol(X_train)") - } - train_cov_preprocess_list <- preprocessTrainData(X_train) - X_train_metadata <- train_cov_preprocess_list$metadata - X_train_raw <- X_train - X_train <- train_cov_preprocess_list$data - original_var_indices <- X_train_metadata$original_var_indices - feature_types <- X_train_metadata$feature_types - X_test_raw <- X_test - if (!is.null(X_test)) { - X_test <- preprocessPredictionData(X_test, X_train_metadata) - } - - # Convert all input data to matrices if not already converted - Z_col <- ifelse(is.null(dim(Z_train)), 1, ncol(Z_train)) - Z_train <- matrix(as.numeric(Z_train), ncol = Z_col) - if ((is.null(dim(propensity_train))) && (!is.null(propensity_train))) { - propensity_train <- as.matrix(propensity_train) - } - if (!is.null(Z_test)) { - Z_test <- matrix(as.numeric(Z_test), ncol = Z_col) - } - if ((is.null(dim(propensity_test))) && (!is.null(propensity_test))) { - propensity_test <- as.matrix(propensity_test) - } - if ((is.null(dim(rfx_basis_train))) && (!is.null(rfx_basis_train))) { - rfx_basis_train <- as.matrix(rfx_basis_train) + previous_b_1_samples <- NULL + previous_b_0_samples <- NULL + } + previous_model_num_samples <- previous_bcf_model$model_params$num_samples + if (previous_model_warmstart_sample_num > previous_model_num_samples) { + stop( + "`previous_model_warmstart_sample_num` exceeds the number of samples in `previous_model_json`" + ) + } + } else { + previous_y_bar <- NULL + previous_y_scale <- NULL + previous_global_var_samples <- NULL + previous_leaf_var_mu_samples <- NULL + previous_leaf_var_tau_samples <- NULL + previous_rfx_samples <- NULL + previous_forest_samples_mu <- NULL + previous_forest_samples_tau <- NULL + previous_forest_samples_variance <- NULL + previous_b_1_samples <- NULL + previous_b_0_samples <- NULL + } + + # Determine whether conditional variance will be modeled + if (num_trees_variance > 0) { + include_variance_forest = TRUE + } else { + include_variance_forest = FALSE + } + + # Set the variance forest priors if not set + if (include_variance_forest) { + if (is.null(a_forest)) { + a_forest <- num_trees_variance / (a_0^2) + 0.5 + } + if (is.null(b_forest)) b_forest <- num_trees_variance / (a_0^2) + } else { + a_forest <- 1. + b_forest <- 1. + } + + # Variable weight preprocessing (and initialization if necessary) + if (is.null(variable_weights)) { + variable_weights = rep(1 / ncol(X_train), ncol(X_train)) + } + if (any(variable_weights < 0)) { + stop("variable_weights cannot have any negative weights") + } + + # Check covariates are matrix or dataframe + if ((!is.data.frame(X_train)) && (!is.matrix(X_train))) { + stop("X_train must be a matrix or dataframe") + } + if (!is.null(X_test)) { + if ((!is.data.frame(X_test)) && (!is.matrix(X_test))) { + stop("X_test must be a matrix or dataframe") + } + } + num_cov_orig <- ncol(X_train) + + # Check delta_max is valid + if ((delta_max <= 0) || (delta_max >= 1)) { + stop("delta_max must be > 0 and < 1") + } + + # Standardize the keep variable lists to numeric indices + if (!is.null(keep_vars_mu)) { + if (is.character(keep_vars_mu)) { + if (!all(keep_vars_mu %in% names(X_train))) { + stop( + "keep_vars_mu includes some variable names that are not in X_train" + ) + } + variable_subset_mu <- unname(which( + names(X_train) %in% keep_vars_mu + )) + } else { + if (any(keep_vars_mu > ncol(X_train))) { + stop( + "keep_vars_mu includes some variable indices that exceed the number of columns in X_train" + ) + } + if (any(keep_vars_mu < 0)) { + stop("keep_vars_mu includes some negative variable indices") + } + variable_subset_mu <- keep_vars_mu + } + } else if ((is.null(keep_vars_mu)) && (!is.null(drop_vars_mu))) { + if (is.character(drop_vars_mu)) { + if (!all(drop_vars_mu %in% names(X_train))) { + stop( + "drop_vars_mu includes some variable names that are not in X_train" + ) + } + variable_subset_mu <- unname(which( + !(names(X_train) %in% drop_vars_mu) + )) + } else { + if (any(drop_vars_mu > ncol(X_train))) { + stop( + "drop_vars_mu includes some variable indices that exceed the number of columns in X_train" + ) + } + if (any(drop_vars_mu < 0)) { + stop("drop_vars_mu includes some negative variable indices") + } + variable_subset_mu <- (1:ncol(X_train))[ + !(1:ncol(X_train) %in% drop_vars_mu) + ] + } + } else { + variable_subset_mu <- 1:ncol(X_train) + } + if (!is.null(keep_vars_tau)) { + if (is.character(keep_vars_tau)) { + if (!all(keep_vars_tau %in% names(X_train))) { + stop( + "keep_vars_tau includes some variable names that are not in X_train" + ) + } + variable_subset_tau <- unname(which( + names(X_train) %in% keep_vars_tau + )) + } else { + if (any(keep_vars_tau > ncol(X_train))) { + stop( + "keep_vars_tau includes some variable indices that exceed the number of columns in X_train" + ) + } + if (any(keep_vars_tau < 0)) { + stop("keep_vars_tau includes some negative variable indices") + } + variable_subset_tau <- keep_vars_tau + } + } else if ((is.null(keep_vars_tau)) && (!is.null(drop_vars_tau))) { + if (is.character(drop_vars_tau)) { + if (!all(drop_vars_tau %in% names(X_train))) { + stop( + "drop_vars_tau includes some variable names that are not in X_train" + ) + } + variable_subset_tau <- unname(which( + !(names(X_train) %in% drop_vars_tau) + )) + } else { + if (any(drop_vars_tau > ncol(X_train))) { + stop( + "drop_vars_tau includes some variable indices that exceed the number of columns in X_train" + ) + } + if (any(drop_vars_tau < 0)) { + stop("drop_vars_tau includes some negative variable indices") + } + variable_subset_tau <- (1:ncol(X_train))[ + !(1:ncol(X_train) %in% drop_vars_tau) + ] + } + } else { + variable_subset_tau <- 1:ncol(X_train) + } + if (!is.null(keep_vars_variance)) { + if (is.character(keep_vars_variance)) { + if (!all(keep_vars_variance %in% names(X_train))) { + stop( + "keep_vars_variance includes some variable names that are not in X_train" + ) + } + variable_subset_variance <- unname(which( + names(X_train) %in% keep_vars_variance + )) + } else { + if (any(keep_vars_variance > ncol(X_train))) { + stop( + "keep_vars_variance includes some variable indices that exceed the number of columns in X_train" + ) + } + if (any(keep_vars_variance < 0)) { + stop( + "keep_vars_variance includes some negative variable indices" + ) + } + variable_subset_variance <- keep_vars_variance } - if ((is.null(dim(rfx_basis_test))) && (!is.null(rfx_basis_test))) { - rfx_basis_test <- as.matrix(rfx_basis_test) + } else if ((is.null(keep_vars_variance)) && (!is.null(drop_vars_variance))) { + if (is.character(drop_vars_variance)) { + if (!all(drop_vars_variance %in% names(X_train))) { + stop( + "drop_vars_variance includes some variable names that are not in X_train" + ) + } + variable_subset_variance <- unname(which( + !(names(X_train) %in% drop_vars_variance) + )) + } else { + if (any(drop_vars_variance > ncol(X_train))) { + stop( + "drop_vars_variance includes some variable indices that exceed the number of columns in X_train" + ) + } + if (any(drop_vars_variance < 0)) { + stop( + "drop_vars_variance includes some negative variable indices" + ) + } + variable_subset_variance <- (1:ncol(X_train))[ + !(1:ncol(X_train) %in% drop_vars_variance) + ] + } + } else { + variable_subset_variance <- 1:ncol(X_train) + } + + # Preprocess covariates + if (ncol(X_train) != length(variable_weights)) { + stop("length(variable_weights) must equal ncol(X_train)") + } + train_cov_preprocess_list <- preprocessTrainData(X_train) + X_train_metadata <- train_cov_preprocess_list$metadata + X_train_raw <- X_train + X_train <- train_cov_preprocess_list$data + original_var_indices <- X_train_metadata$original_var_indices + feature_types <- X_train_metadata$feature_types + X_test_raw <- X_test + if (!is.null(X_test)) { + X_test <- preprocessPredictionData(X_test, X_train_metadata) + } + + # Convert all input data to matrices if not already converted + Z_col <- ifelse(is.null(dim(Z_train)), 1, ncol(Z_train)) + Z_train <- matrix(as.numeric(Z_train), ncol = Z_col) + if ((is.null(dim(propensity_train))) && (!is.null(propensity_train))) { + propensity_train <- as.matrix(propensity_train) + } + if (!is.null(Z_test)) { + Z_test <- matrix(as.numeric(Z_test), ncol = Z_col) + } + if ((is.null(dim(propensity_test))) && (!is.null(propensity_test))) { + propensity_test <- as.matrix(propensity_test) + } + if ((is.null(dim(rfx_basis_train))) && (!is.null(rfx_basis_train))) { + rfx_basis_train <- as.matrix(rfx_basis_train) + } + if ((is.null(dim(rfx_basis_test))) && (!is.null(rfx_basis_test))) { + rfx_basis_test <- as.matrix(rfx_basis_test) + } + + # Recode group IDs to integer vector (if passed as, for example, a vector of county names, etc...) + has_rfx <- FALSE + has_rfx_test <- FALSE + if (!is.null(rfx_group_ids_train)) { + group_ids_factor <- factor(rfx_group_ids_train) + rfx_group_ids_train <- as.integer(group_ids_factor) + has_rfx <- TRUE + if (!is.null(rfx_group_ids_test)) { + group_ids_factor_test <- factor( + rfx_group_ids_test, + levels = levels(group_ids_factor) + ) + if (sum(is.na(group_ids_factor_test)) > 0) { + stop( + "All random effect group labels provided in rfx_group_ids_test must be present in rfx_group_ids_train" + ) + } + rfx_group_ids_test <- as.integer(group_ids_factor_test) + has_rfx_test <- TRUE + } + } + + # Check that outcome and treatment are numeric + if (!is.numeric(y_train)) { + stop("y_train must be numeric") + } + if (!is.numeric(Z_train)) { + stop("Z_train must be numeric") + } + if (!is.null(Z_test)) { + if (!is.numeric(Z_test)) stop("Z_test must be numeric") + } + + # Data consistency checks + if ((!is.null(X_test)) && (ncol(X_test) != ncol(X_train))) { + stop("X_train and X_test must have the same number of columns") + } + if ((!is.null(Z_test)) && (ncol(Z_test) != ncol(Z_train))) { + stop("Z_train and Z_test must have the same number of columns") + } + if ((!is.null(Z_train)) && (nrow(Z_train) != nrow(X_train))) { + stop("Z_train and X_train must have the same number of rows") + } + if ( + (!is.null(propensity_train)) && + (nrow(propensity_train) != nrow(X_train)) + ) { + stop("propensity_train and X_train must have the same number of rows") + } + if ((!is.null(Z_test)) && (nrow(Z_test) != nrow(X_test))) { + stop("Z_test and X_test must have the same number of rows") + } + if ((!is.null(propensity_test)) && (nrow(propensity_test) != nrow(X_test))) { + stop("propensity_test and X_test must have the same number of rows") + } + if (nrow(X_train) != length(y_train)) { + stop("X_train and y_train must have the same number of observations") + } + if ( + (!is.null(rfx_basis_test)) && + (ncol(rfx_basis_test) != ncol(rfx_basis_train)) + ) { + stop( + "rfx_basis_train and rfx_basis_test must have the same number of columns" + ) + } + if (!is.null(rfx_group_ids_train)) { + if (!is.null(rfx_group_ids_test)) { + if ((!is.null(rfx_basis_train)) && (is.null(rfx_basis_test))) { + stop( + "rfx_basis_train is provided but rfx_basis_test is not provided" + ) + } } + } - # Recode group IDs to integer vector (if passed as, for example, a vector of county names, etc...) - has_rfx <- FALSE - has_rfx_test <- FALSE - if (!is.null(rfx_group_ids_train)) { - group_ids_factor <- factor(rfx_group_ids_train) - rfx_group_ids_train <- as.integer(group_ids_factor) - has_rfx <- TRUE - if (!is.null(rfx_group_ids_test)) { - group_ids_factor_test <- factor( - rfx_group_ids_test, - levels = levels(group_ids_factor) - ) - if (sum(is.na(group_ids_factor_test)) > 0) { - stop( - "All random effect group labels provided in rfx_group_ids_test must be present in rfx_group_ids_train" - ) - } - rfx_group_ids_test <- as.integer(group_ids_factor_test) - has_rfx_test <- TRUE - } - } + # # Stop if multivariate treatment is provided + # if (ncol(Z_train) > 1) stop("Multivariate treatments are not currently supported") - # Check that outcome and treatment are numeric - if (!is.numeric(y_train)) { - stop("y_train must be numeric") - } - if (!is.numeric(Z_train)) { - stop("Z_train must be numeric") + # Handle multivariate treatment + has_multivariate_treatment <- ncol(Z_train) > 1 + if (has_multivariate_treatment) { + # Disable adaptive coding, internal propensity model, and + # leaf scale sampling if treatment is multivariate + if (adaptive_coding) { + warning( + "Adaptive coding is incompatible with multivariate treatment and will be ignored" + ) + adaptive_coding <- FALSE + } + if (is.null(propensity_train)) { + if (propensity_covariate != "none") { + warning( + "No propensities were provided for the multivariate treatment; an internal propensity model will not be fitted to the multivariate treatment and propensity_covariate will be set to 'none'" + ) + propensity_covariate <- "none" + } } - if (!is.null(Z_test)) { - if (!is.numeric(Z_test)) stop("Z_test must be numeric") + if (sample_sigma2_leaf_tau) { + warning( + "Sampling leaf scale not yet supported for multivariate leaf models, so the leaf scale parameter will not be sampled for the treatment forest in this model." + ) + sample_sigma2_leaf_tau <- FALSE } + } - # Data consistency checks - if ((!is.null(X_test)) && (ncol(X_test) != ncol(X_train))) { - stop("X_train and X_test must have the same number of columns") - } - if ((!is.null(Z_test)) && (ncol(Z_test) != ncol(Z_train))) { - stop("Z_train and Z_test must have the same number of columns") - } - if ((!is.null(Z_train)) && (nrow(Z_train) != nrow(X_train))) { - stop("Z_train and X_train must have the same number of rows") - } - if ( - (!is.null(propensity_train)) && - (nrow(propensity_train) != nrow(X_train)) - ) { - stop("propensity_train and X_train must have the same number of rows") - } - if ((!is.null(Z_test)) && (nrow(Z_test) != nrow(X_test))) { - stop("Z_test and X_test must have the same number of rows") - } - if ( - (!is.null(propensity_test)) && (nrow(propensity_test) != nrow(X_test)) - ) { - stop("propensity_test and X_test must have the same number of rows") - } - if (nrow(X_train) != length(y_train)) { - stop("X_train and y_train must have the same number of observations") - } - if ( - (!is.null(rfx_basis_test)) && - (ncol(rfx_basis_test) != ncol(rfx_basis_train)) - ) { + # Random effects covariance prior + if (has_rfx) { + if (is.null(rfx_prior_var)) { + rfx_prior_var <- rep(1, ncol(rfx_basis_train)) + } else { + if ((!is.integer(rfx_prior_var)) && (!is.numeric(rfx_prior_var))) { + stop("rfx_prior_var must be a numeric vector") + } + if (length(rfx_prior_var) != ncol(rfx_basis_train)) { + stop("length(rfx_prior_var) must equal ncol(rfx_basis_train)") + } + } + } + + # Update variable weights + variable_weights_adj <- 1 / + sapply(original_var_indices, function(x) sum(original_var_indices == x)) + variable_weights <- variable_weights[original_var_indices] * + variable_weights_adj + + # Create mu and tau (and variance) specific variable weights with weights zeroed out for excluded variables + variable_weights_variance <- variable_weights_tau <- variable_weights_mu <- variable_weights + variable_weights_mu[!(original_var_indices %in% variable_subset_mu)] <- 0 + variable_weights_tau[!(original_var_indices %in% variable_subset_tau)] <- 0 + if (include_variance_forest) { + variable_weights_variance[ + !(original_var_indices %in% variable_subset_variance) + ] <- 0 + } + + # Fill in rfx basis as a vector of 1s (random intercept) if a basis not provided + has_basis_rfx <- FALSE + num_basis_rfx <- 0 + if (has_rfx) { + if (is.null(rfx_basis_train)) { + rfx_basis_train <- matrix( + rep(1, nrow(X_train)), + nrow = nrow(X_train), + ncol = 1 + ) + } else { + has_basis_rfx <- TRUE + num_basis_rfx <- ncol(rfx_basis_train) + } + num_rfx_groups <- length(unique(rfx_group_ids_train)) + num_rfx_components <- ncol(rfx_basis_train) + if (num_rfx_groups == 1) { + warning( + "Only one group was provided for random effect sampling, so the 'redundant parameterization' is likely overkill" + ) + } + } + if (has_rfx_test) { + if (is.null(rfx_basis_test)) { + if (!is.null(rfx_basis_train)) { stop( - "rfx_basis_train and rfx_basis_test must have the same number of columns" + "Random effects basis provided for training set, must also be provided for the test set" ) + } + rfx_basis_test <- matrix( + rep(1, nrow(X_test)), + nrow = nrow(X_test), + ncol = 1 + ) + } + } + + # Check that number of samples are all nonnegative + stopifnot(num_gfr >= 0) + stopifnot(num_burnin >= 0) + stopifnot(num_mcmc >= 0) + + # Determine whether a test set is provided + has_test = !is.null(X_test) + + # Convert y_train to numeric vector if not already converted + if (!is.null(dim(y_train))) { + y_train <- as.matrix(y_train) + } + + # Check whether treatment is binary (specifically 0-1 binary) + binary_treatment <- length(unique(Z_train)) == 2 + if (binary_treatment) { + unique_treatments <- sort(unique(Z_train)) + if (!(all(unique_treatments == c(0, 1)))) binary_treatment <- FALSE + } + + # Adaptive coding will be ignored for continuous / ordered categorical treatments + if ((!binary_treatment) && (adaptive_coding)) { + adaptive_coding <- FALSE + } + + # Check if propensity_covariate is one of the required inputs + if (!(propensity_covariate %in% c("mu", "tau", "both", "none"))) { + stop( + "propensity_covariate must equal one of 'none', 'mu', 'tau', or 'both'" + ) + } + + # Estimate if pre-estimated propensity score is not provided + internal_propensity_model <- FALSE + if ((is.null(propensity_train)) && (propensity_covariate != "none")) { + internal_propensity_model <- TRUE + # Estimate using the last of several iterations of GFR BART + num_burnin <- 10 + num_total <- 50 + bart_model_propensity <- bart( + X_train = X_train, + y_train = as.numeric(Z_train), + X_test = X_test_raw, + num_gfr = num_total, + num_burnin = 0, + num_mcmc = 0 + ) + propensity_train <- rowMeans(bart_model_propensity$y_hat_train[, + (num_burnin + 1):num_total + ]) + if ((is.null(dim(propensity_train))) && (!is.null(propensity_train))) { + propensity_train <- as.matrix(propensity_train) } - if (!is.null(rfx_group_ids_train)) { - if (!is.null(rfx_group_ids_test)) { - if ((!is.null(rfx_basis_train)) && (is.null(rfx_basis_test))) { - stop( - "rfx_basis_train is provided but rfx_basis_test is not provided" - ) - } - } + if (has_test) { + propensity_test <- rowMeans(bart_model_propensity$y_hat_test[, + (num_burnin + 1):num_total + ]) + if ((is.null(dim(propensity_test))) && (!is.null(propensity_test))) { + propensity_test <- as.matrix(propensity_test) + } } + } - # # Stop if multivariate treatment is provided - # if (ncol(Z_train) > 1) stop("Multivariate treatments are not currently supported") - - # Handle multivariate treatment - has_multivariate_treatment <- ncol(Z_train) > 1 - if (has_multivariate_treatment) { - # Disable adaptive coding, internal propensity model, and - # leaf scale sampling if treatment is multivariate - if (adaptive_coding) { - warning( - "Adaptive coding is incompatible with multivariate treatment and will be ignored" - ) - adaptive_coding <- FALSE - } - if (is.null(propensity_train)) { - if (propensity_covariate != "none") { - warning( - "No propensities were provided for the multivariate treatment; an internal propensity model will not be fitted to the multivariate treatment and propensity_covariate will be set to 'none'" - ) - propensity_covariate <- "none" - } - } - if (sample_sigma2_leaf_tau) { - warning( - "Sampling leaf scale not yet supported for multivariate leaf models, so the leaf scale parameter will not be sampled for the treatment forest in this model." - ) - sample_sigma2_leaf_tau <- FALSE - } + if (has_test) { + if (is.null(propensity_test)) { + stop( + "Propensity score must be provided for the test set if provided for the training set" + ) } + } - # Random effects covariance prior - if (has_rfx) { - if (is.null(rfx_prior_var)) { - rfx_prior_var <- rep(1, ncol(rfx_basis_train)) - } else { - if ((!is.integer(rfx_prior_var)) && (!is.numeric(rfx_prior_var))) { - stop("rfx_prior_var must be a numeric vector") - } - if (length(rfx_prior_var) != ncol(rfx_basis_train)) { - stop("length(rfx_prior_var) must equal ncol(rfx_basis_train)") - } - } + # Update feature_types and covariates + feature_types <- as.integer(feature_types) + if (propensity_covariate != "none") { + feature_types <- as.integer(c( + feature_types, + rep(0, ncol(propensity_train)) + )) + X_train <- cbind(X_train, propensity_train) + if (propensity_covariate == "mu") { + variable_weights_mu <- c( + variable_weights_mu, + rep(1. / num_cov_orig, ncol(propensity_train)) + ) + variable_weights_tau <- c( + variable_weights_tau, + rep(0, ncol(propensity_train)) + ) + if (include_variance_forest) { + variable_weights_variance <- c( + variable_weights_variance, + rep(0, ncol(propensity_train)) + ) + } + } else if (propensity_covariate == "tau") { + variable_weights_mu <- c( + variable_weights_mu, + rep(0, ncol(propensity_train)) + ) + variable_weights_tau <- c( + variable_weights_tau, + rep(1. / num_cov_orig, ncol(propensity_train)) + ) + if (include_variance_forest) { + variable_weights_variance <- c( + variable_weights_variance, + rep(0, ncol(propensity_train)) + ) + } + } else if (propensity_covariate == "both") { + variable_weights_mu <- c( + variable_weights_mu, + rep(1. / num_cov_orig, ncol(propensity_train)) + ) + variable_weights_tau <- c( + variable_weights_tau, + rep(1. / num_cov_orig, ncol(propensity_train)) + ) + if (include_variance_forest) { + variable_weights_variance <- c( + variable_weights_variance, + rep(0, ncol(propensity_train)) + ) + } + } + if (has_test) X_test <- cbind(X_test, propensity_test) + } + + # Renormalize variable weights + variable_weights_mu <- variable_weights_mu / sum(variable_weights_mu) + variable_weights_tau <- variable_weights_tau / sum(variable_weights_tau) + if (include_variance_forest) { + variable_weights_variance <- variable_weights_variance / + sum(variable_weights_variance) + } + + # Set num_features_subsample to default, ncol(X_train), if not already set + if (is.null(num_features_subsample_mu)) { + num_features_subsample_mu <- ncol(X_train) + } + if (is.null(num_features_subsample_tau)) { + num_features_subsample_tau <- ncol(X_train) + } + if (is.null(num_features_subsample_variance)) { + num_features_subsample_variance <- ncol(X_train) + } + + # Preliminary runtime checks for probit link + if (probit_outcome_model) { + if (!(length(unique(y_train)) == 2)) { + stop( + "You specified a probit outcome model, but supplied an outcome with more than 2 unique values" + ) + } + unique_outcomes <- sort(unique(y_train)) + if (!(all(unique_outcomes == c(0, 1)))) { + stop( + "You specified a probit outcome model, but supplied an outcome with 2 unique values other than 0 and 1" + ) } - - # Update variable weights - variable_weights_adj <- 1 / - sapply(original_var_indices, function(x) sum(original_var_indices == x)) - variable_weights <- variable_weights[original_var_indices] * - variable_weights_adj - - # Create mu and tau (and variance) specific variable weights with weights zeroed out for excluded variables - variable_weights_variance <- variable_weights_tau <- variable_weights_mu <- variable_weights - variable_weights_mu[!(original_var_indices %in% variable_subset_mu)] <- 0 - variable_weights_tau[!(original_var_indices %in% variable_subset_tau)] <- 0 if (include_variance_forest) { - variable_weights_variance[ - !(original_var_indices %in% variable_subset_variance) - ] <- 0 + stop("We do not support heteroskedasticity with a probit link") } - - # Fill in rfx basis as a vector of 1s (random intercept) if a basis not provided - has_basis_rfx <- FALSE - num_basis_rfx <- 0 - if (has_rfx) { - if (is.null(rfx_basis_train)) { - rfx_basis_train <- matrix( - rep(1, nrow(X_train)), - nrow = nrow(X_train), - ncol = 1 - ) - } else { - has_basis_rfx <- TRUE - num_basis_rfx <- ncol(rfx_basis_train) - } - num_rfx_groups <- length(unique(rfx_group_ids_train)) - num_rfx_components <- ncol(rfx_basis_train) - if (num_rfx_groups == 1) { - warning( - "Only one group was provided for random effect sampling, so the 'redundant parameterization' is likely overkill" - ) + if (sample_sigma2_global) { + warning( + "Global error variance will not be sampled with a probit link as it is fixed at 1" + ) + sample_sigma2_global <- F + } + } + + # Handle standardization, prior calibration, and initialization of forest + # differently for binary and continuous outcomes + if (probit_outcome_model) { + # Compute a probit-scale offset and fix scale to 1 + y_bar_train <- qnorm(mean(y_train)) + y_std_train <- 1 + + # Set a pseudo outcome by subtracting mean(y_train) from y_train + resid_train <- y_train - mean(y_train) + + # Set initial value for the mu forest + init_mu <- 0.0 + + # Calibrate priors for global sigma^2 and sigma2_leaf_mu / sigma2_leaf_tau + # Set sigma2_init to 1, ignoring any defaults provided + sigma2_init <- 1.0 + # Skip variance_forest_init, since variance forests are not supported with probit link + if (is.null(b_leaf_mu)) { + b_leaf_mu <- 1 / num_trees_mu + } + if (is.null(b_leaf_tau)) { + b_leaf_tau <- 1 / (2 * num_trees_tau) + } + if (is.null(sigma2_leaf_mu)) { + sigma2_leaf_mu <- 2 / (num_trees_mu) + current_leaf_scale_mu <- as.matrix(sigma2_leaf_mu) + } else { + if (!is.matrix(sigma2_leaf_mu)) { + current_leaf_scale_mu <- as.matrix(sigma2_leaf_mu) + } else { + current_leaf_scale_mu <- sigma2_leaf_mu + } + } + if (is.null(sigma2_leaf_tau)) { + # Calibrate prior so that P(abs(tau(X)) < delta_max / dnorm(0)) = p + # Use p = 0.9 as an internal default rather than adding another + # user-facing "parameter" of the binary outcome BCF prior. + # Can be overriden by specifying `sigma2_leaf_init` in + # treatment_effect_forest_params. + p <- 0.6827 + q_quantile <- qnorm((p + 1) / 2) + sigma2_leaf_tau <- ((delta_max / (q_quantile * dnorm(0)))^2) / + num_trees_tau + current_leaf_scale_tau <- as.matrix(diag( + sigma2_leaf_tau, + ncol(Z_train) + )) + } else { + if (!is.matrix(sigma2_leaf_tau)) { + current_leaf_scale_tau <- as.matrix(diag( + sigma2_leaf_tau, + ncol(Z_train) + )) + } else { + if (ncol(sigma2_leaf_tau) != ncol(Z_train)) { + stop( + "sigma2_leaf_init for the tau forest must have the same number of columns / rows as columns in the Z_train matrix" + ) } - } - if (has_rfx_test) { - if (is.null(rfx_basis_test)) { - if (!is.null(rfx_basis_train)) { - stop( - "Random effects basis provided for training set, must also be provided for the test set" - ) - } - rfx_basis_test <- matrix( - rep(1, nrow(X_test)), - nrow = nrow(X_test), - ncol = 1 - ) + if (nrow(sigma2_leaf_tau) != ncol(Z_train)) { + stop( + "sigma2_leaf_init for the tau forest must have the same number of columns / rows as columns in the Z_train matrix" + ) } + current_leaf_scale_tau <- sigma2_leaf_tau + } + } + current_sigma2 <- sigma2_init + } else { + # Only standardize if user requested + if (standardize) { + y_bar_train <- mean(y_train) + y_std_train <- sd(y_train) + } else { + y_bar_train <- 0 + y_std_train <- 1 } - # Check that number of samples are all nonnegative - stopifnot(num_gfr >= 0) - stopifnot(num_burnin >= 0) - stopifnot(num_mcmc >= 0) - - # Determine whether a test set is provided - has_test = !is.null(X_test) + # Compute standardized outcome + resid_train <- (y_train - y_bar_train) / y_std_train - # Convert y_train to numeric vector if not already converted - if (!is.null(dim(y_train))) { - y_train <- as.matrix(y_train) - } + # Set initial value for the mu forest + init_mu <- mean(resid_train) - # Check whether treatment is binary (specifically 0-1 binary) - binary_treatment <- length(unique(Z_train)) == 2 - if (binary_treatment) { - unique_treatments <- sort(unique(Z_train)) - if (!(all(unique_treatments == c(0, 1)))) binary_treatment <- FALSE + # Calibrate priors for global sigma^2 and sigma2_leaf_mu / sigma2_leaf_tau + if (is.null(sigma2_init)) { + sigma2_init <- 1.0 * var(resid_train) } - - # Adaptive coding will be ignored for continuous / ordered categorical treatments - if ((!binary_treatment) && (adaptive_coding)) { - adaptive_coding <- FALSE + if (is.null(variance_forest_init)) { + variance_forest_init <- 1.0 * var(resid_train) } - - # Check if propensity_covariate is one of the required inputs - if (!(propensity_covariate %in% c("mu", "tau", "both", "none"))) { - stop( - "propensity_covariate must equal one of 'none', 'mu', 'tau', or 'both'" - ) - } - - # Estimate if pre-estimated propensity score is not provided - internal_propensity_model <- FALSE - if ((is.null(propensity_train)) && (propensity_covariate != "none")) { - internal_propensity_model <- TRUE - # Estimate using the last of several iterations of GFR BART - num_burnin <- 10 - num_total <- 50 - bart_model_propensity <- bart( - X_train = X_train, - y_train = as.numeric(Z_train), - X_test = X_test_raw, - num_gfr = num_total, - num_burnin = 0, - num_mcmc = 0 - ) - propensity_train <- rowMeans(bart_model_propensity$y_hat_train[, - (num_burnin + 1):num_total - ]) - if ((is.null(dim(propensity_train))) && (!is.null(propensity_train))) { - propensity_train <- as.matrix(propensity_train) - } - if (has_test) { - propensity_test <- rowMeans(bart_model_propensity$y_hat_test[, - (num_burnin + 1):num_total - ]) - if ( - (is.null(dim(propensity_test))) && (!is.null(propensity_test)) - ) { - propensity_test <- as.matrix(propensity_test) - } - } + if (is.null(b_leaf_mu)) { + b_leaf_mu <- var(resid_train) / (num_trees_mu) } - - if (has_test) { - if (is.null(propensity_test)) { - stop( - "Propensity score must be provided for the test set if provided for the training set" - ) - } + if (is.null(b_leaf_tau)) { + b_leaf_tau <- var(resid_train) / (2 * num_trees_tau) } - - # Update feature_types and covariates - feature_types <- as.integer(feature_types) - if (propensity_covariate != "none") { - feature_types <- as.integer(c( - feature_types, - rep(0, ncol(propensity_train)) + if (is.null(sigma2_leaf_mu)) { + sigma2_leaf_mu <- 2.0 * var(resid_train) / (num_trees_mu) + current_leaf_scale_mu <- as.matrix(sigma2_leaf_mu) + } else { + if (!is.matrix(sigma2_leaf_mu)) { + current_leaf_scale_mu <- as.matrix(sigma2_leaf_mu) + } else { + current_leaf_scale_mu <- sigma2_leaf_mu + } + } + if (is.null(sigma2_leaf_tau)) { + sigma2_leaf_tau <- var(resid_train) / (num_trees_tau) + current_leaf_scale_tau <- as.matrix(diag( + sigma2_leaf_tau, + ncol(Z_train) + )) + } else { + if (!is.matrix(sigma2_leaf_tau)) { + current_leaf_scale_tau <- as.matrix(diag( + sigma2_leaf_tau, + ncol(Z_train) )) - X_train <- cbind(X_train, propensity_train) - if (propensity_covariate == "mu") { - variable_weights_mu <- c( - variable_weights_mu, - rep(1. / num_cov_orig, ncol(propensity_train)) - ) - variable_weights_tau <- c( - variable_weights_tau, - rep(0, ncol(propensity_train)) - ) - if (include_variance_forest) { - variable_weights_variance <- c( - variable_weights_variance, - rep(0, ncol(propensity_train)) - ) - } - } else if (propensity_covariate == "tau") { - variable_weights_mu <- c( - variable_weights_mu, - rep(0, ncol(propensity_train)) - ) - variable_weights_tau <- c( - variable_weights_tau, - rep(1. / num_cov_orig, ncol(propensity_train)) - ) - if (include_variance_forest) { - variable_weights_variance <- c( - variable_weights_variance, - rep(0, ncol(propensity_train)) - ) - } - } else if (propensity_covariate == "both") { - variable_weights_mu <- c( - variable_weights_mu, - rep(1. / num_cov_orig, ncol(propensity_train)) - ) - variable_weights_tau <- c( - variable_weights_tau, - rep(1. / num_cov_orig, ncol(propensity_train)) - ) - if (include_variance_forest) { - variable_weights_variance <- c( - variable_weights_variance, - rep(0, ncol(propensity_train)) - ) - } + } else { + if (ncol(sigma2_leaf_tau) != ncol(Z_train)) { + stop( + "sigma2_leaf_init for the tau forest must have the same number of columns / rows as columns in the Z_train matrix" + ) } - if (has_test) X_test <- cbind(X_test, propensity_test) - } - - # Renormalize variable weights - variable_weights_mu <- variable_weights_mu / sum(variable_weights_mu) - variable_weights_tau <- variable_weights_tau / sum(variable_weights_tau) - if (include_variance_forest) { - variable_weights_variance <- variable_weights_variance / - sum(variable_weights_variance) - } - - # Set num_features_subsample to default, ncol(X_train), if not already set - if (is.null(num_features_subsample_mu)) { - num_features_subsample_mu <- ncol(X_train) - } - if (is.null(num_features_subsample_tau)) { - num_features_subsample_tau <- ncol(X_train) - } - if (is.null(num_features_subsample_variance)) { - num_features_subsample_variance <- ncol(X_train) - } - - # Preliminary runtime checks for probit link - if (probit_outcome_model) { - if (!(length(unique(y_train)) == 2)) { - stop( - "You specified a probit outcome model, but supplied an outcome with more than 2 unique values" - ) - } - unique_outcomes <- sort(unique(y_train)) - if (!(all(unique_outcomes == c(0, 1)))) { - stop( - "You specified a probit outcome model, but supplied an outcome with 2 unique values other than 0 and 1" - ) - } - if (include_variance_forest) { - stop("We do not support heteroskedasticity with a probit link") - } - if (sample_sigma2_global) { - warning( - "Global error variance will not be sampled with a probit link as it is fixed at 1" - ) - sample_sigma2_global <- F + if (nrow(sigma2_leaf_tau) != ncol(Z_train)) { + stop( + "sigma2_leaf_init for the tau forest must have the same number of columns / rows as columns in the Z_train matrix" + ) } - } - - # Handle standardization, prior calibration, and initialization of forest - # differently for binary and continuous outcomes - if (probit_outcome_model) { - # Compute a probit-scale offset and fix scale to 1 - y_bar_train <- qnorm(mean(y_train)) - y_std_train <- 1 - - # Set a pseudo outcome by subtracting mean(y_train) from y_train - resid_train <- y_train - mean(y_train) - - # Set initial value for the mu forest - init_mu <- 0.0 - - # Calibrate priors for global sigma^2 and sigma2_leaf_mu / sigma2_leaf_tau - # Set sigma2_init to 1, ignoring any defaults provided - sigma2_init <- 1.0 - # Skip variance_forest_init, since variance forests are not supported with probit link - if (is.null(b_leaf_mu)) { - b_leaf_mu <- 1 / num_trees_mu - } - if (is.null(b_leaf_tau)) { - b_leaf_tau <- 1 / (2 * num_trees_tau) - } - if (is.null(sigma2_leaf_mu)) { - sigma2_leaf_mu <- 2 / (num_trees_mu) - current_leaf_scale_mu <- as.matrix(sigma2_leaf_mu) - } else { - if (!is.matrix(sigma2_leaf_mu)) { - current_leaf_scale_mu <- as.matrix(sigma2_leaf_mu) - } else { - current_leaf_scale_mu <- sigma2_leaf_mu - } - } - if (is.null(sigma2_leaf_tau)) { - # Calibrate prior so that P(abs(tau(X)) < delta_max / dnorm(0)) = p - # Use p = 0.9 as an internal default rather than adding another - # user-facing "parameter" of the binary outcome BCF prior. - # Can be overriden by specifying `sigma2_leaf_init` in - # treatment_effect_forest_params. - p <- 0.6827 - q_quantile <- qnorm((p + 1) / 2) - sigma2_leaf_tau <- ((delta_max / (q_quantile * dnorm(0)))^2) / - num_trees_tau - current_leaf_scale_tau <- as.matrix(diag( - sigma2_leaf_tau, - ncol(Z_train) - )) - } else { - if (!is.matrix(sigma2_leaf_tau)) { - current_leaf_scale_tau <- as.matrix(diag( - sigma2_leaf_tau, - ncol(Z_train) - )) - } else { - if (ncol(sigma2_leaf_tau) != ncol(Z_train)) { - stop( - "sigma2_leaf_init for the tau forest must have the same number of columns / rows as columns in the Z_train matrix" - ) - } - if (nrow(sigma2_leaf_tau) != ncol(Z_train)) { - stop( - "sigma2_leaf_init for the tau forest must have the same number of columns / rows as columns in the Z_train matrix" - ) - } - current_leaf_scale_tau <- sigma2_leaf_tau - } - } - current_sigma2 <- sigma2_init + current_leaf_scale_tau <- sigma2_leaf_tau + } + } + current_sigma2 <- sigma2_init + } + + # Set mu and tau leaf models / dimensions + leaf_model_mu_forest <- 0 + leaf_dimension_mu_forest <- 1 + if (has_multivariate_treatment) { + leaf_model_tau_forest <- 2 + leaf_dimension_tau_forest <- ncol(Z_train) + } else { + leaf_model_tau_forest <- 1 + leaf_dimension_tau_forest <- 1 + } + + # Set variance leaf model type (currently only one option) + leaf_model_variance_forest <- 3 + leaf_dimension_variance_forest <- 1 + + # Random effects prior parameters + if (has_rfx) { + # Prior parameters + if (is.null(rfx_working_parameter_prior_mean)) { + if (num_rfx_components == 1) { + alpha_init <- c(0) + } else if (num_rfx_components > 1) { + alpha_init <- rep(0, num_rfx_components) + } else { + stop("There must be at least 1 random effect component") + } } else { - # Only standardize if user requested - if (standardize) { - y_bar_train <- mean(y_train) - y_std_train <- sd(y_train) - } else { - y_bar_train <- 0 - y_std_train <- 1 - } - - # Compute standardized outcome - resid_train <- (y_train - y_bar_train) / y_std_train - - # Set initial value for the mu forest - init_mu <- mean(resid_train) - - # Calibrate priors for global sigma^2 and sigma2_leaf_mu / sigma2_leaf_tau - if (is.null(sigma2_init)) { - sigma2_init <- 1.0 * var(resid_train) - } - if (is.null(variance_forest_init)) { - variance_forest_init <- 1.0 * var(resid_train) - } - if (is.null(b_leaf_mu)) { - b_leaf_mu <- var(resid_train) / (num_trees_mu) - } - if (is.null(b_leaf_tau)) { - b_leaf_tau <- var(resid_train) / (2 * num_trees_tau) - } - if (is.null(sigma2_leaf_mu)) { - sigma2_leaf_mu <- 2.0 * var(resid_train) / (num_trees_mu) - current_leaf_scale_mu <- as.matrix(sigma2_leaf_mu) - } else { - if (!is.matrix(sigma2_leaf_mu)) { - current_leaf_scale_mu <- as.matrix(sigma2_leaf_mu) - } else { - current_leaf_scale_mu <- sigma2_leaf_mu - } - } - if (is.null(sigma2_leaf_tau)) { - sigma2_leaf_tau <- var(resid_train) / (num_trees_tau) - current_leaf_scale_tau <- as.matrix(diag( - sigma2_leaf_tau, - ncol(Z_train) - )) - } else { - if (!is.matrix(sigma2_leaf_tau)) { - current_leaf_scale_tau <- as.matrix(diag( - sigma2_leaf_tau, - ncol(Z_train) - )) - } else { - if (ncol(sigma2_leaf_tau) != ncol(Z_train)) { - stop( - "sigma2_leaf_init for the tau forest must have the same number of columns / rows as columns in the Z_train matrix" - ) - } - if (nrow(sigma2_leaf_tau) != ncol(Z_train)) { - stop( - "sigma2_leaf_init for the tau forest must have the same number of columns / rows as columns in the Z_train matrix" - ) - } - current_leaf_scale_tau <- sigma2_leaf_tau - } - } - current_sigma2 <- sigma2_init - } - - # Set mu and tau leaf models / dimensions - leaf_model_mu_forest <- 0 - leaf_dimension_mu_forest <- 1 - if (has_multivariate_treatment) { - leaf_model_tau_forest <- 2 - leaf_dimension_tau_forest <- ncol(Z_train) + alpha_init <- expand_dims_1d( + rfx_working_parameter_prior_mean, + num_rfx_components + ) + } + + if (is.null(rfx_group_parameter_prior_mean)) { + xi_init <- matrix( + rep(alpha_init, num_rfx_groups), + num_rfx_components, + num_rfx_groups + ) } else { - leaf_model_tau_forest <- 1 - leaf_dimension_tau_forest <- 1 - } - - # Set variance leaf model type (currently only one option) - leaf_model_variance_forest <- 3 - leaf_dimension_variance_forest <- 1 - - # Random effects prior parameters - if (has_rfx) { - # Prior parameters - if (is.null(rfx_working_parameter_prior_mean)) { - if (num_rfx_components == 1) { - alpha_init <- c(0) - } else if (num_rfx_components > 1) { - alpha_init <- rep(0, num_rfx_components) - } else { - stop("There must be at least 1 random effect component") - } - } else { - alpha_init <- expand_dims_1d( - rfx_working_parameter_prior_mean, - num_rfx_components - ) - } - - if (is.null(rfx_group_parameter_prior_mean)) { - xi_init <- matrix( - rep(alpha_init, num_rfx_groups), - num_rfx_components, - num_rfx_groups - ) - } else { - xi_init <- expand_dims_2d( - rfx_group_parameter_prior_mean, - num_rfx_components, - num_rfx_groups - ) - } - - if (is.null(rfx_working_parameter_prior_cov)) { - sigma_alpha_init <- diag(1, num_rfx_components, num_rfx_components) - } else { - sigma_alpha_init <- expand_dims_2d_diag( - rfx_working_parameter_prior_cov, - num_rfx_components - ) - } - - if (is.null(rfx_group_parameter_prior_cov)) { - sigma_xi_init <- diag(1, num_rfx_components, num_rfx_components) - } else { - sigma_xi_init <- expand_dims_2d_diag( - rfx_group_parameter_prior_cov, - num_rfx_components - ) - } - - sigma_xi_shape <- rfx_variance_prior_shape - sigma_xi_scale <- rfx_variance_prior_scale - } - - # Random effects data structure and storage container - if (has_rfx) { - rfx_dataset_train <- createRandomEffectsDataset( - rfx_group_ids_train, - rfx_basis_train - ) - rfx_tracker_train <- createRandomEffectsTracker(rfx_group_ids_train) - rfx_model <- createRandomEffectsModel( - num_rfx_components, - num_rfx_groups - ) - rfx_model$set_working_parameter(alpha_init) - rfx_model$set_group_parameters(xi_init) - rfx_model$set_working_parameter_cov(sigma_alpha_init) - rfx_model$set_group_parameter_cov(sigma_xi_init) - rfx_model$set_variance_prior_shape(sigma_xi_shape) - rfx_model$set_variance_prior_scale(sigma_xi_scale) - rfx_samples <- createRandomEffectSamples( - num_rfx_components, - num_rfx_groups, - rfx_tracker_train - ) + xi_init <- expand_dims_2d( + rfx_group_parameter_prior_mean, + num_rfx_components, + num_rfx_groups + ) } - # Container of variance parameter samples - num_actual_mcmc_iter <- num_mcmc * keep_every - num_samples <- num_gfr + num_burnin + num_actual_mcmc_iter - # Delete GFR samples from these containers after the fact if desired - # num_retained_samples <- ifelse(keep_gfr, num_gfr, 0) + ifelse(keep_burnin, num_burnin, 0) + num_mcmc - num_retained_samples <- num_gfr + - ifelse(keep_burnin, num_burnin, 0) + - num_mcmc * num_chains - if (sample_sigma2_global) { - global_var_samples <- rep(NA, num_retained_samples) - } - if (sample_sigma2_leaf_mu) { - leaf_scale_mu_samples <- rep(NA, num_retained_samples) - } - if (sample_sigma2_leaf_tau) { - leaf_scale_tau_samples <- rep(NA, num_retained_samples) - } - muhat_train_raw <- matrix(NA_real_, nrow(X_train), num_retained_samples) - if (include_variance_forest) { - sigma2_x_train_raw <- matrix( - NA_real_, - nrow(X_train), - num_retained_samples - ) - } - sample_counter <- 0 - - # Prepare adaptive coding structure - if ( - (!is.numeric(b_0)) || - (!is.numeric(b_1)) || - (length(b_0) > 1) || - (length(b_1) > 1) - ) { - stop("b_0 and b_1 must be single numeric values") - } - if (adaptive_coding) { - b_0_samples <- rep(NA, num_retained_samples) - b_1_samples <- rep(NA, num_retained_samples) - current_b_0 <- b_0 - current_b_1 <- b_1 - tau_basis_train <- (1 - Z_train) * current_b_0 + Z_train * current_b_1 - if (has_test) { - tau_basis_test <- (1 - Z_test) * current_b_0 + Z_test * current_b_1 - } + if (is.null(rfx_working_parameter_prior_cov)) { + sigma_alpha_init <- diag(1, num_rfx_components, num_rfx_components) } else { - tau_basis_train <- Z_train - if (has_test) tau_basis_test <- Z_test + sigma_alpha_init <- expand_dims_2d_diag( + rfx_working_parameter_prior_cov, + num_rfx_components + ) } - # Data - forest_dataset_train <- createForestDataset(X_train, tau_basis_train) - if (has_test) { - forest_dataset_test <- createForestDataset(X_test, tau_basis_test) + if (is.null(rfx_group_parameter_prior_cov)) { + sigma_xi_init <- diag(1, num_rfx_components, num_rfx_components) + } else { + sigma_xi_init <- expand_dims_2d_diag( + rfx_group_parameter_prior_cov, + num_rfx_components + ) } - outcome_train <- createOutcome(resid_train) - # Random number generator (std::mt19937) - if (is.null(random_seed)) { - random_seed = sample(1:10000, 1, FALSE) - } - rng <- createCppRNG(random_seed) + sigma_xi_shape <- rfx_variance_prior_shape + sigma_xi_scale <- rfx_variance_prior_scale + } - # Sampling data structures - global_model_config <- createGlobalModelConfig( - global_error_variance = current_sigma2 - ) - forest_model_config_mu <- createForestModelConfig( - feature_types = feature_types, - num_trees = num_trees_mu, - num_features = ncol(X_train), - num_observations = nrow(X_train), - variable_weights = variable_weights_mu, - leaf_dimension = leaf_dimension_mu_forest, - alpha = alpha_mu, - beta = beta_mu, - min_samples_leaf = min_samples_leaf_mu, - max_depth = max_depth_mu, - leaf_model_type = leaf_model_mu_forest, - leaf_model_scale = current_leaf_scale_mu, - cutpoint_grid_size = cutpoint_grid_size, - num_features_subsample = num_features_subsample_mu + # Random effects data structure and storage container + if (has_rfx) { + rfx_dataset_train <- createRandomEffectsDataset( + rfx_group_ids_train, + rfx_basis_train ) - forest_model_config_tau <- createForestModelConfig( - feature_types = feature_types, - num_trees = num_trees_tau, - num_features = ncol(X_train), - num_observations = nrow(X_train), - variable_weights = variable_weights_tau, - leaf_dimension = leaf_dimension_tau_forest, - alpha = alpha_tau, - beta = beta_tau, - min_samples_leaf = min_samples_leaf_tau, - max_depth = max_depth_tau, - leaf_model_type = leaf_model_tau_forest, - leaf_model_scale = current_leaf_scale_tau, - cutpoint_grid_size = cutpoint_grid_size, - num_features_subsample = num_features_subsample_tau + rfx_tracker_train <- createRandomEffectsTracker(rfx_group_ids_train) + rfx_model <- createRandomEffectsModel( + num_rfx_components, + num_rfx_groups ) - forest_model_mu <- createForestModel( - forest_dataset_train, - forest_model_config_mu, - global_model_config + rfx_model$set_working_parameter(alpha_init) + rfx_model$set_group_parameters(xi_init) + rfx_model$set_working_parameter_cov(sigma_alpha_init) + rfx_model$set_group_parameter_cov(sigma_xi_init) + rfx_model$set_variance_prior_shape(sigma_xi_shape) + rfx_model$set_variance_prior_scale(sigma_xi_scale) + rfx_samples <- createRandomEffectSamples( + num_rfx_components, + num_rfx_groups, + rfx_tracker_train ) - forest_model_tau <- createForestModel( - forest_dataset_train, - forest_model_config_tau, - global_model_config + } + + # Container of variance parameter samples + num_actual_mcmc_iter <- num_mcmc * keep_every + num_samples <- num_gfr + num_burnin + num_actual_mcmc_iter + # Delete GFR samples from these containers after the fact if desired + # num_retained_samples <- ifelse(keep_gfr, num_gfr, 0) + ifelse(keep_burnin, num_burnin, 0) + num_mcmc + num_retained_samples <- num_gfr + + ifelse(keep_burnin, num_burnin, 0) + + num_mcmc * num_chains + if (sample_sigma2_global) { + global_var_samples <- rep(NA, num_retained_samples) + } + if (sample_sigma2_leaf_mu) { + leaf_scale_mu_samples <- rep(NA, num_retained_samples) + } + if (sample_sigma2_leaf_tau) { + leaf_scale_tau_samples <- rep(NA, num_retained_samples) + } + muhat_train_raw <- matrix(NA_real_, nrow(X_train), num_retained_samples) + if (include_variance_forest) { + sigma2_x_train_raw <- matrix( + NA_real_, + nrow(X_train), + num_retained_samples ) - if (include_variance_forest) { - forest_model_config_variance <- createForestModelConfig( - feature_types = feature_types, - num_trees = num_trees_variance, - num_features = ncol(X_train), - num_observations = nrow(X_train), - variable_weights = variable_weights_variance, - leaf_dimension = leaf_dimension_variance_forest, - alpha = alpha_variance, - beta = beta_variance, - min_samples_leaf = min_samples_leaf_variance, - max_depth = max_depth_variance, - leaf_model_type = leaf_model_variance_forest, - cutpoint_grid_size = cutpoint_grid_size, - num_features_subsample = num_features_subsample_variance - ) - forest_model_variance <- createForestModel( - forest_dataset_train, - forest_model_config_variance, - global_model_config - ) - } - - # Container of forest samples - forest_samples_mu <- createForestSamples(num_trees_mu, 1, TRUE) - forest_samples_tau <- createForestSamples( - num_trees_tau, - ncol(Z_train), - FALSE + } + sample_counter <- 0 + + # Prepare adaptive coding structure + if ( + (!is.numeric(b_0)) || + (!is.numeric(b_1)) || + (length(b_0) > 1) || + (length(b_1) > 1) + ) { + stop("b_0 and b_1 must be single numeric values") + } + if (adaptive_coding) { + b_0_samples <- rep(NA, num_retained_samples) + b_1_samples <- rep(NA, num_retained_samples) + current_b_0 <- b_0 + current_b_1 <- b_1 + tau_basis_train <- (1 - Z_train) * current_b_0 + Z_train * current_b_1 + if (has_test) { + tau_basis_test <- (1 - Z_test) * current_b_0 + Z_test * current_b_1 + } + } else { + tau_basis_train <- Z_train + if (has_test) tau_basis_test <- Z_test + } + + # Data + forest_dataset_train <- createForestDataset(X_train, tau_basis_train) + if (has_test) { + forest_dataset_test <- createForestDataset(X_test, tau_basis_test) + } + outcome_train <- createOutcome(resid_train) + + # Random number generator (std::mt19937) + if (is.null(random_seed)) { + random_seed = sample(1:10000, 1, FALSE) + } + rng <- createCppRNG(random_seed) + + # Sampling data structures + global_model_config <- createGlobalModelConfig( + global_error_variance = current_sigma2 + ) + forest_model_config_mu <- createForestModelConfig( + feature_types = feature_types, + num_trees = num_trees_mu, + num_features = ncol(X_train), + num_observations = nrow(X_train), + variable_weights = variable_weights_mu, + leaf_dimension = leaf_dimension_mu_forest, + alpha = alpha_mu, + beta = beta_mu, + min_samples_leaf = min_samples_leaf_mu, + max_depth = max_depth_mu, + leaf_model_type = leaf_model_mu_forest, + leaf_model_scale = current_leaf_scale_mu, + cutpoint_grid_size = cutpoint_grid_size, + num_features_subsample = num_features_subsample_mu + ) + forest_model_config_tau <- createForestModelConfig( + feature_types = feature_types, + num_trees = num_trees_tau, + num_features = ncol(X_train), + num_observations = nrow(X_train), + variable_weights = variable_weights_tau, + leaf_dimension = leaf_dimension_tau_forest, + alpha = alpha_tau, + beta = beta_tau, + min_samples_leaf = min_samples_leaf_tau, + max_depth = max_depth_tau, + leaf_model_type = leaf_model_tau_forest, + leaf_model_scale = current_leaf_scale_tau, + cutpoint_grid_size = cutpoint_grid_size, + num_features_subsample = num_features_subsample_tau + ) + forest_model_mu <- createForestModel( + forest_dataset_train, + forest_model_config_mu, + global_model_config + ) + forest_model_tau <- createForestModel( + forest_dataset_train, + forest_model_config_tau, + global_model_config + ) + if (include_variance_forest) { + forest_model_config_variance <- createForestModelConfig( + feature_types = feature_types, + num_trees = num_trees_variance, + num_features = ncol(X_train), + num_observations = nrow(X_train), + variable_weights = variable_weights_variance, + leaf_dimension = leaf_dimension_variance_forest, + alpha = alpha_variance, + beta = beta_variance, + min_samples_leaf = min_samples_leaf_variance, + max_depth = max_depth_variance, + leaf_model_type = leaf_model_variance_forest, + cutpoint_grid_size = cutpoint_grid_size, + num_features_subsample = num_features_subsample_variance ) - active_forest_mu <- createForest(num_trees_mu, 1, TRUE) - active_forest_tau <- createForest(num_trees_tau, ncol(Z_train), FALSE) - if (include_variance_forest) { - forest_samples_variance <- createForestSamples( - num_trees_variance, - 1, - TRUE, - TRUE - ) - active_forest_variance <- createForest( - num_trees_variance, - 1, - TRUE, - TRUE - ) - } - - # Initialize the leaves of each tree in the prognostic forest - active_forest_mu$prepare_for_sampler( - forest_dataset_train, - outcome_train, - forest_model_mu, - leaf_model_mu_forest, - init_mu + forest_model_variance <- createForestModel( + forest_dataset_train, + forest_model_config_variance, + global_model_config ) - active_forest_mu$adjust_residual( - forest_dataset_train, - outcome_train, - forest_model_mu, - FALSE, - FALSE + } + + # Container of forest samples + forest_samples_mu <- createForestSamples(num_trees_mu, 1, TRUE) + forest_samples_tau <- createForestSamples( + num_trees_tau, + ncol(Z_train), + FALSE + ) + active_forest_mu <- createForest(num_trees_mu, 1, TRUE) + active_forest_tau <- createForest(num_trees_tau, ncol(Z_train), FALSE) + if (include_variance_forest) { + forest_samples_variance <- createForestSamples( + num_trees_variance, + 1, + TRUE, + TRUE ) - - # Initialize the leaves of each tree in the treatment effect forest - init_tau <- rep(0., ncol(Z_train)) - active_forest_tau$prepare_for_sampler( - forest_dataset_train, - outcome_train, - forest_model_tau, - leaf_model_tau_forest, - init_tau + active_forest_variance <- createForest( + num_trees_variance, + 1, + TRUE, + TRUE ) - active_forest_tau$adjust_residual( - forest_dataset_train, - outcome_train, - forest_model_tau, - TRUE, - FALSE + } + + # Initialize the leaves of each tree in the prognostic forest + active_forest_mu$prepare_for_sampler( + forest_dataset_train, + outcome_train, + forest_model_mu, + leaf_model_mu_forest, + init_mu + ) + active_forest_mu$adjust_residual( + forest_dataset_train, + outcome_train, + forest_model_mu, + FALSE, + FALSE + ) + + # Initialize the leaves of each tree in the treatment effect forest + init_tau <- rep(0., ncol(Z_train)) + active_forest_tau$prepare_for_sampler( + forest_dataset_train, + outcome_train, + forest_model_tau, + leaf_model_tau_forest, + init_tau + ) + active_forest_tau$adjust_residual( + forest_dataset_train, + outcome_train, + forest_model_tau, + TRUE, + FALSE + ) + + # Initialize the leaves of each tree in the variance forest + if (include_variance_forest) { + active_forest_variance$prepare_for_sampler( + forest_dataset_train, + outcome_train, + forest_model_variance, + leaf_model_variance_forest, + variance_forest_init ) + } + + # Run GFR (warm start) if specified + if (num_gfr > 0) { + for (i in 1:num_gfr) { + # Keep all GFR samples at this stage -- remove from ForestSamples after MCMC + # keep_sample <- ifelse(keep_gfr, TRUE, FALSE) + keep_sample <- TRUE + if (keep_sample) { + sample_counter <- sample_counter + 1 + } + # Print progress + if (verbose) { + if ((i %% 10 == 0) || (i == num_gfr)) { + cat( + "Sampling", + i, + "out of", + num_gfr, + "XBCF (grow-from-root) draws\n" + ) + } + } - # Initialize the leaves of each tree in the variance forest - if (include_variance_forest) { - active_forest_variance$prepare_for_sampler( - forest_dataset_train, - outcome_train, - forest_model_variance, - leaf_model_variance_forest, - variance_forest_init + if (probit_outcome_model) { + # Sample latent probit variable, z | - + mu_forest_pred <- active_forest_mu$predict(forest_dataset_train) + tau_forest_pred <- active_forest_tau$predict( + forest_dataset_train ) - } - - # Run GFR (warm start) if specified - if (num_gfr > 0) { - for (i in 1:num_gfr) { - # Keep all GFR samples at this stage -- remove from ForestSamples after MCMC - # keep_sample <- ifelse(keep_gfr, TRUE, FALSE) - keep_sample <- TRUE - if (keep_sample) { - sample_counter <- sample_counter + 1 - } - # Print progress - if (verbose) { - if ((i %% 10 == 0) || (i == num_gfr)) { - cat( - "Sampling", - i, - "out of", - num_gfr, - "XBCF (grow-from-root) draws\n" - ) - } - } - - if (probit_outcome_model) { - # Sample latent probit variable, z | - - mu_forest_pred <- active_forest_mu$predict(forest_dataset_train) - tau_forest_pred <- active_forest_tau$predict( - forest_dataset_train - ) - forest_pred <- mu_forest_pred + tau_forest_pred - mu0 <- forest_pred[y_train == 0] - mu1 <- forest_pred[y_train == 1] - u0 <- runif(sum(y_train == 0), 0, pnorm(0 - mu0)) - u1 <- runif(sum(y_train == 1), pnorm(0 - mu1), 1) - resid_train[y_train == 0] <- mu0 + qnorm(u0) - resid_train[y_train == 1] <- mu1 + qnorm(u1) - - # Update outcome - outcome_train$update_data(resid_train - forest_pred) - } - - # Sample the prognostic forest - forest_model_mu$sample_one_iteration( - forest_dataset = forest_dataset_train, - residual = outcome_train, - forest_samples = forest_samples_mu, - active_forest = active_forest_mu, - rng = rng, - forest_model_config = forest_model_config_mu, - global_model_config = global_model_config, - num_threads = num_threads, - keep_forest = keep_sample, - gfr = TRUE - ) - - # Cache train set predictions since they are already computed during sampling - if (keep_sample) { - muhat_train_raw[, - sample_counter - ] <- forest_model_mu$get_cached_forest_predictions() - } - - # Sample variance parameters (if requested) - if (sample_sigma2_global) { - current_sigma2 <- sampleGlobalErrorVarianceOneIteration( - outcome_train, - forest_dataset_train, - rng, - a_global, - b_global - ) - global_model_config$update_global_error_variance(current_sigma2) - } - if (sample_sigma2_leaf_mu) { - leaf_scale_mu_double <- sampleLeafVarianceOneIteration( - active_forest_mu, - rng, - a_leaf_mu, - b_leaf_mu - ) - current_leaf_scale_mu <- as.matrix(leaf_scale_mu_double) - if (keep_sample) { - leaf_scale_mu_samples[ - sample_counter - ] <- leaf_scale_mu_double - } - forest_model_config_mu$update_leaf_model_scale( - current_leaf_scale_mu - ) - } - - # Sample the treatment forest - forest_model_tau$sample_one_iteration( - forest_dataset = forest_dataset_train, - residual = outcome_train, - forest_samples = forest_samples_tau, - active_forest = active_forest_tau, - rng = rng, - forest_model_config = forest_model_config_tau, - global_model_config = global_model_config, - num_threads = num_threads, - keep_forest = keep_sample, - gfr = TRUE - ) - - # Cannot cache train set predictions for tau because the cached predictions in the - # tracking data structures are pre-multiplied by the basis (treatment) - # ... - - # Sample coding parameters (if requested) - if (adaptive_coding) { - # Estimate mu(X) and tau(X) and compute y - mu(X) - mu_x_raw_train <- active_forest_mu$predict_raw( - forest_dataset_train - ) - tau_x_raw_train <- active_forest_tau$predict_raw( - forest_dataset_train - ) - partial_resid_mu_train <- resid_train - mu_x_raw_train - if (has_rfx) { - rfx_preds_train <- rfx_model$predict( - rfx_dataset_train, - rfx_tracker_train - ) - partial_resid_mu_train <- partial_resid_mu_train - - rfx_preds_train - } - - # Compute sufficient statistics for regression of y - mu(X) on [tau(X)(1-Z), tau(X)Z] - s_tt0 <- sum(tau_x_raw_train * tau_x_raw_train * (Z_train == 0)) - s_tt1 <- sum(tau_x_raw_train * tau_x_raw_train * (Z_train == 1)) - s_ty0 <- sum( - tau_x_raw_train * partial_resid_mu_train * (Z_train == 0) - ) - s_ty1 <- sum( - tau_x_raw_train * partial_resid_mu_train * (Z_train == 1) - ) - - # Sample b0 (coefficient on tau(X)(1-Z)) and b1 (coefficient on tau(X)Z) - current_b_0 <- rnorm( - 1, - (s_ty0 / (s_tt0 + 2 * current_sigma2)), - sqrt(current_sigma2 / (s_tt0 + 2 * current_sigma2)) - ) - current_b_1 <- rnorm( - 1, - (s_ty1 / (s_tt1 + 2 * current_sigma2)), - sqrt(current_sigma2 / (s_tt1 + 2 * current_sigma2)) - ) - - # Update basis for the leaf regression - tau_basis_train <- (1 - Z_train) * - current_b_0 + - Z_train * current_b_1 - forest_dataset_train$update_basis(tau_basis_train) - if (keep_sample) { - b_0_samples[sample_counter] <- current_b_0 - b_1_samples[sample_counter] <- current_b_1 - } - if (has_test) { - tau_basis_test <- (1 - Z_test) * - current_b_0 + - Z_test * current_b_1 - forest_dataset_test$update_basis(tau_basis_test) - } - - # Update leaf predictions and residual - forest_model_tau$propagate_basis_update( - forest_dataset_train, - outcome_train, - active_forest_tau - ) - } + forest_pred <- mu_forest_pred + tau_forest_pred + mu0 <- forest_pred[y_train == 0] + mu1 <- forest_pred[y_train == 1] + u0 <- runif(sum(y_train == 0), 0, pnorm(0 - mu0)) + u1 <- runif(sum(y_train == 1), pnorm(0 - mu1), 1) + resid_train[y_train == 0] <- mu0 + qnorm(u0) + resid_train[y_train == 1] <- mu1 + qnorm(u1) + + # Update outcome + outcome_train$update_data(resid_train - forest_pred) + } + + # Sample the prognostic forest + forest_model_mu$sample_one_iteration( + forest_dataset = forest_dataset_train, + residual = outcome_train, + forest_samples = forest_samples_mu, + active_forest = active_forest_mu, + rng = rng, + forest_model_config = forest_model_config_mu, + global_model_config = global_model_config, + num_threads = num_threads, + keep_forest = keep_sample, + gfr = TRUE + ) + + # Cache train set predictions since they are already computed during sampling + if (keep_sample) { + muhat_train_raw[, + sample_counter + ] <- forest_model_mu$get_cached_forest_predictions() + } + + # Sample variance parameters (if requested) + if (sample_sigma2_global) { + current_sigma2 <- sampleGlobalErrorVarianceOneIteration( + outcome_train, + forest_dataset_train, + rng, + a_global, + b_global + ) + global_model_config$update_global_error_variance(current_sigma2) + } + if (sample_sigma2_leaf_mu) { + leaf_scale_mu_double <- sampleLeafVarianceOneIteration( + active_forest_mu, + rng, + a_leaf_mu, + b_leaf_mu + ) + current_leaf_scale_mu <- as.matrix(leaf_scale_mu_double) + if (keep_sample) { + leaf_scale_mu_samples[ + sample_counter + ] <- leaf_scale_mu_double + } + forest_model_config_mu$update_leaf_model_scale( + current_leaf_scale_mu + ) + } + + # Sample the treatment forest + forest_model_tau$sample_one_iteration( + forest_dataset = forest_dataset_train, + residual = outcome_train, + forest_samples = forest_samples_tau, + active_forest = active_forest_tau, + rng = rng, + forest_model_config = forest_model_config_tau, + global_model_config = global_model_config, + num_threads = num_threads, + keep_forest = keep_sample, + gfr = TRUE + ) + + # Cannot cache train set predictions for tau because the cached predictions in the + # tracking data structures are pre-multiplied by the basis (treatment) + # ... + + # Sample coding parameters (if requested) + if (adaptive_coding) { + # Estimate mu(X) and tau(X) and compute y - mu(X) + mu_x_raw_train <- active_forest_mu$predict_raw( + forest_dataset_train + ) + tau_x_raw_train <- active_forest_tau$predict_raw( + forest_dataset_train + ) + partial_resid_mu_train <- resid_train - mu_x_raw_train + if (has_rfx) { + rfx_preds_train <- rfx_model$predict( + rfx_dataset_train, + rfx_tracker_train + ) + partial_resid_mu_train <- partial_resid_mu_train - + rfx_preds_train + } - # Sample variance parameters (if requested) - if (include_variance_forest) { - forest_model_variance$sample_one_iteration( - forest_dataset = forest_dataset_train, - residual = outcome_train, - forest_samples = forest_samples_variance, - active_forest = active_forest_variance, - rng = rng, - forest_model_config = forest_model_config_variance, - global_model_config = global_model_config, - num_threads = num_threads, - keep_forest = keep_sample, - gfr = TRUE - ) - - # Cache train set predictions since they are already computed during sampling - if (keep_sample) { - sigma2_x_train_raw[, - sample_counter - ] <- forest_model_variance$get_cached_forest_predictions() - } - } - if (sample_sigma2_global) { - current_sigma2 <- sampleGlobalErrorVarianceOneIteration( - outcome_train, - forest_dataset_train, - rng, - a_global, - b_global - ) - if (keep_sample) { - global_var_samples[sample_counter] <- current_sigma2 - } - global_model_config$update_global_error_variance(current_sigma2) - } - if (sample_sigma2_leaf_tau) { - leaf_scale_tau_double <- sampleLeafVarianceOneIteration( - active_forest_tau, - rng, - a_leaf_tau, - b_leaf_tau - ) - current_leaf_scale_tau <- as.matrix(leaf_scale_tau_double) - if (keep_sample) { - leaf_scale_tau_samples[ - sample_counter - ] <- leaf_scale_tau_double - } - forest_model_config_mu$update_leaf_model_scale( - current_leaf_scale_mu - ) - } + # Compute sufficient statistics for regression of y - mu(X) on [tau(X)(1-Z), tau(X)Z] + s_tt0 <- sum(tau_x_raw_train * tau_x_raw_train * (Z_train == 0)) + s_tt1 <- sum(tau_x_raw_train * tau_x_raw_train * (Z_train == 1)) + s_ty0 <- sum( + tau_x_raw_train * partial_resid_mu_train * (Z_train == 0) + ) + s_ty1 <- sum( + tau_x_raw_train * partial_resid_mu_train * (Z_train == 1) + ) - # Sample random effects parameters (if requested) - if (has_rfx) { - rfx_model$sample_random_effect( - rfx_dataset_train, - outcome_train, - rfx_tracker_train, - rfx_samples, - keep_sample, - current_sigma2, - rng - ) - } - } - } + # Sample b0 (coefficient on tau(X)(1-Z)) and b1 (coefficient on tau(X)Z) + current_b_0 <- rnorm( + 1, + (s_ty0 / (s_tt0 + 2 * current_sigma2)), + sqrt(current_sigma2 / (s_tt0 + 2 * current_sigma2)) + ) + current_b_1 <- rnorm( + 1, + (s_ty1 / (s_tt1 + 2 * current_sigma2)), + sqrt(current_sigma2 / (s_tt1 + 2 * current_sigma2)) + ) - # Run MCMC - if (num_burnin + num_mcmc > 0) { - for (chain_num in 1:num_chains) { - if (num_gfr > 0) { - # Reset state of active_forest and forest_model based on a previous GFR sample - forest_ind <- num_gfr - chain_num - resetActiveForest( - active_forest_mu, - forest_samples_mu, - forest_ind - ) - resetForestModel( - forest_model_mu, - active_forest_mu, - forest_dataset_train, - outcome_train, - TRUE - ) - resetActiveForest( - active_forest_tau, - forest_samples_tau, - forest_ind - ) - resetForestModel( - forest_model_tau, - active_forest_tau, - forest_dataset_train, - outcome_train, - TRUE - ) - if (sample_sigma2_leaf_mu) { - leaf_scale_mu_double <- leaf_scale_mu_samples[ - forest_ind + 1 - ] - current_leaf_scale_mu <- as.matrix(leaf_scale_mu_double) - forest_model_config_mu$update_leaf_model_scale( - current_leaf_scale_mu - ) - } - if (sample_sigma2_leaf_tau) { - leaf_scale_tau_double <- leaf_scale_tau_samples[ - forest_ind + 1 - ] - current_leaf_scale_tau <- as.matrix(leaf_scale_tau_double) - forest_model_config_tau$update_leaf_model_scale( - current_leaf_scale_tau - ) - } - if (include_variance_forest) { - resetActiveForest( - active_forest_variance, - forest_samples_variance, - forest_ind - ) - resetForestModel( - forest_model_variance, - active_forest_variance, - forest_dataset_train, - outcome_train, - FALSE - ) - } - if (has_rfx) { - resetRandomEffectsModel( - rfx_model, - rfx_samples, - forest_ind, - sigma_alpha_init - ) - resetRandomEffectsTracker( - rfx_tracker_train, - rfx_model, - rfx_dataset_train, - outcome_train, - rfx_samples - ) - } - if (adaptive_coding) { - current_b_1 <- b_1_samples[forest_ind + 1] - current_b_0 <- b_0_samples[forest_ind + 1] - tau_basis_train <- (1 - Z_train) * - current_b_0 + - Z_train * current_b_1 - forest_dataset_train$update_basis(tau_basis_train) - if (has_test) { - tau_basis_test <- (1 - Z_test) * - current_b_0 + - Z_test * current_b_1 - forest_dataset_test$update_basis(tau_basis_test) - } - forest_model_tau$propagate_basis_update( - forest_dataset_train, - outcome_train, - active_forest_tau - ) - } - if (sample_sigma2_global) { - current_sigma2 <- global_var_samples[forest_ind + 1] - global_model_config$update_global_error_variance( - current_sigma2 - ) - } - } else if (has_prev_model) { - resetActiveForest( - active_forest_mu, - previous_forest_samples_mu, - previous_model_warmstart_sample_num - 1 - ) - resetForestModel( - forest_model_mu, - active_forest_mu, - forest_dataset_train, - outcome_train, - TRUE - ) - resetActiveForest( - active_forest_tau, - previous_forest_samples_tau, - previous_model_warmstart_sample_num - 1 - ) - resetForestModel( - forest_model_tau, - active_forest_tau, - forest_dataset_train, - outcome_train, - TRUE - ) - if (include_variance_forest) { - resetActiveForest( - active_forest_variance, - previous_forest_samples_variance, - previous_model_warmstart_sample_num - 1 - ) - resetForestModel( - forest_model_variance, - active_forest_variance, - forest_dataset_train, - outcome_train, - FALSE - ) - } - if ( - sample_sigma2_leaf_mu && - (!is.null(previous_leaf_var_mu_samples)) - ) { - leaf_scale_mu_double <- previous_leaf_var_mu_samples[ - previous_model_warmstart_sample_num - ] - current_leaf_scale_mu <- as.matrix(leaf_scale_mu_double) - forest_model_config_mu$update_leaf_model_scale( - current_leaf_scale_mu - ) - } - if ( - sample_sigma2_leaf_tau && - (!is.null(previous_leaf_var_tau_samples)) - ) { - leaf_scale_tau_double <- previous_leaf_var_tau_samples[ - previous_model_warmstart_sample_num - ] - current_leaf_scale_tau <- as.matrix(leaf_scale_tau_double) - forest_model_config_tau$update_leaf_model_scale( - current_leaf_scale_tau - ) - } - if (adaptive_coding) { - if (!is.null(previous_b_1_samples)) { - current_b_1 <- previous_b_1_samples[ - previous_model_warmstart_sample_num - ] - } - if (!is.null(previous_b_0_samples)) { - current_b_0 <- previous_b_0_samples[ - previous_model_warmstart_sample_num - ] - } - tau_basis_train <- (1 - Z_train) * - current_b_0 + - Z_train * current_b_1 - forest_dataset_train$update_basis(tau_basis_train) - if (has_test) { - tau_basis_test <- (1 - Z_test) * - current_b_0 + - Z_test * current_b_1 - forest_dataset_test$update_basis(tau_basis_test) - } - forest_model_tau$propagate_basis_update( - forest_dataset_train, - outcome_train, - active_forest_tau - ) - } - if (has_rfx) { - if (is.null(previous_rfx_samples)) { - warning( - "`previous_model_json` did not have any random effects samples, so the RFX sampler will be run from scratch while the forests and any other parameters are warm started" - ) - rootResetRandomEffectsModel( - rfx_model, - alpha_init, - xi_init, - sigma_alpha_init, - sigma_xi_init, - sigma_xi_shape, - sigma_xi_scale - ) - rootResetRandomEffectsTracker( - rfx_tracker_train, - rfx_model, - rfx_dataset_train, - outcome_train - ) - } else { - resetRandomEffectsModel( - rfx_model, - previous_rfx_samples, - previous_model_warmstart_sample_num - 1, - sigma_alpha_init - ) - resetRandomEffectsTracker( - rfx_tracker_train, - rfx_model, - rfx_dataset_train, - outcome_train, - rfx_samples - ) - } - } - if (sample_sigma2_global) { - if (!is.null(previous_global_var_samples)) { - current_sigma2 <- previous_global_var_samples[ - previous_model_warmstart_sample_num - ] - } - global_model_config$update_global_error_variance( - current_sigma2 - ) - } - } else { - resetActiveForest(active_forest_mu) - active_forest_mu$set_root_leaves(init_mu / num_trees_mu) - resetForestModel( - forest_model_mu, - active_forest_mu, - forest_dataset_train, - outcome_train, - TRUE - ) - resetActiveForest(active_forest_tau) - active_forest_tau$set_root_leaves(init_tau / num_trees_tau) - resetForestModel( - forest_model_tau, - active_forest_tau, - forest_dataset_train, - outcome_train, - TRUE - ) - if (sample_sigma2_leaf_mu) { - current_leaf_scale_mu <- as.matrix(sigma2_leaf_mu) - forest_model_config_mu$update_leaf_model_scale( - current_leaf_scale_mu - ) - } - if (sample_sigma2_leaf_tau) { - current_leaf_scale_tau <- as.matrix(sigma2_leaf_tau) - forest_model_config_tau$update_leaf_model_scale( - current_leaf_scale_tau - ) - } - if (include_variance_forest) { - resetActiveForest(active_forest_variance) - active_forest_variance$set_root_leaves( - log(variance_forest_init) / num_trees_variance - ) - resetForestModel( - forest_model_variance, - active_forest_variance, - forest_dataset_train, - outcome_train, - FALSE - ) - } - if (has_rfx) { - rootResetRandomEffectsModel( - rfx_model, - alpha_init, - xi_init, - sigma_alpha_init, - sigma_xi_init, - sigma_xi_shape, - sigma_xi_scale - ) - rootResetRandomEffectsTracker( - rfx_tracker_train, - rfx_model, - rfx_dataset_train, - outcome_train - ) - } - if (adaptive_coding) { - current_b_1 <- b_1 - current_b_0 <- b_0 - tau_basis_train <- (1 - Z_train) * - current_b_0 + - Z_train * current_b_1 - forest_dataset_train$update_basis(tau_basis_train) - if (has_test) { - tau_basis_test <- (1 - Z_test) * - current_b_0 + - Z_test * current_b_1 - forest_dataset_test$update_basis(tau_basis_test) - } - forest_model_tau$propagate_basis_update( - forest_dataset_train, - outcome_train, - active_forest_tau - ) - } - if (sample_sigma2_global) { - current_sigma2 <- sigma2_init - global_model_config$update_global_error_variance( - current_sigma2 - ) - } - } - for (i in (num_gfr + 1):num_samples) { - is_mcmc <- i > (num_gfr + num_burnin) - if (is_mcmc) { - mcmc_counter <- i - (num_gfr + num_burnin) - if (mcmc_counter %% keep_every == 0) { - keep_sample <- TRUE - } else { - keep_sample <- FALSE - } - } else { - if (keep_burnin) { - keep_sample <- TRUE - } else { - keep_sample <- FALSE - } - } - if (keep_sample) { - sample_counter <- sample_counter + 1 - } - # Print progress - if (verbose) { - if (num_burnin > 0) { - if ( - ((i - num_gfr) %% 100 == 0) || - ((i - num_gfr) == num_burnin) - ) { - cat( - "Sampling", - i - num_gfr, - "out of", - num_gfr, - "BCF burn-in draws\n" - ) - } - } - if (num_mcmc > 0) { - if ( - ((i - num_gfr - num_burnin) %% 100 == 0) || - (i == num_samples) - ) { - cat( - "Sampling", - i - num_burnin - num_gfr, - "out of", - num_mcmc, - "BCF MCMC draws\n" - ) - } - } - } - - if (probit_outcome_model) { - # Sample latent probit variable, z | - - mu_forest_pred <- active_forest_mu$predict( - forest_dataset_train - ) - tau_forest_pred <- active_forest_tau$predict( - forest_dataset_train - ) - forest_pred <- mu_forest_pred + tau_forest_pred - mu0 <- forest_pred[y_train == 0] - mu1 <- forest_pred[y_train == 1] - u0 <- runif(sum(y_train == 0), 0, pnorm(0 - mu0)) - u1 <- runif(sum(y_train == 1), pnorm(0 - mu1), 1) - resid_train[y_train == 0] <- mu0 + qnorm(u0) - resid_train[y_train == 1] <- mu1 + qnorm(u1) - - # Update outcome - outcome_train$update_data(resid_train - forest_pred) - } - - # Sample the prognostic forest - forest_model_mu$sample_one_iteration( - forest_dataset = forest_dataset_train, - residual = outcome_train, - forest_samples = forest_samples_mu, - active_forest = active_forest_mu, - rng = rng, - forest_model_config = forest_model_config_mu, - global_model_config = global_model_config, - num_threads = num_threads, - keep_forest = keep_sample, - gfr = FALSE - ) - - # Cache train set predictions since they are already computed during sampling - if (keep_sample) { - muhat_train_raw[, - sample_counter - ] <- forest_model_mu$get_cached_forest_predictions() - } - - # Sample variance parameters (if requested) - if (sample_sigma2_global) { - current_sigma2 <- sampleGlobalErrorVarianceOneIteration( - outcome_train, - forest_dataset_train, - rng, - a_global, - b_global - ) - global_model_config$update_global_error_variance( - current_sigma2 - ) - } - if (sample_sigma2_leaf_mu) { - leaf_scale_mu_double <- sampleLeafVarianceOneIteration( - active_forest_mu, - rng, - a_leaf_mu, - b_leaf_mu - ) - current_leaf_scale_mu <- as.matrix(leaf_scale_mu_double) - if (keep_sample) { - leaf_scale_mu_samples[ - sample_counter - ] <- leaf_scale_mu_double - } - forest_model_config_mu$update_leaf_model_scale( - current_leaf_scale_mu - ) - } - - # Sample the treatment forest - forest_model_tau$sample_one_iteration( - forest_dataset = forest_dataset_train, - residual = outcome_train, - forest_samples = forest_samples_tau, - active_forest = active_forest_tau, - rng = rng, - forest_model_config = forest_model_config_tau, - global_model_config = global_model_config, - num_threads = num_threads, - keep_forest = keep_sample, - gfr = FALSE - ) - - # Cannot cache train set predictions for tau because the cached predictions in the - # tracking data structures are pre-multiplied by the basis (treatment) - # ... - - # Sample coding parameters (if requested) - if (adaptive_coding) { - # Estimate mu(X) and tau(X) and compute y - mu(X) - mu_x_raw_train <- active_forest_mu$predict_raw( - forest_dataset_train - ) - tau_x_raw_train <- active_forest_tau$predict_raw( - forest_dataset_train - ) - partial_resid_mu_train <- resid_train - mu_x_raw_train - if (has_rfx) { - rfx_preds_train <- rfx_model$predict( - rfx_dataset_train, - rfx_tracker_train - ) - partial_resid_mu_train <- partial_resid_mu_train - - rfx_preds_train - } - - # Compute sufficient statistics for regression of y - mu(X) on [tau(X)(1-Z), tau(X)Z] - s_tt0 <- sum( - tau_x_raw_train * tau_x_raw_train * (Z_train == 0) - ) - s_tt1 <- sum( - tau_x_raw_train * tau_x_raw_train * (Z_train == 1) - ) - s_ty0 <- sum( - tau_x_raw_train * - partial_resid_mu_train * - (Z_train == 0) - ) - s_ty1 <- sum( - tau_x_raw_train * - partial_resid_mu_train * - (Z_train == 1) - ) - - # Sample b0 (coefficient on tau(X)(1-Z)) and b1 (coefficient on tau(X)Z) - current_b_0 <- rnorm( - 1, - (s_ty0 / (s_tt0 + 2 * current_sigma2)), - sqrt(current_sigma2 / (s_tt0 + 2 * current_sigma2)) - ) - current_b_1 <- rnorm( - 1, - (s_ty1 / (s_tt1 + 2 * current_sigma2)), - sqrt(current_sigma2 / (s_tt1 + 2 * current_sigma2)) - ) - - # Update basis for the leaf regression - tau_basis_train <- (1 - Z_train) * - current_b_0 + - Z_train * current_b_1 - forest_dataset_train$update_basis(tau_basis_train) - if (keep_sample) { - b_0_samples[sample_counter] <- current_b_0 - b_1_samples[sample_counter] <- current_b_1 - } - if (has_test) { - tau_basis_test <- (1 - Z_test) * - current_b_0 + - Z_test * current_b_1 - forest_dataset_test$update_basis(tau_basis_test) - } - - # Update leaf predictions and residual - forest_model_tau$propagate_basis_update( - forest_dataset_train, - outcome_train, - active_forest_tau - ) - } - - # Sample variance parameters (if requested) - if (include_variance_forest) { - forest_model_variance$sample_one_iteration( - forest_dataset = forest_dataset_train, - residual = outcome_train, - forest_samples = forest_samples_variance, - active_forest = active_forest_variance, - rng = rng, - forest_model_config = forest_model_config_variance, - global_model_config = global_model_config, - num_threads = num_threads, - keep_forest = keep_sample, - gfr = FALSE - ) - - # Cache train set predictions since they are already computed during sampling - if (keep_sample) { - sigma2_x_train_raw[, - sample_counter - ] <- forest_model_variance$get_cached_forest_predictions() - } - } - if (sample_sigma2_global) { - current_sigma2 <- sampleGlobalErrorVarianceOneIteration( - outcome_train, - forest_dataset_train, - rng, - a_global, - b_global - ) - if (keep_sample) { - global_var_samples[sample_counter] <- current_sigma2 - } - global_model_config$update_global_error_variance( - current_sigma2 - ) - } - if (sample_sigma2_leaf_tau) { - leaf_scale_tau_double <- sampleLeafVarianceOneIteration( - active_forest_tau, - rng, - a_leaf_tau, - b_leaf_tau - ) - current_leaf_scale_tau <- as.matrix(leaf_scale_tau_double) - if (keep_sample) { - leaf_scale_tau_samples[ - sample_counter - ] <- leaf_scale_tau_double - } - forest_model_config_tau$update_leaf_model_scale( - current_leaf_scale_tau - ) - } - - # Sample random effects parameters (if requested) - if (has_rfx) { - rfx_model$sample_random_effect( - rfx_dataset_train, - outcome_train, - rfx_tracker_train, - rfx_samples, - keep_sample, - current_sigma2, - rng - ) - } - } + # Update basis for the leaf regression + tau_basis_train <- (1 - Z_train) * + current_b_0 + + Z_train * current_b_1 + forest_dataset_train$update_basis(tau_basis_train) + if (keep_sample) { + b_0_samples[sample_counter] <- current_b_0 + b_1_samples[sample_counter] <- current_b_1 + } + if (has_test) { + tau_basis_test <- (1 - Z_test) * + current_b_0 + + Z_test * current_b_1 + forest_dataset_test$update_basis(tau_basis_test) } - } - # Remove GFR samples if they are not to be retained - if ((!keep_gfr) && (num_gfr > 0)) { - for (i in 1:num_gfr) { - forest_samples_mu$delete_sample(0) - forest_samples_tau$delete_sample(0) - if (include_variance_forest) { - forest_samples_variance$delete_sample(0) - } - if (has_rfx) { - rfx_samples$delete_sample(0) - } + # Update leaf predictions and residual + forest_model_tau$propagate_basis_update( + forest_dataset_train, + outcome_train, + active_forest_tau + ) + } + + # Sample variance parameters (if requested) + if (include_variance_forest) { + forest_model_variance$sample_one_iteration( + forest_dataset = forest_dataset_train, + residual = outcome_train, + forest_samples = forest_samples_variance, + active_forest = active_forest_variance, + rng = rng, + forest_model_config = forest_model_config_variance, + global_model_config = global_model_config, + num_threads = num_threads, + keep_forest = keep_sample, + gfr = TRUE + ) + + # Cache train set predictions since they are already computed during sampling + if (keep_sample) { + sigma2_x_train_raw[, + sample_counter + ] <- forest_model_variance$get_cached_forest_predictions() } - if (sample_sigma2_global) { - global_var_samples <- global_var_samples[ - (num_gfr + 1):length(global_var_samples) - ] + } + if (sample_sigma2_global) { + current_sigma2 <- sampleGlobalErrorVarianceOneIteration( + outcome_train, + forest_dataset_train, + rng, + a_global, + b_global + ) + if (keep_sample) { + global_var_samples[sample_counter] <- current_sigma2 + } + global_model_config$update_global_error_variance(current_sigma2) + } + if (sample_sigma2_leaf_tau) { + leaf_scale_tau_double <- sampleLeafVarianceOneIteration( + active_forest_tau, + rng, + a_leaf_tau, + b_leaf_tau + ) + current_leaf_scale_tau <- as.matrix(leaf_scale_tau_double) + if (keep_sample) { + leaf_scale_tau_samples[ + sample_counter + ] <- leaf_scale_tau_double } + forest_model_config_mu$update_leaf_model_scale( + current_leaf_scale_mu + ) + } + + # Sample random effects parameters (if requested) + if (has_rfx) { + rfx_model$sample_random_effect( + rfx_dataset_train, + outcome_train, + rfx_tracker_train, + rfx_samples, + keep_sample, + current_sigma2, + rng + ) + } + } + } + + # Run MCMC + if (num_burnin + num_mcmc > 0) { + for (chain_num in 1:num_chains) { + if (num_gfr > 0) { + # Reset state of active_forest and forest_model based on a previous GFR sample + forest_ind <- num_gfr - chain_num + resetActiveForest( + active_forest_mu, + forest_samples_mu, + forest_ind + ) + resetForestModel( + forest_model_mu, + active_forest_mu, + forest_dataset_train, + outcome_train, + TRUE + ) + resetActiveForest( + active_forest_tau, + forest_samples_tau, + forest_ind + ) + resetForestModel( + forest_model_tau, + active_forest_tau, + forest_dataset_train, + outcome_train, + TRUE + ) if (sample_sigma2_leaf_mu) { - leaf_scale_mu_samples <- leaf_scale_mu_samples[ - (num_gfr + 1):length(leaf_scale_mu_samples) - ] + leaf_scale_mu_double <- leaf_scale_mu_samples[ + forest_ind + 1 + ] + current_leaf_scale_mu <- as.matrix(leaf_scale_mu_double) + forest_model_config_mu$update_leaf_model_scale( + current_leaf_scale_mu + ) } if (sample_sigma2_leaf_tau) { - leaf_scale_tau_samples <- leaf_scale_tau_samples[ - (num_gfr + 1):length(leaf_scale_tau_samples) - ] + leaf_scale_tau_double <- leaf_scale_tau_samples[ + forest_ind + 1 + ] + current_leaf_scale_tau <- as.matrix(leaf_scale_tau_double) + forest_model_config_tau$update_leaf_model_scale( + current_leaf_scale_tau + ) + } + if (include_variance_forest) { + resetActiveForest( + active_forest_variance, + forest_samples_variance, + forest_ind + ) + resetForestModel( + forest_model_variance, + active_forest_variance, + forest_dataset_train, + outcome_train, + FALSE + ) + } + if (has_rfx) { + resetRandomEffectsModel( + rfx_model, + rfx_samples, + forest_ind, + sigma_alpha_init + ) + resetRandomEffectsTracker( + rfx_tracker_train, + rfx_model, + rfx_dataset_train, + outcome_train, + rfx_samples + ) } if (adaptive_coding) { - b_1_samples <- b_1_samples[(num_gfr + 1):length(b_1_samples)] - b_0_samples <- b_0_samples[(num_gfr + 1):length(b_0_samples)] + current_b_1 <- b_1_samples[forest_ind + 1] + current_b_0 <- b_0_samples[forest_ind + 1] + tau_basis_train <- (1 - Z_train) * + current_b_0 + + Z_train * current_b_1 + forest_dataset_train$update_basis(tau_basis_train) + if (has_test) { + tau_basis_test <- (1 - Z_test) * + current_b_0 + + Z_test * current_b_1 + forest_dataset_test$update_basis(tau_basis_test) + } + forest_model_tau$propagate_basis_update( + forest_dataset_train, + outcome_train, + active_forest_tau + ) + } + if (sample_sigma2_global) { + current_sigma2 <- global_var_samples[forest_ind + 1] + global_model_config$update_global_error_variance( + current_sigma2 + ) } - muhat_train_raw <- muhat_train_raw[, - (num_gfr + 1):ncol(muhat_train_raw) - ] + } else if (has_prev_model) { + resetActiveForest( + active_forest_mu, + previous_forest_samples_mu, + previous_model_warmstart_sample_num - 1 + ) + resetForestModel( + forest_model_mu, + active_forest_mu, + forest_dataset_train, + outcome_train, + TRUE + ) + resetActiveForest( + active_forest_tau, + previous_forest_samples_tau, + previous_model_warmstart_sample_num - 1 + ) + resetForestModel( + forest_model_tau, + active_forest_tau, + forest_dataset_train, + outcome_train, + TRUE + ) if (include_variance_forest) { - sigma2_x_train_raw <- sigma2_x_train_raw[, - (num_gfr + 1):ncol(sigma2_x_train_raw) + resetActiveForest( + active_forest_variance, + previous_forest_samples_variance, + previous_model_warmstart_sample_num - 1 + ) + resetForestModel( + forest_model_variance, + active_forest_variance, + forest_dataset_train, + outcome_train, + FALSE + ) + } + if ( + sample_sigma2_leaf_mu && + (!is.null(previous_leaf_var_mu_samples)) + ) { + leaf_scale_mu_double <- previous_leaf_var_mu_samples[ + previous_model_warmstart_sample_num + ] + current_leaf_scale_mu <- as.matrix(leaf_scale_mu_double) + forest_model_config_mu$update_leaf_model_scale( + current_leaf_scale_mu + ) + } + if ( + sample_sigma2_leaf_tau && + (!is.null(previous_leaf_var_tau_samples)) + ) { + leaf_scale_tau_double <- previous_leaf_var_tau_samples[ + previous_model_warmstart_sample_num + ] + current_leaf_scale_tau <- as.matrix(leaf_scale_tau_double) + forest_model_config_tau$update_leaf_model_scale( + current_leaf_scale_tau + ) + } + if (adaptive_coding) { + if (!is.null(previous_b_1_samples)) { + current_b_1 <- previous_b_1_samples[ + previous_model_warmstart_sample_num ] + } + if (!is.null(previous_b_0_samples)) { + current_b_0 <- previous_b_0_samples[ + previous_model_warmstart_sample_num + ] + } + tau_basis_train <- (1 - Z_train) * + current_b_0 + + Z_train * current_b_1 + forest_dataset_train$update_basis(tau_basis_train) + if (has_test) { + tau_basis_test <- (1 - Z_test) * + current_b_0 + + Z_test * current_b_1 + forest_dataset_test$update_basis(tau_basis_test) + } + forest_model_tau$propagate_basis_update( + forest_dataset_train, + outcome_train, + active_forest_tau + ) } - num_retained_samples <- num_retained_samples - num_gfr - } - - # Forest predictions - mu_hat_train <- muhat_train_raw * y_std_train + y_bar_train - if (adaptive_coding) { - tau_hat_train_raw <- forest_samples_tau$predict_raw( - forest_dataset_train + if (has_rfx) { + if (is.null(previous_rfx_samples)) { + warning( + "`previous_model_json` did not have any random effects samples, so the RFX sampler will be run from scratch while the forests and any other parameters are warm started" + ) + rootResetRandomEffectsModel( + rfx_model, + alpha_init, + xi_init, + sigma_alpha_init, + sigma_xi_init, + sigma_xi_shape, + sigma_xi_scale + ) + rootResetRandomEffectsTracker( + rfx_tracker_train, + rfx_model, + rfx_dataset_train, + outcome_train + ) + } else { + resetRandomEffectsModel( + rfx_model, + previous_rfx_samples, + previous_model_warmstart_sample_num - 1, + sigma_alpha_init + ) + resetRandomEffectsTracker( + rfx_tracker_train, + rfx_model, + rfx_dataset_train, + outcome_train, + rfx_samples + ) + } + } + if (sample_sigma2_global) { + if (!is.null(previous_global_var_samples)) { + current_sigma2 <- previous_global_var_samples[ + previous_model_warmstart_sample_num + ] + } + global_model_config$update_global_error_variance( + current_sigma2 + ) + } + } else { + resetActiveForest(active_forest_mu) + active_forest_mu$set_root_leaves(init_mu / num_trees_mu) + resetForestModel( + forest_model_mu, + active_forest_mu, + forest_dataset_train, + outcome_train, + TRUE ) - tau_hat_train <- t(t(tau_hat_train_raw) * (b_1_samples - b_0_samples)) * - y_std_train - } else { - tau_hat_train <- forest_samples_tau$predict_raw(forest_dataset_train) * - y_std_train - } - if (has_multivariate_treatment) { - tau_train_dim <- dim(tau_hat_train) - tau_num_obs <- tau_train_dim[1] - tau_num_samples <- tau_train_dim[3] - treatment_term_train <- matrix( - NA_real_, - nrow = tau_num_obs, - tau_num_samples + resetActiveForest(active_forest_tau) + active_forest_tau$set_root_leaves(init_tau / num_trees_tau) + resetForestModel( + forest_model_tau, + active_forest_tau, + forest_dataset_train, + outcome_train, + TRUE ) - for (i in 1:nrow(Z_train)) { - treatment_term_train[i, ] <- colSums( - tau_hat_train[i, , ] * Z_train[i, ] - ) + if (sample_sigma2_leaf_mu) { + current_leaf_scale_mu <- as.matrix(sigma2_leaf_mu) + forest_model_config_mu$update_leaf_model_scale( + current_leaf_scale_mu + ) + } + if (sample_sigma2_leaf_tau) { + current_leaf_scale_tau <- as.matrix(sigma2_leaf_tau) + forest_model_config_tau$update_leaf_model_scale( + current_leaf_scale_tau + ) + } + if (include_variance_forest) { + resetActiveForest(active_forest_variance) + active_forest_variance$set_root_leaves( + log(variance_forest_init) / num_trees_variance + ) + resetForestModel( + forest_model_variance, + active_forest_variance, + forest_dataset_train, + outcome_train, + FALSE + ) + } + if (has_rfx) { + rootResetRandomEffectsModel( + rfx_model, + alpha_init, + xi_init, + sigma_alpha_init, + sigma_xi_init, + sigma_xi_shape, + sigma_xi_scale + ) + rootResetRandomEffectsTracker( + rfx_tracker_train, + rfx_model, + rfx_dataset_train, + outcome_train + ) } - } else { - treatment_term_train <- tau_hat_train * as.numeric(Z_train) - } - y_hat_train <- mu_hat_train + treatment_term_train - if (has_test) { - mu_hat_test <- forest_samples_mu$predict(forest_dataset_test) * - y_std_train + - y_bar_train if (adaptive_coding) { - tau_hat_test_raw <- forest_samples_tau$predict_raw( - forest_dataset_test - ) - tau_hat_test <- t( - t(tau_hat_test_raw) * (b_1_samples - b_0_samples) - ) * - y_std_train + current_b_1 <- b_1 + current_b_0 <- b_0 + tau_basis_train <- (1 - Z_train) * + current_b_0 + + Z_train * current_b_1 + forest_dataset_train$update_basis(tau_basis_train) + if (has_test) { + tau_basis_test <- (1 - Z_test) * + current_b_0 + + Z_test * current_b_1 + forest_dataset_test$update_basis(tau_basis_test) + } + forest_model_tau$propagate_basis_update( + forest_dataset_train, + outcome_train, + active_forest_tau + ) + } + if (sample_sigma2_global) { + current_sigma2 <- sigma2_init + global_model_config$update_global_error_variance( + current_sigma2 + ) + } + } + for (i in (num_gfr + 1):num_samples) { + is_mcmc <- i > (num_gfr + num_burnin) + if (is_mcmc) { + mcmc_counter <- i - (num_gfr + num_burnin) + if (mcmc_counter %% keep_every == 0) { + keep_sample <- TRUE + } else { + keep_sample <- FALSE + } } else { - tau_hat_test <- forest_samples_tau$predict_raw( - forest_dataset_test - ) * - y_std_train + if (keep_burnin) { + keep_sample <- TRUE + } else { + keep_sample <- FALSE + } } - if (has_multivariate_treatment) { - tau_test_dim <- dim(tau_hat_test) - tau_num_obs <- tau_test_dim[1] - tau_num_samples <- tau_test_dim[3] - treatment_term_test <- matrix( - NA_real_, - nrow = tau_num_obs, - tau_num_samples - ) - for (i in 1:nrow(Z_test)) { - treatment_term_test[i, ] <- colSums( - tau_hat_test[i, , ] * Z_test[i, ] - ) + if (keep_sample) { + sample_counter <- sample_counter + 1 + } + # Print progress + if (verbose) { + if (num_burnin > 0) { + if ( + ((i - num_gfr) %% 100 == 0) || + ((i - num_gfr) == num_burnin) + ) { + cat( + "Sampling", + i - num_gfr, + "out of", + num_gfr, + "BCF burn-in draws\n" + ) } - } else { - treatment_term_test <- tau_hat_test * as.numeric(Z_test) + } + if (num_mcmc > 0) { + if ( + ((i - num_gfr - num_burnin) %% 100 == 0) || + (i == num_samples) + ) { + cat( + "Sampling", + i - num_burnin - num_gfr, + "out of", + num_mcmc, + "BCF MCMC draws\n" + ) + } + } } - y_hat_test <- mu_hat_test + treatment_term_test - } - if (include_variance_forest) { - sigma2_x_hat_train <- exp(sigma2_x_train_raw) - if (has_test) { - sigma2_x_hat_test <- forest_samples_variance$predict( - forest_dataset_test - ) + + if (probit_outcome_model) { + # Sample latent probit variable, z | - + mu_forest_pred <- active_forest_mu$predict( + forest_dataset_train + ) + tau_forest_pred <- active_forest_tau$predict( + forest_dataset_train + ) + forest_pred <- mu_forest_pred + tau_forest_pred + mu0 <- forest_pred[y_train == 0] + mu1 <- forest_pred[y_train == 1] + u0 <- runif(sum(y_train == 0), 0, pnorm(0 - mu0)) + u1 <- runif(sum(y_train == 1), pnorm(0 - mu1), 1) + resid_train[y_train == 0] <- mu0 + qnorm(u0) + resid_train[y_train == 1] <- mu1 + qnorm(u1) + + # Update outcome + outcome_train$update_data(resid_train - forest_pred) } - } - # Random effects predictions - if (has_rfx) { - rfx_preds_train <- rfx_samples$predict( - rfx_group_ids_train, - rfx_basis_train - ) * - y_std_train - y_hat_train <- y_hat_train + rfx_preds_train - } - if ((has_rfx_test) && (has_test)) { - rfx_preds_test <- rfx_samples$predict( - rfx_group_ids_test, - rfx_basis_test - ) * - y_std_train - y_hat_test <- y_hat_test + rfx_preds_test - } + # Sample the prognostic forest + forest_model_mu$sample_one_iteration( + forest_dataset = forest_dataset_train, + residual = outcome_train, + forest_samples = forest_samples_mu, + active_forest = active_forest_mu, + rng = rng, + forest_model_config = forest_model_config_mu, + global_model_config = global_model_config, + num_threads = num_threads, + keep_forest = keep_sample, + gfr = FALSE + ) - # Global error variance - if (sample_sigma2_global) { - sigma2_global_samples <- global_var_samples * (y_std_train^2) - } + # Cache train set predictions since they are already computed during sampling + if (keep_sample) { + muhat_train_raw[, + sample_counter + ] <- forest_model_mu$get_cached_forest_predictions() + } - # Leaf parameter variance for prognostic forest - if (sample_sigma2_leaf_mu) { - sigma2_leaf_mu_samples <- leaf_scale_mu_samples - } + # Sample variance parameters (if requested) + if (sample_sigma2_global) { + current_sigma2 <- sampleGlobalErrorVarianceOneIteration( + outcome_train, + forest_dataset_train, + rng, + a_global, + b_global + ) + global_model_config$update_global_error_variance( + current_sigma2 + ) + } + if (sample_sigma2_leaf_mu) { + leaf_scale_mu_double <- sampleLeafVarianceOneIteration( + active_forest_mu, + rng, + a_leaf_mu, + b_leaf_mu + ) + current_leaf_scale_mu <- as.matrix(leaf_scale_mu_double) + if (keep_sample) { + leaf_scale_mu_samples[ + sample_counter + ] <- leaf_scale_mu_double + } + forest_model_config_mu$update_leaf_model_scale( + current_leaf_scale_mu + ) + } - # Leaf parameter variance for treatment effect forest - if (sample_sigma2_leaf_tau) { - sigma2_leaf_tau_samples <- leaf_scale_tau_samples - } + # Sample the treatment forest + forest_model_tau$sample_one_iteration( + forest_dataset = forest_dataset_train, + residual = outcome_train, + forest_samples = forest_samples_tau, + active_forest = active_forest_tau, + rng = rng, + forest_model_config = forest_model_config_tau, + global_model_config = global_model_config, + num_threads = num_threads, + keep_forest = keep_sample, + gfr = FALSE + ) - # Rescale variance forest prediction by global sigma2 (sampled or constant) - if (include_variance_forest) { + # Cannot cache train set predictions for tau because the cached predictions in the + # tracking data structures are pre-multiplied by the basis (treatment) + # ... + + # Sample coding parameters (if requested) + if (adaptive_coding) { + # Estimate mu(X) and tau(X) and compute y - mu(X) + mu_x_raw_train <- active_forest_mu$predict_raw( + forest_dataset_train + ) + tau_x_raw_train <- active_forest_tau$predict_raw( + forest_dataset_train + ) + partial_resid_mu_train <- resid_train - mu_x_raw_train + if (has_rfx) { + rfx_preds_train <- rfx_model$predict( + rfx_dataset_train, + rfx_tracker_train + ) + partial_resid_mu_train <- partial_resid_mu_train - + rfx_preds_train + } + + # Compute sufficient statistics for regression of y - mu(X) on [tau(X)(1-Z), tau(X)Z] + s_tt0 <- sum( + tau_x_raw_train * tau_x_raw_train * (Z_train == 0) + ) + s_tt1 <- sum( + tau_x_raw_train * tau_x_raw_train * (Z_train == 1) + ) + s_ty0 <- sum( + tau_x_raw_train * + partial_resid_mu_train * + (Z_train == 0) + ) + s_ty1 <- sum( + tau_x_raw_train * + partial_resid_mu_train * + (Z_train == 1) + ) + + # Sample b0 (coefficient on tau(X)(1-Z)) and b1 (coefficient on tau(X)Z) + current_b_0 <- rnorm( + 1, + (s_ty0 / (s_tt0 + 2 * current_sigma2)), + sqrt(current_sigma2 / (s_tt0 + 2 * current_sigma2)) + ) + current_b_1 <- rnorm( + 1, + (s_ty1 / (s_tt1 + 2 * current_sigma2)), + sqrt(current_sigma2 / (s_tt1 + 2 * current_sigma2)) + ) + + # Update basis for the leaf regression + tau_basis_train <- (1 - Z_train) * + current_b_0 + + Z_train * current_b_1 + forest_dataset_train$update_basis(tau_basis_train) + if (keep_sample) { + b_0_samples[sample_counter] <- current_b_0 + b_1_samples[sample_counter] <- current_b_1 + } + if (has_test) { + tau_basis_test <- (1 - Z_test) * + current_b_0 + + Z_test * current_b_1 + forest_dataset_test$update_basis(tau_basis_test) + } + + # Update leaf predictions and residual + forest_model_tau$propagate_basis_update( + forest_dataset_train, + outcome_train, + active_forest_tau + ) + } + + # Sample variance parameters (if requested) + if (include_variance_forest) { + forest_model_variance$sample_one_iteration( + forest_dataset = forest_dataset_train, + residual = outcome_train, + forest_samples = forest_samples_variance, + active_forest = active_forest_variance, + rng = rng, + forest_model_config = forest_model_config_variance, + global_model_config = global_model_config, + num_threads = num_threads, + keep_forest = keep_sample, + gfr = FALSE + ) + + # Cache train set predictions since they are already computed during sampling + if (keep_sample) { + sigma2_x_train_raw[, + sample_counter + ] <- forest_model_variance$get_cached_forest_predictions() + } + } if (sample_sigma2_global) { - sigma2_x_hat_train <- sapply(1:num_retained_samples, function(i) { - sigma2_x_hat_train[, i] * sigma2_global_samples[i] - }) - if (has_test) { - sigma2_x_hat_test <- sapply( - 1:num_retained_samples, - function(i) { - sigma2_x_hat_test[, i] * sigma2_global_samples[i] - } - ) - } - } else { - sigma2_x_hat_train <- sigma2_x_hat_train * - sigma2_init * - y_std_train * - y_std_train - if (has_test) { - sigma2_x_hat_test <- sigma2_x_hat_test * - sigma2_init * - y_std_train * - y_std_train - } + current_sigma2 <- sampleGlobalErrorVarianceOneIteration( + outcome_train, + forest_dataset_train, + rng, + a_global, + b_global + ) + if (keep_sample) { + global_var_samples[sample_counter] <- current_sigma2 + } + global_model_config$update_global_error_variance( + current_sigma2 + ) + } + if (sample_sigma2_leaf_tau) { + leaf_scale_tau_double <- sampleLeafVarianceOneIteration( + active_forest_tau, + rng, + a_leaf_tau, + b_leaf_tau + ) + current_leaf_scale_tau <- as.matrix(leaf_scale_tau_double) + if (keep_sample) { + leaf_scale_tau_samples[ + sample_counter + ] <- leaf_scale_tau_double + } + forest_model_config_tau$update_leaf_model_scale( + current_leaf_scale_tau + ) } - } - # Return results as a list - if (include_variance_forest) { - num_variance_covariates <- sum(variable_weights_variance > 0) - } else { - num_variance_covariates <- 0 - } - model_params <- list( - "initial_sigma2" = sigma2_init, - "initial_sigma2_leaf_mu" = sigma2_leaf_mu, - "initial_sigma2_leaf_tau" = sigma2_leaf_tau, - "initial_b_0" = b_0, - "initial_b_1" = b_1, - "a_global" = a_global, - "b_global" = b_global, - "a_leaf_mu" = a_leaf_mu, - "b_leaf_mu" = b_leaf_mu, - "a_leaf_tau" = a_leaf_tau, - "b_leaf_tau" = b_leaf_tau, - "a_forest" = a_forest, - "b_forest" = b_forest, - "outcome_mean" = y_bar_train, - "outcome_scale" = y_std_train, - "standardize" = standardize, - "num_covariates" = num_cov_orig, - "num_prognostic_covariates" = sum(variable_weights_mu > 0), - "num_treatment_covariates" = sum(variable_weights_tau > 0), - "num_variance_covariates" = num_variance_covariates, - "treatment_dim" = ncol(Z_train), - "propensity_covariate" = propensity_covariate, - "binary_treatment" = binary_treatment, - "multivariate_treatment" = has_multivariate_treatment, - "adaptive_coding" = adaptive_coding, - "internal_propensity_model" = internal_propensity_model, - "num_samples" = num_retained_samples, - "num_gfr" = num_gfr, - "num_burnin" = num_burnin, - "num_mcmc" = num_mcmc, - "keep_every" = keep_every, - "num_chains" = num_chains, - "has_rfx" = has_rfx, - "has_rfx_basis" = has_basis_rfx, - "num_rfx_basis" = num_basis_rfx, - "include_variance_forest" = include_variance_forest, - "sample_sigma2_global" = sample_sigma2_global, - "sample_sigma2_leaf_mu" = sample_sigma2_leaf_mu, - "sample_sigma2_leaf_tau" = sample_sigma2_leaf_tau, - "probit_outcome_model" = probit_outcome_model - ) - result <- list( - "forests_mu" = forest_samples_mu, - "forests_tau" = forest_samples_tau, - "model_params" = model_params, - "mu_hat_train" = mu_hat_train, - "tau_hat_train" = tau_hat_train, - "y_hat_train" = y_hat_train, - "train_set_metadata" = X_train_metadata - ) - if (has_test) { - result[["mu_hat_test"]] = mu_hat_test - } - if (has_test) { - result[["tau_hat_test"]] = tau_hat_test - } - if (has_test) { - result[["y_hat_test"]] = y_hat_test - } - if (include_variance_forest) { - result[["forests_variance"]] = forest_samples_variance - result[["sigma2_x_hat_train"]] = sigma2_x_hat_train - if (has_test) result[["sigma2_x_hat_test"]] = sigma2_x_hat_test + # Sample random effects parameters (if requested) + if (has_rfx) { + rfx_model$sample_random_effect( + rfx_dataset_train, + outcome_train, + rfx_tracker_train, + rfx_samples, + keep_sample, + current_sigma2, + rng + ) + } + } + } + } + + # Remove GFR samples if they are not to be retained + if ((!keep_gfr) && (num_gfr > 0)) { + for (i in 1:num_gfr) { + forest_samples_mu$delete_sample(0) + forest_samples_tau$delete_sample(0) + if (include_variance_forest) { + forest_samples_variance$delete_sample(0) + } + if (has_rfx) { + rfx_samples$delete_sample(0) + } } if (sample_sigma2_global) { - result[["sigma2_global_samples"]] = sigma2_global_samples + global_var_samples <- global_var_samples[ + (num_gfr + 1):length(global_var_samples) + ] } if (sample_sigma2_leaf_mu) { - result[["sigma2_leaf_mu_samples"]] = sigma2_leaf_mu_samples + leaf_scale_mu_samples <- leaf_scale_mu_samples[ + (num_gfr + 1):length(leaf_scale_mu_samples) + ] } if (sample_sigma2_leaf_tau) { - result[["sigma2_leaf_tau_samples"]] = sigma2_leaf_tau_samples + leaf_scale_tau_samples <- leaf_scale_tau_samples[ + (num_gfr + 1):length(leaf_scale_tau_samples) + ] } if (adaptive_coding) { - result[["b_0_samples"]] = b_0_samples - result[["b_1_samples"]] = b_1_samples - } - if (has_rfx) { - result[["rfx_samples"]] = rfx_samples - result[["rfx_preds_train"]] = rfx_preds_train - result[["rfx_unique_group_ids"]] = levels(group_ids_factor) + b_1_samples <- b_1_samples[(num_gfr + 1):length(b_1_samples)] + b_0_samples <- b_0_samples[(num_gfr + 1):length(b_0_samples)] } - if ((has_rfx_test) && (has_test)) { - result[["rfx_preds_test"]] = rfx_preds_test + muhat_train_raw <- muhat_train_raw[, + (num_gfr + 1):ncol(muhat_train_raw) + ] + if (include_variance_forest) { + sigma2_x_train_raw <- sigma2_x_train_raw[, + (num_gfr + 1):ncol(sigma2_x_train_raw) + ] + } + num_retained_samples <- num_retained_samples - num_gfr + } + + # Forest predictions + mu_hat_train <- muhat_train_raw * y_std_train + y_bar_train + if (adaptive_coding) { + tau_hat_train_raw <- forest_samples_tau$predict_raw( + forest_dataset_train + ) + tau_hat_train <- t(t(tau_hat_train_raw) * (b_1_samples - b_0_samples)) * + y_std_train + } else { + tau_hat_train <- forest_samples_tau$predict_raw(forest_dataset_train) * + y_std_train + } + if (has_multivariate_treatment) { + tau_train_dim <- dim(tau_hat_train) + tau_num_obs <- tau_train_dim[1] + tau_num_samples <- tau_train_dim[3] + treatment_term_train <- matrix( + NA_real_, + nrow = tau_num_obs, + tau_num_samples + ) + for (i in 1:nrow(Z_train)) { + treatment_term_train[i, ] <- colSums( + tau_hat_train[i, , ] * Z_train[i, ] + ) + } + } else { + treatment_term_train <- tau_hat_train * as.numeric(Z_train) + } + y_hat_train <- mu_hat_train + treatment_term_train + if (has_test) { + mu_hat_test <- forest_samples_mu$predict(forest_dataset_test) * + y_std_train + + y_bar_train + if (adaptive_coding) { + tau_hat_test_raw <- forest_samples_tau$predict_raw( + forest_dataset_test + ) + tau_hat_test <- t( + t(tau_hat_test_raw) * (b_1_samples - b_0_samples) + ) * + y_std_train + } else { + tau_hat_test <- forest_samples_tau$predict_raw( + forest_dataset_test + ) * + y_std_train } - if (internal_propensity_model) { - result[["bart_propensity_model"]] = bart_model_propensity + if (has_multivariate_treatment) { + tau_test_dim <- dim(tau_hat_test) + tau_num_obs <- tau_test_dim[1] + tau_num_samples <- tau_test_dim[3] + treatment_term_test <- matrix( + NA_real_, + nrow = tau_num_obs, + tau_num_samples + ) + for (i in 1:nrow(Z_test)) { + treatment_term_test[i, ] <- colSums( + tau_hat_test[i, , ] * Z_test[i, ] + ) + } + } else { + treatment_term_test <- tau_hat_test * as.numeric(Z_test) } - class(result) <- "bcfmodel" - - return(result) + y_hat_test <- mu_hat_test + treatment_term_test + } + if (include_variance_forest) { + sigma2_x_hat_train <- exp(sigma2_x_train_raw) + if (has_test) { + sigma2_x_hat_test <- forest_samples_variance$predict( + forest_dataset_test + ) + } + } + + # Random effects predictions + if (has_rfx) { + rfx_preds_train <- rfx_samples$predict( + rfx_group_ids_train, + rfx_basis_train + ) * + y_std_train + y_hat_train <- y_hat_train + rfx_preds_train + } + if ((has_rfx_test) && (has_test)) { + rfx_preds_test <- rfx_samples$predict( + rfx_group_ids_test, + rfx_basis_test + ) * + y_std_train + y_hat_test <- y_hat_test + rfx_preds_test + } + + # Global error variance + if (sample_sigma2_global) { + sigma2_global_samples <- global_var_samples * (y_std_train^2) + } + + # Leaf parameter variance for prognostic forest + if (sample_sigma2_leaf_mu) { + sigma2_leaf_mu_samples <- leaf_scale_mu_samples + } + + # Leaf parameter variance for treatment effect forest + if (sample_sigma2_leaf_tau) { + sigma2_leaf_tau_samples <- leaf_scale_tau_samples + } + + # Rescale variance forest prediction by global sigma2 (sampled or constant) + if (include_variance_forest) { + if (sample_sigma2_global) { + sigma2_x_hat_train <- sapply(1:num_retained_samples, function(i) { + sigma2_x_hat_train[, i] * sigma2_global_samples[i] + }) + if (has_test) { + sigma2_x_hat_test <- sapply( + 1:num_retained_samples, + function(i) { + sigma2_x_hat_test[, i] * sigma2_global_samples[i] + } + ) + } + } else { + sigma2_x_hat_train <- sigma2_x_hat_train * + sigma2_init * + y_std_train * + y_std_train + if (has_test) { + sigma2_x_hat_test <- sigma2_x_hat_test * + sigma2_init * + y_std_train * + y_std_train + } + } + } + + # Return results as a list + if (include_variance_forest) { + num_variance_covariates <- sum(variable_weights_variance > 0) + } else { + num_variance_covariates <- 0 + } + model_params <- list( + "initial_sigma2" = sigma2_init, + "initial_sigma2_leaf_mu" = sigma2_leaf_mu, + "initial_sigma2_leaf_tau" = sigma2_leaf_tau, + "initial_b_0" = b_0, + "initial_b_1" = b_1, + "a_global" = a_global, + "b_global" = b_global, + "a_leaf_mu" = a_leaf_mu, + "b_leaf_mu" = b_leaf_mu, + "a_leaf_tau" = a_leaf_tau, + "b_leaf_tau" = b_leaf_tau, + "a_forest" = a_forest, + "b_forest" = b_forest, + "outcome_mean" = y_bar_train, + "outcome_scale" = y_std_train, + "standardize" = standardize, + "num_covariates" = num_cov_orig, + "num_prognostic_covariates" = sum(variable_weights_mu > 0), + "num_treatment_covariates" = sum(variable_weights_tau > 0), + "num_variance_covariates" = num_variance_covariates, + "treatment_dim" = ncol(Z_train), + "propensity_covariate" = propensity_covariate, + "binary_treatment" = binary_treatment, + "multivariate_treatment" = has_multivariate_treatment, + "adaptive_coding" = adaptive_coding, + "internal_propensity_model" = internal_propensity_model, + "num_samples" = num_retained_samples, + "num_gfr" = num_gfr, + "num_burnin" = num_burnin, + "num_mcmc" = num_mcmc, + "keep_every" = keep_every, + "num_chains" = num_chains, + "has_rfx" = has_rfx, + "has_rfx_basis" = has_basis_rfx, + "num_rfx_basis" = num_basis_rfx, + "include_variance_forest" = include_variance_forest, + "sample_sigma2_global" = sample_sigma2_global, + "sample_sigma2_leaf_mu" = sample_sigma2_leaf_mu, + "sample_sigma2_leaf_tau" = sample_sigma2_leaf_tau, + "probit_outcome_model" = probit_outcome_model + ) + result <- list( + "forests_mu" = forest_samples_mu, + "forests_tau" = forest_samples_tau, + "model_params" = model_params, + "mu_hat_train" = mu_hat_train, + "tau_hat_train" = tau_hat_train, + "y_hat_train" = y_hat_train, + "train_set_metadata" = X_train_metadata + ) + if (has_test) { + result[["mu_hat_test"]] = mu_hat_test + } + if (has_test) { + result[["tau_hat_test"]] = tau_hat_test + } + if (has_test) { + result[["y_hat_test"]] = y_hat_test + } + if (include_variance_forest) { + result[["forests_variance"]] = forest_samples_variance + result[["sigma2_x_hat_train"]] = sigma2_x_hat_train + if (has_test) result[["sigma2_x_hat_test"]] = sigma2_x_hat_test + } + if (sample_sigma2_global) { + result[["sigma2_global_samples"]] = sigma2_global_samples + } + if (sample_sigma2_leaf_mu) { + result[["sigma2_leaf_mu_samples"]] = sigma2_leaf_mu_samples + } + if (sample_sigma2_leaf_tau) { + result[["sigma2_leaf_tau_samples"]] = sigma2_leaf_tau_samples + } + if (adaptive_coding) { + result[["b_0_samples"]] = b_0_samples + result[["b_1_samples"]] = b_1_samples + } + if (has_rfx) { + result[["rfx_samples"]] = rfx_samples + result[["rfx_preds_train"]] = rfx_preds_train + result[["rfx_unique_group_ids"]] = levels(group_ids_factor) + } + if ((has_rfx_test) && (has_test)) { + result[["rfx_preds_test"]] = rfx_preds_test + } + if (internal_propensity_model) { + result[["bart_propensity_model"]] = bart_model_propensity + } + class(result) <- "bcfmodel" + + return(result) } #' Predict from a sampled BCF model on new data @@ -2560,6 +2554,7 @@ bcf <- function( #' @param rfx_basis (Optional) Test set basis for "random-slope" regression in additive random effects model. #' @param type (Optional) Type of prediction to return. Options are "mean", which averages the predictions from every draw of a BART model, and "posterior", which returns the entire matrix of posterior predictions. Default: "posterior". #' @param terms (Optional) Which model terms to include in the prediction. This can be a single term or a list of model terms. Options include "y_hat", "prognostic_function", "cate", "rfx", "variance_forest", or "all". If a model doesn't have random effects or variance forest predictions, but one of those terms is request, the request will simply be ignored. If none of the requested terms are present in a model, this function will return `NULL` along with a warning. Default: "all". +#' @param scale (Optional) Scale of mean function predictions. Options are "linear", which returns predictions on the original scale of the mean forest / RFX terms, and "probability", which transforms predictions into a probability of observing `y == 1`. "probability" is only valid for models fit with a probit outcome model. Default: "linear". #' @param ... (Optional) Other prediction parameters. #' #' @return List of prediction matrices or single prediction matrix / vector, depending on the terms requested. @@ -2612,275 +2607,311 @@ bcf <- function( #' num_burnin = 0, num_mcmc = 10) #' preds <- predict(bcf_model, X_test, Z_test, pi_test) predict.bcfmodel <- function( - object, - X, - Z, - propensity = NULL, - rfx_group_ids = NULL, - rfx_basis = NULL, - type = "posterior", - terms = "all", - ... + object, + X, + Z, + propensity = NULL, + rfx_group_ids = NULL, + rfx_basis = NULL, + type = "posterior", + terms = "all", + scale = "linear", + ... ) { - # Handle prediction type - if (!is.character(type)) { - stop("type must be a string or character vector") - } - if (!(type %in% c("mean", "posterior"))) { - stop("type must either be 'mean' or 'posterior") - } - predict_mean <- type == "mean" - - # Handle prediction terms - if (!is.character(terms)) { - stop("type must be a string or character vector") - } - num_terms <- length(terms) - has_mu_forest <- T - has_tau_forest <- T - has_variance_forest <- object$model_params$include_variance_forest - has_rfx <- object$model_params$has_rfx - has_y_hat <- T - predict_y_hat <- (((has_y_hat) && ("y_hat" %in% terms)) || - ((has_y_hat) && ("all" %in% terms))) - predict_mu_forest <- (((has_mu_forest) && ("prognostic_function" %in% terms)) || - ((has_mu_forest) && ("all" %in% terms))) - predict_tau_forest <- (((has_tau_forest) && ("cate" %in% terms)) || - ((has_tau_forest) && ("all" %in% terms))) - predict_rfx <- (((has_rfx) && ("rfx" %in% terms)) || - ((has_rfx) && ("all" %in% terms))) - predict_variance_forest <- (((has_variance_forest) && - ("variance_forest" %in% terms)) || - ((has_variance_forest) && ("all" %in% terms))) - predict_count <- sum(c( - predict_y_hat, - predict_mu_forest, - predict_tau_forest, - predict_rfx, - predict_variance_forest + # Handle mean function scale + if (!is.character(scale)) { + stop("scale must be a string or character vector") + } + if (!(scale %in% c("linear", "probability"))) { + stop("scale must either be 'linear' or 'probability'") + } + is_probit <- object$model_params$probit_outcome_model + if ((scale == "probability") && (!is_probit)) { + stop( + "scale cannot be 'probability' for models not fit with a probit outcome model" + ) + } + probability_scale <- scale == "probability" + + # Handle prediction type + if (!is.character(type)) { + stop("type must be a string or character vector") + } + if (!(type %in% c("mean", "posterior"))) { + stop("type must either be 'mean' or 'posterior") + } + predict_mean <- type == "mean" + + # Handle prediction terms + if (!is.character(terms)) { + stop("type must be a string or character vector") + } + num_terms <- length(terms) + has_mu_forest <- T + has_tau_forest <- T + has_variance_forest <- object$model_params$include_variance_forest + has_rfx <- object$model_params$has_rfx + has_y_hat <- T + predict_y_hat <- (((has_y_hat) && ("y_hat" %in% terms)) || + ((has_y_hat) && ("all" %in% terms))) + predict_mu_forest <- (((has_mu_forest) && + ("prognostic_function" %in% terms)) || + ((has_mu_forest) && ("all" %in% terms))) + predict_tau_forest <- (((has_tau_forest) && ("cate" %in% terms)) || + ((has_tau_forest) && ("all" %in% terms))) + predict_rfx <- (((has_rfx) && ("rfx" %in% terms)) || + ((has_rfx) && ("all" %in% terms))) + predict_variance_forest <- (((has_variance_forest) && + ("variance_forest" %in% terms)) || + ((has_variance_forest) && ("all" %in% terms))) + predict_count <- sum(c( + predict_y_hat, + predict_mu_forest, + predict_tau_forest, + predict_rfx, + predict_variance_forest + )) + if (predict_count == 0) { + warning(paste0( + "None of the requested model terms, ", + paste(terms, collapse = ", "), + ", were fit in this model" )) - if (predict_count == 0) { - warning(paste0( - "None of the requested model terms, ", - paste(terms, collapse = ", "), - ", were fit in this model" - )) - return(NULL) + return(NULL) + } + predict_rfx_intermediate <- (predict_y_hat && has_rfx) + predict_mu_forest_intermediate <- (predict_y_hat && has_mu_forest) + predict_tau_forest_intermediate <- (predict_y_hat && has_tau_forest) + + # Preprocess covariates + if ((!is.data.frame(X)) && (!is.matrix(X))) { + stop("X must be a matrix or dataframe") + } + train_set_metadata <- object$train_set_metadata + X <- preprocessPredictionData(X, train_set_metadata) + + # Convert all input data to matrices if not already converted + if ((is.null(dim(Z))) && (!is.null(Z))) { + Z <- as.matrix(as.numeric(Z)) + } + if ((is.null(dim(propensity))) && (!is.null(propensity))) { + propensity <- as.matrix(propensity) + } + if ((is.null(dim(rfx_basis))) && (!is.null(rfx_basis))) { + rfx_basis <- as.matrix(rfx_basis) + } + + # Data checks + if ( + (object$model_params$propensity_covariate != "none") && + (is.null(propensity)) + ) { + if (!object$model_params$internal_propensity_model) { + stop("propensity must be provided for this model") + } + # Compute propensity score using the internal bart model + propensity <- rowMeans(predict(object$bart_propensity_model, X)$y_hat) + } + if (nrow(X) != nrow(Z)) { + stop("X and Z must have the same number of rows") + } + if (object$model_params$num_covariates != ncol(X)) { + stop( + "X and must have the same number of columns as the covariates used to train the model" + ) + } + if ((object$model_params$has_rfx) && (is.null(rfx_group_ids))) { + stop( + "Random effect group labels (rfx_group_ids) must be provided for this model" + ) + } + if ((object$model_params$has_rfx_basis) && (is.null(rfx_basis))) { + stop("Random effects basis (rfx_basis) must be provided for this model") + } + if ( + (object$model_params$num_rfx_basis > 0) && + (ncol(rfx_basis) != object$model_params$num_rfx_basis) + ) { + stop( + "Random effects basis has a different dimension than the basis used to train this model" + ) + } + + # Recode group IDs to integer vector (if passed as, for example, a vector of county names, etc...) + has_rfx <- FALSE + if (!is.null(rfx_group_ids)) { + rfx_unique_group_ids <- object$rfx_unique_group_ids + group_ids_factor <- factor(rfx_group_ids, levels = rfx_unique_group_ids) + if (sum(is.na(group_ids_factor)) > 0) { + stop( + "All random effect group labels provided in rfx_group_ids must be present in rfx_group_ids_train" + ) + } + rfx_group_ids <- as.integer(group_ids_factor) + has_rfx <- TRUE + } + + # Produce basis for the "intercept-only" random effects case + if ((object$model_params$has_rfx) && (is.null(rfx_basis))) { + rfx_basis <- matrix(rep(1, nrow(X)), ncol = 1) + } + + # Add propensities to covariate set if necessary + if (object$model_params$propensity_covariate != "none") { + X_combined <- cbind(X, propensity) + } + + # Create prediction datasets + forest_dataset_pred <- createForestDataset(X_combined, Z) + + # Compute variance forest predictions + if (predict_variance_forest) { + s_x_raw <- object$forests_variance$predict(forest_dataset_pred) + } + + # Scale variance forest predictions + num_samples <- object$model_params$num_samples + y_std <- object$model_params$outcome_scale + y_bar <- object$model_params$outcome_mean + initial_sigma2 <- object$model_params$initial_sigma2 + if (predict_variance_forest) { + if (object$model_params$sample_sigma2_global) { + sigma2_global_samples <- object$sigma2_global_samples + variance_forest_predictions <- sapply(1:num_samples, function(i) { + s_x_raw[, i] * sigma2_global_samples[i] + }) + } else { + variance_forest_predictions <- s_x_raw * + initial_sigma2 * + y_std * + y_std } - predict_rfx_intermediate <- (predict_y_hat && has_rfx) - predict_mu_forest_intermediate <- (predict_y_hat && has_mu_forest) - predict_tau_forest_intermediate <- (predict_y_hat && has_tau_forest) - - # Preprocess covariates - if ((!is.data.frame(X)) && (!is.matrix(X))) { - stop("X must be a matrix or dataframe") + if (predict_mean) { + variance_forest_predictions <- rowMeans(variance_forest_predictions) } - train_set_metadata <- object$train_set_metadata - X <- preprocessPredictionData(X, train_set_metadata) + } - # Convert all input data to matrices if not already converted - if ((is.null(dim(Z))) && (!is.null(Z))) { - Z <- as.matrix(as.numeric(Z)) - } - if ((is.null(dim(propensity))) && (!is.null(propensity))) { - propensity <- as.matrix(propensity) - } - if ((is.null(dim(rfx_basis))) && (!is.null(rfx_basis))) { - rfx_basis <- as.matrix(rfx_basis) - } + # Compute mu forest predictions + if (predict_mu_forest || predict_mu_forest_intermediate) { + mu_hat <- object$forests_mu$predict(forest_dataset_pred) * y_std + y_bar + } - # Data checks - if ( - (object$model_params$propensity_covariate != "none") && - (is.null(propensity)) - ) { - if (!object$model_params$internal_propensity_model) { - stop("propensity must be provided for this model") - } - # Compute propensity score using the internal bart model - propensity <- rowMeans(predict(object$bart_propensity_model, X)$y_hat) - } - if (nrow(X) != nrow(Z)) { - stop("X and Z must have the same number of rows") - } - if (object$model_params$num_covariates != ncol(X)) { - stop( - "X and must have the same number of columns as the covariates used to train the model" - ) - } - if ((object$model_params$has_rfx) && (is.null(rfx_group_ids))) { - stop( - "Random effect group labels (rfx_group_ids) must be provided for this model" - ) - } - if ((object$model_params$has_rfx_basis) && (is.null(rfx_basis))) { - stop("Random effects basis (rfx_basis) must be provided for this model") - } - if ( - (object$model_params$num_rfx_basis > 0) && - (ncol(rfx_basis) != object$model_params$num_rfx_basis) - ) { - stop( - "Random effects basis has a different dimension than the basis used to train this model" - ) + # Compute CATE forest predictions + if (predict_tau_forest || predict_tau_forest_intermediate) { + if (object$model_params$adaptive_coding) { + tau_hat_raw <- object$forests_tau$predict_raw(forest_dataset_pred) + tau_hat <- t( + t(tau_hat_raw) * (object$b_1_samples - object$b_0_samples) + ) * + y_std + } else { + tau_hat <- object$forests_tau$predict_raw(forest_dataset_pred) * y_std + } + if (object$model_params$multivariate_treatment) { + tau_dim <- dim(tau_hat) + tau_num_obs <- tau_dim[1] + tau_num_samples <- tau_dim[3] + treatment_term <- matrix(NA_real_, nrow = tau_num_obs, tau_num_samples) + for (i in 1:nrow(Z)) { + treatment_term[i, ] <- colSums(tau_hat[i, , ] * Z[i, ]) + } + } else { + treatment_term <- tau_hat * as.numeric(Z) } + } - # Recode group IDs to integer vector (if passed as, for example, a vector of county names, etc...) - has_rfx <- FALSE - if (!is.null(rfx_group_ids)) { - rfx_unique_group_ids <- object$rfx_unique_group_ids - group_ids_factor <- factor(rfx_group_ids, levels = rfx_unique_group_ids) - if (sum(is.na(group_ids_factor)) > 0) { - stop( - "All random effect group labels provided in rfx_group_ids must be present in rfx_group_ids_train" - ) - } - rfx_group_ids <- as.integer(group_ids_factor) - has_rfx <- TRUE + # Compute rfx predictions + if (predict_rfx) { + rfx_predictions <- object$rfx_samples$predict( + rfx_group_ids, + rfx_basis + ) * + y_std + if (predict_mean) { + rfx_predictions <- rowMeans(rfx_predictions) } + } - # Produce basis for the "intercept-only" random effects case - if ((object$model_params$has_rfx) && (is.null(rfx_basis))) { - rfx_basis <- matrix(rep(1, nrow(X)), ncol = 1) + # Combine into y hat predictions + if (probability_scale) { + if (has_rfx) { + y_hat <- pnorm(mu_hat + treatment_term + rfx_predictions) + rfx_predictions <- pnorm(rfx_predictions) + } else { + y_hat <- pnorm(mu_hat + treatment_term) } - - # Add propensities to covariate set if necessary - if (object$model_params$propensity_covariate != "none") { - X_combined <- cbind(X, propensity) + mu_hat <- pnorm(mu_hat) + tau_hat <- pnorm(tau_hat) + } else { + if (has_rfx) { + y_hat <- mu_hat + treatment_term + rfx_predictions + } else { + y_hat <- mu_hat + treatment_term } + } - # Create prediction datasets - forest_dataset_pred <- createForestDataset(X_combined, Z) - - # Compute mu forest predictions - num_samples <- object$model_params$num_samples - y_std <- object$model_params$outcome_scale - y_bar <- object$model_params$outcome_mean - initial_sigma2 <- object$model_params$initial_sigma2 - if (predict_mu_forest || predict_mu_forest_intermediate) { - mu_hat <- object$forests_mu$predict(forest_dataset_pred) * y_std + y_bar - if (predict_mean) { - mu_hat <- rowMeans(mu_hat) - } - } - - # Compute CATE forest predictions - if (predict_tau_forest || predict_tau_forest_intermediate) { - if (object$model_params$adaptive_coding) { - tau_hat_raw <- object$forests_tau$predict_raw(forest_dataset_pred) - tau_hat <- t( - t(tau_hat_raw) * (object$b_1_samples - object$b_0_samples) - ) * - y_std - } else { - tau_hat <- object$forests_tau$predict_raw(forest_dataset_pred) * y_std - } - if (object$model_params$multivariate_treatment) { - tau_dim <- dim(tau_hat) - tau_num_obs <- tau_dim[1] - tau_num_samples <- tau_dim[3] - treatment_term <- matrix(NA_real_, nrow = tau_num_obs, tau_num_samples) - for (i in 1:nrow(Z)) { - treatment_term[i, ] <- colSums(tau_hat[i, , ] * Z[i, ]) - } - if (predict_mean) { - tau_hat <- apply(tau_hat, c(1,2), mean) - treatment_term <- rowMeans(treatment_term) - } - } else { - treatment_term <- tau_hat * as.numeric(Z) - if (predict_mean) { - tau_hat <- rowMeans(tau_hat) - treatment_term <- rowMeans(treatment_term) - } - } + # Collapse to posterior mean predictions if requested + if (predict_mean) { + if (predict_mu_forest) { + mu_hat <- rowMeans(mu_hat) } - - # Compute variance forest predictions - if (predict_variance_forest) { - s_x_raw <- object$forests_variance$predict(forest_dataset_pred) + if (predict_tau_forest) { + if (object$model_params$multivariate_treatment) { + tau_hat <- apply(tau_hat, c(1, 2), mean) + } else { + tau_hat <- rowMeans(tau_hat) + } } - - # Compute rfx predictions if (predict_rfx) { - rfx_predictions <- object$rfx_samples$predict( - rfx_group_ids, - rfx_basis - ) * - y_std - if (predict_mean) { - rfx_predictions <- rowMeans(rfx_predictions) - } + rfx_predictions <- rowMeans(rfx_predictions) } - - # Compute overall "y_hat" predictions if (predict_y_hat) { - y_hat <- mu_hat + treatment_term - if (has_rfx) { - y_hat <- y_hat + rfx_predictions - } + y_hat <- rowMeans(y_hat) } + } - # Scale variance forest predictions - if (predict_variance_forest) { - if (object$model_params$sample_sigma2_global) { - sigma2_global_samples <- object$sigma2_global_samples - variance_forest_predictions <- sapply(1:num_samples, function(i) { - s_x_raw[, i] * sigma2_global_samples[i] - }) - } else { - variance_forest_predictions <- s_x_raw * - initial_sigma2 * - y_std * - y_std - } - if (predict_mean) { - variance_forest_predictions <- rowMeans(variance_forest_predictions) - } + # Return results + if (predict_count == 1) { + if (predict_y_hat) { + return(y_hat) + } else if (predict_mu_forest) { + return(mu_hat) + } else if (predict_tau_forest) { + return(tau_hat) + } else if (predict_rfx) { + return(rfx_predictions) + } else if (predict_variance_forest) { + return(variance_forest_predictions) + } + } else { + result <- list() + if (predict_y_hat) { + result[["y_hat"]] = y_hat + } else { + result[["y_hat"]] <- NULL } - - # Return results - if (predict_count == 1) { - if (predict_y_hat) { - return(y_hat) - } else if (predict_mu_forest) { - return(mu_hat) - } else if (predict_tau_forest) { - return(tau_hat) - } else if (predict_rfx) { - return(rfx_predictions) - } else if (predict_variance_forest) { - return(variance_forest_predictions) - } + if (predict_mu_forest) { + result[["mu_hat"]] = mu_hat } else { - result <- list() - if (predict_y_hat) { - result[["y_hat"]] = y_hat - } else { - result[["y_hat"]] <- NULL - } - if (predict_mu_forest) { - result[["mu_hat"]] = mu_hat - } else { - result[["mu_hat"]] <- NULL - } - if (predict_tau_forest) { - result[["tau_hat"]] = tau_hat - } else { - result[["tau_hat"]] <- NULL - } - if (predict_rfx) { - result[["rfx_predictions"]] = rfx_predictions - } else { - result[["rfx_predictions"]] <- NULL - } - if (predict_variance_forest) { - result[["variance_forest_predictions"]] = variance_forest_predictions - } else { - result[["variance_forest_predictions"]] <- NULL - } + result[["mu_hat"]] <- NULL + } + if (predict_tau_forest) { + result[["tau_hat"]] = tau_hat + } else { + result[["tau_hat"]] <- NULL + } + if (predict_rfx) { + result[["rfx_predictions"]] = rfx_predictions + } else { + result[["rfx_predictions"]] <- NULL } - return(result) + if (predict_variance_forest) { + result[["variance_forest_predictions"]] = variance_forest_predictions + } else { + result[["variance_forest_predictions"]] <- NULL + } + } + return(result) } #' Extract raw sample values for each of the random effect parameter terms. @@ -2959,26 +2990,26 @@ predict.bcfmodel <- function( #' treatment_effect_forest_params = tau_params) #' rfx_samples <- getRandomEffectSamples(bcf_model) getRandomEffectSamples.bcfmodel <- function(object, ...) { - result = list() + result = list() - if (!object$model_params$has_rfx) { - warning("This model has no RFX terms, returning an empty list") - return(result) - } + if (!object$model_params$has_rfx) { + warning("This model has no RFX terms, returning an empty list") + return(result) + } - # Extract the samples - result <- object$rfx_samples$extract_parameter_samples() + # Extract the samples + result <- object$rfx_samples$extract_parameter_samples() - # Scale by sd(y_train) - result$beta_samples <- result$beta_samples * - object$model_params$outcome_scale - result$xi_samples <- result$xi_samples * object$model_params$outcome_scale - result$alpha_samples <- result$alpha_samples * - object$model_params$outcome_scale - result$sigma_samples <- result$sigma_samples * - (object$model_params$outcome_scale^2) + # Scale by sd(y_train) + result$beta_samples <- result$beta_samples * + object$model_params$outcome_scale + result$xi_samples <- result$xi_samples * object$model_params$outcome_scale + result$alpha_samples <- result$alpha_samples * + object$model_params$outcome_scale + result$sigma_samples <- result$sigma_samples * + (object$model_params$outcome_scale^2) - return(result) + return(result) } #' Convert the persistent aspects of a BCF model to (in-memory) JSON @@ -3055,161 +3086,161 @@ getRandomEffectSamples.bcfmodel <- function(object, ...) { #' treatment_effect_forest_params = tau_params) #' bcf_json <- saveBCFModelToJson(bcf_model) saveBCFModelToJson <- function(object) { - jsonobj <- createCppJson() - - if (!inherits(object, "bcfmodel")) { - stop("`object` must be a BCF model") - } - - if (is.null(object$model_params)) { - stop("This BCF model has not yet been sampled") - } - - # Add the forests - jsonobj$add_forest(object$forests_mu) - jsonobj$add_forest(object$forests_tau) - if (object$model_params$include_variance_forest) { - jsonobj$add_forest(object$forests_variance) - } - - # Add metadata - jsonobj$add_scalar( - "num_numeric_vars", - object$train_set_metadata$num_numeric_vars - ) - jsonobj$add_scalar( - "num_ordered_cat_vars", - object$train_set_metadata$num_ordered_cat_vars - ) - jsonobj$add_scalar( - "num_unordered_cat_vars", - object$train_set_metadata$num_unordered_cat_vars + jsonobj <- createCppJson() + + if (!inherits(object, "bcfmodel")) { + stop("`object` must be a BCF model") + } + + if (is.null(object$model_params)) { + stop("This BCF model has not yet been sampled") + } + + # Add the forests + jsonobj$add_forest(object$forests_mu) + jsonobj$add_forest(object$forests_tau) + if (object$model_params$include_variance_forest) { + jsonobj$add_forest(object$forests_variance) + } + + # Add metadata + jsonobj$add_scalar( + "num_numeric_vars", + object$train_set_metadata$num_numeric_vars + ) + jsonobj$add_scalar( + "num_ordered_cat_vars", + object$train_set_metadata$num_ordered_cat_vars + ) + jsonobj$add_scalar( + "num_unordered_cat_vars", + object$train_set_metadata$num_unordered_cat_vars + ) + if (object$train_set_metadata$num_numeric_vars > 0) { + jsonobj$add_string_vector( + "numeric_vars", + object$train_set_metadata$numeric_vars ) - if (object$train_set_metadata$num_numeric_vars > 0) { - jsonobj$add_string_vector( - "numeric_vars", - object$train_set_metadata$numeric_vars - ) - } - if (object$train_set_metadata$num_ordered_cat_vars > 0) { - jsonobj$add_string_vector( - "ordered_cat_vars", - object$train_set_metadata$ordered_cat_vars - ) - jsonobj$add_string_list( - "ordered_unique_levels", - object$train_set_metadata$ordered_unique_levels - ) - } - if (object$train_set_metadata$num_unordered_cat_vars > 0) { - jsonobj$add_string_vector( - "unordered_cat_vars", - object$train_set_metadata$unordered_cat_vars - ) - jsonobj$add_string_list( - "unordered_unique_levels", - object$train_set_metadata$unordered_unique_levels - ) - } - - # Add global parameters - jsonobj$add_scalar("outcome_scale", object$model_params$outcome_scale) - jsonobj$add_scalar("outcome_mean", object$model_params$outcome_mean) - jsonobj$add_boolean("standardize", object$model_params$standardize) - jsonobj$add_scalar("initial_sigma2", object$model_params$initial_sigma2) - jsonobj$add_boolean( - "sample_sigma2_global", - object$model_params$sample_sigma2_global + } + if (object$train_set_metadata$num_ordered_cat_vars > 0) { + jsonobj$add_string_vector( + "ordered_cat_vars", + object$train_set_metadata$ordered_cat_vars ) - jsonobj$add_boolean( - "sample_sigma2_leaf_mu", - object$model_params$sample_sigma2_leaf_mu + jsonobj$add_string_list( + "ordered_unique_levels", + object$train_set_metadata$ordered_unique_levels ) - jsonobj$add_boolean( - "sample_sigma2_leaf_tau", - object$model_params$sample_sigma2_leaf_tau + } + if (object$train_set_metadata$num_unordered_cat_vars > 0) { + jsonobj$add_string_vector( + "unordered_cat_vars", + object$train_set_metadata$unordered_cat_vars ) - jsonobj$add_boolean( - "include_variance_forest", - object$model_params$include_variance_forest + jsonobj$add_string_list( + "unordered_unique_levels", + object$train_set_metadata$unordered_unique_levels ) - jsonobj$add_string( - "propensity_covariate", - object$model_params$propensity_covariate + } + + # Add global parameters + jsonobj$add_scalar("outcome_scale", object$model_params$outcome_scale) + jsonobj$add_scalar("outcome_mean", object$model_params$outcome_mean) + jsonobj$add_boolean("standardize", object$model_params$standardize) + jsonobj$add_scalar("initial_sigma2", object$model_params$initial_sigma2) + jsonobj$add_boolean( + "sample_sigma2_global", + object$model_params$sample_sigma2_global + ) + jsonobj$add_boolean( + "sample_sigma2_leaf_mu", + object$model_params$sample_sigma2_leaf_mu + ) + jsonobj$add_boolean( + "sample_sigma2_leaf_tau", + object$model_params$sample_sigma2_leaf_tau + ) + jsonobj$add_boolean( + "include_variance_forest", + object$model_params$include_variance_forest + ) + jsonobj$add_string( + "propensity_covariate", + object$model_params$propensity_covariate + ) + jsonobj$add_boolean("has_rfx", object$model_params$has_rfx) + jsonobj$add_boolean("has_rfx_basis", object$model_params$has_rfx_basis) + jsonobj$add_scalar("num_rfx_basis", object$model_params$num_rfx_basis) + jsonobj$add_boolean( + "multivariate_treatment", + object$model_params$multivariate_treatment + ) + jsonobj$add_boolean("adaptive_coding", object$model_params$adaptive_coding) + jsonobj$add_boolean( + "internal_propensity_model", + object$model_params$internal_propensity_model + ) + jsonobj$add_scalar("num_gfr", object$model_params$num_gfr) + jsonobj$add_scalar("num_burnin", object$model_params$num_burnin) + jsonobj$add_scalar("num_mcmc", object$model_params$num_mcmc) + jsonobj$add_scalar("num_samples", object$model_params$num_samples) + jsonobj$add_scalar("keep_every", object$model_params$keep_every) + jsonobj$add_scalar("num_chains", object$model_params$num_chains) + jsonobj$add_scalar("num_covariates", object$model_params$num_covariates) + jsonobj$add_boolean( + "probit_outcome_model", + object$model_params$probit_outcome_model + ) + if (object$model_params$sample_sigma2_global) { + jsonobj$add_vector( + "sigma2_global_samples", + object$sigma2_global_samples, + "parameters" ) - jsonobj$add_boolean("has_rfx", object$model_params$has_rfx) - jsonobj$add_boolean("has_rfx_basis", object$model_params$has_rfx_basis) - jsonobj$add_scalar("num_rfx_basis", object$model_params$num_rfx_basis) - jsonobj$add_boolean( - "multivariate_treatment", - object$model_params$multivariate_treatment + } + if (object$model_params$sample_sigma2_leaf_mu) { + jsonobj$add_vector( + "sigma2_leaf_mu_samples", + object$sigma2_leaf_mu_samples, + "parameters" ) - jsonobj$add_boolean("adaptive_coding", object$model_params$adaptive_coding) - jsonobj$add_boolean( - "internal_propensity_model", - object$model_params$internal_propensity_model + } + if (object$model_params$sample_sigma2_leaf_tau) { + jsonobj$add_vector( + "sigma2_leaf_tau_samples", + object$sigma2_leaf_tau_samples, + "parameters" ) - jsonobj$add_scalar("num_gfr", object$model_params$num_gfr) - jsonobj$add_scalar("num_burnin", object$model_params$num_burnin) - jsonobj$add_scalar("num_mcmc", object$model_params$num_mcmc) - jsonobj$add_scalar("num_samples", object$model_params$num_samples) - jsonobj$add_scalar("keep_every", object$model_params$keep_every) - jsonobj$add_scalar("num_chains", object$model_params$num_chains) - jsonobj$add_scalar("num_covariates", object$model_params$num_covariates) - jsonobj$add_boolean( - "probit_outcome_model", - object$model_params$probit_outcome_model + } + if (object$model_params$adaptive_coding) { + jsonobj$add_vector("b_1_samples", object$b_1_samples, "parameters") + jsonobj$add_vector("b_0_samples", object$b_0_samples, "parameters") + } + + # Add random effects (if present) + if (object$model_params$has_rfx) { + jsonobj$add_random_effects(object$rfx_samples) + jsonobj$add_string_vector( + "rfx_unique_group_ids", + object$rfx_unique_group_ids ) - if (object$model_params$sample_sigma2_global) { - jsonobj$add_vector( - "sigma2_global_samples", - object$sigma2_global_samples, - "parameters" - ) - } - if (object$model_params$sample_sigma2_leaf_mu) { - jsonobj$add_vector( - "sigma2_leaf_mu_samples", - object$sigma2_leaf_mu_samples, - "parameters" - ) - } - if (object$model_params$sample_sigma2_leaf_tau) { - jsonobj$add_vector( - "sigma2_leaf_tau_samples", - object$sigma2_leaf_tau_samples, - "parameters" - ) - } - if (object$model_params$adaptive_coding) { - jsonobj$add_vector("b_1_samples", object$b_1_samples, "parameters") - jsonobj$add_vector("b_0_samples", object$b_0_samples, "parameters") - } + } - # Add random effects (if present) - if (object$model_params$has_rfx) { - jsonobj$add_random_effects(object$rfx_samples) - jsonobj$add_string_vector( - "rfx_unique_group_ids", - object$rfx_unique_group_ids - ) - } - - # Add propensity model (if it exists) - if (object$model_params$internal_propensity_model) { - bart_propensity_string <- saveBARTModelToJsonString( - object$bart_propensity_model - ) - jsonobj$add_string("bart_propensity_model", bart_propensity_string) - } - - # Add covariate preprocessor metadata - preprocessor_metadata_string <- savePreprocessorToJsonString( - object$train_set_metadata + # Add propensity model (if it exists) + if (object$model_params$internal_propensity_model) { + bart_propensity_string <- saveBARTModelToJsonString( + object$bart_propensity_model ) - jsonobj$add_string("preprocessor_metadata", preprocessor_metadata_string) + jsonobj$add_string("bart_propensity_model", bart_propensity_string) + } - return(jsonobj) + # Add covariate preprocessor metadata + preprocessor_metadata_string <- savePreprocessorToJsonString( + object$train_set_metadata + ) + jsonobj$add_string("preprocessor_metadata", preprocessor_metadata_string) + + return(jsonobj) } #' Convert the persistent aspects of a BCF model to (in-memory) JSON and save to a file @@ -3289,11 +3320,11 @@ saveBCFModelToJson <- function(object) { #' saveBCFModelToJsonFile(bcf_model, file.path(tmpjson)) #' unlink(tmpjson) saveBCFModelToJsonFile <- function(object, filename) { - # Convert to Json - jsonobj <- saveBCFModelToJson(object) + # Convert to Json + jsonobj <- saveBCFModelToJson(object) - # Save to file - jsonobj$save_file(filename) + # Save to file + jsonobj$save_file(filename) } #' Convert the persistent aspects of a BCF model to (in-memory) JSON string @@ -3369,11 +3400,11 @@ saveBCFModelToJsonFile <- function(object, filename) { #' treatment_effect_forest_params = tau_params) #' saveBCFModelToJsonString(bcf_model) saveBCFModelToJsonString <- function(object) { - # Convert to Json - jsonobj <- saveBCFModelToJson(object) + # Convert to Json + jsonobj <- saveBCFModelToJson(object) - # Dump to string - return(jsonobj$return_json_string()) + # Dump to string + return(jsonobj$return_json_string()) } #' Convert an (in-memory) JSON representation of a BCF model to a BCF model object @@ -3452,161 +3483,161 @@ saveBCFModelToJsonString <- function(object) { #' bcf_json <- saveBCFModelToJson(bcf_model) #' bcf_model_roundtrip <- createBCFModelFromJson(bcf_json) createBCFModelFromJson <- function(json_object) { - # Initialize the BCF model - output <- list() - - # Unpack the forests - output[["forests_mu"]] <- loadForestContainerJson(json_object, "forest_0") - output[["forests_tau"]] <- loadForestContainerJson(json_object, "forest_1") - include_variance_forest <- json_object$get_boolean( - "include_variance_forest" - ) - if (include_variance_forest) { - output[["forests_variance"]] <- loadForestContainerJson( - json_object, - "forest_2" - ) - } - - # Unpack metadata - train_set_metadata = list() - train_set_metadata[["num_numeric_vars"]] <- json_object$get_scalar( - "num_numeric_vars" + # Initialize the BCF model + output <- list() + + # Unpack the forests + output[["forests_mu"]] <- loadForestContainerJson(json_object, "forest_0") + output[["forests_tau"]] <- loadForestContainerJson(json_object, "forest_1") + include_variance_forest <- json_object$get_boolean( + "include_variance_forest" + ) + if (include_variance_forest) { + output[["forests_variance"]] <- loadForestContainerJson( + json_object, + "forest_2" ) - train_set_metadata[["num_ordered_cat_vars"]] <- json_object$get_scalar( - "num_ordered_cat_vars" - ) - train_set_metadata[["num_unordered_cat_vars"]] <- json_object$get_scalar( - "num_unordered_cat_vars" - ) - if (train_set_metadata[["num_numeric_vars"]] > 0) { - train_set_metadata[["numeric_vars"]] <- json_object$get_string_vector( - "numeric_vars" - ) - } - if (train_set_metadata[["num_ordered_cat_vars"]] > 0) { - train_set_metadata[[ - "ordered_cat_vars" - ]] <- json_object$get_string_vector("ordered_cat_vars") - train_set_metadata[[ - "ordered_unique_levels" - ]] <- json_object$get_string_list( - "ordered_unique_levels", - train_set_metadata[["ordered_cat_vars"]] - ) - } - if (train_set_metadata[["num_unordered_cat_vars"]] > 0) { - train_set_metadata[[ - "unordered_cat_vars" - ]] <- json_object$get_string_vector("unordered_cat_vars") - train_set_metadata[[ - "unordered_unique_levels" - ]] <- json_object$get_string_list( - "unordered_unique_levels", - train_set_metadata[["unordered_cat_vars"]] - ) - } - output[["train_set_metadata"]] <- train_set_metadata - - # Unpack model params - model_params = list() - model_params[["outcome_scale"]] <- json_object$get_scalar("outcome_scale") - model_params[["outcome_mean"]] <- json_object$get_scalar("outcome_mean") - model_params[["standardize"]] <- json_object$get_boolean("standardize") - model_params[["initial_sigma2"]] <- json_object$get_scalar("initial_sigma2") - model_params[["sample_sigma2_global"]] <- json_object$get_boolean( - "sample_sigma2_global" + } + + # Unpack metadata + train_set_metadata = list() + train_set_metadata[["num_numeric_vars"]] <- json_object$get_scalar( + "num_numeric_vars" + ) + train_set_metadata[["num_ordered_cat_vars"]] <- json_object$get_scalar( + "num_ordered_cat_vars" + ) + train_set_metadata[["num_unordered_cat_vars"]] <- json_object$get_scalar( + "num_unordered_cat_vars" + ) + if (train_set_metadata[["num_numeric_vars"]] > 0) { + train_set_metadata[["numeric_vars"]] <- json_object$get_string_vector( + "numeric_vars" ) - model_params[["sample_sigma2_leaf_mu"]] <- json_object$get_boolean( - "sample_sigma2_leaf_mu" + } + if (train_set_metadata[["num_ordered_cat_vars"]] > 0) { + train_set_metadata[[ + "ordered_cat_vars" + ]] <- json_object$get_string_vector("ordered_cat_vars") + train_set_metadata[[ + "ordered_unique_levels" + ]] <- json_object$get_string_list( + "ordered_unique_levels", + train_set_metadata[["ordered_cat_vars"]] ) - model_params[["sample_sigma2_leaf_tau"]] <- json_object$get_boolean( - "sample_sigma2_leaf_tau" + } + if (train_set_metadata[["num_unordered_cat_vars"]] > 0) { + train_set_metadata[[ + "unordered_cat_vars" + ]] <- json_object$get_string_vector("unordered_cat_vars") + train_set_metadata[[ + "unordered_unique_levels" + ]] <- json_object$get_string_list( + "unordered_unique_levels", + train_set_metadata[["unordered_cat_vars"]] ) - model_params[["include_variance_forest"]] <- include_variance_forest - model_params[["propensity_covariate"]] <- json_object$get_string( - "propensity_covariate" + } + output[["train_set_metadata"]] <- train_set_metadata + + # Unpack model params + model_params = list() + model_params[["outcome_scale"]] <- json_object$get_scalar("outcome_scale") + model_params[["outcome_mean"]] <- json_object$get_scalar("outcome_mean") + model_params[["standardize"]] <- json_object$get_boolean("standardize") + model_params[["initial_sigma2"]] <- json_object$get_scalar("initial_sigma2") + model_params[["sample_sigma2_global"]] <- json_object$get_boolean( + "sample_sigma2_global" + ) + model_params[["sample_sigma2_leaf_mu"]] <- json_object$get_boolean( + "sample_sigma2_leaf_mu" + ) + model_params[["sample_sigma2_leaf_tau"]] <- json_object$get_boolean( + "sample_sigma2_leaf_tau" + ) + model_params[["include_variance_forest"]] <- include_variance_forest + model_params[["propensity_covariate"]] <- json_object$get_string( + "propensity_covariate" + ) + model_params[["has_rfx"]] <- json_object$get_boolean("has_rfx") + model_params[["has_rfx_basis"]] <- json_object$get_boolean("has_rfx_basis") + model_params[["num_rfx_basis"]] <- json_object$get_scalar("num_rfx_basis") + model_params[["adaptive_coding"]] <- json_object$get_boolean( + "adaptive_coding" + ) + model_params[["multivariate_treatment"]] <- json_object$get_boolean( + "multivariate_treatment" + ) + model_params[["internal_propensity_model"]] <- json_object$get_boolean( + "internal_propensity_model" + ) + model_params[["num_gfr"]] <- json_object$get_scalar("num_gfr") + model_params[["num_burnin"]] <- json_object$get_scalar("num_burnin") + model_params[["num_mcmc"]] <- json_object$get_scalar("num_mcmc") + model_params[["num_samples"]] <- json_object$get_scalar("num_samples") + model_params[["num_covariates"]] <- json_object$get_scalar("num_covariates") + model_params[["probit_outcome_model"]] <- json_object$get_boolean( + "probit_outcome_model" + ) + output[["model_params"]] <- model_params + + # Unpack sampled parameters + if (model_params[["sample_sigma2_global"]]) { + output[["sigma2_global_samples"]] <- json_object$get_vector( + "sigma2_global_samples", + "parameters" ) - model_params[["has_rfx"]] <- json_object$get_boolean("has_rfx") - model_params[["has_rfx_basis"]] <- json_object$get_boolean("has_rfx_basis") - model_params[["num_rfx_basis"]] <- json_object$get_scalar("num_rfx_basis") - model_params[["adaptive_coding"]] <- json_object$get_boolean( - "adaptive_coding" + } + if (model_params[["sample_sigma2_leaf_mu"]]) { + output[["sigma2_leaf_mu_samples"]] <- json_object$get_vector( + "sigma2_leaf_mu_samples", + "parameters" ) - model_params[["multivariate_treatment"]] <- json_object$get_boolean( - "multivariate_treatment" + } + if (model_params[["sample_sigma2_leaf_tau"]]) { + output[["sigma2_leaf_tau_samples"]] <- json_object$get_vector( + "sigma2_leaf_tau_samples", + "parameters" ) - model_params[["internal_propensity_model"]] <- json_object$get_boolean( - "internal_propensity_model" + } + if (model_params[["adaptive_coding"]]) { + output[["b_1_samples"]] <- json_object$get_vector( + "b_1_samples", + "parameters" ) - model_params[["num_gfr"]] <- json_object$get_scalar("num_gfr") - model_params[["num_burnin"]] <- json_object$get_scalar("num_burnin") - model_params[["num_mcmc"]] <- json_object$get_scalar("num_mcmc") - model_params[["num_samples"]] <- json_object$get_scalar("num_samples") - model_params[["num_covariates"]] <- json_object$get_scalar("num_covariates") - model_params[["probit_outcome_model"]] <- json_object$get_boolean( - "probit_outcome_model" + output[["b_0_samples"]] <- json_object$get_vector( + "b_0_samples", + "parameters" ) - output[["model_params"]] <- model_params + } - # Unpack sampled parameters - if (model_params[["sample_sigma2_global"]]) { - output[["sigma2_global_samples"]] <- json_object$get_vector( - "sigma2_global_samples", - "parameters" - ) - } - if (model_params[["sample_sigma2_leaf_mu"]]) { - output[["sigma2_leaf_mu_samples"]] <- json_object$get_vector( - "sigma2_leaf_mu_samples", - "parameters" - ) - } - if (model_params[["sample_sigma2_leaf_tau"]]) { - output[["sigma2_leaf_tau_samples"]] <- json_object$get_vector( - "sigma2_leaf_tau_samples", - "parameters" - ) - } - if (model_params[["adaptive_coding"]]) { - output[["b_1_samples"]] <- json_object$get_vector( - "b_1_samples", - "parameters" - ) - output[["b_0_samples"]] <- json_object$get_vector( - "b_0_samples", - "parameters" - ) - } - - # Unpack random effects - if (model_params[["has_rfx"]]) { - output[["rfx_unique_group_ids"]] <- json_object$get_string_vector( - "rfx_unique_group_ids" - ) - output[["rfx_samples"]] <- loadRandomEffectSamplesJson(json_object, 0) - } - - # Unpack propensity model (if it exists) - if (model_params[["internal_propensity_model"]]) { - bart_propensity_string <- json_object$get_string( - "bart_propensity_model" - ) - output[["bart_propensity_model"]] <- createBARTModelFromJsonString( - bart_propensity_string - ) - } + # Unpack random effects + if (model_params[["has_rfx"]]) { + output[["rfx_unique_group_ids"]] <- json_object$get_string_vector( + "rfx_unique_group_ids" + ) + output[["rfx_samples"]] <- loadRandomEffectSamplesJson(json_object, 0) + } - # Unpack covariate preprocessor - preprocessor_metadata_string <- json_object$get_string( - "preprocessor_metadata" + # Unpack propensity model (if it exists) + if (model_params[["internal_propensity_model"]]) { + bart_propensity_string <- json_object$get_string( + "bart_propensity_model" ) - output[["train_set_metadata"]] <- createPreprocessorFromJsonString( - preprocessor_metadata_string + output[["bart_propensity_model"]] <- createBARTModelFromJsonString( + bart_propensity_string ) - - class(output) <- "bcfmodel" - return(output) + } + + # Unpack covariate preprocessor + preprocessor_metadata_string <- json_object$get_string( + "preprocessor_metadata" + ) + output[["train_set_metadata"]] <- createPreprocessorFromJsonString( + preprocessor_metadata_string + ) + + class(output) <- "bcfmodel" + return(output) } #' Convert a JSON file containing sample information on a trained BCF model @@ -3687,13 +3718,13 @@ createBCFModelFromJson <- function(json_object) { #' bcf_model_roundtrip <- createBCFModelFromJsonFile(file.path(tmpjson)) #' unlink(tmpjson) createBCFModelFromJsonFile <- function(json_filename) { - # Load a `CppJson` object from file - bcf_json <- createCppJsonFile(json_filename) + # Load a `CppJson` object from file + bcf_json <- createCppJsonFile(json_filename) - # Create and return the BCF object - bcf_object <- createBCFModelFromJson(bcf_json) + # Create and return the BCF object + bcf_object <- createBCFModelFromJson(bcf_json) - return(bcf_object) + return(bcf_object) } #' Convert a JSON string containing sample information on a trained BCF model @@ -3768,13 +3799,13 @@ createBCFModelFromJsonFile <- function(json_filename) { #' bcf_json <- saveBCFModelToJsonString(bcf_model) #' bcf_model_roundtrip <- createBCFModelFromJsonString(bcf_json) createBCFModelFromJsonString <- function(json_string) { - # Load a `CppJson` object from string - bcf_json <- createCppJsonString(json_string) + # Load a `CppJson` object from string + bcf_json <- createCppJsonString(json_string) - # Create and return the BCF object - bcf_object <- createBCFModelFromJson(bcf_json) + # Create and return the BCF object + bcf_object <- createBCFModelFromJson(bcf_json) - return(bcf_object) + return(bcf_object) } #' Convert a list of (in-memory) JSON strings that represent BCF models to a single combined BCF model object @@ -3849,271 +3880,271 @@ createBCFModelFromJsonString <- function(json_string) { #' bcf_json_list <- list(saveBCFModelToJson(bcf_model)) #' bcf_model_roundtrip <- createBCFModelFromCombinedJson(bcf_json_list) createBCFModelFromCombinedJson <- function(json_object_list) { - # Initialize the BCF model - output <- list() - - # For scalar / preprocessing details which aren't sample-dependent, - # defer to the first json - json_object_default <- json_object_list[[1]] - - # Unpack the forests - output[["forests_mu"]] <- loadForestContainerCombinedJson( - json_object_list, - "forest_0" - ) - output[["forests_tau"]] <- loadForestContainerCombinedJson( - json_object_list, - "forest_1" + # Initialize the BCF model + output <- list() + + # For scalar / preprocessing details which aren't sample-dependent, + # defer to the first json + json_object_default <- json_object_list[[1]] + + # Unpack the forests + output[["forests_mu"]] <- loadForestContainerCombinedJson( + json_object_list, + "forest_0" + ) + output[["forests_tau"]] <- loadForestContainerCombinedJson( + json_object_list, + "forest_1" + ) + include_variance_forest <- json_object_default$get_boolean( + "include_variance_forest" + ) + if (include_variance_forest) { + output[["forests_variance"]] <- loadForestContainerCombinedJson( + json_object_list, + "forest_2" ) - include_variance_forest <- json_object_default$get_boolean( - "include_variance_forest" - ) - if (include_variance_forest) { - output[["forests_variance"]] <- loadForestContainerCombinedJson( - json_object_list, - "forest_2" - ) - } - - # Unpack metadata - train_set_metadata = list() - train_set_metadata[["num_numeric_vars"]] <- json_object_default$get_scalar( - "num_numeric_vars" + } + + # Unpack metadata + train_set_metadata = list() + train_set_metadata[["num_numeric_vars"]] <- json_object_default$get_scalar( + "num_numeric_vars" + ) + train_set_metadata[[ + "num_ordered_cat_vars" + ]] <- json_object_default$get_scalar("num_ordered_cat_vars") + train_set_metadata[[ + "num_unordered_cat_vars" + ]] <- json_object_default$get_scalar("num_unordered_cat_vars") + if (train_set_metadata[["num_numeric_vars"]] > 0) { + train_set_metadata[[ + "numeric_vars" + ]] <- json_object_default$get_string_vector("numeric_vars") + } + if (train_set_metadata[["num_ordered_cat_vars"]] > 0) { + train_set_metadata[[ + "ordered_cat_vars" + ]] <- json_object_default$get_string_vector("ordered_cat_vars") + train_set_metadata[[ + "ordered_unique_levels" + ]] <- json_object_default$get_string_list( + "ordered_unique_levels", + train_set_metadata[["ordered_cat_vars"]] ) + } + if (train_set_metadata[["num_unordered_cat_vars"]] > 0) { train_set_metadata[[ - "num_ordered_cat_vars" - ]] <- json_object_default$get_scalar("num_ordered_cat_vars") + "unordered_cat_vars" + ]] <- json_object_default$get_string_vector("unordered_cat_vars") train_set_metadata[[ - "num_unordered_cat_vars" - ]] <- json_object_default$get_scalar("num_unordered_cat_vars") - if (train_set_metadata[["num_numeric_vars"]] > 0) { - train_set_metadata[[ - "numeric_vars" - ]] <- json_object_default$get_string_vector("numeric_vars") - } - if (train_set_metadata[["num_ordered_cat_vars"]] > 0) { - train_set_metadata[[ - "ordered_cat_vars" - ]] <- json_object_default$get_string_vector("ordered_cat_vars") - train_set_metadata[[ - "ordered_unique_levels" - ]] <- json_object_default$get_string_list( - "ordered_unique_levels", - train_set_metadata[["ordered_cat_vars"]] + "unordered_unique_levels" + ]] <- json_object_default$get_string_list( + "unordered_unique_levels", + train_set_metadata[["unordered_cat_vars"]] + ) + } + output[["train_set_metadata"]] <- train_set_metadata + + # Unpack model params + model_params = list() + model_params[["outcome_scale"]] <- json_object_default$get_scalar( + "outcome_scale" + ) + model_params[["outcome_mean"]] <- json_object_default$get_scalar( + "outcome_mean" + ) + model_params[["standardize"]] <- json_object_default$get_boolean( + "standardize" + ) + model_params[["initial_sigma2"]] <- json_object_default$get_scalar( + "initial_sigma2" + ) + model_params[["sample_sigma2_global"]] <- json_object_default$get_boolean( + "sample_sigma2_global" + ) + model_params[["sample_sigma2_leaf_mu"]] <- json_object_default$get_boolean( + "sample_sigma2_leaf_mu" + ) + model_params[["sample_sigma2_leaf_tau"]] <- json_object_default$get_boolean( + "sample_sigma2_leaf_tau" + ) + model_params[["include_variance_forest"]] <- include_variance_forest + model_params[["propensity_covariate"]] <- json_object_default$get_string( + "propensity_covariate" + ) + model_params[["has_rfx"]] <- json_object_default$get_boolean("has_rfx") + model_params[["has_rfx_basis"]] <- json_object_default$get_boolean( + "has_rfx_basis" + ) + model_params[["num_rfx_basis"]] <- json_object_default$get_scalar( + "num_rfx_basis" + ) + model_params[["num_covariates"]] <- json_object_default$get_scalar( + "num_covariates" + ) + model_params[["num_chains"]] <- json_object_default$get_scalar("num_chains") + model_params[["keep_every"]] <- json_object_default$get_scalar("keep_every") + model_params[["adaptive_coding"]] <- json_object_default$get_boolean( + "adaptive_coding" + ) + model_params[["multivariate_treatment"]] <- json_object_default$get_boolean( + "multivariate_treatment" + ) + model_params[[ + "internal_propensity_model" + ]] <- json_object_default$get_boolean("internal_propensity_model") + model_params[["probit_outcome_model"]] <- json_object_default$get_boolean( + "probit_outcome_model" + ) + + # Combine values that are sample-specific + for (i in 1:length(json_object_list)) { + json_object <- json_object_list[[i]] + if (i == 1) { + model_params[["num_gfr"]] <- json_object$get_scalar("num_gfr") + model_params[["num_burnin"]] <- json_object$get_scalar("num_burnin") + model_params[["num_mcmc"]] <- json_object$get_scalar("num_mcmc") + model_params[["num_samples"]] <- json_object$get_scalar( + "num_samples" + ) + } else { + prev_json <- json_object_list[[i - 1]] + model_params[["num_gfr"]] <- model_params[["num_gfr"]] + + json_object$get_scalar("num_gfr") + model_params[["num_burnin"]] <- model_params[["num_burnin"]] + + json_object$get_scalar("num_burnin") + model_params[["num_mcmc"]] <- model_params[["num_mcmc"]] + + json_object$get_scalar("num_mcmc") + model_params[["num_samples"]] <- model_params[["num_samples"]] + + json_object$get_scalar("num_samples") + } + } + output[["model_params"]] <- model_params + + # Unpack sampled parameters + if (model_params[["sample_sigma2_global"]]) { + for (i in 1:length(json_object_list)) { + json_object <- json_object_list[[i]] + if (i == 1) { + output[["sigma2_global_samples"]] <- json_object$get_vector( + "sigma2_global_samples", + "parameters" ) - } - if (train_set_metadata[["num_unordered_cat_vars"]] > 0) { - train_set_metadata[[ - "unordered_cat_vars" - ]] <- json_object_default$get_string_vector("unordered_cat_vars") - train_set_metadata[[ - "unordered_unique_levels" - ]] <- json_object_default$get_string_list( - "unordered_unique_levels", - train_set_metadata[["unordered_cat_vars"]] + } else { + output[["sigma2_global_samples"]] <- c( + output[["sigma2_global_samples"]], + json_object$get_vector( + "sigma2_global_samples", + "parameters" + ) ) + } } - output[["train_set_metadata"]] <- train_set_metadata - - # Unpack model params - model_params = list() - model_params[["outcome_scale"]] <- json_object_default$get_scalar( - "outcome_scale" - ) - model_params[["outcome_mean"]] <- json_object_default$get_scalar( - "outcome_mean" - ) - model_params[["standardize"]] <- json_object_default$get_boolean( - "standardize" - ) - model_params[["initial_sigma2"]] <- json_object_default$get_scalar( - "initial_sigma2" - ) - model_params[["sample_sigma2_global"]] <- json_object_default$get_boolean( - "sample_sigma2_global" - ) - model_params[["sample_sigma2_leaf_mu"]] <- json_object_default$get_boolean( - "sample_sigma2_leaf_mu" - ) - model_params[["sample_sigma2_leaf_tau"]] <- json_object_default$get_boolean( - "sample_sigma2_leaf_tau" - ) - model_params[["include_variance_forest"]] <- include_variance_forest - model_params[["propensity_covariate"]] <- json_object_default$get_string( - "propensity_covariate" - ) - model_params[["has_rfx"]] <- json_object_default$get_boolean("has_rfx") - model_params[["has_rfx_basis"]] <- json_object_default$get_boolean( - "has_rfx_basis" - ) - model_params[["num_rfx_basis"]] <- json_object_default$get_scalar( - "num_rfx_basis" - ) - model_params[["num_covariates"]] <- json_object_default$get_scalar( - "num_covariates" - ) - model_params[["num_chains"]] <- json_object_default$get_scalar("num_chains") - model_params[["keep_every"]] <- json_object_default$get_scalar("keep_every") - model_params[["adaptive_coding"]] <- json_object_default$get_boolean( - "adaptive_coding" - ) - model_params[["multivariate_treatment"]] <- json_object_default$get_boolean( - "multivariate_treatment" - ) - model_params[[ - "internal_propensity_model" - ]] <- json_object_default$get_boolean("internal_propensity_model") - model_params[["probit_outcome_model"]] <- json_object_default$get_boolean( - "probit_outcome_model" - ) - - # Combine values that are sample-specific + } + if (model_params[["sample_sigma2_leaf_mu"]]) { for (i in 1:length(json_object_list)) { - json_object <- json_object_list[[i]] - if (i == 1) { - model_params[["num_gfr"]] <- json_object$get_scalar("num_gfr") - model_params[["num_burnin"]] <- json_object$get_scalar("num_burnin") - model_params[["num_mcmc"]] <- json_object$get_scalar("num_mcmc") - model_params[["num_samples"]] <- json_object$get_scalar( - "num_samples" - ) - } else { - prev_json <- json_object_list[[i - 1]] - model_params[["num_gfr"]] <- model_params[["num_gfr"]] + - json_object$get_scalar("num_gfr") - model_params[["num_burnin"]] <- model_params[["num_burnin"]] + - json_object$get_scalar("num_burnin") - model_params[["num_mcmc"]] <- model_params[["num_mcmc"]] + - json_object$get_scalar("num_mcmc") - model_params[["num_samples"]] <- model_params[["num_samples"]] + - json_object$get_scalar("num_samples") - } - } - output[["model_params"]] <- model_params - - # Unpack sampled parameters - if (model_params[["sample_sigma2_global"]]) { - for (i in 1:length(json_object_list)) { - json_object <- json_object_list[[i]] - if (i == 1) { - output[["sigma2_global_samples"]] <- json_object$get_vector( - "sigma2_global_samples", - "parameters" - ) - } else { - output[["sigma2_global_samples"]] <- c( - output[["sigma2_global_samples"]], - json_object$get_vector( - "sigma2_global_samples", - "parameters" - ) - ) - } - } - } - if (model_params[["sample_sigma2_leaf_mu"]]) { - for (i in 1:length(json_object_list)) { - json_object <- json_object_list[[i]] - if (i == 1) { - output[["sigma2_leaf_mu_samples"]] <- json_object$get_vector( - "sigma2_leaf_mu_samples", - "parameters" - ) - } else { - output[["sigma2_leaf_mu_samples"]] <- c( - output[["sigma2_leaf_mu_samples"]], - json_object$get_vector( - "sigma2_leaf_mu_samples", - "parameters" - ) - ) - } - } - } - if (model_params[["sample_sigma2_leaf_tau"]]) { - for (i in 1:length(json_object_list)) { - json_object <- json_object_list[[i]] - if (i == 1) { - output[["sigma2_leaf_tau_samples"]] <- json_object$get_vector( - "sigma2_leaf_tau_samples", - "parameters" - ) - } else { - output[["sigma2_leaf_tau_samples"]] <- c( - output[["sigma2_leaf_tau_samples"]], - json_object$get_vector( - "sigma2_leaf_tau_samples", - "parameters" - ) - ) - } - } - } - if (model_params[["sample_sigma2_leaf_tau"]]) { - for (i in 1:length(json_object_list)) { - json_object <- json_object_list[[i]] - if (i == 1) { - output[["sigma2_leaf_tau_samples"]] <- json_object$get_vector( - "sigma2_leaf_tau_samples", - "parameters" - ) - } else { - output[["sigma2_leaf_tau_samples"]] <- c( - output[["sigma2_leaf_tau_samples"]], - json_object$get_vector( - "sigma2_leaf_tau_samples", - "parameters" - ) - ) - } - } + json_object <- json_object_list[[i]] + if (i == 1) { + output[["sigma2_leaf_mu_samples"]] <- json_object$get_vector( + "sigma2_leaf_mu_samples", + "parameters" + ) + } else { + output[["sigma2_leaf_mu_samples"]] <- c( + output[["sigma2_leaf_mu_samples"]], + json_object$get_vector( + "sigma2_leaf_mu_samples", + "parameters" + ) + ) + } } - if (model_params[["adaptive_coding"]]) { - for (i in 1:length(json_object_list)) { - json_object <- json_object_list[[i]] - if (i == 1) { - output[["b_1_samples"]] <- json_object$get_vector( - "b_1_samples", - "parameters" - ) - output[["b_0_samples"]] <- json_object$get_vector( - "b_0_samples", - "parameters" - ) - } else { - output[["b_1_samples"]] <- c( - output[["b_1_samples"]], - json_object$get_vector("b_1_samples", "parameters") - ) - output[["b_0_samples"]] <- c( - output[["b_0_samples"]], - json_object$get_vector("b_0_samples", "parameters") - ) - } - } + } + if (model_params[["sample_sigma2_leaf_tau"]]) { + for (i in 1:length(json_object_list)) { + json_object <- json_object_list[[i]] + if (i == 1) { + output[["sigma2_leaf_tau_samples"]] <- json_object$get_vector( + "sigma2_leaf_tau_samples", + "parameters" + ) + } else { + output[["sigma2_leaf_tau_samples"]] <- c( + output[["sigma2_leaf_tau_samples"]], + json_object$get_vector( + "sigma2_leaf_tau_samples", + "parameters" + ) + ) + } } - - # Unpack random effects - if (model_params[["has_rfx"]]) { - output[[ - "rfx_unique_group_ids" - ]] <- json_object_default$get_string_vector("rfx_unique_group_ids") - output[["rfx_samples"]] <- loadRandomEffectSamplesCombinedJson( - json_object_list, - 0 + } + if (model_params[["sample_sigma2_leaf_tau"]]) { + for (i in 1:length(json_object_list)) { + json_object <- json_object_list[[i]] + if (i == 1) { + output[["sigma2_leaf_tau_samples"]] <- json_object$get_vector( + "sigma2_leaf_tau_samples", + "parameters" + ) + } else { + output[["sigma2_leaf_tau_samples"]] <- c( + output[["sigma2_leaf_tau_samples"]], + json_object$get_vector( + "sigma2_leaf_tau_samples", + "parameters" + ) ) + } } - - # Unpack covariate preprocessor - preprocessor_metadata_string <- json_object_default$get_string( - "preprocessor_metadata" - ) - output[["train_set_metadata"]] <- createPreprocessorFromJsonString( - preprocessor_metadata_string + } + if (model_params[["adaptive_coding"]]) { + for (i in 1:length(json_object_list)) { + json_object <- json_object_list[[i]] + if (i == 1) { + output[["b_1_samples"]] <- json_object$get_vector( + "b_1_samples", + "parameters" + ) + output[["b_0_samples"]] <- json_object$get_vector( + "b_0_samples", + "parameters" + ) + } else { + output[["b_1_samples"]] <- c( + output[["b_1_samples"]], + json_object$get_vector("b_1_samples", "parameters") + ) + output[["b_0_samples"]] <- c( + output[["b_0_samples"]], + json_object$get_vector("b_0_samples", "parameters") + ) + } + } + } + + # Unpack random effects + if (model_params[["has_rfx"]]) { + output[[ + "rfx_unique_group_ids" + ]] <- json_object_default$get_string_vector("rfx_unique_group_ids") + output[["rfx_samples"]] <- loadRandomEffectSamplesCombinedJson( + json_object_list, + 0 ) - - class(output) <- "bcfmodel" - return(output) + } + + # Unpack covariate preprocessor + preprocessor_metadata_string <- json_object_default$get_string( + "preprocessor_metadata" + ) + output[["train_set_metadata"]] <- createPreprocessorFromJsonString( + preprocessor_metadata_string + ) + + class(output) <- "bcfmodel" + return(output) } #' Convert a list of (in-memory) JSON strings that represent BCF models to a single combined BCF model object @@ -4188,284 +4219,284 @@ createBCFModelFromCombinedJson <- function(json_object_list) { #' bcf_json_string_list <- list(saveBCFModelToJsonString(bcf_model)) #' bcf_model_roundtrip <- createBCFModelFromCombinedJsonString(bcf_json_string_list) createBCFModelFromCombinedJsonString <- function(json_string_list) { - # Initialize the BCF model - output <- list() - - # Convert JSON strings - json_object_list <- list() - for (i in 1:length(json_string_list)) { - json_string <- json_string_list[[i]] - json_object_list[[i]] <- createCppJsonString(json_string) - # Add runtime check for separately serialized propensity models - # We don't support merging BCF models with independent propensity models - # this way at the moment - if (json_object_list[[i]]$get_boolean("internal_propensity_model")) { - stop( - "Combining separate BCF models with cached internal propensity models is currently unsupported. To make this work, please first train a propensity model and then pass the propensities as data to the separate BCF models before sampling." - ) - } - } - - # For scalar / preprocessing details which aren't sample-dependent, - # defer to the first json - json_object_default <- json_object_list[[1]] - - # Unpack the forests - output[["forests_mu"]] <- loadForestContainerCombinedJson( - json_object_list, - "forest_0" - ) - output[["forests_tau"]] <- loadForestContainerCombinedJson( - json_object_list, - "forest_1" + # Initialize the BCF model + output <- list() + + # Convert JSON strings + json_object_list <- list() + for (i in 1:length(json_string_list)) { + json_string <- json_string_list[[i]] + json_object_list[[i]] <- createCppJsonString(json_string) + # Add runtime check for separately serialized propensity models + # We don't support merging BCF models with independent propensity models + # this way at the moment + if (json_object_list[[i]]$get_boolean("internal_propensity_model")) { + stop( + "Combining separate BCF models with cached internal propensity models is currently unsupported. To make this work, please first train a propensity model and then pass the propensities as data to the separate BCF models before sampling." + ) + } + } + + # For scalar / preprocessing details which aren't sample-dependent, + # defer to the first json + json_object_default <- json_object_list[[1]] + + # Unpack the forests + output[["forests_mu"]] <- loadForestContainerCombinedJson( + json_object_list, + "forest_0" + ) + output[["forests_tau"]] <- loadForestContainerCombinedJson( + json_object_list, + "forest_1" + ) + include_variance_forest <- json_object_default$get_boolean( + "include_variance_forest" + ) + if (include_variance_forest) { + output[["forests_variance"]] <- loadForestContainerCombinedJson( + json_object_list, + "forest_2" ) - include_variance_forest <- json_object_default$get_boolean( - "include_variance_forest" - ) - if (include_variance_forest) { - output[["forests_variance"]] <- loadForestContainerCombinedJson( - json_object_list, - "forest_2" - ) - } - - # Unpack metadata - train_set_metadata = list() - train_set_metadata[["num_numeric_vars"]] <- json_object_default$get_scalar( - "num_numeric_vars" + } + + # Unpack metadata + train_set_metadata = list() + train_set_metadata[["num_numeric_vars"]] <- json_object_default$get_scalar( + "num_numeric_vars" + ) + train_set_metadata[[ + "num_ordered_cat_vars" + ]] <- json_object_default$get_scalar("num_ordered_cat_vars") + train_set_metadata[[ + "num_unordered_cat_vars" + ]] <- json_object_default$get_scalar("num_unordered_cat_vars") + if (train_set_metadata[["num_numeric_vars"]] > 0) { + train_set_metadata[[ + "numeric_vars" + ]] <- json_object_default$get_string_vector("numeric_vars") + } + if (train_set_metadata[["num_ordered_cat_vars"]] > 0) { + train_set_metadata[[ + "ordered_cat_vars" + ]] <- json_object_default$get_string_vector("ordered_cat_vars") + train_set_metadata[[ + "ordered_unique_levels" + ]] <- json_object_default$get_string_list( + "ordered_unique_levels", + train_set_metadata[["ordered_cat_vars"]] ) + } + if (train_set_metadata[["num_unordered_cat_vars"]] > 0) { train_set_metadata[[ - "num_ordered_cat_vars" - ]] <- json_object_default$get_scalar("num_ordered_cat_vars") + "unordered_cat_vars" + ]] <- json_object_default$get_string_vector("unordered_cat_vars") train_set_metadata[[ - "num_unordered_cat_vars" - ]] <- json_object_default$get_scalar("num_unordered_cat_vars") - if (train_set_metadata[["num_numeric_vars"]] > 0) { - train_set_metadata[[ - "numeric_vars" - ]] <- json_object_default$get_string_vector("numeric_vars") - } - if (train_set_metadata[["num_ordered_cat_vars"]] > 0) { - train_set_metadata[[ - "ordered_cat_vars" - ]] <- json_object_default$get_string_vector("ordered_cat_vars") - train_set_metadata[[ - "ordered_unique_levels" - ]] <- json_object_default$get_string_list( - "ordered_unique_levels", - train_set_metadata[["ordered_cat_vars"]] + "unordered_unique_levels" + ]] <- json_object_default$get_string_list( + "unordered_unique_levels", + train_set_metadata[["unordered_cat_vars"]] + ) + } + output[["train_set_metadata"]] <- train_set_metadata + + # Unpack model params + model_params = list() + model_params[["outcome_scale"]] <- json_object_default$get_scalar( + "outcome_scale" + ) + model_params[["outcome_mean"]] <- json_object_default$get_scalar( + "outcome_mean" + ) + model_params[["standardize"]] <- json_object_default$get_boolean( + "standardize" + ) + model_params[["initial_sigma2"]] <- json_object_default$get_scalar( + "initial_sigma2" + ) + model_params[["sample_sigma2_global"]] <- json_object_default$get_boolean( + "sample_sigma2_global" + ) + model_params[["sample_sigma2_leaf_mu"]] <- json_object_default$get_boolean( + "sample_sigma2_leaf_mu" + ) + model_params[["sample_sigma2_leaf_tau"]] <- json_object_default$get_boolean( + "sample_sigma2_leaf_tau" + ) + model_params[["include_variance_forest"]] <- include_variance_forest + model_params[["propensity_covariate"]] <- json_object_default$get_string( + "propensity_covariate" + ) + model_params[["has_rfx"]] <- json_object_default$get_boolean("has_rfx") + model_params[["has_rfx_basis"]] <- json_object_default$get_boolean( + "has_rfx_basis" + ) + model_params[["num_rfx_basis"]] <- json_object_default$get_scalar( + "num_rfx_basis" + ) + model_params[["num_covariates"]] <- json_object_default$get_scalar( + "num_covariates" + ) + model_params[["num_chains"]] <- json_object_default$get_scalar("num_chains") + model_params[["keep_every"]] <- json_object_default$get_scalar("keep_every") + model_params[["multivariate_treatment"]] <- json_object_default$get_boolean( + "multivariate_treatment" + ) + model_params[["adaptive_coding"]] <- json_object_default$get_boolean( + "adaptive_coding" + ) + model_params[[ + "internal_propensity_model" + ]] <- json_object_default$get_boolean("internal_propensity_model") + model_params[["probit_outcome_model"]] <- json_object_default$get_boolean( + "probit_outcome_model" + ) + + # Combine values that are sample-specific + for (i in 1:length(json_object_list)) { + json_object <- json_object_list[[i]] + if (i == 1) { + model_params[["num_gfr"]] <- json_object$get_scalar("num_gfr") + model_params[["num_burnin"]] <- json_object$get_scalar("num_burnin") + model_params[["num_mcmc"]] <- json_object$get_scalar("num_mcmc") + model_params[["num_samples"]] <- json_object$get_scalar( + "num_samples" + ) + } else { + prev_json <- json_object_list[[i - 1]] + model_params[["num_gfr"]] <- model_params[["num_gfr"]] + + json_object$get_scalar("num_gfr") + model_params[["num_burnin"]] <- model_params[["num_burnin"]] + + json_object$get_scalar("num_burnin") + model_params[["num_mcmc"]] <- model_params[["num_mcmc"]] + + json_object$get_scalar("num_mcmc") + model_params[["num_samples"]] <- model_params[["num_samples"]] + + json_object$get_scalar("num_samples") + } + } + output[["model_params"]] <- model_params + + # Unpack sampled parameters + if (model_params[["sample_sigma2_global"]]) { + for (i in 1:length(json_object_list)) { + json_object <- json_object_list[[i]] + if (i == 1) { + output[["sigma2_global_samples"]] <- json_object$get_vector( + "sigma2_global_samples", + "parameters" ) - } - if (train_set_metadata[["num_unordered_cat_vars"]] > 0) { - train_set_metadata[[ - "unordered_cat_vars" - ]] <- json_object_default$get_string_vector("unordered_cat_vars") - train_set_metadata[[ - "unordered_unique_levels" - ]] <- json_object_default$get_string_list( - "unordered_unique_levels", - train_set_metadata[["unordered_cat_vars"]] + } else { + output[["sigma2_global_samples"]] <- c( + output[["sigma2_global_samples"]], + json_object$get_vector( + "sigma2_global_samples", + "parameters" + ) ) + } } - output[["train_set_metadata"]] <- train_set_metadata - - # Unpack model params - model_params = list() - model_params[["outcome_scale"]] <- json_object_default$get_scalar( - "outcome_scale" - ) - model_params[["outcome_mean"]] <- json_object_default$get_scalar( - "outcome_mean" - ) - model_params[["standardize"]] <- json_object_default$get_boolean( - "standardize" - ) - model_params[["initial_sigma2"]] <- json_object_default$get_scalar( - "initial_sigma2" - ) - model_params[["sample_sigma2_global"]] <- json_object_default$get_boolean( - "sample_sigma2_global" - ) - model_params[["sample_sigma2_leaf_mu"]] <- json_object_default$get_boolean( - "sample_sigma2_leaf_mu" - ) - model_params[["sample_sigma2_leaf_tau"]] <- json_object_default$get_boolean( - "sample_sigma2_leaf_tau" - ) - model_params[["include_variance_forest"]] <- include_variance_forest - model_params[["propensity_covariate"]] <- json_object_default$get_string( - "propensity_covariate" - ) - model_params[["has_rfx"]] <- json_object_default$get_boolean("has_rfx") - model_params[["has_rfx_basis"]] <- json_object_default$get_boolean( - "has_rfx_basis" - ) - model_params[["num_rfx_basis"]] <- json_object_default$get_scalar( - "num_rfx_basis" - ) - model_params[["num_covariates"]] <- json_object_default$get_scalar( - "num_covariates" - ) - model_params[["num_chains"]] <- json_object_default$get_scalar("num_chains") - model_params[["keep_every"]] <- json_object_default$get_scalar("keep_every") - model_params[["multivariate_treatment"]] <- json_object_default$get_boolean( - "multivariate_treatment" - ) - model_params[["adaptive_coding"]] <- json_object_default$get_boolean( - "adaptive_coding" - ) - model_params[[ - "internal_propensity_model" - ]] <- json_object_default$get_boolean("internal_propensity_model") - model_params[["probit_outcome_model"]] <- json_object_default$get_boolean( - "probit_outcome_model" - ) - - # Combine values that are sample-specific + } + if (model_params[["sample_sigma2_leaf_mu"]]) { for (i in 1:length(json_object_list)) { - json_object <- json_object_list[[i]] - if (i == 1) { - model_params[["num_gfr"]] <- json_object$get_scalar("num_gfr") - model_params[["num_burnin"]] <- json_object$get_scalar("num_burnin") - model_params[["num_mcmc"]] <- json_object$get_scalar("num_mcmc") - model_params[["num_samples"]] <- json_object$get_scalar( - "num_samples" - ) - } else { - prev_json <- json_object_list[[i - 1]] - model_params[["num_gfr"]] <- model_params[["num_gfr"]] + - json_object$get_scalar("num_gfr") - model_params[["num_burnin"]] <- model_params[["num_burnin"]] + - json_object$get_scalar("num_burnin") - model_params[["num_mcmc"]] <- model_params[["num_mcmc"]] + - json_object$get_scalar("num_mcmc") - model_params[["num_samples"]] <- model_params[["num_samples"]] + - json_object$get_scalar("num_samples") - } - } - output[["model_params"]] <- model_params - - # Unpack sampled parameters - if (model_params[["sample_sigma2_global"]]) { - for (i in 1:length(json_object_list)) { - json_object <- json_object_list[[i]] - if (i == 1) { - output[["sigma2_global_samples"]] <- json_object$get_vector( - "sigma2_global_samples", - "parameters" - ) - } else { - output[["sigma2_global_samples"]] <- c( - output[["sigma2_global_samples"]], - json_object$get_vector( - "sigma2_global_samples", - "parameters" - ) - ) - } - } - } - if (model_params[["sample_sigma2_leaf_mu"]]) { - for (i in 1:length(json_object_list)) { - json_object <- json_object_list[[i]] - if (i == 1) { - output[["sigma2_leaf_mu_samples"]] <- json_object$get_vector( - "sigma2_leaf_mu_samples", - "parameters" - ) - } else { - output[["sigma2_leaf_mu_samples"]] <- c( - output[["sigma2_leaf_mu_samples"]], - json_object$get_vector( - "sigma2_leaf_mu_samples", - "parameters" - ) - ) - } - } - } - if (model_params[["sample_sigma2_leaf_tau"]]) { - for (i in 1:length(json_object_list)) { - json_object <- json_object_list[[i]] - if (i == 1) { - output[["sigma2_leaf_tau_samples"]] <- json_object$get_vector( - "sigma2_leaf_tau_samples", - "parameters" - ) - } else { - output[["sigma2_leaf_tau_samples"]] <- c( - output[["sigma2_leaf_tau_samples"]], - json_object$get_vector( - "sigma2_leaf_tau_samples", - "parameters" - ) - ) - } - } - } - if (model_params[["sample_sigma2_leaf_tau"]]) { - for (i in 1:length(json_object_list)) { - json_object <- json_object_list[[i]] - if (i == 1) { - output[["sigma2_leaf_tau_samples"]] <- json_object$get_vector( - "sigma2_leaf_tau_samples", - "parameters" - ) - } else { - output[["sigma2_leaf_tau_samples"]] <- c( - output[["sigma2_leaf_tau_samples"]], - json_object$get_vector( - "sigma2_leaf_tau_samples", - "parameters" - ) - ) - } - } + json_object <- json_object_list[[i]] + if (i == 1) { + output[["sigma2_leaf_mu_samples"]] <- json_object$get_vector( + "sigma2_leaf_mu_samples", + "parameters" + ) + } else { + output[["sigma2_leaf_mu_samples"]] <- c( + output[["sigma2_leaf_mu_samples"]], + json_object$get_vector( + "sigma2_leaf_mu_samples", + "parameters" + ) + ) + } } - if (model_params[["adaptive_coding"]]) { - for (i in 1:length(json_object_list)) { - json_object <- json_object_list[[i]] - if (i == 1) { - output[["b_1_samples"]] <- json_object$get_vector( - "b_1_samples", - "parameters" - ) - output[["b_0_samples"]] <- json_object$get_vector( - "b_0_samples", - "parameters" - ) - } else { - output[["b_1_samples"]] <- c( - output[["b_1_samples"]], - json_object$get_vector("b_1_samples", "parameters") - ) - output[["b_0_samples"]] <- c( - output[["b_0_samples"]], - json_object$get_vector("b_0_samples", "parameters") - ) - } - } + } + if (model_params[["sample_sigma2_leaf_tau"]]) { + for (i in 1:length(json_object_list)) { + json_object <- json_object_list[[i]] + if (i == 1) { + output[["sigma2_leaf_tau_samples"]] <- json_object$get_vector( + "sigma2_leaf_tau_samples", + "parameters" + ) + } else { + output[["sigma2_leaf_tau_samples"]] <- c( + output[["sigma2_leaf_tau_samples"]], + json_object$get_vector( + "sigma2_leaf_tau_samples", + "parameters" + ) + ) + } } - - # Unpack random effects - if (model_params[["has_rfx"]]) { - output[[ - "rfx_unique_group_ids" - ]] <- json_object_default$get_string_vector("rfx_unique_group_ids") - output[["rfx_samples"]] <- loadRandomEffectSamplesCombinedJson( - json_object_list, - 0 + } + if (model_params[["sample_sigma2_leaf_tau"]]) { + for (i in 1:length(json_object_list)) { + json_object <- json_object_list[[i]] + if (i == 1) { + output[["sigma2_leaf_tau_samples"]] <- json_object$get_vector( + "sigma2_leaf_tau_samples", + "parameters" + ) + } else { + output[["sigma2_leaf_tau_samples"]] <- c( + output[["sigma2_leaf_tau_samples"]], + json_object$get_vector( + "sigma2_leaf_tau_samples", + "parameters" + ) ) + } } - - # Unpack covariate preprocessor - preprocessor_metadata_string <- json_object_default$get_string( - "preprocessor_metadata" - ) - output[["train_set_metadata"]] <- createPreprocessorFromJsonString( - preprocessor_metadata_string + } + if (model_params[["adaptive_coding"]]) { + for (i in 1:length(json_object_list)) { + json_object <- json_object_list[[i]] + if (i == 1) { + output[["b_1_samples"]] <- json_object$get_vector( + "b_1_samples", + "parameters" + ) + output[["b_0_samples"]] <- json_object$get_vector( + "b_0_samples", + "parameters" + ) + } else { + output[["b_1_samples"]] <- c( + output[["b_1_samples"]], + json_object$get_vector("b_1_samples", "parameters") + ) + output[["b_0_samples"]] <- c( + output[["b_0_samples"]], + json_object$get_vector("b_0_samples", "parameters") + ) + } + } + } + + # Unpack random effects + if (model_params[["has_rfx"]]) { + output[[ + "rfx_unique_group_ids" + ]] <- json_object_default$get_string_vector("rfx_unique_group_ids") + output[["rfx_samples"]] <- loadRandomEffectSamplesCombinedJson( + json_object_list, + 0 ) - - class(output) <- "bcfmodel" - return(output) + } + + # Unpack covariate preprocessor + preprocessor_metadata_string <- json_object_default$get_string( + "preprocessor_metadata" + ) + output[["train_set_metadata"]] <- createPreprocessorFromJsonString( + preprocessor_metadata_string + ) + + class(output) <- "bcfmodel" + return(output) } diff --git a/R/posterior_transformation.R b/R/posterior_transformation.R new file mode 100644 index 00000000..2b0ae92a --- /dev/null +++ b/R/posterior_transformation.R @@ -0,0 +1,874 @@ +#' Sample from the posterior predictive distribution for outcomes modeled by BCF +#' +#' @param model_object A fitted BCF model object of class `bcfmodel`. +#' @param covariates (Optional) A matrix or data frame of covariates at which to compute the intervals. Required if the requested term depends on covariates (e.g., prognostic forest, CATE forest, variance forest, or overall predictions). +#' @param treatment (Optional) A vector or matrix of treatment assignments. Required if the requested term is `"y_hat"` (overall predictions). +#' @param propensity (Optional) A vector or matrix of propensity scores. Required if the requested term is `"y_hat"` (overall predictions) and the underlying model depends on user-provided propensities. +#' @param rfx_group_ids (Optional) A vector of group IDs for random effects model. Required if the BCF model includes random effects. +#' @param rfx_basis (Optional) A matrix of bases for random effects model. Required if the BCF model includes random effects. +#' @param num_draws (Optional) The number of samples to draw from the likelihood, for each draw of the posterior, in computing intervals. Defaults to a heuristic based on the number of samples in a BCF model (i.e. if the BCF model has >1000 draws, we use 1 draw from the likelihood per sample, otherwise we upsample to ensure at least 1000 posterior predictive draws). +#' +#' @returns Array of posterior predictive samples with dimensions (num_observations, num_posterior_samples, num_draws) if num_draws > 1, otherwise (num_observations, num_posterior_samples). +#' +#' @export +#' @examples +#' n <- 100 +#' p <- 5 +#' X <- matrix(rnorm(n * p), nrow = n, ncol = p) +#' y <- 2 * X[,1] + rnorm(n) +#' bart_model <- bart(y_train = y, X_train = X) +#' ppd_samples <- sample_bart_posterior_predictive( +#' model_object = bart_model, covariates = X +#' ) +sample_bcf_posterior_predictive <- function( + model_object, + covariates = NULL, + treatment = NULL, + propensity = NULL, + rfx_group_ids = NULL, + rfx_basis = NULL, + num_draws = NULL +) { + # Check the provided model object and requested term + check_model_is_valid(model_object) + + # Determine whether the outcome is continuous (Gaussian) or binary (probit-link) + is_probit <- model_object$model_params$probit_outcome_model + + # Check that all the necessary inputs were provided for interval computation + needs_covariates <- TRUE + if (needs_covariates) { + if (is.null(covariates)) { + stop( + "'covariates' must be provided in order to compute the requested intervals" + ) + } + if (!is.matrix(covariates) && !is.data.frame(covariates)) { + stop("'covariates' must be a matrix or data frame") + } + } + needs_treatment <- needs_covariates + if (needs_treatment) { + if (is.null(treatment)) { + stop( + "'treatment' must be provided in order to compute the requested intervals" + ) + } + if (!is.matrix(treatment) && !is.numeric(treatment)) { + stop("'treatment' must be a numeric vector or matrix") + } + if (is.matrix(treatment)) { + if (nrow(treatment) != nrow(covariates)) { + stop("'treatment' must have the same number of rows as 'covariates'") + } + } else { + if (length(treatment) != nrow(covariates)) { + stop( + "'treatment' must have the same number of elements as 'covariates'" + ) + } + } + } + uses_propensity <- model_object$model_params$propensity_covariate != "none" + internal_propensity_model <- model_object$model_params$internal_propensity_model + needs_propensity <- (needs_covariates && + uses_propensity && + (!internal_propensity_model)) + if (needs_propensity) { + if (is.null(propensity)) { + stop( + "'propensity' must be provided in order to compute the requested intervals" + ) + } + if (!is.matrix(propensity) && !is.numeric(propensity)) { + stop("'propensity' must be a numeric vector or matrix") + } + if (is.matrix(propensity)) { + if (nrow(propensity) != nrow(covariates)) { + stop("'propensity' must have the same number of rows as 'covariates'") + } + } else { + if (length(propensity) != nrow(covariates)) { + stop( + "'propensity' must have the same number of elements as 'covariates'" + ) + } + } + } + needs_rfx_data <- model_object$model_params$has_rfx + if (needs_rfx_data) { + if (is.null(rfx_group_ids)) { + stop( + "'rfx_group_ids' must be provided in order to compute the requested intervals" + ) + } + if (length(rfx_group_ids) != nrow(covariates)) { + stop( + "'rfx_group_ids' must have the same length as the number of rows in 'covariates'" + ) + } + if (is.null(rfx_basis)) { + stop( + "'rfx_basis' must be provided in order to compute the requested intervals" + ) + } + if (!is.matrix(rfx_basis)) { + stop("'rfx_basis' must be a matrix") + } + if (nrow(rfx_basis) != nrow(covariates)) { + stop("'rfx_basis' must have the same number of rows as 'covariates'") + } + } + + # Compute posterior predictive samples + bcf_preds <- predict( + model_object, + X = covariates, + Z = treatment, + propensity = propensity, + rfx_group_ids = rfx_group_ids, + rfx_basis = rfx_basis, + type = "posterior", + terms = c("all") + ) + has_rfx <- model_object$model_params$has_rfx + has_variance_forest <- model_object$model_params$include_variance_forest + samples_global_variance <- model_object$model_params$sample_sigma2_global + num_posterior_draws <- model_object$model_params$num_samples + num_observations <- nrow(covariates) + ppd_mean <- bcf_preds$y_hat + if (has_variance_forest) { + ppd_variance <- bcf_preds$variance_forest_predictions + } else { + if (samples_global_variance) { + ppd_variance <- matrix( + rep( + model_object$sigma2_global_samples, + each = num_observations + ), + nrow = num_observations + ) + } else { + ppd_variance <- model_object$model_params$initial_sigma2 + } + } + if (is.null(num_draws)) { + ppd_draw_multiplier <- posterior_predictive_heuristic_multiplier( + num_posterior_draws, + num_observations + ) + } else { + ppd_draw_multiplier <- num_draws + } + num_ppd_draws <- ppd_draw_multiplier * num_posterior_draws * num_observations + ppd_vector <- rnorm(num_ppd_draws, ppd_mean, sqrt(ppd_variance)) + if (ppd_draw_multiplier > 1) { + ppd_array <- array( + ppd_vector, + dim = c(num_observations, num_posterior_draws, ppd_draw_multiplier) + ) + } else { + ppd_array <- array( + ppd_vector, + dim = c(num_observations, num_posterior_draws) + ) + } + + if (is_probit) { + ppd_array <- (ppd_array > 0.0) * 1 + } + + return(ppd_array) +} + +#' Sample from the posterior predictive distribution for outcomes modeled by BART +#' +#' @param model_object A fitted BART model object of class `bartmodel`. +#' @param covariates A matrix or data frame of covariates at which to compute the intervals. Required if the BART model depends on covariates (e.g., contains a mean or variance forest). +#' @param basis A matrix of bases for mean forest models with regression defined in the leaves. Required for "leaf regression" models. +#' @param rfx_group_ids A vector of group IDs for random effects model. Required if the BART model includes random effects. +#' @param rfx_basis A matrix of bases for random effects model. Required if the BART model includes random effects. +#' @param num_draws The number of posterior predictive samples to draw in computing intervals. Defaults to a heuristic based on the number of samples in a BART model (i.e. if the BART model has >1000 draws, we use 1 draw from the likelihood per sample, otherwise we upsample to ensure intervals are based on at least 1000 posterior predictive draws). +#' +#' @returns Array of posterior predictive samples with dimensions (num_observations, num_posterior_samples, num_draws) if num_draws > 1, otherwise (num_observations, num_posterior_samples). +#' +#' @export +#' @examples +#' n <- 100 +#' p <- 5 +#' X <- matrix(rnorm(n * p), nrow = n, ncol = p) +#' y <- 2 * X[,1] + rnorm(n) +#' bart_model <- bart(y_train = y, X_train = X) +#' ppd_samples <- sample_bart_posterior_predictive( +#' model_object = bart_model, covariates = X +#' ) +sample_bart_posterior_predictive <- function( + model_object, + covariates = NULL, + basis = NULL, + rfx_group_ids = NULL, + rfx_basis = NULL, + num_draws = NULL +) { + # Check the provided model object and requested term + check_model_is_valid(model_object) + + # Determine whether the outcome is continuous (Gaussian) or binary (probit-link) + is_probit <- model_object$model_params$probit_outcome_model + + # Check that all the necessary inputs were provided for interval computation + needs_covariates <- model_object$model_params$include_mean_forest + if (needs_covariates) { + if (is.null(covariates)) { + stop( + "'covariates' must be provided in order to compute the requested intervals" + ) + } + if (!is.matrix(covariates) && !is.data.frame(covariates)) { + stop("'covariates' must be a matrix or data frame") + } + } + needs_basis <- needs_covariates && model_object$model_params$has_basis + if (needs_basis) { + if (is.null(basis)) { + stop( + "'basis' must be provided in order to compute the requested intervals" + ) + } + if (!is.matrix(basis)) { + stop("'basis' must be a matrix") + } + if (is.matrix(basis)) { + if (nrow(basis) != nrow(covariates)) { + stop("'basis' must have the same number of rows as 'covariates'") + } + } else { + if (length(basis) != nrow(covariates)) { + stop("'basis' must have the same number of elements as 'covariates'") + } + } + } + needs_rfx_data <- model_object$model_params$has_rfx + if (needs_rfx_data) { + if (is.null(rfx_group_ids)) { + stop( + "'rfx_group_ids' must be provided in order to compute the requested intervals" + ) + } + if (length(rfx_group_ids) != nrow(covariates)) { + stop( + "'rfx_group_ids' must have the same length as the number of rows in 'covariates'" + ) + } + if (is.null(rfx_basis)) { + stop( + "'rfx_basis' must be provided in order to compute the requested intervals" + ) + } + if (!is.matrix(rfx_basis)) { + stop("'rfx_basis' must be a matrix") + } + if (nrow(rfx_basis) != nrow(covariates)) { + stop("'rfx_basis' must have the same number of rows as 'covariates'") + } + } + + # Compute posterior predictive samples + bart_preds <- predict( + model_object, + covariates = covariates, + leaf_basis = basis, + rfx_group_ids = rfx_group_ids, + rfx_basis = rfx_basis, + type = "posterior", + terms = c("all") + ) + has_mean_term <- (model_object$model_params$include_mean_forest || + model_object$model_params$has_rfx) + has_variance_forest <- model_object$model_params$include_variance_forest + samples_global_variance <- model_object$model_params$sample_sigma2_global + num_posterior_draws <- model_object$model_params$num_samples + num_observations <- nrow(covariates) + if (has_mean_term) { + ppd_mean <- bart_preds$y_hat + } else { + ppd_mean <- 0 + } + if (has_variance_forest) { + ppd_variance <- bart_preds$variance_forest_predictions + } else { + if (samples_global_variance) { + ppd_variance <- matrix( + rep( + model_object$sigma2_global_samples, + each = num_observations + ), + nrow = num_observations + ) + } else { + ppd_variance <- model_object$model_params$sigma2_init + } + } + if (is.null(num_draws)) { + ppd_draw_multiplier <- posterior_predictive_heuristic_multiplier( + num_posterior_draws, + num_observations + ) + } else { + ppd_draw_multiplier <- num_draws + } + num_ppd_draws <- ppd_draw_multiplier * num_posterior_draws * num_observations + ppd_vector <- rnorm(num_ppd_draws, ppd_mean, sqrt(ppd_variance)) + if (ppd_draw_multiplier > 1) { + ppd_array <- array( + ppd_vector, + dim = c(num_observations, num_posterior_draws, ppd_draw_multiplier) + ) + } else { + ppd_array <- array( + ppd_vector, + dim = c(num_observations, num_posterior_draws) + ) + } + + if (is_probit) { + ppd_array <- (ppd_array > 0.0) * 1 + } + + return(ppd_array) +} + +posterior_predictive_heuristic_multiplier <- function( + num_samples, + num_observations +) { + if (num_samples >= 1000) { + return(1) + } else { + return(ceiling(1000 / num_samples)) + } +} + +#' Compute posterior credible intervals for BCF model terms +#' +#' This function computes posterior credible intervals for specified terms from a fitted BCF model. It supports intervals for prognostic forests, CATE forests, variance forests, random effects, and overall mean outcome predictions. +#' @param model_object A fitted BCF model object of class `bcfmodel`. +#' @param term A character string specifying the model term for which to compute intervals. Options for BCF models are `"prognostic_function"`, `"cate"`, `"variance_forest"`, `"rfx"`, or `"y_hat"`. +#' @param level A numeric value between 0 and 1 specifying the credible interval level (default is 0.95 for a 95% credible interval). +#' @param scale (Optional) Scale of mean function predictions. Options are "linear", which returns predictions on the original scale of the mean forest / RFX terms, and "probability", which transforms predictions into a probability of observing `y == 1`. "probability" is only valid for models fit with a probit outcome model. Default: "linear". +#' @param covariates (Optional) A matrix or data frame of covariates at which to compute the intervals. Required if the requested term depends on covariates (e.g., prognostic forest, CATE forest, variance forest, or overall predictions). +#' @param treatment (Optional) A vector or matrix of treatment assignments. Required if the requested term is `"y_hat"` (overall predictions). +#' @param propensity (Optional) A vector or matrix of propensity scores. Required if the requested term is `"y_hat"` (overall predictions) and the underlying model depends on user-provided propensities. +#' @param rfx_group_ids An optional vector of group IDs for random effects. Required if the requested term includes random effects. +#' @param rfx_basis An optional matrix of basis function evaluations for random effects. Required if the requested term includes random effects. +#' +#' @returns A list containing the lower and upper bounds of the credible interval for the specified term. If multiple terms are requested, a named list with intervals for each term is returned. +#' +#' @export +#' @examples +#' n <- 100 +#' p <- 5 +#' X <- matrix(rnorm(n * p), nrow = n, ncol = p) +#' pi_X <- pnorm(0.5 * X[,1]) +#' Z <- rbinom(n, 1, pi_X) +#' mu_X <- X[,1] +#' tau_X <- 0.25 * X[,2] +#' y <- mu_X + tau_X * Z + rnorm(n) +#' bcf_model <- bcf(X_train = X, Z_train = Z, y_train = y, +#' propensity_train = pi_X) +#' intervals <- compute_bcf_posterior_interval( +#' model_object = bcf_model, +#' terms = c("prognostic_function", "cate"), +#' covariates = X, +#' treatment = Z, +#' propensity = pi_X, +#' level = 0.90 +#' ) +compute_bcf_posterior_interval <- function( + model_object, + terms, + level = 0.95, + scale = "linear", + covariates = NULL, + treatment = NULL, + propensity = NULL, + rfx_group_ids = NULL, + rfx_basis = NULL +) { + # Check the provided model object and requested term + check_model_is_valid(model_object) + for (term in terms) { + check_model_has_term(model_object, term) + } + + # Handle mean function scale + if (!is.character(scale)) { + stop("scale must be a string or character vector") + } + if (!(scale %in% c("linear", "probability"))) { + stop("scale must either be 'linear' or 'probability'") + } + is_probit <- model_object$model_params$probit_outcome_model + if ((scale == "probability") && (!is_probit)) { + stop( + "scale cannot be 'probability' for models not fit with a probit outcome model" + ) + } + probability_scale <- scale == "probability" + + # Check that all the necessary inputs were provided for interval computation + needs_covariates_intermediate <- ((("y_hat" %in% terms) || + ("all" %in% terms))) + needs_covariates <- (("prognostic_function" %in% terms) || + ("cate" %in% terms) || + ("variance_forest" %in% terms) || + (needs_covariates_intermediate)) + if (needs_covariates) { + if (is.null(covariates)) { + stop( + "'covariates' must be provided in order to compute the requested intervals" + ) + } + if (!is.matrix(covariates) && !is.data.frame(covariates)) { + stop("'covariates' must be a matrix or data frame") + } + } + needs_treatment <- needs_covariates + if (needs_treatment) { + if (is.null(treatment)) { + stop( + "'treatment' must be provided in order to compute the requested intervals" + ) + } + if (!is.matrix(treatment) && !is.numeric(treatment)) { + stop("'treatment' must be a numeric vector or matrix") + } + if (is.matrix(treatment)) { + if (nrow(treatment) != nrow(covariates)) { + stop("'treatment' must have the same number of rows as 'covariates'") + } + } else { + if (length(treatment) != nrow(covariates)) { + stop( + "'treatment' must have the same number of elements as 'covariates'" + ) + } + } + } + uses_propensity <- model_object$model_params$propensity_covariate != "none" + internal_propensity_model <- model_object$model_params$internal_propensity_model + needs_propensity <- (needs_covariates && + uses_propensity && + (!internal_propensity_model)) + if (needs_propensity) { + if (is.null(propensity)) { + stop( + "'propensity' must be provided in order to compute the requested intervals" + ) + } + if (!is.matrix(propensity) && !is.numeric(propensity)) { + stop("'propensity' must be a numeric vector or matrix") + } + if (is.matrix(propensity)) { + if (nrow(propensity) != nrow(covariates)) { + stop("'propensity' must have the same number of rows as 'covariates'") + } + } else { + if (length(propensity) != nrow(covariates)) { + stop( + "'propensity' must have the same number of elements as 'covariates'" + ) + } + } + } + needs_rfx_data_intermediate <- ((("y_hat" %in% terms) || + ("all" %in% terms)) && + model_object$model_params$has_rfx) + needs_rfx_data <- (("rfx" %in% terms) || + (needs_rfx_data_intermediate)) + if (needs_rfx_data) { + if (is.null(rfx_group_ids)) { + stop( + "'rfx_group_ids' must be provided in order to compute the requested intervals" + ) + } + if (length(rfx_group_ids) != nrow(covariates)) { + stop( + "'rfx_group_ids' must have the same length as the number of rows in 'covariates'" + ) + } + if (is.null(rfx_basis)) { + stop( + "'rfx_basis' must be provided in order to compute the requested intervals" + ) + } + if (!is.matrix(rfx_basis)) { + stop("'rfx_basis' must be a matrix") + } + if (nrow(rfx_basis) != nrow(covariates)) { + stop("'rfx_basis' must have the same number of rows as 'covariates'") + } + } + + # Compute posterior matrices for the requested model terms + predictions <- predict( + model_object, + X = covariates, + Z = treatment, + propensity = propensity, + rfx_group_ids = rfx_group_ids, + rfx_basis = rfx_basis, + type = "posterior", + terms = terms, + scale = scale + ) + has_multiple_terms <- ifelse(is.list(predictions), TRUE, FALSE) + + # Compute the interval + if (has_multiple_terms) { + result <- list() + for (term_name in names(predictions)) { + result[[term_name]] <- summarize_interval( + predictions[[term_name]], + sample_dim = 2, + level = level + ) + } + return(result) + } else { + return(summarize_interval( + predictions, + sample_dim = 2, + level = level + )) + } +} + +#' Compute posterior credible intervals for BART model terms +#' +#' This function computes posterior credible intervals for specified terms from a fitted BART model. It supports intervals for mean functions, variance functions, random effects, and overall predictions. +#' @param model_object A fitted BART or BCF model object of class `bartmodel`. +#' @param term A character string specifying the model term for which to compute intervals. Options for BART models are `"mean_forest"`, `"variance_forest"`, `"rfx"`, or `"y_hat"`. +#' @param level A numeric value between 0 and 1 specifying the credible interval level (default is 0.95 for a 95% credible interval). +#' @param scale (Optional) Scale of mean function predictions. Options are "linear", which returns predictions on the original scale of the mean forest / RFX terms, and "probability", which transforms predictions into a probability of observing `y == 1`. "probability" is only valid for models fit with a probit outcome model. Default: "linear". +#' @param covariates A matrix or data frame of covariates at which to compute the intervals. Required if the requested term depends on covariates (e.g., mean forest, variance forest, or overall predictions). +#' @param basis An optional matrix of basis function evaluations for mean forest models with regression defined in the leaves. Required for "leaf regression" models. +#' @param rfx_group_ids An optional vector of group IDs for random effects. Required if the requested term includes random effects. +#' @param rfx_basis An optional matrix of basis function evaluations for random effects. Required if the requested term includes random effects. +#' +#' @returns A list containing the lower and upper bounds of the credible interval for the specified term. If multiple terms are requested, a named list with intervals for each term is returned. +#' +#' @export +#' @examples +#' n <- 100 +#' p <- 5 +#' X <- matrix(rnorm(n * p), nrow = n, ncol = p) +#' y <- 2 * X[,1] + rnorm(n) +#' bart_model <- bart(y_train = y, X_train = X) +#' intervals <- compute_bart_posterior_interval( +#' model_object = bart_model, +#' terms = c("mean_forest", "y_hat"), +#' covariates = X, +#' level = 0.90 +#' ) +#' @export +compute_bart_posterior_interval <- function( + model_object, + terms, + level = 0.95, + scale = "linear", + covariates = NULL, + basis = NULL, + rfx_group_ids = NULL, + rfx_basis = NULL +) { + # Check the provided model object and requested term + check_model_is_valid(model_object) + for (term in terms) { + check_model_has_term(model_object, term) + } + + # Handle mean function scale + if (!is.character(scale)) { + stop("scale must be a string or character vector") + } + if (!(scale %in% c("linear", "probability"))) { + stop("scale must either be 'linear' or 'probability'") + } + is_probit <- model_object$model_params$probit_outcome_model + if ((scale == "probability") && (!is_probit)) { + stop( + "scale cannot be 'probability' for models not fit with a probit outcome model" + ) + } + probability_scale <- scale == "probability" + + # Check that all the necessary inputs were provided for interval computation + needs_covariates_intermediate <- ((("y_hat" %in% terms) || + ("all" %in% terms)) && + model_object$model_params$include_mean_forest) + needs_covariates <- (("mean_forest" %in% terms) || + ("variance_forest" %in% terms) || + (needs_covariates_intermediate)) + if (needs_covariates) { + if (is.null(covariates)) { + stop( + "'covariates' must be provided in order to compute the requested intervals" + ) + } + if (!is.matrix(covariates) && !is.data.frame(covariates)) { + stop("'covariates' must be a matrix or data frame") + } + } + needs_basis <- needs_covariates && model_object$model_params$has_basis + if (needs_basis) { + if (is.null(basis)) { + stop( + "'basis' must be provided in order to compute the requested intervals" + ) + } + if (!is.matrix(basis)) { + stop("'basis' must be a matrix") + } + if (is.matrix(basis)) { + if (nrow(basis) != nrow(covariates)) { + stop("'basis' must have the same number of rows as 'covariates'") + } + } else { + if (length(basis) != nrow(covariates)) { + stop("'basis' must have the same number of elements as 'covariates'") + } + } + } + needs_rfx_data_intermediate <- ((("y_hat" %in% terms) || + ("all" %in% terms)) && + model_object$model_params$has_rfx) + needs_rfx_data <- (("rfx" %in% terms) || + (needs_rfx_data_intermediate)) + if (needs_rfx_data) { + if (is.null(rfx_group_ids)) { + stop( + "'rfx_group_ids' must be provided in order to compute the requested intervals" + ) + } + if (length(rfx_group_ids) != nrow(covariates)) { + stop( + "'rfx_group_ids' must have the same length as the number of rows in 'covariates'" + ) + } + if (is.null(rfx_basis)) { + stop( + "'rfx_basis' must be provided in order to compute the requested intervals" + ) + } + if (!is.matrix(rfx_basis)) { + stop("'rfx_basis' must be a matrix") + } + if (nrow(rfx_basis) != nrow(covariates)) { + stop("'rfx_basis' must have the same number of rows as 'covariates'") + } + } + + # Compute posterior matrices for the requested model terms + predictions <- predict( + model_object, + covariates = covariates, + leaf_basis = basis, + rfx_group_ids = rfx_group_ids, + rfx_basis = rfx_basis, + type = "posterior", + terms = terms, + scale = scale + ) + has_multiple_terms <- ifelse(is.list(predictions), TRUE, FALSE) + + # Compute the interval + if (has_multiple_terms) { + result <- list() + for (term_name in names(predictions)) { + result[[term_name]] <- summarize_interval( + predictions[[term_name]], + sample_dim = 2, + level = level + ) + } + return(result) + } else { + return(summarize_interval( + predictions, + sample_dim = 2, + level = level + )) + } +} + +transform_power <- function(array, exponent) { + return(compute_transformation(array, fn = function(x) x**exponent)) +} + +transform_multiply <- function(array, multiple) { + return(compute_transformation(array, fn = function(x) x * multiple)) +} + +transform_add <- function(array, addend) { + return(compute_transformation(array, fn = function(x) x + addend)) +} + +transform_exp <- function(array) { + return(compute_transformation(array, fn = exp)) +} + +transform_log <- function(array) { + return(compute_transformation(array, fn = log)) +} + +transform_pnorm <- function(array) { + return(compute_transformation(array, fn = pnorm)) +} + +compute_transformation <- function(array, fn = NULL) { + # Check that the array is numeric and at least 1 dimensional + stopifnot(is.numeric(array) && length(dim(array)) >= 1) + + # Calculate the transformation + return(fn(array)) +} + +summarize_mean <- function(array, sample_dim = 2) { + return(compute_summary(array, sample_dim, mean)) +} + +summarize_median <- function(array, sample_dim = 2) { + return(compute_summary(array, sample_dim, median)) +} + +compute_summary <- function(array, sample_dim = 2, fn = mean) { + # Check that the array is numeric and at least 2 dimensional + stopifnot(is.numeric(array) && length(dim(array)) >= 2) + + # Determine the dimensions over which reduction is computed + apply_dim <- setdiff(1:length(dim(array)), sample_dim) + + # Compute the reduction + result <- apply(array, apply_dim, function(x) { + fn(x) + }) + + return(result) +} + +summarize_interval <- function(array, sample_dim = 2, level = 0.95) { + # Check that the array is numeric and at least 2 dimensional + stopifnot(is.numeric(array) && length(dim(array)) >= 2) + + # Compute lower and upper quantiles based on the requested interval + quantile_lb <- (1 - level) / 2 + quantile_ub <- 1 - quantile_lb + + # Determine the dimensions over which interval is computed + apply_dim <- setdiff(1:length(dim(array)), sample_dim) + + # Calculate the interval + result_lb <- apply(array, apply_dim, function(x) { + quantile(x, probs = quantile_lb, names = FALSE) + }) + result_ub <- apply(array, apply_dim, function(x) { + quantile(x, probs = quantile_ub, names = FALSE) + }) + + return(list(lower = result_lb, upper = result_ub)) +} + +check_model_is_valid <- function(model_object) { + if ( + (!inherits(model_object, "bartmodel")) && + (!inherits(model_object, "bcfmodel")) + ) { + stop("'model_object' must be a bartmodel or bcfmodel") + } +} + +check_model_has_term <- function(model_object, term) { + # Parse inputs + if (!is.character(term) || length(term) != 1) { + stop("'term' must be a single character string") + } + if ( + (!inherits(model_object, "bartmodel")) && + (!inherits(model_object, "bcfmodel")) + ) { + stop("'model_object' must be a bartmodel or bcfmodel") + } + model_type <- ifelse(inherits(model_object, "bartmodel"), "bart", "bcf") + + # Check if the term was fitted as part of the provided model + if (model_type == "bart") { + validate_bart_term(term) + return(bart_model_has_term(model_object, term)) + } else { + validate_bcf_term(term) + return(bcf_model_has_term(model_object, term)) + } +} + +bart_model_has_term <- function(model_object, term) { + if (term == "mean_forest") { + return(model_object$model_params$include_mean_forest) + } else if (term == "variance_forest") { + return(model_object$model_params$include_variance_forest) + } else if (term == "rfx") { + return(model_object$model_params$has_rfx) + } else if (term == "y_hat") { + return( + model_object$model_params$include_mean_forest || + model_object$model_params$has_rfx + ) + } else if (term == "all") { + return(TRUE) + } else { + return(FALSE) + } +} + +bcf_model_has_term <- function(model_object, term) { + if (term == "prognostic_function") { + return(TRUE) + } else if (term == "cate") { + return(TRUE) + } else if (term == "variance_forest") { + return(model_object$model_params$include_variance_forest) + } else if (term == "rfx") { + return(model_object$model_params$has_rfx) + } else if (term == "y_hat") { + return(TRUE) + } else if (term == "all") { + return(TRUE) + } else { + return(FALSE) + } +} + +validate_bart_term <- function(term) { + model_terms <- c("mean_forest", "variance_forest", "rfx", "y_hat", "all") + if (!(term %in% model_terms)) { + stop( + "'term' must be one of 'mean_forest', 'variance_forest', 'rfx', 'y_hat', or 'all' for bartmodel objects" + ) + } +} + +validate_bcf_term <- function(term) { + model_terms <- c( + "prognostic_function", + "cate", + "variance_forest", + "rfx", + "y_hat", + "all" + ) + if (!(term %in% model_terms)) { + stop( + "'term' must be one of 'prognostic_function', 'cate', 'variance_forest', 'rfx', 'y_hat', or 'all' for bcfmodel objects" + ) + } +} diff --git a/man/compute_bart_posterior_interval.Rd b/man/compute_bart_posterior_interval.Rd new file mode 100644 index 00000000..66ddceaf --- /dev/null +++ b/man/compute_bart_posterior_interval.Rd @@ -0,0 +1,53 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/posterior_transformation.R +\name{compute_bart_posterior_interval} +\alias{compute_bart_posterior_interval} +\title{Compute posterior credible intervals for BART model terms} +\usage{ +compute_bart_posterior_interval( + model_object, + terms, + level = 0.95, + scale = "linear", + covariates = NULL, + basis = NULL, + rfx_group_ids = NULL, + rfx_basis = NULL +) +} +\arguments{ +\item{model_object}{A fitted BART or BCF model object of class \code{bartmodel}.} + +\item{level}{A numeric value between 0 and 1 specifying the credible interval level (default is 0.95 for a 95\% credible interval).} + +\item{scale}{(Optional) Scale of mean function predictions. Options are "linear", which returns predictions on the original scale of the mean forest / RFX terms, and "probability", which transforms predictions into a probability of observing \code{y == 1}. "probability" is only valid for models fit with a probit outcome model. Default: "linear".} + +\item{covariates}{A matrix or data frame of covariates at which to compute the intervals. Required if the requested term depends on covariates (e.g., mean forest, variance forest, or overall predictions).} + +\item{basis}{An optional matrix of basis function evaluations for mean forest models with regression defined in the leaves. Required for "leaf regression" models.} + +\item{rfx_group_ids}{An optional vector of group IDs for random effects. Required if the requested term includes random effects.} + +\item{rfx_basis}{An optional matrix of basis function evaluations for random effects. Required if the requested term includes random effects.} + +\item{term}{A character string specifying the model term for which to compute intervals. Options for BART models are \code{"mean_forest"}, \code{"variance_forest"}, \code{"rfx"}, or \code{"y_hat"}.} +} +\value{ +A list containing the lower and upper bounds of the credible interval for the specified term. If multiple terms are requested, a named list with intervals for each term is returned. +} +\description{ +This function computes posterior credible intervals for specified terms from a fitted BART model. It supports intervals for mean functions, variance functions, random effects, and overall predictions. +} +\examples{ +n <- 100 +p <- 5 +X <- matrix(rnorm(n * p), nrow = n, ncol = p) +y <- 2 * X[,1] + rnorm(n) +bart_model <- bart(y_train = y, X_train = X) +intervals <- compute_bart_posterior_interval( + model_object = bart_model, + terms = c("mean_forest", "y_hat"), + covariates = X, + level = 0.90 +) +} diff --git a/man/compute_bcf_posterior_interval.Rd b/man/compute_bcf_posterior_interval.Rd new file mode 100644 index 00000000..f82849ed --- /dev/null +++ b/man/compute_bcf_posterior_interval.Rd @@ -0,0 +1,63 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/posterior_transformation.R +\name{compute_bcf_posterior_interval} +\alias{compute_bcf_posterior_interval} +\title{Compute posterior credible intervals for BCF model terms} +\usage{ +compute_bcf_posterior_interval( + model_object, + terms, + level = 0.95, + scale = "linear", + covariates = NULL, + treatment = NULL, + propensity = NULL, + rfx_group_ids = NULL, + rfx_basis = NULL +) +} +\arguments{ +\item{model_object}{A fitted BCF model object of class \code{bcfmodel}.} + +\item{level}{A numeric value between 0 and 1 specifying the credible interval level (default is 0.95 for a 95\% credible interval).} + +\item{scale}{(Optional) Scale of mean function predictions. Options are "linear", which returns predictions on the original scale of the mean forest / RFX terms, and "probability", which transforms predictions into a probability of observing \code{y == 1}. "probability" is only valid for models fit with a probit outcome model. Default: "linear".} + +\item{covariates}{(Optional) A matrix or data frame of covariates at which to compute the intervals. Required if the requested term depends on covariates (e.g., prognostic forest, CATE forest, variance forest, or overall predictions).} + +\item{treatment}{(Optional) A vector or matrix of treatment assignments. Required if the requested term is \code{"y_hat"} (overall predictions).} + +\item{propensity}{(Optional) A vector or matrix of propensity scores. Required if the requested term is \code{"y_hat"} (overall predictions) and the underlying model depends on user-provided propensities.} + +\item{rfx_group_ids}{An optional vector of group IDs for random effects. Required if the requested term includes random effects.} + +\item{rfx_basis}{An optional matrix of basis function evaluations for random effects. Required if the requested term includes random effects.} + +\item{term}{A character string specifying the model term for which to compute intervals. Options for BCF models are \code{"prognostic_function"}, \code{"cate"}, \code{"variance_forest"}, \code{"rfx"}, or \code{"y_hat"}.} +} +\value{ +A list containing the lower and upper bounds of the credible interval for the specified term. If multiple terms are requested, a named list with intervals for each term is returned. +} +\description{ +This function computes posterior credible intervals for specified terms from a fitted BCF model. It supports intervals for prognostic forests, CATE forests, variance forests, random effects, and overall mean outcome predictions. +} +\examples{ +n <- 100 +p <- 5 +X <- matrix(rnorm(n * p), nrow = n, ncol = p) +pi_X <- pnorm(0.5 * X[,1]) +Z <- rbinom(n, 1, pi_X) +mu_X <- X[,1] +tau_X <- 0.25 * X[,2] +y <- mu_X + tau_X * Z + rnorm(n) +bcf_model <- bcf(X_train = X, Z_train = Z, y_train = y, + propensity_train = pi_X) +intervals <- compute_bcf_posterior_interval( + model_object = bcf_model, + terms = c("prognostic_function", "cate"), + covariates = X, + treatment = Z, + propensity = pi_X, + level = 0.90 +) +} diff --git a/man/predict.bartmodel.Rd b/man/predict.bartmodel.Rd index 99e9d878..6d60b0db 100644 --- a/man/predict.bartmodel.Rd +++ b/man/predict.bartmodel.Rd @@ -6,20 +6,19 @@ \usage{ \method{predict}{bartmodel}( object, - X, + covariates, leaf_basis = NULL, rfx_group_ids = NULL, rfx_basis = NULL, type = "posterior", terms = "all", + scale = "linear", ... ) } \arguments{ \item{object}{Object of type \code{bart} containing draws of a regression forest and associated sampling outputs.} -\item{X}{Covariates used to determine tree leaf predictions for each observation. Must be passed as a matrix or dataframe.} - \item{leaf_basis}{(Optional) Bases used for prediction (by e.g. dot product with leaf values). Default: \code{NULL}.} \item{rfx_group_ids}{(Optional) Test set group labels used for an additive random effects model. @@ -32,11 +31,14 @@ that were not in the training set.} \item{terms}{(Optional) Which model terms to include in the prediction. This can be a single term or a list of model terms. Options include "y_hat", "mean_forest", "rfx", "variance_forest", or "all". If a model doesn't have mean forest, random effects, or variance forest predictions, but one of those terms is request, the request will simply be ignored. If none of the requested terms are present in a model, this function will return \code{NULL} along with a warning. Default: "all".} +\item{scale}{(Optional) Scale of mean function predictions. Options are "linear", which returns predictions on the original scale of the mean forest / RFX terms, and "probability", which transforms predictions into a probability of observing \code{y == 1}. "probability" is only valid for models fit with a probit outcome model. Default: "linear".} + \item{...}{(Optional) Other prediction parameters.} + +\item{X}{Covariates used to determine tree leaf predictions for each observation. Must be passed as a matrix or dataframe.} } \value{ -List of prediction matrices. If model does not have random effects, the list has one element -- the predictions from the forest. -If the model does have random effects, the list has three elements -- forest predictions, random effects predictions, and their sum (\code{y_hat}). +List of prediction matrices or single prediction matrix / vector, depending on the terms requested. } \description{ Predict from a sampled BART model on new data diff --git a/man/predict.bcfmodel.Rd b/man/predict.bcfmodel.Rd index dd7f14a0..71f275f6 100644 --- a/man/predict.bcfmodel.Rd +++ b/man/predict.bcfmodel.Rd @@ -13,6 +13,7 @@ rfx_basis = NULL, type = "posterior", terms = "all", + scale = "linear", ... ) } @@ -35,10 +36,12 @@ that were not in the training set.} \item{terms}{(Optional) Which model terms to include in the prediction. This can be a single term or a list of model terms. Options include "y_hat", "prognostic_function", "cate", "rfx", "variance_forest", or "all". If a model doesn't have random effects or variance forest predictions, but one of those terms is request, the request will simply be ignored. If none of the requested terms are present in a model, this function will return \code{NULL} along with a warning. Default: "all".} +\item{scale}{(Optional) Scale of mean function predictions. Options are "linear", which returns predictions on the original scale of the mean forest / RFX terms, and "probability", which transforms predictions into a probability of observing \code{y == 1}. "probability" is only valid for models fit with a probit outcome model. Default: "linear".} + \item{...}{(Optional) Other prediction parameters.} } \value{ -List of 3-5 \code{nrow(X)} by \code{object$num_samples} matrices: prognostic function estimates, treatment effect estimates, (optionally) random effects predictions, (optionally) variance forest predictions, and outcome predictions. +List of prediction matrices or single prediction matrix / vector, depending on the terms requested. } \description{ Predict from a sampled BCF model on new data diff --git a/man/sample_bart_posterior_predictive.Rd b/man/sample_bart_posterior_predictive.Rd new file mode 100644 index 00000000..d8989143 --- /dev/null +++ b/man/sample_bart_posterior_predictive.Rd @@ -0,0 +1,44 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/posterior_transformation.R +\name{sample_bart_posterior_predictive} +\alias{sample_bart_posterior_predictive} +\title{Sample from the posterior predictive distribution for outcomes modeled by BART} +\usage{ +sample_bart_posterior_predictive( + model_object, + covariates = NULL, + basis = NULL, + rfx_group_ids = NULL, + rfx_basis = NULL, + num_draws = NULL +) +} +\arguments{ +\item{model_object}{A fitted BART model object of class \code{bartmodel}.} + +\item{covariates}{A matrix or data frame of covariates at which to compute the intervals. Required if the BART model depends on covariates (e.g., contains a mean or variance forest).} + +\item{basis}{A matrix of bases for mean forest models with regression defined in the leaves. Required for "leaf regression" models.} + +\item{rfx_group_ids}{A vector of group IDs for random effects model. Required if the BART model includes random effects.} + +\item{rfx_basis}{A matrix of bases for random effects model. Required if the BART model includes random effects.} + +\item{num_draws}{The number of posterior predictive samples to draw in computing intervals. Defaults to a heuristic based on the number of samples in a BART model (i.e. if the BART model has >1000 draws, we use 1 draw from the likelihood per sample, otherwise we upsample to ensure intervals are based on at least 1000 posterior predictive draws).} +} +\value{ +Array of posterior predictive samples with dimensions (num_observations, num_posterior_samples, num_draws) if num_draws > 1, otherwise (num_observations, num_posterior_samples). +} +\description{ +Sample from the posterior predictive distribution for outcomes modeled by BART +} +\examples{ +n <- 100 +p <- 5 +X <- matrix(rnorm(n * p), nrow = n, ncol = p) +y <- 2 * X[,1] + rnorm(n) +bart_model <- bart(y_train = y, X_train = X) +ppd_samples <- sample_bart_posterior_predictive( + model_object = bart_model, covariates = X +) +} diff --git a/man/sample_bcf_posterior_predictive.Rd b/man/sample_bcf_posterior_predictive.Rd new file mode 100644 index 00000000..8c73effd --- /dev/null +++ b/man/sample_bcf_posterior_predictive.Rd @@ -0,0 +1,47 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/posterior_transformation.R +\name{sample_bcf_posterior_predictive} +\alias{sample_bcf_posterior_predictive} +\title{Sample from the posterior predictive distribution for outcomes modeled by BCF} +\usage{ +sample_bcf_posterior_predictive( + model_object, + covariates = NULL, + treatment = NULL, + propensity = NULL, + rfx_group_ids = NULL, + rfx_basis = NULL, + num_draws = NULL +) +} +\arguments{ +\item{model_object}{A fitted BCF model object of class \code{bcfmodel}.} + +\item{covariates}{(Optional) A matrix or data frame of covariates at which to compute the intervals. Required if the requested term depends on covariates (e.g., prognostic forest, CATE forest, variance forest, or overall predictions).} + +\item{treatment}{(Optional) A vector or matrix of treatment assignments. Required if the requested term is \code{"y_hat"} (overall predictions).} + +\item{propensity}{(Optional) A vector or matrix of propensity scores. Required if the requested term is \code{"y_hat"} (overall predictions) and the underlying model depends on user-provided propensities.} + +\item{rfx_group_ids}{(Optional) A vector of group IDs for random effects model. Required if the BCF model includes random effects.} + +\item{rfx_basis}{(Optional) A matrix of bases for random effects model. Required if the BCF model includes random effects.} + +\item{num_draws}{(Optional) The number of samples to draw from the likelihood, for each draw of the posterior, in computing intervals. Defaults to a heuristic based on the number of samples in a BCF model (i.e. if the BCF model has >1000 draws, we use 1 draw from the likelihood per sample, otherwise we upsample to ensure at least 1000 posterior predictive draws).} +} +\value{ +Array of posterior predictive samples with dimensions (num_observations, num_posterior_samples, num_draws) if num_draws > 1, otherwise (num_observations, num_posterior_samples). +} +\description{ +Sample from the posterior predictive distribution for outcomes modeled by BCF +} +\examples{ +n <- 100 +p <- 5 +X <- matrix(rnorm(n * p), nrow = n, ncol = p) +y <- 2 * X[,1] + rnorm(n) +bart_model <- bart(y_train = y, X_train = X) +ppd_samples <- sample_bart_posterior_predictive( + model_object = bart_model, covariates = X +) +} diff --git a/tools/debug/bart_predict_debug.R b/tools/debug/bart_predict_debug.R index 17cb5aa4..89766a74 100644 --- a/tools/debug/bart_predict_debug.R +++ b/tools/debug/bart_predict_debug.R @@ -4,7 +4,7 @@ library(stochtree) # Generate data -n <- 100 +n <- 1000 p <- 5 X <- matrix(runif(n * p), ncol = p) # fmt: skip @@ -25,6 +25,8 @@ X_test <- X[test_inds, ] X_train <- X[train_inds, ] y_test <- y[test_inds] y_train <- y[train_inds] +E_y_test <- f_XW[test_inds] +E_y_train <- f_XW[train_inds] # Fit simple BART model bart_model <- bart( @@ -32,7 +34,7 @@ bart_model <- bart( y_train = y_train, num_gfr = 10, num_burnin = 0, - num_mcmc = 10 + num_mcmc = 1000 ) # Check several predict approaches @@ -49,3 +51,174 @@ y_hat_test <- predict( type = "mean", terms = c("rfx", "variance") ) + +y_hat_intervals <- compute_bart_posterior_interval( + model_object = bart_model, + transform = function(x) x, + terms = c("y_hat", "mean_forest"), + covariates = X_test, + level = 0.95 +) + +(coverage <- mean( + (y_hat_intervals$mean_forest_predictions$lower <= E_y_test) & + (y_hat_intervals$mean_forest_predictions$upper >= E_y_test) +)) + +pred_intervals <- sample_bart_posterior_predictive( + model_object = bart_model, + covariates = X_test, + level = 0.95 +) + +(coverage_pred <- mean( + (pred_intervals$lower <= y_test) & + (pred_intervals$upper >= y_test) +)) + +# Generate probit data +n <- 1000 +p <- 5 +X <- matrix(runif(n * p), ncol = p) +# fmt: skip +f_X <- (((0 <= X[, 1]) & (0.25 > X[, 1])) * (-2.5) + + ((0.25 <= X[, 1]) & (0.5 > X[, 1])) * (-1.25) + + ((0.5 <= X[, 1]) & (0.75 > X[, 1])) * (1.25) + + ((0.75 <= X[, 1]) & (1 > X[, 1])) * (2.5)) +noise_sd <- 1 +W <- f_X + rnorm(n, 0, noise_sd) +y <- as.numeric(W > 0) + +# Train-test split +test_set_pct <- 0.2 +n_test <- round(test_set_pct * n) +n_train <- n - n_test +test_inds <- sort(sample(1:n, n_test, replace = FALSE)) +train_inds <- (1:n)[!((1:n) %in% test_inds)] +X_test <- X[test_inds, ] +X_train <- X[train_inds, ] +W_test <- W[test_inds] +W_train <- W[train_inds] +y_test <- y[test_inds] +y_train <- y[train_inds] +E_y_test <- f_X[test_inds] +E_y_train <- f_X[train_inds] + +# Fit simple BART model +bart_model <- bart( + X_train = X_train, + y_train = y_train, + num_gfr = 10, + num_burnin = 0, + num_mcmc = 1000, + general_params = list(probit_outcome_model = TRUE) +) + +# Predict on latent scale +y_hat_post <- predict( + object = bart_model, + type = "posterior", + terms = c("y_hat"), + covariates = X_test, + scale = "linear" +) + +# Predict on probability scale +y_hat_post_prob <- predict( + object = bart_model, + type = "posterior", + terms = c("y_hat"), + covariates = X_test, + scale = "probability" +) + +# Compute intervals on latent scale +y_hat_intervals <- compute_bart_posterior_interval( + model_object = bart_model, + scale = "linear", + terms = c("y_hat"), + covariates = X_test, + level = 0.95 +) + +# Compute intervals on probability scale +y_hat_prob_intervals <- compute_bart_posterior_interval( + model_object = bart_model, + scale = "probability", + terms = c("y_hat"), + covariates = X_test, + level = 0.95 +) + +# Compute posterior means +y_hat_mean_latent <- rowMeans(y_hat_post) +y_hat_mean_prob <- rowMeans(y_hat_post_prob) + +# Plot on latent scale +sort_inds <- order(y_hat_mean_latent) +plot(y_hat_mean_latent[sort_inds]) +lines(y_hat_intervals$lower[sort_inds]) +lines(y_hat_intervals$upper[sort_inds]) + +# Plot on probability scale +sort_inds <- order(y_hat_mean_prob) +plot(y_hat_mean_prob[sort_inds]) +lines(y_hat_prob_intervals$lower[sort_inds]) +lines(y_hat_prob_intervals$upper[sort_inds]) + +# Draw from posterior predictive for covariates in the test set +ppd_samples <- sample_bart_posterior_predictive( + model_object = bart_model, + covariates = X_test, + num_draws = 10 +) + +# Compute histogram of PPD probabilities for both outcome classes +ppd_samples_prob <- apply(ppd_samples, 1, mean) +ppd_outcome_0 <- ppd_samples_prob[y_test == 0] +ppd_outcome_1 <- ppd_samples_prob[y_test == 1] +hist(ppd_outcome_0, breaks = 50, xlim = c(0, 1)) +hist(ppd_outcome_1, breaks = 50, xlim = c(0, 1)) + +# Compute posterior ROC +num_mcmc <- 1000 +num_thresholds <- 1000 +thresholds <- seq(0.001, 0.999, length.out = num_thresholds) +tpr_mean <- rep(NA, num_thresholds) +fpr_mean <- rep(NA, num_thresholds) +tpr_samples <- matrix(NA, num_thresholds, num_mcmc) +fpr_samples <- matrix(NA, num_thresholds, num_mcmc) +for (i in 1:num_thresholds) { + is_above_threshold_samples <- y_hat_post > qnorm(thresholds[i]) + is_above_threshold_mean <- y_hat_mean_latent > qnorm(thresholds[i]) + n_positive <- sum(y_test) + n_negative <- sum(y_test == 0) + y_above_threshold_mean <- y_test[is_above_threshold_mean] + for (j in 1:num_mcmc) { + y_above_threshold <- y_test[is_above_threshold_samples[, j]] + tpr_samples[i, j] <- sum(y_above_threshold) / n_positive + fpr_samples[i, j] <- sum(y_above_threshold == 0) / n_negative + } + # tpr_mean[i] <- sum(y_above_threshold_mean) / n_positive + # fpr_mean[i] <- sum(y_above_threshold_mean == 0) / n_negative + tpr_mean[i] <- mean(tpr_samples[i, ]) + fpr_mean[i] <- mean(fpr_samples[i, ]) +} + +for (i in 1:num_mcmc) { + if (i == 1) { + plot( + fpr_samples[, i], + tpr_samples[, i], + type = "line", + col = "blue", + lwd = 1, + lty = 1, + xlab = "False positive rate", + ylab = "True positive rate" + ) + } else { + lines(fpr_samples[, i], tpr_samples[, i], col = "blue", lwd = 1, lty = 1) + } +} +lines(fpr_mean, tpr_mean, col = "black", lwd = 3, lty = 3) diff --git a/tools/debug/bart_prior_draws.R b/tools/debug/bart_prior_draws.R new file mode 100644 index 00000000..ba3d1182 --- /dev/null +++ b/tools/debug/bart_prior_draws.R @@ -0,0 +1,169 @@ +library(stochtree) + +# Generate the data +n <- 500 +p_X <- 10 +p_W <- 1 +X <- matrix(runif(n * p_X), ncol = p_X) +W <- matrix(runif(n * p_W), ncol = p_W) +f_XW <- (((0 <= X[, 1]) & (0.25 > X[, 1])) * + (-3 * W[, 1]) + + ((0.25 <= X[, 1]) & (0.5 > X[, 1])) * (-1 * W[, 1]) + + ((0.5 <= X[, 1]) & (0.75 > X[, 1])) * (1 * W[, 1]) + + ((0.75 <= X[, 1]) & (1 > X[, 1])) * (3 * W[, 1])) +# y <- f_XW + rnorm(n, 0, 1) +y <- rep(0, n) +# wgt <- rep(0, n) + +# Standardize outcome +# y_bar <- mean(y) +# y_std <- sd(y) +# resid <- (y - y_bar) / y_std +resid <- y + +# Sampling parameters +alpha <- 0.99 +beta <- 1 +min_samples_leaf <- 1 +max_depth <- 10 +num_trees <- 100 +cutpoint_grid_size <- 100 +global_variance_init <- 1. +current_sigma2 <- global_variance_init +tau_init <- 1 / num_trees +leaf_prior_scale <- as.matrix(ifelse( + p_W >= 1, + diag(tau_init, p_W), + diag(tau_init, 1) +)) +nu <- 4 +lambda <- 10 +a_leaf <- 2. +b_leaf <- 0.5 +leaf_regression <- T +feature_types <- as.integer(rep(0, p_X)) # 0 = numeric +var_weights <- rep(1 / p_X, p_X) + +# Sampling data structures +# Data +# forest_dataset <- createForestDataset(X, W, wgt) +forest_dataset <- createForestDataset(X, W) +outcome_model_type <- 1 +leaf_dimension <- p_W +outcome <- createOutcome(resid) + +# Random number generator (std::mt19937) +rng <- createCppRNG() + +# Sampling data structures +forest_model_config <- createForestModelConfig( + feature_types = feature_types, + num_trees = num_trees, + num_features = p_X, + num_observations = n, + variable_weights = var_weights, + leaf_dimension = leaf_dimension, + alpha = alpha, + beta = beta, + min_samples_leaf = min_samples_leaf, + max_depth = max_depth, + leaf_model_type = outcome_model_type, + leaf_model_scale = leaf_prior_scale, + cutpoint_grid_size = cutpoint_grid_size +) +global_model_config <- createGlobalModelConfig( + global_error_variance = global_variance_init +) +forest_model <- createForestModel( + forest_dataset, + forest_model_config, + global_model_config +) + +# "Active forest" (which gets updated by the sample) and +# container of forest samples (which is written to when +# a sample is not discarded due to burn-in / thinning) +if (leaf_regression) { + forest_samples <- createForestSamples(num_trees, 1, F) + active_forest <- createForest(num_trees, 1, F) +} else { + forest_samples <- createForestSamples(num_trees, 1, T) + active_forest <- createForest(num_trees, 1, T) +} + +# Initialize the leaves of each tree in the forest +active_forest$prepare_for_sampler( + forest_dataset, + outcome, + forest_model, + outcome_model_type, + mean(resid) +) +active_forest$adjust_residual( + forest_dataset, + outcome, + forest_model, + ifelse(outcome_model_type == 1, T, F), + F +) + +# Prepare to run sampler +num_mcmc <- 2000 +global_var_samples <- rep(0, num_mcmc) + +# Run MCMC +for (i in 1:num_mcmc) { + # Sample forest + forest_model$sample_one_iteration( + forest_dataset, + outcome, + forest_samples, + active_forest, + rng, + forest_model_config, + global_model_config, + keep_forest = T, + gfr = F, + num_threads = 1 + ) + + # Sample global variance parameter + # current_sigma2 <- sampleGlobalErrorVarianceOneIteration( + # outcome, + # forest_dataset, + # rng, + # nu, + # lambda + # ) + current_sigma2 <- 1 / rgamma(1, shape = nu / 2, rate = lambda / 2) + global_var_samples[i] <- current_sigma2 + global_model_config$update_global_error_variance(current_sigma2) +} + +plot(global_var_samples, type = "l") +hist(global_var_samples, breaks = 30) +hist( + 1 / rgamma(num_mcmc, shape = nu / 2, rate = lambda / 2), + breaks = 30 +) +mean(global_var_samples) +mean(1 / rgamma(num_mcmc, shape = nu / 2, rate = lambda / 2)) +sd(global_var_samples) +sd(1 / rgamma(num_mcmc, shape = nu / 2, rate = lambda / 2)) + +# Extract forest predictions +forest_preds <- forest_samples$predict( + forest_dataset +) + +y_hat_prior <- rowMeans(forest_preds) +y_hat_prior_lb <- apply(forest_preds, 1, quantile, probs = 0.025) +y_hat_prior_ub <- apply(forest_preds, 1, quantile, probs = 0.975) +plot( + 1:length(y_hat_prior), + y_hat_prior, + type = "l", + ylim = range(c(y_hat_prior_lb, y_hat_prior_ub)) +) +lines(1:length(y_hat_prior), y_hat_prior_lb, col = "blue") +lines(1:length(y_hat_prior), y_hat_prior_ub, col = "blue") diff --git a/tools/debug/bcf_predict_debug.R b/tools/debug/bcf_predict_debug.R index 30d93852..88366551 100644 --- a/tools/debug/bcf_predict_debug.R +++ b/tools/debug/bcf_predict_debug.R @@ -4,27 +4,16 @@ library(stochtree) # Generate data -n <- 100 -g <- function(x) { - ifelse(x[, 5] == 1, 2, ifelse(x[, 5] == 2, -1, -4)) -} -x1 <- rnorm(n) -x2 <- rnorm(n) -x3 <- rnorm(n) -x4 <- as.numeric(rbinom(n, 1, 0.5)) -x5 <- as.numeric(sample(1:3, n, replace = TRUE)) -X <- cbind(x1, x2, x3, x4, x5) -p <- ncol(X) -mu_x <- 1 + g(X) + X[, 1] * X[, 3] -tau_x <- 1 + 2 * X[, 2] * X[, 4] -pi_x <- 0.8 * pnorm((3 * mu_x / sd(mu_x)) - 0.5 * X[, 1]) + 0.05 + runif(n) / 10 +n <- 500 +p <- 5 +X <- matrix(rnorm(n * p), ncol = p) +mu_x <- X[, 1] +tau_x <- 0.25 * X[, 2] +pi_x <- pnorm(0.5 * X[, 1]) Z <- rbinom(n, 1, pi_x) E_XZ <- mu_x + Z * tau_x snr <- 2 y <- E_XZ + rnorm(n, 0, 1) * (sd(E_XZ) / snr) -X <- as.data.frame(X) -X$x4 <- factor(X$x4, ordered = TRUE) -X$x5 <- factor(X$x5, ordered = TRUE) # Train-test split test_set_pct <- 0.2 @@ -44,8 +33,6 @@ mu_test <- mu_x[test_inds] mu_train <- mu_x[train_inds] tau_test <- tau_x[test_inds] tau_train <- tau_x[train_inds] -y_test <- y[test_inds] -y_train <- y[train_inds] # Fit a simple BCF model bcf_model <- bcf( @@ -58,7 +45,7 @@ bcf_model <- bcf( propensity_test = pi_test, num_gfr = 10, num_burnin = 0, - num_mcmc = 10 + num_mcmc = 1000 ) # Check several predict approaches @@ -69,12 +56,12 @@ y_hat_posterior_test <- predict( propensity = pi_test )$y_hat pred <- predict( - bcf_model, - X = X_test, - Z = Z_test, - propensity = pi_test, - type = "mean", - terms = c("all") + bcf_model, + X = X_test, + Z = Z_test, + propensity = pi_test, + type = "mean", + terms = c("all") ) y_hat_test <- predict( bcf_model, @@ -85,3 +72,141 @@ y_hat_test <- predict( terms = c("rfx", "variance") ) +y_hat_intervals <- compute_bcf_posterior_interval( + model_object = bcf_model, + scale = "linear", + terms = c("all"), + covariates = X_test, + treatment = Z_test, + propensity = pi_test, + level = 0.95 +) + +(tau_coverage <- mean( + (y_hat_intervals$tau_hat$upper >= tau_test) & + (y_hat_intervals$tau_hat$lower <= tau_test) +)) + +# Generate probit outcome data +n <- 1000 +p <- 5 +X <- matrix(rnorm(n * p), ncol = p) +mu_x <- X[, 1] +tau_x <- 0.25 * X[, 2] +pi_x <- pnorm(0.5 * X[, 1]) +Z <- rbinom(n, 1, pi_x) +E_XZ <- mu_x + Z * tau_x +W <- E_XZ + rnorm(n, 0, 1) +y <- (W > 0) * 1 + +# Train-test split +test_set_pct <- 0.2 +n_test <- round(test_set_pct * n) +n_train <- n - n_test +test_inds <- sort(sample(1:n, n_test, replace = FALSE)) +train_inds <- (1:n)[!((1:n) %in% test_inds)] +X_test <- X[test_inds, ] +X_train <- X[train_inds, ] +pi_test <- pi_x[test_inds] +pi_train <- pi_x[train_inds] +Z_test <- Z[test_inds] +Z_train <- Z[train_inds] +W_test <- W[test_inds] +W_train <- W[train_inds] +mu_test <- mu_x[test_inds] +mu_train <- mu_x[train_inds] +tau_test <- tau_x[test_inds] +tau_train <- tau_x[train_inds] +y_test <- y[test_inds] +y_train <- y[train_inds] + +# Fit a simple BCF model +bcf_model <- bcf( + X_train = X_train, + Z_train = Z_train, + y_train = y_train, + propensity_train = pi_train, + X_test = X_test, + Z_test = Z_test, + propensity_test = pi_test, + num_gfr = 10, + num_burnin = 0, + num_mcmc = 1000, + general_params = list(probit_outcome_model = TRUE) +) + +# Predict on latent scale +y_hat_post <- predict( + object = bcf_model, + type = "posterior", + terms = c("y_hat"), + X = X_test, + Z = Z_test, + propensity = pi_test, + scale = "linear" +) + +# Predict on probability scale +y_hat_post_prob <- predict( + object = bcf_model, + type = "posterior", + terms = c("y_hat"), + X = X_test, + Z = Z_test, + propensity = pi_test, + scale = "probability" +) + +# Compute intervals on latent scale +y_hat_intervals <- compute_bcf_posterior_interval( + model_object = bcf_model, + scale = "linear", + terms = c("y_hat"), + covariates = X_test, + treatment = Z_test, + propensity = pi_test, + level = 0.95 +) + +# Compute intervals on probability scale +y_hat_prob_intervals <- compute_bcf_posterior_interval( + model_object = bcf_model, + scale = "probability", + terms = c("y_hat"), + covariates = X_test, + treatment = Z_test, + propensity = pi_test, + level = 0.95 +) + +# Compute posterior means +y_hat_mean_latent <- rowMeans(y_hat_post) +y_hat_mean_prob <- rowMeans(y_hat_post_prob) + +# Plot on latent scale +sort_inds <- order(y_hat_mean_latent) +plot(y_hat_mean_latent[sort_inds], ylim = range(y_hat_intervals)) +lines(y_hat_intervals$lower[sort_inds]) +lines(y_hat_intervals$upper[sort_inds]) + +# Plot on probability scale +sort_inds <- order(y_hat_mean_prob) +plot(y_hat_mean_prob[sort_inds], ylim = range(y_hat_prob_intervals)) +lines(y_hat_prob_intervals$lower[sort_inds]) +lines(y_hat_prob_intervals$upper[sort_inds]) + +# Draw from posterior predictive for covariates / treatment values in the test set +ppd_samples <- sample_bcf_posterior_predictive( + model_object = bcf_model, + covariates = X_test, + treatment = Z_test, + propensity = pi_test, + num_draws = 10 +) + +# Compute histogram of PPD probabilities for both outcome classes +ppd_samples_prob <- apply(ppd_samples, 1, mean) +ppd_outcome_0 <- ppd_samples_prob[y_test == 0] +ppd_outcome_1 <- ppd_samples_prob[y_test == 1] +hist(ppd_outcome_0, breaks = 50, xlim = c(0, 1)) +hist(ppd_outcome_1, breaks = 50, xlim = c(0, 1)) From a8de2dab073c24daa590833514c47930921756c9 Mon Sep 17 00:00:00 2001 From: Drew Herren Date: Wed, 15 Oct 2025 15:56:00 -0500 Subject: [PATCH 08/53] Removed unused functions --- R/posterior_transformation.R | 55 ------------------------------------ 1 file changed, 55 deletions(-) diff --git a/R/posterior_transformation.R b/R/posterior_transformation.R index 2b0ae92a..2ff58655 100644 --- a/R/posterior_transformation.R +++ b/R/posterior_transformation.R @@ -702,61 +702,6 @@ compute_bart_posterior_interval <- function( } } -transform_power <- function(array, exponent) { - return(compute_transformation(array, fn = function(x) x**exponent)) -} - -transform_multiply <- function(array, multiple) { - return(compute_transformation(array, fn = function(x) x * multiple)) -} - -transform_add <- function(array, addend) { - return(compute_transformation(array, fn = function(x) x + addend)) -} - -transform_exp <- function(array) { - return(compute_transformation(array, fn = exp)) -} - -transform_log <- function(array) { - return(compute_transformation(array, fn = log)) -} - -transform_pnorm <- function(array) { - return(compute_transformation(array, fn = pnorm)) -} - -compute_transformation <- function(array, fn = NULL) { - # Check that the array is numeric and at least 1 dimensional - stopifnot(is.numeric(array) && length(dim(array)) >= 1) - - # Calculate the transformation - return(fn(array)) -} - -summarize_mean <- function(array, sample_dim = 2) { - return(compute_summary(array, sample_dim, mean)) -} - -summarize_median <- function(array, sample_dim = 2) { - return(compute_summary(array, sample_dim, median)) -} - -compute_summary <- function(array, sample_dim = 2, fn = mean) { - # Check that the array is numeric and at least 2 dimensional - stopifnot(is.numeric(array) && length(dim(array)) >= 2) - - # Determine the dimensions over which reduction is computed - apply_dim <- setdiff(1:length(dim(array)), sample_dim) - - # Compute the reduction - result <- apply(array, apply_dim, function(x) { - fn(x) - }) - - return(result) -} - summarize_interval <- function(array, sample_dim = 2, level = 0.95) { # Check that the array is numeric and at least 2 dimensional stopifnot(is.numeric(array) && length(dim(array)) >= 2) From 5cc9faa7d75ff13de827ed04670d30681f0fabfe Mon Sep 17 00:00:00 2001 From: Drew Herren Date: Wed, 15 Oct 2025 18:29:30 -0500 Subject: [PATCH 09/53] Updated PPD sampling function and example script --- R/posterior_transformation.R | 11 +++++++---- tools/debug/bcf_predict_debug.R | 17 +++++++++++++++++ 2 files changed, 24 insertions(+), 4 deletions(-) diff --git a/R/posterior_transformation.R b/R/posterior_transformation.R index 2ff58655..064f9ed0 100644 --- a/R/posterior_transformation.R +++ b/R/posterior_transformation.R @@ -15,10 +15,13 @@ #' n <- 100 #' p <- 5 #' X <- matrix(rnorm(n * p), nrow = n, ncol = p) -#' y <- 2 * X[,1] + rnorm(n) -#' bart_model <- bart(y_train = y, X_train = X) -#' ppd_samples <- sample_bart_posterior_predictive( -#' model_object = bart_model, covariates = X +#' pi_X <- pnorm(X[,1] / 2) +#' Z <- rbinom(n, 1, pi_X) +#' y <- 2 * X[,2] + 0.5 * X[,2] * Z + rnorm(n) +#' bcf_model <- bcf(X_train = X, Z_train = Z, y_train = y, propensity_train = pi_X) +#' ppd_samples <- sample_bcf_posterior_predictive( +#' model_object = bcf_model, covariates = X, +#' treatment = Z, propensity = pi_X #' ) sample_bcf_posterior_predictive <- function( model_object, diff --git a/tools/debug/bcf_predict_debug.R b/tools/debug/bcf_predict_debug.R index 88366551..7a854707 100644 --- a/tools/debug/bcf_predict_debug.R +++ b/tools/debug/bcf_predict_debug.R @@ -63,6 +63,7 @@ pred <- predict( type = "mean", terms = c("all") ) +# Check that this throws a warning y_hat_test <- predict( bcf_model, X = X_test, @@ -72,6 +73,7 @@ y_hat_test <- predict( terms = c("rfx", "variance") ) +# Compute intervals around model terms (including E[y | X, Z]) y_hat_intervals <- compute_bcf_posterior_interval( model_object = bcf_model, scale = "linear", @@ -82,11 +84,26 @@ y_hat_intervals <- compute_bcf_posterior_interval( level = 0.95 ) +# Estimate coverage of intervals on tau(X) (tau_coverage <- mean( (y_hat_intervals$tau_hat$upper >= tau_test) & (y_hat_intervals$tau_hat$lower <= tau_test) )) +# Posterior predictive coverage and MSE checks +quantiles <- c(0.05, 0.95) +ppd_samples <- sample_bcf_posterior_predictive( + model_object = bcf_model, + covariates = X_test, + treatment = Z_test, + propensity = pi_test, + num_draws = 1 +) +yhat_ppd <- apply(ppd_samples, 1, mean) +yhat_interval_ppd <- apply(ppd_samples, 1, quantile, probs = quantiles) +mean((yhat_interval_ppd[1, ] <= y_test) & (yhat_interval_ppd[2, ] >= y_test)) +sqrt(mean((yhat_ppd - y_test)^2)) + # Generate probit outcome data n <- 1000 p <- 5 From 0bf7df373b038f8eb75bc99c4661778a3271b5a5 Mon Sep 17 00:00:00 2001 From: Drew Herren Date: Wed, 15 Oct 2025 18:39:47 -0500 Subject: [PATCH 10/53] Updated functions and docs --- R/bart.R | 6 +++--- man/sample_bcf_posterior_predictive.Rd | 11 +++++++---- 2 files changed, 10 insertions(+), 7 deletions(-) diff --git a/R/bart.R b/R/bart.R index 30e0f8f3..bd94b625 100644 --- a/R/bart.R +++ b/R/bart.R @@ -1957,6 +1957,9 @@ predict.bartmodel <- function( } # Scale variance forest predictions + num_samples <- object$model_params$num_samples + y_std <- object$model_params$outcome_scale + y_bar <- object$model_params$outcome_mean sigma2_init <- object$model_params$sigma2_init if (predict_variance_forest) { if (object$model_params$sample_sigma2_global) { @@ -1973,9 +1976,6 @@ predict.bartmodel <- function( } # Compute mean forest predictions - num_samples <- object$model_params$num_samples - y_std <- object$model_params$outcome_scale - y_bar <- object$model_params$outcome_mean if (predict_mean_forest || predict_mean_forest_intermediate) { mean_forest_predictions <- object$mean_forests$predict( prediction_dataset diff --git a/man/sample_bcf_posterior_predictive.Rd b/man/sample_bcf_posterior_predictive.Rd index 8c73effd..d4e827db 100644 --- a/man/sample_bcf_posterior_predictive.Rd +++ b/man/sample_bcf_posterior_predictive.Rd @@ -39,9 +39,12 @@ Sample from the posterior predictive distribution for outcomes modeled by BCF n <- 100 p <- 5 X <- matrix(rnorm(n * p), nrow = n, ncol = p) -y <- 2 * X[,1] + rnorm(n) -bart_model <- bart(y_train = y, X_train = X) -ppd_samples <- sample_bart_posterior_predictive( - model_object = bart_model, covariates = X +pi_X <- pnorm(X[,1] / 2) +Z <- rbinom(n, 1, pi_X) +y <- 2 * X[,2] + 0.5 * X[,2] * Z + rnorm(n) +bcf_model <- bcf(X_train = X, Z_train = Z, y_train = y, propensity_train = pi_X) +ppd_samples <- sample_bcf_posterior_predictive( + model_object = bcf_model, covariates = X, + treatment = Z, propensity = pi_X ) } From 2421c75996ffeddb7596c0e81cd8be1c66d136b6 Mon Sep 17 00:00:00 2001 From: Drew Herren Date: Wed, 15 Oct 2025 18:45:41 -0500 Subject: [PATCH 11/53] Format R tests with air --- test/R/testthat/test-bart.R | 935 ++++++++------ test/R/testthat/test-bcf.R | 1354 ++++++++++++--------- test/R/testthat/test-categorical.R | 237 ++-- test/R/testthat/test-data-preprocessing.R | 74 +- test/R/testthat/test-dataset.R | 129 +- test/R/testthat/test-forest-container.R | 516 ++++---- test/R/testthat/test-forest.R | 292 ++--- test/R/testthat/test-residual.R | 187 +-- test/R/testthat/test-serialization.R | 319 ++--- test/R/testthat/test-utils.R | 154 +-- 10 files changed, 2340 insertions(+), 1857 deletions(-) diff --git a/test/R/testthat/test-bart.R b/test/R/testthat/test-bart.R index 748ff96c..92ab1f03 100644 --- a/test/R/testthat/test-bart.R +++ b/test/R/testthat/test-bart.R @@ -1,433 +1,564 @@ test_that("MCMC BART", { - skip_on_cran() - - # Generate simulated data - n <- 100 - p <- 5 - X <- matrix(runif(n*p), ncol = p) - f_XW <- ( - ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + - ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + - ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + - ((0.75 <= X[,1]) & (1 > X[,1])) * (7.5) - ) - noise_sd <- 1 - y <- f_XW + rnorm(n, 0, noise_sd) - test_set_pct <- 0.2 - n_test <- round(test_set_pct*n) - n_train <- n - n_test - test_inds <- sort(sample(1:n, n_test, replace = FALSE)) - train_inds <- (1:n)[!((1:n) %in% test_inds)] - X_test <- X[test_inds,] - X_train <- X[train_inds,] - y_test <- y[test_inds] - y_train <- y[train_inds] - - # 1 chain, no thinning - general_param_list <- list(num_chains = 1, keep_every = 1) - expect_no_error( - bart_model <- bart(X_train = X_train, y_train = y_train, X_test = X_test, - num_gfr = 0, num_burnin = 10, num_mcmc = 10, - general_params = general_param_list) - ) - - # 3 chains, no thinning - general_param_list <- list(num_chains = 3, keep_every = 1) - expect_no_error( - bart_model <- bart(X_train = X_train, y_train = y_train, X_test = X_test, - num_gfr = 0, num_burnin = 10, num_mcmc = 10, - general_params = general_param_list) + skip_on_cran() + + # Generate simulated data + n <- 100 + p <- 5 + X <- matrix(runif(n * p), ncol = p) + # fmt: skip + f_XW <- (((0 <= X[, 1]) & (0.25 > X[, 1])) * (-7.5) + + ((0.25 <= X[, 1]) & (0.5 > X[, 1])) * (-2.5) + + ((0.5 <= X[, 1]) & (0.75 > X[, 1])) * (2.5) + + ((0.75 <= X[, 1]) & (1 > X[, 1])) * (7.5)) + noise_sd <- 1 + y <- f_XW + rnorm(n, 0, noise_sd) + test_set_pct <- 0.2 + n_test <- round(test_set_pct * n) + n_train <- n - n_test + test_inds <- sort(sample(1:n, n_test, replace = FALSE)) + train_inds <- (1:n)[!((1:n) %in% test_inds)] + X_test <- X[test_inds, ] + X_train <- X[train_inds, ] + y_test <- y[test_inds] + y_train <- y[train_inds] + + # 1 chain, no thinning + general_param_list <- list(num_chains = 1, keep_every = 1) + expect_no_error( + bart_model <- bart( + X_train = X_train, + y_train = y_train, + X_test = X_test, + num_gfr = 0, + num_burnin = 10, + num_mcmc = 10, + general_params = general_param_list ) - - # 1 chain, thinning - general_param_list <- list(num_chains = 1, keep_every = 5) - expect_no_error( - bart_model <- bart(X_train = X_train, y_train = y_train, X_test = X_test, - num_gfr = 0, num_burnin = 10, num_mcmc = 10, - general_params = general_param_list) + ) + + # 3 chains, no thinning + general_param_list <- list(num_chains = 3, keep_every = 1) + expect_no_error( + bart_model <- bart( + X_train = X_train, + y_train = y_train, + X_test = X_test, + num_gfr = 0, + num_burnin = 10, + num_mcmc = 10, + general_params = general_param_list ) - - # 3 chains, thinning - general_param_list <- list(num_chains = 3, keep_every = 5) - expect_no_error( - bart_model <- bart(X_train = X_train, y_train = y_train, X_test = X_test, - num_gfr = 0, num_burnin = 10, num_mcmc = 10, - general_params = general_param_list) + ) + + # 1 chain, thinning + general_param_list <- list(num_chains = 1, keep_every = 5) + expect_no_error( + bart_model <- bart( + X_train = X_train, + y_train = y_train, + X_test = X_test, + num_gfr = 0, + num_burnin = 10, + num_mcmc = 10, + general_params = general_param_list ) - - # Generate simulated data with a leaf basis - n <- 100 - p <- 5 - p_w <- 2 - X <- matrix(runif(n*p), ncol = p) - W <- matrix(runif(n*p_w), ncol = p_w) - f_XW <- ( - ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5*W[,1]) + - ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5*W[,1]) + - ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5*W[,1]) + - ((0.75 <= X[,1]) & (1 > X[,1])) * (7.5*W[,1]) + ) + + # 3 chains, thinning + general_param_list <- list(num_chains = 3, keep_every = 5) + expect_no_error( + bart_model <- bart( + X_train = X_train, + y_train = y_train, + X_test = X_test, + num_gfr = 0, + num_burnin = 10, + num_mcmc = 10, + general_params = general_param_list ) - noise_sd <- 1 - y <- f_XW + rnorm(n, 0, noise_sd) - test_set_pct <- 0.2 - n_test <- round(test_set_pct*n) - n_train <- n - n_test - test_inds <- sort(sample(1:n, n_test, replace = FALSE)) - train_inds <- (1:n)[!((1:n) %in% test_inds)] - X_test <- X[test_inds,] - X_train <- X[train_inds,] - W_test <- W[test_inds,] - W_train <- W[train_inds,] - y_test <- y[test_inds] - y_train <- y[train_inds] - - # 3 chains, thinning, leaf regression - general_param_list <- list(num_chains = 3, keep_every = 5) - mean_forest_param_list <- list(sample_sigma2_leaf = FALSE) - expect_no_error( - bart_model <- bart(X_train = X_train, y_train = y_train, X_test = X_test, - leaf_basis_train = W_train, leaf_basis_test = W_test, - num_gfr = 0, num_burnin = 10, num_mcmc = 10, - general_params = general_param_list, - mean_forest_params = mean_forest_param_list) + ) + + # Generate simulated data with a leaf basis + n <- 100 + p <- 5 + p_w <- 2 + X <- matrix(runif(n * p), ncol = p) + W <- matrix(runif(n * p_w), ncol = p_w) + # fmt: skip + f_XW <- (((0 <= X[, 1]) & (0.25 > X[, 1])) * (-7.5 * W[, 1]) + + ((0.25 <= X[, 1]) & (0.5 > X[, 1])) * (-2.5 * W[, 1]) + + ((0.5 <= X[, 1]) & (0.75 > X[, 1])) * (2.5 * W[, 1]) + + ((0.75 <= X[, 1]) & (1 > X[, 1])) * (7.5 * W[, 1])) + noise_sd <- 1 + y <- f_XW + rnorm(n, 0, noise_sd) + test_set_pct <- 0.2 + n_test <- round(test_set_pct * n) + n_train <- n - n_test + test_inds <- sort(sample(1:n, n_test, replace = FALSE)) + train_inds <- (1:n)[!((1:n) %in% test_inds)] + X_test <- X[test_inds, ] + X_train <- X[train_inds, ] + W_test <- W[test_inds, ] + W_train <- W[train_inds, ] + y_test <- y[test_inds] + y_train <- y[train_inds] + + # 3 chains, thinning, leaf regression + general_param_list <- list(num_chains = 3, keep_every = 5) + mean_forest_param_list <- list(sample_sigma2_leaf = FALSE) + expect_no_error( + bart_model <- bart( + X_train = X_train, + y_train = y_train, + X_test = X_test, + leaf_basis_train = W_train, + leaf_basis_test = W_test, + num_gfr = 0, + num_burnin = 10, + num_mcmc = 10, + general_params = general_param_list, + mean_forest_params = mean_forest_param_list ) - - # 3 chains, thinning, leaf regression with a scalar leaf scale - general_param_list <- list(num_chains = 3, keep_every = 5) - mean_forest_param_list <- list(sample_sigma2_leaf = FALSE, sigma2_leaf_init = 0.5) - expect_no_error( - bart_model <- bart(X_train = X_train, y_train = y_train, X_test = X_test, - leaf_basis_train = W_train, leaf_basis_test = W_test, - num_gfr = 0, num_burnin = 10, num_mcmc = 10, - general_params = general_param_list, - mean_forest_params = mean_forest_param_list) + ) + + # 3 chains, thinning, leaf regression with a scalar leaf scale + general_param_list <- list(num_chains = 3, keep_every = 5) + mean_forest_param_list <- list( + sample_sigma2_leaf = FALSE, + sigma2_leaf_init = 0.5 + ) + expect_no_error( + bart_model <- bart( + X_train = X_train, + y_train = y_train, + X_test = X_test, + leaf_basis_train = W_train, + leaf_basis_test = W_test, + num_gfr = 0, + num_burnin = 10, + num_mcmc = 10, + general_params = general_param_list, + mean_forest_params = mean_forest_param_list ) - - # 3 chains, thinning, leaf regression with a scalar leaf scale, random leaf scale - general_param_list <- list(num_chains = 3, keep_every = 5) - mean_forest_param_list <- list(sample_sigma2_leaf = T, sigma2_leaf_init = 0.5) - expect_warning( - bart_model <- bart(X_train = X_train, y_train = y_train, X_test = X_test, - leaf_basis_train = W_train, leaf_basis_test = W_test, - num_gfr = 0, num_burnin = 10, num_mcmc = 10, - general_params = general_param_list, - mean_forest_params = mean_forest_param_list) + ) + + # 3 chains, thinning, leaf regression with a scalar leaf scale, random leaf scale + general_param_list <- list(num_chains = 3, keep_every = 5) + mean_forest_param_list <- list(sample_sigma2_leaf = T, sigma2_leaf_init = 0.5) + expect_warning( + bart_model <- bart( + X_train = X_train, + y_train = y_train, + X_test = X_test, + leaf_basis_train = W_train, + leaf_basis_test = W_test, + num_gfr = 0, + num_burnin = 10, + num_mcmc = 10, + general_params = general_param_list, + mean_forest_params = mean_forest_param_list ) + ) }) test_that("GFR BART", { - skip_on_cran() - - # Generate simulated data - n <- 100 - p <- 5 - X <- matrix(runif(n*p), ncol = p) - f_XW <- ( - ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + - ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + - ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + - ((0.75 <= X[,1]) & (1 > X[,1])) * (7.5) - ) - noise_sd <- 1 - y <- f_XW + rnorm(n, 0, noise_sd) - test_set_pct <- 0.2 - n_test <- round(test_set_pct*n) - n_train <- n - n_test - test_inds <- sort(sample(1:n, n_test, replace = FALSE)) - train_inds <- (1:n)[!((1:n) %in% test_inds)] - X_test <- X[test_inds,] - X_train <- X[train_inds,] - y_test <- y[test_inds] - y_train <- y[train_inds] - - # 1 chain, no thinning - general_param_list <- list(num_chains = 1, keep_every = 1) - expect_no_error( - bart_model <- bart(X_train = X_train, y_train = y_train, X_test = X_test, - num_gfr = 10, num_burnin = 10, num_mcmc = 10, - general_params = general_param_list) + skip_on_cran() + + # Generate simulated data + n <- 100 + p <- 5 + X <- matrix(runif(n * p), ncol = p) + # fmt: skip + f_XW <- (((0 <= X[, 1]) & (0.25 > X[, 1])) * (-7.5) + + ((0.25 <= X[, 1]) & (0.5 > X[, 1])) * (-2.5) + + ((0.5 <= X[, 1]) & (0.75 > X[, 1])) * (2.5) + + ((0.75 <= X[, 1]) & (1 > X[, 1])) * (7.5)) + noise_sd <- 1 + y <- f_XW + rnorm(n, 0, noise_sd) + test_set_pct <- 0.2 + n_test <- round(test_set_pct * n) + n_train <- n - n_test + test_inds <- sort(sample(1:n, n_test, replace = FALSE)) + train_inds <- (1:n)[!((1:n) %in% test_inds)] + X_test <- X[test_inds, ] + X_train <- X[train_inds, ] + y_test <- y[test_inds] + y_train <- y[train_inds] + + # 1 chain, no thinning + general_param_list <- list(num_chains = 1, keep_every = 1) + expect_no_error( + bart_model <- bart( + X_train = X_train, + y_train = y_train, + X_test = X_test, + num_gfr = 10, + num_burnin = 10, + num_mcmc = 10, + general_params = general_param_list ) - - # 3 chains, no thinning - general_param_list <- list(num_chains = 3, keep_every = 1) - expect_no_error( - bart_model <- bart(X_train = X_train, y_train = y_train, X_test = X_test, - num_gfr = 10, num_burnin = 10, num_mcmc = 10, - general_params = general_param_list) + ) + + # 3 chains, no thinning + general_param_list <- list(num_chains = 3, keep_every = 1) + expect_no_error( + bart_model <- bart( + X_train = X_train, + y_train = y_train, + X_test = X_test, + num_gfr = 10, + num_burnin = 10, + num_mcmc = 10, + general_params = general_param_list ) - - # 1 chain, thinning - general_param_list <- list(num_chains = 1, keep_every = 5) - expect_no_error( - bart_model <- bart(X_train = X_train, y_train = y_train, X_test = X_test, - num_gfr = 10, num_burnin = 10, num_mcmc = 10, - general_params = general_param_list) + ) + + # 1 chain, thinning + general_param_list <- list(num_chains = 1, keep_every = 5) + expect_no_error( + bart_model <- bart( + X_train = X_train, + y_train = y_train, + X_test = X_test, + num_gfr = 10, + num_burnin = 10, + num_mcmc = 10, + general_params = general_param_list ) - - # 3 chains, thinning - general_param_list <- list(num_chains = 3, keep_every = 5) - expect_no_error( - bart_model <- bart(X_train = X_train, y_train = y_train, X_test = X_test, - num_gfr = 10, num_burnin = 10, num_mcmc = 10, - general_params = general_param_list) + ) + + # 3 chains, thinning + general_param_list <- list(num_chains = 3, keep_every = 5) + expect_no_error( + bart_model <- bart( + X_train = X_train, + y_train = y_train, + X_test = X_test, + num_gfr = 10, + num_burnin = 10, + num_mcmc = 10, + general_params = general_param_list ) - - # Check for error when more chains than GFR forests - general_param_list <- list(num_chains = 11, keep_every = 1) - expect_error( - bart_model <- bart(X_train = X_train, y_train = y_train, X_test = X_test, - num_gfr = 10, num_burnin = 10, num_mcmc = 10, - general_params = general_param_list) + ) + + # Check for error when more chains than GFR forests + general_param_list <- list(num_chains = 11, keep_every = 1) + expect_error( + bart_model <- bart( + X_train = X_train, + y_train = y_train, + X_test = X_test, + num_gfr = 10, + num_burnin = 10, + num_mcmc = 10, + general_params = general_param_list ) - - # Check for error when more chains than GFR forests - general_param_list <- list(num_chains = 11, keep_every = 5) - expect_error( - bart_model <- bart(X_train = X_train, y_train = y_train, X_test = X_test, - num_gfr = 10, num_burnin = 10, num_mcmc = 10, - general_params = general_param_list) + ) + + # Check for error when more chains than GFR forests + general_param_list <- list(num_chains = 11, keep_every = 5) + expect_error( + bart_model <- bart( + X_train = X_train, + y_train = y_train, + X_test = X_test, + num_gfr = 10, + num_burnin = 10, + num_mcmc = 10, + general_params = general_param_list ) + ) }) test_that("Warmstart BART", { - skip_on_cran() - - # Generate simulated data - n <- 100 - p <- 5 - X <- matrix(runif(n*p), ncol = p) - f_XW <- ( - ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + - ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + - ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + - ((0.75 <= X[,1]) & (1 > X[,1])) * (7.5) - ) - noise_sd <- 1 - y <- f_XW + rnorm(n, 0, noise_sd) - test_set_pct <- 0.2 - n_test <- round(test_set_pct*n) - n_train <- n - n_test - test_inds <- sort(sample(1:n, n_test, replace = FALSE)) - train_inds <- (1:n)[!((1:n) %in% test_inds)] - X_test <- X[test_inds,] - X_train <- X[train_inds,] - y_test <- y[test_inds] - y_train <- y[train_inds] - - # Run a BART model with only GFR - general_param_list <- list(num_chains = 1, keep_every = 1) - bart_model <- bart(X_train = X_train, y_train = y_train, X_test = X_test, - num_gfr = 10, num_burnin = 0, num_mcmc = 0, - general_params = general_param_list) - - # Save to JSON string - bart_model_json_string <- saveBARTModelToJsonString(bart_model) - - # Run a new BART chain from the existing (X)BART model - general_param_list <- list(num_chains = 3, keep_every = 5) - expect_no_error( - bart_model <- bart(X_train = X_train, y_train = y_train, X_test = X_test, - num_gfr = 0, num_burnin = 10, num_mcmc = 10, - previous_model_json = bart_model_json_string, - previous_model_warmstart_sample_num = 1, - general_params = general_param_list) - - ) - - # Generate simulated data with random effects - n <- 100 - p <- 5 - X <- matrix(runif(n*p), ncol = p) - f_XW <- ( - ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + - ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + - ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + - ((0.75 <= X[,1]) & (1 > X[,1])) * (7.5) + skip_on_cran() + + # Generate simulated data + n <- 100 + p <- 5 + X <- matrix(runif(n * p), ncol = p) + # fmt: skip + f_XW <- (((0 <= X[, 1]) & (0.25 > X[, 1])) * (-7.5) + + ((0.25 <= X[, 1]) & (0.5 > X[, 1])) * (-2.5) + + ((0.5 <= X[, 1]) & (0.75 > X[, 1])) * (2.5) + + ((0.75 <= X[, 1]) & (1 > X[, 1])) * (7.5)) + noise_sd <- 1 + y <- f_XW + rnorm(n, 0, noise_sd) + test_set_pct <- 0.2 + n_test <- round(test_set_pct * n) + n_train <- n - n_test + test_inds <- sort(sample(1:n, n_test, replace = FALSE)) + train_inds <- (1:n)[!((1:n) %in% test_inds)] + X_test <- X[test_inds, ] + X_train <- X[train_inds, ] + y_test <- y[test_inds] + y_train <- y[train_inds] + + # Run a BART model with only GFR + general_param_list <- list(num_chains = 1, keep_every = 1) + bart_model <- bart( + X_train = X_train, + y_train = y_train, + X_test = X_test, + num_gfr = 10, + num_burnin = 0, + num_mcmc = 0, + general_params = general_param_list + ) + + # Save to JSON string + bart_model_json_string <- saveBARTModelToJsonString(bart_model) + + # Run a new BART chain from the existing (X)BART model + general_param_list <- list(num_chains = 3, keep_every = 5) + expect_no_error( + bart_model <- bart( + X_train = X_train, + y_train = y_train, + X_test = X_test, + num_gfr = 0, + num_burnin = 10, + num_mcmc = 10, + previous_model_json = bart_model_json_string, + previous_model_warmstart_sample_num = 1, + general_params = general_param_list ) - rfx_group_ids <- sample(1:2, size = n, replace = TRUE) - rfx_basis <- rep(1, n) - rfx_coefs <- c(-5, 5) - rfx_term <- rfx_coefs[rfx_group_ids] * rfx_basis - noise_sd <- 1 - y <- f_XW + rfx_term + rnorm(n, 0, noise_sd) - test_set_pct <- 0.2 - n_test <- round(test_set_pct*n) - n_train <- n - n_test - test_inds <- sort(sample(1:n, n_test, replace = FALSE)) - train_inds <- (1:n)[!((1:n) %in% test_inds)] - X_test <- X[test_inds,] - X_train <- X[train_inds,] - rfx_group_ids_test <- rfx_group_ids[test_inds] - rfx_group_ids_train <- rfx_group_ids[train_inds] - rfx_basis_test <- rfx_basis[test_inds] - rfx_basis_train <- rfx_basis[train_inds] - y_test <- y[test_inds] - y_train <- y[train_inds] - - # Run a BART model with only GFR - general_param_list <- list(num_chains = 1, keep_every = 1) - bart_model <- bart(X_train = X_train, y_train = y_train, X_test = X_test, - rfx_group_ids_train = rfx_group_ids_train, - rfx_group_ids_test = rfx_group_ids_test, - rfx_basis_train = rfx_basis_train, - rfx_basis_test = rfx_basis_test, - num_gfr = 10, num_burnin = 0, num_mcmc = 0, - general_params = general_param_list) - - # Save to JSON string - bart_model_json_string <- saveBARTModelToJsonString(bart_model) - - # Run a new BART chain from the existing (X)BART model - general_param_list <- list(num_chains = 4, keep_every = 5) - expect_no_error( - bart_model <- bart(X_train = X_train, y_train = y_train, X_test = X_test, - rfx_group_ids_train = rfx_group_ids_train, - rfx_group_ids_test = rfx_group_ids_test, - rfx_basis_train = rfx_basis_train, - rfx_basis_test = rfx_basis_test, - num_gfr = 0, num_burnin = 10, num_mcmc = 10, - previous_model_json = bart_model_json_string, - previous_model_warmstart_sample_num = 1, - general_params = general_param_list) + ) + + # Generate simulated data with random effects + n <- 100 + p <- 5 + X <- matrix(runif(n * p), ncol = p) + # fmt: skip + f_XW <- (((0 <= X[, 1]) & (0.25 > X[, 1])) * (-7.5) + + ((0.25 <= X[, 1]) & (0.5 > X[, 1])) * (-2.5) + + ((0.5 <= X[, 1]) & (0.75 > X[, 1])) * (2.5) + + ((0.75 <= X[, 1]) & (1 > X[, 1])) * (7.5)) + rfx_group_ids <- sample(1:2, size = n, replace = TRUE) + rfx_basis <- rep(1, n) + rfx_coefs <- c(-5, 5) + rfx_term <- rfx_coefs[rfx_group_ids] * rfx_basis + noise_sd <- 1 + y <- f_XW + rfx_term + rnorm(n, 0, noise_sd) + test_set_pct <- 0.2 + n_test <- round(test_set_pct * n) + n_train <- n - n_test + test_inds <- sort(sample(1:n, n_test, replace = FALSE)) + train_inds <- (1:n)[!((1:n) %in% test_inds)] + X_test <- X[test_inds, ] + X_train <- X[train_inds, ] + rfx_group_ids_test <- rfx_group_ids[test_inds] + rfx_group_ids_train <- rfx_group_ids[train_inds] + rfx_basis_test <- rfx_basis[test_inds] + rfx_basis_train <- rfx_basis[train_inds] + y_test <- y[test_inds] + y_train <- y[train_inds] + + # Run a BART model with only GFR + general_param_list <- list(num_chains = 1, keep_every = 1) + bart_model <- bart( + X_train = X_train, + y_train = y_train, + X_test = X_test, + rfx_group_ids_train = rfx_group_ids_train, + rfx_group_ids_test = rfx_group_ids_test, + rfx_basis_train = rfx_basis_train, + rfx_basis_test = rfx_basis_test, + num_gfr = 10, + num_burnin = 0, + num_mcmc = 0, + general_params = general_param_list + ) + + # Save to JSON string + bart_model_json_string <- saveBARTModelToJsonString(bart_model) + + # Run a new BART chain from the existing (X)BART model + general_param_list <- list(num_chains = 4, keep_every = 5) + expect_no_error( + bart_model <- bart( + X_train = X_train, + y_train = y_train, + X_test = X_test, + rfx_group_ids_train = rfx_group_ids_train, + rfx_group_ids_test = rfx_group_ids_test, + rfx_basis_train = rfx_basis_train, + rfx_basis_test = rfx_basis_test, + num_gfr = 0, + num_burnin = 10, + num_mcmc = 10, + previous_model_json = bart_model_json_string, + previous_model_warmstart_sample_num = 1, + general_params = general_param_list ) + ) }) test_that("BART Predictions", { - skip_on_cran() - - # Generate simulated data - n <- 100 - p <- 5 - X <- matrix(runif(n*p), ncol = p) - f_XW <- ( - ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + - ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + - ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + - ((0.75 <= X[,1]) & (1 > X[,1])) * (7.5) - ) - noise_sd <- 1 - y <- f_XW + rnorm(n, 0, noise_sd) - test_set_pct <- 0.2 - n_test <- round(test_set_pct*n) - n_train <- n - n_test - test_inds <- sort(sample(1:n, n_test, replace = FALSE)) - train_inds <- (1:n)[!((1:n) %in% test_inds)] - X_test <- X[test_inds,] - X_train <- X[train_inds,] - y_test <- y[test_inds] - y_train <- y[train_inds] - - # Run a BART model with only GFR - general_params <- list(num_chains = 1) - variance_forest_params <- list(num_trees = 50) - bart_model <- bart(X_train = X_train, y_train = y_train, X_test = X_test, - num_gfr = 10, num_burnin = 0, num_mcmc = 10, - general_params = general_params, - variance_forest_params = variance_forest_params) - - # Check that cached predictions agree with results of predict() function - train_preds <- predict(bart_model, X = X_train) - train_preds_mean_cached <- bart_model$y_hat_train - train_preds_mean_recomputed <- train_preds$mean_forest_predictions - train_preds_variance_cached <- bart_model$sigma2_x_hat_train - train_preds_variance_recomputed <- train_preds$variance_forest_predictions - - # Assertion - expect_equal(train_preds_mean_cached, train_preds_mean_recomputed) - expect_equal(train_preds_variance_cached, train_preds_variance_recomputed) + skip_on_cran() + + # Generate simulated data + n <- 100 + p <- 5 + X <- matrix(runif(n * p), ncol = p) + # fmt: skip + f_XW <- (((0 <= X[, 1]) & (0.25 > X[, 1])) * (-7.5) + + ((0.25 <= X[, 1]) & (0.5 > X[, 1])) * (-2.5) + + ((0.5 <= X[, 1]) & (0.75 > X[, 1])) * (2.5) + + ((0.75 <= X[, 1]) & (1 > X[, 1])) * (7.5)) + noise_sd <- 1 + y <- f_XW + rnorm(n, 0, noise_sd) + test_set_pct <- 0.2 + n_test <- round(test_set_pct * n) + n_train <- n - n_test + test_inds <- sort(sample(1:n, n_test, replace = FALSE)) + train_inds <- (1:n)[!((1:n) %in% test_inds)] + X_test <- X[test_inds, ] + X_train <- X[train_inds, ] + y_test <- y[test_inds] + y_train <- y[train_inds] + + # Run a BART model with only GFR + general_params <- list(num_chains = 1) + variance_forest_params <- list(num_trees = 50) + bart_model <- bart( + X_train = X_train, + y_train = y_train, + X_test = X_test, + num_gfr = 10, + num_burnin = 0, + num_mcmc = 10, + general_params = general_params, + variance_forest_params = variance_forest_params + ) + + # Check that cached predictions agree with results of predict() function + train_preds <- predict(bart_model, X = X_train) + train_preds_mean_cached <- bart_model$y_hat_train + train_preds_mean_recomputed <- train_preds$mean_forest_predictions + train_preds_variance_cached <- bart_model$sigma2_x_hat_train + train_preds_variance_recomputed <- train_preds$variance_forest_predictions + + # Assertion + expect_equal(train_preds_mean_cached, train_preds_mean_recomputed) + expect_equal(train_preds_variance_cached, train_preds_variance_recomputed) }) test_that("Random Effects BART", { - skip_on_cran() - - # Generate simulated data - n <- 100 - p <- 5 - p_w <- 2 - X <- matrix(runif(n*p), ncol = p) - W <- matrix(runif(n*p_w), ncol = p_w) - f_XW <- ( - ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5*W[,1]) + - ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5*W[,1]) + - ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5*W[,1]) + - ((0.75 <= X[,1]) & (1 > X[,1])) * (7.5*W[,1]) - ) - rfx_group_ids <- sample(1:2, size = n, replace = TRUE) - rfx_basis <- cbind(rep(1, n), runif(n)) - num_rfx_components <- ncol(rfx_basis) - num_rfx_groups <- length(unique(rfx_group_ids)) - rfx_coefs <- matrix(c(-5, 5, 1, -1), ncol = 2, byrow = T) - rfx_term <- rowSums(rfx_coefs[rfx_group_ids,] * rfx_basis) - noise_sd <- 1 - y <- f_XW + rfx_term + rnorm(n, 0, noise_sd) - test_set_pct <- 0.2 - n_test <- round(test_set_pct*n) - n_train <- n - n_test - test_inds <- sort(sample(1:n, n_test, replace = FALSE)) - train_inds <- (1:n)[!((1:n) %in% test_inds)] - X_test <- X[test_inds,] - X_train <- X[train_inds,] - W_test <- W[test_inds,] - W_train <- W[train_inds,] - rfx_group_ids_test <- rfx_group_ids[test_inds] - rfx_group_ids_train <- rfx_group_ids[train_inds] - rfx_basis_test <- rfx_basis[test_inds,] - rfx_basis_train <- rfx_basis[train_inds,] - y_test <- y[test_inds] - y_train <- y[train_inds] - - # Specify no rfx parameters - general_param_list <- list() - mean_forest_param_list <- list(sample_sigma2_leaf = FALSE) - expect_no_error( - bart_model <- bart(X_train = X_train, y_train = y_train, X_test = X_test, - leaf_basis_train = W_train, leaf_basis_test = W_test, - rfx_group_ids_train = rfx_group_ids_train, - rfx_group_ids_test = rfx_group_ids_test, - rfx_basis_train = rfx_basis_train, - rfx_basis_test = rfx_basis_test, - num_gfr = 0, num_burnin = 10, num_mcmc = 10, - general_params = general_param_list, - mean_forest_params = mean_forest_param_list) + skip_on_cran() + + # Generate simulated data + n <- 100 + p <- 5 + p_w <- 2 + X <- matrix(runif(n * p), ncol = p) + W <- matrix(runif(n * p_w), ncol = p_w) + # fmt: skip + f_XW <- (((0 <= X[, 1]) & (0.25 > X[, 1])) * (-7.5 * W[, 1]) + + ((0.25 <= X[, 1]) & (0.5 > X[, 1])) * (-2.5 * W[, 1]) + + ((0.5 <= X[, 1]) & (0.75 > X[, 1])) * (2.5 * W[, 1]) + + ((0.75 <= X[, 1]) & (1 > X[, 1])) * (7.5 * W[, 1])) + rfx_group_ids <- sample(1:2, size = n, replace = TRUE) + rfx_basis <- cbind(rep(1, n), runif(n)) + num_rfx_components <- ncol(rfx_basis) + num_rfx_groups <- length(unique(rfx_group_ids)) + rfx_coefs <- matrix(c(-5, 5, 1, -1), ncol = 2, byrow = T) + rfx_term <- rowSums(rfx_coefs[rfx_group_ids, ] * rfx_basis) + noise_sd <- 1 + y <- f_XW + rfx_term + rnorm(n, 0, noise_sd) + test_set_pct <- 0.2 + n_test <- round(test_set_pct * n) + n_train <- n - n_test + test_inds <- sort(sample(1:n, n_test, replace = FALSE)) + train_inds <- (1:n)[!((1:n) %in% test_inds)] + X_test <- X[test_inds, ] + X_train <- X[train_inds, ] + W_test <- W[test_inds, ] + W_train <- W[train_inds, ] + rfx_group_ids_test <- rfx_group_ids[test_inds] + rfx_group_ids_train <- rfx_group_ids[train_inds] + rfx_basis_test <- rfx_basis[test_inds, ] + rfx_basis_train <- rfx_basis[train_inds, ] + y_test <- y[test_inds] + y_train <- y[train_inds] + + # Specify no rfx parameters + general_param_list <- list() + mean_forest_param_list <- list(sample_sigma2_leaf = FALSE) + expect_no_error( + bart_model <- bart( + X_train = X_train, + y_train = y_train, + X_test = X_test, + leaf_basis_train = W_train, + leaf_basis_test = W_test, + rfx_group_ids_train = rfx_group_ids_train, + rfx_group_ids_test = rfx_group_ids_test, + rfx_basis_train = rfx_basis_train, + rfx_basis_test = rfx_basis_test, + num_gfr = 0, + num_burnin = 10, + num_mcmc = 10, + general_params = general_param_list, + mean_forest_params = mean_forest_param_list ) - - # Specify all rfx parameters as scalars - general_param_list <- list(rfx_working_parameter_prior_mean = 1., - rfx_group_parameter_prior_mean = 1., - rfx_working_parameter_prior_cov = 1., - rfx_group_parameter_prior_cov = 1., - rfx_variance_prior_shape = 1, - rfx_variance_prior_scale = 1) - mean_forest_param_list <- list(sample_sigma2_leaf = FALSE) - expect_no_error( - bart_model <- bart(X_train = X_train, y_train = y_train, X_test = X_test, - leaf_basis_train = W_train, leaf_basis_test = W_test, - rfx_group_ids_train = rfx_group_ids_train, - rfx_group_ids_test = rfx_group_ids_test, - rfx_basis_train = rfx_basis_train, - rfx_basis_test = rfx_basis_test, - num_gfr = 0, num_burnin = 10, num_mcmc = 10, - general_params = general_param_list, - mean_forest_params = mean_forest_param_list) + ) + + # Specify all rfx parameters as scalars + general_param_list <- list( + rfx_working_parameter_prior_mean = 1., + rfx_group_parameter_prior_mean = 1., + rfx_working_parameter_prior_cov = 1., + rfx_group_parameter_prior_cov = 1., + rfx_variance_prior_shape = 1, + rfx_variance_prior_scale = 1 + ) + mean_forest_param_list <- list(sample_sigma2_leaf = FALSE) + expect_no_error( + bart_model <- bart( + X_train = X_train, + y_train = y_train, + X_test = X_test, + leaf_basis_train = W_train, + leaf_basis_test = W_test, + rfx_group_ids_train = rfx_group_ids_train, + rfx_group_ids_test = rfx_group_ids_test, + rfx_basis_train = rfx_basis_train, + rfx_basis_test = rfx_basis_test, + num_gfr = 0, + num_burnin = 10, + num_mcmc = 10, + general_params = general_param_list, + mean_forest_params = mean_forest_param_list ) - - # Specify all relevant rfx parameters as vectors - general_param_list <- list(rfx_working_parameter_prior_mean = c(1.,1.), - rfx_group_parameter_prior_mean = c(1.,1.), - rfx_working_parameter_prior_cov = diag(1.,2), - rfx_group_parameter_prior_cov = diag(1.,2), - rfx_variance_prior_shape = 1, - rfx_variance_prior_scale = 1) - mean_forest_param_list <- list(sample_sigma2_leaf = FALSE) - expect_no_error( - bart_model <- bart(X_train = X_train, y_train = y_train, X_test = X_test, - leaf_basis_train = W_train, leaf_basis_test = W_test, - rfx_group_ids_train = rfx_group_ids_train, - rfx_group_ids_test = rfx_group_ids_test, - rfx_basis_train = rfx_basis_train, - rfx_basis_test = rfx_basis_test, - num_gfr = 0, num_burnin = 10, num_mcmc = 10, - general_params = general_param_list, - mean_forest_params = mean_forest_param_list) + ) + + # Specify all relevant rfx parameters as vectors + general_param_list <- list( + rfx_working_parameter_prior_mean = c(1., 1.), + rfx_group_parameter_prior_mean = c(1., 1.), + rfx_working_parameter_prior_cov = diag(1., 2), + rfx_group_parameter_prior_cov = diag(1., 2), + rfx_variance_prior_shape = 1, + rfx_variance_prior_scale = 1 + ) + mean_forest_param_list <- list(sample_sigma2_leaf = FALSE) + expect_no_error( + bart_model <- bart( + X_train = X_train, + y_train = y_train, + X_test = X_test, + leaf_basis_train = W_train, + leaf_basis_test = W_test, + rfx_group_ids_train = rfx_group_ids_train, + rfx_group_ids_test = rfx_group_ids_test, + rfx_basis_train = rfx_basis_train, + rfx_basis_test = rfx_basis_test, + num_gfr = 0, + num_burnin = 10, + num_mcmc = 10, + general_params = general_param_list, + mean_forest_params = mean_forest_param_list ) -}) \ No newline at end of file + ) +}) diff --git a/test/R/testthat/test-bcf.R b/test/R/testthat/test-bcf.R index 531320f6..a94fe3f3 100644 --- a/test/R/testthat/test-bcf.R +++ b/test/R/testthat/test-bcf.R @@ -1,607 +1,783 @@ test_that("MCMC BCF", { - skip_on_cran() - - # Generate simulated data - n <- 100 - p <- 5 - X <- matrix(runif(n*p), ncol = p) - mu_X <- ( - ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + - ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + - ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + - ((0.75 <= X[,1]) & (1 > X[,1])) * (7.5) - ) - pi_X <- ( - ((0 <= X[,1]) & (0.25 > X[,1])) * (0.2) + - ((0.25 <= X[,1]) & (0.5 > X[,1])) * (0.4) + - ((0.5 <= X[,1]) & (0.75 > X[,1])) * (0.6) + - ((0.75 <= X[,1]) & (1 > X[,1])) * (0.8) - ) - tau_X <- ( - ((0 <= X[,2]) & (0.25 > X[,2])) * (0.5) + - ((0.25 <= X[,2]) & (0.5 > X[,2])) * (1.0) + - ((0.5 <= X[,2]) & (0.75 > X[,2])) * (1.5) + - ((0.75 <= X[,2]) & (1 > X[,2])) * (2.0) - ) - Z <- rbinom(n, 1, pi_X) - noise_sd <- 1 - y <- mu_X + tau_X*Z + rnorm(n, 0, noise_sd) - test_set_pct <- 0.2 - n_test <- round(test_set_pct*n) - n_train <- n - n_test - test_inds <- sort(sample(1:n, n_test, replace = FALSE)) - train_inds <- (1:n)[!((1:n) %in% test_inds)] - X_test <- X[test_inds,] - X_train <- X[train_inds,] - Z_test <- Z[test_inds] - Z_train <- Z[train_inds] - pi_test <- pi_X[test_inds] - pi_train <- pi_X[train_inds] - mu_test <- mu_X[test_inds] - mu_train <- mu_X[train_inds] - tau_test <- tau_X[test_inds] - tau_train <- tau_X[train_inds] - y_test <- y[test_inds] - y_train <- y[train_inds] - - # 1 chain, no thinning - general_param_list <- list(num_chains = 1, keep_every = 1) - expect_no_error( - bcf_model <- bcf(X_train = X_train, y_train = y_train, Z_train = Z_train, - propensity_train = pi_train, X_test = X_test, Z_test = Z_test, - propensity_test = pi_test, num_gfr = 0, num_burnin = 10, - num_mcmc = 10, general_params = general_param_list) - ) - - # 1 chain, no thinning, matrix leaf scale parameter provided - general_param_list <- list(num_chains = 1, keep_every = 1) - mu_forest_param_list <- list(sigma2_leaf_init = as.matrix(0.5)) - tau_forest_param_list <- list(sigma2_leaf_init = as.matrix(0.5)) - expect_no_error( - bcf_model <- bcf(X_train = X_train, y_train = y_train, Z_train = Z_train, - propensity_train = pi_train, X_test = X_test, Z_test = Z_test, - propensity_test = pi_test, num_gfr = 0, num_burnin = 10, - num_mcmc = 10, general_params = general_param_list, - prognostic_forest_params = mu_forest_param_list, - treatment_effect_forest_params = tau_forest_param_list) - ) - - # 1 chain, no thinning, scalar leaf scale parameter provided - general_param_list <- list(num_chains = 1, keep_every = 1) - mu_forest_param_list <- list(sigma2_leaf_init = 0.5) - tau_forest_param_list <- list(sigma2_leaf_init = 0.5) - expect_no_error( - bcf_model <- bcf(X_train = X_train, y_train = y_train, Z_train = Z_train, - propensity_train = pi_train, X_test = X_test, Z_test = Z_test, - propensity_test = pi_test, num_gfr = 0, num_burnin = 10, - num_mcmc = 10, general_params = general_param_list, - prognostic_forest_params = mu_forest_param_list, - treatment_effect_forest_params = tau_forest_param_list) - ) - - # 3 chains, no thinning - general_param_list <- list(num_chains = 3, keep_every = 1) - expect_no_error( - bcf_model <- bcf(X_train = X_train, y_train = y_train, Z_train = Z_train, - propensity_train = pi_train, X_test = X_test, Z_test = Z_test, - propensity_test = pi_test, num_gfr = 0, num_burnin = 10, - num_mcmc = 10, general_params = general_param_list) - ) - - # 1 chain, thinning - general_param_list <- list(num_chains = 1, keep_every = 5) - expect_no_error( - bcf_model <- bcf(X_train = X_train, y_train = y_train, Z_train = Z_train, - propensity_train = pi_train, X_test = X_test, Z_test = Z_test, - propensity_test = pi_test, num_gfr = 0, num_burnin = 10, - num_mcmc = 10, general_params = general_param_list) - ) - - # 3 chains, thinning - general_param_list <- list(num_chains = 3, keep_every = 5) - expect_no_error( - bcf_model <- bcf(X_train = X_train, y_train = y_train, Z_train = Z_train, - propensity_train = pi_train, X_test = X_test, Z_test = Z_test, - propensity_test = pi_test, num_gfr = 0, num_burnin = 10, - num_mcmc = 10, general_params = general_param_list) - ) + skip_on_cran() + + # Generate simulated data + n <- 100 + p <- 5 + X <- matrix(runif(n * p), ncol = p) + mu_X <- (((0 <= X[, 1]) & (0.25 > X[, 1])) * + (-7.5) + + ((0.25 <= X[, 1]) & (0.5 > X[, 1])) * (-2.5) + + ((0.5 <= X[, 1]) & (0.75 > X[, 1])) * (2.5) + + ((0.75 <= X[, 1]) & (1 > X[, 1])) * (7.5)) + pi_X <- (((0 <= X[, 1]) & (0.25 > X[, 1])) * + (0.2) + + ((0.25 <= X[, 1]) & (0.5 > X[, 1])) * (0.4) + + ((0.5 <= X[, 1]) & (0.75 > X[, 1])) * (0.6) + + ((0.75 <= X[, 1]) & (1 > X[, 1])) * (0.8)) + tau_X <- (((0 <= X[, 2]) & (0.25 > X[, 2])) * + (0.5) + + ((0.25 <= X[, 2]) & (0.5 > X[, 2])) * (1.0) + + ((0.5 <= X[, 2]) & (0.75 > X[, 2])) * (1.5) + + ((0.75 <= X[, 2]) & (1 > X[, 2])) * (2.0)) + Z <- rbinom(n, 1, pi_X) + noise_sd <- 1 + y <- mu_X + tau_X * Z + rnorm(n, 0, noise_sd) + test_set_pct <- 0.2 + n_test <- round(test_set_pct * n) + n_train <- n - n_test + test_inds <- sort(sample(1:n, n_test, replace = FALSE)) + train_inds <- (1:n)[!((1:n) %in% test_inds)] + X_test <- X[test_inds, ] + X_train <- X[train_inds, ] + Z_test <- Z[test_inds] + Z_train <- Z[train_inds] + pi_test <- pi_X[test_inds] + pi_train <- pi_X[train_inds] + mu_test <- mu_X[test_inds] + mu_train <- mu_X[train_inds] + tau_test <- tau_X[test_inds] + tau_train <- tau_X[train_inds] + y_test <- y[test_inds] + y_train <- y[train_inds] + + # 1 chain, no thinning + general_param_list <- list(num_chains = 1, keep_every = 1) + expect_no_error( + bcf_model <- bcf( + X_train = X_train, + y_train = y_train, + Z_train = Z_train, + propensity_train = pi_train, + X_test = X_test, + Z_test = Z_test, + propensity_test = pi_test, + num_gfr = 0, + num_burnin = 10, + num_mcmc = 10, + general_params = general_param_list + ) + ) + + # 1 chain, no thinning, matrix leaf scale parameter provided + general_param_list <- list(num_chains = 1, keep_every = 1) + mu_forest_param_list <- list(sigma2_leaf_init = as.matrix(0.5)) + tau_forest_param_list <- list(sigma2_leaf_init = as.matrix(0.5)) + expect_no_error( + bcf_model <- bcf( + X_train = X_train, + y_train = y_train, + Z_train = Z_train, + propensity_train = pi_train, + X_test = X_test, + Z_test = Z_test, + propensity_test = pi_test, + num_gfr = 0, + num_burnin = 10, + num_mcmc = 10, + general_params = general_param_list, + prognostic_forest_params = mu_forest_param_list, + treatment_effect_forest_params = tau_forest_param_list + ) + ) + + # 1 chain, no thinning, scalar leaf scale parameter provided + general_param_list <- list(num_chains = 1, keep_every = 1) + mu_forest_param_list <- list(sigma2_leaf_init = 0.5) + tau_forest_param_list <- list(sigma2_leaf_init = 0.5) + expect_no_error( + bcf_model <- bcf( + X_train = X_train, + y_train = y_train, + Z_train = Z_train, + propensity_train = pi_train, + X_test = X_test, + Z_test = Z_test, + propensity_test = pi_test, + num_gfr = 0, + num_burnin = 10, + num_mcmc = 10, + general_params = general_param_list, + prognostic_forest_params = mu_forest_param_list, + treatment_effect_forest_params = tau_forest_param_list + ) + ) + + # 3 chains, no thinning + general_param_list <- list(num_chains = 3, keep_every = 1) + expect_no_error( + bcf_model <- bcf( + X_train = X_train, + y_train = y_train, + Z_train = Z_train, + propensity_train = pi_train, + X_test = X_test, + Z_test = Z_test, + propensity_test = pi_test, + num_gfr = 0, + num_burnin = 10, + num_mcmc = 10, + general_params = general_param_list + ) + ) + + # 1 chain, thinning + general_param_list <- list(num_chains = 1, keep_every = 5) + expect_no_error( + bcf_model <- bcf( + X_train = X_train, + y_train = y_train, + Z_train = Z_train, + propensity_train = pi_train, + X_test = X_test, + Z_test = Z_test, + propensity_test = pi_test, + num_gfr = 0, + num_burnin = 10, + num_mcmc = 10, + general_params = general_param_list + ) + ) + + # 3 chains, thinning + general_param_list <- list(num_chains = 3, keep_every = 5) + expect_no_error( + bcf_model <- bcf( + X_train = X_train, + y_train = y_train, + Z_train = Z_train, + propensity_train = pi_train, + X_test = X_test, + Z_test = Z_test, + propensity_test = pi_test, + num_gfr = 0, + num_burnin = 10, + num_mcmc = 10, + general_params = general_param_list + ) + ) }) test_that("GFR BCF", { - skip_on_cran() - - # Generate simulated data - n <- 100 - p <- 5 - X <- matrix(runif(n*p), ncol = p) - mu_X <- ( - ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + - ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + - ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + - ((0.75 <= X[,1]) & (1 > X[,1])) * (7.5) - ) - pi_X <- ( - ((0 <= X[,1]) & (0.25 > X[,1])) * (0.2) + - ((0.25 <= X[,1]) & (0.5 > X[,1])) * (0.4) + - ((0.5 <= X[,1]) & (0.75 > X[,1])) * (0.6) + - ((0.75 <= X[,1]) & (1 > X[,1])) * (0.8) - ) - tau_X <- ( - ((0 <= X[,2]) & (0.25 > X[,2])) * (0.5) + - ((0.25 <= X[,2]) & (0.5 > X[,2])) * (1.0) + - ((0.5 <= X[,2]) & (0.75 > X[,2])) * (1.5) + - ((0.75 <= X[,2]) & (1 > X[,2])) * (2.0) - ) - Z <- rbinom(n, 1, pi_X) - noise_sd <- 1 - y <- mu_X + tau_X*Z + rnorm(n, 0, noise_sd) - test_set_pct <- 0.2 - n_test <- round(test_set_pct*n) - n_train <- n - n_test - test_inds <- sort(sample(1:n, n_test, replace = FALSE)) - train_inds <- (1:n)[!((1:n) %in% test_inds)] - X_test <- X[test_inds,] - X_train <- X[train_inds,] - Z_test <- Z[test_inds] - Z_train <- Z[train_inds] - pi_test <- pi_X[test_inds] - pi_train <- pi_X[train_inds] - mu_test <- mu_X[test_inds] - mu_train <- mu_X[train_inds] - tau_test <- tau_X[test_inds] - tau_train <- tau_X[train_inds] - y_test <- y[test_inds] - y_train <- y[train_inds] - - # 1 chain, no thinning - general_param_list <- list(num_chains = 1, keep_every = 1) - expect_no_error( - bcf_model <- bcf(X_train = X_train, y_train = y_train, Z_train = Z_train, - propensity_train = pi_train, X_test = X_test, Z_test = Z_test, - propensity_test = pi_test, num_gfr = 10, num_burnin = 10, - num_mcmc = 10, general_params = general_param_list) - ) - - # 3 chains, no thinning - general_param_list <- list(num_chains = 3, keep_every = 1) - expect_no_error( - bcf_model <- bcf(X_train = X_train, y_train = y_train, Z_train = Z_train, - propensity_train = pi_train, X_test = X_test, Z_test = Z_test, - propensity_test = pi_test, num_gfr = 10, num_burnin = 10, - num_mcmc = 10, general_params = general_param_list) - ) - - # 1 chain, thinning - general_param_list <- list(num_chains = 1, keep_every = 5) - expect_no_error( - bcf_model <- bcf(X_train = X_train, y_train = y_train, Z_train = Z_train, - propensity_train = pi_train, X_test = X_test, Z_test = Z_test, - propensity_test = pi_test, num_gfr = 10, num_burnin = 10, - num_mcmc = 10, general_params = general_param_list) - ) - - # 3 chains, thinning - general_param_list <- list(num_chains = 3, keep_every = 5) - expect_no_error( - bcf_model <- bcf(X_train = X_train, y_train = y_train, Z_train = Z_train, - propensity_train = pi_train, X_test = X_test, Z_test = Z_test, - propensity_test = pi_test, num_gfr = 10, num_burnin = 10, - num_mcmc = 10, general_params = general_param_list) - ) - - # Check for error when more chains than GFR forests - general_param_list <- list(num_chains = 11, keep_every = 1) - expect_error( - bcf_model <- bcf(X_train = X_train, y_train = y_train, Z_train = Z_train, - propensity_train = pi_train, X_test = X_test, Z_test = Z_test, - propensity_test = pi_test, num_gfr = 10, num_burnin = 10, - num_mcmc = 10, general_params = general_param_list) - ) - - # Check for error when more chains than GFR forests - general_param_list <- list(num_chains = 11, keep_every = 5) - expect_error( - bcf_model <- bcf(X_train = X_train, y_train = y_train, Z_train = Z_train, - propensity_train = pi_train, X_test = X_test, Z_test = Z_test, - propensity_test = pi_test, num_gfr = 10, num_burnin = 10, - num_mcmc = 10, general_params = general_param_list) - ) + skip_on_cran() + + # Generate simulated data + n <- 100 + p <- 5 + X <- matrix(runif(n * p), ncol = p) + mu_X <- (((0 <= X[, 1]) & (0.25 > X[, 1])) * + (-7.5) + + ((0.25 <= X[, 1]) & (0.5 > X[, 1])) * (-2.5) + + ((0.5 <= X[, 1]) & (0.75 > X[, 1])) * (2.5) + + ((0.75 <= X[, 1]) & (1 > X[, 1])) * (7.5)) + pi_X <- (((0 <= X[, 1]) & (0.25 > X[, 1])) * + (0.2) + + ((0.25 <= X[, 1]) & (0.5 > X[, 1])) * (0.4) + + ((0.5 <= X[, 1]) & (0.75 > X[, 1])) * (0.6) + + ((0.75 <= X[, 1]) & (1 > X[, 1])) * (0.8)) + tau_X <- (((0 <= X[, 2]) & (0.25 > X[, 2])) * + (0.5) + + ((0.25 <= X[, 2]) & (0.5 > X[, 2])) * (1.0) + + ((0.5 <= X[, 2]) & (0.75 > X[, 2])) * (1.5) + + ((0.75 <= X[, 2]) & (1 > X[, 2])) * (2.0)) + Z <- rbinom(n, 1, pi_X) + noise_sd <- 1 + y <- mu_X + tau_X * Z + rnorm(n, 0, noise_sd) + test_set_pct <- 0.2 + n_test <- round(test_set_pct * n) + n_train <- n - n_test + test_inds <- sort(sample(1:n, n_test, replace = FALSE)) + train_inds <- (1:n)[!((1:n) %in% test_inds)] + X_test <- X[test_inds, ] + X_train <- X[train_inds, ] + Z_test <- Z[test_inds] + Z_train <- Z[train_inds] + pi_test <- pi_X[test_inds] + pi_train <- pi_X[train_inds] + mu_test <- mu_X[test_inds] + mu_train <- mu_X[train_inds] + tau_test <- tau_X[test_inds] + tau_train <- tau_X[train_inds] + y_test <- y[test_inds] + y_train <- y[train_inds] + + # 1 chain, no thinning + general_param_list <- list(num_chains = 1, keep_every = 1) + expect_no_error( + bcf_model <- bcf( + X_train = X_train, + y_train = y_train, + Z_train = Z_train, + propensity_train = pi_train, + X_test = X_test, + Z_test = Z_test, + propensity_test = pi_test, + num_gfr = 10, + num_burnin = 10, + num_mcmc = 10, + general_params = general_param_list + ) + ) + + # 3 chains, no thinning + general_param_list <- list(num_chains = 3, keep_every = 1) + expect_no_error( + bcf_model <- bcf( + X_train = X_train, + y_train = y_train, + Z_train = Z_train, + propensity_train = pi_train, + X_test = X_test, + Z_test = Z_test, + propensity_test = pi_test, + num_gfr = 10, + num_burnin = 10, + num_mcmc = 10, + general_params = general_param_list + ) + ) + + # 1 chain, thinning + general_param_list <- list(num_chains = 1, keep_every = 5) + expect_no_error( + bcf_model <- bcf( + X_train = X_train, + y_train = y_train, + Z_train = Z_train, + propensity_train = pi_train, + X_test = X_test, + Z_test = Z_test, + propensity_test = pi_test, + num_gfr = 10, + num_burnin = 10, + num_mcmc = 10, + general_params = general_param_list + ) + ) + + # 3 chains, thinning + general_param_list <- list(num_chains = 3, keep_every = 5) + expect_no_error( + bcf_model <- bcf( + X_train = X_train, + y_train = y_train, + Z_train = Z_train, + propensity_train = pi_train, + X_test = X_test, + Z_test = Z_test, + propensity_test = pi_test, + num_gfr = 10, + num_burnin = 10, + num_mcmc = 10, + general_params = general_param_list + ) + ) + + # Check for error when more chains than GFR forests + general_param_list <- list(num_chains = 11, keep_every = 1) + expect_error( + bcf_model <- bcf( + X_train = X_train, + y_train = y_train, + Z_train = Z_train, + propensity_train = pi_train, + X_test = X_test, + Z_test = Z_test, + propensity_test = pi_test, + num_gfr = 10, + num_burnin = 10, + num_mcmc = 10, + general_params = general_param_list + ) + ) + + # Check for error when more chains than GFR forests + general_param_list <- list(num_chains = 11, keep_every = 5) + expect_error( + bcf_model <- bcf( + X_train = X_train, + y_train = y_train, + Z_train = Z_train, + propensity_train = pi_train, + X_test = X_test, + Z_test = Z_test, + propensity_test = pi_test, + num_gfr = 10, + num_burnin = 10, + num_mcmc = 10, + general_params = general_param_list + ) + ) }) test_that("Warmstart BCF", { - skip_on_cran() - - # Generate simulated data - n <- 100 - p <- 5 - X <- matrix(runif(n*p), ncol = p) - mu_X <- ( - ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + - ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + - ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + - ((0.75 <= X[,1]) & (1 > X[,1])) * (7.5) - ) - pi_X <- ( - ((0 <= X[,1]) & (0.25 > X[,1])) * (0.2) + - ((0.25 <= X[,1]) & (0.5 > X[,1])) * (0.4) + - ((0.5 <= X[,1]) & (0.75 > X[,1])) * (0.6) + - ((0.75 <= X[,1]) & (1 > X[,1])) * (0.8) - ) - tau_X <- ( - ((0 <= X[,2]) & (0.25 > X[,2])) * (0.5) + - ((0.25 <= X[,2]) & (0.5 > X[,2])) * (1.0) + - ((0.5 <= X[,2]) & (0.75 > X[,2])) * (1.5) + - ((0.75 <= X[,2]) & (1 > X[,2])) * (2.0) - ) - Z <- rbinom(n, 1, pi_X) - noise_sd <- 1 - y <- mu_X + tau_X*Z + rnorm(n, 0, noise_sd) - test_set_pct <- 0.2 - n_test <- round(test_set_pct*n) - n_train <- n - n_test - test_inds <- sort(sample(1:n, n_test, replace = FALSE)) - train_inds <- (1:n)[!((1:n) %in% test_inds)] - X_test <- X[test_inds,] - X_train <- X[train_inds,] - Z_test <- Z[test_inds] - Z_train <- Z[train_inds] - pi_test <- pi_X[test_inds] - pi_train <- pi_X[train_inds] - mu_test <- mu_X[test_inds] - mu_train <- mu_X[train_inds] - tau_test <- tau_X[test_inds] - tau_train <- tau_X[train_inds] - y_test <- y[test_inds] - y_train <- y[train_inds] - - # Run a BCF model with only GFR - general_param_list <- list(num_chains = 1, keep_every = 1) - bcf_model <- bcf(X_train = X_train, y_train = y_train, Z_train = Z_train, - propensity_train = pi_train, X_test = X_test, Z_test = Z_test, - propensity_test = pi_test, num_gfr = 10, num_burnin = 0, - num_mcmc = 0, general_params = general_param_list) - - # Save to JSON string - bcf_model_json_string <- saveBCFModelToJsonString(bcf_model) - - # Run a new BCF chain from the existing (X)BCF model - general_param_list <- list(num_chains = 3, keep_every = 5) - expect_no_error( - bcf_model <- bcf(X_train = X_train, y_train = y_train, Z_train = Z_train, - propensity_train = pi_train, X_test = X_test, Z_test = Z_test, - propensity_test = pi_test, num_gfr = 0, num_burnin = 10, - num_mcmc = 10, previous_model_json = bcf_model_json_string, - previous_model_warmstart_sample_num = 1, - general_params = general_param_list) - ) - - # Generate simulated data with random effects - n <- 100 - p <- 5 - X <- matrix(runif(n*p), ncol = p) - mu_X <- ( - ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + - ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + - ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + - ((0.75 <= X[,1]) & (1 > X[,1])) * (7.5) - ) - pi_X <- ( - ((0 <= X[,1]) & (0.25 > X[,1])) * (0.2) + - ((0.25 <= X[,1]) & (0.5 > X[,1])) * (0.4) + - ((0.5 <= X[,1]) & (0.75 > X[,1])) * (0.6) + - ((0.75 <= X[,1]) & (1 > X[,1])) * (0.8) - ) - tau_X <- ( - ((0 <= X[,2]) & (0.25 > X[,2])) * (0.5) + - ((0.25 <= X[,2]) & (0.5 > X[,2])) * (1.0) + - ((0.5 <= X[,2]) & (0.75 > X[,2])) * (1.5) + - ((0.75 <= X[,2]) & (1 > X[,2])) * (2.0) - ) - Z <- rbinom(n, 1, pi_X) - rfx_group_ids <- sample(1:2, size = n, replace = TRUE) - rfx_basis <- rep(1, n) - rfx_coefs <- c(-5, 5) - rfx_term <- rfx_coefs[rfx_group_ids] * rfx_basis - noise_sd <- 1 - y <- mu_X + tau_X*Z + rfx_term + rnorm(n, 0, noise_sd) - test_set_pct <- 0.2 - n_test <- round(test_set_pct*n) - n_train <- n - n_test - test_inds <- sort(sample(1:n, n_test, replace = FALSE)) - train_inds <- (1:n)[!((1:n) %in% test_inds)] - X_test <- X[test_inds,] - X_train <- X[train_inds,] - Z_test <- Z[test_inds] - Z_train <- Z[train_inds] - pi_test <- pi_X[test_inds] - pi_train <- pi_X[train_inds] - mu_test <- mu_X[test_inds] - mu_train <- mu_X[train_inds] - tau_test <- tau_X[test_inds] - tau_train <- tau_X[train_inds] - rfx_group_ids_test <- rfx_group_ids[test_inds] - rfx_group_ids_train <- rfx_group_ids[train_inds] - rfx_basis_test <- rfx_basis[test_inds] - rfx_basis_train <- rfx_basis[train_inds] - y_test <- y[test_inds] - y_train <- y[train_inds] - - # Run a BCF model with only GFR - general_param_list <- list(num_chains = 1, keep_every = 1) - bcf_model <- bcf(X_train = X_train, y_train = y_train, Z_train = Z_train, - propensity_train = pi_train, X_test = X_test, Z_test = Z_test, - rfx_group_ids_train = rfx_group_ids_train, - rfx_group_ids_test = rfx_group_ids_test, - rfx_basis_train = rfx_basis_train, - rfx_basis_test = rfx_basis_test, - propensity_test = pi_test, num_gfr = 10, num_burnin = 0, - num_mcmc = 0, general_params = general_param_list) - - # Save to JSON string - bcf_model_json_string <- saveBCFModelToJsonString(bcf_model) - - # Run a new BCF chain from the existing (X)BCF model - general_param_list <- list(num_chains = 3, keep_every = 5) - expect_no_error( - bcf_model <- bcf(X_train = X_train, y_train = y_train, Z_train = Z_train, - propensity_train = pi_train, X_test = X_test, Z_test = Z_test, - rfx_group_ids_train = rfx_group_ids_train, - rfx_group_ids_test = rfx_group_ids_test, - rfx_basis_train = rfx_basis_train, - rfx_basis_test = rfx_basis_test, - propensity_test = pi_test, num_gfr = 0, num_burnin = 10, - num_mcmc = 10, previous_model_json = bcf_model_json_string, - previous_model_warmstart_sample_num = 1, - general_params = general_param_list) - ) + skip_on_cran() + + # Generate simulated data + n <- 100 + p <- 5 + X <- matrix(runif(n * p), ncol = p) + # fmt: skip + mu_X <- (((0 <= X[, 1]) & (0.25 > X[, 1])) * (-7.5) + + ((0.25 <= X[, 1]) & (0.5 > X[, 1])) * (-2.5) + + ((0.5 <= X[, 1]) & (0.75 > X[, 1])) * (2.5) + + ((0.75 <= X[, 1]) & (1 > X[, 1])) * (7.5)) + # fmt: skip + pi_X <- (((0 <= X[, 1]) & (0.25 > X[, 1])) * (0.2) + + ((0.25 <= X[, 1]) & (0.5 > X[, 1])) * (0.4) + + ((0.5 <= X[, 1]) & (0.75 > X[, 1])) * (0.6) + + ((0.75 <= X[, 1]) & (1 > X[, 1])) * (0.8)) + # fmt: skip + tau_X <- (((0 <= X[, 2]) & (0.25 > X[, 2])) * (0.5) + + ((0.25 <= X[, 2]) & (0.5 > X[, 2])) * (1.0) + + ((0.5 <= X[, 2]) & (0.75 > X[, 2])) * (1.5) + + ((0.75 <= X[, 2]) & (1 > X[, 2])) * (2.0)) + Z <- rbinom(n, 1, pi_X) + noise_sd <- 1 + y <- mu_X + tau_X * Z + rnorm(n, 0, noise_sd) + test_set_pct <- 0.2 + n_test <- round(test_set_pct * n) + n_train <- n - n_test + test_inds <- sort(sample(1:n, n_test, replace = FALSE)) + train_inds <- (1:n)[!((1:n) %in% test_inds)] + X_test <- X[test_inds, ] + X_train <- X[train_inds, ] + Z_test <- Z[test_inds] + Z_train <- Z[train_inds] + pi_test <- pi_X[test_inds] + pi_train <- pi_X[train_inds] + mu_test <- mu_X[test_inds] + mu_train <- mu_X[train_inds] + tau_test <- tau_X[test_inds] + tau_train <- tau_X[train_inds] + y_test <- y[test_inds] + y_train <- y[train_inds] + + # Run a BCF model with only GFR + general_param_list <- list(num_chains = 1, keep_every = 1) + bcf_model <- bcf( + X_train = X_train, + y_train = y_train, + Z_train = Z_train, + propensity_train = pi_train, + X_test = X_test, + Z_test = Z_test, + propensity_test = pi_test, + num_gfr = 10, + num_burnin = 0, + num_mcmc = 0, + general_params = general_param_list + ) + + # Save to JSON string + bcf_model_json_string <- saveBCFModelToJsonString(bcf_model) + + # Run a new BCF chain from the existing (X)BCF model + general_param_list <- list(num_chains = 3, keep_every = 5) + expect_no_error( + bcf_model <- bcf( + X_train = X_train, + y_train = y_train, + Z_train = Z_train, + propensity_train = pi_train, + X_test = X_test, + Z_test = Z_test, + propensity_test = pi_test, + num_gfr = 0, + num_burnin = 10, + num_mcmc = 10, + previous_model_json = bcf_model_json_string, + previous_model_warmstart_sample_num = 1, + general_params = general_param_list + ) + ) + + # Generate simulated data with random effects + n <- 100 + p <- 5 + X <- matrix(runif(n * p), ncol = p) + # fmt: skip + mu_X <- (((0 <= X[, 1]) & (0.25 > X[, 1])) * (-7.5) + + ((0.25 <= X[, 1]) & (0.5 > X[, 1])) * (-2.5) + + ((0.5 <= X[, 1]) & (0.75 > X[, 1])) * (2.5) + + ((0.75 <= X[, 1]) & (1 > X[, 1])) * (7.5)) + # fmt: skip + pi_X <- (((0 <= X[, 1]) & (0.25 > X[, 1])) * (0.2) + + ((0.25 <= X[, 1]) & (0.5 > X[, 1])) * (0.4) + + ((0.5 <= X[, 1]) & (0.75 > X[, 1])) * (0.6) + + ((0.75 <= X[, 1]) & (1 > X[, 1])) * (0.8)) + # fmt: skip + tau_X <- (((0 <= X[, 2]) & (0.25 > X[, 2])) * (0.5) + + ((0.25 <= X[, 2]) & (0.5 > X[, 2])) * (1.0) + + ((0.5 <= X[, 2]) & (0.75 > X[, 2])) * (1.5) + + ((0.75 <= X[, 2]) & (1 > X[, 2])) * (2.0)) + Z <- rbinom(n, 1, pi_X) + rfx_group_ids <- sample(1:2, size = n, replace = TRUE) + rfx_basis <- rep(1, n) + rfx_coefs <- c(-5, 5) + rfx_term <- rfx_coefs[rfx_group_ids] * rfx_basis + noise_sd <- 1 + y <- mu_X + tau_X * Z + rfx_term + rnorm(n, 0, noise_sd) + test_set_pct <- 0.2 + n_test <- round(test_set_pct * n) + n_train <- n - n_test + test_inds <- sort(sample(1:n, n_test, replace = FALSE)) + train_inds <- (1:n)[!((1:n) %in% test_inds)] + X_test <- X[test_inds, ] + X_train <- X[train_inds, ] + Z_test <- Z[test_inds] + Z_train <- Z[train_inds] + pi_test <- pi_X[test_inds] + pi_train <- pi_X[train_inds] + mu_test <- mu_X[test_inds] + mu_train <- mu_X[train_inds] + tau_test <- tau_X[test_inds] + tau_train <- tau_X[train_inds] + rfx_group_ids_test <- rfx_group_ids[test_inds] + rfx_group_ids_train <- rfx_group_ids[train_inds] + rfx_basis_test <- rfx_basis[test_inds] + rfx_basis_train <- rfx_basis[train_inds] + y_test <- y[test_inds] + y_train <- y[train_inds] + + # Run a BCF model with only GFR + general_param_list <- list(num_chains = 1, keep_every = 1) + bcf_model <- bcf( + X_train = X_train, + y_train = y_train, + Z_train = Z_train, + propensity_train = pi_train, + X_test = X_test, + Z_test = Z_test, + rfx_group_ids_train = rfx_group_ids_train, + rfx_group_ids_test = rfx_group_ids_test, + rfx_basis_train = rfx_basis_train, + rfx_basis_test = rfx_basis_test, + propensity_test = pi_test, + num_gfr = 10, + num_burnin = 0, + num_mcmc = 0, + general_params = general_param_list + ) + + # Save to JSON string + bcf_model_json_string <- saveBCFModelToJsonString(bcf_model) + + # Run a new BCF chain from the existing (X)BCF model + general_param_list <- list(num_chains = 3, keep_every = 5) + expect_no_error( + bcf_model <- bcf( + X_train = X_train, + y_train = y_train, + Z_train = Z_train, + propensity_train = pi_train, + X_test = X_test, + Z_test = Z_test, + rfx_group_ids_train = rfx_group_ids_train, + rfx_group_ids_test = rfx_group_ids_test, + rfx_basis_train = rfx_basis_train, + rfx_basis_test = rfx_basis_test, + propensity_test = pi_test, + num_gfr = 0, + num_burnin = 10, + num_mcmc = 10, + previous_model_json = bcf_model_json_string, + previous_model_warmstart_sample_num = 1, + general_params = general_param_list + ) + ) }) test_that("Multivariate Treatment MCMC BCF", { - skip_on_cran() - - # Generate simulated data - n <- 100 - p <- 5 - X <- matrix(runif(n*p), ncol = p) - mu_X <- ( - ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + - ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + - ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + - ((0.75 <= X[,1]) & (1 > X[,1])) * (7.5) - ) - pi_X_1 <- ( - ((0 <= X[,1]) & (0.25 > X[,1])) * (0.2) + - ((0.25 <= X[,1]) & (0.5 > X[,1])) * (0.4) + - ((0.5 <= X[,1]) & (0.75 > X[,1])) * (0.6) + - ((0.75 <= X[,1]) & (1 > X[,1])) * (0.8) - ) - pi_X_2 <- ( - ((0 <= X[,2]) & (0.25 > X[,2])) * (0.8) + - ((0.25 <= X[,2]) & (0.5 > X[,2])) * (0.4) + - ((0.5 <= X[,2]) & (0.75 > X[,2])) * (0.6) + - ((0.75 <= X[,2]) & (1 > X[,2])) * (0.2) - ) - pi_X <- cbind(pi_X_1, pi_X_2) - tau_X_1 <- ( - ((0 <= X[,2]) & (0.25 > X[,2])) * (0.5) + - ((0.25 <= X[,2]) & (0.5 > X[,2])) * (1.0) + - ((0.5 <= X[,2]) & (0.75 > X[,2])) * (1.5) + - ((0.75 <= X[,2]) & (1 > X[,2])) * (2.0) - ) - tau_X_2 <- ( - ((0 <= X[,3]) & (0.25 > X[,3])) * (-0.5) + - ((0.25 <= X[,3]) & (0.5 > X[,3])) * (-1.5) + - ((0.5 <= X[,3]) & (0.75 > X[,3])) * (-1.0) + - ((0.75 <= X[,3]) & (1 > X[,3])) * (0.0) - ) - tau_X <- cbind(tau_X_1, tau_X_2) - Z_1 <- as.numeric(rbinom(n, 1, pi_X_1)) - Z_2 <- as.numeric(rbinom(n, 1, pi_X_2)) - Z <- cbind(Z_1, Z_2) - noise_sd <- 1 - y <- mu_X + rowSums(tau_X*Z) + rnorm(n, 0, noise_sd) - test_set_pct <- 0.2 - n_test <- round(test_set_pct*n) - n_train <- n - n_test - test_inds <- sort(sample(1:n, n_test, replace = FALSE)) - train_inds <- (1:n)[!((1:n) %in% test_inds)] - X_test <- X[test_inds,] - X_train <- X[train_inds,] - Z_test <- Z[test_inds,] - Z_train <- Z[train_inds,] - pi_test <- pi_X[test_inds,] - pi_train <- pi_X[train_inds,] - mu_test <- mu_X[test_inds] - mu_train <- mu_X[train_inds] - tau_test <- tau_X[test_inds,] - tau_train <- tau_X[train_inds,] - y_test <- y[test_inds] - y_train <- y[train_inds] - - # 1 chain, no thinning - general_param_list <- list(num_chains = 1, keep_every = 1, adaptive_coding = F) - expect_no_error({ - bcf_model <- bcf(X_train = X_train, y_train = y_train, Z_train = Z_train, - propensity_train = pi_train, X_test = X_test, Z_test = Z_test, - propensity_test = pi_test, num_gfr = 0, num_burnin = 10, - num_mcmc = 10, general_params = general_param_list) - predict(bcf_model, X = X_test, Z = Z_test, propensity = pi_test) - }) + skip_on_cran() + + # Generate simulated data + n <- 100 + p <- 5 + X <- matrix(runif(n * p), ncol = p) + # fmt: skip + mu_X <- (((0 <= X[, 1]) & (0.25 > X[, 1])) * (-7.5) + + ((0.25 <= X[, 1]) & (0.5 > X[, 1])) * (-2.5) + + ((0.5 <= X[, 1]) & (0.75 > X[, 1])) * (2.5) + + ((0.75 <= X[, 1]) & (1 > X[, 1])) * (7.5)) + # fmt: skip + pi_X_1 <- (((0 <= X[, 1]) & (0.25 > X[, 1])) * (0.2) + + ((0.25 <= X[, 1]) & (0.5 > X[, 1])) * (0.4) + + ((0.5 <= X[, 1]) & (0.75 > X[, 1])) * (0.6) + + ((0.75 <= X[, 1]) & (1 > X[, 1])) * (0.8)) + # fmt: skip + pi_X_2 <- (((0 <= X[, 2]) & (0.25 > X[, 2])) * (0.8) + + ((0.25 <= X[, 2]) & (0.5 > X[, 2])) * (0.4) + + ((0.5 <= X[, 2]) & (0.75 > X[, 2])) * (0.6) + + ((0.75 <= X[, 2]) & (1 > X[, 2])) * (0.2)) + pi_X <- cbind(pi_X_1, pi_X_2) + # fmt: skip + tau_X_1 <- (((0 <= X[, 2]) & (0.25 > X[, 2])) * (0.5) + + ((0.25 <= X[, 2]) & (0.5 > X[, 2])) * (1.0) + + ((0.5 <= X[, 2]) & (0.75 > X[, 2])) * (1.5) + + ((0.75 <= X[, 2]) & (1 > X[, 2])) * (2.0)) + # fmt: skip + tau_X_2 <- (((0 <= X[, 3]) & (0.25 > X[, 3])) * (-0.5) + + ((0.25 <= X[, 3]) & (0.5 > X[, 3])) * (-1.5) + + ((0.5 <= X[, 3]) & (0.75 > X[, 3])) * (-1.0) + + ((0.75 <= X[, 3]) & (1 > X[, 3])) * (0.0)) + tau_X <- cbind(tau_X_1, tau_X_2) + Z_1 <- as.numeric(rbinom(n, 1, pi_X_1)) + Z_2 <- as.numeric(rbinom(n, 1, pi_X_2)) + Z <- cbind(Z_1, Z_2) + noise_sd <- 1 + y <- mu_X + rowSums(tau_X * Z) + rnorm(n, 0, noise_sd) + test_set_pct <- 0.2 + n_test <- round(test_set_pct * n) + n_train <- n - n_test + test_inds <- sort(sample(1:n, n_test, replace = FALSE)) + train_inds <- (1:n)[!((1:n) %in% test_inds)] + X_test <- X[test_inds, ] + X_train <- X[train_inds, ] + Z_test <- Z[test_inds, ] + Z_train <- Z[train_inds, ] + pi_test <- pi_X[test_inds, ] + pi_train <- pi_X[train_inds, ] + mu_test <- mu_X[test_inds] + mu_train <- mu_X[train_inds] + tau_test <- tau_X[test_inds, ] + tau_train <- tau_X[train_inds, ] + y_test <- y[test_inds] + y_train <- y[train_inds] + + # 1 chain, no thinning + general_param_list <- list( + num_chains = 1, + keep_every = 1, + adaptive_coding = F + ) + expect_no_error({ + bcf_model <- bcf( + X_train = X_train, + y_train = y_train, + Z_train = Z_train, + propensity_train = pi_train, + X_test = X_test, + Z_test = Z_test, + propensity_test = pi_test, + num_gfr = 0, + num_burnin = 10, + num_mcmc = 10, + general_params = general_param_list + ) + predict(bcf_model, X = X_test, Z = Z_test, propensity = pi_test) + }) }) test_that("BCF Predictions", { - skip_on_cran() - - # Generate simulated data - n <- 100 - p <- 5 - X <- matrix(runif(n*p), ncol = p) - mu_X <- ( - ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + - ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + - ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + - ((0.75 <= X[,1]) & (1 > X[,1])) * (7.5) - ) - pi_X <- ( - ((0 <= X[,1]) & (0.25 > X[,1])) * (0.2) + - ((0.25 <= X[,1]) & (0.5 > X[,1])) * (0.4) + - ((0.5 <= X[,1]) & (0.75 > X[,1])) * (0.6) + - ((0.75 <= X[,1]) & (1 > X[,1])) * (0.8) - ) - tau_X <- ( - ((0 <= X[,2]) & (0.25 > X[,2])) * (0.5) + - ((0.25 <= X[,2]) & (0.5 > X[,2])) * (1.0) + - ((0.5 <= X[,2]) & (0.75 > X[,2])) * (1.5) + - ((0.75 <= X[,2]) & (1 > X[,2])) * (2.0) - ) - Z <- rbinom(n, 1, pi_X) - noise_sd <- 1 - y <- mu_X + tau_X*Z + rnorm(n, 0, noise_sd) - test_set_pct <- 0.2 - n_test <- round(test_set_pct*n) - n_train <- n - n_test - test_inds <- sort(sample(1:n, n_test, replace = FALSE)) - train_inds <- (1:n)[!((1:n) %in% test_inds)] - X_test <- X[test_inds,] - X_train <- X[train_inds,] - Z_test <- Z[test_inds] - Z_train <- Z[train_inds] - pi_test <- pi_X[test_inds] - pi_train <- pi_X[train_inds] - mu_test <- mu_X[test_inds] - mu_train <- mu_X[train_inds] - tau_test <- tau_X[test_inds] - tau_train <- tau_X[train_inds] - y_test <- y[test_inds] - y_train <- y[train_inds] - - # Run a BCF model with only GFR - general_params <- list(num_chains = 1, keep_every = 1) - variance_forest_params <- list(num_trees = 50) - bcf_model <- bcf(X_train = X_train, y_train = y_train, Z_train = Z_train, - propensity_train = pi_train, X_test = X_test, Z_test = Z_test, - propensity_test = pi_test, num_gfr = 10, num_burnin = 0, - num_mcmc = 10, general_params = general_params, - variance_forest_params = variance_forest_params) - - # Check that cached predictions agree with results of predict() function - train_preds <- predict(bcf_model, X = X_train, Z = Z_train, propensity = pi_train) - train_preds_mean_cached <- bcf_model$y_hat_train - train_preds_mean_recomputed <- train_preds$y_hat - train_preds_variance_cached <- bcf_model$sigma2_x_hat_train - train_preds_variance_recomputed <- train_preds$variance_forest_predictions - - # Assertion - expect_equal(train_preds_mean_cached, train_preds_mean_recomputed) - expect_equal(train_preds_variance_cached, train_preds_variance_recomputed) + skip_on_cran() + + # Generate simulated data + n <- 100 + p <- 5 + X <- matrix(runif(n * p), ncol = p) + # fmt: skip + mu_X <- (((0 <= X[, 1]) & (0.25 > X[, 1])) * (-7.5) + + ((0.25 <= X[, 1]) & (0.5 > X[, 1])) * (-2.5) + + ((0.5 <= X[, 1]) & (0.75 > X[, 1])) * (2.5) + + ((0.75 <= X[, 1]) & (1 > X[, 1])) * (7.5)) + # fmt: skip + pi_X <- (((0 <= X[, 1]) & (0.25 > X[, 1])) * (0.2) + + ((0.25 <= X[, 1]) & (0.5 > X[, 1])) * (0.4) + + ((0.5 <= X[, 1]) & (0.75 > X[, 1])) * (0.6) + + ((0.75 <= X[, 1]) & (1 > X[, 1])) * (0.8)) + # fmt: skip + tau_X <- (((0 <= X[, 2]) & (0.25 > X[, 2])) * (0.5) + + ((0.25 <= X[, 2]) & (0.5 > X[, 2])) * (1.0) + + ((0.5 <= X[, 2]) & (0.75 > X[, 2])) * (1.5) + + ((0.75 <= X[, 2]) & (1 > X[, 2])) * (2.0)) + Z <- rbinom(n, 1, pi_X) + noise_sd <- 1 + y <- mu_X + tau_X * Z + rnorm(n, 0, noise_sd) + test_set_pct <- 0.2 + n_test <- round(test_set_pct * n) + n_train <- n - n_test + test_inds <- sort(sample(1:n, n_test, replace = FALSE)) + train_inds <- (1:n)[!((1:n) %in% test_inds)] + X_test <- X[test_inds, ] + X_train <- X[train_inds, ] + Z_test <- Z[test_inds] + Z_train <- Z[train_inds] + pi_test <- pi_X[test_inds] + pi_train <- pi_X[train_inds] + mu_test <- mu_X[test_inds] + mu_train <- mu_X[train_inds] + tau_test <- tau_X[test_inds] + tau_train <- tau_X[train_inds] + y_test <- y[test_inds] + y_train <- y[train_inds] + + # Run a BCF model with only GFR + general_params <- list(num_chains = 1, keep_every = 1) + variance_forest_params <- list(num_trees = 50) + bcf_model <- bcf( + X_train = X_train, + y_train = y_train, + Z_train = Z_train, + propensity_train = pi_train, + X_test = X_test, + Z_test = Z_test, + propensity_test = pi_test, + num_gfr = 10, + num_burnin = 0, + num_mcmc = 10, + general_params = general_params, + variance_forest_params = variance_forest_params + ) + + # Check that cached predictions agree with results of predict() function + train_preds <- predict( + bcf_model, + X = X_train, + Z = Z_train, + propensity = pi_train + ) + train_preds_mean_cached <- bcf_model$y_hat_train + train_preds_mean_recomputed <- train_preds$y_hat + train_preds_variance_cached <- bcf_model$sigma2_x_hat_train + train_preds_variance_recomputed <- train_preds$variance_forest_predictions + + # Assertion + expect_equal(train_preds_mean_cached, train_preds_mean_recomputed) + expect_equal(train_preds_variance_cached, train_preds_variance_recomputed) }) test_that("Random Effects BCF", { - skip_on_cran() - - # Generate simulated data - n <- 100 - p <- 5 - X <- matrix(runif(n*p), ncol = p) - mu_X <- ( - ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + - ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + - ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + - ((0.75 <= X[,1]) & (1 > X[,1])) * (7.5) - ) - pi_X <- ( - ((0 <= X[,1]) & (0.25 > X[,1])) * (0.2) + - ((0.25 <= X[,1]) & (0.5 > X[,1])) * (0.4) + - ((0.5 <= X[,1]) & (0.75 > X[,1])) * (0.6) + - ((0.75 <= X[,1]) & (1 > X[,1])) * (0.8) - ) - tau_X <- ( - ((0 <= X[,2]) & (0.25 > X[,2])) * (0.5) + - ((0.25 <= X[,2]) & (0.5 > X[,2])) * (1.0) + - ((0.5 <= X[,2]) & (0.75 > X[,2])) * (1.5) + - ((0.75 <= X[,2]) & (1 > X[,2])) * (2.0) - ) - Z <- rbinom(n, 1, pi_X) - rfx_group_ids <- sample(1:2, size = n, replace = TRUE) - rfx_basis <- cbind(rep(1, n), runif(n)) - num_rfx_components <- ncol(rfx_basis) - num_rfx_groups <- length(unique(rfx_group_ids)) - rfx_coefs <- matrix(c(-5, 5, 1, -1), ncol = 2, byrow = T) - rfx_term <- rowSums(rfx_coefs[rfx_group_ids,] * rfx_basis) - noise_sd <- 1 - y <- mu_X + tau_X*Z + rfx_term + rnorm(n, 0, noise_sd) - - test_set_pct <- 0.2 - n_test <- round(test_set_pct*n) - n_train <- n - n_test - test_inds <- sort(sample(1:n, n_test, replace = FALSE)) - train_inds <- (1:n)[!((1:n) %in% test_inds)] - X_test <- X[test_inds,] - X_train <- X[train_inds,] - Z_test <- Z[test_inds] - Z_train <- Z[train_inds] - pi_test <- pi_X[test_inds] - pi_train <- pi_X[train_inds] - mu_test <- mu_X[test_inds] - mu_train <- mu_X[train_inds] - tau_test <- tau_X[test_inds] - tau_train <- tau_X[train_inds] - rfx_group_ids_test <- rfx_group_ids[test_inds] - rfx_group_ids_train <- rfx_group_ids[train_inds] - rfx_basis_test <- rfx_basis[test_inds,] - rfx_basis_train <- rfx_basis[train_inds,] - y_test <- y[test_inds] - y_train <- y[train_inds] - - # Specify no rfx parameters - general_param_list <- list() - expect_no_error( - bcf_model <- bcf(X_train = X_train, y_train = y_train, - Z_train = Z_train, propensity_train = pi_train, - rfx_group_ids_train = rfx_group_ids_train, - rfx_basis_train = rfx_basis_train, X_test = X_test, - Z_test = Z_test, propensity_test = pi_test, - rfx_group_ids_test = rfx_group_ids_test, - rfx_basis_test = rfx_basis_test, - num_gfr = 10, num_burnin = 0, num_mcmc = 10, - general_params = general_param_list) - ) + skip_on_cran() - # Specify all rfx parameters as scalars - general_param_list <- list(rfx_working_parameter_prior_mean = 1., - rfx_group_parameter_prior_mean = 1., - rfx_working_parameter_prior_cov = 1., - rfx_group_parameter_prior_cov = 1., - rfx_variance_prior_shape = 1, - rfx_variance_prior_scale = 1) - expect_no_error( - bcf_model <- bcf(X_train = X_train, y_train = y_train, - Z_train = Z_train, propensity_train = pi_train, - rfx_group_ids_train = rfx_group_ids_train, - rfx_basis_train = rfx_basis_train, X_test = X_test, - Z_test = Z_test, propensity_test = pi_test, - rfx_group_ids_test = rfx_group_ids_test, - rfx_basis_test = rfx_basis_test, - num_gfr = 10, num_burnin = 0, num_mcmc = 10, - general_params = general_param_list) - ) - - # Specify all relevant rfx parameters as vectors - general_param_list <- list(rfx_working_parameter_prior_mean = c(1.,1.), - rfx_group_parameter_prior_mean = c(1.,1.), - rfx_working_parameter_prior_cov = diag(1.,2), - rfx_group_parameter_prior_cov = diag(1.,2), - rfx_variance_prior_shape = 1, - rfx_variance_prior_scale = 1) - expect_no_error( - bcf_model <- bcf(X_train = X_train, y_train = y_train, - Z_train = Z_train, propensity_train = pi_train, - rfx_group_ids_train = rfx_group_ids_train, - rfx_basis_train = rfx_basis_train, X_test = X_test, - Z_test = Z_test, propensity_test = pi_test, - rfx_group_ids_test = rfx_group_ids_test, - rfx_basis_test = rfx_basis_test, - num_gfr = 10, num_burnin = 0, num_mcmc = 10, - general_params = general_param_list) - ) + # Generate simulated data + n <- 100 + p <- 5 + X <- matrix(runif(n * p), ncol = p) + # fmt: skip + mu_X <- (((0 <= X[, 1]) & (0.25 > X[, 1])) * (-7.5) + + ((0.25 <= X[, 1]) & (0.5 > X[, 1])) * (-2.5) + + ((0.5 <= X[, 1]) & (0.75 > X[, 1])) * (2.5) + + ((0.75 <= X[, 1]) & (1 > X[, 1])) * (7.5)) + # fmt: skip + pi_X <- (((0 <= X[, 1]) & (0.25 > X[, 1])) * (0.2) + + ((0.25 <= X[, 1]) & (0.5 > X[, 1])) * (0.4) + + ((0.5 <= X[, 1]) & (0.75 > X[, 1])) * (0.6) + + ((0.75 <= X[, 1]) & (1 > X[, 1])) * (0.8)) + # fmt: skip + tau_X <- (((0 <= X[, 2]) & (0.25 > X[, 2])) * (0.5) + + ((0.25 <= X[, 2]) & (0.5 > X[, 2])) * (1.0) + + ((0.5 <= X[, 2]) & (0.75 > X[, 2])) * (1.5) + + ((0.75 <= X[, 2]) & (1 > X[, 2])) * (2.0)) + Z <- rbinom(n, 1, pi_X) + rfx_group_ids <- sample(1:2, size = n, replace = TRUE) + rfx_basis <- cbind(rep(1, n), runif(n)) + num_rfx_components <- ncol(rfx_basis) + num_rfx_groups <- length(unique(rfx_group_ids)) + rfx_coefs <- matrix(c(-5, 5, 1, -1), ncol = 2, byrow = T) + rfx_term <- rowSums(rfx_coefs[rfx_group_ids, ] * rfx_basis) + noise_sd <- 1 + y <- mu_X + tau_X * Z + rfx_term + rnorm(n, 0, noise_sd) + + test_set_pct <- 0.2 + n_test <- round(test_set_pct * n) + n_train <- n - n_test + test_inds <- sort(sample(1:n, n_test, replace = FALSE)) + train_inds <- (1:n)[!((1:n) %in% test_inds)] + X_test <- X[test_inds, ] + X_train <- X[train_inds, ] + Z_test <- Z[test_inds] + Z_train <- Z[train_inds] + pi_test <- pi_X[test_inds] + pi_train <- pi_X[train_inds] + mu_test <- mu_X[test_inds] + mu_train <- mu_X[train_inds] + tau_test <- tau_X[test_inds] + tau_train <- tau_X[train_inds] + rfx_group_ids_test <- rfx_group_ids[test_inds] + rfx_group_ids_train <- rfx_group_ids[train_inds] + rfx_basis_test <- rfx_basis[test_inds, ] + rfx_basis_train <- rfx_basis[train_inds, ] + y_test <- y[test_inds] + y_train <- y[train_inds] + + # Specify no rfx parameters + general_param_list <- list() + expect_no_error( + bcf_model <- bcf( + X_train = X_train, + y_train = y_train, + Z_train = Z_train, + propensity_train = pi_train, + rfx_group_ids_train = rfx_group_ids_train, + rfx_basis_train = rfx_basis_train, + X_test = X_test, + Z_test = Z_test, + propensity_test = pi_test, + rfx_group_ids_test = rfx_group_ids_test, + rfx_basis_test = rfx_basis_test, + num_gfr = 10, + num_burnin = 0, + num_mcmc = 10, + general_params = general_param_list + ) + ) + + # Specify all rfx parameters as scalars + general_param_list <- list( + rfx_working_parameter_prior_mean = 1., + rfx_group_parameter_prior_mean = 1., + rfx_working_parameter_prior_cov = 1., + rfx_group_parameter_prior_cov = 1., + rfx_variance_prior_shape = 1, + rfx_variance_prior_scale = 1 + ) + expect_no_error( + bcf_model <- bcf( + X_train = X_train, + y_train = y_train, + Z_train = Z_train, + propensity_train = pi_train, + rfx_group_ids_train = rfx_group_ids_train, + rfx_basis_train = rfx_basis_train, + X_test = X_test, + Z_test = Z_test, + propensity_test = pi_test, + rfx_group_ids_test = rfx_group_ids_test, + rfx_basis_test = rfx_basis_test, + num_gfr = 10, + num_burnin = 0, + num_mcmc = 10, + general_params = general_param_list + ) + ) + + # Specify all relevant rfx parameters as vectors + general_param_list <- list( + rfx_working_parameter_prior_mean = c(1., 1.), + rfx_group_parameter_prior_mean = c(1., 1.), + rfx_working_parameter_prior_cov = diag(1., 2), + rfx_group_parameter_prior_cov = diag(1., 2), + rfx_variance_prior_shape = 1, + rfx_variance_prior_scale = 1 + ) + expect_no_error( + bcf_model <- bcf( + X_train = X_train, + y_train = y_train, + Z_train = Z_train, + propensity_train = pi_train, + rfx_group_ids_train = rfx_group_ids_train, + rfx_basis_train = rfx_basis_train, + X_test = X_test, + Z_test = Z_test, + propensity_test = pi_test, + rfx_group_ids_test = rfx_group_ids_test, + rfx_basis_test = rfx_basis_test, + num_gfr = 10, + num_burnin = 0, + num_mcmc = 10, + general_params = general_param_list + ) + ) }) diff --git a/test/R/testthat/test-categorical.R b/test/R/testthat/test-categorical.R index 92d7459f..19a7468e 100644 --- a/test/R/testthat/test-categorical.R +++ b/test/R/testthat/test-categorical.R @@ -1,7 +1,8 @@ test_that("In-sample one-hot encoding works for unordered categorical variables", { - x1 <- c(3,2,1,4,3,2,3,2,4,2) - x1_onehot <- oneHotInitializeAndEncode(x1) - x1_expected <- matrix( + x1 <- c(3, 2, 1, 4, 3, 2, 3, 2, 4, 2) + x1_onehot <- oneHotInitializeAndEncode(x1) + # fmt: skip + x1_expected <- matrix( c(0,0,1,0,0, 0,1,0,0,0, 1,0,0,0,0, @@ -13,11 +14,12 @@ test_that("In-sample one-hot encoding works for unordered categorical variables" 0,0,0,1,0, 0,1,0,0,0), byrow = TRUE, ncol = 5) - x1_levels_expected <- c("1","2","3","4") - - x2 <- c("a","c","b","c","d","a","c","a","b","d") - x2_onehot <- oneHotInitializeAndEncode(x2) - x2_expected <- matrix( + x1_levels_expected <- c("1", "2", "3", "4") + + x2 <- c("a", "c", "b", "c", "d", "a", "c", "a", "b", "d") + x2_onehot <- oneHotInitializeAndEncode(x2) + # fmt: skip + x2_expected <- matrix( c(1,0,0,0,0, 0,0,1,0,0, 0,1,0,0,0, @@ -29,11 +31,12 @@ test_that("In-sample one-hot encoding works for unordered categorical variables" 0,1,0,0,0, 0,0,0,1,0), byrow = TRUE, ncol = 5) - x2_levels_expected <- c("a","b","c","d") - - x3 <- c(3.2,2.4,1.5,4.6,3.2,2.4,3.2,2.4,4.6,2.4) - x3_onehot <- oneHotInitializeAndEncode(x3) - x3_expected <- matrix( + x2_levels_expected <- c("a", "b", "c", "d") + + x3 <- c(3.2, 2.4, 1.5, 4.6, 3.2, 2.4, 3.2, 2.4, 4.6, 2.4) + x3_onehot <- oneHotInitializeAndEncode(x3) + # fmt: skip + x3_expected <- matrix( c(0,0,1,0,0, 0,1,0,0,0, 1,0,0,0,0, @@ -45,111 +48,159 @@ test_that("In-sample one-hot encoding works for unordered categorical variables" 0,0,0,1,0, 0,1,0,0,0), byrow = TRUE, ncol = 5) - x3_levels_expected <- c("1.5","2.4","3.2","4.6") - - expect_equal(x1_onehot$Xtilde, x1_expected) - expect_equal(x2_onehot$Xtilde, x2_expected) - expect_equal(x3_onehot$Xtilde, x3_expected) - expect_equal(x1_onehot$unique_levels, x1_levels_expected) - expect_equal(x2_onehot$unique_levels, x2_levels_expected) - expect_equal(x3_onehot$unique_levels, x3_levels_expected) + x3_levels_expected <- c("1.5", "2.4", "3.2", "4.6") + + expect_equal(x1_onehot$Xtilde, x1_expected) + expect_equal(x2_onehot$Xtilde, x2_expected) + expect_equal(x3_onehot$Xtilde, x3_expected) + expect_equal(x1_onehot$unique_levels, x1_levels_expected) + expect_equal(x2_onehot$unique_levels, x2_levels_expected) + expect_equal(x3_onehot$unique_levels, x3_levels_expected) }) test_that("Out-of-sample one-hot encoding works for unordered categorical variables", { - x1 <- c(3,2,1,4,3,2,3,2,4,2) - x1_test <- c(1,2,4,3,5) - x1_test_onehot <- oneHotEncode(x1_test, levels(factor(x1))) - x1_test_expected <- matrix( + x1 <- c(3, 2, 1, 4, 3, 2, 3, 2, 4, 2) + x1_test <- c(1, 2, 4, 3, 5) + x1_test_onehot <- oneHotEncode(x1_test, levels(factor(x1))) + # fmt: skip + x1_test_expected <- matrix( c(1,0,0,0,0, 0,1,0,0,0, 0,0,0,1,0, 0,0,1,0,0, 0,0,0,0,1), byrow = TRUE, ncol = 5) - - x2 <- c("a","c","b","c","d","a","c","a","b","d") - x2_test <- c("a","c","g","b","f") - x2_test_onehot <- oneHotEncode(x2_test, levels(factor(x2))) - x2_test_expected <- matrix( + + x2 <- c("a", "c", "b", "c", "d", "a", "c", "a", "b", "d") + x2_test <- c("a", "c", "g", "b", "f") + x2_test_onehot <- oneHotEncode(x2_test, levels(factor(x2))) + # fmt: skip + x2_test_expected <- matrix( c(1,0,0,0,0, 0,0,1,0,0, 0,0,0,0,1, 0,1,0,0,0, 0,0,0,0,1), byrow = TRUE, ncol = 5) - - x3 <- c(3.2,2.4,1.5,4.6,3.2,2.4,3.2,2.4,4.6,2.4) - x3_test <- c(10.3,-0.5,4.6,3.2,1.8) - x3_test_onehot <- oneHotEncode(x3_test, levels(factor(x3))) - x3_test_expected <- matrix( + + x3 <- c(3.2, 2.4, 1.5, 4.6, 3.2, 2.4, 3.2, 2.4, 4.6, 2.4) + x3_test <- c(10.3, -0.5, 4.6, 3.2, 1.8) + x3_test_onehot <- oneHotEncode(x3_test, levels(factor(x3))) + # fmt: skip + x3_test_expected <- matrix( c(0,0,0,0,1, 0,0,0,0,1, 0,0,0,1,0, 0,0,1,0,0, 0,0,0,0,1), byrow = TRUE, ncol = 5) - - expect_equal(x1_test_onehot, x1_test_expected) - expect_equal(x2_test_onehot, x2_test_expected) - expect_equal(x3_test_onehot, x3_test_expected) + + expect_equal(x1_test_onehot, x1_test_expected) + expect_equal(x2_test_onehot, x2_test_expected) + expect_equal(x3_test_onehot, x3_test_expected) }) test_that("In-sample preprocessing for ordered categorical variables", { - string_var_response_levels <- c("1. Strongly disagree", "2. Disagree", "3. Neither agree nor disagree", "4. Agree", "5. Strongly agree") - - x1 <- c("1. Strongly disagree", "3. Neither agree nor disagree", "2. Disagree", "4. Agree", "3. Neither agree nor disagree", "5. Strongly agree", "4. Agree") - x1_preprocessing <- orderedCatInitializeAndPreprocess(x1) - x1_vector_expected <- c(1,3,2,4,3,5,4) - x1_levels_expected <- string_var_response_levels - - x2 <- factor(x1, levels = string_var_response_levels, ordered = TRUE) - x2_preprocessing <- orderedCatInitializeAndPreprocess(x2) - x2_vector_expected <- c(1,3,2,4,3,5,4) - x2_levels_expected <- string_var_response_levels - - string_var_levels_reordered <- c("5. Strongly agree", "4. Agree", "3. Neither agree nor disagree", "2. Disagree", "1. Strongly disagree") - x3 <- factor(x1, levels = string_var_levels_reordered, ordered = TRUE) - x3_preprocessing <- orderedCatInitializeAndPreprocess(x3) - x3_vector_expected <- c(5,3,4,2,3,1,2) - x3_levels_expected <- string_var_levels_reordered - - x4 <- c(3,2,4,6,5,2,3,1,3,4,6) - x4_preprocessing <- orderedCatInitializeAndPreprocess(x4) - x4_vector_expected <- c(3,2,4,6,5,2,3,1,3,4,6) - x4_levels_expected <- c("1","2","3","4","5","6") - - expect_equal(x1_preprocessing$x_preprocessed, x1_vector_expected) - expect_equal(x2_preprocessing$x_preprocessed, x2_vector_expected) - expect_equal(x3_preprocessing$x_preprocessed, x3_vector_expected) - expect_equal(x4_preprocessing$x_preprocessed, x4_vector_expected) - expect_equal(x1_preprocessing$unique_levels, x1_levels_expected) - expect_equal(x2_preprocessing$unique_levels, x2_levels_expected) - expect_equal(x3_preprocessing$unique_levels, x3_levels_expected) - expect_equal(x4_preprocessing$unique_levels, x4_levels_expected) + string_var_response_levels <- c( + "1. Strongly disagree", + "2. Disagree", + "3. Neither agree nor disagree", + "4. Agree", + "5. Strongly agree" + ) + + x1 <- c( + "1. Strongly disagree", + "3. Neither agree nor disagree", + "2. Disagree", + "4. Agree", + "3. Neither agree nor disagree", + "5. Strongly agree", + "4. Agree" + ) + x1_preprocessing <- orderedCatInitializeAndPreprocess(x1) + x1_vector_expected <- c(1, 3, 2, 4, 3, 5, 4) + x1_levels_expected <- string_var_response_levels + + x2 <- factor(x1, levels = string_var_response_levels, ordered = TRUE) + x2_preprocessing <- orderedCatInitializeAndPreprocess(x2) + x2_vector_expected <- c(1, 3, 2, 4, 3, 5, 4) + x2_levels_expected <- string_var_response_levels + + string_var_levels_reordered <- c( + "5. Strongly agree", + "4. Agree", + "3. Neither agree nor disagree", + "2. Disagree", + "1. Strongly disagree" + ) + x3 <- factor(x1, levels = string_var_levels_reordered, ordered = TRUE) + x3_preprocessing <- orderedCatInitializeAndPreprocess(x3) + x3_vector_expected <- c(5, 3, 4, 2, 3, 1, 2) + x3_levels_expected <- string_var_levels_reordered + + x4 <- c(3, 2, 4, 6, 5, 2, 3, 1, 3, 4, 6) + x4_preprocessing <- orderedCatInitializeAndPreprocess(x4) + x4_vector_expected <- c(3, 2, 4, 6, 5, 2, 3, 1, 3, 4, 6) + x4_levels_expected <- c("1", "2", "3", "4", "5", "6") + + expect_equal(x1_preprocessing$x_preprocessed, x1_vector_expected) + expect_equal(x2_preprocessing$x_preprocessed, x2_vector_expected) + expect_equal(x3_preprocessing$x_preprocessed, x3_vector_expected) + expect_equal(x4_preprocessing$x_preprocessed, x4_vector_expected) + expect_equal(x1_preprocessing$unique_levels, x1_levels_expected) + expect_equal(x2_preprocessing$unique_levels, x2_levels_expected) + expect_equal(x3_preprocessing$unique_levels, x3_levels_expected) + expect_equal(x4_preprocessing$unique_levels, x4_levels_expected) }) test_that("Out-of-sample preprocessing for ordered categorical variables", { - string_var_response_levels <- c("1. Strongly disagree", "2. Disagree", "3. Neither agree nor disagree", "4. Agree", "5. Strongly agree") - - x1 <- c("1. Strongly disagree", "3. Neither agree nor disagree", "2. Disagree", "4. Agree", "3. Neither agree nor disagree", "5. Strongly agree", "4. Agree") - x1_preprocessing <- orderedCatPreprocess(x1, string_var_response_levels) - x1_vector_expected <- c(1,3,2,4,3,5,4) - - x2 <- factor(x1, levels = string_var_response_levels, ordered = TRUE) - x2_preprocessing <- orderedCatPreprocess(x2, string_var_response_levels) - x2_vector_expected <- c(1,3,2,4,3,5,4) - - x3 <- c("1. Strongly disagree", "6. Other", "7. Also other", "4. Agree", "3. Neither agree nor disagree", "5. Strongly agree", "4. Agree") - expected_warning_message <- "Variable includes ordered categorical levels not included in the original training set" - expect_warning(x3_preprocessing <- orderedCatPreprocess(x3, string_var_response_levels), expected_warning_message) - x3_vector_expected <- c(1,6,6,4,3,5,4) - - x4 <- c(3,2,4,6,5,2,3,1,3,4,6) - x4_preprocessing <- orderedCatPreprocess(x4, c("1","2","3","4","5","6")) - x4_vector_expected <- c(3,2,4,6,5,2,3,1,3,4,6) - - expect_equal(x1_preprocessing, x1_vector_expected) - expect_equal(x2_preprocessing, x2_vector_expected) - expect_equal(x3_preprocessing, x3_vector_expected) - expect_equal(x4_preprocessing, x4_vector_expected) + string_var_response_levels <- c( + "1. Strongly disagree", + "2. Disagree", + "3. Neither agree nor disagree", + "4. Agree", + "5. Strongly agree" + ) + + x1 <- c( + "1. Strongly disagree", + "3. Neither agree nor disagree", + "2. Disagree", + "4. Agree", + "3. Neither agree nor disagree", + "5. Strongly agree", + "4. Agree" + ) + x1_preprocessing <- orderedCatPreprocess(x1, string_var_response_levels) + x1_vector_expected <- c(1, 3, 2, 4, 3, 5, 4) + + x2 <- factor(x1, levels = string_var_response_levels, ordered = TRUE) + x2_preprocessing <- orderedCatPreprocess(x2, string_var_response_levels) + x2_vector_expected <- c(1, 3, 2, 4, 3, 5, 4) + + x3 <- c( + "1. Strongly disagree", + "6. Other", + "7. Also other", + "4. Agree", + "3. Neither agree nor disagree", + "5. Strongly agree", + "4. Agree" + ) + expected_warning_message <- "Variable includes ordered categorical levels not included in the original training set" + expect_warning( + x3_preprocessing <- orderedCatPreprocess(x3, string_var_response_levels), + expected_warning_message + ) + x3_vector_expected <- c(1, 6, 6, 4, 3, 5, 4) + + x4 <- c(3, 2, 4, 6, 5, 2, 3, 1, 3, 4, 6) + x4_preprocessing <- orderedCatPreprocess(x4, c("1", "2", "3", "4", "5", "6")) + x4_vector_expected <- c(3, 2, 4, 6, 5, 2, 3, 1, 3, 4, 6) + + expect_equal(x1_preprocessing, x1_vector_expected) + expect_equal(x2_preprocessing, x2_vector_expected) + expect_equal(x3_preprocessing, x3_vector_expected) + expect_equal(x4_preprocessing, x4_vector_expected) }) diff --git a/test/R/testthat/test-data-preprocessing.R b/test/R/testthat/test-data-preprocessing.R index 61752aa7..ced792a0 100644 --- a/test/R/testthat/test-data-preprocessing.R +++ b/test/R/testthat/test-data-preprocessing.R @@ -13,7 +13,7 @@ # expect_equal(preprocess_list$metadata$num_unordered_cat_vars, 0) # expect_equal(preprocess_list$metadata$numeric_vars, c("x1","x2","x3")) # }) -# +# # test_that("Preprocessing of all-unordered-categorical covariate dataset works", { # cov_df <- data.frame(x1 = 1:5, x2 = 5:1, x3 = 6:10) # cov_mat <- matrix(c( @@ -30,13 +30,13 @@ # expect_equal(preprocess_list$metadata$num_ordered_cat_vars, 0) # expect_equal(preprocess_list$metadata$num_unordered_cat_vars, 3) # expect_equal(preprocess_list$metadata$unordered_cat_vars, c("x1","x2","x3")) -# expect_equal(preprocess_list$metadata$unordered_unique_levels, -# list(x1=c("1","2","3","4","5"), -# x2=c("1","2","3","4","5"), +# expect_equal(preprocess_list$metadata$unordered_unique_levels, +# list(x1=c("1","2","3","4","5"), +# x2=c("1","2","3","4","5"), # x3=c("6","7","8","9","10")) # ) # }) -# +# # test_that("Preprocessing of all-ordered-categorical covariate dataset works", { # cov_df <- data.frame(x1 = 1:5, x2 = 5:1, x3 = 6:10) # cov_mat <- matrix(c( @@ -51,13 +51,13 @@ # expect_equal(preprocess_list$metadata$num_ordered_cat_vars, 3) # expect_equal(preprocess_list$metadata$num_unordered_cat_vars, 0) # expect_equal(preprocess_list$metadata$ordered_cat_vars, c("x1","x2","x3")) -# expect_equal(preprocess_list$metadata$ordered_unique_levels, -# list(x1=c("1","2","3","4","5"), -# x2=c("1","2","3","4","5"), +# expect_equal(preprocess_list$metadata$ordered_unique_levels, +# list(x1=c("1","2","3","4","5"), +# x2=c("1","2","3","4","5"), # x3=c("6","7","8","9","10")) # ) # }) -# +# # test_that("Preprocessing of mixed-covariate dataset works", { # cov_df <- data.frame(x1 = 1:5, x2 = 5:1, x3 = 6:10) # cov_mat <- matrix(c( @@ -78,7 +78,7 @@ # expect_equal(preprocess_list$metadata$ordered_unique_levels, list(x2=c("1","2","3","4","5"))) # expect_equal(preprocess_list$metadata$unordered_unique_levels, list(x3=c("6","7","8","9","10"))) # }) -# +# # test_that("Preprocessing of mixed-covariate matrix works", { # cov_input <- matrix(c(1:5,5:1,6:10),ncol=3,byrow=FALSE) # cov_mat <- matrix(c( @@ -98,7 +98,7 @@ # expect_equal(preprocess_list$metadata$unordered_cat_vars, c("x3")) # expect_equal(preprocess_list$metadata$ordered_unique_levels, list(x2=c("1","2","3","4","5"))) # expect_equal(preprocess_list$metadata$unordered_unique_levels, list(x3=c("6","7","8","9","10"))) -# +# # alt_preprocess_list <- createForestCovariates(cov_input, ordered_cat_vars = "x2", unordered_cat_vars = "x3") # expect_equal(alt_preprocess_list$data, cov_mat) # expect_equal(alt_preprocess_list$metadata$feature_types, c(0, rep(1,7))) @@ -110,16 +110,16 @@ # expect_equal(alt_preprocess_list$metadata$ordered_unique_levels, list(x2=c("1","2","3","4","5"))) # expect_equal(alt_preprocess_list$metadata$unordered_unique_levels, list(x3=c("6","7","8","9","10"))) # }) -# +# # test_that("Preprocessing of out-of-sample mixed-covariate dataset works", { # metadata <- list( -# num_numeric_vars = 1, -# num_ordered_cat_vars = 1, -# num_unordered_cat_vars = 1, -# numeric_vars = c("x1"), -# ordered_cat_vars = c("x2"), -# unordered_cat_vars = c("x3"), -# ordered_unique_levels = list(x2=c("1","2","3","4","5")), +# num_numeric_vars = 1, +# num_ordered_cat_vars = 1, +# num_unordered_cat_vars = 1, +# numeric_vars = c("x1"), +# ordered_cat_vars = c("x2"), +# unordered_cat_vars = c("x3"), +# ordered_unique_levels = list(x2=c("1","2","3","4","5")), # unordered_unique_levels = list(x3=c("6","7","8","9","10")) # ) # cov_df <- data.frame(x1 = c(1:5,1), x2 = c(5:1,5), x3 = 6:11) @@ -134,7 +134,7 @@ # X_preprocessed <- createForestCovariatesFromMetadata(cov_df, metadata) # expect_equal(X_preprocessed, cov_mat) # }) -# +# # test_that("Preprocessing of all-numeric covariate dataset works", { # cov_df <- data.frame(x1 = 1:5, x2 = 5:1, x3 = 6:10) # cov_mat <- matrix(c( @@ -151,7 +151,7 @@ # expect_equal(preprocess_list$metadata$original_var_indices, 1:3) # expect_equal(preprocess_list$metadata$numeric_vars, c("x1","x2","x3")) # }) -# +# # test_that("Preprocessing of all-unordered-categorical covariate dataset works", { # cov_df <- data.frame(x1 = 1:5, x2 = 5:1, x3 = 6:10) # cov_df$x1 <- factor(cov_df$x1) @@ -173,13 +173,13 @@ # expect_equal(preprocess_list$metadata$unordered_cat_vars, c("x1","x2","x3")) # expected_var_indices <- c(rep(1,6),rep(2,6),rep(3,6)) # expect_equal(preprocess_list$metadata$original_var_indices, expected_var_indices) -# expect_equal(preprocess_list$metadata$unordered_unique_levels, -# list(x1=c("1","2","3","4","5"), -# x2=c("1","2","3","4","5"), +# expect_equal(preprocess_list$metadata$unordered_unique_levels, +# list(x1=c("1","2","3","4","5"), +# x2=c("1","2","3","4","5"), # x3=c("6","7","8","9","10")) # ) # }) -# +# # test_that("Preprocessing of all-ordered-categorical covariate dataset works", { # cov_df <- data.frame(x1 = 1:5, x2 = 5:1, x3 = 6:10) # cov_df$x1 <- factor(cov_df$x1, ordered = TRUE) @@ -198,13 +198,13 @@ # expect_equal(preprocess_list$metadata$num_unordered_cat_vars, 0) # expect_equal(preprocess_list$metadata$ordered_cat_vars, c("x1","x2","x3")) # expect_equal(preprocess_list$metadata$original_var_indices, 1:3) -# expect_equal(preprocess_list$metadata$ordered_unique_levels, -# list(x1=c("1","2","3","4","5"), -# x2=c("1","2","3","4","5"), +# expect_equal(preprocess_list$metadata$ordered_unique_levels, +# list(x1=c("1","2","3","4","5"), +# x2=c("1","2","3","4","5"), # x3=c("6","7","8","9","10")) # ) # }) -# +# # test_that("Preprocessing of mixed-covariate dataset works", { # cov_df <- data.frame(x1 = 1:5, x2 = 5:1, x3 = 6:10) # cov_df$x2 <- factor(cov_df$x2, ordered = TRUE) @@ -229,17 +229,17 @@ # expect_equal(preprocess_list$metadata$ordered_unique_levels, list(x2=c("1","2","3","4","5"))) # expect_equal(preprocess_list$metadata$unordered_unique_levels, list(x3=c("6","7","8","9","10"))) # }) -# +# # test_that("Preprocessing of out-of-sample mixed-covariate dataset works", { # metadata <- list( -# num_numeric_vars = 1, -# num_ordered_cat_vars = 1, -# num_unordered_cat_vars = 1, +# num_numeric_vars = 1, +# num_ordered_cat_vars = 1, +# num_unordered_cat_vars = 1, # original_var_indices = c(1, 2, 3, 3, 3, 3, 3, 3), -# numeric_vars = c("x1"), -# ordered_cat_vars = c("x2"), -# unordered_cat_vars = c("x3"), -# ordered_unique_levels = list(x2=c("1","2","3","4","5")), +# numeric_vars = c("x1"), +# ordered_cat_vars = c("x2"), +# unordered_cat_vars = c("x3"), +# ordered_unique_levels = list(x2=c("1","2","3","4","5")), # unordered_unique_levels = list(x3=c("6","7","8","9","10")) # ) # cov_df <- data.frame(x1 = c(1:5,1), x2 = c(5:1,5), x3 = 6:11) diff --git a/test/R/testthat/test-dataset.R b/test/R/testthat/test-dataset.R index 80f8d8e0..d9ecadf9 100644 --- a/test/R/testthat/test-dataset.R +++ b/test/R/testthat/test-dataset.R @@ -1,66 +1,73 @@ test_that("ForestDataset can be constructed and updated", { - # Generate data - n <- 20 - num_covariates <- 10 - num_basis <- 5 - covariates <- matrix(runif(n * num_covariates), ncol = num_covariates) - basis <- matrix(runif(n * num_basis), ncol = num_basis) - variance_weights <- runif(n) - - # Copy data to a ForestDataset object - forest_dataset <- createForestDataset(covariates, basis, variance_weights) - - # Run first round of expectations - expect_equal(forest_dataset$num_observations(), n) - expect_equal(forest_dataset$num_covariates(), num_covariates) - expect_equal(forest_dataset$num_basis(), num_basis) - expect_equal(forest_dataset$has_variance_weights(), T) - - # Update data - new_basis <- matrix(runif(n * num_basis), ncol = num_basis) - new_variance_weights <- runif(n) - expect_no_error( - forest_dataset$update_basis(new_basis) - ) - expect_no_error( - forest_dataset$update_variance_weights(new_variance_weights) - ) - - # Check that we recover the correct data through get_covariates, get_basis, and get_variance_weights - expect_equal(covariates, forest_dataset$get_covariates()) - expect_equal(new_basis, forest_dataset$get_basis()) - expect_equal(new_variance_weights, forest_dataset$get_variance_weights()) + # Generate data + n <- 20 + num_covariates <- 10 + num_basis <- 5 + covariates <- matrix(runif(n * num_covariates), ncol = num_covariates) + basis <- matrix(runif(n * num_basis), ncol = num_basis) + variance_weights <- runif(n) + + # Copy data to a ForestDataset object + forest_dataset <- createForestDataset(covariates, basis, variance_weights) + + # Run first round of expectations + expect_equal(forest_dataset$num_observations(), n) + expect_equal(forest_dataset$num_covariates(), num_covariates) + expect_equal(forest_dataset$num_basis(), num_basis) + expect_equal(forest_dataset$has_variance_weights(), T) + + # Update data + new_basis <- matrix(runif(n * num_basis), ncol = num_basis) + new_variance_weights <- runif(n) + expect_no_error( + forest_dataset$update_basis(new_basis) + ) + expect_no_error( + forest_dataset$update_variance_weights(new_variance_weights) + ) + + # Check that we recover the correct data through get_covariates, get_basis, and get_variance_weights + expect_equal(covariates, forest_dataset$get_covariates()) + expect_equal(new_basis, forest_dataset$get_basis()) + expect_equal(new_variance_weights, forest_dataset$get_variance_weights()) }) test_that("RandomEffectsDataset can be constructed and updated", { - # Generate data - n <- 20 - num_groups <- 4 - num_basis <- 5 - group_ids <- sample(as.integer(1:num_groups), size = n, replace = T) - rfx_basis <- cbind(1, matrix(runif(n*(num_basis-1)), ncol = (num_basis-1))) - variance_weights <- runif(n) - - # Copy data to a RandomEffectsDataset object - rfx_dataset <- createRandomEffectsDataset(group_ids, rfx_basis, variance_weights) - - # Run first round of expectations - expect_equal(rfx_dataset$num_observations(), n) - expect_equal(rfx_dataset$num_basis(), num_basis) - expect_equal(rfx_dataset$has_variance_weights(), T) - - # Update data - new_rfx_basis <- matrix(runif(n * num_basis), ncol = num_basis) - new_variance_weights <- runif(n) - expect_no_error( - rfx_dataset$update_basis(new_rfx_basis) - ) - expect_no_error( - rfx_dataset$update_variance_weights(new_variance_weights) - ) - - # Check that we recover the correct data through get_group_labels, get_basis, and get_variance_weights - expect_equal(group_ids, rfx_dataset$get_group_labels()) - expect_equal(new_rfx_basis, rfx_dataset$get_basis()) - expect_equal(new_variance_weights, rfx_dataset$get_variance_weights()) + # Generate data + n <- 20 + num_groups <- 4 + num_basis <- 5 + group_ids <- sample(as.integer(1:num_groups), size = n, replace = T) + rfx_basis <- cbind( + 1, + matrix(runif(n * (num_basis - 1)), ncol = (num_basis - 1)) + ) + variance_weights <- runif(n) + + # Copy data to a RandomEffectsDataset object + rfx_dataset <- createRandomEffectsDataset( + group_ids, + rfx_basis, + variance_weights + ) + + # Run first round of expectations + expect_equal(rfx_dataset$num_observations(), n) + expect_equal(rfx_dataset$num_basis(), num_basis) + expect_equal(rfx_dataset$has_variance_weights(), T) + + # Update data + new_rfx_basis <- matrix(runif(n * num_basis), ncol = num_basis) + new_variance_weights <- runif(n) + expect_no_error( + rfx_dataset$update_basis(new_rfx_basis) + ) + expect_no_error( + rfx_dataset$update_variance_weights(new_variance_weights) + ) + + # Check that we recover the correct data through get_group_labels, get_basis, and get_variance_weights + expect_equal(group_ids, rfx_dataset$get_group_labels()) + expect_equal(new_rfx_basis, rfx_dataset$get_basis()) + expect_equal(new_variance_weights, rfx_dataset$get_variance_weights()) }) diff --git a/test/R/testthat/test-forest-container.R b/test/R/testthat/test-forest-container.R index 9b399a96..adbfa423 100644 --- a/test/R/testthat/test-forest-container.R +++ b/test/R/testthat/test-forest-container.R @@ -1,251 +1,297 @@ test_that("Univariate constant forest container", { - # Create dataset and forest container - num_trees <- 10 - X = matrix(c(1.5, 8.7, 1.2, + # Create dataset and forest container + num_trees <- 10 + # fmt: skip + X = matrix(c(1.5, 8.7, 1.2, 2.7, 3.4, 5.4, 3.6, 1.2, 9.3, 4.4, 5.4, 10.4, 5.3, 9.3, 3.6, 6.1, 10.4, 4.4), byrow = TRUE, nrow = 6) - n <- nrow(X) - p <- ncol(X) - forest_dataset = createForestDataset(X) - forest_samples <- createForestSamples(num_trees, 1, TRUE) - - # Initialize a forest with constant root predictions - forest_samples$add_forest_with_constant_leaves(0.) - - # Check that regular and "raw" predictions are the same (since the leaf is constant) - pred <- forest_samples$predict(forest_dataset) - pred_raw <- forest_samples$predict_raw(forest_dataset) - - # Assertion - expect_equal(pred, pred_raw) - - # Split the root of the first tree in the ensemble at X[,1] > 4.0 - forest_samples$add_numeric_split_tree(0, 0, 0, 0, 4.0, -5., 5.) - - # Check that regular and "raw" predictions are the same (since the leaf is constant) - pred <- forest_samples$predict(forest_dataset) - pred_raw <- forest_samples$predict_raw(forest_dataset) - - # Assertion - expect_equal(pred, pred_raw) - - # Split the left leaf of the first tree in the ensemble at X[,2] > 4.0 - forest_samples$add_numeric_split_tree(0, 0, 1, 1, 4.0, -7.5, -2.5) - - # Check that regular and "raw" predictions are the same (since the leaf is constant) - pred <- forest_samples$predict(forest_dataset) - pred_raw <- forest_samples$predict_raw(forest_dataset) - - # Assertion - expect_equal(pred, pred_raw) - - # Multiply first forest by 2.0 - forest_samples$multiply_forest(0, 2.0) - - # Check that predictions are all double - pred_orig <- pred - pred_expected <- pred * 2.0 - pred <- forest_samples$predict(forest_dataset) - - # Assertion - expect_equal(pred, pred_expected) - - # Add 1.0 to every tree in first forest - forest_samples$add_to_forest(0, 1.0) - - # Check that predictions are += num_trees - pred_expected <- pred + num_trees - pred <- forest_samples$predict(forest_dataset) - - # Assertion - expect_equal(pred, pred_expected) - - # Initialize a new forest with constant root predictions - forest_samples$add_forest_with_constant_leaves(0.) - - # Split the second forest as the first forest was split - forest_samples$add_numeric_split_tree(1, 0, 0, 0, 4.0, -5., 5.) - forest_samples$add_numeric_split_tree(1, 0, 1, 1, 4.0, -7.5, -2.5) - - # Check that predictions are as expected - pred <- forest_samples$predict(forest_dataset) - pred_expected_new <- cbind(pred_expected, pred_orig) - - # Assertion - expect_equal(pred, pred_expected_new) - - # Combine second forest with the first forest - forest_samples$combine_forests(c(0,1)) - - # Check that predictions are as expected - pred <- forest_samples$predict(forest_dataset) - pred_expected_new <- cbind(pred_expected + pred_orig, pred_orig) - - # Assertion - expect_equal(pred, pred_expected_new) - - # Divide first forest predictions by 2 - forest_samples$multiply_forest(0, 0.5) - - # Check that predictions are as expected - pred <- forest_samples$predict(forest_dataset) - pred_expected_new <- cbind((pred_expected + pred_orig)/2.0, pred_orig) - - # Assertion - expect_equal(pred, pred_expected_new) - - # Delete second forest - forest_samples$delete_sample(1) - - # Check that predictions are as expected - pred <- forest_samples$predict(forest_dataset) - pred_expected_new <- (pred_expected + pred_orig)/2.0 - - # Assertion - expect_equal(pred, pred_expected_new) + n <- nrow(X) + p <- ncol(X) + forest_dataset = createForestDataset(X) + forest_samples <- createForestSamples(num_trees, 1, TRUE) + + # Initialize a forest with constant root predictions + forest_samples$add_forest_with_constant_leaves(0.) + + # Check that regular and "raw" predictions are the same (since the leaf is constant) + pred <- forest_samples$predict(forest_dataset) + pred_raw <- forest_samples$predict_raw(forest_dataset) + + # Assertion + expect_equal(pred, pred_raw) + + # Split the root of the first tree in the ensemble at X[,1] > 4.0 + forest_samples$add_numeric_split_tree(0, 0, 0, 0, 4.0, -5., 5.) + + # Check that regular and "raw" predictions are the same (since the leaf is constant) + pred <- forest_samples$predict(forest_dataset) + pred_raw <- forest_samples$predict_raw(forest_dataset) + + # Assertion + expect_equal(pred, pred_raw) + + # Split the left leaf of the first tree in the ensemble at X[,2] > 4.0 + forest_samples$add_numeric_split_tree(0, 0, 1, 1, 4.0, -7.5, -2.5) + + # Check that regular and "raw" predictions are the same (since the leaf is constant) + pred <- forest_samples$predict(forest_dataset) + pred_raw <- forest_samples$predict_raw(forest_dataset) + + # Assertion + expect_equal(pred, pred_raw) + + # Multiply first forest by 2.0 + forest_samples$multiply_forest(0, 2.0) + + # Check that predictions are all double + pred_orig <- pred + pred_expected <- pred * 2.0 + pred <- forest_samples$predict(forest_dataset) + + # Assertion + expect_equal(pred, pred_expected) + + # Add 1.0 to every tree in first forest + forest_samples$add_to_forest(0, 1.0) + + # Check that predictions are += num_trees + pred_expected <- pred + num_trees + pred <- forest_samples$predict(forest_dataset) + + # Assertion + expect_equal(pred, pred_expected) + + # Initialize a new forest with constant root predictions + forest_samples$add_forest_with_constant_leaves(0.) + + # Split the second forest as the first forest was split + forest_samples$add_numeric_split_tree(1, 0, 0, 0, 4.0, -5., 5.) + forest_samples$add_numeric_split_tree(1, 0, 1, 1, 4.0, -7.5, -2.5) + + # Check that predictions are as expected + pred <- forest_samples$predict(forest_dataset) + pred_expected_new <- cbind(pred_expected, pred_orig) + + # Assertion + expect_equal(pred, pred_expected_new) + + # Combine second forest with the first forest + forest_samples$combine_forests(c(0, 1)) + + # Check that predictions are as expected + pred <- forest_samples$predict(forest_dataset) + pred_expected_new <- cbind(pred_expected + pred_orig, pred_orig) + + # Assertion + expect_equal(pred, pred_expected_new) + + # Divide first forest predictions by 2 + forest_samples$multiply_forest(0, 0.5) + + # Check that predictions are as expected + pred <- forest_samples$predict(forest_dataset) + pred_expected_new <- cbind((pred_expected + pred_orig) / 2.0, pred_orig) + + # Assertion + expect_equal(pred, pred_expected_new) + + # Delete second forest + forest_samples$delete_sample(1) + + # Check that predictions are as expected + pred <- forest_samples$predict(forest_dataset) + pred_expected_new <- (pred_expected + pred_orig) / 2.0 + + # Assertion + expect_equal(pred, pred_expected_new) }) test_that("Collapse forests", { - skip_on_cran() - - # Generate simulated data - n <- 100 - p <- 5 - X <- matrix(runif(n*p), ncol = p) - f_XW <- ( + skip_on_cran() + + # Generate simulated data + n <- 100 + p <- 5 + X <- matrix(runif(n * p), ncol = p) + # fmt: skip + f_XW <- ( ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + ((0.75 <= X[,1]) & (1 > X[,1])) * (7.5) ) - noise_sd <- 1 - y <- f_XW + rnorm(n, 0, noise_sd) - test_set_pct <- 0.2 - n_test <- round(test_set_pct*n) - n_train <- n - n_test - test_inds <- sort(sample(1:n, n_test, replace = FALSE)) - train_inds <- (1:n)[!((1:n) %in% test_inds)] - X_test <- X[test_inds,] - X_train <- X[train_inds,] - y_test <- y[test_inds] - y_train <- y[train_inds] - - # Create forest dataset - forest_dataset_test <- createForestDataset(covariates = X_test) - - # Run BART for 50 iterations - num_mcmc <- 50 - general_param_list <- list(num_chains = 1, keep_every = 1) - bart_model <- bart(X_train = X_train, y_train = y_train, X_test = X_test, - num_gfr = 0, num_burnin = 0, num_mcmc = num_mcmc, - general_params = general_param_list) - - # Extract the mean forest container - mean_forest_container <- bart_model$mean_forests - - # Predict from the original container - pred_orig <- mean_forest_container$predict(forest_dataset_test) - - # Collapse the container in batches of 5 - batch_size <- 5 - mean_forest_container$collapse(batch_size) - - # Predict from the modified container - pred_new <- mean_forest_container$predict(forest_dataset_test) - - # Check that corresponding (sums of) predictions match - batch_inds <- (seq(1,num_mcmc,1) - (num_mcmc - (num_mcmc %/% (num_mcmc %/% batch_size)) * (num_mcmc %/% batch_size)) - 1) %/% batch_size + 1 - pred_orig_collapsed <- matrix(NA, nrow = nrow(pred_orig), ncol = max(batch_inds)) - for (i in 1:max(batch_inds)) { - pred_orig_collapsed[,i] <- rowSums(pred_orig[,batch_inds == i]) / sum(batch_inds == i) - } - - # Assertion - expect_equal(pred_orig_collapsed, pred_new) - - # Now run BART for 52 iterations - num_mcmc <- 52 - general_param_list <- list(num_chains = 1, keep_every = 1) - bart_model <- bart(X_train = X_train, y_train = y_train, X_test = X_test, - num_gfr = 0, num_burnin = 0, num_mcmc = num_mcmc, - general_params = general_param_list) - - # Extract the mean forest container - mean_forest_container <- bart_model$mean_forests - - # Predict from the original container - pred_orig <- mean_forest_container$predict(forest_dataset_test) - - # Collapse the container in batches of 5 - batch_size <- 5 - mean_forest_container$collapse(batch_size) - - # Predict from the modified container - pred_new <- mean_forest_container$predict(forest_dataset_test) - - # Check that corresponding (sums of) predictions match - batch_inds <- (seq(1,num_mcmc,1) - (num_mcmc - (num_mcmc %/% (num_mcmc %/% batch_size)) * (num_mcmc %/% batch_size)) - 1) %/% batch_size + 1 - pred_orig_collapsed <- matrix(NA, nrow = nrow(pred_orig), ncol = max(batch_inds)) - for (i in 1:max(batch_inds)) { - pred_orig_collapsed[,i] <- rowSums(pred_orig[,batch_inds == i]) / sum(batch_inds == i) - } - - # Assertion - expect_equal(pred_orig_collapsed, pred_new) - - # Now run BART for 5 iterations - num_mcmc <- 5 - general_param_list <- list(num_chains = 1, keep_every = 1) - bart_model <- bart(X_train = X_train, y_train = y_train, X_test = X_test, - num_gfr = 0, num_burnin = 0, num_mcmc = num_mcmc, - general_params = general_param_list) - - # Extract the mean forest container - mean_forest_container <- bart_model$mean_forests - - # Predict from the original container - pred_orig <- mean_forest_container$predict(forest_dataset_test) - - # Collapse the container in batches of 5 - batch_size <- 5 - mean_forest_container$collapse(batch_size) - - # Predict from the modified container - pred_new <- mean_forest_container$predict(forest_dataset_test) - - # Check that corresponding (sums of) predictions match - pred_orig_collapsed <- as.matrix(rowSums(pred_orig) / batch_size) - - # Assertion - expect_equal(pred_orig_collapsed, pred_new) - - # Now run BART for 4 iterations - num_mcmc <- 4 - general_param_list <- list(num_chains = 1, keep_every = 1) - bart_model <- bart(X_train = X_train, y_train = y_train, X_test = X_test, - num_gfr = 0, num_burnin = 0, num_mcmc = num_mcmc, - general_params = general_param_list) - - # Extract the mean forest container - mean_forest_container <- bart_model$mean_forests - - # Predict from the original container - pred_orig <- mean_forest_container$predict(forest_dataset_test) - - # Collapse the container in batches of 5 - batch_size <- 5 - mean_forest_container$collapse(batch_size) - - # Predict from the modified container - pred_new <- mean_forest_container$predict(forest_dataset_test) - - # Check that corresponding (sums of) predictions match - pred_orig_collapsed <- pred_orig - - # Assertion - expect_equal(pred_orig_collapsed, pred_new) + noise_sd <- 1 + y <- f_XW + rnorm(n, 0, noise_sd) + test_set_pct <- 0.2 + n_test <- round(test_set_pct * n) + n_train <- n - n_test + test_inds <- sort(sample(1:n, n_test, replace = FALSE)) + train_inds <- (1:n)[!((1:n) %in% test_inds)] + X_test <- X[test_inds, ] + X_train <- X[train_inds, ] + y_test <- y[test_inds] + y_train <- y[train_inds] + + # Create forest dataset + forest_dataset_test <- createForestDataset(covariates = X_test) + + # Run BART for 50 iterations + num_mcmc <- 50 + general_param_list <- list(num_chains = 1, keep_every = 1) + bart_model <- bart( + X_train = X_train, + y_train = y_train, + X_test = X_test, + num_gfr = 0, + num_burnin = 0, + num_mcmc = num_mcmc, + general_params = general_param_list + ) + + # Extract the mean forest container + mean_forest_container <- bart_model$mean_forests + + # Predict from the original container + pred_orig <- mean_forest_container$predict(forest_dataset_test) + + # Collapse the container in batches of 5 + batch_size <- 5 + mean_forest_container$collapse(batch_size) + + # Predict from the modified container + pred_new <- mean_forest_container$predict(forest_dataset_test) + + # Check that corresponding (sums of) predictions match + batch_inds <- (seq(1, num_mcmc, 1) - + (num_mcmc - + (num_mcmc %/% (num_mcmc %/% batch_size)) * (num_mcmc %/% batch_size)) - + 1) %/% + batch_size + + 1 + pred_orig_collapsed <- matrix( + NA, + nrow = nrow(pred_orig), + ncol = max(batch_inds) + ) + for (i in 1:max(batch_inds)) { + pred_orig_collapsed[, i] <- rowSums(pred_orig[, batch_inds == i]) / + sum(batch_inds == i) + } + + # Assertion + expect_equal(pred_orig_collapsed, pred_new) + + # Now run BART for 52 iterations + num_mcmc <- 52 + general_param_list <- list(num_chains = 1, keep_every = 1) + bart_model <- bart( + X_train = X_train, + y_train = y_train, + X_test = X_test, + num_gfr = 0, + num_burnin = 0, + num_mcmc = num_mcmc, + general_params = general_param_list + ) + + # Extract the mean forest container + mean_forest_container <- bart_model$mean_forests + + # Predict from the original container + pred_orig <- mean_forest_container$predict(forest_dataset_test) + + # Collapse the container in batches of 5 + batch_size <- 5 + mean_forest_container$collapse(batch_size) + + # Predict from the modified container + pred_new <- mean_forest_container$predict(forest_dataset_test) + + # Check that corresponding (sums of) predictions match + batch_inds <- (seq(1, num_mcmc, 1) - + (num_mcmc - + (num_mcmc %/% (num_mcmc %/% batch_size)) * (num_mcmc %/% batch_size)) - + 1) %/% + batch_size + + 1 + pred_orig_collapsed <- matrix( + NA, + nrow = nrow(pred_orig), + ncol = max(batch_inds) + ) + for (i in 1:max(batch_inds)) { + pred_orig_collapsed[, i] <- rowSums(pred_orig[, batch_inds == i]) / + sum(batch_inds == i) + } + + # Assertion + expect_equal(pred_orig_collapsed, pred_new) + + # Now run BART for 5 iterations + num_mcmc <- 5 + general_param_list <- list(num_chains = 1, keep_every = 1) + bart_model <- bart( + X_train = X_train, + y_train = y_train, + X_test = X_test, + num_gfr = 0, + num_burnin = 0, + num_mcmc = num_mcmc, + general_params = general_param_list + ) + + # Extract the mean forest container + mean_forest_container <- bart_model$mean_forests + + # Predict from the original container + pred_orig <- mean_forest_container$predict(forest_dataset_test) + + # Collapse the container in batches of 5 + batch_size <- 5 + mean_forest_container$collapse(batch_size) + + # Predict from the modified container + pred_new <- mean_forest_container$predict(forest_dataset_test) + + # Check that corresponding (sums of) predictions match + pred_orig_collapsed <- as.matrix(rowSums(pred_orig) / batch_size) + + # Assertion + expect_equal(pred_orig_collapsed, pred_new) + + # Now run BART for 4 iterations + num_mcmc <- 4 + general_param_list <- list(num_chains = 1, keep_every = 1) + bart_model <- bart( + X_train = X_train, + y_train = y_train, + X_test = X_test, + num_gfr = 0, + num_burnin = 0, + num_mcmc = num_mcmc, + general_params = general_param_list + ) + + # Extract the mean forest container + mean_forest_container <- bart_model$mean_forests + + # Predict from the original container + pred_orig <- mean_forest_container$predict(forest_dataset_test) + + # Collapse the container in batches of 5 + batch_size <- 5 + mean_forest_container$collapse(batch_size) + + # Predict from the modified container + pred_new <- mean_forest_container$predict(forest_dataset_test) + + # Check that corresponding (sums of) predictions match + pred_orig_collapsed <- pred_orig + + # Assertion + expect_equal(pred_orig_collapsed, pred_new) }) diff --git a/test/R/testthat/test-forest.R b/test/R/testthat/test-forest.R index 3c9409dd..063c64b6 100644 --- a/test/R/testthat/test-forest.R +++ b/test/R/testthat/test-forest.R @@ -1,162 +1,164 @@ test_that("Univariate forest construction", { - # Create dataset and forest container - num_trees <- 10 - X = matrix(c(1.5, 8.7, 1.2, + # Create dataset and forest container + num_trees <- 10 + # fmt: skip + X = matrix(c(1.5, 8.7, 1.2, 2.7, 3.4, 5.4, 3.6, 1.2, 9.3, 4.4, 5.4, 10.4, 5.3, 9.3, 3.6, 6.1, 10.4, 4.4), byrow = TRUE, nrow = 6) - n <- nrow(X) - p <- ncol(X) - forest_dataset = createForestDataset(X) - forest <- createForest(num_trees, 1, TRUE) - - # Initialize forest with 0.0 root predictions - forest$set_root_leaves(0.) - - # Check that regular and "raw" predictions are the same (since the leaf is constant) - pred <- forest$predict(forest_dataset) - pred_raw <- forest$predict_raw(forest_dataset) - - # Assertion - expect_equal(pred, pred_raw) - - # Split the root of the first tree in the ensemble at X[,1] > 4.0 - forest$add_numeric_split_tree(0, 0, 0, 4.0, -5., 5.) - - # Check that predictions are the same (since the leaf is constant) - pred <- forest$predict(forest_dataset) - pred_raw <- forest$predict_raw(forest_dataset) - - # Assertion - expect_equal(pred, pred_raw) - - # Split the left leaf of the first tree in the ensemble at X[,2] > 4.0 - forest$add_numeric_split_tree(0, 1, 1, 4.0, -7.5, -2.5) - - # Check that regular and "raw" predictions are the same (since the leaf is constant) - pred <- forest$predict(forest_dataset) - pred_raw <- forest$predict_raw(forest_dataset) - - # Assertion - expect_equal(pred, pred_raw) - - # Check the split count for the first tree in the ensemble - split_counts <- forest$get_tree_split_counts(0,p) - split_counts_expected <- c(1,1,0) - - # Assertion - expect_equal(split_counts, split_counts_expected) + n <- nrow(X) + p <- ncol(X) + forest_dataset = createForestDataset(X) + forest <- createForest(num_trees, 1, TRUE) + + # Initialize forest with 0.0 root predictions + forest$set_root_leaves(0.) + + # Check that regular and "raw" predictions are the same (since the leaf is constant) + pred <- forest$predict(forest_dataset) + pred_raw <- forest$predict_raw(forest_dataset) + + # Assertion + expect_equal(pred, pred_raw) + + # Split the root of the first tree in the ensemble at X[,1] > 4.0 + forest$add_numeric_split_tree(0, 0, 0, 4.0, -5., 5.) + + # Check that predictions are the same (since the leaf is constant) + pred <- forest$predict(forest_dataset) + pred_raw <- forest$predict_raw(forest_dataset) + + # Assertion + expect_equal(pred, pred_raw) + + # Split the left leaf of the first tree in the ensemble at X[,2] > 4.0 + forest$add_numeric_split_tree(0, 1, 1, 4.0, -7.5, -2.5) + + # Check that regular and "raw" predictions are the same (since the leaf is constant) + pred <- forest$predict(forest_dataset) + pred_raw <- forest$predict_raw(forest_dataset) + + # Assertion + expect_equal(pred, pred_raw) + + # Check the split count for the first tree in the ensemble + split_counts <- forest$get_tree_split_counts(0, p) + split_counts_expected <- c(1, 1, 0) + + # Assertion + expect_equal(split_counts, split_counts_expected) }) test_that("Univariate forest construction and low-level merge / arithmetic ops", { - # Create dataset and forest container - num_trees <- 10 - X = matrix(c(1.5, 8.7, 1.2, + # Create dataset and forest container + num_trees <- 10 + # fmt: skip + X = matrix(c(1.5, 8.7, 1.2, 2.7, 3.4, 5.4, 3.6, 1.2, 9.3, 4.4, 5.4, 10.4, 5.3, 9.3, 3.6, 6.1, 10.4, 4.4), byrow = TRUE, nrow = 6) - n <- nrow(X) - p <- ncol(X) - forest_dataset = createForestDataset(X) - forest1 <- createForest(num_trees, 1, TRUE) - forest2 <- createForest(num_trees, 1, TRUE) - - # Initialize forests with 0.0 root predictions - forest1$set_root_leaves(0.) - forest2$set_root_leaves(0.) - - # Check that predictions are as expected - pred1 <- forest1$predict(forest_dataset) - pred2 <- forest2$predict(forest_dataset) - pred_exp1 <- c(0,0,0,0,0,0) - pred_exp2 <- c(0,0,0,0,0,0) - - # Assertion - expect_equal(pred1, pred_exp1) - expect_equal(pred2, pred_exp2) - - # Split the root of the first tree of the first forest in the ensemble at X[,1] > 4.0 - forest1$add_numeric_split_tree(0, 0, 0, 4.0, -5., 5.) - - # Split the root of the first tree of the second forest in the ensemble at X[,1] > 3.0 - forest2$add_numeric_split_tree(0, 0, 0, 3.0, -1., 1.) - - # Check that predictions are as expected - pred1 <- forest1$predict(forest_dataset) - pred2 <- forest2$predict(forest_dataset) - pred_exp1 <- c(-5,-5,-5,5,5,5) - pred_exp2 <- c(-1,-1,1,1,1,1) - - # Assertion - expect_equal(pred1, pred_exp1) - expect_equal(pred2, pred_exp2) - - # Split the left leaf of the first tree of the first forest in the ensemble at X[,2] > 4.0 - forest1$add_numeric_split_tree(0, 1, 1, 4.0, -7.5, -2.5) - - # Split the left leaf of the first tree of the first forest in the ensemble at X[,2] > 4.0 - forest2$add_numeric_split_tree(0, 1, 1, 4.0, -1.5, -0.5) - - # Check that predictions are as expected - pred1 <- forest1$predict(forest_dataset) - pred2 <- forest2$predict(forest_dataset) - pred_exp1 <- c(-2.5,-7.5,-7.5,5,5,5) - pred_exp2 <- c(-0.5,-1.5,1,1,1,1) - - # Assertion - expect_equal(pred1, pred_exp1) - expect_equal(pred2, pred_exp2) - - # Merge forests - forest1$merge_forest(forest2) - - # Check that predictions are as expected - pred <- forest1$predict(forest_dataset) - pred_exp <- c(-3.0,-9.0,-6.5,6.0,6.0,6.0) - - # Assertion - expect_equal(pred, pred_exp) - - # Add constant to every value of the combined forest - forest1$add_constant(0.5) - - # Check that predictions are as expected - pred <- forest1$predict(forest_dataset) - pred_exp <- c(7.0,1.0,3.5,16.0,16.0,16.0) - - # Assertion - expect_equal(pred, pred_exp) - - # Check that "old" forest is still intact - pred <- forest2$predict(forest_dataset) - pred_exp <- c(-0.5,-1.5,1,1,1,1) - - # Assertion - expect_equal(pred, pred_exp) - - # Subtract constant back off of every value of the combined forest - forest1$add_constant(-0.5) - - # Check that predictions are as expected - pred <- forest1$predict(forest_dataset) - pred_exp <- c(-3.0,-9.0,-6.5,6.0,6.0,6.0) - - # Assertion - expect_equal(pred, pred_exp) - - # Multiply every value of the combined forest by a constant - forest1$multiply_constant(2.0) - - # Check that predictions are as expected - pred <- forest1$predict(forest_dataset) - pred_exp <- c(-6.0,-18.0,-13.0,12.0,12.0,12.0) - - # Assertion - expect_equal(pred, pred_exp) + n <- nrow(X) + p <- ncol(X) + forest_dataset = createForestDataset(X) + forest1 <- createForest(num_trees, 1, TRUE) + forest2 <- createForest(num_trees, 1, TRUE) + + # Initialize forests with 0.0 root predictions + forest1$set_root_leaves(0.) + forest2$set_root_leaves(0.) + + # Check that predictions are as expected + pred1 <- forest1$predict(forest_dataset) + pred2 <- forest2$predict(forest_dataset) + pred_exp1 <- c(0, 0, 0, 0, 0, 0) + pred_exp2 <- c(0, 0, 0, 0, 0, 0) + + # Assertion + expect_equal(pred1, pred_exp1) + expect_equal(pred2, pred_exp2) + + # Split the root of the first tree of the first forest in the ensemble at X[,1] > 4.0 + forest1$add_numeric_split_tree(0, 0, 0, 4.0, -5., 5.) + + # Split the root of the first tree of the second forest in the ensemble at X[,1] > 3.0 + forest2$add_numeric_split_tree(0, 0, 0, 3.0, -1., 1.) + + # Check that predictions are as expected + pred1 <- forest1$predict(forest_dataset) + pred2 <- forest2$predict(forest_dataset) + pred_exp1 <- c(-5, -5, -5, 5, 5, 5) + pred_exp2 <- c(-1, -1, 1, 1, 1, 1) + + # Assertion + expect_equal(pred1, pred_exp1) + expect_equal(pred2, pred_exp2) + + # Split the left leaf of the first tree of the first forest in the ensemble at X[,2] > 4.0 + forest1$add_numeric_split_tree(0, 1, 1, 4.0, -7.5, -2.5) + + # Split the left leaf of the first tree of the first forest in the ensemble at X[,2] > 4.0 + forest2$add_numeric_split_tree(0, 1, 1, 4.0, -1.5, -0.5) + + # Check that predictions are as expected + pred1 <- forest1$predict(forest_dataset) + pred2 <- forest2$predict(forest_dataset) + pred_exp1 <- c(-2.5, -7.5, -7.5, 5, 5, 5) + pred_exp2 <- c(-0.5, -1.5, 1, 1, 1, 1) + + # Assertion + expect_equal(pred1, pred_exp1) + expect_equal(pred2, pred_exp2) + + # Merge forests + forest1$merge_forest(forest2) + + # Check that predictions are as expected + pred <- forest1$predict(forest_dataset) + pred_exp <- c(-3.0, -9.0, -6.5, 6.0, 6.0, 6.0) + + # Assertion + expect_equal(pred, pred_exp) + + # Add constant to every value of the combined forest + forest1$add_constant(0.5) + + # Check that predictions are as expected + pred <- forest1$predict(forest_dataset) + pred_exp <- c(7.0, 1.0, 3.5, 16.0, 16.0, 16.0) + + # Assertion + expect_equal(pred, pred_exp) + + # Check that "old" forest is still intact + pred <- forest2$predict(forest_dataset) + pred_exp <- c(-0.5, -1.5, 1, 1, 1, 1) + + # Assertion + expect_equal(pred, pred_exp) + + # Subtract constant back off of every value of the combined forest + forest1$add_constant(-0.5) + + # Check that predictions are as expected + pred <- forest1$predict(forest_dataset) + pred_exp <- c(-3.0, -9.0, -6.5, 6.0, 6.0, 6.0) + + # Assertion + expect_equal(pred, pred_exp) + + # Multiply every value of the combined forest by a constant + forest1$multiply_constant(2.0) + + # Check that predictions are as expected + pred <- forest1$predict(forest_dataset) + pred_exp <- c(-6.0, -18.0, -13.0, 12.0, 12.0, 12.0) + + # Assertion + expect_equal(pred, pred_exp) }) diff --git a/test/R/testthat/test-residual.R b/test/R/testthat/test-residual.R index 4f663465..64bbb30b 100644 --- a/test/R/testthat/test-residual.R +++ b/test/R/testthat/test-residual.R @@ -1,85 +1,120 @@ test_that("Residual updates correctly propagated after forest sampling step", { - # Setup - # Create dataset - X = matrix(c(1.5, 8.7, 1.2, + # Setup + # Create dataset + # fmt: skip + X = matrix(c(1.5, 8.7, 1.2, 2.7, 3.4, 5.4, 3.6, 1.2, 9.3, 4.4, 5.4, 10.4, 5.3, 9.3, 3.6, 6.1, 10.4, 4.4), byrow = TRUE, nrow = 6) - W = matrix(c(1, 1, 1, 1, 1, 1), nrow = 6) - n = nrow(X) - p = ncol(X) - y = as.matrix(ifelse(X[,1]>4,-5,5) + rnorm(n,0,1)) - y_bar = mean(y) - y_std = sd(y) - resid = (y-y_bar)/y_std - forest_dataset = createForestDataset(X, W) - residual = createOutcome(resid) - variable_weights = rep(1.0/p, p) - feature_types = as.integer(rep(0, p)) - - # Forest parameters - num_trees = 50 - alpha = 0.95 - beta = 2.0 - min_samples_leaf = 1 - current_sigma2 = 1. - current_leaf_scale = as.matrix(1./num_trees,nrow=1,ncol=1) - cutpoint_grid_size = 100 - max_depth = 10 - a_forest = 0 - b_forest = 0 - - # RNG - cpp_rng = createCppRNG(-1) - - # Create forest sampler and forest container - global_model_config = createGlobalModelConfig(global_error_variance=current_sigma2) - forest_model_config = createForestModelConfig(feature_types=feature_types, num_trees=num_trees, - num_observations=n, alpha=alpha, beta=beta, - min_samples_leaf=min_samples_leaf, max_depth=max_depth, - leaf_model_type=0, leaf_model_scale=current_leaf_scale, - variable_weights=variable_weights, variance_forest_shape=a_forest, - variance_forest_scale=b_forest, cutpoint_grid_size=cutpoint_grid_size) - forest_model = createForestModel(forest_dataset, forest_model_config, global_model_config) - forest_samples = createForestSamples(num_trees, 1, FALSE) - active_forest = createForest(num_trees, 1, FALSE) - - # Initialize the leaves of each tree in the prognostic forest - active_forest$prepare_for_sampler(forest_dataset, residual, forest_model, 0, mean(resid)) - active_forest$adjust_residual(forest_dataset, residual, forest_model, FALSE, FALSE) - - # Run the forest sampling algorithm for a single iteration - forest_model$sample_one_iteration( - forest_dataset, residual, forest_samples, active_forest, - cpp_rng, forest_model_config, global_model_config, keep_forest = TRUE, gfr = TRUE - ) + W = matrix(c(1, 1, 1, 1, 1, 1), nrow = 6) + n = nrow(X) + p = ncol(X) + y = as.matrix(ifelse(X[, 1] > 4, -5, 5) + rnorm(n, 0, 1)) + y_bar = mean(y) + y_std = sd(y) + resid = (y - y_bar) / y_std + forest_dataset = createForestDataset(X, W) + residual = createOutcome(resid) + variable_weights = rep(1.0 / p, p) + feature_types = as.integer(rep(0, p)) - # Get the current residual after running the sampler - initial_resid = residual$get_data() - - # Get initial prediction from the tree ensemble - initial_yhat = as.numeric(forest_samples$predict(forest_dataset)) - - # Update the basis vector - scalar = 2.0 - W_update = W*scalar - forest_dataset$update_basis(W_update) - - # Update residual to reflect adjusted basis - forest_model$propagate_basis_update(forest_dataset, residual, active_forest) - - # Get updated prediction from the tree ensemble - updated_yhat = as.numeric(forest_samples$predict(forest_dataset)) - - # Compute the expected residual - expected_resid = initial_resid + initial_yhat - updated_yhat - - # Get the current residual after running the sampler - updated_resid = residual$get_data() - - # Assertion - expect_equal(expected_resid, updated_resid) + # Forest parameters + num_trees = 50 + alpha = 0.95 + beta = 2.0 + min_samples_leaf = 1 + current_sigma2 = 1. + current_leaf_scale = as.matrix(1. / num_trees, nrow = 1, ncol = 1) + cutpoint_grid_size = 100 + max_depth = 10 + a_forest = 0 + b_forest = 0 + + # RNG + cpp_rng = createCppRNG(-1) + + # Create forest sampler and forest container + global_model_config = createGlobalModelConfig( + global_error_variance = current_sigma2 + ) + forest_model_config = createForestModelConfig( + feature_types = feature_types, + num_trees = num_trees, + num_observations = n, + alpha = alpha, + beta = beta, + min_samples_leaf = min_samples_leaf, + max_depth = max_depth, + leaf_model_type = 0, + leaf_model_scale = current_leaf_scale, + variable_weights = variable_weights, + variance_forest_shape = a_forest, + variance_forest_scale = b_forest, + cutpoint_grid_size = cutpoint_grid_size + ) + forest_model = createForestModel( + forest_dataset, + forest_model_config, + global_model_config + ) + forest_samples = createForestSamples(num_trees, 1, FALSE) + active_forest = createForest(num_trees, 1, FALSE) + + # Initialize the leaves of each tree in the prognostic forest + active_forest$prepare_for_sampler( + forest_dataset, + residual, + forest_model, + 0, + mean(resid) + ) + active_forest$adjust_residual( + forest_dataset, + residual, + forest_model, + FALSE, + FALSE + ) + + # Run the forest sampling algorithm for a single iteration + forest_model$sample_one_iteration( + forest_dataset, + residual, + forest_samples, + active_forest, + cpp_rng, + forest_model_config, + global_model_config, + keep_forest = TRUE, + gfr = TRUE + ) + + # Get the current residual after running the sampler + initial_resid = residual$get_data() + + # Get initial prediction from the tree ensemble + initial_yhat = as.numeric(forest_samples$predict(forest_dataset)) + + # Update the basis vector + scalar = 2.0 + W_update = W * scalar + forest_dataset$update_basis(W_update) + + # Update residual to reflect adjusted basis + forest_model$propagate_basis_update(forest_dataset, residual, active_forest) + + # Get updated prediction from the tree ensemble + updated_yhat = as.numeric(forest_samples$predict(forest_dataset)) + + # Compute the expected residual + expected_resid = initial_resid + initial_yhat - updated_yhat + + # Get the current residual after running the sampler + updated_resid = residual$get_data() + + # Assertion + expect_equal(expected_resid, updated_resid) }) diff --git a/test/R/testthat/test-serialization.R b/test/R/testthat/test-serialization.R index 0d78957f..fe50af5f 100644 --- a/test/R/testthat/test-serialization.R +++ b/test/R/testthat/test-serialization.R @@ -1,162 +1,181 @@ test_that("BART Serialization", { - skip_on_cran() - - # Generate simulated data - n <- 100 - p <- 5 - X <- matrix(runif(n*p), ncol = p) - f_XW <- ( + skip_on_cran() + + # Generate simulated data + n <- 100 + p <- 5 + X <- matrix(runif(n * p), ncol = p) + # fmt: skip + f_XW <- ( ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + ((0.75 <= X[,1]) & (1 > X[,1])) * (7.5) ) - noise_sd <- 1 - y <- f_XW + rnorm(n, 0, noise_sd) - test_set_pct <- 0.2 - n_test <- round(test_set_pct*n) - n_train <- n - n_test - test_inds <- sort(sample(1:n, n_test, replace = FALSE)) - train_inds <- (1:n)[!((1:n) %in% test_inds)] - X_test <- X[test_inds,] - X_train <- X[train_inds,] - y_test <- y[test_inds] - y_train <- y[train_inds] - - # Sample a BART model - general_param_list <- list(num_chains = 1, keep_every = 1) - bart_model <- bart(X_train = X_train, y_train = y_train, - num_gfr = 0, num_burnin = 10, num_mcmc = 10, - general_params = general_param_list) - y_hat_orig <- rowMeans(predict(bart_model, X_test)$y_hat) - - # Save to JSON - bart_json_string <- saveBARTModelToJsonString(bart_model) - - # Reload as a BART model - bart_model_roundtrip <- createBARTModelFromJsonString(bart_json_string) - - # Predict from the roundtrip BART model - y_hat_reloaded <- rowMeans(predict(bart_model_roundtrip, X_test)$y_hat) - - # Assertion - expect_equal(y_hat_orig, y_hat_reloaded) + noise_sd <- 1 + y <- f_XW + rnorm(n, 0, noise_sd) + test_set_pct <- 0.2 + n_test <- round(test_set_pct * n) + n_train <- n - n_test + test_inds <- sort(sample(1:n, n_test, replace = FALSE)) + train_inds <- (1:n)[!((1:n) %in% test_inds)] + X_test <- X[test_inds, ] + X_train <- X[train_inds, ] + y_test <- y[test_inds] + y_train <- y[train_inds] + + # Sample a BART model + general_param_list <- list(num_chains = 1, keep_every = 1) + bart_model <- bart( + X_train = X_train, + y_train = y_train, + num_gfr = 0, + num_burnin = 10, + num_mcmc = 10, + general_params = general_param_list + ) + y_hat_orig <- rowMeans(predict(bart_model, X_test)$y_hat) + + # Save to JSON + bart_json_string <- saveBARTModelToJsonString(bart_model) + + # Reload as a BART model + bart_model_roundtrip <- createBARTModelFromJsonString(bart_json_string) + + # Predict from the roundtrip BART model + y_hat_reloaded <- rowMeans(predict(bart_model_roundtrip, X_test)$y_hat) + + # Assertion + expect_equal(y_hat_orig, y_hat_reloaded) }) test_that("BCF Serialization", { - skip_on_cran() - - n <- 500 - x1 <- runif(n) - x2 <- runif(n) - x3 <- runif(n) - x4 <- runif(n) - x5 <- runif(n) - X <- cbind(x1,x2,x3,x4,x5) - p <- ncol(X) - pi_x <- 0.25 + 0.5*X[,1] - mu_x <- pi_x * 5 - tau_x <- X[,2] * 2 - Z <- rbinom(n,1,pi_x) - E_XZ <- mu_x + Z*tau_x - y <- E_XZ + rnorm(n, 0, 1) - test_set_pct <- 0.2 - n_test <- round(test_set_pct*n) - n_train <- n - n_test - test_inds <- sort(sample(1:n, n_test, replace = FALSE)) - train_inds <- (1:n)[!((1:n) %in% test_inds)] - X_test <- X[test_inds,] - X_train <- X[train_inds,] - pi_test <- pi_x[test_inds] - pi_train <- pi_x[train_inds] - Z_test <- Z[test_inds] - Z_train <- Z[train_inds] - y_test <- y[test_inds] - y_train <- y[train_inds] - mu_test <- mu_x[test_inds] - mu_train <- mu_x[train_inds] - tau_test <- tau_x[test_inds] - tau_train <- tau_x[train_inds] - - # Sample a BCF model - bcf_model <- bcf(X_train = X_train, Z_train = Z_train, y_train = y_train, - propensity_train = pi_train, num_gfr = 100, num_burnin = 0, num_mcmc = 100) - bcf_preds_orig <- predict(bcf_model, X_test, Z_test, pi_test) - mu_hat_orig <- rowMeans(bcf_preds_orig[["mu_hat"]]) - tau_hat_orig <- rowMeans(bcf_preds_orig[["tau_hat"]]) - y_hat_orig <- rowMeans(bcf_preds_orig[["y_hat"]]) - - # Save to JSON - bcf_json_string <- saveBCFModelToJsonString(bcf_model) - - # Reload as a BCF model - bcf_model_roundtrip <- createBCFModelFromJsonString(bcf_json_string) - - # Predict from the roundtrip BCF model - bcf_preds_reloaded <- predict(bcf_model_roundtrip, X_test, Z_test, pi_test) - mu_hat_reloaded <- rowMeans(bcf_preds_reloaded[["mu_hat"]]) - tau_hat_reloaded <- rowMeans(bcf_preds_reloaded[["tau_hat"]]) - y_hat_reloaded <- rowMeans(bcf_preds_reloaded[["y_hat"]]) - - # Assertion - expect_equal(y_hat_orig, y_hat_reloaded) + skip_on_cran() + + n <- 500 + x1 <- runif(n) + x2 <- runif(n) + x3 <- runif(n) + x4 <- runif(n) + x5 <- runif(n) + X <- cbind(x1, x2, x3, x4, x5) + p <- ncol(X) + pi_x <- 0.25 + 0.5 * X[, 1] + mu_x <- pi_x * 5 + tau_x <- X[, 2] * 2 + Z <- rbinom(n, 1, pi_x) + E_XZ <- mu_x + Z * tau_x + y <- E_XZ + rnorm(n, 0, 1) + test_set_pct <- 0.2 + n_test <- round(test_set_pct * n) + n_train <- n - n_test + test_inds <- sort(sample(1:n, n_test, replace = FALSE)) + train_inds <- (1:n)[!((1:n) %in% test_inds)] + X_test <- X[test_inds, ] + X_train <- X[train_inds, ] + pi_test <- pi_x[test_inds] + pi_train <- pi_x[train_inds] + Z_test <- Z[test_inds] + Z_train <- Z[train_inds] + y_test <- y[test_inds] + y_train <- y[train_inds] + mu_test <- mu_x[test_inds] + mu_train <- mu_x[train_inds] + tau_test <- tau_x[test_inds] + tau_train <- tau_x[train_inds] + + # Sample a BCF model + bcf_model <- bcf( + X_train = X_train, + Z_train = Z_train, + y_train = y_train, + propensity_train = pi_train, + num_gfr = 100, + num_burnin = 0, + num_mcmc = 100 + ) + bcf_preds_orig <- predict(bcf_model, X_test, Z_test, pi_test) + mu_hat_orig <- rowMeans(bcf_preds_orig[["mu_hat"]]) + tau_hat_orig <- rowMeans(bcf_preds_orig[["tau_hat"]]) + y_hat_orig <- rowMeans(bcf_preds_orig[["y_hat"]]) + + # Save to JSON + bcf_json_string <- saveBCFModelToJsonString(bcf_model) + + # Reload as a BCF model + bcf_model_roundtrip <- createBCFModelFromJsonString(bcf_json_string) + + # Predict from the roundtrip BCF model + bcf_preds_reloaded <- predict(bcf_model_roundtrip, X_test, Z_test, pi_test) + mu_hat_reloaded <- rowMeans(bcf_preds_reloaded[["mu_hat"]]) + tau_hat_reloaded <- rowMeans(bcf_preds_reloaded[["tau_hat"]]) + y_hat_reloaded <- rowMeans(bcf_preds_reloaded[["y_hat"]]) + + # Assertion + expect_equal(y_hat_orig, y_hat_reloaded) }) test_that("BCF Serialization (no propensity)", { - skip_on_cran() - - n <- 500 - x1 <- runif(n) - x2 <- runif(n) - x3 <- runif(n) - x4 <- runif(n) - x5 <- runif(n) - X <- cbind(x1,x2,x3,x4,x5) - p <- ncol(X) - pi_x <- 0.25 + 0.5*X[,1] - mu_x <- pi_x * 5 - tau_x <- X[,2] * 2 - Z <- rbinom(n,1,pi_x) - E_XZ <- mu_x + Z*tau_x - y <- E_XZ + rnorm(n, 0, 1) - test_set_pct <- 0.2 - n_test <- round(test_set_pct*n) - n_train <- n - n_test - test_inds <- sort(sample(1:n, n_test, replace = FALSE)) - train_inds <- (1:n)[!((1:n) %in% test_inds)] - X_test <- X[test_inds,] - X_train <- X[train_inds,] - pi_test <- pi_x[test_inds] - pi_train <- pi_x[train_inds] - Z_test <- Z[test_inds] - Z_train <- Z[train_inds] - y_test <- y[test_inds] - y_train <- y[train_inds] - mu_test <- mu_x[test_inds] - mu_train <- mu_x[train_inds] - tau_test <- tau_x[test_inds] - tau_train <- tau_x[train_inds] - - # Sample a BCF model - bcf_model <- bcf(X_train = X_train, Z_train = Z_train, y_train = y_train, - num_gfr = 100, num_burnin = 0, num_mcmc = 100) - bcf_preds_orig <- predict(bcf_model, X_test, Z_test) - mu_hat_orig <- rowMeans(bcf_preds_orig[["mu_hat"]]) - tau_hat_orig <- rowMeans(bcf_preds_orig[["tau_hat"]]) - y_hat_orig <- rowMeans(bcf_preds_orig[["y_hat"]]) - - # Save to JSON - bcf_json_string <- saveBCFModelToJsonString(bcf_model) - - # Reload as a BCF model - bcf_model_roundtrip <- createBCFModelFromJsonString(bcf_json_string) - - # Predict from the roundtrip BCF model - bcf_preds_reloaded <- predict(bcf_model_roundtrip, X_test, Z_test) - mu_hat_reloaded <- rowMeans(bcf_preds_reloaded[["mu_hat"]]) - tau_hat_reloaded <- rowMeans(bcf_preds_reloaded[["tau_hat"]]) - y_hat_reloaded <- rowMeans(bcf_preds_reloaded[["y_hat"]]) - - # Assertion - expect_equal(y_hat_orig, y_hat_reloaded) -}) \ No newline at end of file + skip_on_cran() + + n <- 500 + x1 <- runif(n) + x2 <- runif(n) + x3 <- runif(n) + x4 <- runif(n) + x5 <- runif(n) + X <- cbind(x1, x2, x3, x4, x5) + p <- ncol(X) + pi_x <- 0.25 + 0.5 * X[, 1] + mu_x <- pi_x * 5 + tau_x <- X[, 2] * 2 + Z <- rbinom(n, 1, pi_x) + E_XZ <- mu_x + Z * tau_x + y <- E_XZ + rnorm(n, 0, 1) + test_set_pct <- 0.2 + n_test <- round(test_set_pct * n) + n_train <- n - n_test + test_inds <- sort(sample(1:n, n_test, replace = FALSE)) + train_inds <- (1:n)[!((1:n) %in% test_inds)] + X_test <- X[test_inds, ] + X_train <- X[train_inds, ] + pi_test <- pi_x[test_inds] + pi_train <- pi_x[train_inds] + Z_test <- Z[test_inds] + Z_train <- Z[train_inds] + y_test <- y[test_inds] + y_train <- y[train_inds] + mu_test <- mu_x[test_inds] + mu_train <- mu_x[train_inds] + tau_test <- tau_x[test_inds] + tau_train <- tau_x[train_inds] + + # Sample a BCF model + bcf_model <- bcf( + X_train = X_train, + Z_train = Z_train, + y_train = y_train, + num_gfr = 100, + num_burnin = 0, + num_mcmc = 100 + ) + bcf_preds_orig <- predict(bcf_model, X_test, Z_test) + mu_hat_orig <- rowMeans(bcf_preds_orig[["mu_hat"]]) + tau_hat_orig <- rowMeans(bcf_preds_orig[["tau_hat"]]) + y_hat_orig <- rowMeans(bcf_preds_orig[["y_hat"]]) + + # Save to JSON + bcf_json_string <- saveBCFModelToJsonString(bcf_model) + + # Reload as a BCF model + bcf_model_roundtrip <- createBCFModelFromJsonString(bcf_json_string) + + # Predict from the roundtrip BCF model + bcf_preds_reloaded <- predict(bcf_model_roundtrip, X_test, Z_test) + mu_hat_reloaded <- rowMeans(bcf_preds_reloaded[["mu_hat"]]) + tau_hat_reloaded <- rowMeans(bcf_preds_reloaded[["tau_hat"]]) + y_hat_reloaded <- rowMeans(bcf_preds_reloaded[["y_hat"]]) + + # Assertion + expect_equal(y_hat_orig, y_hat_reloaded) +}) diff --git a/test/R/testthat/test-utils.R b/test/R/testthat/test-utils.R index 60991ad7..72434368 100644 --- a/test/R/testthat/test-utils.R +++ b/test/R/testthat/test-utils.R @@ -1,70 +1,86 @@ test_that("Array conversion", { - skip_on_cran() - - # Test data - scalar_1 <- 1.5 - scalar_2 <- -2.5 - scalar_3 <- 4 - array_1d_1 <- c(1.6, 3.4, 7.6, 8.7) - array_1d_2 <- c(2.5, 3.1, 5.6) - array_1d_3 <- c(2.5) - array_2d_1 <- matrix( - c(2.5,1.2,4.3,7.4,1.7,2.9,3.6,9.1,7.2,4.5,6.7,1.4), - nrow = 3, ncol = 4, byrow = T - ) - array_2d_2 <- matrix( - c(2.5,1.2,4.3,7.4,1.7,2.9,3.6,9.1), - nrow = 2, ncol = 4, byrow = T - ) - array_square_1 <- matrix( - c(2.5,1.2,1.7,2.9), - nrow = 2, ncol = 2, byrow = T - ) - array_square_2 <- matrix( - c(2.5,0.0,0.0,2.9), - nrow = 2, ncol = 2, byrow = T - ) - array_square_3 <- matrix( - c(2.5,0.0,0.0,0.0,2.9,0.0,0.0,0.0,5.6), - nrow = 3, ncol = 3, byrow = T - ) - - # Error cases - expect_error(expand_dims_1d(array_1d_1, 5)) - expect_error(expand_dims_1d(array_1d_2, 4)) - expect_error(expand_dims_2d(array_2d_1, 2, 4)) - expect_error(expand_dims_2d(array_2d_2, 3, 4)) - expect_error(expand_dims_2d_diag(array_square_1, 4)) - expect_error(expand_dims_2d_diag(array_square_2, 3)) - expect_error(expand_dims_2d_diag(array_square_3, 2)) - - # Working cases - expect_equal(c(scalar_1,scalar_1,scalar_1), expand_dims_1d(scalar_1, 3)) - expect_equal(c(scalar_2,scalar_2,scalar_2,scalar_2), expand_dims_1d(scalar_2, 4)) - expect_equal(c(scalar_3,scalar_3), expand_dims_1d(scalar_3, 2)) - expect_equal(c(array_1d_3,array_1d_3,array_1d_3), expand_dims_1d(array_1d_3, 3)) - - output_exp <- matrix(rep(scalar_1, 6), nrow = 2, byrow = T) - expect_equal(output_exp, expand_dims_2d(scalar_1, 2, 3)) - output_exp <- matrix(rep(scalar_2, 8), nrow = 2, byrow = T) - expect_equal(output_exp, expand_dims_2d(scalar_2, 2, 4)) - output_exp <- matrix(rep(scalar_3, 6), nrow = 3, byrow = T) - expect_equal(output_exp, expand_dims_2d(scalar_3, 3, 2)) - output_exp <- matrix(rep(array_1d_3, 6), nrow = 3, byrow = T) - expect_equal(output_exp, expand_dims_2d(array_1d_3, 3, 2)) - output_exp <- unname(rbind(array_1d_1, array_1d_1)) - expect_equal(output_exp, expand_dims_2d(array_1d_1, 2, 4)) - output_exp <- unname(rbind(array_1d_2, array_1d_2, array_1d_2)) - expect_equal(output_exp, expand_dims_2d(array_1d_2, 3, 3)) - output_exp <- unname(cbind(array_1d_2, array_1d_2, array_1d_2, array_1d_2)) - expect_equal(output_exp, expand_dims_2d(array_1d_2, 3, 4)) - output_exp <- unname(cbind(array_1d_3, array_1d_3, array_1d_3, array_1d_3)) - expect_equal(output_exp, expand_dims_2d(array_1d_3, 1, 4)) - output_exp <- unname(rbind(array_1d_3, array_1d_3, array_1d_3, array_1d_3)) - expect_equal(output_exp, expand_dims_2d(array_1d_3, 4, 1)) - - expect_equal(diag(scalar_1, 3), expand_dims_2d_diag(scalar_1, 3)) - expect_equal(diag(scalar_2, 2), expand_dims_2d_diag(scalar_2, 2)) - expect_equal(diag(scalar_3, 4), expand_dims_2d_diag(scalar_3, 4)) - expect_equal(diag(array_1d_3, 2), expand_dims_2d_diag(array_1d_3, 2)) -}) \ No newline at end of file + skip_on_cran() + + # Test data + scalar_1 <- 1.5 + scalar_2 <- -2.5 + scalar_3 <- 4 + array_1d_1 <- c(1.6, 3.4, 7.6, 8.7) + array_1d_2 <- c(2.5, 3.1, 5.6) + array_1d_3 <- c(2.5) + array_2d_1 <- matrix( + c(2.5, 1.2, 4.3, 7.4, 1.7, 2.9, 3.6, 9.1, 7.2, 4.5, 6.7, 1.4), + nrow = 3, + ncol = 4, + byrow = T + ) + array_2d_2 <- matrix( + c(2.5, 1.2, 4.3, 7.4, 1.7, 2.9, 3.6, 9.1), + nrow = 2, + ncol = 4, + byrow = T + ) + array_square_1 <- matrix( + c(2.5, 1.2, 1.7, 2.9), + nrow = 2, + ncol = 2, + byrow = T + ) + array_square_2 <- matrix( + c(2.5, 0.0, 0.0, 2.9), + nrow = 2, + ncol = 2, + byrow = T + ) + array_square_3 <- matrix( + c(2.5, 0.0, 0.0, 0.0, 2.9, 0.0, 0.0, 0.0, 5.6), + nrow = 3, + ncol = 3, + byrow = T + ) + + # Error cases + expect_error(expand_dims_1d(array_1d_1, 5)) + expect_error(expand_dims_1d(array_1d_2, 4)) + expect_error(expand_dims_2d(array_2d_1, 2, 4)) + expect_error(expand_dims_2d(array_2d_2, 3, 4)) + expect_error(expand_dims_2d_diag(array_square_1, 4)) + expect_error(expand_dims_2d_diag(array_square_2, 3)) + expect_error(expand_dims_2d_diag(array_square_3, 2)) + + # Working cases + expect_equal(c(scalar_1, scalar_1, scalar_1), expand_dims_1d(scalar_1, 3)) + expect_equal( + c(scalar_2, scalar_2, scalar_2, scalar_2), + expand_dims_1d(scalar_2, 4) + ) + expect_equal(c(scalar_3, scalar_3), expand_dims_1d(scalar_3, 2)) + expect_equal( + c(array_1d_3, array_1d_3, array_1d_3), + expand_dims_1d(array_1d_3, 3) + ) + + output_exp <- matrix(rep(scalar_1, 6), nrow = 2, byrow = T) + expect_equal(output_exp, expand_dims_2d(scalar_1, 2, 3)) + output_exp <- matrix(rep(scalar_2, 8), nrow = 2, byrow = T) + expect_equal(output_exp, expand_dims_2d(scalar_2, 2, 4)) + output_exp <- matrix(rep(scalar_3, 6), nrow = 3, byrow = T) + expect_equal(output_exp, expand_dims_2d(scalar_3, 3, 2)) + output_exp <- matrix(rep(array_1d_3, 6), nrow = 3, byrow = T) + expect_equal(output_exp, expand_dims_2d(array_1d_3, 3, 2)) + output_exp <- unname(rbind(array_1d_1, array_1d_1)) + expect_equal(output_exp, expand_dims_2d(array_1d_1, 2, 4)) + output_exp <- unname(rbind(array_1d_2, array_1d_2, array_1d_2)) + expect_equal(output_exp, expand_dims_2d(array_1d_2, 3, 3)) + output_exp <- unname(cbind(array_1d_2, array_1d_2, array_1d_2, array_1d_2)) + expect_equal(output_exp, expand_dims_2d(array_1d_2, 3, 4)) + output_exp <- unname(cbind(array_1d_3, array_1d_3, array_1d_3, array_1d_3)) + expect_equal(output_exp, expand_dims_2d(array_1d_3, 1, 4)) + output_exp <- unname(rbind(array_1d_3, array_1d_3, array_1d_3, array_1d_3)) + expect_equal(output_exp, expand_dims_2d(array_1d_3, 4, 1)) + + expect_equal(diag(scalar_1, 3), expand_dims_2d_diag(scalar_1, 3)) + expect_equal(diag(scalar_2, 2), expand_dims_2d_diag(scalar_2, 2)) + expect_equal(diag(scalar_3, 4), expand_dims_2d_diag(scalar_3, 4)) + expect_equal(diag(array_1d_3, 2), expand_dims_2d_diag(array_1d_3, 2)) +}) From 5089d68ca325df64fdef1095a2e989cf03afda97 Mon Sep 17 00:00:00 2001 From: Drew Herren Date: Wed, 15 Oct 2025 19:43:30 -0500 Subject: [PATCH 12/53] Fixed BCF heteroskedasticity bug --- R/bcf.R | 59 +++++++++++++++++++++++++++---------- test/R/testthat/test-bart.R | 2 +- 2 files changed, 44 insertions(+), 17 deletions(-) diff --git a/R/bcf.R b/R/bcf.R index b69e0ad2..a30b4168 100644 --- a/R/bcf.R +++ b/R/bcf.R @@ -983,6 +983,16 @@ bcf <- function( } } + # Runtime checks for variance forest + if (include_variance_forest) { + if (sample_sigma2_global) { + warning( + "Global error variance will not be sampled with a heteroskedasticity" + ) + sample_sigma2_global <- F + } + } + # Handle standardization, prior calibration, and initialization of forest # differently for binary and continuous outcomes if (probit_outcome_model) { @@ -2827,26 +2837,43 @@ predict.bcfmodel <- function( rfx_basis ) * y_std - if (predict_mean) { - rfx_predictions <- rowMeans(rfx_predictions) - } } # Combine into y hat predictions - if (probability_scale) { - if (has_rfx) { - y_hat <- pnorm(mu_hat + treatment_term + rfx_predictions) - rfx_predictions <- pnorm(rfx_predictions) - } else { - y_hat <- pnorm(mu_hat + treatment_term) - } - mu_hat <- pnorm(mu_hat) - tau_hat <- pnorm(tau_hat) - } else { - if (has_rfx) { - y_hat <- mu_hat + treatment_term + rfx_predictions + needs_mean_term_preds <- predict_y_hat || + predict_mu_forest || + predict_tau_forest || + predict_rfx + if (needs_mean_term_preds) { + if (probability_scale) { + if (has_rfx) { + if (predict_y_hat) { + y_hat <- pnorm(mu_hat + treatment_term + rfx_predictions) + } + if (predict_rfx) { + rfx_predictions <- pnorm(rfx_predictions) + } + } else { + if (predict_y_hat) { + y_hat <- pnorm(mu_hat + treatment_term) + } + } + if (predict_mu_forest) { + mu_hat <- pnorm(mu_hat) + } + if (predict_tau_forest) { + tau_hat <- pnorm(tau_hat) + } } else { - y_hat <- mu_hat + treatment_term + if (has_rfx) { + if (predict_y_hat) { + y_hat <- mu_hat + treatment_term + rfx_predictions + } + } else { + if (predict_y_hat) { + y_hat <- mu_hat + treatment_term + } + } } } diff --git a/test/R/testthat/test-bart.R b/test/R/testthat/test-bart.R index 92ab1f03..d5e3570c 100644 --- a/test/R/testthat/test-bart.R +++ b/test/R/testthat/test-bart.R @@ -433,7 +433,7 @@ test_that("BART Predictions", { ) # Check that cached predictions agree with results of predict() function - train_preds <- predict(bart_model, X = X_train) + train_preds <- predict(bart_model, covariates = X_train) train_preds_mean_cached <- bart_model$y_hat_train train_preds_mean_recomputed <- train_preds$mean_forest_predictions train_preds_variance_cached <- bart_model$sigma2_x_hat_train From 4601dc327ab60bce839d701284dd53a8cb7fc458 Mon Sep 17 00:00:00 2001 From: Drew Herren Date: Wed, 15 Oct 2025 22:47:54 -0500 Subject: [PATCH 13/53] Fixed documentation for several functions --- NAMESPACE | 1 + R/bart.R | 2 +- R/posterior_transformation.R | 4 ++-- R/stochtree-package.R | 1 + man/compute_bart_posterior_interval.Rd | 4 ++-- man/compute_bcf_posterior_interval.Rd | 4 ++-- man/predict.bartmodel.Rd | 4 ++-- 7 files changed, 11 insertions(+), 9 deletions(-) diff --git a/NAMESPACE b/NAMESPACE index ab6544bc..930789db 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -81,6 +81,7 @@ importFrom(stats,pnorm) importFrom(stats,predict) importFrom(stats,qgamma) importFrom(stats,qnorm) +importFrom(stats,quantile) importFrom(stats,resid) importFrom(stats,rnorm) importFrom(stats,runif) diff --git a/R/bart.R b/R/bart.R index bd94b625..18cd757f 100644 --- a/R/bart.R +++ b/R/bart.R @@ -1761,7 +1761,7 @@ bart <- function( #' Predict from a sampled BART model on new data #' #' @param object Object of type `bart` containing draws of a regression forest and associated sampling outputs. -#' @param X Covariates used to determine tree leaf predictions for each observation. Must be passed as a matrix or dataframe. +#' @param covariates Covariates used to determine tree leaf predictions for each observation. Must be passed as a matrix or dataframe. #' @param leaf_basis (Optional) Bases used for prediction (by e.g. dot product with leaf values). Default: `NULL`. #' @param rfx_group_ids (Optional) Test set group labels used for an additive random effects model. #' We do not currently support (but plan to in the near future), test set evaluation for group labels diff --git a/R/posterior_transformation.R b/R/posterior_transformation.R index 064f9ed0..c8adf8bc 100644 --- a/R/posterior_transformation.R +++ b/R/posterior_transformation.R @@ -356,7 +356,7 @@ posterior_predictive_heuristic_multiplier <- function( #' #' This function computes posterior credible intervals for specified terms from a fitted BCF model. It supports intervals for prognostic forests, CATE forests, variance forests, random effects, and overall mean outcome predictions. #' @param model_object A fitted BCF model object of class `bcfmodel`. -#' @param term A character string specifying the model term for which to compute intervals. Options for BCF models are `"prognostic_function"`, `"cate"`, `"variance_forest"`, `"rfx"`, or `"y_hat"`. +#' @param terms A character string specifying the model term(s) for which to compute intervals. Options for BCF models are `"prognostic_function"`, `"cate"`, `"variance_forest"`, `"rfx"`, or `"y_hat"`. #' @param level A numeric value between 0 and 1 specifying the credible interval level (default is 0.95 for a 95% credible interval). #' @param scale (Optional) Scale of mean function predictions. Options are "linear", which returns predictions on the original scale of the mean forest / RFX terms, and "probability", which transforms predictions into a probability of observing `y == 1`. "probability" is only valid for models fit with a probit outcome model. Default: "linear". #' @param covariates (Optional) A matrix or data frame of covariates at which to compute the intervals. Required if the requested term depends on covariates (e.g., prognostic forest, CATE forest, variance forest, or overall predictions). @@ -551,7 +551,7 @@ compute_bcf_posterior_interval <- function( #' #' This function computes posterior credible intervals for specified terms from a fitted BART model. It supports intervals for mean functions, variance functions, random effects, and overall predictions. #' @param model_object A fitted BART or BCF model object of class `bartmodel`. -#' @param term A character string specifying the model term for which to compute intervals. Options for BART models are `"mean_forest"`, `"variance_forest"`, `"rfx"`, or `"y_hat"`. +#' @param terms A character string specifying the model term(s) for which to compute intervals. Options for BART models are `"mean_forest"`, `"variance_forest"`, `"rfx"`, or `"y_hat"`. #' @param level A numeric value between 0 and 1 specifying the credible interval level (default is 0.95 for a 95% credible interval). #' @param scale (Optional) Scale of mean function predictions. Options are "linear", which returns predictions on the original scale of the mean forest / RFX terms, and "probability", which transforms predictions into a probability of observing `y == 1`. "probability" is only valid for models fit with a probit outcome model. Default: "linear". #' @param covariates A matrix or data frame of covariates at which to compute the intervals. Required if the requested term depends on covariates (e.g., mean forest, variance forest, or overall predictions). diff --git a/R/stochtree-package.R b/R/stochtree-package.R index 97eeded1..41fcf99c 100644 --- a/R/stochtree-package.R +++ b/R/stochtree-package.R @@ -6,6 +6,7 @@ #' @importFrom stats predict #' @importFrom stats qgamma #' @importFrom stats qnorm +#' @importFrom stats quantile #' @importFrom stats pnorm #' @importFrom stats resid #' @importFrom stats rnorm diff --git a/man/compute_bart_posterior_interval.Rd b/man/compute_bart_posterior_interval.Rd index 66ddceaf..2ae16f24 100644 --- a/man/compute_bart_posterior_interval.Rd +++ b/man/compute_bart_posterior_interval.Rd @@ -18,6 +18,8 @@ compute_bart_posterior_interval( \arguments{ \item{model_object}{A fitted BART or BCF model object of class \code{bartmodel}.} +\item{terms}{A character string specifying the model term(s) for which to compute intervals. Options for BART models are \code{"mean_forest"}, \code{"variance_forest"}, \code{"rfx"}, or \code{"y_hat"}.} + \item{level}{A numeric value between 0 and 1 specifying the credible interval level (default is 0.95 for a 95\% credible interval).} \item{scale}{(Optional) Scale of mean function predictions. Options are "linear", which returns predictions on the original scale of the mean forest / RFX terms, and "probability", which transforms predictions into a probability of observing \code{y == 1}. "probability" is only valid for models fit with a probit outcome model. Default: "linear".} @@ -29,8 +31,6 @@ compute_bart_posterior_interval( \item{rfx_group_ids}{An optional vector of group IDs for random effects. Required if the requested term includes random effects.} \item{rfx_basis}{An optional matrix of basis function evaluations for random effects. Required if the requested term includes random effects.} - -\item{term}{A character string specifying the model term for which to compute intervals. Options for BART models are \code{"mean_forest"}, \code{"variance_forest"}, \code{"rfx"}, or \code{"y_hat"}.} } \value{ A list containing the lower and upper bounds of the credible interval for the specified term. If multiple terms are requested, a named list with intervals for each term is returned. diff --git a/man/compute_bcf_posterior_interval.Rd b/man/compute_bcf_posterior_interval.Rd index f82849ed..1ff4836d 100644 --- a/man/compute_bcf_posterior_interval.Rd +++ b/man/compute_bcf_posterior_interval.Rd @@ -19,6 +19,8 @@ compute_bcf_posterior_interval( \arguments{ \item{model_object}{A fitted BCF model object of class \code{bcfmodel}.} +\item{terms}{A character string specifying the model term(s) for which to compute intervals. Options for BCF models are \code{"prognostic_function"}, \code{"cate"}, \code{"variance_forest"}, \code{"rfx"}, or \code{"y_hat"}.} + \item{level}{A numeric value between 0 and 1 specifying the credible interval level (default is 0.95 for a 95\% credible interval).} \item{scale}{(Optional) Scale of mean function predictions. Options are "linear", which returns predictions on the original scale of the mean forest / RFX terms, and "probability", which transforms predictions into a probability of observing \code{y == 1}. "probability" is only valid for models fit with a probit outcome model. Default: "linear".} @@ -32,8 +34,6 @@ compute_bcf_posterior_interval( \item{rfx_group_ids}{An optional vector of group IDs for random effects. Required if the requested term includes random effects.} \item{rfx_basis}{An optional matrix of basis function evaluations for random effects. Required if the requested term includes random effects.} - -\item{term}{A character string specifying the model term for which to compute intervals. Options for BCF models are \code{"prognostic_function"}, \code{"cate"}, \code{"variance_forest"}, \code{"rfx"}, or \code{"y_hat"}.} } \value{ A list containing the lower and upper bounds of the credible interval for the specified term. If multiple terms are requested, a named list with intervals for each term is returned. diff --git a/man/predict.bartmodel.Rd b/man/predict.bartmodel.Rd index 6d60b0db..0cb82678 100644 --- a/man/predict.bartmodel.Rd +++ b/man/predict.bartmodel.Rd @@ -19,6 +19,8 @@ \arguments{ \item{object}{Object of type \code{bart} containing draws of a regression forest and associated sampling outputs.} +\item{covariates}{Covariates used to determine tree leaf predictions for each observation. Must be passed as a matrix or dataframe.} + \item{leaf_basis}{(Optional) Bases used for prediction (by e.g. dot product with leaf values). Default: \code{NULL}.} \item{rfx_group_ids}{(Optional) Test set group labels used for an additive random effects model. @@ -34,8 +36,6 @@ that were not in the training set.} \item{scale}{(Optional) Scale of mean function predictions. Options are "linear", which returns predictions on the original scale of the mean forest / RFX terms, and "probability", which transforms predictions into a probability of observing \code{y == 1}. "probability" is only valid for models fit with a probit outcome model. Default: "linear".} \item{...}{(Optional) Other prediction parameters.} - -\item{X}{Covariates used to determine tree leaf predictions for each observation. Must be passed as a matrix or dataframe.} } \value{ List of prediction matrices or single prediction matrix / vector, depending on the terms requested. From c0cbdeb5a2ff0d16f472c9f4f9ba4ca44a77b28d Mon Sep 17 00:00:00 2001 From: Drew Herren Date: Wed, 15 Oct 2025 23:30:41 -0500 Subject: [PATCH 14/53] Updated BART predict method in python --- stochtree/bart.py | 87 ++++++++++++++++++++++++++++++++++++++--------- 1 file changed, 70 insertions(+), 17 deletions(-) diff --git a/stochtree/bart.py b/stochtree/bart.py index a1ffc317..1a847b6e 100644 --- a/stochtree/bart.py +++ b/stochtree/bart.py @@ -1579,7 +1579,8 @@ def predict( rfx_group_ids: np.array = None, rfx_basis: np.array = None, type: str = "posterior", - terms: Union[list[str], str] = "all" + terms: Union[list[str], str] = "all", + scale: str = "linear" ) -> Union[np.array, tuple]: """Return predictions from every forest sampled (either / both of mean and variance). Return type is either a single array of predictions, if a BART model only includes a @@ -1599,11 +1600,25 @@ def predict( Type of prediction to return. Options are "mean", which averages the predictions from every draw of a BART model, and "posterior", which returns the entire matrix of posterior predictions. Default: "posterior". terms : str, optional Which model terms to include in the prediction. This can be a single term or a list of model terms. Options include "y_hat", "mean_forest", "rfx", "variance_forest", or "all". If a model doesn't have mean forest, random effects, or variance forest predictions, but one of those terms is request, the request will simply be ignored. If none of the requested terms are present in a model, this function will return `NULL` along with a warning. Default: "all". + scale : str, optional + Scale of mean function predictions. Options are "linear", which returns predictions on the original scale of the mean forest / RFX terms, and "probability", which transforms predictions into a probability of observing `y == 1`. "probability" is only valid for models fit with a probit outcome model. Default: "linear". Returns ------- Dict of numpy arrays for each prediction term, or a simple numpy array if a single term is requested. """ + # Handle mean function scale + if not isinstance(scale, str): + raise ValueError("scale must be a string") + if scale not in ["linear", "probability"]: + raise ValueError("scale must either be 'linear' or 'probability'") + is_probit = self.probit_outcome_model + if (scale == "probability") and (not is_probit): + raise ValueError( + "scale cannot be 'probability' for models not fit with a probit outcome model" + ) + probability_scale = scale == "probability" + # Handle prediction type if not isinstance(type, str): raise ValueError("type must be a string") @@ -1635,6 +1650,13 @@ def predict( predict_rfx_intermediate = predict_y_hat and has_rfx predict_mean_forest_intermediate = predict_y_hat and has_mean_forest + # Check that we have at least one term to predict on probability scale + if (probability_scale and not predict_y_hat and not predict_mean_forest and not predict_rfx): + raise ValueError( + "scale can only be 'probability' if at least one mean term is requested" + ) + + # Check the model is valid if not self.is_sampled(): msg = ( "This BARTModel instance is not fitted yet. Call 'fit' with " @@ -1690,22 +1712,7 @@ def predict( if basis is not None: pred_dataset.add_basis(basis) - # Forest predictions - if predict_mean_forest or predict_mean_forest_intermediate: - mean_pred_raw = self.forest_container_mean.forest_container_cpp.Predict( - pred_dataset.dataset_cpp - ) - mean_forest_predictions = mean_pred_raw * self.y_std + self.y_bar - if predict_mean: - mean_forest_predictions = np.mean(mean_forest_predictions, axis = 1) - - if predict_rfx or predict_rfx_intermediate: - rfx_predictions = ( - self.rfx_container.predict(rfx_group_ids, rfx_basis) * self.y_std - ) - if predict_mean: - rfx_predictions = np.mean(rfx_predictions, axis = 1) - + # Variance forest predictions if predict_variance_forest: variance_pred_raw = ( self.forest_container_variance.forest_container_cpp.Predict( @@ -1725,6 +1732,24 @@ def predict( if predict_mean: variance_forest_predictions = np.mean(variance_forest_predictions, axis = 1) + # Forest predictions + if predict_mean_forest or predict_mean_forest_intermediate: + mean_pred_raw = self.forest_container_mean.forest_container_cpp.Predict( + pred_dataset.dataset_cpp + ) + mean_forest_predictions = mean_pred_raw * self.y_std + self.y_bar + # if predict_mean: + # mean_forest_predictions = np.mean(mean_forest_predictions, axis = 1) + + # Random effects predictions + if predict_rfx or predict_rfx_intermediate: + rfx_predictions = ( + self.rfx_container.predict(rfx_group_ids, rfx_basis) * self.y_std + ) + # if predict_mean: + # rfx_predictions = np.mean(rfx_predictions, axis = 1) + + # Combine into y hat predictions if predict_y_hat and has_mean_forest and has_rfx: y_hat = mean_forest_predictions + rfx_predictions elif predict_y_hat and has_mean_forest: @@ -1732,6 +1757,34 @@ def predict( elif predict_y_hat and has_rfx: y_hat = rfx_predictions + if probability_scale: + if predict_y_hat and has_mean_forest and has_rfx: + y_hat = norm.ppf(mean_forest_predictions + rfx_predictions) + mean_forest_predictions = norm.ppf(mean_forest_predictions) + rfx_predictions = norm.ppf(rfx_predictions) + elif predict_y_hat and has_mean_forest: + y_hat = norm.ppf(mean_forest_predictions) + mean_forest_predictions = norm.ppf(mean_forest_predictions) + elif predict_y_hat and has_rfx: + y_hat = norm.ppf(rfx_predictions) + rfx_predictions = norm.ppf(rfx_predictions) + else: + if predict_y_hat and has_mean_forest and has_rfx: + y_hat = mean_forest_predictions + rfx_predictions + elif predict_y_hat and has_mean_forest: + y_hat = mean_forest_predictions + elif predict_y_hat and has_rfx: + y_hat = rfx_predictions + + # Collapse to posterior mean predictions if requested + if predict_mean: + if predict_mean_forest: + mean_forest_predictions = np.mean(mean_forest_predictions, axis = 1) + if predict_rfx: + rfx_predictions = np.mean(rfx_predictions, axis = 1) + if predict_y_hat: + y_hat = np.mean(y_hat, axis = 1) + if predict_count == 1: if predict_y_hat: return y_hat From 78ceee2302a835c3cee8e28acaf7b979bd9bc078 Mon Sep 17 00:00:00 2001 From: Drew Herren Date: Wed, 15 Oct 2025 23:31:00 -0500 Subject: [PATCH 15/53] Reformat python code --- stochtree/bart.py | 157 +++++++++++++++++++++++++++++++--------------- 1 file changed, 105 insertions(+), 52 deletions(-) diff --git a/stochtree/bart.py b/stochtree/bart.py index 1a847b6e..6e3a1e8b 100644 --- a/stochtree/bart.py +++ b/stochtree/bart.py @@ -23,7 +23,12 @@ ) from .sampler import RNG, ForestSampler, GlobalVarianceModel, LeafVarianceModel from .serialization import JSONSerializer -from .utils import NotSampledError, _expand_dims_1d, _expand_dims_2d, _expand_dims_2d_diag +from .utils import ( + NotSampledError, + _expand_dims_1d, + _expand_dims_2d, + _expand_dims_2d_diag, +) class BARTModel: @@ -262,10 +267,18 @@ def sample( keep_every = general_params_updated["keep_every"] num_chains = general_params_updated["num_chains"] self.probit_outcome_model = general_params_updated["probit_outcome_model"] - rfx_working_parameter_prior_mean = general_params_updated["rfx_working_parameter_prior_mean"] - rfx_group_parameter_prior_mean = general_params_updated["rfx_group_parameter_prior_mean"] - rfx_working_parameter_prior_cov = general_params_updated["rfx_working_parameter_prior_cov"] - rfx_group_parameter_prior_cov = general_params_updated["rfx_group_parameter_prior_cov"] + rfx_working_parameter_prior_mean = general_params_updated[ + "rfx_working_parameter_prior_mean" + ] + rfx_group_parameter_prior_mean = general_params_updated[ + "rfx_group_parameter_prior_mean" + ] + rfx_working_parameter_prior_cov = general_params_updated[ + "rfx_working_parameter_prior_cov" + ] + rfx_group_parameter_prior_cov = general_params_updated[ + "rfx_group_parameter_prior_cov" + ] rfx_variance_prior_shape = general_params_updated["rfx_variance_prior_shape"] rfx_variance_prior_scale = general_params_updated["rfx_variance_prior_scale"] num_threads = general_params_updated["num_threads"] @@ -282,7 +295,9 @@ def sample( b_leaf = mean_forest_params_updated["sigma2_leaf_scale"] keep_vars_mean = mean_forest_params_updated["keep_vars"] drop_vars_mean = mean_forest_params_updated["drop_vars"] - num_features_subsample_mean = mean_forest_params_updated["num_features_subsample"] + num_features_subsample_mean = mean_forest_params_updated[ + "num_features_subsample" + ] # 3. Variance forest parameters num_trees_variance = variance_forest_params_updated["num_trees"] @@ -298,7 +313,9 @@ def sample( b_forest = variance_forest_params_updated["var_forest_prior_scale"] keep_vars_variance = variance_forest_params_updated["keep_vars"] drop_vars_variance = variance_forest_params_updated["drop_vars"] - num_features_subsample_variance = variance_forest_params_updated["num_features_subsample"] + num_features_subsample_variance = variance_forest_params_updated[ + "num_features_subsample" + ] # Override keep_gfr if there are no MCMC samples if num_mcmc == 0: @@ -989,26 +1006,34 @@ def sample( else: raise ValueError("There must be at least 1 random effect component") else: - alpha_init = _expand_dims_1d(rfx_working_parameter_prior_mean, num_rfx_components) - + alpha_init = _expand_dims_1d( + rfx_working_parameter_prior_mean, num_rfx_components + ) + if rfx_group_parameter_prior_mean is None: xi_init = np.tile(np.expand_dims(alpha_init, 1), (1, num_rfx_groups)) else: - xi_init = _expand_dims_2d(rfx_group_parameter_prior_mean, num_rfx_components, num_rfx_groups) - + xi_init = _expand_dims_2d( + rfx_group_parameter_prior_mean, num_rfx_components, num_rfx_groups + ) + if rfx_working_parameter_prior_cov is None: sigma_alpha_init = np.identity(num_rfx_components) else: - sigma_alpha_init = _expand_dims_2d_diag(rfx_working_parameter_prior_cov, num_rfx_components) - + sigma_alpha_init = _expand_dims_2d_diag( + rfx_working_parameter_prior_cov, num_rfx_components + ) + if rfx_group_parameter_prior_cov is None: sigma_xi_init = np.identity(num_rfx_components) else: - sigma_xi_init = _expand_dims_2d_diag(rfx_group_parameter_prior_cov, num_rfx_components) - + sigma_xi_init = _expand_dims_2d_diag( + rfx_group_parameter_prior_cov, num_rfx_components + ) + sigma_xi_shape = rfx_variance_prior_shape sigma_xi_scale = rfx_variance_prior_scale - + # Random effects sampling data structures rfx_dataset_train = RandomEffectsDataset() rfx_dataset_train.add_group_labels(rfx_group_ids_train) @@ -1046,9 +1071,13 @@ def sample( if sample_sigma2_leaf: self.leaf_scale_samples = np.empty(self.num_samples, dtype=np.float64) if self.include_mean_forest: - yhat_train_raw = np.empty((self.n_train, self.num_samples), dtype=np.float64) + yhat_train_raw = np.empty( + (self.n_train, self.num_samples), dtype=np.float64 + ) if self.include_variance_forest: - sigma2_x_train_raw = np.empty((self.n_train, self.num_samples), dtype=np.float64) + sigma2_x_train_raw = np.empty( + (self.n_train, self.num_samples), dtype=np.float64 + ) sample_counter = -1 # Forest Dataset (covariates and optional basis) @@ -1104,8 +1133,8 @@ def sample( max_depth=max_depth_mean, leaf_model_type=leaf_model_mean_forest, leaf_model_scale=current_leaf_scale, - cutpoint_grid_size=cutpoint_grid_size, - num_features_subsample=num_features_subsample_mean + cutpoint_grid_size=cutpoint_grid_size, + num_features_subsample=num_features_subsample_mean, ) forest_sampler_mean = ForestSampler( forest_dataset_train, @@ -1128,7 +1157,7 @@ def sample( cutpoint_grid_size=cutpoint_grid_size, variance_forest_shape=a_forest, variance_forest_scale=b_forest, - num_features_subsample=num_features_subsample_variance + num_features_subsample=num_features_subsample_variance, ) forest_sampler_variance = ForestSampler( forest_dataset_train, @@ -1234,7 +1263,9 @@ def sample( # Cache train set predictions since they are already computed during sampling if keep_sample: - yhat_train_raw[:,sample_counter] = forest_sampler_mean.get_cached_forest_predictions() + yhat_train_raw[:, sample_counter] = ( + forest_sampler_mean.get_cached_forest_predictions() + ) # Sample the variance forest if self.include_variance_forest: @@ -1253,7 +1284,9 @@ def sample( # Cache train set predictions since they are already computed during sampling if keep_sample: - sigma2_x_train_raw[:,sample_counter] = forest_sampler_variance.get_cached_forest_predictions() + sigma2_x_train_raw[:, sample_counter] = ( + forest_sampler_variance.get_cached_forest_predictions() + ) # Sample variance parameters (if requested) if self.sample_sigma2_global: @@ -1435,7 +1468,9 @@ def sample( ) if keep_sample: - yhat_train_raw[:,sample_counter] = forest_sampler_mean.get_cached_forest_predictions() + yhat_train_raw[:, sample_counter] = ( + forest_sampler_mean.get_cached_forest_predictions() + ) # Sample the variance forest if self.include_variance_forest: @@ -1453,7 +1488,9 @@ def sample( ) if keep_sample: - sigma2_x_train_raw[:,sample_counter] = forest_sampler_variance.get_cached_forest_predictions() + sigma2_x_train_raw[:, sample_counter] = ( + forest_sampler_variance.get_cached_forest_predictions() + ) # Sample variance parameters (if requested) if self.sample_sigma2_global: @@ -1504,9 +1541,9 @@ def sample( if self.sample_sigma2_leaf: self.leaf_scale_samples = self.leaf_scale_samples[num_gfr:] if self.include_mean_forest: - yhat_train_raw = yhat_train_raw[:,num_gfr:] + yhat_train_raw = yhat_train_raw[:, num_gfr:] if self.include_variance_forest: - sigma2_x_train_raw = sigma2_x_train_raw[:,num_gfr:] + sigma2_x_train_raw = sigma2_x_train_raw[:, num_gfr:] self.num_samples -= num_gfr # Store predictions @@ -1553,7 +1590,10 @@ def sample( ) else: self.sigma2_x_train = ( - np.exp(sigma2_x_train_raw) * self.sigma2_init * self.y_std * self.y_std + np.exp(sigma2_x_train_raw) + * self.sigma2_init + * self.y_std + * self.y_std ) if self.has_test: sigma2_x_test_raw = ( @@ -1577,10 +1617,10 @@ def predict( covariates: Union[np.array, pd.DataFrame], basis: np.array = None, rfx_group_ids: np.array = None, - rfx_basis: np.array = None, - type: str = "posterior", - terms: Union[list[str], str] = "all", - scale: str = "linear" + rfx_basis: np.array = None, + type: str = "posterior", + terms: Union[list[str], str] = "all", + scale: str = "linear", ) -> Union[np.array, tuple]: """Return predictions from every forest sampled (either / both of mean and variance). Return type is either a single array of predictions, if a BART model only includes a @@ -1634,28 +1674,39 @@ def predict( has_variance_forest = self.include_variance_forest has_rfx = self.has_rfx has_y_hat = has_mean_forest or has_rfx - predict_y_hat = ((has_y_hat and ("y_hat" in terms)) or - (has_y_hat and ("all" in terms))) - predict_mean_forest = ((has_mean_forest and ("mean_forest" in terms)) or - (has_mean_forest and ("all" in terms))) - predict_rfx = ((has_rfx and ("rfx" in terms)) or - (has_rfx and ("all" in terms))) - predict_variance_forest = ((has_variance_forest and ("variance_forest" in terms)) or - (has_variance_forest and ("all" in terms))) - predict_count = (predict_y_hat + predict_mean_forest + predict_rfx + predict_variance_forest) + predict_y_hat = (has_y_hat and ("y_hat" in terms)) or ( + has_y_hat and ("all" in terms) + ) + predict_mean_forest = (has_mean_forest and ("mean_forest" in terms)) or ( + has_mean_forest and ("all" in terms) + ) + predict_rfx = (has_rfx and ("rfx" in terms)) or (has_rfx and ("all" in terms)) + predict_variance_forest = ( + has_variance_forest and ("variance_forest" in terms) + ) or (has_variance_forest and ("all" in terms)) + predict_count = ( + predict_y_hat + predict_mean_forest + predict_rfx + predict_variance_forest + ) if predict_count == 0: term_list = ", ".join(terms) - warnings.warn(f"None of the requested model terms, {term_list}, were fit in this model") + warnings.warn( + f"None of the requested model terms, {term_list}, were fit in this model" + ) return None predict_rfx_intermediate = predict_y_hat and has_rfx predict_mean_forest_intermediate = predict_y_hat and has_mean_forest # Check that we have at least one term to predict on probability scale - if (probability_scale and not predict_y_hat and not predict_mean_forest and not predict_rfx): + if ( + probability_scale + and not predict_y_hat + and not predict_mean_forest + and not predict_rfx + ): raise ValueError( "scale can only be 'probability' if at least one mean term is requested" ) - + # Check the model is valid if not self.is_sampled(): msg = ( @@ -1730,7 +1781,9 @@ def predict( variance_pred_raw * self.sigma2_init * self.y_std * self.y_std ) if predict_mean: - variance_forest_predictions = np.mean(variance_forest_predictions, axis = 1) + variance_forest_predictions = np.mean( + variance_forest_predictions, axis=1 + ) # Forest predictions if predict_mean_forest or predict_mean_forest_intermediate: @@ -1756,7 +1809,7 @@ def predict( y_hat = mean_forest_predictions elif predict_y_hat and has_rfx: y_hat = rfx_predictions - + if probability_scale: if predict_y_hat and has_mean_forest and has_rfx: y_hat = norm.ppf(mean_forest_predictions + rfx_predictions) @@ -1775,16 +1828,16 @@ def predict( y_hat = mean_forest_predictions elif predict_y_hat and has_rfx: y_hat = rfx_predictions - + # Collapse to posterior mean predictions if requested if predict_mean: if predict_mean_forest: - mean_forest_predictions = np.mean(mean_forest_predictions, axis = 1) + mean_forest_predictions = np.mean(mean_forest_predictions, axis=1) if predict_rfx: - rfx_predictions = np.mean(rfx_predictions, axis = 1) + rfx_predictions = np.mean(rfx_predictions, axis=1) if predict_y_hat: - y_hat = np.mean(y_hat, axis = 1) - + y_hat = np.mean(y_hat, axis=1) + if predict_count == 1: if predict_y_hat: return y_hat @@ -1813,7 +1866,7 @@ def predict( else: result["variance_forest_predictions"] = None return result - + def predict_mean( self, covariates: np.array, From 32f13a60ac8377f7926076ab174089eaf5c77e30 Mon Sep 17 00:00:00 2001 From: Drew Herren Date: Wed, 15 Oct 2025 23:32:30 -0500 Subject: [PATCH 16/53] Reformat other python source files --- stochtree/bcf.py | 227 ++++++++++++++++++++++-------------- stochtree/config.py | 6 +- stochtree/data.py | 36 ++++-- stochtree/forest.py | 105 +++++++++++------ stochtree/kernel.py | 153 +++++++++++++++++------- stochtree/random_effects.py | 22 ++-- stochtree/sampler.py | 6 +- stochtree/serialization.py | 4 +- stochtree/utils.py | 70 ++++++----- 9 files changed, 404 insertions(+), 225 deletions(-) diff --git a/stochtree/bcf.py b/stochtree/bcf.py index 7d384fa3..fcf0f8f3 100644 --- a/stochtree/bcf.py +++ b/stochtree/bcf.py @@ -23,7 +23,12 @@ ) from .sampler import RNG, ForestSampler, GlobalVarianceModel, LeafVarianceModel from .serialization import JSONSerializer -from .utils import NotSampledError, _expand_dims_1d, _expand_dims_2d, _expand_dims_2d_diag +from .utils import ( + NotSampledError, + _expand_dims_1d, + _expand_dims_2d, + _expand_dims_2d_diag, +) class BCFModel: @@ -323,10 +328,18 @@ def sample( keep_every = general_params_updated["keep_every"] num_chains = general_params_updated["num_chains"] self.probit_outcome_model = general_params_updated["probit_outcome_model"] - rfx_working_parameter_prior_mean = general_params_updated["rfx_working_parameter_prior_mean"] - rfx_group_parameter_prior_mean = general_params_updated["rfx_group_parameter_prior_mean"] - rfx_working_parameter_prior_cov = general_params_updated["rfx_working_parameter_prior_cov"] - rfx_group_parameter_prior_cov = general_params_updated["rfx_group_parameter_prior_cov"] + rfx_working_parameter_prior_mean = general_params_updated[ + "rfx_working_parameter_prior_mean" + ] + rfx_group_parameter_prior_mean = general_params_updated[ + "rfx_group_parameter_prior_mean" + ] + rfx_working_parameter_prior_cov = general_params_updated[ + "rfx_working_parameter_prior_cov" + ] + rfx_group_parameter_prior_cov = general_params_updated[ + "rfx_group_parameter_prior_cov" + ] rfx_variance_prior_shape = general_params_updated["rfx_variance_prior_shape"] rfx_variance_prior_scale = general_params_updated["rfx_variance_prior_scale"] num_threads = general_params_updated["num_threads"] @@ -343,7 +356,9 @@ def sample( b_leaf_mu = prognostic_forest_params_updated["sigma2_leaf_scale"] keep_vars_mu = prognostic_forest_params_updated["keep_vars"] drop_vars_mu = prognostic_forest_params_updated["drop_vars"] - num_features_subsample_mu = prognostic_forest_params_updated["num_features_subsample"] + num_features_subsample_mu = prognostic_forest_params_updated[ + "num_features_subsample" + ] # 3. Tau forest parameters num_trees_tau = treatment_effect_forest_params_updated["num_trees"] @@ -362,7 +377,9 @@ def sample( delta_max = treatment_effect_forest_params_updated["delta_max"] keep_vars_tau = treatment_effect_forest_params_updated["keep_vars"] drop_vars_tau = treatment_effect_forest_params_updated["drop_vars"] - num_features_subsample_tau = treatment_effect_forest_params_updated["num_features_subsample"] + num_features_subsample_tau = treatment_effect_forest_params_updated[ + "num_features_subsample" + ] # 4. Variance forest parameters num_trees_variance = variance_forest_params_updated["num_trees"] @@ -378,7 +395,9 @@ def sample( b_forest = variance_forest_params_updated["var_forest_prior_scale"] keep_vars_variance = variance_forest_params_updated["keep_vars"] drop_vars_variance = variance_forest_params_updated["drop_vars"] - num_features_subsample_variance = variance_forest_params_updated["num_features_subsample"] + num_features_subsample_variance = variance_forest_params_updated[ + "num_features_subsample" + ] # Override keep_gfr if there are no MCMC samples if num_mcmc == 0: @@ -1193,34 +1212,24 @@ def sample( current_sigma2 = sigma2_init self.sigma2_init = sigma2_init # Skip variance_forest_init, since variance forests are not supported with probit link - b_leaf_mu = ( - 1.0 / num_trees_mu - if b_leaf_mu is None - else b_leaf_mu - ) - b_leaf_tau = ( - 1.0 / (2 * num_trees_tau) - if b_leaf_tau is None - else b_leaf_tau - ) + b_leaf_mu = 1.0 / num_trees_mu if b_leaf_mu is None else b_leaf_mu + b_leaf_tau = 1.0 / (2 * num_trees_tau) if b_leaf_tau is None else b_leaf_tau sigma2_leaf_mu = ( - 1 / num_trees_mu - if sigma2_leaf_mu is None - else sigma2_leaf_mu + 1 / num_trees_mu if sigma2_leaf_mu is None else sigma2_leaf_mu ) if isinstance(sigma2_leaf_mu, float): current_leaf_scale_mu = np.array([[sigma2_leaf_mu]]) else: raise ValueError("sigma2_leaf_mu must be a scalar") # Calibrate prior so that P(abs(tau(X)) < delta_max / dnorm(0)) = p - # Use p = 0.9 as an internal default rather than adding another - # user-facing "parameter" of the binary outcome BCF prior. - # Can be overriden by specifying `sigma2_leaf_init` in + # Use p = 0.9 as an internal default rather than adding another + # user-facing "parameter" of the binary outcome BCF prior. + # Can be overriden by specifying `sigma2_leaf_init` in # treatment_effect_forest_params. p = 0.6827 q_quantile = norm.ppf((p + 1) / 2.0) sigma2_leaf_tau = ( - ((delta_max / (q_quantile*norm.pdf(0)))**2) / num_trees_tau + ((delta_max / (q_quantile * norm.pdf(0))) ** 2) / num_trees_tau if sigma2_leaf_tau is None else sigma2_leaf_tau ) @@ -1231,7 +1240,9 @@ def sample( ) if isinstance(sigma2_leaf_tau, float): if Z_train.shape[1] > 1: - current_leaf_scale_tau = np.zeros((Z_train.shape[1], Z_train.shape[1]), dtype=float) + current_leaf_scale_tau = np.zeros( + (Z_train.shape[1], Z_train.shape[1]), dtype=float + ) np.fill_diagonal(current_leaf_scale_tau, sigma2_leaf_tau) else: current_leaf_scale_tau = np.array([[sigma2_leaf_tau]]) @@ -1304,7 +1315,9 @@ def sample( ) if isinstance(sigma2_leaf_tau, float): if Z_train.shape[1] > 1: - current_leaf_scale_tau = np.zeros((Z_train.shape[1], Z_train.shape[1]), dtype=float) + current_leaf_scale_tau = np.zeros( + (Z_train.shape[1], Z_train.shape[1]), dtype=float + ) np.fill_diagonal(current_leaf_scale_tau, sigma2_leaf_tau) else: current_leaf_scale_tau = np.array([[sigma2_leaf_tau]]) @@ -1333,7 +1346,7 @@ def sample( if not a_forest: a_forest = 1.0 if not b_forest: - b_forest = 1.0 + b_forest = 1.0 # Runtime checks on RFX group ids self.has_rfx = False @@ -1381,23 +1394,31 @@ def sample( else: raise ValueError("There must be at least 1 random effect component") else: - alpha_init = _expand_dims_1d(rfx_working_parameter_prior_mean, num_rfx_components) - + alpha_init = _expand_dims_1d( + rfx_working_parameter_prior_mean, num_rfx_components + ) + if rfx_group_parameter_prior_mean is None: xi_init = np.tile(np.expand_dims(alpha_init, 1), (1, num_rfx_groups)) else: - xi_init = _expand_dims_2d(rfx_group_parameter_prior_mean, num_rfx_components, num_rfx_groups) - + xi_init = _expand_dims_2d( + rfx_group_parameter_prior_mean, num_rfx_components, num_rfx_groups + ) + if rfx_working_parameter_prior_cov is None: sigma_alpha_init = np.identity(num_rfx_components) else: - sigma_alpha_init = _expand_dims_2d_diag(rfx_working_parameter_prior_cov, num_rfx_components) - + sigma_alpha_init = _expand_dims_2d_diag( + rfx_working_parameter_prior_cov, num_rfx_components + ) + if rfx_group_parameter_prior_cov is None: sigma_xi_init = np.identity(num_rfx_components) else: - sigma_xi_init = _expand_dims_2d_diag(rfx_group_parameter_prior_cov, num_rfx_components) - + sigma_xi_init = _expand_dims_2d_diag( + rfx_group_parameter_prior_cov, num_rfx_components + ) + sigma_xi_shape = rfx_variance_prior_shape sigma_xi_scale = rfx_variance_prior_scale @@ -1526,7 +1547,9 @@ def sample( muhat_train_raw = np.empty((self.n_train, self.num_samples), dtype=np.float64) tauhat_train_raw = np.empty((self.n_train, self.num_samples), dtype=np.float64) if self.include_variance_forest: - sigma2_x_train_raw = np.empty((self.n_train, self.num_samples), dtype=np.float64) + sigma2_x_train_raw = np.empty( + (self.n_train, self.num_samples), dtype=np.float64 + ) sample_counter = -1 # Prepare adaptive coding structure @@ -1699,7 +1722,7 @@ def sample( keep_sample = True if keep_sample: sample_counter += 1 - + if self.probit_outcome_model: # Sample latent probit variable z | - forest_pred_mu = active_forest_mu.predict(forest_dataset_train) @@ -1725,7 +1748,7 @@ def sample( # Update outcome new_outcome = np.squeeze(resid_train) - forest_pred residual_train.update_data(new_outcome) - + # Sample the prognostic forest forest_sampler_mu.sample_one_iteration( self.forest_container_mu, @@ -1742,7 +1765,9 @@ def sample( # Cache train set predictions since they are already computed during sampling if keep_sample: - muhat_train_raw[:,sample_counter] = forest_sampler_mu.get_cached_forest_predictions() + muhat_train_raw[:, sample_counter] = ( + forest_sampler_mu.get_cached_forest_predictions() + ) # Sample variance parameters (if requested) if self.sample_sigma2_global: @@ -1778,7 +1803,7 @@ def sample( num_threads, ) - # Cannot cache train set predictions for tau because the cached predictions in the + # Cannot cache train set predictions for tau because the cached predictions in the # tracking data structures are pre-multiplied by the basis (treatment) # ... @@ -1842,7 +1867,9 @@ def sample( # Cache train set predictions since they are already computed during sampling if keep_sample: - sigma2_x_train_raw[:,sample_counter] = forest_sampler_variance.get_cached_forest_predictions() + sigma2_x_train_raw[:, sample_counter] = ( + forest_sampler_variance.get_cached_forest_predictions() + ) # Sample variance parameters (if requested) if self.sample_sigma2_global: @@ -1895,7 +1922,7 @@ def sample( keep_sample = False if keep_sample: sample_counter += 1 - + if self.probit_outcome_model: # Sample latent probit variable z | - forest_pred_mu = active_forest_mu.predict(forest_dataset_train) @@ -1921,7 +1948,7 @@ def sample( # Update outcome new_outcome = np.squeeze(resid_train) - forest_pred residual_train.update_data(new_outcome) - + # Sample the prognostic forest forest_sampler_mu.sample_one_iteration( self.forest_container_mu, @@ -1938,7 +1965,9 @@ def sample( # Cache train set predictions since they are already computed during sampling if keep_sample: - muhat_train_raw[:,sample_counter] = forest_sampler_mu.get_cached_forest_predictions() + muhat_train_raw[:, sample_counter] = ( + forest_sampler_mu.get_cached_forest_predictions() + ) # Sample variance parameters (if requested) if self.sample_sigma2_global: @@ -1974,7 +2003,7 @@ def sample( num_threads, ) - # Cannot cache train set predictions for tau because the cached predictions in the + # Cannot cache train set predictions for tau because the cached predictions in the # tracking data structures are pre-multiplied by the basis (treatment) # ... @@ -2038,7 +2067,9 @@ def sample( # Cache train set predictions since they are already computed during sampling if keep_sample: - sigma2_x_train_raw[:,sample_counter] = forest_sampler_variance.get_cached_forest_predictions() + sigma2_x_train_raw[:, sample_counter] = ( + forest_sampler_variance.get_cached_forest_predictions() + ) # Sample variance parameters (if requested) if self.sample_sigma2_global: @@ -2095,9 +2126,9 @@ def sample( self.leaf_scale_mu_samples = self.leaf_scale_mu_samples[num_gfr:] if self.sample_sigma2_leaf_tau: self.leaf_scale_tau_samples = self.leaf_scale_tau_samples[num_gfr:] - muhat_train_raw = muhat_train_raw[:,num_gfr:] + muhat_train_raw = muhat_train_raw[:, num_gfr:] if self.include_variance_forest: - sigma2_x_train_raw = sigma2_x_train_raw[:,num_gfr:] + sigma2_x_train_raw = sigma2_x_train_raw[:, num_gfr:] self.num_samples -= num_gfr # Store predictions @@ -2179,7 +2210,10 @@ def sample( ) else: self.sigma2_x_train = ( - np.exp(sigma2_x_train_raw) * self.sigma2_init * self.y_std * self.y_std + np.exp(sigma2_x_train_raw) + * self.sigma2_init + * self.y_std + * self.y_std ) if self.has_test: sigma2_x_test_raw = ( @@ -2260,7 +2294,7 @@ def predict_tau( covariates_processed = X else: covariates_processed = self._covariate_preprocessor.transform(X) - + # Handle propensities if propensity is not None: if propensity.shape[0] != X.shape[0]: @@ -2272,9 +2306,11 @@ def predict_tau( "Propensity scores not provided, but no propensity model was trained during sampling" ) else: - internal_propensity_preds = self.bart_propensity_model.predict(covariates_processed) + internal_propensity_preds = self.bart_propensity_model.predict( + covariates_processed + ) propensity = np.mean( - internal_propensity_preds['y_hat'], axis=1, keepdims=True + internal_propensity_preds["y_hat"], axis=1, keepdims=True ) # Update covariates to include propensities if requested @@ -2361,7 +2397,7 @@ def predict_variance( covariates_processed = covariates else: covariates_processed = self._covariate_preprocessor.transform(covariates) - + # Handle propensities if propensity is not None: if propensity.shape[0] != covariates.shape[0]: @@ -2373,9 +2409,11 @@ def predict_variance( "Propensity scores not provided, but no propensity model was trained during sampling" ) else: - internal_propensity_preds = self.bart_propensity_model.predict(covariates_processed) + internal_propensity_preds = self.bart_propensity_model.predict( + covariates_processed + ) propensity = np.mean( - internal_propensity_preds['y_hat'], axis=1, keepdims=True + internal_propensity_preds["y_hat"], axis=1, keepdims=True ) # Update covariates to include propensities if requested @@ -2416,9 +2454,9 @@ def predict( Z: np.array, propensity: np.array = None, rfx_group_ids: np.array = None, - rfx_basis: np.array = None, - type: str = "posterior", - terms: Union[list[str], str] = "all" + rfx_basis: np.array = None, + type: str = "posterior", + terms: Union[list[str], str] = "all", ) -> dict: """Predict outcome model components (CATE function and prognostic function) as well as overall outcome for every provided observation. Predicted outcomes are computed as `yhat = mu_x + Z*tau_x` where mu_x is a sample of the prognostic function and tau_x is a sample of the treatment effect (CATE) function. @@ -2435,7 +2473,7 @@ def predict( Optional group labels used for an additive random effects model. rfx_basis : np.array, optional Optional basis for "random-slope" regression in an additive random effects model. - + type : str, optional Type of prediction to return. Options are "mean", which averages the predictions from every draw of a BART model, and "posterior", which returns the entire matrix of posterior predictions. Default: "posterior". terms : str, optional @@ -2461,25 +2499,36 @@ def predict( has_variance_forest = self.include_variance_forest has_rfx = self.has_rfx has_y_hat = has_mu_forest or has_tau_forest or has_rfx - predict_y_hat = ((has_y_hat and ("y_hat" in terms)) or - (has_y_hat and ("all" in terms))) - predict_mu_forest = ((has_mu_forest and ("prognostic_function" in terms)) or - (has_mu_forest and ("all" in terms))) - predict_tau_forest = ((has_tau_forest and ("cate" in terms)) or - (has_tau_forest and ("all" in terms))) - predict_rfx = ((has_rfx and ("rfx" in terms)) or - (has_rfx and ("all" in terms))) - predict_variance_forest = ((has_variance_forest and ("variance_forest" in terms)) or - (has_variance_forest and ("all" in terms))) - predict_count = (predict_y_hat + predict_mu_forest + predict_tau_forest + predict_rfx + predict_variance_forest) + predict_y_hat = (has_y_hat and ("y_hat" in terms)) or ( + has_y_hat and ("all" in terms) + ) + predict_mu_forest = (has_mu_forest and ("prognostic_function" in terms)) or ( + has_mu_forest and ("all" in terms) + ) + predict_tau_forest = (has_tau_forest and ("cate" in terms)) or ( + has_tau_forest and ("all" in terms) + ) + predict_rfx = (has_rfx and ("rfx" in terms)) or (has_rfx and ("all" in terms)) + predict_variance_forest = ( + has_variance_forest and ("variance_forest" in terms) + ) or (has_variance_forest and ("all" in terms)) + predict_count = ( + predict_y_hat + + predict_mu_forest + + predict_tau_forest + + predict_rfx + + predict_variance_forest + ) if predict_count == 0: term_list = ", ".join(terms) - warnings.warn(f"None of the requested model terms, {term_list}, were fit in this model") + warnings.warn( + f"None of the requested model terms, {term_list}, were fit in this model" + ) return None predict_rfx_intermediate = predict_y_hat and has_rfx predict_mu_forest_intermediate = predict_y_hat and has_mu_forest predict_tau_forest_intermediate = predict_y_hat and has_tau_forest - + if not self.is_sampled(): msg = ( "This BCFModel instance is not fitted yet. Call 'fit' with " @@ -2523,7 +2572,7 @@ def predict( covariates_processed = X else: covariates_processed = self._covariate_preprocessor.transform(X) - + # Handle propensities if propensity is not None: if propensity.shape[0] != X.shape[0]: @@ -2535,9 +2584,11 @@ def predict( "Propensity scores not provided, but no propensity model was trained during sampling" ) else: - internal_propensity_preds = self.bart_propensity_model.predict(covariates_processed) + internal_propensity_preds = self.bart_propensity_model.predict( + covariates_processed + ) propensity = np.mean( - internal_propensity_preds['y_hat'], axis=1, keepdims=True + internal_propensity_preds["y_hat"], axis=1, keepdims=True ) # Update covariates to include propensities if requested @@ -2570,9 +2621,9 @@ def predict( tau_raw = tau_raw * adaptive_coding_weights tau_x = np.squeeze(tau_raw * self.y_std) if Z.shape[1] > 1: - treatment_term = np.multiply(np.atleast_3d(Z).swapaxes(1, 2), tau_x).sum( - axis=2 - ) + treatment_term = np.multiply( + np.atleast_3d(Z).swapaxes(1, 2), tau_x + ).sum(axis=2) if predict_mean: treatment_term = np.mean(treatment_term, axis=1) tau_x = np.mean(tau_x, axis=2) @@ -2588,7 +2639,7 @@ def predict( ) if predict_mean: rfx_preds = np.mean(rfx_preds, axis=1) - + if predict_y_hat and has_mu_forest and has_rfx: y_hat = mu_x + treatment_term + rfx_preds elif predict_y_hat and has_mu_forest: @@ -2694,9 +2745,7 @@ def to_json(self) -> str: bcf_json.add_boolean( "internal_propensity_model", self.internal_propensity_model ) - bcf_json.add_boolean( - "probit_outcome_model", self.probit_outcome_model - ) + bcf_json.add_boolean("probit_outcome_model", self.probit_outcome_model) # Add parameter samples if self.sample_sigma2_global: @@ -2781,9 +2830,7 @@ def from_json(self, json_string: str) -> None: self.internal_propensity_model = bcf_json.get_boolean( "internal_propensity_model" ) - self.probit_outcome_model = bcf_json.get_boolean( - "probit_outcome_model" - ) + self.probit_outcome_model = bcf_json.get_boolean("probit_outcome_model") # Unpack parameter samples if self.sample_sigma2_global: @@ -2909,7 +2956,7 @@ def from_json_string_list(self, json_string_list: list[str]) -> None: self.internal_propensity_model = json_object_default.get_boolean( "internal_propensity_model" ) - + # Unpack number of samples for i in range(len(json_object_list)): if i == 0: @@ -2949,9 +2996,9 @@ def from_json_string_list(self, json_string_list: list[str]) -> None: if self.sample_sigma2_leaf_tau: for i in range(len(json_object_list)): if i == 0: - self.sample_sigma2_leaf_tau = json_object_list[i].get_numeric_vector( - "sigma2_leaf_tau_samples", "parameters" - ) + self.sample_sigma2_leaf_tau = json_object_list[ + i + ].get_numeric_vector("sigma2_leaf_tau_samples", "parameters") else: sample_sigma2_leaf_tau = json_object_list[i].get_numeric_vector( "sigma2_leaf_tau_samples", "parameters" diff --git a/stochtree/config.py b/stochtree/config.py index 72cae512..84e851c3 100644 --- a/stochtree/config.py +++ b/stochtree/config.py @@ -119,7 +119,9 @@ def __init__( raise ValueError("`leaf_dimension` must be an integer greater than 0") if leaf_model_scale is None: diag_value = 1.0 / num_trees - leaf_model_scale_array = np.zeros((leaf_dimension, leaf_dimension), dtype=float) + leaf_model_scale_array = np.zeros( + (leaf_dimension, leaf_dimension), dtype=float + ) np.fill_diagonal(leaf_model_scale_array, diag_value) else: if isinstance(leaf_model_scale, np.ndarray): @@ -432,7 +434,7 @@ def get_feature_types(self) -> np.ndarray: """ return self.feature_types - def get_sweep_update_indices(self) -> Union[np.ndarray,None]: + def get_sweep_update_indices(self) -> Union[np.ndarray, None]: """ Query vector of (0-indexed) indices of trees to update in a sweep diff --git a/stochtree/data.py b/stochtree/data.py index 4e40a282..da3ca735 100644 --- a/stochtree/data.py +++ b/stochtree/data.py @@ -59,7 +59,9 @@ def update_basis(self, basis: np.array): Numpy array of basis vectors. """ if not self.has_basis(): - raise ValueError("This dataset does not have a basis to update. Please use `add_basis` to create and initialize the values in the Dataset's basis matrix.") + raise ValueError( + "This dataset does not have a basis to update. Please use `add_basis` to create and initialize the values in the Dataset's basis matrix." + ) if not isinstance(basis, np.ndarray): raise ValueError("basis must be a numpy array.") if np.ndim(basis) == 1: @@ -71,9 +73,13 @@ def update_basis(self, basis: np.array): n, p = basis_.shape basis_rowmajor = np.ascontiguousarray(basis_) if self.num_basis() != p: - raise ValueError(f"The number of columns in the new basis ({p}) must match the number of columns in the existing basis ({self.num_basis()}).") + raise ValueError( + f"The number of columns in the new basis ({p}) must match the number of columns in the existing basis ({self.num_basis()})." + ) if self.num_observations() != n: - raise ValueError(f"The number of rows in the new basis ({n}) must match the number of rows in the existing basis ({self.num_observations()}).") + raise ValueError( + f"The number of rows in the new basis ({n}) must match the number of rows in the existing basis ({self.num_observations()})." + ) self.dataset_cpp.UpdateBasis(basis_rowmajor, n, p, True) def add_variance_weights(self, variance_weights: np.array): @@ -91,12 +97,14 @@ def add_variance_weights(self, variance_weights: np.array): n = variance_weights_.size if variance_weights_.ndim != 1: raise ValueError("variance_weights must be a 1-dimensional numpy array.") - + self.dataset_cpp.AddVarianceWeights(variance_weights_, n) - - def update_variance_weights(self, variance_weights: np.array, exponentiate: bool = False): + + def update_variance_weights( + self, variance_weights: np.array, exponentiate: bool = False + ): """ - Update variance weights in a dataset. Allows users to build an ensemble that depends on + Update variance weights in a dataset. Allows users to build an ensemble that depends on variance weights that are updated throughout the sampler. Parameters @@ -107,7 +115,9 @@ def update_variance_weights(self, variance_weights: np.array, exponentiate: bool Whether to exponentiate the variance weights before storing them in the dataset. """ if not self.has_variance_weights(): - raise ValueError("This dataset does not have variance weights to update. Please use `add_variance_weights` to create and initialize the values in the Dataset's variance weight vector.") + raise ValueError( + "This dataset does not have variance weights to update. Please use `add_variance_weights` to create and initialize the values in the Dataset's variance weight vector." + ) if not isinstance(variance_weights, np.ndarray): raise ValueError("variance_weights must be a numpy array.") variance_weights_ = np.squeeze(variance_weights) @@ -115,7 +125,9 @@ def update_variance_weights(self, variance_weights: np.array, exponentiate: bool if variance_weights_.ndim != 1: raise ValueError("variance_weights must be a 1-dimensional numpy array.") if self.num_observations() != n: - raise ValueError(f"The number of rows in the new variance_weights vector ({n}) must match the number of rows in the existing vector ({self.num_observations()}).") + raise ValueError( + f"The number of rows in the new variance_weights vector ({n}) must match the number of rows in the existing vector ({self.num_observations()})." + ) self.dataset_cpp.UpdateVarianceWeights(variance_weights_, n, exponentiate) def num_observations(self) -> int: @@ -150,7 +162,7 @@ def num_basis(self) -> int: Dimension of the basis vector in the dataset, returning 0 if the dataset does not have a basis """ return self.dataset_cpp.NumBasis() - + def get_covariates(self) -> np.array: """ Return the covariates in a Dataset as a numpy array @@ -161,7 +173,7 @@ def get_covariates(self) -> np.array: Covariate data """ return self.dataset_cpp.GetCovariates() - + def get_basis(self) -> np.array: """ Return the bases in a Dataset as a numpy array @@ -172,7 +184,7 @@ def get_basis(self) -> np.array: Basis data """ return self.dataset_cpp.GetBasis() - + def get_variance_weights(self) -> np.array: """ Return the variance weights in a Dataset as a numpy array diff --git a/stochtree/forest.py b/stochtree/forest.py index d672ae5c..a92c5847 100644 --- a/stochtree/forest.py +++ b/stochtree/forest.py @@ -161,13 +161,13 @@ def set_root_leaves( def collapse(self, batch_size: int) -> None: """ - Collapse forests in this container by a pre-specified batch size. - For example, if we have a container of twenty 10-tree forests, and we - specify a `batch_size` of 5, then this method will yield four 50-tree - forests. "Excess" forests remaining after the size of a forest container - is divided by `batch_size` will be pruned from the beginning of the - container (i.e. earlier sampled forests will be deleted). This method - has no effect if `batch_size` is larger than the number of forests + Collapse forests in this container by a pre-specified batch size. + For example, if we have a container of twenty 10-tree forests, and we + specify a `batch_size` of 5, then this method will yield four 50-tree + forests. "Excess" forests remaining after the size of a forest container + is divided by `batch_size` will be pruned from the beginning of the + container (i.e. earlier sampled forests will be deleted). This method + has no effect if `batch_size` is larger than the number of forests in a container. Parameters @@ -177,12 +177,23 @@ def collapse(self, batch_size: int) -> None: """ container_size = self.num_samples() if batch_size <= container_size and batch_size > 1: - reverse_container_inds = np.linspace(start=container_size, stop=1, num=container_size, dtype=int) + reverse_container_inds = np.linspace( + start=container_size, stop=1, num=container_size, dtype=int + ) num_clean_batches = container_size // batch_size - batch_inds = (reverse_container_inds - (container_size - ((container_size // num_clean_batches) * num_clean_batches)) - 1) // batch_size + batch_inds = ( + reverse_container_inds + - ( + container_size + - ((container_size // num_clean_batches) * num_clean_batches) + ) + - 1 + ) // batch_size batch_inds = batch_inds.astype(int) for batch_ind in np.flip(np.unique(batch_inds[batch_inds >= 0])): - merge_forest_inds = np.sort(reverse_container_inds[batch_inds == batch_ind] - 1) + merge_forest_inds = np.sort( + reverse_container_inds[batch_inds == batch_ind] - 1 + ) num_merge_forests = len(merge_forest_inds) self.combine_forests(merge_forest_inds) for i in range(num_merge_forests - 1, 0, -1): @@ -194,10 +205,8 @@ def collapse(self, batch_size: int) -> None: num_delete_forests = len(delete_forest_inds) for i in range(num_delete_forests - 1, -1, -1): self.delete_sample(delete_forest_inds[i]) - - def combine_forests( - self, forest_inds: np.array - ) -> None: + + def combine_forests(self, forest_inds: np.array) -> None: """ Collapse specified forests into a single forest @@ -214,42 +223,56 @@ def combine_forests( forest_inds_sorted = forest_inds_sorted.astype(int) self.forest_container_cpp.CombineForests(forest_inds_sorted) - def add_to_forest( - self, forest_index: int, constant_value : float - ) -> None: + def add_to_forest(self, forest_index: int, constant_value: float) -> None: """ Add a constant value to every leaf of every tree of a given forest Parameters ---------- - forest_index : int + forest_index : int Index of forest whose leaves will be modified (0-indexed) - constant_value : float + constant_value : float Value to add to every leaf of every tree of the forest at `forest_index` """ - if not isinstance(forest_index, int) and not isinstance(constant_value, (int, float)): - raise ValueError("forest_index must be an integer and constant_multiple must be a float or int") - if not forest_index >= 0 or not forest_index < self.forest_container_cpp.NumSamples(): - raise ValueError("forest_index must be >= 0 and less than the total number of samples in a forest container") + if not isinstance(forest_index, int) and not isinstance( + constant_value, (int, float) + ): + raise ValueError( + "forest_index must be an integer and constant_multiple must be a float or int" + ) + if ( + not forest_index >= 0 + or not forest_index < self.forest_container_cpp.NumSamples() + ): + raise ValueError( + "forest_index must be >= 0 and less than the total number of samples in a forest container" + ) self.forest_container_cpp.AddToForest(forest_index, constant_value) - def multiply_forest( - self, forest_index: int, constant_multiple : float - ) -> None: + def multiply_forest(self, forest_index: int, constant_multiple: float) -> None: """ Multiply every leaf of every tree of a given forest by constant value Parameters ---------- - forest_index : int + forest_index : int Index of forest whose leaves will be modified (0-indexed) - constant_multiple : float + constant_multiple : float Value to multiply through by every leaf of every tree of the forest at `forest_index` """ - if not isinstance(forest_index, int) and not isinstance(constant_multiple, (int, float)): - raise ValueError("forest_index must be an integer and constant_multiple must be a float or int") - if not forest_index >= 0 or not forest_index < self.forest_container_cpp.NumSamples(): - raise ValueError("forest_index must be >= 0 and less than the total number of samples in a forest container") + if not isinstance(forest_index, int) and not isinstance( + constant_multiple, (int, float) + ): + raise ValueError( + "forest_index must be an integer and constant_multiple must be a float or int" + ) + if ( + not forest_index >= 0 + or not forest_index < self.forest_container_cpp.NumSamples() + ): + raise ValueError( + "forest_index must be >= 0 and less than the total number of samples in a forest container" + ) self.forest_container_cpp.MultiplyForest(forest_index, constant_multiple) def save_to_json_file(self, json_filename: str) -> None: @@ -1021,7 +1044,7 @@ def set_root_leaves(self, leaf_value: Union[float, np.array]) -> None: else: self.forest_cpp.SetRootValue(leaf_value) self.internal_forest_is_empty = False - + def merge_forest(self, other_forest): """ Create a larger forest by merging the trees of this forest with those of another forest @@ -1034,13 +1057,19 @@ def merge_forest(self, other_forest): if not isinstance(other_forest, Forest): raise ValueError("other_forest must be an instance of the Forest class") if self.leaf_constant != other_forest.leaf_constant: - raise ValueError("Forests must have matching leaf dimensions in order to be merged") + raise ValueError( + "Forests must have matching leaf dimensions in order to be merged" + ) if self.output_dimension != other_forest.output_dimension: - raise ValueError("Forests must have matching leaf dimensions in order to be merged") + raise ValueError( + "Forests must have matching leaf dimensions in order to be merged" + ) if self.is_exponentiated != other_forest.is_exponentiated: - raise ValueError("Forests must have matching leaf dimensions in order to be merged") + raise ValueError( + "Forests must have matching leaf dimensions in order to be merged" + ) self.forest_cpp.MergeForest(other_forest.forest_cpp) - + def add_constant(self, constant_value): """ Add a constant value to every leaf of every tree in an ensemble. If leaves are multi-dimensional, `constant_value` will be added to every dimension of the leaves. @@ -1051,7 +1080,7 @@ def add_constant(self, constant_value): Value that will be added to every leaf of every tree """ self.forest_cpp.AddConstant(constant_value) - + def multiply_constant(self, constant_multiple): """ Multiply every leaf of every tree by a constant value. If leaves are multi-dimensional, `constant_multiple` will be multiplied through every dimension of the leaves. diff --git a/stochtree/kernel.py b/stochtree/kernel.py index 86d9d8a5..58986f7f 100644 --- a/stochtree/kernel.py +++ b/stochtree/kernel.py @@ -2,20 +2,29 @@ import pandas as pd import numpy as np -from stochtree_cpp import cppComputeForestContainerLeafIndices, cppComputeForestMaxLeafIndex +from stochtree_cpp import ( + cppComputeForestContainerLeafIndices, + cppComputeForestMaxLeafIndex, +) from .bart import BARTModel from .bcf import BCFModel from .forest import ForestContainer -def compute_forest_leaf_indices(model_object: Union[BARTModel, BCFModel, ForestContainer], covariates: Union[np.array, pd.DataFrame], forest_type: str = None, propensity: np.array = None, forest_inds: Union[int, np.ndarray] = None): +def compute_forest_leaf_indices( + model_object: Union[BARTModel, BCFModel, ForestContainer], + covariates: Union[np.array, pd.DataFrame], + forest_type: str = None, + propensity: np.array = None, + forest_inds: Union[int, np.ndarray] = None, +): """ Compute and return a vector representation of a forest's leaf predictions for every observation in a dataset. - The vector has a "row-major" format that can be easily re-represented as as a CSR sparse matrix: elements are organized so that the first `n` elements - correspond to leaf predictions for all `n` observations in a dataset for the first tree in an ensemble, the next `n` elements correspond to predictions for - the second tree and so on. The "data" for each element corresponds to a uniquely mapped column index that corresponds to a single leaf of a single tree (i.e. + The vector has a "row-major" format that can be easily re-represented as as a CSR sparse matrix: elements are organized so that the first `n` elements + correspond to leaf predictions for all `n` observations in a dataset for the first tree in an ensemble, the next `n` elements correspond to predictions for + the second tree and so on. The "data" for each element corresponds to a uniquely mapped column index that corresponds to a single leaf of a single tree (i.e. if tree 1 has 3 leaves, its column indices range from 0 to 2, and then tree 2's leaf indices begin at 3, etc...). Parameters @@ -36,38 +45,52 @@ def compute_forest_leaf_indices(model_object: Union[BARTModel, BCFModel, ForestC * `'variance'`: Extracts leaf indices for the variance forest * **ForestContainer** * `NULL`: It is not necessary to disambiguate when this function is called directly on a `ForestSamples` object. This is the default value of this - + propensity : `np.array`, optional Optional test set propensities. Must be provided if propensities were provided when the model was sampled. forest_inds : int or np.ndarray - Indices of the forest sample(s) for which to compute leaf indices. If not provided, this function will return leaf indices for every sample of a forest. + Indices of the forest sample(s) for which to compute leaf indices. If not provided, this function will return leaf indices for every sample of a forest. This function uses 0-indexing, so the first forest sample corresponds to `forest_num = 0`, and so on. - + Returns ------- - Numpy array with dimensions `num_obs` by `num_trees`, where `num_obs` is the number of rows in `covaritates` and `num_trees` is the number of trees in the relevant forest of `model_object`. + Numpy array with dimensions `num_obs` by `num_trees`, where `num_obs` is the number of rows in `covaritates` and `num_trees` is the number of trees in the relevant forest of `model_object`. """ # Extract relevant forest container - if not isinstance(model_object, BARTModel) and not isinstance(model_object, BCFModel) and not isinstance(model_object, ForestContainer): - raise ValueError("model_object must be one of BARTModel, BCFModel, or ForestContainer") + if ( + not isinstance(model_object, BARTModel) + and not isinstance(model_object, BCFModel) + and not isinstance(model_object, ForestContainer) + ): + raise ValueError( + "model_object must be one of BARTModel, BCFModel, or ForestContainer" + ) if isinstance(model_object, BARTModel): model_type = "bart" if forest_type is None: - raise ValueError("forest_type must be specified for a BARTModel model_type (either set to 'mean' or 'variance')") + raise ValueError( + "forest_type must be specified for a BARTModel model_type (either set to 'mean' or 'variance')" + ) elif isinstance(model_object, BCFModel): model_type = "bcf" if forest_type is None: - raise ValueError("forest_type must be specified for a BCFModel model_type (either set to 'prognostic', 'treatment' or 'variance')") + raise ValueError( + "forest_type must be specified for a BCFModel model_type (either set to 'prognostic', 'treatment' or 'variance')" + ) else: model_type = "forest" if model_type == "bart": if forest_type == "mean": if not model_object.include_mean_forest: - raise ValueError("Mean forest was not sampled for model_object, but requested by forest_type") + raise ValueError( + "Mean forest was not sampled for model_object, but requested by forest_type" + ) forest_container = model_object.forest_container_mean else: if not model_object.include_variance_forest: - raise ValueError("Variance forest was not sampled for model_object, but requested by forest_type") + raise ValueError( + "Variance forest was not sampled for model_object, but requested by forest_type" + ) forest_container = model_object.forest_container_variance elif model_type == "bcf": if forest_type == "prognostic": @@ -76,17 +99,23 @@ def compute_forest_leaf_indices(model_object: Union[BARTModel, BCFModel, ForestC forest_container = model_object.forest_container_tau else: if not model_object.include_variance_forest: - raise ValueError("Variance forest was not sampled for model_object, but requested by forest_type") + raise ValueError( + "Variance forest was not sampled for model_object, but requested by forest_type" + ) forest_container = model_object.forest_container_variance else: forest_container = model_object - - if not isinstance(covariates, pd.DataFrame) and not isinstance(covariates, np.ndarray): + + if not isinstance(covariates, pd.DataFrame) and not isinstance( + covariates, np.ndarray + ): raise ValueError("covariates must be a matrix or dataframe") - + # Preprocess covariates if model_type == "bart" or model_type == "bcf": - covariates_processed = model_object._covariate_preprocessor.transform(covariates) + covariates_processed = model_object._covariate_preprocessor.transform( + covariates + ) else: covariates_processed = covariates covariates_processed = np.asfortranarray(covariates_processed) @@ -100,17 +129,21 @@ def compute_forest_leaf_indices(model_object: Union[BARTModel, BCFModel, ForestC "Propensity scores not provided, but no propensity model was trained during sampling" ) propensity = np.mean( - model_object.bart_propensity_model.predict(covariates), axis=1, keepdims=True + model_object.bart_propensity_model.predict(covariates), + axis=1, + keepdims=True, ) covariates_processed = np.c_[covariates_processed, propensity] - + # Preprocess forest indices num_forests = forest_container.num_samples() if forest_inds is None: forest_inds = np.arange(num_forests) elif isinstance(forest_inds, int): if not forest_inds >= 0 or not forest_inds < num_forests: - raise ValueError("The index in forest_inds must be >= 0 and < the total number of samples in a forest container") + raise ValueError( + "The index in forest_inds must be >= 0 and < the total number of samples in a forest container" + ) forest_inds = np.array([forest_inds]) elif isinstance(forest_inds, np.ndarray): if forest_inds.size > 1: @@ -118,13 +151,22 @@ def compute_forest_leaf_indices(model_object: Union[BARTModel, BCFModel, ForestC if forest_inds.ndim > 1: raise ValueError("forest_inds must be a one-dimensional numpy array") if not np.all(forest_inds >= 0) or not np.all(forest_inds < num_forests): - raise ValueError("The indices in forest_inds must be >= 0 and < the total number of samples in a forest container") + raise ValueError( + "The indices in forest_inds must be >= 0 and < the total number of samples in a forest container" + ) else: raise ValueError("forest_inds must be a one-dimensional numpy array") - - return cppComputeForestContainerLeafIndices(forest_container.forest_container_cpp, covariates_processed, forest_inds) -def compute_forest_max_leaf_index(model_object: Union[BARTModel, BCFModel, ForestContainer], forest_type: str = None, forest_inds: Union[int, np.ndarray] = None): + return cppComputeForestContainerLeafIndices( + forest_container.forest_container_cpp, covariates_processed, forest_inds + ) + + +def compute_forest_max_leaf_index( + model_object: Union[BARTModel, BCFModel, ForestContainer], + forest_type: str = None, + forest_inds: Union[int, np.ndarray] = None, +): """ Compute and return the largest possible leaf index computable by `compute_forest_leaf_indices` for the forests in a designated forest sample container. @@ -144,36 +186,50 @@ def compute_forest_max_leaf_index(model_object: Union[BARTModel, BCFModel, Fores * `'variance'`: Extracts leaf indices for the variance forest * **ForestContainer** * `NULL`: It is not necessary to disambiguate when this function is called directly on a `ForestSamples` object. This is the default value of this - + forest_inds : int or np.ndarray - Indices of the forest sample(s) for which to compute max leaf indices. If not provided, this function will return max leaf indices for every sample of a forest. + Indices of the forest sample(s) for which to compute max leaf indices. If not provided, this function will return max leaf indices for every sample of a forest. This function uses 0-indexing, so the first forest sample corresponds to `forest_num = 0`, and so on. - + Returns ------- Numpy array containing the largest possible leaf index computable by `compute_forest_leaf_indices` for the forests in a designated forest sample container. """ # Extract relevant forest container - if not isinstance(model_object, BARTModel) and not isinstance(model_object, BCFModel) and not isinstance(model_object, ForestContainer): - raise ValueError("model_object must be one of BARTModel, BCFModel, or ForestContainer") + if ( + not isinstance(model_object, BARTModel) + and not isinstance(model_object, BCFModel) + and not isinstance(model_object, ForestContainer) + ): + raise ValueError( + "model_object must be one of BARTModel, BCFModel, or ForestContainer" + ) if isinstance(model_object, BARTModel): model_type = "bart" if forest_type is None: - raise ValueError("forest_type must be specified for a BARTModel model_type (either set to 'mean' or 'variance')") + raise ValueError( + "forest_type must be specified for a BARTModel model_type (either set to 'mean' or 'variance')" + ) elif isinstance(model_object, BCFModel): model_type = "bcf" if forest_type is None: - raise ValueError("forest_type must be specified for a BCFModel model_type (either set to 'prognostic', 'treatment' or 'variance')") + raise ValueError( + "forest_type must be specified for a BCFModel model_type (either set to 'prognostic', 'treatment' or 'variance')" + ) else: model_type = "forest" if model_type == "bart": if forest_type == "mean": if not model_object.include_mean_forest: - raise ValueError("Mean forest was not sampled for model_object, but requested by forest_type") + raise ValueError( + "Mean forest was not sampled for model_object, but requested by forest_type" + ) forest_container = model_object.forest_container_mean else: if not model_object.include_variance_forest: - raise ValueError("Variance forest was not sampled for model_object, but requested by forest_type") + raise ValueError( + "Variance forest was not sampled for model_object, but requested by forest_type" + ) forest_container = model_object.forest_container_variance elif model_type == "bcf": if forest_type == "prognostic": @@ -182,18 +238,22 @@ def compute_forest_max_leaf_index(model_object: Union[BARTModel, BCFModel, Fores forest_container = model_object.forest_container_tau else: if not model_object.include_variance_forest: - raise ValueError("Variance forest was not sampled for model_object, but requested by forest_type") + raise ValueError( + "Variance forest was not sampled for model_object, but requested by forest_type" + ) forest_container = model_object.forest_container_variance else: forest_container = model_object - + # Preprocess forest indices num_forests = forest_container.num_samples() if forest_inds is None: forest_inds = np.arange(num_forests) elif isinstance(forest_inds, int): if not forest_inds >= 0 or not forest_inds < num_forests: - raise ValueError("The index in forest_inds must be >= 0 and < the total number of samples in a forest container") + raise ValueError( + "The index in forest_inds must be >= 0 and < the total number of samples in a forest container" + ) forest_inds = np.array([forest_inds]) elif isinstance(forest_inds, np.ndarray): if forest_inds.size > 1: @@ -201,16 +261,19 @@ def compute_forest_max_leaf_index(model_object: Union[BARTModel, BCFModel, Fores if forest_inds.ndim > 1: raise ValueError("forest_inds must be a one-dimensional numpy array") if not np.all(forest_inds >= 0) or not np.all(forest_inds < num_forests): - raise ValueError("The indices in forest_inds must be >= 0 and < the total number of samples in a forest container") + raise ValueError( + "The indices in forest_inds must be >= 0 and < the total number of samples in a forest container" + ) else: raise ValueError("forest_inds must be a one-dimensional numpy array") - + # Compute max index output_size = len(forest_inds) output = np.empty(output_size) for i in np.arange(output_size): - output[i] = cppComputeForestMaxLeafIndex(forest_container.forest_container_cpp, forest_inds[i]) - - # Return result - return output + output[i] = cppComputeForestMaxLeafIndex( + forest_container.forest_container_cpp, forest_inds[i] + ) + # Return result + return output diff --git a/stochtree/random_effects.py b/stochtree/random_effects.py index 6c4093d4..95ac3df1 100644 --- a/stochtree/random_effects.py +++ b/stochtree/random_effects.py @@ -110,10 +110,12 @@ def add_variance_weights(self, variance_weights: np.array): ) n = variance_weights_.shape[0] self.rfx_dataset_cpp.AddVarianceWeights(variance_weights_, n) - - def update_variance_weights(self, variance_weights: np.array, exponentiate: bool = False): + + def update_variance_weights( + self, variance_weights: np.array, exponentiate: bool = False + ): """ - Update variance weights in a dataset. Allows users to build an ensemble that depends on + Update variance weights in a dataset. Allows users to build an ensemble that depends on variance weights that are updated throughout the sampler. Parameters @@ -124,7 +126,9 @@ def update_variance_weights(self, variance_weights: np.array, exponentiate: bool Whether to exponentiate the variance weights before storing them in the dataset. """ if not self.has_variance_weights(): - raise ValueError("This dataset does not have variance weights to update. Please use `add_variance_weights` to create and initialize the values in the Dataset's variance weight vector.") + raise ValueError( + "This dataset does not have variance weights to update. Please use `add_variance_weights` to create and initialize the values in the Dataset's variance weight vector." + ) if not isinstance(variance_weights, np.ndarray): raise ValueError("variance_weights must be a numpy array.") variance_weights_ = np.squeeze(variance_weights) @@ -134,9 +138,11 @@ def update_variance_weights(self, variance_weights: np.array, exponentiate: bool ) n = variance_weights_.shape[0] if self.num_observations() != n: - raise ValueError(f"The number of rows in the new variance_weights vector ({n}) must match the number of rows in the existing vector ({self.num_observations()}).") + raise ValueError( + f"The number of rows in the new variance_weights vector ({n}) must match the number of rows in the existing vector ({self.num_observations()})." + ) self.rfx_dataset_cpp.UpdateVarianceWeights(variance_weights, n, exponentiate) - + def get_group_labels(self) -> np.array: """ Return the group labels in a RandomEffectsDataset as a numpy array @@ -147,7 +153,7 @@ def get_group_labels(self) -> np.array: One-dimensional numpy array of group labels. """ return self.rfx_dataset_cpp.GetGroupLabels() - + def get_basis(self) -> np.array: """ Return the bases in a RandomEffectsDataset as a numpy array @@ -158,7 +164,7 @@ def get_basis(self) -> np.array: Two-dimensional numpy array of basis vectors. """ return self.rfx_dataset_cpp.GetBasis() - + def get_variance_weights(self) -> np.array: """ Return the variance weights in a RandomEffectsDataset as a numpy array diff --git a/stochtree/sampler.py b/stochtree/sampler.py index cbab9ce6..cc5ad54a 100644 --- a/stochtree/sampler.py +++ b/stochtree/sampler.py @@ -151,7 +151,7 @@ def sample_one_iteration( ) if self.forest_sampler_cpp.GetMaxDepth() != forest_config.get_max_depth(): self.forest_sampler_cpp.SetMaxDepth(forest_config.get_max_depth()) - + # Unpack sweep update indices (initializing empty numpy array if None) sweep_update_indices = forest_config.get_sweep_update_indices() if sweep_update_indices is None: @@ -176,7 +176,7 @@ def sample_one_iteration( forest_config.get_num_features_subsample(), keep_forest, gfr, - num_threads, + num_threads, ) def prepare_for_sampler( @@ -270,7 +270,7 @@ def propagate_basis_update( self.forest_sampler_cpp.PropagateBasisUpdate( dataset.dataset_cpp, residual.residual_cpp, forest.forest_cpp ) - + def get_cached_forest_predictions(self) -> np.array: """ Extract an internally-cached prediction of a forest on the training dataset in a sampler. diff --git a/stochtree/serialization.py b/stochtree/serialization.py index 844ee54d..64f0d842 100644 --- a/stochtree/serialization.py +++ b/stochtree/serialization.py @@ -62,7 +62,9 @@ def add_random_effects(self, rfx_container: RandomEffectsContainer) -> None: Samples of a random effects model """ _ = self.json_cpp.AddRandomEffectsContainer(rfx_container.rfx_container_cpp) - _ = self.json_cpp.AddRandomEffectsLabelMapper(rfx_container.rfx_label_mapper_cpp) + _ = self.json_cpp.AddRandomEffectsLabelMapper( + rfx_container.rfx_label_mapper_cpp + ) _ = self.json_cpp.AddRandomEffectsGroupIDs(rfx_container.rfx_group_ids) self.json_cpp.IncrementRandomEffectsCount() self.num_rfx += 1 diff --git a/stochtree/utils.py b/stochtree/utils.py index da25cb9c..8082c76b 100644 --- a/stochtree/utils.py +++ b/stochtree/utils.py @@ -190,17 +190,17 @@ def _check_matrix_square(input: np.ndarray) -> bool: def _expand_dims_1d(input: Union[int, float, np.array], output_size: int) -> np.array: """ - Convert scalar input to 1D numpy array of dimension `output_size`, + Convert scalar input to 1D numpy array of dimension `output_size`, or check that input array is equivalent to a 1D array of dimension `output_size`. Single element numpy arrays (i.e. `np.array([2.5])`) are treated as scalars. - + Parameters ---------- input : int, float, np.array Input to be converted to a 1D array (or passed through as-is) output_size : int Intended size of the output vector - + Returns ------- np.array @@ -209,30 +209,38 @@ def _expand_dims_1d(input: Union[int, float, np.array], output_size: int) -> np. if isinstance(input, np.ndarray): input = np.squeeze(input) if input.ndim > 1: - raise ValueError("`input` must be convertible to a 1D numpy array or scalar") + raise ValueError( + "`input` must be convertible to a 1D numpy array or scalar" + ) if input.ndim == 0: output = np.repeat(input, output_size) else: if input.shape[0] != output_size: - raise ValueError("`input` must be a 1D numpy array with `output_size` elements") + raise ValueError( + "`input` must be a 1D numpy array with `output_size` elements" + ) output = input elif isinstance(input, (int, float)): output = np.repeat(input, output_size) else: - raise ValueError("`input` must be either a 1D numpy array or a scalar that can be repeated `output_size` times") + raise ValueError( + "`input` must be either a 1D numpy array or a scalar that can be repeated `output_size` times" + ) return output -def _expand_dims_2d(input: Union[int, float, np.array], output_rows: int, output_cols: int) -> np.array: +def _expand_dims_2d( + input: Union[int, float, np.array], output_rows: int, output_cols: int +) -> np.array: """ - Ensures that input is propagated appropriately to a 2D numpy array of dimension `output_rows` x `output_cols`. + Ensures that input is propagated appropriately to a 2D numpy array of dimension `output_rows` x `output_cols`. Handles the following cases: 1. `input` is a scalar: output is simply a (`output_rows`, `output_cols`) array with `input` repeated for each element 2. `input` is a 1D array of length `output_rows`: output is a (`output_rows`, `output_cols`) array with `input` broadcast across each of `output_cols` columns 3. `input` is a 1D array of length `output_cols`: output is a (`output_rows`, `output_cols`) array with `input` broadcast across each of `output_rows` rows 4. `input` is a 2D array of dimension (`output_rows`, `output_cols`): input is passed through as-is All other cases raise a `ValueError`. Single element numpy arrays (i.e. `np.array([2.5])`) are treated as scalars. - + Parameters ---------- input : int, float, np.array @@ -241,7 +249,7 @@ def _expand_dims_2d(input: Union[int, float, np.array], output_rows: int, output Intended number of rows in the output array output_cols : int Intended number of columns in the output array - + Returns ------- np.array @@ -253,9 +261,13 @@ def _expand_dims_2d(input: Union[int, float, np.array], output_rows: int, output raise ValueError("`input` must be a 1D or 2D numpy array") elif input.ndim == 2: if input.shape[0] != output_rows: - raise ValueError("If `input` is passed as a 2D numpy array, it must contain `output_rows` rows") + raise ValueError( + "If `input` is passed as a 2D numpy array, it must contain `output_rows` rows" + ) if input.shape[1] != output_cols: - raise ValueError("If `input` is passed as a 2D numpy array, it must contain `output_cols` columns") + raise ValueError( + "If `input` is passed as a 2D numpy array, it must contain `output_cols` columns" + ) output = input elif input.ndim == 1: if input.shape[0] == output_cols: @@ -263,7 +275,9 @@ def _expand_dims_2d(input: Union[int, float, np.array], output_rows: int, output elif input.shape[0] == output_rows: output = np.tile(input, (output_cols, 1)).T else: - raise ValueError("If `input` is a 1D numpy array, it must either contain `output_rows` or `output_cols` elements") + raise ValueError( + "If `input` is a 1D numpy array, it must either contain `output_rows` or `output_cols` elements" + ) elif input.ndim == 0: output = np.tile(input, (output_rows, output_cols)) elif isinstance(input, (int, float)): @@ -273,19 +287,21 @@ def _expand_dims_2d(input: Union[int, float, np.array], output_rows: int, output return output -def _expand_dims_2d_diag(input: Union[int, float, np.array], output_size: int) -> np.array: +def _expand_dims_2d_diag( + input: Union[int, float, np.array], output_size: int +) -> np.array: """ - Convert scalar input to 2D square numpy array of dimension `output_size` x `output_size` with `input` along the diagonal, + Convert scalar input to 2D square numpy array of dimension `output_size` x `output_size` with `input` along the diagonal, or check that input array is equivalent to a 2D square array of dimension `output_size` x `output_size`. Single element numpy arrays (i.e. `np.array([2.5])`) are treated as scalars. - + Parameters ---------- input : int, float, np.array Input to be converted to a 2D square array (or passed through as-is) output_size : int Intended row and column dimension of the square output matrix - + Returns ------- np.array @@ -294,23 +310,25 @@ def _expand_dims_2d_diag(input: Union[int, float, np.array], output_size: int) - if isinstance(input, np.ndarray): input = np.squeeze(input) if (input.ndim != 2) and (input.ndim != 0): - raise ValueError("`input` must be convertible to a 2D numpy array or scalar") - if input.ndim == 0: - output = np.zeros( - (output_size, output_size), dtype=float + raise ValueError( + "`input` must be convertible to a 2D numpy array or scalar" ) + if input.ndim == 0: + output = np.zeros((output_size, output_size), dtype=float) np.fill_diagonal(output, input) else: if input.shape[0] != input.shape[1]: raise ValueError("`input` must be a 2D square numpy array") if input.shape[0] != output_size: - raise ValueError("`input` must be a 2D square numpy array with exactly `output_size` rows and columns") + raise ValueError( + "`input` must be a 2D square numpy array with exactly `output_size` rows and columns" + ) output = input elif isinstance(input, (int, float)): - output = np.zeros( - (output_size, output_size), dtype=float - ) + output = np.zeros((output_size, output_size), dtype=float) np.fill_diagonal(output, input) else: - raise ValueError("`input` must be either a 2D square numpy array or a scalar that can be propagated along the diagonal of a square matrix") + raise ValueError( + "`input` must be either a 2D square numpy array or a scalar that can be propagated along the diagonal of a square matrix" + ) return output From 7dc71ce0f6eab82041dea39e8216ba918a784de9 Mon Sep 17 00:00:00 2001 From: Drew Herren Date: Wed, 15 Oct 2025 23:38:01 -0500 Subject: [PATCH 17/53] Deprecate predict_mean and predict_variance methods in python BART --- stochtree/bart.py | 183 --------------------------------------------- stochtree/utils.py | 2 +- 2 files changed, 1 insertion(+), 184 deletions(-) diff --git a/stochtree/bart.py b/stochtree/bart.py index 6e3a1e8b..7592a1ff 100644 --- a/stochtree/bart.py +++ b/stochtree/bart.py @@ -1867,189 +1867,6 @@ def predict( result["variance_forest_predictions"] = None return result - def predict_mean( - self, - covariates: np.array, - basis: np.array = None, - rfx_group_ids: np.array = None, - rfx_basis: np.array = None, - ) -> np.array: - """Predict expected conditional outcome from a BART model. - - Parameters - ---------- - covariates : np.array - Test set covariates. - basis : np.array, optional - Optional test set basis vector, must be provided if the model was trained with a leaf regression basis. - - Returns - ------- - np.array - Mean forest predictions. - """ - if not self.is_sampled(): - msg = ( - "This BARTModel instance is not fitted yet. Call 'fit' with " - "appropriate arguments before using this model." - ) - raise NotSampledError(msg) - - has_mean_predictions = self.include_mean_forest or self.has_rfx - if not has_mean_predictions: - msg = ( - "This BARTModel instance was not sampled with a mean forest or random effects. " - "Call 'fit' with appropriate arguments before using this model." - ) - raise NotSampledError(msg) - - # Data checks - if not isinstance(covariates, pd.DataFrame) and not isinstance( - covariates, np.ndarray - ): - raise ValueError("covariates must be a pandas dataframe or numpy array") - if basis is not None: - if not isinstance(basis, np.ndarray): - raise ValueError("basis must be a numpy array") - if basis.shape[0] != covariates.shape[0]: - raise ValueError( - "covariates and basis must have the same number of rows" - ) - - # Convert everything to standard shape (2-dimensional) - if isinstance(covariates, np.ndarray): - if covariates.ndim == 1: - covariates = np.expand_dims(covariates, 1) - if basis is not None: - if basis.ndim == 1: - basis = np.expand_dims(basis, 1) - - # Covariate preprocessing - if not self._covariate_preprocessor._check_is_fitted(): - if not isinstance(covariates, np.ndarray): - raise ValueError( - "Prediction cannot proceed on a pandas dataframe, since the BART model was not fit with a covariate preprocessor. Please refit your model by passing covariate data as a Pandas dataframe." - ) - else: - warnings.warn( - "This BART model has not run any covariate preprocessing routines. We will attempt to predict on the raw covariate values, but this will trigger an error with non-numeric columns. Please refit your model by passing non-numeric covariate data a a Pandas dataframe.", - RuntimeWarning, - ) - if not np.issubdtype( - covariates.dtype, np.floating - ) and not np.issubdtype(covariates.dtype, np.integer): - raise ValueError( - "Prediction cannot proceed on a non-numeric numpy array, since the BART model was not fit with a covariate preprocessor. Please refit your model by passing non-numeric covariate data as a Pandas dataframe." - ) - covariates_processed = covariates - else: - covariates_processed = self._covariate_preprocessor.transform(covariates) - - # Dataset construction - pred_dataset = Dataset() - pred_dataset.add_covariates(covariates_processed) - if basis is not None: - pred_dataset.add_basis(basis) - - # Mean forest predictions - if self.include_mean_forest: - mean_pred_raw = self.forest_container_mean.forest_container_cpp.Predict( - pred_dataset.dataset_cpp - ) - mean_pred = mean_pred_raw * self.y_std + self.y_bar - - # RFX predictions - if self.has_rfx: - rfx_preds = ( - self.rfx_container.predict(rfx_group_ids, rfx_basis) * self.y_std - ) - if self.include_mean_forest: - mean_pred = mean_pred + rfx_preds - else: - mean_pred = rfx_preds + self.y_bar - - return mean_pred - - def predict_variance(self, covariates: np.array) -> np.array: - """Predict expected conditional variance from a BART model. - - Parameters - ---------- - covariates : np.array - Test set covariates. - - Returns - ------- - np.array - Variance forest predictions. - """ - if not self.is_sampled(): - msg = ( - "This BARTModel instance is not fitted yet. Call 'fit' with " - "appropriate arguments before using this model." - ) - raise NotSampledError(msg) - - if not self.include_variance_forest: - msg = ( - "This BARTModel instance was not sampled with a variance forest. " - "Call 'fit' with appropriate arguments before using this model." - ) - raise NotSampledError(msg) - - # Data checks - if not isinstance(covariates, pd.DataFrame) and not isinstance( - covariates, np.ndarray - ): - raise ValueError("covariates must be a pandas dataframe or numpy array") - - # Convert everything to standard shape (2-dimensional) - if isinstance(covariates, np.ndarray): - if covariates.ndim == 1: - covariates = np.expand_dims(covariates, 1) - - # Covariate preprocessing - if not self._covariate_preprocessor._check_is_fitted(): - if not isinstance(covariates, np.ndarray): - raise ValueError( - "Prediction cannot proceed on a pandas dataframe, since the BART model was not fit with a covariate preprocessor. Please refit your model by passing covariate data as a Pandas dataframe." - ) - else: - warnings.warn( - "This BART model has not run any covariate preprocessing routines. We will attempt to predict on the raw covariate values, but this will trigger an error with non-numeric columns. Please refit your model by passing non-numeric covariate data a a Pandas dataframe.", - RuntimeWarning, - ) - if not np.issubdtype( - covariates.dtype, np.floating - ) and not np.issubdtype(covariates.dtype, np.integer): - raise ValueError( - "Prediction cannot proceed on a non-numeric numpy array, since the BART model was not fit with a covariate preprocessor. Please refit your model by passing non-numeric covariate data as a Pandas dataframe." - ) - covariates_processed = covariates - else: - covariates_processed = self._covariate_preprocessor.transform(covariates) - - # Dataset construction - pred_dataset = Dataset() - pred_dataset.add_covariates(covariates_processed) - - # Variance forest predictions - variance_pred_raw = self.forest_container_variance.forest_container_cpp.Predict( - pred_dataset.dataset_cpp - ) - if self.sample_sigma2_global: - variance_pred = np.empty_like(variance_pred_raw) - for i in range(self.num_samples): - variance_pred[:, i] = ( - variance_pred_raw[:, i] * self.global_var_samples[i] - ) - else: - variance_pred = ( - variance_pred_raw * self.sigma2_init * self.y_std * self.y_std - ) - - return variance_pred - def to_json(self) -> str: """ Converts a sampled BART model to JSON string representation (which can then be saved to a file or diff --git a/stochtree/utils.py b/stochtree/utils.py index 8082c76b..92cc73a5 100644 --- a/stochtree/utils.py +++ b/stochtree/utils.py @@ -1,4 +1,4 @@ -from typing import Union, Optional +from typing import Union import numpy as np From e96268f56be3a1c120703aee134f95bd9ae0b0dd Mon Sep 17 00:00:00 2001 From: Drew Herren Date: Wed, 15 Oct 2025 23:42:58 -0500 Subject: [PATCH 18/53] Deprecate predict_tau and predict_variance and expand predict method --- stochtree/bcf.py | 232 +++-------------------------------------------- 1 file changed, 15 insertions(+), 217 deletions(-) diff --git a/stochtree/bcf.py b/stochtree/bcf.py index fcf0f8f3..9e520849 100644 --- a/stochtree/bcf.py +++ b/stochtree/bcf.py @@ -2232,222 +2232,6 @@ def sample( sigma2_x_test_raw * self.sigma2_init * self.y_std * self.y_std ) - def predict_tau( - self, X: np.array, Z: np.array, propensity: np.array = None - ) -> np.array: - """Predict CATE function for every provided observation. - - Parameters - ---------- - X : np.array or pd.DataFrame - Test set covariates. - Z : np.array - Test set treatment indicators. - propensity : np.array, optional - Optional test set propensities. Must be provided if propensities were provided when the model was sampled. - - Returns - ------- - np.array - Array with as many rows as in `X` and as many columns as retained samples of the algorithm. - """ - if not self.is_sampled(): - msg = ( - "This BCFModel instance is not fitted yet. Call 'fit' with " - "appropriate arguments before using this model." - ) - raise NotSampledError(msg) - - # Convert everything to standard shape (2-dimensional) - if X.ndim == 1: - X = np.expand_dims(X, 1) - if Z.ndim == 1: - Z = np.expand_dims(Z, 1) - else: - if Z.ndim != 2: - raise ValueError("treatment must have 1 or 2 dimensions") - if propensity is not None: - if propensity.ndim == 1: - propensity = np.expand_dims(propensity, 1) - - # Data checks - if Z.shape[0] != X.shape[0]: - raise ValueError("X and Z must have the same number of rows") - - # Covariate preprocessing - if not self._covariate_preprocessor._check_is_fitted(): - if not isinstance(X, np.ndarray): - raise ValueError( - "Prediction cannot proceed on a pandas dataframe, since the BCF model was not fit with a covariate preprocessor. Please refit your model by passing covariate data as a Pandas dataframe." - ) - else: - warnings.warn( - "This BCF model has not run any covariate preprocessing routines. We will attempt to predict on the raw covariate values, but this will trigger an error with non-numeric columns. Please refit your model by passing non-numeric covariate data a a Pandas dataframe.", - RuntimeWarning, - ) - if not np.issubdtype(X.dtype, np.floating) and not np.issubdtype( - X.dtype, np.integer - ): - raise ValueError( - "Prediction cannot proceed on a non-numeric numpy array, since the BCF model was not fit with a covariate preprocessor. Please refit your model by passing non-numeric covariate data as a Pandas dataframe." - ) - covariates_processed = X - else: - covariates_processed = self._covariate_preprocessor.transform(X) - - # Handle propensities - if propensity is not None: - if propensity.shape[0] != X.shape[0]: - raise ValueError("X and propensity must have the same number of rows") - else: - if self.propensity_covariate != "none": - if not self.internal_propensity_model: - raise ValueError( - "Propensity scores not provided, but no propensity model was trained during sampling" - ) - else: - internal_propensity_preds = self.bart_propensity_model.predict( - covariates_processed - ) - propensity = np.mean( - internal_propensity_preds["y_hat"], axis=1, keepdims=True - ) - - # Update covariates to include propensities if requested - if self.propensity_covariate == "none": - X_combined = covariates_processed - else: - X_combined = np.c_[covariates_processed, propensity] - - # Forest dataset - forest_dataset_test = Dataset() - forest_dataset_test.add_covariates(X_combined) - forest_dataset_test.add_basis(Z) - - # Estimate treatment effect - tau_raw = self.forest_container_tau.forest_container_cpp.PredictRaw( - forest_dataset_test.dataset_cpp - ) - tau_raw = tau_raw - if self.adaptive_coding: - adaptive_coding_weights = np.expand_dims( - self.b1_samples - self.b0_samples, axis=(0, 2) - ) - tau_raw = tau_raw * adaptive_coding_weights - tau_x = np.squeeze(tau_raw * self.y_std) - - # Return result matrix - return tau_x - - def predict_variance( - self, covariates: np.array, propensity: np.array = None - ) -> np.array: - """Predict expected conditional variance from a BART model. - - Parameters - ---------- - covariates : np.array - Test set covariates. - propensity : np.array, optional - Test set propensity scores. Optional (not currently used in variance forests). - - Returns - ------- - np.array - Array of predictions corresponding to the variance forest. Each array will contain as many rows as in `covariates` and as many columns as retained samples of the algorithm. - """ - if not self.is_sampled(): - msg = ( - "This BARTModel instance is not fitted yet. Call 'fit' with " - "appropriate arguments before using this model." - ) - raise NotSampledError(msg) - - if not self.include_variance_forest: - msg = ( - "This BARTModel instance was not sampled with a variance forest. " - "Call 'fit' with appropriate arguments before using this model." - ) - raise NotSampledError(msg) - - # Convert everything to standard shape (2-dimensional) - if covariates.ndim == 1: - covariates = np.expand_dims(covariates, 1) - if propensity is not None: - if propensity.ndim == 1: - propensity = np.expand_dims(propensity, 1) - - # Covariate preprocessing - if not self._covariate_preprocessor._check_is_fitted(): - if not isinstance(covariates, np.ndarray): - raise ValueError( - "Prediction cannot proceed on a pandas dataframe, since the BCF model was not fit with a covariate preprocessor. Please refit your model by passing covariate data as a Pandas dataframe." - ) - else: - warnings.warn( - "This BCF model has not run any covariate preprocessing routines. We will attempt to predict on the raw covariate values, but this will trigger an error with non-numeric columns. Please refit your model by passing non-numeric covariate data a a Pandas dataframe.", - RuntimeWarning, - ) - if not np.issubdtype( - covariates.dtype, np.floating - ) and not np.issubdtype(covariates.dtype, np.integer): - raise ValueError( - "Prediction cannot proceed on a non-numeric numpy array, since the BCF model was not fit with a covariate preprocessor. Please refit your model by passing non-numeric covariate data as a Pandas dataframe." - ) - covariates_processed = covariates - else: - covariates_processed = self._covariate_preprocessor.transform(covariates) - - # Handle propensities - if propensity is not None: - if propensity.shape[0] != covariates.shape[0]: - raise ValueError("X and propensity must have the same number of rows") - else: - if self.propensity_covariate != "none": - if not self.internal_propensity_model: - raise ValueError( - "Propensity scores not provided, but no propensity model was trained during sampling" - ) - else: - internal_propensity_preds = self.bart_propensity_model.predict( - covariates_processed - ) - propensity = np.mean( - internal_propensity_preds["y_hat"], axis=1, keepdims=True - ) - - # Update covariates to include propensities if requested - if self.propensity_covariate == "none": - X_combined = covariates_processed - else: - if propensity is not None: - X_combined = np.c_[covariates_processed, propensity] - else: - # Dummy propensities if not provided but also not needed - propensity = np.ones(covariates_processed.shape[0]) - propensity = np.expand_dims(propensity, 1) - X_combined = np.c_[covariates_processed, propensity] - - # Forest dataset - pred_dataset = Dataset() - pred_dataset.add_covariates(X_combined) - - variance_pred_raw = self.forest_container_variance.forest_container_cpp.Predict( - pred_dataset.dataset_cpp - ) - if self.sample_sigma2_global: - variance_pred = np.empty_like(variance_pred_raw) - for i in range(self.num_samples): - variance_pred[:, i] = ( - variance_pred_raw[:, i] * self.global_var_samples[i] - ) - else: - variance_pred = ( - variance_pred_raw * self.sigma2_init * self.y_std * self.y_std - ) - - return variance_pred - def predict( self, X: np.array, @@ -2457,6 +2241,7 @@ def predict( rfx_basis: np.array = None, type: str = "posterior", terms: Union[list[str], str] = "all", + scale: str = "linear" ) -> dict: """Predict outcome model components (CATE function and prognostic function) as well as overall outcome for every provided observation. Predicted outcomes are computed as `yhat = mu_x + Z*tau_x` where mu_x is a sample of the prognostic function and tau_x is a sample of the treatment effect (CATE) function. @@ -2473,16 +2258,29 @@ def predict( Optional group labels used for an additive random effects model. rfx_basis : np.array, optional Optional basis for "random-slope" regression in an additive random effects model. - type : str, optional Type of prediction to return. Options are "mean", which averages the predictions from every draw of a BART model, and "posterior", which returns the entire matrix of posterior predictions. Default: "posterior". terms : str, optional Which model terms to include in the prediction. This can be a single term or a list of model terms. Options include "y_hat", "prognostic_function", "cate", "rfx", "variance_forest", or "all". If a model doesn't have mean forest, random effects, or variance forest predictions, but one of those terms is request, the request will simply be ignored. If none of the requested terms are present in a model, this function will return `NULL` along with a warning. Default: "all". + scale : str, optional + Scale on which to return predictions. Options are "linear" (the default), which returns predictions on the original outcome scale, and "probit", which returns predictions on the probit (latent) scale. Only applicable for models fit with `probit_outcome_model=True`. Returns ------- Dict of numpy arrays for each prediction term, or a simple numpy array if a single term is requested. """ + # Handle mean function scale + if not isinstance(scale, str): + raise ValueError("scale must be a string") + if scale not in ["linear", "probability"]: + raise ValueError("scale must either be 'linear' or 'probability'") + is_probit = self.probit_outcome_model + if (scale == "probability") and (not is_probit): + raise ValueError( + "scale cannot be 'probability' for models not fit with a probit outcome model" + ) + probability_scale = scale == "probability" + # Handle prediction type if not isinstance(type, str): raise ValueError("type must be a string") From 169e3faf8b5df07ce29ec97511030659595bbbfb Mon Sep 17 00:00:00 2001 From: Drew Herren Date: Thu, 16 Oct 2025 00:26:25 -0500 Subject: [PATCH 19/53] Updated python predict methods and tests --- stochtree/bart.py | 7 ---- stochtree/bcf.py | 83 +++++++++++++++++++++++++++---------- test/python/test_bcf.py | 34 +++++++-------- test/python/test_predict.py | 6 +-- 4 files changed, 79 insertions(+), 51 deletions(-) diff --git a/stochtree/bart.py b/stochtree/bart.py index 7592a1ff..655e9c8c 100644 --- a/stochtree/bart.py +++ b/stochtree/bart.py @@ -1803,13 +1803,6 @@ def predict( # rfx_predictions = np.mean(rfx_predictions, axis = 1) # Combine into y hat predictions - if predict_y_hat and has_mean_forest and has_rfx: - y_hat = mean_forest_predictions + rfx_predictions - elif predict_y_hat and has_mean_forest: - y_hat = mean_forest_predictions - elif predict_y_hat and has_rfx: - y_hat = rfx_predictions - if probability_scale: if predict_y_hat and has_mean_forest and has_rfx: y_hat = norm.ppf(mean_forest_predictions + rfx_predictions) diff --git a/stochtree/bcf.py b/stochtree/bcf.py index 9e520849..62fae548 100644 --- a/stochtree/bcf.py +++ b/stochtree/bcf.py @@ -2400,14 +2400,28 @@ def predict( forest_dataset_test.add_covariates(X_combined) forest_dataset_test.add_basis(Z) - # Compute predicted outcome and decomposed outcome model terms + # Compute predictions from the variance forest (if included) + if predict_variance_forest: + sigma2_x_raw = self.forest_container_variance.forest_container_cpp.Predict( + forest_dataset_test.dataset_cpp + ) + if self.sample_sigma2_global: + sigma2_x = np.empty_like(sigma2_x_raw) + for i in range(self.num_samples): + sigma2_x[:, i] = sigma2_x_raw[:, i] * self.global_var_samples[i] + else: + sigma2_x = sigma2_x_raw * self.sigma2_init * self.y_std * self.y_std + if predict_mean: + sigma2_x = np.mean(sigma2_x, axis=1) + + # Prognostic forest predictions if predict_mu_forest or predict_mu_forest_intermediate: mu_raw = self.forest_container_mu.forest_container_cpp.Predict( forest_dataset_test.dataset_cpp ) mu_x = mu_raw * self.y_std + self.y_bar - if predict_mean: - mu_x = np.mean(mu_x, axis=1) + + # Treatment effect forest predictions if predict_tau_forest or predict_tau_forest_intermediate: tau_raw = self.forest_container_tau.forest_container_cpp.PredictRaw( forest_dataset_test.dataset_cpp @@ -2422,15 +2436,10 @@ def predict( treatment_term = np.multiply( np.atleast_3d(Z).swapaxes(1, 2), tau_x ).sum(axis=2) - if predict_mean: - treatment_term = np.mean(treatment_term, axis=1) - tau_x = np.mean(tau_x, axis=2) else: treatment_term = Z * np.squeeze(tau_x) - if predict_mean: - treatment_term = np.mean(treatment_term, axis=1) - tau_x = np.mean(tau_x, axis=1) + # Random effects predictions if predict_rfx or predict_rfx_intermediate: rfx_preds = ( self.rfx_container.predict(rfx_group_ids, rfx_basis) * self.y_std @@ -2438,27 +2447,54 @@ def predict( if predict_mean: rfx_preds = np.mean(rfx_preds, axis=1) + # Combine into y hat predictions if predict_y_hat and has_mu_forest and has_rfx: y_hat = mu_x + treatment_term + rfx_preds elif predict_y_hat and has_mu_forest: y_hat = mu_x + treatment_term elif predict_y_hat and has_rfx: y_hat = rfx_preds - - # Compute predictions from the variance forest (if included) - if predict_variance_forest: - sigma2_x_raw = self.forest_container_variance.forest_container_cpp.Predict( - forest_dataset_test.dataset_cpp - ) - if self.sample_sigma2_global: - sigma2_x = np.empty_like(sigma2_x_raw) - for i in range(self.num_samples): - sigma2_x[:, i] = sigma2_x_raw[:, i] * self.global_var_samples[i] + + needs_mean_term_preds = predict_y_hat or \ + predict_mu_forest or \ + predict_tau_forest or \ + predict_rfx + if needs_mean_term_preds: + if probability_scale: + if has_rfx: + if predict_y_hat: + y_hat = norm.cdf(mu_x + treatment_term + rfx_preds) + if predict_rfx: + rfx_preds = norm.cdf(rfx_preds) + else: + if predict_y_hat: + y_hat = norm.cdf(mu_x + treatment_term) + if predict_mu_forest: + mu_x = norm.cdf(mu_x) + if predict_tau_forest: + tau_x = norm.cdf(tau_x) else: - sigma2_x = sigma2_x_raw * self.sigma2_init * self.y_std * self.y_std - if predict_mean: - sigma2_x = np.mean(sigma2_x, axis=1) + if has_rfx: + if predict_y_hat: + y_hat = mu_x + treatment_term + rfx_preds + else: + if predict_y_hat: + y_hat = mu_x + treatment_term + # Collapse to posterior mean predictions if requested + if predict_mean: + if predict_mu_forest: + mu_x = np.mean(mu_x, axis=1) + if predict_tau_forest: + if Z.shape[1] > 1: + tau_x = np.mean(tau_x, axis=2) + else: + tau_x = np.mean(tau_x, axis=1) + if predict_rfx: + rfx_preds = np.mean(rfx_preds, axis=1) + if predict_y_hat: + y_hat = np.mean(y_hat, axis=1) + if predict_count == 1: if predict_y_hat: return y_hat @@ -2754,6 +2790,9 @@ def from_json_string_list(self, json_string_list: list[str]) -> None: self.internal_propensity_model = json_object_default.get_boolean( "internal_propensity_model" ) + self.probit_outcome_model = json_object_default.get_boolean( + "probit_outcome_model" + ) # Unpack number of samples for i in range(len(json_object_list)): diff --git a/test/python/test_bcf.py b/test/python/test_bcf.py index 2e2f7fbf..ed0ac3ae 100644 --- a/test/python/test_bcf.py +++ b/test/python/test_bcf.py @@ -76,8 +76,8 @@ def test_binary_bcf(self): assert mu_hat.shape == (n_test, num_mcmc) assert y_hat.shape == (n_test, num_mcmc) - # Check treatment effect prediction method - tau_hat = bcf_model.predict_tau(X_test, Z_test, pi_test) + # Check that we can predict just treatment effects + tau_hat = bcf_model.predict(X = X_test, Z = Z_test, propensity = pi_test, terms = "cate") assert tau_hat.shape == (n_test, num_mcmc) # Run BCF without test set and with propensity score @@ -106,7 +106,7 @@ def test_binary_bcf(self): assert mu_hat.shape == (n_test, num_mcmc) assert y_hat.shape == (n_test, num_mcmc) # Check treatment effect prediction method - tau_hat = bcf_model.predict_tau(X_test, Z_test, pi_test) + tau_hat = bcf_model.predict(X = X_test, Z = Z_test, propensity = pi_test, terms = "cate") assert tau_hat.shape == (n_test, num_mcmc) # Run BCF with test set and without propensity score @@ -142,7 +142,7 @@ def test_binary_bcf(self): assert y_hat.shape == (n_test, num_mcmc) # Check treatment effect prediction method - tau_hat = bcf_model.predict_tau(X_test, Z_test) + tau_hat = bcf_model.predict(X = X_test, Z = Z_test, terms = "cate") assert tau_hat.shape == (n_test, num_mcmc) # Run BCF without test set and without propensity score @@ -172,7 +172,7 @@ def test_binary_bcf(self): assert y_hat.shape == (n_test, num_mcmc) # Check treatment effect prediction method - tau_hat = bcf_model.predict_tau(X_test, Z_test) + tau_hat = bcf_model.predict(X = X_test, Z = Z_test, terms = "cate") def test_continuous_univariate_bcf(self): # RNG @@ -245,7 +245,7 @@ def test_continuous_univariate_bcf(self): assert y_hat.shape == (n_test, num_mcmc) # Check treatment effect prediction method - tau_hat = bcf_model.predict_tau(X_test, Z_test, pi_test) + tau_hat = bcf_model.predict(X = X_test, Z = Z_test, propensity = pi_test, terms = "cate") assert tau_hat.shape == (n_test, num_mcmc) # Run second BCF model with test set and propensity score @@ -281,7 +281,7 @@ def test_continuous_univariate_bcf(self): assert y_hat_2.shape == (n_test, num_mcmc) # Check treatment effect prediction method - tau_hat_2 = bcf_model_2.predict_tau(X_test, Z_test, pi_test) + tau_hat_2 = bcf_model_2.predict(X = X_test, Z = Z_test, propensity = pi_test, terms = "cate") assert tau_hat_2.shape == (n_test, num_mcmc) # Combine into a single model @@ -336,7 +336,7 @@ def test_continuous_univariate_bcf(self): assert y_hat.shape == (n_test, num_mcmc) # Check treatment effect prediction method - tau_hat = bcf_model.predict_tau(X_test, Z_test, pi_test) + tau_hat = bcf_model.predict(X = X_test, Z = Z_test, propensity = pi_test, terms = "cate") assert tau_hat.shape == (n_test, num_mcmc) # Run BCF with test set and without propensity score @@ -372,7 +372,7 @@ def test_continuous_univariate_bcf(self): assert y_hat.shape == (n_test, num_mcmc) # Check treatment effect prediction method - tau_hat = bcf_model.predict_tau(X_test, Z_test) + tau_hat = bcf_model.predict(X = X_test, Z = Z_test, terms = "cate") assert tau_hat.shape == (n_test, num_mcmc) # Run BCF without test set and without propensity score @@ -402,7 +402,7 @@ def test_continuous_univariate_bcf(self): assert y_hat.shape == (n_test, num_mcmc) # Check treatment effect prediction method - tau_hat = bcf_model.predict_tau(X_test, Z_test) + tau_hat = bcf_model.predict(X = X_test, Z = Z_test, terms = "cate") # Run second BCF model with test set and propensity score bcf_model_2 = BCFModel() @@ -430,7 +430,7 @@ def test_continuous_univariate_bcf(self): assert y_hat_2.shape == (n_test, num_mcmc) # Check treatment effect prediction method - tau_hat_2 = bcf_model_2.predict_tau(X_test, Z_test) + tau_hat_2 = bcf_model_2.predict(X = X_test, Z = Z_test, terms = "cate") assert tau_hat_2.shape == (n_test, num_mcmc) # Combine into a single model @@ -528,7 +528,7 @@ def test_multivariate_bcf(self): assert y_hat.shape == (n_test, num_mcmc) # Check treatment effect prediction method - tau_hat = bcf_model.predict_tau(X_test, Z_test, pi_test) + tau_hat = bcf_model.predict(X = X_test, Z = Z_test, propensity = pi_test, terms = "cate") assert tau_hat.shape == (n_test, num_mcmc, treatment_dim) # Run BCF without test set and with propensity score @@ -558,7 +558,7 @@ def test_multivariate_bcf(self): assert y_hat.shape == (n_test, num_mcmc) # Check treatment effect prediction method - tau_hat = bcf_model.predict_tau(X_test, Z_test, pi_test) + tau_hat = bcf_model.predict(X = X_test, Z = Z_test, propensity = pi_test, terms = "cate") assert tau_hat.shape == (n_test, num_mcmc, treatment_dim) # Run BCF with test set and without propensity score @@ -665,7 +665,7 @@ def test_binary_bcf_heteroskedastic(self): assert sigma2_x_hat.shape == (n_test, num_mcmc) # Check treatment effect prediction method - tau_hat = bcf_model.predict_tau(X_test, Z_test, pi_test) + tau_hat = bcf_model.predict(X = X_test, Z = Z_test, propensity = pi_test, terms = "cate") assert tau_hat.shape == (n_test, num_mcmc) # Run BCF without test set and with propensity score @@ -715,7 +715,7 @@ def test_binary_bcf_heteroskedastic(self): ) # Check treatment effect prediction method - tau_hat = bcf_model.predict_tau(X_test, Z_test, pi_test) + tau_hat = bcf_model.predict(X = X_test, Z = Z_test, propensity = pi_test, terms = "cate") assert tau_hat.shape == (n_test, num_mcmc) # Run BCF with test set and without propensity score @@ -752,7 +752,7 @@ def test_binary_bcf_heteroskedastic(self): assert bcf_preds['variance_forest_predictions'].shape == (n_test, num_mcmc) # Check treatment effect prediction method - tau_hat = bcf_model.predict_tau(X_test, Z_test) + tau_hat = bcf_model.predict(X = X_test, Z = Z_test, terms = "cate") assert tau_hat.shape == (n_test, num_mcmc) # Run BCF without test set and without propensity score @@ -781,7 +781,7 @@ def test_binary_bcf_heteroskedastic(self): assert bcf_preds['y_hat'].shape == (n_test, num_mcmc) # Check treatment effect prediction method - tau_hat = bcf_model.predict_tau(X_test, Z_test) + tau_hat = bcf_model.predict(X = X_test, Z = Z_test, terms = "cate") def test_bcf_rfx_parameters(self): # RNG diff --git a/test/python/test_predict.py b/test/python/test_predict.py index 0dda27d4..8ff6d4ed 100644 --- a/test/python/test_predict.py +++ b/test/python/test_predict.py @@ -263,10 +263,6 @@ def test_bart_prediction(self): def test_bcf_prediction(self): # Generate data and test/train split rng = np.random.default_rng(1234) - - - # Convert the R code down below to Python - rng = np.random.default_rng(1234) n = 100 g = lambda x: np.where(x[:, 4] == 1, 2, np.where(x[:, 4] == 2, -1, -4)) x1 = rng.normal(size=n) @@ -328,7 +324,7 @@ def g(x5): num_mcmc = 10 ) - # Check that the default predict method returns a list + # Check that the default predict method returns a dictionary pred = bcf_model.predict(X=X_test, Z=Z_test, propensity=pi_x_test) y_hat_posterior_test = pred['y_hat'] assert y_hat_posterior_test.shape == (20, 10) From 8f51e2be0cb6fd5a6afe037222d6a3fd201743e5 Mon Sep 17 00:00:00 2001 From: Drew Herren Date: Thu, 16 Oct 2025 12:09:30 -0500 Subject: [PATCH 20/53] Updated python prediction and interval methods --- R/posterior_transformation.R | 8 +- demo/debug/bart_predict_debug.py | 18 +++- demo/debug/bcf_predict_debug.py | 118 +++++++++++++++++++++++ stochtree/bart.py | 141 ++++++++++++++++++++++++++- stochtree/bcf.py | 159 ++++++++++++++++++++++++++++++- stochtree/forest.py | 4 - stochtree/utils.py | 33 +++++++ 7 files changed, 461 insertions(+), 20 deletions(-) create mode 100644 demo/debug/bcf_predict_debug.py diff --git a/R/posterior_transformation.R b/R/posterior_transformation.R index c8adf8bc..32e05b67 100644 --- a/R/posterior_transformation.R +++ b/R/posterior_transformation.R @@ -361,7 +361,7 @@ posterior_predictive_heuristic_multiplier <- function( #' @param scale (Optional) Scale of mean function predictions. Options are "linear", which returns predictions on the original scale of the mean forest / RFX terms, and "probability", which transforms predictions into a probability of observing `y == 1`. "probability" is only valid for models fit with a probit outcome model. Default: "linear". #' @param covariates (Optional) A matrix or data frame of covariates at which to compute the intervals. Required if the requested term depends on covariates (e.g., prognostic forest, CATE forest, variance forest, or overall predictions). #' @param treatment (Optional) A vector or matrix of treatment assignments. Required if the requested term is `"y_hat"` (overall predictions). -#' @param propensity (Optional) A vector or matrix of propensity scores. Required if the requested term is `"y_hat"` (overall predictions) and the underlying model depends on user-provided propensities. +#' @param propensity (Optional) A vector or matrix of propensity scores. Required if the underlying model depends on user-provided propensities. #' @param rfx_group_ids An optional vector of group IDs for random effects. Required if the requested term includes random effects. #' @param rfx_basis An optional matrix of basis function evaluations for random effects. Required if the requested term includes random effects. #' @@ -417,7 +417,6 @@ compute_bcf_posterior_interval <- function( "scale cannot be 'probability' for models not fit with a probit outcome model" ) } - probability_scale <- scale == "probability" # Check that all the necessary inputs were provided for interval computation needs_covariates_intermediate <- ((("y_hat" %in% terms) || @@ -547,9 +546,7 @@ compute_bcf_posterior_interval <- function( } } -#' Compute posterior credible intervals for BART model terms -#' -#' This function computes posterior credible intervals for specified terms from a fitted BART model. It supports intervals for mean functions, variance functions, random effects, and overall predictions. +#' Compute posterior credible intervals for specified terms from a fitted BART model. It supports intervals for mean functions, variance functions, random effects, and overall predictions. #' @param model_object A fitted BART or BCF model object of class `bartmodel`. #' @param terms A character string specifying the model term(s) for which to compute intervals. Options for BART models are `"mean_forest"`, `"variance_forest"`, `"rfx"`, or `"y_hat"`. #' @param level A numeric value between 0 and 1 specifying the credible interval level (default is 0.95 for a 95% credible interval). @@ -604,7 +601,6 @@ compute_bart_posterior_interval <- function( "scale cannot be 'probability' for models not fit with a probit outcome model" ) } - probability_scale <- scale == "probability" # Check that all the necessary inputs were provided for interval computation needs_covariates_intermediate <- ((("y_hat" %in% terms) || diff --git a/demo/debug/bart_predict_debug.py b/demo/debug/bart_predict_debug.py index 6b477053..04197039 100644 --- a/demo/debug/bart_predict_debug.py +++ b/demo/debug/bart_predict_debug.py @@ -8,7 +8,7 @@ # Generate data rng = np.random.default_rng() -n = 100 +n = 500 p = 5 X = rng.uniform(low=0.0, high=1.0, size=(n, p)) f_X = np.where( @@ -42,7 +42,7 @@ X_test=X_test, num_gfr=10, num_burnin=0, - num_mcmc=10, + num_mcmc=1000, ) # # Check several predict approaches @@ -66,3 +66,17 @@ plt.ylabel("Actual") plt.title("Y hat") plt.show() + +# Compute posterior interval +intervals = bart_model.compute_posterior_interval( + terms = "all", + scale = "linear", + level = 0.95, + covariates = X_test +) + +# Check coverage +mean_coverage = np.mean( + (intervals["y_hat"]["lower"] <= f_X_test) & (f_X_test <= intervals["y_hat"]["upper"]) +) +print(f"Coverage of 95% posterior interval for f(X): {mean_coverage:.3f}") diff --git a/demo/debug/bcf_predict_debug.py b/demo/debug/bcf_predict_debug.py new file mode 100644 index 00000000..bb95e93a --- /dev/null +++ b/demo/debug/bcf_predict_debug.py @@ -0,0 +1,118 @@ +# Demo of updated predict method for BART + +# Load library +from stochtree import BCFModel +import numpy as np +from sklearn.model_selection import train_test_split +from scipy.stats import norm +import matplotlib.pyplot as plt + +# Generate data +rng = np.random.default_rng() +n = 1000 +p = 5 +X = rng.normal(loc=0.0, scale=1.0, size=(n, p)) +mu_X = X[:,0] +tau_X = 0.25 * X[:,1] +pi_X = norm.cdf(0.5 * X[:,1]) +Z = rng.binomial(n=1, p=pi_X, size=(n,)) +E_XZ = mu_X + tau_X * Z +snr = 2.0 +noise_sd = np.std(E_XZ) / snr +y = E_XZ + rng.normal(loc=0.0, scale=noise_sd, size=(n,)) + +# Train-test split +sample_inds = np.arange(n) +test_set_pct = 0.2 +train_inds, test_inds = train_test_split(sample_inds, test_size=test_set_pct) +X_train = X[train_inds, :] +X_test = X[test_inds, :] +Z_train = Z[train_inds] +Z_test = Z[test_inds] +pi_train = pi_X[train_inds] +pi_test = pi_X[test_inds] +tau_train = tau_X[train_inds] +tau_test = tau_X[test_inds] +mu_train = mu_X[train_inds] +mu_test = mu_X[test_inds] +y_train = y[train_inds] +y_test = y[test_inds] +E_XZ_train = E_XZ[train_inds] +E_XZ_test = E_XZ[test_inds] + +# Fit simple BCF model +bcf_model = BCFModel() +bcf_model.sample( + X_train=X_train, + Z_train=Z_train, + pi_train=pi_train, + y_train=y_train, + num_gfr=10, + num_burnin=0, + num_mcmc=1000, +) + +# Check several predict approaches +bcf_preds = bcf_model.predict(X=X_test, Z=Z_test, propensity=pi_test) +y_hat_posterior_test = bcf_model.predict(X=X_test, Z=Z_test, propensity=pi_test)['y_hat'] +y_hat_mean_test = bcf_model.predict( + X=X_test, Z=Z_test, propensity=pi_test, + type = "mean", + terms = ["y_hat"] +) +tau_hat_mean_test = bcf_model.predict( + X=X_test, Z=Z_test, propensity=pi_test, + type = "mean", + terms = ["cate"] +) +# Check that this raises a warning +y_hat_test = bcf_model.predict( + X=X_test, Z=Z_test, propensity=pi_test, + type = "mean", + terms = ["rfx", "variance"] +) + +# Plot predicted versus actual +plt.scatter(y_hat_mean_test, y_test, color="black") +plt.axline((0, 0), slope=1, color="red", linestyle=(0, (3,3))) +plt.xlabel("Predicted") +plt.ylabel("Actual") +plt.title("Y hat") +plt.show() + +# Plot predicted versus actual +plt.clf() +plt.scatter(tau_hat_mean_test, tau_test, color="black") +plt.axline((0, 0), slope=1, color="red", linestyle=(0, (3,3))) +plt.xlabel("Predicted") +plt.ylabel("Actual") +plt.title("CATE function") +plt.show() + +# Compute posterior interval +intervals = bcf_model.compute_posterior_interval( + terms = "all", + scale = "linear", + level = 0.95, + covariates = X_test, + treatment = Z_test, + propensity = pi_test +) + +# Check coverage of E[Y | X, Z] +mean_coverage = np.mean( + (intervals["y_hat"]["lower"] <= E_XZ_test) & (E_XZ_test <= intervals["y_hat"]["upper"]) +) +print(f"Coverage of 95% posterior interval for E[Y|X,Z]: {mean_coverage:.3f}") + +# Check coverage of tau(X) +tau_coverage = np.mean( + (intervals["tau_hat"]["lower"] <= tau_test) & (tau_test <= intervals["tau_hat"]["upper"]) +) +print(f"Coverage of 95% posterior interval for tau(X): {tau_coverage:.3f}") + +# Check coverage of mu(X) +mu_coverage = np.mean( + (intervals["mu_hat"]["lower"] <= mu_test) & (mu_test <= intervals["mu_hat"]["upper"]) +) +print(f"Coverage of 95% posterior interval for mu(X): {mu_coverage:.3f}") diff --git a/stochtree/bart.py b/stochtree/bart.py index 655e9c8c..b7456e64 100644 --- a/stochtree/bart.py +++ b/stochtree/bart.py @@ -1,7 +1,3 @@ -""" -Bayesian Additive Regression Trees (BART) module -""" - import warnings from math import log from numbers import Integral @@ -28,6 +24,8 @@ _expand_dims_1d, _expand_dims_2d, _expand_dims_2d_diag, + _posterior_predictive_heuristic_multiplier, + _summarize_interval ) @@ -1860,6 +1858,114 @@ def predict( result["variance_forest_predictions"] = None return result + def compute_posterior_interval(self, terms: Union[list[str], str] = "all", scale: str = "linear", level: float = 0.95, covariates: np.array = None, basis: np.array = None, rfx_group_ids: np.array = None, rfx_basis: np.array = None) -> dict: + """ + Compute posterior credible intervals for specified terms from a fitted BART model. It supports intervals for mean functions, variance functions, random effects, and overall predictions. + + Parameters + ---------- + terms : str, optional + Character string specifying the model term(s) for which to compute intervals. Options for BART models are `"mean_forest"`, `"variance_forest"`, `"rfx"`, `"y_hat"`, or `"all"`. Defaults to `"all"`. + scale : str, optional + Scale of mean function predictions. Options are "linear", which returns predictions on the original scale of the mean forest / RFX terms, and "probability", which transforms predictions into a probability of observing `y == 1`. "probability" is only valid for models fit with a probit outcome model. Defaults to `"linear"`. + level : float, optional + A numeric value between 0 and 1 specifying the credible interval level. Defaults to 0.95 for a 95% credible interval. + covariates : np.array, optional + Optional array or data frame of covariates at which to compute the intervals. Required if the requested term depends on covariates (e.g., mean forest, variance forest, or overall predictions). + basis : np.array, optional + Optional array of basis function evaluations for mean forest models with regression defined in the leaves. Required for "leaf regression" models. + rfx_group_ids : np.array, optional + Optional vector of group IDs for random effects. Required if the requested term includes random effects. + rfx_basis : np.array, optional + Optional matrix of basis function evaluations for random effects. Required if the requested term includes random effects. + + Returns + ------- + dict + A dict containing the lower and upper bounds of the credible interval for the specified term. If multiple terms are requested, a dict with intervals for each term is returned. + """ + # Check the provided model object and requested term + self.is_sampled() + for term in terms: + self.has_term(term) + + # Handle mean function scale + if not isinstance(scale, str): + raise ValueError("scale must be a string") + if scale not in ["linear", "probability"]: + raise ValueError("scale must either be 'linear' or 'probability'") + is_probit = self.probit_outcome_model + if (scale == "probability") and (not is_probit): + raise ValueError( + "scale cannot be 'probability' for models not fit with a probit outcome model" + ) + + # Check that all the necessary inputs were provided for interval computation + needs_covariates_intermediate = (("y_hat" in terms) or ("all" in terms)) and self.include_mean_forest + needs_covariates = ("mean_forest" in terms) or ("variance_forest" in terms) or needs_covariates_intermediate + if needs_covariates: + if covariates is None: + raise ValueError( + "'covariates' must be provided in order to compute the requested intervals" + ) + if not isinstance(covariates, np.ndarray) and not isinstance( + covariates, pd.DataFrame + ): + raise ValueError("'covariates' must be a matrix or data frame") + needs_basis = needs_covariates and self.has_basis + if needs_basis: + if basis is None: + raise ValueError( + "'basis' must be provided in order to compute the requested intervals" + ) + if not isinstance(basis, np.ndarray): + raise ValueError("'basis' must be a numpy array") + if basis.shape[0] != covariates.shape[0]: + raise ValueError( + "'basis' must have the same number of rows as 'covariates'" + ) + needs_rfx_data_intermediate = (("y_hat" in terms) or ("all" in terms)) and self.has_rfx + needs_rfx_data = ("rfx" in terms) or needs_rfx_data_intermediate + if needs_rfx_data: + if rfx_group_ids is None: + raise ValueError( + "'rfx_group_ids' must be provided in order to compute the requested intervals" + ) + if not isinstance(rfx_group_ids, np.ndarray): + raise ValueError("'rfx_group_ids' must be a numpy array") + if rfx_group_ids.shape[0] != covariates.shape[0]: + raise ValueError( + "'rfx_group_ids' must have the same length as the number of rows in 'covariates'" + ) + if rfx_basis is None: + raise ValueError( + "'rfx_basis' must be provided in order to compute the requested intervals" + ) + if not isinstance(rfx_basis, np.ndarray): + raise ValueError("'rfx_basis' must be a numpy array") + if rfx_basis.shape[0] != covariates.shape[0]: + raise ValueError( + "'rfx_basis' must have the same number of rows as 'covariates'" + ) + + # Compute posterior matrices for the requested model terms + predictions = self.predict(covariates=covariates, basis=basis, rfx_group_ids=rfx_group_ids, rfx_basis=rfx_basis, type="posterior", terms=terms, scale=scale) + has_multiple_terms = True if isinstance(predictions, dict) else False + + # Compute posterior intervals + if has_multiple_terms: + result = dict() + for term in predictions.keys(): + if predictions[term] is not None: + result[term] = _summarize_interval( + predictions[term], 1, level=level + ) + return result + else: + return _summarize_interval( + predictions, 1, level=level + ) + def to_json(self) -> str: """ Converts a sampled BART model to JSON string representation (which can then be saved to a file or @@ -2145,3 +2251,30 @@ def is_sampled(self) -> bool: `True` if a BART model has been sampled, `False` otherwise """ return self.sampled + + def has_term(self, term: str) -> bool: + """ + Whether or not a model includes a term. + + Parameters + ---------- + term : str + Character string specifying the model term to check for. Options for BART models are `"mean_forest"`, `"variance_forest"`, `"rfx"`, `"y_hat"`, or `"all"`. + + Returns + ------- + bool + `True` if the model includes the specified term, `False` otherwise + """ + if term == "mean_forest": + return self.include_mean_forest + elif term == "variance_forest": + return self.include_variance_forest + elif term == "rfx": + return self.has_rfx + elif term == "y_hat": + return self.include_mean_forest or self.has_rfx + elif term == "all": + return True + else: + return False diff --git a/stochtree/bcf.py b/stochtree/bcf.py index 62fae548..19f0c296 100644 --- a/stochtree/bcf.py +++ b/stochtree/bcf.py @@ -1,7 +1,3 @@ -""" -Bayesian Causal Forests (BCF) module -""" - import warnings from typing import Any, Dict, Optional, Union @@ -28,6 +24,8 @@ _expand_dims_1d, _expand_dims_2d, _expand_dims_2d_diag, + _posterior_predictive_heuristic_multiplier, + _summarize_interval ) @@ -2529,6 +2527,130 @@ def predict( else: result["variance_forest_predictions"] = None return result + + def compute_posterior_interval(self, terms: Union[list[str], str] = "all", scale: str = "linear", level: float = 0.95, covariates: np.array = None, treatment: np.array = None, propensity: np.array = None, rfx_group_ids: np.array = None, rfx_basis: np.array = None) -> dict: + """ + Compute posterior credible intervals for specified terms from a fitted BART model. It supports intervals for mean functions, variance functions, random effects, and overall predictions. + + Parameters + ---------- + terms : str, optional + Character string specifying the model term(s) for which to compute intervals. Options for BCF models are `"prognostic_function"`, `"cate"`, `"variance_forest"`, `"rfx"`, or `"y_hat"`. Defaults to `"all"`. + scale : str, optional + Scale of mean function predictions. Options are "linear", which returns predictions on the original scale of the mean forest / RFX terms, and "probability", which transforms predictions into a probability of observing `y == 1`. "probability" is only valid for models fit with a probit outcome model. Defaults to `"linear"`. + level : float, optional + A numeric value between 0 and 1 specifying the credible interval level. Defaults to 0.95 for a 95% credible interval. + covariates : np.array, optional + Optional array or data frame of covariates at which to compute the intervals. Required if the requested term depends on covariates (e.g., prognostic forest, treatment effect forest, variance forest, or overall predictions). + treatment : np.array, optional + Optional array of treatment assignments. Required if the requested term is `"y_hat"` (overall predictions). + propensity : np.array, optional + Optional array of propensity scores. Required if the underlying model depends on user-provided propensities. + rfx_group_ids : np.array, optional + Optional vector of group IDs for random effects. Required if the requested term includes random effects. + rfx_basis : np.array, optional + Optional matrix of basis function evaluations for random effects. Required if the requested term includes random effects. + + Returns + ------- + dict + A dict containing the lower and upper bounds of the credible interval for the specified term. If multiple terms are requested, a dict with intervals for each term is returned. + """ + # Check the provided model object and requested term + self.is_sampled() + for term in terms: + self.has_term(term) + + # Handle mean function scale + if not isinstance(scale, str): + raise ValueError("scale must be a string") + if scale not in ["linear", "probability"]: + raise ValueError("scale must either be 'linear' or 'probability'") + is_probit = self.probit_outcome_model + if (scale == "probability") and (not is_probit): + raise ValueError( + "scale cannot be 'probability' for models not fit with a probit outcome model" + ) + + # Check that all the necessary inputs were provided for interval computation + needs_covariates_intermediate = (("y_hat" in terms) or ("all" in terms)) + needs_covariates = ("prognostic_function" in terms) or ("cate" in terms) or ("variance_forest" in terms) or needs_covariates_intermediate + if needs_covariates: + if covariates is None: + raise ValueError( + "'covariates' must be provided in order to compute the requested intervals" + ) + if not isinstance(covariates, np.ndarray) and not isinstance( + covariates, pd.DataFrame + ): + raise ValueError("'covariates' must be a matrix or data frame") + needs_treatment = needs_covariates + if needs_treatment: + if treatment is None: + raise ValueError( + "'treatment' must be provided in order to compute the requested intervals" + ) + if not isinstance(treatment, np.ndarray): + raise ValueError("'treatment' must be a numpy array") + if treatment.shape[0] != covariates.shape[0]: + raise ValueError( + "'treatment' must have the same number of rows as 'covariates'" + ) + uses_propensity = self.propensity_covariate is not "none" + internal_propensity_model = self.internal_propensity_model + needs_propensity = needs_covariates and uses_propensity and not internal_propensity_model + if needs_propensity: + if propensity is None: + raise ValueError( + "'propensity' must be provided in order to compute the requested intervals" + ) + if not isinstance(propensity, np.ndarray): + raise ValueError("'propensity' must be a numpy array") + if propensity.shape[0] != covariates.shape[0]: + raise ValueError( + "'propensity' must have the same number of rows as 'covariates'" + ) + needs_rfx_data_intermediate = (("y_hat" in terms) or ("all" in terms)) and self.has_rfx + needs_rfx_data = ("rfx" in terms) or needs_rfx_data_intermediate + if needs_rfx_data: + if rfx_group_ids is None: + raise ValueError( + "'rfx_group_ids' must be provided in order to compute the requested intervals" + ) + if not isinstance(rfx_group_ids, np.ndarray): + raise ValueError("'rfx_group_ids' must be a numpy array") + if rfx_group_ids.shape[0] != covariates.shape[0]: + raise ValueError( + "'rfx_group_ids' must have the same length as the number of rows in 'covariates'" + ) + if rfx_basis is None: + raise ValueError( + "'rfx_basis' must be provided in order to compute the requested intervals" + ) + if not isinstance(rfx_basis, np.ndarray): + raise ValueError("'rfx_basis' must be a numpy array") + if rfx_basis.shape[0] != covariates.shape[0]: + raise ValueError( + "'rfx_basis' must have the same number of rows as 'covariates'" + ) + + # Compute posterior matrices for the requested model terms + predictions = self.predict(X=covariates, Z=treatment, propensity=propensity, rfx_group_ids=rfx_group_ids, rfx_basis=rfx_basis, type="posterior", terms=terms, scale=scale) + has_multiple_terms = True if isinstance(predictions, dict) else False + + # Compute posterior intervals + if has_multiple_terms: + result = dict() + for term in predictions.keys(): + if predictions[term] is not None: + result[term] = _summarize_interval( + predictions[term], 1, level=level + ) + return result + else: + return _summarize_interval( + predictions, 1, level=level + ) def to_json(self) -> str: """ @@ -2871,3 +2993,32 @@ def is_sampled(self) -> bool: `True` if a BCF model has been sampled, `False` otherwise """ return self.sampled + + def has_term(self, term: str) -> bool: + """ + Whether or not a model includes a term. + + Parameters + ---------- + term : str + Character string specifying the model term to check for. Options for BCF models are `"prognostic_function"`, `"cate"`, `"variance_forest"`, `"rfx"`, `"y_hat"`, or `"all"`. + + Returns + ------- + bool + `True` if the model includes the specified term, `False` otherwise + """ + if term == "prognostic_function": + return True + if term == "cate": + return True + elif term == "variance_forest": + return self.include_variance_forest + elif term == "rfx": + return self.has_rfx + elif term == "y_hat": + return True + elif term == "all": + return True + else: + return False diff --git a/stochtree/forest.py b/stochtree/forest.py index a92c5847..3cfde245 100644 --- a/stochtree/forest.py +++ b/stochtree/forest.py @@ -1,7 +1,3 @@ -""" -Python classes wrapping C++ forest container object -""" - from typing import Union import numpy as np diff --git a/stochtree/utils.py b/stochtree/utils.py index 92cc73a5..214beb2f 100644 --- a/stochtree/utils.py +++ b/stochtree/utils.py @@ -1,4 +1,5 @@ from typing import Union +import math import numpy as np @@ -332,3 +333,35 @@ def _expand_dims_2d_diag( "`input` must be either a 2D square numpy array or a scalar that can be propagated along the diagonal of a square matrix" ) return output + +def _posterior_predictive_heuristic_multiplier(num_samples: int, num_observations: int) -> int: + if num_samples >= 1000: + return 1 + else: + return math.ceil(1000 / num_samples) + +def _summarize_interval(array: np.ndarray, sample_dim: int = 2, level: float = 0.95) -> dict: + # Check that the array is numeric and at least 2 dimensional + if not isinstance(array, np.ndarray): + raise ValueError("`array` must be a numpy array") + if not _check_array_numeric(array): + raise ValueError("`array` must be a numeric numpy array") + if not len(array.shape) >= 2: + raise ValueError("`array` must be at least a 2-dimensional numpy array") + if not _check_is_int(sample_dim) or (sample_dim < 0) or (sample_dim >= len(array.shape)): + raise ValueError( + "`sample_dim` must be an integer between 0 and the number of dimensions of `array` - 1" + ) + if not isinstance(level, float) or (level <= 0) or (level >= 1): + raise ValueError("`level` must be a float between 0 and 1") + + # Compute lower and upper quantiles based on the requested interval + quantile_lb = (1 - level) / 2 + quantile_ub = 1 - quantile_lb + + # Calculate the interval + result_lb = np.quantile(array, q=quantile_lb, axis=sample_dim) + result_ub = np.quantile(array, q=quantile_ub, axis=sample_dim) + + # Return results as a dictionary + return {"lower": result_lb, "upper": result_ub} From 4f1beed35aedf15d216173ba7848a5c4d591e1a2 Mon Sep 17 00:00:00 2001 From: Drew Herren Date: Fri, 17 Oct 2025 01:02:33 -0500 Subject: [PATCH 21/53] Added posterior predictive sampling methods to BART and BCF in python --- R/posterior_transformation.R | 56 +++++++----- demo/debug/bart_predict_debug.py | 22 +++++ demo/debug/bcf_predict_debug.py | 22 +++++ stochtree/bart.py | 136 +++++++++++++++++++++++++++- stochtree/bcf.py | 146 ++++++++++++++++++++++++++++++- 5 files changed, 357 insertions(+), 25 deletions(-) diff --git a/R/posterior_transformation.R b/R/posterior_transformation.R index 32e05b67..94c15ed4 100644 --- a/R/posterior_transformation.R +++ b/R/posterior_transformation.R @@ -1,14 +1,14 @@ #' Sample from the posterior predictive distribution for outcomes modeled by BCF #' #' @param model_object A fitted BCF model object of class `bcfmodel`. -#' @param covariates (Optional) A matrix or data frame of covariates at which to compute the intervals. Required if the requested term depends on covariates (e.g., prognostic forest, CATE forest, variance forest, or overall predictions). -#' @param treatment (Optional) A vector or matrix of treatment assignments. Required if the requested term is `"y_hat"` (overall predictions). -#' @param propensity (Optional) A vector or matrix of propensity scores. Required if the requested term is `"y_hat"` (overall predictions) and the underlying model depends on user-provided propensities. +#' @param covariates A matrix or data frame of covariates. +#' @param treatment A vector or matrix of treatment assignments. +#' @param propensity (Optional) A vector or matrix of propensity scores. Required if the underlying model depends on user-provided propensities. #' @param rfx_group_ids (Optional) A vector of group IDs for random effects model. Required if the BCF model includes random effects. #' @param rfx_basis (Optional) A matrix of bases for random effects model. Required if the BCF model includes random effects. -#' @param num_draws (Optional) The number of samples to draw from the likelihood, for each draw of the posterior, in computing intervals. Defaults to a heuristic based on the number of samples in a BCF model (i.e. if the BCF model has >1000 draws, we use 1 draw from the likelihood per sample, otherwise we upsample to ensure at least 1000 posterior predictive draws). +#' @param num_draws_per_sample (Optional) The number of samples to draw from the likelihood for each draw of the posterior. Defaults to a heuristic based on the number of samples in a BCF model (i.e. if the BCF model has >1000 draws, we use 1 draw from the likelihood per sample, otherwise we upsample to ensure at least 1000 posterior predictive draws). #' -#' @returns Array of posterior predictive samples with dimensions (num_observations, num_posterior_samples, num_draws) if num_draws > 1, otherwise (num_observations, num_posterior_samples). +#' @returns Array of posterior predictive samples with dimensions (num_observations, num_posterior_samples, num_draws_per_sample) if num_draws_per_sample > 1, otherwise (num_observations, num_posterior_samples). #' #' @export #' @examples @@ -30,9 +30,9 @@ sample_bcf_posterior_predictive <- function( propensity = NULL, rfx_group_ids = NULL, rfx_basis = NULL, - num_draws = NULL + num_draws_per_sample = NULL ) { - # Check the provided model object and requested term + # Check the provided model object check_model_is_valid(model_object) # Determine whether the outcome is continuous (Gaussian) or binary (probit-link) @@ -123,7 +123,7 @@ sample_bcf_posterior_predictive <- function( } } - # Compute posterior predictive samples + # Compute posterior samples bcf_preds <- predict( model_object, X = covariates, @@ -132,8 +132,11 @@ sample_bcf_posterior_predictive <- function( rfx_group_ids = rfx_group_ids, rfx_basis = rfx_basis, type = "posterior", - terms = c("all") + terms = c("all"), + scale = "linear" ) + + # Compute outcome mean and variance for every posterior draw has_rfx <- model_object$model_params$has_rfx has_variance_forest <- model_object$model_params$include_variance_forest samples_global_variance <- model_object$model_params$sample_sigma2_global @@ -155,16 +158,20 @@ sample_bcf_posterior_predictive <- function( ppd_variance <- model_object$model_params$initial_sigma2 } } - if (is.null(num_draws)) { + + # Sample from the posterior predictive distribution + if (is.null(num_draws_per_sample)) { ppd_draw_multiplier <- posterior_predictive_heuristic_multiplier( num_posterior_draws, num_observations ) } else { - ppd_draw_multiplier <- num_draws + ppd_draw_multiplier <- num_draws_per_sample } num_ppd_draws <- ppd_draw_multiplier * num_posterior_draws * num_observations ppd_vector <- rnorm(num_ppd_draws, ppd_mean, sqrt(ppd_variance)) + + # Reshape data if (ppd_draw_multiplier > 1) { ppd_array <- array( ppd_vector, @@ -177,6 +184,7 @@ sample_bcf_posterior_predictive <- function( ) } + # Binarize outcomes for probit models if (is_probit) { ppd_array <- (ppd_array > 0.0) * 1 } @@ -187,13 +195,13 @@ sample_bcf_posterior_predictive <- function( #' Sample from the posterior predictive distribution for outcomes modeled by BART #' #' @param model_object A fitted BART model object of class `bartmodel`. -#' @param covariates A matrix or data frame of covariates at which to compute the intervals. Required if the BART model depends on covariates (e.g., contains a mean or variance forest). +#' @param covariates A matrix or data frame of covariates. Required if the BART model depends on covariates (e.g., contains a mean or variance forest). #' @param basis A matrix of bases for mean forest models with regression defined in the leaves. Required for "leaf regression" models. #' @param rfx_group_ids A vector of group IDs for random effects model. Required if the BART model includes random effects. #' @param rfx_basis A matrix of bases for random effects model. Required if the BART model includes random effects. -#' @param num_draws The number of posterior predictive samples to draw in computing intervals. Defaults to a heuristic based on the number of samples in a BART model (i.e. if the BART model has >1000 draws, we use 1 draw from the likelihood per sample, otherwise we upsample to ensure intervals are based on at least 1000 posterior predictive draws). +#' @param num_draws_per_sample The number of posterior predictive samples to draw for each posterior sample. Defaults to a heuristic based on the number of samples in a BART model (i.e. if the BART model has >1000 draws, we use 1 draw from the likelihood per sample, otherwise we upsample to ensure intervals are based on at least 1000 posterior predictive draws). #' -#' @returns Array of posterior predictive samples with dimensions (num_observations, num_posterior_samples, num_draws) if num_draws > 1, otherwise (num_observations, num_posterior_samples). +#' @returns Array of posterior predictive samples with dimensions (num_observations, num_posterior_samples, num_draws_per_sample) if num_draws_per_sample > 1, otherwise (num_observations, num_posterior_samples). #' #' @export #' @examples @@ -211,9 +219,9 @@ sample_bart_posterior_predictive <- function( basis = NULL, rfx_group_ids = NULL, rfx_basis = NULL, - num_draws = NULL + num_draws_per_sample = NULL ) { - # Check the provided model object and requested term + # Check the provided model object check_model_is_valid(model_object) # Determine whether the outcome is continuous (Gaussian) or binary (probit-link) @@ -276,7 +284,7 @@ sample_bart_posterior_predictive <- function( } } - # Compute posterior predictive samples + # Compute posterior samples bart_preds <- predict( model_object, covariates = covariates, @@ -284,8 +292,11 @@ sample_bart_posterior_predictive <- function( rfx_group_ids = rfx_group_ids, rfx_basis = rfx_basis, type = "posterior", - terms = c("all") + terms = c("all"), + scale = "linear" ) + + # Compute outcome mean and variance for every posterior draw has_mean_term <- (model_object$model_params$include_mean_forest || model_object$model_params$has_rfx) has_variance_forest <- model_object$model_params$include_variance_forest @@ -312,16 +323,20 @@ sample_bart_posterior_predictive <- function( ppd_variance <- model_object$model_params$sigma2_init } } - if (is.null(num_draws)) { + + # Sample from the posterior predictive distribution + if (is.null(num_draws_per_sample)) { ppd_draw_multiplier <- posterior_predictive_heuristic_multiplier( num_posterior_draws, num_observations ) } else { - ppd_draw_multiplier <- num_draws + ppd_draw_multiplier <- num_draws_per_sample } num_ppd_draws <- ppd_draw_multiplier * num_posterior_draws * num_observations ppd_vector <- rnorm(num_ppd_draws, ppd_mean, sqrt(ppd_variance)) + + # Reshape data if (ppd_draw_multiplier > 1) { ppd_array <- array( ppd_vector, @@ -334,6 +349,7 @@ sample_bart_posterior_predictive <- function( ) } + # Binarize outcomes for probit models if (is_probit) { ppd_array <- (ppd_array > 0.0) * 1 } diff --git a/demo/debug/bart_predict_debug.py b/demo/debug/bart_predict_debug.py index 04197039..15b8184d 100644 --- a/demo/debug/bart_predict_debug.py +++ b/demo/debug/bart_predict_debug.py @@ -80,3 +80,25 @@ (intervals["y_hat"]["lower"] <= f_X_test) & (f_X_test <= intervals["y_hat"]["upper"]) ) print(f"Coverage of 95% posterior interval for f(X): {mean_coverage:.3f}") + +# Sample from the posterior predictive distribution +bart_ppd_samples = bart_model.sample_posterior_predictive( + covariates = X_test, num_draws_per_sample = 10 +) + +# Plot PPD mean vs actual +ppd_mean = np.mean(bart_ppd_samples, axis=(0, 2)) +plt.clf() +plt.scatter(ppd_mean, y_test, color="blue") +plt.axline((0, 0), slope=1, color="red", linestyle=(0, (3,3))) +plt.xlabel("Predicted") +plt.ylabel("Actual") +plt.title("Posterior Predictive Mean Comparison") +plt.show() + +# Check coverage of posterior predictive distribution +ppd_intervals = np.percentile(bart_ppd_samples, [2.5, 97.5], axis=(0, 2)) +ppd_coverage = np.mean( + (ppd_intervals[0, :] <= y_test) & (y_test <= ppd_intervals[1, :]) +) +print(f"Coverage of 95% posterior predictive interval for Y: {ppd_coverage:.3f}") diff --git a/demo/debug/bcf_predict_debug.py b/demo/debug/bcf_predict_debug.py index bb95e93a..dfdf22f4 100644 --- a/demo/debug/bcf_predict_debug.py +++ b/demo/debug/bcf_predict_debug.py @@ -116,3 +116,25 @@ (intervals["mu_hat"]["lower"] <= mu_test) & (mu_test <= intervals["mu_hat"]["upper"]) ) print(f"Coverage of 95% posterior interval for mu(X): {mu_coverage:.3f}") + +# Sample from the posterior predictive distribution +bcf_ppd_samples = bcf_model.sample_posterior_predictive( + covariates = X_test, treatment = Z_test, propensity = pi_test, num_draws_per_sample = 10 +) + +# Plot PPD mean vs actual +ppd_mean = np.mean(bcf_ppd_samples, axis=(0, 2)) +plt.clf() +plt.scatter(ppd_mean, y_test, color="blue") +plt.axline((0, 0), slope=1, color="red", linestyle=(0, (3,3))) +plt.xlabel("Predicted") +plt.ylabel("Actual") +plt.title("Posterior Predictive Mean Comparison") +plt.show() + +# Check coverage of posterior predictive distribution +ppd_intervals = np.percentile(bcf_ppd_samples, [2.5, 97.5], axis=(0, 2)) +ppd_coverage = np.mean( + (ppd_intervals[0, :] <= y_test) & (y_test <= ppd_intervals[1, :]) +) +print(f"Coverage of 95% posterior predictive interval for Y: {ppd_coverage:.3f}") diff --git a/stochtree/bart.py b/stochtree/bart.py index b7456e64..76b95691 100644 --- a/stochtree/bart.py +++ b/stochtree/bart.py @@ -1884,10 +1884,12 @@ def compute_posterior_interval(self, terms: Union[list[str], str] = "all", scale dict A dict containing the lower and upper bounds of the credible interval for the specified term. If multiple terms are requested, a dict with intervals for each term is returned. """ - # Check the provided model object and requested term - self.is_sampled() + # Check the provided model object and requested terms + if not self.is_sampled(): + raise ValueError("Model has not yet been sampled") for term in terms: - self.has_term(term) + if not self.has_term(term): + warnings.warn(f"Term {term} was not sampled in this model and its intervals will not be returned.") # Handle mean function scale if not isinstance(scale, str): @@ -1966,6 +1968,134 @@ def compute_posterior_interval(self, terms: Union[list[str], str] = "all", scale predictions, 1, level=level ) + def sample_posterior_predictive(self, covariates: np.array = None, basis: np.array = None, rfx_group_ids: np.array = None, rfx_basis: np.array = None, num_draws_per_sample: int = None) -> np.array: + """ + Sample from the posterior predictive distribution for outcomes modeled by BART + + Parameters + ---------- + covariates : np.array, optional + An array or data frame of covariates at which to compute the intervals. Required if the BART model depends on covariates (e.g., contains a mean or variance forest). + basis : np.array, optional + An array of basis function evaluations for mean forest models with regression defined in the leaves. Required for "leaf regression" models. + rfx_group_ids : np.array, optional + An array of group IDs for random effects. Required if the BART model includes random effects. + rfx_basis : np.array, optional + An array of basis function evaluations for random effects. Required if the BART model includes random effects. + num_draws_per_sample : int, optional + The number of posterior predictive samples to draw for each posterior sample. Defaults to a heuristic based on the number of samples in a BART model (i.e. if the BART model has >1000 draws, we use 1 draw from the likelihood per sample, otherwise we upsample to ensure intervals are based on at least 1000 posterior predictive draws). + + Returns + ------- + np.array + A matrix of posterior predictive samples. If `num_draws = 1`. + """ + # Check the provided model object + if not self.is_sampled(): + raise ValueError("Model has not yet been sampled") + + # Determine whether the outcome is continuous (Gaussian) or binary (probit-link) + is_probit = self.probit_outcome_model + + # Check that all the necessary inputs were provided for interval computation + needs_covariates = self.include_mean_forest + if needs_covariates: + if covariates is None: + raise ValueError( + "'covariates' must be provided in order to compute the requested intervals" + ) + if not isinstance(covariates, np.ndarray) and not isinstance( + covariates, pd.DataFrame + ): + raise ValueError("'covariates' must be a matrix or data frame") + needs_basis = needs_covariates and self.has_basis + if needs_basis: + if basis is None: + raise ValueError( + "'basis' must be provided in order to compute the requested intervals" + ) + if not isinstance(basis, np.ndarray): + raise ValueError("'basis' must be a numpy array") + if basis.shape[0] != covariates.shape[0]: + raise ValueError( + "'basis' must have the same number of rows as 'covariates'" + ) + needs_rfx_data = self.has_rfx + if needs_rfx_data: + if rfx_group_ids is None: + raise ValueError( + "'rfx_group_ids' must be provided in order to compute the requested intervals" + ) + if not isinstance(rfx_group_ids, np.ndarray): + raise ValueError("'rfx_group_ids' must be a numpy array") + if rfx_group_ids.shape[0] != covariates.shape[0]: + raise ValueError( + "'rfx_group_ids' must have the same length as the number of rows in 'covariates'" + ) + if rfx_basis is None: + raise ValueError( + "'rfx_basis' must be provided in order to compute the requested intervals" + ) + if not isinstance(rfx_basis, np.ndarray): + raise ValueError("'rfx_basis' must be a numpy array") + if rfx_basis.shape[0] != covariates.shape[0]: + raise ValueError( + "'rfx_basis' must have the same number of rows as 'covariates'" + ) + + # Compute posterior predictive samples + bart_preds = self.predict(covariates=covariates, basis=basis, rfx_group_ids=rfx_group_ids, rfx_basis=rfx_basis, type="posterior", terms="all") + + # Compute outcome mean and variance for posterior predictive distribution + has_mean_term = (self.include_mean_forest or self.has_rfx) + has_variance_forest = self.include_variance_forest + samples_global_variance = self.sample_sigma2_global + num_posterior_draws = self.num_samples + num_observations = covariates.shape[0] + if has_mean_term: + ppd_mean = bart_preds["y_hat"] + else: + ppd_mean = 0. + if has_variance_forest: + ppd_variance = bart_preds["variance_forest_predictions"] + else: + if samples_global_variance: + ppd_variance = np.tile( + self.global_var_samples, + (num_observations, 1) + ) + else: + ppd_variance = self.sigma2_init + + # Sample from the posterior predictive distribution + if num_draws_per_sample is None: + ppd_draw_multiplier = _posterior_predictive_heuristic_multiplier( + num_posterior_draws, + num_observations + ) + else: + ppd_draw_multiplier = num_draws_per_sample + if ppd_draw_multiplier > 1: + ppd_mean = np.tile(ppd_mean, (ppd_draw_multiplier, 1, 1)) + ppd_variance = np.tile(ppd_variance, (ppd_draw_multiplier, 1, 1)) + ppd_array = np.random.normal( + loc = ppd_mean, + scale = np.sqrt(ppd_variance), + size = (ppd_draw_multiplier, num_observations, num_posterior_draws) + ) + else: + ppd_array = np.random.normal( + loc = ppd_mean, + scale = np.sqrt(ppd_variance), + size = (num_observations, num_posterior_draws) + ) + + # Binarize outcome for probit models + if is_probit: + ppd_array = (ppd_array > 0.0) * 1 + + return ppd_array + def to_json(self) -> str: """ Converts a sampled BART model to JSON string representation (which can then be saved to a file or diff --git a/stochtree/bcf.py b/stochtree/bcf.py index 19f0c296..9dea8ecb 100644 --- a/stochtree/bcf.py +++ b/stochtree/bcf.py @@ -2557,9 +2557,11 @@ def compute_posterior_interval(self, terms: Union[list[str], str] = "all", scale A dict containing the lower and upper bounds of the credible interval for the specified term. If multiple terms are requested, a dict with intervals for each term is returned. """ # Check the provided model object and requested term - self.is_sampled() + if not self.is_sampled(): + raise ValueError("Model has not yet been sampled") for term in terms: - self.has_term(term) + if not self.has_term(term): + warnings.warn(f"Term {term} was not sampled in this model and its intervals will not be returned.") # Handle mean function scale if not isinstance(scale, str): @@ -2652,6 +2654,146 @@ def compute_posterior_interval(self, terms: Union[list[str], str] = "all", scale predictions, 1, level=level ) + def sample_posterior_predictive(self, covariates: np.array, treatment: np.array, propensity: np.array = None, rfx_group_ids: np.array = None, rfx_basis: np.array = None, num_draws_per_sample: int = None) -> np.array: + """ + Sample from the posterior predictive distribution for outcomes modeled by BART + + Parameters + ---------- + covariates : np.array + An array or data frame of covariates. + treatment : np.array + An array of treatment assignments. + propensity : np.array, optional + Optional array of propensity scores. Required if the underlying model depends on user-provided propensities. + rfx_group_ids : np.array, optional + Optional vector of group IDs for random effects. Required if the requested term includes random effects. + rfx_basis : np.array, optional + Optional matrix of basis function evaluations for random effects. Required if the requested term includes random effects. + num_draws_per_sample : int, optional + The number of posterior predictive samples to draw for each posterior sample. Defaults to a heuristic based on the number of samples in a BCF model (i.e. if the BCF model has >1000 draws, we use 1 draw from the likelihood per sample, otherwise we upsample to ensure intervals are based on at least 1000 posterior predictive draws). + + Returns + ------- + np.array + A matrix of posterior predictive samples. If `num_draws = 1`. + """ + # Check the provided model object + if not self.is_sampled(): + raise ValueError("Model has not yet been sampled") + + # Determine whether the outcome is continuous (Gaussian) or binary (probit-link) + is_probit = self.probit_outcome_model + + # Check that all the necessary inputs were provided for interval computation + needs_covariates = True + if needs_covariates: + if covariates is None: + raise ValueError( + "'covariates' must be provided in order to compute the requested intervals" + ) + if not isinstance(covariates, np.ndarray) and not isinstance( + covariates, pd.DataFrame + ): + raise ValueError("'covariates' must be a matrix or data frame") + needs_treatment = needs_covariates + if needs_treatment: + if treatment is None: + raise ValueError( + "'treatment' must be provided in order to compute the requested intervals" + ) + if not isinstance(treatment, np.ndarray): + raise ValueError("'treatment' must be a numpy array") + if treatment.shape[0] != covariates.shape[0]: + raise ValueError( + "'treatment' must have the same number of rows as 'covariates'" + ) + uses_propensity = self.propensity_covariate != "none" + internal_propensity_model = self.internal_propensity_model + needs_propensity = needs_covariates and uses_propensity and not internal_propensity_model + if needs_propensity: + if propensity is None: + raise ValueError( + "'propensity' must be provided in order to compute the requested intervals" + ) + if not isinstance(propensity, np.ndarray): + raise ValueError("'propensity' must be a numpy array") + if propensity.shape[0] != covariates.shape[0]: + raise ValueError( + "'propensity' must have the same number of rows as 'covariates'" + ) + needs_rfx_data = self.has_rfx + if needs_rfx_data: + if rfx_group_ids is None: + raise ValueError( + "'rfx_group_ids' must be provided in order to compute the requested intervals" + ) + if not isinstance(rfx_group_ids, np.ndarray): + raise ValueError("'rfx_group_ids' must be a numpy array") + if rfx_group_ids.shape[0] != covariates.shape[0]: + raise ValueError( + "'rfx_group_ids' must have the same length as the number of rows in 'covariates'" + ) + if rfx_basis is None: + raise ValueError( + "'rfx_basis' must be provided in order to compute the requested intervals" + ) + if not isinstance(rfx_basis, np.ndarray): + raise ValueError("'rfx_basis' must be a numpy array") + if rfx_basis.shape[0] != covariates.shape[0]: + raise ValueError( + "'rfx_basis' must have the same number of rows as 'covariates'" + ) + + # Compute posterior predictive samples + bcf_preds = self.predict(X=covariates, Z=treatment, propensity=propensity, rfx_group_ids=rfx_group_ids, rfx_basis=rfx_basis, type="posterior", terms="all", scale="linear") + + # Compute outcome mean and variance for posterior predictive distribution + has_variance_forest = self.include_variance_forest + samples_global_variance = self.sample_sigma2_global + num_posterior_draws = self.num_samples + num_observations = covariates.shape[0] + ppd_mean = bcf_preds["y_hat"] + if has_variance_forest: + ppd_variance = bcf_preds["variance_forest_predictions"] + else: + if samples_global_variance: + ppd_variance = np.tile( + self.global_var_samples, + (num_observations, 1) + ) + else: + ppd_variance = self.sigma2_init + + # Sample from the posterior predictive distribution + if num_draws_per_sample is None: + ppd_draw_multiplier = _posterior_predictive_heuristic_multiplier( + num_posterior_draws, + num_observations + ) + else: + ppd_draw_multiplier = num_draws_per_sample + if ppd_draw_multiplier > 1: + ppd_mean = np.tile(ppd_mean, (ppd_draw_multiplier, 1, 1)) + ppd_variance = np.tile(ppd_variance, (ppd_draw_multiplier, 1, 1)) + ppd_array = np.random.normal( + loc = ppd_mean, + scale = np.sqrt(ppd_variance), + size = (ppd_draw_multiplier, num_observations, num_posterior_draws) + ) + else: + ppd_array = np.random.normal( + loc = ppd_mean, + scale = np.sqrt(ppd_variance), + size = (num_observations, num_posterior_draws) + ) + + # Binarize outcome for probit models + if is_probit: + ppd_array = (ppd_array > 0.0) * 1 + + return ppd_array + def to_json(self) -> str: """ Converts a sampled BART model to JSON string representation (which can then be saved to a file or From 85888646ce93b57dcb8262cced67b51d0c361116 Mon Sep 17 00:00:00 2001 From: Drew Herren Date: Fri, 17 Oct 2025 01:05:37 -0500 Subject: [PATCH 22/53] Reformatted python code --- demo/debug/bart_predict_debug.py | 26 +++---- demo/debug/bcf_predict_debug.py | 51 ++++++------ stochtree/bart.py | 105 ++++++++++++++++--------- stochtree/bcf.py | 128 ++++++++++++++++++++----------- 4 files changed, 189 insertions(+), 121 deletions(-) diff --git a/demo/debug/bart_predict_debug.py b/demo/debug/bart_predict_debug.py index 15b8184d..d66b1110 100644 --- a/demo/debug/bart_predict_debug.py +++ b/demo/debug/bart_predict_debug.py @@ -47,21 +47,15 @@ # # Check several predict approaches bart_preds = bart_model.predict(covariates=X_test) -y_hat_posterior_test = bart_model.predict(covariates=X_test)['y_hat'] -y_hat_mean_test = bart_model.predict( - covariates=X_test, - type = "mean", - terms = ["y_hat"] -) +y_hat_posterior_test = bart_model.predict(covariates=X_test)["y_hat"] +y_hat_mean_test = bart_model.predict(covariates=X_test, type="mean", terms=["y_hat"]) y_hat_test = bart_model.predict( - covariates=X_test, - type = "mean", - terms = ["rfx", "variance"] + covariates=X_test, type="mean", terms=["rfx", "variance"] ) # Plot predicted versus actual plt.scatter(y_hat_mean_test, y_test, color="black") -plt.axline((0, 0), slope=1, color="red", linestyle=(0, (3,3))) +plt.axline((0, 0), slope=1, color="red", linestyle=(0, (3, 3))) plt.xlabel("Predicted") plt.ylabel("Actual") plt.title("Y hat") @@ -69,28 +63,26 @@ # Compute posterior interval intervals = bart_model.compute_posterior_interval( - terms = "all", - scale = "linear", - level = 0.95, - covariates = X_test + terms="all", scale="linear", level=0.95, covariates=X_test ) # Check coverage mean_coverage = np.mean( - (intervals["y_hat"]["lower"] <= f_X_test) & (f_X_test <= intervals["y_hat"]["upper"]) + (intervals["y_hat"]["lower"] <= f_X_test) + & (f_X_test <= intervals["y_hat"]["upper"]) ) print(f"Coverage of 95% posterior interval for f(X): {mean_coverage:.3f}") # Sample from the posterior predictive distribution bart_ppd_samples = bart_model.sample_posterior_predictive( - covariates = X_test, num_draws_per_sample = 10 + covariates=X_test, num_draws_per_sample=10 ) # Plot PPD mean vs actual ppd_mean = np.mean(bart_ppd_samples, axis=(0, 2)) plt.clf() plt.scatter(ppd_mean, y_test, color="blue") -plt.axline((0, 0), slope=1, color="red", linestyle=(0, (3,3))) +plt.axline((0, 0), slope=1, color="red", linestyle=(0, (3, 3))) plt.xlabel("Predicted") plt.ylabel("Actual") plt.title("Posterior Predictive Mean Comparison") diff --git a/demo/debug/bcf_predict_debug.py b/demo/debug/bcf_predict_debug.py index dfdf22f4..2257684a 100644 --- a/demo/debug/bcf_predict_debug.py +++ b/demo/debug/bcf_predict_debug.py @@ -12,9 +12,9 @@ n = 1000 p = 5 X = rng.normal(loc=0.0, scale=1.0, size=(n, p)) -mu_X = X[:,0] -tau_X = 0.25 * X[:,1] -pi_X = norm.cdf(0.5 * X[:,1]) +mu_X = X[:, 0] +tau_X = 0.25 * X[:, 1] +pi_X = norm.cdf(0.5 * X[:, 1]) Z = rng.binomial(n=1, p=pi_X, size=(n,)) E_XZ = mu_X + tau_X * Z snr = 2.0 @@ -54,27 +54,23 @@ # Check several predict approaches bcf_preds = bcf_model.predict(X=X_test, Z=Z_test, propensity=pi_test) -y_hat_posterior_test = bcf_model.predict(X=X_test, Z=Z_test, propensity=pi_test)['y_hat'] +y_hat_posterior_test = bcf_model.predict(X=X_test, Z=Z_test, propensity=pi_test)[ + "y_hat" +] y_hat_mean_test = bcf_model.predict( - X=X_test, Z=Z_test, propensity=pi_test, - type = "mean", - terms = ["y_hat"] + X=X_test, Z=Z_test, propensity=pi_test, type="mean", terms=["y_hat"] ) tau_hat_mean_test = bcf_model.predict( - X=X_test, Z=Z_test, propensity=pi_test, - type = "mean", - terms = ["cate"] + X=X_test, Z=Z_test, propensity=pi_test, type="mean", terms=["cate"] ) # Check that this raises a warning y_hat_test = bcf_model.predict( - X=X_test, Z=Z_test, propensity=pi_test, - type = "mean", - terms = ["rfx", "variance"] + X=X_test, Z=Z_test, propensity=pi_test, type="mean", terms=["rfx", "variance"] ) # Plot predicted versus actual plt.scatter(y_hat_mean_test, y_test, color="black") -plt.axline((0, 0), slope=1, color="red", linestyle=(0, (3,3))) +plt.axline((0, 0), slope=1, color="red", linestyle=(0, (3, 3))) plt.xlabel("Predicted") plt.ylabel("Actual") plt.title("Y hat") @@ -83,7 +79,7 @@ # Plot predicted versus actual plt.clf() plt.scatter(tau_hat_mean_test, tau_test, color="black") -plt.axline((0, 0), slope=1, color="red", linestyle=(0, (3,3))) +plt.axline((0, 0), slope=1, color="red", linestyle=(0, (3, 3))) plt.xlabel("Predicted") plt.ylabel("Actual") plt.title("CATE function") @@ -91,42 +87,45 @@ # Compute posterior interval intervals = bcf_model.compute_posterior_interval( - terms = "all", - scale = "linear", - level = 0.95, - covariates = X_test, - treatment = Z_test, - propensity = pi_test + terms="all", + scale="linear", + level=0.95, + covariates=X_test, + treatment=Z_test, + propensity=pi_test, ) # Check coverage of E[Y | X, Z] mean_coverage = np.mean( - (intervals["y_hat"]["lower"] <= E_XZ_test) & (E_XZ_test <= intervals["y_hat"]["upper"]) + (intervals["y_hat"]["lower"] <= E_XZ_test) + & (E_XZ_test <= intervals["y_hat"]["upper"]) ) print(f"Coverage of 95% posterior interval for E[Y|X,Z]: {mean_coverage:.3f}") # Check coverage of tau(X) tau_coverage = np.mean( - (intervals["tau_hat"]["lower"] <= tau_test) & (tau_test <= intervals["tau_hat"]["upper"]) + (intervals["tau_hat"]["lower"] <= tau_test) + & (tau_test <= intervals["tau_hat"]["upper"]) ) print(f"Coverage of 95% posterior interval for tau(X): {tau_coverage:.3f}") # Check coverage of mu(X) mu_coverage = np.mean( - (intervals["mu_hat"]["lower"] <= mu_test) & (mu_test <= intervals["mu_hat"]["upper"]) + (intervals["mu_hat"]["lower"] <= mu_test) + & (mu_test <= intervals["mu_hat"]["upper"]) ) print(f"Coverage of 95% posterior interval for mu(X): {mu_coverage:.3f}") # Sample from the posterior predictive distribution bcf_ppd_samples = bcf_model.sample_posterior_predictive( - covariates = X_test, treatment = Z_test, propensity = pi_test, num_draws_per_sample = 10 + covariates=X_test, treatment=Z_test, propensity=pi_test, num_draws_per_sample=10 ) # Plot PPD mean vs actual ppd_mean = np.mean(bcf_ppd_samples, axis=(0, 2)) plt.clf() plt.scatter(ppd_mean, y_test, color="blue") -plt.axline((0, 0), slope=1, color="red", linestyle=(0, (3,3))) +plt.axline((0, 0), slope=1, color="red", linestyle=(0, (3, 3))) plt.xlabel("Predicted") plt.ylabel("Actual") plt.title("Posterior Predictive Mean Comparison") diff --git a/stochtree/bart.py b/stochtree/bart.py index 76b95691..3fddac92 100644 --- a/stochtree/bart.py +++ b/stochtree/bart.py @@ -24,8 +24,8 @@ _expand_dims_1d, _expand_dims_2d, _expand_dims_2d_diag, - _posterior_predictive_heuristic_multiplier, - _summarize_interval + _posterior_predictive_heuristic_multiplier, + _summarize_interval, ) @@ -1858,7 +1858,16 @@ def predict( result["variance_forest_predictions"] = None return result - def compute_posterior_interval(self, terms: Union[list[str], str] = "all", scale: str = "linear", level: float = 0.95, covariates: np.array = None, basis: np.array = None, rfx_group_ids: np.array = None, rfx_basis: np.array = None) -> dict: + def compute_posterior_interval( + self, + terms: Union[list[str], str] = "all", + scale: str = "linear", + level: float = 0.95, + covariates: np.array = None, + basis: np.array = None, + rfx_group_ids: np.array = None, + rfx_basis: np.array = None, + ) -> dict: """ Compute posterior credible intervals for specified terms from a fitted BART model. It supports intervals for mean functions, variance functions, random effects, and overall predictions. @@ -1889,7 +1898,9 @@ def compute_posterior_interval(self, terms: Union[list[str], str] = "all", scale raise ValueError("Model has not yet been sampled") for term in terms: if not self.has_term(term): - warnings.warn(f"Term {term} was not sampled in this model and its intervals will not be returned.") + warnings.warn( + f"Term {term} was not sampled in this model and its intervals will not be returned." + ) # Handle mean function scale if not isinstance(scale, str): @@ -1903,8 +1914,14 @@ def compute_posterior_interval(self, terms: Union[list[str], str] = "all", scale ) # Check that all the necessary inputs were provided for interval computation - needs_covariates_intermediate = (("y_hat" in terms) or ("all" in terms)) and self.include_mean_forest - needs_covariates = ("mean_forest" in terms) or ("variance_forest" in terms) or needs_covariates_intermediate + needs_covariates_intermediate = ( + ("y_hat" in terms) or ("all" in terms) + ) and self.include_mean_forest + needs_covariates = ( + ("mean_forest" in terms) + or ("variance_forest" in terms) + or needs_covariates_intermediate + ) if needs_covariates: if covariates is None: raise ValueError( @@ -1926,7 +1943,9 @@ def compute_posterior_interval(self, terms: Union[list[str], str] = "all", scale raise ValueError( "'basis' must have the same number of rows as 'covariates'" ) - needs_rfx_data_intermediate = (("y_hat" in terms) or ("all" in terms)) and self.has_rfx + needs_rfx_data_intermediate = ( + ("y_hat" in terms) or ("all" in terms) + ) and self.has_rfx needs_rfx_data = ("rfx" in terms) or needs_rfx_data_intermediate if needs_rfx_data: if rfx_group_ids is None: @@ -1951,7 +1970,15 @@ def compute_posterior_interval(self, terms: Union[list[str], str] = "all", scale ) # Compute posterior matrices for the requested model terms - predictions = self.predict(covariates=covariates, basis=basis, rfx_group_ids=rfx_group_ids, rfx_basis=rfx_basis, type="posterior", terms=terms, scale=scale) + predictions = self.predict( + covariates=covariates, + basis=basis, + rfx_group_ids=rfx_group_ids, + rfx_basis=rfx_basis, + type="posterior", + terms=terms, + scale=scale, + ) has_multiple_terms = True if isinstance(predictions, dict) else False # Compute posterior intervals @@ -1964,11 +1991,16 @@ def compute_posterior_interval(self, terms: Union[list[str], str] = "all", scale ) return result else: - return _summarize_interval( - predictions, 1, level=level - ) - - def sample_posterior_predictive(self, covariates: np.array = None, basis: np.array = None, rfx_group_ids: np.array = None, rfx_basis: np.array = None, num_draws_per_sample: int = None) -> np.array: + return _summarize_interval(predictions, 1, level=level) + + def sample_posterior_predictive( + self, + covariates: np.array = None, + basis: np.array = None, + rfx_group_ids: np.array = None, + rfx_basis: np.array = None, + num_draws_per_sample: int = None, + ) -> np.array: """ Sample from the posterior predictive distribution for outcomes modeled by BART @@ -1984,7 +2016,7 @@ def sample_posterior_predictive(self, covariates: np.array = None, basis: np.arr An array of basis function evaluations for random effects. Required if the BART model includes random effects. num_draws_per_sample : int, optional The number of posterior predictive samples to draw for each posterior sample. Defaults to a heuristic based on the number of samples in a BART model (i.e. if the BART model has >1000 draws, we use 1 draw from the likelihood per sample, otherwise we upsample to ensure intervals are based on at least 1000 posterior predictive draws). - + Returns ------- np.array @@ -2044,10 +2076,17 @@ def sample_posterior_predictive(self, covariates: np.array = None, basis: np.arr ) # Compute posterior predictive samples - bart_preds = self.predict(covariates=covariates, basis=basis, rfx_group_ids=rfx_group_ids, rfx_basis=rfx_basis, type="posterior", terms="all") + bart_preds = self.predict( + covariates=covariates, + basis=basis, + rfx_group_ids=rfx_group_ids, + rfx_basis=rfx_basis, + type="posterior", + terms="all", + ) # Compute outcome mean and variance for posterior predictive distribution - has_mean_term = (self.include_mean_forest or self.has_rfx) + has_mean_term = self.include_mean_forest or self.has_rfx has_variance_forest = self.include_variance_forest samples_global_variance = self.sample_sigma2_global num_posterior_draws = self.num_samples @@ -2055,23 +2094,19 @@ def sample_posterior_predictive(self, covariates: np.array = None, basis: np.arr if has_mean_term: ppd_mean = bart_preds["y_hat"] else: - ppd_mean = 0. + ppd_mean = 0.0 if has_variance_forest: ppd_variance = bart_preds["variance_forest_predictions"] else: if samples_global_variance: - ppd_variance = np.tile( - self.global_var_samples, - (num_observations, 1) - ) + ppd_variance = np.tile(self.global_var_samples, (num_observations, 1)) else: ppd_variance = self.sigma2_init - + # Sample from the posterior predictive distribution if num_draws_per_sample is None: ppd_draw_multiplier = _posterior_predictive_heuristic_multiplier( - num_posterior_draws, - num_observations + num_posterior_draws, num_observations ) else: ppd_draw_multiplier = num_draws_per_sample @@ -2079,23 +2114,23 @@ def sample_posterior_predictive(self, covariates: np.array = None, basis: np.arr ppd_mean = np.tile(ppd_mean, (ppd_draw_multiplier, 1, 1)) ppd_variance = np.tile(ppd_variance, (ppd_draw_multiplier, 1, 1)) ppd_array = np.random.normal( - loc = ppd_mean, - scale = np.sqrt(ppd_variance), - size = (ppd_draw_multiplier, num_observations, num_posterior_draws) + loc=ppd_mean, + scale=np.sqrt(ppd_variance), + size=(ppd_draw_multiplier, num_observations, num_posterior_draws), ) else: ppd_array = np.random.normal( - loc = ppd_mean, - scale = np.sqrt(ppd_variance), - size = (num_observations, num_posterior_draws) + loc=ppd_mean, + scale=np.sqrt(ppd_variance), + size=(num_observations, num_posterior_draws), ) - + # Binarize outcome for probit models if is_probit: ppd_array = (ppd_array > 0.0) * 1 - + return ppd_array - + def to_json(self) -> str: """ Converts a sampled BART model to JSON string representation (which can then be saved to a file or @@ -2381,7 +2416,7 @@ def is_sampled(self) -> bool: `True` if a BART model has been sampled, `False` otherwise """ return self.sampled - + def has_term(self, term: str) -> bool: """ Whether or not a model includes a term. @@ -2390,7 +2425,7 @@ def has_term(self, term: str) -> bool: ---------- term : str Character string specifying the model term to check for. Options for BART models are `"mean_forest"`, `"variance_forest"`, `"rfx"`, `"y_hat"`, or `"all"`. - + Returns ------- bool diff --git a/stochtree/bcf.py b/stochtree/bcf.py index 9dea8ecb..e361b76e 100644 --- a/stochtree/bcf.py +++ b/stochtree/bcf.py @@ -24,8 +24,8 @@ _expand_dims_1d, _expand_dims_2d, _expand_dims_2d_diag, - _posterior_predictive_heuristic_multiplier, - _summarize_interval + _posterior_predictive_heuristic_multiplier, + _summarize_interval, ) @@ -2239,7 +2239,7 @@ def predict( rfx_basis: np.array = None, type: str = "posterior", terms: Union[list[str], str] = "all", - scale: str = "linear" + scale: str = "linear", ) -> dict: """Predict outcome model components (CATE function and prognostic function) as well as overall outcome for every provided observation. Predicted outcomes are computed as `yhat = mu_x + Z*tau_x` where mu_x is a sample of the prognostic function and tau_x is a sample of the treatment effect (CATE) function. @@ -2278,7 +2278,7 @@ def predict( "scale cannot be 'probability' for models not fit with a probit outcome model" ) probability_scale = scale == "probability" - + # Handle prediction type if not isinstance(type, str): raise ValueError("type must be a string") @@ -2418,7 +2418,7 @@ def predict( forest_dataset_test.dataset_cpp ) mu_x = mu_raw * self.y_std + self.y_bar - + # Treatment effect forest predictions if predict_tau_forest or predict_tau_forest_intermediate: tau_raw = self.forest_container_tau.forest_container_cpp.PredictRaw( @@ -2452,11 +2452,10 @@ def predict( y_hat = mu_x + treatment_term elif predict_y_hat and has_rfx: y_hat = rfx_preds - - needs_mean_term_preds = predict_y_hat or \ - predict_mu_forest or \ - predict_tau_forest or \ - predict_rfx + + needs_mean_term_preds = ( + predict_y_hat or predict_mu_forest or predict_tau_forest or predict_rfx + ) if needs_mean_term_preds: if probability_scale: if has_rfx: @@ -2492,7 +2491,7 @@ def predict( rfx_preds = np.mean(rfx_preds, axis=1) if predict_y_hat: y_hat = np.mean(y_hat, axis=1) - + if predict_count == 1: if predict_y_hat: return y_hat @@ -2527,8 +2526,18 @@ def predict( else: result["variance_forest_predictions"] = None return result - - def compute_posterior_interval(self, terms: Union[list[str], str] = "all", scale: str = "linear", level: float = 0.95, covariates: np.array = None, treatment: np.array = None, propensity: np.array = None, rfx_group_ids: np.array = None, rfx_basis: np.array = None) -> dict: + + def compute_posterior_interval( + self, + terms: Union[list[str], str] = "all", + scale: str = "linear", + level: float = 0.95, + covariates: np.array = None, + treatment: np.array = None, + propensity: np.array = None, + rfx_group_ids: np.array = None, + rfx_basis: np.array = None, + ) -> dict: """ Compute posterior credible intervals for specified terms from a fitted BART model. It supports intervals for mean functions, variance functions, random effects, and overall predictions. @@ -2561,7 +2570,9 @@ def compute_posterior_interval(self, terms: Union[list[str], str] = "all", scale raise ValueError("Model has not yet been sampled") for term in terms: if not self.has_term(term): - warnings.warn(f"Term {term} was not sampled in this model and its intervals will not be returned.") + warnings.warn( + f"Term {term} was not sampled in this model and its intervals will not be returned." + ) # Handle mean function scale if not isinstance(scale, str): @@ -2575,8 +2586,13 @@ def compute_posterior_interval(self, terms: Union[list[str], str] = "all", scale ) # Check that all the necessary inputs were provided for interval computation - needs_covariates_intermediate = (("y_hat" in terms) or ("all" in terms)) - needs_covariates = ("prognostic_function" in terms) or ("cate" in terms) or ("variance_forest" in terms) or needs_covariates_intermediate + needs_covariates_intermediate = ("y_hat" in terms) or ("all" in terms) + needs_covariates = ( + ("prognostic_function" in terms) + or ("cate" in terms) + or ("variance_forest" in terms) + or needs_covariates_intermediate + ) if needs_covariates: if covariates is None: raise ValueError( @@ -2600,7 +2616,9 @@ def compute_posterior_interval(self, terms: Union[list[str], str] = "all", scale ) uses_propensity = self.propensity_covariate is not "none" internal_propensity_model = self.internal_propensity_model - needs_propensity = needs_covariates and uses_propensity and not internal_propensity_model + needs_propensity = ( + needs_covariates and uses_propensity and not internal_propensity_model + ) if needs_propensity: if propensity is None: raise ValueError( @@ -2612,7 +2630,9 @@ def compute_posterior_interval(self, terms: Union[list[str], str] = "all", scale raise ValueError( "'propensity' must have the same number of rows as 'covariates'" ) - needs_rfx_data_intermediate = (("y_hat" in terms) or ("all" in terms)) and self.has_rfx + needs_rfx_data_intermediate = ( + ("y_hat" in terms) or ("all" in terms) + ) and self.has_rfx needs_rfx_data = ("rfx" in terms) or needs_rfx_data_intermediate if needs_rfx_data: if rfx_group_ids is None: @@ -2637,7 +2657,16 @@ def compute_posterior_interval(self, terms: Union[list[str], str] = "all", scale ) # Compute posterior matrices for the requested model terms - predictions = self.predict(X=covariates, Z=treatment, propensity=propensity, rfx_group_ids=rfx_group_ids, rfx_basis=rfx_basis, type="posterior", terms=terms, scale=scale) + predictions = self.predict( + X=covariates, + Z=treatment, + propensity=propensity, + rfx_group_ids=rfx_group_ids, + rfx_basis=rfx_basis, + type="posterior", + terms=terms, + scale=scale, + ) has_multiple_terms = True if isinstance(predictions, dict) else False # Compute posterior intervals @@ -2650,11 +2679,17 @@ def compute_posterior_interval(self, terms: Union[list[str], str] = "all", scale ) return result else: - return _summarize_interval( - predictions, 1, level=level - ) + return _summarize_interval(predictions, 1, level=level) - def sample_posterior_predictive(self, covariates: np.array, treatment: np.array, propensity: np.array = None, rfx_group_ids: np.array = None, rfx_basis: np.array = None, num_draws_per_sample: int = None) -> np.array: + def sample_posterior_predictive( + self, + covariates: np.array, + treatment: np.array, + propensity: np.array = None, + rfx_group_ids: np.array = None, + rfx_basis: np.array = None, + num_draws_per_sample: int = None, + ) -> np.array: """ Sample from the posterior predictive distribution for outcomes modeled by BART @@ -2672,7 +2707,7 @@ def sample_posterior_predictive(self, covariates: np.array, treatment: np.array, Optional matrix of basis function evaluations for random effects. Required if the requested term includes random effects. num_draws_per_sample : int, optional The number of posterior predictive samples to draw for each posterior sample. Defaults to a heuristic based on the number of samples in a BCF model (i.e. if the BCF model has >1000 draws, we use 1 draw from the likelihood per sample, otherwise we upsample to ensure intervals are based on at least 1000 posterior predictive draws). - + Returns ------- np.array @@ -2710,7 +2745,9 @@ def sample_posterior_predictive(self, covariates: np.array, treatment: np.array, ) uses_propensity = self.propensity_covariate != "none" internal_propensity_model = self.internal_propensity_model - needs_propensity = needs_covariates and uses_propensity and not internal_propensity_model + needs_propensity = ( + needs_covariates and uses_propensity and not internal_propensity_model + ) if needs_propensity: if propensity is None: raise ValueError( @@ -2746,7 +2783,16 @@ def sample_posterior_predictive(self, covariates: np.array, treatment: np.array, ) # Compute posterior predictive samples - bcf_preds = self.predict(X=covariates, Z=treatment, propensity=propensity, rfx_group_ids=rfx_group_ids, rfx_basis=rfx_basis, type="posterior", terms="all", scale="linear") + bcf_preds = self.predict( + X=covariates, + Z=treatment, + propensity=propensity, + rfx_group_ids=rfx_group_ids, + rfx_basis=rfx_basis, + type="posterior", + terms="all", + scale="linear", + ) # Compute outcome mean and variance for posterior predictive distribution has_variance_forest = self.include_variance_forest @@ -2758,18 +2804,14 @@ def sample_posterior_predictive(self, covariates: np.array, treatment: np.array, ppd_variance = bcf_preds["variance_forest_predictions"] else: if samples_global_variance: - ppd_variance = np.tile( - self.global_var_samples, - (num_observations, 1) - ) + ppd_variance = np.tile(self.global_var_samples, (num_observations, 1)) else: ppd_variance = self.sigma2_init - + # Sample from the posterior predictive distribution if num_draws_per_sample is None: ppd_draw_multiplier = _posterior_predictive_heuristic_multiplier( - num_posterior_draws, - num_observations + num_posterior_draws, num_observations ) else: ppd_draw_multiplier = num_draws_per_sample @@ -2777,23 +2819,23 @@ def sample_posterior_predictive(self, covariates: np.array, treatment: np.array, ppd_mean = np.tile(ppd_mean, (ppd_draw_multiplier, 1, 1)) ppd_variance = np.tile(ppd_variance, (ppd_draw_multiplier, 1, 1)) ppd_array = np.random.normal( - loc = ppd_mean, - scale = np.sqrt(ppd_variance), - size = (ppd_draw_multiplier, num_observations, num_posterior_draws) + loc=ppd_mean, + scale=np.sqrt(ppd_variance), + size=(ppd_draw_multiplier, num_observations, num_posterior_draws), ) else: ppd_array = np.random.normal( - loc = ppd_mean, - scale = np.sqrt(ppd_variance), - size = (num_observations, num_posterior_draws) + loc=ppd_mean, + scale=np.sqrt(ppd_variance), + size=(num_observations, num_posterior_draws), ) - + # Binarize outcome for probit models if is_probit: ppd_array = (ppd_array > 0.0) * 1 - + return ppd_array - + def to_json(self) -> str: """ Converts a sampled BART model to JSON string representation (which can then be saved to a file or @@ -3144,7 +3186,7 @@ def has_term(self, term: str) -> bool: ---------- term : str Character string specifying the model term to check for. Options for BCF models are `"prognostic_function"`, `"cate"`, `"variance_forest"`, `"rfx"`, `"y_hat"`, or `"all"`. - + Returns ------- bool From f57b26888a2cf3cfb0432584c3e1f80fb7649ed4 Mon Sep 17 00:00:00 2001 From: Drew Herren Date: Fri, 17 Oct 2025 09:57:39 -0500 Subject: [PATCH 23/53] Updated R package and unit tests --- man/compute_bart_posterior_interval.Rd | 4 ++-- man/sample_bart_posterior_predictive.Rd | 8 +++---- man/sample_bcf_posterior_predictive.Rd | 12 +++++------ test/R/testthat/test-bcf.R | 28 +++++++++++++------------ test/R/testthat/test-predict.R | 26 ++++++++++++----------- 5 files changed, 41 insertions(+), 37 deletions(-) diff --git a/man/compute_bart_posterior_interval.Rd b/man/compute_bart_posterior_interval.Rd index 2ae16f24..59a0a895 100644 --- a/man/compute_bart_posterior_interval.Rd +++ b/man/compute_bart_posterior_interval.Rd @@ -2,7 +2,7 @@ % Please edit documentation in R/posterior_transformation.R \name{compute_bart_posterior_interval} \alias{compute_bart_posterior_interval} -\title{Compute posterior credible intervals for BART model terms} +\title{Compute posterior credible intervals for specified terms from a fitted BART model. It supports intervals for mean functions, variance functions, random effects, and overall predictions.} \usage{ compute_bart_posterior_interval( model_object, @@ -36,7 +36,7 @@ compute_bart_posterior_interval( A list containing the lower and upper bounds of the credible interval for the specified term. If multiple terms are requested, a named list with intervals for each term is returned. } \description{ -This function computes posterior credible intervals for specified terms from a fitted BART model. It supports intervals for mean functions, variance functions, random effects, and overall predictions. +Compute posterior credible intervals for specified terms from a fitted BART model. It supports intervals for mean functions, variance functions, random effects, and overall predictions. } \examples{ n <- 100 diff --git a/man/sample_bart_posterior_predictive.Rd b/man/sample_bart_posterior_predictive.Rd index d8989143..5bce8442 100644 --- a/man/sample_bart_posterior_predictive.Rd +++ b/man/sample_bart_posterior_predictive.Rd @@ -10,13 +10,13 @@ sample_bart_posterior_predictive( basis = NULL, rfx_group_ids = NULL, rfx_basis = NULL, - num_draws = NULL + num_draws_per_sample = NULL ) } \arguments{ \item{model_object}{A fitted BART model object of class \code{bartmodel}.} -\item{covariates}{A matrix or data frame of covariates at which to compute the intervals. Required if the BART model depends on covariates (e.g., contains a mean or variance forest).} +\item{covariates}{A matrix or data frame of covariates. Required if the BART model depends on covariates (e.g., contains a mean or variance forest).} \item{basis}{A matrix of bases for mean forest models with regression defined in the leaves. Required for "leaf regression" models.} @@ -24,10 +24,10 @@ sample_bart_posterior_predictive( \item{rfx_basis}{A matrix of bases for random effects model. Required if the BART model includes random effects.} -\item{num_draws}{The number of posterior predictive samples to draw in computing intervals. Defaults to a heuristic based on the number of samples in a BART model (i.e. if the BART model has >1000 draws, we use 1 draw from the likelihood per sample, otherwise we upsample to ensure intervals are based on at least 1000 posterior predictive draws).} +\item{num_draws_per_sample}{The number of posterior predictive samples to draw for each posterior sample. Defaults to a heuristic based on the number of samples in a BART model (i.e. if the BART model has >1000 draws, we use 1 draw from the likelihood per sample, otherwise we upsample to ensure intervals are based on at least 1000 posterior predictive draws).} } \value{ -Array of posterior predictive samples with dimensions (num_observations, num_posterior_samples, num_draws) if num_draws > 1, otherwise (num_observations, num_posterior_samples). +Array of posterior predictive samples with dimensions (num_observations, num_posterior_samples, num_draws_per_sample) if num_draws_per_sample > 1, otherwise (num_observations, num_posterior_samples). } \description{ Sample from the posterior predictive distribution for outcomes modeled by BART diff --git a/man/sample_bcf_posterior_predictive.Rd b/man/sample_bcf_posterior_predictive.Rd index d4e827db..0c77d7c1 100644 --- a/man/sample_bcf_posterior_predictive.Rd +++ b/man/sample_bcf_posterior_predictive.Rd @@ -11,26 +11,26 @@ sample_bcf_posterior_predictive( propensity = NULL, rfx_group_ids = NULL, rfx_basis = NULL, - num_draws = NULL + num_draws_per_sample = NULL ) } \arguments{ \item{model_object}{A fitted BCF model object of class \code{bcfmodel}.} -\item{covariates}{(Optional) A matrix or data frame of covariates at which to compute the intervals. Required if the requested term depends on covariates (e.g., prognostic forest, CATE forest, variance forest, or overall predictions).} +\item{covariates}{A matrix or data frame of covariates.} -\item{treatment}{(Optional) A vector or matrix of treatment assignments. Required if the requested term is \code{"y_hat"} (overall predictions).} +\item{treatment}{A vector or matrix of treatment assignments.} -\item{propensity}{(Optional) A vector or matrix of propensity scores. Required if the requested term is \code{"y_hat"} (overall predictions) and the underlying model depends on user-provided propensities.} +\item{propensity}{(Optional) A vector or matrix of propensity scores. Required if the underlying model depends on user-provided propensities.} \item{rfx_group_ids}{(Optional) A vector of group IDs for random effects model. Required if the BCF model includes random effects.} \item{rfx_basis}{(Optional) A matrix of bases for random effects model. Required if the BCF model includes random effects.} -\item{num_draws}{(Optional) The number of samples to draw from the likelihood, for each draw of the posterior, in computing intervals. Defaults to a heuristic based on the number of samples in a BCF model (i.e. if the BCF model has >1000 draws, we use 1 draw from the likelihood per sample, otherwise we upsample to ensure at least 1000 posterior predictive draws).} +\item{num_draws_per_sample}{(Optional) The number of samples to draw from the likelihood for each draw of the posterior. Defaults to a heuristic based on the number of samples in a BCF model (i.e. if the BCF model has >1000 draws, we use 1 draw from the likelihood per sample, otherwise we upsample to ensure at least 1000 posterior predictive draws).} } \value{ -Array of posterior predictive samples with dimensions (num_observations, num_posterior_samples, num_draws) if num_draws > 1, otherwise (num_observations, num_posterior_samples). +Array of posterior predictive samples with dimensions (num_observations, num_posterior_samples, num_draws_per_sample) if num_draws_per_sample > 1, otherwise (num_observations, num_posterior_samples). } \description{ Sample from the posterior predictive distribution for outcomes modeled by BCF diff --git a/test/R/testthat/test-bcf.R b/test/R/testthat/test-bcf.R index a94fe3f3..8f0d69f0 100644 --- a/test/R/testthat/test-bcf.R +++ b/test/R/testthat/test-bcf.R @@ -615,19 +615,21 @@ test_that("BCF Predictions", { # Run a BCF model with only GFR general_params <- list(num_chains = 1, keep_every = 1) variance_forest_params <- list(num_trees = 50) - bcf_model <- bcf( - X_train = X_train, - y_train = y_train, - Z_train = Z_train, - propensity_train = pi_train, - X_test = X_test, - Z_test = Z_test, - propensity_test = pi_test, - num_gfr = 10, - num_burnin = 0, - num_mcmc = 10, - general_params = general_params, - variance_forest_params = variance_forest_params + expect_warning( + bcf_model <- bcf( + X_train = X_train, + y_train = y_train, + Z_train = Z_train, + propensity_train = pi_train, + X_test = X_test, + Z_test = Z_test, + propensity_test = pi_test, + num_gfr = 10, + num_burnin = 0, + num_mcmc = 10, + general_params = general_params, + variance_forest_params = variance_forest_params + ) ) # Check that cached predictions agree with results of predict() function diff --git a/test/R/testthat/test-predict.R b/test/R/testthat/test-predict.R index e64c85a6..bdd9d66b 100644 --- a/test/R/testthat/test-predict.R +++ b/test/R/testthat/test-predict.R @@ -375,18 +375,20 @@ test_that("BCF predictions with pre-summarization", { # Fit a heteroskedastic BCF model var_params <- list(num_trees = 20) - het_bcf_model <- bcf( - X_train = X_train, - Z_train = Z_train, - y_train = y_train, - propensity_train = pi_train, - X_test = X_test, - Z_test = Z_test, - propensity_test = pi_test, - num_gfr = 10, - num_burnin = 0, - num_mcmc = 10, - variance_forest_params = var_params + expect_warning( + het_bcf_model <- bcf( + X_train = X_train, + Z_train = Z_train, + y_train = y_train, + propensity_train = pi_train, + X_test = X_test, + Z_test = Z_test, + propensity_test = pi_test, + num_gfr = 10, + num_burnin = 0, + num_mcmc = 10, + variance_forest_params = var_params + ) ) # Check that the default predict method returns a list From 83d553dc6a7705bdfcaf5b88dc0b07eaff5ee6b7 Mon Sep 17 00:00:00 2001 From: Drew Herren Date: Fri, 17 Oct 2025 14:02:46 -0500 Subject: [PATCH 24/53] Updated indentation --- R/bart.R | 62 ++++++++++++++++++++++++++++---------------------------- R/bcf.R | 30 +++++++++++++-------------- 2 files changed, 46 insertions(+), 46 deletions(-) diff --git a/R/bart.R b/R/bart.R index 6e25bbb6..9853a0b8 100644 --- a/R/bart.R +++ b/R/bart.R @@ -250,16 +250,16 @@ bart <- function( drop_vars_variance <- variance_forest_params_updated$drop_vars num_features_subsample_variance <- variance_forest_params_updated$num_features_subsample - # Set a function-scoped RNG if user provided a random seed - custom_rng <- random_seed >= 0 - if (custom_rng) { - # Store original global environment RNG state - original_global_seed <- .Random.seed - # Set new seed and store associated RNG state - set.seed(random_seed) - function_scoped_seed <- .Random.seed - } - + # Set a function-scoped RNG if user provided a random seed + custom_rng <- random_seed >= 0 + if (custom_rng) { + # Store original global environment RNG state + original_global_seed <- .Random.seed + # Set new seed and store associated RNG state + set.seed(random_seed) + function_scoped_seed <- .Random.seed + } + # Check if there are enough GFR samples to seed num_chains samplers if (num_gfr > 0) { if (num_chains > num_gfr) { @@ -1748,27 +1748,27 @@ bart <- function( } class(result) <- "bartmodel" - # Clean up classes with external pointers to C++ data structures - if (include_mean_forest) { - rm(forest_model_mean) - } - if (include_variance_forest) { - rm(forest_model_variance) - } - rm(forest_dataset_train) - if (has_test) { - rm(forest_dataset_test) - } - if (has_rfx) { - rm(rfx_dataset_train, rfx_tracker_train, rfx_model) - } - rm(outcome_train) - rm(rng) - - # Restore global RNG state if user provided a random seed - if (custom_rng) { - .Random.seed <- original_global_seed - } + # Clean up classes with external pointers to C++ data structures + if (include_mean_forest) { + rm(forest_model_mean) + } + if (include_variance_forest) { + rm(forest_model_variance) + } + rm(forest_dataset_train) + if (has_test) { + rm(forest_dataset_test) + } + if (has_rfx) { + rm(rfx_dataset_train, rfx_tracker_train, rfx_model) + } + rm(outcome_train) + rm(rng) + + # Restore global RNG state if user provided a random seed + if (custom_rng) { + .Random.seed <- original_global_seed + } return(result) } diff --git a/R/bcf.R b/R/bcf.R index 4be368a8..68eb63db 100644 --- a/R/bcf.R +++ b/R/bcf.R @@ -340,16 +340,16 @@ bcf <- function( keep_vars_variance <- variance_forest_params_updated$keep_vars drop_vars_variance <- variance_forest_params_updated$drop_vars num_features_subsample_variance <- variance_forest_params_updated$num_features_subsample - - # Set a function-scoped RNG if user provided a random seed - custom_rng <- random_seed >= 0 - if (custom_rng) { - # Store original global environment RNG state - original_global_seed <- .Random.seed - # Set new seed and store associated RNG state - set.seed(random_seed) - function_scoped_seed <- .Random.seed - } + + # Set a function-scoped RNG if user provided a random seed + custom_rng <- random_seed >= 0 + if (custom_rng) { + # Store original global environment RNG state + original_global_seed <- .Random.seed + # Set new seed and store associated RNG state + set.seed(random_seed) + function_scoped_seed <- .Random.seed + } # Check if there are enough GFR samples to seed num_chains samplers if (num_gfr > 0) { @@ -2558,11 +2558,11 @@ bcf <- function( result[["bart_propensity_model"]] = bart_model_propensity } class(result) <- "bcfmodel" - - # Restore global RNG state if user provided a random seed - if (custom_rng) { - .Random.seed <- original_global_seed - } + + # Restore global RNG state if user provided a random seed + if (custom_rng) { + .Random.seed <- original_global_seed + } return(result) } From f11a0f4ca331862ef06a1e4652159e2e287e5398 Mon Sep 17 00:00:00 2001 From: Drew Herren Date: Fri, 17 Oct 2025 17:54:53 -0500 Subject: [PATCH 25/53] Fixed several prediction bugs in R for BART / BCF --- R/bart.R | 26 +++++++++++++++----------- R/bcf.R | 10 ++++++---- 2 files changed, 21 insertions(+), 15 deletions(-) diff --git a/R/bart.R b/R/bart.R index 9853a0b8..4b45bbb0 100644 --- a/R/bart.R +++ b/R/bart.R @@ -1696,7 +1696,7 @@ bart <- function( "is_leaf_constant" = is_leaf_constant, "leaf_regression" = leaf_regression, "requires_basis" = requires_basis, - "num_covariates" = ncol(X_train), + "num_covariates" = num_cov_orig, "num_basis" = ifelse( is.null(leaf_basis_train), 0, @@ -1896,12 +1896,10 @@ predict.bartmodel <- function( ) } - # Preprocess covariates + # Check that covariates are matrix or data frame if ((!is.data.frame(covariates)) && (!is.matrix(covariates))) { stop("covariates must be a matrix or dataframe") } - train_set_metadata <- object$train_set_metadata - X <- preprocessPredictionData(covariates, train_set_metadata) # Convert all input data to matrices if not already converted if ((is.null(dim(leaf_basis))) && (!is.null(leaf_basis))) { @@ -1915,11 +1913,13 @@ predict.bartmodel <- function( if ((object$model_params$requires_basis) && (is.null(leaf_basis))) { stop("Basis (leaf_basis) must be provided for this model") } - if ((!is.null(leaf_basis)) && (nrow(X) != nrow(leaf_basis))) { - stop("X and leaf_basis must have the same number of rows") + if ((!is.null(leaf_basis)) && (nrow(covariates) != nrow(leaf_basis))) { + stop("covariates and leaf_basis must have the same number of rows") } - if (object$model_params$num_covariates != ncol(X)) { - stop("X and leaf_basis must have the same number of rows") + if (object$model_params$num_covariates != ncol(covariates)) { + stop( + "covariates must contain the same number of columns as the BART model's training dataset" + ) } if ((predict_rfx) && (is.null(rfx_group_ids))) { stop( @@ -1938,6 +1938,10 @@ predict.bartmodel <- function( ) } + # Preprocess covariates + train_set_metadata <- object$train_set_metadata + covariates <- preprocessPredictionData(covariates, train_set_metadata) + # Recode group IDs to integer vector (if passed as, for example, a vector of county names, etc...) has_rfx <- FALSE if (predict_rfx) { @@ -1956,14 +1960,14 @@ predict.bartmodel <- function( # Produce basis for the "intercept-only" random effects case if ((predict_rfx) && (is.null(rfx_basis))) { - rfx_basis <- matrix(rep(1, nrow(X)), ncol = 1) + rfx_basis <- matrix(rep(1, nrow(covariates)), ncol = 1) } # Create prediction dataset if (!is.null(leaf_basis)) { - prediction_dataset <- createForestDataset(X, leaf_basis) + prediction_dataset <- createForestDataset(covariates, leaf_basis) } else { - prediction_dataset <- createForestDataset(X) + prediction_dataset <- createForestDataset(covariates) } # Compute variance forest predictions diff --git a/R/bcf.R b/R/bcf.R index 68eb63db..a68862dd 100644 --- a/R/bcf.R +++ b/R/bcf.R @@ -2708,12 +2708,10 @@ predict.bcfmodel <- function( predict_mu_forest_intermediate <- (predict_y_hat && has_mu_forest) predict_tau_forest_intermediate <- (predict_y_hat && has_tau_forest) - # Preprocess covariates + # Make sure covariates are matrix or data frame if ((!is.data.frame(X)) && (!is.matrix(X))) { stop("X must be a matrix or dataframe") } - train_set_metadata <- object$train_set_metadata - X <- preprocessPredictionData(X, train_set_metadata) # Convert all input data to matrices if not already converted if ((is.null(dim(Z))) && (!is.null(Z))) { @@ -2762,6 +2760,10 @@ predict.bcfmodel <- function( ) } + # Preprocess covariates + train_set_metadata <- object$train_set_metadata + X <- preprocessPredictionData(X, train_set_metadata) + # Recode group IDs to integer vector (if passed as, for example, a vector of county names, etc...) has_rfx <- FALSE if (!is.null(rfx_group_ids)) { @@ -2846,7 +2848,7 @@ predict.bcfmodel <- function( } # Compute rfx predictions - if (predict_rfx) { + if (predict_rfx || predict_rfx_intermediate) { rfx_predictions <- object$rfx_samples$predict( rfx_group_ids, rfx_basis From c03b2ea097fa043dd7f63aa818a4a82d2fd39d83 Mon Sep 17 00:00:00 2001 From: Drew Herren Date: Tue, 21 Oct 2025 18:23:50 -0500 Subject: [PATCH 26/53] Fixed bug in probit + RFX and added contrast computation function for BCF --- DESCRIPTION | 2 +- NAMESPACE | 1 + R/bart.R | 30 ++- R/bcf.R | 32 ++- R/posterior_transformation.R | 250 ++++++++++++++++++++ man/compute_bcf_posterior_interval.Rd | 2 +- man/compute_contrast_bcf_model.Rd | 130 ++++++++++ tools/debug/bcf_cate_debug.R | 328 ++++++++++++++++++++++++++ 8 files changed, 756 insertions(+), 19 deletions(-) create mode 100644 man/compute_contrast_bcf_model.Rd create mode 100644 tools/debug/bcf_cate_debug.R diff --git a/DESCRIPTION b/DESCRIPTION index 15e4c0c0..1e1b9806 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -29,7 +29,7 @@ Description: Flexible stochastic tree ensemble software. License: MIT + file LICENSE Encoding: UTF-8 Roxygen: list(markdown = TRUE) -RoxygenNote: 7.3.2 +RoxygenNote: 7.3.3 LinkingTo: cpp11, BH Suggests: diff --git a/NAMESPACE b/NAMESPACE index 930789db..e70dec9f 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -12,6 +12,7 @@ export(computeForestLeafVariances) export(computeForestMaxLeafIndex) export(compute_bart_posterior_interval) export(compute_bcf_posterior_interval) +export(compute_contrast_bcf_model) export(convertPreprocessorToJson) export(createBARTModelFromCombinedJson) export(createBARTModelFromCombinedJsonString) diff --git a/R/bart.R b/R/bart.R index 4b45bbb0..00cf41f4 100644 --- a/R/bart.R +++ b/R/bart.R @@ -1106,18 +1106,25 @@ bart <- function( if (include_mean_forest) { if (probit_outcome_model) { # Sample latent probit variable, z | - - forest_pred <- active_forest_mean$predict( + outcome_pred <- active_forest_mean$predict( forest_dataset_train ) - mu0 <- forest_pred[y_train == 0] - mu1 <- forest_pred[y_train == 1] + if (has_rfx) { + rfx_pred <- rfx_model$predict( + rfx_dataset_train, + rfx_tracker_train + ) + outcome_pred <- outcome_pred + rfx_pred + } + mu0 <- outcome_pred[y_train == 0] + mu1 <- outcome_pred[y_train == 1] u0 <- runif(sum(y_train == 0), 0, pnorm(0 - mu0)) u1 <- runif(sum(y_train == 1), pnorm(0 - mu1), 1) resid_train[y_train == 0] <- mu0 + qnorm(u0) resid_train[y_train == 1] <- mu1 + qnorm(u1) # Update outcome - outcome_train$update_data(resid_train - forest_pred) + outcome_train$update_data(resid_train - outcome_pred) } # Sample mean forest @@ -1467,18 +1474,25 @@ bart <- function( if (include_mean_forest) { if (probit_outcome_model) { # Sample latent probit variable, z | - - forest_pred <- active_forest_mean$predict( + outcome_pred <- active_forest_mean$predict( forest_dataset_train ) - mu0 <- forest_pred[y_train == 0] - mu1 <- forest_pred[y_train == 1] + if (has_rfx) { + rfx_pred <- rfx_model$predict( + rfx_dataset_train, + rfx_tracker_train + ) + outcome_pred <- outcome_pred + rfx_pred + } + mu0 <- outcome_pred[y_train == 0] + mu1 <- outcome_pred[y_train == 1] u0 <- runif(sum(y_train == 0), 0, pnorm(0 - mu0)) u1 <- runif(sum(y_train == 1), pnorm(0 - mu1), 1) resid_train[y_train == 0] <- mu0 + qnorm(u0) resid_train[y_train == 1] <- mu1 + qnorm(u1) # Update outcome - outcome_train$update_data(resid_train - forest_pred) + outcome_train$update_data(resid_train - outcome_pred) } forest_model_mean$sample_one_iteration( diff --git a/R/bcf.R b/R/bcf.R index a68862dd..1f0a992a 100644 --- a/R/bcf.R +++ b/R/bcf.R @@ -1460,16 +1460,23 @@ bcf <- function( tau_forest_pred <- active_forest_tau$predict( forest_dataset_train ) - forest_pred <- mu_forest_pred + tau_forest_pred - mu0 <- forest_pred[y_train == 0] - mu1 <- forest_pred[y_train == 1] + outcome_pred <- mu_forest_pred + tau_forest_pred + if (has_rfx) { + rfx_pred <- rfx_model$predict( + rfx_dataset_train, + rfx_tracker_train + ) + outcome_pred <- outcome_pred + rfx_pred + } + mu0 <- outcome_pred[y_train == 0] + mu1 <- outcome_pred[y_train == 1] u0 <- runif(sum(y_train == 0), 0, pnorm(0 - mu0)) u1 <- runif(sum(y_train == 1), pnorm(0 - mu1), 1) resid_train[y_train == 0] <- mu0 + qnorm(u0) resid_train[y_train == 1] <- mu1 + qnorm(u1) # Update outcome - outcome_train$update_data(resid_train - forest_pred) + outcome_train$update_data(resid_train - outcome_pred) } # Sample the prognostic forest @@ -2057,16 +2064,23 @@ bcf <- function( tau_forest_pred <- active_forest_tau$predict( forest_dataset_train ) - forest_pred <- mu_forest_pred + tau_forest_pred - mu0 <- forest_pred[y_train == 0] - mu1 <- forest_pred[y_train == 1] + outcome_pred <- mu_forest_pred + tau_forest_pred + if (has_rfx) { + rfx_pred <- rfx_model$predict( + rfx_dataset_train, + rfx_tracker_train + ) + outcome_pred <- outcome_pred + rfx_pred + } + mu0 <- outcome_pred[y_train == 0] + mu1 <- outcome_pred[y_train == 1] u0 <- runif(sum(y_train == 0), 0, pnorm(0 - mu0)) u1 <- runif(sum(y_train == 1), pnorm(0 - mu1), 1) resid_train[y_train == 0] <- mu0 + qnorm(u0) resid_train[y_train == 1] <- mu1 + qnorm(u1) # Update outcome - outcome_train$update_data(resid_train - forest_pred) + outcome_train$update_data(resid_train - outcome_pred) } # Sample the prognostic forest @@ -2771,7 +2785,7 @@ predict.bcfmodel <- function( group_ids_factor <- factor(rfx_group_ids, levels = rfx_unique_group_ids) if (sum(is.na(group_ids_factor)) > 0) { stop( - "All random effect group labels provided in rfx_group_ids must be present in rfx_group_ids_train" + "All random effect group labels provided in rfx_group_ids must be present in rfx_group_ids" ) } rfx_group_ids <- as.integer(group_ids_factor) diff --git a/R/posterior_transformation.R b/R/posterior_transformation.R index 94c15ed4..eb0eb3a1 100644 --- a/R/posterior_transformation.R +++ b/R/posterior_transformation.R @@ -1,3 +1,253 @@ +#' Compute a contrast using a BCF model by making two sets of outcome predictions and taking their difference. +#' For simple BCF models with binary treatment, this will yield the same prediction as requesting `terms = "cate"` +#' in the `predict.bcfmodel` function. For more general models, such as models with continuous / multivariate treatments or +#' an additive random effects term with a coefficient on the treatment, this function provides the flexibility to compute a +#' any contrast of interest by specifying covariates, treatment, and random effects bases and IDs for both sides of a two term +#' contrast. For simplicity, we refer to the subtrahend of the contrast as the "control" or `Y0` term and the minuend of the +#' contrast as the `Y1` term, though the requested contrast need not match the "control vs treatment" terminology of a classic +#' two-arm experiment. We mirror the function calls and terminology of the `predict.bcfmodel` function, labeling each prediction +#' data term with a `1` to denote its contribution to the treatment prediction of a contrast and `0` to denote inclusion in the +#' control prediction. +#' +#' @param object Object of type `bcfmodel` containing draws of a Bayesian causal forest model and associated sampling outputs. +#' @param X_0 Covariates used for prediction in the "control" case. +#' @param X_1 Covariates used for prediction in the "treatment" case. +#' @param Z_0 Treatments used for prediction in the "control" case. +#' @param Z_1 Treatments used for prediction in the "treatment" case. +#' @param propensity_0 (Optional) Propensities used for prediction in the "control" case. +#' @param propensity_1 (Optional) Propensities used for prediction in the "treatment" case. +#' @param rfx_group_ids_0 (Optional) Test set group labels used for prediction from an additive random effects +#' model in the "control" case. We do not currently support (but plan to in the near future), test set evaluation +#' for group labels that were not in the training set. +#' @param rfx_group_ids_1 (Optional) Test set group labels used for prediction from an additive random effects +#' model in the "treatment" case. We do not currently support (but plan to in the near future), test set evaluation +#' for group labels that were not in the training set. +#' @param rfx_basis_0 (Optional) Test set basis for used for prediction from an additive random effects model in the "control" case. +#' @param rfx_basis_1 (Optional) Test set basis for used for prediction from an additive random effects model in the "treatment" case. +#' @param type (Optional) Type of prediction to return. Options are "mean", which averages the predictions from every draw of a BART model, and "posterior", which returns the entire matrix of posterior predictions. Default: "posterior". +#' @param scale (Optional) Scale of mean function predictions. Options are "linear", which returns a contrast on the original scale of the mean forest / RFX terms, and "probability", which transforms each contrast term into a probability of observing `y == 1` before taking their difference. "probability" is only valid for models fit with a probit outcome model. Default: "linear". +#' @param ... (Optional) Other prediction parameters. +#' +#' @return List of prediction matrices or single prediction matrix / vector, depending on the terms requested. +#' @export +#' +#' @examples +#' n <- 500 +#' p <- 5 +#' X <- matrix(runif(n*p), ncol = p) +#' mu_x <- ( +#' ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + +#' ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + +#' ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + +#' ((0.75 <= X[,1]) & (1 > X[,1])) * (7.5) +#' ) +#' pi_x <- ( +#' ((0 <= X[,1]) & (0.25 > X[,1])) * (0.2) + +#' ((0.25 <= X[,1]) & (0.5 > X[,1])) * (0.4) + +#' ((0.5 <= X[,1]) & (0.75 > X[,1])) * (0.6) + +#' ((0.75 <= X[,1]) & (1 > X[,1])) * (0.8) +#' ) +#' tau_x <- ( +#' ((0 <= X[,2]) & (0.25 > X[,2])) * (0.5) + +#' ((0.25 <= X[,2]) & (0.5 > X[,2])) * (1.0) + +#' ((0.5 <= X[,2]) & (0.75 > X[,2])) * (1.5) + +#' ((0.75 <= X[,2]) & (1 > X[,2])) * (2.0) +#' ) +#' Z <- rbinom(n, 1, pi_x) +#' noise_sd <- 1 +#' y <- mu_x + tau_x*Z + rnorm(n, 0, noise_sd) +#' test_set_pct <- 0.2 +#' n_test <- round(test_set_pct*n) +#' n_train <- n - n_test +#' test_inds <- sort(sample(1:n, n_test, replace = FALSE)) +#' train_inds <- (1:n)[!((1:n) %in% test_inds)] +#' X_test <- X[test_inds,] +#' X_train <- X[train_inds,] +#' pi_test <- pi_x[test_inds] +#' pi_train <- pi_x[train_inds] +#' Z_test <- Z[test_inds] +#' Z_train <- Z[train_inds] +#' y_test <- y[test_inds] +#' y_train <- y[train_inds] +#' mu_test <- mu_x[test_inds] +#' mu_train <- mu_x[train_inds] +#' tau_test <- tau_x[test_inds] +#' tau_train <- tau_x[train_inds] +#' bcf_model <- bcf(X_train = X_train, Z_train = Z_train, y_train = y_train, +#' propensity_train = pi_train, num_gfr = 10, +#' num_burnin = 0, num_mcmc = 10) +#' preds <- compute_posterior_contrast_bcf_model( +#' bcf_model, X0=X_test, X1=X_test, Z0=rep(0, n_test), Z1=rep(1, n_test), +#' propensity_0 = pi_test, propensity_1 = pi_test +#' ) +compute_contrast_bcf_model <- function( + object, + X_0, + X_1, + Z_0, + Z_1, + propensity_0 = NULL, + propensity_1 = NULL, + rfx_group_ids_0 = NULL, + rfx_group_ids_1 = NULL, + rfx_basis_0 = NULL, + rfx_basis_1 = NULL, + type = "posterior", + scale = "linear", + ... +) { + # Handle mean function scale + if (!is.character(scale)) { + stop("scale must be a string or character vector") + } + if (!(scale %in% c("linear", "probability"))) { + stop("scale must either be 'linear' or 'probability'") + } + is_probit <- object$model_params$probit_outcome_model + if ((scale == "probability") && (!is_probit)) { + stop( + "scale cannot be 'probability' for models not fit with a probit outcome model" + ) + } + probability_scale <- scale == "probability" + + # Handle prediction type + if (!is.character(type)) { + stop("type must be a string or character vector") + } + if (!(type %in% c("mean", "posterior"))) { + stop("type must either be 'mean' or 'posterior") + } + predict_mean <- type == "mean" + + # Make sure covariates are matrix or data frame + if ((!is.data.frame(X_0)) && (!is.matrix(X_0))) { + stop("X_0 must be a matrix or dataframe") + } + if ((!is.data.frame(X_1)) && (!is.matrix(X_1))) { + stop("X_1 must be a matrix or dataframe") + } + + # Convert all input data to matrices if not already converted + if ((is.null(dim(Z_0))) && (!is.null(Z_0))) { + Z_0 <- as.matrix(as.numeric(Z_0)) + } + if ((is.null(dim(Z_1))) && (!is.null(Z_1))) { + Z_1 <- as.matrix(as.numeric(Z_1)) + } + if ((is.null(dim(propensity_0))) && (!is.null(propensity_0))) { + propensity_0 <- as.matrix(propensity_0) + } + if ((is.null(dim(propensity_1))) && (!is.null(propensity_1))) { + propensity_1 <- as.matrix(propensity_1) + } + if ((is.null(dim(rfx_basis_0))) && (!is.null(rfx_basis_0))) { + rfx_basis_0 <- as.matrix(rfx_basis_0) + } + if ((is.null(dim(rfx_basis_1))) && (!is.null(rfx_basis_1))) { + rfx_basis_1 <- as.matrix(rfx_basis_1) + } + + # Data checks + if ( + (object$model_params$propensity_covariate != "none") && + ((is.null(propensity_0)) || + (is.null(propensity_1))) + ) { + if (!object$model_params$internal_propensity_model) { + stop("propensity_0 and propensity_1 must be provided for this model") + } + } + if (nrow(X_0) != nrow(Z_0)) { + stop("X_0 and Z_0 must have the same number of rows") + } + if (nrow(X_1) != nrow(Z_1)) { + stop("X_1 and Z_1 must have the same number of rows") + } + if (object$model_params$num_covariates != ncol(X_0)) { + stop( + "X_0 and must have the same number of columns as the covariates used to train the model" + ) + } + if (object$model_params$num_covariates != ncol(X_1)) { + stop( + "X_1 and must have the same number of columns as the covariates used to train the model" + ) + } + if ((object$model_params$has_rfx) && (is.null(rfx_group_ids_0))) { + stop( + "Random effect group labels (rfx_group_ids_0) must be provided for this model" + ) + } + if ((object$model_params$has_rfx) && (is.null(rfx_group_ids_1))) { + stop( + "Random effect group labels (rfx_group_ids_1) must be provided for this model" + ) + } + if ((object$model_params$has_rfx_basis) && (is.null(rfx_basis_0))) { + stop("Random effects basis (rfx_basis_0) must be provided for this model") + } + if ((object$model_params$has_rfx_basis) && (is.null(rfx_basis_1))) { + stop("Random effects basis (rfx_basis_1) must be provided for this model") + } + if ( + (object$model_params$num_rfx_basis > 0) && + (ncol(rfx_basis_0) != object$model_params$num_rfx_basis) + ) { + stop( + "Random effects basis has a different dimension than the basis used to train this model" + ) + } + if ( + (object$model_params$num_rfx_basis > 0) && + (ncol(rfx_basis_1) != object$model_params$num_rfx_basis) + ) { + stop( + "Random effects basis has a different dimension than the basis used to train this model" + ) + } + + # Predict for the control arm + control_preds <- predict( + object = object, + X = X_0, + Z = Z_0, + propensity = propensity_0, + rfx_group_ids = rfx_group_ids_0, + rfx_basis = rfx_basis_0, + type = "posterior", + term = "y_hat", + scale = "linear" + ) + + # Predict for the treatment arm + treatment_preds <- predict( + object = object, + X = X_1, + Z = Z_1, + propensity = propensity_1, + rfx_group_ids = rfx_group_ids_1, + rfx_basis = rfx_basis_1, + type = "posterior", + term = "y_hat", + scale = "linear" + ) + + # Transform to probability scale if requested + if (probability_scale) { + treatment_preds <- pnorm(treatment_preds) + control_preds <- pnorm(control_preds) + } + + # Compute and return contrast + if (predict_mean) { + return(rowMeans(treatment_preds - control_preds)) + } else { + return(treatment_preds - control_preds) + } +} + + #' Sample from the posterior predictive distribution for outcomes modeled by BCF #' #' @param model_object A fitted BCF model object of class `bcfmodel`. diff --git a/man/compute_bcf_posterior_interval.Rd b/man/compute_bcf_posterior_interval.Rd index 1ff4836d..880226e6 100644 --- a/man/compute_bcf_posterior_interval.Rd +++ b/man/compute_bcf_posterior_interval.Rd @@ -29,7 +29,7 @@ compute_bcf_posterior_interval( \item{treatment}{(Optional) A vector or matrix of treatment assignments. Required if the requested term is \code{"y_hat"} (overall predictions).} -\item{propensity}{(Optional) A vector or matrix of propensity scores. Required if the requested term is \code{"y_hat"} (overall predictions) and the underlying model depends on user-provided propensities.} +\item{propensity}{(Optional) A vector or matrix of propensity scores. Required if the underlying model depends on user-provided propensities.} \item{rfx_group_ids}{An optional vector of group IDs for random effects. Required if the requested term includes random effects.} diff --git a/man/compute_contrast_bcf_model.Rd b/man/compute_contrast_bcf_model.Rd new file mode 100644 index 00000000..fa1e5b30 --- /dev/null +++ b/man/compute_contrast_bcf_model.Rd @@ -0,0 +1,130 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/posterior_transformation.R +\name{compute_contrast_bcf_model} +\alias{compute_contrast_bcf_model} +\title{Compute a contrast using a BCF model by making two sets of outcome predictions and taking their difference. +For simple BCF models with binary treatment, this will yield the same prediction as requesting \code{terms = "cate"} +in the \code{predict.bcfmodel} function. For more general models, such as models with continuous / multivariate treatments or +an additive random effects term with a coefficient on the treatment, this function provides the flexibility to compute a +any contrast of interest by specifying covariates, treatment, and random effects bases and IDs for both sides of a two term +contrast. For simplicity, we refer to the subtrahend of the contrast as the "control" or \code{Y0} term and the minuend of the +contrast as the \code{Y1} term, though the requested contrast need not match the "control vs treatment" terminology of a classic +two-arm experiment. We mirror the function calls and terminology of the \code{predict.bcfmodel} function, labeling each prediction +data term with a \code{1} to denote its contribution to the treatment prediction of a contrast and \code{0} to denote inclusion in the +control prediction.} +\usage{ +compute_contrast_bcf_model( + object, + X_0, + X_1, + Z_0, + Z_1, + propensity_0 = NULL, + propensity_1 = NULL, + rfx_group_ids_0 = NULL, + rfx_group_ids_1 = NULL, + rfx_basis_0 = NULL, + rfx_basis_1 = NULL, + type = "posterior", + scale = "linear", + ... +) +} +\arguments{ +\item{object}{Object of type \code{bcfmodel} containing draws of a Bayesian causal forest model and associated sampling outputs.} + +\item{X_0}{Covariates used for prediction in the "control" case.} + +\item{X_1}{Covariates used for prediction in the "treatment" case.} + +\item{Z_0}{Treatments used for prediction in the "control" case.} + +\item{Z_1}{Treatments used for prediction in the "treatment" case.} + +\item{propensity_0}{(Optional) Propensities used for prediction in the "control" case.} + +\item{propensity_1}{(Optional) Propensities used for prediction in the "treatment" case.} + +\item{rfx_group_ids_0}{(Optional) Test set group labels used for prediction from an additive random effects +model in the "control" case. We do not currently support (but plan to in the near future), test set evaluation +for group labels that were not in the training set.} + +\item{rfx_group_ids_1}{(Optional) Test set group labels used for prediction from an additive random effects +model in the "treatment" case. We do not currently support (but plan to in the near future), test set evaluation +for group labels that were not in the training set.} + +\item{rfx_basis_0}{(Optional) Test set basis for used for prediction from an additive random effects model in the "control" case.} + +\item{rfx_basis_1}{(Optional) Test set basis for used for prediction from an additive random effects model in the "treatment" case.} + +\item{type}{(Optional) Type of prediction to return. Options are "mean", which averages the predictions from every draw of a BART model, and "posterior", which returns the entire matrix of posterior predictions. Default: "posterior".} + +\item{scale}{(Optional) Scale of mean function predictions. Options are "linear", which returns a contrast on the original scale of the mean forest / RFX terms, and "probability", which transforms each contrast term into a probability of observing \code{y == 1} before taking their difference. "probability" is only valid for models fit with a probit outcome model. Default: "linear".} + +\item{...}{(Optional) Other prediction parameters.} +} +\value{ +List of prediction matrices or single prediction matrix / vector, depending on the terms requested. +} +\description{ +Compute a contrast using a BCF model by making two sets of outcome predictions and taking their difference. +For simple BCF models with binary treatment, this will yield the same prediction as requesting \code{terms = "cate"} +in the \code{predict.bcfmodel} function. For more general models, such as models with continuous / multivariate treatments or +an additive random effects term with a coefficient on the treatment, this function provides the flexibility to compute a +any contrast of interest by specifying covariates, treatment, and random effects bases and IDs for both sides of a two term +contrast. For simplicity, we refer to the subtrahend of the contrast as the "control" or \code{Y0} term and the minuend of the +contrast as the \code{Y1} term, though the requested contrast need not match the "control vs treatment" terminology of a classic +two-arm experiment. We mirror the function calls and terminology of the \code{predict.bcfmodel} function, labeling each prediction +data term with a \code{1} to denote its contribution to the treatment prediction of a contrast and \code{0} to denote inclusion in the +control prediction. +} +\examples{ +n <- 500 +p <- 5 +X <- matrix(runif(n*p), ncol = p) +mu_x <- ( + ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + + ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + + ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + + ((0.75 <= X[,1]) & (1 > X[,1])) * (7.5) +) +pi_x <- ( + ((0 <= X[,1]) & (0.25 > X[,1])) * (0.2) + + ((0.25 <= X[,1]) & (0.5 > X[,1])) * (0.4) + + ((0.5 <= X[,1]) & (0.75 > X[,1])) * (0.6) + + ((0.75 <= X[,1]) & (1 > X[,1])) * (0.8) +) +tau_x <- ( + ((0 <= X[,2]) & (0.25 > X[,2])) * (0.5) + + ((0.25 <= X[,2]) & (0.5 > X[,2])) * (1.0) + + ((0.5 <= X[,2]) & (0.75 > X[,2])) * (1.5) + + ((0.75 <= X[,2]) & (1 > X[,2])) * (2.0) +) +Z <- rbinom(n, 1, pi_x) +noise_sd <- 1 +y <- mu_x + tau_x*Z + rnorm(n, 0, noise_sd) +test_set_pct <- 0.2 +n_test <- round(test_set_pct*n) +n_train <- n - n_test +test_inds <- sort(sample(1:n, n_test, replace = FALSE)) +train_inds <- (1:n)[!((1:n) \%in\% test_inds)] +X_test <- X[test_inds,] +X_train <- X[train_inds,] +pi_test <- pi_x[test_inds] +pi_train <- pi_x[train_inds] +Z_test <- Z[test_inds] +Z_train <- Z[train_inds] +y_test <- y[test_inds] +y_train <- y[train_inds] +mu_test <- mu_x[test_inds] +mu_train <- mu_x[train_inds] +tau_test <- tau_x[test_inds] +tau_train <- tau_x[train_inds] +bcf_model <- bcf(X_train = X_train, Z_train = Z_train, y_train = y_train, + propensity_train = pi_train, num_gfr = 10, + num_burnin = 0, num_mcmc = 10) +preds <- compute_posterior_contrast_bcf_model( + bcf_model, X0=X_test, X1=X_test, Z0=rep(0, n_test), Z1=rep(1, n_test), + propensity_0 = pi_test, propensity_1 = pi_test +) +} diff --git a/tools/debug/bcf_cate_debug.R b/tools/debug/bcf_cate_debug.R new file mode 100644 index 00000000..8f5bcb9b --- /dev/null +++ b/tools/debug/bcf_cate_debug.R @@ -0,0 +1,328 @@ +# Demo of CATE computation function for BCF + +# Load library +library(stochtree) + +# Generate data +n <- 500 +p <- 5 +X <- matrix(rnorm(n * p), ncol = p) +mu_x <- X[, 1] +tau_x <- 0.25 * X[, 2] +pi_x <- pnorm(0.5 * X[, 1]) +Z <- rbinom(n, 1, pi_x) +E_XZ <- mu_x + Z * tau_x +snr <- 2 +y <- E_XZ + rnorm(n, 0, 1) * (sd(E_XZ) / snr) + +# Train-test split +test_set_pct <- 0.2 +n_test <- round(test_set_pct * n) +n_train <- n - n_test +test_inds <- sort(sample(1:n, n_test, replace = FALSE)) +train_inds <- (1:n)[!((1:n) %in% test_inds)] +X_test <- X[test_inds, ] +X_train <- X[train_inds, ] +pi_test <- pi_x[test_inds] +pi_train <- pi_x[train_inds] +Z_test <- Z[test_inds] +Z_train <- Z[train_inds] +y_test <- y[test_inds] +y_train <- y[train_inds] +mu_test <- mu_x[test_inds] +mu_train <- mu_x[train_inds] +tau_test <- tau_x[test_inds] +tau_train <- tau_x[train_inds] + +# Fit BCF model +bcf_model <- bcf( + X_train = X_train, + Z_train = Z_train, + y_train = y_train, + propensity_train = pi_train, + X_test = X_test, + Z_test = Z_test, + propensity_test = pi_test, + num_gfr = 10, + num_burnin = 0, + num_mcmc = 1000 +) + +# Compute CATE posterior +tau_hat_posterior_test <- compute_contrast_bcf_model( + bcf_model, + X_0 = X_test, + X_1 = X_test, + Z_0 = rep(0, n_test), + Z_1 = rep(1, n_test), + propensity_0 = pi_test, + propensity_1 = pi_test, + type = "posterior", + scale = "linear" +) + +# Compute the same quantity via predict +tau_hat_posterior_test_comparison <- predict( + bcf_model, + X = X_test, + Z = Z_test, + propensity = pi_test, + type = "posterior", + terms = "cate", + scale = "linear" +) + +# Compare results +tau_diff <- tau_hat_posterior_test_comparison - tau_hat_posterior_test +all( + abs(tau_diff) < 0.001 +) + +# Generate data for a BCF model with random effects +X <- matrix(rnorm(n * p), ncol = p) +mu_x <- X[, 1] +tau_x <- 0.25 * X[, 2] +pi_x <- pnorm(0.5 * X[, 1]) +Z <- rbinom(n, 1, pi_x) +group_ids <- rep(c(1, 2), n %/% 2) +rfx_coefs <- matrix(c(-1, -1, 1, 1), nrow = 2, byrow = TRUE) +rfx_basis <- cbind(1, Z) +rfx_term <- rowSums(rfx_coefs[group_ids, ] * rfx_basis) +E_XZ <- mu_x + Z * tau_x + rfx_term +snr <- 2 +y <- E_XZ + rnorm(n, 0, 1) * (sd(E_XZ) / snr) + +# Train-test split +test_set_pct <- 0.2 +n_test <- round(test_set_pct * n) +n_train <- n - n_test +test_inds <- sort(sample(1:n, n_test, replace = FALSE)) +train_inds <- (1:n)[!((1:n) %in% test_inds)] +X_test <- X[test_inds, ] +X_train <- X[train_inds, ] +pi_test <- pi_x[test_inds] +pi_train <- pi_x[train_inds] +Z_test <- Z[test_inds] +Z_train <- Z[train_inds] +y_test <- y[test_inds] +y_train <- y[train_inds] +mu_test <- mu_x[test_inds] +mu_train <- mu_x[train_inds] +tau_test <- tau_x[test_inds] +tau_train <- tau_x[train_inds] +group_ids_test <- group_ids[test_inds] +group_ids_train <- group_ids[train_inds] +rfx_basis_test <- rfx_basis[test_inds, ] +rfx_basis_train <- rfx_basis[train_inds, ] + +# Fit BCF model +bcf_model <- bcf( + X_train = X_train, + Z_train = Z_train, + y_train = y_train, + propensity_train = pi_train, + rfx_group_ids_train = group_ids_train, + rfx_basis_train = rfx_basis_train, + X_test = X_test, + Z_test = Z_test, + propensity_test = pi_test, + rfx_group_ids_test = group_ids_test, + rfx_basis_test = rfx_basis_test, + num_gfr = 10, + num_burnin = 0, + num_mcmc = 1000 +) + +# Compute CATE posterior +tau_hat_posterior_test <- compute_contrast_bcf_model( + bcf_model, + X_0 = X_test, + X_1 = X_test, + Z_0 = rep(0, n_test), + Z_1 = rep(1, n_test), + propensity_0 = pi_test, + propensity_1 = pi_test, + rfx_group_ids_0 = group_ids_test, + rfx_group_ids_1 = group_ids_test, + rfx_basis_0 = cbind(1, rep(0, n_test)), + rfx_basis_1 = cbind(1, rep(1, n_test)), + type = "posterior", + scale = "linear" +) + +# Compute the same quantity via predict +tau_forest_posterior_test <- predict( + bcf_model, + X = X_test, + Z = Z_test, + propensity = pi_test, + rfx_group_ids = group_ids_test, + rfx_basis = rfx_basis_test, + type = "posterior", + terms = "cate", + scale = "linear" +) +rfx_term_posterior_treated <- predict( + bcf_model, + X = X_test, + Z = Z_test, + propensity = pi_test, + rfx_group_ids = group_ids_test, + rfx_basis = cbind(1, rep(1, n_test)), + type = "posterior", + terms = "rfx", + scale = "linear" +) +rfx_term_posterior_control <- predict( + bcf_model, + X = X_test, + Z = Z_test, + propensity = pi_test, + rfx_group_ids = group_ids_test, + rfx_basis = cbind(1, rep(0, n_test)), + type = "posterior", + terms = "rfx", + scale = "linear" +) +tau_hat_posterior_test_comparison <- (tau_forest_posterior_test + + (rfx_term_posterior_treated - rfx_term_posterior_control)) + +# Compare results +tau_diff <- tau_hat_posterior_test_comparison - tau_hat_posterior_test +all( + abs(tau_diff) < 0.001 +) + +# Generate data for a probit BCF model with random effects +X <- matrix(rnorm(n * p), ncol = p) +mu_x <- X[, 1] +tau_x <- 0.25 * X[, 2] +pi_x <- pnorm(0.5 * X[, 1]) +Z <- rbinom(n, 1, pi_x) +group_ids <- rep(c(1, 2), n %/% 2) +rfx_coefs <- matrix(c(-1, -1, 1, 1), nrow = 2, byrow = TRUE) +rfx_basis <- cbind(1, Z) +rfx_term <- rowSums(rfx_coefs[group_ids, ] * rfx_basis) +E_XZ <- mu_x + Z * tau_x + rfx_term +# E_XZ <- mu_x + Z * tau_x + rfx_term +W <- E_XZ + rnorm(n, 0, 1) +y <- as.numeric(W > 0) + +# Train-test split +test_set_pct <- 0.2 +n_test <- round(test_set_pct * n) +n_train <- n - n_test +test_inds <- sort(sample(1:n, n_test, replace = FALSE)) +train_inds <- (1:n)[!((1:n) %in% test_inds)] +X_test <- X[test_inds, ] +X_train <- X[train_inds, ] +pi_test <- pi_x[test_inds] +pi_train <- pi_x[train_inds] +Z_test <- Z[test_inds] +Z_train <- Z[train_inds] +W_test <- W[test_inds] +W_train <- W[train_inds] +y_test <- y[test_inds] +y_train <- y[train_inds] +mu_test <- mu_x[test_inds] +mu_train <- mu_x[train_inds] +tau_test <- tau_x[test_inds] +tau_train <- tau_x[train_inds] +group_ids_test <- group_ids[test_inds] +group_ids_train <- group_ids[train_inds] +rfx_basis_test <- rfx_basis[test_inds, ] +rfx_basis_train <- rfx_basis[train_inds, ] + +# Fit BCF model +bcf_model <- bcf( + X_train = X_train, + Z_train = Z_train, + y_train = y_train, + propensity_train = pi_train, + rfx_group_ids_train = group_ids_train, + rfx_basis_train = rfx_basis_train, + X_test = X_test, + Z_test = Z_test, + propensity_test = pi_test, + rfx_group_ids_test = group_ids_test, + rfx_basis_test = rfx_basis_test, + num_gfr = 10, + num_burnin = 0, + num_mcmc = 1000, + general_params = list(probit_outcome_model = T) +) + +# Compute CATE posterior on probability scale +tau_hat_posterior_test <- compute_contrast_bcf_model( + bcf_model, + X_0 = X_test, + X_1 = X_test, + Z_0 = rep(0, n_test), + Z_1 = rep(1, n_test), + propensity_0 = pi_test, + propensity_1 = pi_test, + rfx_group_ids_0 = group_ids_test, + rfx_group_ids_1 = group_ids_test, + rfx_basis_0 = cbind(1, rep(0, n_test)), + rfx_basis_1 = cbind(1, rep(1, n_test)), + type = "posterior", + scale = "probability" +) + +# Compute the same quantity via predict +mu_forest_posterior_test <- predict( + bcf_model, + X = X_test, + Z = Z_test, + propensity = pi_test, + rfx_group_ids = group_ids_test, + rfx_basis = rfx_basis_test, + type = "posterior", + terms = "prognostic_function", + scale = "linear" +) +tau_forest_posterior_test <- predict( + bcf_model, + X = X_test, + Z = Z_test, + propensity = pi_test, + rfx_group_ids = group_ids_test, + rfx_basis = rfx_basis_test, + type = "posterior", + terms = "cate", + scale = "linear" +) +rfx_term_posterior_treated <- predict( + bcf_model, + X = X_test, + Z = Z_test, + propensity = pi_test, + rfx_group_ids = group_ids_test, + rfx_basis = cbind(1, rep(1, n_test)), + type = "posterior", + terms = "rfx", + scale = "linear" +) +rfx_term_posterior_control <- predict( + bcf_model, + X = X_test, + Z = Z_test, + propensity = pi_test, + rfx_group_ids = group_ids_test, + rfx_basis = cbind(1, rep(0, n_test)), + type = "posterior", + terms = "rfx", + scale = "linear" +) +w_hat_0 <- mu_forest_posterior_test + + rfx_term_posterior_control +w_hat_1 <- mu_forest_posterior_test + + tau_forest_posterior_test + + rfx_term_posterior_treated +tau_hat_posterior_test_comparison <- pnorm(w_hat_1) - pnorm(w_hat_0) + +# Compare results +tau_diff <- tau_hat_posterior_test_comparison - tau_hat_posterior_test +all( + abs(tau_diff) < 0.001 +) From 2cd37750ee6a8f09d41319315f252289e42c6423 Mon Sep 17 00:00:00 2001 From: Drew Herren Date: Tue, 21 Oct 2025 19:23:22 -0500 Subject: [PATCH 27/53] Added contrast function for BART --- NAMESPACE | 1 + R/bcf.R | 2 +- R/posterior_transformation.R | 235 +++++++++++++++++++++++++++-- man/compute_contrast_bart_model.Rd | 96 ++++++++++++ man/compute_contrast_bcf_model.Rd | 22 +-- man/predict.bcfmodel.Rd | 2 +- tools/debug/bart_contrast_debug.R | 171 +++++++++++++++++++++ 7 files changed, 505 insertions(+), 24 deletions(-) create mode 100644 man/compute_contrast_bart_model.Rd create mode 100644 tools/debug/bart_contrast_debug.R diff --git a/NAMESPACE b/NAMESPACE index e70dec9f..5303244d 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -12,6 +12,7 @@ export(computeForestLeafVariances) export(computeForestMaxLeafIndex) export(compute_bart_posterior_interval) export(compute_bcf_posterior_interval) +export(compute_contrast_bart_model) export(compute_contrast_bcf_model) export(convertPreprocessorToJson) export(createBARTModelFromCombinedJson) diff --git a/R/bcf.R b/R/bcf.R index 1f0a992a..dd6d9b95 100644 --- a/R/bcf.R +++ b/R/bcf.R @@ -2591,7 +2591,7 @@ bcf <- function( #' We do not currently support (but plan to in the near future), test set evaluation for group labels #' that were not in the training set. #' @param rfx_basis (Optional) Test set basis for "random-slope" regression in additive random effects model. -#' @param type (Optional) Type of prediction to return. Options are "mean", which averages the predictions from every draw of a BART model, and "posterior", which returns the entire matrix of posterior predictions. Default: "posterior". +#' @param type (Optional) Type of prediction to return. Options are "mean", which averages the predictions from every draw of a BCF model, and "posterior", which returns the entire matrix of posterior predictions. Default: "posterior". #' @param terms (Optional) Which model terms to include in the prediction. This can be a single term or a list of model terms. Options include "y_hat", "prognostic_function", "cate", "rfx", "variance_forest", or "all". If a model doesn't have random effects or variance forest predictions, but one of those terms is request, the request will simply be ignored. If none of the requested terms are present in a model, this function will return `NULL` along with a warning. Default: "all". #' @param scale (Optional) Scale of mean function predictions. Options are "linear", which returns predictions on the original scale of the mean forest / RFX terms, and "probability", which transforms predictions into a probability of observing `y == 1`. "probability" is only valid for models fit with a probit outcome model. Default: "linear". #' @param ... (Optional) Other prediction parameters. diff --git a/R/posterior_transformation.R b/R/posterior_transformation.R index eb0eb3a1..50f2c494 100644 --- a/R/posterior_transformation.R +++ b/R/posterior_transformation.R @@ -10,21 +10,21 @@ #' control prediction. #' #' @param object Object of type `bcfmodel` containing draws of a Bayesian causal forest model and associated sampling outputs. -#' @param X_0 Covariates used for prediction in the "control" case. -#' @param X_1 Covariates used for prediction in the "treatment" case. -#' @param Z_0 Treatments used for prediction in the "control" case. -#' @param Z_1 Treatments used for prediction in the "treatment" case. -#' @param propensity_0 (Optional) Propensities used for prediction in the "control" case. -#' @param propensity_1 (Optional) Propensities used for prediction in the "treatment" case. +#' @param X_0 Covariates used for prediction in the "control" case. Must be a matrix or dataframe. +#' @param X_1 Covariates used for prediction in the "treatment" case. Must be a matrix or dataframe. +#' @param Z_0 Treatments used for prediction in the "control" case. Must be a matrix or vector. +#' @param Z_1 Treatments used for prediction in the "treatment" case. Must be a matrix or vector. +#' @param propensity_0 (Optional) Propensities used for prediction in the "control" case. Must be a matrix or vector. +#' @param propensity_1 (Optional) Propensities used for prediction in the "treatment" case. Must be a matrix or vector. #' @param rfx_group_ids_0 (Optional) Test set group labels used for prediction from an additive random effects #' model in the "control" case. We do not currently support (but plan to in the near future), test set evaluation -#' for group labels that were not in the training set. +#' for group labels that were not in the training set. Must be a vector. #' @param rfx_group_ids_1 (Optional) Test set group labels used for prediction from an additive random effects #' model in the "treatment" case. We do not currently support (but plan to in the near future), test set evaluation -#' for group labels that were not in the training set. -#' @param rfx_basis_0 (Optional) Test set basis for used for prediction from an additive random effects model in the "control" case. -#' @param rfx_basis_1 (Optional) Test set basis for used for prediction from an additive random effects model in the "treatment" case. -#' @param type (Optional) Type of prediction to return. Options are "mean", which averages the predictions from every draw of a BART model, and "posterior", which returns the entire matrix of posterior predictions. Default: "posterior". +#' for group labels that were not in the training set. Must be a vector. +#' @param rfx_basis_0 (Optional) Test set basis for used for prediction from an additive random effects model in the "control" case. Must be a matrix or vector. +#' @param rfx_basis_1 (Optional) Test set basis for used for prediction from an additive random effects model in the "treatment" case. Must be a matrix or vector. +#' @param type (Optional) Type of prediction to return. Options are "mean", which averages the contrast evaluations over every draw of a BCF model, and "posterior", which returns the entire matrix of posterior contrast estimates. Default: "posterior". #' @param scale (Optional) Scale of mean function predictions. Options are "linear", which returns a contrast on the original scale of the mean forest / RFX terms, and "probability", which transforms each contrast term into a probability of observing `y == 1` before taking their difference. "probability" is only valid for models fit with a probit outcome model. Default: "linear". #' @param ... (Optional) Other prediction parameters. #' @@ -247,6 +247,219 @@ compute_contrast_bcf_model <- function( } } +#' Compute a contrast using a BART model by making two sets of outcome predictions and taking their difference. +#' This function provides the flexibility to compute any contrast of interest by specifying covariates, leaf basis, and random effects +#' bases / IDs for both sides of a two term contrast. For simplicity, we refer to the subtrahend of the contrast as the "control" or +#' `Y0` term and the minuend of the contrast as the `Y1` term, though the requested contrast need not match the "control vs treatment" +#' terminology of a classic two-treatment causal inference problem. We mirror the function calls and terminology of the `predict.bartmodel` +#' function, labeling each prediction data term with a `1` to denote its contribution to the treatment prediction of a contrast and +#' `0` to denote inclusion in the control prediction. +#' +#' Only valid when there is either a mean forest or a random effects term in the BART model. +#' +#' @param object Object of type `bart` containing draws of a regression forest and associated sampling outputs. +#' @param covariates_0 Covariates used for prediction in the "control" case. Must be a matrix or dataframe. +#' @param covariates_1 Covariates used for prediction in the "treatment" case. Must be a matrix or dataframe. +#' @param leaf_basis_0 (Optional) Bases used for prediction in the "control" case (by e.g. dot product with leaf values). Default: `NULL`. +#' @param leaf_basis_1 (Optional) Bases used for prediction in the "treatment" case (by e.g. dot product with leaf values). Default: `NULL`. +#' @param rfx_group_ids_0 (Optional) Test set group labels used for prediction from an additive random effects +#' model in the "control" case. We do not currently support (but plan to in the near future), test set evaluation +#' for group labels that were not in the training set. Must be a vector. +#' @param rfx_group_ids_1 (Optional) Test set group labels used for prediction from an additive random effects +#' model in the "treatment" case. We do not currently support (but plan to in the near future), test set evaluation +#' for group labels that were not in the training set. Must be a vector. +#' @param rfx_basis_0 (Optional) Test set basis for used for prediction from an additive random effects model in the "control" case. Must be a matrix or vector. +#' @param rfx_basis_1 (Optional) Test set basis for used for prediction from an additive random effects model in the "treatment" case. Must be a matrix or vector. +#' @param type (Optional) Type of prediction to return. Options are "mean", which averages the contrast evaluations over every draw of a BART model, and "posterior", which returns the entire matrix of posterior contrast estimates. Default: "posterior". +#' @param scale (Optional) Scale of mean function predictions. Options are "linear", which returns a contrast on the original scale of the mean forest / RFX terms, and "probability", which transforms each contrast term into a probability of observing `y == 1` before taking their difference. "probability" is only valid for models fit with a probit outcome model. Default: "linear". +#' +#' @return Contrast matrix or vector, depending on whether type = "mean" or "posterior". +#' @export +#' +#' @examples +#' n <- 100 +#' p <- 5 +#' X <- matrix(runif(n*p), ncol = p) +#' W <- matrix(runif(n*1), ncol = 1) +#' f_XW <- ( +#' ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5*W[,1]) + +#' ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5*W[,1]) + +#' ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5*W[,1]) + +#' ((0.75 <= X[,1]) & (1 > X[,1])) * (7.5*W[,1]) +#' ) +#' noise_sd <- 1 +#' y <- f_XW + rnorm(n, 0, noise_sd) +#' test_set_pct <- 0.2 +#' n_test <- round(test_set_pct*n) +#' n_train <- n - n_test +#' test_inds <- sort(sample(1:n, n_test, replace = FALSE)) +#' train_inds <- (1:n)[!((1:n) %in% test_inds)] +#' X_test <- X[test_inds,] +#' X_train <- X[train_inds,] +#' W_test <- W[test_inds,] +#' W_train <- W[train_inds,] +#' y_test <- y[test_inds] +#' y_train <- y[train_inds] +#' bart_model <- bart(X_train = X_train, leaf_basis_train = W_train, y_train = y_train, +#' num_gfr = 10, num_burnin = 0, num_mcmc = 10) +#' contrast <- compute_contrast_bart_model( +#' bart_model, +#' covariates_0 = X_test, +#' covariates_1 = X_test, +#' leaf_basis_0 = matrix(0, nrow = n_test, ncol = 1), +#' leaf_basis_1 = matrix(1, nrow = n_test, ncol = 1), +#' type = "posterior", +#' scale = "linear" +#' ) +compute_contrast_bart_model <- function( + object, + covariates_0, + covariates_1, + leaf_basis_0 = NULL, + leaf_basis_1 = NULL, + rfx_group_ids_0 = NULL, + rfx_group_ids_1 = NULL, + rfx_basis_0 = NULL, + rfx_basis_1 = NULL, + type = "posterior", + scale = "linear", + ... +) { + # Handle mean function scale + if (!is.character(scale)) { + stop("scale must be a string or character vector") + } + if (!(scale %in% c("linear", "probability"))) { + stop("scale must either be 'linear' or 'probability'") + } + is_probit <- object$model_params$probit_outcome_model + if ((scale == "probability") && (!is_probit)) { + stop( + "scale cannot be 'probability' for models not fit with a probit outcome model" + ) + } + probability_scale <- scale == "probability" + + # Handle prediction type + if (!is.character(type)) { + stop("type must be a string or character vector") + } + if (!(type %in% c("mean", "posterior"))) { + stop("type must either be 'mean' or 'posterior'") + } + predict_mean <- type == "mean" + + # Handle prediction terms + has_mean_forest <- object$model_params$include_mean_forest + has_rfx <- object$model_params$has_rfx + if ((!has_mean_forest) && (!has_rfx)) { + stop( + "Model must have either or both of mean forest or random effects terms to compute the requested contrast." + ) + } + + # Check that covariates are matrix or data frame + if ((!is.data.frame(covariates_0)) && (!is.matrix(covariates_0))) { + stop("covariates_0 must be a matrix or dataframe") + } + if ((!is.data.frame(covariates_1)) && (!is.matrix(covariates_1))) { + stop("covariates_1 must be a matrix or dataframe") + } + + # Convert all input data to matrices if not already converted + if ((is.null(dim(leaf_basis_0))) && (!is.null(leaf_basis_0))) { + leaf_basis_0 <- as.matrix(leaf_basis_0) + } + if ((is.null(dim(leaf_basis_1))) && (!is.null(leaf_basis_1))) { + leaf_basis_1 <- as.matrix(leaf_basis_1) + } + if ((is.null(dim(rfx_basis_0))) && (!is.null(rfx_basis_0))) { + rfx_basis_0 <- as.matrix(rfx_basis_0) + } + if ((is.null(dim(rfx_basis_1))) && (!is.null(rfx_basis_1))) { + rfx_basis_1 <- as.matrix(rfx_basis_1) + } + + # Data checks + if ( + (object$model_params$requires_basis) && + (is.null(leaf_basis_0) || is.null(leaf_basis_1)) + ) { + stop("leaf_basis_0 and leaf_basis_1 must be provided for this model") + } + if ((!is.null(leaf_basis_0)) && (nrow(covariates_0) != nrow(leaf_basis_0))) { + stop("covariates_0 and leaf_basis_0 must have the same number of rows") + } + if ((!is.null(leaf_basis_1)) && (nrow(covariates_1) != nrow(leaf_basis_1))) { + stop("covariates_1 and leaf_basis_1 must have the same number of rows") + } + if (object$model_params$num_covariates != ncol(covariates_0)) { + stop( + "covariates_0 must contain the same number of columns as the BART model's training dataset" + ) + } + if (object$model_params$num_covariates != ncol(covariates_1)) { + stop( + "covariates_1 must contain the same number of columns as the BART model's training dataset" + ) + } + if ((has_rfx) && (is.null(rfx_group_ids_0) || is.null(rfx_group_ids_1))) { + stop( + "rfx_group_ids_0 and rfx_group_ids_1 must be provided for this model" + ) + } + if ((has_rfx) && (is.null(rfx_basis_0) || is.null(rfx_basis_1))) { + stop( + "rfx_basis_0 and rfx_basis_1 must be provided for this model" + ) + } + if ( + (object$model_params$num_rfx_basis > 0) && + ((ncol(rfx_basis_0) != object$model_params$num_rfx_basis) || + (ncol(rfx_basis_1) != object$model_params$num_rfx_basis)) + ) { + stop( + "rfx_basis_0 and / or rfx_basis_1 have a different dimension than the basis used to train this model" + ) + } + + # Predict for the control arm + control_preds <- predict( + object = object, + covariates = covariates_0, + leaf_basis = leaf_basis_0, + rfx_group_ids = rfx_group_ids_0, + rfx_basis = rfx_basis_0, + type = "posterior", + term = "y_hat", + scale = "linear" + ) + + # Predict for the treatment arm + treatment_preds <- predict( + object = object, + covariates = covariates_1, + leaf_basis = leaf_basis_1, + rfx_group_ids = rfx_group_ids_1, + rfx_basis = rfx_basis_1, + type = "posterior", + term = "y_hat", + scale = "linear" + ) + + # Transform to probability scale if requested + if (probability_scale) { + treatment_preds <- pnorm(treatment_preds) + control_preds <- pnorm(control_preds) + } + + # Compute and return contrast + if (predict_mean) { + return(rowMeans(treatment_preds - control_preds)) + } else { + return(treatment_preds - control_preds) + } +} #' Sample from the posterior predictive distribution for outcomes modeled by BCF #' diff --git a/man/compute_contrast_bart_model.Rd b/man/compute_contrast_bart_model.Rd new file mode 100644 index 00000000..99aa7eaf --- /dev/null +++ b/man/compute_contrast_bart_model.Rd @@ -0,0 +1,96 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/posterior_transformation.R +\name{compute_contrast_bart_model} +\alias{compute_contrast_bart_model} +\title{Compute a contrast using a BART model by making two sets of outcome predictions and taking their difference. +This function provides the flexibility to compute any contrast of interest by specifying covariates, leaf basis, and random effects +bases / IDs for both sides of a two term contrast. For simplicity, we refer to the subtrahend of the contrast as the "control" or +\code{Y0} term and the minuend of the contrast as the \code{Y1} term, though the requested contrast need not match the "control vs treatment" +terminology of a classic two-treatment causal inference problem. We mirror the function calls and terminology of the \code{predict.bartmodel} +function, labeling each prediction data term with a \code{1} to denote its contribution to the treatment prediction of a contrast and +\code{0} to denote inclusion in the control prediction.} +\usage{ +compute_contrast_bart_model( + object, + covariates_0, + covariates_1, + leaf_basis_0 = NULL, + leaf_basis_1 = NULL, + rfx_group_ids_0 = NULL, + rfx_group_ids_1 = NULL, + rfx_basis_0 = NULL, + rfx_basis_1 = NULL, + type = "posterior", + scale = "linear", + ... +) +} +\arguments{ +\item{object}{Object of type \code{bart} containing draws of a regression forest and associated sampling outputs.} + +\item{covariates_0}{Covariates used for prediction in the "control" case. Must be a matrix or dataframe.} + +\item{covariates_1}{Covariates used for prediction in the "treatment" case. Must be a matrix or dataframe.} + +\item{leaf_basis_0}{(Optional) Bases used for prediction in the "control" case (by e.g. dot product with leaf values). Default: \code{NULL}.} + +\item{leaf_basis_1}{(Optional) Bases used for prediction in the "treatment" case (by e.g. dot product with leaf values). Default: \code{NULL}.} + +\item{rfx_group_ids_0}{(Optional) Test set group labels used for prediction from an additive random effects +model in the "control" case. We do not currently support (but plan to in the near future), test set evaluation +for group labels that were not in the training set. Must be a vector.} + +\item{rfx_group_ids_1}{(Optional) Test set group labels used for prediction from an additive random effects +model in the "treatment" case. We do not currently support (but plan to in the near future), test set evaluation +for group labels that were not in the training set. Must be a vector.} + +\item{rfx_basis_0}{(Optional) Test set basis for used for prediction from an additive random effects model in the "control" case. Must be a matrix or vector.} + +\item{rfx_basis_1}{(Optional) Test set basis for used for prediction from an additive random effects model in the "treatment" case. Must be a matrix or vector.} + +\item{type}{(Optional) Type of prediction to return. Options are "mean", which averages the contrast evaluations over every draw of a BART model, and "posterior", which returns the entire matrix of posterior contrast estimates. Default: "posterior".} + +\item{scale}{(Optional) Scale of mean function predictions. Options are "linear", which returns a contrast on the original scale of the mean forest / RFX terms, and "probability", which transforms each contrast term into a probability of observing \code{y == 1} before taking their difference. "probability" is only valid for models fit with a probit outcome model. Default: "linear".} +} +\value{ +Contrast matrix or vector, depending on whether type = "mean" or "posterior". +} +\description{ +Only valid when there is either a mean forest or a random effects term in the BART model. +} +\examples{ +n <- 100 +p <- 5 +X <- matrix(runif(n*p), ncol = p) +W <- matrix(runif(n*1), ncol = 1) +f_XW <- ( + ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5*W[,1]) + + ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5*W[,1]) + + ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5*W[,1]) + + ((0.75 <= X[,1]) & (1 > X[,1])) * (7.5*W[,1]) +) +noise_sd <- 1 +y <- f_XW + rnorm(n, 0, noise_sd) +test_set_pct <- 0.2 +n_test <- round(test_set_pct*n) +n_train <- n - n_test +test_inds <- sort(sample(1:n, n_test, replace = FALSE)) +train_inds <- (1:n)[!((1:n) \%in\% test_inds)] +X_test <- X[test_inds,] +X_train <- X[train_inds,] +W_test <- W[test_inds,] +W_train <- W[train_inds,] +y_test <- y[test_inds] +y_train <- y[train_inds] +bart_model <- bart(X_train = X_train, leaf_basis_train = W_train, y_train = y_train, + num_gfr = 10, num_burnin = 0, num_mcmc = 10) +contrast <- compute_contrast_bart_model( + bart_model, + covariates_0 = X_test, + covariates_1 = X_test, + leaf_basis_0 = matrix(0, nrow = n_test, ncol = 1), + leaf_basis_1 = matrix(1, nrow = n_test, ncol = 1), + type = "posterior", + scale = "linear" +) +} diff --git a/man/compute_contrast_bcf_model.Rd b/man/compute_contrast_bcf_model.Rd index fa1e5b30..082ffe04 100644 --- a/man/compute_contrast_bcf_model.Rd +++ b/man/compute_contrast_bcf_model.Rd @@ -33,31 +33,31 @@ compute_contrast_bcf_model( \arguments{ \item{object}{Object of type \code{bcfmodel} containing draws of a Bayesian causal forest model and associated sampling outputs.} -\item{X_0}{Covariates used for prediction in the "control" case.} +\item{X_0}{Covariates used for prediction in the "control" case. Must be a matrix or dataframe.} -\item{X_1}{Covariates used for prediction in the "treatment" case.} +\item{X_1}{Covariates used for prediction in the "treatment" case. Must be a matrix or dataframe.} -\item{Z_0}{Treatments used for prediction in the "control" case.} +\item{Z_0}{Treatments used for prediction in the "control" case. Must be a matrix or vector.} -\item{Z_1}{Treatments used for prediction in the "treatment" case.} +\item{Z_1}{Treatments used for prediction in the "treatment" case. Must be a matrix or vector.} -\item{propensity_0}{(Optional) Propensities used for prediction in the "control" case.} +\item{propensity_0}{(Optional) Propensities used for prediction in the "control" case. Must be a matrix or vector.} -\item{propensity_1}{(Optional) Propensities used for prediction in the "treatment" case.} +\item{propensity_1}{(Optional) Propensities used for prediction in the "treatment" case. Must be a matrix or vector.} \item{rfx_group_ids_0}{(Optional) Test set group labels used for prediction from an additive random effects model in the "control" case. We do not currently support (but plan to in the near future), test set evaluation -for group labels that were not in the training set.} +for group labels that were not in the training set. Must be a vector.} \item{rfx_group_ids_1}{(Optional) Test set group labels used for prediction from an additive random effects model in the "treatment" case. We do not currently support (but plan to in the near future), test set evaluation -for group labels that were not in the training set.} +for group labels that were not in the training set. Must be a vector.} -\item{rfx_basis_0}{(Optional) Test set basis for used for prediction from an additive random effects model in the "control" case.} +\item{rfx_basis_0}{(Optional) Test set basis for used for prediction from an additive random effects model in the "control" case. Must be a matrix or vector.} -\item{rfx_basis_1}{(Optional) Test set basis for used for prediction from an additive random effects model in the "treatment" case.} +\item{rfx_basis_1}{(Optional) Test set basis for used for prediction from an additive random effects model in the "treatment" case. Must be a matrix or vector.} -\item{type}{(Optional) Type of prediction to return. Options are "mean", which averages the predictions from every draw of a BART model, and "posterior", which returns the entire matrix of posterior predictions. Default: "posterior".} +\item{type}{(Optional) Type of prediction to return. Options are "mean", which averages the contrast evaluations over every draw of a BCF model, and "posterior", which returns the entire matrix of posterior contrast estimates. Default: "posterior".} \item{scale}{(Optional) Scale of mean function predictions. Options are "linear", which returns a contrast on the original scale of the mean forest / RFX terms, and "probability", which transforms each contrast term into a probability of observing \code{y == 1} before taking their difference. "probability" is only valid for models fit with a probit outcome model. Default: "linear".} diff --git a/man/predict.bcfmodel.Rd b/man/predict.bcfmodel.Rd index 71f275f6..7e8d6e0a 100644 --- a/man/predict.bcfmodel.Rd +++ b/man/predict.bcfmodel.Rd @@ -32,7 +32,7 @@ that were not in the training set.} \item{rfx_basis}{(Optional) Test set basis for "random-slope" regression in additive random effects model.} -\item{type}{(Optional) Type of prediction to return. Options are "mean", which averages the predictions from every draw of a BART model, and "posterior", which returns the entire matrix of posterior predictions. Default: "posterior".} +\item{type}{(Optional) Type of prediction to return. Options are "mean", which averages the predictions from every draw of a BCF model, and "posterior", which returns the entire matrix of posterior predictions. Default: "posterior".} \item{terms}{(Optional) Which model terms to include in the prediction. This can be a single term or a list of model terms. Options include "y_hat", "prognostic_function", "cate", "rfx", "variance_forest", or "all". If a model doesn't have random effects or variance forest predictions, but one of those terms is request, the request will simply be ignored. If none of the requested terms are present in a model, this function will return \code{NULL} along with a warning. Default: "all".} diff --git a/tools/debug/bart_contrast_debug.R b/tools/debug/bart_contrast_debug.R new file mode 100644 index 00000000..647d12b0 --- /dev/null +++ b/tools/debug/bart_contrast_debug.R @@ -0,0 +1,171 @@ +# Demo of CATE computation function for BCF + +# Load library +library(stochtree) + +# Generate data +n <- 500 +p <- 5 +X <- matrix(rnorm(n * p), ncol = p) +W <- matrix(rnorm(n * 1), ncol = 1) +# fmt: skip +f_XW <- ( + ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5 * W[,1]) + + ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5 * W[,1]) + + ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5 * W[,1]) + + ((0.75 <= X[,1]) & (1 > X[,1])) * (7.5 * W[,1]) +) +E_Y <- f_XW +snr <- 2 +y <- E_Y + rnorm(n, 0, 1) * (sd(E_Y) / snr) + +# Train-test split +test_set_pct <- 0.2 +n_test <- round(test_set_pct * n) +n_train <- n - n_test +test_inds <- sort(sample(1:n, n_test, replace = FALSE)) +train_inds <- (1:n)[!((1:n) %in% test_inds)] +X_test <- X[test_inds, ] +X_train <- X[train_inds, ] +W_test <- W[test_inds] +W_train <- W[train_inds] +y_test <- y[test_inds] +y_train <- y[train_inds] + +# Fit BART model +bart_model <- bart( + X_train = X_train, + leaf_basis_train = W_train, + y_train = y_train, + num_gfr = 10, + num_burnin = 0, + num_mcmc = 1000 +) + +# Compute contrast posterior +contrast_posterior_test <- compute_contrast_bart_model( + bart_model, + covariates_0 = X_test, + covariates_1 = X_test, + leaf_basis_0 = matrix(0, nrow = n_test, ncol = 1), + leaf_basis_1 = matrix(1, nrow = n_test, ncol = 1), + type = "posterior", + scale = "linear" +) + +# Compute the same quantity via two predict calls +y_hat_posterior_test_0 <- predict( + bart_model, + covariates = X_test, + leaf_basis = matrix(0, nrow = n_test, ncol = 1), + type = "posterior", + term = "y_hat", + scale = "linear" +) +y_hat_posterior_test_1 <- predict( + bart_model, + covariates = X_test, + leaf_basis = matrix(1, nrow = n_test, ncol = 1), + type = "posterior", + term = "y_hat", + scale = "linear" +) +contrast_posterior_test_comparison <- (y_hat_posterior_test_1 - + y_hat_posterior_test_0) + +# Compare results +contrast_diff <- contrast_posterior_test_comparison - contrast_posterior_test +all( + abs(contrast_diff) < 0.001 +) + +# Generate data for a BCF model with random effects +X <- matrix(rnorm(n * p), ncol = p) +W <- matrix(rnorm(n * 1), ncol = 1) +# fmt: skip +f_XW <- ( + ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5 * W[,1]) + + ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5 * W[,1]) + + ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5 * W[,1]) + + ((0.75 <= X[,1]) & (1 > X[,1])) * (7.5 * W[,1]) +) +group_ids <- rep(c(1, 2), n %/% 2) +rfx_coefs <- matrix(c(-1, -1, 1, 1), nrow = 2, byrow = TRUE) +rfx_basis <- cbind(1, runif(n)) +rfx_term <- rowSums(rfx_coefs[group_ids, ] * rfx_basis) +E_Y <- f_XW + rfx_term +snr <- 2 +y <- E_Y + rnorm(n, 0, 1) * (sd(E_Y) / snr) + +# Train-test split +n_test <- round(test_set_pct * n) +n_train <- n - n_test +test_inds <- sort(sample(1:n, n_test, replace = FALSE)) +train_inds <- (1:n)[!((1:n) %in% test_inds)] +X_test <- X[test_inds, ] +X_train <- X[train_inds, ] +W_test <- W[test_inds] +W_train <- W[train_inds] +y_test <- y[test_inds] +y_train <- y[train_inds] +group_ids_test <- group_ids[test_inds] +group_ids_train <- group_ids[train_inds] +rfx_basis_test <- rfx_basis[test_inds, ] +rfx_basis_train <- rfx_basis[train_inds, ] + +# Fit BART model +bart_model <- bart( + X_train = X_train, + leaf_basis_train = W_train, + y_train = y_train, + rfx_group_ids_train = group_ids_train, + rfx_basis_train = rfx_basis_train, + num_gfr = 10, + num_burnin = 0, + num_mcmc = 1000 +) + +# Compute contrast posterior +contrast_posterior_test <- compute_contrast_bart_model( + bart_model, + covariates_0 = X_test, + covariates_1 = X_test, + leaf_basis_0 = matrix(0, nrow = n_test, ncol = 1), + leaf_basis_1 = matrix(1, nrow = n_test, ncol = 1), + rfx_group_ids_0 = group_ids_test, + rfx_group_ids_1 = group_ids_test, + rfx_basis_0 = rfx_basis_test, + rfx_basis_1 = rfx_basis_test, + type = "posterior", + scale = "linear" +) + +# Compute the same quantity via two predict calls +y_hat_posterior_test_0 <- predict( + bart_model, + covariates = X_test, + leaf_basis = matrix(0, nrow = n_test, ncol = 1), + rfx_group_ids = group_ids_test, + rfx_basis = rfx_basis_test, + type = "posterior", + term = "y_hat", + scale = "linear" +) +y_hat_posterior_test_1 <- predict( + bart_model, + covariates = X_test, + leaf_basis = matrix(1, nrow = n_test, ncol = 1), + rfx_group_ids = group_ids_test, + rfx_basis = rfx_basis_test, + type = "posterior", + term = "y_hat", + scale = "linear" +) +contrast_posterior_test_comparison <- (y_hat_posterior_test_1 - + y_hat_posterior_test_0) + +# Compare results +contrast_diff <- contrast_posterior_test_comparison - contrast_posterior_test +all( + abs(contrast_diff) < 0.001 +) From e46202350528b1c7f6bbd10b3d319c736389907f Mon Sep 17 00:00:00 2001 From: Drew Herren Date: Wed, 22 Oct 2025 00:25:43 -0500 Subject: [PATCH 28/53] Updated doc examples --- R/posterior_transformation.R | 4 ++-- man/compute_contrast_bart_model.Rd | 2 +- man/compute_contrast_bcf_model.Rd | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/R/posterior_transformation.R b/R/posterior_transformation.R index 50f2c494..f8bad094 100644 --- a/R/posterior_transformation.R +++ b/R/posterior_transformation.R @@ -76,7 +76,7 @@ #' bcf_model <- bcf(X_train = X_train, Z_train = Z_train, y_train = y_train, #' propensity_train = pi_train, num_gfr = 10, #' num_burnin = 0, num_mcmc = 10) -#' preds <- compute_posterior_contrast_bcf_model( +#' tau_hat_test <- compute_contrast_bcf_model( #' bcf_model, X0=X_test, X1=X_test, Z0=rep(0, n_test), Z1=rep(1, n_test), #' propensity_0 = pi_test, propensity_1 = pi_test #' ) @@ -302,7 +302,7 @@ compute_contrast_bcf_model <- function( #' y_train <- y[train_inds] #' bart_model <- bart(X_train = X_train, leaf_basis_train = W_train, y_train = y_train, #' num_gfr = 10, num_burnin = 0, num_mcmc = 10) -#' contrast <- compute_contrast_bart_model( +#' contrast_test <- compute_contrast_bart_model( #' bart_model, #' covariates_0 = X_test, #' covariates_1 = X_test, diff --git a/man/compute_contrast_bart_model.Rd b/man/compute_contrast_bart_model.Rd index 99aa7eaf..1e4a3ac7 100644 --- a/man/compute_contrast_bart_model.Rd +++ b/man/compute_contrast_bart_model.Rd @@ -84,7 +84,7 @@ y_test <- y[test_inds] y_train <- y[train_inds] bart_model <- bart(X_train = X_train, leaf_basis_train = W_train, y_train = y_train, num_gfr = 10, num_burnin = 0, num_mcmc = 10) -contrast <- compute_contrast_bart_model( +contrast_test <- compute_contrast_bart_model( bart_model, covariates_0 = X_test, covariates_1 = X_test, diff --git a/man/compute_contrast_bcf_model.Rd b/man/compute_contrast_bcf_model.Rd index 082ffe04..aafde613 100644 --- a/man/compute_contrast_bcf_model.Rd +++ b/man/compute_contrast_bcf_model.Rd @@ -123,7 +123,7 @@ tau_train <- tau_x[train_inds] bcf_model <- bcf(X_train = X_train, Z_train = Z_train, y_train = y_train, propensity_train = pi_train, num_gfr = 10, num_burnin = 0, num_mcmc = 10) -preds <- compute_posterior_contrast_bcf_model( +tau_hat_test <- compute_contrast_bcf_model( bcf_model, X0=X_test, X1=X_test, Z0=rep(0, n_test), Z1=rep(1, n_test), propensity_0 = pi_test, propensity_1 = pi_test ) From 8bc98e1684003fcc14835a90c8322f69c62fe21a Mon Sep 17 00:00:00 2001 From: Drew Herren Date: Wed, 22 Oct 2025 01:42:34 -0500 Subject: [PATCH 29/53] Shifted RFX parameters to a new list in the R interface for BCF and BART --- R/bart.R | 46 ++++++++++++++++++++++++++----------- R/bcf.R | 40 ++++++++++++++++++++++++++------ man/bart.Rd | 19 +++++++++------ man/bcf.Rd | 13 ++++++++++- test/R/testthat/test-bart.R | 36 ++++++++++++++--------------- test/R/testthat/test-bcf.R | 32 +++++++++++++------------- 6 files changed, 124 insertions(+), 62 deletions(-) diff --git a/R/bart.R b/R/bart.R index 00cf41f4..b058a675 100644 --- a/R/bart.R +++ b/R/bart.R @@ -45,12 +45,6 @@ #' - `num_chains` How many independent MCMC chains should be sampled. If `num_mcmc = 0`, this is ignored. If `num_gfr = 0`, then each chain is run from root for `num_mcmc * keep_every + num_burnin` iterations, with `num_mcmc` samples retained. If `num_gfr > 0`, each MCMC chain will be initialized from a separate GFR ensemble, with the requirement that `num_gfr >= num_chains`. Default: `1`. #' - `verbose` Whether or not to print progress during the sampling loops. Default: `FALSE`. #' - `probit_outcome_model` Whether or not the outcome should be modeled as explicitly binary via a probit link. If `TRUE`, `y` must only contain the values `0` and `1`. Default: `FALSE`. -#' - `rfx_working_parameter_prior_mean` Prior mean for the random effects "working parameter". Default: `NULL`. Must be a vector whose dimension matches the number of random effects bases, or a scalar value that will be expanded to a vector. -#' - `rfx_group_parameters_prior_mean` Prior mean for the random effects "group parameters." Default: `NULL`. Must be a vector whose dimension matches the number of random effects bases, or a scalar value that will be expanded to a vector. -#' - `rfx_working_parameter_prior_cov` Prior covariance matrix for the random effects "working parameter." Default: `NULL`. Must be a square matrix whose dimension matches the number of random effects bases, or a scalar value that will be expanded to a diagonal matrix. -#' - `rfx_group_parameter_prior_cov` Prior covariance matrix for the random effects "group parameters." Default: `NULL`. Must be a square matrix whose dimension matches the number of random effects bases, or a scalar value that will be expanded to a diagonal matrix. -#' - `rfx_variance_prior_shape` Shape parameter for the inverse gamma prior on the variance of the random effects "group parameter." Default: `1`. -#' - `rfx_variance_prior_scale` Scale parameter for the inverse gamma prior on the variance of the random effects "group parameter." Default: `1`. #' - `num_threads` Number of threads to use in the GFR and MCMC algorithms, as well as prediction. If OpenMP is not available on a user's setup, this will default to `1`, otherwise to the maximum number of available threads. #' #' @param mean_forest_params (Optional) A list of mean forest model parameters, each of which has a default value processed internally, so this argument list is optional. @@ -83,6 +77,15 @@ #' - `drop_vars` Vector of variable names or column indices denoting variables that should be excluded from the forest. Default: `NULL`. If both `drop_vars` and `keep_vars` are set, `drop_vars` will be ignored. #' - `num_features_subsample` How many features to subsample when growing each tree for the GFR algorithm. Defaults to the number of features in the training dataset. #' +#' @param random_effects_params (Optional) A list of random effects model parameters, each of which has a default value processed internally, so this argument list is optional. +#' +#' - `working_parameter_prior_mean` Prior mean for the random effects "working parameter". Default: `NULL`. Must be a vector whose dimension matches the number of random effects bases, or a scalar value that will be expanded to a vector. +#' - `group_parameters_prior_mean` Prior mean for the random effects "group parameters." Default: `NULL`. Must be a vector whose dimension matches the number of random effects bases, or a scalar value that will be expanded to a vector. +#' - `working_parameter_prior_cov` Prior covariance matrix for the random effects "working parameter." Default: `NULL`. Must be a square matrix whose dimension matches the number of random effects bases, or a scalar value that will be expanded to a diagonal matrix. +#' - `group_parameter_prior_cov` Prior covariance matrix for the random effects "group parameters." Default: `NULL`. Must be a square matrix whose dimension matches the number of random effects bases, or a scalar value that will be expanded to a diagonal matrix. +#' - `variance_prior_shape` Shape parameter for the inverse gamma prior on the variance of the random effects "group parameter." Default: `1`. +#' - `variance_prior_scale` Scale parameter for the inverse gamma prior on the variance of the random effects "group parameter." Default: `1`. +#' #' @return List of sampling outputs and a wrapper around the sampled forests (which can be used for in-memory prediction on new data, or serialized to JSON on disk). #' @export #' @@ -127,7 +130,8 @@ bart <- function( previous_model_warmstart_sample_num = NULL, general_params = list(), mean_forest_params = list(), - variance_forest_params = list() + variance_forest_params = list(), + rfx_params = list() ) { # Update general BART parameters general_params_default <- list( @@ -198,6 +202,20 @@ bart <- function( variance_forest_params ) + # Update rfx parameters + rfx_params_default <- list( + working_parameter_prior_mean = NULL, + group_parameter_prior_mean = NULL, + working_parameter_prior_cov = NULL, + group_parameter_prior_cov = NULL, + variance_prior_shape = 1, + variance_prior_scale = 1 + ) + rfx_params_updated <- preprocessParams( + rfx_params_default, + rfx_params + ) + ### Unpack all parameter values # 1. General parameters cutpoint_grid_size <- general_params_updated$cutpoint_grid_size @@ -214,12 +232,6 @@ bart <- function( num_chains <- general_params_updated$num_chains verbose <- general_params_updated$verbose probit_outcome_model <- general_params_updated$probit_outcome_model - rfx_working_parameter_prior_mean <- general_params_updated$rfx_working_parameter_prior_mean - rfx_group_parameter_prior_mean <- general_params_updated$rfx_group_parameter_prior_mean - rfx_working_parameter_prior_cov <- general_params_updated$rfx_working_parameter_prior_cov - rfx_group_parameter_prior_cov <- general_params_updated$rfx_group_parameter_prior_cov - rfx_variance_prior_shape <- general_params_updated$rfx_variance_prior_shape - rfx_variance_prior_scale <- general_params_updated$rfx_variance_prior_scale num_threads <- general_params_updated$num_threads # 2. Mean forest parameters @@ -250,6 +262,14 @@ bart <- function( drop_vars_variance <- variance_forest_params_updated$drop_vars num_features_subsample_variance <- variance_forest_params_updated$num_features_subsample + # 4. RFX parameters + rfx_working_parameter_prior_mean <- rfx_params_updated$working_parameter_prior_mean + rfx_group_parameter_prior_mean <- rfx_params_updated$group_parameter_prior_mean + rfx_working_parameter_prior_cov <- rfx_params_updated$working_parameter_prior_cov + rfx_group_parameter_prior_cov <- rfx_params_updated$group_parameter_prior_cov + rfx_variance_prior_shape <- rfx_params_updated$variance_prior_shape + rfx_variance_prior_scale <- rfx_params_updated$variance_prior_scale + # Set a function-scoped RNG if user provided a random seed custom_rng <- random_seed >= 0 if (custom_rng) { diff --git a/R/bcf.R b/R/bcf.R index dd6d9b95..444899d2 100644 --- a/R/bcf.R +++ b/R/bcf.R @@ -103,6 +103,15 @@ #' - `drop_vars` Vector of variable names or column indices denoting variables that should be excluded from the forest. Default: `NULL`. If both `drop_vars` and `keep_vars` are set, `drop_vars` will be ignored. #' - `num_features_subsample` How many features to subsample when growing each tree for the GFR algorithm. Defaults to the number of features in the training dataset. #' +#' @param rfx_params (Optional) A list of random effects model parameters, each of which has a default value processed internally, so this argument list is optional. +#' +#' - `working_parameter_prior_mean` Prior mean for the random effects "working parameter". Default: `NULL`. Must be a vector whose dimension matches the number of random effects bases, or a scalar value that will be expanded to a vector. +#' - `group_parameters_prior_mean` Prior mean for the random effects "group parameters." Default: `NULL`. Must be a vector whose dimension matches the number of random effects bases, or a scalar value that will be expanded to a vector. +#' - `working_parameter_prior_cov` Prior covariance matrix for the random effects "working parameter." Default: `NULL`. Must be a square matrix whose dimension matches the number of random effects bases, or a scalar value that will be expanded to a diagonal matrix. +#' - `group_parameter_prior_cov` Prior covariance matrix for the random effects "group parameters." Default: `NULL`. Must be a square matrix whose dimension matches the number of random effects bases, or a scalar value that will be expanded to a diagonal matrix. +#' - `variance_prior_shape` Shape parameter for the inverse gamma prior on the variance of the random effects "group parameter." Default: `1`. +#' - `variance_prior_scale` Scale parameter for the inverse gamma prior on the variance of the random effects "group parameter." Default: `1`. +#' #' @return List of sampling outputs and a wrapper around the sampled forests (which can be used for in-memory prediction on new data, or serialized to JSON on disk). #' @export #' @@ -172,7 +181,8 @@ bcf <- function( general_params = list(), prognostic_forest_params = list(), treatment_effect_forest_params = list(), - variance_forest_params = list() + variance_forest_params = list(), + rfx_params = list() ) { # Update general BCF parameters general_params_default <- list( @@ -269,6 +279,20 @@ bcf <- function( variance_forest_params ) + # Update random effects parameters + rfx_params_default <- list( + working_parameter_prior_mean = NULL, + group_parameter_prior_mean = NULL, + working_parameter_prior_cov = NULL, + group_parameter_prior_cov = NULL, + variance_prior_shape = 1, + variance_prior_scale = 1 + ) + rfx_params_updated <- preprocessParams( + rfx_params_default, + rfx_params + ) + ### Unpack all parameter values # 1. General parameters cutpoint_grid_size <- general_params_updated$cutpoint_grid_size @@ -290,12 +314,6 @@ bcf <- function( num_chains <- general_params_updated$num_chains verbose <- general_params_updated$verbose probit_outcome_model <- general_params_updated$probit_outcome_model - rfx_working_parameter_prior_mean <- general_params_updated$rfx_working_parameter_prior_mean - rfx_group_parameter_prior_mean <- general_params_updated$rfx_group_parameter_prior_mean - rfx_working_parameter_prior_cov <- general_params_updated$rfx_working_parameter_prior_cov - rfx_group_parameter_prior_cov <- general_params_updated$rfx_group_parameter_prior_cov - rfx_variance_prior_shape <- general_params_updated$rfx_variance_prior_shape - rfx_variance_prior_scale <- general_params_updated$rfx_variance_prior_scale num_threads <- general_params_updated$num_threads # 2. Mu forest parameters @@ -341,6 +359,14 @@ bcf <- function( drop_vars_variance <- variance_forest_params_updated$drop_vars num_features_subsample_variance <- variance_forest_params_updated$num_features_subsample + # 5. Random effects parameters + rfx_working_parameter_prior_mean <- rfx_params_updated$working_parameter_prior_mean + rfx_group_parameter_prior_mean <- rfx_params_updated$group_parameter_prior_mean + rfx_working_parameter_prior_cov <- rfx_params_updated$working_parameter_prior_cov + rfx_group_parameter_prior_cov <- rfx_params_updated$group_parameter_prior_cov + rfx_variance_prior_shape <- rfx_params_updated$variance_prior_shape + rfx_variance_prior_scale <- rfx_params_updated$variance_prior_scale + # Set a function-scoped RNG if user provided a random seed custom_rng <- random_seed >= 0 if (custom_rng) { diff --git a/man/bart.Rd b/man/bart.Rd index c11c619b..6bc14615 100644 --- a/man/bart.Rd +++ b/man/bart.Rd @@ -21,7 +21,8 @@ bart( previous_model_warmstart_sample_num = NULL, general_params = list(), mean_forest_params = list(), - variance_forest_params = list() + variance_forest_params = list(), + rfx_params = list() ) } \arguments{ @@ -84,12 +85,6 @@ that were not in the training set.} \item \code{num_chains} How many independent MCMC chains should be sampled. If \code{num_mcmc = 0}, this is ignored. If \code{num_gfr = 0}, then each chain is run from root for \code{num_mcmc * keep_every + num_burnin} iterations, with \code{num_mcmc} samples retained. If \code{num_gfr > 0}, each MCMC chain will be initialized from a separate GFR ensemble, with the requirement that \code{num_gfr >= num_chains}. Default: \code{1}. \item \code{verbose} Whether or not to print progress during the sampling loops. Default: \code{FALSE}. \item \code{probit_outcome_model} Whether or not the outcome should be modeled as explicitly binary via a probit link. If \code{TRUE}, \code{y} must only contain the values \code{0} and \code{1}. Default: \code{FALSE}. -\item \code{rfx_working_parameter_prior_mean} Prior mean for the random effects "working parameter". Default: \code{NULL}. Must be a vector whose dimension matches the number of random effects bases, or a scalar value that will be expanded to a vector. -\item \code{rfx_group_parameters_prior_mean} Prior mean for the random effects "group parameters." Default: \code{NULL}. Must be a vector whose dimension matches the number of random effects bases, or a scalar value that will be expanded to a vector. -\item \code{rfx_working_parameter_prior_cov} Prior covariance matrix for the random effects "working parameter." Default: \code{NULL}. Must be a square matrix whose dimension matches the number of random effects bases, or a scalar value that will be expanded to a diagonal matrix. -\item \code{rfx_group_parameter_prior_cov} Prior covariance matrix for the random effects "group parameters." Default: \code{NULL}. Must be a square matrix whose dimension matches the number of random effects bases, or a scalar value that will be expanded to a diagonal matrix. -\item \code{rfx_variance_prior_shape} Shape parameter for the inverse gamma prior on the variance of the random effects "group parameter." Default: \code{1}. -\item \code{rfx_variance_prior_scale} Scale parameter for the inverse gamma prior on the variance of the random effects "group parameter." Default: \code{1}. \item \code{num_threads} Number of threads to use in the GFR and MCMC algorithms, as well as prediction. If OpenMP is not available on a user's setup, this will default to \code{1}, otherwise to the maximum number of available threads. }} @@ -124,6 +119,16 @@ that were not in the training set.} \item \code{drop_vars} Vector of variable names or column indices denoting variables that should be excluded from the forest. Default: \code{NULL}. If both \code{drop_vars} and \code{keep_vars} are set, \code{drop_vars} will be ignored. \item \code{num_features_subsample} How many features to subsample when growing each tree for the GFR algorithm. Defaults to the number of features in the training dataset. }} + +\item{random_effects_params}{(Optional) A list of random effects model parameters, each of which has a default value processed internally, so this argument list is optional. +\itemize{ +\item \code{working_parameter_prior_mean} Prior mean for the random effects "working parameter". Default: \code{NULL}. Must be a vector whose dimension matches the number of random effects bases, or a scalar value that will be expanded to a vector. +\item \code{group_parameters_prior_mean} Prior mean for the random effects "group parameters." Default: \code{NULL}. Must be a vector whose dimension matches the number of random effects bases, or a scalar value that will be expanded to a vector. +\item \code{working_parameter_prior_cov} Prior covariance matrix for the random effects "working parameter." Default: \code{NULL}. Must be a square matrix whose dimension matches the number of random effects bases, or a scalar value that will be expanded to a diagonal matrix. +\item \code{group_parameter_prior_cov} Prior covariance matrix for the random effects "group parameters." Default: \code{NULL}. Must be a square matrix whose dimension matches the number of random effects bases, or a scalar value that will be expanded to a diagonal matrix. +\item \code{variance_prior_shape} Shape parameter for the inverse gamma prior on the variance of the random effects "group parameter." Default: \code{1}. +\item \code{variance_prior_scale} Scale parameter for the inverse gamma prior on the variance of the random effects "group parameter." Default: \code{1}. +}} } \value{ List of sampling outputs and a wrapper around the sampled forests (which can be used for in-memory prediction on new data, or serialized to JSON on disk). diff --git a/man/bcf.Rd b/man/bcf.Rd index f7d42e93..ac6449db 100644 --- a/man/bcf.Rd +++ b/man/bcf.Rd @@ -24,7 +24,8 @@ bcf( general_params = list(), prognostic_forest_params = list(), treatment_effect_forest_params = list(), - variance_forest_params = list() + variance_forest_params = list(), + rfx_params = list() ) } \arguments{ @@ -150,6 +151,16 @@ that were not in the training set.} \item \code{drop_vars} Vector of variable names or column indices denoting variables that should be excluded from the forest. Default: \code{NULL}. If both \code{drop_vars} and \code{keep_vars} are set, \code{drop_vars} will be ignored. \item \code{num_features_subsample} How many features to subsample when growing each tree for the GFR algorithm. Defaults to the number of features in the training dataset. }} + +\item{rfx_params}{(Optional) A list of random effects model parameters, each of which has a default value processed internally, so this argument list is optional. +\itemize{ +\item \code{working_parameter_prior_mean} Prior mean for the random effects "working parameter". Default: \code{NULL}. Must be a vector whose dimension matches the number of random effects bases, or a scalar value that will be expanded to a vector. +\item \code{group_parameters_prior_mean} Prior mean for the random effects "group parameters." Default: \code{NULL}. Must be a vector whose dimension matches the number of random effects bases, or a scalar value that will be expanded to a vector. +\item \code{working_parameter_prior_cov} Prior covariance matrix for the random effects "working parameter." Default: \code{NULL}. Must be a square matrix whose dimension matches the number of random effects bases, or a scalar value that will be expanded to a diagonal matrix. +\item \code{group_parameter_prior_cov} Prior covariance matrix for the random effects "group parameters." Default: \code{NULL}. Must be a square matrix whose dimension matches the number of random effects bases, or a scalar value that will be expanded to a diagonal matrix. +\item \code{variance_prior_shape} Shape parameter for the inverse gamma prior on the variance of the random effects "group parameter." Default: \code{1}. +\item \code{variance_prior_scale} Scale parameter for the inverse gamma prior on the variance of the random effects "group parameter." Default: \code{1}. +}} } \value{ List of sampling outputs and a wrapper around the sampled forests (which can be used for in-memory prediction on new data, or serialized to JSON on disk). diff --git a/test/R/testthat/test-bart.R b/test/R/testthat/test-bart.R index d5e3570c..de45d73d 100644 --- a/test/R/testthat/test-bart.R +++ b/test/R/testthat/test-bart.R @@ -505,13 +505,13 @@ test_that("Random Effects BART", { ) # Specify all rfx parameters as scalars - general_param_list <- list( - rfx_working_parameter_prior_mean = 1., - rfx_group_parameter_prior_mean = 1., - rfx_working_parameter_prior_cov = 1., - rfx_group_parameter_prior_cov = 1., - rfx_variance_prior_shape = 1, - rfx_variance_prior_scale = 1 + rfx_param_list <- list( + working_parameter_prior_mean = 1., + group_parameter_prior_mean = 1., + working_parameter_prior_cov = 1., + group_parameter_prior_cov = 1., + variance_prior_shape = 1, + variance_prior_scale = 1 ) mean_forest_param_list <- list(sample_sigma2_leaf = FALSE) expect_no_error( @@ -528,19 +528,19 @@ test_that("Random Effects BART", { num_gfr = 0, num_burnin = 10, num_mcmc = 10, - general_params = general_param_list, - mean_forest_params = mean_forest_param_list + mean_forest_params = mean_forest_param_list, + rfx_params = rfx_param_list ) ) # Specify all relevant rfx parameters as vectors - general_param_list <- list( - rfx_working_parameter_prior_mean = c(1., 1.), - rfx_group_parameter_prior_mean = c(1., 1.), - rfx_working_parameter_prior_cov = diag(1., 2), - rfx_group_parameter_prior_cov = diag(1., 2), - rfx_variance_prior_shape = 1, - rfx_variance_prior_scale = 1 + rfx_param_list <- list( + working_parameter_prior_mean = c(1., 1.), + group_parameter_prior_mean = c(1., 1.), + working_parameter_prior_cov = diag(1., 2), + group_parameter_prior_cov = diag(1., 2), + variance_prior_shape = 1, + variance_prior_scale = 1 ) mean_forest_param_list <- list(sample_sigma2_leaf = FALSE) expect_no_error( @@ -557,8 +557,8 @@ test_that("Random Effects BART", { num_gfr = 0, num_burnin = 10, num_mcmc = 10, - general_params = general_param_list, - mean_forest_params = mean_forest_param_list + mean_forest_params = mean_forest_param_list, + rfx_params = rfx_param_list ) ) }) diff --git a/test/R/testthat/test-bcf.R b/test/R/testthat/test-bcf.R index 8f0d69f0..4cf26609 100644 --- a/test/R/testthat/test-bcf.R +++ b/test/R/testthat/test-bcf.R @@ -726,13 +726,13 @@ test_that("Random Effects BCF", { ) # Specify all rfx parameters as scalars - general_param_list <- list( - rfx_working_parameter_prior_mean = 1., - rfx_group_parameter_prior_mean = 1., - rfx_working_parameter_prior_cov = 1., - rfx_group_parameter_prior_cov = 1., - rfx_variance_prior_shape = 1, - rfx_variance_prior_scale = 1 + rfx_param_list <- list( + working_parameter_prior_mean = 1., + group_parameter_prior_mean = 1., + working_parameter_prior_cov = 1., + group_parameter_prior_cov = 1., + variance_prior_shape = 1, + variance_prior_scale = 1 ) expect_no_error( bcf_model <- bcf( @@ -750,18 +750,18 @@ test_that("Random Effects BCF", { num_gfr = 10, num_burnin = 0, num_mcmc = 10, - general_params = general_param_list + rfx_params = rfx_param_list ) ) # Specify all relevant rfx parameters as vectors - general_param_list <- list( - rfx_working_parameter_prior_mean = c(1., 1.), - rfx_group_parameter_prior_mean = c(1., 1.), - rfx_working_parameter_prior_cov = diag(1., 2), - rfx_group_parameter_prior_cov = diag(1., 2), - rfx_variance_prior_shape = 1, - rfx_variance_prior_scale = 1 + rfx_param_list <- list( + working_parameter_prior_mean = c(1., 1.), + group_parameter_prior_mean = c(1., 1.), + working_parameter_prior_cov = diag(1., 2), + group_parameter_prior_cov = diag(1., 2), + variance_prior_shape = 1, + variance_prior_scale = 1 ) expect_no_error( bcf_model <- bcf( @@ -779,7 +779,7 @@ test_that("Random Effects BCF", { num_gfr = 10, num_burnin = 0, num_mcmc = 10, - general_params = general_param_list + rfx_params = rfx_param_list ) ) }) From e157bb64bf7199ca057351a4d5f53f5e8b1d9532 Mon Sep 17 00:00:00 2001 From: Drew Herren Date: Wed, 22 Oct 2025 01:48:08 -0500 Subject: [PATCH 30/53] Updated docs --- R/posterior_transformation.R | 9 +++------ man/compute_contrast_bart_model.Rd | 3 +-- man/compute_contrast_bcf_model.Rd | 7 ++----- 3 files changed, 6 insertions(+), 13 deletions(-) diff --git a/R/posterior_transformation.R b/R/posterior_transformation.R index f8bad094..b8f1e1de 100644 --- a/R/posterior_transformation.R +++ b/R/posterior_transformation.R @@ -26,7 +26,6 @@ #' @param rfx_basis_1 (Optional) Test set basis for used for prediction from an additive random effects model in the "treatment" case. Must be a matrix or vector. #' @param type (Optional) Type of prediction to return. Options are "mean", which averages the contrast evaluations over every draw of a BCF model, and "posterior", which returns the entire matrix of posterior contrast estimates. Default: "posterior". #' @param scale (Optional) Scale of mean function predictions. Options are "linear", which returns a contrast on the original scale of the mean forest / RFX terms, and "probability", which transforms each contrast term into a probability of observing `y == 1` before taking their difference. "probability" is only valid for models fit with a probit outcome model. Default: "linear". -#' @param ... (Optional) Other prediction parameters. #' #' @return List of prediction matrices or single prediction matrix / vector, depending on the terms requested. #' @export @@ -77,7 +76,7 @@ #' propensity_train = pi_train, num_gfr = 10, #' num_burnin = 0, num_mcmc = 10) #' tau_hat_test <- compute_contrast_bcf_model( -#' bcf_model, X0=X_test, X1=X_test, Z0=rep(0, n_test), Z1=rep(1, n_test), +#' bcf_model, X_0=X_test, X_1=X_test, Z_0=rep(0, n_test), Z_1=rep(1, n_test), #' propensity_0 = pi_test, propensity_1 = pi_test #' ) compute_contrast_bcf_model <- function( @@ -93,8 +92,7 @@ compute_contrast_bcf_model <- function( rfx_basis_0 = NULL, rfx_basis_1 = NULL, type = "posterior", - scale = "linear", - ... + scale = "linear" ) { # Handle mean function scale if (!is.character(scale)) { @@ -322,8 +320,7 @@ compute_contrast_bart_model <- function( rfx_basis_0 = NULL, rfx_basis_1 = NULL, type = "posterior", - scale = "linear", - ... + scale = "linear" ) { # Handle mean function scale if (!is.character(scale)) { diff --git a/man/compute_contrast_bart_model.Rd b/man/compute_contrast_bart_model.Rd index 1e4a3ac7..c5bb2d09 100644 --- a/man/compute_contrast_bart_model.Rd +++ b/man/compute_contrast_bart_model.Rd @@ -21,8 +21,7 @@ compute_contrast_bart_model( rfx_basis_0 = NULL, rfx_basis_1 = NULL, type = "posterior", - scale = "linear", - ... + scale = "linear" ) } \arguments{ diff --git a/man/compute_contrast_bcf_model.Rd b/man/compute_contrast_bcf_model.Rd index aafde613..8aa1ec59 100644 --- a/man/compute_contrast_bcf_model.Rd +++ b/man/compute_contrast_bcf_model.Rd @@ -26,8 +26,7 @@ compute_contrast_bcf_model( rfx_basis_0 = NULL, rfx_basis_1 = NULL, type = "posterior", - scale = "linear", - ... + scale = "linear" ) } \arguments{ @@ -60,8 +59,6 @@ for group labels that were not in the training set. Must be a vector.} \item{type}{(Optional) Type of prediction to return. Options are "mean", which averages the contrast evaluations over every draw of a BCF model, and "posterior", which returns the entire matrix of posterior contrast estimates. Default: "posterior".} \item{scale}{(Optional) Scale of mean function predictions. Options are "linear", which returns a contrast on the original scale of the mean forest / RFX terms, and "probability", which transforms each contrast term into a probability of observing \code{y == 1} before taking their difference. "probability" is only valid for models fit with a probit outcome model. Default: "linear".} - -\item{...}{(Optional) Other prediction parameters.} } \value{ List of prediction matrices or single prediction matrix / vector, depending on the terms requested. @@ -124,7 +121,7 @@ bcf_model <- bcf(X_train = X_train, Z_train = Z_train, y_train = y_train, propensity_train = pi_train, num_gfr = 10, num_burnin = 0, num_mcmc = 10) tau_hat_test <- compute_contrast_bcf_model( - bcf_model, X0=X_test, X1=X_test, Z0=rep(0, n_test), Z1=rep(1, n_test), + bcf_model, X_0=X_test, X_1=X_test, Z_0=rep(0, n_test), Z_1=rep(1, n_test), propensity_0 = pi_test, propensity_1 = pi_test ) } From c28578b6a5420f0a8ab422bf7faf5a2263922b8b Mon Sep 17 00:00:00 2001 From: Drew Herren Date: Wed, 22 Oct 2025 01:57:25 -0500 Subject: [PATCH 31/53] Reflected RFX parameter list through python interface as well --- R/bart.R | 6 ---- R/bcf.R | 12 ------- man/bcf.Rd | 6 ---- stochtree/bart.py | 68 ++++++++++++++++++++++++---------------- stochtree/bcf.py | 66 +++++++++++++++++++++++--------------- test/python/test_bart.py | 32 +++++++++---------- test/python/test_bcf.py | 32 +++++++++---------- 7 files changed, 113 insertions(+), 109 deletions(-) diff --git a/R/bart.R b/R/bart.R index b058a675..789fccdf 100644 --- a/R/bart.R +++ b/R/bart.R @@ -149,12 +149,6 @@ bart <- function( num_chains = 1, verbose = FALSE, probit_outcome_model = FALSE, - rfx_working_parameter_prior_mean = NULL, - rfx_group_parameter_prior_mean = NULL, - rfx_working_parameter_prior_cov = NULL, - rfx_group_parameter_prior_cov = NULL, - rfx_variance_prior_shape = 1, - rfx_variance_prior_scale = 1, num_threads = -1 ) general_params_updated <- preprocessParams( diff --git a/R/bcf.R b/R/bcf.R index 444899d2..41136fcc 100644 --- a/R/bcf.R +++ b/R/bcf.R @@ -47,12 +47,6 @@ #' - `num_chains` How many independent MCMC chains should be sampled. If `num_mcmc = 0`, this is ignored. If `num_gfr = 0`, then each chain is run from root for `num_mcmc * keep_every + num_burnin` iterations, with `num_mcmc` samples retained. If `num_gfr > 0`, each MCMC chain will be initialized from a separate GFR ensemble, with the requirement that `num_gfr >= num_chains`. Default: `1`. #' - `verbose` Whether or not to print progress during the sampling loops. Default: `FALSE`. #' - `probit_outcome_model` Whether or not the outcome should be modeled as explicitly binary via a probit link. If `TRUE`, `y` must only contain the values `0` and `1`. Default: `FALSE`. -#' - `rfx_working_parameter_prior_mean` Prior mean for the random effects "working parameter". Default: `NULL`. Must be a vector whose dimension matches the number of random effects bases, or a scalar value that will be expanded to a vector. -#' - `rfx_group_parameters_prior_mean` Prior mean for the random effects "group parameters." Default: `NULL`. Must be a vector whose dimension matches the number of random effects bases, or a scalar value that will be expanded to a vector. -#' - `rfx_working_parameter_prior_cov` Prior covariance matrix for the random effects "working parameter." Default: `NULL`. Must be a square matrix whose dimension matches the number of random effects bases, or a scalar value that will be expanded to a diagonal matrix. -#' - `rfx_group_parameter_prior_cov` Prior covariance matrix for the random effects "group parameters." Default: `NULL`. Must be a square matrix whose dimension matches the number of random effects bases, or a scalar value that will be expanded to a diagonal matrix. -#' - `rfx_variance_prior_shape` Shape parameter for the inverse gamma prior on the variance of the random effects "group parameter." Default: `1`. -#' - `rfx_variance_prior_scale` Scale parameter for the inverse gamma prior on the variance of the random effects "group parameter." Default: `1`. #' - `num_threads` Number of threads to use in the GFR and MCMC algorithms, as well as prediction. If OpenMP is not available on a user's setup, this will default to `1`, otherwise to the maximum number of available threads. #' #' @param prognostic_forest_params (Optional) A list of prognostic forest model parameters, each of which has a default value processed internally, so this argument list is optional. @@ -205,12 +199,6 @@ bcf <- function( num_chains = 1, verbose = FALSE, probit_outcome_model = FALSE, - rfx_working_parameter_prior_mean = NULL, - rfx_group_parameter_prior_mean = NULL, - rfx_working_parameter_prior_cov = NULL, - rfx_group_parameter_prior_cov = NULL, - rfx_variance_prior_shape = 1, - rfx_variance_prior_scale = 1, num_threads = -1 ) general_params_updated <- preprocessParams( diff --git a/man/bcf.Rd b/man/bcf.Rd index ac6449db..819ed66f 100644 --- a/man/bcf.Rd +++ b/man/bcf.Rd @@ -92,12 +92,6 @@ that were not in the training set.} \item \code{num_chains} How many independent MCMC chains should be sampled. If \code{num_mcmc = 0}, this is ignored. If \code{num_gfr = 0}, then each chain is run from root for \code{num_mcmc * keep_every + num_burnin} iterations, with \code{num_mcmc} samples retained. If \code{num_gfr > 0}, each MCMC chain will be initialized from a separate GFR ensemble, with the requirement that \code{num_gfr >= num_chains}. Default: \code{1}. \item \code{verbose} Whether or not to print progress during the sampling loops. Default: \code{FALSE}. \item \code{probit_outcome_model} Whether or not the outcome should be modeled as explicitly binary via a probit link. If \code{TRUE}, \code{y} must only contain the values \code{0} and \code{1}. Default: \code{FALSE}. -\item \code{rfx_working_parameter_prior_mean} Prior mean for the random effects "working parameter". Default: \code{NULL}. Must be a vector whose dimension matches the number of random effects bases, or a scalar value that will be expanded to a vector. -\item \code{rfx_group_parameters_prior_mean} Prior mean for the random effects "group parameters." Default: \code{NULL}. Must be a vector whose dimension matches the number of random effects bases, or a scalar value that will be expanded to a vector. -\item \code{rfx_working_parameter_prior_cov} Prior covariance matrix for the random effects "working parameter." Default: \code{NULL}. Must be a square matrix whose dimension matches the number of random effects bases, or a scalar value that will be expanded to a diagonal matrix. -\item \code{rfx_group_parameter_prior_cov} Prior covariance matrix for the random effects "group parameters." Default: \code{NULL}. Must be a square matrix whose dimension matches the number of random effects bases, or a scalar value that will be expanded to a diagonal matrix. -\item \code{rfx_variance_prior_shape} Shape parameter for the inverse gamma prior on the variance of the random effects "group parameter." Default: \code{1}. -\item \code{rfx_variance_prior_scale} Scale parameter for the inverse gamma prior on the variance of the random effects "group parameter." Default: \code{1}. \item \code{num_threads} Number of threads to use in the GFR and MCMC algorithms, as well as prediction. If OpenMP is not available on a user's setup, this will default to \code{1}, otherwise to the maximum number of available threads. }} diff --git a/stochtree/bart.py b/stochtree/bart.py index 3fddac92..bf985640 100644 --- a/stochtree/bart.py +++ b/stochtree/bart.py @@ -85,6 +85,7 @@ def sample( general_params: Optional[Dict[str, Any]] = None, mean_forest_params: Optional[Dict[str, Any]] = None, variance_forest_params: Optional[Dict[str, Any]] = None, + rfx_params: Optional[Dict[str, Any]] = None, previous_model_json: Optional[str] = None, previous_model_warmstart_sample_num: Optional[int] = None, ) -> None: @@ -135,12 +136,6 @@ def sample( * `keep_every` (`int`): How many iterations of the burned-in MCMC sampler should be run before forests and parameters are retained. Defaults to `1`. Setting `keep_every = k` for some `k > 1` will "thin" the MCMC samples by retaining every `k`-th sample, rather than simply every sample. This can reduce the autocorrelation of the MCMC samples. * `num_chains` (`int`): How many independent MCMC chains should be sampled. If `num_mcmc = 0`, this is ignored. If `num_gfr = 0`, then each chain is run from root for `num_mcmc * keep_every + num_burnin` iterations, with `num_mcmc` samples retained. If `num_gfr > 0`, each MCMC chain will be initialized from a separate GFR ensemble, with the requirement that `num_gfr >= num_chains`. Defaults to `1`. * `probit_outcome_model` (`bool`): Whether or not the outcome should be modeled as explicitly binary via a probit link. If `True`, `y` must only contain the values `0` and `1`. Default: `False`. - * `rfx_working_parameter_prior_mean`: Prior mean for the random effects "working parameter". Default: `None`. Must be a 1D numpy array whose dimension matches the number of random effects bases, or a scalar value that will be expanded to a vector. - * `rfx_group_parameter_prior_mean`: Prior mean for the random effects "group parameters." Default: `None`. Must be a 1D numpy array whose dimension matches the number of random effects bases, or a scalar value that will be expanded to a vector. - * `rfx_working_parameter_prior_cov`: Prior covariance matrix for the random effects "working parameter." Default: `None`. Must be a square numpy matrix whose dimension matches the number of random effects bases, or a scalar value that will be expanded to a diagonal matrix. - * `rfx_group_parameter_prior_cov`: Prior covariance matrix for the random effects "group parameters." Default: `None`. Must be a square numpy matrix whose dimension matches the number of random effects bases, or a scalar value that will be expanded to a diagonal matrix. - * `rfx_variance_prior_shape`: Shape parameter for the inverse gamma prior on the variance of the random effects "group parameter." Default: `1`. - * `rfx_variance_prior_scale`: Scale parameter for the inverse gamma prior on the variance of the random effects "group parameter." Default: `1`. * `num_threads`: Number of threads to use in the GFR and MCMC algorithms, as well as prediction. If OpenMP is not available on a user's setup, this will default to `1`, otherwise to the maximum number of available threads. mean_forest_params : dict, optional @@ -160,7 +155,7 @@ def sample( * `num_features_subsample` (`int`): How many features to subsample when growing each tree for the GFR algorithm. Defaults to the number of features in the training dataset. variance_forest_params : dict, optional - Dictionary of variance forest model parameters, each of which has a default value processed internally, so this argument is optional. + Dictionary of variance forest model parameters, each of which has a default value processed internally, so this argument is optional. * `num_trees` (`int`): Number of trees in the conditional variance model. Defaults to `0`. Variance is only modeled using a tree / forest if `num_trees > 0`. * `alpha` (`float`): Prior probability of splitting for a tree of depth 0 in the conditional variance model. Tree split prior combines `alpha` and `beta` via `alpha*(1+node_depth)^-beta`. Defaults to `0.95`. @@ -175,6 +170,16 @@ def sample( * `drop_vars` (`list` or `np.array`): Vector of variable names or column indices denoting variables that should be excluded from the variance forest. Defaults to `None`. If both `drop_vars` and `keep_vars` are set, `drop_vars` will be ignored. * `num_features_subsample` (`int`): How many features to subsample when growing each tree for the GFR algorithm. Defaults to the number of features in the training dataset. + rfx_params : dict, optional + Dictionary of random effects parameters, each of which has a default value processed internally, so this argument is optional. + + * `working_parameter_prior_mean`: Prior mean for the random effects "working parameter". Default: `None`. Must be a 1D numpy array whose dimension matches the number of random effects bases, or a scalar value that will be expanded to a vector. + * `group_parameter_prior_mean`: Prior mean for the random effects "group parameters." Default: `None`. Must be a 1D numpy array whose dimension matches the number of random effects bases, or a scalar value that will be expanded to a vector. + * `working_parameter_prior_cov`: Prior covariance matrix for the random effects "working parameter." Default: `None`. Must be a square numpy matrix whose dimension matches the number of random effects bases, or a scalar value that will be expanded to a diagonal matrix. + * `group_parameter_prior_cov`: Prior covariance matrix for the random effects "group parameters." Default: `None`. Must be a square numpy matrix whose dimension matches the number of random effects bases, or a scalar value that will be expanded to a diagonal matrix. + * `variance_prior_shape`: Shape parameter for the inverse gamma prior on the variance of the random effects "group parameter." Default: `1`. + * `variance_prior_scale`: Scale parameter for the inverse gamma prior on the variance of the random effects "group parameter." Default: `1`. + previous_model_json : str, optional JSON string containing a previous BART model. This can be used to "continue" a sampler interactively after inspecting the samples or to run parallel chains "warm-started" from existing forest samples. Defaults to `None`. previous_model_warmstart_sample_num : int, optional @@ -200,12 +205,6 @@ def sample( "keep_every": 1, "num_chains": 1, "probit_outcome_model": False, - "rfx_working_parameter_prior_mean": None, - "rfx_group_parameter_prior_mean": None, - "rfx_working_parameter_prior_cov": None, - "rfx_group_parameter_prior_cov": None, - "rfx_variance_prior_shape": 1.0, - "rfx_variance_prior_scale": 1.0, "num_threads": -1, } general_params_updated = _preprocess_params( @@ -250,6 +249,19 @@ def sample( variance_forest_params_default, variance_forest_params ) + # Update random effects parameters + rfx_params_default = { + "working_parameter_prior_mean": None, + "group_parameter_prior_mean": None, + "working_parameter_prior_cov": None, + "group_parameter_prior_cov": None, + "variance_prior_shape": 1.0, + "variance_prior_scale": 1.0, + } + rfx_params_updated = _preprocess_params( + rfx_params_default, rfx_params + ) + ### Unpack all parameter values # 1. General parameters cutpoint_grid_size = general_params_updated["cutpoint_grid_size"] @@ -265,20 +277,6 @@ def sample( keep_every = general_params_updated["keep_every"] num_chains = general_params_updated["num_chains"] self.probit_outcome_model = general_params_updated["probit_outcome_model"] - rfx_working_parameter_prior_mean = general_params_updated[ - "rfx_working_parameter_prior_mean" - ] - rfx_group_parameter_prior_mean = general_params_updated[ - "rfx_group_parameter_prior_mean" - ] - rfx_working_parameter_prior_cov = general_params_updated[ - "rfx_working_parameter_prior_cov" - ] - rfx_group_parameter_prior_cov = general_params_updated[ - "rfx_group_parameter_prior_cov" - ] - rfx_variance_prior_shape = general_params_updated["rfx_variance_prior_shape"] - rfx_variance_prior_scale = general_params_updated["rfx_variance_prior_scale"] num_threads = general_params_updated["num_threads"] # 2. Mean forest parameters @@ -315,6 +313,22 @@ def sample( "num_features_subsample" ] + # 4. Random effects parameters + rfx_working_parameter_prior_mean = rfx_params_updated[ + "working_parameter_prior_mean" + ] + rfx_group_parameter_prior_mean = rfx_params_updated[ + "group_parameter_prior_mean" + ] + rfx_working_parameter_prior_cov = rfx_params_updated[ + "working_parameter_prior_cov" + ] + rfx_group_parameter_prior_cov = rfx_params_updated[ + "group_parameter_prior_cov" + ] + rfx_variance_prior_shape = rfx_params_updated["variance_prior_shape"] + rfx_variance_prior_scale = rfx_params_updated["variance_prior_scale"] + # Override keep_gfr if there are no MCMC samples if num_mcmc == 0: keep_gfr = True diff --git a/stochtree/bcf.py b/stochtree/bcf.py index e361b76e..facacc79 100644 --- a/stochtree/bcf.py +++ b/stochtree/bcf.py @@ -97,6 +97,7 @@ def sample( prognostic_forest_params: Optional[Dict[str, Any]] = None, treatment_effect_forest_params: Optional[Dict[str, Any]] = None, variance_forest_params: Optional[Dict[str, Any]] = None, + rfx_params: Optional[Dict[str, Any]] = None, ) -> None: """Runs a BCF sampler on provided training set. Outcome predictions and estimates of the prognostic and treatment effect functions will be cached for the training set and (if provided) the test set. @@ -155,12 +156,6 @@ def sample( * `keep_every` (`int`): How many iterations of the burned-in MCMC sampler should be run before forests and parameters are retained. Defaults to `1`. Setting `keep_every = k` for some `k > 1` will "thin" the MCMC samples by retaining every `k`-th sample, rather than simply every sample. This can reduce the autocorrelation of the MCMC samples. * `num_chains` (`int`): How many independent MCMC chains should be sampled. If `num_mcmc = 0`, this is ignored. If `num_gfr = 0`, then each chain is run from root for `num_mcmc * keep_every + num_burnin` iterations, with `num_mcmc` samples retained. If `num_gfr > 0`, each MCMC chain will be initialized from a separate GFR ensemble, with the requirement that `num_gfr >= num_chains`. Defaults to `1`. * `probit_outcome_model` (`bool`): Whether or not the outcome should be modeled as explicitly binary via a probit link. If `True`, `y` must only contain the values `0` and `1`. Default: `False`. - * `rfx_working_parameter_prior_mean`: Prior mean for the random effects "working parameter". Default: `None`. Must be a 1D numpy array whose dimension matches the number of random effects bases, or a scalar value that will be expanded to a vector. - * `rfx_group_parameter_prior_mean`: Prior mean for the random effects "group parameters." Default: `None`. Must be a 1D numpy array whose dimension matches the number of random effects bases, or a scalar value that will be expanded to a vector. - * `rfx_working_parameter_prior_cov`: Prior covariance matrix for the random effects "working parameter." Default: `None`. Must be a square numpy matrix whose dimension matches the number of random effects bases, or a scalar value that will be expanded to a diagonal matrix. - * `rfx_group_parameter_prior_cov`: Prior covariance matrix for the random effects "group parameters." Default: `None`. Must be a square numpy matrix whose dimension matches the number of random effects bases, or a scalar value that will be expanded to a diagonal matrix. - * `rfx_variance_prior_shape`: Shape parameter for the inverse gamma prior on the variance of the random effects "group parameter." Default: `1`. - * `rfx_variance_prior_scale`: Scale parameter for the inverse gamma prior on the variance of the random effects "group parameter." Default: `1`. * `num_threads`: Number of threads to use in the GFR and MCMC algorithms, as well as prediction. If OpenMP is not available on a user's setup, this will default to `1`, otherwise to the maximum number of available threads. prognostic_forest_params : dict, optional @@ -213,6 +208,16 @@ def sample( * `drop_vars` (`list` or `np.array`): Vector of variable names or column indices denoting variables that should be excluded from the variance forest. Defaults to `None`. If both `drop_vars` and `keep_vars` are set, `drop_vars` will be ignored. * `num_features_subsample` (`int`): How many features to subsample when growing each tree for the GFR algorithm. Defaults to the number of features in the training dataset. + rfx_params : dict, optional + Dictionary of random effects parameters, each of which has a default value processed internally, so this argument is optional. + + * `working_parameter_prior_mean`: Prior mean for the random effects "working parameter". Default: `None`. Must be a 1D numpy array whose dimension matches the number of random effects bases, or a scalar value that will be expanded to a vector. + * `group_parameter_prior_mean`: Prior mean for the random effects "group parameters." Default: `None`. Must be a 1D numpy array whose dimension matches the number of random effects bases, or a scalar value that will be expanded to a vector. + * `working_parameter_prior_cov`: Prior covariance matrix for the random effects "working parameter." Default: `None`. Must be a square numpy matrix whose dimension matches the number of random effects bases, or a scalar value that will be expanded to a diagonal matrix. + * `group_parameter_prior_cov`: Prior covariance matrix for the random effects "group parameters." Default: `None`. Must be a square numpy matrix whose dimension matches the number of random effects bases, or a scalar value that will be expanded to a diagonal matrix. + * `variance_prior_shape`: Shape parameter for the inverse gamma prior on the variance of the random effects "group parameter." Default: `1`. + * `variance_prior_scale`: Scale parameter for the inverse gamma prior on the variance of the random effects "group parameter." Default: `1`. + Returns ------- self : BCFModel @@ -237,12 +242,6 @@ def sample( "keep_every": 1, "num_chains": 1, "probit_outcome_model": False, - "rfx_working_parameter_prior_mean": None, - "rfx_group_parameter_prior_mean": None, - "rfx_working_parameter_prior_cov": None, - "rfx_group_parameter_prior_cov": None, - "rfx_variance_prior_shape": 1.0, - "rfx_variance_prior_scale": 1.0, "num_threads": -1, } general_params_updated = _preprocess_params( @@ -307,6 +306,19 @@ def sample( variance_forest_params_default, variance_forest_params ) + # Update random effects parameters + rfx_params_default = { + "working_parameter_prior_mean": None, + "group_parameter_prior_mean": None, + "working_parameter_prior_cov": None, + "group_parameter_prior_cov": None, + "variance_prior_shape": 1.0, + "variance_prior_scale": 1.0, + } + rfx_params_updated = _preprocess_params( + rfx_params_default, rfx_params + ) + ### Unpack all parameter values # 1. General parameters cutpoint_grid_size = general_params_updated["cutpoint_grid_size"] @@ -326,20 +338,6 @@ def sample( keep_every = general_params_updated["keep_every"] num_chains = general_params_updated["num_chains"] self.probit_outcome_model = general_params_updated["probit_outcome_model"] - rfx_working_parameter_prior_mean = general_params_updated[ - "rfx_working_parameter_prior_mean" - ] - rfx_group_parameter_prior_mean = general_params_updated[ - "rfx_group_parameter_prior_mean" - ] - rfx_working_parameter_prior_cov = general_params_updated[ - "rfx_working_parameter_prior_cov" - ] - rfx_group_parameter_prior_cov = general_params_updated[ - "rfx_group_parameter_prior_cov" - ] - rfx_variance_prior_shape = general_params_updated["rfx_variance_prior_shape"] - rfx_variance_prior_scale = general_params_updated["rfx_variance_prior_scale"] num_threads = general_params_updated["num_threads"] # 2. Mu forest parameters @@ -397,6 +395,22 @@ def sample( "num_features_subsample" ] + # 5. Random effects parameters + rfx_working_parameter_prior_mean = rfx_params_updated[ + "working_parameter_prior_mean" + ] + rfx_group_parameter_prior_mean = rfx_params_updated[ + "group_parameter_prior_mean" + ] + rfx_working_parameter_prior_cov = rfx_params_updated[ + "working_parameter_prior_cov" + ] + rfx_group_parameter_prior_cov = rfx_params_updated[ + "group_parameter_prior_cov" + ] + rfx_variance_prior_shape = rfx_params_updated["variance_prior_shape"] + rfx_variance_prior_scale = rfx_params_updated["variance_prior_scale"] + # Override keep_gfr if there are no MCMC samples if num_mcmc == 0: keep_gfr = True diff --git a/test/python/test_bart.py b/test/python/test_bart.py index 4b22ab7b..6e26d84e 100644 --- a/test/python/test_bart.py +++ b/test/python/test_bart.py @@ -1122,13 +1122,13 @@ def conditional_stddev(X): ) # Specify scalar rfx parameters - general_params = { - "rfx_working_parameter_prior_mean": 1., - "rfx_group_parameter_prior_mean": 1., - "rfx_working_parameter_prior_cov": 1., - "rfx_group_parameter_prior_cov": 1., - "rfx_variance_prior_shape": 1, - "rfx_variance_prior_scale": 1 + rfx_params = { + "working_parameter_prior_mean": 1., + "group_parameter_prior_mean": 1., + "working_parameter_prior_cov": 1., + "group_parameter_prior_cov": 1., + "variance_prior_shape": 1, + "variance_prior_scale": 1 } bart_model_2 = BARTModel() bart_model_2.sample( @@ -1144,17 +1144,17 @@ def conditional_stddev(X): num_gfr=num_gfr, num_burnin=num_burnin, num_mcmc=num_mcmc, - general_params=general_params, + rfx_params=rfx_params, ) # Specify all relevant rfx parameters as vectors - general_params = { - "rfx_working_parameter_prior_mean": np.repeat(1., num_rfx_basis), - "rfx_group_parameter_prior_mean": np.repeat(1., num_rfx_basis), - "rfx_working_parameter_prior_cov": np.identity(num_rfx_basis), - "rfx_group_parameter_prior_cov": np.identity(num_rfx_basis), - "rfx_variance_prior_shape": 1, - "rfx_variance_prior_scale": 1 + rfx_params = { + "working_parameter_prior_mean": np.repeat(1., num_rfx_basis), + "group_parameter_prior_mean": np.repeat(1., num_rfx_basis), + "working_parameter_prior_cov": np.identity(num_rfx_basis), + "group_parameter_prior_cov": np.identity(num_rfx_basis), + "variance_prior_shape": 1, + "variance_prior_scale": 1 } bart_model_3 = BARTModel() bart_model_3.sample( @@ -1170,5 +1170,5 @@ def conditional_stddev(X): num_gfr=num_gfr, num_burnin=num_burnin, num_mcmc=num_mcmc, - general_params=general_params, + rfx_params=rfx_params, ) diff --git a/test/python/test_bcf.py b/test/python/test_bcf.py index ed0ac3ae..e03631fa 100644 --- a/test/python/test_bcf.py +++ b/test/python/test_bcf.py @@ -860,13 +860,13 @@ def rfx_term(group_labels, basis): ) # Specify scalar rfx parameters - general_params = { - "rfx_working_parameter_prior_mean": 1., - "rfx_group_parameter_prior_mean": 1., - "rfx_working_parameter_prior_cov": 1., - "rfx_group_parameter_prior_cov": 1., - "rfx_variance_prior_shape": 1, - "rfx_variance_prior_scale": 1 + rfx_params = { + "working_parameter_prior_mean": 1., + "group_parameter_prior_mean": 1., + "working_parameter_prior_cov": 1., + "group_parameter_prior_cov": 1., + "variance_prior_shape": 1, + "variance_prior_scale": 1 } bcf_model_2 = BCFModel() bcf_model_2.sample( @@ -884,17 +884,17 @@ def rfx_term(group_labels, basis): num_gfr=num_gfr, num_burnin=num_burnin, num_mcmc=num_mcmc, - general_params=general_params + rfx_params=rfx_params ) # Specify all relevant rfx parameters as vectors - general_params = { - "rfx_working_parameter_prior_mean": np.repeat(1., num_rfx_basis), - "rfx_group_parameter_prior_mean": np.repeat(1., num_rfx_basis), - "rfx_working_parameter_prior_cov": np.identity(num_rfx_basis), - "rfx_group_parameter_prior_cov": np.identity(num_rfx_basis), - "rfx_variance_prior_shape": 1, - "rfx_variance_prior_scale": 1 + rfx_params = { + "working_parameter_prior_mean": np.repeat(1., num_rfx_basis), + "group_parameter_prior_mean": np.repeat(1., num_rfx_basis), + "working_parameter_prior_cov": np.identity(num_rfx_basis), + "group_parameter_prior_cov": np.identity(num_rfx_basis), + "variance_prior_shape": 1, + "variance_prior_scale": 1 } bcf_model_3 = BCFModel() bcf_model_3.sample( @@ -912,5 +912,5 @@ def rfx_term(group_labels, basis): num_gfr=num_gfr, num_burnin=num_burnin, num_mcmc=num_mcmc, - general_params=general_params + rfx_params=rfx_params ) From 7d7f38ac579cb7b4b33d5724069d882c2a307c54 Mon Sep 17 00:00:00 2001 From: Drew Herren Date: Wed, 22 Oct 2025 02:24:02 -0500 Subject: [PATCH 32/53] Updated R interface and added contrast methods to python interface --- R/bart.R | 4 +- R/bcf.R | 6 +- R/posterior_transformation.R | 8 +- man/bart.Rd | 2 +- man/bcf.Rd | 4 +- man/compute_contrast_bart_model.Rd | 4 +- man/compute_contrast_bcf_model.Rd | 4 +- stochtree/bart.py | 134 +++++++++++++++++++++++++++++ stochtree/bcf.py | 109 +++++++++++++++++++++++ 9 files changed, 259 insertions(+), 16 deletions(-) diff --git a/R/bart.R b/R/bart.R index 789fccdf..4369c273 100644 --- a/R/bart.R +++ b/R/bart.R @@ -131,7 +131,7 @@ bart <- function( general_params = list(), mean_forest_params = list(), variance_forest_params = list(), - rfx_params = list() + random_effects_params = list() ) { # Update general BART parameters general_params_default <- list( @@ -207,7 +207,7 @@ bart <- function( ) rfx_params_updated <- preprocessParams( rfx_params_default, - rfx_params + random_effects_params ) ### Unpack all parameter values diff --git a/R/bcf.R b/R/bcf.R index 41136fcc..7c27cefe 100644 --- a/R/bcf.R +++ b/R/bcf.R @@ -97,7 +97,7 @@ #' - `drop_vars` Vector of variable names or column indices denoting variables that should be excluded from the forest. Default: `NULL`. If both `drop_vars` and `keep_vars` are set, `drop_vars` will be ignored. #' - `num_features_subsample` How many features to subsample when growing each tree for the GFR algorithm. Defaults to the number of features in the training dataset. #' -#' @param rfx_params (Optional) A list of random effects model parameters, each of which has a default value processed internally, so this argument list is optional. +#' @param random_effects_params (Optional) A list of random effects model parameters, each of which has a default value processed internally, so this argument list is optional. #' #' - `working_parameter_prior_mean` Prior mean for the random effects "working parameter". Default: `NULL`. Must be a vector whose dimension matches the number of random effects bases, or a scalar value that will be expanded to a vector. #' - `group_parameters_prior_mean` Prior mean for the random effects "group parameters." Default: `NULL`. Must be a vector whose dimension matches the number of random effects bases, or a scalar value that will be expanded to a vector. @@ -176,7 +176,7 @@ bcf <- function( prognostic_forest_params = list(), treatment_effect_forest_params = list(), variance_forest_params = list(), - rfx_params = list() + random_effects_params = list() ) { # Update general BCF parameters general_params_default <- list( @@ -278,7 +278,7 @@ bcf <- function( ) rfx_params_updated <- preprocessParams( rfx_params_default, - rfx_params + random_effects_params ) ### Unpack all parameter values diff --git a/R/posterior_transformation.R b/R/posterior_transformation.R index b8f1e1de..b401cffd 100644 --- a/R/posterior_transformation.R +++ b/R/posterior_transformation.R @@ -24,8 +24,8 @@ #' for group labels that were not in the training set. Must be a vector. #' @param rfx_basis_0 (Optional) Test set basis for used for prediction from an additive random effects model in the "control" case. Must be a matrix or vector. #' @param rfx_basis_1 (Optional) Test set basis for used for prediction from an additive random effects model in the "treatment" case. Must be a matrix or vector. -#' @param type (Optional) Type of prediction to return. Options are "mean", which averages the contrast evaluations over every draw of a BCF model, and "posterior", which returns the entire matrix of posterior contrast estimates. Default: "posterior". -#' @param scale (Optional) Scale of mean function predictions. Options are "linear", which returns a contrast on the original scale of the mean forest / RFX terms, and "probability", which transforms each contrast term into a probability of observing `y == 1` before taking their difference. "probability" is only valid for models fit with a probit outcome model. Default: "linear". +#' @param type (Optional) Aggregation level of the contrast. Options are "mean", which averages the contrast evaluations over every draw of a BCF model, and "posterior", which returns the entire matrix of posterior contrast estimates. Default: "posterior". +#' @param scale (Optional) Scale of the contrast. Options are "linear", which returns a contrast on the original scale of the mean forest / RFX terms, and "probability", which transforms each contrast term into a probability of observing `y == 1` before taking their difference. "probability" is only valid for models fit with a probit outcome model. Default: "linear". #' #' @return List of prediction matrices or single prediction matrix / vector, depending on the terms requested. #' @export @@ -268,8 +268,8 @@ compute_contrast_bcf_model <- function( #' for group labels that were not in the training set. Must be a vector. #' @param rfx_basis_0 (Optional) Test set basis for used for prediction from an additive random effects model in the "control" case. Must be a matrix or vector. #' @param rfx_basis_1 (Optional) Test set basis for used for prediction from an additive random effects model in the "treatment" case. Must be a matrix or vector. -#' @param type (Optional) Type of prediction to return. Options are "mean", which averages the contrast evaluations over every draw of a BART model, and "posterior", which returns the entire matrix of posterior contrast estimates. Default: "posterior". -#' @param scale (Optional) Scale of mean function predictions. Options are "linear", which returns a contrast on the original scale of the mean forest / RFX terms, and "probability", which transforms each contrast term into a probability of observing `y == 1` before taking their difference. "probability" is only valid for models fit with a probit outcome model. Default: "linear". +#' @param type (Optional) Aggregation level of the contrast. Options are "mean", which averages the contrast evaluations over every draw of a BART model, and "posterior", which returns the entire matrix of posterior contrast estimates. Default: "posterior". +#' @param scale (Optional) Scale of the contrast. Options are "linear", which returns a contrast on the original scale of the mean forest / RFX terms, and "probability", which transforms each contrast term into a probability of observing `y == 1` before taking their difference. "probability" is only valid for models fit with a probit outcome model. Default: "linear". #' #' @return Contrast matrix or vector, depending on whether type = "mean" or "posterior". #' @export diff --git a/man/bart.Rd b/man/bart.Rd index 6bc14615..73600fd4 100644 --- a/man/bart.Rd +++ b/man/bart.Rd @@ -22,7 +22,7 @@ bart( general_params = list(), mean_forest_params = list(), variance_forest_params = list(), - rfx_params = list() + random_effects_params = list() ) } \arguments{ diff --git a/man/bcf.Rd b/man/bcf.Rd index 819ed66f..ca3b2983 100644 --- a/man/bcf.Rd +++ b/man/bcf.Rd @@ -25,7 +25,7 @@ bcf( prognostic_forest_params = list(), treatment_effect_forest_params = list(), variance_forest_params = list(), - rfx_params = list() + random_effects_params = list() ) } \arguments{ @@ -146,7 +146,7 @@ that were not in the training set.} \item \code{num_features_subsample} How many features to subsample when growing each tree for the GFR algorithm. Defaults to the number of features in the training dataset. }} -\item{rfx_params}{(Optional) A list of random effects model parameters, each of which has a default value processed internally, so this argument list is optional. +\item{random_effects_params}{(Optional) A list of random effects model parameters, each of which has a default value processed internally, so this argument list is optional. \itemize{ \item \code{working_parameter_prior_mean} Prior mean for the random effects "working parameter". Default: \code{NULL}. Must be a vector whose dimension matches the number of random effects bases, or a scalar value that will be expanded to a vector. \item \code{group_parameters_prior_mean} Prior mean for the random effects "group parameters." Default: \code{NULL}. Must be a vector whose dimension matches the number of random effects bases, or a scalar value that will be expanded to a vector. diff --git a/man/compute_contrast_bart_model.Rd b/man/compute_contrast_bart_model.Rd index c5bb2d09..8a0c3096 100644 --- a/man/compute_contrast_bart_model.Rd +++ b/man/compute_contrast_bart_model.Rd @@ -47,9 +47,9 @@ for group labels that were not in the training set. Must be a vector.} \item{rfx_basis_1}{(Optional) Test set basis for used for prediction from an additive random effects model in the "treatment" case. Must be a matrix or vector.} -\item{type}{(Optional) Type of prediction to return. Options are "mean", which averages the contrast evaluations over every draw of a BART model, and "posterior", which returns the entire matrix of posterior contrast estimates. Default: "posterior".} +\item{type}{(Optional) Aggregation level of the contrast. Options are "mean", which averages the contrast evaluations over every draw of a BART model, and "posterior", which returns the entire matrix of posterior contrast estimates. Default: "posterior".} -\item{scale}{(Optional) Scale of mean function predictions. Options are "linear", which returns a contrast on the original scale of the mean forest / RFX terms, and "probability", which transforms each contrast term into a probability of observing \code{y == 1} before taking their difference. "probability" is only valid for models fit with a probit outcome model. Default: "linear".} +\item{scale}{(Optional) Scale of the contrast. Options are "linear", which returns a contrast on the original scale of the mean forest / RFX terms, and "probability", which transforms each contrast term into a probability of observing \code{y == 1} before taking their difference. "probability" is only valid for models fit with a probit outcome model. Default: "linear".} } \value{ Contrast matrix or vector, depending on whether type = "mean" or "posterior". diff --git a/man/compute_contrast_bcf_model.Rd b/man/compute_contrast_bcf_model.Rd index 8aa1ec59..d28e77b0 100644 --- a/man/compute_contrast_bcf_model.Rd +++ b/man/compute_contrast_bcf_model.Rd @@ -56,9 +56,9 @@ for group labels that were not in the training set. Must be a vector.} \item{rfx_basis_1}{(Optional) Test set basis for used for prediction from an additive random effects model in the "treatment" case. Must be a matrix or vector.} -\item{type}{(Optional) Type of prediction to return. Options are "mean", which averages the contrast evaluations over every draw of a BCF model, and "posterior", which returns the entire matrix of posterior contrast estimates. Default: "posterior".} +\item{type}{(Optional) Aggregation level of the contrast. Options are "mean", which averages the contrast evaluations over every draw of a BCF model, and "posterior", which returns the entire matrix of posterior contrast estimates. Default: "posterior".} -\item{scale}{(Optional) Scale of mean function predictions. Options are "linear", which returns a contrast on the original scale of the mean forest / RFX terms, and "probability", which transforms each contrast term into a probability of observing \code{y == 1} before taking their difference. "probability" is only valid for models fit with a probit outcome model. Default: "linear".} +\item{scale}{(Optional) Scale of the contrast. Options are "linear", which returns a contrast on the original scale of the mean forest / RFX terms, and "probability", which transforms each contrast term into a probability of observing \code{y == 1} before taking their difference. "probability" is only valid for models fit with a probit outcome model. Default: "linear".} } \value{ List of prediction matrices or single prediction matrix / vector, depending on the terms requested. diff --git a/stochtree/bart.py b/stochtree/bart.py index bf985640..2e3c6706 100644 --- a/stochtree/bart.py +++ b/stochtree/bart.py @@ -1872,6 +1872,140 @@ def predict( result["variance_forest_predictions"] = None return result + def compute_contrast( + self, + covariates_0: Union[np.array, pd.DataFrame], + covariates_1: Union[np.array, pd.DataFrame], + basis_0: np.array = None, + basis_1: np.array = None, + rfx_group_ids_0: np.array = None, + rfx_group_ids_1: np.array = None, + rfx_basis_0: np.array = None, + rfx_basis_1: np.array = None, + type: str = "posterior", + scale: str = "linear", + ) -> Union[np.array, tuple]: + """Compute a contrast using a BART model by making two sets of outcome predictions and taking their + difference. This function provides the flexibility to compute any contrast of interest by specifying + covariates, leaf basis, and random effects bases / IDs for both sides of a two term contrast. + For simplicity, we refer to the subtrahend of the contrast as the "control" or `Y0` term and the minuend + of the contrast as the `Y1` term, though the requested contrast need not match the "control vs treatment" + terminology of a classic two-treatment causal inference problem. We mirror the function calls and + terminology of the `predict.bartmodel` function, labeling each prediction data term with a `1` to denote + its contribution to the treatment prediction of a contrast and `0` to denote inclusion in the control prediction. + + Parameters + ---------- + covariates_0 : np.array or pd.DataFrame + Covariates used for prediction in the "control" case. Must be a numpy array or dataframe. + covariates_1 : np.array or pd.DataFrame + Covariates used for prediction in the "treatment" case. Must be a numpy array or dataframe. + basis_0 : np.array, optional + Bases used for prediction in the "control" case (by e.g. dot product with leaf values). + basis_1 : np.array, optional + Bases used for prediction in the "treatment" case (by e.g. dot product with leaf values). + rfx_group_ids_0 : np.array, optional + Test set group labels used for prediction from an additive random effects model in the "control" case. + We do not currently support (but plan to in the near future), test set evaluation for group labels that + were not in the training set. Must be a numpy array. + rfx_group_ids_1 : np.array, optional + Test set group labels used for prediction from an additive random effects model in the "treatment" case. + We do not currently support (but plan to in the near future), test set evaluation for group labels that + were not in the training set. Must be a numpy array. + rfx_basis_0 : np.array, optional + Test set basis for used for prediction from an additive random effects model in the "control" case. + rfx_basis_1 : np.array, optional + Test set basis for used for prediction from an additive random effects model in the "treatment" case. + type : str, optional + Aggregation level of the contrast. Options are "mean", which averages the contrast evaluations over every draw of a BART model, and "posterior", which returns the entire matrix of posterior contrast estimates. Default: "posterior". + scale : str, optional + Scale of the contrast. Options are "linear", which returns predictions on the original scale of the mean forest / RFX terms, and "probability", which transforms predictions into a probability of observing `y == 1`. "probability" is only valid for models fit with a probit outcome model. Default: "linear". + + Returns + ------- + Array, either 1d or 2d depending on whether type = "mean" or "posterior". + """ + # Handle mean function scale + if not isinstance(scale, str): + raise ValueError("scale must be a string") + if scale not in ["linear", "probability"]: + raise ValueError("scale must either be 'linear' or 'probability'") + is_probit = self.probit_outcome_model + if (scale == "probability") and (not is_probit): + raise ValueError( + "scale cannot be 'probability' for models not fit with a probit outcome model" + ) + probability_scale = scale == "probability" + + # Handle prediction type + if not isinstance(type, str): + raise ValueError("type must be a string") + if type not in ["mean", "posterior"]: + raise ValueError("type must either be 'mean' or 'posterior'") + predict_mean = type == "mean" + + # Handle prediction terms + has_mean_forest = self.include_mean_forest + has_rfx = self.has_rfx + + # Check that we have at least one term to predict on probability scale + if ( + not has_mean_forest + and not has_rfx + ): + raise ValueError( + "Contrast cannot be computed as the model does not have a mean forest or random effects term" + ) + + # Check the model is valid + if not self.is_sampled(): + msg = ( + "This BARTModel instance is not fitted yet. Call 'fit' with " + "appropriate arguments before using this model." + ) + raise NotSampledError(msg) + + # Data checks + if not isinstance(covariates_0, pd.DataFrame) and not isinstance( + covariates_0, np.ndarray + ): + raise ValueError("covariates_0 must be a pandas dataframe or numpy array") + if not isinstance(covariates_1, pd.DataFrame) and not isinstance( + covariates_1, np.ndarray + ): + raise ValueError("covariates_1 must be a pandas dataframe or numpy array") + if basis_0 is not None: + if not isinstance(basis_0, np.ndarray): + raise ValueError("basis_0 must be a numpy array") + if basis_0.shape[0] != covariates_0.shape[0]: + raise ValueError( + "covariates_0 and basis_0 must have the same number of rows" + ) + if basis_1 is not None: + if not isinstance(basis_1, np.ndarray): + raise ValueError("basis_1 must be a numpy array") + if basis_1.shape[0] != covariates_1.shape[0]: + raise ValueError( + "covariates_1 and basis_1 must have the same number of rows" + ) + + # Predict for the control arm + control_preds = self.predict(covariates=covariates_0, basis=basis_0, rfx_group_ids=rfx_group_ids_0, rfx_basis=rfx_basis_0, type="posterior", terms="y_hat", scale="linear") + + # Predict for the treatment arm + treatment_preds = self.predict(covariates=covariates_1, basis=basis_1, rfx_group_ids=rfx_group_ids_1, rfx_basis=rfx_basis_1, type="posterior", terms="y_hat", scale="linear") + + # Transform to probability scale if requested + if probability_scale: + treatment_preds = norm.ppf(treatment_preds) + control_preds = norm.ppf(control_preds) + + # Compute and return contrast + if predict_mean: + return(np.mean(treatment_preds - control_preds, axis=1)) + else: + return(treatment_preds - control_preds) + def compute_posterior_interval( self, terms: Union[list[str], str] = "all", diff --git a/stochtree/bcf.py b/stochtree/bcf.py index facacc79..16e5d7e0 100644 --- a/stochtree/bcf.py +++ b/stochtree/bcf.py @@ -2541,6 +2541,115 @@ def predict( result["variance_forest_predictions"] = None return result + def compute_contrast( + self, + X_0: np.array, + X_1: np.array, + Z_0: np.array, + Z_1: np.array, + propensity_0: np.array = None, + propensity_1: np.array = None, + rfx_group_ids_0: np.array = None, + rfx_group_ids_1: np.array = None, + rfx_basis_0: np.array = None, + rfx_basis_1: np.array = None, + type: str = "posterior", + scale: str = "linear", + ) -> dict: + """Compute a contrast using a BCF model by making two sets of outcome predictions and taking their + difference. This function provides the flexibility to compute any contrast of interest by specifying + covariates, leaf basis, and random effects bases / IDs for both sides of a two term contrast. + For simplicity, we refer to the subtrahend of the contrast as the "control" or `Y0` term and the minuend + of the contrast as the `Y1` term, though the requested contrast need not match the "control vs treatment" + terminology of a classic two-treatment causal inference problem. We mirror the function calls and + terminology of the `predict.bartmodel` function, labeling each prediction data term with a `1` to denote + its contribution to the treatment prediction of a contrast and `0` to denote inclusion in the control prediction. + + Parameters + ---------- + X_0 : np.array or pd.DataFrame + Covariates used for prediction in the "control" case. Must be a numpy array or dataframe. + X_1 : np.array or pd.DataFrame + Covariates used for prediction in the "treatment" case. Must be a numpy array or dataframe. + Z_0 : np.array + Treatments used for prediction in the "control" case. Must be a numpy array or vector. + Z_1 : np.array + Treatments used for prediction in the "treatment" case. Must be a numpy array or vector. + propensity_0 : `np.array`, optional + Propensities used for prediction in the "control" case. Must be a numpy array or vector. + propensity_1 : `np.array`, optional + Propensities used for prediction in the "treatment" case. Must be a numpy array or vector. + rfx_group_ids_0 : np.array, optional + Test set group labels used for prediction from an additive random effects model in the "control" case. + We do not currently support (but plan to in the near future), test set evaluation for group labels that + were not in the training set. Must be a numpy array. + rfx_group_ids_1 : np.array, optional + Test set group labels used for prediction from an additive random effects model in the "control" case. + We do not currently support (but plan to in the near future), test set evaluation for group labels that + were not in the training set. Must be a numpy array. + rfx_basis_0 : np.array, optional + Test set basis for used for prediction from an additive random effects model in the "control" case. Must be a numpy array. + rfx_basis_1 : np.array, optional + Test set basis for used for prediction from an additive random effects model in the "treatment" case. Must be a numpy array. + type : str, optional + Aggregation level of the contrast. Options are "mean", which averages the contrast evaluations over every draw of a BCF model, and "posterior", which returns the entire matrix of posterior contrast estimates. Default: "posterior". + scale : str, optional + Scale of the contrast. Options are "linear", which returns a contrast on the original scale of the mean forest / RFX terms, and "probability", which transforms each contrast term into a probability of observing `y == 1` before taking their difference. "probability" is only valid for models fit with a probit outcome model. Default: "linear". + + Returns + ------- + Array, either 1d or 2d depending on whether type = "mean" or "posterior". + """ + # Handle mean function scale + if not isinstance(scale, str): + raise ValueError("scale must be a string") + if scale not in ["linear", "probability"]: + raise ValueError("scale must either be 'linear' or 'probability'") + is_probit = self.probit_outcome_model + if (scale == "probability") and (not is_probit): + raise ValueError( + "scale cannot be 'probability' for models not fit with a probit outcome model" + ) + probability_scale = scale == "probability" + + # Handle prediction type + if not isinstance(type, str): + raise ValueError("type must be a string") + if type not in ["mean", "posterior"]: + raise ValueError("type must either be 'mean' or 'posterior'") + predict_mean = type == "mean" + + # Check the model is valid + if not self.is_sampled(): + msg = ( + "This BCFModel instance is not fitted yet. Call 'fit' with " + "appropriate arguments before using this model." + ) + raise NotSampledError(msg) + + # Data checks + if Z_0.shape[0] != X_0.shape[0]: + raise ValueError("X_0 and Z_0 must have the same number of rows") + if Z_1.shape[0] != X_1.shape[0]: + raise ValueError("X_1 and Z_1 must have the same number of rows") + + # Predict for the control arm + control_preds = self.predict(X=X_0, Z=Z_0, propensity=propensity_0, rfx_group_ids=rfx_group_ids_0, rfx_basis=rfx_basis_0, type="posterior", terms="y_hat", scale="linear") + + # Predict for the treatment arm + treatment_preds = self.predict(X=X_1, Z=Z_1, propensity=propensity_1, rfx_group_ids=rfx_group_ids_1, rfx_basis=rfx_basis_1, type="posterior", terms="y_hat", scale="linear") + + # Transform to probability scale if requested + if probability_scale: + treatment_preds = norm.ppf(treatment_preds) + control_preds = norm.ppf(control_preds) + + # Compute and return contrast + if predict_mean: + return(np.mean(treatment_preds - control_preds, axis=1)) + else: + return(treatment_preds - control_preds) + def compute_posterior_interval( self, terms: Union[list[str], str] = "all", From 51d3da7e870dc281d61fb12067052c7b61880127 Mon Sep 17 00:00:00 2001 From: Drew Herren Date: Wed, 22 Oct 2025 02:33:24 -0500 Subject: [PATCH 33/53] Updated R unit tests --- test/R/testthat/test-bart.R | 4 ++-- test/R/testthat/test-bcf.R | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/test/R/testthat/test-bart.R b/test/R/testthat/test-bart.R index de45d73d..c3f923aa 100644 --- a/test/R/testthat/test-bart.R +++ b/test/R/testthat/test-bart.R @@ -529,7 +529,7 @@ test_that("Random Effects BART", { num_burnin = 10, num_mcmc = 10, mean_forest_params = mean_forest_param_list, - rfx_params = rfx_param_list + random_effects_params = rfx_param_list ) ) @@ -558,7 +558,7 @@ test_that("Random Effects BART", { num_burnin = 10, num_mcmc = 10, mean_forest_params = mean_forest_param_list, - rfx_params = rfx_param_list + random_effects_params = rfx_param_list ) ) }) diff --git a/test/R/testthat/test-bcf.R b/test/R/testthat/test-bcf.R index 4cf26609..221c333f 100644 --- a/test/R/testthat/test-bcf.R +++ b/test/R/testthat/test-bcf.R @@ -750,7 +750,7 @@ test_that("Random Effects BCF", { num_gfr = 10, num_burnin = 0, num_mcmc = 10, - rfx_params = rfx_param_list + random_effects_params = rfx_param_list ) ) @@ -779,7 +779,7 @@ test_that("Random Effects BCF", { num_gfr = 10, num_burnin = 0, num_mcmc = 10, - rfx_params = rfx_param_list + random_effects_params = rfx_param_list ) ) }) From f572cb9ac6d53bdc2de3db6c6da62aea1e3c57ba Mon Sep 17 00:00:00 2001 From: Drew Herren Date: Wed, 22 Oct 2025 12:05:27 -0500 Subject: [PATCH 34/53] Reformat python code --- stochtree/bart.py | 61 +++++++++++++++++++++++++++-------------------- stochtree/bcf.py | 56 +++++++++++++++++++++++++++---------------- 2 files changed, 70 insertions(+), 47 deletions(-) diff --git a/stochtree/bart.py b/stochtree/bart.py index 2e3c6706..86eb67f1 100644 --- a/stochtree/bart.py +++ b/stochtree/bart.py @@ -258,9 +258,7 @@ def sample( "variance_prior_shape": 1.0, "variance_prior_scale": 1.0, } - rfx_params_updated = _preprocess_params( - rfx_params_default, rfx_params - ) + rfx_params_updated = _preprocess_params(rfx_params_default, rfx_params) ### Unpack all parameter values # 1. General parameters @@ -323,9 +321,7 @@ def sample( rfx_working_parameter_prior_cov = rfx_params_updated[ "working_parameter_prior_cov" ] - rfx_group_parameter_prior_cov = rfx_params_updated[ - "group_parameter_prior_cov" - ] + rfx_group_parameter_prior_cov = rfx_params_updated["group_parameter_prior_cov"] rfx_variance_prior_shape = rfx_params_updated["variance_prior_shape"] rfx_variance_prior_scale = rfx_params_updated["variance_prior_scale"] @@ -1885,13 +1881,13 @@ def compute_contrast( type: str = "posterior", scale: str = "linear", ) -> Union[np.array, tuple]: - """Compute a contrast using a BART model by making two sets of outcome predictions and taking their - difference. This function provides the flexibility to compute any contrast of interest by specifying - covariates, leaf basis, and random effects bases / IDs for both sides of a two term contrast. - For simplicity, we refer to the subtrahend of the contrast as the "control" or `Y0` term and the minuend - of the contrast as the `Y1` term, though the requested contrast need not match the "control vs treatment" - terminology of a classic two-treatment causal inference problem. We mirror the function calls and - terminology of the `predict.bartmodel` function, labeling each prediction data term with a `1` to denote + """Compute a contrast using a BART model by making two sets of outcome predictions and taking their + difference. This function provides the flexibility to compute any contrast of interest by specifying + covariates, leaf basis, and random effects bases / IDs for both sides of a two term contrast. + For simplicity, we refer to the subtrahend of the contrast as the "control" or `Y0` term and the minuend + of the contrast as the `Y1` term, though the requested contrast need not match the "control vs treatment" + terminology of a classic two-treatment causal inference problem. We mirror the function calls and + terminology of the `predict.bartmodel` function, labeling each prediction data term with a `1` to denote its contribution to the treatment prediction of a contrast and `0` to denote inclusion in the control prediction. Parameters @@ -1905,12 +1901,12 @@ def compute_contrast( basis_1 : np.array, optional Bases used for prediction in the "treatment" case (by e.g. dot product with leaf values). rfx_group_ids_0 : np.array, optional - Test set group labels used for prediction from an additive random effects model in the "control" case. - We do not currently support (but plan to in the near future), test set evaluation for group labels that + Test set group labels used for prediction from an additive random effects model in the "control" case. + We do not currently support (but plan to in the near future), test set evaluation for group labels that were not in the training set. Must be a numpy array. rfx_group_ids_1 : np.array, optional - Test set group labels used for prediction from an additive random effects model in the "treatment" case. - We do not currently support (but plan to in the near future), test set evaluation for group labels that + Test set group labels used for prediction from an additive random effects model in the "treatment" case. + We do not currently support (but plan to in the near future), test set evaluation for group labels that were not in the training set. Must be a numpy array. rfx_basis_0 : np.array, optional Test set basis for used for prediction from an additive random effects model in the "control" case. @@ -1949,10 +1945,7 @@ def compute_contrast( has_rfx = self.has_rfx # Check that we have at least one term to predict on probability scale - if ( - not has_mean_forest - and not has_rfx - ): + if not has_mean_forest and not has_rfx: raise ValueError( "Contrast cannot be computed as the model does not have a mean forest or random effects term" ) @@ -1988,12 +1981,28 @@ def compute_contrast( raise ValueError( "covariates_1 and basis_1 must have the same number of rows" ) - + # Predict for the control arm - control_preds = self.predict(covariates=covariates_0, basis=basis_0, rfx_group_ids=rfx_group_ids_0, rfx_basis=rfx_basis_0, type="posterior", terms="y_hat", scale="linear") + control_preds = self.predict( + covariates=covariates_0, + basis=basis_0, + rfx_group_ids=rfx_group_ids_0, + rfx_basis=rfx_basis_0, + type="posterior", + terms="y_hat", + scale="linear", + ) # Predict for the treatment arm - treatment_preds = self.predict(covariates=covariates_1, basis=basis_1, rfx_group_ids=rfx_group_ids_1, rfx_basis=rfx_basis_1, type="posterior", terms="y_hat", scale="linear") + treatment_preds = self.predict( + covariates=covariates_1, + basis=basis_1, + rfx_group_ids=rfx_group_ids_1, + rfx_basis=rfx_basis_1, + type="posterior", + terms="y_hat", + scale="linear", + ) # Transform to probability scale if requested if probability_scale: @@ -2002,9 +2011,9 @@ def compute_contrast( # Compute and return contrast if predict_mean: - return(np.mean(treatment_preds - control_preds, axis=1)) + return np.mean(treatment_preds - control_preds, axis=1) else: - return(treatment_preds - control_preds) + return treatment_preds - control_preds def compute_posterior_interval( self, diff --git a/stochtree/bcf.py b/stochtree/bcf.py index 16e5d7e0..d4fae5ad 100644 --- a/stochtree/bcf.py +++ b/stochtree/bcf.py @@ -315,9 +315,7 @@ def sample( "variance_prior_shape": 1.0, "variance_prior_scale": 1.0, } - rfx_params_updated = _preprocess_params( - rfx_params_default, rfx_params - ) + rfx_params_updated = _preprocess_params(rfx_params_default, rfx_params) ### Unpack all parameter values # 1. General parameters @@ -405,9 +403,7 @@ def sample( rfx_working_parameter_prior_cov = rfx_params_updated[ "working_parameter_prior_cov" ] - rfx_group_parameter_prior_cov = rfx_params_updated[ - "group_parameter_prior_cov" - ] + rfx_group_parameter_prior_cov = rfx_params_updated["group_parameter_prior_cov"] rfx_variance_prior_shape = rfx_params_updated["variance_prior_shape"] rfx_variance_prior_scale = rfx_params_updated["variance_prior_scale"] @@ -2556,13 +2552,13 @@ def compute_contrast( type: str = "posterior", scale: str = "linear", ) -> dict: - """Compute a contrast using a BCF model by making two sets of outcome predictions and taking their - difference. This function provides the flexibility to compute any contrast of interest by specifying - covariates, leaf basis, and random effects bases / IDs for both sides of a two term contrast. - For simplicity, we refer to the subtrahend of the contrast as the "control" or `Y0` term and the minuend - of the contrast as the `Y1` term, though the requested contrast need not match the "control vs treatment" - terminology of a classic two-treatment causal inference problem. We mirror the function calls and - terminology of the `predict.bartmodel` function, labeling each prediction data term with a `1` to denote + """Compute a contrast using a BCF model by making two sets of outcome predictions and taking their + difference. This function provides the flexibility to compute any contrast of interest by specifying + covariates, leaf basis, and random effects bases / IDs for both sides of a two term contrast. + For simplicity, we refer to the subtrahend of the contrast as the "control" or `Y0` term and the minuend + of the contrast as the `Y1` term, though the requested contrast need not match the "control vs treatment" + terminology of a classic two-treatment causal inference problem. We mirror the function calls and + terminology of the `predict.bartmodel` function, labeling each prediction data term with a `1` to denote its contribution to the treatment prediction of a contrast and `0` to denote inclusion in the control prediction. Parameters @@ -2580,12 +2576,12 @@ def compute_contrast( propensity_1 : `np.array`, optional Propensities used for prediction in the "treatment" case. Must be a numpy array or vector. rfx_group_ids_0 : np.array, optional - Test set group labels used for prediction from an additive random effects model in the "control" case. - We do not currently support (but plan to in the near future), test set evaluation for group labels that + Test set group labels used for prediction from an additive random effects model in the "control" case. + We do not currently support (but plan to in the near future), test set evaluation for group labels that were not in the training set. Must be a numpy array. rfx_group_ids_1 : np.array, optional - Test set group labels used for prediction from an additive random effects model in the "control" case. - We do not currently support (but plan to in the near future), test set evaluation for group labels that + Test set group labels used for prediction from an additive random effects model in the "control" case. + We do not currently support (but plan to in the near future), test set evaluation for group labels that were not in the training set. Must be a numpy array. rfx_basis_0 : np.array, optional Test set basis for used for prediction from an additive random effects model in the "control" case. Must be a numpy array. @@ -2634,10 +2630,28 @@ def compute_contrast( raise ValueError("X_1 and Z_1 must have the same number of rows") # Predict for the control arm - control_preds = self.predict(X=X_0, Z=Z_0, propensity=propensity_0, rfx_group_ids=rfx_group_ids_0, rfx_basis=rfx_basis_0, type="posterior", terms="y_hat", scale="linear") + control_preds = self.predict( + X=X_0, + Z=Z_0, + propensity=propensity_0, + rfx_group_ids=rfx_group_ids_0, + rfx_basis=rfx_basis_0, + type="posterior", + terms="y_hat", + scale="linear", + ) # Predict for the treatment arm - treatment_preds = self.predict(X=X_1, Z=Z_1, propensity=propensity_1, rfx_group_ids=rfx_group_ids_1, rfx_basis=rfx_basis_1, type="posterior", terms="y_hat", scale="linear") + treatment_preds = self.predict( + X=X_1, + Z=Z_1, + propensity=propensity_1, + rfx_group_ids=rfx_group_ids_1, + rfx_basis=rfx_basis_1, + type="posterior", + terms="y_hat", + scale="linear", + ) # Transform to probability scale if requested if probability_scale: @@ -2646,9 +2660,9 @@ def compute_contrast( # Compute and return contrast if predict_mean: - return(np.mean(treatment_preds - control_preds, axis=1)) + return np.mean(treatment_preds - control_preds, axis=1) else: - return(treatment_preds - control_preds) + return treatment_preds - control_preds def compute_posterior_interval( self, From 6ee1a9b8df2566479e1ed28bfdfd14f93b5c1505 Mon Sep 17 00:00:00 2001 From: Drew Herren Date: Wed, 22 Oct 2025 14:46:23 -0500 Subject: [PATCH 35/53] Fixed bug at the interface of probit / adaptive coding / RFX in Python --- stochtree/bart.py | 40 ++++++++++++++++++------------- stochtree/bcf.py | 48 +++++++++++++++++++++++-------------- stochtree/random_effects.py | 22 +++++++++++++++++ 3 files changed, 75 insertions(+), 35 deletions(-) diff --git a/stochtree/bart.py b/stochtree/bart.py index 86eb67f1..ae4c86b5 100644 --- a/stochtree/bart.py +++ b/stochtree/bart.py @@ -1233,9 +1233,12 @@ def sample( if self.include_mean_forest: if self.probit_outcome_model: # Sample latent probit variable z | - - forest_pred = active_forest_mean.predict(forest_dataset_train) - mu0 = forest_pred[y_train[:, 0] == 0] - mu1 = forest_pred[y_train[:, 0] == 1] + outcome_pred = active_forest_mean.predict(forest_dataset_train) + if self.has_rfx: + rfx_pred = rfx_model.predict(rfx_dataset_train, rfx_tracker) + outcome_pred = outcome_pred + rfx_pred + mu0 = outcome_pred[y_train[:, 0] == 0] + mu1 = outcome_pred[y_train[:, 0] == 1] n0 = np.sum(y_train[:, 0] == 0) n1 = np.sum(y_train[:, 0] == 1) u0 = self.rng.uniform( @@ -1252,7 +1255,7 @@ def sample( resid_train[y_train[:, 0] == 1, 0] = mu1 + norm.ppf(u1) # Update outcome - new_outcome = np.squeeze(resid_train) - forest_pred + new_outcome = np.squeeze(resid_train) - outcome_pred residual_train.update_data(new_outcome) # Sample the mean forest @@ -1437,11 +1440,14 @@ def sample( if self.include_mean_forest: if self.probit_outcome_model: # Sample latent probit variable z | - - forest_pred = active_forest_mean.predict( + outcome_pred = active_forest_mean.predict( forest_dataset_train ) - mu0 = forest_pred[y_train[:, 0] == 0] - mu1 = forest_pred[y_train[:, 0] == 1] + if self.has_rfx: + rfx_pred = rfx_model.predict(rfx_dataset_train, rfx_tracker) + outcome_pred = outcome_pred + rfx_pred + mu0 = outcome_pred[y_train[:, 0] == 0] + mu1 = outcome_pred[y_train[:, 0] == 1] n0 = np.sum(y_train[:, 0] == 0) n1 = np.sum(y_train[:, 0] == 1) u0 = self.rng.uniform( @@ -1458,7 +1464,7 @@ def sample( resid_train[y_train[:, 0] == 1, 0] = mu1 + norm.ppf(u1) # Update outcome - new_outcome = np.squeeze(resid_train) - forest_pred + new_outcome = np.squeeze(resid_train) - outcome_pred residual_train.update_data(new_outcome) # Sample the mean forest @@ -1813,15 +1819,15 @@ def predict( # Combine into y hat predictions if probability_scale: if predict_y_hat and has_mean_forest and has_rfx: - y_hat = norm.ppf(mean_forest_predictions + rfx_predictions) - mean_forest_predictions = norm.ppf(mean_forest_predictions) - rfx_predictions = norm.ppf(rfx_predictions) + y_hat = norm.cdf(mean_forest_predictions + rfx_predictions) + mean_forest_predictions = norm.cdf(mean_forest_predictions) + rfx_predictions = norm.cdf(rfx_predictions) elif predict_y_hat and has_mean_forest: - y_hat = norm.ppf(mean_forest_predictions) - mean_forest_predictions = norm.ppf(mean_forest_predictions) + y_hat = norm.cdf(mean_forest_predictions) + mean_forest_predictions = norm.cdf(mean_forest_predictions) elif predict_y_hat and has_rfx: - y_hat = norm.ppf(rfx_predictions) - rfx_predictions = norm.ppf(rfx_predictions) + y_hat = norm.cdf(rfx_predictions) + rfx_predictions = norm.cdf(rfx_predictions) else: if predict_y_hat and has_mean_forest and has_rfx: y_hat = mean_forest_predictions + rfx_predictions @@ -2006,8 +2012,8 @@ def compute_contrast( # Transform to probability scale if requested if probability_scale: - treatment_preds = norm.ppf(treatment_preds) - control_preds = norm.ppf(control_preds) + treatment_preds = norm.cdf(treatment_preds) + control_preds = norm.cdf(control_preds) # Compute and return contrast if predict_mean: diff --git a/stochtree/bcf.py b/stochtree/bcf.py index d4fae5ad..b5bbe661 100644 --- a/stochtree/bcf.py +++ b/stochtree/bcf.py @@ -1735,9 +1735,12 @@ def sample( # Sample latent probit variable z | - forest_pred_mu = active_forest_mu.predict(forest_dataset_train) forest_pred_tau = active_forest_tau.predict(forest_dataset_train) - forest_pred = forest_pred_mu + forest_pred_tau - mu0 = forest_pred[y_train[:, 0] == 0] - mu1 = forest_pred[y_train[:, 0] == 1] + outcome_pred = forest_pred_mu + forest_pred_tau + if self.has_rfx: + rfx_pred = rfx_model.predict(rfx_dataset_train, rfx_tracker) + outcome_pred = outcome_pred + rfx_pred + mu0 = outcome_pred[y_train[:, 0] == 0] + mu1 = outcome_pred[y_train[:, 0] == 1] n0 = np.sum(y_train[:, 0] == 0) n1 = np.sum(y_train[:, 0] == 1) u0 = self.rng.uniform( @@ -1754,7 +1757,7 @@ def sample( resid_train[y_train[:, 0] == 1, 0] = mu1 + norm.ppf(u1) # Update outcome - new_outcome = np.squeeze(resid_train) - forest_pred + new_outcome = np.squeeze(resid_train) - outcome_pred residual_train.update_data(new_outcome) # Sample the prognostic forest @@ -1817,18 +1820,21 @@ def sample( # Sample coding parameters (if requested) if self.adaptive_coding: - mu_x = active_forest_mu.predict_raw(forest_dataset_train) + partial_outcome_pred = active_forest_mu.predict_raw(forest_dataset_train) tau_x = np.squeeze( active_forest_tau.predict_raw(forest_dataset_train) ) + if self.has_rfx: + rfx_pred = rfx_model.predict(rfx_dataset_train, rfx_tracker) + partial_outcome_pred = partial_outcome_pred + rfx_pred s_tt0 = np.sum(tau_x * tau_x * (np.squeeze(Z_train) == 0)) s_tt1 = np.sum(tau_x * tau_x * (np.squeeze(Z_train) == 1)) - partial_resid_mu = np.squeeze(resid_train - mu_x) + partial_resid = np.squeeze(resid_train - partial_outcome_pred) s_ty0 = np.sum( - tau_x * partial_resid_mu * (np.squeeze(Z_train) == 0) + tau_x * partial_resid * (np.squeeze(Z_train) == 0) ) s_ty1 = np.sum( - tau_x * partial_resid_mu * (np.squeeze(Z_train) == 1) + tau_x * partial_resid * (np.squeeze(Z_train) == 1) ) current_b_0 = self.rng.normal( loc=(s_ty0 / (s_tt0 + 2 * current_sigma2)), @@ -1935,9 +1941,12 @@ def sample( # Sample latent probit variable z | - forest_pred_mu = active_forest_mu.predict(forest_dataset_train) forest_pred_tau = active_forest_tau.predict(forest_dataset_train) - forest_pred = forest_pred_mu + forest_pred_tau - mu0 = forest_pred[y_train[:, 0] == 0] - mu1 = forest_pred[y_train[:, 0] == 1] + outcome_pred = forest_pred_mu + forest_pred_tau + if self.has_rfx: + rfx_pred = rfx_model.predict(rfx_dataset_train, rfx_tracker) + outcome_pred = outcome_pred + rfx_pred + mu0 = outcome_pred[y_train[:, 0] == 0] + mu1 = outcome_pred[y_train[:, 0] == 1] n0 = np.sum(y_train[:, 0] == 0) n1 = np.sum(y_train[:, 0] == 1) u0 = self.rng.uniform( @@ -1954,7 +1963,7 @@ def sample( resid_train[y_train[:, 0] == 1, 0] = mu1 + norm.ppf(u1) # Update outcome - new_outcome = np.squeeze(resid_train) - forest_pred + new_outcome = np.squeeze(resid_train) - outcome_pred residual_train.update_data(new_outcome) # Sample the prognostic forest @@ -2017,18 +2026,21 @@ def sample( # Sample coding parameters (if requested) if self.adaptive_coding: - mu_x = active_forest_mu.predict_raw(forest_dataset_train) + partial_outcome_pred = active_forest_mu.predict_raw(forest_dataset_train) tau_x = np.squeeze( active_forest_tau.predict_raw(forest_dataset_train) ) + if self.has_rfx: + rfx_pred = rfx_model.predict(rfx_dataset_train, rfx_tracker) + partial_outcome_pred = partial_outcome_pred + rfx_pred s_tt0 = np.sum(tau_x * tau_x * (np.squeeze(Z_train) == 0)) s_tt1 = np.sum(tau_x * tau_x * (np.squeeze(Z_train) == 1)) - partial_resid_mu = np.squeeze(resid_train - mu_x) + partial_resid = np.squeeze(resid_train - partial_outcome_pred) s_ty0 = np.sum( - tau_x * partial_resid_mu * (np.squeeze(Z_train) == 0) + tau_x * partial_resid * (np.squeeze(Z_train) == 0) ) s_ty1 = np.sum( - tau_x * partial_resid_mu * (np.squeeze(Z_train) == 1) + tau_x * partial_resid * (np.squeeze(Z_train) == 1) ) current_b_0 = self.rng.normal( loc=(s_ty0 / (s_tt0 + 2 * current_sigma2)), @@ -2655,8 +2667,8 @@ def compute_contrast( # Transform to probability scale if requested if probability_scale: - treatment_preds = norm.ppf(treatment_preds) - control_preds = norm.ppf(control_preds) + treatment_preds = norm.cdf(treatment_preds) + control_preds = norm.cdf(control_preds) # Compute and return contrast if predict_mean: diff --git a/stochtree/random_effects.py b/stochtree/random_effects.py index 95ac3df1..c93fa7e9 100644 --- a/stochtree/random_effects.py +++ b/stochtree/random_effects.py @@ -426,6 +426,28 @@ def sample( rng.rng_cpp, ) + def predict( + self, rfx_dataset: RandomEffectsDataset, rfx_tracker: RandomEffectsTracker + ) -> np.ndarray: + """ + Predict random effects for each observation in `rfx_dataset` + + Parameters + ---------- + rfx_dataset: RandomEffectsDataset + Object of type `RandomEffectsDataset` + rfx_tracker: RandomEffectsTracker + Object of type `RandomEffectsTracker` + + Returns + ------- + np.ndarray + Numpy array with as many rows as observations in `rfx_dataset` and as many columns as samples in the container + """ + return self.rfx_model_cpp.Predict( + rfx_dataset.rfx_dataset_cpp, rfx_tracker.rfx_tracker_cpp + ) + def set_working_parameter(self, working_parameter: np.ndarray) -> None: """ Set values for the "working parameter." This is typically used for initialization, From 8426d98a698b328348ea57adddcfb006ea095628 Mon Sep 17 00:00:00 2001 From: Drew Herren Date: Wed, 22 Oct 2025 15:24:26 -0500 Subject: [PATCH 36/53] Fixed RFX prior mean for python interface --- stochtree/bart.py | 9 ++------- stochtree/bcf.py | 11 +++-------- 2 files changed, 5 insertions(+), 15 deletions(-) diff --git a/stochtree/bart.py b/stochtree/bart.py index ae4c86b5..20f45cc4 100644 --- a/stochtree/bart.py +++ b/stochtree/bart.py @@ -1003,14 +1003,9 @@ def sample( # Prior parameters if rfx_working_parameter_prior_mean is None: if num_rfx_components == 1: - alpha_init = np.array([1]) + alpha_init = np.array([0.0], dtype=float) elif num_rfx_components > 1: - alpha_init = np.concatenate( - ( - np.ones(1, dtype=float), - np.zeros(num_rfx_components - 1, dtype=float), - ) - ) + alpha_init = np.zeros(num_rfx_components, dtype=float) else: raise ValueError("There must be at least 1 random effect component") else: diff --git a/stochtree/bcf.py b/stochtree/bcf.py index b5bbe661..e050a738 100644 --- a/stochtree/bcf.py +++ b/stochtree/bcf.py @@ -1391,14 +1391,9 @@ def sample( # Prior parameters if rfx_working_parameter_prior_mean is None: if num_rfx_components == 1: - alpha_init = np.array([1]) + alpha_init = np.array([0.0], dtype=float) elif num_rfx_components > 1: - alpha_init = np.concatenate( - ( - np.ones(1, dtype=float), - np.zeros(num_rfx_components - 1, dtype=float), - ) - ) + alpha_init = np.zeros(num_rfx_components, dtype=float) else: raise ValueError("There must be at least 1 random effect component") else: @@ -2763,7 +2758,7 @@ def compute_posterior_interval( raise ValueError( "'treatment' must have the same number of rows as 'covariates'" ) - uses_propensity = self.propensity_covariate is not "none" + uses_propensity = self.propensity_covariate != "none" internal_propensity_model = self.internal_propensity_model needs_propensity = ( needs_covariates and uses_propensity and not internal_propensity_model From 19302ac91186c744ac5c827e785e71857ad2d8ac Mon Sep 17 00:00:00 2001 From: Drew Herren Date: Wed, 22 Oct 2025 16:57:18 -0500 Subject: [PATCH 37/53] Fixed bug in BCF probit RFX in Python --- stochtree/bart.py | 4 ---- stochtree/bcf.py | 26 ++++++++++++-------------- 2 files changed, 12 insertions(+), 18 deletions(-) diff --git a/stochtree/bart.py b/stochtree/bart.py index 20f45cc4..0eb22d87 100644 --- a/stochtree/bart.py +++ b/stochtree/bart.py @@ -1800,16 +1800,12 @@ def predict( pred_dataset.dataset_cpp ) mean_forest_predictions = mean_pred_raw * self.y_std + self.y_bar - # if predict_mean: - # mean_forest_predictions = np.mean(mean_forest_predictions, axis = 1) # Random effects predictions if predict_rfx or predict_rfx_intermediate: rfx_predictions = ( self.rfx_container.predict(rfx_group_ids, rfx_basis) * self.y_std ) - # if predict_mean: - # rfx_predictions = np.mean(rfx_predictions, axis = 1) # Combine into y hat predictions if probability_scale: diff --git a/stochtree/bcf.py b/stochtree/bcf.py index e050a738..ddad0e85 100644 --- a/stochtree/bcf.py +++ b/stochtree/bcf.py @@ -1815,21 +1815,21 @@ def sample( # Sample coding parameters (if requested) if self.adaptive_coding: - partial_outcome_pred = active_forest_mu.predict_raw(forest_dataset_train) + mu_x = active_forest_mu.predict_raw(forest_dataset_train) tau_x = np.squeeze( active_forest_tau.predict_raw(forest_dataset_train) ) + partial_resid_train = np.squeeze(resid_train - mu_x) if self.has_rfx: - rfx_pred = rfx_model.predict(rfx_dataset_train, rfx_tracker) - partial_outcome_pred = partial_outcome_pred + rfx_pred + rfx_pred = np.squeeze(rfx_model.predict(rfx_dataset_train, rfx_tracker)) + partial_resid_train = partial_resid_train - rfx_pred s_tt0 = np.sum(tau_x * tau_x * (np.squeeze(Z_train) == 0)) s_tt1 = np.sum(tau_x * tau_x * (np.squeeze(Z_train) == 1)) - partial_resid = np.squeeze(resid_train - partial_outcome_pred) s_ty0 = np.sum( - tau_x * partial_resid * (np.squeeze(Z_train) == 0) + tau_x * partial_resid_train * (np.squeeze(Z_train) == 0) ) s_ty1 = np.sum( - tau_x * partial_resid * (np.squeeze(Z_train) == 1) + tau_x * partial_resid_train * (np.squeeze(Z_train) == 1) ) current_b_0 = self.rng.normal( loc=(s_ty0 / (s_tt0 + 2 * current_sigma2)), @@ -2021,21 +2021,21 @@ def sample( # Sample coding parameters (if requested) if self.adaptive_coding: - partial_outcome_pred = active_forest_mu.predict_raw(forest_dataset_train) + mu_x = active_forest_mu.predict_raw(forest_dataset_train) tau_x = np.squeeze( active_forest_tau.predict_raw(forest_dataset_train) ) + partial_resid_train = np.squeeze(resid_train - mu_x) if self.has_rfx: - rfx_pred = rfx_model.predict(rfx_dataset_train, rfx_tracker) - partial_outcome_pred = partial_outcome_pred + rfx_pred + rfx_pred = np.squeeze(rfx_model.predict(rfx_dataset_train, rfx_tracker)) + partial_resid_train = partial_resid_train - rfx_pred s_tt0 = np.sum(tau_x * tau_x * (np.squeeze(Z_train) == 0)) s_tt1 = np.sum(tau_x * tau_x * (np.squeeze(Z_train) == 1)) - partial_resid = np.squeeze(resid_train - partial_outcome_pred) s_ty0 = np.sum( - tau_x * partial_resid * (np.squeeze(Z_train) == 0) + tau_x * partial_resid_train * (np.squeeze(Z_train) == 0) ) s_ty1 = np.sum( - tau_x * partial_resid * (np.squeeze(Z_train) == 1) + tau_x * partial_resid_train * (np.squeeze(Z_train) == 1) ) current_b_0 = self.rng.normal( loc=(s_ty0 / (s_tt0 + 2 * current_sigma2)), @@ -2459,8 +2459,6 @@ def predict( rfx_preds = ( self.rfx_container.predict(rfx_group_ids, rfx_basis) * self.y_std ) - if predict_mean: - rfx_preds = np.mean(rfx_preds, axis=1) # Combine into y hat predictions if predict_y_hat and has_mu_forest and has_rfx: From 533f2a18885d6124f8be19a5e0753f5dd5aa2243 Mon Sep 17 00:00:00 2001 From: Drew Herren Date: Wed, 22 Oct 2025 16:57:35 -0500 Subject: [PATCH 38/53] Added demo / debug scripts for new python functionality --- demo/debug/bart_contrast_debug.py | 181 +++++++++++++++++++++ demo/debug/bcf_contrast_debug.py | 243 ++++++++++++++++++++++++++++ demo/debug/probit_bart_rfx_debug.py | 124 ++++++++++++++ demo/debug/probit_bcf_rfx_debug.py | 115 +++++++++++++ 4 files changed, 663 insertions(+) create mode 100644 demo/debug/bart_contrast_debug.py create mode 100644 demo/debug/bcf_contrast_debug.py create mode 100644 demo/debug/probit_bart_rfx_debug.py create mode 100644 demo/debug/probit_bcf_rfx_debug.py diff --git a/demo/debug/bart_contrast_debug.py b/demo/debug/bart_contrast_debug.py new file mode 100644 index 00000000..15ce5705 --- /dev/null +++ b/demo/debug/bart_contrast_debug.py @@ -0,0 +1,181 @@ +# Demo of contrast computation function for BART + +# Load libraries +from stochtree import BARTModel +from sklearn.model_selection import train_test_split +import numpy as np + +# Generate data +n = 500 +p = 5 +rng = np.random.default_rng(1234) +X = rng.uniform(low=0.0, high=1.0, size=(n, p)) +W = rng.normal(loc=0.0, scale=1.0, size=(n, 1)) +f_XW = np.where( + ((0 <= X[:, 0]) & (X[:, 0] < 0.25)), + -7.5 * W[:, 0], + np.where( + ((0.25 <= X[:, 0]) & (X[:, 0] < 0.5)), + -2.5 * W[:, 0], + np.where( + ((0.5 <= X[:, 0]) & (X[:, 0] < 0.75)), + 2.5 * W[:, 0], + 7.5 * W[:, 0], + ), + ), +) +E_Y = f_XW +snr = 2 +y = E_Y + rng.normal(loc=0.0, scale=1.0, size=(n,)) * (np.std(E_Y) / snr) + +# Train-test split +test_set_pct = 0.2 +train_inds, test_inds = train_test_split( + np.arange(n), test_size=test_set_pct, random_state=1234 +) +X_train = X[train_inds, :] +X_test = X[test_inds, :] +W_train = W[train_inds, :] +W_test = W[test_inds, :] +y_train = y[train_inds] +y_test = y[test_inds] +n_test = len(test_inds) +n_train = len(train_inds) + +# Fit BART model +bart_model = BARTModel() +bart_model.sample( + X_train=X_train, + leaf_basis_train=W_train, + y_train=y_train, + num_gfr=10, + num_burnin=0, + num_mcmc=1000, +) + +# Compute contrast posterior +contrast_posterior_test = bart_model.compute_contrast( + covariates_0=X_test, + covariates_1=X_test, + basis_0=np.zeros((n_test, 1)), + basis_1=np.ones((n_test, 1)), + type="posterior", + scale="linear", +) + +# Compute the same quantity via two predict calls +y_hat_posterior_test_0 = bart_model.predict( + covariates=X_test, + basis=np.zeros((n_test, 1)), + type="posterior", + terms="y_hat", + scale="linear", +) +y_hat_posterior_test_1 = bart_model.predict( + covariates=X_test, + basis=np.ones((n_test, 1)), + type="posterior", + terms="y_hat", + scale="linear", +) +contrast_posterior_test_comparison = y_hat_posterior_test_1 - y_hat_posterior_test_0 + +# Compare results +contrast_diff = contrast_posterior_test_comparison - contrast_posterior_test +np.allclose(contrast_diff, 0, atol=0.001) + +# Generate data for a BART model with random effects +X = rng.uniform(low=0.0, high=1.0, size=(n, p)) +W = rng.normal(loc=0.0, scale=1.0, size=(n, 1)) +f_XW = np.where( + ((0 <= X[:, 0]) & (X[:, 0] < 0.25)), + -7.5 * W[:, 0], + np.where( + ((0.25 <= X[:, 0]) & (X[:, 0] < 0.5)), + -2.5 * W[:, 0], + np.where( + ((0.5 <= X[:, 0]) & (X[:, 0] < 0.75)), + 2.5 * W[:, 0], + 7.5 * W[:, 0], + ), + ), +) +num_rfx_groups = 3 +group_labels = rng.choice(num_rfx_groups, size=n) +basis = np.empty((n, 2)) +basis[:, 0] = 1.0 +basis[:, 1] = rng.uniform(0, 1, (n,)) +rfx_coefs = np.array([[-2, -2], [0, 0], [2, 2]]) +rfx_term = np.sum(rfx_coefs[group_labels, :] * basis, axis=1) +E_Y = f_XW + rfx_term +snr = 2 +y = E_Y + rng.normal(loc=0.0, scale=1.0, size=(n,)) * (np.std(E_Y) / snr) + +# Train-test split +train_inds, test_inds = train_test_split( + np.arange(n), test_size=test_set_pct, random_state=1234 +) +X_train = X[train_inds, :] +X_test = X[test_inds, :] +W_train = W[train_inds, :] +W_test = W[test_inds, :] +y_train = y[train_inds] +y_test = y[test_inds] +group_ids_train = group_labels[train_inds] +group_ids_test = group_labels[test_inds] +rfx_basis_train = basis[train_inds, :] +rfx_basis_test = basis[test_inds, :] +n_test = len(test_inds) +n_train = len(train_inds) + +# Fit BART model +bart_model = BARTModel() +bart_model.sample( + X_train=X_train, + leaf_basis_train=W_train, + y_train=y_train, + rfx_group_ids_train=group_ids_train, + rfx_basis_train=rfx_basis_train, + num_gfr=10, + num_burnin=0, + num_mcmc=1000, +) + +# Compute contrast posterior +contrast_posterior_test = bart_model.compute_contrast( + covariates_0=X_test, + covariates_1=X_test, + basis_0=np.zeros((n_test, 1)), + basis_1=np.ones((n_test, 1)), + rfx_group_ids_0=group_ids_test, + rfx_group_ids_1=group_ids_test, + rfx_basis_0=rfx_basis_test, + rfx_basis_1=rfx_basis_test, + type="posterior", + scale="linear", +) + +# Compute the same quantity via two predict calls +y_hat_posterior_test_0 = bart_model.predict( + covariates=X_test, + basis=np.zeros((n_test, 1)), + rfx_group_ids=group_ids_test, + rfx_basis=rfx_basis_test, + type="posterior", + terms="y_hat", + scale="linear", +) +y_hat_posterior_test_1 = bart_model.predict( + covariates=X_test, + basis=np.ones((n_test, 1)), + rfx_group_ids=group_ids_test, + rfx_basis=rfx_basis_test, + type="posterior", + terms="y_hat", + scale="linear", +) +contrast_posterior_test_comparison = y_hat_posterior_test_1 - y_hat_posterior_test_0 + +# Compare results +contrast_diff = contrast_posterior_test_comparison - contrast_posterior_test +np.allclose(contrast_diff, 0, atol=0.001) diff --git a/demo/debug/bcf_contrast_debug.py b/demo/debug/bcf_contrast_debug.py new file mode 100644 index 00000000..8a98c33c --- /dev/null +++ b/demo/debug/bcf_contrast_debug.py @@ -0,0 +1,243 @@ +# Demo of contrast computation function for BCF + +# Load libraries +from stochtree import BCFModel +from sklearn.model_selection import train_test_split +from scipy.stats import norm +import numpy as np + +# Generate data +n = 500 +p = 5 +rng = np.random.default_rng(1234) +X = rng.uniform(low=0.0, high=1.0, size=(n, p)) +mu_x = X[:, 0] +tau_x = 0.25 * X[:, 1] +pi_x = norm.cdf(0.5 * X[:, 0]) +Z = rng.binomial(n=1, p=pi_x, size=(n,)) +E_XZ = mu_x + Z * tau_x +snr = 2 +y = E_XZ + rng.normal(loc=0.0, scale=1.0, size=(n,)) * (np.std(E_XZ) / snr) + +# Train-test split +test_set_pct = 0.2 +train_inds, test_inds = train_test_split( + np.arange(n), test_size=test_set_pct, random_state=1234 +) +X_train = X[train_inds, :] +X_test = X[test_inds, :] +Z_train = Z[train_inds] +Z_test = Z[test_inds] +y_train = y[train_inds] +y_test = y[test_inds] +n_test = len(test_inds) +n_train = len(train_inds) + +# Fit BCF model +bcf_model = BCFModel() +bcf_model.sample( + X_train=X_train, + Z_train=Z_train, + y_train=y_train, + num_gfr=10, + num_burnin=0, + num_mcmc=1000, +) + +# Compute contrast posterior +contrast_posterior_test = bcf_model.compute_contrast( + X_0=X_test, + X_1=X_test, + Z_0=np.zeros((n_test, 1)), + Z_1=np.ones((n_test, 1)), + type="posterior", + scale="linear", +) + +# Compute the same quantity via two predict calls +y_hat_posterior_test_0 = bcf_model.predict( + X=X_test, Z=np.zeros((n_test, 1)), type="posterior", terms="y_hat", scale="linear" +) +y_hat_posterior_test_1 = bcf_model.predict( + X=X_test, Z=np.ones((n_test, 1)), type="posterior", terms="y_hat", scale="linear" +) +contrast_posterior_test_comparison = y_hat_posterior_test_1 - y_hat_posterior_test_0 + +# Compare results +contrast_diff = contrast_posterior_test_comparison - contrast_posterior_test +np.allclose(contrast_diff, 0, atol=0.001) + +# Generate data for a BCF model with random effects +X = rng.uniform(low=0.0, high=1.0, size=(n, p)) +mu_x = X[:, 0] +tau_x = 0.25 * X[:, 1] +pi_x = norm.cdf(0.5 * X[:, 0]) +Z = rng.binomial(n=1, p=pi_x, size=(n,)) +num_rfx_groups = 3 +group_labels = rng.choice(num_rfx_groups, size=n) +basis = np.empty((n, 2)) +basis[:, 0] = 1.0 +basis[:, 1] = rng.uniform(0, 1, (n,)) +rfx_coefs = np.array([[-2, -2], [0, 0], [2, 2]]) +rfx_term = np.sum(rfx_coefs[group_labels, :] * basis, axis=1) +E_XZ = mu_x + Z * tau_x + rfx_term +snr = 2 +y = E_XZ + rng.normal(loc=0.0, scale=1.0, size=(n,)) * (np.std(E_XZ) / snr) + +# Train-test split +train_inds, test_inds = train_test_split( + np.arange(n), test_size=test_set_pct, random_state=1234 +) +X_train = X[train_inds, :] +X_test = X[test_inds, :] +Z_train = Z[train_inds] +Z_test = Z[test_inds] +y_train = y[train_inds] +y_test = y[test_inds] +group_ids_train = group_labels[train_inds] +group_ids_test = group_labels[test_inds] +rfx_basis_train = basis[train_inds, :] +rfx_basis_test = basis[test_inds, :] +n_test = len(test_inds) +n_train = len(train_inds) + +# Fit BCF model +bcf_model = BCFModel() +bcf_model.sample( + X_train=X_train, + Z_train=Z_train, + y_train=y_train, + rfx_group_ids_train=group_ids_train, + rfx_basis_train=rfx_basis_train, + num_gfr=10, + num_burnin=0, + num_mcmc=1000, +) + +# Compute contrast posterior +contrast_posterior_test = bcf_model.compute_contrast( + X_0=X_test, + X_1=X_test, + Z_0=np.zeros((n_test, 1)), + Z_1=np.ones((n_test, 1)), + rfx_group_ids_0=group_ids_test, + rfx_group_ids_1=group_ids_test, + rfx_basis_0=rfx_basis_test, + rfx_basis_1=rfx_basis_test, + type="posterior", + scale="linear", +) + +# Compute the same quantity via two predict calls +y_hat_posterior_test_0 = bcf_model.predict( + X=X_test, + Z=np.zeros((n_test, 1)), + rfx_group_ids=group_ids_test, + rfx_basis=rfx_basis_test, + type="posterior", + terms="y_hat", + scale="linear", +) +y_hat_posterior_test_1 = bcf_model.predict( + X=X_test, + Z=np.ones((n_test, 1)), + rfx_group_ids=group_ids_test, + rfx_basis=rfx_basis_test, + type="posterior", + terms="y_hat", + scale="linear", +) +contrast_posterior_test_comparison = y_hat_posterior_test_1 - y_hat_posterior_test_0 + +# Compare results +contrast_diff = contrast_posterior_test_comparison - contrast_posterior_test +np.allclose(contrast_diff, 0, atol=0.001) + +# Generate data for a probit BCF model with random effects +X = rng.uniform(low=0.0, high=1.0, size=(n, p)) +mu_x = X[:, 0] +tau_x = 0.25 * X[:, 1] +pi_x = norm.cdf(0.5 * X[:, 0]) +Z = rng.binomial(n=1, p=pi_x, size=(n,)) +num_rfx_groups = 3 +group_labels = rng.choice(num_rfx_groups, size=n) +basis = np.empty((n, 2)) +basis[:, 0] = 1.0 +basis[:, 1] = rng.uniform(0, 1, (n,)) +rfx_coefs = np.array([[-2, -2], [0, 0], [2, 2]]) +rfx_term = np.sum(rfx_coefs[group_labels, :] * basis, axis=1) +E_XZ = mu_x + Z * tau_x + rfx_term +W = E_XZ + rng.normal(loc=0.0, scale=1.0, size=(n,)) +y = (W > 0) * 1.0 + +# Train-test split +train_inds, test_inds = train_test_split( + np.arange(n), test_size=test_set_pct, random_state=1234 +) +X_train = X[train_inds, :] +X_test = X[test_inds, :] +Z_train = Z[train_inds] +Z_test = Z[test_inds] +W_train = W[train_inds] +W_test = W[test_inds] +y_train = y[train_inds] +y_test = y[test_inds] +group_ids_train = group_labels[train_inds] +group_ids_test = group_labels[test_inds] +rfx_basis_train = basis[train_inds, :] +rfx_basis_test = basis[test_inds, :] +n_test = len(test_inds) +n_train = len(train_inds) + +# Fit BCF model +bcf_model = BCFModel() +bcf_model.sample( + X_train=X_train, + Z_train=Z_train, + y_train=y_train, + rfx_group_ids_train=group_ids_train, + rfx_basis_train=rfx_basis_train, + num_gfr=10, + num_burnin=0, + num_mcmc=1000, + general_params={"probit_outcome_model": True}, +) + +# Compute contrast posterior +contrast_posterior_test = bcf_model.compute_contrast( + X_0=X_test, + X_1=X_test, + Z_0=np.zeros((n_test, 1)), + Z_1=np.ones((n_test, 1)), + rfx_group_ids_0=group_ids_test, + rfx_group_ids_1=group_ids_test, + rfx_basis_0=rfx_basis_test, + rfx_basis_1=rfx_basis_test, + type="posterior", + scale="probability", +) + +# Compute the same quantity via two predict calls +y_hat_posterior_test_0 = bcf_model.predict( + X=X_test, + Z=np.zeros((n_test, 1)), + rfx_group_ids=group_ids_test, + rfx_basis=rfx_basis_test, + type="posterior", + terms="y_hat", + scale="probability", +) +y_hat_posterior_test_1 = bcf_model.predict( + X=X_test, + Z=np.ones((n_test, 1)), + rfx_group_ids=group_ids_test, + rfx_basis=rfx_basis_test, + type="posterior", + terms="y_hat", + scale="probability", +) +contrast_posterior_test_comparison = y_hat_posterior_test_1 - y_hat_posterior_test_0 + +# Compare results +contrast_diff = contrast_posterior_test_comparison - contrast_posterior_test +np.allclose(contrast_diff, 0, atol=0.001) diff --git a/demo/debug/probit_bart_rfx_debug.py b/demo/debug/probit_bart_rfx_debug.py new file mode 100644 index 00000000..ae2e8c10 --- /dev/null +++ b/demo/debug/probit_bart_rfx_debug.py @@ -0,0 +1,124 @@ +# Debuggin probit BCF with random for BART + +# Load libraries +from stochtree import BARTModel +from sklearn.model_selection import train_test_split +import numpy as np +import matplotlib.pyplot as plt + +# Generate data for a probit BART model with random effects +n = 500 +p = 5 +rng = np.random.default_rng(1234) +X = rng.uniform(low=0.0, high=1.0, size=(n, p)) +W = rng.normal(loc=0.0, scale=1.0, size=(n, 1)) +f_XW = np.where( + ((0 <= X[:, 0]) & (X[:, 0] < 0.25)), + -7.5 * W[:, 0], + np.where( + ((0.25 <= X[:, 0]) & (X[:, 0] < 0.5)), + -2.5 * W[:, 0], + np.where( + ((0.5 <= X[:, 0]) & (X[:, 0] < 0.75)), + 2.5 * W[:, 0], + 7.5 * W[:, 0], + ), + ), +) +num_rfx_groups = 3 +group_labels = rng.choice(num_rfx_groups, size=n) +basis = np.empty((n, 2)) +basis[:, 0] = 1.0 +basis[:, 1] = rng.uniform(0, 1, (n,)) +rfx_coefs = np.array([[-2, -2], [0, 0], [2, 2]]) +rfx_term = np.sum(rfx_coefs[group_labels, :] * basis, axis=1) +E_Y = f_XW + rfx_term +Z = E_Y + rng.normal(loc=0.0, scale=1.0, size=(n,)) +y = (Z > 0) * 1.0 + +# Train-test split +test_set_pct = 0.2 +train_inds, test_inds = train_test_split( + np.arange(n), test_size=test_set_pct, random_state=1234 +) +X_train = X[train_inds, :] +X_test = X[test_inds, :] +W_train = W[train_inds, :] +W_test = W[test_inds, :] +Z_train = Z[train_inds] +Z_test = Z[test_inds] +y_train = y[train_inds] +y_test = y[test_inds] +group_ids_train = group_labels[train_inds] +group_ids_test = group_labels[test_inds] +rfx_basis_train = basis[train_inds, :] +rfx_basis_test = basis[test_inds, :] +n_test = len(test_inds) +n_train = len(train_inds) + +# Fit BART model +bart_model = BARTModel() +bart_model.sample( + X_train=X_train, + leaf_basis_train=W_train, + y_train=y_train, + rfx_group_ids_train=group_ids_train, + rfx_basis_train=rfx_basis_train, + num_gfr=10, + num_burnin=0, + num_mcmc=1000, + general_params={"probit_outcome_model": True}, +) + +# Compute contrast posterior +contrast_posterior_test = bart_model.compute_contrast( + covariates_0=X_test, + covariates_1=X_test, + basis_0=np.zeros((n_test, 1)), + basis_1=np.ones((n_test, 1)), + rfx_group_ids_0=group_ids_test, + rfx_group_ids_1=group_ids_test, + rfx_basis_0=rfx_basis_test, + rfx_basis_1=rfx_basis_test, + type="posterior", + scale="linear", +) + +# Compute the same quantity via two predict calls +y_hat_posterior_test_0 = bart_model.predict( + covariates=X_test, + basis=np.zeros((n_test, 1)), + rfx_group_ids=group_ids_test, + rfx_basis=rfx_basis_test, + type="posterior", + terms="y_hat", + scale="linear", +) +y_hat_posterior_test_1 = bart_model.predict( + covariates=X_test, + basis=np.ones((n_test, 1)), + rfx_group_ids=group_ids_test, + rfx_basis=rfx_basis_test, + type="posterior", + terms="y_hat", + scale="linear", +) +contrast_posterior_test_comparison = y_hat_posterior_test_1 - y_hat_posterior_test_0 + +# Compare results +contrast_diff = contrast_posterior_test_comparison - contrast_posterior_test +np.allclose(contrast_diff, 0, atol=0.001) + +# Plot predicted versus actual outcome +Z_hat_test = bart_model.predict( + covariates=X_test, + basis=W_test, + rfx_group_ids=group_ids_test, + rfx_basis=rfx_basis_test, + type="mean", + terms="y_hat", + scale="linear", +) +plt.scatter(Z_hat_test, Z_test, alpha=0.5) +plt.axline((0, 0), slope=1, color="red", linestyle="--") +plt.show() diff --git a/demo/debug/probit_bcf_rfx_debug.py b/demo/debug/probit_bcf_rfx_debug.py new file mode 100644 index 00000000..3036e509 --- /dev/null +++ b/demo/debug/probit_bcf_rfx_debug.py @@ -0,0 +1,115 @@ +# Debugging probit BCF with random effects + +# Load libraries +from stochtree import BCFModel +from scipy.stats import norm +from sklearn.model_selection import train_test_split +import numpy as np +import matplotlib.pyplot as plt + +# Generate data for a probit BCF model with random effects +n = 500 +p = 5 +rng = np.random.default_rng(1234) +X = rng.uniform(low=0.0, high=1.0, size=(n, p)) +mu_x = X[:, 0] +tau_x = 0.25 * X[:, 1] +pi_x = norm.cdf(0.5 * X[:, 0]) +Z = rng.binomial(n=1, p=pi_x, size=(n,)) +num_rfx_groups = 3 +group_labels = rng.choice(num_rfx_groups, size=n) +basis = np.empty((n, 2)) +basis[:, 0] = 1.0 +basis[:, 1] = rng.uniform(0, 1, (n,)) +rfx_coefs = np.array([[-2, -2], [0, 0], [2, 2]]) +rfx_term = np.sum(rfx_coefs[group_labels, :] * basis, axis=1) +E_XZ = mu_x + Z * tau_x + rfx_term +W = E_XZ + rng.normal(loc=0.0, scale=1.0, size=(n,)) +y = (W > 0) * 1.0 + +# Train-test split +test_set_pct = 0.2 +train_inds, test_inds = train_test_split( + np.arange(n), test_size=test_set_pct, random_state=1234 +) +X_train = X[train_inds, :] +X_test = X[test_inds, :] +Z_train = Z[train_inds] +Z_test = Z[test_inds] +W_train = W[train_inds] +W_test = W[test_inds] +y_train = y[train_inds] +y_test = y[test_inds] +group_ids_train = group_labels[train_inds] +group_ids_test = group_labels[test_inds] +rfx_basis_train = basis[train_inds, :] +rfx_basis_test = basis[test_inds, :] +n_test = len(test_inds) +n_train = len(train_inds) + +# Fit BCF model +bcf_model = BCFModel() +bcf_model.sample( + X_train=X_train, + Z_train=Z_train, + y_train=y_train, + rfx_group_ids_train=group_ids_train, + rfx_basis_train=rfx_basis_train, + num_gfr=10, + num_burnin=0, + num_mcmc=1000, + general_params={"probit_outcome_model": True}, +) + +# Compute contrast posterior +contrast_posterior_test = bcf_model.compute_contrast( + X_0=X_test, + X_1=X_test, + Z_0=np.zeros((n_test, 1)), + Z_1=np.ones((n_test, 1)), + rfx_group_ids_0=group_ids_test, + rfx_group_ids_1=group_ids_test, + rfx_basis_0=rfx_basis_test, + rfx_basis_1=rfx_basis_test, + type="posterior", + scale="probability", +) + +# Compute the same quantity via two predict calls +y_hat_posterior_test_0 = bcf_model.predict( + X=X_test, + Z=np.zeros((n_test, 1)), + rfx_group_ids=group_ids_test, + rfx_basis=rfx_basis_test, + type="posterior", + terms="y_hat", + scale="probability", +) +y_hat_posterior_test_1 = bcf_model.predict( + X=X_test, + Z=np.ones((n_test, 1)), + rfx_group_ids=group_ids_test, + rfx_basis=rfx_basis_test, + type="posterior", + terms="y_hat", + scale="probability", +) +contrast_posterior_test_comparison = y_hat_posterior_test_1 - y_hat_posterior_test_0 + +# Compare results +contrast_diff = contrast_posterior_test_comparison - contrast_posterior_test +np.allclose(contrast_diff, 0, atol=0.001) + +# Plot predicted versus actual outcome +W_hat_test = bcf_model.predict( + X=X_test, + Z=Z_test, + rfx_group_ids=group_ids_test, + rfx_basis=rfx_basis_test, + type="mean", + terms="y_hat", + scale="linear", +) +plt.scatter(W_hat_test, W_test, alpha=0.5) +plt.axline((0, 0), slope=1, color="red", linestyle="--") +plt.show() From 97cb983105e9e3cd634047ed6615b4794d5cffcc Mon Sep 17 00:00:00 2001 From: Drew Herren Date: Thu, 23 Oct 2025 00:14:53 -0500 Subject: [PATCH 39/53] Updated RFX DGP in the probit BCF debug python script --- demo/debug/probit_bcf_rfx_debug.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/demo/debug/probit_bcf_rfx_debug.py b/demo/debug/probit_bcf_rfx_debug.py index 3036e509..9c2dbdfb 100644 --- a/demo/debug/probit_bcf_rfx_debug.py +++ b/demo/debug/probit_bcf_rfx_debug.py @@ -8,7 +8,7 @@ import matplotlib.pyplot as plt # Generate data for a probit BCF model with random effects -n = 500 +n = 1000 p = 5 rng = np.random.default_rng(1234) X = rng.uniform(low=0.0, high=1.0, size=(n, p)) @@ -21,7 +21,7 @@ basis = np.empty((n, 2)) basis[:, 0] = 1.0 basis[:, 1] = rng.uniform(0, 1, (n,)) -rfx_coefs = np.array([[-2, -2], [0, 0], [2, 2]]) +rfx_coefs = np.array([[-1, 1], [0, 1], [1, 1]]) rfx_term = np.sum(rfx_coefs[group_labels, :] * basis, axis=1) E_XZ = mu_x + Z * tau_x + rfx_term W = E_XZ + rng.normal(loc=0.0, scale=1.0, size=(n,)) From d6e4480b3e710833fc1351a54ff4e378732645b0 Mon Sep 17 00:00:00 2001 From: Drew Herren Date: Thu, 23 Oct 2025 00:43:02 -0500 Subject: [PATCH 40/53] Added RFX spec argument to R bcf --- R/bcf.R | 100 +++++++++++++++++++++++++++++++++++++++++--------------- 1 file changed, 74 insertions(+), 26 deletions(-) diff --git a/R/bcf.R b/R/bcf.R index 7c27cefe..6e8a420f 100644 --- a/R/bcf.R +++ b/R/bcf.R @@ -99,6 +99,7 @@ #' #' @param random_effects_params (Optional) A list of random effects model parameters, each of which has a default value processed internally, so this argument list is optional. #' +#' - `model_spec` Specification of the random effects model. Options are "custom", "intercept_only", and "intercept_plus_treatment". If "custom" is specified, then a user-provided basis must be passed through `rfx_basis_train`. If "intercept_only" is specified, a random effects basis of all ones will be dispatched internally at sampling and prediction time. If "intercept_plus_treatment" is specified, a random effects basis that combines an "intercept" basis of all ones with the treatment variable (`Z_train`) will be dispatched internally at sampling and prediction time. Default: "custom". If either "intercept_only" or "intercept_plus_treatment" is specified, `rfx_basis_train` and `rfx_basis_test` (if provided) will be ignored. #' - `working_parameter_prior_mean` Prior mean for the random effects "working parameter". Default: `NULL`. Must be a vector whose dimension matches the number of random effects bases, or a scalar value that will be expanded to a vector. #' - `group_parameters_prior_mean` Prior mean for the random effects "group parameters." Default: `NULL`. Must be a vector whose dimension matches the number of random effects bases, or a scalar value that will be expanded to a vector. #' - `working_parameter_prior_cov` Prior covariance matrix for the random effects "working parameter." Default: `NULL`. Must be a square matrix whose dimension matches the number of random effects bases, or a scalar value that will be expanded to a diagonal matrix. @@ -269,6 +270,7 @@ bcf <- function( # Update random effects parameters rfx_params_default <- list( + model_spec = "custom", working_parameter_prior_mean = NULL, group_parameter_prior_mean = NULL, working_parameter_prior_cov = NULL, @@ -348,6 +350,7 @@ bcf <- function( num_features_subsample_variance <- variance_forest_params_updated$num_features_subsample # 5. Random effects parameters + rfx_model_spec <- rfx_params_updated$model_spec rfx_working_parameter_prior_mean <- rfx_params_updated$working_parameter_prior_mean rfx_group_parameter_prior_mean <- rfx_params_updated$group_parameter_prior_mean rfx_working_parameter_prior_cov <- rfx_params_updated$working_parameter_prior_cov @@ -769,20 +772,6 @@ bcf <- function( } } - # Random effects covariance prior - if (has_rfx) { - if (is.null(rfx_prior_var)) { - rfx_prior_var <- rep(1, ncol(rfx_basis_train)) - } else { - if ((!is.integer(rfx_prior_var)) && (!is.numeric(rfx_prior_var))) { - stop("rfx_prior_var must be a numeric vector") - } - if (length(rfx_prior_var) != ncol(rfx_basis_train)) { - stop("length(rfx_prior_var) must equal ncol(rfx_basis_train)") - } - } - } - # Update variable weights variable_weights_adj <- 1 / sapply(original_var_indices, function(x) sum(original_var_indices == x)) @@ -799,40 +788,74 @@ bcf <- function( ] <- 0 } - # Fill in rfx basis as a vector of 1s (random intercept) if a basis not provided + # Handle the rfx basis matrices has_basis_rfx <- FALSE num_basis_rfx <- 0 if (has_rfx) { - if (is.null(rfx_basis_train)) { + if (rfx_model_spec == "custom") { + if (is.null(rfx_basis_train)) { + stop( + "A user-provided basis (`rfx_basis_train`) must be provided when the random effects model spec is 'custom'" + ) + } + has_basis_rfx <- TRUE + num_basis_rfx <- ncol(rfx_basis_train) + } else if (rfx_model_spec == "intercept_only") { rfx_basis_train <- matrix( rep(1, nrow(X_train)), nrow = nrow(X_train), ncol = 1 ) - } else { has_basis_rfx <- TRUE - num_basis_rfx <- ncol(rfx_basis_train) + num_basis_rfx <- 1 + } else if (rfx_model_spec == "intercept_plus_treatment") { + rfx_basis_train <- cbind( + rep(1, nrow(X_train)), + Z_train + ) + has_basis_rfx <- TRUE + num_basis_rfx <- 1 + ncol(Z_train) } num_rfx_groups <- length(unique(rfx_group_ids_train)) num_rfx_components <- ncol(rfx_basis_train) if (num_rfx_groups == 1) { warning( - "Only one group was provided for random effect sampling, so the 'redundant parameterization' is likely overkill" + "Only one group was provided for random effect sampling, so the random effects model is likely overkill" ) } } if (has_rfx_test) { - if (is.null(rfx_basis_test)) { - if (!is.null(rfx_basis_train)) { + if (rfx_model_spec == "custom") { + if (is.null(rfx_basis_test)) { stop( - "Random effects basis provided for training set, must also be provided for the test set" + "A user-provided basis (`rfx_basis_test`) must be provided when the random effects model spec is 'custom'" ) } + } else if (rfx_model_spec == "intercept_only") { rfx_basis_test <- matrix( rep(1, nrow(X_test)), nrow = nrow(X_test), ncol = 1 ) + } else if (rfx_model_spec == "intercept_plus_treatment") { + rfx_basis_test <- cbind( + rep(1, nrow(X_test)), + Z_test + ) + } + } + + # Random effects covariance prior + if (has_rfx) { + if (is.null(rfx_prior_var)) { + rfx_prior_var <- rep(1, ncol(rfx_basis_train)) + } else { + if ((!is.integer(rfx_prior_var)) && (!is.numeric(rfx_prior_var))) { + stop("rfx_prior_var must be a numeric vector") + } + if (length(rfx_prior_var) != ncol(rfx_basis_train)) { + stop("length(rfx_prior_var) must equal ncol(rfx_basis_train)") + } } } @@ -2536,7 +2559,8 @@ bcf <- function( "sample_sigma2_global" = sample_sigma2_global, "sample_sigma2_leaf_mu" = sample_sigma2_leaf_mu, "sample_sigma2_leaf_tau" = sample_sigma2_leaf_tau, - "probit_outcome_model" = probit_outcome_model + "probit_outcome_model" = probit_outcome_model, + "rfx_model_spec" = rfx_model_spec ) result <- list( "forests_mu" = forest_samples_mu, @@ -2806,9 +2830,24 @@ predict.bcfmodel <- function( has_rfx <- TRUE } - # Produce basis for the "intercept-only" random effects case - if ((object$model_params$has_rfx) && (is.null(rfx_basis))) { - rfx_basis <- matrix(rep(1, nrow(X)), ncol = 1) + # Handle RFX model specification + if (object$model_params$rfx_model_spec == "custom") { + if (is.null(rfx_basis)) { + stop( + "A user-provided basis (`rfx_basis`) must be provided when the model was sampled with a random effects model spec set to 'custom'" + ) + } + } else if (object$model_params$rfx_model_spec == "intercept_only") { + rfx_basis <- matrix( + rep(1, nrow(X)), + nrow = nrow(X), + ncol = 1 + ) + } else if (object$model_params$rfx_model_spec == "intercept_plus_treatment") { + rfx_basis <- cbind( + rep(1, nrow(X)), + Z + ) } # Add propensities to covariate set if necessary @@ -3650,6 +3689,9 @@ createBCFModelFromJson <- function(json_object) { model_params[["probit_outcome_model"]] <- json_object$get_boolean( "probit_outcome_model" ) + model_params[["rfx_model_spec"]] <- json_object$get_string( + "rfx_model_spec" + ) output[["model_params"]] <- model_params # Unpack sampled parameters @@ -4069,6 +4111,9 @@ createBCFModelFromCombinedJson <- function(json_object_list) { model_params[["probit_outcome_model"]] <- json_object_default$get_boolean( "probit_outcome_model" ) + model_params[["rfx_model_spec"]] <- json_object_default$get_string( + "rfx_model_spec" + ) # Combine values that are sample-specific for (i in 1:length(json_object_list)) { @@ -4423,6 +4468,9 @@ createBCFModelFromCombinedJsonString <- function(json_string_list) { model_params[["probit_outcome_model"]] <- json_object_default$get_boolean( "probit_outcome_model" ) + model_params[["rfx_model_spec"]] <- json_object_default$get_string( + "rfx_model_spec" + ) # Combine values that are sample-specific for (i in 1:length(json_object_list)) { From bb09de508e325ef271a0714f625e1abb811b9ce8 Mon Sep 17 00:00:00 2001 From: Drew Herren Date: Thu, 23 Oct 2025 01:37:43 -0500 Subject: [PATCH 41/53] Handling RFX specifications in the predict BCF method --- R/bcf.R | 77 +++++++++++++++++++++++++++++++++++++++++++++++---------- 1 file changed, 64 insertions(+), 13 deletions(-) diff --git a/R/bcf.R b/R/bcf.R index 6e8a420f..bfcd7774 100644 --- a/R/bcf.R +++ b/R/bcf.R @@ -2720,6 +2720,10 @@ predict.bcfmodel <- function( predict_mean <- type == "mean" # Handle prediction terms + rfx_model_spec = object$model_params$rfx_model_spec + rfx_intercept_only <- rfx_model_spec == "intercept_only" + rfx_intercept_plus_treatment <- (rfx_model_spec == "intercept_plus_treatment") + rfx_intercept <- rfx_intercept_only || rfx_intercept_plus_treatment if (!is.character(terms)) { stop("type must be a string or character vector") } @@ -2756,7 +2760,10 @@ predict.bcfmodel <- function( )) return(NULL) } - predict_rfx_intermediate <- (predict_y_hat && has_rfx) + predict_rfx_intermediate <- ((predict_y_hat && has_rfx)) + predict_rfx_raw <- ((predict_mu_forest && has_rfx && rfx_intercept_only) || + (predict_mu_forest && has_rfx && rfx_intercept_plus_treatment) || + (predict_tau_forest && has_rfx && rfx_intercept_plus_treatment)) predict_mu_forest_intermediate <- (predict_y_hat && has_mu_forest) predict_tau_forest_intermediate <- (predict_y_hat && has_tau_forest) @@ -2887,30 +2894,33 @@ predict.bcfmodel <- function( # Compute mu forest predictions if (predict_mu_forest || predict_mu_forest_intermediate) { - mu_hat <- object$forests_mu$predict(forest_dataset_pred) * y_std + y_bar + mu_hat_forest <- object$forests_mu$predict(forest_dataset_pred) * + y_std + + y_bar } # Compute CATE forest predictions if (predict_tau_forest || predict_tau_forest_intermediate) { if (object$model_params$adaptive_coding) { tau_hat_raw <- object$forests_tau$predict_raw(forest_dataset_pred) - tau_hat <- t( + tau_hat_forest <- t( t(tau_hat_raw) * (object$b_1_samples - object$b_0_samples) ) * y_std } else { - tau_hat <- object$forests_tau$predict_raw(forest_dataset_pred) * y_std + tau_hat_forest <- object$forests_tau$predict_raw(forest_dataset_pred) * + y_std } if (object$model_params$multivariate_treatment) { - tau_dim <- dim(tau_hat) + tau_dim <- dim(tau_hat_forest) tau_num_obs <- tau_dim[1] tau_num_samples <- tau_dim[3] treatment_term <- matrix(NA_real_, nrow = tau_num_obs, tau_num_samples) for (i in 1:nrow(Z)) { - treatment_term[i, ] <- colSums(tau_hat[i, , ] * Z[i, ]) + treatment_term[i, ] <- colSums(tau_hat_forest[i, , ] * Z[i, ]) } } else { - treatment_term <- tau_hat * as.numeric(Z) + treatment_term <- tau_hat_forest * as.numeric(Z) } } @@ -2923,6 +2933,41 @@ predict.bcfmodel <- function( y_std } + # Extract "raw" rfx coefficients for each rfx basis term if needed + if (predict_rfx_raw) { + # Extract the raw RFX samples and scale by train set outcome sd + rfx_param_list <- object$rfx_samples$extract_parameter_samples() + rfx_beta_draws <- rfx_param_list$beta_samples * + object$model_params$outcome_scale + + # Construct a matrix with the correct random effects + rfx_predictions_raw <- array( + NA, + dim = c( + nrow(X), + ncol(rfx_basis), + object$model_params$num_samples + ) + ) + for (i in 1:nrow(X)) { + rfx_predictions_raw[i, , ] <- + rfx_beta_draws[, rfx_group_ids[i], ] + } + + # Add these RFX predictions to mu and tau if warranted by the RFX model spec + if (predict_mu_forest && rfx_intercept) { + mu_hat_final <- mu_hat_forest + rfx_predictions_raw[, 1, ] + } else { + mu_hat_final <- mu_hat_forest + } + if (predict_tau_forest && rfx_intercept_plus_treatment) { + tau_hat_final <- (tau_hat_forest + + rfx_predictions_raw[, 2:ncol(rfx_basis), ]) + } else { + tau_hat_final <- tau_hat_forest + } + } + # Combine into y hat predictions needs_mean_term_preds <- predict_y_hat || predict_mu_forest || @@ -2932,32 +2977,38 @@ predict.bcfmodel <- function( if (probability_scale) { if (has_rfx) { if (predict_y_hat) { - y_hat <- pnorm(mu_hat + treatment_term + rfx_predictions) + y_hat <- pnorm(mu_hat_forest + treatment_term + rfx_predictions) } if (predict_rfx) { rfx_predictions <- pnorm(rfx_predictions) } } else { if (predict_y_hat) { - y_hat <- pnorm(mu_hat + treatment_term) + y_hat <- pnorm(mu_hat_forest + treatment_term) } } if (predict_mu_forest) { - mu_hat <- pnorm(mu_hat) + mu_hat <- pnorm(mu_hat_final) } if (predict_tau_forest) { - tau_hat <- pnorm(tau_hat) + tau_hat <- pnorm(tau_hat_final) } } else { if (has_rfx) { if (predict_y_hat) { - y_hat <- mu_hat + treatment_term + rfx_predictions + y_hat <- mu_hat_forest + treatment_term + rfx_predictions } } else { if (predict_y_hat) { - y_hat <- mu_hat + treatment_term + y_hat <- mu_hat_forest + treatment_term } } + if (predict_mu_forest) { + mu_hat <- mu_hat_final + } + if (predict_tau_forest) { + tau_hat <- tau_hat_final + } } } From 691eed7fa15763a3610540b60dab49ae6db917f6 Mon Sep 17 00:00:00 2001 From: Drew Herren Date: Thu, 23 Oct 2025 01:59:28 -0500 Subject: [PATCH 42/53] Updated predict method and added demo script to check the behavior --- R/bcf.R | 56 ++++++++++++++++++++++-------------- tools/debug/bcf_cate_debug.R | 55 +++++++++++++++++++++++++++++++++++ 2 files changed, 90 insertions(+), 21 deletions(-) diff --git a/R/bcf.R b/R/bcf.R index bfcd7774..fb43dff0 100644 --- a/R/bcf.R +++ b/R/bcf.R @@ -2628,7 +2628,7 @@ bcf <- function( #' @param rfx_group_ids (Optional) Test set group labels used for an additive random effects model. #' We do not currently support (but plan to in the near future), test set evaluation for group labels #' that were not in the training set. -#' @param rfx_basis (Optional) Test set basis for "random-slope" regression in additive random effects model. +#' @param rfx_basis (Optional) Test set basis for "random-slope" regression in additive random effects model. If the model was sampled with a random effects `model_spec` of "intercept_only" or "intercept_plus_treatment", this is optional, but if it is provided, it will be used. #' @param type (Optional) Type of prediction to return. Options are "mean", which averages the predictions from every draw of a BCF model, and "posterior", which returns the entire matrix of posterior predictions. Default: "posterior". #' @param terms (Optional) Which model terms to include in the prediction. This can be a single term or a list of model terms. Options include "y_hat", "prognostic_function", "cate", "rfx", "variance_forest", or "all". If a model doesn't have random effects or variance forest predictions, but one of those terms is request, the request will simply be ignored. If none of the requested terms are present in a model, this function will return `NULL` along with a warning. Default: "all". #' @param scale (Optional) Scale of mean function predictions. Options are "linear", which returns predictions on the original scale of the mean forest / RFX terms, and "probability", which transforms predictions into a probability of observing `y == 1`. "probability" is only valid for models fit with a probit outcome model. Default: "linear". @@ -2830,7 +2830,7 @@ predict.bcfmodel <- function( group_ids_factor <- factor(rfx_group_ids, levels = rfx_unique_group_ids) if (sum(is.na(group_ids_factor)) > 0) { stop( - "All random effect group labels provided in rfx_group_ids must be present in rfx_group_ids" + "All random effect group labels provided in rfx_group_ids must have been present at sampling time" ) } rfx_group_ids <- as.integer(group_ids_factor) @@ -2838,23 +2838,33 @@ predict.bcfmodel <- function( } # Handle RFX model specification - if (object$model_params$rfx_model_spec == "custom") { - if (is.null(rfx_basis)) { - stop( - "A user-provided basis (`rfx_basis`) must be provided when the model was sampled with a random effects model spec set to 'custom'" - ) + if (has_rfx) { + if (object$model_params$rfx_model_spec == "custom") { + if (is.null(rfx_basis)) { + stop( + "A user-provided basis (`rfx_basis`) must be provided when the model was sampled with a random effects model spec set to 'custom'" + ) + } + } else if (object$model_params$rfx_model_spec == "intercept_only") { + # Only construct a basis if user-provided basis missing + if (is.null(rfx_basis)) { + rfx_basis <- matrix( + rep(1, nrow(X)), + nrow = nrow(X), + ncol = 1 + ) + } + } else if ( + object$model_params$rfx_model_spec == "intercept_plus_treatment" + ) { + # Only construct a basis if user-provided basis missing + if (is.null(rfx_basis)) { + rfx_basis <- cbind( + rep(1, nrow(X)), + Z + ) + } } - } else if (object$model_params$rfx_model_spec == "intercept_only") { - rfx_basis <- matrix( - rep(1, nrow(X)), - nrow = nrow(X), - ncol = 1 - ) - } else if (object$model_params$rfx_model_spec == "intercept_plus_treatment") { - rfx_basis <- cbind( - rep(1, nrow(X)), - Z - ) } # Add propensities to covariate set if necessary @@ -2953,14 +2963,18 @@ predict.bcfmodel <- function( rfx_predictions_raw[i, , ] <- rfx_beta_draws[, rfx_group_ids[i], ] } + } - # Add these RFX predictions to mu and tau if warranted by the RFX model spec - if (predict_mu_forest && rfx_intercept) { + # Add raw RFX predictions to mu and tau if warranted by the RFX model spec + if (predict_mu_forest || predict_mu_forest_intermediate) { + if (rfx_intercept && predict_rfx_raw) { mu_hat_final <- mu_hat_forest + rfx_predictions_raw[, 1, ] } else { mu_hat_final <- mu_hat_forest } - if (predict_tau_forest && rfx_intercept_plus_treatment) { + } + if (predict_tau_forest || predict_tau_forest_intermediate) { + if (rfx_intercept_plus_treatment && predict_rfx_raw) { tau_hat_final <- (tau_hat_forest + rfx_predictions_raw[, 2:ncol(rfx_basis), ]) } else { diff --git a/tools/debug/bcf_cate_debug.R b/tools/debug/bcf_cate_debug.R index 8f5bcb9b..2803eb79 100644 --- a/tools/debug/bcf_cate_debug.R +++ b/tools/debug/bcf_cate_debug.R @@ -193,6 +193,61 @@ all( abs(tau_diff) < 0.001 ) +# Now repeat the same process but via random effects model spec +bcf_model <- bcf( + X_train = X_train, + Z_train = Z_train, + y_train = y_train, + propensity_train = pi_train, + rfx_group_ids_train = group_ids_train, + X_test = X_test, + Z_test = Z_test, + propensity_test = pi_test, + rfx_group_ids_test = group_ids_test, + num_gfr = 10, + num_burnin = 0, + num_mcmc = 1000, + random_effects_params = list( + model_spec = "intercept_plus_treatment" + ) +) + +# Compute CATE posterior +tau_hat_posterior_test <- compute_contrast_bcf_model( + bcf_model, + X_0 = X_test, + X_1 = X_test, + Z_0 = rep(0, n_test), + Z_1 = rep(1, n_test), + propensity_0 = pi_test, + propensity_1 = pi_test, + rfx_group_ids_0 = group_ids_test, + rfx_group_ids_1 = group_ids_test, + rfx_basis_0 = cbind(1, rep(0, n_test)), + rfx_basis_1 = cbind(1, rep(1, n_test)), + type = "posterior", + scale = "linear" +) + +# Compute the same quantity via predict +tau_hat_posterior_test_comparison <- predict( + bcf_model, + X = X_test, + Z = Z_test, + propensity = pi_test, + rfx_group_ids = group_ids_test, + rfx_basis = rfx_basis_test, + type = "posterior", + terms = "cate", + scale = "linear" +) + +# Compare results +tau_diff <- tau_hat_posterior_test_comparison - tau_hat_posterior_test +all( + abs(tau_diff) < 0.001 +) + # Generate data for a probit BCF model with random effects X <- matrix(rnorm(n * p), ncol = p) mu_x <- X[, 1] From a53e58d662521dfa580b22b83157214271f5ffb3 Mon Sep 17 00:00:00 2001 From: Drew Herren Date: Thu, 23 Oct 2025 02:09:17 -0500 Subject: [PATCH 43/53] Updated predict and demo script --- R/bcf.R | 17 +++++++++-------- tools/debug/bcf_cate_debug.R | 3 +-- 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/R/bcf.R b/R/bcf.R index fb43dff0..4b41d400 100644 --- a/R/bcf.R +++ b/R/bcf.R @@ -2808,15 +2808,16 @@ predict.bcfmodel <- function( ) } if ((object$model_params$has_rfx_basis) && (is.null(rfx_basis))) { - stop("Random effects basis (rfx_basis) must be provided for this model") + if (object$model_params$rfx_model_spec == "custom") { + stop("Random effects basis (rfx_basis) must be provided for this model") + } } - if ( - (object$model_params$num_rfx_basis > 0) && - (ncol(rfx_basis) != object$model_params$num_rfx_basis) - ) { - stop( - "Random effects basis has a different dimension than the basis used to train this model" - ) + if ((object$model_params$num_rfx_basis > 0) && (!is.null(rfx_basis))) { + if (ncol(rfx_basis) != object$model_params$num_rfx_basis) { + stop( + "Random effects basis has a different dimension than the basis used to train this model" + ) + } } # Preprocess covariates diff --git a/tools/debug/bcf_cate_debug.R b/tools/debug/bcf_cate_debug.R index 2803eb79..cf4f6b97 100644 --- a/tools/debug/bcf_cate_debug.R +++ b/tools/debug/bcf_cate_debug.R @@ -229,14 +229,13 @@ tau_hat_posterior_test <- compute_contrast_bcf_model( scale = "linear" ) -# Compute the same quantity via predict +# Compute the same quantity directly via predict tau_hat_posterior_test_comparison <- predict( bcf_model, X = X_test, Z = Z_test, propensity = pi_test, rfx_group_ids = group_ids_test, - rfx_basis = rfx_basis_test, type = "posterior", terms = "cate", scale = "linear" From ab6f0ec79d5696b41c9ca2ec7caac20576ee83c2 Mon Sep 17 00:00:00 2001 From: Drew Herren Date: Thu, 23 Oct 2025 10:21:18 -0500 Subject: [PATCH 44/53] Updated serialization --- R/bcf.R | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/R/bcf.R b/R/bcf.R index 4b41d400..46943e1b 100644 --- a/R/bcf.R +++ b/R/bcf.R @@ -3402,6 +3402,10 @@ saveBCFModelToJson <- function(object) { object$rfx_unique_group_ids ) } + jsonobj$add_boolean( + "rfx_model_spec", + object$model_params$rfx_model_spec + ) # Add propensity model (if it exists) if (object$model_params$internal_propensity_model) { From 998ebc3e13b40a8ead32caec623d282e0ba3be27 Mon Sep 17 00:00:00 2001 From: Drew Herren Date: Thu, 23 Oct 2025 10:37:36 -0500 Subject: [PATCH 45/53] Fixed R serialization bug --- R/bcf.R | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/R/bcf.R b/R/bcf.R index 46943e1b..3b75b781 100644 --- a/R/bcf.R +++ b/R/bcf.R @@ -3402,7 +3402,7 @@ saveBCFModelToJson <- function(object) { object$rfx_unique_group_ids ) } - jsonobj$add_boolean( + jsonobj$add_string( "rfx_model_spec", object$model_params$rfx_model_spec ) From 58d221c5ad935dd6f4546cf251757890dfb68bb7 Mon Sep 17 00:00:00 2001 From: Drew Herren Date: Thu, 23 Oct 2025 15:30:31 -0500 Subject: [PATCH 46/53] Added rfx model_spec to python BCF interface and included demo script to test it --- R/bcf.R | 22 ++++- demo/debug/bcf_contrast_debug.py | 41 ++++++++ src/py_stochtree.cpp | 62 ++++++++++++ stochtree/bcf.py | 159 ++++++++++++++++++++++++------- stochtree/random_effects.py | 30 ++++++ 5 files changed, 277 insertions(+), 37 deletions(-) diff --git a/R/bcf.R b/R/bcf.R index 3b75b781..274899a6 100644 --- a/R/bcf.R +++ b/R/bcf.R @@ -358,6 +358,19 @@ bcf <- function( rfx_variance_prior_shape <- rfx_params_updated$variance_prior_shape rfx_variance_prior_scale <- rfx_params_updated$variance_prior_scale + # Handle random effects specification + if (!is.character(rfx_model_spec)) { + stop("rfx_model_spec must be a string or character vector") + } + if ( + !(rfx_model_spec %in% + c("custom", "intercept_only", "intercept_plus_treatment")) + ) { + stop( + "rfx_model_spec must either be 'custom', 'intercept_only', or 'intercept_plus_treatment'" + ) + } + # Set a function-scoped RNG if user provided a random seed custom_rng <- random_seed >= 0 if (custom_rng) { @@ -2760,9 +2773,8 @@ predict.bcfmodel <- function( )) return(NULL) } - predict_rfx_intermediate <- ((predict_y_hat && has_rfx)) - predict_rfx_raw <- ((predict_mu_forest && has_rfx && rfx_intercept_only) || - (predict_mu_forest && has_rfx && rfx_intercept_plus_treatment) || + predict_rfx_intermediate <- (predict_y_hat && has_rfx) + predict_rfx_raw <- ((predict_mu_forest && has_rfx && rfx_intercept) || (predict_tau_forest && has_rfx && rfx_intercept_plus_treatment)) predict_mu_forest_intermediate <- (predict_y_hat && has_mu_forest) predict_tau_forest_intermediate <- (predict_y_hat && has_tau_forest) @@ -2946,12 +2958,12 @@ predict.bcfmodel <- function( # Extract "raw" rfx coefficients for each rfx basis term if needed if (predict_rfx_raw) { - # Extract the raw RFX samples and scale by train set outcome sd + # Extract the raw RFX samples and scale by train set outcome standard deviation rfx_param_list <- object$rfx_samples$extract_parameter_samples() rfx_beta_draws <- rfx_param_list$beta_samples * object$model_params$outcome_scale - # Construct a matrix with the correct random effects + # Construct a matrix with the appropriate group random effects arranged for each observation rfx_predictions_raw <- array( NA, dim = c( diff --git a/demo/debug/bcf_contrast_debug.py b/demo/debug/bcf_contrast_debug.py index 8a98c33c..006780d7 100644 --- a/demo/debug/bcf_contrast_debug.py +++ b/demo/debug/bcf_contrast_debug.py @@ -153,6 +153,47 @@ contrast_diff = contrast_posterior_test_comparison - contrast_posterior_test np.allclose(contrast_diff, 0, atol=0.001) +# Now repeat the same process but via random effects model spec +bcf_model = BCFModel() +bcf_model.sample( + X_train=X_train, + Z_train=Z_train, + y_train=y_train, + rfx_group_ids_train=group_ids_train, + num_gfr=10, + num_burnin=0, + num_mcmc=1000, + random_effects_params={"model_spec": "intercept_plus_treatment"}, +) + +# Compute CATE posterior +tau_hat_posterior_test = bcf_model.compute_contrast( + X_0=X_test, + X_1=X_test, + Z_0=np.zeros((n_test, 1)), + Z_1=np.ones((n_test, 1)), + rfx_group_ids_0=group_ids_test, + rfx_group_ids_1=group_ids_test, + rfx_basis_0=np.concatenate((np.ones((n_test, 1)), np.zeros((n_test, 1))), axis=1), + rfx_basis_1=np.ones((n_test, 2)), + type="posterior", + scale="linear", +) + +# Compute the same quantity via predict +tau_hat_posterior_test_comparison = bcf_model.predict( + X=X_test, + Z=Z_test, + rfx_group_ids=group_ids_test, + type="posterior", + terms="cate", + scale="linear", +) + +# Compare results +contrast_diff = tau_hat_posterior_test_comparison - tau_hat_posterior_test +np.allclose(contrast_diff, 0, atol=0.001) + # Generate data for a probit BCF model with random effects X = rng.uniform(low=0.0, high=1.0, size=(n, p)) mu_x = X[:, 0] diff --git a/src/py_stochtree.cpp b/src/py_stochtree.cpp index 66621d52..f5c30d7d 100644 --- a/src/py_stochtree.cpp +++ b/src/py_stochtree.cpp @@ -1418,6 +1418,64 @@ class RandomEffectsContainerCpp { int NumGroups() { return rfx_container_->NumGroups(); } + py::array_t GetBeta() { + int num_samples = rfx_container_->NumSamples(); + int num_components = rfx_container_->NumComponents(); + int num_groups = rfx_container_->NumGroups(); + std::vector beta_raw = rfx_container_->GetBeta(); + auto result = py::array_t(py::detail::any_container({num_components, num_groups, num_samples})); + auto accessor = result.mutable_unchecked<3>(); + for (int i = 0; i < num_components; i++) { + for (int j = 0; j < num_groups; j++) { + for (int k = 0; k < num_samples; k++) { + accessor(i,j,k) = beta_raw[k*num_groups*num_components + j*num_components + i]; + } + } + } + return result; + } + py::array_t GetXi() { + int num_samples = rfx_container_->NumSamples(); + int num_components = rfx_container_->NumComponents(); + int num_groups = rfx_container_->NumGroups(); + std::vector xi_raw = rfx_container_->GetXi(); + auto result = py::array_t(py::detail::any_container({num_components, num_groups, num_samples})); + auto accessor = result.mutable_unchecked<3>(); + for (int i = 0; i < num_components; i++) { + for (int j = 0; j < num_groups; j++) { + for (int k = 0; k < num_samples; k++) { + accessor(i,j,k) = xi_raw[k*num_groups*num_components + j*num_components + i]; + } + } + } + return result; + } + py::array_t GetAlpha() { + int num_samples = rfx_container_->NumSamples(); + int num_components = rfx_container_->NumComponents(); + std::vector alpha_raw = rfx_container_->GetAlpha(); + auto result = py::array_t(py::detail::any_container({num_components, num_samples})); + auto accessor = result.mutable_unchecked<2>(); + for (int i = 0; i < num_components; i++) { + for (int j = 0; j < num_samples; j++) { + accessor(i,j) = alpha_raw[j*num_components + i]; + } + } + return result; + } + py::array_t GetSigma() { + int num_samples = rfx_container_->NumSamples(); + int num_components = rfx_container_->NumComponents(); + std::vector sigma_raw = rfx_container_->GetSigma(); + auto result = py::array_t(py::detail::any_container({num_components, num_samples})); + auto accessor = result.mutable_unchecked<2>(); + for (int i = 0; i < num_components; i++) { + for (int j = 0; j < num_samples; j++) { + accessor(i,j) = sigma_raw[j*num_components + i]; + } + } + return result; + } void DeleteSample(int sample_num) { rfx_container_->DeleteSample(sample_num); } @@ -2294,6 +2352,10 @@ PYBIND11_MODULE(stochtree_cpp, m) { .def("NumSamples", &RandomEffectsContainerCpp::NumSamples) .def("NumComponents", &RandomEffectsContainerCpp::NumComponents) .def("NumGroups", &RandomEffectsContainerCpp::NumGroups) + .def("GetBeta", &RandomEffectsContainerCpp::GetBeta) + .def("GetXi", &RandomEffectsContainerCpp::GetXi) + .def("GetAlpha", &RandomEffectsContainerCpp::GetAlpha) + .def("GetSigma", &RandomEffectsContainerCpp::GetSigma) .def("DeleteSample", &RandomEffectsContainerCpp::DeleteSample) .def("Predict", &RandomEffectsContainerCpp::Predict) .def("SaveToJsonFile", &RandomEffectsContainerCpp::SaveToJsonFile) diff --git a/stochtree/bcf.py b/stochtree/bcf.py index ddad0e85..7e67a11e 100644 --- a/stochtree/bcf.py +++ b/stochtree/bcf.py @@ -97,7 +97,7 @@ def sample( prognostic_forest_params: Optional[Dict[str, Any]] = None, treatment_effect_forest_params: Optional[Dict[str, Any]] = None, variance_forest_params: Optional[Dict[str, Any]] = None, - rfx_params: Optional[Dict[str, Any]] = None, + random_effects_params: Optional[Dict[str, Any]] = None, ) -> None: """Runs a BCF sampler on provided training set. Outcome predictions and estimates of the prognostic and treatment effect functions will be cached for the training set and (if provided) the test set. @@ -211,6 +211,7 @@ def sample( rfx_params : dict, optional Dictionary of random effects parameters, each of which has a default value processed internally, so this argument is optional. + * `model_spec`: Specification of the random effects model. Options are "custom", "intercept_only", and "intercept_plus_treatment". If "custom" is specified, then a user-provided basis must be passed through `rfx_basis_train`. If "intercept_only" is specified, a random effects basis of all ones will be dispatched internally at sampling and prediction time. If "intercept_plus_treatment" is specified, a random effects basis that combines an "intercept" basis of all ones with the treatment variable (`Z_train`) will be dispatched internally at sampling and prediction time. Default: "custom". If either "intercept_only" or "intercept_plus_treatment" is specified, `rfx_basis_train` and `rfx_basis_test` (if provided) will be ignored. * `working_parameter_prior_mean`: Prior mean for the random effects "working parameter". Default: `None`. Must be a 1D numpy array whose dimension matches the number of random effects bases, or a scalar value that will be expanded to a vector. * `group_parameter_prior_mean`: Prior mean for the random effects "group parameters." Default: `None`. Must be a 1D numpy array whose dimension matches the number of random effects bases, or a scalar value that will be expanded to a vector. * `working_parameter_prior_cov`: Prior covariance matrix for the random effects "working parameter." Default: `None`. Must be a square numpy matrix whose dimension matches the number of random effects bases, or a scalar value that will be expanded to a diagonal matrix. @@ -308,6 +309,7 @@ def sample( # Update random effects parameters rfx_params_default = { + "model_spec": "custom", "working_parameter_prior_mean": None, "group_parameter_prior_mean": None, "working_parameter_prior_cov": None, @@ -315,7 +317,7 @@ def sample( "variance_prior_shape": 1.0, "variance_prior_scale": 1.0, } - rfx_params_updated = _preprocess_params(rfx_params_default, rfx_params) + rfx_params_updated = _preprocess_params(rfx_params_default, random_effects_params) ### Unpack all parameter values # 1. General parameters @@ -394,6 +396,7 @@ def sample( ] # 5. Random effects parameters + self.rfx_model_spec = rfx_params_updated["model_spec"] rfx_working_parameter_prior_mean = rfx_params_updated[ "working_parameter_prior_mean" ] @@ -407,6 +410,12 @@ def sample( rfx_variance_prior_shape = rfx_params_updated["variance_prior_shape"] rfx_variance_prior_scale = rfx_params_updated["variance_prior_scale"] + # Check random effects specification + if not isinstance(self.rfx_model_spec, str): + raise ValueError("rfx_model_spec must be a string") + if self.rfx_model_spec not in ["custom", "intercept_only", "intercept_plus_treatment"]: + raise ValueError("type must either be 'custom', 'intercept_only', 'intercept_plus_treatment'") + # Override keep_gfr if there are no MCMC samples if num_mcmc == 0: keep_gfr = True @@ -1368,23 +1377,45 @@ def sample( "All random effect group labels provided in rfx_group_ids_test must be present in rfx_group_ids_train" ) - # Fill in rfx basis as a vector of 1s (random intercept) if a basis not provided - has_basis_rfx = False + # Handle the rfx basis matrices + self.has_rfx_basis = False + self.num_rfx_basis = 0 if self.has_rfx: - if rfx_basis_train is None: - rfx_basis_train = np.ones((rfx_group_ids_train.shape[0], 1)) - else: - has_basis_rfx = True + if self.rfx_model_spec == "custom": + if rfx_basis_train is None: + raise ValueError( + "rfx_basis_train must be provided when rfx_model_spec = 'custom'" + ) + elif self.rfx_model_spec == "intercept_only": + if rfx_basis_train is None: + rfx_basis_train = np.ones((rfx_group_ids_train.shape[0], 1)) + elif self.rfx_model_spec == "intercept_plus_treatment": + if rfx_basis_train is None: + rfx_basis_train = np.concatenate( + (np.ones((rfx_group_ids_train.shape[0], 1)), Z_train), axis=1 + ) + self.has_rfx_basis = True + self.num_rfx_basis = rfx_basis_train.shape[1] num_rfx_groups = np.unique(rfx_group_ids_train).shape[0] num_rfx_components = rfx_basis_train.shape[1] - # TODO warn if num_rfx_groups is 1 + if num_rfx_groups == 1: + warnings.warn( + "Only one group was provided for random effect sampling, so the random effects model is likely overkill" + ) if has_rfx_test: - if rfx_basis_test is None: - if has_basis_rfx: + if self.rfx_model_spec == "custom": + if rfx_basis_test is None: raise ValueError( - "Random effects basis provided for training set, must also be provided for the test set" + "rfx_basis_test must be provided when rfx_model_spec = 'custom' and a test set is provided" + ) + elif self.rfx_model_spec == "intercept_only": + if rfx_basis_test is None: + rfx_basis_test = np.ones((rfx_group_ids_test.shape[0], 1)) + elif self.rfx_model_spec == "intercept_plus_treatment": + if rfx_basis_test is None: + rfx_basis_test = np.concatenate( + (np.ones((rfx_group_ids_test.shape[0], 1)), Z_test), axis=1 ) - rfx_basis_test = np.ones((rfx_group_ids_test.shape[0], 1)) # Set up random effects structures if self.has_rfx: @@ -1821,7 +1852,9 @@ def sample( ) partial_resid_train = np.squeeze(resid_train - mu_x) if self.has_rfx: - rfx_pred = np.squeeze(rfx_model.predict(rfx_dataset_train, rfx_tracker)) + rfx_pred = np.squeeze( + rfx_model.predict(rfx_dataset_train, rfx_tracker) + ) partial_resid_train = partial_resid_train - rfx_pred s_tt0 = np.sum(tau_x * tau_x * (np.squeeze(Z_train) == 0)) s_tt1 = np.sum(tau_x * tau_x * (np.squeeze(Z_train) == 1)) @@ -2027,7 +2060,9 @@ def sample( ) partial_resid_train = np.squeeze(resid_train - mu_x) if self.has_rfx: - rfx_pred = np.squeeze(rfx_model.predict(rfx_dataset_train, rfx_tracker)) + rfx_pred = np.squeeze( + rfx_model.predict(rfx_dataset_train, rfx_tracker) + ) partial_resid_train = partial_resid_train - rfx_pred s_tt0 = np.sum(tau_x * tau_x * (np.squeeze(Z_train) == 0)) s_tt1 = np.sum(tau_x * tau_x * (np.squeeze(Z_train) == 1)) @@ -2257,9 +2292,11 @@ def predict( type: str = "posterior", terms: Union[list[str], str] = "all", scale: str = "linear", - ) -> dict: + ) -> Union[dict[str, np.array], np.array]: """Predict outcome model components (CATE function and prognostic function) as well as overall outcome for every provided observation. Predicted outcomes are computed as `yhat = mu_x + Z*tau_x` where mu_x is a sample of the prognostic function and tau_x is a sample of the treatment effect (CATE) function. + When random effects are present, they are either included in yhat additively if `rfx_model_spec == "custom"`. They are included in mu_x if `rfx_model_spec == "intercept_only"` or + partially included in mu_x and partially included in tau_x `rfx_model_spec == "intercept_plus_treatment"`. Parameters ---------- @@ -2272,7 +2309,7 @@ def predict( rfx_group_ids : np.array, optional Optional group labels used for an additive random effects model. rfx_basis : np.array, optional - Optional basis for "random-slope" regression in an additive random effects model. + Optional basis for "random-slope" regression in an additive random effects model. Not necessary if `rfx_model_spec` is "intercept_only" or "intercept_plus_treatment", but if rfx_basis is provided, it will supercede the basis implied by `rfx_model_spec`. type : str, optional Type of prediction to return. Options are "mean", which averages the predictions from every draw of a BART model, and "posterior", which returns the entire matrix of posterior predictions. Default: "posterior". terms : str, optional @@ -2304,6 +2341,10 @@ def predict( predict_mean = type == "mean" # Handle prediction terms + rfx_model_spec = self.rfx_model_spec + rfx_intercept_only = rfx_model_spec == "intercept_only" + rfx_intercept_plus_treatment = rfx_model_spec == "intercept_plus_treatment" + rfx_intercept = rfx_intercept_only or rfx_intercept_plus_treatment if not isinstance(terms, str) and not isinstance(terms, list): raise ValueError("type must be a string or list of strings") num_terms = 1 if isinstance(terms, str) else len(terms) @@ -2339,6 +2380,9 @@ def predict( ) return None predict_rfx_intermediate = predict_y_hat and has_rfx + predict_rfx_raw = (predict_mu_forest and has_rfx and rfx_intercept) or ( + predict_tau_forest and has_rfx and rfx_intercept_plus_treatment + ) predict_mu_forest_intermediate = predict_y_hat and has_mu_forest predict_tau_forest_intermediate = predict_y_hat and has_tau_forest @@ -2434,7 +2478,7 @@ def predict( mu_raw = self.forest_container_mu.forest_container_cpp.Predict( forest_dataset_test.dataset_cpp ) - mu_x = mu_raw * self.y_std + self.y_bar + mu_x_forest = mu_raw * self.y_std + self.y_bar # Treatment effect forest predictions if predict_tau_forest or predict_tau_forest_intermediate: @@ -2446,13 +2490,27 @@ def predict( self.b1_samples - self.b0_samples, axis=(0, 2) ) tau_raw = tau_raw * adaptive_coding_weights - tau_x = np.squeeze(tau_raw * self.y_std) + tau_x_forest = np.squeeze(tau_raw * self.y_std) if Z.shape[1] > 1: treatment_term = np.multiply( - np.atleast_3d(Z).swapaxes(1, 2), tau_x + np.atleast_3d(Z).swapaxes(1, 2), tau_x_forest ).sum(axis=2) else: - treatment_term = Z * np.squeeze(tau_x) + treatment_term = Z * np.squeeze(tau_x_forest) + + # Random effects data checks + if has_rfx: + if rfx_group_ids is None: + raise ValueError( + "rfx_group_ids must be provided if rfx_basis is provided" + ) + if rfx_basis is not None: + if rfx_basis.ndim == 1: + rfx_basis = np.expand_dims(rfx_basis, 1) + if rfx_basis.shape[0] != X.shape[0]: + raise ValueError("X and rfx_basis must have the same number of rows") + if rfx_basis.shape[1] != self.num_rfx_basis: + raise ValueError("rfx_basis must have the same number of columns as the random effects basis used to sample this model") # Random effects predictions if predict_rfx or predict_rfx_intermediate: @@ -2460,14 +2518,33 @@ def predict( self.rfx_container.predict(rfx_group_ids, rfx_basis) * self.y_std ) - # Combine into y hat predictions - if predict_y_hat and has_mu_forest and has_rfx: - y_hat = mu_x + treatment_term + rfx_preds - elif predict_y_hat and has_mu_forest: - y_hat = mu_x + treatment_term - elif predict_y_hat and has_rfx: - y_hat = rfx_preds + # Extract "raw" rfx predictions for each rfx basis term if needed + if predict_rfx_raw: + # Extract the raw RFX samples and scale by train set outcome standard deviation + rfx_samples_raw = self.rfx_container.extract_parameter_samples() + rfx_beta_draws = rfx_samples_raw['beta_samples'] * self.y_std + + # Construct a matrix with the appropriate group random effects arranged for each observation + rfx_predictions_raw = np.empty(shape=(X.shape[0], rfx_beta_draws.shape[0], rfx_beta_draws.shape[2])) + for i in range(X.shape[0]): + rfx_predictions_raw[i, :, :] = rfx_beta_draws[ + :, rfx_group_ids[i], : + ] + + # Add raw RFX predictions to mu and tau if warranted by the RFX model spec + if predict_mu_forest or predict_mu_forest_intermediate: + if rfx_intercept and predict_rfx_raw: + mu_x = mu_x_forest + np.squeeze(rfx_predictions_raw[:, 0, :]) + else: + mu_x = mu_x_forest + if predict_tau_forest or predict_tau_forest_intermediate: + if rfx_intercept_plus_treatment and predict_rfx_raw: + tau_x = tau_x_forest + np.squeeze(rfx_predictions_raw[:, 1:, :]) + else: + tau_x = tau_x_forest + + # Combine into y hat predictions needs_mean_term_preds = ( predict_y_hat or predict_mu_forest or predict_tau_forest or predict_rfx ) @@ -2475,12 +2552,12 @@ def predict( if probability_scale: if has_rfx: if predict_y_hat: - y_hat = norm.cdf(mu_x + treatment_term + rfx_preds) + y_hat = norm.cdf(mu_x_forest + treatment_term + rfx_preds) if predict_rfx: rfx_preds = norm.cdf(rfx_preds) else: if predict_y_hat: - y_hat = norm.cdf(mu_x + treatment_term) + y_hat = norm.cdf(mu_x_forest + treatment_term) if predict_mu_forest: mu_x = norm.cdf(mu_x) if predict_tau_forest: @@ -2488,10 +2565,14 @@ def predict( else: if has_rfx: if predict_y_hat: - y_hat = mu_x + treatment_term + rfx_preds + y_hat = mu_x_forest + treatment_term + rfx_preds else: if predict_y_hat: - y_hat = mu_x + treatment_term + y_hat = mu_x_forest + treatment_term + if predict_mu_forest: + mu_x = mu_x + if predict_tau_forest: + tau_x = tau_x # Collapse to posterior mean predictions if requested if predict_mean: @@ -3018,6 +3099,9 @@ def to_json(self) -> str: bcf_json.add_boolean("sample_sigma2_leaf_tau", self.sample_sigma2_leaf_tau) bcf_json.add_boolean("include_variance_forest", self.include_variance_forest) bcf_json.add_boolean("has_rfx", self.has_rfx) + bcf_json.add_boolean("has_rfx_basis", self.has_rfx_basis) + bcf_json.add_scalar("num_rfx_basis", self.num_rfx_basis) + bcf_json.add_boolean("multivariate_treatment", self.multivariate_treatment) bcf_json.add_scalar("num_gfr", self.num_gfr) bcf_json.add_scalar("num_burnin", self.num_burnin) bcf_json.add_scalar("num_mcmc", self.num_mcmc) @@ -3028,6 +3112,7 @@ def to_json(self) -> str: "internal_propensity_model", self.internal_propensity_model ) bcf_json.add_boolean("probit_outcome_model", self.probit_outcome_model) + bcf_json.add_string("rfx_model_spec", self.rfx_model_spec) # Add parameter samples if self.sample_sigma2_global: @@ -3073,6 +3158,9 @@ def from_json(self, json_string: str) -> None: # Unpack forests self.include_variance_forest = bcf_json.get_boolean("include_variance_forest") self.has_rfx = bcf_json.get_boolean("has_rfx") + self.has_rfx_basis = bcf_json.get_boolean("has_rfx_basis") + self.num_rfx_basis = bcf_json.get_scalar("num_rfx_basis") + self.multivariate_treatment = bcf_json.get_boolean("multivariate_treatment") # TODO: don't just make this a placeholder that we overwrite self.forest_container_mu = ForestContainer(0, 0, False, False) self.forest_container_mu.forest_container_cpp.LoadFromJson( @@ -3113,6 +3201,7 @@ def from_json(self, json_string: str) -> None: "internal_propensity_model" ) self.probit_outcome_model = bcf_json.get_boolean("probit_outcome_model") + self.rfx_model_spec = bcf_json.get_string("rfx_model_spec") # Unpack parameter samples if self.sample_sigma2_global: @@ -3206,6 +3295,9 @@ def from_json_string_list(self, json_string_list: list[str]) -> None: # Unpack random effects self.has_rfx = json_object_default.get_boolean("has_rfx") + self.has_rfx_basis = json_object_default.get_boolean("has_rfx_basis") + self.num_rfx_basis = json_object_default.get_scalar("num_rfx_basis") + self.multivariate_treatment = json_object_default.get_boolean("multivariate_treatment") if self.has_rfx: self.rfx_container = RandomEffectsContainer() for i in range(len(json_object_list)): @@ -3241,6 +3333,9 @@ def from_json_string_list(self, json_string_list: list[str]) -> None: self.probit_outcome_model = json_object_default.get_boolean( "probit_outcome_model" ) + self.rfx_model_spec = json_object_default.get_string( + "rfx_model_spec" + ) # Unpack number of samples for i in range(len(json_object_list)): diff --git a/stochtree/random_effects.py b/stochtree/random_effects.py index c93fa7e9..f3e1a187 100644 --- a/stochtree/random_effects.py +++ b/stochtree/random_effects.py @@ -367,6 +367,36 @@ def predict(self, group_labels: np.array, basis: np.array) -> np.ndarray: return self.rfx_container_cpp.Predict( rfx_dataset.rfx_dataset_cpp, self.rfx_label_mapper_cpp ) + + def extract_parameter_samples(self) -> dict[str, np.ndarray]: + """ + Extract the random effects parameters sampled. With the "redundant parameterization" of Gelman et al (2008), + this includes four parameters: alpha (the "working parameter" shared across every group), xi + (the "group parameter" sampled separately for each group), beta (the product of alpha and xi, + which corresponds to the overall group-level random effects), and sigma (group-independent prior + variance for each component of xi). + + Returns + ------- + dict[str, np.ndarray] + dict of arrays. The alpha array has dimension (`num_components`, `num_samples`) and is simply a vector if `num_components = 1`. + The xi and beta arrays have dimension (`num_components`, `num_groups`, `num_samples`) and are simply matrices if `num_components = 1`. + The sigma array has dimension (`num_components`, `num_samples`) and is simply a vector if `num_components = 1`. + """ + # num_samples = self.rfx_container_cpp.NumSamples() + # num_components = self.rfx_container_cpp.NumComponents() + # num_groups = self.rfx_container_cpp.NumGroups() + beta_samples = np.squeeze(self.rfx_container_cpp.GetBeta()) + xi_samples = np.squeeze(self.rfx_container_cpp.GetXi()) + alpha_samples = np.squeeze(self.rfx_container_cpp.GetAlpha()) + sigma_samples = np.squeeze(self.rfx_container_cpp.GetSigma()) + output = { + "beta_samples": beta_samples, + "xi_samples": xi_samples, + "alpha_samples": alpha_samples, + "sigma_samples": sigma_samples + } + return output class RandomEffectsModel: From 33116f52a82f07f0824eb802c9256507adff6646 Mon Sep 17 00:00:00 2001 From: Drew Herren Date: Thu, 23 Oct 2025 15:31:30 -0500 Subject: [PATCH 47/53] Formatted all R package code with air --- R/calibration.R | 40 +- R/config.R | 966 ++++++++--------- R/data.R | 592 +++++------ R/forest.R | 2528 ++++++++++++++++++++++---------------------- R/kernel.R | 458 ++++---- R/model.R | 601 ++++++----- R/random_effects.R | 904 ++++++++-------- R/serialization.R | 1248 +++++++++++----------- R/utils.R | 1416 ++++++++++++------------- R/variance.R | 24 +- 10 files changed, 4388 insertions(+), 4389 deletions(-) diff --git a/R/calibration.R b/R/calibration.R index cbcd293e..a82f84f4 100644 --- a/R/calibration.R +++ b/R/calibration.R @@ -22,25 +22,25 @@ #' sigma2hat <- mean(resid(lm(y~X))^2) #' mean(var(y)/rgamma(100000, nu, rate = nu*lambda) < sigma2hat) calibrateInverseGammaErrorVariance <- function( - y, - X, - W = NULL, - nu = 3, - quant = 0.9, - standardize = TRUE + y, + X, + W = NULL, + nu = 3, + quant = 0.9, + standardize = TRUE ) { - # Compute regression basis - if (!is.null(W)) { - basis <- cbind(X, W) - } else { - basis <- X - } - # Standardize outcome if requested - if (standardize) { - y <- (y - mean(y)) / sd(y) - } - # Compute the "regression-based" overestimate of sigma^2 - sigma2hat <- mean(resid(lm(y ~ basis))^2) - # Calibrate lambda based on the implied quantile of sigma2hat - return((sigma2hat * qgamma(1 - quant, nu)) / nu) + # Compute regression basis + if (!is.null(W)) { + basis <- cbind(X, W) + } else { + basis <- X + } + # Standardize outcome if requested + if (standardize) { + y <- (y - mean(y)) / sd(y) + } + # Compute the "regression-based" overestimate of sigma^2 + sigma2hat <- mean(resid(lm(y ~ basis))^2) + # Calibrate lambda based on the implied quantile of sigma2hat + return((sigma2hat * qgamma(1 - quant, nu)) / nu) } diff --git a/R/config.R b/R/config.R index bf4d51bc..8110a9f5 100644 --- a/R/config.R +++ b/R/config.R @@ -10,423 +10,423 @@ #' forest model they wish to run. ForestModelConfig <- R6::R6Class( - classname = "ForestModelConfig", - cloneable = FALSE, - public = list( - #' @field feature_types Vector of integer-coded feature types (integers where 0 = numeric, 1 = ordered categorical, 2 = unordered categorical) - feature_types = NULL, - - #' @field sweep_update_indices Vector of trees to update in a sweep - sweep_update_indices = NULL, - - #' @field num_trees Number of trees in the forest being sampled - num_trees = NULL, - - #' @field num_features Number of features in training dataset - num_features = NULL, - - #' @field num_observations Number of observations in training dataset - num_observations = NULL, - - #' @field leaf_dimension Dimension of the leaf model - leaf_dimension = NULL, - - #' @field alpha Root node split probability in tree prior - alpha = NULL, - - #' @field beta Depth prior penalty in tree prior - beta = NULL, - - #' @field min_samples_leaf Minimum number of samples in a tree leaf - min_samples_leaf = NULL, - - #' @field max_depth Maximum depth of any tree in the ensemble in the model. Setting to `-1` does not enforce any depth limits on trees. - max_depth = NULL, - - #' @field leaf_model_type Integer specifying the leaf model type (0 = constant leaf, 1 = univariate leaf regression, 2 = multivariate leaf regression) - leaf_model_type = NULL, - - #' @field leaf_model_scale Scale parameter used in Gaussian leaf models - leaf_model_scale = NULL, - - #' @field variable_weights Vector specifying sampling probability for all p covariates in ForestDataset - variable_weights = NULL, - - #' @field variance_forest_shape Shape parameter for IG leaf models (applicable when `leaf_model_type = 3`) - variance_forest_shape = NULL, - - #' @field variance_forest_scale Scale parameter for IG leaf models (applicable when `leaf_model_type = 3`) - variance_forest_scale = NULL, - - #' @field cutpoint_grid_size Number of unique cutpoints to consider - cutpoint_grid_size = NULL, - - #' @field num_features_subsample Number of features to subsample for the GFR algorithm - num_features_subsample = NULL, - - #' Create a new ForestModelConfig object. - #' - #' @param feature_types Vector of integer-coded feature types (where 0 = numeric, 1 = ordered categorical, 2 = unordered categorical) - #' @param sweep_update_indices Vector of (0-indexed) indices of trees to update in a sweep - #' @param num_trees Number of trees in the forest being sampled - #' @param num_features Number of features in training dataset - #' @param num_observations Number of observations in training dataset - #' @param variable_weights Vector specifying sampling probability for all p covariates in ForestDataset - #' @param leaf_dimension Dimension of the leaf model (default: `1`) - #' @param alpha Root node split probability in tree prior (default: `0.95`) - #' @param beta Depth prior penalty in tree prior (default: `2.0`) - #' @param min_samples_leaf Minimum number of samples in a tree leaf (default: `5`) - #' @param max_depth Maximum depth of any tree in the ensemble in the model. Setting to `-1` does not enforce any depth limits on trees. Default: `-1`. - #' @param leaf_model_type Integer specifying the leaf model type (0 = constant leaf, 1 = univariate leaf regression, 2 = multivariate leaf regression). Default: `0`. - #' @param leaf_model_scale Scale parameter used in Gaussian leaf models (can either be a scalar or a q x q matrix, where q is the dimensionality of the basis and is only >1 when `leaf_model_int = 2`). Calibrated internally as `1/num_trees`, propagated along diagonal if needed for multivariate leaf models. - #' @param variance_forest_shape Shape parameter for IG leaf models (applicable when `leaf_model_type = 3`). Default: `1`. - #' @param variance_forest_scale Scale parameter for IG leaf models (applicable when `leaf_model_type = 3`). Default: `1`. - #' @param cutpoint_grid_size Number of unique cutpoints to consider (default: `100`) - #' @param num_features_subsample Number of features to subsample for the GFR algorithm - #' - #' @return A new ForestModelConfig object. - initialize = function( - feature_types = NULL, - sweep_update_indices = NULL, - num_trees = NULL, - num_features = NULL, - num_observations = NULL, - variable_weights = NULL, - leaf_dimension = 1, - alpha = 0.95, - beta = 2.0, - min_samples_leaf = 5, - max_depth = -1, - leaf_model_type = 1, - leaf_model_scale = NULL, - variance_forest_shape = 1.0, - variance_forest_scale = 1.0, - cutpoint_grid_size = 100, - num_features_subsample = NULL - ) { - if (is.null(feature_types)) { - if (is.null(num_features)) { - stop( - "Neither of `num_features` nor `feature_types` (a vector from which `num_features` can be inferred) was provided. Please provide at least one of these inputs when creating a ForestModelConfig object." - ) - } - warning( - "`feature_types` not provided, will be assumed to be numeric" - ) - feature_types <- rep(0, num_features) - } else { - if (is.null(num_features)) { - num_features <- length(feature_types) - } - } - if (is.null(variable_weights)) { - warning( - "`variable_weights` not provided, will be assumed to be equal-weighted" - ) - variable_weights <- rep(1 / num_features, num_features) - } - if (is.null(num_trees)) { - stop("num_trees must be provided") - } - if (!is.null(sweep_update_indices)) { - stopifnot(min(sweep_update_indices) >= 0) - stopifnot(max(sweep_update_indices) < num_trees) - } - if (is.null(num_observations)) { - stop("num_observations must be provided") - } - if (num_features != length(feature_types)) { - stop("`feature_types` must have `num_features` total elements") - } - if (num_features != length(variable_weights)) { - stop( - "`variable_weights` must have `num_features` total elements" - ) - } - self$feature_types <- feature_types - self$sweep_update_indices <- sweep_update_indices - self$variable_weights <- variable_weights - self$num_trees <- num_trees - self$num_features <- num_features - self$num_observations <- num_observations - self$leaf_dimension <- leaf_dimension - self$alpha <- alpha - self$beta <- beta - self$min_samples_leaf <- min_samples_leaf - self$max_depth <- max_depth - self$variance_forest_shape <- variance_forest_shape - self$variance_forest_scale <- variance_forest_scale - self$cutpoint_grid_size <- cutpoint_grid_size - if (is.null(num_features_subsample)) { - num_features_subsample <- num_features - } - if (num_features_subsample > num_features) { - stop( - "`num_features_subsample` cannot be larger than `num_features`" - ) - } - if (num_features_subsample <= 0) { - stop("`num_features_subsample` must be at least 1") - } - self$num_features_subsample <- num_features_subsample - - if (!(as.integer(leaf_model_type) == leaf_model_type)) { - stop("`leaf_model_type` must be an integer between 0 and 3") - if ((leaf_model_type < 0) | (leaf_model_type > 3)) { - stop("`leaf_model_type` must be an integer between 0 and 3") - } - } - self$leaf_model_type <- leaf_model_type - - if (is.null(leaf_model_scale)) { - self$leaf_model_scale <- diag(1 / num_trees, leaf_dimension) - } else if (is.matrix(leaf_model_scale)) { - if (ncol(leaf_model_scale) != nrow(leaf_model_scale)) { - stop("`leaf_model_scale` must be a square matrix") - } - if (ncol(leaf_model_scale) != leaf_dimension) { - stop( - "`leaf_model_scale` must have `leaf_dimension` rows and columns" - ) - } - self$leaf_model_scale <- leaf_model_scale - } else { - if (leaf_model_scale <= 0) { - stop( - "`leaf_model_scale` must be positive, if provided as scalar" - ) - } - self$leaf_model_scale <- diag(leaf_model_scale, leaf_dimension) - } - }, - - #' @description - #' Update feature types - #' @param feature_types Vector of integer-coded feature types (integers where 0 = numeric, 1 = ordered categorical, 2 = unordered categorical) - update_feature_types = function(feature_types) { - stopifnot(length(feature_types) == self$num_features) - self$feature_types <- feature_types - }, - - #' @description - #' Update sweep update indices - #' @param sweep_update_indices Vector of (0-indexed) indices of trees to update in a sweep - update_sweep_indices = function(sweep_update_indices) { - if (!is.null(sweep_update_indices)) { - stopifnot(min(sweep_update_indices) >= 0) - stopifnot(max(sweep_update_indices) < self$num_trees) - } - self$sweep_update_indices <- sweep_update_indices - }, - - #' @description - #' Update variable weights - #' @param variable_weights Vector specifying sampling probability for all p covariates in ForestDataset - update_variable_weights = function(variable_weights) { - stopifnot(length(variable_weights) == self$num_features) - self$variable_weights <- variable_weights - }, - - #' @description - #' Update root node split probability in tree prior - #' @param alpha Root node split probability in tree prior - update_alpha = function(alpha) { - self$alpha <- alpha - }, - - #' @description - #' Update depth prior penalty in tree prior - #' @param beta Depth prior penalty in tree prior - update_beta = function(beta) { - self$beta <- beta - }, - - #' @description - #' Update minimum number of samples per leaf node in the tree prior - #' @param min_samples_leaf Minimum number of samples in a tree leaf - update_min_samples_leaf = function(min_samples_leaf) { - self$min_samples_leaf <- min_samples_leaf - }, - - #' @description - #' Update max depth in the tree prior - #' @param max_depth Maximum depth of any tree in the ensemble in the model - update_max_depth = function(max_depth) { - self$max_depth <- max_depth - }, - - #' @description - #' Update scale parameter used in Gaussian leaf models - #' @param leaf_model_scale Scale parameter used in Gaussian leaf models - update_leaf_model_scale = function(leaf_model_scale) { - if (is.matrix(leaf_model_scale)) { - if (ncol(leaf_model_scale) != nrow(leaf_model_scale)) { - stop("`leaf_model_scale` must be a square matrix") - } - if (ncol(leaf_model_scale) != self$leaf_dimension) { - stop( - "`leaf_model_scale` must have `leaf_dimension` rows and columns" - ) - } - self$leaf_model_scale <- leaf_model_scale - } else { - if (leaf_model_scale <= 0) { - stop( - "`leaf_model_scale` must be positive, if provided as scalar" - ) - } - self$leaf_model_scale <- diag(leaf_model_scale, self$leaf_dimension) - } - }, - - #' @description - #' Update shape parameter for IG leaf models - #' @param variance_forest_shape Shape parameter for IG leaf models - update_variance_forest_shape = function(variance_forest_shape) { - self$variance_forest_shape <- variance_forest_shape - }, - - #' @description - #' Update scale parameter for IG leaf models - #' @param variance_forest_scale Scale parameter for IG leaf models - update_variance_forest_scale = function(variance_forest_scale) { - self$variance_forest_scale <- variance_forest_scale - }, - - #' @description - #' Update number of unique cutpoints to consider - #' @param cutpoint_grid_size Number of unique cutpoints to consider - update_cutpoint_grid_size = function(cutpoint_grid_size) { - self$cutpoint_grid_size <- cutpoint_grid_size - }, - - #' @description - #' Update number of features to subsample for the GFR algorithm - #' @param num_features_subsample Number of features to subsample for the GFR algorithm - update_num_features_subsample = function(num_features_subsample) { - if (num_features_subsample > self$num_features) { - stop( - "`num_features_subsample` cannot be larger than `num_features`" - ) - } - if (num_features_subsample <= 0) { - stop("`num_features_subsample` must at least 1") - } - self$num_features_subsample <- num_features_subsample - }, - - #' @description - #' Query feature types for this ForestModelConfig object - #' @returns Vector of integer-coded feature types (integers where 0 = numeric, 1 = ordered categorical, 2 = unordered categorical) - get_feature_types = function() { - return(self$feature_types) - }, - - #' @description - #' Query sweep update indices for this ForestModelConfig object - #' @returns Vector of (0-indexed) indices of trees to update in a sweep - get_sweep_indices = function() { - return(self$sweep_update_indices) - }, - - #' @description - #' Query variable weights for this ForestModelConfig object - #' @returns Vector specifying sampling probability for all p covariates in ForestDataset - get_variable_weights = function() { - return(self$variable_weights) - }, - - #' @description - #' Query number of trees - #' @returns Number of trees in a forest - get_num_trees = function() { - return(self$num_trees) - }, - - #' @description - #' Query number of features - #' @returns Number of features in a forest model training set - get_num_features = function() { - return(self$num_features) - }, - - #' @description - #' Query number of observations - #' @returns Number of observations in a forest model training set - get_num_observations = function() { - return(self$num_observations) - }, - - #' @description - #' Query root node split probability in tree prior for this ForestModelConfig object - #' @returns Root node split probability in tree prior - get_alpha = function() { - return(self$alpha) - }, - - #' @description - #' Query depth prior penalty in tree prior for this ForestModelConfig object - #' @returns Depth prior penalty in tree prior - get_beta = function() { - return(self$beta) - }, - - #' @description - #' Query root node split probability in tree prior for this ForestModelConfig object - #' @returns Minimum number of samples in a tree leaf - get_min_samples_leaf = function() { - return(self$min_samples_leaf) - }, - - #' @description - #' Query root node split probability in tree prior for this ForestModelConfig object - #' @returns Maximum depth of any tree in the ensemble in the model - get_max_depth = function() { - return(self$max_depth) - }, - - #' @description - #' Query (integer-coded) type of leaf model - #' @returns Integer coded leaf model type - get_leaf_model_type = function() { - return(self$leaf_model_type) - }, - - #' @description - #' Query scale parameter used in Gaussian leaf models for this ForestModelConfig object - #' @returns Scale parameter used in Gaussian leaf models - get_leaf_model_scale = function() { - return(self$leaf_model_scale) - }, - - #' @description - #' Query shape parameter for IG leaf models for this ForestModelConfig object - #' @returns Shape parameter for IG leaf models - get_variance_forest_shape = function() { - return(self$variance_forest_shape) - }, - - #' @description - #' Query scale parameter for IG leaf models for this ForestModelConfig object - #' @returns Scale parameter for IG leaf models - get_variance_forest_scale = function() { - return(self$variance_forest_scale) - }, - - #' @description - #' Query number of unique cutpoints to consider for this ForestModelConfig object - #' @returns Number of unique cutpoints to consider - get_cutpoint_grid_size = function() { - return(self$cutpoint_grid_size) - }, - - #' @description - #' Query number of features to subsample for the GFR algorithm - #' @returns Number of features to subsample for the GFR algorithm - get_num_features_subsample = function() { - return(self$num_features_subsample) + classname = "ForestModelConfig", + cloneable = FALSE, + public = list( + #' @field feature_types Vector of integer-coded feature types (integers where 0 = numeric, 1 = ordered categorical, 2 = unordered categorical) + feature_types = NULL, + + #' @field sweep_update_indices Vector of trees to update in a sweep + sweep_update_indices = NULL, + + #' @field num_trees Number of trees in the forest being sampled + num_trees = NULL, + + #' @field num_features Number of features in training dataset + num_features = NULL, + + #' @field num_observations Number of observations in training dataset + num_observations = NULL, + + #' @field leaf_dimension Dimension of the leaf model + leaf_dimension = NULL, + + #' @field alpha Root node split probability in tree prior + alpha = NULL, + + #' @field beta Depth prior penalty in tree prior + beta = NULL, + + #' @field min_samples_leaf Minimum number of samples in a tree leaf + min_samples_leaf = NULL, + + #' @field max_depth Maximum depth of any tree in the ensemble in the model. Setting to `-1` does not enforce any depth limits on trees. + max_depth = NULL, + + #' @field leaf_model_type Integer specifying the leaf model type (0 = constant leaf, 1 = univariate leaf regression, 2 = multivariate leaf regression) + leaf_model_type = NULL, + + #' @field leaf_model_scale Scale parameter used in Gaussian leaf models + leaf_model_scale = NULL, + + #' @field variable_weights Vector specifying sampling probability for all p covariates in ForestDataset + variable_weights = NULL, + + #' @field variance_forest_shape Shape parameter for IG leaf models (applicable when `leaf_model_type = 3`) + variance_forest_shape = NULL, + + #' @field variance_forest_scale Scale parameter for IG leaf models (applicable when `leaf_model_type = 3`) + variance_forest_scale = NULL, + + #' @field cutpoint_grid_size Number of unique cutpoints to consider + cutpoint_grid_size = NULL, + + #' @field num_features_subsample Number of features to subsample for the GFR algorithm + num_features_subsample = NULL, + + #' Create a new ForestModelConfig object. + #' + #' @param feature_types Vector of integer-coded feature types (where 0 = numeric, 1 = ordered categorical, 2 = unordered categorical) + #' @param sweep_update_indices Vector of (0-indexed) indices of trees to update in a sweep + #' @param num_trees Number of trees in the forest being sampled + #' @param num_features Number of features in training dataset + #' @param num_observations Number of observations in training dataset + #' @param variable_weights Vector specifying sampling probability for all p covariates in ForestDataset + #' @param leaf_dimension Dimension of the leaf model (default: `1`) + #' @param alpha Root node split probability in tree prior (default: `0.95`) + #' @param beta Depth prior penalty in tree prior (default: `2.0`) + #' @param min_samples_leaf Minimum number of samples in a tree leaf (default: `5`) + #' @param max_depth Maximum depth of any tree in the ensemble in the model. Setting to `-1` does not enforce any depth limits on trees. Default: `-1`. + #' @param leaf_model_type Integer specifying the leaf model type (0 = constant leaf, 1 = univariate leaf regression, 2 = multivariate leaf regression). Default: `0`. + #' @param leaf_model_scale Scale parameter used in Gaussian leaf models (can either be a scalar or a q x q matrix, where q is the dimensionality of the basis and is only >1 when `leaf_model_int = 2`). Calibrated internally as `1/num_trees`, propagated along diagonal if needed for multivariate leaf models. + #' @param variance_forest_shape Shape parameter for IG leaf models (applicable when `leaf_model_type = 3`). Default: `1`. + #' @param variance_forest_scale Scale parameter for IG leaf models (applicable when `leaf_model_type = 3`). Default: `1`. + #' @param cutpoint_grid_size Number of unique cutpoints to consider (default: `100`) + #' @param num_features_subsample Number of features to subsample for the GFR algorithm + #' + #' @return A new ForestModelConfig object. + initialize = function( + feature_types = NULL, + sweep_update_indices = NULL, + num_trees = NULL, + num_features = NULL, + num_observations = NULL, + variable_weights = NULL, + leaf_dimension = 1, + alpha = 0.95, + beta = 2.0, + min_samples_leaf = 5, + max_depth = -1, + leaf_model_type = 1, + leaf_model_scale = NULL, + variance_forest_shape = 1.0, + variance_forest_scale = 1.0, + cutpoint_grid_size = 100, + num_features_subsample = NULL + ) { + if (is.null(feature_types)) { + if (is.null(num_features)) { + stop( + "Neither of `num_features` nor `feature_types` (a vector from which `num_features` can be inferred) was provided. Please provide at least one of these inputs when creating a ForestModelConfig object." + ) + } + warning( + "`feature_types` not provided, will be assumed to be numeric" + ) + feature_types <- rep(0, num_features) + } else { + if (is.null(num_features)) { + num_features <- length(feature_types) + } + } + if (is.null(variable_weights)) { + warning( + "`variable_weights` not provided, will be assumed to be equal-weighted" + ) + variable_weights <- rep(1 / num_features, num_features) + } + if (is.null(num_trees)) { + stop("num_trees must be provided") + } + if (!is.null(sweep_update_indices)) { + stopifnot(min(sweep_update_indices) >= 0) + stopifnot(max(sweep_update_indices) < num_trees) + } + if (is.null(num_observations)) { + stop("num_observations must be provided") + } + if (num_features != length(feature_types)) { + stop("`feature_types` must have `num_features` total elements") + } + if (num_features != length(variable_weights)) { + stop( + "`variable_weights` must have `num_features` total elements" + ) + } + self$feature_types <- feature_types + self$sweep_update_indices <- sweep_update_indices + self$variable_weights <- variable_weights + self$num_trees <- num_trees + self$num_features <- num_features + self$num_observations <- num_observations + self$leaf_dimension <- leaf_dimension + self$alpha <- alpha + self$beta <- beta + self$min_samples_leaf <- min_samples_leaf + self$max_depth <- max_depth + self$variance_forest_shape <- variance_forest_shape + self$variance_forest_scale <- variance_forest_scale + self$cutpoint_grid_size <- cutpoint_grid_size + if (is.null(num_features_subsample)) { + num_features_subsample <- num_features + } + if (num_features_subsample > num_features) { + stop( + "`num_features_subsample` cannot be larger than `num_features`" + ) + } + if (num_features_subsample <= 0) { + stop("`num_features_subsample` must be at least 1") + } + self$num_features_subsample <- num_features_subsample + + if (!(as.integer(leaf_model_type) == leaf_model_type)) { + stop("`leaf_model_type` must be an integer between 0 and 3") + if ((leaf_model_type < 0) | (leaf_model_type > 3)) { + stop("`leaf_model_type` must be an integer between 0 and 3") + } + } + self$leaf_model_type <- leaf_model_type + + if (is.null(leaf_model_scale)) { + self$leaf_model_scale <- diag(1 / num_trees, leaf_dimension) + } else if (is.matrix(leaf_model_scale)) { + if (ncol(leaf_model_scale) != nrow(leaf_model_scale)) { + stop("`leaf_model_scale` must be a square matrix") + } + if (ncol(leaf_model_scale) != leaf_dimension) { + stop( + "`leaf_model_scale` must have `leaf_dimension` rows and columns" + ) + } + self$leaf_model_scale <- leaf_model_scale + } else { + if (leaf_model_scale <= 0) { + stop( + "`leaf_model_scale` must be positive, if provided as scalar" + ) + } + self$leaf_model_scale <- diag(leaf_model_scale, leaf_dimension) + } + }, + + #' @description + #' Update feature types + #' @param feature_types Vector of integer-coded feature types (integers where 0 = numeric, 1 = ordered categorical, 2 = unordered categorical) + update_feature_types = function(feature_types) { + stopifnot(length(feature_types) == self$num_features) + self$feature_types <- feature_types + }, + + #' @description + #' Update sweep update indices + #' @param sweep_update_indices Vector of (0-indexed) indices of trees to update in a sweep + update_sweep_indices = function(sweep_update_indices) { + if (!is.null(sweep_update_indices)) { + stopifnot(min(sweep_update_indices) >= 0) + stopifnot(max(sweep_update_indices) < self$num_trees) + } + self$sweep_update_indices <- sweep_update_indices + }, + + #' @description + #' Update variable weights + #' @param variable_weights Vector specifying sampling probability for all p covariates in ForestDataset + update_variable_weights = function(variable_weights) { + stopifnot(length(variable_weights) == self$num_features) + self$variable_weights <- variable_weights + }, + + #' @description + #' Update root node split probability in tree prior + #' @param alpha Root node split probability in tree prior + update_alpha = function(alpha) { + self$alpha <- alpha + }, + + #' @description + #' Update depth prior penalty in tree prior + #' @param beta Depth prior penalty in tree prior + update_beta = function(beta) { + self$beta <- beta + }, + + #' @description + #' Update minimum number of samples per leaf node in the tree prior + #' @param min_samples_leaf Minimum number of samples in a tree leaf + update_min_samples_leaf = function(min_samples_leaf) { + self$min_samples_leaf <- min_samples_leaf + }, + + #' @description + #' Update max depth in the tree prior + #' @param max_depth Maximum depth of any tree in the ensemble in the model + update_max_depth = function(max_depth) { + self$max_depth <- max_depth + }, + + #' @description + #' Update scale parameter used in Gaussian leaf models + #' @param leaf_model_scale Scale parameter used in Gaussian leaf models + update_leaf_model_scale = function(leaf_model_scale) { + if (is.matrix(leaf_model_scale)) { + if (ncol(leaf_model_scale) != nrow(leaf_model_scale)) { + stop("`leaf_model_scale` must be a square matrix") + } + if (ncol(leaf_model_scale) != self$leaf_dimension) { + stop( + "`leaf_model_scale` must have `leaf_dimension` rows and columns" + ) } - ) + self$leaf_model_scale <- leaf_model_scale + } else { + if (leaf_model_scale <= 0) { + stop( + "`leaf_model_scale` must be positive, if provided as scalar" + ) + } + self$leaf_model_scale <- diag(leaf_model_scale, self$leaf_dimension) + } + }, + + #' @description + #' Update shape parameter for IG leaf models + #' @param variance_forest_shape Shape parameter for IG leaf models + update_variance_forest_shape = function(variance_forest_shape) { + self$variance_forest_shape <- variance_forest_shape + }, + + #' @description + #' Update scale parameter for IG leaf models + #' @param variance_forest_scale Scale parameter for IG leaf models + update_variance_forest_scale = function(variance_forest_scale) { + self$variance_forest_scale <- variance_forest_scale + }, + + #' @description + #' Update number of unique cutpoints to consider + #' @param cutpoint_grid_size Number of unique cutpoints to consider + update_cutpoint_grid_size = function(cutpoint_grid_size) { + self$cutpoint_grid_size <- cutpoint_grid_size + }, + + #' @description + #' Update number of features to subsample for the GFR algorithm + #' @param num_features_subsample Number of features to subsample for the GFR algorithm + update_num_features_subsample = function(num_features_subsample) { + if (num_features_subsample > self$num_features) { + stop( + "`num_features_subsample` cannot be larger than `num_features`" + ) + } + if (num_features_subsample <= 0) { + stop("`num_features_subsample` must at least 1") + } + self$num_features_subsample <- num_features_subsample + }, + + #' @description + #' Query feature types for this ForestModelConfig object + #' @returns Vector of integer-coded feature types (integers where 0 = numeric, 1 = ordered categorical, 2 = unordered categorical) + get_feature_types = function() { + return(self$feature_types) + }, + + #' @description + #' Query sweep update indices for this ForestModelConfig object + #' @returns Vector of (0-indexed) indices of trees to update in a sweep + get_sweep_indices = function() { + return(self$sweep_update_indices) + }, + + #' @description + #' Query variable weights for this ForestModelConfig object + #' @returns Vector specifying sampling probability for all p covariates in ForestDataset + get_variable_weights = function() { + return(self$variable_weights) + }, + + #' @description + #' Query number of trees + #' @returns Number of trees in a forest + get_num_trees = function() { + return(self$num_trees) + }, + + #' @description + #' Query number of features + #' @returns Number of features in a forest model training set + get_num_features = function() { + return(self$num_features) + }, + + #' @description + #' Query number of observations + #' @returns Number of observations in a forest model training set + get_num_observations = function() { + return(self$num_observations) + }, + + #' @description + #' Query root node split probability in tree prior for this ForestModelConfig object + #' @returns Root node split probability in tree prior + get_alpha = function() { + return(self$alpha) + }, + + #' @description + #' Query depth prior penalty in tree prior for this ForestModelConfig object + #' @returns Depth prior penalty in tree prior + get_beta = function() { + return(self$beta) + }, + + #' @description + #' Query root node split probability in tree prior for this ForestModelConfig object + #' @returns Minimum number of samples in a tree leaf + get_min_samples_leaf = function() { + return(self$min_samples_leaf) + }, + + #' @description + #' Query root node split probability in tree prior for this ForestModelConfig object + #' @returns Maximum depth of any tree in the ensemble in the model + get_max_depth = function() { + return(self$max_depth) + }, + + #' @description + #' Query (integer-coded) type of leaf model + #' @returns Integer coded leaf model type + get_leaf_model_type = function() { + return(self$leaf_model_type) + }, + + #' @description + #' Query scale parameter used in Gaussian leaf models for this ForestModelConfig object + #' @returns Scale parameter used in Gaussian leaf models + get_leaf_model_scale = function() { + return(self$leaf_model_scale) + }, + + #' @description + #' Query shape parameter for IG leaf models for this ForestModelConfig object + #' @returns Shape parameter for IG leaf models + get_variance_forest_shape = function() { + return(self$variance_forest_shape) + }, + + #' @description + #' Query scale parameter for IG leaf models for this ForestModelConfig object + #' @returns Scale parameter for IG leaf models + get_variance_forest_scale = function() { + return(self$variance_forest_scale) + }, + + #' @description + #' Query number of unique cutpoints to consider for this ForestModelConfig object + #' @returns Number of unique cutpoints to consider + get_cutpoint_grid_size = function() { + return(self$cutpoint_grid_size) + }, + + #' @description + #' Query number of features to subsample for the GFR algorithm + #' @returns Number of features to subsample for the GFR algorithm + get_num_features_subsample = function() { + return(self$num_features_subsample) + } + ) ) #' Object used to get / set global parameters and other global model @@ -441,35 +441,35 @@ ForestModelConfig <- R6::R6Class( #' of a model they wish to run. GlobalModelConfig <- R6::R6Class( - classname = "GlobalModelConfig", - cloneable = FALSE, - public = list( - #' @field global_error_variance Global error variance parameter - global_error_variance = NULL, - - #' Create a new GlobalModelConfig object. - #' - #' @param global_error_variance Global error variance parameter (default: `1.0`) - #' - #' @return A new GlobalModelConfig object. - initialize = function(global_error_variance = 1.0) { - self$global_error_variance <- global_error_variance - }, - - #' @description - #' Update global error variance parameter - #' @param global_error_variance Global error variance parameter - update_global_error_variance = function(global_error_variance) { - self$global_error_variance <- global_error_variance - }, - - #' @description - #' Query global error variance parameter for this GlobalModelConfig object - #' @returns Global error variance parameter - get_global_error_variance = function() { - return(self$global_error_variance) - } - ) + classname = "GlobalModelConfig", + cloneable = FALSE, + public = list( + #' @field global_error_variance Global error variance parameter + global_error_variance = NULL, + + #' Create a new GlobalModelConfig object. + #' + #' @param global_error_variance Global error variance parameter (default: `1.0`) + #' + #' @return A new GlobalModelConfig object. + initialize = function(global_error_variance = 1.0) { + self$global_error_variance <- global_error_variance + }, + + #' @description + #' Update global error variance parameter + #' @param global_error_variance Global error variance parameter + update_global_error_variance = function(global_error_variance) { + self$global_error_variance <- global_error_variance + }, + + #' @description + #' Query global error variance parameter for this GlobalModelConfig object + #' @returns Global error variance parameter + get_global_error_variance = function() { + return(self$global_error_variance) + } + ) ) #' Create a forest model config object @@ -497,45 +497,45 @@ GlobalModelConfig <- R6::R6Class( #' @examples #' config <- createForestModelConfig(num_trees = 10, num_features = 5, num_observations = 100) createForestModelConfig <- function( - feature_types = NULL, - sweep_update_indices = NULL, - num_trees = NULL, - num_features = NULL, - num_observations = NULL, - variable_weights = NULL, - leaf_dimension = 1, - alpha = 0.95, - beta = 2.0, - min_samples_leaf = 5, - max_depth = -1, - leaf_model_type = 1, - leaf_model_scale = NULL, - variance_forest_shape = 1.0, - variance_forest_scale = 1.0, - cutpoint_grid_size = 100, - num_features_subsample = NULL + feature_types = NULL, + sweep_update_indices = NULL, + num_trees = NULL, + num_features = NULL, + num_observations = NULL, + variable_weights = NULL, + leaf_dimension = 1, + alpha = 0.95, + beta = 2.0, + min_samples_leaf = 5, + max_depth = -1, + leaf_model_type = 1, + leaf_model_scale = NULL, + variance_forest_shape = 1.0, + variance_forest_scale = 1.0, + cutpoint_grid_size = 100, + num_features_subsample = NULL ) { - return(invisible( - (ForestModelConfig$new( - feature_types, - sweep_update_indices, - num_trees, - num_features, - num_observations, - variable_weights, - leaf_dimension, - alpha, - beta, - min_samples_leaf, - max_depth, - leaf_model_type, - leaf_model_scale, - variance_forest_shape, - variance_forest_scale, - cutpoint_grid_size, - num_features_subsample - )) + return(invisible( + (ForestModelConfig$new( + feature_types, + sweep_update_indices, + num_trees, + num_features, + num_observations, + variable_weights, + leaf_dimension, + alpha, + beta, + min_samples_leaf, + max_depth, + leaf_model_type, + leaf_model_scale, + variance_forest_shape, + variance_forest_scale, + cutpoint_grid_size, + num_features_subsample )) + )) } #' Create a global model config object @@ -547,5 +547,5 @@ createForestModelConfig <- function( #' @examples #' config <- createGlobalModelConfig(global_error_variance = 100) createGlobalModelConfig <- function(global_error_variance = 1.0) { - return(invisible((GlobalModelConfig$new(global_error_variance)))) + return(invisible((GlobalModelConfig$new(global_error_variance)))) } diff --git a/R/data.R b/R/data.R index 13cd714f..5fcdeb4d 100644 --- a/R/data.R +++ b/R/data.R @@ -6,110 +6,110 @@ #' weights are optional. ForestDataset <- R6::R6Class( - classname = "ForestDataset", - cloneable = FALSE, - public = list( - #' @field data_ptr External pointer to a C++ ForestDataset class - data_ptr = NULL, - - #' @description - #' Create a new ForestDataset object. - #' @param covariates Matrix of covariates - #' @param basis (Optional) Matrix of bases used to define a leaf regression - #' @param variance_weights (Optional) Vector of observation-specific variance weights - #' @return A new `ForestDataset` object. - initialize = function( - covariates, - basis = NULL, - variance_weights = NULL - ) { - self$data_ptr <- create_forest_dataset_cpp() - forest_dataset_add_covariates_cpp(self$data_ptr, covariates) - if (!is.null(basis)) { - forest_dataset_add_basis_cpp(self$data_ptr, basis) - } - if (!is.null(variance_weights)) { - forest_dataset_add_weights_cpp(self$data_ptr, variance_weights) - } - }, - - #' @description - #' Update basis matrix in a dataset - #' @param basis Updated matrix of bases used to define a leaf regression - update_basis = function(basis) { - stopifnot(self$has_basis()) - forest_dataset_update_basis_cpp(self$data_ptr, basis) - }, - - #' @description - #' Update variance_weights in a dataset - #' @param variance_weights Updated vector of variance weights used to define individual variance / case weights - #' @param exponentiate Whether or not input vector should be exponentiated before being written to the Dataset's variance weights. Default: F. - update_variance_weights = function(variance_weights, exponentiate = F) { - stopifnot(self$has_variance_weights()) - forest_dataset_update_var_weights_cpp( - self$data_ptr, - variance_weights, - exponentiate - ) - }, - - #' @description - #' Return number of observations in a `ForestDataset` object - #' @return Observation count - num_observations = function() { - return(dataset_num_rows_cpp(self$data_ptr)) - }, - - #' @description - #' Return number of covariates in a `ForestDataset` object - #' @return Covariate count - num_covariates = function() { - return(dataset_num_covariates_cpp(self$data_ptr)) - }, - - #' @description - #' Return number of bases in a `ForestDataset` object - #' @return Basis count - num_basis = function() { - return(dataset_num_basis_cpp(self$data_ptr)) - }, - - #' @description - #' Return covariates as an R matrix - #' @return Covariate data - get_covariates = function() { - return(forest_dataset_get_covariates_cpp(self$data_ptr)) - }, - - #' @description - #' Return bases as an R matrix - #' @return Basis data - get_basis = function() { - return(forest_dataset_get_basis_cpp(self$data_ptr)) - }, - - #' @description - #' Return variance weights as an R vector - #' @return Variance weight data - get_variance_weights = function() { - return(forest_dataset_get_variance_weights_cpp(self$data_ptr)) - }, - - #' @description - #' Whether or not a dataset has a basis matrix - #' @return True if basis matrix is loaded, false otherwise - has_basis = function() { - return(dataset_has_basis_cpp(self$data_ptr)) - }, - - #' @description - #' Whether or not a dataset has variance weights - #' @return True if variance weights are loaded, false otherwise - has_variance_weights = function() { - return(dataset_has_variance_weights_cpp(self$data_ptr)) - } - ) + classname = "ForestDataset", + cloneable = FALSE, + public = list( + #' @field data_ptr External pointer to a C++ ForestDataset class + data_ptr = NULL, + + #' @description + #' Create a new ForestDataset object. + #' @param covariates Matrix of covariates + #' @param basis (Optional) Matrix of bases used to define a leaf regression + #' @param variance_weights (Optional) Vector of observation-specific variance weights + #' @return A new `ForestDataset` object. + initialize = function( + covariates, + basis = NULL, + variance_weights = NULL + ) { + self$data_ptr <- create_forest_dataset_cpp() + forest_dataset_add_covariates_cpp(self$data_ptr, covariates) + if (!is.null(basis)) { + forest_dataset_add_basis_cpp(self$data_ptr, basis) + } + if (!is.null(variance_weights)) { + forest_dataset_add_weights_cpp(self$data_ptr, variance_weights) + } + }, + + #' @description + #' Update basis matrix in a dataset + #' @param basis Updated matrix of bases used to define a leaf regression + update_basis = function(basis) { + stopifnot(self$has_basis()) + forest_dataset_update_basis_cpp(self$data_ptr, basis) + }, + + #' @description + #' Update variance_weights in a dataset + #' @param variance_weights Updated vector of variance weights used to define individual variance / case weights + #' @param exponentiate Whether or not input vector should be exponentiated before being written to the Dataset's variance weights. Default: F. + update_variance_weights = function(variance_weights, exponentiate = F) { + stopifnot(self$has_variance_weights()) + forest_dataset_update_var_weights_cpp( + self$data_ptr, + variance_weights, + exponentiate + ) + }, + + #' @description + #' Return number of observations in a `ForestDataset` object + #' @return Observation count + num_observations = function() { + return(dataset_num_rows_cpp(self$data_ptr)) + }, + + #' @description + #' Return number of covariates in a `ForestDataset` object + #' @return Covariate count + num_covariates = function() { + return(dataset_num_covariates_cpp(self$data_ptr)) + }, + + #' @description + #' Return number of bases in a `ForestDataset` object + #' @return Basis count + num_basis = function() { + return(dataset_num_basis_cpp(self$data_ptr)) + }, + + #' @description + #' Return covariates as an R matrix + #' @return Covariate data + get_covariates = function() { + return(forest_dataset_get_covariates_cpp(self$data_ptr)) + }, + + #' @description + #' Return bases as an R matrix + #' @return Basis data + get_basis = function() { + return(forest_dataset_get_basis_cpp(self$data_ptr)) + }, + + #' @description + #' Return variance weights as an R vector + #' @return Variance weight data + get_variance_weights = function() { + return(forest_dataset_get_variance_weights_cpp(self$data_ptr)) + }, + + #' @description + #' Whether or not a dataset has a basis matrix + #' @return True if basis matrix is loaded, false otherwise + has_basis = function() { + return(dataset_has_basis_cpp(self$data_ptr)) + }, + + #' @description + #' Whether or not a dataset has variance weights + #' @return True if variance weights are loaded, false otherwise + has_variance_weights = function() { + return(dataset_has_variance_weights_cpp(self$data_ptr)) + } + ) ) #' Outcome / partial residual used to sample an additive model. @@ -123,90 +123,90 @@ ForestDataset <- R6::R6Class( #' (trees, group random effects, etc...). Outcome <- R6::R6Class( - classname = "Outcome", - cloneable = FALSE, - public = list( - #' @field data_ptr External pointer to a C++ Outcome class - data_ptr = NULL, - - #' @description - #' Create a new Outcome object. - #' @param outcome Vector of outcome values - #' @return A new `Outcome` object. - initialize = function(outcome) { - self$data_ptr <- create_column_vector_cpp(outcome) - }, - - #' @description - #' Extract raw data in R from the underlying C++ object - #' @return R vector containing (copy of) the values in `Outcome` object - get_data = function() { - return(get_residual_cpp(self$data_ptr)) - }, - - #' @description - #' Update the current state of the outcome (i.e. partial residual) data by adding the values of `update_vector` - #' @param update_vector Vector to be added to outcome - #' @return None - add_vector = function(update_vector) { - if (!is.numeric(update_vector)) { - stop("update_vector must be a numeric vector or 2d matrix") - } else { - dim_vec <- dim(update_vector) - if (!is.null(dim_vec)) { - if (length(dim_vec) > 2) { - stop( - "if update_vector is provided as a matrix, it must be 2d" - ) - } - update_vector <- as.numeric(update_vector) - } - } - add_to_column_vector_cpp(self$data_ptr, update_vector) - }, - - #' @description - #' Update the current state of the outcome (i.e. partial residual) data by subtracting the values of `update_vector` - #' @param update_vector Vector to be subtracted from outcome - #' @return None - subtract_vector = function(update_vector) { - if (!is.numeric(update_vector)) { - stop("update_vector must be a numeric vector or 2d matrix") - } else { - dim_vec <- dim(update_vector) - if (!is.null(dim_vec)) { - if (length(dim_vec) > 2) { - stop( - "if update_vector is provided as a matrix, it must be 2d" - ) - } - update_vector <- as.numeric(update_vector) - } - } - subtract_from_column_vector_cpp(self$data_ptr, update_vector) - }, - - #' @description - #' Update the current state of the outcome (i.e. partial residual) data by replacing each element with the elements of `new_vector` - #' @param new_vector Vector from which to overwrite the current data - #' @return None - update_data = function(new_vector) { - if (!is.numeric(new_vector)) { - stop("update_vector must be a numeric vector or 2d matrix") - } else { - dim_vec <- dim(new_vector) - if (!is.null(dim_vec)) { - if (length(dim_vec) > 2) { - stop( - "if update_vector is provided as a matrix, it must be 2d" - ) - } - new_vector <- as.numeric(new_vector) - } - } - overwrite_column_vector_cpp(self$data_ptr, new_vector) + classname = "Outcome", + cloneable = FALSE, + public = list( + #' @field data_ptr External pointer to a C++ Outcome class + data_ptr = NULL, + + #' @description + #' Create a new Outcome object. + #' @param outcome Vector of outcome values + #' @return A new `Outcome` object. + initialize = function(outcome) { + self$data_ptr <- create_column_vector_cpp(outcome) + }, + + #' @description + #' Extract raw data in R from the underlying C++ object + #' @return R vector containing (copy of) the values in `Outcome` object + get_data = function() { + return(get_residual_cpp(self$data_ptr)) + }, + + #' @description + #' Update the current state of the outcome (i.e. partial residual) data by adding the values of `update_vector` + #' @param update_vector Vector to be added to outcome + #' @return None + add_vector = function(update_vector) { + if (!is.numeric(update_vector)) { + stop("update_vector must be a numeric vector or 2d matrix") + } else { + dim_vec <- dim(update_vector) + if (!is.null(dim_vec)) { + if (length(dim_vec) > 2) { + stop( + "if update_vector is provided as a matrix, it must be 2d" + ) + } + update_vector <- as.numeric(update_vector) + } + } + add_to_column_vector_cpp(self$data_ptr, update_vector) + }, + + #' @description + #' Update the current state of the outcome (i.e. partial residual) data by subtracting the values of `update_vector` + #' @param update_vector Vector to be subtracted from outcome + #' @return None + subtract_vector = function(update_vector) { + if (!is.numeric(update_vector)) { + stop("update_vector must be a numeric vector or 2d matrix") + } else { + dim_vec <- dim(update_vector) + if (!is.null(dim_vec)) { + if (length(dim_vec) > 2) { + stop( + "if update_vector is provided as a matrix, it must be 2d" + ) + } + update_vector <- as.numeric(update_vector) } - ) + } + subtract_from_column_vector_cpp(self$data_ptr, update_vector) + }, + + #' @description + #' Update the current state of the outcome (i.e. partial residual) data by replacing each element with the elements of `new_vector` + #' @param new_vector Vector from which to overwrite the current data + #' @return None + update_data = function(new_vector) { + if (!is.numeric(new_vector)) { + stop("update_vector must be a numeric vector or 2d matrix") + } else { + dim_vec <- dim(new_vector) + if (!is.null(dim_vec)) { + if (length(dim_vec) > 2) { + stop( + "if update_vector is provided as a matrix, it must be 2d" + ) + } + new_vector <- as.numeric(new_vector) + } + } + overwrite_column_vector_cpp(self$data_ptr, new_vector) + } + ) ) #' Dataset used to sample a random effects model @@ -216,104 +216,104 @@ Outcome <- R6::R6Class( #' bases, and variance weights. Variance weights are optional. RandomEffectsDataset <- R6::R6Class( - classname = "RandomEffectsDataset", - cloneable = FALSE, - public = list( - #' @field data_ptr External pointer to a C++ RandomEffectsDataset class - data_ptr = NULL, - - #' @description - #' Create a new RandomEffectsDataset object. - #' @param group_labels Vector of group labels - #' @param basis Matrix of bases used to define the random effects regression (for an intercept-only model, pass an array of ones) - #' @param variance_weights (Optional) Vector of observation-specific variance weights - #' @return A new `RandomEffectsDataset` object. - initialize = function(group_labels, basis, variance_weights = NULL) { - self$data_ptr <- create_rfx_dataset_cpp() - rfx_dataset_add_group_labels_cpp(self$data_ptr, group_labels) - rfx_dataset_add_basis_cpp(self$data_ptr, basis) - if (!is.null(variance_weights)) { - rfx_dataset_add_weights_cpp(self$data_ptr, variance_weights) - } - }, - - #' @description - #' Update basis matrix in a dataset - #' @param basis Updated matrix of bases used to define random slopes / intercepts - update_basis = function(basis) { - stopifnot(self$has_basis()) - rfx_dataset_update_basis_cpp(self$data_ptr, basis) - }, - - #' @description - #' Update variance_weights in a dataset - #' @param variance_weights Updated vector of variance weights used to define individual variance / case weights - #' @param exponentiate Whether or not input vector should be exponentiated before being written to the RandomEffectsDataset's variance weights. Default: F. - update_variance_weights = function(variance_weights, exponentiate = F) { - stopifnot(self$has_variance_weights()) - rfx_dataset_update_var_weights_cpp( - self$data_ptr, - variance_weights, - exponentiate - ) - }, - - #' @description - #' Return number of observations in a `RandomEffectsDataset` object - #' @return Observation count - num_observations = function() { - return(rfx_dataset_num_rows_cpp(self$data_ptr)) - }, - - #' @description - #' Return dimension of the basis matrix in a `RandomEffectsDataset` object - #' @return Basis vector count - num_basis = function() { - return(rfx_dataset_num_basis_cpp(self$data_ptr)) - }, - - #' @description - #' Return group labels as an R vector - #' @return Group label data - get_group_labels = function() { - return(rfx_dataset_get_group_labels_cpp(self$data_ptr)) - }, - - #' @description - #' Return bases as an R matrix - #' @return Basis data - get_basis = function() { - return(rfx_dataset_get_basis_cpp(self$data_ptr)) - }, - - #' @description - #' Return variance weights as an R vector - #' @return Variance weight data - get_variance_weights = function() { - return(rfx_dataset_get_variance_weights_cpp(self$data_ptr)) - }, - - #' @description - #' Whether or not a dataset has group label indices - #' @return True if group label vector is loaded, false otherwise - has_group_labels = function() { - return(rfx_dataset_has_group_labels_cpp(self$data_ptr)) - }, - - #' @description - #' Whether or not a dataset has a basis matrix - #' @return True if basis matrix is loaded, false otherwise - has_basis = function() { - return(rfx_dataset_has_basis_cpp(self$data_ptr)) - }, - - #' @description - #' Whether or not a dataset has variance weights - #' @return True if variance weights are loaded, false otherwise - has_variance_weights = function() { - return(rfx_dataset_has_variance_weights_cpp(self$data_ptr)) - } - ) + classname = "RandomEffectsDataset", + cloneable = FALSE, + public = list( + #' @field data_ptr External pointer to a C++ RandomEffectsDataset class + data_ptr = NULL, + + #' @description + #' Create a new RandomEffectsDataset object. + #' @param group_labels Vector of group labels + #' @param basis Matrix of bases used to define the random effects regression (for an intercept-only model, pass an array of ones) + #' @param variance_weights (Optional) Vector of observation-specific variance weights + #' @return A new `RandomEffectsDataset` object. + initialize = function(group_labels, basis, variance_weights = NULL) { + self$data_ptr <- create_rfx_dataset_cpp() + rfx_dataset_add_group_labels_cpp(self$data_ptr, group_labels) + rfx_dataset_add_basis_cpp(self$data_ptr, basis) + if (!is.null(variance_weights)) { + rfx_dataset_add_weights_cpp(self$data_ptr, variance_weights) + } + }, + + #' @description + #' Update basis matrix in a dataset + #' @param basis Updated matrix of bases used to define random slopes / intercepts + update_basis = function(basis) { + stopifnot(self$has_basis()) + rfx_dataset_update_basis_cpp(self$data_ptr, basis) + }, + + #' @description + #' Update variance_weights in a dataset + #' @param variance_weights Updated vector of variance weights used to define individual variance / case weights + #' @param exponentiate Whether or not input vector should be exponentiated before being written to the RandomEffectsDataset's variance weights. Default: F. + update_variance_weights = function(variance_weights, exponentiate = F) { + stopifnot(self$has_variance_weights()) + rfx_dataset_update_var_weights_cpp( + self$data_ptr, + variance_weights, + exponentiate + ) + }, + + #' @description + #' Return number of observations in a `RandomEffectsDataset` object + #' @return Observation count + num_observations = function() { + return(rfx_dataset_num_rows_cpp(self$data_ptr)) + }, + + #' @description + #' Return dimension of the basis matrix in a `RandomEffectsDataset` object + #' @return Basis vector count + num_basis = function() { + return(rfx_dataset_num_basis_cpp(self$data_ptr)) + }, + + #' @description + #' Return group labels as an R vector + #' @return Group label data + get_group_labels = function() { + return(rfx_dataset_get_group_labels_cpp(self$data_ptr)) + }, + + #' @description + #' Return bases as an R matrix + #' @return Basis data + get_basis = function() { + return(rfx_dataset_get_basis_cpp(self$data_ptr)) + }, + + #' @description + #' Return variance weights as an R vector + #' @return Variance weight data + get_variance_weights = function() { + return(rfx_dataset_get_variance_weights_cpp(self$data_ptr)) + }, + + #' @description + #' Whether or not a dataset has group label indices + #' @return True if group label vector is loaded, false otherwise + has_group_labels = function() { + return(rfx_dataset_has_group_labels_cpp(self$data_ptr)) + }, + + #' @description + #' Whether or not a dataset has a basis matrix + #' @return True if basis matrix is loaded, false otherwise + has_basis = function() { + return(rfx_dataset_has_basis_cpp(self$data_ptr)) + }, + + #' @description + #' Whether or not a dataset has variance weights + #' @return True if variance weights are loaded, false otherwise + has_variance_weights = function() { + return(rfx_dataset_has_variance_weights_cpp(self$data_ptr)) + } + ) ) #' Create a forest dataset object @@ -333,11 +333,11 @@ RandomEffectsDataset <- R6::R6Class( #' forest_dataset <- createForestDataset(covariate_matrix, basis_matrix) #' forest_dataset <- createForestDataset(covariate_matrix, basis_matrix, weight_vector) createForestDataset <- function( - covariates, - basis = NULL, - variance_weights = NULL + covariates, + basis = NULL, + variance_weights = NULL ) { - return(invisible((ForestDataset$new(covariates, basis, variance_weights)))) + return(invisible((ForestDataset$new(covariates, basis, variance_weights)))) } #' Create an outcome object @@ -352,7 +352,7 @@ createForestDataset <- function( #' y <- -5 + 10*(X[,1] > 0.5) + rnorm(100) #' outcome <- createOutcome(y) createOutcome <- function(outcome) { - return(invisible((Outcome$new(outcome)))) + return(invisible((Outcome$new(outcome)))) } #' Create a random effects dataset object @@ -371,11 +371,11 @@ createOutcome <- function(outcome) { #' rfx_dataset <- createRandomEffectsDataset(rfx_group_ids, rfx_basis) #' rfx_dataset <- createRandomEffectsDataset(rfx_group_ids, rfx_basis, weight_vector) createRandomEffectsDataset <- function( - group_labels, - basis, - variance_weights = NULL + group_labels, + basis, + variance_weights = NULL ) { - return(invisible( - (RandomEffectsDataset$new(group_labels, basis, variance_weights)) - )) + return(invisible( + (RandomEffectsDataset$new(group_labels, basis, variance_weights)) + )) } diff --git a/R/forest.R b/R/forest.R index a554c2d5..ae136a97 100644 --- a/R/forest.R +++ b/R/forest.R @@ -4,897 +4,897 @@ #' Wrapper around a C++ container of tree ensembles ForestSamples <- R6::R6Class( - classname = "ForestSamples", - cloneable = FALSE, - public = list( - #' @field forest_container_ptr External pointer to a C++ ForestContainer class - forest_container_ptr = NULL, - - #' @description - #' Create a new ForestContainer object. - #' @param num_trees Number of trees - #' @param leaf_dimension Dimensionality of the outcome model - #' @param is_leaf_constant Whether leaf is constant - #' @param is_exponentiated Whether forest predictions should be exponentiated before being returned - #' @return A new `ForestContainer` object. - initialize = function( - num_trees, - leaf_dimension = 1, - is_leaf_constant = FALSE, - is_exponentiated = FALSE - ) { - self$forest_container_ptr <- forest_container_cpp( - num_trees, - leaf_dimension, - is_leaf_constant, - is_exponentiated - ) - }, - - #' @description - #' Collapse forests in this container by a pre-specified batch size. - #' For example, if we have a container of twenty 10-tree forests, and we - #' specify a `batch_size` of 5, then this method will yield four 50-tree - #' forests. "Excess" forests remaining after the size of a forest container - #' is divided by `batch_size` will be pruned from the beginning of the - #' container (i.e. earlier sampled forests will be deleted). This method - #' has no effect if `batch_size` is larger than the number of forests - #' in a container. - #' @param batch_size Number of forests to be collapsed into a single forest - collapse = function(batch_size) { - container_size <- self$num_samples() - if ((batch_size <= container_size) && (batch_size > 1)) { - reverse_container_inds <- seq(container_size, 1, -1) - num_clean_batches <- container_size %/% batch_size - batch_inds <- (reverse_container_inds - - (container_size - - (container_size %/% num_clean_batches) * - num_clean_batches) - - 1) %/% - batch_size - for (batch_ind in unique(batch_inds[batch_inds >= 0])) { - merge_forest_inds <- sort( - reverse_container_inds[batch_inds == batch_ind] - 1 - ) - num_merge_forests <- length(merge_forest_inds) - self$combine_forests(merge_forest_inds) - for (i in num_merge_forests:2) { - self$delete_sample(merge_forest_inds[i]) - } - forest_scale_factor <- 1.0 / num_merge_forests - self$multiply_forest( - merge_forest_inds[1], - forest_scale_factor - ) - } - if (min(batch_inds) < 0) { - delete_forest_inds <- sort( - reverse_container_inds[batch_inds < 0] - 1 - ) - for (i in length(delete_forest_inds):1) { - self$delete_sample(delete_forest_inds[i]) - } - } - } - }, - - #' @description - #' Merge specified forests into a single forest - #' @param forest_inds Indices of forests to be combined (0-indexed) - combine_forests = function(forest_inds) { - stopifnot(max(forest_inds) < self$num_samples()) - stopifnot(min(forest_inds) >= 0) - stopifnot(length(forest_inds) > 1) - stopifnot(all(as.integer(forest_inds) == forest_inds)) - forest_inds_sorted <- as.integer(sort(forest_inds)) - combine_forests_forest_container_cpp( - self$forest_container_ptr, - forest_inds_sorted - ) - }, - - #' @description - #' Add a constant value to every leaf of every tree of a given forest - #' @param forest_index Index of forest whose leaves will be modified (0-indexed) - #' @param constant_value Value to add to every leaf of every tree of the forest at `forest_index` - add_to_forest = function(forest_index, constant_value) { - stopifnot(forest_index < self$num_samples()) - stopifnot(forest_index >= 0) - add_to_forest_forest_container_cpp( - self$forest_container_ptr, - forest_index, - constant_value - ) - }, - - #' @description - #' Multiply every leaf of every tree of a given forest by constant value - #' @param forest_index Index of forest whose leaves will be modified (0-indexed) - #' @param constant_multiple Value to multiply through by every leaf of every tree of the forest at `forest_index` - multiply_forest = function(forest_index, constant_multiple) { - stopifnot(forest_index < self$num_samples()) - stopifnot(forest_index >= 0) - multiply_forest_forest_container_cpp( - self$forest_container_ptr, - forest_index, - constant_multiple - ) - }, - - #' @description - #' Create a new `ForestContainer` object from a json object - #' @param json_object Object of class `CppJson` - #' @param json_forest_label Label referring to a particular forest (i.e. "forest_0") in the overall json hierarchy - #' @return A new `ForestContainer` object. - load_from_json = function(json_object, json_forest_label) { - self$forest_container_ptr <- forest_container_from_json_cpp( - json_object$json_ptr, - json_forest_label - ) - }, - - #' @description - #' Append to a `ForestContainer` object from a json object - #' @param json_object Object of class `CppJson` - #' @param json_forest_label Label referring to a particular forest (i.e. "forest_0") in the overall json hierarchy - #' @return None - append_from_json = function(json_object, json_forest_label) { - forest_container_append_from_json_cpp( - self$forest_container_ptr, - json_object$json_ptr, - json_forest_label - ) - }, - - #' @description - #' Create a new `ForestContainer` object from a json object - #' @param json_string JSON string which parses into object of class `CppJson` - #' @param json_forest_label Label referring to a particular forest (i.e. "forest_0") in the overall json hierarchy - #' @return A new `ForestContainer` object. - load_from_json_string = function(json_string, json_forest_label) { - self$forest_container_ptr <- forest_container_from_json_string_cpp( - json_string, - json_forest_label - ) - }, - - #' @description - #' Append to a `ForestContainer` object from a json object - #' @param json_string JSON string which parses into object of class `CppJson` - #' @param json_forest_label Label referring to a particular forest (i.e. "forest_0") in the overall json hierarchy - #' @return None - append_from_json_string = function(json_string, json_forest_label) { - forest_container_append_from_json_string_cpp( - self$forest_container_ptr, - json_string, - json_forest_label - ) - }, - - #' @description - #' Predict every tree ensemble on every sample in `forest_dataset` - #' @param forest_dataset `ForestDataset` R class - #' @return matrix of predictions with as many rows as in forest_dataset - #' and as many columns as samples in the `ForestContainer` - predict = function(forest_dataset) { - stopifnot(!is.null(forest_dataset$data_ptr)) - return(predict_forest_cpp( - self$forest_container_ptr, - forest_dataset$data_ptr - )) - }, - - #' @description - #' Predict "raw" leaf values (without being multiplied by basis) for every tree ensemble on every sample in `forest_dataset` - #' @param forest_dataset `ForestDataset` R class - #' @return Array of predictions for each observation in `forest_dataset` and - #' each sample in the `ForestSamples` class with each prediction having the - #' dimensionality of the forests' leaf model. In the case of a constant leaf model - #' or univariate leaf regression, this array is two-dimensional (number of observations, - #' number of forest samples). In the case of a multivariate leaf regression, - #' this array is three-dimension (number of observations, leaf model dimension, - #' number of samples). - predict_raw = function(forest_dataset) { - stopifnot(!is.null(forest_dataset$data_ptr)) - # Unpack dimensions - output_dim <- leaf_dimension_forest_container_cpp( - self$forest_container_ptr - ) - num_samples <- num_samples_forest_container_cpp( - self$forest_container_ptr - ) - n <- dataset_num_rows_cpp(forest_dataset$data_ptr) - - # Predict leaf values from forest - predictions <- predict_forest_raw_cpp( - self$forest_container_ptr, - forest_dataset$data_ptr - ) - if (output_dim > 1) { - dim(predictions) <- c(n, output_dim, num_samples) - } else { - dim(predictions) <- c(n, num_samples) - } - - return(predictions) - }, - - #' @description - #' Predict "raw" leaf values (without being multiplied by basis) for a specific forest on every sample in `forest_dataset` - #' @param forest_dataset `ForestDataset` R class - #' @param forest_num Index of the forest sample within the container - #' @return matrix of predictions with as many rows as in forest_dataset - #' and as many columns as dimensions in the leaves of trees in `ForestContainer` - predict_raw_single_forest = function(forest_dataset, forest_num) { - stopifnot(!is.null(forest_dataset$data_ptr)) - # Unpack dimensions - output_dim <- leaf_dimension_forest_container_cpp( - self$forest_container_ptr - ) - n <- dataset_num_rows_cpp(forest_dataset$data_ptr) - - # Predict leaf values from forest - output <- predict_forest_raw_single_forest_cpp( - self$forest_container_ptr, - forest_dataset$data_ptr, - forest_num - ) - return(output) - }, - - #' @description - #' Predict "raw" leaf values (without being multiplied by basis) for a specific tree in a specific forest on every observation in `forest_dataset` - #' @param forest_dataset `ForestDataset` R class - #' @param forest_num Index of the forest sample within the container - #' @param tree_num Index of the tree to be queried - #' @return matrix of predictions with as many rows as in `forest_dataset` - #' and as many columns as dimensions in the leaves of trees in `ForestContainer` - predict_raw_single_tree = function( - forest_dataset, - forest_num, - tree_num - ) { - stopifnot(!is.null(forest_dataset$data_ptr)) - - # Predict leaf values from forest - output <- predict_forest_raw_single_tree_cpp( - self$forest_container_ptr, - forest_dataset$data_ptr, - forest_num, - tree_num - ) - return(output) - }, - - #' @description - #' Set a constant predicted value for every tree in the ensemble. - #' Stops program if any tree is more than a root node. - #' @param forest_num Index of the forest sample within the container. - #' @param leaf_value Constant leaf value(s) to be fixed for each tree in the ensemble indexed by `forest_num`. Can be either a single number or a vector, depending on the forest's leaf dimension. - set_root_leaves = function(forest_num, leaf_value) { - stopifnot(!is.null(self$forest_container_ptr)) - stopifnot( - num_samples_forest_container_cpp(self$forest_container_ptr) == 0 - ) - - # Set leaf values - if (length(leaf_value) == 1) { - stopifnot( - leaf_dimension_forest_container_cpp( - self$forest_container_ptr - ) == - 1 - ) - set_leaf_value_forest_container_cpp( - self$forest_container_ptr, - leaf_value - ) - } else if (length(leaf_value) > 1) { - stopifnot( - leaf_dimension_forest_container_cpp( - self$forest_container_ptr - ) == - length(leaf_value) - ) - set_leaf_vector_forest_container_cpp( - self$forest_container_ptr, - leaf_value - ) - } else { - stop( - "leaf_value must be a numeric value or vector of length >= 1" - ) - } - }, - - #' @description - #' Set a constant predicted value for every tree in the ensemble. - #' Stops program if any tree is more than a root node. - #' @param dataset `ForestDataset` Dataset class (covariates, basis, etc...) - #' @param outcome `Outcome` Outcome class (residual / partial residual) - #' @param forest_model `ForestModel` object storing tracking structures used in training / sampling - #' @param leaf_model_int Integer value encoding the leaf model type (0 = constant gaussian, 1 = univariate gaussian, 2 = multivariate gaussian, 3 = log linear variance). - #' @param leaf_value Constant leaf value(s) to be fixed for each tree in the ensemble indexed by `forest_num`. Can be either a single number or a vector, depending on the forest's leaf dimension. - prepare_for_sampler = function( - dataset, - outcome, - forest_model, - leaf_model_int, - leaf_value - ) { - stopifnot(!is.null(dataset$data_ptr)) - stopifnot(!is.null(outcome$data_ptr)) - stopifnot(!is.null(forest_model$tracker_ptr)) - stopifnot(!is.null(self$forest_container_ptr)) - stopifnot( - num_samples_forest_container_cpp(self$forest_container_ptr) == 0 - ) - - # Initialize the model - initialize_forest_model_cpp( - dataset$data_ptr, - outcome$data_ptr, - self$forest_container_ptr, - forest_model$tracker_ptr, - leaf_value, - leaf_model_int - ) - }, - - #' @description - #' Adjusts residual based on the predictions of a forest - #' - #' This is typically run just once at the beginning of a forest sampling algorithm. - #' After trees are initialized with constant root node predictions, their root predictions are subtracted out of the residual. - #' @param dataset `ForestDataset` object storing the covariates and bases for a given forest - #' @param outcome `Outcome` object storing the residuals to be updated based on forest predictions - #' @param forest_model `ForestModel` object storing tracking structures used in training / sampling - #' @param requires_basis Whether or not a forest requires a basis for prediction - #' @param forest_num Index of forest used to update residuals - #' @param add Whether forest predictions should be added to or subtracted from residuals - adjust_residual = function( - dataset, - outcome, - forest_model, - requires_basis, - forest_num, - add - ) { - stopifnot(!is.null(dataset$data_ptr)) - stopifnot(!is.null(outcome$data_ptr)) - stopifnot(!is.null(forest_model$tracker_ptr)) - stopifnot(!is.null(self$forest_container_ptr)) - - adjust_residual_forest_container_cpp( - dataset$data_ptr, - outcome$data_ptr, - self$forest_container_ptr, - forest_model$tracker_ptr, - requires_basis, - forest_num, - add - ) - }, - - #' @description - #' Store the trees and metadata of `ForestDataset` class in a json file - #' @param json_filename Name of output json file (must end in ".json") - save_json = function(json_filename) { - invisible(json_save_forest_container_cpp( - self$forest_container_ptr, - json_filename - )) - }, - - #' @description - #' Load trees and metadata for an ensemble from a json file. Note that - #' any trees and metadata already present in `ForestDataset` class will - #' be overwritten. - #' @param json_filename Name of model input json file (must end in ".json") - load_json = function(json_filename) { - invisible(json_load_forest_container_cpp( - self$forest_container_ptr, - json_filename - )) - }, - - #' @description - #' Return number of samples in a `ForestContainer` object - #' @return Sample count - num_samples = function() { - return(num_samples_forest_container_cpp(self$forest_container_ptr)) - }, - - #' @description - #' Return number of trees in each ensemble of a `ForestContainer` object - #' @return Tree count - num_trees = function() { - return(num_trees_forest_container_cpp(self$forest_container_ptr)) - }, - - #' @description - #' Return output dimension of trees in a `ForestContainer` object - #' @return Leaf node parameter size - leaf_dimension = function() { - return(leaf_dimension_forest_container_cpp( - self$forest_container_ptr - )) - }, - - #' @description - #' Return constant leaf status of trees in a `ForestContainer` object - #' @return `TRUE` if leaves are constant, `FALSE` otherwise - is_constant_leaf = function() { - return(is_constant_leaf_forest_container_cpp( - self$forest_container_ptr - )) - }, - - #' @description - #' Return exponentiation status of trees in a `ForestContainer` object - #' @return `TRUE` if leaf predictions must be exponentiated, `FALSE` otherwise - is_exponentiated = function() { - return(is_exponentiated_forest_container_cpp( - self$forest_container_ptr - )) - }, - - #' @description - #' Add a new all-root ensemble to the container, with all of the leaves - #' set to the value / vector provided - #' @param leaf_value Value (or vector of values) to initialize root nodes in tree - add_forest_with_constant_leaves = function(leaf_value) { - if (length(leaf_value) > 1) { - add_sample_vector_forest_container_cpp( - self$forest_container_ptr, - leaf_value - ) - } else { - add_sample_value_forest_container_cpp( - self$forest_container_ptr, - leaf_value - ) - } - }, - - #' @description - #' Add a numeric (i.e. `X[,i] <= c`) split to a given tree in the ensemble - #' @param forest_num Index of the forest which contains the tree to be split - #' @param tree_num Index of the tree to be split - #' @param leaf_num Leaf to be split - #' @param feature_num Feature that defines the new split - #' @param split_threshold Value that defines the cutoff of the new split - #' @param left_leaf_value Value (or vector of values) to assign to the newly created left node - #' @param right_leaf_value Value (or vector of values) to assign to the newly created right node - add_numeric_split_tree = function( + classname = "ForestSamples", + cloneable = FALSE, + public = list( + #' @field forest_container_ptr External pointer to a C++ ForestContainer class + forest_container_ptr = NULL, + + #' @description + #' Create a new ForestContainer object. + #' @param num_trees Number of trees + #' @param leaf_dimension Dimensionality of the outcome model + #' @param is_leaf_constant Whether leaf is constant + #' @param is_exponentiated Whether forest predictions should be exponentiated before being returned + #' @return A new `ForestContainer` object. + initialize = function( + num_trees, + leaf_dimension = 1, + is_leaf_constant = FALSE, + is_exponentiated = FALSE + ) { + self$forest_container_ptr <- forest_container_cpp( + num_trees, + leaf_dimension, + is_leaf_constant, + is_exponentiated + ) + }, + + #' @description + #' Collapse forests in this container by a pre-specified batch size. + #' For example, if we have a container of twenty 10-tree forests, and we + #' specify a `batch_size` of 5, then this method will yield four 50-tree + #' forests. "Excess" forests remaining after the size of a forest container + #' is divided by `batch_size` will be pruned from the beginning of the + #' container (i.e. earlier sampled forests will be deleted). This method + #' has no effect if `batch_size` is larger than the number of forests + #' in a container. + #' @param batch_size Number of forests to be collapsed into a single forest + collapse = function(batch_size) { + container_size <- self$num_samples() + if ((batch_size <= container_size) && (batch_size > 1)) { + reverse_container_inds <- seq(container_size, 1, -1) + num_clean_batches <- container_size %/% batch_size + batch_inds <- (reverse_container_inds - + (container_size - + (container_size %/% num_clean_batches) * + num_clean_batches) - + 1) %/% + batch_size + for (batch_ind in unique(batch_inds[batch_inds >= 0])) { + merge_forest_inds <- sort( + reverse_container_inds[batch_inds == batch_ind] - 1 + ) + num_merge_forests <- length(merge_forest_inds) + self$combine_forests(merge_forest_inds) + for (i in num_merge_forests:2) { + self$delete_sample(merge_forest_inds[i]) + } + forest_scale_factor <- 1.0 / num_merge_forests + self$multiply_forest( + merge_forest_inds[1], + forest_scale_factor + ) + } + if (min(batch_inds) < 0) { + delete_forest_inds <- sort( + reverse_container_inds[batch_inds < 0] - 1 + ) + for (i in length(delete_forest_inds):1) { + self$delete_sample(delete_forest_inds[i]) + } + } + } + }, + + #' @description + #' Merge specified forests into a single forest + #' @param forest_inds Indices of forests to be combined (0-indexed) + combine_forests = function(forest_inds) { + stopifnot(max(forest_inds) < self$num_samples()) + stopifnot(min(forest_inds) >= 0) + stopifnot(length(forest_inds) > 1) + stopifnot(all(as.integer(forest_inds) == forest_inds)) + forest_inds_sorted <- as.integer(sort(forest_inds)) + combine_forests_forest_container_cpp( + self$forest_container_ptr, + forest_inds_sorted + ) + }, + + #' @description + #' Add a constant value to every leaf of every tree of a given forest + #' @param forest_index Index of forest whose leaves will be modified (0-indexed) + #' @param constant_value Value to add to every leaf of every tree of the forest at `forest_index` + add_to_forest = function(forest_index, constant_value) { + stopifnot(forest_index < self$num_samples()) + stopifnot(forest_index >= 0) + add_to_forest_forest_container_cpp( + self$forest_container_ptr, + forest_index, + constant_value + ) + }, + + #' @description + #' Multiply every leaf of every tree of a given forest by constant value + #' @param forest_index Index of forest whose leaves will be modified (0-indexed) + #' @param constant_multiple Value to multiply through by every leaf of every tree of the forest at `forest_index` + multiply_forest = function(forest_index, constant_multiple) { + stopifnot(forest_index < self$num_samples()) + stopifnot(forest_index >= 0) + multiply_forest_forest_container_cpp( + self$forest_container_ptr, + forest_index, + constant_multiple + ) + }, + + #' @description + #' Create a new `ForestContainer` object from a json object + #' @param json_object Object of class `CppJson` + #' @param json_forest_label Label referring to a particular forest (i.e. "forest_0") in the overall json hierarchy + #' @return A new `ForestContainer` object. + load_from_json = function(json_object, json_forest_label) { + self$forest_container_ptr <- forest_container_from_json_cpp( + json_object$json_ptr, + json_forest_label + ) + }, + + #' @description + #' Append to a `ForestContainer` object from a json object + #' @param json_object Object of class `CppJson` + #' @param json_forest_label Label referring to a particular forest (i.e. "forest_0") in the overall json hierarchy + #' @return None + append_from_json = function(json_object, json_forest_label) { + forest_container_append_from_json_cpp( + self$forest_container_ptr, + json_object$json_ptr, + json_forest_label + ) + }, + + #' @description + #' Create a new `ForestContainer` object from a json object + #' @param json_string JSON string which parses into object of class `CppJson` + #' @param json_forest_label Label referring to a particular forest (i.e. "forest_0") in the overall json hierarchy + #' @return A new `ForestContainer` object. + load_from_json_string = function(json_string, json_forest_label) { + self$forest_container_ptr <- forest_container_from_json_string_cpp( + json_string, + json_forest_label + ) + }, + + #' @description + #' Append to a `ForestContainer` object from a json object + #' @param json_string JSON string which parses into object of class `CppJson` + #' @param json_forest_label Label referring to a particular forest (i.e. "forest_0") in the overall json hierarchy + #' @return None + append_from_json_string = function(json_string, json_forest_label) { + forest_container_append_from_json_string_cpp( + self$forest_container_ptr, + json_string, + json_forest_label + ) + }, + + #' @description + #' Predict every tree ensemble on every sample in `forest_dataset` + #' @param forest_dataset `ForestDataset` R class + #' @return matrix of predictions with as many rows as in forest_dataset + #' and as many columns as samples in the `ForestContainer` + predict = function(forest_dataset) { + stopifnot(!is.null(forest_dataset$data_ptr)) + return(predict_forest_cpp( + self$forest_container_ptr, + forest_dataset$data_ptr + )) + }, + + #' @description + #' Predict "raw" leaf values (without being multiplied by basis) for every tree ensemble on every sample in `forest_dataset` + #' @param forest_dataset `ForestDataset` R class + #' @return Array of predictions for each observation in `forest_dataset` and + #' each sample in the `ForestSamples` class with each prediction having the + #' dimensionality of the forests' leaf model. In the case of a constant leaf model + #' or univariate leaf regression, this array is two-dimensional (number of observations, + #' number of forest samples). In the case of a multivariate leaf regression, + #' this array is three-dimension (number of observations, leaf model dimension, + #' number of samples). + predict_raw = function(forest_dataset) { + stopifnot(!is.null(forest_dataset$data_ptr)) + # Unpack dimensions + output_dim <- leaf_dimension_forest_container_cpp( + self$forest_container_ptr + ) + num_samples <- num_samples_forest_container_cpp( + self$forest_container_ptr + ) + n <- dataset_num_rows_cpp(forest_dataset$data_ptr) + + # Predict leaf values from forest + predictions <- predict_forest_raw_cpp( + self$forest_container_ptr, + forest_dataset$data_ptr + ) + if (output_dim > 1) { + dim(predictions) <- c(n, output_dim, num_samples) + } else { + dim(predictions) <- c(n, num_samples) + } + + return(predictions) + }, + + #' @description + #' Predict "raw" leaf values (without being multiplied by basis) for a specific forest on every sample in `forest_dataset` + #' @param forest_dataset `ForestDataset` R class + #' @param forest_num Index of the forest sample within the container + #' @return matrix of predictions with as many rows as in forest_dataset + #' and as many columns as dimensions in the leaves of trees in `ForestContainer` + predict_raw_single_forest = function(forest_dataset, forest_num) { + stopifnot(!is.null(forest_dataset$data_ptr)) + # Unpack dimensions + output_dim <- leaf_dimension_forest_container_cpp( + self$forest_container_ptr + ) + n <- dataset_num_rows_cpp(forest_dataset$data_ptr) + + # Predict leaf values from forest + output <- predict_forest_raw_single_forest_cpp( + self$forest_container_ptr, + forest_dataset$data_ptr, + forest_num + ) + return(output) + }, + + #' @description + #' Predict "raw" leaf values (without being multiplied by basis) for a specific tree in a specific forest on every observation in `forest_dataset` + #' @param forest_dataset `ForestDataset` R class + #' @param forest_num Index of the forest sample within the container + #' @param tree_num Index of the tree to be queried + #' @return matrix of predictions with as many rows as in `forest_dataset` + #' and as many columns as dimensions in the leaves of trees in `ForestContainer` + predict_raw_single_tree = function( + forest_dataset, + forest_num, + tree_num + ) { + stopifnot(!is.null(forest_dataset$data_ptr)) + + # Predict leaf values from forest + output <- predict_forest_raw_single_tree_cpp( + self$forest_container_ptr, + forest_dataset$data_ptr, + forest_num, + tree_num + ) + return(output) + }, + + #' @description + #' Set a constant predicted value for every tree in the ensemble. + #' Stops program if any tree is more than a root node. + #' @param forest_num Index of the forest sample within the container. + #' @param leaf_value Constant leaf value(s) to be fixed for each tree in the ensemble indexed by `forest_num`. Can be either a single number or a vector, depending on the forest's leaf dimension. + set_root_leaves = function(forest_num, leaf_value) { + stopifnot(!is.null(self$forest_container_ptr)) + stopifnot( + num_samples_forest_container_cpp(self$forest_container_ptr) == 0 + ) + + # Set leaf values + if (length(leaf_value) == 1) { + stopifnot( + leaf_dimension_forest_container_cpp( + self$forest_container_ptr + ) == + 1 + ) + set_leaf_value_forest_container_cpp( + self$forest_container_ptr, + leaf_value + ) + } else if (length(leaf_value) > 1) { + stopifnot( + leaf_dimension_forest_container_cpp( + self$forest_container_ptr + ) == + length(leaf_value) + ) + set_leaf_vector_forest_container_cpp( + self$forest_container_ptr, + leaf_value + ) + } else { + stop( + "leaf_value must be a numeric value or vector of length >= 1" + ) + } + }, + + #' @description + #' Set a constant predicted value for every tree in the ensemble. + #' Stops program if any tree is more than a root node. + #' @param dataset `ForestDataset` Dataset class (covariates, basis, etc...) + #' @param outcome `Outcome` Outcome class (residual / partial residual) + #' @param forest_model `ForestModel` object storing tracking structures used in training / sampling + #' @param leaf_model_int Integer value encoding the leaf model type (0 = constant gaussian, 1 = univariate gaussian, 2 = multivariate gaussian, 3 = log linear variance). + #' @param leaf_value Constant leaf value(s) to be fixed for each tree in the ensemble indexed by `forest_num`. Can be either a single number or a vector, depending on the forest's leaf dimension. + prepare_for_sampler = function( + dataset, + outcome, + forest_model, + leaf_model_int, + leaf_value + ) { + stopifnot(!is.null(dataset$data_ptr)) + stopifnot(!is.null(outcome$data_ptr)) + stopifnot(!is.null(forest_model$tracker_ptr)) + stopifnot(!is.null(self$forest_container_ptr)) + stopifnot( + num_samples_forest_container_cpp(self$forest_container_ptr) == 0 + ) + + # Initialize the model + initialize_forest_model_cpp( + dataset$data_ptr, + outcome$data_ptr, + self$forest_container_ptr, + forest_model$tracker_ptr, + leaf_value, + leaf_model_int + ) + }, + + #' @description + #' Adjusts residual based on the predictions of a forest + #' + #' This is typically run just once at the beginning of a forest sampling algorithm. + #' After trees are initialized with constant root node predictions, their root predictions are subtracted out of the residual. + #' @param dataset `ForestDataset` object storing the covariates and bases for a given forest + #' @param outcome `Outcome` object storing the residuals to be updated based on forest predictions + #' @param forest_model `ForestModel` object storing tracking structures used in training / sampling + #' @param requires_basis Whether or not a forest requires a basis for prediction + #' @param forest_num Index of forest used to update residuals + #' @param add Whether forest predictions should be added to or subtracted from residuals + adjust_residual = function( + dataset, + outcome, + forest_model, + requires_basis, + forest_num, + add + ) { + stopifnot(!is.null(dataset$data_ptr)) + stopifnot(!is.null(outcome$data_ptr)) + stopifnot(!is.null(forest_model$tracker_ptr)) + stopifnot(!is.null(self$forest_container_ptr)) + + adjust_residual_forest_container_cpp( + dataset$data_ptr, + outcome$data_ptr, + self$forest_container_ptr, + forest_model$tracker_ptr, + requires_basis, + forest_num, + add + ) + }, + + #' @description + #' Store the trees and metadata of `ForestDataset` class in a json file + #' @param json_filename Name of output json file (must end in ".json") + save_json = function(json_filename) { + invisible(json_save_forest_container_cpp( + self$forest_container_ptr, + json_filename + )) + }, + + #' @description + #' Load trees and metadata for an ensemble from a json file. Note that + #' any trees and metadata already present in `ForestDataset` class will + #' be overwritten. + #' @param json_filename Name of model input json file (must end in ".json") + load_json = function(json_filename) { + invisible(json_load_forest_container_cpp( + self$forest_container_ptr, + json_filename + )) + }, + + #' @description + #' Return number of samples in a `ForestContainer` object + #' @return Sample count + num_samples = function() { + return(num_samples_forest_container_cpp(self$forest_container_ptr)) + }, + + #' @description + #' Return number of trees in each ensemble of a `ForestContainer` object + #' @return Tree count + num_trees = function() { + return(num_trees_forest_container_cpp(self$forest_container_ptr)) + }, + + #' @description + #' Return output dimension of trees in a `ForestContainer` object + #' @return Leaf node parameter size + leaf_dimension = function() { + return(leaf_dimension_forest_container_cpp( + self$forest_container_ptr + )) + }, + + #' @description + #' Return constant leaf status of trees in a `ForestContainer` object + #' @return `TRUE` if leaves are constant, `FALSE` otherwise + is_constant_leaf = function() { + return(is_constant_leaf_forest_container_cpp( + self$forest_container_ptr + )) + }, + + #' @description + #' Return exponentiation status of trees in a `ForestContainer` object + #' @return `TRUE` if leaf predictions must be exponentiated, `FALSE` otherwise + is_exponentiated = function() { + return(is_exponentiated_forest_container_cpp( + self$forest_container_ptr + )) + }, + + #' @description + #' Add a new all-root ensemble to the container, with all of the leaves + #' set to the value / vector provided + #' @param leaf_value Value (or vector of values) to initialize root nodes in tree + add_forest_with_constant_leaves = function(leaf_value) { + if (length(leaf_value) > 1) { + add_sample_vector_forest_container_cpp( + self$forest_container_ptr, + leaf_value + ) + } else { + add_sample_value_forest_container_cpp( + self$forest_container_ptr, + leaf_value + ) + } + }, + + #' @description + #' Add a numeric (i.e. `X[,i] <= c`) split to a given tree in the ensemble + #' @param forest_num Index of the forest which contains the tree to be split + #' @param tree_num Index of the tree to be split + #' @param leaf_num Leaf to be split + #' @param feature_num Feature that defines the new split + #' @param split_threshold Value that defines the cutoff of the new split + #' @param left_leaf_value Value (or vector of values) to assign to the newly created left node + #' @param right_leaf_value Value (or vector of values) to assign to the newly created right node + add_numeric_split_tree = function( + forest_num, + tree_num, + leaf_num, + feature_num, + split_threshold, + left_leaf_value, + right_leaf_value + ) { + if (length(left_leaf_value) > 1) { + add_numeric_split_tree_vector_forest_container_cpp( + self$forest_container_ptr, + forest_num, + tree_num, + leaf_num, + feature_num, + split_threshold, + left_leaf_value, + right_leaf_value + ) + } else { + add_numeric_split_tree_value_forest_container_cpp( + self$forest_container_ptr, + forest_num, + tree_num, + leaf_num, + feature_num, + split_threshold, + left_leaf_value, + right_leaf_value + ) + } + }, + + #' @description + #' Retrieve a vector of indices of leaf nodes for a given tree in a given forest + #' @param forest_num Index of the forest which contains tree `tree_num` + #' @param tree_num Index of the tree for which leaf indices will be retrieved + get_tree_leaves = function(forest_num, tree_num) { + return(get_tree_leaves_forest_container_cpp( + self$forest_container_ptr, + forest_num, + tree_num + )) + }, + + #' @description + #' Retrieve a vector of split counts for every training set variable in a given tree in a given forest + #' @param forest_num Index of the forest which contains tree `tree_num` + #' @param tree_num Index of the tree for which split counts will be retrieved + #' @param num_features Total number of features in the training set + get_tree_split_counts = function(forest_num, tree_num, num_features) { + return(get_tree_split_counts_forest_container_cpp( + self$forest_container_ptr, + forest_num, + tree_num, + num_features + )) + }, + + #' @description + #' Retrieve a vector of split counts for every training set variable in a given forest + #' @param forest_num Index of the forest for which split counts will be retrieved + #' @param num_features Total number of features in the training set + get_forest_split_counts = function(forest_num, num_features) { + return(get_forest_split_counts_forest_container_cpp( + self$forest_container_ptr, + forest_num, + num_features + )) + }, + + #' @description + #' Retrieve a vector of split counts for every training set variable in a given forest, aggregated across ensembles and trees + #' @param num_features Total number of features in the training set + get_aggregate_split_counts = function(num_features) { + return(get_overall_split_counts_forest_container_cpp( + self$forest_container_ptr, + num_features + )) + }, + + #' @description + #' Retrieve a vector of split counts for every training set variable in a given forest, reported separately for each ensemble and tree + #' @param num_features Total number of features in the training set + get_granular_split_counts = function(num_features) { + n_samples <- self$num_samples() + n_trees <- self$num_trees() + output <- get_granular_split_count_array_forest_container_cpp( + self$forest_container_ptr, + num_features + ) + dim(output) <- c(n_samples, n_trees, num_features) + return(output) + }, + + #' @description + #' Maximum depth of a specific tree in a specific ensemble in a `ForestSamples` object + #' @param ensemble_num Ensemble number + #' @param tree_num Tree index within ensemble `ensemble_num` + #' @return Maximum leaf depth + ensemble_tree_max_depth = function(ensemble_num, tree_num) { + return(ensemble_tree_max_depth_forest_container_cpp( + self$forest_container_ptr, + ensemble_num, + tree_num + )) + }, + + #' @description + #' Average the maximum depth of each tree in a given ensemble in a `ForestSamples` object + #' @param ensemble_num Ensemble number + #' @return Average maximum depth + average_ensemble_max_depth = function(ensemble_num) { + return(ensemble_average_max_depth_forest_container_cpp( + self$forest_container_ptr, + ensemble_num + )) + }, + + #' @description + #' Average the maximum depth of each tree in each ensemble in a `ForestContainer` object + #' @return Average maximum depth + average_max_depth = function() { + return(average_max_depth_forest_container_cpp( + self$forest_container_ptr + )) + }, + + #' @description + #' Number of leaves in a given ensemble in a `ForestSamples` object + #' @param forest_num Index of the ensemble to be queried + #' @return Count of leaves in the ensemble stored at `forest_num` + num_forest_leaves = function(forest_num) { + return(num_leaves_ensemble_forest_container_cpp( + self$forest_container_ptr, + forest_num + )) + }, + + #' @description + #' Sum of squared (raw) leaf values in a given ensemble in a `ForestSamples` object + #' @param forest_num Index of the ensemble to be queried + #' @return Average maximum depth + sum_leaves_squared = function(forest_num) { + return(sum_leaves_squared_ensemble_forest_container_cpp( + self$forest_container_ptr, + forest_num + )) + }, + + #' @description + #' Whether or not a given node of a given tree in a given forest in the `ForestSamples` is a leaf + #' @param forest_num Index of the forest to be queried + #' @param tree_num Index of the tree to be queried + #' @param node_id Index of the node to be queried + #' @return `TRUE` if node is a leaf, `FALSE` otherwise + is_leaf_node = function(forest_num, tree_num, node_id) { + return(is_leaf_node_forest_container_cpp( + self$forest_container_ptr, + forest_num, + tree_num, + node_id + )) + }, + + #' @description + #' Whether or not a given node of a given tree in a given forest in the `ForestSamples` is a numeric split node + #' @param forest_num Index of the forest to be queried + #' @param tree_num Index of the tree to be queried + #' @param node_id Index of the node to be queried + #' @return `TRUE` if node is a numeric split node, `FALSE` otherwise + is_numeric_split_node = function(forest_num, tree_num, node_id) { + return(is_numeric_split_node_forest_container_cpp( + self$forest_container_ptr, + forest_num, + tree_num, + node_id + )) + }, + + #' @description + #' Whether or not a given node of a given tree in a given forest in the `ForestSamples` is a categorical split node + #' @param forest_num Index of the forest to be queried + #' @param tree_num Index of the tree to be queried + #' @param node_id Index of the node to be queried + #' @return `TRUE` if node is a categorical split node, `FALSE` otherwise + is_categorical_split_node = function(forest_num, tree_num, node_id) { + return(is_categorical_split_node_forest_container_cpp( + self$forest_container_ptr, + forest_num, + tree_num, + node_id + )) + }, + + #' @description + #' Parent node of given node of a given tree in a given forest in a `ForestSamples` object + #' @param forest_num Index of the forest to be queried + #' @param tree_num Index of the tree to be queried + #' @param node_id Index of the node to be queried + #' @return Integer ID of the parent node + parent_node = function(forest_num, tree_num, node_id) { + return(parent_node_forest_container_cpp( + self$forest_container_ptr, + forest_num, + tree_num, + node_id + )) + }, + + #' @description + #' Left child node of given node of a given tree in a given forest in a `ForestSamples` object + #' @param forest_num Index of the forest to be queried + #' @param tree_num Index of the tree to be queried + #' @param node_id Index of the node to be queried + #' @return Integer ID of the left child node + left_child_node = function(forest_num, tree_num, node_id) { + return(left_child_node_forest_container_cpp( + self$forest_container_ptr, + forest_num, + tree_num, + node_id + )) + }, + + #' @description + #' Right child node of given node of a given tree in a given forest in a `ForestSamples` object + #' @param forest_num Index of the forest to be queried + #' @param tree_num Index of the tree to be queried + #' @param node_id Index of the node to be queried + #' @return Integer ID of the right child node + right_child_node = function(forest_num, tree_num, node_id) { + return(right_child_node_forest_container_cpp( + self$forest_container_ptr, + forest_num, + tree_num, + node_id + )) + }, + + #' @description + #' Depth of given node of a given tree in a given forest in a `ForestSamples` object, with 0 depth for the root node. + #' @param forest_num Index of the forest to be queried + #' @param tree_num Index of the tree to be queried + #' @param node_id Index of the node to be queried + #' @return Integer valued depth of the node + node_depth = function(forest_num, tree_num, node_id) { + return(node_depth_forest_container_cpp( + self$forest_container_ptr, + forest_num, + tree_num, + node_id + )) + }, + + #' @description + #' Split index of given node of a given tree in a given forest in a `ForestSamples` object. Returns `-1` is node is a leaf. + #' @param forest_num Index of the forest to be queried + #' @param tree_num Index of the tree to be queried + #' @param node_id Index of the node to be queried + #' @return Integer valued depth of the node + node_split_index = function(forest_num, tree_num, node_id) { + if (self$is_leaf_node(forest_num, tree_num, node_id)) { + return(-1) + } else { + return(split_index_forest_container_cpp( + self$forest_container_ptr, + forest_num, + tree_num, + node_id + )) + } + }, + + #' @description + #' Threshold that defines a numeric split for a given node of a given tree in a given forest in a `ForestSamples` object. + #' Returns `Inf` if the node is a leaf or a categorical split node. + #' @param forest_num Index of the forest to be queried + #' @param tree_num Index of the tree to be queried + #' @param node_id Index of the node to be queried + #' @return Threshold defining a split for the node + node_split_threshold = function(forest_num, tree_num, node_id) { + if ( + self$is_leaf_node(forest_num, tree_num, node_id) || + self$is_categorical_split_node( forest_num, tree_num, - leaf_num, - feature_num, - split_threshold, - left_leaf_value, - right_leaf_value - ) { - if (length(left_leaf_value) > 1) { - add_numeric_split_tree_vector_forest_container_cpp( - self$forest_container_ptr, - forest_num, - tree_num, - leaf_num, - feature_num, - split_threshold, - left_leaf_value, - right_leaf_value - ) - } else { - add_numeric_split_tree_value_forest_container_cpp( - self$forest_container_ptr, - forest_num, - tree_num, - leaf_num, - feature_num, - split_threshold, - left_leaf_value, - right_leaf_value - ) - } - }, - - #' @description - #' Retrieve a vector of indices of leaf nodes for a given tree in a given forest - #' @param forest_num Index of the forest which contains tree `tree_num` - #' @param tree_num Index of the tree for which leaf indices will be retrieved - get_tree_leaves = function(forest_num, tree_num) { - return(get_tree_leaves_forest_container_cpp( - self$forest_container_ptr, - forest_num, - tree_num - )) - }, - - #' @description - #' Retrieve a vector of split counts for every training set variable in a given tree in a given forest - #' @param forest_num Index of the forest which contains tree `tree_num` - #' @param tree_num Index of the tree for which split counts will be retrieved - #' @param num_features Total number of features in the training set - get_tree_split_counts = function(forest_num, tree_num, num_features) { - return(get_tree_split_counts_forest_container_cpp( - self$forest_container_ptr, - forest_num, - tree_num, - num_features - )) - }, - - #' @description - #' Retrieve a vector of split counts for every training set variable in a given forest - #' @param forest_num Index of the forest for which split counts will be retrieved - #' @param num_features Total number of features in the training set - get_forest_split_counts = function(forest_num, num_features) { - return(get_forest_split_counts_forest_container_cpp( - self$forest_container_ptr, - forest_num, - num_features - )) - }, - - #' @description - #' Retrieve a vector of split counts for every training set variable in a given forest, aggregated across ensembles and trees - #' @param num_features Total number of features in the training set - get_aggregate_split_counts = function(num_features) { - return(get_overall_split_counts_forest_container_cpp( - self$forest_container_ptr, - num_features - )) - }, - - #' @description - #' Retrieve a vector of split counts for every training set variable in a given forest, reported separately for each ensemble and tree - #' @param num_features Total number of features in the training set - get_granular_split_counts = function(num_features) { - n_samples <- self$num_samples() - n_trees <- self$num_trees() - output <- get_granular_split_count_array_forest_container_cpp( - self$forest_container_ptr, - num_features - ) - dim(output) <- c(n_samples, n_trees, num_features) - return(output) - }, - - #' @description - #' Maximum depth of a specific tree in a specific ensemble in a `ForestSamples` object - #' @param ensemble_num Ensemble number - #' @param tree_num Tree index within ensemble `ensemble_num` - #' @return Maximum leaf depth - ensemble_tree_max_depth = function(ensemble_num, tree_num) { - return(ensemble_tree_max_depth_forest_container_cpp( - self$forest_container_ptr, - ensemble_num, - tree_num - )) - }, - - #' @description - #' Average the maximum depth of each tree in a given ensemble in a `ForestSamples` object - #' @param ensemble_num Ensemble number - #' @return Average maximum depth - average_ensemble_max_depth = function(ensemble_num) { - return(ensemble_average_max_depth_forest_container_cpp( - self$forest_container_ptr, - ensemble_num - )) - }, - - #' @description - #' Average the maximum depth of each tree in each ensemble in a `ForestContainer` object - #' @return Average maximum depth - average_max_depth = function() { - return(average_max_depth_forest_container_cpp( - self$forest_container_ptr - )) - }, - - #' @description - #' Number of leaves in a given ensemble in a `ForestSamples` object - #' @param forest_num Index of the ensemble to be queried - #' @return Count of leaves in the ensemble stored at `forest_num` - num_forest_leaves = function(forest_num) { - return(num_leaves_ensemble_forest_container_cpp( - self$forest_container_ptr, - forest_num - )) - }, - - #' @description - #' Sum of squared (raw) leaf values in a given ensemble in a `ForestSamples` object - #' @param forest_num Index of the ensemble to be queried - #' @return Average maximum depth - sum_leaves_squared = function(forest_num) { - return(sum_leaves_squared_ensemble_forest_container_cpp( - self$forest_container_ptr, - forest_num - )) - }, - - #' @description - #' Whether or not a given node of a given tree in a given forest in the `ForestSamples` is a leaf - #' @param forest_num Index of the forest to be queried - #' @param tree_num Index of the tree to be queried - #' @param node_id Index of the node to be queried - #' @return `TRUE` if node is a leaf, `FALSE` otherwise - is_leaf_node = function(forest_num, tree_num, node_id) { - return(is_leaf_node_forest_container_cpp( - self$forest_container_ptr, - forest_num, - tree_num, - node_id - )) - }, - - #' @description - #' Whether or not a given node of a given tree in a given forest in the `ForestSamples` is a numeric split node - #' @param forest_num Index of the forest to be queried - #' @param tree_num Index of the tree to be queried - #' @param node_id Index of the node to be queried - #' @return `TRUE` if node is a numeric split node, `FALSE` otherwise - is_numeric_split_node = function(forest_num, tree_num, node_id) { - return(is_numeric_split_node_forest_container_cpp( - self$forest_container_ptr, - forest_num, - tree_num, - node_id - )) - }, - - #' @description - #' Whether or not a given node of a given tree in a given forest in the `ForestSamples` is a categorical split node - #' @param forest_num Index of the forest to be queried - #' @param tree_num Index of the tree to be queried - #' @param node_id Index of the node to be queried - #' @return `TRUE` if node is a categorical split node, `FALSE` otherwise - is_categorical_split_node = function(forest_num, tree_num, node_id) { - return(is_categorical_split_node_forest_container_cpp( - self$forest_container_ptr, - forest_num, - tree_num, - node_id - )) - }, - - #' @description - #' Parent node of given node of a given tree in a given forest in a `ForestSamples` object - #' @param forest_num Index of the forest to be queried - #' @param tree_num Index of the tree to be queried - #' @param node_id Index of the node to be queried - #' @return Integer ID of the parent node - parent_node = function(forest_num, tree_num, node_id) { - return(parent_node_forest_container_cpp( - self$forest_container_ptr, - forest_num, - tree_num, - node_id - )) - }, - - #' @description - #' Left child node of given node of a given tree in a given forest in a `ForestSamples` object - #' @param forest_num Index of the forest to be queried - #' @param tree_num Index of the tree to be queried - #' @param node_id Index of the node to be queried - #' @return Integer ID of the left child node - left_child_node = function(forest_num, tree_num, node_id) { - return(left_child_node_forest_container_cpp( - self$forest_container_ptr, - forest_num, - tree_num, - node_id - )) - }, - - #' @description - #' Right child node of given node of a given tree in a given forest in a `ForestSamples` object - #' @param forest_num Index of the forest to be queried - #' @param tree_num Index of the tree to be queried - #' @param node_id Index of the node to be queried - #' @return Integer ID of the right child node - right_child_node = function(forest_num, tree_num, node_id) { - return(right_child_node_forest_container_cpp( - self$forest_container_ptr, - forest_num, - tree_num, - node_id - )) - }, - - #' @description - #' Depth of given node of a given tree in a given forest in a `ForestSamples` object, with 0 depth for the root node. - #' @param forest_num Index of the forest to be queried - #' @param tree_num Index of the tree to be queried - #' @param node_id Index of the node to be queried - #' @return Integer valued depth of the node - node_depth = function(forest_num, tree_num, node_id) { - return(node_depth_forest_container_cpp( - self$forest_container_ptr, - forest_num, - tree_num, - node_id - )) - }, - - #' @description - #' Split index of given node of a given tree in a given forest in a `ForestSamples` object. Returns `-1` is node is a leaf. - #' @param forest_num Index of the forest to be queried - #' @param tree_num Index of the tree to be queried - #' @param node_id Index of the node to be queried - #' @return Integer valued depth of the node - node_split_index = function(forest_num, tree_num, node_id) { - if (self$is_leaf_node(forest_num, tree_num, node_id)) { - return(-1) - } else { - return(split_index_forest_container_cpp( - self$forest_container_ptr, - forest_num, - tree_num, - node_id - )) - } - }, - - #' @description - #' Threshold that defines a numeric split for a given node of a given tree in a given forest in a `ForestSamples` object. - #' Returns `Inf` if the node is a leaf or a categorical split node. - #' @param forest_num Index of the forest to be queried - #' @param tree_num Index of the tree to be queried - #' @param node_id Index of the node to be queried - #' @return Threshold defining a split for the node - node_split_threshold = function(forest_num, tree_num, node_id) { - if ( - self$is_leaf_node(forest_num, tree_num, node_id) || - self$is_categorical_split_node( - forest_num, - tree_num, - node_id - ) - ) { - return(Inf) - } else { - return(split_theshold_forest_container_cpp( - self$forest_container_ptr, - forest_num, - tree_num, - node_id - )) - } - }, - - #' @description - #' Array of category indices that define a categorical split for a given node of a given tree in a given forest in a `ForestSamples` object. - #' Returns `c(Inf)` if the node is a leaf or a numeric split node. - #' @param forest_num Index of the forest to be queried - #' @param tree_num Index of the tree to be queried - #' @param node_id Index of the node to be queried - #' @return Categories defining a split for the node - node_split_categories = function(forest_num, tree_num, node_id) { - if ( - self$is_leaf_node(forest_num, tree_num, node_id) || - self$is_numeric_split_node(forest_num, tree_num, node_id) - ) { - return(c(Inf)) - } else { - return(split_categories_forest_container_cpp( - self$forest_container_ptr, - forest_num, - tree_num, - node_id - )) - } - }, - - #' @description - #' Leaf node value(s) for a given node of a given tree in a given forest in a `ForestSamples` object. - #' Values are stale if the node is a split node. - #' @param forest_num Index of the forest to be queried - #' @param tree_num Index of the tree to be queried - #' @param node_id Index of the node to be queried - #' @return Vector (often univariate) of leaf values - node_leaf_values = function(forest_num, tree_num, node_id) { - return(leaf_values_forest_container_cpp( - self$forest_container_ptr, - forest_num, - tree_num, - node_id - )) - }, - - #' @description - #' Number of nodes in a given tree in a given forest in a `ForestSamples` object. - #' @param forest_num Index of the forest to be queried - #' @param tree_num Index of the tree to be queried - #' @return Count of total tree nodes - num_nodes = function(forest_num, tree_num) { - return(num_nodes_forest_container_cpp( - self$forest_container_ptr, - forest_num, - tree_num - )) - }, - - #' @description - #' Number of leaves in a given tree in a given forest in a `ForestSamples` object. - #' @param forest_num Index of the forest to be queried - #' @param tree_num Index of the tree to be queried - #' @return Count of total tree leaves - num_leaves = function(forest_num, tree_num) { - return(num_leaves_forest_container_cpp( - self$forest_container_ptr, - forest_num, - tree_num - )) - }, - - #' @description - #' Number of leaf parents (split nodes with two leaves as children) in a given tree in a given forest in a `ForestSamples` object. - #' @param forest_num Index of the forest to be queried - #' @param tree_num Index of the tree to be queried - #' @return Count of total tree leaf parents - num_leaf_parents = function(forest_num, tree_num) { - return(num_leaf_parents_forest_container_cpp( - self$forest_container_ptr, - forest_num, - tree_num - )) - }, - - #' @description - #' Number of split nodes in a given tree in a given forest in a `ForestSamples` object. - #' @param forest_num Index of the forest to be queried - #' @param tree_num Index of the tree to be queried - #' @return Count of total tree split nodes - num_split_nodes = function(forest_num, tree_num) { - return(num_split_nodes_forest_container_cpp( - self$forest_container_ptr, - forest_num, - tree_num - )) - }, - - #' @description - #' Array of node indices in a given tree in a given forest in a `ForestSamples` object. - #' @param forest_num Index of the forest to be queried - #' @param tree_num Index of the tree to be queried - #' @return Indices of tree nodes - nodes = function(forest_num, tree_num) { - return(nodes_forest_container_cpp( - self$forest_container_ptr, - forest_num, - tree_num - )) - }, - - #' @description - #' Array of leaf indices in a given tree in a given forest in a `ForestSamples` object. - #' @param forest_num Index of the forest to be queried - #' @param tree_num Index of the tree to be queried - #' @return Indices of leaf nodes - leaves = function(forest_num, tree_num) { - return(leaves_forest_container_cpp( - self$forest_container_ptr, - forest_num, - tree_num - )) - }, - - #' @description - #' Modify the ``ForestSamples`` object by removing the forest sample indexed by `forest_num - #' @param forest_num Index of the forest to be removed - delete_sample = function(forest_num) { - return(remove_sample_forest_container_cpp( - self$forest_container_ptr, - forest_num - )) - } - ) + node_id + ) + ) { + return(Inf) + } else { + return(split_theshold_forest_container_cpp( + self$forest_container_ptr, + forest_num, + tree_num, + node_id + )) + } + }, + + #' @description + #' Array of category indices that define a categorical split for a given node of a given tree in a given forest in a `ForestSamples` object. + #' Returns `c(Inf)` if the node is a leaf or a numeric split node. + #' @param forest_num Index of the forest to be queried + #' @param tree_num Index of the tree to be queried + #' @param node_id Index of the node to be queried + #' @return Categories defining a split for the node + node_split_categories = function(forest_num, tree_num, node_id) { + if ( + self$is_leaf_node(forest_num, tree_num, node_id) || + self$is_numeric_split_node(forest_num, tree_num, node_id) + ) { + return(c(Inf)) + } else { + return(split_categories_forest_container_cpp( + self$forest_container_ptr, + forest_num, + tree_num, + node_id + )) + } + }, + + #' @description + #' Leaf node value(s) for a given node of a given tree in a given forest in a `ForestSamples` object. + #' Values are stale if the node is a split node. + #' @param forest_num Index of the forest to be queried + #' @param tree_num Index of the tree to be queried + #' @param node_id Index of the node to be queried + #' @return Vector (often univariate) of leaf values + node_leaf_values = function(forest_num, tree_num, node_id) { + return(leaf_values_forest_container_cpp( + self$forest_container_ptr, + forest_num, + tree_num, + node_id + )) + }, + + #' @description + #' Number of nodes in a given tree in a given forest in a `ForestSamples` object. + #' @param forest_num Index of the forest to be queried + #' @param tree_num Index of the tree to be queried + #' @return Count of total tree nodes + num_nodes = function(forest_num, tree_num) { + return(num_nodes_forest_container_cpp( + self$forest_container_ptr, + forest_num, + tree_num + )) + }, + + #' @description + #' Number of leaves in a given tree in a given forest in a `ForestSamples` object. + #' @param forest_num Index of the forest to be queried + #' @param tree_num Index of the tree to be queried + #' @return Count of total tree leaves + num_leaves = function(forest_num, tree_num) { + return(num_leaves_forest_container_cpp( + self$forest_container_ptr, + forest_num, + tree_num + )) + }, + + #' @description + #' Number of leaf parents (split nodes with two leaves as children) in a given tree in a given forest in a `ForestSamples` object. + #' @param forest_num Index of the forest to be queried + #' @param tree_num Index of the tree to be queried + #' @return Count of total tree leaf parents + num_leaf_parents = function(forest_num, tree_num) { + return(num_leaf_parents_forest_container_cpp( + self$forest_container_ptr, + forest_num, + tree_num + )) + }, + + #' @description + #' Number of split nodes in a given tree in a given forest in a `ForestSamples` object. + #' @param forest_num Index of the forest to be queried + #' @param tree_num Index of the tree to be queried + #' @return Count of total tree split nodes + num_split_nodes = function(forest_num, tree_num) { + return(num_split_nodes_forest_container_cpp( + self$forest_container_ptr, + forest_num, + tree_num + )) + }, + + #' @description + #' Array of node indices in a given tree in a given forest in a `ForestSamples` object. + #' @param forest_num Index of the forest to be queried + #' @param tree_num Index of the tree to be queried + #' @return Indices of tree nodes + nodes = function(forest_num, tree_num) { + return(nodes_forest_container_cpp( + self$forest_container_ptr, + forest_num, + tree_num + )) + }, + + #' @description + #' Array of leaf indices in a given tree in a given forest in a `ForestSamples` object. + #' @param forest_num Index of the forest to be queried + #' @param tree_num Index of the tree to be queried + #' @return Indices of leaf nodes + leaves = function(forest_num, tree_num) { + return(leaves_forest_container_cpp( + self$forest_container_ptr, + forest_num, + tree_num + )) + }, + + #' @description + #' Modify the ``ForestSamples`` object by removing the forest sample indexed by `forest_num + #' @param forest_num Index of the forest to be removed + delete_sample = function(forest_num) { + return(remove_sample_forest_container_cpp( + self$forest_container_ptr, + forest_num + )) + } + ) ) #' Class that stores a single ensemble of decision trees (often treated as the "active forest") @@ -903,330 +903,330 @@ ForestSamples <- R6::R6Class( #' Wrapper around a C++ tree ensemble Forest <- R6::R6Class( - classname = "Forest", - cloneable = FALSE, - public = list( - #' @field forest_ptr External pointer to a C++ TreeEnsemble class - forest_ptr = NULL, - - #' @field internal_forest_is_empty Whether the forest has not yet been "initialized" such that its `predict` function can be called. - internal_forest_is_empty = TRUE, - - #' @description - #' Create a new Forest object. - #' @param num_trees Number of trees in the forest - #' @param leaf_dimension Dimensionality of the outcome model - #' @param is_leaf_constant Whether leaf is constant - #' @param is_exponentiated Whether forest predictions should be exponentiated before being returned - #' @return A new `Forest` object. - initialize = function( - num_trees, - leaf_dimension = 1, - is_leaf_constant = FALSE, - is_exponentiated = FALSE - ) { - self$forest_ptr <- active_forest_cpp( - num_trees, - leaf_dimension, - is_leaf_constant, - is_exponentiated - ) - self$internal_forest_is_empty <- TRUE - }, - - #' @description - #' Create a larger forest by merging the trees of this forest with those of another forest - #' @param forest Forest to be merged into this forest - merge_forest = function(forest) { - stopifnot(self$leaf_dimension() == forest$leaf_dimension()) - stopifnot(self$is_constant_leaf() == forest$is_constant_leaf()) - stopifnot(self$is_exponentiated() == forest$is_exponentiated()) - forest_merge_cpp(self$forest_ptr, forest$forest_ptr) - }, - - #' @description - #' Add a constant value to every leaf of every tree in an ensemble. If leaves are multi-dimensional, `constant_value` will be added to every dimension of the leaves. - #' @param constant_value Value that will be added to every leaf of every tree - add_constant = function(constant_value) { - forest_add_constant_cpp(self$forest_ptr, constant_value) - }, - - #' @description - #' Multiply every leaf of every tree by a constant value. If leaves are multi-dimensional, `constant_multiple` will be multiplied through every dimension of the leaves. - #' @param constant_multiple Value that will be multiplied by every leaf of every tree - multiply_constant = function(constant_multiple) { - forest_multiply_constant_cpp(self$forest_ptr, constant_multiple) - }, - - #' @description - #' Predict forest on every sample in `forest_dataset` - #' @param forest_dataset `ForestDataset` R class - #' @return vector of predictions with as many rows as in `forest_dataset` - predict = function(forest_dataset) { - stopifnot(!is.null(forest_dataset$data_ptr)) - stopifnot(!is.null(self$forest_ptr)) - return(predict_active_forest_cpp( - self$forest_ptr, - forest_dataset$data_ptr - )) - }, - - #' @description - #' Predict "raw" leaf values (without being multiplied by basis) for every sample in `forest_dataset` - #' @param forest_dataset `ForestDataset` R class - #' @return Array of predictions for each observation in `forest_dataset` and - #' each sample in the `ForestSamples` class with each prediction having the - #' dimensionality of the forests' leaf model. In the case of a constant leaf model - #' or univariate leaf regression, this array is a vector (length is the number of - #' observations). In the case of a multivariate leaf regression, - #' this array is a matrix (number of observations by leaf model dimension, - #' number of samples). - predict_raw = function(forest_dataset) { - stopifnot(!is.null(forest_dataset$data_ptr)) - # Unpack dimensions - output_dim <- leaf_dimension_active_forest_cpp(self$forest_ptr) - n <- dataset_num_rows_cpp(forest_dataset$data_ptr) - - # Predict leaf values from forest - predictions <- predict_raw_active_forest_cpp( - self$forest_ptr, - forest_dataset$data_ptr - ) - if (output_dim > 1) { - dim(predictions) <- c(n, output_dim) - } - - return(predictions) - }, - - #' @description - #' Set a constant predicted value for every tree in the ensemble. - #' Stops program if any tree is more than a root node. - #' @param leaf_value Constant leaf value(s) to be fixed for each tree in the ensemble indexed by `forest_num`. Can be either a single number or a vector, depending on the forest's leaf dimension. - set_root_leaves = function(leaf_value) { - stopifnot(!is.null(self$forest_ptr)) - stopifnot(self$internal_forest_is_empty) - - # Set leaf values - if (length(leaf_value) == 1) { - stopifnot( - leaf_dimension_active_forest_cpp(self$forest_ptr) == 1 - ) - set_leaf_value_active_forest_cpp(self$forest_ptr, leaf_value) - } else if (length(leaf_value) > 1) { - stopifnot( - leaf_dimension_active_forest_cpp(self$forest_ptr) == - length(leaf_value) - ) - set_leaf_vector_active_forest_cpp(self$forest_ptr, leaf_value) - } else { - stop( - "leaf_value must be a numeric value or vector of length >= 1" - ) - } - - self$internal_forest_is_empty = FALSE - }, - - #' @description - #' Set a constant predicted value for every tree in the ensemble. - #' Stops program if any tree is more than a root node. - #' @param dataset `ForestDataset` Dataset class (covariates, basis, etc...) - #' @param outcome `Outcome` Outcome class (residual / partial residual) - #' @param forest_model `ForestModel` object storing tracking structures used in training / sampling - #' @param leaf_model_int Integer value encoding the leaf model type (0 = constant gaussian, 1 = univariate gaussian, 2 = multivariate gaussian, 3 = log linear variance). - #' @param leaf_value Constant leaf value(s) to be fixed for each tree in the ensemble indexed by `forest_num`. Can be either a single number or a vector, depending on the forest's leaf dimension. - prepare_for_sampler = function( - dataset, - outcome, - forest_model, - leaf_model_int, - leaf_value - ) { - stopifnot(!is.null(dataset$data_ptr)) - stopifnot(!is.null(outcome$data_ptr)) - stopifnot(!is.null(forest_model$tracker_ptr)) - stopifnot(!is.null(self$forest_ptr)) - stopifnot(self$internal_forest_is_empty) - - # Initialize the model - initialize_forest_model_active_forest_cpp( - dataset$data_ptr, - outcome$data_ptr, - self$forest_ptr, - forest_model$tracker_ptr, - leaf_value, - leaf_model_int - ) - - self$internal_forest_is_empty = FALSE - }, - - #' @description - #' Adjusts residual based on the predictions of a forest - #' - #' This is typically run just once at the beginning of a forest sampling algorithm. - #' After trees are initialized with constant root node predictions, their root predictions are subtracted out of the residual. - #' @param dataset `ForestDataset` object storing the covariates and bases for a given forest - #' @param outcome `Outcome` object storing the residuals to be updated based on forest predictions - #' @param forest_model `ForestModel` object storing tracking structures used in training / sampling - #' @param requires_basis Whether or not a forest requires a basis for prediction - #' @param add Whether forest predictions should be added to or subtracted from residuals - adjust_residual = function( - dataset, - outcome, - forest_model, - requires_basis, - add - ) { - stopifnot(!is.null(dataset$data_ptr)) - stopifnot(!is.null(outcome$data_ptr)) - stopifnot(!is.null(forest_model$tracker_ptr)) - stopifnot(!is.null(self$forest_ptr)) - - adjust_residual_active_forest_cpp( - dataset$data_ptr, - outcome$data_ptr, - self$forest_ptr, - forest_model$tracker_ptr, - requires_basis, - add - ) - }, - - #' @description - #' Return number of trees in each ensemble of a `Forest` object - #' @return Tree count - num_trees = function() { - return(num_trees_active_forest_cpp(self$forest_ptr)) - }, - - #' @description - #' Return output dimension of trees in a `Forest` object - #' @return Leaf node parameter size - leaf_dimension = function() { - return(leaf_dimension_active_forest_cpp(self$forest_ptr)) - }, - - #' @description - #' Return constant leaf status of trees in a `Forest` object - #' @return `TRUE` if leaves are constant, `FALSE` otherwise - is_constant_leaf = function() { - return(is_leaf_constant_active_forest_cpp(self$forest_ptr)) - }, - - #' @description - #' Return exponentiation status of trees in a `Forest` object - #' @return `TRUE` if leaf predictions must be exponentiated, `FALSE` otherwise - is_exponentiated = function() { - return(is_exponentiated_active_forest_cpp(self$forest_ptr)) - }, - - #' @description - #' Add a numeric (i.e. `X[,i] <= c`) split to a given tree in the ensemble - #' @param tree_num Index of the tree to be split - #' @param leaf_num Leaf to be split - #' @param feature_num Feature that defines the new split - #' @param split_threshold Value that defines the cutoff of the new split - #' @param left_leaf_value Value (or vector of values) to assign to the newly created left node - #' @param right_leaf_value Value (or vector of values) to assign to the newly created right node - add_numeric_split_tree = function( - tree_num, - leaf_num, - feature_num, - split_threshold, - left_leaf_value, - right_leaf_value - ) { - if (length(left_leaf_value) > 1) { - add_numeric_split_tree_vector_active_forest_cpp( - self$forest_ptr, - tree_num, - leaf_num, - feature_num, - split_threshold, - left_leaf_value, - right_leaf_value - ) - } else { - add_numeric_split_tree_value_active_forest_cpp( - self$forest_ptr, - tree_num, - leaf_num, - feature_num, - split_threshold, - left_leaf_value, - right_leaf_value - ) - } - }, - - #' @description - #' Retrieve a vector of indices of leaf nodes for a given tree in a given forest - #' @param tree_num Index of the tree for which leaf indices will be retrieved - get_tree_leaves = function(tree_num) { - return(get_tree_leaves_active_forest_cpp(self$forest_ptr, tree_num)) - }, - - #' @description - #' Retrieve a vector of split counts for every training set variable in a given tree in the forest - #' @param tree_num Index of the tree for which split counts will be retrieved - #' @param num_features Total number of features in the training set - get_tree_split_counts = function(tree_num, num_features) { - return(get_tree_split_counts_active_forest_cpp( - self$forest_ptr, - tree_num, - num_features - )) - }, - - #' @description - #' Retrieve a vector of split counts for every training set variable in the forest - #' @param num_features Total number of features in the training set - get_forest_split_counts = function(num_features) { - return(get_overall_split_counts_active_forest_cpp( - self$forest_ptr, - num_features - )) - }, - - #' @description - #' Maximum depth of a specific tree in the forest - #' @param tree_num Tree index within forest - #' @return Maximum leaf depth - tree_max_depth = function(tree_num) { - return(ensemble_tree_max_depth_active_forest_cpp( - self$forest_ptr, - tree_num - )) - }, - - #' @description - #' Average the maximum depth of each tree in the forest - #' @return Average maximum depth - average_max_depth = function() { - return(ensemble_average_max_depth_active_forest_cpp( - self$forest_ptr - )) - }, - - #' @description - #' When a forest object is created, it is "empty" in the sense that none - #' of its component trees have leaves with values. There are two ways to - #' "initialize" a Forest object. First, the `set_root_leaves()` method - #' simply initializes every tree in the forest to a single node carrying - #' the same (user-specified) leaf value. Second, the `prepare_for_sampler()` - #' method initializes every tree in the forest to a single node with the - #' same value and also propagates this information through to a ForestModel - #' object, which must be synchronized with a Forest during a forest - #' sampler loop. - #' @return `TRUE` if a Forest has not yet been initialized with a constant - #' root value, `FALSE` otherwise if the forest has already been - #' initialized / grown. - is_empty = function() { - return(self$internal_forest_is_empty) - } - ) + classname = "Forest", + cloneable = FALSE, + public = list( + #' @field forest_ptr External pointer to a C++ TreeEnsemble class + forest_ptr = NULL, + + #' @field internal_forest_is_empty Whether the forest has not yet been "initialized" such that its `predict` function can be called. + internal_forest_is_empty = TRUE, + + #' @description + #' Create a new Forest object. + #' @param num_trees Number of trees in the forest + #' @param leaf_dimension Dimensionality of the outcome model + #' @param is_leaf_constant Whether leaf is constant + #' @param is_exponentiated Whether forest predictions should be exponentiated before being returned + #' @return A new `Forest` object. + initialize = function( + num_trees, + leaf_dimension = 1, + is_leaf_constant = FALSE, + is_exponentiated = FALSE + ) { + self$forest_ptr <- active_forest_cpp( + num_trees, + leaf_dimension, + is_leaf_constant, + is_exponentiated + ) + self$internal_forest_is_empty <- TRUE + }, + + #' @description + #' Create a larger forest by merging the trees of this forest with those of another forest + #' @param forest Forest to be merged into this forest + merge_forest = function(forest) { + stopifnot(self$leaf_dimension() == forest$leaf_dimension()) + stopifnot(self$is_constant_leaf() == forest$is_constant_leaf()) + stopifnot(self$is_exponentiated() == forest$is_exponentiated()) + forest_merge_cpp(self$forest_ptr, forest$forest_ptr) + }, + + #' @description + #' Add a constant value to every leaf of every tree in an ensemble. If leaves are multi-dimensional, `constant_value` will be added to every dimension of the leaves. + #' @param constant_value Value that will be added to every leaf of every tree + add_constant = function(constant_value) { + forest_add_constant_cpp(self$forest_ptr, constant_value) + }, + + #' @description + #' Multiply every leaf of every tree by a constant value. If leaves are multi-dimensional, `constant_multiple` will be multiplied through every dimension of the leaves. + #' @param constant_multiple Value that will be multiplied by every leaf of every tree + multiply_constant = function(constant_multiple) { + forest_multiply_constant_cpp(self$forest_ptr, constant_multiple) + }, + + #' @description + #' Predict forest on every sample in `forest_dataset` + #' @param forest_dataset `ForestDataset` R class + #' @return vector of predictions with as many rows as in `forest_dataset` + predict = function(forest_dataset) { + stopifnot(!is.null(forest_dataset$data_ptr)) + stopifnot(!is.null(self$forest_ptr)) + return(predict_active_forest_cpp( + self$forest_ptr, + forest_dataset$data_ptr + )) + }, + + #' @description + #' Predict "raw" leaf values (without being multiplied by basis) for every sample in `forest_dataset` + #' @param forest_dataset `ForestDataset` R class + #' @return Array of predictions for each observation in `forest_dataset` and + #' each sample in the `ForestSamples` class with each prediction having the + #' dimensionality of the forests' leaf model. In the case of a constant leaf model + #' or univariate leaf regression, this array is a vector (length is the number of + #' observations). In the case of a multivariate leaf regression, + #' this array is a matrix (number of observations by leaf model dimension, + #' number of samples). + predict_raw = function(forest_dataset) { + stopifnot(!is.null(forest_dataset$data_ptr)) + # Unpack dimensions + output_dim <- leaf_dimension_active_forest_cpp(self$forest_ptr) + n <- dataset_num_rows_cpp(forest_dataset$data_ptr) + + # Predict leaf values from forest + predictions <- predict_raw_active_forest_cpp( + self$forest_ptr, + forest_dataset$data_ptr + ) + if (output_dim > 1) { + dim(predictions) <- c(n, output_dim) + } + + return(predictions) + }, + + #' @description + #' Set a constant predicted value for every tree in the ensemble. + #' Stops program if any tree is more than a root node. + #' @param leaf_value Constant leaf value(s) to be fixed for each tree in the ensemble indexed by `forest_num`. Can be either a single number or a vector, depending on the forest's leaf dimension. + set_root_leaves = function(leaf_value) { + stopifnot(!is.null(self$forest_ptr)) + stopifnot(self$internal_forest_is_empty) + + # Set leaf values + if (length(leaf_value) == 1) { + stopifnot( + leaf_dimension_active_forest_cpp(self$forest_ptr) == 1 + ) + set_leaf_value_active_forest_cpp(self$forest_ptr, leaf_value) + } else if (length(leaf_value) > 1) { + stopifnot( + leaf_dimension_active_forest_cpp(self$forest_ptr) == + length(leaf_value) + ) + set_leaf_vector_active_forest_cpp(self$forest_ptr, leaf_value) + } else { + stop( + "leaf_value must be a numeric value or vector of length >= 1" + ) + } + + self$internal_forest_is_empty = FALSE + }, + + #' @description + #' Set a constant predicted value for every tree in the ensemble. + #' Stops program if any tree is more than a root node. + #' @param dataset `ForestDataset` Dataset class (covariates, basis, etc...) + #' @param outcome `Outcome` Outcome class (residual / partial residual) + #' @param forest_model `ForestModel` object storing tracking structures used in training / sampling + #' @param leaf_model_int Integer value encoding the leaf model type (0 = constant gaussian, 1 = univariate gaussian, 2 = multivariate gaussian, 3 = log linear variance). + #' @param leaf_value Constant leaf value(s) to be fixed for each tree in the ensemble indexed by `forest_num`. Can be either a single number or a vector, depending on the forest's leaf dimension. + prepare_for_sampler = function( + dataset, + outcome, + forest_model, + leaf_model_int, + leaf_value + ) { + stopifnot(!is.null(dataset$data_ptr)) + stopifnot(!is.null(outcome$data_ptr)) + stopifnot(!is.null(forest_model$tracker_ptr)) + stopifnot(!is.null(self$forest_ptr)) + stopifnot(self$internal_forest_is_empty) + + # Initialize the model + initialize_forest_model_active_forest_cpp( + dataset$data_ptr, + outcome$data_ptr, + self$forest_ptr, + forest_model$tracker_ptr, + leaf_value, + leaf_model_int + ) + + self$internal_forest_is_empty = FALSE + }, + + #' @description + #' Adjusts residual based on the predictions of a forest + #' + #' This is typically run just once at the beginning of a forest sampling algorithm. + #' After trees are initialized with constant root node predictions, their root predictions are subtracted out of the residual. + #' @param dataset `ForestDataset` object storing the covariates and bases for a given forest + #' @param outcome `Outcome` object storing the residuals to be updated based on forest predictions + #' @param forest_model `ForestModel` object storing tracking structures used in training / sampling + #' @param requires_basis Whether or not a forest requires a basis for prediction + #' @param add Whether forest predictions should be added to or subtracted from residuals + adjust_residual = function( + dataset, + outcome, + forest_model, + requires_basis, + add + ) { + stopifnot(!is.null(dataset$data_ptr)) + stopifnot(!is.null(outcome$data_ptr)) + stopifnot(!is.null(forest_model$tracker_ptr)) + stopifnot(!is.null(self$forest_ptr)) + + adjust_residual_active_forest_cpp( + dataset$data_ptr, + outcome$data_ptr, + self$forest_ptr, + forest_model$tracker_ptr, + requires_basis, + add + ) + }, + + #' @description + #' Return number of trees in each ensemble of a `Forest` object + #' @return Tree count + num_trees = function() { + return(num_trees_active_forest_cpp(self$forest_ptr)) + }, + + #' @description + #' Return output dimension of trees in a `Forest` object + #' @return Leaf node parameter size + leaf_dimension = function() { + return(leaf_dimension_active_forest_cpp(self$forest_ptr)) + }, + + #' @description + #' Return constant leaf status of trees in a `Forest` object + #' @return `TRUE` if leaves are constant, `FALSE` otherwise + is_constant_leaf = function() { + return(is_leaf_constant_active_forest_cpp(self$forest_ptr)) + }, + + #' @description + #' Return exponentiation status of trees in a `Forest` object + #' @return `TRUE` if leaf predictions must be exponentiated, `FALSE` otherwise + is_exponentiated = function() { + return(is_exponentiated_active_forest_cpp(self$forest_ptr)) + }, + + #' @description + #' Add a numeric (i.e. `X[,i] <= c`) split to a given tree in the ensemble + #' @param tree_num Index of the tree to be split + #' @param leaf_num Leaf to be split + #' @param feature_num Feature that defines the new split + #' @param split_threshold Value that defines the cutoff of the new split + #' @param left_leaf_value Value (or vector of values) to assign to the newly created left node + #' @param right_leaf_value Value (or vector of values) to assign to the newly created right node + add_numeric_split_tree = function( + tree_num, + leaf_num, + feature_num, + split_threshold, + left_leaf_value, + right_leaf_value + ) { + if (length(left_leaf_value) > 1) { + add_numeric_split_tree_vector_active_forest_cpp( + self$forest_ptr, + tree_num, + leaf_num, + feature_num, + split_threshold, + left_leaf_value, + right_leaf_value + ) + } else { + add_numeric_split_tree_value_active_forest_cpp( + self$forest_ptr, + tree_num, + leaf_num, + feature_num, + split_threshold, + left_leaf_value, + right_leaf_value + ) + } + }, + + #' @description + #' Retrieve a vector of indices of leaf nodes for a given tree in a given forest + #' @param tree_num Index of the tree for which leaf indices will be retrieved + get_tree_leaves = function(tree_num) { + return(get_tree_leaves_active_forest_cpp(self$forest_ptr, tree_num)) + }, + + #' @description + #' Retrieve a vector of split counts for every training set variable in a given tree in the forest + #' @param tree_num Index of the tree for which split counts will be retrieved + #' @param num_features Total number of features in the training set + get_tree_split_counts = function(tree_num, num_features) { + return(get_tree_split_counts_active_forest_cpp( + self$forest_ptr, + tree_num, + num_features + )) + }, + + #' @description + #' Retrieve a vector of split counts for every training set variable in the forest + #' @param num_features Total number of features in the training set + get_forest_split_counts = function(num_features) { + return(get_overall_split_counts_active_forest_cpp( + self$forest_ptr, + num_features + )) + }, + + #' @description + #' Maximum depth of a specific tree in the forest + #' @param tree_num Tree index within forest + #' @return Maximum leaf depth + tree_max_depth = function(tree_num) { + return(ensemble_tree_max_depth_active_forest_cpp( + self$forest_ptr, + tree_num + )) + }, + + #' @description + #' Average the maximum depth of each tree in the forest + #' @return Average maximum depth + average_max_depth = function() { + return(ensemble_average_max_depth_active_forest_cpp( + self$forest_ptr + )) + }, + + #' @description + #' When a forest object is created, it is "empty" in the sense that none + #' of its component trees have leaves with values. There are two ways to + #' "initialize" a Forest object. First, the `set_root_leaves()` method + #' simply initializes every tree in the forest to a single node carrying + #' the same (user-specified) leaf value. Second, the `prepare_for_sampler()` + #' method initializes every tree in the forest to a single node with the + #' same value and also propagates this information through to a ForestModel + #' object, which must be synchronized with a Forest during a forest + #' sampler loop. + #' @return `TRUE` if a Forest has not yet been initialized with a constant + #' root value, `FALSE` otherwise if the forest has already been + #' initialized / grown. + is_empty = function() { + return(self$internal_forest_is_empty) + } + ) ) #' Create a container of forest samples @@ -1246,19 +1246,19 @@ Forest <- R6::R6Class( #' is_exponentiated <- FALSE #' forest_samples <- createForestSamples(num_trees, leaf_dimension, is_leaf_constant, is_exponentiated) createForestSamples <- function( - num_trees, - leaf_dimension = 1, - is_leaf_constant = FALSE, - is_exponentiated = FALSE + num_trees, + leaf_dimension = 1, + is_leaf_constant = FALSE, + is_exponentiated = FALSE ) { - return(invisible( - (ForestSamples$new( - num_trees, - leaf_dimension, - is_leaf_constant, - is_exponentiated - )) + return(invisible( + (ForestSamples$new( + num_trees, + leaf_dimension, + is_leaf_constant, + is_exponentiated )) + )) } #' Create a forest @@ -1278,19 +1278,19 @@ createForestSamples <- function( #' is_exponentiated <- FALSE #' forest <- createForest(num_trees, leaf_dimension, is_leaf_constant, is_exponentiated) createForest <- function( - num_trees, - leaf_dimension = 1, - is_leaf_constant = FALSE, - is_exponentiated = FALSE + num_trees, + leaf_dimension = 1, + is_leaf_constant = FALSE, + is_exponentiated = FALSE ) { - return(invisible( - (Forest$new( - num_trees, - leaf_dimension, - is_leaf_constant, - is_exponentiated - )) + return(invisible( + (Forest$new( + num_trees, + leaf_dimension, + is_leaf_constant, + is_exponentiated )) + )) } #' Reset an active forest, either from a specific forest in a `ForestContainer` @@ -1316,25 +1316,25 @@ createForest <- function( #' resetActiveForest(active_forest, forest_samples, 0) #' resetActiveForest(active_forest) resetActiveForest <- function( - active_forest, - forest_samples = NULL, - forest_num = NULL + active_forest, + forest_samples = NULL, + forest_num = NULL ) { - if (is.null(forest_samples)) { - root_reset_active_forest_cpp(active_forest$forest_ptr) - active_forest$internal_forest_is_empty = TRUE - } else { - if (is.null(forest_num)) { - stop( - "`forest_num` must be specified if `forest_samples` is provided" - ) - } - reset_active_forest_cpp( - active_forest$forest_ptr, - forest_samples$forest_container_ptr, - forest_num - ) + if (is.null(forest_samples)) { + root_reset_active_forest_cpp(active_forest$forest_ptr) + active_forest$internal_forest_is_empty = TRUE + } else { + if (is.null(forest_num)) { + stop( + "`forest_num` must be specified if `forest_samples` is provided" + ) } + reset_active_forest_cpp( + active_forest$forest_ptr, + forest_samples$forest_container_ptr, + forest_num + ) + } } #' Re-initialize a forest model (tracking data structures) from a specific forest in a `ForestContainer` @@ -1394,17 +1394,17 @@ resetActiveForest <- function( #' resetActiveForest(active_forest, forest_samples, 0) #' resetForestModel(forest_model, active_forest, forest_dataset, outcome, TRUE) resetForestModel <- function( - forest_model, - forest, - dataset, - residual, - is_mean_model + forest_model, + forest, + dataset, + residual, + is_mean_model ) { - reset_forest_model_cpp( - forest_model$tracker_ptr, - forest$forest_ptr, - dataset$data_ptr, - residual$data_ptr, - is_mean_model - ) + reset_forest_model_cpp( + forest_model$tracker_ptr, + forest$forest_ptr, + dataset$data_ptr, + residual$data_ptr, + is_mean_model + ) } diff --git a/R/kernel.R b/R/kernel.R index d7e9661e..2b643b98 100644 --- a/R/kernel.R +++ b/R/kernel.R @@ -48,113 +48,113 @@ #' computeForestLeafIndices(bart_model, X, "mean", 0) #' computeForestLeafIndices(bart_model, X, "mean", c(1,3,9)) computeForestLeafIndices <- function( - model_object, - covariates, - forest_type = NULL, - propensity = NULL, - forest_inds = NULL + model_object, + covariates, + forest_type = NULL, + propensity = NULL, + forest_inds = NULL ) { - # Extract relevant forest container - stopifnot(any(c( - inherits(model_object, "bartmodel"), - inherits(model_object, "bcfmodel"), - inherits(model_object, "ForestSamples") - ))) - model_type <- ifelse( - inherits(model_object, "bartmodel"), - "bart", - ifelse(inherits(model_object, "bcfmodel"), "bcf", "forest_samples") - ) - if (model_type == "bart") { - stopifnot(forest_type %in% c("mean", "variance")) - if (forest_type == "mean") { - if (!model_object$model_params$include_mean_forest) { - stop("Mean forest was not sampled in the bart model provided") - } - forest_container <- model_object$mean_forests - } else if (forest_type == "variance") { - if (!model_object$model_params$include_variance_forest) { - stop( - "Variance forest was not sampled in the bart model provided" - ) - } - forest_container <- model_object$variance_forests - } - } else if (model_type == "bcf") { - stopifnot(forest_type %in% c("prognostic", "treatment", "variance")) - if (forest_type == "prognostic") { - forest_container <- model_object$forests_mu - } else if (forest_type == "treatment") { - forest_container <- model_object$forests_tau - } else if (forest_type == "variance") { - if (!model_object$model_params$include_variance_forest) { - stop( - "Variance forest was not sampled in the bcf model provided" - ) - } - forest_container <- model_object$variance_forests - } - } else { - forest_container <- model_object - } - - # Preprocess covariates - if ((!is.data.frame(covariates)) && (!is.matrix(covariates))) { - stop("covariates must be a matrix or dataframe") + # Extract relevant forest container + stopifnot(any(c( + inherits(model_object, "bartmodel"), + inherits(model_object, "bcfmodel"), + inherits(model_object, "ForestSamples") + ))) + model_type <- ifelse( + inherits(model_object, "bartmodel"), + "bart", + ifelse(inherits(model_object, "bcfmodel"), "bcf", "forest_samples") + ) + if (model_type == "bart") { + stopifnot(forest_type %in% c("mean", "variance")) + if (forest_type == "mean") { + if (!model_object$model_params$include_mean_forest) { + stop("Mean forest was not sampled in the bart model provided") + } + forest_container <- model_object$mean_forests + } else if (forest_type == "variance") { + if (!model_object$model_params$include_variance_forest) { + stop( + "Variance forest was not sampled in the bart model provided" + ) + } + forest_container <- model_object$variance_forests } - if (model_type %in% c("bart", "bcf")) { - train_set_metadata <- model_object$train_set_metadata - covariates_processed <- preprocessPredictionData( - covariates, - train_set_metadata + } else if (model_type == "bcf") { + stopifnot(forest_type %in% c("prognostic", "treatment", "variance")) + if (forest_type == "prognostic") { + forest_container <- model_object$forests_mu + } else if (forest_type == "treatment") { + forest_container <- model_object$forests_tau + } else if (forest_type == "variance") { + if (!model_object$model_params$include_variance_forest) { + stop( + "Variance forest was not sampled in the bcf model provided" ) - } else { - if (!is.matrix(covariates)) { - stop( - "covariates must be a matrix since no covariate preprocessor is stored in a `ForestSamples` object provided as `model_object`" - ) - } - covariates_processed <- covariates + } + forest_container <- model_object$variance_forests } + } else { + forest_container <- model_object + } - # Handle BCF propensity covariate - if (model_type == "bcf") { - # Add propensities to covariate set if necessary - if (model_object$model_params$propensity_covariate != "none") { - if (is.null(propensity)) { - if (!model_object$model_params$internal_propensity_model) { - stop("propensity must be provided for this model") - } - # Compute propensity score using the internal bart model - propensity <- rowMeans( - predict( - model_object$bart_propensity_model, - covariates - )$y_hat - ) - } - covariates_processed <- cbind(covariates_processed, propensity) - } + # Preprocess covariates + if ((!is.data.frame(covariates)) && (!is.matrix(covariates))) { + stop("covariates must be a matrix or dataframe") + } + if (model_type %in% c("bart", "bcf")) { + train_set_metadata <- model_object$train_set_metadata + covariates_processed <- preprocessPredictionData( + covariates, + train_set_metadata + ) + } else { + if (!is.matrix(covariates)) { + stop( + "covariates must be a matrix since no covariate preprocessor is stored in a `ForestSamples` object provided as `model_object`" + ) } + covariates_processed <- covariates + } - # Preprocess forest indices - num_forests <- forest_container$num_samples() - if (is.null(forest_inds)) { - forest_inds <- as.integer(1:num_forests - 1) - } else { - stopifnot(all(forest_inds <= num_forests - 1)) - stopifnot(all(forest_inds >= 0)) - forest_inds <- as.integer(forest_inds) + # Handle BCF propensity covariate + if (model_type == "bcf") { + # Add propensities to covariate set if necessary + if (model_object$model_params$propensity_covariate != "none") { + if (is.null(propensity)) { + if (!model_object$model_params$internal_propensity_model) { + stop("propensity must be provided for this model") + } + # Compute propensity score using the internal bart model + propensity <- rowMeans( + predict( + model_object$bart_propensity_model, + covariates + )$y_hat + ) + } + covariates_processed <- cbind(covariates_processed, propensity) } + } - # Compute leaf indices - leaf_ind_matrix <- compute_leaf_indices_cpp( - forest_container$forest_container_ptr, - covariates_processed, - forest_inds - ) + # Preprocess forest indices + num_forests <- forest_container$num_samples() + if (is.null(forest_inds)) { + forest_inds <- as.integer(1:num_forests - 1) + } else { + stopifnot(all(forest_inds <= num_forests - 1)) + stopifnot(all(forest_inds >= 0)) + forest_inds <- as.integer(forest_inds) + } - return(leaf_ind_matrix) + # Compute leaf indices + leaf_ind_matrix <- compute_leaf_indices_cpp( + forest_container$forest_container_ptr, + covariates_processed, + forest_inds + ) + + return(leaf_ind_matrix) } #' Compute vector of forest leaf scale parameters @@ -193,80 +193,80 @@ computeForestLeafIndices <- function( #' computeForestLeafVariances(bart_model, "mean", 0) #' computeForestLeafVariances(bart_model, "mean", c(1,3,5)) computeForestLeafVariances <- function( - model_object, - forest_type, - forest_inds = NULL + model_object, + forest_type, + forest_inds = NULL ) { - # Extract relevant forest container - stopifnot(any(c( - inherits(model_object, "bartmodel"), - inherits(model_object, "bcfmodel") - ))) - model_type <- ifelse(inherits(model_object, "bartmodel"), "bart", "bcf") - if (model_type == "bart") { - stopifnot(forest_type %in% c("mean", "variance")) - if (forest_type == "mean") { - if (!model_object$model_params$include_mean_forest) { - stop("Mean forest was not sampled in the bart model provided") - } - if (!model_object$model_params$sample_sigma2_leaf) { - stop( - "Leaf scale parameter was not sampled for the mean forest in the bart model provided" - ) - } - leaf_scale_vector <- model_object$sigma2_leaf_samples - } else if (forest_type == "variance") { - if (!model_object$model_params$include_variance_forest) { - stop( - "Variance forest was not sampled in the bart model provided" - ) - } - stop( - "Leaf scale parameter was not sampled for the variance forest in the bart model provided" - ) - } - } else { - stopifnot(forest_type %in% c("prognostic", "treatment", "variance")) - if (forest_type == "prognostic") { - if (!model_object$model_params$sample_sigma2_leaf_mu) { - stop( - "Leaf scale parameter was not sampled for the prognostic forest in the bcf model provided" - ) - } - leaf_scale_vector <- model_object$sigma2_leaf_mu_samples - } else if (forest_type == "treatment") { - if (!model_object$model_params$sample_sigma2_leaf_tau) { - stop( - "Leaf scale parameter was not sampled for the treatment effect forest in the bcf model provided" - ) - } - leaf_scale_vector <- model_object$sigma2_leaf_tau_samples - } else if (forest_type == "variance") { - if (!model_object$model_params$include_variance_forest) { - stop( - "Variance forest was not sampled in the bcf model provided" - ) - } - stop( - "Leaf scale parameter was not sampled for the variance forest in the bcf model provided" - ) - } + # Extract relevant forest container + stopifnot(any(c( + inherits(model_object, "bartmodel"), + inherits(model_object, "bcfmodel") + ))) + model_type <- ifelse(inherits(model_object, "bartmodel"), "bart", "bcf") + if (model_type == "bart") { + stopifnot(forest_type %in% c("mean", "variance")) + if (forest_type == "mean") { + if (!model_object$model_params$include_mean_forest) { + stop("Mean forest was not sampled in the bart model provided") + } + if (!model_object$model_params$sample_sigma2_leaf) { + stop( + "Leaf scale parameter was not sampled for the mean forest in the bart model provided" + ) + } + leaf_scale_vector <- model_object$sigma2_leaf_samples + } else if (forest_type == "variance") { + if (!model_object$model_params$include_variance_forest) { + stop( + "Variance forest was not sampled in the bart model provided" + ) + } + stop( + "Leaf scale parameter was not sampled for the variance forest in the bart model provided" + ) } - - # Preprocess forest indices - num_forests <- model_object$model_params$num_samples - if (is.null(forest_inds)) { - forest_inds <- as.integer(1:num_forests) - } else { - stopifnot(all(forest_inds <= num_forests - 1)) - stopifnot(all(forest_inds >= 0)) - forest_inds <- as.integer(forest_inds + 1) + } else { + stopifnot(forest_type %in% c("prognostic", "treatment", "variance")) + if (forest_type == "prognostic") { + if (!model_object$model_params$sample_sigma2_leaf_mu) { + stop( + "Leaf scale parameter was not sampled for the prognostic forest in the bcf model provided" + ) + } + leaf_scale_vector <- model_object$sigma2_leaf_mu_samples + } else if (forest_type == "treatment") { + if (!model_object$model_params$sample_sigma2_leaf_tau) { + stop( + "Leaf scale parameter was not sampled for the treatment effect forest in the bcf model provided" + ) + } + leaf_scale_vector <- model_object$sigma2_leaf_tau_samples + } else if (forest_type == "variance") { + if (!model_object$model_params$include_variance_forest) { + stop( + "Variance forest was not sampled in the bcf model provided" + ) + } + stop( + "Leaf scale parameter was not sampled for the variance forest in the bcf model provided" + ) } + } + + # Preprocess forest indices + num_forests <- model_object$model_params$num_samples + if (is.null(forest_inds)) { + forest_inds <- as.integer(1:num_forests) + } else { + stopifnot(all(forest_inds <= num_forests - 1)) + stopifnot(all(forest_inds >= 0)) + forest_inds <- as.integer(forest_inds + 1) + } - # Gather leaf scale parameters - leaf_scale_params <- leaf_scale_vector[forest_inds] + # Gather leaf scale parameters + leaf_scale_params <- leaf_scale_vector[forest_inds] - return(leaf_scale_params) + return(leaf_scale_params) } #' Compute and return the largest possible leaf index computable by `computeForestLeafIndices` for the forests in a designated forest sample container. @@ -304,72 +304,72 @@ computeForestLeafVariances <- function( #' computeForestMaxLeafIndex(bart_model, "mean", 0) #' computeForestMaxLeafIndex(bart_model, "mean", c(1,3,9)) computeForestMaxLeafIndex <- function( - model_object, - forest_type = NULL, - forest_inds = NULL + model_object, + forest_type = NULL, + forest_inds = NULL ) { - # Extract relevant forest container - stopifnot(any(c( - inherits(model_object, "bartmodel"), - inherits(model_object, "bcfmodel"), - inherits(model_object, "ForestSamples") - ))) - model_type <- ifelse( - inherits(model_object, "bartmodel"), - "bart", - ifelse(inherits(model_object, "bcfmodel"), "bcf", "forest_samples") - ) - if (model_type == "bart") { - stopifnot(forest_type %in% c("mean", "variance")) - if (forest_type == "mean") { - if (!model_object$model_params$include_mean_forest) { - stop("Mean forest was not sampled in the bart model provided") - } - forest_container <- model_object$mean_forests - } else if (forest_type == "variance") { - if (!model_object$model_params$include_variance_forest) { - stop( - "Variance forest was not sampled in the bart model provided" - ) - } - forest_container <- model_object$variance_forests - } - } else if (model_type == "bcf") { - stopifnot(forest_type %in% c("prognostic", "treatment", "variance")) - if (forest_type == "prognostic") { - forest_container <- model_object$forests_mu - } else if (forest_type == "treatment") { - forest_container <- model_object$forests_tau - } else if (forest_type == "variance") { - if (!model_object$model_params$include_variance_forest) { - stop( - "Variance forest was not sampled in the bcf model provided" - ) - } - forest_container <- model_object$variance_forests - } - } else { - forest_container <- model_object - } - - # Preprocess forest indices - num_forests <- forest_container$num_samples() - if (is.null(forest_inds)) { - forest_inds <- as.integer(1:num_forests - 1) - } else { - stopifnot(all(forest_inds <= num_forests - 1)) - stopifnot(all(forest_inds >= 0)) - forest_inds <- as.integer(forest_inds) + # Extract relevant forest container + stopifnot(any(c( + inherits(model_object, "bartmodel"), + inherits(model_object, "bcfmodel"), + inherits(model_object, "ForestSamples") + ))) + model_type <- ifelse( + inherits(model_object, "bartmodel"), + "bart", + ifelse(inherits(model_object, "bcfmodel"), "bcf", "forest_samples") + ) + if (model_type == "bart") { + stopifnot(forest_type %in% c("mean", "variance")) + if (forest_type == "mean") { + if (!model_object$model_params$include_mean_forest) { + stop("Mean forest was not sampled in the bart model provided") + } + forest_container <- model_object$mean_forests + } else if (forest_type == "variance") { + if (!model_object$model_params$include_variance_forest) { + stop( + "Variance forest was not sampled in the bart model provided" + ) + } + forest_container <- model_object$variance_forests } - - # Compute leaf indices - output <- rep(NA, length(forest_inds)) - for (i in 1:length(forest_inds)) { - output[i] <- forest_container_get_max_leaf_index_cpp( - forest_container$forest_container_ptr, - forest_inds[i] + } else if (model_type == "bcf") { + stopifnot(forest_type %in% c("prognostic", "treatment", "variance")) + if (forest_type == "prognostic") { + forest_container <- model_object$forests_mu + } else if (forest_type == "treatment") { + forest_container <- model_object$forests_tau + } else if (forest_type == "variance") { + if (!model_object$model_params$include_variance_forest) { + stop( + "Variance forest was not sampled in the bcf model provided" ) + } + forest_container <- model_object$variance_forests } + } else { + forest_container <- model_object + } + + # Preprocess forest indices + num_forests <- forest_container$num_samples() + if (is.null(forest_inds)) { + forest_inds <- as.integer(1:num_forests - 1) + } else { + stopifnot(all(forest_inds <= num_forests - 1)) + stopifnot(all(forest_inds >= 0)) + forest_inds <- as.integer(forest_inds) + } + + # Compute leaf indices + output <- rep(NA, length(forest_inds)) + for (i in 1:length(forest_inds)) { + output[i] <- forest_container_get_max_leaf_index_cpp( + forest_container$forest_container_ptr, + forest_inds[i] + ) + } - return(output) + return(output) } diff --git a/R/model.R b/R/model.R index 38df5970..5549880d 100644 --- a/R/model.R +++ b/R/model.R @@ -6,20 +6,20 @@ #' the C++ random number generator is initialized using `std::random_device`. CppRNG <- R6::R6Class( - classname = "CppRNG", - cloneable = FALSE, - public = list( - #' @field rng_ptr External pointer to a C++ std::mt19937 class - rng_ptr = NULL, + classname = "CppRNG", + cloneable = FALSE, + public = list( + #' @field rng_ptr External pointer to a C++ std::mt19937 class + rng_ptr = NULL, - #' @description - #' Create a new CppRNG object. - #' @param random_seed (Optional) random seed for sampling - #' @return A new `CppRNG` object. - initialize = function(random_seed = -1) { - self$rng_ptr <- rng_cpp(random_seed) - } - ) + #' @description + #' Create a new CppRNG object. + #' @param random_seed (Optional) random seed for sampling + #' @return A new `CppRNG` object. + initialize = function(random_seed = -1) { + self$rng_ptr <- rng_cpp(random_seed) + } + ) ) #' Class that defines and samples a forest model @@ -30,292 +30,291 @@ CppRNG <- R6::R6Class( #' (using either MCMC or the grow-from-root algorithm). ForestModel <- R6::R6Class( - classname = "ForestModel", - cloneable = FALSE, - public = list( - #' @field tracker_ptr External pointer to a C++ ForestTracker class - tracker_ptr = NULL, + classname = "ForestModel", + cloneable = FALSE, + public = list( + #' @field tracker_ptr External pointer to a C++ ForestTracker class + tracker_ptr = NULL, - #' @field tree_prior_ptr External pointer to a C++ TreePrior class - tree_prior_ptr = NULL, + #' @field tree_prior_ptr External pointer to a C++ TreePrior class + tree_prior_ptr = NULL, - #' @description - #' Create a new ForestModel object. - #' @param forest_dataset `ForestDataset` object, used to initialize forest sampling data structures - #' @param feature_types Feature types (integers where 0 = numeric, 1 = ordered categorical, 2 = unordered categorical) - #' @param num_trees Number of trees in the forest being sampled - #' @param n Number of observations in `forest_dataset` - #' @param alpha Root node split probability in tree prior - #' @param beta Depth prior penalty in tree prior - #' @param min_samples_leaf Minimum number of samples in a tree leaf - #' @param max_depth Maximum depth that any tree can reach - #' @return A new `ForestModel` object. - initialize = function( - forest_dataset, - feature_types, - num_trees, - n, - alpha, - beta, - min_samples_leaf, - max_depth = -1 - ) { - stopifnot(!is.null(forest_dataset$data_ptr)) - self$tracker_ptr <- forest_tracker_cpp( - forest_dataset$data_ptr, - feature_types, - num_trees, - n - ) - self$tree_prior_ptr <- tree_prior_cpp( - alpha, - beta, - min_samples_leaf, - max_depth - ) - }, + #' @description + #' Create a new ForestModel object. + #' @param forest_dataset `ForestDataset` object, used to initialize forest sampling data structures + #' @param feature_types Feature types (integers where 0 = numeric, 1 = ordered categorical, 2 = unordered categorical) + #' @param num_trees Number of trees in the forest being sampled + #' @param n Number of observations in `forest_dataset` + #' @param alpha Root node split probability in tree prior + #' @param beta Depth prior penalty in tree prior + #' @param min_samples_leaf Minimum number of samples in a tree leaf + #' @param max_depth Maximum depth that any tree can reach + #' @return A new `ForestModel` object. + initialize = function( + forest_dataset, + feature_types, + num_trees, + n, + alpha, + beta, + min_samples_leaf, + max_depth = -1 + ) { + stopifnot(!is.null(forest_dataset$data_ptr)) + self$tracker_ptr <- forest_tracker_cpp( + forest_dataset$data_ptr, + feature_types, + num_trees, + n + ) + self$tree_prior_ptr <- tree_prior_cpp( + alpha, + beta, + min_samples_leaf, + max_depth + ) + }, - #' @description - #' Run a single iteration of the forest sampling algorithm (MCMC or GFR) - #' @param forest_dataset Dataset used to sample the forest - #' @param residual Outcome used to sample the forest - #' @param forest_samples Container of forest samples - #' @param active_forest "Active" forest updated by the sampler in each iteration - #' @param rng Wrapper around C++ random number generator - #' @param forest_model_config ForestModelConfig object containing forest model parameters and settings - #' @param global_model_config GlobalModelConfig object containing global model parameters and settings - #' @param num_threads Number of threads to use in the GFR and MCMC algorithms, as well as prediction. If OpenMP is not available on a user's system, this will default to `1`, otherwise to the maximum number of available threads. - #' @param keep_forest (Optional) Whether the updated forest sample should be saved to `forest_samples`. Default: `TRUE`. - #' @param gfr (Optional) Whether or not the forest should be sampled using the "grow-from-root" (GFR) algorithm. Default: `TRUE`. - sample_one_iteration = function( - forest_dataset, - residual, - forest_samples, - active_forest, - rng, - forest_model_config, - global_model_config, - num_threads = -1, - keep_forest = TRUE, - gfr = TRUE - ) { - if (active_forest$is_empty()) { - stop( - "`active_forest` has not yet been initialized, which is necessary to run the sampler. Please set constant values for `active_forest`'s leaves using either the `set_root_leaves` or `prepare_for_sampler` methods." - ) - } + #' @description + #' Run a single iteration of the forest sampling algorithm (MCMC or GFR) + #' @param forest_dataset Dataset used to sample the forest + #' @param residual Outcome used to sample the forest + #' @param forest_samples Container of forest samples + #' @param active_forest "Active" forest updated by the sampler in each iteration + #' @param rng Wrapper around C++ random number generator + #' @param forest_model_config ForestModelConfig object containing forest model parameters and settings + #' @param global_model_config GlobalModelConfig object containing global model parameters and settings + #' @param num_threads Number of threads to use in the GFR and MCMC algorithms, as well as prediction. If OpenMP is not available on a user's system, this will default to `1`, otherwise to the maximum number of available threads. + #' @param keep_forest (Optional) Whether the updated forest sample should be saved to `forest_samples`. Default: `TRUE`. + #' @param gfr (Optional) Whether or not the forest should be sampled using the "grow-from-root" (GFR) algorithm. Default: `TRUE`. + sample_one_iteration = function( + forest_dataset, + residual, + forest_samples, + active_forest, + rng, + forest_model_config, + global_model_config, + num_threads = -1, + keep_forest = TRUE, + gfr = TRUE + ) { + if (active_forest$is_empty()) { + stop( + "`active_forest` has not yet been initialized, which is necessary to run the sampler. Please set constant values for `active_forest`'s leaves using either the `set_root_leaves` or `prepare_for_sampler` methods." + ) + } - # Unpack parameters from model config object - feature_types <- forest_model_config$feature_types - sweep_update_indices <- forest_model_config$sweep_update_indices - leaf_model_int <- forest_model_config$leaf_model_type - leaf_model_scale <- forest_model_config$leaf_model_scale - variable_weights <- forest_model_config$variable_weights - a_forest <- forest_model_config$variance_forest_shape - b_forest <- forest_model_config$variance_forest_scale - global_scale <- global_model_config$global_error_variance - cutpoint_grid_size <- forest_model_config$cutpoint_grid_size - num_features_subsample <- forest_model_config$num_features_subsample + # Unpack parameters from model config object + feature_types <- forest_model_config$feature_types + sweep_update_indices <- forest_model_config$sweep_update_indices + leaf_model_int <- forest_model_config$leaf_model_type + leaf_model_scale <- forest_model_config$leaf_model_scale + variable_weights <- forest_model_config$variable_weights + a_forest <- forest_model_config$variance_forest_shape + b_forest <- forest_model_config$variance_forest_scale + global_scale <- global_model_config$global_error_variance + cutpoint_grid_size <- forest_model_config$cutpoint_grid_size + num_features_subsample <- forest_model_config$num_features_subsample - # Default to empty integer vector if sweep_update_indices is NULL - if (is.null(sweep_update_indices)) { - # sweep_update_indices <- integer(0) - sweep_update_indices <- 0:(forest_model_config$num_trees - 1) - } + # Default to empty integer vector if sweep_update_indices is NULL + if (is.null(sweep_update_indices)) { + # sweep_update_indices <- integer(0) + sweep_update_indices <- 0:(forest_model_config$num_trees - 1) + } - # Detect changes to tree prior - if ( - forest_model_config$alpha != - get_alpha_tree_prior_cpp(self$tree_prior_ptr) - ) { - update_alpha_tree_prior_cpp( - self$tree_prior_ptr, - forest_model_config$alpha - ) - } - if ( - forest_model_config$beta != - get_beta_tree_prior_cpp(self$tree_prior_ptr) - ) { - update_beta_tree_prior_cpp( - self$tree_prior_ptr, - forest_model_config$beta - ) - } - if ( - forest_model_config$min_samples_leaf != - get_min_samples_leaf_tree_prior_cpp(self$tree_prior_ptr) - ) { - update_min_samples_leaf_tree_prior_cpp( - self$tree_prior_ptr, - forest_model_config$min_samples_leaf - ) - } - if ( - forest_model_config$max_depth != - get_max_depth_tree_prior_cpp(self$tree_prior_ptr) - ) { - update_max_depth_tree_prior_cpp( - self$tree_prior_ptr, - forest_model_config$max_depth - ) - } + # Detect changes to tree prior + if ( + forest_model_config$alpha != + get_alpha_tree_prior_cpp(self$tree_prior_ptr) + ) { + update_alpha_tree_prior_cpp( + self$tree_prior_ptr, + forest_model_config$alpha + ) + } + if ( + forest_model_config$beta != get_beta_tree_prior_cpp(self$tree_prior_ptr) + ) { + update_beta_tree_prior_cpp( + self$tree_prior_ptr, + forest_model_config$beta + ) + } + if ( + forest_model_config$min_samples_leaf != + get_min_samples_leaf_tree_prior_cpp(self$tree_prior_ptr) + ) { + update_min_samples_leaf_tree_prior_cpp( + self$tree_prior_ptr, + forest_model_config$min_samples_leaf + ) + } + if ( + forest_model_config$max_depth != + get_max_depth_tree_prior_cpp(self$tree_prior_ptr) + ) { + update_max_depth_tree_prior_cpp( + self$tree_prior_ptr, + forest_model_config$max_depth + ) + } - # Run the sampler - if (gfr) { - sample_gfr_one_iteration_cpp( - forest_dataset$data_ptr, - residual$data_ptr, - forest_samples$forest_container_ptr, - active_forest$forest_ptr, - self$tracker_ptr, - self$tree_prior_ptr, - rng$rng_ptr, - sweep_update_indices, - feature_types, - cutpoint_grid_size, - leaf_model_scale, - variable_weights, - a_forest, - b_forest, - global_scale, - leaf_model_int, - keep_forest, - num_features_subsample, - num_threads - ) - } else { - sample_mcmc_one_iteration_cpp( - forest_dataset$data_ptr, - residual$data_ptr, - forest_samples$forest_container_ptr, - active_forest$forest_ptr, - self$tracker_ptr, - self$tree_prior_ptr, - rng$rng_ptr, - sweep_update_indices, - feature_types, - cutpoint_grid_size, - leaf_model_scale, - variable_weights, - a_forest, - b_forest, - global_scale, - leaf_model_int, - keep_forest, - num_threads - ) - } - }, + # Run the sampler + if (gfr) { + sample_gfr_one_iteration_cpp( + forest_dataset$data_ptr, + residual$data_ptr, + forest_samples$forest_container_ptr, + active_forest$forest_ptr, + self$tracker_ptr, + self$tree_prior_ptr, + rng$rng_ptr, + sweep_update_indices, + feature_types, + cutpoint_grid_size, + leaf_model_scale, + variable_weights, + a_forest, + b_forest, + global_scale, + leaf_model_int, + keep_forest, + num_features_subsample, + num_threads + ) + } else { + sample_mcmc_one_iteration_cpp( + forest_dataset$data_ptr, + residual$data_ptr, + forest_samples$forest_container_ptr, + active_forest$forest_ptr, + self$tracker_ptr, + self$tree_prior_ptr, + rng$rng_ptr, + sweep_update_indices, + feature_types, + cutpoint_grid_size, + leaf_model_scale, + variable_weights, + a_forest, + b_forest, + global_scale, + leaf_model_int, + keep_forest, + num_threads + ) + } + }, - #' @description - #' Extract an internally-cached prediction of a forest on the training dataset in a sampler. - #' @return Vector with as many elements as observations in the training dataset - get_cached_forest_predictions = function() { - get_cached_forest_predictions_cpp(self$tracker_ptr) - }, + #' @description + #' Extract an internally-cached prediction of a forest on the training dataset in a sampler. + #' @return Vector with as many elements as observations in the training dataset + get_cached_forest_predictions = function() { + get_cached_forest_predictions_cpp(self$tracker_ptr) + }, - #' @description - #' Propagates basis update through to the (full/partial) residual by iteratively - #' (a) adding back in the previous prediction of each tree, (b) recomputing predictions - #' for each tree (caching on the C++ side), (c) subtracting the new predictions from the residual. - #' - #' This is useful in cases where a basis (for e.g. leaf regression) is updated outside - #' of a tree sampler (as with e.g. adaptive coding for binary treatment BCF). - #' Once a basis has been updated, the overall "function" represented by a tree model has - #' changed and this should be reflected through to the residual before the next sampling loop is run. - #' @param dataset `ForestDataset` object storing the covariates and bases for a given forest - #' @param outcome `Outcome` object storing the residuals to be updated based on forest predictions - #' @param active_forest "Active" forest updated by the sampler in each iteration - propagate_basis_update = function(dataset, outcome, active_forest) { - stopifnot(!is.null(dataset$data_ptr)) - stopifnot(!is.null(outcome$data_ptr)) - stopifnot(!is.null(self$tracker_ptr)) - stopifnot(!is.null(active_forest$forest_ptr)) + #' @description + #' Propagates basis update through to the (full/partial) residual by iteratively + #' (a) adding back in the previous prediction of each tree, (b) recomputing predictions + #' for each tree (caching on the C++ side), (c) subtracting the new predictions from the residual. + #' + #' This is useful in cases where a basis (for e.g. leaf regression) is updated outside + #' of a tree sampler (as with e.g. adaptive coding for binary treatment BCF). + #' Once a basis has been updated, the overall "function" represented by a tree model has + #' changed and this should be reflected through to the residual before the next sampling loop is run. + #' @param dataset `ForestDataset` object storing the covariates and bases for a given forest + #' @param outcome `Outcome` object storing the residuals to be updated based on forest predictions + #' @param active_forest "Active" forest updated by the sampler in each iteration + propagate_basis_update = function(dataset, outcome, active_forest) { + stopifnot(!is.null(dataset$data_ptr)) + stopifnot(!is.null(outcome$data_ptr)) + stopifnot(!is.null(self$tracker_ptr)) + stopifnot(!is.null(active_forest$forest_ptr)) - propagate_basis_update_active_forest_cpp( - dataset$data_ptr, - outcome$data_ptr, - active_forest$forest_ptr, - self$tracker_ptr - ) - }, + propagate_basis_update_active_forest_cpp( + dataset$data_ptr, + outcome$data_ptr, + active_forest$forest_ptr, + self$tracker_ptr + ) + }, - #' @description - #' Update the current state of the outcome (i.e. partial residual) data by subtracting the current predictions of each tree. - #' This function is run after the `Outcome` class's `update_data` method, which overwrites the partial residual with an entirely new stream of outcome data. - #' @param residual Outcome used to sample the forest - #' @return None - propagate_residual_update = function(residual) { - propagate_trees_column_vector_cpp( - self$tracker_ptr, - residual$data_ptr - ) - }, + #' @description + #' Update the current state of the outcome (i.e. partial residual) data by subtracting the current predictions of each tree. + #' This function is run after the `Outcome` class's `update_data` method, which overwrites the partial residual with an entirely new stream of outcome data. + #' @param residual Outcome used to sample the forest + #' @return None + propagate_residual_update = function(residual) { + propagate_trees_column_vector_cpp( + self$tracker_ptr, + residual$data_ptr + ) + }, - #' @description - #' Update alpha in the tree prior - #' @param alpha New value of alpha to be used - #' @return None - update_alpha = function(alpha) { - update_alpha_tree_prior_cpp(self$tree_prior_ptr, alpha) - }, + #' @description + #' Update alpha in the tree prior + #' @param alpha New value of alpha to be used + #' @return None + update_alpha = function(alpha) { + update_alpha_tree_prior_cpp(self$tree_prior_ptr, alpha) + }, - #' @description - #' Update beta in the tree prior - #' @param beta New value of beta to be used - #' @return None - update_beta = function(beta) { - update_beta_tree_prior_cpp(self$tree_prior_ptr, beta) - }, + #' @description + #' Update beta in the tree prior + #' @param beta New value of beta to be used + #' @return None + update_beta = function(beta) { + update_beta_tree_prior_cpp(self$tree_prior_ptr, beta) + }, - #' @description - #' Update min_samples_leaf in the tree prior - #' @param min_samples_leaf New value of min_samples_leaf to be used - #' @return None - update_min_samples_leaf = function(min_samples_leaf) { - update_min_samples_leaf_tree_prior_cpp( - self$tree_prior_ptr, - min_samples_leaf - ) - }, + #' @description + #' Update min_samples_leaf in the tree prior + #' @param min_samples_leaf New value of min_samples_leaf to be used + #' @return None + update_min_samples_leaf = function(min_samples_leaf) { + update_min_samples_leaf_tree_prior_cpp( + self$tree_prior_ptr, + min_samples_leaf + ) + }, - #' @description - #' Update max_depth in the tree prior - #' @param max_depth New value of max_depth to be used - #' @return None - update_max_depth = function(max_depth) { - update_max_depth_tree_prior_cpp(self$tree_prior_ptr, max_depth) - }, + #' @description + #' Update max_depth in the tree prior + #' @param max_depth New value of max_depth to be used + #' @return None + update_max_depth = function(max_depth) { + update_max_depth_tree_prior_cpp(self$tree_prior_ptr, max_depth) + }, - #' @description - #' Update alpha in the tree prior - #' @return Value of alpha in the tree prior - get_alpha = function() { - get_alpha_tree_prior_cpp(self$tree_prior_ptr) - }, + #' @description + #' Update alpha in the tree prior + #' @return Value of alpha in the tree prior + get_alpha = function() { + get_alpha_tree_prior_cpp(self$tree_prior_ptr) + }, - #' @description - #' Update beta in the tree prior - #' @return Value of beta in the tree prior - get_beta = function() { - get_beta_tree_prior_cpp(self$tree_prior_ptr) - }, + #' @description + #' Update beta in the tree prior + #' @return Value of beta in the tree prior + get_beta = function() { + get_beta_tree_prior_cpp(self$tree_prior_ptr) + }, - #' @description - #' Query min_samples_leaf in the tree prior - #' @return Value of min_samples_leaf in the tree prior - get_min_samples_leaf = function() { - get_min_samples_leaf_tree_prior_cpp(self$tree_prior_ptr) - }, + #' @description + #' Query min_samples_leaf in the tree prior + #' @return Value of min_samples_leaf in the tree prior + get_min_samples_leaf = function() { + get_min_samples_leaf_tree_prior_cpp(self$tree_prior_ptr) + }, - #' @description - #' Query max_depth in the tree prior - #' @return Value of max_depth in the tree prior - get_max_depth = function() { - get_max_depth_tree_prior_cpp(self$tree_prior_ptr) - } - ) + #' @description + #' Query max_depth in the tree prior + #' @return Value of max_depth in the tree prior + get_max_depth = function() { + get_max_depth_tree_prior_cpp(self$tree_prior_ptr) + } + ) ) #' Create an R class that wraps a C++ random number generator @@ -329,7 +328,7 @@ ForestModel <- R6::R6Class( #' rng <- createCppRNG(1234) #' rng <- createCppRNG() createCppRNG <- function(random_seed = -1) { - return(invisible((CppRNG$new(random_seed)))) + return(invisible((CppRNG$new(random_seed)))) } #' Create a forest model object @@ -360,22 +359,22 @@ createCppRNG <- function(random_seed = -1) { #' global_model_config <- createGlobalModelConfig(global_error_variance=1.0) #' forest_model <- createForestModel(forest_dataset, forest_model_config, global_model_config) createForestModel <- function( - forest_dataset, - forest_model_config, - global_model_config + forest_dataset, + forest_model_config, + global_model_config ) { - return(invisible( - (ForestModel$new( - forest_dataset, - forest_model_config$feature_types, - forest_model_config$num_trees, - forest_model_config$num_observations, - forest_model_config$alpha, - forest_model_config$beta, - forest_model_config$min_samples_leaf, - forest_model_config$max_depth - )) + return(invisible( + (ForestModel$new( + forest_dataset, + forest_model_config$feature_types, + forest_model_config$num_trees, + forest_model_config$num_observations, + forest_model_config$alpha, + forest_model_config$beta, + forest_model_config$min_samples_leaf, + forest_model_config$max_depth )) + )) } @@ -394,13 +393,13 @@ createForestModel <- function( #' num_samples <- 5 #' sample_without_replacement(a, p, num_samples) sample_without_replacement <- function( + population_vector, + sampling_probabilities, + sample_size +) { + return(sample_without_replacement_integer_cpp( population_vector, sampling_probabilities, sample_size -) { - return(sample_without_replacement_integer_cpp( - population_vector, - sampling_probabilities, - sample_size - )) + )) } diff --git a/R/random_effects.R b/R/random_effects.R index b91c2678..868897ed 100644 --- a/R/random_effects.R +++ b/R/random_effects.R @@ -9,234 +9,234 @@ #' needed for prediction / serialization RandomEffectSamples <- R6::R6Class( - classname = "RandomEffectSamples", - cloneable = FALSE, - public = list( - #' @field rfx_container_ptr External pointer to a C++ StochTree::RandomEffectsContainer class - rfx_container_ptr = NULL, - - #' @field label_mapper_ptr External pointer to a C++ StochTree::LabelMapper class - label_mapper_ptr = NULL, - - #' @field training_group_ids Unique vector of group IDs that were in the training dataset - training_group_ids = NULL, - - #' @description - #' Create a new RandomEffectSamples object. - #' @return A new `RandomEffectSamples` object. - initialize = function() {}, - - #' @description - #' Construct RandomEffectSamples object from other "in-session" R objects - #' @param num_components Number of "components" or bases defining the random effects regression - #' @param num_groups Number of random effects groups - #' @param random_effects_tracker Object of type `RandomEffectsTracker` - #' @return None - load_in_session = function( - num_components, - num_groups, - random_effects_tracker - ) { - # Initialize - self$rfx_container_ptr <- rfx_container_cpp( - num_components, - num_groups - ) - self$label_mapper_ptr <- rfx_label_mapper_cpp( - random_effects_tracker$rfx_tracker_ptr - ) - self$training_group_ids <- rfx_tracker_get_unique_group_ids_cpp( - random_effects_tracker$rfx_tracker_ptr - ) - }, - - #' @description - #' Construct RandomEffectSamples object from a json object - #' @param json_object Object of class `CppJson` - #' @param json_rfx_container_label Label referring to a particular rfx sample container (i.e. "random_effect_container_0") in the overall json hierarchy - #' @param json_rfx_mapper_label Label referring to a particular rfx label mapper (i.e. "random_effect_label_mapper_0") in the overall json hierarchy - #' @param json_rfx_groupids_label Label referring to a particular set of rfx group IDs (i.e. "random_effect_groupids_0") in the overall json hierarchy - #' @return A new `RandomEffectSamples` object. - load_from_json = function( - json_object, - json_rfx_container_label, - json_rfx_mapper_label, - json_rfx_groupids_label - ) { - self$rfx_container_ptr <- rfx_container_from_json_cpp( - json_object$json_ptr, - json_rfx_container_label - ) - self$label_mapper_ptr <- rfx_label_mapper_from_json_cpp( - json_object$json_ptr, - json_rfx_mapper_label - ) - self$training_group_ids <- rfx_group_ids_from_json_cpp( - json_object$json_ptr, - json_rfx_groupids_label - ) - }, - - #' @description - #' Append random effect draws to `RandomEffectSamples` object from a json object - #' @param json_object Object of class `CppJson` - #' @param json_rfx_container_label Label referring to a particular rfx sample container (i.e. "random_effect_container_0") in the overall json hierarchy - #' @param json_rfx_mapper_label Label referring to a particular rfx label mapper (i.e. "random_effect_label_mapper_0") in the overall json hierarchy - #' @param json_rfx_groupids_label Label referring to a particular set of rfx group IDs (i.e. "random_effect_groupids_0") in the overall json hierarchy - #' @return None - append_from_json = function( - json_object, - json_rfx_container_label, - json_rfx_mapper_label, - json_rfx_groupids_label - ) { - rfx_container_append_from_json_cpp( - self$rfx_container_ptr, - json_object$json_ptr, - json_rfx_container_label - ) - }, - - #' @description - #' Construct RandomEffectSamples object from a json object - #' @param json_string JSON string which parses into object of class `CppJson` - #' @param json_rfx_container_label Label referring to a particular rfx sample container (i.e. "random_effect_container_0") in the overall json hierarchy - #' @param json_rfx_mapper_label Label referring to a particular rfx label mapper (i.e. "random_effect_label_mapper_0") in the overall json hierarchy - #' @param json_rfx_groupids_label Label referring to a particular set of rfx group IDs (i.e. "random_effect_groupids_0") in the overall json hierarchy - #' @return A new `RandomEffectSamples` object. - load_from_json_string = function( - json_string, - json_rfx_container_label, - json_rfx_mapper_label, - json_rfx_groupids_label - ) { - self$rfx_container_ptr <- rfx_container_from_json_string_cpp( - json_string, - json_rfx_container_label - ) - self$label_mapper_ptr <- rfx_label_mapper_from_json_string_cpp( - json_string, - json_rfx_mapper_label - ) - self$training_group_ids <- rfx_group_ids_from_json_string_cpp( - json_string, - json_rfx_groupids_label - ) - }, - - #' @description - #' Append random effect draws to `RandomEffectSamples` object from a json object - #' @param json_string JSON string which parses into object of class `CppJson` - #' @param json_rfx_container_label Label referring to a particular rfx sample container (i.e. "random_effect_container_0") in the overall json hierarchy - #' @param json_rfx_mapper_label Label referring to a particular rfx label mapper (i.e. "random_effect_label_mapper_0") in the overall json hierarchy - #' @param json_rfx_groupids_label Label referring to a particular set of rfx group IDs (i.e. "random_effect_groupids_0") in the overall json hierarchy - #' @return None - append_from_json_string = function( - json_string, - json_rfx_container_label, - json_rfx_mapper_label, - json_rfx_groupids_label - ) { - # Append RFX objects - rfx_container_append_from_json_string_cpp( - self$rfx_container_ptr, - json_string, - json_rfx_container_label - ) - }, - - #' @description - #' Predict random effects for each observation implied by `rfx_group_ids` and `rfx_basis`. - #' If a random effects model is "intercept-only" the `rfx_basis` will be a vector of ones of size `length(rfx_group_ids)`. - #' @param rfx_group_ids Indices of random effects groups in a prediction set - #' @param rfx_basis (Optional) Basis used for random effects prediction - #' @return Matrix with as many rows as observations provided and as many columns as samples drawn of the model. - predict = function(rfx_group_ids, rfx_basis = NULL) { - num_obs = length(rfx_group_ids) - if (is.null(rfx_basis)) { - rfx_basis <- matrix(rep(1, num_obs), ncol = 1) - } - num_samples = rfx_container_num_samples_cpp(self$rfx_container_ptr) - num_components = rfx_container_num_components_cpp( - self$rfx_container_ptr - ) - num_groups = rfx_container_num_groups_cpp(self$rfx_container_ptr) - rfx_group_ids_int <- as.integer(rfx_group_ids) - stopifnot(sum(abs(rfx_group_ids_int - rfx_group_ids)) < 1e-6) - stopifnot(sum(!(rfx_group_ids %in% self$training_group_ids)) == 0) - stopifnot(ncol(rfx_basis) == num_components) - rfx_dataset <- createRandomEffectsDataset( - rfx_group_ids_int, - rfx_basis - ) - output <- rfx_container_predict_cpp( - self$rfx_container_ptr, - rfx_dataset$data_ptr, - self$label_mapper_ptr - ) - dim(output) <- c(num_obs, num_samples) - return(output) - }, - - #' @description - #' Extract the random effects parameters sampled. With the "redundant parameterization" - #' of Gelman et al (2008), this includes four parameters: alpha (the "working parameter" - #' shared across every group), xi (the "group parameter" sampled separately for each group), - #' beta (the product of alpha and xi, which corresponds to the overall group-level random effects), - #' and sigma (group-independent prior variance for each component of xi). - #' @return List of arrays. The alpha array has dimension (`num_components`, `num_samples`) and is simply a vector if `num_components = 1`. - #' The xi and beta arrays have dimension (`num_components`, `num_groups`, `num_samples`) and is simply a matrix if `num_components = 1`. - #' The sigma array has dimension (`num_components`, `num_samples`) and is simply a vector if `num_components = 1`. - extract_parameter_samples = function() { - num_samples = rfx_container_num_samples_cpp(self$rfx_container_ptr) - num_components = rfx_container_num_components_cpp( - self$rfx_container_ptr - ) - num_groups = rfx_container_num_groups_cpp(self$rfx_container_ptr) - beta_samples <- rfx_container_get_beta_cpp(self$rfx_container_ptr) - xi_samples <- rfx_container_get_xi_cpp(self$rfx_container_ptr) - alpha_samples <- rfx_container_get_alpha_cpp(self$rfx_container_ptr) - sigma_samples <- rfx_container_get_sigma_cpp(self$rfx_container_ptr) - if (num_components == 1) { - dim(beta_samples) <- c(num_groups, num_samples) - dim(xi_samples) <- c(num_groups, num_samples) - } else if (num_components > 1) { - dim(beta_samples) <- c(num_components, num_groups, num_samples) - dim(xi_samples) <- c(num_components, num_groups, num_samples) - dim(alpha_samples) <- c(num_components, num_samples) - dim(sigma_samples) <- c(num_components, num_samples) - } else { - stop( - "Invalid random effects sample container, num_components is less than 1" - ) - } - - output = list( - "beta_samples" = beta_samples, - "xi_samples" = xi_samples, - "alpha_samples" = alpha_samples, - "sigma_samples" = sigma_samples - ) - return(output) - }, - - #' @description - #' Modify the `RandomEffectsSamples` object by removing the parameter samples index by `sample_num`. - #' @param sample_num Index of the RFX sample to be removed - delete_sample = function(sample_num) { - rfx_container_delete_sample_cpp(self$rfx_container_ptr, sample_num) - }, - - #' @description - #' Convert the mapping of group IDs to random effect components indices from C++ to R native format - #' @return List mapping group ID to random effect components. - extract_label_mapping = function() { - keys_and_vals <- rfx_label_mapper_to_list_cpp(self$label_mapper_ptr) - result <- as.list(keys_and_vals[[2]] + 1) - setNames(result, keys_and_vals[[1]]) - return(result) - } - ) + classname = "RandomEffectSamples", + cloneable = FALSE, + public = list( + #' @field rfx_container_ptr External pointer to a C++ StochTree::RandomEffectsContainer class + rfx_container_ptr = NULL, + + #' @field label_mapper_ptr External pointer to a C++ StochTree::LabelMapper class + label_mapper_ptr = NULL, + + #' @field training_group_ids Unique vector of group IDs that were in the training dataset + training_group_ids = NULL, + + #' @description + #' Create a new RandomEffectSamples object. + #' @return A new `RandomEffectSamples` object. + initialize = function() {}, + + #' @description + #' Construct RandomEffectSamples object from other "in-session" R objects + #' @param num_components Number of "components" or bases defining the random effects regression + #' @param num_groups Number of random effects groups + #' @param random_effects_tracker Object of type `RandomEffectsTracker` + #' @return None + load_in_session = function( + num_components, + num_groups, + random_effects_tracker + ) { + # Initialize + self$rfx_container_ptr <- rfx_container_cpp( + num_components, + num_groups + ) + self$label_mapper_ptr <- rfx_label_mapper_cpp( + random_effects_tracker$rfx_tracker_ptr + ) + self$training_group_ids <- rfx_tracker_get_unique_group_ids_cpp( + random_effects_tracker$rfx_tracker_ptr + ) + }, + + #' @description + #' Construct RandomEffectSamples object from a json object + #' @param json_object Object of class `CppJson` + #' @param json_rfx_container_label Label referring to a particular rfx sample container (i.e. "random_effect_container_0") in the overall json hierarchy + #' @param json_rfx_mapper_label Label referring to a particular rfx label mapper (i.e. "random_effect_label_mapper_0") in the overall json hierarchy + #' @param json_rfx_groupids_label Label referring to a particular set of rfx group IDs (i.e. "random_effect_groupids_0") in the overall json hierarchy + #' @return A new `RandomEffectSamples` object. + load_from_json = function( + json_object, + json_rfx_container_label, + json_rfx_mapper_label, + json_rfx_groupids_label + ) { + self$rfx_container_ptr <- rfx_container_from_json_cpp( + json_object$json_ptr, + json_rfx_container_label + ) + self$label_mapper_ptr <- rfx_label_mapper_from_json_cpp( + json_object$json_ptr, + json_rfx_mapper_label + ) + self$training_group_ids <- rfx_group_ids_from_json_cpp( + json_object$json_ptr, + json_rfx_groupids_label + ) + }, + + #' @description + #' Append random effect draws to `RandomEffectSamples` object from a json object + #' @param json_object Object of class `CppJson` + #' @param json_rfx_container_label Label referring to a particular rfx sample container (i.e. "random_effect_container_0") in the overall json hierarchy + #' @param json_rfx_mapper_label Label referring to a particular rfx label mapper (i.e. "random_effect_label_mapper_0") in the overall json hierarchy + #' @param json_rfx_groupids_label Label referring to a particular set of rfx group IDs (i.e. "random_effect_groupids_0") in the overall json hierarchy + #' @return None + append_from_json = function( + json_object, + json_rfx_container_label, + json_rfx_mapper_label, + json_rfx_groupids_label + ) { + rfx_container_append_from_json_cpp( + self$rfx_container_ptr, + json_object$json_ptr, + json_rfx_container_label + ) + }, + + #' @description + #' Construct RandomEffectSamples object from a json object + #' @param json_string JSON string which parses into object of class `CppJson` + #' @param json_rfx_container_label Label referring to a particular rfx sample container (i.e. "random_effect_container_0") in the overall json hierarchy + #' @param json_rfx_mapper_label Label referring to a particular rfx label mapper (i.e. "random_effect_label_mapper_0") in the overall json hierarchy + #' @param json_rfx_groupids_label Label referring to a particular set of rfx group IDs (i.e. "random_effect_groupids_0") in the overall json hierarchy + #' @return A new `RandomEffectSamples` object. + load_from_json_string = function( + json_string, + json_rfx_container_label, + json_rfx_mapper_label, + json_rfx_groupids_label + ) { + self$rfx_container_ptr <- rfx_container_from_json_string_cpp( + json_string, + json_rfx_container_label + ) + self$label_mapper_ptr <- rfx_label_mapper_from_json_string_cpp( + json_string, + json_rfx_mapper_label + ) + self$training_group_ids <- rfx_group_ids_from_json_string_cpp( + json_string, + json_rfx_groupids_label + ) + }, + + #' @description + #' Append random effect draws to `RandomEffectSamples` object from a json object + #' @param json_string JSON string which parses into object of class `CppJson` + #' @param json_rfx_container_label Label referring to a particular rfx sample container (i.e. "random_effect_container_0") in the overall json hierarchy + #' @param json_rfx_mapper_label Label referring to a particular rfx label mapper (i.e. "random_effect_label_mapper_0") in the overall json hierarchy + #' @param json_rfx_groupids_label Label referring to a particular set of rfx group IDs (i.e. "random_effect_groupids_0") in the overall json hierarchy + #' @return None + append_from_json_string = function( + json_string, + json_rfx_container_label, + json_rfx_mapper_label, + json_rfx_groupids_label + ) { + # Append RFX objects + rfx_container_append_from_json_string_cpp( + self$rfx_container_ptr, + json_string, + json_rfx_container_label + ) + }, + + #' @description + #' Predict random effects for each observation implied by `rfx_group_ids` and `rfx_basis`. + #' If a random effects model is "intercept-only" the `rfx_basis` will be a vector of ones of size `length(rfx_group_ids)`. + #' @param rfx_group_ids Indices of random effects groups in a prediction set + #' @param rfx_basis (Optional) Basis used for random effects prediction + #' @return Matrix with as many rows as observations provided and as many columns as samples drawn of the model. + predict = function(rfx_group_ids, rfx_basis = NULL) { + num_obs = length(rfx_group_ids) + if (is.null(rfx_basis)) { + rfx_basis <- matrix(rep(1, num_obs), ncol = 1) + } + num_samples = rfx_container_num_samples_cpp(self$rfx_container_ptr) + num_components = rfx_container_num_components_cpp( + self$rfx_container_ptr + ) + num_groups = rfx_container_num_groups_cpp(self$rfx_container_ptr) + rfx_group_ids_int <- as.integer(rfx_group_ids) + stopifnot(sum(abs(rfx_group_ids_int - rfx_group_ids)) < 1e-6) + stopifnot(sum(!(rfx_group_ids %in% self$training_group_ids)) == 0) + stopifnot(ncol(rfx_basis) == num_components) + rfx_dataset <- createRandomEffectsDataset( + rfx_group_ids_int, + rfx_basis + ) + output <- rfx_container_predict_cpp( + self$rfx_container_ptr, + rfx_dataset$data_ptr, + self$label_mapper_ptr + ) + dim(output) <- c(num_obs, num_samples) + return(output) + }, + + #' @description + #' Extract the random effects parameters sampled. With the "redundant parameterization" + #' of Gelman et al (2008), this includes four parameters: alpha (the "working parameter" + #' shared across every group), xi (the "group parameter" sampled separately for each group), + #' beta (the product of alpha and xi, which corresponds to the overall group-level random effects), + #' and sigma (group-independent prior variance for each component of xi). + #' @return List of arrays. The alpha array has dimension (`num_components`, `num_samples`) and is simply a vector if `num_components = 1`. + #' The xi and beta arrays have dimension (`num_components`, `num_groups`, `num_samples`) and are simply matrices if `num_components = 1`. + #' The sigma array has dimension (`num_components`, `num_samples`) and is simply a vector if `num_components = 1`. + extract_parameter_samples = function() { + num_samples = rfx_container_num_samples_cpp(self$rfx_container_ptr) + num_components = rfx_container_num_components_cpp( + self$rfx_container_ptr + ) + num_groups = rfx_container_num_groups_cpp(self$rfx_container_ptr) + beta_samples <- rfx_container_get_beta_cpp(self$rfx_container_ptr) + xi_samples <- rfx_container_get_xi_cpp(self$rfx_container_ptr) + alpha_samples <- rfx_container_get_alpha_cpp(self$rfx_container_ptr) + sigma_samples <- rfx_container_get_sigma_cpp(self$rfx_container_ptr) + if (num_components == 1) { + dim(beta_samples) <- c(num_groups, num_samples) + dim(xi_samples) <- c(num_groups, num_samples) + } else if (num_components > 1) { + dim(beta_samples) <- c(num_components, num_groups, num_samples) + dim(xi_samples) <- c(num_components, num_groups, num_samples) + dim(alpha_samples) <- c(num_components, num_samples) + dim(sigma_samples) <- c(num_components, num_samples) + } else { + stop( + "Invalid random effects sample container, num_components is less than 1" + ) + } + + output = list( + "beta_samples" = beta_samples, + "xi_samples" = xi_samples, + "alpha_samples" = alpha_samples, + "sigma_samples" = sigma_samples + ) + return(output) + }, + + #' @description + #' Modify the `RandomEffectsSamples` object by removing the parameter samples index by `sample_num`. + #' @param sample_num Index of the RFX sample to be removed + delete_sample = function(sample_num) { + rfx_container_delete_sample_cpp(self$rfx_container_ptr, sample_num) + }, + + #' @description + #' Convert the mapping of group IDs to random effect components indices from C++ to R native format + #' @return List mapping group ID to random effect components. + extract_label_mapping = function() { + keys_and_vals <- rfx_label_mapper_to_list_cpp(self$label_mapper_ptr) + result <- as.list(keys_and_vals[[2]] + 1) + setNames(result, keys_and_vals[[1]]) + return(result) + } + ) ) #' Class that defines a "tracker" for random effects models, most notably @@ -249,21 +249,21 @@ RandomEffectSamples <- R6::R6Class( #' group, and predictions for each observation. RandomEffectsTracker <- R6::R6Class( - classname = "RandomEffectsTracker", - cloneable = FALSE, - public = list( - #' @field rfx_tracker_ptr External pointer to a C++ StochTree::RandomEffectsTracker class - rfx_tracker_ptr = NULL, - - #' @description - #' Create a new RandomEffectsTracker object. - #' @param rfx_group_indices Integer indices indicating groups used to define random effects - #' @return A new `RandomEffectsTracker` object. - initialize = function(rfx_group_indices) { - # Initialize - self$rfx_tracker_ptr <- rfx_tracker_cpp(rfx_group_indices) - } - ) + classname = "RandomEffectsTracker", + cloneable = FALSE, + public = list( + #' @field rfx_tracker_ptr External pointer to a C++ StochTree::RandomEffectsTracker class + rfx_tracker_ptr = NULL, + + #' @description + #' Create a new RandomEffectsTracker object. + #' @param rfx_group_indices Integer indices indicating groups used to define random effects + #' @return A new `RandomEffectsTracker` object. + initialize = function(rfx_group_indices) { + # Initialize + self$rfx_tracker_ptr <- rfx_tracker_cpp(rfx_group_indices) + } + ) ) #' The core "model" class for sampling random effects. @@ -273,158 +273,158 @@ RandomEffectsTracker <- R6::R6Class( #' sampling from the conditional posterior of each parameter. RandomEffectsModel <- R6::R6Class( - classname = "RandomEffectsModel", - cloneable = FALSE, - public = list( - #' @field rfx_model_ptr External pointer to a C++ StochTree::RandomEffectsModel class - rfx_model_ptr = NULL, - - #' @field num_groups Number of groups in the random effects model - num_groups = NULL, - - #' @field num_components Number of components (i.e. dimension of basis) in the random effects model - num_components = NULL, - - #' @description - #' Create a new RandomEffectsModel object. - #' @param num_components Number of "components" or bases defining the random effects regression - #' @param num_groups Number of random effects groups - #' @return A new `RandomEffectsModel` object. - initialize = function(num_components, num_groups) { - # Initialize - self$rfx_model_ptr <- rfx_model_cpp(num_components, num_groups) - self$num_components <- num_components - self$num_groups <- num_groups - }, - - #' @description - #' Sample from random effects model. - #' @param rfx_dataset Object of type `RandomEffectsDataset` - #' @param residual Object of type `Outcome` - #' @param rfx_tracker Object of type `RandomEffectsTracker` - #' @param rfx_samples Object of type `RandomEffectSamples` - #' @param keep_sample Whether sample should be retained in `rfx_samples`. If `FALSE`, the state of `rfx_tracker` will be updated, but the parameter values will not be added to the sample container. Samples are commonly discarded due to burn-in or thinning. - #' @param global_variance Scalar global variance parameter - #' @param rng Object of type `CppRNG` - #' @return None - sample_random_effect = function( - rfx_dataset, - residual, - rfx_tracker, - rfx_samples, - keep_sample, - global_variance, - rng - ) { - rfx_model_sample_random_effects_cpp( - self$rfx_model_ptr, - rfx_dataset$data_ptr, - residual$data_ptr, - rfx_tracker$rfx_tracker_ptr, - rfx_samples$rfx_container_ptr, - keep_sample, - global_variance, - rng$rng_ptr - ) - }, - - #' @description - #' Predict from (a single sample of a) random effects model. - #' @param rfx_dataset Object of type `RandomEffectsDataset` - #' @param rfx_tracker Object of type `RandomEffectsTracker` - #' @return Vector of predictions with size matching number of observations in rfx_dataset - predict = function(rfx_dataset, rfx_tracker) { - pred <- rfx_model_predict_cpp( - self$rfx_model_ptr, - rfx_dataset$data_ptr, - rfx_tracker$rfx_tracker_ptr - ) - return(pred) - }, - - #' @description - #' Set value for the "working parameter." This is typically - #' used for initialization, but could also be used to interrupt - #' or override the sampler. - #' @param value Parameter input - #' @return None - set_working_parameter = function(value) { - stopifnot(is.double(value)) - stopifnot(!is.matrix(value)) - stopifnot(length(value) == self$num_components) - rfx_model_set_working_parameter_cpp(self$rfx_model_ptr, value) - }, - - #' @description - #' Set value for the "group parameters." This is typically - #' used for initialization, but could also be used to interrupt - #' or override the sampler. - #' @param value Parameter input - #' @return None - set_group_parameters = function(value) { - stopifnot(is.double(value)) - stopifnot(is.matrix(value)) - stopifnot(nrow(value) == self$num_components) - stopifnot(ncol(value) == self$num_groups) - rfx_model_set_group_parameters_cpp(self$rfx_model_ptr, value) - }, - - #' @description - #' Set value for the working parameter covariance. This is typically - #' used for initialization, but could also be used to interrupt - #' or override the sampler. - #' @param value Parameter input - #' @return None - set_working_parameter_cov = function(value) { - stopifnot(is.double(value)) - stopifnot(is.matrix(value)) - stopifnot(nrow(value) == self$num_components) - stopifnot(ncol(value) == self$num_components) - rfx_model_set_working_parameter_covariance_cpp( - self$rfx_model_ptr, - value - ) - }, - - #' @description - #' Set value for the group parameter covariance. This is typically - #' used for initialization, but could also be used to interrupt - #' or override the sampler. - #' @param value Parameter input - #' @return None - set_group_parameter_cov = function(value) { - stopifnot(is.double(value)) - stopifnot(is.matrix(value)) - stopifnot(nrow(value) == self$num_components) - stopifnot(ncol(value) == self$num_components) - rfx_model_set_group_parameter_covariance_cpp( - self$rfx_model_ptr, - value - ) - }, - - #' @description - #' Set shape parameter for the group parameter variance prior. - #' @param value Parameter input - #' @return None - set_variance_prior_shape = function(value) { - stopifnot(is.double(value)) - stopifnot(!is.matrix(value)) - stopifnot(length(value) == 1) - rfx_model_set_variance_prior_shape_cpp(self$rfx_model_ptr, value) - }, - - #' @description - #' Set shape parameter for the group parameter variance prior. - #' @param value Parameter input - #' @return None - set_variance_prior_scale = function(value) { - stopifnot(is.double(value)) - stopifnot(!is.matrix(value)) - stopifnot(length(value) == 1) - rfx_model_set_variance_prior_scale_cpp(self$rfx_model_ptr, value) - } - ) + classname = "RandomEffectsModel", + cloneable = FALSE, + public = list( + #' @field rfx_model_ptr External pointer to a C++ StochTree::RandomEffectsModel class + rfx_model_ptr = NULL, + + #' @field num_groups Number of groups in the random effects model + num_groups = NULL, + + #' @field num_components Number of components (i.e. dimension of basis) in the random effects model + num_components = NULL, + + #' @description + #' Create a new RandomEffectsModel object. + #' @param num_components Number of "components" or bases defining the random effects regression + #' @param num_groups Number of random effects groups + #' @return A new `RandomEffectsModel` object. + initialize = function(num_components, num_groups) { + # Initialize + self$rfx_model_ptr <- rfx_model_cpp(num_components, num_groups) + self$num_components <- num_components + self$num_groups <- num_groups + }, + + #' @description + #' Sample from random effects model. + #' @param rfx_dataset Object of type `RandomEffectsDataset` + #' @param residual Object of type `Outcome` + #' @param rfx_tracker Object of type `RandomEffectsTracker` + #' @param rfx_samples Object of type `RandomEffectSamples` + #' @param keep_sample Whether sample should be retained in `rfx_samples`. If `FALSE`, the state of `rfx_tracker` will be updated, but the parameter values will not be added to the sample container. Samples are commonly discarded due to burn-in or thinning. + #' @param global_variance Scalar global variance parameter + #' @param rng Object of type `CppRNG` + #' @return None + sample_random_effect = function( + rfx_dataset, + residual, + rfx_tracker, + rfx_samples, + keep_sample, + global_variance, + rng + ) { + rfx_model_sample_random_effects_cpp( + self$rfx_model_ptr, + rfx_dataset$data_ptr, + residual$data_ptr, + rfx_tracker$rfx_tracker_ptr, + rfx_samples$rfx_container_ptr, + keep_sample, + global_variance, + rng$rng_ptr + ) + }, + + #' @description + #' Predict from (a single sample of a) random effects model. + #' @param rfx_dataset Object of type `RandomEffectsDataset` + #' @param rfx_tracker Object of type `RandomEffectsTracker` + #' @return Vector of predictions with size matching number of observations in rfx_dataset + predict = function(rfx_dataset, rfx_tracker) { + pred <- rfx_model_predict_cpp( + self$rfx_model_ptr, + rfx_dataset$data_ptr, + rfx_tracker$rfx_tracker_ptr + ) + return(pred) + }, + + #' @description + #' Set value for the "working parameter." This is typically + #' used for initialization, but could also be used to interrupt + #' or override the sampler. + #' @param value Parameter input + #' @return None + set_working_parameter = function(value) { + stopifnot(is.double(value)) + stopifnot(!is.matrix(value)) + stopifnot(length(value) == self$num_components) + rfx_model_set_working_parameter_cpp(self$rfx_model_ptr, value) + }, + + #' @description + #' Set value for the "group parameters." This is typically + #' used for initialization, but could also be used to interrupt + #' or override the sampler. + #' @param value Parameter input + #' @return None + set_group_parameters = function(value) { + stopifnot(is.double(value)) + stopifnot(is.matrix(value)) + stopifnot(nrow(value) == self$num_components) + stopifnot(ncol(value) == self$num_groups) + rfx_model_set_group_parameters_cpp(self$rfx_model_ptr, value) + }, + + #' @description + #' Set value for the working parameter covariance. This is typically + #' used for initialization, but could also be used to interrupt + #' or override the sampler. + #' @param value Parameter input + #' @return None + set_working_parameter_cov = function(value) { + stopifnot(is.double(value)) + stopifnot(is.matrix(value)) + stopifnot(nrow(value) == self$num_components) + stopifnot(ncol(value) == self$num_components) + rfx_model_set_working_parameter_covariance_cpp( + self$rfx_model_ptr, + value + ) + }, + + #' @description + #' Set value for the group parameter covariance. This is typically + #' used for initialization, but could also be used to interrupt + #' or override the sampler. + #' @param value Parameter input + #' @return None + set_group_parameter_cov = function(value) { + stopifnot(is.double(value)) + stopifnot(is.matrix(value)) + stopifnot(nrow(value) == self$num_components) + stopifnot(ncol(value) == self$num_components) + rfx_model_set_group_parameter_covariance_cpp( + self$rfx_model_ptr, + value + ) + }, + + #' @description + #' Set shape parameter for the group parameter variance prior. + #' @param value Parameter input + #' @return None + set_variance_prior_shape = function(value) { + stopifnot(is.double(value)) + stopifnot(!is.matrix(value)) + stopifnot(length(value) == 1) + rfx_model_set_variance_prior_shape_cpp(self$rfx_model_ptr, value) + }, + + #' @description + #' Set shape parameter for the group parameter variance prior. + #' @param value Parameter input + #' @return None + set_variance_prior_scale = function(value) { + stopifnot(is.double(value)) + stopifnot(!is.matrix(value)) + stopifnot(length(value) == 1) + rfx_model_set_variance_prior_scale_cpp(self$rfx_model_ptr, value) + } + ) ) #' Create a `RandomEffectSamples` object @@ -444,13 +444,13 @@ RandomEffectsModel <- R6::R6Class( #' rfx_tracker <- createRandomEffectsTracker(rfx_group_ids) #' rfx_samples <- createRandomEffectSamples(num_components, num_groups, rfx_tracker) createRandomEffectSamples <- function( - num_components, - num_groups, - random_effects_tracker + num_components, + num_groups, + random_effects_tracker ) { - invisible(output <- RandomEffectSamples$new()) - output$load_in_session(num_components, num_groups, random_effects_tracker) - return(output) + invisible(output <- RandomEffectSamples$new()) + output$load_in_session(num_components, num_groups, random_effects_tracker) + return(output) } #' Create a `RandomEffectsTracker` object @@ -467,7 +467,7 @@ createRandomEffectSamples <- function( #' num_components <- ncol(rfx_basis) #' rfx_tracker <- createRandomEffectsTracker(rfx_group_ids) createRandomEffectsTracker <- function(rfx_group_indices) { - return(invisible((RandomEffectsTracker$new(rfx_group_indices)))) + return(invisible((RandomEffectsTracker$new(rfx_group_indices)))) } #' Create a `RandomEffectsModel` object @@ -485,7 +485,7 @@ createRandomEffectsTracker <- function(rfx_group_indices) { #' num_components <- ncol(rfx_basis) #' rfx_model <- createRandomEffectsModel(num_components, num_groups) createRandomEffectsModel <- function(num_components, num_groups) { - return(invisible((RandomEffectsModel$new(num_components, num_groups)))) + return(invisible((RandomEffectsModel$new(num_components, num_groups)))) } #' Reset a `RandomEffectsModel` object based on the parameters indexed by `sample_num` in a `RandomEffectsSamples` object @@ -531,23 +531,23 @@ createRandomEffectsModel <- function(num_components, num_groups) { #' } #' resetRandomEffectsModel(rfx_model, rfx_samples, 0, 1.0) resetRandomEffectsModel <- function( - rfx_model, - rfx_samples, - sample_num, - sigma_alpha_init + rfx_model, + rfx_samples, + sample_num, + sigma_alpha_init ) { - if (!is.matrix(sigma_alpha_init)) { - if (!is.double(sigma_alpha_init)) { - stop("`sigma_alpha_init` must be a numeric scalar or matrix") - } - sigma_alpha_init <- as.matrix(sigma_alpha_init) + if (!is.matrix(sigma_alpha_init)) { + if (!is.double(sigma_alpha_init)) { + stop("`sigma_alpha_init` must be a numeric scalar or matrix") } - reset_rfx_model_cpp( - rfx_model$rfx_model_ptr, - rfx_samples$rfx_container_ptr, - sample_num - ) - rfx_model$set_working_parameter_cov(sigma_alpha_init) + sigma_alpha_init <- as.matrix(sigma_alpha_init) + } + reset_rfx_model_cpp( + rfx_model$rfx_model_ptr, + rfx_samples$rfx_container_ptr, + sample_num + ) + rfx_model$set_working_parameter_cov(sigma_alpha_init) } #' Reset a `RandomEffectsTracker` object based on the parameters indexed by `sample_num` in a `RandomEffectsSamples` object @@ -595,18 +595,18 @@ resetRandomEffectsModel <- function( #' resetRandomEffectsModel(rfx_model, rfx_samples, 0, 1.0) #' resetRandomEffectsTracker(rfx_tracker, rfx_model, rfx_dataset, outcome, rfx_samples) resetRandomEffectsTracker <- function( - rfx_tracker, - rfx_model, - rfx_dataset, - residual, - rfx_samples + rfx_tracker, + rfx_model, + rfx_dataset, + residual, + rfx_samples ) { - reset_rfx_tracker_cpp( - rfx_tracker$rfx_tracker_ptr, - rfx_dataset$data_ptr, - residual$data_ptr, - rfx_model$rfx_model_ptr - ) + reset_rfx_tracker_cpp( + rfx_tracker$rfx_tracker_ptr, + rfx_dataset$data_ptr, + residual$data_ptr, + rfx_model$rfx_model_ptr + ) } #' Reset a `RandomEffectsModel` object to its "default" state @@ -656,20 +656,20 @@ resetRandomEffectsTracker <- function( #' rootResetRandomEffectsModel(rfx_model, alpha_init, xi_init, sigma_alpha_init, #' sigma_xi_init, sigma_xi_shape, sigma_xi_scale) rootResetRandomEffectsModel <- function( - rfx_model, - alpha_init, - xi_init, - sigma_alpha_init, - sigma_xi_init, - sigma_xi_shape, - sigma_xi_scale + rfx_model, + alpha_init, + xi_init, + sigma_alpha_init, + sigma_xi_init, + sigma_xi_shape, + sigma_xi_scale ) { - rfx_model$set_working_parameter(alpha_init) - rfx_model$set_group_parameters(xi_init) - rfx_model$set_working_parameter_cov(sigma_alpha_init) - rfx_model$set_group_parameter_cov(sigma_xi_init) - rfx_model$set_variance_prior_shape(sigma_xi_shape) - rfx_model$set_variance_prior_scale(sigma_xi_scale) + rfx_model$set_working_parameter(alpha_init) + rfx_model$set_group_parameters(xi_init) + rfx_model$set_working_parameter_cov(sigma_alpha_init) + rfx_model$set_group_parameter_cov(sigma_xi_init) + rfx_model$set_variance_prior_shape(sigma_xi_shape) + rfx_model$set_variance_prior_scale(sigma_xi_scale) } #' Reset a `RandomEffectsTracker` object to its "default" state @@ -717,15 +717,15 @@ rootResetRandomEffectsModel <- function( #' sigma_xi_init, sigma_xi_shape, sigma_xi_scale) #' rootResetRandomEffectsTracker(rfx_tracker, rfx_model, rfx_dataset, outcome) rootResetRandomEffectsTracker <- function( - rfx_tracker, - rfx_model, - rfx_dataset, - residual + rfx_tracker, + rfx_model, + rfx_dataset, + residual ) { - root_reset_rfx_tracker_cpp( - rfx_tracker$rfx_tracker_ptr, - rfx_dataset$data_ptr, - residual$data_ptr, - rfx_model$rfx_model_ptr - ) + root_reset_rfx_tracker_cpp( + rfx_tracker$rfx_tracker_ptr, + rfx_dataset$data_ptr, + residual$data_ptr, + rfx_model$rfx_model_ptr + ) } diff --git a/R/serialization.R b/R/serialization.R index c52e9fec..1b579bd3 100644 --- a/R/serialization.R +++ b/R/serialization.R @@ -4,521 +4,521 @@ #' Wrapper around a C++ container of tree ensembles CppJson <- R6::R6Class( - classname = "CppJson", - cloneable = FALSE, - public = list( - #' @field json_ptr External pointer to a C++ nlohmann::json object - json_ptr = NULL, - - #' @field num_forests Number of forests in the nlohmann::json object - num_forests = NULL, - - #' @field forest_labels Names of forest objects in the overall nlohmann::json object - forest_labels = NULL, - - #' @field num_rfx Number of random effects terms in the nlohman::json object - num_rfx = NULL, - - #' @field rfx_container_labels Names of rfx container objects in the overall nlohmann::json object - rfx_container_labels = NULL, - - #' @field rfx_mapper_labels Names of rfx label mapper objects in the overall nlohmann::json object - rfx_mapper_labels = NULL, - - #' @field rfx_groupid_labels Names of rfx group id objects in the overall nlohmann::json object - rfx_groupid_labels = NULL, - - #' @description - #' Create a new CppJson object. - #' @return A new `CppJson` object. - initialize = function() { - self$json_ptr <- init_json_cpp() - self$num_forests <- 0 - self$forest_labels <- c() - self$num_rfx <- 0 - self$rfx_container_labels <- c() - self$rfx_mapper_labels <- c() - self$rfx_groupid_labels <- c() - }, - - #' @description - #' Convert a forest container to json and add to the current `CppJson` object - #' @param forest_samples `ForestSamples` R class - #' @return None - add_forest = function(forest_samples) { - forest_label <- json_add_forest_cpp( - self$json_ptr, - forest_samples$forest_container_ptr - ) - self$num_forests <- self$num_forests + 1 - self$forest_labels <- c(self$forest_labels, forest_label) - }, - - #' @description - #' Convert a random effects container to json and add to the current `CppJson` object - #' @param rfx_samples `RandomEffectSamples` R class - #' @return None - add_random_effects = function(rfx_samples) { - rfx_container_label <- json_add_rfx_container_cpp( - self$json_ptr, - rfx_samples$rfx_container_ptr - ) - self$rfx_container_labels <- c( - self$rfx_container_labels, - rfx_container_label - ) - rfx_mapper_label <- json_add_rfx_label_mapper_cpp( - self$json_ptr, - rfx_samples$label_mapper_ptr - ) - self$rfx_mapper_labels <- c( - self$rfx_mapper_labels, - rfx_mapper_label - ) - rfx_groupid_label <- json_add_rfx_groupids_cpp( - self$json_ptr, - rfx_samples$training_group_ids - ) - self$rfx_groupid_labels <- c( - self$rfx_groupid_labels, - rfx_groupid_label - ) - json_increment_rfx_count_cpp(self$json_ptr) - self$num_rfx <- self$num_rfx + 1 - }, - - #' @description - #' Add a scalar to the json object under the name "field_name" (with optional subfolder "subfolder_name") - #' @param field_name The name of the field to be added to json - #' @param field_value Numeric value of the field to be added to json - #' @param subfolder_name (Optional) Name of the subfolder / hierarchy under which to place the value - #' @return None - add_scalar = function(field_name, field_value, subfolder_name = NULL) { - if (is.null(subfolder_name)) { - json_add_double_cpp(self$json_ptr, field_name, field_value) - } else { - json_add_double_subfolder_cpp( - self$json_ptr, - subfolder_name, - field_name, - field_value - ) - } - }, - - #' @description - #' Add a scalar to the json object under the name "field_name" (with optional subfolder "subfolder_name") - #' @param field_name The name of the field to be added to json - #' @param field_value Integer value of the field to be added to json - #' @param subfolder_name (Optional) Name of the subfolder / hierarchy under which to place the value - #' @return None - add_integer = function(field_name, field_value, subfolder_name = NULL) { - if (is.null(subfolder_name)) { - json_add_integer_cpp(self$json_ptr, field_name, field_value) - } else { - json_add_integer_subfolder_cpp( - self$json_ptr, - subfolder_name, - field_name, - field_value - ) - } - }, - - #' @description - #' Add a boolean value to the json object under the name "field_name" (with optional subfolder "subfolder_name") - #' @param field_name The name of the field to be added to json - #' @param field_value Numeric value of the field to be added to json - #' @param subfolder_name (Optional) Name of the subfolder / hierarchy under which to place the value - #' @return None - add_boolean = function(field_name, field_value, subfolder_name = NULL) { - if (is.null(subfolder_name)) { - json_add_bool_cpp(self$json_ptr, field_name, field_value) - } else { - json_add_bool_subfolder_cpp( - self$json_ptr, - subfolder_name, - field_name, - field_value - ) - } - }, - - #' @description - #' Add a string value to the json object under the name "field_name" (with optional subfolder "subfolder_name") - #' @param field_name The name of the field to be added to json - #' @param field_value Numeric value of the field to be added to json - #' @param subfolder_name (Optional) Name of the subfolder / hierarchy under which to place the value - #' @return None - add_string = function(field_name, field_value, subfolder_name = NULL) { - if (is.null(subfolder_name)) { - json_add_string_cpp(self$json_ptr, field_name, field_value) - } else { - json_add_string_subfolder_cpp( - self$json_ptr, - subfolder_name, - field_name, - field_value - ) - } - }, - - #' @description - #' Add a vector to the json object under the name "field_name" (with optional subfolder "subfolder_name") - #' @param field_name The name of the field to be added to json - #' @param field_vector Vector to be stored in json - #' @param subfolder_name (Optional) Name of the subfolder / hierarchy under which to place the value - #' @return None - add_vector = function(field_name, field_vector, subfolder_name = NULL) { - field_vector <- as.numeric(field_vector) - if (is.null(subfolder_name)) { - json_add_vector_cpp(self$json_ptr, field_name, field_vector) - } else { - json_add_vector_subfolder_cpp( - self$json_ptr, - subfolder_name, - field_name, - field_vector - ) - } - }, - - #' @description - #' Add an integer vector to the json object under the name "field_name" (with optional subfolder "subfolder_name") - #' @param field_name The name of the field to be added to json - #' @param field_vector Vector to be stored in json - #' @param subfolder_name (Optional) Name of the subfolder / hierarchy under which to place the value - #' @return None - add_integer_vector = function( - field_name, - field_vector, - subfolder_name = NULL - ) { - field_vector <- as.numeric(field_vector) - if (is.null(subfolder_name)) { - json_add_integer_vector_cpp( - self$json_ptr, - field_name, - field_vector - ) - } else { - json_add_integer_vector_subfolder_cpp( - self$json_ptr, - subfolder_name, - field_name, - field_vector - ) - } - }, - - #' @description - #' Add an array to the json object under the name "field_name" (with optional subfolder "subfolder_name") - #' @param field_name The name of the field to be added to json - #' @param field_vector Character vector to be stored in json - #' @param subfolder_name (Optional) Name of the subfolder / hierarchy under which to place the value - #' @return None - add_string_vector = function( - field_name, - field_vector, - subfolder_name = NULL - ) { - if (is.null(subfolder_name)) { - json_add_string_vector_cpp( - self$json_ptr, - field_name, - field_vector - ) - } else { - json_add_string_vector_subfolder_cpp( - self$json_ptr, - subfolder_name, - field_name, - field_vector - ) - } - }, - - #' @description - #' Add a list of vectors (as an object map of arrays) to the json object under the name "field_name" - #' @param field_name The name of the field to be added to json - #' @param field_list List to be stored in json - #' @return None - add_list = function(field_name, field_list) { - stopifnot(sum(!sapply(field_list, is.vector)) == 0) - list_names <- names(field_list) - for (i in 1:length(field_list)) { - vec_name <- list_names[i] - vec <- field_list[[i]] - json_add_vector_subfolder_cpp( - self$json_ptr, - field_name, - vec_name, - vec - ) - } - }, - - #' @description - #' Add a list of vectors (as an object map of arrays) to the json object under the name "field_name" - #' @param field_name The name of the field to be added to json - #' @param field_list List to be stored in json - #' @return None - add_string_list = function(field_name, field_list) { - stopifnot(sum(!sapply(field_list, is.vector)) == 0) - list_names <- names(field_list) - for (i in 1:length(field_list)) { - vec_name <- list_names[i] - vec <- field_list[[i]] - json_add_string_vector_subfolder_cpp( - self$json_ptr, - field_name, - vec_name, - vec - ) - } - }, - - #' @description - #' Retrieve a scalar value from the json object under the name "field_name" (with optional subfolder "subfolder_name") - #' @param field_name The name of the field to be accessed from json - #' @param subfolder_name (Optional) Name of the subfolder / hierarchy under which the field is stored - #' @return None - get_scalar = function(field_name, subfolder_name = NULL) { - if (is.null(subfolder_name)) { - stopifnot(json_contains_field_cpp(self$json_ptr, field_name)) - result <- json_extract_double_cpp(self$json_ptr, field_name) - } else { - stopifnot(json_contains_field_subfolder_cpp( - self$json_ptr, - subfolder_name, - field_name - )) - result <- json_extract_double_subfolder_cpp( - self$json_ptr, - subfolder_name, - field_name - ) - } - return(result) - }, - - #' @description - #' Retrieve a integer value from the json object under the name "field_name" (with optional subfolder "subfolder_name") - #' @param field_name The name of the field to be accessed from json - #' @param subfolder_name (Optional) Name of the subfolder / hierarchy under which the field is stored - #' @return None - get_integer = function(field_name, subfolder_name = NULL) { - if (is.null(subfolder_name)) { - stopifnot(json_contains_field_cpp(self$json_ptr, field_name)) - result <- json_extract_integer_cpp(self$json_ptr, field_name) - } else { - stopifnot(json_contains_field_subfolder_cpp( - self$json_ptr, - subfolder_name, - field_name - )) - result <- json_extract_integer_subfolder_cpp( - self$json_ptr, - subfolder_name, - field_name - ) - } - return(result) - }, - - #' @description - #' Retrieve a boolean value from the json object under the name "field_name" (with optional subfolder "subfolder_name") - #' @param field_name The name of the field to be accessed from json - #' @param subfolder_name (Optional) Name of the subfolder / hierarchy under which the field is stored - #' @return None - get_boolean = function(field_name, subfolder_name = NULL) { - if (is.null(subfolder_name)) { - stopifnot(json_contains_field_cpp(self$json_ptr, field_name)) - result <- json_extract_bool_cpp(self$json_ptr, field_name) - } else { - stopifnot(json_contains_field_subfolder_cpp( - self$json_ptr, - subfolder_name, - field_name - )) - result <- json_extract_bool_subfolder_cpp( - self$json_ptr, - subfolder_name, - field_name - ) - } - return(result) - }, - - #' @description - #' Retrieve a string value from the json object under the name "field_name" (with optional subfolder "subfolder_name") - #' @param field_name The name of the field to be accessed from json - #' @param subfolder_name (Optional) Name of the subfolder / hierarchy under which the field is stored - #' @return None - get_string = function(field_name, subfolder_name = NULL) { - if (is.null(subfolder_name)) { - stopifnot(json_contains_field_cpp(self$json_ptr, field_name)) - result <- json_extract_string_cpp(self$json_ptr, field_name) - } else { - stopifnot(json_contains_field_subfolder_cpp( - self$json_ptr, - subfolder_name, - field_name - )) - result <- json_extract_string_subfolder_cpp( - self$json_ptr, - subfolder_name, - field_name - ) - } - return(result) - }, - - #' @description - #' Retrieve a vector from the json object under the name "field_name" (with optional subfolder "subfolder_name") - #' @param field_name The name of the field to be accessed from json - #' @param subfolder_name (Optional) Name of the subfolder / hierarchy under which the field is stored - #' @return None - get_vector = function(field_name, subfolder_name = NULL) { - if (is.null(subfolder_name)) { - stopifnot(json_contains_field_cpp(self$json_ptr, field_name)) - result <- json_extract_vector_cpp(self$json_ptr, field_name) - } else { - stopifnot(json_contains_field_subfolder_cpp( - self$json_ptr, - subfolder_name, - field_name - )) - result <- json_extract_vector_subfolder_cpp( - self$json_ptr, - subfolder_name, - field_name - ) - } - return(result) - }, - - #' @description - #' Retrieve an integer vector from the json object under the name "field_name" (with optional subfolder "subfolder_name") - #' @param field_name The name of the field to be accessed from json - #' @param subfolder_name (Optional) Name of the subfolder / hierarchy under which the field is stored - #' @return None - get_integer_vector = function(field_name, subfolder_name = NULL) { - if (is.null(subfolder_name)) { - stopifnot(json_contains_field_cpp(self$json_ptr, field_name)) - result <- json_extract_integer_vector_cpp( - self$json_ptr, - field_name - ) - } else { - stopifnot(json_contains_field_subfolder_cpp( - self$json_ptr, - subfolder_name, - field_name - )) - result <- json_extract_integer_vector_subfolder_cpp( - self$json_ptr, - subfolder_name, - field_name - ) - } - return(result) - }, - - #' @description - #' Retrieve a character vector from the json object under the name "field_name" (with optional subfolder "subfolder_name") - #' @param field_name The name of the field to be accessed from json - #' @param subfolder_name (Optional) Name of the subfolder / hierarchy under which the field is stored - #' @return None - get_string_vector = function(field_name, subfolder_name = NULL) { - if (is.null(subfolder_name)) { - stopifnot(json_contains_field_cpp(self$json_ptr, field_name)) - result <- json_extract_string_vector_cpp( - self$json_ptr, - field_name - ) - } else { - stopifnot(json_contains_field_subfolder_cpp( - self$json_ptr, - subfolder_name, - field_name - )) - result <- json_extract_string_vector_subfolder_cpp( - self$json_ptr, - subfolder_name, - field_name - ) - } - return(result) - }, - - #' @description - #' Reconstruct a list of numeric vectors from the json object stored under "field_name" - #' @param field_name The name of the field to be added to json - #' @param key_names Vector of names of list elements (each of which is a vector) - #' @return None - get_numeric_list = function(field_name, key_names) { - output <- list() - for (i in 1:length(key_names)) { - vec_name <- key_names[i] - output[[vec_name]] <- json_extract_vector_subfolder_cpp( - self$json_ptr, - field_name, - vec_name - ) - } - return(output) - }, - - #' @description - #' Reconstruct a list of string vectors from the json object stored under "field_name" - #' @param field_name The name of the field to be added to json - #' @param key_names Vector of names of list elements (each of which is a vector) - #' @return None - get_string_list = function(field_name, key_names) { - output <- list() - for (i in 1:length(key_names)) { - vec_name <- key_names[i] - output[[vec_name]] <- json_extract_string_vector_subfolder_cpp( - self$json_ptr, - field_name, - vec_name - ) - } - return(output) - }, - - #' @description - #' Convert a JSON object to in-memory string - #' @return JSON string - return_json_string = function() { - return(get_json_string_cpp(self$json_ptr)) - }, - - #' @description - #' Save a json object to file - #' @param filename String of filepath, must end in ".json" - #' @return None - save_file = function(filename) { - json_save_file_cpp(self$json_ptr, filename) - }, - - #' @description - #' Load a json object from file - #' @param filename String of filepath, must end in ".json" - #' @return None - load_from_file = function(filename) { - json_load_file_cpp(self$json_ptr, filename) - }, - - #' @description - #' Load a json object from string - #' @param json_string JSON string dump - #' @return None - load_from_string = function(json_string) { - json_load_string_cpp(self$json_ptr, json_string) - } - ) + classname = "CppJson", + cloneable = FALSE, + public = list( + #' @field json_ptr External pointer to a C++ nlohmann::json object + json_ptr = NULL, + + #' @field num_forests Number of forests in the nlohmann::json object + num_forests = NULL, + + #' @field forest_labels Names of forest objects in the overall nlohmann::json object + forest_labels = NULL, + + #' @field num_rfx Number of random effects terms in the nlohman::json object + num_rfx = NULL, + + #' @field rfx_container_labels Names of rfx container objects in the overall nlohmann::json object + rfx_container_labels = NULL, + + #' @field rfx_mapper_labels Names of rfx label mapper objects in the overall nlohmann::json object + rfx_mapper_labels = NULL, + + #' @field rfx_groupid_labels Names of rfx group id objects in the overall nlohmann::json object + rfx_groupid_labels = NULL, + + #' @description + #' Create a new CppJson object. + #' @return A new `CppJson` object. + initialize = function() { + self$json_ptr <- init_json_cpp() + self$num_forests <- 0 + self$forest_labels <- c() + self$num_rfx <- 0 + self$rfx_container_labels <- c() + self$rfx_mapper_labels <- c() + self$rfx_groupid_labels <- c() + }, + + #' @description + #' Convert a forest container to json and add to the current `CppJson` object + #' @param forest_samples `ForestSamples` R class + #' @return None + add_forest = function(forest_samples) { + forest_label <- json_add_forest_cpp( + self$json_ptr, + forest_samples$forest_container_ptr + ) + self$num_forests <- self$num_forests + 1 + self$forest_labels <- c(self$forest_labels, forest_label) + }, + + #' @description + #' Convert a random effects container to json and add to the current `CppJson` object + #' @param rfx_samples `RandomEffectSamples` R class + #' @return None + add_random_effects = function(rfx_samples) { + rfx_container_label <- json_add_rfx_container_cpp( + self$json_ptr, + rfx_samples$rfx_container_ptr + ) + self$rfx_container_labels <- c( + self$rfx_container_labels, + rfx_container_label + ) + rfx_mapper_label <- json_add_rfx_label_mapper_cpp( + self$json_ptr, + rfx_samples$label_mapper_ptr + ) + self$rfx_mapper_labels <- c( + self$rfx_mapper_labels, + rfx_mapper_label + ) + rfx_groupid_label <- json_add_rfx_groupids_cpp( + self$json_ptr, + rfx_samples$training_group_ids + ) + self$rfx_groupid_labels <- c( + self$rfx_groupid_labels, + rfx_groupid_label + ) + json_increment_rfx_count_cpp(self$json_ptr) + self$num_rfx <- self$num_rfx + 1 + }, + + #' @description + #' Add a scalar to the json object under the name "field_name" (with optional subfolder "subfolder_name") + #' @param field_name The name of the field to be added to json + #' @param field_value Numeric value of the field to be added to json + #' @param subfolder_name (Optional) Name of the subfolder / hierarchy under which to place the value + #' @return None + add_scalar = function(field_name, field_value, subfolder_name = NULL) { + if (is.null(subfolder_name)) { + json_add_double_cpp(self$json_ptr, field_name, field_value) + } else { + json_add_double_subfolder_cpp( + self$json_ptr, + subfolder_name, + field_name, + field_value + ) + } + }, + + #' @description + #' Add a scalar to the json object under the name "field_name" (with optional subfolder "subfolder_name") + #' @param field_name The name of the field to be added to json + #' @param field_value Integer value of the field to be added to json + #' @param subfolder_name (Optional) Name of the subfolder / hierarchy under which to place the value + #' @return None + add_integer = function(field_name, field_value, subfolder_name = NULL) { + if (is.null(subfolder_name)) { + json_add_integer_cpp(self$json_ptr, field_name, field_value) + } else { + json_add_integer_subfolder_cpp( + self$json_ptr, + subfolder_name, + field_name, + field_value + ) + } + }, + + #' @description + #' Add a boolean value to the json object under the name "field_name" (with optional subfolder "subfolder_name") + #' @param field_name The name of the field to be added to json + #' @param field_value Numeric value of the field to be added to json + #' @param subfolder_name (Optional) Name of the subfolder / hierarchy under which to place the value + #' @return None + add_boolean = function(field_name, field_value, subfolder_name = NULL) { + if (is.null(subfolder_name)) { + json_add_bool_cpp(self$json_ptr, field_name, field_value) + } else { + json_add_bool_subfolder_cpp( + self$json_ptr, + subfolder_name, + field_name, + field_value + ) + } + }, + + #' @description + #' Add a string value to the json object under the name "field_name" (with optional subfolder "subfolder_name") + #' @param field_name The name of the field to be added to json + #' @param field_value Numeric value of the field to be added to json + #' @param subfolder_name (Optional) Name of the subfolder / hierarchy under which to place the value + #' @return None + add_string = function(field_name, field_value, subfolder_name = NULL) { + if (is.null(subfolder_name)) { + json_add_string_cpp(self$json_ptr, field_name, field_value) + } else { + json_add_string_subfolder_cpp( + self$json_ptr, + subfolder_name, + field_name, + field_value + ) + } + }, + + #' @description + #' Add a vector to the json object under the name "field_name" (with optional subfolder "subfolder_name") + #' @param field_name The name of the field to be added to json + #' @param field_vector Vector to be stored in json + #' @param subfolder_name (Optional) Name of the subfolder / hierarchy under which to place the value + #' @return None + add_vector = function(field_name, field_vector, subfolder_name = NULL) { + field_vector <- as.numeric(field_vector) + if (is.null(subfolder_name)) { + json_add_vector_cpp(self$json_ptr, field_name, field_vector) + } else { + json_add_vector_subfolder_cpp( + self$json_ptr, + subfolder_name, + field_name, + field_vector + ) + } + }, + + #' @description + #' Add an integer vector to the json object under the name "field_name" (with optional subfolder "subfolder_name") + #' @param field_name The name of the field to be added to json + #' @param field_vector Vector to be stored in json + #' @param subfolder_name (Optional) Name of the subfolder / hierarchy under which to place the value + #' @return None + add_integer_vector = function( + field_name, + field_vector, + subfolder_name = NULL + ) { + field_vector <- as.numeric(field_vector) + if (is.null(subfolder_name)) { + json_add_integer_vector_cpp( + self$json_ptr, + field_name, + field_vector + ) + } else { + json_add_integer_vector_subfolder_cpp( + self$json_ptr, + subfolder_name, + field_name, + field_vector + ) + } + }, + + #' @description + #' Add an array to the json object under the name "field_name" (with optional subfolder "subfolder_name") + #' @param field_name The name of the field to be added to json + #' @param field_vector Character vector to be stored in json + #' @param subfolder_name (Optional) Name of the subfolder / hierarchy under which to place the value + #' @return None + add_string_vector = function( + field_name, + field_vector, + subfolder_name = NULL + ) { + if (is.null(subfolder_name)) { + json_add_string_vector_cpp( + self$json_ptr, + field_name, + field_vector + ) + } else { + json_add_string_vector_subfolder_cpp( + self$json_ptr, + subfolder_name, + field_name, + field_vector + ) + } + }, + + #' @description + #' Add a list of vectors (as an object map of arrays) to the json object under the name "field_name" + #' @param field_name The name of the field to be added to json + #' @param field_list List to be stored in json + #' @return None + add_list = function(field_name, field_list) { + stopifnot(sum(!sapply(field_list, is.vector)) == 0) + list_names <- names(field_list) + for (i in 1:length(field_list)) { + vec_name <- list_names[i] + vec <- field_list[[i]] + json_add_vector_subfolder_cpp( + self$json_ptr, + field_name, + vec_name, + vec + ) + } + }, + + #' @description + #' Add a list of vectors (as an object map of arrays) to the json object under the name "field_name" + #' @param field_name The name of the field to be added to json + #' @param field_list List to be stored in json + #' @return None + add_string_list = function(field_name, field_list) { + stopifnot(sum(!sapply(field_list, is.vector)) == 0) + list_names <- names(field_list) + for (i in 1:length(field_list)) { + vec_name <- list_names[i] + vec <- field_list[[i]] + json_add_string_vector_subfolder_cpp( + self$json_ptr, + field_name, + vec_name, + vec + ) + } + }, + + #' @description + #' Retrieve a scalar value from the json object under the name "field_name" (with optional subfolder "subfolder_name") + #' @param field_name The name of the field to be accessed from json + #' @param subfolder_name (Optional) Name of the subfolder / hierarchy under which the field is stored + #' @return None + get_scalar = function(field_name, subfolder_name = NULL) { + if (is.null(subfolder_name)) { + stopifnot(json_contains_field_cpp(self$json_ptr, field_name)) + result <- json_extract_double_cpp(self$json_ptr, field_name) + } else { + stopifnot(json_contains_field_subfolder_cpp( + self$json_ptr, + subfolder_name, + field_name + )) + result <- json_extract_double_subfolder_cpp( + self$json_ptr, + subfolder_name, + field_name + ) + } + return(result) + }, + + #' @description + #' Retrieve a integer value from the json object under the name "field_name" (with optional subfolder "subfolder_name") + #' @param field_name The name of the field to be accessed from json + #' @param subfolder_name (Optional) Name of the subfolder / hierarchy under which the field is stored + #' @return None + get_integer = function(field_name, subfolder_name = NULL) { + if (is.null(subfolder_name)) { + stopifnot(json_contains_field_cpp(self$json_ptr, field_name)) + result <- json_extract_integer_cpp(self$json_ptr, field_name) + } else { + stopifnot(json_contains_field_subfolder_cpp( + self$json_ptr, + subfolder_name, + field_name + )) + result <- json_extract_integer_subfolder_cpp( + self$json_ptr, + subfolder_name, + field_name + ) + } + return(result) + }, + + #' @description + #' Retrieve a boolean value from the json object under the name "field_name" (with optional subfolder "subfolder_name") + #' @param field_name The name of the field to be accessed from json + #' @param subfolder_name (Optional) Name of the subfolder / hierarchy under which the field is stored + #' @return None + get_boolean = function(field_name, subfolder_name = NULL) { + if (is.null(subfolder_name)) { + stopifnot(json_contains_field_cpp(self$json_ptr, field_name)) + result <- json_extract_bool_cpp(self$json_ptr, field_name) + } else { + stopifnot(json_contains_field_subfolder_cpp( + self$json_ptr, + subfolder_name, + field_name + )) + result <- json_extract_bool_subfolder_cpp( + self$json_ptr, + subfolder_name, + field_name + ) + } + return(result) + }, + + #' @description + #' Retrieve a string value from the json object under the name "field_name" (with optional subfolder "subfolder_name") + #' @param field_name The name of the field to be accessed from json + #' @param subfolder_name (Optional) Name of the subfolder / hierarchy under which the field is stored + #' @return None + get_string = function(field_name, subfolder_name = NULL) { + if (is.null(subfolder_name)) { + stopifnot(json_contains_field_cpp(self$json_ptr, field_name)) + result <- json_extract_string_cpp(self$json_ptr, field_name) + } else { + stopifnot(json_contains_field_subfolder_cpp( + self$json_ptr, + subfolder_name, + field_name + )) + result <- json_extract_string_subfolder_cpp( + self$json_ptr, + subfolder_name, + field_name + ) + } + return(result) + }, + + #' @description + #' Retrieve a vector from the json object under the name "field_name" (with optional subfolder "subfolder_name") + #' @param field_name The name of the field to be accessed from json + #' @param subfolder_name (Optional) Name of the subfolder / hierarchy under which the field is stored + #' @return None + get_vector = function(field_name, subfolder_name = NULL) { + if (is.null(subfolder_name)) { + stopifnot(json_contains_field_cpp(self$json_ptr, field_name)) + result <- json_extract_vector_cpp(self$json_ptr, field_name) + } else { + stopifnot(json_contains_field_subfolder_cpp( + self$json_ptr, + subfolder_name, + field_name + )) + result <- json_extract_vector_subfolder_cpp( + self$json_ptr, + subfolder_name, + field_name + ) + } + return(result) + }, + + #' @description + #' Retrieve an integer vector from the json object under the name "field_name" (with optional subfolder "subfolder_name") + #' @param field_name The name of the field to be accessed from json + #' @param subfolder_name (Optional) Name of the subfolder / hierarchy under which the field is stored + #' @return None + get_integer_vector = function(field_name, subfolder_name = NULL) { + if (is.null(subfolder_name)) { + stopifnot(json_contains_field_cpp(self$json_ptr, field_name)) + result <- json_extract_integer_vector_cpp( + self$json_ptr, + field_name + ) + } else { + stopifnot(json_contains_field_subfolder_cpp( + self$json_ptr, + subfolder_name, + field_name + )) + result <- json_extract_integer_vector_subfolder_cpp( + self$json_ptr, + subfolder_name, + field_name + ) + } + return(result) + }, + + #' @description + #' Retrieve a character vector from the json object under the name "field_name" (with optional subfolder "subfolder_name") + #' @param field_name The name of the field to be accessed from json + #' @param subfolder_name (Optional) Name of the subfolder / hierarchy under which the field is stored + #' @return None + get_string_vector = function(field_name, subfolder_name = NULL) { + if (is.null(subfolder_name)) { + stopifnot(json_contains_field_cpp(self$json_ptr, field_name)) + result <- json_extract_string_vector_cpp( + self$json_ptr, + field_name + ) + } else { + stopifnot(json_contains_field_subfolder_cpp( + self$json_ptr, + subfolder_name, + field_name + )) + result <- json_extract_string_vector_subfolder_cpp( + self$json_ptr, + subfolder_name, + field_name + ) + } + return(result) + }, + + #' @description + #' Reconstruct a list of numeric vectors from the json object stored under "field_name" + #' @param field_name The name of the field to be added to json + #' @param key_names Vector of names of list elements (each of which is a vector) + #' @return None + get_numeric_list = function(field_name, key_names) { + output <- list() + for (i in 1:length(key_names)) { + vec_name <- key_names[i] + output[[vec_name]] <- json_extract_vector_subfolder_cpp( + self$json_ptr, + field_name, + vec_name + ) + } + return(output) + }, + + #' @description + #' Reconstruct a list of string vectors from the json object stored under "field_name" + #' @param field_name The name of the field to be added to json + #' @param key_names Vector of names of list elements (each of which is a vector) + #' @return None + get_string_list = function(field_name, key_names) { + output <- list() + for (i in 1:length(key_names)) { + vec_name <- key_names[i] + output[[vec_name]] <- json_extract_string_vector_subfolder_cpp( + self$json_ptr, + field_name, + vec_name + ) + } + return(output) + }, + + #' @description + #' Convert a JSON object to in-memory string + #' @return JSON string + return_json_string = function() { + return(get_json_string_cpp(self$json_ptr)) + }, + + #' @description + #' Save a json object to file + #' @param filename String of filepath, must end in ".json" + #' @return None + save_file = function(filename) { + json_save_file_cpp(self$json_ptr, filename) + }, + + #' @description + #' Load a json object from file + #' @param filename String of filepath, must end in ".json" + #' @return None + load_from_file = function(filename) { + json_load_file_cpp(self$json_ptr, filename) + }, + + #' @description + #' Load a json object from string + #' @param json_string JSON string dump + #' @return None + load_from_string = function(json_string) { + json_load_string_cpp(self$json_ptr, json_string) + } + ) ) #' Load a container of forest samples from json @@ -536,9 +536,9 @@ CppJson <- R6::R6Class( #' bart_json <- saveBARTModelToJson(bart_model) #' mean_forest <- loadForestContainerJson(bart_json, "forest_0") loadForestContainerJson <- function(json_object, json_forest_label) { - invisible(output <- ForestSamples$new(0, 1, T)) - output$load_from_json(json_object, json_forest_label) - return(output) + invisible(output <- ForestSamples$new(0, 1, T)) + output$load_from_json(json_object, json_forest_label) + return(output) } #' Combine multiple JSON model objects containing forests (with the same hierarchy / schema) into a single forest_container @@ -556,19 +556,19 @@ loadForestContainerJson <- function(json_object, json_forest_label) { #' bart_json <- list(saveBARTModelToJson(bart_model)) #' mean_forest <- loadForestContainerCombinedJson(bart_json, "forest_0") loadForestContainerCombinedJson <- function( - json_object_list, - json_forest_label + json_object_list, + json_forest_label ) { - invisible(output <- ForestSamples$new(0, 1, T)) - for (i in 1:length(json_object_list)) { - json_object <- json_object_list[[i]] - if (i == 1) { - output$load_from_json(json_object, json_forest_label) - } else { - output$append_from_json(json_object, json_forest_label) - } + invisible(output <- ForestSamples$new(0, 1, T)) + for (i in 1:length(json_object_list)) { + json_object <- json_object_list[[i]] + if (i == 1) { + output$load_from_json(json_object, json_forest_label) + } else { + output$append_from_json(json_object, json_forest_label) } - return(output) + } + return(output) } #' Combine multiple JSON strings representing model objects containing forests (with the same hierarchy / schema) into a single forest_container @@ -586,19 +586,19 @@ loadForestContainerCombinedJson <- function( #' bart_json_string <- list(saveBARTModelToJsonString(bart_model)) #' mean_forest <- loadForestContainerCombinedJsonString(bart_json_string, "forest_0") loadForestContainerCombinedJsonString <- function( - json_string_list, - json_forest_label + json_string_list, + json_forest_label ) { - invisible(output <- ForestSamples$new(0, 1, T)) - for (i in 1:length(json_string_list)) { - json_string <- json_string_list[[i]] - if (i == 1) { - output$load_from_json_string(json_string, json_forest_label) - } else { - output$append_from_json_string(json_string, json_forest_label) - } + invisible(output <- ForestSamples$new(0, 1, T)) + for (i in 1:length(json_string_list)) { + json_string <- json_string_list[[i]] + if (i == 1) { + output$load_from_json_string(json_string, json_forest_label) + } else { + output$append_from_json_string(json_string, json_forest_label) } - return(output) + } + return(output) } #' Load a container of random effect samples from json @@ -621,17 +621,17 @@ loadForestContainerCombinedJsonString <- function( #' bart_json <- saveBARTModelToJson(bart_model) #' rfx_samples <- loadRandomEffectSamplesJson(bart_json, 0) loadRandomEffectSamplesJson <- function(json_object, json_rfx_num) { - json_rfx_container_label <- paste0("random_effect_container_", json_rfx_num) - json_rfx_mapper_label <- paste0("random_effect_label_mapper_", json_rfx_num) - json_rfx_groupids_label <- paste0("random_effect_groupids_", json_rfx_num) - invisible(output <- RandomEffectSamples$new()) - output$load_from_json( - json_object, - json_rfx_container_label, - json_rfx_mapper_label, - json_rfx_groupids_label - ) - return(output) + json_rfx_container_label <- paste0("random_effect_container_", json_rfx_num) + json_rfx_mapper_label <- paste0("random_effect_label_mapper_", json_rfx_num) + json_rfx_groupids_label <- paste0("random_effect_groupids_", json_rfx_num) + invisible(output <- RandomEffectSamples$new()) + output$load_from_json( + json_object, + json_rfx_container_label, + json_rfx_mapper_label, + json_rfx_groupids_label + ) + return(output) } #' Combine multiple JSON model objects containing random effects (with the same hierarchy / schema) into a single container @@ -654,32 +654,32 @@ loadRandomEffectSamplesJson <- function(json_object, json_rfx_num) { #' bart_json <- list(saveBARTModelToJson(bart_model)) #' rfx_samples <- loadRandomEffectSamplesCombinedJson(bart_json, 0) loadRandomEffectSamplesCombinedJson <- function( - json_object_list, - json_rfx_num + json_object_list, + json_rfx_num ) { - json_rfx_container_label <- paste0("random_effect_container_", json_rfx_num) - json_rfx_mapper_label <- paste0("random_effect_label_mapper_", json_rfx_num) - json_rfx_groupids_label <- paste0("random_effect_groupids_", json_rfx_num) - invisible(output <- RandomEffectSamples$new()) - for (i in 1:length(json_object_list)) { - json_object <- json_object_list[[i]] - if (i == 1) { - output$load_from_json( - json_object, - json_rfx_container_label, - json_rfx_mapper_label, - json_rfx_groupids_label - ) - } else { - output$append_from_json( - json_object, - json_rfx_container_label, - json_rfx_mapper_label, - json_rfx_groupids_label - ) - } + json_rfx_container_label <- paste0("random_effect_container_", json_rfx_num) + json_rfx_mapper_label <- paste0("random_effect_label_mapper_", json_rfx_num) + json_rfx_groupids_label <- paste0("random_effect_groupids_", json_rfx_num) + invisible(output <- RandomEffectSamples$new()) + for (i in 1:length(json_object_list)) { + json_object <- json_object_list[[i]] + if (i == 1) { + output$load_from_json( + json_object, + json_rfx_container_label, + json_rfx_mapper_label, + json_rfx_groupids_label + ) + } else { + output$append_from_json( + json_object, + json_rfx_container_label, + json_rfx_mapper_label, + json_rfx_groupids_label + ) } - return(output) + } + return(output) } #' Combine multiple JSON strings representing model objects containing random effects (with the same hierarchy / schema) into a single container @@ -702,32 +702,32 @@ loadRandomEffectSamplesCombinedJson <- function( #' bart_json_string <- list(saveBARTModelToJsonString(bart_model)) #' rfx_samples <- loadRandomEffectSamplesCombinedJsonString(bart_json_string, 0) loadRandomEffectSamplesCombinedJsonString <- function( - json_string_list, - json_rfx_num + json_string_list, + json_rfx_num ) { - json_rfx_container_label <- paste0("random_effect_container_", json_rfx_num) - json_rfx_mapper_label <- paste0("random_effect_label_mapper_", json_rfx_num) - json_rfx_groupids_label <- paste0("random_effect_groupids_", json_rfx_num) - invisible(output <- RandomEffectSamples$new()) - for (i in 1:length(json_string_list)) { - json_string <- json_string_list[[i]] - if (i == 1) { - output$load_from_json_string( - json_string, - json_rfx_container_label, - json_rfx_mapper_label, - json_rfx_groupids_label - ) - } else { - output$append_from_json_string( - json_string, - json_rfx_container_label, - json_rfx_mapper_label, - json_rfx_groupids_label - ) - } + json_rfx_container_label <- paste0("random_effect_container_", json_rfx_num) + json_rfx_mapper_label <- paste0("random_effect_label_mapper_", json_rfx_num) + json_rfx_groupids_label <- paste0("random_effect_groupids_", json_rfx_num) + invisible(output <- RandomEffectSamples$new()) + for (i in 1:length(json_string_list)) { + json_string <- json_string_list[[i]] + if (i == 1) { + output$load_from_json_string( + json_string, + json_rfx_container_label, + json_rfx_mapper_label, + json_rfx_groupids_label + ) + } else { + output$append_from_json_string( + json_string, + json_rfx_container_label, + json_rfx_mapper_label, + json_rfx_groupids_label + ) } - return(output) + } + return(output) } #' Load a vector from json @@ -745,16 +745,16 @@ loadRandomEffectSamplesCombinedJsonString <- function( #' example_json$add_vector("myvec", example_vec) #' roundtrip_vec <- loadVectorJson(example_json, "myvec") loadVectorJson <- function( - json_object, - json_vector_label, - subfolder_name = NULL + json_object, + json_vector_label, + subfolder_name = NULL ) { - if (is.null(subfolder_name)) { - output <- json_object$get_vector(json_vector_label) - } else { - output <- json_object$get_vector(json_vector_label, subfolder_name) - } - return(output) + if (is.null(subfolder_name)) { + output <- json_object$get_vector(json_vector_label) + } else { + output <- json_object$get_vector(json_vector_label, subfolder_name) + } + return(output) } #' Load a scalar from json @@ -772,16 +772,16 @@ loadVectorJson <- function( #' example_json$add_scalar("myscalar", example_scalar) #' roundtrip_scalar <- loadScalarJson(example_json, "myscalar") loadScalarJson <- function( - json_object, - json_scalar_label, - subfolder_name = NULL + json_object, + json_scalar_label, + subfolder_name = NULL ) { - if (is.null(subfolder_name)) { - output <- json_object$get_scalar(json_scalar_label) - } else { - output <- json_object$get_scalar(json_scalar_label, subfolder_name) - } - return(output) + if (is.null(subfolder_name)) { + output <- json_object$get_scalar(json_scalar_label) + } else { + output <- json_object$get_scalar(json_scalar_label, subfolder_name) + } + return(output) } #' Create a new (empty) C++ Json object @@ -794,7 +794,7 @@ loadScalarJson <- function( #' example_json <- createCppJson() #' example_json$add_vector("myvec", example_vec) createCppJson <- function() { - return(invisible((CppJson$new()))) + return(invisible((CppJson$new()))) } #' Create a C++ Json object from a Json file @@ -812,9 +812,9 @@ createCppJson <- function() { #' example_json_roundtrip <- createCppJsonFile(file.path(tmpjson)) #' unlink(tmpjson) createCppJsonFile <- function(json_filename) { - invisible((output <- CppJson$new())) - output$load_from_file(json_filename) - return(output) + invisible((output <- CppJson$new())) + output$load_from_file(json_filename) + return(output) } #' Create a C++ Json object from a Json string @@ -830,7 +830,7 @@ createCppJsonFile <- function(json_filename) { #' example_json_string <- example_json$return_json_string() #' example_json_roundtrip <- createCppJsonString(example_json_string) createCppJsonString <- function(json_string) { - invisible((output <- CppJson$new())) - output$load_from_string(json_string) - return(output) + invisible((output <- CppJson$new())) + output$load_from_string(json_string) + return(output) } diff --git a/R/utils.R b/R/utils.R index ad9752bc..c169cc1c 100644 --- a/R/utils.R +++ b/R/utils.R @@ -6,18 +6,18 @@ #' #' @return Parameter list with defaults overriden by values supplied in `user_params` preprocessParams <- function(default_params, user_params = NULL) { - # Override defaults from general_params - if (!is.null(user_params)) { - for (key in names(user_params)) { - if (key %in% names(default_params)) { - val <- user_params[[key]] - if (!is.null(val)) default_params[[key]] <- val - } - } + # Override defaults from general_params + if (!is.null(user_params)) { + for (key in names(user_params)) { + if (key %in% names(default_params)) { + val <- user_params[[key]] + if (!is.null(val)) default_params[[key]] <- val + } } + } - # Return result - return(default_params) + # Return result + return(default_params) } #' Preprocess covariates. DataFrames will be preprocessed based on their column @@ -35,19 +35,19 @@ preprocessParams <- function(default_params, user_params = NULL) { #' preprocess_list <- preprocessTrainData(cov_mat) #' X <- preprocess_list$X preprocessTrainData <- function(input_data) { - # Input checks - if ((!is.matrix(input_data)) && (!is.data.frame(input_data))) { - stop("Covariates provided must be a dataframe or matrix") - } - - # Routing the correct preprocessing function - if (is.matrix(input_data)) { - output <- preprocessTrainMatrix(input_data) - } else { - output <- preprocessTrainDataFrame(input_data) - } - - return(output) + # Input checks + if ((!is.matrix(input_data)) && (!is.data.frame(input_data))) { + stop("Covariates provided must be a dataframe or matrix") + } + + # Routing the correct preprocessing function + if (is.matrix(input_data)) { + output <- preprocessTrainMatrix(input_data) + } else { + output <- preprocessTrainDataFrame(input_data) + } + + return(output) } #' Preprocess covariates. DataFrames will be preprocessed based on their column @@ -66,19 +66,19 @@ preprocessTrainData <- function(input_data) { #' num_numeric_vars = 3, numeric_vars = c("x1", "x2", "x3")) #' X_preprocessed <- preprocessPredictionData(cov_df, metadata) preprocessPredictionData <- function(input_data, metadata) { - # Input checks - if ((!is.matrix(input_data)) && (!is.data.frame(input_data))) { - stop("Covariates provided must be a dataframe or matrix") - } - - # Routing the correct preprocessing function - if (is.matrix(input_data)) { - X <- preprocessPredictionMatrix(input_data, metadata) - } else { - X <- preprocessPredictionDataFrame(input_data, metadata) - } - - return(X) + # Input checks + if ((!is.matrix(input_data)) && (!is.data.frame(input_data))) { + stop("Covariates provided must be a dataframe or matrix") + } + + # Routing the correct preprocessing function + if (is.matrix(input_data)) { + X <- preprocessPredictionMatrix(input_data, metadata) + } else { + X <- preprocessPredictionDataFrame(input_data, metadata) + } + + return(X) } #' Preprocess a matrix of covariate values, assuming all columns are numeric. @@ -96,38 +96,38 @@ preprocessPredictionData <- function(input_data, metadata) { #' preprocess_list <- preprocessTrainMatrix(cov_mat) #' X <- preprocess_list$X preprocessTrainMatrix <- function(input_matrix) { - # Input checks - if (!is.matrix(input_matrix)) { - stop("covariates provided must be a matrix") - } - - # Unpack metadata (assuming all variables are numeric) - names(input_matrix) <- paste0("x", 1:ncol(input_matrix)) - df_vars <- names(input_matrix) - num_ordered_cat_vars <- 0 - num_unordered_cat_vars <- 0 - num_numeric_vars <- ncol(input_matrix) - numeric_vars <- names(input_matrix) - feature_types <- rep(0, ncol(input_matrix)) - - # Unpack data - X <- input_matrix - - # Aggregate results into a list - metadata <- list( - feature_types = feature_types, - num_ordered_cat_vars = num_ordered_cat_vars, - num_unordered_cat_vars = num_unordered_cat_vars, - num_numeric_vars = num_numeric_vars, - numeric_vars = numeric_vars, - original_var_indices = 1:num_numeric_vars - ) - output <- list( - data = X, - metadata = metadata - ) - - return(output) + # Input checks + if (!is.matrix(input_matrix)) { + stop("covariates provided must be a matrix") + } + + # Unpack metadata (assuming all variables are numeric) + names(input_matrix) <- paste0("x", 1:ncol(input_matrix)) + df_vars <- names(input_matrix) + num_ordered_cat_vars <- 0 + num_unordered_cat_vars <- 0 + num_numeric_vars <- ncol(input_matrix) + numeric_vars <- names(input_matrix) + feature_types <- rep(0, ncol(input_matrix)) + + # Unpack data + X <- input_matrix + + # Aggregate results into a list + metadata <- list( + feature_types = feature_types, + num_ordered_cat_vars = num_ordered_cat_vars, + num_unordered_cat_vars = num_unordered_cat_vars, + num_numeric_vars = num_numeric_vars, + numeric_vars = numeric_vars, + original_var_indices = 1:num_numeric_vars + ) + output <- list( + data = X, + metadata = metadata + ) + + return(output) } #' Preprocess a matrix of covariate values, assuming all columns are numeric. @@ -145,17 +145,17 @@ preprocessTrainMatrix <- function(input_matrix) { #' num_numeric_vars = 3, numeric_vars = c("x1", "x2", "x3")) #' X_preprocessed <- preprocessPredictionMatrix(cov_mat, metadata) preprocessPredictionMatrix <- function(input_matrix, metadata) { - # Input checks - if (!is.matrix(input_matrix)) { - stop("covariates provided must be a matrix") - } - if (!(ncol(input_matrix) == metadata$num_numeric_vars)) { - stop( - "Prediction set covariates have inconsistent dimension from train set covariates" - ) - } + # Input checks + if (!is.matrix(input_matrix)) { + stop("covariates provided must be a matrix") + } + if (!(ncol(input_matrix) == metadata$num_numeric_vars)) { + stop( + "Prediction set covariates have inconsistent dimension from train set covariates" + ) + } - return(input_matrix) + return(input_matrix) } #' Preprocess a dataframe of covariate values, converting categorical variables @@ -170,126 +170,126 @@ preprocessPredictionMatrix <- function(input_matrix, metadata) { #' of variable, unique categories associated with categorical variables, and the #' vector of feature types needed for calls to BART and BCF. preprocessTrainDataFrame <- function(input_df) { - # Input checks / details - if (!is.data.frame(input_df)) { - stop("covariates provided must be a data frame") - } - df_vars <- names(input_df) - - # Detect ordered and unordered categorical variables - - # First, ordered categorical: users must have explicitly - # converted this to a factor with ordered = TRUE - factor_mask <- sapply(input_df, is.factor) - ordered_mask <- sapply(input_df, is.ordered) - ordered_cat_matches <- factor_mask & ordered_mask - ordered_cat_vars <- df_vars[ordered_cat_matches] - ordered_cat_var_inds <- unname(which(ordered_cat_matches)) - num_ordered_cat_vars <- length(ordered_cat_vars) - if (num_ordered_cat_vars > 0) { - ordered_cat_df <- input_df[, ordered_cat_vars, drop = FALSE] - } - - # Next, unordered categorical: we will convert character - # columns but not integer columns (users must explicitly - # convert these to factor) - character_mask <- sapply(input_df, is.character) - unordered_cat_matches <- (factor_mask & (!ordered_mask)) | character_mask - unordered_cat_vars <- df_vars[unordered_cat_matches] - unordered_cat_var_inds <- unname(which(unordered_cat_matches)) - num_unordered_cat_vars <- length(unordered_cat_vars) - if (num_unordered_cat_vars > 0) { - unordered_cat_df <- input_df[, unordered_cat_vars, drop = FALSE] - } - - # Numeric variables - numeric_matches <- (!ordered_cat_matches) & (!unordered_cat_matches) - numeric_vars <- df_vars[numeric_matches] - numeric_var_inds <- unname(which(numeric_matches)) - num_numeric_vars <- length(numeric_vars) - if (num_numeric_vars > 0) { - numeric_df <- input_df[, numeric_vars, drop = FALSE] - } - - # Empty outputs - X <- double(0) - unordered_unique_levels <- list() - ordered_unique_levels <- list() - feature_types <- integer(0) - original_var_indices <- integer(0) - - # First, extract the numeric covariates - if (num_numeric_vars > 0) { - Xnum <- double(0) - for (i in 1:ncol(numeric_df)) { - stopifnot(is.numeric(numeric_df[, i])) - Xnum <- cbind(Xnum, numeric_df[, i]) - } - X <- cbind(X, unname(Xnum)) - feature_types <- c(feature_types, rep(0, ncol(Xnum))) - original_var_indices <- c(original_var_indices, numeric_var_inds) - } - - # Next, run some simple preprocessing on the ordered categorical covariates - if (num_ordered_cat_vars > 0) { - Xordcat <- double(0) - for (i in 1:ncol(ordered_cat_df)) { - var_name <- names(ordered_cat_df)[i] - preprocess_list <- orderedCatInitializeAndPreprocess(ordered_cat_df[, - i - ]) - ordered_unique_levels[[var_name]] <- preprocess_list$unique_levels - Xordcat <- cbind(Xordcat, preprocess_list$x_preprocessed) - } - X <- cbind(X, unname(Xordcat)) - feature_types <- c(feature_types, rep(1, ncol(Xordcat))) - original_var_indices <- c(original_var_indices, ordered_cat_var_inds) + # Input checks / details + if (!is.data.frame(input_df)) { + stop("covariates provided must be a data frame") + } + df_vars <- names(input_df) + + # Detect ordered and unordered categorical variables + + # First, ordered categorical: users must have explicitly + # converted this to a factor with ordered = TRUE + factor_mask <- sapply(input_df, is.factor) + ordered_mask <- sapply(input_df, is.ordered) + ordered_cat_matches <- factor_mask & ordered_mask + ordered_cat_vars <- df_vars[ordered_cat_matches] + ordered_cat_var_inds <- unname(which(ordered_cat_matches)) + num_ordered_cat_vars <- length(ordered_cat_vars) + if (num_ordered_cat_vars > 0) { + ordered_cat_df <- input_df[, ordered_cat_vars, drop = FALSE] + } + + # Next, unordered categorical: we will convert character + # columns but not integer columns (users must explicitly + # convert these to factor) + character_mask <- sapply(input_df, is.character) + unordered_cat_matches <- (factor_mask & (!ordered_mask)) | character_mask + unordered_cat_vars <- df_vars[unordered_cat_matches] + unordered_cat_var_inds <- unname(which(unordered_cat_matches)) + num_unordered_cat_vars <- length(unordered_cat_vars) + if (num_unordered_cat_vars > 0) { + unordered_cat_df <- input_df[, unordered_cat_vars, drop = FALSE] + } + + # Numeric variables + numeric_matches <- (!ordered_cat_matches) & (!unordered_cat_matches) + numeric_vars <- df_vars[numeric_matches] + numeric_var_inds <- unname(which(numeric_matches)) + num_numeric_vars <- length(numeric_vars) + if (num_numeric_vars > 0) { + numeric_df <- input_df[, numeric_vars, drop = FALSE] + } + + # Empty outputs + X <- double(0) + unordered_unique_levels <- list() + ordered_unique_levels <- list() + feature_types <- integer(0) + original_var_indices <- integer(0) + + # First, extract the numeric covariates + if (num_numeric_vars > 0) { + Xnum <- double(0) + for (i in 1:ncol(numeric_df)) { + stopifnot(is.numeric(numeric_df[, i])) + Xnum <- cbind(Xnum, numeric_df[, i]) } - - # Finally, one-hot encode the unordered categorical covariates - if (num_unordered_cat_vars > 0) { - one_hot_mats <- list() - for (i in 1:ncol(unordered_cat_df)) { - var_name <- names(unordered_cat_df)[i] - encode_list <- oneHotInitializeAndEncode(unordered_cat_df[, i]) - unordered_unique_levels[[var_name]] <- encode_list$unique_levels - one_hot_mats[[var_name]] <- encode_list$Xtilde - one_hot_var <- rep( - unordered_cat_var_inds[i], - ncol(encode_list$Xtilde) - ) - original_var_indices <- c(original_var_indices, one_hot_var) - } - Xcat <- do.call(cbind, one_hot_mats) - X <- cbind(X, unname(Xcat)) - feature_types <- c(feature_types, rep(1, ncol(Xcat))) - } - - # Aggregate results into a list - metadata <- list( - feature_types = feature_types, - num_ordered_cat_vars = num_ordered_cat_vars, - num_unordered_cat_vars = num_unordered_cat_vars, - num_numeric_vars = num_numeric_vars, - original_var_indices = original_var_indices - ) - if (num_ordered_cat_vars > 0) { - metadata[["ordered_cat_vars"]] = ordered_cat_vars - metadata[["ordered_unique_levels"]] = ordered_unique_levels + X <- cbind(X, unname(Xnum)) + feature_types <- c(feature_types, rep(0, ncol(Xnum))) + original_var_indices <- c(original_var_indices, numeric_var_inds) + } + + # Next, run some simple preprocessing on the ordered categorical covariates + if (num_ordered_cat_vars > 0) { + Xordcat <- double(0) + for (i in 1:ncol(ordered_cat_df)) { + var_name <- names(ordered_cat_df)[i] + preprocess_list <- orderedCatInitializeAndPreprocess(ordered_cat_df[, + i + ]) + ordered_unique_levels[[var_name]] <- preprocess_list$unique_levels + Xordcat <- cbind(Xordcat, preprocess_list$x_preprocessed) } - if (num_unordered_cat_vars > 0) { - metadata[["unordered_cat_vars"]] = unordered_cat_vars - metadata[["unordered_unique_levels"]] = unordered_unique_levels + X <- cbind(X, unname(Xordcat)) + feature_types <- c(feature_types, rep(1, ncol(Xordcat))) + original_var_indices <- c(original_var_indices, ordered_cat_var_inds) + } + + # Finally, one-hot encode the unordered categorical covariates + if (num_unordered_cat_vars > 0) { + one_hot_mats <- list() + for (i in 1:ncol(unordered_cat_df)) { + var_name <- names(unordered_cat_df)[i] + encode_list <- oneHotInitializeAndEncode(unordered_cat_df[, i]) + unordered_unique_levels[[var_name]] <- encode_list$unique_levels + one_hot_mats[[var_name]] <- encode_list$Xtilde + one_hot_var <- rep( + unordered_cat_var_inds[i], + ncol(encode_list$Xtilde) + ) + original_var_indices <- c(original_var_indices, one_hot_var) } - if (num_numeric_vars > 0) { - metadata[["numeric_vars"]] = numeric_vars - } - output <- list( - data = X, - metadata = metadata - ) - - return(output) + Xcat <- do.call(cbind, one_hot_mats) + X <- cbind(X, unname(Xcat)) + feature_types <- c(feature_types, rep(1, ncol(Xcat))) + } + + # Aggregate results into a list + metadata <- list( + feature_types = feature_types, + num_ordered_cat_vars = num_ordered_cat_vars, + num_unordered_cat_vars = num_unordered_cat_vars, + num_numeric_vars = num_numeric_vars, + original_var_indices = original_var_indices + ) + if (num_ordered_cat_vars > 0) { + metadata[["ordered_cat_vars"]] = ordered_cat_vars + metadata[["ordered_unique_levels"]] = ordered_unique_levels + } + if (num_unordered_cat_vars > 0) { + metadata[["unordered_cat_vars"]] = unordered_cat_vars + metadata[["unordered_unique_levels"]] = unordered_unique_levels + } + if (num_numeric_vars > 0) { + metadata[["numeric_vars"]] = numeric_vars + } + output <- list( + data = X, + metadata = metadata + ) + + return(output) } #' Preprocess a dataframe of covariate values, converting categorical variables @@ -309,70 +309,70 @@ preprocessTrainDataFrame <- function(input_df) { #' num_numeric_vars = 3, numeric_vars = c("x1", "x2", "x3")) #' X_preprocessed <- preprocessPredictionDataFrame(cov_df, metadata) preprocessPredictionDataFrame <- function(input_df, metadata) { - if (!is.data.frame(input_df)) { - stop("covariates provided must be a data frame") - } - df_vars <- names(input_df) - num_ordered_cat_vars <- metadata$num_ordered_cat_vars - num_unordered_cat_vars <- metadata$num_unordered_cat_vars - num_numeric_vars <- metadata$num_numeric_vars - - if (num_ordered_cat_vars > 0) { - ordered_cat_vars <- metadata$ordered_cat_vars - ordered_cat_df <- input_df[, ordered_cat_vars, drop = FALSE] + if (!is.data.frame(input_df)) { + stop("covariates provided must be a data frame") + } + df_vars <- names(input_df) + num_ordered_cat_vars <- metadata$num_ordered_cat_vars + num_unordered_cat_vars <- metadata$num_unordered_cat_vars + num_numeric_vars <- metadata$num_numeric_vars + + if (num_ordered_cat_vars > 0) { + ordered_cat_vars <- metadata$ordered_cat_vars + ordered_cat_df <- input_df[, ordered_cat_vars, drop = FALSE] + } + if (num_unordered_cat_vars > 0) { + unordered_cat_vars <- metadata$unordered_cat_vars + unordered_cat_df <- input_df[, unordered_cat_vars, drop = FALSE] + } + if (num_numeric_vars > 0) { + numeric_vars <- metadata$numeric_vars + numeric_df <- input_df[, numeric_vars, drop = FALSE] + } + + # Empty outputs + X <- double(0) + + # First, extract the numeric covariates + if (num_numeric_vars > 0) { + Xnum <- double(0) + for (i in 1:ncol(numeric_df)) { + stopifnot(is.numeric(numeric_df[, i])) + Xnum <- cbind(Xnum, numeric_df[, i]) } - if (num_unordered_cat_vars > 0) { - unordered_cat_vars <- metadata$unordered_cat_vars - unordered_cat_df <- input_df[, unordered_cat_vars, drop = FALSE] + X <- cbind(X, unname(Xnum)) + } + + # Next, run some simple preprocessing on the ordered categorical covariates + if (num_ordered_cat_vars > 0) { + Xordcat <- double(0) + for (i in 1:ncol(ordered_cat_df)) { + var_name <- names(ordered_cat_df)[i] + x_preprocessed <- orderedCatPreprocess( + ordered_cat_df[, i], + metadata$ordered_unique_levels[[var_name]] + ) + Xordcat <- cbind(Xordcat, x_preprocessed) } - if (num_numeric_vars > 0) { - numeric_vars <- metadata$numeric_vars - numeric_df <- input_df[, numeric_vars, drop = FALSE] + X <- cbind(X, unname(Xordcat)) + } + + # Finally, one-hot encode the unordered categorical covariates + if (num_unordered_cat_vars > 0) { + one_hot_mats <- list() + for (i in 1:ncol(unordered_cat_df)) { + var_name <- names(unordered_cat_df)[i] + Xtilde <- oneHotEncode( + unordered_cat_df[, i], + metadata$unordered_unique_levels[[var_name]] + ) + one_hot_mats[[var_name]] <- Xtilde } + Xcat <- do.call(cbind, one_hot_mats) + X <- cbind(X, unname(Xcat)) + } - # Empty outputs - X <- double(0) - - # First, extract the numeric covariates - if (num_numeric_vars > 0) { - Xnum <- double(0) - for (i in 1:ncol(numeric_df)) { - stopifnot(is.numeric(numeric_df[, i])) - Xnum <- cbind(Xnum, numeric_df[, i]) - } - X <- cbind(X, unname(Xnum)) - } - - # Next, run some simple preprocessing on the ordered categorical covariates - if (num_ordered_cat_vars > 0) { - Xordcat <- double(0) - for (i in 1:ncol(ordered_cat_df)) { - var_name <- names(ordered_cat_df)[i] - x_preprocessed <- orderedCatPreprocess( - ordered_cat_df[, i], - metadata$ordered_unique_levels[[var_name]] - ) - Xordcat <- cbind(Xordcat, x_preprocessed) - } - X <- cbind(X, unname(Xordcat)) - } - - # Finally, one-hot encode the unordered categorical covariates - if (num_unordered_cat_vars > 0) { - one_hot_mats <- list() - for (i in 1:ncol(unordered_cat_df)) { - var_name <- names(unordered_cat_df)[i] - Xtilde <- oneHotEncode( - unordered_cat_df[, i], - metadata$unordered_unique_levels[[var_name]] - ) - one_hot_mats[[var_name]] <- Xtilde - } - Xcat <- do.call(cbind, one_hot_mats) - X <- cbind(X, unname(Xcat)) - } - - return(X) + return(X) } #' Convert the persistent aspects of a covariate preprocessor to (in-memory) C++ JSON object @@ -388,59 +388,59 @@ preprocessPredictionDataFrame <- function(input_df, metadata) { #' preprocess_list <- preprocessTrainData(cov_mat) #' preprocessor_json <- convertPreprocessorToJson(preprocess_list$metadata) convertPreprocessorToJson <- function(object) { - jsonobj <- createCppJson() - if (is.null(object$feature_types)) { - stop("This covariate preprocessor has not yet been fit") - } - - # Add internal scalars - jsonobj$add_integer("num_numeric_vars", object$num_numeric_vars) - jsonobj$add_integer("num_ordered_cat_vars", object$num_ordered_cat_vars) - jsonobj$add_integer("num_unordered_cat_vars", object$num_unordered_cat_vars) - - # Add internal vectors - jsonobj$add_vector("feature_types", object$feature_types) - jsonobj$add_vector("original_var_indices", object$original_var_indices) - if (object$num_numeric_vars > 0) { - jsonobj$add_string_vector("numeric_vars", object$numeric_vars) + jsonobj <- createCppJson() + if (is.null(object$feature_types)) { + stop("This covariate preprocessor has not yet been fit") + } + + # Add internal scalars + jsonobj$add_integer("num_numeric_vars", object$num_numeric_vars) + jsonobj$add_integer("num_ordered_cat_vars", object$num_ordered_cat_vars) + jsonobj$add_integer("num_unordered_cat_vars", object$num_unordered_cat_vars) + + # Add internal vectors + jsonobj$add_vector("feature_types", object$feature_types) + jsonobj$add_vector("original_var_indices", object$original_var_indices) + if (object$num_numeric_vars > 0) { + jsonobj$add_string_vector("numeric_vars", object$numeric_vars) + } + if (object$num_ordered_cat_vars > 0) { + jsonobj$add_string_vector("ordered_cat_vars", object$ordered_cat_vars) + for (i in 1:object$num_ordered_cat_vars) { + var_key <- names(object$ordered_unique_levels)[i] + jsonobj$add_string( + paste0("key_", i), + var_key, + "ordered_unique_level_keys" + ) + jsonobj$add_string_vector( + var_key, + object$ordered_unique_levels[[i]], + "ordered_unique_levels" + ) } - if (object$num_ordered_cat_vars > 0) { - jsonobj$add_string_vector("ordered_cat_vars", object$ordered_cat_vars) - for (i in 1:object$num_ordered_cat_vars) { - var_key <- names(object$ordered_unique_levels)[i] - jsonobj$add_string( - paste0("key_", i), - var_key, - "ordered_unique_level_keys" - ) - jsonobj$add_string_vector( - var_key, - object$ordered_unique_levels[[i]], - "ordered_unique_levels" - ) - } - } - if (object$num_unordered_cat_vars > 0) { - jsonobj$add_string_vector( - "unordered_cat_vars", - object$unordered_cat_vars - ) - for (i in 1:object$num_unordered_cat_vars) { - var_key <- names(object$unordered_unique_levels)[i] - jsonobj$add_string( - paste0("key_", i), - var_key, - "unordered_unique_level_keys" - ) - jsonobj$add_string_vector( - var_key, - object$unordered_unique_levels[[i]], - "unordered_unique_levels" - ) - } + } + if (object$num_unordered_cat_vars > 0) { + jsonobj$add_string_vector( + "unordered_cat_vars", + object$unordered_cat_vars + ) + for (i in 1:object$num_unordered_cat_vars) { + var_key <- names(object$unordered_unique_levels)[i] + jsonobj$add_string( + paste0("key_", i), + var_key, + "unordered_unique_level_keys" + ) + jsonobj$add_string_vector( + var_key, + object$unordered_unique_levels[[i]], + "unordered_unique_levels" + ) } + } - return(jsonobj) + return(jsonobj) } #' Convert the persistent aspects of a covariate preprocessor to (in-memory) JSON string @@ -456,11 +456,11 @@ convertPreprocessorToJson <- function(object) { #' preprocess_list <- preprocessTrainData(cov_mat) #' preprocessor_json_string <- savePreprocessorToJsonString(preprocess_list$metadata) savePreprocessorToJsonString <- function(object) { - # Convert to Json - jsonobj <- convertPreprocessorToJson(object) + # Convert to Json + jsonobj <- convertPreprocessorToJson(object) - # Dump to string - return(jsonobj$return_json_string()) + # Dump to string + return(jsonobj$return_json_string()) } #' Reload a covariate preprocessor object from a JSON string containing a serialized preprocessor @@ -476,66 +476,66 @@ savePreprocessorToJsonString <- function(object) { #' preprocessor_json <- convertPreprocessorToJson(preprocess_list$metadata) #' preprocessor_roundtrip <- createPreprocessorFromJson(preprocessor_json) createPreprocessorFromJson <- function(json_object) { - # Initialize the metadata list - metadata <- list() - - # Unpack internal scalars - metadata[["num_numeric_vars"]] <- json_object$get_integer( - "num_numeric_vars" + # Initialize the metadata list + metadata <- list() + + # Unpack internal scalars + metadata[["num_numeric_vars"]] <- json_object$get_integer( + "num_numeric_vars" + ) + metadata[["num_ordered_cat_vars"]] <- json_object$get_integer( + "num_ordered_cat_vars" + ) + metadata[["num_unordered_cat_vars"]] <- json_object$get_integer( + "num_unordered_cat_vars" + ) + + # Unpack internal vectors + metadata[["feature_types"]] <- json_object$get_vector("feature_types") + metadata[["original_var_indices"]] <- json_object$get_vector( + "original_var_indices" + ) + if (metadata$num_numeric_vars > 0) { + metadata[["numeric_vars"]] <- json_object$get_string_vector( + "numeric_vars" ) - metadata[["num_ordered_cat_vars"]] <- json_object$get_integer( - "num_ordered_cat_vars" + } + if (metadata$num_ordered_cat_vars > 0) { + metadata[["ordered_cat_vars"]] <- json_object$get_string_vector( + "ordered_cat_vars" ) - metadata[["num_unordered_cat_vars"]] <- json_object$get_integer( - "num_unordered_cat_vars" - ) - - # Unpack internal vectors - metadata[["feature_types"]] <- json_object$get_vector("feature_types") - metadata[["original_var_indices"]] <- json_object$get_vector( - "original_var_indices" - ) - if (metadata$num_numeric_vars > 0) { - metadata[["numeric_vars"]] <- json_object$get_string_vector( - "numeric_vars" - ) - } - if (metadata$num_ordered_cat_vars > 0) { - metadata[["ordered_cat_vars"]] <- json_object$get_string_vector( - "ordered_cat_vars" - ) - ordered_unique_levels <- list() - for (i in 1:metadata$num_ordered_cat_vars) { - var_key <- json_object$get_string( - paste0("key_", i), - "ordered_unique_level_keys" - ) - ordered_unique_levels[[var_key]] <- json_object$get_string_vector( - var_key, - "ordered_unique_levels" - ) - } - metadata[["ordered_unique_levels"]] <- ordered_unique_levels + ordered_unique_levels <- list() + for (i in 1:metadata$num_ordered_cat_vars) { + var_key <- json_object$get_string( + paste0("key_", i), + "ordered_unique_level_keys" + ) + ordered_unique_levels[[var_key]] <- json_object$get_string_vector( + var_key, + "ordered_unique_levels" + ) } - if (metadata$num_unordered_cat_vars > 0) { - metadata[["unordered_cat_vars"]] <- json_object$get_string_vector( - "unordered_cat_vars" - ) - unordered_unique_levels <- list() - for (i in 1:metadata$num_unordered_cat_vars) { - var_key <- json_object$get_string( - paste0("key_", i), - "unordered_unique_level_keys" - ) - unordered_unique_levels[[var_key]] <- json_object$get_string_vector( - var_key, - "unordered_unique_levels" - ) - } - metadata[["unordered_unique_levels"]] <- unordered_unique_levels + metadata[["ordered_unique_levels"]] <- ordered_unique_levels + } + if (metadata$num_unordered_cat_vars > 0) { + metadata[["unordered_cat_vars"]] <- json_object$get_string_vector( + "unordered_cat_vars" + ) + unordered_unique_levels <- list() + for (i in 1:metadata$num_unordered_cat_vars) { + var_key <- json_object$get_string( + paste0("key_", i), + "unordered_unique_level_keys" + ) + unordered_unique_levels[[var_key]] <- json_object$get_string_vector( + var_key, + "unordered_unique_levels" + ) } + metadata[["unordered_unique_levels"]] <- unordered_unique_levels + } - return(metadata) + return(metadata) } #' Reload a covariate preprocessor object from a JSON string containing a serialized preprocessor @@ -551,13 +551,13 @@ createPreprocessorFromJson <- function(json_object) { #' preprocessor_json_string <- savePreprocessorToJsonString(preprocess_list$metadata) #' preprocessor_roundtrip <- createPreprocessorFromJsonString(preprocessor_json_string) createPreprocessorFromJsonString <- function(json_string) { - # Load a `CppJson` object from string - preprocessor_json <- createCppJsonString(json_string) + # Load a `CppJson` object from string + preprocessor_json <- createCppJsonString(json_string) - # Create and return the BCF object - preprocessor_object <- createPreprocessorFromJson(preprocessor_json) + # Create and return the BCF object + preprocessor_object <- createPreprocessorFromJson(preprocessor_json) - return(preprocessor_object) + return(preprocessor_object) } #' Preprocess a dataframe of covariate values, converting categorical variables @@ -579,129 +579,129 @@ createPreprocessorFromJsonString <- function(json_string) { #' preprocess_list <- createForestCovariates(cov_df) #' X <- preprocess_list$X createForestCovariates <- function( - input_data, - ordered_cat_vars = NULL, - unordered_cat_vars = NULL + input_data, + ordered_cat_vars = NULL, + unordered_cat_vars = NULL ) { - if (is.matrix(input_data)) { - input_df <- as.data.frame(input_data) - names(input_df) <- paste0("x", 1:ncol(input_data)) - if (!is.null(ordered_cat_vars)) { - if (is.numeric(ordered_cat_vars)) { - ordered_cat_vars <- paste0("x", as.integer(ordered_cat_vars)) - } - } - if (!is.null(unordered_cat_vars)) { - if (is.numeric(unordered_cat_vars)) { - unordered_cat_vars <- paste0( - "x", - as.integer(unordered_cat_vars) - ) - } - } - } else if (is.data.frame(input_data)) { - input_df <- input_data - } else { - stop("input_data must be either a matrix or a data frame") - } - df_vars <- names(input_df) - if (is.null(ordered_cat_vars)) { - ordered_cat_matches <- rep(FALSE, length(df_vars)) - } else { - ordered_cat_matches <- df_vars %in% ordered_cat_vars - } - if (is.null(unordered_cat_vars)) { - unordered_cat_matches <- rep(FALSE, length(df_vars)) - } else { - unordered_cat_matches <- df_vars %in% unordered_cat_vars - } - numeric_matches <- ((!ordered_cat_matches) & (!unordered_cat_matches)) - ordered_cat_vars <- df_vars[ordered_cat_matches] - unordered_cat_vars <- df_vars[unordered_cat_matches] - numeric_vars <- df_vars[numeric_matches] - num_ordered_cat_vars <- length(ordered_cat_vars) - num_unordered_cat_vars <- length(unordered_cat_vars) - num_numeric_vars <- length(numeric_vars) - if (num_ordered_cat_vars > 0) { - ordered_cat_df <- input_df[, ordered_cat_vars, drop = FALSE] - } - if (num_unordered_cat_vars > 0) { - unordered_cat_df <- input_df[, unordered_cat_vars, drop = FALSE] - } - if (num_numeric_vars > 0) { - numeric_df <- input_df[, numeric_vars, drop = FALSE] - } - - # Empty outputs - X <- double(0) - unordered_unique_levels <- list() - ordered_unique_levels <- list() - feature_types <- integer(0) - - # First, extract the numeric covariates - if (num_numeric_vars > 0) { - Xnum <- double(0) - for (i in 1:ncol(numeric_df)) { - stopifnot(is.numeric(numeric_df[, i])) - Xnum <- cbind(Xnum, numeric_df[, i]) - } - X <- cbind(X, unname(Xnum)) - feature_types <- c(feature_types, rep(0, ncol(Xnum))) - } - - # Next, run some simple preprocessing on the ordered categorical covariates - if (num_ordered_cat_vars > 0) { - Xordcat <- double(0) - for (i in 1:ncol(ordered_cat_df)) { - var_name <- names(ordered_cat_df)[i] - preprocess_list <- orderedCatInitializeAndPreprocess(ordered_cat_df[, - i - ]) - ordered_unique_levels[[var_name]] <- preprocess_list$unique_levels - Xordcat <- cbind(Xordcat, preprocess_list$x_preprocessed) - } - X <- cbind(X, unname(Xordcat)) - feature_types <- c(feature_types, rep(1, ncol(Xordcat))) + if (is.matrix(input_data)) { + input_df <- as.data.frame(input_data) + names(input_df) <- paste0("x", 1:ncol(input_data)) + if (!is.null(ordered_cat_vars)) { + if (is.numeric(ordered_cat_vars)) { + ordered_cat_vars <- paste0("x", as.integer(ordered_cat_vars)) + } } - - # Finally, one-hot encode the unordered categorical covariates - if (num_unordered_cat_vars > 0) { - one_hot_mats <- list() - for (i in 1:ncol(unordered_cat_df)) { - var_name <- names(unordered_cat_df)[i] - encode_list <- oneHotInitializeAndEncode(unordered_cat_df[, i]) - unordered_unique_levels[[var_name]] <- encode_list$unique_levels - one_hot_mats[[var_name]] <- encode_list$Xtilde - } - Xcat <- do.call(cbind, one_hot_mats) - X <- cbind(X, unname(Xcat)) - feature_types <- c(feature_types, rep(1, ncol(Xcat))) + if (!is.null(unordered_cat_vars)) { + if (is.numeric(unordered_cat_vars)) { + unordered_cat_vars <- paste0( + "x", + as.integer(unordered_cat_vars) + ) + } } - - # Aggregate results into a list - metadata <- list( - feature_types = feature_types, - num_ordered_cat_vars = num_ordered_cat_vars, - num_unordered_cat_vars = num_unordered_cat_vars, - num_numeric_vars = num_numeric_vars - ) - if (num_ordered_cat_vars > 0) { - metadata[["ordered_cat_vars"]] = ordered_cat_vars - metadata[["ordered_unique_levels"]] = ordered_unique_levels + } else if (is.data.frame(input_data)) { + input_df <- input_data + } else { + stop("input_data must be either a matrix or a data frame") + } + df_vars <- names(input_df) + if (is.null(ordered_cat_vars)) { + ordered_cat_matches <- rep(FALSE, length(df_vars)) + } else { + ordered_cat_matches <- df_vars %in% ordered_cat_vars + } + if (is.null(unordered_cat_vars)) { + unordered_cat_matches <- rep(FALSE, length(df_vars)) + } else { + unordered_cat_matches <- df_vars %in% unordered_cat_vars + } + numeric_matches <- ((!ordered_cat_matches) & (!unordered_cat_matches)) + ordered_cat_vars <- df_vars[ordered_cat_matches] + unordered_cat_vars <- df_vars[unordered_cat_matches] + numeric_vars <- df_vars[numeric_matches] + num_ordered_cat_vars <- length(ordered_cat_vars) + num_unordered_cat_vars <- length(unordered_cat_vars) + num_numeric_vars <- length(numeric_vars) + if (num_ordered_cat_vars > 0) { + ordered_cat_df <- input_df[, ordered_cat_vars, drop = FALSE] + } + if (num_unordered_cat_vars > 0) { + unordered_cat_df <- input_df[, unordered_cat_vars, drop = FALSE] + } + if (num_numeric_vars > 0) { + numeric_df <- input_df[, numeric_vars, drop = FALSE] + } + + # Empty outputs + X <- double(0) + unordered_unique_levels <- list() + ordered_unique_levels <- list() + feature_types <- integer(0) + + # First, extract the numeric covariates + if (num_numeric_vars > 0) { + Xnum <- double(0) + for (i in 1:ncol(numeric_df)) { + stopifnot(is.numeric(numeric_df[, i])) + Xnum <- cbind(Xnum, numeric_df[, i]) } - if (num_unordered_cat_vars > 0) { - metadata[["unordered_cat_vars"]] = unordered_cat_vars - metadata[["unordered_unique_levels"]] = unordered_unique_levels + X <- cbind(X, unname(Xnum)) + feature_types <- c(feature_types, rep(0, ncol(Xnum))) + } + + # Next, run some simple preprocessing on the ordered categorical covariates + if (num_ordered_cat_vars > 0) { + Xordcat <- double(0) + for (i in 1:ncol(ordered_cat_df)) { + var_name <- names(ordered_cat_df)[i] + preprocess_list <- orderedCatInitializeAndPreprocess(ordered_cat_df[, + i + ]) + ordered_unique_levels[[var_name]] <- preprocess_list$unique_levels + Xordcat <- cbind(Xordcat, preprocess_list$x_preprocessed) } - if (num_numeric_vars > 0) { - metadata[["numeric_vars"]] = numeric_vars + X <- cbind(X, unname(Xordcat)) + feature_types <- c(feature_types, rep(1, ncol(Xordcat))) + } + + # Finally, one-hot encode the unordered categorical covariates + if (num_unordered_cat_vars > 0) { + one_hot_mats <- list() + for (i in 1:ncol(unordered_cat_df)) { + var_name <- names(unordered_cat_df)[i] + encode_list <- oneHotInitializeAndEncode(unordered_cat_df[, i]) + unordered_unique_levels[[var_name]] <- encode_list$unique_levels + one_hot_mats[[var_name]] <- encode_list$Xtilde } - output <- list( - data = X, - metadata = metadata - ) - - return(output) + Xcat <- do.call(cbind, one_hot_mats) + X <- cbind(X, unname(Xcat)) + feature_types <- c(feature_types, rep(1, ncol(Xcat))) + } + + # Aggregate results into a list + metadata <- list( + feature_types = feature_types, + num_ordered_cat_vars = num_ordered_cat_vars, + num_unordered_cat_vars = num_unordered_cat_vars, + num_numeric_vars = num_numeric_vars + ) + if (num_ordered_cat_vars > 0) { + metadata[["ordered_cat_vars"]] = ordered_cat_vars + metadata[["ordered_unique_levels"]] = ordered_unique_levels + } + if (num_unordered_cat_vars > 0) { + metadata[["unordered_cat_vars"]] = unordered_cat_vars + metadata[["unordered_unique_levels"]] = unordered_unique_levels + } + if (num_numeric_vars > 0) { + metadata[["numeric_vars"]] = numeric_vars + } + output <- list( + data = X, + metadata = metadata + ) + + return(output) } #' Preprocess a dataframe of covariate values, converting categorical variables @@ -722,75 +722,75 @@ createForestCovariates <- function( #' num_numeric_vars = 3, numeric_vars = c("x1", "x2", "x3")) #' X_preprocessed <- createForestCovariatesFromMetadata(cov_df, metadata) createForestCovariatesFromMetadata <- function(input_data, metadata) { - if (is.matrix(input_data)) { - input_df <- as.data.frame(input_data) - names(input_df) <- paste0("x", 1:ncol(input_data)) - } else if (is.data.frame(input_data)) { - input_df <- input_data - } else { - stop("input_data must be either a matrix or a data frame") - } - df_vars <- names(input_df) - num_ordered_cat_vars <- metadata$num_ordered_cat_vars - num_unordered_cat_vars <- metadata$num_unordered_cat_vars - num_numeric_vars <- metadata$num_numeric_vars - - if (num_ordered_cat_vars > 0) { - ordered_cat_vars <- metadata$ordered_cat_vars - ordered_cat_df <- input_df[, ordered_cat_vars, drop = FALSE] + if (is.matrix(input_data)) { + input_df <- as.data.frame(input_data) + names(input_df) <- paste0("x", 1:ncol(input_data)) + } else if (is.data.frame(input_data)) { + input_df <- input_data + } else { + stop("input_data must be either a matrix or a data frame") + } + df_vars <- names(input_df) + num_ordered_cat_vars <- metadata$num_ordered_cat_vars + num_unordered_cat_vars <- metadata$num_unordered_cat_vars + num_numeric_vars <- metadata$num_numeric_vars + + if (num_ordered_cat_vars > 0) { + ordered_cat_vars <- metadata$ordered_cat_vars + ordered_cat_df <- input_df[, ordered_cat_vars, drop = FALSE] + } + if (num_unordered_cat_vars > 0) { + unordered_cat_vars <- metadata$unordered_cat_vars + unordered_cat_df <- input_df[, unordered_cat_vars, drop = FALSE] + } + if (num_numeric_vars > 0) { + numeric_vars <- metadata$numeric_vars + numeric_df <- input_df[, numeric_vars, drop = FALSE] + } + + # Empty outputs + X <- double(0) + + # First, extract the numeric covariates + if (num_numeric_vars > 0) { + Xnum <- double(0) + for (i in 1:ncol(numeric_df)) { + stopifnot(is.numeric(numeric_df[, i])) + Xnum <- cbind(Xnum, numeric_df[, i]) } - if (num_unordered_cat_vars > 0) { - unordered_cat_vars <- metadata$unordered_cat_vars - unordered_cat_df <- input_df[, unordered_cat_vars, drop = FALSE] + X <- cbind(X, unname(Xnum)) + } + + # Next, run some simple preprocessing on the ordered categorical covariates + if (num_ordered_cat_vars > 0) { + Xordcat <- double(0) + for (i in 1:ncol(ordered_cat_df)) { + var_name <- names(ordered_cat_df)[i] + x_preprocessed <- orderedCatPreprocess( + ordered_cat_df[, i], + metadata$ordered_unique_levels[[var_name]] + ) + Xordcat <- cbind(Xordcat, x_preprocessed) } - if (num_numeric_vars > 0) { - numeric_vars <- metadata$numeric_vars - numeric_df <- input_df[, numeric_vars, drop = FALSE] + X <- cbind(X, unname(Xordcat)) + } + + # Finally, one-hot encode the unordered categorical covariates + if (num_unordered_cat_vars > 0) { + one_hot_mats <- list() + for (i in 1:ncol(unordered_cat_df)) { + var_name <- names(unordered_cat_df)[i] + Xtilde <- oneHotEncode( + unordered_cat_df[, i], + metadata$unordered_unique_levels[[var_name]] + ) + one_hot_mats[[var_name]] <- Xtilde } + Xcat <- do.call(cbind, one_hot_mats) + X <- cbind(X, unname(Xcat)) + } - # Empty outputs - X <- double(0) - - # First, extract the numeric covariates - if (num_numeric_vars > 0) { - Xnum <- double(0) - for (i in 1:ncol(numeric_df)) { - stopifnot(is.numeric(numeric_df[, i])) - Xnum <- cbind(Xnum, numeric_df[, i]) - } - X <- cbind(X, unname(Xnum)) - } - - # Next, run some simple preprocessing on the ordered categorical covariates - if (num_ordered_cat_vars > 0) { - Xordcat <- double(0) - for (i in 1:ncol(ordered_cat_df)) { - var_name <- names(ordered_cat_df)[i] - x_preprocessed <- orderedCatPreprocess( - ordered_cat_df[, i], - metadata$ordered_unique_levels[[var_name]] - ) - Xordcat <- cbind(Xordcat, x_preprocessed) - } - X <- cbind(X, unname(Xordcat)) - } - - # Finally, one-hot encode the unordered categorical covariates - if (num_unordered_cat_vars > 0) { - one_hot_mats <- list() - for (i in 1:ncol(unordered_cat_df)) { - var_name <- names(unordered_cat_df)[i] - Xtilde <- oneHotEncode( - unordered_cat_df[, i], - metadata$unordered_unique_levels[[var_name]] - ) - one_hot_mats[[var_name]] <- Xtilde - } - Xcat <- do.call(cbind, one_hot_mats) - X <- cbind(X, unname(Xcat)) - } - - return(X) + return(X) } #' Convert a vector of unordered categorical data (either numeric or character @@ -813,15 +813,15 @@ createForestCovariatesFromMetadata <- function(input_data, metadata) { #' x <- c("a","c","b","c","d","a","c","a","b","d") #' x_onehot <- oneHotInitializeAndEncode(x) oneHotInitializeAndEncode <- function(x_input) { - stopifnot((is.null(dim(x_input)) && length(x_input) > 0)) - if (is.factor(x_input) && is.ordered(x_input)) { - warning("One-hot encoding an ordered categorical variable") - } - x_factor <- factor(x_input) - unique_levels <- levels(x_factor) - Xtilde <- cbind(unname(model.matrix(~ 0 + x_factor)), 0) - output <- list(Xtilde = Xtilde, unique_levels = unique_levels) - return(output) + stopifnot((is.null(dim(x_input)) && length(x_input) > 0)) + if (is.factor(x_input) && is.ordered(x_input)) { + warning("One-hot encoding an ordered categorical variable") + } + x_factor <- factor(x_input) + unique_levels <- levels(x_factor) + Xtilde <- cbind(unname(model.matrix(~ 0 + x_factor)), 0) + output <- list(Xtilde = Xtilde, unique_levels = unique_levels) + return(output) } #' Convert a vector of unordered categorical data (either numeric or character @@ -848,34 +848,34 @@ oneHotInitializeAndEncode <- function(x_input) { #' x_test <- sample(1:9, 10, TRUE) #' x_onehot <- oneHotEncode(x_test, levels(factor(x))) oneHotEncode <- function(x_input, unique_levels) { - stopifnot((is.null(dim(x_input)) && length(x_input) > 0)) - stopifnot((is.null(dim(unique_levels)) && length(unique_levels) > 0)) - num_unique_levels <- length(unique_levels) - in_sample <- x_input %in% unique_levels - out_of_sample <- !(x_input %in% unique_levels) - has_out_of_sample <- sum(out_of_sample) > 0 - if (has_out_of_sample) { - x_factor_insample <- factor(x_input[in_sample], levels = unique_levels) - Xtilde <- matrix( - 0, - nrow = length(x_input), - ncol = num_unique_levels + 1 - ) - Xtilde_insample <- cbind( - unname(model.matrix(~ 0 + x_factor_insample)), - 0 - ) - Xtilde_out_of_sample <- cbind( - matrix(0, nrow = sum(out_of_sample), ncol = num_unique_levels), - 1 - ) - Xtilde[in_sample, ] <- Xtilde_insample - Xtilde[out_of_sample, ] <- Xtilde_out_of_sample - } else { - x_factor <- factor(x_input, levels = unique_levels) - Xtilde <- cbind(unname(model.matrix(~ 0 + x_factor)), 0) - } - return(Xtilde) + stopifnot((is.null(dim(x_input)) && length(x_input) > 0)) + stopifnot((is.null(dim(unique_levels)) && length(unique_levels) > 0)) + num_unique_levels <- length(unique_levels) + in_sample <- x_input %in% unique_levels + out_of_sample <- !(x_input %in% unique_levels) + has_out_of_sample <- sum(out_of_sample) > 0 + if (has_out_of_sample) { + x_factor_insample <- factor(x_input[in_sample], levels = unique_levels) + Xtilde <- matrix( + 0, + nrow = length(x_input), + ncol = num_unique_levels + 1 + ) + Xtilde_insample <- cbind( + unname(model.matrix(~ 0 + x_factor_insample)), + 0 + ) + Xtilde_out_of_sample <- cbind( + matrix(0, nrow = sum(out_of_sample), ncol = num_unique_levels), + 1 + ) + Xtilde[in_sample, ] <- Xtilde_insample + Xtilde[out_of_sample, ] <- Xtilde_out_of_sample + } else { + x_factor <- factor(x_input, levels = unique_levels) + Xtilde <- cbind(unname(model.matrix(~ 0 + x_factor)), 0) + } + return(Xtilde) } #' Run some simple preprocessing of ordered categorical variables, converting @@ -897,17 +897,17 @@ oneHotEncode <- function(x_input, unique_levels) { #' preprocess_list <- orderedCatInitializeAndPreprocess(x) #' x_preprocessed <- preprocess_list$x_preprocessed orderedCatInitializeAndPreprocess <- function(x_input) { - stopifnot((is.null(dim(x_input)) && length(x_input) > 0)) - already_ordered_factor <- (is.factor(x_input)) && (is.ordered(x_input)) - if (already_ordered_factor) { - x_preprocessed <- as.integer(x_input) - unique_levels <- levels(x_input) - } else { - x_factor <- factor(x_input, ordered = TRUE) - x_preprocessed <- as.integer(x_factor) - unique_levels <- levels(x_factor) - } - return(list(x_preprocessed = x_preprocessed, unique_levels = unique_levels)) + stopifnot((is.null(dim(x_input)) && length(x_input) > 0)) + already_ordered_factor <- (is.factor(x_input)) && (is.ordered(x_input)) + if (already_ordered_factor) { + x_preprocessed <- as.integer(x_input) + unique_levels <- levels(x_input) + } else { + x_factor <- factor(x_input, ordered = TRUE) + x_preprocessed <- as.integer(x_factor) + unique_levels <- levels(x_factor) + } + return(list(x_preprocessed = x_preprocessed, unique_levels = unique_levels)) } #' Run some simple preprocessing of ordered categorical variables, converting @@ -933,56 +933,56 @@ orderedCatInitializeAndPreprocess <- function(x_input) { #' "4. Agree", "3. Neither agree nor disagree", "5. Strongly agree", "4. Agree") #' x_processed <- orderedCatPreprocess(x, x_levels) orderedCatPreprocess <- function(x_input, unique_levels, var_name = NULL) { - stopifnot((is.null(dim(x_input)) && length(x_input) > 0)) - stopifnot((is.null(dim(unique_levels)) && length(unique_levels) > 0)) - already_ordered_factor <- (is.factor(x_input)) && (is.ordered(x_input)) - if (already_ordered_factor) { - # Run time checks - levels_not_in_reflist <- !(levels(x_input) %in% unique_levels) - if (sum(levels_not_in_reflist) > 0) { - if (!is.null(var_name)) { - warning_message <- paste0( - "Variable ", - var_name, - " includes ordered categorical levels not included in the original training set" - ) - } else { - warning_message <- paste0( - "Variable includes ordered categorical levels not included in the original training set" - ) - } - warning(warning_message) - } - # Preprocessing - x_string <- as.character(x_input) - x_factor <- factor(x_string, unique_levels, ordered = TRUE) - x_preprocessed <- as.integer(x_factor) - x_preprocessed[is.na(x_preprocessed)] <- length(unique_levels) + 1 - } else { - x_factor <- factor(x_input, ordered = TRUE) - # Run time checks - levels_not_in_reflist <- !(levels(x_factor) %in% unique_levels) - if (sum(levels_not_in_reflist) > 0) { - if (!is.null(var_name)) { - warning_message <- paste0( - "Variable ", - var_name, - " includes ordered categorical levels not included in the original training set" - ) - } else { - warning_message <- paste0( - "Variable includes ordered categorical levels not included in the original training set" - ) - } - warning(warning_message) - } - # Preprocessing - x_string <- as.character(x_input) - x_factor <- factor(x_string, unique_levels, ordered = TRUE) - x_preprocessed <- as.integer(x_factor) - x_preprocessed[is.na(x_preprocessed)] <- length(unique_levels) + 1 + stopifnot((is.null(dim(x_input)) && length(x_input) > 0)) + stopifnot((is.null(dim(unique_levels)) && length(unique_levels) > 0)) + already_ordered_factor <- (is.factor(x_input)) && (is.ordered(x_input)) + if (already_ordered_factor) { + # Run time checks + levels_not_in_reflist <- !(levels(x_input) %in% unique_levels) + if (sum(levels_not_in_reflist) > 0) { + if (!is.null(var_name)) { + warning_message <- paste0( + "Variable ", + var_name, + " includes ordered categorical levels not included in the original training set" + ) + } else { + warning_message <- paste0( + "Variable includes ordered categorical levels not included in the original training set" + ) + } + warning(warning_message) } - return(x_preprocessed) + # Preprocessing + x_string <- as.character(x_input) + x_factor <- factor(x_string, unique_levels, ordered = TRUE) + x_preprocessed <- as.integer(x_factor) + x_preprocessed[is.na(x_preprocessed)] <- length(unique_levels) + 1 + } else { + x_factor <- factor(x_input, ordered = TRUE) + # Run time checks + levels_not_in_reflist <- !(levels(x_factor) %in% unique_levels) + if (sum(levels_not_in_reflist) > 0) { + if (!is.null(var_name)) { + warning_message <- paste0( + "Variable ", + var_name, + " includes ordered categorical levels not included in the original training set" + ) + } else { + warning_message <- paste0( + "Variable includes ordered categorical levels not included in the original training set" + ) + } + warning(warning_message) + } + # Preprocessing + x_string <- as.character(x_input) + x_factor <- factor(x_string, unique_levels, ordered = TRUE) + x_preprocessed <- as.integer(x_factor) + x_preprocessed[is.na(x_preprocessed)] <- length(unique_levels) + 1 + } + return(x_preprocessed) } #' Convert scalar input to vector of dimension `output_size`, @@ -993,19 +993,19 @@ orderedCatPreprocess <- function(x_input, unique_levels, var_name = NULL) { #' @return A vector of length `output_size` #' @export expand_dims_1d <- function(input, output_size) { - if (length(input) == 1) { - output <- rep(input, output_size) - } else if (is.numeric(input)) { - if (length(input) != output_size) { - stop("`input` must be a 1D numpy array with `output_size` elements") - } - output <- input - } else { - stop( - "`input` must be either a 1D numpy array or a scalar that can be repeated `output_size` times" - ) + if (length(input) == 1) { + output <- rep(input, output_size) + } else if (is.numeric(input)) { + if (length(input) != output_size) { + stop("`input` must be a 1D numpy array with `output_size` elements") } - return(output) + output <- input + } else { + stop( + "`input` must be either a 1D numpy array or a scalar that can be repeated `output_size` times" + ) + } + return(output) } #' Ensures that input is propagated appropriately to a matrix of dimension `output_rows` x `output_cols`. @@ -1022,41 +1022,41 @@ expand_dims_1d <- function(input, output_size) { #' @return A matrix of dimension `output_rows` x `output_cols` #' @export expand_dims_2d <- function(input, output_rows, output_cols) { - if (length(input) == 1) { - output <- matrix( - rep(input, output_rows * output_cols), - ncol = output_cols - ) - } else if (is.numeric(input)) { - if (length(input) == output_cols) { - output <- matrix( - rep(input, output_rows), - nrow = output_rows, - byrow = T - ) - } else if (length(input) == output_rows) { - output <- matrix( - rep(input, output_cols), - ncol = output_cols, - byrow = F - ) - } else { - stop( - "If `input` is a vector, it must either contain `output_rows` or `output_cols` elements" - ) - } - } else if (is.matrix(input)) { - if (nrow(input) != output_rows) { - stop("`input` must be a matrix with `output_rows` rows") - } - if (ncol(input) != output_cols) { - stop("`input` must be a matrix with `output_cols` columns") - } - output <- input + if (length(input) == 1) { + output <- matrix( + rep(input, output_rows * output_cols), + ncol = output_cols + ) + } else if (is.numeric(input)) { + if (length(input) == output_cols) { + output <- matrix( + rep(input, output_rows), + nrow = output_rows, + byrow = T + ) + } else if (length(input) == output_rows) { + output <- matrix( + rep(input, output_cols), + ncol = output_cols, + byrow = F + ) } else { - stop("`input` must be either a matrix, vector or a scalar") + stop( + "If `input` is a vector, it must either contain `output_rows` or `output_cols` elements" + ) + } + } else if (is.matrix(input)) { + if (nrow(input) != output_rows) { + stop("`input` must be a matrix with `output_rows` rows") } - return(output) + if (ncol(input) != output_cols) { + stop("`input` must be a matrix with `output_cols` columns") + } + output <- input + } else { + stop("`input` must be either a matrix, vector or a scalar") + } + return(output) } #' Convert scalar input to square matrix of dimension `output_size` x `output_size` with `input` along the diagonal, @@ -1067,20 +1067,20 @@ expand_dims_2d <- function(input, output_rows, output_cols) { #' @return A square matrix of dimension `output_size` x `output_size` #' @export expand_dims_2d_diag <- function(input, output_size) { - if (length(input) == 1) { - output <- as.matrix(diag(input, output_size)) - } else if (is.matrix(input)) { - if (nrow(input) != ncol(input)) { - stop("`input` must be a square matrix") - } - if (nrow(input) != output_size) { - stop( - "`input` must be a square matrix with `output_size` rows and columns" - ) - } - output <- input - } else { - stop("`input` must be either a square matrix or a scalar") + if (length(input) == 1) { + output <- as.matrix(diag(input, output_size)) + } else if (is.matrix(input)) { + if (nrow(input) != ncol(input)) { + stop("`input` must be a square matrix") + } + if (nrow(input) != output_size) { + stop( + "`input` must be a square matrix with `output_size` rows and columns" + ) } - return(output) + output <- input + } else { + stop("`input` must be either a square matrix or a scalar") + } + return(output) } diff --git a/R/variance.R b/R/variance.R index b0ad722e..11c6c326 100644 --- a/R/variance.R +++ b/R/variance.R @@ -19,19 +19,19 @@ #' b <- 1.0 #' sigma2 <- sampleGlobalErrorVarianceOneIteration(outcome, forest_dataset, rng, a, b) sampleGlobalErrorVarianceOneIteration <- function( - residual, - dataset, - rng, + residual, + dataset, + rng, + a, + b +) { + return(sample_sigma2_one_iteration_cpp( + residual$data_ptr, + dataset$data_ptr, + rng$rng_ptr, a, b -) { - return(sample_sigma2_one_iteration_cpp( - residual$data_ptr, - dataset$data_ptr, - rng$rng_ptr, - a, - b - )) + )) } #' Sample one iteration of the leaf parameter variance model (only for univariate basis and constant leaf!) @@ -54,5 +54,5 @@ sampleGlobalErrorVarianceOneIteration <- function( #' b <- 1.0 #' tau <- sampleLeafVarianceOneIteration(active_forest, rng, a, b) sampleLeafVarianceOneIteration <- function(forest, rng, a, b) { - return(sample_tau_one_iteration_cpp(forest$forest_ptr, rng$rng_ptr, a, b)) + return(sample_tau_one_iteration_cpp(forest$forest_ptr, rng$rng_ptr, a, b)) } From 784fac9fb6cc5ac4f77bbbb3d34d238436b6d0e0 Mon Sep 17 00:00:00 2001 From: Drew Herren Date: Thu, 23 Oct 2025 15:36:58 -0500 Subject: [PATCH 48/53] Updated python unit tests --- test/python/test_bcf.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/python/test_bcf.py b/test/python/test_bcf.py index e03631fa..d9d5ee5c 100644 --- a/test/python/test_bcf.py +++ b/test/python/test_bcf.py @@ -884,7 +884,7 @@ def rfx_term(group_labels, basis): num_gfr=num_gfr, num_burnin=num_burnin, num_mcmc=num_mcmc, - rfx_params=rfx_params + random_effects_params=rfx_params ) # Specify all relevant rfx parameters as vectors @@ -912,5 +912,5 @@ def rfx_term(group_labels, basis): num_gfr=num_gfr, num_burnin=num_burnin, num_mcmc=num_mcmc, - rfx_params=rfx_params + random_effects_params=rfx_params ) From 8b2d62e372cc7d874ae69c2a32fb4ce677dd705b Mon Sep 17 00:00:00 2001 From: Drew Herren Date: Wed, 29 Oct 2025 17:41:59 -0500 Subject: [PATCH 49/53] Added support for random intercept specification in python BART --- stochtree/bart.py | 99 ++++++++++++++++++++++++++++++++++++++--------- stochtree/bcf.py | 39 ++++++++++++------- 2 files changed, 105 insertions(+), 33 deletions(-) diff --git a/stochtree/bart.py b/stochtree/bart.py index 0eb22d87..bfd81ac2 100644 --- a/stochtree/bart.py +++ b/stochtree/bart.py @@ -85,7 +85,7 @@ def sample( general_params: Optional[Dict[str, Any]] = None, mean_forest_params: Optional[Dict[str, Any]] = None, variance_forest_params: Optional[Dict[str, Any]] = None, - rfx_params: Optional[Dict[str, Any]] = None, + random_effects_params: Optional[Dict[str, Any]] = None, previous_model_json: Optional[str] = None, previous_model_warmstart_sample_num: Optional[int] = None, ) -> None: @@ -170,9 +170,10 @@ def sample( * `drop_vars` (`list` or `np.array`): Vector of variable names or column indices denoting variables that should be excluded from the variance forest. Defaults to `None`. If both `drop_vars` and `keep_vars` are set, `drop_vars` will be ignored. * `num_features_subsample` (`int`): How many features to subsample when growing each tree for the GFR algorithm. Defaults to the number of features in the training dataset. - rfx_params : dict, optional + random_effects_params : dict, optional Dictionary of random effects parameters, each of which has a default value processed internally, so this argument is optional. + * `model_spec`: Specification of the random effects model. Options are "custom" and "intercept_only". If "custom" is specified, then a user-provided basis must be passed through `rfx_basis_train`. If "intercept_only" is specified, a random effects basis of all ones will be dispatched internally at sampling and prediction time. If "intercept_plus_treatment" is specified, a random effects basis that combines an "intercept" basis of all ones with the treatment variable (`Z_train`) will be dispatched internally at sampling and prediction time. Default: "custom". If "intercept_only" is specified, `rfx_basis_train` and `rfx_basis_test` (if provided) will be ignored. * `working_parameter_prior_mean`: Prior mean for the random effects "working parameter". Default: `None`. Must be a 1D numpy array whose dimension matches the number of random effects bases, or a scalar value that will be expanded to a vector. * `group_parameter_prior_mean`: Prior mean for the random effects "group parameters." Default: `None`. Must be a 1D numpy array whose dimension matches the number of random effects bases, or a scalar value that will be expanded to a vector. * `working_parameter_prior_cov`: Prior covariance matrix for the random effects "working parameter." Default: `None`. Must be a square numpy matrix whose dimension matches the number of random effects bases, or a scalar value that will be expanded to a diagonal matrix. @@ -251,6 +252,7 @@ def sample( # Update random effects parameters rfx_params_default = { + "model_spec": "custom", "working_parameter_prior_mean": None, "group_parameter_prior_mean": None, "working_parameter_prior_cov": None, @@ -258,7 +260,7 @@ def sample( "variance_prior_shape": 1.0, "variance_prior_scale": 1.0, } - rfx_params_updated = _preprocess_params(rfx_params_default, rfx_params) + rfx_params_updated = _preprocess_params(rfx_params_default, random_effects_params) ### Unpack all parameter values # 1. General parameters @@ -312,6 +314,7 @@ def sample( ] # 4. Random effects parameters + self.rfx_model_spec = rfx_params_updated["model_spec"] rfx_working_parameter_prior_mean = rfx_params_updated[ "working_parameter_prior_mean" ] @@ -325,6 +328,12 @@ def sample( rfx_variance_prior_shape = rfx_params_updated["variance_prior_shape"] rfx_variance_prior_scale = rfx_params_updated["variance_prior_scale"] + # Check random effects specification + if not isinstance(self.rfx_model_spec, str): + raise ValueError("rfx_model_spec must be a string") + if self.rfx_model_spec not in ["custom", "intercept_only"]: + raise ValueError("type must either be 'custom' or 'intercept_only'") + # Override keep_gfr if there are no MCMC samples if num_mcmc == 0: keep_gfr = True @@ -980,24 +989,35 @@ def sample( "All random effect group labels provided in rfx_group_ids_test must be present in rfx_group_ids_train" ) - # Fill in rfx basis as a vector of 1s (random intercept) if a basis not provided - has_basis_rfx = False + # Handle the rfx basis matrices + self.has_rfx_basis = False + self.num_rfx_basis = 0 if self.has_rfx: - if rfx_basis_train is None: - rfx_basis_train = np.ones((rfx_group_ids_train.shape[0], 1)) - else: - has_basis_rfx = True + if self.rfx_model_spec == "custom": + if rfx_basis_train is None: + raise ValueError( + "rfx_basis_train must be provided when rfx_model_spec = 'custom'" + ) + elif self.rfx_model_spec == "intercept_only": + if rfx_basis_train is None: + rfx_basis_train = np.ones((rfx_group_ids_train.shape[0], 1)) + self.has_rfx_basis = True + self.num_rfx_basis = rfx_basis_train.shape[1] num_rfx_groups = np.unique(rfx_group_ids_train).shape[0] num_rfx_components = rfx_basis_train.shape[1] - # TODO warn if num_rfx_groups is 1 + if num_rfx_groups == 1: + warnings.warn( + "Only one group was provided for random effect sampling, so the random effects model is likely overkill" + ) if has_rfx_test: - if rfx_basis_test is None: - if has_basis_rfx: + if self.rfx_model_spec == "custom": + if rfx_basis_test is None: raise ValueError( - "Random effects basis provided for training set, must also be provided for the test set" + "rfx_basis_test must be provided when rfx_model_spec = 'custom' and a test set is provided" ) - rfx_basis_test = np.ones((rfx_group_ids_test.shape[0], 1)) - + elif self.rfx_model_spec == "intercept_only": + if rfx_basis_test is None: + rfx_basis_test = np.ones((rfx_group_ids_test.shape[0], 1)) # Set up random effects structures if self.has_rfx: # Prior parameters @@ -1676,6 +1696,8 @@ def predict( predict_mean = type == "mean" # Handle prediction terms + rfx_model_spec = self.rfx_model_spec + rfx_intercept = rfx_model_spec == "intercept_only" if not isinstance(terms, str) and not isinstance(terms, list): raise ValueError("type must be a string or list of strings") num_terms = 1 if isinstance(terms, str) else len(terms) @@ -1801,12 +1823,51 @@ def predict( ) mean_forest_predictions = mean_pred_raw * self.y_std + self.y_bar + # Random effects data checks + if has_rfx: + if rfx_group_ids is None: + raise ValueError( + "rfx_group_ids must be provided if rfx_basis is provided" + ) + if rfx_basis is not None: + if rfx_basis.ndim == 1: + rfx_basis = np.expand_dims(rfx_basis, 1) + if rfx_basis.shape[0] != covariates.shape[0]: + raise ValueError("X and rfx_basis must have the same number of rows") + if rfx_basis.shape[1] != self.num_rfx_basis: + raise ValueError("rfx_basis must have the same number of columns as the random effects basis used to sample this model") + # Random effects predictions if predict_rfx or predict_rfx_intermediate: - rfx_predictions = ( - self.rfx_container.predict(rfx_group_ids, rfx_basis) * self.y_std - ) - + if rfx_basis is not None: + rfx_predictions = ( + self.rfx_container.predict(rfx_group_ids, rfx_basis) * self.y_std + ) + else: + # Sanity check -- this branch should only occur if rfx_model_spec == "intercept_only" + if not rfx_intercept: + raise ValueError( + "rfx_basis must be provided for random effects models with random slopes" + ) + + # Extract the raw RFX samples and scale by train set outcome standard deviation + rfx_samples_raw = self.rfx_container.extract_parameter_samples() + rfx_beta_draws = rfx_samples_raw['beta_samples'] * self.y_std + + # Construct an array with the appropriate group random effects arranged for each observation + n_train = covariates.shape[0] + if rfx_beta_draws.ndim != 2: + raise ValueError( + "BART models fit with random intercept models should only yield 2 dimensional random effect sample matrices" + ) + else: + rfx_predictions_raw = np.empty(shape=(n_train, 1, rfx_beta_draws.shape[1])) + for i in range(n_train): + rfx_predictions_raw[i, 0, :] = rfx_beta_draws[ + rfx_group_ids[i], : + ] + rfx_predictions = np.squeeze(rfx_predictions_raw[:, 0, :]) + # Combine into y hat predictions if probability_scale: if predict_y_hat and has_mean_forest and has_rfx: diff --git a/stochtree/bcf.py b/stochtree/bcf.py index 7e67a11e..c3105e7a 100644 --- a/stochtree/bcf.py +++ b/stochtree/bcf.py @@ -208,7 +208,7 @@ def sample( * `drop_vars` (`list` or `np.array`): Vector of variable names or column indices denoting variables that should be excluded from the variance forest. Defaults to `None`. If both `drop_vars` and `keep_vars` are set, `drop_vars` will be ignored. * `num_features_subsample` (`int`): How many features to subsample when growing each tree for the GFR algorithm. Defaults to the number of features in the training dataset. - rfx_params : dict, optional + random_effects_params : dict, optional Dictionary of random effects parameters, each of which has a default value processed internally, so this argument is optional. * `model_spec`: Specification of the random effects model. Options are "custom", "intercept_only", and "intercept_plus_treatment". If "custom" is specified, then a user-provided basis must be passed through `rfx_basis_train`. If "intercept_only" is specified, a random effects basis of all ones will be dispatched internally at sampling and prediction time. If "intercept_plus_treatment" is specified, a random effects basis that combines an "intercept" basis of all ones with the treatment variable (`Z_train`) will be dispatched internally at sampling and prediction time. Default: "custom". If either "intercept_only" or "intercept_plus_treatment" is specified, `rfx_basis_train` and `rfx_basis_test` (if provided) will be ignored. @@ -2504,13 +2504,13 @@ def predict( raise ValueError( "rfx_group_ids must be provided if rfx_basis is provided" ) - if rfx_basis is not None: - if rfx_basis.ndim == 1: - rfx_basis = np.expand_dims(rfx_basis, 1) - if rfx_basis.shape[0] != X.shape[0]: - raise ValueError("X and rfx_basis must have the same number of rows") - if rfx_basis.shape[1] != self.num_rfx_basis: - raise ValueError("rfx_basis must have the same number of columns as the random effects basis used to sample this model") + if rfx_basis is not None: + if rfx_basis.ndim == 1: + rfx_basis = np.expand_dims(rfx_basis, 1) + if rfx_basis.shape[0] != X.shape[0]: + raise ValueError("X and rfx_basis must have the same number of rows") + if rfx_basis.shape[1] != self.num_rfx_basis: + raise ValueError("rfx_basis must have the same number of columns as the random effects basis used to sample this model") # Random effects predictions if predict_rfx or predict_rfx_intermediate: @@ -2524,12 +2524,23 @@ def predict( rfx_samples_raw = self.rfx_container.extract_parameter_samples() rfx_beta_draws = rfx_samples_raw['beta_samples'] * self.y_std - # Construct a matrix with the appropriate group random effects arranged for each observation - rfx_predictions_raw = np.empty(shape=(X.shape[0], rfx_beta_draws.shape[0], rfx_beta_draws.shape[2])) - for i in range(X.shape[0]): - rfx_predictions_raw[i, :, :] = rfx_beta_draws[ - :, rfx_group_ids[i], : - ] + # Construct an array with the appropriate group random effects arranged for each observation + if rfx_beta_draws.ndim == 3: + rfx_predictions_raw = np.empty(shape=(X.shape[0], rfx_beta_draws.shape[0], rfx_beta_draws.shape[2])) + for i in range(X.shape[0]): + rfx_predictions_raw[i, :, :] = rfx_beta_draws[ + :, rfx_group_ids[i], : + ] + elif rfx_beta_draws.ndim == 2: + rfx_predictions_raw = np.empty(shape=(X.shape[0], 1, rfx_beta_draws.shape[1])) + for i in range(X.shape[0]): + rfx_predictions_raw[i, 0, :] = rfx_beta_draws[ + rfx_group_ids[i], : + ] + else: + raise ValueError( + "Unexpected number of dimensions in extracted random effects samples" + ) # Add raw RFX predictions to mu and tau if warranted by the RFX model spec if predict_mu_forest or predict_mu_forest_intermediate: From 3c2d4ceffcbe98e0033d7653cf43f01b7a840e69 Mon Sep 17 00:00:00 2001 From: Drew Herren Date: Wed, 29 Oct 2025 18:06:22 -0500 Subject: [PATCH 50/53] Updated serialization methods to reflect python BART updates --- stochtree/bart.py | 11 +++++++++++ test/python/test_bart.py | 33 +++++++++++++++++++++++++++++++-- 2 files changed, 42 insertions(+), 2 deletions(-) diff --git a/stochtree/bart.py b/stochtree/bart.py index bfd81ac2..b1dcf576 100644 --- a/stochtree/bart.py +++ b/stochtree/bart.py @@ -2386,6 +2386,8 @@ def to_json(self) -> str: bart_json.add_boolean("include_mean_forest", self.include_mean_forest) bart_json.add_boolean("include_variance_forest", self.include_variance_forest) bart_json.add_boolean("has_rfx", self.has_rfx) + bart_json.add_boolean("has_rfx_basis", self.has_rfx_basis) + bart_json.add_scalar("num_rfx_basis", self.num_rfx_basis) bart_json.add_integer("num_gfr", self.num_gfr) bart_json.add_integer("num_burnin", self.num_burnin) bart_json.add_integer("num_mcmc", self.num_mcmc) @@ -2393,6 +2395,7 @@ def to_json(self) -> str: bart_json.add_integer("num_basis", self.num_basis) bart_json.add_boolean("requires_basis", self.has_basis) bart_json.add_boolean("probit_outcome_model", self.probit_outcome_model) + bart_json.add_string("rfx_model_spec", self.rfx_model_spec) # Add parameter samples if self.sample_sigma2_global: @@ -2427,6 +2430,8 @@ def from_json(self, json_string: str) -> None: self.include_mean_forest = bart_json.get_boolean("include_mean_forest") self.include_variance_forest = bart_json.get_boolean("include_variance_forest") self.has_rfx = bart_json.get_boolean("has_rfx") + self.has_rfx_basis = bart_json.get_boolean("has_rfx_basis") + self.num_rfx_basis = bart_json.get_scalar("num_rfx_basis") if self.include_mean_forest: # TODO: don't just make this a placeholder that we overwrite self.forest_container_mean = ForestContainer(0, 0, False, False) @@ -2465,6 +2470,7 @@ def from_json(self, json_string: str) -> None: self.num_basis = bart_json.get_integer("num_basis") self.has_basis = bart_json.get_boolean("requires_basis") self.probit_outcome_model = bart_json.get_boolean("probit_outcome_model") + self.rfx_model_spec = bart_json.get_string("rfx_model_spec") # Unpack parameter samples if self.sample_sigma2_global: @@ -2550,6 +2556,8 @@ def from_json_string_list(self, json_string_list: list[str]) -> None: # Unpack random effects self.has_rfx = json_object_default.get_boolean("has_rfx") + self.has_rfx_basis = json_object_default.get_boolean("has_rfx_basis") + self.num_rfx_basis = json_object_default.get_scalar("num_rfx_basis") if self.has_rfx: self.rfx_container = RandomEffectsContainer() for i in range(len(json_object_list)): @@ -2575,6 +2583,9 @@ def from_json_string_list(self, json_string_list: list[str]) -> None: self.probit_outcome_model = json_object_default.get_boolean( "probit_outcome_model" ) + self.rfx_model_spec = json_object_default.get_string( + "rfx_model_spec" + ) # Unpack number of samples for i in range(len(json_object_list)): diff --git a/test/python/test_bart.py b/test/python/test_bart.py index 6e26d84e..0156f9e9 100644 --- a/test/python/test_bart.py +++ b/test/python/test_bart.py @@ -1123,6 +1123,7 @@ def conditional_stddev(X): # Specify scalar rfx parameters rfx_params = { + "model_spec": "custom", "working_parameter_prior_mean": 1., "group_parameter_prior_mean": 1., "working_parameter_prior_cov": 1., @@ -1144,11 +1145,12 @@ def conditional_stddev(X): num_gfr=num_gfr, num_burnin=num_burnin, num_mcmc=num_mcmc, - rfx_params=rfx_params, + random_effects_params=rfx_params, ) # Specify all relevant rfx parameters as vectors rfx_params = { + "model_spec": "custom", "working_parameter_prior_mean": np.repeat(1., num_rfx_basis), "group_parameter_prior_mean": np.repeat(1., num_rfx_basis), "working_parameter_prior_cov": np.identity(num_rfx_basis), @@ -1170,5 +1172,32 @@ def conditional_stddev(X): num_gfr=num_gfr, num_burnin=num_burnin, num_mcmc=num_mcmc, - rfx_params=rfx_params, + random_effects_params=rfx_params, ) + + # Fit a simpler intercept-only RFX model + rfx_params = { + "model_spec": "intercept_only" + } + bart_model_4 = BARTModel() + bart_model_4.sample( + X_train=X_train, + y_train=y_train, + leaf_basis_train=basis_train, + rfx_group_ids_train=group_labels_train, + X_test=X_test, + leaf_basis_test=basis_test, + rfx_group_ids_test=group_labels_test, + num_gfr=num_gfr, + num_burnin=num_burnin, + num_mcmc=num_mcmc, + random_effects_params=rfx_params, + ) + preds = bart_model_4.predict( + covariates=X_test, + basis=basis_test, + rfx_group_ids=group_labels_test, + type="posterior", + terms="rfx" + ) + assert preds.shape == (n_test, num_mcmc) From 5d6535ca27c7cd274905301b6a14dbb14fc9c6dc Mon Sep 17 00:00:00 2001 From: Drew Herren Date: Wed, 29 Oct 2025 19:03:02 -0500 Subject: [PATCH 51/53] Updated R BART to include intercept only RFX specification --- R/bart.R | 120 ++++++++++++++++++++++++++++-------- man/RandomEffectSamples.Rd | 2 +- man/bart.Rd | 1 + man/bcf.Rd | 1 + man/predict.bcfmodel.Rd | 2 +- test/R/testthat/test-bart.R | 23 +++++++ 6 files changed, 123 insertions(+), 26 deletions(-) diff --git a/R/bart.R b/R/bart.R index 4369c273..3fd6bb86 100644 --- a/R/bart.R +++ b/R/bart.R @@ -79,6 +79,7 @@ #' #' @param random_effects_params (Optional) A list of random effects model parameters, each of which has a default value processed internally, so this argument list is optional. #' +#' - `model_spec` Specification of the random effects model. Options are "custom" and "intercept_only". If "custom" is specified, then a user-provided basis must be passed through `rfx_basis_train`. If "intercept_only" is specified, a random effects basis of all ones will be dispatched internally at sampling and prediction time. If "intercept_plus_treatment" is specified, a random effects basis that combines an "intercept" basis of all ones with the treatment variable (`Z_train`) will be dispatched internally at sampling and prediction time. Default: "custom". If "intercept_only" is specified, `rfx_basis_train` and `rfx_basis_test` (if provided) will be ignored. #' - `working_parameter_prior_mean` Prior mean for the random effects "working parameter". Default: `NULL`. Must be a vector whose dimension matches the number of random effects bases, or a scalar value that will be expanded to a vector. #' - `group_parameters_prior_mean` Prior mean for the random effects "group parameters." Default: `NULL`. Must be a vector whose dimension matches the number of random effects bases, or a scalar value that will be expanded to a vector. #' - `working_parameter_prior_cov` Prior covariance matrix for the random effects "working parameter." Default: `NULL`. Must be a square matrix whose dimension matches the number of random effects bases, or a scalar value that will be expanded to a diagonal matrix. @@ -198,6 +199,7 @@ bart <- function( # Update rfx parameters rfx_params_default <- list( + model_spec = "custom", working_parameter_prior_mean = NULL, group_parameter_prior_mean = NULL, working_parameter_prior_cov = NULL, @@ -257,6 +259,7 @@ bart <- function( num_features_subsample_variance <- variance_forest_params_updated$num_features_subsample # 4. RFX parameters + rfx_model_spec <- rfx_params_updated$model_spec rfx_working_parameter_prior_mean <- rfx_params_updated$working_parameter_prior_mean rfx_group_parameter_prior_mean <- rfx_params_updated$group_parameter_prior_mean rfx_working_parameter_prior_cov <- rfx_params_updated$working_parameter_prior_cov @@ -614,35 +617,43 @@ bart <- function( } } - # Fill in rfx basis as a vector of 1s (random intercept) if a basis not provided + # Handle the rfx basis matrices has_basis_rfx <- FALSE num_basis_rfx <- 0 if (has_rfx) { - if (is.null(rfx_basis_train)) { + if (rfx_model_spec == "custom") { + if (is.null(rfx_basis_train)) { + stop( + "A user-provided basis (`rfx_basis_train`) must be provided when the random effects model spec is 'custom'" + ) + } + has_basis_rfx <- TRUE + num_basis_rfx <- ncol(rfx_basis_train) + } else if (rfx_model_spec == "intercept_only") { rfx_basis_train <- matrix( rep(1, nrow(X_train)), nrow = nrow(X_train), ncol = 1 ) - } else { has_basis_rfx <- TRUE - num_basis_rfx <- ncol(rfx_basis_train) + num_basis_rfx <- 1 } num_rfx_groups <- length(unique(rfx_group_ids_train)) num_rfx_components <- ncol(rfx_basis_train) if (num_rfx_groups == 1) { warning( - "Only one group was provided for random effect sampling, so the 'redundant parameterization' is likely overkill" + "Only one group was provided for random effect sampling, so the random effects model is likely overkill" ) } } if (has_rfx_test) { - if (is.null(rfx_basis_test)) { - if (has_basis_rfx) { + if (rfx_model_spec == "custom") { + if (is.null(rfx_basis_test)) { stop( - "Random effects basis provided for training set, must also be provided for the test set" + "A user-provided basis (`rfx_basis_test`) must be provided when the random effects model spec is 'custom'" ) } + } else if (rfx_model_spec == "intercept_only") { rfx_basis_test <- matrix( rep(1, nrow(X_test)), nrow = nrow(X_test), @@ -1744,7 +1755,8 @@ bart <- function( "sample_sigma2_leaf" = sample_sigma2_leaf, "include_mean_forest" = include_mean_forest, "include_variance_forest" = include_variance_forest, - "probit_outcome_model" = probit_outcome_model + "probit_outcome_model" = probit_outcome_model, + "rfx_model_spec" = rfx_model_spec ) result <- list( "model_params" = model_params, @@ -1878,6 +1890,8 @@ predict.bartmodel <- function( predict_mean <- type == "mean" # Handle prediction terms + rfx_model_spec <- object$model_params$rfx_model_spec + rfx_intercept <- rfx_model_spec == "intercept_only" if (!is.character(terms)) { stop("type must be a string or character vector") } @@ -1954,16 +1968,17 @@ predict.bartmodel <- function( "Random effect group labels (rfx_group_ids) must be provided for this model" ) } - if ((predict_rfx) && (is.null(rfx_basis))) { + if ((predict_rfx) && (is.null(rfx_basis)) && (!rfx_intercept)) { stop("Random effects basis (rfx_basis) must be provided for this model") } if ( - (object$model_params$num_rfx_basis > 0) && - (ncol(rfx_basis) != object$model_params$num_rfx_basis) + (object$model_params$num_rfx_basis > 0) && (!rfx_intercept) ) { - stop( - "Random effects basis has a different dimension than the basis used to train this model" - ) + if (ncol(rfx_basis) != object$model_params$num_rfx_basis) { + stop( + "Random effects basis has a different dimension than the basis used to train this model" + ) + } } # Preprocess covariates @@ -1986,11 +2001,26 @@ predict.bartmodel <- function( } } - # Produce basis for the "intercept-only" random effects case - if ((predict_rfx) && (is.null(rfx_basis))) { - rfx_basis <- matrix(rep(1, nrow(covariates)), ncol = 1) + # Handle RFX model specification + if (has_rfx) { + if (object$model_params$rfx_model_spec == "custom") { + if (is.null(rfx_basis)) { + stop( + "A user-provided basis (`rfx_basis`) must be provided when the model was sampled with a random effects model spec set to 'custom'" + ) + } + } else if (object$model_params$rfx_model_spec == "intercept_only") { + # Only construct a basis if user-provided basis missing + if (is.null(rfx_basis)) { + rfx_basis <- matrix( + rep(1, nrow(covariates)), + nrow = nrow(covariates), + ncol = 1 + ) + } + } } - + # Create prediction dataset if (!is.null(leaf_basis)) { prediction_dataset <- createForestDataset(covariates, leaf_basis) @@ -2033,11 +2063,40 @@ predict.bartmodel <- function( # Compute rfx predictions (if needed) if (predict_rfx || predict_rfx_intermediate) { - rfx_predictions <- object$rfx_samples$predict( - rfx_group_ids, - rfx_basis - ) * - y_std + if (!is.null(rfx_basis)) { + rfx_predictions <- object$rfx_samples$predict( + rfx_group_ids, + rfx_basis + ) * + y_std + } else { + # Sanity check -- this branch should only occur if rfx_model_spec == "intercept_only" + if (!rfx_intercept) { + stop("rfx_basis must be provided for random effects models with random slopes") + } + + # Extract the raw RFX samples and scale by train set outcome standard deviation + rfx_param_list <- object$rfx_samples$extract_parameter_samples() + rfx_beta_draws <- rfx_param_list$beta_samples * y_std + + # Construct a matrix with the appropriate group random effects arranged for each observation + rfx_predictions_raw <- array( + NA, + dim = c( + nrow(X), + ncol(rfx_basis), + object$model_params$num_samples + ) + ) + for (i in 1:nrow(X)) { + rfx_predictions_raw[i, , ] <- + rfx_beta_draws[, rfx_group_ids[i], ] + } + + # Intercept-only model, so the random effect prediction is simply the + # value of the respective group's intercept coefficient for each observation + rfx_predictions = rfx_predictions_raw[, 1, ] + } } # Combine into y hat predictions @@ -2310,6 +2369,10 @@ saveBARTModelToJson <- function(object) { "probit_outcome_model", object$model_params$probit_outcome_model ) + jsonobj$add_string( + "rfx_model_spec", + object$model_params$rfx_model_spec + ) if (object$model_params$sample_sigma2_global) { jsonobj$add_vector( "sigma2_global_samples", @@ -2554,6 +2617,9 @@ createBARTModelFromJson <- function(json_object) { model_params[["probit_outcome_model"]] <- json_object$get_boolean( "probit_outcome_model" ) + model_params[["rfx_model_spec"]] <- json_object$get_string( + "rfx_model_spec" + ) output[["model_params"]] <- model_params @@ -2825,6 +2891,9 @@ createBARTModelFromCombinedJson <- function(json_object_list) { model_params[["probit_outcome_model"]] <- json_object_default$get_boolean( "probit_outcome_model" ) + model_params[["rfx_model_spec"]] <- json_object_default$get_string( + "rfx_model_spec" + ) model_params[["num_chains"]] <- json_object_default$get_scalar("num_chains") model_params[["keep_every"]] <- json_object_default$get_scalar("keep_every") @@ -3066,6 +3135,9 @@ createBARTModelFromCombinedJsonString <- function(json_string_list) { model_params[["probit_outcome_model"]] <- json_object_default$get_boolean( "probit_outcome_model" ) + model_params[["rfx_model_spec"]] <- json_object_default$get_string( + "rfx_model_spec" + ) # Combine values that are sample-specific for (i in 1:length(json_object_list)) { diff --git a/man/RandomEffectSamples.Rd b/man/RandomEffectSamples.Rd index ae5e9ac0..beb85a8b 100644 --- a/man/RandomEffectSamples.Rd +++ b/man/RandomEffectSamples.Rd @@ -240,7 +240,7 @@ and sigma (group-independent prior variance for each component of xi). \subsection{Returns}{ List of arrays. The alpha array has dimension (\code{num_components}, \code{num_samples}) and is simply a vector if \code{num_components = 1}. -The xi and beta arrays have dimension (\code{num_components}, \code{num_groups}, \code{num_samples}) and is simply a matrix if \code{num_components = 1}. +The xi and beta arrays have dimension (\code{num_components}, \code{num_groups}, \code{num_samples}) and are simply matrices if \code{num_components = 1}. The sigma array has dimension (\code{num_components}, \code{num_samples}) and is simply a vector if \code{num_components = 1}. } } diff --git a/man/bart.Rd b/man/bart.Rd index 73600fd4..c76ec963 100644 --- a/man/bart.Rd +++ b/man/bart.Rd @@ -122,6 +122,7 @@ that were not in the training set.} \item{random_effects_params}{(Optional) A list of random effects model parameters, each of which has a default value processed internally, so this argument list is optional. \itemize{ +\item \code{model_spec} Specification of the random effects model. Options are "custom" and "intercept_only". If "custom" is specified, then a user-provided basis must be passed through \code{rfx_basis_train}. If "intercept_only" is specified, a random effects basis of all ones will be dispatched internally at sampling and prediction time. If "intercept_plus_treatment" is specified, a random effects basis that combines an "intercept" basis of all ones with the treatment variable (\code{Z_train}) will be dispatched internally at sampling and prediction time. Default: "custom". If "intercept_only" is specified, \code{rfx_basis_train} and \code{rfx_basis_test} (if provided) will be ignored. \item \code{working_parameter_prior_mean} Prior mean for the random effects "working parameter". Default: \code{NULL}. Must be a vector whose dimension matches the number of random effects bases, or a scalar value that will be expanded to a vector. \item \code{group_parameters_prior_mean} Prior mean for the random effects "group parameters." Default: \code{NULL}. Must be a vector whose dimension matches the number of random effects bases, or a scalar value that will be expanded to a vector. \item \code{working_parameter_prior_cov} Prior covariance matrix for the random effects "working parameter." Default: \code{NULL}. Must be a square matrix whose dimension matches the number of random effects bases, or a scalar value that will be expanded to a diagonal matrix. diff --git a/man/bcf.Rd b/man/bcf.Rd index ca3b2983..55e5e181 100644 --- a/man/bcf.Rd +++ b/man/bcf.Rd @@ -148,6 +148,7 @@ that were not in the training set.} \item{random_effects_params}{(Optional) A list of random effects model parameters, each of which has a default value processed internally, so this argument list is optional. \itemize{ +\item \code{model_spec} Specification of the random effects model. Options are "custom", "intercept_only", and "intercept_plus_treatment". If "custom" is specified, then a user-provided basis must be passed through \code{rfx_basis_train}. If "intercept_only" is specified, a random effects basis of all ones will be dispatched internally at sampling and prediction time. If "intercept_plus_treatment" is specified, a random effects basis that combines an "intercept" basis of all ones with the treatment variable (\code{Z_train}) will be dispatched internally at sampling and prediction time. Default: "custom". If either "intercept_only" or "intercept_plus_treatment" is specified, \code{rfx_basis_train} and \code{rfx_basis_test} (if provided) will be ignored. \item \code{working_parameter_prior_mean} Prior mean for the random effects "working parameter". Default: \code{NULL}. Must be a vector whose dimension matches the number of random effects bases, or a scalar value that will be expanded to a vector. \item \code{group_parameters_prior_mean} Prior mean for the random effects "group parameters." Default: \code{NULL}. Must be a vector whose dimension matches the number of random effects bases, or a scalar value that will be expanded to a vector. \item \code{working_parameter_prior_cov} Prior covariance matrix for the random effects "working parameter." Default: \code{NULL}. Must be a square matrix whose dimension matches the number of random effects bases, or a scalar value that will be expanded to a diagonal matrix. diff --git a/man/predict.bcfmodel.Rd b/man/predict.bcfmodel.Rd index 7e8d6e0a..bda63aa5 100644 --- a/man/predict.bcfmodel.Rd +++ b/man/predict.bcfmodel.Rd @@ -30,7 +30,7 @@ We do not currently support (but plan to in the near future), test set evaluation for group labels that were not in the training set.} -\item{rfx_basis}{(Optional) Test set basis for "random-slope" regression in additive random effects model.} +\item{rfx_basis}{(Optional) Test set basis for "random-slope" regression in additive random effects model. If the model was sampled with a random effects \code{model_spec} of "intercept_only" or "intercept_plus_treatment", this is optional, but if it is provided, it will be used.} \item{type}{(Optional) Type of prediction to return. Options are "mean", which averages the predictions from every draw of a BCF model, and "posterior", which returns the entire matrix of posterior predictions. Default: "posterior".} diff --git a/test/R/testthat/test-bart.R b/test/R/testthat/test-bart.R index c3f923aa..59ad2504 100644 --- a/test/R/testthat/test-bart.R +++ b/test/R/testthat/test-bart.R @@ -561,4 +561,27 @@ test_that("Random Effects BART", { random_effects_params = rfx_param_list ) ) + + # Specify simpler intercept-only RFX model + rfx_param_list <- list( + model_spec = "intercept_only" + ) + mean_forest_param_list <- list(sample_sigma2_leaf = FALSE) + expect_no_error({ + bart_model <- bart( + X_train = X_train, + y_train = y_train, + X_test = X_test, + leaf_basis_train = W_train, + leaf_basis_test = W_test, + rfx_group_ids_train = rfx_group_ids_train, + rfx_group_ids_test = rfx_group_ids_test, + num_gfr = 0, + num_burnin = 10, + num_mcmc = 10, + mean_forest_params = mean_forest_param_list, + random_effects_params = rfx_param_list + ) + preds <- predict(bart_model, covariates = X_test, leaf_basis = W_test, rfx_group_ids = rfx_group_ids_test, type = "posterior", terms = "rfx") + }) }) From c00e48069510986c8279d32ce62c08bccad275cd Mon Sep 17 00:00:00 2001 From: Drew Herren Date: Wed, 29 Oct 2025 19:19:45 -0500 Subject: [PATCH 52/53] Reformatting R and Python code --- R/bart.R | 12 ++++++------ stochtree/bart.py | 28 ++++++++++++++++----------- stochtree/bcf.py | 49 ++++++++++++++++++++++++++++++----------------- 3 files changed, 54 insertions(+), 35 deletions(-) diff --git a/R/bart.R b/R/bart.R index 3fd6bb86..4f152ba2 100644 --- a/R/bart.R +++ b/R/bart.R @@ -1971,9 +1971,7 @@ predict.bartmodel <- function( if ((predict_rfx) && (is.null(rfx_basis)) && (!rfx_intercept)) { stop("Random effects basis (rfx_basis) must be provided for this model") } - if ( - (object$model_params$num_rfx_basis > 0) && (!rfx_intercept) - ) { + if ((object$model_params$num_rfx_basis > 0) && (!rfx_intercept)) { if (ncol(rfx_basis) != object$model_params$num_rfx_basis) { stop( "Random effects basis has a different dimension than the basis used to train this model" @@ -2020,7 +2018,7 @@ predict.bartmodel <- function( } } } - + # Create prediction dataset if (!is.null(leaf_basis)) { prediction_dataset <- createForestDataset(covariates, leaf_basis) @@ -2072,7 +2070,9 @@ predict.bartmodel <- function( } else { # Sanity check -- this branch should only occur if rfx_model_spec == "intercept_only" if (!rfx_intercept) { - stop("rfx_basis must be provided for random effects models with random slopes") + stop( + "rfx_basis must be provided for random effects models with random slopes" + ) } # Extract the raw RFX samples and scale by train set outcome standard deviation @@ -2093,7 +2093,7 @@ predict.bartmodel <- function( rfx_beta_draws[, rfx_group_ids[i], ] } - # Intercept-only model, so the random effect prediction is simply the + # Intercept-only model, so the random effect prediction is simply the # value of the respective group's intercept coefficient for each observation rfx_predictions = rfx_predictions_raw[, 1, ] } diff --git a/stochtree/bart.py b/stochtree/bart.py index b1dcf576..3f81c531 100644 --- a/stochtree/bart.py +++ b/stochtree/bart.py @@ -260,7 +260,9 @@ def sample( "variance_prior_shape": 1.0, "variance_prior_scale": 1.0, } - rfx_params_updated = _preprocess_params(rfx_params_default, random_effects_params) + rfx_params_updated = _preprocess_params( + rfx_params_default, random_effects_params + ) ### Unpack all parameter values # 1. General parameters @@ -1459,7 +1461,9 @@ def sample( forest_dataset_train ) if self.has_rfx: - rfx_pred = rfx_model.predict(rfx_dataset_train, rfx_tracker) + rfx_pred = rfx_model.predict( + rfx_dataset_train, rfx_tracker + ) outcome_pred = outcome_pred + rfx_pred mu0 = outcome_pred[y_train[:, 0] == 0] mu1 = outcome_pred[y_train[:, 0] == 1] @@ -1835,8 +1839,10 @@ def predict( if rfx_basis.shape[0] != covariates.shape[0]: raise ValueError("X and rfx_basis must have the same number of rows") if rfx_basis.shape[1] != self.num_rfx_basis: - raise ValueError("rfx_basis must have the same number of columns as the random effects basis used to sample this model") - + raise ValueError( + "rfx_basis must have the same number of columns as the random effects basis used to sample this model" + ) + # Random effects predictions if predict_rfx or predict_rfx_intermediate: if rfx_basis is not None: @@ -1849,10 +1855,10 @@ def predict( raise ValueError( "rfx_basis must be provided for random effects models with random slopes" ) - + # Extract the raw RFX samples and scale by train set outcome standard deviation rfx_samples_raw = self.rfx_container.extract_parameter_samples() - rfx_beta_draws = rfx_samples_raw['beta_samples'] * self.y_std + rfx_beta_draws = rfx_samples_raw["beta_samples"] * self.y_std # Construct an array with the appropriate group random effects arranged for each observation n_train = covariates.shape[0] @@ -1861,13 +1867,15 @@ def predict( "BART models fit with random intercept models should only yield 2 dimensional random effect sample matrices" ) else: - rfx_predictions_raw = np.empty(shape=(n_train, 1, rfx_beta_draws.shape[1])) + rfx_predictions_raw = np.empty( + shape=(n_train, 1, rfx_beta_draws.shape[1]) + ) for i in range(n_train): rfx_predictions_raw[i, 0, :] = rfx_beta_draws[ rfx_group_ids[i], : ] rfx_predictions = np.squeeze(rfx_predictions_raw[:, 0, :]) - + # Combine into y hat predictions if probability_scale: if predict_y_hat and has_mean_forest and has_rfx: @@ -2583,9 +2591,7 @@ def from_json_string_list(self, json_string_list: list[str]) -> None: self.probit_outcome_model = json_object_default.get_boolean( "probit_outcome_model" ) - self.rfx_model_spec = json_object_default.get_string( - "rfx_model_spec" - ) + self.rfx_model_spec = json_object_default.get_string("rfx_model_spec") # Unpack number of samples for i in range(len(json_object_list)): diff --git a/stochtree/bcf.py b/stochtree/bcf.py index c3105e7a..be4410cc 100644 --- a/stochtree/bcf.py +++ b/stochtree/bcf.py @@ -317,7 +317,9 @@ def sample( "variance_prior_shape": 1.0, "variance_prior_scale": 1.0, } - rfx_params_updated = _preprocess_params(rfx_params_default, random_effects_params) + rfx_params_updated = _preprocess_params( + rfx_params_default, random_effects_params + ) ### Unpack all parameter values # 1. General parameters @@ -413,8 +415,14 @@ def sample( # Check random effects specification if not isinstance(self.rfx_model_spec, str): raise ValueError("rfx_model_spec must be a string") - if self.rfx_model_spec not in ["custom", "intercept_only", "intercept_plus_treatment"]: - raise ValueError("type must either be 'custom', 'intercept_only', 'intercept_plus_treatment'") + if self.rfx_model_spec not in [ + "custom", + "intercept_only", + "intercept_plus_treatment", + ]: + raise ValueError( + "type must either be 'custom', 'intercept_only', 'intercept_plus_treatment'" + ) # Override keep_gfr if there are no MCMC samples if num_mcmc == 0: @@ -2295,7 +2303,7 @@ def predict( ) -> Union[dict[str, np.array], np.array]: """Predict outcome model components (CATE function and prognostic function) as well as overall outcome for every provided observation. Predicted outcomes are computed as `yhat = mu_x + Z*tau_x` where mu_x is a sample of the prognostic function and tau_x is a sample of the treatment effect (CATE) function. - When random effects are present, they are either included in yhat additively if `rfx_model_spec == "custom"`. They are included in mu_x if `rfx_model_spec == "intercept_only"` or + When random effects are present, they are either included in yhat additively if `rfx_model_spec == "custom"`. They are included in mu_x if `rfx_model_spec == "intercept_only"` or partially included in mu_x and partially included in tau_x `rfx_model_spec == "intercept_plus_treatment"`. Parameters @@ -2508,9 +2516,13 @@ def predict( if rfx_basis.ndim == 1: rfx_basis = np.expand_dims(rfx_basis, 1) if rfx_basis.shape[0] != X.shape[0]: - raise ValueError("X and rfx_basis must have the same number of rows") + raise ValueError( + "X and rfx_basis must have the same number of rows" + ) if rfx_basis.shape[1] != self.num_rfx_basis: - raise ValueError("rfx_basis must have the same number of columns as the random effects basis used to sample this model") + raise ValueError( + "rfx_basis must have the same number of columns as the random effects basis used to sample this model" + ) # Random effects predictions if predict_rfx or predict_rfx_intermediate: @@ -2522,26 +2534,28 @@ def predict( if predict_rfx_raw: # Extract the raw RFX samples and scale by train set outcome standard deviation rfx_samples_raw = self.rfx_container.extract_parameter_samples() - rfx_beta_draws = rfx_samples_raw['beta_samples'] * self.y_std + rfx_beta_draws = rfx_samples_raw["beta_samples"] * self.y_std # Construct an array with the appropriate group random effects arranged for each observation if rfx_beta_draws.ndim == 3: - rfx_predictions_raw = np.empty(shape=(X.shape[0], rfx_beta_draws.shape[0], rfx_beta_draws.shape[2])) + rfx_predictions_raw = np.empty( + shape=(X.shape[0], rfx_beta_draws.shape[0], rfx_beta_draws.shape[2]) + ) for i in range(X.shape[0]): rfx_predictions_raw[i, :, :] = rfx_beta_draws[ :, rfx_group_ids[i], : ] elif rfx_beta_draws.ndim == 2: - rfx_predictions_raw = np.empty(shape=(X.shape[0], 1, rfx_beta_draws.shape[1])) + rfx_predictions_raw = np.empty( + shape=(X.shape[0], 1, rfx_beta_draws.shape[1]) + ) for i in range(X.shape[0]): - rfx_predictions_raw[i, 0, :] = rfx_beta_draws[ - rfx_group_ids[i], : - ] + rfx_predictions_raw[i, 0, :] = rfx_beta_draws[rfx_group_ids[i], :] else: raise ValueError( "Unexpected number of dimensions in extracted random effects samples" ) - + # Add raw RFX predictions to mu and tau if warranted by the RFX model spec if predict_mu_forest or predict_mu_forest_intermediate: if rfx_intercept and predict_rfx_raw: @@ -2553,7 +2567,6 @@ def predict( tau_x = tau_x_forest + np.squeeze(rfx_predictions_raw[:, 1:, :]) else: tau_x = tau_x_forest - # Combine into y hat predictions needs_mean_term_preds = ( @@ -3308,7 +3321,9 @@ def from_json_string_list(self, json_string_list: list[str]) -> None: self.has_rfx = json_object_default.get_boolean("has_rfx") self.has_rfx_basis = json_object_default.get_boolean("has_rfx_basis") self.num_rfx_basis = json_object_default.get_scalar("num_rfx_basis") - self.multivariate_treatment = json_object_default.get_boolean("multivariate_treatment") + self.multivariate_treatment = json_object_default.get_boolean( + "multivariate_treatment" + ) if self.has_rfx: self.rfx_container = RandomEffectsContainer() for i in range(len(json_object_list)): @@ -3344,9 +3359,7 @@ def from_json_string_list(self, json_string_list: list[str]) -> None: self.probit_outcome_model = json_object_default.get_boolean( "probit_outcome_model" ) - self.rfx_model_spec = json_object_default.get_string( - "rfx_model_spec" - ) + self.rfx_model_spec = json_object_default.get_string("rfx_model_spec") # Unpack number of samples for i in range(len(json_object_list)): From b0b59f257b51b28087e5f14e4190ef31ced4739f Mon Sep 17 00:00:00 2001 From: Drew Herren Date: Wed, 29 Oct 2025 19:20:02 -0500 Subject: [PATCH 53/53] Reformatting R and Python unit test code --- test/R/testthat/test-bart.R | 9 +- test/python/test_bart.py | 46 ++-- test/python/test_bcf.py | 213 ++++++++++------ test/python/test_data.py | 23 +- test/python/test_forest.py | 82 +++--- test/python/test_forest_container.py | 78 +++--- test/python/test_json.py | 172 ++++++++----- test/python/test_kernel.py | 74 +++--- test/python/test_predict.py | 358 +++++++++++++++------------ test/python/test_preprocessor.py | 164 ++++++------ test/python/test_residual.py | 18 +- test/python/test_utils.py | 138 ++++++++--- 12 files changed, 801 insertions(+), 574 deletions(-) diff --git a/test/R/testthat/test-bart.R b/test/R/testthat/test-bart.R index 59ad2504..23013ec2 100644 --- a/test/R/testthat/test-bart.R +++ b/test/R/testthat/test-bart.R @@ -582,6 +582,13 @@ test_that("Random Effects BART", { mean_forest_params = mean_forest_param_list, random_effects_params = rfx_param_list ) - preds <- predict(bart_model, covariates = X_test, leaf_basis = W_test, rfx_group_ids = rfx_group_ids_test, type = "posterior", terms = "rfx") + preds <- predict( + bart_model, + covariates = X_test, + leaf_basis = W_test, + rfx_group_ids = rfx_group_ids_test, + type = "posterior", + terms = "rfx" + ) }) }) diff --git a/test/python/test_bart.py b/test/python/test_bart.py index 0156f9e9..3243b86a 100644 --- a/test/python/test_bart.py +++ b/test/python/test_bart.py @@ -84,7 +84,7 @@ def outcome_mean(X): # Assertions bart_preds_combined = bart_model_3.predict(covariates=X_train) - y_hat_train_combined = bart_preds_combined['y_hat'] + y_hat_train_combined = bart_preds_combined["y_hat"] assert y_hat_train_combined.shape == (n_train, num_mcmc * 2) np.testing.assert_allclose( y_hat_train_combined[:, 0:num_mcmc], bart_model.y_hat_train @@ -192,7 +192,7 @@ def outcome_mean(X, W): bart_preds_combined = bart_model_3.predict( covariates=X_train, basis=basis_train ) - y_hat_train_combined = bart_preds_combined['y_hat'] + y_hat_train_combined = bart_preds_combined["y_hat"] assert y_hat_train_combined.shape == (n_train, num_mcmc * 2) np.testing.assert_allclose( y_hat_train_combined[:, 0:num_mcmc], bart_model.y_hat_train @@ -300,7 +300,7 @@ def outcome_mean(X, W): bart_preds_combined = bart_model_3.predict( covariates=X_train, basis=basis_train ) - y_hat_train_combined = bart_preds_combined['y_hat'] + y_hat_train_combined = bart_preds_combined["y_hat"] assert y_hat_train_combined.shape == (n_train, num_mcmc * 2) np.testing.assert_allclose( y_hat_train_combined[:, 0:num_mcmc], bart_model.y_hat_train @@ -411,7 +411,10 @@ def conditional_stddev(X): # Assertions bart_preds_combined = bart_model_3.predict(covariates=X_train) - y_hat_train_combined, sigma2_x_train_combined = bart_preds_combined['y_hat'], bart_preds_combined['variance_forest_predictions'] + y_hat_train_combined, sigma2_x_train_combined = ( + bart_preds_combined["y_hat"], + bart_preds_combined["variance_forest_predictions"], + ) assert y_hat_train_combined.shape == (n_train, num_mcmc * 2) assert sigma2_x_train_combined.shape == (n_train, num_mcmc * 2) np.testing.assert_allclose( @@ -424,7 +427,8 @@ def conditional_stddev(X): sigma2_x_train_combined[:, 0:num_mcmc], bart_model.sigma2_x_train ) np.testing.assert_allclose( - sigma2_x_train_combined[:, num_mcmc : (2 * num_mcmc)], bart_model_2.sigma2_x_train + sigma2_x_train_combined[:, num_mcmc : (2 * num_mcmc)], + bart_model_2.sigma2_x_train, ) np.testing.assert_allclose( bart_model_3.global_var_samples[0:num_mcmc], bart_model.global_var_samples @@ -543,7 +547,7 @@ def conditional_stddev(X): bart_preds_combined = bart_model_3.predict( covariates=X_train, basis=basis_train ) - y_hat_train_combined = bart_preds_combined['y_hat'] + y_hat_train_combined = bart_preds_combined["y_hat"] assert y_hat_train_combined.shape == (n_train, num_mcmc * 2) np.testing.assert_allclose( y_hat_train_combined[:, 0:num_mcmc], bart_model.y_hat_train @@ -668,7 +672,7 @@ def conditional_stddev(X): bart_preds_combined = bart_model_3.predict( covariates=X_train, basis=basis_train ) - y_hat_train_combined = bart_preds_combined['y_hat'] + y_hat_train_combined = bart_preds_combined["y_hat"] assert y_hat_train_combined.shape == (n_train, num_mcmc * 2) np.testing.assert_allclose( y_hat_train_combined[:, 0:num_mcmc], bart_model.y_hat_train @@ -825,7 +829,7 @@ def rfx_term(group_labels, basis): rfx_group_ids=group_labels_train, rfx_basis=rfx_basis_train, ) - y_hat_train_combined = bart_preds_combined['y_hat'] + y_hat_train_combined = bart_preds_combined["y_hat"] assert y_hat_train_combined.shape == (n_train, num_mcmc * 2) np.testing.assert_allclose( y_hat_train_combined[:, 0:num_mcmc], bart_model.y_hat_train @@ -999,7 +1003,7 @@ def conditional_stddev(X): rfx_group_ids=group_labels_train, rfx_basis=rfx_basis_train, ) - y_hat_train_combined = bart_preds_combined['y_hat'] + y_hat_train_combined = bart_preds_combined["y_hat"] assert y_hat_train_combined.shape == (n_train, num_mcmc * 2) np.testing.assert_allclose( y_hat_train_combined[:, 0:num_mcmc], bart_model.y_hat_train @@ -1059,7 +1063,7 @@ def outcome_mean(X, W): # Define the group rfx function def rfx_term(group_labels, basis): return np.where( - group_labels == 0, -5 + 1. * basis[:,1], 5 - 1. * basis[:,1] + group_labels == 0, -5 + 1.0 * basis[:, 1], 5 - 1.0 * basis[:, 1] ) # Define the conditional standard deviation function @@ -1124,12 +1128,12 @@ def conditional_stddev(X): # Specify scalar rfx parameters rfx_params = { "model_spec": "custom", - "working_parameter_prior_mean": 1., - "group_parameter_prior_mean": 1., - "working_parameter_prior_cov": 1., - "group_parameter_prior_cov": 1., + "working_parameter_prior_mean": 1.0, + "group_parameter_prior_mean": 1.0, + "working_parameter_prior_cov": 1.0, + "group_parameter_prior_cov": 1.0, "variance_prior_shape": 1, - "variance_prior_scale": 1 + "variance_prior_scale": 1, } bart_model_2 = BARTModel() bart_model_2.sample( @@ -1151,12 +1155,12 @@ def conditional_stddev(X): # Specify all relevant rfx parameters as vectors rfx_params = { "model_spec": "custom", - "working_parameter_prior_mean": np.repeat(1., num_rfx_basis), - "group_parameter_prior_mean": np.repeat(1., num_rfx_basis), + "working_parameter_prior_mean": np.repeat(1.0, num_rfx_basis), + "group_parameter_prior_mean": np.repeat(1.0, num_rfx_basis), "working_parameter_prior_cov": np.identity(num_rfx_basis), "group_parameter_prior_cov": np.identity(num_rfx_basis), "variance_prior_shape": 1, - "variance_prior_scale": 1 + "variance_prior_scale": 1, } bart_model_3 = BARTModel() bart_model_3.sample( @@ -1176,9 +1180,7 @@ def conditional_stddev(X): ) # Fit a simpler intercept-only RFX model - rfx_params = { - "model_spec": "intercept_only" - } + rfx_params = {"model_spec": "intercept_only"} bart_model_4 = BARTModel() bart_model_4.sample( X_train=X_train, @@ -1198,6 +1200,6 @@ def conditional_stddev(X): basis=basis_test, rfx_group_ids=group_labels_test, type="posterior", - terms="rfx" + terms="rfx", ) assert preds.shape == (n_test, num_mcmc) diff --git a/test/python/test_bcf.py b/test/python/test_bcf.py index d9d5ee5c..dac1ea25 100644 --- a/test/python/test_bcf.py +++ b/test/python/test_bcf.py @@ -71,13 +71,19 @@ def test_binary_bcf(self): # Check overall prediction method bcf_preds = bcf_model.predict(X_test, Z_test, pi_test) - tau_hat, mu_hat, y_hat = bcf_preds['tau_hat'], bcf_preds['mu_hat'], bcf_preds['y_hat'] + tau_hat, mu_hat, y_hat = ( + bcf_preds["tau_hat"], + bcf_preds["mu_hat"], + bcf_preds["y_hat"], + ) assert tau_hat.shape == (n_test, num_mcmc) assert mu_hat.shape == (n_test, num_mcmc) assert y_hat.shape == (n_test, num_mcmc) # Check that we can predict just treatment effects - tau_hat = bcf_model.predict(X = X_test, Z = Z_test, propensity = pi_test, terms = "cate") + tau_hat = bcf_model.predict( + X=X_test, Z=Z_test, propensity=pi_test, terms="cate" + ) assert tau_hat.shape == (n_test, num_mcmc) # Run BCF without test set and with propensity score @@ -101,12 +107,18 @@ def test_binary_bcf(self): # Check overall prediction method bcf_preds = bcf_model.predict(X_test, Z_test, pi_test) - tau_hat, mu_hat, y_hat = bcf_preds['tau_hat'], bcf_preds['mu_hat'], bcf_preds['y_hat'] + tau_hat, mu_hat, y_hat = ( + bcf_preds["tau_hat"], + bcf_preds["mu_hat"], + bcf_preds["y_hat"], + ) assert tau_hat.shape == (n_test, num_mcmc) assert mu_hat.shape == (n_test, num_mcmc) assert y_hat.shape == (n_test, num_mcmc) # Check treatment effect prediction method - tau_hat = bcf_model.predict(X = X_test, Z = Z_test, propensity = pi_test, terms = "cate") + tau_hat = bcf_model.predict( + X=X_test, Z=Z_test, propensity=pi_test, terms="cate" + ) assert tau_hat.shape == (n_test, num_mcmc) # Run BCF with test set and without propensity score @@ -136,13 +148,17 @@ def test_binary_bcf(self): # Check overall prediction method bcf_preds = bcf_model.predict(X_test, Z_test) - tau_hat, mu_hat, y_hat = bcf_preds['tau_hat'], bcf_preds['mu_hat'], bcf_preds['y_hat'] + tau_hat, mu_hat, y_hat = ( + bcf_preds["tau_hat"], + bcf_preds["mu_hat"], + bcf_preds["y_hat"], + ) assert tau_hat.shape == (n_test, num_mcmc) assert mu_hat.shape == (n_test, num_mcmc) assert y_hat.shape == (n_test, num_mcmc) # Check treatment effect prediction method - tau_hat = bcf_model.predict(X = X_test, Z = Z_test, terms = "cate") + tau_hat = bcf_model.predict(X=X_test, Z=Z_test, terms="cate") assert tau_hat.shape == (n_test, num_mcmc) # Run BCF without test set and without propensity score @@ -166,13 +182,17 @@ def test_binary_bcf(self): # Check overall prediction method bcf_preds = bcf_model.predict(X_test, Z_test) - tau_hat, mu_hat, y_hat = bcf_preds['tau_hat'], bcf_preds['mu_hat'], bcf_preds['y_hat'] + tau_hat, mu_hat, y_hat = ( + bcf_preds["tau_hat"], + bcf_preds["mu_hat"], + bcf_preds["y_hat"], + ) assert tau_hat.shape == (n_test, num_mcmc) assert mu_hat.shape == (n_test, num_mcmc) assert y_hat.shape == (n_test, num_mcmc) # Check treatment effect prediction method - tau_hat = bcf_model.predict(X = X_test, Z = Z_test, terms = "cate") + tau_hat = bcf_model.predict(X=X_test, Z=Z_test, terms="cate") def test_continuous_univariate_bcf(self): # RNG @@ -239,13 +259,19 @@ def test_continuous_univariate_bcf(self): # Check overall prediction method bcf_preds = bcf_model.predict(X_test, Z_test, pi_test) - tau_hat, mu_hat, y_hat = bcf_preds['tau_hat'], bcf_preds['mu_hat'], bcf_preds['y_hat'] + tau_hat, mu_hat, y_hat = ( + bcf_preds["tau_hat"], + bcf_preds["mu_hat"], + bcf_preds["y_hat"], + ) assert tau_hat.shape == (n_test, num_mcmc) assert mu_hat.shape == (n_test, num_mcmc) assert y_hat.shape == (n_test, num_mcmc) # Check treatment effect prediction method - tau_hat = bcf_model.predict(X = X_test, Z = Z_test, propensity = pi_test, terms = "cate") + tau_hat = bcf_model.predict( + X=X_test, Z=Z_test, propensity=pi_test, terms="cate" + ) assert tau_hat.shape == (n_test, num_mcmc) # Run second BCF model with test set and propensity score @@ -275,13 +301,19 @@ def test_continuous_univariate_bcf(self): # Check overall prediction method bcf_preds_2 = bcf_model_2.predict(X_test, Z_test, pi_test) - tau_hat_2, mu_hat_2, y_hat_2 = bcf_preds_2['tau_hat'], bcf_preds_2['mu_hat'], bcf_preds_2['y_hat'] + tau_hat_2, mu_hat_2, y_hat_2 = ( + bcf_preds_2["tau_hat"], + bcf_preds_2["mu_hat"], + bcf_preds_2["y_hat"], + ) assert tau_hat_2.shape == (n_test, num_mcmc) assert mu_hat_2.shape == (n_test, num_mcmc) assert y_hat_2.shape == (n_test, num_mcmc) # Check treatment effect prediction method - tau_hat_2 = bcf_model_2.predict(X = X_test, Z = Z_test, propensity = pi_test, terms = "cate") + tau_hat_2 = bcf_model_2.predict( + X=X_test, Z=Z_test, propensity=pi_test, terms="cate" + ) assert tau_hat_2.shape == (n_test, num_mcmc) # Combine into a single model @@ -291,7 +323,11 @@ def test_continuous_univariate_bcf(self): # Assertions bcf_preds_3 = bcf_model_3.predict(X_test, Z_test, pi_test) - tau_hat_3, mu_hat_3, y_hat_3 = bcf_preds_3['tau_hat'], bcf_preds_3['mu_hat'], bcf_preds_3['y_hat'] + tau_hat_3, mu_hat_3, y_hat_3 = ( + bcf_preds_3["tau_hat"], + bcf_preds_3["mu_hat"], + bcf_preds_3["y_hat"], + ) assert tau_hat_3.shape == (n_train, num_mcmc * 2) assert mu_hat_3.shape == (n_train, num_mcmc * 2) assert y_hat_3.shape == (n_train, num_mcmc * 2) @@ -330,13 +366,19 @@ def test_continuous_univariate_bcf(self): # Check overall prediction method bcf_preds = bcf_model.predict(X_test, Z_test, pi_test) - tau_hat, mu_hat, y_hat = bcf_preds['tau_hat'], bcf_preds['mu_hat'], bcf_preds['y_hat'] + tau_hat, mu_hat, y_hat = ( + bcf_preds["tau_hat"], + bcf_preds["mu_hat"], + bcf_preds["y_hat"], + ) assert tau_hat.shape == (n_test, num_mcmc) assert mu_hat.shape == (n_test, num_mcmc) assert y_hat.shape == (n_test, num_mcmc) # Check treatment effect prediction method - tau_hat = bcf_model.predict(X = X_test, Z = Z_test, propensity = pi_test, terms = "cate") + tau_hat = bcf_model.predict( + X=X_test, Z=Z_test, propensity=pi_test, terms="cate" + ) assert tau_hat.shape == (n_test, num_mcmc) # Run BCF with test set and without propensity score @@ -366,13 +408,17 @@ def test_continuous_univariate_bcf(self): # Check overall prediction method bcf_preds = bcf_model.predict(X_test, Z_test, pi_test) - tau_hat, mu_hat, y_hat = bcf_preds['tau_hat'], bcf_preds['mu_hat'], bcf_preds['y_hat'] + tau_hat, mu_hat, y_hat = ( + bcf_preds["tau_hat"], + bcf_preds["mu_hat"], + bcf_preds["y_hat"], + ) assert tau_hat.shape == (n_test, num_mcmc) assert mu_hat.shape == (n_test, num_mcmc) assert y_hat.shape == (n_test, num_mcmc) # Check treatment effect prediction method - tau_hat = bcf_model.predict(X = X_test, Z = Z_test, terms = "cate") + tau_hat = bcf_model.predict(X=X_test, Z=Z_test, terms="cate") assert tau_hat.shape == (n_test, num_mcmc) # Run BCF without test set and without propensity score @@ -396,13 +442,17 @@ def test_continuous_univariate_bcf(self): # Check overall prediction method bcf_preds = bcf_model.predict(X_test, Z_test) - tau_hat, mu_hat, y_hat = bcf_preds['tau_hat'], bcf_preds['mu_hat'], bcf_preds['y_hat'] + tau_hat, mu_hat, y_hat = ( + bcf_preds["tau_hat"], + bcf_preds["mu_hat"], + bcf_preds["y_hat"], + ) assert tau_hat.shape == (n_test, num_mcmc) assert mu_hat.shape == (n_test, num_mcmc) assert y_hat.shape == (n_test, num_mcmc) # Check treatment effect prediction method - tau_hat = bcf_model.predict(X = X_test, Z = Z_test, terms = "cate") + tau_hat = bcf_model.predict(X=X_test, Z=Z_test, terms="cate") # Run second BCF model with test set and propensity score bcf_model_2 = BCFModel() @@ -424,13 +474,17 @@ def test_continuous_univariate_bcf(self): # Check overall prediction method bcf_preds_2 = bcf_model_2.predict(X_test, Z_test) - tau_hat_2, mu_hat_2, y_hat_2 = bcf_preds_2['tau_hat'], bcf_preds_2['mu_hat'], bcf_preds_2['y_hat'] + tau_hat_2, mu_hat_2, y_hat_2 = ( + bcf_preds_2["tau_hat"], + bcf_preds_2["mu_hat"], + bcf_preds_2["y_hat"], + ) assert tau_hat_2.shape == (n_test, num_mcmc) assert mu_hat_2.shape == (n_test, num_mcmc) assert y_hat_2.shape == (n_test, num_mcmc) # Check treatment effect prediction method - tau_hat_2 = bcf_model_2.predict(X = X_test, Z = Z_test, terms = "cate") + tau_hat_2 = bcf_model_2.predict(X=X_test, Z=Z_test, terms="cate") assert tau_hat_2.shape == (n_test, num_mcmc) # Combine into a single model @@ -440,7 +494,11 @@ def test_continuous_univariate_bcf(self): # Assertions bcf_preds_3 = bcf_model_3.predict(X_test, Z_test) - tau_hat_3, mu_hat_3, y_hat_3 = bcf_preds_3['tau_hat'], bcf_preds_3['mu_hat'], bcf_preds_3['y_hat'] + tau_hat_3, mu_hat_3, y_hat_3 = ( + bcf_preds_3["tau_hat"], + bcf_preds_3["mu_hat"], + bcf_preds_3["y_hat"], + ) assert tau_hat_3.shape == (n_train, num_mcmc * 2) assert mu_hat_3.shape == (n_train, num_mcmc * 2) assert y_hat_3.shape == (n_train, num_mcmc * 2) @@ -522,13 +580,19 @@ def test_multivariate_bcf(self): # Check overall prediction method bcf_preds = bcf_model.predict(X_test, Z_test, pi_test) - tau_hat, mu_hat, y_hat = bcf_preds['tau_hat'], bcf_preds['mu_hat'], bcf_preds['y_hat'] + tau_hat, mu_hat, y_hat = ( + bcf_preds["tau_hat"], + bcf_preds["mu_hat"], + bcf_preds["y_hat"], + ) assert tau_hat.shape == (n_test, num_mcmc, treatment_dim) assert mu_hat.shape == (n_test, num_mcmc) assert y_hat.shape == (n_test, num_mcmc) # Check treatment effect prediction method - tau_hat = bcf_model.predict(X = X_test, Z = Z_test, propensity = pi_test, terms = "cate") + tau_hat = bcf_model.predict( + X=X_test, Z=Z_test, propensity=pi_test, terms="cate" + ) assert tau_hat.shape == (n_test, num_mcmc, treatment_dim) # Run BCF without test set and with propensity score @@ -552,13 +616,19 @@ def test_multivariate_bcf(self): # Check overall prediction method bcf_preds = bcf_model.predict(X_test, Z_test, pi_test) - tau_hat, mu_hat, y_hat = bcf_preds['tau_hat'], bcf_preds['mu_hat'], bcf_preds['y_hat'] + tau_hat, mu_hat, y_hat = ( + bcf_preds["tau_hat"], + bcf_preds["mu_hat"], + bcf_preds["y_hat"], + ) assert tau_hat.shape == (n_test, num_mcmc, treatment_dim) assert mu_hat.shape == (n_test, num_mcmc) assert y_hat.shape == (n_test, num_mcmc) # Check treatment effect prediction method - tau_hat = bcf_model.predict(X = X_test, Z = Z_test, propensity = pi_test, terms = "cate") + tau_hat = bcf_model.predict( + X=X_test, Z=Z_test, propensity=pi_test, terms="cate" + ) assert tau_hat.shape == (n_test, num_mcmc, treatment_dim) # Run BCF with test set and without propensity score @@ -658,14 +728,21 @@ def test_binary_bcf_heteroskedastic(self): # Check overall prediction method bcf_preds = bcf_model.predict(X_test, Z_test, pi_test) - tau_hat, mu_hat, y_hat, sigma2_x_hat = bcf_preds['tau_hat'], bcf_preds['mu_hat'], bcf_preds['y_hat'], bcf_preds['variance_forest_predictions'] + tau_hat, mu_hat, y_hat, sigma2_x_hat = ( + bcf_preds["tau_hat"], + bcf_preds["mu_hat"], + bcf_preds["y_hat"], + bcf_preds["variance_forest_predictions"], + ) assert tau_hat.shape == (n_test, num_mcmc) assert mu_hat.shape == (n_test, num_mcmc) assert y_hat.shape == (n_test, num_mcmc) assert sigma2_x_hat.shape == (n_test, num_mcmc) # Check treatment effect prediction method - tau_hat = bcf_model.predict(X = X_test, Z = Z_test, propensity = pi_test, terms = "cate") + tau_hat = bcf_model.predict( + X=X_test, Z=Z_test, propensity=pi_test, terms="cate" + ) assert tau_hat.shape == (n_test, num_mcmc) # Run BCF without test set and with propensity score @@ -690,32 +767,28 @@ def test_binary_bcf_heteroskedastic(self): # Check overall prediction method bcf_preds = bcf_model.predict(X_test, Z_test, pi_test) - assert bcf_preds['tau_hat'].shape == (n_test, num_mcmc) - assert bcf_preds['mu_hat'].shape == (n_test, num_mcmc) - assert bcf_preds['y_hat'].shape == (n_test, num_mcmc) - assert bcf_preds['variance_forest_predictions'].shape == (n_test, num_mcmc) + assert bcf_preds["tau_hat"].shape == (n_test, num_mcmc) + assert bcf_preds["mu_hat"].shape == (n_test, num_mcmc) + assert bcf_preds["y_hat"].shape == (n_test, num_mcmc) + assert bcf_preds["variance_forest_predictions"].shape == (n_test, num_mcmc) # Check predictions match bcf_preds = bcf_model.predict(X_train, Z_train, pi_train) - assert bcf_preds['tau_hat'].shape == (n_train, num_mcmc) - assert bcf_preds['mu_hat'].shape == (n_train, num_mcmc) - assert bcf_preds['y_hat'].shape == (n_train, num_mcmc) - assert bcf_preds['variance_forest_predictions'].shape == (n_train, num_mcmc) - np.testing.assert_allclose( - bcf_preds['y_hat'], bcf_model.y_hat_train - ) - np.testing.assert_allclose( - bcf_preds['mu_hat'], bcf_model.mu_hat_train - ) - np.testing.assert_allclose( - bcf_preds['tau_hat'], bcf_model.tau_hat_train - ) + assert bcf_preds["tau_hat"].shape == (n_train, num_mcmc) + assert bcf_preds["mu_hat"].shape == (n_train, num_mcmc) + assert bcf_preds["y_hat"].shape == (n_train, num_mcmc) + assert bcf_preds["variance_forest_predictions"].shape == (n_train, num_mcmc) + np.testing.assert_allclose(bcf_preds["y_hat"], bcf_model.y_hat_train) + np.testing.assert_allclose(bcf_preds["mu_hat"], bcf_model.mu_hat_train) + np.testing.assert_allclose(bcf_preds["tau_hat"], bcf_model.tau_hat_train) np.testing.assert_allclose( - bcf_preds['variance_forest_predictions'], bcf_model.sigma2_x_train + bcf_preds["variance_forest_predictions"], bcf_model.sigma2_x_train ) # Check treatment effect prediction method - tau_hat = bcf_model.predict(X = X_test, Z = Z_test, propensity = pi_test, terms = "cate") + tau_hat = bcf_model.predict( + X=X_test, Z=Z_test, propensity=pi_test, terms="cate" + ) assert tau_hat.shape == (n_test, num_mcmc) # Run BCF with test set and without propensity score @@ -746,13 +819,13 @@ def test_binary_bcf_heteroskedastic(self): # Check overall prediction method bcf_preds = bcf_model.predict(X_test, Z_test) - assert bcf_preds['tau_hat'].shape == (n_test, num_mcmc) - assert bcf_preds['mu_hat'].shape == (n_test, num_mcmc) - assert bcf_preds['y_hat'].shape == (n_test, num_mcmc) - assert bcf_preds['variance_forest_predictions'].shape == (n_test, num_mcmc) + assert bcf_preds["tau_hat"].shape == (n_test, num_mcmc) + assert bcf_preds["mu_hat"].shape == (n_test, num_mcmc) + assert bcf_preds["y_hat"].shape == (n_test, num_mcmc) + assert bcf_preds["variance_forest_predictions"].shape == (n_test, num_mcmc) # Check treatment effect prediction method - tau_hat = bcf_model.predict(X = X_test, Z = Z_test, terms = "cate") + tau_hat = bcf_model.predict(X=X_test, Z=Z_test, terms="cate") assert tau_hat.shape == (n_test, num_mcmc) # Run BCF without test set and without propensity score @@ -776,13 +849,13 @@ def test_binary_bcf_heteroskedastic(self): # Check overall prediction method bcf_preds = bcf_model.predict(X_test, Z_test) - assert bcf_preds['tau_hat'].shape == (n_test, num_mcmc) - assert bcf_preds['mu_hat'].shape == (n_test, num_mcmc) - assert bcf_preds['y_hat'].shape == (n_test, num_mcmc) + assert bcf_preds["tau_hat"].shape == (n_test, num_mcmc) + assert bcf_preds["mu_hat"].shape == (n_test, num_mcmc) + assert bcf_preds["y_hat"].shape == (n_test, num_mcmc) # Check treatment effect prediction method - tau_hat = bcf_model.predict(X = X_test, Z = Z_test, terms = "cate") - + tau_hat = bcf_model.predict(X=X_test, Z=Z_test, terms="cate") + def test_bcf_rfx_parameters(self): # RNG random_seed = 101 @@ -811,9 +884,9 @@ def test_bcf_rfx_parameters(self): # Define the group rfx function def rfx_term(group_labels, basis): return np.where( - group_labels == 0, -5 + 1. * basis[:,1], 5 - 1. * basis[:,1] + group_labels == 0, -5 + 1.0 * basis[:, 1], 5 - 1.0 * basis[:, 1] ) - + # Generate outcome epsilon = rng.normal(0, 1, n) y = mu_X + tau_X * Z + rfx_term(group_labels, rfx_basis) + epsilon @@ -856,17 +929,17 @@ def rfx_term(group_labels, basis): num_gfr=num_gfr, num_burnin=num_burnin, num_mcmc=num_mcmc, - general_params=general_params + general_params=general_params, ) # Specify scalar rfx parameters rfx_params = { - "working_parameter_prior_mean": 1., - "group_parameter_prior_mean": 1., - "working_parameter_prior_cov": 1., - "group_parameter_prior_cov": 1., + "working_parameter_prior_mean": 1.0, + "group_parameter_prior_mean": 1.0, + "working_parameter_prior_cov": 1.0, + "group_parameter_prior_cov": 1.0, "variance_prior_shape": 1, - "variance_prior_scale": 1 + "variance_prior_scale": 1, } bcf_model_2 = BCFModel() bcf_model_2.sample( @@ -884,17 +957,17 @@ def rfx_term(group_labels, basis): num_gfr=num_gfr, num_burnin=num_burnin, num_mcmc=num_mcmc, - random_effects_params=rfx_params + random_effects_params=rfx_params, ) # Specify all relevant rfx parameters as vectors rfx_params = { - "working_parameter_prior_mean": np.repeat(1., num_rfx_basis), - "group_parameter_prior_mean": np.repeat(1., num_rfx_basis), + "working_parameter_prior_mean": np.repeat(1.0, num_rfx_basis), + "group_parameter_prior_mean": np.repeat(1.0, num_rfx_basis), "working_parameter_prior_cov": np.identity(num_rfx_basis), "group_parameter_prior_cov": np.identity(num_rfx_basis), "variance_prior_shape": 1, - "variance_prior_scale": 1 + "variance_prior_scale": 1, } bcf_model_3 = BCFModel() bcf_model_3.sample( @@ -912,5 +985,5 @@ def rfx_term(group_labels, basis): num_gfr=num_gfr, num_burnin=num_burnin, num_mcmc=num_mcmc, - random_effects_params=rfx_params + random_effects_params=rfx_params, ) diff --git a/test/python/test_data.py b/test/python/test_data.py index 09c75154..cda8f58b 100644 --- a/test/python/test_data.py +++ b/test/python/test_data.py @@ -2,6 +2,7 @@ from stochtree import Dataset, RandomEffectsDataset + class TestDataset: def test_dataset_update(self): # Generate data @@ -12,7 +13,7 @@ def test_dataset_update(self): covariates = rng.uniform(0, 1, size=(n, num_covariates)) basis = rng.uniform(0, 1, size=(n, num_basis)) variance_weights = rng.uniform(0, 1, size=n) - + # Construct dataset forest_dataset = Dataset() forest_dataset.add_covariates(covariates) @@ -22,18 +23,21 @@ def test_dataset_update(self): assert forest_dataset.num_covariates() == num_covariates assert forest_dataset.num_basis() == num_basis assert forest_dataset.has_variance_weights() - + # Update dataset new_basis = rng.uniform(0, 1, size=(n, num_basis)) new_variance_weights = rng.uniform(0, 1, size=n) with np.testing.assert_no_warnings(): forest_dataset.update_basis(new_basis) forest_dataset.update_variance_weights(new_variance_weights) - + # Check that we recover the correct data through get_covariates, get_basis, and get_variance_weights np.testing.assert_array_equal(forest_dataset.get_covariates(), covariates) np.testing.assert_array_equal(forest_dataset.get_basis(), new_basis) - np.testing.assert_array_equal(forest_dataset.get_variance_weights(), new_variance_weights) + np.testing.assert_array_equal( + forest_dataset.get_variance_weights(), new_variance_weights + ) + class TestRFXDataset: def test_rfx_dataset_update(self): @@ -48,7 +52,7 @@ def test_rfx_dataset_update(self): if num_basis > 1: basis[:, 1:] = rng.uniform(-1, 1, (n, num_basis - 1)) variance_weights = rng.uniform(0, 1, size=n) - + # Construct dataset rfx_dataset = RandomEffectsDataset() rfx_dataset.add_group_labels(group_labels) @@ -57,16 +61,17 @@ def test_rfx_dataset_update(self): assert rfx_dataset.num_observations() == n assert rfx_dataset.num_basis() == num_basis assert rfx_dataset.has_variance_weights() - + # Update dataset new_basis = rng.uniform(0, 1, size=(n, num_basis)) new_variance_weights = rng.uniform(0, 1, size=n) with np.testing.assert_no_warnings(): rfx_dataset.update_basis(new_basis) rfx_dataset.update_variance_weights(new_variance_weights) - + # Check that we recover the correct data through get_group_labels, get_basis, and get_variance_weights np.testing.assert_array_equal(rfx_dataset.get_group_labels(), group_labels) np.testing.assert_array_equal(rfx_dataset.get_basis(), new_basis) - np.testing.assert_array_equal(rfx_dataset.get_variance_weights(), new_variance_weights) - + np.testing.assert_array_equal( + rfx_dataset.get_variance_weights(), new_variance_weights + ) diff --git a/test/python/test_forest.py b/test/python/test_forest.py index 267b09b4..9d5e6b48 100644 --- a/test/python/test_forest.py +++ b/test/python/test_forest.py @@ -6,14 +6,14 @@ class TestPredict: def test_constant_forest_construction(self): # Create dataset - X = np.array( - [[1.5, 8.7, 1.2], - [2.7, 3.4, 5.4], - [3.6, 1.2, 9.3], - [4.4, 5.4, 10.4], - [5.3, 9.3, 3.6], - [6.1, 10.4, 4.4]] - ) + X = np.array([ + [1.5, 8.7, 1.2], + [2.7, 3.4, 5.4], + [3.6, 1.2, 9.3], + [4.4, 5.4, 10.4], + [5.3, 9.3, 3.6], + [6.1, 10.4, 4.4], + ]) n, p = X.shape num_trees = 10 output_dim = 1 @@ -26,41 +26,41 @@ def test_constant_forest_construction(self): # Check that regular and "raw" predictions are the same (since the leaf is constant) pred = forest.predict(forest_dataset) - pred_exp = np.array([0.,0.,0.,0.,0.,0.]) + pred_exp = np.array([0.0, 0.0, 0.0, 0.0, 0.0, 0.0]) # Assertion np.testing.assert_almost_equal(pred, pred_exp) # Split the root of the first tree in the ensemble at X[,1] > 4.0 - forest.add_numeric_split(0, 0, 0, 4.0, -5., 5.) + forest.add_numeric_split(0, 0, 0, 4.0, -5.0, 5.0) # Check that regular and "raw" predictions are the same (since the leaf is constant) pred = forest.predict(forest_dataset) - pred_exp = np.array([-5.,-5.,-5.,5.,5.,5.]) - + pred_exp = np.array([-5.0, -5.0, -5.0, 5.0, 5.0, 5.0]) + # Assertion np.testing.assert_almost_equal(pred, pred_exp) # Split the left leaf of the first tree in the ensemble at X[,2] > 4.0 forest.add_numeric_split(0, 1, 1, 4.0, -7.5, -2.5) - + # Check that regular and "raw" predictions are the same (since the leaf is constant) pred = forest.predict(forest_dataset) - pred_exp = np.array([-2.5,-7.5,-7.5,5.,5.,5.]) - + pred_exp = np.array([-2.5, -7.5, -7.5, 5.0, 5.0, 5.0]) + # Assertion np.testing.assert_almost_equal(pred, pred_exp) - + def test_constant_forest_merge_arithmetic(self): # Create dataset - X = np.array( - [[1.5, 8.7, 1.2], - [2.7, 3.4, 5.4], - [3.6, 1.2, 9.3], - [4.4, 5.4, 10.4], - [5.3, 9.3, 3.6], - [6.1, 10.4, 4.4]] - ) + X = np.array([ + [1.5, 8.7, 1.2], + [2.7, 3.4, 5.4], + [3.6, 1.2, 9.3], + [4.4, 5.4, 10.4], + [5.3, 9.3, 3.6], + [6.1, 10.4, 4.4], + ]) n, p = X.shape num_trees = 10 output_dim = 1 @@ -78,25 +78,25 @@ def test_constant_forest_merge_arithmetic(self): # Check that predictions are as expected pred1 = forest1.predict(forest_dataset) pred2 = forest2.predict(forest_dataset) - pred_exp1 = np.array([0.,0.,0.,0.,0.,0.]) - pred_exp2 = np.array([0.,0.,0.,0.,0.,0.]) + pred_exp1 = np.array([0.0, 0.0, 0.0, 0.0, 0.0, 0.0]) + pred_exp2 = np.array([0.0, 0.0, 0.0, 0.0, 0.0, 0.0]) # Assertion np.testing.assert_almost_equal(pred1, pred_exp1) np.testing.assert_almost_equal(pred2, pred_exp2) # Split the root of the first tree in the first forest at X[,1] > 4.0 - forest1.add_numeric_split(0, 0, 0, 4.0, -5., 5.) - + forest1.add_numeric_split(0, 0, 0, 4.0, -5.0, 5.0) + # Split the root of the first tree in the second forest at X[,1] > 3.0 - forest2.add_numeric_split(0, 0, 0, 3.0, -1., 1.) + forest2.add_numeric_split(0, 0, 0, 3.0, -1.0, 1.0) # Check that predictions are as expected pred1 = forest1.predict(forest_dataset) pred2 = forest2.predict(forest_dataset) - pred_exp1 = np.array([-5.,-5.,-5.,5.,5.,5.]) - pred_exp2 = np.array([-1.,-1.,1.,1.,1.,1.]) - + pred_exp1 = np.array([-5.0, -5.0, -5.0, 5.0, 5.0, 5.0]) + pred_exp2 = np.array([-1.0, -1.0, 1.0, 1.0, 1.0, 1.0]) + # Assertion np.testing.assert_almost_equal(pred1, pred_exp1) np.testing.assert_almost_equal(pred2, pred_exp2) @@ -106,13 +106,13 @@ def test_constant_forest_merge_arithmetic(self): # Split the left leaf of the first tree in the first forest at X[,2] > 4.0 forest2.add_numeric_split(0, 1, 1, 4.0, -1.5, -0.5) - + # Check that predictions are as expected pred1 = forest1.predict(forest_dataset) pred2 = forest2.predict(forest_dataset) - pred_exp1 = np.array([-2.5,-7.5,-7.5,5.,5.,5.]) - pred_exp2 = np.array([-0.5,-1.5,1.,1.,1.,1.]) - + pred_exp1 = np.array([-2.5, -7.5, -7.5, 5.0, 5.0, 5.0]) + pred_exp2 = np.array([-0.5, -1.5, 1.0, 1.0, 1.0, 1.0]) + # Assertion np.testing.assert_almost_equal(pred1, pred_exp1) np.testing.assert_almost_equal(pred2, pred_exp2) @@ -122,7 +122,7 @@ def test_constant_forest_merge_arithmetic(self): # Check that predictions are as expected pred = forest1.predict(forest_dataset) - pred_exp = np.array([-3.0,-9.0,-6.5,6.0,6.0,6.0]) + pred_exp = np.array([-3.0, -9.0, -6.5, 6.0, 6.0, 6.0]) # Assertion np.testing.assert_almost_equal(pred, pred_exp) @@ -132,14 +132,14 @@ def test_constant_forest_merge_arithmetic(self): # Check that predictions are as expected pred = forest1.predict(forest_dataset) - pred_exp = np.array([7.0,1.0,3.5,16.0,16.0,16.0]) + pred_exp = np.array([7.0, 1.0, 3.5, 16.0, 16.0, 16.0]) # Assertion np.testing.assert_almost_equal(pred, pred_exp) # Check that "old" forest is still intact pred = forest2.predict(forest_dataset) - pred_exp = np.array([-0.5,-1.5,1.,1.,1.,1.]) + pred_exp = np.array([-0.5, -1.5, 1.0, 1.0, 1.0, 1.0]) # Assertion np.testing.assert_almost_equal(pred, pred_exp) @@ -149,7 +149,7 @@ def test_constant_forest_merge_arithmetic(self): # Check that predictions are as expected pred = forest1.predict(forest_dataset) - pred_exp = np.array([-3.0,-9.0,-6.5,6.0,6.0,6.0]) + pred_exp = np.array([-3.0, -9.0, -6.5, 6.0, 6.0, 6.0]) # Assertion np.testing.assert_almost_equal(pred, pred_exp) @@ -159,7 +159,7 @@ def test_constant_forest_merge_arithmetic(self): # Check that predictions are as expected pred = forest1.predict(forest_dataset) - pred_exp = np.array([-6.0,-18.0,-13.0,12.0,12.0,12.0]) + pred_exp = np.array([-6.0, -18.0, -13.0, 12.0, 12.0, 12.0]) # Assertion np.testing.assert_almost_equal(pred, pred_exp) diff --git a/test/python/test_forest_container.py b/test/python/test_forest_container.py index 9eda8569..4368f9b3 100644 --- a/test/python/test_forest_container.py +++ b/test/python/test_forest_container.py @@ -7,14 +7,14 @@ class TestPredict: def test_constant_leaf_forest_container(self): # Create dataset - X = np.array( - [[1.5, 8.7, 1.2], - [2.7, 3.4, 5.4], - [3.6, 1.2, 9.3], - [4.4, 5.4, 10.4], - [5.3, 9.3, 3.6], - [6.1, 10.4, 4.4]] - ) + X = np.array([ + [1.5, 8.7, 1.2], + [2.7, 3.4, 5.4], + [3.6, 1.2, 9.3], + [4.4, 5.4, 10.4], + [5.3, 9.3, 3.6], + [6.1, 10.4, 4.4], + ]) n, p = X.shape num_trees = 10 output_dim = 1 @@ -23,11 +23,11 @@ def test_constant_leaf_forest_container(self): forest_samples = ForestContainer(num_trees, output_dim, True, False) # Initialize a forest with constant root predictions - forest_samples.add_sample(0.) + forest_samples.add_sample(0.0) # Split the root of the first tree in the ensemble at X[,1] > 4.0 # and then split the left leaf of the first tree in the ensemble at X[,2] > 4.0 - forest_samples.add_numeric_split(0, 0, 0, 0, 4.0, -5., 5.) + forest_samples.add_numeric_split(0, 0, 0, 0, 4.0, -5.0, 5.0) forest_samples.add_numeric_split(0, 0, 1, 1, 4.0, -7.5, -2.5) # Store the predictions of the "original" forest before modifications @@ -35,48 +35,48 @@ def test_constant_leaf_forest_container(self): # Multiply first forest by 2.0 forest_samples.multiply_forest(0, 2.0) - + # Check that predictions are all double pred = forest_samples.predict(forest_dataset) pred_expected = pred_orig * 2.0 - + # Assertion np.testing.assert_almost_equal(pred, pred_expected) # Add 1.0 to every tree in first forest forest_samples.add_to_forest(0, 1.0) - + # Check that predictions are += num_trees pred_expected = pred + num_trees pred = forest_samples.predict(forest_dataset) - + # Assertion np.testing.assert_almost_equal(pred, pred_expected) # Initialize a new forest with constant root predictions - forest_samples.add_sample(0.) + forest_samples.add_sample(0.0) # Split the second forest as the first forest was split - forest_samples.add_numeric_split(1, 0, 0, 0, 4.0, -5., 5.) + forest_samples.add_numeric_split(1, 0, 0, 0, 4.0, -5.0, 5.0) forest_samples.add_numeric_split(1, 0, 1, 1, 4.0, -7.5, -2.5) - + # Check that predictions are as expected pred_expected_new = np.c_[pred_expected, pred_orig] pred = forest_samples.predict(forest_dataset) - + # Assertion np.testing.assert_almost_equal(pred, pred_expected_new) # Combine second forest with the first forest - forest_samples.combine_forests(np.array([0,1])) + forest_samples.combine_forests(np.array([0, 1])) # Check that predictions are as expected pred_expected_new = np.c_[pred_expected + pred_orig, pred_orig] pred = forest_samples.predict(forest_dataset) - + # Assertion np.testing.assert_almost_equal(pred, pred_expected_new) - + def test_collapse_forest_container(self): # RNG rng = np.random.default_rng() @@ -143,13 +143,22 @@ def outcome_mean(X): # Check that corresponding (sums of) predictions match container_inds = np.linspace(start=1, stop=num_mcmc, num=num_mcmc) - batch_inds = (container_inds - (num_mcmc - ((num_mcmc // (num_mcmc // batch_size)) * (num_mcmc // batch_size))) - 1) // batch_size + batch_inds = ( + container_inds + - ( + num_mcmc + - ((num_mcmc // (num_mcmc // batch_size)) * (num_mcmc // batch_size)) + ) + - 1 + ) // batch_size batch_inds = batch_inds.astype(int) num_batches = np.max(batch_inds) + 1 pred_orig_collapsed = np.empty((n_test, num_batches)) for i in range(num_batches): - pred_orig_collapsed[:,i] = np.sum(pred_orig[:,batch_inds == i], axis=1) / np.sum(batch_inds == i) - + pred_orig_collapsed[:, i] = np.sum( + pred_orig[:, batch_inds == i], axis=1 + ) / np.sum(batch_inds == i) + # Assertion np.testing.assert_almost_equal(pred_orig_collapsed, pred_new) @@ -180,13 +189,22 @@ def outcome_mean(X): # Check that corresponding (sums of) predictions match container_inds = np.linspace(start=1, stop=num_mcmc, num=num_mcmc) - batch_inds = (container_inds - (num_mcmc - ((num_mcmc // (num_mcmc // batch_size)) * (num_mcmc // batch_size))) - 1) // batch_size + batch_inds = ( + container_inds + - ( + num_mcmc + - ((num_mcmc // (num_mcmc // batch_size)) * (num_mcmc // batch_size)) + ) + - 1 + ) // batch_size batch_inds = batch_inds.astype(int) num_batches = np.max(batch_inds) + 1 pred_orig_collapsed = np.empty((n_test, num_batches)) for i in range(num_batches): - pred_orig_collapsed[:,i] = np.sum(pred_orig[:,batch_inds == i], axis=1) / np.sum(batch_inds == i) - + pred_orig_collapsed[:, i] = np.sum( + pred_orig[:, batch_inds == i], axis=1 + ) / np.sum(batch_inds == i) + # Assertion np.testing.assert_almost_equal(pred_orig_collapsed, pred_new) @@ -218,8 +236,8 @@ def outcome_mean(X): # Check that corresponding (sums of) predictions match num_batches = 1 pred_orig_collapsed = np.empty((n_test, num_batches)) - pred_orig_collapsed[:,0] = np.sum(pred_orig, axis=1) / batch_size - + pred_orig_collapsed[:, 0] = np.sum(pred_orig, axis=1) / batch_size + # Assertion np.testing.assert_almost_equal(pred_orig_collapsed, pred_new) @@ -250,6 +268,6 @@ def outcome_mean(X): # Check that corresponding (sums of) predictions match pred_orig_collapsed = pred_orig - + # Assertion np.testing.assert_almost_equal(pred_orig_collapsed, pred_new) diff --git a/test/python/test_json.py b/test/python/test_json.py index 1d0dc66d..48d4845b 100644 --- a/test/python/test_json.py +++ b/test/python/test_json.py @@ -14,7 +14,7 @@ JSONSerializer, Residual, ForestModelConfig, - GlobalModelConfig + GlobalModelConfig, ) @@ -41,17 +41,15 @@ def test_array(self): assert b == json_test.get_string_vector("b") def test_preprocessor(self): - df = pd.DataFrame( - { - "x1": [1.5, 2.7, 3.6, 4.4, 5.3, 6.1], - "x2": pd.Categorical( - ["a", "b", "c", "a", "b", "c"], - ordered=False, - categories=["c", "b", "a"], - ), - "x3": [1.2, 5.4, 9.3, 10.4, 3.6, 4.4], - } - ) + df = pd.DataFrame({ + "x1": [1.5, 2.7, 3.6, 4.4, 5.3, 6.1], + "x2": pd.Categorical( + ["a", "b", "c", "a", "b", "c"], + ordered=False, + categories=["c", "b", "a"], + ), + "x3": [1.2, 5.4, 9.3, 10.4, 3.6, 4.4], + }) cov_transformer = CovariatePreprocessor() df_transformed_orig = cov_transformer.fit_transform(df) cov_transformer_json = cov_transformer.to_json() @@ -60,27 +58,25 @@ def test_preprocessor(self): df_transformed_reloaded = cov_transformer_reloaded.transform(df) np.testing.assert_array_equal(df_transformed_orig, df_transformed_reloaded) - df_2 = pd.DataFrame( - { - "x1": [1.5, 2.7, 3.6, 4.4, 5.3, 6.1], - "x2": pd.Categorical( - ["a", "b", "c", "a", "b", "c"], - ordered=False, - categories=["c", "b", "a"], - ), - "x3": pd.Categorical( - ["a", "c", "d", "b", "d", "b"], - ordered=False, - categories=["c", "b", "a", "d"], - ), - "x4": pd.Categorical( - ["a", "b", "f", "f", "c", "a"], - ordered=True, - categories=["c", "b", "a", "f"], - ), - "x5": [1.2, 5.4, 9.3, 10.4, 3.6, 4.4], - } - ) + df_2 = pd.DataFrame({ + "x1": [1.5, 2.7, 3.6, 4.4, 5.3, 6.1], + "x2": pd.Categorical( + ["a", "b", "c", "a", "b", "c"], + ordered=False, + categories=["c", "b", "a"], + ), + "x3": pd.Categorical( + ["a", "c", "d", "b", "d", "b"], + ordered=False, + categories=["c", "b", "a", "d"], + ), + "x4": pd.Categorical( + ["a", "b", "f", "f", "c", "a"], + ordered=True, + categories=["c", "b", "a", "f"], + ), + "x5": [1.2, 5.4, 9.3, 10.4, 3.6, 4.4], + }) cov_transformer_2 = CovariatePreprocessor() df_transformed_orig_2 = cov_transformer_2.fit_transform(df_2) cov_transformer_json_2 = cov_transformer_2.to_json() @@ -89,9 +85,14 @@ def test_preprocessor(self): df_transformed_reloaded_2 = cov_transformer_reloaded_2.transform(df_2) np.testing.assert_array_equal(df_transformed_orig_2, df_transformed_reloaded_2) - np_3 = np.array( - [[1.5, 1.2], [2.7, 5.4], [3.6, 9.3], [4.4, 10.4], [5.3, 3.6], [6.1, 4.4]] - ) + np_3 = np.array([ + [1.5, 1.2], + [2.7, 5.4], + [3.6, 9.3], + [4.4, 10.4], + [5.3, 3.6], + [6.1, 4.4], + ]) cov_transformer_3 = CovariatePreprocessor() df_transformed_orig_3 = cov_transformer_3.fit_transform(np_3) cov_transformer_json_3 = cov_transformer_3.to_json() @@ -131,7 +132,7 @@ def outcome_mean(X): # Extract original predictions bart_preds = bart_model.predict(X) - forest_preds_y_mcmc_retrieved = bart_preds['y_hat'] + forest_preds_y_mcmc_retrieved = bart_preds["y_hat"] # Roundtrip to / from JSON json_test = JSONSerializer() @@ -213,7 +214,7 @@ def outcome_mean(X, W): residual = Residual(resid) # Forest samplers and temporary tracking data structures - leaf_model_type = 0 if p_W == 0 else 1 + 1*(p_W > 1) + leaf_model_type = 0 if p_W == 0 else 1 + 1 * (p_W > 1) forest_config = ForestModelConfig( num_trees=num_trees, num_features=p_X, @@ -231,9 +232,7 @@ def outcome_mean(X, W): global_config = GlobalModelConfig(global_error_variance=global_variance_init) forest_container = ForestContainer(num_trees, W.shape[1], False, False) active_forest = Forest(num_trees, W.shape[1], False, False) - forest_sampler = ForestSampler( - dataset, global_config, forest_config - ) + forest_sampler = ForestSampler(dataset, global_config, forest_config) cpp_rng = RNG(random_seed) global_var_model = GlobalVarianceModel() @@ -241,9 +240,10 @@ def outcome_mean(X, W): num_warmstart = 10 num_mcmc = 100 num_samples = num_warmstart + num_mcmc - global_var_samples = np.concatenate( - (np.array([global_variance_init]), np.repeat(0, num_samples)) - ) + global_var_samples = np.concatenate(( + np.array([global_variance_init]), + np.repeat(0, num_samples), + )) if p_W > 0: init_val = np.repeat(0.0, W.shape[1]) else: @@ -264,8 +264,8 @@ def outcome_mean(X, W): dataset, residual, cpp_rng, - global_config, - forest_config, + global_config, + forest_config, True, True, ) @@ -281,8 +281,8 @@ def outcome_mean(X, W): dataset, residual, cpp_rng, - global_config, - forest_config, + global_config, + forest_config, True, True, ) @@ -334,18 +334,20 @@ def outcome_mean(X, W): # Run BART bart_orig = BARTModel() - bart_orig.sample(X_train=X, y_train=y, leaf_basis_train=W, num_gfr=10, num_mcmc=10) + bart_orig.sample( + X_train=X, y_train=y, leaf_basis_train=W, num_gfr=10, num_mcmc=10 + ) # Extract predictions from the sampler bart_preds_orig = bart_orig.predict(X, W) - y_hat_orig = bart_preds_orig['y_hat'] + y_hat_orig = bart_preds_orig["y_hat"] # "Round-trip" the model to JSON string and back and check that the predictions agree bart_json_string = bart_orig.to_json() bart_reloaded = BARTModel() bart_reloaded.from_json(bart_json_string) bart_preds_reloaded = bart_reloaded.predict(X, W) - y_hat_reloaded = bart_preds_reloaded['y_hat'] + y_hat_reloaded = bart_preds_reloaded["y_hat"] np.testing.assert_almost_equal(y_hat_orig, y_hat_reloaded) def test_bart_rfx_string(self): @@ -407,19 +409,26 @@ def rfx_mean(group_labels, basis): # Run BART bart_orig = BARTModel() - bart_orig.sample(X_train=X, y_train=y, leaf_basis_train=W, rfx_group_ids_train=group_labels, - rfx_basis_train=basis, num_gfr=10, num_mcmc=10) + bart_orig.sample( + X_train=X, + y_train=y, + leaf_basis_train=W, + rfx_group_ids_train=group_labels, + rfx_basis_train=basis, + num_gfr=10, + num_mcmc=10, + ) # Extract predictions from the sampler bart_preds_orig = bart_orig.predict(X, W, group_labels, basis) - y_hat_orig = bart_preds_orig['y_hat'] + y_hat_orig = bart_preds_orig["y_hat"] # "Round-trip" the model to JSON string and back and check that the predictions agree bart_json_string = bart_orig.to_json() bart_reloaded = BARTModel() bart_reloaded.from_json(bart_json_string) bart_preds_reloaded = bart_reloaded.predict(X, W, group_labels, basis) - y_hat_reloaded = bart_preds_reloaded['y_hat'] + y_hat_reloaded = bart_preds_reloaded["y_hat"] np.testing.assert_almost_equal(y_hat_orig, y_hat_reloaded) def test_bcf_string(self): @@ -450,16 +459,22 @@ def test_bcf_string(self): # Extract predictions from the sampler bcf_preds_orig = bcf_orig.predict(X, Z, pi_X) - mu_hat_orig, tau_hat_orig, y_hat_orig = bcf_preds_orig['mu_hat'], bcf_preds_orig['tau_hat'], bcf_preds_orig['y_hat'] + mu_hat_orig, tau_hat_orig, y_hat_orig = ( + bcf_preds_orig["mu_hat"], + bcf_preds_orig["tau_hat"], + bcf_preds_orig["y_hat"], + ) # "Round-trip" the model to JSON string and back and check that the predictions agree bcf_json_string = bcf_orig.to_json() bcf_reloaded = BCFModel() bcf_reloaded.from_json(bcf_json_string) - bcf_preds_reloaded = bcf_reloaded.predict( - X, Z, pi_X + bcf_preds_reloaded = bcf_reloaded.predict(X, Z, pi_X) + mu_hat_reloaded, tau_hat_reloaded, y_hat_reloaded = ( + bcf_preds_reloaded["mu_hat"], + bcf_preds_reloaded["tau_hat"], + bcf_preds_reloaded["y_hat"], ) - mu_hat_reloaded, tau_hat_reloaded, y_hat_reloaded = bcf_preds_reloaded['mu_hat'], bcf_preds_reloaded['tau_hat'], bcf_preds_reloaded['y_hat'] np.testing.assert_almost_equal(y_hat_orig, y_hat_reloaded) np.testing.assert_almost_equal(tau_hat_orig, tau_hat_reloaded) np.testing.assert_almost_equal(mu_hat_orig, mu_hat_reloaded) @@ -511,21 +526,36 @@ def rfx_mean(group_labels, basis): # Run BCF bcf_orig = BCFModel() bcf_orig.sample( - X_train=X, Z_train=Z, y_train=y, pi_train=pi_X, rfx_group_ids_train=group_labels, rfx_basis_train=basis, num_gfr=10, num_mcmc=10 + X_train=X, + Z_train=Z, + y_train=y, + pi_train=pi_X, + rfx_group_ids_train=group_labels, + rfx_basis_train=basis, + num_gfr=10, + num_mcmc=10, ) # Extract predictions from the sampler bcf_preds_orig = bcf_orig.predict(X, Z, pi_X, group_labels, basis) - mu_hat_orig, tau_hat_orig, rfx_hat_orig, y_hat_orig = bcf_preds_orig['mu_hat'], bcf_preds_orig['tau_hat'], bcf_preds_orig['rfx_predictions'], bcf_preds_orig['y_hat'] + mu_hat_orig, tau_hat_orig, rfx_hat_orig, y_hat_orig = ( + bcf_preds_orig["mu_hat"], + bcf_preds_orig["tau_hat"], + bcf_preds_orig["rfx_predictions"], + bcf_preds_orig["y_hat"], + ) # "Round-trip" the model to JSON string and back and check that the predictions agree bcf_json_string = bcf_orig.to_json() bcf_reloaded = BCFModel() bcf_reloaded.from_json(bcf_json_string) - bcf_preds_reloaded = bcf_reloaded.predict( - X, Z, pi_X, group_labels, basis + bcf_preds_reloaded = bcf_reloaded.predict(X, Z, pi_X, group_labels, basis) + mu_hat_reloaded, tau_hat_reloaded, rfx_hat_reloaded, y_hat_reloaded = ( + bcf_preds_reloaded["mu_hat"], + bcf_preds_reloaded["tau_hat"], + bcf_preds_reloaded["rfx_predictions"], + bcf_preds_reloaded["y_hat"], ) - mu_hat_reloaded, tau_hat_reloaded, rfx_hat_reloaded, y_hat_reloaded = bcf_preds_reloaded['mu_hat'], bcf_preds_reloaded['tau_hat'], bcf_preds_reloaded['rfx_predictions'], bcf_preds_reloaded['y_hat'] np.testing.assert_almost_equal(y_hat_orig, y_hat_reloaded) np.testing.assert_almost_equal(tau_hat_orig, tau_hat_reloaded) np.testing.assert_almost_equal(mu_hat_orig, mu_hat_reloaded) @@ -557,16 +587,22 @@ def test_bcf_propensity_string(self): # Extract predictions from the sampler bcf_preds_orig = bcf_orig.predict(X, Z, pi_X) - mu_hat_orig, tau_hat_orig, y_hat_orig = bcf_preds_orig['mu_hat'], bcf_preds_orig['tau_hat'], bcf_preds_orig['y_hat'] + mu_hat_orig, tau_hat_orig, y_hat_orig = ( + bcf_preds_orig["mu_hat"], + bcf_preds_orig["tau_hat"], + bcf_preds_orig["y_hat"], + ) # "Round-trip" the model to JSON string and back and check that the predictions agree bcf_json_string = bcf_orig.to_json() bcf_reloaded = BCFModel() bcf_reloaded.from_json(bcf_json_string) - bcf_preds_reloaded = bcf_reloaded.predict( - X, Z, pi_X + bcf_preds_reloaded = bcf_reloaded.predict(X, Z, pi_X) + mu_hat_reloaded, tau_hat_reloaded, y_hat_reloaded = ( + bcf_preds_reloaded["mu_hat"], + bcf_preds_reloaded["tau_hat"], + bcf_preds_reloaded["y_hat"], ) - mu_hat_reloaded, tau_hat_reloaded, y_hat_reloaded = bcf_preds_reloaded['mu_hat'], bcf_preds_reloaded['tau_hat'], bcf_preds_reloaded['y_hat'] np.testing.assert_almost_equal(y_hat_orig, y_hat_reloaded) np.testing.assert_almost_equal(tau_hat_orig, tau_hat_reloaded) np.testing.assert_almost_equal(mu_hat_orig, mu_hat_reloaded) diff --git a/test/python/test_kernel.py b/test/python/test_kernel.py index 6d630874..6a6bff09 100644 --- a/test/python/test_kernel.py +++ b/test/python/test_kernel.py @@ -5,22 +5,22 @@ Dataset, Forest, ForestContainer, - compute_forest_leaf_indices, - compute_forest_max_leaf_index + compute_forest_leaf_indices, + compute_forest_max_leaf_index, ) class TestKernel: def test_forest(self): # Create dataset - X = np.array( - [[1.5, 8.7, 1.2], - [2.7, 3.4, 5.4], - [3.6, 1.2, 9.3], - [4.4, 5.4, 10.4], - [5.3, 9.3, 3.6], - [6.1, 10.4, 4.4]] - ) + X = np.array([ + [1.5, 8.7, 1.2], + [2.7, 3.4, 5.4], + [3.6, 1.2, 9.3], + [4.4, 5.4, 10.4], + [5.3, 9.3, 3.6], + [6.1, 10.4, 4.4], + ]) n, p = X.shape num_trees = 2 output_dim = 1 @@ -29,54 +29,54 @@ def test_forest(self): forest_samples = ForestContainer(num_trees, output_dim, True, False) # Initialize a forest with constant root predictions - forest_samples.add_sample(0.) + forest_samples.add_sample(0.0) # Split the root of the first tree in the ensemble at X[,1] > 4.0 - forest_samples.add_numeric_split(0, 0, 0, 0, 4.0, -5., 5.) + forest_samples.add_numeric_split(0, 0, 0, 0, 4.0, -5.0, 5.0) # Check that regular and "raw" predictions are the same (since the leaf is constant) computed = compute_forest_leaf_indices(forest_samples, X) max_leaf_index = compute_forest_max_leaf_index(forest_samples) expected = np.array([ - [0], - [0], - [0], - [1], - [1], + [0], + [0], + [0], [1], - [2], - [2], - [2], - [2], - [2], - [2] + [1], + [1], + [2], + [2], + [2], + [2], + [2], + [2], ]) - + # Assertion np.testing.assert_almost_equal(computed, expected) assert max_leaf_index == [2] # Split the left leaf of the first tree in the ensemble at X[,2] > 4.0 forest_samples.add_numeric_split(0, 0, 1, 1, 4.0, -7.5, -2.5) - + # Check that regular and "raw" predictions are the same (since the leaf is constant) computed = compute_forest_leaf_indices(forest_samples, X) max_leaf_index = compute_forest_max_leaf_index(forest_samples) expected = np.array([ - [2], - [1], - [1], - [0], - [0], + [2], + [1], + [1], [0], - [3], - [3], - [3], - [3], - [3], - [3] + [0], + [0], + [3], + [3], + [3], + [3], + [3], + [3], ]) - + # Assertion np.testing.assert_almost_equal(computed, expected) assert max_leaf_index == [3] diff --git a/test/python/test_predict.py b/test/python/test_predict.py index 8ff6d4ed..03f36cb2 100644 --- a/test/python/test_predict.py +++ b/test/python/test_predict.py @@ -10,14 +10,14 @@ class TestPredict: def test_constant_leaf_prediction(self): # Create dataset - X = np.array( - [[1.5, 8.7, 1.2], - [2.7, 3.4, 5.4], - [3.6, 1.2, 9.3], - [4.4, 5.4, 10.4], - [5.3, 9.3, 3.6], - [6.1, 10.4, 4.4]] - ) + X = np.array([ + [1.5, 8.7, 1.2], + [2.7, 3.4, 5.4], + [3.6, 1.2, 9.3], + [4.4, 5.4, 10.4], + [5.3, 9.3, 3.6], + [6.1, 10.4, 4.4], + ]) n, p = X.shape num_trees = 10 output_dim = 1 @@ -26,7 +26,7 @@ def test_constant_leaf_prediction(self): forest_samples = ForestContainer(num_trees, output_dim, True, False) # Initialize a forest with constant root predictions - forest_samples.add_sample(0.) + forest_samples.add_sample(0.0) # Check that regular and "raw" predictions are the same (since the leaf is constant) pred = forest_samples.predict(forest_dataset) @@ -36,50 +36,43 @@ def test_constant_leaf_prediction(self): np.testing.assert_almost_equal(pred, pred_raw) # Split the root of the first tree in the ensemble at X[,1] > 4.0 - forest_samples.add_numeric_split(0, 0, 0, 0, 4.0, -5., 5.) + forest_samples.add_numeric_split(0, 0, 0, 0, 4.0, -5.0, 5.0) # Check that regular and "raw" predictions are the same (since the leaf is constant) pred = forest_samples.predict(forest_dataset) pred_raw = forest_samples.predict_raw(forest_dataset) - + # Assertion np.testing.assert_almost_equal(pred, pred_raw) # Split the left leaf of the first tree in the ensemble at X[,2] > 4.0 forest_samples.add_numeric_split(0, 0, 1, 1, 4.0, -7.5, -2.5) - + # Check that regular and "raw" predictions are the same (since the leaf is constant) pred = forest_samples.predict(forest_dataset) pred_raw = forest_samples.predict_raw(forest_dataset) - + # Assertion np.testing.assert_almost_equal(pred, pred_raw) - + # Check the split count for the first tree in the ensemble split_counts = forest_samples.get_tree_split_counts(0, 0, p) - split_counts_expected = np.array([1,1,0]) - + split_counts_expected = np.array([1, 1, 0]) + # Assertion np.testing.assert_almost_equal(split_counts, split_counts_expected) - + def test_univariate_regression_leaf_prediction(self): # Create dataset - X = np.array( - [[1.5, 8.7, 1.2], - [2.7, 3.4, 5.4], - [3.6, 1.2, 9.3], - [4.4, 5.4, 10.4], - [5.3, 9.3, 3.6], - [6.1, 10.4, 4.4]] - ) - W = np.array( - [[-1], - [-1], - [-1], - [1], - [1], - [1]] - ) + X = np.array([ + [1.5, 8.7, 1.2], + [2.7, 3.4, 5.4], + [3.6, 1.2, 9.3], + [4.4, 5.4, 10.4], + [5.3, 9.3, 3.6], + [6.1, 10.4, 4.4], + ]) + W = np.array([[-1], [-1], [-1], [1], [1], [1]]) n, p = X.shape num_trees = 10 output_dim = 1 @@ -89,7 +82,7 @@ def test_univariate_regression_leaf_prediction(self): forest_samples = ForestContainer(num_trees, output_dim, False, False) # Initialize a forest with constant root predictions - forest_samples.add_sample(0.) + forest_samples.add_sample(0.0) # Check that regular and "raw" predictions are the same (since the leaf is constant) pred = forest_samples.predict(forest_dataset) @@ -99,52 +92,45 @@ def test_univariate_regression_leaf_prediction(self): np.testing.assert_almost_equal(pred, pred_raw) # Split the root of the first tree in the ensemble at X[,1] > 4.0 - forest_samples.add_numeric_split(0, 0, 0, 0, 4.0, -5., 5.) + forest_samples.add_numeric_split(0, 0, 0, 0, 4.0, -5.0, 5.0) # Check that regular and "raw" predictions are the same (since the leaf is constant) pred = forest_samples.predict(forest_dataset) pred_raw = forest_samples.predict_raw(forest_dataset) - pred_manual = pred_raw*W - + pred_manual = pred_raw * W + # Assertion np.testing.assert_almost_equal(pred, pred_manual) # Split the left leaf of the first tree in the ensemble at X[,2] > 4.0 forest_samples.add_numeric_split(0, 0, 1, 1, 4.0, -7.5, -2.5) - + # Check that regular and "raw" predictions are the same (since the leaf is constant) pred = forest_samples.predict(forest_dataset) pred_raw = forest_samples.predict_raw(forest_dataset) - pred_manual = pred_raw*W - + pred_manual = pred_raw * W + # Assertion np.testing.assert_almost_equal(pred, pred_manual) - + # Check the split count for the first tree in the ensemble split_counts = forest_samples.get_tree_split_counts(0, 0, p) split_counts_expected = np.array([1, 1, 0]) - + # Assertion np.testing.assert_almost_equal(split_counts, split_counts_expected) - + def test_multivariate_regression_leaf_prediction(self): # Create dataset - X = np.array( - [[1.5, 8.7, 1.2], - [2.7, 3.4, 5.4], - [3.6, 1.2, 9.3], - [4.4, 5.4, 10.4], - [5.3, 9.3, 3.6], - [6.1, 10.4, 4.4]] - ) - W = np.array( - [[1,-1], - [1,-1], - [1,-1], - [1, 1], - [1, 1], - [1, 1]] - ) + X = np.array([ + [1.5, 8.7, 1.2], + [2.7, 3.4, 5.4], + [3.6, 1.2, 9.3], + [4.4, 5.4, 10.4], + [5.3, 9.3, 3.6], + [6.1, 10.4, 4.4], + ]) + W = np.array([[1, -1], [1, -1], [1, -1], [1, 1], [1, 1], [1, 1]]) n, p = X.shape num_trees = 10 output_dim = 2 @@ -155,62 +141,74 @@ def test_multivariate_regression_leaf_prediction(self): forest_samples = ForestContainer(num_trees, output_dim, False, False) # Initialize a forest with constant root predictions - forest_samples.add_sample(np.array([1.,1.])) + forest_samples.add_sample(np.array([1.0, 1.0])) num_samples += 1 # Check that regular and "raw" predictions are the same (since the leaf is constant) pred = forest_samples.predict(forest_dataset) pred_raw = forest_samples.predict_raw(forest_dataset) pred_intermediate = pred_raw * W - pred_manual = pred_intermediate.sum(axis=1, keepdims = True) + pred_manual = pred_intermediate.sum(axis=1, keepdims=True) # Assertion np.testing.assert_almost_equal(pred, pred_manual) # Split the root of the first tree in the ensemble at X[,1] > 4.0 - forest_samples.add_numeric_split(0, 0, 0, 0, 4.0, np.array([-5.,-1.]), np.array([5.,1.])) + forest_samples.add_numeric_split( + 0, 0, 0, 0, 4.0, np.array([-5.0, -1.0]), np.array([5.0, 1.0]) + ) # Check that regular and "raw" predictions are the same (since the leaf is constant) pred = forest_samples.predict(forest_dataset) pred_raw = forest_samples.predict_raw(forest_dataset) pred_intermediate = pred_raw * W - pred_manual = pred_intermediate.sum(axis=1, keepdims = True) - + pred_manual = pred_intermediate.sum(axis=1, keepdims=True) + # Assertion np.testing.assert_almost_equal(pred, pred_manual) # Split the left leaf of the first tree in the ensemble at X[,2] > 4.0 - forest_samples.add_numeric_split(0, 0, 1, 1, 4.0, np.array([-7.5,2.5]), np.array([-2.5,7.5])) - + forest_samples.add_numeric_split( + 0, 0, 1, 1, 4.0, np.array([-7.5, 2.5]), np.array([-2.5, 7.5]) + ) + # Check that regular and "raw" predictions are the same (since the leaf is constant) pred = forest_samples.predict(forest_dataset) pred_raw = forest_samples.predict_raw(forest_dataset) pred_intermediate = pred_raw * W - pred_manual = pred_intermediate.sum(axis=1, keepdims = True) - + pred_manual = pred_intermediate.sum(axis=1, keepdims=True) + # Assertion np.testing.assert_almost_equal(pred, pred_manual) - + # Check the split count for the first tree in the ensemble split_counts = forest_samples.get_tree_split_counts(0, 0, p) split_counts_expected = np.array([1, 1, 0]) - + # Assertion np.testing.assert_almost_equal(split_counts, split_counts_expected) - + def test_bart_prediction(self): # Generate data and test/train split rng = np.random.default_rng(1234) n = 100 p = 5 X = rng.uniform(size=(n, p)) - f_XW = np.where((0 <= X[:, 0]) & (X[:, 0] < 0.25), -7.5, - np.where((0.25 <= X[:, 0]) & (X[:, 0] < 0.5), -2.5, - np.where((0.5 <= X[:, 0]) & (X[:, 0] < 0.75), 2.5, 7.5))) + f_XW = np.where( + (0 <= X[:, 0]) & (X[:, 0] < 0.25), + -7.5, + np.where( + (0.25 <= X[:, 0]) & (X[:, 0] < 0.5), + -2.5, + np.where((0.5 <= X[:, 0]) & (X[:, 0] < 0.75), 2.5, 7.5), + ), + ) noise_sd = 1 y = f_XW + rng.normal(0, noise_sd, size=n) test_set_pct = 0.2 - train_inds, test_inds = train_test_split(np.arange(n), test_size=test_set_pct, random_state=1234) + train_inds, test_inds = train_test_split( + np.arange(n), test_size=test_set_pct, random_state=1234 + ) X_train = X[train_inds, :] X_test = X[test_inds, :] y_train = y[train_inds] @@ -218,48 +216,65 @@ def test_bart_prediction(self): # Fit a "classic" BART model bart_model = BARTModel() - bart_model.sample(X_train = X_train, y_train = y_train, num_gfr=10, num_burnin=0, num_mcmc=10) + bart_model.sample( + X_train=X_train, y_train=y_train, num_gfr=10, num_burnin=0, num_mcmc=10 + ) # Check that the default predict method returns a dictionary pred = bart_model.predict(covariates=X_test) - y_hat_posterior_test = pred['y_hat'] + y_hat_posterior_test = pred["y_hat"] assert y_hat_posterior_test.shape == (20, 10) # Check that the pre-aggregated predictions match with those computed by np.mean pred_mean = bart_model.predict(covariates=X_test, type="mean") - y_hat_mean_test = pred_mean['y_hat'] - np.testing.assert_almost_equal(y_hat_mean_test, np.mean(y_hat_posterior_test, axis=1)) + y_hat_mean_test = pred_mean["y_hat"] + np.testing.assert_almost_equal( + y_hat_mean_test, np.mean(y_hat_posterior_test, axis=1) + ) # Fit a heteroskedastic BART model - var_params = {'num_trees': 20} + var_params = {"num_trees": 20} het_bart_model = BARTModel() het_bart_model.sample( - X_train = X_train, y_train = y_train, - num_gfr=10, num_burnin=0, num_mcmc=10, - variance_forest_params=var_params + X_train=X_train, + y_train=y_train, + num_gfr=10, + num_burnin=0, + num_mcmc=10, + variance_forest_params=var_params, ) # Check that the default predict method returns a dictionary pred = het_bart_model.predict(covariates=X_test) - y_hat_posterior_test = pred['y_hat'] - sigma2_hat_posterior_test = pred['variance_forest_predictions'] + y_hat_posterior_test = pred["y_hat"] + sigma2_hat_posterior_test = pred["variance_forest_predictions"] assert y_hat_posterior_test.shape == (20, 10) assert sigma2_hat_posterior_test.shape == (20, 10) # Check that the pre-aggregated predictions match with those computed by np.mean pred_mean = het_bart_model.predict(covariates=X_test, type="mean") - y_hat_mean_test = pred_mean['y_hat'] - sigma2_hat_mean_test = pred_mean['variance_forest_predictions'] - np.testing.assert_almost_equal(y_hat_mean_test, np.mean(y_hat_posterior_test, axis=1)) - np.testing.assert_almost_equal(sigma2_hat_mean_test, np.mean(sigma2_hat_posterior_test, axis=1)) + y_hat_mean_test = pred_mean["y_hat"] + sigma2_hat_mean_test = pred_mean["variance_forest_predictions"] + np.testing.assert_almost_equal( + y_hat_mean_test, np.mean(y_hat_posterior_test, axis=1) + ) + np.testing.assert_almost_equal( + sigma2_hat_mean_test, np.mean(sigma2_hat_posterior_test, axis=1) + ) # Check that the "single-term" pre-aggregated predictions # match those computed by pre-aggregated predictions returned in a dictionary - y_hat_mean_test_single_term = het_bart_model.predict(covariates=X_test, type="mean", terms="y_hat") - sigma2_hat_mean_test_single_term = het_bart_model.predict(covariates=X_test, type="mean", terms="variance_forest") + y_hat_mean_test_single_term = het_bart_model.predict( + covariates=X_test, type="mean", terms="y_hat" + ) + sigma2_hat_mean_test_single_term = het_bart_model.predict( + covariates=X_test, type="mean", terms="variance_forest" + ) np.testing.assert_almost_equal(y_hat_mean_test, y_hat_mean_test_single_term) - np.testing.assert_almost_equal(sigma2_hat_mean_test, sigma2_hat_mean_test_single_term) - + np.testing.assert_almost_equal( + sigma2_hat_mean_test, sigma2_hat_mean_test_single_term + ) + def test_bcf_prediction(self): # Generate data and test/train split rng = np.random.default_rng(1234) @@ -268,121 +283,138 @@ def test_bcf_prediction(self): x1 = rng.normal(size=n) x2 = rng.normal(size=n) x3 = rng.normal(size=n) - x4 = rng.binomial(n=1,p=0.5,size=(n,)) - x5 = rng.choice(a=[0,1,2], size=(n,), replace=True) - x4_cat = pd.Categorical(x4, categories=[0,1], ordered=True) - x5_cat = pd.Categorical(x4, categories=[0,1,2], ordered=True) + x4 = rng.binomial(n=1, p=0.5, size=(n,)) + x5 = rng.choice(a=[0, 1, 2], size=(n,), replace=True) + x4_cat = pd.Categorical(x4, categories=[0, 1], ordered=True) + x5_cat = pd.Categorical(x4, categories=[0, 1, 2], ordered=True) p = 5 - X = pd.DataFrame(data={ - "x1": pd.Series(x1), - "x2": pd.Series(x2), - "x3": pd.Series(x3), - "x4": pd.Series(x4_cat), - "x5": pd.Series(x5_cat) - }) + X = pd.DataFrame( + data={ + "x1": pd.Series(x1), + "x2": pd.Series(x2), + "x3": pd.Series(x3), + "x4": pd.Series(x4_cat), + "x5": pd.Series(x5_cat), + } + ) + def g(x5): - return np.where( - x5 == 0, 2.0, - np.where( - x5 == 1, -1.0, -4.0 - ) - ) + return np.where(x5 == 0, 2.0, np.where(x5 == 1, -1.0, -4.0)) + p = X.shape[1] - mu_x = 1.0 + g(x5) + x1*x3 - tau_x = 1.0 + 2*x2*x4 + mu_x = 1.0 + g(x5) + x1 * x3 + tau_x = 1.0 + 2 * x2 * x4 pi_x = ( - 0.8 * norm.cdf(3.0 * mu_x / np.squeeze(np.std(mu_x)) - 0.5 * x1) + - 0.05 + rng.uniform(low=0., high=0.1, size=(n,)) + 0.8 * norm.cdf(3.0 * mu_x / np.squeeze(np.std(mu_x)) - 0.5 * x1) + + 0.05 + + rng.uniform(low=0.0, high=0.1, size=(n,)) ) Z = rng.binomial(n=1, p=pi_x, size=(n,)) - E_XZ = mu_x + tau_x*Z + E_XZ = mu_x + tau_x * Z snr = 2 - y = E_XZ + rng.normal(loc=0., scale=np.std(E_XZ) / snr, size=(n,)) + y = E_XZ + rng.normal(loc=0.0, scale=np.std(E_XZ) / snr, size=(n,)) test_set_pct = 0.2 - train_inds, test_inds = train_test_split(np.arange(n), test_size=test_set_pct, random_state=1234) - X_train = X.iloc[train_inds,:] - X_test = X.iloc[test_inds,:] + train_inds, test_inds = train_test_split( + np.arange(n), test_size=test_set_pct, random_state=1234 + ) + X_train = X.iloc[train_inds, :] + X_test = X.iloc[test_inds, :] Z_train = Z[train_inds] Z_test = Z[test_inds] pi_x_train = pi_x[train_inds] pi_x_test = pi_x[test_inds] y_train = y[train_inds] y_test = y[test_inds] - + # Fit a "classic" BCF model bcf_model = BCFModel() bcf_model.sample( - X_train = X_train, - Z_train = Z_train, - y_train = y_train, - pi_train = pi_x_train, - X_test = X_test, - Z_test = Z_test, - pi_test = pi_x_test, - num_gfr = 10, - num_burnin = 0, - num_mcmc = 10 + X_train=X_train, + Z_train=Z_train, + y_train=y_train, + pi_train=pi_x_train, + X_test=X_test, + Z_test=Z_test, + pi_test=pi_x_test, + num_gfr=10, + num_burnin=0, + num_mcmc=10, ) # Check that the default predict method returns a dictionary pred = bcf_model.predict(X=X_test, Z=Z_test, propensity=pi_x_test) - y_hat_posterior_test = pred['y_hat'] + y_hat_posterior_test = pred["y_hat"] assert y_hat_posterior_test.shape == (20, 10) # Check that the pre-aggregated predictions match with those computed by np.mean - pred_mean = bcf_model.predict(X=X_test, Z=Z_test, propensity=pi_x_test, type="mean") - y_hat_mean_test = pred_mean['y_hat'] - np.testing.assert_almost_equal(y_hat_mean_test, np.mean(y_hat_posterior_test, axis=1)) + pred_mean = bcf_model.predict( + X=X_test, Z=Z_test, propensity=pi_x_test, type="mean" + ) + y_hat_mean_test = pred_mean["y_hat"] + np.testing.assert_almost_equal( + y_hat_mean_test, np.mean(y_hat_posterior_test, axis=1) + ) # Check that we warn and return None when requesting terms that weren't fit with pytest.warns(UserWarning): pred_mean = bcf_model.predict( - X=X_test, Z=Z_test, propensity=pi_x_test, - type="mean", terms=["rfx", "variance_forest"] + X=X_test, + Z=Z_test, + propensity=pi_x_test, + type="mean", + terms=["rfx", "variance_forest"], ) - + # Fit a heteroskedastic BCF model - var_params = {'num_trees': 20} + var_params = {"num_trees": 20} het_bcf_model = BCFModel() het_bcf_model.sample( - X_train = X_train, - Z_train = Z_train, - y_train = y_train, - pi_train = pi_x_train, - X_test = X_test, - Z_test = Z_test, - pi_test = pi_x_test, - num_gfr = 10, - num_burnin = 0, - num_mcmc = 10, - variance_forest_params=var_params + X_train=X_train, + Z_train=Z_train, + y_train=y_train, + pi_train=pi_x_train, + X_test=X_test, + Z_test=Z_test, + pi_test=pi_x_test, + num_gfr=10, + num_burnin=0, + num_mcmc=10, + variance_forest_params=var_params, ) # Check that the default predict method returns a dictionary pred = het_bcf_model.predict(X=X_test, Z=Z_test, propensity=pi_x_test) - y_hat_posterior_test = pred['y_hat'] - sigma2_hat_posterior_test = pred['variance_forest_predictions'] + y_hat_posterior_test = pred["y_hat"] + sigma2_hat_posterior_test = pred["variance_forest_predictions"] assert y_hat_posterior_test.shape == (20, 10) assert sigma2_hat_posterior_test.shape == (20, 10) # Check that the pre-aggregated predictions match with those computed by np.mean - pred_mean = het_bcf_model.predict(X=X_test, Z=Z_test, propensity=pi_x_test, type="mean") - y_hat_mean_test = pred_mean['y_hat'] - sigma2_hat_mean_test = pred_mean['variance_forest_predictions'] - np.testing.assert_almost_equal(y_hat_mean_test, np.mean(y_hat_posterior_test, axis=1)) - np.testing.assert_almost_equal(sigma2_hat_mean_test, np.mean(sigma2_hat_posterior_test, axis=1)) + pred_mean = het_bcf_model.predict( + X=X_test, Z=Z_test, propensity=pi_x_test, type="mean" + ) + y_hat_mean_test = pred_mean["y_hat"] + sigma2_hat_mean_test = pred_mean["variance_forest_predictions"] + np.testing.assert_almost_equal( + y_hat_mean_test, np.mean(y_hat_posterior_test, axis=1) + ) + np.testing.assert_almost_equal( + sigma2_hat_mean_test, np.mean(sigma2_hat_posterior_test, axis=1) + ) # Check that the "single-term" pre-aggregated predictions # match those computed by pre-aggregated predictions returned in a dictionary y_hat_mean_test_single_term = het_bcf_model.predict( - X=X_test, Z=Z_test, propensity=pi_x_test, - type="mean", terms="y_hat" + X=X_test, Z=Z_test, propensity=pi_x_test, type="mean", terms="y_hat" ) sigma2_hat_mean_test_single_term = het_bcf_model.predict( - X=X_test, Z=Z_test, propensity=pi_x_test, - type="mean", terms="variance_forest" + X=X_test, + Z=Z_test, + propensity=pi_x_test, + type="mean", + terms="variance_forest", ) np.testing.assert_almost_equal(y_hat_mean_test, y_hat_mean_test_single_term) - np.testing.assert_almost_equal(sigma2_hat_mean_test, sigma2_hat_mean_test_single_term) - - + np.testing.assert_almost_equal( + sigma2_hat_mean_test, sigma2_hat_mean_test_single_term + ) diff --git a/test/python/test_preprocessor.py b/test/python/test_preprocessor.py index f40ef204..748fa18a 100644 --- a/test/python/test_preprocessor.py +++ b/test/python/test_preprocessor.py @@ -7,16 +7,14 @@ class TestPreprocessor: def test_numpy(self): cov_transformer = CovariatePreprocessor() - np_1 = np.array( - [ - [1.5, 8.7, 1.2], - [2.7, 3.4, 5.4], - [3.6, 1.2, 9.3], - [4.4, 5.4, 10.4], - [5.3, 9.3, 3.6], - [6.1, 10.4, 4.4], - ] - ) + np_1 = np.array([ + [1.5, 8.7, 1.2], + [2.7, 3.4, 5.4], + [3.6, 1.2, 9.3], + [4.4, 5.4, 10.4], + [5.3, 9.3, 3.6], + [6.1, 10.4, 4.4], + ]) np_1_transformed = cov_transformer.fit_transform(np_1) np.testing.assert_array_equal(np_1, np_1_transformed) np.testing.assert_array_equal( @@ -24,23 +22,19 @@ def test_numpy(self): ) def test_pandas(self): - df_1 = pd.DataFrame( - { - "x1": [1.5, 2.7, 3.6, 4.4, 5.3, 6.1], - "x2": [8.7, 3.4, 1.2, 5.4, 9.3, 10.4], - "x3": [1.2, 5.4, 9.3, 10.4, 3.6, 4.4], - } - ) - np_1 = np.array( - [ - [1.5, 8.7, 1.2], - [2.7, 3.4, 5.4], - [3.6, 1.2, 9.3], - [4.4, 5.4, 10.4], - [5.3, 9.3, 3.6], - [6.1, 10.4, 4.4], - ] - ) + df_1 = pd.DataFrame({ + "x1": [1.5, 2.7, 3.6, 4.4, 5.3, 6.1], + "x2": [8.7, 3.4, 1.2, 5.4, 9.3, 10.4], + "x3": [1.2, 5.4, 9.3, 10.4, 3.6, 4.4], + }) + np_1 = np.array([ + [1.5, 8.7, 1.2], + [2.7, 3.4, 5.4], + [3.6, 1.2, 9.3], + [4.4, 5.4, 10.4], + [5.3, 9.3, 3.6], + [6.1, 10.4, 4.4], + ]) cov_transformer = CovariatePreprocessor() df_1_transformed = cov_transformer.fit_transform(df_1) np.testing.assert_array_equal(np_1, df_1_transformed) @@ -48,27 +42,23 @@ def test_pandas(self): cov_transformer._processed_feature_types, np.array([0, 0, 0]) ) - df_2 = pd.DataFrame( - { - "x1": [1.5, 2.7, 3.6, 4.4, 5.3, 6.1], - "x2": pd.Categorical( - ["a", "b", "c", "a", "b", "c"], - ordered=True, - categories=["c", "b", "a"], - ), - "x3": [1.2, 5.4, 9.3, 10.4, 3.6, 4.4], - } - ) - np_2 = np.array( - [ - [1.5, 2, 1.2], - [2.7, 1, 5.4], - [3.6, 0, 9.3], - [4.4, 2, 10.4], - [5.3, 1, 3.6], - [6.1, 0, 4.4], - ] - ) + df_2 = pd.DataFrame({ + "x1": [1.5, 2.7, 3.6, 4.4, 5.3, 6.1], + "x2": pd.Categorical( + ["a", "b", "c", "a", "b", "c"], + ordered=True, + categories=["c", "b", "a"], + ), + "x3": [1.2, 5.4, 9.3, 10.4, 3.6, 4.4], + }) + np_2 = np.array([ + [1.5, 2, 1.2], + [2.7, 1, 5.4], + [3.6, 0, 9.3], + [4.4, 2, 10.4], + [5.3, 1, 3.6], + [6.1, 0, 4.4], + ]) cov_transformer = CovariatePreprocessor() df_2_transformed = cov_transformer.fit_transform(df_2) np.testing.assert_array_equal(np_2, df_2_transformed) @@ -76,27 +66,23 @@ def test_pandas(self): cov_transformer._processed_feature_types, np.array([0, 1, 0]) ) - df_3 = pd.DataFrame( - { - "x1": [1.5, 2.7, 3.6, 4.4, 5.3, 6.1], - "x2": pd.Categorical( - ["a", "b", "c", "a", "b", "c"], - ordered=False, - categories=["c", "b", "a"], - ), - "x3": [1.2, 5.4, 9.3, 10.4, 3.6, 4.4], - } - ) - np_3 = np.array( - [ - [1.5, 0, 0, 1, 1.2], - [2.7, 0, 1, 0, 5.4], - [3.6, 1, 0, 0, 9.3], - [4.4, 0, 0, 1, 10.4], - [5.3, 0, 1, 0, 3.6], - [6.1, 1, 0, 0, 4.4], - ] - ) + df_3 = pd.DataFrame({ + "x1": [1.5, 2.7, 3.6, 4.4, 5.3, 6.1], + "x2": pd.Categorical( + ["a", "b", "c", "a", "b", "c"], + ordered=False, + categories=["c", "b", "a"], + ), + "x3": [1.2, 5.4, 9.3, 10.4, 3.6, 4.4], + }) + np_3 = np.array([ + [1.5, 0, 0, 1, 1.2], + [2.7, 0, 1, 0, 5.4], + [3.6, 1, 0, 0, 9.3], + [4.4, 0, 0, 1, 10.4], + [5.3, 0, 1, 0, 3.6], + [6.1, 1, 0, 0, 4.4], + ]) cov_transformer = CovariatePreprocessor() df_3_transformed = cov_transformer.fit_transform(df_3) np.testing.assert_array_equal(np_3, df_3_transformed) @@ -104,28 +90,24 @@ def test_pandas(self): cov_transformer._processed_feature_types, np.array([0, 1, 1, 1, 0]) ) - df_4 = pd.DataFrame( - { - "x1": [1.5, 2.7, 3.6, 4.4, 5.3, 6.1, 7.6], - "x2": pd.Categorical( - ["a", "b", "c", "a", "b", "c", "c"], - ordered=False, - categories=["c", "b", "a", "d"], - ), - "x3": [1.2, 5.4, 9.3, 10.4, 3.6, 4.4, 3.4], - } - ) - np_4 = np.array( - [ - [1.5, 0, 0, 1, 0, 1.2], - [2.7, 0, 1, 0, 0, 5.4], - [3.6, 1, 0, 0, 0, 9.3], - [4.4, 0, 0, 1, 0, 10.4], - [5.3, 0, 1, 0, 0, 3.6], - [6.1, 1, 0, 0, 0, 4.4], - [7.6, 1, 0, 0, 0, 3.4], - ] - ) + df_4 = pd.DataFrame({ + "x1": [1.5, 2.7, 3.6, 4.4, 5.3, 6.1, 7.6], + "x2": pd.Categorical( + ["a", "b", "c", "a", "b", "c", "c"], + ordered=False, + categories=["c", "b", "a", "d"], + ), + "x3": [1.2, 5.4, 9.3, 10.4, 3.6, 4.4, 3.4], + }) + np_4 = np.array([ + [1.5, 0, 0, 1, 0, 1.2], + [2.7, 0, 1, 0, 0, 5.4], + [3.6, 1, 0, 0, 0, 9.3], + [4.4, 0, 0, 1, 0, 10.4], + [5.3, 0, 1, 0, 0, 3.6], + [6.1, 1, 0, 0, 0, 4.4], + [7.6, 1, 0, 0, 0, 3.4], + ]) cov_transformer = CovariatePreprocessor() df_4_transformed = cov_transformer.fit_transform(df_4) np.testing.assert_array_equal(np_4, df_4_transformed) diff --git a/test/python/test_residual.py b/test/python/test_residual.py index a0dd1c09..9be4eaa9 100644 --- a/test/python/test_residual.py +++ b/test/python/test_residual.py @@ -15,16 +15,14 @@ class TestResidual: def test_basis_update(self): # Create dataset - X = np.array( - [ - [1.5, 8.7, 1.2], - [2.7, 3.4, 5.4], - [3.6, 1.2, 9.3], - [4.4, 5.4, 10.4], - [5.3, 9.3, 3.6], - [6.1, 10.4, 4.4], - ] - ) + X = np.array([ + [1.5, 8.7, 1.2], + [2.7, 3.4, 5.4], + [3.6, 1.2, 9.3], + [4.4, 5.4, 10.4], + [5.3, 9.3, 3.6], + [6.1, 10.4, 4.4], + ]) W = np.array([[1], [1], [1], [1], [1], [1]]) n = X.shape[0] p = X.shape[1] diff --git a/test/python/test_utils.py b/test/python/test_utils.py index 5394802f..9b6da787 100644 --- a/test/python/test_utils.py +++ b/test/python/test_utils.py @@ -9,9 +9,9 @@ _check_matrix_square, _standardize_array_to_list, _standardize_array_to_np, - _expand_dims_1d, - _expand_dims_2d, - _expand_dims_2d_diag + _expand_dims_1d, + _expand_dims_2d, + _expand_dims_2d_diag, ) @@ -90,9 +90,11 @@ def test_standardize(self): array_np3 = np.array([8.2, 4.5, 3.8]) array_np4 = np.array([[8.2, 4.5, 3.8]]) nonconforming_array_np1 = np.array([[8.2, 4.5, 3.8], [1.6, 3.4, 7.6]]) - nonconforming_array_np2 = np.array( - [[8.2, 4.5, 3.8], [1.6, 3.4, 7.6], [1.6, 3.4, 7.6]] - ) + nonconforming_array_np2 = np.array([ + [8.2, 4.5, 3.8], + [1.6, 3.4, 7.6], + [1.6, 3.4, 7.6], + ]) non_array_1 = 100000000 non_array_2 = "a" @@ -117,7 +119,7 @@ def test_standardize(self): _ = _standardize_array_to_np(non_array_2) _ = _standardize_array_to_np(nonconforming_array_np1) _ = _standardize_array_to_np(nonconforming_array_np2) - + def test_array_conversion(self): scalar_1 = 1.5 scalar_2 = -2.5 @@ -125,11 +127,15 @@ def test_array_conversion(self): array_1d_1 = np.array([1.6, 3.4, 7.6, 8.7]) array_1d_2 = np.array([2.5, 3.1, 5.6]) array_1d_3 = np.array([2.5]) - array_2d_1 = np.array([[2.5,1.2,4.3,7.4],[1.7,2.9,3.6,9.1],[7.2,4.5,6.7,1.4]]) - array_2d_2 = np.array([[2.5,1.2,4.3,7.4],[1.7,2.9,3.6,9.1]]) - array_square_1 = np.array([[2.5,1.2],[1.7,2.9]]) - array_square_2 = np.array([[2.5,0.0],[0.0,2.9]]) - array_square_3 = np.array([[2.5,0.0,0.0],[0.0,2.9,0.0],[0.0,0.0,5.6]]) + array_2d_1 = np.array([ + [2.5, 1.2, 4.3, 7.4], + [1.7, 2.9, 3.6, 9.1], + [7.2, 4.5, 6.7, 1.4], + ]) + array_2d_2 = np.array([[2.5, 1.2, 4.3, 7.4], [1.7, 2.9, 3.6, 9.1]]) + array_square_1 = np.array([[2.5, 1.2], [1.7, 2.9]]) + array_square_2 = np.array([[2.5, 0.0], [0.0, 2.9]]) + array_square_3 = np.array([[2.5, 0.0, 0.0], [0.0, 2.9, 0.0], [0.0, 0.0, 5.6]]) with pytest.raises(ValueError): _ = _expand_dims_1d(array_1d_1, 5) _ = _expand_dims_1d(array_1d_2, 4) @@ -139,23 +145,91 @@ def test_array_conversion(self): _ = _expand_dims_2d_diag(array_square_1, 4) _ = _expand_dims_2d_diag(array_square_2, 3) _ = _expand_dims_2d_diag(array_square_3, 2) - - np.testing.assert_array_equal(np.array([scalar_1,scalar_1,scalar_1]), _expand_dims_1d(scalar_1, 3)) - np.testing.assert_array_equal(np.array([scalar_2,scalar_2,scalar_2,scalar_2]), _expand_dims_1d(scalar_2, 4)) - np.testing.assert_array_equal(np.array([scalar_3,scalar_3]), _expand_dims_1d(scalar_3, 2)) - np.testing.assert_array_equal(np.array([array_1d_3[0],array_1d_3[0],array_1d_3[0]]), _expand_dims_1d(array_1d_3, 3)) - - np.testing.assert_array_equal(np.array([[scalar_1,scalar_1,scalar_1],[scalar_1,scalar_1,scalar_1]]), _expand_dims_2d(scalar_1, 2, 3)) - np.testing.assert_array_equal(np.array([[scalar_2,scalar_2,scalar_2,scalar_2],[scalar_2,scalar_2,scalar_2,scalar_2]]), _expand_dims_2d(scalar_2, 2, 4)) - np.testing.assert_array_equal(np.array([[scalar_3,scalar_3],[scalar_3,scalar_3],[scalar_3,scalar_3]]), _expand_dims_2d(scalar_3, 3, 2)) - np.testing.assert_array_equal(np.array([[array_1d_3[0],array_1d_3[0]],[array_1d_3[0],array_1d_3[0]],[array_1d_3[0],array_1d_3[0]]]), _expand_dims_2d(array_1d_3, 3, 2)) - np.testing.assert_array_equal(np.vstack((array_1d_1, array_1d_1)), _expand_dims_2d(array_1d_1, 2, 4)) - np.testing.assert_array_equal(np.vstack((array_1d_2, array_1d_2, array_1d_2)), _expand_dims_2d(array_1d_2, 3, 3)) - np.testing.assert_array_equal(np.column_stack((array_1d_2, array_1d_2, array_1d_2, array_1d_2)), _expand_dims_2d(array_1d_2, 3, 4)) - np.testing.assert_array_equal(np.column_stack((array_1d_3, array_1d_3, array_1d_3, array_1d_3)), _expand_dims_2d(array_1d_3, 1, 4)) - np.testing.assert_array_equal(np.vstack((array_1d_3, array_1d_3, array_1d_3, array_1d_3)), _expand_dims_2d(array_1d_3, 4, 1)) - - np.testing.assert_array_equal(np.array([[scalar_1,0.0,0.0],[0.0,scalar_1,0.0],[0.0,0.0,scalar_1]]), _expand_dims_2d_diag(scalar_1, 3)) - np.testing.assert_array_equal(np.array([[scalar_2,0.0],[0.0,scalar_2]]), _expand_dims_2d_diag(scalar_2, 2)) - np.testing.assert_array_equal(np.array([[scalar_3,0.0,0.0,0.0],[0.0,scalar_3,0.0,0.0],[0.0,0.0,scalar_3,0.0],[0.0,0.0,0.0,scalar_3]]), _expand_dims_2d_diag(scalar_3, 4)) - np.testing.assert_array_equal(np.array([[array_1d_3[0],0.0],[0.0,array_1d_3[0]]]), _expand_dims_2d_diag(array_1d_3, 2)) + + np.testing.assert_array_equal( + np.array([scalar_1, scalar_1, scalar_1]), _expand_dims_1d(scalar_1, 3) + ) + np.testing.assert_array_equal( + np.array([scalar_2, scalar_2, scalar_2, scalar_2]), + _expand_dims_1d(scalar_2, 4), + ) + np.testing.assert_array_equal( + np.array([scalar_3, scalar_3]), _expand_dims_1d(scalar_3, 2) + ) + np.testing.assert_array_equal( + np.array([array_1d_3[0], array_1d_3[0], array_1d_3[0]]), + _expand_dims_1d(array_1d_3, 3), + ) + + np.testing.assert_array_equal( + np.array([[scalar_1, scalar_1, scalar_1], [scalar_1, scalar_1, scalar_1]]), + _expand_dims_2d(scalar_1, 2, 3), + ) + np.testing.assert_array_equal( + np.array([ + [scalar_2, scalar_2, scalar_2, scalar_2], + [scalar_2, scalar_2, scalar_2, scalar_2], + ]), + _expand_dims_2d(scalar_2, 2, 4), + ) + np.testing.assert_array_equal( + np.array([ + [scalar_3, scalar_3], + [scalar_3, scalar_3], + [scalar_3, scalar_3], + ]), + _expand_dims_2d(scalar_3, 3, 2), + ) + np.testing.assert_array_equal( + np.array([ + [array_1d_3[0], array_1d_3[0]], + [array_1d_3[0], array_1d_3[0]], + [array_1d_3[0], array_1d_3[0]], + ]), + _expand_dims_2d(array_1d_3, 3, 2), + ) + np.testing.assert_array_equal( + np.vstack((array_1d_1, array_1d_1)), _expand_dims_2d(array_1d_1, 2, 4) + ) + np.testing.assert_array_equal( + np.vstack((array_1d_2, array_1d_2, array_1d_2)), + _expand_dims_2d(array_1d_2, 3, 3), + ) + np.testing.assert_array_equal( + np.column_stack((array_1d_2, array_1d_2, array_1d_2, array_1d_2)), + _expand_dims_2d(array_1d_2, 3, 4), + ) + np.testing.assert_array_equal( + np.column_stack((array_1d_3, array_1d_3, array_1d_3, array_1d_3)), + _expand_dims_2d(array_1d_3, 1, 4), + ) + np.testing.assert_array_equal( + np.vstack((array_1d_3, array_1d_3, array_1d_3, array_1d_3)), + _expand_dims_2d(array_1d_3, 4, 1), + ) + + np.testing.assert_array_equal( + np.array([ + [scalar_1, 0.0, 0.0], + [0.0, scalar_1, 0.0], + [0.0, 0.0, scalar_1], + ]), + _expand_dims_2d_diag(scalar_1, 3), + ) + np.testing.assert_array_equal( + np.array([[scalar_2, 0.0], [0.0, scalar_2]]), + _expand_dims_2d_diag(scalar_2, 2), + ) + np.testing.assert_array_equal( + np.array([ + [scalar_3, 0.0, 0.0, 0.0], + [0.0, scalar_3, 0.0, 0.0], + [0.0, 0.0, scalar_3, 0.0], + [0.0, 0.0, 0.0, scalar_3], + ]), + _expand_dims_2d_diag(scalar_3, 4), + ) + np.testing.assert_array_equal( + np.array([[array_1d_3[0], 0.0], [0.0, array_1d_3[0]]]), + _expand_dims_2d_diag(array_1d_3, 2), + )