22# '
33# ' Will undo a step_epi_YeoJohnson transformation.
44# '
5- # ' @param frosting a `frosting` postprocessor. The layer will be added to the
6- # ' sequence of operations for this frosting.
7- # ' @param lambdas Internal. A data frame of lambda values to be used for
5+ # ' @inheritParams layer_population_scaling
6+ # ' @param yj_params Internal. A data frame of parameters to be used for
87# ' inverting the transformation.
9- # ' @param ... One or more selector functions to scale variables
10- # ' for this step. See [recipes::selections()] for more details.
118# ' @param by A (possibly named) character vector of variables to join by.
12- # ' @param id a random id string
139# '
1410# ' @return an updated `frosting` postprocessor
1511# ' @export
4137# ' # Compare to the original data.
4238# ' jhu %>% filter(time_value == "2021-12-31")
4339# ' forecast(wf)
44- layer_epi_YeoJohnson <- function (frosting , ... , lambdas = NULL , by = NULL , id = rand_id(" epi_YeoJohnson" )) {
45- checkmate :: assert_tibble(lambdas , min.rows = 1 , null.ok = TRUE )
40+ layer_epi_YeoJohnson <- function (frosting , ... , yj_params = NULL , by = NULL , id = rand_id(" epi_YeoJohnson" )) {
41+ checkmate :: assert_tibble(yj_params , min.rows = 1 , null.ok = TRUE )
4642
4743 add_layer(
4844 frosting ,
4945 layer_epi_YeoJohnson_new(
50- lambdas = lambdas ,
46+ yj_params = yj_params ,
5147 by = by ,
5248 terms = dplyr :: enquos(... ),
5349 id = id
5450 )
5551 )
5652}
5753
58- layer_epi_YeoJohnson_new <- function (lambdas , by , terms , id ) {
59- layer(" epi_YeoJohnson" , lambdas = lambdas , by = by , terms = terms , id = id )
54+ layer_epi_YeoJohnson_new <- function (yj_params , by , terms , id ) {
55+ layer(" epi_YeoJohnson" , yj_params = yj_params , by = by , terms = terms , id = id )
6056}
6157
6258# ' @export
6359# ' @importFrom workflows extract_preprocessor
6460slather.layer_epi_YeoJohnson <- function (object , components , workflow , new_data , ... ) {
6561 rlang :: check_dots_empty()
6662
67- # Get the lambdas from the layer or from the workflow.
68- lambdas <- object $ lambdas %|| % get_lambdas_in_layer(workflow )
63+ # TODO: We will error if we don't have a workflow. Write a check later.
6964
70- # If the by is not specified, try to infer it from the lambdas.
65+ # Get the yj_params from the layer or from the workflow.
66+ yj_params <- object $ yj_params %|| % get_yj_params_in_layer(workflow )
67+
68+ # If the by is not specified, try to infer it from the yj_params.
7169 if (is.null(object $ by )) {
7270 # Assume `layer_predict` has calculated the prediction keys and other
7371 # layers don't change the prediction key colnames:
7472 prediction_key_colnames <- names(components $ keys )
7573 lhs_potential_keys <- prediction_key_colnames
76- rhs_potential_keys <- colnames(select(lambdas , - starts_with(" lambda_ " )))
74+ rhs_potential_keys <- colnames(select(yj_params , - starts_with(" .yj_param_ " )))
7775 object $ by <- intersect(lhs_potential_keys , rhs_potential_keys )
7876 suggested_min_keys <- setdiff(lhs_potential_keys , " time_value" )
7977 if (! all(suggested_min_keys %in% object $ by )) {
@@ -95,16 +93,16 @@ slather.layer_epi_YeoJohnson <- function(object, components, workflow, new_data,
9593 object $ by <- object $ by %|| %
9694 intersect(
9795 epi_keys_only(components $ predictions ),
98- colnames(select(lambdas , - starts_with(" .lambda_ " )))
96+ colnames(select(yj_params , - starts_with(" .yj_param_ " )))
9997 )
10098 joinby <- list (x = names(object $ by ) %|| % object $ by , y = object $ by )
10199 hardhat :: validate_column_names(components $ predictions , joinby $ x )
102- hardhat :: validate_column_names(lambdas , joinby $ y )
100+ hardhat :: validate_column_names(yj_params , joinby $ y )
103101
104- # Join the lambdas .
102+ # Join the yj_params .
105103 components $ predictions <- inner_join(
106104 components $ predictions ,
107- lambdas ,
105+ yj_params ,
108106 by = object $ by ,
109107 relationship = " many-to-one" ,
110108 unmatched = c(" error" , " drop" )
@@ -115,7 +113,7 @@ slather.layer_epi_YeoJohnson <- function(object, components, workflow, new_data,
115113 col_names <- names(pos )
116114
117115 # The `object$terms` is where the user specifies the columns they want to
118- # untransform. We need to match the outcomes with their lambda columns in our
116+ # untransform. We need to match the outcomes with their yj_param columns in our
119117 # parameter table and then apply the inverse transformation.
120118 if (identical(col_names , " .pred" )) {
121119 # In this case, we don't get a hint for the outcome column name, so we need
@@ -130,8 +128,7 @@ slather.layer_epi_YeoJohnson <- function(object, components, workflow, new_data,
130128 magrittr :: extract(, 2 )
131129
132130 components $ predictions <- components $ predictions %> %
133- rowwise() %> %
134- mutate(.pred : = yj_inverse(.pred , !! sym(paste0(" .lambda_" , outcome_cols ))))
131+ mutate(.pred : = yj_inverse(.pred , !! sym(paste0(" .yj_param_" , outcome_cols ))))
135132 } else if (identical(col_names , character (0 ))) {
136133 # Wish I could suggest `all_outcomes()` here, but currently it's the same as
137134 # not specifying any terms. I don't want to spend time with dealing with
@@ -146,10 +143,10 @@ slather.layer_epi_YeoJohnson <- function(object, components, workflow, new_data,
146143 )
147144 } else {
148145 # In this case, we assume that the user has specified the columns they want
149- # transformed here. We then need to determine the lambda columns for each of
146+ # transformed here. We then need to determine the yj_param columns for each of
150147 # these columns. That is, we need to convert a vector of column names like
151148 # c(".pred_ahead_1_case_rate", ".pred_ahead_7_case_rate") to
152- # c("lambda_ahead_1_case_rate ", "lambda_ahead_7_case_rate ").
149+ # c(".yj_param_ahead_1_case_rate ", ".yj_param_ahead_7_case_rate ").
153150 original_outcome_cols <- stringr :: str_match(col_names , " .pred_ahead_\\ d+_(.*)" )[, 2 ]
154151 outcomes_wout_ahead <- stringr :: str_match(names(components $ mold $ outcomes ), " ahead_\\ d+_(.*)" )[, 2 ]
155152 if (any(original_outcome_cols %nin % outcomes_wout_ahead )) {
@@ -163,34 +160,37 @@ slather.layer_epi_YeoJohnson <- function(object, components, workflow, new_data,
163160
164161 for (i in seq_along(col_names )) {
165162 col <- col_names [i ]
166- lambda_col <- paste0(" .lambda_ " , original_outcome_cols [i ])
163+ yj_param_col <- paste0(" .yj_param_ " , original_outcome_cols [i ])
167164 components $ predictions <- components $ predictions %> %
168- rowwise() %> %
169- mutate(!! sym(col ) : = yj_inverse(!! sym(col ), !! sym(lambda_col )))
165+ mutate(!! sym(col ) : = yj_inverse(!! sym(col ), !! sym(yj_param_col )))
170166 }
171167 }
172168
173- # Remove the lambda columns.
169+ # Remove the yj_param columns.
174170 components $ predictions <- components $ predictions %> %
175- select(- any_of(starts_with(" .lambda_ " ))) %> %
171+ select(- any_of(starts_with(" .yj_param_ " ))) %> %
176172 ungroup()
177173 components
178174}
179175
180176# ' @export
181177print.layer_epi_YeoJohnson <- function (x , width = max(20 , options()$ width - 30 ), ... ) {
182- title <- " Yeo-Johnson transformation (see `lambdas ` object for values) on "
178+ title <- " Yeo-Johnson transformation (see `yj_params ` object for values) on "
183179 print_layer(x $ terms , title = title , width = width )
184180}
185181
186182# Inverse Yeo-Johnson transformation
187183#
188- # Inverse of `yj_transform` in step_yeo_johnson.R. Note that this function is
189- # vectorized in x, but not in lambda.
184+ # Inverse of `yj_transform` in step_yeo_johnson.R.
190185yj_inverse <- function (x , lambda , eps = 0.001 ) {
191- if (is.na(lambda )) {
186+ if (any( is.na(lambda ) )) {
192187 return (x )
193188 }
189+ if (length(x ) > 1 && length(lambda ) == 1 ) {
190+ lambda <- rep(lambda , length(x ))
191+ } else if (length(x ) != length(lambda )) {
192+ cli :: cli_abort(" Length of `x` must be equal to length of `lambda`." , call = rlang :: caller_fn())
193+ }
194194 if (! inherits(x , " tbl_df" ) || is.data.frame(x )) {
195195 x <- unlist(x , use.names = FALSE )
196196 } else {
@@ -199,52 +199,58 @@ yj_inverse <- function(x, lambda, eps = 0.001) {
199199 }
200200 }
201201
202- dat_neg <- x < 0
203- ind_neg <- list (is = which(dat_neg ), not = which(! dat_neg ))
204- not_neg <- ind_neg [[" not" ]]
205- is_neg <- ind_neg [[" is" ]]
206-
207202 nn_inv_trans <- function (x , lambda ) {
208203 out <- double(length(x ))
209204 sm_lambdas <- abs(lambda ) < eps
210- out [sm_lambdas ] <- exp(x [sm_lambdas ]) - 1
205+ if (length(sm_lambdas ) > 0 ) {
206+ out [sm_lambdas ] <- exp(x [sm_lambdas ]) - 1
207+ }
211208 x <- x [! sm_lambdas ]
212209 lambda <- lambda [! sm_lambdas ]
213- out [! sm_lambdas ] <- (lambda * x + 1 )^ (1 / lambda ) - 1
210+ if (length(x ) > 0 ) {
211+ out [! sm_lambdas ] <- (lambda * x + 1 )^ (1 / lambda ) - 1
212+ }
214213 out
215214 }
216- }
217215
218216 ng_inv_trans <- function (x , lambda ) {
219- if (abs(lambda - 2 ) < eps ) {
220- # -log(-x + 1)
221- - (exp(- x ) - 1 )
222- } else {
223- # -((-x + 1)^(2 - lambda) - 1) / (2 - lambda)
224- - (((lambda - 2 ) * x + 1 )^ (1 / (2 - lambda )) - 1 )
217+ out <- double(length(x ))
218+ near2_lambdas <- abs(lambda - 2 ) < eps
219+ if (length(near2_lambdas ) > 0 ) {
220+ out [near2_lambdas ] <- - (exp(- x [near2_lambdas ]) - 1 )
221+ }
222+ x <- x [! near2_lambdas ]
223+ lambda <- lambda [! near2_lambdas ]
224+ if (length(x ) > 0 ) {
225+ out [! near2_lambdas ] <- - (((lambda - 2 ) * x + 1 )^ (1 / (2 - lambda )) - 1 )
225226 }
227+ out
226228 }
227229
230+ dat_neg <- x < 0
231+ not_neg <- which(! dat_neg )
232+ is_neg <- which(dat_neg )
233+
228234 if (length(not_neg ) > 0 ) {
229- x [not_neg ] <- nn_inv_trans(x [not_neg ], lambda )
235+ x [not_neg ] <- nn_inv_trans(x [not_neg ], lambda [ not_neg ] )
230236 }
231237
232238 if (length(is_neg ) > 0 ) {
233- x [is_neg ] <- ng_inv_trans(x [is_neg ], lambda )
239+ x [is_neg ] <- ng_inv_trans(x [is_neg ], lambda [ is_neg ] )
234240 }
235241 x
236242}
237243
238- get_lambdas_in_layer <- function (workflow ) {
244+ get_yj_params_in_layer <- function (workflow ) {
239245 this_recipe <- hardhat :: extract_recipe(workflow )
240246 if (! (this_recipe %> % recipes :: detect_step(" epi_YeoJohnson" ))) {
241247 cli_abort(" `layer_epi_YeoJohnson` requires `step_epi_YeoJohnson` in the recipe." , call = rlang :: caller_env())
242248 }
243249 for (step in this_recipe $ steps ) {
244250 if (inherits(step , " step_epi_YeoJohnson" )) {
245- lambdas <- step $ lambdas
251+ yj_params <- step $ yj_params
246252 break
247253 }
248254 }
249- lambdas
255+ yj_params
250256}
0 commit comments