Skip to content

default engine changes for #513 #515

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Jun 24, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
# parsnip (development version)

* Each model now has a default engine that is used when the model is defined. The default for each model is listed in the help documents. This also adds functionality to declare an engine in the model specification function. `set_engine()` is still required if engine-specific arguments need to be added. (#513)

* The default engine for `multinom_reg()` was changed to `nnet`.

* The helper functions `.convert_form_to_xy_fit()`, `.convert_form_to_xy_new()`, `.convert_xy_to_form_fit()`, and `.convert_xy_to_form_new()` for converting between formula and matrix interface are now exported for developer use (#508).

* Fix bug in `augment()` when non-predictor, non-outcome variables are included in data (#510).
Expand Down
8 changes: 6 additions & 2 deletions R/boost_tree.R
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,12 @@
#' functions. If parameters need to be modified, `update()` can be used
#' in lieu of recreating the object from scratch.
#'
#' @param mode A single character string for the type of model.
#' @param mode A single character string for the prediction outcome mode.
#' Possible values for this model are "unknown", "regression", or
#' "classification".
#' @param engine A single character string specifying what computational engine
#' to use for fitting. Possible engines are listed below. The default for this
#' model is `"xgboost"`.
#' @param mtry A number for the number (or proportion) of predictors that will
#' be randomly sampled at each split when creating the tree models (`xgboost`
#' only).
Expand Down Expand Up @@ -92,6 +95,7 @@

boost_tree <-
function(mode = "unknown",
engine = "xgboost",
mtry = NULL, trees = NULL, min_n = NULL,
tree_depth = NULL, learn_rate = NULL,
loss_reduction = NULL,
Expand All @@ -114,7 +118,7 @@ boost_tree <-
eng_args = NULL,
mode,
method = NULL,
engine = NULL
engine = engine
)
}

Expand Down
10 changes: 7 additions & 3 deletions R/decision_tree.R
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,12 @@
#' functions. If parameters need to be modified, `update()` can be used
#' in lieu of recreating the object from scratch.
#'
#' @param mode A single character string for the type of model.
#' @param mode A single character string for the prediction outcome mode.
#' Possible values for this model are "unknown", "regression", or
#' "classification".
#' @param engine A single character string specifying what computational engine
#' to use for fitting. Possible engines are listed below. The default for this
#' model is `"rpart"`.
#' @param cost_complexity A positive number for the the cost/complexity
#' parameter (a.k.a. `Cp`) used by CART models (`rpart` only).
#' @param tree_depth An integer for maximum depth of the tree.
Expand Down Expand Up @@ -69,7 +72,8 @@
#' @export

decision_tree <-
function(mode = "unknown", cost_complexity = NULL, tree_depth = NULL, min_n = NULL) {
function(mode = "unknown", engine = "rpart", cost_complexity = NULL,
tree_depth = NULL, min_n = NULL) {

args <- list(
cost_complexity = enquo(cost_complexity),
Expand All @@ -83,7 +87,7 @@ decision_tree <-
eng_args = NULL,
mode = mode,
method = NULL,
engine = NULL
engine = engine
)
}

Expand Down
8 changes: 6 additions & 2 deletions R/linear_reg.R
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,11 @@
#' here (`NULL`), the values are taken from the underlying model
#' functions. If parameters need to be modified, `update()` can be used
#' in lieu of recreating the object from scratch.
#' @param mode A single character string for the type of model.
#' @param mode A single character string for the prediction outcome mode.
#' The only possible value for this model is "regression".
#' @param engine A single character string specifying what computational engine
#' to use for fitting. Possible engines are listed below. The default for this
#' model is `"lm"`.
#' @param penalty A non-negative number representing the total
#' amount of regularization (`glmnet`, `keras`, and `spark` only).
#' For `keras` models, this corresponds to purely L2 regularization
Expand Down Expand Up @@ -70,6 +73,7 @@
#' @importFrom purrr map_lgl
linear_reg <-
function(mode = "regression",
engine = "lm",
penalty = NULL,
mixture = NULL) {

Expand All @@ -84,7 +88,7 @@ linear_reg <-
eng_args = NULL,
mode = mode,
method = NULL,
engine = NULL
engine = engine
)
}

Expand Down
8 changes: 6 additions & 2 deletions R/logistic_reg.R
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,11 @@
#' here (`NULL`), the values are taken from the underlying model
#' functions. If parameters need to be modified, `update()` can be used
#' in lieu of recreating the object from scratch.
#' @param mode A single character string for the type of model.
#' @param mode A single character string for the prediction outcome mode.
#' The only possible value for this model is "classification".
#' @param engine A single character string specifying what computational engine
#' to use for fitting. Possible engines are listed below. The default for this
#' model is `"glm"`.
#' @param penalty A non-negative number representing the total
#' amount of regularization (`glmnet`, `LiblineaR`, `keras`, and `spark` only).
#' For `keras` models, this corresponds to purely L2 regularization
Expand Down Expand Up @@ -69,6 +72,7 @@
#' @importFrom purrr map_lgl
logistic_reg <-
function(mode = "classification",
engine = "glm",
penalty = NULL,
mixture = NULL) {

Expand All @@ -83,7 +87,7 @@ logistic_reg <-
eng_args = NULL,
mode = mode,
method = NULL,
engine = NULL
engine = engine
)
}

Expand Down
9 changes: 6 additions & 3 deletions R/mars.R
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,12 @@
#' functions. If parameters need to be modified, `update()` can be used
#' in lieu of recreating the object from scratch.
#'
#' @param mode A single character string for the type of model.
#' @param mode A single character string for the prediction outcome mode.
#' Possible values for this model are "unknown", "regression", or
#' "classification".
#' @param engine A single character string specifying what computational engine
#' to use for fitting. Possible engines are listed below. The default for this
#' model is `"earth"`.
#' @param num_terms The number of features that will be retained in the
#' final model, including the intercept.
#' @param prod_degree The highest possible interaction degree.
Expand All @@ -45,7 +48,7 @@
#' mars(mode = "regression", num_terms = 5)
#' @export
mars <-
function(mode = "unknown",
function(mode = "unknown", engine = "earth",
num_terms = NULL, prod_degree = NULL, prune_method = NULL) {

args <- list(
Expand All @@ -60,7 +63,7 @@ mars <-
eng_args = NULL,
mode = mode,
method = NULL,
engine = NULL
engine = engine
)
}

Expand Down
9 changes: 6 additions & 3 deletions R/mlp.R
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,12 @@
#' If parameters need to be modified, `update()` can be used
#' in lieu of recreating the object from scratch.
#'
#' @param mode A single character string for the type of model.
#' @param mode A single character string for the prediction outcome mode.
#' Possible values for this model are "unknown", "regression", or
#' "classification".
#' @param engine A single character string specifying what computational engine
#' to use for fitting. Possible engines are listed below. The default for this
#' model is `"nnet"`.
#' @param hidden_units An integer for the number of units in the hidden model.
#' @param penalty A non-negative numeric value for the amount of weight
#' decay.
Expand Down Expand Up @@ -63,7 +66,7 @@
#' @export

mlp <-
function(mode = "unknown",
function(mode = "unknown", engine = "nnet",
hidden_units = NULL, penalty = NULL, dropout = NULL, epochs = NULL,
activation = NULL) {

Expand All @@ -81,7 +84,7 @@ mlp <-
eng_args = NULL,
mode = mode,
method = NULL,
engine = NULL
engine = engine
)
}

Expand Down
10 changes: 7 additions & 3 deletions R/multinom_reg.R
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,11 @@
#' here (`NULL`), the values are taken from the underlying model
#' functions. If parameters need to be modified, `update()` can be used
#' in lieu of recreating the object from scratch.
#' @param mode A single character string for the type of model.
#' @param mode A single character string for the prediction outcome mode.
#' The only possible value for this model is "classification".
#' @param engine A single character string specifying what computational engine
#' to use for fitting. Possible engines are listed below. The default for this
#' model is `"nnet"`.
#' @param penalty A non-negative number representing the total
#' amount of regularization (`glmnet`, `keras`, and `spark` only).
#' For `keras` models, this corresponds to purely L2 regularization
Expand All @@ -33,7 +36,7 @@
#' The model can be created using the `fit()` function using the
#' following _engines_:
#' \itemize{
#' \item \pkg{R}: `"glmnet"` (the default), `"nnet"`
#' \item \pkg{R}: `"nnet"` (the default), `"glmnet"`
#' \item \pkg{Spark}: `"spark"`
#' \item \pkg{keras}: `"keras"`
#' }
Expand Down Expand Up @@ -64,6 +67,7 @@
#' @importFrom purrr map_lgl
multinom_reg <-
function(mode = "classification",
engine = "nnet",
penalty = NULL,
mixture = NULL) {

Expand All @@ -78,7 +82,7 @@ multinom_reg <-
eng_args = NULL,
mode = mode,
method = NULL,
engine = NULL
engine = engine
)
}

Expand Down
9 changes: 6 additions & 3 deletions R/nearest_neighbor.R
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,12 @@
#' here (`NULL`), the values are taken from the underlying model
#' functions. If parameters need to be modified, `update()` can be used
#' in lieu of recreating the object from scratch.
#' @param mode A single character string for the type of model.
#' @param mode A single character string for the prediction outcome mode.
#' Possible values for this model are `"unknown"`, `"regression"`, or
#' `"classification"`.
#'
#' @param engine A single character string specifying what computational engine
#' to use for fitting. Possible engines are listed below. The default for this
#' model is `"kknn"`.
#' @param neighbors A single integer for the number of neighbors
#' to consider (often called `k`). For \pkg{kknn}, a value of 5
#' is used if `neighbors` is not specified.
Expand Down Expand Up @@ -57,6 +59,7 @@
#'
#' @export
nearest_neighbor <- function(mode = "unknown",
engine = "kknn",
neighbors = NULL,
weight_func = NULL,
dist_power = NULL) {
Expand All @@ -72,7 +75,7 @@ nearest_neighbor <- function(mode = "unknown",
eng_args = NULL,
mode = mode,
method = NULL,
engine = NULL
engine = engine
)
}

Expand Down
15 changes: 10 additions & 5 deletions R/proportional_hazards.R
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,11 @@
#' functions. If parameters need to be modified, `update()` can be used
#' in lieu of recreating the object from scratch.
#'
#' @param mode A single character string for the type of model.
#' @param mode A single character string for the prediction outcome mode.
#' Possible values for this model are "unknown", or "censored regression".
#' @param engine A single character string specifying what computational engine
#' to use for fitting. Possible engines are listed below. The default for this
#' model is `"survival"`.
#' @inheritParams linear_reg
#'
#' @details
Expand All @@ -29,9 +32,11 @@
#' show_engines("proportional_hazards")
#' @keywords internal
#' @export
proportional_hazards <- function(mode = "censored regression",
penalty = NULL,
mixture = NULL) {
proportional_hazards <- function(
mode = "censored regression",
engine = "survival",
penalty = NULL,
mixture = NULL) {

args <- list(
penalty = enquo(penalty),
Expand All @@ -44,7 +49,7 @@ proportional_hazards <- function(mode = "censored regression",
eng_args = NULL,
mode = mode,
method = NULL,
engine = NULL
engine = engine
)
}

Expand Down
9 changes: 6 additions & 3 deletions R/rand_forest.R
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,12 @@
#' functions. If parameters need to be modified, `update()` can be used
#' in lieu of recreating the object from scratch.
#'
#' @param mode A single character string for the type of model.
#' @param mode A single character string for the prediction outcome mode.
#' Possible values for this model are "unknown", "regression", or
#' "classification".
#' @param engine A single character string specifying what computational engine
#' to use for fitting. Possible engines are listed below. The default for this
#' model is `"ranger"`.
#' @param mtry An integer for the number of predictors that will
#' be randomly sampled at each split when creating the tree models.
#' @param trees An integer for the number of trees contained in
Expand Down Expand Up @@ -63,7 +66,7 @@
#' @export

rand_forest <-
function(mode = "unknown", mtry = NULL, trees = NULL, min_n = NULL) {
function(mode = "unknown", engine = "ranger", mtry = NULL, trees = NULL, min_n = NULL) {

args <- list(
mtry = enquo(mtry),
Expand All @@ -77,7 +80,7 @@ rand_forest <-
eng_args = NULL,
mode = mode,
method = NULL,
engine = NULL
engine = engine
)
}

Expand Down
9 changes: 6 additions & 3 deletions R/surv_reg.R
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,11 @@
#' `strata` function cannot be used. To achieve the same effect,
#' the extra parameter roles can be used (as described above).
#'
#' @param mode A single character string for the type of model.
#' @param mode A single character string for the prediction outcome mode.
#' The only possible value for this model is "regression".
#' @param engine A single character string specifying what computational engine
#' to use for fitting. Possible engines are listed below. The default for this
#' model is `"survival"`.
#' @param dist A character string for the outcome distribution. "weibull" is
#' the default.
#' @details
Expand Down Expand Up @@ -65,7 +68,7 @@
#'
#' @keywords internal
#' @export
surv_reg <- function(mode = "regression", dist = NULL) {
surv_reg <- function(mode = "regression", engine = "survival", dist = NULL) {

lifecycle::deprecate_soft("0.1.6", "surv_reg()", "survival_reg()")

Expand All @@ -79,7 +82,7 @@ surv_reg <- function(mode = "regression", dist = NULL) {
eng_args = NULL,
mode = mode,
method = NULL,
engine = NULL
engine = engine
)
}

Expand Down
9 changes: 6 additions & 3 deletions R/survival_reg.R
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,11 @@
#' functions. If parameters need to be modified, `update()` can be used
#' in lieu of recreating the object from scratch.
#'
#' @param mode A single character string for the type of model.
#' @param mode A single character string for the prediction outcome mode.
#' The only possible value for this model is "censored regression".
#' @param engine A single character string specifying what computational engine
#' to use for fitting. Possible engines are listed below. The default for this
#' model is `"survival"`.
#' @param dist A character string for the outcome distribution. "weibull" is
#' the default.
#' @details
Expand All @@ -34,7 +37,7 @@
#' survival_reg(dist = varying())
#' @keywords internal
#' @export
survival_reg <- function(mode = "censored regression", dist = NULL) {
survival_reg <- function(mode = "censored regression", engine = "survival", dist = NULL) {

args <- list(
dist = enquo(dist)
Expand All @@ -46,7 +49,7 @@ survival_reg <- function(mode = "censored regression", dist = NULL) {
eng_args = NULL,
mode = mode,
method = NULL,
engine = NULL
engine = engine
)
}

Expand Down
Loading