Skip to content

Commit d1cb0b7

Browse files
cloud-fancmonkey
authored andcommitted
[SPARK-19060][SQL] remove the supportsPartial flag in AggregateFunction
## What changes were proposed in this pull request? Now all aggregation functions support partial aggregate, we can remove the `supportsPartual` flag in `AggregateFunction` ## How was this patch tested? existing tests. Author: Wenchen Fan <[email protected]> Closes apache#16461 from cloud-fan/partial.
1 parent 659eb1c commit d1cb0b7

File tree

7 files changed

+3
-48
lines changed

7 files changed

+3
-48
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -173,12 +173,6 @@ abstract class AggregateFunction extends Expression {
173173
*/
174174
def inputAggBufferAttributes: Seq[AttributeReference]
175175

176-
/**
177-
* Indicates if this function supports partial aggregation.
178-
* Currently Hive UDAF is the only one that doesn't support partial aggregation.
179-
*/
180-
def supportsPartial: Boolean = true
181-
182176
/**
183177
* Result of the aggregate function when the input is empty. This is currently only used for the
184178
* proper rewriting of distinct aggregate functions.

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -436,7 +436,6 @@ abstract class AggregateWindowFunction extends DeclarativeAggregate with WindowF
436436
override val frame = SpecifiedWindowFrame(RowFrame, UnboundedPreceding, CurrentRow)
437437
override def dataType: DataType = IntegerType
438438
override def nullable: Boolean = true
439-
override def supportsPartial: Boolean = false
440439
override lazy val mergeExpressions =
441440
throw new UnsupportedOperationException("Window Functions do not support merging.")
442441
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -131,11 +131,8 @@ object RewriteDistinctAggregates extends Rule[LogicalPlan] {
131131
}
132132
}
133133

134-
// Check if the aggregates contains functions that do not support partial aggregation.
135-
val existsNonPartial = aggExpressions.exists(!_.aggregateFunction.supportsPartial)
136-
137-
// Aggregation strategy can handle queries with a single distinct group and partial aggregates.
138-
if (distinctAggGroups.size > 1 || (distinctAggGroups.size == 1 && existsNonPartial)) {
134+
// Aggregation strategy can handle queries with a single distinct group.
135+
if (distinctAggGroups.size > 1) {
139136
// Create the attributes for the grouping id and the group by clause.
140137
val gid = AttributeReference("gid", IntegerType, nullable = false)(isGenerated = true)
141138
val groupByMap = a.groupingExpressions.collect {

sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala

Lines changed: 1 addition & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -262,18 +262,7 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
262262
}
263263

264264
val aggregateOperator =
265-
if (aggregateExpressions.map(_.aggregateFunction).exists(!_.supportsPartial)) {
266-
if (functionsWithDistinct.nonEmpty) {
267-
sys.error("Distinct columns cannot exist in Aggregate operator containing " +
268-
"aggregate functions which don't support partial aggregation.")
269-
} else {
270-
aggregate.AggUtils.planAggregateWithoutPartial(
271-
groupingExpressions,
272-
aggregateExpressions,
273-
resultExpressions,
274-
planLater(child))
275-
}
276-
} else if (functionsWithDistinct.isEmpty) {
265+
if (functionsWithDistinct.isEmpty) {
277266
aggregate.AggUtils.planAggregateWithoutDistinct(
278267
groupingExpressions,
279268
aggregateExpressions,

sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala

Lines changed: 0 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -27,26 +27,6 @@ import org.apache.spark.sql.internal.SQLConf
2727
* Utility functions used by the query planner to convert our plan to new aggregation code path.
2828
*/
2929
object AggUtils {
30-
31-
def planAggregateWithoutPartial(
32-
groupingExpressions: Seq[NamedExpression],
33-
aggregateExpressions: Seq[AggregateExpression],
34-
resultExpressions: Seq[NamedExpression],
35-
child: SparkPlan): Seq[SparkPlan] = {
36-
37-
val completeAggregateExpressions = aggregateExpressions.map(_.copy(mode = Complete))
38-
val completeAggregateAttributes = completeAggregateExpressions.map(_.resultAttribute)
39-
SortAggregateExec(
40-
requiredChildDistributionExpressions = Some(groupingExpressions),
41-
groupingExpressions = groupingExpressions,
42-
aggregateExpressions = completeAggregateExpressions,
43-
aggregateAttributes = completeAggregateAttributes,
44-
initialInputBufferOffset = 0,
45-
resultExpressions = resultExpressions,
46-
child = child
47-
) :: Nil
48-
}
49-
5030
private def createAggregate(
5131
requiredChildDistributionExpressions: Option[Seq[Expression]] = None,
5232
groupingExpressions: Seq[NamedExpression] = Nil,

sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -380,8 +380,6 @@ private[hive] case class HiveUDAFFunction(
380380

381381
override def nullable: Boolean = true
382382

383-
override def supportsPartial: Boolean = true
384-
385383
override lazy val dataType: DataType = inspectorToDataType(returnInspector)
386384

387385
override def prettyName: String = name

sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/TestingTypedCount.scala

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,8 +42,6 @@ case class TestingTypedCount(
4242

4343
override def nullable: Boolean = false
4444

45-
override val supportsPartial: Boolean = true
46-
4745
override def createAggregationBuffer(): State = TestingTypedCount.State(0L)
4846

4947
override def update(buffer: State, input: InternalRow): State = {

0 commit comments

Comments
 (0)