Skip to content

Commit b4b0a61

Browse files
author
Sun Rui
committed
[SPARK-11774][SPARKR] Implement struct(), encode(), decode() functions in SparkR.
1 parent 3d28081 commit b4b0a61

File tree

4 files changed

+94
-1
lines changed

4 files changed

+94
-1
lines changed

R/pkg/NAMESPACE

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,8 +134,10 @@ exportMethods("%in%",
134134
"datediff",
135135
"dayofmonth",
136136
"dayofyear",
137+
"decode",
137138
"dense_rank",
138139
"desc",
140+
"encode",
139141
"endsWith",
140142
"exp",
141143
"explode",
@@ -225,6 +227,7 @@ exportMethods("%in%",
225227
"stddev",
226228
"stddev_pop",
227229
"stddev_samp",
230+
"struct",
228231
"sqrt",
229232
"startsWith",
230233
"substr",

R/pkg/R/functions.R

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -357,6 +357,40 @@ setMethod("dayofyear",
357357
column(jc)
358358
})
359359

360+
#' decode
361+
#'
362+
#' Computes the first argument into a string from a binary using the provided character set
363+
#' (one of 'US-ASCII', 'ISO-8859-1', 'UTF-8', 'UTF-16BE', 'UTF-16LE', 'UTF-16').
364+
#'
365+
#' @rdname decode
366+
#' @name decode
367+
#' @family string_funcs
368+
#' @export
369+
#' @examples \dontrun{decode(df$c, "UTF-8")}
370+
setMethod("decode",
371+
signature(x = "Column", charset = "character"),
372+
function(x, charset) {
373+
jc <- callJStatic("org.apache.spark.sql.functions", "decode", x@jc, charset)
374+
column(jc)
375+
})
376+
377+
#' encode
378+
#'
379+
#' Computes the first argument into a binary from a string using the provided character set
380+
#' (one of 'US-ASCII', 'ISO-8859-1', 'UTF-8', 'UTF-16BE', 'UTF-16LE', 'UTF-16').
381+
#'
382+
#' @rdname encode
383+
#' @name encode
384+
#' @family string_funcs
385+
#' @export
386+
#' @examples \dontrun{encode(df$c, "UTF-8")}
387+
setMethod("encode",
388+
signature(x = "Column", charset = "character"),
389+
function(x, charset) {
390+
jc <- callJStatic("org.apache.spark.sql.functions", "encode", x@jc, charset)
391+
column(jc)
392+
})
393+
360394
#' exp
361395
#'
362396
#' Computes the exponential of the given value.
@@ -1039,6 +1073,31 @@ setMethod("stddev_samp",
10391073
column(jc)
10401074
})
10411075

1076+
#' struct
1077+
#'
1078+
#' Creates a new struct column that composes multiple input columns.
1079+
#'
1080+
#' @rdname struct
1081+
#' @name struct
1082+
#' @family normal_funcs
1083+
#' @export
1084+
#' @examples
1085+
#' \dontrun{
1086+
#' struct(df$c, df$d)
1087+
#' struct("col1", "col2")
1088+
#' }
1089+
setMethod("struct",
1090+
signature(x = "characterOrColumn"),
1091+
function(x, ...) {
1092+
if (class(x) == "Column") {
1093+
jcols <- lapply(list(x, ...), function(x) { x@jc })
1094+
jc <- callJStatic("org.apache.spark.sql.functions", "struct", jcols)
1095+
} else {
1096+
jc <- callJStatic("org.apache.spark.sql.functions", "struct", x, list(...))
1097+
}
1098+
column(jc)
1099+
})
1100+
10421101
#' sqrt
10431102
#'
10441103
#' Computes the square root of the specified float value.

R/pkg/R/generics.R

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -744,10 +744,18 @@ setGeneric("dayofmonth", function(x) { standardGeneric("dayofmonth") })
744744
#' @export
745745
setGeneric("dayofyear", function(x) { standardGeneric("dayofyear") })
746746

747+
#' @rdname decode
748+
#' @export
749+
setGeneric("decode", function(x, charset) { standardGeneric("decode") })
750+
747751
#' @rdname dense_rank
748752
#' @export
749753
setGeneric("dense_rank", function(x) { standardGeneric("dense_rank") })
750754

755+
#' @rdname encode
756+
#' @export
757+
setGeneric("encode", function(x, charset) { standardGeneric("encode") })
758+
751759
#' @rdname explode
752760
#' @export
753761
setGeneric("explode", function(x) { standardGeneric("explode") })
@@ -1001,6 +1009,10 @@ setGeneric("stddev_pop", function(x) { standardGeneric("stddev_pop") })
10011009
#' @export
10021010
setGeneric("stddev_samp", function(x) { standardGeneric("stddev_samp") })
10031011

1012+
#' @rdname struct
1013+
#' @export
1014+
setGeneric("struct", function(x, ...) { standardGeneric("struct") })
1015+
10041016
#' @rdname substring_index
10051017
#' @export
10061018
setGeneric("substring_index", function(x, delim, count) { standardGeneric("substring_index") })

R/pkg/inst/tests/test_sparkSQL.R

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -928,8 +928,27 @@ test_that("column functions", {
928928

929929
# Test that stats::lag is working
930930
expect_equal(length(lag(ldeaths, 12)), 72)
931+
932+
# Test struct()
933+
df <- createDataFrame(sqlContext, list(list(1L, 2L, 3L), list(4L, 5L, 6L)), schema = c("a", "b", "c"))
934+
result <- collect(select(df, struct("a", "c")))
935+
expected <- data.frame(row.names = 1:2)
936+
expected$"struct(a,c)" <- list(listToStruct(list(a = 1L, c = 3L)), listToStruct(list(a = 4L, c = 6L)))
937+
expect_equal(result, expected)
938+
939+
result <- collect(select(df, struct(df$a, df$b)))
940+
expected <- data.frame(row.names = 1:2)
941+
expected$"struct(a,b)" <- list(listToStruct(list(a = 1L, b = 2L)), listToStruct(list(a = 4L, b = 5L)))
942+
expect_equal(result, expected)
943+
944+
# Test encode(), decode()
945+
bytes <- as.raw(c(0xe5, 0xa4, 0xa7, 0xe5, 0x8d, 0x83, 0xe4, 0xb8, 0x96, 0xe7, 0x95, 0x8c))
946+
df <- createDataFrame(sqlContext, list(list("大千世界", "utf-8", bytes)), schema = c("a", "b", "c"))
947+
result <- collect(select(df, encode(df$a, "utf-8"), decode(df$c, "utf-8")))
948+
expect_equal(result[[1]][[1]], bytes)
949+
expect_equal(result[[2]], "大千世界")
931950
})
932-
#
951+
933952
test_that("column binary mathfunctions", {
934953
lines <- c("{\"a\":1, \"b\":5}",
935954
"{\"a\":2, \"b\":6}",

0 commit comments

Comments
 (0)