Skip to content

Commit 8874b70

Browse files
committed
Updated Kinesis RDD
1 parent 575bdbc commit 8874b70

File tree

1 file changed

+24
-26
lines changed

1 file changed

+24
-26
lines changed

extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisBackedBlockRDD.scala

Lines changed: 24 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -23,15 +23,18 @@ import com.amazonaws.auth.{AWSCredentials, DefaultAWSCredentialsProviderChain}
2323
import com.amazonaws.services.kinesis.AmazonKinesisClient
2424
import com.amazonaws.services.kinesis.model._
2525

26+
import org.apache.spark._
2627
import org.apache.spark.rdd.{BlockRDD, BlockRDDPartition}
2728
import org.apache.spark.storage.BlockId
2829
import org.apache.spark.util.NextIterator
29-
import org.apache.spark.{Partition, SparkContext, SparkEnv, TaskContext}
3030

31+
32+
/** Class representing a range of Kinesis sequence numbers */
3133
private[kinesis]
3234
case class SequenceNumberRange(
3335
streamName: String, shardId: String, fromSeqNumber: String, toSeqNumber: String)
3436

37+
/** Class representing an array of Kinesis sequence number ranges */
3538
private[kinesis]
3639
case class SequenceNumberRanges(ranges: Array[SequenceNumberRange]) {
3740
def isEmpty(): Boolean = ranges.isEmpty
@@ -41,20 +44,13 @@ case class SequenceNumberRanges(ranges: Array[SequenceNumberRange]) {
4144

4245
private[kinesis]
4346
object SequenceNumberRanges {
44-
4547
def apply(range: SequenceNumberRange): SequenceNumberRanges = {
4648
new SequenceNumberRanges(Array(range))
4749
}
48-
49-
def apply(ranges: Seq[SequenceNumberRange]): SequenceNumberRanges = {
50-
new SequenceNumberRanges(ranges.toArray)
51-
}
52-
53-
def empty: SequenceNumberRanges = {
54-
new SequenceNumberRanges(Array.empty)
55-
}
5650
}
5751

52+
53+
/** Partition storing the information of the ranges of Kinesis sequence numbers to read */
5854
private[kinesis]
5955
class KinesisBackedBlockRDDPartition(
6056
idx: Int,
@@ -63,14 +59,19 @@ class KinesisBackedBlockRDDPartition(
6359
val seqNumberRanges: SequenceNumberRanges
6460
) extends BlockRDDPartition(blockId, idx)
6561

62+
/**
63+
* A BlockRDD where the block data is backed by Kinesis, which can accessed using the
64+
* sequence numbers of the corresponding blocks.
65+
*/
6666
private[kinesis]
6767
class KinesisBackedBlockRDD(
6868
sc: SparkContext,
6969
regionId: String,
7070
endpointUrl: String,
7171
@transient blockIds: Array[BlockId],
7272
@transient arrayOfseqNumberRanges: Array[SequenceNumberRanges],
73-
@transient isBlockIdValid: Array[Boolean] = Array.empty
73+
@transient isBlockIdValid: Array[Boolean] = Array.empty,
74+
awsCredentialsOption: Option[SerializableAWSCredentials] = None
7475
) extends BlockRDD[Array[Byte]](sc, blockIds) {
7576

7677
require(blockIds.length == arrayOfseqNumberRanges.length,
@@ -96,11 +97,11 @@ class KinesisBackedBlockRDD(
9697
}
9798

9899
def getBlockFromKinesis(): Iterator[Array[Byte]] = {
99-
val credenentials = new DefaultAWSCredentialsProviderChain().getCredentials()
100+
val credenentials = awsCredentialsOption.getOrElse {
101+
new DefaultAWSCredentialsProviderChain().getCredentials()
102+
}
100103
partition.seqNumberRanges.ranges.iterator.flatMap { range =>
101-
new KinesisSequenceRangeIterator(
102-
credenentials, endpointUrl, regionId,
103-
range.streamName, range.shardId, range.fromSeqNumber, range.toSeqNumber)
104+
new KinesisSequenceRangeIterator(credenentials, endpointUrl, regionId, range)
104105
}
105106
}
106107
if (partition.isBlockIdValid) {
@@ -112,15 +113,13 @@ class KinesisBackedBlockRDD(
112113
}
113114

114115

116+
/** An iterator that return the Kinesis data based on the given range of Sequence numbers */
115117
private[kinesis]
116118
class KinesisSequenceRangeIterator(
117119
credentials: AWSCredentials,
118120
endpointUrl: String,
119121
regionId: String,
120-
streamName: String,
121-
shardId: String,
122-
fromSeqNumber: String,
123-
toSeqNumber: String
122+
range: SequenceNumberRange
124123
) extends NextIterator[Array[Byte]] {
125124

126125
private val backoffTimeMillis = 1000
@@ -142,7 +141,7 @@ class KinesisSequenceRangeIterator(
142141

143142
// If the internal iterator has not been initialized,
144143
// then fetch records from starting sequence number
145-
getRecords(ShardIteratorType.AT_SEQUENCE_NUMBER, fromSeqNumber)
144+
getRecords(ShardIteratorType.AT_SEQUENCE_NUMBER, range.fromSeqNumber)
146145
} else if (!internalIterator.hasNext) {
147146

148147
// If the internal iterator does not have any more records,
@@ -155,9 +154,7 @@ class KinesisSequenceRangeIterator(
155154
// If the internal iterator still does not have any data, then throw exception
156155
// and terminate this iterator
157156
finished = true
158-
throw new Exception("Could not read until the specified end sequence number: " +
159-
s"shardId = $shardId, fromSequenceNumber = $fromSeqNumber, " +
160-
s"toSequenceNumber = $toSeqNumber")
157+
throw new SparkException(s"Could not read until the specified end sequence number: $range")
161158
} else {
162159

163160
// Get the record, and remember its sequence number
@@ -167,7 +164,7 @@ class KinesisSequenceRangeIterator(
167164

168165
// If the this record's sequence number matches the stopping sequence number, then make sure
169166
// the iterator is marked finished next time getNext() is called
170-
if (nextRecord.getSequenceNumber == toSeqNumber) {
167+
if (nextRecord.getSequenceNumber == range.toSeqNumber) {
171168
toSeqNumberReceived = true
172169
}
173170
}
@@ -179,11 +176,12 @@ class KinesisSequenceRangeIterator(
179176
override protected def close(): Unit = { }
180177

181178
private def getRecords(iteratorType: ShardIteratorType, seqNum: String): Unit = {
182-
val shardIterator = getKinesisIterator(streamName, shardId, iteratorType, seqNum)
179+
val shardIterator = getKinesisIterator(range.streamName, range.shardId, iteratorType, seqNum)
183180
var records: Seq[Record] = null
184181
do {
185182
try {
186-
val getResult = getRecordsAndNextKinesisIterator(streamName, shardId, shardIterator)
183+
val getResult = getRecordsAndNextKinesisIterator(
184+
range.streamName, range.shardId, shardIterator)
187185
records = getResult._1
188186
} catch {
189187
case ptee: ProvisionedThroughputExceededException =>

0 commit comments

Comments
 (0)