From 1ad48f443ea2bf18f119be4c99480e94ab06214b Mon Sep 17 00:00:00 2001 From: Jiaan Geng Date: Thu, 27 Jul 2023 19:12:17 +0800 Subject: [PATCH] [SPARK-44571][SQL] Eliminate the Join by combine multiple Aggregates --- .../expressions/BloomFilterMightContain.scala | 4 +- .../expressions/collectionOperations.scala | 6 +- .../expressions/nullExpressions.scala | 4 +- .../expressions/regexpExpressions.scala | 2 + .../expressions/stringExpressions.scala | 4 +- .../optimizer/CombineJoinedAggregates.scala | 152 +++++ .../optimizer/MergeScalarSubqueries.scala | 19 +- .../MergeScalarSubqueriesHelper.scala | 43 ++ .../sql/catalyst/optimizer/Optimizer.scala | 1 + .../sql/catalyst/rules/RuleIdCollection.scala | 1 + .../sql/catalyst/trees/TreePatterns.scala | 5 + .../apache/spark/sql/internal/SQLConf.scala | 22 + .../CombineJoinedAggregatesSuite.scala | 517 ++++++++++++++++++ .../spark/sql/DataFrameAggregateSuite.scala | 228 ++++++++ 14 files changed, 986 insertions(+), 22 deletions(-) create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/CombineJoinedAggregates.scala create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/MergeScalarSubqueriesHelper.scala create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CombineJoinedAggregatesSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BloomFilterMightContain.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BloomFilterMightContain.scala index 784bea899c4c8..3f8653e26a52e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BloomFilterMightContain.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BloomFilterMightContain.scala @@ -23,7 +23,7 @@ import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.DataTypeMismatch import org.apache.spark.sql.catalyst.expressions.Cast.{toSQLExpr, toSQLId, toSQLType} import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode, JavaCode, TrueLiteral} import org.apache.spark.sql.catalyst.expressions.codegen.Block.BlockHelper -import org.apache.spark.sql.catalyst.trees.TreePattern.OUTER_REFERENCE +import org.apache.spark.sql.catalyst.trees.TreePattern.{BLOOM_FILTER, OUTER_REFERENCE, TreePattern} import org.apache.spark.sql.types._ import org.apache.spark.util.sketch.BloomFilter @@ -47,6 +47,8 @@ case class BloomFilterMightContain( override def right: Expression = valueExpression override def prettyName: String = "might_contain" + final override val nodePatterns: Seq[TreePattern] = Seq(BLOOM_FILTER) + override def checkInputDataTypes(): TypeCheckResult = { (left.dataType, right.dataType) match { case (BinaryType, NullType) | (NullType, LongType) | (NullType, NullType) | diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 39bf6734eb27b..c0615da631e0d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -31,7 +31,7 @@ import org.apache.spark.sql.catalyst.expressions.ArraySortLike.NullOrder import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.trees.{BinaryLike, UnaryLike} -import org.apache.spark.sql.catalyst.trees.TreePattern.{ARRAYS_ZIP, CONCAT, TreePattern} +import org.apache.spark.sql.catalyst.trees.TreePattern.{ARRAY_CONTAINS, ARRAYS_OVERLAP, ARRAYS_ZIP, CONCAT, TreePattern} import org.apache.spark.sql.catalyst.types.{DataTypeUtils, PhysicalDataType, PhysicalIntegralType} import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.catalyst.util.DateTimeConstants._ @@ -1428,6 +1428,8 @@ case class ArrayContains(left: Expression, right: Expression) extends BinaryExpression with ImplicitCastInputTypes with NullIntolerant with Predicate with QueryErrorsBase { + final override val nodePatterns: Seq[TreePattern] = Seq(ARRAY_CONTAINS) + @transient private lazy val ordering: Ordering[Any] = TypeUtils.getInterpretedOrdering(right.dataType) @@ -1651,6 +1653,8 @@ case class ArrayAppend(left: Expression, right: Expression) extends ArrayPendBas case class ArraysOverlap(left: Expression, right: Expression) extends BinaryArrayExpressionWithImplicitCast with NullIntolerant with Predicate { + final override val nodePatterns: Seq[TreePattern] = Seq(ARRAYS_OVERLAP) + override def checkInputDataTypes(): TypeCheckResult = super.checkInputDataTypes() match { case TypeCheckResult.TypeCheckSuccess => TypeUtils.checkForOrderingExpr(elementType, prettyName) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala index 010d79f808d10..123eb092306dd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala @@ -22,7 +22,7 @@ import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions.Cast._ import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.expressions.codegen.Block._ -import org.apache.spark.sql.catalyst.trees.TreePattern.{COALESCE, NULL_CHECK, TreePattern} +import org.apache.spark.sql.catalyst.trees.TreePattern.{AT_LEAST_N_NON_NULLS, COALESCE, NULL_CHECK, TreePattern} import org.apache.spark.sql.catalyst.util.TypeUtils import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.internal.SQLConf @@ -429,6 +429,8 @@ case class AtLeastNNonNulls(n: Int, children: Seq[Expression]) extends Predicate override def nullable: Boolean = false override def foldable: Boolean = children.forall(_.foldable) + final override val nodePatterns: Seq[TreePattern] = Seq(AT_LEAST_N_NON_NULLS) + private[this] val childrenArray = children.toArray override def eval(input: InternalRow): Boolean = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala index b33de303b5d55..4855b2d4a10d9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala @@ -466,6 +466,8 @@ case class RLike(left: Expression, right: Expression) extends StringRegexExpress override def toString: String = s"RLIKE($left, $right)" override def sql: String = s"${prettyName.toUpperCase(Locale.ROOT)}(${left.sql}, ${right.sql})" + final override val nodePatterns: Seq[TreePattern] = Seq(LIKE_FAMLIY) + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val patternClass = classOf[Pattern].getName diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala index 2b7703ed82b37..d31359582b815 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala @@ -33,7 +33,7 @@ import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.expressions.objects.StaticInvoke import org.apache.spark.sql.catalyst.trees.BinaryLike -import org.apache.spark.sql.catalyst.trees.TreePattern.{TreePattern, UPPER_OR_LOWER} +import org.apache.spark.sql.catalyst.trees.TreePattern.{STRING_PREDICATE, TreePattern, UPPER_OR_LOWER} import org.apache.spark.sql.catalyst.util.{ArrayData, CollationSupport, GenericArrayData, TypeUtils} import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors} import org.apache.spark.sql.internal.SQLConf @@ -508,6 +508,8 @@ abstract class StringPredicate extends BinaryExpression override def inputTypes: Seq[AbstractDataType] = Seq(StringTypeAnyCollation, StringTypeAnyCollation) + final override val nodePatterns: Seq[TreePattern] = Seq(STRING_PREDICATE) + protected override def nullSafeEval(input1: Any, input2: Any): Any = compare(input1.asInstanceOf[UTF8String], input2.asInstanceOf[UTF8String]) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/CombineJoinedAggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/CombineJoinedAggregates.scala new file mode 100644 index 0000000000000..c28d95f46fa27 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/CombineJoinedAggregates.scala @@ -0,0 +1,152 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.optimizer + +import scala.collection.mutable.ArrayBuffer + +import org.apache.spark.sql.catalyst.expressions.{Alias, And, Attribute, AttributeMap, Expression, NamedExpression, Or} +import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression +import org.apache.spark.sql.catalyst.plans.{Cross, FullOuter, Inner, JoinType, LeftOuter, RightOuter} +import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Filter, Join, LeafNode, LogicalPlan, Project, SerializeFromObject} +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.catalyst.trees.TreePattern.{AGGREGATE, ARRAY_CONTAINS, ARRAYS_OVERLAP, AT_LEAST_N_NON_NULLS, BLOOM_FILTER, DYNAMIC_PRUNING_EXPRESSION, DYNAMIC_PRUNING_SUBQUERY, EXISTS_SUBQUERY, HIGH_ORDER_FUNCTION, IN, IN_SUBQUERY, INSET, INVOKE, JOIN, JSON_TO_STRUCT, LIKE_FAMLIY, PYTHON_UDF, REGEXP_EXTRACT_FAMILY, REGEXP_REPLACE, SCALA_UDF, STRING_PREDICATE} + +/** + * This rule eliminates the [[Join]] if all the join side are [[Aggregate]]s by combine these + * [[Aggregate]]s. This rule also support the nested [[Join]], as long as all the join sides for + * every [[Join]] are [[Aggregate]]s. + * + * Note: this rule doesn't support following cases: + * 1. The [[Aggregate]]s to be merged if at least one of them does not have a predicate or + * has low predicate selectivity. + * 2. The upstream node of these [[Aggregate]]s to be merged exists [[Join]]. + */ +object CombineJoinedAggregates extends Rule[LogicalPlan] with MergeScalarSubqueriesHelper { + + private def isSupportedJoinType(joinType: JoinType): Boolean = + Seq(Inner, Cross, LeftOuter, RightOuter, FullOuter).contains(joinType) + + private def isCheapPredicate(e: Expression): Boolean = { + !e.containsAnyPattern(PYTHON_UDF, SCALA_UDF, INVOKE, JSON_TO_STRUCT, LIKE_FAMLIY, + REGEXP_EXTRACT_FAMILY, REGEXP_REPLACE, DYNAMIC_PRUNING_SUBQUERY, DYNAMIC_PRUNING_EXPRESSION, + HIGH_ORDER_FUNCTION, IN_SUBQUERY, IN, INSET, EXISTS_SUBQUERY, STRING_PREDICATE, + AT_LEAST_N_NON_NULLS, BLOOM_FILTER, ARRAY_CONTAINS, ARRAYS_OVERLAP) && + Option(e.apply(conf.maxTreeNodeNumOfPredicate)).isEmpty + } + + /** + * Try to merge two `Aggregate`s by traverse down recursively. + * + * @return The optional tuple as follows: + * 1. the merged plan + * 2. the attribute mapping from the old to the merged version + * 3. optional filters of both plans that need to be propagated and merged in an + * ancestor `Aggregate` node if possible. + */ + private def mergePlan( + left: LogicalPlan, + right: LogicalPlan): Option[(LogicalPlan, AttributeMap[Attribute], Seq[Expression])] = { + (left, right) match { + case (la: Aggregate, ra: Aggregate) => + mergePlan(la.child, ra.child).map { case (newChild, outputMap, filters) => + val rightAggregateExprs = ra.aggregateExpressions.map(mapAttributes(_, outputMap)) + + val mergedAggregateExprs = if (filters.length == 2) { + Seq( + (la.aggregateExpressions, filters.head), + (rightAggregateExprs, filters.last) + ).flatMap { case (aggregateExpressions, propagatedFilter) => + aggregateExpressions.map { ne => + ne.transform { + case ae @ AggregateExpression(_, _, _, filterOpt, _) => + val newFilter = filterOpt.map { filter => + And(propagatedFilter, filter) + }.orElse(Some(propagatedFilter)) + ae.copy(filter = newFilter) + }.asInstanceOf[NamedExpression] + } + } + } else { + la.aggregateExpressions ++ rightAggregateExprs + } + + (Aggregate(Seq.empty, mergedAggregateExprs, newChild), AttributeMap.empty, Seq.empty) + } + case (lp: Project, rp: Project) => + val mergedProjectList = ArrayBuffer[NamedExpression](lp.projectList: _*) + + mergePlan(lp.child, rp.child).map { case (newChild, outputMap, filters) => + val allFilterReferences = filters.flatMap(_.references) + val newOutputMap = AttributeMap((rp.projectList ++ allFilterReferences).map { ne => + val mapped = mapAttributes(ne, outputMap) + + val withoutAlias = mapped match { + case Alias(child, _) => child + case e => e + } + + val outputAttr = mergedProjectList.find { + case Alias(child, _) => child semanticEquals withoutAlias + case e => e semanticEquals withoutAlias + }.getOrElse { + mergedProjectList += mapped + mapped + }.toAttribute + ne.toAttribute -> outputAttr + }) + + (Project(mergedProjectList.toSeq, newChild), newOutputMap, filters) + } + case (lf: Filter, rf: Filter) + if isCheapPredicate(lf.condition) && isCheapPredicate(rf.condition) => + mergePlan(lf.child, rf.child).map { + case (newChild, outputMap, filters) => + val mappedRightCondition = mapAttributes(rf.condition, outputMap) + val (newLeftCondition, newRightCondition) = if (filters.length == 2) { + (And(lf.condition, filters.head), And(mappedRightCondition, filters.last)) + } else { + (lf.condition, mappedRightCondition) + } + val newCondition = Or(newLeftCondition, newRightCondition) + + (Filter(newCondition, newChild), outputMap, Seq(newLeftCondition, newRightCondition)) + } + case (ll: LeafNode, rl: LeafNode) => + checkIdenticalPlans(rl, ll).map { outputMap => + (ll, outputMap, Seq.empty) + } + case (ls: SerializeFromObject, rs: SerializeFromObject) => + checkIdenticalPlans(rs, ls).map { outputMap => + (ls, outputMap, Seq.empty) + } + case _ => None + } + } + + def apply(plan: LogicalPlan): LogicalPlan = { + if (!conf.combineJoinedAggregatesEnabled) return plan + + plan.transformUpWithPruning(_.containsAnyPattern(JOIN, AGGREGATE), ruleId) { + case j @ Join(left: Aggregate, right: Aggregate, joinType, None, _) + if isSupportedJoinType(joinType) && + left.groupingExpressions.isEmpty && right.groupingExpressions.isEmpty => + val mergedAggregate = mergePlan(left, right) + mergedAggregate.map(_._1).getOrElse(j) + } + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/MergeScalarSubqueries.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/MergeScalarSubqueries.scala index 2d1e71a63a8ce..c38d06b5c3a1b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/MergeScalarSubqueries.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/MergeScalarSubqueries.scala @@ -101,7 +101,7 @@ import org.apache.spark.sql.types.DataType * : +- ReusedSubquery Subquery scalar-subquery#242, [id=#125] * +- *(1) Scan OneRowRelation[] */ -object MergeScalarSubqueries extends Rule[LogicalPlan] { +object MergeScalarSubqueries extends Rule[LogicalPlan] with MergeScalarSubqueriesHelper { def apply(plan: LogicalPlan): LogicalPlan = { plan match { // Subquery reuse needs to be enabled for this optimization. @@ -212,17 +212,6 @@ object MergeScalarSubqueries extends Rule[LogicalPlan] { } } - // If 2 plans are identical return the attribute mapping from the new to the cached version. - private def checkIdenticalPlans( - newPlan: LogicalPlan, - cachedPlan: LogicalPlan): Option[AttributeMap[Attribute]] = { - if (newPlan.canonicalized == cachedPlan.canonicalized) { - Some(AttributeMap(newPlan.output.zip(cachedPlan.output))) - } else { - None - } - } - // Recursively traverse down and try merging 2 plans. If merge is possible then return the merged // plan with the attribute mapping from the new to the merged version. // Please note that merging arbitrary plans can be complicated, the current version supports only @@ -314,12 +303,6 @@ object MergeScalarSubqueries extends Rule[LogicalPlan] { plan) } - private def mapAttributes[T <: Expression](expr: T, outputMap: AttributeMap[Attribute]) = { - expr.transform { - case a: Attribute => outputMap.getOrElse(a, a) - }.asInstanceOf[T] - } - // Applies `outputMap` attribute mapping on attributes of `newExpressions` and merges them into // `cachedExpressions`. Returns the merged expressions and the attribute mapping from the new to // the merged version that can be propagated up during merging nodes. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/MergeScalarSubqueriesHelper.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/MergeScalarSubqueriesHelper.scala new file mode 100644 index 0000000000000..a490559be1b4c --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/MergeScalarSubqueriesHelper.scala @@ -0,0 +1,43 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.optimizer + +import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, Expression} +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan + +/** + * The helper class used to merge scalar subqueries. + */ +trait MergeScalarSubqueriesHelper { + + // If 2 plans are identical return the attribute mapping from the left to the right. + protected def checkIdenticalPlans( + left: LogicalPlan, right: LogicalPlan): Option[AttributeMap[Attribute]] = { + if (left.canonicalized == right.canonicalized) { + Some(AttributeMap(left.output.zip(right.output))) + } else { + None + } + } + + protected def mapAttributes[T <: Expression](expr: T, outputMap: AttributeMap[Attribute]): T = { + expr.transform { + case a: Attribute => outputMap.getOrElse(a, a) + }.asInstanceOf[T] + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index e4548ee69dc4a..aee14483000ee 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -97,6 +97,7 @@ abstract class Optimizer(catalogManager: CatalogManager) EliminateOffsets, EliminateLimits, CombineUnions, + CombineJoinedAggregates, // Constant folding and strength reduction OptimizeRepartition, EliminateWindowPartitions, diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleIdCollection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleIdCollection.scala index 11df764ebb03b..dd721b837d4a0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleIdCollection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleIdCollection.scala @@ -112,6 +112,7 @@ object RuleIdCollection { "org.apache.spark.sql.catalyst.optimizer.ColumnPruning" :: "org.apache.spark.sql.catalyst.optimizer.CombineConcats" :: "org.apache.spark.sql.catalyst.optimizer.CombineFilters" :: + "org.apache.spark.sql.catalyst.optimizer.CombineJoinedAggregates" :: "org.apache.spark.sql.catalyst.optimizer.CombineTypedFilters" :: "org.apache.spark.sql.catalyst.optimizer.CombineUnions" :: "org.apache.spark.sql.catalyst.optimizer.ConstantFolding" :: diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala index 4ab075db5709a..3d6b8cadecbf8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala @@ -26,13 +26,17 @@ object TreePattern extends Enumeration { val AGGREGATE_EXPRESSION = Value(0) val ALIAS: Value = Value val AND: Value = Value + val ARRAY_CONTAINS: Value = Value + val ARRAYS_OVERLAP: Value = Value val ARRAYS_ZIP: Value = Value val ATTRIBUTE_REFERENCE: Value = Value val APPEND_COLUMNS: Value = Value val AVERAGE: Value = Value + val AT_LEAST_N_NON_NULLS = Value val GROUPING_ANALYTICS: Value = Value val BINARY_ARITHMETIC: Value = Value val BINARY_COMPARISON: Value = Value + val BLOOM_FILTER: Value = Value val CASE_WHEN: Value = Value val CAST: Value = Value val COALESCE: Value = Value @@ -88,6 +92,7 @@ object TreePattern extends Enumeration { val SCALA_UDF: Value = Value val SESSION_WINDOW: Value = Value val SORT: Value = Value + val STRING_PREDICATE: Value = Value val SUBQUERY_ALIAS: Value = Value val SUM: Value = Value val TIME_WINDOW: Value = Value diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 428ad052eba81..b148b0ed527dd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -4831,6 +4831,23 @@ object SQLConf { .booleanConf .createWithDefault(false) + val COMBINE_JOINED_AGGREGATES_ENABLED = + buildConf("spark.sql.optimizer.combineJoinedAggregates.enabled") + .doc("When true, we attempt to eliminate join by combine aggregates " + + "to reduce the scan times and avoid shuffle.") + .version("4.0.0") + .booleanConf + .createWithDefault(false) + + val MAX_TREE_NODE_NUM_OF_PREDICATE = + buildConf("spark.sql.optimizer.combineJoinedAggregates.maxTreeNodeNumOfPredicate") + .doc("Maximum tree node number of predicate. If tree node number of predicate exceeds the" + + "limit, CombineJoinedAggregates will not merging the aggregates connected with join.") + .version("4.0.0") + .intConf + .checkValue(_ > 0, "The threshold of tree node numbers should be positive") + .createWithDefault(10) + /** * Holds information about keys that have been deprecated. * @@ -5755,6 +5772,11 @@ class SQLConf extends Serializable with Logging with SqlApiConf { def legacyRaiseErrorWithoutErrorClass: Boolean = getConf(SQLConf.LEGACY_RAISE_ERROR_WITHOUT_ERROR_CLASS) + def combineJoinedAggregatesEnabled: Boolean = + getConf(SQLConf.COMBINE_JOINED_AGGREGATES_ENABLED) + + def maxTreeNodeNumOfPredicate: Int = getConf(SQLConf.MAX_TREE_NODE_NUM_OF_PREDICATE) + def stackTracesInDataFrameContext: Int = getConf(SQLConf.STACK_TRACES_IN_DATAFRAME_CONTEXT) def legacyJavaCharsets: Boolean = getConf(SQLConf.LEGACY_JAVA_CHARSETS) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CombineJoinedAggregatesSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CombineJoinedAggregatesSuite.scala new file mode 100644 index 0000000000000..76fcc52702ec1 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CombineJoinedAggregatesSuite.scala @@ -0,0 +1,517 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.optimizer + +import org.scalactic.source.Position +import org.scalatest.Tag + +import org.apache.spark.sql.Row +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.dsl.plans._ +import org.apache.spark.sql.catalyst.plans.{Cross, FullOuter, Inner, LeftAnti, LeftOuter, LeftSemi, PlanTest, RightOuter} +import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} +import org.apache.spark.sql.catalyst.rules.RuleExecutor +import org.apache.spark.sql.internal.SQLConf + +class CombineJoinedAggregatesSuite extends PlanTest { + private object Optimize extends RuleExecutor[LogicalPlan] { + val batches = + Batch("Eliminate Join By Combine Aggregate", FixedPoint(10), + CollapseProject, + RemoveNoopOperators, + PushDownPredicates, + CombineJoinedAggregates, + BooleanSimplification) :: Nil + } + + private object WithoutOptimize extends RuleExecutor[LogicalPlan] { + val batches = + Batch("Eliminate Join By Combine Aggregate", FixedPoint(10), + CollapseProject, + RemoveNoopOperators, + PushDownPredicates, + BooleanSimplification) :: Nil + } + + private val testRelation = LocalRelation.fromExternalRows( + Seq("a".attr.int, "b".attr.int, "c".attr.int), + 1.to(6).map(i => Row(i, 2 * i, 3 * i))) + private val a = testRelation.output(0) + private val b = testRelation.output(1) + private val c = testRelation.output(2) + + override def test(testName: String, testTags: Tag*)(testFun: => Any) + (implicit pos: Position): Unit = { + super.test(testName, testTags: _*) { + withSQLConf(SQLConf.COMBINE_JOINED_AGGREGATES_ENABLED.key -> "true") { + testFun + } + } + } + + test("join type is unsupported") { + Seq(LeftSemi, LeftAnti).foreach { joinType => + val originalQuery1 = + testRelation.where(a === 1).groupBy()(sum(b).as("sum_b")).join( + testRelation.where(a === 2).groupBy()(sum(b).as("sum_b")), joinType) + + comparePlans( + Optimize.execute(originalQuery1.analyze), + WithoutOptimize.execute(originalQuery1.analyze)) + + val originalQuery2 = + testRelation.where(a === 1).groupBy()(sum(b).as("sum_b")).join( + testRelation.where(a === 2).groupBy()(avg(b).as("avg_b")), joinType).join( + testRelation.where(a === 3).groupBy()(count(b).as("count_b"))) + + comparePlans( + Optimize.execute(originalQuery2.analyze), + WithoutOptimize.execute(originalQuery2.analyze)) + } + } + + test("join with condition is unsupported") { + Seq(Inner, Cross, LeftOuter, RightOuter, FullOuter).foreach { joinType => + val originalQuery1 = + testRelation.where(a === 1).groupBy()(sum(b).as("sum_b")).as("left").join( + testRelation.where(a === 2).groupBy()(sum(b).as("sum_b")).as("right"), + joinType, Some($"left.sum_b" === $"right.sum_b")) + + comparePlans( + Optimize.execute(originalQuery1.analyze), + WithoutOptimize.execute(originalQuery1.analyze)) + + val originalQuery2 = + testRelation.where(a === 1).groupBy()(sum(b).as("sum_b")).as("left").join( + testRelation.where(a === 2).groupBy()(avg(b).as("avg_b")).as("right"), + joinType, Some($"left.sum_b" === $"right.avg_b")).join( + testRelation.where(a === 3).groupBy()(count(b).as("count_b"))) + + comparePlans( + Optimize.execute(originalQuery2.analyze), + WithoutOptimize.execute(originalQuery2.analyze)) + } + } + + test("join side doesn't contains Aggregate") { + val originalQuery1 = + testRelation.where(a === 1).join( + testRelation.where(a === 2)) + + comparePlans( + Optimize.execute(originalQuery1.analyze), + WithoutOptimize.execute(originalQuery1.analyze)) + + val originalQuery2 = + testRelation.where(a === 1).select(b, c).join( + testRelation.where(a === 2).select(b, c)) + + comparePlans( + Optimize.execute(originalQuery2.analyze), + WithoutOptimize.execute(originalQuery2.analyze)) + } + + test("join side is not Aggregate") { + val originalQuery1 = + testRelation.where(a === 1).groupBy()(sum(b).as("sum_b")).join( + testRelation.where(a === 2)) + + comparePlans( + Optimize.execute(originalQuery1.analyze), + WithoutOptimize.execute(originalQuery1.analyze)) + + val originalQuery2 = + testRelation.where(a === 1).groupBy()(sum(b).as("sum_b"), avg(b).as("avg_b")).join( + testRelation.where(a === 2).select(b, c)) + + comparePlans( + Optimize.execute(originalQuery2.analyze), + WithoutOptimize.execute(originalQuery2.analyze)) + + // Nested Join + val originalQuery3 = + testRelation.where(a === 1).groupBy()(sum(b).as("sum_b")).join( + testRelation.where(a === 2)).join( + testRelation.where(a === 3).groupBy()(count(b).as("count_b"))) + + comparePlans( + Optimize.execute(originalQuery3.analyze), + WithoutOptimize.execute(originalQuery3.analyze)) + + val originalQuery4 = + testRelation.where(a === 1).groupBy()(sum(b).as("sum_b")).join( + testRelation.where(a === 2).join( + testRelation.where(a === 3).groupBy()(count(b).as("count_b")))) + + comparePlans( + Optimize.execute(originalQuery4.analyze), + WithoutOptimize.execute(originalQuery4.analyze)) + } + + test("join side contains Aggregate with group by clause") { + val originalQuery1 = + testRelation.where(a === 1).groupBy(c)(sum(b).as("sum_b")).join( + testRelation.where(a === 2).groupBy()(sum(b).as("sum_b"))) + + comparePlans( + Optimize.execute(originalQuery1.analyze), + WithoutOptimize.execute(originalQuery1.analyze)) + + val originalQuery2 = + testRelation.where(a === 1).groupBy()(sum(b).as("sum_b")).join( + testRelation.where(a === 2).groupBy(c)(sum(b).as("sum_b"))) + + comparePlans( + Optimize.execute(originalQuery2.analyze), + WithoutOptimize.execute(originalQuery2.analyze)) + + val originalQuery3 = + testRelation.where(a === 1).groupBy(c)(sum(b).as("sum_b")).join( + testRelation.where(a === 2).groupBy(c)(sum(b).as("sum_b"))) + + comparePlans( + Optimize.execute(originalQuery3.analyze), + WithoutOptimize.execute(originalQuery3.analyze)) + + val originalQuery4 = + testRelation.where(a === 1).groupBy(c)(sum(b).as("sum_b")).join( + testRelation.where(a === 2).groupBy()(avg(b).as("avg_b"))).join( + testRelation.where(a === 3).groupBy()(count(b).as("count_b"))) + + comparePlans( + Optimize.execute(originalQuery4.analyze), + WithoutOptimize.execute(originalQuery4.analyze)) + } + + test("join two side are Aggregates with subquery") { + val subQuery1 = testRelation.where(a === 1).as("tab1") + val subQuery2 = testRelation.where(a === 2).as("tab2") + val b1 = subQuery1.output(1) + val c1 = subQuery1.output(2) + val b2 = subQuery2.output(1) + val c2 = subQuery2.output(2) + val originalQuery = + subQuery1.where(c1 === 1).groupBy()(sum(b1).as("sum_b")).join( + subQuery2.where(c2 === 2).groupBy()(avg(b2).as("avg_b"))) + + val correctAnswer = + testRelation.where((a === 1 && c === 1) || (a === 2 && c === 2)).groupBy()( + sum(b, Some(a === 1 && c === 1)).as("sum_b"), + avg(b, Some(a === 2 && c === 2)).as("avg_b")) + + comparePlans( + Optimize.execute(originalQuery.analyze), + WithoutOptimize.execute(correctAnswer.analyze)) + } + + test("join two side are Aggregates and only one side with Filter") { + val originalQuery = + testRelation.where(a === 1).groupBy()(sum(b).as("sum_b")).join( + testRelation.groupBy()(sum(b).as("sum_b"))) + + comparePlans( + Optimize.execute(originalQuery.analyze), + WithoutOptimize.execute(originalQuery.analyze)) + } + + test("join two side are Aggregates without Filter") { + val originalQuery = + testRelation.groupBy()(sum(b).as("sum_b")).join( + testRelation.groupBy()(sum(b).as("sum_b"))) + + val correctAnswer = testRelation.groupBy()(sum(b).as("sum_b"), sum(b).as("sum_b")) + + comparePlans( + Optimize.execute(originalQuery.analyze), + WithoutOptimize.execute(correctAnswer.analyze)) + } + + test("join two side are Aggregates with Filter") { + val originalQuery1 = + testRelation.where(a === 1).groupBy()(sum(b).as("sum_b")).join( + testRelation.where(a === 2).groupBy()(sum(b).as("sum_b"))) + + val correctAnswer1 = + testRelation.where(a === 1 || a === 2).groupBy()( + sum(b, Some(a === 1)).as("sum_b"), sum(b, Some(a === 2)).as("sum_b")) + + comparePlans( + Optimize.execute(originalQuery1.analyze), + WithoutOptimize.execute(correctAnswer1.analyze)) + + val originalQuery2 = + testRelation.where(a === 1).groupBy()(sum(b).as("sum_b")).join( + testRelation.where(a === 2).groupBy()(avg(b).as("avg_b"))) + + val correctAnswer2 = + testRelation.where(a === 1 || a === 2).groupBy()( + sum(b, Some(a === 1)).as("sum_b"), avg(b, Some(a === 2)).as("avg_b")) + + comparePlans( + Optimize.execute(originalQuery2.analyze), + WithoutOptimize.execute(correctAnswer2.analyze)) + + val originalQuery3 = + testRelation.where(a === 1).groupBy()(sum(b).as("sum_b")).join( + testRelation.where(a === 2).groupBy()(avg(c).as("avg_c"))) + + val correctAnswer3 = + testRelation.where(a === 1 || a === 2).groupBy()( + sum(b, Some(a === 1)).as("sum_b"), avg(c, Some(a === 2)).as("avg_c")) + + comparePlans( + Optimize.execute(originalQuery3.analyze), + WithoutOptimize.execute(correctAnswer3.analyze)) + + val originalQuery4 = + testRelation.where(a === 1).groupBy()(sum(b).as("sum_b"), count(c).as("count_c")).join( + testRelation.where(a === 2).groupBy()(avg(c).as("avg_c"))) + + val correctAnswer4 = + testRelation.where(a === 1 || a === 2).groupBy()( + sum(b, Some(a === 1)).as("sum_b"), + count(c, Some(a === 1)).as("count_c"), + avg(c, Some(a === 2)).as("avg_c")) + + comparePlans( + Optimize.execute(originalQuery4.analyze), + WithoutOptimize.execute(correctAnswer4.analyze)) + + val originalQuery5 = + testRelation.where(a === 1).groupBy()(sum(b).as("sum_b")).join( + testRelation.where(a === 2).groupBy()(avg(c).as("avg_c"), count(c).as("count_c"))) + + val correctAnswer5 = + testRelation.where(a === 1 || a === 2).groupBy()( + sum(b, Some(a === 1)).as("sum_b"), + avg(c, Some(a === 2)).as("avg_c"), + count(c, Some(a === 2)).as("count_c")) + + comparePlans( + Optimize.execute(originalQuery5.analyze), + WithoutOptimize.execute(correctAnswer5.analyze)) + } + + test("all side of nested join are Aggregates") { + val originalQuery1 = + testRelation.where(a === 1).groupBy()(sum(b).as("sum_b")).join( + testRelation.where(a === 2).groupBy()(avg(b).as("avg_b"))).join( + testRelation.where(a === 3).groupBy()(count(b).as("count_b"))) + + val correctAnswer1 = + testRelation.where(a === 1 || a === 2 || a === 3).groupBy()( + sum(b, Some(a === 1)).as("sum_b"), + avg(b, Some(a === 2)).as("avg_b"), + count(b, Some(a === 3)).as("count_b")) + + comparePlans( + Optimize.execute(originalQuery1.analyze), + WithoutOptimize.execute(correctAnswer1.analyze)) + + val originalQuery2 = + testRelation.where(a === 1).groupBy()(sum(b).as("sum_b")).join( + testRelation.where(a === 2).groupBy()(avg(b).as("avg_b")).join( + testRelation.where(a === 3).groupBy()(count(b).as("count_b")))) + + val correctAnswer2 = + testRelation.where(a === 1 || (a === 2 || a === 3)).groupBy()( + sum(b, Some(a === 1)).as("sum_b"), + avg(b, Some(a === 2)).as("avg_b"), + count(b, Some(a === 3)).as("count_b")) + + comparePlans( + Optimize.execute(originalQuery2.analyze), + WithoutOptimize.execute(correctAnswer2.analyze)) + + val originalQuery3 = + testRelation.where(a === 1).groupBy()(avg(a).as("avg_a"), sum(b).as("sum_b")).join( + testRelation.where(a === 2).groupBy()(avg(b).as("avg_b"), sum(a).as("sum_a")).join( + testRelation.where(a === 3).groupBy()( + count(a).as("count_a"), + count(b).as("count_b"), + count(c).as("count_c")))) + + val correctAnswer3 = + testRelation.where(a === 1 || (a === 2 || a === 3)).groupBy()( + avg(a, Some(a === 1)).as("avg_a"), + sum(b, Some(a === 1)).as("sum_b"), + avg(b, Some(a === 2)).as("avg_b"), + sum(a, Some(a === 2)).as("sum_a"), + count(a, Some(a === 3)).as("count_a"), + count(b, Some(a === 3)).as("count_b"), + count(c, Some(a === 3)).as("count_c")) + + comparePlans( + Optimize.execute(originalQuery3.analyze), + WithoutOptimize.execute(correctAnswer3.analyze)) + + val originalQuery4 = + testRelation.where(a === 1).groupBy()(sum(b).as("sum_b")).join( + testRelation.where(a === 2).groupBy()(avg(b).as("avg_b"))).join( + testRelation.where(a === 3).groupBy()(countDistinct(b).as("count_distinct_b"))) + + val correctAnswer4 = + testRelation.where(a === 1 || a === 2 || a === 3).groupBy()( + sum(b, Some(a === 1)).as("sum_b"), + avg(b, Some(a === 2)).as("avg_b"), + countDistinctWithFilter(a === 3, b).as("count_distinct_b")) + + comparePlans( + Optimize.execute(originalQuery4.analyze), + WithoutOptimize.execute(correctAnswer4.analyze)) + } + + test("join two side are Aggregates and aggregate expressions exist Filter clause") { + val originalQuery1 = + testRelation.where(a === 1).groupBy()(sum(b, Some(c === 1)).as("sum_b")).join( + testRelation.where(a === 2).groupBy()(avg(b).as("avg_b"))) + + val correctAnswer1 = + testRelation.where(a === 1 || a === 2).groupBy()( + sum(b, Some((a === 1) && (c === 1))).as("sum_b"), + avg(b, Some(a === 2)).as("avg_b")) + + comparePlans( + Optimize.execute(originalQuery1.analyze), + WithoutOptimize.execute(correctAnswer1.analyze)) + + val originalQuery2 = + testRelation.where(a === 1).groupBy()(sum(b).as("sum_b")).join( + testRelation.where(a === 2).groupBy()(avg(b, Some(c === 1)).as("avg_b"))) + + val correctAnswer2 = + testRelation.where(a === 1 || a === 2).groupBy()( + sum(b, Some(a === 1)).as("sum_b"), + avg(b, Some((a === 2) && (c === 1))).as("avg_b")) + + comparePlans( + Optimize.execute(originalQuery2.analyze), + WithoutOptimize.execute(correctAnswer2.analyze)) + + val originalQuery3 = + testRelation.where(a === 1).groupBy()(sum(b, Some(c === 1)).as("sum_b")).join( + testRelation.where(a === 2).groupBy()(avg(b, Some(c === 1)).as("avg_b"))) + + val correctAnswer3 = + testRelation.where(a === 1 || a === 2).groupBy()( + sum(b, Some((a === 1) && (c === 1))).as("sum_b"), + avg(b, Some((a === 2) && (c === 1))).as("avg_b")) + + comparePlans( + Optimize.execute(originalQuery3.analyze), + WithoutOptimize.execute(correctAnswer3.analyze)) + + val originalQuery4 = + testRelation.where(a === 1).groupBy()(sum(b, Some(c === 1)).as("sum_b")).join( + testRelation.where(a === 2).groupBy()(avg(b, Some(c === 1)).as("avg_b"))).join( + testRelation.where(a === 3).groupBy()(count(b, Some(c > 1)).as("count_b"))) + + val correctAnswer4 = + testRelation.where(a === 1 || a === 2 || a === 3).groupBy()( + sum(b, Some(((a === 1) || (a === 2)) && ((a === 1) && (c === 1)))).as("sum_b"), + avg(b, Some(((a === 1) || (a === 2)) && ((a === 2) && (c === 1)))).as("avg_b"), + count(b, Some((a === 3) && (c > 1))).as("count_b")) + + comparePlans( + Optimize.execute(originalQuery4.analyze), + WithoutOptimize.execute(correctAnswer4.analyze)) + + val originalQuery5 = + testRelation.where(a === 1).groupBy()(sum(b).as("sum_b")).join( + testRelation.where(a === 2).groupBy()(avg(b).as("avg_b"))).join( + testRelation.where(a === 3).groupBy()(count(b, Some(c === 1)).as("count_b"))) + + val correctAnswer5 = + testRelation.where(a === 1 || a === 2 || a === 3).groupBy()( + sum(b, Some(a === 1)).as("sum_b"), + avg(b, Some(a === 2)).as("avg_b"), + count(b, Some((a === 3) && (c === 1))).as("count_b")) + + comparePlans( + Optimize.execute(originalQuery5.analyze), + WithoutOptimize.execute(correctAnswer5.analyze)) + } + + test("upstream join could be optimized") { + Seq(LeftSemi, LeftAnti).foreach { joinType => + val originalQuery1 = + testRelation.where(a === 1).groupBy()(sum(b).as("sum_b")).join( + testRelation.where(a === 2).groupBy()(avg(b).as("avg_b"))).join( + testRelation.where(a === 3).groupBy()(count(b).as("count_b")), joinType) + + val correctAnswer1 = + testRelation.where(a === 1 || a === 2).groupBy()( + sum(b, Some(a === 1)).as("sum_b"), + avg(b, Some(a === 2)).as("avg_b")).join( + testRelation.where(a === 3).groupBy()(count(b).as("count_b")), joinType) + + comparePlans( + Optimize.execute(originalQuery1.analyze), + WithoutOptimize.execute(correctAnswer1.analyze)) + } + + Seq(Inner, Cross, LeftOuter, RightOuter, FullOuter).foreach { joinType => + val originalQuery2 = + testRelation.where(a === 1).groupBy()(sum(b).as("sum_b")).join( + testRelation.where(a === 2).groupBy()(avg(b).as("avg_b"))).as("left").join( + testRelation.where(a === 3).groupBy()(count(b).as("count_b")).as("right"), + joinType, Some($"left.sum_b" === $"right.count_b")) + + val correctAnswer2 = + testRelation.where(a === 1 || a === 2).groupBy()( + sum(b, Some(a === 1)).as("sum_b"), + avg(b, Some(a === 2)).as("avg_b")).as("left").join( + testRelation.where(a === 3).groupBy()(count(b).as("count_b")).as("right"), + joinType, Some($"left.sum_b" === $"right.count_b")) + + comparePlans( + Optimize.execute(originalQuery2.analyze), + WithoutOptimize.execute(correctAnswer2.analyze)) + } + + val originalQuery3 = + testRelation.where(a === 1).groupBy()(sum(b).as("sum_b")).join( + testRelation.where(a === 2).groupBy()(avg(b).as("avg_b"))).join( + testRelation.where(a === 3)) + + val correctAnswer3 = + testRelation.where(a === 1 || a === 2).groupBy()( + sum(b, Some(a === 1)).as("sum_b"), + avg(b, Some(a === 2)).as("avg_b")).join( + testRelation.where(a === 3)) + + comparePlans( + Optimize.execute(originalQuery3.analyze), + WithoutOptimize.execute(correctAnswer3.analyze)) + + val originalQuery4 = + testRelation.where(a === 1).groupBy()(sum(b).as("sum_b")).join( + testRelation.where(a === 2).groupBy()(avg(b).as("avg_b"))).join( + testRelation.where(a === 3).groupBy(c)(count(b).as("count_b"))) + + val correctAnswer4 = + testRelation.where(a === 1 || a === 2).groupBy()( + sum(b, Some(a === 1)).as("sum_b"), + avg(b, Some(a === 2)).as("avg_b")).join( + testRelation.where(a === 3).groupBy(c)(count(b).as("count_b"))) + + comparePlans( + Optimize.execute(originalQuery4.analyze), + WithoutOptimize.execute(correctAnswer4.analyze)) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala index 620ee430cab20..62bdd06e0647e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala @@ -24,6 +24,9 @@ import scala.util.Random import org.scalatest.matchers.must.Matchers.the import org.apache.spark.{SparkArithmeticException, SparkRuntimeException} +import org.apache.spark.sql.catalyst.expressions.EqualTo +import org.apache.spark.sql.catalyst.expressions.aggregate.{Average, Count, Sum} +import org.apache.spark.sql.catalyst.optimizer.PushDownPredicates import org.apache.spark.sql.catalyst.util.AUTO_GENERATED_ALIAS import org.apache.spark.sql.execution.WholeStageCodegenExec import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper @@ -2161,6 +2164,231 @@ class DataFrameAggregateSuite extends QueryTest ) } + test("test CombineJoinedAggregates") { + val df = spark.sparkContext.parallelize(Seq( + Fact(20151123, 18, 35, "room1", 18.6), + Fact(20151123, 18, 35, "room2", 22.4), + Fact(20151123, 18, 36, "room1", 17.4), + Fact(20151123, 18, 36, "room2", 25.6), + Fact(20151124, 19, 45, "room1", 18.7), + Fact(20151124, 19, 25, "room2", 32.4), + Fact(20151124, 19, 26, "room1", 17.8), + Fact(20151124, 19, 26, "room2", 23.6), + Fact(20151125, 20, 15, "room1", 28.1), + Fact(20151125, 20, 25, "room2", 22.8), + Fact(20151125, 20, 36, "room1", 27.3), + Fact(20151125, 20, 46, "room2", 13.2), + Fact(20151125, 20, 59, "room2", 53.9))).toDF() + + Seq(false, true).foreach { enabled => + withSQLConf(SQLConf.COMBINE_JOINED_AGGREGATES_ENABLED.key -> enabled.toString) { + // join two side are Aggregates without Filter + val join1 = + df.agg(sum($"temp").as("sum_temp")).join( + df.agg(sum($"temp").as("sum_temp"))) + checkAnswer(join1, Row(321.79999999999995, 321.79999999999995)) + + // join two side are Aggregates with Filter + val join2 = + df.where($"date" === 20151123).agg(sum($"temp").as("sum_temp")).join( + df.where($"date" === 20151124).agg(sum($"temp").as("sum_temp"))) + checkAnswer(join2, Row(84.0, 92.5)) + + val join3 = + df.where($"date" === 20151123).agg(sum($"temp").as("sum_temp")).join( + df.where($"date" === 20151124).agg(avg($"temp").as("avg_temp"))) + checkAnswer(join3, Row(84.0, 23.125)) + + val join4 = + df.where($"date" === 20151123).agg(sum($"temp").as("sum_temp")).join( + df.where($"date" === 20151124).agg(avg($"minute").as("avg_minute"))) + checkAnswer(join4, Row(84.0, 30.5)) + + val join5 = + df.where($"date" === 20151123).agg( + sum($"temp").as("sum_temp"), + count($"minute").as("count_minute")).join( + df.where($"date" === 20151124).agg(avg($"minute").as("avg_minute"))) + checkAnswer(join5, Row(84.0, 4, 30.5)) + + val join6 = + df.where($"date" === 20151123).agg(sum($"temp").as("sum_temp")).join( + df.where($"date" === 20151124).agg( + avg($"minute").as("avg_minute"), + count($"minute").as("count_minute"))) + checkAnswer(join6, Row(84.0, 30.5, 4)) + + // all side of nested join are Aggregates + val join7 = + df.where($"date" === 20151123).agg(sum($"temp").as("sum_temp")).join( + df.where($"date" === 20151124).agg(avg($"temp").as("avg_temp"))).join( + df.where($"date" === 20151125).agg(count($"temp").as("count_temp"))) + checkAnswer(join7, Row(84.0, 23.125, 5)) + + val join8 = + df.where($"date" === 20151123).agg(sum($"temp").as("sum_temp")).join( + df.where($"date" === 20151124).agg(avg($"temp").as("avg_temp")).join( + df.where($"date" === 20151125).agg(count($"temp").as("count_temp")))) + checkAnswer(join8, Row(84.0, 23.125, 5)) + + val join9 = + df.where($"date" === 20151123).agg( + avg($"minute").as("avg_minute"), + sum($"temp").as("sum_temp")).join( + df.where($"date" === 20151124).agg( + avg($"temp").as("avg_temp"), + sum($"minute").as("sum_minute")).join( + df.where($"date" === 20151125).agg( + count($"minute").as("count_minute"), + count($"temp").as("count_temp"), + count($"room_name").as("count_room_name")))) + checkAnswer(join9, Row(35.5, 84.0, 23.125, 122, 5, 5, 5)) + + val join10 = + df.where($"date" === 20151123).agg(sum($"temp").as("sum_temp")).join( + df.where($"date" === 20151124).agg(avg($"temp").as("avg_temp")), + Seq.empty, "inner").join( + df.where($"date" === 20151125).agg(count($"temp").as("count_temp"))) + checkAnswer(join10, Row(84.0, 23.125, 5)) + + val join11 = + df.where($"date" === 20151123).agg(sum($"temp").as("sum_temp")).join( + df.where($"date" === 20151124).agg(avg($"temp").as("avg_temp")), + Seq.empty, "cross").join( + df.where($"date" === 20151125).agg(count($"temp").as("count_temp"))) + checkAnswer(join11, Row(84.0, 23.125, 5)) + + val join12 = + df.where($"date" === 20151123).agg(sum($"temp").as("sum_temp")).join( + df.where($"date" === 20151124).agg(avg($"temp").as("avg_temp")), + Seq.empty, "left_outer").join( + df.where($"date" === 20151125).agg(count($"temp").as("count_temp"))) + checkAnswer(join12, Row(84.0, 23.125, 5)) + + val join13 = + df.where($"date" === 20151123).agg(sum($"temp").as("sum_temp")).join( + df.where($"date" === 20151124).agg(avg($"temp").as("avg_temp")), + Seq.empty, "right_outer").join( + df.where($"date" === 20151125).agg(count($"temp").as("count_temp"))) + checkAnswer(join13, Row(84.0, 23.125, 5)) + + val join14 = + df.where($"date" === 20151123).agg(sum($"temp").as("sum_temp")).join( + df.where($"date" === 20151124).agg(avg($"temp").as("avg_temp")), + Seq.empty, "full_outer").join( + df.where($"date" === 20151125).agg(count($"temp").as("count_temp"))) + checkAnswer(join14, Row(84.0, 23.125, 5)) + + // PushLeftSemiLeftAntiThroughJoin push LEFT SEMI and LEFT ANTI through JOIN, + // So EliminateJoinByCombineAggregate can't eliminate JOIN. + val join15 = + df.where($"date" === 20151123).agg(sum($"temp").as("sum_temp")).join( + df.where($"date" === 20151124).agg(avg($"temp").as("avg_temp"))).join( + df.where($"date" === 20151125).agg(count($"temp").as("count_temp")), + Seq.empty, "left_semi") + checkAnswer(join15, Row(84.0, 23.125)) + + val join16 = + df.where($"date" === 20151123).agg(sum($"temp").as("sum_temp")).join( + df.where($"date" === 20151124).agg(avg($"temp").as("avg_temp"))).join( + df.where($"date" === 20151125).agg(count($"temp").as("count_temp")), + Seq.empty, "left_anti") + checkAnswer(join16, Seq.empty) + + // ReorderJoin push the join condition of inner like join into upstream join, + // So EliminateJoinByCombineAggregate can't eliminate JOIN. + val join17 = + df.where($"date" === 20151123).agg(sum($"temp").as("sum_temp")).join( + df.where($"date" === 20151124).agg(avg($"temp").as("avg_temp"))).as("left").join( + df.where($"date" === 20151125).agg(count($"temp").as("count_temp")).as("right"), + $"left.sum_temp" === $"right.count_temp", "inner") + checkAnswer(join17, Seq.empty) + + val join18 = + df.where($"date" === 20151123).agg(sum($"temp").as("sum_temp")).join( + df.where($"date" === 20151124).agg(avg($"temp").as("avg_temp"))).as("left").join( + df.where($"date" === 20151125).agg(count($"temp").as("count_temp")).as("right"), + $"left.sum_temp" === $"right.count_temp", "cross") + checkAnswer(join18, Seq.empty) + + // ReorderJoin can't push the join condition of non inner like join into upstream join, + // So EliminateJoinByCombineAggregate can still eliminate JOIN. + val join19 = + df.where($"date" === 20151123).agg(sum($"temp").as("sum_temp")).join( + df.where($"date" === 20151124).agg(avg($"temp").as("avg_temp"))).as("left").join( + df.where($"date" === 20151125).agg(count($"temp").as("count_temp")).as("right"), + $"left.sum_temp" === $"right.count_temp", "left_outer") + checkAnswer(join19, Row(84.0, 23.125, null)) + + val join20 = + df.where($"date" === 20151123).agg(sum($"temp").as("sum_temp")).join( + df.where($"date" === 20151124).agg(avg($"temp").as("avg_temp"))).as("left").join( + df.where($"date" === 20151125).agg(count($"temp").as("count_temp")).as("right"), + $"left.sum_temp" === $"right.count_temp", "right_outer") + checkAnswer(join20, Row(null, null, 5)) + + val join21 = + df.where($"date" === 20151123).agg(sum($"temp").as("sum_temp")).join( + df.where($"date" === 20151124).agg(avg($"temp").as("avg_temp"))).as("left").join( + df.where($"date" === 20151125).agg(count($"temp").as("count_temp")).as("right"), + $"left.sum_temp" === $"right.count_temp", "full_outer") + checkAnswer(join21, Seq(Row(84.0, 23.125, null), Row(null, null, 5))) + + // join two side are Aggregates and aggregate expressions exist Filter clause + val sumWithFilter = Sum($"temp".expr).toAggregateExpression( + false, Some(EqualTo($"room_name".expr, lit("room1").expr))) + val join22 = + df.where($"date" === 20151123).agg( + new Column(sumWithFilter).as("sum_temp")).join( + df.where($"date" === 20151124).agg(avg($"temp").as("avg_temp"))) + checkAnswer(join22, Row(36.0, 23.125)) + + val avgWithFilter = Average($"temp".expr).toAggregateExpression( + false, Some(EqualTo($"room_name".expr, lit("room1").expr))) + val join23 = + df.where($"date" === 20151123).agg(sum($"temp").as("sum_temp")).join( + df.where($"date" === 20151124).agg(new Column(avgWithFilter).as("avg_temp"))) + checkAnswer(join23, Row(84.0, 18.25)) + + val join24 = + df.where($"date" === 20151123).agg(new Column(sumWithFilter).as("sum_temp")).join( + df.where($"date" === 20151124).agg(new Column(avgWithFilter).as("avg_temp"))) + checkAnswer(join24, Row(36.0, 18.25)) + + val countWithFilter = Count($"temp".expr).toAggregateExpression( + false, Some(EqualTo($"room_name".expr, lit("room2").expr))) + val join25 = + df.where($"date" === 20151123).agg(new Column(sumWithFilter).as("sum_temp")).join( + df.where($"date" === 20151124).agg(new Column(avgWithFilter).as("avg_temp"))).join( + df.where($"date" === 20151125).agg(new Column(countWithFilter).as("count_temp"))) + checkAnswer(join25, Row(36.0, 18.25, 3)) + + val join26 = + df.where($"date" === 20151123).agg(sum($"temp").as("sum_temp")).join( + df.where($"date" === 20151124).agg(avg($"temp").as("avg_temp"))).join( + df.where($"date" === 20151125).agg(new Column(countWithFilter).as("count_temp"))) + checkAnswer(join26, Row(84.0, 23.125, 3)) + + val join27 = + df.where($"date" === 20151123).agg(sum($"temp").as("sum_temp")).join( + df.where($"date" === 20151124).agg(new Column(avgWithFilter).as("avg_temp"))).join( + df.where($"date" === 20151125).agg(count($"temp").as("count_temp"))) + checkAnswer(join27, Row(84.0, 18.25, 5)) + + Seq(PushDownPredicates.ruleName, "").map { ruleName => + withSQLConf(SQLConf.OPTIMIZER_EXCLUDED_RULES.key -> ruleName) { + val subQuery1 = df.where($"date" === 20151123).as("tab1") + val subQuery2 = df.where($"date" === 20151124).as("tab2") + val join28 = + subQuery1.where($"tab1.minute" > 30).agg(sum($"tab1.temp").as("sum_temp")).join( + subQuery2.where($"tab2.minute" < 30).agg(avg($"tab2.temp").as("avg_temp"))) + checkAnswer(join28, Row(84.0, 24.600000000000005)) + } + } + } + } + } + private def assertAggregateOnDataframe(df: DataFrame, expected: Int, aggregateColumn: String): Unit = { val configurations = Seq(