44# ' that will generate one or more new columns of derived data by "sliding"
55# ' a computation along existing data.
66# '
7- # '
87# ' @inheritParams step_epi_lag
98# ' @param .f A function in one of the following formats:
109# ' 1. An unquoted function name with no arguments, e.g., `mean`
2019# ' argument must be named `.x`. A common, though very difficult to debug
2120# ' error is using something like `function(x) mean`. This will not work
2221# ' because it returns the function mean, rather than `mean(x)`
22+ # ' @param before,after the size of the sliding window on the left and the right
23+ # ' of the center. Usually non-negative integers for data indexed by date, but
24+ # ' more restrictive in other cases (see [epiprocess::epi_slide()] for details).
25+ # ' @param prefix A character string that will be prefixed to the new column.
2326# ' @param f_name a character string of at most 20 characters that describes
2427# ' the function. This will be combined with `prefix` and the columns in `...`
2528# ' to name the result using `{prefix}{f_name}_{column}`. By default it will be determined
2629# ' automatically using `clean_f_name()`.
27- # ' @param before,after non-negative integers.
28- # ' How far `before` and `after` each `time_value` should
29- # ' the sliding window extend? Any value provided for either
30- # ' argument must be a single, non-`NA`, non-negative,
31- # ' [integer-compatible][vctrs::vec_cast] number of time steps. Endpoints of
32- # ' the window are inclusive. Common settings:
33- # ' * For trailing/right-aligned windows from `time_value - time_step(k)` to
34- # ' `time_value`, use `before=k, after=0`. This is the most likely use case
35- # ' for the purposes of forecasting.
36- # ' * For center-aligned windows from `time_value - time_step(k)` to
37- # ' `time_value + time_step(k)`, use `before=k, after=k`.
38- # ' * For leading/left-aligned windows from `time_value` to
39- # ' `time_value + time_step(k)`, use `after=k, after=0`.
4030# '
41- # ' You may also pass a [lubridate::period], like `lubridate::weeks(1)` or a
42- # ' character string that is coercible to a [lubridate::period], like
43- # ' `"2 weeks"`.
4431# ' @template step-return
4532# '
4633# ' @export
@@ -69,9 +56,8 @@ step_epi_slide <-
6956 rlang :: abort(" This recipe step can only operate on an `epi_recipe`." )
7057 }
7158 .f <- validate_slide_fun(.f )
72- arg_is_scalar(before , after )
73- before <- try_period(before )
74- after <- try_period(after )
59+ epiprocess ::: validate_slide_window_arg(before , attributes(recipe $ template )$ metadata $ time_type )
60+ epiprocess ::: validate_slide_window_arg(after , attributes(recipe $ template )$ metadata $ time_type )
7561 arg_is_chr_scalar(role , prefix , id )
7662 arg_is_lgl_scalar(skip )
7763
@@ -126,7 +112,6 @@ step_epi_slide_new <-
126112 }
127113
128114
129-
130115# ' @export
131116prep.step_epi_slide <- function (x , training , info = NULL , ... ) {
132117 col_names <- recipes :: recipes_eval_select(x $ terms , data = training , info = info )
@@ -150,7 +135,6 @@ prep.step_epi_slide <- function(x, training, info = NULL, ...) {
150135}
151136
152137
153-
154138# ' @export
155139bake.step_epi_slide <- function (object , new_data , ... ) {
156140 recipes :: check_new_data(names(object $ columns ), object , new_data )
@@ -170,12 +154,16 @@ bake.step_epi_slide <- function(object, new_data, ...) {
170154 class = " epipredict__step__name_collision_error"
171155 )
172156 }
173- if (any(vapply(c(mean , sum ), \(x ) identical(x , object $ .f ), logical (1L )))) {
174- cli_warn(
175- c(" There is an optimized version of both mean and sum. See `step_epi_slide_mean`, `step_epi_slide_sum`, or `step_epi_slide_opt`." ),
176- class = " epipredict__step_epi_slide__optimized_version"
177- )
178- }
157+ # TODO: Uncomment this whenever we make the optimized versions available.
158+ # if (any(vapply(c(mean, sum), \(x) identical(x, object$.f), logical(1L)))) {
159+ # cli_warn(
160+ # c(
161+ # "There is an optimized version of both mean and sum. See `step_epi_slide_mean`, `step_epi_slide_sum`,
162+ # or `step_epi_slide_opt`."
163+ # ),
164+ # class = "epipredict__step_epi_slide__optimized_version"
165+ # )
166+ # }
179167 epi_slide_wrapper(
180168 new_data ,
181169 object $ before ,
@@ -187,48 +175,51 @@ bake.step_epi_slide <- function(object, new_data, ...) {
187175 object $ prefix
188176 )
189177}
190- # ' wrapper to handle epi_slide particulars
178+
179+
180+ # ' Wrapper to handle epi_slide particulars
181+ # '
191182# ' @description
192183# ' This should simplify somewhat in the future when we can run `epi_slide` on
193184# ' columns. Surprisingly, lapply is several orders of magnitude faster than
194185# ' using roughly equivalent tidy select style.
186+ # '
195187# ' @param fns vector of functions, even if it's length 1.
196188# ' @param group_keys the keys to group by. likely `epi_keys[-1]` (to remove time_value)
189+ # '
197190# ' @importFrom tidyr crossing
198191# ' @importFrom dplyr bind_cols group_by ungroup
199192# ' @importFrom epiprocess epi_slide
200193# ' @keywords internal
201194epi_slide_wrapper <- function (new_data , before , after , columns , fns , fn_names , group_keys , name_prefix ) {
202195 cols_fns <- tidyr :: crossing(col_name = columns , fn_name = fn_names , fn = fns )
196+ # Iterate over the rows of cols_fns. For each row number, we will output a
197+ # transformed column. The first result returns all the original columns along
198+ # with the new column. The rest just return the new column.
203199 seq_len(nrow(cols_fns )) %> %
204- lapply( # iterate over the rows of cols_fns
205- # takes in the row number, outputs the transformed column
206- function (comp_i ) {
207- # extract values from the row
208- col_name <- cols_fns [[comp_i , " col_name" ]]
209- fn_name <- cols_fns [[comp_i , " fn_name" ]]
210- fn <- cols_fns [[comp_i , " fn" ]][[1L ]]
211- result_name <- paste(name_prefix , fn_name , col_name , sep = " _" )
212- result <- new_data %> %
213- group_by(across(all_of(group_keys ))) %> %
214- epi_slide(
215- before = before ,
216- after = after ,
217- new_col_name = result_name ,
218- f = function (slice , geo_key , ref_time_value ) {
219- fn(slice [[col_name ]])
220- }
221- ) %> %
222- ungroup()
223- # the first result needs to include all of the original columns
224- if (comp_i == 1L ) {
225- result
226- } else {
227- # everything else just needs that column transformed
228- result [result_name ]
229- }
200+ lapply(function (comp_i ) {
201+ col_name <- cols_fns [[comp_i , " col_name" ]]
202+ fn_name <- cols_fns [[comp_i , " fn_name" ]]
203+ fn <- cols_fns [[comp_i , " fn" ]][[1L ]]
204+ result_name <- paste(name_prefix , fn_name , col_name , sep = " _" )
205+ result <- new_data %> %
206+ group_by(across(all_of(group_keys ))) %> %
207+ epi_slide(
208+ before = before ,
209+ after = after ,
210+ new_col_name = result_name ,
211+ f = function (slice , geo_key , ref_time_value ) {
212+ fn(slice [[col_name ]])
213+ }
214+ ) %> %
215+ ungroup()
216+
217+ if (comp_i == 1L ) {
218+ result
219+ } else {
220+ result [result_name ]
230221 }
231- ) %> %
222+ } ) %> %
232223 bind_cols()
233224}
234225
@@ -286,33 +277,11 @@ validate_slide_fun <- function(.f) {
286277 cli_abort(" In, `step_epi_slide()`, `.f` may not be missing." )
287278 }
288279 if (rlang :: is_formula(.f , scoped = TRUE )) {
289- if (! is.null(rlang :: f_lhs(.f ))) {
290- cli_abort(" In, `step_epi_slide()`, `.f` must be a one-sided formula." )
291- }
280+ cli_abort(" In, `step_epi_slide()`, `.f` cannot be a formula." )
292281 } else if (rlang :: is_character(.f )) {
293282 .f <- rlang :: as_function(.f )
294283 } else if (! rlang :: is_function(.f )) {
295284 cli_abort(" In, `step_epi_slide()`, `.f` must be a function." )
296285 }
297286 .f
298287}
299-
300- try_period <- function (x ) {
301- err <- is.na(x )
302- if (! err ) {
303- if (is.numeric(x )) {
304- err <- ! rlang :: is_integerish(x ) || x < 0
305- } else {
306- x <- lubridate :: as.period(x )
307- err <- is.na(x )
308- }
309- }
310- if (err ) {
311- cli_abort(paste(
312- " The value supplied to `before` or `after` must be a non-negative integer" ,
313- " a {.cls lubridate::period} or a character scalar that can be coerced" ,
314- ' as a {.cls lubridate::period}, e.g., `"1 week"`.'
315- ), )
316- }
317- x
318- }
0 commit comments