@@ -21,7 +21,7 @@ import scala.collection.mutable
2121import scala .collection .mutable .ArrayBuffer
2222
2323import org .apache .spark .internal .Logging
24- import org .apache .spark .sql .catalyst .expressions .{Attribute , AttributeMap , AttributeReference , Expression }
24+ import org .apache .spark .sql .catalyst .expressions .{Attribute , AttributeMap , AttributeReference , Expression , ExpressionSet }
2525import org .apache .spark .sql .catalyst .planning .ExtractEquiJoinKeys
2626import org .apache .spark .sql .catalyst .plans ._
2727import org .apache .spark .sql .catalyst .plans .logical .{ColumnStat , Histogram , Join , Statistics }
@@ -56,10 +56,13 @@ case class JoinEstimation(join: Join) extends Logging {
5656 case _ if ! rowCountsExist(join.left, join.right) =>
5757 None
5858
59- case ExtractEquiJoinKeys (joinType, leftKeys, rightKeys, _, _, _ , _) =>
59+ case ExtractEquiJoinKeys (joinType, leftKeys, rightKeys, _, left, right , _) =>
6060 // 1. Compute join selectivity
6161 val joinKeyPairs = extractJoinKeysWithColStats(leftKeys, rightKeys)
62- val (numInnerJoinedRows, keyStatsAfterJoin) = computeCardinalityAndStats(joinKeyPairs)
62+ val leftUniqueness = left.distinctKeys.exists(_.subsetOf(ExpressionSet (leftKeys)))
63+ val rightUniqueness = right.distinctKeys.exists(_.subsetOf(ExpressionSet (rightKeys)))
64+ val (numInnerJoinedRows, keyStatsAfterJoin) =
65+ computeCardinalityAndStats(joinKeyPairs, leftUniqueness, rightUniqueness)
6366
6467 // 2. Estimate the number of output rows
6568 val leftRows = leftStats.rowCount.get
@@ -177,10 +180,17 @@ case class JoinEstimation(join: Join) extends Logging {
177180 * @return join cardinality, and column stats for join keys after the join
178181 */
179182 // scalastyle:on
180- private def computeCardinalityAndStats (keyPairs : Seq [(AttributeReference , AttributeReference )])
181- : (BigInt , AttributeMap [ColumnStat ]) = {
183+ private def computeCardinalityAndStats (
184+ keyPairs : Seq [(AttributeReference , AttributeReference )],
185+ leftUniqueness : Boolean ,
186+ rightUniqueness : Boolean ): (BigInt , AttributeMap [ColumnStat ]) = {
182187 // If there's no column stats available for join keys, estimate as cartesian product.
183- var joinCard : BigInt = leftStats.rowCount.get * rightStats.rowCount.get
188+ var joinCard : BigInt = (leftUniqueness, rightUniqueness) match {
189+ case (true , true ) => leftStats.rowCount.get.min(rightStats.rowCount.get)
190+ case (true , false ) => rightStats.rowCount.get
191+ case (false , true ) => leftStats.rowCount.get
192+ case _ => leftStats.rowCount.get * rightStats.rowCount.get
193+ }
184194 val keyStatsAfterJoin = new mutable.HashMap [Attribute , ColumnStat ]()
185195 var i = 0
186196 while (i < keyPairs.length && joinCard != 0 ) {
0 commit comments