Skip to content

Commit d821a34

Browse files
committed
Cleanup.
1 parent 32aea9c commit d821a34

File tree

3 files changed

+193
-98
lines changed

3 files changed

+193
-98
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate2/aggregates.scala

Lines changed: 71 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -24,26 +24,53 @@ import org.apache.spark.sql.catalyst.{expressions, InternalRow}
2424
import org.apache.spark.sql.catalyst.trees.{LeafNode, UnaryNode}
2525
import org.apache.spark.sql.types._
2626

27+
/** The mode of an [[AggregateFunction]]. */
2728
private[sql] sealed trait AggregateMode
2829

30+
/**
31+
* An [[AggregateFunction]] with [[Partial]] mode is used for partial aggregation.
32+
* This function updates the given aggregation buffer with the original input of this
33+
* function. When it has processed all input rows, the aggregation buffer is returned.
34+
*/
2935
private[sql] case object Partial extends AggregateMode
3036

37+
/**
38+
* An [[AggregateFunction]] with [[PartialMerge]] mode is used to merge aggregation buffers
39+
* containing intermediate results for this function.
40+
* This function updates the given aggregation buffer by merging multiple aggregation buffers.
41+
* When it has processed all input rows, the aggregation buffer is returned.
42+
*/
3143
private[sql] case object PartialMerge extends AggregateMode
3244

45+
/**
46+
* An [[AggregateFunction]] with [[PartialMerge]] mode is used to merge aggregation buffers
47+
* containing intermediate results for this function and the generate final result.
48+
* This function updates the given aggregation buffer by merging multiple aggregation buffers.
49+
* When it has processed all input rows, the final result of this function is returned.
50+
*/
3351
private[sql] case object Final extends AggregateMode
3452

53+
/**
54+
* An [[AggregateFunction2]] with [[Partial]] mode is used to evaluate this function directly
55+
* from original input rows without any partial aggregation.
56+
* This function updates the given aggregation buffer with the original input of this
57+
* function. When it has processed all input rows, the final result of this function is returned.
58+
*/
3559
private[sql] case object Complete extends AggregateMode
3660

37-
case object NoOp extends Expression {
61+
private[sql] case object NoOp extends Expression {
3862
override def nullable: Boolean = true
39-
override def eval(input: InternalRow): Any = ???
63+
override def eval(input: InternalRow): Any = {
64+
throw new TreeNodeException(
65+
this, s"No function to evaluate expression. type: ${this.nodeName}")
66+
}
4067
override def dataType: DataType = NullType
4168
override def children: Seq[Expression] = Nil
4269
}
4370

4471
/**
45-
* A container of a Aggregate Function, Aggregate Mode, and a field (`isDistinct`) indicating
46-
* if DISTINCT keyword is specified for this function.
72+
* A container for an [[AggregateFunction2]] with its [[AggregateMode]] and a field
73+
* (`isDistinct`) indicating if DISTINCT keyword is specified for this function.
4774
* @param aggregateFunction
4875
* @param mode
4976
* @param isDistinct
@@ -54,60 +81,84 @@ private[sql] case class AggregateExpression2(
5481
isDistinct: Boolean) extends Expression {
5582

5683
override def children: Seq[Expression] = aggregateFunction :: Nil
57-
5884
override def dataType: DataType = aggregateFunction.dataType
5985
override def foldable: Boolean = false
6086
override def nullable: Boolean = aggregateFunction.nullable
6187

6288
override def toString: String = s"(${aggregateFunction}2,mode=$mode,isDistinct=$isDistinct)"
6389

64-
override def eval(input: InternalRow = null): Any =
65-
throw new TreeNodeException(this, s"No function to evaluate expression. type: ${this.nodeName}")
66-
67-
def bufferSchema: StructType = aggregateFunction.bufferSchema
68-
def bufferAttributes: Seq[Attribute] = aggregateFunction.bufferAttributes
90+
override def eval(input: InternalRow = null): Any = {
91+
throw new TreeNodeException(
92+
this, s"No function to evaluate expression. type: ${this.nodeName}")
93+
}
6994
}
7095

7196
abstract class AggregateFunction2
7297
extends Expression {
7398

7499
self: Product =>
75100

76-
var bufferOffset: Int = 0
77-
101+
/** An aggregate function is not foldable. */
78102
override def foldable: Boolean = false
79103

104+
/**
105+
* The offset of this function's buffer in the underlying buffer shared with other functions.
106+
*/
107+
var bufferOffset: Int = 0
108+
80109
/** The schema of the aggregation buffer. */
81110
def bufferSchema: StructType
82111

83112
/** Attributes of fields in bufferSchema. */
84113
def bufferAttributes: Seq[Attribute]
85114

86-
def rightBufferSchema: Seq[Attribute]
115+
/** Clones bufferAttributes. */
116+
def cloneBufferAttributes: Seq[Attribute]
87117

118+
/**
119+
* Initializes its aggregation buffer located in `buffer`.
120+
* It will use bufferOffset to find the starting point of
121+
* its buffer in the given `buffer` shared with other functions.
122+
*/
88123
def initialize(buffer: MutableRow): Unit
89124

125+
/**
126+
* Updates its aggregation buffer located in `buffer` based on the given `input`.
127+
* It will use bufferOffset to find the starting point of its buffer in the given `buffer`
128+
* shared with other functions.
129+
*/
90130
def update(buffer: MutableRow, input: InternalRow): Unit
91131

132+
/**
133+
* Updates its aggregation buffer located in `buffer1` by combining intermediate results
134+
* in the current buffer and intermediate results from another buffer `buffer2`.
135+
* It will use bufferOffset to find the starting point of its buffer in the given `buffer1`
136+
* and `buffer2`.
137+
*/
92138
def merge(buffer1: MutableRow, buffer2: InternalRow): Unit
93-
94-
override def eval(buffer: InternalRow = null): Any
95139
}
96140

141+
/**
142+
* An example [[AggregateFunction2]] that is not an [[AlgebraicAggregate]].
143+
* This function calculate the sum of double values.
144+
* @param child
145+
*/
97146
case class MyDoubleSum(child: Expression) extends AggregateFunction2 {
98147
override val bufferSchema: StructType =
99148
StructType(StructField("currentSum", DoubleType, true) :: Nil)
100149

101150
override val bufferAttributes: Seq[Attribute] = bufferSchema.toAttributes
102-
override lazy val rightBufferSchema = bufferAttributes.map(_.newInstance())
151+
152+
override lazy val cloneBufferAttributes = bufferAttributes.map(_.newInstance())
153+
103154
override def initialize(buffer: MutableRow): Unit = {
104155
buffer.update(bufferOffset, null)
105156
}
106157

107158
override def update(buffer: MutableRow, input: InternalRow): Unit = {
108159
val inputValue = child.eval(input)
109160
if (inputValue != null) {
110-
if (buffer.isNullAt(bufferOffset) == null) {
161+
if (buffer.isNullAt(bufferOffset)) {
111162
buffer.setDouble(bufferOffset, inputValue.asInstanceOf[Double])
112163
} else {
113164
val currentSum = buffer.getDouble(bufferOffset)
@@ -151,10 +202,11 @@ abstract class AlgebraicAggregate extends AggregateFunction2 with Serializable {
151202
val mergeExpressions: Seq[Expression]
152203
val evaluateExpression: Expression
153204

154-
override lazy val rightBufferSchema = bufferAttributes.map(_.newInstance())
205+
override lazy val cloneBufferAttributes = bufferAttributes.map(_.newInstance())
206+
155207
implicit class RichAttribute(a: AttributeReference) {
156208
def left = a
157-
def right = rightBufferSchema(bufferAttributes.indexOf(a))
209+
def right = cloneBufferAttributes(bufferAttributes.indexOf(a))
158210
}
159211

160212
/** An AlgebraicAggregate's bufferSchema is derived from bufferAttributes. */

sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -235,11 +235,10 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
235235
AggregateExpression2(aggregateFunction, Partial, isDistinct)
236236
}
237237
val partialAggregateAttributes = partialAggregateExpressions.flatMap { agg =>
238-
agg.bufferAttributes
238+
agg.aggregateFunction.bufferAttributes
239239
}
240240
val partialAggregate =
241241
Aggregate2Sort(
242-
true,
243242
namedGroupingExpressions.map(_._2),
244243
partialAggregateExpressions,
245244
partialAggregateAttributes,
@@ -264,7 +263,6 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
264263
}.asInstanceOf[NamedExpression]
265264
}
266265
val finalAggregate = Aggregate2Sort(
267-
false,
268266
namedGroupingAttributes,
269267
finalAggregateExpressions,
270268
finalAggregateAttributes,

0 commit comments

Comments
 (0)