diff --git a/extras/kinesis-asl-assembly/pom.xml b/extras/kinesis-asl-assembly/pom.xml
index d1c38c7ca5d69..911b00e2b579f 100644
--- a/extras/kinesis-asl-assembly/pom.xml
+++ b/extras/kinesis-asl-assembly/pom.xml
@@ -47,6 +47,12 @@
${project.version}
provided
+
+ org.apache.spark
+ spark-sql_${scala.binary.version}
+ ${project.version}
+ provided
+
diff --git a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/DefaultSource.scala b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/DefaultSource.scala
new file mode 100644
index 0000000000000..8be8dd7df5dc4
--- /dev/null
+++ b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/DefaultSource.scala
@@ -0,0 +1,115 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.streaming.kinesis
+
+import com.amazonaws.AmazonClientException
+import com.amazonaws.auth.DefaultAWSCredentialsProviderChain
+import com.amazonaws.regions.RegionUtils
+import com.amazonaws.services.kinesis.clientlibrary.lib.worker.InitialPositionInStream
+
+import org.apache.spark.sql.SQLContext
+import org.apache.spark.sql.execution.datasources.CaseInsensitiveMap
+import org.apache.spark.sql.execution.streaming.Source
+import org.apache.spark.sql.sources.{DataSourceRegister, StreamSourceProvider}
+import org.apache.spark.sql.types.StructType
+
+class DefaultSource extends StreamSourceProvider with DataSourceRegister {
+
+ override def shortName(): String = "kinesis"
+
+ override def createSource(
+ sqlContext: SQLContext,
+ schema: Option[StructType],
+ providerName: String,
+ parameters: Map[String, String]): Source = {
+ val caseInsensitiveOptions = new CaseInsensitiveMap(parameters)
+
+ val streams = caseInsensitiveOptions.getOrElse("stream", {
+ throw new IllegalArgumentException(
+ "Option 'stream' must be specified. Examples: " +
+ """option("stream", "stream1"), option("stream", "stream1,stream2")""")
+ }).split(",", -1).toSet
+
+ if (streams.isEmpty || streams.exists(_.isEmpty)) {
+ throw new IllegalArgumentException(
+ "Option 'stream' is invalid, as stream names cannot be empty.")
+ }
+
+ val regionOption = caseInsensitiveOptions.get("region")
+ val endpointOption = caseInsensitiveOptions.get("endpoint")
+ val (region, endpoint) = (regionOption, endpointOption) match {
+ case (Some(_region), Some(_endpoint)) =>
+ if (RegionUtils.getRegionByEndpoint(_endpoint).getName != _region) {
+ throw new IllegalArgumentException(
+ s"'region'(${_region}) doesn't match to 'endpoint'(${_endpoint})")
+ }
+ (_region, _endpoint)
+ case (Some(_region), None) =>
+ (_region, RegionUtils.getRegion(_region).getServiceEndpoint("kinesis"))
+ case (None, Some(_endpoint)) =>
+ (RegionUtils.getRegionByEndpoint(_endpoint).getName, _endpoint)
+ case (None, None) =>
+ throw new IllegalArgumentException(
+ "Either option 'region' or option 'endpoint' must be specified. Examples: " +
+ """option("region", "us-west-2"), """ +
+ """option("endpoint", "https://kinesis.us-west-2.amazonaws.com")""")
+ }
+
+ val initialPosInStream =
+ caseInsensitiveOptions.getOrElse("position", InitialPositionInStream.LATEST.name) match {
+ case pos if pos.toUpperCase == InitialPositionInStream.LATEST.name =>
+ InitialPositionInStream.LATEST
+ case pos if pos.toUpperCase == InitialPositionInStream.TRIM_HORIZON.name =>
+ InitialPositionInStream.TRIM_HORIZON
+ case pos =>
+ throw new IllegalArgumentException(s"Unknown value of option 'position': $pos")
+ }
+
+ val accessKeyOption = caseInsensitiveOptions.get("accessKey")
+ val secretKeyOption = caseInsensitiveOptions.get("secretKey")
+ val credentials = (accessKeyOption, secretKeyOption) match {
+ case (Some(accessKey), Some(secretKey)) =>
+ new SerializableAWSCredentials(accessKey, secretKey)
+ case (Some(accessKey), None) =>
+ throw new IllegalArgumentException(
+ s"'accessKey' is set but 'secretKey' is not found")
+ case (None, Some(secretKey)) =>
+ throw new IllegalArgumentException(
+ s"'secretKey' is set but 'accessKey' is not found")
+ case (None, None) =>
+ try {
+ SerializableAWSCredentials(new DefaultAWSCredentialsProviderChain().getCredentials())
+ } catch {
+ case _: AmazonClientException =>
+ throw new IllegalArgumentException(
+ "No credential found using default AWS provider chain. Specify credentials using " +
+ "options 'accessKey' and 'secretKey'. Examples: " +
+ """option("accessKey", "your-aws-access-key"), """ +
+ """option("secretKey", "your-aws-secret-key")""")
+ }
+ }
+
+ new KinesisSource(
+ sqlContext,
+ region,
+ endpoint,
+ streams,
+ initialPosInStream,
+ credentials)
+ }
+}
diff --git a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisDataFetcher.scala b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisDataFetcher.scala
new file mode 100644
index 0000000000000..90d2e2bd5471a
--- /dev/null
+++ b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisDataFetcher.scala
@@ -0,0 +1,203 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.streaming.kinesis
+
+import scala.collection.JavaConverters._
+import scala.collection.mutable.ArrayBuffer
+import scala.util.control.NonFatal
+
+import com.amazonaws.services.kinesis.AmazonKinesisClient
+import com.amazonaws.services.kinesis.clientlibrary.lib.worker.InitialPositionInStream
+import com.amazonaws.services.kinesis.clientlibrary.types.UserRecord
+import com.amazonaws.services.kinesis.model._
+
+import org.apache.spark._
+import org.apache.spark.network.util.JavaUtils
+import org.apache.spark.storage.{BlockId, StorageLevel}
+
+/**
+ * However, this class runs in the driver so could be a bottleneck.
+ */
+private[kinesis] class KinesisDataFetcher(
+ credentials: SerializableAWSCredentials,
+ endpointUrl: String,
+ fromSeqNums: Seq[(Shard, Option[String], BlockId)],
+ initialPositionInStream: InitialPositionInStream,
+ readTimeoutMs: Long = 2000L) extends Serializable with Logging {
+
+ /**
+ * Use lazy because the client needs to be created in executors
+ */
+ @transient private lazy val client = new AmazonKinesisClient(credentials)
+
+ /**
+ * Launch a Spark job to fetch latest data from the specified `shard`s. This method will try to
+ * fetch arriving data in `readTimeoutMs` milliseconds so as to get the latest sequence numbers.
+ * New data will be pushed to the block manager to avoid fetching them again.
+ *
+ * This is a workaround since Kinesis doesn't provider an API to fetch the latest sequence number.
+ */
+ def fetch(sc: SparkContext): Array[(BlockId, SequenceNumberRange)] = {
+ sc.makeRDD(fromSeqNums, fromSeqNums.size).map {
+ case (shard, fromSeqNum, blockId) => fetchPartition(shard, fromSeqNum, blockId)
+ }.collect().flatten
+ }
+
+ /**
+ * Fetch latest data from the specified `shard` since `fromSeqNum`. This method will try to fetch
+ * arriving data in `readTimeoutMs` milliseconds so as to get the latest sequence number. New data
+ * will be pushed to the block manager to avoid fetching them again.
+ *
+ * This is a workaround since Kinesis doesn't provider an API to fetch the latest sequence number.
+ */
+ private def fetchPartition(
+ shard: Shard,
+ fromSeqNum: Option[String],
+ blockId: BlockId): Option[(BlockId, SequenceNumberRange)] = {
+ client.setEndpoint(endpointUrl)
+
+ val endTime = System.currentTimeMillis + readTimeoutMs
+ def timeLeft = math.max(endTime - System.currentTimeMillis, 0)
+
+ val buffer = new ArrayBuffer[Array[Byte]]
+ var firstSeqNumber: String = null
+ var lastSeqNumber: String = fromSeqNum.orNull
+ var lastIterator: String = null
+ try {
+ logDebug(s"Trying to fetch data from $shard, from seq num $lastSeqNumber")
+
+ while (timeLeft > 0) {
+ val (records, nextIterator) = retryOrTimeout("getting shard iterator", timeLeft) {
+ if (lastIterator == null) {
+ lastIterator = if (lastSeqNumber != null) {
+ getKinesisIterator(shard, ShardIteratorType.AFTER_SEQUENCE_NUMBER, lastSeqNumber)
+ } else {
+ if (initialPositionInStream == InitialPositionInStream.LATEST) {
+ getKinesisIterator(shard, ShardIteratorType.LATEST, lastSeqNumber)
+ } else {
+ getKinesisIterator(shard, ShardIteratorType.TRIM_HORIZON, lastSeqNumber)
+ }
+ }
+ }
+ getRecordsAndNextKinesisIterator(lastIterator)
+ }
+
+ records.foreach { record =>
+ buffer += JavaUtils.bufferToArray(record.getData())
+ if (firstSeqNumber == null) {
+ firstSeqNumber = record.getSequenceNumber
+ }
+ lastSeqNumber = record.getSequenceNumber
+ }
+
+ lastIterator = nextIterator
+ }
+
+ if (buffer.nonEmpty) {
+ SparkEnv.get.blockManager.putIterator(blockId, buffer.iterator, StorageLevel.MEMORY_ONLY)
+ val range = SequenceNumberRange(
+ shard.streamName, shard.shardId, firstSeqNumber, lastSeqNumber)
+ logDebug(s"Received block $blockId having range $range from shard $shard")
+ Some(blockId -> range)
+ } else {
+ None
+ }
+ } catch {
+ case NonFatal(e) =>
+ logWarning(s"Error fetching data from shard $shard", e)
+ None
+ }
+ }
+
+ /**
+ * Get the records starting from using a Kinesis shard iterator (which is a progress handle
+ * to get records from Kinesis), and get the next shard iterator for next consumption.
+ */
+ private def getRecordsAndNextKinesisIterator(shardIterator: String): (Seq[Record], String) = {
+ val getRecordsRequest = new GetRecordsRequest().withShardIterator(shardIterator)
+ getRecordsRequest.setRequestCredentials(credentials)
+ val getRecordsResult = client.getRecords(getRecordsRequest)
+ // De-aggregate records, if KPL was used in producing the records. The KCL automatically
+ // handles de-aggregation during regular operation. This code path is used during recovery
+ val records = UserRecord.deaggregate(getRecordsResult.getRecords)
+ logTrace(
+ s"Got ${records.size()} records and next iterator ${getRecordsResult.getNextShardIterator}")
+ (records.asScala, getRecordsResult.getNextShardIterator)
+ }
+
+ /**
+ * Get the Kinesis shard iterator for getting records starting from or after the given
+ * sequence number.
+ */
+ private def getKinesisIterator(
+ shard: Shard,
+ iteratorType: ShardIteratorType,
+ sequenceNumber: String): String = {
+ val getShardIteratorRequest = new GetShardIteratorRequest()
+ .withStreamName(shard.streamName)
+ .withShardId(shard.shardId)
+ .withShardIteratorType(iteratorType.toString)
+ .withStartingSequenceNumber(sequenceNumber)
+ getShardIteratorRequest.setRequestCredentials(credentials)
+ val getShardIteratorResult = client.getShardIterator(getShardIteratorRequest)
+ logTrace(s"Shard $shard: Got iterator ${getShardIteratorResult.getShardIterator}")
+ getShardIteratorResult.getShardIterator
+ }
+
+ /** Helper method to retry Kinesis API request with exponential backoff and timeouts */
+ private def retryOrTimeout[T](message: String, retryTimeoutMs: Long)(body: => T): T = {
+ import KinesisSequenceRangeIterator._
+ val startTimeMs = System.currentTimeMillis()
+ var retryCount = 0
+ var waitTimeMs = MIN_RETRY_WAIT_TIME_MS
+ var result: Option[T] = None
+ var lastError: Throwable = null
+
+ def isTimedOut = (System.currentTimeMillis() - startTimeMs) >= retryTimeoutMs
+ def isMaxRetryDone = retryCount >= MAX_RETRIES
+
+ while (result.isEmpty && !isTimedOut && !isMaxRetryDone) {
+ if (retryCount > 0) {
+ // wait only if this is a retry
+ Thread.sleep(waitTimeMs)
+ waitTimeMs *= 2 // if you have waited, then double wait time for next round
+ }
+ try {
+ result = Some(body)
+ } catch {
+ case NonFatal(t) =>
+ lastError = t
+ t match {
+ case ptee: ProvisionedThroughputExceededException =>
+ logWarning(s"Error while $message [attempt = ${retryCount + 1}]", ptee)
+ case e: Throwable =>
+ throw new SparkException(s"Error while $message", e)
+ }
+ }
+ retryCount += 1
+ }
+ result.getOrElse {
+ if (isTimedOut) {
+ throw new SparkException(
+ s"Timed out after $retryTimeoutMs ms while $message, last exception: ", lastError)
+ } else {
+ throw new SparkException(
+ s"Gave up after $retryCount retries while $message, last exception: ", lastError)
+ }
+ }
+ }
+}
diff --git a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisReceiver.scala b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisReceiver.scala
index ca13a21087cc8..d295c4d7a63eb 100644
--- a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisReceiver.scala
+++ b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisReceiver.scala
@@ -41,7 +41,7 @@ case class SerializableAWSCredentials(accessKeyId: String, secretKey: String)
override def getAWSSecretKey: String = secretKey
}
-object SerializableAWSCredentials {
+private[kinesis] object SerializableAWSCredentials {
def apply(credentials: AWSCredentials): SerializableAWSCredentials = {
new SerializableAWSCredentials(credentials.getAWSAccessKeyId, credentials.getAWSSecretKey)
}
diff --git a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisSource.scala b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisSource.scala
index 07876f5f45aa1..77c6a70235686 100644
--- a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisSource.scala
+++ b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisSource.scala
@@ -16,23 +16,22 @@
*/
package org.apache.spark.streaming.kinesis
+import java.util.concurrent.atomic.AtomicLong
+
import scala.collection.JavaConverters._
+import scala.collection.mutable
import scala.collection.mutable.ArrayBuffer
-import scala.util.Random
-import scala.util.control.NonFatal
-import com.amazonaws.auth.DefaultAWSCredentialsProviderChain
+import com.amazonaws.AbortedException
import com.amazonaws.services.kinesis.AmazonKinesisClient
import com.amazonaws.services.kinesis.clientlibrary.lib.worker.InitialPositionInStream
-import com.amazonaws.services.kinesis.clientlibrary.types.UserRecord
-import com.amazonaws.services.kinesis.model._
import org.apache.spark._
+import org.apache.spark.sql.SQLContext
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
-import org.apache.spark.sql.execution.streaming.{Batch, Offset, Source, StreamingRelation}
+import org.apache.spark.sql.execution.streaming.{Batch, Offset, Source}
import org.apache.spark.sql.types.StructType
-import org.apache.spark.sql.{DataFrame, Dataset, SQLContext}
-import org.apache.spark.storage.{BlockId, StorageLevel, StreamBlockId}
+import org.apache.spark.storage.{BlockId, StreamBlockId}
private[kinesis] case class Shard(streamName: String, shardId: String)
@@ -40,14 +39,13 @@ private[kinesis] case class KinesisSourceOffset(seqNums: Map[Shard, String])
extends Offset {
override def compareTo(otherOffset: Offset): Int = otherOffset match {
-
case that: KinesisSourceOffset =>
val allShards = this.seqNums.keySet ++ that.seqNums.keySet
val comparisons = allShards.map { shard =>
(this.seqNums.get(shard).map(BigInt(_)), that.seqNums.get(shard).map(BigInt(_))) match {
case (Some(thisNum), Some(thatNum)) => thisNum.compare(thatNum)
case (None, _) => -1 // new shard started by resharding
- case (_, None) => -1 // old shard got eliminated by resharding
+ case (_, None) => 1 // old shard got eliminated by resharding
}
}
@@ -57,7 +55,7 @@ private[kinesis] case class KinesisSourceOffset(seqNums: Map[Shard, String])
case 1 => nonZeroSigns.head // if there are only (0s and 1s) or (0s and -1s)
case _ => // there are both 1s and -1s
throw new IllegalArgumentException(
- s"Invalid comparison between non-linear histories: $this <=> $that")
+ s"Invalid comparison between KinesisSource offsets: \n\t this: $this\n\t that: $that")
}
case _ =>
@@ -65,50 +63,58 @@ private[kinesis] case class KinesisSourceOffset(seqNums: Map[Shard, String])
}
}
-
-private[kinesis] object KinesisSourceOffset {
- def fromOffset(offset: Offset): KinesisSourceOffset = {
- offset match {
- case o: KinesisSourceOffset => o
- case _ =>
- throw new IllegalArgumentException(
- s"Invalid conversion from offset of ${offset.getClass} to $getClass")
- }
- }
-}
-
-
-private[kinesis] case class KinesisSource(
+private[kinesis] class KinesisSource(
sqlContext: SQLContext,
regionName: String,
endpointUrl: String,
streamNames: Set[String],
- initialPosInStream: InitialPositionInStream = InitialPositionInStream.LATEST,
- awsCredentialsOption: Option[SerializableAWSCredentials] = None) extends Source {
+ initialPosInStream: InitialPositionInStream,
+ awsCredentials: SerializableAWSCredentials) extends Source {
+
+ // How long we should wait before calling `fetchShards()`. Because DescribeStream has a limit of
+ // 10 transactions per second per account, we should not request too frequently.
+ private val FETCH_SHARDS_INTERVAL_MS = 200L
+
+ // The last time `fetchShards()` is called.
+ private var lastFetchShardsTimeMS = 0L
implicit private val encoder = ExpressionEncoder[Array[Byte]]
- private val logicalPlan = StreamingRelation(this)
- @transient val credentials = SerializableAWSCredentials(
- awsCredentialsOption.getOrElse(new DefaultAWSCredentialsProviderChain().getCredentials())
- )
+ private val client = new AmazonKinesisClient(awsCredentials)
+ client.setEndpoint(endpointUrl)
- @transient private val client = new AmazonKinesisClient(credentials)
- client.setEndpoint(endpointUrl, "kinesis", regionName)
+ private var cachedBlocks = new mutable.HashSet[BlockId]
- override def schema: StructType = encoder.schema
+ override val schema: StructType = encoder.schema
override def getNextBatch(start: Option[Offset]): Option[Batch] = {
- val startOffset = start.map(KinesisSourceOffset.fromOffset)
+ val now = System.currentTimeMillis()
+ if (now - lastFetchShardsTimeMS < FETCH_SHARDS_INTERVAL_MS) {
+ // Because DescribeStream has a limit of 10 transactions per second per account, we should not
+ // request too frequently.
+ return None
+ }
+ lastFetchShardsTimeMS = now
+ val startOffset = start.map(_.asInstanceOf[KinesisSourceOffset])
val shards = fetchShards()
// Get the starting seq number of each shard if available
val fromSeqNums = shards.map { shard => shard -> startOffset.flatMap(_.seqNums.get(shard)) }
- /** Prefetch Kinesis data from the starting seq nums */
- val prefetchedData = new KinesisDataFetcher(
- sqlContext, credentials, endpointUrl, regionName, fromSeqNums, initialPosInStream).fetch()
+ // Assign a unique block id for each shard
+ val fromSeqNumsWithBlockId = fromSeqNums.map { case (shard, seqNum) =>
+ val uniqueBlockId = KinesisSource.nextBlockId
+ (shard, seqNum, uniqueBlockId)
+ }
+
+ // Prefetch Kinesis data from the starting seq nums
+ val prefetchedData =
+ new KinesisDataFetcher(
+ awsCredentials,
+ endpointUrl,
+ fromSeqNumsWithBlockId,
+ initialPosInStream).fetch(sqlContext.sparkContext)
if (prefetchedData.nonEmpty) {
val prefetechedRanges = prefetchedData.map(_._2)
@@ -125,189 +131,52 @@ private[kinesis] case class KinesisSource(
val rdd = new KinesisBackedBlockRDD[Array[Byte]](sqlContext.sparkContext, regionName,
endpointUrl, prefetchedBlockIds, prefetechedRanges.map(SequenceNumberRanges.apply))
+
+ dropOldBlocks()
+ cachedBlocks ++= prefetchedBlockIds
+
Some(new Batch(new KinesisSourceOffset(endOffset), sqlContext.createDataset(rdd).toDF))
} else {
None
}
}
- def toDS(): Dataset[Array[Byte]] = {
- toDF.as[Array[Byte]]
- }
-
- def toDF(): DataFrame = {
- new DataFrame(sqlContext, logicalPlan)
- }
-
- private def fetchShards(): Seq[Shard] = {
- streamNames.toSeq.flatMap { streamName =>
- val desc = client.describeStream(streamName)
- desc.getStreamDescription.getShards.asScala.map { s =>
- Shard(streamName, s.getShardId)
+ private def dropOldBlocks(): Unit = {
+ val droppedBlocks = ArrayBuffer[BlockId]()
+ try {
+ for (blockId <- cachedBlocks) {
+ SparkEnv.get.blockManager.removeBlock(blockId)
+ droppedBlocks += blockId
}
+ } finally {
+ cachedBlocks --= droppedBlocks
}
}
-}
-
-private[kinesis] class KinesisDataFetcher(
- sqlContext: SQLContext,
- credentials: SerializableAWSCredentials,
- endpointUrl: String,
- regionName: String,
- fromSeqNums: Seq[(Shard, Option[String])],
- initialPositionInStream: InitialPositionInStream,
- readTimeoutMs: Int = 2000
- ) extends Serializable with Logging {
-
- @transient private lazy val client = new AmazonKinesisClient(credentials)
-
- def fetch(): Array[(BlockId, SequenceNumberRange)] = {
- sqlContext.sparkContext.makeRDD(fromSeqNums, fromSeqNums.size).map {
- case (shard, fromSeqNum) => fetchPartition(shard, fromSeqNum)
- }.collect().flatten
- }
-
- private def fetchPartition(
- shard: Shard,
- fromSeqNum: Option[String]): Option[(BlockId, SequenceNumberRange)] = {
- client.setEndpoint(endpointUrl, "kinesis", regionName)
-
- val endTime = System.currentTimeMillis + readTimeoutMs
- def timeLeft = math.max(endTime - System.currentTimeMillis, 0)
-
- val buffer = new ArrayBuffer[Array[Byte]]
- var firstSeqNumber: String = null
- var lastSeqNumber: String = fromSeqNum.orNull
- var lastIterator: String = null
+ private def fetchShards(): Seq[Shard] = {
try {
- logDebug(s"Trying to fetch data from $shard, from seq num $lastSeqNumber")
-
- while (timeLeft > 0) {
- val (records, nextIterator) = retryOrTimeout("getting shard iterator", timeLeft) {
- if (lastIterator == null) {
- lastIterator = if (lastSeqNumber != null) {
- getKinesisIterator(shard, ShardIteratorType.AFTER_SEQUENCE_NUMBER, lastSeqNumber)
- } else {
- if (initialPositionInStream == InitialPositionInStream.LATEST) {
- getKinesisIterator(shard, ShardIteratorType.LATEST, lastSeqNumber)
- } else {
- getKinesisIterator(shard, ShardIteratorType.TRIM_HORIZON, lastSeqNumber)
- }
- }
- }
- getRecordsAndNextKinesisIterator(lastIterator)
+ streamNames.toSeq.flatMap { streamName =>
+ val desc = client.describeStream(streamName)
+ desc.getStreamDescription.getShards.asScala.map { s =>
+ Shard(streamName, s.getShardId)
}
-
- records.foreach { record =>
- val byteBuffer = record.getData()
- val byteArray = new Array[Byte](byteBuffer.remaining())
- byteBuffer.get(byteArray)
- buffer += byteArray
- if (firstSeqNumber == null) {
- firstSeqNumber = record.getSequenceNumber
- }
- lastSeqNumber = record.getSequenceNumber
- }
-
- lastIterator = nextIterator
- }
-
- if (buffer.nonEmpty) {
- val blockId = StreamBlockId(0, Random.nextLong)
- SparkEnv.get.blockManager.putIterator(blockId, buffer.iterator, StorageLevel.MEMORY_ONLY)
- val range = SequenceNumberRange(
- shard.streamName, shard.shardId, firstSeqNumber, lastSeqNumber)
- logDebug(s"Received block $blockId having range $range from shard $shard")
- Some(blockId -> range)
- } else {
- None
}
} catch {
- case NonFatal(e) =>
- logWarning(s"Error fetching data from shard $shard", e)
- None
+ case e: AbortedException =>
+ // AbortedException will be thrown if the current thread is interrupted
+ // So let's convert it back to InterruptedException
+ val e1 = new InterruptedException("thread is interrupted")
+ e1.addSuppressed(e)
+ throw e1
}
}
+ override def toString: String = s"KinesisSource[streamNames=${streamNames.mkString(",")}]"
+}
- /**
- * Get the records starting from using a Kinesis shard iterator (which is a progress handle
- * to get records from Kinesis), and get the next shard iterator for next consumption.
- */
- private def getRecordsAndNextKinesisIterator(
- shardIterator: String): (Seq[Record], String) = {
- val getRecordsRequest = new GetRecordsRequest
- getRecordsRequest.setRequestCredentials(credentials)
- getRecordsRequest.setShardIterator(shardIterator)
- val getRecordsResult = client.getRecords(getRecordsRequest)
- // De-aggregate records, if KPL was used in producing the records. The KCL automatically
- // handles de-aggregation during regular operation. This code path is used during recovery
- val records = UserRecord.deaggregate(getRecordsResult.getRecords)
- logTrace(
- s"Got ${records.size()} records and next iterator ${getRecordsResult.getNextShardIterator}")
- (records.asScala, getRecordsResult.getNextShardIterator)
- }
+private[kinesis] object KinesisSource {
- /**
- * Get the Kinesis shard iterator for getting records starting from or after the given
- * sequence number.
- */
- private def getKinesisIterator(
- shard: Shard,
- iteratorType: ShardIteratorType,
- sequenceNumber: String): String = {
- val getShardIteratorRequest = new GetShardIteratorRequest
- getShardIteratorRequest.setRequestCredentials(credentials)
- getShardIteratorRequest.setStreamName(shard.streamName)
- getShardIteratorRequest.setShardId(shard.shardId)
- getShardIteratorRequest.setShardIteratorType(iteratorType.toString)
- getShardIteratorRequest.setStartingSequenceNumber(sequenceNumber)
- val getShardIteratorResult = client.getShardIterator(getShardIteratorRequest)
- logTrace(s"Shard $shard: Got iterator ${getShardIteratorResult.getShardIterator}")
- getShardIteratorResult.getShardIterator
- }
+ private val nextId = new AtomicLong(0)
- /** Helper method to retry Kinesis API request with exponential backoff and timeouts */
- private def retryOrTimeout[T](message: String, retryTimeoutMs: Long)(body: => T): T = {
- import KinesisSequenceRangeIterator._
- var startTimeMs = System.currentTimeMillis()
- var retryCount = 0
- var waitTimeMs = MIN_RETRY_WAIT_TIME_MS
- var result: Option[T] = None
- var lastError: Throwable = null
-
- def isTimedOut = (System.currentTimeMillis() - startTimeMs) >= retryTimeoutMs
- def isMaxRetryDone = retryCount >= MAX_RETRIES
-
- while (result.isEmpty && !isTimedOut && !isMaxRetryDone) {
- if (retryCount > 0) {
- // wait only if this is a retry
- Thread.sleep(waitTimeMs)
- waitTimeMs *= 2 // if you have waited, then double wait time for next round
- }
- try {
- result = Some(body)
- } catch {
- case NonFatal(t) =>
- lastError = t
- t match {
- case ptee: ProvisionedThroughputExceededException =>
- logWarning(s"Error while $message [attempt = ${retryCount + 1}]", ptee)
- case e: Throwable =>
- throw new SparkException(s"Error while $message", e)
- }
- }
- retryCount += 1
- }
- result.getOrElse {
- if (isTimedOut) {
- throw new SparkException(
- s"Timed out after $retryTimeoutMs ms while $message, last exception: ", lastError)
- } else {
- throw new SparkException(
- s"Gave up after $retryCount retries while $message, last exception: ", lastError)
- }
- }
- }
+ def nextBlockId: StreamBlockId = StreamBlockId(Int.MaxValue, nextId.getAndIncrement)
}
diff --git a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisTestUtils.scala b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisTestUtils.scala
index e3814be676e85..8d32c5c5aee64 100644
--- a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisTestUtils.scala
+++ b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisTestUtils.scala
@@ -42,7 +42,7 @@ import org.apache.spark.Logging
private[kinesis] class KinesisTestUtils extends Logging {
val endpointUrl = KinesisTestUtils.endpointUrl
- val regionName = RegionUtils.getRegionByEndpoint(endpointUrl).getName()
+ val regionName = KinesisTestUtils.regionName
val streamShardCount = 2
private val createStreamTimeoutSeconds = 300
@@ -54,6 +54,8 @@ private[kinesis] class KinesisTestUtils extends Logging {
@volatile
private var _streamName: String = _
+ private val shardIdToLatestSeqNum = mutable.HashMap[String, String]()
+
protected lazy val kinesisClient = {
val client = new AmazonKinesisClient(KinesisTestUtils.getAWSCredentials())
client.setEndpoint(endpointUrl)
@@ -105,9 +107,14 @@ private[kinesis] class KinesisTestUtils extends Logging {
val producer = getProducer(aggregate)
val shardIdToSeqNumbers = producer.sendData(streamName, testData)
logInfo(s"Pushed $testData:\n\t ${shardIdToSeqNumbers.mkString("\n\t")}")
+ shardIdToSeqNumbers.foreach { case (shardId, seq) =>
+ shardIdToLatestSeqNum(shardId) = seq.last._2
+ }
shardIdToSeqNumbers.toMap
}
+ def getLatestSeqNumsOfShards(): Map[String, String] = shardIdToLatestSeqNum.toMap
+
/**
* Expose a Python friendly API.
*/
@@ -118,7 +125,6 @@ private[kinesis] class KinesisTestUtils extends Logging {
def deleteStream(): Unit = {
try {
if (streamCreated) {
- logInfo(s"Deleting stream $streamName")
kinesisClient.deleteStream(streamName)
}
} catch {
@@ -210,6 +216,8 @@ private[kinesis] object KinesisTestUtils {
url
}
+ lazy val regionName = RegionUtils.getRegionByEndpoint(endpointUrl).getName()
+
def isAWSCredentialsPresent: Boolean = {
Try { new DefaultAWSCredentialsProviderChain().getCredentials() }.isSuccess
}
@@ -256,10 +264,6 @@ private[kinesis] class SimpleDataGenerator(
val seqNumber = putRecordResult.getSequenceNumber()
val sentSeqNumbers = shardIdToSeqNumbers.getOrElseUpdate(shardId,
new ArrayBuffer[(Int, String)]())
- // scalastyle:off println
- println(s"$data with key $str in shard ${putRecordResult.getShardId} " +
- s"and seq ${putRecordResult.getSequenceNumber}")
- // scalastyle:on println
sentSeqNumbers += ((num, seqNumber))
seqNumForOrdering = seqNumber
}
diff --git a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/package.scala b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/package.scala
new file mode 100644
index 0000000000000..efb79bad1ee9a
--- /dev/null
+++ b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/package.scala
@@ -0,0 +1,31 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.streaming
+
+import org.apache.spark.sql.DataFrameReader
+
+package object kinesis {
+
+ /**
+ * Add the `kinesis` method to DataFrameReader that allows people to read from Kinesis using
+ * the DataFileReader.
+ */
+ implicit class KinesisDataFrameReader(reader: DataFrameReader) {
+ def kinesis(): DataFrameReader = reader.format("org.apache.spark.streaming.kinesis")
+ }
+}
diff --git a/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisSourceSuite.scala b/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisSourceSuite.scala
index f17131dc93c6c..29336051d089d 100644
--- a/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisSourceSuite.scala
+++ b/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisSourceSuite.scala
@@ -17,63 +17,209 @@
package org.apache.spark.streaming.kinesis
-
import com.amazonaws.services.kinesis.clientlibrary.lib.worker.InitialPositionInStream
-import org.scalatest.time.SpanSugar._
-import org.apache.spark.sql.StreamTest
-import org.apache.spark.sql.execution.streaming.{Offset, Source}
+import org.apache.spark.SparkEnv
+import org.apache.spark.sql.{AnalysisException, StreamTest}
+import org.apache.spark.sql.execution.streaming.{Offset, Source, StreamingRelation}
import org.apache.spark.sql.test.SharedSQLContext
+import org.apache.spark.storage.StreamBlockId
+abstract class KinesisSourceTest extends StreamTest with SharedSQLContext {
-class KinesisSourceSuite extends StreamTest with SharedSQLContext with KinesisFunSuite {
-
- import testImplicits._
-
- private var testUtils: KPLBasedKinesisTestUtils = _
- private var streamName: String = _
+ case class AddKinesisData(
+ testUtils: KPLBasedKinesisTestUtils,
+ kinesisSource: KinesisSource,
+ data: Seq[Int]) extends AddData {
- override val streamingTimout = 60.seconds
-
- case class AddKinesisData(kinesisSource: KinesisSource, data: Int*) extends AddData {
override def addData(): Offset = {
- val shardIdToSeqNums = testUtils.pushData(data, false).map { case (shardId, info) =>
- (Shard(streamName, shardId), info.last._2)
+ testUtils.pushData(data, false)
+ val shardIdToSeqNums = testUtils.getLatestSeqNumsOfShards().map { case (shardId, seqNum) =>
+ (Shard(testUtils.streamName, shardId), seqNum)
}
- assert(shardIdToSeqNums.size === testUtils.streamShardCount,
- s"Data must be send to all ${testUtils.streamShardCount} shards of stream $streamName")
KinesisSourceOffset(shardIdToSeqNums)
}
override def source: Source = kinesisSource
}
- override def beforeAll(): Unit = {
- super.beforeAll()
- testUtils = new KPLBasedKinesisTestUtils
- testUtils.createStream()
- streamName = testUtils.streamName
+ def createKinesisSourceForTest(testUtils: KPLBasedKinesisTestUtils): KinesisSource = {
+ new KinesisSource(
+ sqlContext,
+ testUtils.regionName,
+ testUtils.endpointUrl,
+ Set(testUtils.streamName),
+ InitialPositionInStream.TRIM_HORIZON,
+ SerializableAWSCredentials(KinesisTestUtils.getAWSCredentials()))
}
+}
+
+class KinesisSourceSuite extends KinesisSourceTest with KinesisFunSuite {
+
+ import testImplicits._
+
+ testIfEnabled("basic receiving and failover") {
+ var streamBlocksInLastBatch: Seq[StreamBlockId] = Seq.empty
- override def afterAll(): Unit = {
- if (testUtils != null) {
+ def assertStreamBlocks: Boolean = {
+ if (sqlContext.sparkContext.isLocal) {
+ // Only test this one in local mode so that we can assume there is only one BlockManager
+ val streamBlocks =
+ SparkEnv.get.blockManager.getMatchingBlockIds(_.isInstanceOf[StreamBlockId])
+ val cleaned = streamBlocks.intersect(streamBlocksInLastBatch).isEmpty
+ streamBlocksInLastBatch = streamBlocks.map(_.asInstanceOf[StreamBlockId])
+ cleaned
+ } else {
+ true
+ }
+ }
+
+ val testUtils = new KPLBasedKinesisTestUtils
+ testUtils.createStream()
+ try {
+ val kinesisSource = createKinesisSourceForTest(testUtils)
+ val mapped =
+ kinesisSource.toDS[Array[Byte]]().map((bytes: Array[Byte]) => new String(bytes).toInt + 1)
+ val testData = 1 to 10
+ testStream(mapped)(
+ AddKinesisData(testUtils, kinesisSource, testData),
+ CheckAnswer((1 to 10).map(_ + 1): _*),
+ Assert(assertStreamBlocks, "Old stream blocks should be cleaned"),
+ StopStream,
+ AddKinesisData(testUtils, kinesisSource, 11 to 20),
+ StartStream,
+ CheckAnswer((1 to 20).map(_ + 1): _*),
+ Assert(assertStreamBlocks, "Old stream blocks should be cleaned"),
+ AddKinesisData(testUtils, kinesisSource, 21 to 30),
+ CheckAnswer((1 to 30).map(_ + 1): _*),
+ Assert(assertStreamBlocks, "Old stream blocks should be cleaned")
+ )
+ } finally {
testUtils.deleteStream()
}
- super.afterAll()
}
- test("basic receiving") {
- val kinesisSource = KinesisSource(
- sqlContext,
- testUtils.regionName,
- testUtils.endpointUrl,
- Set(streamName),
- InitialPositionInStream.TRIM_HORIZON)
- val mapped = kinesisSource.toDS().map[Int]((bytes: Array[Byte]) => new String(bytes).toInt + 1)
- val testData = 1 to 10 // This ensures that data is sent to multiple shards for 2 shard streams
- testStream(mapped)(
- AddKinesisData(kinesisSource, testData: _*),
- CheckAnswer(testData.map { _ + 1 }: _*)
- )
+ test("DataFrameReader") {
+ val df = sqlContext.read
+ .option("endpoint", KinesisTestUtils.endpointUrl)
+ .option("stream", "stream1")
+ .option("accessKey", "accessKey")
+ .option("secretKey", "secretKey")
+ .option("position", InitialPositionInStream.TRIM_HORIZON.name())
+ .kinesis().stream()
+
+ val sources = df.queryExecution.analyzed.collect {
+ case StreamingRelation(s: KinesisSource, _) => s
+ }
+ assert(sources.size === 1)
+
+ // stream
+ assertExceptionAndMessage[IllegalArgumentException](
+ "Option 'stream' must be specified.") {
+ sqlContext.read.kinesis().stream()
+ }
+ assertExceptionAndMessage[IllegalArgumentException](
+ "Option 'stream' is invalid, as stream names cannot be empty.") {
+ sqlContext.read.option("stream", "").kinesis().stream()
+ }
+ assertExceptionAndMessage[IllegalArgumentException](
+ "Option 'stream' is invalid, as stream names cannot be empty.") {
+ sqlContext.read.option("stream", "a,").kinesis().stream()
+ }
+ assertExceptionAndMessage[IllegalArgumentException](
+ "Option 'stream' is invalid, as stream names cannot be empty.") {
+ sqlContext.read.option("stream", ",a").kinesis().stream()
+ }
+
+ // region and endpoint
+ // Setting either endpoint or region is fine
+ sqlContext.read
+ .option("stream", "stream1")
+ .option("endpoint", KinesisTestUtils.endpointUrl)
+ .option("accessKey", "accessKey")
+ .option("secretKey", "secretKey")
+ .kinesis().stream()
+ sqlContext.read
+ .option("stream", "stream1")
+ .option("region", KinesisTestUtils.regionName)
+ .option("accessKey", "accessKey")
+ .option("secretKey", "secretKey")
+ .kinesis().stream()
+
+ assertExceptionAndMessage[IllegalArgumentException](
+ "Either option 'region' or option 'endpoint' must be specified.") {
+ sqlContext.read.option("stream", "stream1").kinesis().stream()
+ }
+ assertExceptionAndMessage[IllegalArgumentException](
+ s"'region'(invalid-region) doesn't match to 'endpoint'(${KinesisTestUtils.endpointUrl})") {
+ sqlContext.read
+ .option("stream", "stream1")
+ .option("region", "invalid-region")
+ .option("endpoint", KinesisTestUtils.endpointUrl)
+ .kinesis().stream()
+ }
+
+ // position
+ assertExceptionAndMessage[IllegalArgumentException](
+ "Unknown value of option 'position': invalid") {
+ sqlContext.read
+ .option("stream", "stream1")
+ .option("endpoint", KinesisTestUtils.endpointUrl)
+ .option("position", "invalid")
+ .kinesis().stream()
+ }
+
+ // accessKey and secretKey
+ assertExceptionAndMessage[IllegalArgumentException](
+ "'accessKey' is set but 'secretKey' is not found") {
+ sqlContext.read
+ .option("stream", "stream1")
+ .option("endpoint", KinesisTestUtils.endpointUrl)
+ .option("position", InitialPositionInStream.TRIM_HORIZON.name())
+ .option("accessKey", "test")
+ .kinesis().stream()
+ }
+ assertExceptionAndMessage[IllegalArgumentException](
+ "'secretKey' is set but 'accessKey' is not found") {
+ sqlContext.read
+ .option("stream", "stream1")
+ .option("endpoint", KinesisTestUtils.endpointUrl)
+ .option("position", InitialPositionInStream.TRIM_HORIZON.name())
+ .option("secretKey", "test")
+ .kinesis().stream()
+ }
+ }
+
+ test("call kinesis when not using stream") {
+ intercept[AnalysisException] {
+ sqlContext.read.kinesis().load()
+ }
+ }
+
+ private def assertExceptionAndMessage[T <: Exception : Manifest](
+ expectedMessage: String)(body: => Unit): Unit = {
+ val e = intercept[T] {
+ body
+ }
+ assert(e.getMessage.contains(expectedMessage))
+ }
+}
+
+class KinesisSourceStressTestSuite extends KinesisSourceTest with KinesisFunSuite {
+
+ import testImplicits._
+
+ testIfEnabled("kinesis source stress test") {
+ val testUtils = new KPLBasedKinesisTestUtils
+ testUtils.createStream()
+ try {
+ val kinesisSource = createKinesisSourceForTest(testUtils)
+ val ds = kinesisSource.toDS[String]().map(_.toInt + 1)
+ runStressTest(ds, data => {
+ AddKinesisData(testUtils, kinesisSource, data)
+ })
+ } finally {
+ testUtils.deleteStream()
+ }
}
-}
\ No newline at end of file
+}
diff --git a/pom.xml b/pom.xml
index 244355b080221..88b9ab2a37ce3 100644
--- a/pom.xml
+++ b/pom.xml
@@ -154,9 +154,9 @@
1.7.7
hadoop2
0.7.1
- 1.6.1
+ 1.4.0
- 0.10.2
+ 0.10.1
4.3.2
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StreamTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/StreamTest.scala
index bb5135826e2f3..fbf6f6ba49600 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/StreamTest.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/StreamTest.scala
@@ -380,10 +380,10 @@ trait StreamTest extends QueryTest with Timeouts {
pos += 1
}
} catch {
- case _: InterruptedException if streamDeathCause != null =>
- failTest("Stream Thread Died")
- case _: org.scalatest.exceptions.TestFailedDueToTimeoutException =>
- failTest("Timed out waiting for stream")
+ case e: InterruptedException if streamDeathCause != null =>
+ failTest("Stream Thread Died", e)
+ case e: org.scalatest.exceptions.TestFailedDueToTimeoutException =>
+ failTest("Timed out waiting for stream", e)
} finally {
if (currentStream != null && currentStream.microBatchThread.isAlive) {
currentStream.stop()