Skip to content

Commit c92e457

Browse files
author
Davies Liu
committed
improve codegen
1 parent 95e1ab2 commit c92e457

File tree

6 files changed

+125
-91
lines changed

6 files changed

+125
-91
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -62,8 +62,10 @@ case class BoundReference(ordinal: Int, dataType: DataType, nullable: Boolean)
6262
val javaType = ctx.javaType(dataType)
6363
val value = ctx.getValue(ctx.INPUT_ROW, dataType, ordinal.toString)
6464
if (ctx.currentVars != null && ctx.currentVars(ordinal) != null) {
65-
ev.isNull = ctx.currentVars(ordinal).isNull
66-
ev.value = ctx.currentVars(ordinal).value
65+
val oev = ctx.currentVars(ordinal)
66+
// assert(oev.code == "", s"$this has not been evaluated yet.")
67+
ev.isNull = oev.isNull
68+
ev.value = oev.value
6769
""
6870
} else if (nullable) {
6971
s"""

sql/core/src/main/scala/org/apache/spark/sql/execution/Expand.scala

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -187,8 +187,10 @@ case class Expand(
187187

188188
val numOutput = metricTerm(ctx, "numOutputRows")
189189
val i = ctx.freshName("i")
190+
// these column have to declared before the loop.
191+
val evaluate = evaluateVariables(outputColumns)
190192
s"""
191-
|${outputColumns.map(_.code).mkString("\n").trim}
193+
|$evaluate
192194
|for (int $i = 0; $i < ${projections.length}; $i ++) {
193195
| switch ($i) {
194196
| ${cases.mkString("\n").trim}

sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala

Lines changed: 58 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,10 @@ trait CodegenSupport extends SparkPlan {
7676
def produce(ctx: CodegenContext, parent: CodegenSupport): String = {
7777
this.parent = parent
7878
ctx.freshNamePrefix = variablePrefix
79-
doProduce(ctx)
79+
s"""
80+
|/*** PRODUCE: ${commentSafe(this.simpleString)} */
81+
|${doProduce(ctx)}
82+
""".stripMargin
8083
}
8184

8285
/**
@@ -108,6 +111,38 @@ trait CodegenSupport extends SparkPlan {
108111
parent.consumeChild(ctx, this, input, row)
109112
}
110113

114+
/**
115+
* Returns source code to evaluate all the variables, and clear the code of them, to prevent
116+
* them to be evaluated twice.
117+
*/
118+
protected def evaluateVariables(variables: Seq[ExprCode]): String = {
119+
val evaluate = variables.filter(_.code != "").map(_.code.trim).mkString("\n")
120+
variables.foreach(_.code = "")
121+
evaluate
122+
}
123+
124+
/**
125+
* Returns source code to evaluate the variables for required attributes, and clear the code
126+
* of evaluated variables, to prevent them to be evaluated twice..
127+
*/
128+
protected def evaluateRequiredVariables(
129+
attributes: Seq[Attribute],
130+
variables: Seq[ExprCode],
131+
required: AttributeSet): String = {
132+
var evaluateVars = ""
133+
variables.zipWithIndex.foreach { case (ev, i) =>
134+
if (ev.code != "" && required.contains(attributes(i))) {
135+
evaluateVars += ev.code.trim + "\n"
136+
ev.code = ""
137+
}
138+
}
139+
evaluateVars
140+
}
141+
142+
protected def commentSafe(s: String): String = {
143+
s.replace("*/", "\\*\\/").replace("\\u", "\\\\u")
144+
}
145+
111146
/**
112147
* Consume the columns generated from it's child, call doConsume() or emit the rows.
113148
*/
@@ -117,19 +152,22 @@ trait CodegenSupport extends SparkPlan {
117152
input: Seq[ExprCode],
118153
row: String = null): String = {
119154
ctx.freshNamePrefix = variablePrefix
120-
if (row != null) {
121-
ctx.currentVars = null
122-
ctx.INPUT_ROW = row
123-
val evals = child.output.zipWithIndex.map { case (attr, i) =>
124-
BoundReference(i, attr.dataType, attr.nullable).gen(ctx)
155+
val inputVars =
156+
if (row != null) {
157+
ctx.currentVars = null
158+
ctx.INPUT_ROW = row
159+
child.output.zipWithIndex.map { case (attr, i) =>
160+
BoundReference(i, attr.dataType, attr.nullable).gen(ctx)
161+
}
162+
} else {
163+
input
125164
}
126-
s"""
127-
| ${evals.map(_.code).mkString("\n")}
128-
| ${doConsume(ctx, evals)}
129-
""".stripMargin
130-
} else {
131-
doConsume(ctx, input)
132-
}
165+
s"""
166+
|
167+
|/*** CONSUME: ${commentSafe(this.simpleString)} */
168+
|${evaluateRequiredVariables(child.output, inputVars, references)}
169+
|${doConsume(ctx, inputVars)}
170+
""".stripMargin
133171
}
134172

135173
/**
@@ -183,13 +221,9 @@ case class InputAdapter(child: SparkPlan) extends LeafNode with CodegenSupport {
183221
ctx.currentVars = null
184222
val columns = exprs.map(_.gen(ctx))
185223
s"""
186-
| while (input.hasNext()) {
224+
| while (!shouldStop() && input.hasNext()) {
187225
| InternalRow $row = (InternalRow) input.next();
188-
| ${columns.map(_.code).mkString("\n").trim}
189226
| ${consume(ctx, columns).trim}
190-
| if (shouldStop()) {
191-
| return;
192-
| }
193227
| }
194228
""".stripMargin
195229
}
@@ -251,7 +285,7 @@ case class WholeStageCodegen(plan: CodegenSupport, children: Seq[SparkPlan])
251285
}
252286

253287
/** Codegened pipeline for:
254-
* ${plan.treeString.trim}
288+
* ${commentSafe(plan.treeString.trim)}
255289
*/
256290
class GeneratedIterator extends org.apache.spark.sql.execution.BufferedRowIterator {
257291

@@ -305,7 +339,7 @@ case class WholeStageCodegen(plan: CodegenSupport, children: Seq[SparkPlan])
305339
if (row != null) {
306340
// There is an UnsafeRow already
307341
s"""
308-
| currentRows.add($row.copy());
342+
|currentRows.add($row.copy());
309343
""".stripMargin
310344
} else {
311345
assert(input != null)
@@ -317,13 +351,14 @@ case class WholeStageCodegen(plan: CodegenSupport, children: Seq[SparkPlan])
317351
ctx.currentVars = input
318352
val code = GenerateUnsafeProjection.createCode(ctx, colExprs, false)
319353
s"""
320-
| ${code.code.trim}
321-
| currentRows.add(${code.value}.copy());
354+
|${evaluateVariables(input)}
355+
|${code.code.trim}
356+
|currentRows.add(${code.value}.copy());
322357
""".stripMargin
323358
} else {
324359
// There is no columns
325360
s"""
326-
| currentRows.add(unsafeRow);
361+
|currentRows.add(unsafeRow);
327362
""".stripMargin
328363
}
329364
}

sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala

Lines changed: 27 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,14 @@ case class TungstenAggregate(
116116
// all the mode of aggregate expressions
117117
private val modes = aggregateExpressions.map(_.mode).distinct
118118

119+
override def references: AttributeSet = {
120+
AttributeSet(groupingExpressions.flatMap(_.references) ++ aggregateExpressions.flatMap {
121+
case AggregateExpression(f, Final | PartialMerge, _) => f.inputAggBufferAttributes
122+
case AggregateExpression(f, Partial | Complete, _) => f.references
123+
})
124+
child.outputSet
125+
}
126+
119127
override def supportCodegen: Boolean = {
120128
// ImperativeAggregate is not supported right now
121129
!aggregateExpressions.exists(_.aggregateFunction.isInstanceOf[ImperativeAggregate])
@@ -164,47 +172,47 @@ case class TungstenAggregate(
164172
""".stripMargin
165173
ExprCode(ev.code + initVars, isNull, value)
166174
}
175+
val initBufVar = evaluateVariables(bufVars)
167176

168177
// generate variables for output
169-
val bufferAttrs = functions.flatMap(_.aggBufferAttributes)
170178
val (resultVars, genResult) = if (modes.contains(Final) || modes.contains(Complete)) {
171179
// evaluate aggregate results
172180
ctx.currentVars = bufVars
173181
val aggResults = functions.map(_.evaluateExpression).map { e =>
174-
BindReferences.bindReference(e, bufferAttrs).gen(ctx)
182+
BindReferences.bindReference(e, aggregateBufferAttributes).gen(ctx)
175183
}
176184
// evaluate result expressions
177185
ctx.currentVars = aggResults
178186
val resultVars = resultExpressions.map { e =>
179187
BindReferences.bindReference(e, aggregateAttributes).gen(ctx)
180188
}
181189
(resultVars, s"""
182-
| ${aggResults.map(_.code).mkString("\n")}
183-
| ${resultVars.map(_.code).mkString("\n")}
190+
| ${evaluateVariables(aggResults)}
191+
| ${evaluateVariables(resultVars)}
184192
""".stripMargin)
185193
} else if (modes.contains(Partial) || modes.contains(PartialMerge)) {
186194
// output the aggregate buffer directly
187195
(bufVars, "")
188196
} else {
189197
// no aggregate function, the result should be literals
190198
val resultVars = resultExpressions.map(_.gen(ctx))
191-
(resultVars, resultVars.map(_.code).mkString("\n"))
199+
(resultVars, evaluateVariables(resultVars))
192200
}
193201

194202
val doAgg = ctx.freshName("doAggregateWithoutKey")
195203
ctx.addNewFunction(doAgg,
196204
s"""
197205
| private void $doAgg() throws java.io.IOException {
198206
| // initialize aggregation buffer
199-
| ${bufVars.map(_.code).mkString("\n")}
207+
| $initBufVar
200208
|
201209
| ${child.asInstanceOf[CodegenSupport].produce(ctx, this)}
202210
| }
203211
""".stripMargin)
204212

205213
val numOutput = metricTerm(ctx, "numOutputRows")
206214
s"""
207-
| if (!$initAgg) {
215+
| while (!$initAgg) {
208216
| $initAgg = true;
209217
| $doAgg();
210218
|
@@ -241,7 +249,7 @@ case class TungstenAggregate(
241249
}
242250
s"""
243251
| // do aggregate
244-
| ${aggVals.map(_.code).mkString("\n").trim}
252+
| ${evaluateVariables(aggVals)}
245253
| // update aggregation buffer
246254
| ${updates.mkString("\n").trim}
247255
""".stripMargin
@@ -252,8 +260,7 @@ case class TungstenAggregate(
252260
private val declFunctions = aggregateExpressions.map(_.aggregateFunction)
253261
.filter(_.isInstanceOf[DeclarativeAggregate])
254262
.map(_.asInstanceOf[DeclarativeAggregate])
255-
private val bufferAttributes = declFunctions.flatMap(_.aggBufferAttributes)
256-
private val bufferSchema = StructType.fromAttributes(bufferAttributes)
263+
private val bufferSchema = StructType.fromAttributes(aggregateBufferAttributes)
257264

258265
// The name for HashMap
259266
private var hashMapTerm: String = _
@@ -318,7 +325,7 @@ case class TungstenAggregate(
318325
val mergeExpr = declFunctions.flatMap(_.mergeExpressions)
319326
val mergeProjection = newMutableProjection(
320327
mergeExpr,
321-
bufferAttributes ++ declFunctions.flatMap(_.inputAggBufferAttributes),
328+
aggregateBufferAttributes ++ declFunctions.flatMap(_.inputAggBufferAttributes),
322329
subexpressionEliminationEnabled)()
323330
val joinedRow = new JoinedRow()
324331

@@ -381,13 +388,13 @@ case class TungstenAggregate(
381388
BoundReference(i, e.dataType, e.nullable).gen(ctx)
382389
}
383390
ctx.INPUT_ROW = bufferTerm
384-
val bufferVars = bufferAttributes.zipWithIndex.map { case (e, i) =>
391+
val bufferVars = aggregateBufferAttributes.zipWithIndex.map { case (e, i) =>
385392
BoundReference(i, e.dataType, e.nullable).gen(ctx)
386393
}
387394
// evaluate the aggregation result
388395
ctx.currentVars = bufferVars
389396
val aggResults = declFunctions.map(_.evaluateExpression).map { e =>
390-
BindReferences.bindReference(e, bufferAttributes).gen(ctx)
397+
BindReferences.bindReference(e, aggregateBufferAttributes).gen(ctx)
391398
}
392399
// generate the final result
393400
ctx.currentVars = keyVars ++ aggResults
@@ -396,11 +403,9 @@ case class TungstenAggregate(
396403
BindReferences.bindReference(e, inputAttrs).gen(ctx)
397404
}
398405
s"""
399-
${keyVars.map(_.code).mkString("\n")}
400-
${bufferVars.map(_.code).mkString("\n")}
401-
${aggResults.map(_.code).mkString("\n")}
402-
${resultVars.map(_.code).mkString("\n")}
403-
406+
${evaluateVariables(keyVars)}
407+
${evaluateVariables(bufferVars)}
408+
${evaluateVariables(aggResults)}
404409
${consume(ctx, resultVars)}
405410
"""
406411

@@ -422,10 +427,7 @@ case class TungstenAggregate(
422427
val eval = resultExpressions.map{ e =>
423428
BindReferences.bindReference(e, groupingAttributes).gen(ctx)
424429
}
425-
s"""
426-
${eval.map(_.code).mkString("\n")}
427-
${consume(ctx, eval)}
428-
"""
430+
consume(ctx, eval)
429431
}
430432
}
431433

@@ -508,8 +510,8 @@ case class TungstenAggregate(
508510
ctx.currentVars = input
509511
val hashEval = BindReferences.bindReference(hashExpr, child.output).gen(ctx)
510512

511-
val inputAttr = bufferAttributes ++ child.output
512-
ctx.currentVars = new Array[ExprCode](bufferAttributes.length) ++ input
513+
val inputAttr = aggregateBufferAttributes ++ child.output
514+
ctx.currentVars = new Array[ExprCode](aggregateBufferAttributes.length) ++ input
513515
ctx.INPUT_ROW = buffer
514516
// TODO: support subexpression elimination
515517
val evals = updateExpr.map(BindReferences.bindReference(_, inputAttr).gen(ctx))
@@ -557,7 +559,7 @@ case class TungstenAggregate(
557559
$incCounter
558560

559561
// evaluate aggregate function
560-
${evals.map(_.code).mkString("\n").trim}
562+
${evaluateVariables(evals)}
561563
// update aggregate buffer
562564
${updates.mkString("\n").trim}
563565
"""

sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala

Lines changed: 10 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -43,11 +43,12 @@ case class Project(projectList: Seq[NamedExpression], child: SparkPlan)
4343
val exprs = projectList.map(x =>
4444
ExpressionCanonicalizer.execute(BindReferences.bindReference(x, child.output)))
4545
ctx.currentVars = input
46-
val output = exprs.map(_.gen(ctx))
46+
val resultVars = exprs.map(_.gen(ctx))
47+
// Evaluation of non-deterministic expressions can't be deferred.
48+
val nonDeterministicAttrs = projectList.zip(output).filter(!_._1.deterministic).unzip._2
4749
s"""
48-
| ${output.map(_.code).mkString("\n")}
49-
|
50-
| ${consume(ctx, output)}
50+
|${evaluateRequiredVariables(output, resultVars, AttributeSet(nonDeterministicAttrs))}
51+
|${consume(ctx, resultVars)}
5152
""".stripMargin
5253
}
5354

@@ -89,11 +90,10 @@ case class Filter(condition: Expression, child: SparkPlan) extends UnaryNode wit
8990
s""
9091
}
9192
s"""
92-
| ${eval.code}
93-
| if ($nullCheck ${eval.value}) {
94-
| $numOutput.add(1);
95-
| ${consume(ctx, ctx.currentVars)}
96-
| }
93+
|${eval.code}
94+
|if (!($nullCheck ${eval.value})) continue;
95+
|$numOutput.add(1);
96+
|${consume(ctx, ctx.currentVars)}
9797
""".stripMargin
9898
}
9999

@@ -224,15 +224,13 @@ case class Range(
224224
| }
225225
| }
226226
|
227-
| while (!$overflow && $checkEnd) {
227+
| while (!$overflow && $checkEnd && !shouldStop()) {
228228
| long $value = $number;
229229
| $number += ${step}L;
230230
| if ($number < $value ^ ${step}L < 0) {
231231
| $overflow = true;
232232
| }
233233
| ${consume(ctx, Seq(ev))}
234-
|
235-
| if (shouldStop()) return;
236234
| }
237235
""".stripMargin
238236
}

0 commit comments

Comments
 (0)