From 03633e3915afeaa7bbd6280b6aa55946b272b4de Mon Sep 17 00:00:00 2001 From: caoxuewen Date: Thu, 17 Aug 2017 15:12:43 +0800 Subject: [PATCH 1/2] fix a special case for non-deterministic projects in optimizer --- .../sql/catalyst/planning/patterns.scala | 20 ++++++++++++++++- .../org/apache/spark/sql/DataFrameSuite.scala | 22 +++++++++++++++++++ 2 files changed, 41 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala index 8d034c21a496..64a86087f9b4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala @@ -23,6 +23,24 @@ import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ +/** + * A pattern that matches any number of project if fields is deterministic + * or child is LeafNode of project on top of another relational operator. + */ +object ProjectOperation extends PredicateHelper { + type ReturnType = (Seq[NamedExpression], LogicalPlan) + + def unapply(plan: LogicalPlan): Option[ReturnType] = plan match { + case Project(fields, child) if fields.forall(_.deterministic) => + Some((fields, child)) + + case p @ Project(fields, child: LeafNode) if p.references.nonEmpty => + Some((fields, child)) + + case _ => None + } +} + /** * A pattern that matches any number of project or filter operations on top of another relational * operator. All filter operators are collected and their conditions are broken up and returned @@ -55,7 +73,7 @@ object PhysicalOperation extends PredicateHelper { private def collectProjectsAndFilters(plan: LogicalPlan): (Option[Seq[NamedExpression]], Seq[Expression], LogicalPlan, Map[Attribute, Expression]) = plan match { - case Project(fields, child) if fields.forall(_.deterministic) => + case ProjectOperation(fields, child) => val (_, filters, other, aliases) = collectProjectsAndFilters(child) val substitutedFields = fields.map(substitute(aliases)).asInstanceOf[Seq[NamedExpression]] (Some(substitutedFields), filters, other, collectAliases(substitutedFields)) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 13341645e8ff..30a6e1fe3900 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -2034,6 +2034,28 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { } } + test("SPARK-21520: the fields of project contains nondeterministic") { + withSQLConf(SQLConf.CASE_SENSITIVE.key -> "false") { + withTempPath { path => + val p = path.getAbsolutePath + Seq((1, 2, 3), (4, 5, 6)).toDF("a", "b", "c").write.partitionBy("a").parquet(p) + val df = spark.read.parquet(p) + + val qe = df.select($"a", rand(10).as('rand)) + // FileScan parquet [a#38] + assert(qe.queryExecution.sparkPlan.inputSet.toString.contains("a#")) + assert(!qe.queryExecution.sparkPlan.inputSet.toString.contains("b#")) + assert(!qe.queryExecution.sparkPlan.inputSet.toString.contains("c#")) + + val qe2 = df.select($"a", $"b", rand(10).as('rand2)) + // FileScan parquet [b#70,a#72] + assert(qe2.queryExecution.sparkPlan.inputSet.toString.contains("a#")) + assert(qe2.queryExecution.sparkPlan.inputSet.toString.contains("b#")) + assert(!qe2.queryExecution.sparkPlan.inputSet.toString.contains("c#")) + } + } + } + test("order-by ordinal.") { checkAnswer( testData2.select(lit(7), 'a, 'b).orderBy(lit(1), lit(2), lit(3)), From 4ebabc7b07d5a0b3514920387c7a2405e43d9d8f Mon Sep 17 00:00:00 2001 From: caoxuewen Date: Mon, 11 Sep 2017 18:10:59 +0800 Subject: [PATCH 2/2] a new fix --- .../sql/catalyst/planning/patterns.scala | 20 +------------------ .../datasources/FileSourceStrategy.scala | 11 +++++++++- .../spark/sql/hive/HiveStrategies.scala | 12 +++++++++-- 3 files changed, 21 insertions(+), 22 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala index 64a86087f9b4..8d034c21a496 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala @@ -23,24 +23,6 @@ import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ -/** - * A pattern that matches any number of project if fields is deterministic - * or child is LeafNode of project on top of another relational operator. - */ -object ProjectOperation extends PredicateHelper { - type ReturnType = (Seq[NamedExpression], LogicalPlan) - - def unapply(plan: LogicalPlan): Option[ReturnType] = plan match { - case Project(fields, child) if fields.forall(_.deterministic) => - Some((fields, child)) - - case p @ Project(fields, child: LeafNode) if p.references.nonEmpty => - Some((fields, child)) - - case _ => None - } -} - /** * A pattern that matches any number of project or filter operations on top of another relational * operator. All filter operators are collected and their conditions are broken up and returned @@ -73,7 +55,7 @@ object PhysicalOperation extends PredicateHelper { private def collectProjectsAndFilters(plan: LogicalPlan): (Option[Seq[NamedExpression]], Seq[Expression], LogicalPlan, Map[Attribute, Expression]) = plan match { - case ProjectOperation(fields, child) => + case Project(fields, child) if fields.forall(_.deterministic) => val (_, filters, other, aliases) = collectProjectsAndFilters(child) val substitutedFields = fields.map(substitute(aliases)).asInstanceOf[Seq[NamedExpression]] (Some(substitutedFields), filters, other, collectAliases(substitutedFields)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala index 16b22717b8d9..e55a0f5ab761 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala @@ -22,7 +22,7 @@ import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.planning.PhysicalOperation -import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project} import org.apache.spark.sql.execution.FileSourceScanExec import org.apache.spark.sql.execution.SparkPlan @@ -51,6 +51,15 @@ import org.apache.spark.sql.execution.SparkPlan */ object FileSourceStrategy extends Strategy with Logging { def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { + case p @ Project(fields, child) + if !fields.forall(_.deterministic) && p.references.nonEmpty => + collectFileSource(Project(child.output.filter(p.references.contains), child)) + .map(p => execution.ProjectExec(fields, p)).toList + + case _ => collectFileSource(plan) + } + + private def collectFileSource(plan: LogicalPlan): Seq[SparkPlan] = plan match { case PhysicalOperation(projects, filters, l @ LogicalRelation(fsRelation: HadoopFsRelation, _, table, _)) => // Filters on this relation fall into four categories based on where we can use them to avoid diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala index caf554d9ea51..d9abf1f3a236 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala @@ -26,8 +26,7 @@ import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.catalog._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.planning._ -import org.apache.spark.sql.catalyst.plans.logical.{InsertIntoDir, InsertIntoTable, LogicalPlan, - ScriptTransformation} +import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.command.{CreateTableCommand, DDLUtils} @@ -239,6 +238,15 @@ private[hive] trait HiveStrategies { */ object HiveTableScans extends Strategy { def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { + case p @ Project(fields, child) + if !fields.forall(_.deterministic) && p.references.nonEmpty => + collectHiveTableSource(Project(child.output.filter(p.references.contains), child)) + .map(p => ProjectExec(fields, p)).toList + + case _ => collectHiveTableSource(plan) + } + + private def collectHiveTableSource(plan: LogicalPlan): Seq[SparkPlan] = plan match { case PhysicalOperation(projectList, predicates, relation: HiveTableRelation) => // Filter out all predicates that only deal with partition keys, these are given to the // hive table scan operator to be used for partition pruning.