From 644665a2ebae4bc4a49f28152e3af00681affcc8 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Fri, 25 Jul 2014 16:11:57 -0700 Subject: [PATCH 01/24] use tuple and namedtuple for schemardd --- python/pyspark/rdd.py | 8 +- python/pyspark/sql.py | 256 +++++++++++++----- .../org/apache/spark/sql/SchemaRDD.scala | 33 +-- 3 files changed, 200 insertions(+), 97 deletions(-) diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index 113a082e1672..973ddd7e94fc 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -311,9 +311,9 @@ def map(self, f, preservesPartitioning=False): >>> sorted(rdd.map(lambda x: (x, 1)).collect()) [('a', 1), ('b', 1), ('c', 1)] """ - def func(split, iterator): + def func(_, iterator): return imap(f, iterator) - return PipelinedRDD(self, func, preservesPartitioning) + return self.mapPartitionsWithIndex(func, preservesPartitioning) def flatMap(self, f, preservesPartitioning=False): """ @@ -1070,7 +1070,7 @@ def func(split, iterator): if not isinstance(x, basestring): x = unicode(x) yield x.encode("utf-8") - keyed = PipelinedRDD(self, func) + keyed = self.mapPartitionsWithIndex(func) keyed._bypass_serializer = True keyed._jrdd.map(self.ctx._jvm.BytesToString()).saveAsTextFile(path) @@ -1268,7 +1268,7 @@ def add_shuffle_key(split, iterator): yield pack_long(split) yield outputSerializer.dumps(items) - keyed = PipelinedRDD(self, add_shuffle_key) + keyed = self.mapPartitionsWithIndex(add_shuffle_key) keyed._bypass_serializer = True with _JavaStackTrace(self.context) as st: pairRDD = self.ctx._jvm.PairwiseRDD( diff --git a/python/pyspark/sql.py b/python/pyspark/sql.py index cb83e8917682..d40a52d63569 100644 --- a/python/pyspark/sql.py +++ b/python/pyspark/sql.py @@ -15,7 +15,14 @@ # limitations under the License. # -from pyspark.rdd import RDD, PipelinedRDD + +import sys +import types +import array +import itertools +from operator import itemgetter + +from pyspark.rdd import RDD from pyspark.serializers import BatchedSerializer, PickleSerializer from py4j.protocol import Py4JError @@ -23,6 +30,119 @@ __all__ = ["SQLContext", "HiveContext", "LocalHiveContext", "TestHiveContext", "SchemaRDD", "Row"] +def _extend_tree(lines): + parts = [] + + def subtree(depth): + sub = parts + while depth > 1: + sub = sub[-1] + depth -= 1 + return sub + + for line in lines: + subtree(line.count('|')).append([line]) + + return parts + +def _parse_tree(tree): + if isinstance(tree[0], basestring): + name, _type = tree[0].split(":") + name = name.split(" ")[-1] + if len(tree) == 1: + return (name, _type.strip()) + else: + return (name, _parse_tree(tree[1:])) + else: + return tuple(_parse_tree(sub) for sub in tree) + +def parse_tree_schema(tree): + lines = tree.split("\n")[1:-1] + parts = _extend_tree(lines) + return _parse_tree(parts) + +def _create_object(cls, v): + return cls(v) if v is not None else v + +_cached_cls = {} +def _restore_object(fields, obj): + cls = _cached_cls.get(fields) + if cls is None: + cls = namedtuple("Row", fields) + _cached_cls[fields] = cls + return cls(*obj) + +def create_getter(schema, i): + cls = create_cls(schema) + def getter(self): + return _create_object(cls, self[i]) + return getter + +def create_cls(schema): + # this can not be in global + from operator import itemgetter + + if isinstance(schema, list): + if not schema: + return list + cls = create_cls(schema[0]) + class List(list): + def __getitem__(self, i): + return _create_object(cls, list.__getitem__(self, i)) + def __reduce__(self): + return (list, (list(self),)) + return List + + elif isinstance(schema, dict): + if not schema: + return dict + vcls = create_cls(schema['value']) + class Dict(dict): + def __getitem__(self, k): + return create(vcls, dict.__getitem__(self, k)) + + # builtin types + elif not isinstance(schema, tuple): + return schema + + + class Row(tuple): + + _fields = tuple(n for n, _ in schema) + + for __i,__x in enumerate(schema): + if isinstance(__x, tuple): + __name, _type = __x + if _type and isinstance(_type, (tuple,list)): + locals()[__name] = property(create_getter(_type,__i)) + else: + locals()[__name] = property(itemgetter(__i)) + del __name, _type + else: + locals()[__x] = property(itemgetter(__i)) + del __i, __x + + def __equals__(self, x): + if type(self) != type(x): + return False + for name in self._fields: + if getattr(self, name) != getattr(x, name): + return False + return True + + def __repr__(self): + return ("Row(%s)" % ", ".join("%s=%r" % (n, getattr(self, n)) + for n in self._fields)) + + def __str__(self): + return repr(self) + + def __reduce__(self): + return (_restore_object, (self._fields, tuple(self))) + + return Row + + class SQLContext: """Main entry point for SparkSQL functionality. @@ -51,8 +171,8 @@ def __init__(self, sparkContext, sqlContext=None): ... "boolean" : True}]) >>> srdd = sqlCtx.inferSchema(allTypes).map(lambda x: (x.int, x.string, x.double, x.long, ... x.boolean)) - >>> srdd.collect()[0] - (1, u'string', 1.0, 1, True) + >>> srdd.collect() + [(1, u'string', 1.0, 1, True)] """ self._sc = sparkContext self._jsc = self._sc._jsc @@ -82,20 +202,17 @@ def inferSchema(self, rdd): tuple. >>> srdd = sqlCtx.inferSchema(rdd) - >>> srdd.collect() == [{"field1" : 1, "field2" : "row1"}, {"field1" : 2, "field2": "row2"}, - ... {"field1" : 3, "field2": "row3"}] - True + >>> srdd.collect()[0] + Row(field1=1, field2=u'row1') >>> from array import array >>> srdd = sqlCtx.inferSchema(nestedRdd1) - >>> srdd.collect() == [{"f1" : array('i', [1, 2]), "f2" : {"row1" : 1.0}}, - ... {"f1" : array('i', [2, 3]), "f2" : {"row2" : 2.0}}] - True + >>> srdd.collect() + [Row(f2={u'row1': 1.0}, f1=array('i', [1, 2])), Row(f2={u'row2': 2.0}, f1=array('i', [2, 3]))] >>> srdd = sqlCtx.inferSchema(nestedRdd2) - >>> srdd.collect() == [{"f1" : [[1, 2], [2, 3]], "f2" : set([1, 2]), "f3" : (1, 2)}, - ... {"f1" : [[2, 3], [3, 4]], "f2" : set([2, 3]), "f3" : (2, 3)}] - True + >>> srdd.collect() + [Row(f1=[[1, 2], [2, 3]], f3=(1, 2), f2=set([1, 2])), Row(f1=[[2, 3], [3, 4]], f3=(2, 3), f2=set([2, 3]))] """ if (rdd.__class__ is SchemaRDD): raise ValueError("Cannot apply schema to %s" % SchemaRDD.__name__) @@ -153,11 +270,10 @@ def jsonFile(self, path): >>> sqlCtx.registerRDDAsTable(srdd, "table1") >>> srdd2 = sqlCtx.sql( ... "SELECT field1 AS f1, field2 as f2, field3 as f3, field6 as f4 from table1") - >>> srdd2.collect() == [ - ... {"f1":1, "f2":"row1", "f3":{"field4":11, "field5": None}, "f4":None}, - ... {"f1":2, "f2":None, "f3":{"field4":22, "field5": [10, 11]}, "f4":[{"field7": "row2"}]}, - ... {"f1":None, "f2":"row3", "f3":{"field4":33, "field5": []}, "f4":None}] - True + >>> srdd2.collect() + [Row(f1=1, f2=u'row1', f3=Row(field4=11, field5=None), f4=None), \ +Row(f1=2, f2=None, f3=Row(field4=22, field5=[10, 11]), f4=Row(field7=(u'row2',))), \ +Row(f1=None, f2=u'row3', f3=Row(field4=33, field5=[]), f4=None)] """ jschema_rdd = self._ssql_ctx.jsonFile(path) return SchemaRDD(jschema_rdd, self) @@ -170,18 +286,17 @@ def jsonRDD(self, rdd): >>> sqlCtx.registerRDDAsTable(srdd, "table1") >>> srdd2 = sqlCtx.sql( ... "SELECT field1 AS f1, field2 as f2, field3 as f3, field6 as f4 from table1") - >>> srdd2.collect() == [ - ... {"f1":1, "f2":"row1", "f3":{"field4":11, "field5": None}, "f4":None}, - ... {"f1":2, "f2":None, "f3":{"field4":22, "field5": [10, 11]}, "f4":[{"field7": "row2"}]}, - ... {"f1":None, "f2":"row3", "f3":{"field4":33, "field5": []}, "f4":None}] - True + >>> srdd2.collect() + [Row(f1=1, f2=u'row1', f3=Row(field4=11, field5=None), f4=None), \ +Row(f1=2, f2=None, f3=Row(field4=22, field5=[10, 11]), f4=Row(field7=(u'row2',))), \ +Row(f1=None, f2=u'row3', f3=Row(field4=33, field5=[]), f4=None)] """ - def func(split, iterator): + def func(iterator): for x in iterator: if not isinstance(x, basestring): x = unicode(x) yield x.encode("utf-8") - keyed = PipelinedRDD(rdd, func) + keyed = rdd.mapPartitions(func) keyed._bypass_serializer = True jrdd = keyed._jrdd.map(self._jvm.BytesToString()) jschema_rdd = self._ssql_ctx.jsonRDD(jrdd.rdd()) @@ -193,9 +308,8 @@ def sql(self, sqlQuery): >>> srdd = sqlCtx.inferSchema(rdd) >>> sqlCtx.registerRDDAsTable(srdd, "table1") >>> srdd2 = sqlCtx.sql("SELECT field1 AS f1, field2 as f2 from table1") - >>> srdd2.collect() == [{"f1" : 1, "f2" : "row1"}, {"f1" : 2, "f2": "row2"}, - ... {"f1" : 3, "f2": "row3"}] - True + >>> srdd2.collect() + [Row(f1=1, f2=u'row1'), Row(f1=2, f2=u'row2'), Row(f1=3, f2=u'row3')] """ return SchemaRDD(self._ssql_ctx.sql(sqlQuery), self) @@ -258,22 +372,23 @@ class LocalHiveContext(HiveContext): An in-process metadata data is created with data stored in ./metadata. Warehouse data is stored in in ./warehouse. - >>> import os - >>> hiveCtx = LocalHiveContext(sc) - >>> try: - ... supress = hiveCtx.hql("DROP TABLE src") - ... except Exception: - ... pass - >>> kv1 = os.path.join(os.environ["SPARK_HOME"], 'examples/src/main/resources/kv1.txt') - >>> supress = hiveCtx.hql("CREATE TABLE IF NOT EXISTS src (key INT, value STRING)") - >>> supress = hiveCtx.hql("LOAD DATA LOCAL INPATH '%s' INTO TABLE src" % kv1) - >>> results = hiveCtx.hql("FROM src SELECT value").map(lambda r: int(r.value.split('_')[1])) - >>> num = results.count() - >>> reduce_sum = results.reduce(lambda x, y: x + y) - >>> num - 500 - >>> reduce_sum - 130091 + ## disable these tests tempory + ## >>> import os + ## >>> hiveCtx = LocalHiveContext(sc) + ## >>> try: + ## ... supress = hiveCtx.hql("DROP TABLE src") + ## ... except Exception: + ## ... pass + ## >>> kv1 = os.path.join(os.environ["SPARK_HOME"], 'examples/src/main/resources/kv1.txt') + ## >>> supress = hiveCtx.hql("CREATE TABLE IF NOT EXISTS src (key INT, value STRING)") + ## >>> supress = hiveCtx.hql("LOAD DATA LOCAL INPATH '%s' INTO TABLE src" % kv1) + ## >>> results = hiveCtx.hql("FROM src SELECT value").map(lambda r: int(r.value.split('_')[1])) + ## >>> num = results.count() + ## >>> reduce_sum = results.reduce(lambda x, y: x + y) + ## >>> num + ## 500 + ## >>> reduce_sum + ## 130091 """ def _get_hive_ctx(self): @@ -286,26 +401,11 @@ def _get_hive_ctx(self): return self._jvm.TestHiveContext(self._jsc.sc()) -# TODO: Investigate if it is more efficient to use a namedtuple. One problem is that named tuples -# are custom classes that must be generated per Schema. -class Row(dict): - """A row in L{SchemaRDD}. - - An extended L{dict} that takes a L{dict} in its constructor, and - exposes those items as fields. - - >>> r = Row({"hello" : "world", "foo" : "bar"}) - >>> r.hello - 'world' - >>> r.foo - 'bar' +# a stub type, the real type is dynamic generated. +class Row(tuple): + """ + A row in L{SchemaRDD}. The fields in it can be accessed like attributes. """ - - def __init__(self, d): - d.update(self.__dict__) - self.__dict__ = d - dict.__init__(self, d) - class SchemaRDD(RDD): """An RDD of L{Row} objects that has an associated schema. @@ -328,7 +428,8 @@ def __init__(self, jschema_rdd, sql_ctx): self.is_cached = False self.is_checkpointed = False self.ctx = self.sql_ctx._sc - self._jrdd_deserializer = self.ctx.serializer + # the _jrdd is created by javaToPython(), serialized by pickle + self._jrdd_deserializer = BatchedSerializer(PickleSerializer()) @property def _jrdd(self): @@ -338,7 +439,7 @@ def _jrdd(self): L{pyspark.rdd.RDD} super class (map, filter, etc.). """ if not hasattr(self, '_lazy_jrdd'): - self._lazy_jrdd = self._toPython()._jrdd + self._lazy_jrdd = self._jschema_rdd.javaToPython() return self._lazy_jrdd @property @@ -410,16 +511,23 @@ def count(self): """ return self._jschema_rdd.count() - def _toPython(self): - # We have to import the Row class explicitly, so that the reference Pickler has is - # pyspark.sql.Row instead of __main__.Row - from pyspark.sql import Row - jrdd = self._jschema_rdd.javaToPython() - # TODO: This is inefficient, we should construct the Python Row object - # in Java land in the javaToPython function. May require a custom - # pickle serializer in Pyrolite - return RDD(jrdd, self._sc, BatchedSerializer( - PickleSerializer())).map(lambda d: Row(d)) + def collect(self): + rows = RDD.collect(self) + schema = parse_tree_schema(self._jschema_rdd.schemaString()) + cls = create_cls(schema) + return map(cls, rows) + + # convert Row in JavaSchemaRDD into namedtuple, let access fields easier + def mapPartitionsWithIndex(self, f, preservesPartitioning=False): + rdd = RDD(self._jrdd, self._sc, self._jrdd_deserializer) + + schema = parse_tree_schema(self._jschema_rdd.schemaString()) + def applySchema(_, it): + cls = create_cls(schema) + return itertools.imap(cls, it) + + objrdd = rdd.mapPartitionsWithIndex(applySchema, preservesPartitioning) + return objrdd.mapPartitionsWithIndex(f, preservesPartitioning) # We override the default cache/persist/checkpoint behavior as we want to cache the underlying # SchemaRDD object in the JVM, not the PythonRDD checkpointed by the super class diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala index 31d27bb4f057..14b07b9a025c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala @@ -376,47 +376,42 @@ class SchemaRDD( * Converts a JavaRDD to a PythonRDD. It is used by pyspark. */ private[sql] def javaToPython: JavaRDD[Array[Byte]] = { - def rowToMap(row: Row, structType: StructType): JMap[String, Any] = { - val fields = structType.fields.map(field => (field.name, field.dataType)) - val map: JMap[String, Any] = new java.util.HashMap - row.zip(fields).foreach { - case (obj, (attrName, dataType)) => + def rowToArray(row: Row, structType: StructType): Array[Any] = { + val fields = structType.fields.map(field => field.dataType) + row.zip(fields).map { + case (obj, dataType) => dataType match { - case struct: StructType => map.put(attrName, rowToMap(obj.asInstanceOf[Row], struct)) + case struct: StructType => rowToArray(obj.asInstanceOf[Row], struct) case array @ ArrayType(struct: StructType) => - val arrayValues = obj match { + obj match { case seq: Seq[Any] => - seq.map(element => rowToMap(element.asInstanceOf[Row], struct)).asJava + seq.map(element => rowToArray(element.asInstanceOf[Row], struct)).asJava case list: JList[_] => - list.map(element => rowToMap(element.asInstanceOf[Row], struct)) + list.map(element => rowToArray(element.asInstanceOf[Row], struct)) case set: JSet[_] => - set.map(element => rowToMap(element.asInstanceOf[Row], struct)) + set.map(element => rowToArray(element.asInstanceOf[Row], struct)) case arr if arr != null && arr.getClass.isArray => arr.asInstanceOf[Array[Any]].map { - element => rowToMap(element.asInstanceOf[Row], struct) + element => rowToArray(element.asInstanceOf[Row], struct) } case other => other } - map.put(attrName, arrayValues) case array: ArrayType => { - val arrayValues = obj match { + obj match { case seq: Seq[Any] => seq.asJava case other => other } - map.put(attrName, arrayValues) } - case other => map.put(attrName, obj) + case other => obj } - } - - map + }.toArray } val rowSchema = StructType.fromAttributes(this.queryExecution.analyzed.output) this.mapPartitions { iter => val pickle = new Pickler iter.map { row => - rowToMap(row, rowSchema) + rowToArray(row, rowSchema) }.grouped(10).map(batched => pickle.dumps(batched.toArray)) } } From a435b5a8a6aa6e8302eb427396e0e50e4b70a8dd Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Tue, 29 Jul 2014 18:01:21 -0700 Subject: [PATCH 02/24] add docs and code refactor --- python/pyspark/sql.py | 111 +++++++++++++++++++++++++++--------------- 1 file changed, 73 insertions(+), 38 deletions(-) diff --git a/python/pyspark/sql.py b/python/pyspark/sql.py index 64814bd41b7e..16181075a025 100644 --- a/python/pyspark/sql.py +++ b/python/pyspark/sql.py @@ -21,6 +21,7 @@ import array import itertools from operator import itemgetter +from collections import namedtuple from pyspark.rdd import RDD from pyspark.serializers import BatchedSerializer, PickleSerializer @@ -30,6 +31,7 @@ __all__ = ["SQLContext", "HiveContext", "LocalHiveContext", "TestHiveContext", "SchemaRDD", "Row"] +# FIXME: these will be updated to use new API def _extend_tree(lines): parts = [] @@ -61,31 +63,46 @@ def parse_tree_schema(tree): parts = _extend_tree(lines) return _parse_tree(parts) -def _create_object(cls, v): - return cls(v) if v is not None else v -_cached_cls = {} +_cached_namedtuples = {} def _restore_object(fields, obj): - cls = _cached_cls.get(fields) + """ Restore namedtuple object during unpickling. """ + cls = _cached_namedtuples.get(fields) if cls is None: cls = namedtuple("Row", fields) - _cached_cls[fields] = cls + _cached_namedtuples[fields] = cls return cls(*obj) -def create_getter(schema, i): - cls = create_cls(schema) - def getter(self): - return _create_object(cls, self[i]) - return getter +def _create_object(cls, v): + """ Create an customized object with class `cls`. """ + return cls(v) if v is not None else v -def create_cls(schema): +def _create_getter(schema, i): + """ Create a getter for item `i` with schema """ + # TODO: cache created class + cls = _create_cls(schema) + if cls: + def getter(self): + return _create_object(cls, self[i]) + return getter + return itemgetter(i) + +def _create_cls(schema): + """ + Create an class by schama + + The created class is similar to namedtuple, but can have nested schema. + """ # this can not be in global from operator import itemgetter + # TODO: update to new DataType if isinstance(schema, list): if not schema: - return list - cls = create_cls(schema[0]) + return + cls = _create_cls(schema[0]) + if not cls: + return class List(list): def __getitem__(self, i): return _create_object(cls, list.__getitem__(self, i)) @@ -95,26 +112,29 @@ def __reduce__(self): elif isinstance(schema, dict): if not schema: - return dict - vcls = create_cls(schema['value']) + return + vcls = _create_cls(schema['value']) + if not vcls: + return class Dict(dict): def __getitem__(self, k): - return create(vcls, dict.__getitem__(self, k)) + return _create_object(vcls, dict.__getitem__(self, k)) + return Dict # builtin types elif not isinstance(schema, tuple): - return schema + return class Row(tuple): - + """ Row in SchemaRDD """ _fields = tuple(n for n, _ in schema) for __i,__x in enumerate(schema): if isinstance(__x, tuple): __name, _type = __x if _type and isinstance(_type, (tuple,list)): - locals()[__name] = property(create_getter(_type,__i)) + locals()[__name] = property(_create_getter(_type,__i)) else: locals()[__name] = property(itemgetter(__i)) del __name, _type @@ -374,23 +394,23 @@ class LocalHiveContext(HiveContext): An in-process metadata data is created with data stored in ./metadata. Warehouse data is stored in in ./warehouse. - ## disable these tests tempory - ## >>> import os - ## >>> hiveCtx = LocalHiveContext(sc) - ## >>> try: - ## ... supress = hiveCtx.hql("DROP TABLE src") - ## ... except Exception: - ## ... pass - ## >>> kv1 = os.path.join(os.environ["SPARK_HOME"], 'examples/src/main/resources/kv1.txt') - ## >>> supress = hiveCtx.hql("CREATE TABLE IF NOT EXISTS src (key INT, value STRING)") - ## >>> supress = hiveCtx.hql("LOAD DATA LOCAL INPATH '%s' INTO TABLE src" % kv1) - ## >>> results = hiveCtx.hql("FROM src SELECT value").map(lambda r: int(r.value.split('_')[1])) - ## >>> num = results.count() - ## >>> reduce_sum = results.reduce(lambda x, y: x + y) - ## >>> num - ## 500 - ## >>> reduce_sum - ## 130091 + disable these tests tempory + >>> import os + >>> hiveCtx = LocalHiveContext(sc) + >>> try: + ... supress = hiveCtx.hql("DROP TABLE src") + ... except Exception: + ... pass + >>> kv1 = os.path.join(os.environ["SPARK_HOME"], 'examples/src/main/resources/kv1.txt') + >>> supress = hiveCtx.hql("CREATE TABLE IF NOT EXISTS src (key INT, value STRING)") + >>> supress = hiveCtx.hql("LOAD DATA LOCAL INPATH '%s' INTO TABLE src" % kv1) + >>> results = hiveCtx.hql("FROM src SELECT value").map(lambda r: int(r.value.split('_')[1])) + >>> num = results.count() + >>> reduce_sum = results.reduce(lambda x, y: x + y) + >>> num + 500 + >>> reduce_sum + 130091 """ def _get_hive_ctx(self): @@ -514,18 +534,33 @@ def count(self): return self._jschema_rdd.count() def collect(self): + """ + Return a list that contains all of the rows in this RDD. + + Each object in the list is on Row, the fields can be accessed as + attributes. + """ rows = RDD.collect(self) schema = parse_tree_schema(self._jschema_rdd.schemaString()) - cls = create_cls(schema) + cls = _create_cls(schema) return map(cls, rows) # convert Row in JavaSchemaRDD into namedtuple, let access fields easier def mapPartitionsWithIndex(self, f, preservesPartitioning=False): + """ + Return a new RDD by applying a function to each partition of this RDD, + while tracking the index of the original partition. + + >>> rdd = sc.parallelize([1, 2, 3, 4], 4) + >>> def f(splitIndex, iterator): yield splitIndex + >>> rdd.mapPartitionsWithIndex(f).sum() + 6 + """ rdd = RDD(self._jrdd, self._sc, self._jrdd_deserializer) schema = parse_tree_schema(self._jschema_rdd.schemaString()) def applySchema(_, it): - cls = create_cls(schema) + cls = _create_cls(schema) return itertools.imap(cls, it) objrdd = rdd.mapPartitionsWithIndex(applySchema, preservesPartitioning) From bc6e9e16d652418d0c43998815c48541de322cdf Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Wed, 30 Jul 2014 13:33:25 -0700 Subject: [PATCH 03/24] switch to new Schema API --- python/pyspark/sql.py | 183 +++++++----------- .../org/apache/spark/sql/SQLContext.scala | 9 +- 2 files changed, 77 insertions(+), 115 deletions(-) diff --git a/python/pyspark/sql.py b/python/pyspark/sql.py index 373daef58f2c..eb48859578be 100644 --- a/python/pyspark/sql.py +++ b/python/pyspark/sql.py @@ -475,45 +475,17 @@ def _parse_datatype_string(datatype_string): return StructType(fields) -# FIXME: these will be updated to use new API -def _extend_tree(lines): - parts = [] - - def subtree(depth): - sub = parts - while depth > 1: - sub = sub[-1] - depth -= 1 - return sub - - for line in lines: - subtree(line.count('|')).append([line]) - - return parts - -def _parse_tree(tree): - if isinstance(tree[0], basestring): - name, _type = tree[0].split(":") - name = name.split(" ")[-1] - if len(tree) == 1: - return (name, _type.strip()) - else: - return (name, _parse_tree(tree[1:])) - else: - return tuple(_parse_tree(sub) for sub in tree) - -def parse_tree_schema(tree): - lines = tree.split("\n")[1:-1] - parts = _extend_tree(lines) - return _parse_tree(parts) - _cached_namedtuples = {} + def _restore_object(fields, obj): """ Restore namedtuple object during unpickling. """ cls = _cached_namedtuples.get(fields) if cls is None: cls = namedtuple("Row", fields) + def __reduce__(self): + return (_restore_object, (fields, tuple(self))) + cls.__reduce__ = __reduce__ _cached_namedtuples[fields] = cls return cls(*obj) @@ -521,78 +493,75 @@ def _create_object(cls, v): """ Create an customized object with class `cls`. """ return cls(v) if v is not None else v -def _create_getter(schema, i): +def _create_getter(dt, i): """ Create a getter for item `i` with schema """ # TODO: cache created class - cls = _create_cls(schema) - if cls: - def getter(self): - return _create_object(cls, self[i]) - return getter - return itemgetter(i) - -def _create_cls(schema): + cls = _create_cls(dt) + def getter(self): + return _create_object(cls, self[i]) + return getter + +def _has_struct(dt): + if isinstance(dt, StructType): + return True + elif isinstance(dt, ArrayType): + return _has_struct(dt.elementType) + elif isinstance(dt, MapType): + return _has_struct(dt.valueType) + return False + +def _create_cls(dataType): """ - Create an class by schama + Create an class by dataType The created class is similar to namedtuple, but can have nested schema. """ # this can not be in global + from pyspark.sql import _has_struct, _create_getter from operator import itemgetter + # TODO: update to new DataType - if isinstance(schema, list): - if not schema: - return - cls = _create_cls(schema[0]) - if not cls: - return + if isinstance(dataType, ArrayType): + cls = _create_cls(dataType.elementType) class List(list): def __getitem__(self, i): return _create_object(cls, list.__getitem__(self, i)) + def __repr__(self): + return "[%s]" % (", ".join(repr(self[i]) + for i in range(len(self)))) def __reduce__(self): + # the nested struct can be reduced by itself return (list, (list(self),)) return List - elif isinstance(schema, dict): - if not schema: - return - vcls = _create_cls(schema['value']) - if not vcls: - return + elif isinstance(dataType, MapType): + vcls = _create_cls(dataType.valueType) class Dict(dict): def __getitem__(self, k): return _create_object(vcls, dict.__getitem__(self, k)) + def __repr__(self): + return "{%s}" % (", ".join("%r: %r" % (k, self[k]) + for k in self)) + def __reduce__(self): + return (dict, (dict(self),)) return Dict - # builtin types - elif not isinstance(schema, tuple): - return - + elif not isinstance(dataType, StructType): + raise Exception("unexpected data type: %s" % dataType) class Row(tuple): """ Row in SchemaRDD """ - _fields = tuple(n for n, _ in schema) - - for __i,__x in enumerate(schema): - if isinstance(__x, tuple): - __name, _type = __x - if _type and isinstance(_type, (tuple,list)): - locals()[__name] = property(_create_getter(_type,__i)) - else: - locals()[__name] = property(itemgetter(__i)) - del __name, _type - else: - locals()[__x] = property(itemgetter(__i)) - del __i, __x + _fields = tuple(f.name for f in dataType.fields) - def __equals__(self, x): - if type(self) != type(x): - return False - for name in self._fields: - if getattr(self, name) != getattr(x, name): - return False - return True + # use local vars begins with "_" + for _i,_f in enumerate(dataType.fields): + if _has_struct(_f.dataType): + _getter = property(_create_getter(_f.dataType, _i)) + else: + _getter = property(itemgetter(_i)) + locals()[_f.name] = _getter + del _i, _f, _getter def __repr__(self): return ("Row(%s)" % ", ".join("%s=%r" % (n, getattr(self, n)) @@ -698,9 +667,8 @@ def applySchema(self, rdd, schema): >>> srdd = sqlCtx.applySchema(rdd, schema) >>> sqlCtx.registerRDDAsTable(srdd, "table1") >>> srdd2 = sqlCtx.sql("SELECT * from table1") - >>> srdd2.collect() == [{"field1" : 1, "field2" : "row1"}, {"field1" : 2, "field2": "row2"}, - ... {"field1" : 3, "field2": "row3"}] - True + >>> srdd2.collect() + [Row(field1=1, field2=u'row1'), Row(field1=2, field2=u'row2'), Row(field1=3, field2=u'row3')] >>> from datetime import datetime >>> rdd = sc.parallelize([{"byte": 127, "short": -32768, "float": 1.0, ... "time": datetime(2010, 1, 1, 1, 1, 1), "map": {"a": 1}, "struct": {"b": 2}, @@ -716,7 +684,7 @@ def applySchema(self, rdd, schema): ... StructField("null", DoubleType(), True)]) >>> srdd = sqlCtx.applySchema(rdd, schema).map( ... lambda x: ( - ... x.byte, x.short, x.float, x.time, x.map["a"], x.struct["b"], x.list, x.null)) + ... x.byte, x.short, x.float, x.time, x.map["a"], x.struct.b, x.list, x.null)) >>> srdd.collect()[0] (127, -32768, 1.0, datetime.datetime(2010, 1, 1, 1, 1, 1), 1, 2, [1, 2, 3], None) """ @@ -773,17 +741,16 @@ def jsonFile(self, path, schema=None): ... "SELECT field1 AS f1, field2 as f2, field3 as f3, field6 as f4 from table1") >>> srdd2.collect() [Row(f1=1, f2=u'row1', f3=Row(field4=11, field5=None), f4=None), \ -Row(f1=2, f2=None, f3=Row(field4=22, field5=[10, 11]), f4=Row(field7=(u'row2',))), \ +Row(f1=2, f2=None, f3=Row(field4=22, field5=[10, 11]), f4=[Row(field7=u'row2')]), \ Row(f1=None, f2=u'row3', f3=Row(field4=33, field5=[]), f4=None)] >>> srdd3 = sqlCtx.jsonFile(jsonFile, srdd1.schema()) >>> sqlCtx.registerRDDAsTable(srdd3, "table2") >>> srdd4 = sqlCtx.sql( ... "SELECT field1 AS f1, field2 as f2, field3 as f3, field6 as f4 from table2") - >>> srdd4.collect() == [ - ... {"f1":1, "f2":"row1", "f3":{"field4":11, "field5": None}, "f4":None}, - ... {"f1":2, "f2":None, "f3":{"field4":22, "field5": [10, 11]}, "f4":[{"field7": "row2"}]}, - ... {"f1":None, "f2":"row3", "f3":{"field4":33, "field5": []}, "f4":None}] - True + >>> srdd4.collect() + [Row(f1=1, f2=u'row1', f3=Row(field4=11, field5=None), f4=None), \ +Row(f1=2, f2=None, f3=Row(field4=22, field5=[10, 11]), f4=[Row(field7=u'row2')]), \ +Row(f1=None, f2=u'row3', f3=Row(field4=33, field5=[]), f4=None)] >>> schema = StructType([ ... StructField("field2", StringType(), True), ... StructField("field3", @@ -793,11 +760,8 @@ def jsonFile(self, path, schema=None): >>> sqlCtx.registerRDDAsTable(srdd5, "table3") >>> srdd6 = sqlCtx.sql( ... "SELECT field2 AS f1, field3.field5 as f2, field3.field5[0] as f3 from table3") - >>> srdd6.collect() == [ - ... {"f1": "row1", "f2": None, "f3": None}, - ... {"f1": None, "f2": [10, 11], "f3": 10}, - ... {"f1": "row3", "f2": [], "f3": None}] - True + >>> srdd6.collect() + [Row(f1=u'row1', f2=None, f3=None), Row(f1=None, f2=[10, 11], f3=10), Row(f1=u'row3', f2=[], f3=None)] """ if schema is None: jschema_rdd = self._ssql_ctx.jsonFile(path) @@ -818,17 +782,16 @@ def jsonRDD(self, rdd, schema=None): ... "SELECT field1 AS f1, field2 as f2, field3 as f3, field6 as f4 from table1") >>> srdd2.collect() [Row(f1=1, f2=u'row1', f3=Row(field4=11, field5=None), f4=None), \ -Row(f1=2, f2=None, f3=Row(field4=22, field5=[10, 11]), f4=Row(field7=(u'row2',))), \ +Row(f1=2, f2=None, f3=Row(field4=22, field5=[10, 11]), f4=[Row(field7=u'row2')]), \ Row(f1=None, f2=u'row3', f3=Row(field4=33, field5=[]), f4=None)] >>> srdd3 = sqlCtx.jsonRDD(json, srdd1.schema()) >>> sqlCtx.registerRDDAsTable(srdd3, "table2") >>> srdd4 = sqlCtx.sql( ... "SELECT field1 AS f1, field2 as f2, field3 as f3, field6 as f4 from table2") - >>> srdd4.collect() == [ - ... {"f1":1, "f2":"row1", "f3":{"field4":11, "field5": None}, "f4":None}, - ... {"f1":2, "f2":None, "f3":{"field4":22, "field5": [10, 11]}, "f4":[{"field7": "row2"}]}, - ... {"f1":None, "f2":"row3", "f3":{"field4":33, "field5": []}, "f4":None}] - True + >>> srdd4.collect() + [Row(f1=1, f2=u'row1', f3=Row(field4=11, field5=None), f4=None), \ +Row(f1=2, f2=None, f3=Row(field4=22, field5=[10, 11]), f4=[Row(field7=u'row2')]), \ +Row(f1=None, f2=u'row3', f3=Row(field4=33, field5=[]), f4=None)] >>> schema = StructType([ ... StructField("field2", StringType(), True), ... StructField("field3", @@ -838,11 +801,8 @@ def jsonRDD(self, rdd, schema=None): >>> sqlCtx.registerRDDAsTable(srdd5, "table3") >>> srdd6 = sqlCtx.sql( ... "SELECT field2 AS f1, field3.field5 as f2, field3.field5[0] as f3 from table3") - >>> srdd6.collect() == [ - ... {"f1": "row1", "f2": None, "f3": None}, - ... {"f1": None, "f2": [10, 11], "f3": 10}, - ... {"f1": "row3", "f2": [], "f3": None}] - True + >>> srdd6.collect() + [Row(f1=u'row1', f2=None, f3=None), Row(f1=None, f2=[10, 11], f3=10), Row(f1=u'row3', f2=[], f3=None)] """ def func(iterator): for x in iterator: @@ -1080,8 +1040,7 @@ def collect(self): attributes. """ rows = RDD.collect(self) - schema = parse_tree_schema(self._jschema_rdd.schemaString()) - cls = _create_cls(schema) + cls = _create_cls(self.schema()) return map(cls, rows) # convert Row in JavaSchemaRDD into namedtuple, let access fields easier @@ -1097,11 +1056,13 @@ def mapPartitionsWithIndex(self, f, preservesPartitioning=False): """ rdd = RDD(self._jrdd, self._sc, self._jrdd_deserializer) - schema = parse_tree_schema(self._jschema_rdd.schemaString()) + schema = self.schema() + import pickle + pickle.loads(pickle.dumps(schema)) def applySchema(_, it): cls = _create_cls(schema) return itertools.imap(cls, it) - + objrdd = rdd.mapPartitionsWithIndex(applySchema, preservesPartitioning) return objrdd.mapPartitionsWithIndex(f, preservesPartitioning) @@ -1171,6 +1132,9 @@ def _test(): import doctest from array import array from pyspark.context import SparkContext + # let doctest run in pyspark.sql, so DataTypes can be picklable + import pyspark.sql + from pyspark.sql import SQLContext globs = globals().copy() # The small batch size here ensures that we see multiple batches, # even in these small test examples: @@ -1195,7 +1159,8 @@ def _test(): globs['nestedRdd2'] = sc.parallelize([ {"f1": [[1, 2], [2, 3]], "f2": [1, 2]}, {"f1": [[2, 3], [3, 4]], "f2": [2, 3]}]) - (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS) + (failure_count, test_count) = doctest.testmod(pyspark.sql, + globs=globs, optionflags=doctest.ELLIPSIS) globs['sc'].stop() if failure_count: exit(-1) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index 86338752a21c..4bd358f8bf72 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -503,12 +503,9 @@ class SQLContext(@transient val sparkContext: SparkContext) } row - case (c: java.util.Map[_, _], MapType(keyType, valueType, _)) => - val converted = c.map { - case (key, value) => - (convert(key, keyType), convert(value, valueType)) - } - JMapWrapper(converted) + case (c: java.util.Map[_, _], MapType(keyType, valueType, _)) => c.map { + case (key, value) => (convert(key, keyType), convert(value, valueType)) + }.toMap case (c, ArrayType(elementType, _)) if c.getClass.isArray => val converted = c.asInstanceOf[Array[_]].map(e => convert(e, elementType)) From 182fb46079e1e732042d3b3f87bcb7731d29ce63 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Wed, 30 Jul 2014 13:48:56 -0700 Subject: [PATCH 04/24] refactor --- python/pyspark/sql.py | 160 +++++++++++++----------------------------- 1 file changed, 48 insertions(+), 112 deletions(-) diff --git a/python/pyspark/sql.py b/python/pyspark/sql.py index eb48859578be..6278a053f048 100644 --- a/python/pyspark/sql.py +++ b/python/pyspark/sql.py @@ -35,7 +35,26 @@ "SQLContext", "HiveContext", "LocalHiveContext", "TestHiveContext", "SchemaRDD", "Row"] +class DataType(object): + """Spark SQL DataType""" + + def __repr__(self): + return self.__class__.__name__ + + def __hash__(self): + return hash(repr(self)) + + def __eq__(self, other): + return (isinstance(other, self.__class__) and + self.__dict__ == other.__dict__) + + def __ne__(self, other): + return not self.__eq__(other) + + class PrimitiveTypeSingleton(type): + """Metaclass for PrimitiveType""" + _instances = {} def __call__(cls): @@ -44,140 +63,91 @@ def __call__(cls): return cls._instances[cls] -class StringType(object): +class PrimitiveType(DataType): + """Spark SQL PrimitiveType""" + + __metaclass__ = PrimitiveTypeSingleton + + +class StringType(PrimitiveType): """Spark SQL StringType The data type representing string values. - """ - __metaclass__ = PrimitiveTypeSingleton - def __repr__(self): - return "StringType" - -class BinaryType(object): +class BinaryType(PrimitiveType): """Spark SQL BinaryType The data type representing bytearray values. - """ - __metaclass__ = PrimitiveTypeSingleton - - def __repr__(self): - return "BinaryType" -class BooleanType(object): +class BooleanType(PrimitiveType): """Spark SQL BooleanType The data type representing bool values. - """ - __metaclass__ = PrimitiveTypeSingleton - def __repr__(self): - return "BooleanType" - -class TimestampType(object): +class TimestampType(PrimitiveType): """Spark SQL TimestampType The data type representing datetime.datetime values. - """ - __metaclass__ = PrimitiveTypeSingleton - - def __repr__(self): - return "TimestampType" -class DecimalType(object): +class DecimalType(PrimitiveType): """Spark SQL DecimalType The data type representing decimal.Decimal values. - """ - __metaclass__ = PrimitiveTypeSingleton - def __repr__(self): - return "DecimalType" - -class DoubleType(object): +class DoubleType(PrimitiveType): """Spark SQL DoubleType The data type representing float values. - """ - __metaclass__ = PrimitiveTypeSingleton - - def __repr__(self): - return "DoubleType" -class FloatType(object): +class FloatType(PrimitiveType): """Spark SQL FloatType The data type representing single precision floating-point values. - """ - __metaclass__ = PrimitiveTypeSingleton - - def __repr__(self): - return "FloatType" -class ByteType(object): +class ByteType(PrimitiveType): """Spark SQL ByteType The data type representing int values with 1 singed byte. - """ - __metaclass__ = PrimitiveTypeSingleton - - def __repr__(self): - return "ByteType" -class IntegerType(object): +class IntegerType(PrimitiveType): """Spark SQL IntegerType The data type representing int values. - """ - __metaclass__ = PrimitiveTypeSingleton - def __repr__(self): - return "IntegerType" - -class LongType(object): +class LongType(PrimitiveType): """Spark SQL LongType The data type representing long values. If the any value is beyond the range of [-9223372036854775808, 9223372036854775807], please use DecimalType. - """ - __metaclass__ = PrimitiveTypeSingleton - - def __repr__(self): - return "LongType" -class ShortType(object): +class ShortType(PrimitiveType): """Spark SQL ShortType The data type representing int values with 2 signed bytes. - """ - __metaclass__ = PrimitiveTypeSingleton - - def __repr__(self): - return "ShortType" -class ArrayType(object): +class ArrayType(DataType): """Spark SQL ArrayType The data type representing list values. @@ -201,19 +171,12 @@ def __init__(self, elementType, containsNull=False): self.containsNull = containsNull def __repr__(self): - return "ArrayType(" + self.elementType.__repr__() + "," + \ - str(self.containsNull).lower() + ")" + return "ArrayType(%r,%s)" % (self.elementType, + str(self.containsNull).lower()) - def __eq__(self, other): - return (isinstance(other, self.__class__) and - self.elementType == other.elementType and - self.containsNull == other.containsNull) - - def __ne__(self, other): - return not self.__eq__(other) -class MapType(object): +class MapType(DataType): """Spark SQL MapType The data type representing dict values. @@ -241,21 +204,11 @@ def __init__(self, keyType, valueType, valueContainsNull=True): self.valueContainsNull = valueContainsNull def __repr__(self): - return "MapType(" + self.keyType.__repr__() + "," + \ - self.valueType.__repr__() + "," + \ - str(self.valueContainsNull).lower() + ")" - - def __eq__(self, other): - return (isinstance(other, self.__class__) and - self.keyType == other.keyType and - self.valueType == other.valueType and - self.valueContainsNull == other.valueContainsNull) - - def __ne__(self, other): - return not self.__eq__(other) + return "MapType(%r,%r,%s)" % (self.keyType, self.valueType, + str(self.valueContainsNull).lower()) -class StructField(object): +class StructField(DataType): """Spark SQL StructField Represents a field in a StructType. @@ -281,21 +234,11 @@ def __init__(self, name, dataType, nullable): self.nullable = nullable def __repr__(self): - return "StructField(" + self.name + "," + \ - self.dataType.__repr__() + "," + \ - str(self.nullable).lower() + ")" - - def __eq__(self, other): - return (isinstance(other, self.__class__) and - self.name == other.name and - self.dataType == other.dataType and - self.nullable == other.nullable) + return "StructField(%s,%r,%s)" % (self.name, self.dataType, + str(self.nullable).lower()) - def __ne__(self, other): - return not self.__eq__(other) - -class StructType(object): +class StructType(DataType): """Spark SQL StructType The data type representing namedtuple values. @@ -318,15 +261,8 @@ def __init__(self, fields): self.fields = fields def __repr__(self): - return "StructType(List(" + \ - ",".join([field.__repr__() for field in self.fields]) + "))" - - def __eq__(self, other): - return (isinstance(other, self.__class__) and - self.fields == other.fields) - - def __ne__(self, other): - return not self.__eq__(other) + return ("StructType(List(%s))" % + ",".join(repr(field) for field in self.fields)) def _parse_datatype_list(datatype_list_string): From 2cc2d4546be846f5af3b7c8b6af0c6269e6990c5 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Wed, 30 Jul 2014 13:59:22 -0700 Subject: [PATCH 05/24] refactor --- python/pyspark/sql.py | 50 ++++++------------------------------------- 1 file changed, 7 insertions(+), 43 deletions(-) diff --git a/python/pyspark/sql.py b/python/pyspark/sql.py index 6278a053f048..c9ee3a8e1e20 100644 --- a/python/pyspark/sql.py +++ b/python/pyspark/sql.py @@ -289,6 +289,10 @@ def _parse_datatype_list(datatype_list_string): return datatype_list +_all_primitive_types = dict((k, v) for k, v in globals().iteritems() + if type(v) is PrimitiveTypeSingleton and v.__base__ == PrimitiveType) + + def _parse_datatype_string(datatype_string): """Parses the given data type string. @@ -296,27 +300,7 @@ def _parse_datatype_string(datatype_string): ... scala_datatype = sqlCtx._ssql_ctx.parseDataType(datatype.__repr__()) ... python_datatype = _parse_datatype_string(scala_datatype.toString()) ... return datatype == python_datatype - >>> check_datatype(StringType()) - True - >>> check_datatype(BinaryType()) - True - >>> check_datatype(BooleanType()) - True - >>> check_datatype(TimestampType()) - True - >>> check_datatype(DecimalType()) - True - >>> check_datatype(DoubleType()) - True - >>> check_datatype(FloatType()) - True - >>> check_datatype(ByteType()) - True - >>> check_datatype(IntegerType()) - True - >>> check_datatype(LongType()) - True - >>> check_datatype(ShortType()) + >>> all(check_datatype(cls()) for cls in _all_primitive_types.values()) True >>> # Simple ArrayType. >>> simple_arraytype = ArrayType(StringType(), True) @@ -357,28 +341,8 @@ def _parse_datatype_string(datatype_string): left_bracket_index = len(datatype_string) type_or_field = datatype_string[:left_bracket_index] rest_part = datatype_string[left_bracket_index+1:len(datatype_string)-1].strip() - if type_or_field == "StringType": - return StringType() - elif type_or_field == "BinaryType": - return BinaryType() - elif type_or_field == "BooleanType": - return BooleanType() - elif type_or_field == "TimestampType": - return TimestampType() - elif type_or_field == "DecimalType": - return DecimalType() - elif type_or_field == "DoubleType": - return DoubleType() - elif type_or_field == "FloatType": - return FloatType() - elif type_or_field == "ByteType": - return ByteType() - elif type_or_field == "IntegerType": - return IntegerType() - elif type_or_field == "LongType": - return LongType() - elif type_or_field == "ShortType": - return ShortType() + if type_or_field in _all_primitive_types: + return _all_primitive_types[type_or_field]() elif type_or_field == "ArrayType": last_comma_index = rest_part.rfind(",") containsNull = True From d69d3976f9291bd5070dafeff9743b3d31cc6be7 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Wed, 30 Jul 2014 14:15:41 -0700 Subject: [PATCH 06/24] refactor --- python/pyspark/sql.py | 50 ++++++++++++++++++++++++------------------- 1 file changed, 28 insertions(+), 22 deletions(-) diff --git a/python/pyspark/sql.py b/python/pyspark/sql.py index c9ee3a8e1e20..22004dd60ef0 100644 --- a/python/pyspark/sql.py +++ b/python/pyspark/sql.py @@ -335,14 +335,16 @@ def _parse_datatype_string(datatype_string): >>> check_datatype(complex_maptype) True """ - left_bracket_index = datatype_string.find("(") - if left_bracket_index == -1: + index = datatype_string.find("(") + if index == -1: # It is a primitive type. - left_bracket_index = len(datatype_string) - type_or_field = datatype_string[:left_bracket_index] - rest_part = datatype_string[left_bracket_index+1:len(datatype_string)-1].strip() + index = len(datatype_string) + type_or_field = datatype_string[:index] + rest_part = datatype_string[index+1:len(datatype_string)-1].strip() + if type_or_field in _all_primitive_types: return _all_primitive_types[type_or_field]() + elif type_or_field == "ArrayType": last_comma_index = rest_part.rfind(",") containsNull = True @@ -350,6 +352,7 @@ def _parse_datatype_string(datatype_string): containsNull = False elementType = _parse_datatype_string(rest_part[:last_comma_index].strip()) return ArrayType(elementType, containsNull) + elif type_or_field == "MapType": last_comma_index = rest_part.rfind(",") valueContainsNull = True @@ -357,6 +360,7 @@ def _parse_datatype_string(datatype_string): valueContainsNull = False keyType, valueType = _parse_datatype_list(rest_part[:last_comma_index].strip()) return MapType(keyType, valueType, valueContainsNull) + elif type_or_field == "StructField": first_comma_index = rest_part.find(",") name = rest_part[:first_comma_index].strip() @@ -367,6 +371,7 @@ def _parse_datatype_string(datatype_string): dataType = _parse_datatype_string( rest_part[first_comma_index+1:last_comma_index].strip()) return StructField(name, dataType, nullable) + elif type_or_field == "StructType": # rest_part should be in the format like # List(StructField(field1,IntegerType,false)). @@ -378,13 +383,13 @@ def _parse_datatype_string(datatype_string): _cached_namedtuples = {} -def _restore_object(fields, obj): +def _restore_object(name, fields, obj): """ Restore namedtuple object during unpickling. """ cls = _cached_namedtuples.get(fields) if cls is None: - cls = namedtuple("Row", fields) + cls = namedtuple(name, fields) def __reduce__(self): - return (_restore_object, (fields, tuple(self))) + return (_restore_object, (name, fields, tuple(self))) cls.__reduce__ = __reduce__ _cached_namedtuples[fields] = cls return cls(*obj) @@ -395,13 +400,13 @@ def _create_object(cls, v): def _create_getter(dt, i): """ Create a getter for item `i` with schema """ - # TODO: cache created class cls = _create_cls(dt) def getter(self): return _create_object(cls, self[i]) return getter def _has_struct(dt): + """Return whether `dt` is or has StructType in it""" if isinstance(dt, StructType): return True elif isinstance(dt, ArrayType): @@ -416,22 +421,20 @@ def _create_cls(dataType): The created class is similar to namedtuple, but can have nested schema. """ - # this can not be in global - from pyspark.sql import _has_struct, _create_getter from operator import itemgetter - - # TODO: update to new DataType if isinstance(dataType, ArrayType): cls = _create_cls(dataType.elementType) class List(list): def __getitem__(self, i): + # create object with datetype return _create_object(cls, list.__getitem__(self, i)) def __repr__(self): + # call collect __repr__ for nested objects return "[%s]" % (", ".join(repr(self[i]) for i in range(len(self)))) def __reduce__(self): - # the nested struct can be reduced by itself + # pickle as dict, the nested struct can be reduced by itself return (list, (list(self),)) return List @@ -439,11 +442,14 @@ def __reduce__(self): vcls = _create_cls(dataType.valueType) class Dict(dict): def __getitem__(self, k): + # create object with datetype return _create_object(vcls, dict.__getitem__(self, k)) def __repr__(self): + # call collect __repr__ for nested objects return "{%s}" % (", ".join("%r: %r" % (k, self[k]) for k in self)) def __reduce__(self): + # pickle as dict, the nested struct can be reduced by itself return (dict, (dict(self),)) return Dict @@ -454,24 +460,24 @@ class Row(tuple): """ Row in SchemaRDD """ _fields = tuple(f.name for f in dataType.fields) + # create property for fast access # use local vars begins with "_" for _i,_f in enumerate(dataType.fields): if _has_struct(_f.dataType): - _getter = property(_create_getter(_f.dataType, _i)) + # delay creating object until accessing it + _getter = _create_getter(_f.dataType, _i) else: - _getter = property(itemgetter(_i)) - locals()[_f.name] = _getter + _getter = itemgetter(_i) + locals()[_f.name] = property(_getter) del _i, _f, _getter def __repr__(self): + # call collect __repr__ for nested objects return ("Row(%s)" % ", ".join("%s=%r" % (n, getattr(self, n)) for n in self._fields)) - - def __str__(self): - return repr(self) - def __reduce__(self): - return (_restore_object, (self._fields, tuple(self))) + # pickle as namedtuple + return (_restore_object, ("Row", self._fields, tuple(self))) return Row From 7f6f2510f3c4bcc27e516eccbee28b9622b75f16 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Wed, 30 Jul 2014 17:20:26 -0700 Subject: [PATCH 07/24] address all comments --- python/pyspark/sql.py | 54 +++++++++++++++++++++++++------------------ 1 file changed, 31 insertions(+), 23 deletions(-) diff --git a/python/pyspark/sql.py b/python/pyspark/sql.py index 22004dd60ef0..122b4036feae 100644 --- a/python/pyspark/sql.py +++ b/python/pyspark/sql.py @@ -241,7 +241,7 @@ def __repr__(self): class StructType(DataType): """Spark SQL StructType - The data type representing namedtuple values. + The data type representing rows. A StructType object comprises a list of L{StructField}s. """ @@ -458,11 +458,11 @@ def __reduce__(self): class Row(tuple): """ Row in SchemaRDD """ - _fields = tuple(f.name for f in dataType.fields) + _FIELDS = tuple(f.name for f in dataType.fields) # create property for fast access # use local vars begins with "_" - for _i,_f in enumerate(dataType.fields): + for _i, _f in enumerate(dataType.fields): if _has_struct(_f.dataType): # delay creating object until accessing it _getter = _create_getter(_f.dataType, _i) @@ -474,10 +474,10 @@ class Row(tuple): def __repr__(self): # call collect __repr__ for nested objects return ("Row(%s)" % ", ".join("%s=%r" % (n, getattr(self, n)) - for n in self._fields)) + for n in self._FIELDS)) def __reduce__(self): # pickle as namedtuple - return (_restore_object, ("Row", self._fields, tuple(self))) + return (_restore_object, ("Row", self._FIELDS, tuple(self))) return Row @@ -645,18 +645,20 @@ def jsonFile(self, path, schema=None): >>> sqlCtx.registerRDDAsTable(srdd1, "table1") >>> srdd2 = sqlCtx.sql( ... "SELECT field1 AS f1, field2 as f2, field3 as f3, field6 as f4 from table1") - >>> srdd2.collect() - [Row(f1=1, f2=u'row1', f3=Row(field4=11, field5=None), f4=None), \ -Row(f1=2, f2=None, f3=Row(field4=22, field5=[10, 11]), f4=[Row(field7=u'row2')]), \ -Row(f1=None, f2=u'row3', f3=Row(field4=33, field5=[]), f4=None)] + >>> for r in srdd2.collect(): + ... print r + Row(f1=1, f2=u'row1', f3=Row(field4=11, field5=None), f4=None) + Row(f1=2, f2=None, f3=Row(field4=22, field5=[10, 11]), f4=[Row(field7=u'row2')]) + Row(f1=None, f2=u'row3', f3=Row(field4=33, field5=[]), f4=None) >>> srdd3 = sqlCtx.jsonFile(jsonFile, srdd1.schema()) >>> sqlCtx.registerRDDAsTable(srdd3, "table2") >>> srdd4 = sqlCtx.sql( ... "SELECT field1 AS f1, field2 as f2, field3 as f3, field6 as f4 from table2") - >>> srdd4.collect() - [Row(f1=1, f2=u'row1', f3=Row(field4=11, field5=None), f4=None), \ -Row(f1=2, f2=None, f3=Row(field4=22, field5=[10, 11]), f4=[Row(field7=u'row2')]), \ -Row(f1=None, f2=u'row3', f3=Row(field4=33, field5=[]), f4=None)] + >>> for r in srdd4.collect(): + ... print r + Row(f1=1, f2=u'row1', f3=Row(field4=11, field5=None), f4=None) + Row(f1=2, f2=None, f3=Row(field4=22, field5=[10, 11]), f4=[Row(field7=u'row2')]) + Row(f1=None, f2=u'row3', f3=Row(field4=33, field5=[]), f4=None) >>> schema = StructType([ ... StructField("field2", StringType(), True), ... StructField("field3", @@ -686,18 +688,20 @@ def jsonRDD(self, rdd, schema=None): >>> sqlCtx.registerRDDAsTable(srdd1, "table1") >>> srdd2 = sqlCtx.sql( ... "SELECT field1 AS f1, field2 as f2, field3 as f3, field6 as f4 from table1") - >>> srdd2.collect() - [Row(f1=1, f2=u'row1', f3=Row(field4=11, field5=None), f4=None), \ -Row(f1=2, f2=None, f3=Row(field4=22, field5=[10, 11]), f4=[Row(field7=u'row2')]), \ -Row(f1=None, f2=u'row3', f3=Row(field4=33, field5=[]), f4=None)] + >>> for r in srdd2.collect(): + ... print r + Row(f1=1, f2=u'row1', f3=Row(field4=11, field5=None), f4=None) + Row(f1=2, f2=None, f3=Row(field4=22, field5=[10, 11]), f4=[Row(field7=u'row2')]) + Row(f1=None, f2=u'row3', f3=Row(field4=33, field5=[]), f4=None) >>> srdd3 = sqlCtx.jsonRDD(json, srdd1.schema()) >>> sqlCtx.registerRDDAsTable(srdd3, "table2") >>> srdd4 = sqlCtx.sql( ... "SELECT field1 AS f1, field2 as f2, field3 as f3, field6 as f4 from table2") - >>> srdd4.collect() - [Row(f1=1, f2=u'row1', f3=Row(field4=11, field5=None), f4=None), \ -Row(f1=2, f2=None, f3=Row(field4=22, field5=[10, 11]), f4=[Row(field7=u'row2')]), \ -Row(f1=None, f2=u'row3', f3=Row(field4=33, field5=[]), f4=None)] + >>> for r in srdd4.collect(): + ... print r + Row(f1=1, f2=u'row1', f3=Row(field4=11, field5=None), f4=None) + Row(f1=2, f2=None, f3=Row(field4=22, field5=[10, 11]), f4=[Row(field7=u'row2')]) + Row(f1=None, f2=u'row3', f3=Row(field4=33, field5=[]), f4=None) >>> schema = StructType([ ... StructField("field2", StringType(), True), ... StructField("field3", @@ -795,7 +799,6 @@ class LocalHiveContext(HiveContext): An in-process metadata data is created with data stored in ./metadata. Warehouse data is stored in in ./warehouse. - disable these tests tempory >>> import os >>> hiveCtx = LocalHiveContext(sc) >>> try: @@ -841,6 +844,10 @@ class SchemaRDD(RDD): implementation is an RDD composed of Java objects. Instead it is converted to a PythonRDD in the JVM, on which Python operations can be done. + + This class receives raw tuples from Java but assigns a class to it in + all its data-collection methods (mapPartitionsWithIndex, collect, take, + etc) so that PySpark sees them as Row objects with named fields. """ def __init__(self, jschema_rdd, sql_ctx): @@ -949,7 +956,8 @@ def collect(self): cls = _create_cls(self.schema()) return map(cls, rows) - # convert Row in JavaSchemaRDD into namedtuple, let access fields easier + # Convert each object in the RDD to a Row with the right class + # for this SchemaRDD, so that fields can be accessed as attributes. def mapPartitionsWithIndex(self, f, preservesPartitioning=False): """ Return a new RDD by applying a function to each partition of this RDD, From c4ddc3076e929c3cc60d81c3847d047bd2e505cd Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Wed, 30 Jul 2014 17:34:37 -0700 Subject: [PATCH 08/24] fix conflict between name of fields and variables show an warning when name begins with `__` and ends with `__`. --- python/pyspark/sql.py | 38 +++++++++++++++++++++++++------------- 1 file changed, 25 insertions(+), 13 deletions(-) diff --git a/python/pyspark/sql.py b/python/pyspark/sql.py index 122b4036feae..2088bd0c3eef 100644 --- a/python/pyspark/sql.py +++ b/python/pyspark/sql.py @@ -20,6 +20,7 @@ import types import array import itertools +import warnings from operator import itemgetter from collections import namedtuple @@ -415,13 +416,30 @@ def _has_struct(dt): return _has_struct(dt.valueType) return False + +def _create_properties(fields): + """Create properties according to fields""" + ps = {} + for i, f in enumerate(fields): + name = f.name + if name.startswith("__") and name.endswith("__"): + warnings.warn("field name %s can not be accessed in Python," + "use position to access instead" % name) + continue + if _has_struct(f.dataType): + # delay creating object until accessing it + getter = _create_getter(f.dataType, i) + else: + getter = itemgetter(i) + ps[name] = property(getter) + return ps + def _create_cls(dataType): """ Create an class by dataType The created class is similar to namedtuple, but can have nested schema. """ - from operator import itemgetter if isinstance(dataType, ArrayType): cls = _create_cls(dataType.elementType) @@ -456,28 +474,22 @@ def __reduce__(self): elif not isinstance(dataType, StructType): raise Exception("unexpected data type: %s" % dataType) + + class Row(tuple): """ Row in SchemaRDD """ - _FIELDS = tuple(f.name for f in dataType.fields) + __FIELDS__ = tuple(f.name for f in dataType.fields) # create property for fast access - # use local vars begins with "_" - for _i, _f in enumerate(dataType.fields): - if _has_struct(_f.dataType): - # delay creating object until accessing it - _getter = _create_getter(_f.dataType, _i) - else: - _getter = itemgetter(_i) - locals()[_f.name] = property(_getter) - del _i, _f, _getter + locals().update(_create_properties(dataType.fields)) def __repr__(self): # call collect __repr__ for nested objects return ("Row(%s)" % ", ".join("%s=%r" % (n, getattr(self, n)) - for n in self._FIELDS)) + for n in self.__FIELDS__)) def __reduce__(self): # pickle as namedtuple - return (_restore_object, ("Row", self._FIELDS, tuple(self))) + return (_restore_object, ("Row", self.__FIELDS__, tuple(self))) return Row From b3559b4e023af04dbc7878045c889353e5020a66 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Wed, 30 Jul 2014 23:19:22 -0700 Subject: [PATCH 09/24] use generated Row instead of namedtuple so field's name can begin with "_" --- python/pyspark/sql.py | 33 +++++++++++++++++---------------- 1 file changed, 17 insertions(+), 16 deletions(-) diff --git a/python/pyspark/sql.py b/python/pyspark/sql.py index 2088bd0c3eef..df5ef94bb2f5 100644 --- a/python/pyspark/sql.py +++ b/python/pyspark/sql.py @@ -22,7 +22,6 @@ import itertools import warnings from operator import itemgetter -from collections import namedtuple from pyspark.rdd import RDD from pyspark.serializers import BatchedSerializer, PickleSerializer @@ -69,6 +68,10 @@ class PrimitiveType(DataType): __metaclass__ = PrimitiveTypeSingleton + def __eq__(self, other): + # because they should be the same object + return self is other + class StringType(PrimitiveType): """Spark SQL StringType @@ -382,18 +385,19 @@ def _parse_datatype_string(datatype_string): -_cached_namedtuples = {} +_cached_cls = {} -def _restore_object(name, fields, obj): - """ Restore namedtuple object during unpickling. """ - cls = _cached_namedtuples.get(fields) +def _restore_object(fields, obj): + """ Restore object during unpickling. """ + cls = _cached_cls.get(fields) if cls is None: - cls = namedtuple(name, fields) - def __reduce__(self): - return (_restore_object, (name, fields, tuple(self))) - cls.__reduce__ = __reduce__ - _cached_namedtuples[fields] = cls - return cls(*obj) + # create a mock StructType, because nested StructType will + # be restored by itself + fs = [StructField(n, StringType, True) for n in fields] + dataType = StructType(fs) + cls = _create_cls(dataType) + _cached_cls[fields] = cls + return cls(obj) def _create_object(cls, v): """ Create an customized object with class `cls`. """ @@ -416,7 +420,6 @@ def _has_struct(dt): return _has_struct(dt.valueType) return False - def _create_properties(fields): """Create properties according to fields""" ps = {} @@ -474,8 +477,6 @@ def __reduce__(self): elif not isinstance(dataType, StructType): raise Exception("unexpected data type: %s" % dataType) - - class Row(tuple): """ Row in SchemaRDD """ __FIELDS__ = tuple(f.name for f in dataType.fields) @@ -487,9 +488,9 @@ def __repr__(self): # call collect __repr__ for nested objects return ("Row(%s)" % ", ".join("%s=%r" % (n, getattr(self, n)) for n in self.__FIELDS__)) + def __reduce__(self): - # pickle as namedtuple - return (_restore_object, ("Row", self.__FIELDS__, tuple(self))) + return (_restore_object, (self.__FIELDS__, tuple(self))) return Row From 0eaaf560b7a6009c57de2be0030d7f830e842820 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Thu, 31 Jul 2014 12:49:40 -0700 Subject: [PATCH 10/24] fix doc tests --- python/pyspark/sql.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/pyspark/sql.py b/python/pyspark/sql.py index df5ef94bb2f5..44b836f5291e 100644 --- a/python/pyspark/sql.py +++ b/python/pyspark/sql.py @@ -1062,7 +1062,7 @@ def _test(): # let doctest run in pyspark.sql, so DataTypes can be picklable import pyspark.sql from pyspark.sql import SQLContext - globs = globals().copy() + globs = pyspark.sql.__dict__.copy() # The small batch size here ensures that we see multiple batches, # even in these small test examples: sc = SparkContext('local[4]', 'PythonTest', batchSize=2) From 9d9af5595fc2fba688ea6619e9605e402971cfed Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Thu, 31 Jul 2014 10:38:02 -0700 Subject: [PATCH 11/24] use arrry to applySchema and infer schema in Python --- .../apache/spark/api/python/PythonRDD.scala | 68 ++++--- python/pyspark/sql.py | 166 +++++++++++++++--- .../org/apache/spark/sql/SQLContext.scala | 32 ++-- 3 files changed, 198 insertions(+), 68 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala index a9d758bf998c..3abb9ffa3aa0 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala @@ -25,7 +25,7 @@ import java.util.{List => JList, ArrayList => JArrayList, Map => JMap, Collectio import scala.collection.JavaConversions._ import scala.language.existentials import scala.reflect.ClassTag -import scala.util.Try +import scala.util.{Try, Success, Failure} import net.razorvine.pickle.{Pickler, Unpickler} @@ -536,25 +536,6 @@ private[spark] object PythonRDD extends Logging { file.close() } - /** - * Convert an RDD of serialized Python dictionaries to Scala Maps (no recursive conversions). - * It is only used by pyspark.sql. - * TODO: Support more Python types. - */ - def pythonToJavaMap(pyRDD: JavaRDD[Array[Byte]]): JavaRDD[Map[String, _]] = { - pyRDD.rdd.mapPartitions { iter => - val unpickle = new Unpickler - iter.flatMap { row => - unpickle.loads(row) match { - // in case of objects are pickled in batch mode - case objs: java.util.ArrayList[JMap[String, _] @unchecked] => objs.map(_.toMap) - // not in batch mode - case obj: JMap[String @unchecked, _] => Seq(obj.toMap) - } - } - } - } - private def getMergedConf(confAsMap: java.util.HashMap[String, String], baseConf: Configuration): Configuration = { val conf = PythonHadoopUtil.mapToConf(confAsMap) @@ -701,6 +682,53 @@ private[spark] object PythonRDD extends Logging { } } + + /** + * Convert an RDD of serialized Python dictionaries to Scala Maps (no recursive conversions). + * This function is outdated, PySpark does not use it anymore + */ + def pythonToJavaMap(pyRDD: JavaRDD[Array[Byte]]): JavaRDD[Map[String, _]] = { + pyRDD.rdd.mapPartitions { iter => + val unpickle = new Unpickler + iter.flatMap { row => + unpickle.loads(row) match { + // in case of objects are pickled in batch mode + case objs: JArrayList[JMap[String, _] @unchecked] => objs.map(_.toMap) + // not in batch mode + case obj: JMap[String @unchecked, _] => Seq(obj.toMap) + } + } + } + } + + /** + * Convert an RDD of serialized Python tuple to Array (no recursive conversions). + * It is only used by pyspark.sql. + */ + def pythonToJava(pyRDD: JavaRDD[Array[Byte]]): JavaRDD[Array[_]] = { + pyRDD.rdd.mapPartitions { iter => + val unpickle = new Unpickler + iter.flatMap { row => + unpickle.loads(row) match { + // in case of objects are pickled in batch mode + case objs: JArrayList[_] => Try(objs.map(obj => obj match { + case list: JArrayList[_] => list.toArray // list + case obj if obj.getClass.isArray => // tuple + obj.asInstanceOf[Array[_]].toArray + })) match { + // objs is list of list or tuple + case Success(v) => v + // objs is a row, list of different objects + case Failure(e) => Seq(objs.toArray) + } + // not in batch mode + case obj if obj.getClass.isArray => // tuple + Seq(obj.asInstanceOf[Array[_]].toArray) + } + } + }.toJavaRDD() + } + /** * Convert and RDD of Java objects to and RDD of serialized Python objects, that is usable by * PySpark. diff --git a/python/pyspark/sql.py b/python/pyspark/sql.py index 3c61feb8f98c..b1c963ccfa7f 100644 --- a/python/pyspark/sql.py +++ b/python/pyspark/sql.py @@ -21,6 +21,8 @@ import array import itertools import warnings +import decimal +import datetime from operator import itemgetter import warnings @@ -39,11 +41,11 @@ class DataType(object): """Spark SQL DataType""" - def __repr__(self): + def __str__(self): return self.__class__.__name__ def __hash__(self): - return hash(repr(self)) + return hash(str(self)) def __eq__(self, other): return (isinstance(other, self.__class__) and @@ -175,8 +177,8 @@ def __init__(self, elementType, containsNull=False): self.elementType = elementType self.containsNull = containsNull - def __repr__(self): - return "ArrayType(%r,%s)" % (self.elementType, + def __str__(self): + return "ArrayType(%s,%s)" % (self.elementType, str(self.containsNull).lower()) @@ -208,8 +210,8 @@ def __init__(self, keyType, valueType, valueContainsNull=True): self.valueType = valueType self.valueContainsNull = valueContainsNull - def __repr__(self): - return "MapType(%r,%r,%s)" % (self.keyType, self.valueType, + def __str__(self): + return "MapType(%s,%s,%s)" % (self.keyType, self.valueType, str(self.valueContainsNull).lower()) @@ -238,8 +240,8 @@ def __init__(self, name, dataType, nullable): self.dataType = dataType self.nullable = nullable - def __repr__(self): - return "StructField(%s,%r,%s)" % (self.name, self.dataType, + def __str__(self): + return "StructField(%s,%s,%s)" % (self.name, self.dataType, str(self.nullable).lower()) @@ -265,9 +267,9 @@ def __init__(self, fields): """ self.fields = fields - def __repr__(self): + def __str__(self): return ("StructType(List(%s))" % - ",".join(repr(field) for field in self.fields)) + ",".join(str(field) for field in self.fields)) def _parse_datatype_list(datatype_list_string): @@ -302,7 +304,7 @@ def _parse_datatype_string(datatype_string): """Parses the given data type string. >>> def check_datatype(datatype): - ... scala_datatype = sqlCtx._ssql_ctx.parseDataType(datatype.__repr__()) + ... scala_datatype = sqlCtx._ssql_ctx.parseDataType(str(datatype)) ... python_datatype = _parse_datatype_string(scala_datatype.toString()) ... return datatype == python_datatype >>> all(check_datatype(cls()) for cls in _all_primitive_types.values()) @@ -385,6 +387,112 @@ def _parse_datatype_string(datatype_string): return StructType(fields) +# Mapping Python types to Spark SQL DateType +_type_mappings = { + bool: BooleanType, + int: IntegerType, + long: LongType, + float: DoubleType, + str: StringType, + unicode: StringType, + decimal.Decimal: DecimalType, + datetime.datetime: TimestampType, + datetime.date: TimestampType, + datetime.time: TimestampType, +} + +def _inferType(obj): + """Infer the DataType from obj""" + if obj is None: + raise ValueError("Can not infer type for None") + + dataType = _type_mappings.get(type(obj)) + if dataType is not None: + return dataType() + + if isinstance(obj, dict): + if not obj: + raise ValueError("Can not infer type for empty dict") + key, value = obj.iteritems().next() + return MapType(_inferType(key), _inferType(value), True) + elif isinstance(obj, (list, array.array)): + if not obj: + raise ValueError("Can not infer type for empty list/array") + return ArrayType(_inferType(obj[0]), True) + else: + try: + return _inferSchema(obj) + except ValueError: + raise ValueError("not supported type: %s" % type(obj)) + +def _inferSchema(row): + """Infer the schema from dict/namedtuple/object""" + if isinstance(row, dict): + items = sorted(row.items()) + elif isinstance(row, tuple): + if hasattr(row, "_fields"): # namedtuple + items = zip(row._fields, tuple(row)) + elif all(isinstance(x, tuple) and len(x) == 2 + for x in row): + items = row + elif hasattr(row, "__dict__"): # object + items = sorted(row.__dict__.items()) + else: + raise ValueError("Can not infer schema for type: %s" % type(row)) + + fields = [StructField(k, _inferType(v), True) for k, v in items] + return StructType(fields) + +def _create_converter(obj, dataType): + """Create an converter to drop the names of fields in obj """ + if not _has_struct(dataType): + return lambda x: x + + elif isinstance(dataType, ArrayType): + conv = _create_converter(obj[0], dataType.elementType) + return lambda row: map(conv, row) + + elif isinstance(dataType, MapType): + value = obj.values()[0] + conv = _create_converter(value, dataType.valueType) + return lambda row: dict((k, conv(v)) for k, v in row.iteritems()) + + # dataType must be StructType + names = [f.name for f in dataType.fields] + + if isinstance(obj, dict): + conv = lambda o: tuple(o.get(n) for n in names) + + elif isinstance(obj, tuple): + if hasattr(obj, "_fields"): # namedtuple + conv = tuple + elif all(isinstance(x, tuple) and len(x) == 2 + for x in obj): + conv = lambda o: tuple(v for k, v in o) + + elif hasattr(obj, "__dict__"): # object + conv = lambda o: [o.__dict__.get(n, None) for n in names] + + nested = any(_has_struct(f.dataType) for f in dataType.fields) + if not nested: + return conv + + row = conv(obj) + convs = [_create_converter(v, f.dataType) + for v, f in zip(row, dataType.fields)] + def nested_conv(row): + return tuple(f(v) for f, v in zip(convs, conv(row))) + return nested_conv + +def _dropSchema(rows, schema): + """Drop all the names of fields, becoming tuples""" + iterator = iter(rows) + row = iterator.next() + converter = _create_converter(row, schema) + yield converter(row) + for i in iterator: + yield converter(i) + _cached_cls = {} @@ -532,7 +640,7 @@ def __init__(self, sparkContext, sqlContext=None): self._sc = sparkContext self._jsc = self._sc._jsc self._jvm = self._sc._jvm - self._pythonToJavaMap = self._jvm.PythonRDD.pythonToJavaMap + self._pythonToJava = self._jvm.PythonRDD.pythonToJava if sqlContext: self._scala_SQLContext = sqlContext @@ -563,20 +671,24 @@ def inferSchema(self, rdd): >>> from array import array >>> srdd = sqlCtx.inferSchema(nestedRdd1) >>> srdd.collect() - [Row(f2={u'row1': 1.0}, f1=[1, 2]), Row(f2={u'row2': 2.0}, f1=[2, 3])] + [Row(f1=[1, 2], f2={u'row1': 1.0}), Row(f1=[2, 3], f2={u'row2': 2.0})] >>> srdd = sqlCtx.inferSchema(nestedRdd2) >>> srdd.collect() - [Row(f2=[1, 2], f1=[[1, 2], [2, 3]]), Row(f2=[2, 3], f1=[[2, 3], [3, 4]])] + [Row(f1=[[1, 2], [2, 3]], f2=[1, 2]), Row(f1=[[2, 3], [3, 4]], f2=[2, 3])] """ if (rdd.__class__ is SchemaRDD): raise ValueError("Cannot apply schema to %s" % SchemaRDD.__name__) - elif not isinstance(rdd.first(), dict): - raise ValueError("Only RDDs with dictionaries can be converted to %s: %s" % - (SchemaRDD.__name__, rdd.first())) - jrdd = self._pythonToJavaMap(rdd._jrdd) - srdd = self._ssql_ctx.inferSchema(jrdd.rdd()) + first = rdd.first() + if not first: + raise ValueError("The first row in RDD is empty, can not infer schema") + + schema = _inferSchema(first) + rdd = rdd.mapPartitions(lambda rows: _dropSchema(rows, schema)) + + jrdd = self._pythonToJava(rdd._jrdd) + srdd = self._ssql_ctx.applySchemaToPythonRDD(jrdd.rdd(), str(schema)) return SchemaRDD(srdd, self) def applySchema(self, rdd, schema): @@ -584,15 +696,14 @@ def applySchema(self, rdd, schema): >>> schema = StructType([StructField("field1", IntegerType(), False), ... StructField("field2", StringType(), False)]) - >>> srdd = sqlCtx.applySchema(rdd, schema) + >>> srdd = sqlCtx.applySchema(rdd2, schema) >>> sqlCtx.registerRDDAsTable(srdd, "table1") >>> srdd2 = sqlCtx.sql("SELECT * from table1") >>> srdd2.collect() [Row(field1=1, field2=u'row1'), Row(field1=2, field2=u'row2'), Row(field1=3, field2=u'row3')] >>> from datetime import datetime - >>> rdd = sc.parallelize([{"byte": 127, "short": -32768, "float": 1.0, - ... "time": datetime(2010, 1, 1, 1, 1, 1), "map": {"a": 1}, "struct": {"b": 2}, - ... "list": [1, 2, 3]}]) + >>> rdd = sc.parallelize([(127, -32768, 1.0, datetime(2010, 1, 1, 1, 1, 1), + ... {"a": 1}, {"b": 2}, [1, 2, 3], None)]) >>> schema = StructType([ ... StructField("byte", ByteType(), False), ... StructField("short", ShortType(), False), @@ -608,8 +719,8 @@ def applySchema(self, rdd, schema): >>> srdd.collect()[0] (127, -32768, 1.0, datetime.datetime(2010, 1, 1, 1, 1, 1), 1, 2, [1, 2, 3], None) """ - jrdd = self._pythonToJavaMap(rdd._jrdd) - srdd = self._ssql_ctx.applySchemaToPythonRDD(jrdd.rdd(), schema.__repr__()) + jrdd = self._pythonToJava(rdd._jrdd) + srdd = self._ssql_ctx.applySchemaToPythonRDD(jrdd.rdd(), str(schema)) return SchemaRDD(srdd, self) def registerRDDAsTable(self, rdd, tableName): @@ -688,7 +799,7 @@ def jsonFile(self, path, schema=None): if schema is None: jschema_rdd = self._ssql_ctx.jsonFile(path) else: - scala_datatype = self._ssql_ctx.parseDataType(schema.__repr__()) + scala_datatype = self._ssql_ctx.parseDataType(str(schema)) jschema_rdd = self._ssql_ctx.jsonFile(path, scala_datatype) return SchemaRDD(jschema_rdd, self) @@ -739,7 +850,7 @@ def func(iterator): if schema is None: jschema_rdd = self._ssql_ctx.jsonRDD(jrdd.rdd()) else: - scala_datatype = self._ssql_ctx.parseDataType(schema.__repr__()) + scala_datatype = self._ssql_ctx.parseDataType(str(schema)) jschema_rdd = self._ssql_ctx.jsonRDD(jrdd.rdd(), scala_datatype) return SchemaRDD(jschema_rdd, self) @@ -1078,6 +1189,7 @@ def _test(): {"field1": 2, "field2": "row2"}, {"field1": 3, "field2": "row3"}] ) + globs['rdd2'] = sc.parallelize([(1, "row1"), (2, "row2"), (3, "row3")]) jsonStrings = [ '{"field1": 1, "field2": "row1", "field3":{"field4":11}}', '{"field1" : 2, "field3":{"field4":22, "field5": [10, 11]}, "field6":[{"field7": "row2"}]}', diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index 4bd358f8bf72..eac33abb69a6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -413,7 +413,7 @@ class SQLContext(@transient val sparkContext: SparkContext) /** * Peek at the first row of the RDD and infer its schema. - * It is only used by PySpark. + * This function is outdated, PySpark does not use it anymore */ private[sql] def inferSchema(rdd: RDD[Map[String, _]]): SchemaRDD = { import scala.collection.JavaConversions._ @@ -437,7 +437,10 @@ class SQLContext(@transient val sparkContext: SparkContext) case (fieldName, obj) => StructField(fieldName, typeOfObject(obj), true) }.toSeq - applySchemaToPythonRDD(rdd, StructType(fields)) + val arrayRdd = rdd.map { + m => fields.map { field => m.getOrElse(field.name, null) }.toArray + } + applySchemaToPythonRDD(arrayRdd, StructType(fields)) } /** @@ -454,7 +457,7 @@ class SQLContext(@transient val sparkContext: SparkContext) * Apply a schema defined by the schemaString to an RDD. It is only used by PySpark. */ private[sql] def applySchemaToPythonRDD( - rdd: RDD[Map[String, _]], + rdd: RDD[Array[Any]], schemaString: String): SchemaRDD = { val schema = parseDataType(schemaString).asInstanceOf[StructType] applySchemaToPythonRDD(rdd, schema) @@ -464,10 +467,8 @@ class SQLContext(@transient val sparkContext: SparkContext) * Apply a schema defined by the schema to an RDD. It is only used by PySpark. */ private[sql] def applySchemaToPythonRDD( - rdd: RDD[Map[String, _]], + rdd: RDD[Array[Any]], schema: StructType): SchemaRDD = { - // TODO: We should have a better implementation once we do not turn a Python side record - // to a Map. import scala.collection.JavaConversions._ import scala.collection.convert.Wrappers.{JListWrapper, JMapWrapper} @@ -520,26 +521,15 @@ class SQLContext(@transient val sparkContext: SparkContext) } val convertedRdd = if (schema.fields.exists(f => needsConversion(f.dataType))) { - rdd.map(m => m.map { case (key, value) => (key, convert(value, schema(key).dataType)) }) + rdd.map(m => m.zip(schema.fields).map { + case (value, field) => convert(value, field.dataType) + }) } else { rdd } val rowRdd = convertedRdd.mapPartitions { iter => - val row = new GenericMutableRow(schema.fields.length) - val fieldsWithIndex = schema.fields.zipWithIndex - iter.map { m => - // We cannot use m.values because the order of values returned by m.values may not - // match fields order. - fieldsWithIndex.foreach { - case (field, i) => - val value = - m.get(field.name).flatMap(v => Option(v)).map(v => convert(v, field.dataType)).orNull - row.update(i, value) - } - - row: Row - } + iter.map { m => new GenericRow(m): Row} } new SchemaRDD(this, SparkLogicalPlan(ExistingRdd(schema.toAttributes, rowRdd))(self)) From f5df97ff4d68b89a3c86907449b28a2bc5322e57 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Thu, 31 Jul 2014 14:32:09 -0700 Subject: [PATCH 12/24] refactor, address comments --- .../apache/spark/api/python/PythonRDD.scala | 32 +++++++++---------- python/pyspark/sql.py | 10 +++--- 2 files changed, 20 insertions(+), 22 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala index 3abb9ffa3aa0..9b0ccef45259 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala @@ -705,25 +705,25 @@ private[spark] object PythonRDD extends Logging { * Convert an RDD of serialized Python tuple to Array (no recursive conversions). * It is only used by pyspark.sql. */ - def pythonToJava(pyRDD: JavaRDD[Array[Byte]]): JavaRDD[Array[_]] = { + def pythonToJavaArray(pyRDD: JavaRDD[Array[Byte]], batched: Boolean): JavaRDD[Array[_]] = { + + def toArray(obj: Any): Array[_] = { + obj match { + case objs: JArrayList[_] => + objs.toArray + case obj if obj.getClass.isArray => + obj.asInstanceOf[Array[_]].toArray + } + } + pyRDD.rdd.mapPartitions { iter => val unpickle = new Unpickler iter.flatMap { row => - unpickle.loads(row) match { - // in case of objects are pickled in batch mode - case objs: JArrayList[_] => Try(objs.map(obj => obj match { - case list: JArrayList[_] => list.toArray // list - case obj if obj.getClass.isArray => // tuple - obj.asInstanceOf[Array[_]].toArray - })) match { - // objs is list of list or tuple - case Success(v) => v - // objs is a row, list of different objects - case Failure(e) => Seq(objs.toArray) - } - // not in batch mode - case obj if obj.getClass.isArray => // tuple - Seq(obj.asInstanceOf[Array[_]].toArray) + val obj = unpickle.loads(row) + if (batched) { + obj.asInstanceOf[JArrayList[_]].map(toArray) + } else { + Seq(toArray(obj)) } } }.toJavaRDD() diff --git a/python/pyspark/sql.py b/python/pyspark/sql.py index b1c963ccfa7f..b221856b1932 100644 --- a/python/pyspark/sql.py +++ b/python/pyspark/sql.py @@ -640,7 +640,7 @@ def __init__(self, sparkContext, sqlContext=None): self._sc = sparkContext self._jsc = self._sc._jsc self._jvm = self._sc._jvm - self._pythonToJava = self._jvm.PythonRDD.pythonToJava + self._pythonToJava = self._jvm.PythonRDD.pythonToJavaArray if sqlContext: self._scala_SQLContext = sqlContext @@ -686,10 +686,7 @@ def inferSchema(self, rdd): schema = _inferSchema(first) rdd = rdd.mapPartitions(lambda rows: _dropSchema(rows, schema)) - - jrdd = self._pythonToJava(rdd._jrdd) - srdd = self._ssql_ctx.applySchemaToPythonRDD(jrdd.rdd(), str(schema)) - return SchemaRDD(srdd, self) + return self.applySchema(rdd, schema) def applySchema(self, rdd, schema): """Applies the given schema to the given RDD of L{dict}s. @@ -719,7 +716,8 @@ def applySchema(self, rdd, schema): >>> srdd.collect()[0] (127, -32768, 1.0, datetime.datetime(2010, 1, 1, 1, 1, 1), 1, 2, [1, 2, 3], None) """ - jrdd = self._pythonToJava(rdd._jrdd) + batched = isinstance(rdd._jrdd_deserializer, BatchedSerializer) + jrdd = self._pythonToJava(rdd._jrdd, batched) srdd = self._ssql_ctx.applySchemaToPythonRDD(jrdd.rdd(), str(schema)) return SchemaRDD(srdd, self) From 9d8447c19a65ddb6110d7d2d0726346a845fe082 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Thu, 31 Jul 2014 16:16:39 -0700 Subject: [PATCH 13/24] apply schema provided by string of names the type of fields will be infered automatically --- python/pyspark/sql.py | 168 +++++++++++++++++++++++++++++++++++++++--- 1 file changed, 158 insertions(+), 10 deletions(-) diff --git a/python/pyspark/sql.py b/python/pyspark/sql.py index b221856b1932..c5b8717fe84e 100644 --- a/python/pyspark/sql.py +++ b/python/pyspark/sql.py @@ -42,6 +42,9 @@ class DataType(object): """Spark SQL DataType""" def __str__(self): + return repr(self) + + def __repr__(self): return self.__class__.__name__ def __hash__(self): @@ -210,7 +213,7 @@ def __init__(self, keyType, valueType, valueContainsNull=True): self.valueType = valueType self.valueContainsNull = valueContainsNull - def __str__(self): + def __repr__(self): return "MapType(%s,%s,%s)" % (self.keyType, self.valueType, str(self.valueContainsNull).lower()) @@ -240,7 +243,7 @@ def __init__(self, name, dataType, nullable): self.dataType = dataType self.nullable = nullable - def __str__(self): + def __repr__(self): return "StructField(%s,%s,%s)" % (self.name, self.dataType, str(self.nullable).lower()) @@ -267,7 +270,7 @@ def __init__(self, fields): """ self.fields = fields - def __str__(self): + def __repr__(self): return ("StructType(List(%s))" % ",".join(str(field) for field in self.fields)) @@ -401,7 +404,7 @@ def _parse_datatype_string(datatype_string): datetime.time: TimestampType, } -def _inferType(obj): +def _infer_type(obj): """Infer the DataType from obj""" if obj is None: raise ValueError("Can not infer type for None") @@ -414,18 +417,18 @@ def _inferType(obj): if not obj: raise ValueError("Can not infer type for empty dict") key, value = obj.iteritems().next() - return MapType(_inferType(key), _inferType(value), True) + return MapType(_infer_type(key), _infer_type(value), True) elif isinstance(obj, (list, array.array)): if not obj: raise ValueError("Can not infer type for empty list/array") - return ArrayType(_inferType(obj[0]), True) + return ArrayType(_infer_type(obj[0]), True) else: try: - return _inferSchema(obj) + return _infer_schema(obj) except ValueError: raise ValueError("not supported type: %s" % type(obj)) -def _inferSchema(row): +def _infer_schema(row): """Infer the schema from dict/namedtuple/object""" if isinstance(row, dict): items = sorted(row.items()) @@ -440,7 +443,7 @@ def _inferSchema(row): else: raise ValueError("Can not infer schema for type: %s" % type(row)) - fields = [StructField(k, _inferType(v), True) for k, v in items] + fields = [StructField(k, _infer_type(v), True) for k, v in items] return StructType(fields) def _create_converter(obj, dataType): @@ -494,6 +497,133 @@ def _dropSchema(rows, schema): yield converter(i) +_BRAKETS = {'(':')', '[':']', '{':'}'} + +def _split_schema_abstract(s): + """ + split the schema abstract into fields + + >>> _split_schema_abstract("a b c") + ['a', 'b', 'c'] + >>> _split_schema_abstract("a(a b)") + ['a(a b)'] + >>> _split_schema_abstract("a b[] c{a b}") + ['a', 'b[]', 'c{a b}'] + >>> _split_schema_abstract(" ") + [] + """ + + r = [] + w = '' + brackets = [] + for c in s: + if c == ' ' and not brackets: + if w: + r.append(w) + w = '' + else: + w += c + if c in _BRAKETS: + brackets.append(c) + elif c in _BRAKETS.values(): + if not brackets or c != _BRAKETS[brackets.pop()]: + raise ValueError("unexpected " + c) + + if brackets: + raise ValueError("brackets not closed: %s" % brackets) + if w: + r.append(w) + return r + +def _parse_field_abstract(s): + """ + Parse a field in schema abstract + + >>> _parse_field_abstract("a") + StructField(a,None,true) + >>> _parse_field_abstract("b(c d)") + StructField(b,StructType(List(StructField(c,None,true),StructField(d,None,true))),true) + >>> _parse_field_abstract("a[]") + StructField(a,ArrayType(None,true),true) + >>> _parse_field_abstract("a{[]}") + StructField(a,MapType(None,ArrayType(None,true),true),true) + """ + if set(_BRAKETS.keys()) & set(s): + idx = min((s.index(c) for c in _BRAKETS if c in s)) + name = s[:idx] + return StructField(name, _parse_schema_abstract(s[idx:]), True) + else: + return StructField(s, None, True) + +def _parse_schema_abstract(s): + """ + parse abstract into schema + + >>> _parse_schema_abstract("a b c") + StructType...a...b...c... + >>> _parse_schema_abstract("a[b c] b{}") + StructType...a,ArrayType...b...c...b,MapType... + >>> _parse_schema_abstract("c{} d{a b}") + StructType...c,MapType...d,MapType...a...b... + >>> _parse_schema_abstract("a b(t)").fields[1] + StructField(b,StructType(List(StructField(t,None,true))),true) + """ + s = s.strip() + if not s: + return + + elif s.startswith('('): + return _parse_schema_abstract(s[1:-1]) + + elif s.startswith('['): + return ArrayType(_parse_schema_abstract(s[1:-1]), True) + + elif s.startswith('{'): + return MapType(None, _parse_schema_abstract(s[1:-1])) + + parts = _split_schema_abstract(s) + fields = [_parse_field_abstract(p) for p in parts] + return StructType(fields) + +def _infer_schema_type(obj, dataType): + """ + Fill the dataType with types infered from obj + + >>> schema = _parse_schema_abstract("a b c") + >>> row = (1, 1.0, "str") + >>> _infer_schema_type(row, schema) + StructType...IntegerType...DoubleType...StringType... + >>> row = [[1], {"key": (1, 2.0)}] + >>> schema = _parse_schema_abstract("a[] b{c d}") + >>> _infer_schema_type(row, schema) + StructType...a,ArrayType...b,MapType(StringType,StructType...c,IntegerType... + """ + if dataType is None: + return _infer_type(obj) + + if not obj: + raise ValueError("Can not infer type from empty value") + + if isinstance(dataType, ArrayType): + eType = _infer_schema_type(obj[0], dataType.elementType) + return ArrayType(eType, True) + + elif isinstance(dataType, MapType): + k, v = obj.iteritems().next() + return MapType(_infer_type(k), + _infer_schema_type(v, dataType.valueType)) + + elif isinstance(dataType, StructType): + fs = dataType.fields + assert len(fs) == len(obj), "Obj(%s) have different length with fields(%s)" % (obj, fs) + fields = [StructField(f.name, _infer_schema_type(o, f.dataType), True) + for o, f in zip(obj, fs)] + return StructType(fields) + + else: + raise ValueError("Unexpected dataType: %s" % dataType) + + _cached_cls = {} def _restore_object(fields, obj): @@ -684,7 +814,7 @@ def inferSchema(self, rdd): if not first: raise ValueError("The first row in RDD is empty, can not infer schema") - schema = _inferSchema(first) + schema = _infer_schema(first) rdd = rdd.mapPartitions(lambda rows: _dropSchema(rows, schema)) return self.applySchema(rdd, schema) @@ -698,6 +828,7 @@ def applySchema(self, rdd, schema): >>> srdd2 = sqlCtx.sql("SELECT * from table1") >>> srdd2.collect() [Row(field1=1, field2=u'row1'), Row(field1=2, field2=u'row2'), Row(field1=3, field2=u'row3')] + >>> from datetime import datetime >>> rdd = sc.parallelize([(127, -32768, 1.0, datetime(2010, 1, 1, 1, 1, 1), ... {"a": 1}, {"b": 2}, [1, 2, 3], None)]) @@ -715,7 +846,24 @@ def applySchema(self, rdd, schema): ... x.byte, x.short, x.float, x.time, x.map["a"], x.struct.b, x.list, x.null)) >>> srdd.collect()[0] (127, -32768, 1.0, datetime.datetime(2010, 1, 1, 1, 1, 1), 1, 2, [1, 2, 3], None) + + >>> rdd = sc.parallelize([(127, -32768, 1.0, datetime(2010, 1, 1, 1, 1, 1), + ... {"a": 1}, {"b": 2}, [1, 2, 3])]) + >>> schema = "byte short float time map{} struct(b) list[]" + >>> srdd = sqlCtx.applySchema(rdd, schema) + >>> srdd.collect() + [Row(byte=127, short=-32768, float=1.0, time=..., struct=Row(b=2), list=[1, 2, 3])] + """ + + first = rdd.first() + if not isinstance(first, (tuple, list)): + raise ValueError("Can not apply schema to type: %s" % type(first)) + + if isinstance(schema, basestring): + schema = _parse_schema_abstract(schema) + schema = _infer_schema_type(first, schema) + batched = isinstance(rdd._jrdd_deserializer, BatchedSerializer) jrdd = self._pythonToJava(rdd._jrdd, batched) srdd = self._ssql_ctx.applySchemaToPythonRDD(jrdd.rdd(), str(schema)) From 6b258b5ea52b4f9c6ac50c9d8eeb15781ab1d209 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Thu, 31 Jul 2014 17:15:06 -0700 Subject: [PATCH 14/24] fix pep8 --- python/pyspark/sql.py | 290 ++++++++++++++++++++++++++++-------------- 1 file changed, 196 insertions(+), 94 deletions(-) diff --git a/python/pyspark/sql.py b/python/pyspark/sql.py index c5b8717fe84e..dfa9a9db4d44 100644 --- a/python/pyspark/sql.py +++ b/python/pyspark/sql.py @@ -35,7 +35,8 @@ "StringType", "BinaryType", "BooleanType", "TimestampType", "DecimalType", "DoubleType", "FloatType", "ByteType", "IntegerType", "LongType", "ShortType", "ArrayType", "MapType", "StructField", "StructType", - "SQLContext", "HiveContext", "LocalHiveContext", "TestHiveContext", "SchemaRDD", "Row"] + "SQLContext", "HiveContext", "LocalHiveContext", "TestHiveContext", + "SchemaRDD", "Row"] class DataType(object): @@ -145,8 +146,9 @@ class IntegerType(PrimitiveType): class LongType(PrimitiveType): """Spark SQL LongType - The data type representing long values. If the any value is beyond the range of - [-9223372036854775808, 9223372036854775807], please use DecimalType. + The data type representing long values. If the any value is + beyond the range of [-9223372036854775808, 9223372036854775807], + please use DecimalType. """ @@ -160,12 +162,13 @@ class ShortType(PrimitiveType): class ArrayType(DataType): """Spark SQL ArrayType - The data type representing list values. - An ArrayType object comprises two fields, elementType (a DataType) and containsNull (a bool). + The data type representing list values. An ArrayType object + comprises two fields, elementType (a DataType) and containsNull (a bool). The field of elementType is used to specify the type of array elements. The field of containsNull is used to specify if the array has None values. """ + def __init__(self, elementType, containsNull=False): """Creates an ArrayType @@ -185,28 +188,34 @@ def __str__(self): str(self.containsNull).lower()) - class MapType(DataType): """Spark SQL MapType - The data type representing dict values. - A MapType object comprises three fields, - keyType (a DataType), valueType (a DataType) and valueContainsNull (a bool). + The data type representing dict values. A MapType object comprises + three fields, keyType (a DataType), valueType (a DataType) and + valueContainsNull (a bool). + The field of keyType is used to specify the type of keys in the map. The field of valueType is used to specify the type of values in the map. - The field of valueContainsNull is used to specify if values of this map has None values. + The field of valueContainsNull is used to specify if values of this + map has None values. + For values of a MapType column, keys are not allowed to have None values. """ + def __init__(self, keyType, valueType, valueContainsNull=True): """Creates a MapType :param keyType: the data type of keys. :param valueType: the data type of values. - :param valueContainsNull: indicates whether values contains null values. + :param valueContainsNull: indicates whether values contains + null values. - >>> MapType(StringType, IntegerType) == MapType(StringType, IntegerType, True) + >>> (MapType(StringType, IntegerType) + ... == MapType(StringType, IntegerType, True)) True - >>> MapType(StringType, IntegerType, False) == MapType(StringType, FloatType) + >>> (MapType(StringType, IntegerType, False) + ... == MapType(StringType, FloatType)) False """ self.keyType = keyType @@ -222,21 +231,28 @@ class StructField(DataType): """Spark SQL StructField Represents a field in a StructType. - A StructField object comprises three fields, name (a string), dataType (a DataType), - and nullable (a bool). The field of name is the name of a StructField. The field of - dataType specifies the data type of a StructField. - The field of nullable specifies if values of a StructField can contain None values. + A StructField object comprises three fields, name (a string), + dataType (a DataType) and nullable (a bool). The field of name + is the name of a StructField. The field of dataType specifies + the data type of a StructField. + + The field of nullable specifies if values of a StructField can + contain None values. """ + def __init__(self, name, dataType, nullable): """Creates a StructField :param name: the name of this field. :param dataType: the data type of this field. - :param nullable: indicates whether values of this field can be null. + :param nullable: indicates whether values of this field + can be null. - >>> StructField("f1", StringType, True) == StructField("f1", StringType, True) + >>> (StructField("f1", StringType, True) + ... == StructField("f1", StringType, True)) True - >>> StructField("f1", StringType, True) == StructField("f2", StringType, True) + >>> (StructField("f1", StringType, True) + ... == StructField("f2", StringType, True)) False """ self.name = name @@ -255,6 +271,7 @@ class StructType(DataType): A StructType object comprises a list of L{StructField}s. """ + def __init__(self, fields): """Creates a StructType @@ -308,7 +325,8 @@ def _parse_datatype_string(datatype_string): >>> def check_datatype(datatype): ... scala_datatype = sqlCtx._ssql_ctx.parseDataType(str(datatype)) - ... python_datatype = _parse_datatype_string(scala_datatype.toString()) + ... python_datatype = _parse_datatype_string( + ... scala_datatype.toString()) ... return datatype == python_datatype >>> all(check_datatype(cls()) for cls in _all_primitive_types.values()) True @@ -341,7 +359,8 @@ def _parse_datatype_string(datatype_string): >>> check_datatype(complex_arraytype) True >>> # Complex MapType. - >>> complex_maptype = MapType(complex_structtype, complex_arraytype, False) + >>> complex_maptype = MapType(complex_structtype, + ... complex_arraytype, False) >>> check_datatype(complex_maptype) True """ @@ -350,7 +369,7 @@ def _parse_datatype_string(datatype_string): # It is a primitive type. index = len(datatype_string) type_or_field = datatype_string[:index] - rest_part = datatype_string[index+1:len(datatype_string)-1].strip() + rest_part = datatype_string[index + 1:len(datatype_string) - 1].strip() if type_or_field in _all_primitive_types: return _all_primitive_types[type_or_field]() @@ -358,17 +377,19 @@ def _parse_datatype_string(datatype_string): elif type_or_field == "ArrayType": last_comma_index = rest_part.rfind(",") containsNull = True - if rest_part[last_comma_index+1:].strip().lower() == "false": + if rest_part[last_comma_index + 1:].strip().lower() == "false": containsNull = False - elementType = _parse_datatype_string(rest_part[:last_comma_index].strip()) + elementType = _parse_datatype_string( + rest_part[:last_comma_index].strip()) return ArrayType(elementType, containsNull) elif type_or_field == "MapType": last_comma_index = rest_part.rfind(",") valueContainsNull = True - if rest_part[last_comma_index+1:].strip().lower() == "false": + if rest_part[last_comma_index + 1:].strip().lower() == "false": valueContainsNull = False - keyType, valueType = _parse_datatype_list(rest_part[:last_comma_index].strip()) + keyType, valueType = _parse_datatype_list( + rest_part[:last_comma_index].strip()) return MapType(keyType, valueType, valueContainsNull) elif type_or_field == "StructField": @@ -376,16 +397,16 @@ def _parse_datatype_string(datatype_string): name = rest_part[:first_comma_index].strip() last_comma_index = rest_part.rfind(",") nullable = True - if rest_part[last_comma_index+1:].strip().lower() == "false": + if rest_part[last_comma_index + 1:].strip().lower() == "false": nullable = False dataType = _parse_datatype_string( - rest_part[first_comma_index+1:last_comma_index].strip()) + rest_part[first_comma_index + 1:last_comma_index].strip()) return StructField(name, dataType, nullable) elif type_or_field == "StructType": # rest_part should be in the format like # List(StructField(field1,IntegerType,false)). - field_list_string = rest_part[rest_part.find("(")+1:-1] + field_list_string = rest_part[rest_part.find("(") + 1:-1] fields = _parse_datatype_list(field_list_string) return StructType(fields) @@ -404,6 +425,7 @@ def _parse_datatype_string(datatype_string): datetime.time: TimestampType, } + def _infer_type(obj): """Infer the DataType from obj""" if obj is None: @@ -428,6 +450,7 @@ def _infer_type(obj): except ValueError: raise ValueError("not supported type: %s" % type(obj)) + def _infer_schema(row): """Infer the schema from dict/namedtuple/object""" if isinstance(row, dict): @@ -446,6 +469,7 @@ def _infer_schema(row): fields = [StructField(k, _infer_type(v), True) for k, v in items] return StructType(fields) + def _create_converter(obj, dataType): """Create an converter to drop the names of fields in obj """ if not _has_struct(dataType): @@ -483,10 +507,13 @@ def _create_converter(obj, dataType): row = conv(obj) convs = [_create_converter(v, f.dataType) for v, f in zip(row, dataType.fields)] + def nested_conv(row): return tuple(f(v) for f, v in zip(convs, conv(row))) + return nested_conv + def _dropSchema(rows, schema): """Drop all the names of fields, becoming tuples""" iterator = iter(rows) @@ -497,7 +524,8 @@ def _dropSchema(rows, schema): yield converter(i) -_BRAKETS = {'(':')', '[':']', '{':'}'} +_BRAKETS = {'(': ')', '[': ']', '{': '}'} + def _split_schema_abstract(s): """ @@ -535,6 +563,7 @@ def _split_schema_abstract(s): r.append(w) return r + def _parse_field_abstract(s): """ Parse a field in schema abstract @@ -542,7 +571,7 @@ def _parse_field_abstract(s): >>> _parse_field_abstract("a") StructField(a,None,true) >>> _parse_field_abstract("b(c d)") - StructField(b,StructType(List(StructField(c,None,true),StructField(d,None,true))),true) + StructField(b,StructType(...c,None,true),StructField(d... >>> _parse_field_abstract("a[]") StructField(a,ArrayType(None,true),true) >>> _parse_field_abstract("a{[]}") @@ -555,6 +584,7 @@ def _parse_field_abstract(s): else: return StructField(s, None, True) + def _parse_schema_abstract(s): """ parse abstract into schema @@ -585,6 +615,7 @@ def _parse_schema_abstract(s): fields = [_parse_field_abstract(p) for p in parts] return StructType(fields) + def _infer_schema_type(obj, dataType): """ Fill the dataType with types infered from obj @@ -596,7 +627,7 @@ def _infer_schema_type(obj, dataType): >>> row = [[1], {"key": (1, 2.0)}] >>> schema = _parse_schema_abstract("a[] b{c d}") >>> _infer_schema_type(row, schema) - StructType...a,ArrayType...b,MapType(StringType,StructType...c,IntegerType... + StructType...a,ArrayType...b,MapType(StringType,...c,IntegerType... """ if dataType is None: return _infer_type(obj) @@ -615,7 +646,8 @@ def _infer_schema_type(obj, dataType): elif isinstance(dataType, StructType): fs = dataType.fields - assert len(fs) == len(obj), "Obj(%s) have different length with fields(%s)" % (obj, fs) + assert len(fs) == len(obj), \ + "Obj(%s) have different length with fields(%s)" % (obj, fs) fields = [StructField(f.name, _infer_schema_type(o, f.dataType), True) for o, f in zip(obj, fs)] return StructType(fields) @@ -626,6 +658,7 @@ def _infer_schema_type(obj, dataType): _cached_cls = {} + def _restore_object(fields, obj): """ Restore object during unpickling. """ cls = _cached_cls.get(fields) @@ -633,22 +666,27 @@ def _restore_object(fields, obj): # create a mock StructType, because nested StructType will # be restored by itself fs = [StructField(n, StringType, True) for n in fields] - dataType = StructType(fs) + dataType = StructType(fs) cls = _create_cls(dataType) _cached_cls[fields] = cls return cls(obj) + def _create_object(cls, v): """ Create an customized object with class `cls`. """ return cls(v) if v is not None else v + def _create_getter(dt, i): """ Create a getter for item `i` with schema """ cls = _create_cls(dt) + def getter(self): return _create_object(cls, self[i]) + return getter + def _has_struct(dt): """Return whether `dt` is or has StructType in it""" if isinstance(dt, StructType): @@ -659,6 +697,7 @@ def _has_struct(dt): return _has_struct(dt.valueType) return False + def _create_properties(fields): """Create properties according to fields""" ps = {} @@ -676,6 +715,7 @@ def _create_properties(fields): ps[name] = property(getter) return ps + def _create_cls(dataType): """ Create an class by dataType @@ -685,32 +725,42 @@ def _create_cls(dataType): if isinstance(dataType, ArrayType): cls = _create_cls(dataType.elementType) + class List(list): + def __getitem__(self, i): # create object with datetype return _create_object(cls, list.__getitem__(self, i)) + def __repr__(self): # call collect __repr__ for nested objects return "[%s]" % (", ".join(repr(self[i]) for i in range(len(self)))) + + # pickle as dict, the nested struct can be reduced by itself def __reduce__(self): - # pickle as dict, the nested struct can be reduced by itself return (list, (list(self),)) + return List elif isinstance(dataType, MapType): vcls = _create_cls(dataType.valueType) + class Dict(dict): + def __getitem__(self, k): # create object with datetype return _create_object(vcls, dict.__getitem__(self, k)) + def __repr__(self): # call collect __repr__ for nested objects return "{%s}" % (", ".join("%r: %r" % (k, self[k]) for k in self)) + + # pickle as dict, the nested struct can be reduced by itself def __reduce__(self): - # pickle as dict, the nested struct can be reduced by itself return (dict, (dict(self),)) + return Dict elif not isinstance(dataType, StructType): @@ -718,6 +768,7 @@ def __reduce__(self): class Row(tuple): """ Row in SchemaRDD """ + __FIELDS__ = tuple(f.name for f in dataType.fields) # create property for fast access @@ -759,13 +810,13 @@ def __init__(self, sparkContext, sqlContext=None): ValueError:... >>> from datetime import datetime - >>> allTypes = sc.parallelize([{"int": 1, "string": "string", "double": 1.0, "long": 1L, - ... "boolean": True, "time": datetime(2010, 1, 1, 1, 1, 1), "dict": {"a": 1}, - ... "list": [1, 2, 3]}]) - >>> srdd = sqlCtx.inferSchema(allTypes).map(lambda x: (x.int, x.string, x.double, x.long, - ... x.boolean, x.time, x.dict["a"], x.list)) + >>> allTypes = sc.parallelize([{"int": 1, "string": "string", + ... "double": 1.0, "long": 1L, "boolean": True, "list": [1, 2, 3], + ... "time": datetime(2010, 1, 1, 1, 1, 1), "dict": {"a": 1},}]) + >>> srdd = sqlCtx.inferSchema(allTypes).map(lambda x: (x.int, x.string, + ... x.double, x.long, x.boolean, x.time, x.dict["a"], x.list)) >>> srdd.collect()[0] - (1, u'string', 1.0, 1, True, datetime.datetime(2010, 1, 1, 1, 1, 1), 1, [1, 2, 3]) + (1, u'string', 1.0, 1, True, ...(2010, 1, 1, 1, 1, 1), 1, [1, 2, 3]) """ self._sc = sparkContext self._jsc = self._sc._jsc @@ -801,25 +852,40 @@ def inferSchema(self, rdd): >>> from array import array >>> srdd = sqlCtx.inferSchema(nestedRdd1) >>> srdd.collect() - [Row(f1=[1, 2], f2={u'row1': 1.0}), Row(f1=[2, 3], f2={u'row2': 2.0})] + [Row(f1=[1, 2], f2={u'row1': 1.0}), ..., f2={u'row2': 2.0})] >>> srdd = sqlCtx.inferSchema(nestedRdd2) >>> srdd.collect() - [Row(f1=[[1, 2], [2, 3]], f2=[1, 2]), Row(f1=[[2, 3], [3, 4]], f2=[2, 3])] + [Row(f1=[[1, 2], [2, 3]], f2=[1, 2]), ..., f2=[2, 3])] """ if (rdd.__class__ is SchemaRDD): raise ValueError("Cannot apply schema to %s" % SchemaRDD.__name__) first = rdd.first() if not first: - raise ValueError("The first row in RDD is empty, can not infer schema") + raise ValueError("The first row in RDD is empty, " + "can not infer schema") schema = _infer_schema(first) rdd = rdd.mapPartitions(lambda rows: _dropSchema(rows, schema)) return self.applySchema(rdd, schema) def applySchema(self, rdd, schema): - """Applies the given schema to the given RDD of L{dict}s. + """ + Applies the given schema to the given RDD of L{tuple} or L{list}s. + + + The schema could be a StructType or string, such as "name value". + The schema can have nested struct (struct, list, map). + + If schema is a string, the fields are seperated by space. + Each field can be followed by composit type immediately + (without space), for example: + + "name address(city zipcode) items[] props{key value}" + + which will be filled with infered datetype from first row, so you + not have empty value in the first row. >>> schema = StructType([StructField("field1", IntegerType(), False), ... StructField("field2", StringType(), False)]) @@ -827,32 +893,36 @@ def applySchema(self, rdd, schema): >>> sqlCtx.registerRDDAsTable(srdd, "table1") >>> srdd2 = sqlCtx.sql("SELECT * from table1") >>> srdd2.collect() - [Row(field1=1, field2=u'row1'), Row(field1=2, field2=u'row2'), Row(field1=3, field2=u'row3')] + [Row(field1=1, field2=u'row1'),..., Row(field1=3, field2=u'row3')] >>> from datetime import datetime - >>> rdd = sc.parallelize([(127, -32768, 1.0, datetime(2010, 1, 1, 1, 1, 1), - ... {"a": 1}, {"b": 2}, [1, 2, 3], None)]) + >>> rdd = sc.parallelize([(127, -32768, 1.0, + ... datetime(2010, 1, 1, 1, 1, 1), + ... {"a": 1}, {"b": 2}, [1, 2, 3], None)]) >>> schema = StructType([ ... StructField("byte", ByteType(), False), ... StructField("short", ShortType(), False), ... StructField("float", FloatType(), False), ... StructField("time", TimestampType(), False), - ... StructField("map", MapType(StringType(), IntegerType(), False), False), - ... StructField("struct", StructType([StructField("b", ShortType(), False)]), False), + ... StructField("map", + ... MapType(StringType(), IntegerType(), False), False), + ... StructField("struct", + ... StructType([StructField("b", ShortType(), False)]), False), ... StructField("list", ArrayType(ByteType(), False), False), ... StructField("null", DoubleType(), True)]) >>> srdd = sqlCtx.applySchema(rdd, schema).map( - ... lambda x: ( - ... x.byte, x.short, x.float, x.time, x.map["a"], x.struct.b, x.list, x.null)) + ... lambda x: (x.byte, x.short, x.float, x.time, + ... x.map["a"], x.struct.b, x.list, x.null)) >>> srdd.collect()[0] - (127, -32768, 1.0, datetime.datetime(2010, 1, 1, 1, 1, 1), 1, 2, [1, 2, 3], None) + (127, -32768, 1.0, ...(2010, 1, 1, 1, 1, 1), 1, 2, [1, 2, 3], None) - >>> rdd = sc.parallelize([(127, -32768, 1.0, datetime(2010, 1, 1, 1, 1, 1), - ... {"a": 1}, {"b": 2}, [1, 2, 3])]) + >>> rdd = sc.parallelize([(127, -32768, 1.0, + ... datetime(2010, 1, 1, 1, 1, 1), + ... {"a": 1}, {"b": 2}, [1, 2, 3])]) >>> schema = "byte short float time map{} struct(b) list[]" >>> srdd = sqlCtx.applySchema(rdd, schema) >>> srdd.collect() - [Row(byte=127, short=-32768, float=1.0, time=..., struct=Row(b=2), list=[1, 2, 3])] + [Row(byte=127, short=-32768, float=1.0, time=..., list=[1, 2, 3])] """ @@ -900,10 +970,15 @@ def parquetFile(self, path): return SchemaRDD(jschema_rdd, self) def jsonFile(self, path, schema=None): - """Loads a text file storing one JSON object per line as a L{SchemaRDD}. + """ + Loads a text file storing one JSON object per line as a + L{SchemaRDD}. + + If the schema is provided, applies the given schema to this + JSON dataset. - If the schema is provided, applies the given schema to this JSON dataset. - Otherwise, it goes through the entire dataset once to determine the schema. + Otherwise, it goes through the entire dataset once to determine + the schema. >>> import tempfile, shutil >>> jsonFile = tempfile.mkdtemp() @@ -915,32 +990,36 @@ def jsonFile(self, path, schema=None): >>> srdd1 = sqlCtx.jsonFile(jsonFile) >>> sqlCtx.registerRDDAsTable(srdd1, "table1") >>> srdd2 = sqlCtx.sql( - ... "SELECT field1 AS f1, field2 as f2, field3 as f3, field6 as f4 from table1") + ... "SELECT field1 AS f1, field2 as f2, field3 as f3, " + ... "field6 as f4 from table1") >>> for r in srdd2.collect(): ... print r Row(f1=1, f2=u'row1', f3=Row(field4=11, field5=None), f4=None) - Row(f1=2, f2=None, f3=Row(field4=22, field5=[10, 11]), f4=[Row(field7=u'row2')]) + Row(f1=2, f2=None, f3=Row(field4=22,..., f4=[Row(field7=u'row2')]) Row(f1=None, f2=u'row3', f3=Row(field4=33, field5=[]), f4=None) >>> srdd3 = sqlCtx.jsonFile(jsonFile, srdd1.schema()) >>> sqlCtx.registerRDDAsTable(srdd3, "table2") >>> srdd4 = sqlCtx.sql( - ... "SELECT field1 AS f1, field2 as f2, field3 as f3, field6 as f4 from table2") + ... "SELECT field1 AS f1, field2 as f2, field3 as f3, " + ... "field6 as f4 from table2") >>> for r in srdd4.collect(): ... print r Row(f1=1, f2=u'row1', f3=Row(field4=11, field5=None), f4=None) - Row(f1=2, f2=None, f3=Row(field4=22, field5=[10, 11]), f4=[Row(field7=u'row2')]) + Row(f1=2, f2=None, f3=Row(field4=22,..., f4=[Row(field7=u'row2')]) Row(f1=None, f2=u'row3', f3=Row(field4=33, field5=[]), f4=None) >>> schema = StructType([ ... StructField("field2", StringType(), True), ... StructField("field3", ... StructType([ - ... StructField("field5", ArrayType(IntegerType(), False), True)]), False)]) + ... StructField("field5", + ... ArrayType(IntegerType(), False), True)]), False)]) >>> srdd5 = sqlCtx.jsonFile(jsonFile, schema) >>> sqlCtx.registerRDDAsTable(srdd5, "table3") >>> srdd6 = sqlCtx.sql( - ... "SELECT field2 AS f1, field3.field5 as f2, field3.field5[0] as f3 from table3") + ... "SELECT field2 AS f1, field3.field5 as f2, " + ... "field3.field5[0] as f3 from table3") >>> srdd6.collect() - [Row(f1=u'row1', f2=None, f3=None), Row(f1=None, f2=[10, 11], f3=10), Row(f1=u'row3', f2=[], f3=None)] + [Row(f1=u'row1', f2=None, f3=None)...Row(f1=u'row3', f2=[], f3=None)] """ if schema is None: jschema_rdd = self._ssql_ctx.jsonFile(path) @@ -952,39 +1031,47 @@ def jsonFile(self, path, schema=None): def jsonRDD(self, rdd, schema=None): """Loads an RDD storing one JSON object per string as a L{SchemaRDD}. - If the schema is provided, applies the given schema to this JSON dataset. - Otherwise, it goes through the entire dataset once to determine the schema. + If the schema is provided, applies the given schema to this + JSON dataset. + + Otherwise, it goes through the entire dataset once to determine + the schema. >>> srdd1 = sqlCtx.jsonRDD(json) >>> sqlCtx.registerRDDAsTable(srdd1, "table1") >>> srdd2 = sqlCtx.sql( - ... "SELECT field1 AS f1, field2 as f2, field3 as f3, field6 as f4 from table1") + ... "SELECT field1 AS f1, field2 as f2, field3 as f3, " + ... "field6 as f4 from table1") >>> for r in srdd2.collect(): ... print r Row(f1=1, f2=u'row1', f3=Row(field4=11, field5=None), f4=None) - Row(f1=2, f2=None, f3=Row(field4=22, field5=[10, 11]), f4=[Row(field7=u'row2')]) + Row(f1=2, f2=None, f3=Row(field4=22..., f4=[Row(field7=u'row2')]) Row(f1=None, f2=u'row3', f3=Row(field4=33, field5=[]), f4=None) >>> srdd3 = sqlCtx.jsonRDD(json, srdd1.schema()) >>> sqlCtx.registerRDDAsTable(srdd3, "table2") >>> srdd4 = sqlCtx.sql( - ... "SELECT field1 AS f1, field2 as f2, field3 as f3, field6 as f4 from table2") + ... "SELECT field1 AS f1, field2 as f2, field3 as f3, " + ... "field6 as f4 from table2") >>> for r in srdd4.collect(): ... print r Row(f1=1, f2=u'row1', f3=Row(field4=11, field5=None), f4=None) - Row(f1=2, f2=None, f3=Row(field4=22, field5=[10, 11]), f4=[Row(field7=u'row2')]) + Row(f1=2, f2=None, f3=Row(field4=22..., f4=[Row(field7=u'row2')]) Row(f1=None, f2=u'row3', f3=Row(field4=33, field5=[]), f4=None) >>> schema = StructType([ ... StructField("field2", StringType(), True), ... StructField("field3", ... StructType([ - ... StructField("field5", ArrayType(IntegerType(), False), True)]), False)]) + ... StructField("field5", + ... ArrayType(IntegerType(), False), True)]), False)]) >>> srdd5 = sqlCtx.jsonRDD(json, schema) >>> sqlCtx.registerRDDAsTable(srdd5, "table3") >>> srdd6 = sqlCtx.sql( - ... "SELECT field2 AS f1, field3.field5 as f2, field3.field5[0] as f3 from table3") + ... "SELECT field2 AS f1, field3.field5 as f2, " + ... "field3.field5[0] as f3 from table3") >>> srdd6.collect() - [Row(f1=u'row1', f2=None, f3=None), Row(f1=None, f2=[10, 11], f3=10), Row(f1=u'row3', f2=[], f3=None)] + [Row(f1=u'row1', f2=None,...Row(f1=u'row3', f2=[], f3=None)] """ + def func(iterator): for x in iterator: if not isinstance(x, basestring): @@ -1045,7 +1132,8 @@ def _ssql_ctx(self): self._scala_HiveContext = self._get_hive_ctx() return self._scala_HiveContext except Py4JError as e: - raise Exception("You must build Spark with Hive. Export 'SPARK_HIVE=true' and run " + raise Exception("You must build Spark with Hive. " + "Export 'SPARK_HIVE=true' and run " "sbt/sbt assembly", e) def _get_hive_ctx(self): @@ -1053,13 +1141,15 @@ def _get_hive_ctx(self): def hiveql(self, hqlQuery): """ - Runs a query expressed in HiveQL, returning the result as a L{SchemaRDD}. + Runs a query expressed in HiveQL, returning the result as + a L{SchemaRDD}. """ return SchemaRDD(self._ssql_ctx.hiveql(hqlQuery), self) def hql(self, hqlQuery): """ - Runs a query expressed in HiveQL, returning the result as a L{SchemaRDD}. + Runs a query expressed in HiveQL, returning the result as + a L{SchemaRDD}. """ return self.hiveql(hqlQuery) @@ -1076,10 +1166,14 @@ class LocalHiveContext(HiveContext): ... supress = hiveCtx.hql("DROP TABLE src") ... except Exception: ... pass - >>> kv1 = os.path.join(os.environ["SPARK_HOME"], 'examples/src/main/resources/kv1.txt') - >>> supress = hiveCtx.hql("CREATE TABLE IF NOT EXISTS src (key INT, value STRING)") - >>> supress = hiveCtx.hql("LOAD DATA LOCAL INPATH '%s' INTO TABLE src" % kv1) - >>> results = hiveCtx.hql("FROM src SELECT value").map(lambda r: int(r.value.split('_')[1])) + >>> kv1 = os.path.join(os.environ["SPARK_HOME"], + ... 'examples/src/main/resources/kv1.txt') + >>> supress = hiveCtx.hql( + ... "CREATE TABLE IF NOT EXISTS src (key INT, value STRING)") + >>> supress = hiveCtx.hql("LOAD DATA LOCAL INPATH '%s' INTO TABLE src" + ... % kv1) + >>> results = hiveCtx.hql("FROM src SELECT value" + ... ).map(lambda r: int(r.value.split('_')[1])) >>> num = results.count() >>> reduce_sum = results.reduce(lambda x, y: x + y) >>> num @@ -1089,8 +1183,9 @@ class LocalHiveContext(HiveContext): """ def __init__(self, sparkContext, sqlContext=None): - HiveContext.__init__(self, sparkContext, sqlContext) - warnings.warn("LocalHiveContext is deprecated. Use HiveContext instead.", DeprecationWarning) + HiveContext.__init__(self, sparkContext, sqlContext) + warnings.warn("LocalHiveContext is deprecated. " + "Use HiveContext instead.", DeprecationWarning) def _get_hive_ctx(self): return self._jvm.LocalHiveContext(self._jsc.sc()) @@ -1108,6 +1203,7 @@ class Row(tuple): A row in L{SchemaRDD}. The fields in it can be accessed like attributes. """ + class SchemaRDD(RDD): """An RDD of L{Row} objects that has an associated schema. @@ -1194,7 +1290,8 @@ def saveAsTable(self, tableName): self._jschema_rdd.saveAsTable(tableName) def schema(self): - """Returns the schema of this SchemaRDD (represented by a L{StructType}).""" + """Returns the schema of this SchemaRDD (represented by + a L{StructType}).""" return _parse_datatype_string(self._jschema_rdd.schema().toString()) def schemaString(self): @@ -1248,6 +1345,7 @@ def mapPartitionsWithIndex(self, f, preservesPartitioning=False): schema = self.schema() import pickle pickle.loads(pickle.dumps(schema)) + def applySchema(_, it): cls = _create_cls(schema) return itertools.imap(cls, it) @@ -1255,8 +1353,9 @@ def applySchema(_, it): objrdd = rdd.mapPartitionsWithIndex(applySchema, preservesPartitioning) return objrdd.mapPartitionsWithIndex(f, preservesPartitioning) - # We override the default cache/persist/checkpoint behavior as we want to cache the underlying - # SchemaRDD object in the JVM, not the PythonRDD checkpointed by the super class + # We override the default cache/persist/checkpoint behavior + # as we want to cache the underlying SchemaRDD object in the JVM, + # not the PythonRDD checkpointed by the super class def cache(self): self.is_cached = True self._jschema_rdd.cache() @@ -1311,7 +1410,8 @@ def subtract(self, other, numPartitions=None): if numPartitions is None: rdd = self._jschema_rdd.subtract(other._jschema_rdd) else: - rdd = self._jschema_rdd.subtract(other._jschema_rdd, numPartitions) + rdd = self._jschema_rdd.subtract(other._jschema_rdd, + numPartitions) return SchemaRDD(rdd, self.sql_ctx) else: raise ValueError("Can only subtract another SchemaRDD") @@ -1338,8 +1438,10 @@ def _test(): globs['rdd2'] = sc.parallelize([(1, "row1"), (2, "row2"), (3, "row3")]) jsonStrings = [ '{"field1": 1, "field2": "row1", "field3":{"field4":11}}', - '{"field1" : 2, "field3":{"field4":22, "field5": [10, 11]}, "field6":[{"field7": "row2"}]}', - '{"field1" : null, "field2": "row3", "field3":{"field4":33, "field5": []}}' + '{"field1" : 2, "field3":{"field4":22, "field5": [10, 11]},' + '"field6":[{"field7": "row2"}]}', + '{"field1" : null, "field2": "row3", ' + '"field3":{"field4":33, "field5": []}}' ] globs['jsonStrings'] = jsonStrings globs['json'] = sc.parallelize(jsonStrings) @@ -1349,8 +1451,8 @@ def _test(): globs['nestedRdd2'] = sc.parallelize([ {"f1": [[1, 2], [2, 3]], "f2": [1, 2]}, {"f1": [[2, 3], [3, 4]], "f2": [2, 3]}]) - (failure_count, test_count) = doctest.testmod(pyspark.sql, - globs=globs, optionflags=doctest.ELLIPSIS) + (failure_count, test_count) = doctest.testmod( + pyspark.sql, globs=globs, optionflags=doctest.ELLIPSIS) globs['sc'].stop() if failure_count: exit(-1) From c79ca67c02f2ed0f9d44623a946c3d05e99ecd5a Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Thu, 31 Jul 2014 18:08:43 -0700 Subject: [PATCH 15/24] fix serialization of nested data --- python/pyspark/sql.py | 45 ++++++++++++++++++++++++++++--------------- 1 file changed, 29 insertions(+), 16 deletions(-) diff --git a/python/pyspark/sql.py b/python/pyspark/sql.py index dfa9a9db4d44..d262cb37db81 100644 --- a/python/pyspark/sql.py +++ b/python/pyspark/sql.py @@ -24,6 +24,7 @@ import decimal import datetime from operator import itemgetter +import keyword import warnings from pyspark.rdd import RDD, PipelinedRDD @@ -659,16 +660,15 @@ def _infer_schema_type(obj, dataType): _cached_cls = {} -def _restore_object(fields, obj): +def _restore_object(dataType, obj): """ Restore object during unpickling. """ - cls = _cached_cls.get(fields) + # use id(dataType) as key to speed up lookup in dict + # Because of batched pickling, dataType will be the + # same object in mose cases. + cls = _cached_cls.get(id(dataType)) if cls is None: - # create a mock StructType, because nested StructType will - # be restored by itself - fs = [StructField(n, StringType, True) for n in fields] - dataType = StructType(fs) cls = _create_cls(dataType) - _cached_cls[fields] = cls + _cached_cls[id(dataType)] = cls return cls(obj) @@ -703,10 +703,10 @@ def _create_properties(fields): ps = {} for i, f in enumerate(fields): name = f.name - if name.startswith("__") and name.endswith("__"): + if (name.startswith("__") and name.endswith("__") + or keyword.iskeyword(name)): warnings.warn("field name %s can not be accessed in Python," - "use position to access instead" % name) - continue + "use position to access it instead" % name) if _has_struct(f.dataType): # delay creating object until accessing it getter = _create_getter(f.dataType, i) @@ -721,6 +721,21 @@ def _create_cls(dataType): Create an class by dataType The created class is similar to namedtuple, but can have nested schema. + + >>> schema = _parse_schema_abstract("a b c") + >>> row = (1, 1.0, "str") + >>> schema = _infer_schema_type(row, schema) + >>> obj = _create_cls(schema)(row) + >>> import pickle + >>> pickle.loads(pickle.dumps(obj)) + Row(a=1, b=1.0, c='str') + + >>> row = [[1], {"key": (1, 2.0)}] + >>> schema = _parse_schema_abstract("a[] b{c d}") + >>> schema = _infer_schema_type(row, schema) + >>> obj = _create_cls(schema)(row) + >>> pickle.loads(pickle.dumps(obj)) + Row(a=[1], b={'key': Row(c=1, d=2.0)}) """ if isinstance(dataType, ArrayType): @@ -737,9 +752,8 @@ def __repr__(self): return "[%s]" % (", ".join(repr(self[i]) for i in range(len(self)))) - # pickle as dict, the nested struct can be reduced by itself def __reduce__(self): - return (list, (list(self),)) + return list.__reduce__(self) return List @@ -757,9 +771,8 @@ def __repr__(self): return "{%s}" % (", ".join("%r: %r" % (k, self[k]) for k in self)) - # pickle as dict, the nested struct can be reduced by itself def __reduce__(self): - return (dict, (dict(self),)) + return dict.__reduce__(self) return Dict @@ -768,7 +781,7 @@ def __reduce__(self): class Row(tuple): """ Row in SchemaRDD """ - + __DATATYPE__ = dataType __FIELDS__ = tuple(f.name for f in dataType.fields) # create property for fast access @@ -780,7 +793,7 @@ def __repr__(self): for n in self.__FIELDS__)) def __reduce__(self): - return (_restore_object, (self.__FIELDS__, tuple(self))) + return (_restore_object, (self.__DATATYPE__, tuple(self))) return Row From 63de8f82e85d7dee169a9412cd506eccd47f4665 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Thu, 31 Jul 2014 18:10:38 -0700 Subject: [PATCH 16/24] fix typo --- python/pyspark/sql.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/python/pyspark/sql.py b/python/pyspark/sql.py index d262cb37db81..cd5e463e9865 100644 --- a/python/pyspark/sql.py +++ b/python/pyspark/sql.py @@ -525,7 +525,7 @@ def _dropSchema(rows, schema): yield converter(i) -_BRAKETS = {'(': ')', '[': ']', '{': '}'} +_BRACKETS = {'(': ')', '[': ']', '{': '}'} def _split_schema_abstract(s): @@ -552,10 +552,10 @@ def _split_schema_abstract(s): w = '' else: w += c - if c in _BRAKETS: + if c in _BRACKETS: brackets.append(c) - elif c in _BRAKETS.values(): - if not brackets or c != _BRAKETS[brackets.pop()]: + elif c in _BRACKETS.values(): + if not brackets or c != _BRACKETS[brackets.pop()]: raise ValueError("unexpected " + c) if brackets: @@ -578,8 +578,8 @@ def _parse_field_abstract(s): >>> _parse_field_abstract("a{[]}") StructField(a,MapType(None,ArrayType(None,true),true),true) """ - if set(_BRAKETS.keys()) & set(s): - idx = min((s.index(c) for c in _BRAKETS if c in s)) + if set(_BRACKETS.keys()) & set(s): + idx = min((s.index(c) for c in _BRACKETS if c in s)) name = s[:idx] return StructField(name, _parse_schema_abstract(s[idx:]), True) else: From 353a3f20792f7980b20bef2c1ba6838f8d1a4eb6 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Thu, 31 Jul 2014 18:14:42 -0700 Subject: [PATCH 17/24] fix code style --- python/pyspark/sql.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/python/pyspark/sql.py b/python/pyspark/sql.py index cd5e463e9865..3ca1e521c3ea 100644 --- a/python/pyspark/sql.py +++ b/python/pyspark/sql.py @@ -515,8 +515,8 @@ def nested_conv(row): return nested_conv -def _dropSchema(rows, schema): - """Drop all the names of fields, becoming tuples""" +def _drop_schema(rows, schema): + """ all the names of fields, becoming tuples""" iterator = iter(rows) row = iterator.next() converter = _create_converter(row, schema) @@ -880,7 +880,7 @@ def inferSchema(self, rdd): "can not infer schema") schema = _infer_schema(first) - rdd = rdd.mapPartitions(lambda rows: _dropSchema(rows, schema)) + rdd = rdd.mapPartitions(lambda rows: _drop_schema(rows, schema)) return self.applySchema(rdd, schema) def applySchema(self, rdd, schema): From e9c0d5c354e8f6e1f00ca7e2a86905a5741ad57a Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Thu, 31 Jul 2014 22:34:52 -0700 Subject: [PATCH 18/24] remove string typed schema --- python/pyspark/sql.py | 24 ++++++------------------ 1 file changed, 6 insertions(+), 18 deletions(-) diff --git a/python/pyspark/sql.py b/python/pyspark/sql.py index 3ca1e521c3ea..b92655d4e913 100644 --- a/python/pyspark/sql.py +++ b/python/pyspark/sql.py @@ -887,18 +887,7 @@ def applySchema(self, rdd, schema): """ Applies the given schema to the given RDD of L{tuple} or L{list}s. - - The schema could be a StructType or string, such as "name value". - The schema can have nested struct (struct, list, map). - - If schema is a string, the fields are seperated by space. - Each field can be followed by composit type immediately - (without space), for example: - - "name address(city zipcode) items[] props{key value}" - - which will be filled with infered datetype from first row, so you - not have empty value in the first row. + The schema should be a StructType. >>> schema = StructType([StructField("field1", IntegerType(), False), ... StructField("field2", StringType(), False)]) @@ -932,20 +921,19 @@ def applySchema(self, rdd, schema): >>> rdd = sc.parallelize([(127, -32768, 1.0, ... datetime(2010, 1, 1, 1, 1, 1), ... {"a": 1}, {"b": 2}, [1, 2, 3])]) - >>> schema = "byte short float time map{} struct(b) list[]" - >>> srdd = sqlCtx.applySchema(rdd, schema) + >>> abstract = "byte short float time map{} struct(b) list[]" + >>> schema = _parse_schema_abstract(abstract) + >>> typedSchema = _infer_schema_type(rdd.first(), schema) + >>> srdd = sqlCtx.applySchema(rdd, typedSchema) >>> srdd.collect() [Row(byte=127, short=-32768, float=1.0, time=..., list=[1, 2, 3])] - """ first = rdd.first() if not isinstance(first, (tuple, list)): raise ValueError("Can not apply schema to type: %s" % type(first)) - if isinstance(schema, basestring): - schema = _parse_schema_abstract(schema) - schema = _infer_schema_type(first, schema) + # TODO: verify schema with first few rows batched = isinstance(rdd._jrdd_deserializer, BatchedSerializer) jrdd = self._pythonToJava(rdd._jrdd, batched) From 51aa1358d40beb3a5fb2c5581befac7de225ea9b Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Fri, 1 Aug 2014 00:45:37 -0700 Subject: [PATCH 19/24] use Row to infer schema --- python/pyspark/sql.py | 156 +++++++++++++++++++++++++++++++++--------- 1 file changed, 125 insertions(+), 31 deletions(-) diff --git a/python/pyspark/sql.py b/python/pyspark/sql.py index b92655d4e913..1d35853c7841 100644 --- a/python/pyspark/sql.py +++ b/python/pyspark/sql.py @@ -18,14 +18,14 @@ import sys import types -import array import itertools import warnings import decimal import datetime -from operator import itemgetter import keyword import warnings +from array import array +from operator import itemgetter from pyspark.rdd import RDD, PipelinedRDD from pyspark.serializers import BatchedSerializer, PickleSerializer @@ -441,7 +441,7 @@ def _infer_type(obj): raise ValueError("Can not infer type for empty dict") key, value = obj.iteritems().next() return MapType(_infer_type(key), _infer_type(value), True) - elif isinstance(obj, (list, array.array)): + elif isinstance(obj, (list, array)): if not obj: raise ValueError("Can not infer type for empty list/array") return ArrayType(_infer_type(obj[0]), True) @@ -456,14 +456,20 @@ def _infer_schema(row): """Infer the schema from dict/namedtuple/object""" if isinstance(row, dict): items = sorted(row.items()) + elif isinstance(row, tuple): if hasattr(row, "_fields"): # namedtuple items = zip(row._fields, tuple(row)) - elif all(isinstance(x, tuple) and len(x) == 2 - for x in row): + elif hasattr(row, "__FIELDS__"): # Row + items = zip(row.__FIELDS__, tuple(row)) + elif all(isinstance(x, tuple) and len(x) == 2 for x in row): items = row + else: + raise ValueError("Can't infer schema from tuple") + elif hasattr(row, "__dict__"): # object items = sorted(row.__dict__.items()) + else: raise ValueError("Can not infer schema for type: %s" % type(row)) @@ -494,9 +500,12 @@ def _create_converter(obj, dataType): elif isinstance(obj, tuple): if hasattr(obj, "_fields"): # namedtuple conv = tuple - elif all(isinstance(x, tuple) and len(x) == 2 - for x in obj): + elif hasattr(obj, "__FIELDS__"): + conv = tuple + elif all(isinstance(x, tuple) and len(x) == 2 for x in obj): conv = lambda o: tuple(v for k, v in o) + else: + raise ValueError("unexpected tuple") elif hasattr(obj, "__dict__"): # object conv = lambda o: [o.__dict__.get(n, None) for n in names] @@ -783,6 +792,7 @@ class Row(tuple): """ Row in SchemaRDD """ __DATATYPE__ = dataType __FIELDS__ = tuple(f.name for f in dataType.fields) + __slots__ = () # create property for fast access locals().update(_create_properties(dataType.fields)) @@ -814,7 +824,7 @@ def __init__(self, sparkContext, sqlContext=None): >>> sqlCtx.inferSchema(srdd) # doctest: +IGNORE_EXCEPTION_DETAIL Traceback (most recent call last): ... - ValueError:... + TypeError:... >>> bad_rdd = sc.parallelize([1,2,3]) >>> sqlCtx.inferSchema(bad_rdd) # doctest: +IGNORE_EXCEPTION_DETAIL @@ -823,9 +833,9 @@ def __init__(self, sparkContext, sqlContext=None): ValueError:... >>> from datetime import datetime - >>> allTypes = sc.parallelize([{"int": 1, "string": "string", - ... "double": 1.0, "long": 1L, "boolean": True, "list": [1, 2, 3], - ... "time": datetime(2010, 1, 1, 1, 1, 1), "dict": {"a": 1},}]) + >>> allTypes = sc.parallelize([Row(int=1, string="string", + ... double=1.0, long=1L, boolean=True, list=[1, 2, 3], + ... time=datetime(2010, 1, 1, 1, 1, 1), dict={"a": 1})]) >>> srdd = sqlCtx.inferSchema(allTypes).map(lambda x: (x.int, x.string, ... x.double, x.long, x.boolean, x.time, x.dict["a"], x.list)) >>> srdd.collect()[0] @@ -851,33 +861,48 @@ def _ssql_ctx(self): return self._scala_SQLContext def inferSchema(self, rdd): - """Infer and apply a schema to an RDD of L{dict}s. + """Infer and apply a schema to an RDD of L{Row}s. + + We peek at the first row of the RDD to determine the fields' names + and types. Nested collections are supported, which include array, + dict, list, Row, tuple, namedtuple, or object. - We peek at the first row of the RDD to determine the fields names - and types, and then use that to extract all the dictionaries. Nested - collections are supported, which include array, dict, list, set, and - tuple. + Each row in `rdd` should be Row object or namedtuple or objects, + using dict is deprecated. + >>> rdd = sc.parallelize( + ... [Row(field1=1, field2="row1"), + ... Row(field1=2, field2="row2"), + ... Row(field1=3, field2="row3")]) >>> srdd = sqlCtx.inferSchema(rdd) >>> srdd.collect()[0] Row(field1=1, field2=u'row1') - >>> from array import array + >>> NestedRow = Row("f1", "f2") + >>> nestedRdd1 = sc.parallelize([ + ... NestedRow(array('i', [1, 2]), {"row1": 1.0}), + ... NestedRow(array('i', [2, 3]), {"row2": 2.0})]) >>> srdd = sqlCtx.inferSchema(nestedRdd1) >>> srdd.collect() [Row(f1=[1, 2], f2={u'row1': 1.0}), ..., f2={u'row2': 2.0})] + >>> nestedRdd2 = sc.parallelize([ + ... NestedRow([[1, 2], [2, 3]], [1, 2]), + ... NestedRow([[2, 3], [3, 4]], [2, 3])]) >>> srdd = sqlCtx.inferSchema(nestedRdd2) >>> srdd.collect() [Row(f1=[[1, 2], [2, 3]], f2=[1, 2]), ..., f2=[2, 3])] """ - if (rdd.__class__ is SchemaRDD): - raise ValueError("Cannot apply schema to %s" % SchemaRDD.__name__) + + if isinstance(rdd, SchemaRDD): + raise TypeError("Cannot apply schema to SchemaRDD") first = rdd.first() if not first: raise ValueError("The first row in RDD is empty, " "can not infer schema") + if type(first) is dict: + warnings.warn("Using RDD of dict to inferSchema is deprecated") schema = _infer_schema(first) rdd = rdd.mapPartitions(lambda rows: _drop_schema(rows, schema)) @@ -889,6 +914,7 @@ def applySchema(self, rdd, schema): The schema should be a StructType. + >>> rdd2 = sc.parallelize([(1, "row1"), (2, "row2"), (3, "row3")]) >>> schema = StructType([StructField("field1", IntegerType(), False), ... StructField("field2", StringType(), False)]) >>> srdd = sqlCtx.applySchema(rdd2, schema) @@ -929,6 +955,9 @@ def applySchema(self, rdd, schema): [Row(byte=127, short=-32768, float=1.0, time=..., list=[1, 2, 3])] """ + if isinstance(rdd, SchemaRDD): + raise TypeError("Cannot apply schema to SchemaRDD") + first = rdd.first() if not isinstance(first, (tuple, list)): raise ValueError("Can not apply schema to type: %s" % type(first)) @@ -1198,12 +1227,84 @@ def _get_hive_ctx(self): return self._jvm.TestHiveContext(self._jsc.sc()) -# a stub type, the real type is dynamic generated. +def _create_row(fields, values): + row = Row(*values) + row.__FIELDS__ = fields + return row + + class Row(tuple): """ A row in L{SchemaRDD}. The fields in it can be accessed like attributes. + + Row can be used to create a row object by using named arguments, + the fields will be sorted by names. + + >>> row = Row(name="Alice", age=11) + >>> row + Row(age=11, name='Alice') + >>> row.name, row.age + ('Alice', 11) + + Row also can be used to create another Row like class, then it + could be used to create Row objects, such as + + >>> Person = Row("name", "age") + >>> Person + + >>> Person("Alice", 11) + Row(name='Alice', age=11) """ + def __new__(self, *args, **kwargs): + if args and kwargs: + raise ValueError("Can not use both args " + "and kwargs to create Row") + if args: + # create row class or objects + return tuple.__new__(self, args) + + elif kwargs: + # create row objects + names = sorted(kwargs.keys()) + values = tuple(kwargs[n] for n in names) + row = tuple.__new__(self, values) + row.__FIELDS__ = names + return row + + else: + raise ValueError("No args or kwargs") + + + # let obect acs like class + def __call__(self, *args): + """create new Row object""" + return _create_row(self, args) + + def __getattr__(self, item): + if item.startswith("__"): + raise AttributeError(item) + try: + # it will be slow when it has many fields, + # but this will not be used in normal cases + idx = self.__FIELDS__.index(item) + return self[idx] + except IndexError: + raise AttributeError(item) + + def __reduce__(self): + if hasattr(self, "__FIELDS__"): + return (_create_row, (self.__FIELDS__, tuple(self))) + else: + return tuple.__reduce__(self) + + def __repr__(self): + if hasattr(self, "__FIELDS__"): + return "Row(%s)" % ", ".join("%s=%r" % (k, v) + for k, v in zip(self.__FIELDS__, self)) + else: + return "" % ", ".join(self) + class SchemaRDD(RDD): """An RDD of L{Row} objects that has an associated schema. @@ -1424,7 +1525,7 @@ def _test(): from pyspark.context import SparkContext # let doctest run in pyspark.sql, so DataTypes can be picklable import pyspark.sql - from pyspark.sql import SQLContext + from pyspark.sql import Row, SQLContext globs = pyspark.sql.__dict__.copy() # The small batch size here ensures that we see multiple batches, # even in these small test examples: @@ -1432,11 +1533,10 @@ def _test(): globs['sc'] = sc globs['sqlCtx'] = SQLContext(sc) globs['rdd'] = sc.parallelize( - [{"field1": 1, "field2": "row1"}, - {"field1": 2, "field2": "row2"}, - {"field1": 3, "field2": "row3"}] + [Row(field1=1, field2="row1"), + Row(field1=2, field2="row2"), + Row(field1=3, field2="row3")] ) - globs['rdd2'] = sc.parallelize([(1, "row1"), (2, "row2"), (3, "row3")]) jsonStrings = [ '{"field1": 1, "field2": "row1", "field3":{"field4":11}}', '{"field1" : 2, "field3":{"field4":22, "field5": [10, 11]},' @@ -1446,12 +1546,6 @@ def _test(): ] globs['jsonStrings'] = jsonStrings globs['json'] = sc.parallelize(jsonStrings) - globs['nestedRdd1'] = sc.parallelize([ - {"f1": array('i', [1, 2]), "f2": {"row1": 1.0}}, - {"f1": array('i', [2, 3]), "f2": {"row2": 2.0}}]) - globs['nestedRdd2'] = sc.parallelize([ - {"f1": [[1, 2], [2, 3]], "f2": [1, 2]}, - {"f1": [[2, 3], [3, 4]], "f2": [2, 3]}]) (failure_count, test_count) = doctest.testmod( pyspark.sql, globs=globs, optionflags=doctest.ELLIPSIS) globs['sc'].stop() From 1e5b80119b5a1e2a81654d170bd3d410f6017ca7 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Fri, 1 Aug 2014 00:55:36 -0700 Subject: [PATCH 20/24] improve cache of classes --- python/pyspark/sql.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/python/pyspark/sql.py b/python/pyspark/sql.py index 1d35853c7841..d72782ff0821 100644 --- a/python/pyspark/sql.py +++ b/python/pyspark/sql.py @@ -674,10 +674,15 @@ def _restore_object(dataType, obj): # use id(dataType) as key to speed up lookup in dict # Because of batched pickling, dataType will be the # same object in mose cases. - cls = _cached_cls.get(id(dataType)) + k = id(dataType) + cls = _cached_cls.get(k) if cls is None: - cls = _create_cls(dataType) - _cached_cls[id(dataType)] = cls + # use dataType as key to avoid create multiple class + cls = _cached_cls.get(dataType) + if cls is None: + cls = _create_cls(dataType) + _cached_cls[dataType] = cls + _cached_cls[k] = cls return cls(obj) From 61b22924ceba407e0d6885a8d4479fcb2279dafb Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Fri, 1 Aug 2014 12:18:18 -0700 Subject: [PATCH 21/24] add @deprecated to pythonToJavaMap --- core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala | 1 + 1 file changed, 1 insertion(+) diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala index 9b0ccef45259..f10fe649c0fe 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala @@ -687,6 +687,7 @@ private[spark] object PythonRDD extends Logging { * Convert an RDD of serialized Python dictionaries to Scala Maps (no recursive conversions). * This function is outdated, PySpark does not use it anymore */ + @deprecated def pythonToJavaMap(pyRDD: JavaRDD[Array[Byte]]): JavaRDD[Map[String, _]] = { pyRDD.rdd.mapPartitions { iter => val unpickle = new Unpickler From abe9e6e9e138112b138ea0aead01e037dda628ae Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Fri, 1 Aug 2014 14:21:28 -0700 Subject: [PATCH 22/24] address comments --- python/pyspark/sql.py | 33 ++++++++---- .../org/apache/spark/sql/SQLContext.scala | 54 ++++--------------- 2 files changed, 33 insertions(+), 54 deletions(-) diff --git a/python/pyspark/sql.py b/python/pyspark/sql.py index d72782ff0821..f683ae4a11ef 100644 --- a/python/pyspark/sql.py +++ b/python/pyspark/sql.py @@ -838,13 +838,17 @@ def __init__(self, sparkContext, sqlContext=None): ValueError:... >>> from datetime import datetime - >>> allTypes = sc.parallelize([Row(int=1, string="string", - ... double=1.0, long=1L, boolean=True, list=[1, 2, 3], - ... time=datetime(2010, 1, 1, 1, 1, 1), dict={"a": 1})]) - >>> srdd = sqlCtx.inferSchema(allTypes).map(lambda x: (x.int, x.string, - ... x.double, x.long, x.boolean, x.time, x.dict["a"], x.list)) - >>> srdd.collect()[0] - (1, u'string', 1.0, 1, True, ...(2010, 1, 1, 1, 1, 1), 1, [1, 2, 3]) + >>> allTypes = sc.parallelize([Row(i=1, s="string", d=1.0, l=1L, + ... b=True, list=[1, 2, 3], dict={"s": 0}, row=Row(a=1), + ... time=datetime(2014, 8, 1, 14, 1, 5))]) + >>> srdd = sqlCtx.inferSchema(allTypes) + >>> srdd.registerAsTable("allTypes") + >>> sqlCtx.sql('select i+1, d+1, not b, list[1], dict["s"], time, row.a ' + ... 'from allTypes where b and i > 0').collect() + [Row(c0=2, c1=2.0, c2=False, c3=2, c4=0...8, 1, 14, 1, 5), a=1)] + >>> srdd.map(lambda x: (x.i, x.s, x.d, x.l, x.b, x.time, + ... x.row.a, x.list)).collect() + [(1, u'string', 1.0, 1, True, ...(2014, 8, 1, 14, 1, 5), 1, [1, 2, 3])] """ self._sc = sparkContext self._jsc = self._sc._jsc @@ -872,7 +876,10 @@ def inferSchema(self, rdd): and types. Nested collections are supported, which include array, dict, list, Row, tuple, namedtuple, or object. - Each row in `rdd` should be Row object or namedtuple or objects, + All the rows in `rdd` should have the same type with the first one, + or it will cause runtime exceptions. + + Each row could be L{pyspark.sql.Row} object or namedtuple or objects, using dict is deprecated. >>> rdd = sc.parallelize( @@ -917,8 +924,14 @@ def applySchema(self, rdd, schema): """ Applies the given schema to the given RDD of L{tuple} or L{list}s. + These tuples or lists can contain complex nested structures like + lists, maps or nested rows. + The schema should be a StructType. + It is important that the schema matches the types of the objects + in each row or exceptions could be thrown at runtime. + >>> rdd2 = sc.parallelize([(1, "row1"), (2, "row2"), (3, "row3")]) >>> schema = StructType([StructField("field1", IntegerType(), False), ... StructField("field2", StringType(), False)]) @@ -931,7 +944,7 @@ def applySchema(self, rdd, schema): >>> from datetime import datetime >>> rdd = sc.parallelize([(127, -32768, 1.0, ... datetime(2010, 1, 1, 1, 1, 1), - ... {"a": 1}, {"b": 2}, [1, 2, 3], None)]) + ... {"a": 1}, (2,), [1, 2, 3], None)]) >>> schema = StructType([ ... StructField("byte", ByteType(), False), ... StructField("short", ShortType(), False), @@ -951,7 +964,7 @@ def applySchema(self, rdd, schema): >>> rdd = sc.parallelize([(127, -32768, 1.0, ... datetime(2010, 1, 1, 1, 1, 1), - ... {"a": 1}, {"b": 2}, [1, 2, 3])]) + ... {"a": 1}, (2,), [1, 2, 3])]) >>> abstract = "byte short float time map{} struct(b) list[]" >>> schema = _parse_schema_abstract(abstract) >>> typedSchema = _infer_schema_type(rdd.first(), schema) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index eac33abb69a6..dad71079c29b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -411,38 +411,6 @@ class SQLContext(@transient val sparkContext: SparkContext) """.stripMargin.trim } - /** - * Peek at the first row of the RDD and infer its schema. - * This function is outdated, PySpark does not use it anymore - */ - private[sql] def inferSchema(rdd: RDD[Map[String, _]]): SchemaRDD = { - import scala.collection.JavaConversions._ - - def typeOfComplexValue: PartialFunction[Any, DataType] = { - case c: java.util.Calendar => TimestampType - case c: java.util.List[_] => - ArrayType(typeOfObject(c.head)) - case c: java.util.Map[_, _] => - val (key, value) = c.head - MapType(typeOfObject(key), typeOfObject(value)) - case c if c.getClass.isArray => - val elem = c.asInstanceOf[Array[_]].head - ArrayType(typeOfObject(elem)) - case c => throw new Exception(s"Object of type $c cannot be used") - } - def typeOfObject = ScalaReflection.typeOfObject orElse typeOfComplexValue - - val firstRow = rdd.first() - val fields = firstRow.map { - case (fieldName, obj) => StructField(fieldName, typeOfObject(obj), true) - }.toSeq - - val arrayRdd = rdd.map { - m => fields.map { field => m.getOrElse(field.name, null) }.toArray - } - applySchemaToPythonRDD(arrayRdd, StructType(fields)) - } - /** * Parses the data type in our internal string representation. The data type string should * have the same format as the one generated by `toString` in scala. @@ -495,27 +463,25 @@ class SQLContext(@transient val sparkContext: SparkContext) val converted = c.map { e => convert(e, elementType)} JListWrapper(converted) - case (c: java.util.Map[_, _], struct: StructType) => - val row = new GenericMutableRow(struct.fields.length) - struct.fields.zipWithIndex.foreach { - case (field, i) => - val value = convert(c.get(field.name), field.dataType) - row.update(i, value) - } - row + case (c, ArrayType(elementType, _)) if c.getClass.isArray => + c.asInstanceOf[Array[_]].map(e => convert(e, elementType)): Seq[Any] case (c: java.util.Map[_, _], MapType(keyType, valueType, _)) => c.map { case (key, value) => (convert(key, keyType), convert(value, valueType)) }.toMap - case (c, ArrayType(elementType, _)) if c.getClass.isArray => - val converted = c.asInstanceOf[Array[_]].map(e => convert(e, elementType)) - converted: Seq[Any] + case (c, StructType(fields)) if c.getClass.isArray => + new GenericRow(c.asInstanceOf[Array[_]].zip(fields).map { + case (e, f) => convert(e, f.dataType) + }): Row + + case (c: java.util.Calendar, TimestampType) => + new java.sql.Timestamp(c.getTime().getTime()) - case (c: java.util.Calendar, TimestampType) => new java.sql.Timestamp(c.getTime().getTime()) case (c: Int, ByteType) => c.toByte case (c: Int, ShortType) => c.toShort case (c: Double, FloatType) => c.toFloat + case (c, StringType) if !c.isInstanceOf[String] => c.toString case (c, _) => c } From 8852aafc93a786280b8fe7ba02b40a3276d5b2de Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Fri, 1 Aug 2014 14:32:11 -0700 Subject: [PATCH 23/24] check type of schema --- python/pyspark/sql.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/python/pyspark/sql.py b/python/pyspark/sql.py index f683ae4a11ef..b03254f3c4a7 100644 --- a/python/pyspark/sql.py +++ b/python/pyspark/sql.py @@ -43,9 +43,6 @@ class DataType(object): """Spark SQL DataType""" - def __str__(self): - return repr(self) - def __repr__(self): return self.__class__.__name__ @@ -976,6 +973,9 @@ def applySchema(self, rdd, schema): if isinstance(rdd, SchemaRDD): raise TypeError("Cannot apply schema to SchemaRDD") + if not isinstance(schema, StructType): + raise TypeError("schema should be StructType") + first = rdd.first() if not isinstance(first, (tuple, list)): raise ValueError("Can not apply schema to type: %s" % type(first)) From f1d15b6dc6e0ef0364dafc80c73058632e6796b1 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Fri, 1 Aug 2014 14:59:04 -0700 Subject: [PATCH 24/24] verify schema with the first few rows --- python/pyspark/sql.py | 75 ++++++++++++++++++++++++++++++++++++++++--- 1 file changed, 70 insertions(+), 5 deletions(-) diff --git a/python/pyspark/sql.py b/python/pyspark/sql.py index b03254f3c4a7..f840475ffaf7 100644 --- a/python/pyspark/sql.py +++ b/python/pyspark/sql.py @@ -663,6 +663,72 @@ def _infer_schema_type(obj, dataType): raise ValueError("Unexpected dataType: %s" % dataType) +_acceptable_types = { + BooleanType: (bool,), + ByteType: (int, long), + ShortType: (int, long), + IntegerType: (int, long), + LongType: (int, long), + FloatType: (float,), + DoubleType: (float,), + DecimalType: (decimal.Decimal,), + StringType: (str, unicode), + TimestampType: (datetime.datetime, datetime.time, datetime.date), + ArrayType: (list, tuple, array), + MapType: (dict,), + StructType: (tuple, list), +} + +def _verify_type(obj, dataType): + """ + Verify the type of obj against dataType, raise an exception if + they do not match. + + >>> _verify_type(None, StructType([])) + >>> _verify_type("", StringType()) + >>> _verify_type(0, IntegerType()) + >>> _verify_type(range(3), ArrayType(ShortType())) + >>> _verify_type(set(), ArrayType(StringType())) # doctest: +IGNORE_EXCEPTION_DETAIL + Traceback (most recent call last): + ... + TypeError:... + >>> _verify_type({}, MapType(StringType(), IntegerType())) + >>> _verify_type((), StructType([])) + >>> _verify_type([], StructType([])) + >>> _verify_type([1], StructType([])) # doctest: +IGNORE_EXCEPTION_DETAIL + Traceback (most recent call last): + ... + ValueError:... + """ + # all objects are nullable + if obj is None: + return + + _type = type(dataType) + if _type not in _acceptable_types: + return + + if type(obj) not in _acceptable_types[_type]: + raise TypeError("%s can not accept abject in type %s" + % (dataType, type(obj))) + + if isinstance(dataType, ArrayType): + for i in obj: + _verify_type(i, dataType.elementType) + + elif isinstance(dataType, MapType): + for k, v in obj.iteritems(): + _verify_type(k, dataType.keyType) + _verify_type(v, dataType.valueType) + + elif isinstance(dataType, StructType): + if len(obj) != len(dataType.fields): + raise ValueError("Length of object (%d) does not match with" + "length of fields (%d)" % (len(obj), len(dataType.fields))) + for v, f in zip(obj, dataType.fields): + _verify_type(v, f.dataType) + + _cached_cls = {} @@ -976,11 +1042,10 @@ def applySchema(self, rdd, schema): if not isinstance(schema, StructType): raise TypeError("schema should be StructType") - first = rdd.first() - if not isinstance(first, (tuple, list)): - raise ValueError("Can not apply schema to type: %s" % type(first)) - - # TODO: verify schema with first few rows + # take the first few rows to verify schema + rows = rdd.take(10) + for row in rows: + _verify_type(row, schema) batched = isinstance(rdd._jrdd_deserializer, BatchedSerializer) jrdd = self._pythonToJava(rdd._jrdd, batched)