Skip to content

Commit b1d719e

Browse files
rxingatorsmile
authored andcommitted
[SPARK-21273][SQL] Propagate logical plan stats using visitor pattern and mixin
## What changes were proposed in this pull request? We currently implement statistics propagation directly in logical plan. Given we already have two different implementations, it'd make sense to actually decouple the two and add stats propagation using mixin. This would reduce the coupling between logical plan and statistics handling. This can also be a powerful pattern in the future to add additional properties (e.g. constraints). ## How was this patch tested? Should be covered by existing test cases. Author: Reynold Xin <[email protected]> Closes #18479 from rxin/stats-trait.
1 parent 61b5df5 commit b1d719e

File tree

16 files changed

+409
-238
lines changed

16 files changed

+409
-238
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -438,7 +438,7 @@ case class CatalogRelation(
438438
case (attr, index) => attr.withExprId(ExprId(index + dataCols.length))
439439
})
440440

441-
override def computeStats: Statistics = {
441+
override def computeStats(): Statistics = {
442442
// For data source tables, we will create a `LogicalRelation` and won't call this method, for
443443
// hive serde tables, we will always generate a statistics.
444444
// TODO: unify the table stats generation.

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LocalRelation.scala

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -66,9 +66,8 @@ case class LocalRelation(output: Seq[Attribute], data: Seq[InternalRow] = Nil)
6666
}
6767
}
6868

69-
override def computeStats: Statistics =
70-
Statistics(sizeInBytes =
71-
output.map(n => BigInt(n.dataType.defaultSize)).sum * data.length)
69+
override def computeStats(): Statistics =
70+
Statistics(sizeInBytes = output.map(n => BigInt(n.dataType.defaultSize)).sum * data.length)
7271

7372
def toSQL(inlineTableName: String): String = {
7473
require(data.nonEmpty)

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala

Lines changed: 9 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -22,11 +22,16 @@ import org.apache.spark.sql.AnalysisException
2222
import org.apache.spark.sql.catalyst.analysis._
2323
import org.apache.spark.sql.catalyst.expressions._
2424
import org.apache.spark.sql.catalyst.plans.QueryPlan
25+
import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.LogicalPlanStats
2526
import org.apache.spark.sql.catalyst.trees.CurrentOrigin
2627
import org.apache.spark.sql.types.StructType
2728

2829

29-
abstract class LogicalPlan extends QueryPlan[LogicalPlan] with QueryPlanConstraints with Logging {
30+
abstract class LogicalPlan
31+
extends QueryPlan[LogicalPlan]
32+
with LogicalPlanStats
33+
with QueryPlanConstraints
34+
with Logging {
3035

3136
private var _analyzed: Boolean = false
3237

@@ -80,40 +85,6 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with QueryPlanConstrai
8085
}
8186
}
8287

83-
/** A cache for the estimated statistics, such that it will only be computed once. */
84-
private var statsCache: Option[Statistics] = None
85-
86-
/**
87-
* Returns the estimated statistics for the current logical plan node. Under the hood, this
88-
* method caches the return value, which is computed based on the configuration passed in the
89-
* first time. If the configuration changes, the cache can be invalidated by calling
90-
* [[invalidateStatsCache()]].
91-
*/
92-
final def stats: Statistics = statsCache.getOrElse {
93-
statsCache = Some(computeStats)
94-
statsCache.get
95-
}
96-
97-
/** Invalidates the stats cache. See [[stats]] for more information. */
98-
final def invalidateStatsCache(): Unit = {
99-
statsCache = None
100-
children.foreach(_.invalidateStatsCache())
101-
}
102-
103-
/**
104-
* Computes [[Statistics]] for this plan. The default implementation assumes the output
105-
* cardinality is the product of all child plan's cardinality, i.e. applies in the case
106-
* of cartesian joins.
107-
*
108-
* [[LeafNode]]s must override this.
109-
*/
110-
protected def computeStats: Statistics = {
111-
if (children.isEmpty) {
112-
throw new UnsupportedOperationException(s"LeafNode $nodeName must implement statistics.")
113-
}
114-
Statistics(sizeInBytes = children.map(_.stats.sizeInBytes).product)
115-
}
116-
11788
override def verboseStringWithSuffix: String = {
11889
super.verboseString + statsCache.map(", " + _.toString).getOrElse("")
11990
}
@@ -300,6 +271,9 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with QueryPlanConstrai
300271
abstract class LeafNode extends LogicalPlan {
301272
override final def children: Seq[LogicalPlan] = Nil
302273
override def producedAttributes: AttributeSet = outputSet
274+
275+
/** Leaf nodes that can survive analysis must define their own statistics. */
276+
def computeStats(): Statistics = throw new UnsupportedOperationException
303277
}
304278

305279
/**
@@ -331,23 +305,6 @@ abstract class UnaryNode extends LogicalPlan {
331305
}
332306

333307
override protected def validConstraints: Set[Expression] = child.constraints
334-
335-
override def computeStats: Statistics = {
336-
// There should be some overhead in Row object, the size should not be zero when there is
337-
// no columns, this help to prevent divide-by-zero error.
338-
val childRowSize = child.output.map(_.dataType.defaultSize).sum + 8
339-
val outputRowSize = output.map(_.dataType.defaultSize).sum + 8
340-
// Assume there will be the same number of rows as child has.
341-
var sizeInBytes = (child.stats.sizeInBytes * outputRowSize) / childRowSize
342-
if (sizeInBytes == 0) {
343-
// sizeInBytes can't be zero, or sizeInBytes of BinaryNode will also be zero
344-
// (product of children).
345-
sizeInBytes = 1
346-
}
347-
348-
// Don't propagate rowCount and attributeStats, since they are not estimated here.
349-
Statistics(sizeInBytes = sizeInBytes, hints = child.stats.hints)
350-
}
351308
}
352309

353310
/**
Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql.catalyst.plans.logical
19+
20+
/**
21+
* A visitor pattern for traversing a [[LogicalPlan]] tree and compute some properties.
22+
*/
23+
trait LogicalPlanVisitor[T] {
24+
25+
def visit(p: LogicalPlan): T = p match {
26+
case p: Aggregate => visitAggregate(p)
27+
case p: Distinct => visitDistinct(p)
28+
case p: Except => visitExcept(p)
29+
case p: Expand => visitExpand(p)
30+
case p: Filter => visitFilter(p)
31+
case p: Generate => visitGenerate(p)
32+
case p: GlobalLimit => visitGlobalLimit(p)
33+
case p: Intersect => visitIntersect(p)
34+
case p: Join => visitJoin(p)
35+
case p: LocalLimit => visitLocalLimit(p)
36+
case p: Pivot => visitPivot(p)
37+
case p: Project => visitProject(p)
38+
case p: Range => visitRange(p)
39+
case p: Repartition => visitRepartition(p)
40+
case p: RepartitionByExpression => visitRepartitionByExpr(p)
41+
case p: Sample => visitSample(p)
42+
case p: ScriptTransformation => visitScriptTransform(p)
43+
case p: Union => visitUnion(p)
44+
case p: ResolvedHint => visitHint(p)
45+
case p: LogicalPlan => default(p)
46+
}
47+
48+
def default(p: LogicalPlan): T
49+
50+
def visitAggregate(p: Aggregate): T
51+
52+
def visitDistinct(p: Distinct): T
53+
54+
def visitExcept(p: Except): T
55+
56+
def visitExpand(p: Expand): T
57+
58+
def visitFilter(p: Filter): T
59+
60+
def visitGenerate(p: Generate): T
61+
62+
def visitGlobalLimit(p: GlobalLimit): T
63+
64+
def visitHint(p: ResolvedHint): T
65+
66+
def visitIntersect(p: Intersect): T
67+
68+
def visitJoin(p: Join): T
69+
70+
def visitLocalLimit(p: LocalLimit): T
71+
72+
def visitPivot(p: Pivot): T
73+
74+
def visitProject(p: Project): T
75+
76+
def visitRange(p: Range): T
77+
78+
def visitRepartition(p: Repartition): T
79+
80+
def visitRepartitionByExpr(p: RepartitionByExpression): T
81+
82+
def visitSample(p: Sample): T
83+
84+
def visitScriptTransform(p: ScriptTransformation): T
85+
86+
def visitUnion(p: Union): T
87+
}

0 commit comments

Comments
 (0)