From fcb10712eb2535eb0c4fc52178e1f1f9811dc6e4 Mon Sep 17 00:00:00 2001 From: Julia Silge Date: Tue, 11 May 2021 14:53:23 -0600 Subject: [PATCH 1/4] tryCatch for class probabilities in `augment()` --- R/augment.R | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/R/augment.R b/R/augment.R index 434e2b69b..09ea39977 100644 --- a/R/augment.R +++ b/R/augment.R @@ -68,12 +68,17 @@ augment.model_fit <- function(x, new_data, ...) { } } } else if (x$spec$mode == "classification") { - new_data <- - new_data %>% - dplyr::bind_cols( - predict(x, new_data = new_data, type = "class"), + new_data <- dplyr::bind_cols( + new_data, + predict(x, new_data = new_data, type = "class") + ) + tryCatch( + new_data <- dplyr::bind_cols( + new_data, predict(x, new_data = new_data, type = "prob") - ) + ), + error = function(cnd) cnd + ) } else { rlang::abort(paste("Unknown mode:", x$spec$mode)) } From 1e525bb4fe8ee2f0488f86d047f5fab6a46306c6 Mon Sep 17 00:00:00 2001 From: Julia Silge Date: Tue, 11 May 2021 14:58:34 -0600 Subject: [PATCH 2/4] Test that can augment models without class probability support --- tests/testthat/test-augment.R | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/tests/testthat/test-augment.R b/tests/testthat/test-augment.R index 9ed9eb65c..c2c545c0e 100644 --- a/tests/testthat/test-augment.R +++ b/tests/testthat/test-augment.R @@ -76,3 +76,18 @@ test_that('classification models', { }) + +test_that('augment for model without class probabilities', { + skip_if_not_installed("LiblineaR") + + data(two_class_dat, package = "modeldata") + x <- svm_linear(mode = "classification") %>% set_engine("LiblineaR") + cls_form <- x %>% fit(Class ~ ., data = two_class_dat) + + expect_equal( + colnames(augment(cls_form, head(two_class_dat))), + c("A", "B", "Class", ".pred_class") + ) + expect_equal(nrow(augment(cls_form, head(two_class_dat))), 6) + +}) From b2b054d4add85cd09c46d187476b362003507d5e Mon Sep 17 00:00:00 2001 From: Julia Silge Date: Wed, 12 May 2021 17:11:54 -0600 Subject: [PATCH 3/4] Use helper functions to check whether model supports class / prob predictions --- R/augment.R | 25 +++++++++++++++++-------- 1 file changed, 17 insertions(+), 8 deletions(-) diff --git a/R/augment.R b/R/augment.R index 09ea39977..d382c8175 100644 --- a/R/augment.R +++ b/R/augment.R @@ -68,19 +68,28 @@ augment.model_fit <- function(x, new_data, ...) { } } } else if (x$spec$mode == "classification") { - new_data <- dplyr::bind_cols( - new_data, - predict(x, new_data = new_data, type = "class") - ) - tryCatch( + if (has_class_preds(x)) { + new_data <- dplyr::bind_cols( + new_data, + predict(x, new_data = new_data, type = "class") + ) + } + if (has_class_probs(x)) { new_data <- dplyr::bind_cols( new_data, predict(x, new_data = new_data, type = "prob") - ), - error = function(cnd) cnd - ) + ) + } } else { rlang::abort(paste("Unknown mode:", x$spec$mode)) } as_tibble(new_data) } + +has_class_preds <- function(x) { + any(names(x$spec$method$pred) == "class") +} + +has_class_probs <- function(x) { + any(names(x$spec$method$pred) == "prob") +} From 426fd3d995c70a365e1765cafed98756700ea52d Mon Sep 17 00:00:00 2001 From: Max Kuhn Date: Thu, 13 May 2021 10:34:40 -0400 Subject: [PATCH 4/4] modified checking method for correct pred type --- R/aaa_models.R | 8 ++++++-- R/augment.R | 18 ++++++------------ man/augment.Rd | 5 +++-- 3 files changed, 15 insertions(+), 16 deletions(-) diff --git a/R/aaa_models.R b/R/aaa_models.R index 8235bddf7..2a40e3d11 100644 --- a/R/aaa_models.R +++ b/R/aaa_models.R @@ -286,9 +286,13 @@ check_pred_info <- function(pred_obj, type) { invisible(NULL) } -check_spec_pred_type <- function(object, type) { +spec_has_pred_type <- function(object, type) { possible_preds <- names(object$spec$method$pred) - if (!any(possible_preds == type)) { + any(possible_preds == type) +} +check_spec_pred_type <- function(object, type) { + if (!spec_has_pred_type(object, type)) { + possible_preds <- names(object$spec$method$pred) rlang::abort(c( glue::glue("No {type} prediction method available for this model."), glue::glue("Value for `type` should be one of: ", diff --git a/R/augment.R b/R/augment.R index d382c8175..fa264f69e 100644 --- a/R/augment.R +++ b/R/augment.R @@ -6,8 +6,9 @@ #' [fit()] and `new_data` contains the outcome column, a `.resid` column is #' also added. #' -#' For classification models, the results include a column called `.pred_class` -#' as well as class probability columns named `.pred_{level}`. +#' For classification models, the results can include a column called +#' `.pred_class` as well as class probability columns named `.pred_{level}`. +#' This depends on what type of prediction types are available for the model. #' @param x A `model_fit` object produced by [fit()] or [fit_xy()]. #' @param new_data A data frame or matrix. #' @param ... Not currently used. @@ -56,6 +57,7 @@ #' augment.model_fit <- function(x, new_data, ...) { if (x$spec$mode == "regression") { + check_spec_pred_type(x, "numeric") new_data <- new_data %>% dplyr::bind_cols( @@ -68,13 +70,13 @@ augment.model_fit <- function(x, new_data, ...) { } } } else if (x$spec$mode == "classification") { - if (has_class_preds(x)) { + if (spec_has_pred_type(x, "class")) { new_data <- dplyr::bind_cols( new_data, predict(x, new_data = new_data, type = "class") ) } - if (has_class_probs(x)) { + if (spec_has_pred_type(x, "prob")) { new_data <- dplyr::bind_cols( new_data, predict(x, new_data = new_data, type = "prob") @@ -85,11 +87,3 @@ augment.model_fit <- function(x, new_data, ...) { } as_tibble(new_data) } - -has_class_preds <- function(x) { - any(names(x$spec$method$pred) == "class") -} - -has_class_probs <- function(x) { - any(names(x$spec$method$pred) == "prob") -} diff --git a/man/augment.Rd b/man/augment.Rd index 9bda10cec..fe3bd4982 100644 --- a/man/augment.Rd +++ b/man/augment.Rd @@ -21,8 +21,9 @@ For regression models, a \code{.pred} column is added. If \code{x} was created u \code{\link[=fit]{fit()}} and \code{new_data} contains the outcome column, a \code{.resid} column is also added. -For classification models, the results include a column called \code{.pred_class} -as well as class probability columns named \verb{.pred_\{level\}}. +For classification models, the results can include a column called +\code{.pred_class} as well as class probability columns named \verb{.pred_\{level\}}. +This depends on what type of prediction types are available for the model. } \examples{ car_trn <- mtcars[11:32,]