diff --git a/DESCRIPTION b/DESCRIPTION index 1c1a08492..bfa80fed1 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -1,6 +1,6 @@ Package: parsnip Title: A Common API to Modeling and Analysis Functions -Version: 1.2.1.9000 +Version: 1.2.1.9001 Authors@R: c( person("Max", "Kuhn", , "max@posit.co", role = c("aut", "cre")), person("Davis", "Vaughan", , "davis@posit.co", role = "aut"), diff --git a/R/bag_tree.R b/R/bag_tree.R index d36997083..e80fc200a 100644 --- a/R/bag_tree.R +++ b/R/bag_tree.R @@ -85,9 +85,7 @@ update.bag_tree <- # ------------------------------------------------------------------------------ #' @export -check_args.bag_tree <- function(object) { - if (object$engine == "C5.0" && object$mode == "regression") - stop("C5.0 is classification only.", call. = FALSE) +check_args.bag_tree <- function(object, call = rlang::caller_env()) { invisible(object) } diff --git a/R/boost_tree.R b/R/boost_tree.R index c7227eadc..799afee1e 100644 --- a/R/boost_tree.R +++ b/R/boost_tree.R @@ -164,23 +164,15 @@ translate.boost_tree <- function(x, engine = x$engine, ...) { # ------------------------------------------------------------------------------ #' @export -check_args.boost_tree <- function(object) { +check_args.boost_tree <- function(object, call = rlang::caller_env()) { args <- lapply(object$args, rlang::eval_tidy) - if (is.numeric(args$trees) && args$trees < 0) { - rlang::abort("`trees` should be >= 1.") - } - if (is.numeric(args$sample_size) && (args$sample_size < 0 | args$sample_size > 1)) { - rlang::abort("`sample_size` should be within [0,1].") - } - if (is.numeric(args$tree_depth) && args$tree_depth < 0) { - rlang::abort("`tree_depth` should be >= 1.") - } - if (is.numeric(args$min_n) && args$min_n < 0) { - rlang::abort("`min_n` should be >= 1.") - } - + check_number_whole(args$trees, min = 0, allow_null = TRUE, call = call, arg = "trees") + check_number_decimal(args$sample_size, min = 0, max = 1, allow_null = TRUE, call = call, arg = "sample_size") + check_number_whole(args$tree_depth, min = 0, allow_null = TRUE, call = call, arg = "tree_depth") + check_number_whole(args$min_n, min = 0, allow_null = TRUE, call = call, arg = "min_n") + invisible(object) } diff --git a/R/c5_rules.R b/R/c5_rules.R index 7e5040f5d..904e2cb56 100644 --- a/R/c5_rules.R +++ b/R/c5_rules.R @@ -111,32 +111,23 @@ update.C5_rules <- # make work in different places #' @export -check_args.C5_rules <- function(object) { +check_args.C5_rules <- function(object, call = rlang::caller_env()) { args <- lapply(object$args, rlang::eval_tidy) - if (is.numeric(args$trees)) { - if (length(args$trees) > 1) { - rlang::abort("Only a single value of `trees` is used.") - } - msg <- "The number of trees should be >= 1 and <= 100. Truncating the value." - if (args$trees > 100) { - object$args$trees <- - rlang::new_quosure(100L, env = rlang::empty_env()) - rlang::warn(msg) - } - if (args$trees < 1) { - object$args$trees <- - rlang::new_quosure(1L, env = rlang::empty_env()) - rlang::warn(msg) - } + check_number_whole(args$min_n, allow_null = TRUE, call = call, arg = "min_n") + check_number_whole(args$tree, allow_null = TRUE, call = call, arg = "tree") + msg <- "The number of trees should be {.code >= 1} and {.code <= 100}" + if (!(is.null(args$trees)) && args$trees > 100) { + object$args$trees <- rlang::new_quosure(100L, env = rlang::empty_env()) + cli::cli_warn(c(msg, "Truncating to 100.")) } - if (is.numeric(args$min_n)) { - if (length(args$min_n) > 1) { - rlang::abort("Only a single `min_n`` value is used.") - } + if (!(is.null(args$trees)) && args$trees < 1) { + object$args$trees <- rlang::new_quosure(1L, env = rlang::empty_env()) + cli::cli_warn(c(msg, "Truncating to 1.")) } + invisible(object) } diff --git a/R/cubist_rules.R b/R/cubist_rules.R index 5273a2665..a8cc5407c 100644 --- a/R/cubist_rules.R +++ b/R/cubist_rules.R @@ -135,44 +135,36 @@ update.cubist_rules <- # make work in different places #' @export -check_args.cubist_rules <- function(object) { +check_args.cubist_rules <- function(object, call = rlang::caller_env()) { args <- lapply(object$args, rlang::eval_tidy) - if (is.numeric(args$committees)) { - if (length(args$committees) > 1) { - rlang::abort("Only a single committee member is used.") - } - msg <- "The number of committees should be >= 1 and <= 100. Truncating the value." - if (args$committees > 100) { - object$args$committees <- - rlang::new_quosure(100L, env = rlang::empty_env()) - rlang::warn(msg) - } - if (args$committees < 1) { - object$args$committees <- - rlang::new_quosure(1L, env = rlang::empty_env()) - rlang::warn(msg) - } + check_number_whole(args$committees, allow_null = TRUE, call = call, arg = "committees") - } - if (is.numeric(args$neighbors)) { - if (length(args$neighbors) > 1) { - rlang::abort("Only a single neighbors value is used.") - } - msg <- "The number of neighbors should be >= 0 and <= 9. Truncating the value." - if (args$neighbors > 9) { - object$args$neighbors <- - rlang::new_quosure(9L, env = rlang::empty_env()) - rlang::warn(msg) - } - if (args$neighbors < 0) { - object$args$neighbors <- - rlang::new_quosure(0L, env = rlang::empty_env()) - rlang::warn(msg) + msg <- "The number of committees should be {.code >= 1} and {.code <= 100}." + if (!(is.null(args$committees)) && args$committees > 100) { + object$args$committees <- + rlang::new_quosure(100L, env = rlang::empty_env()) + cli::cli_warn(c(msg, "Truncating to 100.")) } + if (!(is.null(args$committees)) && args$committees < 1) { + object$args$committees <- + rlang::new_quosure(1L, env = rlang::empty_env()) + cli::cli_warn(c(msg, "Truncating to 1.")) + } + + check_number_whole(args$neighbors, allow_null = TRUE, call = call, arg = "neighbors") + msg <- "The number of neighbors should be {.code >= 0} and {.code <= 9}." + if (!(is.null(args$neighbors)) && args$neighbors > 9) { + object$args$neighbors <- rlang::new_quosure(9L, env = rlang::empty_env()) + cli::cli_warn(c(msg, "Truncating to 9.")) } + if (!(is.null(args$neighbors)) && args$neighbors < 0) { + object$args$neighbors <- rlang::new_quosure(0L, env = rlang::empty_env()) + cli::cli_warn(c(msg, "Truncating to 0.")) + } + invisible(object) } diff --git a/R/decision_tree.R b/R/decision_tree.R index 98c13d26a..8266fe806 100644 --- a/R/decision_tree.R +++ b/R/decision_tree.R @@ -128,9 +128,7 @@ translate.decision_tree <- function(x, engine = x$engine, ...) { # ------------------------------------------------------------------------------ #' @export -check_args.decision_tree <- function(object) { - if (object$engine == "C5.0" && object$mode == "regression") - rlang::abort("C5.0 is classification only.") +check_args.decision_tree <- function(object, call = rlang::caller_env()) { invisible(object) } diff --git a/R/discrim_flexible.R b/R/discrim_flexible.R index 0f5c0162c..8b2826b1a 100644 --- a/R/discrim_flexible.R +++ b/R/discrim_flexible.R @@ -85,21 +85,14 @@ update.discrim_flexible <- # ------------------------------------------------------------------------------ #' @export -check_args.discrim_flexible <- function(object) { +check_args.discrim_flexible <- function(object, call = rlang::caller_env()) { args <- lapply(object$args, rlang::eval_tidy) - if (is.numeric(args$prod_degree) && args$prod_degree < 0) - stop("`prod_degree` should be >= 1", call. = FALSE) - - if (is.numeric(args$num_terms) && args$num_terms < 0) - stop("`num_terms` should be >= 1", call. = FALSE) - - if (!is.character(args$prune_method) && - !is.null(args$prune_method) && - !is.character(args$prune_method)) - stop("`prune_method` should be a single string value", call. = FALSE) - + check_number_whole(args$prod_degree, min = 1, allow_null = TRUE, call = call, arg = "prod_degree") + check_number_whole(args$num_terms, min = 1, allow_null = TRUE, call = call, arg = "num_terms") + check_string(args$prune_method, allow_empty = FALSE, allow_null = TRUE, call = call, arg = "prune_method") + invisible(object) } diff --git a/R/discrim_linear.R b/R/discrim_linear.R index 54d618a9c..22eff4ea5 100644 --- a/R/discrim_linear.R +++ b/R/discrim_linear.R @@ -80,13 +80,11 @@ update.discrim_linear <- # ------------------------------------------------------------------------------ #' @export -check_args.discrim_linear <- function(object) { +check_args.discrim_linear <- function(object, call = rlang::caller_env()) { args <- lapply(object$args, rlang::eval_tidy) - if (all(is.numeric(args$penalty)) && any(args$penalty < 0)) { - stop("The amount of regularization should be >= 0", call. = FALSE) - } + check_number_decimal(args$penalty, min = 0, allow_null = TRUE, call = call, arg = "penalty") invisible(object) } diff --git a/R/discrim_regularized.R b/R/discrim_regularized.R index ddf327d22..e6f209c4c 100644 --- a/R/discrim_regularized.R +++ b/R/discrim_regularized.R @@ -95,18 +95,13 @@ update.discrim_regularized <- # ------------------------------------------------------------------------------ #' @export -check_args.discrim_regularized <- function(object) { +check_args.discrim_regularized <- function(object, call = rlang::caller_env()) { args <- lapply(object$args, rlang::eval_tidy) - if (is.numeric(args$frac_common_cov) && - (args$frac_common_cov < 0 | args$frac_common_cov > 1)) { - stop("The common covariance fraction should be between zero and one", call. = FALSE) - } - if (is.numeric(args$frac_identity) && - (args$frac_identity < 0 | args$frac_identity > 1)) { - stop("The identity matrix fraction should be between zero and one", call. = FALSE) - } + check_number_decimal(args$frac_common_cov, min = 0, max = 1, allow_null = TRUE, call = call, arg = "frac_common_cov") + check_number_decimal(args$frac_identity, min = 0, max = 1, allow_null = TRUE, call = call, arg = "frac_identity") + invisible(object) } diff --git a/R/fit_helpers.R b/R/fit_helpers.R index 29002a41d..7e2437434 100644 --- a/R/fit_helpers.R +++ b/R/fit_helpers.R @@ -4,7 +4,7 @@ # data to formula/data objects and so on. form_form <- - function(object, control, env, ...) { + function(object, control, env, ..., call = rlang::caller_env()) { if (inherits(env$data, "data.frame")) { check_outcome(eval_tidy(rlang::f_lhs(env$formula), env$data), object) @@ -32,7 +32,7 @@ form_form <- } # evaluate quoted args once here to check them - object <- check_args(object) + object <- check_args(object, call = call) # sub in arguments to actual syntax for corresponding engine object <- translate(object, engine = object$engine) @@ -60,7 +60,12 @@ form_form <- res } -xy_xy <- function(object, env, control, target = "none", ...) { +xy_xy <- function(object, + env, + control, + target = "none", + ..., + call = rlang::caller_env()) { if (inherits(env$x, "tbl_spark") | inherits(env$y, "tbl_spark")) rlang::abort("spark objects can only be used with the formula interface to `fit()`") @@ -83,7 +88,7 @@ xy_xy <- function(object, env, control, target = "none", ...) { } # evaluate quoted args once here to check them - object <- check_args(object) + object <- check_args(object, call = call) # sub in arguments to actual syntax for corresponding engine object <- translate(object, engine = object$engine) @@ -114,7 +119,7 @@ xy_xy <- function(object, env, control, target = "none", ...) { } form_xy <- function(object, control, env, - target = "none", ...) { + target = "none", ..., call = rlang::caller_env()) { encoding_info <- get_encoding(class(object)[1]) %>% @@ -138,7 +143,8 @@ form_xy <- function(object, control, env, object = object, env = env, #weights! control = control, - target = target + target = target, + call = call ) data_obj$y_var <- all.vars(rlang::f_lhs(env$formula)) data_obj$x <- NULL diff --git a/R/linear_reg.R b/R/linear_reg.R index 0d4625223..0b7b636b4 100644 --- a/R/linear_reg.R +++ b/R/linear_reg.R @@ -106,16 +106,12 @@ update.linear_reg <- # ------------------------------------------------------------------------------ #' @export -check_args.linear_reg <- function(object) { +check_args.linear_reg <- function(object, call = rlang::caller_env()) { args <- lapply(object$args, rlang::eval_tidy) - if (all(is.numeric(args$penalty)) && any(args$penalty < 0)) - rlang::abort("The amount of regularization should be >= 0.") - if (is.numeric(args$mixture) && (args$mixture < 0 | args$mixture > 1)) - rlang::abort("The mixture proportion should be within [0,1].") - if (is.numeric(args$mixture) && length(args$mixture) > 1) - rlang::abort("Only one value of `mixture` is allowed.") + check_number_decimal(args$mixture, min = 0, max = 1, allow_null = TRUE, call = call, arg = "mixture") + check_number_decimal(args$penalty, min = 0, allow_null = TRUE, call = call, arg = "penalty") invisible(object) } diff --git a/R/logistic_reg.R b/R/logistic_reg.R index 58b5de93e..40ccf5dc5 100644 --- a/R/logistic_reg.R +++ b/R/logistic_reg.R @@ -135,25 +135,32 @@ update.logistic_reg <- # ------------------------------------------------------------------------------ #' @export -check_args.logistic_reg <- function(object) { +check_args.logistic_reg <- function(object, call = rlang::caller_env()) { args <- lapply(object$args, rlang::eval_tidy) - if (all(is.numeric(args$penalty)) && any(args$penalty < 0)) - rlang::abort("The amount of regularization should be >= 0.") - if (is.numeric(args$mixture) && (args$mixture < 0 | args$mixture > 1)) - rlang::abort("The mixture proportion should be within [0,1].") - if (is.numeric(args$mixture) && length(args$mixture) > 1) - rlang::abort("Only one value of `mixture` is allowed.") + check_number_decimal(args$mixture, min = 0, max = 1, allow_null = TRUE, call = call, arg = "mixture") + check_number_decimal(args$penalty, min = 0, allow_null = TRUE, call = call, arg = "penalty") if (object$engine == "LiblineaR") { - if(is.numeric(args$mixture) && !args$mixture %in% 0:1) - rlang::abort(c("For the LiblineaR engine, mixture must be 0 or 1.", - "Choose a pure ridge model with `mixture = 0`.", - "Choose a pure lasso model with `mixture = 1`.", - "The Liblinear engine does not support other values.")) - if(all(is.numeric(args$penalty)) && !all(args$penalty > 0)) - rlang::abort("For the LiblineaR engine, penalty must be > 0.") + if (is.numeric(args$mixture) && !args$mixture %in% 0:1) { + cli::cli_abort( + c("x" = "For the {.pkg LiblineaR} engine, mixture must be 0 or 1, \\ + not {args$mixture}.", + "i" = "Choose a pure ridge model with {.code mixture = 0} or \\ + a pure lasso model with {.code mixture = 1}.", + "!" = "The {.pkg Liblinear} engine does not support other values."), + call = call + ) + } + + if ((!is.null(args$penalty)) && args$penalty == 0) { + cli::cli_abort( + "For the {.pkg LiblineaR} engine, {.arg penalty} must be {.code > 0}, \\ + not 0.", + call = call + ) + } } invisible(object) diff --git a/R/mars.R b/R/mars.R index 59e4f0f63..f4ad10ae7 100644 --- a/R/mars.R +++ b/R/mars.R @@ -105,20 +105,13 @@ translate.mars <- function(x, engine = x$engine, ...) { # ------------------------------------------------------------------------------ #' @export -check_args.mars <- function(object) { +check_args.mars <- function(object, call = rlang::caller_env()) { args <- lapply(object$args, rlang::eval_tidy) - if (is.numeric(args$prod_degree) && args$prod_degree < 0) - rlang::abort("`prod_degree` should be >= 1.") - - if (is.numeric(args$num_terms) && args$num_terms < 0) - rlang::abort("`num_terms` should be >= 1.") - - if (!is_varying(args$prune_method) && - !is.null(args$prune_method) && - !is.character(args$prune_method)) - rlang::abort("`prune_method` should be a single string value.") + check_number_whole(args$prod_degree, min = 1, allow_null = TRUE, call = call, arg = "prod_degree") + check_number_whole(args$num_terms, min = 1, allow_null = TRUE, call = call, arg = "num_terms") + check_string(args$prune_method, allow_empty = FALSE, allow_null = TRUE, call = call, arg = "prune_method") invisible(object) } diff --git a/R/misc.R b/R/misc.R index fa591aecc..b564870fd 100644 --- a/R/misc.R +++ b/R/misc.R @@ -285,12 +285,12 @@ show_fit <- function(model, eng) { #' @export #' @keywords internal #' @rdname add_on_exports -check_args <- function(object) { +check_args <- function(object, call = rlang::caller_env()) { UseMethod("check_args") } #' @export -check_args.default <- function(object) { +check_args.default <- function(object, call = rlang::caller_env()) { invisible(object) } diff --git a/R/mlp.R b/R/mlp.R index 8b5ba1b59..fa3cf097c 100644 --- a/R/mlp.R +++ b/R/mlp.R @@ -126,21 +126,20 @@ translate.mlp <- function(x, engine = x$engine, ...) { # ------------------------------------------------------------------------------ #' @export -check_args.mlp <- function(object) { +check_args.mlp <- function(object, call = rlang::caller_env()) { args <- lapply(object$args, rlang::eval_tidy) - if (is.numeric(args$penalty)) - if (args$penalty < 0) - rlang::abort("The amount of weight decay must be >= 0.") + check_number_decimal(args$penalty, min = 0, allow_null = TRUE, call = call, arg = "penalty") + check_number_decimal(args$dropout, min = 0, max = 1, allow_null = TRUE, call = call, arg = "dropout") - if (is.numeric(args$dropout)) - if (args$dropout < 0 | args$dropout >= 1) - rlang::abort("The dropout proportion must be on [0, 1).") - - if (is.numeric(args$penalty) & is.numeric(args$dropout)) - if (args$dropout > 0 & args$penalty > 0) - rlang::abort("Both weight decay and dropout should not be specified.") + if (is.numeric(args$penalty) && is.numeric(args$dropout) && + args$dropout > 0 && args$penalty > 0) { + cli::cli_abort( + "Both weight decay and dropout should not be specified.", + call = call + ) + } invisible(object) } diff --git a/R/multinom_reg.R b/R/multinom_reg.R index 8b0d887e0..1a8f0e8a1 100644 --- a/R/multinom_reg.R +++ b/R/multinom_reg.R @@ -100,14 +100,12 @@ update.multinom_reg <- # ------------------------------------------------------------------------------ #' @export -check_args.multinom_reg <- function(object) { +check_args.multinom_reg <- function(object, call = rlang::caller_env()) { args <- lapply(object$args, rlang::eval_tidy) - if (all(is.numeric(args$penalty)) && any(args$penalty < 0)) - rlang::abort("The amount of regularization should be >= 0.") - if (is.numeric(args$mixture) && (args$mixture < 0 | args$mixture > 1)) - rlang::abort("The mixture proportion should be within [0,1].") + check_number_decimal(args$mixture, min = 0, max = 1, allow_null = TRUE, call = call, arg = "mixture") + check_number_decimal(args$penalty, min = 0, allow_null = TRUE, call = call, arg = "penalty") invisible(object) } diff --git a/R/nearest_neighbor.R b/R/nearest_neighbor.R index 0976d2907..37daaea4c 100644 --- a/R/nearest_neighbor.R +++ b/R/nearest_neighbor.R @@ -90,26 +90,16 @@ update.nearest_neighbor <- function(object, ) } - -positive_int_scalar <- function(x) { - (length(x) == 1) && (x > 0) && (x %% 1 == 0) -} - # ------------------------------------------------------------------------------ #' @export -check_args.nearest_neighbor <- function(object) { +check_args.nearest_neighbor <- function(object, call = rlang::caller_env()) { args <- lapply(object$args, rlang::eval_tidy) - if (is.numeric(args$neighbors) && !positive_int_scalar(args$neighbors)) { - rlang::abort("`neighbors` must be a length 1 positive integer.") - } - - if (is.character(args$weight_func) && length(args$weight_func) > 1) { - rlang::abort("The length of `weight_func` must be 1.") - } - + check_number_whole(args$neighbors, min = 0, allow_null = TRUE, call = call, arg = "neighbors") + check_string(args$weight_func, allow_null = TRUE, call = call, arg = "weight_func") + invisible(object) } diff --git a/R/pls.R b/R/pls.R index c9eea560f..3a2bc13d7 100644 --- a/R/pls.R +++ b/R/pls.R @@ -87,13 +87,11 @@ update.pls <- # ------------------------------------------------------------------------------ #' @export -check_args.pls <- function(object) { +check_args.pls <- function(object, call = rlang::caller_env()) { args <- lapply(object$args, rlang::eval_tidy) - if (is.numeric(args$num_comp) && args$num_comp < 1) { - rlang::abort("`num_comp` should be >= 0.") - } + check_number_whole(args$num_comp, min = 0, allow_null = TRUE, call = call, arg = "num_comp") invisible(object) } diff --git a/R/poisson_reg.R b/R/poisson_reg.R index e538f29ee..e3201aca1 100644 --- a/R/poisson_reg.R +++ b/R/poisson_reg.R @@ -101,16 +101,12 @@ translate.poisson_reg <- function(x, engine = x$engine, ...) { # ------------------------------------------------------------------------------ #' @export -check_args.poisson_reg <- function(object) { +check_args.poisson_reg <- function(object, call = rlang::caller_env()) { args <- lapply(object$args, rlang::eval_tidy) - if (all(is.numeric(args$penalty)) && any(args$penalty < 0)) - rlang::abort("The amount of regularization should be >= 0.") - if (is.numeric(args$mixture) && (args$mixture < 0 | args$mixture > 1)) - rlang::abort("The mixture proportion should be within [0,1].") - if (is.numeric(args$mixture) && length(args$mixture) > 1) - rlang::abort("Only one value of `mixture` is allowed.") + check_number_decimal(args$mixture, min = 0, max = 1, allow_null = TRUE, call = call, arg = "mixture") + check_number_decimal(args$penalty, min = 0, allow_null = TRUE, call = call, arg = "penalty") invisible(object) } diff --git a/R/rand_forest.R b/R/rand_forest.R index 425b9d9fb..2a6aed042 100644 --- a/R/rand_forest.R +++ b/R/rand_forest.R @@ -161,7 +161,7 @@ translate.rand_forest <- function(x, engine = x$engine, ...) { # ------------------------------------------------------------------------------ #' @export -check_args.rand_forest <- function(object) { +check_args.rand_forest <- function(object, call = rlang::caller_env()) { # move translate checks here? invisible(object) } diff --git a/R/surv_reg.R b/R/surv_reg.R index e0d321536..85638840d 100644 --- a/R/surv_reg.R +++ b/R/surv_reg.R @@ -83,7 +83,7 @@ translate.surv_reg <- function(x, engine = x$engine, ...) { # ------------------------------------------------------------------------------ #' @export -check_args.surv_reg <- function(object) { +check_args.surv_reg <- function(object, call = rlang::caller_env()) { if (object$engine == "flexsurv") { diff --git a/R/survival_reg.R b/R/survival_reg.R index 6b270d825..d34781ee9 100644 --- a/R/survival_reg.R +++ b/R/survival_reg.R @@ -82,7 +82,7 @@ translate.survival_reg <- function(x, engine = x$engine, ...) { } #' @export -check_args.survival_reg <- function(object) { +check_args.survival_reg <- function(object, call = rlang::caller_env()) { if (object$engine == "flexsurv") { diff --git a/R/svm_linear.R b/R/svm_linear.R index b888758f3..2548cd499 100644 --- a/R/svm_linear.R +++ b/R/svm_linear.R @@ -140,7 +140,7 @@ translate.svm_linear <- function(x, engine = x$engine, ...) { # ------------------------------------------------------------------------------ #' @export -check_args.svm_linear <- function(object) { +check_args.svm_linear <- function(object, call = rlang::caller_env()) { invisible(object) } diff --git a/R/svm_poly.R b/R/svm_poly.R index 028a09294..4acd1afe8 100644 --- a/R/svm_poly.R +++ b/R/svm_poly.R @@ -134,7 +134,7 @@ translate.svm_poly <- function(x, engine = x$engine, ...) { # ------------------------------------------------------------------------------ #' @export -check_args.svm_poly <- function(object) { +check_args.svm_poly <- function(object, call = rlang::caller_env()) { invisible(object) } diff --git a/R/svm_rbf.R b/R/svm_rbf.R index af6bde862..ba8abf272 100644 --- a/R/svm_rbf.R +++ b/R/svm_rbf.R @@ -139,7 +139,7 @@ translate.svm_rbf <- function(x, engine = x$engine, ...) { # ------------------------------------------------------------------------------ #' @export -check_args.svm_rbf <- function(object) { +check_args.svm_rbf <- function(object, call = rlang::caller_env()) { invisible(object) } diff --git a/man/add_on_exports.Rd b/man/add_on_exports.Rd index 8ecfcdad4..69760d481 100644 --- a/man/add_on_exports.Rd +++ b/man/add_on_exports.Rd @@ -19,7 +19,7 @@ null_value(x) show_fit(model, eng) -check_args(object) +check_args(object, call = rlang::caller_env()) update_dot_check(...) diff --git a/tests/testthat/_snaps/boost_tree.md b/tests/testthat/_snaps/boost_tree.md index 890681236..f84c5d509 100644 --- a/tests/testthat/_snaps/boost_tree.md +++ b/tests/testthat/_snaps/boost_tree.md @@ -15,3 +15,51 @@ Computational engine: C5.0 +# bad input + + Code + boost_tree(mode = "bogus") + Condition + Error in `boost_tree()`: + ! "bogus" is not a known mode for model `boost_tree()`. + +# check_args() works + + Code + spec <- boost_tree(trees = -1) %>% set_engine("xgboost") %>% set_mode( + "classification") + fit(spec, class ~ ., hpc) + Condition + Error in `fit()`: + ! `trees` must be a whole number larger than or equal to 0 or `NULL`, not the number -1. + +--- + + Code + spec <- boost_tree(sample_size = -10) %>% set_engine("xgboost") %>% set_mode( + "classification") + fit(spec, class ~ ., hpc) + Condition + Error in `fit()`: + ! `sample_size` must be a number between 0 and 1 or `NULL`, not the number -10. + +--- + + Code + spec <- boost_tree(tree_depth = -10) %>% set_engine("xgboost") %>% set_mode( + "classification") + fit(spec, class ~ ., hpc) + Condition + Error in `fit()`: + ! `tree_depth` must be a whole number larger than or equal to 0 or `NULL`, not the number -10. + +--- + + Code + spec <- boost_tree(min_n = -10) %>% set_engine("xgboost") %>% set_mode( + "classification") + fit(spec, class ~ ., hpc) + Condition + Error in `fit()`: + ! `min_n` must be a whole number larger than or equal to 0 or `NULL`, not the number -10. + diff --git a/tests/testthat/_snaps/linear_reg.md b/tests/testthat/_snaps/linear_reg.md index 4626c9656..cce6764c5 100644 --- a/tests/testthat/_snaps/linear_reg.md +++ b/tests/testthat/_snaps/linear_reg.md @@ -23,3 +23,21 @@ Error in `predict()`: ! Please use `new_data` instead of `newdata`. +# check_args() works + + Code + spec <- linear_reg(mixture = -1) %>% set_engine("lm") %>% set_mode("regression") + fit(spec, compounds ~ ., hpc) + Condition + Error in `fit()`: + ! `mixture` must be a number between 0 and 1 or `NULL`, not the number -1. + +--- + + Code + spec <- linear_reg(penalty = -1) %>% set_engine("lm") %>% set_mode("regression") + fit(spec, compounds ~ ., hpc) + Condition + Error in `fit()`: + ! `penalty` must be a number larger than or equal to 0 or `NULL`, not the number -1. + diff --git a/tests/testthat/_snaps/logistic_reg.md b/tests/testthat/_snaps/logistic_reg.md index ed336d9e3..9bb671332 100644 --- a/tests/testthat/_snaps/logistic_reg.md +++ b/tests/testthat/_snaps/logistic_reg.md @@ -29,3 +29,45 @@ Warning: glm.fit: fitted probabilities numerically 0 or 1 occurred +# check_args() works + + Code + spec <- logistic_reg(mixture = -1) %>% set_engine("glm") %>% set_mode( + "classification") + fit(spec, Class ~ ., lending_club) + Condition + Error in `fit()`: + ! `mixture` must be a number between 0 and 1 or `NULL`, not the number -1. + +--- + + Code + spec <- logistic_reg(penalty = -1) %>% set_engine("glm") %>% set_mode( + "classification") + fit(spec, Class ~ ., lending_club) + Condition + Error in `fit()`: + ! `penalty` must be a number larger than or equal to 0 or `NULL`, not the number -1. + +--- + + Code + spec <- logistic_reg(mixture = 0.5) %>% set_engine("LiblineaR") %>% set_mode( + "classification") + fit(spec, Class ~ ., lending_club) + Condition + Error in `fit()`: + x For the LiblineaR engine, mixture must be 0 or 1, not 0.5. + i Choose a pure ridge model with `mixture = 0` or a pure lasso model with `mixture = 1`. + ! The Liblinear engine does not support other values. + +--- + + Code + spec <- logistic_reg(penalty = 0) %>% set_engine("LiblineaR") %>% set_mode( + "classification") + fit(spec, Class ~ ., lending_club) + Condition + Error in `fit()`: + ! For the LiblineaR engine, `penalty` must be `> 0`, not 0. + diff --git a/tests/testthat/_snaps/mars.md b/tests/testthat/_snaps/mars.md index a1596440e..64cc504a1 100644 --- a/tests/testthat/_snaps/mars.md +++ b/tests/testthat/_snaps/mars.md @@ -22,3 +22,33 @@ Error in `multi_predict()`: ! Please use `new_data` instead of `newdata`. +# check_args() works + + Code + spec <- mars(prod_degree = 0) %>% set_engine("earth") %>% set_mode( + "classification") + fit(spec, class ~ ., hpc) + Condition + Error in `fit()`: + ! `prod_degree` must be a whole number larger than or equal to 1 or `NULL`, not the number 0. + +--- + + Code + spec <- mars(num_terms = 0) %>% set_engine("earth") %>% set_mode( + "classification") + fit(spec, class ~ ., hpc) + Condition + Error in `fit()`: + ! `num_terms` must be a whole number larger than or equal to 1 or `NULL`, not the number 0. + +--- + + Code + spec <- mars(prune_method = 2) %>% set_engine("earth") %>% set_mode( + "classification") + fit(spec, class ~ ., hpc) + Condition + Error in `fit()`: + ! `prune_method` must be a single string or `NULL`, not the number 2. + diff --git a/tests/testthat/_snaps/mlp.md b/tests/testthat/_snaps/mlp.md index bfdd3bfb2..03b192b7a 100644 --- a/tests/testthat/_snaps/mlp.md +++ b/tests/testthat/_snaps/mlp.md @@ -15,3 +15,31 @@ Computational engine: nnet +# check_args() works + + Code + spec <- mlp(penalty = -1) %>% set_engine("keras") %>% set_mode("classification") + fit(spec, class ~ ., hpc) + Condition + Error in `fit()`: + ! `penalty` must be a number larger than or equal to 0 or `NULL`, not the number -1. + +--- + + Code + spec <- mlp(dropout = -1) %>% set_engine("keras") %>% set_mode("classification") + fit(spec, class ~ ., hpc) + Condition + Error in `fit()`: + ! `dropout` must be a number between 0 and 1 or `NULL`, not the number -1. + +--- + + Code + spec <- mlp(dropout = 1, penalty = 3) %>% set_engine("keras") %>% set_mode( + "classification") + fit(spec, class ~ ., hpc) + Condition + Error in `fit()`: + ! Both weight decay and dropout should not be specified. + diff --git a/tests/testthat/_snaps/multinom_reg.md b/tests/testthat/_snaps/multinom_reg.md index da29212d1..0f01bb547 100644 --- a/tests/testthat/_snaps/multinom_reg.md +++ b/tests/testthat/_snaps/multinom_reg.md @@ -15,3 +15,23 @@ Computational engine: glmnet +# check_args() works + + Code + spec <- multinom_reg(mixture = -1) %>% set_engine("keras") %>% set_mode( + "classification") + fit(spec, class ~ ., hpc) + Condition + Error in `fit()`: + ! `mixture` must be a number between 0 and 1 or `NULL`, not the number -1. + +--- + + Code + spec <- multinom_reg(penalty = -1) %>% set_engine("keras") %>% set_mode( + "classification") + fit(spec, class ~ ., hpc) + Condition + Error in `fit()`: + ! `penalty` must be a number larger than or equal to 0 or `NULL`, not the number -1. + diff --git a/tests/testthat/_snaps/nearest_neighbor.md b/tests/testthat/_snaps/nearest_neighbor.md index c4d6d63f3..401c06c51 100644 --- a/tests/testthat/_snaps/nearest_neighbor.md +++ b/tests/testthat/_snaps/nearest_neighbor.md @@ -15,3 +15,23 @@ Computational engine: kknn +# check_args() works + + Code + spec <- nearest_neighbor(neighbors = -1) %>% set_engine("kknn") %>% set_mode( + "classification") + fit(spec, class ~ ., hpc) + Condition + Error in `fit()`: + ! `neighbors` must be a whole number larger than or equal to 0 or `NULL`, not the number -1. + +--- + + Code + spec <- nearest_neighbor(weight_func = 2) %>% set_engine("kknn") %>% set_mode( + "classification") + fit(spec, class ~ ., hpc) + Condition + Error in `fit()`: + ! `weight_func` must be a single string or `NULL`, not the number 2. + diff --git a/tests/testthat/test-bart.R b/tests/testthat/test-bart.R new file mode 100644 index 000000000..162cb0d5c --- /dev/null +++ b/tests/testthat/test-bart.R @@ -0,0 +1,4 @@ +test_that("check_args() works", { + # Here for completeness, no checking is done + expect_true(TRUE) +}) diff --git a/tests/testthat/test_boost_tree.R b/tests/testthat/test_boost_tree.R index f92216870..a099d64bf 100644 --- a/tests/testthat/test_boost_tree.R +++ b/tests/testthat/test_boost_tree.R @@ -11,15 +11,7 @@ test_that('updating', { }) test_that('bad input', { - expect_error(boost_tree(mode = "bogus")) - expect_error({ - bt <- boost_tree(trees = -1) %>% set_engine("xgboost") - fit(bt, class ~ ., hpc) - }) - expect_error({ - bt <- boost_tree(min_n = -10) %>% set_engine("xgboost") - fit(bt, class ~ ., hpc) - }) + expect_snapshot(error = TRUE, boost_tree(mode = "bogus")) expect_message(translate(boost_tree(mode = "classification"), engine = NULL)) expect_error(translate(boost_tree(formula = y ~ x))) }) @@ -54,3 +46,44 @@ test_that('boost_tree can be fit with 1 predictor if validation is used', { fit(spec, mpg ~ disp, data = mtcars) ) }) + +test_that("check_args() works", { + skip_if_not_installed("xgboost") + + expect_snapshot( + error = TRUE, + { + spec <- boost_tree(trees = -1) %>% + set_engine("xgboost") %>% + set_mode("classification") + fit(spec, class ~ ., hpc) + } + ) + expect_snapshot( + error = TRUE, + { + spec <- boost_tree(sample_size = -10) %>% + set_engine("xgboost") %>% + set_mode("classification") + fit(spec, class ~ ., hpc) + } + ) + expect_snapshot( + error = TRUE, + { + spec <- boost_tree(tree_depth = -10) %>% + set_engine("xgboost") %>% + set_mode("classification") + fit(spec, class ~ ., hpc) + } + ) + expect_snapshot( + error = TRUE, + { + spec <- boost_tree(min_n = -10) %>% + set_engine("xgboost") %>% + set_mode("classification") + fit(spec, class ~ ., hpc) + } + ) +}) diff --git a/tests/testthat/test_decision_tree.R b/tests/testthat/test_decision_tree.R index bb1391299..85dd3ac1f 100644 --- a/tests/testthat/test_decision_tree.R +++ b/tests/testthat/test_decision_tree.R @@ -69,3 +69,8 @@ test_that('argument checks for data dimensions', { expect_equal(args$min_instances_per_node, rlang::expr(min_rows(1000, x))) }) + +test_that("check_args() works", { + # Here for completeness, no checking is done + expect_true(TRUE) +}) diff --git a/tests/testthat/test_gen_additive_model.R b/tests/testthat/test_gen_additive_model.R index 3e58605a2..0b6a76ea8 100644 --- a/tests/testthat/test_gen_additive_model.R +++ b/tests/testthat/test_gen_additive_model.R @@ -98,3 +98,8 @@ test_that('classification', { expect_equal(f_ci[[".pred_lower_Class2"]], lower) }) + +test_that("check_args() works", { + # Here for completeness, no checking is done + expect_true(TRUE) +}) diff --git a/tests/testthat/test_linear_reg.R b/tests/testthat/test_linear_reg.R index 21a93bae3..637ec8075 100644 --- a/tests/testthat/test_linear_reg.R +++ b/tests/testthat/test_linear_reg.R @@ -12,8 +12,6 @@ test_that('updating', { test_that('bad input', { expect_error(linear_reg(mode = "classification")) - # expect_error(linear_reg(penalty = -1)) - # expect_error(linear_reg(mixture = -1)) expect_error(translate(linear_reg(), engine = "wat?")) expect_error(translate(linear_reg(), engine = NULL)) expect_error(translate(linear_reg(formula = y ~ x))) @@ -342,3 +340,23 @@ test_that('lm can handle rankdeficient predictions', { expect_identical(names(preds), ".pred") }) +test_that("check_args() works", { + expect_snapshot( + error = TRUE, + { + spec <- linear_reg(mixture = -1) %>% + set_engine("lm") %>% + set_mode("regression") + fit(spec, compounds ~ ., hpc) + } + ) + expect_snapshot( + error = TRUE, + { + spec <- linear_reg(penalty = -1) %>% + set_engine("lm") %>% + set_mode("regression") + fit(spec, compounds ~ ., hpc) + } + ) +}) diff --git a/tests/testthat/test_logistic_reg.R b/tests/testthat/test_logistic_reg.R index 8a6434747..3b46e023f 100644 --- a/tests/testthat/test_logistic_reg.R +++ b/tests/testthat/test_logistic_reg.R @@ -249,3 +249,41 @@ test_that('liblinear probabilities', { }) +test_that("check_args() works", { + expect_snapshot( + error = TRUE, + { + spec <- logistic_reg(mixture = -1) %>% + set_engine("glm") %>% + set_mode("classification") + fit(spec, Class ~ ., lending_club) + } + ) + expect_snapshot( + error = TRUE, + { + spec <- logistic_reg(penalty = -1) %>% + set_engine("glm") %>% + set_mode("classification") + fit(spec, Class ~ ., lending_club) + } + ) + expect_snapshot( + error = TRUE, + { + spec <- logistic_reg(mixture = 0.5) %>% + set_engine("LiblineaR") %>% + set_mode("classification") + fit(spec, Class ~ ., lending_club) + } + ) + expect_snapshot( + error = TRUE, + { + spec <- logistic_reg(penalty = 0) %>% + set_engine("LiblineaR") %>% + set_mode("classification") + fit(spec, Class ~ ., lending_club) + } + ) +}) diff --git a/tests/testthat/test_mars.R b/tests/testthat/test_mars.R index dbd5a04fb..2a9d419cd 100644 --- a/tests/testthat/test_mars.R +++ b/tests/testthat/test_mars.R @@ -216,3 +216,35 @@ test_that('classification', { expect_equal(parsnip_pred$.pred_good, earth_pred) }) + +test_that("check_args() works", { + skip_if_not_installed("earth") + + expect_snapshot( + error = TRUE, + { + spec <- mars(prod_degree = 0) %>% + set_engine("earth") %>% + set_mode("classification") + fit(spec, class ~ ., hpc) + } + ) + expect_snapshot( + error = TRUE, + { + spec <- mars(num_terms = 0) %>% + set_engine("earth") %>% + set_mode("classification") + fit(spec, class ~ ., hpc) + } + ) + expect_snapshot( + error = TRUE, + { + spec <- mars(prune_method = 2) %>% + set_engine("earth") %>% + set_mode("classification") + fit(spec, class ~ ., hpc) + } + ) +}) diff --git a/tests/testthat/test_mlp.R b/tests/testthat/test_mlp.R index ec9dafad9..4e9a52b4e 100644 --- a/tests/testthat/test_mlp.R +++ b/tests/testthat/test_mlp.R @@ -1,3 +1,4 @@ +hpc <- hpc_data[1:150, c(2:5, 8)] test_that('updating', { expect_snapshot( @@ -48,3 +49,34 @@ test_that("more activations for brulee", { expect_true(inherits(fit$fit, "brulee_mlp")) }) +test_that("check_args() works", { + skip_if_not_installed("keras") + + expect_snapshot( + error = TRUE, + { + spec <- mlp(penalty = -1) %>% + set_engine("keras") %>% + set_mode("classification") + fit(spec, class ~ ., hpc) + } + ) + expect_snapshot( + error = TRUE, + { + spec <- mlp(dropout = -1) %>% + set_engine("keras") %>% + set_mode("classification") + fit(spec, class ~ ., hpc) + } + ) + expect_snapshot( + error = TRUE, + { + spec <- mlp(dropout = 1, penalty = 3) %>% + set_engine("keras") %>% + set_mode("classification") + fit(spec, class ~ ., hpc) + } + ) +}) diff --git a/tests/testthat/test_multinom_reg.R b/tests/testthat/test_multinom_reg.R index 18b25be9a..0cec39200 100644 --- a/tests/testthat/test_multinom_reg.R +++ b/tests/testthat/test_multinom_reg.R @@ -15,3 +15,26 @@ test_that('bad input', { expect_error(multinom_reg(penalty = 0.1) %>% set_engine()) expect_warning(translate(multinom_reg(penalty = 0.1) %>% set_engine("glmnet", x = hpc[,1:3], y = hpc$class))) }) + +test_that('check_args() works', { + skip_if_not_installed("keras") + + expect_snapshot( + error = TRUE, + { + spec <- multinom_reg(mixture = -1) %>% + set_engine("keras") %>% + set_mode("classification") + fit(spec, class ~ ., hpc) + } + ) + expect_snapshot( + error = TRUE, + { + spec <- multinom_reg(penalty = -1) %>% + set_engine("keras") %>% + set_mode("classification") + fit(spec, class ~ ., hpc) + } + ) +}) diff --git a/tests/testthat/test_nearest_neighbor.R b/tests/testthat/test_nearest_neighbor.R index 784d957da..c834dc570 100644 --- a/tests/testthat/test_nearest_neighbor.R +++ b/tests/testthat/test_nearest_neighbor.R @@ -10,3 +10,28 @@ test_that('bad input', { expect_error(nearest_neighbor(mode = "reallyunknown")) expect_error(nearest_neighbor() %>% set_engine( NULL)) }) + +test_that('check_args() works', { + skip_if_not_installed("kknn") + + hpc <- hpc_data[1:150, c(2:5, 8)] + + expect_snapshot( + error = TRUE, + { + spec <- nearest_neighbor(neighbors = -1) %>% + set_engine("kknn") %>% + set_mode("classification") + fit(spec, class ~ ., hpc) + } + ) + expect_snapshot( + error = TRUE, + { + spec <- nearest_neighbor(weight_func = 2) %>% + set_engine("kknn") %>% + set_mode("classification") + fit(spec, class ~ ., hpc) + } + ) +}) diff --git a/tests/testthat/test_nullmodel.R b/tests/testthat/test_nullmodel.R index 7b31c8a80..2f112c5f3 100644 --- a/tests/testthat/test_nullmodel.R +++ b/tests/testthat/test_nullmodel.R @@ -124,5 +124,7 @@ test_that('null_model printing', { ) }) - - +test_that("check_args() works", { + # Here for completeness, no checking is done + expect_true(TRUE) +}) diff --git a/tests/testthat/test_rand_forest.R b/tests/testthat/test_rand_forest.R index e74b5e7c7..b05633146 100644 --- a/tests/testthat/test_rand_forest.R +++ b/tests/testthat/test_rand_forest.R @@ -14,3 +14,7 @@ test_that('bad input', { expect_error(translate(rand_forest(mode = "classification", ytest = 2))) }) +test_that("check_args() works", { + # Here for completeness, no checking is done + expect_true(TRUE) +}) diff --git a/tests/testthat/test_svm_linear.R b/tests/testthat/test_svm_linear.R index 0b875c093..31c8bfc41 100644 --- a/tests/testthat/test_svm_linear.R +++ b/tests/testthat/test_svm_linear.R @@ -370,5 +370,7 @@ test_that('linear svm classification prediction: kernlab', { }) - - +test_that("check_args() works", { + # Here for completeness, no checking is done + expect_true(TRUE) +}) diff --git a/tests/testthat/test_svm_poly.R b/tests/testthat/test_svm_poly.R index 910bf978f..dda70dc52 100644 --- a/tests/testthat/test_svm_poly.R +++ b/tests/testthat/test_svm_poly.R @@ -185,3 +185,8 @@ test_that('svm poly classification probabilities', { parsnip_xy_probs <- predict(cls_xy_form, hpc_no_m[ind, -5], type = "prob") expect_equal(as.data.frame(kern_probs), as.data.frame(parsnip_xy_probs)) }) + +test_that("check_args() works", { + # Here for completeness, no checking is done + expect_true(TRUE) +}) diff --git a/tests/testthat/test_svm_rbf.R b/tests/testthat/test_svm_rbf.R index db9d5d1f5..8e0c4bc4c 100644 --- a/tests/testthat/test_svm_rbf.R +++ b/tests/testthat/test_svm_rbf.R @@ -195,4 +195,7 @@ test_that('svm rbf classification probabilities', { expect_equal(as.data.frame(kern_probs), as.data.frame(parsnip_xy_probs)) }) - +test_that("check_args() works", { + # Here for completeness, no checking is done + expect_true(TRUE) +})