Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -332,7 +332,10 @@ class Analyzer(
gid: Expression): Expression = {
expr transform {
case e: GroupingID =>
if (e.groupByExprs.isEmpty || e.groupByExprs == groupByExprs) {
def sameExprs = e.groupByExprs.zip(groupByExprs).forall {
case (e1, e2) => e1.semanticEquals(e2)
}
if (e.groupByExprs.isEmpty || sameExprs) {
Alias(gid, toPrettySQL(e))()
} else {
throw new AnalysisException(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,15 @@ package object expressions {
}

private def unique[T](m: Map[T, Seq[Attribute]]): Map[T, Seq[Attribute]] = {
m.mapValues(_.distinct).map(identity)
m.mapValues { allAttrs =>
val buffer = new scala.collection.mutable.ListBuffer[Attribute]
allAttrs.foreach { a =>
if (!buffer.exists(_.semanticEquals(a))) {
buffer += a
}
}
buffer
}.map(identity)
}

/** Map to use for direct case insensitive attribute lookups. */
Expand Down
40 changes: 31 additions & 9 deletions sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,9 @@ private[sql] object Dataset {
qe.assertAnalyzed()
new Dataset[Row](sparkSession, qe, RowEncoder(qe.analyzed.schema))
}

// String used as key in metadata of resolved attributes
private val DATASET_ID = "dataset.hash"
}

/**
Expand Down Expand Up @@ -217,16 +220,27 @@ class Dataset[T] private[sql](
@transient lazy val sqlContext: SQLContext = sparkSession.sqlContext

private[sql] def resolve(colName: String): NamedExpression = {
queryExecution.analyzed.resolveQuoted(colName, sparkSession.sessionState.analyzer.resolver)
.getOrElse {
val resolved = queryExecution.analyzed.resolveQuoted(colName,
sparkSession.sessionState.analyzer.resolver).getOrElse {
throw new AnalysisException(
s"""Cannot resolve column name "$colName" among (${schema.fieldNames.mkString(", ")})""")
}
// We introduce in the metadata a reference to the Dataset the attribute is coming from because
// it is useful to determine what this attribute is really referencing when performing
// self-joins (or joins between dataset with common lineage) and the join condition contains
// ambiguous references.
resolved match {
case a: AttributeReference =>
val mBuilder = new MetadataBuilder()
mBuilder.withMetadata(a.metadata).putLong(Dataset.DATASET_ID, this.hashCode().toLong)
a.withMetadata(mBuilder.build())
case other => other
}
}

private[sql] def numericColumns: Seq[Expression] = {
schema.fields.filter(_.dataType.isInstanceOf[NumericType]).map { n =>
queryExecution.analyzed.resolveQuoted(n.name, sparkSession.sessionState.analyzer.resolver).get
resolve(n.name)
}
}

Expand Down Expand Up @@ -449,7 +463,8 @@ class Dataset[T] private[sql](
* @group basic
* @since 1.6.0
*/
def schema: StructType = queryExecution.analyzed.schema
def schema: StructType = StructType.removeMetadata(
Dataset.DATASET_ID, queryExecution.analyzed.schema).asInstanceOf[StructType]

/**
* Prints the schema to the console in a nice tree format.
Expand Down Expand Up @@ -1000,11 +1015,18 @@ class Dataset[T] private[sql](
// By the time we get here, since we have already run analysis, all attributes should've been
// resolved and become AttributeReference.
val cond = plan.condition.map { _.transform {
case catalyst.expressions.EqualTo(a: AttributeReference, b: AttributeReference)
case e @ catalyst.expressions.BinaryComparison(a: AttributeReference, b: AttributeReference)
if a.sameRef(b) =>
catalyst.expressions.EqualTo(
withPlan(plan.left).resolve(a.name),
withPlan(plan.right).resolve(b.name))
val bReferencesThis = b.metadata.contains(Dataset.DATASET_ID) &&
b.metadata.getLong(Dataset.DATASET_ID) == hashCode()
val aReferencesRight = a.metadata.contains(Dataset.DATASET_ID) &&
a.metadata.getLong(Dataset.DATASET_ID) == right.hashCode()
val newChildren = if (bReferencesThis && aReferencesRight) {
Seq(withPlan(plan.right).resolve(a.name), withPlan(plan.left).resolve(b.name))
} else {
Seq(withPlan(plan.left).resolve(a.name), withPlan(plan.right).resolve(b.name))
}
e.withNewChildren(newChildren)
}}

withPlan {
Expand Down Expand Up @@ -2307,7 +2329,7 @@ class Dataset[T] private[sql](
}
val attrs = this.planWithBarrier.output
val colsAfterDrop = attrs.filter { attr =>
attr != expression
!attr.semanticEquals(expression)
}.map(attr => Column(attr))
select(colsAfterDrop : _*)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -287,4 +287,23 @@ class DataFrameJoinSuite extends QueryTest with SharedSQLContext {
dfOne.join(dfTwo, $"a" === $"b", "left").queryExecution.optimizedPlan
}
}

test("SPARK-24385: Resolve ambiguity in self-joins with operators different from EqualsTo") {
withSQLConf(SQLConf.CROSS_JOINS_ENABLED.key -> "false") {
val df = spark.range(2)

// These should not throw any exception.
df.join(df, df("id") >= df("id")).queryExecution.optimizedPlan
df.join(df, df("id") <=> df("id")).queryExecution.optimizedPlan
df.join(df, df("id") <= df("id")).queryExecution.optimizedPlan
df.join(df, df("id") > df("id")).queryExecution.optimizedPlan
df.join(df, df("id") < df("id")).queryExecution.optimizedPlan

// Check we properly resolve columns when datasets are different but they share a common
// lineage.
val df1 = df.groupBy("id").count()
val df2 = df.groupBy("id").sum("id")
checkAnswer(df1.join(df2, df2("id") < df1("id")), Seq(Row(1, 1, 0, 0)))
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -862,7 +862,7 @@ class DataFrameSuite extends QueryTest with SharedSQLContext {
Row(id, name, age, salary)
}.toSeq)
assert(df.schema.map(_.name) === Seq("id", "name", "age", "salary"))
assert(df("id") == person("id"))
assert(df("id").expr.semanticEquals(person("id").expr))
}

test("drop top level columns that contains dot") {
Expand Down