Skip to content

Commit 3579003

Browse files
committed
[SPARK-6247][SQL] Fix resolution of ambiguous joins caused by new aliases
We need to handle ambiguous `exprId`s that are produced by new aliases as well as those caused by leaf nodes (`MultiInstanceRelation`). Attempting to fix this revealed a bug in `equals` for `Alias` as these objects were comparing equal even when the expression ids did not match. Additionally, `LocalRelation` did not correctly provide statistics, and some tests in `catalyst` and `hive` were not using the helper functions for comparing plans. Based on apache#4991 by chenghao-intel Author: Michael Armbrust <[email protected]> Closes apache#5062 from marmbrus/selfJoins and squashes the following commits: 8e9b84b [Michael Armbrust] check qualifier too 8038a36 [Michael Armbrust] handle aggs too 0b9c687 [Michael Armbrust] fix more tests c3c574b [Michael Armbrust] revert change. 725f1ab [Michael Armbrust] add statistics a925d08 [Michael Armbrust] check for conflicting attributes in join resolution b022ef7 [Michael Armbrust] Handle project aliases. d8caa40 [Michael Armbrust] test case: SPARK-6247 f9c67c2 [Michael Armbrust] Check for duplicate attributes in join resolution. 898af73 [Michael Armbrust] Fix Alias equality.
1 parent a6ee2f7 commit 3579003

File tree

9 files changed

+96
-12
lines changed

9 files changed

+96
-12
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala

Lines changed: 26 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -237,22 +237,33 @@ class Analyzer(catalog: Catalog,
237237
// Special handling for cases when self-join introduce duplicate expression ids.
238238
case j @ Join(left, right, _, _) if left.outputSet.intersect(right.outputSet).nonEmpty =>
239239
val conflictingAttributes = left.outputSet.intersect(right.outputSet)
240+
logDebug(s"Conflicting attributes ${conflictingAttributes.mkString(",")} in $j")
240241

241-
val (oldRelation, newRelation, attributeRewrites) = right.collect {
242+
val (oldRelation, newRelation) = right.collect {
243+
// Handle base relations that might appear more than once.
242244
case oldVersion: MultiInstanceRelation
243245
if oldVersion.outputSet.intersect(conflictingAttributes).nonEmpty =>
244246
val newVersion = oldVersion.newInstance()
245-
val newAttributes = AttributeMap(oldVersion.output.zip(newVersion.output))
246-
(oldVersion, newVersion, newAttributes)
247+
(oldVersion, newVersion)
248+
249+
// Handle projects that create conflicting aliases.
250+
case oldVersion @ Project(projectList, _)
251+
if findAliases(projectList).intersect(conflictingAttributes).nonEmpty =>
252+
(oldVersion, oldVersion.copy(projectList = newAliases(projectList)))
253+
254+
case oldVersion @ Aggregate(_, aggregateExpressions, _)
255+
if findAliases(aggregateExpressions).intersect(conflictingAttributes).nonEmpty =>
256+
(oldVersion, oldVersion.copy(aggregateExpressions = newAliases(aggregateExpressions)))
247257
}.head // Only handle first case found, others will be fixed on the next pass.
248258

259+
val attributeRewrites = AttributeMap(oldRelation.output.zip(newRelation.output))
249260
val newRight = right transformUp {
250261
case r if r == oldRelation => newRelation
262+
} transformUp {
251263
case other => other transformExpressions {
252264
case a: Attribute => attributeRewrites.get(a).getOrElse(a)
253265
}
254266
}
255-
256267
j.copy(right = newRight)
257268

258269
case q: LogicalPlan =>
@@ -272,6 +283,17 @@ class Analyzer(catalog: Catalog,
272283
}
273284
}
274285

286+
def newAliases(expressions: Seq[NamedExpression]): Seq[NamedExpression] = {
287+
expressions.map {
288+
case a: Alias => Alias(a.child, a.name)()
289+
case other => other
290+
}
291+
}
292+
293+
def findAliases(projectList: Seq[NamedExpression]): AttributeSet = {
294+
AttributeSet(projectList.collect { case a: Alias => a.toAttribute })
295+
}
296+
275297
/**
276298
* Returns true if `exprs` contains a [[Star]].
277299
*/

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,12 @@ case class Alias(child: Expression, name: String)
124124
override def toString: String = s"$child AS $name#${exprId.id}$typeSuffix"
125125

126126
override protected final def otherCopyArgs = exprId :: qualifiers :: Nil
127+
128+
override def equals(other: Any): Boolean = other match {
129+
case a: Alias =>
130+
name == a.name && exprId == a.exprId && child == a.child && qualifiers == a.qualifiers
131+
case _ => false
132+
}
127133
}
128134

129135
/**

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LocalRelation.scala

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,4 +54,7 @@ case class LocalRelation(output: Seq[Attribute], data: Seq[Row] = Nil)
5454
otherOutput.map(_.dataType) == output.map(_.dataType) && otherData == data
5555
case _ => false
5656
}
57+
58+
override lazy val statistics =
59+
Statistics(sizeInBytes = output.map(_.dataType.defaultSize).sum * data.length)
5760
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,13 @@ case class Join(
108108
left.output ++ right.output
109109
}
110110
}
111+
112+
def selfJoinResolved = left.outputSet.intersect(right.outputSet).isEmpty
113+
114+
// Joins are only resolved if they don't introduce ambiguious expression ids.
115+
override lazy val resolved: Boolean = {
116+
childrenResolved && !expressions.exists(!_.resolved) && selfJoinResolved
117+
}
111118
}
112119

113120
case class Except(left: LogicalPlan, right: LogicalPlan) extends BinaryNode {

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,13 +17,13 @@
1717

1818
package org.apache.spark.sql.catalyst.analysis
1919

20-
import org.scalatest.FunSuite
20+
import org.apache.spark.sql.catalyst.plans.PlanTest
2121

2222
import org.apache.spark.sql.catalyst.expressions._
2323
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, Project}
2424
import org.apache.spark.sql.types._
2525

26-
class HiveTypeCoercionSuite extends FunSuite {
26+
class HiveTypeCoercionSuite extends PlanTest {
2727

2828
test("tightest common bound for types") {
2929
def widenTest(t1: DataType, t2: DataType, tightestCommon: Option[DataType]) {
@@ -106,7 +106,8 @@ class HiveTypeCoercionSuite extends FunSuite {
106106
val booleanCasts = new HiveTypeCoercion { }.BooleanCasts
107107
def ruleTest(initial: Expression, transformed: Expression) {
108108
val testRelation = LocalRelation(AttributeReference("a", IntegerType)())
109-
assert(booleanCasts(Project(Seq(Alias(initial, "a")()), testRelation)) ==
109+
comparePlans(
110+
booleanCasts(Project(Seq(Alias(initial, "a")()), testRelation)),
110111
Project(Seq(Alias(transformed, "a")()), testRelation))
111112
}
112113
// Remove superflous boolean -> boolean casts.
@@ -119,7 +120,8 @@ class HiveTypeCoercionSuite extends FunSuite {
119120
val fac = new HiveTypeCoercion { }.FunctionArgumentConversion
120121
def ruleTest(initial: Expression, transformed: Expression) {
121122
val testRelation = LocalRelation(AttributeReference("a", IntegerType)())
122-
assert(fac(Project(Seq(Alias(initial, "a")()), testRelation)) ==
123+
comparePlans(
124+
fac(Project(Seq(Alias(initial, "a")()), testRelation)),
123125
Project(Seq(Alias(transformed, "a")()), testRelation))
124126
}
125127
ruleTest(

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,8 @@ package org.apache.spark.sql.catalyst.plans
1919

2020
import org.scalatest.FunSuite
2121

22-
import org.apache.spark.sql.catalyst.expressions.{ExprId, AttributeReference}
23-
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
22+
import org.apache.spark.sql.catalyst.expressions._
23+
import org.apache.spark.sql.catalyst.plans.logical.{NoRelation, Filter, LogicalPlan}
2424
import org.apache.spark.sql.catalyst.util._
2525

2626
/**
@@ -36,6 +36,8 @@ class PlanTest extends FunSuite {
3636
plan transformAllExpressions {
3737
case a: AttributeReference =>
3838
AttributeReference(a.name, a.dataType, a.nullable)(exprId = ExprId(0))
39+
case a: Alias =>
40+
Alias(a.child, a.name)(exprId = ExprId(0))
3941
}
4042
}
4143

@@ -50,4 +52,9 @@ class PlanTest extends FunSuite {
5052
|${sideBySide(normalized1.treeString, normalized2.treeString).mkString("\n")}
5153
""".stripMargin)
5254
}
55+
56+
/** Fails the test if the two expressions do not match */
57+
protected def compareExpressions(e1: Expression, e2: Expression): Unit = {
58+
comparePlans(Filter(e1, NoRelation), Filter(e2, NoRelation))
59+
}
5360
}

sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717

1818
package org.apache.spark.sql
1919

20+
import org.apache.spark.sql.catalyst.expressions.NamedExpression
21+
import org.apache.spark.sql.catalyst.plans.logical.{Project, NoRelation}
2022
import org.apache.spark.sql.functions._
2123
import org.apache.spark.sql.test.TestSQLContext
2224
import org.apache.spark.sql.test.TestSQLContext.implicits._
@@ -311,7 +313,9 @@ class ColumnExpressionSuite extends QueryTest {
311313
}
312314

313315
test("lift alias out of cast") {
314-
assert(col("1234").as("name").cast("int").expr === col("1234").cast("int").as("name").expr)
316+
compareExpressions(
317+
col("1234").as("name").cast("int").expr,
318+
col("1234").cast("int").as("name").expr)
315319
}
316320

317321
test("columns can be compared") {

sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,37 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll {
3636
import org.apache.spark.sql.test.TestSQLContext.implicits._
3737
val sqlCtx = TestSQLContext
3838

39+
test("self join with aliases") {
40+
Seq(1,2,3).map(i => (i, i.toString)).toDF("int", "str").registerTempTable("df")
41+
42+
checkAnswer(
43+
sql(
44+
"""
45+
|SELECT x.str, COUNT(*)
46+
|FROM df x JOIN df y ON x.str = y.str
47+
|GROUP BY x.str
48+
""".stripMargin),
49+
Row("1", 1) :: Row("2", 1) :: Row("3", 1) :: Nil)
50+
}
51+
52+
test("self join with alias in agg") {
53+
Seq(1,2,3)
54+
.map(i => (i, i.toString))
55+
.toDF("int", "str")
56+
.groupBy("str")
57+
.agg($"str", count("str").as("strCount"))
58+
.registerTempTable("df")
59+
60+
checkAnswer(
61+
sql(
62+
"""
63+
|SELECT x.str, SUM(x.strCount)
64+
|FROM df x JOIN df y ON x.str = y.str
65+
|GROUP BY x.str
66+
""".stripMargin),
67+
Row("1", 1) :: Row("2", 1) :: Row("3", 1) :: Nil)
68+
}
69+
3970
test("SPARK-4625 support SORT BY in SimpleSQLParser & DSL") {
4071
checkAnswer(
4172
sql("SELECT a FROM testData2 SORT BY a"),

sql/hive/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
package org.apache.spark.sql.catalyst.plans
1919

20-
import org.apache.spark.sql.catalyst.expressions.{AttributeReference, ExprId}
20+
import org.apache.spark.sql.catalyst.expressions.{Alias, AttributeReference, ExprId}
2121
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
2222
import org.apache.spark.sql.catalyst.util._
2323
import org.scalatest.FunSuite
@@ -38,6 +38,8 @@ class PlanTest extends FunSuite {
3838
plan transformAllExpressions {
3939
case a: AttributeReference =>
4040
AttributeReference(a.name, a.dataType, a.nullable)(exprId = ExprId(0))
41+
case a: Alias =>
42+
Alias(a.child, a.name)(exprId = ExprId(0))
4143
}
4244
}
4345

0 commit comments

Comments
 (0)