Skip to content

Commit 3352803

Browse files
mingbo_pbcloud-fan
authored andcommitted
[SPARK-27406][SQL] UnsafeArrayData serialization breaks when two machi…
This PR is the branch-2.4 version for #24317 Closes #24324 from pengbo/SPARK-27406-branch-2.4. Authored-by: mingbo_pb <[email protected]> Signed-off-by: Wenchen Fan <[email protected]>
1 parent 53658ab commit 3352803

File tree

2 files changed

+54
-2
lines changed

2 files changed

+54
-2
lines changed

sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java

Lines changed: 38 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,10 @@
1717

1818
package org.apache.spark.sql.catalyst.expressions;
1919

20+
import java.io.Externalizable;
21+
import java.io.IOException;
22+
import java.io.ObjectInput;
23+
import java.io.ObjectOutput;
2024
import java.math.BigDecimal;
2125
import java.math.BigInteger;
2226
import java.nio.ByteBuffer;
@@ -30,6 +34,8 @@
3034
import org.apache.spark.unsafe.types.CalendarInterval;
3135
import org.apache.spark.unsafe.types.UTF8String;
3236

37+
import static org.apache.spark.unsafe.Platform.BYTE_ARRAY_OFFSET;
38+
3339
/**
3440
* An Unsafe implementation of Array which is backed by raw memory instead of Java objects.
3541
*
@@ -52,7 +58,7 @@
5258
* Instances of `UnsafeArrayData` act as pointers to row data stored in this format.
5359
*/
5460

55-
public final class UnsafeArrayData extends ArrayData {
61+
public final class UnsafeArrayData extends ArrayData implements Externalizable {
5662

5763
public static int calculateHeaderPortionInBytes(int numFields) {
5864
return (int)calculateHeaderPortionInBytes((long)numFields);
@@ -523,4 +529,35 @@ public static UnsafeArrayData fromPrimitiveArray(float[] arr) {
523529
public static UnsafeArrayData fromPrimitiveArray(double[] arr) {
524530
return fromPrimitiveArray(arr, Platform.DOUBLE_ARRAY_OFFSET, arr.length, 8);
525531
}
532+
533+
534+
public byte[] getBytes() {
535+
if (baseObject instanceof byte[]
536+
&& baseOffset == Platform.BYTE_ARRAY_OFFSET
537+
&& (((byte[]) baseObject).length == sizeInBytes)) {
538+
return (byte[]) baseObject;
539+
} else {
540+
byte[] bytes = new byte[sizeInBytes];
541+
Platform.copyMemory(baseObject, baseOffset, bytes, Platform.BYTE_ARRAY_OFFSET, sizeInBytes);
542+
return bytes;
543+
}
544+
}
545+
546+
@Override
547+
public void writeExternal(ObjectOutput out) throws IOException {
548+
byte[] bytes = getBytes();
549+
out.writeInt(bytes.length);
550+
out.writeInt(this.numElements);
551+
out.write(bytes);
552+
}
553+
554+
@Override
555+
public void readExternal(ObjectInput in) throws IOException, ClassNotFoundException {
556+
this.baseOffset = BYTE_ARRAY_OFFSET;
557+
this.sizeInBytes = in.readInt();
558+
this.numElements = in.readInt();
559+
this.elementOffset = baseOffset + calculateHeaderPortionInBytes(this.numElements);
560+
this.baseObject = new byte[sizeInBytes];
561+
in.readFully((byte[]) baseObject);
562+
}
526563
}

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/UnsafeArraySuite.scala

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,13 @@
1717

1818
package org.apache.spark.sql.catalyst.util
1919

20-
import org.apache.spark.SparkFunSuite
20+
import org.apache.spark.{SparkConf, SparkFunSuite}
21+
import org.apache.spark.serializer.JavaSerializer
2122
import org.apache.spark.sql.Row
2223
import org.apache.spark.sql.catalyst.encoders.{ExpressionEncoder, RowEncoder}
2324
import org.apache.spark.sql.catalyst.expressions.UnsafeArrayData
2425
import org.apache.spark.sql.types._
26+
import org.apache.spark.unsafe.Platform
2527
import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String}
2628

2729
class UnsafeArraySuite extends SparkFunSuite {
@@ -204,4 +206,17 @@ class UnsafeArraySuite extends SparkFunSuite {
204206
val doubleEncoder = ExpressionEncoder[Array[Double]].resolveAndBind()
205207
assert(doubleEncoder.toRow(doubleArray).getArray(0).toDoubleArray.sameElements(doubleArray))
206208
}
209+
210+
test("unsafe java serialization") {
211+
val offset = 32
212+
val data = new Array[Byte](1024)
213+
Platform.putLong(data, offset, 1)
214+
val arrayData = new UnsafeArrayData()
215+
arrayData.pointTo(data, offset, data.length)
216+
arrayData.setLong(0, 19285)
217+
val ser = new JavaSerializer(new SparkConf).newInstance()
218+
val arrayDataSer = ser.deserialize[UnsafeArrayData](ser.serialize(arrayData))
219+
assert(arrayDataSer.getLong(0) == 19285)
220+
assert(arrayDataSer.getBaseObject.asInstanceOf[Array[Byte]].length == 1024)
221+
}
207222
}

0 commit comments

Comments
 (0)