Skip to content

Commit 6f69025

Browse files
viiryadavies
authored andcommitted
[SPARK-8840] [SPARKR] Add float coercion on SparkR
JIRA: https://issues.apache.org/jira/browse/SPARK-8840 Currently the type coercion rules don't include float type. This PR simply adds it. Author: Liang-Chi Hsieh <[email protected]> Closes #7280 from viirya/add_r_float_coercion and squashes the following commits: c86dc0e [Liang-Chi Hsieh] For comments. dbf0c1b [Liang-Chi Hsieh] Implicitly convert Double to Float based on provided schema. 733015a [Liang-Chi Hsieh] Add test case for DataFrame with float type. 30c2a40 [Liang-Chi Hsieh] Update test case. 52b5294 [Liang-Chi Hsieh] Merge remote-tracking branch 'upstream/master' into add_r_float_coercion 6f9159d [Liang-Chi Hsieh] Add another test case. 8db3244 [Liang-Chi Hsieh] schema also needs to support float. add test case. 0dcc992 [Liang-Chi Hsieh] Add float coercion on SparkR.
1 parent 20bb10f commit 6f69025

File tree

5 files changed

+44
-3
lines changed

5 files changed

+44
-3
lines changed

R/pkg/R/deserialize.R

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
# Int -> integer
2424
# String -> character
2525
# Boolean -> logical
26+
# Float -> double
2627
# Double -> double
2728
# Long -> double
2829
# Array[Byte] -> raw

R/pkg/R/schema.R

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,7 @@ structField.character <- function(x, type, nullable = TRUE) {
123123
}
124124
options <- c("byte",
125125
"integer",
126+
"float",
126127
"double",
127128
"numeric",
128129
"character",

R/pkg/inst/tests/test_sparkSQL.R

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,32 @@ test_that("create DataFrame from RDD", {
108108
expect_equal(count(df), 10)
109109
expect_equal(columns(df), c("a", "b"))
110110
expect_equal(dtypes(df), list(c("a", "int"), c("b", "string")))
111+
112+
df <- jsonFile(sqlContext, jsonPathNa)
113+
hiveCtx <- tryCatch({
114+
newJObject("org.apache.spark.sql.hive.test.TestHiveContext", ssc)
115+
}, error = function(err) {
116+
skip("Hive is not build with SparkSQL, skipped")
117+
})
118+
sql(hiveCtx, "CREATE TABLE people (name string, age double, height float)")
119+
insertInto(df, "people")
120+
expect_equal(sql(hiveCtx, "SELECT age from people WHERE name = 'Bob'"), c(16))
121+
expect_equal(sql(hiveCtx, "SELECT height from people WHERE name ='Bob'"), c(176.5))
122+
123+
schema <- structType(structField("name", "string"), structField("age", "integer"),
124+
structField("height", "float"))
125+
df2 <- createDataFrame(sqlContext, df.toRDD, schema)
126+
expect_equal(columns(df2), c("name", "age", "height"))
127+
expect_equal(dtypes(df2), list(c("name", "string"), c("age", "int"), c("height", "float")))
128+
expect_equal(collect(where(df2, df2$name == "Bob")), c("Bob", 16, 176.5))
129+
130+
localDF <- data.frame(name=c("John", "Smith", "Sarah"), age=c(19, 23, 18), height=c(164.10, 181.4, 173.7))
131+
df <- createDataFrame(sqlContext, localDF, schema)
132+
expect_is(df, "DataFrame")
133+
expect_equal(count(df), 3)
134+
expect_equal(columns(df), c("name", "age", "height"))
135+
expect_equal(dtypes(df), list(c("name", "string"), c("age", "int"), c("height", "float")))
136+
expect_equal(collect(where(df, df$name == "John")), c("John", 19, 164.10))
111137
})
112138

113139
test_that("convert NAs to null type in DataFrames", {

core/src/main/scala/org/apache/spark/api/r/SerDe.scala

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -179,6 +179,7 @@ private[spark] object SerDe {
179179
// Int -> integer
180180
// String -> character
181181
// Boolean -> logical
182+
// Float -> double
182183
// Double -> double
183184
// Long -> double
184185
// Array[Byte] -> raw
@@ -215,6 +216,9 @@ private[spark] object SerDe {
215216
case "long" | "java.lang.Long" =>
216217
writeType(dos, "double")
217218
writeDouble(dos, value.asInstanceOf[Long].toDouble)
219+
case "float" | "java.lang.Float" =>
220+
writeType(dos, "double")
221+
writeDouble(dos, value.asInstanceOf[Float].toDouble)
218222
case "double" | "java.lang.Double" =>
219223
writeType(dos, "double")
220224
writeDouble(dos, value.asInstanceOf[Double])

sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ private[r] object SQLUtils {
4747
dataType match {
4848
case "byte" => org.apache.spark.sql.types.ByteType
4949
case "integer" => org.apache.spark.sql.types.IntegerType
50+
case "float" => org.apache.spark.sql.types.FloatType
5051
case "double" => org.apache.spark.sql.types.DoubleType
5152
case "numeric" => org.apache.spark.sql.types.DoubleType
5253
case "character" => org.apache.spark.sql.types.StringType
@@ -68,20 +69,28 @@ private[r] object SQLUtils {
6869

6970
def createDF(rdd: RDD[Array[Byte]], schema: StructType, sqlContext: SQLContext): DataFrame = {
7071
val num = schema.fields.size
71-
val rowRDD = rdd.map(bytesToRow)
72+
val rowRDD = rdd.map(bytesToRow(_, schema))
7273
sqlContext.createDataFrame(rowRDD, schema)
7374
}
7475

7576
def dfToRowRDD(df: DataFrame): JavaRDD[Array[Byte]] = {
7677
df.map(r => rowToRBytes(r))
7778
}
7879

79-
private[this] def bytesToRow(bytes: Array[Byte]): Row = {
80+
private[this] def doConversion(data: Object, dataType: DataType): Object = {
81+
data match {
82+
case d: java.lang.Double if dataType == FloatType =>
83+
new java.lang.Float(d)
84+
case _ => data
85+
}
86+
}
87+
88+
private[this] def bytesToRow(bytes: Array[Byte], schema: StructType): Row = {
8089
val bis = new ByteArrayInputStream(bytes)
8190
val dis = new DataInputStream(bis)
8291
val num = SerDe.readInt(dis)
8392
Row.fromSeq((0 until num).map { i =>
84-
SerDe.readObject(dis)
93+
doConversion(SerDe.readObject(dis), schema.fields(i).dataType)
8594
}.toSeq)
8695
}
8796

0 commit comments

Comments
 (0)