From 6c2909a7ed19d3f77c21a464e41f4160b822cfe0 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Thu, 26 Feb 2015 23:39:59 -0800 Subject: [PATCH 1/4] fix incorrect DataType.__eq__ --- python/pyspark/sql.py | 29 +++++++++++++++++++++++++---- 1 file changed, 25 insertions(+), 4 deletions(-) diff --git a/python/pyspark/sql.py b/python/pyspark/sql.py index aa5af1bd4049..2867d3d61250 100644 --- a/python/pyspark/sql.py +++ b/python/pyspark/sql.py @@ -36,6 +36,7 @@ import warnings import json import re +import weakref from array import array from operator import itemgetter from itertools import imap @@ -68,8 +69,7 @@ def __hash__(self): return hash(str(self)) def __eq__(self, other): - return (isinstance(other, self.__class__) and - self.__dict__ == other.__dict__) + return isinstance(other, self.__class__) and self.jsonValue() == other.jsonValue() def __ne__(self, other): return not self.__eq__(other) @@ -781,8 +781,25 @@ def _merge_type(a, b): return a +def _need_converter(dataType): + if isinstance(dataType, StructType): + return True + elif isinstance(dataType, ArrayType): + return _need_converter(dataType.elementType) + elif isinstance(dataType, MapType): + return _need_converter(dataType.keyType) or _need_converter(dataType.valueType) + elif isinstance(dataType, NullType): + return True + else: + return False + + def _create_converter(dataType): """Create an converter to drop the names of fields in obj """ + + if not _need_converter(dataType): + return lambda x: x + if isinstance(dataType, ArrayType): conv = _create_converter(dataType.elementType) return lambda row: map(conv, row) @@ -800,6 +817,7 @@ def _create_converter(dataType): # dataType must be StructType names = [f.name for f in dataType.fields] converters = [_create_converter(f.dataType) for f in dataType.fields] + convert_fields = any(_need_converter(f.dataType) for f in dataType.fields) def convert_struct(obj): if obj is None: @@ -822,7 +840,10 @@ def convert_struct(obj): else: raise ValueError("Unexpected obj: %s" % obj) - return tuple([conv(d.get(name)) for name, conv in zip(names, converters)]) + if convert_fields: + return tuple([conv(d.get(name)) for name, conv in zip(names, converters)]) + else: + return tuple([d.get(name) for name in names]) return convert_struct @@ -1039,7 +1060,7 @@ def _verify_type(obj, dataType): _verify_type(v, f.dataType) -_cached_cls = {} +_cached_cls = weakref.WeakValueDictionary() def _restore_object(dataType, obj): From b57610758620a9230cb1c32a6586193821c5eeb0 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Fri, 27 Feb 2015 08:40:13 -0800 Subject: [PATCH 2/4] fix tests --- python/pyspark/sql.py | 30 +++++++++++++++--------------- 1 file changed, 15 insertions(+), 15 deletions(-) diff --git a/python/pyspark/sql.py b/python/pyspark/sql.py index 2867d3d61250..b43003b9f924 100644 --- a/python/pyspark/sql.py +++ b/python/pyspark/sql.py @@ -251,9 +251,9 @@ def __init__(self, elementType, containsNull=True): :param elementType: the data type of elements. :param containsNull: indicates whether the list contains None values. - >>> ArrayType(StringType) == ArrayType(StringType, True) + >>> ArrayType(StringType()) == ArrayType(StringType(), True) True - >>> ArrayType(StringType, False) == ArrayType(StringType) + >>> ArrayType(StringType(), False) == ArrayType(StringType()) False """ self.elementType = elementType @@ -298,11 +298,11 @@ def __init__(self, keyType, valueType, valueContainsNull=True): :param valueContainsNull: indicates whether values contains null values. - >>> (MapType(StringType, IntegerType) - ... == MapType(StringType, IntegerType, True)) + >>> (MapType(StringType(), IntegerType()) + ... == MapType(StringType(), IntegerType(), True)) True - >>> (MapType(StringType, IntegerType, False) - ... == MapType(StringType, FloatType)) + >>> (MapType(StringType(), IntegerType(), False) + ... == MapType(StringType(), FloatType())) False """ self.keyType = keyType @@ -351,11 +351,11 @@ def __init__(self, name, dataType, nullable=True, metadata=None): to simple type that can be serialized to JSON automatically - >>> (StructField("f1", StringType, True) - ... == StructField("f1", StringType, True)) + >>> (StructField("f1", StringType(), True) + ... == StructField("f1", StringType(), True)) True - >>> (StructField("f1", StringType, True) - ... == StructField("f2", StringType, True)) + >>> (StructField("f1", StringType(), True) + ... == StructField("f2", StringType(), True)) False """ self.name = name @@ -393,13 +393,13 @@ class StructType(DataType): def __init__(self, fields): """Creates a StructType - >>> struct1 = StructType([StructField("f1", StringType, True)]) - >>> struct2 = StructType([StructField("f1", StringType, True)]) + >>> struct1 = StructType([StructField("f1", StringType(), True)]) + >>> struct2 = StructType([StructField("f1", StringType(), True)]) >>> struct1 == struct2 True - >>> struct1 = StructType([StructField("f1", StringType, True)]) - >>> struct2 = StructType([StructField("f1", StringType, True), - ... [StructField("f2", IntegerType, False)]]) + >>> struct1 = StructType([StructField("f1", StringType(), True)]) + >>> struct2 = StructType([StructField("f1", StringType(), True), + ... StructField("f2", IntegerType(), False)]) >>> struct1 == struct2 False """ From 9b4dadc04ade3564c6b4b483ad5775112fdbe12e Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Fri, 27 Feb 2015 12:27:35 -0800 Subject: [PATCH 3/4] fix __eq__ of singleton --- python/pyspark/sql.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/python/pyspark/sql.py b/python/pyspark/sql.py index b43003b9f924..6e8e2b404efb 100644 --- a/python/pyspark/sql.py +++ b/python/pyspark/sql.py @@ -69,7 +69,7 @@ def __hash__(self): return hash(str(self)) def __eq__(self, other): - return isinstance(other, self.__class__) and self.jsonValue() == other.jsonValue() + return isinstance(other, self.__class__) and self.__dict__ == other.__dict__ def __ne__(self, other): return not self.__eq__(other) @@ -105,10 +105,6 @@ class PrimitiveType(DataType): __metaclass__ = PrimitiveTypeSingleton - def __eq__(self, other): - # because they should be the same object - return self is other - class NullType(PrimitiveType): @@ -499,6 +495,9 @@ def __eq__(self, other): def _parse_datatype_json_string(json_string): """Parses the given data type JSON string. + + >>> import pickle + >>> LongType() == pickle.loads(pickle.dumps(LongType())) >>> def check_datatype(datatype): ... scala_datatype = sqlCtx._ssql_ctx.parseDataType(datatype.json()) ... python_datatype = _parse_datatype_json_string(scala_datatype.json()) From 65c222f9bcf3a27ec80594faac1e520e4c4e2d3b Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Fri, 27 Feb 2015 13:43:02 -0800 Subject: [PATCH 4/4] Update sql.py --- python/pyspark/sql.py | 1 + 1 file changed, 1 insertion(+) diff --git a/python/pyspark/sql.py b/python/pyspark/sql.py index 6e8e2b404efb..4410925ba0f8 100644 --- a/python/pyspark/sql.py +++ b/python/pyspark/sql.py @@ -498,6 +498,7 @@ def _parse_datatype_json_string(json_string): >>> import pickle >>> LongType() == pickle.loads(pickle.dumps(LongType())) + True >>> def check_datatype(datatype): ... scala_datatype = sqlCtx._ssql_ctx.parseDataType(datatype.json()) ... python_datatype = _parse_datatype_json_string(scala_datatype.json())