Skip to content

Commit 5b46d41

Browse files
committed
Bug fix.
1 parent aff9534 commit 5b46d41

File tree

6 files changed

+89
-56
lines changed

6 files changed

+89
-56
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate2/aggregates.scala

Lines changed: 6 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ private[sql] case class AggregateExpression2(
5656
override def children: Seq[Expression] = aggregateFunction :: Nil
5757

5858
override def dataType: DataType = aggregateFunction.dataType
59-
override def foldable: Boolean = aggregateFunction.foldable
59+
override def foldable: Boolean = false
6060
override def nullable: Boolean = aggregateFunction.nullable
6161

6262
override def toString: String = s"(${aggregateFunction}2,mode=$mode,isDistinct=$isDistinct)"
@@ -75,17 +75,16 @@ abstract class AggregateFunction2
7575

7676
var bufferOffset: Int = 0
7777

78-
def withBufferOffset(newBufferOffset: Int): AggregateFunction2 = {
79-
bufferOffset = newBufferOffset
80-
this
81-
}
78+
override def foldable: Boolean = false
8279

8380
/** The schema of the aggregation buffer. */
8481
def bufferSchema: StructType
8582

8683
/** Attributes of fields in bufferSchema. */
8784
def bufferAttributes: Seq[Attribute]
8885

86+
def rightBufferSchema: Seq[Attribute]
87+
8988
def initialize(buffer: MutableRow): Unit
9089

9190
def update(buffer: MutableRow, input: InternalRow): Unit
@@ -100,7 +99,7 @@ case class MyDoubleSum(child: Expression) extends AggregateFunction2 {
10099
StructType(StructField("currentSum", DoubleType, true) :: Nil)
101100

102101
override val bufferAttributes: Seq[Attribute] = bufferSchema.toAttributes
103-
102+
override lazy val rightBufferSchema = bufferAttributes.map(_.newInstance())
104103
override def initialize(buffer: MutableRow): Unit = {
105104
buffer.update(bufferOffset, null)
106105
}
@@ -152,17 +151,7 @@ abstract class AlgebraicAggregate extends AggregateFunction2 with Serializable {
152151
val mergeExpressions: Seq[Expression]
153152
val evaluateExpression: Expression
154153

155-
/** Must be filled in by the executors */
156-
var inputSchema: Seq[Attribute] = _
157-
158-
override def withBufferOffset(newBufferOffset: Int): AlgebraicAggregate = {
159-
bufferOffset = newBufferOffset
160-
this
161-
}
162-
163-
def offsetExpressions: Seq[Attribute] = Seq.fill(bufferOffset)(AttributeReference("offset", NullType)())
164-
165-
lazy val rightBufferSchema = bufferAttributes.map(_.newInstance())
154+
override lazy val rightBufferSchema = bufferAttributes.map(_.newInstance())
166155
implicit class RichAttribute(a: AttributeReference) {
167156
def left = a
168157
def right = rightBufferSchema(bufferAttributes.indexOf(a))

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate2/AggregateExpressionSuite.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,8 @@ class AggregateExpressionSuite extends SparkFunSuite with ExpressionEvalHelper {
2525

2626
test("Average") {
2727
val inputValues = Array(Int.MaxValue, null, 1000, Int.MinValue, 2)
28-
val avg = Average(child = BoundReference(0, IntegerType, true)).withBufferOffset(2)
28+
val avg = Average(child = BoundReference(0, IntegerType, true))
29+
avg.bufferOffset = 2
2930
val inputRow = new GenericMutableRow(1)
3031
val buffer = new GenericMutableRow(4)
3132
avg.initialize(buffer)

sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -205,12 +205,6 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
205205
def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
206206
case logical.Aggregate(groupingExpressions, resultExpressions, child)
207207
if sqlContext.conf.useSqlAggregate2 =>
208-
// 0. Make sure we can convert.
209-
resultExpressions.foreach {
210-
case agg1: AggregateExpression =>
211-
sys.error(s"$agg1 is not supported. Please set spark.sql.useAggregate2 to false.")
212-
case _ => // ok
213-
}
214208
// 1. Extracts all distinct aggregate expressions from the resultExpressions.
215209
val aggregateExpressions = resultExpressions.flatMap { expr =>
216210
expr.collect {

sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate2/Aggregate2Sort.scala

Lines changed: 36 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -71,34 +71,35 @@ case class Aggregate2Sort(
7171
while (i < aggregateExpressions.length) {
7272
val func = aggregateExpressions(i).aggregateFunction
7373
bufferOffsets += bufferOffset
74-
bufferOffset = aggregateExpressions(i).mode match {
75-
case Partial | PartialMerge => bufferOffset + func.bufferSchema.length
76-
case Final | Complete => bufferOffset + 1
77-
}
74+
bufferOffset += func.bufferSchema.length
7875
i += 1
7976
}
8077
aggregateExpressions.zip(bufferOffsets)
8178
}
82-
83-
private val algebraicAggregateFunctions: Array[AlgebraicAggregate] = {
84-
aggregateExprsWithBufferOffset.collect {
85-
case (AggregateExpression2(agg: AlgebraicAggregate, mode, isDistinct), offset) =>
86-
agg.inputSchema = child.output
87-
agg.withBufferOffset(offset)
79+
// println("aggregateExprsWithBufferOffset " + aggregateExprsWithBufferOffset)
80+
81+
private val aggregateFunctions: Array[AggregateFunction2] = {
82+
aggregateExprsWithBufferOffset.map {
83+
case (aggExpr, bufferOffset) =>
84+
val func = aggExpr.aggregateFunction
85+
func.bufferOffset = bufferOffset
86+
func
8887
}.toArray
8988
}
9089

9190
private val nonAlgebraicAggregateFunctions: Array[AggregateFunction2] = {
9291
aggregateExprsWithBufferOffset.collect {
9392
case (AggregateExpression2(agg: AggregateFunction2, mode, isDistinct), offset)
9493
if !agg.isInstanceOf[AlgebraicAggregate] =>
95-
val func = agg.withBufferOffset(offset)
9694
mode match {
9795
case Partial | Complete =>
9896
// Only need to bind reference when the function is not an AlgebraicAggregate
9997
// and the mode is Partial or Complete.
100-
BindReferences.bindReference(func, child.output)
101-
case _ => func
98+
val func = BindReferences.bindReference(agg, child.output)
99+
// Need to set it again since BindReference will create a new instance.
100+
func.bufferOffset = offset
101+
func
102+
case _ => agg
102103
}
103104
}.toArray
104105
}
@@ -119,13 +120,8 @@ case class Aggregate2Sort(
119120
private val bufferSize: Int = {
120121
var size = 0
121122
var i = 0
122-
while (i < algebraicAggregateFunctions.length) {
123-
size += algebraicAggregateFunctions(i).bufferSchema.length
124-
i += 1
125-
}
126-
i = 0
127-
while (i < nonAlgebraicAggregateFunctions.length) {
128-
size += nonAlgebraicAggregateFunctions(i).bufferSchema.length
123+
while (i < aggregateFunctions.length) {
124+
size += aggregateFunctions(i).bufferSchema.length
129125
i += 1
130126
}
131127
if (preShuffle) {
@@ -160,20 +156,23 @@ case class Aggregate2Sort(
160156
val offsetExpressions = if (preShuffle) Nil else Seq.fill(groupingExpressions.length)(NoOp)
161157

162158
val algebraicInitialProjection = {
163-
val initExpressions = offsetExpressions ++ algebraicAggregateFunctions.flatMap {
159+
val initExpressions = offsetExpressions ++ aggregateFunctions.flatMap {
164160
case ae: AlgebraicAggregate => ae.initialValues
161+
case agg: AggregateFunction2 => NoOp :: Nil
165162
}
166163
// println(initExpressions.mkString(","))
167164

168165
newMutableProjection(initExpressions, Nil)().target(buffer)
169166
}
170167

171168
lazy val algebraicUpdateProjection = {
172-
val bufferSchema = algebraicAggregateFunctions.flatMap {
169+
val bufferSchema = aggregateFunctions.flatMap {
173170
case ae: AlgebraicAggregate => ae.bufferAttributes
171+
case agg: AggregateFunction2 => agg.bufferAttributes
174172
}
175-
val updateExpressions = algebraicAggregateFunctions.flatMap {
173+
val updateExpressions = aggregateFunctions.flatMap {
176174
case ae: AlgebraicAggregate => ae.updateExpressions
175+
case agg: AggregateFunction2 => NoOp :: Nil
177176
}
178177

179178
// println(updateExpressions.mkString(","))
@@ -182,27 +181,33 @@ case class Aggregate2Sort(
182181

183182
lazy val algebraicMergeProjection = {
184183
val bufferSchemata =
185-
offsetAttributes ++ algebraicAggregateFunctions.flatMap {
184+
offsetAttributes ++ aggregateFunctions.flatMap {
186185
case ae: AlgebraicAggregate => ae.bufferAttributes
187-
} ++ offsetAttributes ++ algebraicAggregateFunctions.flatMap {
186+
case agg: AggregateFunction2 => agg.bufferAttributes
187+
} ++ offsetAttributes ++ aggregateFunctions.flatMap {
188188
case ae: AlgebraicAggregate => ae.rightBufferSchema
189+
case agg: AggregateFunction2 => agg.rightBufferSchema
189190
}
190-
val mergeExpressions = offsetExpressions ++ algebraicAggregateFunctions.flatMap {
191+
val mergeExpressions = offsetExpressions ++ aggregateFunctions.flatMap {
191192
case ae: AlgebraicAggregate => ae.mergeExpressions
193+
case agg: AggregateFunction2 => NoOp :: Nil
192194
}
193195

194196
newMutableProjection(mergeExpressions, bufferSchemata)()
195197
}
196198

197199
lazy val algebraicEvalProjection = {
198200
val bufferSchemata =
199-
offsetAttributes ++ algebraicAggregateFunctions.flatMap {
201+
offsetAttributes ++ aggregateFunctions.flatMap {
200202
case ae: AlgebraicAggregate => ae.bufferAttributes
201-
} ++ offsetAttributes ++ algebraicAggregateFunctions.flatMap {
203+
case agg: AggregateFunction2 => agg.bufferAttributes
204+
} ++ offsetAttributes ++ aggregateFunctions.flatMap {
202205
case ae: AlgebraicAggregate => ae.rightBufferSchema
206+
case agg: AggregateFunction2 => agg.rightBufferSchema
203207
}
204-
val evalExpressions = algebraicAggregateFunctions.map {
208+
val evalExpressions = aggregateFunctions.map {
205209
case ae: AlgebraicAggregate => ae.evaluateExpression
210+
case agg: AggregateFunction2 => NoOp
206211
}
207212

208213
newMutableProjection(evalExpressions, bufferSchemata)()
@@ -251,6 +256,7 @@ case class Aggregate2Sort(
251256
nonAlgebraicAggregateFunctions(i).merge(buffer, row)
252257
i += 1
253258
}
259+
// println("buffer merge " + buffer + " " + row)
254260
}
255261
}
256262

@@ -293,6 +299,7 @@ case class Aggregate2Sort(
293299
val outputRow =
294300
if (preShuffle) {
295301
// If it is preShuffle, we just output the grouping columns and the buffer.
302+
// println("buffer " + buffer)
296303
joinedRow(currentGroupingKey, buffer).copy()
297304
} else {
298305
algebraicEvalProjection.target(aggregateResult)(buffer)
@@ -304,7 +311,6 @@ case class Aggregate2Sort(
304311
i += 1
305312
}
306313
resultProjection(joinedRow(currentGroupingKey, aggregateResult))
307-
308314
}
309315
initializeBuffer()
310316

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,8 @@
1717

1818
package org.apache.spark.sql.execution.aggregate2
1919

20-
import org.apache.spark.sql.SQLContext
21-
import org.apache.spark.sql.catalyst.expressions.{Average => Average1}
20+
import org.apache.spark.sql.{SQLConf, AnalysisException, SQLContext}
21+
import org.apache.spark.sql.catalyst.expressions.{Average => Average1, AggregateExpression}
2222
import org.apache.spark.sql.catalyst.expressions.aggregate2.{Average => Average2, AggregateExpression2, Complete}
2323
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
2424
import org.apache.spark.sql.catalyst.rules.Rule
@@ -32,3 +32,18 @@ case class ConvertAggregateFunction(context: SQLContext) extends Rule[LogicalPla
3232
}
3333
}
3434
}
35+
36+
case class CheckAggregateFunction(context: SQLContext) extends (LogicalPlan => Unit) {
37+
def failAnalysis(msg: String): Nothing = { throw new AnalysisException(msg) }
38+
39+
def apply(plan: LogicalPlan): Unit = plan.foreachUp {
40+
case p if context.conf.useSqlAggregate2 => p.transformExpressionsUp {
41+
case agg: AggregateExpression =>
42+
failAnalysis(s"${SQLConf.USE_SQL_AGGREGATE2} is enabled. Please disable it to use $agg.")
43+
}
44+
case p if !context.conf.useSqlAggregate2 => p.transformExpressionsUp {
45+
case agg: AggregateExpression2 =>
46+
failAnalysis(s"${SQLConf.USE_SQL_AGGREGATE2} is disabled. Please enable it to use $agg.")
47+
}
48+
}
49+
}

sql/hive/src/test/java/org/apache/spark/sql/hive/execution/Aggregate2Suite.scala

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,34 @@ class Aggregate2Suite extends QueryTest with BeforeAndAfterAll {
154154
Row(null) :: Nil)
155155

156156
}
157+
158+
test("non-AlgebraicAggregate and AlgebraicAggregate aggreguate function") {
159+
checkAnswer(
160+
ctx.sql(
161+
"""
162+
|SELECT mydoublesum(cast(value as double)), key, avg(value)
163+
|FROM agg2
164+
|GROUP BY key
165+
""".stripMargin),
166+
Row(60.0, 1, 20.0) :: Row(-1.0, 2, -0.5) :: Row(null, 3, null) :: Nil)
167+
168+
checkAnswer(
169+
ctx.sql(
170+
"""
171+
|SELECT
172+
| mydoublesum(cast(value as double) + 1.5 * key),
173+
| avg(value - key),
174+
| key,
175+
| mydoublesum(cast(value as double) - 1.5 * key),
176+
| avg(value)
177+
|FROM agg2
178+
|GROUP BY key
179+
""".stripMargin),
180+
Row(64.5, 19.0, 1, 55.5, 20.0) ::
181+
Row(5.0, -2.5, 2, -7.0, -0.5) ::
182+
Row(null, null, 3, null, null) :: Nil)
183+
}
184+
157185
override def afterAll(): Unit = {
158186
ctx.sql(s"set spark.sql.useAggregate2=$originalUseAggregate2")
159187
}

0 commit comments

Comments
 (0)