Skip to content

Commit b803b66

Browse files
wangzhenhuagatorsmile
authored andcommitted
[SPARK-21180][SQL] Remove conf from stats functions since now we have conf in LogicalPlan
## What changes were proposed in this pull request? After wiring `SQLConf` in logical plan ([PR 18299](#18299)), we can remove the need of passing `conf` into `def stats` and `def computeStats`. ## How was this patch tested? Covered by existing tests, plus some modified existing tests. Author: wangzhenhua <[email protected]> Author: Zhenhua Wang <[email protected]> Closes #18391 from wzhfy/removeConf.
1 parent 07479b3 commit b803b66

38 files changed

+178
-173
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,6 @@ import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, Attri
3131
import org.apache.spark.sql.catalyst.plans.logical._
3232
import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, DateTimeUtils}
3333
import org.apache.spark.sql.catalyst.util.quoteIdentifier
34-
import org.apache.spark.sql.internal.SQLConf
3534
import org.apache.spark.sql.types.StructType
3635

3736

@@ -436,7 +435,7 @@ case class CatalogRelation(
436435
createTime = -1
437436
))
438437

439-
override def computeStats(conf: SQLConf): Statistics = {
438+
override def computeStats: Statistics = {
440439
// For data source tables, we will create a `LogicalRelation` and won't call this method, for
441440
// hive serde tables, we will always generate a statistics.
442441
// TODO: unify the table stats generation.

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/CostBasedJoinReorder.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ case class CostBasedJoinReorder(conf: SQLConf) extends Rule[LogicalPlan] with Pr
5858
// Do reordering if the number of items is appropriate and join conditions exist.
5959
// We also need to check if costs of all items can be evaluated.
6060
if (items.size > 2 && items.size <= conf.joinReorderDPThreshold && conditions.nonEmpty &&
61-
items.forall(_.stats(conf).rowCount.isDefined)) {
61+
items.forall(_.stats.rowCount.isDefined)) {
6262
JoinReorderDP.search(conf, items, conditions, output)
6363
} else {
6464
plan
@@ -322,7 +322,7 @@ object JoinReorderDP extends PredicateHelper with Logging {
322322
/** Get the cost of the root node of this plan tree. */
323323
def rootCost(conf: SQLConf): Cost = {
324324
if (itemIds.size > 1) {
325-
val rootStats = plan.stats(conf)
325+
val rootStats = plan.stats
326326
Cost(rootStats.rowCount.get, rootStats.sizeInBytes)
327327
} else {
328328
// If the plan is a leaf item, it has zero cost.

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -317,7 +317,7 @@ case class LimitPushDown(conf: SQLConf) extends Rule[LogicalPlan] {
317317
case FullOuter =>
318318
(left.maxRows, right.maxRows) match {
319319
case (None, None) =>
320-
if (left.stats(conf).sizeInBytes >= right.stats(conf).sizeInBytes) {
320+
if (left.stats.sizeInBytes >= right.stats.sizeInBytes) {
321321
join.copy(left = maybePushLimit(exp, left))
322322
} else {
323323
join.copy(right = maybePushLimit(exp, right))

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/StarSchemaDetection.scala

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ case class StarSchemaDetection(conf: SQLConf) extends PredicateHelper {
8282
// Find if the input plans are eligible for star join detection.
8383
// An eligible plan is a base table access with valid statistics.
8484
val foundEligibleJoin = input.forall {
85-
case PhysicalOperation(_, _, t: LeafNode) if t.stats(conf).rowCount.isDefined => true
85+
case PhysicalOperation(_, _, t: LeafNode) if t.stats.rowCount.isDefined => true
8686
case _ => false
8787
}
8888

@@ -181,7 +181,7 @@ case class StarSchemaDetection(conf: SQLConf) extends PredicateHelper {
181181
val leafCol = findLeafNodeCol(column, plan)
182182
leafCol match {
183183
case Some(col) if t.outputSet.contains(col) =>
184-
val stats = t.stats(conf)
184+
val stats = t.stats
185185
stats.rowCount match {
186186
case Some(rowCount) if rowCount >= 0 =>
187187
if (stats.attributeStats.nonEmpty && stats.attributeStats.contains(col)) {
@@ -237,7 +237,7 @@ case class StarSchemaDetection(conf: SQLConf) extends PredicateHelper {
237237
val leafCol = findLeafNodeCol(column, plan)
238238
leafCol match {
239239
case Some(col) if t.outputSet.contains(col) =>
240-
val stats = t.stats(conf)
240+
val stats = t.stats
241241
stats.attributeStats.nonEmpty && stats.attributeStats.contains(col)
242242
case None => false
243243
}
@@ -296,11 +296,11 @@ case class StarSchemaDetection(conf: SQLConf) extends PredicateHelper {
296296
*/
297297
private def getTableAccessCardinality(
298298
input: LogicalPlan): Option[BigInt] = input match {
299-
case PhysicalOperation(_, cond, t: LeafNode) if t.stats(conf).rowCount.isDefined =>
300-
if (conf.cboEnabled && input.stats(conf).rowCount.isDefined) {
301-
Option(input.stats(conf).rowCount.get)
299+
case PhysicalOperation(_, cond, t: LeafNode) if t.stats.rowCount.isDefined =>
300+
if (conf.cboEnabled && input.stats.rowCount.isDefined) {
301+
Option(input.stats.rowCount.get)
302302
} else {
303-
Option(t.stats(conf).rowCount.get)
303+
Option(t.stats.rowCount.get)
304304
}
305305
case _ => None
306306
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LocalRelation.scala

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@ import org.apache.spark.sql.Row
2121
import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
2222
import org.apache.spark.sql.catalyst.analysis
2323
import org.apache.spark.sql.catalyst.expressions.{Attribute, Literal}
24-
import org.apache.spark.sql.internal.SQLConf
2524
import org.apache.spark.sql.types.{StructField, StructType}
2625

2726
object LocalRelation {
@@ -67,7 +66,7 @@ case class LocalRelation(output: Seq[Attribute], data: Seq[InternalRow] = Nil)
6766
}
6867
}
6968

70-
override def computeStats(conf: SQLConf): Statistics =
69+
override def computeStats: Statistics =
7170
Statistics(sizeInBytes =
7271
output.map(n => BigInt(n.dataType.defaultSize)).sum * data.length)
7372

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@ import org.apache.spark.sql.catalyst.analysis._
2323
import org.apache.spark.sql.catalyst.expressions._
2424
import org.apache.spark.sql.catalyst.plans.QueryPlan
2525
import org.apache.spark.sql.catalyst.trees.CurrentOrigin
26-
import org.apache.spark.sql.internal.SQLConf
2726
import org.apache.spark.sql.types.StructType
2827

2928

@@ -90,8 +89,8 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with QueryPlanConstrai
9089
* first time. If the configuration changes, the cache can be invalidated by calling
9190
* [[invalidateStatsCache()]].
9291
*/
93-
final def stats(conf: SQLConf): Statistics = statsCache.getOrElse {
94-
statsCache = Some(computeStats(conf))
92+
final def stats: Statistics = statsCache.getOrElse {
93+
statsCache = Some(computeStats)
9594
statsCache.get
9695
}
9796

@@ -108,11 +107,11 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with QueryPlanConstrai
108107
*
109108
* [[LeafNode]]s must override this.
110109
*/
111-
protected def computeStats(conf: SQLConf): Statistics = {
110+
protected def computeStats: Statistics = {
112111
if (children.isEmpty) {
113112
throw new UnsupportedOperationException(s"LeafNode $nodeName must implement statistics.")
114113
}
115-
Statistics(sizeInBytes = children.map(_.stats(conf).sizeInBytes).product)
114+
Statistics(sizeInBytes = children.map(_.stats.sizeInBytes).product)
116115
}
117116

118117
override def verboseStringWithSuffix: String = {
@@ -333,21 +332,21 @@ abstract class UnaryNode extends LogicalPlan {
333332

334333
override protected def validConstraints: Set[Expression] = child.constraints
335334

336-
override def computeStats(conf: SQLConf): Statistics = {
335+
override def computeStats: Statistics = {
337336
// There should be some overhead in Row object, the size should not be zero when there is
338337
// no columns, this help to prevent divide-by-zero error.
339338
val childRowSize = child.output.map(_.dataType.defaultSize).sum + 8
340339
val outputRowSize = output.map(_.dataType.defaultSize).sum + 8
341340
// Assume there will be the same number of rows as child has.
342-
var sizeInBytes = (child.stats(conf).sizeInBytes * outputRowSize) / childRowSize
341+
var sizeInBytes = (child.stats.sizeInBytes * outputRowSize) / childRowSize
343342
if (sizeInBytes == 0) {
344343
// sizeInBytes can't be zero, or sizeInBytes of BinaryNode will also be zero
345344
// (product of children).
346345
sizeInBytes = 1
347346
}
348347

349348
// Don't propagate rowCount and attributeStats, since they are not estimated here.
350-
Statistics(sizeInBytes = sizeInBytes, hints = child.stats(conf).hints)
349+
Statistics(sizeInBytes = sizeInBytes, hints = child.stats.hints)
351350
}
352351
}
353352

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala

Lines changed: 32 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@ import org.apache.spark.sql.catalyst.expressions._
2323
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
2424
import org.apache.spark.sql.catalyst.plans._
2525
import org.apache.spark.sql.catalyst.plans.logical.statsEstimation._
26-
import org.apache.spark.sql.internal.SQLConf
2726
import org.apache.spark.sql.types._
2827
import org.apache.spark.util.Utils
2928
import org.apache.spark.util.random.RandomSampler
@@ -65,11 +64,11 @@ case class Project(projectList: Seq[NamedExpression], child: LogicalPlan) extend
6564
override def validConstraints: Set[Expression] =
6665
child.constraints.union(getAliasedConstraints(projectList))
6766

68-
override def computeStats(conf: SQLConf): Statistics = {
67+
override def computeStats: Statistics = {
6968
if (conf.cboEnabled) {
70-
ProjectEstimation.estimate(conf, this).getOrElse(super.computeStats(conf))
69+
ProjectEstimation.estimate(this).getOrElse(super.computeStats)
7170
} else {
72-
super.computeStats(conf)
71+
super.computeStats
7372
}
7473
}
7574
}
@@ -139,11 +138,11 @@ case class Filter(condition: Expression, child: LogicalPlan)
139138
child.constraints.union(predicates.toSet)
140139
}
141140

142-
override def computeStats(conf: SQLConf): Statistics = {
141+
override def computeStats: Statistics = {
143142
if (conf.cboEnabled) {
144-
FilterEstimation(this, conf).estimate.getOrElse(super.computeStats(conf))
143+
FilterEstimation(this).estimate.getOrElse(super.computeStats)
145144
} else {
146-
super.computeStats(conf)
145+
super.computeStats
147146
}
148147
}
149148
}
@@ -192,13 +191,13 @@ case class Intersect(left: LogicalPlan, right: LogicalPlan) extends SetOperation
192191
}
193192
}
194193

195-
override def computeStats(conf: SQLConf): Statistics = {
196-
val leftSize = left.stats(conf).sizeInBytes
197-
val rightSize = right.stats(conf).sizeInBytes
194+
override def computeStats: Statistics = {
195+
val leftSize = left.stats.sizeInBytes
196+
val rightSize = right.stats.sizeInBytes
198197
val sizeInBytes = if (leftSize < rightSize) leftSize else rightSize
199198
Statistics(
200199
sizeInBytes = sizeInBytes,
201-
hints = left.stats(conf).hints.resetForJoin())
200+
hints = left.stats.hints.resetForJoin())
202201
}
203202
}
204203

@@ -209,8 +208,8 @@ case class Except(left: LogicalPlan, right: LogicalPlan) extends SetOperation(le
209208

210209
override protected def validConstraints: Set[Expression] = leftConstraints
211210

212-
override def computeStats(conf: SQLConf): Statistics = {
213-
left.stats(conf).copy()
211+
override def computeStats: Statistics = {
212+
left.stats.copy()
214213
}
215214
}
216215

@@ -248,8 +247,8 @@ case class Union(children: Seq[LogicalPlan]) extends LogicalPlan {
248247
children.length > 1 && childrenResolved && allChildrenCompatible
249248
}
250249

251-
override def computeStats(conf: SQLConf): Statistics = {
252-
val sizeInBytes = children.map(_.stats(conf).sizeInBytes).sum
250+
override def computeStats: Statistics = {
251+
val sizeInBytes = children.map(_.stats.sizeInBytes).sum
253252
Statistics(sizeInBytes = sizeInBytes)
254253
}
255254

@@ -357,20 +356,20 @@ case class Join(
357356
case _ => resolvedExceptNatural
358357
}
359358

360-
override def computeStats(conf: SQLConf): Statistics = {
359+
override def computeStats: Statistics = {
361360
def simpleEstimation: Statistics = joinType match {
362361
case LeftAnti | LeftSemi =>
363362
// LeftSemi and LeftAnti won't ever be bigger than left
364-
left.stats(conf)
363+
left.stats
365364
case _ =>
366365
// Make sure we don't propagate isBroadcastable in other joins, because
367366
// they could explode the size.
368-
val stats = super.computeStats(conf)
367+
val stats = super.computeStats
369368
stats.copy(hints = stats.hints.resetForJoin())
370369
}
371370

372371
if (conf.cboEnabled) {
373-
JoinEstimation.estimate(conf, this).getOrElse(simpleEstimation)
372+
JoinEstimation.estimate(this).getOrElse(simpleEstimation)
374373
} else {
375374
simpleEstimation
376375
}
@@ -523,7 +522,7 @@ case class Range(
523522

524523
override def newInstance(): Range = copy(output = output.map(_.newInstance()))
525524

526-
override def computeStats(conf: SQLConf): Statistics = {
525+
override def computeStats: Statistics = {
527526
val sizeInBytes = LongType.defaultSize * numElements
528527
Statistics( sizeInBytes = sizeInBytes )
529528
}
@@ -556,20 +555,20 @@ case class Aggregate(
556555
child.constraints.union(getAliasedConstraints(nonAgg))
557556
}
558557

559-
override def computeStats(conf: SQLConf): Statistics = {
558+
override def computeStats: Statistics = {
560559
def simpleEstimation: Statistics = {
561560
if (groupingExpressions.isEmpty) {
562561
Statistics(
563562
sizeInBytes = EstimationUtils.getOutputSize(output, outputRowCount = 1),
564563
rowCount = Some(1),
565-
hints = child.stats(conf).hints)
564+
hints = child.stats.hints)
566565
} else {
567-
super.computeStats(conf)
566+
super.computeStats
568567
}
569568
}
570569

571570
if (conf.cboEnabled) {
572-
AggregateEstimation.estimate(conf, this).getOrElse(simpleEstimation)
571+
AggregateEstimation.estimate(this).getOrElse(simpleEstimation)
573572
} else {
574573
simpleEstimation
575574
}
@@ -672,8 +671,8 @@ case class Expand(
672671
override def references: AttributeSet =
673672
AttributeSet(projections.flatten.flatMap(_.references))
674673

675-
override def computeStats(conf: SQLConf): Statistics = {
676-
val sizeInBytes = super.computeStats(conf).sizeInBytes * projections.length
674+
override def computeStats: Statistics = {
675+
val sizeInBytes = super.computeStats.sizeInBytes * projections.length
677676
Statistics(sizeInBytes = sizeInBytes)
678677
}
679678

@@ -743,9 +742,9 @@ case class GlobalLimit(limitExpr: Expression, child: LogicalPlan) extends UnaryN
743742
case _ => None
744743
}
745744
}
746-
override def computeStats(conf: SQLConf): Statistics = {
745+
override def computeStats: Statistics = {
747746
val limit = limitExpr.eval().asInstanceOf[Int]
748-
val childStats = child.stats(conf)
747+
val childStats = child.stats
749748
val rowCount: BigInt = childStats.rowCount.map(_.min(limit)).getOrElse(limit)
750749
// Don't propagate column stats, because we don't know the distribution after a limit operation
751750
Statistics(
@@ -763,9 +762,9 @@ case class LocalLimit(limitExpr: Expression, child: LogicalPlan) extends UnaryNo
763762
case _ => None
764763
}
765764
}
766-
override def computeStats(conf: SQLConf): Statistics = {
765+
override def computeStats: Statistics = {
767766
val limit = limitExpr.eval().asInstanceOf[Int]
768-
val childStats = child.stats(conf)
767+
val childStats = child.stats
769768
if (limit == 0) {
770769
// sizeInBytes can't be zero, or sizeInBytes of BinaryNode will also be zero
771770
// (product of children).
@@ -832,9 +831,9 @@ case class Sample(
832831

833832
override def output: Seq[Attribute] = child.output
834833

835-
override def computeStats(conf: SQLConf): Statistics = {
834+
override def computeStats: Statistics = {
836835
val ratio = upperBound - lowerBound
837-
val childStats = child.stats(conf)
836+
val childStats = child.stats
838837
var sizeInBytes = EstimationUtils.ceil(BigDecimal(childStats.sizeInBytes) * ratio)
839838
if (sizeInBytes == 0) {
840839
sizeInBytes = 1
@@ -898,7 +897,7 @@ case class RepartitionByExpression(
898897
case object OneRowRelation extends LeafNode {
899898
override def maxRows: Option[Long] = Some(1)
900899
override def output: Seq[Attribute] = Nil
901-
override def computeStats(conf: SQLConf): Statistics = Statistics(sizeInBytes = 1)
900+
override def computeStats: Statistics = Statistics(sizeInBytes = 1)
902901
}
903902

904903
/** A logical plan for `dropDuplicates`. */

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/hints.scala

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
package org.apache.spark.sql.catalyst.plans.logical
1919

2020
import org.apache.spark.sql.catalyst.expressions.Attribute
21-
import org.apache.spark.sql.internal.SQLConf
2221

2322
/**
2423
* A general hint for the child that is not yet resolved. This node is generated by the parser and
@@ -44,8 +43,8 @@ case class ResolvedHint(child: LogicalPlan, hints: HintInfo = HintInfo())
4443

4544
override lazy val canonicalized: LogicalPlan = child.canonicalized
4645

47-
override def computeStats(conf: SQLConf): Statistics = {
48-
val stats = child.stats(conf)
46+
override def computeStats: Statistics = {
47+
val stats = child.stats
4948
stats.copy(hints = hints)
5049
}
5150
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/AggregateEstimation.scala

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@ package org.apache.spark.sql.catalyst.plans.logical.statsEstimation
1919

2020
import org.apache.spark.sql.catalyst.expressions.Attribute
2121
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Statistics}
22-
import org.apache.spark.sql.internal.SQLConf
2322

2423

2524
object AggregateEstimation {
@@ -29,13 +28,13 @@ object AggregateEstimation {
2928
* Estimate the number of output rows based on column stats of group-by columns, and propagate
3029
* column stats for aggregate expressions.
3130
*/
32-
def estimate(conf: SQLConf, agg: Aggregate): Option[Statistics] = {
33-
val childStats = agg.child.stats(conf)
31+
def estimate(agg: Aggregate): Option[Statistics] = {
32+
val childStats = agg.child.stats
3433
// Check if we have column stats for all group-by columns.
3534
val colStatsExist = agg.groupingExpressions.forall { e =>
3635
e.isInstanceOf[Attribute] && childStats.attributeStats.contains(e.asInstanceOf[Attribute])
3736
}
38-
if (rowCountsExist(conf, agg.child) && colStatsExist) {
37+
if (rowCountsExist(agg.child) && colStatsExist) {
3938
// Multiply distinct counts of group-by columns. This is an upper bound, which assumes
4039
// the data contains all combinations of distinct values of group-by columns.
4140
var outputRows: BigInt = agg.groupingExpressions.foldLeft(BigInt(1))(

0 commit comments

Comments
 (0)