Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 16 additions & 6 deletions python/pyspark/sql/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -483,7 +483,9 @@ def __init__(self, fields=None):
self.names = [f.name for f in fields]
assert all(isinstance(f, StructField) for f in fields),\
"fields should be a list of StructField"
self._needSerializeAnyField = any(f.needConversion() for f in self)
# Precalculated list of fields that need conversion with fromInternal/toInternal functions
self._needConversion = [f.needConversion() for f in self]
self._needSerializeAnyField = any(self._needConversion)

def add(self, field, data_type=None, nullable=True, metadata=None):
"""
Expand Down Expand Up @@ -528,7 +530,9 @@ def add(self, field, data_type=None, nullable=True, metadata=None):
data_type_f = data_type
self.fields.append(StructField(field, data_type_f, nullable, metadata))
self.names.append(field)
self._needSerializeAnyField = any(f.needConversion() for f in self)
# Precalculated list of fields that need conversion with fromInternal/toInternal functions
self._needConversion = [f.needConversion() for f in self]
self._needSerializeAnyField = any(self._needConversion)
return self

def __iter__(self):
Expand Down Expand Up @@ -590,13 +594,17 @@ def toInternal(self, obj):
return

if self._needSerializeAnyField:
# Only calling toInternal function for fields that need conversion
if isinstance(obj, dict):
return tuple(f.toInternal(obj.get(n)) for n, f in zip(self.names, self.fields))
return tuple(f.toInternal(obj.get(n)) if c else obj.get(n)
for n, f, c in zip(self.names, self.fields, self._needConversion))
elif isinstance(obj, (tuple, list)):
return tuple(f.toInternal(v) for f, v in zip(self.fields, obj))
return tuple(f.toInternal(v) if c else v
for f, v, c in zip(self.fields, obj, self._needConversion))
elif hasattr(obj, "__dict__"):
d = obj.__dict__
return tuple(f.toInternal(d.get(n)) for n, f in zip(self.names, self.fields))
return tuple(f.toInternal(d.get(n)) if c else d.get(n)
for n, f, c in zip(self.names, self.fields, self._needConversion))
else:
raise ValueError("Unexpected tuple %r with StructType" % obj)
else:
Expand All @@ -619,7 +627,9 @@ def fromInternal(self, obj):
# it's already converted by pickler
return obj
if self._needSerializeAnyField:
values = [f.fromInternal(v) for f, v in zip(self.fields, obj)]
# Only calling fromInternal function for fields that need conversion
values = [f.fromInternal(v) if c else v
for f, v, c in zip(self.fields, obj, self._needConversion)]
else:
values = obj
return _create_row(self.names, values)
Expand Down