Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 44 additions & 23 deletions python/pyspark/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
import warnings
import json
import re
import weakref
from array import array
from operator import itemgetter
from itertools import imap
Expand Down Expand Up @@ -68,8 +69,7 @@ def __hash__(self):
return hash(str(self))

def __eq__(self, other):
return (isinstance(other, self.__class__) and
self.__dict__ == other.__dict__)
return isinstance(other, self.__class__) and self.__dict__ == other.__dict__

def __ne__(self, other):
return not self.__eq__(other)
Expand Down Expand Up @@ -105,10 +105,6 @@ class PrimitiveType(DataType):

__metaclass__ = PrimitiveTypeSingleton

def __eq__(self, other):
# because they should be the same object
return self is other


class NullType(PrimitiveType):

Expand Down Expand Up @@ -251,9 +247,9 @@ def __init__(self, elementType, containsNull=True):
:param elementType: the data type of elements.
:param containsNull: indicates whether the list contains None values.

>>> ArrayType(StringType) == ArrayType(StringType, True)
>>> ArrayType(StringType()) == ArrayType(StringType(), True)
True
>>> ArrayType(StringType, False) == ArrayType(StringType)
>>> ArrayType(StringType(), False) == ArrayType(StringType())
False
"""
self.elementType = elementType
Expand Down Expand Up @@ -298,11 +294,11 @@ def __init__(self, keyType, valueType, valueContainsNull=True):
:param valueContainsNull: indicates whether values contains
null values.

>>> (MapType(StringType, IntegerType)
... == MapType(StringType, IntegerType, True))
>>> (MapType(StringType(), IntegerType())
... == MapType(StringType(), IntegerType(), True))
True
>>> (MapType(StringType, IntegerType, False)
... == MapType(StringType, FloatType))
>>> (MapType(StringType(), IntegerType(), False)
... == MapType(StringType(), FloatType()))
False
"""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I noticed that the PR opened against master added typechecking asserts here. Should we also add them in this branch, or is there a reason why we should omit them here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because we can not cherry-pick the patch from master, I need to re-do all the things on 1.2/1.1, so I'd like to keep the changes minimized.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fair enough; just wanted to check.

self.keyType = keyType
Expand Down Expand Up @@ -351,11 +347,11 @@ def __init__(self, name, dataType, nullable=True, metadata=None):
to simple type that can be serialized to JSON
automatically

>>> (StructField("f1", StringType, True)
... == StructField("f1", StringType, True))
>>> (StructField("f1", StringType(), True)
... == StructField("f1", StringType(), True))
True
>>> (StructField("f1", StringType, True)
... == StructField("f2", StringType, True))
>>> (StructField("f1", StringType(), True)
... == StructField("f2", StringType(), True))
False
"""
self.name = name
Expand Down Expand Up @@ -393,13 +389,13 @@ class StructType(DataType):
def __init__(self, fields):
"""Creates a StructType

>>> struct1 = StructType([StructField("f1", StringType, True)])
>>> struct2 = StructType([StructField("f1", StringType, True)])
>>> struct1 = StructType([StructField("f1", StringType(), True)])
>>> struct2 = StructType([StructField("f1", StringType(), True)])
>>> struct1 == struct2
True
>>> struct1 = StructType([StructField("f1", StringType, True)])
>>> struct2 = StructType([StructField("f1", StringType, True),
... [StructField("f2", IntegerType, False)]])
>>> struct1 = StructType([StructField("f1", StringType(), True)])
>>> struct2 = StructType([StructField("f1", StringType(), True),
... StructField("f2", IntegerType(), False)])
>>> struct1 == struct2
False
"""
Expand Down Expand Up @@ -499,6 +495,10 @@ def __eq__(self, other):

def _parse_datatype_json_string(json_string):
"""Parses the given data type JSON string.

>>> import pickle
>>> LongType() == pickle.loads(pickle.dumps(LongType()))
True
>>> def check_datatype(datatype):
... scala_datatype = sqlCtx._ssql_ctx.parseDataType(datatype.json())
... python_datatype = _parse_datatype_json_string(scala_datatype.json())
Expand Down Expand Up @@ -781,8 +781,25 @@ def _merge_type(a, b):
return a


def _need_converter(dataType):
if isinstance(dataType, StructType):
return True
elif isinstance(dataType, ArrayType):
return _need_converter(dataType.elementType)
elif isinstance(dataType, MapType):
return _need_converter(dataType.keyType) or _need_converter(dataType.valueType)
elif isinstance(dataType, NullType):
return True
else:
return False


def _create_converter(dataType):
"""Create an converter to drop the names of fields in obj """

if not _need_converter(dataType):
return lambda x: x

if isinstance(dataType, ArrayType):
conv = _create_converter(dataType.elementType)
return lambda row: map(conv, row)
Expand All @@ -800,6 +817,7 @@ def _create_converter(dataType):
# dataType must be StructType
names = [f.name for f in dataType.fields]
converters = [_create_converter(f.dataType) for f in dataType.fields]
convert_fields = any(_need_converter(f.dataType) for f in dataType.fields)

def convert_struct(obj):
if obj is None:
Expand All @@ -822,7 +840,10 @@ def convert_struct(obj):
else:
raise ValueError("Unexpected obj: %s" % obj)

return tuple([conv(d.get(name)) for name, conv in zip(names, converters)])
if convert_fields:
return tuple([conv(d.get(name)) for name, conv in zip(names, converters)])
else:
return tuple([d.get(name) for name in names])

return convert_struct

Expand Down Expand Up @@ -1039,7 +1060,7 @@ def _verify_type(obj, dataType):
_verify_type(v, f.dataType)


_cached_cls = {}
_cached_cls = weakref.WeakValueDictionary()


def _restore_object(dataType, obj):
Expand Down