diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/EstimationUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/EstimationUtils.scala index 6f868cbd072c..71e852afe065 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/EstimationUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/EstimationUtils.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.catalyst.plans.logical.statsEstimation +import scala.collection.mutable.ArrayBuffer import scala.math.BigDecimal.RoundingMode import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap} @@ -212,4 +213,172 @@ object EstimationUtils { } } + /** + * Returns overlapped ranges between two histograms, in the given value range + * [lowerBound, upperBound]. + */ + def getOverlappedRanges( + leftHistogram: Histogram, + rightHistogram: Histogram, + lowerBound: Double, + upperBound: Double): Seq[OverlappedRange] = { + val overlappedRanges = new ArrayBuffer[OverlappedRange]() + // Only bins whose range intersect [lowerBound, upperBound] have join possibility. + val leftBins = leftHistogram.bins + .filter(b => b.lo <= upperBound && b.hi >= lowerBound) + val rightBins = rightHistogram.bins + .filter(b => b.lo <= upperBound && b.hi >= lowerBound) + + leftBins.foreach { lb => + rightBins.foreach { rb => + val (left, leftHeight) = trimBin(lb, leftHistogram.height, lowerBound, upperBound) + val (right, rightHeight) = trimBin(rb, rightHistogram.height, lowerBound, upperBound) + // Only collect overlapped ranges. + if (left.lo <= right.hi && left.hi >= right.lo) { + // Collect overlapped ranges. + val range = if (right.lo >= left.lo && right.hi >= left.hi) { + // Case1: the left bin is "smaller" than the right bin + // left.lo right.lo left.hi right.hi + // --------+------------------+------------+----------------+-------> + if (left.hi == right.lo) { + // The overlapped range has only one value. + OverlappedRange( + lo = right.lo, + hi = right.lo, + leftNdv = 1, + rightNdv = 1, + leftNumRows = leftHeight / left.ndv, + rightNumRows = rightHeight / right.ndv + ) + } else { + val leftRatio = (left.hi - right.lo) / (left.hi - left.lo) + val rightRatio = (left.hi - right.lo) / (right.hi - right.lo) + OverlappedRange( + lo = right.lo, + hi = left.hi, + leftNdv = left.ndv * leftRatio, + rightNdv = right.ndv * rightRatio, + leftNumRows = leftHeight * leftRatio, + rightNumRows = rightHeight * rightRatio + ) + } + } else if (right.lo <= left.lo && right.hi <= left.hi) { + // Case2: the left bin is "larger" than the right bin + // right.lo left.lo right.hi left.hi + // --------+------------------+------------+----------------+-------> + if (right.hi == left.lo) { + // The overlapped range has only one value. + OverlappedRange( + lo = right.hi, + hi = right.hi, + leftNdv = 1, + rightNdv = 1, + leftNumRows = leftHeight / left.ndv, + rightNumRows = rightHeight / right.ndv + ) + } else { + val leftRatio = (right.hi - left.lo) / (left.hi - left.lo) + val rightRatio = (right.hi - left.lo) / (right.hi - right.lo) + OverlappedRange( + lo = left.lo, + hi = right.hi, + leftNdv = left.ndv * leftRatio, + rightNdv = right.ndv * rightRatio, + leftNumRows = leftHeight * leftRatio, + rightNumRows = rightHeight * rightRatio + ) + } + } else if (right.lo >= left.lo && right.hi <= left.hi) { + // Case3: the left bin contains the right bin + // left.lo right.lo right.hi left.hi + // --------+------------------+------------+----------------+-------> + val leftRatio = (right.hi - right.lo) / (left.hi - left.lo) + OverlappedRange( + lo = right.lo, + hi = right.hi, + leftNdv = left.ndv * leftRatio, + rightNdv = right.ndv, + leftNumRows = leftHeight * leftRatio, + rightNumRows = rightHeight + ) + } else { + assert(right.lo <= left.lo && right.hi >= left.hi) + // Case4: the right bin contains the left bin + // right.lo left.lo left.hi right.hi + // --------+------------------+------------+----------------+-------> + val rightRatio = (left.hi - left.lo) / (right.hi - right.lo) + OverlappedRange( + lo = left.lo, + hi = left.hi, + leftNdv = left.ndv, + rightNdv = right.ndv * rightRatio, + leftNumRows = leftHeight, + rightNumRows = rightHeight * rightRatio + ) + } + overlappedRanges += range + } + } + } + overlappedRanges + } + + /** + * Given an original bin and a value range [lowerBound, upperBound], returns the trimmed part + * of the bin in that range and its number of rows. + * @param bin the input histogram bin. + * @param height the number of rows of the given histogram bin inside an equi-height histogram. + * @param lowerBound lower bound of the given range. + * @param upperBound upper bound of the given range. + * @return trimmed part of the given bin and its number of rows. + */ + def trimBin(bin: HistogramBin, height: Double, lowerBound: Double, upperBound: Double) + : (HistogramBin, Double) = { + val (lo, hi) = if (bin.lo <= lowerBound && bin.hi >= upperBound) { + // bin.lo lowerBound upperBound bin.hi + // --------+------------------+------------+-------------+-------> + (lowerBound, upperBound) + } else if (bin.lo <= lowerBound && bin.hi >= lowerBound) { + // bin.lo lowerBound bin.hi upperBound + // --------+------------------+------------+-------------+-------> + (lowerBound, bin.hi) + } else if (bin.lo <= upperBound && bin.hi >= upperBound) { + // lowerBound bin.lo upperBound bin.hi + // --------+------------------+------------+-------------+-------> + (bin.lo, upperBound) + } else { + // lowerBound bin.lo bin.hi upperBound + // --------+------------------+------------+-------------+-------> + assert(bin.lo >= lowerBound && bin.hi <= upperBound) + (bin.lo, bin.hi) + } + + if (hi == lo) { + // Note that bin.hi == bin.lo also falls into this branch. + (HistogramBin(lo, hi, 1), height / bin.ndv) + } else { + assert(bin.hi != bin.lo) + val ratio = (hi - lo) / (bin.hi - bin.lo) + (HistogramBin(lo, hi, math.ceil(bin.ndv * ratio).toLong), height * ratio) + } + } + + /** + * A join between two equi-height histograms may produce multiple overlapped ranges. + * Each overlapped range is produced by a part of one bin in the left histogram and a part of + * one bin in the right histogram. + * @param lo lower bound of this overlapped range. + * @param hi higher bound of this overlapped range. + * @param leftNdv ndv in the left part. + * @param rightNdv ndv in the right part. + * @param leftNumRows number of rows in the left part. + * @param rightNumRows number of rows in the right part. + */ + case class OverlappedRange( + lo: Double, + hi: Double, + leftNdv: Double, + rightNdv: Double, + leftNumRows: Double, + rightNumRows: Double) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/JoinEstimation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/JoinEstimation.scala index b073108c26ee..f0294a424670 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/JoinEstimation.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/JoinEstimation.scala @@ -24,7 +24,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, AttributeReference, Expression} import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys import org.apache.spark.sql.catalyst.plans._ -import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, Join, Statistics} +import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, Histogram, Join, Statistics} import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.EstimationUtils._ @@ -191,8 +191,19 @@ case class JoinEstimation(join: Join) extends Logging { val rInterval = ValueInterval(rightKeyStat.min, rightKeyStat.max, rightKey.dataType) if (ValueInterval.isIntersected(lInterval, rInterval)) { val (newMin, newMax) = ValueInterval.intersect(lInterval, rInterval, leftKey.dataType) - val (card, joinStat) = computeByNdv(leftKey, rightKey, newMin, newMax) - keyStatsAfterJoin += (leftKey -> joinStat, rightKey -> joinStat) + val (card, joinStat) = (leftKeyStat.histogram, rightKeyStat.histogram) match { + case (Some(l: Histogram), Some(r: Histogram)) => + computeByHistogram(leftKey, rightKey, l, r, newMin, newMax) + case _ => + computeByNdv(leftKey, rightKey, newMin, newMax) + } + keyStatsAfterJoin += ( + // Histograms are propagated as unchanged. During future estimation, they should be + // truncated by the updated max/min. In this way, only pointers of the histograms are + // propagated and thus reduce memory consumption. + leftKey -> joinStat.copy(histogram = leftKeyStat.histogram), + rightKey -> joinStat.copy(histogram = rightKeyStat.histogram) + ) // Return cardinality estimated from the most selective join keys. if (card < joinCard) joinCard = card } else { @@ -225,6 +236,43 @@ case class JoinEstimation(join: Join) extends Logging { (ceil(card), newStats) } + /** Compute join cardinality using equi-height histograms. */ + private def computeByHistogram( + leftKey: AttributeReference, + rightKey: AttributeReference, + leftHistogram: Histogram, + rightHistogram: Histogram, + newMin: Option[Any], + newMax: Option[Any]): (BigInt, ColumnStat) = { + val overlappedRanges = getOverlappedRanges( + leftHistogram = leftHistogram, + rightHistogram = rightHistogram, + // Only numeric values have equi-height histograms. + lowerBound = newMin.get.toString.toDouble, + upperBound = newMax.get.toString.toDouble) + + var card: BigDecimal = 0 + var totalNdv: Double = 0 + for (i <- overlappedRanges.indices) { + val range = overlappedRanges(i) + if (i == 0 || range.hi != overlappedRanges(i - 1).hi) { + // If range.hi == overlappedRanges(i - 1).hi, that means the current range has only one + // value, and this value is already counted in the previous range. So there is no need to + // count it in this range. + totalNdv += math.min(range.leftNdv, range.rightNdv) + } + // Apply the formula in this overlapped range. + card += range.leftNumRows * range.rightNumRows / math.max(range.leftNdv, range.rightNdv) + } + + val leftKeyStat = leftStats.attributeStats(leftKey) + val rightKeyStat = rightStats.attributeStats(rightKey) + val newMaxLen = math.min(leftKeyStat.maxLen, rightKeyStat.maxLen) + val newAvgLen = (leftKeyStat.avgLen + rightKeyStat.avgLen) / 2 + val newStats = ColumnStat(ceil(totalNdv), newMin, newMax, 0, newAvgLen, newMaxLen) + (ceil(card), newStats) + } + /** * Propagate or update column stats for output attributes. */ diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/JoinEstimationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/JoinEstimationSuite.scala index 097c78eb27fc..26139d85d25f 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/JoinEstimationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/JoinEstimationSuite.scala @@ -23,7 +23,7 @@ import scala.collection.mutable import org.apache.spark.sql.catalyst.expressions.{And, Attribute, AttributeMap, AttributeReference, EqualTo} import org.apache.spark.sql.catalyst.plans._ -import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, Join, Project, Statistics} +import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.EstimationUtils._ import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.types.{DateType, TimestampType, _} @@ -67,6 +67,213 @@ class JoinEstimationSuite extends StatsEstimationTestBase { rowCount = 2, attributeStats = AttributeMap(Seq("key-1-2", "key-2-3").map(nameToColInfo))) + private def estimateByHistogram( + leftHistogram: Histogram, + rightHistogram: Histogram, + expectedMin: Double, + expectedMax: Double, + expectedNdv: Long, + expectedRows: Long): Unit = { + val col1 = attr("key1") + val col2 = attr("key2") + val c1 = generateJoinChild(col1, leftHistogram, expectedMin, expectedMax) + val c2 = generateJoinChild(col2, rightHistogram, expectedMin, expectedMax) + + val c1JoinC2 = Join(c1, c2, Inner, Some(EqualTo(col1, col2))) + val c2JoinC1 = Join(c2, c1, Inner, Some(EqualTo(col2, col1))) + val expectedStatsAfterJoin = Statistics( + sizeInBytes = expectedRows * (8 + 2 * 4), + rowCount = Some(expectedRows), + attributeStats = AttributeMap(Seq( + col1 -> c1.stats.attributeStats(col1).copy( + distinctCount = expectedNdv, min = Some(expectedMin), max = Some(expectedMax)), + col2 -> c2.stats.attributeStats(col2).copy( + distinctCount = expectedNdv, min = Some(expectedMin), max = Some(expectedMax)))) + ) + + // Join order should not affect estimation result. + Seq(c1JoinC2, c2JoinC1).foreach { join => + assert(join.stats == expectedStatsAfterJoin) + } + } + + private def generateJoinChild( + col: Attribute, + histogram: Histogram, + expectedMin: Double, + expectedMax: Double): LogicalPlan = { + val colStat = inferColumnStat(histogram) + StatsTestPlan( + outputList = Seq(col), + rowCount = (histogram.height * histogram.bins.length).toLong, + attributeStats = AttributeMap(Seq(col -> colStat))) + } + + /** Column statistics should be consistent with histograms in tests. */ + private def inferColumnStat(histogram: Histogram): ColumnStat = { + var ndv = 0L + for (i <- histogram.bins.indices) { + val bin = histogram.bins(i) + if (i == 0 || bin.hi != histogram.bins(i - 1).hi) { + ndv += bin.ndv + } + } + ColumnStat(distinctCount = ndv, min = Some(histogram.bins.head.lo), + max = Some(histogram.bins.last.hi), nullCount = 0, avgLen = 4, maxLen = 4, + histogram = Some(histogram)) + } + + test("equi-height histograms: a bin is contained by another one") { + val histogram1 = Histogram(height = 300, Array( + HistogramBin(lo = 10, hi = 30, ndv = 10), HistogramBin(lo = 30, hi = 60, ndv = 30))) + val histogram2 = Histogram(height = 100, Array( + HistogramBin(lo = 0, hi = 50, ndv = 50), HistogramBin(lo = 50, hi = 100, ndv = 40))) + // test bin trimming + val (t0, h0) = trimBin(histogram2.bins(0), height = 100, lowerBound = 10, upperBound = 60) + assert(t0 == HistogramBin(lo = 10, hi = 50, ndv = 40) && h0 == 80) + val (t1, h1) = trimBin(histogram2.bins(1), height = 100, lowerBound = 10, upperBound = 60) + assert(t1 == HistogramBin(lo = 50, hi = 60, ndv = 8) && h1 == 20) + + val expectedRanges = Seq( + // histogram1.bins(0) overlaps t0 + OverlappedRange(10, 30, 10, 40 * 1 / 2, 300, 80 * 1 / 2), + // histogram1.bins(1) overlaps t0 + OverlappedRange(30, 50, 30 * 2 / 3, 40 * 1 / 2, 300 * 2 / 3, 80 * 1 / 2), + // histogram1.bins(1) overlaps t1 + OverlappedRange(50, 60, 30 * 1 / 3, 8, 300 * 1 / 3, 20) + ) + assert(expectedRanges.equals( + getOverlappedRanges(histogram1, histogram2, lowerBound = 10, upperBound = 60))) + + estimateByHistogram( + leftHistogram = histogram1, + rightHistogram = histogram2, + expectedMin = 10, + expectedMax = 60, + expectedNdv = 10 + 20 + 8, + expectedRows = 300 * 40 / 20 + 200 * 40 / 20 + 100 * 20 / 10) + } + + test("equi-height histograms: a bin has only one value after trimming") { + val histogram1 = Histogram(height = 300, Array( + HistogramBin(lo = 50, hi = 60, ndv = 10), HistogramBin(lo = 60, hi = 75, ndv = 3))) + val histogram2 = Histogram(height = 100, Array( + HistogramBin(lo = 0, hi = 50, ndv = 50), HistogramBin(lo = 50, hi = 100, ndv = 40))) + // test bin trimming + val (t0, h0) = trimBin(histogram2.bins(0), height = 100, lowerBound = 50, upperBound = 75) + assert(t0 == HistogramBin(lo = 50, hi = 50, ndv = 1) && h0 == 2) + val (t1, h1) = trimBin(histogram2.bins(1), height = 100, lowerBound = 50, upperBound = 75) + assert(t1 == HistogramBin(lo = 50, hi = 75, ndv = 20) && h1 == 50) + + val expectedRanges = Seq( + // histogram1.bins(0) overlaps t0 + OverlappedRange(50, 50, 1, 1, 300 / 10, 2), + // histogram1.bins(0) overlaps t1 + OverlappedRange(50, 60, 10, 20 * 10 / 25, 300, 50 * 10 / 25), + // histogram1.bins(1) overlaps t1 + OverlappedRange(60, 75, 3, 20 * 15 / 25, 300, 50 * 15 / 25) + ) + assert(expectedRanges.equals( + getOverlappedRanges(histogram1, histogram2, lowerBound = 50, upperBound = 75))) + + estimateByHistogram( + leftHistogram = histogram1, + rightHistogram = histogram2, + expectedMin = 50, + expectedMax = 75, + expectedNdv = 1 + 8 + 3, + expectedRows = 30 * 2 / 1 + 300 * 20 / 10 + 300 * 30 / 12) + } + + test("equi-height histograms: skew distribution (some bins have only one value)") { + val histogram1 = Histogram(height = 300, Array( + HistogramBin(lo = 30, hi = 30, ndv = 1), + HistogramBin(lo = 30, hi = 30, ndv = 1), + HistogramBin(lo = 30, hi = 60, ndv = 30))) + val histogram2 = Histogram(height = 100, Array( + HistogramBin(lo = 0, hi = 50, ndv = 50), HistogramBin(lo = 50, hi = 100, ndv = 40))) + // test bin trimming + val (t0, h0) = trimBin(histogram2.bins(0), height = 100, lowerBound = 30, upperBound = 60) + assert(t0 == HistogramBin(lo = 30, hi = 50, ndv = 20) && h0 == 40) + val (t1, h1) = trimBin(histogram2.bins(1), height = 100, lowerBound = 30, upperBound = 60) + assert(t1 ==HistogramBin(lo = 50, hi = 60, ndv = 8) && h1 == 20) + + val expectedRanges = Seq( + OverlappedRange(30, 30, 1, 1, 300, 40 / 20), + OverlappedRange(30, 30, 1, 1, 300, 40 / 20), + OverlappedRange(30, 50, 30 * 2 / 3, 20, 300 * 2 / 3, 40), + OverlappedRange(50, 60, 30 * 1 / 3, 8, 300 * 1 / 3, 20) + ) + assert(expectedRanges.equals( + getOverlappedRanges(histogram1, histogram2, lowerBound = 30, upperBound = 60))) + + estimateByHistogram( + leftHistogram = histogram1, + rightHistogram = histogram2, + expectedMin = 30, + expectedMax = 60, + expectedNdv = 1 + 20 + 8, + expectedRows = 300 * 2 / 1 + 300 * 2 / 1 + 200 * 40 / 20 + 100 * 20 / 10) + } + + test("equi-height histograms: skew distribution (histograms have different skewed values") { + val histogram1 = Histogram(height = 300, Array( + HistogramBin(lo = 30, hi = 30, ndv = 1), HistogramBin(lo = 30, hi = 60, ndv = 30))) + val histogram2 = Histogram(height = 100, Array( + HistogramBin(lo = 0, hi = 50, ndv = 50), HistogramBin(lo = 50, hi = 50, ndv = 1))) + // test bin trimming + val (t0, h0) = trimBin(histogram1.bins(1), height = 300, lowerBound = 30, upperBound = 50) + assert(t0 == HistogramBin(lo = 30, hi = 50, ndv = 20) && h0 == 200) + val (t1, h1) = trimBin(histogram2.bins(0), height = 100, lowerBound = 30, upperBound = 50) + assert(t1 == HistogramBin(lo = 30, hi = 50, ndv = 20) && h1 == 40) + + val expectedRanges = Seq( + OverlappedRange(30, 30, 1, 1, 300, 40 / 20), + OverlappedRange(30, 50, 20, 20, 200, 40), + OverlappedRange(50, 50, 1, 1, 200 / 20, 100) + ) + assert(expectedRanges.equals( + getOverlappedRanges(histogram1, histogram2, lowerBound = 30, upperBound = 50))) + + estimateByHistogram( + leftHistogram = histogram1, + rightHistogram = histogram2, + expectedMin = 30, + expectedMax = 50, + expectedNdv = 1 + 20, + expectedRows = 300 * 2 / 1 + 200 * 40 / 20 + 10 * 100 / 1) + } + + test("equi-height histograms: skew distribution (both histograms have the same skewed value") { + val histogram1 = Histogram(height = 300, Array( + HistogramBin(lo = 30, hi = 30, ndv = 1), HistogramBin(lo = 30, hi = 60, ndv = 30))) + val histogram2 = Histogram(height = 150, Array( + HistogramBin(lo = 0, hi = 30, ndv = 30), HistogramBin(lo = 30, hi = 30, ndv = 1))) + // test bin trimming + val (t0, h0) = trimBin(histogram1.bins(1), height = 300, lowerBound = 30, upperBound = 30) + assert(t0 == HistogramBin(lo = 30, hi = 30, ndv = 1) && h0 == 10) + val (t1, h1) = trimBin(histogram2.bins(0), height = 150, lowerBound = 30, upperBound = 30) + assert(t1 == HistogramBin(lo = 30, hi = 30, ndv = 1) && h1 == 5) + + val expectedRanges = Seq( + OverlappedRange(30, 30, 1, 1, 300, 5), + OverlappedRange(30, 30, 1, 1, 300, 150), + OverlappedRange(30, 30, 1, 1, 10, 5), + OverlappedRange(30, 30, 1, 1, 10, 150) + ) + assert(expectedRanges.equals( + getOverlappedRanges(histogram1, histogram2, lowerBound = 30, upperBound = 30))) + + estimateByHistogram( + leftHistogram = histogram1, + rightHistogram = histogram2, + expectedMin = 30, + expectedMax = 30, + // only one value: 30 + expectedNdv = 1, + expectedRows = 300 * 5 / 1 + 300 * 150 / 1 + 10 * 5 / 1 + 10 * 150 / 1) + } + test("cross join") { // table1 (key-1-5 int, key-5-9 int): (1, 9), (2, 8), (3, 7), (4, 6), (5, 5) // table2 (key-1-2 int, key-2-4 int): (1, 2), (2, 3), (2, 4)