Skip to content

Commit e57d6b5

Browse files
cloud-fandavies
authored andcommitted
[SPARK-9683] [SQL] copy UTF8String when convert unsafe array/map to safe
When we convert unsafe row to safe row, we will do copy if the column is struct or string type. However, the string inside unsafe array/map are not copied, which may cause problems. Author: Wenchen Fan <[email protected]> Closes apache#7990 from cloud-fan/copy and squashes the following commits: c13d1e3 [Wenchen Fan] change test name fe36294 [Wenchen Fan] we should deep copy UTF8String when convert unsafe row to safe row
1 parent 15bd6f3 commit e57d6b5

File tree

2 files changed

+40
-1
lines changed

2 files changed

+40
-1
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/FromUnsafe.scala

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.expressions
1919

2020
import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
2121
import org.apache.spark.sql.types._
22+
import org.apache.spark.unsafe.types.UTF8String
2223

2324
case class FromUnsafe(child: Expression) extends UnaryExpression
2425
with ExpectsInputTypes with CodegenFallback {
@@ -52,6 +53,8 @@ case class FromUnsafe(child: Expression) extends UnaryExpression
5253
}
5354
new GenericArrayData(result)
5455

56+
case StringType => value.asInstanceOf[UTF8String].clone()
57+
5558
case MapType(kt, vt, _) =>
5659
val map = value.asInstanceOf[UnsafeMapData]
5760
val safeKeyArray = convert(map.keys, ArrayType(kt)).asInstanceOf[GenericArrayData]

sql/core/src/test/scala/org/apache/spark/sql/execution/RowFormatConvertersSuite.scala

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,13 @@
1717

1818
package org.apache.spark.sql.execution
1919

20+
import org.apache.spark.rdd.RDD
2021
import org.apache.spark.sql.Row
21-
import org.apache.spark.sql.catalyst.expressions.{Literal, IsNull}
22+
import org.apache.spark.sql.catalyst.InternalRow
23+
import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Attribute, Literal, IsNull}
2224
import org.apache.spark.sql.test.TestSQLContext
25+
import org.apache.spark.sql.types.{GenericArrayData, ArrayType, StructType, StringType}
26+
import org.apache.spark.unsafe.types.UTF8String
2327

2428
class RowFormatConvertersSuite extends SparkPlanTest {
2529

@@ -87,4 +91,36 @@ class RowFormatConvertersSuite extends SparkPlanTest {
8791
input.map(Row.fromTuple)
8892
)
8993
}
94+
95+
test("SPARK-9683: copy UTF8String when convert unsafe array/map to safe") {
96+
SparkPlan.currentContext.set(TestSQLContext)
97+
val schema = ArrayType(StringType)
98+
val rows = (1 to 100).map { i =>
99+
InternalRow(new GenericArrayData(Array[Any](UTF8String.fromString(i.toString))))
100+
}
101+
val relation = LocalTableScan(Seq(AttributeReference("t", schema)()), rows)
102+
103+
val plan =
104+
DummyPlan(
105+
ConvertToSafe(
106+
ConvertToUnsafe(relation)))
107+
assert(plan.execute().collect().map(_.getUTF8String(0).toString) === (1 to 100).map(_.toString))
108+
}
109+
}
110+
111+
case class DummyPlan(child: SparkPlan) extends UnaryNode {
112+
113+
override protected def doExecute(): RDD[InternalRow] = {
114+
child.execute().mapPartitions { iter =>
115+
// cache all strings to make sure we have deep copied UTF8String inside incoming
116+
// safe InternalRow.
117+
val strings = new scala.collection.mutable.ArrayBuffer[UTF8String]
118+
iter.foreach { row =>
119+
strings += row.getArray(0).getUTF8String(0)
120+
}
121+
strings.map(InternalRow(_)).iterator
122+
}
123+
}
124+
125+
override def output: Seq[Attribute] = Seq(AttributeReference("a", StringType)())
90126
}

0 commit comments

Comments
 (0)