Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,8 @@ trait AliasHelper {
a.copy(child = trimAliases(a.child))(
exprId = a.exprId,
qualifier = a.qualifier,
explicitMetadata = Some(a.metadata))
explicitMetadata = Some(a.metadata),
deniedMetadataKeys = a.deniedMetadataKeys)
case a: MultiAlias =>
a.copy(child = trimAliases(a.child))
case other => trimAliases(other)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -143,11 +143,14 @@ abstract class Attribute extends LeafExpression with NamedExpression with NullIn
* fully qualified way. Consider the examples tableName.name, subQueryAlias.name.
* tableName and subQueryAlias are possible qualifiers.
* @param explicitMetadata Explicit metadata associated with this alias that overwrites child's.
* @param deniedMetadataKeys Keys of metadata entries that are supposed to be removed when
* inheriting the metadata from the child.
*/
case class Alias(child: Expression, name: String)(
val exprId: ExprId = NamedExpression.newExprId,
val qualifier: Seq[String] = Seq.empty,
val explicitMetadata: Option[Metadata] = None)
val explicitMetadata: Option[Metadata] = None,
val deniedMetadataKeys: Seq[String] = Seq.empty)
extends UnaryExpression with NamedExpression {

// Alias(Generator, xx) need to be transformed into Generate(generator, ...)
Expand All @@ -167,7 +170,11 @@ case class Alias(child: Expression, name: String)(
override def metadata: Metadata = {
explicitMetadata.getOrElse {
child match {
case named: NamedExpression => named.metadata
case named: NamedExpression =>
val builder = new MetadataBuilder().withMetadata(named.metadata)
deniedMetadataKeys.foreach(builder.remove)
builder.build()

case _ => Metadata.empty
}
}
Expand All @@ -194,7 +201,7 @@ case class Alias(child: Expression, name: String)(
override def toString: String = s"$child AS $name#${exprId.id}$typeSuffix$delaySuffix"

override protected final def otherCopyArgs: Seq[AnyRef] = {
exprId :: qualifier :: explicitMetadata :: Nil
exprId :: qualifier :: explicitMetadata :: deniedMetadataKeys :: Nil
}

override def hashCode(): Int = {
Expand All @@ -205,7 +212,7 @@ case class Alias(child: Expression, name: String)(
override def equals(other: Any): Boolean = other match {
case a: Alias =>
name == a.name && exprId == a.exprId && child == a.child && qualifier == a.qualifier &&
explicitMetadata == a.explicitMetadata
explicitMetadata == a.explicitMetadata && deniedMetadataKeys == a.deniedMetadataKeys
case _ => false
}

Expand Down
5 changes: 4 additions & 1 deletion sql/core/src/main/scala/org/apache/spark/sql/Column.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1164,7 +1164,10 @@ class Column(val expr: Expression) extends Logging {
* @since 2.0.0
*/
def name(alias: String): Column = withExpr {
Alias(normalizedExpr(), alias)()
// SPARK-33536: The Alias is no longer a column reference after converting to an attribute.
// These denied metadata keys are used to strip the column reference related metadata for
// the Alias. So it won't be caught as a column reference in DetectAmbiguousSelfJoin.
Alias(expr, alias)(deniedMetadataKeys = Seq(Dataset.DATASET_ID_KEY, Dataset.COL_POS_KEY))
}

/**
Expand Down
39 changes: 23 additions & 16 deletions sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,8 @@ class Dataset[T] private[sql](
case _ =>
queryExecution.analyzed
}
if (sparkSession.sessionState.conf.getConf(SQLConf.FAIL_AMBIGUOUS_SELF_JOIN_ENABLED)) {
if (sparkSession.sessionState.conf.getConf(SQLConf.FAIL_AMBIGUOUS_SELF_JOIN_ENABLED) &&
plan.getTagValue(Dataset.DATASET_ID_TAG).isEmpty) {
plan.setTagValue(Dataset.DATASET_ID_TAG, id)
}
plan
Expand Down Expand Up @@ -259,15 +260,16 @@ class Dataset[T] private[sql](
private[sql] def resolve(colName: String): NamedExpression = {
val resolver = sparkSession.sessionState.analyzer.resolver
queryExecution.analyzed.resolveQuoted(colName, resolver)
.getOrElse {
val fields = schema.fieldNames
val extraMsg = if (fields.exists(resolver(_, colName))) {
s"; did you mean to quote the `$colName` column?"
} else ""
val fieldsStr = fields.mkString(", ")
val errorMsg = s"""Cannot resolve column name "$colName" among (${fieldsStr})${extraMsg}"""
throw new AnalysisException(errorMsg)
}
.getOrElse(throw resolveException(colName, schema.fieldNames))
}

private def resolveException(colName: String, fields: Array[String]): AnalysisException = {
val extraMsg = if (fields.exists(sparkSession.sessionState.analyzer.resolver(_, colName))) {
s"; did you mean to quote the `$colName` column?"
} else ""
val fieldsStr = fields.mkString(", ")
val errorMsg = s"""Cannot resolve column name "$colName" among (${fieldsStr})${extraMsg}"""
new AnalysisException(errorMsg)
}

private[sql] def numericColumns: Seq[Expression] = {
Expand Down Expand Up @@ -1083,26 +1085,31 @@ class Dataset[T] private[sql](
}

// If left/right have no output set intersection, return the plan.
val lanalyzed = withPlan(this.logicalPlan).queryExecution.analyzed
val ranalyzed = withPlan(right.logicalPlan).queryExecution.analyzed
val lanalyzed = this.queryExecution.analyzed
val ranalyzed = right.queryExecution.analyzed
if (lanalyzed.outputSet.intersect(ranalyzed.outputSet).isEmpty) {
return withPlan(plan)
}

// Otherwise, find the trivially true predicates and automatically resolves them to both sides.
// By the time we get here, since we have already run analysis, all attributes should've been
// resolved and become AttributeReference.
val resolver = sparkSession.sessionState.analyzer.resolver
val cond = plan.condition.map { _.transform {
case catalyst.expressions.EqualTo(a: AttributeReference, b: AttributeReference)
if a.sameRef(b) =>
catalyst.expressions.EqualTo(
withPlan(plan.left).resolve(a.name),
withPlan(plan.right).resolve(b.name))
plan.left.resolveQuoted(a.name, resolver)
.getOrElse(throw resolveException(a.name, plan.left.schema.fieldNames)),
plan.right.resolveQuoted(b.name, resolver)
.getOrElse(throw resolveException(b.name, plan.right.schema.fieldNames)))
case catalyst.expressions.EqualNullSafe(a: AttributeReference, b: AttributeReference)
if a.sameRef(b) =>
catalyst.expressions.EqualNullSafe(
withPlan(plan.left).resolve(a.name),
withPlan(plan.right).resolve(b.name))
plan.left.resolveQuoted(a.name, resolver)
.getOrElse(throw resolveException(a.name, plan.left.schema.fieldNames)),
plan.right.resolveQuoted(b.name, resolver)
.getOrElse(throw resolveException(b.name, plan.right.schema.fieldNames)))
}}

withPlan {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import org.apache.spark.sql.expressions.Window
import org.apache.spark.sql.functions.{count, sum}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SharedSparkSession
import org.apache.spark.sql.test.SQLTestData.TestData

class DataFrameSelfJoinSuite extends QueryTest with SharedSparkSession {
import testImplicits._
Expand Down Expand Up @@ -219,4 +220,32 @@ class DataFrameSelfJoinSuite extends QueryTest with SharedSparkSession {
Seq((1, 2), (1, 2), (2, 4), (2, 4)).map(Row.fromTuple))
}
}

test("SPARK-33071/SPARK-33536: Avoid changing dataset_id of LogicalPlan in join() " +
"to not break DetectAmbiguousSelfJoin") {
val emp1 = Seq[TestData](
TestData(1, "sales"),
TestData(2, "personnel"),
TestData(3, "develop"),
TestData(4, "IT")).toDS()
val emp2 = Seq[TestData](
TestData(1, "sales"),
TestData(2, "personnel"),
TestData(3, "develop")).toDS()
val emp3 = emp1.join(emp2, emp1("key") === emp2("key")).select(emp1("*"))
assertAmbiguousSelfJoin(emp1.join(emp3, emp1.col("key") === emp3.col("key"),
"left_outer").select(emp1.col("*"), emp3.col("key").as("e2")))
}

test("df.show() should also not change dataset_id of LogicalPlan") {
val df = Seq[TestData](
TestData(1, "sales"),
TestData(2, "personnel"),
TestData(3, "develop"),
TestData(4, "IT")).toDF()
val ds_id1 = df.logicalPlan.getTagValue(Dataset.DATASET_ID_TAG)
df.show(0)
val ds_id2 = df.logicalPlan.getTagValue(Dataset.DATASET_ID_TAG)
assert(ds_id1 === ds_id2)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -573,8 +573,9 @@ class ColumnarBoundReference(ordinal: Int, dataType: DataType, nullable: Boolean
class ColumnarAlias(child: ColumnarExpression, name: String)(
override val exprId: ExprId = NamedExpression.newExprId,
override val qualifier: Seq[String] = Seq.empty,
override val explicitMetadata: Option[Metadata] = None)
extends Alias(child, name)(exprId, qualifier, explicitMetadata)
override val explicitMetadata: Option[Metadata] = None,
override val deniedMetadataKeys: Seq[String] = Seq.empty)
extends Alias(child, name)(exprId, qualifier, explicitMetadata, deniedMetadataKeys)
with ColumnarExpression {

override def columnarEval(batch: ColumnarBatch): Any = child.columnarEval(batch)
Expand Down Expand Up @@ -711,7 +712,7 @@ case class PreRuleReplaceAddWithBrokenVersion() extends Rule[SparkPlan] {
def replaceWithColumnarExpression(exp: Expression): ColumnarExpression = exp match {
case a: Alias =>
new ColumnarAlias(replaceWithColumnarExpression(a.child),
a.name)(a.exprId, a.qualifier, a.explicitMetadata)
a.name)(a.exprId, a.qualifier, a.explicitMetadata, a.deniedMetadataKeys)
case att: AttributeReference =>
new ColumnarAttributeReference(att.name, att.dataType, att.nullable,
att.metadata)(att.exprId, att.qualifier)
Expand Down