Skip to content

Commit 24972d2

Browse files
committed
wip
1 parent a1b8b5f commit 24972d2

11 files changed

+84
-95
lines changed

R/new_epipredict_steps/step_yeo_johnson.R

Lines changed: 35 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
#' `step_YeoJohnson2()` creates a *specification* of a recipe step that will
44
#' transform data using a Yeo-Johnson transformation. This fork works with panel
55
#' data and is meant for epidata.
6+
#' TODO: Do an edit pass on this docstring.
67
#'
78
#' @inheritParams step_center
89
#' @param lambdas A numeric vector of transformation values. This
@@ -69,11 +70,21 @@
6970
#' tidy(yj_transform, number = 1)
7071
#' tidy(yj_estimates, number = 1)
7172
step_YeoJohnson2 <-
72-
function(recipe, ..., role = NA, trained = FALSE,
73-
lambdas = NULL, na_lambda_fill = 1 / 4, limits = c(-5, 5), num_unique = 5,
74-
na_rm = TRUE,
75-
skip = FALSE,
76-
id = rand_id("YeoJohnson2")) {
73+
function(
74+
recipe,
75+
...,
76+
role = NA,
77+
trained = FALSE,
78+
lambdas = NULL,
79+
na_lambda_fill = 1 / 4,
80+
limits = c(-5, 5),
81+
num_unique = 5,
82+
na_rm = TRUE,
83+
skip = FALSE,
84+
id = rand_id("YeoJohnson2")
85+
) {
86+
# TODO: Add arg validations.
87+
# TODO: Improve arg names.
7788
add_step(
7889
recipe,
7990
step_YeoJohnson2_new(
@@ -115,17 +126,18 @@ prep.step_YeoJohnson2 <- function(x, training, info = NULL, ...) {
115126
recipes:::check_number_whole(x$num_unique, args = "num_unique")
116127
recipes:::check_bool(x$na_rm, arg = "na_rm")
117128
if (!is.numeric(x$limits) || any(is.na(x$limits)) || length(x$limits) != 2) {
118-
cli::cli_abort("{.arg limits} should be a numeric vector with two values,
119-
not {.obj_type_friendly {x$limits}}")
129+
cli::cli_abort(
130+
"{.arg limits} should be a numeric vector with two values,
131+
not {.obj_type_friendly {x$limits}}"
132+
)
120133
}
121134

122-
x$limits <- sort(x$limits)
123-
124135
values <- training %>%
125-
group_by(geo_value) %>%
126-
summarise(across(all_of(col_names), ~ estimate_yj(.x, x$limits, x$num_unique, x$na_rm))) %>%
127-
ungroup() %>%
128-
rename_with(~ paste0("lambda_", .x), -geo_value)
136+
summarise(
137+
across(all_of(col_names), ~ estimate_yj(.x, x$limits, x$num_unique, x$na_rm)),
138+
.by = key_colnames(training, exclude = "time_value")
139+
) %>%
140+
rename_with(~ paste0("lambda_", .x), -all_of(key_colnames(training, exclude = "time_value")))
129141

130142
# Check for NAs in any of the lambda_ columns
131143
for (col in col_names) {
@@ -137,17 +149,12 @@ prep.step_YeoJohnson2 <- function(x, training, info = NULL, ...) {
137149
),
138150
call = rlang::caller_fn()
139151
)
140-
values <- values %>%
141-
mutate(
142-
!!sym(paste0("lambda_", col)) := ifelse(
143-
is.na(!!sym(paste0("lambda_", col))),
144-
x$na_lambda_fill,
145-
!!sym(paste0("lambda_", col))
146-
)
147-
)
148152
}
149153
}
150154

155+
values <- values %>%
156+
mutate(across(starts_with("lambda_"), \(col) ifelse(is.na(col), x$na_lambda_fill, col)))
157+
151158
step_YeoJohnson2_new(
152159
terms = x$terms,
153160
role = x$role,
@@ -168,11 +175,12 @@ bake.step_YeoJohnson2 <- function(object, new_data, ...) {
168175
col_names <- object$terms %>% purrr::map_chr(rlang::as_name)
169176
check_new_data(col_names, object, new_data)
170177

171-
new_data %<>% left_join(object$lambdas, by = "geo_value")
178+
new_data %<>% left_join(object$lambdas, by = key_colnames(new_data, exclude = "time_value"))
172179
for (col in col_names) {
173180
new_data <- new_data %>%
174181
rowwise() %>%
175182
mutate(!!col := yj_transform(!!sym(col), !!sym(paste0("lambda_", col))))
183+
# mutate(across(col_names, ~ yj_transform(.x, !!sym(paste0("lambda_", .x)))))
176184
}
177185
new_data %>%
178186
select(-starts_with("lambda_")) %>%
@@ -260,11 +268,7 @@ yj_obj <- function(lam, dat, ind_neg, const) {
260268
#' @keywords internal
261269
#' @rdname recipes-internal
262270
#' @export
263-
estimate_yj <- function(dat,
264-
limits = c(-5, 5),
265-
num_unique = 5,
266-
na_rm = TRUE,
267-
call = caller_env(2)) {
271+
estimate_yj <- function(dat, limits = c(-5, 5), num_unique = 5, na_rm = TRUE, call = caller_env(2)) {
268272
na_rows <- which(is.na(dat))
269273
if (length(na_rows) > 0) {
270274
if (na_rm) {
@@ -305,7 +309,7 @@ estimate_yj <- function(dat,
305309
lam
306310
}
307311

308-
309-
#' @rdname tidy.recipe
310-
#' @export
311-
tidy.step_YeoJohnson2 <- tidy.step_BoxCox2
312+
# #
313+
# #' @rdname tidy.recipe
314+
# #' @export
315+
# tidy.step_YeoJohnson2 <- tidy.step_BoxCox2

renv/activate.R

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -135,12 +135,12 @@ local({
135135

136136
# R help links
137137
pattern <- "`\\?(renv::(?:[^`])+)`"
138-
replacement <- "`\033]8;;x-r-help:\\1\a?\\1\033]8;;\a`"
138+
replacement <- "`\033]8;;ide:help:\\1\a?\\1\033]8;;\a`"
139139
text <- gsub(pattern, replacement, text, perl = TRUE)
140140

141141
# runnable code
142142
pattern <- "`(renv::(?:[^`])+)`"
143-
replacement <- "`\033]8;;x-r-run:\\1\a\\1\033]8;;\a`"
143+
replacement <- "`\033]8;;ide:run:\\1\a\\1\033]8;;\a`"
144144
text <- gsub(pattern, replacement, text, perl = TRUE)
145145

146146
# return ansified text

tests/testthat/test-daily-weekly-archive.R

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
source(here::here("R", "load_all.R"))
1+
suppressPackageStartupMessages(source(here::here("R", "load_all.R")))
22

33
# Works correctly if you have exactly one version where the previous Friday data
44
# is the latest so it is ignored and the week before THAT is summed (10-27 to

tests/testthat/test-data-whitening.R

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
source(here::here("R", "load_all.R"))
1+
suppressPackageStartupMessages(source(here::here("R", "load_all.R")))
22
real_ex <- epidatasets::covid_case_death_rates %>%
33
as_tibble() %>%
44
mutate(source = "same") %>%

tests/testthat/test-forecaster-utils.R

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
source(here::here("R", "load_all.R"))
1+
suppressPackageStartupMessages(source(here::here("R", "load_all.R")))
22

33
test_that("sanitize_args_predictors_trainer", {
44
epi_data <- epidatasets::covid_case_death_rates

tests/testthat/test-forecasters-basics.R

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
source(here::here("R", "load_all.R"))
1+
suppressPackageStartupMessages(source(here::here("R", "load_all.R")))
22
testthat::local_edition(3)
33
# TODO better way to do this than copypasta
44
forecasters <- list(

tests/testthat/test-forecasters-data.R

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
source(here::here("R", "load_all.R"))
1+
suppressPackageStartupMessages(source(here::here("R", "load_all.R")))
22

33
testthat::skip("Optional, long-running tests skipped.")
44

tests/testthat/test-latency_adjusting.R

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
source(here::here("R", "load_all.R"))
1+
suppressPackageStartupMessages(source(here::here("R", "load_all.R")))
22

33
test_that("extend_ahead", {
44
# testing that POSIXct converts correctly (as well as basic types)

tests/testthat/test-step-training-window.R

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
source(here::here("R", "load_all.R"))
1+
suppressPackageStartupMessages(source(here::here("R", "load_all.R")))
22

33
data <- tribble(
44
~geo_value, ~time_value, ~version, ~value,

tests/testthat/test-transforms.R

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
source(here::here("R", "load_all.R"))
1+
suppressPackageStartupMessages(source(here::here("R", "load_all.R")))
22

33
n_days <- 20
44
removed_date <- 10

0 commit comments

Comments
 (0)