Skip to content

Commit 12703bd

Browse files
committed
limit operation within whole stage codegen should not consume all the inputs.
1 parent 411ecc3 commit 12703bd

File tree

4 files changed

+42
-12
lines changed

4 files changed

+42
-12
lines changed

sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -243,7 +243,7 @@ case class HashAggregateExec(
243243
val aggTime = metricTerm(ctx, "aggTime")
244244
val beforeAgg = ctx.freshName("beforeAgg")
245245
s"""
246-
| while (!$initAgg) {
246+
| while (!$initAgg && !stopEarly()) {
247247
| $initAgg = true;
248248
| long $beforeAgg = System.nanoTime();
249249
| $doAggFuncName();
@@ -665,7 +665,7 @@ case class HashAggregateExec(
665665

666666
def outputFromRowBasedMap: String = {
667667
s"""
668-
|while ($iterTermForFastHashMap.next()) {
668+
|while ($iterTermForFastHashMap.next() && !stopEarly()) {
669669
| UnsafeRow $keyTerm = (UnsafeRow) $iterTermForFastHashMap.getKey();
670670
| UnsafeRow $bufferTerm = (UnsafeRow) $iterTermForFastHashMap.getValue();
671671
| $outputFunc($keyTerm, $bufferTerm);
@@ -690,7 +690,7 @@ case class HashAggregateExec(
690690
BoundReference(groupingKeySchema.length + i, attr.dataType, attr.nullable)
691691
})
692692
s"""
693-
|while ($iterTermForFastHashMap.hasNext()) {
693+
|while ($iterTermForFastHashMap.hasNext() && !stopEarly()) {
694694
| InternalRow $row = (InternalRow) $iterTermForFastHashMap.next();
695695
| ${generateKeyRow.code}
696696
| ${generateBufferRow.code}
@@ -705,7 +705,7 @@ case class HashAggregateExec(
705705

706706
def outputFromRegularHashMap: String = {
707707
s"""
708-
|while ($iterTerm.next()) {
708+
|while ($iterTerm.next() && !stopEarly()) {
709709
| UnsafeRow $keyTerm = (UnsafeRow) $iterTerm.getKey();
710710
| UnsafeRow $bufferTerm = (UnsafeRow) $iterTerm.getValue();
711711
| $outputFunc($keyTerm, $bufferTerm);

sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -465,13 +465,18 @@ case class RangeExec(range: org.apache.spark.sql.catalyst.plans.logical.Range)
465465
| $initRangeFuncName(partitionIndex);
466466
| }
467467
|
468-
| while (true) {
468+
| while (true && !stopEarly()) {
469469
| long $range = $batchEnd - $number;
470470
| if ($range != 0L) {
471471
| int $localEnd = (int)($range / ${step}L);
472472
| for (int $localIdx = 0; $localIdx < $localEnd; $localIdx++) {
473473
| long $value = ((long)$localIdx * ${step}L) + $number;
474+
| $numOutput.add(1);
475+
| $inputMetrics.incRecordsRead(1);
474476
| ${consume(ctx, Seq(ev))}
477+
| if (stopEarly()) {
478+
| break;
479+
| }
475480
| $shouldStop
476481
| }
477482
| $number = $batchEnd;
@@ -488,9 +493,6 @@ case class RangeExec(range: org.apache.spark.sql.catalyst.plans.logical.Range)
488493
| $numElementsTodo = 0;
489494
| if ($nextBatchTodo == 0) break;
490495
| }
491-
| $numOutput.add($nextBatchTodo);
492-
| $inputMetrics.incRecordsRead($nextBatchTodo);
493-
|
494496
| $batchEnd += $nextBatchTodo * ${step}L;
495497
| }
496498
""".stripMargin

sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -84,9 +84,10 @@ trait BaseLimitExec extends UnaryExecNode with CodegenSupport {
8484
s"""
8585
| if ($countTerm < $limit) {
8686
| $countTerm += 1;
87+
| if ($countTerm == $limit) {
88+
| $stopEarly = true;
89+
| }
8790
| ${consume(ctx, input)}
88-
| } else {
89-
| $stopEarly = true;
9091
| }
9192
""".stripMargin
9293
}

sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala

Lines changed: 29 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,8 @@ import java.util.concurrent.atomic.AtomicBoolean
2525
import org.apache.spark.{AccumulatorSuite, SparkException}
2626
import org.apache.spark.scheduler.{SparkListener, SparkListenerJobStart}
2727
import org.apache.spark.sql.catalyst.util.StringUtils
28-
import org.apache.spark.sql.execution.aggregate
28+
import org.apache.spark.sql.execution.{aggregate, FilterExec, RangeExec}
2929
import org.apache.spark.sql.execution.aggregate.{HashAggregateExec, SortAggregateExec}
30-
import org.apache.spark.sql.execution.datasources.FilePartition
3130
import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, CartesianProductExec, SortMergeJoinExec}
3231
import org.apache.spark.sql.functions._
3332
import org.apache.spark.sql.internal.SQLConf
@@ -2849,6 +2848,34 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
28492848
val result = ds.flatMap(_.bar).distinct
28502849
result.rdd.isEmpty
28512850
}
2851+
2852+
test("SPARK-25497: limit operation within whole stage codegen should not " +
2853+
"consume all the inputs") {
2854+
2855+
val aggDF = spark.range(0, 100, 1, 1)
2856+
.groupBy("id")
2857+
.count().limit(1).filter('count > 0)
2858+
aggDF.collect()
2859+
val aggNumRecords = aggDF.queryExecution.sparkPlan.collect {
2860+
case h: HashAggregateExec => h
2861+
}.map { hashNode =>
2862+
hashNode.metrics("numOutputRows").value
2863+
}.sum
2864+
// The first hash aggregate node outputs 100 records.
2865+
// The second hash aggregate before local limit outputs 1 record.
2866+
assert(aggNumRecords == 101)
2867+
2868+
val filterDF = spark.range(0, 100, 1, 1).filter('id >= 0)
2869+
.selectExpr("id + 1 as id2").limit(1).filter('id > 50)
2870+
filterDF.collect()
2871+
val filterNumRecords = filterDF.queryExecution.sparkPlan.collect {
2872+
case f @ FilterExec(_, r: RangeExec) => (f, r)
2873+
}.map { case (filterNode, rangeNode) =>
2874+
(filterNode.metrics("numOutputRows").value, rangeNode.metrics("numOutputRows").value)
2875+
}.head
2876+
// RangeNode and FilterNode both output 1 record.
2877+
assert(filterNumRecords == Tuple2(1, 1))
2878+
}
28522879
}
28532880

28542881
case class Foo(bar: Option[String])

0 commit comments

Comments
 (0)