diff --git a/R/aaa_models.R b/R/aaa_models.R index 112d5b225..957d6a029 100644 --- a/R/aaa_models.R +++ b/R/aaa_models.R @@ -280,12 +280,27 @@ check_pred_info <- function(pred_obj, type) { invisible(NULL) } +check_spec_pred_type <- function(object, type) { + possible_preds <- names(object$spec$method$pred) + if (!any(possible_preds == type)) { + rlang::abort(c( + glue::glue("No {type} prediction method available for this model."), + glue::glue("Value for `type` should be one of: ", + glue::glue_collapse(glue::glue("'{possible_preds}'"), sep = ", ")) + )) + } + invisible(NULL) +} + + check_pkg_val <- function(pkg) { - if (rlang::is_missing(pkg) || length(pkg) != 1 || !is.character(pkg)) + if (rlang::is_missing(pkg) || length(pkg) != 1 || !is.character(pkg)) { rlang::abort("Please supply a single character value for the package name.") + } invisible(NULL) } + check_interface_val <- function(x) { exp_interf <- c("data.frame", "formula", "matrix") if (length(x) != 1 || !(x %in% exp_interf)) { diff --git a/R/predict_class.R b/R/predict_class.R index 6e777b329..e96d460b1 100644 --- a/R/predict_class.R +++ b/R/predict_class.R @@ -12,8 +12,7 @@ predict_class.model_fit <- function(object, new_data, ...) { if (object$spec$mode != "classification") rlang::abort("`predict.model_fit()` is for predicting factor outcomes.") - if (!any(names(object$spec$method$pred) == "class")) - rlang::abort("No class prediction module defined for this model.") + check_spec_pred_type(object, "class") if (inherits(object$fit, "try-error")) { rlang::warn("Model fit failed; cannot make predictions.") diff --git a/R/predict_classprob.R b/R/predict_classprob.R index 6dcb4d897..5d8f2cf05 100644 --- a/R/predict_classprob.R +++ b/R/predict_classprob.R @@ -9,8 +9,8 @@ predict_classprob.model_fit <- function(object, new_data, ...) { if (object$spec$mode != "classification") rlang::abort("`predict.model_fit()` is for predicting factor outcomes.") - if (!any(names(object$spec$method$pred) == "prob")) - rlang::abort("No class probability module defined for this model.") + check_spec_pred_type(object, "prob") + if (inherits(object$fit, "try-error")) { rlang::warn("Model fit failed; cannot make predictions.") diff --git a/R/predict_hazard.R b/R/predict_hazard.R index e090cb85d..d70217cba 100644 --- a/R/predict_hazard.R +++ b/R/predict_hazard.R @@ -7,8 +7,7 @@ predict_hazard.model_fit <- function(object, new_data, .time, ...) { - if (is.null(object$spec$method$pred$hazard)) - rlang::abort("No hazard prediction method defined for this engine.") + check_spec_pred_type(object, "hazard") if (inherits(object$fit, "try-error")) { rlang::warn("Model fit failed; cannot make predictions.") diff --git a/R/predict_interval.R b/R/predict_interval.R index 7dd8562c1..018228cdd 100644 --- a/R/predict_interval.R +++ b/R/predict_interval.R @@ -10,8 +10,7 @@ #' @export predict_confint.model_fit <- function(object, new_data, level = 0.95, std_error = FALSE, ...) { - if (is.null(object$spec$method$pred$conf_int)) - rlang::abort("No confidence interval method defined for this engine.") + check_spec_pred_type(object, "conf_int") if (inherits(object$fit, "try-error")) { rlang::warn("Model fit failed; cannot make predictions.") @@ -58,8 +57,7 @@ predict_confint <- function(object, ...) # @export predict_predint.model_fit <- function(object, new_data, level = 0.95, std_error = FALSE, ...) { - if (is.null(object$spec$method$pred$pred_int)) - rlang::abort("No prediction interval method defined for this engine.") + check_spec_pred_type(object, "pred_int") if (inherits(object$fit, "try-error")) { rlang::warn("Model fit failed; cannot make predictions.") diff --git a/R/predict_linear_pred.R b/R/predict_linear_pred.R index c0a7abc92..62af94f44 100644 --- a/R/predict_linear_pred.R +++ b/R/predict_linear_pred.R @@ -6,8 +6,7 @@ #' @export predict_linear_pred.model_fit <- function(object, new_data, ...) { - if (!any(names(object$spec$method$pred) == "linear_pred")) - rlang::abort("No prediction module defined for this model.") + check_spec_pred_type(object, "linear_pred") if (inherits(object$fit, "try-error")) { rlang::warn("Model fit failed; cannot make predictions.") diff --git a/R/predict_numeric.R b/R/predict_numeric.R index 38fce1a39..17d61a614 100644 --- a/R/predict_numeric.R +++ b/R/predict_numeric.R @@ -10,8 +10,7 @@ predict_numeric.model_fit <- function(object, new_data, ...) { "Use `predict_class()` or `predict_classprob()` for ", "classification models.")) - if (!any(names(object$spec$method$pred) == "numeric")) - rlang::abort("No prediction module defined for this model.") + check_spec_pred_type(object, "numeric") if (inherits(object$fit, "try-error")) { rlang::warn("Model fit failed; cannot make predictions.") diff --git a/R/predict_quantile.R b/R/predict_quantile.R index 9d671af24..7965069e6 100644 --- a/R/predict_quantile.R +++ b/R/predict_quantile.R @@ -9,8 +9,7 @@ predict_quantile.model_fit <- function(object, new_data, quantile = (1:9)/10, ...) { - if (is.null(object$spec$method$pred$quantile)) - rlang::abort("No quantile prediction method defined for this engine.") + check_spec_pred_type(object, "quantile") if (inherits(object$fit, "try-error")) { rlang::warn("Model fit failed; cannot make predictions.") diff --git a/R/predict_raw.R b/R/predict_raw.R index 22e59b792..cf391a063 100644 --- a/R/predict_raw.R +++ b/R/predict_raw.R @@ -13,8 +13,7 @@ predict_raw.model_fit <- function(object, new_data, opts = list(), ...) { c(object$spec$method$pred$raw$args, opts) } - if (!any(names(object$spec$method$pred) == "raw")) - rlang::abort("No raw prediction module defined for this model.") + check_spec_pred_type(object, "raw") if (inherits(object$fit, "try-error")) { rlang::warn("Model fit failed; cannot make predictions.") diff --git a/R/predict_survival.R b/R/predict_survival.R index a3075fd66..536d800eb 100644 --- a/R/predict_survival.R +++ b/R/predict_survival.R @@ -7,8 +7,7 @@ predict_survival.model_fit <- function(object, new_data, .time, ...) { - if (is.null(object$spec$method$pred$survival)) - rlang::abort("No survival prediction method defined for this engine.") + check_spec_pred_type(object, "survival") if (inherits(object$fit, "try-error")) { rlang::warn("Model fit failed; cannot make predictions.") diff --git a/R/predict_time.R b/R/predict_time.R index fc96c6207..cdedda287 100644 --- a/R/predict_time.R +++ b/R/predict_time.R @@ -10,8 +10,7 @@ predict_time.model_fit <- function(object, new_data, ...) { "Use `predict_class()` or `predict_classprob()` for ", "classification models.")) - if (!any(names(object$spec$method$pred) == "time")) - rlang::abort("No prediction module defined for this model.") + check_spec_pred_type(object, "time") if (inherits(object$fit, "try-error")) { rlang::warn("Model fit failed; cannot make predictions.") diff --git a/R/svm_linear_data.R b/R/svm_linear_data.R index 2778eff1c..25da85477 100644 --- a/R/svm_linear_data.R +++ b/R/svm_linear_data.R @@ -123,27 +123,6 @@ set_pred( ) ) ) -set_pred( - model = "svm_linear", - eng = "LiblineaR", - mode = "classification", - type = "prob", - value = list( - pre = function(x, object) { - rlang::abort( - paste0("The LiblineaR engine does not support class probabilities ", - "for any `svm` models.") - ) - }, - post = NULL, - func = c(fun = "predict"), - args = - list( - object = quote(object$fit), - newx = expr(as.matrix(new_data)) - ) - ) -) set_pred( model = "svm_linear", eng = "LiblineaR", diff --git a/tests/testthat/test_svm_linear.R b/tests/testthat/test_svm_linear.R index 24d4ced4a..cf6875079 100644 --- a/tests/testthat/test_svm_linear.R +++ b/tests/testthat/test_svm_linear.R @@ -280,12 +280,12 @@ test_that('linear svm classification prediction: LiblineaR', { expect_error( predict(cls_form, hpc_no_m[ind, -5], type = "prob"), - "The LiblineaR engine does not support class probabilities" + "No prob prediction method available for this model" ) expect_error( predict(cls_xy_form, hpc_no_m[ind, -5], type = "prob"), - "The LiblineaR engine does not support class probabilities" + "No prob prediction method available for this model" ) })