diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UpdateAttributeNullability.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UpdateNullability.scala similarity index 74% rename from sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UpdateAttributeNullability.scala rename to sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UpdateNullability.scala index 5004108d348b6..adc696618d3d7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UpdateAttributeNullability.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UpdateNullability.scala @@ -17,8 +17,8 @@ package org.apache.spark.sql.catalyst.analysis -import org.apache.spark.sql.catalyst.expressions.Attribute -import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LogicalPlan} +import org.apache.spark.sql.catalyst.expressions.{Attribute, GroupingExprRef, NamedExpression} +import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LeafNode, LogicalPlan} import org.apache.spark.sql.catalyst.rules.Rule /** @@ -52,3 +52,22 @@ object UpdateAttributeNullability extends Rule[LogicalPlan] { } } } + +/** + * Updates nullability of [[GroupingExprRef]]s in a resolved LogicalPlan by using the nullability of + * referenced grouping expression. + */ +object UpdateGroupingExprRefNullability extends Rule[LogicalPlan] { + def apply(plan: LogicalPlan): LogicalPlan = plan transform { + case a: Aggregate => + val nullabilities = a.groupingExpressions.map(_.nullable).toArray + + val newAggregateExpressions = + a.aggregateExpressions.map(_.transform { + case g: GroupingExprRef if g.nullable != nullabilities(g.ordinal) => + g.copy(nullable = nullabilities(g.ordinal)) + }.asInstanceOf[NamedExpression]) + + a.copy(aggregateExpressions = newAggregateExpressions) + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AliasHelper.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AliasHelper.scala index 1f3f762662252..e9673d7f20f20 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AliasHelper.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AliasHelper.scala @@ -35,7 +35,7 @@ trait AliasHelper { protected def getAliasMap(plan: Aggregate): AttributeMap[Alias] = { // Find all the aliased expressions in the aggregate list that don't include any actual // AggregateExpression or PythonUDF, and create a map from the alias to the expression - val aliasMap = plan.aggregateExpressions.collect { + val aliasMap = plan.aggregateExpressionsWithoutGroupingRefs.collect { case a: Alias if a.child.find(e => e.isInstanceOf[AggregateExpression] || PythonUDF.isGroupedAggPandasUDF(e)).isEmpty => (a.toAttribute, a) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala index 281734c6f14ae..8c70c86aa1868 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala @@ -80,6 +80,14 @@ object AggregateExpression { filter, NamedExpression.newExprId) } + + def containsAggregate(expr: Expression): Boolean = { + expr.find(isAggregate).isDefined + } + + def isAggregate(expr: Expression): Boolean = { + expr.isInstanceOf[AggregateExpression] || PythonUDF.isGroupedAggPandasUDF(expr) + } } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/grouping.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/grouping.scala index 0f14203a901c6..808c8222d2938 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/grouping.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/grouping.scala @@ -277,3 +277,22 @@ object GroupingAnalytics { } } } + +/** + * A reference to an grouping expression in [[Aggregate]] node. + * + * @param ordinal The ordinal of the grouping expression in [[Aggregate]] that this expression + * refers to. + * @param dataType The [[DataType]] of the referenced grouping expression. + * @param nullable True if null is a valid value for the referenced grouping expression. + */ +case class GroupingExprRef( + ordinal: Int, + dataType: DataType, + nullable: Boolean) + extends LeafExpression with Unevaluable { + + override def stringArgs: Iterator[Any] = { + Iterator(ordinal) + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ComplexTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ComplexTypes.scala index 0ff11ca49f3d1..8f1548a9788af 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ComplexTypes.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ComplexTypes.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.catalyst.optimizer import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan} +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.rules.Rule /** @@ -26,15 +26,6 @@ import org.apache.spark.sql.catalyst.rules.Rule */ object SimplifyExtractValueOps extends Rule[LogicalPlan] { override def apply(plan: LogicalPlan): LogicalPlan = plan transform { - // One place where this optimization is invalid is an aggregation where the select - // list expression is a function of a grouping expression: - // - // SELECT struct(a,b).a FROM tbl GROUP BY struct(a,b) - // - // cannot be simplified to SELECT a FROM tbl GROUP BY struct(a,b). So just skip this - // optimization for Aggregates (although this misses some cases where the optimization - // can be made). - case a: Aggregate => a case p => p.transformExpressionsUp { // Remove redundant field extraction. case GetStructField(createNamedStruct: CreateNamedStruct, ordinal, _) => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/EnforceGroupingReferencesInAggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/EnforceGroupingReferencesInAggregates.scala new file mode 100644 index 0000000000000..74042fcbc85b8 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/EnforceGroupingReferencesInAggregates.scala @@ -0,0 +1,34 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.optimizer + +import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan} +import org.apache.spark.sql.catalyst.rules.Rule + +/** + * This rule ensures that [[Aggregate]] nodes contain all required [[GroupingExprRef]] + * references for optimization phase. + */ +object EnforceGroupingReferencesInAggregates extends Rule[LogicalPlan] { + override def apply(plan: LogicalPlan): LogicalPlan = { + plan transform { + case a: Aggregate => + Aggregate.withGroupingRefs(a.groupingExpressions, a.aggregateExpressions, a.child) + } + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 3402374630f2c..2f1020535ff4c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -118,7 +118,8 @@ abstract class Optimizer(catalogManager: CatalogManager) OptimizeUpdateFields, SimplifyExtractValueOps, OptimizeCsvJsonExprs, - CombineConcats) ++ + CombineConcats, + UpdateGroupingExprRefNullability) ++ extendedOperatorOptimizationRules val operatorOptimizationBatch: Seq[Batch] = { @@ -147,6 +148,7 @@ abstract class Optimizer(catalogManager: CatalogManager) EliminateView, ReplaceExpressions, RewriteNonCorrelatedExists, + EnforceGroupingReferencesInAggregates, ComputeCurrentTime, GetCurrentDatabaseAndCatalog(catalogManager)) :: ////////////////////////////////////////////////////////////////////////////////////////// @@ -266,7 +268,9 @@ abstract class Optimizer(catalogManager: CatalogManager) RewriteCorrelatedScalarSubquery.ruleName :: RewritePredicateSubquery.ruleName :: NormalizeFloatingNumbers.ruleName :: - ReplaceUpdateFieldsExpression.ruleName :: Nil + ReplaceUpdateFieldsExpression.ruleName :: + EnforceGroupingReferencesInAggregates.ruleName :: + UpdateGroupingExprRefNullability.ruleName :: Nil /** * Optimize all the subqueries inside expression. @@ -506,7 +510,7 @@ object RemoveRedundantAggregates extends Rule[LogicalPlan] with AliasHelper { case upper @ Aggregate(_, _, lower: Aggregate) if lowerIsRedundant(upper, lower) => val aliasMap = getAliasMap(lower) - val newAggregate = upper.copy( + val newAggregate = Aggregate.withGroupingRefs( child = lower.child, groupingExpressions = upper.groupingExpressions.map(replaceAlias(_, aliasMap)), aggregateExpressions = upper.aggregateExpressions.map( @@ -522,23 +526,19 @@ object RemoveRedundantAggregates extends Rule[LogicalPlan] with AliasHelper { } private def lowerIsRedundant(upper: Aggregate, lower: Aggregate): Boolean = { - val upperHasNoAggregateExpressions = !upper.aggregateExpressions.exists(isAggregate) + val upperHasNoAggregateExpressions = + !upper.aggregateExpressions.exists(AggregateExpression.containsAggregate) lazy val upperRefsOnlyDeterministicNonAgg = upper.references.subsetOf(AttributeSet( lower .aggregateExpressions .filter(_.deterministic) - .filter(!isAggregate(_)) + .filterNot(AggregateExpression.containsAggregate) .map(_.toAttribute) )) upperHasNoAggregateExpressions && upperRefsOnlyDeterministicNonAgg } - - private def isAggregate(expr: Expression): Boolean = { - expr.find(e => e.isInstanceOf[AggregateExpression] || - PythonUDF.isGroupedAggPandasUDF(e)).isDefined - } } /** @@ -1979,7 +1979,18 @@ object RemoveLiteralFromGroupExpressions extends Rule[LogicalPlan] { case a @ Aggregate(grouping, _, _) if grouping.nonEmpty => val newGrouping = grouping.filter(!_.foldable) if (newGrouping.nonEmpty) { - a.copy(groupingExpressions = newGrouping) + val droppedGroupsBefore = + grouping.scanLeft(0)((n, e) => n + (if (e.foldable) 1 else 0)).toArray + + val newAggregateExpressions = + a.aggregateExpressions.map(_.transform { + case g: GroupingExprRef if droppedGroupsBefore(g.ordinal) > 0 => + g.copy(ordinal = g.ordinal - droppedGroupsBefore(g.ordinal)) + }.asInstanceOf[NamedExpression]) + + a.copy( + groupingExpressions = newGrouping, + aggregateExpressions = newAggregateExpressions) } else { // All grouping expressions are literals. We should not drop them all, because this can // change the return semantics when the input of the Aggregate is empty (SPARK-17114). We @@ -2000,7 +2011,25 @@ object RemoveRepetitionFromGroupExpressions extends Rule[LogicalPlan] { if (newGrouping.size == grouping.size) { a } else { - a.copy(groupingExpressions = newGrouping) + var i = 0 + val droppedGroupsBefore = grouping.scanLeft(0)((n, e) => + n + (if (i >= newGrouping.size || e.eq(newGrouping(i))) { + i += 1 + 0 + } else { + 1 + }) + ).toArray + + val newAggregateExpressions = + a.aggregateExpressions.map(_.transform { + case g: GroupingExprRef if droppedGroupsBefore(g.ordinal) > 0 => + g.copy(ordinal = g.ordinal - droppedGroupsBefore(g.ordinal)) + }.asInstanceOf[NamedExpression]) + + a.copy( + groupingExpressions = newGrouping, + aggregateExpressions = newAggregateExpressions) } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala index ff212f7c04e84..ef73e58645a89 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala @@ -632,9 +632,10 @@ object RewriteCorrelatedScalarSubquery extends Rule[LogicalPlan] with AliasHelpe * subqueries. */ def apply(plan: LogicalPlan): LogicalPlan = plan transformUpWithNewOutput { - case a @ Aggregate(grouping, expressions, child) => + case a @ Aggregate(grouping, _, child) => val subqueries = ArrayBuffer.empty[ScalarSubquery] - val rewriteExprs = expressions.map(extractCorrelatedScalarSubqueries(_, subqueries)) + val rewriteExprs = a.aggregateExpressionsWithoutGroupingRefs + .map(extractCorrelatedScalarSubqueries(_, subqueries)) if (subqueries.nonEmpty) { // We currently only allow correlated subqueries in an aggregate if they are part of the // grouping expressions. As a result we need to replace all the scalar subqueries in the diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala index c22a874779fca..a96674fe9705d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala @@ -287,7 +287,7 @@ object PhysicalAggregation { (Seq[NamedExpression], Seq[Expression], Seq[NamedExpression], LogicalPlan) def unapply(a: Any): Option[ReturnType] = a match { - case logical.Aggregate(groupingExpressions, resultExpressions, child) => + case a @ logical.Aggregate(groupingExpressions, resultExpressions, child) => // A single aggregate expression might appear multiple times in resultExpressions. // In order to avoid evaluating an individual aggregate function multiple times, we'll // build a set of semantically distinct aggregate expressions and re-write expressions so @@ -297,11 +297,9 @@ object PhysicalAggregation { val aggregateExpressions = resultExpressions.flatMap { expr => expr.collect { // addExpr() always returns false for non-deterministic expressions and do not add them. - case agg: AggregateExpression - if !equivalentAggregateExpressions.addExpr(agg) => agg - case udf: PythonUDF - if PythonUDF.isGroupedAggPandasUDF(udf) && - !equivalentAggregateExpressions.addExpr(udf) => udf + case a + if AggregateExpression.isAggregate(a) && !equivalentAggregateExpressions.addExpr(a) => + a } } @@ -322,7 +320,7 @@ object PhysicalAggregation { // which takes the grouping columns and final aggregate result buffer as input. // Thus, we must re-write the result expressions so that their attributes match up with // the attributes of the final result projection's input row: - val rewrittenResultExpressions = resultExpressions.map { expr => + val rewrittenResultExpressions = a.aggregateExpressionsWithoutGroupingRefs.map { expr => expr.transformDown { case ae: AggregateExpression => // The final aggregation buffer's attributes will be `finalAggregationAttributes`, diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala index b31e930e490c7..21e87b4c62606 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.catalyst.plans.logical +import scala.collection.mutable + import org.apache.spark.sql.catalyst.AliasIdentifier import org.apache.spark.sql.catalyst.analysis.{AnsiTypeCoercion, MultiInstanceRelation, TypeCoercion, TypeCoercionBase} import org.apache.spark.sql.catalyst.catalog.{CatalogStorageFormat, CatalogTable} @@ -779,14 +781,23 @@ case class Range( /** * This is a Group by operator with the aggregate functions and projections. * - * @param groupingExpressions expressions for grouping keys - * @param aggregateExpressions expressions for a project list, which could contain - * [[AggregateExpression]]s. + * @param groupingExpressions Expressions for grouping keys. + * @param aggregateExpressions Expressions for a project list, which can contain + * [[AggregateExpression]]s and [[GroupingExprRef]]s. + * @param child The child of the aggregate node. + * + * Expressions without aggregate functions in [[aggregateExpressions]] can contain + * [[GroupingExprRef]]s to refer to complex grouping expressions in [[groupingExpressions]]. These + * references ensure that optimization rules don't change the aggregate expressions to invalid ones + * that no longer refer to any grouping expressions and also simplify the expression transformations + * on the node (need to transform the expression only once). * - * Note: Currently, aggregateExpressions is the project list of this Group by operator. Before - * separating projection from grouping and aggregate, we should avoid expression-level optimization - * on aggregateExpressions, which could reference an expression in groupingExpressions. - * For example, see the rule [[org.apache.spark.sql.catalyst.optimizer.SimplifyExtractValueOps]] + * For example, in the following query Spark shouldn't optimize the aggregate expression + * `Not(IsNull(c))` to `IsNotNull(c)` as the grouping expression is `IsNull(c)`: + * SELECT not(c IS NULL) + * FROM t + * GROUP BY c IS NULL + * Instead, the aggregate expression should contain `Not(GroupingExprRef(0))`. */ case class Aggregate( groupingExpressions: Seq[Expression], @@ -813,8 +824,21 @@ case class Aggregate( } } + private def expandGroupingReferences(e: Expression): Expression = { + e match { + case _ if AggregateExpression.isAggregate(e) => e + case g: GroupingExprRef => groupingExpressions(g.ordinal) + case _ => e.mapChildren(expandGroupingReferences) + } + } + + lazy val aggregateExpressionsWithoutGroupingRefs = { + aggregateExpressions.map(expandGroupingReferences(_).asInstanceOf[NamedExpression]) + } + override lazy val validConstraints: ExpressionSet = { - val nonAgg = aggregateExpressions.filter(_.find(_.isInstanceOf[AggregateExpression]).isEmpty) + val nonAgg = aggregateExpressionsWithoutGroupingRefs. + filter(_.find(_.isInstanceOf[AggregateExpression]).isEmpty) getAllValidConstraints(nonAgg) } @@ -822,6 +846,51 @@ case class Aggregate( copy(child = newChild) } +object Aggregate { + private def collectComplexGroupingExpressions(groupingExpressions: Seq[Expression]) = { + val complexGroupingExpressions = mutable.Map.empty[Expression, (Expression, Int)] + var i = 0 + groupingExpressions.foreach { ge => + if (!ge.foldable && ge.children.nonEmpty && + !complexGroupingExpressions.contains(ge.canonicalized)) { + complexGroupingExpressions += ge.canonicalized -> (ge, i) + } + i += 1 + } + complexGroupingExpressions + } + + private def insertGroupingReferences( + aggregateExpressions: Seq[NamedExpression], + groupingExpressions: collection.Map[Expression, (Expression, Int)]): Seq[NamedExpression] = { + def insertGroupingExprRefs(e: Expression): Expression = { + e match { + case _ if AggregateExpression.isAggregate(e) => e + case _ if groupingExpressions.contains(e.canonicalized) => + val (groupingExpression, ordinal) = groupingExpressions(e.canonicalized) + GroupingExprRef(ordinal, groupingExpression.dataType, groupingExpression.nullable) + case _ => e.mapChildren(insertGroupingExprRefs) + } + } + + aggregateExpressions.map(insertGroupingExprRefs(_).asInstanceOf[NamedExpression]) + } + + def withGroupingRefs( + groupingExpressions: Seq[Expression], + aggregateExpressions: Seq[NamedExpression], + child: LogicalPlan): Aggregate = { + val complexGroupingExpressions = collectComplexGroupingExpressions(groupingExpressions) + val aggrExprWithGroupingReferences = if (complexGroupingExpressions.nonEmpty) { + insertGroupingReferences(aggregateExpressions, complexGroupingExpressions) + } else { + aggregateExpressions + } + + new Aggregate(groupingExpressions, aggrExprWithGroupingReferences, child) + } +} + case class Window( windowExpressions: Seq[NamedExpression], partitionSpec: Seq[Expression], diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RemoveRedundantAggregatesSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RemoveRedundantAggregatesSuite.scala index d376c31ef965f..3eba003d7752b 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RemoveRedundantAggregatesSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RemoveRedundantAggregatesSuite.scala @@ -96,7 +96,7 @@ class RemoveRedundantAggregatesSuite extends PlanTest { .groupBy('a + 'b)(('a + 'b) as 'c) .analyze val optimized = Optimize.execute(query) - comparePlans(optimized, expected) + comparePlans(optimized, EnforceGroupingReferencesInAggregates(expected)) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/complexTypesSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/complexTypesSuite.scala index dcd2fbbf00529..d14996709401b 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/complexTypesSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/complexTypesSuite.scala @@ -36,6 +36,8 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper { object Optimizer extends RuleExecutor[LogicalPlan] { val batches = + Batch("Finish Analysis", Once, + EnforceGroupingReferencesInAggregates) :: Batch("collapse projections", FixedPoint(10), CollapseProject) :: Batch("Constant Folding", FixedPoint(10), @@ -57,7 +59,7 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper { private def checkRule(originalQuery: LogicalPlan, correctAnswer: LogicalPlan) = { val optimized = Optimizer.execute(originalQuery.analyze) assert(optimized.resolved, "optimized plans must be still resolvable") - comparePlans(optimized, correctAnswer.analyze) + comparePlans(optimized, EnforceGroupingReferencesInAggregates(correctAnswer.analyze)) } test("explicit get from namedStruct") { @@ -405,14 +407,6 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper { val arrayAggRel = relation.groupBy( CreateArray(Seq('nullable_id)))(GetArrayItem(CreateArray(Seq('nullable_id)), 0)) checkRule(arrayAggRel, arrayAggRel) - - // This could be done if we had a more complex rule that checks that - // the CreateMap does not come from key. - val originalQuery = relation - .groupBy('id)( - GetMapValue(CreateMap(Seq('id, 'id + 1L)), 0L) as "a" - ) - checkRule(originalQuery, originalQuery) } test("SPARK-23500: namedStruct and getField in the same Project #1") { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala index c6a70fb204354..1c018be6d5773 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala @@ -40,6 +40,7 @@ object ExtractPythonUDFFromAggregate extends Rule[LogicalPlan] { private def belongAggregate(e: Expression, agg: Aggregate): Boolean = { e.isInstanceOf[AggregateExpression] || PythonUDF.isGroupedAggPandasUDF(e) || + e.isInstanceOf[GroupingExprRef] || agg.groupingExpressions.exists(_.semanticEquals(e)) } @@ -119,23 +120,8 @@ object ExtractGroupingPythonUDFFromAggregate extends Rule[LogicalPlan] { groupingExpr += expr } } - val aggExpr = agg.aggregateExpressions.map { expr => - expr.transformUp { - // PythonUDF over aggregate was pull out by ExtractPythonUDFFromAggregate. - // PythonUDF here should be either - // 1. Argument of an aggregate function. - // CheckAnalysis guarantees the arguments are deterministic. - // 2. PythonUDF in grouping key. Grouping key must be deterministic. - // 3. PythonUDF not in grouping key. It is either no arguments or with grouping key - // in its arguments. Such PythonUDF was pull out by ExtractPythonUDFFromAggregate, too. - case p: PythonUDF if p.udfDeterministic => - val canonicalized = p.canonicalized.asInstanceOf[PythonUDF] - attributeMap.getOrElse(canonicalized, p) - }.asInstanceOf[NamedExpression] - } agg.copy( groupingExpressions = groupingExpr.toSeq, - aggregateExpressions = aggExpr, child = Project((projList ++ agg.child.output).toSeq, agg.child)) } diff --git a/sql/core/src/test/resources/sql-tests/inputs/group-by.sql b/sql/core/src/test/resources/sql-tests/inputs/group-by.sql index 6ee1014739759..988ad99418a10 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/group-by.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/group-by.sql @@ -179,3 +179,12 @@ SELECT count(*) FROM test_agg WHERE k = 1 or k = 2 or count(*) + 1L > 1L or max( -- Aggregate with multiple distinct decimal columns SELECT AVG(DISTINCT decimal_col), SUM(DISTINCT decimal_col) FROM VALUES (CAST(1 AS DECIMAL(9, 0))) t(decimal_col); + +-- SPARK-34581: Don't optimize out grouping expressions from aggregate expressions without aggregate function +SELECT not(a IS NULL), count(*) AS c +FROM testData +GROUP BY a IS NULL; + +SELECT if(not(a IS NULL), rand(0), 1), count(*) AS c +FROM testData +GROUP BY a IS NULL; diff --git a/sql/core/src/test/resources/sql-tests/results/group-by.sql.out b/sql/core/src/test/resources/sql-tests/results/group-by.sql.out index 1d8c44c29129a..b5471a785a224 100644 --- a/sql/core/src/test/resources/sql-tests/results/group-by.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/group-by.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 62 +-- Number of queries: 64 -- !query @@ -642,3 +642,25 @@ SELECT AVG(DISTINCT decimal_col), SUM(DISTINCT decimal_col) FROM VALUES (CAST(1 struct -- !query output 1.0000 1 + + +-- !query +SELECT not(a IS NULL), count(*) AS c +FROM testData +GROUP BY a IS NULL +-- !query schema +struct<(NOT (a IS NULL)):boolean,c:bigint> +-- !query output +false 2 +true 7 + + +-- !query +SELECT if(not(a IS NULL), rand(0), 1), count(*) AS c +FROM testData +GROUP BY a IS NULL +-- !query schema +struct<(IF((NOT (a IS NULL)), rand(0), 1)):double,c:bigint> +-- !query output +0.7604953758285915 7 +1.0 2