@@ -24,26 +24,53 @@ import org.apache.spark.sql.catalyst.{expressions, InternalRow}
2424import org .apache .spark .sql .catalyst .trees .{LeafNode , UnaryNode }
2525import org .apache .spark .sql .types ._
2626
27+ /** The mode of an [[AggregateFunction ]]. */
2728private [sql] sealed trait AggregateMode
2829
30+ /**
31+ * An [[AggregateFunction ]] with [[Partial ]] mode is used for partial aggregation.
32+ * This function updates the given aggregation buffer with the original input of this
33+ * function. When it has processed all input rows, the aggregation buffer is returned.
34+ */
2935private [sql] case object Partial extends AggregateMode
3036
37+ /**
38+ * An [[AggregateFunction ]] with [[PartialMerge ]] mode is used to merge aggregation buffers
39+ * containing intermediate results for this function.
40+ * This function updates the given aggregation buffer by merging multiple aggregation buffers.
41+ * When it has processed all input rows, the aggregation buffer is returned.
42+ */
3143private [sql] case object PartialMerge extends AggregateMode
3244
45+ /**
46+ * An [[AggregateFunction ]] with [[PartialMerge ]] mode is used to merge aggregation buffers
47+ * containing intermediate results for this function and the generate final result.
48+ * This function updates the given aggregation buffer by merging multiple aggregation buffers.
49+ * When it has processed all input rows, the final result of this function is returned.
50+ */
3351private [sql] case object Final extends AggregateMode
3452
53+ /**
54+ * An [[AggregateFunction2 ]] with [[Partial ]] mode is used to evaluate this function directly
55+ * from original input rows without any partial aggregation.
56+ * This function updates the given aggregation buffer with the original input of this
57+ * function. When it has processed all input rows, the final result of this function is returned.
58+ */
3559private [sql] case object Complete extends AggregateMode
3660
37- case object NoOp extends Expression {
61+ private [sql] case object NoOp extends Expression {
3862 override def nullable : Boolean = true
39- override def eval (input : InternalRow ): Any = ???
63+ override def eval (input : InternalRow ): Any = {
64+ throw new TreeNodeException (
65+ this , s " No function to evaluate expression. type: ${this .nodeName}" )
66+ }
4067 override def dataType : DataType = NullType
4168 override def children : Seq [Expression ] = Nil
4269}
4370
4471/**
45- * A container of a Aggregate Function, Aggregate Mode, and a field (`isDistinct`) indicating
46- * if DISTINCT keyword is specified for this function.
72+ * A container for an [[ AggregateFunction2 ]] with its [[ AggregateMode ]] and a field
73+ * (`isDistinct`) indicating if DISTINCT keyword is specified for this function.
4774 * @param aggregateFunction
4875 * @param mode
4976 * @param isDistinct
@@ -54,60 +81,84 @@ private[sql] case class AggregateExpression2(
5481 isDistinct : Boolean ) extends Expression {
5582
5683 override def children : Seq [Expression ] = aggregateFunction :: Nil
57-
5884 override def dataType : DataType = aggregateFunction.dataType
5985 override def foldable : Boolean = false
6086 override def nullable : Boolean = aggregateFunction.nullable
6187
6288 override def toString : String = s " ( ${aggregateFunction}2,mode= $mode,isDistinct= $isDistinct) "
6389
64- override def eval (input : InternalRow = null ): Any =
65- throw new TreeNodeException (this , s " No function to evaluate expression. type: ${this .nodeName}" )
66-
67- def bufferSchema : StructType = aggregateFunction.bufferSchema
68- def bufferAttributes : Seq [Attribute ] = aggregateFunction.bufferAttributes
90+ override def eval (input : InternalRow = null ): Any = {
91+ throw new TreeNodeException (
92+ this , s " No function to evaluate expression. type: ${this .nodeName}" )
93+ }
6994}
7095
7196abstract class AggregateFunction2
7297 extends Expression {
7398
7499 self : Product =>
75100
76- var bufferOffset : Int = 0
77-
101+ /** An aggregate function is not foldable. */
78102 override def foldable : Boolean = false
79103
104+ /**
105+ * The offset of this function's buffer in the underlying buffer shared with other functions.
106+ */
107+ var bufferOffset : Int = 0
108+
80109 /** The schema of the aggregation buffer. */
81110 def bufferSchema : StructType
82111
83112 /** Attributes of fields in bufferSchema. */
84113 def bufferAttributes : Seq [Attribute ]
85114
86- def rightBufferSchema : Seq [Attribute ]
115+ /** Clones bufferAttributes. */
116+ def cloneBufferAttributes : Seq [Attribute ]
87117
118+ /**
119+ * Initializes its aggregation buffer located in `buffer`.
120+ * It will use bufferOffset to find the starting point of
121+ * its buffer in the given `buffer` shared with other functions.
122+ */
88123 def initialize (buffer : MutableRow ): Unit
89124
125+ /**
126+ * Updates its aggregation buffer located in `buffer` based on the given `input`.
127+ * It will use bufferOffset to find the starting point of its buffer in the given `buffer`
128+ * shared with other functions.
129+ */
90130 def update (buffer : MutableRow , input : InternalRow ): Unit
91131
132+ /**
133+ * Updates its aggregation buffer located in `buffer1` by combining intermediate results
134+ * in the current buffer and intermediate results from another buffer `buffer2`.
135+ * It will use bufferOffset to find the starting point of its buffer in the given `buffer1`
136+ * and `buffer2`.
137+ */
92138 def merge (buffer1 : MutableRow , buffer2 : InternalRow ): Unit
93-
94- override def eval (buffer : InternalRow = null ): Any
95139}
96140
141+ /**
142+ * An example [[AggregateFunction2 ]] that is not an [[AlgebraicAggregate ]].
143+ * This function calculate the sum of double values.
144+ * @param child
145+ */
97146case class MyDoubleSum (child : Expression ) extends AggregateFunction2 {
98147 override val bufferSchema : StructType =
99148 StructType (StructField (" currentSum" , DoubleType , true ) :: Nil )
100149
101150 override val bufferAttributes : Seq [Attribute ] = bufferSchema.toAttributes
102- override lazy val rightBufferSchema = bufferAttributes.map(_.newInstance())
151+
152+ override lazy val cloneBufferAttributes = bufferAttributes.map(_.newInstance())
153+
103154 override def initialize (buffer : MutableRow ): Unit = {
104155 buffer.update(bufferOffset, null )
105156 }
106157
107158 override def update (buffer : MutableRow , input : InternalRow ): Unit = {
108159 val inputValue = child.eval(input)
109160 if (inputValue != null ) {
110- if (buffer.isNullAt(bufferOffset) == null ) {
161+ if (buffer.isNullAt(bufferOffset)) {
111162 buffer.setDouble(bufferOffset, inputValue.asInstanceOf [Double ])
112163 } else {
113164 val currentSum = buffer.getDouble(bufferOffset)
@@ -151,10 +202,11 @@ abstract class AlgebraicAggregate extends AggregateFunction2 with Serializable {
151202 val mergeExpressions : Seq [Expression ]
152203 val evaluateExpression : Expression
153204
154- override lazy val rightBufferSchema = bufferAttributes.map(_.newInstance())
205+ override lazy val cloneBufferAttributes = bufferAttributes.map(_.newInstance())
206+
155207 implicit class RichAttribute (a : AttributeReference ) {
156208 def left = a
157- def right = rightBufferSchema (bufferAttributes.indexOf(a))
209+ def right = cloneBufferAttributes (bufferAttributes.indexOf(a))
158210 }
159211
160212 /** An AlgebraicAggregate's bufferSchema is derived from bufferAttributes. */
0 commit comments