Skip to content

Commit 04ff99a

Browse files
committed
fix tests
1 parent 67e138d commit 04ff99a

File tree

4 files changed

+19
-33
lines changed

4 files changed

+19
-33
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,16 +30,16 @@ abstract class QueryPlan[PlanType <: TreeNode[PlanType]]
3030
/**
3131
* Extracts the output property from a given child.
3232
*/
33-
def extractConstraintsFromChild(child: QueryPlan[PlanType]): Seq[Expression] = {
33+
def extractConstraintsFromChild(child: QueryPlan[PlanType]): Set[Expression] = {
3434
child.constraints.filter(_.references.subsetOf(outputSet))
3535
}
3636

3737
/**
3838
* An sequence of expressions that describes the data property of the output rows of this
3939
* operator. For example, if the output of this operator is column `a`, an example `constraints`
40-
* can be `Seq(a > 10, a < 20)`.
40+
* can be `Set(a > 10, a < 20)`.
4141
*/
42-
def constraints: Seq[Expression] = Nil
42+
def constraints: Set[Expression] = Set.empty
4343

4444
/**
4545
* Returns the set of attributes that are output by this node.

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -306,7 +306,7 @@ abstract class UnaryNode extends LogicalPlan with PredicateHelper {
306306

307307
override def children: Seq[LogicalPlan] = child :: Nil
308308

309-
override def constraints: Seq[Expression] = {
309+
override def constraints: Set[Expression] = {
310310
extractConstraintsFromChild(child)
311311
}
312312
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -89,9 +89,9 @@ case class Generate(
8989
case class Filter(condition: Expression, child: LogicalPlan) extends UnaryNode {
9090
override def output: Seq[Attribute] = child.output
9191

92-
override def constraints: Seq[Expression] = {
92+
override def constraints: Set[Expression] = {
9393
val newConstraint = splitConjunctivePredicates(condition).filter(
94-
_.references.subsetOf(outputSet))
94+
_.references.subsetOf(outputSet)).toSet
9595
newConstraint.union(extractConstraintsFromChild(child))
9696
}
9797
}
@@ -103,9 +103,9 @@ abstract class SetOperation(left: LogicalPlan, right: LogicalPlan) extends Binar
103103
leftAttr.withNullability(leftAttr.nullable || rightAttr.nullable)
104104
}
105105

106-
protected def leftConstraints: Seq[Expression] = extractConstraintsFromChild(left)
106+
protected def leftConstraints: Set[Expression] = extractConstraintsFromChild(left)
107107

108-
protected def rightConstraints: Seq[Expression] = {
108+
protected def rightConstraints: Set[Expression] = {
109109
require(left.output.size == right.output.size)
110110
val attributeRewrites = AttributeMap(left.output.zip(right.output))
111111
extractConstraintsFromChild(right).map(_ transform {
@@ -135,7 +135,7 @@ case class Union(left: LogicalPlan, right: LogicalPlan) extends SetOperation(lef
135135
Statistics(sizeInBytes = sizeInBytes)
136136
}
137137

138-
override def constraints: Seq[Expression] = {
138+
override def constraints: Set[Expression] = {
139139
leftConstraints.intersect(rightConstraints)
140140
}
141141
}
@@ -147,7 +147,7 @@ case class Intersect(left: LogicalPlan, right: LogicalPlan) extends SetOperation
147147
leftAttr.withNullability(leftAttr.nullable && rightAttr.nullable)
148148
}
149149

150-
override def constraints: Seq[Expression] = {
150+
override def constraints: Set[Expression] = {
151151
leftConstraints.union(rightConstraints)
152152
}
153153
}
@@ -156,7 +156,7 @@ case class Except(left: LogicalPlan, right: LogicalPlan) extends SetOperation(le
156156
/** We don't use right.output because those rows get excluded from the set. */
157157
override def output: Seq[Attribute] = left.output
158158

159-
override def constraints: Seq[Expression] = leftConstraints
159+
override def constraints: Set[Expression] = leftConstraints
160160
}
161161

162162
case class Join(
@@ -180,7 +180,7 @@ case class Join(
180180
}
181181
}
182182

183-
override def constraints: Seq[Expression] = {
183+
override def constraints: Set[Expression] = {
184184
joinType match {
185185
case LeftSemi =>
186186
extractConstraintsFromChild(left)

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/LogicalPlanSuite.scala

Lines changed: 7 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,10 @@ package org.apache.spark.sql.catalyst.plans
1919

2020
import org.apache.spark.SparkFunSuite
2121
import org.apache.spark.sql.catalyst.analysis._
22-
import org.apache.spark.sql.catalyst.expressions._
23-
import org.apache.spark.sql.catalyst.plans.logical._
24-
import org.apache.spark.sql.catalyst.util._
2522
import org.apache.spark.sql.catalyst.dsl.expressions._
2623
import org.apache.spark.sql.catalyst.dsl.plans._
24+
import org.apache.spark.sql.catalyst.expressions._
25+
import org.apache.spark.sql.catalyst.plans.logical._
2726

2827
/**
2928
* This suite is used to test [[LogicalPlan]]'s `resolveOperators` and make sure it can correctly
@@ -75,27 +74,14 @@ class LogicalPlanSuite extends SparkFunSuite {
7574
}
7675

7776
test("propagating constraint in filter") {
78-
79-
def resolve(plan: LogicalPlan, constraints: Seq[String]): Seq[Expression] = {
80-
Seq(plan.resolve(constraints.map(_.toString), caseInsensitiveResolution).get)
81-
}
82-
8377
val tr = LocalRelation('a.int, 'b.string, 'c.int)
78+
def resolveColumn(columnName: String): Expression =
79+
tr.analyze.resolveQuoted(columnName, caseInsensitiveResolution).get
8480
assert(tr.analyze.constraints.isEmpty)
8581
assert(tr.select('a.attr).analyze.constraints.isEmpty)
86-
assert(tr.where('a.attr > 10).analyze.constraints.zip(Seq('a.attr > 10))
87-
.forall(e => e._1.semanticEquals(e._2)))
88-
/*
89-
assert(tr.where('a.attr > 10).analyze.constraints == resolve(tr.where('a.attr > 10).analyze,
90-
Seq("a > 10")))
91-
*/
92-
/*
93-
assert(logicalPlan.constraints ==
94-
Seq(logicalPlan.resolve(Seq('a > 10), caseInsensitiveResolution))
95-
assert(tr.where('a.attr > 10).select('c.attr).analyze.constraints.get == ('a > 10))
96-
assert(tr.where('a.attr > 10).select('c.attr, 'a.attr).where('c.attr < 100)
97-
.analyze.constraints.get == And('a > 10, 'c < 100))
82+
assert(tr.where('a.attr > 10).analyze.constraints == Set(resolveColumn("a") > 10))
9883
assert(tr.where('a.attr > 10).select('c.attr, 'b.attr).analyze.constraints.isEmpty)
99-
*/
84+
assert(tr.where('a.attr > 10).select('c.attr, 'a.attr).where('c.attr < 100)
85+
.analyze.constraints == Set(resolveColumn("a") > 10, resolveColumn("c") < 100))
10086
}
10187
}

0 commit comments

Comments
 (0)