diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/BufferedRowIterator.java b/sql/core/src/main/java/org/apache/spark/sql/execution/BufferedRowIterator.java index 74c9c05992719..eb904d01ce1d6 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/BufferedRowIterator.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/BufferedRowIterator.java @@ -38,6 +38,11 @@ public abstract class BufferedRowIterator { protected int partitionIndex = -1; + // This indicates whether the query execution should be stopped even the input rows are still + // available. This is used in limit operator. When it reaches the given number of rows to limit, + // this flag is set and the execution should be stopped. + protected boolean isStopEarly = false; + public boolean hasNext() throws IOException { if (currentRows.isEmpty()) { processNext(); @@ -73,6 +78,18 @@ public void append(InternalRow row) { currentRows.add(row); } + /** + * Sets the flag of stopping the query execution early under whole-stage codegen. + * + * This has two use cases: + * 1. Limit operators should call it with true when the given limit number is reached. + * 2. Blocking operators (sort, aggregate, etc.) should call it with false to reset it after + * consuming all records from upstream. + */ + public void setStopEarly(boolean value) { + isStopEarly = value; + } + /** * Returns whether this iterator should stop fetching next row from [[CodegenSupport#inputRDDs]]. * @@ -80,7 +97,7 @@ public void append(InternalRow row) { * This interface is mainly used to limit the number of input rows. */ public boolean stopEarly() { - return false; + return isStopEarly; } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SortExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SortExec.scala index 0dc16ba5ce281..d3a9145895a2d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SortExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SortExec.scala @@ -170,9 +170,12 @@ case class SortExec( | $spillSize.add($metrics.memoryBytesSpilled() - $spillSizeBefore); | $metrics.incPeakExecutionMemory($sorterVariable.getPeakMemoryUsage()); | $needToSort = false; + | + | // Reset stop early flag set by previous limit operator + | setStopEarly(false); | } | - | while ($sortedIterator.hasNext()) { + | while ($sortedIterator.hasNext() && !stopEarly()) { | UnsafeRow $outputRow = (UnsafeRow)$sortedIterator.next(); | ${consume(ctx, null, outputRow)} | if (shouldStop()) return; diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala index 98adba50b2973..a11d6a7b5bd0e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala @@ -665,7 +665,7 @@ case class HashAggregateExec( def outputFromRowBasedMap: String = { s""" - |while ($iterTermForFastHashMap.next()) { + |while ($iterTermForFastHashMap.next() && !stopEarly()) { | UnsafeRow $keyTerm = (UnsafeRow) $iterTermForFastHashMap.getKey(); | UnsafeRow $bufferTerm = (UnsafeRow) $iterTermForFastHashMap.getValue(); | $outputFunc($keyTerm, $bufferTerm); @@ -690,7 +690,7 @@ case class HashAggregateExec( BoundReference(groupingKeySchema.length + i, attr.dataType, attr.nullable) }) s""" - |while ($iterTermForFastHashMap.hasNext()) { + |while ($iterTermForFastHashMap.hasNext() && !stopEarly()) { | InternalRow $row = (InternalRow) $iterTermForFastHashMap.next(); | ${generateKeyRow.code} | ${generateBufferRow.code} @@ -705,7 +705,7 @@ case class HashAggregateExec( def outputFromRegularHashMap: String = { s""" - |while ($iterTerm.next()) { + |while ($iterTerm.next() && !stopEarly()) { | UnsafeRow $keyTerm = (UnsafeRow) $iterTerm.getKey(); | UnsafeRow $bufferTerm = (UnsafeRow) $iterTerm.getValue(); | $outputFunc($keyTerm, $bufferTerm); @@ -723,6 +723,9 @@ case class HashAggregateExec( long $beforeAgg = System.nanoTime(); $doAggFuncName(); $aggTime.add((System.nanoTime() - $beforeAgg) / 1000000); + + // Reset stop early flag set by previous limit operator + setStopEarly(false); } // output the result diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala index 222a1b8bc7301..3cfe51e4625d6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala @@ -465,13 +465,16 @@ case class RangeExec(range: org.apache.spark.sql.catalyst.plans.logical.Range) | $initRangeFuncName(partitionIndex); | } | - | while (true) { + | while (true && !stopEarly()) { | long $range = $batchEnd - $number; | if ($range != 0L) { | int $localEnd = (int)($range / ${step}L); | for (int $localIdx = 0; $localIdx < $localEnd; $localIdx++) { | long $value = ((long)$localIdx * ${step}L) + $number; | ${consume(ctx, Seq(ev))} + | if (stopEarly()) { + | break; + | } | $shouldStop | } | $number = $batchEnd; diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala index 66bcda8913738..93fffa37eeb67 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala @@ -71,22 +71,15 @@ trait BaseLimitExec extends UnaryExecNode with CodegenSupport { } override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: ExprCode): String = { - val stopEarly = - ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, "stopEarly") // init as stopEarly = false - - ctx.addNewFunction("stopEarly", s""" - @Override - protected boolean stopEarly() { - return $stopEarly; - } - """, inlineToOuterClass = true) val countTerm = ctx.addMutableState(CodeGenerator.JAVA_INT, "count") // init as count = 0 s""" | if ($countTerm < $limit) { | $countTerm += 1; | ${consume(ctx, input)} - | } else { - | $stopEarly = true; + | + | if ($countTerm == $limit) { + | setStopEarly(true); + | } | } """.stripMargin } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala index d0106c44b7db2..46757397ddb5d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala @@ -556,7 +556,7 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext { Seq(Row(1, 2, Seq("a", "b")), Row(3, 2, Seq("c", "c", "d")))) } - test("SPARK-18004 limit + aggregates") { + test("SPARK-18528 limit + aggregates") { val df = Seq(("a", 1), ("b", 2), ("c", 1), ("d", 5)).toDF("id", "value") val limit2Df = df.limit(2) checkAnswer( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 631ab1b7ece7f..e237169a956d9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -25,9 +25,8 @@ import java.util.concurrent.atomic.AtomicBoolean import org.apache.spark.{AccumulatorSuite, SparkException} import org.apache.spark.scheduler.{SparkListener, SparkListenerJobStart} import org.apache.spark.sql.catalyst.util.StringUtils -import org.apache.spark.sql.execution.aggregate +import org.apache.spark.sql.execution.{aggregate, FilterExec, LocalLimitExec, RangeExec} import org.apache.spark.sql.execution.aggregate.{HashAggregateExec, SortAggregateExec} -import org.apache.spark.sql.execution.datasources.FilePartition import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, CartesianProductExec, SortMergeJoinExec} import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf @@ -2850,6 +2849,80 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { result.rdd.isEmpty } + test("SPARK-25497: limit operation within whole stage codegen should not " + + "consume all the inputs") { + + val aggDF = spark.range(0, 100, 1, 1) + .groupBy("id") + .count().limit(1).filter('count > 0) + aggDF.collect() + val aggNumRecords = aggDF.queryExecution.sparkPlan.collect { + case h: HashAggregateExec => h + }.map { hashNode => + hashNode.metrics("numOutputRows").value + }.sum + // The first hash aggregate node outputs 100 records. + // The second hash aggregate before local limit outputs 1 record. + assert(aggNumRecords == 101) + + val aggNoGroupingDF = spark.range(0, 100, 1, 1) + .groupBy() + .count().limit(1).filter('count > 0) + aggNoGroupingDF.collect() + val aggNoGroupingNumRecords = aggNoGroupingDF.queryExecution.sparkPlan.collect { + case h: HashAggregateExec => h + }.map { hashNode => + hashNode.metrics("numOutputRows").value + }.sum + assert(aggNoGroupingNumRecords == 2) + + // Sets `TOP_K_SORT_FALLBACK_THRESHOLD` to a low value because we don't want sort + limit + // be planned as `TakeOrderedAndProject` node. + withSQLConf(SQLConf.TOP_K_SORT_FALLBACK_THRESHOLD.key -> "1") { + val sortDF = spark.range(0, 100, 1, 1) + .filter('id >= 0) + .limit(10) + .sortWithinPartitions("id") + // use non-deterministic expr to prevent filter be pushed down. + .selectExpr("rand() + id as id2") + .filter('id2 >= 0) + .limit(5) + .selectExpr("1 + id2 as id3") + sortDF.collect() + val sortNumRecords = sortDF.queryExecution.sparkPlan.collect { + case l @ LocalLimitExec(_, f: FilterExec) => f + }.map { filterNode => + filterNode.metrics("numOutputRows").value + } + assert(sortNumRecords.sorted === Seq(5, 10)) + } + + val filterDF = spark.range(0, 100, 1, 1).filter('id >= 0) + .selectExpr("id + 1 as id2").limit(1).filter('id > 50) + filterDF.collect() + val filterNumRecords = filterDF.queryExecution.sparkPlan.collect { + case f @ FilterExec(_, r: RangeExec) => f + }.map { case filterNode => + filterNode.metrics("numOutputRows").value + }.head + assert(filterNumRecords == 1) + + val twoLimitsDF = spark.range(0, 100, 1, 1) + .filter('id >= 0) + .limit(1) + .selectExpr("id + 1 as id2") + .limit(2) + .filter('id2 >= 0) + twoLimitsDF.collect() + val twoLimitsDFNumRecords = twoLimitsDF.queryExecution.sparkPlan.collect { + case f @ FilterExec(_, _: RangeExec) => f + }.map { filterNode => + filterNode.metrics("numOutputRows").value + }.head + assert(twoLimitsDFNumRecords == 1) + checkAnswer(twoLimitsDF, Row(1) :: Nil) + } + test("SPARK-25454: decimal division with negative scale") { // TODO: completely fix this issue even when LITERAL_PRECISE_PRECISION is true. withSQLConf(SQLConf.LITERAL_PICK_MINIMUM_PRECISION.key -> "false") {