Skip to content

Commit 3269bd7

Browse files
committed
fix a bug
1 parent 6445289 commit 3269bd7

File tree

7 files changed

+35
-2
lines changed

7 files changed

+35
-2
lines changed

sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@
3737
* element in `values` region. We can get the length of this element by subtracting next offset.
3838
* Note that offset can by negative which means this element is null.
3939
*
40-
* In ghe `values` region, we store the content of elements. As we can get length info, so elements
40+
* In the `values` region, we store the content of elements. As we can get length info, so elements
4141
* can be variable-length.
4242
*
4343
* Note that when we write out this array, we should write out the `numElements` at first 4 bytes,
@@ -315,4 +315,19 @@ public void writeToMemory(Object target, long targetOffset) {
315315
sizeInBytes
316316
);
317317
}
318+
319+
@Override
320+
public UnsafeArrayData copy() {
321+
UnsafeArrayData arrayCopy = new UnsafeArrayData();
322+
final byte[] arrayDataCopy = new byte[sizeInBytes];
323+
PlatformDependent.copyMemory(
324+
baseObject,
325+
baseOffset,
326+
arrayDataCopy,
327+
PlatformDependent.BYTE_ARRAY_OFFSET,
328+
sizeInBytes
329+
);
330+
arrayCopy.pointTo(arrayDataCopy, PlatformDependent.BYTE_ARRAY_OFFSET, numElements, sizeInBytes);
331+
return arrayCopy;
332+
}
318333
}

sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeMapData.java

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,4 +58,9 @@ public ArrayData keyArray() {
5858
public ArrayData valueArray() {
5959
return values;
6060
}
61+
62+
@Override
63+
public UnsafeMapData copy() {
64+
return new UnsafeMapData(keys.copy(), values.copy());
65+
}
6166
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -379,13 +379,18 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
379379
case _: MapType => "UnsafeMapData"
380380
case _ => ctx.javaType(elementType)
381381
}
382+
val copy = elementType match {
383+
// We reuse the buffer during conversion, need copy it before process next element.
384+
case _: StructType | _: ArrayType | _: MapType => ".copy()"
385+
case _ => ""
386+
}
382387

383388
s"""
384389
final $unsafeType[] $elements = new $unsafeType[$numElements];
385390
for (int $index = 0; $index < $numElements; $index++) {
386391
${convertedElement.code}
387392
if (!${convertedElement.isNull}) {
388-
$elements[$index] = ${convertedElement.primitive};
393+
$elements[$index] = ${convertedElement.primitive}$copy;
389394
$numBytes += $elementSize;
390395
}
391396
}

sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayBasedMapData.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@ class ArrayBasedMapData(val keyArray: ArrayData, val valueArray: ArrayData) exte
2222

2323
override def numElements(): Int = keyArray.numElements()
2424

25+
override def copy(): MapData = new ArrayBasedMapData(keyArray.copy(), valueArray.copy())
26+
2527
// We need to check equality of map type in tests.
2628
override def equals(o: Any): Boolean = {
2729
if (!o.isInstanceOf[ArrayBasedMapData]) {

sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayData.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@ import org.apache.spark.sql.catalyst.expressions.SpecializedGetters
2424
abstract class ArrayData extends SpecializedGetters with Serializable {
2525
def numElements(): Int
2626

27+
def copy(): ArrayData
28+
2729
def toBooleanArray(): Array[Boolean] = {
2830
val size = numElements()
2931
val values = new Array[Boolean](size)

sql/catalyst/src/main/scala/org/apache/spark/sql/types/GenericArrayData.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@ class GenericArrayData(private[sql] val array: Array[Any])
2626

2727
override def genericGet(ordinal: Int): Any = array(ordinal)
2828

29+
override def copy(): ArrayData = new GenericArrayData(array.clone())
30+
2931
// todo: Array is invariant in scala, maybe use toSeq instead?
3032
override def toArray[T: ClassTag](elementType: DataType): Array[T] = array.map(_.asInstanceOf[T])
3133

sql/catalyst/src/main/scala/org/apache/spark/sql/types/MapData.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@ abstract class MapData extends Serializable {
2525

2626
def valueArray(): ArrayData
2727

28+
def copy(): MapData
29+
2830
def foreach(keyType: DataType, valueType: DataType, f: (Any, Any) => Unit): Unit = {
2931
val length = numElements()
3032
val keys = keyArray()

0 commit comments

Comments
 (0)