diff --git a/R/boost_tree.R b/R/boost_tree.R index cce9d1ccd..4698ec53a 100644 --- a/R/boost_tree.R +++ b/R/boost_tree.R @@ -232,8 +232,31 @@ xgb_train <- function( num_class <- length(levels(y)) - if (!is.numeric(validation) || validation < 0 || validation >= 1) { - rlang::abort("`validation` should be on [0, 1).") + if(is.data.frame(validation)) { + if(is.matrix(y) | is.data.frame(y) | is.numeric(y)){ + if (is.null(colnames(y))){ + rlang::abort("`y` must be named when `validation` is a dataframe") + } else if (!(colnames(y) %in% colnames(validation))){ + wrong_col <- colnames(y) + rlang::abort(paste0("`",wrong_col,"`", " column not found in `validation`")) + } + } else { + if (is.null(attr(y, "col_name"))) { + rlang::abort("`y` must be named when `validation` is a dataframe") + } else if (!(attr(y, "col_name") %in% colnames(validation))) { + wrong_col <- attr(y, "col_name") + rlang::abort(paste0("`",wrong_col,"`", " column not found in `validation`")) + } + } + + if (!all(colnames(x) %in% colnames(validation))){ + missing_cols <- colnames(x)[which(!(colnames(x) %in% colnames(validation)))] + missing_cols_txt <- paste0("`", missing_cols, "`", collapse = ",") + rlang::abort(glue::glue("`validation` is missing column(s): {missing_cols_txt}")) + } + + } else if (!is.numeric(validation) || validation < 0 || validation >= 1) { + rlang::abort("`validation` should be on [0, 1).") } if (!is.null(early_stop)) { @@ -409,47 +432,120 @@ xgb_predict <- function(object, new_data, ...) { as_xgb_data <- function(x, y, validation = 0, weights = NULL, event_level = "first", ...) { lvls <- levels(y) n <- nrow(x) + y_is_factor <- is.factor(y) if (is.data.frame(x)) { x <- as.matrix(x) } - if (is.factor(y)) { + if (y_is_factor) { + + y_col_name <- attr(y, "col_name") + if (length(lvls) < 3) { if (event_level == "first") { y <- -as.numeric(y) + 2 + y <- as.matrix(y) + colnames(y) <- y_col_name + } else { + y <- as.numeric(y) - 1 + y <- as.matrix(y) + colnames(y) <- y_col_name } } else { if (event_level == "second") rlang::warn("`event_level` can only be set for binary variables.") + y <- as.numeric(y) - 1 + y <- as.matrix(y) + colnames(y) <- y_col_name } } if (!inherits(x, "xgb.DMatrix")) { - if (validation > 0) { - # Split data - m <- floor(n * (1 - validation)) + 1 - trn_index <- sample(1:n, size = max(m, 2)) - val_data <- xgboost::xgb.DMatrix(x[-trn_index,], label = y[-trn_index], missing = NA) - watch_list <- list(validation = val_data) - info_list <- list(label = y[trn_index]) - if (!is.null(weights)) { - info_list$weight <- weights[trn_index] + if (is.numeric(validation)) { + + if (validation > 0) { + + # get splits index + m <- floor(n * (1 - validation)) + 1 + trn_index <- sample(1:n, size = max(m, 2)) + info_list <- list(label = y[trn_index]) + + val_data <- xgboost::xgb.DMatrix(x[-trn_index,], label = y[-trn_index], missing = NA) + watch_list <- list(validation = val_data) + + dat <- xgboost::xgb.DMatrix(x[trn_index,], missing = NA, info = info_list) + + } else { + + info_list <- list(label = y) + + if (!is.null(weights)) { + info_list$weight <- weights + } + + dat <- xgboost::xgb.DMatrix(x, missing = NA, info = info_list) + watch_list <- list(training = dat) + } - dat <- xgboost::xgb.DMatrix(x[trn_index,], missing = NA, info = info_list) + } else if (is.data.frame(validation)) { + + predictor_cols <- which(colnames(validation) %in% colnames(x)) + y_index <- which(colnames(validation) %in% colnames(y)) + + if (y_is_factor){ + + y_val <- validation[,y_index, drop = T] + + if (length(lvls) < 3) { + if (event_level == "first") { + + y_val <- -as.numeric(y_val) + 2 + y_val <- as.matrix(y_val) + colnames(y_val) <- y_col_name + } else { + + y_val <- as.numeric(y_val) - 1 + y_val <- as.matrix(y_val) + colnames(y_val) <- y_col_name + } + } else { + if (event_level == "second") rlang::warn("`event_level` can only be set for binary variables. More than two outcome classes in `validation`") + + y_val <- as.numeric(y_val) - 1 + y_val <- as.matrix(y_val) + colnames(y_val) <- y_col_name + } + + validation <- as.matrix(validation[,predictor_cols]) + rownames(validation) <- NULL + val_info_list <- list(label = y_val) + val_data <- xgboost::xgb.DMatrix(validation, missing = NA, info = val_info_list) + + } else { + + validation <- as.matrix(validation) + rownames(validation) <- NULL + val_info_list <- list(label = validation[,y_index]) + val_data <- xgboost::xgb.DMatrix(validation[,predictor_cols], missing = NA, info = val_info_list) + + } - } else { info_list <- list(label = y) + if (!is.null(weights)) { info_list$weight <- weights } + dat <- xgboost::xgb.DMatrix(x, missing = NA, info = info_list) - watch_list <- list(training = dat) + watch_list <- list(validation = val_data) + } + } else { dat <- xgboost::setinfo(x, "label", y) if (!is.null(weights)) { diff --git a/R/convert_data.R b/R/convert_data.R index ef8fa0673..e37bbed5d 100644 --- a/R/convert_data.R +++ b/R/convert_data.R @@ -102,9 +102,11 @@ remove_intercept = remove_intercept ) + if (composition == "data.frame") { if (is.matrix(y)) { y <- as.data.frame(y) + colnames(y) <- all.vars(formula[[2]]) } res <- list( @@ -117,11 +119,12 @@ options = options ) } else { - # Since a matrix is requested, try to convert y but check - # to see if it is possible + if (will_make_matrix(y)) { y <- as.matrix(y) + colnames(y) <- all.vars(formula[[2]]) } + res <- list( x = x, @@ -325,7 +328,7 @@ make_formula <- function(x, y, short = TRUE) { will_make_matrix <- function(y) { if (is.matrix(y) | is.vector(y)) - return(FALSE) + return(TRUE) cls <- unique(unlist(lapply(y, class))) if (length(cls) > 1) return(FALSE) diff --git a/R/fit.R b/R/fit.R index 6cda2e2c0..b2d11d4b7 100644 --- a/R/fit.R +++ b/R/fit.R @@ -255,11 +255,23 @@ fit_xy.model_spec <- } if (object$engine != "spark" & NCOL(y) == 1 & !(is.vector(y) | is.factor(y))) { + + y_col_name <- colnames(y) + if (is.matrix(y)) { y <- y[, 1] + } else if (!is.null(colnames(y)) && is.numeric(y[,1,drop=T])) { + # preserves colname of y + y <- as.matrix(y) } else { + #strips colname of y y <- y[[1]] } + + if(object$engine == "xgboost" && object$mode == "classification" && is.factor(y)){ + + attr(y, "col_name") <- y_col_name + } } cl <- match.call(expand.dots = TRUE) @@ -411,6 +423,7 @@ check_xy_interface <- function(x, y, cl, model) { } # `y` can be a vector (which is not a class), or a factor (which is not a vector) + if (!is.null(y) && !is.vector(y)) inher(y, c("data.frame", "matrix", "factor"), cl) diff --git a/R/fit_helpers.R b/R/fit_helpers.R index d4fbdf6b8..70613b476 100644 --- a/R/fit_helpers.R +++ b/R/fit_helpers.R @@ -147,6 +147,10 @@ form_xy <- function(object, control, env, check_outcome(env$y, object) + if(object$engine == "xgboost" && object$mode == "classification" && is.factor(env$y)){ + attr(env$y, "col_name") <- all.vars(env$formula[[2]]) + } + res <- xy_xy( object = object, env = env, #weights! diff --git a/tests/testthat/test_convert_data.R b/tests/testthat/test_convert_data.R index e84bb4812..03fd9a05e 100644 --- a/tests/testthat/test_convert_data.R +++ b/tests/testthat/test_convert_data.R @@ -378,7 +378,11 @@ test_that("numeric x and y, matrix composition", { remove_intercept = TRUE ) expect_equal(format_x_for_test(expected$x, df = FALSE), observed$x) - expect_equal(mtcars$mpg, observed$y) + + expected_y <- as.matrix(mtcars$mpg) + names(expected_y) <- NULL + colnames(expected_y) <- "mpg" + expect_equal(expected_y, observed$y) new_obs <- .convert_form_to_xy_new(observed,