Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions R/pkg/R/DataFrame.R
Original file line number Diff line number Diff line change
Expand Up @@ -1437,6 +1437,29 @@ dapplyInternal <- function(x, func, schema) {
schema <- structType(schema)
}

arrowEnabled <- sparkR.conf("spark.sql.execution.arrow.enabled")[[1]] == "true"
if (arrowEnabled) {
requireNamespace1 <- requireNamespace
if (!requireNamespace1("arrow", quietly = TRUE)) {
stop("'arrow' package should be installed.")
}
# Currenty Arrow optimization does not support raw for now.
# Also, it does not support explicit float type set by users.
if (inherits(schema, "structType")) {
if (any(sapply(schema$fields(), function(x) x$dataType.toString() == "FloatType"))) {
stop("Arrow optimization with dapply do not support FloatType yet.")
}
if (any(sapply(schema$fields(), function(x) x$dataType.toString() == "BinaryType"))) {
stop("Arrow optimization with dapply do not support BinaryType yet.")
}
} else if (is.null(schema)) {
stop(paste0("Arrow optimization does not support 'dapplyCollect' yet. Please disable ",
"Arrow optimization or use 'collect' and 'dapply' APIs instead."))
} else {
stop("'schema' should be DDL-formatted string or structType.")
}
}

packageNamesArr <- serialize(.sparkREnv[[".packages"]],
connection = NULL)

Expand Down
16 changes: 10 additions & 6 deletions R/pkg/R/deserialize.R
Original file line number Diff line number Diff line change
Expand Up @@ -247,17 +247,21 @@ readDeserializeInArrow <- function(inputCon) {
batches <- RecordBatchStreamReader(arrowData)$batches()

# Read all groupped batches. Tibble -> data.frame is cheap.
data <- lapply(batches, function(batch) as.data.frame(as_tibble(batch)))

# Read keys to map with each groupped batch.
keys <- readMultipleObjects(inputCon)

list(keys = keys, data = data)
lapply(batches, function(batch) as.data.frame(as_tibble(batch)))
} else {
stop("'arrow' package should be installed.")
}
}

readDeserializeWithKeysInArrow <- function(inputCon) {
data <- readDeserializeInArrow(inputCon)

keys <- readMultipleObjects(inputCon)

# Read keys to map with each groupped batch later.
list(keys = keys, data = data)
}

readRowList <- function(obj) {
# readRowList is meant for use inside an lapply. As a result, it is
# necessary to open a standalone connection for the row and consume
Expand Down
15 changes: 15 additions & 0 deletions R/pkg/R/serialize.R
Original file line number Diff line number Diff line change
Expand Up @@ -220,3 +220,18 @@ writeArgs <- function(con, args) {
}
}
}

writeSerializeInArrow <- function(conn, df) {
# This is a hack to avoid CRAN check. Arrow is not uploaded into CRAN now. See ARROW-3204.
requireNamespace1 <- requireNamespace
if (requireNamespace1("arrow", quietly = TRUE)) {
write_arrow <- get("write_arrow", envir = asNamespace("arrow"), inherits = FALSE)

# There looks no way to send each batch in streaming format via socket
# connection. See ARROW-4512.
# So, it writes the whole Arrow streaming-formatted binary at once for now.
writeRaw(conn, write_arrow(df, raw()))
} else {
stop("'arrow' package should be installed.")
}
}
33 changes: 15 additions & 18 deletions R/pkg/inst/worker/worker.R
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,8 @@ outputResult <- function(serializer, output, outputCon) {
SparkR:::writeRawSerialize(outputCon, output)
} else if (serializer == "row") {
SparkR:::writeRowSerialize(outputCon, output)
} else if (serializer == "arrow") {
SparkR:::writeSerializeInArrow(outputCon, output)
} else {
# write lines one-by-one with flag
lapply(output, function(line) SparkR:::writeString(outputCon, line))
Expand Down Expand Up @@ -172,9 +174,15 @@ if (isEmpty != 0) {
} else if (deserializer == "row") {
data <- SparkR:::readMultipleObjects(inputCon)
} else if (deserializer == "arrow" && mode == 2) {
dataWithKeys <- SparkR:::readDeserializeInArrow(inputCon)
dataWithKeys <- SparkR:::readDeserializeWithKeysInArrow(inputCon)
keys <- dataWithKeys$keys
data <- dataWithKeys$data
} else if (deserializer == "arrow" && mode == 1) {
data <- SparkR:::readDeserializeInArrow(inputCon)
# See https://stat.ethz.ch/pipermail/r-help/2010-September/252046.html
# rbind.fill might be an anternative to make it faster if plyr is installed.
# Also, note that, 'dapply' applies a function to each partition.
data <- do.call("rbind", data)
}

# Timing reading input data for execution
Expand All @@ -192,7 +200,7 @@ if (isEmpty != 0) {
output <- compute(mode, partition, serializer, deserializer, keys[[i]],
colNames, computeFunc, data[[i]])
computeElap <- elapsedSecs()
if (deserializer == "arrow") {
if (serializer == "arrow") {
outputs[[length(outputs) + 1L]] <- output
} else {
outputResult(serializer, output, outputCon)
Expand All @@ -202,22 +210,11 @@ if (isEmpty != 0) {
outputComputeElapsDiff <- outputComputeElapsDiff + (outputElap - computeElap)
}

if (deserializer == "arrow") {
# This is a hack to avoid CRAN check. Arrow is not uploaded into CRAN now. See ARROW-3204.
requireNamespace1 <- requireNamespace
if (requireNamespace1("arrow", quietly = TRUE)) {
write_arrow <- get("write_arrow", envir = asNamespace("arrow"), inherits = FALSE)
# See https://stat.ethz.ch/pipermail/r-help/2010-September/252046.html
# rbind.fill might be an anternative to make it faster if plyr is installed.
combined <- do.call("rbind", outputs)

# Likewise, there looks no way to send each batch in streaming format via socket
# connection. See ARROW-4512.
# So, it writes the whole Arrow streaming-formatted binary at once for now.
SparkR:::writeRaw(outputCon, write_arrow(combined, raw()))
} else {
stop("'arrow' package should be installed.")
}
if (serializer == "arrow") {
# See https://stat.ethz.ch/pipermail/r-help/2010-September/252046.html
# rbind.fill might be an anternative to make it faster if plyr is installed.
combined <- do.call("rbind", outputs)
SparkR:::writeSerializeInArrow(outputCon, combined)
}
}
} else {
Expand Down
99 changes: 99 additions & 0 deletions R/pkg/tests/fulltests/test_sparkSQL.R
Original file line number Diff line number Diff line change
Expand Up @@ -3300,6 +3300,105 @@ test_that("dapplyCollect() on DataFrame with a binary column", {

})

test_that("dapply() Arrow optimization", {
skip_if_not_installed("arrow")
df <- createDataFrame(mtcars)

conf <- callJMethod(sparkSession, "conf")
arrowEnabled <- sparkR.conf("spark.sql.execution.arrow.enabled")[[1]]

callJMethod(conf, "set", "spark.sql.execution.arrow.enabled", "false")
tryCatch({
ret <- dapply(df,
function(rdf) {
stopifnot(class(rdf) == "data.frame")
rdf
},
schema(df))
expected <- collect(ret)
},
finally = {
# Resetting the conf back to default value
callJMethod(conf, "set", "spark.sql.execution.arrow.enabled", arrowEnabled)
})

callJMethod(conf, "set", "spark.sql.execution.arrow.enabled", "true")
tryCatch({
ret <- dapply(df,
function(rdf) {
stopifnot(class(rdf) == "data.frame")
# mtcars' hp is more then 50.
stopifnot(all(rdf$hp > 50))
rdf
},
schema(df))
actual <- collect(ret)
expect_equal(actual, expected)
expect_equal(count(ret), nrow(mtcars))
},
finally = {
# Resetting the conf back to default value
callJMethod(conf, "set", "spark.sql.execution.arrow.enabled", arrowEnabled)
})
})

test_that("dapply() Arrow optimization - type specification", {
skip_if_not_installed("arrow")
# Note that regular dapply() seems not supporting date and timestamps
# whereas Arrow-optimized dapply() does.
rdf <- data.frame(list(list(a = 1,
b = "a",
c = TRUE,
d = 1.1,
e = 1L)))
# numPartitions are set to 8 intentionally to test empty partitions as well.
df <- createDataFrame(rdf, numPartitions = 8)

conf <- callJMethod(sparkSession, "conf")
arrowEnabled <- sparkR.conf("spark.sql.execution.arrow.enabled")[[1]]

callJMethod(conf, "set", "spark.sql.execution.arrow.enabled", "false")
tryCatch({
ret <- dapply(df, function(rdf) { rdf }, schema(df))
expected <- collect(ret)
},
finally = {
# Resetting the conf back to default value
callJMethod(conf, "set", "spark.sql.execution.arrow.enabled", arrowEnabled)
})

callJMethod(conf, "set", "spark.sql.execution.arrow.enabled", "true")
tryCatch({
ret <- dapply(df, function(rdf) { rdf }, schema(df))
actual <- collect(ret)
expect_equal(actual, expected)
},
finally = {
# Resetting the conf back to default value
callJMethod(conf, "set", "spark.sql.execution.arrow.enabled", arrowEnabled)
})
})

test_that("dapply() Arrow optimization - type specification (date and timestamp)", {
skip_if_not_installed("arrow")
rdf <- data.frame(list(list(a = as.Date("1990-02-24"),
b = as.POSIXct("1990-02-24 12:34:56"))))
df <- createDataFrame(rdf)

conf <- callJMethod(sparkSession, "conf")
arrowEnabled <- sparkR.conf("spark.sql.execution.arrow.enabled")[[1]]

callJMethod(conf, "set", "spark.sql.execution.arrow.enabled", "true")
tryCatch({
ret <- dapply(df, function(rdf) { rdf }, schema(df))
expect_equal(collect(ret), rdf)
},
finally = {
# Resetting the conf back to default value
callJMethod(conf, "set", "spark.sql.execution.arrow.enabled", arrowEnabled)
})
})

test_that("repartition by columns on DataFrame", {
# The tasks here launch R workers with shuffles. So, we decrease the number of shuffle
# partitions to reduce the number of the tasks to speed up the test. This is particularly
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -123,16 +123,25 @@ object MapPartitionsInR {
schema: StructType,
encoder: ExpressionEncoder[Row],
child: LogicalPlan): LogicalPlan = {
val deserialized = CatalystSerde.deserialize(child)(encoder)
val mapped = MapPartitionsInR(
func,
packageNames,
broadcastVars,
encoder.schema,
schema,
CatalystSerde.generateObjAttr(RowEncoder(schema)),
deserialized)
CatalystSerde.serialize(mapped)(RowEncoder(schema))
if (SQLConf.get.arrowEnabled) {
MapPartitionsInRWithArrow(
func,
packageNames,
broadcastVars,
encoder.schema,
schema.toAttributes,
child)
} else {
val deserialized = CatalystSerde.deserialize(child)(encoder)
CatalystSerde.serialize(MapPartitionsInR(
func,
packageNames,
broadcastVars,
encoder.schema,
schema,
CatalystSerde.generateObjAttr(RowEncoder(schema)),
deserialized))(RowEncoder(schema))
}
}
}

Expand All @@ -154,6 +163,28 @@ case class MapPartitionsInR(
outputObjAttr, child)
}

/**
* Similar with `MapPartitionsInR` but serializes and deserializes input/output in
* Arrow format.
*
* This is somewhat similar with `org.apache.spark.sql.execution.python.ArrowEvalPython`
*/
case class MapPartitionsInRWithArrow(
func: Array[Byte],
packageNames: Array[Byte],
broadcastVars: Array[Broadcast[Object]],
inputSchema: StructType,
output: Seq[Attribute],
child: LogicalPlan) extends UnaryNode {
// This operator always need all columns of its child, even it doesn't reference to.
override def references: AttributeSet = child.outputSet

override protected def stringArgs: Iterator[Any] = Iterator(
inputSchema, StructType.fromAttributes(output), child)

override val producedAttributes = AttributeSet(output)
}

object MapElements {
def apply[T : Encoder, U : Encoder](
func: AnyRef,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -599,6 +599,9 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
case logical.FlatMapGroupsInRWithArrow(f, p, b, is, ot, key, grouping, child) =>
execution.FlatMapGroupsInRWithArrowExec(
f, p, b, is, ot, key, grouping, planLater(child)) :: Nil
case logical.MapPartitionsInRWithArrow(f, p, b, is, ot, child) =>
execution.MapPartitionsInRWithArrowExec(
f, p, b, is, ot, planLater(child)) :: Nil
case logical.FlatMapGroupsInPandas(grouping, func, output, child) =>
execution.python.FlatMapGroupsInPandasExec(grouping, func, output, planLater(child)) :: Nil
case logical.MapElements(f, _, _, objAttr, child) =>
Expand Down
Loading