Skip to content

Commit bc125e9

Browse files
authored
Merge pull request #485 from tidymodels/glmnet-fail-penalty
Add error for glmnet models if penalty is not exactly 1
2 parents 01168ca + 2d9c1d3 commit bc125e9

17 files changed

+70
-118
lines changed

R/engines.R

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ load_libs <- function(x, quiet, attach = FALSE) {
8282
#' @examples
8383
#' # First, set general arguments using the standardized names
8484
#' mod <-
85-
#' logistic_reg(mixture = 1/3) %>%
85+
#' logistic_reg(penalty = 0.01, mixture = 1/3) %>%
8686
#' # now say how you want to fit the model and another other options
8787
#' set_engine("glmnet", nlambda = 10)
8888
#' translate(mod, engine = "glmnet")

R/linear_reg.R

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,7 @@ translate.linear_reg <- function(x, engine = x$engine, ...) {
112112
# Since the `fit` information is gone for the penalty, we need to have an
113113
# evaluated value for the parameter.
114114
x$args$penalty <- rlang::eval_tidy(x$args$penalty)
115+
check_glmnet_penalty(x)
115116
}
116117

117118
x

R/logistic_reg.R

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,7 @@ translate.logistic_reg <- function(x, engine = x$engine, ...) {
115115
# Since the `fit` information is gone for the penalty, we need to have an
116116
# evaluated value for the parameter.
117117
x$args$penalty <- rlang::eval_tidy(x$args$penalty)
118+
check_glmnet_penalty(x)
118119
}
119120

120121
if (engine == "LiblineaR") {

R/misc.R

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -323,4 +323,13 @@ stan_conf_int <- function(object, newdata) {
323323
rlang::eval_tidy(fn)
324324
}
325325

326-
326+
check_glmnet_penalty <- function(x) {
327+
if (length(x$args$penalty) != 1) {
328+
rlang::abort(c(
329+
"For the glmnet engine, `penalty` must be a single number (or a value of `tune()`).",
330+
glue::glue("There are {length(x$args$penalty)} values for `penalty`."),
331+
"To try multiple values for total regularization, use the tune package.",
332+
"To predict multiple penalties, use `multi_predict()`"
333+
))
334+
}
335+
}

R/translate.R

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@
3838
#' translate(lm_spec, engine = "spark")
3939
#'
4040
#' # with a placeholder for an unknown argument value:
41-
#' translate(linear_reg(mixture = varying()), engine = "glmnet")
41+
#' translate(linear_reg(penalty = varying(), mixture = varying()), engine = "glmnet")
4242
#'
4343
#' @export
4444

man/contr_one_hot.Rd

Lines changed: 6 additions & 4 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

man/linear_reg.Rd

Lines changed: 4 additions & 6 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

man/logistic_reg.Rd

Lines changed: 4 additions & 7 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

man/multinom_reg.Rd

Lines changed: 4 additions & 5 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

man/rmd/linear-reg.Rmd

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,16 +10,14 @@ Engines may have pre-set default arguments when executing the model fit call. Fo
1010
```{r lm-reg}
1111
linear_reg() %>%
1212
set_engine("lm") %>%
13-
set_mode("regression") %>%
1413
translate()
1514
```
1615

1716
## glmnet
1817

1918
```{r glmnet-csl}
20-
linear_reg() %>%
19+
linear_reg(penalty = 0.1) %>%
2120
set_engine("glmnet") %>%
22-
set_mode("regression") %>%
2321
translate()
2422
```
2523

@@ -37,7 +35,6 @@ penalty results.
3735
```{r stan-reg}
3836
linear_reg() %>%
3937
set_engine("stan") %>%
40-
set_mode("regression") %>%
4138
translate()
4239
```
4340

@@ -55,7 +52,6 @@ returned.
5552
```{r spark-reg}
5653
linear_reg() %>%
5754
set_engine("spark") %>%
58-
set_mode("regression") %>%
5955
translate()
6056
```
6157

@@ -64,7 +60,6 @@ linear_reg() %>%
6460
```{r keras-reg}
6561
linear_reg() %>%
6662
set_engine("keras") %>%
67-
set_mode("regression") %>%
6863
translate()
6964
```
7065

0 commit comments

Comments
 (0)