From f41871490da4df5d7cb5d352f5ef3795e8a6f625 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Mon, 5 Oct 2020 19:38:24 -0700 Subject: [PATCH 01/11] Avoid collapsing projects if reaching max allowed common exprs. --- .../sql/catalyst/optimizer/Optimizer.scala | 23 ++++++++++-- .../apache/spark/sql/internal/SQLConf.scala | 15 ++++++++ .../optimizer/CollapseProjectSuite.scala | 35 +++++++++++++++++-- 3 files changed, 69 insertions(+), 4 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index f2360150e47b..11017aa27d7b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -732,10 +732,12 @@ object ColumnPruning extends Rule[LogicalPlan] { * `GlobalLimit(LocalLimit)` pattern is also considered. */ object CollapseProject extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { case p1 @ Project(_, p2: Project) => - if (haveCommonNonDeterministicOutput(p1.projectList, p2.projectList)) { + val maxCommonExprs = SQLConf.get.maxCommonExprsInCollapseProject + + if (haveCommonNonDeterministicOutput(p1.projectList, p2.projectList) || + getLargestNumOfCommonOutput(p1.projectList, p2.projectList) >= maxCommonExprs) { p1 } else { p2.copy(projectList = buildCleanedProjectList(p1.projectList, p2.projectList)) @@ -766,6 +768,23 @@ object CollapseProject extends Rule[LogicalPlan] { }) } + // Counts for the largest times common outputs from lower operator are used in upper operators. + private def getLargestNumOfCommonOutput( + upper: Seq[NamedExpression], lower: Seq[NamedExpression]): Int = { + val aliases = collectAliases(lower) + val exprMap = mutable.HashMap.empty[Attribute, Int] + + upper.foreach(_.collect { + case a: Attribute if aliases.contains(a) => exprMap.update(a, exprMap.getOrElse(a, 0) + 1) + }) + + if (exprMap.size > 0) { + exprMap.maxBy(_._2)._2 + } else { + 0 + } + } + private def haveCommonNonDeterministicOutput( upper: Seq[NamedExpression], lower: Seq[NamedExpression]): Boolean = { // Create a map of Aliases to their values from the lower projection. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 18ffc655b217..ae531f8cd720 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -1926,6 +1926,19 @@ object SQLConf { .booleanConf .createWithDefault(true) + val MAX_COMMON_EXPRS_IN_COLLAPSE_PROJECT = + buildConf("spark.sql.optimizer.maxCommonExprsInCollapseProject") + .doc("An integer number indicates the maximum allowed number of a common expression " + + "can be collapsed into upper Project from lower Project by optimizer rule " + + "`CollapseProject`. Normally `CollapseProject` will collapse adjacent Project " + + "and merge expressions. But in some edge cases, expensive expressions might be " + + "duplicated many times in merged Project by this optimization. This config sets " + + "a maximum number. Once an expression is duplicated equal to or more than this number " + + "if merging two Project, Spark SQL will skip the merging.") + .version("3.1.0") + .intConf + .createWithDefault(20) + val DECIMAL_OPERATIONS_ALLOW_PREC_LOSS = buildConf("spark.sql.decimalOperations.allowPrecisionLoss") .internal() @@ -3289,6 +3302,8 @@ class SQLConf extends Serializable with Logging { def replaceExceptWithFilter: Boolean = getConf(REPLACE_EXCEPT_WITH_FILTER) + def maxCommonExprsInCollapseProject: Int = getConf(MAX_COMMON_EXPRS_IN_COLLAPSE_PROJECT) + def decimalOperationsAllowPrecisionLoss: Boolean = getConf(DECIMAL_OPERATIONS_ALLOW_PREC_LOSS) def literalPickMinimumPrecision: Boolean = getConf(LITERAL_PICK_MINIMUM_PRECISION) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CollapseProjectSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CollapseProjectSuite.scala index 42bcd13ee378..1e8e2c1ff6eb 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CollapseProjectSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CollapseProjectSuite.scala @@ -20,11 +20,12 @@ package org.apache.spark.sql.catalyst.optimizer import org.apache.spark.sql.catalyst.analysis.EliminateSubqueryAliases import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ -import org.apache.spark.sql.catalyst.expressions.{Alias, Rand} +import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules.RuleExecutor -import org.apache.spark.sql.types.MetadataBuilder +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types.{MetadataBuilder, StructType} class CollapseProjectSuite extends PlanTest { object Optimize extends RuleExecutor[LogicalPlan] { @@ -170,4 +171,34 @@ class CollapseProjectSuite extends PlanTest { val expected = Sample(0.0, 0.6, false, 11L, relation.select('a as 'c)).analyze comparePlans(optimized, expected) } + + test("SPARK-32945: avoid collapsing projects if reaching max allowed common exprs") { + val options = Map.empty[String, String] + val schema = StructType.fromDDL("a int, b int, c string, d long") + + Seq("1", "2", "3", "4").foreach { maxCommonExprs => + withSQLConf(SQLConf.MAX_COMMON_EXPRS_IN_COLLAPSE_PROJECT.key -> maxCommonExprs) { + // If we collapse two Projects, `JsonToStructs` will be repeated three times. + val relation = LocalRelation('json.string) + val query = relation.select( + JsonToStructs(schema, options, 'json).as("struct")) + .select( + GetStructField('struct, 0).as("a"), + GetStructField('struct, 1).as("b"), + GetStructField('struct, 2).as("c")).analyze + val optimized = Optimize.execute(query) + + if (maxCommonExprs.toInt <= 3) { + val expected = query + comparePlans(optimized, expected) + } else { + val expected = relation.select( + GetStructField(JsonToStructs(schema, options, 'json), 0).as("a"), + GetStructField(JsonToStructs(schema, options, 'json), 1).as("b"), + GetStructField(JsonToStructs(schema, options, 'json), 2).as("c")).analyze + comparePlans(optimized, expected) + } + } + } + } } From 98843dd0d86c47fce7871c56d1b94e67976d473e Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Tue, 6 Oct 2020 18:56:56 -0700 Subject: [PATCH 02/11] Avoid collapsing projection lists in physical query. --- .../sql/catalyst/planning/patterns.scala | 28 +++++++++++++++++-- 1 file changed, 26 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala index 2880e87ab156..a7621b7237b1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.catalyst.planning +import scala.collection.mutable + import org.apache.spark.internal.Logging import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.expressions._ @@ -108,6 +110,8 @@ object ScanOperation extends OperationHelper with PredicateHelper { type ScanReturnType = Option[(Option[Seq[NamedExpression]], Seq[Expression], LogicalPlan, AttributeMap[Expression])] + val maxCommonExprs = SQLConf.get.maxCommonExprsInCollapseProject + def unapply(plan: LogicalPlan): Option[ReturnType] = { collectProjectsAndFilters(plan) match { case Some((fields, filters, child, _)) => @@ -124,14 +128,34 @@ object ScanOperation extends OperationHelper with PredicateHelper { }.exists(!_.deterministic)) } + def equalOrMoreThanMaxAllowedCommonOutput( + expr: Seq[NamedExpression], + aliases: AttributeMap[Expression]): Boolean = { + val exprMap = mutable.HashMap.empty[Attribute, Int] + + expr.foreach(_.collect { + case a: Attribute if aliases.contains(a) => exprMap.update(a, exprMap.getOrElse(a, 0) + 1) + }) + + val commonOutputs = if (exprMap.size > 0) { + exprMap.maxBy(_._2)._2 + } else { + 0 + } + + commonOutputs >= maxCommonExprs + } + private def collectProjectsAndFilters(plan: LogicalPlan): ScanReturnType = { plan match { case Project(fields, child) => collectProjectsAndFilters(child) match { case Some((_, filters, other, aliases)) => // Follow CollapseProject and only keep going if the collected Projects - // do not have common non-deterministic expressions. - if (!hasCommonNonDeterministic(fields, aliases)) { + // do not have common non-deterministic expressions, or do not have equal to/more than + // maximum allowed common outputs. + if (!hasCommonNonDeterministic(fields, aliases) + || !equalOrMoreThanMaxAllowedCommonOutput(fields, aliases)) { val substitutedFields = fields.map(substitute(aliases)).asInstanceOf[Seq[NamedExpression]] Some((Some(substitutedFields), filters, other, collectAliases(substitutedFields))) From 1b567e7a46dbaf926e708fdd9f3228530fd2c5d9 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Tue, 6 Oct 2020 19:09:01 -0700 Subject: [PATCH 03/11] Should be more than instead of equal to and more than. --- .../org/apache/spark/sql/catalyst/optimizer/Optimizer.scala | 2 +- .../org/apache/spark/sql/catalyst/planning/patterns.scala | 6 +++--- .../main/scala/org/apache/spark/sql/internal/SQLConf.scala | 2 +- .../spark/sql/catalyst/optimizer/CollapseProjectSuite.scala | 2 +- 4 files changed, 6 insertions(+), 6 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 11017aa27d7b..ea533c4662a5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -737,7 +737,7 @@ object CollapseProject extends Rule[LogicalPlan] { val maxCommonExprs = SQLConf.get.maxCommonExprsInCollapseProject if (haveCommonNonDeterministicOutput(p1.projectList, p2.projectList) || - getLargestNumOfCommonOutput(p1.projectList, p2.projectList) >= maxCommonExprs) { + getLargestNumOfCommonOutput(p1.projectList, p2.projectList) > maxCommonExprs) { p1 } else { p2.copy(projectList = buildCleanedProjectList(p1.projectList, p2.projectList)) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala index a7621b7237b1..e87f6ae0104c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala @@ -128,7 +128,7 @@ object ScanOperation extends OperationHelper with PredicateHelper { }.exists(!_.deterministic)) } - def equalOrMoreThanMaxAllowedCommonOutput( + def moreThanMaxAllowedCommonOutput( expr: Seq[NamedExpression], aliases: AttributeMap[Expression]): Boolean = { val exprMap = mutable.HashMap.empty[Attribute, Int] @@ -143,7 +143,7 @@ object ScanOperation extends OperationHelper with PredicateHelper { 0 } - commonOutputs >= maxCommonExprs + commonOutputs > maxCommonExprs } private def collectProjectsAndFilters(plan: LogicalPlan): ScanReturnType = { @@ -155,7 +155,7 @@ object ScanOperation extends OperationHelper with PredicateHelper { // do not have common non-deterministic expressions, or do not have equal to/more than // maximum allowed common outputs. if (!hasCommonNonDeterministic(fields, aliases) - || !equalOrMoreThanMaxAllowedCommonOutput(fields, aliases)) { + || !moreThanMaxAllowedCommonOutput(fields, aliases)) { val substitutedFields = fields.map(substitute(aliases)).asInstanceOf[Seq[NamedExpression]] Some((Some(substitutedFields), filters, other, collectAliases(substitutedFields))) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index ae531f8cd720..2b67256a8bf0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -1933,7 +1933,7 @@ object SQLConf { "`CollapseProject`. Normally `CollapseProject` will collapse adjacent Project " + "and merge expressions. But in some edge cases, expensive expressions might be " + "duplicated many times in merged Project by this optimization. This config sets " + - "a maximum number. Once an expression is duplicated equal to or more than this number " + + "a maximum number. Once an expression is duplicated more than this number " + "if merging two Project, Spark SQL will skip the merging.") .version("3.1.0") .intConf diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CollapseProjectSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CollapseProjectSuite.scala index 1e8e2c1ff6eb..bcac96f468e6 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CollapseProjectSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CollapseProjectSuite.scala @@ -188,7 +188,7 @@ class CollapseProjectSuite extends PlanTest { GetStructField('struct, 2).as("c")).analyze val optimized = Optimize.execute(query) - if (maxCommonExprs.toInt <= 3) { + if (maxCommonExprs.toInt < 3) { val expected = query comparePlans(optimized, expected) } else { From 76509b36cf2afef38a2794422cf5684f86bd1cf6 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Fri, 9 Oct 2020 12:35:12 -0700 Subject: [PATCH 04/11] Update the approach. --- .../sql/catalyst/optimizer/Optimizer.scala | 30 ++++++++----- .../sql/catalyst/planning/patterns.scala | 8 ++-- .../optimizer/CollapseProjectSuite.scala | 43 +++++++++++++++---- 3 files changed, 57 insertions(+), 24 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index ea533c4662a5..645d4e409f75 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -726,22 +726,16 @@ object ColumnPruning extends Rule[LogicalPlan] { /** * Combines two [[Project]] operators into one and perform alias substitution, * merging the expressions into one single expression for the following cases. - * 1. When two [[Project]] operators are adjacent. + * 1. When two [[Project]] operators are adjacent, if the number of common expressions in the + * combined [[Project]] is not more than `spark.sql.optimizer.maxCommonExprsInCollapseProject`. * 2. When two [[Project]] operators have LocalLimit/Sample/Repartition operator between them * and the upper project consists of the same number of columns which is equal or aliasing. * `GlobalLimit(LocalLimit)` pattern is also considered. */ object CollapseProject extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { - case p1 @ Project(_, p2: Project) => - val maxCommonExprs = SQLConf.get.maxCommonExprsInCollapseProject - - if (haveCommonNonDeterministicOutput(p1.projectList, p2.projectList) || - getLargestNumOfCommonOutput(p1.projectList, p2.projectList) > maxCommonExprs) { - p1 - } else { - p2.copy(projectList = buildCleanedProjectList(p1.projectList, p2.projectList)) - } + def apply(plan: LogicalPlan): LogicalPlan = plan transformDown { + case p @ Project(_, _: Project) => + collapseProjects(p) case p @ Project(_, agg: Aggregate) => if (haveCommonNonDeterministicOutput(p.projectList, agg.aggregateExpressions)) { p @@ -762,6 +756,20 @@ object CollapseProject extends Rule[LogicalPlan] { s.copy(child = p2.copy(projectList = buildCleanedProjectList(l1, p2.projectList))) } + private def collapseProjects(plan: LogicalPlan): LogicalPlan = plan match { + case p1 @ Project(_, p2: Project) => + val maxCommonExprs = SQLConf.get.maxCommonExprsInCollapseProject + + if (haveCommonNonDeterministicOutput(p1.projectList, p2.projectList) || + getLargestNumOfCommonOutput(p1.projectList, p2.projectList) > maxCommonExprs) { + p1 + } else { + collapseProjects( + p2.copy(projectList = buildCleanedProjectList(p1.projectList, p2.projectList))) + } + case _ => plan + } + private def collectAliases(projectList: Seq[NamedExpression]): AttributeMap[Alias] = { AttributeMap(projectList.collect { case a: Alias => a.toAttribute -> a diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala index e87f6ae0104c..f74359b92598 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala @@ -129,8 +129,8 @@ object ScanOperation extends OperationHelper with PredicateHelper { } def moreThanMaxAllowedCommonOutput( - expr: Seq[NamedExpression], - aliases: AttributeMap[Expression]): Boolean = { + expr: Seq[NamedExpression], + aliases: AttributeMap[Expression]): Boolean = { val exprMap = mutable.HashMap.empty[Attribute, Int] expr.foreach(_.collect { @@ -154,8 +154,8 @@ object ScanOperation extends OperationHelper with PredicateHelper { // Follow CollapseProject and only keep going if the collected Projects // do not have common non-deterministic expressions, or do not have equal to/more than // maximum allowed common outputs. - if (!hasCommonNonDeterministic(fields, aliases) - || !moreThanMaxAllowedCommonOutput(fields, aliases)) { + if (!hasCommonNonDeterministic(fields, aliases) || + !moreThanMaxAllowedCommonOutput(fields, aliases)) { val substitutedFields = fields.map(substitute(aliases)).asInstanceOf[Seq[NamedExpression]] Some((Some(substitutedFields), filters, other, collectAliases(substitutedFields))) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CollapseProjectSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CollapseProjectSuite.scala index bcac96f468e6..1a57731700dd 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CollapseProjectSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CollapseProjectSuite.scala @@ -180,23 +180,48 @@ class CollapseProjectSuite extends PlanTest { withSQLConf(SQLConf.MAX_COMMON_EXPRS_IN_COLLAPSE_PROJECT.key -> maxCommonExprs) { // If we collapse two Projects, `JsonToStructs` will be repeated three times. val relation = LocalRelation('json.string) - val query = relation.select( + val query1 = relation.select( JsonToStructs(schema, options, 'json).as("struct")) .select( GetStructField('struct, 0).as("a"), GetStructField('struct, 1).as("b"), - GetStructField('struct, 2).as("c")).analyze - val optimized = Optimize.execute(query) - - if (maxCommonExprs.toInt < 3) { - val expected = query - comparePlans(optimized, expected) + GetStructField('struct, 2).as("c"), + GetStructField('struct, 3).as("d")).analyze + val optimized1 = Optimize.execute(query1) + + val query2 = relation + .select('json, JsonToStructs(schema, options, 'json).as("struct")) + .select('json, 'struct, GetStructField('struct, 0).as("a")) + .select('json, 'struct, 'a, GetStructField('struct, 1).as("b")) + .select('json, 'struct, 'a, 'b, GetStructField('struct, 2).as("c")) + .analyze + val optimized2 = Optimize.execute(query2) + + if (maxCommonExprs.toInt < 4) { + val expected1 = query1 + comparePlans(optimized1, expected1) + + val expected2 = relation + .select('json, JsonToStructs(schema, options, 'json).as("struct")) + .select('json, 'struct, + GetStructField('struct, 0).as("a"), + GetStructField('struct, 1).as("b"), + GetStructField('struct, 2).as("c")) + .analyze + comparePlans(optimized2, expected2) } else { - val expected = relation.select( + val expected1 = relation.select( + GetStructField(JsonToStructs(schema, options, 'json), 0).as("a"), + GetStructField(JsonToStructs(schema, options, 'json), 1).as("b"), + GetStructField(JsonToStructs(schema, options, 'json), 2).as("c"), + GetStructField(JsonToStructs(schema, options, 'json), 3).as("d")).analyze + comparePlans(optimized1, expected1) + + val expected2 = relation.select('json, JsonToStructs(schema, options, 'json).as("struct"), GetStructField(JsonToStructs(schema, options, 'json), 0).as("a"), GetStructField(JsonToStructs(schema, options, 'json), 1).as("b"), GetStructField(JsonToStructs(schema, options, 'json), 2).as("c")).analyze - comparePlans(optimized, expected) + comparePlans(optimized2, expected2) } } } From 43eb50d57eb334395c9cbe7b193de28644c62e68 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Fri, 9 Oct 2020 13:29:45 -0700 Subject: [PATCH 05/11] Add end-to-end test. --- .../sql/catalyst/planning/patterns.scala | 8 ++- .../org/apache/spark/sql/DataFrameSuite.scala | 50 ++++++++++++++++++- 2 files changed, 51 insertions(+), 7 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala index f74359b92598..2de6fbb2c846 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala @@ -110,8 +110,6 @@ object ScanOperation extends OperationHelper with PredicateHelper { type ScanReturnType = Option[(Option[Seq[NamedExpression]], Seq[Expression], LogicalPlan, AttributeMap[Expression])] - val maxCommonExprs = SQLConf.get.maxCommonExprsInCollapseProject - def unapply(plan: LogicalPlan): Option[ReturnType] = { collectProjectsAndFilters(plan) match { case Some((fields, filters, child, _)) => @@ -143,7 +141,7 @@ object ScanOperation extends OperationHelper with PredicateHelper { 0 } - commonOutputs > maxCommonExprs + commonOutputs > SQLConf.get.maxCommonExprsInCollapseProject } private def collectProjectsAndFilters(plan: LogicalPlan): ScanReturnType = { @@ -152,9 +150,9 @@ object ScanOperation extends OperationHelper with PredicateHelper { collectProjectsAndFilters(child) match { case Some((_, filters, other, aliases)) => // Follow CollapseProject and only keep going if the collected Projects - // do not have common non-deterministic expressions, or do not have equal to/more than + // do not have common non-deterministic expressions, and do not have more than // maximum allowed common outputs. - if (!hasCommonNonDeterministic(fields, aliases) || + if (!hasCommonNonDeterministic(fields, aliases) && !moreThanMaxAllowedCommonOutput(fields, aliases)) { val substitutedFields = fields.map(substitute(aliases)).asInstanceOf[Seq[NamedExpression]] diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 321f4966178d..5308700037de 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -32,12 +32,13 @@ import org.scalatest.matchers.should.Matchers._ import org.apache.spark.SparkException import org.apache.spark.scheduler.{SparkListener, SparkListenerJobEnd} import org.apache.spark.sql.catalyst.TableIdentifier +import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute import org.apache.spark.sql.catalyst.encoders.{ExpressionEncoder, RowEncoder} -import org.apache.spark.sql.catalyst.expressions.Uuid +import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.optimizer.ConvertToLocalRelation import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, OneRowRelation} import org.apache.spark.sql.catalyst.util.DateTimeUtils -import org.apache.spark.sql.execution.{FilterExec, QueryExecution, WholeStageCodegenExec} +import org.apache.spark.sql.execution.{FilterExec, ProjectExec, QueryExecution, WholeStageCodegenExec} import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper import org.apache.spark.sql.execution.aggregate.HashAggregateExec import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ReusedExchangeExec, ShuffleExchangeExec} @@ -2567,6 +2568,51 @@ class DataFrameSuite extends QueryTest val df = l.join(r, $"col2" === $"col4", "LeftOuter") checkAnswer(df, Row("2", "2")) } + + test("SPARK-32945: Avoid collapsing projects if reaching max allowed common exprs") { + val options = Map.empty[String, String] + val schema = StructType.fromDDL("a int, b int, c long, d string") + + withTable("test_table") { + val jsonDF = Seq("""{"a":1, "b":2, "c": 123, "d": "test"}""").toDF("json") + jsonDF.write.saveAsTable("test_table") + + Seq("1", "2", "3", "4").foreach { maxCommonExprs => + withSQLConf(SQLConf.MAX_COMMON_EXPRS_IN_COLLAPSE_PROJECT.key -> maxCommonExprs) { + + val jsonDF = spark.read.table("test_table") + val jsonStruct = UnresolvedAttribute("struct") + val df = jsonDF + .select(from_json('json, schema, options).as("struct")) + .select( + Column(GetStructField(jsonStruct, 0)).as("a"), + Column(GetStructField(jsonStruct, 1)).as("b"), + Column(GetStructField(jsonStruct, 2)).as("c"), + Column(GetStructField(jsonStruct, 3)).as("d")) + + val numProjects = df.queryExecution.executedPlan.collect { + case p: ProjectExec => p + }.size + + val numFromJson = df.queryExecution.executedPlan.collect { + case p: ProjectExec => p.projectList.flatMap(_.collect { + case j: JsonToStructs => j + }) + }.flatten.size + + if (maxCommonExprs.toInt < 4) { + assert(numProjects == 2) + assert(numFromJson == 1) + } else { + assert(numProjects == 1) + assert(numFromJson == 4) + } + + checkAnswer(df, Row(1, 2, 123L, "test")) + } + } + } + } } case class GroupByKey(a: Int, b: Int) From 4bf4dc252c663afce30848481ec3887a70f7b611 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Fri, 9 Oct 2020 17:11:53 -0700 Subject: [PATCH 06/11] Make RewriteSubquery batch as fixed point. --- .../org/apache/spark/sql/catalyst/optimizer/Optimizer.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 645d4e409f75..b0c3e933cc39 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -216,7 +216,7 @@ abstract class Optimizer(catalogManager: CatalogManager) // The following batch should be executed after batch "Join Reorder" and "LocalRelation". Batch("Check Cartesian Products", Once, CheckCartesianProducts) :+ - Batch("RewriteSubquery", Once, + Batch("RewriteSubquery", fixedPoint, RewritePredicateSubquery, ColumnPruning, CollapseProject, From 9bfafc75d66987fd550052e6b3d53efe703b55ec Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Mon, 12 Oct 2020 21:56:37 -0700 Subject: [PATCH 07/11] For review comments. --- .../org/apache/spark/sql/catalyst/optimizer/Optimizer.scala | 1 + .../main/scala/org/apache/spark/sql/internal/SQLConf.scala | 5 +++-- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index b0c3e933cc39..572554fc5a91 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -216,6 +216,7 @@ abstract class Optimizer(catalogManager: CatalogManager) // The following batch should be executed after batch "Join Reorder" and "LocalRelation". Batch("Check Cartesian Products", Once, CheckCartesianProducts) :+ + // `CollapseProject` cannot collapse all projects in once. So we need `fixedPoint` here. Batch("RewriteSubquery", fixedPoint, RewritePredicateSubquery, ColumnPruning, diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 2b67256a8bf0..eb3500796023 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -1928,8 +1928,8 @@ object SQLConf { val MAX_COMMON_EXPRS_IN_COLLAPSE_PROJECT = buildConf("spark.sql.optimizer.maxCommonExprsInCollapseProject") - .doc("An integer number indicates the maximum allowed number of a common expression " + - "can be collapsed into upper Project from lower Project by optimizer rule " + + .doc("An integer number indicates the maximum allowed number of common input expression " + + "from lower Project when being collapsed into upper Project by optimizer rule " + "`CollapseProject`. Normally `CollapseProject` will collapse adjacent Project " + "and merge expressions. But in some edge cases, expensive expressions might be " + "duplicated many times in merged Project by this optimization. This config sets " + @@ -1937,6 +1937,7 @@ object SQLConf { "if merging two Project, Spark SQL will skip the merging.") .version("3.1.0") .intConf + .checkValue(m => m > 0, "maxCommonExprsInCollapseProject must be larger than zero.") .createWithDefault(20) val DECIMAL_OPERATIONS_ALLOW_PREC_LOSS = From c2c01e4a1c3d4e5e68058e1739df9323e74acf3e Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Tue, 13 Oct 2020 09:18:49 -0700 Subject: [PATCH 08/11] Add more doc to config. --- .../main/scala/org/apache/spark/sql/internal/SQLConf.scala | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index eb3500796023..d7c608d5bc54 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -1934,7 +1934,10 @@ object SQLConf { "and merge expressions. But in some edge cases, expensive expressions might be " + "duplicated many times in merged Project by this optimization. This config sets " + "a maximum number. Once an expression is duplicated more than this number " + - "if merging two Project, Spark SQL will skip the merging.") + "if merging two Project, Spark SQL will skip the merging. Note that normally " + + "in whole-stage codegen Project operator will de-duplicate expressions internally, " + + "but in edge cases Spark cannot do whole-stage codegen and fallback to interpreted " + + "mode. In such cases, users can use this config to avoid duplicate expressions") .version("3.1.0") .intConf .checkValue(m => m > 0, "maxCommonExprsInCollapseProject must be larger than zero.") From 4990375fd1ed3acbd56e00817e7fd167a63031d6 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Thu, 22 Oct 2020 10:51:34 -0700 Subject: [PATCH 09/11] Update for review comment. --- .../sql/catalyst/optimizer/Optimizer.scala | 17 ++++++++--------- .../spark/sql/catalyst/planning/patterns.scala | 8 +++----- .../org/apache/spark/sql/internal/SQLConf.scala | 8 ++++++-- .../org/apache/spark/sql/DataFrameSuite.scala | 8 ++++---- 4 files changed, 21 insertions(+), 20 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 572554fc5a91..0f11b326755d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -759,10 +759,8 @@ object CollapseProject extends Rule[LogicalPlan] { private def collapseProjects(plan: LogicalPlan): LogicalPlan = plan match { case p1 @ Project(_, p2: Project) => - val maxCommonExprs = SQLConf.get.maxCommonExprsInCollapseProject - if (haveCommonNonDeterministicOutput(p1.projectList, p2.projectList) || - getLargestNumOfCommonOutput(p1.projectList, p2.projectList) > maxCommonExprs) { + moreThanMaxAllowedCommonOutput(p1.projectList, p2.projectList)) { p1 } else { collapseProjects( @@ -777,9 +775,10 @@ object CollapseProject extends Rule[LogicalPlan] { }) } - // Counts for the largest times common outputs from lower operator are used in upper operators. - private def getLargestNumOfCommonOutput( - upper: Seq[NamedExpression], lower: Seq[NamedExpression]): Int = { + // Whether the largest times common outputs from lower operator used in upper operators is + // larger than allowed. + private def moreThanMaxAllowedCommonOutput( + upper: Seq[NamedExpression], lower: Seq[NamedExpression]): Boolean = { val aliases = collectAliases(lower) val exprMap = mutable.HashMap.empty[Attribute, Int] @@ -787,10 +786,10 @@ object CollapseProject extends Rule[LogicalPlan] { case a: Attribute if aliases.contains(a) => exprMap.update(a, exprMap.getOrElse(a, 0) + 1) }) - if (exprMap.size > 0) { - exprMap.maxBy(_._2)._2 + if (exprMap.nonEmpty) { + exprMap.maxBy(_._2)._2 > SQLConf.get.maxCommonExprsInCollapseProject } else { - 0 + false } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala index 2de6fbb2c846..3fc1168960d1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala @@ -135,13 +135,11 @@ object ScanOperation extends OperationHelper with PredicateHelper { case a: Attribute if aliases.contains(a) => exprMap.update(a, exprMap.getOrElse(a, 0) + 1) }) - val commonOutputs = if (exprMap.size > 0) { - exprMap.maxBy(_._2)._2 + if (exprMap.nonEmpty) { + exprMap.maxBy(_._2)._2 > SQLConf.get.maxCommonExprsInCollapseProject } else { - 0 + false } - - commonOutputs > SQLConf.get.maxCommonExprsInCollapseProject } private def collectProjectsAndFilters(plan: LogicalPlan): ScanReturnType = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index d7c608d5bc54..ed3d612d3199 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -1937,10 +1937,14 @@ object SQLConf { "if merging two Project, Spark SQL will skip the merging. Note that normally " + "in whole-stage codegen Project operator will de-duplicate expressions internally, " + "but in edge cases Spark cannot do whole-stage codegen and fallback to interpreted " + - "mode. In such cases, users can use this config to avoid duplicate expressions") + "mode. In such cases, users can use this config to avoid duplicate expressions. " + + "Note that even users exclude `CollapseProject` rule using " + + "`spark.sql.optimizer.excludedRules`, at physical planning phase Spark will still " + + "collapse projections. This config is also effective on collapsing projections in " + + "the physical planning.") .version("3.1.0") .intConf - .checkValue(m => m > 0, "maxCommonExprsInCollapseProject must be larger than zero.") + .checkValue(_ > 0, "The value of maxCommonExprsInCollapseProject must be larger than zero.") .createWithDefault(20) val DECIMAL_OPERATIONS_ALLOW_PREC_LOSS = diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 5308700037de..cccc44094c98 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -2574,15 +2574,15 @@ class DataFrameSuite extends QueryTest val schema = StructType.fromDDL("a int, b int, c long, d string") withTable("test_table") { - val jsonDF = Seq("""{"a":1, "b":2, "c": 123, "d": "test"}""").toDF("json") - jsonDF.write.saveAsTable("test_table") + val jsonDf = Seq("""{"a":1, "b":2, "c": 123, "d": "test"}""").toDF("json") + jsonDf.write.saveAsTable("test_table") Seq("1", "2", "3", "4").foreach { maxCommonExprs => withSQLConf(SQLConf.MAX_COMMON_EXPRS_IN_COLLAPSE_PROJECT.key -> maxCommonExprs) { - val jsonDF = spark.read.table("test_table") + val jsonDf = spark.read.table("test_table") val jsonStruct = UnresolvedAttribute("struct") - val df = jsonDF + val df = jsonDf .select(from_json('json, schema, options).as("struct")) .select( Column(GetStructField(jsonStruct, 0)).as("a"), From 58e71d85fb0752a2963b9e343d321cf1c6a8e634 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Sat, 31 Oct 2020 10:49:20 -0700 Subject: [PATCH 10/11] For review comment. --- .../scala/org/apache/spark/sql/catalyst/planning/patterns.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala index 3fc1168960d1..36ccea341a8c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala @@ -126,7 +126,7 @@ object ScanOperation extends OperationHelper with PredicateHelper { }.exists(!_.deterministic)) } - def moreThanMaxAllowedCommonOutput( + private def moreThanMaxAllowedCommonOutput( expr: Seq[NamedExpression], aliases: AttributeMap[Expression]): Boolean = { val exprMap = mutable.HashMap.empty[Attribute, Int] From bbaae3ed9b96ca517ec5435838ac37feb578c959 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Thu, 12 Nov 2020 16:05:40 -0800 Subject: [PATCH 11/11] Change config default value. --- .../src/main/scala/org/apache/spark/sql/internal/SQLConf.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 3b872573874b..752fdcd50688 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -1982,7 +1982,7 @@ object SQLConf { .version("3.1.0") .intConf .checkValue(_ > 0, "The value of maxCommonExprsInCollapseProject must be larger than zero.") - .createWithDefault(20) + .createWithDefault(Int.MaxValue) val DECIMAL_OPERATIONS_ALLOW_PREC_LOSS = buildConf("spark.sql.decimalOperations.allowPrecisionLoss")