@@ -54,7 +54,7 @@ private[kinesis] class KinesisTestUtils extends Logging {
5454 @ volatile
5555 private var _streamName : String = _
5656
57- private lazy val kinesisClient = {
57+ protected lazy val kinesisClient = {
5858 val client = new AmazonKinesisClient (KinesisTestUtils .getAWSCredentials())
5959 client.setEndpoint(endpointUrl)
6060 client
@@ -66,14 +66,11 @@ private[kinesis] class KinesisTestUtils extends Logging {
6666 new DynamoDB (dynamoDBClient)
6767 }
6868
69- /** Left as a protected val so that we don't need to depend on KPL outside of tests. */
70- protected val kplProducer : KinesisProducer = null
71-
72- private def getProducer (aggregate : Boolean ): KinesisProducer = {
73- if (aggregate) {
74- kplProducer
69+ protected def getProducer (aggregate : Boolean ): KinesisDataGenerator = {
70+ if (! aggregate) {
71+ new SimpleDataGenerator (kinesisClient)
7572 } else {
76- new KinesisClientProducer (kinesisClient )
73+ throw new UnsupportedOperationException ( " Aggregation is not supported through this code path " )
7774 }
7875 }
7976
@@ -106,12 +103,7 @@ private[kinesis] class KinesisTestUtils extends Logging {
106103 def pushData (testData : Seq [Int ], aggregate : Boolean ): Map [String , Seq [(Int , String )]] = {
107104 require(streamCreated, " Stream not yet created, call createStream() to create one" )
108105 val producer = getProducer(aggregate)
109-
110- testData.foreach { num =>
111- producer.putRecord(streamName, num)
112- }
113-
114- val shardIdToSeqNumbers = producer.flush()
106+ val shardIdToSeqNumbers = producer.sendData(streamName, testData)
115107 logInfo(s " Pushed $testData: \n\t ${shardIdToSeqNumbers.mkString(" \n\t " )}" )
116108 shardIdToSeqNumbers.toMap
117109 }
@@ -239,31 +231,30 @@ private[kinesis] object KinesisTestUtils {
239231}
240232
241233/** A wrapper interface that will allow us to consolidate the code for synthetic data generation. */
242- private [kinesis] trait KinesisProducer {
243- /** Sends the data to Kinesis possibly with aggregation if KPL is used. */
244- def putRecord (streamName : String , num : Int ): Unit
245-
246- /** Flush all data in the buffer and return the metadata for everything that has been sent. */
247- def flush (): Map [String , Seq [(Int , String )]]
234+ private [kinesis] trait KinesisDataGenerator {
235+ /** Sends the data to Kinesis and returns the metadata for everything that has been sent. */
236+ def sendData (streamName : String , data : Seq [Int ]): Map [String , Seq [(Int , String )]]
248237}
249238
250- private [kinesis] class KinesisClientProducer (client : AmazonKinesisClient ) extends KinesisProducer {
251- private val shardIdToSeqNumbers = new mutable.HashMap [String , ArrayBuffer [(Int , String )]]()
252-
253- override def putRecord (streamName : String , num : Int ): Unit = {
254- val str = num.toString
255- val data = ByteBuffer .wrap(str.getBytes())
256- val putRecordRequest = new PutRecordRequest ().withStreamName(streamName)
257- .withData(data)
258- .withPartitionKey(str)
259-
260- val putRecordResult = client.putRecord(putRecordRequest)
261- val shardId = putRecordResult.getShardId
262- val seqNumber = putRecordResult.getSequenceNumber()
263- val sentSeqNumbers = shardIdToSeqNumbers.getOrElseUpdate(shardId,
264- new ArrayBuffer [(Int , String )]())
265- sentSeqNumbers += ((num, seqNumber))
266- }
239+ private [kinesis] class SimpleDataGenerator (
240+ client : AmazonKinesisClient ) extends KinesisDataGenerator {
241+ override def sendData (streamName : String , data : Seq [Int ]): Map [String , Seq [(Int , String )]] = {
242+ val shardIdToSeqNumbers = new mutable.HashMap [String , ArrayBuffer [(Int , String )]]()
243+ data.foreach { num =>
244+ val str = num.toString
245+ val data = ByteBuffer .wrap(str.getBytes())
246+ val putRecordRequest = new PutRecordRequest ().withStreamName(streamName)
247+ .withData(data)
248+ .withPartitionKey(str)
249+
250+ val putRecordResult = client.putRecord(putRecordRequest)
251+ val shardId = putRecordResult.getShardId
252+ val seqNumber = putRecordResult.getSequenceNumber()
253+ val sentSeqNumbers = shardIdToSeqNumbers.getOrElseUpdate(shardId,
254+ new ArrayBuffer [(Int , String )]())
255+ sentSeqNumbers += ((num, seqNumber))
256+ }
267257
268- override def flush (): Map [String , Seq [(Int , String )]] = shardIdToSeqNumbers.toMap
258+ shardIdToSeqNumbers.toMap
259+ }
269260}
0 commit comments