Skip to content

Commit 1b0bb3f

Browse files
committed
Do not bind references in AlgebraicAggregate and use code gen for all places.
1 parent 072209f commit 1b0bb3f

File tree

3 files changed

+53
-43
lines changed

3 files changed

+53
-43
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate2/aggregates.scala

Lines changed: 6 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -128,45 +128,19 @@ abstract class AlgebraicAggregate extends AggregateFunction2 with Serializable{
128128
}
129129
}
130130

131-
lazy val boundUpdateExpressions = {
132-
val updateSchema = inputSchema ++ offsetExpressions ++ bufferAttributes
133-
val bound = updateExpressions.map(BindReferences.bindReference(_, updateSchema)).toArray
134-
println(s"update: ${updateExpressions.mkString(",")}")
135-
println(s"update: ${bound.mkString(",")}")
136-
bound
137-
}
138-
139-
val joinedRow = new JoinedRow
140131
override def update(buffer: MutableRow, input: InternalRow): Unit = {
141-
var i = 0
142-
while (i < bufferAttributes.size) {
143-
buffer(i + bufferOffset) = boundUpdateExpressions(i).eval(joinedRow(input, buffer))
144-
i += 1
145-
}
132+
throw new UnsupportedOperationException(
133+
"AlgebraicAggregate's update should not be called directly")
146134
}
147135

148-
lazy val boundMergeExpressions = {
149-
val mergeSchema = offsetExpressions ++ bufferAttributes ++ offsetExpressions ++ rightBufferSchema
150-
mergeExpressions.map(BindReferences.bindReference(_, mergeSchema)).toArray
151-
}
152136
override def merge(buffer1: MutableRow, buffer2: InternalRow): Unit = {
153-
var i = 0
154-
println(s"Merging: $buffer1 $buffer2 with ${boundMergeExpressions.mkString(",")}")
155-
joinedRow(buffer1, buffer2)
156-
while (i < bufferAttributes.size) {
157-
println(s"$i + $bufferOffset: ${boundMergeExpressions(i).eval(joinedRow)}")
158-
buffer1(i + bufferOffset) = boundMergeExpressions(i).eval(joinedRow)
159-
i += 1
160-
}
137+
throw new UnsupportedOperationException(
138+
"AlgebraicAggregate's merge should not be called directly")
161139
}
162140

163-
lazy val boundEvaluateExpression =
164-
BindReferences.bindReference(evaluateExpression, offsetExpressions ++ bufferAttributes)
165141
override def eval(buffer: InternalRow): Any = {
166-
println(s"eval: $buffer")
167-
val res = boundEvaluateExpression.eval(buffer)
168-
println(s"eval: $buffer with $boundEvaluateExpression => $res")
169-
res
142+
throw new UnsupportedOperationException(
143+
"AlgebraicAggregate's eval should not be called directly")
170144
}
171145
}
172146

sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate2/Aggregate2Sort.scala

Lines changed: 29 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -117,9 +117,9 @@ case class Aggregate2Sort(
117117
private val buffer: MutableRow = new GenericMutableRow(bufferSize)
118118
private val aggregateResult: MutableRow = new GenericMutableRow(aggregateAttributes.length)
119119
private val joinedRow = new JoinedRow4
120-
private val resultProjection =
121-
new InterpretedMutableProjection(
122-
resultExpressions, groupingExpressions.map(_.toAttribute) ++ aggregateAttributes)
120+
private lazy val resultProjection =
121+
newMutableProjection(
122+
resultExpressions, groupingExpressions.map(_.toAttribute) ++ aggregateAttributes)()
123123

124124
val offsetAttributes = if (preShuffle) Nil else Seq.fill(groupingExpressions.length)(AttributeReference("offset", NullType)())
125125
val offsetExpressions = if (preShuffle) Nil else Seq.fill(groupingExpressions.length)(NoOp)
@@ -128,7 +128,7 @@ case class Aggregate2Sort(
128128
val initExpressions = offsetExpressions ++ aggregateFunctions.flatMap {
129129
case ae: AlgebraicAggregate => ae.initialValues
130130
}
131-
println(initExpressions.mkString(","))
131+
// println(initExpressions.mkString(","))
132132
newMutableProjection(initExpressions, Nil)().target(buffer)
133133
}
134134

@@ -140,24 +140,38 @@ case class Aggregate2Sort(
140140
case ae: AlgebraicAggregate => ae.updateExpressions
141141
}
142142

143-
println(updateExpressions.mkString(","))
143+
// println(updateExpressions.mkString(","))
144144
newMutableProjection(updateExpressions, bufferSchema ++ child.output)().target(buffer)
145145
}
146146

147-
val mergeProjection = {
147+
lazy val mergeProjection = {
148148
val bufferSchemata =
149149
offsetAttributes ++ aggregateFunctions.flatMap {
150150
case ae: AlgebraicAggregate => ae.bufferAttributes
151151
} ++ offsetAttributes ++ aggregateFunctions.flatMap {
152152
case ae: AlgebraicAggregate => ae.rightBufferSchema
153153
}
154-
val mergeExpressions = offsetExpressions ++ aggregateFunctions.flatMap {
155-
case ae: AlgebraicAggregate => ae.mergeExpressions
156-
}
154+
val mergeExpressions = offsetExpressions ++ aggregateFunctions.flatMap {
155+
case ae: AlgebraicAggregate => ae.mergeExpressions
156+
}
157157

158158
newMutableProjection(mergeExpressions, bufferSchemata)()
159159
}
160160

161+
lazy val evalProjection = {
162+
val bufferSchemata =
163+
offsetAttributes ++ aggregateFunctions.flatMap {
164+
case ae: AlgebraicAggregate => ae.bufferAttributes
165+
} ++ offsetAttributes ++ aggregateFunctions.flatMap {
166+
case ae: AlgebraicAggregate => ae.rightBufferSchema
167+
}
168+
val evalExpressions = aggregateFunctions.map {
169+
case ae: AlgebraicAggregate => ae.evaluateExpression
170+
}
171+
172+
newMutableProjection(evalExpressions, bufferSchemata)()
173+
}
174+
161175
// Initialize this iterator.
162176
initialize()
163177

@@ -177,7 +191,7 @@ case class Aggregate2Sort(
177191

178192
private def initializeBuffer(): Unit = {
179193
initialProjection(EmptyRow)
180-
println("initilized: " + buffer)
194+
// println("initilized: " + buffer)
181195
}
182196

183197
private def processRow(row: InternalRow): Unit = {
@@ -230,16 +244,20 @@ case class Aggregate2Sort(
230244
// If it is preShuffle, we just output the grouping columns and the buffer.
231245
joinedRow(currentGroupingKey, buffer).copy()
232246
} else {
247+
/*
233248
var i = 0
234249
while (i < aggregateFunctions.length) {
235250
aggregateResult.update(i, aggregateFunctions(i).eval(buffer))
236251
i += 1
237252
}
238253
resultProjection(joinedRow(currentGroupingKey, aggregateResult)).copy()
254+
*/
255+
resultProjection(joinedRow(currentGroupingKey, evalProjection.target(aggregateResult)(buffer)))
256+
239257
}
240258
initializeBuffer()
241259

242-
println(s"outputRow $preShuffle " + outputRow)
260+
// println(s"outputRow $preShuffle " + outputRow)
243261
outputRow
244262
} else {
245263
// no more result

sql/core/src/test/scala/org/apache/spark/sql/execution/Aggregate2Suite.scala

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,24 @@ class Aggregate2Suite extends QueryTest with BeforeAndAfterAll {
7070
""".stripMargin),
7171
Row(1, 20.0) :: Row(2, -0.5) :: Row(3, null) :: Nil)
7272

73+
checkAnswer(
74+
ctx.sql(
75+
"""
76+
|SELECT avg(value), key
77+
|FROM agg2
78+
|GROUP BY key
79+
""".stripMargin),
80+
Row(20.0, 1) :: Row(-0.5, 2) :: Row(null, 3) :: Nil)
81+
82+
checkAnswer(
83+
ctx.sql(
84+
"""
85+
|SELECT avg(value) + 1.5, key + 10
86+
|FROM agg2
87+
|GROUP BY key + 10
88+
""".stripMargin),
89+
Row(21.5, 11) :: Row(1.0, 12) :: Row(null, 13) :: Nil)
90+
7391
}
7492

7593
override def afterAll(): Unit = {

0 commit comments

Comments
 (0)