Skip to content

Commit be375fc

Browse files
cloud-fandavies
authored andcommitted
[SPARK-12879] [SQL] improve the unsafe row writing framework
As we begin to use unsafe row writing framework(`BufferHolder` and `UnsafeRowWriter`) in more and more places(`UnsafeProjection`, `UnsafeRowParquetRecordReader`, `GenerateColumnAccessor`, etc.), we should add more doc to it and make it easier to use. This PR abstract the technique used in `UnsafeRowParquetRecordReader`: avoid unnecessary operatition as more as possible. For example, do not always point the row to the buffer at the end, we only need to update the size of row. If all fields are of primitive type, we can even save the row size updating. Then we can apply this technique to more places easily. a local benchmark shows `UnsafeProjection` is up to 1.7x faster after this PR: **old version** ``` Intel(R) Core(TM) i7-4960HQ CPU 2.60GHz unsafe projection: Avg Time(ms) Avg Rate(M/s) Relative Rate ------------------------------------------------------------------------------- single long 2616.04 102.61 1.00 X single nullable long 3032.54 88.52 0.86 X primitive types 9121.05 29.43 0.29 X nullable primitive types 12410.60 21.63 0.21 X ``` **new version** ``` Intel(R) Core(TM) i7-4960HQ CPU 2.60GHz unsafe projection: Avg Time(ms) Avg Rate(M/s) Relative Rate ------------------------------------------------------------------------------- single long 1533.34 175.07 1.00 X single nullable long 2306.73 116.37 0.66 X primitive types 8403.93 31.94 0.18 X nullable primitive types 12448.39 21.56 0.12 X ``` For single non-nullable long(the best case), we can have about 1.7x speed up. Even it's nullable, we can still have 1.3x speed up. For other cases, it's not such a boost as the saved operations only take a little proportion of the whole process. The benchmark code is included in this PR. Author: Wenchen Fan <[email protected]> Closes #10809 from cloud-fan/unsafe-projection.
1 parent 6f0f1d9 commit be375fc

File tree

7 files changed

+258
-78
lines changed

7 files changed

+258
-78
lines changed

sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/BufferHolder.java

Lines changed: 25 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -21,24 +21,40 @@
2121
import org.apache.spark.unsafe.Platform;
2222

2323
/**
24-
* A helper class to manage the row buffer when construct unsafe rows.
24+
* A helper class to manage the data buffer for an unsafe row. The data buffer can grow and
25+
* automatically re-point the unsafe row to it.
26+
*
27+
* This class can be used to build a one-pass unsafe row writing program, i.e. data will be written
28+
* to the data buffer directly and no extra copy is needed. There should be only one instance of
29+
* this class per writing program, so that the memory segment/data buffer can be reused. Note that
30+
* for each incoming record, we should call `reset` of BufferHolder instance before write the record
31+
* and reuse the data buffer.
32+
*
33+
* Generally we should call `UnsafeRow.setTotalSize` and pass in `BufferHolder.totalSize` to update
34+
* the size of the result row, after writing a record to the buffer. However, we can skip this step
35+
* if the fields of row are all fixed-length, as the size of result row is also fixed.
2536
*/
2637
public class BufferHolder {
2738
public byte[] buffer;
2839
public int cursor = Platform.BYTE_ARRAY_OFFSET;
40+
private final UnsafeRow row;
41+
private final int fixedSize;
2942

30-
public BufferHolder() {
31-
this(64);
43+
public BufferHolder(UnsafeRow row) {
44+
this(row, 64);
3245
}
3346

34-
public BufferHolder(int size) {
35-
buffer = new byte[size];
47+
public BufferHolder(UnsafeRow row, int initialSize) {
48+
this.fixedSize = UnsafeRow.calculateBitSetWidthInBytes(row.numFields()) + 8 * row.numFields();
49+
this.buffer = new byte[fixedSize + initialSize];
50+
this.row = row;
51+
this.row.pointTo(buffer, buffer.length);
3652
}
3753

3854
/**
39-
* Grows the buffer to at least neededSize. If row is non-null, points the row to the buffer.
55+
* Grows the buffer by at least neededSize and points the row to the buffer.
4056
*/
41-
public void grow(int neededSize, UnsafeRow row) {
57+
public void grow(int neededSize) {
4258
final int length = totalSize() + neededSize;
4359
if (buffer.length < length) {
4460
// This will not happen frequently, because the buffer is re-used.
@@ -50,22 +66,12 @@ public void grow(int neededSize, UnsafeRow row) {
5066
Platform.BYTE_ARRAY_OFFSET,
5167
totalSize());
5268
buffer = tmp;
53-
if (row != null) {
54-
row.pointTo(buffer, length * 2);
55-
}
69+
row.pointTo(buffer, buffer.length);
5670
}
5771
}
5872

59-
public void grow(int neededSize) {
60-
grow(neededSize, null);
61-
}
62-
6373
public void reset() {
64-
cursor = Platform.BYTE_ARRAY_OFFSET;
65-
}
66-
public void resetTo(int offset) {
67-
assert(offset <= buffer.length);
68-
cursor = Platform.BYTE_ARRAY_OFFSET + offset;
74+
cursor = Platform.BYTE_ARRAY_OFFSET + fixedSize;
6975
}
7076

7177
public int totalSize() {

sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeRowWriter.java

Lines changed: 38 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -26,38 +26,56 @@
2626
import org.apache.spark.unsafe.types.UTF8String;
2727

2828
/**
29-
* A helper class to write data into global row buffer using `UnsafeRow` format,
30-
* used by {@link org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection}.
29+
* A helper class to write data into global row buffer using `UnsafeRow` format.
30+
*
31+
* It will remember the offset of row buffer which it starts to write, and move the cursor of row
32+
* buffer while writing. If new data(can be the input record if this is the outermost writer, or
33+
* nested struct if this is an inner writer) comes, the starting cursor of row buffer may be
34+
* changed, so we need to call `UnsafeRowWriter.reset` before writing, to update the
35+
* `startingOffset` and clear out null bits.
36+
*
37+
* Note that if this is the outermost writer, which means we will always write from the very
38+
* beginning of the global row buffer, we don't need to update `startingOffset` and can just call
39+
* `zeroOutNullBytes` before writing new data.
3140
*/
3241
public class UnsafeRowWriter {
3342

34-
private BufferHolder holder;
43+
private final BufferHolder holder;
3544
// The offset of the global buffer where we start to write this row.
3645
private int startingOffset;
37-
private int nullBitsSize;
38-
private UnsafeRow row;
46+
private final int nullBitsSize;
47+
private final int fixedSize;
3948

40-
public void initialize(BufferHolder holder, int numFields) {
49+
public UnsafeRowWriter(BufferHolder holder, int numFields) {
4150
this.holder = holder;
42-
this.startingOffset = holder.cursor;
4351
this.nullBitsSize = UnsafeRow.calculateBitSetWidthInBytes(numFields);
52+
this.fixedSize = nullBitsSize + 8 * numFields;
53+
this.startingOffset = holder.cursor;
54+
}
55+
56+
/**
57+
* Resets the `startingOffset` according to the current cursor of row buffer, and clear out null
58+
* bits. This should be called before we write a new nested struct to the row buffer.
59+
*/
60+
public void reset() {
61+
this.startingOffset = holder.cursor;
4462

4563
// grow the global buffer to make sure it has enough space to write fixed-length data.
46-
final int fixedSize = nullBitsSize + 8 * numFields;
47-
holder.grow(fixedSize, row);
64+
holder.grow(fixedSize);
4865
holder.cursor += fixedSize;
4966

50-
// zero-out the null bits region
67+
zeroOutNullBytes();
68+
}
69+
70+
/**
71+
* Clears out null bits. This should be called before we write a new row to row buffer.
72+
*/
73+
public void zeroOutNullBytes() {
5174
for (int i = 0; i < nullBitsSize; i += 8) {
5275
Platform.putLong(holder.buffer, startingOffset + i, 0L);
5376
}
5477
}
5578

56-
public void initialize(UnsafeRow row, BufferHolder holder, int numFields) {
57-
initialize(holder, numFields);
58-
this.row = row;
59-
}
60-
6179
private void zeroOutPaddingBytes(int numBytes) {
6280
if ((numBytes & 0x07) > 0) {
6381
Platform.putLong(holder.buffer, holder.cursor + ((numBytes >> 3) << 3), 0L);
@@ -98,7 +116,7 @@ public void alignToWords(int numBytes) {
98116

99117
if (remainder > 0) {
100118
final int paddingBytes = 8 - remainder;
101-
holder.grow(paddingBytes, row);
119+
holder.grow(paddingBytes);
102120

103121
for (int i = 0; i < paddingBytes; i++) {
104122
Platform.putByte(holder.buffer, holder.cursor, (byte) 0);
@@ -161,7 +179,7 @@ public void write(int ordinal, Decimal input, int precision, int scale) {
161179
}
162180
} else {
163181
// grow the global buffer before writing data.
164-
holder.grow(16, row);
182+
holder.grow(16);
165183

166184
// zero-out the bytes
167185
Platform.putLong(holder.buffer, holder.cursor, 0L);
@@ -193,7 +211,7 @@ public void write(int ordinal, UTF8String input) {
193211
final int roundedSize = ByteArrayMethods.roundNumberOfBytesToNearestWord(numBytes);
194212

195213
// grow the global buffer before writing data.
196-
holder.grow(roundedSize, row);
214+
holder.grow(roundedSize);
197215

198216
zeroOutPaddingBytes(numBytes);
199217

@@ -214,7 +232,7 @@ public void write(int ordinal, byte[] input, int offset, int numBytes) {
214232
final int roundedSize = ByteArrayMethods.roundNumberOfBytesToNearestWord(numBytes);
215233

216234
// grow the global buffer before writing data.
217-
holder.grow(roundedSize, row);
235+
holder.grow(roundedSize);
218236

219237
zeroOutPaddingBytes(numBytes);
220238

@@ -230,7 +248,7 @@ public void write(int ordinal, byte[] input, int offset, int numBytes) {
230248

231249
public void write(int ordinal, CalendarInterval input) {
232250
// grow the global buffer before writing data.
233-
holder.grow(16, row);
251+
holder.grow(16);
234252

235253
// Write the months and microseconds fields of Interval to the variable length portion.
236254
Platform.putLong(holder.buffer, holder.cursor, input.months);

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala

Lines changed: 49 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -43,9 +43,6 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
4343
case _ => false
4444
}
4545

46-
private val rowWriterClass = classOf[UnsafeRowWriter].getName
47-
private val arrayWriterClass = classOf[UnsafeArrayWriter].getName
48-
4946
// TODO: if the nullability of field is correct, we can use it to save null check.
5047
private def writeStructToBuffer(
5148
ctx: CodegenContext,
@@ -73,9 +70,27 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
7370
row: String,
7471
inputs: Seq[ExprCode],
7572
inputTypes: Seq[DataType],
76-
bufferHolder: String): String = {
73+
bufferHolder: String,
74+
isTopLevel: Boolean = false): String = {
75+
val rowWriterClass = classOf[UnsafeRowWriter].getName
7776
val rowWriter = ctx.freshName("rowWriter")
78-
ctx.addMutableState(rowWriterClass, rowWriter, s"this.$rowWriter = new $rowWriterClass();")
77+
ctx.addMutableState(rowWriterClass, rowWriter,
78+
s"this.$rowWriter = new $rowWriterClass($bufferHolder, ${inputs.length});")
79+
80+
val resetWriter = if (isTopLevel) {
81+
// For top level row writer, it always writes to the beginning of the global buffer holder,
82+
// which means its fixed-size region always in the same position, so we don't need to call
83+
// `reset` to set up its fixed-size region every time.
84+
if (inputs.map(_.isNull).forall(_ == "false")) {
85+
// If all fields are not nullable, which means the null bits never changes, then we don't
86+
// need to clear it out every time.
87+
""
88+
} else {
89+
s"$rowWriter.zeroOutNullBytes();"
90+
}
91+
} else {
92+
s"$rowWriter.reset();"
93+
}
7994

8095
val writeFields = inputs.zip(inputTypes).zipWithIndex.map {
8196
case ((input, dataType), index) =>
@@ -122,11 +137,6 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
122137
$rowWriter.alignToWords($bufferHolder.cursor - $tmpCursor);
123138
"""
124139

125-
case _ if ctx.isPrimitiveType(dt) =>
126-
s"""
127-
$rowWriter.write($index, ${input.value});
128-
"""
129-
130140
case t: DecimalType =>
131141
s"$rowWriter.write($index, ${input.value}, ${t.precision}, ${t.scale});"
132142

@@ -153,7 +163,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
153163
}
154164

155165
s"""
156-
$rowWriter.initialize($bufferHolder, ${inputs.length});
166+
$resetWriter
157167
${ctx.splitExpressions(row, writeFields)}
158168
""".trim
159169
}
@@ -164,6 +174,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
164174
input: String,
165175
elementType: DataType,
166176
bufferHolder: String): String = {
177+
val arrayWriterClass = classOf[UnsafeArrayWriter].getName
167178
val arrayWriter = ctx.freshName("arrayWriter")
168179
ctx.addMutableState(arrayWriterClass, arrayWriter,
169180
s"this.$arrayWriter = new $arrayWriterClass();")
@@ -288,22 +299,43 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
288299
val exprEvals = ctx.generateExpressions(expressions, useSubexprElimination)
289300
val exprTypes = expressions.map(_.dataType)
290301

302+
val numVarLenFields = exprTypes.count {
303+
case dt if UnsafeRow.isFixedLength(dt) => false
304+
// TODO: consider large decimal and interval type
305+
case _ => true
306+
}
307+
291308
val result = ctx.freshName("result")
292309
ctx.addMutableState("UnsafeRow", result, s"$result = new UnsafeRow(${expressions.length});")
293-
val bufferHolder = ctx.freshName("bufferHolder")
310+
311+
val holder = ctx.freshName("holder")
294312
val holderClass = classOf[BufferHolder].getName
295-
ctx.addMutableState(holderClass, bufferHolder, s"this.$bufferHolder = new $holderClass();")
313+
ctx.addMutableState(holderClass, holder,
314+
s"this.$holder = new $holderClass($result, ${numVarLenFields * 32});")
315+
316+
val resetBufferHolder = if (numVarLenFields == 0) {
317+
""
318+
} else {
319+
s"$holder.reset();"
320+
}
321+
val updateRowSize = if (numVarLenFields == 0) {
322+
""
323+
} else {
324+
s"$result.setTotalSize($holder.totalSize());"
325+
}
296326

297327
// Evaluate all the subexpression.
298328
val evalSubexpr = ctx.subexprFunctions.mkString("\n")
299329

330+
val writeExpressions =
331+
writeExpressionsToBuffer(ctx, ctx.INPUT_ROW, exprEvals, exprTypes, holder, isTopLevel = true)
332+
300333
val code =
301334
s"""
302-
$bufferHolder.reset();
335+
$resetBufferHolder
303336
$evalSubexpr
304-
${writeExpressionsToBuffer(ctx, ctx.INPUT_ROW, exprEvals, exprTypes, bufferHolder)}
305-
306-
$result.pointTo($bufferHolder.buffer, $bufferHolder.totalSize());
337+
$writeExpressions
338+
$updateRowSize
307339
"""
308340
ExprCode(code, "false", result)
309341
}

0 commit comments

Comments
 (0)