@@ -23,15 +23,18 @@ import com.amazonaws.auth.{AWSCredentials, DefaultAWSCredentialsProviderChain}
2323import com .amazonaws .services .kinesis .AmazonKinesisClient
2424import com .amazonaws .services .kinesis .model ._
2525
26+ import org .apache .spark ._
2627import org .apache .spark .rdd .{BlockRDD , BlockRDDPartition }
2728import org .apache .spark .storage .BlockId
2829import org .apache .spark .util .NextIterator
29- import org .apache .spark .{Partition , SparkContext , SparkEnv , TaskContext }
3030
31+
32+ /** Class representing a range of Kinesis sequence numbers */
3133private [kinesis]
3234case class SequenceNumberRange (
3335 streamName : String , shardId : String , fromSeqNumber : String , toSeqNumber : String )
3436
37+ /** Class representing an array of Kinesis sequence number ranges */
3538private [kinesis]
3639case class SequenceNumberRanges (ranges : Array [SequenceNumberRange ]) {
3740 def isEmpty (): Boolean = ranges.isEmpty
@@ -41,20 +44,13 @@ case class SequenceNumberRanges(ranges: Array[SequenceNumberRange]) {
4144
4245private [kinesis]
4346object SequenceNumberRanges {
44-
4547 def apply (range : SequenceNumberRange ): SequenceNumberRanges = {
4648 new SequenceNumberRanges (Array (range))
4749 }
48-
49- def apply (ranges : Seq [SequenceNumberRange ]): SequenceNumberRanges = {
50- new SequenceNumberRanges (ranges.toArray)
51- }
52-
53- def empty : SequenceNumberRanges = {
54- new SequenceNumberRanges (Array .empty)
55- }
5650}
5751
52+
53+ /** Partition storing the information of the ranges of Kinesis sequence numbers to read */
5854private [kinesis]
5955class KinesisBackedBlockRDDPartition (
6056 idx : Int ,
@@ -63,14 +59,19 @@ class KinesisBackedBlockRDDPartition(
6359 val seqNumberRanges : SequenceNumberRanges
6460 ) extends BlockRDDPartition (blockId, idx)
6561
62+ /**
63+ * A BlockRDD where the block data is backed by Kinesis, which can accessed using the
64+ * sequence numbers of the corresponding blocks.
65+ */
6666private [kinesis]
6767class KinesisBackedBlockRDD (
6868 sc : SparkContext ,
6969 regionId : String ,
7070 endpointUrl : String ,
7171 @ transient blockIds : Array [BlockId ],
7272 @ transient arrayOfseqNumberRanges : Array [SequenceNumberRanges ],
73- @ transient isBlockIdValid : Array [Boolean ] = Array .empty
73+ @ transient isBlockIdValid : Array [Boolean ] = Array .empty,
74+ awsCredentialsOption : Option [SerializableAWSCredentials ] = None
7475) extends BlockRDD [Array [Byte ]](sc, blockIds) {
7576
7677 require(blockIds.length == arrayOfseqNumberRanges.length,
@@ -96,11 +97,11 @@ class KinesisBackedBlockRDD(
9697 }
9798
9899 def getBlockFromKinesis (): Iterator [Array [Byte ]] = {
99- val credenentials = new DefaultAWSCredentialsProviderChain ().getCredentials()
100+ val credenentials = awsCredentialsOption.getOrElse {
101+ new DefaultAWSCredentialsProviderChain ().getCredentials()
102+ }
100103 partition.seqNumberRanges.ranges.iterator.flatMap { range =>
101- new KinesisSequenceRangeIterator (
102- credenentials, endpointUrl, regionId,
103- range.streamName, range.shardId, range.fromSeqNumber, range.toSeqNumber)
104+ new KinesisSequenceRangeIterator (credenentials, endpointUrl, regionId, range)
104105 }
105106 }
106107 if (partition.isBlockIdValid) {
@@ -112,15 +113,13 @@ class KinesisBackedBlockRDD(
112113}
113114
114115
116+ /** An iterator that return the Kinesis data based on the given range of Sequence numbers */
115117private [kinesis]
116118class KinesisSequenceRangeIterator (
117119 credentials : AWSCredentials ,
118120 endpointUrl : String ,
119121 regionId : String ,
120- streamName : String ,
121- shardId : String ,
122- fromSeqNumber : String ,
123- toSeqNumber : String
122+ range : SequenceNumberRange
124123 ) extends NextIterator [Array [Byte ]] {
125124
126125 private val backoffTimeMillis = 1000
@@ -142,7 +141,7 @@ class KinesisSequenceRangeIterator(
142141
143142 // If the internal iterator has not been initialized,
144143 // then fetch records from starting sequence number
145- getRecords(ShardIteratorType .AT_SEQUENCE_NUMBER , fromSeqNumber)
144+ getRecords(ShardIteratorType .AT_SEQUENCE_NUMBER , range. fromSeqNumber)
146145 } else if (! internalIterator.hasNext) {
147146
148147 // If the internal iterator does not have any more records,
@@ -155,9 +154,7 @@ class KinesisSequenceRangeIterator(
155154 // If the internal iterator still does not have any data, then throw exception
156155 // and terminate this iterator
157156 finished = true
158- throw new Exception (" Could not read until the specified end sequence number: " +
159- s " shardId = $shardId, fromSequenceNumber = $fromSeqNumber, " +
160- s " toSequenceNumber = $toSeqNumber" )
157+ throw new SparkException (s " Could not read until the specified end sequence number: $range" )
161158 } else {
162159
163160 // Get the record, and remember its sequence number
@@ -167,7 +164,7 @@ class KinesisSequenceRangeIterator(
167164
168165 // If the this record's sequence number matches the stopping sequence number, then make sure
169166 // the iterator is marked finished next time getNext() is called
170- if (nextRecord.getSequenceNumber == toSeqNumber) {
167+ if (nextRecord.getSequenceNumber == range. toSeqNumber) {
171168 toSeqNumberReceived = true
172169 }
173170 }
@@ -179,11 +176,12 @@ class KinesisSequenceRangeIterator(
179176 override protected def close (): Unit = { }
180177
181178 private def getRecords (iteratorType : ShardIteratorType , seqNum : String ): Unit = {
182- val shardIterator = getKinesisIterator(streamName, shardId, iteratorType, seqNum)
179+ val shardIterator = getKinesisIterator(range. streamName, range. shardId, iteratorType, seqNum)
183180 var records : Seq [Record ] = null
184181 do {
185182 try {
186- val getResult = getRecordsAndNextKinesisIterator(streamName, shardId, shardIterator)
183+ val getResult = getRecordsAndNextKinesisIterator(
184+ range.streamName, range.shardId, shardIterator)
187185 records = getResult._1
188186 } catch {
189187 case ptee : ProvisionedThroughputExceededException =>
0 commit comments