Skip to content

Commit 620f072

Browse files
zhengruifengmaropu
authored andcommitted
[SPARK-35231][SQL] logical.Range override maxRowsPerPartition
### What changes were proposed in this pull request? when `numSlices` is avaiable, `logical.Range` should compute a exact `maxRowsPerPartition` ### Why are the changes needed? `maxRowsPerPartition` is used in optimizer, we should provide an exact value if possible ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? existing testsuites Closes #32350 from zhengruifeng/range_maxRowsPerPartition. Authored-by: Ruifeng Zheng <[email protected]> Signed-off-by: Takeshi Yamamuro <[email protected]>
1 parent 5b65d8a commit 620f072

File tree

2 files changed

+22
-1
lines changed

2 files changed

+22
-1
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@ case class Project(projectList: Seq[NamedExpression], child: LogicalPlan)
6969
extends OrderPreservingUnaryNode {
7070
override def output: Seq[Attribute] = projectList.map(_.toAttribute)
7171
override def maxRows: Option[Long] = child.maxRows
72+
override def maxRowsPerPartition: Option[Long] = child.maxRowsPerPartition
7273

7374
final override val nodePatterns: Seq[TreePattern] = Seq(PROJECT)
7475

@@ -163,6 +164,7 @@ case class Filter(condition: Expression, child: LogicalPlan)
163164
override def output: Seq[Attribute] = child.output
164165

165166
override def maxRows: Option[Long] = child.maxRows
167+
override def maxRowsPerPartition: Option[Long] = child.maxRowsPerPartition
166168

167169
final override val nodePatterns: Seq[TreePattern] = Seq(FILTER)
168170

@@ -746,6 +748,16 @@ case class Range(
746748
}
747749
}
748750

751+
override def maxRowsPerPartition: Option[Long] = {
752+
if (numSlices.isDefined) {
753+
var m = numElements / numSlices.get
754+
if (numElements % numSlices.get != 0) m += 1
755+
if (m.isValidLong) Some(m.toLong) else maxRows
756+
} else {
757+
maxRows
758+
}
759+
}
760+
749761
override def computeStats(): Statistics = {
750762
if (numElements == 0) {
751763
Statistics(sizeInBytes = 0, rowCount = Some(0))

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/LogicalPlanSuite.scala

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,9 @@
1818
package org.apache.spark.sql.catalyst.plans
1919

2020
import org.apache.spark.SparkFunSuite
21-
import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeReference, Literal, NamedExpression}
21+
import org.apache.spark.sql.catalyst.dsl.expressions._
22+
import org.apache.spark.sql.catalyst.dsl.plans._
23+
import org.apache.spark.sql.catalyst.expressions._
2224
import org.apache.spark.sql.catalyst.plans.logical._
2325
import org.apache.spark.sql.types.IntegerType
2426

@@ -96,4 +98,11 @@ class LogicalPlanSuite extends SparkFunSuite {
9698
OneRowRelation())
9799
assert(result.sameResult(expected))
98100
}
101+
102+
test("SPARK-35231: logical.Range override maxRowsPerPartition") {
103+
assert(Range(0, 100, 1, 3).maxRowsPerPartition === Some(34))
104+
assert(Range(0, 100, 1, 4).maxRowsPerPartition === Some(25))
105+
assert(Range(0, 100, 1, 3).select('id).maxRowsPerPartition === Some(34))
106+
assert(Range(0, 100, 1, 3).where('id % 2 === 1).maxRowsPerPartition === Some(34))
107+
}
99108
}

0 commit comments

Comments
 (0)