Skip to content

Commit e8a5fa3

Browse files
committed
properly handle different datasets with common lineage
1 parent 92cb513 commit e8a5fa3

File tree

2 files changed

+35
-7
lines changed

2 files changed

+35
-7
lines changed

sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala

Lines changed: 26 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,9 @@ private[sql] object Dataset {
7474
qe.assertAnalyzed()
7575
new Dataset[Row](sparkSession, qe, RowEncoder(qe.analyzed.schema))
7676
}
77+
78+
// String used as key in metadata of resolved attributes
79+
private val DATASET_ID = "dataset.hash"
7780
}
7881

7982
/**
@@ -217,11 +220,22 @@ class Dataset[T] private[sql](
217220
@transient lazy val sqlContext: SQLContext = sparkSession.sqlContext
218221

219222
private[sql] def resolve(colName: String): NamedExpression = {
220-
queryExecution.analyzed.resolveQuoted(colName, sparkSession.sessionState.analyzer.resolver)
221-
.getOrElse {
223+
val resolved = queryExecution.analyzed.resolveQuoted(colName,
224+
sparkSession.sessionState.analyzer.resolver).getOrElse {
222225
throw new AnalysisException(
223226
s"""Cannot resolve column name "$colName" among (${schema.fieldNames.mkString(", ")})""")
224227
}
228+
// We introduce in the metadata a reference to the Dataset the attribute is coming from because
229+
// it is useful to determine what this attribute is really referencing when performing
230+
// self-joins (or joins between dataset with common lineage) and the join condition contains
231+
// ambiguous references.
232+
resolved match {
233+
case a: AttributeReference =>
234+
val mBuilder = new MetadataBuilder()
235+
mBuilder.withMetadata(a.metadata).putLong(Dataset.DATASET_ID, this.hashCode().toLong)
236+
a.withMetadata(mBuilder.build())
237+
case other => other
238+
}
225239
}
226240

227241
private[sql] def numericColumns: Seq[Expression] = {
@@ -1002,9 +1016,16 @@ class Dataset[T] private[sql](
10021016
val cond = plan.condition.map { _.transform {
10031017
case e @ catalyst.expressions.BinaryComparison(a: AttributeReference, b: AttributeReference)
10041018
if a.sameRef(b) =>
1005-
e.withNewChildren(Seq(
1006-
withPlan(plan.left).resolve(a.name),
1007-
withPlan(plan.right).resolve(b.name)))
1019+
val bReferencesThis = b.metadata.contains(Dataset.DATASET_ID) &&
1020+
b.metadata.getLong(Dataset.DATASET_ID) == hashCode()
1021+
val aReferencesRight = a.metadata.contains(Dataset.DATASET_ID) &&
1022+
a.metadata.getLong(Dataset.DATASET_ID) == right.hashCode()
1023+
val newChildren = if (bReferencesThis && aReferencesRight) {
1024+
Seq(withPlan(plan.right).resolve(a.name), withPlan(plan.left).resolve(b.name))
1025+
} else {
1026+
Seq(withPlan(plan.left).resolve(a.name), withPlan(plan.right).resolve(b.name))
1027+
}
1028+
e.withNewChildren(newChildren)
10081029
}}
10091030

10101031
withPlan {

sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -290,13 +290,20 @@ class DataFrameJoinSuite extends QueryTest with SharedSQLContext {
290290

291291
test("SPARK-24385: Resolve ambiguity in self-joins with operators different from EqualsTo") {
292292
withSQLConf(SQLConf.CROSS_JOINS_ENABLED.key -> "false") {
293-
val df = spark.range(10)
294-
// these should not throw any exception
293+
val df = spark.range(2)
294+
295+
// These should not throw any exception.
295296
df.join(df, df("id") >= df("id")).queryExecution.optimizedPlan
296297
df.join(df, df("id") <=> df("id")).queryExecution.optimizedPlan
297298
df.join(df, df("id") <= df("id")).queryExecution.optimizedPlan
298299
df.join(df, df("id") > df("id")).queryExecution.optimizedPlan
299300
df.join(df, df("id") < df("id")).queryExecution.optimizedPlan
301+
302+
// Check we properly resolve columns when datasets are different but they share a common
303+
// lineage.
304+
val df1 = df.groupBy("id").count()
305+
val df2 = df.groupBy("id").sum("id")
306+
checkAnswer(df1.join(df2, df2("id") < df1("id")), Seq(Row(1, 1, 0, 0)))
300307
}
301308
}
302309
}

0 commit comments

Comments
 (0)