Skip to content

Commit ee6301a

Browse files
clarkfitzgshivaram
authored andcommitted
[SPARK-16785] R dapply doesn't return array or raw columns
Fixed bug in `dapplyCollect` by changing the `compute` function of `worker.R` to explicitly handle raw (binary) vectors. cc shivaram Unit tests Author: Clark Fitzgerald <[email protected]> Closes #14783 from clarkfitzg/SPARK-16785. (cherry picked from commit 9fccde4) Signed-off-by: Shivaram Venkataraman <[email protected]>
1 parent 796577b commit ee6301a

File tree

5 files changed

+72
-1
lines changed

5 files changed

+72
-1
lines changed

R/pkg/R/SQLContext.R

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -202,7 +202,10 @@ getDefaultSqlSource <- function() {
202202
# TODO(davies): support sampling and infer type from NA
203203
createDataFrame.default <- function(data, schema = NULL, samplingRatio = 1.0) {
204204
sparkSession <- getSparkSession()
205+
205206
if (is.data.frame(data)) {
207+
# Convert data into a list of rows. Each row is a list.
208+
206209
# get the names of columns, they will be put into RDD
207210
if (is.null(schema)) {
208211
schema <- names(data)
@@ -227,6 +230,7 @@ createDataFrame.default <- function(data, schema = NULL, samplingRatio = 1.0) {
227230
args <- list(FUN = list, SIMPLIFY = FALSE, USE.NAMES = FALSE)
228231
data <- do.call(mapply, append(args, data))
229232
}
233+
230234
if (is.list(data)) {
231235
sc <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "getJavaSparkContext", sparkSession)
232236
rdd <- parallelize(sc, data)

R/pkg/R/utils.R

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -697,3 +697,18 @@ isMasterLocal <- function(master) {
697697
isSparkRShell <- function() {
698698
grepl(".*shell\\.R$", Sys.getenv("R_PROFILE_USER"), perl = TRUE)
699699
}
700+
701+
# rbind a list of rows with raw (binary) columns
702+
#
703+
# @param inputData a list of rows, with each row a list
704+
# @return data.frame with raw columns as lists
705+
rbindRaws <- function(inputData){
706+
row1 <- inputData[[1]]
707+
rawcolumns <- ("raw" == sapply(row1, class))
708+
709+
listmatrix <- do.call(rbind, inputData)
710+
# A dataframe with all list columns
711+
out <- as.data.frame(listmatrix)
712+
out[!rawcolumns] <- lapply(out[!rawcolumns], unlist)
713+
out
714+
}

R/pkg/inst/tests/testthat/test_sparkSQL.R

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2262,6 +2262,27 @@ test_that("dapply() and dapplyCollect() on a DataFrame", {
22622262
expect_identical(expected, result)
22632263
})
22642264

2265+
test_that("dapplyCollect() on DataFrame with a binary column", {
2266+
2267+
df <- data.frame(key = 1:3)
2268+
df$bytes <- lapply(df$key, serialize, connection = NULL)
2269+
2270+
df_spark <- createDataFrame(df)
2271+
2272+
result1 <- collect(df_spark)
2273+
expect_identical(df, result1)
2274+
2275+
result2 <- dapplyCollect(df_spark, function(x) x)
2276+
expect_identical(df, result2)
2277+
2278+
# A data.frame with a single column of bytes
2279+
scb <- subset(df, select = "bytes")
2280+
scb_spark <- createDataFrame(scb)
2281+
result <- dapplyCollect(scb_spark, function(x) x)
2282+
expect_identical(scb, result)
2283+
2284+
})
2285+
22652286
test_that("repartition by columns on DataFrame", {
22662287
df <- createDataFrame(
22672288
list(list(1L, 1, "1", 0.1), list(1L, 2, "2", 0.2), list(3L, 3, "3", 0.3)),

R/pkg/inst/tests/testthat/test_utils.R

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -182,3 +182,27 @@ test_that("overrideEnvs", {
182182
expect_equal(config[["param_only"]], "blah")
183183
expect_equal(config[["config_only"]], "ok")
184184
})
185+
186+
test_that("rbindRaws", {
187+
188+
# Mixed Column types
189+
r <- serialize(1:5, connection = NULL)
190+
r1 <- serialize(1, connection = NULL)
191+
r2 <- serialize(letters, connection = NULL)
192+
r3 <- serialize(1:10, connection = NULL)
193+
inputData <- list(list(1L, r1, "a", r), list(2L, r2, "b", r),
194+
list(3L, r3, "c", r))
195+
expected <- data.frame(V1 = 1:3)
196+
expected$V2 <- list(r1, r2, r3)
197+
expected$V3 <- c("a", "b", "c")
198+
expected$V4 <- list(r, r, r)
199+
result <- rbindRaws(inputData)
200+
expect_equal(expected, result)
201+
202+
# Single binary column
203+
input <- list(list(r1), list(r2), list(r3))
204+
expected <- subset(expected, select = "V2")
205+
result <- setNames(rbindRaws(input), "V2")
206+
expect_equal(expected, result)
207+
208+
})

R/pkg/inst/worker/worker.R

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,14 @@ compute <- function(mode, partition, serializer, deserializer, key,
3636
# available since R 3.2.4. So we set the global option here.
3737
oldOpt <- getOption("stringsAsFactors")
3838
options(stringsAsFactors = FALSE)
39-
inputData <- do.call(rbind.data.frame, inputData)
39+
40+
# Handle binary data types
41+
if ("raw" %in% sapply(inputData[[1]], class)) {
42+
inputData <- SparkR:::rbindRaws(inputData)
43+
} else {
44+
inputData <- do.call(rbind.data.frame, inputData)
45+
}
46+
4047
options(stringsAsFactors = oldOpt)
4148

4249
names(inputData) <- colNames

0 commit comments

Comments
 (0)