Skip to content

Commit 25bba58

Browse files
Davies Liudavies
authored andcommitted
[SPARK-13404] [SQL] Create variables for input row when it's actually used
## What changes were proposed in this pull request? This PR change the way how we generate the code for the output variables passing from a plan to it's parent. Right now, they are generated before call consume() of it's parent. It's not efficient, if the parent is a Filter or Join, which could filter out most the rows, the time to access some of the columns that are not used by the Filter or Join are wasted. This PR try to improve this by defering the access of columns until they are actually used by a plan. After this PR, a plan does not need to generate code to evaluate the variables for output, just passing the ExprCode to its parent by `consume()`. In `parent.consumeChild()`, it will check the output from child and `usedInputs`, generate the code for those columns that is part of `usedInputs` before calling `doConsume()`. This PR also change the `if` from ``` if (cond) { xxx } ``` to ``` if (!cond) continue; xxx ``` The new one could help to reduce the nested indents for multiple levels of Filter and BroadcastHashJoin. It also added some comments for operators. ## How was the this patch tested? Unit tests. Manually ran TPCDS Q55, this PR improve the performance about 30% (scale=10, from 2.56s to 1.96s) Author: Davies Liu <[email protected]> Closes #11274 from davies/gen_defer.
1 parent da7bfac commit 25bba58

File tree

9 files changed

+224
-155
lines changed

9 files changed

+224
-155
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -62,9 +62,10 @@ case class BoundReference(ordinal: Int, dataType: DataType, nullable: Boolean)
6262
val javaType = ctx.javaType(dataType)
6363
val value = ctx.getValue(ctx.INPUT_ROW, dataType, ordinal.toString)
6464
if (ctx.currentVars != null && ctx.currentVars(ordinal) != null) {
65-
ev.isNull = ctx.currentVars(ordinal).isNull
66-
ev.value = ctx.currentVars(ordinal).value
67-
""
65+
val oev = ctx.currentVars(ordinal)
66+
ev.isNull = oev.isNull
67+
ev.value = oev.value
68+
oev.code
6869
} else if (nullable) {
6970
s"""
7071
boolean ${ev.isNull} = ${ctx.INPUT_ROW}.isNullAt($ordinal);

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,8 @@ import org.apache.spark.util.Utils
3737
* Java source for evaluating an [[Expression]] given a [[InternalRow]] of input.
3838
*
3939
* @param code The sequence of statements required to evaluate the expression.
40+
* It should be empty string, if `isNull` and `value` are already existed, or no code
41+
* needed to evaluate them (literals).
4042
* @param isNull A term that holds a boolean value representing whether the expression evaluated
4143
* to null.
4244
* @param value A term for a (possibly primitive) value of the result of the evaluation. Not

sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala

Lines changed: 23 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -151,17 +151,16 @@ private[sql] case class PhysicalRDD(
151151
val exprs = output.zipWithIndex.map(x => new BoundReference(x._2, x._1.dataType, true))
152152
val row = ctx.freshName("row")
153153
val numOutputRows = metricTerm(ctx, "numOutputRows")
154-
ctx.INPUT_ROW = row
155-
ctx.currentVars = null
156-
val columns = exprs.map(_.gen(ctx))
157154

158155
// The input RDD can either return (all) ColumnarBatches or InternalRows. We determine this
159156
// by looking at the first value of the RDD and then calling the function which will process
160157
// the remaining. It is faster to return batches.
161158
// TODO: The abstractions between this class and SqlNewHadoopRDD makes it difficult to know
162159
// here which path to use. Fix this.
163160

164-
161+
ctx.INPUT_ROW = row
162+
ctx.currentVars = null
163+
val columns1 = exprs.map(_.gen(ctx))
165164
val scanBatches = ctx.freshName("processBatches")
166165
ctx.addNewFunction(scanBatches,
167166
s"""
@@ -170,12 +169,11 @@ private[sql] case class PhysicalRDD(
170169
| int numRows = $batch.numRows();
171170
| if ($idx == 0) $numOutputRows.add(numRows);
172171
|
173-
| while ($idx < numRows) {
172+
| while (!shouldStop() && $idx < numRows) {
174173
| InternalRow $row = $batch.getRow($idx++);
175-
| ${columns.map(_.code).mkString("\n").trim}
176-
| ${consume(ctx, columns).trim}
177-
| if (shouldStop()) return;
174+
| ${consume(ctx, columns1).trim}
178175
| }
176+
| if (shouldStop()) return;
179177
|
180178
| if (!$input.hasNext()) {
181179
| $batch = null;
@@ -186,30 +184,37 @@ private[sql] case class PhysicalRDD(
186184
| }
187185
| }""".stripMargin)
188186

187+
ctx.INPUT_ROW = row
188+
ctx.currentVars = null
189+
val columns2 = exprs.map(_.gen(ctx))
190+
val inputRow = if (isUnsafeRow) row else null
189191
val scanRows = ctx.freshName("processRows")
190192
ctx.addNewFunction(scanRows,
191193
s"""
192194
| private void $scanRows(InternalRow $row) throws java.io.IOException {
193-
| while (true) {
195+
| boolean firstRow = true;
196+
| while (!shouldStop() && (firstRow || $input.hasNext())) {
197+
| if (firstRow) {
198+
| firstRow = false;
199+
| } else {
200+
| $row = (InternalRow) $input.next();
201+
| }
194202
| $numOutputRows.add(1);
195-
| ${columns.map(_.code).mkString("\n").trim}
196-
| ${consume(ctx, columns).trim}
197-
| if (shouldStop()) return;
198-
| if (!$input.hasNext()) break;
199-
| $row = (InternalRow)$input.next();
203+
| ${consume(ctx, columns2, inputRow).trim}
200204
| }
201205
| }""".stripMargin)
202206

207+
val value = ctx.freshName("value")
203208
s"""
204209
| if ($batch != null) {
205210
| $scanBatches();
206211
| } else if ($input.hasNext()) {
207-
| Object value = $input.next();
208-
| if (value instanceof $columnarBatchClz) {
209-
| $batch = ($columnarBatchClz)value;
212+
| Object $value = $input.next();
213+
| if ($value instanceof $columnarBatchClz) {
214+
| $batch = ($columnarBatchClz)$value;
210215
| $scanBatches();
211216
| } else {
212-
| $scanRows((InternalRow)value);
217+
| $scanRows((InternalRow) $value);
213218
| }
214219
| }
215220
""".stripMargin

sql/core/src/main/scala/org/apache/spark/sql/execution/Expand.scala

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -185,8 +185,10 @@ case class Expand(
185185

186186
val numOutput = metricTerm(ctx, "numOutputRows")
187187
val i = ctx.freshName("i")
188+
// these column have to declared before the loop.
189+
val evaluate = evaluateVariables(outputColumns)
188190
s"""
189-
|${outputColumns.map(_.code).mkString("\n").trim}
191+
|$evaluate
190192
|for (int $i = 0; $i < ${projections.length}; $i ++) {
191193
| switch ($i) {
192194
| ${cases.mkString("\n").trim}

sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala

Lines changed: 70 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -81,11 +81,14 @@ trait CodegenSupport extends SparkPlan {
8181
this.parent = parent
8282
ctx.freshNamePrefix = variablePrefix
8383
waitForSubqueries()
84-
doProduce(ctx)
84+
s"""
85+
|/*** PRODUCE: ${toCommentSafeString(this.simpleString)} */
86+
|${doProduce(ctx)}
87+
""".stripMargin
8588
}
8689

8790
/**
88-
* Generate the Java source code to process, should be overrided by subclass to support codegen.
91+
* Generate the Java source code to process, should be overridden by subclass to support codegen.
8992
*
9093
* doProduce() usually generate the framework, for example, aggregation could generate this:
9194
*
@@ -94,11 +97,11 @@ trait CodegenSupport extends SparkPlan {
9497
* # call child.produce()
9598
* initialized = true;
9699
* }
97-
* while (hashmap.hasNext()) {
100+
* while (!shouldStop() && hashmap.hasNext()) {
98101
* row = hashmap.next();
99102
* # build the aggregation results
100-
* # create varialbles for results
101-
* # call consume(), wich will call parent.doConsume()
103+
* # create variables for results
104+
* # call consume(), which will call parent.doConsume()
102105
* }
103106
*/
104107
protected def doProduce(ctx: CodegenContext): String
@@ -114,27 +117,71 @@ trait CodegenSupport extends SparkPlan {
114117
}
115118

116119
/**
117-
* Consume the columns generated from it's child, call doConsume() or emit the rows.
120+
* Returns source code to evaluate all the variables, and clear the code of them, to prevent
121+
* them to be evaluated twice.
122+
*/
123+
protected def evaluateVariables(variables: Seq[ExprCode]): String = {
124+
val evaluate = variables.filter(_.code != "").map(_.code.trim).mkString("\n")
125+
variables.foreach(_.code = "")
126+
evaluate
127+
}
128+
129+
/**
130+
* Returns source code to evaluate the variables for required attributes, and clear the code
131+
* of evaluated variables, to prevent them to be evaluated twice..
118132
*/
133+
protected def evaluateRequiredVariables(
134+
attributes: Seq[Attribute],
135+
variables: Seq[ExprCode],
136+
required: AttributeSet): String = {
137+
var evaluateVars = ""
138+
variables.zipWithIndex.foreach { case (ev, i) =>
139+
if (ev.code != "" && required.contains(attributes(i))) {
140+
evaluateVars += ev.code.trim + "\n"
141+
ev.code = ""
142+
}
143+
}
144+
evaluateVars
145+
}
146+
147+
/**
148+
* The subset of inputSet those should be evaluated before this plan.
149+
*
150+
* We will use this to insert some code to access those columns that are actually used by current
151+
* plan before calling doConsume().
152+
*/
153+
def usedInputs: AttributeSet = references
154+
155+
/**
156+
* Consume the columns generated from its child, call doConsume() or emit the rows.
157+
*
158+
* An operator could generate variables for the output, or a row, either one could be null.
159+
*
160+
* If the row is not null, we create variables to access the columns that are actually used by
161+
* current plan before calling doConsume().
162+
*/
119163
def consumeChild(
120164
ctx: CodegenContext,
121165
child: SparkPlan,
122166
input: Seq[ExprCode],
123167
row: String = null): String = {
124168
ctx.freshNamePrefix = variablePrefix
125-
if (row != null) {
126-
ctx.currentVars = null
127-
ctx.INPUT_ROW = row
128-
val evals = child.output.zipWithIndex.map { case (attr, i) =>
129-
BoundReference(i, attr.dataType, attr.nullable).gen(ctx)
169+
val inputVars =
170+
if (row != null) {
171+
ctx.currentVars = null
172+
ctx.INPUT_ROW = row
173+
child.output.zipWithIndex.map { case (attr, i) =>
174+
BoundReference(i, attr.dataType, attr.nullable).gen(ctx)
175+
}
176+
} else {
177+
input
130178
}
131-
s"""
132-
| ${evals.map(_.code).mkString("\n")}
133-
| ${doConsume(ctx, evals)}
134-
""".stripMargin
135-
} else {
136-
doConsume(ctx, input)
137-
}
179+
s"""
180+
|
181+
|/*** CONSUME: ${toCommentSafeString(this.simpleString)} */
182+
|${evaluateRequiredVariables(child.output, inputVars, usedInputs)}
183+
|${doConsume(ctx, inputVars)}
184+
""".stripMargin
138185
}
139186

140187
/**
@@ -145,9 +192,8 @@ trait CodegenSupport extends SparkPlan {
145192
* For example, Filter will generate the code like this:
146193
*
147194
* # code to evaluate the predicate expression, result is isNull1 and value2
148-
* if (isNull1 || value2) {
149-
* # call consume(), which will call parent.doConsume()
150-
* }
195+
* if (isNull1 || !value2) continue;
196+
* # call consume(), which will call parent.doConsume()
151197
*/
152198
protected def doConsume(ctx: CodegenContext, input: Seq[ExprCode]): String = {
153199
throw new UnsupportedOperationException
@@ -190,13 +236,9 @@ case class InputAdapter(child: SparkPlan) extends UnaryNode with CodegenSupport
190236
ctx.currentVars = null
191237
val columns = exprs.map(_.gen(ctx))
192238
s"""
193-
| while ($input.hasNext()) {
239+
| while (!shouldStop() && $input.hasNext()) {
194240
| InternalRow $row = (InternalRow) $input.next();
195-
| ${columns.map(_.code).mkString("\n").trim}
196241
| ${consume(ctx, columns).trim}
197-
| if (shouldStop()) {
198-
| return;
199-
| }
200242
| }
201243
""".stripMargin
202244
}
@@ -332,10 +374,12 @@ case class WholeStageCodegen(child: SparkPlan) extends UnaryNode with CodegenSup
332374
val colExprs = output.zipWithIndex.map { case (attr, i) =>
333375
BoundReference(i, attr.dataType, attr.nullable)
334376
}
377+
val evaluateInputs = evaluateVariables(input)
335378
// generate the code to create a UnsafeRow
336379
ctx.currentVars = input
337380
val code = GenerateUnsafeProjection.createCode(ctx, colExprs, false)
338381
s"""
382+
|$evaluateInputs
339383
|${code.code.trim}
340384
|append(${code.value}.copy());
341385
""".stripMargin.trim

0 commit comments

Comments
 (0)