From 63ea68708eca33920a0da1844e7ccbf3a424d2d5 Mon Sep 17 00:00:00 2001 From: Peter Toth Date: Sun, 28 Apr 2019 10:54:02 +0200 Subject: [PATCH 1/2] [SPARK-27604][SQL] Enhance constant propagation --- .../sql/catalyst/optimizer/expressions.scala | 211 ++++++++++-------- .../optimizer/ConstantPropagationSuite.scala | 100 ++++++++- .../sql-tests/results/explain.sql.out | 4 +- .../org/apache/spark/sql/DataFrameSuite.scala | 96 +++++++- .../spark/sql/execution/PlannerSuite.scala | 2 +- .../datasources/FileSourceStrategySuite.scala | 4 +- 6 files changed, 309 insertions(+), 108 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala index bd400f86ea2c1..ecd47fee1cc2b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.catalyst.optimizer import scala.collection.immutable.HashSet -import scala.collection.mutable.{ArrayBuffer, Stack} +import scala.collection.mutable.{ArrayBuffer, Map, Stack} import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.expressions._ @@ -55,119 +55,138 @@ object ConstantFolding extends Rule[LogicalPlan] { } /** - * Substitutes [[Attribute Attributes]] which can be statically evaluated with their corresponding + * Substitutes [[Expression Expressions]] which can be statically evaluated with their corresponding * value in conjunctive [[Expression Expressions]] * eg. * {{{ - * SELECT * FROM table WHERE i = 5 AND j = i + 3 - * ==> SELECT * FROM table WHERE i = 5 AND j = 8 + * i = 5 AND j = i + 3 => ... i = 5 AND j = 8 + * abs(i) = 5 AND j <= abs(i) + 3 => ... abs(i) = 5 AND j <= 8 * }}} * * Approach used: - * - Populate a mapping of attribute => constant value by looking at all the equals predicates - * - Using this mapping, replace occurrence of the attributes with the corresponding constant values - * in the AND node. + * - Populate a mapping of expression => constant value by looking at all the deterministic equals + * predicates + * - Using this mapping, replace occurrence of the expressions with the corresponding constant + * values in the AND node. */ object ConstantPropagation extends Rule[LogicalPlan] with PredicateHelper { def apply(plan: LogicalPlan): LogicalPlan = plan transform { - case f: Filter => - val (newCondition, _) = traverse(f.condition, replaceChildren = true, nullIsFalse = true) - if (newCondition.isDefined) { - f.copy(condition = newCondition.get) - } else { - f - } - } + case f: Filter => f.mapExpressions(e => traverse(e, Some(false))._1) - type EqualityPredicates = Seq[((AttributeReference, Literal), BinaryComparison)] + // Constant propagation can remove equalities from [[Join]] conditions as they don't add any + // real value, but [[ExtractEquiJoinKeys]] is not prepared to handle that situation. + // SPARK-30598 can solve this issue. + case j: Join => j + + case o => o.mapExpressions(e => traverse(e, None)._1) + } /** - * Traverse a condition as a tree and replace attributes with constant values. - * - On matching [[And]], recursively traverse each children and get propagated mappings. - * If the current node is not child of another [[And]], replace all occurrences of the - * attributes with the corresponding constant values. - * - If a child of [[And]] is [[EqualTo]] or [[EqualNullSafe]], propagate the mapping - * of attribute => constant. - * - On matching [[Or]] or [[Not]], recursively traverse each children, propagate empty mapping. - * - Otherwise, stop traversal and propagate empty mapping. - * @param condition condition to be traversed - * @param replaceChildren whether to replace attributes with constant values in children - * @param nullIsFalse whether a boolean expression result can be considered to false e.g. in the - * case of `WHERE e`, null result of expression `e` means the same as if it - * resulted false + * Traverse a condition as a tree and replace expressions with constant values. + * - On matching [[EqualTo]] or [[EqualNullSafe]], recursively traverse left and right children + * and propagate the expression => constant mapping. + * - On matching [[And]], recursively traverse left subtree and collect propagated mapping to + * replace expressions to constants in right subtree. Then recursively traverse right subtree + * and collect propagated mapping to replace expressions to constants in left subtree. + * - Otherwise, recursively traverse each children, propagate empty mapping. + * - During expression tree traversal tracks a boolean context that controls if constant + * propagation of a nullable expression can be safely applied. + * - E.g. in the case of `WHERE a = c AND f(a)` or `IF(a = c AND f(a), ..., ...)` where `a` is a + * nullable expression and `c` is a constant the null result of `a = c AND f(a)` means the + * same as if it resulted `false` therefore constant propagation can be safely applied (`a = c + * AND f(a)` => `a = c AND f(c)`). This context is represented by `Some(False)`. + * - In the case of `SELECT a = c AND f(a)` the `null` result really means `null`. In this + * context constant propagation can't be applied safely. This context is represented by + * `None`. + * - There is also a 3rd context due to an enclosing `Not` in which the context flips. E.g. + * constant propagation can't be applied on `WHERE NOT(a = c AND f(a))` but can be again on + * `WHERE NOT(IF(..., NOT(a = c AND f(a)), ...)`. This context is represented by `Some(True)`. + * @param expression expression to be traversed + * @param nullValue optional boolean that a null boolean expression result can be considered to * @return A tuple including: - * 1. Option[Expression]: optional changed condition after traversal - * 2. EqualityPredicates: propagated mapping of attribute => constant + * 1. Expression: possibly changed expression after traversal + * 2. Map[Expression, Literal]: propagated mapping of expression => constant */ - private def traverse(condition: Expression, replaceChildren: Boolean, nullIsFalse: Boolean) - : (Option[Expression], EqualityPredicates) = - condition match { - case e @ EqualTo(left: AttributeReference, right: Literal) - if safeToReplace(left, nullIsFalse) => - (None, Seq(((left, right), e))) - case e @ EqualTo(left: Literal, right: AttributeReference) - if safeToReplace(right, nullIsFalse) => - (None, Seq(((right, left), e))) - case e @ EqualNullSafe(left: AttributeReference, right: Literal) - if safeToReplace(left, nullIsFalse) => - (None, Seq(((left, right), e))) - case e @ EqualNullSafe(left: Literal, right: AttributeReference) - if safeToReplace(right, nullIsFalse) => - (None, Seq(((right, left), e))) - case a: And => - val (newLeft, equalityPredicatesLeft) = - traverse(a.left, replaceChildren = false, nullIsFalse) - val (newRight, equalityPredicatesRight) = - traverse(a.right, replaceChildren = false, nullIsFalse) - val equalityPredicates = equalityPredicatesLeft ++ equalityPredicatesRight - val newSelf = if (equalityPredicates.nonEmpty && replaceChildren) { - Some(And(replaceConstants(newLeft.getOrElse(a.left), equalityPredicates), - replaceConstants(newRight.getOrElse(a.right), equalityPredicates))) - } else { - if (newLeft.isDefined || newRight.isDefined) { - Some(And(newLeft.getOrElse(a.left), newRight.getOrElse(a.right))) - } else { - None - } + private def traverse( + expression: Expression, + nullValue: Option[Boolean] = None): (Expression, Map[Expression, Literal]) = + expression match { + case et @ EqualTo(left, right: Literal) if safeToReplace(left, nullValue) => + (et.mapChildren(traverse(_)._1), Map(left.canonicalized -> right)) + case et @ EqualTo(left: Literal, right) if safeToReplace(right, nullValue) => + (et.mapChildren(traverse(_)._1), Map(right.canonicalized -> left)) + case ens @ EqualNullSafe(left, right: Literal) if safeToReplace(left, nullValue) => + (ens.mapChildren(traverse(_)._1), Map(left.canonicalized -> right)) + case ens @ EqualNullSafe(left: Literal, right) if safeToReplace(right, nullValue) => + (ens.mapChildren(traverse(_)._1), Map(right.canonicalized -> left)) + case a @ And(left, right) => + val (newLeft, equalityPredicatesLeft) = traverse(left, nullValue) + val replacedRight = replaceConstants(right, equalityPredicatesLeft) + val (replacedNewRight, equalityPredicatesRight) = traverse(replacedRight, nullValue) + val replacedNewLeft = replaceConstants(newLeft, equalityPredicatesRight) + val newAnd = a.withNewChildren(Seq(replacedNewLeft, replacedNewRight)) + (newAnd, equalityPredicatesLeft ++= equalityPredicatesRight) + case o: Or => (o.mapChildren(traverse(_, nullValue)._1), Map.empty) + case n: Not => (n.mapChildren(traverse(_, nullValue.map(!_))._1), Map.empty) + case i @ If(predicate, trueValue, falseValue) => + val newPredicate = traverse(predicate, Some(false))._1 + val newTrueValue = traverse(trueValue, nullValue)._1 + val newFalseValue = traverse(falseValue, nullValue)._1 + val newIf = i.withNewChildren(Seq(newPredicate, newTrueValue, newFalseValue)) + (newIf, Map.empty) + case cw @ CaseWhen(branches, elseValue) => + val newBranches = branches.flatMap { + case (w, t) => Seq(traverse(w, Some(false))._1, traverse(t, nullValue)._1) } - (newSelf, equalityPredicates) - case o: Or => - // Ignore the EqualityPredicates from children since they are only propagated through And. - val (newLeft, _) = traverse(o.left, replaceChildren = true, nullIsFalse) - val (newRight, _) = traverse(o.right, replaceChildren = true, nullIsFalse) - val newSelf = if (newLeft.isDefined || newRight.isDefined) { - Some(Or(left = newLeft.getOrElse(o.left), right = newRight.getOrElse((o.right)))) - } else { - None - } - (newSelf, Seq.empty) - case n: Not => - // Ignore the EqualityPredicates from children since they are only propagated through And. - val (newChild, _) = traverse(n.child, replaceChildren = true, nullIsFalse = false) - (newChild.map(Not), Seq.empty) - case _ => (None, Seq.empty) + val newElseValue = elseValue.map(traverse(_, nullValue)._1) + val newCaseWhen = cw.withNewChildren(newBranches ++ newElseValue) + (newCaseWhen, Map.empty) + case af @ ArrayFilter(argument, lf: LambdaFunction) => + val newArgument = traverse(argument, nullValue)._1 + val newLF: LambdaFunction = traverseLambdaFunction(lf, false) + val newArrayFilter = af.withNewChildren(Seq(newArgument, newLF)) + (newArrayFilter, Map.empty) + case ae @ ArrayExists(argument, lf: LambdaFunction) => + val newArgument = traverse(argument, nullValue)._1 + val newLF: LambdaFunction = traverseLambdaFunction(lf, + SQLConf.get.getConf(SQLConf.LEGACY_ARRAY_EXISTS_FOLLOWS_THREE_VALUED_LOGIC)) + val newArrayExists = ae.withNewChildren(Seq(newArgument, newLF)) + (newArrayExists, Map.empty) + case mf @ MapFilter(argument, lf: LambdaFunction) => + val newArgument = traverse(argument, nullValue)._1 + val newLF: LambdaFunction = traverseLambdaFunction(lf, false) + val newMapFilter = mf.withNewChildren(Seq(newArgument, newLF)) + (newMapFilter, Map.empty) + + // Actually most of the expressions could propagate nullValue safely. + // We use these few in tests. + case a: Alias => (a.mapChildren(traverse(_, nullValue)._1), Map.empty) + case ca: CreateArray => (ca.mapChildren(traverse(_, nullValue)._1), Map.empty) + case gai: GetArrayItem => (gai.mapChildren(traverse(_, nullValue)._1), Map.empty) + case cm: CreateMap => (cm.mapChildren(traverse(_, nullValue)._1), Map.empty) + case cmv: GetMapValue => (cmv.mapChildren(traverse(_, nullValue)._1), Map.empty) + + // Stay on the safe side and don't propagate nullValue. + case o => (o.mapChildren(traverse(_)._1), Map.empty) } - // We need to take into account if an attribute is nullable and the context of the conjunctive - // expression. E.g. `SELECT * FROM t WHERE NOT(c = 1 AND c + 1 = 1)` where attribute `c` can be - // substituted into `1 + 1 = 1` if 'c' isn't nullable. If 'c' is nullable then the enclosing - // NOT prevents us to do the substitution as NOT flips the context (`nullIsFalse`) of what a - // null result of the enclosed expression means. - private def safeToReplace(ar: AttributeReference, nullIsFalse: Boolean) = - !ar.nullable || nullIsFalse - - private def replaceConstants(condition: Expression, equalityPredicates: EqualityPredicates) - : Expression = { - val constantsMap = AttributeMap(equalityPredicates.map(_._1)) - val predicates = equalityPredicates.map(_._2).toSet - def replaceConstants0(expression: Expression) = expression transform { - case a: AttributeReference => constantsMap.getOrElse(a, a) - } - condition transform { - case e @ EqualTo(_, _) if !predicates.contains(e) => replaceConstants0(e) - case e @ EqualNullSafe(_, _) if !predicates.contains(e) => replaceConstants0(e) - } + private def traverseLambdaFunction(lf: LambdaFunction, threeValuedLogic: Boolean) = { + val newFunction = traverse(lf.function, if (threeValuedLogic) None else Some(false))._1 + lf.withNewChildren(newFunction +: lf.arguments).asInstanceOf[LambdaFunction] } + + private def safeToReplace(expression : Expression, nullValue: Option[Boolean]) = + !expression.foldable && expression.deterministic && + (!expression.nullable || nullValue.contains(false)) + + private def replaceConstants(expression: Expression, constants: Map[Expression, Literal]) = + if (constants.isEmpty) { + expression + } else { + expression transform { + case e if constants.contains(e.canonicalized) => constants(e.canonicalized) + } + } } /** diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantPropagationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantPropagationSuite.scala index 171ac4e3091c3..caa44b60af517 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantPropagationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantPropagationSuite.scala @@ -24,6 +24,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} import org.apache.spark.sql.catalyst.rules.RuleExecutor +import org.apache.spark.sql.internal.SQLConf /** * Unit tests for constant propagation in expressions. @@ -40,12 +41,13 @@ class ConstantPropagationSuite extends PlanTest { BooleanSimplification) :: Nil } - val testRelation = LocalRelation('a.int, 'b.int, 'c.int, 'd.int.notNull) + val testRelation = LocalRelation('a.int, 'b.int, 'c.int, 'd.int.notNull, 'e.int.notNull) private val columnA = 'a private val columnB = 'b private val columnC = 'c private val columnD = 'd + private val columnE = 'e test("basic test") { val query = testRelation @@ -154,13 +156,11 @@ class ConstantPropagationSuite extends PlanTest { test("conflicting equality predicates") { val query = testRelation - .select(columnA) .where( columnA === Literal(1) && columnA === Literal(2) && columnB === Add(columnA, Literal(3))) + .analyze - val correctAnswer = testRelation - .select(columnA) - .where(columnA === Literal(1) && columnA === Literal(2) && columnB === Literal(5)).analyze + val correctAnswer = testRelation.where(Literal.FalseLiteral) comparePlans(Optimize.execute(query.analyze), correctAnswer) } @@ -186,4 +186,94 @@ class ConstantPropagationSuite extends PlanTest { .analyze comparePlans(Optimize.execute(query2), correctAnswer2) } + + test("Constant propagation in conflicting equalities") { + val query = testRelation + .select(columnA) + .where(columnA === Literal(1) && columnA === Literal(2)) + .analyze + val correctAnswer = testRelation + .select(columnA) + .where(Literal.FalseLiteral) + .analyze + comparePlans(Optimize.execute(query), correctAnswer) + } + + test("Enhanced constant propagation") { + def testSelect(expression: Expression, expected: Expression): Unit = { + val plan = testRelation.select(expression.as("x")).analyze + val expectedPlan = testRelation.select(expected.as("x")).analyze + comparePlans(Optimize.execute(plan), expectedPlan) + } + + def testFilter(expression: Expression, expected: Expression): Unit = { + val plan = testRelation.select(columnA).where(expression).analyze + val expectedPlan = testRelation.select(columnA).where(expected).analyze + comparePlans(Optimize.execute(plan), expectedPlan) + } + + val nullable = + abs(columnA) === Literal(1) && columnB === Literal(1) && abs(columnA) <= columnB + val reducedNullable = abs(columnA) === Literal(1) && columnB === Literal(1) + + val nonNullable = + abs(columnD) === Literal(1) && columnE === Literal(1) && abs(columnD) <= columnE + val reducedNonNullable = abs(columnD) === Literal(1) && columnE === Literal(1) + + val expression = nullable || nonNullable + val partlyReduced = nullable || reducedNonNullable + val reduced = reducedNullable || reducedNonNullable + + val simplifiedNegatedNullable = + abs(columnA) =!= Literal(1) || columnB =!= Literal(1) || abs(columnA) > columnB + val reducedSimplifiedNegatedNullable = abs(columnA) =!= Literal(1) || columnB =!= Literal(1) + + val reducedSimplifiedNegatedNonNullable = abs(columnD) =!= Literal(1) || columnE =!= Literal(1) + + val partlyReducedSimplifiedNegated = + simplifiedNegatedNullable && reducedSimplifiedNegatedNonNullable + val reducedSimplifiedNegated = + reducedSimplifiedNegatedNullable && reducedSimplifiedNegatedNonNullable + + testSelect(expression, partlyReduced) + testSelect(If(expression, expression, expression), + If(reduced, partlyReduced, partlyReduced)) + testSelect(CaseWhen(Seq((expression, expression)), expression), + CaseWhen(Seq((reduced, partlyReduced)), partlyReduced)) + testSelect(ArrayFilter(CreateArray(Seq(expression)), LambdaFunction(expression, Nil)), + ArrayFilter(CreateArray(Seq(partlyReduced)), LambdaFunction(reduced, Nil))) + Seq(true, false).foreach { tvl => + withSQLConf(SQLConf.LEGACY_ARRAY_EXISTS_FOLLOWS_THREE_VALUED_LOGIC.key -> s"$tvl") { + testSelect(ArrayExists(CreateArray(Seq(expression)), LambdaFunction(expression, Nil)), + ArrayExists(CreateArray(Seq(partlyReduced)), + LambdaFunction(if (tvl) partlyReduced else reduced, Nil))) + } + } + testSelect(MapFilter(CreateMap(Seq(expression, expression)), LambdaFunction(expression, Nil)), + MapFilter(CreateMap(Seq(partlyReduced, partlyReduced)), LambdaFunction(reduced, Nil))) + testSelect(Not(If(expression, Not(expression), Not(expression))), + Not(If(reduced, partlyReducedSimplifiedNegated, partlyReducedSimplifiedNegated))) + + testFilter(expression, reduced) + testFilter(If(expression, expression, expression), + If(reduced, reduced, reduced)) + testFilter(CaseWhen(Seq((expression, expression)), expression), + CaseWhen(Seq((reduced, reduced)), reduced)) + testFilter( + GetArrayItem(ArrayFilter(CreateArray(Seq(expression)), LambdaFunction(expression, Nil)), 1), + GetArrayItem(ArrayFilter(CreateArray(Seq(reduced)), LambdaFunction(reduced, Nil)), 1)) + Seq(true, false).foreach { tvl => + withSQLConf(SQLConf.LEGACY_ARRAY_EXISTS_FOLLOWS_THREE_VALUED_LOGIC.key -> s"$tvl") { + testFilter(ArrayExists(CreateArray(Seq(expression)), LambdaFunction(expression, Nil)), + ArrayExists(CreateArray(Seq(reduced)), + LambdaFunction(if (tvl) partlyReduced else reduced, Nil))) + } + } + testFilter( + GetMapValue(MapFilter(CreateMap(Seq(expression, expression)), + LambdaFunction(expression, Nil)), true), + GetMapValue(MapFilter(CreateMap(Seq(reduced, reduced)), LambdaFunction(reduced, Nil)), true)) + testFilter(Not(If(expression, Not(expression), Not(expression))), + Not(If(reduced, reducedSimplifiedNegated, reducedSimplifiedNegated))) + } } diff --git a/sql/core/src/test/resources/sql-tests/results/explain.sql.out b/sql/core/src/test/resources/sql-tests/results/explain.sql.out index 756c14f28a657..74c012d2b1d10 100644 --- a/sql/core/src/test/resources/sql-tests/results/explain.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/explain.sql.out @@ -397,7 +397,7 @@ Subquery:1 Hosting operator id = 3 Hosting Expression = Subquery scalar-subquery Output: [key#x, val#x] Batched: true Location [not included in comparison]/{warehouse_dir}/explain_temp2] -PushedFilters: [IsNotNull(key), IsNotNull(val), EqualTo(val,2)] +PushedFilters: [IsNotNull(key), EqualTo(val,2)] ReadSchema: struct (6) ColumnarToRow [codegen id : 1] @@ -405,7 +405,7 @@ Input: [key#x, val#x] (7) Filter [codegen id : 1] Input : [key#x, val#x] -Condition : (((isnotnull(key#x) AND isnotnull(val#x)) AND (key#x = Subquery scalar-subquery#x, [id=#x])) AND (val#x = 2)) +Condition : ((isnotnull(key#x) AND (key#x = Subquery scalar-subquery#x, [id=#x])) AND (val#x = 2)) (8) Project [codegen id : 1] Output : [key#x] diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index d2d58a83ded5d..03fadb82576bf 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -31,8 +31,8 @@ import org.apache.spark.SparkException import org.apache.spark.scheduler.{SparkListener, SparkListenerJobEnd} import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.encoders.RowEncoder -import org.apache.spark.sql.catalyst.expressions.Uuid -import org.apache.spark.sql.catalyst.optimizer.ConvertToLocalRelation +import org.apache.spark.sql.catalyst.expressions.{If, Uuid} +import org.apache.spark.sql.catalyst.optimizer.{ConstantPropagation, ConvertToLocalRelation} import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, OneRowRelation, Union} import org.apache.spark.sql.execution.{FilterExec, QueryExecution, WholeStageCodegenExec} import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper @@ -2298,6 +2298,98 @@ class DataFrameSuite extends QueryTest fail("emptyDataFrame should be foldable") } } + + test("Enhanced constant propagation") { + val testRelation = + Seq[(Integer, Integer, Int, Int)]( + (null, null, 1, 1), + (null, null, 2, 2), + (1, 1, 1, 1), + (1, 1, 2, 2), + (2, 2, 1, 1), + (2, 2, 2, 2) + ).toDF("a", "b", "d", "e") + + val columnA = $"a" + val columnB = $"b" + val columnD = $"d" + val columnE = $"e" + + def testSelect(column: Column, expected: Column): Unit = { + val expectedRows = testRelation.select(expected).collect() + withSQLConf(SQLConf.OPTIMIZER_EXCLUDED_RULES.key -> ConstantPropagation.ruleName) { + val df = testRelation.select(column) + checkAnswer(df, expectedRows) + } + } + + def testFilter(column: Column, expected: Column): Unit = { + val expectedRows = testRelation.filter(expected).collect() + withSQLConf(SQLConf.OPTIMIZER_EXCLUDED_RULES.key -> ConstantPropagation.ruleName) { + val df = testRelation.filter(column) + checkAnswer(df, expectedRows) + } + } + + val nullable = abs(columnA) === 1 && columnB === 1 && abs(columnA) <= columnB + val reducedNullable = abs(columnA) === 1 && columnB === 1 + + val nonNullable = abs(columnD) === 1 && columnE === 1 && abs(columnD) <= columnE + val reducedNonNullable = abs(columnD) === 1 && columnE === 1 + + val column = rand() < 0 || nullable || nonNullable + val partlyReduced = rand() < 0 || nullable || reducedNonNullable + val reduced = rand() < 0 || reducedNullable || reducedNonNullable + + val simplifiedNegatedNullable = abs(columnA) =!= 1 || columnB =!= 1 || abs(columnA) > columnB + val reducedSimplifiedNegatedNullable = abs(columnA) =!= 1 || columnB =!= 1 + + val reducedSimplifiedNegatedNonNullable = abs(columnD) =!= 1 || columnE =!= 1 + + val partlyReducedSimplifiedNegated = + rand() >= 0 && simplifiedNegatedNullable && reducedSimplifiedNegatedNonNullable + val reducedSimplifiedNegated = + rand() >= 0 && reducedSimplifiedNegatedNullable && reducedSimplifiedNegatedNonNullable + + testSelect(column, partlyReduced) + testSelect(new Column(If(column.expr, column.expr, column.expr)), + new Column(If(reduced.expr, partlyReduced.expr, partlyReduced.expr))) + testSelect(when(column, column).otherwise(column), + when(reduced, partlyReduced).otherwise(partlyReduced)) + testSelect(filter(array(column), _ => column), + filter(array(partlyReduced), _ => reduced)) + Seq(true, false).foreach { tvl => + withSQLConf(SQLConf.LEGACY_ARRAY_EXISTS_FOLLOWS_THREE_VALUED_LOGIC.key -> s"$tvl") { + testSelect(exists(array(column), _ => column), + exists(array(partlyReduced), _ => if (tvl) partlyReduced else reduced)) + } + } + testSelect(map_filter(map(coalesce(column, lit(false)), column), (_, _) => column), + map_filter(map(coalesce(partlyReduced, lit(false)), partlyReduced), (_, _) => reduced)) + testSelect(!new Column(If(column.expr, (!column).expr, (!column).expr)), + !new Column(If(reduced.expr, partlyReducedSimplifiedNegated.expr, + partlyReducedSimplifiedNegated.expr))) + + testFilter(column, reduced) + testFilter(new Column(If(column.expr, column.expr, column.expr)), + new Column(If(reduced.expr, reduced.expr, reduced.expr))) + testFilter(when(column, column).otherwise(column), + when(reduced, reduced).otherwise(reduced)) + testFilter( + filter(array(column), _ => column)(0), + filter(array(reduced), _ => reduced)(0)) + Seq(true, false).foreach { tvl => + withSQLConf(SQLConf.LEGACY_ARRAY_EXISTS_FOLLOWS_THREE_VALUED_LOGIC.key -> s"$tvl") { + testFilter(exists(array(column), _ => column), + exists(array(reduced), _ => if (tvl) partlyReduced else reduced)) + } + } + testFilter( + map_filter(map(coalesce(column, lit(false)), column), (_, _) => column)(true), + map_filter(map(coalesce(reduced, lit(false)), reduced), (_, _) => reduced)(true)) + testFilter(!new Column(If(column.expr, (!column).expr, (!column).expr)), + !new Column(If(reduced.expr, reducedSimplifiedNegated.expr, reducedSimplifiedNegated.expr))) + } } case class GroupByKey(a: Int, b: Int) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala index 94ce3559bb44b..db6dc2ec2065b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala @@ -183,7 +183,7 @@ class PlannerSuite extends SharedSparkSession with AdaptiveSparkPlanHelper { withTempView("testPushed") { val exp = sql("select * from testPushed where key = 15").queryExecution.sparkPlan - assert(exp.toString.contains("PushedFilters: [IsNotNull(key), EqualTo(key,15)]")) + assert(exp.toString.contains("PushedFilters: [EqualTo(key,15)]")) } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala index 812305ba24403..7330005db82bb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala @@ -198,7 +198,7 @@ class FileSourceStrategySuite extends QueryTest with SharedSparkSession with Pre "when checking partition values") } // Only the filters that do not contain the partition column should be pushed down - checkDataFilters(Set(IsNotNull("c1"), EqualTo("c1", 1))) + checkDataFilters(Set(EqualTo("c1", 1))) } test("partitioned table - case insensitive") { @@ -225,7 +225,7 @@ class FileSourceStrategySuite extends QueryTest with SharedSparkSession with Pre "when checking partition values") } // Only the filters that do not contain the partition column should be pushed down - checkDataFilters(Set(IsNotNull("c1"), EqualTo("c1", 1))) + checkDataFilters(Set(EqualTo("c1", 1))) } } From 6d994abcf76b6ee48e8d6dd635a894b931bf8725 Mon Sep 17 00:00:00 2001 From: Peter Toth Date: Thu, 13 Feb 2020 11:47:19 +0100 Subject: [PATCH 2/2] fix UTs --- .../scala/org/apache/spark/sql/DataFrameSuite.scala | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 03fadb82576bf..c2f6803bfb72a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -2337,9 +2337,9 @@ class DataFrameSuite extends QueryTest val nonNullable = abs(columnD) === 1 && columnE === 1 && abs(columnD) <= columnE val reducedNonNullable = abs(columnD) === 1 && columnE === 1 - val column = rand() < 0 || nullable || nonNullable - val partlyReduced = rand() < 0 || nullable || reducedNonNullable - val reduced = rand() < 0 || reducedNullable || reducedNonNullable + val column = nullable || nonNullable + val partlyReduced = nullable || reducedNonNullable + val reduced = reducedNullable || reducedNonNullable val simplifiedNegatedNullable = abs(columnA) =!= 1 || columnB =!= 1 || abs(columnA) > columnB val reducedSimplifiedNegatedNullable = abs(columnA) =!= 1 || columnB =!= 1 @@ -2347,9 +2347,9 @@ class DataFrameSuite extends QueryTest val reducedSimplifiedNegatedNonNullable = abs(columnD) =!= 1 || columnE =!= 1 val partlyReducedSimplifiedNegated = - rand() >= 0 && simplifiedNegatedNullable && reducedSimplifiedNegatedNonNullable + simplifiedNegatedNullable && reducedSimplifiedNegatedNonNullable val reducedSimplifiedNegated = - rand() >= 0 && reducedSimplifiedNegatedNullable && reducedSimplifiedNegatedNonNullable + reducedSimplifiedNegatedNullable && reducedSimplifiedNegatedNonNullable testSelect(column, partlyReduced) testSelect(new Column(If(column.expr, column.expr, column.expr)),