Skip to content

Commit a09e60f

Browse files
committed
Fix limit codegen.
1 parent 12703bd commit a09e60f

File tree

5 files changed

+44
-13
lines changed

5 files changed

+44
-13
lines changed

sql/core/src/main/java/org/apache/spark/sql/execution/BufferedRowIterator.java

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,11 @@ public abstract class BufferedRowIterator {
3838

3939
protected int partitionIndex = -1;
4040

41+
// This indicates whether the query execution should be stopped even the input rows are still
42+
// available. This is used in limit operator. When it reaches the given number of rows to limit,
43+
// this flag is set and the execution should be stopped.
44+
protected boolean isStopEarly = false;
45+
4146
public boolean hasNext() throws IOException {
4247
if (currentRows.isEmpty()) {
4348
processNext();
@@ -73,14 +78,21 @@ public void append(InternalRow row) {
7378
currentRows.add(row);
7479
}
7580

81+
/**
82+
* Sets the flag of stopping the query execution early.
83+
*/
84+
public void setStopEarly(boolean value) {
85+
isStopEarly = value;
86+
}
87+
7688
/**
7789
* Returns whether this iterator should stop fetching next row from [[CodegenSupport#inputRDDs]].
7890
*
7991
* If it returns true, the caller should exit the loop that [[InputAdapter]] generates.
8092
* This interface is mainly used to limit the number of input rows.
8193
*/
8294
public boolean stopEarly() {
83-
return false;
95+
return isStopEarly;
8496
}
8597

8698
/**

sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -243,7 +243,7 @@ case class HashAggregateExec(
243243
val aggTime = metricTerm(ctx, "aggTime")
244244
val beforeAgg = ctx.freshName("beforeAgg")
245245
s"""
246-
| while (!$initAgg && !stopEarly()) {
246+
| while (!$initAgg) {
247247
| $initAgg = true;
248248
| long $beforeAgg = System.nanoTime();
249249
| $doAggFuncName();
@@ -723,6 +723,9 @@ case class HashAggregateExec(
723723
long $beforeAgg = System.nanoTime();
724724
$doAggFuncName();
725725
$aggTime.add((System.nanoTime() - $beforeAgg) / 1000000);
726+
727+
// Reset stop early flag set by previous limit operator
728+
setStopEarly(false);
726729
}
727730

728731
// output the result

sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala

Lines changed: 1 addition & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -71,21 +71,12 @@ trait BaseLimitExec extends UnaryExecNode with CodegenSupport {
7171
}
7272

7373
override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: ExprCode): String = {
74-
val stopEarly =
75-
ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, "stopEarly") // init as stopEarly = false
76-
77-
ctx.addNewFunction("stopEarly", s"""
78-
@Override
79-
protected boolean stopEarly() {
80-
return $stopEarly;
81-
}
82-
""", inlineToOuterClass = true)
8374
val countTerm = ctx.addMutableState(CodeGenerator.JAVA_INT, "count") // init as count = 0
8475
s"""
8576
| if ($countTerm < $limit) {
8677
| $countTerm += 1;
8778
| if ($countTerm == $limit) {
88-
| $stopEarly = true;
79+
| setStopEarly(true);
8980
| }
9081
| ${consume(ctx, input)}
9182
| }

sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -556,7 +556,7 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext {
556556
Seq(Row(1, 2, Seq("a", "b")), Row(3, 2, Seq("c", "c", "d"))))
557557
}
558558

559-
test("SPARK-18004 limit + aggregates") {
559+
test("SPARK-18528 limit + aggregates") {
560560
val df = Seq(("a", 1), ("b", 2), ("c", 1), ("d", 5)).toDF("id", "value")
561561
val limit2Df = df.limit(2)
562562
checkAnswer(

sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2865,6 +2865,17 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
28652865
// The second hash aggregate before local limit outputs 1 record.
28662866
assert(aggNumRecords == 101)
28672867

2868+
val aggNoGroupingDF = spark.range(0, 100, 1, 1)
2869+
.groupBy()
2870+
.count().limit(1).filter('count > 0)
2871+
aggNoGroupingDF.collect()
2872+
val aggNoGroupingNumRecords = aggNoGroupingDF.queryExecution.sparkPlan.collect {
2873+
case h: HashAggregateExec => h
2874+
}.map { hashNode =>
2875+
hashNode.metrics("numOutputRows").value
2876+
}.sum
2877+
assert(aggNoGroupingNumRecords == 2)
2878+
28682879
val filterDF = spark.range(0, 100, 1, 1).filter('id >= 0)
28692880
.selectExpr("id + 1 as id2").limit(1).filter('id > 50)
28702881
filterDF.collect()
@@ -2875,6 +2886,20 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
28752886
}.head
28762887
// RangeNode and FilterNode both output 1 record.
28772888
assert(filterNumRecords == Tuple2(1, 1))
2889+
2890+
val twoLimitsDF = spark.range(0, 100, 1, 1)
2891+
.limit(1)
2892+
.filter('id >= 0)
2893+
.selectExpr("id + 1 as id2")
2894+
.limit(2)
2895+
.filter('id > 50)
2896+
twoLimitsDF.collect()
2897+
val twoLimitsDFNumRecords = twoLimitsDF.queryExecution.sparkPlan.collect {
2898+
case r: RangeExec => r
2899+
}.map { rangeNode =>
2900+
rangeNode.metrics("numOutputRows").value
2901+
}.head
2902+
assert(twoLimitsDFNumRecords == 1)
28782903
}
28792904
}
28802905

0 commit comments

Comments
 (0)