Skip to content

Commit 255b67e

Browse files
authored
Merge pull request #512 from mdancho84/gammodels
add gen_additive_mod
2 parents f7ba069 + c2c2c3d commit 255b67e

File tree

10 files changed

+659
-28
lines changed

10 files changed

+659
-28
lines changed

DESCRIPTION

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ Suggests:
5454
nlme,
5555
modeldata,
5656
LiblineaR,
57-
Matrix
57+
Matrix,
58+
mgcv
5859
Remotes:
5960
topepo/C5.0

NAMESPACE

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
S3method(augment,model_fit)
44
S3method(fit,model_spec)
5+
S3method(fit_xy,gen_additive_mod)
56
S3method(fit_xy,model_spec)
67
S3method(glance,model_fit)
78
S3method(has_multi_predict,default)
@@ -46,6 +47,7 @@ S3method(predict_time,model_fit)
4647
S3method(print,boost_tree)
4748
S3method(print,control_parsnip)
4849
S3method(print,decision_tree)
50+
S3method(print,gen_additive_mod)
4951
S3method(print,linear_reg)
5052
S3method(print,logistic_reg)
5153
S3method(print,mars)
@@ -76,6 +78,7 @@ S3method(tidy,nullmodel)
7678
S3method(translate,boost_tree)
7779
S3method(translate,decision_tree)
7880
S3method(translate,default)
81+
S3method(translate,gen_additive_mod)
7982
S3method(translate,linear_reg)
8083
S3method(translate,logistic_reg)
8184
S3method(translate,mars)
@@ -91,6 +94,7 @@ S3method(type_sum,model_fit)
9194
S3method(type_sum,model_spec)
9295
S3method(update,boost_tree)
9396
S3method(update,decision_tree)
97+
S3method(update,gen_additive_mod)
9498
S3method(update,linear_reg)
9599
S3method(update,logistic_reg)
96100
S3method(update,mars)
@@ -139,6 +143,7 @@ export(fit.model_spec)
139143
export(fit_control)
140144
export(fit_xy)
141145
export(fit_xy.model_spec)
146+
export(gen_additive_mod)
142147
export(get_dependency)
143148
export(get_encoding)
144149
export(get_fit)
@@ -298,6 +303,7 @@ importFrom(stats,na.omit)
298303
importFrom(stats,na.pass)
299304
importFrom(stats,predict)
300305
importFrom(stats,qnorm)
306+
importFrom(stats,qt)
301307
importFrom(stats,quantile)
302308
importFrom(stats,setNames)
303309
importFrom(stats,terms)

NEWS.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44

55
* Fix bug in `augment()` when non-predictor, non-outcome variables are included in data (#510).
66

7+
* A model function (`gen_additive_mod()`) was added for generalized additive models.
8+
79
# parsnip 0.1.6
810

911
## Model Specification Changes

R/aaa.R

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,59 @@ convert_stan_interval <- function(x, level = 0.95, lower = TRUE) {
3030
res
3131
}
3232

33+
# ------------------------------------------------------------------------------
34+
35+
#' @importFrom stats qt
36+
# used by logistic_reg() and gen_additive_mod()
37+
logistic_lp_to_conf_int <- function(results, object) {
38+
hf_lvl <- (1 - object$spec$method$pred$conf_int$extras$level)/2
39+
const <-
40+
stats::qt(hf_lvl, df = object$fit$df.residual, lower.tail = FALSE)
41+
trans <- object$fit$family$linkinv
42+
res_2 <-
43+
tibble(
44+
lo = trans(results$fit - const * results$se.fit),
45+
hi = trans(results$fit + const * results$se.fit)
46+
)
47+
res_1 <- res_2
48+
res_1$lo <- 1 - res_2$hi
49+
res_1$hi <- 1 - res_2$lo
50+
lo_nms <- paste0(".pred_lower_", object$lvl)
51+
hi_nms <- paste0(".pred_upper_", object$lvl)
52+
colnames(res_1) <- c(lo_nms[1], hi_nms[1])
53+
colnames(res_2) <- c(lo_nms[2], hi_nms[2])
54+
res <- bind_cols(res_1, res_2)
55+
56+
if (object$spec$method$pred$conf_int$extras$std_error)
57+
res$.std_error <- results$se.fit
58+
res
59+
}
60+
61+
# used by gen_additive_mod()
62+
linear_lp_to_conf_int <-
63+
function(results, object) {
64+
hf_lvl <- (1 - object$spec$method$pred$conf_int$extras$level)/2
65+
const <-
66+
stats::qt(hf_lvl, df = object$fit$df.residual, lower.tail = FALSE)
67+
trans <- object$fit$family$linkinv
68+
res <-
69+
tibble(
70+
.pred_lower = trans(results$fit - const * results$se.fit),
71+
.pred_upper = trans(results$fit + const * results$se.fit)
72+
)
73+
# In case of inverse or other links
74+
if (any(res$.pred_upper < res$.pred_lower)) {
75+
nms <- names(res)
76+
res <- res[, 2:1]
77+
names(res) <- nms
78+
}
79+
80+
if (object$spec$method$pred$conf_int$extras$std_error) {
81+
res$.std_error <- results$se.fit
82+
}
83+
res
84+
}
85+
3386
# ------------------------------------------------------------------------------
3487
# nocov
3588

R/gen_additive_mod.R

Lines changed: 166 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,166 @@
1+
# gen_additive_mod() - General Interface to Linear GAM Models
2+
# - backend: gam
3+
# - prediction:
4+
# - mode = "regression" (default) uses
5+
# - mode = "classification"
6+
7+
#' Generalized additive models (GAMs)
8+
#'
9+
#' `gen_additive_mod()` defines a model that can use smoothed functions of
10+
#' numeric predictors in a generalized linear model.
11+
#'
12+
#' There are different ways to fit this model. See the engine-specific pages
13+
#' for more details
14+
#'
15+
#' More information on how `parsnip` is used for modeling is at
16+
#' \url{https://www.tidymodels.org}.
17+
#'
18+
#' @inheritParams boost_tree
19+
#' @param select_features TRUE or FALSE. If this is TRUE then can add an
20+
#' extra penalty to each term so that it can be penalized to zero.
21+
#' This means that the smoothing parameter estimation that is part of
22+
#' fitting can completely remove terms from the model. If the corresponding
23+
#' smoothing parameter is estimated as zero then the extra penalty has no effect.
24+
#' Use `adjust_deg_free` to increase level of penalization.
25+
#' @param adjust_deg_free If `select_features = TRUE`, then acts as a multiplier for smoothness.
26+
#' Increase this beyond 1 to produce smoother models.
27+
#'
28+
#'
29+
#' @return
30+
#' A `parsnip` model specification
31+
#'
32+
#' @details
33+
#'
34+
#' This function only defines what _type_ of model is being fit. Once an engine
35+
#' is specified, the _method_ to fit the model is also defined.
36+
#'
37+
#' The model is not trained or fit until the [fit.model_spec()] function is used
38+
#' with the data.
39+
#'
40+
#' __gam__
41+
#'
42+
#' This engine uses [mgcv::gam()] and has the following parameters,
43+
#' which can be modified through the [set_engine()] function.
44+
#'
45+
#' ``` {r echo=F}
46+
#' str(mgcv::gam)
47+
#' ```
48+
#'
49+
#' @section Fit Details:
50+
#'
51+
#' __MGCV Formula Interface__
52+
#'
53+
#' Fitting GAMs is accomplished using parameters including:
54+
#'
55+
#' - [mgcv::s()]: GAM spline smooths
56+
#' - [mgcv::te()]: GAM tensor product smooths
57+
#'
58+
#' These are applied in the `fit()` function:
59+
#'
60+
#' ``` r
61+
#' fit(value ~ s(date_mon, k = 12) + s(date_num), data = df)
62+
#' ```
63+
#'
64+
#' @references \url{https://www.tidymodels.org},
65+
#' [_Tidy Models with R_](https://tmwr.org)
66+
#' @examples
67+
#'
68+
#' #show_engines("gen_additive_mod")
69+
#'
70+
#' #gen_additive_mod()
71+
#'
72+
#'
73+
#' @export
74+
gen_additive_mod <- function(mode = "unknown",
75+
select_features = NULL,
76+
adjust_deg_free = NULL) {
77+
78+
args <- list(
79+
select_features = rlang::enquo(select_features),
80+
adjust_deg_free = rlang::enquo(adjust_deg_free)
81+
)
82+
83+
new_model_spec(
84+
"gen_additive_mod",
85+
args = args,
86+
eng_args = NULL,
87+
mode = mode,
88+
method = NULL,
89+
engine = NULL
90+
)
91+
92+
}
93+
94+
#' @export
95+
print.gen_additive_mod <- function(x, ...) {
96+
cat("GAM Specification (", x$mode, ")\n\n", sep = "")
97+
model_printer(x, ...)
98+
99+
if(!is.null(x$method$fit$args)) {
100+
cat("Model fit template:\n")
101+
print(show_call(x))
102+
}
103+
104+
invisible(x)
105+
}
106+
107+
#' @export
108+
#' @rdname parsnip_update
109+
#' @importFrom stats update
110+
#' @inheritParams gen_additive_mod
111+
update.gen_additive_mod <- function(object,
112+
select_features = NULL,
113+
adjust_deg_free = NULL,
114+
parameters = NULL,
115+
fresh = FALSE, ...) {
116+
117+
update_dot_check(...)
118+
119+
if (!is.null(parameters)) {
120+
parameters <- check_final_param(parameters)
121+
}
122+
123+
args <- list(
124+
select_features = rlang::enquo(select_features),
125+
adjust_deg_free = rlang::enquo(adjust_deg_free)
126+
)
127+
128+
args <- update_main_parameters(args, parameters)
129+
130+
if (fresh) {
131+
object$args <- args
132+
} else {
133+
null_args <- purrr::map_lgl(args, null_value)
134+
if (any(null_args))
135+
args <- args[!null_args]
136+
if (length(args) > 0)
137+
object$args[names(args)] <- args
138+
}
139+
140+
new_model_spec(
141+
"gen_additive_mod",
142+
args = object$args,
143+
eng_args = object$eng_args,
144+
mode = object$mode,
145+
method = NULL,
146+
engine = object$engine
147+
)
148+
}
149+
150+
151+
#' @export
152+
translate.gen_additive_mod <- function(x, engine = x$engine, ...) {
153+
if (is.null(engine)) {
154+
message("Used `engine = 'mgcv'` for translation.")
155+
engine <- "gam"
156+
}
157+
x <- translate.default(x, engine, ...)
158+
159+
x
160+
}
161+
162+
#' @export
163+
#' @keywords internal
164+
fit_xy.gen_additive_mod <- function(object, ...) {
165+
rlang::abort("`fit()` must be used with GAM models (due to its use of formulas).")
166+
}

0 commit comments

Comments
 (0)