Skip to content

Commit a082f46

Browse files
Ngone51cloud-fan
authored andcommitted
[SPARK-33071][SPARK-33536][SQL] Avoid changing dataset_id of LogicalPlan in join() to not break DetectAmbiguousSelfJoin
### What changes were proposed in this pull request? Currently, `join()` uses `withPlan(logicalPlan)` for convenient to call some Dataset functions. But it leads to the `dataset_id` inconsistent between the `logicalPlan` and the original `Dataset`(because `withPlan(logicalPlan)` will create a new Dataset with the new id and reset the `dataset_id` with the new id of the `logicalPlan`). As a result, it breaks the rule `DetectAmbiguousSelfJoin`. In this PR, we propose to drop the usage of `withPlan` but use the `logicalPlan` directly so its `dataset_id` doesn't change. Besides, this PR also removes related metadata (`DATASET_ID_KEY`, `COL_POS_KEY`) when an `Alias` tries to construct its own metadata. Because the `Alias` is no longer a reference column after converting to an `Attribute`. To achieve that, we add a new field, `deniedMetadataKeys`, to indicate the metadata that needs to be removed. ### Why are the changes needed? For the query below, it returns the wrong result while it should throws ambiguous self join exception instead: ```scala val emp1 = Seq[TestData]( TestData(1, "sales"), TestData(2, "personnel"), TestData(3, "develop"), TestData(4, "IT")).toDS() val emp2 = Seq[TestData]( TestData(1, "sales"), TestData(2, "personnel"), TestData(3, "develop")).toDS() val emp3 = emp1.join(emp2, emp1("key") === emp2("key")).select(emp1("*")) emp1.join(emp3, emp1.col("key") === emp3.col("key"), "left_outer") .select(emp1.col("*"), emp3.col("key").as("e2")).show() // wrong result +---+---------+---+ |key| value| e2| +---+---------+---+ | 1| sales| 1| | 2|personnel| 2| | 3| develop| 3| | 4| IT| 4| +---+---------+---+ ``` This PR fixes the wrong behaviour. ### Does this PR introduce _any_ user-facing change? Yes, users hit the exception instead of the wrong result after this PR. ### How was this patch tested? Added a new unit test. Closes #30488 from Ngone51/fix-self-join. Authored-by: yi.wu <[email protected]> Signed-off-by: Wenchen Fan <[email protected]>
1 parent 91182d6 commit a082f46

File tree

6 files changed

+73
-25
lines changed

6 files changed

+73
-25
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AliasHelper.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,8 @@ trait AliasHelper {
8989
a.copy(child = trimAliases(a.child))(
9090
exprId = a.exprId,
9191
qualifier = a.qualifier,
92-
explicitMetadata = Some(a.metadata))
92+
explicitMetadata = Some(a.metadata),
93+
deniedMetadataKeys = a.deniedMetadataKeys)
9394
case a: MultiAlias =>
9495
a.copy(child = trimAliases(a.child))
9596
case other => trimAliases(other)

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -143,11 +143,14 @@ abstract class Attribute extends LeafExpression with NamedExpression with NullIn
143143
* fully qualified way. Consider the examples tableName.name, subQueryAlias.name.
144144
* tableName and subQueryAlias are possible qualifiers.
145145
* @param explicitMetadata Explicit metadata associated with this alias that overwrites child's.
146+
* @param deniedMetadataKeys Keys of metadata entries that are supposed to be removed when
147+
* inheriting the metadata from the child.
146148
*/
147149
case class Alias(child: Expression, name: String)(
148150
val exprId: ExprId = NamedExpression.newExprId,
149151
val qualifier: Seq[String] = Seq.empty,
150-
val explicitMetadata: Option[Metadata] = None)
152+
val explicitMetadata: Option[Metadata] = None,
153+
val deniedMetadataKeys: Seq[String] = Seq.empty)
151154
extends UnaryExpression with NamedExpression {
152155

153156
// Alias(Generator, xx) need to be transformed into Generate(generator, ...)
@@ -167,7 +170,11 @@ case class Alias(child: Expression, name: String)(
167170
override def metadata: Metadata = {
168171
explicitMetadata.getOrElse {
169172
child match {
170-
case named: NamedExpression => named.metadata
173+
case named: NamedExpression =>
174+
val builder = new MetadataBuilder().withMetadata(named.metadata)
175+
deniedMetadataKeys.foreach(builder.remove)
176+
builder.build()
177+
171178
case _ => Metadata.empty
172179
}
173180
}
@@ -194,7 +201,7 @@ case class Alias(child: Expression, name: String)(
194201
override def toString: String = s"$child AS $name#${exprId.id}$typeSuffix$delaySuffix"
195202

196203
override protected final def otherCopyArgs: Seq[AnyRef] = {
197-
exprId :: qualifier :: explicitMetadata :: Nil
204+
exprId :: qualifier :: explicitMetadata :: deniedMetadataKeys :: Nil
198205
}
199206

200207
override def hashCode(): Int = {
@@ -205,7 +212,7 @@ case class Alias(child: Expression, name: String)(
205212
override def equals(other: Any): Boolean = other match {
206213
case a: Alias =>
207214
name == a.name && exprId == a.exprId && child == a.child && qualifier == a.qualifier &&
208-
explicitMetadata == a.explicitMetadata
215+
explicitMetadata == a.explicitMetadata && deniedMetadataKeys == a.deniedMetadataKeys
209216
case _ => false
210217
}
211218

sql/core/src/main/scala/org/apache/spark/sql/Column.scala

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1164,7 +1164,10 @@ class Column(val expr: Expression) extends Logging {
11641164
* @since 2.0.0
11651165
*/
11661166
def name(alias: String): Column = withExpr {
1167-
Alias(normalizedExpr(), alias)()
1167+
// SPARK-33536: The Alias is no longer a column reference after converting to an attribute.
1168+
// These denied metadata keys are used to strip the column reference related metadata for
1169+
// the Alias. So it won't be caught as a column reference in DetectAmbiguousSelfJoin.
1170+
Alias(expr, alias)(deniedMetadataKeys = Seq(Dataset.DATASET_ID_KEY, Dataset.COL_POS_KEY))
11681171
}
11691172

11701173
/**

sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala

Lines changed: 23 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -231,7 +231,8 @@ class Dataset[T] private[sql](
231231
case _ =>
232232
queryExecution.analyzed
233233
}
234-
if (sparkSession.sessionState.conf.getConf(SQLConf.FAIL_AMBIGUOUS_SELF_JOIN_ENABLED)) {
234+
if (sparkSession.sessionState.conf.getConf(SQLConf.FAIL_AMBIGUOUS_SELF_JOIN_ENABLED) &&
235+
plan.getTagValue(Dataset.DATASET_ID_TAG).isEmpty) {
235236
plan.setTagValue(Dataset.DATASET_ID_TAG, id)
236237
}
237238
plan
@@ -259,15 +260,16 @@ class Dataset[T] private[sql](
259260
private[sql] def resolve(colName: String): NamedExpression = {
260261
val resolver = sparkSession.sessionState.analyzer.resolver
261262
queryExecution.analyzed.resolveQuoted(colName, resolver)
262-
.getOrElse {
263-
val fields = schema.fieldNames
264-
val extraMsg = if (fields.exists(resolver(_, colName))) {
265-
s"; did you mean to quote the `$colName` column?"
266-
} else ""
267-
val fieldsStr = fields.mkString(", ")
268-
val errorMsg = s"""Cannot resolve column name "$colName" among (${fieldsStr})${extraMsg}"""
269-
throw new AnalysisException(errorMsg)
270-
}
263+
.getOrElse(throw resolveException(colName, schema.fieldNames))
264+
}
265+
266+
private def resolveException(colName: String, fields: Array[String]): AnalysisException = {
267+
val extraMsg = if (fields.exists(sparkSession.sessionState.analyzer.resolver(_, colName))) {
268+
s"; did you mean to quote the `$colName` column?"
269+
} else ""
270+
val fieldsStr = fields.mkString(", ")
271+
val errorMsg = s"""Cannot resolve column name "$colName" among (${fieldsStr})${extraMsg}"""
272+
new AnalysisException(errorMsg)
271273
}
272274

273275
private[sql] def numericColumns: Seq[Expression] = {
@@ -1083,26 +1085,31 @@ class Dataset[T] private[sql](
10831085
}
10841086

10851087
// If left/right have no output set intersection, return the plan.
1086-
val lanalyzed = withPlan(this.logicalPlan).queryExecution.analyzed
1087-
val ranalyzed = withPlan(right.logicalPlan).queryExecution.analyzed
1088+
val lanalyzed = this.queryExecution.analyzed
1089+
val ranalyzed = right.queryExecution.analyzed
10881090
if (lanalyzed.outputSet.intersect(ranalyzed.outputSet).isEmpty) {
10891091
return withPlan(plan)
10901092
}
10911093

10921094
// Otherwise, find the trivially true predicates and automatically resolves them to both sides.
10931095
// By the time we get here, since we have already run analysis, all attributes should've been
10941096
// resolved and become AttributeReference.
1097+
val resolver = sparkSession.sessionState.analyzer.resolver
10951098
val cond = plan.condition.map { _.transform {
10961099
case catalyst.expressions.EqualTo(a: AttributeReference, b: AttributeReference)
10971100
if a.sameRef(b) =>
10981101
catalyst.expressions.EqualTo(
1099-
withPlan(plan.left).resolve(a.name),
1100-
withPlan(plan.right).resolve(b.name))
1102+
plan.left.resolveQuoted(a.name, resolver)
1103+
.getOrElse(throw resolveException(a.name, plan.left.schema.fieldNames)),
1104+
plan.right.resolveQuoted(b.name, resolver)
1105+
.getOrElse(throw resolveException(b.name, plan.right.schema.fieldNames)))
11011106
case catalyst.expressions.EqualNullSafe(a: AttributeReference, b: AttributeReference)
11021107
if a.sameRef(b) =>
11031108
catalyst.expressions.EqualNullSafe(
1104-
withPlan(plan.left).resolve(a.name),
1105-
withPlan(plan.right).resolve(b.name))
1109+
plan.left.resolveQuoted(a.name, resolver)
1110+
.getOrElse(throw resolveException(a.name, plan.left.schema.fieldNames)),
1111+
plan.right.resolveQuoted(b.name, resolver)
1112+
.getOrElse(throw resolveException(b.name, plan.right.schema.fieldNames)))
11061113
}}
11071114

11081115
withPlan {

sql/core/src/test/scala/org/apache/spark/sql/DataFrameSelfJoinSuite.scala

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ import org.apache.spark.sql.expressions.Window
2121
import org.apache.spark.sql.functions.{count, sum}
2222
import org.apache.spark.sql.internal.SQLConf
2323
import org.apache.spark.sql.test.SharedSparkSession
24+
import org.apache.spark.sql.test.SQLTestData.TestData
2425

2526
class DataFrameSelfJoinSuite extends QueryTest with SharedSparkSession {
2627
import testImplicits._
@@ -219,4 +220,32 @@ class DataFrameSelfJoinSuite extends QueryTest with SharedSparkSession {
219220
Seq((1, 2), (1, 2), (2, 4), (2, 4)).map(Row.fromTuple))
220221
}
221222
}
223+
224+
test("SPARK-33071/SPARK-33536: Avoid changing dataset_id of LogicalPlan in join() " +
225+
"to not break DetectAmbiguousSelfJoin") {
226+
val emp1 = Seq[TestData](
227+
TestData(1, "sales"),
228+
TestData(2, "personnel"),
229+
TestData(3, "develop"),
230+
TestData(4, "IT")).toDS()
231+
val emp2 = Seq[TestData](
232+
TestData(1, "sales"),
233+
TestData(2, "personnel"),
234+
TestData(3, "develop")).toDS()
235+
val emp3 = emp1.join(emp2, emp1("key") === emp2("key")).select(emp1("*"))
236+
assertAmbiguousSelfJoin(emp1.join(emp3, emp1.col("key") === emp3.col("key"),
237+
"left_outer").select(emp1.col("*"), emp3.col("key").as("e2")))
238+
}
239+
240+
test("df.show() should also not change dataset_id of LogicalPlan") {
241+
val df = Seq[TestData](
242+
TestData(1, "sales"),
243+
TestData(2, "personnel"),
244+
TestData(3, "develop"),
245+
TestData(4, "IT")).toDF()
246+
val ds_id1 = df.logicalPlan.getTagValue(Dataset.DATASET_ID_TAG)
247+
df.show(0)
248+
val ds_id2 = df.logicalPlan.getTagValue(Dataset.DATASET_ID_TAG)
249+
assert(ds_id1 === ds_id2)
250+
}
222251
}

sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -573,8 +573,9 @@ class ColumnarBoundReference(ordinal: Int, dataType: DataType, nullable: Boolean
573573
class ColumnarAlias(child: ColumnarExpression, name: String)(
574574
override val exprId: ExprId = NamedExpression.newExprId,
575575
override val qualifier: Seq[String] = Seq.empty,
576-
override val explicitMetadata: Option[Metadata] = None)
577-
extends Alias(child, name)(exprId, qualifier, explicitMetadata)
576+
override val explicitMetadata: Option[Metadata] = None,
577+
override val deniedMetadataKeys: Seq[String] = Seq.empty)
578+
extends Alias(child, name)(exprId, qualifier, explicitMetadata, deniedMetadataKeys)
578579
with ColumnarExpression {
579580

580581
override def columnarEval(batch: ColumnarBatch): Any = child.columnarEval(batch)
@@ -711,7 +712,7 @@ case class PreRuleReplaceAddWithBrokenVersion() extends Rule[SparkPlan] {
711712
def replaceWithColumnarExpression(exp: Expression): ColumnarExpression = exp match {
712713
case a: Alias =>
713714
new ColumnarAlias(replaceWithColumnarExpression(a.child),
714-
a.name)(a.exprId, a.qualifier, a.explicitMetadata)
715+
a.name)(a.exprId, a.qualifier, a.explicitMetadata, a.deniedMetadataKeys)
715716
case att: AttributeReference =>
716717
new ColumnarAttributeReference(att.name, att.dataType, att.nullable,
717718
att.metadata)(att.exprId, att.qualifier)

0 commit comments

Comments
 (0)