Skip to content

Commit b77a4d6

Browse files
committed
modify function inferAdditionalConstraints to avoid producing non-converging set of constraints
1 parent aef506e commit b77a4d6

File tree

2 files changed

+59
-2
lines changed

2 files changed

+59
-2
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -74,20 +74,36 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] extends TreeNode[PlanT
7474
* additional constraint of the form `b = 5`
7575
*/
7676
private def inferAdditionalConstraints(constraints: Set[Expression]): Set[Expression] = {
77+
// Collect alias from expressions to avoid producing non-converging set of constraints
78+
// for recursive functions.
79+
// For more details, infer https://issues.apache.org/jira/browse/SPARK-17733
80+
val aliasMap = AttributeMap((expressions ++ children.flatMap(_.expressions)).collect {
81+
case a: Alias => (a.toAttribute, a.child)
82+
})
83+
7784
var inferredConstraints = Set.empty[Expression]
7885
constraints.foreach {
7986
case eq @ EqualTo(l: Attribute, r: Attribute) =>
8087
inferredConstraints ++= (constraints - eq).map(_ transform {
81-
case a: Attribute if a.semanticEquals(l) => r
88+
case a: Attribute if a.semanticEquals(l) && !isRecursiveDeduction(a, r, aliasMap) => r
8289
})
8390
inferredConstraints ++= (constraints - eq).map(_ transform {
84-
case a: Attribute if a.semanticEquals(r) => l
91+
case a: Attribute if a.semanticEquals(r) && !isRecursiveDeduction(l, a, aliasMap) => l
8592
})
8693
case _ => // No inference
8794
}
8895
inferredConstraints -- constraints
8996
}
9097

98+
private def isRecursiveDeduction(
99+
left: Attribute,
100+
right: Attribute,
101+
aliasMap: AttributeMap[Expression]): Boolean = {
102+
val leftExpression = aliasMap.getOrElse(left, left)
103+
val rightExpression = aliasMap.getOrElse(right, right)
104+
leftExpression.containsChild(rightExpression) || rightExpression.containsChild(leftExpression)
105+
}
106+
91107
/**
92108
* An [[ExpressionSet]] that contains invariants about the rows output by this operator. For
93109
* example, if this set contains the expression `a = 2` then that expression is guaranteed to

sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2678,4 +2678,45 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
26782678
}
26792679
}
26802680
}
2681+
2682+
test("SPARK-17733 InferFiltersFromConstraints rule never terminates for query") {
2683+
withTempView("tmpv") {
2684+
spark.range(10).toDF("a").createTempView("tmpv")
2685+
2686+
// Just ensure the following query will successfully execute complete.
2687+
assert(sql(
2688+
"""
2689+
|SELECT
2690+
| *
2691+
|FROM (
2692+
| SELECT
2693+
| COALESCE(t1.a, t2.a) AS int_col,
2694+
| t1.a,
2695+
| t2.a AS b
2696+
| FROM tmpv t1
2697+
| CROSS JOIN tmpv t2
2698+
|) t1
2699+
|INNER JOIN tmpv t2
2700+
|ON (((t2.a) = (t1.a)) AND ((t2.a) = (t1.int_col))) AND ((t2.a) = (t1.b))
2701+
""".stripMargin).count() > 0
2702+
)
2703+
2704+
//sql("CREATE TEMPORARY VIEW foo(a) AS VALUES (CAST(-993 AS BIGINT))")
2705+
2706+
/*sql(
2707+
"""
2708+
|SELECT
2709+
|*
2710+
|FROM (
2711+
| SELECT
2712+
| COALESCE(t1.a, t2.a) AS int_col,
2713+
| t1.a,
2714+
| t2.a AS b
2715+
| FROM foo t1
2716+
| CROSS JOIN foo t2
2717+
|) t1
2718+
|INNER JOIN foo t2 ON (((t2.a) = (t1.a)) AND ((t2.a) = (t1.int_col))) AND ((t2.a) = (t1.b))
2719+
""".stripMargin).collect()*/
2720+
}
2721+
}
26812722
}

0 commit comments

Comments
 (0)