Skip to content

Commit 726f6d3

Browse files
amanomercloud-fan
authored andcommitted
[SPARK-30184][SQL] Implement a helper method for aliasing functions
### What changes were proposed in this pull request? This PR is to use `expressionWithAlias` for remaining functions for which alias name can be used. Remaining functions are: `Average, First, Last, ApproximatePercentile, StddevSamp, VarianceSamp` PR #26712 introduced `expressionWithAlias` ### Why are the changes needed? Error message is wrong when alias name is used for above mentioned functions. ### Does this PR introduce any user-facing change? No ### How was this patch tested? Manually Closes #26808 from amanomer/fncAlias. Lead-authored-by: Aman Omer <[email protected]> Co-authored-by: Aman Omer <[email protected]> Signed-off-by: Wenchen Fan <[email protected]>
1 parent dea1823 commit 726f6d3

File tree

12 files changed

+61
-85
lines changed

12 files changed

+61
-85
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala

Lines changed: 19 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@ import javax.annotation.concurrent.GuardedBy
2222

2323
import scala.collection.mutable
2424
import scala.reflect.ClassTag
25-
import scala.util.{Failure, Success, Try}
2625

2726
import org.apache.spark.internal.Logging
2827
import org.apache.spark.sql.AnalysisException
@@ -31,6 +30,7 @@ import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder
3130
import org.apache.spark.sql.catalyst.expressions._
3231
import org.apache.spark.sql.catalyst.expressions.aggregate._
3332
import org.apache.spark.sql.catalyst.expressions.xml._
33+
import org.apache.spark.sql.catalyst.trees.TreeNodeTag
3434
import org.apache.spark.sql.types._
3535

3636

@@ -193,6 +193,8 @@ object FunctionRegistry {
193193

194194
type FunctionBuilder = Seq[Expression] => Expression
195195

196+
val FUNC_ALIAS = TreeNodeTag[String]("functionAliasName")
197+
196198
// Note: Whenever we add a new entry here, make sure we also update ExpressionToSQLSuite
197199
val expressions: Map[String, (ExpressionInfo, FunctionBuilder)] = Map(
198200
// misc non-aggregate functions
@@ -289,35 +291,35 @@ object FunctionRegistry {
289291
expression[CovPopulation]("covar_pop"),
290292
expression[CovSample]("covar_samp"),
291293
expression[First]("first"),
292-
expression[First]("first_value"),
294+
expression[First]("first_value", true),
293295
expression[Kurtosis]("kurtosis"),
294296
expression[Last]("last"),
295-
expression[Last]("last_value"),
297+
expression[Last]("last_value", true),
296298
expression[Max]("max"),
297299
expression[MaxBy]("max_by"),
298-
expression[Average]("mean"),
300+
expression[Average]("mean", true),
299301
expression[Min]("min"),
300302
expression[MinBy]("min_by"),
301303
expression[Percentile]("percentile"),
302304
expression[Skewness]("skewness"),
303305
expression[ApproximatePercentile]("percentile_approx"),
304-
expression[ApproximatePercentile]("approx_percentile"),
305-
expression[StddevSamp]("std"),
306-
expression[StddevSamp]("stddev"),
306+
expression[ApproximatePercentile]("approx_percentile", true),
307+
expression[StddevSamp]("std", true),
308+
expression[StddevSamp]("stddev", true),
307309
expression[StddevPop]("stddev_pop"),
308310
expression[StddevSamp]("stddev_samp"),
309311
expression[Sum]("sum"),
310-
expression[VarianceSamp]("variance"),
312+
expression[VarianceSamp]("variance", true),
311313
expression[VariancePop]("var_pop"),
312314
expression[VarianceSamp]("var_samp"),
313315
expression[CollectList]("collect_list"),
314316
expression[CollectSet]("collect_set"),
315317
expression[CountMinSketchAgg]("count_min_sketch"),
316-
expressionWithAlias[BoolAnd]("every"),
317-
expressionWithAlias[BoolAnd]("bool_and"),
318-
expressionWithAlias[BoolOr]("any"),
319-
expressionWithAlias[BoolOr]("some"),
320-
expressionWithAlias[BoolOr]("bool_or"),
318+
expression[BoolAnd]("every", true),
319+
expression[BoolAnd]("bool_and"),
320+
expression[BoolOr]("any", true),
321+
expression[BoolOr]("some", true),
322+
expression[BoolOr]("bool_or"),
321323

322324
// string functions
323325
expression[Ascii]("ascii"),
@@ -573,7 +575,7 @@ object FunctionRegistry {
573575
val functionSet: Set[FunctionIdentifier] = builtin.listFunction().toSet
574576

575577
/** See usage above. */
576-
private def expression[T <: Expression](name: String)
578+
private def expression[T <: Expression](name: String, setAlias: Boolean = false)
577579
(implicit tag: ClassTag[T]): (String, (ExpressionInfo, FunctionBuilder)) = {
578580

579581
// For `RuntimeReplaceable`, skip the constructor with most arguments, which is the main
@@ -619,7 +621,9 @@ object FunctionRegistry {
619621
throw new AnalysisException(invalidArgumentsMsg)
620622
}
621623
try {
622-
f.newInstance(expressions : _*).asInstanceOf[Expression]
624+
val exp = f.newInstance(expressions : _*).asInstanceOf[Expression]
625+
if (setAlias) exp.setTagValue(FUNC_ALIAS, name)
626+
exp
623627
} catch {
624628
// the exception is an invocation exception. To get a meaningful message, we need the
625629
// cause.
@@ -631,42 +635,6 @@ object FunctionRegistry {
631635
(name, (expressionInfo[T](name), builder))
632636
}
633637

634-
private def expressionWithAlias[T <: Expression](name: String)
635-
(implicit tag: ClassTag[T]): (String, (ExpressionInfo, FunctionBuilder)) = {
636-
val constructors = tag.runtimeClass.getConstructors
637-
.filter(_.getParameterTypes.head == classOf[String])
638-
assert(constructors.length == 1)
639-
val builder = (expressions: Seq[Expression]) => {
640-
val params = classOf[String] +: Seq.fill(expressions.size)(classOf[Expression])
641-
val f = constructors.find(_.getParameterTypes.toSeq == params).getOrElse {
642-
val validParametersCount = constructors
643-
.filter(_.getParameterTypes.tail.forall(_ == classOf[Expression]))
644-
.map(_.getParameterCount - 1).distinct.sorted
645-
val invalidArgumentsMsg = if (validParametersCount.length == 0) {
646-
s"Invalid arguments for function $name"
647-
} else {
648-
val expectedNumberOfParameters = if (validParametersCount.length == 1) {
649-
validParametersCount.head.toString
650-
} else {
651-
validParametersCount.init.mkString("one of ", ", ", " and ") +
652-
validParametersCount.last
653-
}
654-
s"Invalid number of arguments for function $name. " +
655-
s"Expected: $expectedNumberOfParameters; Found: ${expressions.size}"
656-
}
657-
throw new AnalysisException(invalidArgumentsMsg)
658-
}
659-
try {
660-
f.newInstance(name.toString +: expressions: _*).asInstanceOf[Expression]
661-
} catch {
662-
// the exception is an invocation exception. To get a meaningful message, we need the
663-
// cause.
664-
case e: Exception => throw new AnalysisException(e.getCause.getMessage)
665-
}
666-
}
667-
(name, (expressionInfo[T](name), builder))
668-
}
669-
670638
/**
671639
* Creates a function registry lookup entry for cast aliases (SPARK-16730).
672640
* For example, if name is "int", and dataType is IntegerType, this means int(x) would become

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentile.scala

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ import java.nio.ByteBuffer
2222
import com.google.common.primitives.{Doubles, Ints, Longs}
2323

2424
import org.apache.spark.sql.catalyst.InternalRow
25-
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
25+
import org.apache.spark.sql.catalyst.analysis.{FunctionRegistry, TypeCheckResult}
2626
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckFailure, TypeCheckSuccess}
2727
import org.apache.spark.sql.catalyst.expressions._
2828
import org.apache.spark.sql.catalyst.expressions.aggregate.ApproximatePercentile.PercentileDigest
@@ -185,7 +185,8 @@ case class ApproximatePercentile(
185185
if (returnPercentileArray) ArrayType(child.dataType, false) else child.dataType
186186
}
187187

188-
override def prettyName: String = "percentile_approx"
188+
override def prettyName: String =
189+
getTagValue(FunctionRegistry.FUNC_ALIAS).getOrElse("percentile_approx")
189190

190191
override def serialize(obj: PercentileDigest): Array[Byte] = {
191192
ApproximatePercentile.serializer.serialize(obj)
@@ -321,4 +322,5 @@ object ApproximatePercentile {
321322
}
322323

323324
val serializer: PercentileDigestSerializer = new PercentileDigestSerializer
325+
324326
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
package org.apache.spark.sql.catalyst.expressions.aggregate
1919

20-
import org.apache.spark.sql.catalyst.analysis.{DecimalPrecision, TypeCheckResult}
20+
import org.apache.spark.sql.catalyst.analysis.{DecimalPrecision, FunctionRegistry, TypeCheckResult}
2121
import org.apache.spark.sql.catalyst.dsl.expressions._
2222
import org.apache.spark.sql.catalyst.expressions._
2323
import org.apache.spark.sql.catalyst.util.TypeUtils
@@ -37,7 +37,7 @@ import org.apache.spark.sql.types._
3737
since = "1.0.0")
3838
case class Average(child: Expression) extends DeclarativeAggregate with ImplicitCastInputTypes {
3939

40-
override def prettyName: String = "avg"
40+
override def prettyName: String = getTagValue(FunctionRegistry.FUNC_ALIAS).getOrElse("avg")
4141

4242
override def children: Seq[Expression] = child :: Nil
4343

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
package org.apache.spark.sql.catalyst.expressions.aggregate
1919

20+
import org.apache.spark.sql.catalyst.analysis.FunctionRegistry
2021
import org.apache.spark.sql.catalyst.dsl.expressions._
2122
import org.apache.spark.sql.catalyst.expressions._
2223
import org.apache.spark.sql.types._
@@ -174,7 +175,8 @@ case class StddevSamp(child: Expression) extends CentralMomentAgg(child) {
174175
If(n === 1.0, Double.NaN, sqrt(m2 / (n - 1.0))))
175176
}
176177

177-
override def prettyName: String = "stddev_samp"
178+
override def prettyName: String =
179+
getTagValue(FunctionRegistry.FUNC_ALIAS).getOrElse("stddev_samp")
178180
}
179181

180182
// Compute the population variance of a column
@@ -215,7 +217,7 @@ case class VarianceSamp(child: Expression) extends CentralMomentAgg(child) {
215217
If(n === 1.0, Double.NaN, m2 / (n - 1.0)))
216218
}
217219

218-
override def prettyName: String = "var_samp"
220+
override def prettyName: String = getTagValue(FunctionRegistry.FUNC_ALIAS).getOrElse("var_samp")
219221
}
220222

221223
@ExpressionDescription(

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/First.scala

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
package org.apache.spark.sql.catalyst.expressions.aggregate
1919

20-
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
20+
import org.apache.spark.sql.catalyst.analysis.{FunctionRegistry, TypeCheckResult}
2121
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckFailure, TypeCheckSuccess}
2222
import org.apache.spark.sql.catalyst.dsl.expressions._
2323
import org.apache.spark.sql.catalyst.expressions._
@@ -113,5 +113,7 @@ case class First(child: Expression, ignoreNullsExpr: Expression)
113113

114114
override lazy val evaluateExpression: AttributeReference = first
115115

116-
override def toString: String = s"first($child)${if (ignoreNulls) " ignore nulls"}"
116+
override def prettyName: String = getTagValue(FunctionRegistry.FUNC_ALIAS).getOrElse("first")
117+
118+
override def toString: String = s"$prettyName($child)${if (ignoreNulls) " ignore nulls"}"
117119
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Last.scala

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
package org.apache.spark.sql.catalyst.expressions.aggregate
1919

20-
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
20+
import org.apache.spark.sql.catalyst.analysis.{FunctionRegistry, TypeCheckResult}
2121
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckFailure, TypeCheckSuccess}
2222
import org.apache.spark.sql.catalyst.dsl.expressions._
2323
import org.apache.spark.sql.catalyst.expressions._
@@ -111,5 +111,7 @@ case class Last(child: Expression, ignoreNullsExpr: Expression)
111111

112112
override lazy val evaluateExpression: AttributeReference = last
113113

114-
override def toString: String = s"last($child)${if (ignoreNulls) " ignore nulls"}"
114+
override def prettyName: String = getTagValue(FunctionRegistry.FUNC_ALIAS).getOrElse("last")
115+
116+
override def toString: String = s"$prettyName($child)${if (ignoreNulls) " ignore nulls"}"
115117
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/UnevaluableAggs.scala

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
package org.apache.spark.sql.catalyst.expressions.aggregate
1919

20-
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
20+
import org.apache.spark.sql.catalyst.analysis.{FunctionRegistry, TypeCheckResult}
2121
import org.apache.spark.sql.catalyst.expressions._
2222
import org.apache.spark.sql.types._
2323

@@ -52,8 +52,8 @@ abstract class UnevaluableBooleanAggBase(arg: Expression)
5252
false
5353
""",
5454
since = "3.0.0")
55-
case class BoolAnd(funcName: String, arg: Expression) extends UnevaluableBooleanAggBase(arg) {
56-
override def nodeName: String = funcName
55+
case class BoolAnd(arg: Expression) extends UnevaluableBooleanAggBase(arg) {
56+
override def nodeName: String = getTagValue(FunctionRegistry.FUNC_ALIAS).getOrElse("bool_and")
5757
}
5858

5959
@ExpressionDescription(
@@ -68,6 +68,6 @@ case class BoolAnd(funcName: String, arg: Expression) extends UnevaluableBoolean
6868
false
6969
""",
7070
since = "3.0.0")
71-
case class BoolOr(funcName: String, arg: Expression) extends UnevaluableBooleanAggBase(arg) {
72-
override def nodeName: String = funcName
71+
case class BoolOr(arg: Expression) extends UnevaluableBooleanAggBase(arg) {
72+
override def nodeName: String = getTagValue(FunctionRegistry.FUNC_ALIAS).getOrElse("bool_or")
7373
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/finishAnalysis.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,8 +47,8 @@ object ReplaceExpressions extends Rule[LogicalPlan] {
4747
def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions {
4848
case e: RuntimeReplaceable => e.child
4949
case CountIf(predicate) => Count(new NullIf(predicate, Literal.FalseLiteral))
50-
case BoolOr(_, arg) => Max(arg)
51-
case BoolAnd(_, arg) => Min(arg)
50+
case BoolOr(arg) => Max(arg)
51+
case BoolAnd(arg) => Min(arg)
5252
}
5353
}
5454

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -153,8 +153,8 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite {
153153
assertSuccess(Sum(Symbol("stringField")))
154154
assertSuccess(Average(Symbol("stringField")))
155155
assertSuccess(Min(Symbol("arrayField")))
156-
assertSuccess(new BoolAnd("bool_and", Symbol("booleanField")))
157-
assertSuccess(new BoolOr("bool_or", Symbol("booleanField")))
156+
assertSuccess(new BoolAnd(Symbol("booleanField")))
157+
assertSuccess(new BoolOr(Symbol("booleanField")))
158158

159159
assertError(Min(Symbol("mapField")), "min does not support ordering on type")
160160
assertError(Max(Symbol("mapField")), "max does not support ordering on type")

sql/core/src/test/resources/sql-tests/results/group-by.sql.out

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,7 @@ NULL 1
128128
SELECT SKEWNESS(a), KURTOSIS(a), MIN(a), MAX(a), AVG(a), VARIANCE(a), STDDEV(a), SUM(a), COUNT(a)
129129
FROM testData
130130
-- !query 13 schema
131-
struct<skewness(CAST(a AS DOUBLE)):double,kurtosis(CAST(a AS DOUBLE)):double,min(a):int,max(a):int,avg(a):double,var_samp(CAST(a AS DOUBLE)):double,stddev_samp(CAST(a AS DOUBLE)):double,sum(a):bigint,count(a):bigint>
131+
struct<skewness(CAST(a AS DOUBLE)):double,kurtosis(CAST(a AS DOUBLE)):double,min(a):int,max(a):int,avg(a):double,variance(CAST(a AS DOUBLE)):double,stddev(CAST(a AS DOUBLE)):double,sum(a):bigint,count(a):bigint>
132132
-- !query 13 output
133133
-0.2723801058145729 -1.5069204152249134 1 3 2.142857142857143 0.8095238095238094 0.8997354108424372 15 7
134134

0 commit comments

Comments
 (0)