Skip to content

Commit e0afca3

Browse files
committed
Gracefully fallback to old aggregation code path.
1 parent 8a8ac4a commit e0afca3

File tree

13 files changed

+264
-66
lines changed

13 files changed

+264
-66
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -250,18 +250,19 @@ case class Sum(child: Expression) extends AlgebraicAggregate {
250250
override def dataType: DataType = resultType
251251

252252
// Expected input data type.
253-
override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(NumericType, NullType))
253+
override def inputTypes: Seq[AbstractDataType] =
254+
Seq(TypeCollection(LongType, DoubleType, DecimalType, NullType))
254255

255256
private val resultType = child.dataType match {
256257
case DecimalType.Fixed(precision, scale) =>
257258
DecimalType(precision + 4, scale + 4)
258259
case DecimalType.Unlimited => DecimalType.Unlimited
259-
case _ => DoubleType
260+
case _ => child.dataType
260261
}
261262

262263
private val sumDataType = child.dataType match {
263264
case _ @ DecimalType() => DecimalType.Unlimited
264-
case _ => DoubleType
265+
case _ => child.dataType
265266
}
266267

267268
private val currentSum = AttributeReference("currentSum", sumDataType)()

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -71,8 +71,6 @@ private[sql] case object NoOp extends Expression with Unevaluable {
7171
override def children: Seq[Expression] = Nil
7272
}
7373

74-
75-
7674
/**
7775
* A container for an [[AggregateFunction2]] with its [[AggregateMode]] and a field
7876
* (`isDistinct`) indicating if DISTINCT keyword is specified for this function.

sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -865,7 +865,7 @@ class SQLContext(@transient val sparkContext: SparkContext)
865865
DDLStrategy ::
866866
TakeOrderedAndProject ::
867867
HashAggregation ::
868-
AggregateOperator2 ::
868+
Aggregation ::
869869
LeftSemiJoin ::
870870
HashJoin ::
871871
InMemoryScans ::

sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala

Lines changed: 59 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -184,8 +184,12 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
184184
case _ => Nil
185185
}
186186

187-
def canBeConvertedToNewAggregation(plan: LogicalPlan): Boolean =
188-
aggregate.Utils.tryConvert(plan, sqlContext.conf.useSqlAggregate2).isDefined
187+
def canBeConvertedToNewAggregation(plan: LogicalPlan): Boolean = {
188+
aggregate.Utils.tryConvert(
189+
plan,
190+
sqlContext.conf.useSqlAggregate2,
191+
sqlContext.conf.codegenEnabled).isDefined
192+
}
189193

190194
def canBeCodeGened(aggs: Seq[AggregateExpression1]): Boolean = !aggs.exists {
191195
case _: CombineSum | _: Sum | _: Count | _: Max | _: Min | _: CombineSetsAndCount => false
@@ -202,50 +206,62 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
202206
/**
203207
* Used to plan the aggregate operator for expressions based on the AggregateFunction2 interface.
204208
*/
205-
object AggregateOperator2 extends Strategy {
209+
object Aggregation extends Strategy {
206210
def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
207-
case aggregate.NewAggregation(groupingExpressions, resultExpressions, child)
208-
if sqlContext.conf.useSqlAggregate2 =>
209-
// Extracts all distinct aggregate expressions from the resultExpressions.
210-
val aggregateExpressions = resultExpressions.flatMap { expr =>
211-
expr.collect {
212-
case agg: AggregateExpression2 => agg
213-
}
214-
}.toSet.toSeq
215-
// For those distinct aggregate expressions, we create a map from the aggregate function
216-
// to the corresponding attribute of the function.
217-
val aggregateFunctionMap = aggregateExpressions.map { agg =>
218-
val aggregateFunction = agg.aggregateFunction
219-
(aggregateFunction, agg.isDistinct) ->
220-
Alias(aggregateFunction, aggregateFunction.toString)().toAttribute
221-
}.toMap
222-
223-
val (functionsWithDistinct, functionsWithoutDistinct) =
224-
aggregateExpressions.partition(_.isDistinct)
225-
if (functionsWithDistinct.map(_.aggregateFunction.children).distinct.length > 1) {
226-
// This is a sanity check. We should not reach here since we check the same thing in
227-
// CheckAggregateFunction.
228-
sys.error("Having more than one distinct column sets is not allowed.")
211+
case p: logical.Aggregate =>
212+
val converted =
213+
aggregate.Utils.tryConvert(
214+
p,
215+
sqlContext.conf.useSqlAggregate2,
216+
sqlContext.conf.codegenEnabled)
217+
converted match {
218+
case None => Nil // Cannot convert to new aggregation code path.
219+
case Some(logical.Aggregate(groupingExpressions, resultExpressions, child)) =>
220+
// Extracts all distinct aggregate expressions from the resultExpressions.
221+
val aggregateExpressions = resultExpressions.flatMap { expr =>
222+
expr.collect {
223+
case agg: AggregateExpression2 => agg
224+
}
225+
}.toSet.toSeq
226+
// For those distinct aggregate expressions, we create a map from the
227+
// aggregate function to the corresponding attribute of the function.
228+
val aggregateFunctionMap = aggregateExpressions.map { agg =>
229+
val aggregateFunction = agg.aggregateFunction
230+
(aggregateFunction, agg.isDistinct) ->
231+
Alias(aggregateFunction, aggregateFunction.toString)().toAttribute
232+
}.toMap
233+
234+
val (functionsWithDistinct, functionsWithoutDistinct) =
235+
aggregateExpressions.partition(_.isDistinct)
236+
if (functionsWithDistinct.map(_.aggregateFunction.children).distinct.length > 1) {
237+
// This is a sanity check. We should not reach here when we have multiple distinct
238+
// column sets (aggregate.NewAggregation will not match).
239+
sys.error(
240+
"Multiple distinct column sets are not supported by the new aggregation" +
241+
"code path.")
242+
}
243+
244+
val aggregateOperator =
245+
if (functionsWithDistinct.isEmpty) {
246+
aggregate.Utils.planAggregateWithoutDistinct(
247+
groupingExpressions,
248+
aggregateExpressions,
249+
aggregateFunctionMap,
250+
resultExpressions,
251+
planLater(child))
252+
} else {
253+
aggregate.Utils.planAggregateWithOneDistinct(
254+
groupingExpressions,
255+
functionsWithDistinct,
256+
functionsWithoutDistinct,
257+
aggregateFunctionMap,
258+
resultExpressions,
259+
planLater(child))
260+
}
261+
262+
aggregateOperator
229263
}
230-
val aggregateOperator =
231-
if (functionsWithDistinct.isEmpty) {
232-
aggregate.Utils.planAggregateWithoutDistinct(
233-
groupingExpressions,
234-
aggregateExpressions,
235-
aggregateFunctionMap,
236-
resultExpressions,
237-
planLater(child))
238-
} else {
239-
aggregate.Utils.planAggregateWithOneDistinct(
240-
groupingExpressions,
241-
functionsWithDistinct,
242-
functionsWithoutDistinct,
243-
aggregateFunctionMap,
244-
resultExpressions,
245-
planLater(child))
246-
}
247264

248-
aggregateOperator
249265
case _ => Nil
250266
}
251267
}

sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/sortBasedIterators.scala

Lines changed: 51 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -149,7 +149,7 @@ private[sql] abstract class SortAggregationIterator(
149149
}
150150
}
151151

152-
private def initialize(): Unit = {
152+
protected def initialize(): Unit = {
153153
if (inputIter.hasNext) {
154154
initializeBuffer()
155155
val currentRow = inputIter.next().copy()
@@ -474,6 +474,31 @@ class FinalSortAggregationIterator(
474474

475475
override protected def initialBufferOffset: Int = groupingExpressions.length
476476

477+
override def initialize(): Unit = {
478+
if (inputIter.hasNext) {
479+
initializeBuffer()
480+
val currentRow = inputIter.next().copy()
481+
// partitionGenerator is a mutable projection. Since we need to track nextGroupingKey,
482+
// we are making a copy at here.
483+
nextGroupingKey = groupGenerator(currentRow).copy()
484+
firstRowInNextGroup = currentRow
485+
} else {
486+
if (groupingExpressions.isEmpty) {
487+
// If there is no grouping expression, we need to generate a single row as the output.
488+
initializeBuffer()
489+
// Right now, the buffer only contains initial buffer values. Because
490+
// merging two buffers with initial values will generate a row that
491+
// still store initial values. We set the currentRow as the copy of the current buffer.
492+
val currentRow = buffer.copy()
493+
nextGroupingKey = groupGenerator(currentRow).copy()
494+
firstRowInNextGroup = currentRow
495+
} else {
496+
// This iter is an empty one.
497+
hasNewGroup = false
498+
}
499+
}
500+
}
501+
477502
override protected def processRow(row: InternalRow): Unit = {
478503
// Process all algebraic aggregate functions.
479504
algebraicMergeProjection.target(buffer)(joinedRow(buffer, row))
@@ -659,6 +684,31 @@ class FinalAndCompleteSortAggregationIterator(
659684
newMutableProjection(evalExpressions, bufferSchemata)()
660685
}
661686

687+
override def initialize(): Unit = {
688+
if (inputIter.hasNext) {
689+
initializeBuffer()
690+
val currentRow = inputIter.next().copy()
691+
// partitionGenerator is a mutable projection. Since we need to track nextGroupingKey,
692+
// we are making a copy at here.
693+
nextGroupingKey = groupGenerator(currentRow).copy()
694+
firstRowInNextGroup = currentRow
695+
} else {
696+
if (groupingExpressions.isEmpty) {
697+
// If there is no grouping expression, we need to generate a single row as the output.
698+
initializeBuffer()
699+
// Right now, the buffer only contains initial buffer values. Because
700+
// merging two buffers with initial values will generate a row that
701+
// still store initial values. We set the currentRow as the copy of the current buffer.
702+
val currentRow = buffer.copy()
703+
nextGroupingKey = groupGenerator(currentRow).copy()
704+
firstRowInNextGroup = currentRow
705+
} else {
706+
// This iter is an empty one.
707+
hasNewGroup = false
708+
}
709+
}
710+
}
711+
662712
override protected def processRow(row: InternalRow): Unit = {
663713
val input = joinedRow(buffer, row)
664714
// For all aggregate functions with mode Complete, update buffers.

sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala

Lines changed: 50 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,20 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
118
package org.apache.spark.sql.execution.aggregate
219

320
import org.apache.spark.sql.AnalysisException
@@ -6,10 +23,23 @@ import org.apache.spark.sql.catalyst.expressions._
623
import org.apache.spark.sql.catalyst.expressions.aggregate._
724
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan}
825
import org.apache.spark.sql.execution.SparkPlan
26+
import org.apache.spark.sql.types.{StructType, MapType, ArrayType}
927

1028
object Utils {
29+
// Right now, we do not support complex types in the grouping key schema.
30+
private def groupingKeySchemaIsSupported(aggregate: Aggregate): Boolean = {
31+
val hasComplexTypes = aggregate.groupingExpressions.map(_.dataType).exists {
32+
case array: ArrayType => true
33+
case map: MapType => true
34+
case struct: StructType => true
35+
case _ => false
36+
}
37+
38+
!hasComplexTypes
39+
}
40+
1141
private def tryConvert(plan: LogicalPlan): Option[Aggregate] = plan match {
12-
case p: Aggregate =>
42+
case p: Aggregate if groupingKeySchemaIsSupported(p) =>
1343
val converted = p.transformExpressionsDown {
1444
case expressions.Average(child) =>
1545
aggregate.AggregateExpression2(
@@ -76,17 +106,32 @@ object Utils {
76106
}.isDefined
77107
}
78108

79-
if (!hasAggregateExpression1) Some(converted) else None
109+
// Check if there are multiple distinct columns.
110+
val aggregateExpressions = converted.aggregateExpressions.flatMap { expr =>
111+
expr.collect {
112+
case agg: AggregateExpression2 => agg
113+
}
114+
}.toSet.toSeq
115+
val functionsWithDistinct = aggregateExpressions.filter(_.isDistinct)
116+
val hasMultipleDistinctColumnSets =
117+
if (functionsWithDistinct.map(_.aggregateFunction.children).distinct.length > 1) {
118+
true
119+
} else {
120+
false
121+
}
122+
123+
if (!hasAggregateExpression1 && !hasMultipleDistinctColumnSets) Some(converted) else None
80124

81125
case other => None
82126
}
83127

84128
def tryConvert(
85129
plan: LogicalPlan,
86-
useNewAggregation: Boolean): Option[Aggregate] = plan match {
87-
case p: Aggregate =>
130+
useNewAggregation: Boolean,
131+
codeGenEnabled: Boolean): Option[Aggregate] = plan match {
132+
case p: Aggregate if useNewAggregation && codeGenEnabled =>
88133
val converted = tryConvert(p)
89-
if (useNewAggregation && converted.isDefined) {
134+
if (converted.isDefined) {
90135
converted
91136
} else {
92137
// If the plan cannot be converted, we will do a final round check to if the original

sql/core/src/main/scala/org/apache/spark/sql/expressions/aggregate/udaf.scala

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,11 @@ abstract class UserDefinedAggregateFunction extends Serializable {
6767
/** Indicates if this function is deterministic. */
6868
def deterministic: Boolean
6969

70-
/** Initializes the given aggregation buffer. */
70+
/**
71+
* Initializes the given aggregation buffer. Initial values set by this method should satisfy
72+
* the condition that when merging two buffers with initial values, the new buffer should
73+
* still store initial values.
74+
*/
7175
def initialize(buffer: MutableAggregationBuffer): Unit
7276

7377
/** Updates the given aggregation buffer `buffer` with new input data from `input`. */

sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ import java.sql.Timestamp
2323

2424
import org.apache.spark.sql.catalyst.DefaultParserDialect
2525
import org.apache.spark.sql.catalyst.errors.DialectException
26+
import org.apache.spark.sql.execution.aggregate.Aggregate2Sort
2627
import org.apache.spark.sql.execution.GeneratedAggregate
2728
import org.apache.spark.sql.functions._
2829
import org.apache.spark.sql.TestData._
@@ -204,6 +205,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils {
204205
var hasGeneratedAgg = false
205206
df.queryExecution.executedPlan.foreach {
206207
case generatedAgg: GeneratedAggregate => hasGeneratedAgg = true
208+
case newAggregate: Aggregate2Sort => hasGeneratedAgg = true
207209
case _ =>
208210
}
209211
if (!hasGeneratedAgg) {
@@ -285,7 +287,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils {
285287
// Aggregate with Code generation handling all null values
286288
testCodeGen(
287289
"SELECT sum('a'), avg('a'), count(null) FROM testData",
288-
Row(0, null, 0) :: Nil)
290+
Row(null, null, 0) :: Nil)
289291
} finally {
290292
sqlContext.dropTempTable("testData3x")
291293
sqlContext.setConf(SQLConf.CODEGEN_ENABLED, originalValue)

0 commit comments

Comments
 (0)