-
Notifications
You must be signed in to change notification settings - Fork 28.9k
[SPARK-12879][SQL] improve the unsafe row writing framework #10809
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -43,9 +43,6 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro | |
| case _ => false | ||
| } | ||
|
|
||
| private val rowWriterClass = classOf[UnsafeRowWriter].getName | ||
| private val arrayWriterClass = classOf[UnsafeArrayWriter].getName | ||
|
|
||
| // TODO: if the nullability of field is correct, we can use it to save null check. | ||
| private def writeStructToBuffer( | ||
| ctx: CodegenContext, | ||
|
|
@@ -73,9 +70,27 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro | |
| row: String, | ||
| inputs: Seq[ExprCode], | ||
| inputTypes: Seq[DataType], | ||
| bufferHolder: String): String = { | ||
| bufferHolder: String, | ||
| isTopLevel: Boolean = false): String = { | ||
| val rowWriterClass = classOf[UnsafeRowWriter].getName | ||
| val rowWriter = ctx.freshName("rowWriter") | ||
| ctx.addMutableState(rowWriterClass, rowWriter, s"this.$rowWriter = new $rowWriterClass();") | ||
| ctx.addMutableState(rowWriterClass, rowWriter, | ||
| s"this.$rowWriter = new $rowWriterClass($bufferHolder, ${inputs.length});") | ||
|
|
||
| val resetWriter = if (isTopLevel) { | ||
| // For top level row writer, it always writes to the beginning of the global buffer holder, | ||
| // which means its fixed-size region always in the same position, so we don't need to call | ||
| // `reset` to set up its fixed-size region every time. | ||
| if (inputs.map(_.isNull).forall(_ == "false")) { | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Even the expression is not nullable,
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I followed https://github.com/apache/spark/pull/10333/files#diff-90b107e5c61791e17d5b4b25021b89fdR138 to do this, is there a better approach?
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could we pass in the expressions ? |
||
| // If all fields are not nullable, which means the null bits never changes, then we don't | ||
| // need to clear it out every time. | ||
| "" | ||
| } else { | ||
| s"$rowWriter.zeroOutNullBytes();" | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Here I made a different decision compare to the unsafe parquet reader. We can clear out the null bits at beginning, and call
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Make sense for me. |
||
| } | ||
| } else { | ||
| s"$rowWriter.reset();" | ||
| } | ||
|
|
||
| val writeFields = inputs.zip(inputTypes).zipWithIndex.map { | ||
| case ((input, dataType), index) => | ||
|
|
@@ -122,11 +137,6 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro | |
| $rowWriter.alignToWords($bufferHolder.cursor - $tmpCursor); | ||
| """ | ||
|
|
||
| case _ if ctx.isPrimitiveType(dt) => | ||
| s""" | ||
| $rowWriter.write($index, ${input.value}); | ||
| """ | ||
|
|
||
| case t: DecimalType => | ||
| s"$rowWriter.write($index, ${input.value}, ${t.precision}, ${t.scale});" | ||
|
|
||
|
|
@@ -153,7 +163,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro | |
| } | ||
|
|
||
| s""" | ||
| $rowWriter.initialize($bufferHolder, ${inputs.length}); | ||
| $resetWriter | ||
| ${ctx.splitExpressions(row, writeFields)} | ||
| """.trim | ||
| } | ||
|
|
@@ -164,6 +174,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro | |
| input: String, | ||
| elementType: DataType, | ||
| bufferHolder: String): String = { | ||
| val arrayWriterClass = classOf[UnsafeArrayWriter].getName | ||
| val arrayWriter = ctx.freshName("arrayWriter") | ||
| ctx.addMutableState(arrayWriterClass, arrayWriter, | ||
| s"this.$arrayWriter = new $arrayWriterClass();") | ||
|
|
@@ -288,22 +299,43 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro | |
| val exprEvals = ctx.generateExpressions(expressions, useSubexprElimination) | ||
| val exprTypes = expressions.map(_.dataType) | ||
|
|
||
| val numVarLenFields = exprTypes.count { | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Since it's easy to grow the buffer, we don't need these optimization.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is used to avoid calling reset and setTotalSize(), still useful. nvm. |
||
| case dt if UnsafeRow.isFixedLength(dt) => false | ||
| // TODO: consider large decimal and interval type | ||
| case _ => true | ||
| } | ||
|
|
||
| val result = ctx.freshName("result") | ||
| ctx.addMutableState("UnsafeRow", result, s"$result = new UnsafeRow(${expressions.length});") | ||
| val bufferHolder = ctx.freshName("bufferHolder") | ||
|
|
||
| val holder = ctx.freshName("holder") | ||
| val holderClass = classOf[BufferHolder].getName | ||
| ctx.addMutableState(holderClass, bufferHolder, s"this.$bufferHolder = new $holderClass();") | ||
| ctx.addMutableState(holderClass, holder, | ||
| s"this.$holder = new $holderClass($result, ${numVarLenFields * 32});") | ||
|
|
||
| val resetBufferHolder = if (numVarLenFields == 0) { | ||
| "" | ||
| } else { | ||
| s"$holder.reset();" | ||
| } | ||
| val updateRowSize = if (numVarLenFields == 0) { | ||
| "" | ||
| } else { | ||
| s"$result.setTotalSize($holder.totalSize());" | ||
| } | ||
|
|
||
| // Evaluate all the subexpression. | ||
| val evalSubexpr = ctx.subexprFunctions.mkString("\n") | ||
|
|
||
| val writeExpressions = | ||
| writeExpressionsToBuffer(ctx, ctx.INPUT_ROW, exprEvals, exprTypes, holder, isTopLevel = true) | ||
|
|
||
| val code = | ||
| s""" | ||
| $bufferHolder.reset(); | ||
| $resetBufferHolder | ||
| $evalSubexpr | ||
| ${writeExpressionsToBuffer(ctx, ctx.INPUT_ROW, exprEvals, exprTypes, bufferHolder)} | ||
|
|
||
| $result.pointTo($bufferHolder.buffer, $bufferHolder.totalSize()); | ||
| $writeExpressions | ||
| $updateRowSize | ||
| """ | ||
| ExprCode(code, "false", result) | ||
| } | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you also comment that we should either call
unsafeRow.pointTo()orunsafeRow.setTotalSize()?