From 7837e0fa233dc950efa111539e2a9966c7cb0e02 Mon Sep 17 00:00:00 2001 From: David Navas Date: Tue, 6 Jun 2017 20:02:47 -0700 Subject: [PATCH 1/2] Skip literals so that percentile_approx works correctly --- .../DistinctAggregationRewriter.scala | 19 +++++++++++++------ 1 file changed, 13 insertions(+), 6 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DistinctAggregationRewriter.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DistinctAggregationRewriter.scala index 0cfec43ec72c1..eadf1fa09eedb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DistinctAggregationRewriter.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DistinctAggregationRewriter.scala @@ -142,7 +142,7 @@ case class DistinctAggregationRewriter(conf: CatalystConf) extends Rule[LogicalP // Setup unique distinct aggregate children. val distinctAggChildren = distinctAggGroups.keySet.flatten.toSeq.distinct - val distinctAggChildAttrMap = distinctAggChildren.map(expressionAttributePair) + val distinctAggChildAttrMap = distinctAggChildren.flatMap(expressionAttributePair) val distinctAggChildAttrs = distinctAggChildAttrMap.map(_._2) // Setup expand & aggregate operators for distinct aggregate expressions. @@ -161,7 +161,7 @@ case class DistinctAggregationRewriter(conf: CatalystConf) extends Rule[LogicalP val operators = expressions.map { e => val af = e.aggregateFunction val naf = patchAggregateFunctionChildren(af) { x => - evalWithinGroup(id, distinctAggChildAttrLookup(x)) + evalWithinGroup(id, distinctAggChildAttrLookup.getOrElse(x, x)) } (e, e.copy(aggregateFunction = naf, isDistinct = false)) } @@ -172,14 +172,16 @@ case class DistinctAggregationRewriter(conf: CatalystConf) extends Rule[LogicalP // Setup expand for the 'regular' aggregate expressions. val regularAggExprs = aggExpressions.filter(!_.isDistinct) val regularAggChildren = regularAggExprs.flatMap(_.aggregateFunction.children).distinct - val regularAggChildAttrMap = regularAggChildren.map(expressionAttributePair) + val regularAggChildAttrMap = regularAggChildren.flatMap(expressionAttributePair) // Setup aggregates for 'regular' aggregate expressions. val regularGroupId = Literal(0) val regularAggChildAttrLookup = regularAggChildAttrMap.toMap val regularAggOperatorMap = regularAggExprs.map { e => // Perform the actual aggregation in the initial aggregate. - val af = patchAggregateFunctionChildren(e.aggregateFunction)(regularAggChildAttrLookup) + val af = patchAggregateFunctionChildren(e.aggregateFunction) { x => + regularAggChildAttrLookup.getOrElse(x, x) + } val operator = Alias(e.copy(aggregateFunction = af), e.prettyString)() // Select the result of the first aggregate in the last aggregate. @@ -260,10 +262,15 @@ case class DistinctAggregationRewriter(conf: CatalystConf) extends Rule[LogicalP private def nullify(e: Expression) = Literal.create(null, e.dataType) - private def expressionAttributePair(e: Expression) = + private def expressionAttributePair(e: Expression) = { // We are creating a new reference here instead of reusing the attribute in case of a // NamedExpression. This is done to prevent collisions between distinct and regular aggregate // children, in this case attribute reuse causes the input of the regular aggregate to bound to // the (nulled out) input of the distinct aggregate. - e -> new AttributeReference(e.prettyString, e.dataType, true)() + e match { + case x: Literal => None + case x => Some(e -> new AttributeReference(e.prettyString, e.dataType, true)()) + + } + } } From e757de5903f91c416605b2c4429bcece21276952 Mon Sep 17 00:00:00 2001 From: David Navas Date: Wed, 7 Jun 2017 16:49:30 -0700 Subject: [PATCH 2/2] Backport of apache/spark/pull/15668 --- .../DistinctAggregationRewriter.scala | 50 +++++++++++-------- .../sql/hive/execution/HiveUDFSuite.scala | 34 +++++++++++++ 2 files changed, 64 insertions(+), 20 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DistinctAggregationRewriter.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DistinctAggregationRewriter.scala index eadf1fa09eedb..137e9cb9900f1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DistinctAggregationRewriter.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DistinctAggregationRewriter.scala @@ -118,7 +118,21 @@ case class DistinctAggregationRewriter(conf: CatalystConf) extends Rule[LogicalP // Extract distinct aggregate expressions. val distinctAggGroups = aggExpressions .filter(_.isDistinct) - .groupBy(_.aggregateFunction.children.toSet) + .groupBy { e => + val unfoldableChildren = e.aggregateFunction.children.filter(!_.foldable).toSet + if (unfoldableChildren.nonEmpty) { + // Only expand the unfoldable children + unfoldableChildren + } else { + // If aggregateFunction's children are all foldable + // we must expand at least one of the children (here we take the first child), + // or If we don't, we will get the wrong result, for example: + // count(distinct 1) will be explained to count(1) after the rewrite function. + // Generally, the distinct aggregateFunction should not run + // foldable TypeCheck for the first child. + e.aggregateFunction.children.take(1).toSet + } + } // Aggregation strategy can handle the query with single distinct if (distinctAggGroups.size > 1) { @@ -134,15 +148,14 @@ case class DistinctAggregationRewriter(conf: CatalystConf) extends Rule[LogicalP def evalWithinGroup(id: Literal, e: Expression) = If(EqualTo(gid, id), e, nullify(e)) def patchAggregateFunctionChildren( af: AggregateFunction)( - attrs: Expression => Expression): AggregateFunction = { - af.withNewChildren(af.children.map { - case afc => attrs(afc) - }).asInstanceOf[AggregateFunction] + attrs: Expression => Option[Expression]): AggregateFunction = { + val newChildren = af.children.map(c => attrs(c).getOrElse(c)) + af.withNewChildren(newChildren).asInstanceOf[AggregateFunction] } // Setup unique distinct aggregate children. val distinctAggChildren = distinctAggGroups.keySet.flatten.toSeq.distinct - val distinctAggChildAttrMap = distinctAggChildren.flatMap(expressionAttributePair) + val distinctAggChildAttrMap = distinctAggChildren.map(expressionAttributePair) val distinctAggChildAttrs = distinctAggChildAttrMap.map(_._2) // Setup expand & aggregate operators for distinct aggregate expressions. @@ -161,7 +174,7 @@ case class DistinctAggregationRewriter(conf: CatalystConf) extends Rule[LogicalP val operators = expressions.map { e => val af = e.aggregateFunction val naf = patchAggregateFunctionChildren(af) { x => - evalWithinGroup(id, distinctAggChildAttrLookup.getOrElse(x, x)) + distinctAggChildAttrLookup.get(x).map(evalWithinGroup(id, _)) } (e, e.copy(aggregateFunction = naf, isDistinct = false)) } @@ -170,18 +183,20 @@ case class DistinctAggregationRewriter(conf: CatalystConf) extends Rule[LogicalP } // Setup expand for the 'regular' aggregate expressions. - val regularAggExprs = aggExpressions.filter(!_.isDistinct) - val regularAggChildren = regularAggExprs.flatMap(_.aggregateFunction.children).distinct - val regularAggChildAttrMap = regularAggChildren.flatMap(expressionAttributePair) + // only expand unfoldable children + val regularAggExprs = aggExpressions + .filter(e => !e.isDistinct && e.children.exists(!_.foldable)) + val regularAggChildren = regularAggExprs + .flatMap(_.aggregateFunction.children.filter(!_.foldable)) + .distinct + val regularAggChildAttrMap = regularAggChildren.map(expressionAttributePair) // Setup aggregates for 'regular' aggregate expressions. val regularGroupId = Literal(0) val regularAggChildAttrLookup = regularAggChildAttrMap.toMap val regularAggOperatorMap = regularAggExprs.map { e => // Perform the actual aggregation in the initial aggregate. - val af = patchAggregateFunctionChildren(e.aggregateFunction) { x => - regularAggChildAttrLookup.getOrElse(x, x) - } + val af = patchAggregateFunctionChildren(e.aggregateFunction)(regularAggChildAttrLookup.get) val operator = Alias(e.copy(aggregateFunction = af), e.prettyString)() // Select the result of the first aggregate in the last aggregate. @@ -262,15 +277,10 @@ case class DistinctAggregationRewriter(conf: CatalystConf) extends Rule[LogicalP private def nullify(e: Expression) = Literal.create(null, e.dataType) - private def expressionAttributePair(e: Expression) = { + private def expressionAttributePair(e: Expression) = // We are creating a new reference here instead of reusing the attribute in case of a // NamedExpression. This is done to prevent collisions between distinct and regular aggregate // children, in this case attribute reuse causes the input of the regular aggregate to bound to // the (nulled out) input of the distinct aggregate. - e match { - case x: Literal => None - case x => Some(e -> new AttributeReference(e.prettyString, e.dataType, true)()) - - } - } + e -> new AttributeReference(e.prettyString, e.dataType, true)() } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala index 3f73657814839..310a7a2c486f7 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala @@ -143,6 +143,40 @@ class HiveUDFSuite extends QueryTest with TestHiveSingleton with SQLTestUtils { } test("Generic UDAF aggregates") { + checkAnswer(sql( + """ + |SELECT percentile_approx(2, 0.99999), + | sum(distinct 1), + | count(distinct 1,2,3,4) FROM src LIMIT 1 + """.stripMargin), sql("SELECT 2, 1, 1 FROM src LIMIT 1").collect().toSeq) + + checkAnswer(sql( + """ + |SELECT ceiling(percentile_approx(distinct key, 0.99999)), + | count(distinct key), + | sum(distinct key), + | count(distinct 1), + | sum(distinct 1), + | sum(1) FROM src LIMIT 1 + """.stripMargin), + sql( + """ + |SELECT max(key), + | count(distinct key), + | sum(distinct key), + | 1, 1, sum(1) FROM src LIMIT 1 + """.stripMargin).collect().toSeq) + + checkAnswer(sql( + """ + |SELECT ceiling(percentile_approx(distinct key, 0.9 + 0.09999)), + | count(distinct key), sum(distinct key), + | count(distinct 1), sum(distinct 1), + | sum(1) FROM src LIMIT 1 + """.stripMargin), + sql("SELECT max(key), count(distinct key), sum(distinct key), 1, 1, sum(1) FROM src LIMIT 1") + .collect().toSeq) + checkAnswer(sql("SELECT ceiling(percentile_approx(key, 0.99999)) FROM src LIMIT 1"), sql("SELECT max(key) FROM src LIMIT 1").collect().toSeq)