diff --git a/R/aaa_models.R b/R/aaa_models.R index 8235bddf7..2a40e3d11 100644 --- a/R/aaa_models.R +++ b/R/aaa_models.R @@ -286,9 +286,13 @@ check_pred_info <- function(pred_obj, type) { invisible(NULL) } -check_spec_pred_type <- function(object, type) { +spec_has_pred_type <- function(object, type) { possible_preds <- names(object$spec$method$pred) - if (!any(possible_preds == type)) { + any(possible_preds == type) +} +check_spec_pred_type <- function(object, type) { + if (!spec_has_pred_type(object, type)) { + possible_preds <- names(object$spec$method$pred) rlang::abort(c( glue::glue("No {type} prediction method available for this model."), glue::glue("Value for `type` should be one of: ", diff --git a/R/augment.R b/R/augment.R index 434e2b69b..fa264f69e 100644 --- a/R/augment.R +++ b/R/augment.R @@ -6,8 +6,9 @@ #' [fit()] and `new_data` contains the outcome column, a `.resid` column is #' also added. #' -#' For classification models, the results include a column called `.pred_class` -#' as well as class probability columns named `.pred_{level}`. +#' For classification models, the results can include a column called +#' `.pred_class` as well as class probability columns named `.pred_{level}`. +#' This depends on what type of prediction types are available for the model. #' @param x A `model_fit` object produced by [fit()] or [fit_xy()]. #' @param new_data A data frame or matrix. #' @param ... Not currently used. @@ -56,6 +57,7 @@ #' augment.model_fit <- function(x, new_data, ...) { if (x$spec$mode == "regression") { + check_spec_pred_type(x, "numeric") new_data <- new_data %>% dplyr::bind_cols( @@ -68,12 +70,18 @@ augment.model_fit <- function(x, new_data, ...) { } } } else if (x$spec$mode == "classification") { - new_data <- - new_data %>% - dplyr::bind_cols( - predict(x, new_data = new_data, type = "class"), + if (spec_has_pred_type(x, "class")) { + new_data <- dplyr::bind_cols( + new_data, + predict(x, new_data = new_data, type = "class") + ) + } + if (spec_has_pred_type(x, "prob")) { + new_data <- dplyr::bind_cols( + new_data, predict(x, new_data = new_data, type = "prob") ) + } } else { rlang::abort(paste("Unknown mode:", x$spec$mode)) } diff --git a/man/augment.Rd b/man/augment.Rd index 9bda10cec..fe3bd4982 100644 --- a/man/augment.Rd +++ b/man/augment.Rd @@ -21,8 +21,9 @@ For regression models, a \code{.pred} column is added. If \code{x} was created u \code{\link[=fit]{fit()}} and \code{new_data} contains the outcome column, a \code{.resid} column is also added. -For classification models, the results include a column called \code{.pred_class} -as well as class probability columns named \verb{.pred_\{level\}}. +For classification models, the results can include a column called +\code{.pred_class} as well as class probability columns named \verb{.pred_\{level\}}. +This depends on what type of prediction types are available for the model. } \examples{ car_trn <- mtcars[11:32,] diff --git a/tests/testthat/test-augment.R b/tests/testthat/test-augment.R index 9ed9eb65c..c2c545c0e 100644 --- a/tests/testthat/test-augment.R +++ b/tests/testthat/test-augment.R @@ -76,3 +76,18 @@ test_that('classification models', { }) + +test_that('augment for model without class probabilities', { + skip_if_not_installed("LiblineaR") + + data(two_class_dat, package = "modeldata") + x <- svm_linear(mode = "classification") %>% set_engine("LiblineaR") + cls_form <- x %>% fit(Class ~ ., data = two_class_dat) + + expect_equal( + colnames(augment(cls_form, head(two_class_dat))), + c("A", "B", "Class", ".pred_class") + ) + expect_equal(nrow(augment(cls_form, head(two_class_dat))), 6) + +})