From 92cb513416c5dd0e9fa690c25cfae0565471a5e1 Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Tue, 29 May 2018 14:34:57 +0200 Subject: [PATCH 1/4] [SPARK-24385][SQL] Resolve self-join condition ambiguity for all BinaryComparisons --- .../main/scala/org/apache/spark/sql/Dataset.scala | 6 +++--- .../org/apache/spark/sql/DataFrameJoinSuite.scala | 12 ++++++++++++ 2 files changed, 15 insertions(+), 3 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index abb5ae53f4d73..0f2de323a0a9b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -1000,11 +1000,11 @@ class Dataset[T] private[sql]( // By the time we get here, since we have already run analysis, all attributes should've been // resolved and become AttributeReference. val cond = plan.condition.map { _.transform { - case catalyst.expressions.EqualTo(a: AttributeReference, b: AttributeReference) + case e @ catalyst.expressions.BinaryComparison(a: AttributeReference, b: AttributeReference) if a.sameRef(b) => - catalyst.expressions.EqualTo( + e.withNewChildren(Seq( withPlan(plan.left).resolve(a.name), - withPlan(plan.right).resolve(b.name)) + withPlan(plan.right).resolve(b.name))) }} withPlan { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala index 0d9eeabb397a1..7a57863f78356 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala @@ -287,4 +287,16 @@ class DataFrameJoinSuite extends QueryTest with SharedSQLContext { dfOne.join(dfTwo, $"a" === $"b", "left").queryExecution.optimizedPlan } } + + test("SPARK-24385: Resolve ambiguity in self-joins with operators different from EqualsTo") { + withSQLConf(SQLConf.CROSS_JOINS_ENABLED.key -> "false") { + val df = spark.range(10) + // these should not throw any exception + df.join(df, df("id") >= df("id")).queryExecution.optimizedPlan + df.join(df, df("id") <=> df("id")).queryExecution.optimizedPlan + df.join(df, df("id") <= df("id")).queryExecution.optimizedPlan + df.join(df, df("id") > df("id")).queryExecution.optimizedPlan + df.join(df, df("id") < df("id")).queryExecution.optimizedPlan + } + } } From e8a5fa33d56187a6e30e81ba9439cd097fff5b2c Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Wed, 30 May 2018 11:54:44 +0200 Subject: [PATCH 2/4] properly handle different datasets with common lineage --- .../scala/org/apache/spark/sql/Dataset.scala | 31 ++++++++++++++++--- .../apache/spark/sql/DataFrameJoinSuite.scala | 11 +++++-- 2 files changed, 35 insertions(+), 7 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 0f2de323a0a9b..c4df603d7f120 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -74,6 +74,9 @@ private[sql] object Dataset { qe.assertAnalyzed() new Dataset[Row](sparkSession, qe, RowEncoder(qe.analyzed.schema)) } + + // String used as key in metadata of resolved attributes + private val DATASET_ID = "dataset.hash" } /** @@ -217,11 +220,22 @@ class Dataset[T] private[sql]( @transient lazy val sqlContext: SQLContext = sparkSession.sqlContext private[sql] def resolve(colName: String): NamedExpression = { - queryExecution.analyzed.resolveQuoted(colName, sparkSession.sessionState.analyzer.resolver) - .getOrElse { + val resolved = queryExecution.analyzed.resolveQuoted(colName, + sparkSession.sessionState.analyzer.resolver).getOrElse { throw new AnalysisException( s"""Cannot resolve column name "$colName" among (${schema.fieldNames.mkString(", ")})""") } + // We introduce in the metadata a reference to the Dataset the attribute is coming from because + // it is useful to determine what this attribute is really referencing when performing + // self-joins (or joins between dataset with common lineage) and the join condition contains + // ambiguous references. + resolved match { + case a: AttributeReference => + val mBuilder = new MetadataBuilder() + mBuilder.withMetadata(a.metadata).putLong(Dataset.DATASET_ID, this.hashCode().toLong) + a.withMetadata(mBuilder.build()) + case other => other + } } private[sql] def numericColumns: Seq[Expression] = { @@ -1002,9 +1016,16 @@ class Dataset[T] private[sql]( val cond = plan.condition.map { _.transform { case e @ catalyst.expressions.BinaryComparison(a: AttributeReference, b: AttributeReference) if a.sameRef(b) => - e.withNewChildren(Seq( - withPlan(plan.left).resolve(a.name), - withPlan(plan.right).resolve(b.name))) + val bReferencesThis = b.metadata.contains(Dataset.DATASET_ID) && + b.metadata.getLong(Dataset.DATASET_ID) == hashCode() + val aReferencesRight = a.metadata.contains(Dataset.DATASET_ID) && + a.metadata.getLong(Dataset.DATASET_ID) == right.hashCode() + val newChildren = if (bReferencesThis && aReferencesRight) { + Seq(withPlan(plan.right).resolve(a.name), withPlan(plan.left).resolve(b.name)) + } else { + Seq(withPlan(plan.left).resolve(a.name), withPlan(plan.right).resolve(b.name)) + } + e.withNewChildren(newChildren) }} withPlan { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala index 7a57863f78356..04eea442ac9c2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala @@ -290,13 +290,20 @@ class DataFrameJoinSuite extends QueryTest with SharedSQLContext { test("SPARK-24385: Resolve ambiguity in self-joins with operators different from EqualsTo") { withSQLConf(SQLConf.CROSS_JOINS_ENABLED.key -> "false") { - val df = spark.range(10) - // these should not throw any exception + val df = spark.range(2) + + // These should not throw any exception. df.join(df, df("id") >= df("id")).queryExecution.optimizedPlan df.join(df, df("id") <=> df("id")).queryExecution.optimizedPlan df.join(df, df("id") <= df("id")).queryExecution.optimizedPlan df.join(df, df("id") > df("id")).queryExecution.optimizedPlan df.join(df, df("id") < df("id")).queryExecution.optimizedPlan + + // Check we properly resolve columns when datasets are different but they share a common + // lineage. + val df1 = df.groupBy("id").count() + val df2 = df.groupBy("id").sum("id") + checkAnswer(df1.join(df2, df2("id") < df1("id")), Seq(Row(1, 1, 0, 0))) } } } From b8d50570b7b172ef310fdfb12b01be1598ff5481 Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Wed, 30 May 2018 13:04:15 +0200 Subject: [PATCH 3/4] fix ut failure --- sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index c4df603d7f120..c564fef3a7c51 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -463,7 +463,8 @@ class Dataset[T] private[sql]( * @group basic * @since 1.6.0 */ - def schema: StructType = queryExecution.analyzed.schema + def schema: StructType = StructType.removeMetadata( + Dataset.DATASET_ID, queryExecution.analyzed.schema).asInstanceOf[StructType] /** * Prints the schema to the console in a nice tree format. From 8e6e5c0059574c1e171e589fcf533c6b5669499f Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Wed, 30 May 2018 18:44:46 +0200 Subject: [PATCH 4/4] use semanticEquals for comparing Attributes --- .../apache/spark/sql/catalyst/analysis/Analyzer.scala | 5 ++++- .../spark/sql/catalyst/expressions/package.scala | 10 +++++++++- .../src/main/scala/org/apache/spark/sql/Dataset.scala | 4 ++-- .../scala/org/apache/spark/sql/DataFrameSuite.scala | 2 +- 4 files changed, 16 insertions(+), 5 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 3eaa9ecf5d075..e7ab3545c2d37 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -332,7 +332,10 @@ class Analyzer( gid: Expression): Expression = { expr transform { case e: GroupingID => - if (e.groupByExprs.isEmpty || e.groupByExprs == groupByExprs) { + def sameExprs = e.groupByExprs.zip(groupByExprs).forall { + case (e1, e2) => e1.semanticEquals(e2) + } + if (e.groupByExprs.isEmpty || sameExprs) { Alias(gid, toPrettySQL(e))() } else { throw new AnalysisException( diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala index 8a06daa37132d..e4c24bd79f9ad 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala @@ -144,7 +144,15 @@ package object expressions { } private def unique[T](m: Map[T, Seq[Attribute]]): Map[T, Seq[Attribute]] = { - m.mapValues(_.distinct).map(identity) + m.mapValues { allAttrs => + val buffer = new scala.collection.mutable.ListBuffer[Attribute] + allAttrs.foreach { a => + if (!buffer.exists(_.semanticEquals(a))) { + buffer += a + } + } + buffer + }.map(identity) } /** Map to use for direct case insensitive attribute lookups. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index c564fef3a7c51..fe363c53515f9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -240,7 +240,7 @@ class Dataset[T] private[sql]( private[sql] def numericColumns: Seq[Expression] = { schema.fields.filter(_.dataType.isInstanceOf[NumericType]).map { n => - queryExecution.analyzed.resolveQuoted(n.name, sparkSession.sessionState.analyzer.resolver).get + resolve(n.name) } } @@ -2329,7 +2329,7 @@ class Dataset[T] private[sql]( } val attrs = this.planWithBarrier.output val colsAfterDrop = attrs.filter { attr => - attr != expression + !attr.semanticEquals(expression) }.map(attr => Column(attr)) select(colsAfterDrop : _*) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 1cc8cb3874c9b..ac223dbc3b0c2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -862,7 +862,7 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { Row(id, name, age, salary) }.toSeq) assert(df.schema.map(_.name) === Seq("id", "name", "age", "salary")) - assert(df("id") == person("id")) + assert(df("id").expr.semanticEquals(person("id").expr)) } test("drop top level columns that contains dot") {