Skip to content

Commit 3d28081

Browse files
hvanhovellyhuai
authored andcommitted
[SPARK-12024][SQL] More efficient multi-column counting.
In #9409 we enabled multi-column counting. The approach taken in that PR introduces a bit of overhead by first creating a row only to check if all of the columns are non-null. This PR fixes that technical debt. Count now takes multiple columns as its input. In order to make this work I have also added support for multiple columns in the single distinct code path. cc yhuai Author: Herman van Hovell <[email protected]> Closes #10015 from hvanhovell/SPARK-12024.
1 parent cc7a1bc commit 3d28081

File tree

6 files changed

+33
-86
lines changed

6 files changed

+33
-86
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Count.scala

Lines changed: 4 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -21,16 +21,15 @@ import org.apache.spark.sql.catalyst.dsl.expressions._
2121
import org.apache.spark.sql.catalyst.expressions._
2222
import org.apache.spark.sql.types._
2323

24-
case class Count(child: Expression) extends DeclarativeAggregate {
25-
override def children: Seq[Expression] = child :: Nil
24+
case class Count(children: Seq[Expression]) extends DeclarativeAggregate {
2625

2726
override def nullable: Boolean = false
2827

2928
// Return data type.
3029
override def dataType: DataType = LongType
3130

3231
// Expected input data type.
33-
override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType)
32+
override def inputTypes: Seq[AbstractDataType] = Seq.fill(children.size)(AnyDataType)
3433

3534
private lazy val count = AttributeReference("count", LongType)()
3635

@@ -41,7 +40,7 @@ case class Count(child: Expression) extends DeclarativeAggregate {
4140
)
4241

4342
override lazy val updateExpressions = Seq(
44-
/* count = */ If(IsNull(child), count, count + 1L)
43+
/* count = */ If(children.map(IsNull).reduce(Or), count, count + 1L)
4544
)
4645

4746
override lazy val mergeExpressions = Seq(
@@ -54,17 +53,5 @@ case class Count(child: Expression) extends DeclarativeAggregate {
5453
}
5554

5655
object Count {
57-
def apply(children: Seq[Expression]): Count = {
58-
// This is used to deal with COUNT DISTINCT. When we have multiple
59-
// children (COUNT(DISTINCT col1, col2, ...)), we wrap them in a STRUCT (i.e. a Row).
60-
// Also, the semantic of COUNT(DISTINCT col1, col2, ...) is that if there is any
61-
// null in the arguments, we will not count that row. So, we use DropAnyNull at here
62-
// to return a null when any field of the created STRUCT is null.
63-
val child = if (children.size > 1) {
64-
DropAnyNull(CreateStruct(children))
65-
} else {
66-
children.head
67-
}
68-
Count(child)
69-
}
56+
def apply(child: Expression): Count = Count(child :: Nil)
7057
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala

Lines changed: 0 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -426,30 +426,3 @@ case class Greatest(children: Seq[Expression]) extends Expression {
426426
}
427427
}
428428

429-
/** Operator that drops a row when it contains any nulls. */
430-
case class DropAnyNull(child: Expression) extends UnaryExpression with ExpectsInputTypes {
431-
override def nullable: Boolean = true
432-
override def dataType: DataType = child.dataType
433-
override def inputTypes: Seq[AbstractDataType] = Seq(StructType)
434-
435-
protected override def nullSafeEval(input: Any): InternalRow = {
436-
val row = input.asInstanceOf[InternalRow]
437-
if (row.anyNull) {
438-
null
439-
} else {
440-
row
441-
}
442-
}
443-
444-
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
445-
nullSafeCodeGen(ctx, ev, eval => {
446-
s"""
447-
if ($eval.anyNull()) {
448-
${ev.isNull} = true;
449-
} else {
450-
${ev.value} = $eval;
451-
}
452-
"""
453-
})
454-
}
455-
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -362,9 +362,14 @@ object LikeSimplification extends Rule[LogicalPlan] {
362362
* Null value propagation from bottom to top of the expression tree.
363363
*/
364364
object NullPropagation extends Rule[LogicalPlan] {
365+
def nonNullLiteral(e: Expression): Boolean = e match {
366+
case Literal(null, _) => false
367+
case _ => true
368+
}
369+
365370
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
366371
case q: LogicalPlan => q transformExpressionsUp {
367-
case e @ AggregateExpression(Count(Literal(null, _)), _, _) =>
372+
case e @ AggregateExpression(Count(exprs), _, _) if !exprs.exists(nonNullLiteral) =>
368373
Cast(Literal(0L), e.dataType)
369374
case e @ IsNull(c) if !c.nullable => Literal.create(false, BooleanType)
370375
case e @ IsNotNull(c) if !c.nullable => Literal.create(true, BooleanType)
@@ -377,16 +382,13 @@ object NullPropagation extends Rule[LogicalPlan] {
377382
Literal.create(null, e.dataType)
378383
case e @ EqualNullSafe(Literal(null, _), r) => IsNull(r)
379384
case e @ EqualNullSafe(l, Literal(null, _)) => IsNull(l)
380-
case e @ AggregateExpression(Count(expr), mode, false) if !expr.nullable =>
385+
case e @ AggregateExpression(Count(exprs), mode, false) if !exprs.exists(_.nullable) =>
381386
// This rule should be only triggered when isDistinct field is false.
382387
AggregateExpression(Count(Literal(1)), mode, isDistinct = false)
383388

384389
// For Coalesce, remove null literals.
385390
case e @ Coalesce(children) =>
386-
val newChildren = children.filter {
387-
case Literal(null, _) => false
388-
case _ => true
389-
}
391+
val newChildren = children.filter(nonNullLiteral)
390392
if (newChildren.length == 0) {
391393
Literal.create(null, e.dataType)
392394
} else if (newChildren.length == 1) {

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ConditionalExpressionSuite.scala

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -231,18 +231,4 @@ class ConditionalExpressionSuite extends SparkFunSuite with ExpressionEvalHelper
231231
checkConsistencyBetweenInterpretedAndCodegen(Greatest, dt, 2)
232232
}
233233
}
234-
235-
test("function dropAnyNull") {
236-
val drop = DropAnyNull(CreateStruct(Seq('a.string.at(0), 'b.string.at(1))))
237-
val a = create_row("a", "q")
238-
val nullStr: String = null
239-
checkEvaluation(drop, a, a)
240-
checkEvaluation(drop, null, create_row("b", nullStr))
241-
checkEvaluation(drop, null, create_row(nullStr, nullStr))
242-
243-
val row = 'r.struct(
244-
StructField("a", StringType, false),
245-
StructField("b", StringType, true)).at(0)
246-
checkEvaluation(DropAnyNull(row), null, create_row(null))
247-
}
248234
}

sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala

Lines changed: 19 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -146,20 +146,16 @@ object Utils {
146146
aggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes))
147147

148148
// functionsWithDistinct is guaranteed to be non-empty. Even though it may contain more than one
149-
// DISTINCT aggregate function, all of those functions will have the same column expression.
149+
// DISTINCT aggregate function, all of those functions will have the same column expressions.
150150
// For example, it would be valid for functionsWithDistinct to be
151151
// [COUNT(DISTINCT foo), MAX(DISTINCT foo)], but [COUNT(DISTINCT bar), COUNT(DISTINCT foo)] is
152152
// disallowed because those two distinct aggregates have different column expressions.
153-
val distinctColumnExpression: Expression = {
154-
val allDistinctColumnExpressions = functionsWithDistinct.head.aggregateFunction.children
155-
assert(allDistinctColumnExpressions.length == 1)
156-
allDistinctColumnExpressions.head
157-
}
158-
val namedDistinctColumnExpression: NamedExpression = distinctColumnExpression match {
153+
val distinctColumnExpressions = functionsWithDistinct.head.aggregateFunction.children
154+
val namedDistinctColumnExpressions = distinctColumnExpressions.map {
159155
case ne: NamedExpression => ne
160156
case other => Alias(other, other.toString)()
161157
}
162-
val distinctColumnAttribute: Attribute = namedDistinctColumnExpression.toAttribute
158+
val distinctColumnAttributes = namedDistinctColumnExpressions.map(_.toAttribute)
163159
val groupingAttributes = groupingExpressions.map(_.toAttribute)
164160

165161
// 1. Create an Aggregate Operator for partial aggregations.
@@ -170,10 +166,11 @@ object Utils {
170166
// We will group by the original grouping expression, plus an additional expression for the
171167
// DISTINCT column. For example, for AVG(DISTINCT value) GROUP BY key, the grouping
172168
// expressions will be [key, value].
173-
val partialAggregateGroupingExpressions = groupingExpressions :+ namedDistinctColumnExpression
169+
val partialAggregateGroupingExpressions =
170+
groupingExpressions ++ namedDistinctColumnExpressions
174171
val partialAggregateResult =
175172
groupingAttributes ++
176-
Seq(distinctColumnAttribute) ++
173+
distinctColumnAttributes ++
177174
partialAggregateExpressions.flatMap(_.aggregateFunction.inputAggBufferAttributes)
178175
if (usesTungstenAggregate) {
179176
TungstenAggregate(
@@ -208,28 +205,28 @@ object Utils {
208205
partialMergeAggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes)
209206
val partialMergeAggregateResult =
210207
groupingAttributes ++
211-
Seq(distinctColumnAttribute) ++
208+
distinctColumnAttributes ++
212209
partialMergeAggregateExpressions.flatMap(_.aggregateFunction.inputAggBufferAttributes)
213210
if (usesTungstenAggregate) {
214211
TungstenAggregate(
215212
requiredChildDistributionExpressions = Some(groupingAttributes),
216-
groupingExpressions = groupingAttributes :+ distinctColumnAttribute,
213+
groupingExpressions = groupingAttributes ++ distinctColumnAttributes,
217214
nonCompleteAggregateExpressions = partialMergeAggregateExpressions,
218215
nonCompleteAggregateAttributes = partialMergeAggregateAttributes,
219216
completeAggregateExpressions = Nil,
220217
completeAggregateAttributes = Nil,
221-
initialInputBufferOffset = (groupingAttributes :+ distinctColumnAttribute).length,
218+
initialInputBufferOffset = (groupingAttributes ++ distinctColumnAttributes).length,
222219
resultExpressions = partialMergeAggregateResult,
223220
child = partialAggregate)
224221
} else {
225222
SortBasedAggregate(
226223
requiredChildDistributionExpressions = Some(groupingAttributes),
227-
groupingExpressions = groupingAttributes :+ distinctColumnAttribute,
224+
groupingExpressions = groupingAttributes ++ distinctColumnAttributes,
228225
nonCompleteAggregateExpressions = partialMergeAggregateExpressions,
229226
nonCompleteAggregateAttributes = partialMergeAggregateAttributes,
230227
completeAggregateExpressions = Nil,
231228
completeAggregateAttributes = Nil,
232-
initialInputBufferOffset = (groupingAttributes :+ distinctColumnAttribute).length,
229+
initialInputBufferOffset = (groupingAttributes ++ distinctColumnAttributes).length,
233230
resultExpressions = partialMergeAggregateResult,
234231
child = partialAggregate)
235232
}
@@ -244,14 +241,16 @@ object Utils {
244241
expr => aggregateFunctionToAttribute(expr.aggregateFunction, expr.isDistinct)
245242
}
246243

244+
val distinctColumnAttributeLookup =
245+
distinctColumnExpressions.zip(distinctColumnAttributes).toMap
247246
val (completeAggregateExpressions, completeAggregateAttributes) = functionsWithDistinct.map {
248247
// Children of an AggregateFunction with DISTINCT keyword has already
249248
// been evaluated. At here, we need to replace original children
250249
// to AttributeReferences.
251250
case agg @ AggregateExpression(aggregateFunction, mode, true) =>
252-
val rewrittenAggregateFunction = aggregateFunction.transformDown {
253-
case expr if expr == distinctColumnExpression => distinctColumnAttribute
254-
}.asInstanceOf[AggregateFunction]
251+
val rewrittenAggregateFunction = aggregateFunction
252+
.transformDown(distinctColumnAttributeLookup)
253+
.asInstanceOf[AggregateFunction]
255254
// We rewrite the aggregate function to a non-distinct aggregation because
256255
// its input will have distinct arguments.
257256
// We just keep the isDistinct setting to true, so when users look at the query plan,
@@ -270,7 +269,7 @@ object Utils {
270269
nonCompleteAggregateAttributes = finalAggregateAttributes,
271270
completeAggregateExpressions = completeAggregateExpressions,
272271
completeAggregateAttributes = completeAggregateAttributes,
273-
initialInputBufferOffset = (groupingAttributes :+ distinctColumnAttribute).length,
272+
initialInputBufferOffset = (groupingAttributes ++ distinctColumnAttributes).length,
274273
resultExpressions = resultExpressions,
275274
child = partialMergeAggregate)
276275
} else {
@@ -281,7 +280,7 @@ object Utils {
281280
nonCompleteAggregateAttributes = finalAggregateAttributes,
282281
completeAggregateExpressions = completeAggregateExpressions,
283282
completeAggregateAttributes = completeAggregateAttributes,
284-
initialInputBufferOffset = (groupingAttributes :+ distinctColumnAttribute).length,
283+
initialInputBufferOffset = (groupingAttributes ++ distinctColumnAttributes).length,
285284
resultExpressions = resultExpressions,
286285
child = partialMergeAggregate)
287286
}

sql/core/src/main/scala/org/apache/spark/sql/expressions/WindowSpec.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -152,8 +152,8 @@ class WindowSpec private[sql](
152152
case Sum(child) => WindowExpression(
153153
UnresolvedWindowFunction("sum", child :: Nil),
154154
WindowSpecDefinition(partitionSpec, orderSpec, frame))
155-
case Count(child) => WindowExpression(
156-
UnresolvedWindowFunction("count", child :: Nil),
155+
case Count(children) => WindowExpression(
156+
UnresolvedWindowFunction("count", children),
157157
WindowSpecDefinition(partitionSpec, orderSpec, frame))
158158
case First(child, ignoreNulls) => WindowExpression(
159159
// TODO this is a hack for Hive UDAF first_value

0 commit comments

Comments
 (0)