Skip to content

Commit 8cf4a1f

Browse files
adrian-wangmarmbrus
authored andcommitted
[SPARK-5262] [SPARK-5244] [SQL] add coalesce in SQLParser and widen types for parameters of coalesce
I'll add test case in #4040 Author: Daoyuan Wang <[email protected]> Closes #4057 from adrian-wang/coal and squashes the following commits: 4d0111a [Daoyuan Wang] address Yin's comments c393e18 [Daoyuan Wang] fix rebase conflicts e47c03a [Daoyuan Wang] add coalesce in parser c74828d [Daoyuan Wang] cast types for coalesce
1 parent 1b56f1d commit 8cf4a1f

File tree

6 files changed

+65
-0
lines changed

6 files changed

+65
-0
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ class SqlParser extends AbstractSparkSQLParser {
5050
protected val CACHE = Keyword("CACHE")
5151
protected val CASE = Keyword("CASE")
5252
protected val CAST = Keyword("CAST")
53+
protected val COALESCE = Keyword("COALESCE")
5354
protected val COUNT = Keyword("COUNT")
5455
protected val DECIMAL = Keyword("DECIMAL")
5556
protected val DESC = Keyword("DESC")
@@ -295,6 +296,7 @@ class SqlParser extends AbstractSparkSQLParser {
295296
{ case s ~ p => Substring(s, p, Literal(Integer.MAX_VALUE)) }
296297
| (SUBSTR | SUBSTRING) ~ "(" ~> expression ~ ("," ~> expression) ~ ("," ~> expression) <~ ")" ^^
297298
{ case s ~ p ~ l => Substring(s, p, l) }
299+
| COALESCE ~ "(" ~> repsep(expression, ",") <~ ")" ^^ { case exprs => Coalesce(exprs) }
298300
| SQRT ~ "(" ~> expression <~ ")" ^^ { case exp => Sqrt(exp) }
299301
| ABS ~ "(" ~> expression <~ ")" ^^ { case exp => Abs(exp) }
300302
| ident ~ ("(" ~> repsep(expression, ",")) <~ ")" ^^

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -503,6 +503,22 @@ trait HiveTypeCoercion {
503503
// Hive lets you do aggregation of timestamps... for some reason
504504
case Sum(e @ TimestampType()) => Sum(Cast(e, DoubleType))
505505
case Average(e @ TimestampType()) => Average(Cast(e, DoubleType))
506+
507+
// Coalesce should return the first non-null value, which could be any column
508+
// from the list. So we need to make sure the return type is deterministic and
509+
// compatible with every child column.
510+
case Coalesce(es) if es.map(_.dataType).distinct.size > 1 =>
511+
val dt: Option[DataType] = Some(NullType)
512+
val types = es.map(_.dataType)
513+
val rt = types.foldLeft(dt)((r, c) => r match {
514+
case None => None
515+
case Some(d) => findTightestCommonType(d, c)
516+
})
517+
rt match {
518+
case Some(finaldt) => Coalesce(es.map(Cast(_, finaldt)))
519+
case None =>
520+
sys.error(s"Could not determine return type of Coalesce for ${types.mkString(",")}")
521+
}
506522
}
507523
}
508524

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,4 +114,31 @@ class HiveTypeCoercionSuite extends FunSuite {
114114
// Stringify boolean when casting to string.
115115
ruleTest(Cast(Literal(false), StringType), If(Literal(false), Literal("true"), Literal("false")))
116116
}
117+
118+
test("coalesce casts") {
119+
val fac = new HiveTypeCoercion { }.FunctionArgumentConversion
120+
def ruleTest(initial: Expression, transformed: Expression) {
121+
val testRelation = LocalRelation(AttributeReference("a", IntegerType)())
122+
assert(fac(Project(Seq(Alias(initial, "a")()), testRelation)) ==
123+
Project(Seq(Alias(transformed, "a")()), testRelation))
124+
}
125+
ruleTest(
126+
Coalesce(Literal(1.0)
127+
:: Literal(1)
128+
:: Literal(1.0, FloatType)
129+
:: Nil),
130+
Coalesce(Cast(Literal(1.0), DoubleType)
131+
:: Cast(Literal(1), DoubleType)
132+
:: Cast(Literal(1.0, FloatType), DoubleType)
133+
:: Nil))
134+
ruleTest(
135+
Coalesce(Literal(1L)
136+
:: Literal(1)
137+
:: Literal(new java.math.BigDecimal("1000000000000000000000"))
138+
:: Nil),
139+
Coalesce(Cast(Literal(1L), DecimalType())
140+
:: Cast(Literal(1), DecimalType())
141+
:: Cast(Literal(new java.math.BigDecimal("1000000000000000000000")), DecimalType())
142+
:: Nil))
143+
}
117144
}

sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,18 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll {
8888
setConf(SQLConf.CODEGEN_ENABLED, originalValue.toString)
8989
}
9090

91+
test("Add Parser of SQL COALESCE()") {
92+
checkAnswer(
93+
sql("""SELECT COALESCE(1, 2)"""),
94+
Row(1))
95+
checkAnswer(
96+
sql("SELECT COALESCE(null, 1, 1.5)"),
97+
Row(1.toDouble))
98+
checkAnswer(
99+
sql("SELECT COALESCE(null, null, null)"),
100+
Row(null))
101+
}
102+
91103
test("SPARK-3176 Added Parser of SQL LAST()") {
92104
checkAnswer(
93105
sql("SELECT LAST(n) FROM lowerCaseData"),

sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -965,6 +965,7 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C
965965

966966
/* Case insensitive matches */
967967
val ARRAY = "(?i)ARRAY".r
968+
val COALESCE = "(?i)COALESCE".r
968969
val COUNT = "(?i)COUNT".r
969970
val AVG = "(?i)AVG".r
970971
val SUM = "(?i)SUM".r
@@ -1140,6 +1141,7 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C
11401141
Substring(nodeToExpr(string), nodeToExpr(pos), Literal(Integer.MAX_VALUE, IntegerType))
11411142
case Token("TOK_FUNCTION", Token(SUBSTR(), Nil) :: string :: pos :: length :: Nil) =>
11421143
Substring(nodeToExpr(string), nodeToExpr(pos), nodeToExpr(length))
1144+
case Token("TOK_FUNCTION", Token(COALESCE(), Nil) :: list) => Coalesce(list.map(nodeToExpr))
11431145

11441146
/* UDFs - Must be last otherwise will preempt built in functions */
11451147
case Token("TOK_FUNCTION", Token(name, Nil) :: args) =>

sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTypeCoercionSuite.scala

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,4 +57,10 @@ class HiveTypeCoercionSuite extends HiveComparisonTest {
5757
}
5858
assert(numEquals === 1)
5959
}
60+
61+
test("COALESCE with different types") {
62+
intercept[RuntimeException] {
63+
TestHive.sql("""SELECT COALESCE(1, true, "abc") FROM src limit 1""").collect()
64+
}
65+
}
6066
}

0 commit comments

Comments
 (0)