From 35b7b229b245d26ed8a72c2dd3701e2675ec27ea Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Fri, 9 Oct 2015 14:04:41 -0700 Subject: [PATCH 1/6] fix wrong result of Window function in cluster mode --- .../apache/spark/sql/execution/Window.scala | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala index f8929530c503..ac7aa05b085e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala @@ -201,18 +201,17 @@ case class Window( * This method uses Code Generation. It can only be used on the executor side. * * @param expressions unbound ordered function expressions. + * @param attributes output attributes * @return the final resulting projection. */ private[this] def createResultProjection( - expressions: Seq[Expression]): MutableProjection = { - val unboundToAttr = expressions.map { - e => (e, AttributeReference("windowResult", e.dataType, e.nullable)()) - } - val unboundToAttrMap = unboundToAttr.toMap + expressions: Seq[Expression], + attributes: Seq[Attribute]): MutableProjection = { + val unboundToAttrMap = expressions.zip(attributes).toMap val patchedWindowExpression = windowExpression.map(_.transform(unboundToAttrMap)) newMutableProjection( projectList ++ patchedWindowExpression, - child.output ++ unboundToAttr.map(_._2))() + child.output ++ attributes)() } protected override def doExecute(): RDD[InternalRow] = { @@ -247,12 +246,17 @@ case class Window( factories(index) = () => createFrameProcessor(frame, functions, ordinal) } + // AttributeReference can only be created in driver, or the id will not be unique + val outputAttributes = unboundExpressions.map { + e => AttributeReference("windowResult", e.dataType, e.nullable)() + } + // Start processing. child.execute().mapPartitions { stream => new Iterator[InternalRow] { // Get all relevant projections. - val result = createResultProjection(unboundExpressions) + val result = createResultProjection(unboundExpressions, outputAttributes) val grouping = if (child.outputsUnsafeRows) { UnsafeProjection.create(partitionSpec, child.output) } else { From 39b99b819a5219cdb9ea4fae8e58ed9582cc10a6 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Fri, 9 Oct 2015 17:13:14 -0700 Subject: [PATCH 2/6] add a comment --- .../src/main/scala/org/apache/spark/sql/execution/Window.scala | 2 ++ 1 file changed, 2 insertions(+) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala index ac7aa05b085e..7e204ec88aa1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala @@ -146,6 +146,8 @@ case class Window( // to the result of bound value projection. This is done manually because we want to use // Code Generation (if it is enabled). val (sortExprs, schema) = exprs.map { case e => + // This AttributeReference does not need to have unique IDs, it's OK to be called + // in executor. val ref = AttributeReference("ordExpr", e.dataType, e.nullable)() (SortOrder(ref, e.direction), ref) }.unzip From 2d55882530aaa1c852e8294b8cb11af158f37481 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Mon, 12 Oct 2015 14:57:36 -0700 Subject: [PATCH 3/6] use BoundReference --- .../apache/spark/sql/execution/Window.scala | 32 ++++++++----------- 1 file changed, 13 insertions(+), 19 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala index 7e204ec88aa1..55035f4bc5f2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala @@ -145,13 +145,10 @@ case class Window( // Construct the ordering. This is used to compare the result of current value projection // to the result of bound value projection. This is done manually because we want to use // Code Generation (if it is enabled). - val (sortExprs, schema) = exprs.map { case e => - // This AttributeReference does not need to have unique IDs, it's OK to be called - // in executor. - val ref = AttributeReference("ordExpr", e.dataType, e.nullable)() - (SortOrder(ref, e.direction), ref) - }.unzip - val ordering = newOrdering(sortExprs, schema) + val sortExprs = exprs.zipWithIndex.map { case (e, i) => + SortOrder(BoundReference(i, e.dataType, e.nullable), e.direction) + } + val ordering = newOrdering(sortExprs, Nil) RangeBoundOrdering(ordering, current, bound) case RowFrame => RowBoundOrdering(offset) } @@ -203,17 +200,19 @@ case class Window( * This method uses Code Generation. It can only be used on the executor side. * * @param expressions unbound ordered function expressions. - * @param attributes output attributes * @return the final resulting projection. */ private[this] def createResultProjection( - expressions: Seq[Expression], - attributes: Seq[Attribute]): MutableProjection = { - val unboundToAttrMap = expressions.zip(attributes).toMap - val patchedWindowExpression = windowExpression.map(_.transform(unboundToAttrMap)) + expressions: Seq[Expression]): MutableProjection = { + val references = expressions.zipWithIndex.map{ case (e, i) => + // Results of window expressions will be on the right side of child's output + BoundReference(child.output.size + i, e.dataType, e.nullable) + } + val unboundToRefMap = expressions.zip(references).toMap + val patchedWindowExpression = windowExpression.map(_.transform(unboundToRefMap)) newMutableProjection( projectList ++ patchedWindowExpression, - child.output ++ attributes)() + child.output)() } protected override def doExecute(): RDD[InternalRow] = { @@ -248,17 +247,12 @@ case class Window( factories(index) = () => createFrameProcessor(frame, functions, ordinal) } - // AttributeReference can only be created in driver, or the id will not be unique - val outputAttributes = unboundExpressions.map { - e => AttributeReference("windowResult", e.dataType, e.nullable)() - } - // Start processing. child.execute().mapPartitions { stream => new Iterator[InternalRow] { // Get all relevant projections. - val result = createResultProjection(unboundExpressions, outputAttributes) + val result = createResultProjection(unboundExpressions) val grouping = if (child.outputsUnsafeRows) { UnsafeProjection.create(partitionSpec, child.output) } else { From bc566fa0d5e56d5c10c6ad7245336d9fb1bf101c Mon Sep 17 00:00:00 2001 From: Yin Huai Date: Mon, 12 Oct 2015 17:55:55 -0700 Subject: [PATCH 4/6] Add regression test. --- .../spark/sql/hive/HiveSparkSubmitSuite.scala | 44 +++++++++++++++++++ 1 file changed, 44 insertions(+) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala index 5f1660b62d41..2df467b78321 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala @@ -30,6 +30,7 @@ import org.scalatest.time.SpanSugar._ import org.apache.spark._ import org.apache.spark.sql.{SQLContext, QueryTest} +import org.apache.spark.sql.expressions.Window import org.apache.spark.sql.hive.test.{TestHive, TestHiveContext} import org.apache.spark.sql.test.ProcessTestUtils.ProcessOutputCapturer import org.apache.spark.sql.types.DecimalType @@ -107,6 +108,16 @@ class HiveSparkSubmitSuite runSparkSubmit(args) } + test("SPARK-11009 fix wrong result of Window function in cluster mode") { + val unusedJar = TestUtils.createJarWithClasses(Seq.empty) + val args = Seq( + "--class", SPARK_11009.getClass.getName.stripSuffix("$"), + "--name", "SparkSQLConfTest", + "--master", "local-cluster[2,1,1024]", + unusedJar.toString) + runSparkSubmit(args) + } + // NOTE: This is an expensive operation in terms of time (10 seconds+). Use sparingly. // This is copied from org.apache.spark.deploy.SparkSubmitSuite private def runSparkSubmit(args: Seq[String]): Unit = { @@ -320,3 +331,36 @@ object SPARK_9757 extends QueryTest { } } } + +object SPARK_11009 extends QueryTest { + import org.apache.spark.sql.functions._ + + protected var sqlContext: SQLContext = _ + + def main(args: Array[String]): Unit = { + Utils.configTestLog4j("INFO") + + val sparkContext = new SparkContext( + new SparkConf() + .set("spark.ui.enabled", "false") + .set("spark.sql.shuffle.partitions", "3")) + + val hiveContext = new TestHiveContext(sparkContext) + sqlContext = hiveContext + + try { + val df = sqlContext.range(1<<20) + val df2 = df.select((df("id") % 1000).alias("A"), (df("id") / 1000).alias("B")) + val ws = Window.partitionBy(df2("A")).orderBy(df2("B")) + val df3 = + df2 + .select(df2("A"), df2("B"), rowNumber().over(ws).alias("rn")) + .filter("rn < 0") + if (df3.count() == 0) { + throw new Exception(s"df3 should has 0 row.") + } + } finally { + sparkContext.stop() + } + } +} From 89c140104ff931f4d5c154f754d9f43891cd600f Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Mon, 12 Oct 2015 21:59:21 -0700 Subject: [PATCH 5/6] Update HiveSparkSubmitSuite.scala --- .../scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala index 2df467b78321..584b7ca978d3 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala @@ -356,7 +356,7 @@ object SPARK_11009 extends QueryTest { df2 .select(df2("A"), df2("B"), rowNumber().over(ws).alias("rn")) .filter("rn < 0") - if (df3.count() == 0) { + if (df3.rdd.count() == 0) { throw new Exception(s"df3 should has 0 row.") } } finally { From 3aec389d82363975fc4e8a17e8bf69474a70c988 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Mon, 12 Oct 2015 23:38:11 -0700 Subject: [PATCH 6/6] Update HiveSparkSubmitSuite.scala --- .../spark/sql/hive/HiveSparkSubmitSuite.scala | 13 ++++--------- 1 file changed, 4 insertions(+), 9 deletions(-) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala index 584b7ca978d3..0102a58d406e 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala @@ -343,22 +343,17 @@ object SPARK_11009 extends QueryTest { val sparkContext = new SparkContext( new SparkConf() .set("spark.ui.enabled", "false") - .set("spark.sql.shuffle.partitions", "3")) + .set("spark.sql.shuffle.partitions", "100")) val hiveContext = new TestHiveContext(sparkContext) sqlContext = hiveContext try { - val df = sqlContext.range(1<<20) + val df = sqlContext.range(1 << 20) val df2 = df.select((df("id") % 1000).alias("A"), (df("id") / 1000).alias("B")) val ws = Window.partitionBy(df2("A")).orderBy(df2("B")) - val df3 = - df2 - .select(df2("A"), df2("B"), rowNumber().over(ws).alias("rn")) - .filter("rn < 0") - if (df3.rdd.count() == 0) { - throw new Exception(s"df3 should has 0 row.") - } + val df3 = df2.select(df2("A"), df2("B"), rowNumber().over(ws).alias("rn")).filter("rn < 0") + assert(df3.rdd.count() === 0) } finally { sparkContext.stop() }