From 7ec5ebf867110f1c0105faee36ee6776ad36aab1 Mon Sep 17 00:00:00 2001 From: wangyang Date: Mon, 6 Feb 2017 20:07:57 +0800 Subject: [PATCH 1/3] SPARK-19471 --- .../spark/sql/execution/aggregate/AggregationIterator.scala | 4 ++++ .../spark/sql/execution/aggregate/HashAggregateExec.scala | 3 ++- .../spark/sql/execution/aggregate/SortAggregateExec.scala | 3 ++- .../execution/aggregate/SortBasedAggregationIterator.scala | 2 ++ .../sql/execution/aggregate/TungstenAggregationIterator.scala | 2 ++ 5 files changed, 12 insertions(+), 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregationIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregationIterator.scala index 7c11fdb9792e8..28d2055aef22d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregationIterator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregationIterator.scala @@ -33,6 +33,7 @@ import org.apache.spark.sql.catalyst.expressions.aggregate._ * is used to generate result. */ abstract class AggregationIterator( + partIndex: Int, groupingExpressions: Seq[NamedExpression], inputAttributes: Seq[Attribute], aggregateExpressions: Seq[AggregateExpression], @@ -229,6 +230,7 @@ abstract class AggregationIterator( allImperativeAggregateFunctions(i).eval(currentBuffer)) i += 1 } + resultProjection.initialize(partIndex) resultProjection(joinedRow(currentGroupingKey, aggregateResult)) } } else if (modes.contains(Partial) || modes.contains(PartialMerge)) { @@ -251,12 +253,14 @@ abstract class AggregationIterator( typedImperativeAggregates(i).serializeAggregateBufferInPlace(currentBuffer) i += 1 } + resultProjection.initialize(partIndex) resultProjection(joinedRow(currentGroupingKey, currentBuffer)) } } else { // Grouping-only: we only output values based on grouping expressions. val resultProjection = UnsafeProjection.create(resultExpressions, groupingAttributes) (currentGroupingKey: UnsafeRow, currentBuffer: InternalRow) => { + resultProjection.initialize(partIndex) resultProjection(currentGroupingKey) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala index 4529ed067e565..f4582f3f50a8b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala @@ -94,7 +94,7 @@ case class HashAggregateExec( val peakMemory = longMetric("peakMemory") val spillSize = longMetric("spillSize") - child.execute().mapPartitions { iter => + child.execute().mapPartitionsWithIndex { (partIndex, iter) => val hasInput = iter.hasNext if (!hasInput && groupingExpressions.nonEmpty) { @@ -104,6 +104,7 @@ case class HashAggregateExec( } else { val aggregationIterator = new TungstenAggregationIterator( + partIndex, groupingExpressions, aggregateExpressions, aggregateAttributes, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortAggregateExec.scala index be3198b8e7d82..a43235790834e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortAggregateExec.scala @@ -74,7 +74,7 @@ case class SortAggregateExec( protected override def doExecute(): RDD[InternalRow] = attachTree(this, "execute") { val numOutputRows = longMetric("numOutputRows") - child.execute().mapPartitionsInternal { iter => + child.execute().mapPartitionsWithIndexInternal { (partIndex, iter) => // Because the constructor of an aggregation iterator will read at least the first row, // we need to get the value of iter.hasNext first. val hasInput = iter.hasNext @@ -84,6 +84,7 @@ case class SortAggregateExec( Iterator[UnsafeRow]() } else { val outputIter = new SortBasedAggregationIterator( + partIndex, groupingExpressions, child.output, iter, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala index bea2dce1a7657..c95e202f71932 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala @@ -27,6 +27,7 @@ import org.apache.spark.sql.execution.metric.SQLMetric * sorted by values of [[groupingExpressions]]. */ class SortBasedAggregationIterator( + partIndex: Int, groupingExpressions: Seq[NamedExpression], valueAttributes: Seq[Attribute], inputIterator: Iterator[InternalRow], @@ -37,6 +38,7 @@ class SortBasedAggregationIterator( newMutableProjection: (Seq[Expression], Seq[Attribute]) => MutableProjection, numOutputRows: SQLMetric) extends AggregationIterator( + partIndex: Int, groupingExpressions, valueAttributes, aggregateExpressions, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala index 2988161ee5e7b..f332fa2c11122 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala @@ -77,6 +77,7 @@ import org.apache.spark.unsafe.KVIterator * the iterator containing input [[UnsafeRow]]s. */ class TungstenAggregationIterator( + partIndex: Int, groupingExpressions: Seq[NamedExpression], aggregateExpressions: Seq[AggregateExpression], aggregateAttributes: Seq[Attribute], @@ -90,6 +91,7 @@ class TungstenAggregationIterator( peakMemory: SQLMetric, spillSize: SQLMetric) extends AggregationIterator( + partIndex: Int, groupingExpressions, originalInputAttributes, aggregateExpressions, From 97b07a1aa3723af4d401aff89edce1ddcebbfeff Mon Sep 17 00:00:00 2001 From: wangyang Date: Mon, 6 Feb 2017 22:00:47 +0800 Subject: [PATCH 2/3] fix build --- .../sql/execution/aggregate/ObjectAggregationIterator.scala | 2 ++ .../sql/execution/aggregate/ObjectHashAggregateExec.scala | 3 ++- 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectAggregationIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectAggregationIterator.scala index 3a7fcf1fa9d89..3f86b1d11d1c5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectAggregationIterator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectAggregationIterator.scala @@ -30,6 +30,7 @@ import org.apache.spark.unsafe.KVIterator import org.apache.spark.util.collection.unsafe.sort.UnsafeExternalSorter class ObjectAggregationIterator( + partIndex: Int, outputAttributes: Seq[Attribute], groupingExpressions: Seq[NamedExpression], aggregateExpressions: Seq[AggregateExpression], @@ -41,6 +42,7 @@ class ObjectAggregationIterator( inputRows: Iterator[InternalRow], fallbackCountThreshold: Int) extends AggregationIterator( + partIndex, groupingExpressions, originalInputAttributes, aggregateExpressions, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectHashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectHashAggregateExec.scala index 3fcb7ec9a6411..e28ba50dd53be 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectHashAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectHashAggregateExec.scala @@ -98,7 +98,7 @@ case class ObjectHashAggregateExec( val numOutputRows = longMetric("numOutputRows") val fallbackCountThreshold = sqlContext.conf.objectAggSortBasedFallbackThreshold - child.execute().mapPartitionsInternal { iter => + child.execute().mapPartitionsWithIndexInternal { (index, iter) => val hasInput = iter.hasNext if (!hasInput && groupingExpressions.nonEmpty) { // This is a grouped aggregate and the input kvIterator is empty, @@ -107,6 +107,7 @@ case class ObjectHashAggregateExec( } else { val aggregationIterator = new ObjectAggregationIterator( + index, child.output, groupingExpressions, aggregateExpressions, From b9b969394a17773162f1802ce12d03dd293afd97 Mon Sep 17 00:00:00 2001 From: wangyang Date: Mon, 6 Feb 2017 22:32:32 +0800 Subject: [PATCH 3/3] add test --- .../spark/sql/DataFrameFunctionsSuite.scala | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 0e9a2c6cf7dec..77989c8c379fc 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -448,6 +448,22 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { rand(Random.nextLong()), randn(Random.nextLong()) ).foreach(assertValuesDoNotChangeAfterCoalesceOrUnion(_)) } + + private def assertNoExceptions(c: Column): Unit = { + for (wholeStage <- Seq(true, false)) { + withSQLConf((SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key, wholeStage.toString)) { + spark.range(0, 5).toDF("a").agg(sum("a")).withColumn("v", c).collect() + } + } + } + + test("[SPARK-19471] AggregationIterator does not initialize the generated result projection" + + " before using it") { + Seq( + monotonically_increasing_id(), spark_partition_id(), + rand(Random.nextLong()), randn(Random.nextLong()) + ).foreach(assertNoExceptions(_)) + } } object DataFrameFunctionsSuite {