Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ import org.apache.spark.sql.catalyst.expressions.aggregate._
* is used to generate result.
*/
abstract class AggregationIterator(
partIndex: Int,
groupingExpressions: Seq[NamedExpression],
inputAttributes: Seq[Attribute],
aggregateExpressions: Seq[AggregateExpression],
Expand Down Expand Up @@ -229,6 +230,7 @@ abstract class AggregationIterator(
allImperativeAggregateFunctions(i).eval(currentBuffer))
i += 1
}
resultProjection.initialize(partIndex)
resultProjection(joinedRow(currentGroupingKey, aggregateResult))
}
} else if (modes.contains(Partial) || modes.contains(PartialMerge)) {
Expand All @@ -251,12 +253,14 @@ abstract class AggregationIterator(
typedImperativeAggregates(i).serializeAggregateBufferInPlace(currentBuffer)
i += 1
}
resultProjection.initialize(partIndex)
resultProjection(joinedRow(currentGroupingKey, currentBuffer))
}
} else {
// Grouping-only: we only output values based on grouping expressions.
val resultProjection = UnsafeProjection.create(resultExpressions, groupingAttributes)
(currentGroupingKey: UnsafeRow, currentBuffer: InternalRow) => {
resultProjection.initialize(partIndex)
resultProjection(currentGroupingKey)
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ case class HashAggregateExec(
val peakMemory = longMetric("peakMemory")
val spillSize = longMetric("spillSize")

child.execute().mapPartitions { iter =>
child.execute().mapPartitionsWithIndex { (partIndex, iter) =>

val hasInput = iter.hasNext
if (!hasInput && groupingExpressions.nonEmpty) {
Expand All @@ -104,6 +104,7 @@ case class HashAggregateExec(
} else {
val aggregationIterator =
new TungstenAggregationIterator(
partIndex,
groupingExpressions,
aggregateExpressions,
aggregateAttributes,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ import org.apache.spark.unsafe.KVIterator
import org.apache.spark.util.collection.unsafe.sort.UnsafeExternalSorter

class ObjectAggregationIterator(
partIndex: Int,
outputAttributes: Seq[Attribute],
groupingExpressions: Seq[NamedExpression],
aggregateExpressions: Seq[AggregateExpression],
Expand All @@ -41,6 +42,7 @@ class ObjectAggregationIterator(
inputRows: Iterator[InternalRow],
fallbackCountThreshold: Int)
extends AggregationIterator(
partIndex,
groupingExpressions,
originalInputAttributes,
aggregateExpressions,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ case class ObjectHashAggregateExec(
val numOutputRows = longMetric("numOutputRows")
val fallbackCountThreshold = sqlContext.conf.objectAggSortBasedFallbackThreshold

child.execute().mapPartitionsInternal { iter =>
child.execute().mapPartitionsWithIndexInternal { (index, iter) =>
val hasInput = iter.hasNext
if (!hasInput && groupingExpressions.nonEmpty) {
// This is a grouped aggregate and the input kvIterator is empty,
Expand All @@ -107,6 +107,7 @@ case class ObjectHashAggregateExec(
} else {
val aggregationIterator =
new ObjectAggregationIterator(
index,
child.output,
groupingExpressions,
aggregateExpressions,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ case class SortAggregateExec(

protected override def doExecute(): RDD[InternalRow] = attachTree(this, "execute") {
val numOutputRows = longMetric("numOutputRows")
child.execute().mapPartitionsInternal { iter =>
child.execute().mapPartitionsWithIndexInternal { (partIndex, iter) =>
// Because the constructor of an aggregation iterator will read at least the first row,
// we need to get the value of iter.hasNext first.
val hasInput = iter.hasNext
Expand All @@ -84,6 +84,7 @@ case class SortAggregateExec(
Iterator[UnsafeRow]()
} else {
val outputIter = new SortBasedAggregationIterator(
partIndex,
groupingExpressions,
child.output,
iter,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ import org.apache.spark.sql.execution.metric.SQLMetric
* sorted by values of [[groupingExpressions]].
*/
class SortBasedAggregationIterator(
partIndex: Int,
groupingExpressions: Seq[NamedExpression],
valueAttributes: Seq[Attribute],
inputIterator: Iterator[InternalRow],
Expand All @@ -37,6 +38,7 @@ class SortBasedAggregationIterator(
newMutableProjection: (Seq[Expression], Seq[Attribute]) => MutableProjection,
numOutputRows: SQLMetric)
extends AggregationIterator(
partIndex: Int,
groupingExpressions,
valueAttributes,
aggregateExpressions,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ import org.apache.spark.unsafe.KVIterator
* the iterator containing input [[UnsafeRow]]s.
*/
class TungstenAggregationIterator(
partIndex: Int,
groupingExpressions: Seq[NamedExpression],
aggregateExpressions: Seq[AggregateExpression],
aggregateAttributes: Seq[Attribute],
Expand All @@ -90,6 +91,7 @@ class TungstenAggregationIterator(
peakMemory: SQLMetric,
spillSize: SQLMetric)
extends AggregationIterator(
partIndex: Int,
groupingExpressions,
originalInputAttributes,
aggregateExpressions,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -448,6 +448,22 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext {
rand(Random.nextLong()), randn(Random.nextLong())
).foreach(assertValuesDoNotChangeAfterCoalesceOrUnion(_))
}

private def assertNoExceptions(c: Column): Unit = {
for (wholeStage <- Seq(true, false)) {
withSQLConf((SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key, wholeStage.toString)) {
spark.range(0, 5).toDF("a").agg(sum("a")).withColumn("v", c).collect()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test also passes without your test. I think you need to reference a NonDeterministic expression in the aggregate.

Could also make sure that we test all aggregation paths:

  1. HashAggregate
  2. ObjectHashAggregate
  3. SortAggregate

}
}
}

test("[SPARK-19471] AggregationIterator does not initialize the generated result projection" +
" before using it") {
Seq(
monotonically_increasing_id(), spark_partition_id(),
rand(Random.nextLong()), randn(Random.nextLong())
).foreach(assertNoExceptions(_))
}
}

object DataFrameFunctionsSuite {
Expand Down