From 9c77d73160d0aefb75eb0e5e10670a0e510f958d Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Wed, 2 Mar 2016 15:13:01 +0000 Subject: [PATCH 1/6] init import. --- .../apache/spark/sql/execution/Expand.scala | 2 +- .../org/apache/spark/sql/execution/Sort.scala | 28 +++++---- .../sql/execution/WholeStageCodegen.scala | 59 ++++++++++++++----- .../aggregate/TungstenAggregate.scala | 2 +- .../spark/sql/execution/basicOperators.scala | 4 +- .../execution/joins/BroadcastHashJoin.scala | 2 +- .../apache/spark/sql/execution/limit.scala | 2 +- 7 files changed, 66 insertions(+), 33 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Expand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Expand.scala index 12998a38f59e..09062a3ba627 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Expand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Expand.scala @@ -93,7 +93,7 @@ case class Expand( child.asInstanceOf[CodegenSupport].produce(ctx, this) } - override def doConsume(ctx: CodegenContext, input: Seq[ExprCode]): String = { + override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: String = null): String = { /* * When the projections list looks like: * expr1A, exprB, expr1C diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Sort.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Sort.scala index 2ea889ea72c7..7bc6492bd453 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Sort.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Sort.scala @@ -105,6 +105,8 @@ case class Sort( // Name of sorter variable used in codegen. private var sorterVariable: String = _ + override def consumeUnsafeRow: Boolean = true + override protected def doProduce(ctx: CodegenContext): String = { val needToSort = ctx.freshName("needToSort") ctx.addMutableState("boolean", needToSort, s"$needToSort = true;") @@ -153,18 +155,22 @@ case class Sort( """.stripMargin.trim } - override def doConsume(ctx: CodegenContext, input: Seq[ExprCode]): String = { - val colExprs = child.output.zipWithIndex.map { case (attr, i) => - BoundReference(i, attr.dataType, attr.nullable) - } + override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: String = null): String = { + if (row != null) { + s"$sorterVariable.insertRow((UnsafeRow)$row.copy());" + } else { + val colExprs = child.output.zipWithIndex.map { case (attr, i) => + BoundReference(i, attr.dataType, attr.nullable) + } - ctx.currentVars = input - val code = GenerateUnsafeProjection.createCode(ctx, colExprs) + ctx.currentVars = input + val code = GenerateUnsafeProjection.createCode(ctx, colExprs) - s""" - | // Convert the input attributes to an UnsafeRow and add it to the sorter - | ${code.code} - | $sorterVariable.insertRow(${code.value}); - """.stripMargin.trim + s""" + | // Convert the input attributes to an UnsafeRow and add it to the sorter + | ${code.code} + | $sorterVariable.insertRow(${code.value}); + """.stripMargin.trim + } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala index cb68ca6ada36..951332973154 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala @@ -67,7 +67,12 @@ trait CodegenSupport extends SparkPlan { /** * Which SparkPlan is calling produce() of this one. It's itself for the first SparkPlan. */ - private var parent: CodegenSupport = null + protected var parent: CodegenSupport = null + + /** + * Whether this SparkPlan accepts UnsafeRow as input in consumeChild. + */ + def consumeUnsafeRow: Boolean = false /** * Returns all the RDDs of InternalRow which generates the input rows. @@ -109,7 +114,7 @@ trait CodegenSupport extends SparkPlan { * Consume the columns generated from current SparkPlan, call it's parent. */ final def consume(ctx: CodegenContext, input: Seq[ExprCode], row: String = null): String = { - if (input != null) { + if (input != null && !parent.consumeUnsafeRow) { assert(input.length == output.length) } parent.consumeChild(ctx, this, input, row) @@ -125,14 +130,27 @@ trait CodegenSupport extends SparkPlan { row: String = null): String = { ctx.freshNamePrefix = variablePrefix if (row != null) { - ctx.currentVars = null - ctx.INPUT_ROW = row - val evals = child.output.zipWithIndex.map { case (attr, i) => - BoundReference(i, attr.dataType, attr.nullable).gen(ctx) + val evals: Seq[ExprCode] = if (!consumeUnsafeRow) { + // If this SparkPlan can't consume UnsafeRow and there is an UnsafeRow, + // we extract the columns from the row and call doConsume. + ctx.currentVars = null + ctx.INPUT_ROW = row + child.output.zipWithIndex.map { case (attr, i) => + BoundReference(i, attr.dataType, attr.nullable).gen(ctx) + } + } else { + // If this SparkPlan consumes UnsafeRow and there is an UnsafeRow, + // we don't need to unpack variables from the row. + Seq.empty + } + val evalCode = if (evals.isEmpty) { + "" + } else { + s"${evals.map(_.code).mkString("\n")}" } s""" - | ${evals.map(_.code).mkString("\n")} - | ${doConsume(ctx, evals)} + | $evalCode + | ${doConsume(ctx, evals, row)} """.stripMargin } else { doConsume(ctx, input) @@ -151,7 +169,7 @@ trait CodegenSupport extends SparkPlan { * # call consume(), which will call parent.doConsume() * } */ - protected def doConsume(ctx: CodegenContext, input: Seq[ExprCode]): String = { + protected def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: String = null): String = { throw new UnsupportedOperationException } } @@ -191,17 +209,26 @@ case class InputAdapter(child: SparkPlan) extends LeafNode with CodegenSupport { val input = ctx.freshName("input") // Right now, InputAdapter is only used when there is one upstream. ctx.addMutableState("scala.collection.Iterator", input, s"$input = inputs[0];") - - val exprs = output.zipWithIndex.map(x => new BoundReference(x._2, x._1.dataType, true)) val row = ctx.freshName("row") - ctx.INPUT_ROW = row - ctx.currentVars = null - val columns = exprs.map(_.gen(ctx)) + + val columns: Seq[ExprCode] = if (!this.parent.consumeUnsafeRow) { + val exprs = output.zipWithIndex.map(x => new BoundReference(x._2, x._1.dataType, true)) + ctx.INPUT_ROW = row + ctx.currentVars = null + exprs.map(_.gen(ctx)) + } else { + Seq.empty + } + val columnsCode = if (columns.isEmpty) { + "" + } else { + s"${columns.map(_.code).mkString("\n").trim}" + } s""" | while ($input.hasNext()) { | InternalRow $row = (InternalRow) $input.next(); - | ${columns.map(_.code).mkString("\n").trim} - | ${consume(ctx, columns).trim} + | $columnsCode + | ${consume(ctx, columns, row).trim} | if (shouldStop()) { | return; | } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala index a46722963a6e..ef4391786c28 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala @@ -133,7 +133,7 @@ case class TungstenAggregate( } } - override def doConsume(ctx: CodegenContext, input: Seq[ExprCode]): String = { + override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: String = null): String = { if (groupingExpressions.isEmpty) { doConsumeWithoutKeys(ctx, input) } else { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala index b2f443c0e9ae..a709538d2bcc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala @@ -39,7 +39,7 @@ case class Project(projectList: Seq[NamedExpression], child: SparkPlan) child.asInstanceOf[CodegenSupport].produce(ctx, this) } - override def doConsume(ctx: CodegenContext, input: Seq[ExprCode]): String = { + override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: String = null): String = { val exprs = projectList.map(x => ExpressionCanonicalizer.execute(BindReferences.bindReference(x, child.output))) ctx.currentVars = input @@ -77,7 +77,7 @@ case class Filter(condition: Expression, child: SparkPlan) extends UnaryNode wit child.asInstanceOf[CodegenSupport].produce(ctx, this) } - override def doConsume(ctx: CodegenContext, input: Seq[ExprCode]): String = { + override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: String = null): String = { val numOutput = metricTerm(ctx, "numOutputRows") val expr = ExpressionCanonicalizer.execute( BindReferences.bindReference(condition, child.output)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala index 6699dbafe7e7..3f28c60069dc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala @@ -107,7 +107,7 @@ case class BroadcastHashJoin( streamedPlan.asInstanceOf[CodegenSupport].produce(ctx, this) } - override def doConsume(ctx: CodegenContext, input: Seq[ExprCode]): String = { + override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: String = null): String = { if (joinType == Inner) { codegenInner(ctx, input) } else { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala index 45175d36d5c9..6f2d7817aff7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala @@ -66,7 +66,7 @@ trait BaseLimit extends UnaryNode with CodegenSupport { child.asInstanceOf[CodegenSupport].produce(ctx, this) } - override def doConsume(ctx: CodegenContext, input: Seq[ExprCode]): String = { + override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: String = null): String = { val stopEarly = ctx.freshName("stopEarly") ctx.addMutableState("boolean", stopEarly, s"$stopEarly = false;") From 6941eb1370b7bbba35dcb45dd5ae59b43bfedf40 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Thu, 3 Mar 2016 04:22:27 +0000 Subject: [PATCH 2/6] Add some comments. --- .../apache/spark/sql/execution/WholeStageCodegen.scala | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala index 951332973154..6d4efa757e79 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala @@ -114,7 +114,10 @@ trait CodegenSupport extends SparkPlan { * Consume the columns generated from current SparkPlan, call it's parent. */ final def consume(ctx: CodegenContext, input: Seq[ExprCode], row: String = null): String = { - if (input != null && !parent.consumeUnsafeRow) { + // We check if input expressions has same length as output when: + // 1. parent can't consume UnsafeRow and input is not null. + // 2. parent consumes UnsafeRow and row is null. + if ((input != null && !parent.consumeUnsafeRow) || (parent.consumeUnsafeRow && row == null)) { assert(input.length == output.length) } parent.consumeChild(ctx, this, input, row) @@ -211,12 +214,15 @@ case class InputAdapter(child: SparkPlan) extends LeafNode with CodegenSupport { ctx.addMutableState("scala.collection.Iterator", input, s"$input = inputs[0];") val row = ctx.freshName("row") + // If the parent of this InputAdapter can't consume UnsafeRow, + // we unpack variables from the row. val columns: Seq[ExprCode] = if (!this.parent.consumeUnsafeRow) { val exprs = output.zipWithIndex.map(x => new BoundReference(x._2, x._1.dataType, true)) ctx.INPUT_ROW = row ctx.currentVars = null exprs.map(_.gen(ctx)) } else { + // If the parent consumes UnsafeRow, we don't need to unpack the row. Seq.empty } val columnsCode = if (columns.isEmpty) { From 7428fd4ce7008631194a29c9c02a3e2cd0aa9a7c Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Wed, 9 Mar 2016 14:42:21 +0800 Subject: [PATCH 3/6] Address comments. --- .../src/main/scala/org/apache/spark/sql/execution/Sort.scala | 2 +- .../org/apache/spark/sql/execution/WholeStageCodegen.scala | 4 ++-- .../scala/org/apache/spark/sql/execution/basicOperators.scala | 4 ++-- .../scala/org/apache/spark/sql/execution/debug/package.scala | 2 +- .../src/main/scala/org/apache/spark/sql/execution/limit.scala | 2 +- 5 files changed, 7 insertions(+), 7 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Sort.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Sort.scala index 7bc6492bd453..2e252cb26071 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Sort.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Sort.scala @@ -155,7 +155,7 @@ case class Sort( """.stripMargin.trim } - override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: String = null): String = { + override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: String): String = { if (row != null) { s"$sorterVariable.insertRow((UnsafeRow)$row.copy());" } else { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala index eb5f2755e66b..ce46467dbb56 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala @@ -68,7 +68,7 @@ trait CodegenSupport extends SparkPlan { protected var parent: CodegenSupport = null /** - * Whether this SparkPlan accepts UnsafeRow as input in consumeChild. + * Whether this SparkPlan accepts UnsafeRow as input in doConsume. */ def consumeUnsafeRow: Boolean = false @@ -211,7 +211,7 @@ trait CodegenSupport extends SparkPlan { * if (isNull1 || !value2) continue; * # call consume(), which will call parent.doConsume() */ - protected def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: String = null): String = { + protected def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: String): String = { throw new UnsupportedOperationException } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala index 5e27eea81b1d..1e3446192c2c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala @@ -49,7 +49,7 @@ case class Project(projectList: Seq[NamedExpression], child: SparkPlan) references.filter(a => usedMoreThanOnce.contains(a.exprId)) } - override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: String = null): String = { + override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: String): String = { val exprs = projectList.map(x => ExpressionCanonicalizer.execute(BindReferences.bindReference(x, child.output))) ctx.currentVars = input @@ -88,7 +88,7 @@ case class Filter(condition: Expression, child: SparkPlan) extends UnaryNode wit child.asInstanceOf[CodegenSupport].produce(ctx, this) } - override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: String = null): String = { + override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: String): String = { val numOutput = metricTerm(ctx, "numOutputRows") val expr = ExpressionCanonicalizer.execute( BindReferences.bindReference(condition, child.output)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala index 228b3791875f..034bf152620d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala @@ -136,7 +136,7 @@ package object debug { child.asInstanceOf[CodegenSupport].produce(ctx, this) } - override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: String = null): String = { + override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: String): String = { consume(ctx, input) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala index 5193180ff7e3..ca624a5a84e6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala @@ -65,7 +65,7 @@ trait BaseLimit extends UnaryNode with CodegenSupport { child.asInstanceOf[CodegenSupport].produce(ctx, this) } - override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: String = null): String = { + override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: String): String = { val stopEarly = ctx.freshName("stopEarly") ctx.addMutableState("boolean", stopEarly, s"$stopEarly = false;") From 6400eb22aefa986fb5d96d3d6d0242778e2f332a Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Wed, 9 Mar 2016 15:02:24 +0800 Subject: [PATCH 4/6] Simplify the solution. --- .../sql/execution/WholeStageCodegen.scala | 44 ++++++++----------- 1 file changed, 19 insertions(+), 25 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala index ce46467dbb56..04bbc38ab759 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala @@ -174,20 +174,21 @@ trait CodegenSupport extends SparkPlan { input: Seq[ExprCode], row: String = null): String = { ctx.freshNamePrefix = variablePrefix + val realUsedInput = + if (row != null && consumeUnsafeRow) { + // If this SparkPlan consumes UnsafeRow and there is an UnsafeRow passed in, + // we don't need to evaluate inputs because doConsume will directly consume the UnsafeRow. + AttributeSet.empty + } else { + usedInputs + } + val inputVars = if (row != null) { - if (!consumeUnsafeRow) { - // If this SparkPlan can't consume UnsafeRow and there is an UnsafeRow, - // we extract the columns from the row and call doConsume. - ctx.currentVars = null - ctx.INPUT_ROW = row - child.output.zipWithIndex.map { case (attr, i) => - BoundReference(i, attr.dataType, attr.nullable).gen(ctx) - } - } else { - // If this SparkPlan consumes UnsafeRow and there is an UnsafeRow, - // we don't need to unpack variables from the row. - Seq.empty + ctx.currentVars = null + ctx.INPUT_ROW = row + child.output.zipWithIndex.map { case (attr, i) => + BoundReference(i, attr.dataType, attr.nullable).gen(ctx) } } else { input @@ -195,7 +196,7 @@ trait CodegenSupport extends SparkPlan { s""" | |/*** CONSUME: ${toCommentSafeString(this.simpleString)} */ - |${evaluateRequiredVariables(child.output, inputVars, usedInputs)} + |${evaluateRequiredVariables(child.output, inputVars, realUsedInput)} |${doConsume(ctx, inputVars, row)} """.stripMargin } @@ -245,19 +246,12 @@ case class InputAdapter(child: SparkPlan) extends UnaryNode with CodegenSupport val input = ctx.freshName("input") // Right now, InputAdapter is only used when there is one upstream. ctx.addMutableState("scala.collection.Iterator", input, s"$input = inputs[0];") - val row = ctx.freshName("row") - // If the parent of this InputAdapter can't consume UnsafeRow, - // we unpack variables from the row. - val columns: Seq[ExprCode] = if (!this.parent.consumeUnsafeRow) { - val exprs = output.zipWithIndex.map(x => new BoundReference(x._2, x._1.dataType, true)) - ctx.INPUT_ROW = row - ctx.currentVars = null - exprs.map(_.gen(ctx)) - } else { - // If the parent consumes UnsafeRow, we don't need to unpack the row. - Seq.empty - } + val exprs = output.zipWithIndex.map(x => new BoundReference(x._2, x._1.dataType, true)) + val row = ctx.freshName("row") + ctx.INPUT_ROW = row + ctx.currentVars = null + val columns = exprs.map(_.gen(ctx)) s""" | while (!shouldStop() && $input.hasNext()) { | InternalRow $row = (InternalRow) $input.next(); From dea644a74251c62f1ccb0fd095083d434acd1a8c Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Wed, 9 Mar 2016 15:07:01 +0800 Subject: [PATCH 5/6] Address comments. --- .../src/main/scala/org/apache/spark/sql/execution/Expand.scala | 2 +- .../spark/sql/execution/aggregate/TungstenAggregate.scala | 2 +- .../apache/spark/sql/execution/joins/BroadcastHashJoin.scala | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Expand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Expand.scala index a3d49a7d1fc1..a84e180ad1dd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Expand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Expand.scala @@ -93,7 +93,7 @@ case class Expand( child.asInstanceOf[CodegenSupport].produce(ctx, this) } - override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: String = null): String = { + override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: String): String = { /* * When the projections list looks like: * expr1A, exprB, expr1C diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala index 1b0ee6b2990e..2eea910929c8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala @@ -135,7 +135,7 @@ case class TungstenAggregate( } } - override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: String = null): String = { + override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: String): String = { if (groupingExpressions.isEmpty) { doConsumeWithoutKeys(ctx, input) } else { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala index e949680dea29..4c8f8080a98d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala @@ -107,7 +107,7 @@ case class BroadcastHashJoin( streamedPlan.asInstanceOf[CodegenSupport].produce(ctx, this) } - override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: String = null): String = { + override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: String): String = { if (joinType == Inner) { codegenInner(ctx, input) } else { From 6f0ae353a648f6b060b318726d550181bc40b4e2 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Thu, 10 Mar 2016 11:01:37 +0800 Subject: [PATCH 6/6] Address comments. --- .../org/apache/spark/sql/execution/Sort.scala | 4 +-- .../sql/execution/WholeStageCodegen.scala | 29 +++++++++---------- 2 files changed, 15 insertions(+), 18 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Sort.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Sort.scala index 2e252cb26071..5a67cd0c24b4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Sort.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Sort.scala @@ -105,7 +105,7 @@ case class Sort( // Name of sorter variable used in codegen. private var sorterVariable: String = _ - override def consumeUnsafeRow: Boolean = true + override def preferUnsafeRow: Boolean = true override protected def doProduce(ctx: CodegenContext): String = { val needToSort = ctx.freshName("needToSort") @@ -157,7 +157,7 @@ case class Sort( override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: String): String = { if (row != null) { - s"$sorterVariable.insertRow((UnsafeRow)$row.copy());" + s"$sorterVariable.insertRow((UnsafeRow)$row);" } else { val colExprs = child.output.zipWithIndex.map { case (attr, i) => BoundReference(i, attr.dataType, attr.nullable) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala index 04bbc38ab759..f084ce9a68b5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala @@ -68,9 +68,9 @@ trait CodegenSupport extends SparkPlan { protected var parent: CodegenSupport = null /** - * Whether this SparkPlan accepts UnsafeRow as input in doConsume. + * Whether this SparkPlan prefers to accept UnsafeRow as input in doConsume. */ - def consumeUnsafeRow: Boolean = false + def preferUnsafeRow: Boolean = false /** * Returns all the RDDs of InternalRow which generates the input rows. @@ -115,10 +115,7 @@ trait CodegenSupport extends SparkPlan { * Consume the columns generated from current SparkPlan, call it's parent. */ final def consume(ctx: CodegenContext, input: Seq[ExprCode], row: String = null): String = { - // We check if input expressions has same length as output when: - // 1. parent can't consume UnsafeRow and input is not null. - // 2. parent consumes UnsafeRow and row is null. - if ((input != null && !parent.consumeUnsafeRow) || (parent.consumeUnsafeRow && row == null)) { + if (input != null) { assert(input.length == output.length) } parent.consumeChild(ctx, this, input, row) @@ -174,15 +171,6 @@ trait CodegenSupport extends SparkPlan { input: Seq[ExprCode], row: String = null): String = { ctx.freshNamePrefix = variablePrefix - val realUsedInput = - if (row != null && consumeUnsafeRow) { - // If this SparkPlan consumes UnsafeRow and there is an UnsafeRow passed in, - // we don't need to evaluate inputs because doConsume will directly consume the UnsafeRow. - AttributeSet.empty - } else { - usedInputs - } - val inputVars = if (row != null) { ctx.currentVars = null @@ -193,10 +181,19 @@ trait CodegenSupport extends SparkPlan { } else { input } + + val evaluated = + if (row != null && preferUnsafeRow) { + // Current plan can consume UnsafeRows directly. + "" + } else { + evaluateRequiredVariables(child.output, inputVars, usedInputs) + } + s""" | |/*** CONSUME: ${toCommentSafeString(this.simpleString)} */ - |${evaluateRequiredVariables(child.output, inputVars, realUsedInput)} + |${evaluated} |${doConsume(ctx, inputVars, row)} """.stripMargin }