Skip to content

Commit 302d68d

Browse files
brkyvztdas
authored andcommitted
[SPARK-12058][STREAMING][KINESIS][TESTS] fix Kinesis python tests
Python tests require access to the `KinesisTestUtils` file. When this file exists under src/test, python can't access it, since it is not available in the assembly jar. However, if we move KinesisTestUtils to src/main, we need to add the KinesisProducerLibrary as a dependency. In order to avoid this, I moved KinesisTestUtils to src/main, and extended it with ExtendedKinesisTestUtils which is under src/test that adds support for the KPL. cc zsxwing tdas Author: Burak Yavuz <[email protected]> Closes #10050 from brkyvz/kinesis-py.
1 parent d0d8222 commit 302d68d

File tree

5 files changed

+115
-50
lines changed

5 files changed

+115
-50
lines changed
Lines changed: 41 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -31,13 +31,13 @@ import com.amazonaws.services.dynamodbv2.AmazonDynamoDBClient
3131
import com.amazonaws.services.dynamodbv2.document.DynamoDB
3232
import com.amazonaws.services.kinesis.AmazonKinesisClient
3333
import com.amazonaws.services.kinesis.model._
34-
import com.amazonaws.services.kinesis.producer.{KinesisProducer, KinesisProducerConfiguration, UserRecordResult}
35-
import com.google.common.util.concurrent.{FutureCallback, Futures}
3634

3735
import org.apache.spark.Logging
3836

3937
/**
40-
* Shared utility methods for performing Kinesis tests that actually transfer data
38+
* Shared utility methods for performing Kinesis tests that actually transfer data.
39+
*
40+
* PLEASE KEEP THIS FILE UNDER src/main AS PYTHON TESTS NEED ACCESS TO THIS FILE!
4141
*/
4242
private[kinesis] class KinesisTestUtils extends Logging {
4343

@@ -54,7 +54,7 @@ private[kinesis] class KinesisTestUtils extends Logging {
5454
@volatile
5555
private var _streamName: String = _
5656

57-
private lazy val kinesisClient = {
57+
protected lazy val kinesisClient = {
5858
val client = new AmazonKinesisClient(KinesisTestUtils.getAWSCredentials())
5959
client.setEndpoint(endpointUrl)
6060
client
@@ -66,14 +66,12 @@ private[kinesis] class KinesisTestUtils extends Logging {
6666
new DynamoDB(dynamoDBClient)
6767
}
6868

69-
private lazy val kinesisProducer: KinesisProducer = {
70-
val conf = new KinesisProducerConfiguration()
71-
.setRecordMaxBufferedTime(1000)
72-
.setMaxConnections(1)
73-
.setRegion(regionName)
74-
.setMetricsLevel("none")
75-
76-
new KinesisProducer(conf)
69+
protected def getProducer(aggregate: Boolean): KinesisDataGenerator = {
70+
if (!aggregate) {
71+
new SimpleDataGenerator(kinesisClient)
72+
} else {
73+
throw new UnsupportedOperationException("Aggregation is not supported through this code path")
74+
}
7775
}
7876

7977
def streamName: String = {
@@ -104,41 +102,8 @@ private[kinesis] class KinesisTestUtils extends Logging {
104102
*/
105103
def pushData(testData: Seq[Int], aggregate: Boolean): Map[String, Seq[(Int, String)]] = {
106104
require(streamCreated, "Stream not yet created, call createStream() to create one")
107-
val shardIdToSeqNumbers = new mutable.HashMap[String, ArrayBuffer[(Int, String)]]()
108-
109-
testData.foreach { num =>
110-
val str = num.toString
111-
val data = ByteBuffer.wrap(str.getBytes())
112-
if (aggregate) {
113-
val future = kinesisProducer.addUserRecord(streamName, str, data)
114-
val kinesisCallBack = new FutureCallback[UserRecordResult]() {
115-
override def onFailure(t: Throwable): Unit = {} // do nothing
116-
117-
override def onSuccess(result: UserRecordResult): Unit = {
118-
val shardId = result.getShardId
119-
val seqNumber = result.getSequenceNumber()
120-
val sentSeqNumbers = shardIdToSeqNumbers.getOrElseUpdate(shardId,
121-
new ArrayBuffer[(Int, String)]())
122-
sentSeqNumbers += ((num, seqNumber))
123-
}
124-
}
125-
126-
Futures.addCallback(future, kinesisCallBack)
127-
kinesisProducer.flushSync() // make sure we send all data before returning the map
128-
} else {
129-
val putRecordRequest = new PutRecordRequest().withStreamName(streamName)
130-
.withData(data)
131-
.withPartitionKey(str)
132-
133-
val putRecordResult = kinesisClient.putRecord(putRecordRequest)
134-
val shardId = putRecordResult.getShardId
135-
val seqNumber = putRecordResult.getSequenceNumber()
136-
val sentSeqNumbers = shardIdToSeqNumbers.getOrElseUpdate(shardId,
137-
new ArrayBuffer[(Int, String)]())
138-
sentSeqNumbers += ((num, seqNumber))
139-
}
140-
}
141-
105+
val producer = getProducer(aggregate)
106+
val shardIdToSeqNumbers = producer.sendData(streamName, testData)
142107
logInfo(s"Pushed $testData:\n\t ${shardIdToSeqNumbers.mkString("\n\t")}")
143108
shardIdToSeqNumbers.toMap
144109
}
@@ -264,3 +229,32 @@ private[kinesis] object KinesisTestUtils {
264229
}
265230
}
266231
}
232+
233+
/** A wrapper interface that will allow us to consolidate the code for synthetic data generation. */
234+
private[kinesis] trait KinesisDataGenerator {
235+
/** Sends the data to Kinesis and returns the metadata for everything that has been sent. */
236+
def sendData(streamName: String, data: Seq[Int]): Map[String, Seq[(Int, String)]]
237+
}
238+
239+
private[kinesis] class SimpleDataGenerator(
240+
client: AmazonKinesisClient) extends KinesisDataGenerator {
241+
override def sendData(streamName: String, data: Seq[Int]): Map[String, Seq[(Int, String)]] = {
242+
val shardIdToSeqNumbers = new mutable.HashMap[String, ArrayBuffer[(Int, String)]]()
243+
data.foreach { num =>
244+
val str = num.toString
245+
val data = ByteBuffer.wrap(str.getBytes())
246+
val putRecordRequest = new PutRecordRequest().withStreamName(streamName)
247+
.withData(data)
248+
.withPartitionKey(str)
249+
250+
val putRecordResult = client.putRecord(putRecordRequest)
251+
val shardId = putRecordResult.getShardId
252+
val seqNumber = putRecordResult.getSequenceNumber()
253+
val sentSeqNumbers = shardIdToSeqNumbers.getOrElseUpdate(shardId,
254+
new ArrayBuffer[(Int, String)]())
255+
sentSeqNumbers += ((num, seqNumber))
256+
}
257+
258+
shardIdToSeqNumbers.toMap
259+
}
260+
}
Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
package org.apache.spark.streaming.kinesis
18+
19+
import java.nio.ByteBuffer
20+
21+
import scala.collection.mutable
22+
import scala.collection.mutable.ArrayBuffer
23+
24+
import com.amazonaws.services.kinesis.producer.{KinesisProducer => KPLProducer, KinesisProducerConfiguration, UserRecordResult}
25+
import com.google.common.util.concurrent.{FutureCallback, Futures}
26+
27+
private[kinesis] class KPLBasedKinesisTestUtils extends KinesisTestUtils {
28+
override protected def getProducer(aggregate: Boolean): KinesisDataGenerator = {
29+
if (!aggregate) {
30+
new SimpleDataGenerator(kinesisClient)
31+
} else {
32+
new KPLDataGenerator(regionName)
33+
}
34+
}
35+
}
36+
37+
/** A wrapper for the KinesisProducer provided in the KPL. */
38+
private[kinesis] class KPLDataGenerator(regionName: String) extends KinesisDataGenerator {
39+
40+
private lazy val producer: KPLProducer = {
41+
val conf = new KinesisProducerConfiguration()
42+
.setRecordMaxBufferedTime(1000)
43+
.setMaxConnections(1)
44+
.setRegion(regionName)
45+
.setMetricsLevel("none")
46+
47+
new KPLProducer(conf)
48+
}
49+
50+
override def sendData(streamName: String, data: Seq[Int]): Map[String, Seq[(Int, String)]] = {
51+
val shardIdToSeqNumbers = new mutable.HashMap[String, ArrayBuffer[(Int, String)]]()
52+
data.foreach { num =>
53+
val str = num.toString
54+
val data = ByteBuffer.wrap(str.getBytes())
55+
val future = producer.addUserRecord(streamName, str, data)
56+
val kinesisCallBack = new FutureCallback[UserRecordResult]() {
57+
override def onFailure(t: Throwable): Unit = {} // do nothing
58+
59+
override def onSuccess(result: UserRecordResult): Unit = {
60+
val shardId = result.getShardId
61+
val seqNumber = result.getSequenceNumber()
62+
val sentSeqNumbers = shardIdToSeqNumbers.getOrElseUpdate(shardId,
63+
new ArrayBuffer[(Int, String)]())
64+
sentSeqNumbers += ((num, seqNumber))
65+
}
66+
}
67+
Futures.addCallback(future, kinesisCallBack)
68+
}
69+
producer.flushSync()
70+
shardIdToSeqNumbers.toMap
71+
}
72+
}

extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisBackedBlockRDDSuite.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ abstract class KinesisBackedBlockRDDTests(aggregateTestData: Boolean)
4040

4141
override def beforeAll(): Unit = {
4242
runIfTestsEnabled("Prepare KinesisTestUtils") {
43-
testUtils = new KinesisTestUtils()
43+
testUtils = new KPLBasedKinesisTestUtils()
4444
testUtils.createStream()
4545

4646
shardIdToDataAndSeqNumbers = testUtils.pushData(testData, aggregate = aggregateTestData)

extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisStreamSuite.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ abstract class KinesisStreamTests(aggregateTestData: Boolean) extends KinesisFun
6363
sc = new SparkContext(conf)
6464

6565
runIfTestsEnabled("Prepare KinesisTestUtils") {
66-
testUtils = new KinesisTestUtils()
66+
testUtils = new KPLBasedKinesisTestUtils()
6767
testUtils.createStream()
6868
}
6969
}

python/pyspark/streaming/tests.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1458,7 +1458,6 @@ def test_kinesis_stream_api(self):
14581458
InitialPositionInStream.LATEST, 2, StorageLevel.MEMORY_AND_DISK_2,
14591459
"awsAccessKey", "awsSecretKey")
14601460

1461-
@unittest.skip("Enable it when we fix SPAKR-12058")
14621461
def test_kinesis_stream(self):
14631462
if not are_kinesis_tests_enabled:
14641463
sys.stderr.write(

0 commit comments

Comments
 (0)