Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.DataTypeMismatch
import org.apache.spark.sql.catalyst.expressions.Cast.{toSQLExpr, toSQLId, toSQLType}
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode, JavaCode, TrueLiteral}
import org.apache.spark.sql.catalyst.expressions.codegen.Block.BlockHelper
import org.apache.spark.sql.catalyst.trees.TreePattern.OUTER_REFERENCE
import org.apache.spark.sql.catalyst.trees.TreePattern.{BLOOM_FILTER, OUTER_REFERENCE, TreePattern}
import org.apache.spark.sql.types._
import org.apache.spark.util.sketch.BloomFilter

Expand All @@ -47,6 +47,8 @@ case class BloomFilterMightContain(
override def right: Expression = valueExpression
override def prettyName: String = "might_contain"

final override val nodePatterns: Seq[TreePattern] = Seq(BLOOM_FILTER)

override def checkInputDataTypes(): TypeCheckResult = {
(left.dataType, right.dataType) match {
case (BinaryType, NullType) | (NullType, LongType) | (NullType, NullType) |
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ import org.apache.spark.sql.catalyst.expressions.ArraySortLike.NullOrder
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.catalyst.trees.{BinaryLike, UnaryLike}
import org.apache.spark.sql.catalyst.trees.TreePattern.{ARRAYS_ZIP, CONCAT, TreePattern}
import org.apache.spark.sql.catalyst.trees.TreePattern.{ARRAY_CONTAINS, ARRAYS_OVERLAP, ARRAYS_ZIP, CONCAT, TreePattern}
import org.apache.spark.sql.catalyst.types.{DataTypeUtils, PhysicalDataType, PhysicalIntegralType}
import org.apache.spark.sql.catalyst.util._
import org.apache.spark.sql.catalyst.util.DateTimeConstants._
Expand Down Expand Up @@ -1428,6 +1428,8 @@ case class ArrayContains(left: Expression, right: Expression)
extends BinaryExpression with ImplicitCastInputTypes with NullIntolerant with Predicate
with QueryErrorsBase {

final override val nodePatterns: Seq[TreePattern] = Seq(ARRAY_CONTAINS)

@transient private lazy val ordering: Ordering[Any] =
TypeUtils.getInterpretedOrdering(right.dataType)

Expand Down Expand Up @@ -1651,6 +1653,8 @@ case class ArrayAppend(left: Expression, right: Expression) extends ArrayPendBas
case class ArraysOverlap(left: Expression, right: Expression)
extends BinaryArrayExpressionWithImplicitCast with NullIntolerant with Predicate {

final override val nodePatterns: Seq[TreePattern] = Seq(ARRAYS_OVERLAP)

override def checkInputDataTypes(): TypeCheckResult = super.checkInputDataTypes() match {
case TypeCheckResult.TypeCheckSuccess =>
TypeUtils.checkForOrderingExpr(elementType, prettyName)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.expressions.Cast._
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.catalyst.trees.TreePattern.{COALESCE, NULL_CHECK, TreePattern}
import org.apache.spark.sql.catalyst.trees.TreePattern.{AT_LEAST_N_NON_NULLS, COALESCE, NULL_CHECK, TreePattern}
import org.apache.spark.sql.catalyst.util.TypeUtils
import org.apache.spark.sql.errors.QueryCompilationErrors
import org.apache.spark.sql.internal.SQLConf
Expand Down Expand Up @@ -429,6 +429,8 @@ case class AtLeastNNonNulls(n: Int, children: Seq[Expression]) extends Predicate
override def nullable: Boolean = false
override def foldable: Boolean = children.forall(_.foldable)

final override val nodePatterns: Seq[TreePattern] = Seq(AT_LEAST_N_NON_NULLS)

private[this] val childrenArray = children.toArray

override def eval(input: InternalRow): Boolean = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -466,6 +466,8 @@ case class RLike(left: Expression, right: Expression) extends StringRegexExpress
override def toString: String = s"RLIKE($left, $right)"
override def sql: String = s"${prettyName.toUpperCase(Locale.ROOT)}(${left.sql}, ${right.sql})"

final override val nodePatterns: Seq[TreePattern] = Seq(LIKE_FAMLIY)

override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val patternClass = classOf[Pattern].getName

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.catalyst.expressions.objects.StaticInvoke
import org.apache.spark.sql.catalyst.trees.BinaryLike
import org.apache.spark.sql.catalyst.trees.TreePattern.{TreePattern, UPPER_OR_LOWER}
import org.apache.spark.sql.catalyst.trees.TreePattern.{STRING_PREDICATE, TreePattern, UPPER_OR_LOWER}
import org.apache.spark.sql.catalyst.util.{ArrayData, CollationSupport, GenericArrayData, TypeUtils}
import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors}
import org.apache.spark.sql.internal.SQLConf
Expand Down Expand Up @@ -508,6 +508,8 @@ abstract class StringPredicate extends BinaryExpression
override def inputTypes: Seq[AbstractDataType] =
Seq(StringTypeAnyCollation, StringTypeAnyCollation)

final override val nodePatterns: Seq[TreePattern] = Seq(STRING_PREDICATE)

protected override def nullSafeEval(input1: Any, input2: Any): Any =
compare(input1.asInstanceOf[UTF8String], input2.asInstanceOf[UTF8String])

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.catalyst.optimizer

import scala.collection.mutable.ArrayBuffer

import org.apache.spark.sql.catalyst.expressions.{Alias, And, Attribute, AttributeMap, Expression, NamedExpression, Or}
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
import org.apache.spark.sql.catalyst.plans.{Cross, FullOuter, Inner, JoinType, LeftOuter, RightOuter}
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Filter, Join, LeafNode, LogicalPlan, Project, SerializeFromObject}
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.catalyst.trees.TreePattern.{AGGREGATE, ARRAY_CONTAINS, ARRAYS_OVERLAP, AT_LEAST_N_NON_NULLS, BLOOM_FILTER, DYNAMIC_PRUNING_EXPRESSION, DYNAMIC_PRUNING_SUBQUERY, EXISTS_SUBQUERY, HIGH_ORDER_FUNCTION, IN, IN_SUBQUERY, INSET, INVOKE, JOIN, JSON_TO_STRUCT, LIKE_FAMLIY, PYTHON_UDF, REGEXP_EXTRACT_FAMILY, REGEXP_REPLACE, SCALA_UDF, STRING_PREDICATE}

/**
* This rule eliminates the [[Join]] if all the join side are [[Aggregate]]s by combine these
* [[Aggregate]]s. This rule also support the nested [[Join]], as long as all the join sides for
* every [[Join]] are [[Aggregate]]s.
*
* Note: this rule doesn't support following cases:
* 1. The [[Aggregate]]s to be merged if at least one of them does not have a predicate or
* has low predicate selectivity.
* 2. The upstream node of these [[Aggregate]]s to be merged exists [[Join]].
*/
object CombineJoinedAggregates extends Rule[LogicalPlan] with MergeScalarSubqueriesHelper {

private def isSupportedJoinType(joinType: JoinType): Boolean =
Seq(Inner, Cross, LeftOuter, RightOuter, FullOuter).contains(joinType)

private def isCheapPredicate(e: Expression): Boolean = {
!e.containsAnyPattern(PYTHON_UDF, SCALA_UDF, INVOKE, JSON_TO_STRUCT, LIKE_FAMLIY,
REGEXP_EXTRACT_FAMILY, REGEXP_REPLACE, DYNAMIC_PRUNING_SUBQUERY, DYNAMIC_PRUNING_EXPRESSION,
HIGH_ORDER_FUNCTION, IN_SUBQUERY, IN, INSET, EXISTS_SUBQUERY, STRING_PREDICATE,
AT_LEAST_N_NON_NULLS, BLOOM_FILTER, ARRAY_CONTAINS, ARRAYS_OVERLAP) &&
Option(e.apply(conf.maxTreeNodeNumOfPredicate)).isEmpty
}

/**
* Try to merge two `Aggregate`s by traverse down recursively.
*
* @return The optional tuple as follows:
* 1. the merged plan
* 2. the attribute mapping from the old to the merged version
* 3. optional filters of both plans that need to be propagated and merged in an
* ancestor `Aggregate` node if possible.
*/
private def mergePlan(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the return type is a bit complicated, let's add comment to explain

left: LogicalPlan,
right: LogicalPlan): Option[(LogicalPlan, AttributeMap[Attribute], Seq[Expression])] = {
(left, right) match {
case (la: Aggregate, ra: Aggregate) =>
mergePlan(la.child, ra.child).map { case (newChild, outputMap, filters) =>
val rightAggregateExprs = ra.aggregateExpressions.map(mapAttributes(_, outputMap))

val mergedAggregateExprs = if (filters.length == 2) {
Seq(
(la.aggregateExpressions, filters.head),
(rightAggregateExprs, filters.last)
).flatMap { case (aggregateExpressions, propagatedFilter) =>
aggregateExpressions.map { ne =>
ne.transform {
case ae @ AggregateExpression(_, _, _, filterOpt, _) =>
val newFilter = filterOpt.map { filter =>
And(propagatedFilter, filter)
}.orElse(Some(propagatedFilter))
ae.copy(filter = newFilter)
}.asInstanceOf[NamedExpression]
}
}
} else {
la.aggregateExpressions ++ rightAggregateExprs
}

(Aggregate(Seq.empty, mergedAggregateExprs, newChild), AttributeMap.empty, Seq.empty)
}
case (lp: Project, rp: Project) =>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do you have tests that hit this branch? Ideally Aggregate and Project will be merged.

Copy link
Contributor Author

@beliefer beliefer Aug 2, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. This test case has already been covered. Please see test("join side is not Aggregate")

val mergedProjectList = ArrayBuffer[NamedExpression](lp.projectList: _*)

mergePlan(lp.child, rp.child).map { case (newChild, outputMap, filters) =>
val allFilterReferences = filters.flatMap(_.references)
val newOutputMap = AttributeMap((rp.projectList ++ allFilterReferences).map { ne =>
val mapped = mapAttributes(ne, outputMap)

val withoutAlias = mapped match {
case Alias(child, _) => child
case e => e
}

val outputAttr = mergedProjectList.find {
case Alias(child, _) => child semanticEquals withoutAlias
case e => e semanticEquals withoutAlias
}.getOrElse {
mergedProjectList += mapped
mapped
}.toAttribute
ne.toAttribute -> outputAttr
})

(Project(mergedProjectList.toSeq, newChild), newOutputMap, filters)
}
case (lf: Filter, rf: Filter)
if isCheapPredicate(lf.condition) && isCheapPredicate(rf.condition) =>
mergePlan(lf.child, rf.child).map {
case (newChild, outputMap, filters) =>
val mappedRightCondition = mapAttributes(rf.condition, outputMap)
val (newLeftCondition, newRightCondition) = if (filters.length == 2) {
(And(lf.condition, filters.head), And(mappedRightCondition, filters.last))
} else {
(lf.condition, mappedRightCondition)
}
val newCondition = Or(newLeftCondition, newRightCondition)

(Filter(newCondition, newChild), outputMap, Seq(newLeftCondition, newRightCondition))
}
case (ll: LeafNode, rl: LeafNode) =>
checkIdenticalPlans(rl, ll).map { outputMap =>
(ll, outputMap, Seq.empty)
}
case (ls: SerializeFromObject, rs: SerializeFromObject) =>
checkIdenticalPlans(rs, ls).map { outputMap =>
(ls, outputMap, Seq.empty)
}
case _ => None
}
}

def apply(plan: LogicalPlan): LogicalPlan = {
if (!conf.combineJoinedAggregatesEnabled) return plan

plan.transformUpWithPruning(_.containsAnyPattern(JOIN, AGGREGATE), ruleId) {
case j @ Join(left: Aggregate, right: Aggregate, joinType, None, _)
if isSupportedJoinType(joinType) &&
left.groupingExpressions.isEmpty && right.groupingExpressions.isEmpty =>
val mergedAggregate = mergePlan(left, right)
mergedAggregate.map(_._1).getOrElse(j)
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ import org.apache.spark.sql.types.DataType
* : +- ReusedSubquery Subquery scalar-subquery#242, [id=#125]
* +- *(1) Scan OneRowRelation[]
*/
object MergeScalarSubqueries extends Rule[LogicalPlan] {
object MergeScalarSubqueries extends Rule[LogicalPlan] with MergeScalarSubqueriesHelper {
def apply(plan: LogicalPlan): LogicalPlan = {
plan match {
// Subquery reuse needs to be enabled for this optimization.
Expand Down Expand Up @@ -212,17 +212,6 @@ object MergeScalarSubqueries extends Rule[LogicalPlan] {
}
}

// If 2 plans are identical return the attribute mapping from the new to the cached version.
private def checkIdenticalPlans(
newPlan: LogicalPlan,
cachedPlan: LogicalPlan): Option[AttributeMap[Attribute]] = {
if (newPlan.canonicalized == cachedPlan.canonicalized) {
Some(AttributeMap(newPlan.output.zip(cachedPlan.output)))
} else {
None
}
}

// Recursively traverse down and try merging 2 plans. If merge is possible then return the merged
// plan with the attribute mapping from the new to the merged version.
// Please note that merging arbitrary plans can be complicated, the current version supports only
Expand Down Expand Up @@ -314,12 +303,6 @@ object MergeScalarSubqueries extends Rule[LogicalPlan] {
plan)
}

private def mapAttributes[T <: Expression](expr: T, outputMap: AttributeMap[Attribute]) = {
expr.transform {
case a: Attribute => outputMap.getOrElse(a, a)
}.asInstanceOf[T]
}

// Applies `outputMap` attribute mapping on attributes of `newExpressions` and merges them into
// `cachedExpressions`. Returns the merged expressions and the attribute mapping from the new to
// the merged version that can be propagated up during merging nodes.
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, Expression}
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan

/**
* The helper class used to merge scalar subqueries.
*/
trait MergeScalarSubqueriesHelper {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@peter-toth I create this trait used to share the common functions.


// If 2 plans are identical return the attribute mapping from the left to the right.
protected def checkIdenticalPlans(
left: LogicalPlan, right: LogicalPlan): Option[AttributeMap[Attribute]] = {
if (left.canonicalized == right.canonicalized) {
Some(AttributeMap(left.output.zip(right.output)))
} else {
None
}
}

protected def mapAttributes[T <: Expression](expr: T, outputMap: AttributeMap[Attribute]): T = {
expr.transform {
case a: Attribute => outputMap.getOrElse(a, a)
}.asInstanceOf[T]
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ abstract class Optimizer(catalogManager: CatalogManager)
EliminateOffsets,
EliminateLimits,
CombineUnions,
CombineJoinedAggregates,
// Constant folding and strength reduction
OptimizeRepartition,
EliminateWindowPartitions,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,7 @@ object RuleIdCollection {
"org.apache.spark.sql.catalyst.optimizer.ColumnPruning" ::
"org.apache.spark.sql.catalyst.optimizer.CombineConcats" ::
"org.apache.spark.sql.catalyst.optimizer.CombineFilters" ::
"org.apache.spark.sql.catalyst.optimizer.CombineJoinedAggregates" ::
"org.apache.spark.sql.catalyst.optimizer.CombineTypedFilters" ::
"org.apache.spark.sql.catalyst.optimizer.CombineUnions" ::
"org.apache.spark.sql.catalyst.optimizer.ConstantFolding" ::
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,17 @@ object TreePattern extends Enumeration {
val AGGREGATE_EXPRESSION = Value(0)
val ALIAS: Value = Value
val AND: Value = Value
val ARRAY_CONTAINS: Value = Value
val ARRAYS_OVERLAP: Value = Value
val ARRAYS_ZIP: Value = Value
val ATTRIBUTE_REFERENCE: Value = Value
val APPEND_COLUMNS: Value = Value
val AVERAGE: Value = Value
val AT_LEAST_N_NON_NULLS = Value
val GROUPING_ANALYTICS: Value = Value
val BINARY_ARITHMETIC: Value = Value
val BINARY_COMPARISON: Value = Value
val BLOOM_FILTER: Value = Value
val CASE_WHEN: Value = Value
val CAST: Value = Value
val COALESCE: Value = Value
Expand Down Expand Up @@ -88,6 +92,7 @@ object TreePattern extends Enumeration {
val SCALA_UDF: Value = Value
val SESSION_WINDOW: Value = Value
val SORT: Value = Value
val STRING_PREDICATE: Value = Value
val SUBQUERY_ALIAS: Value = Value
val SUM: Value = Value
val TIME_WINDOW: Value = Value
Expand Down
Loading