Skip to content

Commit 39ee975

Browse files
committed
Code cleanup: Remove unnecesary AttributeReferences.
1 parent b7720ba commit 39ee975

File tree

3 files changed

+58
-45
lines changed

3 files changed

+58
-45
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate2/aggregates.scala

Lines changed: 24 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,9 @@ private[sql] case class AggregateExpression2(
6363

6464
override def eval(input: InternalRow = null): Any =
6565
throw new TreeNodeException(this, s"No function to evaluate expression. type: ${this.nodeName}")
66+
67+
def bufferSchema: StructType = aggregateFunction.bufferSchema
68+
def bufferAttributes: Seq[Attribute] = aggregateFunction.bufferAttributes
6669
}
6770

6871
abstract class AggregateFunction2
@@ -77,7 +80,11 @@ abstract class AggregateFunction2
7780
this
7881
}
7982

80-
def bufferValueDataTypes: StructType
83+
/** The schema of the aggregation buffer. */
84+
def bufferSchema: StructType
85+
86+
/** Attributes of fields in bufferSchema. */
87+
def bufferAttributes: Seq[Attribute]
8188

8289
def initialize(buffer: MutableRow): Unit
8390

@@ -94,7 +101,6 @@ abstract class AggregateFunction2
94101
abstract class AlgebraicAggregate extends AggregateFunction2 with Serializable{
95102
self: Product =>
96103

97-
val bufferSchema: Seq[Attribute]
98104
val initialValues: Seq[Expression]
99105
val updateExpressions: Seq[Expression]
100106
val mergeExpressions: Seq[Expression]
@@ -105,23 +111,25 @@ abstract class AlgebraicAggregate extends AggregateFunction2 with Serializable{
105111

106112
def offsetExpressions: Seq[Attribute] = Seq.fill(bufferOffset)(AttributeReference("offset", NullType)())
107113

108-
lazy val rightBufferSchema = bufferSchema.map(_.newInstance())
114+
lazy val rightBufferSchema = bufferAttributes.map(_.newInstance())
109115
implicit class RichAttribute(a: AttributeReference) {
110116
def left = a
111-
def right = rightBufferSchema(bufferSchema.indexOf(a))
117+
def right = rightBufferSchema(bufferAttributes.indexOf(a))
112118
}
113119

114-
override def bufferValueDataTypes: StructType = StructType.fromAttributes(bufferSchema)
120+
/** An AlgebraicAggregate's bufferSchema is derived from bufferAttributes. */
121+
override def bufferSchema: StructType = StructType.fromAttributes(bufferAttributes)
122+
115123
override def initialize(buffer: MutableRow): Unit = {
116124
var i = 0
117-
while (i < bufferSchema.size) {
125+
while (i < bufferAttributes.size) {
118126
buffer(i + bufferOffset) = initialValues(i).eval()
119127
i += 1
120128
}
121129
}
122130

123131
lazy val boundUpdateExpressions = {
124-
val updateSchema = inputSchema ++ offsetExpressions ++ bufferSchema
132+
val updateSchema = inputSchema ++ offsetExpressions ++ bufferAttributes
125133
val bound = updateExpressions.map(BindReferences.bindReference(_, updateSchema)).toArray
126134
println(s"update: ${updateExpressions.mkString(",")}")
127135
println(s"update: ${bound.mkString(",")}")
@@ -131,29 +139,29 @@ abstract class AlgebraicAggregate extends AggregateFunction2 with Serializable{
131139
val joinedRow = new JoinedRow
132140
override def update(buffer: MutableRow, input: InternalRow): Unit = {
133141
var i = 0
134-
while (i < bufferSchema.size) {
142+
while (i < bufferAttributes.size) {
135143
buffer(i + bufferOffset) = boundUpdateExpressions(i).eval(joinedRow(input, buffer))
136144
i += 1
137145
}
138146
}
139147

140148
lazy val boundMergeExpressions = {
141-
val mergeSchema = offsetExpressions ++ bufferSchema ++ offsetExpressions ++ rightBufferSchema
149+
val mergeSchema = offsetExpressions ++ bufferAttributes ++ offsetExpressions ++ rightBufferSchema
142150
mergeExpressions.map(BindReferences.bindReference(_, mergeSchema)).toArray
143151
}
144152
override def merge(buffer1: MutableRow, buffer2: InternalRow): Unit = {
145153
var i = 0
146154
println(s"Merging: $buffer1 $buffer2 with ${boundMergeExpressions.mkString(",")}")
147155
joinedRow(buffer1, buffer2)
148-
while (i < bufferSchema.size) {
156+
while (i < bufferAttributes.size) {
149157
println(s"$i + $bufferOffset: ${boundMergeExpressions(i).eval(joinedRow)}")
150158
buffer1(i + bufferOffset) = boundMergeExpressions(i).eval(joinedRow)
151159
i += 1
152160
}
153161
}
154162

155163
lazy val boundEvaluateExpression =
156-
BindReferences.bindReference(evaluateExpression, offsetExpressions ++ bufferSchema)
164+
BindReferences.bindReference(evaluateExpression, offsetExpressions ++ bufferAttributes)
157165
override def eval(buffer: InternalRow): Any = {
158166
println(s"eval: $buffer")
159167
val res = boundEvaluateExpression.eval(buffer)
@@ -170,26 +178,26 @@ case class Average(child: Expression) extends AlgebraicAggregate {
170178
case _ => DoubleType
171179
}
172180

173-
val intermediateType = child.dataType match {
181+
val sumDataType = child.dataType match {
174182
case _ @ DecimalType() => DecimalType.Unlimited
175183
case _ => DoubleType
176184
}
177185

178-
val currentSum = AttributeReference("currentSum", DoubleType)()
186+
val currentSum = AttributeReference("currentSum", sumDataType)()
179187
val currentCount = AttributeReference("currentCount", LongType)()
180188

181-
val bufferSchema = currentSum :: currentCount :: Nil
189+
override val bufferAttributes = currentSum :: currentCount :: Nil
182190

183191
val initialValues = Seq(
184-
/* currentSum = */ Cast(Literal(0), intermediateType),
192+
/* currentSum = */ Cast(Literal(0), sumDataType),
185193
/* currentCount = */ Literal(0L)
186194
)
187195

188196
val updateExpressions = Seq(
189197
/* currentSum = */
190198
Add(
191199
currentSum,
192-
Coalesce(Cast(child, intermediateType) :: Cast(Literal(0), intermediateType) :: Nil)),
200+
Coalesce(Cast(child, sumDataType) :: Cast(Literal(0), sumDataType) :: Nil)),
193201
/* currentCount = */ If(IsNull(child), currentCount, currentCount + 1L)
194202
)
195203

sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala

Lines changed: 31 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ import org.apache.spark.sql.catalyst.plans._
2424
import org.apache.spark.sql.catalyst.plans.logical.{BroadcastHint, LogicalPlan}
2525
import org.apache.spark.sql.catalyst.plans.physical._
2626
import org.apache.spark.sql.columnar.{InMemoryColumnarTableScan, InMemoryRelation}
27+
import org.apache.spark.sql.execution.aggregate2.Aggregate2Sort
2728
import org.apache.spark.sql.execution.{DescribeCommand => RunnableDescribeCommand}
2829
import org.apache.spark.sql.parquet._
2930
import org.apache.spark.sql.sources.{CreateTableUsing, CreateTempTableUsing, DescribeCommand => LogicalDescribeCommand, _}
@@ -186,67 +187,71 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
186187
exprs.flatMap(_.collect { case a: AggregateExpression => a })
187188
}
188189

190+
/**
191+
* Used to plan the aggregate operator for expressions based on the AggregateFunction2 interface.
192+
*/
189193
object AggregateOperator2 extends Strategy {
190194
def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
191195
case logical.Aggregate(groupingExpressions, resultExpressions, child)
192196
if sqlContext.conf.useSqlAggregate2 =>
193-
// 1. Extracts all aggregate expressions.
197+
// 1. Extracts all distinct aggregate expressions from the resultExpressions.
194198
val aggregateExpressions = resultExpressions.flatMap { expr =>
195199
expr.collect {
196200
case agg: AggregateExpression2 => agg
197201
}
198202
}.toSet.toSeq
199-
val aggregateExpressionMap = aggregateExpressions.zipWithIndex.map {
200-
case (agg, index) =>
201-
agg.aggregateFunction -> Alias(agg, s"_agg$index")().toAttribute
203+
// For those distinct aggregate expressions, we create a map from the aggregate function
204+
// to the corresponding attribute of the function.
205+
val aggregateFunctionMap = aggregateExpressions.map { agg =>
206+
val aggregateFunction = agg.aggregateFunction
207+
aggregateFunction -> Alias(aggregateFunction, aggregateFunction.toString)().toAttribute
202208
}.toMap
203209

204-
// 2. Create Pre-shuffle Aggregate Operator
205-
val namedGroupingExpressions = groupingExpressions.zipWithIndex.map {
206-
case (ne: NamedExpression, index) => ne
207-
case (other, index) => Alias(other, s"_groupingExpr$index")()
210+
// 2. Create an Aggregate Operator for partial aggregations.
211+
val namedGroupingExpressions = groupingExpressions.map {
212+
case ne: NamedExpression => ne
213+
// If the expression is not a NamedExpressions, we add an alias.
214+
// So, when we generate the result of the operator, the Aggregate Operator
215+
// can directly get the Seq of attributes representing the grouping expressions.
216+
case other => Alias(other, other.toString)()
208217
}
209218
val namedGroupingAttributes = namedGroupingExpressions.map(_.toAttribute)
210-
val preShuffleAggregateExpressions = aggregateExpressions.map {
219+
val partialAggregateExpressions = aggregateExpressions.map {
211220
case AggregateExpression2(aggregateFunction, mode, isDistinct) =>
212221
AggregateExpression2(aggregateFunction, Partial, isDistinct)
213222
}
214-
val preShuffleAggregateAttributes = preShuffleAggregateExpressions.zipWithIndex.flatMap {
215-
case (AggregateExpression2(aggregateFunction, Partial, isDistinct), index) =>
216-
aggregateFunction.bufferValueDataTypes.map {
217-
case StructField(name, dataType, nullable, metadata) =>
218-
AttributeReference(s"_partialAgg${index}_${name}", dataType, nullable, metadata)()
219-
}
223+
val partialAggregateAttributes = partialAggregateExpressions.flatMap { agg =>
224+
agg.bufferAttributes
220225
}
221226
val partialAggregate =
222227
Aggregate2Sort(
223228
true,
224229
namedGroupingExpressions,
225-
preShuffleAggregateExpressions,
226-
preShuffleAggregateAttributes,
227-
namedGroupingAttributes ++ preShuffleAggregateAttributes,
230+
partialAggregateExpressions,
231+
partialAggregateAttributes,
232+
namedGroupingAttributes ++ partialAggregateAttributes,
228233
planLater(child))
229234

230-
// 3. Create post-shuffle Aggregate Operator.
231-
val postShuffleAggregateExpressions = aggregateExpressions.map {
235+
// 3. Create an Aggregate Operator for final aggregations.
236+
val finalAggregateExpressions = aggregateExpressions.map {
232237
case AggregateExpression2(aggregateFunction, mode, isDistinct) =>
233238
AggregateExpression2(aggregateFunction, Final, isDistinct)
234239
}
235-
val postShuffleAggregateAttributes =
236-
postShuffleAggregateExpressions.map {
237-
expr => aggregateExpressionMap(expr.aggregateFunction)
240+
val finalAggregateAttributes =
241+
finalAggregateExpressions.map {
242+
expr => aggregateFunctionMap(expr.aggregateFunction)
238243
}
239244
val rewrittenResultExpressions = resultExpressions.map { expr =>
240245
expr.transform {
241246
case agg: AggregateExpression2 =>
242-
aggregateExpressionMap(agg.aggregateFunction).toAttribute
247+
aggregateFunctionMap(agg.aggregateFunction).toAttribute
243248
}.asInstanceOf[NamedExpression]
244249
}
245250
val finalAggregate = Aggregate2Sort(
246251
false,
247252
namedGroupingAttributes,
248-
postShuffleAggregateExpressions,
249-
postShuffleAggregateAttributes,
253+
finalAggregateExpressions,
254+
finalAggregateAttributes,
250255
rewrittenResultExpressions,
251256
partialAggregate)
252257

sql/core/src/test/scala/org/apache/spark/sql/execution/Aggregate2Suite.scala

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -49,22 +49,22 @@ class Aggregate2Suite extends QueryTest with BeforeAndAfterAll {
4949
test("test average2") {
5050
ctx.sql(
5151
"""
52-
|SELECT key, avg2(value)
52+
|SELECT key, avg(value)
5353
|FROM agg2
5454
|GROUP BY key
5555
""".stripMargin).explain(true)
5656

5757
ctx.sql(
5858
"""
59-
|SELECT key, avg2(value)
59+
|SELECT key, avg(value)
6060
|FROM agg2
6161
|GROUP BY key
6262
""".stripMargin).queryExecution.executedPlan(3).execute().collect().foreach(println)
6363

6464
checkAnswer(
6565
ctx.sql(
6666
"""
67-
|SELECT key, avg2(value)
67+
|SELECT key, avg(value)
6868
|FROM agg2
6969
|GROUP BY key
7070
""".stripMargin),

0 commit comments

Comments
 (0)