Skip to content

Commit fc522d5

Browse files
committed
Hook generated aggregation in to the planner.
1 parent e742640 commit fc522d5

File tree

6 files changed

+110
-52
lines changed

6 files changed

+110
-52
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,62 @@ object PhysicalOperation extends PredicateHelper {
104104
}
105105
}
106106

107+
object PartialAggregation {
108+
type ReturnType =
109+
(Seq[Attribute], Seq[NamedExpression], Seq[Expression], Seq[NamedExpression], LogicalPlan)
110+
111+
def unapply(plan: LogicalPlan): Option[ReturnType] = plan match {
112+
case logical.Aggregate(groupingExpressions, aggregateExpressions, child) =>
113+
// Collect all aggregate expressions.
114+
val allAggregates =
115+
aggregateExpressions.flatMap(_ collect { case a: AggregateExpression => a})
116+
// Collect all aggregate expressions that can be computed partially.
117+
val partialAggregates =
118+
aggregateExpressions.flatMap(_ collect { case p: PartialAggregate => p})
119+
120+
// Only do partial aggregation if supported by all aggregate expressions.
121+
if (allAggregates.size == partialAggregates.size) {
122+
// Create a map of expressions to their partial evaluations for all aggregate expressions.
123+
val partialEvaluations: Map[Long, SplitEvaluation] =
124+
partialAggregates.map(a => (a.id, a.asPartial)).toMap
125+
126+
// We need to pass all grouping expressions though so the grouping can happen a second
127+
// time. However some of them might be unnamed so we alias them allowing them to be
128+
// referenced in the second aggregation.
129+
val namedGroupingExpressions: Map[Expression, NamedExpression] = groupingExpressions.map {
130+
case n: NamedExpression => (n, n)
131+
case other => (other, Alias(other, "PartialGroup")())
132+
}.toMap
133+
134+
// Replace aggregations with a new expression that computes the result from the already
135+
// computed partial evaluations and grouping values.
136+
val rewrittenAggregateExpressions = aggregateExpressions.map(_.transformUp {
137+
case e: Expression if partialEvaluations.contains(e.id) =>
138+
partialEvaluations(e.id).finalEvaluation
139+
case e: Expression if namedGroupingExpressions.contains(e) =>
140+
namedGroupingExpressions(e).toAttribute
141+
}).asInstanceOf[Seq[NamedExpression]]
142+
143+
val partialComputation =
144+
(namedGroupingExpressions.values ++
145+
partialEvaluations.values.flatMap(_.partialEvaluations)).toSeq
146+
147+
val namedGroupingAttributes = namedGroupingExpressions.values.map(_.toAttribute).toSeq
148+
149+
Some(
150+
(namedGroupingAttributes,
151+
rewrittenAggregateExpressions,
152+
groupingExpressions,
153+
partialComputation,
154+
child))
155+
} else {
156+
None
157+
}
158+
case _ => None
159+
}
160+
}
161+
162+
107163
/**
108164
* A pattern that finds joins with equality conditions that can be evaluated using equi-join.
109165
*/

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/CodeGenerationSuite.scala

Whitespace-only changes.

sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -239,7 +239,7 @@ class SQLContext(@transient val sparkContext: SparkContext)
239239
val strategies: Seq[Strategy] =
240240
CommandStrategy(self) ::
241241
TakeOrdered ::
242-
PartialAggregation ::
242+
HashAggregation ::
243243
LeftSemiJoin ::
244244
HashJoin ::
245245
InMemoryScans ::

sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala

Lines changed: 41 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -95,58 +95,57 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
9595
}
9696
}
9797

98-
object PartialAggregation extends Strategy {
98+
object HashAggregation extends Strategy {
9999
def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
100-
case logical.Aggregate(groupingExpressions, aggregateExpressions, child) =>
101-
// Collect all aggregate expressions.
102-
val allAggregates =
103-
aggregateExpressions.flatMap(_ collect { case a: AggregateExpression => a })
104-
// Collect all aggregate expressions that can be computed partially.
105-
val partialAggregates =
106-
aggregateExpressions.flatMap(_ collect { case p: PartialAggregate => p })
100+
// Aggregations that can be performed in two phases, before and after the shuffle.
107101

108-
// Only do partial aggregation if supported by all aggregate expressions.
109-
if (allAggregates.size == partialAggregates.size) {
110-
// Create a map of expressions to their partial evaluations for all aggregate expressions.
111-
val partialEvaluations: Map[Long, SplitEvaluation] =
112-
partialAggregates.map(a => (a.id, a.asPartial)).toMap
113-
114-
// We need to pass all grouping expressions though so the grouping can happen a second
115-
// time. However some of them might be unnamed so we alias them allowing them to be
116-
// referenced in the second aggregation.
117-
val namedGroupingExpressions: Map[Expression, NamedExpression] = groupingExpressions.map {
118-
case n: NamedExpression => (n, n)
119-
case other => (other, Alias(other, "PartialGroup")())
120-
}.toMap
121-
122-
// Replace aggregations with a new expression that computes the result from the already
123-
// computed partial evaluations and grouping values.
124-
val rewrittenAggregateExpressions = aggregateExpressions.map(_.transformUp {
125-
case e: Expression if partialEvaluations.contains(e.id) =>
126-
partialEvaluations(e.id).finalEvaluation
127-
case e: Expression if namedGroupingExpressions.contains(e) =>
128-
namedGroupingExpressions(e).toAttribute
129-
}).asInstanceOf[Seq[NamedExpression]]
130-
131-
val partialComputation =
132-
(namedGroupingExpressions.values ++
133-
partialEvaluations.values.flatMap(_.partialEvaluations)).toSeq
134-
135-
// Construct two phased aggregation.
136-
execution.Aggregate(
102+
// Where all aggregates can be codegened.
103+
case PartialAggregation(
104+
namedGroupingAttributes,
105+
rewrittenAggregateExpressions,
106+
groupingExpressions,
107+
partialComputation,
108+
child)
109+
if canBeCodeGened(
110+
allAggregates(partialComputation) ++
111+
allAggregates(rewrittenAggregateExpressions))=>
112+
execution.HashAggregate(
137113
partial = false,
138-
namedGroupingExpressions.values.map(_.toAttribute).toSeq,
114+
namedGroupingAttributes,
139115
rewrittenAggregateExpressions,
140-
execution.Aggregate(
116+
execution.HashAggregate(
141117
partial = true,
142118
groupingExpressions,
143119
partialComputation,
144120
planLater(child))(sqlContext))(sqlContext) :: Nil
145-
} else {
146-
Nil
147-
}
121+
122+
123+
// Where some aggregate can not be codegened
124+
case PartialAggregation(
125+
namedGroupingAttributes,
126+
rewrittenAggregateExpressions,
127+
groupingExpressions,
128+
partialComputation,
129+
child) =>
130+
execution.Aggregate(
131+
partial = false,
132+
namedGroupingAttributes,
133+
rewrittenAggregateExpressions,
134+
execution.Aggregate(
135+
partial = true,
136+
groupingExpressions,
137+
partialComputation,
138+
planLater(child))(sqlContext))(sqlContext) :: Nil
148139
case _ => Nil
149140
}
141+
142+
def canBeCodeGened(aggs: Seq[AggregateExpression]) = !aggs.exists {
143+
case _: Sum | _: Count => false
144+
case _ => true
145+
}
146+
147+
def allAggregates(exprs: Seq[Expression]) =
148+
exprs.flatMap(_.collect { case a: AggregateExpression => a })
150149
}
151150

152151
object BroadcastNestedLoopJoin extends Strategy {

sql/core/src/main/scala/org/apache/spark/sql/execution/aggregates.scala

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
package org.apache.spark.sql.execution
1919

2020
import org.apache.spark.SparkContext
21+
import org.apache.spark.sql.SQLContext
2122
import org.apache.spark.sql.catalyst.errors._
2223
import org.apache.spark.sql.catalyst.expressions._
2324
import org.apache.spark.sql.catalyst.plans.physical._
@@ -33,12 +34,14 @@ import org.apache.spark.sql.catalyst.types._
3334
* @param child the input data source.
3435
*/
3536
case class HashAggregate(
36-
partial: Boolean,
37-
groupingExpressions: Seq[Expression],
38-
aggregateExpressions: Seq[NamedExpression],
39-
child: SparkPlan)(@transient sc: SparkContext)
37+
partial: Boolean,
38+
groupingExpressions: Seq[Expression],
39+
aggregateExpressions: Seq[NamedExpression],
40+
child: SparkPlan)(@transient sqlContext: SQLContext)
4041
extends UnaryNode with NoBind {
4142

43+
private def sc = sqlContext.sparkContext
44+
4245
override def requiredChildDistribution =
4346
if (partial) {
4447
UnspecifiedDistribution :: Nil
@@ -50,7 +53,7 @@ case class HashAggregate(
5053
}
5154
}
5255

53-
override def otherCopyArgs = sc :: Nil
56+
override def otherCopyArgs = sqlContext :: Nil
5457

5558
def output = aggregateExpressions.map(_.toAttribute)
5659

sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -39,22 +39,22 @@ class PlannerSuite extends FunSuite {
3939

4040
test("count is partially aggregated") {
4141
val query = testData.groupBy('value)(Count('key)).queryExecution.analyzed
42-
val planned = PartialAggregation(query).head
43-
val aggregations = planned.collect { case a: Aggregate => a }
42+
val planned = HashAggregation(query).head
43+
val aggregations = planned.collect { case n if n.nodeName contains "Aggregate" => n }
4444

4545
assert(aggregations.size === 2)
4646
}
4747

4848
test("count distinct is not partially aggregated") {
4949
val query = testData.groupBy('value)(CountDistinct('key :: Nil)).queryExecution.analyzed
50-
val planned = PartialAggregation(query)
50+
val planned = HashAggregation(query)
5151
assert(planned.isEmpty)
5252
}
5353

5454
test("mixed aggregates are not partially aggregated") {
5555
val query =
5656
testData.groupBy('value)(Count('value), CountDistinct('key :: Nil)).queryExecution.analyzed
57-
val planned = PartialAggregation(query)
57+
val planned = HashAggregation(query)
5858
assert(planned.isEmpty)
5959
}
6060
}

0 commit comments

Comments
 (0)