Skip to content

Commit f83e291

Browse files
committed
MutableProjection should not cache content from the input row
1 parent a91ab70 commit f83e291

File tree

11 files changed

+188
-69
lines changed

11 files changed

+188
-69
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,9 @@
1818
package org.apache.spark.sql.catalyst
1919

2020
import org.apache.spark.sql.catalyst.expressions._
21+
import org.apache.spark.sql.catalyst.util.{ArrayData, MapData}
2122
import org.apache.spark.sql.types.{DataType, StructType}
23+
import org.apache.spark.unsafe.types.UTF8String
2224

2325
/**
2426
* An abstract class for row used internal in Spark SQL, which only contain the columns as
@@ -73,4 +75,21 @@ object InternalRow {
7375

7476
/** Returns an empty [[InternalRow]]. */
7577
val empty = apply()
78+
79+
/**
80+
* Copies the given value if it's string/struct/array/map type.
81+
*/
82+
def copyValue(value: Any): Any = {
83+
if (value.isInstanceOf[UTF8String]) {
84+
value.asInstanceOf[UTF8String].clone()
85+
} else if (value.isInstanceOf[InternalRow]) {
86+
value.asInstanceOf[InternalRow].copy()
87+
} else if (value.isInstanceOf[ArrayData]) {
88+
value.asInstanceOf[ArrayData].copy()
89+
} else if (value.isInstanceOf[MapData]) {
90+
value.asInstanceOf[MapData].copy()
91+
} else {
92+
value
93+
}
94+
}
7695
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -225,10 +225,9 @@ final class SpecificMutableRow(val values: Array[MutableValue])
225225
val newValues = new Array[Any](values.length)
226226
var i = 0
227227
while (i < values.length) {
228-
newValues(i) = values(i).boxed
228+
newValues(i) = InternalRow.copyValue(values(i).boxed)
229229
i += 1
230230
}
231-
232231
new GenericInternalRow(newValues)
233232
}
234233

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ abstract class Collect extends ImperativeAggregate {
6565
}
6666

6767
override def update(b: MutableRow, input: InternalRow): Unit = {
68-
buffer += child.eval(input)
68+
buffer += InternalRow.copyValue(child.eval(input))
6969
}
7070

7171
override def merge(buffer: MutableRow, input: InternalRow): Unit = {

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -313,6 +313,9 @@ abstract class ImperativeAggregate extends AggregateFunction with CodegenFallbac
313313
* Updates its aggregation buffer, located in `mutableAggBuffer`, based on the given `inputRow`.
314314
*
315315
* Use `fieldNumber + mutableAggBufferOffset` to access fields of `mutableAggBuffer`.
316+
*
317+
* Note that, the input row may be produced by unsafe projection and it may not be safe to cache
318+
* some fields of the input row, as the values can be changed unexpectedly.
316319
*/
317320
def update(mutableAggBuffer: MutableRow, inputRow: InternalRow): Unit
318321

@@ -322,6 +325,9 @@ abstract class ImperativeAggregate extends AggregateFunction with CodegenFallbac
322325
*
323326
* Use `fieldNumber + mutableAggBufferOffset` to access fields of `mutableAggBuffer`.
324327
* Use `fieldNumber + inputAggBufferOffset` to access fields of `inputAggBuffer`.
328+
*
329+
* Note that, the input row may be produced by unsafe projection and it may not be safe to cache
330+
* some fields of the input row, as the values can be changed unexpectedly.
325331
*/
326332
def merge(mutableAggBuffer: MutableRow, inputAggBuffer: InternalRow): Unit
327333
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -276,6 +276,11 @@ class CodegenContext {
276276
case t: DecimalType => s"$row.setDecimal($ordinal, $value, ${t.precision})"
277277
// The UTF8String may came from UnsafeRow, otherwise clone is cheap (re-use the bytes)
278278
case StringType => s"$row.update($ordinal, $value.clone())"
279+
// InternalRow, ArrayData and MapData may came from UnsafeRow, we should copy it to avoid
280+
// keeping a "pointer" to a memory region which is out of control from the updated row.
281+
case _: StructType => s"$row.update($ordinal, $value.copy())"
282+
case _: ArrayType => s"$row.update($ordinal, $value.copy())"
283+
case _: MapType => s"$row.update($ordinal, $value.copy())"
279284
case udt: UserDefinedType[_] => setColumn(row, udt.sqlType, ordinal, value)
280285
case _ => s"$row.update($ordinal, $value)"
281286
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -230,7 +230,15 @@ class GenericInternalRow(val values: Array[Any]) extends BaseGenericInternalRow
230230

231231
override def numFields: Int = values.length
232232

233-
override def copy(): GenericInternalRow = this
233+
override def copy(): GenericInternalRow = {
234+
val newValues = new Array[Any](values.length)
235+
var i = 0
236+
while (i < values.length) {
237+
newValues(i) = InternalRow.copyValue(values(i))
238+
i += 1
239+
}
240+
new GenericInternalRow(newValues)
241+
}
234242
}
235243

236244
class GenericMutableRow(values: Array[Any]) extends MutableRow with BaseGenericInternalRow {
@@ -249,5 +257,13 @@ class GenericMutableRow(values: Array[Any]) extends MutableRow with BaseGenericI
249257

250258
override def update(i: Int, value: Any): Unit = { values(i) = value }
251259

252-
override def copy(): InternalRow = new GenericInternalRow(values.clone())
260+
override def copy(): InternalRow = {
261+
val newValues = new Array[Any](values.length)
262+
var i = 0
263+
while (i < values.length) {
264+
newValues(i) = InternalRow.copyValue(values(i))
265+
i += 1
266+
}
267+
new GenericInternalRow(newValues)
268+
}
253269
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/GenericArrayData.scala

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,15 @@ class GenericArrayData(val array: Array[Any]) extends ArrayData {
4949

5050
def this(seqOrArray: Any) = this(GenericArrayData.anyToSeq(seqOrArray))
5151

52-
override def copy(): ArrayData = new GenericArrayData(array.clone())
52+
override def copy(): ArrayData = {
53+
val newValues = new Array[Any](array.length)
54+
var i = 0
55+
while (i < array.length) {
56+
newValues(i) = InternalRow.copyValue(array(i))
57+
i += 1
58+
}
59+
new GenericArrayData(newValues)
60+
}
5361

5462
override def numElements(): Int = array.length
5563

Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql.catalyst.expressions
19+
20+
import scala.collection._
21+
22+
import org.apache.spark.SparkFunSuite
23+
import org.apache.spark.sql.catalyst.InternalRow
24+
import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection
25+
import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, GenericArrayData}
26+
import org.apache.spark.sql.types.{DataType, IntegerType, MapType, StringType}
27+
import org.apache.spark.unsafe.types.UTF8String
28+
29+
class ComplexDataSuite extends SparkFunSuite {
30+
31+
test("inequality tests for MapData") {
32+
def u(str: String): UTF8String = UTF8String.fromString(str)
33+
34+
// test data
35+
val testMap1 = Map(u("key1") -> 1)
36+
val testMap2 = Map(u("key1") -> 1, u("key2") -> 2)
37+
val testMap3 = Map(u("key1") -> 1)
38+
val testMap4 = Map(u("key1") -> 1, u("key2") -> 2)
39+
40+
// ArrayBasedMapData
41+
val testArrayMap1 = ArrayBasedMapData(testMap1.toMap)
42+
val testArrayMap2 = ArrayBasedMapData(testMap2.toMap)
43+
val testArrayMap3 = ArrayBasedMapData(testMap3.toMap)
44+
val testArrayMap4 = ArrayBasedMapData(testMap4.toMap)
45+
assert(testArrayMap1 !== testArrayMap3)
46+
assert(testArrayMap2 !== testArrayMap4)
47+
48+
// UnsafeMapData
49+
val unsafeConverter = UnsafeProjection.create(Array[DataType](MapType(StringType, IntegerType)))
50+
val row = new GenericMutableRow(1)
51+
def toUnsafeMap(map: ArrayBasedMapData): UnsafeMapData = {
52+
row.update(0, map)
53+
val unsafeRow = unsafeConverter.apply(row)
54+
unsafeRow.getMap(0).copy
55+
}
56+
assert(toUnsafeMap(testArrayMap1) !== toUnsafeMap(testArrayMap3))
57+
assert(toUnsafeMap(testArrayMap2) !== toUnsafeMap(testArrayMap4))
58+
}
59+
60+
test("GenericInternalRow.copy return a new instance that is independent from the old one") {
61+
val project = GenerateUnsafeProjection.generate(Seq(BoundReference(0, StringType, true)))
62+
val unsafeRow = project.apply(InternalRow(UTF8String.fromString("a")))
63+
64+
val genericRow = new GenericInternalRow(Array[Any](unsafeRow.getUTF8String(0)))
65+
val copiedGenericRow = genericRow.copy()
66+
assert(copiedGenericRow.getString(0) == "a")
67+
project.apply(InternalRow(UTF8String.fromString("b")))
68+
// The copied internal row should not be changed externally.
69+
assert(copiedGenericRow.getString(0) == "a")
70+
}
71+
72+
test("GenericMutableRow.copy return a new instance that is independent from the old one") {
73+
val project = GenerateUnsafeProjection.generate(Seq(BoundReference(0, StringType, true)))
74+
val unsafeRow = project.apply(InternalRow(UTF8String.fromString("a")))
75+
76+
val mutableRow = new GenericMutableRow(Array[Any](unsafeRow.getUTF8String(0)))
77+
val copiedMutableRow = mutableRow.copy()
78+
assert(copiedMutableRow.getString(0) == "a")
79+
project.apply(InternalRow(UTF8String.fromString("b")))
80+
// The copied internal row should not be changed externally.
81+
assert(copiedMutableRow.getString(0) == "a")
82+
}
83+
84+
test("SpecificMutableRow.copy return a new instance that is independent from the old one") {
85+
val project = GenerateUnsafeProjection.generate(Seq(BoundReference(0, StringType, true)))
86+
val unsafeRow = project.apply(InternalRow(UTF8String.fromString("a")))
87+
88+
val mutableRow = new SpecificMutableRow(Seq(StringType))
89+
mutableRow(0) = unsafeRow.getUTF8String(0)
90+
val copiedMutableRow = mutableRow.copy()
91+
assert(copiedMutableRow.getString(0) == "a")
92+
project.apply(InternalRow(UTF8String.fromString("b")))
93+
// The copied internal row should not be changed externally.
94+
assert(copiedMutableRow.getString(0) == "a")
95+
}
96+
97+
test("GenericArrayData.copy return a new instance that is independent from the old one") {
98+
val project = GenerateUnsafeProjection.generate(Seq(BoundReference(0, StringType, true)))
99+
val unsafeRow = project.apply(InternalRow(UTF8String.fromString("a")))
100+
101+
val genericArray = new GenericArrayData(Array[Any](unsafeRow.getUTF8String(0)))
102+
val copiedGenericArray = genericArray.copy()
103+
assert(copiedGenericArray.getUTF8String(0).toString == "a")
104+
project.apply(InternalRow(UTF8String.fromString("b")))
105+
// The copied array data should not be changed externally.
106+
assert(copiedGenericArray.getUTF8String(0).toString == "a")
107+
}
108+
}

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MapDataSuite.scala

Lines changed: 0 additions & 57 deletions
This file was deleted.

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratedProjectionSuite.scala

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,4 +122,23 @@ class GeneratedProjectionSuite extends SparkFunSuite {
122122
assert(unsafe1 === unsafe3)
123123
assert(unsafe1.getStruct(1, 7) === unsafe3.getStruct(1, 7))
124124
}
125+
126+
test("MutableProjection should not cache content from the input row") {
127+
val mutableProj = GenerateMutableProjection.generate(
128+
Seq(BoundReference(0, new StructType().add("i", IntegerType), true)))
129+
val mutableRow = new GenericMutableRow(1)
130+
mutableProj.target(mutableRow)
131+
132+
val unsafeProj = GenerateUnsafeProjection.generate(
133+
Seq(BoundReference(0, new StructType().add("i", IntegerType), true)))
134+
val unsafeRow = unsafeProj.apply(InternalRow(InternalRow(1)))
135+
136+
mutableProj.apply(unsafeRow)
137+
assert(mutableRow.getStruct(0, 1).getInt(0) == 1)
138+
139+
// Even if the input row of the mutable projection has been changed, the target mutable row
140+
// should keep same.
141+
unsafeProj.apply(InternalRow(InternalRow(2)))
142+
assert(mutableRow.getStruct(0, 1).getInt(0) == 1)
143+
}
125144
}

0 commit comments

Comments
 (0)