Skip to content

Commit 0e33195

Browse files
committed
[SPARK-39834][SQL][SS] Include the origin stats and constraints for LogicalRDD if it comes from DataFrame
Credit to juliuszsompolski for figuring out issues and proposing the alternative. ### What changes were proposed in this pull request? This PR proposes to effectively revert SPARK-39748 but include the origin stats and constraints instead in LogicalRDD if it comes from DataFrame, to help optimizer figuring out better plan. ### Why are the changes needed? We figured out several issues from [SPARK-39748](https://issues.apache.org/jira/browse/SPARK-39748): 1. One of major use case for DataFrame.checkpoint is ML, especially "iterative algorithm", and the purpose on calling checkpoint is to "prune" the logical plan. That is against the purpose of including origin logical plan and we have a risk to have nested LogicalRDDs which grows the size of logical plan infinitely. 2. We leverage logical plan to carry over stats, but the correct stats information is in optimized plan. 3. (Not an issue but missing spot) constraints is also something we can carry over. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Existing and new UTs. Closes #37248 from HeartSaVioR/SPARK-39834. Authored-by: Jungtaek Lim <[email protected]> Signed-off-by: Jungtaek Lim <[email protected]>
1 parent 869fc21 commit 0e33195

File tree

5 files changed

+141
-115
lines changed

5 files changed

+141
-115
lines changed

sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala

Lines changed: 2 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,6 @@ import org.apache.spark.sql.catalyst.optimizer.CombineUnions
4646
import org.apache.spark.sql.catalyst.parser.{ParseException, ParserUtils}
4747
import org.apache.spark.sql.catalyst.plans._
4848
import org.apache.spark.sql.catalyst.plans.logical._
49-
import org.apache.spark.sql.catalyst.plans.physical.{Partitioning, PartitioningCollection}
5049
import org.apache.spark.sql.catalyst.trees.TreeNodeTag
5150
import org.apache.spark.sql.catalyst.util.IntervalUtils
5251
import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors}
@@ -710,29 +709,10 @@ class Dataset[T] private[sql](
710709
internalRdd.doCheckpoint()
711710
}
712711

713-
// Takes the first leaf partitioning whenever we see a `PartitioningCollection`. Otherwise the
714-
// size of `PartitioningCollection` may grow exponentially for queries involving deep inner
715-
// joins.
716-
@scala.annotation.tailrec
717-
def firstLeafPartitioning(partitioning: Partitioning): Partitioning = {
718-
partitioning match {
719-
case p: PartitioningCollection => firstLeafPartitioning(p.partitionings.head)
720-
case p => p
721-
}
722-
}
723-
724-
val outputPartitioning = firstLeafPartitioning(physicalPlan.outputPartitioning)
725-
726712
Dataset.ofRows(
727713
sparkSession,
728-
LogicalRDD(
729-
logicalPlan.output,
730-
internalRdd,
731-
Some(queryExecution.analyzed),
732-
outputPartitioning,
733-
physicalPlan.outputOrdering,
734-
isStreaming
735-
)(sparkSession)).as[T]
714+
LogicalRDD.fromDataset(rdd = internalRdd, originDataset = this, isStreaming = isStreaming)
715+
).as[T]
736716
}
737717
}
738718

sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala

Lines changed: 63 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -18,12 +18,12 @@
1818
package org.apache.spark.sql.execution
1919

2020
import org.apache.spark.rdd.RDD
21-
import org.apache.spark.sql.{Encoder, SparkSession}
21+
import org.apache.spark.sql.{Dataset, Encoder, SparkSession}
2222
import org.apache.spark.sql.catalyst.InternalRow
2323
import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation
2424
import org.apache.spark.sql.catalyst.expressions._
2525
import org.apache.spark.sql.catalyst.plans.logical._
26-
import org.apache.spark.sql.catalyst.plans.physical.{Partitioning, UnknownPartitioning}
26+
import org.apache.spark.sql.catalyst.plans.physical.{Partitioning, PartitioningCollection, UnknownPartitioning}
2727
import org.apache.spark.sql.catalyst.util.truncatedString
2828
import org.apache.spark.sql.execution.metric.SQLMetrics
2929

@@ -86,19 +86,24 @@ case class ExternalRDDScanExec[T](
8686
/**
8787
* Logical plan node for scanning data from an RDD of InternalRow.
8888
*
89-
* It is advised to set the field `originLogicalPlan` if the RDD is directly built from DataFrame,
90-
* as the stat can be inherited from `originLogicalPlan`.
89+
* It is advised to set the field `originStats` and `originConstraints` if the RDD is directly
90+
* built from DataFrame, so that Spark can make better optimizations.
9191
*/
9292
case class LogicalRDD(
9393
output: Seq[Attribute],
9494
rdd: RDD[InternalRow],
95-
originLogicalPlan: Option[LogicalPlan] = None,
9695
outputPartitioning: Partitioning = UnknownPartitioning(0),
9796
override val outputOrdering: Seq[SortOrder] = Nil,
98-
override val isStreaming: Boolean = false)(session: SparkSession)
97+
override val isStreaming: Boolean = false)(
98+
session: SparkSession,
99+
// originStats and originConstraints are intentionally placed to "second" parameter list,
100+
// to prevent catalyst rules to mistakenly transform and rewrite them. Do not change this.
101+
originStats: Option[Statistics] = None,
102+
originConstraints: Option[ExpressionSet] = None)
99103
extends LeafNode with MultiInstanceRelation {
100104

101-
override protected final def otherCopyArgs: Seq[AnyRef] = session :: Nil
105+
override protected final def otherCopyArgs: Seq[AnyRef] =
106+
session :: originStats :: originConstraints :: Nil
102107

103108
override def newInstance(): LogicalRDD.this.type = {
104109
val rewrite = output.zip(output.map(_.newInstance())).toMap
@@ -116,37 +121,78 @@ case class LogicalRDD(
116121
case e: Attribute => rewrite.getOrElse(e, e)
117122
}.asInstanceOf[SortOrder])
118123

119-
val rewrittenOriginLogicalPlan = originLogicalPlan.map { plan =>
120-
assert(output == plan.output, "The output columns are expected to the same for output " +
121-
s"and originLogicalPlan. output: $output / output in originLogicalPlan: ${plan.output}")
124+
val rewrittenStatistics = originStats.map { s =>
125+
Statistics(
126+
s.sizeInBytes,
127+
s.rowCount,
128+
AttributeMap[ColumnStat](s.attributeStats.map {
129+
case (attr, v) => (rewrite.getOrElse(attr, attr), v)
130+
}),
131+
s.isRuntime
132+
)
133+
}
122134

123-
val projectList = output.map { attr =>
124-
Alias(attr, attr.name)(exprId = rewrite(attr).exprId)
125-
}
126-
Project(projectList, plan)
135+
val rewrittenConstraints = originConstraints.map { c =>
136+
c.map(_.transform {
137+
case e: Attribute => rewrite.getOrElse(e, e)
138+
})
127139
}
128140

129141
LogicalRDD(
130142
output.map(rewrite),
131143
rdd,
132-
rewrittenOriginLogicalPlan,
133144
rewrittenPartitioning,
134145
rewrittenOrdering,
135146
isStreaming
136-
)(session).asInstanceOf[this.type]
147+
)(session, rewrittenStatistics, rewrittenConstraints).asInstanceOf[this.type]
137148
}
138149

139150
override protected def stringArgs: Iterator[Any] = Iterator(output, isStreaming)
140151

141152
override def computeStats(): Statistics = {
142-
originLogicalPlan.map(_.stats).getOrElse {
153+
originStats.getOrElse {
143154
Statistics(
144155
// TODO: Instead of returning a default value here, find a way to return a meaningful size
145156
// estimate for RDDs. See PR 1238 for more discussions.
146157
sizeInBytes = BigInt(session.sessionState.conf.defaultSizeInBytes)
147158
)
148159
}
149160
}
161+
162+
override lazy val constraints: ExpressionSet = originConstraints.getOrElse(ExpressionSet())
163+
}
164+
165+
object LogicalRDD {
166+
/**
167+
* Create a new LogicalRDD based on existing Dataset. Stats and constraints are inherited from
168+
* origin Dataset.
169+
*/
170+
private[sql] def fromDataset(
171+
rdd: RDD[InternalRow],
172+
originDataset: Dataset[_],
173+
isStreaming: Boolean): LogicalRDD = {
174+
// Takes the first leaf partitioning whenever we see a `PartitioningCollection`. Otherwise the
175+
// size of `PartitioningCollection` may grow exponentially for queries involving deep inner
176+
// joins.
177+
@scala.annotation.tailrec
178+
def firstLeafPartitioning(partitioning: Partitioning): Partitioning = {
179+
partitioning match {
180+
case p: PartitioningCollection => firstLeafPartitioning(p.partitionings.head)
181+
case p => p
182+
}
183+
}
184+
185+
val optimizedPlan = originDataset.queryExecution.optimizedPlan
186+
val executedPlan = originDataset.queryExecution.executedPlan
187+
188+
LogicalRDD(
189+
originDataset.logicalPlan.output,
190+
rdd,
191+
firstLeafPartitioning(executedPlan.outputPartitioning),
192+
executedPlan.outputOrdering,
193+
isStreaming
194+
)(originDataset.sparkSession, Some(optimizedPlan.stats), Some(optimizedPlan.constraints))
195+
}
150196
}
151197

152198
/** Physical plan node for scanning data from an RDD of InternalRow. */

sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ForeachBatchSink.scala

Lines changed: 2 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@ package org.apache.spark.sql.execution.streaming.sources
1919

2020
import org.apache.spark.sql._
2121
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
22-
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
2322
import org.apache.spark.sql.execution.LogicalRDD
2423
import org.apache.spark.sql.execution.streaming.Sink
2524
import org.apache.spark.sql.streaming.DataStreamWriter
@@ -28,33 +27,13 @@ class ForeachBatchSink[T](batchWriter: (Dataset[T], Long) => Unit, encoder: Expr
2827
extends Sink {
2928

3029
override def addBatch(batchId: Long, data: DataFrame): Unit = {
31-
val rdd = data.queryExecution.toRdd
32-
val executedPlan = data.queryExecution.executedPlan
33-
val analyzedPlanWithoutMarkerNode = eliminateWriteMarkerNode(data.queryExecution.analyzed)
34-
// assertion on precondition
35-
assert(data.logicalPlan.output == analyzedPlanWithoutMarkerNode.output)
36-
val node = LogicalRDD(
37-
data.logicalPlan.output,
38-
rdd,
39-
Some(analyzedPlanWithoutMarkerNode),
40-
executedPlan.outputPartitioning,
41-
executedPlan.outputOrdering)(data.sparkSession)
30+
val node = LogicalRDD.fromDataset(rdd = data.queryExecution.toRdd, originDataset = data,
31+
isStreaming = false)
4232
implicit val enc = encoder
4333
val ds = Dataset.ofRows(data.sparkSession, node).as[T]
4434
batchWriter(ds, batchId)
4535
}
4636

47-
/**
48-
* ForEachBatchSink implementation reuses the logical plan of `data` which breaks the contract
49-
* of Sink.addBatch, which `data` should be just used to "collect" the output data.
50-
* We have to deal with eliminating marker node here which we do this in streaming specific
51-
* optimization rule.
52-
*/
53-
private def eliminateWriteMarkerNode(plan: LogicalPlan): LogicalPlan = plan match {
54-
case node: WriteToMicroBatchDataSourceV1 => node.child
55-
case node => node
56-
}
57-
5837
override def toString(): String = "ForeachBatchSink"
5938
}
6039

sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala

Lines changed: 73 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -31,10 +31,10 @@ import org.scalatest.matchers.should.Matchers._
3131

3232
import org.apache.spark.SparkException
3333
import org.apache.spark.scheduler.{SparkListener, SparkListenerJobEnd}
34-
import org.apache.spark.sql.catalyst.TableIdentifier
34+
import org.apache.spark.sql.catalyst.{InternalRow, TableIdentifier}
3535
import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation
3636
import org.apache.spark.sql.catalyst.encoders.{ExpressionEncoder, RowEncoder}
37-
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, AttributeReference, Uuid}
37+
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, AttributeReference, EqualTo, ExpressionSet, GreaterThan, Literal, Uuid}
3838
import org.apache.spark.sql.catalyst.optimizer.ConvertToLocalRelation
3939
import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, LeafNode, LocalRelation, LogicalPlan, OneRowRelation, Statistics}
4040
import org.apache.spark.sql.catalyst.util.DateTimeUtils
@@ -2011,7 +2011,7 @@ class DataFrameSuite extends QueryTest
20112011
}
20122012
}
20132013

2014-
test("SPARK-39748: build the stats for LogicalRDD based on originLogicalPlan") {
2014+
test("SPARK-39834: build the stats for LogicalRDD based on origin stats") {
20152015
def buildExpectedColumnStats(attrs: Seq[Attribute]): AttributeMap[ColumnStat] = {
20162016
AttributeMap(
20172017
attrs.map {
@@ -2040,7 +2040,8 @@ class DataFrameSuite extends QueryTest
20402040

20412041
val outputList = Seq(
20422042
AttributeReference("cbool", BooleanType)(),
2043-
AttributeReference("cbyte", BooleanType)()
2043+
AttributeReference("cbyte", ByteType)(),
2044+
AttributeReference("cint", IntegerType)()
20442045
)
20452046

20462047
val expectedSize = 16
@@ -2052,9 +2053,11 @@ class DataFrameSuite extends QueryTest
20522053
withSQLConf(SQLConf.CBO_ENABLED.key -> "true") {
20532054
val df = Dataset.ofRows(spark, statsPlan)
20542055

2056+
// We can't leverage LogicalRDD.fromDataset here, since it triggers physical planning and
2057+
// there is no matching physical node for OutputListAwareStatsTestPlan.
20552058
val logicalRDD = LogicalRDD(
2056-
df.logicalPlan.output, spark.sparkContext.emptyRDD, Some(df.queryExecution.analyzed),
2057-
isStreaming = true)(spark)
2059+
df.logicalPlan.output, spark.sparkContext.emptyRDD[InternalRow], isStreaming = true)(
2060+
spark, Some(df.queryExecution.optimizedPlan.stats), None)
20582061

20592062
val stats = logicalRDD.computeStats()
20602063
val expectedStats = Statistics(sizeInBytes = expectedSize, rowCount = Some(2),
@@ -2065,14 +2068,52 @@ class DataFrameSuite extends QueryTest
20652068
// reflected as well.
20662069
val newLogicalRDD = logicalRDD.newInstance()
20672070
val newStats = newLogicalRDD.computeStats()
2068-
// LogicalRDD.newInstance adds projection to originLogicalPlan, which triggers estimation
2069-
// on sizeInBytes. We don't intend to check the estimated value.
2070-
val newExpectedStats = Statistics(sizeInBytes = newStats.sizeInBytes, rowCount = Some(2),
2071+
val newExpectedStats = Statistics(sizeInBytes = expectedSize, rowCount = Some(2),
20712072
attributeStats = buildExpectedColumnStats(newLogicalRDD.output))
20722073
assert(newStats === newExpectedStats)
20732074
}
20742075
}
20752076

2077+
test("SPARK-39834: build the constraints for LogicalRDD based on origin constraints") {
2078+
def buildExpectedConstraints(attrs: Seq[Attribute]): ExpressionSet = {
2079+
val exprs = attrs.flatMap { attr =>
2080+
attr.dataType match {
2081+
case BooleanType => Some(EqualTo(attr, Literal(true, BooleanType)))
2082+
case IntegerType => Some(GreaterThan(attr, Literal(5, IntegerType)))
2083+
case _ => None
2084+
}
2085+
}
2086+
ExpressionSet(exprs)
2087+
}
2088+
2089+
val outputList = Seq(
2090+
AttributeReference("cbool", BooleanType)(),
2091+
AttributeReference("cbyte", ByteType)(),
2092+
AttributeReference("cint", IntegerType)()
2093+
)
2094+
2095+
val statsPlan = OutputListAwareConstraintsTestPlan(outputList = outputList)
2096+
2097+
val df = Dataset.ofRows(spark, statsPlan)
2098+
2099+
// We can't leverage LogicalRDD.fromDataset here, since it triggers physical planning and
2100+
// there is no matching physical node for OutputListAwareConstraintsTestPlan.
2101+
val logicalRDD = LogicalRDD(
2102+
df.logicalPlan.output, spark.sparkContext.emptyRDD[InternalRow], isStreaming = true)(
2103+
spark, None, Some(df.queryExecution.optimizedPlan.constraints))
2104+
2105+
val constraints = logicalRDD.constraints
2106+
val expectedConstraints = buildExpectedConstraints(logicalRDD.output)
2107+
assert(constraints === expectedConstraints)
2108+
2109+
// This method re-issues expression IDs for all outputs. We expect constraints to be
2110+
// reflected as well.
2111+
val newLogicalRDD = logicalRDD.newInstance()
2112+
val newConstraints = newLogicalRDD.constraints
2113+
val newExpectedConstraints = buildExpectedConstraints(newLogicalRDD.output)
2114+
assert(newConstraints === newExpectedConstraints)
2115+
}
2116+
20762117
test("SPARK-10656: completely support special chars") {
20772118
val df = Seq(1 -> "a").toDF("i_$.a", "d^'a.")
20782119
checkAnswer(df.select(df("*")), Row(1, "a"))
@@ -3356,3 +3397,26 @@ case class OutputListAwareStatsTestPlan(
33563397
}
33573398
override def newInstance(): LogicalPlan = copy(outputList = outputList.map(_.newInstance()))
33583399
}
3400+
3401+
/**
3402+
* This class is used for unit-testing. It's a logical plan whose output is passed in.
3403+
*/
3404+
case class OutputListAwareConstraintsTestPlan(
3405+
outputList: Seq[Attribute]) extends LeafNode with MultiInstanceRelation {
3406+
override def output: Seq[Attribute] = outputList
3407+
3408+
override lazy val constraints: ExpressionSet = {
3409+
val exprs = outputList.flatMap { attr =>
3410+
attr.dataType match {
3411+
case BooleanType => Some(EqualTo(attr, Literal(true, BooleanType)))
3412+
case IntegerType => Some(GreaterThan(attr, Literal(5, IntegerType)))
3413+
case _ => None
3414+
}
3415+
}
3416+
ExpressionSet(exprs)
3417+
}
3418+
3419+
override def newInstance(): LogicalPlan = copy(outputList = outputList.map(_.newInstance()))
3420+
}
3421+
3422+

0 commit comments

Comments
 (0)