Skip to content

Commit d5e7492

Browse files
committed
Use double to simpily the codes
1 parent f471651 commit d5e7492

File tree

2 files changed

+11
-66
lines changed

2 files changed

+11
-66
lines changed

sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/RateSourceProvider.scala

Lines changed: 6 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -199,52 +199,13 @@ class RateStreamSource(
199199
}
200200

201201
val localStartTimeMs = startTimeMs + TimeUnit.SECONDS.toMillis(startSeconds)
202-
val timeIntervalSizeMs = TimeUnit.SECONDS.toMillis(endSeconds - startSeconds)
202+
val relativeMsPerValue =
203+
TimeUnit.SECONDS.toMillis(endSeconds - startSeconds).toDouble / (rangeEnd - rangeStart)
203204

204-
val func =
205-
if (timeIntervalSizeMs < rangeEnd - rangeStart) {
206-
// Different rows may have the same timestamp
207-
val valueSizePerMs = (rangeEnd - rangeStart) / timeIntervalSizeMs
208-
val remainderValue = (rangeEnd - rangeStart) % timeIntervalSizeMs
209-
210-
(v: Long) => {
211-
val relativeValue = v - rangeStart
212-
val relativeMs = {
213-
// Increase the timestamp per "valueSizePerMs + 1" values before
214-
// "(valueSizePerMs + 1) * remainderValue", and increase the timestamp per
215-
// "valueSizePerMs" values for remaining values.
216-
217-
// The following condition is the same as
218-
// "relativeValue < (valueSizePerMs + 1) * remainderValue", just rewrite it to avoid
219-
// overflow.
220-
if (relativeValue - remainderValue < valueSizePerMs * remainderValue) {
221-
relativeValue / (valueSizePerMs + 1)
222-
} else {
223-
(relativeValue - remainderValue) / valueSizePerMs
224-
}
225-
}
226-
InternalRow(DateTimeUtils.fromMillis(relativeMs + localStartTimeMs), v)
227-
}
228-
} else {
229-
// Different rows never have the same timestamp
230-
val relativeMsPerValue = timeIntervalSizeMs / (rangeEnd - rangeStart)
231-
val remainderMs = timeIntervalSizeMs % (rangeEnd - rangeStart)
232-
233-
(v: Long) => {
234-
val relativeValue = v - rangeStart
235-
// The interval size for the first "remainderMs" values will be "relativeMsPerValue + 1",
236-
// and the interval size for remaining values will be "relativeMsPerValue".
237-
val relativeMs =
238-
if (relativeValue < remainderMs) {
239-
relativeValue * (relativeMsPerValue + 1)
240-
} else {
241-
remainderMs + relativeValue * relativeMsPerValue
242-
}
243-
InternalRow(DateTimeUtils.fromMillis(relativeMs + localStartTimeMs), v)
244-
}
245-
}
246-
247-
val rdd = sqlContext.sparkContext.range(rangeStart, rangeEnd, 1, numPartitions).map(func)
205+
val rdd = sqlContext.sparkContext.range(rangeStart, rangeEnd, 1, numPartitions).map { v =>
206+
val relative = math.round((v - rangeStart) * relativeMsPerValue)
207+
InternalRow(DateTimeUtils.fromMillis(relative + localStartTimeMs), v)
208+
}
248209
sqlContext.internalCreateDataFrame(rdd, schema)
249210
}
250211

sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/RateSourceSuite.scala

Lines changed: 5 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -56,33 +56,17 @@ class RateSourceSuite extends StreamTest {
5656
)
5757
}
5858

59-
test("uniform distribution of event timestamps: rowsPerSecond > 1000") {
59+
test("uniform distribution of event timestamps") {
6060
val input = spark.readStream
6161
.format("rate")
6262
.option("rowsPerSecond", "1500")
6363
.option("useManualClock", "true")
6464
.load()
6565
.as[(java.sql.Timestamp, Long)]
6666
.map(v => (v._1.getTime, v._2))
67-
val expectedAnswer =
68-
(0 until 1000).map(v => (v / 2, v)) ++ // Two values share the same timestamp.
69-
((1000 until 1500).map(v => (v - 500, v))) // Each value has one timestamp
70-
testStream(input)(
71-
AdvanceRateManualClock(seconds = 1),
72-
CheckLastBatch(expectedAnswer: _*)
73-
)
74-
}
75-
76-
test("uniform distribution of event timestamps: rowsPerSecond < 1000") {
77-
val input = spark.readStream
78-
.format("rate")
79-
.option("rowsPerSecond", "400")
80-
.option("useManualClock", "true")
81-
.load()
82-
.as[(java.sql.Timestamp, Long)]
83-
.map(v => (v._1.getTime, v._2))
84-
val expectedAnswer = (0 until 200).map(v => (v * 3, v)) ++
85-
((200 until 400).map(v => (600 + (v - 200) * 2, v)))
67+
val expectedAnswer = (0 until 1500).map { v =>
68+
(math.round(v * (1000.0 / 1500)), v)
69+
}
8670
testStream(input)(
8771
AdvanceRateManualClock(seconds = 1),
8872
CheckLastBatch(expectedAnswer: _*)
@@ -121,7 +105,7 @@ class RateSourceSuite extends StreamTest {
121105
CheckLastBatch((2 until 6).map(v => 1000 + (v - 2) * 250 -> v): _*), // speed = 4
122106
AdvanceRateManualClock(seconds = 1),
123107
CheckLastBatch({
124-
Seq(2000 -> 6, 2167 -> 7, 2334 -> 8, 2501 -> 9, 2668 -> 10, 2834 -> 11)
108+
Seq(2000 -> 6, 2167 -> 7, 2333 -> 8, 2500 -> 9, 2667 -> 10, 2833 -> 11)
125109
}: _*), // speed = 6
126110
AdvanceRateManualClock(seconds = 1),
127111
CheckLastBatch((12 until 20).map(v => 3000 + (v - 12) * 125 -> v): _*), // speed = 8

0 commit comments

Comments
 (0)