Skip to content

Commit 16e4954

Browse files
committed
[SPARK-29503][SQL] Copy result row from RowWriter in GenerateUnsafeProjection when expr is lambdaFunction in MapObject
1 parent eb8c420 commit 16e4954

File tree

2 files changed

+47
-1
lines changed

2 files changed

+47
-1
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.expressions.codegen
2020
import org.apache.spark.sql.catalyst.expressions._
2121
import org.apache.spark.sql.catalyst.expressions.BindReferences.bindReferences
2222
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
23+
import org.apache.spark.sql.catalyst.expressions.objects.LambdaVariable
2324
import org.apache.spark.sql.types._
2425

2526
/**
@@ -301,6 +302,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
301302
// Evaluate all the subexpression.
302303
val evalSubexpr = ctx.subexprFunctionsCode
303304

305+
304306
val writeExpressions = writeExpressionsToBuffer(
305307
ctx, ctx.INPUT_ROW, exprEvals, exprSchemas, rowWriter, isTopLevel = true)
306308

@@ -310,8 +312,19 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
310312
|$evalSubexpr
311313
|$writeExpressions
312314
""".stripMargin
315+
316+
val runInsideLoop = expressions.exists {
317+
case e: LambdaVariable => true
318+
case _ => false
319+
}
320+
val extractValueCode = if (runInsideLoop) {
321+
s"$rowWriter.getRow().copy()"
322+
} else {
323+
s"$rowWriter.getRow()"
324+
}
325+
313326
// `rowWriter` is declared as a class field, so we can access it directly in methods.
314-
ExprCode(code, FalseLiteral, JavaCode.expression(s"$rowWriter.getRow()", classOf[UnsafeRow]))
327+
ExprCode(code, FalseLiteral, JavaCode.expression(extractValueCode, classOf[UnsafeRow]))
315328
}
316329

317330
protected def canonicalize(in: Seq[Expression]): Seq[Expression] =

sql/core/src/test/scala/org/apache/spark/sql/DataFrameComplexTypeSuite.scala

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,15 @@
1717

1818
package org.apache.spark.sql
1919

20+
import scala.collection.mutable
21+
2022
import org.apache.spark.sql.catalyst.DefinedByConstructorParams
23+
import org.apache.spark.sql.catalyst.expressions.{Expression, GenericRowWithSchema}
24+
import org.apache.spark.sql.catalyst.expressions.objects.MapObjects
2125
import org.apache.spark.sql.functions._
26+
import org.apache.spark.sql.internal.SQLConf
2227
import org.apache.spark.sql.test.SharedSparkSession
28+
import org.apache.spark.sql.types.ArrayType
2329

2430
/**
2531
* A test suite to test DataFrame/SQL functionalities with complex types (i.e. array, struct, map).
@@ -64,6 +70,33 @@ class DataFrameComplexTypeSuite extends QueryTest with SharedSparkSession {
6470
val ds100_5 = Seq(S100_5()).toDS()
6571
ds100_5.rdd.count
6672
}
73+
74+
test("SPARK-29503 nest unsafe struct inside safe array") {
75+
withSQLConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> "false") {
76+
val exampleDS = spark.sparkContext.parallelize(Seq(Seq(1, 2, 3))).toDF("items")
77+
78+
// items: Seq[Int] => items.map { item => Seq(Struct(item)) }
79+
val result = exampleDS.select(
80+
new Column(MapObjects(
81+
(item: Expression) => array(struct(new Column(item))).expr,
82+
$"items".expr,
83+
exampleDS.schema("items").dataType.asInstanceOf[ArrayType].elementType
84+
)) as "items"
85+
).collect()
86+
87+
def getValueInsideDepth(result: Row, index: Int): Int = {
88+
// expected output:
89+
// WrappedArray([WrappedArray(WrappedArray([1]), WrappedArray([2]), WrappedArray([3]))])
90+
result.getSeq[mutable.WrappedArray[_]](0)(index)(0)
91+
.asInstanceOf[GenericRowWithSchema].getInt(0)
92+
}
93+
94+
assert(result.size === 1)
95+
assert(getValueInsideDepth(result.head, 0) === 1)
96+
assert(getValueInsideDepth(result.head, 1) === 2)
97+
assert(getValueInsideDepth(result.head, 2) === 3)
98+
}
99+
}
67100
}
68101

69102
class S100(

0 commit comments

Comments
 (0)