Skip to content

Commit 83d65ac

Browse files
author
Davies Liu
committed
fix bug in StructType
1 parent 55bb86e commit 83d65ac

File tree

4 files changed

+22
-19
lines changed

4 files changed

+22
-19
lines changed

python/pyspark/sql/context.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -295,6 +295,7 @@ def _createFromRDD(self, rdd, schema, samplingRatio):
295295
if isinstance(schema, (list, tuple)):
296296
for i, name in enumerate(schema):
297297
struct.fields[i].name = name
298+
struct.names[i] = name
298299
schema = struct
299300

300301
elif isinstance(schema, StructType):
@@ -325,6 +326,7 @@ def _createFromLocal(self, data, schema):
325326
if isinstance(schema, (list, tuple)):
326327
for i, name in enumerate(schema):
327328
struct.fields[i].name = name
329+
struct.names[i] = name
328330
schema = struct
329331

330332
elif isinstance(schema, StructType):

python/pyspark/sql/types.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -455,7 +455,7 @@ def __init__(self, fields=None):
455455
self.names = [f.name for f in fields]
456456
assert all(isinstance(f, StructField) for f in fields),\
457457
"fields should be a list of StructField"
458-
self._needSerializeFields = None
458+
self._needSerializeAnyField = any(f.needConversion() for f in self.fields)
459459

460460
def add(self, field, data_type=None, nullable=True, metadata=None):
461461
"""
@@ -498,6 +498,7 @@ def add(self, field, data_type=None, nullable=True, metadata=None):
498498
data_type_f = data_type
499499
self.fields.append(StructField(field, data_type_f, nullable, metadata))
500500
self.names.append(field)
501+
self._needSerializeAnyField = any(f.needConversion() for f in self.fields)
501502
return self
502503

503504
def simpleString(self):
@@ -523,12 +524,9 @@ def toInternal(self, obj):
523524
if obj is None:
524525
return
525526

526-
if self._needSerializeFields is None:
527-
self._needSerializeFields = any(f.needConversion() for f in self.fields)
528-
529-
if self._needSerializeFields:
527+
if self._needSerializeAnyField:
530528
if isinstance(obj, dict):
531-
return tuple(f.toInternal(obj.get(n)) for n, f in zip(names, self.fields))
529+
return tuple(f.toInternal(obj.get(n)) for n, f in zip(self.names, self.fields))
532530
elif isinstance(obj, (tuple, list)):
533531
return tuple(f.toInternal(v) for f, v in zip(self.fields, obj))
534532
else:
@@ -547,7 +545,10 @@ def fromInternal(self, obj):
547545
if isinstance(obj, Row):
548546
# it's already converted by pickler
549547
return obj
550-
values = [f.dataType.fromInternal(v) for f, v in zip(self.fields, obj)]
548+
if self._needSerializeAnyField:
549+
values = [f.fromInternal(v) for f, v in zip(self.fields, obj)]
550+
else:
551+
values = obj
551552
return _create_row(self.names, values)
552553

553554

sql/catalyst/src/main/scala/org/apache/spark/sql/types/UserDefinedType.scala

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -89,15 +89,16 @@ abstract class UserDefinedType[UserType] extends DataType with Serializable {
8989
*
9090
* Note: This can only be accessed via Python UDF, or accessed as serialized object.
9191
*/
92-
class PythonUserDefinedType(val sqlType: DataType, pyClass: String) extends UserDefinedType[Any] {
93-
override def pyUDT = pyClass
94-
def serialize(obj: Any): Any = obj
95-
def deserialize(datam: Any): Any = datam
96-
override private[sql] def jsonValue: JValue = {
97-
("type" -> "udt") ~
98-
("class" -> "") ~
99-
("pyClass" -> pyUDT) ~
100-
("sqlType" -> sqlType.jsonValue)
101-
}
102-
def userClass: java.lang.Class[Any] = null
92+
private[sql] class PythonUserDefinedType(val sqlType: DataType, pyClass: String)
93+
extends UserDefinedType[Any] {
94+
95+
/* The Python UDT class */
96+
override def pyUDT: String = pyClass
97+
98+
/* The serialization is handled by UDT class in Python */
99+
override def serialize(obj: Any): Any = obj
100+
override def deserialize(datam: Any): Any = datam
101+
102+
/* There is no Java class for Python UDT */
103+
override def userClass: java.lang.Class[Any] = null
103104
}

sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUDFs.scala

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -271,7 +271,6 @@ object EvaluatePython {
271271
pickler.save(row.values(i))
272272
i += 1
273273
}
274-
row.values.foreach(pickler.save)
275274
out.write(Opcodes.TUPLE)
276275
out.write(Opcodes.REDUCE)
277276
}

0 commit comments

Comments
 (0)