1818package org .apache .spark .streaming .kinesis
1919
2020import scala .collection .JavaConversions ._
21+ import scala .util .control .NonFatal
2122
2223import com .amazonaws .auth .{AWSCredentials , DefaultAWSCredentialsProviderChain }
2324import com .amazonaws .services .kinesis .AmazonKinesisClient
@@ -29,7 +30,7 @@ import org.apache.spark.storage.BlockId
2930import org .apache .spark .util .NextIterator
3031
3132
32- /** Class representing a range of Kinesis sequence numbers */
33+ /** Class representing a range of Kinesis sequence numbers. Both sequence numbers are inclusive. */
3334private [kinesis]
3435case class SequenceNumberRange (
3536 streamName : String , shardId : String , fromSeqNumber : String , toSeqNumber : String )
@@ -71,8 +72,9 @@ class KinesisBackedBlockRDD(
7172 @ transient blockIds : Array [BlockId ],
7273 @ transient arrayOfseqNumberRanges : Array [SequenceNumberRanges ],
7374 @ transient isBlockIdValid : Array [Boolean ] = Array .empty,
75+ retryTimeoutMs : Int = 10000 ,
7476 awsCredentialsOption : Option [SerializableAWSCredentials ] = None
75- ) extends BlockRDD [Array [Byte ]](sc, blockIds) {
77+ ) extends BlockRDD [Array [Byte ]](sc, blockIds) {
7678
7779 require(blockIds.length == arrayOfseqNumberRanges.length,
7880 " Number of blockIds is not equal to the number of sequence number ranges" )
@@ -101,7 +103,8 @@ class KinesisBackedBlockRDD(
101103 new DefaultAWSCredentialsProviderChain ().getCredentials()
102104 }
103105 partition.seqNumberRanges.ranges.iterator.flatMap { range =>
104- new KinesisSequenceRangeIterator (credenentials, endpointUrl, regionId, range)
106+ new KinesisSequenceRangeIterator (
107+ credenentials, endpointUrl, regionId, range, retryTimeoutMs)
105108 }
106109 }
107110 if (partition.isBlockIdValid) {
@@ -113,17 +116,23 @@ class KinesisBackedBlockRDD(
113116}
114117
115118
116- /** An iterator that return the Kinesis data based on the given range of Sequence numbers */
119+ /**
120+ * An iterator that return the Kinesis data based on the given range of sequence numbers.
121+ * Internally, it repeatedly fetches sets of records starting from the fromSequenceNumber,
122+ * until the endSequenceNumber is reached.
123+ */
117124private [kinesis]
118125class KinesisSequenceRangeIterator (
119126 credentials : AWSCredentials ,
120127 endpointUrl : String ,
121128 regionId : String ,
122- range : SequenceNumberRange
123- ) extends NextIterator [Array [Byte ]] {
129+ range : SequenceNumberRange ,
130+ retryTimeoutMs : Int
131+ ) extends NextIterator [Array [Byte ]] with Logging {
124132
125- private val backoffTimeMillis = 1000
126133 private val client = new AmazonKinesisClient (credentials)
134+ private val streamName = range.streamName
135+ private val shardId = range.shardId
127136
128137 private var toSeqNumberReceived = false
129138 private var lastSeqNumber : String = null
@@ -141,25 +150,28 @@ class KinesisSequenceRangeIterator(
141150
142151 // If the internal iterator has not been initialized,
143152 // then fetch records from starting sequence number
144- getRecords(ShardIteratorType .AT_SEQUENCE_NUMBER , range.fromSeqNumber)
153+ internalIterator = getRecords(ShardIteratorType .AT_SEQUENCE_NUMBER , range.fromSeqNumber)
145154 } else if (! internalIterator.hasNext) {
146155
147156 // If the internal iterator does not have any more records,
148157 // then fetch more records after the last consumed sequence number
149- getRecords(ShardIteratorType .AFTER_SEQUENCE_NUMBER , lastSeqNumber)
158+ internalIterator = getRecords(ShardIteratorType .AFTER_SEQUENCE_NUMBER , lastSeqNumber)
150159 }
151160
152161 if (! internalIterator.hasNext) {
153162
154163 // If the internal iterator still does not have any data, then throw exception
155164 // and terminate this iterator
156165 finished = true
157- throw new SparkException (s " Could not read until the specified end sequence number: $range" )
166+ throw new SparkException (
167+ s " Could not read until the end sequence number of the range: $range" )
158168 } else {
159169
160- // Get the record, and remember its sequence number
161- val nextRecord = internalIterator.next()
162- nextBytes = nextRecord.getData().array()
170+ // Get the record, copy the data into a byte array and remember its sequence number
171+ val nextRecord : Record = internalIterator.next()
172+ val byteBuffer = nextRecord.getData()
173+ nextBytes = new Array [Byte ](byteBuffer.remaining())
174+ byteBuffer.get(nextBytes )
163175 lastSeqNumber = nextRecord.getSequenceNumber()
164176
165177 // If the this record's sequence number matches the stopping sequence number, then make sure
@@ -173,51 +185,92 @@ class KinesisSequenceRangeIterator(
173185 nextBytes
174186 }
175187
176- override protected def close (): Unit = { }
188+ override protected def close (): Unit = {
189+ client.shutdown()
190+ }
177191
178- private def getRecords (iteratorType : ShardIteratorType , seqNum : String ): Unit = {
179- val shardIterator = getKinesisIterator(range.streamName, range.shardId, iteratorType, seqNum)
180- var records : Seq [Record ] = null
181- do {
182- try {
183- val getResult = getRecordsAndNextKinesisIterator(
184- range.streamName, range.shardId, shardIterator)
185- records = getResult._1
186- } catch {
187- case ptee : ProvisionedThroughputExceededException =>
188- Thread .sleep(backoffTimeMillis)
189- }
190- } while (records == null || records.length == 0 ) // TODO: put a limit on the number of retries
191- if (records != null && records.nonEmpty) {
192- internalIterator = records.iterator
193- }
192+ /**
193+ * Get records starting from or after the given sequence number.
194+ */
195+ private def getRecords (iteratorType : ShardIteratorType , seqNum : String ): Iterator [Record ] = {
196+ val shardIterator = getKinesisIterator(iteratorType, seqNum)
197+ val result = getRecordsAndNextKinesisIterator(shardIterator)
198+ result._1
194199 }
195200
201+ /**
202+ * Get the records starting from using a Kinesis shard iterator (which is a progress handle
203+ * to get records from Kinesis), and get the next shard iterator for next consumption.
204+ */
196205 private def getRecordsAndNextKinesisIterator (
197- streamName : String ,
198- shardId : String ,
199- shardIterator : String
200- ): (Seq [Record ], String ) = {
206+ shardIterator : String ): (Iterator [Record ], String ) = {
201207 val getRecordsRequest = new GetRecordsRequest
202208 getRecordsRequest.setRequestCredentials(credentials)
203209 getRecordsRequest.setShardIterator(shardIterator)
204- val getRecordsResult = client.getRecords(getRecordsRequest)
205- (getRecordsResult.getRecords, getRecordsResult.getNextShardIterator)
210+ val getRecordsResult = retryOrTimeout[GetRecordsResult ](
211+ s " getting records using shard iterator " ) {
212+ client.getRecords(getRecordsRequest)
213+ }
214+ (getRecordsResult.getRecords.iterator(), getRecordsResult.getNextShardIterator)
206215 }
207216
217+ /**
218+ * Get the Kinesis shard iterator for getting records starting from or after the given
219+ * sequence number.
220+ */
208221 private def getKinesisIterator (
209- streamName : String ,
210- shardId : String ,
211222 iteratorType : ShardIteratorType ,
212- sequenceNumber : String
213- ): String = {
223+ sequenceNumber : String ): String = {
214224 val getShardIteratorRequest = new GetShardIteratorRequest
215225 getShardIteratorRequest.setRequestCredentials(credentials)
216226 getShardIteratorRequest.setStreamName(streamName)
217227 getShardIteratorRequest.setShardId(shardId)
218228 getShardIteratorRequest.setShardIteratorType(iteratorType.toString)
219229 getShardIteratorRequest.setStartingSequenceNumber(sequenceNumber)
220- val getShardIteratorResult = client.getShardIterator(getShardIteratorRequest)
230+ val getShardIteratorResult = retryOrTimeout[GetShardIteratorResult ](
231+ s " getting shard iterator from sequence number $sequenceNumber" ) {
232+ client.getShardIterator(getShardIteratorRequest)
233+ }
221234 getShardIteratorResult.getShardIterator
222235 }
236+
237+ private def retryOrTimeout [T ](message : String )(body : => T ): T = {
238+ import KinesisSequenceRangeIterator ._
239+
240+ var startTimeMs = System .currentTimeMillis()
241+ var retryCount = 0
242+ var waitTimeMs = 0
243+ var result : Option [T ] = None
244+ var lastError : Throwable = null
245+
246+ def timeSpentMs = System .currentTimeMillis() - startTimeMs
247+
248+ while (result == null && retryCount < MAX_RETRIES && timeSpentMs <= retryTimeoutMs) {
249+ Thread .sleep(waitTimeMs)
250+ try {
251+ result = Some (body)
252+ } catch {
253+ case NonFatal (t) =>
254+ lastError = t
255+ t match {
256+ case ptee : ProvisionedThroughputExceededException =>
257+ logWarning(s " Exception while $message" , ptee)
258+ case e : Throwable =>
259+ throw new SparkException (s " Error $message" , e)
260+ }
261+ } finally {
262+ retryCount += 1
263+ if (waitTimeMs == 0 ) waitTimeMs = MIN_RETRY_WAIT_TIME_MS else waitTimeMs *= 2
264+ }
265+ }
266+ result.getOrElse {
267+ throw new SparkException (s " Timed out while $message, last exception: " , lastError)
268+ }
269+ }
223270}
271+
272+ private [streaming]
273+ object KinesisSequenceRangeIterator {
274+ val MAX_RETRIES = 3
275+ val MIN_RETRY_WAIT_TIME_MS = 100
276+ }
0 commit comments