diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index f6088695a9276..185878d251f48 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -58,6 +58,7 @@ object DefaultOptimizer extends Optimizer { ConstantFolding, LikeSimplification, BooleanSimplification, + CNFNormalization, RemoveDispensableExpressions, SimplifyFilters, SimplifyCasts, @@ -463,6 +464,24 @@ object OptimizeIn extends Rule[LogicalPlan] { } } +/** + * Convert an expression into its conjunctive normal form (CNF), i.e. AND of ORs. + * For example, a && b || c is normalized to (a || c) && (b || c) by this method. + * + * Refer to https://en.wikipedia.org/wiki/Conjunctive_normal_form for more information + */ +object CNFNormalization extends Rule[LogicalPlan] { + def apply(plan: LogicalPlan): LogicalPlan = plan transform { + case q: LogicalPlan => q transformExpressionsUp { + // reverse Or with its And child to eliminate (And under Or) occurrence + case Or(And(innerLhs, innerRhs), rhs) => + And(Or(innerLhs, rhs), Or(innerRhs, rhs)) + case Or(lhs, And(innerLhs, innerRhs)) => + And(Or(lhs, innerLhs), Or(lhs, innerRhs)) + } + } +} + /** * Simplifies boolean expressions: * 1. Simplifies expressions whose answer can be determined without evaluating both sides. @@ -489,13 +508,13 @@ object BooleanSimplification extends Rule[LogicalPlan] with PredicateHelper { case (l, Or(r, l1)) if (Not(l) == l1) => And(l, r) case (Or(l, l1), r) if (l1 == Not(r)) => And(l, r) case (Or(l1, l), r) if (l1 == Not(r)) => And(l, r) - // (a || b) && (a || c) => a || (b && c) + // (a || b) && (a || b || c) => a || b case _ => // 1. Split left and right to get the disjunctive predicates, - // i.e. lhs = (a, b), rhs = (a, c) - // 2. Find the common predict between lhsSet and rhsSet, i.e. common = (a) - // 3. Remove common predict from lhsSet and rhsSet, i.e. ldiff = (b), rdiff = (c) - // 4. Apply the formula, get the optimized predicate: common || (ldiff && rdiff) + // i.e. lhs = (a, b), rhs = (a, b, c) + // 2. Find the common predict between lhsSet and rhsSet, i.e. common = (a, b) + // 3. If lhsSet or rhsSet contains only the common part, i.e. ldiff = () + // apply the formula, get the optimized predicate: common val lhs = splitDisjunctivePredicates(left) val rhs = splitDisjunctivePredicates(right) val common = lhs.filter(e => rhs.exists(e.semanticEquals(_))) @@ -509,9 +528,8 @@ object BooleanSimplification extends Rule[LogicalPlan] with PredicateHelper { // (a || b || c || ...) && (a || b) => (a || b) common.reduce(Or) } else { - // (a || b || c || ...) && (a || b || d || ...) => - // ((c || ...) && (d || ...)) || a || b - (common :+ And(ldiff.reduce(Or), rdiff.reduce(Or))).reduce(Or) + // both hand sides contain remaining parts, we cannot do simplification here + and } } } // end of And(left, right) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CNFNormalizationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CNFNormalizationSuite.scala new file mode 100644 index 0000000000000..caab60c39228f --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CNFNormalizationSuite.scala @@ -0,0 +1,161 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import scala.collection.mutable.ArrayBuffer + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.analysis._ +import org.apache.spark.sql.catalyst.dsl.plans._ +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.optimizer._ +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.rules.RuleExecutor + +class CNFNormalizationSuite extends SparkFunSuite with PredicateHelper { + + object Optimize extends RuleExecutor[LogicalPlan] { + val batches = + Batch("AnalysisNodes", Once, + EliminateSubQueries) :: + Batch("Constant Folding", FixedPoint(50), + NullPropagation, + ConstantFolding, + BooleanSimplification, + CNFNormalization, + SimplifyFilters) :: Nil + } + + val testRelation = LocalRelation('a.int, 'b.int, 'c.int, 'd.int, 'e.int) + + // Take in expression without And, normalize to its leftmost representation to be comparable + private def normalizeOrExpression(expression: Expression): Expression = { + val elements = ArrayBuffer.empty[Expression] + expression.foreachUp { + case Or(lhs, rhs) => + if (!lhs.isInstanceOf[Or]) elements += lhs + if (!rhs.isInstanceOf[Or]) elements += rhs + case _ => // do nothing + } + if (!expression.isInstanceOf[Or]) { + elements += expression + } + elements.sortBy(_.toString).reduce(Or) + } + + private def checkCondition(input: Expression, expected: Expression): Unit = { + val actual = Optimize.execute(testRelation.where(input).analyze) + val correctAnswer = Optimize.execute(testRelation.where(expected).analyze) + + val resultFilterExpression = actual.collectFirst { case f: Filter => f.condition }.get + val expectedFilterExpression = correctAnswer.collectFirst { case f: Filter => f.condition }.get + + val exprs = splitConjunctivePredicates(resultFilterExpression) + .map(normalizeOrExpression).sortBy(_.toString) + val expectedExprs = splitConjunctivePredicates(expectedFilterExpression) + .map(normalizeOrExpression).sortBy(_.toString) + + assert(exprs === expectedExprs) + } + + private val a = Literal(1) < 'a + private val b = Literal(1) < 'b + private val c = Literal(1) < 'c + private val d = Literal(1) < 'd + private val e = Literal(1) < 'e + private val f = ! a + + test("a || b => a || b") { + checkCondition(a || b, a || b) + } + + test("a && b && c => a && b && c") { + checkCondition(a && b && c, a && b && c) + } + + test("a && !(b || c) => a && !b && !c") { + checkCondition(a && !(b || c), a && !b && !c) + } + + test("a && b || c => (a || c) && (b || c)") { + checkCondition(a && b || c, (a || c) && (b || c)) + } + + test("a && b || f => (a || f) && (b || f)") { + checkCondition(a && b || f, (a || f) && (b || f)) + } + + test("(a && b) || (c && d) => (c || a) && (c || b) && ((d || a) && (d || b))") { + checkCondition((a && b) || (c && d), (a || c) && (b || c) && (a || d) && (b || d)) + } + + test("(a && b) || !(c && d) => (a || !c || !d) && (b || !c || !d)") { + checkCondition((a && b) || !(c && d), (a || !c || !d) && (b || !c || !d)) + } + + test("a || b || c && d => (a || b || c) && (a || b || d)") { + checkCondition(a || b || c && d, (a || b || c) && (a || b || d)) + } + + test("a || (b && c || d) => (a || b || d) && (a || c || d)") { + checkCondition(a || (b && c || d), (a || b || d) && (a || c || d)) + } + + test("a || !(b && c || d) => (a || !b || !c) && (a || !d)") { + checkCondition(a || !(b && c || d), (a || !b || !c) && (a || !d)) + } + + test("a && (b && c || d && e) => a && (b || d) && (c || d) && (b || e) && (c || e)") { + val input = a && (b && c || d && e) + val expected = a && (b || d) && (c || d) && (b || e) && (c || e) + checkCondition(input, expected) + } + + test("a && !(b && c || d && e) => a && (!b || !c) && (!d || !e)") { + checkCondition(a && !(b && c || d && e), a && (!b || !c) && (!d || !e)) + } + + test( + "a || (b && c || d && e) => (a || b || d) && (a || c || d) && (a || b || e) && (a || c || e)") { + val input = a || (b && c || d && e) + val expected = (a || b || d) && (a || c || d) && (a || b || e) && (a || c || e) + checkCondition(input, expected) + } + + test( + "a || !(b && c || d && e) => (a || !b || !c) && (a || !d || !e)") { + checkCondition(a || !(b && c || d && e), (a || !b || !c) && (a || !d || !e)) + } + + test("a && b && c || !(d && e) => (a || !d || !e) && (b || !d || !e) && (c || !d || !e)") { + val input = a && b && c || !(d && e) + val expected = (a || !d || !e) && (b || !d || !e) && (c || !d || !e) + checkCondition(input, expected) + } + + test( + "a && b && c || d && e && f => " + + "(a || d) && (a || e) && (a || f) && (b || d) && " + + "(b || e) && (b || f) && (c || d) && (c || e) && (c || f)") { + val input = a && b && c || d && e && f + val expected = (a || d) && (a || e) && (a || f) && + (b || d) && (b || e) && (b || f) && + (c || d) && (c || e) && (c || f) + checkCondition(input, expected) + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BooleanSimplificationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BooleanSimplificationSuite.scala index cde346e99eb17..e5132f3158bc2 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BooleanSimplificationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BooleanSimplificationSuite.scala @@ -81,12 +81,6 @@ class BooleanSimplificationSuite extends PlanTest with PredicateHelper { checkCondition(('a < 2 || 'a > 3 || 'b > 5) && 'a < 2, 'a < 2) checkCondition('a < 2 && ('a < 2 || 'a > 3 || 'b > 5) , 'a < 2) - - checkCondition(('a < 2 || 'b > 3) && ('a < 2 || 'c > 5), 'a < 2 || ('b > 3 && 'c > 5)) - - checkCondition( - ('a === 'b || 'b > 3) && ('a === 'b || 'a > 3) && ('a === 'b || 'a < 5), - ('a === 'b || 'b > 3 && 'a > 3 && 'a < 5)) } test("a && (!a || b)") { @@ -116,13 +110,4 @@ class BooleanSimplificationSuite extends PlanTest with PredicateHelper { testRelation.where('a > 2 && ('b > 3 || 'b < 5))) comparePlans(actual, expected) } - - test("(a || b) && (a || c) => a || (b && c) when case insensitive") { - val plan = caseInsensitiveAnalyzer.execute( - testRelation.where(('a > 2 || 'b > 3) && ('A > 2 || 'b < 5))) - val actual = Optimize.execute(plan) - val expected = caseInsensitiveAnalyzer.execute( - testRelation.where('a > 2 || ('b > 3 && 'b < 5))) - comparePlans(actual, expected) - } }