Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,10 @@

package org.apache.spark.sql.catalyst.expressions;

import java.io.Externalizable;
import java.io.IOException;
import java.io.ObjectInput;
import java.io.ObjectOutput;
import java.math.BigDecimal;
import java.math.BigInteger;
import java.nio.ByteBuffer;
Expand All @@ -30,6 +34,8 @@
import org.apache.spark.unsafe.types.CalendarInterval;
import org.apache.spark.unsafe.types.UTF8String;

import static org.apache.spark.unsafe.Platform.BYTE_ARRAY_OFFSET;

/**
* An Unsafe implementation of Array which is backed by raw memory instead of Java objects.
*
Expand All @@ -52,7 +58,7 @@
* Instances of `UnsafeArrayData` act as pointers to row data stored in this format.
*/

public final class UnsafeArrayData extends ArrayData {
public final class UnsafeArrayData extends ArrayData implements Externalizable {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

implement KryoSerializable also

public static int calculateHeaderPortionInBytes(int numFields) {
return (int)calculateHeaderPortionInBytes((long)numFields);
}
Expand Down Expand Up @@ -485,4 +491,35 @@ public static UnsafeArrayData fromPrimitiveArray(float[] arr) {
public static UnsafeArrayData fromPrimitiveArray(double[] arr) {
return fromPrimitiveArray(arr, Platform.DOUBLE_ARRAY_OFFSET, arr.length, 8);
}


public byte[] getBytes() {
if (baseObject instanceof byte[]
&& baseOffset == Platform.BYTE_ARRAY_OFFSET
&& (((byte[]) baseObject).length == sizeInBytes)) {
return (byte[]) baseObject;
} else {
byte[] bytes = new byte[sizeInBytes];
Platform.copyMemory(baseObject, baseOffset, bytes, Platform.BYTE_ARRAY_OFFSET, sizeInBytes);
return bytes;
}
}

@Override
public void writeExternal(ObjectOutput out) throws IOException {
byte[] bytes = getBytes();
out.writeInt(bytes.length);
out.writeInt(this.numElements);
out.write(bytes);
}

@Override
public void readExternal(ObjectInput in) throws IOException, ClassNotFoundException {
this.baseOffset = BYTE_ARRAY_OFFSET;
this.sizeInBytes = in.readInt();
this.numElements = in.readInt();
this.elementOffset = baseOffset + calculateHeaderPortionInBytes(this.numElements);
this.baseObject = new byte[sizeInBytes];
in.readFully((byte[]) baseObject);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,13 @@ package org.apache.spark.sql.catalyst.util

import java.time.ZoneId

import org.apache.spark.SparkFunSuite
import org.apache.spark.{SparkConf, SparkFunSuite}
import org.apache.spark.serializer.JavaSerializer
import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.encoders.{ExpressionEncoder, RowEncoder}
import org.apache.spark.sql.catalyst.expressions.UnsafeArrayData
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.Platform
import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String}

class UnsafeArraySuite extends SparkFunSuite {
Expand Down Expand Up @@ -210,4 +212,17 @@ class UnsafeArraySuite extends SparkFunSuite {
val doubleEncoder = ExpressionEncoder[Array[Double]].resolveAndBind()
assert(doubleEncoder.toRow(doubleArray).getArray(0).toDoubleArray.sameElements(doubleArray))
}

test("unsafe java serialization") {
val offset = 32
val data = new Array[Byte](1024)
Platform.putLong(data, offset, 1)
val arrayData = new UnsafeArrayData()
arrayData.pointTo(data, offset, data.length)
arrayData.setLong(0, 19285)
val ser = new JavaSerializer(new SparkConf).newInstance()
val arrayDataSer = ser.deserialize[UnsafeArrayData](ser.serialize(arrayData))
assert(arrayDataSer.getLong(0) == 19285)
assert(arrayDataSer.getBaseObject.asInstanceOf[Array[Byte]].length == 1024)
}
}