Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ object DefaultOptimizer extends Optimizer {
ConstantFolding,
LikeSimplification,
BooleanSimplification,
CNFNormalization,
RemoveDispensableExpressions,
SimplifyFilters,
SimplifyCasts,
Expand Down Expand Up @@ -463,6 +464,24 @@ object OptimizeIn extends Rule[LogicalPlan] {
}
}

/**
* Convert an expression into its conjunctive normal form (CNF), i.e. AND of ORs.
* For example, a && b || c is normalized to (a || c) && (b || c) by this method.
*
* Refer to https://en.wikipedia.org/wiki/Conjunctive_normal_form for more information
*/
object CNFNormalization extends Rule[LogicalPlan] {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@marmbrus @nongli do we want to do this for all expressions? If we do, maybe we should have a feature flag for this?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually come to think of it, it'd be great to be able to turn on/off optimization rules for testing. Most of these can be undocumented.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems scary to do in general without some kind of bounding. The transformation can explode the number of expressions. Is there an easy way we can cap this?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about using the following heuristic solution to prevent exponential explosion:

  1. Add a simple size method to TreeNode, which returns the size (total number of nodes) of a tree:
def size: Int = 1 + children.map(_.size).sum
  1. Gives up CNF conversion once the result predicate exceeds a predefined threshold.

For example, we can stop if the size of the converted predicate is 10 times larger than the original one.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(I wonder how traditional RDBMS copes with the CNF exponential expansion issue?)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Capping the size seems reasonable. We need to make sure it continues to work even if the pass is rerun (respects the original limit).

def apply(plan: LogicalPlan): LogicalPlan = plan transform {
case q: LogicalPlan => q transformExpressionsUp {
// reverse Or with its And child to eliminate (And under Or) occurrence
case Or(And(innerLhs, innerRhs), rhs) =>
And(Or(innerLhs, rhs), Or(innerRhs, rhs))
case Or(lhs, And(innerLhs, innerRhs)) =>
And(Or(lhs, innerLhs), Or(lhs, innerRhs))
}
}
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should probably include/move those Not push-down rules added in PR #5700 defined in BooleanSimplification here, since they are essential for CNF normalization.


/**
* Simplifies boolean expressions:
* 1. Simplifies expressions whose answer can be determined without evaluating both sides.
Expand All @@ -489,13 +508,13 @@ object BooleanSimplification extends Rule[LogicalPlan] with PredicateHelper {
case (l, Or(r, l1)) if (Not(l) == l1) => And(l, r)
case (Or(l, l1), r) if (l1 == Not(r)) => And(l, r)
case (Or(l1, l), r) if (l1 == Not(r)) => And(l, r)
// (a || b) && (a || c) => a || (b && c)
// (a || b) && (a || b || c) => a || b
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(a || b) && (a || c) => a || (b && c) is just a transformation instead of optimization, it is only the case when we could eliminate one side like: (a || b) && (a || b || c) => a || b. Besides, the original transformation is opposite to CNF Normalize.

case _ =>
// 1. Split left and right to get the disjunctive predicates,
// i.e. lhs = (a, b), rhs = (a, c)
// 2. Find the common predict between lhsSet and rhsSet, i.e. common = (a)
// 3. Remove common predict from lhsSet and rhsSet, i.e. ldiff = (b), rdiff = (c)
// 4. Apply the formula, get the optimized predicate: common || (ldiff && rdiff)
// i.e. lhs = (a, b), rhs = (a, b, c)
// 2. Find the common predict between lhsSet and rhsSet, i.e. common = (a, b)
// 3. If lhsSet or rhsSet contains only the common part, i.e. ldiff = ()
// apply the formula, get the optimized predicate: common
val lhs = splitDisjunctivePredicates(left)
val rhs = splitDisjunctivePredicates(right)
val common = lhs.filter(e => rhs.exists(e.semanticEquals(_)))
Expand All @@ -509,9 +528,8 @@ object BooleanSimplification extends Rule[LogicalPlan] with PredicateHelper {
// (a || b || c || ...) && (a || b) => (a || b)
common.reduce(Or)
} else {
// (a || b || c || ...) && (a || b || d || ...) =>
// ((c || ...) && (d || ...)) || a || b
(common :+ And(ldiff.reduce(Or), rdiff.reduce(Or))).reduce(Or)
// both hand sides contain remaining parts, we cannot do simplification here
and
}
}
} // end of And(left, right)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.catalyst.expressions

import scala.collection.mutable.ArrayBuffer

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.analysis._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.optimizer._
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules.RuleExecutor

class CNFNormalizationSuite extends SparkFunSuite with PredicateHelper {

object Optimize extends RuleExecutor[LogicalPlan] {
val batches =
Batch("AnalysisNodes", Once,
EliminateSubQueries) ::
Batch("Constant Folding", FixedPoint(50),
NullPropagation,
ConstantFolding,
BooleanSimplification,
CNFNormalization,
SimplifyFilters) :: Nil
}

val testRelation = LocalRelation('a.int, 'b.int, 'c.int, 'd.int, 'e.int)

// Take in expression without And, normalize to its leftmost representation to be comparable
private def normalizeOrExpression(expression: Expression): Expression = {
val elements = ArrayBuffer.empty[Expression]
expression.foreachUp {
case Or(lhs, rhs) =>
if (!lhs.isInstanceOf[Or]) elements += lhs
if (!rhs.isInstanceOf[Or]) elements += rhs
case _ => // do nothing
}
if (!expression.isInstanceOf[Or]) {
elements += expression
}
elements.sortBy(_.toString).reduce(Or)
}

private def checkCondition(input: Expression, expected: Expression): Unit = {
val actual = Optimize.execute(testRelation.where(input).analyze)
val correctAnswer = Optimize.execute(testRelation.where(expected).analyze)

val resultFilterExpression = actual.collectFirst { case f: Filter => f.condition }.get
val expectedFilterExpression = correctAnswer.collectFirst { case f: Filter => f.condition }.get

val exprs = splitConjunctivePredicates(resultFilterExpression)
.map(normalizeOrExpression).sortBy(_.toString)
val expectedExprs = splitConjunctivePredicates(expectedFilterExpression)
.map(normalizeOrExpression).sortBy(_.toString)

assert(exprs === expectedExprs)
}

private val a = Literal(1) < 'a
private val b = Literal(1) < 'b
private val c = Literal(1) < 'c
private val d = Literal(1) < 'd
private val e = Literal(1) < 'e
private val f = ! a

test("a || b => a || b") {
checkCondition(a || b, a || b)
}

test("a && b && c => a && b && c") {
checkCondition(a && b && c, a && b && c)
}

test("a && !(b || c) => a && !b && !c") {
checkCondition(a && !(b || c), a && !b && !c)
}

test("a && b || c => (a || c) && (b || c)") {
checkCondition(a && b || c, (a || c) && (b || c))
}

test("a && b || f => (a || f) && (b || f)") {
checkCondition(a && b || f, (a || f) && (b || f))
}

test("(a && b) || (c && d) => (c || a) && (c || b) && ((d || a) && (d || b))") {
checkCondition((a && b) || (c && d), (a || c) && (b || c) && (a || d) && (b || d))
}

test("(a && b) || !(c && d) => (a || !c || !d) && (b || !c || !d)") {
checkCondition((a && b) || !(c && d), (a || !c || !d) && (b || !c || !d))
}

test("a || b || c && d => (a || b || c) && (a || b || d)") {
checkCondition(a || b || c && d, (a || b || c) && (a || b || d))
}

test("a || (b && c || d) => (a || b || d) && (a || c || d)") {
checkCondition(a || (b && c || d), (a || b || d) && (a || c || d))
}

test("a || !(b && c || d) => (a || !b || !c) && (a || !d)") {
checkCondition(a || !(b && c || d), (a || !b || !c) && (a || !d))
}

test("a && (b && c || d && e) => a && (b || d) && (c || d) && (b || e) && (c || e)") {
val input = a && (b && c || d && e)
val expected = a && (b || d) && (c || d) && (b || e) && (c || e)
checkCondition(input, expected)
}

test("a && !(b && c || d && e) => a && (!b || !c) && (!d || !e)") {
checkCondition(a && !(b && c || d && e), a && (!b || !c) && (!d || !e))
}

test(
"a || (b && c || d && e) => (a || b || d) && (a || c || d) && (a || b || e) && (a || c || e)") {
val input = a || (b && c || d && e)
val expected = (a || b || d) && (a || c || d) && (a || b || e) && (a || c || e)
checkCondition(input, expected)
}

test(
"a || !(b && c || d && e) => (a || !b || !c) && (a || !d || !e)") {
checkCondition(a || !(b && c || d && e), (a || !b || !c) && (a || !d || !e))
}

test("a && b && c || !(d && e) => (a || !d || !e) && (b || !d || !e) && (c || !d || !e)") {
val input = a && b && c || !(d && e)
val expected = (a || !d || !e) && (b || !d || !e) && (c || !d || !e)
checkCondition(input, expected)
}

test(
"a && b && c || d && e && f => " +
"(a || d) && (a || e) && (a || f) && (b || d) && " +
"(b || e) && (b || f) && (c || d) && (c || e) && (c || f)") {
val input = a && b && c || d && e && f
val expected = (a || d) && (a || e) && (a || f) &&
(b || d) && (b || e) && (b || f) &&
(c || d) && (c || e) && (c || f)
checkCondition(input, expected)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -81,12 +81,6 @@ class BooleanSimplificationSuite extends PlanTest with PredicateHelper {
checkCondition(('a < 2 || 'a > 3 || 'b > 5) && 'a < 2, 'a < 2)

checkCondition('a < 2 && ('a < 2 || 'a > 3 || 'b > 5) , 'a < 2)

checkCondition(('a < 2 || 'b > 3) && ('a < 2 || 'c > 5), 'a < 2 || ('b > 3 && 'c > 5))

checkCondition(
('a === 'b || 'b > 3) && ('a === 'b || 'a > 3) && ('a === 'b || 'a < 5),
('a === 'b || 'b > 3 && 'a > 3 && 'a < 5))
}

test("a && (!a || b)") {
Expand Down Expand Up @@ -116,13 +110,4 @@ class BooleanSimplificationSuite extends PlanTest with PredicateHelper {
testRelation.where('a > 2 && ('b > 3 || 'b < 5)))
comparePlans(actual, expected)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we add this case with the CNF rule enabled?

}

test("(a || b) && (a || c) => a || (b && c) when case insensitive") {
val plan = caseInsensitiveAnalyzer.execute(
testRelation.where(('a > 2 || 'b > 3) && ('A > 2 || 'b < 5)))
val actual = Optimize.execute(plan)
val expected = caseInsensitiveAnalyzer.execute(
testRelation.where('a > 2 || ('b > 3 && 'b < 5)))
comparePlans(actual, expected)
}
}