Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 10 additions & 10 deletions sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala
Original file line number Diff line number Diff line change
Expand Up @@ -145,11 +145,10 @@ case class Window(
// Construct the ordering. This is used to compare the result of current value projection
// to the result of bound value projection. This is done manually because we want to use
// Code Generation (if it is enabled).
val (sortExprs, schema) = exprs.map { case e =>
val ref = AttributeReference("ordExpr", e.dataType, e.nullable)()
(SortOrder(ref, e.direction), ref)
}.unzip
val ordering = newOrdering(sortExprs, schema)
val sortExprs = exprs.zipWithIndex.map { case (e, i) =>
SortOrder(BoundReference(i, e.dataType, e.nullable), e.direction)
}
val ordering = newOrdering(sortExprs, Nil)
RangeBoundOrdering(ordering, current, bound)
case RowFrame => RowBoundOrdering(offset)
}
Expand Down Expand Up @@ -205,14 +204,15 @@ case class Window(
*/
private[this] def createResultProjection(
expressions: Seq[Expression]): MutableProjection = {
val unboundToAttr = expressions.map {
e => (e, AttributeReference("windowResult", e.dataType, e.nullable)())
val references = expressions.zipWithIndex.map{ case (e, i) =>
// Results of window expressions will be on the right side of child's output
BoundReference(child.output.size + i, e.dataType, e.nullable)
}
val unboundToAttrMap = unboundToAttr.toMap
val patchedWindowExpression = windowExpression.map(_.transform(unboundToAttrMap))
val unboundToRefMap = expressions.zip(references).toMap
val patchedWindowExpression = windowExpression.map(_.transform(unboundToRefMap))
newMutableProjection(
projectList ++ patchedWindowExpression,
child.output ++ unboundToAttr.map(_._2))()
child.output)()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it true that the input row of this projection has more elements than child.output? Maybe it is not very easy to understand this subtle change?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Currently, we put all the windowResult on the right side of child.output, patchedWindowExpression will be pointed to them.

}

protected override def doExecute(): RDD[InternalRow] = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ import org.scalatest.time.SpanSugar._

import org.apache.spark._
import org.apache.spark.sql.{SQLContext, QueryTest}
import org.apache.spark.sql.expressions.Window
import org.apache.spark.sql.hive.test.{TestHive, TestHiveContext}
import org.apache.spark.sql.test.ProcessTestUtils.ProcessOutputCapturer
import org.apache.spark.sql.types.DecimalType
Expand Down Expand Up @@ -107,6 +108,16 @@ class HiveSparkSubmitSuite
runSparkSubmit(args)
}

test("SPARK-11009 fix wrong result of Window function in cluster mode") {
val unusedJar = TestUtils.createJarWithClasses(Seq.empty)
val args = Seq(
"--class", SPARK_11009.getClass.getName.stripSuffix("$"),
"--name", "SparkSQLConfTest",
"--master", "local-cluster[2,1,1024]",
unusedJar.toString)
runSparkSubmit(args)
}

// NOTE: This is an expensive operation in terms of time (10 seconds+). Use sparingly.
// This is copied from org.apache.spark.deploy.SparkSubmitSuite
private def runSparkSubmit(args: Seq[String]): Unit = {
Expand Down Expand Up @@ -320,3 +331,31 @@ object SPARK_9757 extends QueryTest {
}
}
}

object SPARK_11009 extends QueryTest {
import org.apache.spark.sql.functions._

protected var sqlContext: SQLContext = _

def main(args: Array[String]): Unit = {
Utils.configTestLog4j("INFO")

val sparkContext = new SparkContext(
new SparkConf()
.set("spark.ui.enabled", "false")
.set("spark.sql.shuffle.partitions", "100"))

val hiveContext = new TestHiveContext(sparkContext)
sqlContext = hiveContext

try {
val df = sqlContext.range(1 << 20)
val df2 = df.select((df("id") % 1000).alias("A"), (df("id") / 1000).alias("B"))
val ws = Window.partitionBy(df2("A")).orderBy(df2("B"))
val df3 = df2.select(df2("A"), df2("B"), rowNumber().over(ws).alias("rn")).filter("rn < 0")
assert(df3.rdd.count() === 0)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We have to throw an exception at here. Otherwise, even this assertion fails, the test will pass (because we are running an application at here).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will fix it while I merging it.

} finally {
sparkContext.stop()
}
}
}