Skip to content

Commit 1bb3631

Browse files
committed
[SPARK-5454] More robust handling of self joins
Also I fix a bunch of bad output in test cases. Author: Michael Armbrust <[email protected]> Closes #4520 from marmbrus/selfJoin and squashes the following commits: 4f4a85c [Michael Armbrust] comments 49c8e26 [Michael Armbrust] fix tests 6fc38de [Michael Armbrust] fix style 55d64b3 [Michael Armbrust] fix dataframe selfjoins (cherry picked from commit a60d2b7) Signed-off-by: Michael Armbrust <[email protected]>
1 parent 72adfc5 commit 1bb3631

File tree

7 files changed

+40
-30
lines changed

7 files changed

+40
-30
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -53,14 +53,11 @@ class Analyzer(catalog: Catalog,
5353
val extendedRules: Seq[Rule[LogicalPlan]] = Nil
5454

5555
lazy val batches: Seq[Batch] = Seq(
56-
Batch("MultiInstanceRelations", Once,
57-
NewRelationInstances),
5856
Batch("Resolution", fixedPoint,
59-
ResolveReferences ::
6057
ResolveRelations ::
58+
ResolveReferences ::
6159
ResolveGroupingAnalytics ::
6260
ResolveSortReferences ::
63-
NewRelationInstances ::
6461
ImplicitGenerate ::
6562
ResolveFunctions ::
6663
GlobalAggregates ::
@@ -285,6 +282,27 @@ class Analyzer(catalog: Catalog,
285282
}
286283
)
287284

285+
// Special handling for cases when self-join introduce duplicate expression ids.
286+
case j @ Join(left, right, _, _) if left.outputSet.intersect(right.outputSet).nonEmpty =>
287+
val conflictingAttributes = left.outputSet.intersect(right.outputSet)
288+
289+
val (oldRelation, newRelation, attributeRewrites) = right.collect {
290+
case oldVersion: MultiInstanceRelation
291+
if oldVersion.outputSet.intersect(conflictingAttributes).nonEmpty =>
292+
val newVersion = oldVersion.newInstance()
293+
val newAttributes = AttributeMap(oldVersion.output.zip(newVersion.output))
294+
(oldVersion, newVersion, newAttributes)
295+
}.head // Only handle first case found, others will be fixed on the next pass.
296+
297+
val newRight = right transformUp {
298+
case r if r == oldRelation => newRelation
299+
case other => other transformExpressions {
300+
case a: Attribute => attributeRewrites.get(a).getOrElse(a)
301+
}
302+
}
303+
304+
j.copy(right = newRight)
305+
288306
case q: LogicalPlan =>
289307
logTrace(s"Attempting to resolve ${q.simpleString}")
290308
q transformExpressionsUp {

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/MultiInstanceRelation.scala

Lines changed: 1 addition & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -26,28 +26,9 @@ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
2626
* produced by distinct operators in a query tree as this breaks the guarantee that expression
2727
* ids, which are used to differentiate attributes, are unique.
2828
*
29-
* Before analysis, all operators that include this trait will be asked to produce a new version
29+
* During analysis, operators that include this trait may be asked to produce a new version
3030
* of itself with globally unique expression ids.
3131
*/
3232
trait MultiInstanceRelation {
3333
def newInstance(): this.type
3434
}
35-
36-
/**
37-
* If any MultiInstanceRelation appears more than once in the query plan then the plan is updated so
38-
* that each instance has unique expression ids for the attributes produced.
39-
*/
40-
object NewRelationInstances extends Rule[LogicalPlan] {
41-
def apply(plan: LogicalPlan): LogicalPlan = {
42-
val localRelations = plan collect { case l: MultiInstanceRelation => l}
43-
val multiAppearance = localRelations
44-
.groupBy(identity[MultiInstanceRelation])
45-
.filter { case (_, ls) => ls.size > 1 }
46-
.map(_._1)
47-
.toSet
48-
49-
plan transform {
50-
case l: MultiInstanceRelation if multiAppearance.contains(l) => l.newInstance()
51-
}
52-
}
53-
}

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,11 +33,9 @@ class PlanTest extends FunSuite {
3333
* we must normalize them to check if two different queries are identical.
3434
*/
3535
protected def normalizeExprIds(plan: LogicalPlan) = {
36-
val list = plan.flatMap(_.expressions.flatMap(_.references).map(_.exprId.id))
37-
val minId = if (list.isEmpty) 0 else list.min
3836
plan transformAllExpressions {
3937
case a: AttributeReference =>
40-
AttributeReference(a.name, a.dataType, a.nullable)(exprId = ExprId(a.exprId.id - minId))
38+
AttributeReference(a.name, a.dataType, a.nullable)(exprId = ExprId(0))
4139
}
4240
}
4341

sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,7 @@ class SQLContext(@transient val sparkContext: SparkContext)
122122
case _ =>
123123
}
124124

125+
@transient
125126
protected[sql] val cacheManager = new CacheManager(this)
126127

127128
/**
@@ -159,6 +160,7 @@ class SQLContext(@transient val sparkContext: SparkContext)
159160
* DataTypes.StringType);
160161
* }}}
161162
*/
163+
@transient
162164
val udf: UDFRegistration = new UDFRegistration(this)
163165

164166
/** Returns true if the table is currently cached in-memory. */

sql/core/src/test/resources/log4j.properties

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,9 @@ log4j.appender.FA.Threshold = INFO
3939
log4j.additivity.parquet.hadoop.ParquetRecordReader=false
4040
log4j.logger.parquet.hadoop.ParquetRecordReader=OFF
4141

42+
log4j.additivity.parquet.hadoop.ParquetOutputCommitter=false
43+
log4j.logger.parquet.hadoop.ParquetOutputCommitter=OFF
44+
4245
log4j.additivity.org.apache.hadoop.hive.serde2.lazy.LazyStruct=false
4346
log4j.logger.org.apache.hadoop.hive.serde2.lazy.LazyStruct=OFF
4447

sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ import org.apache.spark.sql.types._
2626
import org.apache.spark.sql.test.TestSQLContext
2727
import org.apache.spark.sql.test.TestSQLContext.logicalPlanToSparkQuery
2828
import org.apache.spark.sql.test.TestSQLContext.implicits._
29+
import org.apache.spark.sql.test.TestSQLContext.sql
2930

3031

3132
class DataFrameSuite extends QueryTest {
@@ -88,6 +89,15 @@ class DataFrameSuite extends QueryTest {
8889
testData.collect().toSeq)
8990
}
9091

92+
test("self join") {
93+
val df1 = testData.select(testData("key")).as('df1)
94+
val df2 = testData.select(testData("key")).as('df2)
95+
96+
checkAnswer(
97+
df1.join(df2, $"df1.key" === $"df2.key"),
98+
sql("SELECT a.key, b.key FROM testData a JOIN testData b ON a.key = b.key").collect().toSeq)
99+
}
100+
91101
test("selectExpr") {
92102
checkAnswer(
93103
testData.selectExpr("abs(key)", "value"),

sql/hive/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -35,11 +35,9 @@ class PlanTest extends FunSuite {
3535
* we must normalize them to check if two different queries are identical.
3636
*/
3737
protected def normalizeExprIds(plan: LogicalPlan) = {
38-
val list = plan.flatMap(_.expressions.flatMap(_.references).map(_.exprId.id))
39-
val minId = if (list.isEmpty) 0 else list.min
4038
plan transformAllExpressions {
4139
case a: AttributeReference =>
42-
AttributeReference(a.name, a.dataType, a.nullable)(exprId = ExprId(a.exprId.id - minId))
40+
AttributeReference(a.name, a.dataType, a.nullable)(exprId = ExprId(0))
4341
}
4442
}
4543

0 commit comments

Comments
 (0)