Skip to content

Commit f6139e6

Browse files
author
Davies Liu
committed
Defer the evaluation of expresssions in Project
1 parent e58c8a6 commit f6139e6

File tree

4 files changed

+27
-23
lines changed

4 files changed

+27
-23
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -63,10 +63,9 @@ case class BoundReference(ordinal: Int, dataType: DataType, nullable: Boolean)
6363
val value = ctx.getValue(ctx.INPUT_ROW, dataType, ordinal.toString)
6464
if (ctx.currentVars != null && ctx.currentVars(ordinal) != null) {
6565
val oev = ctx.currentVars(ordinal)
66-
// assert(oev.code == "", s"$this has not been evaluated yet.")
6766
ev.isNull = oev.isNull
6867
ev.value = oev.value
69-
""
68+
oev.code
7069
} else if (nullable) {
7170
s"""
7271
boolean ${ev.isNull} = ${ctx.INPUT_ROW}.isNullAt($ordinal);

sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ trait CodegenSupport extends SparkPlan {
8080
ctx.freshNamePrefix = variablePrefix
8181
waitForSubqueries()
8282
s"""
83-
|/*** PRODUCE: ${commentSafe(this.simpleString)} */
83+
|/*** PRODUCE: ${toCommentSafeString(this.simpleString)} */
8484
|${doProduce(ctx)}
8585
""".stripMargin
8686
}
@@ -142,9 +142,10 @@ trait CodegenSupport extends SparkPlan {
142142
evaluateVars
143143
}
144144

145-
protected def commentSafe(s: String): String = {
146-
s.replace("*/", "\\*\\/").replace("\\u", "\\\\u")
147-
}
145+
/**
146+
* The subset of inputSet those should be evaluated before this plan.
147+
*/
148+
def usedInputs: AttributeSet = references
148149

149150
/**
150151
* Consume the columns generated from it's child, call doConsume() or emit the rows.
@@ -167,8 +168,8 @@ trait CodegenSupport extends SparkPlan {
167168
}
168169
s"""
169170
|
170-
|/*** CONSUME: ${commentSafe(this.simpleString)} */
171-
|${evaluateRequiredVariables(child.output, inputVars, references)}
171+
|/*** CONSUME: ${toCommentSafeString(this.simpleString)} */
172+
|${evaluateRequiredVariables(child.output, inputVars, usedInputs)}
172173
|${doConsume(ctx, inputVars)}
173174
""".stripMargin
174175
}
@@ -292,11 +293,7 @@ case class WholeStageCodegen(plan: CodegenSupport, children: Seq[SparkPlan])
292293
}
293294

294295
/** Codegened pipeline for:
295-
<<<<<<< HEAD
296-
* ${commentSafe(plan.treeString.trim)}
297-
=======
298296
* ${toCommentSafeString(plan.treeString.trim)}
299-
>>>>>>> 00461bb911c31aff9c945a14e23df2af4c280c23
300297
*/
301298
class GeneratedIterator extends org.apache.spark.sql.execution.BufferedRowIterator {
302299

@@ -358,11 +355,12 @@ case class WholeStageCodegen(plan: CodegenSupport, children: Seq[SparkPlan])
358355
val colExprs = output.zipWithIndex.map { case (attr, i) =>
359356
BoundReference(i, attr.dataType, attr.nullable)
360357
}
358+
val evaluateInputs = evaluateVariables(input)
361359
// generate the code to create a UnsafeRow
362360
ctx.currentVars = input
363361
val code = GenerateUnsafeProjection.createCode(ctx, colExprs, false)
364362
s"""
365-
|${evaluateVariables(input)}
363+
|$evaluateInputs
366364
|${code.code.trim}
367365
|currentRows.add(${code.value}.copy());
368366
""".stripMargin

sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -116,13 +116,7 @@ case class TungstenAggregate(
116116
// all the mode of aggregate expressions
117117
private val modes = aggregateExpressions.map(_.mode).distinct
118118

119-
override def references: AttributeSet = {
120-
AttributeSet(groupingExpressions.flatMap(_.references) ++ aggregateExpressions.flatMap {
121-
case AggregateExpression(f, Final | PartialMerge, _) => f.inputAggBufferAttributes
122-
case AggregateExpression(f, Partial | Complete, _) => f.references
123-
})
124-
child.outputSet
125-
}
119+
override def usedInputs: AttributeSet = inputSet
126120

127121
override def supportCodegen: Boolean = {
128122
// ImperativeAggregate is not supported right now
@@ -387,25 +381,28 @@ case class TungstenAggregate(
387381
val keyVars = groupingExpressions.zipWithIndex.map { case (e, i) =>
388382
BoundReference(i, e.dataType, e.nullable).gen(ctx)
389383
}
384+
val evaluateKeyVars = evaluateVariables(keyVars)
390385
ctx.INPUT_ROW = bufferTerm
391386
val bufferVars = aggregateBufferAttributes.zipWithIndex.map { case (e, i) =>
392387
BoundReference(i, e.dataType, e.nullable).gen(ctx)
393388
}
389+
val evaluateBufferVars = evaluateVariables(bufferVars)
394390
// evaluate the aggregation result
395391
ctx.currentVars = bufferVars
396392
val aggResults = declFunctions.map(_.evaluateExpression).map { e =>
397393
BindReferences.bindReference(e, aggregateBufferAttributes).gen(ctx)
398394
}
395+
val evaluateAggResults = evaluateVariables(aggResults)
399396
// generate the final result
400397
ctx.currentVars = keyVars ++ aggResults
401398
val inputAttrs = groupingAttributes ++ aggregateAttributes
402399
val resultVars = resultExpressions.map { e =>
403400
BindReferences.bindReference(e, inputAttrs).gen(ctx)
404401
}
405402
s"""
406-
${evaluateVariables(keyVars)}
407-
${evaluateVariables(bufferVars)}
408-
${evaluateVariables(aggResults)}
403+
$evaluateKeyVars
404+
$evaluateBufferVars
405+
$evaluateAggResults
409406
${consume(ctx, resultVars)}
410407
"""
411408

sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,16 @@ case class Project(projectList: Seq[NamedExpression], child: SparkPlan)
3939
child.asInstanceOf[CodegenSupport].produce(ctx, this)
4040
}
4141

42+
override def usedInputs: AttributeSet = {
43+
// only the attributes those are used at least twice should be evaluated before this plan,
44+
// otherwise we could defer the evaluation until output attribute is actually used.
45+
val usedExprIds = projectList.flatMap(_.collect {
46+
case a: Attribute => a.exprId
47+
})
48+
val usedMoreThanOnce = usedExprIds.groupBy(id => id).filter(_._2.size > 1).keySet
49+
references.filter(a => usedMoreThanOnce.contains(a.exprId))
50+
}
51+
4252
override def doConsume(ctx: CodegenContext, input: Seq[ExprCode]): String = {
4353
val exprs = projectList.map(x =>
4454
ExpressionCanonicalizer.execute(BindReferences.bindReference(x, child.output)))

0 commit comments

Comments
 (0)