Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@
package org.apache.spark.sql.catalyst

import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.util.{ArrayData, MapData}
import org.apache.spark.sql.types.{DataType, StructType}
import org.apache.spark.unsafe.types.UTF8String

/**
* An abstract class for row used internal in Spark SQL, which only contain the columns as
Expand Down Expand Up @@ -73,4 +75,21 @@ object InternalRow {

/** Returns an empty [[InternalRow]]. */
val empty = apply()

/**
* Copies the given value if it's string/struct/array/map type.
*/
def copyValue(value: Any): Any = {
if (value.isInstanceOf[UTF8String]) {
value.asInstanceOf[UTF8String].clone()
} else if (value.isInstanceOf[InternalRow]) {
value.asInstanceOf[InternalRow].copy()
} else if (value.isInstanceOf[ArrayData]) {
value.asInstanceOf[ArrayData].copy()
} else if (value.isInstanceOf[MapData]) {
value.asInstanceOf[MapData].copy()
} else {
value
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -225,10 +225,9 @@ final class SpecificMutableRow(val values: Array[MutableValue])
val newValues = new Array[Any](values.length)
var i = 0
while (i < values.length) {
newValues(i) = values(i).boxed
newValues(i) = InternalRow.copyValue(values(i).boxed)
i += 1
}

new GenericInternalRow(newValues)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ abstract class Collect extends ImperativeAggregate {
}

override def update(b: MutableRow, input: InternalRow): Unit = {
buffer += child.eval(input)
buffer += InternalRow.copyValue(child.eval(input))
}

override def merge(buffer: MutableRow, input: InternalRow): Unit = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -313,6 +313,9 @@ abstract class ImperativeAggregate extends AggregateFunction with CodegenFallbac
* Updates its aggregation buffer, located in `mutableAggBuffer`, based on the given `inputRow`.
*
* Use `fieldNumber + mutableAggBufferOffset` to access fields of `mutableAggBuffer`.
*
* Note that, the input row may be produced by unsafe projection and it may not be safe to cache
* some fields of the input row, as the values can be changed unexpectedly.
*/
def update(mutableAggBuffer: MutableRow, inputRow: InternalRow): Unit

Expand All @@ -322,6 +325,9 @@ abstract class ImperativeAggregate extends AggregateFunction with CodegenFallbac
*
* Use `fieldNumber + mutableAggBufferOffset` to access fields of `mutableAggBuffer`.
* Use `fieldNumber + inputAggBufferOffset` to access fields of `inputAggBuffer`.
*
* Note that, the input row may be produced by unsafe projection and it may not be safe to cache
* some fields of the input row, as the values can be changed unexpectedly.
*/
def merge(mutableAggBuffer: MutableRow, inputAggBuffer: InternalRow): Unit
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -277,6 +277,11 @@ class CodegenContext {
case t: DecimalType => s"$row.setDecimal($ordinal, $value, ${t.precision})"
// The UTF8String may came from UnsafeRow, otherwise clone is cheap (re-use the bytes)
case StringType => s"$row.update($ordinal, $value.clone())"
// InternalRow, ArrayData and MapData may came from UnsafeRow, we should copy it to avoid
// keeping a "pointer" to a memory region which is out of control from the updated row.
case _: StructType => s"$row.update($ordinal, $value.copy())"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If $row could do the copy inside update, then we do need to do the copy here, right?

Maybe it's time to check all the MutableRow, MutableProjection, to see where is the best place to do the copy.

case _: ArrayType => s"$row.update($ordinal, $value.copy())"
case _: MapType => s"$row.update($ordinal, $value.copy())"
case udt: UserDefinedType[_] => setColumn(row, udt.sqlType, ordinal, value)
case _ => s"$row.update($ordinal, $value)"
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,15 @@ class GenericInternalRow(val values: Array[Any]) extends BaseGenericInternalRow

override def numFields: Int = values.length

override def copy(): GenericInternalRow = this
override def copy(): GenericInternalRow = {
val newValues = new Array[Any](values.length)
var i = 0
while (i < values.length) {
newValues(i) = InternalRow.copyValue(values(i))
i += 1
}
new GenericInternalRow(newValues)
}
}

class GenericMutableRow(values: Array[Any]) extends MutableRow with BaseGenericInternalRow {
Expand All @@ -249,5 +257,13 @@ class GenericMutableRow(values: Array[Any]) extends MutableRow with BaseGenericI

override def update(i: Int, value: Any): Unit = { values(i) = value }

override def copy(): InternalRow = new GenericInternalRow(values.clone())
override def copy(): InternalRow = {
val newValues = new Array[Any](values.length)
var i = 0
while (i < values.length) {
newValues(i) = InternalRow.copyValue(values(i))
i += 1
}
new GenericInternalRow(newValues)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,15 @@ class GenericArrayData(val array: Array[Any]) extends ArrayData {

def this(seqOrArray: Any) = this(GenericArrayData.anyToSeq(seqOrArray))

override def copy(): ArrayData = new GenericArrayData(array.clone())
override def copy(): ArrayData = {
val newValues = new Array[Any](array.length)
var i = 0
while (i < array.length) {
newValues(i) = InternalRow.copyValue(array(i))
i += 1
}
new GenericArrayData(newValues)
}

override def numElements(): Int = array.length

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -121,10 +121,6 @@ class RowTest extends FunSpec with Matchers {
externalRow should be theSameInstanceAs externalRow.copy()
}

it("copy should return same ref for internal rows") {
internalRow should be theSameInstanceAs internalRow.copy()
}

it("toSeq should not expose internal state for external rows") {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we also fix the external row? Why copy should return same ref?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy returned the same ref because it is supposed to be immutable. See #10553 for more context.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah got it, but it's not true for internal row right? It can be mutable so it's safe to remove this test.

Copy link
Contributor

@hvanhovell hvanhovell Sep 15, 2016

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well MutableRow is mutable, so it shouldn't hold for those. The only exception is GenericInternalRow.

That being said, I don't mind if you remove/modify the test.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As long as InternalRow can have mutable implementation, InternalRow is not immutable anymore, because it can have a struct field, whose value can be a MutableRow.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As discussed offline: I don't see the point of having an immutable GenericInternalRow if we cannot guarantee its immutability. We could just make every InternalRow a mutable one, and simplify the class structure in the process. I am not sure if we should make that part of the current PR though.

val modifiedValues = modifyValues(externalRow.toSeq)
externalRow.toSeq should not equal modifiedValues
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.catalyst.expressions

import scala.collection._

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection
import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, GenericArrayData}
import org.apache.spark.sql.types.{DataType, IntegerType, MapType, StringType}
import org.apache.spark.unsafe.types.UTF8String

class ComplexDataSuite extends SparkFunSuite {

test("inequality tests for MapData") {
def u(str: String): UTF8String = UTF8String.fromString(str)

// test data
val testMap1 = Map(u("key1") -> 1)
val testMap2 = Map(u("key1") -> 1, u("key2") -> 2)
val testMap3 = Map(u("key1") -> 1)
val testMap4 = Map(u("key1") -> 1, u("key2") -> 2)

// ArrayBasedMapData
val testArrayMap1 = ArrayBasedMapData(testMap1.toMap)
val testArrayMap2 = ArrayBasedMapData(testMap2.toMap)
val testArrayMap3 = ArrayBasedMapData(testMap3.toMap)
val testArrayMap4 = ArrayBasedMapData(testMap4.toMap)
assert(testArrayMap1 !== testArrayMap3)
assert(testArrayMap2 !== testArrayMap4)

// UnsafeMapData
val unsafeConverter = UnsafeProjection.create(Array[DataType](MapType(StringType, IntegerType)))
val row = new GenericMutableRow(1)
def toUnsafeMap(map: ArrayBasedMapData): UnsafeMapData = {
row.update(0, map)
val unsafeRow = unsafeConverter.apply(row)
unsafeRow.getMap(0).copy
}
assert(toUnsafeMap(testArrayMap1) !== toUnsafeMap(testArrayMap3))
assert(toUnsafeMap(testArrayMap2) !== toUnsafeMap(testArrayMap4))
}

test("GenericInternalRow.copy return a new instance that is independent from the old one") {
val project = GenerateUnsafeProjection.generate(Seq(BoundReference(0, StringType, true)))
val unsafeRow = project.apply(InternalRow(UTF8String.fromString("a")))

val genericRow = new GenericInternalRow(Array[Any](unsafeRow.getUTF8String(0)))
val copiedGenericRow = genericRow.copy()
assert(copiedGenericRow.getString(0) == "a")
project.apply(InternalRow(UTF8String.fromString("b")))
// The copied internal row should not be changed externally.
assert(copiedGenericRow.getString(0) == "a")
}

test("GenericMutableRow.copy return a new instance that is independent from the old one") {
val project = GenerateUnsafeProjection.generate(Seq(BoundReference(0, StringType, true)))
val unsafeRow = project.apply(InternalRow(UTF8String.fromString("a")))

val mutableRow = new GenericMutableRow(Array[Any](unsafeRow.getUTF8String(0)))
val copiedMutableRow = mutableRow.copy()
assert(copiedMutableRow.getString(0) == "a")
project.apply(InternalRow(UTF8String.fromString("b")))
// The copied internal row should not be changed externally.
assert(copiedMutableRow.getString(0) == "a")
}

test("SpecificMutableRow.copy return a new instance that is independent from the old one") {
val project = GenerateUnsafeProjection.generate(Seq(BoundReference(0, StringType, true)))
val unsafeRow = project.apply(InternalRow(UTF8String.fromString("a")))

val mutableRow = new SpecificMutableRow(Seq(StringType))
mutableRow(0) = unsafeRow.getUTF8String(0)
val copiedMutableRow = mutableRow.copy()
assert(copiedMutableRow.getString(0) == "a")
project.apply(InternalRow(UTF8String.fromString("b")))
// The copied internal row should not be changed externally.
assert(copiedMutableRow.getString(0) == "a")
}

test("GenericArrayData.copy return a new instance that is independent from the old one") {
val project = GenerateUnsafeProjection.generate(Seq(BoundReference(0, StringType, true)))
val unsafeRow = project.apply(InternalRow(UTF8String.fromString("a")))

val genericArray = new GenericArrayData(Array[Any](unsafeRow.getUTF8String(0)))
val copiedGenericArray = genericArray.copy()
assert(copiedGenericArray.getUTF8String(0).toString == "a")
project.apply(InternalRow(UTF8String.fromString("b")))
// The copied array data should not be changed externally.
assert(copiedGenericArray.getUTF8String(0).toString == "a")
}
}

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -122,4 +122,23 @@ class GeneratedProjectionSuite extends SparkFunSuite {
assert(unsafe1 === unsafe3)
assert(unsafe1.getStruct(1, 7) === unsafe3.getStruct(1, 7))
}

test("MutableProjection should not cache content from the input row") {
val mutableProj = GenerateMutableProjection.generate(
Seq(BoundReference(0, new StructType().add("i", IntegerType), true)))
val mutableRow = new GenericMutableRow(1)
mutableProj.target(mutableRow)

val unsafeProj = GenerateUnsafeProjection.generate(
Seq(BoundReference(0, new StructType().add("i", IntegerType), true)))
val unsafeRow = unsafeProj.apply(InternalRow(InternalRow(1)))

mutableProj.apply(unsafeRow)
assert(mutableRow.getStruct(0, 1).getInt(0) == 1)

// Even if the input row of the mutable projection has been changed, the target mutable row
// should keep same.
unsafeProj.apply(InternalRow(InternalRow(2)))
assert(mutableRow.getStruct(0, 1).getInt(0) == 1)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -86,17 +86,6 @@ class SortBasedAggregationIterator(
// The aggregation buffer used by the sort-based aggregation.
private[this] val sortBasedAggregationBuffer: MutableRow = newBuffer

// This safe projection is used to turn the input row into safe row. This is necessary
// because the input row may be produced by unsafe projection in child operator and all the
// produced rows share one byte array. However, when we update the aggregate buffer according to
// the input row, we may cache some values from input row, e.g. `Max` will keep the max value from
// input row via MutableProjection, `CollectList` will keep all values in an array via
// ImperativeAggregate framework. These values may get changed unexpectedly if the underlying
// unsafe projection update the shared byte array. By applying a safe projection to the input row,
// we can cut down the connection from input row to the shared byte array, and thus it's safe to
// cache values from input row while updating the aggregation buffer.
private[this] val safeProj: Projection = FromUnsafeProjection(valueAttributes.map(_.dataType))

protected def initialize(): Unit = {
if (inputIterator.hasNext) {
initializeBuffer(sortBasedAggregationBuffer)
Expand All @@ -119,7 +108,7 @@ class SortBasedAggregationIterator(
// We create a variable to track if we see the next group.
var findNextPartition = false
// firstRowInNextGroup is the first row of this group. We first process it.
processRow(sortBasedAggregationBuffer, safeProj(firstRowInNextGroup))
processRow(sortBasedAggregationBuffer, firstRowInNextGroup)

// The search will stop when we see the next group or there is no
// input row left in the iter.
Expand All @@ -130,7 +119,7 @@ class SortBasedAggregationIterator(

// Check if the current row belongs the current input row.
if (currentGroupingKey == groupingKey) {
processRow(sortBasedAggregationBuffer, safeProj(currentRow))
processRow(sortBasedAggregationBuffer, currentRow)
} else {
// We find a new group.
findNextPartition = true
Expand Down