Skip to content
Closed
Original file line number Diff line number Diff line change
Expand Up @@ -22,26 +22,27 @@ import org.apache.spark.sql.catalyst._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical.{Expand, Aggregate, LogicalPlan}
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.types.{IntegerType, StructType, MapType, ArrayType}
import org.apache.spark.sql.types._

/**
* Utility functions used by the query planner to convert our plan to new aggregation code path.
*/
object Utils {
// Right now, we do not support complex types in the grouping key schema.
private def supportsGroupingKeySchema(aggregate: Aggregate): Boolean = {
val hasComplexTypes = aggregate.groupingExpressions.map(_.dataType).exists {
case array: ArrayType => true
case map: MapType => true
case struct: StructType => true
case _ => false
}

!hasComplexTypes
// Check if the DataType given cannot be part of a group by clause.
private def isUnGroupable(dt: DataType): Boolean = dt match {
case _: ArrayType | _: MapType => true
case s: StructType => s.fields.exists(f => isUnGroupable(f.dataType))
case _ => false
}

// Right now, we do not support complex types in the grouping key schema.
private def supportsGroupingKeySchema(aggregate: Aggregate): Boolean =
!aggregate.groupingExpressions.exists(e => isUnGroupable(e.dataType))

private def doConvert(plan: LogicalPlan): Option[Aggregate] = plan match {
case p: Aggregate if supportsGroupingKeySchema(p) =>

val converted = MultipleDistinctRewriter.rewrite(p.transformExpressionsDown {
case expressions.Average(child) =>
aggregate.AggregateExpression2(
Expand All @@ -55,10 +56,14 @@ object Utils {
mode = aggregate.Complete,
isDistinct = false)

// We do not support multiple COUNT DISTINCT columns for now.
case expressions.CountDistinct(children) if children.length == 1 =>
case expressions.CountDistinct(children) =>
val child = if (children.size > 1) {
DropAnyNull(CreateStruct(children))
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@yhuai if we combine this with the distinct rewriting rule. It will add a struct to the groupBy clause of the first aggregate. This is currently not allowed in the new UDAF path, so it'll fall back to the old path. For example:

val data2 = Seq[(Integer, Integer, Integer)](
    (1, 10, -10),
    (null, -60, 60),
    (1, 30, -30),
    (1, 30, 30),
    (2, 1, 1),
    (null, -10, 10),
    (2, -1, null),
    (2, 1, 1),
    (2, null, 1),
    (null, 100, -10),
    (3, null, 3),
    (null, null, null),
    (3, null, null)).toDF("key", "value1", "value2")
data2.registerTempTable("agg2")

val q sql(
 """
 |SELECT
 |  key,
 |  count(distinct value1),
 |  count(distinct value2),
 |  count(distinct value1, value2)
 |FROM agg2
 |GROUP BY key
 """.stripMargin)

Will create the following physical plan:

== Physical Plan ==
TungstenAggregate(key=[key#3], functions=[(count(if ((gid#44 = 1)) attributereference#45 else null),mode=Final,isDistinct=false),(count(if ((gid#44 = 3)) attributereference#47 else null),mode=Final,isDistinct=false),(count(if ((gid#44 = 2)) dropanynull#46 else null),mode=Final,isDistinct=false)], output=[key#3,_c1#32L,_c2#33L,_c3#34L])
 TungstenExchange(Shuffle without coordinator) hashpartitioning(key#3,200), None
  TungstenAggregate(key=[key#3], functions=[(count(if ((gid#44 = 1)) attributereference#45 else null),mode=Partial,isDistinct=false),(count(if ((gid#44 = 3)) attributereference#47 else null),mode=Partial,isDistinct=false),(count(if ((gid#44 = 2)) dropanynull#46 else null),mode=Partial,isDistinct=false)], output=[key#3,count#49L,count#53L,count#51L])
   Aggregate false, [key#3,attributereference#45,dropanynull#46,attributereference#47,gid#44], [key#3,attributereference#45,dropanynull#46,attributereference#47,gid#44]
    ConvertToSafe
     TungstenExchange(Shuffle without coordinator) hashpartitioning(key#3,attributereference#45,dropanynull#46,attributereference#47,gid#44,200), None
      ConvertToUnsafe
       Aggregate true, [key#3,attributereference#45,dropanynull#46,attributereference#47,gid#44], [key#3,attributereference#45,dropanynull#46,attributereference#47,gid#44]
        !Expand [List(key#3, value1#4, null, null, 1),List(key#3, null, dropanynull(struct(value1#4,value2#5)), null, 2),List(key#3, null, null, value2#5, 3)], [key#3,attributereference#45,dropanynull#46,attributereference#47,gid#44]
         LocalTableScan [key#3,value1#4,value2#5], [[1,10,-10],[null,-60,60],[1,30,-30],[1,30,30],[2,1,1],[null,-10,10],[2,-1,null],[2,1,1],[2,null,1],[null,100,-10],[3,null,3],[null,null,null],[3,null,null]]

Is it possible to add support for fixed width structs as group by expression to the new aggregation path?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Quick follow-up.

Allowing structs does not seem to create a problem. I disabled this line locally: https://github.com/apache/spark/blob/master/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Utils.scala#L36. And now it uses the TungstenAggregate.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh, yes. Based on https://github.com/apache/spark/blob/master/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala#L240-L266, we can compare two struct values. Looks like only array and map types are not handled there. So, I think we can visit all data types of a struct and if it does not have array or map, we can use new agg code path. Can you update Utils.scala? I am thinking about if an array or a map appear in the grouping expressions, we throw an analysis error and say it is not allowed right now.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added (proper) StructType checking.

Do you want me to also start throwing AnalysisError's?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can make the change of throwing analysis error in my pr.

} else {
children.head
}
aggregate.AggregateExpression2(
aggregateFunction = aggregate.Count(children.head),
aggregateFunction = aggregate.Count(child),
mode = aggregate.Complete,
isDistinct = true)

Expand Down Expand Up @@ -320,7 +325,7 @@ object MultipleDistinctRewriter extends Rule[LogicalPlan] {
val gid = new AttributeReference("gid", IntegerType, false)()
val groupByMap = a.groupingExpressions.collect {
case ne: NamedExpression => ne -> ne.toAttribute
case e => e -> new AttributeReference(e.prettyName, e.dataType, e.nullable)()
case e => e -> new AttributeReference(e.prettyString, e.dataType, e.nullable)()
}
val groupByAttrs = groupByMap.map(_._2)

Expand Down Expand Up @@ -365,14 +370,15 @@ object MultipleDistinctRewriter extends Rule[LogicalPlan] {
// Setup expand for the 'regular' aggregate expressions.
val regularAggExprs = aggExpressions.filter(!_.isDistinct)
val regularAggChildren = regularAggExprs.flatMap(_.aggregateFunction.children).distinct
val regularAggChildAttrMap = regularAggChildren.map(expressionAttributePair).toMap
val regularAggChildAttrMap = regularAggChildren.map(expressionAttributePair)

// Setup aggregates for 'regular' aggregate expressions.
val regularGroupId = Literal(0)
val regularAggChildAttrLookup = regularAggChildAttrMap.toMap
val regularAggOperatorMap = regularAggExprs.map { e =>
// Perform the actual aggregation in the initial aggregate.
val af = patchAggregateFunctionChildren(e.aggregateFunction)(regularAggChildAttrMap)
val operator = Alias(e.copy(aggregateFunction = af), e.toString)()
val af = patchAggregateFunctionChildren(e.aggregateFunction)(regularAggChildAttrLookup)
val operator = Alias(e.copy(aggregateFunction = af), e.prettyString)()

// Select the result of the first aggregate in the last aggregate.
val result = AggregateExpression2(
Expand Down Expand Up @@ -416,7 +422,7 @@ object MultipleDistinctRewriter extends Rule[LogicalPlan] {
// Construct the expand operator.
val expand = Expand(
regularAggProjection ++ distinctAggProjections,
groupByAttrs ++ distinctAggChildAttrs ++ Seq(gid) ++ regularAggChildAttrMap.values.toSeq,
groupByAttrs ++ distinctAggChildAttrs ++ Seq(gid) ++ regularAggChildAttrMap.map(_._2),
a.child)

// Construct the first aggregate operator. This de-duplicates the all the children of
Expand Down Expand Up @@ -457,5 +463,5 @@ object MultipleDistinctRewriter extends Rule[LogicalPlan] {
// NamedExpression. This is done to prevent collisions between distinct and regular aggregate
// children, in this case attribute reuse causes the input of the regular aggregate to bound to
// the (nulled out) input of the distinct aggregate.
e -> new AttributeReference(e.prettyName, e.dataType, true)()
e -> new AttributeReference(e.prettyString, e.dataType, true)()
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.util.TypeUtils
import org.apache.spark.sql.types.{NullType, BooleanType, DataType}
import org.apache.spark.sql.types._


case class If(predicate: Expression, trueValue: Expression, falseValue: Expression)
Expand Down Expand Up @@ -419,3 +419,31 @@ case class Greatest(children: Seq[Expression]) extends Expression {
"""
}
}

/** Operator that drops a row when it contains any nulls. */
case class DropAnyNull(child: Expression) extends UnaryExpression with ExpectsInputTypes {
override def nullable: Boolean = true
override def dataType: DataType = child.dataType
override def inputTypes: Seq[AbstractDataType] = Seq(StructType)

protected override def nullSafeEval(input: Any): InternalRow = {
val row = input.asInstanceOf[InternalRow]
if (row.anyNull) {
null
} else {
row
}
}

override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
nullSafeCodeGen(ctx, ev, eval => {
s"""
if ($eval.anyNull()) {
${ev.isNull} = true;
} else {
${ev.value} = $eval;
}
"""
})
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -306,6 +306,9 @@ case class Expand(
output: Seq[Attribute],
child: LogicalPlan) extends UnaryNode {

override def references: AttributeSet =
AttributeSet(projections.flatten.flatMap(_.references))

override def statistics: Statistics = {
// TODO shouldn't we factor in the size of the projection versus the size of the backing child
// row?
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -231,4 +231,18 @@ class ConditionalExpressionSuite extends SparkFunSuite with ExpressionEvalHelper
checkConsistencyBetweenInterpretedAndCodegen(Greatest, dt, 2)
}
}

test("function dropAnyNull") {
val drop = DropAnyNull(CreateStruct(Seq('a.string.at(0), 'b.string.at(1))))
val a = create_row("a", "q")
val nullStr: String = null
checkEvaluation(drop, a, a)
checkEvaluation(drop, null, create_row("b", nullStr))
checkEvaluation(drop, null, create_row(nullStr, nullStr))

val row = 'r.struct(
StructField("a", StringType, false),
StructField("b", StringType, true)).at(0)
checkEvaluation(DropAnyNull(row), null, create_row(null))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,31 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext {
)
}

test("multiple column distinct count") {
val df1 = Seq(
("a", "b", "c"),
("a", "b", "c"),
("a", "b", "d"),
("x", "y", "z"),
("x", "q", null.asInstanceOf[String]))
.toDF("key1", "key2", "key3")

checkAnswer(
df1.agg(countDistinct('key1, 'key2)),
Row(3)
)

checkAnswer(
df1.agg(countDistinct('key1, 'key2, 'key3)),
Row(3)
)

checkAnswer(
df1.groupBy('key1).agg(countDistinct('key2, 'key3)),
Seq(Row("a", 2), Row("x", 1))
)
}

test("zero count") {
val emptyTableData = Seq.empty[(Int, Int)].toDF("a", "b")
checkAnswer(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -516,21 +516,46 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Te
Row(3, 4, 4, 3, null) :: Nil)
}

test("multiple distinct column sets") {
test("single distinct multiple columns set") {
checkAnswer(
sqlContext.sql(
"""
|SELECT
| key,
| count(distinct value1, value2)
|FROM agg2
|GROUP BY key
""".stripMargin),
Row(null, 3) ::
Row(1, 3) ::
Row(2, 1) ::
Row(3, 0) :: Nil)
}

test("multiple distinct multiple columns sets") {
checkAnswer(
sqlContext.sql(
"""
|SELECT
| key,
| count(distinct value1),
| count(distinct value2)
| sum(distinct value1),
| count(distinct value2),
| sum(distinct value2),
| count(distinct value1, value2),
| count(value1),
| sum(value1),
| count(value2),
| sum(value2),
| count(*),
| count(1)
|FROM agg2
|GROUP BY key
""".stripMargin),
Row(null, 3, 3) ::
Row(1, 2, 3) ::
Row(2, 2, 1) ::
Row(3, 0, 1) :: Nil)
Row(null, 3, 30, 3, 60, 3, 3, 30, 3, 60, 4, 4) ::
Row(1, 2, 40, 3, -10, 3, 3, 70, 3, -10, 3, 3) ::
Row(2, 2, 0, 1, 1, 1, 3, 1, 3, 3, 4, 4) ::
Row(3, 0, null, 1, 3, 0, 0, null, 1, 3, 2, 2) :: Nil)
}

test("test count") {
Expand Down