Skip to content

Commit 6704103

Browse files
HyukjinKwondongjoon-hyun
authored andcommitted
[SPARK-31146][SQL] Leverage the helper method for aliasing in built-in SQL expressions
### What changes were proposed in this pull request? This PR is kind of a followup of #26808. It leverages the helper method for aliasing in built-in SQL expressions to use the alias as its output column name where it's applicable. - `Expression`, `UnaryMathExpression` and `BinaryMathExpression` search the alias in the tags by default. - When the naming is different in its implementation, it has to be overwritten for the expression specifically. E.g., `CallMethodViaReflection`, `Remainder`, `CurrentTimestamp`, `FormatString` and `XPathDouble`. This PR fixes the aliases of the functions below: | class | alias | |--------------------------|------------------| |`Rand` |`random` | |`Ceil` |`ceiling` | |`Remainder` |`mod` | |`Pow` |`pow` | |`Signum` |`sign` | |`Chr` |`char` | |`Length` |`char_length` | |`Length` |`character_length`| |`FormatString` |`printf` | |`Substring` |`substr` | |`Upper` |`ucase` | |`XPathDouble` |`xpath_number` | |`DayOfMonth` |`day` | |`CurrentTimestamp` |`now` | |`Size` |`cardinality` | |`Sha1` |`sha` | |`CallMethodViaReflection` |`java_method` | Note: `EqualTo`, `=` and `==` aliases were excluded because it's unable to leverage this helper method. It should fix the parser. Note: this PR also excludes some instances such as `ToDegrees`, `ToRadians`, `UnaryMinus` and `UnaryPositive` that needs an explicit name overwritten to make the scope of this PR smaller. ### Why are the changes needed? To respect expression name. ### Does this PR introduce any user-facing change? Yes, it will change the output column name. ### How was this patch tested? Manually tested, and unittests were added. Closes #27901 from HyukjinKwon/31146. Authored-by: HyukjinKwon <[email protected]> Signed-off-by: Dongjoon Hyun <[email protected]>
1 parent 3ce1dff commit 6704103

File tree

17 files changed

+70
-55
lines changed

17 files changed

+70
-55
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala

Lines changed: 20 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -218,7 +218,7 @@ object FunctionRegistry {
218218
expression[PosExplode]("posexplode"),
219219
expressionGeneratorOuter[PosExplode]("posexplode_outer"),
220220
expression[Rand]("rand"),
221-
expression[Rand]("random"),
221+
expression[Rand]("random", true),
222222
expression[Randn]("randn"),
223223
expression[Stack]("stack"),
224224
expression[CaseWhen]("when"),
@@ -235,7 +235,7 @@ object FunctionRegistry {
235235
expression[BRound]("bround"),
236236
expression[Cbrt]("cbrt"),
237237
expression[Ceil]("ceil"),
238-
expression[Ceil]("ceiling"),
238+
expression[Ceil]("ceiling", true),
239239
expression[Cos]("cos"),
240240
expression[Cosh]("cosh"),
241241
expression[Conv]("conv"),
@@ -252,20 +252,20 @@ object FunctionRegistry {
252252
expression[Log1p]("log1p"),
253253
expression[Log2]("log2"),
254254
expression[Log]("ln"),
255-
expression[Remainder]("mod"),
255+
expression[Remainder]("mod", true),
256256
expression[UnaryMinus]("negative"),
257257
expression[Pi]("pi"),
258258
expression[Pmod]("pmod"),
259259
expression[UnaryPositive]("positive"),
260-
expression[Pow]("pow"),
260+
expression[Pow]("pow", true),
261261
expression[Pow]("power"),
262262
expression[ToRadians]("radians"),
263263
expression[Rint]("rint"),
264264
expression[Round]("round"),
265265
expression[ShiftLeft]("shiftleft"),
266266
expression[ShiftRight]("shiftright"),
267267
expression[ShiftRightUnsigned]("shiftrightunsigned"),
268-
expression[Signum]("sign"),
268+
expression[Signum]("sign", true),
269269
expression[Signum]("signum"),
270270
expression[Sin]("sin"),
271271
expression[Sinh]("sinh"),
@@ -323,12 +323,12 @@ object FunctionRegistry {
323323

324324
// string functions
325325
expression[Ascii]("ascii"),
326-
expression[Chr]("char"),
326+
expression[Chr]("char", true),
327327
expression[Chr]("chr"),
328328
expression[Base64]("base64"),
329329
expression[BitLength]("bit_length"),
330-
expression[Length]("char_length"),
331-
expression[Length]("character_length"),
330+
expression[Length]("char_length", true),
331+
expression[Length]("character_length", true),
332332
expression[ConcatWs]("concat_ws"),
333333
expression[Decode]("decode"),
334334
expression[Elt]("elt"),
@@ -351,7 +351,7 @@ object FunctionRegistry {
351351
expression[JsonTuple]("json_tuple"),
352352
expression[ParseUrl]("parse_url"),
353353
expression[StringLocate]("position"),
354-
expression[FormatString]("printf"),
354+
expression[FormatString]("printf", true),
355355
expression[RegExpExtract]("regexp_extract"),
356356
expression[RegExpReplace]("regexp_replace"),
357357
expression[StringRepeat]("repeat"),
@@ -364,21 +364,21 @@ object FunctionRegistry {
364364
expression[SoundEx]("soundex"),
365365
expression[StringSpace]("space"),
366366
expression[StringSplit]("split"),
367-
expression[Substring]("substr"),
367+
expression[Substring]("substr", true),
368368
expression[Substring]("substring"),
369369
expression[Left]("left"),
370370
expression[Right]("right"),
371371
expression[SubstringIndex]("substring_index"),
372372
expression[StringTranslate]("translate"),
373373
expression[StringTrim]("trim"),
374-
expression[Upper]("ucase"),
374+
expression[Upper]("ucase", true),
375375
expression[UnBase64]("unbase64"),
376376
expression[Unhex]("unhex"),
377377
expression[Upper]("upper"),
378378
expression[XPathList]("xpath"),
379379
expression[XPathBoolean]("xpath_boolean"),
380380
expression[XPathDouble]("xpath_double"),
381-
expression[XPathDouble]("xpath_number"),
381+
expression[XPathDouble]("xpath_number", true),
382382
expression[XPathFloat]("xpath_float"),
383383
expression[XPathInt]("xpath_int"),
384384
expression[XPathLong]("xpath_long"),
@@ -393,7 +393,7 @@ object FunctionRegistry {
393393
expression[DateAdd]("date_add"),
394394
expression[DateFormatClass]("date_format"),
395395
expression[DateSub]("date_sub"),
396-
expression[DayOfMonth]("day"),
396+
expression[DayOfMonth]("day", true),
397397
expression[DayOfYear]("dayofyear"),
398398
expression[DayOfMonth]("dayofmonth"),
399399
expression[FromUnixTime]("from_unixtime"),
@@ -404,7 +404,7 @@ object FunctionRegistry {
404404
expression[Month]("month"),
405405
expression[MonthsBetween]("months_between"),
406406
expression[NextDay]("next_day"),
407-
expression[CurrentTimestamp]("now"),
407+
expression[CurrentTimestamp]("now", true),
408408
expression[Quarter]("quarter"),
409409
expression[Second]("second"),
410410
expression[ParseToTimestamp]("to_timestamp"),
@@ -445,7 +445,7 @@ object FunctionRegistry {
445445
expression[MapConcat]("map_concat"),
446446
expression[Size]("size"),
447447
expression[Slice]("slice"),
448-
expression[Size]("cardinality"),
448+
expression[Size]("cardinality", true),
449449
expression[ArraysZip]("arrays_zip"),
450450
expression[SortArray]("sort_array"),
451451
expression[Shuffle]("shuffle"),
@@ -478,7 +478,7 @@ object FunctionRegistry {
478478
expression[Uuid]("uuid"),
479479
expression[Murmur3Hash]("hash"),
480480
expression[XxHash64]("xxhash64"),
481-
expression[Sha1]("sha"),
481+
expression[Sha1]("sha", true),
482482
expression[Sha1]("sha1"),
483483
expression[Sha2]("sha2"),
484484
expression[SparkPartitionID]("spark_partition_id"),
@@ -488,7 +488,7 @@ object FunctionRegistry {
488488
expression[MonotonicallyIncreasingID]("monotonically_increasing_id"),
489489
expression[CurrentDatabase]("current_database"),
490490
expression[CallMethodViaReflection]("reflect"),
491-
expression[CallMethodViaReflection]("java_method"),
491+
expression[CallMethodViaReflection]("java_method", true),
492492
expression[SparkVersion]("version"),
493493
expression[TypeOf]("typeof"),
494494

@@ -590,7 +590,9 @@ object FunctionRegistry {
590590
if (varargCtor.isDefined) {
591591
// If there is an apply method that accepts Seq[Expression], use that one.
592592
try {
593-
varargCtor.get.newInstance(expressions).asInstanceOf[Expression]
593+
val exp = varargCtor.get.newInstance(expressions).asInstanceOf[Expression]
594+
if (setAlias) exp.setTagValue(FUNC_ALIAS, name)
595+
exp
594596
} catch {
595597
// the exception is an invocation exception. To get a meaningful message, we need the
596598
// cause.

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/CallMethodViaReflection.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.expressions
2020
import java.lang.reflect.{Method, Modifier}
2121

2222
import org.apache.spark.sql.catalyst.InternalRow
23-
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
23+
import org.apache.spark.sql.catalyst.analysis.{FunctionRegistry, TypeCheckResult}
2424
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckFailure, TypeCheckSuccess}
2525
import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
2626
import org.apache.spark.sql.types._
@@ -55,7 +55,7 @@ import org.apache.spark.util.Utils
5555
case class CallMethodViaReflection(children: Seq[Expression])
5656
extends Expression with CodegenFallback {
5757

58-
override def prettyName: String = "reflect"
58+
override def prettyName: String = getTagValue(FunctionRegistry.FUNC_ALIAS).getOrElse("reflect")
5959

6060
override def checkInputDataTypes(): TypeCheckResult = {
6161
if (children.size < 2) {

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.expressions
2020
import java.util.Locale
2121

2222
import org.apache.spark.sql.catalyst.InternalRow
23-
import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, TypeCoercion}
23+
import org.apache.spark.sql.catalyst.analysis.{FunctionRegistry, TypeCheckResult, TypeCoercion}
2424
import org.apache.spark.sql.catalyst.expressions.aggregate.DeclarativeAggregate
2525
import org.apache.spark.sql.catalyst.expressions.codegen._
2626
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
@@ -258,7 +258,8 @@ abstract class Expression extends TreeNode[Expression] {
258258
* Returns a user-facing string representation of this expression's name.
259259
* This should usually match the name of the function in SQL.
260260
*/
261-
def prettyName: String = nodeName.toLowerCase(Locale.ROOT)
261+
def prettyName: String =
262+
getTagValue(FunctionRegistry.FUNC_ALIAS).getOrElse(nodeName.toLowerCase(Locale.ROOT))
262263

263264
protected def flatArguments: Iterator[Any] = stringArgs.flatMap {
264265
case t: Iterable[_] => t

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/First.scala

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,5 @@ case class First(child: Expression, ignoreNullsExpr: Expression)
117117

118118
override lazy val evaluateExpression: AttributeReference = first
119119

120-
override def prettyName: String = getTagValue(FunctionRegistry.FUNC_ALIAS).getOrElse("first")
121-
122120
override def toString: String = s"$prettyName($child)${if (ignoreNulls) " ignore nulls"}"
123121
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Last.scala

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,5 @@ case class Last(child: Expression, ignoreNullsExpr: Expression)
115115

116116
override lazy val evaluateExpression: AttributeReference = last
117117

118-
override def prettyName: String = getTagValue(FunctionRegistry.FUNC_ALIAS).getOrElse("last")
119-
120118
override def toString: String = s"$prettyName($child)${if (ignoreNulls) " ignore nulls"}"
121119
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
package org.apache.spark.sql.catalyst.expressions
1919

2020
import org.apache.spark.sql.catalyst.InternalRow
21-
import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, TypeCoercion}
21+
import org.apache.spark.sql.catalyst.analysis.{FunctionRegistry, TypeCheckResult, TypeCoercion}
2222
import org.apache.spark.sql.catalyst.expressions.codegen._
2323
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
2424
import org.apache.spark.sql.catalyst.util.{IntervalUtils, TypeUtils}
@@ -457,6 +457,18 @@ case class Remainder(left: Expression, right: Expression) extends DivModLike {
457457

458458
override def symbol: String = "%"
459459
override def decimalMethod: String = "remainder"
460+
override def toString: String = {
461+
getTagValue(FunctionRegistry.FUNC_ALIAS).getOrElse(sqlOperator) match {
462+
case operator if operator == sqlOperator => s"($left $sqlOperator $right)"
463+
case funcName => s"$funcName($left, $right)"
464+
}
465+
}
466+
override def sql: String = {
467+
getTagValue(FunctionRegistry.FUNC_ALIAS).getOrElse(sqlOperator) match {
468+
case operator if operator == sqlOperator => s"(${left.sql} $sqlOperator ${right.sql})"
469+
case funcName => s"$funcName(${left.sql}, ${right.sql})"
470+
}
471+
}
460472

461473
private lazy val mod: (Any, Any) => Any = dataType match {
462474
// special cases to make float/double primitive types faster

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ import org.apache.commons.text.StringEscapeUtils
2929
import org.apache.spark.SparkUpgradeException
3030
import org.apache.spark.sql.AnalysisException
3131
import org.apache.spark.sql.catalyst.InternalRow
32+
import org.apache.spark.sql.catalyst.analysis.FunctionRegistry
3233
import org.apache.spark.sql.catalyst.expressions.codegen._
3334
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
3435
import org.apache.spark.sql.catalyst.util.{DateTimeUtils, LegacyDateFormats, TimestampFormatter}
@@ -99,7 +100,8 @@ case class CurrentTimestamp() extends LeafExpression with CodegenFallback {
99100

100101
override def eval(input: InternalRow): Any = currentTimestamp()
101102

102-
override def prettyName: String = "current_timestamp"
103+
override def prettyName: String =
104+
getTagValue(FunctionRegistry.FUNC_ALIAS).getOrElse("current_timestamp")
103105
}
104106

105107
/**

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ import java.{lang => jl}
2121
import java.util.Locale
2222

2323
import org.apache.spark.sql.catalyst.InternalRow
24-
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
24+
import org.apache.spark.sql.catalyst.analysis.{FunctionRegistry, TypeCheckResult}
2525
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckFailure, TypeCheckSuccess}
2626
import org.apache.spark.sql.catalyst.expressions.codegen._
2727
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
@@ -62,8 +62,8 @@ abstract class UnaryMathExpression(val f: Double => Double, name: String)
6262
override def inputTypes: Seq[AbstractDataType] = Seq(DoubleType)
6363
override def dataType: DataType = DoubleType
6464
override def nullable: Boolean = true
65-
override def toString: String = s"$name($child)"
66-
override def prettyName: String = name
65+
override def toString: String = s"$prettyName($child)"
66+
override def prettyName: String = getTagValue(FunctionRegistry.FUNC_ALIAS).getOrElse(name)
6767

6868
protected override def nullSafeEval(input: Any): Any = {
6969
f(input.asInstanceOf[Double])
@@ -115,9 +115,9 @@ abstract class BinaryMathExpression(f: (Double, Double) => Double, name: String)
115115

116116
override def inputTypes: Seq[DataType] = Seq(DoubleType, DoubleType)
117117

118-
override def toString: String = s"$name($left, $right)"
118+
override def toString: String = s"$prettyName($left, $right)"
119119

120-
override def prettyName: String = name
120+
override def prettyName: String = getTagValue(FunctionRegistry.FUNC_ALIAS).getOrElse(name)
121121

122122
override def dataType: DataType = DoubleType
123123

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ import scala.collection.mutable.ArrayBuffer
2727
import org.apache.commons.codec.binary.{Base64 => CommonsBase64}
2828

2929
import org.apache.spark.sql.catalyst.InternalRow
30-
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
30+
import org.apache.spark.sql.catalyst.analysis.{FunctionRegistry, TypeCheckResult}
3131
import org.apache.spark.sql.catalyst.expressions.codegen._
3232
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
3333
import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData, TypeUtils}
@@ -1450,7 +1450,7 @@ case class ParseUrl(children: Seq[Expression])
14501450
// scalastyle:on line.size.limit
14511451
case class FormatString(children: Expression*) extends Expression with ImplicitCastInputTypes {
14521452

1453-
require(children.nonEmpty, "format_string() should take at least 1 argument")
1453+
require(children.nonEmpty, s"$prettyName() should take at least 1 argument")
14541454

14551455
override def foldable: Boolean = children.forall(_.foldable)
14561456
override def nullable: Boolean = children(0).nullable
@@ -1517,7 +1517,8 @@ case class FormatString(children: Expression*) extends Expression with ImplicitC
15171517
}""")
15181518
}
15191519

1520-
override def prettyName: String = "format_string"
1520+
override def prettyName: String = getTagValue(
1521+
FunctionRegistry.FUNC_ALIAS).getOrElse("format_string")
15211522
}
15221523

15231524
/**

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/xml/xpath.scala

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
package org.apache.spark.sql.catalyst.expressions.xml
1919

20-
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
20+
import org.apache.spark.sql.catalyst.analysis.{FunctionRegistry, TypeCheckResult}
2121
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.TypeCheckFailure
2222
import org.apache.spark.sql.catalyst.expressions._
2323
import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
@@ -160,7 +160,8 @@ case class XPathFloat(xml: Expression, path: Expression) extends XPathExtract {
160160
""")
161161
// scalastyle:on line.size.limit
162162
case class XPathDouble(xml: Expression, path: Expression) extends XPathExtract {
163-
override def prettyName: String = "xpath_double"
163+
override def prettyName: String =
164+
getTagValue(FunctionRegistry.FUNC_ALIAS).getOrElse("xpath_double")
164165
override def dataType: DataType = DoubleType
165166

166167
override def nullSafeEval(xml: Any, path: Any): Any = {

0 commit comments

Comments
 (0)