Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -49,4 +49,6 @@ case class Count(child: Expression) extends DeclarativeAggregate {
)

override val evaluateExpression = Cast(count, LongType)

override def defaultResult: Option[Literal] = Option(Literal(0L))
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,9 @@ package org.apache.spark.sql.catalyst.expressions.aggregate
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan}
import org.apache.spark.sql.types.{StructType, MapType, ArrayType}
import org.apache.spark.sql.catalyst.plans.logical.{Expand, Aggregate, LogicalPlan}
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.types.{IntegerType, StructType, MapType, ArrayType}

/**
* Utility functions used by the query planner to convert our plan to new aggregation code path.
Expand All @@ -41,7 +42,7 @@ object Utils {

private def doConvert(plan: LogicalPlan): Option[Aggregate] = plan match {
case p: Aggregate if supportsGroupingKeySchema(p) =>
val converted = p.transformExpressionsDown {
val converted = MultipleDistinctRewriter.rewrite(p.transformExpressionsDown {
case expressions.Average(child) =>
aggregate.AggregateExpression2(
aggregateFunction = aggregate.Average(child),
Expand Down Expand Up @@ -144,7 +145,8 @@ object Utils {
aggregateFunction = aggregate.VarianceSamp(child),
mode = aggregate.Complete,
isDistinct = false)
}
})

// Check if there is any expressions.AggregateExpression1 left.
// If so, we cannot convert this plan.
val hasAggregateExpression1 = converted.aggregateExpressions.exists { expr =>
Expand All @@ -156,6 +158,7 @@ object Utils {
}

// Check if there are multiple distinct columns.
// TODO remove this.
val aggregateExpressions = converted.aggregateExpressions.flatMap { expr =>
expr.collect {
case agg: AggregateExpression2 => agg
Expand Down Expand Up @@ -213,3 +216,178 @@ object Utils {
case other => None
}
}

/**
* This rule rewrites an aggregate query with multiple distinct clauses into an expanded double
* aggregation in which the regular aggregation expressions and every distinct clause is aggregated
* in a separate group. The results are then combined in a second aggregate.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be really helpful if there was an example of what this rewrite looks like here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll add an example in the follow-up PR.

*
* TODO Expression cannocalization
* TODO Eliminate foldable expressions from distinct clauses.
* TODO This eliminates all distinct expressions. We could safely pass one to the aggregate
* operator. Perhaps this is a good thing? It is much simpler to plan later on...
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just making sure I understand, these are all optimizations not correctness concerns?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These are all optimizations. The last one is perhaps a question for @yhuai: we have a choice to rewrite all distinct expressions.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, we can use this path to handle all cases. If I understand correctly, this rewriting approach will first create two logical Aggregate operators and then we shuffle data twice. Our current planning rule for a single distinct agg will shuffle data once, which can be bad if we do not have group by clause (because we will have a single reducer). To make the ideal decision, we need to know the statistics of grouping columns and distinct column. However, for the cases that we have a single distinct column and we do not have a group by clause, I feel your rewriting approach should be strictly better. What do you think?

*/
object MultipleDistinctRewriter extends Rule[LogicalPlan] {

def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
case a: Aggregate => rewrite(a)
case p => p
}

def rewrite(a: Aggregate): Aggregate = {

// Collect all aggregate expressions.
val aggExpressions = a.aggregateExpressions.flatMap { e =>
e.collect {
case ae: AggregateExpression2 => ae
}
}

// Extract distinct aggregate expressions.
val distinctAggGroups = aggExpressions
.filter(_.isDistinct)
.groupBy(_.aggregateFunction.children.toSet)

// Only continue to rewrite if there is more than one distinct group.
if (distinctAggGroups.size > 1) {
// Create the attributes for the grouping id and the group by clause.
val gid = new AttributeReference("gid", IntegerType, false)()
val groupByMap = a.groupingExpressions.collect {
case ne: NamedExpression => ne -> ne.toAttribute
case e => e -> new AttributeReference(e.prettyName, e.dataType, e.nullable)()
}
val groupByAttrs = groupByMap.map(_._2)

// Functions used to modify aggregate functions and their inputs.
def evalWithinGroup(id: Literal, e: Expression) = If(EqualTo(gid, id), e, nullify(e))
def patchAggregateFunctionChildren(
af: AggregateFunction2,
id: Literal,
attrs: Map[Expression, Expression]): AggregateFunction2 = {
af.withNewChildren(af.children.map { case afc =>
evalWithinGroup(id, attrs(afc))
}).asInstanceOf[AggregateFunction2]
}

// Setup unique distinct aggregate children.
val distinctAggChildren = distinctAggGroups.keySet.flatten.toSeq
val distinctAggChildAttrMap = distinctAggChildren.map(expressionAttributePair).toMap
val distinctAggChildAttrs = distinctAggChildAttrMap.values.toSeq

// Setup expand & aggregate operators for distinct aggregate expressions.
val distinctAggOperatorMap = distinctAggGroups.toSeq.zipWithIndex.map {
case ((group, expressions), i) =>
val id = Literal(i + 1)

// Expand projection
val projection = distinctAggChildren.map {
case e if group.contains(e) => e
case e => nullify(e)
} :+ id

// Final aggregate
val operators = expressions.map { e =>
val af = e.aggregateFunction
val naf = patchAggregateFunctionChildren(af, id, distinctAggChildAttrMap)
(e, e.copy(aggregateFunction = naf, isDistinct = false))
}

(projection, operators)
}

// Setup expand for the 'regular' aggregate expressions.
val regularAggExprs = aggExpressions.filter(!_.isDistinct)
val regularAggChildren = regularAggExprs.flatMap(_.aggregateFunction.children).distinct
val regularAggChildAttrMap = regularAggChildren.map(expressionAttributePair).toMap

// Setup aggregates for 'regular' aggregate expressions.
val regularGroupId = Literal(0)
val regularAggOperatorMap = regularAggExprs.map { e =>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comment on what each tuple element is, or maybe even use a case class?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll add documentation in a follow-up PR.

// Perform the actual aggregation in the initial aggregate.
val af = patchAggregateFunctionChildren(
e.aggregateFunction,
regularGroupId,
regularAggChildAttrMap)
val a = Alias(e.copy(aggregateFunction = af), e.toString)()

// Get the result of the first aggregate in the last aggregate.
val b = AggregateExpression2(
aggregate.First(evalWithinGroup(regularGroupId, a.toAttribute), Literal(true)),
mode = Complete,
isDistinct = false)

// Some aggregate functions (COUNT) have the special property that they can return a
// non-null result without any input. We need to make sure we return a result in this case.
val c = af.defaultResult match {
case Some(lit) => Coalesce(Seq(b, lit))
case None => b
}

(e, a, c)
}

// Construct the regular aggregate input projection only if we need one.
val regularAggProjection = if (regularAggExprs.nonEmpty) {
Seq(a.groupingExpressions ++
distinctAggChildren.map(nullify) ++
Seq(regularGroupId) ++
regularAggChildren)
} else {
Seq.empty[Seq[Expression]]
}

// Construct the distinct aggregate input projections.
val regularAggNulls = regularAggChildren.map(nullify)
val distinctAggProjections = distinctAggOperatorMap.map {
case (projection, _) =>
a.groupingExpressions ++
projection ++
regularAggNulls
}

// Construct the expand operator.
val expand = Expand(
regularAggProjection ++ distinctAggProjections,
groupByAttrs ++ distinctAggChildAttrs ++ Seq(gid) ++ regularAggChildAttrMap.values.toSeq,
a.child)

// Construct the first aggregate operator. This de-duplicates the all the children of
// distinct operators, and applies the regular aggregate operators.
val firstAggregateGroupBy = groupByAttrs ++ distinctAggChildAttrs :+ gid
val firstAggregate = Aggregate(
firstAggregateGroupBy,
firstAggregateGroupBy ++ regularAggOperatorMap.map(_._2),
expand)

// Construct the second aggregate
val transformations: Map[Expression, Expression] =
(distinctAggOperatorMap.flatMap(_._2) ++
regularAggOperatorMap.map(e => (e._1, e._3))).toMap

val patchedAggExpressions = a.aggregateExpressions.map { e =>
e.transformDown {
case e: Expression =>
// The same GROUP BY clauses can have different forms (different names for instance) in
// the groupBy and aggregate expressions of an aggregate. This makes a map lookup
// tricky. So we do a linear search for a semantically equal group by expression.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We've talked about adding an ExpressionMap similar to AttributeMap in the past. Seems like that would be useful here.

groupByMap
.find(ge => e.semanticEquals(ge._1))
.map(_._2)
.getOrElse(transformations.getOrElse(e, e))
}.asInstanceOf[NamedExpression]
}
Aggregate(groupByAttrs, patchedAggExpressions, firstAggregate)
} else {
a
}
}

private def nullify(e: Expression) = Literal.create(null, e.dataType)

private def expressionAttributePair(e: Expression) =
// We are creating a new reference here instead of reusing the attribute in case of a
// NamedExpression. This is done to prevent collisions between distinct and regular aggregate
// children, in this case attribute reuse causes the input of the regular aggregate to bound to
// the (nulled out) input of the distinct aggregate.
e -> new AttributeReference(e.prettyName, e.dataType, true)()
}
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,12 @@ sealed abstract class AggregateFunction2 extends Expression with ImplicitCastInp
*/
def supportsPartial: Boolean = true

/**
* Result of the aggregate function when the input is empty. This is currently only used for the
* proper rewriting of distinct aggregate functions.
*/
def defaultResult: Option[Literal] = None

override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String =
throw new UnsupportedOperationException(s"Cannot evaluate expression: $this")
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -200,9 +200,9 @@ object SetOperationPushDown extends Rule[LogicalPlan] with PredicateHelper {
*/
object ColumnPruning extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
case a @ Aggregate(_, _, e @ Expand(_, groupByExprs, _, child))
if (child.outputSet -- AttributeSet(groupByExprs) -- a.references).nonEmpty =>
a.copy(child = e.copy(child = prunedChild(child, AttributeSet(groupByExprs) ++ a.references)))
case a @ Aggregate(_, _, e @ Expand(_, _, child))
if (child.outputSet -- AttributeSet(e.output) -- a.references).nonEmpty =>
a.copy(child = e.copy(child = prunedChild(child, AttributeSet(e.output) ++ a.references)))

// Eliminate attributes that are not needed to calculate the specified aggregates.
case a @ Aggregate(_, _, child) if (child.outputSet -- a.references).nonEmpty =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -235,33 +235,17 @@ case class Window(
projectList ++ windowExpressions.map(_.toAttribute)
}

/**
* Apply the all of the GroupExpressions to every input row, hence we will get
* multiple output rows for a input row.
* @param bitmasks The bitmask set represents the grouping sets
* @param groupByExprs The grouping by expressions
* @param child Child operator
*/
case class Expand(
bitmasks: Seq[Int],
groupByExprs: Seq[Expression],
gid: Attribute,
child: LogicalPlan) extends UnaryNode {
override def statistics: Statistics = {
val sizeInBytes = child.statistics.sizeInBytes * projections.length
Statistics(sizeInBytes = sizeInBytes)
}

val projections: Seq[Seq[Expression]] = expand()

private[sql] object Expand {
/**
* Extract attribute set according to the grouping id
* Extract attribute set according to the grouping id.
*
* @param bitmask bitmask to represent the selected of the attribute sequence
* @param exprs the attributes in sequence
* @return the attributes of non selected specified via bitmask (with the bit set to 1)
*/
private def buildNonSelectExprSet(bitmask: Int, exprs: Seq[Expression])
: OpenHashSet[Expression] = {
private def buildNonSelectExprSet(
bitmask: Int,
exprs: Seq[Expression]): OpenHashSet[Expression] = {
val set = new OpenHashSet[Expression](2)

var bit = exprs.length - 1
Expand All @@ -274,18 +258,28 @@ case class Expand(
}

/**
* Create an array of Projections for the child projection, and replace the projections'
* expressions which equal GroupBy expressions with Literal(null), if those expressions
* are not set for this grouping set (according to the bit mask).
* Apply the all of the GroupExpressions to every input row, hence we will get
* multiple output rows for a input row.
*
* @param bitmasks The bitmask set represents the grouping sets
* @param groupByExprs The grouping by expressions
* @param gid Attribute of the grouping id
* @param child Child operator
*/
private[this] def expand(): Seq[Seq[Expression]] = {
val result = new scala.collection.mutable.ArrayBuffer[Seq[Expression]]

bitmasks.foreach { bitmask =>
def apply(
bitmasks: Seq[Int],
groupByExprs: Seq[Expression],
gid: Attribute,
child: LogicalPlan): Expand = {
// Create an array of Projections for the child projection, and replace the projections'
// expressions which equal GroupBy expressions with Literal(null), if those expressions
// are not set for this grouping set (according to the bit mask).
val projections = bitmasks.map { bitmask =>
// get the non selected grouping attributes according to the bit mask
val nonSelectedGroupExprSet = buildNonSelectExprSet(bitmask, groupByExprs)

val substitution = (child.output :+ gid).map(expr => expr transformDown {
(child.output :+ gid).map(expr => expr transformDown {
// TODO this causes a problem when a column is used both for grouping and aggregation.
case x: Expression if nonSelectedGroupExprSet.contains(x) =>
// if the input attribute in the Invalid Grouping Expression set of for this group
// replace it with constant null
Expand All @@ -294,15 +288,29 @@ case class Expand(
// replace the groupingId with concrete value (the bit mask)
Literal.create(bitmask, IntegerType)
})

result += substitution
}

result.toSeq
Expand(projections, child.output :+ gid, child)
}
}

override def output: Seq[Attribute] = {
child.output :+ gid
/**
* Apply a number of projections to every input row, hence we will get multiple output rows for
* a input row.
*
* @param projections to apply
* @param output of all projections.
* @param child operator.
*/
case class Expand(
projections: Seq[Seq[Expression]],
output: Seq[Attribute],
child: LogicalPlan) extends UnaryNode {

override def statistics: Statistics = {
// TODO shouldn't we factor in the size of the projection versus the size of the backing child
// row?
val sizeInBytes = child.statistics.sizeInBytes * projections.length
Statistics(sizeInBytes = sizeInBytes)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -420,7 +420,7 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
}
case logical.Filter(condition, child) =>
execution.Filter(condition, planLater(child)) :: Nil
case e @ logical.Expand(_, _, _, child) =>
case e @ logical.Expand(_, _, child) =>
execution.Expand(e.projections, e.output, planLater(child)) :: Nil
case a @ logical.Aggregate(group, agg, child) => {
val useNewAggregation = sqlContext.conf.useSqlAggregate2 && sqlContext.conf.codegenEnabled
Expand Down