diff --git a/R/predict.R b/R/predict.R index 327bd80ef..e2a7545f2 100644 --- a/R/predict.R +++ b/R/predict.R @@ -147,7 +147,7 @@ #' @export predict.model_fit <- function(object, new_data, type = NULL, opts = list(), ...) { if (inherits(object$fit, "try-error")) { - rlang::warn("Model fit failed; cannot make predictions.") + cli::cli_warn("Model fit failed; cannot make predictions.") return(NULL) } @@ -156,7 +156,7 @@ predict.model_fit <- function(object, new_data, type = NULL, opts = list(), ...) type <- check_pred_type(object, type) if (type != "raw" && length(opts) > 0) { - rlang::warn("`opts` is only used with `type = 'raw'` and was ignored.") + cli::cli_warn("{.arg opts} is only used with `type = 'raw'` and was ignored.") } check_pred_type_dots(object, type, ...) @@ -173,7 +173,7 @@ predict.model_fit <- function(object, new_data, type = NULL, opts = list(), ...) linear_pred = predict_linear_pred(object = object, new_data = new_data, ...), hazard = predict_hazard(object = object, new_data = new_data, ...), raw = predict_raw(object = object, new_data = new_data, opts = opts, ...), - rlang::abort(glue::glue("I don't know about type = '{type}'")) + cli::cli_abort("Unknown prediction {.arg type} '{type}'.") ) if (!inherits(res, "tbl_spark")) { res <- switch( @@ -191,45 +191,69 @@ predict.model_fit <- function(object, new_data, type = NULL, opts = list(), ...) res } -check_pred_type <- function(object, type, ...) { +check_pred_type <- function(object, type, ..., call = rlang::caller_env()) { if (is.null(type)) { type <- - switch(object$spec$mode, - regression = "numeric", - classification = "class", - "censored regression" = "time", - rlang::abort("`type` should be 'regression', 'censored regression', or 'classification'.")) + switch( + object$spec$mode, + regression = "numeric", + classification = "class", + "censored regression" = "time", + cli::cli_abort( + "{.arg type} should be 'regression', 'censored regression', or 'classification'.", + call = call + ) + ) } if (!(type %in% pred_types)) - rlang::abort( - glue::glue( - "`type` should be one of: ", - glue_collapse(pred_types, sep = ", ", last = " and ") - ) + cli::cli_abort( + "{.arg type} should be one of:{.arg {pred_types}}", + call = call ) switch( type, "numeric" = if (object$spec$mode != "regression") { - rlang::abort("For numeric predictions, the object should be a regression model.") + cli::cli_abort( + "For numeric predictions, the object should be a regression model.", + call = call + ) }, "class" = if (object$spec$mode != "classification") { - rlang::abort("For class predictions, the object should be a classification model.") + cli::cli_abort( + "For class predictions, the object should be a classification model.", + call = call + ) }, "prob" = if (object$spec$mode != "classification") { - rlang::abort("For probability predictions, the object should be a classification model.") + cli::cli_abort( + "For probability predictions, the object should be a classification model.", + call = call + ) }, "time" = if (object$spec$mode != "censored regression") { - rlang::abort("For event time predictions, the object should be a censored regression.") + cli::cli_abort( + "For event time predictions, the object should be a censored regression.", + call = call + ) }, "survival" = if (object$spec$mode != "censored regression") { - rlang::abort("For survival probability predictions, the object should be a censored regression.") + cli::cli_abort( + "For survival probability predictions, the object should be a censored regression.", + call = call + ) }, "hazard" = if (object$spec$mode != "censored regression") { - rlang::abort("For hazard predictions, the object should be a censored regression.") + cli::cli_abort( + "For hazard predictions, the object should be a censored regression.", + call = call + ) }, "linear_pred" = if (object$spec$mode != "censored regression") { - rlang::abort("For the linear predictor, the object should be a censored regression.") + cli::cli_abort( + "For the linear predictor, the object should be a censored regression.", + call = call + ) } ) @@ -349,56 +373,57 @@ check_pred_type_dots <- function(object, type, ..., call = rlang::caller_env()) other_args <- c("interval", "level", "std_error", "quantile", "time", "eval_time", "increasing") + + eval_time_types <- c("survival", "hazard") + is_pred_arg <- names(the_dots) %in% other_args if (any(!is_pred_arg)) { bad_args <- names(the_dots)[!is_pred_arg] bad_args <- paste0("`", bad_args, "`", collapse = ", ") - rlang::abort( - glue::glue( - "The ellipses are not used to pass args to the model function's ", - "predict function. These arguments cannot be used: {bad_args}", - ) + cli::cli_abort( + "The ellipses are not used to pass args to the model function's + predict function. These arguments cannot be used: {.val bad_args}", + call = call ) } # ---------------------------------------------------------------------------- # places where eval_time should not be given if (any(nms == "eval_time") & !type %in% c("survival", "hazard")) { - rlang::abort( - paste( - "`eval_time` should only be passed to `predict()` when `type` is one of:", - paste0("'", c("survival", "hazard"), "'", collapse = ", ") - ) - ) + cli::cli_abort( + "{.arg eval_time} should only be passed to {.fn predict} when \\ + {.arg type} is one of {.or {.val {eval_time_types}}}.", + call = call + ) + + } if (any(nms == "time") & !type %in% c("survival", "hazard")) { - rlang::abort( - paste( - "'time' should only be passed to `predict()` when 'type' is one of:", - paste0("'", c("survival", "hazard"), "'", collapse = ", ") - ) + cli::cli_abort( + "{.arg time} should only be passed to {.fn predict} when {.arg type} is + one of {.or {.val {eval_time_types}}}.", + call = call ) } # when eval_time should be passed if (!any(nms %in% c("eval_time", "time")) & type %in% c("survival", "hazard")) { - rlang::abort( - paste( - "When using `type` values of 'survival' or 'hazard',", - "a numeric vector `eval_time` should also be given." - ) - ) + cli::cli_abort( + "When using {.arg type} values of {.or {.val {eval_time_types}}} a numeric + vector {.arg eval_time} should also be given.", + call = call + ) } # `increasing` only applies to linear_pred for censored regression if (any(nms == "increasing") & !(type == "linear_pred" & object$spec$mode == "censored regression")) { - rlang::abort( - paste( - "The 'increasing' argument only applies to predictions of", - "type 'linear_pred' for the mode censored regression." - ) + cli::cli_abort( + "{.arg increasing} only applies to predictions of + type 'linear_pred' for the mode censored regression.", + call = call ) + } invisible(TRUE)