Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -21,24 +21,40 @@
import org.apache.spark.unsafe.Platform;

/**
* A helper class to manage the row buffer when construct unsafe rows.
* A helper class to manage the data buffer for an unsafe row. The data buffer can grow and
* automatically re-point the unsafe row to it.
*
* This class can be used to build a one-pass unsafe row writing program, i.e. data will be written
* to the data buffer directly and no extra copy is needed. There should be only one instance of
* this class per writing program, so that the memory segment/data buffer can be reused. Note that
* for each incoming record, we should call `reset` of BufferHolder instance before write the record
* and reuse the data buffer.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you also comment that we should either call unsafeRow.pointTo() or unsafeRow.setTotalSize() ?

*
* Generally we should call `UnsafeRow.setTotalSize` and pass in `BufferHolder.totalSize` to update
* the size of the result row, after writing a record to the buffer. However, we can skip this step
* if the fields of row are all fixed-length, as the size of result row is also fixed.
*/
public class BufferHolder {
public byte[] buffer;
public int cursor = Platform.BYTE_ARRAY_OFFSET;
private final UnsafeRow row;
private final int fixedSize;

public BufferHolder() {
this(64);
public BufferHolder(UnsafeRow row) {
this(row, 64);
}

public BufferHolder(int size) {
buffer = new byte[size];
public BufferHolder(UnsafeRow row, int initialSize) {
this.fixedSize = UnsafeRow.calculateBitSetWidthInBytes(row.numFields()) + 8 * row.numFields();
this.buffer = new byte[fixedSize + initialSize];
this.row = row;
this.row.pointTo(buffer, buffer.length);
}

/**
* Grows the buffer to at least neededSize. If row is non-null, points the row to the buffer.
* Grows the buffer by at least neededSize and points the row to the buffer.
*/
public void grow(int neededSize, UnsafeRow row) {
public void grow(int neededSize) {
final int length = totalSize() + neededSize;
if (buffer.length < length) {
// This will not happen frequently, because the buffer is re-used.
Expand All @@ -50,22 +66,12 @@ public void grow(int neededSize, UnsafeRow row) {
Platform.BYTE_ARRAY_OFFSET,
totalSize());
buffer = tmp;
if (row != null) {
row.pointTo(buffer, length * 2);
}
row.pointTo(buffer, buffer.length);
}
}

public void grow(int neededSize) {
grow(neededSize, null);
}

public void reset() {
cursor = Platform.BYTE_ARRAY_OFFSET;
}
public void resetTo(int offset) {
assert(offset <= buffer.length);
cursor = Platform.BYTE_ARRAY_OFFSET + offset;
cursor = Platform.BYTE_ARRAY_OFFSET + fixedSize;
}

public int totalSize() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,38 +26,56 @@
import org.apache.spark.unsafe.types.UTF8String;

/**
* A helper class to write data into global row buffer using `UnsafeRow` format,
* used by {@link org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection}.
* A helper class to write data into global row buffer using `UnsafeRow` format.
*
* It will remember the offset of row buffer which it starts to write, and move the cursor of row
* buffer while writing. If new data(can be the input record if this is the outermost writer, or
* nested struct if this is an inner writer) comes, the starting cursor of row buffer may be
* changed, so we need to call `UnsafeRowWriter.reset` before writing, to update the
* `startingOffset` and clear out null bits.
*
* Note that if this is the outermost writer, which means we will always write from the very
* beginning of the global row buffer, we don't need to update `startingOffset` and can just call
* `zeroOutNullBytes` before writing new data.
*/
public class UnsafeRowWriter {

private BufferHolder holder;
private final BufferHolder holder;
// The offset of the global buffer where we start to write this row.
private int startingOffset;
private int nullBitsSize;
private UnsafeRow row;
private final int nullBitsSize;
private final int fixedSize;

public void initialize(BufferHolder holder, int numFields) {
public UnsafeRowWriter(BufferHolder holder, int numFields) {
this.holder = holder;
this.startingOffset = holder.cursor;
this.nullBitsSize = UnsafeRow.calculateBitSetWidthInBytes(numFields);
this.fixedSize = nullBitsSize + 8 * numFields;
this.startingOffset = holder.cursor;
}

/**
* Resets the `startingOffset` according to the current cursor of row buffer, and clear out null
* bits. This should be called before we write a new nested struct to the row buffer.
*/
public void reset() {
this.startingOffset = holder.cursor;

// grow the global buffer to make sure it has enough space to write fixed-length data.
final int fixedSize = nullBitsSize + 8 * numFields;
holder.grow(fixedSize, row);
holder.grow(fixedSize);
holder.cursor += fixedSize;

// zero-out the null bits region
zeroOutNullBytes();
}

/**
* Clears out null bits. This should be called before we write a new row to row buffer.
*/
public void zeroOutNullBytes() {
for (int i = 0; i < nullBitsSize; i += 8) {
Platform.putLong(holder.buffer, startingOffset + i, 0L);
}
}

public void initialize(UnsafeRow row, BufferHolder holder, int numFields) {
initialize(holder, numFields);
this.row = row;
}

private void zeroOutPaddingBytes(int numBytes) {
if ((numBytes & 0x07) > 0) {
Platform.putLong(holder.buffer, holder.cursor + ((numBytes >> 3) << 3), 0L);
Expand Down Expand Up @@ -98,7 +116,7 @@ public void alignToWords(int numBytes) {

if (remainder > 0) {
final int paddingBytes = 8 - remainder;
holder.grow(paddingBytes, row);
holder.grow(paddingBytes);

for (int i = 0; i < paddingBytes; i++) {
Platform.putByte(holder.buffer, holder.cursor, (byte) 0);
Expand Down Expand Up @@ -161,7 +179,7 @@ public void write(int ordinal, Decimal input, int precision, int scale) {
}
} else {
// grow the global buffer before writing data.
holder.grow(16, row);
holder.grow(16);

// zero-out the bytes
Platform.putLong(holder.buffer, holder.cursor, 0L);
Expand Down Expand Up @@ -193,7 +211,7 @@ public void write(int ordinal, UTF8String input) {
final int roundedSize = ByteArrayMethods.roundNumberOfBytesToNearestWord(numBytes);

// grow the global buffer before writing data.
holder.grow(roundedSize, row);
holder.grow(roundedSize);

zeroOutPaddingBytes(numBytes);

Expand All @@ -214,7 +232,7 @@ public void write(int ordinal, byte[] input, int offset, int numBytes) {
final int roundedSize = ByteArrayMethods.roundNumberOfBytesToNearestWord(numBytes);

// grow the global buffer before writing data.
holder.grow(roundedSize, row);
holder.grow(roundedSize);

zeroOutPaddingBytes(numBytes);

Expand All @@ -230,7 +248,7 @@ public void write(int ordinal, byte[] input, int offset, int numBytes) {

public void write(int ordinal, CalendarInterval input) {
// grow the global buffer before writing data.
holder.grow(16, row);
holder.grow(16);

// Write the months and microseconds fields of Interval to the variable length portion.
Platform.putLong(holder.buffer, holder.cursor, input.months);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,6 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
case _ => false
}

private val rowWriterClass = classOf[UnsafeRowWriter].getName
private val arrayWriterClass = classOf[UnsafeArrayWriter].getName

// TODO: if the nullability of field is correct, we can use it to save null check.
private def writeStructToBuffer(
ctx: CodegenContext,
Expand Down Expand Up @@ -73,9 +70,27 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
row: String,
inputs: Seq[ExprCode],
inputTypes: Seq[DataType],
bufferHolder: String): String = {
bufferHolder: String,
isTopLevel: Boolean = false): String = {
val rowWriterClass = classOf[UnsafeRowWriter].getName
val rowWriter = ctx.freshName("rowWriter")
ctx.addMutableState(rowWriterClass, rowWriter, s"this.$rowWriter = new $rowWriterClass();")
ctx.addMutableState(rowWriterClass, rowWriter,
s"this.$rowWriter = new $rowWriterClass($bufferHolder, ${inputs.length});")

val resetWriter = if (isTopLevel) {
// For top level row writer, it always writes to the beginning of the global buffer holder,
// which means its fixed-size region always in the same position, so we don't need to call
// `reset` to set up its fixed-size region every time.
if (inputs.map(_.isNull).forall(_ == "false")) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Even the expression is not nullable, isNull could still be not false (not optimized yet).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we pass in the expressions ?

// If all fields are not nullable, which means the null bits never changes, then we don't
// need to clear it out every time.
""
} else {
s"$rowWriter.zeroOutNullBytes();"
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here I made a different decision compare to the unsafe parquet reader. We can clear out the null bits at beginning, and call UnsafeRowWriter.write instead of UnsafeRow.setXXX, which saves one null bits updating. If null values are rare, this one should be faster. I'll benchmark it later.
cc @nongli

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Make sense for me.

}
} else {
s"$rowWriter.reset();"
}

val writeFields = inputs.zip(inputTypes).zipWithIndex.map {
case ((input, dataType), index) =>
Expand Down Expand Up @@ -122,11 +137,6 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
$rowWriter.alignToWords($bufferHolder.cursor - $tmpCursor);
"""

case _ if ctx.isPrimitiveType(dt) =>
s"""
$rowWriter.write($index, ${input.value});
"""

case t: DecimalType =>
s"$rowWriter.write($index, ${input.value}, ${t.precision}, ${t.scale});"

Expand All @@ -153,7 +163,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
}

s"""
$rowWriter.initialize($bufferHolder, ${inputs.length});
$resetWriter
${ctx.splitExpressions(row, writeFields)}
""".trim
}
Expand All @@ -164,6 +174,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
input: String,
elementType: DataType,
bufferHolder: String): String = {
val arrayWriterClass = classOf[UnsafeArrayWriter].getName
val arrayWriter = ctx.freshName("arrayWriter")
ctx.addMutableState(arrayWriterClass, arrayWriter,
s"this.$arrayWriter = new $arrayWriterClass();")
Expand Down Expand Up @@ -288,22 +299,43 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
val exprEvals = ctx.generateExpressions(expressions, useSubexprElimination)
val exprTypes = expressions.map(_.dataType)

val numVarLenFields = exprTypes.count {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since it's easy to grow the buffer, we don't need these optimization.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is used to avoid calling reset and setTotalSize(), still useful. nvm.

case dt if UnsafeRow.isFixedLength(dt) => false
// TODO: consider large decimal and interval type
case _ => true
}

val result = ctx.freshName("result")
ctx.addMutableState("UnsafeRow", result, s"$result = new UnsafeRow(${expressions.length});")
val bufferHolder = ctx.freshName("bufferHolder")

val holder = ctx.freshName("holder")
val holderClass = classOf[BufferHolder].getName
ctx.addMutableState(holderClass, bufferHolder, s"this.$bufferHolder = new $holderClass();")
ctx.addMutableState(holderClass, holder,
s"this.$holder = new $holderClass($result, ${numVarLenFields * 32});")

val resetBufferHolder = if (numVarLenFields == 0) {
""
} else {
s"$holder.reset();"
}
val updateRowSize = if (numVarLenFields == 0) {
""
} else {
s"$result.setTotalSize($holder.totalSize());"
}

// Evaluate all the subexpression.
val evalSubexpr = ctx.subexprFunctions.mkString("\n")

val writeExpressions =
writeExpressionsToBuffer(ctx, ctx.INPUT_ROW, exprEvals, exprTypes, holder, isTopLevel = true)

val code =
s"""
$bufferHolder.reset();
$resetBufferHolder
$evalSubexpr
${writeExpressionsToBuffer(ctx, ctx.INPUT_ROW, exprEvals, exprTypes, bufferHolder)}

$result.pointTo($bufferHolder.buffer, $bufferHolder.totalSize());
$writeExpressions
$updateRowSize
"""
ExprCode(code, "false", result)
}
Expand Down
Loading