From c92e45717cc2e23e5a9f552f99b7b1233d161f8b Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Fri, 19 Feb 2016 00:11:17 -0800 Subject: [PATCH 1/7] improve codegen --- .../catalyst/expressions/BoundAttribute.scala | 6 +- .../apache/spark/sql/execution/Expand.scala | 4 +- .../sql/execution/WholeStageCodegen.scala | 81 +++++++++++++------ .../aggregate/TungstenAggregate.scala | 52 ++++++------ .../spark/sql/execution/basicOperators.scala | 22 +++-- .../execution/joins/BroadcastHashJoin.scala | 51 ++++++------ 6 files changed, 125 insertions(+), 91 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala index 4727ff1885ad7..62e4e3082f5cc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala @@ -62,8 +62,10 @@ case class BoundReference(ordinal: Int, dataType: DataType, nullable: Boolean) val javaType = ctx.javaType(dataType) val value = ctx.getValue(ctx.INPUT_ROW, dataType, ordinal.toString) if (ctx.currentVars != null && ctx.currentVars(ordinal) != null) { - ev.isNull = ctx.currentVars(ordinal).isNull - ev.value = ctx.currentVars(ordinal).value + val oev = ctx.currentVars(ordinal) + // assert(oev.code == "", s"$this has not been evaluated yet.") + ev.isNull = oev.isNull + ev.value = oev.value "" } else if (nullable) { s""" diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Expand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Expand.scala index d26a0b74674a6..f3dd7a7ef2f5b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Expand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Expand.scala @@ -187,8 +187,10 @@ case class Expand( val numOutput = metricTerm(ctx, "numOutputRows") val i = ctx.freshName("i") + // these column have to declared before the loop. + val evaluate = evaluateVariables(outputColumns) s""" - |${outputColumns.map(_.code).mkString("\n").trim} + |$evaluate |for (int $i = 0; $i < ${projections.length}; $i ++) { | switch ($i) { | ${cases.mkString("\n").trim} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala index 8626f54eb413c..b5fada38f5cb7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala @@ -76,7 +76,10 @@ trait CodegenSupport extends SparkPlan { def produce(ctx: CodegenContext, parent: CodegenSupport): String = { this.parent = parent ctx.freshNamePrefix = variablePrefix - doProduce(ctx) + s""" + |/*** PRODUCE: ${commentSafe(this.simpleString)} */ + |${doProduce(ctx)} + """.stripMargin } /** @@ -108,6 +111,38 @@ trait CodegenSupport extends SparkPlan { parent.consumeChild(ctx, this, input, row) } + /** + * Returns source code to evaluate all the variables, and clear the code of them, to prevent + * them to be evaluated twice. + */ + protected def evaluateVariables(variables: Seq[ExprCode]): String = { + val evaluate = variables.filter(_.code != "").map(_.code.trim).mkString("\n") + variables.foreach(_.code = "") + evaluate + } + + /** + * Returns source code to evaluate the variables for required attributes, and clear the code + * of evaluated variables, to prevent them to be evaluated twice.. + */ + protected def evaluateRequiredVariables( + attributes: Seq[Attribute], + variables: Seq[ExprCode], + required: AttributeSet): String = { + var evaluateVars = "" + variables.zipWithIndex.foreach { case (ev, i) => + if (ev.code != "" && required.contains(attributes(i))) { + evaluateVars += ev.code.trim + "\n" + ev.code = "" + } + } + evaluateVars + } + + protected def commentSafe(s: String): String = { + s.replace("*/", "\\*\\/").replace("\\u", "\\\\u") + } + /** * Consume the columns generated from it's child, call doConsume() or emit the rows. */ @@ -117,19 +152,22 @@ trait CodegenSupport extends SparkPlan { input: Seq[ExprCode], row: String = null): String = { ctx.freshNamePrefix = variablePrefix - if (row != null) { - ctx.currentVars = null - ctx.INPUT_ROW = row - val evals = child.output.zipWithIndex.map { case (attr, i) => - BoundReference(i, attr.dataType, attr.nullable).gen(ctx) + val inputVars = + if (row != null) { + ctx.currentVars = null + ctx.INPUT_ROW = row + child.output.zipWithIndex.map { case (attr, i) => + BoundReference(i, attr.dataType, attr.nullable).gen(ctx) + } + } else { + input } - s""" - | ${evals.map(_.code).mkString("\n")} - | ${doConsume(ctx, evals)} - """.stripMargin - } else { - doConsume(ctx, input) - } + s""" + | + |/*** CONSUME: ${commentSafe(this.simpleString)} */ + |${evaluateRequiredVariables(child.output, inputVars, references)} + |${doConsume(ctx, inputVars)} + """.stripMargin } /** @@ -183,13 +221,9 @@ case class InputAdapter(child: SparkPlan) extends LeafNode with CodegenSupport { ctx.currentVars = null val columns = exprs.map(_.gen(ctx)) s""" - | while (input.hasNext()) { + | while (!shouldStop() && input.hasNext()) { | InternalRow $row = (InternalRow) input.next(); - | ${columns.map(_.code).mkString("\n").trim} | ${consume(ctx, columns).trim} - | if (shouldStop()) { - | return; - | } | } """.stripMargin } @@ -251,7 +285,7 @@ case class WholeStageCodegen(plan: CodegenSupport, children: Seq[SparkPlan]) } /** Codegened pipeline for: - * ${plan.treeString.trim} + * ${commentSafe(plan.treeString.trim)} */ class GeneratedIterator extends org.apache.spark.sql.execution.BufferedRowIterator { @@ -305,7 +339,7 @@ case class WholeStageCodegen(plan: CodegenSupport, children: Seq[SparkPlan]) if (row != null) { // There is an UnsafeRow already s""" - | currentRows.add($row.copy()); + |currentRows.add($row.copy()); """.stripMargin } else { assert(input != null) @@ -317,13 +351,14 @@ case class WholeStageCodegen(plan: CodegenSupport, children: Seq[SparkPlan]) ctx.currentVars = input val code = GenerateUnsafeProjection.createCode(ctx, colExprs, false) s""" - | ${code.code.trim} - | currentRows.add(${code.value}.copy()); + |${evaluateVariables(input)} + |${code.code.trim} + |currentRows.add(${code.value}.copy()); """.stripMargin } else { // There is no columns s""" - | currentRows.add(unsafeRow); + |currentRows.add(unsafeRow); """.stripMargin } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala index 852203f3743dc..f596a1a06f7ea 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala @@ -116,6 +116,14 @@ case class TungstenAggregate( // all the mode of aggregate expressions private val modes = aggregateExpressions.map(_.mode).distinct + override def references: AttributeSet = { + AttributeSet(groupingExpressions.flatMap(_.references) ++ aggregateExpressions.flatMap { + case AggregateExpression(f, Final | PartialMerge, _) => f.inputAggBufferAttributes + case AggregateExpression(f, Partial | Complete, _) => f.references + }) + child.outputSet + } + override def supportCodegen: Boolean = { // ImperativeAggregate is not supported right now !aggregateExpressions.exists(_.aggregateFunction.isInstanceOf[ImperativeAggregate]) @@ -164,14 +172,14 @@ case class TungstenAggregate( """.stripMargin ExprCode(ev.code + initVars, isNull, value) } + val initBufVar = evaluateVariables(bufVars) // generate variables for output - val bufferAttrs = functions.flatMap(_.aggBufferAttributes) val (resultVars, genResult) = if (modes.contains(Final) || modes.contains(Complete)) { // evaluate aggregate results ctx.currentVars = bufVars val aggResults = functions.map(_.evaluateExpression).map { e => - BindReferences.bindReference(e, bufferAttrs).gen(ctx) + BindReferences.bindReference(e, aggregateBufferAttributes).gen(ctx) } // evaluate result expressions ctx.currentVars = aggResults @@ -179,8 +187,8 @@ case class TungstenAggregate( BindReferences.bindReference(e, aggregateAttributes).gen(ctx) } (resultVars, s""" - | ${aggResults.map(_.code).mkString("\n")} - | ${resultVars.map(_.code).mkString("\n")} + | ${evaluateVariables(aggResults)} + | ${evaluateVariables(resultVars)} """.stripMargin) } else if (modes.contains(Partial) || modes.contains(PartialMerge)) { // output the aggregate buffer directly @@ -188,7 +196,7 @@ case class TungstenAggregate( } else { // no aggregate function, the result should be literals val resultVars = resultExpressions.map(_.gen(ctx)) - (resultVars, resultVars.map(_.code).mkString("\n")) + (resultVars, evaluateVariables(resultVars)) } val doAgg = ctx.freshName("doAggregateWithoutKey") @@ -196,7 +204,7 @@ case class TungstenAggregate( s""" | private void $doAgg() throws java.io.IOException { | // initialize aggregation buffer - | ${bufVars.map(_.code).mkString("\n")} + | $initBufVar | | ${child.asInstanceOf[CodegenSupport].produce(ctx, this)} | } @@ -204,7 +212,7 @@ case class TungstenAggregate( val numOutput = metricTerm(ctx, "numOutputRows") s""" - | if (!$initAgg) { + | while (!$initAgg) { | $initAgg = true; | $doAgg(); | @@ -241,7 +249,7 @@ case class TungstenAggregate( } s""" | // do aggregate - | ${aggVals.map(_.code).mkString("\n").trim} + | ${evaluateVariables(aggVals)} | // update aggregation buffer | ${updates.mkString("\n").trim} """.stripMargin @@ -252,8 +260,7 @@ case class TungstenAggregate( private val declFunctions = aggregateExpressions.map(_.aggregateFunction) .filter(_.isInstanceOf[DeclarativeAggregate]) .map(_.asInstanceOf[DeclarativeAggregate]) - private val bufferAttributes = declFunctions.flatMap(_.aggBufferAttributes) - private val bufferSchema = StructType.fromAttributes(bufferAttributes) + private val bufferSchema = StructType.fromAttributes(aggregateBufferAttributes) // The name for HashMap private var hashMapTerm: String = _ @@ -318,7 +325,7 @@ case class TungstenAggregate( val mergeExpr = declFunctions.flatMap(_.mergeExpressions) val mergeProjection = newMutableProjection( mergeExpr, - bufferAttributes ++ declFunctions.flatMap(_.inputAggBufferAttributes), + aggregateBufferAttributes ++ declFunctions.flatMap(_.inputAggBufferAttributes), subexpressionEliminationEnabled)() val joinedRow = new JoinedRow() @@ -381,13 +388,13 @@ case class TungstenAggregate( BoundReference(i, e.dataType, e.nullable).gen(ctx) } ctx.INPUT_ROW = bufferTerm - val bufferVars = bufferAttributes.zipWithIndex.map { case (e, i) => + val bufferVars = aggregateBufferAttributes.zipWithIndex.map { case (e, i) => BoundReference(i, e.dataType, e.nullable).gen(ctx) } // evaluate the aggregation result ctx.currentVars = bufferVars val aggResults = declFunctions.map(_.evaluateExpression).map { e => - BindReferences.bindReference(e, bufferAttributes).gen(ctx) + BindReferences.bindReference(e, aggregateBufferAttributes).gen(ctx) } // generate the final result ctx.currentVars = keyVars ++ aggResults @@ -396,11 +403,9 @@ case class TungstenAggregate( BindReferences.bindReference(e, inputAttrs).gen(ctx) } s""" - ${keyVars.map(_.code).mkString("\n")} - ${bufferVars.map(_.code).mkString("\n")} - ${aggResults.map(_.code).mkString("\n")} - ${resultVars.map(_.code).mkString("\n")} - + ${evaluateVariables(keyVars)} + ${evaluateVariables(bufferVars)} + ${evaluateVariables(aggResults)} ${consume(ctx, resultVars)} """ @@ -422,10 +427,7 @@ case class TungstenAggregate( val eval = resultExpressions.map{ e => BindReferences.bindReference(e, groupingAttributes).gen(ctx) } - s""" - ${eval.map(_.code).mkString("\n")} - ${consume(ctx, eval)} - """ + consume(ctx, eval) } } @@ -508,8 +510,8 @@ case class TungstenAggregate( ctx.currentVars = input val hashEval = BindReferences.bindReference(hashExpr, child.output).gen(ctx) - val inputAttr = bufferAttributes ++ child.output - ctx.currentVars = new Array[ExprCode](bufferAttributes.length) ++ input + val inputAttr = aggregateBufferAttributes ++ child.output + ctx.currentVars = new Array[ExprCode](aggregateBufferAttributes.length) ++ input ctx.INPUT_ROW = buffer // TODO: support subexpression elimination val evals = updateExpr.map(BindReferences.bindReference(_, inputAttr).gen(ctx)) @@ -557,7 +559,7 @@ case class TungstenAggregate( $incCounter // evaluate aggregate function - ${evals.map(_.code).mkString("\n").trim} + ${evaluateVariables(evals)} // update aggregate buffer ${updates.mkString("\n").trim} """ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala index 4b82d5563460b..ab42672d3949f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala @@ -43,11 +43,12 @@ case class Project(projectList: Seq[NamedExpression], child: SparkPlan) val exprs = projectList.map(x => ExpressionCanonicalizer.execute(BindReferences.bindReference(x, child.output))) ctx.currentVars = input - val output = exprs.map(_.gen(ctx)) + val resultVars = exprs.map(_.gen(ctx)) + // Evaluation of non-deterministic expressions can't be deferred. + val nonDeterministicAttrs = projectList.zip(output).filter(!_._1.deterministic).unzip._2 s""" - | ${output.map(_.code).mkString("\n")} - | - | ${consume(ctx, output)} + |${evaluateRequiredVariables(output, resultVars, AttributeSet(nonDeterministicAttrs))} + |${consume(ctx, resultVars)} """.stripMargin } @@ -89,11 +90,10 @@ case class Filter(condition: Expression, child: SparkPlan) extends UnaryNode wit s"" } s""" - | ${eval.code} - | if ($nullCheck ${eval.value}) { - | $numOutput.add(1); - | ${consume(ctx, ctx.currentVars)} - | } + |${eval.code} + |if (!($nullCheck ${eval.value})) continue; + |$numOutput.add(1); + |${consume(ctx, ctx.currentVars)} """.stripMargin } @@ -224,15 +224,13 @@ case class Range( | } | } | - | while (!$overflow && $checkEnd) { + | while (!$overflow && $checkEnd && !shouldStop()) { | long $value = $number; | $number += ${step}L; | if ($number < $value ^ ${step}L < 0) { | $overflow = true; | } | ${consume(ctx, Seq(ev))} - | - | if (shouldStop()) return; | } """.stripMargin } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala index a64da225800a3..cde525b80f836 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala @@ -235,22 +235,17 @@ case class BroadcastHashJoin( } val numOutput = metricTerm(ctx, "numOutputRows") - val outputCode = if (condition.isDefined) { + val checkCondition = if (condition.isDefined) { // filter the output via condition ctx.currentVars = resultVars val ev = BindReferences.bindReference(condition.get, this.output).gen(ctx) s""" + |${evaluateRequiredVariables(buildPlan.output, buildVars, condition.get.references)} |${ev.code} - |if (!${ev.isNull} && ${ev.value}) { - | $numOutput.add(1); - | ${consume(ctx, resultVars)} - |} + |if (${ev.isNull} || !${ev.value}) continue; """.stripMargin } else { - s""" - |$numOutput.add(1); - |${consume(ctx, resultVars)} - """.stripMargin + "" } if (broadcastRelation.value.isInstanceOf[UniqueHashedRelation]) { @@ -259,10 +254,10 @@ case class BroadcastHashJoin( |${keyEv.code} |// find matches from HashedRelation |UnsafeRow $matched = $anyNull ? null: (UnsafeRow)$relationTerm.getValue(${keyEv.value}); - |if ($matched != null) { - | ${buildVars.map(_.code).mkString("\n")} - | $outputCode - |} + |if ($matched == null) continue; + |$checkCondition + |$numOutput.add(1); + |${consume(ctx, resultVars)} """.stripMargin } else { @@ -275,13 +270,13 @@ case class BroadcastHashJoin( |${keyEv.code} |// find matches from HashRelation |$bufferType $matches = $anyNull ? null : ($bufferType)$relationTerm.get(${keyEv.value}); - |if ($matches != null) { - | int $size = $matches.size(); - | for (int $i = 0; $i < $size; $i++) { - | UnsafeRow $matched = (UnsafeRow) $matches.apply($i); - | ${buildVars.map(_.code).mkString("\n")} - | $outputCode - | } + |if ($matches == null) continue; + |int $size = $matches.size(); + |for (int $i = 0; $i < $size; $i++) { + | UnsafeRow $matched = (UnsafeRow) $matches.apply($i); + | $checkCondition + | $numOutput.add(1); + | ${consume(ctx, resultVars)} |} """.stripMargin } @@ -309,6 +304,7 @@ case class BroadcastHashJoin( val ev = BindReferences.bindReference(condition.get, this.output).gen(ctx) s""" |boolean $conditionPassed = true; + |${evaluateRequiredVariables(buildPlan.output, buildVars, condition.get.references)} |if ($matched != null) { | ${ev.code} | $conditionPassed = !${ev.isNull} && ${ev.value}; @@ -324,11 +320,12 @@ case class BroadcastHashJoin( |${keyEv.code} |// find matches from HashedRelation |UnsafeRow $matched = $anyNull ? null: (UnsafeRow)$relationTerm.getValue(${keyEv.value}); - |${buildVars.map(_.code).mkString("\n")} |${checkCondition.trim} |if (!$conditionPassed) { | // reset to null - | ${buildVars.map(v => s"${v.isNull} = true;").mkString("\n")} + | $matched = null; + | // reset the variables those are already evaluated. + | ${buildVars.filter(_.code == "").map(v => s"${v.isNull} = true;").mkString("\n")} |} |$numOutput.add(1); |${consume(ctx, resultVars)} @@ -350,13 +347,11 @@ case class BroadcastHashJoin( |// the last iteration of this loop is to emit an empty row if there is no matched rows. |for (int $i = 0; $i <= $size; $i++) { | UnsafeRow $matched = $i < $size ? (UnsafeRow) $matches.apply($i) : null; - | ${buildVars.map(_.code).mkString("\n")} | ${checkCondition.trim} - | if ($conditionPassed && ($i < $size || !$found)) { - | $found = true; - | $numOutput.add(1); - | ${consume(ctx, resultVars)} - | } + | if (!$conditionPassed || ($i == $size && $found)) continue; + | $found = true; + | $numOutput.add(1); + | ${consume(ctx, resultVars)} |} """.stripMargin } From f6139e6c43303206730d144c147bd45adf1277fb Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Mon, 22 Feb 2016 13:30:03 -0800 Subject: [PATCH 2/7] Defer the evaluation of expresssions in Project --- .../catalyst/expressions/BoundAttribute.scala | 3 +-- .../sql/execution/WholeStageCodegen.scala | 20 +++++++++---------- .../aggregate/TungstenAggregate.scala | 17 +++++++--------- .../spark/sql/execution/basicOperators.scala | 10 ++++++++++ 4 files changed, 27 insertions(+), 23 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala index 62e4e3082f5cc..72fe06545910c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala @@ -63,10 +63,9 @@ case class BoundReference(ordinal: Int, dataType: DataType, nullable: Boolean) val value = ctx.getValue(ctx.INPUT_ROW, dataType, ordinal.toString) if (ctx.currentVars != null && ctx.currentVars(ordinal) != null) { val oev = ctx.currentVars(ordinal) - // assert(oev.code == "", s"$this has not been evaluated yet.") ev.isNull = oev.isNull ev.value = oev.value - "" + oev.code } else if (nullable) { s""" boolean ${ev.isNull} = ${ctx.INPUT_ROW}.isNullAt($ordinal); diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala index fb387571780ca..3c046a5ac60ab 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala @@ -80,7 +80,7 @@ trait CodegenSupport extends SparkPlan { ctx.freshNamePrefix = variablePrefix waitForSubqueries() s""" - |/*** PRODUCE: ${commentSafe(this.simpleString)} */ + |/*** PRODUCE: ${toCommentSafeString(this.simpleString)} */ |${doProduce(ctx)} """.stripMargin } @@ -142,9 +142,10 @@ trait CodegenSupport extends SparkPlan { evaluateVars } - protected def commentSafe(s: String): String = { - s.replace("*/", "\\*\\/").replace("\\u", "\\\\u") - } + /** + * The subset of inputSet those should be evaluated before this plan. + */ + def usedInputs: AttributeSet = references /** * Consume the columns generated from it's child, call doConsume() or emit the rows. @@ -167,8 +168,8 @@ trait CodegenSupport extends SparkPlan { } s""" | - |/*** CONSUME: ${commentSafe(this.simpleString)} */ - |${evaluateRequiredVariables(child.output, inputVars, references)} + |/*** CONSUME: ${toCommentSafeString(this.simpleString)} */ + |${evaluateRequiredVariables(child.output, inputVars, usedInputs)} |${doConsume(ctx, inputVars)} """.stripMargin } @@ -292,11 +293,7 @@ case class WholeStageCodegen(plan: CodegenSupport, children: Seq[SparkPlan]) } /** Codegened pipeline for: -<<<<<<< HEAD - * ${commentSafe(plan.treeString.trim)} -======= * ${toCommentSafeString(plan.treeString.trim)} ->>>>>>> 00461bb911c31aff9c945a14e23df2af4c280c23 */ class GeneratedIterator extends org.apache.spark.sql.execution.BufferedRowIterator { @@ -358,11 +355,12 @@ case class WholeStageCodegen(plan: CodegenSupport, children: Seq[SparkPlan]) val colExprs = output.zipWithIndex.map { case (attr, i) => BoundReference(i, attr.dataType, attr.nullable) } + val evaluateInputs = evaluateVariables(input) // generate the code to create a UnsafeRow ctx.currentVars = input val code = GenerateUnsafeProjection.createCode(ctx, colExprs, false) s""" - |${evaluateVariables(input)} + |$evaluateInputs |${code.code.trim} |currentRows.add(${code.value}.copy()); """.stripMargin diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala index f596a1a06f7ea..1f126e22a0b85 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala @@ -116,13 +116,7 @@ case class TungstenAggregate( // all the mode of aggregate expressions private val modes = aggregateExpressions.map(_.mode).distinct - override def references: AttributeSet = { - AttributeSet(groupingExpressions.flatMap(_.references) ++ aggregateExpressions.flatMap { - case AggregateExpression(f, Final | PartialMerge, _) => f.inputAggBufferAttributes - case AggregateExpression(f, Partial | Complete, _) => f.references - }) - child.outputSet - } + override def usedInputs: AttributeSet = inputSet override def supportCodegen: Boolean = { // ImperativeAggregate is not supported right now @@ -387,15 +381,18 @@ case class TungstenAggregate( val keyVars = groupingExpressions.zipWithIndex.map { case (e, i) => BoundReference(i, e.dataType, e.nullable).gen(ctx) } + val evaluateKeyVars = evaluateVariables(keyVars) ctx.INPUT_ROW = bufferTerm val bufferVars = aggregateBufferAttributes.zipWithIndex.map { case (e, i) => BoundReference(i, e.dataType, e.nullable).gen(ctx) } + val evaluateBufferVars = evaluateVariables(bufferVars) // evaluate the aggregation result ctx.currentVars = bufferVars val aggResults = declFunctions.map(_.evaluateExpression).map { e => BindReferences.bindReference(e, aggregateBufferAttributes).gen(ctx) } + val evaluateAggResults = evaluateVariables(aggResults) // generate the final result ctx.currentVars = keyVars ++ aggResults val inputAttrs = groupingAttributes ++ aggregateAttributes @@ -403,9 +400,9 @@ case class TungstenAggregate( BindReferences.bindReference(e, inputAttrs).gen(ctx) } s""" - ${evaluateVariables(keyVars)} - ${evaluateVariables(bufferVars)} - ${evaluateVariables(aggResults)} + $evaluateKeyVars + $evaluateBufferVars + $evaluateAggResults ${consume(ctx, resultVars)} """ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala index 288ef85d8bda2..a9105876c601b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala @@ -39,6 +39,16 @@ case class Project(projectList: Seq[NamedExpression], child: SparkPlan) child.asInstanceOf[CodegenSupport].produce(ctx, this) } + override def usedInputs: AttributeSet = { + // only the attributes those are used at least twice should be evaluated before this plan, + // otherwise we could defer the evaluation until output attribute is actually used. + val usedExprIds = projectList.flatMap(_.collect { + case a: Attribute => a.exprId + }) + val usedMoreThanOnce = usedExprIds.groupBy(id => id).filter(_._2.size > 1).keySet + references.filter(a => usedMoreThanOnce.contains(a.exprId)) + } + override def doConsume(ctx: CodegenContext, input: Seq[ExprCode]): String = { val exprs = projectList.map(x => ExpressionCanonicalizer.execute(BindReferences.bindReference(x, child.output))) From 4fb0bc885f56dae8ea60c1b255e43cb51c486cb3 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Tue, 23 Feb 2016 16:14:54 -0800 Subject: [PATCH 3/7] fix broadcast hash join --- .../execution/joins/BroadcastHashJoin.scala | 32 +++++----- .../sql/execution/joins/SortMergeJoin.scala | 63 ++++++++++--------- 2 files changed, 52 insertions(+), 43 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala index 6522b2803adf3..b4c43a0bf1c47 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala @@ -190,18 +190,14 @@ case class BroadcastHashJoin( val (keyEv, anyNull) = genStreamSideJoinKey(ctx, input) val matched = ctx.freshName("matched") val buildVars = genBuildSideVars(ctx, matched) - val resultVars = buildSide match { - case BuildLeft => buildVars ++ input - case BuildRight => input ++ buildVars - } val numOutput = metricTerm(ctx, "numOutputRows") val checkCondition = if (condition.isDefined) { // filter the output via condition - ctx.currentVars = resultVars - val ev = BindReferences.bindReference(condition.get, this.output).gen(ctx) + ctx.currentVars = input ++ buildVars + val ev = BindReferences.bindReference( + condition.get, streamedPlan.output ++ buildPlan.output).gen(ctx) s""" - |${evaluateRequiredVariables(buildPlan.output, buildVars, condition.get.references)} |${ev.code} |if (${ev.isNull} || !${ev.value}) continue; """.stripMargin @@ -209,6 +205,10 @@ case class BroadcastHashJoin( "" } + val resultVars = buildSide match { + case BuildLeft => buildVars ++ input + case BuildRight => input ++ buildVars + } if (broadcastRelation.value.isInstanceOf[UniqueHashedRelation]) { s""" |// generate join key for stream side @@ -252,22 +252,20 @@ case class BroadcastHashJoin( val (keyEv, anyNull) = genStreamSideJoinKey(ctx, input) val matched = ctx.freshName("matched") val buildVars = genBuildSideVars(ctx, matched) - val resultVars = buildSide match { - case BuildLeft => buildVars ++ input - case BuildRight => input ++ buildVars - } val numOutput = metricTerm(ctx, "numOutputRows") // filter the output via condition val conditionPassed = ctx.freshName("conditionPassed") val checkCondition = if (condition.isDefined) { - ctx.currentVars = resultVars - val ev = BindReferences.bindReference(condition.get, this.output).gen(ctx) + val eval = evaluateRequiredVariables(buildPlan.output, buildVars, condition.get.references) + ctx.currentVars = input ++ buildVars + val ev = BindReferences.bindReference(condition.get, + streamedPlan.output ++ buildPlan.output).gen(ctx) s""" |boolean $conditionPassed = true; - |${evaluateRequiredVariables(buildPlan.output, buildVars, condition.get.references)} + |${eval.trim} + |${ev.code} |if ($matched != null) { - | ${ev.code} | $conditionPassed = !${ev.isNull} && ${ev.value}; |} """.stripMargin @@ -275,6 +273,10 @@ case class BroadcastHashJoin( s"final boolean $conditionPassed = true;" } + val resultVars = buildSide match { + case BuildLeft => buildVars ++ input + case BuildRight => input ++ buildVars + } if (broadcastRelation.value.isInstanceOf[UniqueHashedRelation]) { s""" |// generate join key for stream side diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala index 7ec4027188f14..cffd6f6032f2f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala @@ -306,11 +306,11 @@ case class SortMergeJoin( val (used, notUsed) = attributes.zip(variables).partition{ case (a, ev) => condRefs.contains(a) } - val beforeCond = used.map(_._2.code).mkString("\n") - val afterCond = notUsed.map(_._2.code).mkString("\n") + val beforeCond = evaluateVariables(used.map(_._2)) + val afterCond = evaluateVariables(notUsed.map(_._2)) (beforeCond, afterCond) } else { - (variables.map(_.code).mkString("\n"), "") + (evaluateVariables(variables), "") } } @@ -326,41 +326,48 @@ case class SortMergeJoin( val leftVars = createLeftVars(ctx, leftRow) val rightRow = ctx.freshName("rightRow") val rightVars = createRightVar(ctx, rightRow) - val resultVars = leftVars ++ rightVars - - // Check condition - ctx.currentVars = resultVars - val cond = if (condition.isDefined) { - BindReferences.bindReference(condition.get, output).gen(ctx) - } else { - ExprCode("", "false", "true") - } - // Split the code of creating variables based on whether it's used by condition or not. - val loaded = ctx.freshName("loaded") - val (leftBefore, leftAfter) = splitVarsByCondition(left.output, leftVars) - val (rightBefore, rightAfter) = splitVarsByCondition(right.output, rightVars) - val size = ctx.freshName("size") val i = ctx.freshName("i") val numOutput = metricTerm(ctx, "numOutputRows") + val (beforeLoop, condCheck) = if (condition.isDefined) { + // Split the code of creating variables based on whether it's used by condition or not. + val loaded = ctx.freshName("loaded") + val (leftBefore, leftAfter) = splitVarsByCondition(left.output, leftVars) + val (rightBefore, rightAfter) = splitVarsByCondition(right.output, rightVars) + // Generate code for condition + ctx.currentVars = leftVars ++ rightVars + val cond = BindReferences.bindReference(condition.get, output).gen(ctx) + // evaluate the columns those used by condition before loop + val before = s""" + |boolean $loaded = false; + |$leftBefore + """.stripMargin + + val checking = s""" + |$rightBefore + |${cond.code} + |if (${cond.isNull} || !${cond.value}) continue; + |if (!$loaded) { + | $loaded = true; + | $leftAfter + |} + |$rightAfter + """.stripMargin + (before, checking) + } else { + (evaluateVariables(leftVars), "") + } + s""" |while (findNextInnerJoinRows($leftInput, $rightInput)) { | int $size = $matches.size(); - | boolean $loaded = false; - | $leftBefore + | ${beforeLoop.trim} | for (int $i = 0; $i < $size; $i ++) { | InternalRow $rightRow = (InternalRow) $matches.get($i); - | $rightBefore - | ${cond.code} - | if (${cond.isNull} || !${cond.value}) continue; - | if (!$loaded) { - | $loaded = true; - | $leftAfter - | } - | $rightAfter + | ${condCheck.trim} | $numOutput.add(1); - | ${consume(ctx, resultVars)} + | ${consume(ctx, leftVars ++ rightVars)} | } | if (shouldStop()) return; |} From ef9e8f3558b20f6a4a7c94eb7c46a8f435694ad4 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Wed, 24 Feb 2016 16:22:14 -0800 Subject: [PATCH 4/7] fix bug --- .../sql/execution/joins/BroadcastHashJoin.scala | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala index b4c43a0bf1c47..c52662a61e7f8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala @@ -193,11 +193,14 @@ case class BroadcastHashJoin( val numOutput = metricTerm(ctx, "numOutputRows") val checkCondition = if (condition.isDefined) { + val expr = condition.get + // evaluate the variables from build side that used by condition + val eval = evaluateRequiredVariables(buildPlan.output, buildVars, expr.references) // filter the output via condition ctx.currentVars = input ++ buildVars - val ev = BindReferences.bindReference( - condition.get, streamedPlan.output ++ buildPlan.output).gen(ctx) + val ev = BindReferences.bindReference(expr, streamedPlan.output ++ buildPlan.output).gen(ctx) s""" + |$eval |${ev.code} |if (${ev.isNull} || !${ev.value}) continue; """.stripMargin @@ -257,10 +260,11 @@ case class BroadcastHashJoin( // filter the output via condition val conditionPassed = ctx.freshName("conditionPassed") val checkCondition = if (condition.isDefined) { - val eval = evaluateRequiredVariables(buildPlan.output, buildVars, condition.get.references) + val expr = condition.get + // evaluate the variables from build side that used by condition + val eval = evaluateRequiredVariables(buildPlan.output, buildVars, expr.references) ctx.currentVars = input ++ buildVars - val ev = BindReferences.bindReference(condition.get, - streamedPlan.output ++ buildPlan.output).gen(ctx) + val ev = BindReferences.bindReference(expr, streamedPlan.output ++ buildPlan.output).gen(ctx) s""" |boolean $conditionPassed = true; |${eval.trim} @@ -285,7 +289,6 @@ case class BroadcastHashJoin( |UnsafeRow $matched = $anyNull ? null: (UnsafeRow)$relationTerm.getValue(${keyEv.value}); |${checkCondition.trim} |if (!$conditionPassed) { - | // reset to null | $matched = null; | // reset the variables those are already evaluated. | ${buildVars.filter(_.code == "").map(v => s"${v.isNull} = true;").mkString("\n")} From ca8fe0f5f55cabb1bb5903c3e85c150b31eaa7c7 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Thu, 25 Feb 2016 00:26:58 -0800 Subject: [PATCH 5/7] fix aggregate --- .../scala/org/apache/spark/sql/execution/ExistingRDD.scala | 1 - .../spark/sql/execution/aggregate/TungstenAggregate.scala | 5 +++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala index 8649d2d69b62e..3f9e9a4b24913 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala @@ -152,7 +152,6 @@ private[sql] case class PhysicalRDD( | while ($input.hasNext()) { | InternalRow $row = (InternalRow) $input.next(); | $numOutputRows.add(1); - | ${columns.map(_.code).mkString("\n").trim} | ${consume(ctx, columns).trim} | if (shouldStop()) { | return; diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala index 89680c4eb6ee0..f07add83d5849 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala @@ -175,14 +175,15 @@ case class TungstenAggregate( val aggResults = functions.map(_.evaluateExpression).map { e => BindReferences.bindReference(e, aggregateBufferAttributes).gen(ctx) } + val evaluateAggResults = evaluateVariables(aggResults) // evaluate result expressions ctx.currentVars = aggResults val resultVars = resultExpressions.map { e => BindReferences.bindReference(e, aggregateAttributes).gen(ctx) } (resultVars, s""" - | ${evaluateVariables(aggResults)} - | ${evaluateVariables(resultVars)} + |$evaluateAggResults + |${evaluateVariables(resultVars)} """.stripMargin) } else if (modes.contains(Partial) || modes.contains(PartialMerge)) { // output the aggregate buffer directly From ffc9d8ca90192df2e2cff0b8ed4900b8c97601b0 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Fri, 4 Mar 2016 21:18:23 -0800 Subject: [PATCH 6/7] fix tests --- .../spark/sql/execution/ExistingRDD.scala | 30 +++++++++++-------- .../sql/execution/WholeStageCodegen.scala | 3 +- .../spark/sql/execution/basicOperators.scala | 11 +++++-- 3 files changed, 27 insertions(+), 17 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala index a73fffbd8dd1b..0a132beca9a6b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala @@ -151,9 +151,6 @@ private[sql] case class PhysicalRDD( val exprs = output.zipWithIndex.map(x => new BoundReference(x._2, x._1.dataType, true)) val row = ctx.freshName("row") val numOutputRows = metricTerm(ctx, "numOutputRows") - ctx.INPUT_ROW = row - ctx.currentVars = null - val columns = exprs.map(_.gen(ctx)) // The input RDD can either return (all) ColumnarBatches or InternalRows. We determine this // by looking at the first value of the RDD and then calling the function which will process @@ -161,7 +158,9 @@ private[sql] case class PhysicalRDD( // TODO: The abstractions between this class and SqlNewHadoopRDD makes it difficult to know // here which path to use. Fix this. - + ctx.INPUT_ROW = row + ctx.currentVars = null + val columns1 = exprs.map(_.gen(ctx)) val scanBatches = ctx.freshName("processBatches") ctx.addNewFunction(scanBatches, s""" @@ -170,11 +169,11 @@ private[sql] case class PhysicalRDD( | int numRows = $batch.numRows(); | if ($idx == 0) $numOutputRows.add(numRows); | - | while ($idx < numRows) { + | while (!shouldStop() && $idx < numRows) { | InternalRow $row = $batch.getRow($idx++); - | ${consume(ctx, columns).trim} - | if (shouldStop()) return; + | ${consume(ctx, columns1).trim} | } + | if (shouldStop()) return; | | if (!$input.hasNext()) { | $batch = null; @@ -185,16 +184,23 @@ private[sql] case class PhysicalRDD( | } | }""".stripMargin) + ctx.INPUT_ROW = row + ctx.currentVars = null + val columns2 = exprs.map(_.gen(ctx)) + val inputRow = if (isUnsafeRow) row else null val scanRows = ctx.freshName("processRows") ctx.addNewFunction(scanRows, s""" | private void $scanRows(InternalRow $row) throws java.io.IOException { - | while (true) { + | boolean firstRow = true; + | while (!shouldStop() && (firstRow || $input.hasNext())) { + | if (firstRow) { + | firstRow = false; + | } else { + | $row = (InternalRow) $input.next(); + | } | $numOutputRows.add(1); - | ${consume(ctx, columns).trim} - | if (shouldStop()) return; - | if (!$input.hasNext()) break; - | $row = (InternalRow)$input.next(); + | ${consume(ctx, columns2, inputRow).trim} | } | }""".stripMargin) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala index 5c5ccf5ec38bb..5d90240468f42 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala @@ -97,12 +97,11 @@ trait CodegenSupport extends SparkPlan { * # call child.produce() * initialized = true; * } - * while (hashmap.hasNext()) { + * while (!shouldStop() && hashmap.hasNext()) { * row = hashmap.next(); * # build the aggregation results * # create variables for results * # call consume(), which will call parent.doConsume() - * if (shouldStop()) return; * } */ protected def doProduce(ctx: CodegenContext): String diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala index dca8b31ca9c2a..4a9e736f7abdb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala @@ -40,8 +40,13 @@ case class Project(projectList: Seq[NamedExpression], child: SparkPlan) } override def usedInputs: AttributeSet = { - // filter out the expressions that just pass the input to next operator. - AttributeSet(projectList.filterNot(inputSet.contains).flatMap(_.references)) + // only the attributes those are used at least twice should be evaluated before this plan, + // otherwise we could defer the evaluation until output attribute is actually used. + val usedExprIds = projectList.flatMap(_.collect { + case a: Attribute => a.exprId + }) + val usedMoreThanOnce = usedExprIds.groupBy(id => id).filter(_._2.size > 1).keySet + references.filter(a => usedMoreThanOnce.contains(a.exprId)) } override def doConsume(ctx: CodegenContext, input: Seq[ExprCode]): String = { @@ -50,7 +55,7 @@ case class Project(projectList: Seq[NamedExpression], child: SparkPlan) ctx.currentVars = input val resultVars = exprs.map(_.gen(ctx)) // Evaluation of non-deterministic expressions can't be deferred. - val nonDeterministicAttrs = projectList.zip(output).filter(!_._1.deterministic).unzip._2 + val nonDeterministicAttrs = projectList.filterNot(_.deterministic).map(_.toAttribute) s""" |${evaluateRequiredVariables(output, resultVars, AttributeSet(nonDeterministicAttrs))} |${consume(ctx, resultVars)} From f4311709dd0c66add99aeb248acdc70863fba239 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Mon, 7 Mar 2016 15:51:02 -0800 Subject: [PATCH 7/7] improve docs --- .../sql/catalyst/expressions/codegen/CodeGenerator.scala | 2 ++ .../apache/spark/sql/execution/WholeStageCodegen.scala | 8 ++++++-- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 63e19564dd861..c4265a753933f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -37,6 +37,8 @@ import org.apache.spark.util.Utils * Java source for evaluating an [[Expression]] given a [[InternalRow]] of input. * * @param code The sequence of statements required to evaluate the expression. + * It should be empty string, if `isNull` and `value` are already existed, or no code + * needed to evaluate them (literals). * @param isNull A term that holds a boolean value representing whether the expression evaluated * to null. * @param value A term for a (possibly primitive) value of the result of the evaluation. Not diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala index 5d90240468f42..45578d50bfc0d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala @@ -146,15 +146,19 @@ trait CodegenSupport extends SparkPlan { /** * The subset of inputSet those should be evaluated before this plan. + * + * We will use this to insert some code to access those columns that are actually used by current + * plan before calling doConsume(). */ def usedInputs: AttributeSet = references /** - * Consume the columns generated from it's child, call doConsume() or emit the rows. + * Consume the columns generated from its child, call doConsume() or emit the rows. * * An operator could generate variables for the output, or a row, either one could be null. * - * If the row is not null, we create variables to access the columns before calling doConsume(). + * If the row is not null, we create variables to access the columns that are actually used by + * current plan before calling doConsume(). */ def consumeChild( ctx: CodegenContext,