Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

package org.apache.spark.sql.catalyst.expressions.codegen;

import org.apache.spark.sql.catalyst.util.ArrayData;
import org.apache.spark.sql.types.Decimal;
import org.apache.spark.unsafe.Platform;
import org.apache.spark.unsafe.array.ByteArrayMethods;
Expand Down Expand Up @@ -259,4 +260,158 @@ public void write(int ordinal, CalendarInterval input) {
// move the cursor forward.
holder.cursor += 16;
}

private void writePrimitiveArray(Object input, int offset, int elementSize, int length) {
Platform.copyMemory(input, offset, holder.buffer, startingOffset + headerInBytes, elementSize * length);
}

public void writePrimitiveBooleanArray(ArrayData arrayData) {
boolean[] input = arrayData.toBooleanArray();
int length = input.length;
int offset = Platform.BYTE_ARRAY_OFFSET;
writePrimitiveArray(input, offset, 1, length);
}

public void writePrimitiveByteArray(ArrayData arrayData) {
byte[] input = arrayData.toByteArray();
int length = input.length;
int offset = Platform.BYTE_ARRAY_OFFSET;
writePrimitiveArray(input, offset, 1, length);
}

public void writePrimitiveShortArray(ArrayData arrayData) {
short[] input = arrayData.toShortArray();
int length = input.length;
int offset = Platform.SHORT_ARRAY_OFFSET;
writePrimitiveArray(input, offset, 2, length);
}

public void writePrimitiveIntArray(ArrayData arrayData) {
int[] input = arrayData.toIntArray();
int length = input.length;
int offset = Platform.INT_ARRAY_OFFSET;
writePrimitiveArray(input, offset, 4, length);
}

public void writePrimitiveLongArray(ArrayData arrayData) {
long[] input = arrayData.toLongArray();
int length = input.length;
int offset = Platform.LONG_ARRAY_OFFSET;
writePrimitiveArray(input, offset, 8, length);
}

public void writePrimitiveFloatArray(ArrayData arrayData) {
float[] input = arrayData.toFloatArray();
int length = input.length;
int offset = Platform.FLOAT_ARRAY_OFFSET;
writePrimitiveArray(input, offset, 4, length);
}

public void writePrimitiveDoubleArray(ArrayData arrayData) {
double[] input = arrayData.toDoubleArray();
int length = input.length;
int offset = Platform.DOUBLE_ARRAY_OFFSET;
writePrimitiveArray(input, offset, 8, length);
}

/** uncomment this if SPARK-16043 is merged

public void writePrimitiveBooleanArray(ArrayData arrayData) {
if (arrayData instanceof GenericBooleanArrayData) {
boolean[] input = ((GenericBooleanArrayData)arrayData).primitiveArray();
int length = input.length;
Platform.copyMemory(input, Platform.BOOLEAN_ARRAY_OFFSET,
holder.buffer, startingOffset + headerInBytes, length);
} else {
int length = arrayData.numElements();
for (int i = 0; i < length; i++) {
Platform.putBoolean(holder.buffer, holder.cursor + i, arrayData.getBoolean(i));
}
}
}

public void writePrimitiveByteArray(ArrayData arrayData) {
if (arrayData instanceof GenericByteArrayData) {
byte[] input = ((GenericByteArrayData)arrayData).primitiveArray();
int length = input.length;
Platform.copyMemory(input, Platform.BYTE_ARRAY_OFFSET,
holder.buffer, startingOffset + headerInBytes, length);
} else {
int length = arrayData.numElements();
for (int i = 0; i < length; i++) {
Platform.putByte(holder.buffer, holder.cursor + i, arrayData.getByte(i));
}
}
}

public void writePrimitiveShortArray(ArrayData arrayData) {
if (arrayData instanceof GenericShortArrayData) {
short[] input = ((GenericShortArrayData)arrayData).primitiveArray();
int length = input.length;
Platform.copyMemory(input, Platform.SHORT_ARRAY_OFFSET,
holder.buffer, startingOffset + headerInBytes, length);
} else {
int length = arrayData.numElements();
for (int i = 0; i < length; i++) {
Platform.putShort(holder.buffer, holder.cursor + i, arrayData.getShort(i));
}
}
}

public void writePrimitiveIntArray(ArrayData arrayData) {
if (arrayData instanceof GenericIntArrayData) {
int[] input = ((GenericIntArrayData)arrayData).primitiveArray();
int length = input.length;
Platform.copyMemory(input, Platform.INT_ARRAY_OFFSET,
holder.buffer, startingOffset + headerInBytes, length);
} else {
int length = arrayData.numElements();
for (int i = 0; i < length; i++) {
Platform.putInt(holder.buffer, holder.cursor + i, arrayData.getInt(i));
}
}
}

public void writePrimitiveLongArray(ArrayData arrayData) {
if (arrayData instanceof GenericLongArrayData) {
long[] input = ((GenericLongArrayData)arrayData).primitiveArray();
int length = input.length;
Platform.copyMemory(input, Platform.LONG_ARRAY_OFFSET,
holder.buffer, startingOffset + headerInBytes, length);
} else {
int length = arrayData.numElements();
for (int i = 0; i < length; i++) {
Platform.putLong(holder.buffer, holder.cursor + i, arrayData.getLong(i));
}
}
}

public void writePrimitiveFloatArray(ArrayData arrayData) {
if (arrayData instanceof GenericFloatArrayData) {
float[] input = ((GenericFloatArrayData)arrayData).primitiveArray();
int length = input.length;
Platform.copyMemory(input, Platform.FLOAT_ARRAY_OFFSET,
holder.buffer, startingOffset + headerInBytes, length);
} else {
int length = arrayData.numElements();
for (int i = 0; i < length; i++) {
Platform.putFloat(holder.buffer, holder.cursor + i, arrayData.getFloat(i));
}
}
}

public void writePrimitiveDoubleArray(ArrayData arrayData) {
if (arrayData instanceof GenericDoubleArrayData) {
double[] input = ((GenericDoubleArrayData)arrayData).primitiveArray();
int length = input.length;
Platform.copyMemory(input, Platform.DOUBLE_ARRAY_OFFSET,
holder.buffer, startingOffset + headerInBytes, length);
} else {
int length = arrayData.numElements();
for (int i = 0; i < length; i++) {
Platform.putFloat(holder.buffer, holder.cursor + i, arrayData.getFloat(i));
}
}
}
*/
}
Original file line number Diff line number Diff line change
Expand Up @@ -117,12 +117,12 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
$rowWriter.setOffsetAndSize($index, $tmpCursor, $bufferHolder.cursor - $tmpCursor);
"""

case a @ ArrayType(et, _) =>
case a @ ArrayType(et, cn) =>
s"""
// Remember the current cursor so that we can calculate how many bytes are
// written later.
final int $tmpCursor = $bufferHolder.cursor;
${writeArrayToBuffer(ctx, input.value, et, bufferHolder)}
${writeArrayToBuffer(ctx, input.value, et, cn, bufferHolder)}
$rowWriter.setOffsetAndSize($index, $tmpCursor, $bufferHolder.cursor - $tmpCursor);
"""

Expand Down Expand Up @@ -171,6 +171,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
ctx: CodegenContext,
input: String,
elementType: DataType,
containsNull: Boolean,
bufferHolder: String): String = {
val arrayWriterClass = classOf[UnsafeArrayWriter].getName
val arrayWriter = ctx.freshName("arrayWriter")
Expand Down Expand Up @@ -202,10 +203,10 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
$arrayWriter.setOffsetAndSize($index, $tmpCursor, $bufferHolder.cursor - $tmpCursor);
"""

case a @ ArrayType(et, _) =>
case a @ ArrayType(et, cn) =>
s"""
final int $tmpCursor = $bufferHolder.cursor;
${writeArrayToBuffer(ctx, element, et, bufferHolder)}
${writeArrayToBuffer(ctx, element, et, cn, bufferHolder)}
$arrayWriter.setOffsetAndSize($index, $tmpCursor, $bufferHolder.cursor - $tmpCursor);
"""

Expand All @@ -224,22 +225,39 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
case _ => s"$arrayWriter.write($index, $element);"
}

val primitiveTypeName = if (ctx.isPrimitiveType(jt)) ctx.primitiveTypeName(et) else ""
val typeName = if (ctx.isPrimitiveType(jt)) ctx.primitiveTypeName(et) else ""
val storeElements = if (containsNull) {
s"""
for (int $index = 0; $index < $numElements; $index++) {
if ($input.isNullAt($index)) {
$arrayWriter.setNull${typeName}($index);
} else {
final $jt $element = ${ctx.getValue(input, et, index)};
$writeElement
}
}
"""
} else {
if (ctx.isPrimitiveType(jt)) {
s"$arrayWriter.writePrimitive${typeName}Array($input);"
} else {
s"""
for (int $index = 0; $index < $numElements; $index++) {
final $jt $element = ${ctx.getValue(input, et, index)};
$writeElement
}
"""
}
}

s"""
if ($input instanceof UnsafeArrayData) {
${writeUnsafeData(ctx, s"((UnsafeArrayData) $input)", bufferHolder)}
} else {
final int $numElements = $input.numElements();
$arrayWriter.initialize($bufferHolder, $numElements, $elementOrOffsetSize);

for (int $index = 0; $index < $numElements; $index++) {
if ($input.isNullAt($index)) {
$arrayWriter.setNull$primitiveTypeName($index);
} else {
final $jt $element = ${ctx.getValue(input, et, index)};
$writeElement
}
}
$storeElements
}
"""
}
Expand Down Expand Up @@ -271,11 +289,11 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
// Remember the current cursor so that we can write numBytes of key array later.
final int $tmpCursor = $bufferHolder.cursor;

${writeArrayToBuffer(ctx, keys, keyType, bufferHolder)}
${writeArrayToBuffer(ctx, keys, keyType, false, bufferHolder)}
// Write the numBytes of key array into the first 8 bytes.
Platform.putLong($bufferHolder.buffer, $tmpCursor - 8, $bufferHolder.cursor - $tmpCursor);

${writeArrayToBuffer(ctx, values, valueType, bufferHolder)}
${writeArrayToBuffer(ctx, values, valueType, true, bufferHolder)}
}
"""
}
Expand Down
Loading