Skip to content

Commit aff9534

Browse files
committed
Make Aggregate2Sort work with both algebraic AggregateFunctions and non-algebraic AggregateFunctions.
1 parent 2857b55 commit aff9534

File tree

6 files changed

+205
-40
lines changed

6 files changed

+205
-40
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
package org.apache.spark.sql.catalyst.analysis
1919

2020
import org.apache.spark.sql.AnalysisException
21+
import org.apache.spark.sql.catalyst.expressions.aggregate2.{Complete, AggregateExpression2, AggregateFunction2}
2122
import org.apache.spark.sql.catalyst.{SimpleCatalystConf, CatalystConf}
2223
import org.apache.spark.sql.catalyst.expressions._
2324
import org.apache.spark.sql.catalyst.plans.logical._
@@ -483,7 +484,10 @@ class Analyzer(
483484
q transformExpressions {
484485
case u @ UnresolvedFunction(name, children) =>
485486
withPosition(u) {
486-
registry.lookupFunction(name, children)
487+
registry.lookupFunction(name, children) match {
488+
case agg2: AggregateFunction2 => AggregateExpression2(agg2, Complete, false)
489+
case other => other
490+
}
487491
}
488492
}
489493
}
@@ -501,6 +505,7 @@ class Analyzer(
501505
def containsAggregates(exprs: Seq[Expression]): Boolean = {
502506
exprs.foreach(_.foreach {
503507
case agg: AggregateExpression => return true
508+
case agg2: AggregateExpression2 => return true
504509
case _ =>
505510
})
506511
false

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717

1818
package org.apache.spark.sql.catalyst.analysis
1919

20+
import org.apache.spark.sql.catalyst.expressions.aggregate2.MyDoubleSum
21+
2022
import scala.reflect.ClassTag
2123
import scala.util.{Failure, Success, Try}
2224

@@ -143,6 +145,7 @@ object FunctionRegistry {
143145
expression[Max]("max"),
144146
expression[Min]("min"),
145147
expression[Sum]("sum"),
148+
expression[MyDoubleSum]("mydoublesum"),
146149

147150
// string functions
148151
expression[Ascii]("ascii"),

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate2/aggregates.scala

Lines changed: 53 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -95,10 +95,56 @@ abstract class AggregateFunction2
9595
override def eval(buffer: InternalRow = null): Any
9696
}
9797

98+
case class MyDoubleSum(child: Expression) extends AggregateFunction2 {
99+
override val bufferSchema: StructType =
100+
StructType(StructField("currentSum", DoubleType, true) :: Nil)
101+
102+
override val bufferAttributes: Seq[Attribute] = bufferSchema.toAttributes
103+
104+
override def initialize(buffer: MutableRow): Unit = {
105+
buffer.update(bufferOffset, null)
106+
}
107+
108+
override def update(buffer: MutableRow, input: InternalRow): Unit = {
109+
val inputValue = child.eval(input)
110+
if (inputValue != null) {
111+
if (buffer.isNullAt(bufferOffset) == null) {
112+
buffer.setDouble(bufferOffset, inputValue.asInstanceOf[Double])
113+
} else {
114+
val currentSum = buffer.getDouble(bufferOffset)
115+
buffer.setDouble(bufferOffset, currentSum + inputValue.asInstanceOf[Double])
116+
}
117+
}
118+
}
119+
120+
override def merge(buffer1: MutableRow, buffer2: InternalRow): Unit = {
121+
if (!buffer2.isNullAt(bufferOffset)) {
122+
if (buffer1.isNullAt(bufferOffset)) {
123+
buffer1.setDouble(bufferOffset, buffer2.getDouble(bufferOffset))
124+
} else {
125+
val currentSum = buffer1.getDouble(bufferOffset)
126+
buffer1.setDouble(bufferOffset, currentSum + buffer2.getDouble(bufferOffset))
127+
}
128+
}
129+
}
130+
131+
override def eval(buffer: InternalRow = null): Any = {
132+
if (buffer.isNullAt(bufferOffset)) {
133+
null
134+
} else {
135+
buffer.getDouble(bufferOffset)
136+
}
137+
}
138+
139+
override def nullable: Boolean = true
140+
override def dataType: DataType = DoubleType
141+
override def children: Seq[Expression] = child :: Nil
142+
}
143+
98144
/**
99145
* A helper class for aggregate functions that can be implemented in terms of catalyst expressions.
100146
*/
101-
abstract class AlgebraicAggregate extends AggregateFunction2 with Serializable{
147+
abstract class AlgebraicAggregate extends AggregateFunction2 with Serializable {
102148
self: Product =>
103149

104150
val initialValues: Seq[Expression]
@@ -109,6 +155,11 @@ abstract class AlgebraicAggregate extends AggregateFunction2 with Serializable{
109155
/** Must be filled in by the executors */
110156
var inputSchema: Seq[Attribute] = _
111157

158+
override def withBufferOffset(newBufferOffset: Int): AlgebraicAggregate = {
159+
bufferOffset = newBufferOffset
160+
this
161+
}
162+
112163
def offsetExpressions: Seq[Attribute] = Seq.fill(bufferOffset)(AttributeReference("offset", NullType)())
113164

114165
lazy val rightBufferSchema = bufferAttributes.map(_.newInstance())
@@ -182,7 +233,7 @@ case class Average(child: Expression) extends AlgebraicAggregate {
182233

183234
val evaluateExpression = Cast(currentSum, resultType) / Cast(currentCount, resultType)
184235

185-
override def nullable: Boolean = false
236+
override def nullable: Boolean = true
186237
override def dataType: DataType = resultType
187238
override def children: Seq[Expression] = child :: Nil
188239
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
package org.apache.spark.sql.catalyst.plans.logical
1919

2020
import org.apache.spark.sql.catalyst.expressions._
21+
import org.apache.spark.sql.catalyst.expressions.aggregate2.AggregateExpression2
2122
import org.apache.spark.sql.catalyst.plans._
2223
import org.apache.spark.sql.types._
2324
import org.apache.spark.util.collection.OpenHashSet
@@ -28,6 +29,7 @@ case class Project(projectList: Seq[NamedExpression], child: LogicalPlan) extend
2829
override lazy val resolved: Boolean = {
2930
val hasSpecialExpressions = projectList.exists ( _.collect {
3031
case agg: AggregateExpression => agg
32+
case agg: AggregateExpression2 => agg
3133
case generator: Generator => generator
3234
case window: WindowExpression => window
3335
}.nonEmpty

sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate2/Aggregate2Sort.scala

Lines changed: 87 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@ import org.apache.spark.sql.catalyst.plans.physical.{AllTuples, ClusteredDistrib
2626
import org.apache.spark.sql.execution.{SparkPlan, UnaryNode}
2727
import org.apache.spark.sql.types.NullType
2828

29+
import scala.collection.mutable.ArrayBuffer
30+
2931
case class Aggregate2Sort(
3032
preShuffle: Boolean,
3133
groupingExpressions: Seq[NamedExpression],
@@ -57,40 +59,73 @@ case class Aggregate2Sort(
5759
child.execute().mapPartitions { iter =>
5860

5961
new Iterator[InternalRow] {
60-
private val aggregateFunctions: Array[AggregateFunction2] = {
62+
private val aggregateExprsWithBufferOffset = {
6163
var bufferOffset =
6264
if (preShuffle) {
6365
0
6466
} else {
6567
groupingExpressions.length
6668
}
69+
val bufferOffsets = new ArrayBuffer[Int]()
6770
var i = 0
68-
val functions = new Array[AggregateFunction2](aggregateExpressions.length)
6971
while (i < aggregateExpressions.length) {
70-
val func = aggregateExpressions(i).aggregateFunction.withBufferOffset(bufferOffset)
71-
functions(i) = aggregateExpressions(i).mode match {
72-
case Partial | Complete => func
73-
case PartialMerge | Final => func
74-
}
72+
val func = aggregateExpressions(i).aggregateFunction
73+
bufferOffsets += bufferOffset
7574
bufferOffset = aggregateExpressions(i).mode match {
7675
case Partial | PartialMerge => bufferOffset + func.bufferSchema.length
7776
case Final | Complete => bufferOffset + 1
7877
}
7978
i += 1
8079
}
80+
aggregateExpressions.zip(bufferOffsets)
81+
}
8182

82-
functions.foreach {
83-
case ae: AlgebraicAggregate => ae.inputSchema = child.output
84-
case _ =>
83+
private val algebraicAggregateFunctions: Array[AlgebraicAggregate] = {
84+
aggregateExprsWithBufferOffset.collect {
85+
case (AggregateExpression2(agg: AlgebraicAggregate, mode, isDistinct), offset) =>
86+
agg.inputSchema = child.output
87+
agg.withBufferOffset(offset)
88+
}.toArray
89+
}
90+
91+
private val nonAlgebraicAggregateFunctions: Array[AggregateFunction2] = {
92+
aggregateExprsWithBufferOffset.collect {
93+
case (AggregateExpression2(agg: AggregateFunction2, mode, isDistinct), offset)
94+
if !agg.isInstanceOf[AlgebraicAggregate] =>
95+
val func = agg.withBufferOffset(offset)
96+
mode match {
97+
case Partial | Complete =>
98+
// Only need to bind reference when the function is not an AlgebraicAggregate
99+
// and the mode is Partial or Complete.
100+
BindReferences.bindReference(func, child.output)
101+
case _ => func
102+
}
103+
}.toArray
104+
}
105+
106+
private val nonAlgebraicAggregateFunctionOrdinals: Array[Int] = {
107+
val ordinals = new ArrayBuffer[Int]()
108+
var i = 0
109+
while (i < aggregateExpressions.length) {
110+
aggregateExpressions(i).aggregateFunction match {
111+
case agg: AlgebraicAggregate =>
112+
case _ => ordinals += i
113+
}
114+
i += 1
85115
}
86-
functions
116+
ordinals.toArray
87117
}
88118

89119
private val bufferSize: Int = {
90-
var i = 0
91120
var size = 0
92-
while (i < aggregateFunctions.length) {
93-
size += aggregateFunctions(i).bufferSchema.length
121+
var i = 0
122+
while (i < algebraicAggregateFunctions.length) {
123+
size += algebraicAggregateFunctions(i).bufferSchema.length
124+
i += 1
125+
}
126+
i = 0
127+
while (i < nonAlgebraicAggregateFunctions.length) {
128+
size += nonAlgebraicAggregateFunctions(i).bufferSchema.length
94129
i += 1
95130
}
96131
if (preShuffle) {
@@ -124,48 +159,49 @@ case class Aggregate2Sort(
124159
val offsetAttributes = if (preShuffle) Nil else Seq.fill(groupingExpressions.length)(AttributeReference("offset", NullType)())
125160
val offsetExpressions = if (preShuffle) Nil else Seq.fill(groupingExpressions.length)(NoOp)
126161

127-
val initialProjection = {
128-
val initExpressions = offsetExpressions ++ aggregateFunctions.flatMap {
162+
val algebraicInitialProjection = {
163+
val initExpressions = offsetExpressions ++ algebraicAggregateFunctions.flatMap {
129164
case ae: AlgebraicAggregate => ae.initialValues
130165
}
131166
// println(initExpressions.mkString(","))
167+
132168
newMutableProjection(initExpressions, Nil)().target(buffer)
133169
}
134170

135-
lazy val updateProjection = {
136-
val bufferSchema = aggregateFunctions.flatMap {
171+
lazy val algebraicUpdateProjection = {
172+
val bufferSchema = algebraicAggregateFunctions.flatMap {
137173
case ae: AlgebraicAggregate => ae.bufferAttributes
138174
}
139-
val updateExpressions = aggregateFunctions.flatMap {
175+
val updateExpressions = algebraicAggregateFunctions.flatMap {
140176
case ae: AlgebraicAggregate => ae.updateExpressions
141177
}
142178

143179
// println(updateExpressions.mkString(","))
144180
newMutableProjection(updateExpressions, bufferSchema ++ child.output)().target(buffer)
145181
}
146182

147-
lazy val mergeProjection = {
183+
lazy val algebraicMergeProjection = {
148184
val bufferSchemata =
149-
offsetAttributes ++ aggregateFunctions.flatMap {
185+
offsetAttributes ++ algebraicAggregateFunctions.flatMap {
150186
case ae: AlgebraicAggregate => ae.bufferAttributes
151-
} ++ offsetAttributes ++ aggregateFunctions.flatMap {
187+
} ++ offsetAttributes ++ algebraicAggregateFunctions.flatMap {
152188
case ae: AlgebraicAggregate => ae.rightBufferSchema
153189
}
154-
val mergeExpressions = offsetExpressions ++ aggregateFunctions.flatMap {
190+
val mergeExpressions = offsetExpressions ++ algebraicAggregateFunctions.flatMap {
155191
case ae: AlgebraicAggregate => ae.mergeExpressions
156192
}
157193

158194
newMutableProjection(mergeExpressions, bufferSchemata)()
159195
}
160196

161-
lazy val evalProjection = {
197+
lazy val algebraicEvalProjection = {
162198
val bufferSchemata =
163-
offsetAttributes ++ aggregateFunctions.flatMap {
199+
offsetAttributes ++ algebraicAggregateFunctions.flatMap {
164200
case ae: AlgebraicAggregate => ae.bufferAttributes
165-
} ++ offsetAttributes ++ aggregateFunctions.flatMap {
201+
} ++ offsetAttributes ++ algebraicAggregateFunctions.flatMap {
166202
case ae: AlgebraicAggregate => ae.rightBufferSchema
167203
}
168-
val evalExpressions = aggregateFunctions.map {
204+
val evalExpressions = algebraicAggregateFunctions.map {
169205
case ae: AlgebraicAggregate => ae.evaluateExpression
170206
}
171207

@@ -190,16 +226,31 @@ case class Aggregate2Sort(
190226
}
191227

192228
private def initializeBuffer(): Unit = {
193-
initialProjection(EmptyRow)
229+
algebraicInitialProjection(EmptyRow)
230+
var i = 0
231+
while (i < nonAlgebraicAggregateFunctions.length) {
232+
nonAlgebraicAggregateFunctions(i).initialize(buffer)
233+
i += 1
234+
}
194235
// println("initilized: " + buffer)
195236
}
196237

197238
private def processRow(row: InternalRow): Unit = {
198239
// The new row is still in the current group.
199240
if (preShuffle) {
200-
updateProjection(joinedRow(buffer, row))
241+
algebraicUpdateProjection(joinedRow(buffer, row))
242+
var i = 0
243+
while (i < nonAlgebraicAggregateFunctions.length) {
244+
nonAlgebraicAggregateFunctions(i).update(buffer, row)
245+
i += 1
246+
}
201247
} else {
202-
mergeProjection.target(buffer)(joinedRow(buffer, row))
248+
algebraicMergeProjection.target(buffer)(joinedRow(buffer, row))
249+
var i = 0
250+
while (i < nonAlgebraicAggregateFunctions.length) {
251+
nonAlgebraicAggregateFunctions(i).merge(buffer, row)
252+
i += 1
253+
}
203254
}
204255
}
205256

@@ -244,15 +295,15 @@ case class Aggregate2Sort(
244295
// If it is preShuffle, we just output the grouping columns and the buffer.
245296
joinedRow(currentGroupingKey, buffer).copy()
246297
} else {
247-
/*
298+
algebraicEvalProjection.target(aggregateResult)(buffer)
248299
var i = 0
249-
while (i < aggregateFunctions.length) {
250-
aggregateResult.update(i, aggregateFunctions(i).eval(buffer))
300+
while (i < nonAlgebraicAggregateFunctions.length) {
301+
aggregateResult.update(
302+
nonAlgebraicAggregateFunctionOrdinals(i),
303+
nonAlgebraicAggregateFunctions(i).eval(buffer))
251304
i += 1
252305
}
253-
resultProjection(joinedRow(currentGroupingKey, aggregateResult)).copy()
254-
*/
255-
resultProjection(joinedRow(currentGroupingKey, evalProjection.target(aggregateResult)(buffer)))
306+
resultProjection(joinedRow(currentGroupingKey, aggregateResult))
256307

257308
}
258309
initializeBuffer()

0 commit comments

Comments
 (0)