@@ -96,6 +96,8 @@ abstract class AlgebraicAggregate extends AggregateFunction2 with Serializable{
9696 /** Must be filled in by the executors */
9797 var inputSchema : Seq [Attribute ] = _
9898
99+ def offsetExpressions : Seq [Attribute ] = Seq .fill(bufferOffset)(AttributeReference (" offset" , NullType )())
100+
99101 lazy val rightBufferSchema = bufferSchema.map(_.newInstance())
100102 implicit class RichAttribute (a : AttributeReference ) {
101103 def left = a
@@ -112,8 +114,11 @@ abstract class AlgebraicAggregate extends AggregateFunction2 with Serializable{
112114 }
113115
114116 lazy val boundUpdateExpressions = {
115- val updateSchema = inputSchema ++ bufferSchema
116- updateExpressions.map(BindReferences .bindReference(_, updateSchema)).toArray
117+ val updateSchema = inputSchema ++ offsetExpressions ++ bufferSchema
118+ val bound = updateExpressions.map(BindReferences .bindReference(_, updateSchema)).toArray
119+ println(s " update: ${updateExpressions.mkString(" ," )}" )
120+ println(s " update: ${bound.mkString(" ," )}" )
121+ bound
117122 }
118123
119124 val joinedRow = new JoinedRow
@@ -126,20 +131,27 @@ abstract class AlgebraicAggregate extends AggregateFunction2 with Serializable{
126131 }
127132
128133 lazy val boundMergeExpressions = {
129- val mergeSchema = bufferSchema ++ rightBufferSchema
134+ val mergeSchema = offsetExpressions ++ bufferSchema ++ offsetExpressions ++ rightBufferSchema
130135 mergeExpressions.map(BindReferences .bindReference(_, mergeSchema)).toArray
131136 }
132137 override def merge (buffer1 : MutableRow , buffer2 : InternalRow ): Unit = {
133138 var i = 0
139+ println(s " Merging: $buffer1 $buffer2 with ${boundMergeExpressions.mkString(" ," )}" )
140+ joinedRow(buffer1, buffer2)
134141 while (i < bufferSchema.size) {
135- buffer1(i + bufferOffset) = boundMergeExpressions(i).eval(joinedRow(buffer1, buffer2))
142+ println(s " $i + $bufferOffset: ${boundMergeExpressions(i).eval(joinedRow)}" )
143+ buffer1(i + bufferOffset) = boundMergeExpressions(i).eval(joinedRow)
136144 i += 1
137145 }
138146 }
139147
140- lazy val boundEvaluateExpression = BindReferences .bindReference(evaluateExpression, bufferSchema)
148+ lazy val boundEvaluateExpression =
149+ BindReferences .bindReference(evaluateExpression, offsetExpressions ++ bufferSchema)
141150 override def eval (buffer : InternalRow ): Any = {
142- boundEvaluateExpression.eval(buffer)
151+ println(s " eval: $buffer" )
152+ val res = boundEvaluateExpression.eval(buffer)
153+ println(s " eval: $buffer with $boundEvaluateExpression => $res" )
154+ res
143155 }
144156}
145157
@@ -171,15 +183,15 @@ case class Average(child: Expression) extends AlgebraicAggregate {
171183 Add (
172184 currentSum,
173185 Coalesce (Cast (child, intermediateType) :: Cast (Literal (0 ), intermediateType) :: Nil )),
174- /* currentCount = */ If (IsNotNull (child), currentCount, currentCount + 1L )
186+ /* currentCount = */ If (IsNull (child), currentCount, currentCount + 1L )
175187 )
176188
177189 val mergeExpressions = Seq (
178190 /* currentSum = */ currentSum.left + currentSum.right,
179191 /* currentCount = */ currentCount.left + currentCount.right
180192 )
181193
182- val evaluateExpression = Cast (currentCount , resultType) / Cast (currentSum , resultType)
194+ val evaluateExpression = Cast (currentSum , resultType) / Cast (currentCount , resultType)
183195
184196 override def nullable : Boolean = false
185197 override def dataType : DataType = resultType
0 commit comments