Skip to content

Commit 4a22bce

Browse files
Cheolsoo Parkrxin
authored andcommitted
[SPARK-8572] [SQL] Type coercion for ScalaUDFs
Implemented type coercion for udf arguments in Scala. The changes include- * Add `with ExpectsInputTypes ` to `ScalaUDF` class. * Pass down argument types info from `UDFRegistration` and `functions`. With this patch, the example query in [SPARK-8572](https://issues.apache.org/jira/browse/SPARK-8572) no longer throws a type cast error at runtime. Also added a unit test to `UDFSuite` in which a decimal type is passed to a udf that expects an int. Author: Cheolsoo Park <[email protected]> Closes #7203 from piaozhexiu/SPARK-8572 and squashes the following commits: 2d0ed15 [Cheolsoo Park] Incorporate comments dce1efd [Cheolsoo Park] Fix unit tests and update the codegen script 066deed [Cheolsoo Park] Type coercion for udf inputs
1 parent e92c24d commit 4a22bce

File tree

6 files changed

+93
-42
lines changed

6 files changed

+93
-42
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -680,7 +680,7 @@ object HiveTypeCoercion {
680680
// Skip nodes who's children have not been resolved yet.
681681
case e if !e.childrenResolved => e
682682

683-
case e: ExpectsInputTypes =>
683+
case e: ExpectsInputTypes if (e.inputTypes.nonEmpty) =>
684684
val children: Seq[Expression] = e.children.zip(e.inputTypes).map { case (in, expected) =>
685685
// If we cannot do the implicit cast, just use the original input.
686686
implicitCast(in, expected).getOrElse(in)

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,11 @@ import org.apache.spark.sql.types.DataType
2424
* User-defined function.
2525
* @param dataType Return type of function.
2626
*/
27-
case class ScalaUDF(function: AnyRef, dataType: DataType, children: Seq[Expression])
28-
extends Expression {
27+
case class ScalaUDF(
28+
function: AnyRef,
29+
dataType: DataType,
30+
children: Seq[Expression],
31+
inputTypes: Seq[DataType] = Nil) extends Expression with ExpectsInputTypes {
2932

3033
override def nullable: Boolean = true
3134

0 commit comments

Comments
 (0)