diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala index eba95c5c8b90..b28294e56bab 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala @@ -18,7 +18,9 @@ package org.apache.spark.sql.catalyst import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.util.{ArrayData, MapData} import org.apache.spark.sql.types.{DataType, StructType} +import org.apache.spark.unsafe.types.UTF8String /** * An abstract class for row used internal in Spark SQL, which only contain the columns as @@ -73,4 +75,21 @@ object InternalRow { /** Returns an empty [[InternalRow]]. */ val empty = apply() + + /** + * Copies the given value if it's string/struct/array/map type. + */ + def copyValue(value: Any): Any = { + if (value.isInstanceOf[UTF8String]) { + value.asInstanceOf[UTF8String].clone() + } else if (value.isInstanceOf[InternalRow]) { + value.asInstanceOf[InternalRow].copy() + } else if (value.isInstanceOf[ArrayData]) { + value.asInstanceOf[ArrayData].copy() + } else if (value.isInstanceOf[MapData]) { + value.asInstanceOf[MapData].copy() + } else { + value + } + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala index 61ca7272dfa6..1834fed7ee16 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala @@ -225,10 +225,9 @@ final class SpecificMutableRow(val values: Array[MutableValue]) val newValues = new Array[Any](values.length) var i = 0 while (i < values.length) { - newValues(i) = values(i).boxed + newValues(i) = InternalRow.copyValue(values(i).boxed) i += 1 } - new GenericInternalRow(newValues) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala index 896ff61b2309..c950222096d2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala @@ -65,7 +65,7 @@ abstract class Collect extends ImperativeAggregate { } override def update(b: MutableRow, input: InternalRow): Unit = { - buffer += child.eval(input) + buffer += InternalRow.copyValue(child.eval(input)) } override def merge(buffer: MutableRow, input: InternalRow): Unit = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala index b5c0844fbf31..81db98cb6675 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala @@ -313,6 +313,9 @@ abstract class ImperativeAggregate extends AggregateFunction with CodegenFallbac * Updates its aggregation buffer, located in `mutableAggBuffer`, based on the given `inputRow`. * * Use `fieldNumber + mutableAggBufferOffset` to access fields of `mutableAggBuffer`. + * + * Note that, the input row may be produced by unsafe projection and it may not be safe to cache + * some fields of the input row, as the values can be changed unexpectedly. */ def update(mutableAggBuffer: MutableRow, inputRow: InternalRow): Unit @@ -322,6 +325,9 @@ abstract class ImperativeAggregate extends AggregateFunction with CodegenFallbac * * Use `fieldNumber + mutableAggBufferOffset` to access fields of `mutableAggBuffer`. * Use `fieldNumber + inputAggBufferOffset` to access fields of `inputAggBuffer`. + * + * Note that, the input row may be produced by unsafe projection and it may not be safe to cache + * some fields of the input row, as the values can be changed unexpectedly. */ def merge(mutableAggBuffer: MutableRow, inputAggBuffer: InternalRow): Unit } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 33b9b804fc60..06da145d86e5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -277,6 +277,11 @@ class CodegenContext { case t: DecimalType => s"$row.setDecimal($ordinal, $value, ${t.precision})" // The UTF8String may came from UnsafeRow, otherwise clone is cheap (re-use the bytes) case StringType => s"$row.update($ordinal, $value.clone())" + // InternalRow, ArrayData and MapData may came from UnsafeRow, we should copy it to avoid + // keeping a "pointer" to a memory region which is out of control from the updated row. + case _: StructType => s"$row.update($ordinal, $value.copy())" + case _: ArrayType => s"$row.update($ordinal, $value.copy())" + case _: MapType => s"$row.update($ordinal, $value.copy())" case udt: UserDefinedType[_] => setColumn(row, udt.sqlType, ordinal, value) case _ => s"$row.update($ordinal, $value)" } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala index 73dceb35ac50..220ceefeee34 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala @@ -230,7 +230,15 @@ class GenericInternalRow(val values: Array[Any]) extends BaseGenericInternalRow override def numFields: Int = values.length - override def copy(): GenericInternalRow = this + override def copy(): GenericInternalRow = { + val newValues = new Array[Any](values.length) + var i = 0 + while (i < values.length) { + newValues(i) = InternalRow.copyValue(values(i)) + i += 1 + } + new GenericInternalRow(newValues) + } } class GenericMutableRow(values: Array[Any]) extends MutableRow with BaseGenericInternalRow { @@ -249,5 +257,13 @@ class GenericMutableRow(values: Array[Any]) extends MutableRow with BaseGenericI override def update(i: Int, value: Any): Unit = { values(i) = value } - override def copy(): InternalRow = new GenericInternalRow(values.clone()) + override def copy(): InternalRow = { + val newValues = new Array[Any](values.length) + var i = 0 + while (i < values.length) { + newValues(i) = InternalRow.copyValue(values(i)) + i += 1 + } + new GenericInternalRow(newValues) + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/GenericArrayData.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/GenericArrayData.scala index 7ee9581b63af..4a6086dc5709 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/GenericArrayData.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/GenericArrayData.scala @@ -49,7 +49,15 @@ class GenericArrayData(val array: Array[Any]) extends ArrayData { def this(seqOrArray: Any) = this(GenericArrayData.anyToSeq(seqOrArray)) - override def copy(): ArrayData = new GenericArrayData(array.clone()) + override def copy(): ArrayData = { + val newValues = new Array[Any](array.length) + var i = 0 + while (i < array.length) { + newValues(i) = InternalRow.copyValue(array(i)) + i += 1 + } + new GenericArrayData(newValues) + } override def numElements(): Int = array.length diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/RowTest.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/RowTest.scala index c9c9599e7f46..25699de33d71 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/RowTest.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/RowTest.scala @@ -121,10 +121,6 @@ class RowTest extends FunSpec with Matchers { externalRow should be theSameInstanceAs externalRow.copy() } - it("copy should return same ref for internal rows") { - internalRow should be theSameInstanceAs internalRow.copy() - } - it("toSeq should not expose internal state for external rows") { val modifiedValues = modifyValues(externalRow.toSeq) externalRow.toSeq should not equal modifiedValues diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexDataSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexDataSuite.scala new file mode 100644 index 000000000000..5dcb1f177ab5 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexDataSuite.scala @@ -0,0 +1,108 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import scala.collection._ + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection +import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, GenericArrayData} +import org.apache.spark.sql.types.{DataType, IntegerType, MapType, StringType} +import org.apache.spark.unsafe.types.UTF8String + +class ComplexDataSuite extends SparkFunSuite { + + test("inequality tests for MapData") { + def u(str: String): UTF8String = UTF8String.fromString(str) + + // test data + val testMap1 = Map(u("key1") -> 1) + val testMap2 = Map(u("key1") -> 1, u("key2") -> 2) + val testMap3 = Map(u("key1") -> 1) + val testMap4 = Map(u("key1") -> 1, u("key2") -> 2) + + // ArrayBasedMapData + val testArrayMap1 = ArrayBasedMapData(testMap1.toMap) + val testArrayMap2 = ArrayBasedMapData(testMap2.toMap) + val testArrayMap3 = ArrayBasedMapData(testMap3.toMap) + val testArrayMap4 = ArrayBasedMapData(testMap4.toMap) + assert(testArrayMap1 !== testArrayMap3) + assert(testArrayMap2 !== testArrayMap4) + + // UnsafeMapData + val unsafeConverter = UnsafeProjection.create(Array[DataType](MapType(StringType, IntegerType))) + val row = new GenericMutableRow(1) + def toUnsafeMap(map: ArrayBasedMapData): UnsafeMapData = { + row.update(0, map) + val unsafeRow = unsafeConverter.apply(row) + unsafeRow.getMap(0).copy + } + assert(toUnsafeMap(testArrayMap1) !== toUnsafeMap(testArrayMap3)) + assert(toUnsafeMap(testArrayMap2) !== toUnsafeMap(testArrayMap4)) + } + + test("GenericInternalRow.copy return a new instance that is independent from the old one") { + val project = GenerateUnsafeProjection.generate(Seq(BoundReference(0, StringType, true))) + val unsafeRow = project.apply(InternalRow(UTF8String.fromString("a"))) + + val genericRow = new GenericInternalRow(Array[Any](unsafeRow.getUTF8String(0))) + val copiedGenericRow = genericRow.copy() + assert(copiedGenericRow.getString(0) == "a") + project.apply(InternalRow(UTF8String.fromString("b"))) + // The copied internal row should not be changed externally. + assert(copiedGenericRow.getString(0) == "a") + } + + test("GenericMutableRow.copy return a new instance that is independent from the old one") { + val project = GenerateUnsafeProjection.generate(Seq(BoundReference(0, StringType, true))) + val unsafeRow = project.apply(InternalRow(UTF8String.fromString("a"))) + + val mutableRow = new GenericMutableRow(Array[Any](unsafeRow.getUTF8String(0))) + val copiedMutableRow = mutableRow.copy() + assert(copiedMutableRow.getString(0) == "a") + project.apply(InternalRow(UTF8String.fromString("b"))) + // The copied internal row should not be changed externally. + assert(copiedMutableRow.getString(0) == "a") + } + + test("SpecificMutableRow.copy return a new instance that is independent from the old one") { + val project = GenerateUnsafeProjection.generate(Seq(BoundReference(0, StringType, true))) + val unsafeRow = project.apply(InternalRow(UTF8String.fromString("a"))) + + val mutableRow = new SpecificMutableRow(Seq(StringType)) + mutableRow(0) = unsafeRow.getUTF8String(0) + val copiedMutableRow = mutableRow.copy() + assert(copiedMutableRow.getString(0) == "a") + project.apply(InternalRow(UTF8String.fromString("b"))) + // The copied internal row should not be changed externally. + assert(copiedMutableRow.getString(0) == "a") + } + + test("GenericArrayData.copy return a new instance that is independent from the old one") { + val project = GenerateUnsafeProjection.generate(Seq(BoundReference(0, StringType, true))) + val unsafeRow = project.apply(InternalRow(UTF8String.fromString("a"))) + + val genericArray = new GenericArrayData(Array[Any](unsafeRow.getUTF8String(0))) + val copiedGenericArray = genericArray.copy() + assert(copiedGenericArray.getUTF8String(0).toString == "a") + project.apply(InternalRow(UTF8String.fromString("b"))) + // The copied array data should not be changed externally. + assert(copiedGenericArray.getUTF8String(0).toString == "a") + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MapDataSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MapDataSuite.scala deleted file mode 100644 index 0f1264c7c326..000000000000 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MapDataSuite.scala +++ /dev/null @@ -1,57 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.catalyst.expressions - -import scala.collection._ - -import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.catalyst.util.ArrayBasedMapData -import org.apache.spark.sql.types.{DataType, IntegerType, MapType, StringType} -import org.apache.spark.unsafe.types.UTF8String - -class MapDataSuite extends SparkFunSuite { - - test("inequality tests") { - def u(str: String): UTF8String = UTF8String.fromString(str) - - // test data - val testMap1 = Map(u("key1") -> 1) - val testMap2 = Map(u("key1") -> 1, u("key2") -> 2) - val testMap3 = Map(u("key1") -> 1) - val testMap4 = Map(u("key1") -> 1, u("key2") -> 2) - - // ArrayBasedMapData - val testArrayMap1 = ArrayBasedMapData(testMap1.toMap) - val testArrayMap2 = ArrayBasedMapData(testMap2.toMap) - val testArrayMap3 = ArrayBasedMapData(testMap3.toMap) - val testArrayMap4 = ArrayBasedMapData(testMap4.toMap) - assert(testArrayMap1 !== testArrayMap3) - assert(testArrayMap2 !== testArrayMap4) - - // UnsafeMapData - val unsafeConverter = UnsafeProjection.create(Array[DataType](MapType(StringType, IntegerType))) - val row = new GenericMutableRow(1) - def toUnsafeMap(map: ArrayBasedMapData): UnsafeMapData = { - row.update(0, map) - val unsafeRow = unsafeConverter.apply(row) - unsafeRow.getMap(0).copy - } - assert(toUnsafeMap(testArrayMap1) !== toUnsafeMap(testArrayMap3)) - assert(toUnsafeMap(testArrayMap2) !== toUnsafeMap(testArrayMap4)) - } -} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratedProjectionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratedProjectionSuite.scala index b69b74b4240b..21cdfaabb979 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratedProjectionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratedProjectionSuite.scala @@ -122,4 +122,23 @@ class GeneratedProjectionSuite extends SparkFunSuite { assert(unsafe1 === unsafe3) assert(unsafe1.getStruct(1, 7) === unsafe3.getStruct(1, 7)) } + + test("MutableProjection should not cache content from the input row") { + val mutableProj = GenerateMutableProjection.generate( + Seq(BoundReference(0, new StructType().add("i", IntegerType), true))) + val mutableRow = new GenericMutableRow(1) + mutableProj.target(mutableRow) + + val unsafeProj = GenerateUnsafeProjection.generate( + Seq(BoundReference(0, new StructType().add("i", IntegerType), true))) + val unsafeRow = unsafeProj.apply(InternalRow(InternalRow(1))) + + mutableProj.apply(unsafeRow) + assert(mutableRow.getStruct(0, 1).getInt(0) == 1) + + // Even if the input row of the mutable projection has been changed, the target mutable row + // should keep same. + unsafeProj.apply(InternalRow(InternalRow(2))) + assert(mutableRow.getStruct(0, 1).getInt(0) == 1) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala index c2b1ef0fe3c2..cdcd28d7fdad 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala @@ -86,17 +86,6 @@ class SortBasedAggregationIterator( // The aggregation buffer used by the sort-based aggregation. private[this] val sortBasedAggregationBuffer: MutableRow = newBuffer - // This safe projection is used to turn the input row into safe row. This is necessary - // because the input row may be produced by unsafe projection in child operator and all the - // produced rows share one byte array. However, when we update the aggregate buffer according to - // the input row, we may cache some values from input row, e.g. `Max` will keep the max value from - // input row via MutableProjection, `CollectList` will keep all values in an array via - // ImperativeAggregate framework. These values may get changed unexpectedly if the underlying - // unsafe projection update the shared byte array. By applying a safe projection to the input row, - // we can cut down the connection from input row to the shared byte array, and thus it's safe to - // cache values from input row while updating the aggregation buffer. - private[this] val safeProj: Projection = FromUnsafeProjection(valueAttributes.map(_.dataType)) - protected def initialize(): Unit = { if (inputIterator.hasNext) { initializeBuffer(sortBasedAggregationBuffer) @@ -119,7 +108,7 @@ class SortBasedAggregationIterator( // We create a variable to track if we see the next group. var findNextPartition = false // firstRowInNextGroup is the first row of this group. We first process it. - processRow(sortBasedAggregationBuffer, safeProj(firstRowInNextGroup)) + processRow(sortBasedAggregationBuffer, firstRowInNextGroup) // The search will stop when we see the next group or there is no // input row left in the iter. @@ -130,7 +119,7 @@ class SortBasedAggregationIterator( // Check if the current row belongs the current input row. if (currentGroupingKey == groupingKey) { - processRow(sortBasedAggregationBuffer, safeProj(currentRow)) + processRow(sortBasedAggregationBuffer, currentRow) } else { // We find a new group. findNextPartition = true