Skip to content

Commit 1307432

Browse files
gmoehlercmonkey
authored andcommitted
[SPARK-19311][SQL] fix UDT hierarchy issue
## What changes were proposed in this pull request? acceptType() in UDT will no only accept the same type but also all base types ## How was this patch tested? Manual test using a set of generated UDTs fixing acceptType() in my user defined types Please review http://spark.apache.org/contributing.html before opening a pull request. Author: gmoehler <[email protected]> Closes apache#16660 from gmoehler/master.
1 parent d1b17f3 commit 1307432

File tree

2 files changed

+110
-3
lines changed

2 files changed

+110
-3
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/types/UserDefinedType.scala

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -78,8 +78,12 @@ abstract class UserDefinedType[UserType >: Null] extends DataType with Serializa
7878
*/
7979
override private[spark] def asNullable: UserDefinedType[UserType] = this
8080

81-
override private[sql] def acceptsType(dataType: DataType) =
82-
this.getClass == dataType.getClass
81+
override private[sql] def acceptsType(dataType: DataType) = dataType match {
82+
case other: UserDefinedType[_] =>
83+
this.getClass == other.getClass ||
84+
this.userClass.isAssignableFrom(other.userClass)
85+
case _ => false
86+
}
8387

8488
override def sql: String = sqlType.sql
8589

sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala

Lines changed: 104 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,8 @@ package org.apache.spark.sql
2020
import scala.beans.{BeanInfo, BeanProperty}
2121

2222
import org.apache.spark.rdd.RDD
23-
import org.apache.spark.sql.catalyst.CatalystTypeConverters
23+
import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
24+
import org.apache.spark.sql.catalyst.expressions.GenericInternalRow
2425
import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData}
2526
import org.apache.spark.sql.execution.datasources.parquet.ParquetTest
2627
import org.apache.spark.sql.functions._
@@ -71,6 +72,77 @@ object UDT {
7172

7273
}
7374

75+
// object and classes to test SPARK-19311
76+
77+
// Trait/Interface for base type
78+
sealed trait IExampleBaseType extends Serializable {
79+
def field: Int
80+
}
81+
82+
// Trait/Interface for derived type
83+
sealed trait IExampleSubType extends IExampleBaseType
84+
85+
// a base class
86+
class ExampleBaseClass(override val field: Int) extends IExampleBaseType
87+
88+
// a derived class
89+
class ExampleSubClass(override val field: Int)
90+
extends ExampleBaseClass(field) with IExampleSubType
91+
92+
// UDT for base class
93+
class ExampleBaseTypeUDT extends UserDefinedType[IExampleBaseType] {
94+
95+
override def sqlType: StructType = {
96+
StructType(Seq(
97+
StructField("intfield", IntegerType, nullable = false)))
98+
}
99+
100+
override def serialize(obj: IExampleBaseType): InternalRow = {
101+
val row = new GenericInternalRow(1)
102+
row.setInt(0, obj.field)
103+
row
104+
}
105+
106+
override def deserialize(datum: Any): IExampleBaseType = {
107+
datum match {
108+
case row: InternalRow =>
109+
require(row.numFields == 1,
110+
"ExampleBaseTypeUDT requires row with length == 1")
111+
val field = row.getInt(0)
112+
new ExampleBaseClass(field)
113+
}
114+
}
115+
116+
override def userClass: Class[IExampleBaseType] = classOf[IExampleBaseType]
117+
}
118+
119+
// UDT for derived class
120+
private[spark] class ExampleSubTypeUDT extends UserDefinedType[IExampleSubType] {
121+
122+
override def sqlType: StructType = {
123+
StructType(Seq(
124+
StructField("intfield", IntegerType, nullable = false)))
125+
}
126+
127+
override def serialize(obj: IExampleSubType): InternalRow = {
128+
val row = new GenericInternalRow(1)
129+
row.setInt(0, obj.field)
130+
row
131+
}
132+
133+
override def deserialize(datum: Any): IExampleSubType = {
134+
datum match {
135+
case row: InternalRow =>
136+
require(row.numFields == 1,
137+
"ExampleSubTypeUDT requires row with length == 1")
138+
val field = row.getInt(0)
139+
new ExampleSubClass(field)
140+
}
141+
}
142+
143+
override def userClass: Class[IExampleSubType] = classOf[IExampleSubType]
144+
}
145+
74146
class UserDefinedTypeSuite extends QueryTest with SharedSQLContext with ParquetTest {
75147
import testImplicits._
76148

@@ -194,4 +266,35 @@ class UserDefinedTypeSuite extends QueryTest with SharedSQLContext with ParquetT
194266
// call `collect` to make sure this query can pass analysis.
195267
pointsRDD.as[MyLabeledPoint].map(_.copy(label = 2.0)).collect()
196268
}
269+
270+
test("SPARK-19311: UDFs disregard UDT type hierarchy") {
271+
UDTRegistration.register(classOf[IExampleBaseType].getName,
272+
classOf[ExampleBaseTypeUDT].getName)
273+
UDTRegistration.register(classOf[IExampleSubType].getName,
274+
classOf[ExampleSubTypeUDT].getName)
275+
276+
// UDF that returns a base class object
277+
sqlContext.udf.register("doUDF", (param: Int) => {
278+
new ExampleBaseClass(param)
279+
}: IExampleBaseType)
280+
281+
// UDF that returns a derived class object
282+
sqlContext.udf.register("doSubTypeUDF", (param: Int) => {
283+
new ExampleSubClass(param)
284+
}: IExampleSubType)
285+
286+
// UDF that takes a base class object as parameter
287+
sqlContext.udf.register("doOtherUDF", (obj: IExampleBaseType) => {
288+
obj.field
289+
}: Int)
290+
291+
// this worked already before the fix SPARK-19311:
292+
// return type of doUDF equals parameter type of doOtherUDF
293+
sql("SELECT doOtherUDF(doUDF(41))")
294+
295+
// this one passes only with the fix SPARK-19311:
296+
// return type of doSubUDF is a subtype of the parameter type of doOtherUDF
297+
sql("SELECT doOtherUDF(doSubTypeUDF(42))")
298+
}
299+
197300
}

0 commit comments

Comments
 (0)