Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -389,7 +389,7 @@ class Analyzer(
a.copy(aggregateExpressions = expanded)

// Special handling for cases when self-join introduce duplicate expression ids.
case j @ Join(left, right, _, _) if !j.selfJoinResolved =>
case j @ Join(left, right, _, _, _) if !j.selfJoinResolved =>
val conflictingAttributes = left.outputSet.intersect(right.outputSet)
logDebug(s"Conflicting attributes ${conflictingAttributes.mkString(",")} in $j")

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.analysis

import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateFunction, AggregateExpression}
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.types._

Expand Down Expand Up @@ -87,12 +87,12 @@ trait CheckAnalysis {
s"filter expression '${f.condition.prettyString}' " +
s"of type ${f.condition.dataType.simpleString} is not a boolean.")

case j @ Join(_, _, _, Some(condition)) if condition.dataType != BooleanType =>
case j @ Join(_, _, _, Some(condition), _) if condition.dataType != BooleanType =>
failAnalysis(
s"join condition '${condition.prettyString}' " +
s"of type ${condition.dataType.simpleString} is not a boolean.")

case j @ Join(_, _, _, Some(condition)) =>
case j @ Join(_, _, _, Some(condition), _) =>
def checkValidJoinConditionExprs(expr: Expression): Unit = expr match {
case p: Predicate =>
p.asInstanceOf[Expression].children.foreach(checkValidJoinConditionExprs)
Expand Down Expand Up @@ -190,14 +190,15 @@ trait CheckAnalysis {
| ${exprs.map(_.prettyString).mkString(",")}""".stripMargin)

// Special handling for cases when self-join introduce duplicate expression ids.
case j @ Join(left, right, _, _) if left.outputSet.intersect(right.outputSet).nonEmpty =>
val conflictingAttributes = left.outputSet.intersect(right.outputSet)
failAnalysis(
s"""
|Failure when resolving conflicting references in Join:
|$plan
|Conflicting attributes: ${conflictingAttributes.mkString(",")}
|""".stripMargin)
case j @ Join(left, right, _, _, _)
if left.outputSet.intersect(right.outputSet).nonEmpty =>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: we can use if !j.selfJoinResolved

val conflictingAttributes = left.outputSet.intersect(right.outputSet)
failAnalysis(
s"""
|Failure when resolving conflicting references in Join:
|$plan
|Conflicting attributes: ${conflictingAttributes.mkString(",")}
|""".stripMargin)

case o if !o.resolved =>
failAnalysis(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,13 @@ class EquivalentExpressions {
equivalenceMap.values.map(_.toSeq).toSeq
}

/**
* Returns true if e exists.
*/
def contains(e: Expression): Boolean = {
equivalenceMap.contains(Expr(e))
}

/**
* Returns the state of the data structure as a string. If `all` is false, skips sets of
* equivalent expressions with cardinality 1.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import org.apache.spark.sql.catalyst.analysis.{CleanupAliases, EliminateSubQueri
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate._
import org.apache.spark.sql.catalyst.planning.ExtractFiltersAndInnerJoins
import org.apache.spark.sql.catalyst.planning.ExtractNonNullableAttributes
import org.apache.spark.sql.catalyst.plans.{FullOuter, Inner, LeftOuter, LeftSemi, RightOuter}
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules._
Expand Down Expand Up @@ -52,6 +53,8 @@ object DefaultOptimizer extends Optimizer {
ProjectCollapsing,
CombineFilters,
CombineLimits,
// Predicate inference
AddJoinKeyNullabilityFilters,
// Constant folding
NullPropagation,
OptimizeIn,
Expand Down Expand Up @@ -233,7 +236,7 @@ object ColumnPruning extends Rule[LogicalPlan] {
child))

// Eliminate unneeded attributes from either side of a Join.
case Project(projectList, Join(left, right, joinType, condition)) =>
case Project(projectList, Join(left, right, joinType, condition, generated)) =>
// Collect the list of all references required either above or to evaluate the condition.
val allReferences: AttributeSet =
AttributeSet(
Expand All @@ -243,15 +246,16 @@ object ColumnPruning extends Rule[LogicalPlan] {
/** Applies a projection only when the child is producing unnecessary attributes */
def pruneJoinChild(c: LogicalPlan): LogicalPlan = prunedChild(c, allReferences)

Project(projectList, Join(pruneJoinChild(left), pruneJoinChild(right), joinType, condition))
Project(projectList,
Join(pruneJoinChild(left), pruneJoinChild(right), joinType, condition, generated))

// Eliminate unneeded attributes from right side of a LeftSemiJoin.
case Join(left, right, LeftSemi, condition) =>
case Join(left, right, LeftSemi, condition, generated) =>
// Collect the list of all references required to evaluate the condition.
val allReferences: AttributeSet =
condition.map(_.references).getOrElse(AttributeSet(Seq.empty))

Join(left, prunedChild(right, allReferences), LeftSemi, condition)
Join(left, prunedChild(right, allReferences), LeftSemi, condition, generated)

// Push down project through limit, so that we may have chance to push it further.
case Project(projectList, Limit(exp, child)) =>
Expand Down Expand Up @@ -355,6 +359,51 @@ object LikeSimplification extends Rule[LogicalPlan] {
}
}

/**
* This rule adds IsNotNull predicates based on join keys. If the join contains a condition
* `a` binaryOp *, a is non-nullable. This adds filters for those attributes to the children.
*
* To avoid the problem of repeatedly generating the IsNotNull predicates, the join operator
* remembers all the expressions it generated.
*/
object AddJoinKeyNullabilityFilters extends Rule[LogicalPlan] with PredicateHelper {
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
case j @ Join(left, right, Inner, Some(condition), generated) => {
val nonNullableKeys = ExtractNonNullableAttributes.unapply(condition)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: val ExtractNonNullableAttributes(nonNullableKeys) = condition

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh nvm, I think your version looks more clear.

val newKeys = nonNullableKeys.filter { generated.isEmpty || !generated.get.contains(_) }

if (newKeys.isEmpty) {
j
} else {
val leftKeys = newKeys.filter { canEvaluate(_, left) }
val rightKeys = newKeys.filter { canEvaluate(_, right) }

if (leftKeys.nonEmpty || rightKeys.nonEmpty) {
val newGenerated =
if (j.generatedExpressions.isDefined) j.generatedExpressions.get
else new EquivalentExpressions

var newLeft: LogicalPlan = left
var newRight: LogicalPlan = right

if (leftKeys.nonEmpty) {
newLeft = Filter(leftKeys.map(IsNotNull(_)).reduce(And), left)
leftKeys.foreach { e => newGenerated.addExpr(e) }
}
if (rightKeys.nonEmpty) {
newRight = Filter(rightKeys.map(IsNotNull(_)).reduce(And), right)
rightKeys.foreach { e => newGenerated.addExpr(e) }
}

Join(newLeft, newRight, Inner, Some(condition), Some(newGenerated))
} else {
j
}
}
}
}
}

/**
* Replaces [[Expression Expressions]] that can be statically evaluated with
* equivalent [[Literal]] values. This rule is more specific with
Expand Down Expand Up @@ -784,7 +833,7 @@ object PushPredicateThroughJoin extends Rule[LogicalPlan] with PredicateHelper {

def apply(plan: LogicalPlan): LogicalPlan = plan transform {
// push the where condition down into join filter
case f @ Filter(filterCondition, Join(left, right, joinType, joinCondition)) =>
case f @ Filter(filterCondition, Join(left, right, joinType, joinCondition, generated)) =>
val (leftFilterConditions, rightFilterConditions, commonFilterCondition) =
split(splitConjunctivePredicates(filterCondition), left, right)

Expand All @@ -797,14 +846,14 @@ object PushPredicateThroughJoin extends Rule[LogicalPlan] with PredicateHelper {
reduceLeftOption(And).map(Filter(_, right)).getOrElse(right)
val newJoinCond = (commonFilterCondition ++ joinCondition).reduceLeftOption(And)

Join(newLeft, newRight, Inner, newJoinCond)
Join(newLeft, newRight, Inner, newJoinCond, generated)
case RightOuter =>
// push down the right side only `where` condition
val newLeft = left
val newRight = rightFilterConditions.
reduceLeftOption(And).map(Filter(_, right)).getOrElse(right)
val newJoinCond = joinCondition
val newJoin = Join(newLeft, newRight, RightOuter, newJoinCond)
val newJoin = Join(newLeft, newRight, RightOuter, newJoinCond, generated)

(leftFilterConditions ++ commonFilterCondition).
reduceLeftOption(And).map(Filter(_, newJoin)).getOrElse(newJoin)
Expand All @@ -814,15 +863,15 @@ object PushPredicateThroughJoin extends Rule[LogicalPlan] with PredicateHelper {
reduceLeftOption(And).map(Filter(_, left)).getOrElse(left)
val newRight = right
val newJoinCond = joinCondition
val newJoin = Join(newLeft, newRight, joinType, newJoinCond)
val newJoin = Join(newLeft, newRight, joinType, newJoinCond, generated)

(rightFilterConditions ++ commonFilterCondition).
reduceLeftOption(And).map(Filter(_, newJoin)).getOrElse(newJoin)
case FullOuter => f // DO Nothing for Full Outer Join
}

// push down the join filter into sub query scanning if applicable
case f @ Join(left, right, joinType, joinCondition) =>
case f @ Join(left, right, joinType, joinCondition, generated) =>
val (leftJoinConditions, rightJoinConditions, commonJoinCondition) =
split(joinCondition.map(splitConjunctivePredicates).getOrElse(Nil), left, right)

Expand All @@ -835,23 +884,23 @@ object PushPredicateThroughJoin extends Rule[LogicalPlan] with PredicateHelper {
reduceLeftOption(And).map(Filter(_, right)).getOrElse(right)
val newJoinCond = commonJoinCondition.reduceLeftOption(And)

Join(newLeft, newRight, joinType, newJoinCond)
Join(newLeft, newRight, joinType, newJoinCond, generated)
case RightOuter =>
// push down the left side only join filter for left side sub query
val newLeft = leftJoinConditions.
reduceLeftOption(And).map(Filter(_, left)).getOrElse(left)
val newRight = right
val newJoinCond = (rightJoinConditions ++ commonJoinCondition).reduceLeftOption(And)

Join(newLeft, newRight, RightOuter, newJoinCond)
Join(newLeft, newRight, RightOuter, newJoinCond, generated)
case LeftOuter =>
// push down the right side only join filter for right sub query
val newLeft = left
val newRight = rightJoinConditions.
reduceLeftOption(And).map(Filter(_, right)).getOrElse(right)
val newJoinCond = (leftJoinConditions ++ commonJoinCondition).reduceLeftOption(And)

Join(newLeft, newRight, LeftOuter, newJoinCond)
Join(newLeft, newRight, LeftOuter, newJoinCond, generated)
case FullOuter => f
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@

package org.apache.spark.sql.catalyst.planning

import scala.collection.mutable

import org.apache.spark.Logging
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans._
Expand Down Expand Up @@ -95,7 +97,7 @@ object ExtractEquiJoinKeys extends Logging with PredicateHelper {
(JoinType, Seq[Expression], Seq[Expression], Option[Expression], LogicalPlan, LogicalPlan)

def unapply(plan: LogicalPlan): Option[ReturnType] = plan match {
case join @ Join(left, right, joinType, condition) =>
case join @ Join(left, right, joinType, condition, _) =>
logDebug(s"Considering join on: $condition")
// Find equi-join predicates that can be evaluated before the join, and thus can be used
// as join keys.
Expand Down Expand Up @@ -150,21 +152,21 @@ object ExtractFiltersAndInnerJoins extends PredicateHelper {

// flatten all inner joins, which are next to each other
def flattenJoin(plan: LogicalPlan): (Seq[LogicalPlan], Seq[Expression]) = plan match {
case Join(left, right, Inner, cond) =>
case Join(left, right, Inner, cond, _) =>
val (plans, conditions) = flattenJoin(left)
(plans ++ Seq(right), conditions ++ cond.toSeq)

case Filter(filterCondition, j @ Join(left, right, Inner, joinCondition)) =>
case Filter(filterCondition, j @ Join(left, right, Inner, joinCondition, _)) =>
val (plans, conditions) = flattenJoin(j)
(plans, conditions ++ splitConjunctivePredicates(filterCondition))

case _ => (Seq(plan), Seq())
}

def unapply(plan: LogicalPlan): Option[(Seq[LogicalPlan], Seq[Expression])] = plan match {
case f @ Filter(filterCondition, j @ Join(_, _, Inner, _)) =>
case f @ Filter(filterCondition, j @ Join(_, _, Inner, _, _)) =>
Some(flattenJoin(f))
case j @ Join(_, _, Inner, _) =>
case j @ Join(_, _, Inner, _, _) =>
Some(flattenJoin(j))
case _ => None
}
Expand All @@ -184,3 +186,39 @@ object Unions {
case other => other :: Nil
}
}

/**
* A pattern that finds all attributes in `expr` that cannot be nullable.
*/
object ExtractNonNullableAttributes extends Logging with PredicateHelper {
def unapply(condition: Expression): Set[Attribute] = {
val predicates = splitConjunctivePredicates(condition)

val result = mutable.HashSet.empty[Attribute]
def extract(e: Expression): Unit = e match {
case IsNotNull(a: Attribute) => result.add(a)
case BinaryComparison(a: Attribute, b: Attribute) => {
if (!e.isInstanceOf[EqualNullSafe]) {
result.add(a)
result.add(b)
}
}
case BinaryComparison(Cast(a: Attribute, _), Cast(b: Attribute, _)) => {
if (!e.isInstanceOf[EqualNullSafe]) {
result.add(a)
result.add(b)
}
}
case BinaryComparison(a: Attribute, _) => if (!e.isInstanceOf[EqualNullSafe]) result.add(a)
case BinaryComparison(_, a: Attribute) => if (!e.isInstanceOf[EqualNullSafe]) result.add(a)
case BinaryComparison(Cast(a: Attribute, _), _) =>
if (!e.isInstanceOf[EqualNullSafe]) result.add(a)
case BinaryComparison(_, Cast(a: Attribute, _)) =>
if (!e.isInstanceOf[EqualNullSafe]) result.add(a)
case Not(child) => extract(child)
case _ =>
}
predicates.foreach { extract(_) }
result.toSet
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,13 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging {
*/
lazy val resolved: Boolean = expressions.forall(_.resolved) && childrenResolved

/**
* Returns true if the two plans are semantically equal. This should ignore state generated
* during planning to help the planning process.
* TODO: implement this as a pass that canonicalizes the plan tree instead?
*/
def semanticEquals(other: LogicalPlan): Boolean = this == other
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, this is a new semantic equals. How is this different than sameResult? Maybe we should unify the naming between Expression and LogicalPlan for this concept.


override protected def statePrefix = if (!resolved) "'" else super.statePrefix

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -122,11 +122,22 @@ case class Except(left: LogicalPlan, right: LogicalPlan) extends SetOperation(le
override def output: Seq[Attribute] = left.output
}

object Join {
def apply(
left: LogicalPlan,
right: LogicalPlan,
joinType: JoinType,
condition: Option[Expression]): Join = {
Join(left, right, joinType, condition, None)
}
}

case class Join(
left: LogicalPlan,
right: LogicalPlan,
joinType: JoinType,
condition: Option[Expression]) extends BinaryNode {
condition: Option[Expression],
generatedExpressions: Option[EquivalentExpressions]) extends BinaryNode {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is semi-public API cause I think some advanced projects do dig into catalyst and we've never changed the signature of something as basic as Join before. Could we do this instead by fixing nullablity propagation and only inserting the filter if the attribute is nullable?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I started down that path but can you think of a way to make that handle the more general case of predicate propagation?

t1.key join t2.key where t1.key = t2.key and t1.key = 5.

How do we generate the predicate t2.key = 5? how do we make this more general?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In an earlier version of catalyst we also had equivalence classes propagate up the logical plans. Would that give you enough information?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Equivalence classes is one thing, we can compute that no problem I think. The issue is how to remember that t2.key = 5 was generated and not to generate it again. The trick of setting nullable doesn't work here. We could maintain value constraints (where nullability is a subset).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Couldn't Literal(5) be in the equivalence class and we could check for that?

That said, I also like the idea more general value constraints.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Literal(5) would work for equivalence but we want to track more than equality. If it was t1.key join t2.key where t1.key = t2.key and t1.key > 5, we'd similarly want to add t2.key > 5.

Are you suggesting we don't change the operator and walk the tree bottom up to collect these constraints? This seems extremely expensive to do.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was thinking operators would propagate the set of constraints up from their children (possibly augmenting or clearing as appropriate) and we'd save it in a lazy val.


override def output: Seq[Attribute] = {
joinType match {
Expand All @@ -152,6 +163,17 @@ case class Join(
selfJoinResolved &&
condition.forall(_.dataType == BooleanType)
}

override def simpleString: String = s"$nodeName $joinType, $condition".trim

override def semanticEquals(other: LogicalPlan): Boolean = {
other match {
case Join (l, r, joinType, condition, _) => {
l == left && r == right && this.joinType == joinType && this.condition == condition
}
case _ => false
}
}
}

/**
Expand Down
Loading