@@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.plans.logical
2020import org .apache .spark .sql .catalyst .expressions ._
2121import org .apache .spark .sql .catalyst .plans ._
2222import org .apache .spark .sql .types ._
23+ import org .apache .spark .util .collection .OpenHashSet
2324
2425case class Project (projectList : Seq [NamedExpression ], child : LogicalPlan ) extends UnaryNode {
2526 override def output : Seq [Attribute ] = projectList.map(_.toAttribute)
@@ -228,24 +229,76 @@ case class Window(
228229/**
229230 * Apply the all of the GroupExpressions to every input row, hence we will get
230231 * multiple output rows for a input row.
231- * @param projections The group of expressions, all of the group expressions should
232- * output the same schema specified by the parameter `output`
233- * @param output The output Schema
232+ * @param bitmasks The bitmask set represents the grouping sets
233+ * @param groupByExprs The grouping by expressions
234234 * @param child Child operator
235235 */
236236case class Expand (
237- projections : Seq [Seq [Expression ]],
238- output : Seq [Attribute ],
237+ bitmasks : Seq [Int ],
238+ groupByExprs : Seq [Expression ],
239+ gid : Attribute ,
239240 child : LogicalPlan ) extends UnaryNode {
240241 override def statistics : Statistics = {
241242 val sizeInBytes = child.statistics.sizeInBytes * projections.length
242243 Statistics (sizeInBytes = sizeInBytes)
243244 }
245+
246+ val projections : Seq [Seq [Expression ]] = expand()
247+
248+ /**
249+ * Extract attribute set according to the grouping id
250+ * @param bitmask bitmask to represent the selected of the attribute sequence
251+ * @param exprs the attributes in sequence
252+ * @return the attributes of non selected specified via bitmask (with the bit set to 1)
253+ */
254+ private def buildNonSelectExprSet (bitmask : Int , exprs : Seq [Expression ])
255+ : OpenHashSet [Expression ] = {
256+ val set = new OpenHashSet [Expression ](2 )
257+
258+ var bit = exprs.length - 1
259+ while (bit >= 0 ) {
260+ if (((bitmask >> bit) & 1 ) == 0 ) set.add(exprs(bit))
261+ bit -= 1
262+ }
263+
264+ set
265+ }
266+
267+ /**
268+ * Create an array of Projections for the child projection, and replace the projections'
269+ * expressions which equal GroupBy expressions with Literal(null), if those expressions
270+ * are not set for this grouping set (according to the bit mask).
271+ */
272+ private [this ] def expand (): Seq [Seq [Expression ]] = {
273+ val result = new scala.collection.mutable.ArrayBuffer [Seq [Expression ]]
274+
275+ bitmasks.foreach { bitmask =>
276+ // get the non selected grouping attributes according to the bit mask
277+ val nonSelectedGroupExprSet = buildNonSelectExprSet(bitmask, groupByExprs)
278+
279+ val substitution = (child.output :+ gid).map(expr => expr transformDown {
280+ case x : Expression if nonSelectedGroupExprSet.contains(x) =>
281+ // if the input attribute in the Invalid Grouping Expression set of for this group
282+ // replace it with constant null
283+ Literal .create(null , expr.dataType)
284+ case x if x == gid =>
285+ // replace the groupingId with concrete value (the bit mask)
286+ Literal .create(bitmask, IntegerType )
287+ })
288+
289+ result += substitution
290+ }
291+
292+ result.toSeq
293+ }
294+
295+ override def output : Seq [Attribute ] = {
296+ child.output :+ gid
297+ }
244298}
245299
246300trait GroupingAnalytics extends UnaryNode {
247301 self : Product =>
248- def gid : AttributeReference
249302 def groupByExprs : Seq [Expression ]
250303 def aggregations : Seq [NamedExpression ]
251304
@@ -266,17 +319,12 @@ trait GroupingAnalytics extends UnaryNode {
266319 * @param child Child operator
267320 * @param aggregations The Aggregation expressions, those non selected group by expressions
268321 * will be considered as constant null if it appears in the expressions
269- * @param gid The attribute represents the virtual column GROUPING__ID, and it's also
270- * the bitmask indicates the selected GroupBy Expressions for each
271- * aggregating output row.
272- * The associated output will be one of the value in `bitmasks`
273322 */
274323case class GroupingSets (
275324 bitmasks : Seq [Int ],
276325 groupByExprs : Seq [Expression ],
277326 child : LogicalPlan ,
278- aggregations : Seq [NamedExpression ],
279- gid : AttributeReference = VirtualColumn .newGroupingId) extends GroupingAnalytics {
327+ aggregations : Seq [NamedExpression ]) extends GroupingAnalytics {
280328
281329 def withNewAggs (aggs : Seq [NamedExpression ]): GroupingAnalytics =
282330 this .copy(aggregations = aggs)
@@ -290,15 +338,11 @@ case class GroupingSets(
290338 * @param child Child operator
291339 * @param aggregations The Aggregation expressions, those non selected group by expressions
292340 * will be considered as constant null if it appears in the expressions
293- * @param gid The attribute represents the virtual column GROUPING__ID, and it's also
294- * the bitmask indicates the selected GroupBy Expressions for each
295- * aggregating output row.
296341 */
297342case class Cube (
298343 groupByExprs : Seq [Expression ],
299344 child : LogicalPlan ,
300- aggregations : Seq [NamedExpression ],
301- gid : AttributeReference = VirtualColumn .newGroupingId) extends GroupingAnalytics {
345+ aggregations : Seq [NamedExpression ]) extends GroupingAnalytics {
302346
303347 def withNewAggs (aggs : Seq [NamedExpression ]): GroupingAnalytics =
304348 this .copy(aggregations = aggs)
@@ -313,15 +357,11 @@ case class Cube(
313357 * @param child Child operator
314358 * @param aggregations The Aggregation expressions, those non selected group by expressions
315359 * will be considered as constant null if it appears in the expressions
316- * @param gid The attribute represents the virtual column GROUPING__ID, and it's also
317- * the bitmask indicates the selected GroupBy Expressions for each
318- * aggregating output row.
319360 */
320361case class Rollup (
321362 groupByExprs : Seq [Expression ],
322363 child : LogicalPlan ,
323- aggregations : Seq [NamedExpression ],
324- gid : AttributeReference = VirtualColumn .newGroupingId) extends GroupingAnalytics {
364+ aggregations : Seq [NamedExpression ]) extends GroupingAnalytics {
325365
326366 def withNewAggs (aggs : Seq [NamedExpression ]): GroupingAnalytics =
327367 this .copy(aggregations = aggs)
0 commit comments