@@ -71,34 +71,35 @@ case class Aggregate2Sort(
7171 while (i < aggregateExpressions.length) {
7272 val func = aggregateExpressions(i).aggregateFunction
7373 bufferOffsets += bufferOffset
74- bufferOffset = aggregateExpressions(i).mode match {
75- case Partial | PartialMerge => bufferOffset + func.bufferSchema.length
76- case Final | Complete => bufferOffset + 1
77- }
74+ bufferOffset += func.bufferSchema.length
7875 i += 1
7976 }
8077 aggregateExpressions.zip(bufferOffsets)
8178 }
82-
83- private val algebraicAggregateFunctions : Array [AlgebraicAggregate ] = {
84- aggregateExprsWithBufferOffset.collect {
85- case (AggregateExpression2 (agg : AlgebraicAggregate , mode, isDistinct), offset) =>
86- agg.inputSchema = child.output
87- agg.withBufferOffset(offset)
79+ // println("aggregateExprsWithBufferOffset " + aggregateExprsWithBufferOffset)
80+
81+ private val aggregateFunctions : Array [AggregateFunction2 ] = {
82+ aggregateExprsWithBufferOffset.map {
83+ case (aggExpr, bufferOffset) =>
84+ val func = aggExpr.aggregateFunction
85+ func.bufferOffset = bufferOffset
86+ func
8887 }.toArray
8988 }
9089
9190 private val nonAlgebraicAggregateFunctions : Array [AggregateFunction2 ] = {
9291 aggregateExprsWithBufferOffset.collect {
9392 case (AggregateExpression2 (agg : AggregateFunction2 , mode, isDistinct), offset)
9493 if ! agg.isInstanceOf [AlgebraicAggregate ] =>
95- val func = agg.withBufferOffset(offset)
9694 mode match {
9795 case Partial | Complete =>
9896 // Only need to bind reference when the function is not an AlgebraicAggregate
9997 // and the mode is Partial or Complete.
100- BindReferences .bindReference(func, child.output)
101- case _ => func
98+ val func = BindReferences .bindReference(agg, child.output)
99+ // Need to set it again since BindReference will create a new instance.
100+ func.bufferOffset = offset
101+ func
102+ case _ => agg
102103 }
103104 }.toArray
104105 }
@@ -119,13 +120,8 @@ case class Aggregate2Sort(
119120 private val bufferSize : Int = {
120121 var size = 0
121122 var i = 0
122- while (i < algebraicAggregateFunctions.length) {
123- size += algebraicAggregateFunctions(i).bufferSchema.length
124- i += 1
125- }
126- i = 0
127- while (i < nonAlgebraicAggregateFunctions.length) {
128- size += nonAlgebraicAggregateFunctions(i).bufferSchema.length
123+ while (i < aggregateFunctions.length) {
124+ size += aggregateFunctions(i).bufferSchema.length
129125 i += 1
130126 }
131127 if (preShuffle) {
@@ -160,20 +156,23 @@ case class Aggregate2Sort(
160156 val offsetExpressions = if (preShuffle) Nil else Seq .fill(groupingExpressions.length)(NoOp )
161157
162158 val algebraicInitialProjection = {
163- val initExpressions = offsetExpressions ++ algebraicAggregateFunctions .flatMap {
159+ val initExpressions = offsetExpressions ++ aggregateFunctions .flatMap {
164160 case ae : AlgebraicAggregate => ae.initialValues
161+ case agg : AggregateFunction2 => NoOp :: Nil
165162 }
166163 // println(initExpressions.mkString(","))
167164
168165 newMutableProjection(initExpressions, Nil )().target(buffer)
169166 }
170167
171168 lazy val algebraicUpdateProjection = {
172- val bufferSchema = algebraicAggregateFunctions .flatMap {
169+ val bufferSchema = aggregateFunctions .flatMap {
173170 case ae : AlgebraicAggregate => ae.bufferAttributes
171+ case agg : AggregateFunction2 => agg.bufferAttributes
174172 }
175- val updateExpressions = algebraicAggregateFunctions .flatMap {
173+ val updateExpressions = aggregateFunctions .flatMap {
176174 case ae : AlgebraicAggregate => ae.updateExpressions
175+ case agg : AggregateFunction2 => NoOp :: Nil
177176 }
178177
179178 // println(updateExpressions.mkString(","))
@@ -182,27 +181,33 @@ case class Aggregate2Sort(
182181
183182 lazy val algebraicMergeProjection = {
184183 val bufferSchemata =
185- offsetAttributes ++ algebraicAggregateFunctions .flatMap {
184+ offsetAttributes ++ aggregateFunctions .flatMap {
186185 case ae : AlgebraicAggregate => ae.bufferAttributes
187- } ++ offsetAttributes ++ algebraicAggregateFunctions.flatMap {
186+ case agg : AggregateFunction2 => agg.bufferAttributes
187+ } ++ offsetAttributes ++ aggregateFunctions.flatMap {
188188 case ae : AlgebraicAggregate => ae.rightBufferSchema
189+ case agg : AggregateFunction2 => agg.rightBufferSchema
189190 }
190- val mergeExpressions = offsetExpressions ++ algebraicAggregateFunctions .flatMap {
191+ val mergeExpressions = offsetExpressions ++ aggregateFunctions .flatMap {
191192 case ae : AlgebraicAggregate => ae.mergeExpressions
193+ case agg : AggregateFunction2 => NoOp :: Nil
192194 }
193195
194196 newMutableProjection(mergeExpressions, bufferSchemata)()
195197 }
196198
197199 lazy val algebraicEvalProjection = {
198200 val bufferSchemata =
199- offsetAttributes ++ algebraicAggregateFunctions .flatMap {
201+ offsetAttributes ++ aggregateFunctions .flatMap {
200202 case ae : AlgebraicAggregate => ae.bufferAttributes
201- } ++ offsetAttributes ++ algebraicAggregateFunctions.flatMap {
203+ case agg : AggregateFunction2 => agg.bufferAttributes
204+ } ++ offsetAttributes ++ aggregateFunctions.flatMap {
202205 case ae : AlgebraicAggregate => ae.rightBufferSchema
206+ case agg : AggregateFunction2 => agg.rightBufferSchema
203207 }
204- val evalExpressions = algebraicAggregateFunctions .map {
208+ val evalExpressions = aggregateFunctions .map {
205209 case ae : AlgebraicAggregate => ae.evaluateExpression
210+ case agg : AggregateFunction2 => NoOp
206211 }
207212
208213 newMutableProjection(evalExpressions, bufferSchemata)()
@@ -251,6 +256,7 @@ case class Aggregate2Sort(
251256 nonAlgebraicAggregateFunctions(i).merge(buffer, row)
252257 i += 1
253258 }
259+ // println("buffer merge " + buffer + " " + row)
254260 }
255261 }
256262
@@ -293,6 +299,7 @@ case class Aggregate2Sort(
293299 val outputRow =
294300 if (preShuffle) {
295301 // If it is preShuffle, we just output the grouping columns and the buffer.
302+ // println("buffer " + buffer)
296303 joinedRow(currentGroupingKey, buffer).copy()
297304 } else {
298305 algebraicEvalProjection.target(aggregateResult)(buffer)
@@ -304,7 +311,6 @@ case class Aggregate2Sort(
304311 i += 1
305312 }
306313 resultProjection(joinedRow(currentGroupingKey, aggregateResult))
307-
308314 }
309315 initializeBuffer()
310316
0 commit comments