Skip to content

Commit 6987c06

Browse files
Davies Liuyhuai
authored andcommitted
[SPARK-11009] [SQL] fix wrong result of Window function in cluster mode
Currently, All windows function could generate wrong result in cluster sometimes. The root cause is that AttributeReference is called in executor, then id of it may not be unique than others created in driver. Here is the script that could reproduce the problem (run in local cluster): ``` from pyspark import SparkContext, HiveContext from pyspark.sql.window import Window from pyspark.sql.functions import rowNumber sqlContext = HiveContext(SparkContext()) sqlContext.setConf("spark.sql.shuffle.partitions", "3") df = sqlContext.range(1<<20) df2 = df.select((df.id % 1000).alias("A"), (df.id / 1000).alias('B')) ws = Window.partitionBy(df2.A).orderBy(df2.B) df3 = df2.select("client", "date", rowNumber().over(ws).alias("rn")).filter("rn < 0") assert df3.count() == 0 ``` Author: Davies Liu <[email protected]> Author: Yin Huai <[email protected]> Closes #9050 from davies/wrong_window.
1 parent 626aab7 commit 6987c06

File tree

2 files changed

+51
-10
lines changed

2 files changed

+51
-10
lines changed

sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -145,11 +145,10 @@ case class Window(
145145
// Construct the ordering. This is used to compare the result of current value projection
146146
// to the result of bound value projection. This is done manually because we want to use
147147
// Code Generation (if it is enabled).
148-
val (sortExprs, schema) = exprs.map { case e =>
149-
val ref = AttributeReference("ordExpr", e.dataType, e.nullable)()
150-
(SortOrder(ref, e.direction), ref)
151-
}.unzip
152-
val ordering = newOrdering(sortExprs, schema)
148+
val sortExprs = exprs.zipWithIndex.map { case (e, i) =>
149+
SortOrder(BoundReference(i, e.dataType, e.nullable), e.direction)
150+
}
151+
val ordering = newOrdering(sortExprs, Nil)
153152
RangeBoundOrdering(ordering, current, bound)
154153
case RowFrame => RowBoundOrdering(offset)
155154
}
@@ -205,14 +204,15 @@ case class Window(
205204
*/
206205
private[this] def createResultProjection(
207206
expressions: Seq[Expression]): MutableProjection = {
208-
val unboundToAttr = expressions.map {
209-
e => (e, AttributeReference("windowResult", e.dataType, e.nullable)())
207+
val references = expressions.zipWithIndex.map{ case (e, i) =>
208+
// Results of window expressions will be on the right side of child's output
209+
BoundReference(child.output.size + i, e.dataType, e.nullable)
210210
}
211-
val unboundToAttrMap = unboundToAttr.toMap
212-
val patchedWindowExpression = windowExpression.map(_.transform(unboundToAttrMap))
211+
val unboundToRefMap = expressions.zip(references).toMap
212+
val patchedWindowExpression = windowExpression.map(_.transform(unboundToRefMap))
213213
newMutableProjection(
214214
projectList ++ patchedWindowExpression,
215-
child.output ++ unboundToAttr.map(_._2))()
215+
child.output)()
216216
}
217217

218218
protected override def doExecute(): RDD[InternalRow] = {

sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ import org.scalatest.time.SpanSugar._
3030

3131
import org.apache.spark._
3232
import org.apache.spark.sql.{SQLContext, QueryTest}
33+
import org.apache.spark.sql.expressions.Window
3334
import org.apache.spark.sql.hive.test.{TestHive, TestHiveContext}
3435
import org.apache.spark.sql.test.ProcessTestUtils.ProcessOutputCapturer
3536
import org.apache.spark.sql.types.DecimalType
@@ -107,6 +108,16 @@ class HiveSparkSubmitSuite
107108
runSparkSubmit(args)
108109
}
109110

111+
test("SPARK-11009 fix wrong result of Window function in cluster mode") {
112+
val unusedJar = TestUtils.createJarWithClasses(Seq.empty)
113+
val args = Seq(
114+
"--class", SPARK_11009.getClass.getName.stripSuffix("$"),
115+
"--name", "SparkSQLConfTest",
116+
"--master", "local-cluster[2,1,1024]",
117+
unusedJar.toString)
118+
runSparkSubmit(args)
119+
}
120+
110121
// NOTE: This is an expensive operation in terms of time (10 seconds+). Use sparingly.
111122
// This is copied from org.apache.spark.deploy.SparkSubmitSuite
112123
private def runSparkSubmit(args: Seq[String]): Unit = {
@@ -320,3 +331,33 @@ object SPARK_9757 extends QueryTest {
320331
}
321332
}
322333
}
334+
335+
object SPARK_11009 extends QueryTest {
336+
import org.apache.spark.sql.functions._
337+
338+
protected var sqlContext: SQLContext = _
339+
340+
def main(args: Array[String]): Unit = {
341+
Utils.configTestLog4j("INFO")
342+
343+
val sparkContext = new SparkContext(
344+
new SparkConf()
345+
.set("spark.ui.enabled", "false")
346+
.set("spark.sql.shuffle.partitions", "100"))
347+
348+
val hiveContext = new TestHiveContext(sparkContext)
349+
sqlContext = hiveContext
350+
351+
try {
352+
val df = sqlContext.range(1 << 20)
353+
val df2 = df.select((df("id") % 1000).alias("A"), (df("id") / 1000).alias("B"))
354+
val ws = Window.partitionBy(df2("A")).orderBy(df2("B"))
355+
val df3 = df2.select(df2("A"), df2("B"), rowNumber().over(ws).alias("rn")).filter("rn < 0")
356+
if (df3.rdd.count() != 0) {
357+
throw new Exception("df3 should have 0 output row.")
358+
}
359+
} finally {
360+
sparkContext.stop()
361+
}
362+
}
363+
}

0 commit comments

Comments
 (0)