From 11282db2438bb71624b953d627c449e2cd7c0e75 Mon Sep 17 00:00:00 2001 From: Joey Couse <54423399+joeycouse@users.noreply.github.com> Date: Wed, 20 Jul 2022 12:00:05 -0500 Subject: [PATCH 01/10] allow user to pass dataframe to `validation` argument of xgboost. Preserve weights for internal validation set. --- R/boost_tree.R | 63 +++++++++++++++++++++++++++++++++++++++++--------- 1 file changed, 52 insertions(+), 11 deletions(-) diff --git a/R/boost_tree.R b/R/boost_tree.R index cf12e5dea..52cb7e32a 100644 --- a/R/boost_tree.R +++ b/R/boost_tree.R @@ -235,7 +235,12 @@ xgb_train <- function( num_class <- length(levels(y)) - if (!is.numeric(validation) || validation < 0 || validation >= 1) { + if(is.data.frame(validation)) { + if(length(colnames(validation)) != length(colnames(x))+1){ + msg <- paste0("`validation` should contain ", length(colnames(x))+1, " columns") + rlang::abort(msg) + } + } else if (!is.numeric(validation) || validation < 0 || validation >= 1) { rlang::abort("`validation` should be on [0, 1).") } @@ -399,19 +404,54 @@ as_xgb_data <- function(x, y, validation = 0, weights = NULL, event_level = "fir } if (!inherits(x, "xgb.DMatrix")) { - if (validation > 0) { - # Split data - m <- floor(n * (1 - validation)) + 1 - trn_index <- sample(1:n, size = max(m, 2)) - val_data <- xgboost::xgb.DMatrix(x[-trn_index,], label = y[-trn_index], missing = NA) - watch_list <- list(validation = val_data) - info_list <- list(label = y[trn_index]) - if (!is.null(weights)) { - info_list$weight <- weights[trn_index] + if (is.numeric(validation)) { + + if (validation > 0) { + + # get splits index + m <- floor(n * (1 - validation)) + 1 + trn_index <- sample(1:n, size = max(m, 2)) + info_list <- list(label = y[trn_index]) + + if (!is.null(weights)){ + + val_data <- xgboost::xgb.DMatrix(x[-trn_index,], label = y[-trn_index], weight = weights[-trn_index], missing = NA) + watch_list <- list(validation = val_data) + + info_list$weight <- weights[trn_index] + dat <- xgboost::xgb.DMatrix(x[trn_index,], missing = NA, info = info_list) + + } else + + val_data <- xgboost::xgb.DMatrix(x[-trn_index,], label = y[-trn_index], missing = NA) + watch_list <- list(validation = val_data) + + dat <- xgboost::xgb.DMatrix(x[trn_index,], missing = NA, info = info_list) + } + + } else if (is.data.frame(validation)) { + + validation <- as.matrix(validation) + # Assuming whichever column is not present in x is the outcome + # Not ideal bc validation could contain abritarty columns + # Would need the colname of `y` + y_index <- which(!(colnames(validation) %in% colnames(x))) + + val_info_list <- list(label = validation[,y_index]) + + check_weights <- sapply(validation, hardhat::is_case_weights) + + if (any(check_weights)) { + weights_col_num <- which(check_weights) + val_info_list$weight <- validation[, weights_col_num, drop = T] + val_data <- xgboost::xgb.DMatrix(validation[,-y_index], missing = NA, info = val_info_list) } - dat <- xgboost::xgb.DMatrix(x[trn_index,], missing = NA, info = info_list) + val_data <- xgboost::xgb.DMatrix(validation[,-y_index], label = validation[,y_index], missing = NA) + watch_list <- list(validation = val_data) + + dat <- xgboost::xgb.DMatrix(x, label = y, missing = NA) } else { info_list <- list(label = y) @@ -421,6 +461,7 @@ as_xgb_data <- function(x, y, validation = 0, weights = NULL, event_level = "fir dat <- xgboost::xgb.DMatrix(x, missing = NA, info = info_list) watch_list <- list(training = dat) } + } else { dat <- xgboost::setinfo(x, "label", y) if (!is.null(weights)) { From 940502bbdf17c2eaaeb5125193381be8fdb49985 Mon Sep 17 00:00:00 2001 From: Joey Couse <54423399+joeycouse@users.noreply.github.com> Date: Wed, 20 Jul 2022 12:00:14 -0500 Subject: [PATCH 02/10] typo --- R/boost_tree.R | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/R/boost_tree.R b/R/boost_tree.R index 52cb7e32a..b4723df93 100644 --- a/R/boost_tree.R +++ b/R/boost_tree.R @@ -434,7 +434,7 @@ as_xgb_data <- function(x, y, validation = 0, weights = NULL, event_level = "fir validation <- as.matrix(validation) # Assuming whichever column is not present in x is the outcome - # Not ideal bc validation could contain abritarty columns + # Not ideal bc validation could contain arbitrary column that isn't the intended outcome # Would need the colname of `y` y_index <- which(!(colnames(validation) %in% colnames(x))) From 4338a6c99890520bb54632d758f9e3393711154a Mon Sep 17 00:00:00 2001 From: Joey Couse <54423399+joeycouse@users.noreply.github.com> Date: Thu, 11 Aug 2022 11:47:40 -0500 Subject: [PATCH 03/10] remove name atr of vectors --- R/convert_data.R | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/R/convert_data.R b/R/convert_data.R index ef8fa0673..35c5c1516 100644 --- a/R/convert_data.R +++ b/R/convert_data.R @@ -102,9 +102,11 @@ remove_intercept = remove_intercept ) + if (composition == "data.frame") { if (is.matrix(y)) { y <- as.data.frame(y) + colnames(y) <- all.vars(formula[[2]]) } res <- list( @@ -117,11 +119,14 @@ options = options ) } else { - # Since a matrix is requested, try to convert y but check - # to see if it is possible + if (will_make_matrix(y)) { y <- as.matrix(y) + colnames(y) <- all.vars(formula[[2]]) + } else { + attr(y, "colnames") <- all.vars(formula[[2]]) } + res <- list( x = x, @@ -325,7 +330,7 @@ make_formula <- function(x, y, short = TRUE) { will_make_matrix <- function(y) { if (is.matrix(y) | is.vector(y)) - return(FALSE) + return(TRUE) cls <- unique(unlist(lapply(y, class))) if (length(cls) > 1) return(FALSE) From e06547e88a0e3ea6fac1e6aba0fc9d668622360e Mon Sep 17 00:00:00 2001 From: Joey Couse <54423399+joeycouse@users.noreply.github.com> Date: Thu, 11 Aug 2022 11:48:33 -0500 Subject: [PATCH 04/10] preserve colnames of y when y is provided as data.frame with numeric only --- R/fit.R | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/R/fit.R b/R/fit.R index 6cda2e2c0..69f54ebc9 100644 --- a/R/fit.R +++ b/R/fit.R @@ -257,7 +257,11 @@ fit_xy.model_spec <- if (object$engine != "spark" & NCOL(y) == 1 & !(is.vector(y) | is.factor(y))) { if (is.matrix(y)) { y <- y[, 1] + } else if (!is.null(colnames(y))){ + # preserves colname of y + y <- as.matrix(y) } else { + #strips colname of y y <- y[[1]] } } @@ -411,6 +415,7 @@ check_xy_interface <- function(x, y, cl, model) { } # `y` can be a vector (which is not a class), or a factor (which is not a vector) + if (!is.null(y) && !is.vector(y)) inher(y, c("data.frame", "matrix", "factor"), cl) From b3e249d93ae270a57c9af2a7400654e772d17362 Mon Sep 17 00:00:00 2001 From: Joey Couse <54423399+joeycouse@users.noreply.github.com> Date: Thu, 11 Aug 2022 11:49:13 -0500 Subject: [PATCH 05/10] remove attempt to add attributes to vectors --- R/convert_data.R | 2 -- 1 file changed, 2 deletions(-) diff --git a/R/convert_data.R b/R/convert_data.R index 35c5c1516..e37bbed5d 100644 --- a/R/convert_data.R +++ b/R/convert_data.R @@ -123,8 +123,6 @@ if (will_make_matrix(y)) { y <- as.matrix(y) colnames(y) <- all.vars(formula[[2]]) - } else { - attr(y, "colnames") <- all.vars(formula[[2]]) } res <- From 5940ef70be691fff1e1e568c0687115ad88a02a9 Mon Sep 17 00:00:00 2001 From: Joey Couse <54423399+joeycouse@users.noreply.github.com> Date: Thu, 11 Aug 2022 11:53:03 -0500 Subject: [PATCH 06/10] * add improved error handling * handles case where validation has additional columns * add error for when `y` is vector and validation is a dataframe - fit_xy() issue. --- R/boost_tree.R | 28 ++++++++++++++++++---------- 1 file changed, 18 insertions(+), 10 deletions(-) diff --git a/R/boost_tree.R b/R/boost_tree.R index b4723df93..dd2a60938 100644 --- a/R/boost_tree.R +++ b/R/boost_tree.R @@ -236,9 +236,17 @@ xgb_train <- function( num_class <- length(levels(y)) if(is.data.frame(validation)) { - if(length(colnames(validation)) != length(colnames(x))+1){ - msg <- paste0("`validation` should contain ", length(colnames(x))+1, " columns") - rlang::abort(msg) + if(is.null(colnames(y)) | is.vector(y) && is.null(attr(y, 'colnames'))){ + rlang::abort("`y` must be named when `validation` is a dataframe") + } else if (!(colnames(y) %in% colnames(validation))){ + wrong_col <- colnames(y) + rlang::abort(paste0("`",wrong_col,"`", " column not found in `validation`")) + } else if (!all(colnames(x) %in% colnames(validation))){ + missing_cols <- colnames(x)[which(!(colnames(x) %in% colnames(validation)))] + missing_cols_txt <- paste0("`", missing_cols, "`", collapse = ",") + + rlang::abort(glue::glue("`validation` is missing column(s): {missing_cols_txt}")) + } } else if (!is.numeric(validation) || validation < 0 || validation >= 1) { rlang::abort("`validation` should be on [0, 1).") @@ -424,7 +432,7 @@ as_xgb_data <- function(x, y, validation = 0, weights = NULL, event_level = "fir } else - val_data <- xgboost::xgb.DMatrix(x[-trn_index,], label = y[-trn_index], missing = NA) + val_data <- xgboost::xgb.DMatrix(x[-trn_index,], label = y[-trn_index], missing = NA) watch_list <- list(validation = val_data) dat <- xgboost::xgb.DMatrix(x[trn_index,], missing = NA, info = info_list) @@ -432,11 +440,11 @@ as_xgb_data <- function(x, y, validation = 0, weights = NULL, event_level = "fir } else if (is.data.frame(validation)) { + predictor_cols <- which(colnames(validation) %in% colnames(x)) + validation <- as.matrix(validation) - # Assuming whichever column is not present in x is the outcome - # Not ideal bc validation could contain arbitrary column that isn't the intended outcome - # Would need the colname of `y` - y_index <- which(!(colnames(validation) %in% colnames(x))) + + y_index <- which(colnames(validation) %in% colnames(y)) val_info_list <- list(label = validation[,y_index]) @@ -445,10 +453,10 @@ as_xgb_data <- function(x, y, validation = 0, weights = NULL, event_level = "fir if (any(check_weights)) { weights_col_num <- which(check_weights) val_info_list$weight <- validation[, weights_col_num, drop = T] - val_data <- xgboost::xgb.DMatrix(validation[,-y_index], missing = NA, info = val_info_list) + val_data <- xgboost::xgb.DMatrix(validation[,predictor_cols], missing = NA, info = val_info_list) } - val_data <- xgboost::xgb.DMatrix(validation[,-y_index], label = validation[,y_index], missing = NA) + val_data <- xgboost::xgb.DMatrix(validation[,predictor_cols], label = validation[,y_index], missing = NA) watch_list <- list(validation = val_data) dat <- xgboost::xgb.DMatrix(x, label = y, missing = NA) From a46404c00dd1dc1ee99711e9c332df880fd42b0a Mon Sep 17 00:00:00 2001 From: Joey Couse <54423399+joeycouse@users.noreply.github.com> Date: Mon, 22 Aug 2022 08:11:59 -0500 Subject: [PATCH 07/10] support for dataframe as arg for validation --- R/boost_tree.R | 117 ++++++++++++++++++++++++++++++++++--------------- 1 file changed, 82 insertions(+), 35 deletions(-) diff --git a/R/boost_tree.R b/R/boost_tree.R index dd2a60938..62d7b0797 100644 --- a/R/boost_tree.R +++ b/R/boost_tree.R @@ -236,20 +236,30 @@ xgb_train <- function( num_class <- length(levels(y)) if(is.data.frame(validation)) { - if(is.null(colnames(y)) | is.vector(y) && is.null(attr(y, 'colnames'))){ - rlang::abort("`y` must be named when `validation` is a dataframe") - } else if (!(colnames(y) %in% colnames(validation))){ - wrong_col <- colnames(y) - rlang::abort(paste0("`",wrong_col,"`", " column not found in `validation`")) - } else if (!all(colnames(x) %in% colnames(validation))){ - missing_cols <- colnames(x)[which(!(colnames(x) %in% colnames(validation)))] - missing_cols_txt <- paste0("`", missing_cols, "`", collapse = ",") - - rlang::abort(glue::glue("`validation` is missing column(s): {missing_cols_txt}")) + if(is.matrix(y) | is.data.frame(y) | is.numeric(y)){ + if (is.null(colnames(y))){ + rlang::abort("`y` must be named when `validation` is a dataframe") + } else if (!(colnames(y) %in% colnames(validation))){ + wrong_col <- colnames(y) + rlang::abort(paste0("`",wrong_col,"`", " column not found in `validation`")) + } + } else { + if (is.null(attr(y, "col_name"))) { + rlang::abort("`y` must be named when `validation` is a dataframe") + } else if (!(attr(y, "col_name") %in% colnames(validation))) { + wrong_col <- attr(y, "col_name") + rlang::abort(paste0("`",wrong_col,"`", " column not found in `validation`")) + } + } + if (!all(colnames(x) %in% colnames(validation))){ + missing_cols <- colnames(x)[which(!(colnames(x) %in% colnames(validation)))] + missing_cols_txt <- paste0("`", missing_cols, "`", collapse = ",") + rlang::abort(glue::glue("`validation` is missing column(s): {missing_cols_txt}")) } + } else if (!is.numeric(validation) || validation < 0 || validation >= 1) { - rlang::abort("`validation` should be on [0, 1).") + rlang::abort("`validation` should be on [0, 1).") } if (!is.null(early_stop)) { @@ -393,21 +403,34 @@ xgb_predict <- function(object, new_data, ...) { as_xgb_data <- function(x, y, validation = 0, weights = NULL, event_level = "first", ...) { lvls <- levels(y) n <- nrow(x) + y_is_factor <- is.factor(y) if (is.data.frame(x)) { x <- as.matrix(x) } - if (is.factor(y)) { + if (y_is_factor) { + + y_col_name <- attr(y, "col_name") + if (length(lvls) < 3) { if (event_level == "first") { y <- -as.numeric(y) + 2 + y <- as.matrix(y) + colnames(y) <- y_col_name + } else { + y <- as.numeric(y) - 1 + y <- as.matrix(y) + colnames(y) <- y_col_name } } else { if (event_level == "second") rlang::warn("`event_level` can only be set for binary variables.") + y <- as.numeric(y) - 1 + y <- as.matrix(y) + colnames(y) <- y_col_name } } @@ -422,52 +445,76 @@ as_xgb_data <- function(x, y, validation = 0, weights = NULL, event_level = "fir trn_index <- sample(1:n, size = max(m, 2)) info_list <- list(label = y[trn_index]) - if (!is.null(weights)){ + val_data <- xgboost::xgb.DMatrix(x[-trn_index,], label = y[-trn_index], missing = NA) + watch_list <- list(validation = val_data) + + dat <- xgboost::xgb.DMatrix(x[trn_index,], missing = NA, info = info_list) - val_data <- xgboost::xgb.DMatrix(x[-trn_index,], label = y[-trn_index], weight = weights[-trn_index], missing = NA) - watch_list <- list(validation = val_data) + } else { - info_list$weight <- weights[trn_index] - dat <- xgboost::xgb.DMatrix(x[trn_index,], missing = NA, info = info_list) + info_list <- list(label = y) - } else + if (!is.null(weights)) { + info_list$weight <- weights + } - val_data <- xgboost::xgb.DMatrix(x[-trn_index,], label = y[-trn_index], missing = NA) - watch_list <- list(validation = val_data) + dat <- xgboost::xgb.DMatrix(x, missing = NA, info = info_list) + watch_list <- list(training = dat) - dat <- xgboost::xgb.DMatrix(x[trn_index,], missing = NA, info = info_list) } } else if (is.data.frame(validation)) { predictor_cols <- which(colnames(validation) %in% colnames(x)) + y_index <- which(colnames(validation) %in% colnames(y)) - validation <- as.matrix(validation) + if (y_is_factor){ - y_index <- which(colnames(validation) %in% colnames(y)) + y_val <- validation[,y_index, drop = T] - val_info_list <- list(label = validation[,y_index]) + if (length(lvls) < 3) { + if (event_level == "first") { - check_weights <- sapply(validation, hardhat::is_case_weights) + y_val <- -as.numeric(y_val) + 2 + y_val <- as.matrix(y_val) + colnames(y_val) <- y_col_name + } else { - if (any(check_weights)) { - weights_col_num <- which(check_weights) - val_info_list$weight <- validation[, weights_col_num, drop = T] - val_data <- xgboost::xgb.DMatrix(validation[,predictor_cols], missing = NA, info = val_info_list) - } + y_val <- as.numeric(y_val) - 1 + y_val <- as.matrix(y_val) + colnames(y_val) <- y_col_name + } + } else { + if (event_level == "second") rlang::warn("`event_level` can only be set for binary variables. More than two outcome classes in `validation`") - val_data <- xgboost::xgb.DMatrix(validation[,predictor_cols], label = validation[,y_index], missing = NA) - watch_list <- list(validation = val_data) + y_val <- as.numeric(y_val) - 1 + y_val <- as.matrix(y_val) + colnames(y_val) <- y_col_name + } - dat <- xgboost::xgb.DMatrix(x, label = y, missing = NA) + validation <- as.matrix(validation[,predictor_cols]) + rownames(validation) <- NULL + val_info_list <- list(label = y_val) + val_data <- xgboost::xgb.DMatrix(validation, missing = NA, info = val_info_list) + + } else { + + validation <- as.matrix(validation) + rownames(validation) <- NULL + val_info_list <- list(label = validation[,y_index]) + val_data <- xgboost::xgb.DMatrix(validation[,predictor_cols], missing = NA, info = val_info_list) + + } - } else { info_list <- list(label = y) + if (!is.null(weights)) { info_list$weight <- weights } + dat <- xgboost::xgb.DMatrix(x, missing = NA, info = info_list) - watch_list <- list(training = dat) + watch_list <- list(validation = val_data) + } } else { From 9037d8744d9c5e71df52d251af67e2cdd16b9f55 Mon Sep 17 00:00:00 2001 From: Joey Couse <54423399+joeycouse@users.noreply.github.com> Date: Mon, 22 Aug 2022 08:13:09 -0500 Subject: [PATCH 08/10] update logic to preserve colnames of y --- R/fit.R | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/R/fit.R b/R/fit.R index 69f54ebc9..b2d11d4b7 100644 --- a/R/fit.R +++ b/R/fit.R @@ -255,15 +255,23 @@ fit_xy.model_spec <- } if (object$engine != "spark" & NCOL(y) == 1 & !(is.vector(y) | is.factor(y))) { + + y_col_name <- colnames(y) + if (is.matrix(y)) { y <- y[, 1] - } else if (!is.null(colnames(y))){ + } else if (!is.null(colnames(y)) && is.numeric(y[,1,drop=T])) { # preserves colname of y y <- as.matrix(y) } else { #strips colname of y y <- y[[1]] } + + if(object$engine == "xgboost" && object$mode == "classification" && is.factor(y)){ + + attr(y, "col_name") <- y_col_name + } } cl <- match.call(expand.dots = TRUE) From 10e513313952225ba5df08ea74d0b1ec5764ada0 Mon Sep 17 00:00:00 2001 From: Joey Couse <54423399+joeycouse@users.noreply.github.com> Date: Mon, 22 Aug 2022 08:13:35 -0500 Subject: [PATCH 09/10] preserve colname of y when y is a factor --- R/fit_helpers.R | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/R/fit_helpers.R b/R/fit_helpers.R index d4fbdf6b8..70613b476 100644 --- a/R/fit_helpers.R +++ b/R/fit_helpers.R @@ -147,6 +147,10 @@ form_xy <- function(object, control, env, check_outcome(env$y, object) + if(object$engine == "xgboost" && object$mode == "classification" && is.factor(env$y)){ + attr(env$y, "col_name") <- all.vars(env$formula[[2]]) + } + res <- xy_xy( object = object, env = env, #weights! From c175d3e4899753e19bae695c2f31bd6a6974e15f Mon Sep 17 00:00:00 2001 From: Joey Couse <54423399+joeycouse@users.noreply.github.com> Date: Mon, 22 Aug 2022 12:51:30 -0500 Subject: [PATCH 10/10] update test to expect named numeric matrix --- tests/testthat/test_convert_data.R | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/tests/testthat/test_convert_data.R b/tests/testthat/test_convert_data.R index e84bb4812..03fd9a05e 100644 --- a/tests/testthat/test_convert_data.R +++ b/tests/testthat/test_convert_data.R @@ -378,7 +378,11 @@ test_that("numeric x and y, matrix composition", { remove_intercept = TRUE ) expect_equal(format_x_for_test(expected$x, df = FALSE), observed$x) - expect_equal(mtcars$mpg, observed$y) + + expected_y <- as.matrix(mtcars$mpg) + names(expected_y) <- NULL + colnames(expected_y) <- "mpg" + expect_equal(expected_y, observed$y) new_obs <- .convert_form_to_xy_new(observed,