From 8ad9cd0e172dfbb632ac69079421705589a38fcc Mon Sep 17 00:00:00 2001 From: Emil Hvitfeldt Date: Mon, 1 Apr 2024 11:08:47 -0700 Subject: [PATCH 01/24] pass calls around for check_args() --- R/bag_tree.R | 2 +- R/boost_tree.R | 2 +- R/c5_rules.R | 2 +- R/cubist_rules.R | 2 +- R/decision_tree.R | 2 +- R/discrim_flexible.R | 2 +- R/discrim_linear.R | 2 +- R/discrim_regularized.R | 2 +- R/fit_helpers.R | 13 +++++++++---- R/linear_reg.R | 2 +- R/logistic_reg.R | 2 +- R/mars.R | 2 +- R/misc.R | 4 ++-- R/mlp.R | 2 +- R/multinom_reg.R | 2 +- R/nearest_neighbor.R | 2 +- R/pls.R | 2 +- R/poisson_reg.R | 2 +- R/rand_forest.R | 2 +- R/surv_reg.R | 2 +- R/survival_reg.R | 2 +- R/svm_linear.R | 2 +- R/svm_poly.R | 2 +- R/svm_rbf.R | 2 +- 24 files changed, 33 insertions(+), 28 deletions(-) diff --git a/R/bag_tree.R b/R/bag_tree.R index d36997083..0bed9625f 100644 --- a/R/bag_tree.R +++ b/R/bag_tree.R @@ -85,7 +85,7 @@ update.bag_tree <- # ------------------------------------------------------------------------------ #' @export -check_args.bag_tree <- function(object) { +check_args.bag_tree <- function(object, call = rlang::caller_env()) { if (object$engine == "C5.0" && object$mode == "regression") stop("C5.0 is classification only.", call. = FALSE) invisible(object) diff --git a/R/boost_tree.R b/R/boost_tree.R index c7227eadc..48decd8f9 100644 --- a/R/boost_tree.R +++ b/R/boost_tree.R @@ -164,7 +164,7 @@ translate.boost_tree <- function(x, engine = x$engine, ...) { # ------------------------------------------------------------------------------ #' @export -check_args.boost_tree <- function(object) { +check_args.boost_tree <- function(object, call = rlang::caller_env()) { args <- lapply(object$args, rlang::eval_tidy) diff --git a/R/c5_rules.R b/R/c5_rules.R index 7e5040f5d..3a81b2858 100644 --- a/R/c5_rules.R +++ b/R/c5_rules.R @@ -111,7 +111,7 @@ update.C5_rules <- # make work in different places #' @export -check_args.C5_rules <- function(object) { +check_args.C5_rules <- function(object, call = rlang::caller_env()) { args <- lapply(object$args, rlang::eval_tidy) diff --git a/R/cubist_rules.R b/R/cubist_rules.R index 5273a2665..60452be67 100644 --- a/R/cubist_rules.R +++ b/R/cubist_rules.R @@ -135,7 +135,7 @@ update.cubist_rules <- # make work in different places #' @export -check_args.cubist_rules <- function(object) { +check_args.cubist_rules <- function(object, call = rlang::caller_env()) { args <- lapply(object$args, rlang::eval_tidy) diff --git a/R/decision_tree.R b/R/decision_tree.R index 98c13d26a..5b04e4446 100644 --- a/R/decision_tree.R +++ b/R/decision_tree.R @@ -128,7 +128,7 @@ translate.decision_tree <- function(x, engine = x$engine, ...) { # ------------------------------------------------------------------------------ #' @export -check_args.decision_tree <- function(object) { +check_args.decision_tree <- function(object, call = rlang::caller_env()) { if (object$engine == "C5.0" && object$mode == "regression") rlang::abort("C5.0 is classification only.") invisible(object) diff --git a/R/discrim_flexible.R b/R/discrim_flexible.R index 0f5c0162c..44ca60d30 100644 --- a/R/discrim_flexible.R +++ b/R/discrim_flexible.R @@ -85,7 +85,7 @@ update.discrim_flexible <- # ------------------------------------------------------------------------------ #' @export -check_args.discrim_flexible <- function(object) { +check_args.discrim_flexible <- function(object, call = rlang::caller_env()) { args <- lapply(object$args, rlang::eval_tidy) diff --git a/R/discrim_linear.R b/R/discrim_linear.R index 54d618a9c..6ccf89cac 100644 --- a/R/discrim_linear.R +++ b/R/discrim_linear.R @@ -80,7 +80,7 @@ update.discrim_linear <- # ------------------------------------------------------------------------------ #' @export -check_args.discrim_linear <- function(object) { +check_args.discrim_linear <- function(object, call = rlang::caller_env()) { args <- lapply(object$args, rlang::eval_tidy) diff --git a/R/discrim_regularized.R b/R/discrim_regularized.R index ddf327d22..d622c32b0 100644 --- a/R/discrim_regularized.R +++ b/R/discrim_regularized.R @@ -95,7 +95,7 @@ update.discrim_regularized <- # ------------------------------------------------------------------------------ #' @export -check_args.discrim_regularized <- function(object) { +check_args.discrim_regularized <- function(object, call = rlang::caller_env()) { args <- lapply(object$args, rlang::eval_tidy) diff --git a/R/fit_helpers.R b/R/fit_helpers.R index ec4ddf426..0b00f8541 100644 --- a/R/fit_helpers.R +++ b/R/fit_helpers.R @@ -4,7 +4,7 @@ # data to formula/data objects and so on. form_form <- - function(object, control, env, ...) { + function(object, control, env, ..., call = rlang::caller_env()) { if (inherits(env$data, "data.frame")) { check_outcome(eval_tidy(rlang::f_lhs(env$formula), env$data), object) @@ -32,7 +32,7 @@ form_form <- } # evaluate quoted args once here to check them - object <- check_args(object) + object <- check_args(object, call = call) # sub in arguments to actual syntax for corresponding engine object <- translate(object, engine = object$engine) @@ -70,7 +70,12 @@ form_form <- res } -xy_xy <- function(object, env, control, target = "none", ...) { +xy_xy <- function(object, + env, + control, + target = "none", + ..., + call = rlang::caller_env()) { if (inherits(env$x, "tbl_spark") | inherits(env$y, "tbl_spark")) rlang::abort("spark objects can only be used with the formula interface to `fit()`") @@ -93,7 +98,7 @@ xy_xy <- function(object, env, control, target = "none", ...) { } # evaluate quoted args once here to check them - object <- check_args(object) + object <- check_args(object, call = call) # sub in arguments to actual syntax for corresponding engine object <- translate(object, engine = object$engine) diff --git a/R/linear_reg.R b/R/linear_reg.R index 0d4625223..fa598e00a 100644 --- a/R/linear_reg.R +++ b/R/linear_reg.R @@ -106,7 +106,7 @@ update.linear_reg <- # ------------------------------------------------------------------------------ #' @export -check_args.linear_reg <- function(object) { +check_args.linear_reg <- function(object, call = rlang::caller_env()) { args <- lapply(object$args, rlang::eval_tidy) diff --git a/R/logistic_reg.R b/R/logistic_reg.R index 58b5de93e..90953fe6d 100644 --- a/R/logistic_reg.R +++ b/R/logistic_reg.R @@ -135,7 +135,7 @@ update.logistic_reg <- # ------------------------------------------------------------------------------ #' @export -check_args.logistic_reg <- function(object) { +check_args.logistic_reg <- function(object, call = rlang::caller_env()) { args <- lapply(object$args, rlang::eval_tidy) diff --git a/R/mars.R b/R/mars.R index 59e4f0f63..60589fff8 100644 --- a/R/mars.R +++ b/R/mars.R @@ -105,7 +105,7 @@ translate.mars <- function(x, engine = x$engine, ...) { # ------------------------------------------------------------------------------ #' @export -check_args.mars <- function(object) { +check_args.mars <- function(object, call = rlang::caller_env()) { args <- lapply(object$args, rlang::eval_tidy) diff --git a/R/misc.R b/R/misc.R index fa591aecc..b564870fd 100644 --- a/R/misc.R +++ b/R/misc.R @@ -285,12 +285,12 @@ show_fit <- function(model, eng) { #' @export #' @keywords internal #' @rdname add_on_exports -check_args <- function(object) { +check_args <- function(object, call = rlang::caller_env()) { UseMethod("check_args") } #' @export -check_args.default <- function(object) { +check_args.default <- function(object, call = rlang::caller_env()) { invisible(object) } diff --git a/R/mlp.R b/R/mlp.R index 8b5ba1b59..45ee65261 100644 --- a/R/mlp.R +++ b/R/mlp.R @@ -126,7 +126,7 @@ translate.mlp <- function(x, engine = x$engine, ...) { # ------------------------------------------------------------------------------ #' @export -check_args.mlp <- function(object) { +check_args.mlp <- function(object, call = rlang::caller_env()) { args <- lapply(object$args, rlang::eval_tidy) diff --git a/R/multinom_reg.R b/R/multinom_reg.R index 8b0d887e0..f07747d46 100644 --- a/R/multinom_reg.R +++ b/R/multinom_reg.R @@ -100,7 +100,7 @@ update.multinom_reg <- # ------------------------------------------------------------------------------ #' @export -check_args.multinom_reg <- function(object) { +check_args.multinom_reg <- function(object, call = rlang::caller_env()) { args <- lapply(object$args, rlang::eval_tidy) diff --git a/R/nearest_neighbor.R b/R/nearest_neighbor.R index 0976d2907..26966a2e4 100644 --- a/R/nearest_neighbor.R +++ b/R/nearest_neighbor.R @@ -98,7 +98,7 @@ positive_int_scalar <- function(x) { # ------------------------------------------------------------------------------ #' @export -check_args.nearest_neighbor <- function(object) { +check_args.nearest_neighbor <- function(object, call = rlang::caller_env()) { args <- lapply(object$args, rlang::eval_tidy) diff --git a/R/pls.R b/R/pls.R index c9eea560f..14653093b 100644 --- a/R/pls.R +++ b/R/pls.R @@ -87,7 +87,7 @@ update.pls <- # ------------------------------------------------------------------------------ #' @export -check_args.pls <- function(object) { +check_args.pls <- function(object, call = rlang::caller_env()) { args <- lapply(object$args, rlang::eval_tidy) diff --git a/R/poisson_reg.R b/R/poisson_reg.R index e538f29ee..427a204db 100644 --- a/R/poisson_reg.R +++ b/R/poisson_reg.R @@ -101,7 +101,7 @@ translate.poisson_reg <- function(x, engine = x$engine, ...) { # ------------------------------------------------------------------------------ #' @export -check_args.poisson_reg <- function(object) { +check_args.poisson_reg <- function(object, call = rlang::caller_env()) { args <- lapply(object$args, rlang::eval_tidy) diff --git a/R/rand_forest.R b/R/rand_forest.R index 425b9d9fb..2a6aed042 100644 --- a/R/rand_forest.R +++ b/R/rand_forest.R @@ -161,7 +161,7 @@ translate.rand_forest <- function(x, engine = x$engine, ...) { # ------------------------------------------------------------------------------ #' @export -check_args.rand_forest <- function(object) { +check_args.rand_forest <- function(object, call = rlang::caller_env()) { # move translate checks here? invisible(object) } diff --git a/R/surv_reg.R b/R/surv_reg.R index e0d321536..85638840d 100644 --- a/R/surv_reg.R +++ b/R/surv_reg.R @@ -83,7 +83,7 @@ translate.surv_reg <- function(x, engine = x$engine, ...) { # ------------------------------------------------------------------------------ #' @export -check_args.surv_reg <- function(object) { +check_args.surv_reg <- function(object, call = rlang::caller_env()) { if (object$engine == "flexsurv") { diff --git a/R/survival_reg.R b/R/survival_reg.R index 6b270d825..d34781ee9 100644 --- a/R/survival_reg.R +++ b/R/survival_reg.R @@ -82,7 +82,7 @@ translate.survival_reg <- function(x, engine = x$engine, ...) { } #' @export -check_args.survival_reg <- function(object) { +check_args.survival_reg <- function(object, call = rlang::caller_env()) { if (object$engine == "flexsurv") { diff --git a/R/svm_linear.R b/R/svm_linear.R index b888758f3..2548cd499 100644 --- a/R/svm_linear.R +++ b/R/svm_linear.R @@ -140,7 +140,7 @@ translate.svm_linear <- function(x, engine = x$engine, ...) { # ------------------------------------------------------------------------------ #' @export -check_args.svm_linear <- function(object) { +check_args.svm_linear <- function(object, call = rlang::caller_env()) { invisible(object) } diff --git a/R/svm_poly.R b/R/svm_poly.R index 028a09294..4acd1afe8 100644 --- a/R/svm_poly.R +++ b/R/svm_poly.R @@ -134,7 +134,7 @@ translate.svm_poly <- function(x, engine = x$engine, ...) { # ------------------------------------------------------------------------------ #' @export -check_args.svm_poly <- function(object) { +check_args.svm_poly <- function(object, call = rlang::caller_env()) { invisible(object) } diff --git a/R/svm_rbf.R b/R/svm_rbf.R index af6bde862..ba8abf272 100644 --- a/R/svm_rbf.R +++ b/R/svm_rbf.R @@ -139,7 +139,7 @@ translate.svm_rbf <- function(x, engine = x$engine, ...) { # ------------------------------------------------------------------------------ #' @export -check_args.svm_rbf <- function(object) { +check_args.svm_rbf <- function(object, call = rlang::caller_env()) { invisible(object) } From 2bdd04a3bdcd2f3ca90ba918987f9ec9d4c86bbe Mon Sep 17 00:00:00 2001 From: Emil Hvitfeldt Date: Mon, 1 Apr 2024 12:36:57 -0700 Subject: [PATCH 02/24] switch to {cli} in all check_args() methods --- R/bag_tree.R | 9 +++++++-- R/boost_tree.R | 24 ++++++++++-------------- R/c5_rules.R | 23 ++++++++++++++--------- R/cubist_rules.R | 34 ++++++++++++++++++++-------------- R/decision_tree.R | 9 +++++++-- R/discrim_flexible.R | 18 +++++++----------- R/discrim_linear.R | 5 ++++- R/discrim_regularized.R | 13 +++++-------- R/linear_reg.R | 15 +++++++++------ R/logistic_reg.R | 36 ++++++++++++++++++++++++------------ R/mars.R | 16 ++++++---------- R/mlp.R | 21 +++++++++++---------- R/multinom_reg.R | 12 ++++++++---- R/nearest_neighbor.R | 18 +++++------------- R/pls.R | 5 ++--- R/poisson_reg.R | 15 +++++++++------ 16 files changed, 148 insertions(+), 125 deletions(-) diff --git a/R/bag_tree.R b/R/bag_tree.R index 0bed9625f..b58ded365 100644 --- a/R/bag_tree.R +++ b/R/bag_tree.R @@ -86,8 +86,13 @@ update.bag_tree <- #' @export check_args.bag_tree <- function(object, call = rlang::caller_env()) { - if (object$engine == "C5.0" && object$mode == "regression") - stop("C5.0 is classification only.", call. = FALSE) + if (object$engine == "C5.0" && object$mode != "classification") { + cli::cli_abort( + "The engine {.pkg C5.0} only supports the mode {.val classification}, \\ + {.val {object$mode}} was requested.", + call = call + ) + } invisible(object) } diff --git a/R/boost_tree.R b/R/boost_tree.R index 48decd8f9..d0a3ab9a8 100644 --- a/R/boost_tree.R +++ b/R/boost_tree.R @@ -167,20 +167,16 @@ translate.boost_tree <- function(x, engine = x$engine, ...) { check_args.boost_tree <- function(object, call = rlang::caller_env()) { args <- lapply(object$args, rlang::eval_tidy) - - if (is.numeric(args$trees) && args$trees < 0) { - rlang::abort("`trees` should be >= 1.") - } - if (is.numeric(args$sample_size) && (args$sample_size < 0 | args$sample_size > 1)) { - rlang::abort("`sample_size` should be within [0,1].") - } - if (is.numeric(args$tree_depth) && args$tree_depth < 0) { - rlang::abort("`tree_depth` should be >= 1.") - } - if (is.numeric(args$min_n) && args$min_n < 0) { - rlang::abort("`min_n` should be >= 1.") - } - + trees <- args$trees + sample_size <- args$sample_size + tree_depth <- args$tree_depth + min_n <- args$min_n + + check_number_whole(trees, min = 0, allow_null = TRUE, call = call) + check_number_decimal(sample_size, min = 0, max = 1, allow_null = TRUE, call = call) + check_number_whole(tree_depth, min = 0, allow_null = TRUE, call = call) + check_number_whole(min_n, min = 0, allow_null = TRUE, call = call) + invisible(object) } diff --git a/R/c5_rules.R b/R/c5_rules.R index 3a81b2858..2ad3d3154 100644 --- a/R/c5_rules.R +++ b/R/c5_rules.R @@ -117,24 +117,29 @@ check_args.C5_rules <- function(object, call = rlang::caller_env()) { if (is.numeric(args$trees)) { if (length(args$trees) > 1) { - rlang::abort("Only a single value of `trees` is used.") + cli::cli_abort( + "Only a single value of {.arg trees} should be passed, \\ + not {length(args$trees)}." + ) } - msg <- "The number of trees should be >= 1 and <= 100. Truncating the value." + + msg <- "The number of trees should be {.code >= 1} and {.code <= 100}" if (args$trees > 100) { - object$args$trees <- - rlang::new_quosure(100L, env = rlang::empty_env()) - rlang::warn(msg) + object$args$trees <- rlang::new_quosure(100L, env = rlang::empty_env()) + cli::cli_warn(c(msg, "Truncating to 100.")) } if (args$trees < 1) { - object$args$trees <- - rlang::new_quosure(1L, env = rlang::empty_env()) - rlang::warn(msg) + object$args$trees <- rlang::new_quosure(1L, env = rlang::empty_env()) + cli::cli_warn(c(msg, "Truncating to 1.")) } } if (is.numeric(args$min_n)) { if (length(args$min_n) > 1) { - rlang::abort("Only a single `min_n`` value is used.") + cli::cli_abort( + "Only a single value of {.arg min_n} should be passed, \\ + not {length(args$min_n)}." + ) } } invisible(object) diff --git a/R/cubist_rules.R b/R/cubist_rules.R index 60452be67..4ab85f599 100644 --- a/R/cubist_rules.R +++ b/R/cubist_rules.R @@ -141,35 +141,41 @@ check_args.cubist_rules <- function(object, call = rlang::caller_env()) { if (is.numeric(args$committees)) { if (length(args$committees) > 1) { - rlang::abort("Only a single committee member is used.") + cli::cli_abort( + "Only a single value of {.arg committees} should be passed, \\ + not {length(args$committees)}." + ) } - msg <- "The number of committees should be >= 1 and <= 100. Truncating the value." + + msg <- "The number of committees should be {.code >= 1} and {.code <= 100}." if (args$committees > 100) { object$args$committees <- rlang::new_quosure(100L, env = rlang::empty_env()) - rlang::warn(msg) - } + cli::cli_warn(c(msg, "Truncating to 100.")) + } if (args$committees < 1) { object$args$committees <- rlang::new_quosure(1L, env = rlang::empty_env()) - rlang::warn(msg) - } + cli::cli_warn(c(msg, "Truncating to 100.")) + } } if (is.numeric(args$neighbors)) { if (length(args$neighbors) > 1) { - rlang::abort("Only a single neighbors value is used.") + cli::cli_abort( + "Only a single value of {.arg neighbors} should be passed, \\ + not {length(args$neighbors)}." + ) } - msg <- "The number of neighbors should be >= 0 and <= 9. Truncating the value." + + msg <- "The number of neighbors should be {.code >= 0} and {.code <= 9}." if (args$neighbors > 9) { - object$args$neighbors <- - rlang::new_quosure(9L, env = rlang::empty_env()) - rlang::warn(msg) + object$args$neighbors <- rlang::new_quosure(9L, env = rlang::empty_env()) + cli::cli_warn(c(msg, "Truncating to 9.")) } if (args$neighbors < 0) { - object$args$neighbors <- - rlang::new_quosure(0L, env = rlang::empty_env()) - rlang::warn(msg) + object$args$neighbors <- rlang::new_quosure(0L, env = rlang::empty_env()) + cli::cli_warn(c(msg, "Truncating to 0.")) } } diff --git a/R/decision_tree.R b/R/decision_tree.R index 5b04e4446..1e8192b85 100644 --- a/R/decision_tree.R +++ b/R/decision_tree.R @@ -129,8 +129,13 @@ translate.decision_tree <- function(x, engine = x$engine, ...) { #' @export check_args.decision_tree <- function(object, call = rlang::caller_env()) { - if (object$engine == "C5.0" && object$mode == "regression") - rlang::abort("C5.0 is classification only.") + if (object$engine == "C5.0" && object$mode != "classification") { + cli::cli_abort( + "The engine {.pkg C5.0} only supports the mode {.val classification}, \\ + {.val {object$mode}} was requested.", + call = call + ) + } invisible(object) } diff --git a/R/discrim_flexible.R b/R/discrim_flexible.R index 44ca60d30..a8b3dba92 100644 --- a/R/discrim_flexible.R +++ b/R/discrim_flexible.R @@ -88,18 +88,14 @@ update.discrim_flexible <- check_args.discrim_flexible <- function(object, call = rlang::caller_env()) { args <- lapply(object$args, rlang::eval_tidy) + prod_degree <- args$prod_degree + num_terms <- args$num_terms + prune_method <- args$prune_method - if (is.numeric(args$prod_degree) && args$prod_degree < 0) - stop("`prod_degree` should be >= 1", call. = FALSE) - - if (is.numeric(args$num_terms) && args$num_terms < 0) - stop("`num_terms` should be >= 1", call. = FALSE) - - if (!is.character(args$prune_method) && - !is.null(args$prune_method) && - !is.character(args$prune_method)) - stop("`prune_method` should be a single string value", call. = FALSE) - + check_number_whole(prod_degree, min = 1, allow_null = TRUE, call = call) + check_number_whole(num_terms, min = 1, allow_null = TRUE, call = call) + check_string(prune_method, allow_empty = FALSE, allow_null = TRUE, call = call) + invisible(object) } diff --git a/R/discrim_linear.R b/R/discrim_linear.R index 6ccf89cac..88c0379b3 100644 --- a/R/discrim_linear.R +++ b/R/discrim_linear.R @@ -85,7 +85,10 @@ check_args.discrim_linear <- function(object, call = rlang::caller_env()) { args <- lapply(object$args, rlang::eval_tidy) if (all(is.numeric(args$penalty)) && any(args$penalty < 0)) { - stop("The amount of regularization should be >= 0", call. = FALSE) + cli::cli_abort( + "The amount of regularization, {.arg penalty}, should be {.code >= 0}.", + call = call + ) } invisible(object) diff --git a/R/discrim_regularized.R b/R/discrim_regularized.R index d622c32b0..011d26e6c 100644 --- a/R/discrim_regularized.R +++ b/R/discrim_regularized.R @@ -98,15 +98,12 @@ update.discrim_regularized <- check_args.discrim_regularized <- function(object, call = rlang::caller_env()) { args <- lapply(object$args, rlang::eval_tidy) + frac_common_cov <- args$frac_common_cov + frac_identity <- args$frac_identity - if (is.numeric(args$frac_common_cov) && - (args$frac_common_cov < 0 | args$frac_common_cov > 1)) { - stop("The common covariance fraction should be between zero and one", call. = FALSE) - } - if (is.numeric(args$frac_identity) && - (args$frac_identity < 0 | args$frac_identity > 1)) { - stop("The identity matrix fraction should be between zero and one", call. = FALSE) - } + check_number_decimal(frac_common_cov, min = 0, max = 1, allow_null = TRUE, call = call) + check_number_decimal(frac_identity, min = 0, max = 1, allow_null = TRUE, call = call) + invisible(object) } diff --git a/R/linear_reg.R b/R/linear_reg.R index fa598e00a..2fbf31ea4 100644 --- a/R/linear_reg.R +++ b/R/linear_reg.R @@ -109,13 +109,16 @@ update.linear_reg <- check_args.linear_reg <- function(object, call = rlang::caller_env()) { args <- lapply(object$args, rlang::eval_tidy) + mixture <- args$mixture - if (all(is.numeric(args$penalty)) && any(args$penalty < 0)) - rlang::abort("The amount of regularization should be >= 0.") - if (is.numeric(args$mixture) && (args$mixture < 0 | args$mixture > 1)) - rlang::abort("The mixture proportion should be within [0,1].") - if (is.numeric(args$mixture) && length(args$mixture) > 1) - rlang::abort("Only one value of `mixture` is allowed.") + check_number_decimal(mixture, min = 0, max = 1, allow_null = TRUE, call = call) + + if (all(is.numeric(args$penalty)) && any(args$penalty < 0)) { + cli::cli_abort( + "The amount of regularization, {.arg penalty}, should be {.code >= 0}.", + call = call + ) + } invisible(object) } diff --git a/R/logistic_reg.R b/R/logistic_reg.R index 90953fe6d..21e0eb680 100644 --- a/R/logistic_reg.R +++ b/R/logistic_reg.R @@ -138,22 +138,34 @@ update.logistic_reg <- check_args.logistic_reg <- function(object, call = rlang::caller_env()) { args <- lapply(object$args, rlang::eval_tidy) + mixture <- args$mixture + + check_number_decimal(mixture, min = 0, max = 1, allow_null = TRUE, call = call) if (all(is.numeric(args$penalty)) && any(args$penalty < 0)) - rlang::abort("The amount of regularization should be >= 0.") - if (is.numeric(args$mixture) && (args$mixture < 0 | args$mixture > 1)) - rlang::abort("The mixture proportion should be within [0,1].") - if (is.numeric(args$mixture) && length(args$mixture) > 1) - rlang::abort("Only one value of `mixture` is allowed.") + cli::cli_abort( + "The amount of regularization, {.arg penalty}, should be {.code >= 0}.", + call = call + ) if (object$engine == "LiblineaR") { - if(is.numeric(args$mixture) && !args$mixture %in% 0:1) - rlang::abort(c("For the LiblineaR engine, mixture must be 0 or 1.", - "Choose a pure ridge model with `mixture = 0`.", - "Choose a pure lasso model with `mixture = 1`.", - "The Liblinear engine does not support other values.")) - if(all(is.numeric(args$penalty)) && !all(args$penalty > 0)) - rlang::abort("For the LiblineaR engine, penalty must be > 0.") + if (is.numeric(args$mixture) && !args$mixture %in% 0:1) { + cli::cli_abort( + "For the {.pkg LiblineaR} engine, mixture must be 0 or 1,\\ + not {arg$mixture}.\\ + Choose a pure ridge model with {.code mixture = 0}.\\ + Choose a pure lasso model with {.code mixture = 1}.\\ + The {.pkg Liblinear} engine does not support other values.", + call = call + ) + } + + if (all(is.numeric(args$penalty)) && !all(args$penalty > 0)) { + cli::cli_abort( + "For the {.pkg LiblineaR} engine, {.arg penalty} must be {.code > 0}.", + call = call + ) + } } invisible(object) diff --git a/R/mars.R b/R/mars.R index 60589fff8..9a955588d 100644 --- a/R/mars.R +++ b/R/mars.R @@ -108,17 +108,13 @@ translate.mars <- function(x, engine = x$engine, ...) { check_args.mars <- function(object, call = rlang::caller_env()) { args <- lapply(object$args, rlang::eval_tidy) + prod_degree <- args$prod_degree + num_terms <- args$num_terms + prune_method <- args$prune_method - if (is.numeric(args$prod_degree) && args$prod_degree < 0) - rlang::abort("`prod_degree` should be >= 1.") - - if (is.numeric(args$num_terms) && args$num_terms < 0) - rlang::abort("`num_terms` should be >= 1.") - - if (!is_varying(args$prune_method) && - !is.null(args$prune_method) && - !is.character(args$prune_method)) - rlang::abort("`prune_method` should be a single string value.") + check_number_whole(prod_degree, min = 1, allow_null = TRUE, call = call) + check_number_whole(num_terms, min = 1, allow_null = TRUE, call = call) + check_string(prune_method, allow_empty = FALSE, allow_null = TRUE, call = call) invisible(object) } diff --git a/R/mlp.R b/R/mlp.R index 45ee65261..e882efba6 100644 --- a/R/mlp.R +++ b/R/mlp.R @@ -129,18 +129,19 @@ translate.mlp <- function(x, engine = x$engine, ...) { check_args.mlp <- function(object, call = rlang::caller_env()) { args <- lapply(object$args, rlang::eval_tidy) + penalty <- args$penalty + dropout <- args$dropout - if (is.numeric(args$penalty)) - if (args$penalty < 0) - rlang::abort("The amount of weight decay must be >= 0.") + check_number_decimal(penalty, min = 0, allow_null = TRUE, call = call) + check_number_decimal(dropout, min = 0, max = 1, allow_null = TRUE, call = call) - if (is.numeric(args$dropout)) - if (args$dropout < 0 | args$dropout >= 1) - rlang::abort("The dropout proportion must be on [0, 1).") - - if (is.numeric(args$penalty) & is.numeric(args$dropout)) - if (args$dropout > 0 & args$penalty > 0) - rlang::abort("Both weight decay and dropout should not be specified.") + if (is.numeric(args$penalty) && is.numeric(args$dropout) && + args$dropout > 0 && args$penalty > 0) { + cli::cli_abort( + "Both weight decay and dropout should not be specified.", + call = call + ) + } invisible(object) } diff --git a/R/multinom_reg.R b/R/multinom_reg.R index f07747d46..6bd9f6559 100644 --- a/R/multinom_reg.R +++ b/R/multinom_reg.R @@ -103,12 +103,16 @@ update.multinom_reg <- check_args.multinom_reg <- function(object, call = rlang::caller_env()) { args <- lapply(object$args, rlang::eval_tidy) + mixture <- args$mixture - if (all(is.numeric(args$penalty)) && any(args$penalty < 0)) - rlang::abort("The amount of regularization should be >= 0.") - if (is.numeric(args$mixture) && (args$mixture < 0 | args$mixture > 1)) - rlang::abort("The mixture proportion should be within [0,1].") + check_number_decimal(mixture, min = 0, max = 1, allow_null = TRUE, call = call) + if (all(is.numeric(args$penalty)) && any(args$penalty < 0)) { + cli::cli_abort( + "The amount of regularization, {.arg penalty}, should be {.code >= 0}.", + call = call + ) + } invisible(object) } diff --git a/R/nearest_neighbor.R b/R/nearest_neighbor.R index 26966a2e4..7bc7eb9b5 100644 --- a/R/nearest_neighbor.R +++ b/R/nearest_neighbor.R @@ -90,26 +90,18 @@ update.nearest_neighbor <- function(object, ) } - -positive_int_scalar <- function(x) { - (length(x) == 1) && (x > 0) && (x %% 1 == 0) -} - # ------------------------------------------------------------------------------ #' @export check_args.nearest_neighbor <- function(object, call = rlang::caller_env()) { args <- lapply(object$args, rlang::eval_tidy) + neighbors <- args$neighbors + weight_func <- args$weight_func - if (is.numeric(args$neighbors) && !positive_int_scalar(args$neighbors)) { - rlang::abort("`neighbors` must be a length 1 positive integer.") - } - - if (is.character(args$weight_func) && length(args$weight_func) > 1) { - rlang::abort("The length of `weight_func` must be 1.") - } - + check_number_whole(neighbors, min = 0, allow_null = TRUE, call = call) + check_string(weight_func, allow_null = TRUE, call = call) + invisible(object) } diff --git a/R/pls.R b/R/pls.R index 14653093b..1fda1066b 100644 --- a/R/pls.R +++ b/R/pls.R @@ -90,10 +90,9 @@ update.pls <- check_args.pls <- function(object, call = rlang::caller_env()) { args <- lapply(object$args, rlang::eval_tidy) + num_comp <- args$num_comp - if (is.numeric(args$num_comp) && args$num_comp < 1) { - rlang::abort("`num_comp` should be >= 0.") - } + check_number_whole(num_comp, min = 0, allow_null = TRUE, call = call) invisible(object) } diff --git a/R/poisson_reg.R b/R/poisson_reg.R index 427a204db..09f8b8691 100644 --- a/R/poisson_reg.R +++ b/R/poisson_reg.R @@ -104,13 +104,16 @@ translate.poisson_reg <- function(x, engine = x$engine, ...) { check_args.poisson_reg <- function(object, call = rlang::caller_env()) { args <- lapply(object$args, rlang::eval_tidy) + mixture <- args$mixture - if (all(is.numeric(args$penalty)) && any(args$penalty < 0)) - rlang::abort("The amount of regularization should be >= 0.") - if (is.numeric(args$mixture) && (args$mixture < 0 | args$mixture > 1)) - rlang::abort("The mixture proportion should be within [0,1].") - if (is.numeric(args$mixture) && length(args$mixture) > 1) - rlang::abort("Only one value of `mixture` is allowed.") + check_number_decimal(mixture, min = 0, max = 1, allow_null = TRUE, call = call) + + if (all(is.numeric(args$penalty)) && any(args$penalty < 0)) { + cli::cli_abort( + "The amount of regularization, {.arg penalty}, should be {.code >= 0}.", + call = call + ) + } invisible(object) } From 7e95fc31be6b1cbb97cdca1287d3ca6f1cd7ecec Mon Sep 17 00:00:00 2001 From: Emil Hvitfeldt Date: Mon, 1 Apr 2024 16:38:43 -0700 Subject: [PATCH 03/24] devtools::document() --- man/add_on_exports.Rd | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/man/add_on_exports.Rd b/man/add_on_exports.Rd index 8ecfcdad4..69760d481 100644 --- a/man/add_on_exports.Rd +++ b/man/add_on_exports.Rd @@ -19,7 +19,7 @@ null_value(x) show_fit(model, eng) -check_args(object) +check_args(object, call = rlang::caller_env()) update_dot_check(...) From ecafc3f4809ddd7983a204cbd05d141967f4c5e6 Mon Sep 17 00:00:00 2001 From: Emil Hvitfeldt Date: Fri, 5 Apr 2024 12:30:44 -0700 Subject: [PATCH 04/24] update snapshots for check_args() --- tests/testthat/_snaps/args_and_modes.new.md | 149 ++++++++++++++++++++ tests/testthat/_snaps/boost_tree.md | 48 +++++++ tests/testthat/_snaps/boost_tree.new.md | 64 +++++++++ tests/testthat/_snaps/decision_tree.md | 40 ------ tests/testthat/test_boost_tree.R | 46 ++++-- tests/testthat/test_decision_tree.R | 24 +++- 6 files changed, 317 insertions(+), 54 deletions(-) create mode 100644 tests/testthat/_snaps/args_and_modes.new.md create mode 100644 tests/testthat/_snaps/boost_tree.new.md delete mode 100644 tests/testthat/_snaps/decision_tree.md diff --git a/tests/testthat/_snaps/args_and_modes.new.md b/tests/testthat/_snaps/args_and_modes.new.md new file mode 100644 index 000000000..fb6bb6430 --- /dev/null +++ b/tests/testthat/_snaps/args_and_modes.new.md @@ -0,0 +1,149 @@ +# can't set a mode that isn't allowed by the model spec + + Code + set_mode(linear_reg(), "classification") + Condition + Error in `set_mode()`: + ! "classification" is not a known mode for model `linear_reg()`. + +# unavailable modes for an engine and vice-versa + + Code + decision_tree() %>% set_mode("regression") %>% set_engine("C5.0") + Condition + Error in `set_engine()`: + ! Available modes for engine C5.0 are: "unknown" and "classification". + +--- + + Code + decision_tree(mode = "regression", engine = "C5.0") + Condition + Error in `decision_tree()`: + ! Available modes for engine C5.0 are: "unknown" and "classification". + +--- + + Code + decision_tree() %>% set_engine("C5.0") %>% set_mode("regression") + Condition + Error in `new_model_spec()`: + ! The engine C5.0 only supports the mode "classification", "unknown" was requested. + +--- + + Code + decision_tree(engine = NULL) %>% set_engine("C5.0") %>% set_mode("regression") + Condition + Error in `if (object$engine == "C5.0" && object$mode != "classification") ...`: + ! missing value where TRUE/FALSE needed + +--- + + Code + decision_tree(engine = NULL) %>% set_mode("regression") %>% set_engine("C5.0") + Condition + Error in `if (object$engine == "C5.0" && object$mode != "classification") ...`: + ! missing value where TRUE/FALSE needed + +--- + + Code + proportional_hazards() %>% set_mode("regression") + Condition + Error in `set_mode()`: + ! "regression" is not a known mode for model `proportional_hazards()`. + +--- + + Code + linear_reg() %>% set_mode() + Condition + Error in `set_mode()`: + ! Available modes for model type linear_reg are: "unknown" and "regression". + +--- + + Code + linear_reg(engine = "boop") + Condition + Error in `linear_reg()`: + x Engine "boop" is not supported for `linear_reg()` + i See `show_engines("linear_reg")`. + +--- + + Code + linear_reg() %>% set_engine() + Condition + Error in `set_engine()`: + ! Missing engine. Possible mode/engine combinations are: regression {lm, glm, glmnet, stan, spark, keras, brulee}. + +--- + + Code + proportional_hazards() %>% set_engine() + Condition + Error in `set_engine()`: + ! No known engines for `proportional_hazards()`. + +# set_* functions error when input isn't model_spec + + Code + set_mode(mtcars, "regression") + Condition + Error in `set_mode()`: + ! `set_mode()` expected a model specification to be supplied to the `object` argument, but received a(n) `data.frame` object. + +--- + + Code + set_args(mtcars, blah = "blah") + Condition + Error in `set_args()`: + ! `set_args()` expected a model specification to be supplied to the `object` argument, but received a(n) `data.frame` object. + +--- + + Code + bag_tree %>% set_mode("classification") + Condition + Error in `set_mode()`: + ! `set_mode()` expected a model specification to be supplied to the `object` argument, but received a(n) `function` object. + i Did you mistakenly pass `model_function` rather than `model_function()`? + +--- + + Code + bag_tree %>% set_engine("rpart") + Condition + Error in `set_engine()`: + ! `set_engine()` expected a model specification to be supplied to the `object` argument, but received a(n) `function` object. + i Did you mistakenly pass `model_function` rather than `model_function()`? + +--- + + Code + bag_tree %>% set_args(boop = "bop") + Condition + Error in `set_args()`: + ! `set_args()` expected a model specification to be supplied to the `object` argument, but received a(n) `function` object. + i Did you mistakenly pass `model_function` rather than `model_function()`? + +--- + + Code + 1L %>% set_args(mode = "classification") + Condition + Error in `set_args()`: + ! `set_args()` expected a model specification to be supplied to the `object` argument, but received a(n) `integer` object. + +--- + + Code + bag_tree %>% set_mode("classification") + Condition + Error in `set_mode()`: + ! `set_mode()` expected a model specification to be supplied to the `object` argument, but received a(n) `function` object. + i Did you mistakenly pass `model_function` rather than `model_function()`? + diff --git a/tests/testthat/_snaps/boost_tree.md b/tests/testthat/_snaps/boost_tree.md index 890681236..b5753ba3b 100644 --- a/tests/testthat/_snaps/boost_tree.md +++ b/tests/testthat/_snaps/boost_tree.md @@ -15,3 +15,51 @@ Computational engine: C5.0 +# bad input + + Code + boost_tree(mode = "bogus") + Condition + Error in `boost_tree()`: + ! "bogus" is not a known mode for model `boost_tree()`. + +--- + + Code + bt <- boost_tree(trees = -1) %>% set_engine("xgboost") %>% set_mode( + "classification") + fit(bt, class ~ ., hpc) + Condition + Error in `form_xy()`: + ! `trees` must be a whole number larger than or equal to 0 or `NULL`, not the number -1. + +--- + + Code + bt <- boost_tree(sample_size = -10) %>% set_engine("xgboost") %>% set_mode( + "classification") + fit(bt, class ~ ., hpc) + Condition + Error in `form_xy()`: + ! `sample_size` must be a number between 0 and 1 or `NULL`, not the number -10. + +--- + + Code + bt <- boost_tree(tree_depth = -10) %>% set_engine("xgboost") %>% set_mode( + "classification") + fit(bt, class ~ ., hpc) + Condition + Error in `form_xy()`: + ! `tree_depth` must be a whole number larger than or equal to 0 or `NULL`, not the number -10. + +--- + + Code + bt <- boost_tree(min_n = -10) %>% set_engine("xgboost") %>% set_mode( + "classification") + fit(bt, class ~ ., hpc) + Condition + Error in `form_xy()`: + ! `min_n` must be a whole number larger than or equal to 0 or `NULL`, not the number -10. + diff --git a/tests/testthat/_snaps/boost_tree.new.md b/tests/testthat/_snaps/boost_tree.new.md new file mode 100644 index 000000000..043c52c3d --- /dev/null +++ b/tests/testthat/_snaps/boost_tree.new.md @@ -0,0 +1,64 @@ +# bad input + + Code + boost_tree(mode = "bogus") + Condition + Error in `boost_tree()`: + ! "bogus" is not a known mode for model `boost_tree()`. + +--- + + Code + bt <- boost_tree(trees = -1) %>% set_engine("xgboost") %>% set_mode( + "classification") + Condition + Error in `new_model_spec()`: + ! `trees` must be a whole number larger than or equal to 0 or `NULL`, not the number -1. + Code + fit(bt, class ~ ., hpc) + Condition + Error: + ! object 'bt' not found + +--- + + Code + bt <- boost_tree(sample_size = -10) %>% set_engine("xgboost") %>% set_mode( + "classification") + Condition + Error in `new_model_spec()`: + ! `sample_size` must be a number between 0 and 1 or `NULL`, not the number -10. + Code + fit(bt, class ~ ., hpc) + Condition + Error: + ! object 'bt' not found + +--- + + Code + bt <- boost_tree(tree_depth = -10) %>% set_engine("xgboost") %>% set_mode( + "classification") + Condition + Error in `new_model_spec()`: + ! `tree_depth` must be a whole number larger than or equal to 0 or `NULL`, not the number -10. + Code + fit(bt, class ~ ., hpc) + Condition + Error: + ! object 'bt' not found + +--- + + Code + bt <- boost_tree(min_n = -10) %>% set_engine("xgboost") %>% set_mode( + "classification") + Condition + Error in `new_model_spec()`: + ! `min_n` must be a whole number larger than or equal to 0 or `NULL`, not the number -10. + Code + fit(bt, class ~ ., hpc) + Condition + Error: + ! object 'bt' not found + diff --git a/tests/testthat/_snaps/decision_tree.md b/tests/testthat/_snaps/decision_tree.md deleted file mode 100644 index 0a6aa9dc2..000000000 --- a/tests/testthat/_snaps/decision_tree.md +++ /dev/null @@ -1,40 +0,0 @@ -# updating - - Code - decision_tree(cost_complexity = 0.1) %>% set_engine("rpart", model = FALSE) %>% - update(cost_complexity = tune(), model = tune()) - Output - Decision Tree Model Specification (unknown mode) - - Main Arguments: - cost_complexity = tune() - - Engine-Specific Arguments: - model = tune() - - Computational engine: rpart - - -# bad input - - "bogus" is not a known mode for model `decision_tree()`. - ---- - - Please set the mode in the model specification. - ---- - - Please set the mode in the model specification. - ---- - - Code - try(translate(decision_tree(), engine = NULL), silent = TRUE) - Message - Used `engine = 'rpart'` for translation. - ---- - - unused argument (formula = y ~ x) - diff --git a/tests/testthat/test_boost_tree.R b/tests/testthat/test_boost_tree.R index f92216870..ceb26ea69 100644 --- a/tests/testthat/test_boost_tree.R +++ b/tests/testthat/test_boost_tree.R @@ -11,15 +11,43 @@ test_that('updating', { }) test_that('bad input', { - expect_error(boost_tree(mode = "bogus")) - expect_error({ - bt <- boost_tree(trees = -1) %>% set_engine("xgboost") - fit(bt, class ~ ., hpc) - }) - expect_error({ - bt <- boost_tree(min_n = -10) %>% set_engine("xgboost") - fit(bt, class ~ ., hpc) - }) + expect_snapshot(error = TRUE, boost_tree(mode = "bogus")) + expect_snapshot( + error = TRUE, + { + bt <- boost_tree(trees = -1) %>% + set_engine("xgboost") %>% + set_mode("classification") + fit(bt, class ~ ., hpc) + } + ) + expect_snapshot( + error = TRUE, + { + bt <- boost_tree(sample_size = -10) %>% + set_engine("xgboost") %>% + set_mode("classification") + fit(bt, class ~ ., hpc) + } + ) + expect_snapshot( + error = TRUE, + { + bt <- boost_tree(tree_depth = -10) %>% + set_engine("xgboost") %>% + set_mode("classification") + fit(bt, class ~ ., hpc) + } + ) + expect_snapshot( + error = TRUE, + { + bt <- boost_tree(min_n = -10) %>% + set_engine("xgboost") %>% + set_mode("classification") + fit(bt, class ~ ., hpc) + } + ) expect_message(translate(boost_tree(mode = "classification"), engine = NULL)) expect_error(translate(boost_tree(formula = y ~ x))) }) diff --git a/tests/testthat/test_decision_tree.R b/tests/testthat/test_decision_tree.R index bb1391299..b1133f8e4 100644 --- a/tests/testthat/test_decision_tree.R +++ b/tests/testthat/test_decision_tree.R @@ -1,4 +1,12 @@ -hpc <- hpc_data[1:150, c(2:5, 8)] +library(parsnip) + +data(hpc_data, package = "modeldata") + +bt <- decision_tree(tree_depth = c(1, 5, 10)) %>% + set_engine("rpart") %>% + set_mode("classification") + +fit(bt, class ~ ., hpc_data) # ------------------------------------------------------------------------------ @@ -12,10 +20,16 @@ test_that('updating', { test_that('bad input', { expect_snapshot_error(decision_tree(mode = "bogus")) - expect_snapshot_error({ - bt <- decision_tree(cost_complexity = -1) %>% set_engine("rpart") - fit(bt, class ~ ., hpc) - }) + expect_snapshot( + error = TRUE, + { + bt <- decision_tree(tree_depth = "six") %>% + set_engine("rpart") %>% + set_mode("classification") + + fit(bt, class ~ ., hpc) + } + ) expect_snapshot_error({ bt <- decision_tree(min_n = 0) %>% set_engine("rpart") fit(bt, class ~ ., hpc) From 3ae0eeb9f00411604f52dc049c3f108a5173f94d Mon Sep 17 00:00:00 2001 From: Emil Hvitfeldt Date: Tue, 9 Apr 2024 14:00:11 -0700 Subject: [PATCH 05/24] revert changes --- tests/testthat/_snaps/args_and_modes.new.md | 149 -------------------- tests/testthat/_snaps/boost_tree.new.md | 64 --------- tests/testthat/_snaps/decision_tree.md | 40 ++++++ tests/testthat/test_decision_tree.R | 24 +--- 4 files changed, 45 insertions(+), 232 deletions(-) delete mode 100644 tests/testthat/_snaps/args_and_modes.new.md delete mode 100644 tests/testthat/_snaps/boost_tree.new.md create mode 100644 tests/testthat/_snaps/decision_tree.md diff --git a/tests/testthat/_snaps/args_and_modes.new.md b/tests/testthat/_snaps/args_and_modes.new.md deleted file mode 100644 index fb6bb6430..000000000 --- a/tests/testthat/_snaps/args_and_modes.new.md +++ /dev/null @@ -1,149 +0,0 @@ -# can't set a mode that isn't allowed by the model spec - - Code - set_mode(linear_reg(), "classification") - Condition - Error in `set_mode()`: - ! "classification" is not a known mode for model `linear_reg()`. - -# unavailable modes for an engine and vice-versa - - Code - decision_tree() %>% set_mode("regression") %>% set_engine("C5.0") - Condition - Error in `set_engine()`: - ! Available modes for engine C5.0 are: "unknown" and "classification". - ---- - - Code - decision_tree(mode = "regression", engine = "C5.0") - Condition - Error in `decision_tree()`: - ! Available modes for engine C5.0 are: "unknown" and "classification". - ---- - - Code - decision_tree() %>% set_engine("C5.0") %>% set_mode("regression") - Condition - Error in `new_model_spec()`: - ! The engine C5.0 only supports the mode "classification", "unknown" was requested. - ---- - - Code - decision_tree(engine = NULL) %>% set_engine("C5.0") %>% set_mode("regression") - Condition - Error in `if (object$engine == "C5.0" && object$mode != "classification") ...`: - ! missing value where TRUE/FALSE needed - ---- - - Code - decision_tree(engine = NULL) %>% set_mode("regression") %>% set_engine("C5.0") - Condition - Error in `if (object$engine == "C5.0" && object$mode != "classification") ...`: - ! missing value where TRUE/FALSE needed - ---- - - Code - proportional_hazards() %>% set_mode("regression") - Condition - Error in `set_mode()`: - ! "regression" is not a known mode for model `proportional_hazards()`. - ---- - - Code - linear_reg() %>% set_mode() - Condition - Error in `set_mode()`: - ! Available modes for model type linear_reg are: "unknown" and "regression". - ---- - - Code - linear_reg(engine = "boop") - Condition - Error in `linear_reg()`: - x Engine "boop" is not supported for `linear_reg()` - i See `show_engines("linear_reg")`. - ---- - - Code - linear_reg() %>% set_engine() - Condition - Error in `set_engine()`: - ! Missing engine. Possible mode/engine combinations are: regression {lm, glm, glmnet, stan, spark, keras, brulee}. - ---- - - Code - proportional_hazards() %>% set_engine() - Condition - Error in `set_engine()`: - ! No known engines for `proportional_hazards()`. - -# set_* functions error when input isn't model_spec - - Code - set_mode(mtcars, "regression") - Condition - Error in `set_mode()`: - ! `set_mode()` expected a model specification to be supplied to the `object` argument, but received a(n) `data.frame` object. - ---- - - Code - set_args(mtcars, blah = "blah") - Condition - Error in `set_args()`: - ! `set_args()` expected a model specification to be supplied to the `object` argument, but received a(n) `data.frame` object. - ---- - - Code - bag_tree %>% set_mode("classification") - Condition - Error in `set_mode()`: - ! `set_mode()` expected a model specification to be supplied to the `object` argument, but received a(n) `function` object. - i Did you mistakenly pass `model_function` rather than `model_function()`? - ---- - - Code - bag_tree %>% set_engine("rpart") - Condition - Error in `set_engine()`: - ! `set_engine()` expected a model specification to be supplied to the `object` argument, but received a(n) `function` object. - i Did you mistakenly pass `model_function` rather than `model_function()`? - ---- - - Code - bag_tree %>% set_args(boop = "bop") - Condition - Error in `set_args()`: - ! `set_args()` expected a model specification to be supplied to the `object` argument, but received a(n) `function` object. - i Did you mistakenly pass `model_function` rather than `model_function()`? - ---- - - Code - 1L %>% set_args(mode = "classification") - Condition - Error in `set_args()`: - ! `set_args()` expected a model specification to be supplied to the `object` argument, but received a(n) `integer` object. - ---- - - Code - bag_tree %>% set_mode("classification") - Condition - Error in `set_mode()`: - ! `set_mode()` expected a model specification to be supplied to the `object` argument, but received a(n) `function` object. - i Did you mistakenly pass `model_function` rather than `model_function()`? - diff --git a/tests/testthat/_snaps/boost_tree.new.md b/tests/testthat/_snaps/boost_tree.new.md deleted file mode 100644 index 043c52c3d..000000000 --- a/tests/testthat/_snaps/boost_tree.new.md +++ /dev/null @@ -1,64 +0,0 @@ -# bad input - - Code - boost_tree(mode = "bogus") - Condition - Error in `boost_tree()`: - ! "bogus" is not a known mode for model `boost_tree()`. - ---- - - Code - bt <- boost_tree(trees = -1) %>% set_engine("xgboost") %>% set_mode( - "classification") - Condition - Error in `new_model_spec()`: - ! `trees` must be a whole number larger than or equal to 0 or `NULL`, not the number -1. - Code - fit(bt, class ~ ., hpc) - Condition - Error: - ! object 'bt' not found - ---- - - Code - bt <- boost_tree(sample_size = -10) %>% set_engine("xgboost") %>% set_mode( - "classification") - Condition - Error in `new_model_spec()`: - ! `sample_size` must be a number between 0 and 1 or `NULL`, not the number -10. - Code - fit(bt, class ~ ., hpc) - Condition - Error: - ! object 'bt' not found - ---- - - Code - bt <- boost_tree(tree_depth = -10) %>% set_engine("xgboost") %>% set_mode( - "classification") - Condition - Error in `new_model_spec()`: - ! `tree_depth` must be a whole number larger than or equal to 0 or `NULL`, not the number -10. - Code - fit(bt, class ~ ., hpc) - Condition - Error: - ! object 'bt' not found - ---- - - Code - bt <- boost_tree(min_n = -10) %>% set_engine("xgboost") %>% set_mode( - "classification") - Condition - Error in `new_model_spec()`: - ! `min_n` must be a whole number larger than or equal to 0 or `NULL`, not the number -10. - Code - fit(bt, class ~ ., hpc) - Condition - Error: - ! object 'bt' not found - diff --git a/tests/testthat/_snaps/decision_tree.md b/tests/testthat/_snaps/decision_tree.md new file mode 100644 index 000000000..0a6aa9dc2 --- /dev/null +++ b/tests/testthat/_snaps/decision_tree.md @@ -0,0 +1,40 @@ +# updating + + Code + decision_tree(cost_complexity = 0.1) %>% set_engine("rpart", model = FALSE) %>% + update(cost_complexity = tune(), model = tune()) + Output + Decision Tree Model Specification (unknown mode) + + Main Arguments: + cost_complexity = tune() + + Engine-Specific Arguments: + model = tune() + + Computational engine: rpart + + +# bad input + + "bogus" is not a known mode for model `decision_tree()`. + +--- + + Please set the mode in the model specification. + +--- + + Please set the mode in the model specification. + +--- + + Code + try(translate(decision_tree(), engine = NULL), silent = TRUE) + Message + Used `engine = 'rpart'` for translation. + +--- + + unused argument (formula = y ~ x) + diff --git a/tests/testthat/test_decision_tree.R b/tests/testthat/test_decision_tree.R index b1133f8e4..bb1391299 100644 --- a/tests/testthat/test_decision_tree.R +++ b/tests/testthat/test_decision_tree.R @@ -1,12 +1,4 @@ -library(parsnip) - -data(hpc_data, package = "modeldata") - -bt <- decision_tree(tree_depth = c(1, 5, 10)) %>% - set_engine("rpart") %>% - set_mode("classification") - -fit(bt, class ~ ., hpc_data) +hpc <- hpc_data[1:150, c(2:5, 8)] # ------------------------------------------------------------------------------ @@ -20,16 +12,10 @@ test_that('updating', { test_that('bad input', { expect_snapshot_error(decision_tree(mode = "bogus")) - expect_snapshot( - error = TRUE, - { - bt <- decision_tree(tree_depth = "six") %>% - set_engine("rpart") %>% - set_mode("classification") - - fit(bt, class ~ ., hpc) - } - ) + expect_snapshot_error({ + bt <- decision_tree(cost_complexity = -1) %>% set_engine("rpart") + fit(bt, class ~ ., hpc) + }) expect_snapshot_error({ bt <- decision_tree(min_n = 0) %>% set_engine("rpart") fit(bt, class ~ ., hpc) From fd2a8c135ee6a7512ec612b196839dac9c5e54c5 Mon Sep 17 00:00:00 2001 From: Emil Hvitfeldt Date: Tue, 9 Apr 2024 14:00:26 -0700 Subject: [PATCH 06/24] delete unreachable code --- R/bag_tree.R | 7 ------- R/decision_tree.R | 7 ------- 2 files changed, 14 deletions(-) diff --git a/R/bag_tree.R b/R/bag_tree.R index b58ded365..e80fc200a 100644 --- a/R/bag_tree.R +++ b/R/bag_tree.R @@ -86,13 +86,6 @@ update.bag_tree <- #' @export check_args.bag_tree <- function(object, call = rlang::caller_env()) { - if (object$engine == "C5.0" && object$mode != "classification") { - cli::cli_abort( - "The engine {.pkg C5.0} only supports the mode {.val classification}, \\ - {.val {object$mode}} was requested.", - call = call - ) - } invisible(object) } diff --git a/R/decision_tree.R b/R/decision_tree.R index 1e8192b85..8266fe806 100644 --- a/R/decision_tree.R +++ b/R/decision_tree.R @@ -129,13 +129,6 @@ translate.decision_tree <- function(x, engine = x$engine, ...) { #' @export check_args.decision_tree <- function(object, call = rlang::caller_env()) { - if (object$engine == "C5.0" && object$mode != "classification") { - cli::cli_abort( - "The engine {.pkg C5.0} only supports the mode {.val classification}, \\ - {.val {object$mode}} was requested.", - call = call - ) - } invisible(object) } From 8a83a4261c6d07bb6fdef1661d6394c526eb3f6a Mon Sep 17 00:00:00 2001 From: Emil Hvitfeldt Date: Tue, 9 Apr 2024 15:29:22 -0700 Subject: [PATCH 07/24] pass call argument through form_xy() --- R/fit_helpers.R | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/R/fit_helpers.R b/R/fit_helpers.R index 3cc023ea1..7e2437434 100644 --- a/R/fit_helpers.R +++ b/R/fit_helpers.R @@ -119,7 +119,7 @@ xy_xy <- function(object, } form_xy <- function(object, control, env, - target = "none", ...) { + target = "none", ..., call = rlang::caller_env()) { encoding_info <- get_encoding(class(object)[1]) %>% @@ -143,7 +143,8 @@ form_xy <- function(object, control, env, object = object, env = env, #weights! control = control, - target = target + target = target, + call = call ) data_obj$y_var <- all.vars(rlang::f_lhs(env$formula)) data_obj$x <- NULL From 3c361b9577fa26c98c5cbed3e3742f74973def65 Mon Sep 17 00:00:00 2001 From: Emil Hvitfeldt Date: Tue, 9 Apr 2024 15:29:29 -0700 Subject: [PATCH 08/24] fix typo --- R/logistic_reg.R | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/R/logistic_reg.R b/R/logistic_reg.R index 21e0eb680..06a2ca392 100644 --- a/R/logistic_reg.R +++ b/R/logistic_reg.R @@ -152,7 +152,7 @@ check_args.logistic_reg <- function(object, call = rlang::caller_env()) { if (is.numeric(args$mixture) && !args$mixture %in% 0:1) { cli::cli_abort( "For the {.pkg LiblineaR} engine, mixture must be 0 or 1,\\ - not {arg$mixture}.\\ + not {args$mixture}.\\ Choose a pure ridge model with {.code mixture = 0}.\\ Choose a pure lasso model with {.code mixture = 1}.\\ The {.pkg Liblinear} engine does not support other values.", From f67f14b0c8b655aabd5d6f0d640a8ef2420984a1 Mon Sep 17 00:00:00 2001 From: Emil Hvitfeldt Date: Tue, 9 Apr 2024 15:29:42 -0700 Subject: [PATCH 09/24] add all tests for check_args() --- tests/testthat/_snaps/boost_tree.md | 26 ++++---- tests/testthat/_snaps/linear_reg.md | 18 ++++++ tests/testthat/_snaps/logistic_reg.md | 40 ++++++++++++ tests/testthat/_snaps/mars.md | 30 +++++++++ tests/testthat/_snaps/mlp.md | 28 +++++++++ tests/testthat/_snaps/multinom_reg.md | 20 ++++++ tests/testthat/_snaps/nearest_neighbor.md | 20 ++++++ tests/testthat/test_boost_tree.R | 75 ++++++++++++----------- tests/testthat/test_decision_tree.R | 5 ++ tests/testthat/test_linear_reg.R | 22 ++++++- tests/testthat/test_logistic_reg.R | 39 ++++++++++++ tests/testthat/test_mars.R | 30 +++++++++ tests/testthat/test_mlp.R | 30 +++++++++ tests/testthat/test_multinom_reg.R | 21 +++++++ tests/testthat/test_nearest_neighbor.R | 23 +++++++ tests/testthat/test_rand_forest.R | 4 ++ tests/testthat/test_svm_linear.R | 6 +- tests/testthat/test_svm_poly.R | 5 ++ tests/testthat/test_svm_rbf.R | 5 +- 19 files changed, 393 insertions(+), 54 deletions(-) diff --git a/tests/testthat/_snaps/boost_tree.md b/tests/testthat/_snaps/boost_tree.md index b5753ba3b..f84c5d509 100644 --- a/tests/testthat/_snaps/boost_tree.md +++ b/tests/testthat/_snaps/boost_tree.md @@ -23,43 +23,43 @@ Error in `boost_tree()`: ! "bogus" is not a known mode for model `boost_tree()`. ---- +# check_args() works Code - bt <- boost_tree(trees = -1) %>% set_engine("xgboost") %>% set_mode( + spec <- boost_tree(trees = -1) %>% set_engine("xgboost") %>% set_mode( "classification") - fit(bt, class ~ ., hpc) + fit(spec, class ~ ., hpc) Condition - Error in `form_xy()`: + Error in `fit()`: ! `trees` must be a whole number larger than or equal to 0 or `NULL`, not the number -1. --- Code - bt <- boost_tree(sample_size = -10) %>% set_engine("xgboost") %>% set_mode( + spec <- boost_tree(sample_size = -10) %>% set_engine("xgboost") %>% set_mode( "classification") - fit(bt, class ~ ., hpc) + fit(spec, class ~ ., hpc) Condition - Error in `form_xy()`: + Error in `fit()`: ! `sample_size` must be a number between 0 and 1 or `NULL`, not the number -10. --- Code - bt <- boost_tree(tree_depth = -10) %>% set_engine("xgboost") %>% set_mode( + spec <- boost_tree(tree_depth = -10) %>% set_engine("xgboost") %>% set_mode( "classification") - fit(bt, class ~ ., hpc) + fit(spec, class ~ ., hpc) Condition - Error in `form_xy()`: + Error in `fit()`: ! `tree_depth` must be a whole number larger than or equal to 0 or `NULL`, not the number -10. --- Code - bt <- boost_tree(min_n = -10) %>% set_engine("xgboost") %>% set_mode( + spec <- boost_tree(min_n = -10) %>% set_engine("xgboost") %>% set_mode( "classification") - fit(bt, class ~ ., hpc) + fit(spec, class ~ ., hpc) Condition - Error in `form_xy()`: + Error in `fit()`: ! `min_n` must be a whole number larger than or equal to 0 or `NULL`, not the number -10. diff --git a/tests/testthat/_snaps/linear_reg.md b/tests/testthat/_snaps/linear_reg.md index 4626c9656..af9999167 100644 --- a/tests/testthat/_snaps/linear_reg.md +++ b/tests/testthat/_snaps/linear_reg.md @@ -23,3 +23,21 @@ Error in `predict()`: ! Please use `new_data` instead of `newdata`. +# check_args() works + + Code + spec <- linear_reg(mixture = -1) %>% set_engine("lm") %>% set_mode("regression") + fit(spec, compounds ~ ., hpc) + Condition + Error in `fit()`: + ! `mixture` must be a number between 0 and 1 or `NULL`, not the number -1. + +--- + + Code + spec <- linear_reg(penalty = -1) %>% set_engine("lm") %>% set_mode("regression") + fit(spec, compounds ~ ., hpc) + Condition + Error in `fit()`: + ! The amount of regularization, `penalty`, should be `>= 0`. + diff --git a/tests/testthat/_snaps/logistic_reg.md b/tests/testthat/_snaps/logistic_reg.md index ed336d9e3..71aaa8535 100644 --- a/tests/testthat/_snaps/logistic_reg.md +++ b/tests/testthat/_snaps/logistic_reg.md @@ -29,3 +29,43 @@ Warning: glm.fit: fitted probabilities numerically 0 or 1 occurred +# check_args() works + + Code + spec <- logistic_reg(mixture = -1) %>% set_engine("glm") %>% set_mode( + "classification") + fit(spec, Class ~ ., lending_club) + Condition + Error in `fit()`: + ! `mixture` must be a number between 0 and 1 or `NULL`, not the number -1. + +--- + + Code + spec <- logistic_reg(penalty = -1) %>% set_engine("glm") %>% set_mode( + "classification") + fit(spec, Class ~ ., lending_club) + Condition + Error in `fit()`: + ! The amount of regularization, `penalty`, should be `>= 0`. + +--- + + Code + spec <- logistic_reg(mixture = 0.5) %>% set_engine("LiblineaR") %>% set_mode( + "classification") + fit(spec, Class ~ ., lending_club) + Condition + Error in `fit()`: + ! For the LiblineaR engine, mixture must be 0 or 1,not 0.5.Choose a pure ridge model with `mixture = 0`.Choose a pure lasso model with `mixture = 1`.The Liblinear engine does not support other values. + +--- + + Code + spec <- logistic_reg(penalty = 0) %>% set_engine("LiblineaR") %>% set_mode( + "classification") + fit(spec, Class ~ ., lending_club) + Condition + Error in `fit()`: + ! For the LiblineaR engine, `penalty` must be `> 0`. + diff --git a/tests/testthat/_snaps/mars.md b/tests/testthat/_snaps/mars.md index a1596440e..64cc504a1 100644 --- a/tests/testthat/_snaps/mars.md +++ b/tests/testthat/_snaps/mars.md @@ -22,3 +22,33 @@ Error in `multi_predict()`: ! Please use `new_data` instead of `newdata`. +# check_args() works + + Code + spec <- mars(prod_degree = 0) %>% set_engine("earth") %>% set_mode( + "classification") + fit(spec, class ~ ., hpc) + Condition + Error in `fit()`: + ! `prod_degree` must be a whole number larger than or equal to 1 or `NULL`, not the number 0. + +--- + + Code + spec <- mars(num_terms = 0) %>% set_engine("earth") %>% set_mode( + "classification") + fit(spec, class ~ ., hpc) + Condition + Error in `fit()`: + ! `num_terms` must be a whole number larger than or equal to 1 or `NULL`, not the number 0. + +--- + + Code + spec <- mars(prune_method = 2) %>% set_engine("earth") %>% set_mode( + "classification") + fit(spec, class ~ ., hpc) + Condition + Error in `fit()`: + ! `prune_method` must be a single string or `NULL`, not the number 2. + diff --git a/tests/testthat/_snaps/mlp.md b/tests/testthat/_snaps/mlp.md index bfdd3bfb2..4e9422a65 100644 --- a/tests/testthat/_snaps/mlp.md +++ b/tests/testthat/_snaps/mlp.md @@ -15,3 +15,31 @@ Computational engine: nnet +# check_args() works + + Code + spec <- mlp(penalty = -1) %>% set_engine("nnet") %>% set_mode("classification") + fit(spec, class ~ ., hpc) + Condition + Error in `fit()`: + ! `penalty` must be a number larger than or equal to 0 or `NULL`, not the number -1. + +--- + + Code + spec <- mlp(dropout = -1) %>% set_engine("nnet") %>% set_mode("classification") + fit(spec, class ~ ., hpc) + Condition + Error in `fit()`: + ! `dropout` must be a number between 0 and 1 or `NULL`, not the number -1. + +--- + + Code + spec <- mlp(dropout = 1, penalty = 3) %>% set_engine("nnet") %>% set_mode( + "classification") + fit(spec, class ~ ., hpc) + Condition + Error in `fit()`: + ! Both weight decay and dropout should not be specified. + diff --git a/tests/testthat/_snaps/multinom_reg.md b/tests/testthat/_snaps/multinom_reg.md index da29212d1..f601c8a88 100644 --- a/tests/testthat/_snaps/multinom_reg.md +++ b/tests/testthat/_snaps/multinom_reg.md @@ -15,3 +15,23 @@ Computational engine: glmnet +# check_args() works + + Code + spec <- multinom_reg(mixture = -1) %>% set_engine("nnet") %>% set_mode( + "classification") + fit(spec, class ~ ., hpc) + Condition + Error in `fit()`: + ! `mixture` must be a number between 0 and 1 or `NULL`, not the number -1. + +--- + + Code + spec <- multinom_reg(penalty = -1) %>% set_engine("nnet") %>% set_mode( + "classification") + fit(spec, class ~ ., hpc) + Condition + Error in `fit()`: + ! The amount of regularization, `penalty`, should be `>= 0`. + diff --git a/tests/testthat/_snaps/nearest_neighbor.md b/tests/testthat/_snaps/nearest_neighbor.md index c4d6d63f3..401c06c51 100644 --- a/tests/testthat/_snaps/nearest_neighbor.md +++ b/tests/testthat/_snaps/nearest_neighbor.md @@ -15,3 +15,23 @@ Computational engine: kknn +# check_args() works + + Code + spec <- nearest_neighbor(neighbors = -1) %>% set_engine("kknn") %>% set_mode( + "classification") + fit(spec, class ~ ., hpc) + Condition + Error in `fit()`: + ! `neighbors` must be a whole number larger than or equal to 0 or `NULL`, not the number -1. + +--- + + Code + spec <- nearest_neighbor(weight_func = 2) %>% set_engine("kknn") %>% set_mode( + "classification") + fit(spec, class ~ ., hpc) + Condition + Error in `fit()`: + ! `weight_func` must be a single string or `NULL`, not the number 2. + diff --git a/tests/testthat/test_boost_tree.R b/tests/testthat/test_boost_tree.R index ceb26ea69..7a5338ede 100644 --- a/tests/testthat/test_boost_tree.R +++ b/tests/testthat/test_boost_tree.R @@ -12,42 +12,6 @@ test_that('updating', { test_that('bad input', { expect_snapshot(error = TRUE, boost_tree(mode = "bogus")) - expect_snapshot( - error = TRUE, - { - bt <- boost_tree(trees = -1) %>% - set_engine("xgboost") %>% - set_mode("classification") - fit(bt, class ~ ., hpc) - } - ) - expect_snapshot( - error = TRUE, - { - bt <- boost_tree(sample_size = -10) %>% - set_engine("xgboost") %>% - set_mode("classification") - fit(bt, class ~ ., hpc) - } - ) - expect_snapshot( - error = TRUE, - { - bt <- boost_tree(tree_depth = -10) %>% - set_engine("xgboost") %>% - set_mode("classification") - fit(bt, class ~ ., hpc) - } - ) - expect_snapshot( - error = TRUE, - { - bt <- boost_tree(min_n = -10) %>% - set_engine("xgboost") %>% - set_mode("classification") - fit(bt, class ~ ., hpc) - } - ) expect_message(translate(boost_tree(mode = "classification"), engine = NULL)) expect_error(translate(boost_tree(formula = y ~ x))) }) @@ -82,3 +46,42 @@ test_that('boost_tree can be fit with 1 predictor if validation is used', { fit(spec, mpg ~ disp, data = mtcars) ) }) + +test_that("check_args() works", { + expect_snapshot( + error = TRUE, + { + spec <- boost_tree(trees = -1) %>% + set_engine("xgboost") %>% + set_mode("classification") + fit(spec, class ~ ., hpc) + } + ) + expect_snapshot( + error = TRUE, + { + spec <- boost_tree(sample_size = -10) %>% + set_engine("xgboost") %>% + set_mode("classification") + fit(spec, class ~ ., hpc) + } + ) + expect_snapshot( + error = TRUE, + { + spec <- boost_tree(tree_depth = -10) %>% + set_engine("xgboost") %>% + set_mode("classification") + fit(spec, class ~ ., hpc) + } + ) + expect_snapshot( + error = TRUE, + { + spec <- boost_tree(min_n = -10) %>% + set_engine("xgboost") %>% + set_mode("classification") + fit(spec, class ~ ., hpc) + } + ) +}) \ No newline at end of file diff --git a/tests/testthat/test_decision_tree.R b/tests/testthat/test_decision_tree.R index bb1391299..12c32c824 100644 --- a/tests/testthat/test_decision_tree.R +++ b/tests/testthat/test_decision_tree.R @@ -69,3 +69,8 @@ test_that('argument checks for data dimensions', { expect_equal(args$min_instances_per_node, rlang::expr(min_rows(1000, x))) }) + +test_that("check_args() works", { + # Here for completeness, no checking is done + expect_true(TRUE) +}) \ No newline at end of file diff --git a/tests/testthat/test_linear_reg.R b/tests/testthat/test_linear_reg.R index 21a93bae3..8381bf4df 100644 --- a/tests/testthat/test_linear_reg.R +++ b/tests/testthat/test_linear_reg.R @@ -12,8 +12,6 @@ test_that('updating', { test_that('bad input', { expect_error(linear_reg(mode = "classification")) - # expect_error(linear_reg(penalty = -1)) - # expect_error(linear_reg(mixture = -1)) expect_error(translate(linear_reg(), engine = "wat?")) expect_error(translate(linear_reg(), engine = NULL)) expect_error(translate(linear_reg(formula = y ~ x))) @@ -342,3 +340,23 @@ test_that('lm can handle rankdeficient predictions', { expect_identical(names(preds), ".pred") }) +test_that("check_args() works", { + expect_snapshot( + error = TRUE, + { + spec <- linear_reg(mixture = -1) %>% + set_engine("lm") %>% + set_mode("regression") + fit(spec, compounds ~ ., hpc) + } + ) + expect_snapshot( + error = TRUE, + { + spec <- linear_reg(penalty = -1) %>% + set_engine("lm") %>% + set_mode("regression") + fit(spec, compounds ~ ., hpc) + } + ) +}) \ No newline at end of file diff --git a/tests/testthat/test_logistic_reg.R b/tests/testthat/test_logistic_reg.R index 8a6434747..77c09f77f 100644 --- a/tests/testthat/test_logistic_reg.R +++ b/tests/testthat/test_logistic_reg.R @@ -249,3 +249,42 @@ test_that('liblinear probabilities', { }) +test_that("check_args() works", { + expect_snapshot( + error = TRUE, + { + spec <- logistic_reg(mixture = -1) %>% + set_engine("glm") %>% + set_mode("classification") + fit(spec, Class ~ ., lending_club) + } + ) + expect_snapshot( + error = TRUE, + { + spec <- logistic_reg(penalty = -1) %>% + set_engine("glm") %>% + set_mode("classification") + fit(spec, Class ~ ., lending_club) + } + ) + expect_snapshot( + error = TRUE, + { + spec <- logistic_reg(mixture = 0.5) %>% + set_engine("LiblineaR") %>% + set_mode("classification") + fit(spec, Class ~ ., lending_club) + } + ) + expect_snapshot( + error = TRUE, + { + spec <- logistic_reg(penalty = 0) %>% + set_engine("LiblineaR") %>% + set_mode("classification") + fit(spec, Class ~ ., lending_club) + } + ) + +}) \ No newline at end of file diff --git a/tests/testthat/test_mars.R b/tests/testthat/test_mars.R index dbd5a04fb..69e7c153b 100644 --- a/tests/testthat/test_mars.R +++ b/tests/testthat/test_mars.R @@ -216,3 +216,33 @@ test_that('classification', { expect_equal(parsnip_pred$.pred_good, earth_pred) }) + +test_that("check_args() works", { + expect_snapshot( + error = TRUE, + { + spec <- mars(prod_degree = 0) %>% + set_engine("earth") %>% + set_mode("classification") + fit(spec, class ~ ., hpc) + } + ) + expect_snapshot( + error = TRUE, + { + spec <- mars(num_terms = 0) %>% + set_engine("earth") %>% + set_mode("classification") + fit(spec, class ~ ., hpc) + } + ) + expect_snapshot( + error = TRUE, + { + spec <- mars(prune_method = 2) %>% + set_engine("earth") %>% + set_mode("classification") + fit(spec, class ~ ., hpc) + } + ) +}) \ No newline at end of file diff --git a/tests/testthat/test_mlp.R b/tests/testthat/test_mlp.R index ec9dafad9..19cc6239d 100644 --- a/tests/testthat/test_mlp.R +++ b/tests/testthat/test_mlp.R @@ -1,3 +1,4 @@ +hpc <- hpc_data[1:150, c(2:5, 8)] test_that('updating', { expect_snapshot( @@ -48,3 +49,32 @@ test_that("more activations for brulee", { expect_true(inherits(fit$fit, "brulee_mlp")) }) +test_that("check_args() works", { + expect_snapshot( + error = TRUE, + { + spec <- mlp(penalty = -1) %>% + set_engine("nnet") %>% + set_mode("classification") + fit(spec, class ~ ., hpc) + } + ) + expect_snapshot( + error = TRUE, + { + spec <- mlp(dropout = -1) %>% + set_engine("nnet") %>% + set_mode("classification") + fit(spec, class ~ ., hpc) + } + ) + expect_snapshot( + error = TRUE, + { + spec <- mlp(dropout = 1, penalty = 3) %>% + set_engine("nnet") %>% + set_mode("classification") + fit(spec, class ~ ., hpc) + } + ) +}) \ No newline at end of file diff --git a/tests/testthat/test_multinom_reg.R b/tests/testthat/test_multinom_reg.R index 18b25be9a..1751a2cb1 100644 --- a/tests/testthat/test_multinom_reg.R +++ b/tests/testthat/test_multinom_reg.R @@ -15,3 +15,24 @@ test_that('bad input', { expect_error(multinom_reg(penalty = 0.1) %>% set_engine()) expect_warning(translate(multinom_reg(penalty = 0.1) %>% set_engine("glmnet", x = hpc[,1:3], y = hpc$class))) }) + +test_that('check_args() works', { + expect_snapshot( + error = TRUE, + { + spec <- multinom_reg(mixture = -1) %>% + set_engine("nnet") %>% + set_mode("classification") + fit(spec, class ~ ., hpc) + } + ) + expect_snapshot( + error = TRUE, + { + spec <- multinom_reg(penalty = -1) %>% + set_engine("nnet") %>% + set_mode("classification") + fit(spec, class ~ ., hpc) + } + ) +}) \ No newline at end of file diff --git a/tests/testthat/test_nearest_neighbor.R b/tests/testthat/test_nearest_neighbor.R index 784d957da..0546fc1d6 100644 --- a/tests/testthat/test_nearest_neighbor.R +++ b/tests/testthat/test_nearest_neighbor.R @@ -1,3 +1,5 @@ +hpc <- hpc_data[1:150, c(2:5, 8)] + test_that('updating', { expect_snapshot( nearest_neighbor(neighbors = 5) %>% @@ -10,3 +12,24 @@ test_that('bad input', { expect_error(nearest_neighbor(mode = "reallyunknown")) expect_error(nearest_neighbor() %>% set_engine( NULL)) }) + +test_that('check_args() works', { + expect_snapshot( + error = TRUE, + { + spec <- nearest_neighbor(neighbors = -1) %>% + set_engine("kknn") %>% + set_mode("classification") + fit(spec, class ~ ., hpc) + } + ) + expect_snapshot( + error = TRUE, + { + spec <- nearest_neighbor(weight_func = 2) %>% + set_engine("kknn") %>% + set_mode("classification") + fit(spec, class ~ ., hpc) + } + ) +}) \ No newline at end of file diff --git a/tests/testthat/test_rand_forest.R b/tests/testthat/test_rand_forest.R index e74b5e7c7..0de228d44 100644 --- a/tests/testthat/test_rand_forest.R +++ b/tests/testthat/test_rand_forest.R @@ -14,3 +14,7 @@ test_that('bad input', { expect_error(translate(rand_forest(mode = "classification", ytest = 2))) }) +test_that("check_args() works", { + # Here for completeness, no checking is done + expect_true(TRUE) +}) \ No newline at end of file diff --git a/tests/testthat/test_svm_linear.R b/tests/testthat/test_svm_linear.R index 0b875c093..eb0663f14 100644 --- a/tests/testthat/test_svm_linear.R +++ b/tests/testthat/test_svm_linear.R @@ -370,5 +370,7 @@ test_that('linear svm classification prediction: kernlab', { }) - - +test_that("check_args() works", { + # Here for completeness, no checking is done + expect_true(TRUE) +}) \ No newline at end of file diff --git a/tests/testthat/test_svm_poly.R b/tests/testthat/test_svm_poly.R index 910bf978f..2dc67db6c 100644 --- a/tests/testthat/test_svm_poly.R +++ b/tests/testthat/test_svm_poly.R @@ -185,3 +185,8 @@ test_that('svm poly classification probabilities', { parsnip_xy_probs <- predict(cls_xy_form, hpc_no_m[ind, -5], type = "prob") expect_equal(as.data.frame(kern_probs), as.data.frame(parsnip_xy_probs)) }) + +test_that("check_args() works", { + # Here for completeness, no checking is done + expect_true(TRUE) +}) \ No newline at end of file diff --git a/tests/testthat/test_svm_rbf.R b/tests/testthat/test_svm_rbf.R index db9d5d1f5..758885a6c 100644 --- a/tests/testthat/test_svm_rbf.R +++ b/tests/testthat/test_svm_rbf.R @@ -195,4 +195,7 @@ test_that('svm rbf classification probabilities', { expect_equal(as.data.frame(kern_probs), as.data.frame(parsnip_xy_probs)) }) - +test_that("check_args() works", { + # Here for completeness, no checking is done + expect_true(TRUE) +}) \ No newline at end of file From 9130aac3a5d941655c95e9fa2ec9a0492b62a095 Mon Sep 17 00:00:00 2001 From: Emil Hvitfeldt Date: Tue, 9 Apr 2024 16:17:35 -0700 Subject: [PATCH 10/24] use skip_if_not_installed() --- tests/testthat/_snaps/mlp.md | 6 +++--- tests/testthat/_snaps/multinom_reg.md | 4 ++-- tests/testthat/test_boost_tree.R | 2 ++ tests/testthat/test_mars.R | 2 ++ tests/testthat/test_mlp.R | 8 +++++--- tests/testthat/test_multinom_reg.R | 6 ++++-- tests/testthat/test_nearest_neighbor.R | 2 ++ 7 files changed, 20 insertions(+), 10 deletions(-) diff --git a/tests/testthat/_snaps/mlp.md b/tests/testthat/_snaps/mlp.md index 4e9422a65..03b192b7a 100644 --- a/tests/testthat/_snaps/mlp.md +++ b/tests/testthat/_snaps/mlp.md @@ -18,7 +18,7 @@ # check_args() works Code - spec <- mlp(penalty = -1) %>% set_engine("nnet") %>% set_mode("classification") + spec <- mlp(penalty = -1) %>% set_engine("keras") %>% set_mode("classification") fit(spec, class ~ ., hpc) Condition Error in `fit()`: @@ -27,7 +27,7 @@ --- Code - spec <- mlp(dropout = -1) %>% set_engine("nnet") %>% set_mode("classification") + spec <- mlp(dropout = -1) %>% set_engine("keras") %>% set_mode("classification") fit(spec, class ~ ., hpc) Condition Error in `fit()`: @@ -36,7 +36,7 @@ --- Code - spec <- mlp(dropout = 1, penalty = 3) %>% set_engine("nnet") %>% set_mode( + spec <- mlp(dropout = 1, penalty = 3) %>% set_engine("keras") %>% set_mode( "classification") fit(spec, class ~ ., hpc) Condition diff --git a/tests/testthat/_snaps/multinom_reg.md b/tests/testthat/_snaps/multinom_reg.md index f601c8a88..002f2121e 100644 --- a/tests/testthat/_snaps/multinom_reg.md +++ b/tests/testthat/_snaps/multinom_reg.md @@ -18,7 +18,7 @@ # check_args() works Code - spec <- multinom_reg(mixture = -1) %>% set_engine("nnet") %>% set_mode( + spec <- multinom_reg(mixture = -1) %>% set_engine("keras") %>% set_mode( "classification") fit(spec, class ~ ., hpc) Condition @@ -28,7 +28,7 @@ --- Code - spec <- multinom_reg(penalty = -1) %>% set_engine("nnet") %>% set_mode( + spec <- multinom_reg(penalty = -1) %>% set_engine("keras") %>% set_mode( "classification") fit(spec, class ~ ., hpc) Condition diff --git a/tests/testthat/test_boost_tree.R b/tests/testthat/test_boost_tree.R index 7a5338ede..4d0f135d7 100644 --- a/tests/testthat/test_boost_tree.R +++ b/tests/testthat/test_boost_tree.R @@ -48,6 +48,8 @@ test_that('boost_tree can be fit with 1 predictor if validation is used', { }) test_that("check_args() works", { + skip_if_not_installed("xgboost") + expect_snapshot( error = TRUE, { diff --git a/tests/testthat/test_mars.R b/tests/testthat/test_mars.R index 69e7c153b..fade45d2e 100644 --- a/tests/testthat/test_mars.R +++ b/tests/testthat/test_mars.R @@ -218,6 +218,8 @@ test_that('classification', { }) test_that("check_args() works", { + skip_if_not_installed("earth") + expect_snapshot( error = TRUE, { diff --git a/tests/testthat/test_mlp.R b/tests/testthat/test_mlp.R index 19cc6239d..731a20b01 100644 --- a/tests/testthat/test_mlp.R +++ b/tests/testthat/test_mlp.R @@ -50,11 +50,13 @@ test_that("more activations for brulee", { }) test_that("check_args() works", { + skip_if_not_installed("keras") + expect_snapshot( error = TRUE, { spec <- mlp(penalty = -1) %>% - set_engine("nnet") %>% + set_engine("keras") %>% set_mode("classification") fit(spec, class ~ ., hpc) } @@ -63,7 +65,7 @@ test_that("check_args() works", { error = TRUE, { spec <- mlp(dropout = -1) %>% - set_engine("nnet") %>% + set_engine("keras") %>% set_mode("classification") fit(spec, class ~ ., hpc) } @@ -72,7 +74,7 @@ test_that("check_args() works", { error = TRUE, { spec <- mlp(dropout = 1, penalty = 3) %>% - set_engine("nnet") %>% + set_engine("keras") %>% set_mode("classification") fit(spec, class ~ ., hpc) } diff --git a/tests/testthat/test_multinom_reg.R b/tests/testthat/test_multinom_reg.R index 1751a2cb1..7752d75e0 100644 --- a/tests/testthat/test_multinom_reg.R +++ b/tests/testthat/test_multinom_reg.R @@ -17,11 +17,13 @@ test_that('bad input', { }) test_that('check_args() works', { + skip_if_not_installed("keras") + expect_snapshot( error = TRUE, { spec <- multinom_reg(mixture = -1) %>% - set_engine("nnet") %>% + set_engine("keras") %>% set_mode("classification") fit(spec, class ~ ., hpc) } @@ -30,7 +32,7 @@ test_that('check_args() works', { error = TRUE, { spec <- multinom_reg(penalty = -1) %>% - set_engine("nnet") %>% + set_engine("keras") %>% set_mode("classification") fit(spec, class ~ ., hpc) } diff --git a/tests/testthat/test_nearest_neighbor.R b/tests/testthat/test_nearest_neighbor.R index 0546fc1d6..456b7a955 100644 --- a/tests/testthat/test_nearest_neighbor.R +++ b/tests/testthat/test_nearest_neighbor.R @@ -14,6 +14,8 @@ test_that('bad input', { }) test_that('check_args() works', { + skip_if_not_installed("kknn") + expect_snapshot( error = TRUE, { From eeb2a8127d7f03f3b1cec7bd5a48d77f8dfa2d71 Mon Sep 17 00:00:00 2001 From: Emil Hvitfeldt Date: Tue, 9 Apr 2024 16:40:41 -0700 Subject: [PATCH 11/24] increase package version --- DESCRIPTION | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/DESCRIPTION b/DESCRIPTION index 1c1a08492..bfa80fed1 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -1,6 +1,6 @@ Package: parsnip Title: A Common API to Modeling and Analysis Functions -Version: 1.2.1.9000 +Version: 1.2.1.9001 Authors@R: c( person("Max", "Kuhn", , "max@posit.co", role = c("aut", "cre")), person("Davis", "Vaughan", , "davis@posit.co", role = "aut"), From 4191c3c3fec40208085dd501d9456941ed972be2 Mon Sep 17 00:00:00 2001 From: Emil Hvitfeldt Date: Tue, 9 Apr 2024 16:50:17 -0700 Subject: [PATCH 12/24] pass call in check_args.C5_rules() --- R/c5_rules.R | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/R/c5_rules.R b/R/c5_rules.R index 2ad3d3154..44d8150d6 100644 --- a/R/c5_rules.R +++ b/R/c5_rules.R @@ -119,7 +119,8 @@ check_args.C5_rules <- function(object, call = rlang::caller_env()) { if (length(args$trees) > 1) { cli::cli_abort( "Only a single value of {.arg trees} should be passed, \\ - not {length(args$trees)}." + not {length(args$trees)}.", + call = call ) } @@ -138,7 +139,8 @@ check_args.C5_rules <- function(object, call = rlang::caller_env()) { if (length(args$min_n) > 1) { cli::cli_abort( "Only a single value of {.arg min_n} should be passed, \\ - not {length(args$min_n)}." + not {length(args$min_n)}.", + call = call ) } } From 2ac8d17e4f1b6f2237a36a80e47ce88a984773a1 Mon Sep 17 00:00:00 2001 From: Emil Hvitfeldt Date: Tue, 9 Apr 2024 17:09:31 -0700 Subject: [PATCH 13/24] pass calls in check_args.cubist_rules() --- R/cubist_rules.R | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/R/cubist_rules.R b/R/cubist_rules.R index 4ab85f599..9037748ab 100644 --- a/R/cubist_rules.R +++ b/R/cubist_rules.R @@ -143,7 +143,8 @@ check_args.cubist_rules <- function(object, call = rlang::caller_env()) { if (length(args$committees) > 1) { cli::cli_abort( "Only a single value of {.arg committees} should be passed, \\ - not {length(args$committees)}." + not {length(args$committees)}.", + call = call ) } @@ -164,7 +165,8 @@ check_args.cubist_rules <- function(object, call = rlang::caller_env()) { if (length(args$neighbors) > 1) { cli::cli_abort( "Only a single value of {.arg neighbors} should be passed, \\ - not {length(args$neighbors)}." + not {length(args$neighbors)}.", + call = call ) } From d9e9f06f409fbf4a1814e674a179265d66cffb88 Mon Sep 17 00:00:00 2001 From: Emil Hvitfeldt Date: Tue, 9 Apr 2024 17:56:01 -0700 Subject: [PATCH 14/24] add tests for more models --- tests/testthat/test-bart.R | 4 ++++ tests/testthat/test_gen_additive_model.R | 5 +++++ tests/testthat/test_nullmodel.R | 6 ++++-- 3 files changed, 13 insertions(+), 2 deletions(-) create mode 100644 tests/testthat/test-bart.R diff --git a/tests/testthat/test-bart.R b/tests/testthat/test-bart.R new file mode 100644 index 000000000..245f18f7a --- /dev/null +++ b/tests/testthat/test-bart.R @@ -0,0 +1,4 @@ +test_that("check_args() works", { + # Here for completeness, no checking is done + expect_true(TRUE) +}) \ No newline at end of file diff --git a/tests/testthat/test_gen_additive_model.R b/tests/testthat/test_gen_additive_model.R index 3e58605a2..adaa60560 100644 --- a/tests/testthat/test_gen_additive_model.R +++ b/tests/testthat/test_gen_additive_model.R @@ -98,3 +98,8 @@ test_that('classification', { expect_equal(f_ci[[".pred_lower_Class2"]], lower) }) + +test_that("check_args() works", { + # Here for completeness, no checking is done + expect_true(TRUE) +}) \ No newline at end of file diff --git a/tests/testthat/test_nullmodel.R b/tests/testthat/test_nullmodel.R index 7b31c8a80..28e1da8a7 100644 --- a/tests/testthat/test_nullmodel.R +++ b/tests/testthat/test_nullmodel.R @@ -124,5 +124,7 @@ test_that('null_model printing', { ) }) - - +test_that("check_args() works", { + # Here for completeness, no checking is done + expect_true(TRUE) +}) \ No newline at end of file From 6db51f970013d67a29f8fb29e34d772ba08609ed Mon Sep 17 00:00:00 2001 From: Emil Hvitfeldt Date: Wed, 10 Apr 2024 09:40:39 -0700 Subject: [PATCH 15/24] use arg argument in check_* functions --- R/boost_tree.R | 14 +++++--------- R/discrim_flexible.R | 9 +++------ R/discrim_regularized.R | 6 ++---- R/linear_reg.R | 3 +-- R/logistic_reg.R | 3 +-- R/mars.R | 9 +++------ R/mlp.R | 6 ++---- R/multinom_reg.R | 3 +-- R/nearest_neighbor.R | 6 ++---- R/pls.R | 3 +-- R/poisson_reg.R | 3 +-- 11 files changed, 22 insertions(+), 43 deletions(-) diff --git a/R/boost_tree.R b/R/boost_tree.R index d0a3ab9a8..799afee1e 100644 --- a/R/boost_tree.R +++ b/R/boost_tree.R @@ -167,15 +167,11 @@ translate.boost_tree <- function(x, engine = x$engine, ...) { check_args.boost_tree <- function(object, call = rlang::caller_env()) { args <- lapply(object$args, rlang::eval_tidy) - trees <- args$trees - sample_size <- args$sample_size - tree_depth <- args$tree_depth - min_n <- args$min_n - - check_number_whole(trees, min = 0, allow_null = TRUE, call = call) - check_number_decimal(sample_size, min = 0, max = 1, allow_null = TRUE, call = call) - check_number_whole(tree_depth, min = 0, allow_null = TRUE, call = call) - check_number_whole(min_n, min = 0, allow_null = TRUE, call = call) + + check_number_whole(args$trees, min = 0, allow_null = TRUE, call = call, arg = "trees") + check_number_decimal(args$sample_size, min = 0, max = 1, allow_null = TRUE, call = call, arg = "sample_size") + check_number_whole(args$tree_depth, min = 0, allow_null = TRUE, call = call, arg = "tree_depth") + check_number_whole(args$min_n, min = 0, allow_null = TRUE, call = call, arg = "min_n") invisible(object) } diff --git a/R/discrim_flexible.R b/R/discrim_flexible.R index a8b3dba92..8b2826b1a 100644 --- a/R/discrim_flexible.R +++ b/R/discrim_flexible.R @@ -88,13 +88,10 @@ update.discrim_flexible <- check_args.discrim_flexible <- function(object, call = rlang::caller_env()) { args <- lapply(object$args, rlang::eval_tidy) - prod_degree <- args$prod_degree - num_terms <- args$num_terms - prune_method <- args$prune_method - check_number_whole(prod_degree, min = 1, allow_null = TRUE, call = call) - check_number_whole(num_terms, min = 1, allow_null = TRUE, call = call) - check_string(prune_method, allow_empty = FALSE, allow_null = TRUE, call = call) + check_number_whole(args$prod_degree, min = 1, allow_null = TRUE, call = call, arg = "prod_degree") + check_number_whole(args$num_terms, min = 1, allow_null = TRUE, call = call, arg = "num_terms") + check_string(args$prune_method, allow_empty = FALSE, allow_null = TRUE, call = call, arg = "prune_method") invisible(object) } diff --git a/R/discrim_regularized.R b/R/discrim_regularized.R index 011d26e6c..e6f209c4c 100644 --- a/R/discrim_regularized.R +++ b/R/discrim_regularized.R @@ -98,11 +98,9 @@ update.discrim_regularized <- check_args.discrim_regularized <- function(object, call = rlang::caller_env()) { args <- lapply(object$args, rlang::eval_tidy) - frac_common_cov <- args$frac_common_cov - frac_identity <- args$frac_identity - check_number_decimal(frac_common_cov, min = 0, max = 1, allow_null = TRUE, call = call) - check_number_decimal(frac_identity, min = 0, max = 1, allow_null = TRUE, call = call) + check_number_decimal(args$frac_common_cov, min = 0, max = 1, allow_null = TRUE, call = call, arg = "frac_common_cov") + check_number_decimal(args$frac_identity, min = 0, max = 1, allow_null = TRUE, call = call, arg = "frac_identity") invisible(object) } diff --git a/R/linear_reg.R b/R/linear_reg.R index 2fbf31ea4..12d5c556d 100644 --- a/R/linear_reg.R +++ b/R/linear_reg.R @@ -109,9 +109,8 @@ update.linear_reg <- check_args.linear_reg <- function(object, call = rlang::caller_env()) { args <- lapply(object$args, rlang::eval_tidy) - mixture <- args$mixture - check_number_decimal(mixture, min = 0, max = 1, allow_null = TRUE, call = call) + check_number_decimal(args$mixture, min = 0, max = 1, allow_null = TRUE, call = call, arg = "mixture") if (all(is.numeric(args$penalty)) && any(args$penalty < 0)) { cli::cli_abort( diff --git a/R/logistic_reg.R b/R/logistic_reg.R index 06a2ca392..9fc779762 100644 --- a/R/logistic_reg.R +++ b/R/logistic_reg.R @@ -138,9 +138,8 @@ update.logistic_reg <- check_args.logistic_reg <- function(object, call = rlang::caller_env()) { args <- lapply(object$args, rlang::eval_tidy) - mixture <- args$mixture - check_number_decimal(mixture, min = 0, max = 1, allow_null = TRUE, call = call) + check_number_decimal(args$mixture, min = 0, max = 1, allow_null = TRUE, call = call, arg = "mixture") if (all(is.numeric(args$penalty)) && any(args$penalty < 0)) cli::cli_abort( diff --git a/R/mars.R b/R/mars.R index 9a955588d..f4ad10ae7 100644 --- a/R/mars.R +++ b/R/mars.R @@ -108,13 +108,10 @@ translate.mars <- function(x, engine = x$engine, ...) { check_args.mars <- function(object, call = rlang::caller_env()) { args <- lapply(object$args, rlang::eval_tidy) - prod_degree <- args$prod_degree - num_terms <- args$num_terms - prune_method <- args$prune_method - check_number_whole(prod_degree, min = 1, allow_null = TRUE, call = call) - check_number_whole(num_terms, min = 1, allow_null = TRUE, call = call) - check_string(prune_method, allow_empty = FALSE, allow_null = TRUE, call = call) + check_number_whole(args$prod_degree, min = 1, allow_null = TRUE, call = call, arg = "prod_degree") + check_number_whole(args$num_terms, min = 1, allow_null = TRUE, call = call, arg = "num_terms") + check_string(args$prune_method, allow_empty = FALSE, allow_null = TRUE, call = call, arg = "prune_method") invisible(object) } diff --git a/R/mlp.R b/R/mlp.R index e882efba6..fa3cf097c 100644 --- a/R/mlp.R +++ b/R/mlp.R @@ -129,11 +129,9 @@ translate.mlp <- function(x, engine = x$engine, ...) { check_args.mlp <- function(object, call = rlang::caller_env()) { args <- lapply(object$args, rlang::eval_tidy) - penalty <- args$penalty - dropout <- args$dropout - check_number_decimal(penalty, min = 0, allow_null = TRUE, call = call) - check_number_decimal(dropout, min = 0, max = 1, allow_null = TRUE, call = call) + check_number_decimal(args$penalty, min = 0, allow_null = TRUE, call = call, arg = "penalty") + check_number_decimal(args$dropout, min = 0, max = 1, allow_null = TRUE, call = call, arg = "dropout") if (is.numeric(args$penalty) && is.numeric(args$dropout) && args$dropout > 0 && args$penalty > 0) { diff --git a/R/multinom_reg.R b/R/multinom_reg.R index 6bd9f6559..cb074fc4a 100644 --- a/R/multinom_reg.R +++ b/R/multinom_reg.R @@ -103,9 +103,8 @@ update.multinom_reg <- check_args.multinom_reg <- function(object, call = rlang::caller_env()) { args <- lapply(object$args, rlang::eval_tidy) - mixture <- args$mixture - check_number_decimal(mixture, min = 0, max = 1, allow_null = TRUE, call = call) + check_number_decimal(args$mixture, min = 0, max = 1, allow_null = TRUE, call = call, arg = "mixture") if (all(is.numeric(args$penalty)) && any(args$penalty < 0)) { cli::cli_abort( diff --git a/R/nearest_neighbor.R b/R/nearest_neighbor.R index 7bc7eb9b5..37daaea4c 100644 --- a/R/nearest_neighbor.R +++ b/R/nearest_neighbor.R @@ -96,11 +96,9 @@ update.nearest_neighbor <- function(object, check_args.nearest_neighbor <- function(object, call = rlang::caller_env()) { args <- lapply(object$args, rlang::eval_tidy) - neighbors <- args$neighbors - weight_func <- args$weight_func - check_number_whole(neighbors, min = 0, allow_null = TRUE, call = call) - check_string(weight_func, allow_null = TRUE, call = call) + check_number_whole(args$neighbors, min = 0, allow_null = TRUE, call = call, arg = "neighbors") + check_string(args$weight_func, allow_null = TRUE, call = call, arg = "weight_func") invisible(object) } diff --git a/R/pls.R b/R/pls.R index 1fda1066b..3a2bc13d7 100644 --- a/R/pls.R +++ b/R/pls.R @@ -90,9 +90,8 @@ update.pls <- check_args.pls <- function(object, call = rlang::caller_env()) { args <- lapply(object$args, rlang::eval_tidy) - num_comp <- args$num_comp - check_number_whole(num_comp, min = 0, allow_null = TRUE, call = call) + check_number_whole(args$num_comp, min = 0, allow_null = TRUE, call = call, arg = "num_comp") invisible(object) } diff --git a/R/poisson_reg.R b/R/poisson_reg.R index 09f8b8691..790517e26 100644 --- a/R/poisson_reg.R +++ b/R/poisson_reg.R @@ -104,9 +104,8 @@ translate.poisson_reg <- function(x, engine = x$engine, ...) { check_args.poisson_reg <- function(object, call = rlang::caller_env()) { args <- lapply(object$args, rlang::eval_tidy) - mixture <- args$mixture - check_number_decimal(mixture, min = 0, max = 1, allow_null = TRUE, call = call) + check_number_decimal(args$mixture, min = 0, max = 1, allow_null = TRUE, call = call, arg = "mixture") if (all(is.numeric(args$penalty)) && any(args$penalty < 0)) { cli::cli_abort( From dbc783993ae86fe4be52f9f529888bc37a9a3bfb Mon Sep 17 00:00:00 2001 From: Emil Hvitfeldt Date: Wed, 10 Apr 2024 09:45:04 -0700 Subject: [PATCH 16/24] Update R/cubist_rules.R Co-authored-by: Simon P. Couch --- R/cubist_rules.R | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/R/cubist_rules.R b/R/cubist_rules.R index 9037748ab..e61e0eb8a 100644 --- a/R/cubist_rules.R +++ b/R/cubist_rules.R @@ -157,7 +157,7 @@ check_args.cubist_rules <- function(object, call = rlang::caller_env()) { if (args$committees < 1) { object$args$committees <- rlang::new_quosure(1L, env = rlang::empty_env()) - cli::cli_warn(c(msg, "Truncating to 100.")) + cli::cli_warn(c(msg, "Truncating to 1.")) } } From 3277863af13952425834f00254ee614583c209bb Mon Sep 17 00:00:00 2001 From: Emil Hvitfeldt Date: Wed, 10 Apr 2024 10:13:59 -0700 Subject: [PATCH 17/24] use more check_* functions --- R/c5_rules.R | 36 +++++++++--------------------- R/cubist_rules.R | 58 ++++++++++++++++++------------------------------ 2 files changed, 31 insertions(+), 63 deletions(-) diff --git a/R/c5_rules.R b/R/c5_rules.R index 44d8150d6..904e2cb56 100644 --- a/R/c5_rules.R +++ b/R/c5_rules.R @@ -115,35 +115,19 @@ check_args.C5_rules <- function(object, call = rlang::caller_env()) { args <- lapply(object$args, rlang::eval_tidy) - if (is.numeric(args$trees)) { - if (length(args$trees) > 1) { - cli::cli_abort( - "Only a single value of {.arg trees} should be passed, \\ - not {length(args$trees)}.", - call = call - ) - } - - msg <- "The number of trees should be {.code >= 1} and {.code <= 100}" - if (args$trees > 100) { - object$args$trees <- rlang::new_quosure(100L, env = rlang::empty_env()) - cli::cli_warn(c(msg, "Truncating to 100.")) - } - if (args$trees < 1) { - object$args$trees <- rlang::new_quosure(1L, env = rlang::empty_env()) - cli::cli_warn(c(msg, "Truncating to 1.")) - } + check_number_whole(args$min_n, allow_null = TRUE, call = call, arg = "min_n") + check_number_whole(args$tree, allow_null = TRUE, call = call, arg = "tree") + msg <- "The number of trees should be {.code >= 1} and {.code <= 100}" + if (!(is.null(args$trees)) && args$trees > 100) { + object$args$trees <- rlang::new_quosure(100L, env = rlang::empty_env()) + cli::cli_warn(c(msg, "Truncating to 100.")) } - if (is.numeric(args$min_n)) { - if (length(args$min_n) > 1) { - cli::cli_abort( - "Only a single value of {.arg min_n} should be passed, \\ - not {length(args$min_n)}.", - call = call - ) - } + if (!(is.null(args$trees)) && args$trees < 1) { + object$args$trees <- rlang::new_quosure(1L, env = rlang::empty_env()) + cli::cli_warn(c(msg, "Truncating to 1.")) } + invisible(object) } diff --git a/R/cubist_rules.R b/R/cubist_rules.R index e61e0eb8a..a8cc5407c 100644 --- a/R/cubist_rules.R +++ b/R/cubist_rules.R @@ -139,48 +139,32 @@ check_args.cubist_rules <- function(object, call = rlang::caller_env()) { args <- lapply(object$args, rlang::eval_tidy) - if (is.numeric(args$committees)) { - if (length(args$committees) > 1) { - cli::cli_abort( - "Only a single value of {.arg committees} should be passed, \\ - not {length(args$committees)}.", - call = call - ) - } - - msg <- "The number of committees should be {.code >= 1} and {.code <= 100}." - if (args$committees > 100) { - object$args$committees <- - rlang::new_quosure(100L, env = rlang::empty_env()) - cli::cli_warn(c(msg, "Truncating to 100.")) - } - if (args$committees < 1) { - object$args$committees <- - rlang::new_quosure(1L, env = rlang::empty_env()) - cli::cli_warn(c(msg, "Truncating to 1.")) - } + check_number_whole(args$committees, allow_null = TRUE, call = call, arg = "committees") - } - if (is.numeric(args$neighbors)) { - if (length(args$neighbors) > 1) { - cli::cli_abort( - "Only a single value of {.arg neighbors} should be passed, \\ - not {length(args$neighbors)}.", - call = call - ) + msg <- "The number of committees should be {.code >= 1} and {.code <= 100}." + if (!(is.null(args$committees)) && args$committees > 100) { + object$args$committees <- + rlang::new_quosure(100L, env = rlang::empty_env()) + cli::cli_warn(c(msg, "Truncating to 100.")) } + if (!(is.null(args$committees)) && args$committees < 1) { + object$args$committees <- + rlang::new_quosure(1L, env = rlang::empty_env()) + cli::cli_warn(c(msg, "Truncating to 1.")) + } - msg <- "The number of neighbors should be {.code >= 0} and {.code <= 9}." - if (args$neighbors > 9) { - object$args$neighbors <- rlang::new_quosure(9L, env = rlang::empty_env()) - cli::cli_warn(c(msg, "Truncating to 9.")) - } - if (args$neighbors < 0) { - object$args$neighbors <- rlang::new_quosure(0L, env = rlang::empty_env()) - cli::cli_warn(c(msg, "Truncating to 0.")) - } + check_number_whole(args$neighbors, allow_null = TRUE, call = call, arg = "neighbors") + msg <- "The number of neighbors should be {.code >= 0} and {.code <= 9}." + if (!(is.null(args$neighbors)) && args$neighbors > 9) { + object$args$neighbors <- rlang::new_quosure(9L, env = rlang::empty_env()) + cli::cli_warn(c(msg, "Truncating to 9.")) + } + if (!(is.null(args$neighbors)) && args$neighbors < 0) { + object$args$neighbors <- rlang::new_quosure(0L, env = rlang::empty_env()) + cli::cli_warn(c(msg, "Truncating to 0.")) } + invisible(object) } From 02be9dc139e6936c0999bef07997613109d68e76 Mon Sep 17 00:00:00 2001 From: Emil Hvitfeldt Date: Wed, 10 Apr 2024 10:31:21 -0700 Subject: [PATCH 18/24] use check_ functions for penalty --- R/discrim_linear.R | 7 +------ R/linear_reg.R | 8 +------- R/logistic_reg.R | 7 +------ R/multinom_reg.R | 7 +------ R/poisson_reg.R | 8 +------- tests/testthat/_snaps/linear_reg.md | 2 +- tests/testthat/_snaps/logistic_reg.md | 2 +- tests/testthat/_snaps/multinom_reg.md | 2 +- 8 files changed, 8 insertions(+), 35 deletions(-) diff --git a/R/discrim_linear.R b/R/discrim_linear.R index 88c0379b3..22eff4ea5 100644 --- a/R/discrim_linear.R +++ b/R/discrim_linear.R @@ -84,12 +84,7 @@ check_args.discrim_linear <- function(object, call = rlang::caller_env()) { args <- lapply(object$args, rlang::eval_tidy) - if (all(is.numeric(args$penalty)) && any(args$penalty < 0)) { - cli::cli_abort( - "The amount of regularization, {.arg penalty}, should be {.code >= 0}.", - call = call - ) - } + check_number_decimal(args$penalty, min = 0, allow_null = TRUE, call = call, arg = "penalty") invisible(object) } diff --git a/R/linear_reg.R b/R/linear_reg.R index 12d5c556d..0b7b636b4 100644 --- a/R/linear_reg.R +++ b/R/linear_reg.R @@ -111,13 +111,7 @@ check_args.linear_reg <- function(object, call = rlang::caller_env()) { args <- lapply(object$args, rlang::eval_tidy) check_number_decimal(args$mixture, min = 0, max = 1, allow_null = TRUE, call = call, arg = "mixture") - - if (all(is.numeric(args$penalty)) && any(args$penalty < 0)) { - cli::cli_abort( - "The amount of regularization, {.arg penalty}, should be {.code >= 0}.", - call = call - ) - } + check_number_decimal(args$penalty, min = 0, allow_null = TRUE, call = call, arg = "penalty") invisible(object) } diff --git a/R/logistic_reg.R b/R/logistic_reg.R index 9fc779762..587b652d5 100644 --- a/R/logistic_reg.R +++ b/R/logistic_reg.R @@ -140,12 +140,7 @@ check_args.logistic_reg <- function(object, call = rlang::caller_env()) { args <- lapply(object$args, rlang::eval_tidy) check_number_decimal(args$mixture, min = 0, max = 1, allow_null = TRUE, call = call, arg = "mixture") - - if (all(is.numeric(args$penalty)) && any(args$penalty < 0)) - cli::cli_abort( - "The amount of regularization, {.arg penalty}, should be {.code >= 0}.", - call = call - ) + check_number_decimal(args$penalty, min = 0, allow_null = TRUE, call = call, arg = "penalty") if (object$engine == "LiblineaR") { if (is.numeric(args$mixture) && !args$mixture %in% 0:1) { diff --git a/R/multinom_reg.R b/R/multinom_reg.R index cb074fc4a..1a8f0e8a1 100644 --- a/R/multinom_reg.R +++ b/R/multinom_reg.R @@ -105,13 +105,8 @@ check_args.multinom_reg <- function(object, call = rlang::caller_env()) { args <- lapply(object$args, rlang::eval_tidy) check_number_decimal(args$mixture, min = 0, max = 1, allow_null = TRUE, call = call, arg = "mixture") + check_number_decimal(args$penalty, min = 0, allow_null = TRUE, call = call, arg = "penalty") - if (all(is.numeric(args$penalty)) && any(args$penalty < 0)) { - cli::cli_abort( - "The amount of regularization, {.arg penalty}, should be {.code >= 0}.", - call = call - ) - } invisible(object) } diff --git a/R/poisson_reg.R b/R/poisson_reg.R index 790517e26..e3201aca1 100644 --- a/R/poisson_reg.R +++ b/R/poisson_reg.R @@ -106,13 +106,7 @@ check_args.poisson_reg <- function(object, call = rlang::caller_env()) { args <- lapply(object$args, rlang::eval_tidy) check_number_decimal(args$mixture, min = 0, max = 1, allow_null = TRUE, call = call, arg = "mixture") - - if (all(is.numeric(args$penalty)) && any(args$penalty < 0)) { - cli::cli_abort( - "The amount of regularization, {.arg penalty}, should be {.code >= 0}.", - call = call - ) - } + check_number_decimal(args$penalty, min = 0, allow_null = TRUE, call = call, arg = "penalty") invisible(object) } diff --git a/tests/testthat/_snaps/linear_reg.md b/tests/testthat/_snaps/linear_reg.md index af9999167..cce6764c5 100644 --- a/tests/testthat/_snaps/linear_reg.md +++ b/tests/testthat/_snaps/linear_reg.md @@ -39,5 +39,5 @@ fit(spec, compounds ~ ., hpc) Condition Error in `fit()`: - ! The amount of regularization, `penalty`, should be `>= 0`. + ! `penalty` must be a number larger than or equal to 0 or `NULL`, not the number -1. diff --git a/tests/testthat/_snaps/logistic_reg.md b/tests/testthat/_snaps/logistic_reg.md index 71aaa8535..1172a36ad 100644 --- a/tests/testthat/_snaps/logistic_reg.md +++ b/tests/testthat/_snaps/logistic_reg.md @@ -47,7 +47,7 @@ fit(spec, Class ~ ., lending_club) Condition Error in `fit()`: - ! The amount of regularization, `penalty`, should be `>= 0`. + ! `penalty` must be a number larger than or equal to 0 or `NULL`, not the number -1. --- diff --git a/tests/testthat/_snaps/multinom_reg.md b/tests/testthat/_snaps/multinom_reg.md index 002f2121e..0f01bb547 100644 --- a/tests/testthat/_snaps/multinom_reg.md +++ b/tests/testthat/_snaps/multinom_reg.md @@ -33,5 +33,5 @@ fit(spec, class ~ ., hpc) Condition Error in `fit()`: - ! The amount of regularization, `penalty`, should be `>= 0`. + ! `penalty` must be a number larger than or equal to 0 or `NULL`, not the number -1. From 61c452d4917250bcb51367d60640e45b1ddee0ba Mon Sep 17 00:00:00 2001 From: Emil Hvitfeldt Date: Wed, 10 Apr 2024 10:38:46 -0700 Subject: [PATCH 19/24] break up message into multiple lines --- R/logistic_reg.R | 10 +++++----- tests/testthat/_snaps/logistic_reg.md | 4 +++- 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/R/logistic_reg.R b/R/logistic_reg.R index 587b652d5..5f2a26328 100644 --- a/R/logistic_reg.R +++ b/R/logistic_reg.R @@ -145,11 +145,11 @@ check_args.logistic_reg <- function(object, call = rlang::caller_env()) { if (object$engine == "LiblineaR") { if (is.numeric(args$mixture) && !args$mixture %in% 0:1) { cli::cli_abort( - "For the {.pkg LiblineaR} engine, mixture must be 0 or 1,\\ - not {args$mixture}.\\ - Choose a pure ridge model with {.code mixture = 0}.\\ - Choose a pure lasso model with {.code mixture = 1}.\\ - The {.pkg Liblinear} engine does not support other values.", + c("x" = "For the {.pkg LiblineaR} engine, mixture must be 0 or 1,\\ + not {args$mixture}.", + "i" = "Choose a pure ridge model with {.code mixture = 0} or \\ + a pure lasso model with {.code mixture = 1}.", + "!" = "The {.pkg Liblinear} engine does not support other values."), call = call ) } diff --git a/tests/testthat/_snaps/logistic_reg.md b/tests/testthat/_snaps/logistic_reg.md index 1172a36ad..ecbc90264 100644 --- a/tests/testthat/_snaps/logistic_reg.md +++ b/tests/testthat/_snaps/logistic_reg.md @@ -57,7 +57,9 @@ fit(spec, Class ~ ., lending_club) Condition Error in `fit()`: - ! For the LiblineaR engine, mixture must be 0 or 1,not 0.5.Choose a pure ridge model with `mixture = 0`.Choose a pure lasso model with `mixture = 1`.The Liblinear engine does not support other values. + x For the LiblineaR engine, mixture must be 0 or 1,not 0.5. + i Choose a pure ridge model with `mixture = 0` or a pure lasso model with `mixture = 1`. + ! The Liblinear engine does not support other values. --- From b1e2f5a2228db1ba269c5712a76899b5068c9206 Mon Sep 17 00:00:00 2001 From: Emil Hvitfeldt Date: Wed, 10 Apr 2024 10:43:56 -0700 Subject: [PATCH 20/24] better LiblineaR specific penalty error --- R/logistic_reg.R | 5 +++-- tests/testthat/_snaps/logistic_reg.md | 2 +- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/R/logistic_reg.R b/R/logistic_reg.R index 5f2a26328..80d82f511 100644 --- a/R/logistic_reg.R +++ b/R/logistic_reg.R @@ -154,9 +154,10 @@ check_args.logistic_reg <- function(object, call = rlang::caller_env()) { ) } - if (all(is.numeric(args$penalty)) && !all(args$penalty > 0)) { + if ((!is.null(args$penalty)) && args$penalty <= 0) { cli::cli_abort( - "For the {.pkg LiblineaR} engine, {.arg penalty} must be {.code > 0}.", + "For the {.pkg LiblineaR} engine, {.arg penalty} must be {.code > 0}, \\ + not {args$penalty}.", call = call ) } diff --git a/tests/testthat/_snaps/logistic_reg.md b/tests/testthat/_snaps/logistic_reg.md index ecbc90264..4047e3bab 100644 --- a/tests/testthat/_snaps/logistic_reg.md +++ b/tests/testthat/_snaps/logistic_reg.md @@ -69,5 +69,5 @@ fit(spec, Class ~ ., lending_club) Condition Error in `fit()`: - ! For the LiblineaR engine, `penalty` must be `> 0`. + ! For the LiblineaR engine, `penalty` must be `> 0`, not 0. From eceac2dbe9a877e36808fc35e93469143e15e01c Mon Sep 17 00:00:00 2001 From: Emil Hvitfeldt Date: Wed, 10 Apr 2024 10:49:05 -0700 Subject: [PATCH 21/24] move data inside test_that() --- tests/testthat/test_nearest_neighbor.R | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/testthat/test_nearest_neighbor.R b/tests/testthat/test_nearest_neighbor.R index 456b7a955..2aba15810 100644 --- a/tests/testthat/test_nearest_neighbor.R +++ b/tests/testthat/test_nearest_neighbor.R @@ -1,5 +1,3 @@ -hpc <- hpc_data[1:150, c(2:5, 8)] - test_that('updating', { expect_snapshot( nearest_neighbor(neighbors = 5) %>% @@ -15,6 +13,8 @@ test_that('bad input', { test_that('check_args() works', { skip_if_not_installed("kknn") + + hpc <- hpc_data[1:150, c(2:5, 8)] expect_snapshot( error = TRUE, From c9a5ba535e30c4855210da49ff0b4f47db7025d6 Mon Sep 17 00:00:00 2001 From: Emil Hvitfeldt Date: Wed, 10 Apr 2024 10:52:03 -0700 Subject: [PATCH 22/24] adding newline at end of file --- tests/testthat/test-bart.R | 2 +- tests/testthat/test_boost_tree.R | 2 +- tests/testthat/test_decision_tree.R | 2 +- tests/testthat/test_gen_additive_model.R | 2 +- tests/testthat/test_linear_reg.R | 2 +- tests/testthat/test_logistic_reg.R | 3 +-- tests/testthat/test_mars.R | 2 +- tests/testthat/test_mlp.R | 2 +- tests/testthat/test_multinom_reg.R | 2 +- tests/testthat/test_nearest_neighbor.R | 2 +- tests/testthat/test_nullmodel.R | 2 +- tests/testthat/test_rand_forest.R | 2 +- tests/testthat/test_svm_linear.R | 2 +- tests/testthat/test_svm_poly.R | 2 +- tests/testthat/test_svm_rbf.R | 2 +- 15 files changed, 15 insertions(+), 16 deletions(-) diff --git a/tests/testthat/test-bart.R b/tests/testthat/test-bart.R index 245f18f7a..162cb0d5c 100644 --- a/tests/testthat/test-bart.R +++ b/tests/testthat/test-bart.R @@ -1,4 +1,4 @@ test_that("check_args() works", { # Here for completeness, no checking is done expect_true(TRUE) -}) \ No newline at end of file +}) diff --git a/tests/testthat/test_boost_tree.R b/tests/testthat/test_boost_tree.R index 4d0f135d7..a099d64bf 100644 --- a/tests/testthat/test_boost_tree.R +++ b/tests/testthat/test_boost_tree.R @@ -86,4 +86,4 @@ test_that("check_args() works", { fit(spec, class ~ ., hpc) } ) -}) \ No newline at end of file +}) diff --git a/tests/testthat/test_decision_tree.R b/tests/testthat/test_decision_tree.R index 12c32c824..85dd3ac1f 100644 --- a/tests/testthat/test_decision_tree.R +++ b/tests/testthat/test_decision_tree.R @@ -73,4 +73,4 @@ test_that('argument checks for data dimensions', { test_that("check_args() works", { # Here for completeness, no checking is done expect_true(TRUE) -}) \ No newline at end of file +}) diff --git a/tests/testthat/test_gen_additive_model.R b/tests/testthat/test_gen_additive_model.R index adaa60560..0b6a76ea8 100644 --- a/tests/testthat/test_gen_additive_model.R +++ b/tests/testthat/test_gen_additive_model.R @@ -102,4 +102,4 @@ test_that('classification', { test_that("check_args() works", { # Here for completeness, no checking is done expect_true(TRUE) -}) \ No newline at end of file +}) diff --git a/tests/testthat/test_linear_reg.R b/tests/testthat/test_linear_reg.R index 8381bf4df..637ec8075 100644 --- a/tests/testthat/test_linear_reg.R +++ b/tests/testthat/test_linear_reg.R @@ -359,4 +359,4 @@ test_that("check_args() works", { fit(spec, compounds ~ ., hpc) } ) -}) \ No newline at end of file +}) diff --git a/tests/testthat/test_logistic_reg.R b/tests/testthat/test_logistic_reg.R index 77c09f77f..3b46e023f 100644 --- a/tests/testthat/test_logistic_reg.R +++ b/tests/testthat/test_logistic_reg.R @@ -286,5 +286,4 @@ test_that("check_args() works", { fit(spec, Class ~ ., lending_club) } ) - -}) \ No newline at end of file +}) diff --git a/tests/testthat/test_mars.R b/tests/testthat/test_mars.R index fade45d2e..2a9d419cd 100644 --- a/tests/testthat/test_mars.R +++ b/tests/testthat/test_mars.R @@ -247,4 +247,4 @@ test_that("check_args() works", { fit(spec, class ~ ., hpc) } ) -}) \ No newline at end of file +}) diff --git a/tests/testthat/test_mlp.R b/tests/testthat/test_mlp.R index 731a20b01..4e9a52b4e 100644 --- a/tests/testthat/test_mlp.R +++ b/tests/testthat/test_mlp.R @@ -79,4 +79,4 @@ test_that("check_args() works", { fit(spec, class ~ ., hpc) } ) -}) \ No newline at end of file +}) diff --git a/tests/testthat/test_multinom_reg.R b/tests/testthat/test_multinom_reg.R index 7752d75e0..0cec39200 100644 --- a/tests/testthat/test_multinom_reg.R +++ b/tests/testthat/test_multinom_reg.R @@ -37,4 +37,4 @@ test_that('check_args() works', { fit(spec, class ~ ., hpc) } ) -}) \ No newline at end of file +}) diff --git a/tests/testthat/test_nearest_neighbor.R b/tests/testthat/test_nearest_neighbor.R index 2aba15810..c834dc570 100644 --- a/tests/testthat/test_nearest_neighbor.R +++ b/tests/testthat/test_nearest_neighbor.R @@ -34,4 +34,4 @@ test_that('check_args() works', { fit(spec, class ~ ., hpc) } ) -}) \ No newline at end of file +}) diff --git a/tests/testthat/test_nullmodel.R b/tests/testthat/test_nullmodel.R index 28e1da8a7..2f112c5f3 100644 --- a/tests/testthat/test_nullmodel.R +++ b/tests/testthat/test_nullmodel.R @@ -127,4 +127,4 @@ test_that('null_model printing', { test_that("check_args() works", { # Here for completeness, no checking is done expect_true(TRUE) -}) \ No newline at end of file +}) diff --git a/tests/testthat/test_rand_forest.R b/tests/testthat/test_rand_forest.R index 0de228d44..b05633146 100644 --- a/tests/testthat/test_rand_forest.R +++ b/tests/testthat/test_rand_forest.R @@ -17,4 +17,4 @@ test_that('bad input', { test_that("check_args() works", { # Here for completeness, no checking is done expect_true(TRUE) -}) \ No newline at end of file +}) diff --git a/tests/testthat/test_svm_linear.R b/tests/testthat/test_svm_linear.R index eb0663f14..31c8bfc41 100644 --- a/tests/testthat/test_svm_linear.R +++ b/tests/testthat/test_svm_linear.R @@ -373,4 +373,4 @@ test_that('linear svm classification prediction: kernlab', { test_that("check_args() works", { # Here for completeness, no checking is done expect_true(TRUE) -}) \ No newline at end of file +}) diff --git a/tests/testthat/test_svm_poly.R b/tests/testthat/test_svm_poly.R index 2dc67db6c..dda70dc52 100644 --- a/tests/testthat/test_svm_poly.R +++ b/tests/testthat/test_svm_poly.R @@ -189,4 +189,4 @@ test_that('svm poly classification probabilities', { test_that("check_args() works", { # Here for completeness, no checking is done expect_true(TRUE) -}) \ No newline at end of file +}) diff --git a/tests/testthat/test_svm_rbf.R b/tests/testthat/test_svm_rbf.R index 758885a6c..8e0c4bc4c 100644 --- a/tests/testthat/test_svm_rbf.R +++ b/tests/testthat/test_svm_rbf.R @@ -198,4 +198,4 @@ test_that('svm rbf classification probabilities', { test_that("check_args() works", { # Here for completeness, no checking is done expect_true(TRUE) -}) \ No newline at end of file +}) From fa40f0a3b7c302f9fa638c17e198835e860170c1 Mon Sep 17 00:00:00 2001 From: Emil Hvitfeldt Date: Wed, 10 Apr 2024 11:05:50 -0700 Subject: [PATCH 23/24] be more specific in testing of LiblineaR penalty --- R/logistic_reg.R | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/R/logistic_reg.R b/R/logistic_reg.R index 80d82f511..d0eee6043 100644 --- a/R/logistic_reg.R +++ b/R/logistic_reg.R @@ -154,10 +154,10 @@ check_args.logistic_reg <- function(object, call = rlang::caller_env()) { ) } - if ((!is.null(args$penalty)) && args$penalty <= 0) { + if ((!is.null(args$penalty)) && args$penalty == 0) { cli::cli_abort( "For the {.pkg LiblineaR} engine, {.arg penalty} must be {.code > 0}, \\ - not {args$penalty}.", + not 0.", call = call ) } From 21c0e911b57a64e7e8c562a719c7e66be7065322 Mon Sep 17 00:00:00 2001 From: Emil Hvitfeldt Date: Wed, 10 Apr 2024 11:08:41 -0700 Subject: [PATCH 24/24] fix space typo --- R/logistic_reg.R | 4 ++-- tests/testthat/_snaps/logistic_reg.md | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/R/logistic_reg.R b/R/logistic_reg.R index d0eee6043..40ccf5dc5 100644 --- a/R/logistic_reg.R +++ b/R/logistic_reg.R @@ -145,7 +145,7 @@ check_args.logistic_reg <- function(object, call = rlang::caller_env()) { if (object$engine == "LiblineaR") { if (is.numeric(args$mixture) && !args$mixture %in% 0:1) { cli::cli_abort( - c("x" = "For the {.pkg LiblineaR} engine, mixture must be 0 or 1,\\ + c("x" = "For the {.pkg LiblineaR} engine, mixture must be 0 or 1, \\ not {args$mixture}.", "i" = "Choose a pure ridge model with {.code mixture = 0} or \\ a pure lasso model with {.code mixture = 1}.", @@ -157,7 +157,7 @@ check_args.logistic_reg <- function(object, call = rlang::caller_env()) { if ((!is.null(args$penalty)) && args$penalty == 0) { cli::cli_abort( "For the {.pkg LiblineaR} engine, {.arg penalty} must be {.code > 0}, \\ - not 0.", + not 0.", call = call ) } diff --git a/tests/testthat/_snaps/logistic_reg.md b/tests/testthat/_snaps/logistic_reg.md index 4047e3bab..9bb671332 100644 --- a/tests/testthat/_snaps/logistic_reg.md +++ b/tests/testthat/_snaps/logistic_reg.md @@ -57,7 +57,7 @@ fit(spec, Class ~ ., lending_club) Condition Error in `fit()`: - x For the LiblineaR engine, mixture must be 0 or 1,not 0.5. + x For the LiblineaR engine, mixture must be 0 or 1, not 0.5. i Choose a pure ridge model with `mixture = 0` or a pure lasso model with `mixture = 1`. ! The Liblinear engine does not support other values.