Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,11 @@ public abstract class BufferedRowIterator {

protected int partitionIndex = -1;

// This indicates whether the query execution should be stopped even the input rows are still
// available. This is used in limit operator. When it reaches the given number of rows to limit,
// this flag is set and the execution should be stopped.
protected boolean isStopEarly = false;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what if there are 2 limits in the query?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've added a test for 2 limits.

When any of 2 limits sets isStopEarly, I think the execution should be stopped. Is there any case opposite to this?


public boolean hasNext() throws IOException {
if (currentRows.isEmpty()) {
processNext();
Expand Down Expand Up @@ -73,14 +78,26 @@ public void append(InternalRow row) {
currentRows.add(row);
}

/**
* Sets the flag of stopping the query execution early under whole-stage codegen.
*
* This has two use cases:
* 1. Limit operators should call it with true when the given limit number is reached.
* 2. Blocking operators (sort, aggregate, etc.) should call it with false to reset it after
* consuming all records from upstream.
*/
public void setStopEarly(boolean value) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we have more documents about how to use it? For now I see 2 use cases:

  1. limit operator should call it with true when the limit is hit
  2. blocking operator(sort, agg, etc.) should call it with false to reset it.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok. Let me add it.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You also hint me that we should reset stop early flag in sort exec node too. I will add it and related test.

isStopEarly = value;
}

/**
* Returns whether this iterator should stop fetching next row from [[CodegenSupport#inputRDDs]].
*
* If it returns true, the caller should exit the loop that [[InputAdapter]] generates.
* This interface is mainly used to limit the number of input rows.
*/
public boolean stopEarly() {
return false;
return isStopEarly;
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -170,9 +170,12 @@ case class SortExec(
| $spillSize.add($metrics.memoryBytesSpilled() - $spillSizeBefore);
| $metrics.incPeakExecutionMemory($sorterVariable.getPeakMemoryUsage());
| $needToSort = false;
|
| // Reset stop early flag set by previous limit operator
| setStopEarly(false);
| }
|
| while ($sortedIterator.hasNext()) {
| while ($sortedIterator.hasNext() && !stopEarly()) {
| UnsafeRow $outputRow = (UnsafeRow)$sortedIterator.next();
| ${consume(ctx, null, outputRow)}
| if (shouldStop()) return;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -665,7 +665,7 @@ case class HashAggregateExec(

def outputFromRowBasedMap: String = {
s"""
|while ($iterTermForFastHashMap.next()) {
|while ($iterTermForFastHashMap.next() && !stopEarly()) {
| UnsafeRow $keyTerm = (UnsafeRow) $iterTermForFastHashMap.getKey();
| UnsafeRow $bufferTerm = (UnsafeRow) $iterTermForFastHashMap.getValue();
| $outputFunc($keyTerm, $bufferTerm);
Expand All @@ -690,7 +690,7 @@ case class HashAggregateExec(
BoundReference(groupingKeySchema.length + i, attr.dataType, attr.nullable)
})
s"""
|while ($iterTermForFastHashMap.hasNext()) {
|while ($iterTermForFastHashMap.hasNext() && !stopEarly()) {
| InternalRow $row = (InternalRow) $iterTermForFastHashMap.next();
| ${generateKeyRow.code}
| ${generateBufferRow.code}
Expand All @@ -705,7 +705,7 @@ case class HashAggregateExec(

def outputFromRegularHashMap: String = {
s"""
|while ($iterTerm.next()) {
|while ($iterTerm.next() && !stopEarly()) {
| UnsafeRow $keyTerm = (UnsafeRow) $iterTerm.getKey();
| UnsafeRow $bufferTerm = (UnsafeRow) $iterTerm.getValue();
| $outputFunc($keyTerm, $bufferTerm);
Expand All @@ -723,6 +723,9 @@ case class HashAggregateExec(
long $beforeAgg = System.nanoTime();
$doAggFuncName();
$aggTime.add((System.nanoTime() - $beforeAgg) / 1000000);

// Reset stop early flag set by previous limit operator
setStopEarly(false);
}

// output the result
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -465,13 +465,16 @@ case class RangeExec(range: org.apache.spark.sql.catalyst.plans.logical.Range)
| $initRangeFuncName(partitionIndex);
| }
|
| while (true) {
| while (true && !stopEarly()) {
| long $range = $batchEnd - $number;
| if ($range != 0L) {
| int $localEnd = (int)($range / ${step}L);
| for (int $localIdx = 0; $localIdx < $localEnd; $localIdx++) {
| long $value = ((long)$localIdx * ${step}L) + $number;
| ${consume(ctx, Seq(ev))}
| if (stopEarly()) {
| break;
| }
| $shouldStop
| }
| $number = $batchEnd;
Expand Down
15 changes: 4 additions & 11 deletions sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala
Original file line number Diff line number Diff line change
Expand Up @@ -71,22 +71,15 @@ trait BaseLimitExec extends UnaryExecNode with CodegenSupport {
}

override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: ExprCode): String = {
val stopEarly =
ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, "stopEarly") // init as stopEarly = false

ctx.addNewFunction("stopEarly", s"""
@Override
protected boolean stopEarly() {
return $stopEarly;
}
""", inlineToOuterClass = true)
val countTerm = ctx.addMutableState(CodeGenerator.JAVA_INT, "count") // init as count = 0
s"""
| if ($countTerm < $limit) {
| $countTerm += 1;
| ${consume(ctx, input)}
| } else {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we need to remove this? Isn't it safer to let it here?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we don't execute into it. If we do, there should be a bug.

| $stopEarly = true;
|
| if ($countTerm == $limit) {
| setStopEarly(true);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shall we do this after consume?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

won't we call shouldStop inside consume? if it does, stopEarly will not be set.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if ($countTerm == $limit) means this is the last record, and we should still consume it?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, I see. And I think shouldStop shouldn't be called inside consume.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually as I'm just looking at the query again, there should not be a stopEarly check inside consume that prevents us to consume the last record. Because the check should be at the outer while loop.

The cases having stopEarly check inside consume, is blocking operators like sort and aggregate, for them we need to reset the flag.

But for safety, I think I will also move this after consume.

| }
| }
""".stripMargin
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -556,7 +556,7 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext {
Seq(Row(1, 2, Seq("a", "b")), Row(3, 2, Seq("c", "c", "d"))))
}

test("SPARK-18004 limit + aggregates") {
test("SPARK-18528 limit + aggregates") {
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This JIRA number is wrong.

val df = Seq(("a", 1), ("b", 2), ("c", 1), ("d", 5)).toDF("id", "value")
val limit2Df = df.limit(2)
checkAnswer(
Expand Down
77 changes: 75 additions & 2 deletions sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,8 @@ import java.util.concurrent.atomic.AtomicBoolean
import org.apache.spark.{AccumulatorSuite, SparkException}
import org.apache.spark.scheduler.{SparkListener, SparkListenerJobStart}
import org.apache.spark.sql.catalyst.util.StringUtils
import org.apache.spark.sql.execution.aggregate
import org.apache.spark.sql.execution.{aggregate, FilterExec, LocalLimitExec, RangeExec}
import org.apache.spark.sql.execution.aggregate.{HashAggregateExec, SortAggregateExec}
import org.apache.spark.sql.execution.datasources.FilePartition
import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, CartesianProductExec, SortMergeJoinExec}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.internal.SQLConf
Expand Down Expand Up @@ -2850,6 +2849,80 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
result.rdd.isEmpty
}

test("SPARK-25497: limit operation within whole stage codegen should not " +
"consume all the inputs") {

val aggDF = spark.range(0, 100, 1, 1)
.groupBy("id")
.count().limit(1).filter('count > 0)
aggDF.collect()
val aggNumRecords = aggDF.queryExecution.sparkPlan.collect {
case h: HashAggregateExec => h
}.map { hashNode =>
hashNode.metrics("numOutputRows").value
}.sum
// The first hash aggregate node outputs 100 records.
// The second hash aggregate before local limit outputs 1 record.
assert(aggNumRecords == 101)

val aggNoGroupingDF = spark.range(0, 100, 1, 1)
.groupBy()
.count().limit(1).filter('count > 0)
aggNoGroupingDF.collect()
val aggNoGroupingNumRecords = aggNoGroupingDF.queryExecution.sparkPlan.collect {
case h: HashAggregateExec => h
}.map { hashNode =>
hashNode.metrics("numOutputRows").value
}.sum
assert(aggNoGroupingNumRecords == 2)

// Sets `TOP_K_SORT_FALLBACK_THRESHOLD` to a low value because we don't want sort + limit
// be planned as `TakeOrderedAndProject` node.
withSQLConf(SQLConf.TOP_K_SORT_FALLBACK_THRESHOLD.key -> "1") {
val sortDF = spark.range(0, 100, 1, 1)
.filter('id >= 0)
.limit(10)
.sortWithinPartitions("id")
// use non-deterministic expr to prevent filter be pushed down.
.selectExpr("rand() + id as id2")
.filter('id2 >= 0)
.limit(5)
.selectExpr("1 + id2 as id3")
sortDF.collect()
val sortNumRecords = sortDF.queryExecution.sparkPlan.collect {
case l @ LocalLimitExec(_, f: FilterExec) => f
}.map { filterNode =>
filterNode.metrics("numOutputRows").value
}
assert(sortNumRecords.sorted === Seq(5, 10))
}

val filterDF = spark.range(0, 100, 1, 1).filter('id >= 0)
.selectExpr("id + 1 as id2").limit(1).filter('id > 50)
filterDF.collect()
val filterNumRecords = filterDF.queryExecution.sparkPlan.collect {
case f @ FilterExec(_, r: RangeExec) => f
}.map { case filterNode =>
filterNode.metrics("numOutputRows").value
}.head
assert(filterNumRecords == 1)

val twoLimitsDF = spark.range(0, 100, 1, 1)
.filter('id >= 0)
.limit(1)
.selectExpr("id + 1 as id2")
.limit(2)
.filter('id2 >= 0)
twoLimitsDF.collect()
val twoLimitsDFNumRecords = twoLimitsDF.queryExecution.sparkPlan.collect {
case f @ FilterExec(_, _: RangeExec) => f
}.map { filterNode =>
filterNode.metrics("numOutputRows").value
}.head
assert(twoLimitsDFNumRecords == 1)
checkAnswer(twoLimitsDF, Row(1) :: Nil)
}

test("SPARK-25454: decimal division with negative scale") {
// TODO: completely fix this issue even when LITERAL_PRECISE_PRECISION is true.
withSQLConf(SQLConf.LITERAL_PICK_MINIMUM_PRECISION.key -> "false") {
Expand Down