Skip to content

Commit ae16fb8

Browse files
authored
Merge pull request #515 from tidymodels/add-default-engine
default engine changes for #513
2 parents 255b67e + 713bc94 commit ae16fb8

38 files changed

+241
-117
lines changed

NEWS.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,9 @@
11
# parsnip (development version)
22

3+
* Each model now has a default engine that is used when the model is defined. The default for each model is listed in the help documents. This also adds functionality to declare an engine in the model specification function. `set_engine()` is still required if engine-specific arguments need to be added. (#513)
4+
5+
* The default engine for `multinom_reg()` was changed to `nnet`.
6+
37
* The helper functions `.convert_form_to_xy_fit()`, `.convert_form_to_xy_new()`, `.convert_xy_to_form_fit()`, and `.convert_xy_to_form_new()` for converting between formula and matrix interface are now exported for developer use (#508).
48

59
* Fix bug in `augment()` when non-predictor, non-outcome variables are included in data (#510).

R/boost_tree.R

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,9 +29,12 @@
2929
#' functions. If parameters need to be modified, `update()` can be used
3030
#' in lieu of recreating the object from scratch.
3131
#'
32-
#' @param mode A single character string for the type of model.
32+
#' @param mode A single character string for the prediction outcome mode.
3333
#' Possible values for this model are "unknown", "regression", or
3434
#' "classification".
35+
#' @param engine A single character string specifying what computational engine
36+
#' to use for fitting. Possible engines are listed below. The default for this
37+
#' model is `"xgboost"`.
3538
#' @param mtry A number for the number (or proportion) of predictors that will
3639
#' be randomly sampled at each split when creating the tree models (`xgboost`
3740
#' only).
@@ -92,6 +95,7 @@
9295

9396
boost_tree <-
9497
function(mode = "unknown",
98+
engine = "xgboost",
9599
mtry = NULL, trees = NULL, min_n = NULL,
96100
tree_depth = NULL, learn_rate = NULL,
97101
loss_reduction = NULL,
@@ -114,7 +118,7 @@ boost_tree <-
114118
eng_args = NULL,
115119
mode,
116120
method = NULL,
117-
engine = NULL
121+
engine = engine
118122
)
119123
}
120124

R/decision_tree.R

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,12 @@
2121
#' functions. If parameters need to be modified, `update()` can be used
2222
#' in lieu of recreating the object from scratch.
2323
#'
24-
#' @param mode A single character string for the type of model.
24+
#' @param mode A single character string for the prediction outcome mode.
2525
#' Possible values for this model are "unknown", "regression", or
2626
#' "classification".
27+
#' @param engine A single character string specifying what computational engine
28+
#' to use for fitting. Possible engines are listed below. The default for this
29+
#' model is `"rpart"`.
2730
#' @param cost_complexity A positive number for the the cost/complexity
2831
#' parameter (a.k.a. `Cp`) used by CART models (`rpart` only).
2932
#' @param tree_depth An integer for maximum depth of the tree.
@@ -69,7 +72,8 @@
6972
#' @export
7073

7174
decision_tree <-
72-
function(mode = "unknown", cost_complexity = NULL, tree_depth = NULL, min_n = NULL) {
75+
function(mode = "unknown", engine = "rpart", cost_complexity = NULL,
76+
tree_depth = NULL, min_n = NULL) {
7377

7478
args <- list(
7579
cost_complexity = enquo(cost_complexity),
@@ -83,7 +87,7 @@ decision_tree <-
8387
eng_args = NULL,
8488
mode = mode,
8589
method = NULL,
86-
engine = NULL
90+
engine = engine
8791
)
8892
}
8993

R/linear_reg.R

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,11 @@
1616
#' here (`NULL`), the values are taken from the underlying model
1717
#' functions. If parameters need to be modified, `update()` can be used
1818
#' in lieu of recreating the object from scratch.
19-
#' @param mode A single character string for the type of model.
19+
#' @param mode A single character string for the prediction outcome mode.
2020
#' The only possible value for this model is "regression".
21+
#' @param engine A single character string specifying what computational engine
22+
#' to use for fitting. Possible engines are listed below. The default for this
23+
#' model is `"lm"`.
2124
#' @param penalty A non-negative number representing the total
2225
#' amount of regularization (`glmnet`, `keras`, and `spark` only).
2326
#' For `keras` models, this corresponds to purely L2 regularization
@@ -70,6 +73,7 @@
7073
#' @importFrom purrr map_lgl
7174
linear_reg <-
7275
function(mode = "regression",
76+
engine = "lm",
7377
penalty = NULL,
7478
mixture = NULL) {
7579

@@ -84,7 +88,7 @@ linear_reg <-
8488
eng_args = NULL,
8589
mode = mode,
8690
method = NULL,
87-
engine = NULL
91+
engine = engine
8892
)
8993
}
9094

R/logistic_reg.R

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,11 @@
1616
#' here (`NULL`), the values are taken from the underlying model
1717
#' functions. If parameters need to be modified, `update()` can be used
1818
#' in lieu of recreating the object from scratch.
19-
#' @param mode A single character string for the type of model.
19+
#' @param mode A single character string for the prediction outcome mode.
2020
#' The only possible value for this model is "classification".
21+
#' @param engine A single character string specifying what computational engine
22+
#' to use for fitting. Possible engines are listed below. The default for this
23+
#' model is `"glm"`.
2124
#' @param penalty A non-negative number representing the total
2225
#' amount of regularization (`glmnet`, `LiblineaR`, `keras`, and `spark` only).
2326
#' For `keras` models, this corresponds to purely L2 regularization
@@ -69,6 +72,7 @@
6972
#' @importFrom purrr map_lgl
7073
logistic_reg <-
7174
function(mode = "classification",
75+
engine = "glm",
7276
penalty = NULL,
7377
mixture = NULL) {
7478

@@ -83,7 +87,7 @@ logistic_reg <-
8387
eng_args = NULL,
8488
mode = mode,
8589
method = NULL,
86-
engine = NULL
90+
engine = engine
8791
)
8892
}
8993

R/mars.R

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,12 @@
2222
#' functions. If parameters need to be modified, `update()` can be used
2323
#' in lieu of recreating the object from scratch.
2424
#'
25-
#' @param mode A single character string for the type of model.
25+
#' @param mode A single character string for the prediction outcome mode.
2626
#' Possible values for this model are "unknown", "regression", or
2727
#' "classification".
28+
#' @param engine A single character string specifying what computational engine
29+
#' to use for fitting. Possible engines are listed below. The default for this
30+
#' model is `"earth"`.
2831
#' @param num_terms The number of features that will be retained in the
2932
#' final model, including the intercept.
3033
#' @param prod_degree The highest possible interaction degree.
@@ -45,7 +48,7 @@
4548
#' mars(mode = "regression", num_terms = 5)
4649
#' @export
4750
mars <-
48-
function(mode = "unknown",
51+
function(mode = "unknown", engine = "earth",
4952
num_terms = NULL, prod_degree = NULL, prune_method = NULL) {
5053

5154
args <- list(
@@ -60,7 +63,7 @@ mars <-
6063
eng_args = NULL,
6164
mode = mode,
6265
method = NULL,
63-
engine = NULL
66+
engine = engine
6467
)
6568
}
6669

R/mlp.R

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,9 +27,12 @@
2727
#' If parameters need to be modified, `update()` can be used
2828
#' in lieu of recreating the object from scratch.
2929
#'
30-
#' @param mode A single character string for the type of model.
30+
#' @param mode A single character string for the prediction outcome mode.
3131
#' Possible values for this model are "unknown", "regression", or
3232
#' "classification".
33+
#' @param engine A single character string specifying what computational engine
34+
#' to use for fitting. Possible engines are listed below. The default for this
35+
#' model is `"nnet"`.
3336
#' @param hidden_units An integer for the number of units in the hidden model.
3437
#' @param penalty A non-negative numeric value for the amount of weight
3538
#' decay.
@@ -63,7 +66,7 @@
6366
#' @export
6467

6568
mlp <-
66-
function(mode = "unknown",
69+
function(mode = "unknown", engine = "nnet",
6770
hidden_units = NULL, penalty = NULL, dropout = NULL, epochs = NULL,
6871
activation = NULL) {
6972

@@ -81,7 +84,7 @@ mlp <-
8184
eng_args = NULL,
8285
mode = mode,
8386
method = NULL,
84-
engine = NULL
87+
engine = engine
8588
)
8689
}
8790

R/multinom_reg.R

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,11 @@
1616
#' here (`NULL`), the values are taken from the underlying model
1717
#' functions. If parameters need to be modified, `update()` can be used
1818
#' in lieu of recreating the object from scratch.
19-
#' @param mode A single character string for the type of model.
19+
#' @param mode A single character string for the prediction outcome mode.
2020
#' The only possible value for this model is "classification".
21+
#' @param engine A single character string specifying what computational engine
22+
#' to use for fitting. Possible engines are listed below. The default for this
23+
#' model is `"nnet"`.
2124
#' @param penalty A non-negative number representing the total
2225
#' amount of regularization (`glmnet`, `keras`, and `spark` only).
2326
#' For `keras` models, this corresponds to purely L2 regularization
@@ -33,7 +36,7 @@
3336
#' The model can be created using the `fit()` function using the
3437
#' following _engines_:
3538
#' \itemize{
36-
#' \item \pkg{R}: `"glmnet"` (the default), `"nnet"`
39+
#' \item \pkg{R}: `"nnet"` (the default), `"glmnet"`
3740
#' \item \pkg{Spark}: `"spark"`
3841
#' \item \pkg{keras}: `"keras"`
3942
#' }
@@ -64,6 +67,7 @@
6467
#' @importFrom purrr map_lgl
6568
multinom_reg <-
6669
function(mode = "classification",
70+
engine = "nnet",
6771
penalty = NULL,
6872
mixture = NULL) {
6973

@@ -78,7 +82,7 @@ multinom_reg <-
7882
eng_args = NULL,
7983
mode = mode,
8084
method = NULL,
81-
engine = NULL
85+
engine = engine
8286
)
8387
}
8488

R/nearest_neighbor.R

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,10 +23,12 @@
2323
#' here (`NULL`), the values are taken from the underlying model
2424
#' functions. If parameters need to be modified, `update()` can be used
2525
#' in lieu of recreating the object from scratch.
26-
#' @param mode A single character string for the type of model.
26+
#' @param mode A single character string for the prediction outcome mode.
2727
#' Possible values for this model are `"unknown"`, `"regression"`, or
2828
#' `"classification"`.
29-
#'
29+
#' @param engine A single character string specifying what computational engine
30+
#' to use for fitting. Possible engines are listed below. The default for this
31+
#' model is `"kknn"`.
3032
#' @param neighbors A single integer for the number of neighbors
3133
#' to consider (often called `k`). For \pkg{kknn}, a value of 5
3234
#' is used if `neighbors` is not specified.
@@ -57,6 +59,7 @@
5759
#'
5860
#' @export
5961
nearest_neighbor <- function(mode = "unknown",
62+
engine = "kknn",
6063
neighbors = NULL,
6164
weight_func = NULL,
6265
dist_power = NULL) {
@@ -72,7 +75,7 @@ nearest_neighbor <- function(mode = "unknown",
7275
eng_args = NULL,
7376
mode = mode,
7477
method = NULL,
75-
engine = NULL
78+
engine = engine
7679
)
7780
}
7881

R/proportional_hazards.R

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,11 @@
1616
#' functions. If parameters need to be modified, `update()` can be used
1717
#' in lieu of recreating the object from scratch.
1818
#'
19-
#' @param mode A single character string for the type of model.
19+
#' @param mode A single character string for the prediction outcome mode.
2020
#' Possible values for this model are "unknown", or "censored regression".
21+
#' @param engine A single character string specifying what computational engine
22+
#' to use for fitting. Possible engines are listed below. The default for this
23+
#' model is `"survival"`.
2124
#' @inheritParams linear_reg
2225
#'
2326
#' @details
@@ -29,9 +32,11 @@
2932
#' show_engines("proportional_hazards")
3033
#' @keywords internal
3134
#' @export
32-
proportional_hazards <- function(mode = "censored regression",
33-
penalty = NULL,
34-
mixture = NULL) {
35+
proportional_hazards <- function(
36+
mode = "censored regression",
37+
engine = "survival",
38+
penalty = NULL,
39+
mixture = NULL) {
3540

3641
args <- list(
3742
penalty = enquo(penalty),
@@ -44,7 +49,7 @@ proportional_hazards <- function(mode = "censored regression",
4449
eng_args = NULL,
4550
mode = mode,
4651
method = NULL,
47-
engine = NULL
52+
engine = engine
4853
)
4954
}
5055

0 commit comments

Comments
 (0)