Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -697,16 +697,17 @@ object ColumnPruning extends Rule[LogicalPlan] {
* `GlobalLimit(LocalLimit)` pattern is also considered.
*/
object CollapseProject extends Rule[LogicalPlan] {

def apply(plan: LogicalPlan): LogicalPlan = plan transformUp {
case p1 @ Project(_, p2: Project) =>
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems like you can simplify it like this?;

  private def hasTooManyExprs(exprs: Seq[Expression]): Boolean = {
    var numExprs = 0
    exprs.foreach { _.foreach { _ => numExprs += 1 } }
    numExprs > SQLConf.get.XXXX
  }

  def apply(plan: LogicalPlan): LogicalPlan = plan transformUp {
    case p1 @ Project(_, p2: Project) if hasTooManyExprs(p2.projectList) => // skip
      p1

    case p1 @ Project(_, p2: Project) =>

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, I can simplify the condition check logic, do you suggest to add a new SQLConf instead of the hard limit? And for the case statement there is already a condition check called 'haveCommonNonDeterministicOutput', so I put them together. Also the same for 'case p @ Project(_, agg: Aggregate)'

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yea, if we add this logic, I think we need a conf for that.

if (haveCommonNonDeterministicOutput(p1.projectList, p2.projectList)) {
if (haveCommonNonDeterministicOutput(p1.projectList, p2.projectList)
|| hasTooManyExprs(p2.projectList)) {
p1
} else {
p2.copy(projectList = buildCleanedProjectList(p1.projectList, p2.projectList))
}
case p @ Project(_, agg: Aggregate) =>
if (haveCommonNonDeterministicOutput(p.projectList, agg.aggregateExpressions)) {
if (haveCommonNonDeterministicOutput(p.projectList, agg.aggregateExpressions)
|| hasTooManyExprs(agg.aggregateExpressions)) {
p
} else {
agg.copy(aggregateExpressions = buildCleanedProjectList(
Expand All @@ -725,6 +726,14 @@ object CollapseProject extends Rule[LogicalPlan] {
s.copy(child = p2.copy(projectList = buildCleanedProjectList(l1, p2.projectList)))
}

private def hasTooManyExprs(exprs: Seq[Expression]): Boolean = {
if (SQLConf.get.optimizerCollapseProjectExpressionThreshold == -1) false else {
var numExprs = 0
exprs.foreach { _.foreach { _ => numExprs += 1 } }
numExprs > SQLConf.get.optimizerCollapseProjectExpressionThreshold
}
}

private def collectAliases(projectList: Seq[NamedExpression]): AttributeMap[Alias] = {
AttributeMap(projectList.collect {
case a: Alias => a.toAttribute -> a
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.internal.SQLConf

trait OperationHelper {
type ReturnType = (Seq[NamedExpression], Seq[Expression], LogicalPlan)
Expand Down Expand Up @@ -122,6 +123,14 @@ object ScanOperation extends OperationHelper with PredicateHelper {
}.exists(!_.deterministic))
}

private def hasTooManyExprs(exprs: Seq[Expression]): Boolean = {
if (SQLConf.get.optimizerCollapseProjectExpressionThreshold == -1) false else {
var numExprs = 0
exprs.foreach { _.foreach { _ => numExprs += 1 } }
numExprs > SQLConf.get.optimizerCollapseProjectExpressionThreshold
}
}

private def collectProjectsAndFilters(plan: LogicalPlan): ScanReturnType = {
plan match {
case Project(fields, child) =>
Expand All @@ -132,7 +141,9 @@ object ScanOperation extends OperationHelper with PredicateHelper {
if (!hasCommonNonDeterministic(fields, aliases)) {
val substitutedFields =
fields.map(substitute(aliases)).asInstanceOf[Seq[NamedExpression]]
Some((Some(substitutedFields), filters, other, collectAliases(substitutedFields)))
if (hasTooManyExprs(substitutedFields)) None else {
Some((Some(substitutedFields), filters, other, collectAliases(substitutedFields)))
}
} else {
None
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,15 @@ object SQLConf {
.stringConf
.createOptional

val OPTIMIZER_COLLAPSE_PROJECT_EXPRESSION_THRESHOLD =
buildConf("spark.sql.optimizer.collapseProjectExpressionThreshold")
.internal()
.doc("Sets a threshold for the size of expressions when collpase project, if the current " +
"project has more expressions than the threshold then the project won't collapse. " +
"Set to -1 to disable.")
.intConf
.createWithDefault(1000)

val DYNAMIC_PARTITION_PRUNING_ENABLED =
buildConf("spark.sql.optimizer.dynamicPartitionPruning.enabled")
.doc("When true, we will generate predicate for partition column when it's used as join key")
Expand Down Expand Up @@ -2780,6 +2789,9 @@ class SQLConf extends Serializable with Logging {

def optimizerPlanChangeBatches: Option[String] = getConf(OPTIMIZER_PLAN_CHANGE_LOG_BATCHES)

def optimizerCollapseProjectExpressionThreshold: Int =
getConf(OPTIMIZER_COLLAPSE_PROJECT_EXPRESSION_THRESHOLD)

def dynamicPartitionPruningEnabled: Boolean = getConf(DYNAMIC_PARTITION_PRUNING_ENABLED)

def dynamicPartitionPruningUseStats: Boolean = getConf(DYNAMIC_PARTITION_PRUNING_USE_STATS)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,17 @@ class CollapseProjectSuite extends PlanTest {
comparePlans(optimized, correctAnswer)
}

test("do not collapse project if number of leave expressions would be too big") {
var query: LogicalPlan = testRelation
for( _ <- 1 to 10) {
// after n iterations the number of leaf expressions will be 2^{n+1}
// => after 10 iterations we would end up with more than 1000 leaf expressions
query = query.select(('a + 'b).as('a), ('a - 'b).as('b))
Copy link
Member

@maropu maropu Jul 21, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The same issue can happens in ProjectExec even if the issue of CollapseProject fixed?

case Project(fields, child) =>
collectProjectsAndFilters(child) match {
case Some((_, filters, other, aliases)) =>
// Follow CollapseProject and only keep going if the collected Projects
// do not have common non-deterministic expressions.
if (!hasCommonNonDeterministic(fields, aliases)) {
val substitutedFields =
fields.map(substitute(aliases)).asInstanceOf[Seq[NamedExpression]]
Some((Some(substitutedFields), filters, other, collectAliases(substitutedFields)))

scala> Seq((1, 2)).toDF("a", "b").write.saveAsTable("a")
scala> var query = spark.table("a")
scala> for( _ <- 1 to 10) {
     |   query = query.select(('a + 'b).as('a), ('a - 'b).as('b))
     | }

scala> query.explain(true)
== Parsed Logical Plan ==
...

== Analyzed Logical Plan ==
...

== Optimized Logical Plan ==
Project [(a#49 + b#50) AS a#53, (a#49 - b#50) AS b#54]
+- Project [(a#45 + b#46) AS a#49, (a#45 - b#46) AS b#50]
   +- Project [(a#41 + b#42) AS a#45, (a#41 - b#42) AS b#46]
      +- Project [(a#37 + b#38) AS a#41, (a#37 - b#38) AS b#42]
         +- Project [(a#33 + b#34) AS a#37, (a#33 - b#34) AS b#38]
            +- Project [(a#29 + b#30) AS a#33, (a#29 - b#30) AS b#34]
               +- Project [(a#25 + b#26) AS a#29, (a#25 - b#26) AS b#30]
                  +- Project [(a#21 + b#22) AS a#25, (a#21 - b#22) AS b#26]
                     +- Project [(a#17 + b#18) AS a#21, (a#17 - b#18) AS b#22]
                        +- Project [(a#13 + b#14) AS a#17, (a#13 - b#14) AS b#18]
                           +- Relation[a#13,b#14] parquet

== Physical Plan ==
*(1) Project [((((((((((a#13 + b#14) AS a#17 + ...
  // too many expressions...

+- *(1) ColumnarToRow
   +- FileScan parquet default.a[a#13,b#14] Batched: true, DataFilters: [], Format: Parquet, Location: InMemoryFileIndex[file:/Users/maropu/Repositories/spark/spark-master/spark-warehouse/a], PartitionFilters: [], PushedFilters: [], ReadSchema: struct<a:int,b:int>

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for reminding, my test case is :

var query = spark.range(5).withColumn("new_column", 'id + 5 as "plus5").toDF("a","b")
for( a <- 1 to 10) {query = query.select(('a + 'b).as('a), ('a - 'b).as('b))}
query.explain(true)

And it works for both Optimized Logical Plan and Physical Plan.

I notice the difference is that my data type is bigint: org.apache.spark.sql.DataFrame = [a: bigint, b: bigint], it seems the project will not collapse

I test the case above and the problem exist for Physical Plan, so we also add a check for that?

}
val projects = Optimize.execute(query.analyze).collect { case p: Project => p }
assert(projects.size === 2) // should be collapsed to two projects
}

test("preserve top-level alias metadata while collapsing projects") {
def hasMetadata(logicalPlan: LogicalPlan): Boolean = {
logicalPlan.asInstanceOf[Project].projectList.exists(_.metadata.contains("key"))
Expand Down