Skip to content

Commit 573a2c9

Browse files
damnMeddlingKidrxin
authored andcommitted
[SPARK-13410][SQL] Support unionAll for DataFrames with UDT columns.
## What changes were proposed in this pull request? This PR adds equality operators to UDT classes so that they can be correctly tested for dataType equality during union operations. This was previously causing `"AnalysisException: u"unresolved operator 'Union;""` when trying to unionAll two dataframes with UDT columns as below. ``` from pyspark.sql.tests import PythonOnlyPoint, PythonOnlyUDT from pyspark.sql import types schema = types.StructType([types.StructField("point", PythonOnlyUDT(), True)]) a = sqlCtx.createDataFrame([[PythonOnlyPoint(1.0, 2.0)]], schema) b = sqlCtx.createDataFrame([[PythonOnlyPoint(3.0, 4.0)]], schema) c = a.unionAll(b) ``` ## How was the this patch tested? Tested using two unit tests in sql/test.py and the DataFrameSuite. Additional information here : https://issues.apache.org/jira/browse/SPARK-13410 rxin Author: Franklyn D'souza <[email protected]> Closes #11333 from damnMeddlingKid/udt-union-patch.
1 parent 0784e02 commit 573a2c9

File tree

4 files changed

+50
-1
lines changed

4 files changed

+50
-1
lines changed

python/pyspark/sql/tests.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -601,6 +601,24 @@ def test_parquet_with_udt(self):
601601
point = df1.head().point
602602
self.assertEqual(point, PythonOnlyPoint(1.0, 2.0))
603603

604+
def test_unionAll_with_udt(self):
605+
from pyspark.sql.tests import ExamplePoint, ExamplePointUDT
606+
row1 = (1.0, ExamplePoint(1.0, 2.0))
607+
row2 = (2.0, ExamplePoint(3.0, 4.0))
608+
schema = StructType([StructField("label", DoubleType(), False),
609+
StructField("point", ExamplePointUDT(), False)])
610+
df1 = self.sqlCtx.createDataFrame([row1], schema)
611+
df2 = self.sqlCtx.createDataFrame([row2], schema)
612+
613+
result = df1.unionAll(df2).orderBy("label").collect()
614+
self.assertEqual(
615+
result,
616+
[
617+
Row(label=1.0, point=ExamplePoint(1.0, 2.0)),
618+
Row(label=2.0, point=ExamplePoint(3.0, 4.0))
619+
]
620+
)
621+
604622
def test_column_operators(self):
605623
ci = self.df.key
606624
cs = self.df.value

sql/catalyst/src/main/scala/org/apache/spark/sql/types/UserDefinedType.scala

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,11 @@ abstract class UserDefinedType[UserType] extends DataType with Serializable {
8484

8585
override private[sql] def acceptsType(dataType: DataType) =
8686
this.getClass == dataType.getClass
87+
88+
override def equals(other: Any): Boolean = other match {
89+
case that: UserDefinedType[_] => this.acceptsType(that)
90+
case _ => false
91+
}
8792
}
8893

8994
/**
@@ -110,4 +115,9 @@ private[sql] class PythonUserDefinedType(
110115
("serializedClass" -> serializedPyClass) ~
111116
("sqlType" -> sqlType.jsonValue)
112117
}
118+
119+
override def equals(other: Any): Boolean = other match {
120+
case that: PythonUserDefinedType => this.pyUDT.equals(that.pyUDT)
121+
case _ => false
122+
}
113123
}

sql/core/src/main/scala/org/apache/spark/sql/test/ExamplePointUDT.scala

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,12 @@ import org.apache.spark.sql.types._
2626
* @param y y coordinate
2727
*/
2828
@SQLUserDefinedType(udt = classOf[ExamplePointUDT])
29-
private[sql] class ExamplePoint(val x: Double, val y: Double) extends Serializable
29+
private[sql] class ExamplePoint(val x: Double, val y: Double) extends Serializable {
30+
override def equals(other: Any): Boolean = other match {
31+
case that: ExamplePoint => this.x == that.x && this.y == that.y
32+
case _ => false
33+
}
34+
}
3035

3136
/**
3237
* User-defined type for [[ExamplePoint]].

sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -516,6 +516,22 @@ class DataFrameSuite extends QueryTest with SharedSQLContext {
516516
}
517517
}
518518

519+
test("unionAll should union DataFrames with UDTs (SPARK-13410)") {
520+
val rowRDD1 = sparkContext.parallelize(Seq(Row(1, new ExamplePoint(1.0, 2.0))))
521+
val schema1 = StructType(Array(StructField("label", IntegerType, false),
522+
StructField("point", new ExamplePointUDT(), false)))
523+
val rowRDD2 = sparkContext.parallelize(Seq(Row(2, new ExamplePoint(3.0, 4.0))))
524+
val schema2 = StructType(Array(StructField("label", IntegerType, false),
525+
StructField("point", new ExamplePointUDT(), false)))
526+
val df1 = sqlContext.createDataFrame(rowRDD1, schema1)
527+
val df2 = sqlContext.createDataFrame(rowRDD2, schema2)
528+
529+
checkAnswer(
530+
df1.unionAll(df2).orderBy("label"),
531+
Seq(Row(1, new ExamplePoint(1.0, 2.0)), Row(2, new ExamplePoint(3.0, 4.0)))
532+
)
533+
}
534+
519535
ignore("show") {
520536
// This test case is intended ignored, but to make sure it compiles correctly
521537
testData.select($"*").show()

0 commit comments

Comments
 (0)