Skip to content

Commit 72a572f

Browse files
cloud-fanHyukjinKwon
authored andcommitted
[SPARK-26323][SQL] Scala UDF should still check input types even if some inputs are of type Any
## What changes were proposed in this pull request? For Scala UDF, when checking input nullability, we will skip inputs with type `Any`, and only check the inputs that provide nullability info. We should do the same for checking input types. ## How was this patch tested? new tests Closes #23275 from cloud-fan/udf. Authored-by: Wenchen Fan <[email protected]> Signed-off-by: Hyukjin Kwon <[email protected]>
1 parent 29a7d2d commit 72a572f

File tree

7 files changed

+175
-184
lines changed

7 files changed

+175
-184
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -882,7 +882,18 @@ object TypeCoercion {
882882

883883
case udf: ScalaUDF if udf.inputTypes.nonEmpty =>
884884
val children = udf.children.zip(udf.inputTypes).map { case (in, expected) =>
885-
implicitCast(in, udfInputToCastType(in.dataType, expected)).getOrElse(in)
885+
// Currently Scala UDF will only expect `AnyDataType` at top level, so this trick works.
886+
// In the future we should create types like `AbstractArrayType`, so that Scala UDF can
887+
// accept inputs of array type of arbitrary element type.
888+
if (expected == AnyDataType) {
889+
in
890+
} else {
891+
implicitCast(
892+
in,
893+
udfInputToCastType(in.dataType, expected.asInstanceOf[DataType])
894+
).getOrElse(in)
895+
}
896+
886897
}
887898
udf.withNewChildren(children)
888899
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ import org.apache.spark.SparkException
2121
import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow, ScalaReflection}
2222
import org.apache.spark.sql.catalyst.expressions.codegen._
2323
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
24-
import org.apache.spark.sql.types.DataType
24+
import org.apache.spark.sql.types.{AbstractDataType, DataType}
2525

2626
/**
2727
* User-defined function.
@@ -48,7 +48,7 @@ case class ScalaUDF(
4848
dataType: DataType,
4949
children: Seq[Expression],
5050
inputsNullSafe: Seq[Boolean],
51-
inputTypes: Seq[DataType] = Nil,
51+
inputTypes: Seq[AbstractDataType] = Nil,
5252
udfName: Option[String] = None,
5353
nullable: Boolean = true,
5454
udfDeterministic: Boolean = true)

sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@ private[sql] object TypeCollection {
9696
/**
9797
* An `AbstractDataType` that matches any concrete data types.
9898
*/
99-
protected[sql] object AnyDataType extends AbstractDataType {
99+
protected[sql] object AnyDataType extends AbstractDataType with Serializable {
100100

101101
// Note that since AnyDataType matches any concrete types, defaultConcreteType should never
102102
// be invoked.

sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala

Lines changed: 96 additions & 120 deletions
Large diffs are not rendered by default.

sql/core/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala

Lines changed: 24 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,8 @@ package org.apache.spark.sql.expressions
2020
import org.apache.spark.annotation.Stable
2121
import org.apache.spark.sql.Column
2222
import org.apache.spark.sql.catalyst.ScalaReflection
23-
import org.apache.spark.sql.catalyst.expressions.ScalaUDF
24-
import org.apache.spark.sql.types.DataType
23+
import org.apache.spark.sql.catalyst.expressions.{Expression, ScalaUDF}
24+
import org.apache.spark.sql.types.{AnyDataType, DataType}
2525

2626
/**
2727
* A user-defined function. To create one, use the `udf` functions in `functions`.
@@ -88,68 +88,59 @@ sealed abstract class UserDefinedFunction {
8888
private[sql] case class SparkUserDefinedFunction(
8989
f: AnyRef,
9090
dataType: DataType,
91-
inputTypes: Option[Seq[DataType]],
92-
nullableTypes: Option[Seq[Boolean]],
91+
inputSchemas: Seq[Option[ScalaReflection.Schema]],
9392
name: Option[String] = None,
9493
nullable: Boolean = true,
9594
deterministic: Boolean = true) extends UserDefinedFunction {
9695

9796
@scala.annotation.varargs
9897
override def apply(exprs: Column*): Column = {
99-
// TODO: make sure this class is only instantiated through `SparkUserDefinedFunction.create()`
100-
// and `nullableTypes` is always set.
101-
if (inputTypes.isDefined) {
102-
assert(inputTypes.get.length == nullableTypes.get.length)
103-
}
98+
Column(createScalaUDF(exprs.map(_.expr)))
99+
}
100+
101+
private[sql] def createScalaUDF(exprs: Seq[Expression]): ScalaUDF = {
102+
// It's possible that some of the inputs don't have a specific type(e.g. `Any`), skip type
103+
// check and null check for them.
104+
val inputTypes = inputSchemas.map(_.map(_.dataType).getOrElse(AnyDataType))
104105

105-
val inputsNullSafe = nullableTypes.getOrElse {
106+
val inputsNullSafe = if (inputSchemas.isEmpty) {
107+
// This is for backward compatibility of `functions.udf(AnyRef, DataType)`. We need to
108+
// do reflection of the lambda function object and see if its arguments are nullable or not.
109+
// This doesn't work for Scala 2.12 and we should consider removing this workaround, as Spark
110+
// uses Scala 2.12 by default since 3.0.
106111
ScalaReflection.getParameterTypeNullability(f)
112+
} else {
113+
inputSchemas.map(_.map(_.nullable).getOrElse(true))
107114
}
108115

109-
Column(ScalaUDF(
116+
ScalaUDF(
110117
f,
111118
dataType,
112-
exprs.map(_.expr),
119+
exprs,
113120
inputsNullSafe,
114-
inputTypes.getOrElse(Nil),
121+
inputTypes,
115122
udfName = name,
116123
nullable = nullable,
117-
udfDeterministic = deterministic))
124+
udfDeterministic = deterministic)
118125
}
119126

120-
override def withName(name: String): UserDefinedFunction = {
127+
override def withName(name: String): SparkUserDefinedFunction = {
121128
copy(name = Option(name))
122129
}
123130

124-
override def asNonNullable(): UserDefinedFunction = {
131+
override def asNonNullable(): SparkUserDefinedFunction = {
125132
if (!nullable) {
126133
this
127134
} else {
128135
copy(nullable = false)
129136
}
130137
}
131138

132-
override def asNondeterministic(): UserDefinedFunction = {
139+
override def asNondeterministic(): SparkUserDefinedFunction = {
133140
if (!deterministic) {
134141
this
135142
} else {
136143
copy(deterministic = false)
137144
}
138145
}
139146
}
140-
141-
private[sql] object SparkUserDefinedFunction {
142-
143-
def create(
144-
f: AnyRef,
145-
dataType: DataType,
146-
inputSchemas: Seq[Option[ScalaReflection.Schema]]): UserDefinedFunction = {
147-
val inputTypes = if (inputSchemas.contains(None)) {
148-
None
149-
} else {
150-
Some(inputSchemas.map(_.get.dataType))
151-
}
152-
val nullableTypes = Some(inputSchemas.map(_.map(_.nullable).getOrElse(true)))
153-
SparkUserDefinedFunction(f, dataType, inputTypes, nullableTypes)
154-
}
155-
}

sql/core/src/main/scala/org/apache/spark/sql/functions.scala

Lines changed: 25 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -3874,7 +3874,7 @@ object functions {
38743874
|def udf[$typeTags](f: Function$x[$types]): UserDefinedFunction = {
38753875
| val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT]
38763876
| val inputSchemas = $inputSchemas
3877-
| val udf = SparkUserDefinedFunction.create(f, dataType, inputSchemas)
3877+
| val udf = SparkUserDefinedFunction(f, dataType, inputSchemas)
38783878
| if (nullable) udf else udf.asNonNullable()
38793879
|}""".stripMargin)
38803880
}
@@ -3897,7 +3897,7 @@ object functions {
38973897
| */
38983898
|def udf(f: UDF$i[$extTypeArgs], returnType: DataType): UserDefinedFunction = {
38993899
| val func = f$anyCast.call($anyParams)
3900-
| SparkUserDefinedFunction.create($funcCall, returnType, inputSchemas = Seq.fill($i)(None))
3900+
| SparkUserDefinedFunction($funcCall, returnType, inputSchemas = Seq.fill($i)(None))
39013901
|}""".stripMargin)
39023902
}
39033903
@@ -3919,7 +3919,7 @@ object functions {
39193919
def udf[RT: TypeTag](f: Function0[RT]): UserDefinedFunction = {
39203920
val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT]
39213921
val inputSchemas = Nil
3922-
val udf = SparkUserDefinedFunction.create(f, dataType, inputSchemas)
3922+
val udf = SparkUserDefinedFunction(f, dataType, inputSchemas)
39233923
if (nullable) udf else udf.asNonNullable()
39243924
}
39253925

@@ -3935,7 +3935,7 @@ object functions {
39353935
def udf[RT: TypeTag, A1: TypeTag](f: Function1[A1, RT]): UserDefinedFunction = {
39363936
val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT]
39373937
val inputSchemas = Try(ScalaReflection.schemaFor(typeTag[A1])).toOption :: Nil
3938-
val udf = SparkUserDefinedFunction.create(f, dataType, inputSchemas)
3938+
val udf = SparkUserDefinedFunction(f, dataType, inputSchemas)
39393939
if (nullable) udf else udf.asNonNullable()
39403940
}
39413941

@@ -3951,7 +3951,7 @@ object functions {
39513951
def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag](f: Function2[A1, A2, RT]): UserDefinedFunction = {
39523952
val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT]
39533953
val inputSchemas = Try(ScalaReflection.schemaFor(typeTag[A1])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A2])).toOption :: Nil
3954-
val udf = SparkUserDefinedFunction.create(f, dataType, inputSchemas)
3954+
val udf = SparkUserDefinedFunction(f, dataType, inputSchemas)
39553955
if (nullable) udf else udf.asNonNullable()
39563956
}
39573957

@@ -3967,7 +3967,7 @@ object functions {
39673967
def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag](f: Function3[A1, A2, A3, RT]): UserDefinedFunction = {
39683968
val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT]
39693969
val inputSchemas = Try(ScalaReflection.schemaFor(typeTag[A1])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A2])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A3])).toOption :: Nil
3970-
val udf = SparkUserDefinedFunction.create(f, dataType, inputSchemas)
3970+
val udf = SparkUserDefinedFunction(f, dataType, inputSchemas)
39713971
if (nullable) udf else udf.asNonNullable()
39723972
}
39733973

@@ -3983,7 +3983,7 @@ object functions {
39833983
def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag](f: Function4[A1, A2, A3, A4, RT]): UserDefinedFunction = {
39843984
val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT]
39853985
val inputSchemas = Try(ScalaReflection.schemaFor(typeTag[A1])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A2])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A3])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A4])).toOption :: Nil
3986-
val udf = SparkUserDefinedFunction.create(f, dataType, inputSchemas)
3986+
val udf = SparkUserDefinedFunction(f, dataType, inputSchemas)
39873987
if (nullable) udf else udf.asNonNullable()
39883988
}
39893989

@@ -3999,7 +3999,7 @@ object functions {
39993999
def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag](f: Function5[A1, A2, A3, A4, A5, RT]): UserDefinedFunction = {
40004000
val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT]
40014001
val inputSchemas = Try(ScalaReflection.schemaFor(typeTag[A1])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A2])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A3])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A4])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A5])).toOption :: Nil
4002-
val udf = SparkUserDefinedFunction.create(f, dataType, inputSchemas)
4002+
val udf = SparkUserDefinedFunction(f, dataType, inputSchemas)
40034003
if (nullable) udf else udf.asNonNullable()
40044004
}
40054005

@@ -4015,7 +4015,7 @@ object functions {
40154015
def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag](f: Function6[A1, A2, A3, A4, A5, A6, RT]): UserDefinedFunction = {
40164016
val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT]
40174017
val inputSchemas = Try(ScalaReflection.schemaFor(typeTag[A1])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A2])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A3])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A4])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A5])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A6])).toOption :: Nil
4018-
val udf = SparkUserDefinedFunction.create(f, dataType, inputSchemas)
4018+
val udf = SparkUserDefinedFunction(f, dataType, inputSchemas)
40194019
if (nullable) udf else udf.asNonNullable()
40204020
}
40214021

@@ -4031,7 +4031,7 @@ object functions {
40314031
def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag](f: Function7[A1, A2, A3, A4, A5, A6, A7, RT]): UserDefinedFunction = {
40324032
val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT]
40334033
val inputSchemas = Try(ScalaReflection.schemaFor(typeTag[A1])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A2])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A3])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A4])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A5])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A6])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A7])).toOption :: Nil
4034-
val udf = SparkUserDefinedFunction.create(f, dataType, inputSchemas)
4034+
val udf = SparkUserDefinedFunction(f, dataType, inputSchemas)
40354035
if (nullable) udf else udf.asNonNullable()
40364036
}
40374037

@@ -4047,7 +4047,7 @@ object functions {
40474047
def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag](f: Function8[A1, A2, A3, A4, A5, A6, A7, A8, RT]): UserDefinedFunction = {
40484048
val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT]
40494049
val inputSchemas = Try(ScalaReflection.schemaFor(typeTag[A1])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A2])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A3])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A4])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A5])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A6])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A7])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A8])).toOption :: Nil
4050-
val udf = SparkUserDefinedFunction.create(f, dataType, inputSchemas)
4050+
val udf = SparkUserDefinedFunction(f, dataType, inputSchemas)
40514051
if (nullable) udf else udf.asNonNullable()
40524052
}
40534053

@@ -4063,7 +4063,7 @@ object functions {
40634063
def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag](f: Function9[A1, A2, A3, A4, A5, A6, A7, A8, A9, RT]): UserDefinedFunction = {
40644064
val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT]
40654065
val inputSchemas = Try(ScalaReflection.schemaFor(typeTag[A1])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A2])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A3])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A4])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A5])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A6])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A7])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A8])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A9])).toOption :: Nil
4066-
val udf = SparkUserDefinedFunction.create(f, dataType, inputSchemas)
4066+
val udf = SparkUserDefinedFunction(f, dataType, inputSchemas)
40674067
if (nullable) udf else udf.asNonNullable()
40684068
}
40694069

@@ -4079,7 +4079,7 @@ object functions {
40794079
def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag](f: Function10[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, RT]): UserDefinedFunction = {
40804080
val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT]
40814081
val inputSchemas = Try(ScalaReflection.schemaFor(typeTag[A1])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A2])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A3])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A4])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A5])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A6])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A7])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A8])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A9])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A10])).toOption :: Nil
4082-
val udf = SparkUserDefinedFunction.create(f, dataType, inputSchemas)
4082+
val udf = SparkUserDefinedFunction(f, dataType, inputSchemas)
40834083
if (nullable) udf else udf.asNonNullable()
40844084
}
40854085

@@ -4098,7 +4098,7 @@ object functions {
40984098
*/
40994099
def udf(f: UDF0[_], returnType: DataType): UserDefinedFunction = {
41004100
val func = f.asInstanceOf[UDF0[Any]].call()
4101-
SparkUserDefinedFunction.create(() => func, returnType, inputSchemas = Seq.fill(0)(None))
4101+
SparkUserDefinedFunction(() => func, returnType, inputSchemas = Seq.fill(0)(None))
41024102
}
41034103

41044104
/**
@@ -4112,7 +4112,7 @@ object functions {
41124112
*/
41134113
def udf(f: UDF1[_, _], returnType: DataType): UserDefinedFunction = {
41144114
val func = f.asInstanceOf[UDF1[Any, Any]].call(_: Any)
4115-
SparkUserDefinedFunction.create(func, returnType, inputSchemas = Seq.fill(1)(None))
4115+
SparkUserDefinedFunction(func, returnType, inputSchemas = Seq.fill(1)(None))
41164116
}
41174117

41184118
/**
@@ -4126,7 +4126,7 @@ object functions {
41264126
*/
41274127
def udf(f: UDF2[_, _, _], returnType: DataType): UserDefinedFunction = {
41284128
val func = f.asInstanceOf[UDF2[Any, Any, Any]].call(_: Any, _: Any)
4129-
SparkUserDefinedFunction.create(func, returnType, inputSchemas = Seq.fill(2)(None))
4129+
SparkUserDefinedFunction(func, returnType, inputSchemas = Seq.fill(2)(None))
41304130
}
41314131

41324132
/**
@@ -4140,7 +4140,7 @@ object functions {
41404140
*/
41414141
def udf(f: UDF3[_, _, _, _], returnType: DataType): UserDefinedFunction = {
41424142
val func = f.asInstanceOf[UDF3[Any, Any, Any, Any]].call(_: Any, _: Any, _: Any)
4143-
SparkUserDefinedFunction.create(func, returnType, inputSchemas = Seq.fill(3)(None))
4143+
SparkUserDefinedFunction(func, returnType, inputSchemas = Seq.fill(3)(None))
41444144
}
41454145

41464146
/**
@@ -4154,7 +4154,7 @@ object functions {
41544154
*/
41554155
def udf(f: UDF4[_, _, _, _, _], returnType: DataType): UserDefinedFunction = {
41564156
val func = f.asInstanceOf[UDF4[Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any)
4157-
SparkUserDefinedFunction.create(func, returnType, inputSchemas = Seq.fill(4)(None))
4157+
SparkUserDefinedFunction(func, returnType, inputSchemas = Seq.fill(4)(None))
41584158
}
41594159

41604160
/**
@@ -4168,7 +4168,7 @@ object functions {
41684168
*/
41694169
def udf(f: UDF5[_, _, _, _, _, _], returnType: DataType): UserDefinedFunction = {
41704170
val func = f.asInstanceOf[UDF5[Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any)
4171-
SparkUserDefinedFunction.create(func, returnType, inputSchemas = Seq.fill(5)(None))
4171+
SparkUserDefinedFunction(func, returnType, inputSchemas = Seq.fill(5)(None))
41724172
}
41734173

41744174
/**
@@ -4182,7 +4182,7 @@ object functions {
41824182
*/
41834183
def udf(f: UDF6[_, _, _, _, _, _, _], returnType: DataType): UserDefinedFunction = {
41844184
val func = f.asInstanceOf[UDF6[Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any)
4185-
SparkUserDefinedFunction.create(func, returnType, inputSchemas = Seq.fill(6)(None))
4185+
SparkUserDefinedFunction(func, returnType, inputSchemas = Seq.fill(6)(None))
41864186
}
41874187

41884188
/**
@@ -4196,7 +4196,7 @@ object functions {
41964196
*/
41974197
def udf(f: UDF7[_, _, _, _, _, _, _, _], returnType: DataType): UserDefinedFunction = {
41984198
val func = f.asInstanceOf[UDF7[Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any)
4199-
SparkUserDefinedFunction.create(func, returnType, inputSchemas = Seq.fill(7)(None))
4199+
SparkUserDefinedFunction(func, returnType, inputSchemas = Seq.fill(7)(None))
42004200
}
42014201

42024202
/**
@@ -4210,7 +4210,7 @@ object functions {
42104210
*/
42114211
def udf(f: UDF8[_, _, _, _, _, _, _, _, _], returnType: DataType): UserDefinedFunction = {
42124212
val func = f.asInstanceOf[UDF8[Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any)
4213-
SparkUserDefinedFunction.create(func, returnType, inputSchemas = Seq.fill(8)(None))
4213+
SparkUserDefinedFunction(func, returnType, inputSchemas = Seq.fill(8)(None))
42144214
}
42154215

42164216
/**
@@ -4224,7 +4224,7 @@ object functions {
42244224
*/
42254225
def udf(f: UDF9[_, _, _, _, _, _, _, _, _, _], returnType: DataType): UserDefinedFunction = {
42264226
val func = f.asInstanceOf[UDF9[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any)
4227-
SparkUserDefinedFunction.create(func, returnType, inputSchemas = Seq.fill(9)(None))
4227+
SparkUserDefinedFunction(func, returnType, inputSchemas = Seq.fill(9)(None))
42284228
}
42294229

42304230
/**
@@ -4238,7 +4238,7 @@ object functions {
42384238
*/
42394239
def udf(f: UDF10[_, _, _, _, _, _, _, _, _, _, _], returnType: DataType): UserDefinedFunction = {
42404240
val func = f.asInstanceOf[UDF10[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any)
4241-
SparkUserDefinedFunction.create(func, returnType, inputSchemas = Seq.fill(10)(None))
4241+
SparkUserDefinedFunction(func, returnType, inputSchemas = Seq.fill(10)(None))
42424242
}
42434243

42444244
// scalastyle:on parameter.number
@@ -4257,9 +4257,7 @@ object functions {
42574257
* @since 2.0.0
42584258
*/
42594259
def udf(f: AnyRef, dataType: DataType): UserDefinedFunction = {
4260-
// TODO: should call SparkUserDefinedFunction.create() instead but inputSchemas is currently
4261-
// unavailable. We may need to create type-safe overloaded versions of udf() methods.
4262-
SparkUserDefinedFunction(f, dataType, inputTypes = None, nullableTypes = None)
4260+
SparkUserDefinedFunction(f, dataType, inputSchemas = Nil)
42634261
}
42644262

42654263
/**

0 commit comments

Comments
 (0)