Skip to content

Commit 6bbc6ba

Browse files
committed
now with correct answers\!
1 parent f7996d0 commit 6bbc6ba

File tree

2 files changed

+29
-14
lines changed

2 files changed

+29
-14
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate2/aggregates.scala

Lines changed: 20 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,8 @@ abstract class AlgebraicAggregate extends AggregateFunction2 with Serializable{
9696
/** Must be filled in by the executors */
9797
var inputSchema: Seq[Attribute] = _
9898

99+
def offsetExpressions: Seq[Attribute] = Seq.fill(bufferOffset)(AttributeReference("offset", NullType)())
100+
99101
lazy val rightBufferSchema = bufferSchema.map(_.newInstance())
100102
implicit class RichAttribute(a: AttributeReference) {
101103
def left = a
@@ -112,8 +114,11 @@ abstract class AlgebraicAggregate extends AggregateFunction2 with Serializable{
112114
}
113115

114116
lazy val boundUpdateExpressions = {
115-
val updateSchema = inputSchema ++ bufferSchema
116-
updateExpressions.map(BindReferences.bindReference(_, updateSchema)).toArray
117+
val updateSchema = inputSchema ++ offsetExpressions ++ bufferSchema
118+
val bound = updateExpressions.map(BindReferences.bindReference(_, updateSchema)).toArray
119+
println(s"update: ${updateExpressions.mkString(",")}")
120+
println(s"update: ${bound.mkString(",")}")
121+
bound
117122
}
118123

119124
val joinedRow = new JoinedRow
@@ -126,20 +131,27 @@ abstract class AlgebraicAggregate extends AggregateFunction2 with Serializable{
126131
}
127132

128133
lazy val boundMergeExpressions = {
129-
val mergeSchema = bufferSchema ++ rightBufferSchema
134+
val mergeSchema = offsetExpressions ++ bufferSchema ++ offsetExpressions ++ rightBufferSchema
130135
mergeExpressions.map(BindReferences.bindReference(_, mergeSchema)).toArray
131136
}
132137
override def merge(buffer1: MutableRow, buffer2: InternalRow): Unit = {
133138
var i = 0
139+
println(s"Merging: $buffer1 $buffer2 with ${boundMergeExpressions.mkString(",")}")
140+
joinedRow(buffer1, buffer2)
134141
while (i < bufferSchema.size) {
135-
buffer1(i + bufferOffset) = boundMergeExpressions(i).eval(joinedRow(buffer1, buffer2))
142+
println(s"$i + $bufferOffset: ${boundMergeExpressions(i).eval(joinedRow)}")
143+
buffer1(i + bufferOffset) = boundMergeExpressions(i).eval(joinedRow)
136144
i += 1
137145
}
138146
}
139147

140-
lazy val boundEvaluateExpression = BindReferences.bindReference(evaluateExpression, bufferSchema)
148+
lazy val boundEvaluateExpression =
149+
BindReferences.bindReference(evaluateExpression, offsetExpressions ++ bufferSchema)
141150
override def eval(buffer: InternalRow): Any = {
142-
boundEvaluateExpression.eval(buffer)
151+
println(s"eval: $buffer")
152+
val res = boundEvaluateExpression.eval(buffer)
153+
println(s"eval: $buffer with $boundEvaluateExpression => $res")
154+
res
143155
}
144156
}
145157

@@ -171,15 +183,15 @@ case class Average(child: Expression) extends AlgebraicAggregate {
171183
Add(
172184
currentSum,
173185
Coalesce(Cast(child, intermediateType) :: Cast(Literal(0), intermediateType) :: Nil)),
174-
/* currentCount = */ If(IsNotNull(child), currentCount, currentCount + 1L)
186+
/* currentCount = */ If(IsNull(child), currentCount, currentCount + 1L)
175187
)
176188

177189
val mergeExpressions = Seq(
178190
/* currentSum = */ currentSum.left + currentSum.right,
179191
/* currentCount = */ currentCount.left + currentCount.right
180192
)
181193

182-
val evaluateExpression = Cast(currentCount, resultType) / Cast(currentSum, resultType)
194+
val evaluateExpression = Cast(currentSum, resultType) / Cast(currentCount, resultType)
183195

184196
override def nullable: Boolean = false
185197
override def dataType: DataType = resultType

sql/core/src/test/scala/org/apache/spark/sql/execution/Aggregate2Suite.scala

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -61,12 +61,15 @@ class Aggregate2Suite extends QueryTest with BeforeAndAfterAll {
6161
|GROUP BY key
6262
""".stripMargin).queryExecution.executedPlan(3).execute().collect().foreach(println)
6363

64-
ctx.sql(
65-
"""
66-
|SELECT key, avg2(value)
67-
|FROM agg2
68-
|GROUP BY key
69-
""".stripMargin).show()
64+
checkAnswer(
65+
ctx.sql(
66+
"""
67+
|SELECT key, avg2(value)
68+
|FROM agg2
69+
|GROUP BY key
70+
""".stripMargin),
71+
Row(1, 20.0) :: Row(2, -0.5) :: Row(3, null) :: Nil)
72+
7073
}
7174

7275
override def afterAll(): Unit = {

0 commit comments

Comments
 (0)