diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index 304b438c84ba4..72e35770d865b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -21,6 +21,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions.codegen.{CodeGenContext, GeneratedExpressionCode} import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.rules.{RuleExecutor, Rule} import org.apache.spark.sql.catalyst.util.TypeUtils import org.apache.spark.sql.types._ import org.apache.spark.util.Utils @@ -47,6 +48,34 @@ trait Predicate extends Expression { override def dataType: DataType = BooleanType } +object Predicate extends PredicateHelper { + def toCNF(predicate: Expression, maybeThreshold: Option[Double] = None): Expression = { + val cnf = new CNFExecutor(predicate).execute(predicate) + val threshold = maybeThreshold.map(predicate.size * _).getOrElse(Double.MaxValue) + if (cnf.size > threshold) predicate else cnf + } + + private class CNFNormalization(input: Expression) + extends Rule[Expression] { + + override def apply(tree: Expression): Expression = { + import org.apache.spark.sql.catalyst.dsl.expressions._ + + tree transformDown { + case Not(Not(e)) => e + case Not(a And b) => !a || !b + case Not(a Or b) => !a && !b + case a Or (b And c) => (a || b) && (a || c) + case (a And b) Or c => (a || c) && (b || c) + } + } + } + + private class CNFExecutor(input: Expression) extends RuleExecutor[Expression] { + override protected val batches: Seq[Batch] = + Batch("CNFNormalization", FixedPoint.Unlimited, new CNFNormalization(input)) :: Nil + } +} trait PredicateHelper { protected def splitConjunctivePredicates(condition: Expression): Seq[Expression] = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index f6088695a9276..275aad5e9ed54 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -58,6 +58,7 @@ object DefaultOptimizer extends Optimizer { ConstantFolding, LikeSimplification, BooleanSimplification, + CNFNormalization, RemoveDispensableExpressions, SimplifyFilters, SimplifyCasts, @@ -583,6 +584,12 @@ object BooleanSimplification extends Rule[LogicalPlan] with PredicateHelper { } } +object CNFNormalization extends Rule[LogicalPlan] { + override def apply(plan: LogicalPlan): LogicalPlan = plan transform { + case f @ Filter(condition, _) => f.copy(condition = Predicate.toCNF(condition, Some(10))) + } +} + /** * Combines two adjacent [[Filter]] operators into one, merging the * conditions into one conjunctive predicate. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala index f80d2a93241d1..06bbed9973fbf 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala @@ -55,6 +55,10 @@ abstract class RuleExecutor[TreeType <: TreeNode[_]] extends Logging { /** A strategy that runs until fix point or maxIterations times, whichever comes first. */ case class FixedPoint(maxIterations: Int) extends Strategy + object FixedPoint { + val Unlimited: FixedPoint = FixedPoint(-1) + } + /** A batch of rules. */ protected case class Batch(name: String, strategy: Strategy, rules: Rule[TreeType]*) @@ -95,7 +99,7 @@ abstract class RuleExecutor[TreeType <: TreeNode[_]] extends Logging { result } iteration += 1 - if (iteration > batch.strategy.maxIterations) { + if (batch.strategy.maxIterations > 0 && iteration > batch.strategy.maxIterations) { // Only log if this is a rule that is supposed to run more than once. if (iteration != 2) { logInfo(s"Max iterations (${iteration - 1}) reached for batch ${batch.name}") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala index c97dc2d8be7e6..da8afd71e77da 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala @@ -558,6 +558,9 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product { } case _ => JNull } + + /** Returns total number of tree nodes in this tree. */ + def size: Int = 1 + children.map(_.size).sum } object TreeNode {