diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 691f71a7d4ac..2e8ce4541865 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -72,7 +72,12 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { execution.TakeOrderedAndProjectExec( limit, order, projectList, planLater(child)) :: Nil case logical.Limit(IntegerLiteral(limit), child) => - execution.CollectLimitExec(limit, planLater(child)) :: Nil + // Normally wrapping child with `LocalLimitExec` here is a no-op, because + // `CollectLimitExec.executeCollect` will call `LocalLimitExec.executeTake`, which + // calls `child.executeTake`. If child supports whole stage codegen, adding this + // `LocalLimitExec` can stop the processing of whole stage codegen and trigger the + // resource releasing work, after we consume `limit` rows. + execution.CollectLimitExec(limit, LocalLimitExec(limit, planLater(child))) :: Nil case other => planLater(other) :: Nil } case logical.Limit(IntegerLiteral(limit), logical.Sort(order, true, child)) => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala index 73a0f8735ed4..7cef5569717a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala @@ -54,6 +54,14 @@ trait BaseLimitExec extends UnaryExecNode with CodegenSupport { val limit: Int override def output: Seq[Attribute] = child.output + // Do not enable whole stage codegen for a single limit. + override def supportCodegen: Boolean = child match { + case plan: CodegenSupport => plan.supportCodegen + case _ => false + } + + override def executeTake(n: Int): Array[InternalRow] = child.executeTake(math.min(n, limit)) + protected override def doExecute(): RDD[InternalRow] = child.execute().mapPartitions { iter => iter.take(limit) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index e95f6dba4607..923c6d8eb71f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -2658,4 +2658,9 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { checkAnswer(sql("SELECT __auto_generated_subquery_name.i from (SELECT i FROM v)"), Row(1)) } } + + test("SPARK-21743: top-most limit should not cause memory leak") { + // In unit test, Spark will fail the query if memory leak detected. + spark.range(100).groupBy("id").count().limit(1).collect() + } }