-
Notifications
You must be signed in to change notification settings - Fork 28.9k
[SPARK-6624][SQL]Add CNF Normalization as part of optimization #8200
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -58,6 +58,7 @@ object DefaultOptimizer extends Optimizer { | |
| ConstantFolding, | ||
| LikeSimplification, | ||
| BooleanSimplification, | ||
| CNFNormalization, | ||
| RemoveDispensableExpressions, | ||
| SimplifyFilters, | ||
| SimplifyCasts, | ||
|
|
@@ -463,6 +464,24 @@ object OptimizeIn extends Rule[LogicalPlan] { | |
| } | ||
| } | ||
|
|
||
| /** | ||
| * Convert an expression into its conjunctive normal form (CNF), i.e. AND of ORs. | ||
| * For example, a && b || c is normalized to (a || c) && (b || c) by this method. | ||
| * | ||
| * Refer to https://en.wikipedia.org/wiki/Conjunctive_normal_form for more information | ||
| */ | ||
| object CNFNormalization extends Rule[LogicalPlan] { | ||
| def apply(plan: LogicalPlan): LogicalPlan = plan transform { | ||
| case q: LogicalPlan => q transformExpressionsUp { | ||
| // reverse Or with its And child to eliminate (And under Or) occurrence | ||
| case Or(And(innerLhs, innerRhs), rhs) => | ||
| And(Or(innerLhs, rhs), Or(innerRhs, rhs)) | ||
| case Or(lhs, And(innerLhs, innerRhs)) => | ||
| And(Or(lhs, innerLhs), Or(lhs, innerRhs)) | ||
| } | ||
| } | ||
| } | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We should probably include/move those |
||
|
|
||
| /** | ||
| * Simplifies boolean expressions: | ||
| * 1. Simplifies expressions whose answer can be determined without evaluating both sides. | ||
|
|
@@ -489,13 +508,13 @@ object BooleanSimplification extends Rule[LogicalPlan] with PredicateHelper { | |
| case (l, Or(r, l1)) if (Not(l) == l1) => And(l, r) | ||
| case (Or(l, l1), r) if (l1 == Not(r)) => And(l, r) | ||
| case (Or(l1, l), r) if (l1 == Not(r)) => And(l, r) | ||
| // (a || b) && (a || c) => a || (b && c) | ||
| // (a || b) && (a || b || c) => a || b | ||
|
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
| case _ => | ||
| // 1. Split left and right to get the disjunctive predicates, | ||
| // i.e. lhs = (a, b), rhs = (a, c) | ||
| // 2. Find the common predict between lhsSet and rhsSet, i.e. common = (a) | ||
| // 3. Remove common predict from lhsSet and rhsSet, i.e. ldiff = (b), rdiff = (c) | ||
| // 4. Apply the formula, get the optimized predicate: common || (ldiff && rdiff) | ||
| // i.e. lhs = (a, b), rhs = (a, b, c) | ||
| // 2. Find the common predict between lhsSet and rhsSet, i.e. common = (a, b) | ||
| // 3. If lhsSet or rhsSet contains only the common part, i.e. ldiff = () | ||
| // apply the formula, get the optimized predicate: common | ||
| val lhs = splitDisjunctivePredicates(left) | ||
| val rhs = splitDisjunctivePredicates(right) | ||
| val common = lhs.filter(e => rhs.exists(e.semanticEquals(_))) | ||
|
|
@@ -509,9 +528,8 @@ object BooleanSimplification extends Rule[LogicalPlan] with PredicateHelper { | |
| // (a || b || c || ...) && (a || b) => (a || b) | ||
| common.reduce(Or) | ||
| } else { | ||
| // (a || b || c || ...) && (a || b || d || ...) => | ||
| // ((c || ...) && (d || ...)) || a || b | ||
| (common :+ And(ldiff.reduce(Or), rdiff.reduce(Or))).reduce(Or) | ||
| // both hand sides contain remaining parts, we cannot do simplification here | ||
| and | ||
| } | ||
| } | ||
| } // end of And(left, right) | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,161 @@ | ||
| /* | ||
| * Licensed to the Apache Software Foundation (ASF) under one or more | ||
| * contributor license agreements. See the NOTICE file distributed with | ||
| * this work for additional information regarding copyright ownership. | ||
| * The ASF licenses this file to You under the Apache License, Version 2.0 | ||
| * (the "License"); you may not use this file except in compliance with | ||
| * the License. You may obtain a copy of the License at | ||
| * | ||
| * http://www.apache.org/licenses/LICENSE-2.0 | ||
| * | ||
| * Unless required by applicable law or agreed to in writing, software | ||
| * distributed under the License is distributed on an "AS IS" BASIS, | ||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| * See the License for the specific language governing permissions and | ||
| * limitations under the License. | ||
| */ | ||
|
|
||
| package org.apache.spark.sql.catalyst.expressions | ||
|
|
||
| import scala.collection.mutable.ArrayBuffer | ||
|
|
||
| import org.apache.spark.SparkFunSuite | ||
| import org.apache.spark.sql.catalyst.analysis._ | ||
| import org.apache.spark.sql.catalyst.dsl.plans._ | ||
| import org.apache.spark.sql.catalyst.dsl.expressions._ | ||
| import org.apache.spark.sql.catalyst.optimizer._ | ||
| import org.apache.spark.sql.catalyst.plans.logical._ | ||
| import org.apache.spark.sql.catalyst.rules.RuleExecutor | ||
|
|
||
| class CNFNormalizationSuite extends SparkFunSuite with PredicateHelper { | ||
|
|
||
| object Optimize extends RuleExecutor[LogicalPlan] { | ||
| val batches = | ||
| Batch("AnalysisNodes", Once, | ||
| EliminateSubQueries) :: | ||
| Batch("Constant Folding", FixedPoint(50), | ||
| NullPropagation, | ||
| ConstantFolding, | ||
| BooleanSimplification, | ||
| CNFNormalization, | ||
| SimplifyFilters) :: Nil | ||
| } | ||
|
|
||
| val testRelation = LocalRelation('a.int, 'b.int, 'c.int, 'd.int, 'e.int) | ||
|
|
||
| // Take in expression without And, normalize to its leftmost representation to be comparable | ||
| private def normalizeOrExpression(expression: Expression): Expression = { | ||
| val elements = ArrayBuffer.empty[Expression] | ||
| expression.foreachUp { | ||
| case Or(lhs, rhs) => | ||
| if (!lhs.isInstanceOf[Or]) elements += lhs | ||
| if (!rhs.isInstanceOf[Or]) elements += rhs | ||
| case _ => // do nothing | ||
| } | ||
| if (!expression.isInstanceOf[Or]) { | ||
| elements += expression | ||
| } | ||
| elements.sortBy(_.toString).reduce(Or) | ||
| } | ||
|
|
||
| private def checkCondition(input: Expression, expected: Expression): Unit = { | ||
| val actual = Optimize.execute(testRelation.where(input).analyze) | ||
| val correctAnswer = Optimize.execute(testRelation.where(expected).analyze) | ||
|
|
||
| val resultFilterExpression = actual.collectFirst { case f: Filter => f.condition }.get | ||
| val expectedFilterExpression = correctAnswer.collectFirst { case f: Filter => f.condition }.get | ||
|
|
||
| val exprs = splitConjunctivePredicates(resultFilterExpression) | ||
| .map(normalizeOrExpression).sortBy(_.toString) | ||
| val expectedExprs = splitConjunctivePredicates(expectedFilterExpression) | ||
| .map(normalizeOrExpression).sortBy(_.toString) | ||
|
|
||
| assert(exprs === expectedExprs) | ||
| } | ||
|
|
||
| private val a = Literal(1) < 'a | ||
| private val b = Literal(1) < 'b | ||
| private val c = Literal(1) < 'c | ||
| private val d = Literal(1) < 'd | ||
| private val e = Literal(1) < 'e | ||
| private val f = ! a | ||
|
|
||
| test("a || b => a || b") { | ||
| checkCondition(a || b, a || b) | ||
| } | ||
|
|
||
| test("a && b && c => a && b && c") { | ||
| checkCondition(a && b && c, a && b && c) | ||
| } | ||
|
|
||
| test("a && !(b || c) => a && !b && !c") { | ||
| checkCondition(a && !(b || c), a && !b && !c) | ||
| } | ||
|
|
||
| test("a && b || c => (a || c) && (b || c)") { | ||
| checkCondition(a && b || c, (a || c) && (b || c)) | ||
| } | ||
|
|
||
| test("a && b || f => (a || f) && (b || f)") { | ||
| checkCondition(a && b || f, (a || f) && (b || f)) | ||
| } | ||
|
|
||
| test("(a && b) || (c && d) => (c || a) && (c || b) && ((d || a) && (d || b))") { | ||
| checkCondition((a && b) || (c && d), (a || c) && (b || c) && (a || d) && (b || d)) | ||
| } | ||
|
|
||
| test("(a && b) || !(c && d) => (a || !c || !d) && (b || !c || !d)") { | ||
| checkCondition((a && b) || !(c && d), (a || !c || !d) && (b || !c || !d)) | ||
| } | ||
|
|
||
| test("a || b || c && d => (a || b || c) && (a || b || d)") { | ||
| checkCondition(a || b || c && d, (a || b || c) && (a || b || d)) | ||
| } | ||
|
|
||
| test("a || (b && c || d) => (a || b || d) && (a || c || d)") { | ||
| checkCondition(a || (b && c || d), (a || b || d) && (a || c || d)) | ||
| } | ||
|
|
||
| test("a || !(b && c || d) => (a || !b || !c) && (a || !d)") { | ||
| checkCondition(a || !(b && c || d), (a || !b || !c) && (a || !d)) | ||
| } | ||
|
|
||
| test("a && (b && c || d && e) => a && (b || d) && (c || d) && (b || e) && (c || e)") { | ||
| val input = a && (b && c || d && e) | ||
| val expected = a && (b || d) && (c || d) && (b || e) && (c || e) | ||
| checkCondition(input, expected) | ||
| } | ||
|
|
||
| test("a && !(b && c || d && e) => a && (!b || !c) && (!d || !e)") { | ||
| checkCondition(a && !(b && c || d && e), a && (!b || !c) && (!d || !e)) | ||
| } | ||
|
|
||
| test( | ||
| "a || (b && c || d && e) => (a || b || d) && (a || c || d) && (a || b || e) && (a || c || e)") { | ||
| val input = a || (b && c || d && e) | ||
| val expected = (a || b || d) && (a || c || d) && (a || b || e) && (a || c || e) | ||
| checkCondition(input, expected) | ||
| } | ||
|
|
||
| test( | ||
| "a || !(b && c || d && e) => (a || !b || !c) && (a || !d || !e)") { | ||
| checkCondition(a || !(b && c || d && e), (a || !b || !c) && (a || !d || !e)) | ||
| } | ||
|
|
||
| test("a && b && c || !(d && e) => (a || !d || !e) && (b || !d || !e) && (c || !d || !e)") { | ||
| val input = a && b && c || !(d && e) | ||
| val expected = (a || !d || !e) && (b || !d || !e) && (c || !d || !e) | ||
| checkCondition(input, expected) | ||
| } | ||
|
|
||
| test( | ||
| "a && b && c || d && e && f => " + | ||
| "(a || d) && (a || e) && (a || f) && (b || d) && " + | ||
| "(b || e) && (b || f) && (c || d) && (c || e) && (c || f)") { | ||
| val input = a && b && c || d && e && f | ||
| val expected = (a || d) && (a || e) && (a || f) && | ||
| (b || d) && (b || e) && (b || f) && | ||
| (c || d) && (c || e) && (c || f) | ||
| checkCondition(input, expected) | ||
| } | ||
| } |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -81,12 +81,6 @@ class BooleanSimplificationSuite extends PlanTest with PredicateHelper { | |
| checkCondition(('a < 2 || 'a > 3 || 'b > 5) && 'a < 2, 'a < 2) | ||
|
|
||
| checkCondition('a < 2 && ('a < 2 || 'a > 3 || 'b > 5) , 'a < 2) | ||
|
|
||
| checkCondition(('a < 2 || 'b > 3) && ('a < 2 || 'c > 5), 'a < 2 || ('b > 3 && 'c > 5)) | ||
|
|
||
| checkCondition( | ||
| ('a === 'b || 'b > 3) && ('a === 'b || 'a > 3) && ('a === 'b || 'a < 5), | ||
| ('a === 'b || 'b > 3 && 'a > 3 && 'a < 5)) | ||
| } | ||
|
|
||
| test("a && (!a || b)") { | ||
|
|
@@ -116,13 +110,4 @@ class BooleanSimplificationSuite extends PlanTest with PredicateHelper { | |
| testRelation.where('a > 2 && ('b > 3 || 'b < 5))) | ||
| comparePlans(actual, expected) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we add this case with the CNF rule enabled? |
||
| } | ||
|
|
||
| test("(a || b) && (a || c) => a || (b && c) when case insensitive") { | ||
| val plan = caseInsensitiveAnalyzer.execute( | ||
| testRelation.where(('a > 2 || 'b > 3) && ('A > 2 || 'b < 5))) | ||
| val actual = Optimize.execute(plan) | ||
| val expected = caseInsensitiveAnalyzer.execute( | ||
| testRelation.where('a > 2 || ('b > 3 && 'b < 5))) | ||
| comparePlans(actual, expected) | ||
| } | ||
| } | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@marmbrus @nongli do we want to do this for all expressions? If we do, maybe we should have a feature flag for this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually come to think of it, it'd be great to be able to turn on/off optimization rules for testing. Most of these can be undocumented.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This seems scary to do in general without some kind of bounding. The transformation can explode the number of expressions. Is there an easy way we can cap this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How about using the following heuristic solution to prevent exponential explosion:
sizemethod toTreeNode, which returns the size (total number of nodes) of a tree:For example, we can stop if the size of the converted predicate is 10 times larger than the original one.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(I wonder how traditional RDBMS copes with the CNF exponential expansion issue?)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Capping the size seems reasonable. We need to make sure it continues to work even if the pass is rerun (respects the original limit).