1717
1818package org .apache .spark .sql .execution
1919
20- import org .apache .spark .SparkContext
2120import org .apache .spark .annotation .DeveloperApi
2221import org .apache .spark .sql .SQLContext
23- import org .apache .spark .sql .catalyst .errors ._
2422import org .apache .spark .sql .catalyst .expressions ._
2523import org .apache .spark .sql .catalyst .plans .physical ._
2624import org .apache .spark .sql .catalyst .types ._
@@ -51,8 +49,6 @@ case class GeneratedAggregate(
5149 child : SparkPlan )(@ transient sqlContext : SQLContext )
5250 extends UnaryNode with NoBind {
5351
54- private def sc = sqlContext.sparkContext
55-
5652 override def requiredChildDistribution =
5753 if (partial) {
5854 UnspecifiedDistribution :: Nil
@@ -66,24 +62,24 @@ case class GeneratedAggregate(
6662
6763 override def otherCopyArgs = sqlContext :: Nil
6864
69- def output = aggregateExpressions.map(_.toAttribute)
65+ override def output = aggregateExpressions.map(_.toAttribute)
7066
71- def execute () = {
67+ override def execute () = {
7268 val aggregatesToCompute = aggregateExpressions.flatMap { a =>
7369 a.collect { case agg : AggregateExpression => agg}
7470 }
7571
7672 val computeFunctions = aggregatesToCompute.map {
77- case c@ Count (expr) =>
78- val currentCount = AttributeReference (" currentCount" , LongType , true )()
73+ case c @ Count (expr) =>
74+ val currentCount = AttributeReference (" currentCount" , LongType , nullable = false )()
7975 val initialValue = Literal (0L )
8076 val updateFunction = If (IsNotNull (expr), Add (currentCount, Literal (1L )), currentCount)
8177 val result = currentCount
8278
8379 AggregateEvaluation (currentCount :: Nil , initialValue :: Nil , updateFunction :: Nil , result)
8480
8581 case Sum (expr) =>
86- val currentSum = AttributeReference (" currentSum" , expr.dataType, true )()
82+ val currentSum = AttributeReference (" currentSum" , expr.dataType, nullable = false )()
8783 val initialValue = Cast (Literal (0L ), expr.dataType)
8884
8985 // Coalasce avoids double calculation...
@@ -93,9 +89,9 @@ case class GeneratedAggregate(
9389
9490 AggregateEvaluation (currentSum :: Nil , initialValue :: Nil , updateFunction :: Nil , result)
9591
96- case a@ Average (expr) =>
97- val currentCount = AttributeReference (" currentCount" , LongType , true )()
98- val currentSum = AttributeReference (" currentSum" , expr.dataType, true )()
92+ case a @ Average (expr) =>
93+ val currentCount = AttributeReference (" currentCount" , LongType , nullable = false )()
94+ val currentSum = AttributeReference (" currentSum" , expr.dataType, nullable = false )()
9995 val initialCount = Literal (0L )
10096 val initialSum = Cast (Literal (0L ), expr.dataType)
10197 val updateCount = If (IsNotNull (expr), Add (currentCount, Literal (1L )), currentCount)
@@ -131,50 +127,70 @@ case class GeneratedAggregate(
131127
132128 child.execute().mapPartitions { iter =>
133129 // Builds a new custom class for holding the results of aggregation for a group.
130+ @ transient
134131 val newAggregationBuffer =
135132 newProjection(computeFunctions.flatMap(_.initialValues), child.output)
136133
137134 // A projection that is used to update the aggregate values for a group given a new tuple.
138135 // This projection should be targeted at the current values for the group and then applied
139136 // to a joined row of the current values with the new input row.
137+ @ transient
140138 val updateProjection =
141139 newMutableProjection(
142140 computeFunctions.flatMap(_.update),
143141 computeFunctions.flatMap(_.schema) ++ child.output)()
144142
145143 // A projection that computes the group given an input tuple.
144+ @ transient
146145 val groupProjection = newProjection(groupingExpressions, child.output)
147146
148147 // A projection that produces the final result, given a computation.
148+ @ transient
149149 val resultProjectionBuilder =
150150 newMutableProjection(
151151 resultExpressions,
152152 (namedGroups.map(_._2.toAttribute) ++ computationSchema).toSeq)
153153
154- val buffers = new java.util.HashMap [Row , MutableRow ]()
155154 val joinedRow = new JoinedRow
156155
157- var currentRow : Row = null
158- while (iter.hasNext) {
159- currentRow = iter.next()
160- val currentGroup = groupProjection(currentRow)
161- var currentBuffer = buffers.get(currentGroup)
162- if (currentBuffer == null ) {
163- currentBuffer = newAggregationBuffer(EmptyRow ).asInstanceOf [MutableRow ]
164- buffers.put(currentGroup, currentBuffer)
156+ if (groupingExpressions.isEmpty) {
157+ // TODO: Codegening anything other than the updateProjection is probably over kill.
158+ val buffer = newAggregationBuffer(EmptyRow ).asInstanceOf [MutableRow ]
159+ var currentRow : Row = null
160+ while (iter.hasNext) {
161+ currentRow = iter.next()
162+ updateProjection.target(buffer)(joinedRow(buffer, currentRow))
163+ }
164+
165+ val resultProjection = resultProjectionBuilder()
166+ Iterator (resultProjection(buffer))
167+ } else {
168+ val buffers = new java.util.HashMap [Row , MutableRow ]()
169+
170+ var currentRow : Row = null
171+ while (iter.hasNext) {
172+ currentRow = iter.next()
173+ val currentGroup = groupProjection(currentRow)
174+ var currentBuffer = buffers.get(currentGroup)
175+ if (currentBuffer == null ) {
176+ currentBuffer = newAggregationBuffer(EmptyRow ).asInstanceOf [MutableRow ]
177+ buffers.put(currentGroup, currentBuffer)
178+ }
179+ // Target the projection at the current aggregation buffer and then project the updated
180+ // values.
181+ updateProjection.target(currentBuffer)(joinedRow(currentBuffer, currentRow))
165182 }
166- updateProjection.target(currentBuffer)(joinedRow(currentBuffer, currentRow))
167- }
168183
169- new Iterator [Row ] {
170- private [this ] val resultIterator = buffers.entrySet.iterator()
171- private [this ] val resultProjection = resultProjectionBuilder()
184+ new Iterator [Row ] {
185+ private [this ] val resultIterator = buffers.entrySet.iterator()
186+ private [this ] val resultProjection = resultProjectionBuilder()
172187
173- def hasNext = resultIterator.hasNext
188+ def hasNext = resultIterator.hasNext
174189
175- def next () = {
176- val currentGroup = resultIterator.next()
177- resultProjection(joinedRow(currentGroup.getKey, currentGroup.getValue))
190+ def next () = {
191+ val currentGroup = resultIterator.next()
192+ resultProjection(joinedRow(currentGroup.getKey, currentGroup.getValue))
193+ }
178194 }
179195 }
180196 }
0 commit comments