Skip to content

Commit 7dd9fc6

Browse files
kanzhangmateiz
authored andcommitted
[SPARK-1837] NumericRange should be partitioned in the same way as other...
... sequences Author: Kan Zhang <[email protected]> Closes apache#776 from kanzhang/SPARK-1837 and squashes the following commits: e48f018 [Kan Zhang] [SPARK-1837] code refactoring 67c33b5 [Kan Zhang] minor change 403f9b1 [Kan Zhang] [SPARK-1837] NumericRange should be partitioned in the same way as other sequences
1 parent b52603b commit 7dd9fc6

File tree

2 files changed

+37
-12
lines changed

2 files changed

+37
-12
lines changed

core/src/main/scala/org/apache/spark/rdd/ParallelCollectionRDD.scala

Lines changed: 19 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,15 @@ private object ParallelCollectionRDD {
117117
if (numSlices < 1) {
118118
throw new IllegalArgumentException("Positive number of slices required")
119119
}
120+
// Sequences need to be sliced at the same set of index positions for operations
121+
// like RDD.zip() to behave as expected
122+
def positions(length: Long, numSlices: Int): Iterator[(Int, Int)] = {
123+
(0 until numSlices).iterator.map(i => {
124+
val start = ((i * length) / numSlices).toInt
125+
val end = (((i + 1) * length) / numSlices).toInt
126+
(start, end)
127+
})
128+
}
120129
seq match {
121130
case r: Range.Inclusive => {
122131
val sign = if (r.step < 0) {
@@ -128,30 +137,28 @@ private object ParallelCollectionRDD {
128137
r.start, r.end + sign, r.step).asInstanceOf[Seq[T]], numSlices)
129138
}
130139
case r: Range => {
131-
(0 until numSlices).map(i => {
132-
val start = ((i * r.length.toLong) / numSlices).toInt
133-
val end = (((i + 1) * r.length.toLong) / numSlices).toInt
134-
new Range(r.start + start * r.step, r.start + end * r.step, r.step)
135-
}).asInstanceOf[Seq[Seq[T]]]
140+
positions(r.length, numSlices).map({
141+
case (start, end) =>
142+
new Range(r.start + start * r.step, r.start + end * r.step, r.step)
143+
}).toSeq.asInstanceOf[Seq[Seq[T]]]
136144
}
137145
case nr: NumericRange[_] => {
138146
// For ranges of Long, Double, BigInteger, etc
139147
val slices = new ArrayBuffer[Seq[T]](numSlices)
140-
val sliceSize = (nr.size + numSlices - 1) / numSlices // Round up to catch everything
141148
var r = nr
142-
for (i <- 0 until numSlices) {
149+
for ((start, end) <- positions(nr.length, numSlices)) {
150+
val sliceSize = end - start
143151
slices += r.take(sliceSize).asInstanceOf[Seq[T]]
144152
r = r.drop(sliceSize)
145153
}
146154
slices
147155
}
148156
case _ => {
149157
val array = seq.toArray // To prevent O(n^2) operations for List etc
150-
(0 until numSlices).map(i => {
151-
val start = ((i * array.length.toLong) / numSlices).toInt
152-
val end = (((i + 1) * array.length.toLong) / numSlices).toInt
153-
array.slice(start, end).toSeq
154-
})
158+
positions(array.length, numSlices).map({
159+
case (start, end) =>
160+
array.slice(start, end).toSeq
161+
}).toSeq
155162
}
156163
}
157164
}

core/src/test/scala/org/apache/spark/rdd/ParallelCollectionSplitSuite.scala

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,24 @@ class ParallelCollectionSplitSuite extends FunSuite with Checkers {
111111
assert(slices.forall(_.isInstanceOf[Range]))
112112
}
113113

114+
test("identical slice sizes between Range and NumericRange") {
115+
val r = ParallelCollectionRDD.slice(1 to 7, 4)
116+
val nr = ParallelCollectionRDD.slice(1L to 7L, 4)
117+
assert(r.size === 4)
118+
for (i <- 0 until r.size) {
119+
assert(r(i).size === nr(i).size)
120+
}
121+
}
122+
123+
test("identical slice sizes between List and NumericRange") {
124+
val r = ParallelCollectionRDD.slice(List(1, 2), 4)
125+
val nr = ParallelCollectionRDD.slice(1L to 2L, 4)
126+
assert(r.size === 4)
127+
for (i <- 0 until r.size) {
128+
assert(r(i).size === nr(i).size)
129+
}
130+
}
131+
114132
test("large ranges don't overflow") {
115133
val N = 100 * 1000 * 1000
116134
val data = 0 until N

0 commit comments

Comments
 (0)