diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala index 9e4da36b67ee6..98aba7b78fd4d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala @@ -1409,9 +1409,14 @@ class SessionCatalog( Utils.classForName("org.apache.spark.sql.expressions.UserDefinedAggregateFunction") if (clsForUDAF.isAssignableFrom(clazz)) { val cls = Utils.classForName("org.apache.spark.sql.execution.aggregate.ScalaUDAF") - val e = cls.getConstructor(classOf[Seq[Expression]], clsForUDAF, classOf[Int], classOf[Int]) - .newInstance(input, - clazz.getConstructor().newInstance().asInstanceOf[Object], Int.box(1), Int.box(1)) + val e = cls.getConstructor( + classOf[Seq[Expression]], clsForUDAF, classOf[Int], classOf[Int], classOf[Option[String]]) + .newInstance( + input, + clazz.getConstructor().newInstance().asInstanceOf[Object], + Int.box(1), + Int.box(1), + Some(name)) .asInstanceOf[ImplicitCastInputTypes] // Check input argument size diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index fa9b7d2c4252d..83f77a6abd490 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -1088,4 +1088,6 @@ trait ComplexTypeMergingExpression extends Expression { * Common base trait for user-defined functions, including UDF/UDAF/UDTF of different languages * and Hive function wrappers. */ -trait UserDefinedExpression +trait UserDefinedExpression { + def name: String +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala index 011a3503dd0c2..4086e7698e7b1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala @@ -57,7 +57,9 @@ case class ScalaUDF( override lazy val deterministic: Boolean = udfDeterministic && children.forall(_.deterministic) - override def toString: String = s"${udfName.getOrElse("UDF")}(${children.mkString(", ")})" + override def toString: String = s"$name(${children.mkString(", ")})" + + override def name: String = udfName.getOrElse("UDF") override lazy val canonicalized: Expression = { // SPARK-32307: `ExpressionEncoder` can't be canonicalized, and technically we don't diff --git a/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala b/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala index 7567ed63359a9..4c165680d428b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala @@ -82,7 +82,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends @deprecated("Aggregator[IN, BUF, OUT] should now be registered as a UDF" + " via the functions.udaf(agg) method.", "3.0.0") def register(name: String, udaf: UserDefinedAggregateFunction): UserDefinedAggregateFunction = { - def builder(children: Seq[Expression]) = ScalaUDAF(children, udaf) + def builder(children: Seq[Expression]) = ScalaUDAF(children, udaf, udafName = Some(name)) functionRegistry.createOrReplaceTempFunction(name, builder) udaf } @@ -109,15 +109,15 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends * @since 2.2.0 */ def register(name: String, udf: UserDefinedFunction): UserDefinedFunction = { - udf match { + udf.withName(name) match { case udaf: UserDefinedAggregator[_, _, _] => def builder(children: Seq[Expression]) = udaf.scalaAggregator(children) functionRegistry.createOrReplaceTempFunction(name, builder) - udf - case _ => - def builder(children: Seq[Expression]) = udf.apply(children.map(Column.apply) : _*).expr + udaf + case other => + def builder(children: Seq[Expression]) = other.apply(children.map(Column.apply) : _*).expr functionRegistry.createOrReplaceTempFunction(name, builder) - udf + other } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala index 41e247a02759b..e6851a9af739f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala @@ -325,7 +325,8 @@ case class ScalaUDAF( children: Seq[Expression], udaf: UserDefinedAggregateFunction, mutableAggBufferOffset: Int = 0, - inputAggBufferOffset: Int = 0) + inputAggBufferOffset: Int = 0, + udafName: Option[String] = None) extends ImperativeAggregate with NonSQLExpression with Logging @@ -447,10 +448,12 @@ case class ScalaUDAF( } override def toString: String = { - s"""${udaf.getClass.getSimpleName}(${children.mkString(",")})""" + s"""$nodeName(${children.mkString(",")})""" } - override def nodeName: String = udaf.getClass.getSimpleName + override def nodeName: String = name + + override def name: String = udafName.getOrElse(udaf.getClass.getSimpleName) } case class ScalaAggregator[IN, BUF, OUT]( @@ -461,7 +464,8 @@ case class ScalaAggregator[IN, BUF, OUT]( nullable: Boolean = true, isDeterministic: Boolean = true, mutableAggBufferOffset: Int = 0, - inputAggBufferOffset: Int = 0) + inputAggBufferOffset: Int = 0, + aggregatorName: Option[String] = None) extends TypedImperativeAggregate[BUF] with NonSQLExpression with UserDefinedExpression @@ -513,7 +517,9 @@ case class ScalaAggregator[IN, BUF, OUT]( override def toString: String = s"""${nodeName}(${children.mkString(",")})""" - override def nodeName: String = agg.getClass.getSimpleName + override def nodeName: String = name + + override def name: String = aggregatorName.getOrElse(agg.getClass.getSimpleName) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala index 80dd3cf8bc840..03dc9abf081fe 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala @@ -150,7 +150,8 @@ private[sql] case class UserDefinedAggregator[IN, BUF, OUT]( def scalaAggregator(exprs: Seq[Expression]): ScalaAggregator[IN, BUF, OUT] = { val iEncoder = inputEncoder.asInstanceOf[ExpressionEncoder[IN]] val bEncoder = aggregator.bufferEncoder.asInstanceOf[ExpressionEncoder[BUF]] - ScalaAggregator(exprs, aggregator, iEncoder, bEncoder, nullable, deterministic) + ScalaAggregator( + exprs, aggregator, iEncoder, bEncoder, nullable, deterministic, aggregatorName = name) } override def withName(name: String): UserDefinedAggregator[IN, BUF, OUT] = { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/IntegratedUDFTestUtils.scala b/sql/core/src/test/scala/org/apache/spark/sql/IntegratedUDFTestUtils.scala index 861a001b190aa..a090eba430061 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/IntegratedUDFTestUtils.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/IntegratedUDFTestUtils.scala @@ -321,26 +321,34 @@ object IntegratedUDFTestUtils extends SQLHelper { * casted_col.cast(df.schema("col").dataType) * }}} */ - case class TestScalaUDF(name: String) extends TestUDF { - private[IntegratedUDFTestUtils] lazy val udf = new SparkUserDefinedFunction( - (input: Any) => if (input == null) { - null - } else { - input.toString - }, - StringType, - inputEncoders = Seq.fill(1)(None), - name = Some(name)) { - - override def apply(exprs: Column*): Column = { - assert(exprs.length == 1, "Defined UDF only has one column") - val expr = exprs.head.expr - assert(expr.resolved, "column should be resolved to use the same type " + - "as input. Try df(name) or df.col(name)") - Column(Cast(createScalaUDF(Cast(expr, StringType) :: Nil), expr.dataType)) - } + class TestInternalScalaUDF(name: String) extends SparkUserDefinedFunction( + (input: Any) => if (input == null) { + null + } else { + input.toString + }, + StringType, + inputEncoders = Seq.fill(1)(None), + name = Some(name)) { + + override def apply(exprs: Column*): Column = { + assert(exprs.length == 1, "Defined UDF only has one column") + val expr = exprs.head.expr + assert(expr.resolved, "column should be resolved to use the same type " + + "as input. Try df(name) or df.col(name)") + Column(Cast(createScalaUDF(Cast(expr, StringType) :: Nil), expr.dataType)) } + override def withName(name: String): TestInternalScalaUDF = { + // "withName" should overridden to return TestInternalScalaUDF. Otherwise, the current object + // is sliced and the overridden "apply" is not invoked. + new TestInternalScalaUDF(name) + } + } + + case class TestScalaUDF(name: String) extends TestUDF { + private[IntegratedUDFTestUtils] lazy val udf = new TestInternalScalaUDF(name) + def apply(exprs: Column*): Column = udf(exprs: _*) val prettyName: String = "Scala UDF" diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala index e9b99ad002a66..7d3faaef2cd47 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala @@ -26,15 +26,18 @@ import scala.collection.mutable.{ArrayBuffer, WrappedArray} import org.apache.spark.SparkException import org.apache.spark.sql.api.java._ -import org.apache.spark.sql.catalyst.encoders.OuterScopes +import org.apache.spark.sql.catalyst.FunctionIdentifier +import org.apache.spark.sql.catalyst.encoders.{ExpressionEncoder, OuterScopes} +import org.apache.spark.sql.catalyst.expressions.{Literal, ScalaUDF} import org.apache.spark.sql.catalyst.plans.logical.Project import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.execution.{QueryExecution, SimpleMode} +import org.apache.spark.sql.execution.aggregate.{ScalaAggregator, ScalaUDAF} import org.apache.spark.sql.execution.columnar.InMemoryRelation import org.apache.spark.sql.execution.command.{CreateDataSourceTableAsSelectCommand, ExplainCommand} import org.apache.spark.sql.execution.datasources.InsertIntoHadoopFsRelationCommand -import org.apache.spark.sql.expressions.SparkUserDefinedFunction -import org.apache.spark.sql.functions.{lit, struct, udf} +import org.apache.spark.sql.expressions.{Aggregator, MutableAggregationBuffer, SparkUserDefinedFunction, UserDefinedAggregateFunction} +import org.apache.spark.sql.functions.{lit, struct, udaf, udf} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSparkSession import org.apache.spark.sql.test.SQLTestData._ @@ -798,4 +801,47 @@ class UDFSuite extends QueryTest with SharedSparkSession { .select(myUdf(Column("col"))), Row(ArrayBuffer(100))) } + + test("SPARK-34388: UDF name is propagated with registration for ScalaUDF") { + spark.udf.register("udf34388", udf((value: Int) => value > 2)) + spark.sessionState.catalog.lookupFunction( + FunctionIdentifier("udf34388"), Seq(Literal(1))) match { + case udf: ScalaUDF => assert(udf.name === "udf34388") + } + } + + test("SPARK-34388: UDF name is propagated with registration for ScalaAggregator") { + val agg = new Aggregator[Long, Long, Long] { + override def zero: Long = 0L + override def reduce(b: Long, a: Long): Long = a + b + override def merge(b1: Long, b2: Long): Long = b1 + b2 + override def finish(reduction: Long): Long = reduction + override def bufferEncoder: Encoder[Long] = ExpressionEncoder[Long]() + override def outputEncoder: Encoder[Long] = ExpressionEncoder[Long]() + } + + spark.udf.register("agg34388", udaf(agg)) + spark.sessionState.catalog.lookupFunction( + FunctionIdentifier("agg34388"), Seq(Literal(1))) match { + case agg: ScalaAggregator[_, _, _] => assert(agg.name === "agg34388") + } + } + + test("SPARK-34388: UDF name is propagated with registration for ScalaUDAF") { + val udaf = new UserDefinedAggregateFunction { + def inputSchema: StructType = new StructType().add("a", LongType) + def bufferSchema: StructType = new StructType().add("product", LongType) + def dataType: DataType = LongType + def deterministic: Boolean = true + def initialize(buffer: MutableAggregationBuffer): Unit = {} + def update(buffer: MutableAggregationBuffer, input: Row): Unit = {} + def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {} + def evaluate(buffer: Row): Any = buffer.getLong(0) + } + spark.udf.register("udaf34388", udaf) + spark.sessionState.catalog.lookupFunction( + FunctionIdentifier("udaf34388"), Seq(Literal(1))) match { + case udaf: ScalaUDAF => assert(udaf.name === "udaf34388") + } + } }