Skip to content

Commit 35c0a60

Browse files
NarineKshivaram
authored andcommitted
[SPARK-12922][SPARKR][WIP] Implement gapply() on DataFrame in SparkR
## What changes were proposed in this pull request? gapply() applies an R function on groups grouped by one or more columns of a DataFrame, and returns a DataFrame. It is like GroupedDataSet.flatMapGroups() in the Dataset API. Please, let me know what do you think and if you have any ideas to improve it. Thank you! ## How was this patch tested? Unit tests. 1. Primitive test with different column types 2. Add a boolean column 3. Compute average by a group Author: Narine Kokhlikyan <[email protected]> Author: NarineK <[email protected]> Closes #12836 from NarineK/gapply2. (cherry picked from commit 7c6c692) Signed-off-by: Shivaram Venkataraman <[email protected]>
1 parent f0279b0 commit 35c0a60

File tree

14 files changed

+540
-65
lines changed

14 files changed

+540
-65
lines changed

R/pkg/NAMESPACE

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@ exportMethods("arrange",
6262
"filter",
6363
"first",
6464
"freqItems",
65+
"gapply",
6566
"group_by",
6667
"groupBy",
6768
"head",

R/pkg/R/DataFrame.R

Lines changed: 81 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1181,7 +1181,7 @@ dapplyInternal <- function(x, func, schema) {
11811181
#' func should have only one parameter, to which a data.frame corresponds
11821182
#' to each partition will be passed.
11831183
#' The output of func should be a data.frame.
1184-
#' @param schema The schema of the resulting DataFrame after the function is applied.
1184+
#' @param schema The schema of the resulting SparkDataFrame after the function is applied.
11851185
#' It must match the output of func.
11861186
#' @family SparkDataFrame functions
11871187
#' @rdname dapply
@@ -1267,6 +1267,86 @@ setMethod("dapplyCollect",
12671267
ldf
12681268
})
12691269

1270+
#' gapply
1271+
#'
1272+
#' Group the SparkDataFrame using the specified columns and apply the R function to each
1273+
#' group.
1274+
#'
1275+
#' @param x A SparkDataFrame
1276+
#' @param cols Grouping columns
1277+
#' @param func A function to be applied to each group partition specified by grouping
1278+
#' column of the SparkDataFrame. The function `func` takes as argument
1279+
#' a key - grouping columns and a data frame - a local R data.frame.
1280+
#' The output of `func` is a local R data.frame.
1281+
#' @param schema The schema of the resulting SparkDataFrame after the function is applied.
1282+
#' The schema must match to output of `func`. It has to be defined for each
1283+
#' output column with preferred output column name and corresponding data type.
1284+
#' @family SparkDataFrame functions
1285+
#' @rdname gapply
1286+
#' @name gapply
1287+
#' @export
1288+
#' @examples
1289+
#'
1290+
#' \dontrun{
1291+
#' Computes the arithmetic mean of the second column by grouping
1292+
#' on the first and third columns. Output the grouping values and the average.
1293+
#'
1294+
#' df <- createDataFrame (
1295+
#' list(list(1L, 1, "1", 0.1), list(1L, 2, "1", 0.2), list(3L, 3, "3", 0.3)),
1296+
#' c("a", "b", "c", "d"))
1297+
#'
1298+
#' Here our output contains three columns, the key which is a combination of two
1299+
#' columns with data types integer and string and the mean which is a double.
1300+
#' schema <- structType(structField("a", "integer"), structField("c", "string"),
1301+
#' structField("avg", "double"))
1302+
#' df1 <- gapply(
1303+
#' df,
1304+
#' list("a", "c"),
1305+
#' function(key, x) {
1306+
#' y <- data.frame(key, mean(x$b), stringsAsFactors = FALSE)
1307+
#' },
1308+
#' schema)
1309+
#' collect(df1)
1310+
#'
1311+
#' Result
1312+
#' ------
1313+
#' a c avg
1314+
#' 3 3 3.0
1315+
#' 1 1 1.5
1316+
#'
1317+
#' Fits linear models on iris dataset by grouping on the 'Species' column and
1318+
#' using 'Sepal_Length' as a target variable, 'Sepal_Width', 'Petal_Length'
1319+
#' and 'Petal_Width' as training features.
1320+
#'
1321+
#' df <- createDataFrame (iris)
1322+
#' schema <- structType(structField("(Intercept)", "double"),
1323+
#' structField("Sepal_Width", "double"),structField("Petal_Length", "double"),
1324+
#' structField("Petal_Width", "double"))
1325+
#' df1 <- gapply(
1326+
#' df,
1327+
#' list(df$"Species"),
1328+
#' function(key, x) {
1329+
#' m <- suppressWarnings(lm(Sepal_Length ~
1330+
#' Sepal_Width + Petal_Length + Petal_Width, x))
1331+
#' data.frame(t(coef(m)))
1332+
#' }, schema)
1333+
#' collect(df1)
1334+
#'
1335+
#'Result
1336+
#'---------
1337+
#' Model (Intercept) Sepal_Width Petal_Length Petal_Width
1338+
#' 1 0.699883 0.3303370 0.9455356 -0.1697527
1339+
#' 2 1.895540 0.3868576 0.9083370 -0.6792238
1340+
#' 3 2.351890 0.6548350 0.2375602 0.2521257
1341+
#'
1342+
#'}
1343+
setMethod("gapply",
1344+
signature(x = "SparkDataFrame"),
1345+
function(x, cols, func, schema) {
1346+
grouped <- do.call("groupBy", c(x, cols))
1347+
gapply(grouped, func, schema)
1348+
})
1349+
12701350
############################## RDD Map Functions ##################################
12711351
# All of the following functions mirror the existing RDD map functions, #
12721352
# but allow for use with DataFrames by first converting to an RRDD before calling #

R/pkg/R/deserialize.R

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -197,6 +197,36 @@ readMultipleObjects <- function(inputCon) {
197197
data # this is a list of named lists now
198198
}
199199

200+
readMultipleObjectsWithKeys <- function(inputCon) {
201+
# readMultipleObjectsWithKeys will read multiple continuous objects from
202+
# a DataOutputStream. There is no preceding field telling the count
203+
# of the objects, so the number of objects varies, we try to read
204+
# all objects in a loop until the end of the stream. This function
205+
# is for use by gapply. Each group of rows is followed by the grouping
206+
# key for this group which is then followed by next group.
207+
keys <- list()
208+
data <- list()
209+
subData <- list()
210+
while (TRUE) {
211+
# If reaching the end of the stream, type returned should be "".
212+
type <- readType(inputCon)
213+
if (type == "") {
214+
break
215+
} else if (type == "r") {
216+
type <- readType(inputCon)
217+
# A grouping boundary detected
218+
key <- readTypedObject(inputCon, type)
219+
index <- length(data) + 1L
220+
data[[index]] <- subData
221+
keys[[index]] <- key
222+
subData <- list()
223+
} else {
224+
subData[[length(subData) + 1L]] <- readTypedObject(inputCon, type)
225+
}
226+
}
227+
list(keys = keys, data = data) # this is a list of keys and corresponding data
228+
}
229+
200230
readRowList <- function(obj) {
201231
# readRowList is meant for use inside an lapply. As a result, it is
202232
# necessary to open a standalone connection for the row and consume

R/pkg/R/generics.R

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -454,6 +454,10 @@ setGeneric("dapply", function(x, func, schema) { standardGeneric("dapply") })
454454
#' @export
455455
setGeneric("dapplyCollect", function(x, func) { standardGeneric("dapplyCollect") })
456456

457+
#' @rdname gapply
458+
#' @export
459+
setGeneric("gapply", function(x, ...) { standardGeneric("gapply") })
460+
457461
#' @rdname summary
458462
#' @export
459463
setGeneric("describe", function(x, col, ...) { standardGeneric("describe") })

R/pkg/R/group.R

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -142,3 +142,65 @@ createMethods <- function() {
142142
}
143143

144144
createMethods()
145+
146+
#' gapply
147+
#'
148+
#' Applies a R function to each group in the input GroupedData
149+
#'
150+
#' @param x a GroupedData
151+
#' @param func A function to be applied to each group partition specified by GroupedData.
152+
#' The function `func` takes as argument a key - grouping columns and
153+
#' a data frame - a local R data.frame.
154+
#' The output of `func` is a local R data.frame.
155+
#' @param schema The schema of the resulting SparkDataFrame after the function is applied.
156+
#' The schema must match to output of `func`. It has to be defined for each
157+
#' output column with preferred output column name and corresponding data type.
158+
#' @return a SparkDataFrame
159+
#' @rdname gapply
160+
#' @name gapply
161+
#' @examples
162+
#' \dontrun{
163+
#' Computes the arithmetic mean of the second column by grouping
164+
#' on the first and third columns. Output the grouping values and the average.
165+
#'
166+
#' df <- createDataFrame (
167+
#' list(list(1L, 1, "1", 0.1), list(1L, 2, "1", 0.2), list(3L, 3, "3", 0.3)),
168+
#' c("a", "b", "c", "d"))
169+
#'
170+
#' Here our output contains three columns, the key which is a combination of two
171+
#' columns with data types integer and string and the mean which is a double.
172+
#' schema <- structType(structField("a", "integer"), structField("c", "string"),
173+
#' structField("avg", "double"))
174+
#' df1 <- gapply(
175+
#' df,
176+
#' list("a", "c"),
177+
#' function(key, x) {
178+
#' y <- data.frame(key, mean(x$b), stringsAsFactors = FALSE)
179+
#' },
180+
#' schema)
181+
#' collect(df1)
182+
#'
183+
#' Result
184+
#' ------
185+
#' a c avg
186+
#' 3 3 3.0
187+
#' 1 1 1.5
188+
#' }
189+
setMethod("gapply",
190+
signature(x = "GroupedData"),
191+
function(x, func, schema) {
192+
try(if (is.null(schema)) stop("schema cannot be NULL"))
193+
packageNamesArr <- serialize(.sparkREnv[[".packages"]],
194+
connection = NULL)
195+
broadcastArr <- lapply(ls(.broadcastNames),
196+
function(name) { get(name, .broadcastNames) })
197+
sdf <- callJStatic(
198+
"org.apache.spark.sql.api.r.SQLUtils",
199+
"gapply",
200+
x@sgd,
201+
serialize(cleanClosure(func), connection = NULL),
202+
packageNamesArr,
203+
broadcastArr,
204+
schema$jobj)
205+
dataFrame(sdf)
206+
})

R/pkg/inst/tests/testthat/test_sparkSQL.R

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2146,6 +2146,71 @@ test_that("repartition by columns on DataFrame", {
21462146
expect_equal(nrow(df1), 2)
21472147
})
21482148

2149+
test_that("gapply() on a DataFrame", {
2150+
df <- createDataFrame (
2151+
list(list(1L, 1, "1", 0.1), list(1L, 2, "1", 0.2), list(3L, 3, "3", 0.3)),
2152+
c("a", "b", "c", "d"))
2153+
expected <- collect(df)
2154+
df1 <- gapply(df, list("a"), function(key, x) { x }, schema(df))
2155+
actual <- collect(df1)
2156+
expect_identical(actual, expected)
2157+
2158+
# Computes the sum of second column by grouping on the first and third columns
2159+
# and checks if the sum is larger than 2
2160+
schema <- structType(structField("a", "integer"), structField("e", "boolean"))
2161+
df2 <- gapply(
2162+
df,
2163+
list(df$"a", df$"c"),
2164+
function(key, x) {
2165+
y <- data.frame(key[1], sum(x$b) > 2)
2166+
},
2167+
schema)
2168+
actual <- collect(df2)$e
2169+
expected <- c(TRUE, TRUE)
2170+
expect_identical(actual, expected)
2171+
2172+
# Computes the arithmetic mean of the second column by grouping
2173+
# on the first and third columns. Output the groupping value and the average.
2174+
schema <- structType(structField("a", "integer"), structField("c", "string"),
2175+
structField("avg", "double"))
2176+
df3 <- gapply(
2177+
df,
2178+
list("a", "c"),
2179+
function(key, x) {
2180+
y <- data.frame(key, mean(x$b), stringsAsFactors = FALSE)
2181+
},
2182+
schema)
2183+
actual <- collect(df3)
2184+
actual <- actual[order(actual$a), ]
2185+
rownames(actual) <- NULL
2186+
expected <- collect(select(df, "a", "b", "c"))
2187+
expected <- data.frame(aggregate(expected$b, by = list(expected$a, expected$c), FUN = mean))
2188+
colnames(expected) <- c("a", "c", "avg")
2189+
expected <- expected[order(expected$a), ]
2190+
rownames(expected) <- NULL
2191+
expect_identical(actual, expected)
2192+
2193+
irisDF <- suppressWarnings(createDataFrame (iris))
2194+
schema <- structType(structField("Sepal_Length", "double"), structField("Avg", "double"))
2195+
# Groups by `Sepal_Length` and computes the average for `Sepal_Width`
2196+
df4 <- gapply(
2197+
cols = list("Sepal_Length"),
2198+
irisDF,
2199+
function(key, x) {
2200+
y <- data.frame(key, mean(x$Sepal_Width), stringsAsFactors = FALSE)
2201+
},
2202+
schema)
2203+
actual <- collect(df4)
2204+
actual <- actual[order(actual$Sepal_Length), ]
2205+
rownames(actual) <- NULL
2206+
agg_local_df <- data.frame(aggregate(iris$Sepal.Width, by = list(iris$Sepal.Length), FUN = mean),
2207+
stringsAsFactors = FALSE)
2208+
colnames(agg_local_df) <- c("Sepal_Length", "Avg")
2209+
expected <- agg_local_df[order(agg_local_df$Sepal_Length), ]
2210+
rownames(expected) <- NULL
2211+
expect_identical(actual, expected)
2212+
})
2213+
21492214
test_that("Window functions on a DataFrame", {
21502215
setHiveContext(sc)
21512216
df <- createDataFrame(list(list(1L, "1"), list(2L, "2"), list(1L, "1"), list(2L, "2")),

0 commit comments

Comments
 (0)