Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -21,16 +21,15 @@ import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.types._

case class Count(child: Expression) extends DeclarativeAggregate {
override def children: Seq[Expression] = child :: Nil
case class Count(children: Seq[Expression]) extends DeclarativeAggregate {

override def nullable: Boolean = false

// Return data type.
override def dataType: DataType = LongType

// Expected input data type.
override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType)
override def inputTypes: Seq[AbstractDataType] = Seq.fill(children.size)(AnyDataType)

private lazy val count = AttributeReference("count", LongType)()

Expand All @@ -41,7 +40,7 @@ case class Count(child: Expression) extends DeclarativeAggregate {
)

override lazy val updateExpressions = Seq(
/* count = */ If(IsNull(child), count, count + 1L)
/* count = */ If(children.map(IsNull).reduce(Or), count, count + 1L)
)

override lazy val mergeExpressions = Seq(
Expand All @@ -54,17 +53,5 @@ case class Count(child: Expression) extends DeclarativeAggregate {
}

object Count {
def apply(children: Seq[Expression]): Count = {
// This is used to deal with COUNT DISTINCT. When we have multiple
// children (COUNT(DISTINCT col1, col2, ...)), we wrap them in a STRUCT (i.e. a Row).
// Also, the semantic of COUNT(DISTINCT col1, col2, ...) is that if there is any
// null in the arguments, we will not count that row. So, we use DropAnyNull at here
// to return a null when any field of the created STRUCT is null.
val child = if (children.size > 1) {
DropAnyNull(CreateStruct(children))
} else {
children.head
}
Count(child)
}
def apply(child: Expression): Count = Count(child :: Nil)
}
Original file line number Diff line number Diff line change
Expand Up @@ -426,30 +426,3 @@ case class Greatest(children: Seq[Expression]) extends Expression {
}
}

/** Operator that drops a row when it contains any nulls. */
case class DropAnyNull(child: Expression) extends UnaryExpression with ExpectsInputTypes {
override def nullable: Boolean = true
override def dataType: DataType = child.dataType
override def inputTypes: Seq[AbstractDataType] = Seq(StructType)

protected override def nullSafeEval(input: Any): InternalRow = {
val row = input.asInstanceOf[InternalRow]
if (row.anyNull) {
null
} else {
row
}
}

override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
nullSafeCodeGen(ctx, ev, eval => {
s"""
if ($eval.anyNull()) {
${ev.isNull} = true;
} else {
${ev.value} = $eval;
}
"""
})
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -362,9 +362,14 @@ object LikeSimplification extends Rule[LogicalPlan] {
* Null value propagation from bottom to top of the expression tree.
*/
object NullPropagation extends Rule[LogicalPlan] {
def nonNullLiteral(e: Expression): Boolean = e match {
case Literal(null, _) => false
case _ => true
}

def apply(plan: LogicalPlan): LogicalPlan = plan transform {
case q: LogicalPlan => q transformExpressionsUp {
case e @ AggregateExpression(Count(Literal(null, _)), _, _) =>
case e @ AggregateExpression(Count(exprs), _, _) if !exprs.exists(nonNullLiteral) =>
Cast(Literal(0L), e.dataType)
case e @ IsNull(c) if !c.nullable => Literal.create(false, BooleanType)
case e @ IsNotNull(c) if !c.nullable => Literal.create(true, BooleanType)
Expand All @@ -377,16 +382,13 @@ object NullPropagation extends Rule[LogicalPlan] {
Literal.create(null, e.dataType)
case e @ EqualNullSafe(Literal(null, _), r) => IsNull(r)
case e @ EqualNullSafe(l, Literal(null, _)) => IsNull(l)
case e @ AggregateExpression(Count(expr), mode, false) if !expr.nullable =>
case e @ AggregateExpression(Count(exprs), mode, false) if !exprs.exists(_.nullable) =>
// This rule should be only triggered when isDistinct field is false.
AggregateExpression(Count(Literal(1)), mode, isDistinct = false)

// For Coalesce, remove null literals.
case e @ Coalesce(children) =>
val newChildren = children.filter {
case Literal(null, _) => false
case _ => true
}
val newChildren = children.filter(nonNullLiteral)
if (newChildren.length == 0) {
Literal.create(null, e.dataType)
} else if (newChildren.length == 1) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -231,18 +231,4 @@ class ConditionalExpressionSuite extends SparkFunSuite with ExpressionEvalHelper
checkConsistencyBetweenInterpretedAndCodegen(Greatest, dt, 2)
}
}

test("function dropAnyNull") {
val drop = DropAnyNull(CreateStruct(Seq('a.string.at(0), 'b.string.at(1))))
val a = create_row("a", "q")
val nullStr: String = null
checkEvaluation(drop, a, a)
checkEvaluation(drop, null, create_row("b", nullStr))
checkEvaluation(drop, null, create_row(nullStr, nullStr))

val row = 'r.struct(
StructField("a", StringType, false),
StructField("b", StringType, true)).at(0)
checkEvaluation(DropAnyNull(row), null, create_row(null))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -146,20 +146,16 @@ object Utils {
aggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes))

// functionsWithDistinct is guaranteed to be non-empty. Even though it may contain more than one
// DISTINCT aggregate function, all of those functions will have the same column expression.
// DISTINCT aggregate function, all of those functions will have the same column expressions.
// For example, it would be valid for functionsWithDistinct to be
// [COUNT(DISTINCT foo), MAX(DISTINCT foo)], but [COUNT(DISTINCT bar), COUNT(DISTINCT foo)] is
// disallowed because those two distinct aggregates have different column expressions.
val distinctColumnExpression: Expression = {
val allDistinctColumnExpressions = functionsWithDistinct.head.aggregateFunction.children
assert(allDistinctColumnExpressions.length == 1)
allDistinctColumnExpressions.head
}
val namedDistinctColumnExpression: NamedExpression = distinctColumnExpression match {
val distinctColumnExpressions = functionsWithDistinct.head.aggregateFunction.children
val namedDistinctColumnExpressions = distinctColumnExpressions.map {
case ne: NamedExpression => ne
case other => Alias(other, other.toString)()
}
val distinctColumnAttribute: Attribute = namedDistinctColumnExpression.toAttribute
val distinctColumnAttributes = namedDistinctColumnExpressions.map(_.toAttribute)
val groupingAttributes = groupingExpressions.map(_.toAttribute)

// 1. Create an Aggregate Operator for partial aggregations.
Expand All @@ -170,10 +166,11 @@ object Utils {
// We will group by the original grouping expression, plus an additional expression for the
// DISTINCT column. For example, for AVG(DISTINCT value) GROUP BY key, the grouping
// expressions will be [key, value].
val partialAggregateGroupingExpressions = groupingExpressions :+ namedDistinctColumnExpression
val partialAggregateGroupingExpressions =
groupingExpressions ++ namedDistinctColumnExpressions
val partialAggregateResult =
groupingAttributes ++
Seq(distinctColumnAttribute) ++
distinctColumnAttributes ++
partialAggregateExpressions.flatMap(_.aggregateFunction.inputAggBufferAttributes)
if (usesTungstenAggregate) {
TungstenAggregate(
Expand Down Expand Up @@ -208,28 +205,28 @@ object Utils {
partialMergeAggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes)
val partialMergeAggregateResult =
groupingAttributes ++
Seq(distinctColumnAttribute) ++
distinctColumnAttributes ++
partialMergeAggregateExpressions.flatMap(_.aggregateFunction.inputAggBufferAttributes)
if (usesTungstenAggregate) {
TungstenAggregate(
requiredChildDistributionExpressions = Some(groupingAttributes),
groupingExpressions = groupingAttributes :+ distinctColumnAttribute,
groupingExpressions = groupingAttributes ++ distinctColumnAttributes,
nonCompleteAggregateExpressions = partialMergeAggregateExpressions,
nonCompleteAggregateAttributes = partialMergeAggregateAttributes,
completeAggregateExpressions = Nil,
completeAggregateAttributes = Nil,
initialInputBufferOffset = (groupingAttributes :+ distinctColumnAttribute).length,
initialInputBufferOffset = (groupingAttributes ++ distinctColumnAttributes).length,
resultExpressions = partialMergeAggregateResult,
child = partialAggregate)
} else {
SortBasedAggregate(
requiredChildDistributionExpressions = Some(groupingAttributes),
groupingExpressions = groupingAttributes :+ distinctColumnAttribute,
groupingExpressions = groupingAttributes ++ distinctColumnAttributes,
nonCompleteAggregateExpressions = partialMergeAggregateExpressions,
nonCompleteAggregateAttributes = partialMergeAggregateAttributes,
completeAggregateExpressions = Nil,
completeAggregateAttributes = Nil,
initialInputBufferOffset = (groupingAttributes :+ distinctColumnAttribute).length,
initialInputBufferOffset = (groupingAttributes ++ distinctColumnAttributes).length,
resultExpressions = partialMergeAggregateResult,
child = partialAggregate)
}
Expand All @@ -244,14 +241,16 @@ object Utils {
expr => aggregateFunctionToAttribute(expr.aggregateFunction, expr.isDistinct)
}

val distinctColumnAttributeLookup =
distinctColumnExpressions.zip(distinctColumnAttributes).toMap
val (completeAggregateExpressions, completeAggregateAttributes) = functionsWithDistinct.map {
// Children of an AggregateFunction with DISTINCT keyword has already
// been evaluated. At here, we need to replace original children
// to AttributeReferences.
case agg @ AggregateExpression(aggregateFunction, mode, true) =>
val rewrittenAggregateFunction = aggregateFunction.transformDown {
case expr if expr == distinctColumnExpression => distinctColumnAttribute
}.asInstanceOf[AggregateFunction]
val rewrittenAggregateFunction = aggregateFunction
.transformDown(distinctColumnAttributeLookup)
.asInstanceOf[AggregateFunction]
// We rewrite the aggregate function to a non-distinct aggregation because
// its input will have distinct arguments.
// We just keep the isDistinct setting to true, so when users look at the query plan,
Expand All @@ -270,7 +269,7 @@ object Utils {
nonCompleteAggregateAttributes = finalAggregateAttributes,
completeAggregateExpressions = completeAggregateExpressions,
completeAggregateAttributes = completeAggregateAttributes,
initialInputBufferOffset = (groupingAttributes :+ distinctColumnAttribute).length,
initialInputBufferOffset = (groupingAttributes ++ distinctColumnAttributes).length,
resultExpressions = resultExpressions,
child = partialMergeAggregate)
} else {
Expand All @@ -281,7 +280,7 @@ object Utils {
nonCompleteAggregateAttributes = finalAggregateAttributes,
completeAggregateExpressions = completeAggregateExpressions,
completeAggregateAttributes = completeAggregateAttributes,
initialInputBufferOffset = (groupingAttributes :+ distinctColumnAttribute).length,
initialInputBufferOffset = (groupingAttributes ++ distinctColumnAttributes).length,
resultExpressions = resultExpressions,
child = partialMergeAggregate)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -152,8 +152,8 @@ class WindowSpec private[sql](
case Sum(child) => WindowExpression(
UnresolvedWindowFunction("sum", child :: Nil),
WindowSpecDefinition(partitionSpec, orderSpec, frame))
case Count(child) => WindowExpression(
UnresolvedWindowFunction("count", child :: Nil),
case Count(children) => WindowExpression(
UnresolvedWindowFunction("count", children),
WindowSpecDefinition(partitionSpec, orderSpec, frame))
case First(child, ignoreNulls) => WindowExpression(
// TODO this is a hack for Hive UDAF first_value
Expand Down