Skip to content

Commit f6e35c8

Browse files
committed
Added retry logic to make it more robust
1 parent 8874b70 commit f6e35c8

File tree

1 file changed

+94
-41
lines changed

1 file changed

+94
-41
lines changed

extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisBackedBlockRDD.scala

Lines changed: 94 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
package org.apache.spark.streaming.kinesis
1919

2020
import scala.collection.JavaConversions._
21+
import scala.util.control.NonFatal
2122

2223
import com.amazonaws.auth.{AWSCredentials, DefaultAWSCredentialsProviderChain}
2324
import com.amazonaws.services.kinesis.AmazonKinesisClient
@@ -29,7 +30,7 @@ import org.apache.spark.storage.BlockId
2930
import org.apache.spark.util.NextIterator
3031

3132

32-
/** Class representing a range of Kinesis sequence numbers */
33+
/** Class representing a range of Kinesis sequence numbers. Both sequence numbers are inclusive. */
3334
private[kinesis]
3435
case class SequenceNumberRange(
3536
streamName: String, shardId: String, fromSeqNumber: String, toSeqNumber: String)
@@ -71,8 +72,9 @@ class KinesisBackedBlockRDD(
7172
@transient blockIds: Array[BlockId],
7273
@transient arrayOfseqNumberRanges: Array[SequenceNumberRanges],
7374
@transient isBlockIdValid: Array[Boolean] = Array.empty,
75+
retryTimeoutMs: Int = 10000,
7476
awsCredentialsOption: Option[SerializableAWSCredentials] = None
75-
) extends BlockRDD[Array[Byte]](sc, blockIds) {
77+
) extends BlockRDD[Array[Byte]](sc, blockIds) {
7678

7779
require(blockIds.length == arrayOfseqNumberRanges.length,
7880
"Number of blockIds is not equal to the number of sequence number ranges")
@@ -101,7 +103,8 @@ class KinesisBackedBlockRDD(
101103
new DefaultAWSCredentialsProviderChain().getCredentials()
102104
}
103105
partition.seqNumberRanges.ranges.iterator.flatMap { range =>
104-
new KinesisSequenceRangeIterator(credenentials, endpointUrl, regionId, range)
106+
new KinesisSequenceRangeIterator(
107+
credenentials, endpointUrl, regionId, range, retryTimeoutMs)
105108
}
106109
}
107110
if (partition.isBlockIdValid) {
@@ -113,17 +116,23 @@ class KinesisBackedBlockRDD(
113116
}
114117

115118

116-
/** An iterator that return the Kinesis data based on the given range of Sequence numbers */
119+
/**
120+
* An iterator that return the Kinesis data based on the given range of sequence numbers.
121+
* Internally, it repeatedly fetches sets of records starting from the fromSequenceNumber,
122+
* until the endSequenceNumber is reached.
123+
*/
117124
private[kinesis]
118125
class KinesisSequenceRangeIterator(
119126
credentials: AWSCredentials,
120127
endpointUrl: String,
121128
regionId: String,
122-
range: SequenceNumberRange
123-
) extends NextIterator[Array[Byte]] {
129+
range: SequenceNumberRange,
130+
retryTimeoutMs: Int
131+
) extends NextIterator[Array[Byte]] with Logging {
124132

125-
private val backoffTimeMillis = 1000
126133
private val client = new AmazonKinesisClient(credentials)
134+
private val streamName = range.streamName
135+
private val shardId = range.shardId
127136

128137
private var toSeqNumberReceived = false
129138
private var lastSeqNumber: String = null
@@ -141,25 +150,28 @@ class KinesisSequenceRangeIterator(
141150

142151
// If the internal iterator has not been initialized,
143152
// then fetch records from starting sequence number
144-
getRecords(ShardIteratorType.AT_SEQUENCE_NUMBER, range.fromSeqNumber)
153+
internalIterator = getRecords(ShardIteratorType.AT_SEQUENCE_NUMBER, range.fromSeqNumber)
145154
} else if (!internalIterator.hasNext) {
146155

147156
// If the internal iterator does not have any more records,
148157
// then fetch more records after the last consumed sequence number
149-
getRecords(ShardIteratorType.AFTER_SEQUENCE_NUMBER, lastSeqNumber)
158+
internalIterator = getRecords(ShardIteratorType.AFTER_SEQUENCE_NUMBER, lastSeqNumber)
150159
}
151160

152161
if (!internalIterator.hasNext) {
153162

154163
// If the internal iterator still does not have any data, then throw exception
155164
// and terminate this iterator
156165
finished = true
157-
throw new SparkException(s"Could not read until the specified end sequence number: $range")
166+
throw new SparkException(
167+
s"Could not read until the end sequence number of the range: $range")
158168
} else {
159169

160-
// Get the record, and remember its sequence number
161-
val nextRecord = internalIterator.next()
162-
nextBytes = nextRecord.getData().array()
170+
// Get the record, copy the data into a byte array and remember its sequence number
171+
val nextRecord: Record = internalIterator.next()
172+
val byteBuffer = nextRecord.getData()
173+
nextBytes = new Array[Byte](byteBuffer.remaining())
174+
byteBuffer.get(nextBytes )
163175
lastSeqNumber = nextRecord.getSequenceNumber()
164176

165177
// If the this record's sequence number matches the stopping sequence number, then make sure
@@ -173,51 +185,92 @@ class KinesisSequenceRangeIterator(
173185
nextBytes
174186
}
175187

176-
override protected def close(): Unit = { }
188+
override protected def close(): Unit = {
189+
client.shutdown()
190+
}
177191

178-
private def getRecords(iteratorType: ShardIteratorType, seqNum: String): Unit = {
179-
val shardIterator = getKinesisIterator(range.streamName, range.shardId, iteratorType, seqNum)
180-
var records: Seq[Record] = null
181-
do {
182-
try {
183-
val getResult = getRecordsAndNextKinesisIterator(
184-
range.streamName, range.shardId, shardIterator)
185-
records = getResult._1
186-
} catch {
187-
case ptee: ProvisionedThroughputExceededException =>
188-
Thread.sleep(backoffTimeMillis)
189-
}
190-
} while (records == null || records.length == 0) // TODO: put a limit on the number of retries
191-
if (records != null && records.nonEmpty) {
192-
internalIterator = records.iterator
193-
}
192+
/**
193+
* Get records starting from or after the given sequence number.
194+
*/
195+
private def getRecords(iteratorType: ShardIteratorType, seqNum: String): Iterator[Record] = {
196+
val shardIterator = getKinesisIterator(iteratorType, seqNum)
197+
val result = getRecordsAndNextKinesisIterator(shardIterator)
198+
result._1
194199
}
195200

201+
/**
202+
* Get the records starting from using a Kinesis shard iterator (which is a progress handle
203+
* to get records from Kinesis), and get the next shard iterator for next consumption.
204+
*/
196205
private def getRecordsAndNextKinesisIterator(
197-
streamName: String,
198-
shardId: String,
199-
shardIterator: String
200-
): (Seq[Record], String) = {
206+
shardIterator: String): (Iterator[Record], String) = {
201207
val getRecordsRequest = new GetRecordsRequest
202208
getRecordsRequest.setRequestCredentials(credentials)
203209
getRecordsRequest.setShardIterator(shardIterator)
204-
val getRecordsResult = client.getRecords(getRecordsRequest)
205-
(getRecordsResult.getRecords, getRecordsResult.getNextShardIterator)
210+
val getRecordsResult = retryOrTimeout[GetRecordsResult](
211+
s"getting records using shard iterator") {
212+
client.getRecords(getRecordsRequest)
213+
}
214+
(getRecordsResult.getRecords.iterator(), getRecordsResult.getNextShardIterator)
206215
}
207216

217+
/**
218+
* Get the Kinesis shard iterator for getting records starting from or after the given
219+
* sequence number.
220+
*/
208221
private def getKinesisIterator(
209-
streamName: String,
210-
shardId: String,
211222
iteratorType: ShardIteratorType,
212-
sequenceNumber: String
213-
): String = {
223+
sequenceNumber: String): String = {
214224
val getShardIteratorRequest = new GetShardIteratorRequest
215225
getShardIteratorRequest.setRequestCredentials(credentials)
216226
getShardIteratorRequest.setStreamName(streamName)
217227
getShardIteratorRequest.setShardId(shardId)
218228
getShardIteratorRequest.setShardIteratorType(iteratorType.toString)
219229
getShardIteratorRequest.setStartingSequenceNumber(sequenceNumber)
220-
val getShardIteratorResult = client.getShardIterator(getShardIteratorRequest)
230+
val getShardIteratorResult = retryOrTimeout[GetShardIteratorResult](
231+
s"getting shard iterator from sequence number $sequenceNumber") {
232+
client.getShardIterator(getShardIteratorRequest)
233+
}
221234
getShardIteratorResult.getShardIterator
222235
}
236+
237+
private def retryOrTimeout[T](message: String)(body: => T): T = {
238+
import KinesisSequenceRangeIterator._
239+
240+
var startTimeMs = System.currentTimeMillis()
241+
var retryCount = 0
242+
var waitTimeMs = 0
243+
var result: Option[T] = None
244+
var lastError: Throwable = null
245+
246+
def timeSpentMs = System.currentTimeMillis() - startTimeMs
247+
248+
while (result == null && retryCount < MAX_RETRIES && timeSpentMs <= retryTimeoutMs) {
249+
Thread.sleep(waitTimeMs)
250+
try {
251+
result = Some(body)
252+
} catch {
253+
case NonFatal(t) =>
254+
lastError = t
255+
t match {
256+
case ptee: ProvisionedThroughputExceededException =>
257+
logWarning(s"Exception while $message", ptee)
258+
case e: Throwable =>
259+
throw new SparkException(s"Error $message", e)
260+
}
261+
} finally {
262+
retryCount += 1
263+
if (waitTimeMs == 0) waitTimeMs = MIN_RETRY_WAIT_TIME_MS else waitTimeMs *= 2
264+
}
265+
}
266+
result.getOrElse {
267+
throw new SparkException(s"Timed out while $message, last exception: ", lastError)
268+
}
269+
}
223270
}
271+
272+
private[streaming]
273+
object KinesisSequenceRangeIterator {
274+
val MAX_RETRIES = 3
275+
val MIN_RETRY_WAIT_TIME_MS = 100
276+
}

0 commit comments

Comments
 (0)