From a376c17c66b4c8f093be33c2a1535ce28cb1bb18 Mon Sep 17 00:00:00 2001 From: Yuming Wang Date: Sun, 26 Jun 2022 19:08:32 +0800 Subject: [PATCH] Use child stats to estimate order operator --- .../plans/logical/statsEstimation/BasicStatsPlanVisitor.scala | 4 +--- .../statsEstimation/SizeInBytesOnlyStatsPlanVisitor.scala | 2 +- .../catalyst/statsEstimation/BasicStatsEstimationSuite.scala | 2 +- 3 files changed, 3 insertions(+), 5 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/BasicStatsPlanVisitor.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/BasicStatsPlanVisitor.scala index 59a302b1af90..21799a5c683a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/BasicStatsPlanVisitor.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/BasicStatsPlanVisitor.scala @@ -102,9 +102,7 @@ object BasicStatsPlanVisitor extends LogicalPlanVisitor[Statistics] { override def visitWindow(p: Window): Statistics = fallback(p) - override def visitSort(p: Sort): Statistics = { - BasicStatsPlanVisitor.visit(p.child) - } + override def visitSort(p: Sort): Statistics = fallback(p) override def visitTail(p: Tail): Statistics = { fallback(p) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/SizeInBytesOnlyStatsPlanVisitor.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/SizeInBytesOnlyStatsPlanVisitor.scala index 311dd31a96b3..77c728ba7c57 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/SizeInBytesOnlyStatsPlanVisitor.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/SizeInBytesOnlyStatsPlanVisitor.scala @@ -162,7 +162,7 @@ object SizeInBytesOnlyStatsPlanVisitor extends LogicalPlanVisitor[Statistics] { override def visitWindow(p: Window): Statistics = visitUnaryNode(p) - override def visitSort(p: Sort): Statistics = default(p) + override def visitSort(p: Sort): Statistics = p.child.stats override def visitTail(p: Tail): Statistics = { val limit = p.limitExpr.eval().asInstanceOf[Int] diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/BasicStatsEstimationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/BasicStatsEstimationSuite.scala index 86a1cb4c3c5d..4362e0c5172d 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/BasicStatsEstimationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/BasicStatsEstimationSuite.scala @@ -360,7 +360,7 @@ class BasicStatsEstimationSuite extends PlanTest with StatsEstimationTestBase { checkStats( sort, expectedStatsCboOn = expectedSortStats, - expectedStatsCboOff = Statistics(sizeInBytes = expectedSize)) + expectedStatsCboOff = expectedSortStats) } /** Check estimated stats when cbo is turned on/off. */