Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,10 @@
import java.math.BigInteger;
import java.nio.ByteBuffer;

import com.esotericsoftware.kryo.Kryo;
import com.esotericsoftware.kryo.KryoSerializable;
import com.esotericsoftware.kryo.io.Input;
import com.esotericsoftware.kryo.io.Output;
import org.apache.spark.sql.catalyst.util.ArrayData;
import org.apache.spark.sql.types.*;
import org.apache.spark.unsafe.Platform;
Expand Down Expand Up @@ -58,7 +62,7 @@
* Instances of `UnsafeArrayData` act as pointers to row data stored in this format.
*/

public final class UnsafeArrayData extends ArrayData implements Externalizable {
public final class UnsafeArrayData extends ArrayData implements Externalizable, KryoSerializable {
public static int calculateHeaderPortionInBytes(int numFields) {
return (int)calculateHeaderPortionInBytes((long)numFields);
}
Expand Down Expand Up @@ -492,22 +496,9 @@ public static UnsafeArrayData fromPrimitiveArray(double[] arr) {
return fromPrimitiveArray(arr, Platform.DOUBLE_ARRAY_OFFSET, arr.length, 8);
}


public byte[] getBytes() {
if (baseObject instanceof byte[]
&& baseOffset == Platform.BYTE_ARRAY_OFFSET
&& (((byte[]) baseObject).length == sizeInBytes)) {
return (byte[]) baseObject;
} else {
byte[] bytes = new byte[sizeInBytes];
Platform.copyMemory(baseObject, baseOffset, bytes, Platform.BYTE_ARRAY_OFFSET, sizeInBytes);
return bytes;
}
}

@Override
public void writeExternal(ObjectOutput out) throws IOException {
byte[] bytes = getBytes();
byte[] bytes = UnsafeDataUtils.getBytes(baseObject, baseOffset, sizeInBytes);
out.writeInt(bytes.length);
out.writeInt(this.numElements);
out.write(bytes);
Expand All @@ -522,4 +513,22 @@ public void readExternal(ObjectInput in) throws IOException, ClassNotFoundExcept
this.baseObject = new byte[sizeInBytes];
in.readFully((byte[]) baseObject);
}

@Override
public void write(Kryo kryo, Output output) {
byte[] bytes = UnsafeDataUtils.getBytes(baseObject, baseOffset, sizeInBytes);
output.writeInt(bytes.length);
output.writeInt(this.numElements);
output.write(bytes);
}

@Override
public void read(Kryo kryo, Input input) {
this.baseOffset = BYTE_ARRAY_OFFSET;
this.sizeInBytes = input.readInt();
this.numElements = input.readInt();
this.elementOffset = baseOffset + calculateHeaderPortionInBytes(this.numElements);
this.baseObject = new byte[sizeInBytes];
input.read((byte[]) baseObject);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.catalyst.expressions;

import org.apache.spark.unsafe.Platform;

/**
* General utilities available for unsafe data
*/
final class UnsafeDataUtils {

private UnsafeDataUtils() {
}

public static byte[] getBytes(Object baseObject, long baseOffset, int sizeInBytes) {
if (baseObject instanceof byte[]
&& baseOffset == Platform.BYTE_ARRAY_OFFSET
&& (((byte[]) baseObject).length == sizeInBytes)) {
return (byte[]) baseObject;
}
byte[] bytes = new byte[sizeInBytes];
Platform.copyMemory(baseObject, baseOffset, bytes, Platform.BYTE_ARRAY_OFFSET,
sizeInBytes);
return bytes;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,21 @@

package org.apache.spark.sql.catalyst.expressions;

import java.nio.ByteBuffer;

import com.esotericsoftware.kryo.Kryo;
import com.esotericsoftware.kryo.KryoSerializable;
import com.esotericsoftware.kryo.io.Input;
import com.esotericsoftware.kryo.io.Output;
import org.apache.spark.sql.catalyst.util.MapData;
import org.apache.spark.unsafe.Platform;

import java.io.Externalizable;
import java.io.IOException;
import java.io.ObjectInput;
import java.io.ObjectOutput;
import java.nio.ByteBuffer;

import static org.apache.spark.unsafe.Platform.BYTE_ARRAY_OFFSET;

/**
* An Unsafe implementation of Map which is backed by raw memory instead of Java objects.
*
Expand All @@ -33,7 +43,7 @@
* elements, otherwise the behavior is undefined.
*/
// TODO: Use a more efficient format which doesn't depend on unsafe array.
public final class UnsafeMapData extends MapData {
public final class UnsafeMapData extends MapData implements Externalizable, KryoSerializable {

private Object baseObject;
private long baseOffset;
Expand Down Expand Up @@ -123,4 +133,36 @@ public UnsafeMapData copy() {
mapCopy.pointTo(mapDataCopy, Platform.BYTE_ARRAY_OFFSET, sizeInBytes);
return mapCopy;
}

@Override
public void writeExternal(ObjectOutput out) throws IOException {
byte[] bytes = UnsafeDataUtils.getBytes(baseObject, baseOffset, sizeInBytes);
out.writeInt(bytes.length);
out.write(bytes);
}

@Override
public void readExternal(ObjectInput in) throws IOException, ClassNotFoundException {
this.baseOffset = BYTE_ARRAY_OFFSET;
this.sizeInBytes = in.readInt();
this.baseObject = new byte[sizeInBytes];
in.readFully((byte[]) baseObject);
pointTo(baseObject, baseOffset, sizeInBytes);
}

@Override
public void write(Kryo kryo, Output output) {
byte[] bytes = UnsafeDataUtils.getBytes(baseObject, baseOffset, sizeInBytes);
output.writeInt(bytes.length);
output.write(bytes);
}

@Override
public void read(Kryo kryo, Input input) {
this.baseOffset = BYTE_ARRAY_OFFSET;
this.sizeInBytes = input.readInt();
this.baseObject = new byte[sizeInBytes];
input.read((byte[]) baseObject);
pointTo(baseObject, baseOffset, sizeInBytes);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -541,14 +541,7 @@ public boolean equals(Object other) {
* Returns the underlying bytes for this UnsafeRow.
*/
public byte[] getBytes() {
if (baseObject instanceof byte[] && baseOffset == Platform.BYTE_ARRAY_OFFSET
&& (((byte[]) baseObject).length == sizeInBytes)) {
return (byte[]) baseObject;
} else {
byte[] bytes = new byte[sizeInBytes];
Platform.copyMemory(baseObject, baseOffset, bytes, Platform.BYTE_ARRAY_OFFSET, sizeInBytes);
return bytes;
}
return UnsafeDataUtils.getBytes(baseObject, baseOffset, sizeInBytes);
}

// This is for debugging
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.util
import java.time.ZoneId

import org.apache.spark.{SparkConf, SparkFunSuite}
import org.apache.spark.serializer.JavaSerializer
import org.apache.spark.serializer.{JavaSerializer, KryoSerializer}
import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.encoders.{ExpressionEncoder, RowEncoder}
import org.apache.spark.sql.catalyst.expressions.UnsafeArrayData
Expand Down Expand Up @@ -60,6 +60,16 @@ class UnsafeArraySuite extends SparkFunSuite {
val doubleMultiDimArray = Array(
Array(1.1, 11.1), Array(2.2, 22.2, 222.2), Array(3.3, 33.3, 333.3, 3333.3))

val serialArray = {
val offset = 32
val data = new Array[Byte](1024)
Platform.putLong(data, offset, 1)
val arrayData = new UnsafeArrayData()
arrayData.pointTo(data, offset, data.length)
arrayData.setLong(0, 19285)
arrayData
}

test("read array") {
val unsafeBoolean = ExpressionEncoder[Array[Boolean]].resolveAndBind().
toRow(booleanArray).getArray(0)
Expand Down Expand Up @@ -214,14 +224,15 @@ class UnsafeArraySuite extends SparkFunSuite {
}

test("unsafe java serialization") {
val offset = 32
val data = new Array[Byte](1024)
Platform.putLong(data, offset, 1)
val arrayData = new UnsafeArrayData()
arrayData.pointTo(data, offset, data.length)
arrayData.setLong(0, 19285)
val ser = new JavaSerializer(new SparkConf).newInstance()
val arrayDataSer = ser.deserialize[UnsafeArrayData](ser.serialize(arrayData))
val arrayDataSer = ser.deserialize[UnsafeArrayData](ser.serialize(serialArray))
assert(arrayDataSer.getLong(0) == 19285)
assert(arrayDataSer.getBaseObject.asInstanceOf[Array[Byte]].length == 1024)
}

test("unsafe Kryo serialization") {
val ser = new KryoSerializer(new SparkConf).newInstance()
val arrayDataSer = ser.deserialize[UnsafeArrayData](ser.serialize(serialArray))
assert(arrayDataSer.getLong(0) == 19285)
assert(arrayDataSer.getBaseObject.asInstanceOf[Array[Byte]].length == 1024)
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.catalyst.util

import org.apache.spark.{SparkConf, SparkFunSuite}
import org.apache.spark.serializer.{JavaSerializer, KryoSerializer}
import org.apache.spark.sql.catalyst.expressions.{UnsafeArrayData, UnsafeMapData}
import org.apache.spark.unsafe.Platform

class UnsafeMapSuite extends SparkFunSuite {

val unsafeMapData = {
val offset = 32
val keyArraySize = 256
val baseObject = new Array[Byte](1024)
Platform.putLong(baseObject, offset, keyArraySize)

val unsafeMap = new UnsafeMapData
Platform.putLong(baseObject, offset + 8, 1)
val keyArray = new UnsafeArrayData()
keyArray.pointTo(baseObject, offset + 8, keyArraySize)
keyArray.setLong(0, 19285)

val valueArray = new UnsafeArrayData()
Platform.putLong(baseObject, offset + 8 + keyArray.getSizeInBytes, 1)
valueArray.pointTo(baseObject, offset + 8 + keyArray.getSizeInBytes, keyArraySize)
valueArray.setLong(0, 19285)
unsafeMap.pointTo(baseObject, offset, baseObject.length)
unsafeMap
}

test("unsafe java serialization") {
val ser = new JavaSerializer(new SparkConf).newInstance()
val mapDataSer = ser.deserialize[UnsafeMapData](ser.serialize(unsafeMapData))
assert(mapDataSer.numElements() == 1)
assert(mapDataSer.keyArray().getInt(0) == 19285)
assert(mapDataSer.valueArray().getInt(0) == 19285)
assert(mapDataSer.getBaseObject.asInstanceOf[Array[Byte]].length == 1024)
}

test("unsafe Kryo serialization") {
val ser = new KryoSerializer(new SparkConf).newInstance()
val mapDataSer = ser.deserialize[UnsafeMapData](ser.serialize(unsafeMapData))
assert(mapDataSer.numElements() == 1)
assert(mapDataSer.keyArray().getInt(0) == 19285)
assert(mapDataSer.valueArray().getInt(0) == 19285)
assert(mapDataSer.getBaseObject.asInstanceOf[Array[Byte]].length == 1024)
}
}