Skip to content

Commit 0a6043f

Browse files
dbaliafroozehhvanhovell
authored andcommitted
[SPARK-32755][SQL] Maintain the order of expressions in AttributeSet and ExpressionSet
### What changes were proposed in this pull request? This PR changes `AttributeSet` and `ExpressionSet` to maintain the insertion order of the elements. More specifically, we: - change the underlying data structure of `AttributeSet` from `HashSet` to `LinkedHashSet` to maintain the insertion order. - `ExpressionSet` already uses a list to keep track of the expressions, however, since it is extending Scala's immutable.Set class, operations such as map and flatMap are delegated to the immutable.Set itself. This means that the result of these operations is not an instance of ExpressionSet anymore, rather it's a implementation picked up by the parent class. We also remove this inheritance from `immutable.Set `and implement the needed methods directly. ExpressionSet has a very specific semantics and it does not make sense to extend `immutable.Set` anyway. - change the `PlanStabilitySuite` to not sort the attributes, to be able to catch changes in the order of expressions in different runs. ### Why are the changes needed? Expressions identity is based on the `ExprId` which is an auto-incremented number. This means that the same query can yield a query plan with different expression ids in different runs. `AttributeSet` and `ExpressionSet` internally use a `HashSet` as the underlying data structure, and therefore cannot guarantee the a fixed order of operations in different runs. This can be problematic in cases we like to check for plan changes in different runs. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Passes `PlanStabilitySuite` after regenerating the golden files. Closes #29598 from dbaliafroozeh/FixOrderOfExpressions. Authored-by: Ali Afroozeh <[email protected]> Signed-off-by: herman <[email protected]>
1 parent 95f1e95 commit 0a6043f

File tree

397 files changed

+9481
-9419
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

397 files changed

+9481
-9419
lines changed

sql/catalyst/src/main/scala-2.12/org/apache/spark/sql/catalyst/expressions/ExpressionSet.scala

Lines changed: 78 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,10 @@ object ExpressionSet {
2727
expressions.foreach(set.add)
2828
set
2929
}
30+
31+
def apply(): ExpressionSet = {
32+
new ExpressionSet()
33+
}
3034
}
3135

3236
/**
@@ -53,46 +57,102 @@ object ExpressionSet {
5357
* This is consistent with how we define `semanticEquals` between two expressions.
5458
*/
5559
class ExpressionSet protected(
56-
protected val baseSet: mutable.Set[Expression] = new mutable.HashSet,
57-
protected val originals: mutable.Buffer[Expression] = new ArrayBuffer)
58-
extends Set[Expression] {
60+
private val baseSet: mutable.Set[Expression] = new mutable.HashSet,
61+
private val originals: mutable.Buffer[Expression] = new ArrayBuffer)
62+
extends Iterable[Expression] {
5963

6064
// Note: this class supports Scala 2.12. A parallel source tree has a 2.13 implementation.
6165

6266
protected def add(e: Expression): Unit = {
6367
if (!e.deterministic) {
6468
originals += e
65-
} else if (!baseSet.contains(e.canonicalized) ) {
69+
} else if (!baseSet.contains(e.canonicalized)) {
6670
baseSet.add(e.canonicalized)
6771
originals += e
6872
}
6973
}
7074

71-
override def contains(elem: Expression): Boolean = baseSet.contains(elem.canonicalized)
75+
protected def remove(e: Expression): Unit = {
76+
if (e.deterministic) {
77+
baseSet --= baseSet.filter(_ == e.canonicalized)
78+
originals --= originals.filter(_.canonicalized == e.canonicalized)
79+
}
80+
}
81+
82+
def contains(elem: Expression): Boolean = baseSet.contains(elem.canonicalized)
83+
84+
override def filter(p: Expression => Boolean): ExpressionSet = {
85+
val newBaseSet = baseSet.filter(e => p(e.canonicalized))
86+
val newOriginals = originals.filter(e => p(e.canonicalized))
87+
new ExpressionSet(newBaseSet, newOriginals)
88+
}
89+
90+
override def filterNot(p: Expression => Boolean): ExpressionSet = {
91+
val newBaseSet = baseSet.filterNot(e => p(e.canonicalized))
92+
val newOriginals = originals.filterNot(e => p(e.canonicalized))
93+
new ExpressionSet(newBaseSet, newOriginals)
94+
}
7295

73-
override def +(elem: Expression): ExpressionSet = {
74-
val newSet = new ExpressionSet(baseSet.clone(), originals.clone())
96+
def +(elem: Expression): ExpressionSet = {
97+
val newSet = clone()
7598
newSet.add(elem)
7699
newSet
77100
}
78101

79-
override def ++(elems: GenTraversableOnce[Expression]): ExpressionSet = {
80-
val newSet = new ExpressionSet(baseSet.clone(), originals.clone())
102+
def ++(elems: GenTraversableOnce[Expression]): ExpressionSet = {
103+
val newSet = clone()
81104
elems.foreach(newSet.add)
82105
newSet
83106
}
84107

85-
override def -(elem: Expression): ExpressionSet = {
86-
if (elem.deterministic) {
87-
val newBaseSet = baseSet.clone().filterNot(_ == elem.canonicalized)
88-
val newOriginals = originals.clone().filterNot(_.canonicalized == elem.canonicalized)
89-
new ExpressionSet(newBaseSet, newOriginals)
90-
} else {
91-
new ExpressionSet(baseSet.clone(), originals.clone())
92-
}
108+
def -(elem: Expression): ExpressionSet = {
109+
val newSet = clone()
110+
newSet.remove(elem)
111+
newSet
112+
}
113+
114+
def --(elems: GenTraversableOnce[Expression]): ExpressionSet = {
115+
val newSet = clone()
116+
elems.foreach(newSet.remove)
117+
newSet
93118
}
94119

95-
override def iterator: Iterator[Expression] = originals.iterator
120+
def map(f: Expression => Expression): ExpressionSet = {
121+
val newSet = new ExpressionSet()
122+
this.iterator.foreach(elem => newSet.add(f(elem)))
123+
newSet
124+
}
125+
126+
def flatMap(f: Expression => Iterable[Expression]): ExpressionSet = {
127+
val newSet = new ExpressionSet()
128+
this.iterator.foreach(f(_).foreach(newSet.add))
129+
newSet
130+
}
131+
132+
def iterator: Iterator[Expression] = originals.iterator
133+
134+
def union(that: ExpressionSet): ExpressionSet = {
135+
val newSet = clone()
136+
that.iterator.foreach(newSet.add)
137+
newSet
138+
}
139+
140+
def subsetOf(that: ExpressionSet): Boolean = this.iterator.forall(that.contains)
141+
142+
def intersect(that: ExpressionSet): ExpressionSet = this.filter(that.contains)
143+
144+
def diff(that: ExpressionSet): ExpressionSet = this -- that
145+
146+
def apply(elem: Expression): Boolean = this.contains(elem)
147+
148+
override def equals(obj: Any): Boolean = obj match {
149+
case other: ExpressionSet => this.baseSet == other.baseSet
150+
case _ => false
151+
}
152+
153+
override def hashCode(): Int = baseSet.hashCode()
154+
155+
override def clone(): ExpressionSet = new ExpressionSet(baseSet.clone(), originals.clone())
96156

97157
/**
98158
* Returns a string containing both the post [[Canonicalize]] expressions and the original

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeSet.scala

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,11 @@ object AttributeSet {
3737
val empty = apply(Iterable.empty)
3838

3939
/** Constructs a new [[AttributeSet]] that contains a single [[Attribute]]. */
40-
def apply(a: Attribute): AttributeSet = new AttributeSet(Set(new AttributeEquals(a)))
40+
def apply(a: Attribute): AttributeSet = {
41+
val baseSet = new mutable.LinkedHashSet[AttributeEquals]
42+
baseSet += new AttributeEquals(a)
43+
new AttributeSet(baseSet)
44+
}
4145

4246
/** Constructs a new [[AttributeSet]] given a sequence of [[Expression Expressions]]. */
4347
def apply(baseSet: Iterable[Expression]): AttributeSet = {
@@ -47,7 +51,7 @@ object AttributeSet {
4751
/** Constructs a new [[AttributeSet]] given a sequence of [[AttributeSet]]s. */
4852
def fromAttributeSets(sets: Iterable[AttributeSet]): AttributeSet = {
4953
val baseSet = sets.foldLeft(new mutable.LinkedHashSet[AttributeEquals]())( _ ++= _.baseSet)
50-
new AttributeSet(baseSet.toSet)
54+
new AttributeSet(baseSet)
5155
}
5256
}
5357

@@ -62,7 +66,7 @@ object AttributeSet {
6266
* and also makes doing transformations hard (we always try keep older trees instead of new ones
6367
* when the transformation was a no-op).
6468
*/
65-
class AttributeSet private (val baseSet: Set[AttributeEquals])
69+
class AttributeSet private (private val baseSet: mutable.LinkedHashSet[AttributeEquals])
6670
extends Iterable[Attribute] with Serializable {
6771

6872
override def hashCode: Int = baseSet.hashCode()

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/CostBasedJoinReorder.scala

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.optimizer
2020
import scala.collection.mutable
2121

2222
import org.apache.spark.internal.Logging
23-
import org.apache.spark.sql.catalyst.expressions.{And, Attribute, AttributeSet, Expression, PredicateHelper}
23+
import org.apache.spark.sql.catalyst.expressions.{And, Attribute, AttributeSet, Expression, ExpressionSet, PredicateHelper}
2424
import org.apache.spark.sql.catalyst.plans.{Inner, InnerLike, JoinType}
2525
import org.apache.spark.sql.catalyst.plans.logical._
2626
import org.apache.spark.sql.catalyst.rules.Rule
@@ -75,18 +75,18 @@ object CostBasedJoinReorder extends Rule[LogicalPlan] with PredicateHelper {
7575
* Extracts items of consecutive inner joins and join conditions.
7676
* This method works for bushy trees and left/right deep trees.
7777
*/
78-
private def extractInnerJoins(plan: LogicalPlan): (Seq[LogicalPlan], Set[Expression]) = {
78+
private def extractInnerJoins(plan: LogicalPlan): (Seq[LogicalPlan], ExpressionSet) = {
7979
plan match {
8080
case Join(left, right, _: InnerLike, Some(cond), JoinHint.NONE) =>
8181
val (leftPlans, leftConditions) = extractInnerJoins(left)
8282
val (rightPlans, rightConditions) = extractInnerJoins(right)
83-
(leftPlans ++ rightPlans, splitConjunctivePredicates(cond).toSet ++
84-
leftConditions ++ rightConditions)
83+
(leftPlans ++ rightPlans, leftConditions ++ rightConditions ++
84+
splitConjunctivePredicates(cond))
8585
case Project(projectList, j @ Join(_, _, _: InnerLike, Some(cond), JoinHint.NONE))
8686
if projectList.forall(_.isInstanceOf[Attribute]) =>
8787
extractInnerJoins(j)
8888
case _ =>
89-
(Seq(plan), Set())
89+
(Seq(plan), ExpressionSet())
9090
}
9191
}
9292

@@ -143,15 +143,15 @@ object JoinReorderDP extends PredicateHelper with Logging {
143143
def search(
144144
conf: SQLConf,
145145
items: Seq[LogicalPlan],
146-
conditions: Set[Expression],
146+
conditions: ExpressionSet,
147147
output: Seq[Attribute]): LogicalPlan = {
148148

149149
val startTime = System.nanoTime()
150150
// Level i maintains all found plans for i + 1 items.
151151
// Create the initial plans: each plan is a single item with zero cost.
152152
val itemIndex = items.zipWithIndex
153153
val foundPlans = mutable.Buffer[JoinPlanMap](itemIndex.map {
154-
case (item, id) => Set(id) -> JoinPlan(Set(id), item, Set.empty, Cost(0, 0))
154+
case (item, id) => Set(id) -> JoinPlan(Set(id), item, ExpressionSet(), Cost(0, 0))
155155
}.toMap)
156156

157157
// Build filters from the join graph to be used by the search algorithm.
@@ -194,7 +194,7 @@ object JoinReorderDP extends PredicateHelper with Logging {
194194
private def searchLevel(
195195
existingLevels: Seq[JoinPlanMap],
196196
conf: SQLConf,
197-
conditions: Set[Expression],
197+
conditions: ExpressionSet,
198198
topOutput: AttributeSet,
199199
filters: Option[JoinGraphInfo]): JoinPlanMap = {
200200

@@ -255,7 +255,7 @@ object JoinReorderDP extends PredicateHelper with Logging {
255255
oneJoinPlan: JoinPlan,
256256
otherJoinPlan: JoinPlan,
257257
conf: SQLConf,
258-
conditions: Set[Expression],
258+
conditions: ExpressionSet,
259259
topOutput: AttributeSet,
260260
filters: Option[JoinGraphInfo]): Option[JoinPlan] = {
261261

@@ -329,7 +329,7 @@ object JoinReorderDP extends PredicateHelper with Logging {
329329
case class JoinPlan(
330330
itemIds: Set[Int],
331331
plan: LogicalPlan,
332-
joinConds: Set[Expression],
332+
joinConds: ExpressionSet,
333333
planCost: Cost) {
334334

335335
/** Get the cost of the root node of this plan tree. */
@@ -387,7 +387,7 @@ object JoinReorderDPFilters extends PredicateHelper {
387387
def buildJoinGraphInfo(
388388
conf: SQLConf,
389389
items: Seq[LogicalPlan],
390-
conditions: Set[Expression],
390+
conditions: ExpressionSet,
391391
itemIndex: Seq[(LogicalPlan, Int)]): Option[JoinGraphInfo] = {
392392

393393
if (conf.joinReorderDPStarFilter) {

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -921,13 +921,13 @@ object InferFiltersFromConstraints extends Rule[LogicalPlan]
921921
private def getAllConstraints(
922922
left: LogicalPlan,
923923
right: LogicalPlan,
924-
conditionOpt: Option[Expression]): Set[Expression] = {
924+
conditionOpt: Option[Expression]): ExpressionSet = {
925925
val baseConstraints = left.constraints.union(right.constraints)
926-
.union(conditionOpt.map(splitConjunctivePredicates).getOrElse(Nil).toSet)
926+
.union(ExpressionSet(conditionOpt.map(splitConjunctivePredicates).getOrElse(Nil)))
927927
baseConstraints.union(inferAdditionalConstraints(baseConstraints))
928928
}
929929

930-
private def inferNewFilter(plan: LogicalPlan, constraints: Set[Expression]): LogicalPlan = {
930+
private def inferNewFilter(plan: LogicalPlan, constraints: ExpressionSet): LogicalPlan = {
931931
val newPredicates = constraints
932932
.union(constructIsNotNullConstraints(constraints, plan.output))
933933
.filter { c =>

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -169,8 +169,8 @@ abstract class UnaryNode extends LogicalPlan {
169169
* Generates all valid constraints including an set of aliased constraints by replacing the
170170
* original constraint expressions with the corresponding alias
171171
*/
172-
protected def getAllValidConstraints(projectList: Seq[NamedExpression]): Set[Expression] = {
173-
var allConstraints = child.constraints.asInstanceOf[Set[Expression]]
172+
protected def getAllValidConstraints(projectList: Seq[NamedExpression]): ExpressionSet = {
173+
var allConstraints = child.constraints
174174
projectList.foreach {
175175
case a @ Alias(l: Literal, _) =>
176176
allConstraints += EqualNullSafe(a.toAttribute, l)
@@ -187,7 +187,7 @@ abstract class UnaryNode extends LogicalPlan {
187187
allConstraints
188188
}
189189

190-
override protected lazy val validConstraints: Set[Expression] = child.constraints
190+
override protected lazy val validConstraints: ExpressionSet = child.constraints
191191
}
192192

193193
/**

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/QueryPlanConstraints.scala

Lines changed: 16 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -29,16 +29,14 @@ trait QueryPlanConstraints extends ConstraintHelper { self: LogicalPlan =>
2929
*/
3030
lazy val constraints: ExpressionSet = {
3131
if (conf.constraintPropagationEnabled) {
32-
ExpressionSet(
33-
validConstraints
34-
.union(inferAdditionalConstraints(validConstraints))
35-
.union(constructIsNotNullConstraints(validConstraints, output))
36-
.filter { c =>
37-
c.references.nonEmpty && c.references.subsetOf(outputSet) && c.deterministic
38-
}
39-
)
32+
validConstraints
33+
.union(inferAdditionalConstraints(validConstraints))
34+
.union(constructIsNotNullConstraints(validConstraints, output))
35+
.filter { c =>
36+
c.references.nonEmpty && c.references.subsetOf(outputSet) && c.deterministic
37+
}
4038
} else {
41-
ExpressionSet(Set.empty)
39+
ExpressionSet()
4240
}
4341
}
4442

@@ -50,7 +48,7 @@ trait QueryPlanConstraints extends ConstraintHelper { self: LogicalPlan =>
5048
*
5149
* See [[Canonicalize]] for more details.
5250
*/
53-
protected lazy val validConstraints: Set[Expression] = Set.empty
51+
protected lazy val validConstraints: ExpressionSet = ExpressionSet()
5452
}
5553

5654
trait ConstraintHelper {
@@ -60,8 +58,8 @@ trait ConstraintHelper {
6058
* For e.g., if an operator has constraints of the form (`a = 5`, `a = b`), this returns an
6159
* additional constraint of the form `b = 5`.
6260
*/
63-
def inferAdditionalConstraints(constraints: Set[Expression]): Set[Expression] = {
64-
var inferredConstraints = Set.empty[Expression]
61+
def inferAdditionalConstraints(constraints: ExpressionSet): ExpressionSet = {
62+
var inferredConstraints = ExpressionSet()
6563
// IsNotNull should be constructed by `constructIsNotNullConstraints`.
6664
val predicates = constraints.filterNot(_.isInstanceOf[IsNotNull])
6765
predicates.foreach {
@@ -79,9 +77,9 @@ trait ConstraintHelper {
7977
}
8078

8179
private def replaceConstraints(
82-
constraints: Set[Expression],
80+
constraints: ExpressionSet,
8381
source: Expression,
84-
destination: Expression): Set[Expression] = constraints.map(_ transform {
82+
destination: Expression): ExpressionSet = constraints.map(_ transform {
8583
case e: Expression if e.semanticEquals(source) => destination
8684
})
8785

@@ -91,15 +89,15 @@ trait ConstraintHelper {
9189
* returns a constraint of the form `isNotNull(a)`
9290
*/
9391
def constructIsNotNullConstraints(
94-
constraints: Set[Expression],
95-
output: Seq[Attribute]): Set[Expression] = {
92+
constraints: ExpressionSet,
93+
output: Seq[Attribute]): ExpressionSet = {
9694
// First, we propagate constraints from the null intolerant expressions.
97-
var isNotNullConstraints: Set[Expression] = constraints.flatMap(inferIsNotNullConstraints)
95+
var isNotNullConstraints = constraints.flatMap(inferIsNotNullConstraints(_))
9896

9997
// Second, we infer additional constraints from non-nullable attributes that are part of the
10098
// operator's output
10199
val nonNullableAttributes = output.filterNot(_.nullable)
102-
isNotNullConstraints ++= nonNullableAttributes.map(IsNotNull).toSet
100+
isNotNullConstraints ++= nonNullableAttributes.map(IsNotNull)
103101

104102
isNotNullConstraints -- constraints
105103
}

0 commit comments

Comments
 (0)