Skip to content
Closed
Original file line number Diff line number Diff line change
Expand Up @@ -266,7 +266,18 @@ case class GeneratedAggregate(

val joinedRow = new JoinedRow3

if (groupingExpressions.isEmpty) {
if (!iter.hasNext) {
// This is an empty input, so return early so that we do not allocate data structures
// that won't be cleaned up (see SPARK-8357).
if (groupingExpressions.isEmpty) {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here, I made a slight simplification compared to @navis's original patch: if groupingExpressions is empty and the input is empty, then always return an empty aggregation buffer. @navis's patch contained an additional branch here which would skip this output if partial = true, but I think that is an unnecessary performance optimization given that the non-generated-Aggregate operator still outputs an empty row even on empty inputs. Removing this branch means fewer cases to have to test.

// This is a global aggregate, so return an empty aggregation buffer.
val resultProjection = resultProjectionBuilder()
Iterator(resultProjection(newAggregationBuffer(EmptyRow)))
} else {
// This is a grouped aggregate, so return an empty iterator.
Iterator[InternalRow]()
}
} else if (groupingExpressions.isEmpty) {
// TODO: Codegening anything other than the updateProjection is probably over kill.
val buffer = newAggregationBuffer(EmptyRow).asInstanceOf[MutableRow]
var currentRow: InternalRow = null
Expand All @@ -280,6 +291,7 @@ case class GeneratedAggregate(
val resultProjection = resultProjectionBuilder()
Iterator(resultProjection(buffer))
} else if (unsafeEnabled) {
assert(iter.hasNext, "There should be at least one row for this path")
log.info("Using Unsafe-based aggregator")
val aggregationMap = new UnsafeFixedWidthAggregationMap(
newAggregationBuffer,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -648,6 +648,15 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils {
Row(2, 1, 2, 2, 1))
}

test("count of empty table") {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added this test to make it easier to catch mistakes in the implementation of the groupingExpressions.isEmpty && !iter.hasNext() case. A wrong implementation that does not return an empty buffer will quickly be caught by this test.

withTempTable("t") {
Seq.empty[(Int, Int)].toDF("a", "b").registerTempTable("t")
checkAnswer(
sql("select count(a) from t"),
Row(0))
}
}

test("inner join where, one match per row") {
checkAnswer(
sql("SELECT * FROM upperCaseData JOIN lowerCaseData WHERE n = N"),
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.execution

import org.apache.spark.sql.SQLConf
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.test.TestSQLContext

class AggregateSuite extends SparkPlanTest {

test("SPARK-8357 unsafe aggregation path should not leak memory with empty input") {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I made some simplifications in this test to remove the big cross-product of options since I think there's only one problematic case that we really need to write a regression test for.

val codegenDefault = TestSQLContext.getConf(SQLConf.CODEGEN_ENABLED)
val unsafeDefault = TestSQLContext.getConf(SQLConf.UNSAFE_ENABLED)
try {
TestSQLContext.setConf(SQLConf.CODEGEN_ENABLED, true)
TestSQLContext.setConf(SQLConf.UNSAFE_ENABLED, true)
val df = Seq.empty[(Int, Int)].toDF("a", "b")
checkAnswer(
df,
GeneratedAggregate(
partial = true,
Seq(df.col("b").expr),
Seq(Alias(Count(df.col("a").expr), "cnt")()),
unsafeEnabled = true,
_: SparkPlan),
Seq.empty
)
} finally {
TestSQLContext.setConf(SQLConf.CODEGEN_ENABLED, codegenDefault)
TestSQLContext.setConf(SQLConf.UNSAFE_ENABLED, unsafeDefault)
}
}
}