@@ -26,6 +26,25 @@ import org.apache.spark.sql.catalyst.expressions._
2626import org .apache .spark .sql .catalyst .plans .logical .{Rollup , Cube , Aggregate }
2727import org .apache .spark .sql .types .NumericType
2828
29+ /**
30+ * The Grouping Type
31+ */
32+ sealed private [sql] trait GroupType
33+
34+ /**
35+ * To indicate it's the GroupBy
36+ */
37+ private [sql] object GroupByType extends GroupType
38+
39+ /**
40+ * To indicate it's the CUBE
41+ */
42+ private [sql] object CubeType extends GroupType
43+
44+ /**
45+ * To indicate it's the ROLLUP
46+ */
47+ private [sql] object RollupType extends GroupType
2948
3049/**
3150 * :: Experimental ::
@@ -34,10 +53,13 @@ import org.apache.spark.sql.types.NumericType
3453 * @since 1.3.0
3554 */
3655@ Experimental
37- class GroupedData protected [sql](df : DataFrame , groupingExprs : Seq [Expression ]) {
56+ class GroupedData protected [sql](
57+ df : DataFrame ,
58+ groupingExprs : Seq [Expression ],
59+ groupType : GroupType ) {
3860
3961 protected def aggregateExpressions (aggrExprs : Seq [NamedExpression ])
40- : Seq [NamedExpression ] = {
62+ : Seq [NamedExpression ] = {
4163 if (df.sqlContext.conf.dataFrameRetainGroupColumns) {
4264 val retainedExprs = groupingExprs.map {
4365 case expr : NamedExpression => expr
@@ -50,8 +72,17 @@ class GroupedData protected[sql](df: DataFrame, groupingExprs: Seq[Expression])
5072 }
5173
5274 protected [sql] implicit def toDF (aggExprs : Seq [NamedExpression ]): DataFrame = {
53- DataFrame (
54- df.sqlContext, Aggregate (groupingExprs, aggregateExpressions(aggExprs), df.logicalPlan))
75+ groupType match {
76+ case GroupByType =>
77+ DataFrame (
78+ df.sqlContext, Aggregate (groupingExprs, aggregateExpressions(aggExprs), df.logicalPlan))
79+ case RollupType =>
80+ DataFrame (
81+ df.sqlContext, Rollup (groupingExprs, df.logicalPlan, aggregateExpressions(aggExprs)))
82+ case CubeType =>
83+ DataFrame (
84+ df.sqlContext, Cube (groupingExprs, df.logicalPlan, aggregateExpressions(aggExprs)))
85+ }
5586 }
5687
5788 private [this ] def aggregateNumericColumns (colNames : String * )(f : Expression => Expression )
@@ -259,27 +290,3 @@ class GroupedData protected[sql](df: DataFrame, groupingExprs: Seq[Expression])
259290 }
260291
261292}
262-
263- /**
264- * A set of methods for aggregations on a [[DataFrame ]] cube, created by [[DataFrame.cube ]].
265- */
266- private [sql] class CubedData protected [sql](df : DataFrame , groupingExprs : Seq [Expression ])
267- extends GroupedData (df, groupingExprs) {
268-
269- protected [sql] implicit override def toDF (aggExprs : Seq [NamedExpression ]): DataFrame = {
270- DataFrame (
271- df.sqlContext, Cube (groupingExprs, df.logicalPlan, aggregateExpressions(aggExprs)))
272- }
273- }
274-
275- /**
276- * A set of methods for aggregations on a [[DataFrame ]] rollup, created by [[DataFrame.rollup ]].
277- */
278- private [sql] class RollupedData protected [sql](df : DataFrame , groupingExprs : Seq [Expression ])
279- extends GroupedData (df, groupingExprs) {
280-
281- protected [sql] implicit override def toDF (aggExprs : Seq [NamedExpression ]): DataFrame = {
282- DataFrame (
283- df.sqlContext, Rollup (groupingExprs, df.logicalPlan, aggregateExpressions(aggExprs)))
284- }
285- }
0 commit comments