diff --git a/R/pkg/R/deserialize.R b/R/pkg/R/deserialize.R index 0e99b171cabeb..23cafe14d5467 100644 --- a/R/pkg/R/deserialize.R +++ b/R/pkg/R/deserialize.R @@ -55,6 +55,7 @@ readTypedObject <- function(con, type) { "l" = readList(con), "e" = readEnv(con), "s" = readStruct(con), + "B" = readDouble(con), "n" = NULL, "j" = getJobj(readString(con)), stop(paste("Unsupported type for deserialization", type))) diff --git a/R/pkg/R/serialize.R b/R/pkg/R/serialize.R index 3bbf60d9b668c..ac1e8d67d81ac 100644 --- a/R/pkg/R/serialize.R +++ b/R/pkg/R/serialize.R @@ -83,6 +83,7 @@ writeObject <- function(con, object, writeType = TRUE) { Date = writeDate(con, object), POSIXlt = writeTime(con, object), POSIXct = writeTime(con, object), + bigint = writeDouble(con, object), stop(paste("Unsupported type for serialization", type))) } @@ -157,6 +158,7 @@ writeType <- function(con, class) { Date = "D", POSIXlt = "t", POSIXct = "t", + bigint = "B", stop(paste("Unsupported type for serialization", class))) writeBin(charToRaw(type), con) } diff --git a/R/pkg/inst/tests/testthat/test_Serde.R b/R/pkg/inst/tests/testthat/test_Serde.R index b5f6f1b54fa85..f7348b96dfc0a 100644 --- a/R/pkg/inst/tests/testthat/test_Serde.R +++ b/R/pkg/inst/tests/testthat/test_Serde.R @@ -28,6 +28,10 @@ test_that("SerDe of primitive types", { expect_equal(x, 1) expect_equal(class(x), "numeric") + x <- callJStatic("SparkRHandler", "echo", 1380742793415240) + expect_equal(x, 1380742793415240) + expect_equal(class(x), "numeric") + x <- callJStatic("SparkRHandler", "echo", TRUE) expect_true(x) expect_equal(class(x), "logical") @@ -43,6 +47,11 @@ test_that("SerDe of list of primitive types", { expect_equal(x, y) expect_equal(class(y[[1]]), "integer") + x <- list(1380742793415240, 13807427934152401, 13807427934152402) + y <- callJStatic("SparkRHandler", "echo", x) + expect_equal(x, y) + expect_equal(class(y[[1]]), "numeric") + x <- list(1, 2, 3) y <- callJStatic("SparkRHandler", "echo", x) expect_equal(x, y) @@ -66,7 +75,8 @@ test_that("SerDe of list of primitive types", { test_that("SerDe of list of lists", { x <- list(list(1L, 2L, 3L), list(1, 2, 3), - list(TRUE, FALSE), list("a", "b", "c")) + list(TRUE, FALSE), list("a", "b", "c"), + list(1380742793415240, 1380742793415240)) y <- callJStatic("SparkRHandler", "echo", x) expect_equal(x, y) diff --git a/R/pkg/inst/tests/testthat/test_sparkSQL.R b/R/pkg/inst/tests/testthat/test_sparkSQL.R index 1a3d6df437d7e..9d030837e1c09 100644 --- a/R/pkg/inst/tests/testthat/test_sparkSQL.R +++ b/R/pkg/inst/tests/testthat/test_sparkSQL.R @@ -3188,6 +3188,23 @@ test_that("catalog APIs, currentDatabase, setCurrentDatabase, listDatabases", { expect_equal(dbs[[1]], "default") }) +test_that("dapply with bigint type", { + df <- createDataFrame( + list(list(1380742793415240, 1, "1"), list(1380742793415240, 2, "2"), + list(1380742793415240, 3, "3")), c("a", "b", "c")) + schema <- structType(structField("a", "bigint"), structField("b", "bigint"), + structField("c", "string"), structField("d", "bigint")) + df1 <- dapply( + df, + function(x) { + y <- x[x[1] > 1, ] + y <- cbind(y, y[1] + 1L) + }, + schema) + result <- collect(df1) + expect_equal(result$a[1], 1380742793415240) +}) + test_that("catalog APIs, listTables, listColumns, listFunctions", { tb <- listTables() count <- count(tables()) diff --git a/core/src/main/scala/org/apache/spark/api/r/SerDe.scala b/core/src/main/scala/org/apache/spark/api/r/SerDe.scala index dad928cdcfd0f..f3fce575fbf76 100644 --- a/core/src/main/scala/org/apache/spark/api/r/SerDe.scala +++ b/core/src/main/scala/org/apache/spark/api/r/SerDe.scala @@ -84,6 +84,7 @@ private[spark] object SerDe { case 'l' => readList(dis, jvmObjectTracker) case 'D' => readDate(dis) case 't' => readTime(dis) + case 'B' => new java.lang.Double(readDouble(dis)) case 'j' => jvmObjectTracker(JVMObjectId(readString(dis))) case _ => if (sqlReadObject == null) { @@ -198,6 +199,7 @@ private[spark] object SerDe { case 'b' => readBooleanArr(dis) case 'j' => readStringArr(dis).map(x => jvmObjectTracker(JVMObjectId(x))) case 'r' => readBytesArr(dis) + case 'B' => readDoubleArr(dis) case 'a' => val len = readInt(dis) (0 until len).map(_ => readArray(dis, jvmObjectTracker)).toArray @@ -278,6 +280,7 @@ private[spark] object SerDe { case "list" => dos.writeByte('l') case "map" => dos.writeByte('e') case "jobj" => dos.writeByte('j') + case "bigint" => dos.writeByte('B') case _ => throw new IllegalArgumentException(s"Invalid type $typeStr") } } diff --git a/core/src/test/scala/org/apache/spark/api/r/RBackendSuite.scala b/core/src/test/scala/org/apache/spark/api/r/RBackendSuite.scala index 085cc267ca74d..e752247062a1b 100644 --- a/core/src/test/scala/org/apache/spark/api/r/RBackendSuite.scala +++ b/core/src/test/scala/org/apache/spark/api/r/RBackendSuite.scala @@ -17,6 +17,8 @@ package org.apache.spark.api.r +import java.io.{ByteArrayInputStream, ByteArrayOutputStream, DataInputStream, DataOutputStream} + import org.apache.spark.SparkFunSuite class RBackendSuite extends SparkFunSuite { @@ -28,4 +30,21 @@ class RBackendSuite extends SparkFunSuite { assert(tracker.get(id) === None) assert(tracker.size === 0) } + + test("read and write bigint in the buffer") { + val bos = new ByteArrayOutputStream() + val dos = new DataOutputStream(bos) + val tracker = new JVMObjectTracker + SerDe.writeObject(dos, 1380742793415240L.asInstanceOf[Object], + tracker) + val buf = bos.toByteArray + val bis = new ByteArrayInputStream(buf) + val dis = new DataInputStream(bis) + val data = SerDe.readObject(dis, tracker) + assert(data.asInstanceOf[Double] === 1380742793415240L) + bos.close() + bis.close() + dos.close() + dis.close() + } }