Skip to content

Commit 337c16d

Browse files
committed
[SQL] Miscellaneous SQL/DF expression changes.
SPARK-8201 conditional function: if SPARK-8205 conditional function: nvl SPARK-8208 math function: ceiling SPARK-8210 math function: degrees SPARK-8211 math function: radians SPARK-8219 math function: negative SPARK-8216 math function: rename log -> ln SPARK-8222 math function: alias power / pow SPARK-8225 math function: alias sign / signum SPARK-8228 conditional function: isnull SPARK-8229 conditional function: isnotnull SPARK-8250 string function: alias lower/lcase SPARK-8251 string function: alias upper / ucase Author: Reynold Xin <[email protected]> Closes apache#6754 from rxin/expressions-misc and squashes the following commits: 35fce15 [Reynold Xin] Removed println. 2647067 [Reynold Xin] Promote to string type. 3c32bbc [Reynold Xin] Fixed if. de827ac [Reynold Xin] Fixed style b201cd4 [Reynold Xin] Removed if. 6b21a9b [Reynold Xin] [SQL] Miscellaneous SQL/DF expression changes.
1 parent 7914c72 commit 337c16d

File tree

7 files changed

+175
-27
lines changed

7 files changed

+175
-27
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -84,43 +84,51 @@ object FunctionRegistry {
8484
type FunctionBuilder = Seq[Expression] => Expression
8585

8686
val expressions: Map[String, FunctionBuilder] = Map(
87-
// Non aggregate functions
87+
// misc non-aggregate functions
8888
expression[Abs]("abs"),
8989
expression[CreateArray]("array"),
9090
expression[Coalesce]("coalesce"),
9191
expression[Explode]("explode"),
92+
expression[If]("if"),
93+
expression[IsNull]("isnull"),
94+
expression[IsNotNull]("isnotnull"),
95+
expression[Coalesce]("nvl"),
9296
expression[Rand]("rand"),
9397
expression[Randn]("randn"),
9498
expression[CreateStruct]("struct"),
9599
expression[Sqrt]("sqrt"),
96100

97-
// Math functions
101+
// math functions
98102
expression[Acos]("acos"),
99103
expression[Asin]("asin"),
100104
expression[Atan]("atan"),
101105
expression[Atan2]("atan2"),
102106
expression[Cbrt]("cbrt"),
103107
expression[Ceil]("ceil"),
108+
expression[Ceil]("ceiling"),
104109
expression[Cos]("cos"),
105110
expression[EulerNumber]("e"),
106111
expression[Exp]("exp"),
107112
expression[Expm1]("expm1"),
108113
expression[Floor]("floor"),
109114
expression[Hypot]("hypot"),
110-
expression[Log]("log"),
115+
expression[Log]("ln"),
111116
expression[Log10]("log10"),
112117
expression[Log1p]("log1p"),
118+
expression[UnaryMinus]("negative"),
113119
expression[Pi]("pi"),
114120
expression[Log2]("log2"),
115121
expression[Pow]("pow"),
122+
expression[Pow]("power"),
116123
expression[Rint]("rint"),
124+
expression[Signum]("sign"),
117125
expression[Signum]("signum"),
118126
expression[Sin]("sin"),
119127
expression[Sinh]("sinh"),
120128
expression[Tan]("tan"),
121129
expression[Tanh]("tanh"),
122-
expression[ToDegrees]("todegrees"),
123-
expression[ToRadians]("toradians"),
130+
expression[ToDegrees]("degrees"),
131+
expression[ToRadians]("radians"),
124132

125133
// aggregate functions
126134
expression[Average]("avg"),
@@ -132,10 +140,12 @@ object FunctionRegistry {
132140
expression[Sum]("sum"),
133141

134142
// string functions
143+
expression[Lower]("lcase"),
135144
expression[Lower]("lower"),
136145
expression[StringLength]("length"),
137146
expression[Substring]("substr"),
138147
expression[Substring]("substring"),
148+
expression[Upper]("ucase"),
139149
expression[Upper]("upper")
140150
)
141151

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,15 @@ object HiveTypeCoercion {
5858
case _ => None
5959
}
6060

61+
/** Similar to [[findTightestCommonType]], but can promote all the way to StringType. */
62+
private def findTightestCommonTypeToString(left: DataType, right: DataType): Option[DataType] = {
63+
findTightestCommonTypeOfTwo(left, right).orElse((left, right) match {
64+
case (StringType, t2: AtomicType) if t2 != BinaryType && t2 != BooleanType => Some(StringType)
65+
case (t1: AtomicType, StringType) if t1 != BinaryType && t1 != BooleanType => Some(StringType)
66+
case _ => None
67+
})
68+
}
69+
6170
/**
6271
* Find the tightest common type of a set of types by continuously applying
6372
* `findTightestCommonTypeOfTwo` on these types.
@@ -91,6 +100,7 @@ trait HiveTypeCoercion {
91100
StringToIntegralCasts ::
92101
FunctionArgumentConversion ::
93102
CaseWhenCoercion ::
103+
IfCoercion ::
94104
Division ::
95105
PropagateTypes ::
96106
ExpectedInputConversion ::
@@ -652,6 +662,26 @@ trait HiveTypeCoercion {
652662
}
653663
}
654664

665+
/**
666+
* Coerces the type of different branches of If statement to a common type.
667+
*/
668+
object IfCoercion extends Rule[LogicalPlan] {
669+
def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions {
670+
// Find tightest common type for If, if the true value and false value have different types.
671+
case i @ If(pred, left, right) if left.dataType != right.dataType =>
672+
findTightestCommonTypeToString(left.dataType, right.dataType).map { widestType =>
673+
val newLeft = if (left.dataType == widestType) left else Cast(left, widestType)
674+
val newRight = if (right.dataType == widestType) right else Cast(right, widestType)
675+
i.makeCopy(Array(pred, newLeft, newRight))
676+
}.getOrElse(i) // If there is no applicable conversion, leave expression unchanged.
677+
678+
// Convert If(null literal, _, _) into boolean type.
679+
// In the optimizer, we should short-circuit this directly into false value.
680+
case i @ If(pred, left, right) if pred.dataType == NullType =>
681+
i.makeCopy(Array(Literal.create(null, BooleanType), left, right))
682+
}
683+
}
684+
655685
/**
656686
* Casts types according to the expected input types for Expressions that have the trait
657687
* `ExpectsInputTypes`.

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,19 @@ class HiveTypeCoercionSuite extends PlanTest {
134134
:: Nil))
135135
}
136136

137+
test("type coercion for If") {
138+
val rule = new HiveTypeCoercion { }.IfCoercion
139+
ruleTest(rule,
140+
If(Literal(true), Literal(1), Literal(1L)),
141+
If(Literal(true), Cast(Literal(1), LongType), Literal(1L))
142+
)
143+
144+
ruleTest(rule,
145+
If(Literal.create(null, NullType), Literal(1), Literal(1)),
146+
If(Literal.create(null, BooleanType), Literal(1), Literal(1))
147+
)
148+
}
149+
137150
test("type coercion for CaseKeyWhen") {
138151
val cwc = new HiveTypeCoercion {}.CaseWhenCoercion
139152
ruleTest(cwc,

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ConditionalExpressionSuite.scala

Lines changed: 42 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,52 @@ package org.apache.spark.sql.catalyst.expressions
1919

2020
import org.apache.spark.SparkFunSuite
2121
import org.apache.spark.sql.catalyst.dsl.expressions._
22-
import org.apache.spark.sql.types.{IntegerType, BooleanType}
22+
import org.apache.spark.sql.types._
2323

2424

2525
class ConditionalExpressionSuite extends SparkFunSuite with ExpressionEvalHelper {
2626

27+
test("if") {
28+
val testcases = Seq[(java.lang.Boolean, Integer, Integer, Integer)](
29+
(true, 1, 2, 1),
30+
(false, 1, 2, 2),
31+
(null, 1, 2, 2),
32+
(true, null, 2, null),
33+
(false, 1, null, null),
34+
(null, null, 2, 2),
35+
(null, 1, null, null)
36+
)
37+
38+
// dataType must match T.
39+
def testIf(convert: (Integer => Any), dataType: DataType): Unit = {
40+
for ((predicate, trueValue, falseValue, expected) <- testcases) {
41+
val trueValueConverted = if (trueValue == null) null else convert(trueValue)
42+
val falseValueConverted = if (falseValue == null) null else convert(falseValue)
43+
val expectedConverted = if (expected == null) null else convert(expected)
44+
45+
checkEvaluation(
46+
If(Literal.create(predicate, BooleanType),
47+
Literal.create(trueValueConverted, dataType),
48+
Literal.create(falseValueConverted, dataType)),
49+
expectedConverted)
50+
}
51+
}
52+
53+
testIf(_ == 1, BooleanType)
54+
testIf(_.toShort, ShortType)
55+
testIf(identity, IntegerType)
56+
testIf(_.toLong, LongType)
57+
58+
testIf(_.toFloat, FloatType)
59+
testIf(_.toDouble, DoubleType)
60+
testIf(Decimal(_), DecimalType.Unlimited)
61+
62+
testIf(identity, DateType)
63+
testIf(_.toLong, TimestampType)
64+
65+
testIf(_.toString, StringType)
66+
}
67+
2768
test("case when") {
2869
val row = create_row(null, false, true, "a", "b", "c")
2970
val c1 = 'a.boolean.at(0)

sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -185,12 +185,20 @@ class ColumnExpressionSuite extends QueryTest {
185185
checkAnswer(
186186
nullStrings.toDF.where($"s".isNull),
187187
nullStrings.collect().toSeq.filter(r => r.getString(1) eq null))
188+
189+
checkAnswer(
190+
ctx.sql("select isnull(null), isnull(1)"),
191+
Row(true, false))
188192
}
189193

190194
test("isNotNull") {
191195
checkAnswer(
192196
nullStrings.toDF.where($"s".isNotNull),
193197
nullStrings.collect().toSeq.filter(r => r.getString(1) ne null))
198+
199+
checkAnswer(
200+
ctx.sql("select isnotnull(null), isnotnull('a')"),
201+
Row(false, true))
194202
}
195203

196204
test("===") {
@@ -393,6 +401,10 @@ class ColumnExpressionSuite extends QueryTest {
393401
testData.select(upper(lit(null))),
394402
(1 to 100).map(n => Row(null))
395403
)
404+
405+
checkAnswer(
406+
ctx.sql("SELECT upper('aB'), ucase('cDe')"),
407+
Row("AB", "CDE"))
396408
}
397409

398410
test("lower") {
@@ -410,6 +422,10 @@ class ColumnExpressionSuite extends QueryTest {
410422
testData.select(lower(lit(null))),
411423
(1 to 100).map(n => Row(null))
412424
)
425+
426+
checkAnswer(
427+
ctx.sql("SELECT lower('aB'), lcase('cDe')"),
428+
Row("ab", "cde"))
413429
}
414430

415431
test("monotonicallyIncreasingId") {

sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala

Lines changed: 14 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,20 @@ class DataFrameFunctionsSuite extends QueryTest {
110110
testData2.collect().toSeq.map(r => Row(~r.getInt(0))))
111111
}
112112

113-
test("length") {
113+
test("if function") {
114+
val df = Seq((1, 2)).toDF("a", "b")
115+
checkAnswer(
116+
df.selectExpr("if(a = 1, 'one', 'not_one')", "if(b = 1, 'one', 'not_one')"),
117+
Row("one", "not_one"))
118+
}
119+
120+
test("nvl function") {
121+
checkAnswer(
122+
ctx.sql("SELECT nvl(null, 'x'), nvl('y', 'x'), nvl(null, null)"),
123+
Row("x", "y", null))
124+
}
125+
126+
test("string length function") {
114127
checkAnswer(
115128
nullStrings.select(strlen($"s"), strlen("s")),
116129
nullStrings.collect().toSeq.map { r =>
@@ -127,18 +140,4 @@ class DataFrameFunctionsSuite extends QueryTest {
127140
Row(l)
128141
})
129142
}
130-
131-
test("log2 functions test") {
132-
val df = Seq((1, 2)).toDF("a", "b")
133-
checkAnswer(
134-
df.select(log2("b") + log2("a")),
135-
Row(1))
136-
137-
checkAnswer(
138-
ctx.sql("SELECT LOG2(8)"),
139-
Row(3))
140-
checkAnswer(
141-
ctx.sql("SELECT LOG2(null)"),
142-
Row(null))
143-
}
144143
}

sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala

Lines changed: 45 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
package org.apache.spark.sql
1919

2020
import org.apache.spark.sql.functions._
21+
import org.apache.spark.sql.functions.{log => logarithm}
2122

2223

2324
private object MathExpressionsTestData {
@@ -151,20 +152,31 @@ class MathExpressionsSuite extends QueryTest {
151152
testOneToOneMathFunction(tanh, math.tanh)
152153
}
153154

154-
test("toDeg") {
155+
test("toDegrees") {
155156
testOneToOneMathFunction(toDegrees, math.toDegrees)
157+
checkAnswer(
158+
ctx.sql("SELECT degrees(0), degrees(1), degrees(1.5)"),
159+
Seq((1, 2)).toDF().select(toDegrees(lit(0)), toDegrees(lit(1)), toDegrees(lit(1.5)))
160+
)
156161
}
157162

158-
test("toRad") {
163+
test("toRadians") {
159164
testOneToOneMathFunction(toRadians, math.toRadians)
165+
checkAnswer(
166+
ctx.sql("SELECT radians(0), radians(1), radians(1.5)"),
167+
Seq((1, 2)).toDF().select(toRadians(lit(0)), toRadians(lit(1)), toRadians(lit(1.5)))
168+
)
160169
}
161170

162171
test("cbrt") {
163172
testOneToOneMathFunction(cbrt, math.cbrt)
164173
}
165174

166-
test("ceil") {
175+
test("ceil and ceiling") {
167176
testOneToOneMathFunction(ceil, math.ceil)
177+
checkAnswer(
178+
ctx.sql("SELECT ceiling(0), ceiling(1), ceiling(1.5)"),
179+
Row(0.0, 1.0, 2.0))
168180
}
169181

170182
test("floor") {
@@ -183,12 +195,21 @@ class MathExpressionsSuite extends QueryTest {
183195
testOneToOneMathFunction(expm1, math.expm1)
184196
}
185197

186-
test("signum") {
198+
test("signum / sign") {
187199
testOneToOneMathFunction[Double](signum, math.signum)
200+
201+
checkAnswer(
202+
ctx.sql("SELECT sign(10), signum(-11)"),
203+
Row(1, -1))
188204
}
189205

190-
test("pow") {
206+
test("pow / power") {
191207
testTwoToOneMathFunction(pow, pow, math.pow)
208+
209+
checkAnswer(
210+
ctx.sql("SELECT pow(1, 2), power(2, 1)"),
211+
Seq((1, 2)).toDF().select(pow(lit(1), lit(2)), pow(lit(2), lit(1)))
212+
)
192213
}
193214

194215
test("hypot") {
@@ -199,8 +220,12 @@ class MathExpressionsSuite extends QueryTest {
199220
testTwoToOneMathFunction(atan2, atan2, math.atan2)
200221
}
201222

202-
test("log") {
223+
test("log / ln") {
203224
testOneToOneNonNegativeMathFunction(org.apache.spark.sql.functions.log, math.log)
225+
checkAnswer(
226+
ctx.sql("SELECT ln(0), ln(1), ln(1.5)"),
227+
Seq((1, 2)).toDF().select(logarithm(lit(0)), logarithm(lit(1)), logarithm(lit(1.5)))
228+
)
204229
}
205230

206231
test("log10") {
@@ -211,4 +236,18 @@ class MathExpressionsSuite extends QueryTest {
211236
testOneToOneNonNegativeMathFunction(log1p, math.log1p)
212237
}
213238

239+
test("log2") {
240+
val df = Seq((1, 2)).toDF("a", "b")
241+
checkAnswer(
242+
df.select(log2("b") + log2("a")),
243+
Row(1))
244+
245+
checkAnswer(ctx.sql("SELECT LOG2(8), LOG2(null)"), Row(3, null))
246+
}
247+
248+
test("negative") {
249+
checkAnswer(
250+
ctx.sql("SELECT negative(1), negative(0), negative(-1)"),
251+
Row(-1, 0, 1))
252+
}
214253
}

0 commit comments

Comments
 (0)