Skip to content

Commit 35b7b22

Browse files
author
Davies Liu
committed
fix wrong result of Window function in cluster mode
1 parent 12b7191 commit 35b7b22

File tree

1 file changed

+11
-7
lines changed
  • sql/core/src/main/scala/org/apache/spark/sql/execution

1 file changed

+11
-7
lines changed

sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -201,18 +201,17 @@ case class Window(
201201
* This method uses Code Generation. It can only be used on the executor side.
202202
*
203203
* @param expressions unbound ordered function expressions.
204+
* @param attributes output attributes
204205
* @return the final resulting projection.
205206
*/
206207
private[this] def createResultProjection(
207-
expressions: Seq[Expression]): MutableProjection = {
208-
val unboundToAttr = expressions.map {
209-
e => (e, AttributeReference("windowResult", e.dataType, e.nullable)())
210-
}
211-
val unboundToAttrMap = unboundToAttr.toMap
208+
expressions: Seq[Expression],
209+
attributes: Seq[Attribute]): MutableProjection = {
210+
val unboundToAttrMap = expressions.zip(attributes).toMap
212211
val patchedWindowExpression = windowExpression.map(_.transform(unboundToAttrMap))
213212
newMutableProjection(
214213
projectList ++ patchedWindowExpression,
215-
child.output ++ unboundToAttr.map(_._2))()
214+
child.output ++ attributes)()
216215
}
217216

218217
protected override def doExecute(): RDD[InternalRow] = {
@@ -247,12 +246,17 @@ case class Window(
247246
factories(index) = () => createFrameProcessor(frame, functions, ordinal)
248247
}
249248

249+
// AttributeReference can only be created in driver, or the id will not be unique
250+
val outputAttributes = unboundExpressions.map {
251+
e => AttributeReference("windowResult", e.dataType, e.nullable)()
252+
}
253+
250254
// Start processing.
251255
child.execute().mapPartitions { stream =>
252256
new Iterator[InternalRow] {
253257

254258
// Get all relevant projections.
255-
val result = createResultProjection(unboundExpressions)
259+
val result = createResultProjection(unboundExpressions, outputAttributes)
256260
val grouping = if (child.outputsUnsafeRows) {
257261
UnsafeProjection.create(partitionSpec, child.output)
258262
} else {

0 commit comments

Comments
 (0)