Skip to content

Commit dbf0c1b

Browse files
committed
Implicitly convert Double to Float based on provided schema.
1 parent 733015a commit dbf0c1b

File tree

3 files changed

+22
-5
lines changed

3 files changed

+22
-5
lines changed

R/pkg/inst/tests/test_sparkSQL.R

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,14 @@ test_that("create DataFrame from RDD", {
126126
expect_equal(columns(df2), c("name", "age", "height"))
127127
expect_equal(dtypes(df2), list(c("name", "string"), c("age", "int"), c("height", "float")))
128128
expect_equal(collect(where(df2, df2$name == "Bob")), c("Bob", 16, 176.5))
129+
130+
localDF <- data.frame(name=c("John", "Smith", "Sarah"), age=c(19, 23, 18), height=c(164.10, 181.4, 173.7))
131+
df <- createDataFrame(sqlContext, localDF, schema)
132+
expect_is(df, "DataFrame")
133+
expect_equal(count(df), 3)
134+
expect_equal(columns(df), c("name", "age", "height"))
135+
expect_equal(dtypes(df), list(c("name", "string"), c("age", "int"), c("height", "float")))
136+
expect_equal(collect(where(df, df$name == "John")), c("John", 19, 164.10))
129137
})
130138

131139
test_that("convert NAs to null type in DataFrames", {

core/src/main/scala/org/apache/spark/api/r/SerDe.scala

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,9 +46,18 @@ private[spark] object SerDe {
4646
dis.readByte().toChar
4747
}
4848

49-
def readObject(dis: DataInputStream): Object = {
49+
def readObject(dis: DataInputStream, typeName: String = ""): Object = {
5050
val dataType = readObjectType(dis)
51-
readTypedObject(dis, dataType)
51+
val data = readTypedObject(dis, dataType)
52+
doConversion(data, dataType, typeName)
53+
}
54+
55+
def doConversion(data: Object, dataType: Char, typeName: String): Object = {
56+
dataType match {
57+
case 'd' if typeName == "Float" =>
58+
new java.lang.Float(data.asInstanceOf[java.lang.Double])
59+
case _ => data
60+
}
5261
}
5362

5463
def readTypedObject(

sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -69,20 +69,20 @@ private[r] object SQLUtils {
6969

7070
def createDF(rdd: RDD[Array[Byte]], schema: StructType, sqlContext: SQLContext): DataFrame = {
7171
val num = schema.fields.size
72-
val rowRDD = rdd.map(bytesToRow)
72+
val rowRDD = rdd.map(bytesToRow(_, schema))
7373
sqlContext.createDataFrame(rowRDD, schema)
7474
}
7575

7676
def dfToRowRDD(df: DataFrame): JavaRDD[Array[Byte]] = {
7777
df.map(r => rowToRBytes(r))
7878
}
7979

80-
private[this] def bytesToRow(bytes: Array[Byte]): Row = {
80+
private[this] def bytesToRow(bytes: Array[Byte], schema: StructType): Row = {
8181
val bis = new ByteArrayInputStream(bytes)
8282
val dis = new DataInputStream(bis)
8383
val num = SerDe.readInt(dis)
8484
Row.fromSeq((0 until num).map { i =>
85-
SerDe.readObject(dis)
85+
SerDe.readObject(dis, schema.fields(i).dataType.typeName)
8686
}.toSeq)
8787
}
8888

0 commit comments

Comments
 (0)