Skip to content

Commit beb1ff9

Browse files
author
mingbo_pb
committed
SPARK-27406-UnsafeArrayData serialization breaks when two machines have different Oops size
1 parent 18b36ee commit beb1ff9

File tree

2 files changed

+54
-2
lines changed

2 files changed

+54
-2
lines changed

sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java

Lines changed: 38 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,10 @@
1717

1818
package org.apache.spark.sql.catalyst.expressions;
1919

20+
import java.io.Externalizable;
21+
import java.io.IOException;
22+
import java.io.ObjectInput;
23+
import java.io.ObjectOutput;
2024
import java.math.BigDecimal;
2125
import java.math.BigInteger;
2226
import java.nio.ByteBuffer;
@@ -30,6 +34,8 @@
3034
import org.apache.spark.unsafe.types.CalendarInterval;
3135
import org.apache.spark.unsafe.types.UTF8String;
3236

37+
import static org.apache.spark.unsafe.Platform.BYTE_ARRAY_OFFSET;
38+
3339
/**
3440
* An Unsafe implementation of Array which is backed by raw memory instead of Java objects.
3541
*
@@ -52,7 +58,7 @@
5258
* Instances of `UnsafeArrayData` act as pointers to row data stored in this format.
5359
*/
5460

55-
public final class UnsafeArrayData extends ArrayData {
61+
public final class UnsafeArrayData extends ArrayData implements Externalizable {
5662
public static int calculateHeaderPortionInBytes(int numFields) {
5763
return (int)calculateHeaderPortionInBytes((long)numFields);
5864
}
@@ -485,4 +491,35 @@ public static UnsafeArrayData fromPrimitiveArray(float[] arr) {
485491
public static UnsafeArrayData fromPrimitiveArray(double[] arr) {
486492
return fromPrimitiveArray(arr, Platform.DOUBLE_ARRAY_OFFSET, arr.length, 8);
487493
}
494+
495+
496+
public byte[] getBytes() {
497+
if (baseObject instanceof byte[]
498+
&& baseOffset == Platform.BYTE_ARRAY_OFFSET
499+
&& (((byte[]) baseObject).length == sizeInBytes)) {
500+
return (byte[]) baseObject;
501+
} else {
502+
byte[] bytes = new byte[sizeInBytes];
503+
Platform.copyMemory(baseObject, baseOffset, bytes, Platform.BYTE_ARRAY_OFFSET, sizeInBytes);
504+
return bytes;
505+
}
506+
}
507+
508+
@Override
509+
public void writeExternal(ObjectOutput out) throws IOException {
510+
byte[] bytes = getBytes();
511+
out.writeInt(bytes.length);
512+
out.writeInt(this.numElements);
513+
out.write(bytes);
514+
}
515+
516+
@Override
517+
public void readExternal(ObjectInput in) throws IOException, ClassNotFoundException {
518+
this.baseOffset = BYTE_ARRAY_OFFSET;
519+
this.sizeInBytes = in.readInt();
520+
this.numElements = in.readInt();
521+
this.elementOffset = baseOffset + calculateHeaderPortionInBytes(this.numElements);
522+
this.baseObject = new byte[sizeInBytes];
523+
in.readFully((byte[]) baseObject);
524+
}
488525
}

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/UnsafeArraySuite.scala

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,13 @@ package org.apache.spark.sql.catalyst.util
1919

2020
import java.time.ZoneId
2121

22-
import org.apache.spark.SparkFunSuite
22+
import org.apache.spark.{SparkConf, SparkFunSuite}
23+
import org.apache.spark.serializer.JavaSerializer
2324
import org.apache.spark.sql.Row
2425
import org.apache.spark.sql.catalyst.encoders.{ExpressionEncoder, RowEncoder}
2526
import org.apache.spark.sql.catalyst.expressions.UnsafeArrayData
2627
import org.apache.spark.sql.types._
28+
import org.apache.spark.unsafe.Platform
2729
import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String}
2830

2931
class UnsafeArraySuite extends SparkFunSuite {
@@ -210,4 +212,17 @@ class UnsafeArraySuite extends SparkFunSuite {
210212
val doubleEncoder = ExpressionEncoder[Array[Double]].resolveAndBind()
211213
assert(doubleEncoder.toRow(doubleArray).getArray(0).toDoubleArray.sameElements(doubleArray))
212214
}
215+
216+
test("unsafe java serialization") {
217+
val offset = 32
218+
val data = new Array[Byte](1024)
219+
Platform.putLong(data, offset, 1)
220+
val arrayData = new UnsafeArrayData()
221+
arrayData.pointTo(data, offset, data.length)
222+
arrayData.setLong(0, 19285)
223+
val ser = new JavaSerializer(new SparkConf).newInstance()
224+
val arrayDataSer = ser.deserialize[UnsafeArrayData](ser.serialize(arrayData))
225+
assert(arrayDataSer.getLong(0) == 19285)
226+
assert(arrayDataSer.getBaseObject.asInstanceOf[Array[Byte]].length == 1024)
227+
}
213228
}

0 commit comments

Comments
 (0)