diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DistinctAggregationRewriter.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DistinctAggregationRewriter.scala index 0cfec43ec72c1..137e9cb9900f1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DistinctAggregationRewriter.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DistinctAggregationRewriter.scala @@ -118,7 +118,21 @@ case class DistinctAggregationRewriter(conf: CatalystConf) extends Rule[LogicalP // Extract distinct aggregate expressions. val distinctAggGroups = aggExpressions .filter(_.isDistinct) - .groupBy(_.aggregateFunction.children.toSet) + .groupBy { e => + val unfoldableChildren = e.aggregateFunction.children.filter(!_.foldable).toSet + if (unfoldableChildren.nonEmpty) { + // Only expand the unfoldable children + unfoldableChildren + } else { + // If aggregateFunction's children are all foldable + // we must expand at least one of the children (here we take the first child), + // or If we don't, we will get the wrong result, for example: + // count(distinct 1) will be explained to count(1) after the rewrite function. + // Generally, the distinct aggregateFunction should not run + // foldable TypeCheck for the first child. + e.aggregateFunction.children.take(1).toSet + } + } // Aggregation strategy can handle the query with single distinct if (distinctAggGroups.size > 1) { @@ -134,10 +148,9 @@ case class DistinctAggregationRewriter(conf: CatalystConf) extends Rule[LogicalP def evalWithinGroup(id: Literal, e: Expression) = If(EqualTo(gid, id), e, nullify(e)) def patchAggregateFunctionChildren( af: AggregateFunction)( - attrs: Expression => Expression): AggregateFunction = { - af.withNewChildren(af.children.map { - case afc => attrs(afc) - }).asInstanceOf[AggregateFunction] + attrs: Expression => Option[Expression]): AggregateFunction = { + val newChildren = af.children.map(c => attrs(c).getOrElse(c)) + af.withNewChildren(newChildren).asInstanceOf[AggregateFunction] } // Setup unique distinct aggregate children. @@ -161,7 +174,7 @@ case class DistinctAggregationRewriter(conf: CatalystConf) extends Rule[LogicalP val operators = expressions.map { e => val af = e.aggregateFunction val naf = patchAggregateFunctionChildren(af) { x => - evalWithinGroup(id, distinctAggChildAttrLookup(x)) + distinctAggChildAttrLookup.get(x).map(evalWithinGroup(id, _)) } (e, e.copy(aggregateFunction = naf, isDistinct = false)) } @@ -170,8 +183,12 @@ case class DistinctAggregationRewriter(conf: CatalystConf) extends Rule[LogicalP } // Setup expand for the 'regular' aggregate expressions. - val regularAggExprs = aggExpressions.filter(!_.isDistinct) - val regularAggChildren = regularAggExprs.flatMap(_.aggregateFunction.children).distinct + // only expand unfoldable children + val regularAggExprs = aggExpressions + .filter(e => !e.isDistinct && e.children.exists(!_.foldable)) + val regularAggChildren = regularAggExprs + .flatMap(_.aggregateFunction.children.filter(!_.foldable)) + .distinct val regularAggChildAttrMap = regularAggChildren.map(expressionAttributePair) // Setup aggregates for 'regular' aggregate expressions. @@ -179,7 +196,7 @@ case class DistinctAggregationRewriter(conf: CatalystConf) extends Rule[LogicalP val regularAggChildAttrLookup = regularAggChildAttrMap.toMap val regularAggOperatorMap = regularAggExprs.map { e => // Perform the actual aggregation in the initial aggregate. - val af = patchAggregateFunctionChildren(e.aggregateFunction)(regularAggChildAttrLookup) + val af = patchAggregateFunctionChildren(e.aggregateFunction)(regularAggChildAttrLookup.get) val operator = Alias(e.copy(aggregateFunction = af), e.prettyString)() // Select the result of the first aggregate in the last aggregate. diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala index 3f73657814839..310a7a2c486f7 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala @@ -143,6 +143,40 @@ class HiveUDFSuite extends QueryTest with TestHiveSingleton with SQLTestUtils { } test("Generic UDAF aggregates") { + checkAnswer(sql( + """ + |SELECT percentile_approx(2, 0.99999), + | sum(distinct 1), + | count(distinct 1,2,3,4) FROM src LIMIT 1 + """.stripMargin), sql("SELECT 2, 1, 1 FROM src LIMIT 1").collect().toSeq) + + checkAnswer(sql( + """ + |SELECT ceiling(percentile_approx(distinct key, 0.99999)), + | count(distinct key), + | sum(distinct key), + | count(distinct 1), + | sum(distinct 1), + | sum(1) FROM src LIMIT 1 + """.stripMargin), + sql( + """ + |SELECT max(key), + | count(distinct key), + | sum(distinct key), + | 1, 1, sum(1) FROM src LIMIT 1 + """.stripMargin).collect().toSeq) + + checkAnswer(sql( + """ + |SELECT ceiling(percentile_approx(distinct key, 0.9 + 0.09999)), + | count(distinct key), sum(distinct key), + | count(distinct 1), sum(distinct 1), + | sum(1) FROM src LIMIT 1 + """.stripMargin), + sql("SELECT max(key), count(distinct key), sum(distinct key), 1, 1, sum(1) FROM src LIMIT 1") + .collect().toSeq) + checkAnswer(sql("SELECT ceiling(percentile_approx(key, 0.99999)) FROM src LIMIT 1"), sql("SELECT max(key) FROM src LIMIT 1").collect().toSeq)