@@ -26,6 +26,8 @@ import org.apache.spark.sql.catalyst.plans.physical.{AllTuples, ClusteredDistrib
2626import org .apache .spark .sql .execution .{SparkPlan , UnaryNode }
2727import org .apache .spark .sql .types .NullType
2828
29+ import scala .collection .mutable .ArrayBuffer
30+
2931case class Aggregate2Sort (
3032 preShuffle : Boolean ,
3133 groupingExpressions : Seq [NamedExpression ],
@@ -57,40 +59,73 @@ case class Aggregate2Sort(
5759 child.execute().mapPartitions { iter =>
5860
5961 new Iterator [InternalRow ] {
60- private val aggregateFunctions : Array [ AggregateFunction2 ] = {
62+ private val aggregateExprsWithBufferOffset = {
6163 var bufferOffset =
6264 if (preShuffle) {
6365 0
6466 } else {
6567 groupingExpressions.length
6668 }
69+ val bufferOffsets = new ArrayBuffer [Int ]()
6770 var i = 0
68- val functions = new Array [AggregateFunction2 ](aggregateExpressions.length)
6971 while (i < aggregateExpressions.length) {
70- val func = aggregateExpressions(i).aggregateFunction.withBufferOffset(bufferOffset)
71- functions(i) = aggregateExpressions(i).mode match {
72- case Partial | Complete => func
73- case PartialMerge | Final => func
74- }
72+ val func = aggregateExpressions(i).aggregateFunction
73+ bufferOffsets += bufferOffset
7574 bufferOffset = aggregateExpressions(i).mode match {
7675 case Partial | PartialMerge => bufferOffset + func.bufferSchema.length
7776 case Final | Complete => bufferOffset + 1
7877 }
7978 i += 1
8079 }
80+ aggregateExpressions.zip(bufferOffsets)
81+ }
8182
82- functions.foreach {
83- case ae : AlgebraicAggregate => ae.inputSchema = child.output
84- case _ =>
83+ private val algebraicAggregateFunctions : Array [AlgebraicAggregate ] = {
84+ aggregateExprsWithBufferOffset.collect {
85+ case (AggregateExpression2 (agg : AlgebraicAggregate , mode, isDistinct), offset) =>
86+ agg.inputSchema = child.output
87+ agg.withBufferOffset(offset)
88+ }.toArray
89+ }
90+
91+ private val nonAlgebraicAggregateFunctions : Array [AggregateFunction2 ] = {
92+ aggregateExprsWithBufferOffset.collect {
93+ case (AggregateExpression2 (agg : AggregateFunction2 , mode, isDistinct), offset)
94+ if ! agg.isInstanceOf [AlgebraicAggregate ] =>
95+ val func = agg.withBufferOffset(offset)
96+ mode match {
97+ case Partial | Complete =>
98+ // Only need to bind reference when the function is not an AlgebraicAggregate
99+ // and the mode is Partial or Complete.
100+ BindReferences .bindReference(func, child.output)
101+ case _ => func
102+ }
103+ }.toArray
104+ }
105+
106+ private val nonAlgebraicAggregateFunctionOrdinals : Array [Int ] = {
107+ val ordinals = new ArrayBuffer [Int ]()
108+ var i = 0
109+ while (i < aggregateExpressions.length) {
110+ aggregateExpressions(i).aggregateFunction match {
111+ case agg : AlgebraicAggregate =>
112+ case _ => ordinals += i
113+ }
114+ i += 1
85115 }
86- functions
116+ ordinals.toArray
87117 }
88118
89119 private val bufferSize : Int = {
90- var i = 0
91120 var size = 0
92- while (i < aggregateFunctions.length) {
93- size += aggregateFunctions(i).bufferSchema.length
121+ var i = 0
122+ while (i < algebraicAggregateFunctions.length) {
123+ size += algebraicAggregateFunctions(i).bufferSchema.length
124+ i += 1
125+ }
126+ i = 0
127+ while (i < nonAlgebraicAggregateFunctions.length) {
128+ size += nonAlgebraicAggregateFunctions(i).bufferSchema.length
94129 i += 1
95130 }
96131 if (preShuffle) {
@@ -124,48 +159,49 @@ case class Aggregate2Sort(
124159 val offsetAttributes = if (preShuffle) Nil else Seq .fill(groupingExpressions.length)(AttributeReference (" offset" , NullType )())
125160 val offsetExpressions = if (preShuffle) Nil else Seq .fill(groupingExpressions.length)(NoOp )
126161
127- val initialProjection = {
128- val initExpressions = offsetExpressions ++ aggregateFunctions .flatMap {
162+ val algebraicInitialProjection = {
163+ val initExpressions = offsetExpressions ++ algebraicAggregateFunctions .flatMap {
129164 case ae : AlgebraicAggregate => ae.initialValues
130165 }
131166 // println(initExpressions.mkString(","))
167+
132168 newMutableProjection(initExpressions, Nil )().target(buffer)
133169 }
134170
135- lazy val updateProjection = {
136- val bufferSchema = aggregateFunctions .flatMap {
171+ lazy val algebraicUpdateProjection = {
172+ val bufferSchema = algebraicAggregateFunctions .flatMap {
137173 case ae : AlgebraicAggregate => ae.bufferAttributes
138174 }
139- val updateExpressions = aggregateFunctions .flatMap {
175+ val updateExpressions = algebraicAggregateFunctions .flatMap {
140176 case ae : AlgebraicAggregate => ae.updateExpressions
141177 }
142178
143179 // println(updateExpressions.mkString(","))
144180 newMutableProjection(updateExpressions, bufferSchema ++ child.output)().target(buffer)
145181 }
146182
147- lazy val mergeProjection = {
183+ lazy val algebraicMergeProjection = {
148184 val bufferSchemata =
149- offsetAttributes ++ aggregateFunctions .flatMap {
185+ offsetAttributes ++ algebraicAggregateFunctions .flatMap {
150186 case ae : AlgebraicAggregate => ae.bufferAttributes
151- } ++ offsetAttributes ++ aggregateFunctions .flatMap {
187+ } ++ offsetAttributes ++ algebraicAggregateFunctions .flatMap {
152188 case ae : AlgebraicAggregate => ae.rightBufferSchema
153189 }
154- val mergeExpressions = offsetExpressions ++ aggregateFunctions .flatMap {
190+ val mergeExpressions = offsetExpressions ++ algebraicAggregateFunctions .flatMap {
155191 case ae : AlgebraicAggregate => ae.mergeExpressions
156192 }
157193
158194 newMutableProjection(mergeExpressions, bufferSchemata)()
159195 }
160196
161- lazy val evalProjection = {
197+ lazy val algebraicEvalProjection = {
162198 val bufferSchemata =
163- offsetAttributes ++ aggregateFunctions .flatMap {
199+ offsetAttributes ++ algebraicAggregateFunctions .flatMap {
164200 case ae : AlgebraicAggregate => ae.bufferAttributes
165- } ++ offsetAttributes ++ aggregateFunctions .flatMap {
201+ } ++ offsetAttributes ++ algebraicAggregateFunctions .flatMap {
166202 case ae : AlgebraicAggregate => ae.rightBufferSchema
167203 }
168- val evalExpressions = aggregateFunctions .map {
204+ val evalExpressions = algebraicAggregateFunctions .map {
169205 case ae : AlgebraicAggregate => ae.evaluateExpression
170206 }
171207
@@ -190,16 +226,31 @@ case class Aggregate2Sort(
190226 }
191227
192228 private def initializeBuffer (): Unit = {
193- initialProjection(EmptyRow )
229+ algebraicInitialProjection(EmptyRow )
230+ var i = 0
231+ while (i < nonAlgebraicAggregateFunctions.length) {
232+ nonAlgebraicAggregateFunctions(i).initialize(buffer)
233+ i += 1
234+ }
194235 // println("initilized: " + buffer)
195236 }
196237
197238 private def processRow (row : InternalRow ): Unit = {
198239 // The new row is still in the current group.
199240 if (preShuffle) {
200- updateProjection(joinedRow(buffer, row))
241+ algebraicUpdateProjection(joinedRow(buffer, row))
242+ var i = 0
243+ while (i < nonAlgebraicAggregateFunctions.length) {
244+ nonAlgebraicAggregateFunctions(i).update(buffer, row)
245+ i += 1
246+ }
201247 } else {
202- mergeProjection.target(buffer)(joinedRow(buffer, row))
248+ algebraicMergeProjection.target(buffer)(joinedRow(buffer, row))
249+ var i = 0
250+ while (i < nonAlgebraicAggregateFunctions.length) {
251+ nonAlgebraicAggregateFunctions(i).merge(buffer, row)
252+ i += 1
253+ }
203254 }
204255 }
205256
@@ -244,15 +295,15 @@ case class Aggregate2Sort(
244295 // If it is preShuffle, we just output the grouping columns and the buffer.
245296 joinedRow(currentGroupingKey, buffer).copy()
246297 } else {
247- /*
298+ algebraicEvalProjection.target(aggregateResult)(buffer)
248299 var i = 0
249- while (i < aggregateFunctions.length) {
250- aggregateResult.update(i, aggregateFunctions(i).eval(buffer))
300+ while (i < nonAlgebraicAggregateFunctions.length) {
301+ aggregateResult.update(
302+ nonAlgebraicAggregateFunctionOrdinals(i),
303+ nonAlgebraicAggregateFunctions(i).eval(buffer))
251304 i += 1
252305 }
253- resultProjection(joinedRow(currentGroupingKey, aggregateResult)).copy()
254- */
255- resultProjection(joinedRow(currentGroupingKey, evalProjection.target(aggregateResult)(buffer)))
306+ resultProjection(joinedRow(currentGroupingKey, aggregateResult))
256307
257308 }
258309 initializeBuffer()
0 commit comments