Skip to content

Commit f0d880e

Browse files
daviesmarmbrus
authored andcommitted
[SPARK-2674] [SQL] [PySpark] support datetime type for SchemaRDD
Datetime and time in Python will be converted into java.util.Calendar after serialization, it will be converted into java.sql.Timestamp during inferSchema(). In javaToPython(), Timestamp will be converted into Calendar, then be converted into datetime in Python after pickling. Author: Davies Liu <[email protected]> Closes #1601 from davies/date and squashes the following commits: f0599b0 [Davies Liu] remove tests for sets and tuple in sql, fix list of list c9d607a [Davies Liu] convert datetype for runtime 709d40d [Davies Liu] remove brackets 96db384 [Davies Liu] support datetime type for SchemaRDD
1 parent e364348 commit f0d880e

File tree

4 files changed

+68
-44
lines changed

4 files changed

+68
-44
lines changed

core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -550,11 +550,11 @@ private[spark] object PythonRDD extends Logging {
550550
def pythonToJavaMap(pyRDD: JavaRDD[Array[Byte]]): JavaRDD[Map[String, _]] = {
551551
pyRDD.rdd.mapPartitions { iter =>
552552
val unpickle = new Unpickler
553-
// TODO: Figure out why flatMap is necessay for pyspark
554553
iter.flatMap { row =>
555554
unpickle.loads(row) match {
555+
// in case of objects are pickled in batch mode
556556
case objs: java.util.ArrayList[JMap[String, _] @unchecked] => objs.map(_.toMap)
557-
// Incase the partition doesn't have a collection
557+
// not in batch mode
558558
case obj: JMap[String @unchecked, _] => Seq(obj.toMap)
559559
}
560560
}

python/pyspark/sql.py

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -47,12 +47,14 @@ def __init__(self, sparkContext, sqlContext=None):
4747
...
4848
ValueError:...
4949
50-
>>> allTypes = sc.parallelize([{"int" : 1, "string" : "string", "double" : 1.0, "long": 1L,
51-
... "boolean" : True}])
50+
>>> from datetime import datetime
51+
>>> allTypes = sc.parallelize([{"int": 1, "string": "string", "double": 1.0, "long": 1L,
52+
... "boolean": True, "time": datetime(2010, 1, 1, 1, 1, 1), "dict": {"a": 1},
53+
... "list": [1, 2, 3]}])
5254
>>> srdd = sqlCtx.inferSchema(allTypes).map(lambda x: (x.int, x.string, x.double, x.long,
53-
... x.boolean))
55+
... x.boolean, x.time, x.dict["a"], x.list))
5456
>>> srdd.collect()[0]
55-
(1, u'string', 1.0, 1, True)
57+
(1, u'string', 1.0, 1, True, datetime.datetime(2010, 1, 1, 1, 1, 1), 1, [1, 2, 3])
5658
"""
5759
self._sc = sparkContext
5860
self._jsc = self._sc._jsc
@@ -88,13 +90,13 @@ def inferSchema(self, rdd):
8890
8991
>>> from array import array
9092
>>> srdd = sqlCtx.inferSchema(nestedRdd1)
91-
>>> srdd.collect() == [{"f1" : array('i', [1, 2]), "f2" : {"row1" : 1.0}},
92-
... {"f1" : array('i', [2, 3]), "f2" : {"row2" : 2.0}}]
93+
>>> srdd.collect() == [{"f1" : [1, 2], "f2" : {"row1" : 1.0}},
94+
... {"f1" : [2, 3], "f2" : {"row2" : 2.0}}]
9395
True
9496
9597
>>> srdd = sqlCtx.inferSchema(nestedRdd2)
96-
>>> srdd.collect() == [{"f1" : [[1, 2], [2, 3]], "f2" : set([1, 2]), "f3" : (1, 2)},
97-
... {"f1" : [[2, 3], [3, 4]], "f2" : set([2, 3]), "f3" : (2, 3)}]
98+
>>> srdd.collect() == [{"f1" : [[1, 2], [2, 3]], "f2" : [1, 2]},
99+
... {"f1" : [[2, 3], [3, 4]], "f2" : [2, 3]}]
98100
True
99101
"""
100102
if (rdd.__class__ is SchemaRDD):
@@ -509,8 +511,8 @@ def _test():
509511
{"f1": array('i', [1, 2]), "f2": {"row1": 1.0}},
510512
{"f1": array('i', [2, 3]), "f2": {"row2": 2.0}}])
511513
globs['nestedRdd2'] = sc.parallelize([
512-
{"f1": [[1, 2], [2, 3]], "f2": set([1, 2]), "f3": (1, 2)},
513-
{"f1": [[2, 3], [3, 4]], "f2": set([2, 3]), "f3": (2, 3)}])
514+
{"f1": [[1, 2], [2, 3]], "f2": [1, 2]},
515+
{"f1": [[2, 3], [3, 4]], "f2": [2, 3]}])
514516
(failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS)
515517
globs['sc'].stop()
516518
if failure_count:

sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala

Lines changed: 37 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -352,8 +352,10 @@ class SQLContext(@transient val sparkContext: SparkContext)
352352
case c: java.lang.Long => LongType
353353
case c: java.lang.Double => DoubleType
354354
case c: java.lang.Boolean => BooleanType
355+
case c: java.math.BigDecimal => DecimalType
356+
case c: java.sql.Timestamp => TimestampType
357+
case c: java.util.Calendar => TimestampType
355358
case c: java.util.List[_] => ArrayType(typeFor(c.head))
356-
case c: java.util.Set[_] => ArrayType(typeFor(c.head))
357359
case c: java.util.Map[_, _] =>
358360
val (key, value) = c.head
359361
MapType(typeFor(key), typeFor(value))
@@ -362,11 +364,43 @@ class SQLContext(@transient val sparkContext: SparkContext)
362364
ArrayType(typeFor(elem))
363365
case c => throw new Exception(s"Object of type $c cannot be used")
364366
}
365-
val schema = rdd.first().map { case (fieldName, obj) =>
367+
val firstRow = rdd.first()
368+
val schema = firstRow.map { case (fieldName, obj) =>
366369
AttributeReference(fieldName, typeFor(obj), true)()
367370
}.toSeq
368371

369-
val rowRdd = rdd.mapPartitions { iter =>
372+
def needTransform(obj: Any): Boolean = obj match {
373+
case c: java.util.List[_] => true
374+
case c: java.util.Map[_, _] => true
375+
case c if c.getClass.isArray => true
376+
case c: java.util.Calendar => true
377+
case c => false
378+
}
379+
380+
// convert JList, JArray into Seq, convert JMap into Map
381+
// convert Calendar into Timestamp
382+
def transform(obj: Any): Any = obj match {
383+
case c: java.util.List[_] => c.map(transform).toSeq
384+
case c: java.util.Map[_, _] => c.map {
385+
case (key, value) => (key, transform(value))
386+
}.toMap
387+
case c if c.getClass.isArray =>
388+
c.asInstanceOf[Array[_]].map(transform).toSeq
389+
case c: java.util.Calendar =>
390+
new java.sql.Timestamp(c.getTime().getTime())
391+
case c => c
392+
}
393+
394+
val need = firstRow.exists {case (key, value) => needTransform(value)}
395+
val transformed = if (need) {
396+
rdd.mapPartitions { iter =>
397+
iter.map {
398+
m => m.map {case (key, value) => (key, transform(value))}
399+
}
400+
}
401+
} else rdd
402+
403+
val rowRdd = transformed.mapPartitions { iter =>
370404
iter.map { map =>
371405
new GenericRow(map.values.toArray.asInstanceOf[Array[Any]]): Row
372406
}

sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala

Lines changed: 17 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ import org.apache.spark.sql.catalyst.analysis._
3232
import org.apache.spark.sql.catalyst.expressions._
3333
import org.apache.spark.sql.catalyst.plans.logical._
3434
import org.apache.spark.sql.catalyst.plans.{Inner, JoinType}
35-
import org.apache.spark.sql.catalyst.types.{ArrayType, BooleanType, StructType}
35+
import org.apache.spark.sql.catalyst.types.{DataType, ArrayType, BooleanType, StructType, MapType}
3636
import org.apache.spark.sql.execution.{ExistingRdd, SparkLogicalPlan}
3737
import org.apache.spark.api.java.JavaRDD
3838

@@ -376,39 +376,27 @@ class SchemaRDD(
376376
* Converts a JavaRDD to a PythonRDD. It is used by pyspark.
377377
*/
378378
private[sql] def javaToPython: JavaRDD[Array[Byte]] = {
379+
def toJava(obj: Any, dataType: DataType): Any = dataType match {
380+
case struct: StructType => rowToMap(obj.asInstanceOf[Row], struct)
381+
case array: ArrayType => obj match {
382+
case seq: Seq[Any] => seq.map(x => toJava(x, array.elementType)).asJava
383+
case list: JList[_] => list.map(x => toJava(x, array.elementType)).asJava
384+
case arr if arr != null && arr.getClass.isArray =>
385+
arr.asInstanceOf[Array[Any]].map(x => toJava(x, array.elementType))
386+
case other => other
387+
}
388+
case mt: MapType => obj.asInstanceOf[Map[_, _]].map {
389+
case (k, v) => (k, toJava(v, mt.valueType)) // key should be primitive type
390+
}.asJava
391+
// Pyrolite can handle Timestamp
392+
case other => obj
393+
}
379394
def rowToMap(row: Row, structType: StructType): JMap[String, Any] = {
380395
val fields = structType.fields.map(field => (field.name, field.dataType))
381396
val map: JMap[String, Any] = new java.util.HashMap
382397
row.zip(fields).foreach {
383-
case (obj, (attrName, dataType)) =>
384-
dataType match {
385-
case struct: StructType => map.put(attrName, rowToMap(obj.asInstanceOf[Row], struct))
386-
case array @ ArrayType(struct: StructType) =>
387-
val arrayValues = obj match {
388-
case seq: Seq[Any] =>
389-
seq.map(element => rowToMap(element.asInstanceOf[Row], struct)).asJava
390-
case list: JList[_] =>
391-
list.map(element => rowToMap(element.asInstanceOf[Row], struct))
392-
case set: JSet[_] =>
393-
set.map(element => rowToMap(element.asInstanceOf[Row], struct))
394-
case arr if arr != null && arr.getClass.isArray =>
395-
arr.asInstanceOf[Array[Any]].map {
396-
element => rowToMap(element.asInstanceOf[Row], struct)
397-
}
398-
case other => other
399-
}
400-
map.put(attrName, arrayValues)
401-
case array: ArrayType => {
402-
val arrayValues = obj match {
403-
case seq: Seq[Any] => seq.asJava
404-
case other => other
405-
}
406-
map.put(attrName, arrayValues)
407-
}
408-
case other => map.put(attrName, obj)
409-
}
398+
case (obj, (attrName, dataType)) => map.put(attrName, toJava(obj, dataType))
410399
}
411-
412400
map
413401
}
414402

0 commit comments

Comments
 (0)