Skip to content

Commit 336cd98

Browse files
wangyumGitHub Enterprise
authored andcommitted
[CARMEL-6106] Improve join stats estimation if one side can keep uniqueness (#1024)
1 parent b11685d commit 336cd98

File tree

3 files changed

+61
-9
lines changed

3 files changed

+61
-9
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/JoinEstimation.scala

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ import scala.collection.mutable
2121
import scala.collection.mutable.ArrayBuffer
2222

2323
import org.apache.spark.internal.Logging
24-
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, AttributeReference, Expression}
24+
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, AttributeReference, Expression, ExpressionSet}
2525
import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys
2626
import org.apache.spark.sql.catalyst.plans._
2727
import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, Histogram, Join, Statistics}
@@ -56,10 +56,13 @@ case class JoinEstimation(join: Join) extends Logging {
5656
case _ if !rowCountsExist(join.left, join.right) =>
5757
None
5858

59-
case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, _, _, _, _) =>
59+
case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, _, left, right, _) =>
6060
// 1. Compute join selectivity
6161
val joinKeyPairs = extractJoinKeysWithColStats(leftKeys, rightKeys)
62-
val (numInnerJoinedRows, keyStatsAfterJoin) = computeCardinalityAndStats(joinKeyPairs)
62+
val leftUniqueness = left.distinctKeys.exists(_.subsetOf(ExpressionSet(leftKeys)))
63+
val rightUniqueness = right.distinctKeys.exists(_.subsetOf(ExpressionSet(rightKeys)))
64+
val (numInnerJoinedRows, keyStatsAfterJoin) =
65+
computeCardinalityAndStats(joinKeyPairs, leftUniqueness, rightUniqueness)
6366

6467
// 2. Estimate the number of output rows
6568
val leftRows = leftStats.rowCount.get
@@ -177,10 +180,17 @@ case class JoinEstimation(join: Join) extends Logging {
177180
* @return join cardinality, and column stats for join keys after the join
178181
*/
179182
// scalastyle:on
180-
private def computeCardinalityAndStats(keyPairs: Seq[(AttributeReference, AttributeReference)])
181-
: (BigInt, AttributeMap[ColumnStat]) = {
183+
private def computeCardinalityAndStats(
184+
keyPairs: Seq[(AttributeReference, AttributeReference)],
185+
leftUniqueness: Boolean,
186+
rightUniqueness: Boolean): (BigInt, AttributeMap[ColumnStat]) = {
182187
// If there's no column stats available for join keys, estimate as cartesian product.
183-
var joinCard: BigInt = leftStats.rowCount.get * rightStats.rowCount.get
188+
var joinCard: BigInt = (leftUniqueness, rightUniqueness) match {
189+
case (true, true) => leftStats.rowCount.get.min(rightStats.rowCount.get)
190+
case (true, false) => rightStats.rowCount.get
191+
case (false, true) => leftStats.rowCount.get
192+
case _ => leftStats.rowCount.get * rightStats.rowCount.get
193+
}
184194
val keyStatsAfterJoin = new mutable.HashMap[Attribute, ColumnStat]()
185195
var i = 0
186196
while(i < keyPairs.length && joinCard != 0) {

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/SizeInBytesOnlyStatsPlanVisitor.scala

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,9 @@
1717

1818
package org.apache.spark.sql.catalyst.plans.logical.statsEstimation
1919

20-
import org.apache.spark.sql.catalyst.expressions.AttributeMap
21-
import org.apache.spark.sql.catalyst.plans.{LeftAnti, LeftSemi}
20+
import org.apache.spark.sql.catalyst.expressions.{AttributeMap, ExpressionSet}
21+
import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys
22+
import org.apache.spark.sql.catalyst.plans._
2223
import org.apache.spark.sql.catalyst.plans.logical._
2324

2425
/**
@@ -103,6 +104,16 @@ object SizeInBytesOnlyStatsPlanVisitor extends LogicalPlanVisitor[Statistics] {
103104
case LeftAnti | LeftSemi =>
104105
// LeftSemi and LeftAnti won't ever be bigger than left
105106
p.left.stats
107+
case Inner | LeftOuter | RightOuter | FullOuter =>
108+
p match {
109+
case ExtractEquiJoinKeys(_, leftKeys, rightKeys, _, left, right, _)
110+
if left.distinctKeys.exists(_.subsetOf(ExpressionSet(leftKeys))) ||
111+
right.distinctKeys.exists(_.subsetOf(ExpressionSet(rightKeys))) =>
112+
// The sizeInBytes should be > 1 because sizeInBytes * 1 != sizeInBytes + 1.
113+
Statistics(sizeInBytes = p.children.map(_.stats.sizeInBytes).filter(_ > 1L).sum)
114+
case _ =>
115+
default(p)
116+
}
106117
case _ =>
107118
default(p)
108119
}

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/BasicStatsEstimationSuite.scala

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ import org.apache.spark.sql.catalyst.analysis.ResolvedNamespace
2323
import org.apache.spark.sql.catalyst.dsl.expressions._
2424
import org.apache.spark.sql.catalyst.dsl.plans._
2525
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, AttributeReference, Literal}
26-
import org.apache.spark.sql.catalyst.plans.PlanTest
26+
import org.apache.spark.sql.catalyst.plans.{Inner, PlanTest}
2727
import org.apache.spark.sql.catalyst.plans.logical._
2828
import org.apache.spark.sql.connector.catalog.SupportsNamespaces
2929
import org.apache.spark.sql.internal.SQLConf
@@ -155,6 +155,37 @@ class BasicStatsEstimationSuite extends PlanTest with StatsEstimationTestBase {
155155
expectedStatsCboOff = Statistics(sizeInBytes = sizeInBytes))
156156
}
157157

158+
test("SPARK-39851: Improve join stats estimation if one side can keep uniqueness") {
159+
val brandId = attr("brand_id")
160+
val classId = attr("class_id")
161+
val aliasedBrandId = brandId.as("new_brand_id")
162+
val aliasedClassId = classId.as("new_class_id")
163+
164+
val tableSize = 4059900
165+
val tableRowCnt = 202995
166+
167+
val tbl = StatsTestPlan(
168+
outputList = Seq(brandId, classId),
169+
size = Some(tableSize),
170+
rowCount = tableRowCnt,
171+
attributeStats =
172+
AttributeMap(Seq(
173+
brandId -> ColumnStat(Some(858), Some(101001), Some(1016017), Some(0), Some(4), Some(4)),
174+
classId -> ColumnStat(Some(16), Some(1), Some(16), Some(0), Some(4), Some(4)))))
175+
176+
val join = Join(
177+
tbl,
178+
tbl.groupBy(brandId, classId)(aliasedBrandId, aliasedClassId),
179+
Inner,
180+
Some(brandId === aliasedBrandId.toAttribute && classId === aliasedClassId.toAttribute),
181+
JoinHint.NONE)
182+
183+
checkStats(
184+
join,
185+
expectedStatsCboOn = Statistics(4871880, Some(tableRowCnt), join.stats.attributeStats),
186+
expectedStatsCboOff = Statistics(sizeInBytes = 4059900 * 2))
187+
}
188+
158189
/** Check estimated stats when cbo is turned on/off. */
159190
private def checkStats(
160191
plan: LogicalPlan,

0 commit comments

Comments
 (0)