From 32e66fb2b6c8828bdc8b4f3f810df456d09cf580 Mon Sep 17 00:00:00 2001 From: Addison Higham Date: Tue, 16 Feb 2016 16:45:13 -0700 Subject: [PATCH 1/4] [SPARK-13367] Refactor KinesusUtils to specify more KCL options This patch refactors the KinesisUtils and adds new APIs such that it is easy to pass more configuration options into the KCL, such as using a different DynamoDBClient with a different endpoint. The core of this refactoring/change is to introduce the `KinesisConfig` class which is intended to encapsulate all configuration concerns. Currently, it doesn't do much more than allow for setting the DynamoDB endpoint, but it is a good place for where other options could be contained without changing underlying APIs. It could also be inherited to override behavior for more options without requiring any API changes (such as overriding `buildKCLConfig` ti allow for more configuration changes) This also introduce a new external API for creating a KinesisRDD, which takes a `KinesisConfig` object which may be a more manageable API going forward. Docs are still lacking for the new class as well as some basic unit specs --- .../kinesis/KinesisBackedBlockRDD.scala | 26 +--- .../streaming/kinesis/KinesisConfig.scala | 140 ++++++++++++++++++ .../kinesis/KinesisInputDStream.scala | 18 +-- .../streaming/kinesis/KinesisReceiver.scala | 66 +-------- .../streaming/kinesis/KinesisTestUtils.scala | 8 + .../streaming/kinesis/KinesisUtils.scala | 93 ++++++++++-- .../kinesis/JavaKinesisStreamSuite.java | 2 - .../kinesis/KinesisBackedBlockRDDSuite.scala | 15 +- .../kinesis/KinesisStreamSuite.scala | 12 +- 9 files changed, 260 insertions(+), 120 deletions(-) create mode 100644 extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisConfig.scala diff --git a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisBackedBlockRDD.scala b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisBackedBlockRDD.scala index 3996f168e69ee..117ed81f92523 100644 --- a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisBackedBlockRDD.scala +++ b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisBackedBlockRDD.scala @@ -21,8 +21,6 @@ import scala.collection.JavaConverters._ import scala.reflect.ClassTag import scala.util.control.NonFatal -import com.amazonaws.auth.{AWSCredentials, DefaultAWSCredentialsProviderChain} -import com.amazonaws.services.kinesis.AmazonKinesisClient import com.amazonaws.services.kinesis.clientlibrary.types.UserRecord import com.amazonaws.services.kinesis.model._ @@ -71,14 +69,12 @@ class KinesisBackedBlockRDDPartition( private[kinesis] class KinesisBackedBlockRDD[T: ClassTag]( sc: SparkContext, - val regionName: String, - val endpointUrl: String, + val config: KinesisConfig, @transient private val _blockIds: Array[BlockId], @transient val arrayOfseqNumberRanges: Array[SequenceNumberRanges], @transient private val isBlockIdValid: Array[Boolean] = Array.empty, val retryTimeoutMs: Int = 10000, - val messageHandler: Record => T = KinesisUtils.defaultMessageHandler _, - val awsCredentialsOption: Option[SerializableAWSCredentials] = None + val messageHandler: Record => T = KinesisUtils.defaultMessageHandler _ ) extends BlockRDD[T](sc, _blockIds) { require(_blockIds.length == arrayOfseqNumberRanges.length, @@ -104,12 +100,8 @@ class KinesisBackedBlockRDD[T: ClassTag]( } def getBlockFromKinesis(): Iterator[T] = { - val credentials = awsCredentialsOption.getOrElse { - new DefaultAWSCredentialsProviderChain().getCredentials() - } partition.seqNumberRanges.ranges.iterator.flatMap { range => - new KinesisSequenceRangeIterator(credentials, endpointUrl, regionName, - range, retryTimeoutMs).map(messageHandler) + new KinesisSequenceRangeIterator(config, range, retryTimeoutMs).map(messageHandler) } } if (partition.isBlockIdValid) { @@ -128,13 +120,11 @@ class KinesisBackedBlockRDD[T: ClassTag]( */ private[kinesis] class KinesisSequenceRangeIterator( - credentials: AWSCredentials, - endpointUrl: String, - regionId: String, + config: KinesisConfig, range: SequenceNumberRange, retryTimeoutMs: Int) extends NextIterator[Record] with Logging { - private val client = new AmazonKinesisClient(credentials) + private val client = config.buildKinesisClient() private val streamName = range.streamName private val shardId = range.shardId @@ -142,7 +132,7 @@ class KinesisSequenceRangeIterator( private var lastSeqNumber: String = null private var internalIterator: Iterator[Record] = null - client.setEndpoint(endpointUrl, "kinesis", regionId) + client.setEndpoint(config.endpointUrl) override protected def getNext(): Record = { var nextRecord: Record = null @@ -205,7 +195,7 @@ class KinesisSequenceRangeIterator( private def getRecordsAndNextKinesisIterator( shardIterator: String): (Iterator[Record], String) = { val getRecordsRequest = new GetRecordsRequest - getRecordsRequest.setRequestCredentials(credentials) + getRecordsRequest.setRequestCredentials(config.awsCreds) getRecordsRequest.setShardIterator(shardIterator) val getRecordsResult = retryOrTimeout[GetRecordsResult]( s"getting records using shard iterator") { @@ -225,7 +215,7 @@ class KinesisSequenceRangeIterator( iteratorType: ShardIteratorType, sequenceNumber: String): String = { val getShardIteratorRequest = new GetShardIteratorRequest - getShardIteratorRequest.setRequestCredentials(credentials) + getShardIteratorRequest.setRequestCredentials(config.awsCreds) getShardIteratorRequest.setStreamName(streamName) getShardIteratorRequest.setShardId(shardId) getShardIteratorRequest.setShardIteratorType(iteratorType.toString) diff --git a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisConfig.scala b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisConfig.scala new file mode 100644 index 0000000000000..aaa3ed51784fe --- /dev/null +++ b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisConfig.scala @@ -0,0 +1,140 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.streaming.kinesis + +import scala.reflect.ClassTag + + +import com.amazonaws.auth.{AWSCredentials, AWSCredentialsProvider, DefaultAWSCredentialsProviderChain} +import com.amazonaws.regions.{RegionUtils, Region} +import com.amazonaws.services.kinesis.clientlibrary.lib.worker.{InitialPositionInStream, KinesisClientLibConfiguration} +import com.amazonaws.services.dynamodbv2.AmazonDynamoDBClient +import com.amazonaws.services.kinesis.AmazonKinesisClient +import com.amazonaws.services.cloudwatch.AmazonCloudWatchClient + + +case class KinesisConfig( + kinesisAppName: String, + streamName: String, + endpointUrl: String, + regionName: String, + initialPositionInStream: InitialPositionInStream = InitialPositionInStream.TRIM_HORIZON, + awsCredentialsOption: Option[SerializableAWSCredentials] = None, + dynamoEndpointUrl: Option[String] = None, + dynamoCredentials: Option[SerializableAWSCredentials] = None + ) { + + def buildKCLConfig(workerId: String): KinesisClientLibConfiguration = { + // KCL config instance + val kinesisClientLibConfiguration = + new KinesisClientLibConfiguration(kinesisAppName, streamName, resolveAWSCredentialsProvider(), workerId) + .withKinesisEndpoint(endpointUrl) + .withInitialPositionInStream(initialPositionInStream) + .withTaskBackoffTimeMillis(500) + .withRegionName(regionName) + return kinesisClientLibConfiguration + + } + + def region: Region = { + RegionUtils.getRegion(regionName) + } + + def buildDynamoClient(): AmazonDynamoDBClient = { + val client = if (dynamoCredentials.isDefined) new AmazonDynamoDBClient(resolveAWSCredentialsProvider(dynamoCredentials)) else new AmazonDynamoDBClient(resolveAWSCredentialsProvider()) + client.setRegion(region) + if (dynamoEndpointUrl.isDefined) { + client.setEndpoint(dynamoEndpointUrl.get) + } + client + } + + def buildKinesisClient(): AmazonKinesisClient = { + val client = new AmazonKinesisClient(resolveAWSCredentialsProvider()) + client.setRegion(region) + client.setEndpoint(endpointUrl) + client + + } + + def buildCloudwatchClient(): AmazonCloudWatchClient = { + val client = new AmazonCloudWatchClient(resolveAWSCredentialsProvider()) + client.setRegion(region) + client + + } + + def awsCreds: AWSCredentials = { + awsCredentialsOption.getOrElse(new DefaultAWSCredentialsProviderChain().getCredentials()) + + } + + + /** + * If AWS credential is provided, return a AWSCredentialProvider returning that credential. + * Otherwise, return the DefaultAWSCredentialsProviderChain. + */ + private def resolveAWSCredentialsProvider(awsCredOpt: Option[SerializableAWSCredentials] = awsCredentialsOption): AWSCredentialsProvider = { + awsCredOpt match { + case Some(awsCredentials) => + new AWSCredentialsProvider { + override def getCredentials: AWSCredentials = awsCredentials + override def refresh(): Unit = { } + } + case None => + new DefaultAWSCredentialsProviderChain() + } + } + +} + +case class SerializableAWSCredentials(accessKeyId: String, secretKey: String) + extends AWSCredentials { + override def getAWSAccessKeyId: String = accessKeyId + override def getAWSSecretKey: String = secretKey +} + +private object KinesisConfig { + + + /* + * @param kinesisAppName Kinesis application name used by the Kinesis Client Library + * (KCL) to update DynamoDB + * @param streamName Kinesis stream name + * @param endpointUrl Url of Kinesis service (e.g., https://kinesis.us-east-1.amazonaws.com) + * @param regionName Name of region used by the Kinesis Client Library (KCL) to update + * DynamoDB (lease coordination and checkpointing) and CloudWatch (metrics) + * @param initialPositionInStream In the absence of Kinesis checkpoint info, this is the + * worker's initial starting position in the stream. + * The values are either the beginning of the stream + * per Kinesis' limit of 24 hours + * (InitialPositionInStream.TRIM_HORIZON) or + * the tip of the stream (InitialPositionInStream.LATEST). + * + */ + def buildConfig( + kinesisAppName: String, + streamName: String, + endpointUrl: String, + regionName: String, + initialPositionInStream: InitialPositionInStream = InitialPositionInStream.TRIM_HORIZON, + awsCredentialsOption: Option[SerializableAWSCredentials] = None): KinesisConfig = { + new KinesisConfig(kinesisAppName, streamName, endpointUrl, + regionName, initialPositionInStream, awsCredentialsOption) + } + +} \ No newline at end of file diff --git a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisInputDStream.scala b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisInputDStream.scala index 5223c81a8e0e0..62b13a49ee627 100644 --- a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisInputDStream.scala +++ b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisInputDStream.scala @@ -19,7 +19,6 @@ package org.apache.spark.streaming.kinesis import scala.reflect.ClassTag -import com.amazonaws.services.kinesis.clientlibrary.lib.worker.InitialPositionInStream import com.amazonaws.services.kinesis.model.Record import org.apache.spark.rdd.RDD @@ -31,15 +30,10 @@ import org.apache.spark.streaming.scheduler.ReceivedBlockInfo private[kinesis] class KinesisInputDStream[T: ClassTag]( _ssc: StreamingContext, - streamName: String, - endpointUrl: String, - regionName: String, - initialPositionInStream: InitialPositionInStream, - checkpointAppName: String, + config: KinesisConfig, checkpointInterval: Duration, storageLevel: StorageLevel, - messageHandler: Record => T, - awsCredentialsOption: Option[SerializableAWSCredentials] + messageHandler: Record => T ) extends ReceiverInputDStream[T](_ssc) { private[streaming] @@ -57,11 +51,10 @@ private[kinesis] class KinesisInputDStream[T: ClassTag]( logDebug(s"Creating KinesisBackedBlockRDD for $time with ${seqNumRanges.length} " + s"seq number ranges: ${seqNumRanges.mkString(", ")} ") new KinesisBackedBlockRDD( - context.sc, regionName, endpointUrl, blockIds, seqNumRanges, + context.sc, config, blockIds, seqNumRanges, isBlockIdValid = isBlockIdValid, retryTimeoutMs = ssc.graph.batchDuration.milliseconds.toInt, - messageHandler = messageHandler, - awsCredentialsOption = awsCredentialsOption) + messageHandler = messageHandler) } else { logWarning("Kinesis sequence number information was not present with some block metadata," + " it may not be possible to recover from failures") @@ -70,7 +63,6 @@ private[kinesis] class KinesisInputDStream[T: ClassTag]( } override def getReceiver(): Receiver[T] = { - new KinesisReceiver(streamName, endpointUrl, regionName, initialPositionInStream, - checkpointAppName, checkpointInterval, storageLevel, messageHandler, awsCredentialsOption) + new KinesisReceiver(config, checkpointInterval, storageLevel, messageHandler) } } diff --git a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisReceiver.scala b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisReceiver.scala index 48ee2a959786b..fbe55996809b4 100644 --- a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisReceiver.scala +++ b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisReceiver.scala @@ -23,9 +23,8 @@ import scala.collection.JavaConverters._ import scala.collection.mutable import scala.util.control.NonFatal -import com.amazonaws.auth.{AWSCredentials, AWSCredentialsProvider, DefaultAWSCredentialsProviderChain} import com.amazonaws.services.kinesis.clientlibrary.interfaces.{IRecordProcessor, IRecordProcessorCheckpointer, IRecordProcessorFactory} -import com.amazonaws.services.kinesis.clientlibrary.lib.worker.{InitialPositionInStream, KinesisClientLibConfiguration, Worker} +import com.amazonaws.services.kinesis.clientlibrary.lib.worker.Worker import com.amazonaws.services.kinesis.model.Record import org.apache.spark.storage.{StorageLevel, StreamBlockId} @@ -34,12 +33,6 @@ import org.apache.spark.streaming.receiver.{BlockGenerator, BlockGeneratorListen import org.apache.spark.util.Utils import org.apache.spark.Logging -private[kinesis] -case class SerializableAWSCredentials(accessKeyId: String, secretKey: String) - extends AWSCredentials { - override def getAWSAccessKeyId: String = accessKeyId - override def getAWSSecretKey: String = secretKey -} /** * Custom AWS Kinesis-specific implementation of Spark Streaming's Receiver. @@ -60,37 +53,17 @@ case class SerializableAWSCredentials(accessKeyId: String, secretKey: String) * - Periodically, each KinesisRecordProcessor checkpoints the latest successfully stored sequence * number for it own shard. * - * @param streamName Kinesis stream name - * @param endpointUrl Url of Kinesis service (e.g., https://kinesis.us-east-1.amazonaws.com) - * @param regionName Region name used by the Kinesis Client Library for - * DynamoDB (lease coordination and checkpointing) and CloudWatch (metrics) - * @param initialPositionInStream In the absence of Kinesis checkpoint info, this is the - * worker's initial starting position in the stream. - * The values are either the beginning of the stream - * per Kinesis' limit of 24 hours - * (InitialPositionInStream.TRIM_HORIZON) or - * the tip of the stream (InitialPositionInStream.LATEST). - * @param checkpointAppName Kinesis application name. Kinesis Apps are mapped to Kinesis Streams - * by the Kinesis Client Library. If you change the App name or Stream name, - * the KCL will throw errors. This usually requires deleting the backing - * DynamoDB table with the same name this Kinesis application. + * @param config SparkKinesisConfig object, * @param checkpointInterval Checkpoint interval for Kinesis checkpointing. * See the Kinesis Spark Streaming documentation for more * details on the different types of checkpoints. * @param storageLevel Storage level to use for storing the received objects - * @param awsCredentialsOption Optional AWS credentials, used when user directly specifies - * the credentials */ private[kinesis] class KinesisReceiver[T]( - val streamName: String, - endpointUrl: String, - regionName: String, - initialPositionInStream: InitialPositionInStream, - checkpointAppName: String, + config: KinesisConfig, checkpointInterval: Duration, storageLevel: StorageLevel, - messageHandler: Record => T, - awsCredentialsOption: Option[SerializableAWSCredentials]) + messageHandler: Record => T) extends Receiver[T](storageLevel) with Logging { receiver => /* @@ -147,14 +120,6 @@ private[kinesis] class KinesisReceiver[T]( workerId = Utils.localHostName() + ":" + UUID.randomUUID() kinesisCheckpointer = new KinesisCheckpointer(receiver, checkpointInterval, workerId) - // KCL config instance - val awsCredProvider = resolveAWSCredentialsProvider() - val kinesisClientLibConfiguration = - new KinesisClientLibConfiguration(checkpointAppName, streamName, awsCredProvider, workerId) - .withKinesisEndpoint(endpointUrl) - .withInitialPositionInStream(initialPositionInStream) - .withTaskBackoffTimeMillis(500) - .withRegionName(regionName) /* * RecordProcessorFactory creates impls of IRecordProcessor. @@ -167,7 +132,8 @@ private[kinesis] class KinesisReceiver[T]( new KinesisRecordProcessor(receiver, workerId) } - worker = new Worker(recordProcessorFactory, kinesisClientLibConfiguration) + worker = new Worker(recordProcessorFactory, config.buildKCLConfig(workerId), + config.buildKinesisClient(), config.buildDynamoClient(), config.buildCloudwatchClient()) workerThread = new Thread() { override def run(): Unit = { try { @@ -215,7 +181,7 @@ private[kinesis] class KinesisReceiver[T]( private[kinesis] def addRecords(shardId: String, records: java.util.List[Record]): Unit = { if (records.size > 0) { val dataIterator = records.iterator().asScala.map(messageHandler) - val metadata = SequenceNumberRange(streamName, shardId, + val metadata = SequenceNumberRange(config.streamName, shardId, records.get(0).getSequenceNumber(), records.get(records.size() - 1).getSequenceNumber()) blockGenerator.addMultipleDataWithCallback(dataIterator, metadata) } @@ -299,24 +265,6 @@ private[kinesis] class KinesisReceiver[T]( } } - /** - * If AWS credential is provided, return a AWSCredentialProvider returning that credential. - * Otherwise, return the DefaultAWSCredentialsProviderChain. - */ - private def resolveAWSCredentialsProvider(): AWSCredentialsProvider = { - awsCredentialsOption match { - case Some(awsCredentials) => - logInfo("Using provided AWS credentials") - new AWSCredentialsProvider { - override def getCredentials: AWSCredentials = awsCredentials - override def refresh(): Unit = { } - } - case None => - logInfo("Using DefaultAWSCredentialsProviderChain") - new DefaultAWSCredentialsProviderChain() - } - } - /** * Class to handle blocks generated by this receiver's block generator. Specifically, in diff --git a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisTestUtils.scala b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisTestUtils.scala index 0ace453ee9280..9327f7a893e83 100644 --- a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisTestUtils.scala +++ b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisTestUtils.scala @@ -26,6 +26,7 @@ import scala.collection.mutable.ArrayBuffer import scala.util.{Failure, Random, Success, Try} import com.amazonaws.auth.{AWSCredentials, DefaultAWSCredentialsProviderChain} +import com.amazonaws.services.kinesis.clientlibrary.lib.worker.InitialPositionInStream import com.amazonaws.regions.RegionUtils import com.amazonaws.services.dynamodbv2.AmazonDynamoDBClient import com.amazonaws.services.dynamodbv2.document.DynamoDB @@ -66,6 +67,10 @@ private[kinesis] class KinesisTestUtils extends Logging { new DynamoDB(dynamoDBClient) } + lazy val sparkKinesisConfig: KinesisConfig = { + KinesisConfig.buildConfig("kinesis-asl-unit-test", _streamName, endpointUrl, regionName, InitialPositionInStream.TRIM_HORIZON) + } + protected def getProducer(aggregate: Boolean): KinesisDataGenerator = { if (!aggregate) { new SimpleDataGenerator(kinesisClient) @@ -74,11 +79,13 @@ private[kinesis] class KinesisTestUtils extends Logging { } } + def streamName: String = { require(streamCreated, "Stream not yet created, call createStream() to create one") _streamName } + def createStream(): Unit = { require(!streamCreated, "Stream already created") _streamName = findNonExistentStreamName() @@ -228,6 +235,7 @@ private[kinesis] object KinesisTestUtils { """.stripMargin) } } + } /** A wrapper interface that will allow us to consolidate the code for synthetic data generation. */ diff --git a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisUtils.scala b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisUtils.scala index 15ac588b82587..7c65cd0ce2fa4 100644 --- a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisUtils.scala +++ b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisUtils.scala @@ -27,8 +27,38 @@ import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming.{Duration, StreamingContext} import org.apache.spark.streaming.api.java.{JavaReceiverInputDStream, JavaStreamingContext} import org.apache.spark.streaming.dstream.ReceiverInputDStream +import org.apache.spark.util.Utils object KinesisUtils { + /** + * Create an input stream that pulls messages from a Kinesis stream. + * This uses the Kinesis Client Library (KCL) to pull messages from Kinesis. + * + * Note: the + * + * @param ssc StreamingContext object + * @param config SparkKinesisConfig object, + * @param checkpointInterval Checkpoint interval for Kinesis checkpointing. + * See the Kinesis Spark Streaming documentation for more + * details on the different types of checkpoints. + * @param storageLevel Storage level to use for storing the received objects. + * StorageLevel.MEMORY_AND_DISK_2 is recommended. + * @param messageHandler A custom message handler that can generate a generic output from a + * Kinesis `Record`, which contains both message data, and metadata. + */ + def createStream[T: ClassTag]( + ssc: StreamingContext, + config: KinesisConfig, + checkpointInterval: Duration, + storageLevel: StorageLevel, + messageHandler: Record => T): ReceiverInputDStream[T] = { + val cleanedHandler = ssc.sc.clean(messageHandler) + // Setting scope to override receiver stream's scope of "receiver stream" + ssc.withNamedScope("kinesis stream") { + new KinesisInputDStream[T](ssc, config, checkpointInterval, storageLevel, cleanedHandler) + } + } + /** * Create an input stream that pulls messages from a Kinesis stream. * This uses the Kinesis Client Library (KCL) to pull messages from Kinesis. @@ -70,13 +100,15 @@ object KinesisUtils { messageHandler: Record => T): ReceiverInputDStream[T] = { val cleanedHandler = ssc.sc.clean(messageHandler) // Setting scope to override receiver stream's scope of "receiver stream" + val kinesisClientConfig = KinesisConfig.buildConfig(kinesisAppName, streamName, + endpointUrl, validateRegion(regionName), initialPositionInStream, None) ssc.withNamedScope("kinesis stream") { - new KinesisInputDStream[T](ssc, streamName, endpointUrl, validateRegion(regionName), - initialPositionInStream, kinesisAppName, checkpointInterval, storageLevel, - cleanedHandler, None) + new KinesisInputDStream[T](ssc, kinesisClientConfig, + checkpointInterval, storageLevel, cleanedHandler) } } + /** * Create an input stream that pulls messages from a Kinesis stream. * This uses the Kinesis Client Library (KCL) to pull messages from Kinesis. @@ -123,10 +155,36 @@ object KinesisUtils { awsSecretKey: String): ReceiverInputDStream[T] = { // scalastyle:on val cleanedHandler = ssc.sc.clean(messageHandler) + val kinesisClientConfig = KinesisConfig.buildConfig(kinesisAppName, streamName, + endpointUrl, validateRegion(regionName), initialPositionInStream, Some(SerializableAWSCredentials(awsAccessKeyId, awsSecretKey))) + ssc.withNamedScope("kinesis stream") { + new KinesisInputDStream[T](ssc, kinesisClientConfig, + checkpointInterval, storageLevel, cleanedHandler) + } + } + + /** + * Create an input stream that pulls messages from a Kinesis stream. + * This uses the Kinesis Client Library (KCL) to pull messages from Kinesis. + * + * Note: the + * + * @param ssc StreamingContext object + * @param config SparkKinesisConfig object, + * @param checkpointInterval Checkpoint interval for Kinesis checkpointing. + * See the Kinesis Spark Streaming documentation for more + * details on the different types of checkpoints. + * @param storageLevel Storage level to use for storing the received objects. + * StorageLevel.MEMORY_AND_DISK_2 is recommended. + */ + def createStream( + ssc: StreamingContext, + config: KinesisConfig, + checkpointInterval: Duration, + storageLevel: StorageLevel): ReceiverInputDStream[Array[Byte]] = { + // Setting scope to override receiver stream's scope of "receiver stream" ssc.withNamedScope("kinesis stream") { - new KinesisInputDStream[T](ssc, streamName, endpointUrl, validateRegion(regionName), - initialPositionInStream, kinesisAppName, checkpointInterval, storageLevel, - cleanedHandler, Some(SerializableAWSCredentials(awsAccessKeyId, awsSecretKey))) + new KinesisInputDStream[Array[Byte]](ssc, config, checkpointInterval, storageLevel, defaultMessageHandler) } } @@ -167,10 +225,11 @@ object KinesisUtils { checkpointInterval: Duration, storageLevel: StorageLevel): ReceiverInputDStream[Array[Byte]] = { // Setting scope to override receiver stream's scope of "receiver stream" + val kinesisClientConfig = KinesisConfig.buildConfig(kinesisAppName, streamName, + endpointUrl, validateRegion(regionName), initialPositionInStream, None) ssc.withNamedScope("kinesis stream") { - new KinesisInputDStream[Array[Byte]](ssc, streamName, endpointUrl, validateRegion(regionName), - initialPositionInStream, kinesisAppName, checkpointInterval, storageLevel, - defaultMessageHandler, None) + new KinesisInputDStream[Array[Byte]](ssc, kinesisClientConfig, + checkpointInterval, storageLevel, defaultMessageHandler) } } @@ -214,10 +273,11 @@ object KinesisUtils { storageLevel: StorageLevel, awsAccessKeyId: String, awsSecretKey: String): ReceiverInputDStream[Array[Byte]] = { + val kinesisClientConfig = KinesisConfig.buildConfig(kinesisAppName, streamName, + endpointUrl, validateRegion(regionName), initialPositionInStream, Some(SerializableAWSCredentials(awsAccessKeyId, awsSecretKey))) ssc.withNamedScope("kinesis stream") { - new KinesisInputDStream[Array[Byte]](ssc, streamName, endpointUrl, validateRegion(regionName), - initialPositionInStream, kinesisAppName, checkpointInterval, storageLevel, - defaultMessageHandler, Some(SerializableAWSCredentials(awsAccessKeyId, awsSecretKey))) + new KinesisInputDStream[Array[Byte]](ssc, kinesisClientConfig, + checkpointInterval, storageLevel, defaultMessageHandler) } } @@ -259,10 +319,11 @@ object KinesisUtils { initialPositionInStream: InitialPositionInStream, storageLevel: StorageLevel ): ReceiverInputDStream[Array[Byte]] = { + val kinesisClientConfig = KinesisConfig.buildConfig(ssc.sc.appName, streamName, + endpointUrl, getRegionByEndpoint(endpointUrl), initialPositionInStream) ssc.withNamedScope("kinesis stream") { - new KinesisInputDStream[Array[Byte]](ssc, streamName, endpointUrl, - getRegionByEndpoint(endpointUrl), initialPositionInStream, ssc.sc.appName, - checkpointInterval, storageLevel, defaultMessageHandler, None) + new KinesisInputDStream[Array[Byte]](ssc, kinesisClientConfig, + checkpointInterval, storageLevel, defaultMessageHandler) } } @@ -558,3 +619,5 @@ private class KinesisUtilsPythonHelper { } } + + diff --git a/extras/kinesis-asl/src/test/java/org/apache/spark/streaming/kinesis/JavaKinesisStreamSuite.java b/extras/kinesis-asl/src/test/java/org/apache/spark/streaming/kinesis/JavaKinesisStreamSuite.java index 3f0f6793d2d21..5c2371c5430b3 100644 --- a/extras/kinesis-asl/src/test/java/org/apache/spark/streaming/kinesis/JavaKinesisStreamSuite.java +++ b/extras/kinesis-asl/src/test/java/org/apache/spark/streaming/kinesis/JavaKinesisStreamSuite.java @@ -28,8 +28,6 @@ import com.amazonaws.services.kinesis.clientlibrary.lib.worker.InitialPositionInStream; -import java.nio.ByteBuffer; - /** * Demonstrate the use of the KinesisUtils Java API */ diff --git a/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisBackedBlockRDDSuite.scala b/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisBackedBlockRDDSuite.scala index e916f1ee0893b..ec7bdb8d8cac1 100644 --- a/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisBackedBlockRDDSuite.scala +++ b/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisBackedBlockRDDSuite.scala @@ -28,6 +28,7 @@ abstract class KinesisBackedBlockRDDTests(aggregateTestData: Boolean) private val testData = 1 to 8 private var testUtils: KinesisTestUtils = null + private var sparkKinesisConfig: KinesisConfig = null private var shardIds: Seq[String] = null private var shardIdToData: Map[String, Seq[Int]] = null private var shardIdToSeqNumbers: Map[String, Seq[String]] = null @@ -42,6 +43,7 @@ abstract class KinesisBackedBlockRDDTests(aggregateTestData: Boolean) runIfTestsEnabled("Prepare KinesisTestUtils") { testUtils = new KPLBasedKinesisTestUtils() testUtils.createStream() + sparkKinesisConfig = testUtils.sparkKinesisConfig shardIdToDataAndSeqNumbers = testUtils.pushData(testData, aggregate = aggregateTestData) require(shardIdToDataAndSeqNumbers.size > 1, "Need data to be sent to multiple shards") @@ -73,22 +75,19 @@ abstract class KinesisBackedBlockRDDTests(aggregateTestData: Boolean) testIfEnabled("Basic reading from Kinesis") { // Verify all data using multiple ranges in a single RDD partition - val receivedData1 = new KinesisBackedBlockRDD[Array[Byte]](sc, testUtils.regionName, - testUtils.endpointUrl, fakeBlockIds(1), + val receivedData1 = new KinesisBackedBlockRDD[Array[Byte]](sc, sparkKinesisConfig, fakeBlockIds(1), Array(SequenceNumberRanges(allRanges.toArray)) ).map { bytes => new String(bytes).toInt }.collect() assert(receivedData1.toSet === testData.toSet) // Verify all data using one range in each of the multiple RDD partitions - val receivedData2 = new KinesisBackedBlockRDD[Array[Byte]](sc, testUtils.regionName, - testUtils.endpointUrl, fakeBlockIds(allRanges.size), + val receivedData2 = new KinesisBackedBlockRDD[Array[Byte]](sc, sparkKinesisConfig, fakeBlockIds(allRanges.size), allRanges.map { range => SequenceNumberRanges(Array(range)) }.toArray ).map { bytes => new String(bytes).toInt }.collect() assert(receivedData2.toSet === testData.toSet) // Verify ordering within each partition - val receivedData3 = new KinesisBackedBlockRDD[Array[Byte]](sc, testUtils.regionName, - testUtils.endpointUrl, fakeBlockIds(allRanges.size), + val receivedData3 = new KinesisBackedBlockRDD[Array[Byte]](sc, sparkKinesisConfig, fakeBlockIds(allRanges.size), allRanges.map { range => SequenceNumberRanges(Array(range)) }.toArray ).map { bytes => new String(bytes).toInt }.collectPartitions() assert(receivedData3.length === allRanges.size) @@ -210,7 +209,7 @@ abstract class KinesisBackedBlockRDDTests(aggregateTestData: Boolean) ) val rdd = new KinesisBackedBlockRDD[Array[Byte]]( - sc, testUtils.regionName, testUtils.endpointUrl, blockIds, ranges) + sc, sparkKinesisConfig, blockIds, ranges) val collectedData = rdd.map { bytes => new String(bytes).toInt }.collect() @@ -224,7 +223,7 @@ abstract class KinesisBackedBlockRDDTests(aggregateTestData: Boolean) require(numPartitionsInBM === numPartitions, "All partitions must be in BlockManager") require(numPartitionsInKinesis === 0, "No partitions must be in Kinesis") val rdd2 = new KinesisBackedBlockRDD[Array[Byte]]( - sc, testUtils.regionName, testUtils.endpointUrl, blockIds.toArray, ranges, + sc, sparkKinesisConfig, blockIds.toArray, ranges, isBlockIdValid = Array.fill(blockIds.length)(false)) intercept[SparkException] { rdd2.collect() diff --git a/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisStreamSuite.scala b/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisStreamSuite.scala index ee6a5f0390d04..e3d9af315533d 100644 --- a/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisStreamSuite.scala +++ b/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisStreamSuite.scala @@ -139,11 +139,13 @@ abstract class KinesisStreamTests(aggregateTestData: Boolean) extends KinesisFun val nonEmptyRDD = kinesisStream.createBlockRDD(time, blockInfos) nonEmptyRDD shouldBe a [KinesisBackedBlockRDD[_]] val kinesisRDD = nonEmptyRDD.asInstanceOf[KinesisBackedBlockRDD[_]] - assert(kinesisRDD.regionName === dummyRegionName) - assert(kinesisRDD.endpointUrl === dummyEndpointUrl) + assert(kinesisRDD.config.regionName === dummyRegionName) + assert(kinesisRDD.config.endpointUrl === dummyEndpointUrl) assert(kinesisRDD.retryTimeoutMs === batchDuration.milliseconds) - assert(kinesisRDD.awsCredentialsOption === - Some(SerializableAWSCredentials(dummyAWSAccessKey, dummyAWSSecretKey))) + assert(kinesisRDD.config.awsCreds.getAWSAccessKeyId() === + dummyAWSAccessKey) + assert(kinesisRDD.config.awsCreds.getAWSSecretKey() === + dummyAWSSecretKey) assert(nonEmptyRDD.partitions.size === blockInfos.size) nonEmptyRDD.partitions.foreach { _ shouldBe a [KinesisBackedBlockRDDPartition] } val partitions = nonEmptyRDD.partitions.map { @@ -154,7 +156,7 @@ abstract class KinesisStreamTests(aggregateTestData: Boolean) extends KinesisFun // Verify that KinesisBackedBlockRDD is generated even when there are no blocks val emptyRDD = kinesisStream.createBlockRDD(time, Seq.empty) - emptyRDD shouldBe a [KinesisBackedBlockRDD[Array[Byte]]] + emptyRDD shouldBe a [KinesisBackedBlockRDD[_]] emptyRDD.partitions shouldBe empty // Verify that the KinesisBackedBlockRDD has isBlockValid = false when blocks are invalid From 6aec426331f397792ef9efd7249001d0c367988d Mon Sep 17 00:00:00 2001 From: Addison Higham Date: Fri, 19 Feb 2016 13:24:27 -0700 Subject: [PATCH 2/4] Add docs and tests --- .../kinesis/KinesisBackedBlockRDD.scala | 4 +- .../streaming/kinesis/KinesisConfig.scala | 70 ++++++++++++++++--- .../kinesis/KinesisConfigSuite.scala | 56 +++++++++++++++ .../kinesis/KinesisStreamSuite.scala | 4 +- 4 files changed, 121 insertions(+), 13 deletions(-) create mode 100644 extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisConfigSuite.scala diff --git a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisBackedBlockRDD.scala b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisBackedBlockRDD.scala index 117ed81f92523..c1e81314154e9 100644 --- a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisBackedBlockRDD.scala +++ b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisBackedBlockRDD.scala @@ -195,7 +195,7 @@ class KinesisSequenceRangeIterator( private def getRecordsAndNextKinesisIterator( shardIterator: String): (Iterator[Record], String) = { val getRecordsRequest = new GetRecordsRequest - getRecordsRequest.setRequestCredentials(config.awsCreds) + getRecordsRequest.setRequestCredentials(config.awsCredentials) getRecordsRequest.setShardIterator(shardIterator) val getRecordsResult = retryOrTimeout[GetRecordsResult]( s"getting records using shard iterator") { @@ -215,7 +215,7 @@ class KinesisSequenceRangeIterator( iteratorType: ShardIteratorType, sequenceNumber: String): String = { val getShardIteratorRequest = new GetShardIteratorRequest - getShardIteratorRequest.setRequestCredentials(config.awsCreds) + getShardIteratorRequest.setRequestCredentials(config.awsCredentials) getShardIteratorRequest.setStreamName(streamName) getShardIteratorRequest.setShardId(shardId) getShardIteratorRequest.setShardIteratorType(iteratorType.toString) diff --git a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisConfig.scala b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisConfig.scala index aaa3ed51784fe..8a3b5e7b7801f 100644 --- a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisConfig.scala +++ b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisConfig.scala @@ -27,6 +27,31 @@ import com.amazonaws.services.kinesis.AmazonKinesisClient import com.amazonaws.services.cloudwatch.AmazonCloudWatchClient +/** + * Configuration container for settings to be passed down into the kinesis-client-library (KCL). + * This class is also used to build any of the client instances used by the KCL so we + * can override the things like the endpoint. + * + * + * @param kinesisAppName The name of kinesis application (used in creating dynamo tables) + * @param streamName The name of the actual kinesis stream + * @param endpointUrl The AWS API endpoint that will be used for the kinesis client + * @param regionName The AWS region that will be connected to (will set default enpoint for dynamo and cloudwatch) + * @param initialPositionInStream In the absence of Kinesis checkpoint info, this is the + * worker's initial starting position in the stream. + * The values are either the beginning of the stream + * per Kinesis' limit of 24 hours + * (InitialPositionInStream.TRIM_HORIZON) or + * the tip of the stream (InitialPositionInStream.LATEST). + * @param awsCredentialsOption None or Some instance of SerializableAWSCredentials that will be used for + * credentials for Kinesis and the default for other clients. If None, then the + * DefaultAWSCredentialsProviderChain will be used + * @param dynamoEndpointUrl None or Some AWS API endpoint that will be used for the DynamoDBClient, if None, then the regionName + * will be used to build the default endpoint + * @param dynamoCredentials None or Some SerializableAWSCredentials that will be used as the credentials. If None, + * then the DefaultProviderKeychain will be used to build credentials + * + */ case class KinesisConfig( kinesisAppName: String, streamName: String, @@ -38,6 +63,13 @@ case class KinesisConfig( dynamoCredentials: Option[SerializableAWSCredentials] = None ) { + /** + * Builds a KinesisClientLibConfiguration object, which contains all the configuration options + * See the docs for more info: + * http://static.javadoc.io/com.amazonaws/amazon-kinesis-client/1.6.1/com/amazonaws/services/kinesis/clientlibrary/lib/worker/KinesisClientLibConfiguration.html + * + * @param workerId A unique string to identify a worker + */ def buildKCLConfig(workerId: String): KinesisClientLibConfiguration = { // KCL config instance val kinesisClientLibConfiguration = @@ -50,10 +82,10 @@ case class KinesisConfig( } - def region: Region = { - RegionUtils.getRegion(regionName) - } + /** + * Returns a AmazonDynamoDBClient instance configured with the proper region/endpoint + */ def buildDynamoClient(): AmazonDynamoDBClient = { val client = if (dynamoCredentials.isDefined) new AmazonDynamoDBClient(resolveAWSCredentialsProvider(dynamoCredentials)) else new AmazonDynamoDBClient(resolveAWSCredentialsProvider()) client.setRegion(region) @@ -63,6 +95,9 @@ case class KinesisConfig( client } + /** + * Returns a AmazonKinesisClient instance configured with the proper region/endpoint + */ def buildKinesisClient(): AmazonKinesisClient = { val client = new AmazonKinesisClient(resolveAWSCredentialsProvider()) client.setRegion(region) @@ -71,6 +106,9 @@ case class KinesisConfig( } + /** + * Returns a AmazonCloudWatchClient instance configured with the proper region/endpoint + */ def buildCloudwatchClient(): AmazonCloudWatchClient = { val client = new AmazonCloudWatchClient(resolveAWSCredentialsProvider()) client.setRegion(region) @@ -78,9 +116,11 @@ case class KinesisConfig( } - def awsCreds: AWSCredentials = { - awsCredentialsOption.getOrElse(new DefaultAWSCredentialsProviderChain().getCredentials()) - + /** + * Returns the provided credentials or resolves a pair of credentials using DefaultAWSCredentialsProviderChain + */ + def awsCredentials: AWSCredentials = { + resolveAWSCredentialsProvider().getCredentials() } @@ -100,8 +140,22 @@ case class KinesisConfig( } } + /** + * Resolves string region into the region object + */ + private def region: Region = { + RegionUtils.getRegion(regionName) + } + } +/** + * A small class that extends AWSCredentials that is marked as serializable, which + * is needed in order to have it serialized into a spark context + * + * @param accessKeyId An AWS accessKeyId + * @param secretKey An AWS secretKey + */ case class SerializableAWSCredentials(accessKeyId: String, secretKey: String) extends AWSCredentials { override def getAWSAccessKeyId: String = accessKeyId @@ -109,9 +163,7 @@ case class SerializableAWSCredentials(accessKeyId: String, secretKey: String) } private object KinesisConfig { - - - /* + /** * @param kinesisAppName Kinesis application name used by the Kinesis Client Library * (KCL) to update DynamoDB * @param streamName Kinesis stream name diff --git a/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisConfigSuite.scala b/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisConfigSuite.scala new file mode 100644 index 0000000000000..36fc2a4a0a27a --- /dev/null +++ b/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisConfigSuite.scala @@ -0,0 +1,56 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.kinesis + +import scala.language.postfixOps + + +import com.amazonaws.services.kinesis.clientlibrary.lib.worker.{InitialPositionInStream, KinesisClientLibConfiguration} + +class KinesisConfigSuite extends KinesisFunSuite { + + private val workerId = "dummyWorkerId" + private val kinesisAppName = "testApp" + private val kinesisStreamName = "testStream" + private val regionName = "us-east-1" + private val endpointUrl = "https://testendpoint.local" + private val streamPosition = InitialPositionInStream.TRIM_HORIZON + + private val awsAccessKey = "accessKey" + private val awsSecretKey = "secretKey" + private val awsCreds = new SerializableAWSCredentials(awsAccessKey, awsSecretKey) + + + + test("builds a KinesisClientLibConfiguration with defaults set") { + val kinesisConfig = new KinesisConfig(kinesisAppName, kinesisStreamName, regionName, endpointUrl, streamPosition) + val kclConfig = kinesisConfig.buildKCLConfig(workerId) + assert(kclConfig.getApplicationName() == kinesisAppName) + assert(kclConfig.getStreamName() == kinesisStreamName) + assert(kclConfig.getInitialPositionInStream() == streamPosition) + assert(kclConfig.getApplicationName() == kinesisAppName) + assert(kclConfig.getRegionName() == regionName) + assert(kclConfig.getKinesisEndpoint() == endpointUrl) + } + + test("returns given creds if creds are specified") { + val kinesisConfig = new KinesisConfig(kinesisAppName, kinesisStreamName, regionName, endpointUrl, streamPosition, Some(awsCreds)) + assert(kinesisConfig.awsCredentials == awsCreds) + } + +} diff --git a/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisStreamSuite.scala b/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisStreamSuite.scala index e3d9af315533d..66a014407169f 100644 --- a/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisStreamSuite.scala +++ b/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisStreamSuite.scala @@ -142,9 +142,9 @@ abstract class KinesisStreamTests(aggregateTestData: Boolean) extends KinesisFun assert(kinesisRDD.config.regionName === dummyRegionName) assert(kinesisRDD.config.endpointUrl === dummyEndpointUrl) assert(kinesisRDD.retryTimeoutMs === batchDuration.milliseconds) - assert(kinesisRDD.config.awsCreds.getAWSAccessKeyId() === + assert(kinesisRDD.config.awsCredentials.getAWSAccessKeyId() === dummyAWSAccessKey) - assert(kinesisRDD.config.awsCreds.getAWSSecretKey() === + assert(kinesisRDD.config.awsCredentials.getAWSSecretKey() === dummyAWSSecretKey) assert(nonEmptyRDD.partitions.size === blockInfos.size) nonEmptyRDD.partitions.foreach { _ shouldBe a [KinesisBackedBlockRDDPartition] } From e0a29682d9510c82dc12fb96d73418306ae2ca5f Mon Sep 17 00:00:00 2001 From: Addison Higham Date: Fri, 19 Feb 2016 14:34:14 -0700 Subject: [PATCH 3/4] pass checkstyle --- .../streaming/kinesis/KinesisConfig.scala | 50 ++++++++++++------- .../streaming/kinesis/KinesisTestUtils.scala | 9 +++- .../streaming/kinesis/KinesisUtils.scala | 9 ++-- .../kinesis/KinesisConfigSuite.scala | 2 +- 4 files changed, 47 insertions(+), 23 deletions(-) diff --git a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisConfig.scala b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisConfig.scala index 8a3b5e7b7801f..efdf0834e3511 100644 --- a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisConfig.scala +++ b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisConfig.scala @@ -18,13 +18,12 @@ package org.apache.spark.streaming.kinesis import scala.reflect.ClassTag - import com.amazonaws.auth.{AWSCredentials, AWSCredentialsProvider, DefaultAWSCredentialsProviderChain} -import com.amazonaws.regions.{RegionUtils, Region} -import com.amazonaws.services.kinesis.clientlibrary.lib.worker.{InitialPositionInStream, KinesisClientLibConfiguration} +import com.amazonaws.regions.{Region, RegionUtils} +import com.amazonaws.services.cloudwatch.AmazonCloudWatchClient import com.amazonaws.services.dynamodbv2.AmazonDynamoDBClient import com.amazonaws.services.kinesis.AmazonKinesisClient -import com.amazonaws.services.cloudwatch.AmazonCloudWatchClient +import com.amazonaws.services.kinesis.clientlibrary.lib.worker.{InitialPositionInStream, KinesisClientLibConfiguration} /** @@ -36,20 +35,24 @@ import com.amazonaws.services.cloudwatch.AmazonCloudWatchClient * @param kinesisAppName The name of kinesis application (used in creating dynamo tables) * @param streamName The name of the actual kinesis stream * @param endpointUrl The AWS API endpoint that will be used for the kinesis client - * @param regionName The AWS region that will be connected to (will set default enpoint for dynamo and cloudwatch) + * @param regionName The AWS region that will be connected to + * (will set default enpoint for dynamo and cloudwatch) * @param initialPositionInStream In the absence of Kinesis checkpoint info, this is the * worker's initial starting position in the stream. * The values are either the beginning of the stream * per Kinesis' limit of 24 hours * (InitialPositionInStream.TRIM_HORIZON) or * the tip of the stream (InitialPositionInStream.LATEST). - * @param awsCredentialsOption None or Some instance of SerializableAWSCredentials that will be used for - * credentials for Kinesis and the default for other clients. If None, then the + * @param awsCredentialsOption None or Some instance of SerializableAWSCredentials that + * will be used for credentials for Kinesis and the default + * for other clients. If None, then the * DefaultAWSCredentialsProviderChain will be used - * @param dynamoEndpointUrl None or Some AWS API endpoint that will be used for the DynamoDBClient, if None, then the regionName + * @param dynamoEndpointUrl None or Some AWS API endpoint that will be used for + * the DynamoDBClient, if None, then the regionName * will be used to build the default endpoint - * @param dynamoCredentials None or Some SerializableAWSCredentials that will be used as the credentials. If None, - * then the DefaultProviderKeychain will be used to build credentials + * @param dynamoCredentials None or Some SerializableAWSCredentials that will be used + * as the credentials. If None, then the + * DefaultProviderKeychain will be used to build credentials * */ case class KinesisConfig( @@ -65,15 +68,21 @@ case class KinesisConfig( /** * Builds a KinesisClientLibConfiguration object, which contains all the configuration options - * See the docs for more info: - * http://static.javadoc.io/com.amazonaws/amazon-kinesis-client/1.6.1/com/amazonaws/services/kinesis/clientlibrary/lib/worker/KinesisClientLibConfiguration.html + * See the + * KinesisClientLibConfiguration docs + * for more info: + * * * @param workerId A unique string to identify a worker */ def buildKCLConfig(workerId: String): KinesisClientLibConfiguration = { // KCL config instance val kinesisClientLibConfiguration = - new KinesisClientLibConfiguration(kinesisAppName, streamName, resolveAWSCredentialsProvider(), workerId) + new KinesisClientLibConfiguration( + kinesisAppName, + streamName, + resolveAWSCredentialsProvider(), + workerId) .withKinesisEndpoint(endpointUrl) .withInitialPositionInStream(initialPositionInStream) .withTaskBackoffTimeMillis(500) @@ -87,7 +96,11 @@ case class KinesisConfig( * Returns a AmazonDynamoDBClient instance configured with the proper region/endpoint */ def buildDynamoClient(): AmazonDynamoDBClient = { - val client = if (dynamoCredentials.isDefined) new AmazonDynamoDBClient(resolveAWSCredentialsProvider(dynamoCredentials)) else new AmazonDynamoDBClient(resolveAWSCredentialsProvider()) + val client = if (dynamoCredentials.isDefined) { + new AmazonDynamoDBClient(resolveAWSCredentialsProvider(dynamoCredentials)) + } else { + new AmazonDynamoDBClient(resolveAWSCredentialsProvider()) + } client.setRegion(region) if (dynamoEndpointUrl.isDefined) { client.setEndpoint(dynamoEndpointUrl.get) @@ -117,7 +130,8 @@ case class KinesisConfig( } /** - * Returns the provided credentials or resolves a pair of credentials using DefaultAWSCredentialsProviderChain + * Returns the provided credentials or resolves a + * pair of credentials using DefaultAWSCredentialsProviderChain */ def awsCredentials: AWSCredentials = { resolveAWSCredentialsProvider().getCredentials() @@ -128,7 +142,9 @@ case class KinesisConfig( * If AWS credential is provided, return a AWSCredentialProvider returning that credential. * Otherwise, return the DefaultAWSCredentialsProviderChain. */ - private def resolveAWSCredentialsProvider(awsCredOpt: Option[SerializableAWSCredentials] = awsCredentialsOption): AWSCredentialsProvider = { + private def resolveAWSCredentialsProvider( + awsCredOpt: Option[SerializableAWSCredentials] = awsCredentialsOption + ): AWSCredentialsProvider = { awsCredOpt match { case Some(awsCredentials) => new AWSCredentialsProvider { @@ -189,4 +205,4 @@ private object KinesisConfig { regionName, initialPositionInStream, awsCredentialsOption) } -} \ No newline at end of file +} diff --git a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisTestUtils.scala b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisTestUtils.scala index 9327f7a893e83..a134cbf8b26d6 100644 --- a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisTestUtils.scala +++ b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisTestUtils.scala @@ -26,10 +26,10 @@ import scala.collection.mutable.ArrayBuffer import scala.util.{Failure, Random, Success, Try} import com.amazonaws.auth.{AWSCredentials, DefaultAWSCredentialsProviderChain} -import com.amazonaws.services.kinesis.clientlibrary.lib.worker.InitialPositionInStream import com.amazonaws.regions.RegionUtils import com.amazonaws.services.dynamodbv2.AmazonDynamoDBClient import com.amazonaws.services.dynamodbv2.document.DynamoDB +import com.amazonaws.services.kinesis.clientlibrary.lib.worker.InitialPositionInStream import com.amazonaws.services.kinesis.AmazonKinesisClient import com.amazonaws.services.kinesis.model._ @@ -68,7 +68,12 @@ private[kinesis] class KinesisTestUtils extends Logging { } lazy val sparkKinesisConfig: KinesisConfig = { - KinesisConfig.buildConfig("kinesis-asl-unit-test", _streamName, endpointUrl, regionName, InitialPositionInStream.TRIM_HORIZON) + KinesisConfig.buildConfig( + "kinesis-asl-unit-test", + _streamName, + endpointUrl, + regionName, + InitialPositionInStream.TRIM_HORIZON) } protected def getProducer(aggregate: Boolean): KinesisDataGenerator = { diff --git a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisUtils.scala b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisUtils.scala index 7c65cd0ce2fa4..b6952d9d05ab6 100644 --- a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisUtils.scala +++ b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisUtils.scala @@ -156,7 +156,8 @@ object KinesisUtils { // scalastyle:on val cleanedHandler = ssc.sc.clean(messageHandler) val kinesisClientConfig = KinesisConfig.buildConfig(kinesisAppName, streamName, - endpointUrl, validateRegion(regionName), initialPositionInStream, Some(SerializableAWSCredentials(awsAccessKeyId, awsSecretKey))) + endpointUrl, validateRegion(regionName), initialPositionInStream, + Some(SerializableAWSCredentials(awsAccessKeyId, awsSecretKey))) ssc.withNamedScope("kinesis stream") { new KinesisInputDStream[T](ssc, kinesisClientConfig, checkpointInterval, storageLevel, cleanedHandler) @@ -184,7 +185,8 @@ object KinesisUtils { storageLevel: StorageLevel): ReceiverInputDStream[Array[Byte]] = { // Setting scope to override receiver stream's scope of "receiver stream" ssc.withNamedScope("kinesis stream") { - new KinesisInputDStream[Array[Byte]](ssc, config, checkpointInterval, storageLevel, defaultMessageHandler) + new KinesisInputDStream[Array[Byte]](ssc, config, checkpointInterval, + storageLevel, defaultMessageHandler) } } @@ -274,7 +276,8 @@ object KinesisUtils { awsAccessKeyId: String, awsSecretKey: String): ReceiverInputDStream[Array[Byte]] = { val kinesisClientConfig = KinesisConfig.buildConfig(kinesisAppName, streamName, - endpointUrl, validateRegion(regionName), initialPositionInStream, Some(SerializableAWSCredentials(awsAccessKeyId, awsSecretKey))) + endpointUrl, validateRegion(regionName), initialPositionInStream, + Some(SerializableAWSCredentials(awsAccessKeyId, awsSecretKey))) ssc.withNamedScope("kinesis stream") { new KinesisInputDStream[Array[Byte]](ssc, kinesisClientConfig, checkpointInterval, storageLevel, defaultMessageHandler) diff --git a/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisConfigSuite.scala b/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisConfigSuite.scala index 36fc2a4a0a27a..95e62918ea020 100644 --- a/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisConfigSuite.scala +++ b/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisConfigSuite.scala @@ -38,7 +38,7 @@ class KinesisConfigSuite extends KinesisFunSuite { test("builds a KinesisClientLibConfiguration with defaults set") { - val kinesisConfig = new KinesisConfig(kinesisAppName, kinesisStreamName, regionName, endpointUrl, streamPosition) + val kinesisConfig = new KinesisConfig(kinesisAppName, kinesisStreamName, endpointUrl, regionName, streamPosition) val kclConfig = kinesisConfig.buildKCLConfig(workerId) assert(kclConfig.getApplicationName() == kinesisAppName) assert(kclConfig.getStreamName() == kinesisStreamName) From dd2e0dcbcfbaca110bc768a358d0f9654ff2ec3f Mon Sep 17 00:00:00 2001 From: Addison Higham Date: Tue, 23 Feb 2016 15:47:43 -0700 Subject: [PATCH 4/4] refactor to newer KCL interfaces, remove regionName to fix endpoints being overwritten --- .../kinesis/KinesisBackedBlockRDD.scala | 2 -- .../streaming/kinesis/KinesisConfig.scala | 18 ++++------- .../streaming/kinesis/KinesisReceiver.scala | 16 +++++++--- .../kinesis/KinesisRecordProcessor.scala | 22 +++++++------ .../kinesis/KinesisConfigSuite.scala | 1 - .../kinesis/KinesisReceiverSuite.scala | 32 ++++++++++++------- 6 files changed, 51 insertions(+), 40 deletions(-) diff --git a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisBackedBlockRDD.scala b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisBackedBlockRDD.scala index c1e81314154e9..74d1295ce3120 100644 --- a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisBackedBlockRDD.scala +++ b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisBackedBlockRDD.scala @@ -132,8 +132,6 @@ class KinesisSequenceRangeIterator( private var lastSeqNumber: String = null private var internalIterator: Iterator[Record] = null - client.setEndpoint(config.endpointUrl) - override protected def getNext(): Record = { var nextRecord: Record = null if (toSeqNumberReceived) { diff --git a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisConfig.scala b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisConfig.scala index efdf0834e3511..727e691c4e83d 100644 --- a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisConfig.scala +++ b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisConfig.scala @@ -64,7 +64,7 @@ case class KinesisConfig( awsCredentialsOption: Option[SerializableAWSCredentials] = None, dynamoEndpointUrl: Option[String] = None, dynamoCredentials: Option[SerializableAWSCredentials] = None - ) { + ) extends Serializable { /** * Builds a KinesisClientLibConfiguration object, which contains all the configuration options @@ -86,7 +86,6 @@ case class KinesisConfig( .withKinesisEndpoint(endpointUrl) .withInitialPositionInStream(initialPositionInStream) .withTaskBackoffTimeMillis(500) - .withRegionName(regionName) return kinesisClientLibConfiguration } @@ -101,11 +100,12 @@ case class KinesisConfig( } else { new AmazonDynamoDBClient(resolveAWSCredentialsProvider()) } - client.setRegion(region) + if (dynamoEndpointUrl.isDefined) { - client.setEndpoint(dynamoEndpointUrl.get) + client.withEndpoint(dynamoEndpointUrl.get) + } else { + client.withRegion(region) } - client } /** @@ -113,9 +113,7 @@ case class KinesisConfig( */ def buildKinesisClient(): AmazonKinesisClient = { val client = new AmazonKinesisClient(resolveAWSCredentialsProvider()) - client.setRegion(region) - client.setEndpoint(endpointUrl) - client + client.withEndpoint(endpointUrl) } @@ -124,9 +122,7 @@ case class KinesisConfig( */ def buildCloudwatchClient(): AmazonCloudWatchClient = { val client = new AmazonCloudWatchClient(resolveAWSCredentialsProvider()) - client.setRegion(region) - client - + client.withRegion(region) } /** diff --git a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisReceiver.scala b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisReceiver.scala index fbe55996809b4..61010d1e15c07 100644 --- a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisReceiver.scala +++ b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisReceiver.scala @@ -23,7 +23,8 @@ import scala.collection.JavaConverters._ import scala.collection.mutable import scala.util.control.NonFatal -import com.amazonaws.services.kinesis.clientlibrary.interfaces.{IRecordProcessor, IRecordProcessorCheckpointer, IRecordProcessorFactory} +import com.amazonaws.services.kinesis.clientlibrary.interfaces.v2.{IRecordProcessor, IRecordProcessorFactory} +import com.amazonaws.services.kinesis.clientlibrary.interfaces.IRecordProcessorCheckpointer import com.amazonaws.services.kinesis.clientlibrary.lib.worker.Worker import com.amazonaws.services.kinesis.model.Record @@ -33,7 +34,6 @@ import org.apache.spark.streaming.receiver.{BlockGenerator, BlockGeneratorListen import org.apache.spark.util.Utils import org.apache.spark.Logging - /** * Custom AWS Kinesis-specific implementation of Spark Streaming's Receiver. * This implementation relies on the Kinesis Client Library (KCL) Worker as described here: @@ -60,7 +60,7 @@ import org.apache.spark.Logging * @param storageLevel Storage level to use for storing the received objects */ private[kinesis] class KinesisReceiver[T]( - config: KinesisConfig, + val config: KinesisConfig, checkpointInterval: Duration, storageLevel: StorageLevel, messageHandler: Record => T) @@ -132,8 +132,14 @@ private[kinesis] class KinesisReceiver[T]( new KinesisRecordProcessor(receiver, workerId) } - worker = new Worker(recordProcessorFactory, config.buildKCLConfig(workerId), - config.buildKinesisClient(), config.buildDynamoClient(), config.buildCloudwatchClient()) + worker = new Worker.Builder() + .recordProcessorFactory(recordProcessorFactory) + .config(config.buildKCLConfig(workerId)) + .kinesisClient(config.buildKinesisClient()) + .dynamoDBClient(config.buildDynamoClient()) + .cloudWatchClient(config.buildCloudwatchClient()) + .build() + workerThread = new Thread() { override def run(): Unit = { try { diff --git a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisRecordProcessor.scala b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisRecordProcessor.scala index b5b76cb92d866..5deaa4d9907b7 100644 --- a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisRecordProcessor.scala +++ b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisRecordProcessor.scala @@ -22,8 +22,8 @@ import scala.util.Random import scala.util.control.NonFatal import com.amazonaws.services.kinesis.clientlibrary.exceptions.{InvalidStateException, KinesisClientLibDependencyException, ShutdownException, ThrottlingException} -import com.amazonaws.services.kinesis.clientlibrary.interfaces.{IRecordProcessor, IRecordProcessorCheckpointer} -import com.amazonaws.services.kinesis.clientlibrary.types.ShutdownReason +import com.amazonaws.services.kinesis.clientlibrary.interfaces.v2.IRecordProcessor +import com.amazonaws.services.kinesis.clientlibrary.types.{InitializationInput, ProcessRecordsInput, ShutdownInput, ShutdownReason} import com.amazonaws.services.kinesis.model.Record import org.apache.spark.Logging @@ -50,10 +50,10 @@ private[kinesis] class KinesisRecordProcessor[T](receiver: KinesisReceiver[T], w /** * The Kinesis Client Library calls this method during IRecordProcessor initialization. * - * @param shardId assigned by the KCL to this particular RecordProcessor. + * @param initInput, contains info about the places this processor starts from */ - override def initialize(shardId: String) { - this.shardId = shardId + override def initialize(initInput: InitializationInput) { + this.shardId = initInput.getShardId() logInfo(s"Initialized workerId $workerId with shardId $shardId") } @@ -66,12 +66,13 @@ private[kinesis] class KinesisRecordProcessor[T](receiver: KinesisReceiver[T], w * @param checkpointer used to update Kinesis when this batch has been processed/stored * in the DStream */ - override def processRecords(batch: List[Record], checkpointer: IRecordProcessorCheckpointer) { + override def processRecords(recordInput: ProcessRecordsInput) { if (!receiver.isStopped()) { try { + val batch = recordInput.getRecords() receiver.addRecords(shardId, batch) logDebug(s"Stored: Worker $workerId stored ${batch.size} records for shardId $shardId") - receiver.setCheckpointer(shardId, checkpointer) + receiver.setCheckpointer(shardId, recordInput.getCheckpointer()) } catch { case NonFatal(e) => { /* @@ -103,8 +104,9 @@ private[kinesis] class KinesisRecordProcessor[T](receiver: KinesisReceiver[T], w * @param checkpointer used to perform a Kinesis checkpoint for ShutdownReason.TERMINATE * @param reason for shutdown (ShutdownReason.TERMINATE or ShutdownReason.ZOMBIE) */ - override def shutdown(checkpointer: IRecordProcessorCheckpointer, reason: ShutdownReason) { - logInfo(s"Shutdown: Shutting down workerId $workerId with reason $reason") + override def shutdown(shutdownInput: ShutdownInput) { + val reason = shutdownInput.getShutdownReason() + logInfo(s"Shutdown: Shutting down workerId $workerId with reason $reason.") reason match { /* * TERMINATE Use Case. Checkpoint. @@ -112,7 +114,7 @@ private[kinesis] class KinesisRecordProcessor[T](receiver: KinesisReceiver[T], w * It's now OK to read from the new shards that resulted from a resharding event. */ case ShutdownReason.TERMINATE => - receiver.removeCheckpointer(shardId, checkpointer) + receiver.removeCheckpointer(shardId, shutdownInput.getCheckpointer()) /* * ZOMBIE Use Case or Unknown reason. NoOp. diff --git a/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisConfigSuite.scala b/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisConfigSuite.scala index 95e62918ea020..aaefb7784e3d3 100644 --- a/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisConfigSuite.scala +++ b/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisConfigSuite.scala @@ -44,7 +44,6 @@ class KinesisConfigSuite extends KinesisFunSuite { assert(kclConfig.getStreamName() == kinesisStreamName) assert(kclConfig.getInitialPositionInStream() == streamPosition) assert(kclConfig.getApplicationName() == kinesisAppName) - assert(kclConfig.getRegionName() == regionName) assert(kclConfig.getKinesisEndpoint() == endpointUrl) } diff --git a/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisReceiverSuite.scala b/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisReceiverSuite.scala index fd15b6ccdc889..868c045023a5f 100644 --- a/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisReceiverSuite.scala +++ b/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisReceiverSuite.scala @@ -22,7 +22,7 @@ import java.util.Arrays import com.amazonaws.services.kinesis.clientlibrary.exceptions._ import com.amazonaws.services.kinesis.clientlibrary.interfaces.IRecordProcessorCheckpointer -import com.amazonaws.services.kinesis.clientlibrary.types.ShutdownReason +import com.amazonaws.services.kinesis.clientlibrary.types.{ShutdownReason, InitializationInput, ProcessRecordsInput, ShutdownInput} import com.amazonaws.services.kinesis.model.Record import org.mockito.Matchers._ import org.mockito.Matchers.{eq => meq} @@ -53,6 +53,14 @@ class KinesisReceiverSuite extends TestSuiteBase with Matchers with BeforeAndAft val record2 = new Record() record2.setData(ByteBuffer.wrap("Learning Spark".getBytes(StandardCharsets.UTF_8))) val batch = Arrays.asList(record1, record2) + val initInput = new InitializationInput() + .withShardId(shardId) + + var recordInput = new ProcessRecordsInput() + .withRecords(batch) + + var shutdownInput = new ShutdownInput() + .withShutdownReason(ShutdownReason.TERMINATE) var receiverMock: KinesisReceiver[Array[Byte]] = _ var checkpointerMock: IRecordProcessorCheckpointer = _ @@ -60,6 +68,8 @@ class KinesisReceiverSuite extends TestSuiteBase with Matchers with BeforeAndAft override def beforeFunction(): Unit = { receiverMock = mock[KinesisReceiver[Array[Byte]]] checkpointerMock = mock[IRecordProcessorCheckpointer] + recordInput = recordInput.withCheckpointer(checkpointerMock) + shutdownInput = shutdownInput.withCheckpointer(checkpointerMock) } test("check serializability of SerializableAWSCredentials") { @@ -71,8 +81,8 @@ class KinesisReceiverSuite extends TestSuiteBase with Matchers with BeforeAndAft when(receiverMock.isStopped()).thenReturn(false) val recordProcessor = new KinesisRecordProcessor(receiverMock, workerId) - recordProcessor.initialize(shardId) - recordProcessor.processRecords(batch, checkpointerMock) + recordProcessor.initialize(initInput) + recordProcessor.processRecords(recordInput) verify(receiverMock, times(1)).isStopped() verify(receiverMock, times(1)).addRecords(shardId, batch) @@ -83,7 +93,7 @@ class KinesisReceiverSuite extends TestSuiteBase with Matchers with BeforeAndAft when(receiverMock.isStopped()).thenReturn(true) val recordProcessor = new KinesisRecordProcessor(receiverMock, workerId) - recordProcessor.processRecords(batch, checkpointerMock) + recordProcessor.processRecords(recordInput) verify(receiverMock, times(1)).isStopped() verify(receiverMock, never).addRecords(anyString, anyListOf(classOf[Record])) @@ -98,8 +108,8 @@ class KinesisReceiverSuite extends TestSuiteBase with Matchers with BeforeAndAft intercept[RuntimeException] { val recordProcessor = new KinesisRecordProcessor(receiverMock, workerId) - recordProcessor.initialize(shardId) - recordProcessor.processRecords(batch, checkpointerMock) + recordProcessor.initialize(initInput) + recordProcessor.processRecords(recordInput) } verify(receiverMock, times(1)).isStopped() @@ -111,8 +121,8 @@ class KinesisReceiverSuite extends TestSuiteBase with Matchers with BeforeAndAft when(receiverMock.getLatestSeqNumToCheckpoint(shardId)).thenReturn(someSeqNum) val recordProcessor = new KinesisRecordProcessor(receiverMock, workerId) - recordProcessor.initialize(shardId) - recordProcessor.shutdown(checkpointerMock, ShutdownReason.TERMINATE) + recordProcessor.initialize(initInput) + recordProcessor.shutdown(shutdownInput) verify(receiverMock, times(1)).removeCheckpointer(meq(shardId), meq(checkpointerMock)) } @@ -122,9 +132,9 @@ class KinesisReceiverSuite extends TestSuiteBase with Matchers with BeforeAndAft when(receiverMock.getLatestSeqNumToCheckpoint(shardId)).thenReturn(someSeqNum) val recordProcessor = new KinesisRecordProcessor(receiverMock, workerId) - recordProcessor.initialize(shardId) - recordProcessor.shutdown(checkpointerMock, ShutdownReason.ZOMBIE) - recordProcessor.shutdown(checkpointerMock, null) + recordProcessor.initialize(initInput) + recordProcessor.shutdown(shutdownInput.withShutdownReason(ShutdownReason.ZOMBIE)) + recordProcessor.shutdown(shutdownInput.withShutdownReason(null)) verify(receiverMock, times(2)).removeCheckpointer(meq(shardId), meq[IRecordProcessorCheckpointer](null))