@@ -20,7 +20,8 @@ package org.apache.spark.sql
2020import scala .beans .{BeanInfo , BeanProperty }
2121
2222import org .apache .spark .rdd .RDD
23- import org .apache .spark .sql .catalyst .CatalystTypeConverters
23+ import org .apache .spark .sql .catalyst .{CatalystTypeConverters , InternalRow }
24+ import org .apache .spark .sql .catalyst .expressions .GenericInternalRow
2425import org .apache .spark .sql .catalyst .util .{ArrayData , GenericArrayData }
2526import org .apache .spark .sql .execution .datasources .parquet .ParquetTest
2627import org .apache .spark .sql .functions ._
@@ -71,6 +72,77 @@ object UDT {
7172
7273}
7374
75+ // object and classes to test SPARK-19311
76+
77+ // Trait/Interface for base type
78+ sealed trait IExampleBaseType extends Serializable {
79+ def field : Int
80+ }
81+
82+ // Trait/Interface for derived type
83+ sealed trait IExampleSubType extends IExampleBaseType
84+
85+ // a base class
86+ class ExampleBaseClass (override val field : Int ) extends IExampleBaseType
87+
88+ // a derived class
89+ class ExampleSubClass (override val field : Int )
90+ extends ExampleBaseClass (field) with IExampleSubType
91+
92+ // UDT for base class
93+ class ExampleBaseTypeUDT extends UserDefinedType [IExampleBaseType ] {
94+
95+ override def sqlType : StructType = {
96+ StructType (Seq (
97+ StructField (" intfield" , IntegerType , nullable = false )))
98+ }
99+
100+ override def serialize (obj : IExampleBaseType ): InternalRow = {
101+ val row = new GenericInternalRow (1 )
102+ row.setInt(0 , obj.field)
103+ row
104+ }
105+
106+ override def deserialize (datum : Any ): IExampleBaseType = {
107+ datum match {
108+ case row : InternalRow =>
109+ require(row.numFields == 1 ,
110+ " ExampleBaseTypeUDT requires row with length == 1" )
111+ val field = row.getInt(0 )
112+ new ExampleBaseClass (field)
113+ }
114+ }
115+
116+ override def userClass : Class [IExampleBaseType ] = classOf [IExampleBaseType ]
117+ }
118+
119+ // UDT for derived class
120+ private [spark] class ExampleSubTypeUDT extends UserDefinedType [IExampleSubType ] {
121+
122+ override def sqlType : StructType = {
123+ StructType (Seq (
124+ StructField (" intfield" , IntegerType , nullable = false )))
125+ }
126+
127+ override def serialize (obj : IExampleSubType ): InternalRow = {
128+ val row = new GenericInternalRow (1 )
129+ row.setInt(0 , obj.field)
130+ row
131+ }
132+
133+ override def deserialize (datum : Any ): IExampleSubType = {
134+ datum match {
135+ case row : InternalRow =>
136+ require(row.numFields == 1 ,
137+ " ExampleSubTypeUDT requires row with length == 1" )
138+ val field = row.getInt(0 )
139+ new ExampleSubClass (field)
140+ }
141+ }
142+
143+ override def userClass : Class [IExampleSubType ] = classOf [IExampleSubType ]
144+ }
145+
74146class UserDefinedTypeSuite extends QueryTest with SharedSQLContext with ParquetTest {
75147 import testImplicits ._
76148
@@ -194,4 +266,35 @@ class UserDefinedTypeSuite extends QueryTest with SharedSQLContext with ParquetT
194266 // call `collect` to make sure this query can pass analysis.
195267 pointsRDD.as[MyLabeledPoint ].map(_.copy(label = 2.0 )).collect()
196268 }
269+
270+ test(" SPARK-19311: UDFs disregard UDT type hierarchy" ) {
271+ UDTRegistration .register(classOf [IExampleBaseType ].getName,
272+ classOf [ExampleBaseTypeUDT ].getName)
273+ UDTRegistration .register(classOf [IExampleSubType ].getName,
274+ classOf [ExampleSubTypeUDT ].getName)
275+
276+ // UDF that returns a base class object
277+ sqlContext.udf.register(" doUDF" , (param : Int ) => {
278+ new ExampleBaseClass (param)
279+ }: IExampleBaseType )
280+
281+ // UDF that returns a derived class object
282+ sqlContext.udf.register(" doSubTypeUDF" , (param : Int ) => {
283+ new ExampleSubClass (param)
284+ }: IExampleSubType )
285+
286+ // UDF that takes a base class object as parameter
287+ sqlContext.udf.register(" doOtherUDF" , (obj : IExampleBaseType ) => {
288+ obj.field
289+ }: Int )
290+
291+ // this worked already before the fix SPARK-19311:
292+ // return type of doUDF equals parameter type of doOtherUDF
293+ sql(" SELECT doOtherUDF(doUDF(41))" )
294+
295+ // this one passes only with the fix SPARK-19311:
296+ // return type of doSubUDF is a subtype of the parameter type of doOtherUDF
297+ sql(" SELECT doOtherUDF(doSubTypeUDF(42))" )
298+ }
299+
197300}
0 commit comments