From 9650b69cbedcd6fd8d1da678c70ee813a9289fb6 Mon Sep 17 00:00:00 2001 From: Peter Toth Date: Thu, 13 Apr 2023 14:40:25 +0200 Subject: [PATCH 1/5] [SPARK-43124][SQL] Dataset.show() projects CommandResults locally --- .../sql/catalyst/analysis/Analyzer.scala | 4 +++ .../plans/logical/basicLogicalOperators.scala | 14 +++++++++ .../scala/org/apache/spark/sql/Dataset.scala | 30 +++++++++++++------ .../spark/sql/execution/SparkStrategies.scala | 2 ++ .../execution/basicPhysicalOperators.scala | 28 +++++++++++++++++ .../org/apache/spark/sql/DatasetSuite.scala | 17 +++++++++++ 6 files changed, 86 insertions(+), 9 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 8821e652a31f0..c5bb37e32b118 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -3839,6 +3839,10 @@ object CleanupAliases extends Rule[LogicalPlan] with AliasHelper { val cleanedProjectList = projectList.map(trimNonTopLevelAliases) Project(cleanedProjectList, child) + case LocalProject(projectList, child) => + val cleanedProjectList = projectList.map(trimNonTopLevelAliases) + LocalProject(cleanedProjectList, child) + case Aggregate(grouping, aggs, child) => val cleanedAggs = aggs.map(trimNonTopLevelAliases) Aggregate(grouping.map(trimAliases), cleanedAggs, child) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala index 91726185090f6..c436469e7c08b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala @@ -234,6 +234,20 @@ object Project { } } +// `LocalProject` is used only in `Dataset.getRows()` to avoid any job execution due to casting +// columns to String. +case class LocalProject(projectList: Seq[NamedExpression], child: LogicalPlan) + extends OrderPreservingUnaryNode { + override def output: Seq[Attribute] = projectList.map(_.toAttribute) + + override def maxRows: Option[Long] = child.maxRows + + override def maxRowsPerPartition: Option[Long] = child.maxRowsPerPartition + + override protected def withNewChildInternal(newChild: LogicalPlan): LocalProject = + copy(child = newChild) +} + /** * Applies a [[Generator]] to a stream of input rows, combining the * output of each into a new stream of rows. This operation is similar to a `flatMap` in functional diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 584ce19c77a28..110206a0e42c8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -272,16 +272,28 @@ class Dataset[T] private[sql]( numRows: Int, truncate: Int): Seq[Seq[String]] = { val newDf = toDF() - val castCols = newDf.logicalPlan.output.map { col => - // Since binary types in top-level schema fields have a specific format to print, - // so we do not cast them to strings here. - if (col.dataType == BinaryType) { - Column(col) - } else { - Column(col).cast(StringType) - } + val data: Array[Row] = newDf.logicalPlan match { + case c: CommandResult => + val localProject = LocalProject(c.output.map { a => + if (a.dataType == BinaryType) { + a + } else { + Alias(Cast(a, StringType), a.name)() + } + }, c) + Dataset.ofRows(sparkSession, localProject).take(numRows + 1) + case _ => + val castCols = newDf.logicalPlan.output.map { col => + // Since binary types in top-level schema fields have a specific format to print, + // so we do not cast them to strings here. + if (col.dataType == BinaryType) { + Column(col) + } else { + Column(col).cast(StringType) + } + } + newDf.select(castCols: _*).take(numRows + 1) } - val data = newDf.select(castCols: _*).take(numRows + 1) // For array values, replace Seq and Array with square brackets // For cells that are beyond `truncate` characters, replace it with the diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 824d18043cb9a..75ce26bba2be6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -854,6 +854,8 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { execution.SortExec(sortExprs, global, planLater(child)) :: Nil case logical.Project(projectList, child) => execution.ProjectExec(projectList, planLater(child)) :: Nil + case logical.LocalProject(projectList, child) => + execution.LocalProjectExec(projectList, planLater(child)) :: Nil case logical.Filter(condition, child) => execution.FilterExec(condition, planLater(child)) :: Nil case f: logical.TypedFilter => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala index 68f056d894b9f..a3afed183c6fc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala @@ -31,6 +31,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.BindReferences.bindReferences import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.plans.physical._ +import org.apache.spark.sql.errors.QueryExecutionErrors import org.apache.spark.sql.execution.metric.SQLMetrics import org.apache.spark.sql.internal.{SQLConf, StaticSQLConf} import org.apache.spark.sql.types.{LongType, StructType} @@ -118,6 +119,33 @@ case class ProjectExec(projectList: Seq[NamedExpression], child: SparkPlan) copy(child = newChild) } +case class LocalProjectExec(projectList: Seq[NamedExpression], child: SparkPlan) + extends UnaryExecNode { + + @transient private lazy val project = UnsafeProjection.create(projectList, child.output) + + override def output: Seq[Attribute] = projectList.map(_.toAttribute) + + override protected def withNewChildInternal(newChild: SparkPlan): LocalProjectExec = + copy(child = newChild) + + override def executeCollect(): Array[InternalRow] = { + child.executeCollect().map(project) + } + + override def executeTake(n: Int): Array[InternalRow] = { + child.executeTake(n).map(project) + } + + override def executeTail(n: Int): Array[InternalRow] = { + child.executeTail(n).map(project) + } + + override protected def doExecute(): RDD[InternalRow] = { + throw QueryExecutionErrors.executeCodePathUnsupportedError("LocalProjectExec") + } +} + trait GeneratePredicateHelper extends PredicateHelper { self: CodegenSupport => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index 4aca7c8a5a666..8e4047a709a67 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -2429,6 +2429,23 @@ class DatasetSuite extends QueryTest assert(parquetFiles.size === 10) } } + + test("SPARK-43124: Show does not trigger job execution on CommandResults") { + withTable("t1") { + sql("create table t1(c int) using parquet") + + @volatile var jobCounter = 0 + val listener = new SparkListener { + override def onJobStart(jobStart: SparkListenerJobStart): Unit = { + jobCounter += 1 + } + } + withListener(spark.sparkContext, listener) { _ => + sql("show tables").show() + } + assert(jobCounter === 0) + } + } } class DatasetLargeResultCollectingSuite extends QueryTest From 8c014a250bf2931c9737e437b06f11a0b55fd9da Mon Sep 17 00:00:00 2001 From: Peter Toth Date: Thu, 13 Apr 2023 15:10:03 +0200 Subject: [PATCH 2/5] remove type --- sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 110206a0e42c8..c3aac3851da0c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -272,7 +272,7 @@ class Dataset[T] private[sql]( numRows: Int, truncate: Int): Seq[Seq[String]] = { val newDf = toDF() - val data: Array[Row] = newDf.logicalPlan match { + val data = newDf.logicalPlan match { case c: CommandResult => val localProject = LocalProject(c.output.map { a => if (a.dataType == BinaryType) { From 4c1d505de02f974d922352ef549bfacbfad871b8 Mon Sep 17 00:00:00 2001 From: Peter Toth Date: Thu, 13 Apr 2023 15:35:51 +0200 Subject: [PATCH 3/5] revert LocalProject --- .../sql/catalyst/analysis/Analyzer.scala | 4 --- .../plans/logical/basicLogicalOperators.scala | 14 ---------- .../spark/sql/execution/SparkStrategies.scala | 2 -- .../execution/basicPhysicalOperators.scala | 28 ------------------- 4 files changed, 48 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index c5bb37e32b118..8821e652a31f0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -3839,10 +3839,6 @@ object CleanupAliases extends Rule[LogicalPlan] with AliasHelper { val cleanedProjectList = projectList.map(trimNonTopLevelAliases) Project(cleanedProjectList, child) - case LocalProject(projectList, child) => - val cleanedProjectList = projectList.map(trimNonTopLevelAliases) - LocalProject(cleanedProjectList, child) - case Aggregate(grouping, aggs, child) => val cleanedAggs = aggs.map(trimNonTopLevelAliases) Aggregate(grouping.map(trimAliases), cleanedAggs, child) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala index c436469e7c08b..91726185090f6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala @@ -234,20 +234,6 @@ object Project { } } -// `LocalProject` is used only in `Dataset.getRows()` to avoid any job execution due to casting -// columns to String. -case class LocalProject(projectList: Seq[NamedExpression], child: LogicalPlan) - extends OrderPreservingUnaryNode { - override def output: Seq[Attribute] = projectList.map(_.toAttribute) - - override def maxRows: Option[Long] = child.maxRows - - override def maxRowsPerPartition: Option[Long] = child.maxRowsPerPartition - - override protected def withNewChildInternal(newChild: LogicalPlan): LocalProject = - copy(child = newChild) -} - /** * Applies a [[Generator]] to a stream of input rows, combining the * output of each into a new stream of rows. This operation is similar to a `flatMap` in functional diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 75ce26bba2be6..824d18043cb9a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -854,8 +854,6 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { execution.SortExec(sortExprs, global, planLater(child)) :: Nil case logical.Project(projectList, child) => execution.ProjectExec(projectList, planLater(child)) :: Nil - case logical.LocalProject(projectList, child) => - execution.LocalProjectExec(projectList, planLater(child)) :: Nil case logical.Filter(condition, child) => execution.FilterExec(condition, planLater(child)) :: Nil case f: logical.TypedFilter => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala index a3afed183c6fc..68f056d894b9f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala @@ -31,7 +31,6 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.BindReferences.bindReferences import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.plans.physical._ -import org.apache.spark.sql.errors.QueryExecutionErrors import org.apache.spark.sql.execution.metric.SQLMetrics import org.apache.spark.sql.internal.{SQLConf, StaticSQLConf} import org.apache.spark.sql.types.{LongType, StructType} @@ -119,33 +118,6 @@ case class ProjectExec(projectList: Seq[NamedExpression], child: SparkPlan) copy(child = newChild) } -case class LocalProjectExec(projectList: Seq[NamedExpression], child: SparkPlan) - extends UnaryExecNode { - - @transient private lazy val project = UnsafeProjection.create(projectList, child.output) - - override def output: Seq[Attribute] = projectList.map(_.toAttribute) - - override protected def withNewChildInternal(newChild: SparkPlan): LocalProjectExec = - copy(child = newChild) - - override def executeCollect(): Array[InternalRow] = { - child.executeCollect().map(project) - } - - override def executeTake(n: Int): Array[InternalRow] = { - child.executeTake(n).map(project) - } - - override def executeTail(n: Int): Array[InternalRow] = { - child.executeTail(n).map(project) - } - - override protected def doExecute(): RDD[InternalRow] = { - throw QueryExecutionErrors.executeCodePathUnsupportedError("LocalProjectExec") - } -} - trait GeneratePredicateHelper extends PredicateHelper { self: CodegenSupport => From 98bf071cfb016b6461545a964ec07576419a6bb6 Mon Sep 17 00:00:00 2001 From: Peter Toth Date: Thu, 13 Apr 2023 15:36:40 +0200 Subject: [PATCH 4/5] convert to LocalRelation in getRows --- .../src/main/scala/org/apache/spark/sql/Dataset.scala | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index c3aac3851da0c..8699124622461 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -274,14 +274,18 @@ class Dataset[T] private[sql]( val newDf = toDF() val data = newDf.logicalPlan match { case c: CommandResult => - val localProject = LocalProject(c.output.map { a => + // Do the casting locally to avoid triggering a job + val projectList = c.output.map { a => if (a.dataType == BinaryType) { a } else { Alias(Cast(a, StringType), a.name)() } - }, c) - Dataset.ofRows(sparkSession, localProject).take(numRows + 1) + } + val projection = new InterpretedMutableProjection(projectList, c.output) + val casted = LocalRelation(projectList.map(_.toAttribute), + c.rows.take(numRows + 1).map(projection(_).copy())) + Dataset.ofRows(sparkSession, casted).collect() case _ => val castCols = newDf.logicalPlan.output.map { col => // Since binary types in top-level schema fields have a specific format to print, From d8210ff849727961681b7d9df71b70e19ae24d41 Mon Sep 17 00:00:00 2001 From: Peter Toth Date: Thu, 13 Apr 2023 17:44:45 +0200 Subject: [PATCH 5/5] simpify --- .../scala/org/apache/spark/sql/Dataset.scala | 40 +++++++------------ .../org/apache/spark/sql/DatasetSuite.scala | 22 +++++----- 2 files changed, 27 insertions(+), 35 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 8699124622461..14131b261f539 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -271,33 +271,23 @@ class Dataset[T] private[sql]( private[sql] def getRows( numRows: Int, truncate: Int): Seq[Seq[String]] = { - val newDf = toDF() - val data = newDf.logicalPlan match { + val newDf = logicalPlan match { case c: CommandResult => - // Do the casting locally to avoid triggering a job - val projectList = c.output.map { a => - if (a.dataType == BinaryType) { - a - } else { - Alias(Cast(a, StringType), a.name)() - } - } - val projection = new InterpretedMutableProjection(projectList, c.output) - val casted = LocalRelation(projectList.map(_.toAttribute), - c.rows.take(numRows + 1).map(projection(_).copy())) - Dataset.ofRows(sparkSession, casted).collect() - case _ => - val castCols = newDf.logicalPlan.output.map { col => - // Since binary types in top-level schema fields have a specific format to print, - // so we do not cast them to strings here. - if (col.dataType == BinaryType) { - Column(col) - } else { - Column(col).cast(StringType) - } - } - newDf.select(castCols: _*).take(numRows + 1) + // Convert to `LocalRelation` and let `ConvertToLocalRelation` do the casting locally to + // avoid triggering a job + Dataset.ofRows(sparkSession, LocalRelation(c.output, c.rows)) + case _ => toDF() + } + val castCols = newDf.logicalPlan.output.map { col => + // Since binary types in top-level schema fields have a specific format to print, + // so we do not cast them to strings here. + if (col.dataType == BinaryType) { + Column(col) + } else { + Column(col).cast(StringType) + } } + val data = newDf.select(castCols: _*).take(numRows + 1) // For array values, replace Seq and Array with square brackets // For cells that are beyond `truncate` characters, replace it with the diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index 8e4047a709a67..2bd1a4ac2780b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -2431,19 +2431,21 @@ class DatasetSuite extends QueryTest } test("SPARK-43124: Show does not trigger job execution on CommandResults") { - withTable("t1") { - sql("create table t1(c int) using parquet") + withSQLConf(SQLConf.OPTIMIZER_EXCLUDED_RULES.key -> "") { + withTable("t1") { + sql("create table t1(c int) using parquet") - @volatile var jobCounter = 0 - val listener = new SparkListener { - override def onJobStart(jobStart: SparkListenerJobStart): Unit = { - jobCounter += 1 + @volatile var jobCounter = 0 + val listener = new SparkListener { + override def onJobStart(jobStart: SparkListenerJobStart): Unit = { + jobCounter += 1 + } } + withListener(spark.sparkContext, listener) { _ => + sql("show tables").show() + } + assert(jobCounter === 0) } - withListener(spark.sparkContext, listener) { _ => - sql("show tables").show() - } - assert(jobCounter === 0) } } }