Skip to content

Commit 5b45b1b

Browse files
committed
polish and tests
1 parent 554eafb commit 5b45b1b

File tree

5 files changed

+360
-122
lines changed

5 files changed

+360
-122
lines changed

sql/core/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,3 +5,4 @@ org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat
55
org.apache.spark.sql.execution.datasources.text.TextFileFormat
66
org.apache.spark.sql.execution.streaming.ConsoleSinkProvider
77
org.apache.spark.sql.execution.streaming.TextSocketSourceProvider
8+
org.apache.spark.sql.execution.streaming.RateSourceProvider
Lines changed: 208 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,208 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql.execution.streaming
19+
20+
import java.io._
21+
import java.nio.charset.StandardCharsets
22+
import java.util.concurrent.TimeUnit
23+
24+
import org.apache.commons.io.IOUtils
25+
26+
import org.apache.spark.internal.Logging
27+
import org.apache.spark.sql.{DataFrame, SQLContext}
28+
import org.apache.spark.sql.catalyst.InternalRow
29+
import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, DateTimeUtils}
30+
import org.apache.spark.sql.sources.{DataSourceRegister, StreamSourceProvider}
31+
import org.apache.spark.sql.types._
32+
import org.apache.spark.util.{ManualClock, SystemClock}
33+
34+
/**
35+
* A source that generates increment long values with timestamps. Each generated row has two
36+
* columns: a timestamp column for the generated time and an auto increment long column starting
37+
* with 0L.
38+
*
39+
* This source supports the following options:
40+
* - `tuplesPerSecond` (default: 1): How many tuples should be generated per second.
41+
* - `rampUpTimeSeconds` (default: 0): How many seconds to ramp up before the generating speed
42+
* becomes `tuplesPerSecond`.
43+
* - `numPartitions` (default: Spark's default parallelism): The partition number for the generated
44+
* tuples.
45+
*/
46+
class RateSourceProvider extends StreamSourceProvider with DataSourceRegister {
47+
48+
override def sourceSchema(
49+
sqlContext: SQLContext,
50+
schema: Option[StructType],
51+
providerName: String,
52+
parameters: Map[String, String]): (String, StructType) =
53+
(shortName(), RateSourceProvider.SCHEMA)
54+
55+
override def createSource(
56+
sqlContext: SQLContext,
57+
metadataPath: String,
58+
schema: Option[StructType],
59+
providerName: String,
60+
parameters: Map[String, String]): Source = {
61+
val params = CaseInsensitiveMap(parameters)
62+
63+
val tuplesPerSecond = params.get("tuplesPerSecond").map(_.toLong).getOrElse(1L)
64+
if (tuplesPerSecond <= 0) {
65+
throw new IllegalArgumentException(
66+
s"Invalid value '${params("tuplesPerSecond")}' for option 'tuplesPerSecond', " +
67+
"must be positive")
68+
}
69+
70+
val rampUpTimeSeconds = params.get("rampUpTimeSeconds").map(_.toLong).getOrElse(0L)
71+
if (rampUpTimeSeconds < 0) {
72+
throw new IllegalArgumentException(
73+
s"Invalid value '${params("rampUpTimeSeconds")}' for option 'rampUpTimeSeconds', " +
74+
"must not be negative")
75+
}
76+
77+
val numPartitions = params.get("numPartitions").map(_.toInt).getOrElse(
78+
sqlContext.sparkContext.defaultParallelism)
79+
if (numPartitions <= 0) {
80+
throw new IllegalArgumentException(
81+
s"Invalid value '${params("numPartitions")}' for option 'numPartitions', " +
82+
"must be positive")
83+
}
84+
85+
new RateStreamSource(
86+
sqlContext,
87+
metadataPath,
88+
tuplesPerSecond,
89+
rampUpTimeSeconds,
90+
numPartitions,
91+
params.get("useManualClock").map(_.toBoolean).getOrElse(false) // Only for testing
92+
)
93+
}
94+
override def shortName(): String = "rate"
95+
}
96+
97+
object RateSourceProvider {
98+
val SCHEMA =
99+
StructType(StructField("timestamp", TimestampType) :: StructField("value", LongType) :: Nil)
100+
101+
val VERSION = 1
102+
}
103+
104+
class RateStreamSource(
105+
sqlContext: SQLContext,
106+
metadataPath: String,
107+
tuplesPerSecond: Long,
108+
rampUpTimeSeconds: Long,
109+
numPartitions: Int,
110+
useManualClock: Boolean) extends Source with Logging {
111+
112+
import RateSourceProvider._
113+
114+
val clock = if (useManualClock) new ManualClock else new SystemClock
115+
116+
private val maxSeconds = Long.MaxValue / tuplesPerSecond
117+
118+
if (rampUpTimeSeconds > maxSeconds) {
119+
throw new ArithmeticException("integer overflow. Max offset with tuplesPerSecond " +
120+
s"$tuplesPerSecond is $maxSeconds, but 'rampUpTimeSeconds' is $rampUpTimeSeconds.")
121+
}
122+
123+
private val startTimeMs = {
124+
val metadataLog =
125+
new HDFSMetadataLog[LongOffset](sqlContext.sparkSession, metadataPath) {
126+
override def serialize(metadata: LongOffset, out: OutputStream): Unit = {
127+
val writer = new BufferedWriter(new OutputStreamWriter(out, StandardCharsets.UTF_8))
128+
writer.write("v" + VERSION + "\n")
129+
writer.write(metadata.json)
130+
writer.flush
131+
}
132+
133+
override def deserialize(in: InputStream): LongOffset = {
134+
val content = IOUtils.toString(new InputStreamReader(in, StandardCharsets.UTF_8))
135+
// HDFSMetadataLog guarantees that it never creates a partial file.
136+
assert(content.length != 0)
137+
if (content(0) == 'v') {
138+
val indexOfNewLine = content.indexOf("\n")
139+
if (indexOfNewLine > 0) {
140+
val version = parseVersion(content.substring(0, indexOfNewLine), VERSION)
141+
LongOffset(SerializedOffset(content.substring(indexOfNewLine + 1)))
142+
} else {
143+
throw new IllegalStateException(
144+
s"Log file was malformed: failed to detect the log file version line.")
145+
}
146+
} else {
147+
throw new IllegalStateException(
148+
s"Log file was malformed: failed to detect the log file version line.")
149+
}
150+
}
151+
}
152+
153+
metadataLog.get(0).getOrElse {
154+
val offset = LongOffset(clock.getTimeMillis())
155+
metadataLog.add(0, offset)
156+
logInfo(s"Start time: $offset")
157+
offset
158+
}.offset
159+
}
160+
161+
/** When the system time runs backward, "lastTimeMs" will make sure we are still monotonic. */
162+
@volatile private var lastTimeMs = startTimeMs
163+
164+
override def schema: StructType = RateSourceProvider.SCHEMA
165+
166+
override def getOffset: Option[Offset] = {
167+
val now = clock.getTimeMillis()
168+
if (lastTimeMs < now) {
169+
lastTimeMs = now
170+
}
171+
Some(LongOffset(TimeUnit.MILLISECONDS.toSeconds(lastTimeMs - startTimeMs)))
172+
}
173+
174+
override def getBatch(start: Option[Offset], end: Offset): DataFrame = {
175+
val startSeconds = start.flatMap(LongOffset.convert(_).map(_.offset)).getOrElse(0L)
176+
val endSeconds = LongOffset.convert(end).map(_.offset).getOrElse(0L)
177+
assert(startSeconds <= endSeconds)
178+
if (endSeconds > maxSeconds) {
179+
throw new ArithmeticException("integer overflow. Max offset with " +
180+
s"tuplesPerSecond $tuplesPerSecond is $maxSeconds, but it's $endSeconds now.")
181+
}
182+
// Fix "lastTimeMs" for recovery
183+
if (lastTimeMs < TimeUnit.SECONDS.toMillis(endSeconds) + startTimeMs) {
184+
lastTimeMs = TimeUnit.SECONDS.toMillis(endSeconds) + startTimeMs
185+
}
186+
val (rangeStart, rangeEnd) = if (rampUpTimeSeconds > endSeconds) {
187+
(math.rint(tuplesPerSecond * (startSeconds * 1.0 / rampUpTimeSeconds)).toLong * startSeconds,
188+
math.rint(tuplesPerSecond * (endSeconds * 1.0 / rampUpTimeSeconds)).toLong * endSeconds)
189+
} else if (startSeconds < rampUpTimeSeconds) {
190+
(math.rint(tuplesPerSecond * (startSeconds * 1.0 / rampUpTimeSeconds)).toLong * startSeconds,
191+
endSeconds * tuplesPerSecond)
192+
} else {
193+
(startSeconds * tuplesPerSecond, endSeconds * tuplesPerSecond)
194+
}
195+
logDebug(s"startSeconds: $startSeconds, endSeconds: $endSeconds, " +
196+
s"rangeStart: $rangeStart, rangeEnd: $rangeEnd")
197+
val localStartTimeMs = startTimeMs
198+
val localPerSecond = tuplesPerSecond
199+
200+
val rdd = sqlContext.sparkContext.range(rangeStart, rangeEnd, 1, numPartitions).map { v =>
201+
val relative = v * 1000L / localPerSecond
202+
InternalRow(DateTimeUtils.fromMillis(relative + localStartTimeMs), v)
203+
}
204+
sqlContext.internalCreateDataFrame(rdd, schema)
205+
}
206+
207+
override def stop(): Unit = {}
208+
}

sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TimeSource.scala

Lines changed: 0 additions & 122 deletions
This file was deleted.

0 commit comments

Comments
 (0)