Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,9 @@ trait CatalystConf {
def optimizerInSetConversionThreshold: Int
def maxCaseBranchesForCodegen: Int

def maxDepthForCNFNormalization: Int
def maxPredicateNumberForCNFNormalization: Int

def runSQLonFile: Boolean

def warehousePath: String
Expand Down Expand Up @@ -60,6 +63,8 @@ case class SimpleCatalystConf(
optimizerMaxIterations: Int = 100,
optimizerInSetConversionThreshold: Int = 10,
maxCaseBranchesForCodegen: Int = 20,
maxDepthForCNFNormalization: Int = 10,
maxPredicateNumberForCNFNormalization: Int = 20,
runSQLonFile: Boolean = true,
crossJoinEnabled: Boolean = false,
warehousePath: String = "/user/hive/warehouse")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ abstract class Optimizer(sessionCatalog: SessionCatalog, conf: CatalystConf)
ReorderAssociativeOperator,
LikeSimplification,
BooleanSimplification,
CNFNormalization(conf),
SimplifyConditionals,
RemoveDispensableExpressions,
SimplifyBinaryComparison,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package org.apache.spark.sql.catalyst.optimizer

import scala.collection.immutable.HashSet
import scala.collection.mutable.ArrayBuffer

import org.apache.spark.sql.catalyst.CatalystConf
import org.apache.spark.sql.catalyst.analysis._
Expand Down Expand Up @@ -132,6 +133,57 @@ case class OptimizeIn(conf: CatalystConf) extends Rule[LogicalPlan] {
}
}

/**
* Converts the predicates of [[Filter]] operators to CNF form.
*/
case class CNFNormalization(conf: CatalystConf) extends Rule[LogicalPlan] with PredicateHelper {
/**
* Converts a predicate expression to its CNF format. There is a given parameter `depth` which
* can be used to control the processing depth of CNF conversion.
*/
private def toCNF(predicate: Expression, depth: Int = 0): Expression = {
if (depth > conf.maxDepthForCNFNormalization) {
return predicate
}
// For a predicate like: (A && B) || (C && D) || (E)
// The steps in follows looks like:
// 1. (A && B) || (C && D) => (A && B), (C && D), E
// 2. (A && B) => A, B
// 3. foreach predicate in (C && D), E
// 3.a. generate (A || C), (B || C), (A || D), (B || D)
// 3.b. generate ((A || C) || E), ((B || C) || E), ((A || D) || E), ((B || D) || E)
// 4. Recursively apply on each predicate with increasing depth.
// 5. Concatenate them with `AND`:
// ((A || C) || E) && ((B || C) || E) && ((A || D) || E) && ((B || D) || E)
val disjunctives = splitDisjunctivePredicates(predicate)
var finalPredicates = splitConjunctivePredicates(disjunctives.head)
disjunctives.tail.foreach { cond =>
val predicates = new ArrayBuffer[Expression]()
splitConjunctivePredicates(cond).map { p =>
predicates ++= finalPredicates.map(Or(_, p))
}
finalPredicates = predicates.toSeq
}
val cnf = finalPredicates.map { p =>
if (p.semanticEquals(predicate)) {
p
} else {
toCNF(p, depth + 1)
}
}
// To prevent expression explosion problem in CNF conversion, we throw away the CNF format if
// its length is more then a threshold.
if (cnf.length > conf.maxPredicateNumberForCNFNormalization) {
return predicate
} else {
cnf.reduce(And)
}
}

override def apply(plan: LogicalPlan): LogicalPlan = plan transform {
case f @ Filter(condition, _) => f.copy(condition = toCNF(condition))
}
}

/**
* Simplifies boolean expressions:
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.catalyst.expressions

import scala.collection.mutable.ArrayBuffer

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.analysis._
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.optimizer._
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules.RuleExecutor
import org.apache.spark.sql.catalyst.SimpleCatalystConf

class CNFNormalizationSuite extends SparkFunSuite with PredicateHelper {

object Optimize extends RuleExecutor[LogicalPlan] {
val batches =
Batch("AnalysisNodes", Once,
EliminateSubqueryAliases) ::
Batch("Constant Folding", FixedPoint(50),
NullPropagation,
ConstantFolding,
BooleanSimplification,
CNFNormalization(SimpleCatalystConf(true)),
PruneFilters) :: Nil
}

val testRelation = LocalRelation('a.int, 'b.int, 'c.int, 'd.int, 'e.int)

private def checkCondition(input: Expression, expected: Expression): Unit = {
val actual = Optimize.execute(testRelation.where(input).analyze)
val correctAnswer = Optimize.execute(testRelation.where(expected).analyze)

val resultFilterExpression = actual.collectFirst { case f: Filter => f.condition }.get
val expectedFilterExpression = correctAnswer.collectFirst { case f: Filter => f.condition }.get

assert(resultFilterExpression.semanticEquals(expectedFilterExpression))
}

private val a = Literal(1) < 'a
private val b = Literal(1) < 'b
private val c = Literal(1) < 'c
private val d = Literal(1) < 'd
private val e = Literal(1) < 'e
private val f = ! a

test("a || b => a || b") {
checkCondition(a || b, a || b)
}

test("a && b && c => a && b && c") {
checkCondition(a && b && c, a && b && c)
}

test("a && !(b || c) => a && !b && !c") {
checkCondition(a && !(b || c), a && !b && !c)
}

test("a && b || c => (a || c) && (b || c)") {
checkCondition(a && b || c, (a || c) && (b || c))
}

test("a && b || f => (a || f) && (b || f)") {
checkCondition(a && b || f, b || f)
}

test("(a && b) || (c && d) => (c || a) && (c || b) && ((d || a) && (d || b))") {
checkCondition((a && b) || (c && d), (a || c) && (b || c) && (a || d) && (b || d))
}

test("(a && b) || !(c && d) => (a || !c || !d) && (b || !c || !d)") {
checkCondition((a && b) || !(c && d), (a || !c || !d) && (b || !c || !d))
}

test("a || b || c && d => (a || b || c) && (a || b || d)") {
checkCondition(a || b || c && d, (a || b || c) && (a || b || d))
}

test("a || (b && c || d) => (a || b || d) && (a || c || d)") {
checkCondition(a || (b && c || d), (a || b || d) && (a || c || d))
}

test("a || !(b && c || d) => (a || !b || !c) && (a || !d)") {
checkCondition(a || !(b && c || d), (a || !b || !c) && (a || !d))
}

test("a && (b && c || d && e) => a && (b || d) && (c || d) && (b || e) && (c || e)") {
val input = a && (b && c || d && e)
val expected = a && (b || d) && (c || d) && (b || e) && (c || e)
checkCondition(input, expected)
}

test("a && !(b && c || d && e) => a && (!b || !c) && (!d || !e)") {
checkCondition(a && !(b && c || d && e), a && (!b || !c) && (!d || !e))
}

test(
"a || (b && c || d && e) => (a || b || d) && (a || c || d) && (a || b || e) && (a || c || e)") {
val input = a || (b && c || d && e)
val expected = (a || b || d) && (a || c || d) && (a || b || e) && (a || c || e)
checkCondition(input, expected)
}

test(
"a || !(b && c || d && e) => (a || !b || !c) && (a || !d || !e)") {
checkCondition(a || !(b && c || d && e), (a || !b || !c) && (a || !d || !e))
}

test("a && b && c || !(d && e) => (a || !d || !e) && (b || !d || !e) && (c || !d || !e)") {
val input = a && b && c || !(d && e)
val expected = (a || !d || !e) && (b || !d || !e) && (c || !d || !e)
checkCondition(input, expected)
}

test(
"a && b && c || d && e && f => " +
"(a || d) && (a || e) && (a || f) && (b || d) && " +
"(b || e) && (b || f) && (c || d) && (c || e) && (c || f)") {
val input = (a && b && c) || (d && e && f)
val expected = (a || d) && (a || e) && (a || f) &&
(b || d) && (b || e) && (b || f) &&
(c || d) && (c || e) && (c || f)
checkCondition(input, expected)
}

test("((a && b) || (c && d)) || e") {
val input = ((a && b) || (c && d)) || e
val expected = ((a || c) || e) && ((a || d) || e) && ((b || c) || e) && ((b || d) || e)
checkCondition(input, expected)
}

test("CNF normalization exceeds max predicate numbers") {
val input = (1 to 100).map(i => Literal(i) < 'c).reduce(And) ||
(1 to 10).map(i => Literal(i) < 'a).reduce(And)
val analyzed = testRelation.where(input).analyze
val optimized = Optimize.execute(analyzed)
val resultFilterExpression = optimized.collectFirst { case f: Filter => f.condition }.get
val expectedFilterExpression = analyzed.collectFirst { case f: Filter => f.condition }.get
assert(resultFilterExpression.semanticEquals(expectedFilterExpression))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules._
import org.apache.spark.sql.catalyst.SimpleCatalystConf
import org.apache.spark.sql.types.IntegerType

class FilterPushdownSuite extends PlanTest {
Expand All @@ -37,6 +38,7 @@ class FilterPushdownSuite extends PlanTest {
CombineFilters,
PushDownPredicate,
BooleanSimplification,
CNFNormalization(SimpleCatalystConf(true)),
PushPredicateThroughJoin,
CollapseProject) :: Nil
}
Expand Down Expand Up @@ -1018,4 +1020,46 @@ class FilterPushdownSuite extends PlanTest {

comparePlans(Optimize.execute(originalQuery.analyze), correctAnswer.analyze)
}

test("push down filters that are not be able to pushed down after simplification") {
// The following predicate ('a === 2 || 'a === 3) && ('c > 10 || 'a === 2)
// will be simplified as ('a == 2) || ('c > 10 && 'a == 3).
// In its original form, ('a === 2 || 'a === 3) can be pushed down.
// But the simplified one can't.
val originalQuery = testRelation
.select('a, 'b, ('c + 1) as 'cc)
.groupBy('a)('a, count('cc) as 'c)
.where('c > 10) // this predicate can't be pushed down.
.where(('a === 2 || 'a === 3) && ('c > 10 || 'a === 2))

val optimized = Optimize.execute(originalQuery.analyze)
val correctAnswer =
testRelation
.where('a === 2 || 'a === 3)
.select('a, 'b, ('c + 1) as 'cc)
.groupBy('a)('a, count('cc) as 'c)
.where('c > 10).analyze

comparePlans(optimized, correctAnswer)
}

test("disjunctive predicates which are able to pushdown should be pushed down after converted") {
// (('a === 2) || ('c > 10 || 'a === 3)) can't be pushdown due to the disjunctive form.
// However, its conjunctive normal form can be pushdown.
val originalQuery = testRelation
.select('a, 'b, ('c + 1) as 'cc)
.groupBy('a)('a, count('cc) as 'c)
.where('c > 10)
.where(('a === 2) || ('c > 10 && 'a === 3))

val optimized = Optimize.execute(originalQuery.analyze)
val correctAnswer =
testRelation
.where('a === 2 || 'a === 3)
.select('a, 'b, ('c + 1) as 'cc)
.groupBy('a)('a, count('cc) as 'c)
.where('c > 10).analyze

comparePlans(optimized, correctAnswer)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -441,6 +441,18 @@ object SQLConf {
.intConf
.createWithDefault(20)

val MAX_DEPTH_CNF_PREDICATE = SQLConfigBuilder("spark.sql.expression.cnf.maxDepth")
.internal()
.doc("The maximum depth of converting recursively filter predicates to CNF normalization.")
.intConf
.createWithDefault(10)

val MAX_PREDICATE_NUMBER_CNF_PREDICATE = SQLConfigBuilder("spark.sql.expression.cnf.maxNumber")
.internal()
.doc("The maximum number of predicates in the CNF normalization of filter predicates")
.intConf
.createWithDefault(20)

val FILES_MAX_PARTITION_BYTES = SQLConfigBuilder("spark.sql.files.maxPartitionBytes")
.doc("The maximum number of bytes to pack into a single partition when reading files.")
.longConf
Expand Down Expand Up @@ -689,6 +701,10 @@ private[sql] class SQLConf extends Serializable with CatalystConf with Logging {

def maxCaseBranchesForCodegen: Int = getConf(MAX_CASES_BRANCHES)

def maxDepthForCNFNormalization: Int = getConf(MAX_DEPTH_CNF_PREDICATE)

def maxPredicateNumberForCNFNormalization: Int = getConf(MAX_PREDICATE_NUMBER_CNF_PREDICATE)

def exchangeReuseEnabled: Boolean = getConf(EXCHANGE_REUSE_ENABLED)

def caseSensitiveAnalysis: Boolean = getConf(SQLConf.CASE_SENSITIVE)
Expand Down