From 0abee45e8803a0e7717ae9b562e27fb5f549b9c8 Mon Sep 17 00:00:00 2001 From: Julia Silge Date: Tue, 11 May 2021 17:39:30 -0600 Subject: [PATCH 1/5] Adjust where engine checking happens for better errors --- R/engines.R | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/R/engines.R b/R/engines.R index 913cf8a1e..5e2953f3a 100644 --- a/R/engines.R +++ b/R/engines.R @@ -15,6 +15,9 @@ check_engine <- function(object) { if (is.null(object$engine)) { object$engine <- avail_eng[1] rlang::warn(glue::glue("`engine` was NULL and updated to be `{object$engine}`")) + } else { + if (!is.character(object$engine) | length(object$engine) != 1) + rlang::abort("`engine` should be a single character value.") } if (!(object$engine %in% avail_eng)) { rlang::abort( @@ -91,18 +94,18 @@ set_engine <- function(object, engine, ...) { if (!inherits(object, "model_spec")) { rlang::abort("`object` should have class 'model_spec'.") } - if (!is.character(engine) | length(engine) != 1) - rlang::abort("`engine` should be a single character value.") - if (engine == "liquidSVM") { + + if (rlang::is_missing(engine)) engine <- NULL + object$engine <- engine + object <- check_engine(object) + + if (object$engine == "liquidSVM") { lifecycle::deprecate_soft( "0.1.6", "set_engine(engine = 'cannot be liquidSVM')", details = "The liquidSVM package is no longer available on CRAN.") } - object$engine <- engine - object <- check_engine(object) - new_model_spec( cls = class(object)[1], args = object$args, From a5203d3533463dad632f0f8ed3f74390e7f714ac Mon Sep 17 00:00:00 2001 From: Julia Silge Date: Tue, 11 May 2021 17:40:52 -0600 Subject: [PATCH 2/5] Better error message for missing/null mode --- R/arguments.R | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/R/arguments.R b/R/arguments.R index 3a4d887b3..043a867f5 100644 --- a/R/arguments.R +++ b/R/arguments.R @@ -76,10 +76,9 @@ set_args <- function(object, ...) { #' @rdname set_args #' @export set_mode <- function(object, mode) { - if (is.null(mode)) - return(object) + if (rlang::is_missing(mode)) mode <- NULL mode <- mode[1] - if (!(any(all_modes == mode))) { + if (is.null(mode) | !(any(all_modes == mode))) { rlang::abort( glue::glue( "`mode` should be one of ", From 002f53e1e1bd388c585d61d7e48a130df922a687 Mon Sep 17 00:00:00 2001 From: Julia Silge Date: Tue, 11 May 2021 18:17:42 -0600 Subject: [PATCH 3/5] Update tests for warnings/errors in set_mode + set_engine --- tests/testthat/test_mars.R | 2 +- tests/testthat/test_multinom_reg.R | 6 +++--- tests/testthat/test_nearest_neighbor.R | 2 +- tests/testthat/test_nullmodel.R | 2 +- tests/testthat/test_rand_forest.R | 2 +- tests/testthat/test_surv_reg.R | 2 +- tests/testthat/test_svm_linear.R | 2 +- tests/testthat/test_svm_liquidsvm.R | 2 +- tests/testthat/test_svm_poly.R | 2 +- tests/testthat/test_svm_rbf.R | 2 +- 10 files changed, 12 insertions(+), 12 deletions(-) diff --git a/tests/testthat/test_mars.R b/tests/testthat/test_mars.R index 63fa17e57..ca5b3e650 100644 --- a/tests/testthat/test_mars.R +++ b/tests/testthat/test_mars.R @@ -111,8 +111,8 @@ test_that('updating', { }) test_that('bad input', { + expect_warning(translate(mars(mode = "regression") %>% set_engine())) expect_error(translate(mars() %>% set_engine("wat?"))) - expect_error(translate(mars(mode = "regression") %>% set_engine())) expect_error(translate(mars(formula = y ~ x))) }) diff --git a/tests/testthat/test_multinom_reg.R b/tests/testthat/test_multinom_reg.R index 6a1b0037d..81118532a 100644 --- a/tests/testthat/test_multinom_reg.R +++ b/tests/testthat/test_multinom_reg.R @@ -109,7 +109,7 @@ test_that('updating', { test_that('bad input', { expect_error(multinom_reg(mode = "regression")) - expect_error(translate(multinom_reg() %>% set_engine("wat?"))) - expect_error(translate(multinom_reg() %>% set_engine())) - expect_warning(translate(multinom_reg(penalty = 0.01) %>% set_engine("glmnet", x = hpc[,1:3], y = hpc$class))) + expect_error(translate(multinom_reg(penalty = 0.1) %>% set_engine("wat?"))) + expect_warning(multinom_reg(penalty = 0.1) %>% set_engine()) + expect_warning(translate(multinom_reg(penalty = 0.1) %>% set_engine("glmnet", x = hpc[,1:3], y = hpc$class))) }) diff --git a/tests/testthat/test_nearest_neighbor.R b/tests/testthat/test_nearest_neighbor.R index 29b3f7f35..a550fa4ba 100644 --- a/tests/testthat/test_nearest_neighbor.R +++ b/tests/testthat/test_nearest_neighbor.R @@ -122,5 +122,5 @@ test_that('updating', { test_that('bad input', { expect_error(nearest_neighbor(mode = "reallyunknown")) - expect_error(translate(nearest_neighbor() %>% set_engine( NULL))) + expect_warning(nearest_neighbor() %>% set_engine( NULL)) }) diff --git a/tests/testthat/test_nullmodel.R b/tests/testthat/test_nullmodel.R index aaf77532a..9a5af12bf 100644 --- a/tests/testthat/test_nullmodel.R +++ b/tests/testthat/test_nullmodel.R @@ -32,8 +32,8 @@ test_that('engine arguments', { }) test_that('bad input', { + expect_warning(translate(null_model(mode = "regression") %>% set_engine())) expect_error(translate(null_model() %>% set_engine("wat?"))) - expect_error(translate(null_model(mode = "regression") %>% set_engine())) expect_error(translate(null_model(formula = y ~ x))) expect_warning( translate( diff --git a/tests/testthat/test_rand_forest.R b/tests/testthat/test_rand_forest.R index 1be3aa7e8..b8dbf361f 100644 --- a/tests/testthat/test_rand_forest.R +++ b/tests/testthat/test_rand_forest.R @@ -192,9 +192,9 @@ test_that('updating', { }) test_that('bad input', { + expect_warning(translate(rand_forest(mode = "classification") %>% set_engine(NULL))) expect_error(rand_forest(mode = "time series")) expect_error(translate(rand_forest(mode = "classification") %>% set_engine("wat?"))) - expect_error(translate(rand_forest(mode = "classification") %>% set_engine(NULL))) expect_error(translate(rand_forest(mode = "classification", ytest = 2))) }) diff --git a/tests/testthat/test_surv_reg.R b/tests/testthat/test_surv_reg.R index 424f48a9b..1947bed41 100644 --- a/tests/testthat/test_surv_reg.R +++ b/tests/testthat/test_surv_reg.R @@ -85,7 +85,7 @@ test_that('bad input', { expect_error(surv_reg(mode = ", classification")) expect_error(translate(surv_reg() %>% set_engine("wat"))) - expect_error(translate(surv_reg() %>% set_engine(NULL))) + expect_warning(translate(surv_reg() %>% set_engine(NULL))) }) test_that("deprecation warning", { diff --git a/tests/testthat/test_svm_linear.R b/tests/testthat/test_svm_linear.R index 24d4ced4a..f878aa95f 100644 --- a/tests/testthat/test_svm_linear.R +++ b/tests/testthat/test_svm_linear.R @@ -104,8 +104,8 @@ test_that('updating', { }) test_that('bad input', { + expect_warning(translate(svm_linear(mode = "regression") %>% set_engine( NULL))) expect_error(svm_linear(mode = "reallyunknown")) - expect_error(translate(svm_linear(mode = "regression") %>% set_engine( NULL))) expect_error(translate(svm_linear(mode = "regression") %>% set_engine("LiblineaR", type = 3))) expect_error(translate(svm_linear(mode = "classification") %>% set_engine("LiblineaR", type = 11))) }) diff --git a/tests/testthat/test_svm_liquidsvm.R b/tests/testthat/test_svm_liquidsvm.R index 373ff4c12..2a966b8e9 100644 --- a/tests/testthat/test_svm_liquidsvm.R +++ b/tests/testthat/test_svm_liquidsvm.R @@ -77,5 +77,5 @@ test_that('updating', { test_that('bad input', { expect_error(svm_rbf(mode = "reallyunknown")) - expect_error(translate(svm_rbf() %>% set_engine( NULL))) + expect_warning(svm_rbf() %>% set_engine( NULL)) }) diff --git a/tests/testthat/test_svm_poly.R b/tests/testthat/test_svm_poly.R index fef562ccd..fd403560f 100644 --- a/tests/testthat/test_svm_poly.R +++ b/tests/testthat/test_svm_poly.R @@ -106,7 +106,7 @@ test_that('updating', { test_that('bad input', { expect_error(svm_poly(mode = "reallyunknown")) - expect_error(translate(svm_poly() %>% set_engine( NULL))) + expect_warning(svm_poly() %>% set_engine(NULL)) }) # ------------------------------------------------------------------------------ diff --git a/tests/testthat/test_svm_rbf.R b/tests/testthat/test_svm_rbf.R index 059754fbf..9541f1792 100644 --- a/tests/testthat/test_svm_rbf.R +++ b/tests/testthat/test_svm_rbf.R @@ -87,7 +87,7 @@ test_that('updating', { test_that('bad input', { expect_error(svm_rbf(mode = "reallyunknown")) - expect_error(translate(svm_rbf(mode = "regression") %>% set_engine( NULL))) + expect_warning(translate(svm_rbf(mode = "regression") %>% set_engine( NULL))) }) # ------------------------------------------------------------------------------ From 806d8cb24a32a0c068cf8ac9a89a6811d1878036 Mon Sep 17 00:00:00 2001 From: Julia Silge Date: Wed, 12 May 2021 13:04:27 -0600 Subject: [PATCH 4/5] Use check_spec_mode_val for all mode checking --- R/aaa_models.R | 14 ++++++++------ R/arguments.R | 8 -------- 2 files changed, 8 insertions(+), 14 deletions(-) diff --git a/R/aaa_models.R b/R/aaa_models.R index 112d5b225..137062f2b 100644 --- a/R/aaa_models.R +++ b/R/aaa_models.R @@ -135,13 +135,15 @@ check_mode_val <- function(mode) { # check if class and mode are compatible check_spec_mode_val <- function(cls, mode) { spec_modes <- rlang::env_get(get_model_env(), paste0(cls, "_modes")) - if (!(mode %in% spec_modes)) - rlang::abort( - glue::glue( - "`mode` should be one of: ", - glue::glue_collapse(glue::glue("'{spec_modes}'"), sep = ", ") - ) + compatible_modes <- + glue::glue( + "`mode` should be one of: ", + glue::glue_collapse(glue::glue("'{spec_modes}'"), sep = ", ") ) + + if (is.null(mode)) rlang::abort(compatible_modes) + else if (!(mode %in% spec_modes)) rlang::abort(compatible_modes) + invisible(NULL) } diff --git a/R/arguments.R b/R/arguments.R index 043a867f5..60478d2bc 100644 --- a/R/arguments.R +++ b/R/arguments.R @@ -78,14 +78,6 @@ set_args <- function(object, ...) { set_mode <- function(object, mode) { if (rlang::is_missing(mode)) mode <- NULL mode <- mode[1] - if (is.null(mode) | !(any(all_modes == mode))) { - rlang::abort( - glue::glue( - "`mode` should be one of ", - glue::glue_collapse(glue::glue("'{all_modes}'"), sep = ", ") - ) - ) - } check_spec_mode_val(class(object)[1], mode) object$mode <- mode object From ba99fb2dcd632ad9b69a49e6dde3cf3748c0008b Mon Sep 17 00:00:00 2001 From: Max Kuhn Date: Thu, 13 May 2021 09:35:40 -0400 Subject: [PATCH 5/5] minor linting --- R/aaa_models.R | 10 +++++++--- R/arguments.R | 4 +++- R/engines.R | 7 +++++-- 3 files changed, 15 insertions(+), 6 deletions(-) diff --git a/R/aaa_models.R b/R/aaa_models.R index 137062f2b..2a5f6017e 100644 --- a/R/aaa_models.R +++ b/R/aaa_models.R @@ -127,8 +127,9 @@ check_model_doesnt_exist <- function(model) { } check_mode_val <- function(mode) { - if (rlang::is_missing(mode) || length(mode) != 1 || !is.character(mode)) + if (rlang::is_missing(mode) || length(mode) != 1 || !is.character(mode)) { rlang::abort("Please supply a character string for a mode (e.g. `'regression'`).") + } invisible(NULL) } @@ -141,8 +142,11 @@ check_spec_mode_val <- function(cls, mode) { glue::glue_collapse(glue::glue("'{spec_modes}'"), sep = ", ") ) - if (is.null(mode)) rlang::abort(compatible_modes) - else if (!(mode %in% spec_modes)) rlang::abort(compatible_modes) + if (is.null(mode)) { + rlang::abort(compatible_modes) + } else if (!(mode %in% spec_modes)) { + rlang::abort(compatible_modes) + } invisible(NULL) } diff --git a/R/arguments.R b/R/arguments.R index 60478d2bc..ed2b995a7 100644 --- a/R/arguments.R +++ b/R/arguments.R @@ -76,7 +76,9 @@ set_args <- function(object, ...) { #' @rdname set_args #' @export set_mode <- function(object, mode) { - if (rlang::is_missing(mode)) mode <- NULL + if (rlang::is_missing(mode)) { + mode <- NULL + } mode <- mode[1] check_spec_mode_val(class(object)[1], mode) object$mode <- mode diff --git a/R/engines.R b/R/engines.R index 5e2953f3a..1babb0186 100644 --- a/R/engines.R +++ b/R/engines.R @@ -16,8 +16,9 @@ check_engine <- function(object) { object$engine <- avail_eng[1] rlang::warn(glue::glue("`engine` was NULL and updated to be `{object$engine}`")) } else { - if (!is.character(object$engine) | length(object$engine) != 1) + if (!is.character(object$engine) | length(object$engine) != 1) { rlang::abort("`engine` should be a single character value.") + } } if (!(object$engine %in% avail_eng)) { rlang::abort( @@ -95,7 +96,9 @@ set_engine <- function(object, engine, ...) { rlang::abort("`object` should have class 'model_spec'.") } - if (rlang::is_missing(engine)) engine <- NULL + if (rlang::is_missing(engine)) { + engine <- NULL + } object$engine <- engine object <- check_engine(object)