|
17 | 17 |
|
18 | 18 | package org.apache.spark.sql.execution |
19 | 19 |
|
| 20 | +import org.apache.spark.rdd.RDD |
20 | 21 | import org.apache.spark.sql.Row |
21 | | -import org.apache.spark.sql.catalyst.expressions.{Literal, IsNull} |
| 22 | +import org.apache.spark.sql.catalyst.InternalRow |
| 23 | +import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Attribute, Literal, IsNull} |
22 | 24 | import org.apache.spark.sql.test.TestSQLContext |
| 25 | +import org.apache.spark.sql.types.{GenericArrayData, ArrayType, StructType, StringType} |
| 26 | +import org.apache.spark.unsafe.types.UTF8String |
23 | 27 |
|
24 | 28 | class RowFormatConvertersSuite extends SparkPlanTest { |
25 | 29 |
|
@@ -87,4 +91,36 @@ class RowFormatConvertersSuite extends SparkPlanTest { |
87 | 91 | input.map(Row.fromTuple) |
88 | 92 | ) |
89 | 93 | } |
| 94 | + |
| 95 | + test("SPARK-9683: we should deep copy UTF8String when convert unsafe row to safe row") { |
| 96 | + SparkPlan.currentContext.set(TestSQLContext) |
| 97 | + val schema = ArrayType(StringType) |
| 98 | + val rows = (1 to 100).map { i => |
| 99 | + InternalRow(new GenericArrayData(Array[Any](UTF8String.fromString(i.toString)))) |
| 100 | + } |
| 101 | + val relation = LocalTableScan(Seq(AttributeReference("t", schema)()), rows) |
| 102 | + |
| 103 | + val plan = |
| 104 | + DummyPlan( |
| 105 | + ConvertToSafe( |
| 106 | + ConvertToUnsafe(relation))) |
| 107 | + assert(plan.execute().collect().map(_.getUTF8String(0).toString) === (1 to 100).map(_.toString)) |
| 108 | + } |
| 109 | +} |
| 110 | + |
| 111 | +case class DummyPlan(child: SparkPlan) extends UnaryNode { |
| 112 | + |
| 113 | + override protected def doExecute(): RDD[InternalRow] = { |
| 114 | + child.execute().mapPartitions { iter => |
| 115 | + // cache all strings to make sure we have deep copied UTF8String inside incoming |
| 116 | + // safe InternalRow. |
| 117 | + val strings = new scala.collection.mutable.ArrayBuffer[UTF8String] |
| 118 | + iter.foreach { row => |
| 119 | + strings += row.getArray(0).getUTF8String(0) |
| 120 | + } |
| 121 | + strings.map(InternalRow(_)).iterator |
| 122 | + } |
| 123 | + } |
| 124 | + |
| 125 | + override def output: Seq[Attribute] = Seq(AttributeReference("a", StringType)()) |
90 | 126 | } |
0 commit comments