|
17 | 17 |
|
18 | 18 | package org.apache.spark.sql.execution.aggregate |
19 | 19 |
|
20 | | -import org.apache.spark.TaskContext |
21 | | -import org.apache.spark.rdd.{MapPartitionsWithPreparationRDD, RDD} |
| 20 | +import org.apache.spark.rdd.RDD |
22 | 21 | import org.apache.spark.sql.catalyst.InternalRow |
23 | 22 | import org.apache.spark.sql.catalyst.errors._ |
24 | | -import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression2 |
25 | 23 | import org.apache.spark.sql.catalyst.expressions._ |
| 24 | +import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression2 |
26 | 25 | import org.apache.spark.sql.catalyst.plans.physical._ |
27 | | -import org.apache.spark.sql.execution.{UnsafeFixedWidthAggregationMap, UnaryNode, SparkPlan} |
28 | 26 | import org.apache.spark.sql.execution.metric.SQLMetrics |
| 27 | +import org.apache.spark.sql.execution.{SparkPlan, UnaryNode, UnsafeFixedWidthAggregationMap} |
29 | 28 | import org.apache.spark.sql.types.StructType |
30 | 29 |
|
31 | 30 | case class TungstenAggregate( |
@@ -84,59 +83,39 @@ case class TungstenAggregate( |
84 | 83 | val dataSize = longMetric("dataSize") |
85 | 84 | val spillSize = longMetric("spillSize") |
86 | 85 |
|
87 | | - /** |
88 | | - * Set up the underlying unsafe data structures used before computing the parent partition. |
89 | | - * This makes sure our iterator is not starved by other operators in the same task. |
90 | | - */ |
91 | | - def preparePartition(): TungstenAggregationIterator = { |
92 | | - new TungstenAggregationIterator( |
93 | | - groupingExpressions, |
94 | | - nonCompleteAggregateExpressions, |
95 | | - nonCompleteAggregateAttributes, |
96 | | - completeAggregateExpressions, |
97 | | - completeAggregateAttributes, |
98 | | - initialInputBufferOffset, |
99 | | - resultExpressions, |
100 | | - newMutableProjection, |
101 | | - child.output, |
102 | | - testFallbackStartsAt, |
103 | | - numInputRows, |
104 | | - numOutputRows, |
105 | | - dataSize, |
106 | | - spillSize) |
107 | | - } |
| 86 | + child.execute().mapPartitions { iter => |
108 | 87 |
|
109 | | - /** Compute a partition using the iterator already set up previously. */ |
110 | | - def executePartition( |
111 | | - context: TaskContext, |
112 | | - partitionIndex: Int, |
113 | | - aggregationIterator: TungstenAggregationIterator, |
114 | | - parentIterator: Iterator[InternalRow]): Iterator[UnsafeRow] = { |
115 | | - val hasInput = parentIterator.hasNext |
116 | | - if (!hasInput) { |
117 | | - // We're not using the underlying map, so we just can free it here |
118 | | - aggregationIterator.free() |
119 | | - if (groupingExpressions.isEmpty) { |
| 88 | + val hasInput = iter.hasNext |
| 89 | + if (!hasInput && groupingExpressions.nonEmpty) { |
| 90 | + // This is a grouped aggregate and the input iterator is empty, |
| 91 | + // so return an empty iterator. |
| 92 | + Iterator.empty |
| 93 | + } else { |
| 94 | + val aggregationIterator = |
| 95 | + new TungstenAggregationIterator( |
| 96 | + groupingExpressions, |
| 97 | + nonCompleteAggregateExpressions, |
| 98 | + nonCompleteAggregateAttributes, |
| 99 | + completeAggregateExpressions, |
| 100 | + completeAggregateAttributes, |
| 101 | + initialInputBufferOffset, |
| 102 | + resultExpressions, |
| 103 | + newMutableProjection, |
| 104 | + child.output, |
| 105 | + iter, |
| 106 | + testFallbackStartsAt, |
| 107 | + numInputRows, |
| 108 | + numOutputRows, |
| 109 | + dataSize, |
| 110 | + spillSize) |
| 111 | + if (!hasInput && groupingExpressions.isEmpty) { |
120 | 112 | numOutputRows += 1 |
121 | 113 | Iterator.single[UnsafeRow](aggregationIterator.outputForEmptyGroupingKeyWithoutInput()) |
122 | 114 | } else { |
123 | | - // This is a grouped aggregate and the input iterator is empty, |
124 | | - // so return an empty iterator. |
125 | | - Iterator.empty |
| 115 | + aggregationIterator |
126 | 116 | } |
127 | | - } else { |
128 | | - aggregationIterator.start(parentIterator) |
129 | | - aggregationIterator |
130 | 117 | } |
131 | 118 | } |
132 | | - |
133 | | - // Note: we need to set up the iterator in each partition before computing the |
134 | | - // parent partition, so we cannot simply use `mapPartitions` here (SPARK-9747). |
135 | | - val resultRdd = { |
136 | | - new MapPartitionsWithPreparationRDD[UnsafeRow, InternalRow, TungstenAggregationIterator]( |
137 | | - child.execute(), preparePartition, executePartition, preservesPartitioning = true) |
138 | | - } |
139 | | - resultRdd.asInstanceOf[RDD[InternalRow]] |
140 | 119 | } |
141 | 120 |
|
142 | 121 | override def simpleString: String = { |
|
0 commit comments