Skip to content

Commit 0192d2d

Browse files
committed
address
1 parent 9c208d7 commit 0192d2d

File tree

4 files changed

+55
-64
lines changed

4 files changed

+55
-64
lines changed

extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisTestUtils.scala

Lines changed: 29 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ private[kinesis] class KinesisTestUtils extends Logging {
5454
@volatile
5555
private var _streamName: String = _
5656

57-
private lazy val kinesisClient = {
57+
protected lazy val kinesisClient = {
5858
val client = new AmazonKinesisClient(KinesisTestUtils.getAWSCredentials())
5959
client.setEndpoint(endpointUrl)
6060
client
@@ -66,14 +66,11 @@ private[kinesis] class KinesisTestUtils extends Logging {
6666
new DynamoDB(dynamoDBClient)
6767
}
6868

69-
/** Left as a protected val so that we don't need to depend on KPL outside of tests. */
70-
protected val kplProducer: KinesisProducer = null
71-
72-
private def getProducer(aggregate: Boolean): KinesisProducer = {
73-
if (aggregate) {
74-
kplProducer
69+
protected def getProducer(aggregate: Boolean): KinesisDataGenerator = {
70+
if (!aggregate) {
71+
new SimpleDataGenerator(kinesisClient)
7572
} else {
76-
new KinesisClientProducer(kinesisClient)
73+
throw new UnsupportedOperationException("Aggregation is not supported through this code path")
7774
}
7875
}
7976

@@ -106,12 +103,7 @@ private[kinesis] class KinesisTestUtils extends Logging {
106103
def pushData(testData: Seq[Int], aggregate: Boolean): Map[String, Seq[(Int, String)]] = {
107104
require(streamCreated, "Stream not yet created, call createStream() to create one")
108105
val producer = getProducer(aggregate)
109-
110-
testData.foreach { num =>
111-
producer.putRecord(streamName, num)
112-
}
113-
114-
val shardIdToSeqNumbers = producer.flush()
106+
val shardIdToSeqNumbers = producer.sendData(streamName, testData)
115107
logInfo(s"Pushed $testData:\n\t ${shardIdToSeqNumbers.mkString("\n\t")}")
116108
shardIdToSeqNumbers.toMap
117109
}
@@ -239,31 +231,30 @@ private[kinesis] object KinesisTestUtils {
239231
}
240232

241233
/** A wrapper interface that will allow us to consolidate the code for synthetic data generation. */
242-
private[kinesis] trait KinesisProducer {
243-
/** Sends the data to Kinesis possibly with aggregation if KPL is used. */
244-
def putRecord(streamName: String, num: Int): Unit
245-
246-
/** Flush all data in the buffer and return the metadata for everything that has been sent. */
247-
def flush(): Map[String, Seq[(Int, String)]]
234+
private[kinesis] trait KinesisDataGenerator {
235+
/** Sends the data to Kinesis and returns the metadata for everything that has been sent. */
236+
def sendData(streamName: String, data: Seq[Int]): Map[String, Seq[(Int, String)]]
248237
}
249238

250-
private[kinesis] class KinesisClientProducer(client: AmazonKinesisClient) extends KinesisProducer {
251-
private val shardIdToSeqNumbers = new mutable.HashMap[String, ArrayBuffer[(Int, String)]]()
252-
253-
override def putRecord(streamName: String, num: Int): Unit = {
254-
val str = num.toString
255-
val data = ByteBuffer.wrap(str.getBytes())
256-
val putRecordRequest = new PutRecordRequest().withStreamName(streamName)
257-
.withData(data)
258-
.withPartitionKey(str)
259-
260-
val putRecordResult = client.putRecord(putRecordRequest)
261-
val shardId = putRecordResult.getShardId
262-
val seqNumber = putRecordResult.getSequenceNumber()
263-
val sentSeqNumbers = shardIdToSeqNumbers.getOrElseUpdate(shardId,
264-
new ArrayBuffer[(Int, String)]())
265-
sentSeqNumbers += ((num, seqNumber))
266-
}
239+
private[kinesis] class SimpleDataGenerator(
240+
client: AmazonKinesisClient) extends KinesisDataGenerator {
241+
override def sendData(streamName: String, data: Seq[Int]): Map[String, Seq[(Int, String)]] = {
242+
val shardIdToSeqNumbers = new mutable.HashMap[String, ArrayBuffer[(Int, String)]]()
243+
data.foreach { num =>
244+
val str = num.toString
245+
val data = ByteBuffer.wrap(str.getBytes())
246+
val putRecordRequest = new PutRecordRequest().withStreamName(streamName)
247+
.withData(data)
248+
.withPartitionKey(str)
249+
250+
val putRecordResult = client.putRecord(putRecordRequest)
251+
val shardId = putRecordResult.getShardId
252+
val seqNumber = putRecordResult.getSequenceNumber()
253+
val sentSeqNumbers = shardIdToSeqNumbers.getOrElseUpdate(shardId,
254+
new ArrayBuffer[(Int, String)]())
255+
sentSeqNumbers += ((num, seqNumber))
256+
}
267257

268-
override def flush(): Map[String, Seq[(Int, String)]] = shardIdToSeqNumbers.toMap
258+
shardIdToSeqNumbers.toMap
259+
}
269260
}
Lines changed: 24 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -24,14 +24,18 @@ import scala.collection.mutable.ArrayBuffer
2424
import com.amazonaws.services.kinesis.producer.{KinesisProducer => KPLProducer, KinesisProducerConfiguration, UserRecordResult}
2525
import com.google.common.util.concurrent.{FutureCallback, Futures}
2626

27-
private[kinesis] class ExtendedKinesisTestUtils extends KinesisTestUtils {
28-
override protected val kplProducer: KinesisProducer = {
29-
new KinesisProducerLibraryProducer(regionName)
27+
private[kinesis] class KPLBasedKinesisTestUtils extends KinesisTestUtils {
28+
override protected def getProducer(aggregate: Boolean): KinesisDataGenerator = {
29+
if (!aggregate) {
30+
new SimpleDataGenerator(kinesisClient)
31+
} else {
32+
new KPLDataGenerator(regionName)
33+
}
3034
}
3135
}
3236

3337
/** A wrapper for the KinesisProducer provided in the KPL. */
34-
private[kinesis] class KinesisProducerLibraryProducer(regionName: String) extends KinesisProducer {
38+
private[kinesis] class KPLDataGenerator(regionName: String) extends KinesisDataGenerator {
3539

3640
private lazy val producer: KPLProducer = {
3741
val conf = new KinesisProducerConfiguration()
@@ -43,30 +47,26 @@ private[kinesis] class KinesisProducerLibraryProducer(regionName: String) extend
4347
new KPLProducer(conf)
4448
}
4549

46-
private val shardIdToSeqNumbers = new mutable.HashMap[String, ArrayBuffer[(Int, String)]]()
47-
48-
override def putRecord(streamName: String, num: Int): Unit = {
49-
val str = num.toString
50-
val data = ByteBuffer.wrap(str.getBytes())
51-
val future = producer.addUserRecord(streamName, str, data)
52-
val kinesisCallBack = new FutureCallback[UserRecordResult]() {
53-
override def onFailure(t: Throwable): Unit = {} // do nothing
50+
override def sendData(streamName: String, data: Seq[Int]): Map[String, Seq[(Int, String)]] = {
51+
val shardIdToSeqNumbers = new mutable.HashMap[String, ArrayBuffer[(Int, String)]]()
52+
data.foreach { num =>
53+
val str = num.toString
54+
val data = ByteBuffer.wrap(str.getBytes())
55+
val future = producer.addUserRecord(streamName, str, data)
56+
val kinesisCallBack = new FutureCallback[UserRecordResult]() {
57+
override def onFailure(t: Throwable): Unit = {} // do nothing
5458

55-
override def onSuccess(result: UserRecordResult): Unit = {
56-
val shardId = result.getShardId
57-
val seqNumber = result.getSequenceNumber()
58-
val sentSeqNumbers = shardIdToSeqNumbers.getOrElseUpdate(shardId,
59-
new ArrayBuffer[(Int, String)]())
60-
sentSeqNumbers += ((num, seqNumber))
59+
override def onSuccess(result: UserRecordResult): Unit = {
60+
val shardId = result.getShardId
61+
val seqNumber = result.getSequenceNumber()
62+
val sentSeqNumbers = shardIdToSeqNumbers.getOrElseUpdate(shardId,
63+
new ArrayBuffer[(Int, String)]())
64+
sentSeqNumbers += ((num, seqNumber))
65+
}
6166
}
67+
Futures.addCallback(future, kinesisCallBack)
6268
}
63-
64-
Futures.addCallback(future, kinesisCallBack)
6569
producer.flushSync()
66-
}
67-
68-
override def flush(): Map[String, Seq[(Int, String)]] = {
6970
shardIdToSeqNumbers.toMap
7071
}
71-
7272
}

extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisBackedBlockRDDSuite.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ abstract class KinesisBackedBlockRDDTests(aggregateTestData: Boolean)
4040

4141
override def beforeAll(): Unit = {
4242
runIfTestsEnabled("Prepare KinesisTestUtils") {
43-
testUtils = new ExtendedKinesisTestUtils()
43+
testUtils = new KPLBasedKinesisTestUtils()
4444
testUtils.createStream()
4545

4646
shardIdToDataAndSeqNumbers = testUtils.pushData(testData, aggregate = aggregateTestData)

extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisStreamSuite.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ abstract class KinesisStreamTests(aggregateTestData: Boolean) extends KinesisFun
6363
sc = new SparkContext(conf)
6464

6565
runIfTestsEnabled("Prepare KinesisTestUtils") {
66-
testUtils = new ExtendedKinesisTestUtils()
66+
testUtils = new KPLBasedKinesisTestUtils()
6767
testUtils.createStream()
6868
}
6969
}

0 commit comments

Comments
 (0)