diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceNullWithFalseInPredicate.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceNullWithFalseInPredicate.scala index 4a71dba663b38..92401131e8b82 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceNullWithFalseInPredicate.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceNullWithFalseInPredicate.scala @@ -93,8 +93,12 @@ object ReplaceNullWithFalseInPredicate extends Rule[LogicalPlan] { val newBranches = cw.branches.map { case (cond, value) => replaceNullWithFalse(cond) -> replaceNullWithFalse(value) } - val newElseValue = cw.elseValue.map(replaceNullWithFalse) - CaseWhen(newBranches, newElseValue) + if (newBranches.forall(_._2 == FalseLiteral) && cw.elseValue.isEmpty) { + FalseLiteral + } else { + val newElseValue = cw.elseValue.map(replaceNullWithFalse) + CaseWhen(newBranches, newElseValue) + } case i @ If(pred, trueVal, falseVal) if i.dataType == BooleanType => If(replaceNullWithFalse(pred), replaceNullWithFalse(trueVal), replaceNullWithFalse(falseVal)) case e if e.dataType == BooleanType => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala index 47b968f6ebdd7..f01df5e5e6768 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala @@ -525,6 +525,10 @@ object SimplifyConditionals extends Rule[LogicalPlan] with PredicateHelper { } else { e.copy(branches = branches.take(i).map(branch => (branch._1, elseValue))) } + + case e @ CaseWhen(branches, None) + if branches.forall(_._2.semanticEquals(Literal(null, e.dataType))) => + Literal(null, e.dataType) } } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PushFoldableIntoBranchesSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PushFoldableIntoBranchesSuite.scala index 02307a52ebb89..2d826e7b55a68 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PushFoldableIntoBranchesSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PushFoldableIntoBranchesSuite.scala @@ -258,4 +258,15 @@ class PushFoldableIntoBranchesSuite EqualTo(CaseWhen(Seq((a, Literal(1)), (c, Literal(2))), None).cast(StringType), Literal("4")), CaseWhen(Seq((a, FalseLiteral), (c, FalseLiteral)), None)) } + + test("SPARK-33847: Remove the CaseWhen if elseValue is empty and other outputs are null") { + Seq(a, LessThan(Rand(1), Literal(0.5))).foreach { condition => + assertEquivalent( + EqualTo(CaseWhen(Seq((condition, Literal.create(null, IntegerType)))), Literal(2)), + Literal.create(null, BooleanType)) + assertEquivalent( + EqualTo(CaseWhen(Seq((condition, Literal("str")))).cast(IntegerType), Literal(2)), + Literal.create(null, BooleanType)) + } + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceNullWithFalseInPredicateSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceNullWithFalseInPredicateSuite.scala index 5da71c31e1990..f49e6921fd46a 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceNullWithFalseInPredicateSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceNullWithFalseInPredicateSuite.scala @@ -380,6 +380,39 @@ class ReplaceNullWithFalseInPredicateSuite extends PlanTest { testProjection(originalExpr = column, expectedExpr = column) } + test("replace None of elseValue inside CaseWhen if all branches are FalseLiteral") { + val allFalseBranches = Seq( + (UnresolvedAttribute("i") < Literal(10)) -> FalseLiteral, + (UnresolvedAttribute("i") > Literal(40)) -> FalseLiteral) + val allFalseCond = CaseWhen(allFalseBranches) + + val nonAllFalseBranches = Seq( + (UnresolvedAttribute("i") < Literal(10)) -> FalseLiteral, + (UnresolvedAttribute("i") > Literal(40)) -> TrueLiteral) + val nonAllFalseCond = CaseWhen(nonAllFalseBranches, FalseLiteral) + + testFilter(allFalseCond, FalseLiteral) + testJoin(allFalseCond, FalseLiteral) + testDelete(allFalseCond, FalseLiteral) + testUpdate(allFalseCond, FalseLiteral) + + testFilter(nonAllFalseCond, nonAllFalseCond) + testJoin(nonAllFalseCond, nonAllFalseCond) + testDelete(nonAllFalseCond, nonAllFalseCond) + testUpdate(nonAllFalseCond, nonAllFalseCond) + } + + test("replace None of elseValue inside CaseWhen if all branches are null") { + val allNullBranches = Seq( + (UnresolvedAttribute("i") < Literal(10)) -> Literal.create(null, BooleanType), + (UnresolvedAttribute("i") > Literal(40)) -> Literal.create(null, BooleanType)) + val allFalseCond = CaseWhen(allNullBranches) + testFilter(allFalseCond, FalseLiteral) + testJoin(allFalseCond, FalseLiteral) + testDelete(allFalseCond, FalseLiteral) + testUpdate(allFalseCond, FalseLiteral) + } + private def testFilter(originalCond: Expression, expectedCond: Expression): Unit = { test((rel, exp) => rel.where(exp), originalCond, expectedCond) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyConditionalSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyConditionalSuite.scala index 328fc107e1c1b..1876be21dea4b 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyConditionalSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyConditionalSuite.scala @@ -215,4 +215,12 @@ class SimplifyConditionalSuite extends PlanTest with ExpressionEvalHelper with P If(GreaterThan(Rand(0), UnresolvedAttribute("a")), FalseLiteral, TrueLiteral), LessThanOrEqual(Rand(0), UnresolvedAttribute("a"))) } + + test("SPARK-33847: Remove the CaseWhen if elseValue is empty and other outputs are null") { + Seq(GreaterThan('a, 1), GreaterThan(Rand(0), 1)).foreach { condition => + assertEquivalent( + CaseWhen((condition, Literal.create(null, IntegerType)) :: Nil, None), + Literal.create(null, IntegerType)) + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ReplaceNullWithFalseInPredicateEndToEndSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ReplaceNullWithFalseInPredicateEndToEndSuite.scala index bdbb741f24bc6..739b4052ee90d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ReplaceNullWithFalseInPredicateEndToEndSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ReplaceNullWithFalseInPredicateEndToEndSuite.scala @@ -27,6 +27,12 @@ import org.apache.spark.sql.types.BooleanType class ReplaceNullWithFalseInPredicateEndToEndSuite extends QueryTest with SharedSparkSession { import testImplicits._ + private def checkPlanIsEmptyLocalScan(df: DataFrame): Unit = + df.queryExecution.executedPlan match { + case s: LocalTableScanExec => assert(s.rows.isEmpty) + case p => fail(s"$p is not LocalTableScanExec") + } + test("SPARK-25860: Replace Literal(null, _) with FalseLiteral whenever possible") { withTable("t1", "t2") { Seq((1, true), (2, false)).toDF("l", "b").write.saveAsTable("t1") @@ -64,11 +70,6 @@ class ReplaceNullWithFalseInPredicateEndToEndSuite extends QueryTest with Shared checkAnswer(df1.where("IF(l > 10, false, b OR null)"), Row(1, true)) } - - def checkPlanIsEmptyLocalScan(df: DataFrame): Unit = df.queryExecution.executedPlan match { - case s: LocalTableScanExec => assert(s.rows.isEmpty) - case p => fail(s"$p is not LocalTableScanExec") - } } test("SPARK-26107: Replace Literal(null, _) with FalseLiteral in higher-order functions") { @@ -112,4 +113,14 @@ class ReplaceNullWithFalseInPredicateEndToEndSuite extends QueryTest with Shared assertNoLiteralNullInPlan(q3) } } + + test("SPARK-33847: replace None of elseValue inside CaseWhen to FalseLiteral") { + withTable("t1") { + Seq((1, 1), (2, 2)).toDF("a", "b").write.saveAsTable("t1") + val t1 = spark.table("t1") + val q1 = t1.filter("(CASE WHEN a > 1 THEN 1 END) = 0") + checkAnswer(q1, Seq.empty) + checkPlanIsEmptyLocalScan(q1) + } + } }