diff --git a/R/aaa_models.R b/R/aaa_models.R index 112d5b225..2a5f6017e 100644 --- a/R/aaa_models.R +++ b/R/aaa_models.R @@ -127,21 +127,27 @@ check_model_doesnt_exist <- function(model) { } check_mode_val <- function(mode) { - if (rlang::is_missing(mode) || length(mode) != 1 || !is.character(mode)) + if (rlang::is_missing(mode) || length(mode) != 1 || !is.character(mode)) { rlang::abort("Please supply a character string for a mode (e.g. `'regression'`).") + } invisible(NULL) } # check if class and mode are compatible check_spec_mode_val <- function(cls, mode) { spec_modes <- rlang::env_get(get_model_env(), paste0(cls, "_modes")) - if (!(mode %in% spec_modes)) - rlang::abort( - glue::glue( - "`mode` should be one of: ", - glue::glue_collapse(glue::glue("'{spec_modes}'"), sep = ", ") - ) + compatible_modes <- + glue::glue( + "`mode` should be one of: ", + glue::glue_collapse(glue::glue("'{spec_modes}'"), sep = ", ") ) + + if (is.null(mode)) { + rlang::abort(compatible_modes) + } else if (!(mode %in% spec_modes)) { + rlang::abort(compatible_modes) + } + invisible(NULL) } diff --git a/R/arguments.R b/R/arguments.R index 3a4d887b3..ed2b995a7 100644 --- a/R/arguments.R +++ b/R/arguments.R @@ -76,17 +76,10 @@ set_args <- function(object, ...) { #' @rdname set_args #' @export set_mode <- function(object, mode) { - if (is.null(mode)) - return(object) - mode <- mode[1] - if (!(any(all_modes == mode))) { - rlang::abort( - glue::glue( - "`mode` should be one of ", - glue::glue_collapse(glue::glue("'{all_modes}'"), sep = ", ") - ) - ) + if (rlang::is_missing(mode)) { + mode <- NULL } + mode <- mode[1] check_spec_mode_val(class(object)[1], mode) object$mode <- mode object diff --git a/R/engines.R b/R/engines.R index 913cf8a1e..1babb0186 100644 --- a/R/engines.R +++ b/R/engines.R @@ -15,6 +15,10 @@ check_engine <- function(object) { if (is.null(object$engine)) { object$engine <- avail_eng[1] rlang::warn(glue::glue("`engine` was NULL and updated to be `{object$engine}`")) + } else { + if (!is.character(object$engine) | length(object$engine) != 1) { + rlang::abort("`engine` should be a single character value.") + } } if (!(object$engine %in% avail_eng)) { rlang::abort( @@ -91,18 +95,20 @@ set_engine <- function(object, engine, ...) { if (!inherits(object, "model_spec")) { rlang::abort("`object` should have class 'model_spec'.") } - if (!is.character(engine) | length(engine) != 1) - rlang::abort("`engine` should be a single character value.") - if (engine == "liquidSVM") { + + if (rlang::is_missing(engine)) { + engine <- NULL + } + object$engine <- engine + object <- check_engine(object) + + if (object$engine == "liquidSVM") { lifecycle::deprecate_soft( "0.1.6", "set_engine(engine = 'cannot be liquidSVM')", details = "The liquidSVM package is no longer available on CRAN.") } - object$engine <- engine - object <- check_engine(object) - new_model_spec( cls = class(object)[1], args = object$args, diff --git a/tests/testthat/test_mars.R b/tests/testthat/test_mars.R index 63fa17e57..ca5b3e650 100644 --- a/tests/testthat/test_mars.R +++ b/tests/testthat/test_mars.R @@ -111,8 +111,8 @@ test_that('updating', { }) test_that('bad input', { + expect_warning(translate(mars(mode = "regression") %>% set_engine())) expect_error(translate(mars() %>% set_engine("wat?"))) - expect_error(translate(mars(mode = "regression") %>% set_engine())) expect_error(translate(mars(formula = y ~ x))) }) diff --git a/tests/testthat/test_multinom_reg.R b/tests/testthat/test_multinom_reg.R index 6a1b0037d..81118532a 100644 --- a/tests/testthat/test_multinom_reg.R +++ b/tests/testthat/test_multinom_reg.R @@ -109,7 +109,7 @@ test_that('updating', { test_that('bad input', { expect_error(multinom_reg(mode = "regression")) - expect_error(translate(multinom_reg() %>% set_engine("wat?"))) - expect_error(translate(multinom_reg() %>% set_engine())) - expect_warning(translate(multinom_reg(penalty = 0.01) %>% set_engine("glmnet", x = hpc[,1:3], y = hpc$class))) + expect_error(translate(multinom_reg(penalty = 0.1) %>% set_engine("wat?"))) + expect_warning(multinom_reg(penalty = 0.1) %>% set_engine()) + expect_warning(translate(multinom_reg(penalty = 0.1) %>% set_engine("glmnet", x = hpc[,1:3], y = hpc$class))) }) diff --git a/tests/testthat/test_nearest_neighbor.R b/tests/testthat/test_nearest_neighbor.R index 29b3f7f35..a550fa4ba 100644 --- a/tests/testthat/test_nearest_neighbor.R +++ b/tests/testthat/test_nearest_neighbor.R @@ -122,5 +122,5 @@ test_that('updating', { test_that('bad input', { expect_error(nearest_neighbor(mode = "reallyunknown")) - expect_error(translate(nearest_neighbor() %>% set_engine( NULL))) + expect_warning(nearest_neighbor() %>% set_engine( NULL)) }) diff --git a/tests/testthat/test_nullmodel.R b/tests/testthat/test_nullmodel.R index aaf77532a..9a5af12bf 100644 --- a/tests/testthat/test_nullmodel.R +++ b/tests/testthat/test_nullmodel.R @@ -32,8 +32,8 @@ test_that('engine arguments', { }) test_that('bad input', { + expect_warning(translate(null_model(mode = "regression") %>% set_engine())) expect_error(translate(null_model() %>% set_engine("wat?"))) - expect_error(translate(null_model(mode = "regression") %>% set_engine())) expect_error(translate(null_model(formula = y ~ x))) expect_warning( translate( diff --git a/tests/testthat/test_rand_forest.R b/tests/testthat/test_rand_forest.R index 1be3aa7e8..b8dbf361f 100644 --- a/tests/testthat/test_rand_forest.R +++ b/tests/testthat/test_rand_forest.R @@ -192,9 +192,9 @@ test_that('updating', { }) test_that('bad input', { + expect_warning(translate(rand_forest(mode = "classification") %>% set_engine(NULL))) expect_error(rand_forest(mode = "time series")) expect_error(translate(rand_forest(mode = "classification") %>% set_engine("wat?"))) - expect_error(translate(rand_forest(mode = "classification") %>% set_engine(NULL))) expect_error(translate(rand_forest(mode = "classification", ytest = 2))) }) diff --git a/tests/testthat/test_surv_reg.R b/tests/testthat/test_surv_reg.R index 424f48a9b..1947bed41 100644 --- a/tests/testthat/test_surv_reg.R +++ b/tests/testthat/test_surv_reg.R @@ -85,7 +85,7 @@ test_that('bad input', { expect_error(surv_reg(mode = ", classification")) expect_error(translate(surv_reg() %>% set_engine("wat"))) - expect_error(translate(surv_reg() %>% set_engine(NULL))) + expect_warning(translate(surv_reg() %>% set_engine(NULL))) }) test_that("deprecation warning", { diff --git a/tests/testthat/test_svm_linear.R b/tests/testthat/test_svm_linear.R index 24d4ced4a..f878aa95f 100644 --- a/tests/testthat/test_svm_linear.R +++ b/tests/testthat/test_svm_linear.R @@ -104,8 +104,8 @@ test_that('updating', { }) test_that('bad input', { + expect_warning(translate(svm_linear(mode = "regression") %>% set_engine( NULL))) expect_error(svm_linear(mode = "reallyunknown")) - expect_error(translate(svm_linear(mode = "regression") %>% set_engine( NULL))) expect_error(translate(svm_linear(mode = "regression") %>% set_engine("LiblineaR", type = 3))) expect_error(translate(svm_linear(mode = "classification") %>% set_engine("LiblineaR", type = 11))) }) diff --git a/tests/testthat/test_svm_liquidsvm.R b/tests/testthat/test_svm_liquidsvm.R index 373ff4c12..2a966b8e9 100644 --- a/tests/testthat/test_svm_liquidsvm.R +++ b/tests/testthat/test_svm_liquidsvm.R @@ -77,5 +77,5 @@ test_that('updating', { test_that('bad input', { expect_error(svm_rbf(mode = "reallyunknown")) - expect_error(translate(svm_rbf() %>% set_engine( NULL))) + expect_warning(svm_rbf() %>% set_engine( NULL)) }) diff --git a/tests/testthat/test_svm_poly.R b/tests/testthat/test_svm_poly.R index fef562ccd..fd403560f 100644 --- a/tests/testthat/test_svm_poly.R +++ b/tests/testthat/test_svm_poly.R @@ -106,7 +106,7 @@ test_that('updating', { test_that('bad input', { expect_error(svm_poly(mode = "reallyunknown")) - expect_error(translate(svm_poly() %>% set_engine( NULL))) + expect_warning(svm_poly() %>% set_engine(NULL)) }) # ------------------------------------------------------------------------------ diff --git a/tests/testthat/test_svm_rbf.R b/tests/testthat/test_svm_rbf.R index 059754fbf..9541f1792 100644 --- a/tests/testthat/test_svm_rbf.R +++ b/tests/testthat/test_svm_rbf.R @@ -87,7 +87,7 @@ test_that('updating', { test_that('bad input', { expect_error(svm_rbf(mode = "reallyunknown")) - expect_error(translate(svm_rbf(mode = "regression") %>% set_engine( NULL))) + expect_warning(translate(svm_rbf(mode = "regression") %>% set_engine( NULL))) }) # ------------------------------------------------------------------------------