Skip to content

Commit c5419b3

Browse files
committed
addressed comments
1 parent 143e1ef commit c5419b3

File tree

3 files changed

+16
-11
lines changed

3 files changed

+16
-11
lines changed

sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,11 @@ case class GeneratedAggregate(
6464
}
6565
}
6666

67+
// even with empty input iterator, if this group-by operator is for
68+
// global(groupingExpression.isEmpty) and final(partial=false),
69+
// we still need to make a row from empty buffer.
70+
def needEmptyBufferForwarded: Boolean = groupingExpressions.isEmpty && !partial
71+
6772
override def output: Seq[Attribute] = aggregateExpressions.map(_.toAttribute)
6873

6974
protected override def doExecute(): RDD[InternalRow] = {
@@ -270,8 +275,7 @@ case class GeneratedAggregate(
270275

271276
val joinedRow = new JoinedRow3
272277

273-
if (!iter.hasNext && (partial || groupingExpressions.nonEmpty)) {
274-
// even with empty input, final-global groupby should forward value of empty buffer
278+
if (!iter.hasNext && !needEmptyBufferForwarded) {
275279
Iterator[InternalRow]()
276280
} else if (groupingExpressions.isEmpty) {
277281
// TODO: Codegening anything other than the updateProjection is probably over kill.

sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -154,7 +154,8 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ
154154

155155
protected def newProjection(
156156
expressions: Seq[Expression],
157-
inputSchema: Seq[Attribute], mutableRow: Boolean = false): Projection = {
157+
inputSchema: Seq[Attribute],
158+
mutableRow: Boolean = false): Projection = {
158159
log.debug(
159160
s"Creating Projection: $expressions, inputSchema: $inputSchema, codegen:$codegenEnabled")
160161
if (codegenEnabled && expressions.forall(_.isThreadSafe)) {

sql/core/src/test/scala/org/apache/spark/sql/execution/AggregateSuite.scala

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -27,29 +27,29 @@ class AggregateSuite extends SparkPlanTest {
2727
test("SPARK-8357 Memory leakage on unsafe aggregation path with empty input") {
2828

2929
val input0 = Seq.empty[(String, Int, Double)]
30-
val input1 = Seq(("Hello", 4, 2.0))
31-
32-
// hack : current default parallelism of test local backend is two
30+
// in the case of needEmptyBufferForwarded=true, task makes a row from empty buffer
31+
// even with empty input. And current default parallelism of SparkPlanTest is two (local[2])
3332
val x0 = Seq(Tuple1(0L), Tuple1(0L))
3433
val y0 = Seq.empty[Tuple1[Long]]
3534

35+
val input1 = Seq(("Hello", 4, 2.0))
3636
val x1 = Seq(Tuple1(0L), Tuple1(1L))
3737
val y1 = Seq(Tuple1(1L))
3838

3939
val codegenDefault = TestSQLContext.getConf(SQLConf.CODEGEN_ENABLED)
40+
TestSQLContext.setConf(SQLConf.CODEGEN_ENABLED, true)
4041
try {
4142
for ((input, x, y) <- Seq((input0, x0, y0), (input1, x1, y1))) {
4243
val df = input.toDF("a", "b", "c")
4344
val colB = df.col("b").expr
4445
val colC = df.col("c").expr
4546
val aggrExpr = Alias(Count(Cast(colC, LongType)), "Count")()
4647

47-
for ((codegen, unsafe) <- Seq((false, false), (true, false), (true, true));
48-
partial <- Seq(false, true); groupExpr <- Seq(colB :: Nil, Seq.empty)) {
49-
TestSQLContext.setConf(SQLConf.CODEGEN_ENABLED, codegen)
48+
for (partial <- Seq(false, true); groupExpr <- Seq(Seq(colB), Seq.empty)) {
49+
val aggregate = GeneratedAggregate(partial, groupExpr, Seq(aggrExpr), true, _: SparkPlan)
5050
checkAnswer(df,
51-
GeneratedAggregate(partial, groupExpr, aggrExpr :: Nil, unsafe, _: SparkPlan),
52-
if (groupExpr.isEmpty && !partial) x else y)
51+
aggregate,
52+
if (aggregate(null).needEmptyBufferForwarded) x else y)
5353
}
5454
}
5555
} finally {

0 commit comments

Comments
 (0)