Skip to content

Commit 3c32bbc

Browse files
committed
Fixed if.
1 parent de827ac commit 3c32bbc

File tree

4 files changed

+78
-2
lines changed

4 files changed

+78
-2
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ object FunctionRegistry {
8989
expression[CreateArray]("array"),
9090
expression[Coalesce]("coalesce"),
9191
expression[Explode]("explode"),
92-
// expression[If]("if"), TODO: turn this on after adding rules to auto cast types.
92+
expression[If]("if"),
9393
expression[IsNull]("isnull"),
9494
expression[IsNotNull]("isnotnull"),
9595
expression[Coalesce]("nvl"),

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,7 @@ trait HiveTypeCoercion {
9191
StringToIntegralCasts ::
9292
FunctionArgumentConversion ::
9393
CaseWhenCoercion ::
94+
IfCoercion ::
9495
Division ::
9596
PropagateTypes ::
9697
ExpectedInputConversion ::
@@ -652,6 +653,27 @@ trait HiveTypeCoercion {
652653
}
653654
}
654655

656+
/**
657+
* Coerces the type of different branches of If statement to a common type.
658+
*/
659+
object IfCoercion extends Rule[LogicalPlan] {
660+
def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions {
661+
// Find tightest common type for If, if the true value and false value have different types.
662+
case i @ If(pred, left, right) if left.dataType != right.dataType =>
663+
findTightestCommonTypeOfTwo(left.dataType, right.dataType).map { widestType =>
664+
val newLeft = if (left.dataType == widestType) left else Cast(left, widestType)
665+
val newRight = if (right.dataType == widestType) right else Cast(right, widestType)
666+
i.makeCopy(Array(pred, newLeft, newRight))
667+
}.getOrElse(i) // If there is no applicable conversion, leave expression unchanged.
668+
669+
// Convert If(null literal, _, _) into boolean type.
670+
// In the optimizer, we should short-circuit this directly into false value.
671+
case i @ If(pred, left, right) if pred.dataType == NullType =>
672+
println("fireing this rule")
673+
i.makeCopy(Array(Literal.create(null, BooleanType), left, right))
674+
}
675+
}
676+
655677
/**
656678
* Casts types according to the expected input types for Expressions that have the trait
657679
* `ExpectsInputTypes`.

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,19 @@ class HiveTypeCoercionSuite extends PlanTest {
134134
:: Nil))
135135
}
136136

137+
test("type coercion for If") {
138+
val rule = new HiveTypeCoercion { }.IfCoercion
139+
ruleTest(rule,
140+
If(Literal(true), Literal(1), Literal(1L)),
141+
If(Literal(true), Cast(Literal(1), LongType), Literal(1L))
142+
)
143+
144+
ruleTest(rule,
145+
If(Literal.create(null, NullType), Literal(1), Literal(1)),
146+
If(Literal.create(null, BooleanType), Literal(1), Literal(1))
147+
)
148+
}
149+
137150
test("type coercion for CaseKeyWhen") {
138151
val cwc = new HiveTypeCoercion {}.CaseWhenCoercion
139152
ruleTest(cwc,

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ConditionalExpressionSuite.scala

Lines changed: 42 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,52 @@ package org.apache.spark.sql.catalyst.expressions
1919

2020
import org.apache.spark.SparkFunSuite
2121
import org.apache.spark.sql.catalyst.dsl.expressions._
22-
import org.apache.spark.sql.types.{IntegerType, BooleanType}
22+
import org.apache.spark.sql.types._
2323

2424

2525
class ConditionalExpressionSuite extends SparkFunSuite with ExpressionEvalHelper {
2626

27+
test("if") {
28+
val testcases = Seq[(java.lang.Boolean, Integer, Integer, Integer)](
29+
(true, 1, 2, 1),
30+
(false, 1, 2, 2),
31+
(null, 1, 2, 2),
32+
(true, null, 2, null),
33+
(false, 1, null, null),
34+
(null, null, 2, 2),
35+
(null, 1, null, null)
36+
)
37+
38+
// dataType must match T.
39+
def testIf(convert: (Integer => Any), dataType: DataType): Unit = {
40+
for ((predicate, trueValue, falseValue, expected) <- testcases) {
41+
val trueValueConverted = if (trueValue == null) null else convert(trueValue)
42+
val falseValueConverted = if (falseValue == null) null else convert(falseValue)
43+
val expectedConverted = if (expected == null) null else convert(expected)
44+
45+
checkEvaluation(
46+
If(Literal.create(predicate, BooleanType),
47+
Literal.create(trueValueConverted, dataType),
48+
Literal.create(falseValueConverted, dataType)),
49+
expectedConverted)
50+
}
51+
}
52+
53+
testIf(_ == 1, BooleanType)
54+
testIf(_.toShort, ShortType)
55+
testIf(identity, IntegerType)
56+
testIf(_.toLong, LongType)
57+
58+
testIf(_.toFloat, FloatType)
59+
testIf(_.toDouble, DoubleType)
60+
testIf(Decimal(_), DecimalType.Unlimited)
61+
62+
testIf(identity, DateType)
63+
testIf(_.toLong, TimestampType)
64+
65+
testIf(_.toString, StringType)
66+
}
67+
2768
test("case when") {
2869
val row = create_row(null, false, true, "a", "b", "c")
2970
val c1 = 'a.boolean.at(0)

0 commit comments

Comments
 (0)