Skip to content

Commit d676b62

Browse files
committed
[SPARK-24051][SQL] Replace Aliases with the same exprId
1 parent 20ca208 commit d676b62

File tree

2 files changed

+90
-0
lines changed

2 files changed

+90
-0
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
package org.apache.spark.sql.catalyst.analysis
1919

20+
import scala.collection.mutable
2021
import scala.collection.mutable.ArrayBuffer
2122
import scala.util.Random
2223

@@ -36,6 +37,7 @@ import org.apache.spark.sql.catalyst.util.toPrettySQL
3637
import org.apache.spark.sql.internal.SQLConf
3738
import org.apache.spark.sql.types._
3839

40+
3941
/**
4042
* A trivial [[Analyzer]] with a dummy [[SessionCatalog]] and [[EmptyFunctionRegistry]].
4143
* Used for testing when all relations are already filled in and the analyzer needs only
@@ -145,6 +147,8 @@ class Analyzer(
145147
ResolveHints.RemoveAllHints),
146148
Batch("Simple Sanity Check", Once,
147149
LookupFunctions),
150+
Batch("DeduplicateAliases", Once,
151+
DeduplicateAliases),
148152
Batch("Substitution", fixedPoint,
149153
CTESubstitution,
150154
WindowsSubstitution,
@@ -284,6 +288,80 @@ class Analyzer(
284288
}
285289
}
286290

291+
/**
292+
* Replaces [[Alias]] with the same exprId but different references with [[Alias]] having
293+
* different exprIds. This is a rare situation which can cause incorrect results.
294+
*/
295+
object DeduplicateAliases extends Rule[LogicalPlan] {
296+
def apply(plan: LogicalPlan): LogicalPlan = {
297+
val allAliases = collectAllAliasesInPlan(plan)
298+
val dupAliases = allAliases.groupBy(_.exprId).collect {
299+
case (eId, aliases) if containsDifferentAliases(aliases) => eId
300+
}.toSeq
301+
if (dupAliases.nonEmpty) {
302+
val exprIdsDictionary = mutable.HashMap[ExprId, ExprId]()
303+
resolveConflictingAliases(plan, dupAliases, exprIdsDictionary)
304+
} else {
305+
plan
306+
}
307+
}
308+
309+
def containsDifferentAliases(aliases: Seq[Alias]): Boolean = {
310+
aliases.exists(a1 => aliases.exists(a2 => !a1.fastEquals(a2)))
311+
}
312+
313+
def collectAllAliasesInPlan(plan: LogicalPlan): Seq[Alias] = {
314+
plan.flatMap {
315+
case Project(projectList, _) => projectList.collect { case a: Alias => a }
316+
case AnalysisBarrier(child) => collectAllAliasesInPlan(child)
317+
case _ => Nil
318+
}
319+
}
320+
321+
def containsExprIds(
322+
projectList: Seq[NamedExpression],
323+
exprIds: Seq[ExprId]): Boolean = {
324+
projectList.count {
325+
case a: Alias if exprIds.contains(a.exprId) => true
326+
case a: AttributeReference if exprIds.contains(a.exprId) => true
327+
case _ => false
328+
} > 0
329+
}
330+
331+
def renewConflictingAliases(
332+
exprs: Seq[NamedExpression],
333+
exprIds: Seq[ExprId],
334+
exprIdsDictionary: mutable.HashMap[ExprId, ExprId]): Seq[NamedExpression] = {
335+
exprs.map {
336+
case a: Alias if exprIds.contains(a.exprId) =>
337+
val newAlias = Alias(a.child, a.name)()
338+
// update the map with the new id to replace
339+
// since we are in a transformUp, all the parent nodes will see the updated map
340+
exprIdsDictionary(a.exprId) = newAlias.exprId
341+
newAlias
342+
case a: AttributeReference if exprIds.contains(a.exprId) =>
343+
// replace with the new id
344+
a.withExprId(exprIdsDictionary(a.exprId))
345+
case other => other
346+
}
347+
}
348+
349+
def resolveConflictingAliases(
350+
plan: LogicalPlan,
351+
dupAliases: Seq[ExprId],
352+
exprIdsDictionary: mutable.HashMap[ExprId, ExprId]): LogicalPlan = {
353+
plan.transformUp {
354+
case p @ Project(projectList, _) if containsExprIds(projectList, dupAliases) =>
355+
p.copy(renewConflictingAliases(projectList, dupAliases, exprIdsDictionary))
356+
case a @ Aggregate(_, aggs, _) if containsExprIds(aggs, dupAliases) =>
357+
a.copy(aggregateExpressions =
358+
renewConflictingAliases(aggs, dupAliases, exprIdsDictionary))
359+
case AnalysisBarrier(child) =>
360+
AnalysisBarrier(resolveConflictingAliases(child, dupAliases, exprIdsDictionary))
361+
}
362+
}
363+
}
364+
287365
object ResolveGroupingAnalytics extends Rule[LogicalPlan] {
288366
/*
289367
* GROUP BY a, b, c WITH ROLLUP

sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ import org.apache.spark.sql.catalyst.plans.logical.{Filter, OneRowRelation, Unio
3333
import org.apache.spark.sql.execution.{FilterExec, QueryExecution, WholeStageCodegenExec}
3434
import org.apache.spark.sql.execution.aggregate.HashAggregateExec
3535
import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ReusedExchangeExec, ShuffleExchangeExec}
36+
import org.apache.spark.sql.expressions.Window
3637
import org.apache.spark.sql.functions._
3738
import org.apache.spark.sql.internal.SQLConf
3839
import org.apache.spark.sql.test.{ExamplePoint, ExamplePointUDT, SharedSQLContext}
@@ -2265,4 +2266,15 @@ class DataFrameSuite extends QueryTest with SharedSQLContext {
22652266
val df = spark.range(1).select($"id", new Column(Uuid()))
22662267
checkAnswer(df, df.collect())
22672268
}
2269+
2270+
test("SPARK-24051: using the same alias can produce incorrect result") {
2271+
val ds1 = Seq((1, 42), (2, 99)).toDF("a", "b")
2272+
val ds2 = Seq(3).toDF("a").withColumn("b", lit(0))
2273+
2274+
val cols = Seq(col("a"), col("b").alias("b"),
2275+
count(lit(1)).over(Window.partitionBy()).alias("n"))
2276+
2277+
val df = ds1.select(cols: _*).union(ds2.select(cols: _*))
2278+
checkAnswer(df, Row(1, 42, 2) :: Row(2, 99, 2) :: Row(3, 0, 1) :: Nil)
2279+
}
22682280
}

0 commit comments

Comments
 (0)