Skip to content

Commit ffc9d8c

Browse files
author
Davies Liu
committed
fix tests
1 parent 76ca6c6 commit ffc9d8c

File tree

3 files changed

+27
-17
lines changed

3 files changed

+27
-17
lines changed

sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala

Lines changed: 18 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -151,17 +151,16 @@ private[sql] case class PhysicalRDD(
151151
val exprs = output.zipWithIndex.map(x => new BoundReference(x._2, x._1.dataType, true))
152152
val row = ctx.freshName("row")
153153
val numOutputRows = metricTerm(ctx, "numOutputRows")
154-
ctx.INPUT_ROW = row
155-
ctx.currentVars = null
156-
val columns = exprs.map(_.gen(ctx))
157154

158155
// The input RDD can either return (all) ColumnarBatches or InternalRows. We determine this
159156
// by looking at the first value of the RDD and then calling the function which will process
160157
// the remaining. It is faster to return batches.
161158
// TODO: The abstractions between this class and SqlNewHadoopRDD makes it difficult to know
162159
// here which path to use. Fix this.
163160

164-
161+
ctx.INPUT_ROW = row
162+
ctx.currentVars = null
163+
val columns1 = exprs.map(_.gen(ctx))
165164
val scanBatches = ctx.freshName("processBatches")
166165
ctx.addNewFunction(scanBatches,
167166
s"""
@@ -170,11 +169,11 @@ private[sql] case class PhysicalRDD(
170169
| int numRows = $batch.numRows();
171170
| if ($idx == 0) $numOutputRows.add(numRows);
172171
|
173-
| while ($idx < numRows) {
172+
| while (!shouldStop() && $idx < numRows) {
174173
| InternalRow $row = $batch.getRow($idx++);
175-
| ${consume(ctx, columns).trim}
176-
| if (shouldStop()) return;
174+
| ${consume(ctx, columns1).trim}
177175
| }
176+
| if (shouldStop()) return;
178177
|
179178
| if (!$input.hasNext()) {
180179
| $batch = null;
@@ -185,16 +184,23 @@ private[sql] case class PhysicalRDD(
185184
| }
186185
| }""".stripMargin)
187186

187+
ctx.INPUT_ROW = row
188+
ctx.currentVars = null
189+
val columns2 = exprs.map(_.gen(ctx))
190+
val inputRow = if (isUnsafeRow) row else null
188191
val scanRows = ctx.freshName("processRows")
189192
ctx.addNewFunction(scanRows,
190193
s"""
191194
| private void $scanRows(InternalRow $row) throws java.io.IOException {
192-
| while (true) {
195+
| boolean firstRow = true;
196+
| while (!shouldStop() && (firstRow || $input.hasNext())) {
197+
| if (firstRow) {
198+
| firstRow = false;
199+
| } else {
200+
| $row = (InternalRow) $input.next();
201+
| }
193202
| $numOutputRows.add(1);
194-
| ${consume(ctx, columns).trim}
195-
| if (shouldStop()) return;
196-
| if (!$input.hasNext()) break;
197-
| $row = (InternalRow)$input.next();
203+
| ${consume(ctx, columns2, inputRow).trim}
198204
| }
199205
| }""".stripMargin)
200206

sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -97,12 +97,11 @@ trait CodegenSupport extends SparkPlan {
9797
* # call child.produce()
9898
* initialized = true;
9999
* }
100-
* while (hashmap.hasNext()) {
100+
* while (!shouldStop() && hashmap.hasNext()) {
101101
* row = hashmap.next();
102102
* # build the aggregation results
103103
* # create variables for results
104104
* # call consume(), which will call parent.doConsume()
105-
* if (shouldStop()) return;
106105
* }
107106
*/
108107
protected def doProduce(ctx: CodegenContext): String

sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -40,8 +40,13 @@ case class Project(projectList: Seq[NamedExpression], child: SparkPlan)
4040
}
4141

4242
override def usedInputs: AttributeSet = {
43-
// filter out the expressions that just pass the input to next operator.
44-
AttributeSet(projectList.filterNot(inputSet.contains).flatMap(_.references))
43+
// only the attributes those are used at least twice should be evaluated before this plan,
44+
// otherwise we could defer the evaluation until output attribute is actually used.
45+
val usedExprIds = projectList.flatMap(_.collect {
46+
case a: Attribute => a.exprId
47+
})
48+
val usedMoreThanOnce = usedExprIds.groupBy(id => id).filter(_._2.size > 1).keySet
49+
references.filter(a => usedMoreThanOnce.contains(a.exprId))
4550
}
4651

4752
override def doConsume(ctx: CodegenContext, input: Seq[ExprCode]): String = {
@@ -50,7 +55,7 @@ case class Project(projectList: Seq[NamedExpression], child: SparkPlan)
5055
ctx.currentVars = input
5156
val resultVars = exprs.map(_.gen(ctx))
5257
// Evaluation of non-deterministic expressions can't be deferred.
53-
val nonDeterministicAttrs = projectList.zip(output).filter(!_._1.deterministic).unzip._2
58+
val nonDeterministicAttrs = projectList.filterNot(_.deterministic).map(_.toAttribute)
5459
s"""
5560
|${evaluateRequiredVariables(output, resultVars, AttributeSet(nonDeterministicAttrs))}
5661
|${consume(ctx, resultVars)}

0 commit comments

Comments
 (0)