1717
1818package org .apache .spark .streaming .mqtt
1919
20- import java .net .{URI , ServerSocket }
21- import java .util .concurrent .CountDownLatch
22- import java .util .concurrent .TimeUnit
23-
2420import scala .concurrent .duration ._
2521import scala .language .postfixOps
2622
27- import org .apache .activemq .broker .{TransportConnector , BrokerService }
28- import org .apache .commons .lang3 .RandomUtils
29- import org .eclipse .paho .client .mqttv3 ._
30- import org .eclipse .paho .client .mqttv3 .persist .MqttDefaultFilePersistence
31-
32- import org .scalatest .BeforeAndAfter
23+ import org .scalatest .BeforeAndAfterAll
3324import org .scalatest .concurrent .Eventually
3425
35- import org .apache .spark .streaming .{Milliseconds , StreamingContext }
36- import org .apache .spark .storage .StorageLevel
37- import org .apache .spark .streaming .dstream .ReceiverInputDStream
38- import org .apache .spark .streaming .scheduler .StreamingListener
39- import org .apache .spark .streaming .scheduler .StreamingListenerReceiverStarted
4026import org .apache .spark .{SparkConf , SparkFunSuite }
41- import org .apache .spark .util .Utils
42-
43- class MQTTStreamSuite extends SparkFunSuite with Eventually with BeforeAndAfter {
27+ import org .apache .spark .storage .StorageLevel
28+ import org .apache .spark .streaming .{Milliseconds , StreamingContext }
4429
45- private val batchDuration = Milliseconds (500 )
46- private val master = " local[2]"
47- private val framework = this .getClass.getSimpleName
48- private val freePort = findFreePort()
49- private val brokerUri = " //localhost:" + freePort
50- private val topic = " def"
51- private val persistenceDir = Utils .createTempDir()
30+ class MQTTStreamSuite extends SparkFunSuite with Eventually with BeforeAndAfterAll {
5231
32+ private val topic = " topic"
5333 private var ssc : StreamingContext = _
54- private var broker : BrokerService = _
55- private var connector : TransportConnector = _
34+ private var MQTTTestUtils : MQTTTestUtils = _
5635
57- before {
58- ssc = new StreamingContext (master, framework, batchDuration)
59- setupMQTT ()
36+ override def beforeAll () : Unit = {
37+ MQTTTestUtils = new MQTTTestUtils
38+ MQTTTestUtils .setup ()
6039 }
6140
62- after {
41+ override def afterAll () : Unit = {
6342 if (ssc != null ) {
6443 ssc.stop()
6544 ssc = null
6645 }
67- Utils .deleteRecursively(persistenceDir)
68- tearDownMQTT()
46+
47+ if (MQTTTestUtils != null ) {
48+ MQTTTestUtils .teardown()
49+ MQTTTestUtils = null
50+ }
6951 }
7052
7153 test(" mqtt input stream" ) {
54+ val sparkConf = new SparkConf ().setMaster(" local[4]" ).setAppName(this .getClass.getSimpleName)
55+ ssc = new StreamingContext (sparkConf, Milliseconds (500 ))
7256 val sendMessage = " MQTT demo for spark streaming"
7357 val receiveStream =
74- MQTTUtils .createStream(ssc, " tcp:" + brokerUri, topic, StorageLevel .MEMORY_ONLY )
58+ MQTTUtils .createStream(ssc, " tcp:// " + MQTTTestUtils . brokerUri, topic, StorageLevel .MEMORY_ONLY )
7559 @ volatile var receiveMessage : List [String ] = List ()
7660 receiveStream.foreachRDD { rdd =>
7761 if (rdd.collect.length > 0 ) {
@@ -83,85 +67,13 @@ class MQTTStreamSuite extends SparkFunSuite with Eventually with BeforeAndAfter
8367
8468 // wait for the receiver to start before publishing data, or we risk failing
8569 // the test nondeterministically. See SPARK-4631
86- waitForReceiverToStart()
70+ MQTTTestUtils .waitForReceiverToStart(ssc)
71+
72+ MQTTTestUtils .publishData(topic, sendMessage)
8773
88- publishData(sendMessage)
8974 eventually(timeout(10000 milliseconds), interval(100 milliseconds)) {
9075 assert(sendMessage.equals(receiveMessage(0 )))
9176 }
9277 ssc.stop()
9378 }
94-
95- private def setupMQTT () {
96- broker = new BrokerService ()
97- broker.setDataDirectoryFile(Utils .createTempDir())
98- connector = new TransportConnector ()
99- connector.setName(" mqtt" )
100- connector.setUri(new URI (" mqtt:" + brokerUri))
101- broker.addConnector(connector)
102- broker.start()
103- }
104-
105- private def tearDownMQTT () {
106- if (broker != null ) {
107- broker.stop()
108- broker = null
109- }
110- if (connector != null ) {
111- connector.stop()
112- connector = null
113- }
114- }
115-
116- private def findFreePort (): Int = {
117- val candidatePort = RandomUtils .nextInt(1024 , 65536 )
118- Utils .startServiceOnPort(candidatePort, (trialPort : Int ) => {
119- val socket = new ServerSocket (trialPort)
120- socket.close()
121- (null , trialPort)
122- }, new SparkConf ())._2
123- }
124-
125- def publishData (data : String ): Unit = {
126- var client : MqttClient = null
127- try {
128- val persistence = new MqttDefaultFilePersistence (persistenceDir.getAbsolutePath)
129- client = new MqttClient (" tcp:" + brokerUri, MqttClient .generateClientId(), persistence)
130- client.connect()
131- if (client.isConnected) {
132- val msgTopic = client.getTopic(topic)
133- val message = new MqttMessage (data.getBytes(" utf-8" ))
134- message.setQos(1 )
135- message.setRetained(true )
136-
137- for (i <- 0 to 10 ) {
138- try {
139- msgTopic.publish(message)
140- } catch {
141- case e : MqttException if e.getReasonCode == MqttException .REASON_CODE_MAX_INFLIGHT =>
142- // wait for Spark streaming to consume something from the message queue
143- Thread .sleep(50 )
144- }
145- }
146- }
147- } finally {
148- client.disconnect()
149- client.close()
150- client = null
151- }
152- }
153-
154- /**
155- * Block until at least one receiver has started or timeout occurs.
156- */
157- private def waitForReceiverToStart () = {
158- val latch = new CountDownLatch (1 )
159- ssc.addStreamingListener(new StreamingListener {
160- override def onReceiverStarted (receiverStarted : StreamingListenerReceiverStarted ) {
161- latch.countDown()
162- }
163- })
164-
165- assert(latch.await(10 , TimeUnit .SECONDS ), " Timeout waiting for receiver to start." )
166- }
16779}
0 commit comments