From 55bb86eeaba233e964c66534e99faf1d6a55c164 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Thu, 16 Jul 2015 14:08:46 -0700 Subject: [PATCH 1/9] support Python UDT in __main__ (without Scala one) --- python/pyspark/sql/context.py | 99 +++++++++++-------- python/pyspark/sql/tests.py | 15 +-- python/pyspark/sql/types.py | 72 ++++++++++---- .../org/apache/spark/sql/types/DataType.scala | 10 +- .../spark/sql/types/UserDefinedType.scala | 19 ++++ .../spark/sql/test/ExamplePointUDT.scala | 2 +- 6 files changed, 145 insertions(+), 72 deletions(-) diff --git a/python/pyspark/sql/context.py b/python/pyspark/sql/context.py index c93a15badae2..4ce2994342bc 100644 --- a/python/pyspark/sql/context.py +++ b/python/pyspark/sql/context.py @@ -287,6 +287,57 @@ def applySchema(self, rdd, schema): return self.createDataFrame(rdd, schema) + def _createFromRDD(self, rdd, schema, samplingRatio): + if schema is None or isinstance(schema, (list, tuple)): + struct = self._inferSchema(rdd, samplingRatio) + converter = _create_converter(struct) + rdd = rdd.map(converter) + if isinstance(schema, (list, tuple)): + for i, name in enumerate(schema): + struct.fields[i].name = name + schema = struct + + elif isinstance(schema, StructType): + # take the first few rows to verify schema + rows = rdd.take(10) + for row in rows: + _verify_type(row, schema) + + else: + raise TypeError("schema should be StructType or list or None") + + # convert python objects to sql data + rdd = rdd.map(schema.toInternal) + return rdd, schema + + def _createFromLocal(self, data, schema): + if has_pandas and isinstance(data, pandas.DataFrame): + if schema is None: + schema = [str(x) for x in data.columns] + data = [r.tolist() for r in data.to_records(index=False)] + + # make sure data could consumed multiple times + if not isinstance(data, list): + data = list(data) + + if schema is None or isinstance(schema, (list, tuple)): + struct = self._inferSchemaFromList(data) + if isinstance(schema, (list, tuple)): + for i, name in enumerate(schema): + struct.fields[i].name = name + schema = struct + + elif isinstance(schema, StructType): + for row in data: + _verify_type(row, schema) + + else: + raise TypeError("schema should be StructType or list or None") + + # convert python objects to sql data + data = [schema.toInternal(row) for row in data] + return self._sc.parallelize(data), schema + @since(1.3) @ignore_unicode_prefix def createDataFrame(self, data, schema=None, samplingRatio=None): @@ -350,49 +401,15 @@ def createDataFrame(self, data, schema=None, samplingRatio=None): if isinstance(data, DataFrame): raise TypeError("data is already a DataFrame") - if has_pandas and isinstance(data, pandas.DataFrame): - if schema is None: - schema = [str(x) for x in data.columns] - data = [r.tolist() for r in data.to_records(index=False)] - - if not isinstance(data, RDD): - if not isinstance(data, list): - data = list(data) - try: - # data could be list, tuple, generator ... - rdd = self._sc.parallelize(data) - except Exception: - raise TypeError("cannot create an RDD from type: %s" % type(data)) + if isinstance(data, RDD): + rdd, schema = self._createFromRDD(data, schema, samplingRatio) else: - rdd = data - - if schema is None or isinstance(schema, (list, tuple)): - if isinstance(data, RDD): - struct = self._inferSchema(rdd, samplingRatio) - else: - struct = self._inferSchemaFromList(data) - if isinstance(schema, (list, tuple)): - for i, name in enumerate(schema): - struct.fields[i].name = name - schema = struct - converter = _create_converter(schema) - rdd = rdd.map(converter) - - elif isinstance(schema, StructType): - # take the first few rows to verify schema - rows = rdd.take(10) - for row in rows: - _verify_type(row, schema) - - else: - raise TypeError("schema should be StructType or list or None") - - # convert python objects to sql data - rdd = rdd.map(schema.toInternal) - + rdd, schema = self._createFromLocal(data, schema) jrdd = self._jvm.SerDeUtil.toJavaArray(rdd._to_java_object_rdd()) - df = self._ssql_ctx.applySchemaToPythonRDD(jrdd.rdd(), schema.json()) - return DataFrame(df, self) + jdf = self._ssql_ctx.applySchemaToPythonRDD(jrdd.rdd(), schema.json()) + df = DataFrame(jdf, self) + df._schema = schema + return df @since(1.3) def registerDataFrameAsTable(self, df, tableName): diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 241eac45cfe3..6910df2b416b 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -75,7 +75,7 @@ def sqlType(self): @classmethod def module(cls): - return 'pyspark.tests' + return '__main__' @classmethod def scalaUDT(cls): @@ -392,9 +392,8 @@ def test_convert_row_to_dict(self): self.assertEqual(1.0, row.asDict()['d']['key'].c) def test_infer_schema_with_udt(self): - from pyspark.sql.tests import ExamplePoint, ExamplePointUDT row = Row(label=1.0, point=ExamplePoint(1.0, 2.0)) - df = self.sc.parallelize([row]).toDF() + df = self.sqlCtx.createDataFrame([row]) schema = df.schema field = [f for f in schema.fields if f.name == "point"][0] self.assertEqual(type(field.dataType), ExamplePointUDT) @@ -403,27 +402,23 @@ def test_infer_schema_with_udt(self): self.assertEqual(point, ExamplePoint(1.0, 2.0)) def test_apply_schema_with_udt(self): - from pyspark.sql.tests import ExamplePoint, ExamplePointUDT row = (1.0, ExamplePoint(1.0, 2.0)) - rdd = self.sc.parallelize([row]) schema = StructType([StructField("label", DoubleType(), False), StructField("point", ExamplePointUDT(), False)]) - df = rdd.toDF(schema) + df = self.sqlCtx.createDataFrame([row], schema) point = df.head().point self.assertEquals(point, ExamplePoint(1.0, 2.0)) def test_udf_with_udt(self): - from pyspark.sql.tests import ExamplePoint row = Row(label=1.0, point=ExamplePoint(1.0, 2.0)) - df = self.sc.parallelize([row]).toDF() + df = self.sqlCtx.createDataFrame([row]) self.assertEqual(1.0, df.map(lambda r: r.point.x).first()) udf = UserDefinedFunction(lambda p: p.y, DoubleType()) self.assertEqual(2.0, df.select(udf(df.point)).first()[0]) def test_parquet_with_udt(self): - from pyspark.sql.tests import ExamplePoint row = Row(label=1.0, point=ExamplePoint(1.0, 2.0)) - df0 = self.sc.parallelize([row]).toDF() + df0 = self.sqlCtx.createDataFrame([row]) output_dir = os.path.join(self.tempdir.name, "labeled_point") df0.saveAsParquetFile(output_dir) df1 = self.sqlCtx.parquetFile(output_dir) diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py index f75791fad161..94d9c25b06ce 100644 --- a/python/pyspark/sql/types.py +++ b/python/pyspark/sql/types.py @@ -580,7 +580,7 @@ def scalaUDT(cls): """ The class name of the paired Scala UDT. """ - raise NotImplementedError("UDT must have a paired Scala UDT.") + return '' def needConversion(self): return True @@ -641,6 +641,50 @@ def __eq__(self, other): return type(self) == type(other) +class PointUDT(UserDefinedType): + """ + User-defined type (UDT) for Point. + """ + + @classmethod + def sqlType(self): + return ArrayType(DoubleType(), False) + + @classmethod + def module(cls): + return '__main__' + + def serialize(self, obj): + return [obj.x, obj.y] + + def deserialize(self, datum): + return Point(datum[0], datum[1]) + + def __eq__(self, other): + return True + + +class Point: + """ + An example class to demonstrate UDT in Python. + """ + + __UDT__ = PointUDT() + + def __init__(self, x, y): + self.x = x + self.y = y + + def __repr__(self): + return "ExamplePoint(%s,%s)" % (self.x, self.y) + + def __str__(self): + return "(%s,%s)" % (self.x, self.y) + + def __eq__(self, other): + return isinstance(other, Point) and other.x == self.x and other.y == self.y + + _atomic_types = [StringType, BinaryType, BooleanType, DecimalType, FloatType, DoubleType, ByteType, ShortType, IntegerType, LongType, DateType, TimestampType] _all_atomic_types = dict((t.typeName(), t) for t in _atomic_types) @@ -656,7 +700,7 @@ def _parse_datatype_json_string(json_string): ... assert datatype == pickled ... scala_datatype = sqlContext._ssql_ctx.parseDataType(datatype.json()) ... python_datatype = _parse_datatype_json_string(scala_datatype.json()) - ... assert datatype == python_datatype + ... assert datatype == python_datatype, str(datatype) + str(python_datatype) >>> for cls in _all_atomic_types.values(): ... check_datatype(cls()) @@ -694,9 +738,9 @@ def _parse_datatype_json_string(json_string): ... complex_arraytype, False) >>> check_datatype(complex_maptype) - >>> check_datatype(ExamplePointUDT()) + >>> check_datatype(PointUDT()) >>> structtype_with_udt = StructType([StructField("label", DoubleType(), False), - ... StructField("point", ExamplePointUDT(), False)]) + ... StructField("point", PointUDT(), False)]) >>> check_datatype(structtype_with_udt) """ return _parse_datatype_json_value(json.loads(json_string)) @@ -750,9 +794,9 @@ def _parse_datatype_json_value(json_value): def _infer_type(obj): """Infer the DataType from obj - >>> p = ExamplePoint(1.0, 2.0) + >>> p = Point(1.0, 2.0) >>> _infer_type(p) - ExamplePointUDT + PointUDT """ if obj is None: return NullType() @@ -1084,8 +1128,8 @@ def _verify_type(obj, dataType): Traceback (most recent call last): ... ValueError:... - >>> _verify_type(ExamplePoint(1.0, 2.0), ExamplePointUDT()) - >>> _verify_type([1.0, 2.0], ExamplePointUDT()) # doctest: +IGNORE_EXCEPTION_DETAIL + >>> _verify_type(Point(1.0, 2.0), PointUDT()) + >>> _verify_type([1.0, 2.0], PointUDT()) # doctest: +IGNORE_EXCEPTION_DETAIL Traceback (most recent call last): ... ValueError:... @@ -1253,18 +1297,12 @@ def convert(self, obj, gateway_client): def _test(): import doctest from pyspark.context import SparkContext - # let doctest run in pyspark.sql.types, so DataTypes can be picklable - import pyspark.sql.types - from pyspark.sql import Row, SQLContext - from pyspark.sql.tests import ExamplePoint, ExamplePointUDT - globs = pyspark.sql.types.__dict__.copy() + from pyspark.sql import SQLContext + globs = globals() sc = SparkContext('local[4]', 'PythonTest') globs['sc'] = sc globs['sqlContext'] = SQLContext(sc) - globs['ExamplePoint'] = ExamplePoint - globs['ExamplePointUDT'] = ExamplePointUDT - (failure_count, test_count) = doctest.testmod( - pyspark.sql.types, globs=globs, optionflags=doctest.ELLIPSIS) + (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS) globs['sc'].stop() if failure_count: exit(-1) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala index 2d133eea19fe..e3e27eeeeb75 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala @@ -144,10 +144,14 @@ object DataType { case JSortedObject( ("class", JString(udtClass)), - ("pyClass", _), - ("sqlType", _), + ("pyClass", JString(pyClass)), + ("sqlType", v: JValue), ("type", JString("udt"))) => - Utils.classForName(udtClass).newInstance().asInstanceOf[UserDefinedType[_]] + if (udtClass.length > 0) { + Utils.classForName(udtClass).newInstance().asInstanceOf[UserDefinedType[_]] + } else { + new PythonUserDefinedType(parseDataType(v), pyClass) + } } private def parseStructField(json: JValue): StructField = json match { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UserDefinedType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UserDefinedType.scala index e47cfb4833bd..9d661ec25e88 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UserDefinedType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UserDefinedType.scala @@ -82,3 +82,22 @@ abstract class UserDefinedType[UserType] extends DataType with Serializable { override private[sql] def acceptsType(dataType: DataType) = this.getClass == dataType.getClass } + +/** + * ::DeveloperApi:: + * The user defined type in Python. + * + * Note: This can only be accessed via Python UDF, or accessed as serialized object. + */ +class PythonUserDefinedType(val sqlType: DataType, pyClass: String) extends UserDefinedType[Any] { + override def pyUDT = pyClass + def serialize(obj: Any): Any = obj + def deserialize(datam: Any): Any = datam + override private[sql] def jsonValue: JValue = { + ("type" -> "udt") ~ + ("class" -> "") ~ + ("pyClass" -> pyUDT) ~ + ("sqlType" -> sqlType.jsonValue) + } + def userClass: java.lang.Class[Any] = null +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/test/ExamplePointUDT.scala b/sql/core/src/main/scala/org/apache/spark/sql/test/ExamplePointUDT.scala index 2fdd798b44bb..c807c84c3e89 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/test/ExamplePointUDT.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/test/ExamplePointUDT.scala @@ -37,7 +37,7 @@ private[sql] class ExamplePointUDT extends UserDefinedType[ExamplePoint] { override def sqlType: DataType = ArrayType(DoubleType, false) - override def pyUDT: String = "pyspark.sql.tests.ExamplePointUDT" + override def pyUDT: String = "__main__.ExamplePointUDT" override def serialize(obj: Any): Seq[Double] = { obj match { From 83d65ac79ab0d699edf1dcb4abce8d820bc5c7e2 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Thu, 16 Jul 2015 22:47:31 -0700 Subject: [PATCH 2/9] fix bug in StructType --- python/pyspark/sql/context.py | 2 ++ python/pyspark/sql/types.py | 15 ++++++------ .../spark/sql/types/UserDefinedType.scala | 23 ++++++++++--------- .../spark/sql/execution/pythonUDFs.scala | 1 - 4 files changed, 22 insertions(+), 19 deletions(-) diff --git a/python/pyspark/sql/context.py b/python/pyspark/sql/context.py index 4ce2994342bc..63a27c0e7658 100644 --- a/python/pyspark/sql/context.py +++ b/python/pyspark/sql/context.py @@ -295,6 +295,7 @@ def _createFromRDD(self, rdd, schema, samplingRatio): if isinstance(schema, (list, tuple)): for i, name in enumerate(schema): struct.fields[i].name = name + struct.names[i] = name schema = struct elif isinstance(schema, StructType): @@ -325,6 +326,7 @@ def _createFromLocal(self, data, schema): if isinstance(schema, (list, tuple)): for i, name in enumerate(schema): struct.fields[i].name = name + struct.names[i] = name schema = struct elif isinstance(schema, StructType): diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py index 94d9c25b06ce..8ebb579cbee5 100644 --- a/python/pyspark/sql/types.py +++ b/python/pyspark/sql/types.py @@ -455,7 +455,7 @@ def __init__(self, fields=None): self.names = [f.name for f in fields] assert all(isinstance(f, StructField) for f in fields),\ "fields should be a list of StructField" - self._needSerializeFields = None + self._needSerializeAnyField = any(f.needConversion() for f in self.fields) def add(self, field, data_type=None, nullable=True, metadata=None): """ @@ -498,6 +498,7 @@ def add(self, field, data_type=None, nullable=True, metadata=None): data_type_f = data_type self.fields.append(StructField(field, data_type_f, nullable, metadata)) self.names.append(field) + self._needSerializeAnyField = any(f.needConversion() for f in self.fields) return self def simpleString(self): @@ -523,12 +524,9 @@ def toInternal(self, obj): if obj is None: return - if self._needSerializeFields is None: - self._needSerializeFields = any(f.needConversion() for f in self.fields) - - if self._needSerializeFields: + if self._needSerializeAnyField: if isinstance(obj, dict): - return tuple(f.toInternal(obj.get(n)) for n, f in zip(names, self.fields)) + return tuple(f.toInternal(obj.get(n)) for n, f in zip(self.names, self.fields)) elif isinstance(obj, (tuple, list)): return tuple(f.toInternal(v) for f, v in zip(self.fields, obj)) else: @@ -547,7 +545,10 @@ def fromInternal(self, obj): if isinstance(obj, Row): # it's already converted by pickler return obj - values = [f.dataType.fromInternal(v) for f, v in zip(self.fields, obj)] + if self._needSerializeAnyField: + values = [f.fromInternal(v) for f, v in zip(self.fields, obj)] + else: + values = obj return _create_row(self.names, values) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UserDefinedType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UserDefinedType.scala index 9d661ec25e88..00549cb1cec6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UserDefinedType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UserDefinedType.scala @@ -89,15 +89,16 @@ abstract class UserDefinedType[UserType] extends DataType with Serializable { * * Note: This can only be accessed via Python UDF, or accessed as serialized object. */ -class PythonUserDefinedType(val sqlType: DataType, pyClass: String) extends UserDefinedType[Any] { - override def pyUDT = pyClass - def serialize(obj: Any): Any = obj - def deserialize(datam: Any): Any = datam - override private[sql] def jsonValue: JValue = { - ("type" -> "udt") ~ - ("class" -> "") ~ - ("pyClass" -> pyUDT) ~ - ("sqlType" -> sqlType.jsonValue) - } - def userClass: java.lang.Class[Any] = null +private[sql] class PythonUserDefinedType(val sqlType: DataType, pyClass: String) + extends UserDefinedType[Any] { + + /* The Python UDT class */ + override def pyUDT: String = pyClass + + /* The serialization is handled by UDT class in Python */ + override def serialize(obj: Any): Any = obj + override def deserialize(datam: Any): Any = datam + + /* There is no Java class for Python UDT */ + override def userClass: java.lang.Class[Any] = null } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUDFs.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUDFs.scala index 6d6e67dace17..3456fba40f35 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUDFs.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUDFs.scala @@ -271,7 +271,6 @@ object EvaluatePython { pickler.save(row.values(i)) i += 1 } - row.values.foreach(pickler.save) out.write(Opcodes.TUPLE) out.write(Opcodes.REDUCE) } From de986d63e4cb90d26df073be1f9baf50f054ad12 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Fri, 17 Jul 2015 00:12:19 -0700 Subject: [PATCH 3/9] fix test --- .../src/main/scala/org/apache/spark/sql/types/DataType.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala index e3e27eeeeb75..0feeea49a202 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala @@ -144,13 +144,13 @@ object DataType { case JSortedObject( ("class", JString(udtClass)), - ("pyClass", JString(pyClass)), + ("pyClass", pyClass: JValue), // pyClass could be null, cannot match with JString() ("sqlType", v: JValue), ("type", JString("udt"))) => if (udtClass.length > 0) { Utils.classForName(udtClass).newInstance().asInstanceOf[UserDefinedType[_]] } else { - new PythonUserDefinedType(parseDataType(v), pyClass) + new PythonUserDefinedType(parseDataType(v), pyClass.asInstanceOf[JString].values) } } From 0bcb3efda449051ce9c0cf5df22a053af5c42a9a Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Fri, 17 Jul 2015 10:30:19 -0700 Subject: [PATCH 4/9] fix bug in mllib --- python/pyspark/mllib/linalg.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/python/pyspark/mllib/linalg.py b/python/pyspark/mllib/linalg.py index 040886f71775..3f96d4ece0ef 100644 --- a/python/pyspark/mllib/linalg.py +++ b/python/pyspark/mllib/linalg.py @@ -778,6 +778,8 @@ def dense(elements): >>> Vectors.dense([1, 2, 3]) DenseVector([1.0, 2.0, 3.0]) """ + if isinstance(elements, (int, float)): + elements = [elements] return DenseVector(elements) @staticmethod From 316a39451e1db822e9fce11a11863c202f9dcc18 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Fri, 17 Jul 2015 17:52:08 -0700 Subject: [PATCH 5/9] support Python UDT with UTF --- python/pyspark/cloudpickle.py | 35 +++++++- python/pyspark/sql/tests.py | 31 ++++++- python/pyspark/sql/types.py | 90 ++++++------------- .../org/apache/spark/sql/types/DataType.scala | 17 ++-- .../spark/sql/types/UserDefinedType.scala | 13 +-- .../spark/sql/test/ExamplePointUDT.scala | 2 - 6 files changed, 105 insertions(+), 83 deletions(-) diff --git a/python/pyspark/cloudpickle.py b/python/pyspark/cloudpickle.py index 9ef93071d2e7..3c445444dbd2 100644 --- a/python/pyspark/cloudpickle.py +++ b/python/pyspark/cloudpickle.py @@ -350,7 +350,26 @@ def save_global(self, obj, name=None, pack=struct.pack): if new_override: d['__new__'] = obj.__new__ - self.save_reduce(typ, (obj.__name__, obj.__bases__, d), obj=obj) + self.save(_load_class) + self.save_reduce(typ, (obj.__name__, obj.__bases__, {"__doc__": obj.__doc__}), obj=obj) + d.pop('__doc__', None) + # handle property and staticmethod + dd = {} + for k, v in d.items(): + if isinstance(v, property): + k = ('property', k) + v = (v.fget, v.fset, v.fdel, v.__doc__) + elif isinstance(v, staticmethod) and hasattr(v, '__func__'): + k = ('staticmethod', k) + v = v.__func__ + elif isinstance(v, classmethod) and hasattr(v, '__func__'): + k = ('classmethod', k) + v = v.__func__ + dd[k] = v + self.save(dd) + self.write(pickle.TUPLE2) + self.write(pickle.REDUCE) + else: raise pickle.PicklingError("Can't pickle %r" % obj) @@ -708,6 +727,20 @@ def _make_skel_func(code, closures, base_globals = None): None, None, closure) +def _load_class(cls, d): + for k, v in d.items(): + if isinstance(k, tuple): + typ, k = k + if typ == 'property': + v = property(*v) + elif typ == 'staticmethod': + v = staticmethod(v) + elif typ == 'classmethod': + v = classmethod(v) + setattr(cls, k, v) + return cls + + """Constructors for 3rd party libraries Note: These can never be renamed due to client compatibility issues""" diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 6910df2b416b..97badabf9a2d 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -77,16 +77,20 @@ def sqlType(self): def module(cls): return '__main__' - @classmethod - def scalaUDT(cls): - return 'org.apache.spark.sql.test.ExamplePointUDT' - def serialize(self, obj): return [obj.x, obj.y] def deserialize(self, datum): return ExamplePoint(datum[0], datum[1]) + @staticmethod + def foo(): + pass + + @property + def props(self): + return {} + class ExamplePoint: """ @@ -391,6 +395,25 @@ def test_convert_row_to_dict(self): self.assertEqual(1, row.asDict()["l"][0].a) self.assertEqual(1.0, row.asDict()['d']['key'].c) + def test_udt(self): + from pyspark.sql.types import _parse_datatype_json_string, _infer_type, _verify_type + + def check_datatype(datatype): + pickled = pickle.loads(pickle.dumps(datatype)) + assert datatype == pickled + scala_datatype = self.sqlCtx._ssql_ctx.parseDataType(datatype.json()) + python_datatype = _parse_datatype_json_string(scala_datatype.json()) + assert datatype == python_datatype + + check_datatype(ExamplePointUDT()) + structtype_with_udt = StructType([StructField("label", DoubleType(), False), + StructField("point", ExamplePointUDT(), False)]) + check_datatype(structtype_with_udt) + p = ExamplePoint(1.0, 2.0) + self.assertEqual(_infer_type(p), ExamplePointUDT()) + _verify_type(ExamplePoint(1.0, 2.0), ExamplePointUDT()) + self.assertRaises(ValueError, lambda: _verify_type([1.0, 2.0], ExamplePointUDT())) + def test_infer_schema_with_udt(self): row = Row(label=1.0, point=ExamplePoint(1.0, 2.0)) df = self.sqlCtx.createDataFrame([row]) diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py index 8ebb579cbee5..d0a4e9bff2b6 100644 --- a/python/pyspark/sql/types.py +++ b/python/pyspark/sql/types.py @@ -22,6 +22,7 @@ import calendar import json import re +import base64 from array import array if sys.version >= "3": @@ -31,6 +32,8 @@ from py4j.protocol import register_input_converter from py4j.java_gateway import JavaClass +from pyspark.serializers import CloudPickleSerializer + __all__ = [ "DataType", "NullType", "StringType", "BinaryType", "BooleanType", "DateType", "TimestampType", "DecimalType", "DoubleType", "FloatType", "ByteType", "IntegerType", @@ -620,12 +623,23 @@ def json(self): return json.dumps(self.jsonValue(), separators=(',', ':'), sort_keys=True) def jsonValue(self): - schema = { - "type": "udt", - "class": self.scalaUDT(), - "pyClass": "%s.%s" % (self.module(), type(self).__name__), - "sqlType": self.sqlType().jsonValue() - } + if self.scalaUDT(): + assert self.module() != '__main__', 'UDT in __main__ cannot work with ScalaUDT' + schema = { + "type": "udt", + "class": self.scalaUDT(), + "pyClass": "%s.%s" % (self.module(), type(self).__name__), + "sqlType": self.sqlType().jsonValue() + } + else: + ser = CloudPickleSerializer() + b = ser.dumps(type(self)) + schema = { + "type": "udt", + "pyClass": "%s.%s" % (self.module(), type(self).__name__), + "serializedClass": base64.b64encode(b).decode('utf8'), + "sqlType": self.sqlType().jsonValue() + } return schema @classmethod @@ -635,57 +649,17 @@ def fromJson(cls, json): pyModule = pyUDT[:split] pyClass = pyUDT[split+1:] m = __import__(pyModule, globals(), locals(), [pyClass]) - UDT = getattr(m, pyClass) + if pyModule == '__main__' and not hasattr(m, pyClass): + s = base64.b64decode(json['serializedClass'].encode('utf-8')) + UDT = CloudPickleSerializer().loads(s) + else: + UDT = getattr(m, pyClass) return UDT() def __eq__(self, other): return type(self) == type(other) -class PointUDT(UserDefinedType): - """ - User-defined type (UDT) for Point. - """ - - @classmethod - def sqlType(self): - return ArrayType(DoubleType(), False) - - @classmethod - def module(cls): - return '__main__' - - def serialize(self, obj): - return [obj.x, obj.y] - - def deserialize(self, datum): - return Point(datum[0], datum[1]) - - def __eq__(self, other): - return True - - -class Point: - """ - An example class to demonstrate UDT in Python. - """ - - __UDT__ = PointUDT() - - def __init__(self, x, y): - self.x = x - self.y = y - - def __repr__(self): - return "ExamplePoint(%s,%s)" % (self.x, self.y) - - def __str__(self): - return "(%s,%s)" % (self.x, self.y) - - def __eq__(self, other): - return isinstance(other, Point) and other.x == self.x and other.y == self.y - - _atomic_types = [StringType, BinaryType, BooleanType, DecimalType, FloatType, DoubleType, ByteType, ShortType, IntegerType, LongType, DateType, TimestampType] _all_atomic_types = dict((t.typeName(), t) for t in _atomic_types) @@ -738,11 +712,6 @@ def _parse_datatype_json_string(json_string): >>> complex_maptype = MapType(complex_structtype, ... complex_arraytype, False) >>> check_datatype(complex_maptype) - - >>> check_datatype(PointUDT()) - >>> structtype_with_udt = StructType([StructField("label", DoubleType(), False), - ... StructField("point", PointUDT(), False)]) - >>> check_datatype(structtype_with_udt) """ return _parse_datatype_json_value(json.loads(json_string)) @@ -794,10 +763,6 @@ def _parse_datatype_json_value(json_value): def _infer_type(obj): """Infer the DataType from obj - - >>> p = Point(1.0, 2.0) - >>> _infer_type(p) - PointUDT """ if obj is None: return NullType() @@ -1129,11 +1094,6 @@ def _verify_type(obj, dataType): Traceback (most recent call last): ... ValueError:... - >>> _verify_type(Point(1.0, 2.0), PointUDT()) - >>> _verify_type([1.0, 2.0], PointUDT()) # doctest: +IGNORE_EXCEPTION_DETAIL - Traceback (most recent call last): - ... - ValueError:... """ # all objects are nullable if obj is None: diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala index 0feeea49a202..d9b6027a7112 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala @@ -142,16 +142,21 @@ object DataType { ("type", JString("struct"))) => StructType(fields.map(parseStructField)) + // Scala/Java UDT case JSortedObject( ("class", JString(udtClass)), - ("pyClass", pyClass: JValue), // pyClass could be null, cannot match with JString() + ("pyClass", _), + ("sqlType", _), + ("type", JString("udt"))) => + Utils.classForName(udtClass).newInstance().asInstanceOf[UserDefinedType[_]] + + // Python UDT + case JSortedObject( + ("pyClass", JString(pyClass)), + ("serializedClass", JString(serialized)), ("sqlType", v: JValue), ("type", JString("udt"))) => - if (udtClass.length > 0) { - Utils.classForName(udtClass).newInstance().asInstanceOf[UserDefinedType[_]] - } else { - new PythonUserDefinedType(parseDataType(v), pyClass.asInstanceOf[JString].values) - } + new PythonUserDefinedType(parseDataType(v), pyClass, serialized) } private def parseStructField(json: JValue): StructField = json match { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UserDefinedType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UserDefinedType.scala index 00549cb1cec6..b2097cf50891 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UserDefinedType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UserDefinedType.scala @@ -45,6 +45,9 @@ abstract class UserDefinedType[UserType] extends DataType with Serializable { /** Paired Python UDT class, if exists. */ def pyUDT: String = null + /** Serialized Python UDT class, if exists. */ + def serializedPyClass: String = null + /** * Convert the user type to a SQL datum * @@ -60,6 +63,7 @@ abstract class UserDefinedType[UserType] extends DataType with Serializable { ("type" -> "udt") ~ ("class" -> this.getClass.getName) ~ ("pyClass" -> pyUDT) ~ + ("serializedClass" -> serializedPyClass) ~ ("sqlType" -> sqlType.jsonValue) } @@ -89,11 +93,10 @@ abstract class UserDefinedType[UserType] extends DataType with Serializable { * * Note: This can only be accessed via Python UDF, or accessed as serialized object. */ -private[sql] class PythonUserDefinedType(val sqlType: DataType, pyClass: String) - extends UserDefinedType[Any] { - - /* The Python UDT class */ - override def pyUDT: String = pyClass +private[sql] class PythonUserDefinedType( + val sqlType: DataType, + override val pyUDT: String, + override val serializedPyClass: String) extends UserDefinedType[Any] { /* The serialization is handled by UDT class in Python */ override def serialize(obj: Any): Any = obj diff --git a/sql/core/src/main/scala/org/apache/spark/sql/test/ExamplePointUDT.scala b/sql/core/src/main/scala/org/apache/spark/sql/test/ExamplePointUDT.scala index c807c84c3e89..2da6b55bc3d6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/test/ExamplePointUDT.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/test/ExamplePointUDT.scala @@ -37,8 +37,6 @@ private[sql] class ExamplePointUDT extends UserDefinedType[ExamplePoint] { override def sqlType: DataType = ArrayType(DoubleType, false) - override def pyUDT: String = "__main__.ExamplePointUDT" - override def serialize(obj: Any): Seq[Double] = { obj match { case p: ExamplePoint => From 63f52efe337fd81c9880da510c7cb618d2a138ee Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Mon, 20 Jul 2015 13:50:32 -0700 Subject: [PATCH 6/9] fix pylint check --- dev/lint-python | 2 +- pylintrc | 2 +- python/pyspark/shuffle.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/dev/lint-python b/dev/lint-python index e02dff220eb8..b622956e9803 100755 --- a/dev/lint-python +++ b/dev/lint-python @@ -61,7 +61,7 @@ export "PATH=$PYTHONPATH:$PATH" if [ ! -d "$PYLINT_HOME" ]; then mkdir "$PYLINT_HOME" # Redirect the annoying pylint installation output. - easy_install -d "$PYLINT_HOME" pylint==1.4.4 &>> "$PYLINT_INSTALL_INFO" + easy_install -d "$PYLINT_HOME" pylint==1.4.4 >> "$PYLINT_INSTALL_INFO" 2>&1 easy_install_status="$?" if [ "$easy_install_status" -ne 0 ]; then diff --git a/pylintrc b/pylintrc index 061775960393..6a675770da69 100644 --- a/pylintrc +++ b/pylintrc @@ -84,7 +84,7 @@ enable= # If you would like to improve the code quality of pyspark, remove any of these disabled errors # run ./dev/lint-python and see if the errors raised by pylint can be fixed. -disable=invalid-name,missing-docstring,protected-access,unused-argument,no-member,unused-wildcard-import,redefined-builtin,too-many-arguments,unused-variable,too-few-public-methods,bad-continuation,duplicate-code,redefined-outer-name,too-many-ancestors,import-error,superfluous-parens,unused-import,line-too-long,no-name-in-module,unnecessary-lambda,import-self,no-self-use,unidiomatic-typecheck,fixme,too-many-locals,cyclic-import,too-many-branches,bare-except,wildcard-import,dangerous-default-value,broad-except,too-many-public-methods,deprecated-lambda,anomalous-backslash-in-string,too-many-lines,reimported,too-many-statements,bad-whitespace,unpacking-non-sequence,too-many-instance-attributes,abstract-method,old-style-class,global-statement,attribute-defined-outside-init,arguments-differ,undefined-all-variable,no-init,useless-else-on-loop,super-init-not-called,notimplemented-raised,too-many-return-statements,pointless-string-statement,global-variable-undefined,bad-classmethod-argument,too-many-format-args,parse-error,no-self-argument,pointless-statement,undefined-variable +disable=invalid-name,missing-docstring,protected-access,unused-argument,no-member,unused-wildcard-import,redefined-builtin,too-many-arguments,unused-variable,too-few-public-methods,bad-continuation,duplicate-code,redefined-outer-name,too-many-ancestors,import-error,superfluous-parens,unused-import,line-too-long,no-name-in-module,unnecessary-lambda,import-self,no-self-use,unidiomatic-typecheck,fixme,too-many-locals,cyclic-import,too-many-branches,bare-except,wildcard-import,dangerous-default-value,broad-except,too-many-public-methods,deprecated-lambda,anomalous-backslash-in-string,too-many-lines,reimported,too-many-statements,bad-whitespace,unpacking-non-sequence,too-many-instance-attributes,abstract-method,old-style-class,global-statement,attribute-defined-outside-init,arguments-differ,undefined-all-variable,no-init,useless-else-on-loop,super-init-not-called,notimplemented-raised,too-many-return-statements,pointless-string-statement,global-variable-undefined,bad-classmethod-argument,too-many-format-args,parse-error,no-self-argument,pointless-statement,undefined-variable,undefined-loop-variable [REPORTS] diff --git a/python/pyspark/shuffle.py b/python/pyspark/shuffle.py index 8fb71bac64a5..b8118bdb7ca7 100644 --- a/python/pyspark/shuffle.py +++ b/python/pyspark/shuffle.py @@ -606,7 +606,7 @@ def _open_file(self): if not os.path.exists(d): os.makedirs(d) p = os.path.join(d, str(id(self))) - self._file = open(p, "wb+", 65536) + self._file = open(p, "w+b", 65536) self._ser = BatchedSerializer(CompressedSerializer(PickleSerializer()), 1024) os.unlink(p) From a86e1fcf95af24635e86857ee214588ff16e96fe Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Thu, 23 Jul 2015 16:13:38 -0700 Subject: [PATCH 7/9] fix serialization --- python/pyspark/sql/types.py | 2 +- .../org/apache/spark/sql/types/UserDefinedType.scala | 8 +++++++- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py index 5e1a4fef302c..da973fc1eb5c 100644 --- a/python/pyspark/sql/types.py +++ b/python/pyspark/sql/types.py @@ -649,7 +649,7 @@ def fromJson(cls, json): pyModule = pyUDT[:split] pyClass = pyUDT[split+1:] m = __import__(pyModule, globals(), locals(), [pyClass]) - if pyModule == '__main__' and not hasattr(m, pyClass): + if not hasattr(m, pyClass): s = base64.b64decode(json['serializedClass'].encode('utf-8')) UDT = CloudPickleSerializer().loads(s) else: diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UserDefinedType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UserDefinedType.scala index b2097cf50891..4305903616bd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UserDefinedType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UserDefinedType.scala @@ -63,7 +63,6 @@ abstract class UserDefinedType[UserType] extends DataType with Serializable { ("type" -> "udt") ~ ("class" -> this.getClass.getName) ~ ("pyClass" -> pyUDT) ~ - ("serializedClass" -> serializedPyClass) ~ ("sqlType" -> sqlType.jsonValue) } @@ -104,4 +103,11 @@ private[sql] class PythonUserDefinedType( /* There is no Java class for Python UDT */ override def userClass: java.lang.Class[Any] = null + + override private[sql] def jsonValue: JValue = { + ("type" -> "udt") ~ + ("pyClass" -> pyUDT) ~ + ("serializedClass" -> serializedPyClass) ~ + ("sqlType" -> sqlType.jsonValue) + } } From dc65f19dca9affaf7a197d70c06100af2969cf9d Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Wed, 29 Jul 2015 16:13:11 -0700 Subject: [PATCH 8/9] address comment --- python/pyspark/cloudpickle.py | 3 ++ python/pyspark/sql/context.py | 11 ++++- python/pyspark/sql/tests.py | 88 +++++++++++++++++++++++++++++++---- python/pyspark/sql/types.py | 5 +- 4 files changed, 93 insertions(+), 14 deletions(-) diff --git a/python/pyspark/cloudpickle.py b/python/pyspark/cloudpickle.py index 3c445444dbd2..3b647985801b 100644 --- a/python/pyspark/cloudpickle.py +++ b/python/pyspark/cloudpickle.py @@ -728,6 +728,9 @@ def _make_skel_func(code, closures, base_globals = None): def _load_class(cls, d): + """ + Loads additional properties into class `cls`. + """ for k, v in d.items(): if isinstance(k, tuple): typ, k = k diff --git a/python/pyspark/sql/context.py b/python/pyspark/sql/context.py index 794d55b69066..917de24f3536 100644 --- a/python/pyspark/sql/context.py +++ b/python/pyspark/sql/context.py @@ -278,6 +278,9 @@ def applySchema(self, rdd, schema): return self.createDataFrame(rdd, schema) def _createFromRDD(self, rdd, schema, samplingRatio): + """ + Create an RDD for DataFrame from an existing RDD, returns the RDD and schema. + """ if schema is None or isinstance(schema, (list, tuple)): struct = self._inferSchema(rdd, samplingRatio) converter = _create_converter(struct) @@ -295,13 +298,17 @@ def _createFromRDD(self, rdd, schema, samplingRatio): _verify_type(row, schema) else: - raise TypeError("schema should be StructType or list or None") + raise TypeError("schema should be StructType or list or None, but got: %s" % schema) # convert python objects to sql data rdd = rdd.map(schema.toInternal) return rdd, schema def _createFromLocal(self, data, schema): + """ + Create an RDD for DataFrame from an list or pandas.DataFrame, returns + the RDD and schema. + """ if has_pandas and isinstance(data, pandas.DataFrame): if schema is None: schema = [str(x) for x in data.columns] @@ -324,7 +331,7 @@ def _createFromLocal(self, data, schema): _verify_type(row, schema) else: - raise TypeError("schema should be StructType or list or None") + raise TypeError("schema should be StructType or list or None, but got: %s" % schema) # convert python objects to sql data data = [schema.toInternal(row) for row in data] diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 49f7993ca2aa..7a40b4515ee1 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -75,7 +75,11 @@ def sqlType(self): @classmethod def module(cls): - return '__main__' + return 'pyspark.sql.tests' + + @classmethod + def scalaUDT(cls): + return 'org.apache.spark.sql.test.ExamplePointUDT' def serialize(self, obj): return [obj.x, obj.y] @@ -83,14 +87,6 @@ def serialize(self, obj): def deserialize(self, datum): return ExamplePoint(datum[0], datum[1]) - @staticmethod - def foo(): - pass - - @property - def props(self): - return {} - class ExamplePoint: """ @@ -110,10 +106,45 @@ def __str__(self): return "(%s,%s)" % (self.x, self.y) def __eq__(self, other): - return isinstance(other, ExamplePoint) and \ + return isinstance(other, self.__class__) and \ other.x == self.x and other.y == self.y +class PythonOnlyUDT(UserDefinedType): + """ + User-defined type (UDT) for ExamplePoint. + """ + + @classmethod + def sqlType(self): + return ArrayType(DoubleType(), False) + + @classmethod + def module(cls): + return '__main__' + + def serialize(self, obj): + return [obj.x, obj.y] + + def deserialize(self, datum): + return PythonOnlyPoint(datum[0], datum[1]) + + @staticmethod + def foo(): + pass + + @property + def props(self): + return {} + + +class PythonOnlyPoint(ExamplePoint): + """ + An example class to demonstrate UDT in only Python + """ + __UDT__ = PythonOnlyUDT() + + class DataTypeTests(unittest.TestCase): # regression test for SPARK-6055 def test_data_type_eq(self): @@ -401,6 +432,7 @@ def test_convert_row_to_dict(self): def test_udt(self): from pyspark.sql.types import _parse_datatype_json_string, _infer_type, _verify_type + from pyspark.sql.tests import ExamplePointUDT, ExamplePoint def check_datatype(datatype): pickled = pickle.loads(pickle.dumps(datatype)) @@ -418,7 +450,17 @@ def check_datatype(datatype): _verify_type(ExamplePoint(1.0, 2.0), ExamplePointUDT()) self.assertRaises(ValueError, lambda: _verify_type([1.0, 2.0], ExamplePointUDT())) + check_datatype(PythonOnlyUDT()) + structtype_with_udt = StructType([StructField("label", DoubleType(), False), + StructField("point", PythonOnlyUDT(), False)]) + check_datatype(structtype_with_udt) + p = PythonOnlyPoint(1.0, 2.0) + self.assertEqual(_infer_type(p), PythonOnlyUDT()) + _verify_type(ExamplePoint(1.0, 2.0), PythonOnlyUDT()) + self.assertRaises(ValueError, lambda: _verify_type([1.0, 2.0], PythonOnlyUDT())) + def test_infer_schema_with_udt(self): + from pyspark.sql.tests import ExamplePoint, ExamplePointUDT row = Row(label=1.0, point=ExamplePoint(1.0, 2.0)) df = self.sqlCtx.createDataFrame([row]) schema = df.schema @@ -428,7 +470,17 @@ def test_infer_schema_with_udt(self): point = self.sqlCtx.sql("SELECT point FROM labeled_point").head().point self.assertEqual(point, ExamplePoint(1.0, 2.0)) + row = Row(label=1.0, point=PythonOnlyPoint(1.0, 2.0)) + df = self.sqlCtx.createDataFrame([row]) + schema = df.schema + field = [f for f in schema.fields if f.name == "point"][0] + self.assertEqual(type(field.dataType), PythonOnlyUDT) + df.registerTempTable("labeled_point") + point = self.sqlCtx.sql("SELECT point FROM labeled_point").head().point + self.assertEqual(point, PythonOnlyPoint(1.0, 2.0)) + def test_apply_schema_with_udt(self): + from pyspark.sql.tests import ExamplePoint, ExamplePointUDT row = (1.0, ExamplePoint(1.0, 2.0)) schema = StructType([StructField("label", DoubleType(), False), StructField("point", ExamplePointUDT(), False)]) @@ -436,7 +488,15 @@ def test_apply_schema_with_udt(self): point = df.head().point self.assertEquals(point, ExamplePoint(1.0, 2.0)) + row = (1.0, PythonOnlyPoint(1.0, 2.0)) + schema = StructType([StructField("label", DoubleType(), False), + StructField("point", PythonOnlyUDT(), False)]) + df = self.sqlCtx.createDataFrame([row], schema) + point = df.head().point + self.assertEquals(point, PythonOnlyPoint(1.0, 2.0)) + def test_udf_with_udt(self): + from pyspark.sql.tests import ExamplePoint, ExamplePointUDT row = Row(label=1.0, point=ExamplePoint(1.0, 2.0)) df = self.sqlCtx.createDataFrame([row]) self.assertEqual(1.0, df.map(lambda r: r.point.x).first()) @@ -445,6 +505,14 @@ def test_udf_with_udt(self): udf2 = UserDefinedFunction(lambda p: ExamplePoint(p.x + 1, p.y + 1), ExamplePointUDT()) self.assertEqual(ExamplePoint(2.0, 3.0), df.select(udf2(df.point)).first()[0]) + row = Row(label=1.0, point=PythonOnlyPoint(1.0, 2.0)) + df = self.sqlCtx.createDataFrame([row]) + self.assertEqual(1.0, df.map(lambda r: r.point.x).first()) + udf = UserDefinedFunction(lambda p: p.y, DoubleType()) + self.assertEqual(2.0, df.select(udf(df.point)).first()[0]) + udf2 = UserDefinedFunction(lambda p: PythonOnlyPoint(p.x + 1, p.y + 1), PythonOnlyUDT()) + self.assertEqual(PythonOnlyPoint(2.0, 3.0), df.select(udf2(df.point)).first()[0]) + def test_parquet_with_udt(self): row = Row(label=1.0, point=ExamplePoint(1.0, 2.0)) df0 = self.sqlCtx.createDataFrame([row]) diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py index aff07a0cee93..b25f3747afb3 100644 --- a/python/pyspark/sql/types.py +++ b/python/pyspark/sql/types.py @@ -585,7 +585,8 @@ def module(cls): @classmethod def scalaUDT(cls): """ - The class name of the paired Scala UDT. + The class name of the paired Scala UDT (could be '', if there + is no corresponding one). """ return '' @@ -678,7 +679,7 @@ def _parse_datatype_json_string(json_string): ... assert datatype == pickled ... scala_datatype = sqlContext._ssql_ctx.parseDataType(datatype.json()) ... python_datatype = _parse_datatype_json_string(scala_datatype.json()) - ... assert datatype == python_datatype, str(datatype) + str(python_datatype) + ... assert datatype == python_datatype >>> for cls in _all_atomic_types.values(): ... check_datatype(cls()) From 4dfd5e155ff220367fe9f7098b82209dde6acc47 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Wed, 29 Jul 2015 16:41:24 -0700 Subject: [PATCH 9/9] add tests for Python and Scala UDT --- python/pyspark/sql/tests.py | 12 ++++++++++-- python/pyspark/sql/types.py | 2 +- .../org/apache/spark/sql/test/ExamplePointUDT.scala | 2 ++ 3 files changed, 13 insertions(+), 3 deletions(-) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 7a40b4515ee1..ebd3ea8db6a4 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -456,7 +456,7 @@ def check_datatype(datatype): check_datatype(structtype_with_udt) p = PythonOnlyPoint(1.0, 2.0) self.assertEqual(_infer_type(p), PythonOnlyUDT()) - _verify_type(ExamplePoint(1.0, 2.0), PythonOnlyUDT()) + _verify_type(PythonOnlyPoint(1.0, 2.0), PythonOnlyUDT()) self.assertRaises(ValueError, lambda: _verify_type([1.0, 2.0], PythonOnlyUDT())) def test_infer_schema_with_udt(self): @@ -514,14 +514,22 @@ def test_udf_with_udt(self): self.assertEqual(PythonOnlyPoint(2.0, 3.0), df.select(udf2(df.point)).first()[0]) def test_parquet_with_udt(self): + from pyspark.sql.tests import ExamplePoint, ExamplePointUDT row = Row(label=1.0, point=ExamplePoint(1.0, 2.0)) df0 = self.sqlCtx.createDataFrame([row]) output_dir = os.path.join(self.tempdir.name, "labeled_point") - df0.saveAsParquetFile(output_dir) + df0.write.parquet(output_dir) df1 = self.sqlCtx.parquetFile(output_dir) point = df1.head().point self.assertEquals(point, ExamplePoint(1.0, 2.0)) + row = Row(label=1.0, point=PythonOnlyPoint(1.0, 2.0)) + df0 = self.sqlCtx.createDataFrame([row]) + df0.write.parquet(output_dir, mode='overwrite') + df1 = self.sqlCtx.parquetFile(output_dir) + point = df1.head().point + self.assertEquals(point, PythonOnlyPoint(1.0, 2.0)) + def test_column_operators(self): ci = self.df.key cs = self.df.value diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py index b25f3747afb3..0976aea72c03 100644 --- a/python/pyspark/sql/types.py +++ b/python/pyspark/sql/types.py @@ -648,7 +648,7 @@ def jsonValue(self): @classmethod def fromJson(cls, json): - pyUDT = json["pyClass"] + pyUDT = str(json["pyClass"]) split = pyUDT.rfind(".") pyModule = pyUDT[:split] pyClass = pyUDT[split+1:] diff --git a/sql/core/src/main/scala/org/apache/spark/sql/test/ExamplePointUDT.scala b/sql/core/src/main/scala/org/apache/spark/sql/test/ExamplePointUDT.scala index 2da6b55bc3d6..2fdd798b44bb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/test/ExamplePointUDT.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/test/ExamplePointUDT.scala @@ -37,6 +37,8 @@ private[sql] class ExamplePointUDT extends UserDefinedType[ExamplePoint] { override def sqlType: DataType = ArrayType(DoubleType, false) + override def pyUDT: String = "pyspark.sql.tests.ExamplePointUDT" + override def serialize(obj: Any): Seq[Double] = { obj match { case p: ExamplePoint =>