diff --git a/extras/kinesis-asl-assembly/pom.xml b/extras/kinesis-asl-assembly/pom.xml index d1c38c7ca5d69..911b00e2b579f 100644 --- a/extras/kinesis-asl-assembly/pom.xml +++ b/extras/kinesis-asl-assembly/pom.xml @@ -47,6 +47,12 @@ ${project.version} provided + + org.apache.spark + spark-sql_${scala.binary.version} + ${project.version} + provided + diff --git a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/DefaultSource.scala b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/DefaultSource.scala new file mode 100644 index 0000000000000..8be8dd7df5dc4 --- /dev/null +++ b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/DefaultSource.scala @@ -0,0 +1,115 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.kinesis + +import com.amazonaws.AmazonClientException +import com.amazonaws.auth.DefaultAWSCredentialsProviderChain +import com.amazonaws.regions.RegionUtils +import com.amazonaws.services.kinesis.clientlibrary.lib.worker.InitialPositionInStream + +import org.apache.spark.sql.SQLContext +import org.apache.spark.sql.execution.datasources.CaseInsensitiveMap +import org.apache.spark.sql.execution.streaming.Source +import org.apache.spark.sql.sources.{DataSourceRegister, StreamSourceProvider} +import org.apache.spark.sql.types.StructType + +class DefaultSource extends StreamSourceProvider with DataSourceRegister { + + override def shortName(): String = "kinesis" + + override def createSource( + sqlContext: SQLContext, + schema: Option[StructType], + providerName: String, + parameters: Map[String, String]): Source = { + val caseInsensitiveOptions = new CaseInsensitiveMap(parameters) + + val streams = caseInsensitiveOptions.getOrElse("stream", { + throw new IllegalArgumentException( + "Option 'stream' must be specified. Examples: " + + """option("stream", "stream1"), option("stream", "stream1,stream2")""") + }).split(",", -1).toSet + + if (streams.isEmpty || streams.exists(_.isEmpty)) { + throw new IllegalArgumentException( + "Option 'stream' is invalid, as stream names cannot be empty.") + } + + val regionOption = caseInsensitiveOptions.get("region") + val endpointOption = caseInsensitiveOptions.get("endpoint") + val (region, endpoint) = (regionOption, endpointOption) match { + case (Some(_region), Some(_endpoint)) => + if (RegionUtils.getRegionByEndpoint(_endpoint).getName != _region) { + throw new IllegalArgumentException( + s"'region'(${_region}) doesn't match to 'endpoint'(${_endpoint})") + } + (_region, _endpoint) + case (Some(_region), None) => + (_region, RegionUtils.getRegion(_region).getServiceEndpoint("kinesis")) + case (None, Some(_endpoint)) => + (RegionUtils.getRegionByEndpoint(_endpoint).getName, _endpoint) + case (None, None) => + throw new IllegalArgumentException( + "Either option 'region' or option 'endpoint' must be specified. Examples: " + + """option("region", "us-west-2"), """ + + """option("endpoint", "https://kinesis.us-west-2.amazonaws.com")""") + } + + val initialPosInStream = + caseInsensitiveOptions.getOrElse("position", InitialPositionInStream.LATEST.name) match { + case pos if pos.toUpperCase == InitialPositionInStream.LATEST.name => + InitialPositionInStream.LATEST + case pos if pos.toUpperCase == InitialPositionInStream.TRIM_HORIZON.name => + InitialPositionInStream.TRIM_HORIZON + case pos => + throw new IllegalArgumentException(s"Unknown value of option 'position': $pos") + } + + val accessKeyOption = caseInsensitiveOptions.get("accessKey") + val secretKeyOption = caseInsensitiveOptions.get("secretKey") + val credentials = (accessKeyOption, secretKeyOption) match { + case (Some(accessKey), Some(secretKey)) => + new SerializableAWSCredentials(accessKey, secretKey) + case (Some(accessKey), None) => + throw new IllegalArgumentException( + s"'accessKey' is set but 'secretKey' is not found") + case (None, Some(secretKey)) => + throw new IllegalArgumentException( + s"'secretKey' is set but 'accessKey' is not found") + case (None, None) => + try { + SerializableAWSCredentials(new DefaultAWSCredentialsProviderChain().getCredentials()) + } catch { + case _: AmazonClientException => + throw new IllegalArgumentException( + "No credential found using default AWS provider chain. Specify credentials using " + + "options 'accessKey' and 'secretKey'. Examples: " + + """option("accessKey", "your-aws-access-key"), """ + + """option("secretKey", "your-aws-secret-key")""") + } + } + + new KinesisSource( + sqlContext, + region, + endpoint, + streams, + initialPosInStream, + credentials) + } +} diff --git a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisDataFetcher.scala b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisDataFetcher.scala new file mode 100644 index 0000000000000..90d2e2bd5471a --- /dev/null +++ b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisDataFetcher.scala @@ -0,0 +1,203 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.streaming.kinesis + +import scala.collection.JavaConverters._ +import scala.collection.mutable.ArrayBuffer +import scala.util.control.NonFatal + +import com.amazonaws.services.kinesis.AmazonKinesisClient +import com.amazonaws.services.kinesis.clientlibrary.lib.worker.InitialPositionInStream +import com.amazonaws.services.kinesis.clientlibrary.types.UserRecord +import com.amazonaws.services.kinesis.model._ + +import org.apache.spark._ +import org.apache.spark.network.util.JavaUtils +import org.apache.spark.storage.{BlockId, StorageLevel} + +/** + * However, this class runs in the driver so could be a bottleneck. + */ +private[kinesis] class KinesisDataFetcher( + credentials: SerializableAWSCredentials, + endpointUrl: String, + fromSeqNums: Seq[(Shard, Option[String], BlockId)], + initialPositionInStream: InitialPositionInStream, + readTimeoutMs: Long = 2000L) extends Serializable with Logging { + + /** + * Use lazy because the client needs to be created in executors + */ + @transient private lazy val client = new AmazonKinesisClient(credentials) + + /** + * Launch a Spark job to fetch latest data from the specified `shard`s. This method will try to + * fetch arriving data in `readTimeoutMs` milliseconds so as to get the latest sequence numbers. + * New data will be pushed to the block manager to avoid fetching them again. + * + * This is a workaround since Kinesis doesn't provider an API to fetch the latest sequence number. + */ + def fetch(sc: SparkContext): Array[(BlockId, SequenceNumberRange)] = { + sc.makeRDD(fromSeqNums, fromSeqNums.size).map { + case (shard, fromSeqNum, blockId) => fetchPartition(shard, fromSeqNum, blockId) + }.collect().flatten + } + + /** + * Fetch latest data from the specified `shard` since `fromSeqNum`. This method will try to fetch + * arriving data in `readTimeoutMs` milliseconds so as to get the latest sequence number. New data + * will be pushed to the block manager to avoid fetching them again. + * + * This is a workaround since Kinesis doesn't provider an API to fetch the latest sequence number. + */ + private def fetchPartition( + shard: Shard, + fromSeqNum: Option[String], + blockId: BlockId): Option[(BlockId, SequenceNumberRange)] = { + client.setEndpoint(endpointUrl) + + val endTime = System.currentTimeMillis + readTimeoutMs + def timeLeft = math.max(endTime - System.currentTimeMillis, 0) + + val buffer = new ArrayBuffer[Array[Byte]] + var firstSeqNumber: String = null + var lastSeqNumber: String = fromSeqNum.orNull + var lastIterator: String = null + try { + logDebug(s"Trying to fetch data from $shard, from seq num $lastSeqNumber") + + while (timeLeft > 0) { + val (records, nextIterator) = retryOrTimeout("getting shard iterator", timeLeft) { + if (lastIterator == null) { + lastIterator = if (lastSeqNumber != null) { + getKinesisIterator(shard, ShardIteratorType.AFTER_SEQUENCE_NUMBER, lastSeqNumber) + } else { + if (initialPositionInStream == InitialPositionInStream.LATEST) { + getKinesisIterator(shard, ShardIteratorType.LATEST, lastSeqNumber) + } else { + getKinesisIterator(shard, ShardIteratorType.TRIM_HORIZON, lastSeqNumber) + } + } + } + getRecordsAndNextKinesisIterator(lastIterator) + } + + records.foreach { record => + buffer += JavaUtils.bufferToArray(record.getData()) + if (firstSeqNumber == null) { + firstSeqNumber = record.getSequenceNumber + } + lastSeqNumber = record.getSequenceNumber + } + + lastIterator = nextIterator + } + + if (buffer.nonEmpty) { + SparkEnv.get.blockManager.putIterator(blockId, buffer.iterator, StorageLevel.MEMORY_ONLY) + val range = SequenceNumberRange( + shard.streamName, shard.shardId, firstSeqNumber, lastSeqNumber) + logDebug(s"Received block $blockId having range $range from shard $shard") + Some(blockId -> range) + } else { + None + } + } catch { + case NonFatal(e) => + logWarning(s"Error fetching data from shard $shard", e) + None + } + } + + /** + * Get the records starting from using a Kinesis shard iterator (which is a progress handle + * to get records from Kinesis), and get the next shard iterator for next consumption. + */ + private def getRecordsAndNextKinesisIterator(shardIterator: String): (Seq[Record], String) = { + val getRecordsRequest = new GetRecordsRequest().withShardIterator(shardIterator) + getRecordsRequest.setRequestCredentials(credentials) + val getRecordsResult = client.getRecords(getRecordsRequest) + // De-aggregate records, if KPL was used in producing the records. The KCL automatically + // handles de-aggregation during regular operation. This code path is used during recovery + val records = UserRecord.deaggregate(getRecordsResult.getRecords) + logTrace( + s"Got ${records.size()} records and next iterator ${getRecordsResult.getNextShardIterator}") + (records.asScala, getRecordsResult.getNextShardIterator) + } + + /** + * Get the Kinesis shard iterator for getting records starting from or after the given + * sequence number. + */ + private def getKinesisIterator( + shard: Shard, + iteratorType: ShardIteratorType, + sequenceNumber: String): String = { + val getShardIteratorRequest = new GetShardIteratorRequest() + .withStreamName(shard.streamName) + .withShardId(shard.shardId) + .withShardIteratorType(iteratorType.toString) + .withStartingSequenceNumber(sequenceNumber) + getShardIteratorRequest.setRequestCredentials(credentials) + val getShardIteratorResult = client.getShardIterator(getShardIteratorRequest) + logTrace(s"Shard $shard: Got iterator ${getShardIteratorResult.getShardIterator}") + getShardIteratorResult.getShardIterator + } + + /** Helper method to retry Kinesis API request with exponential backoff and timeouts */ + private def retryOrTimeout[T](message: String, retryTimeoutMs: Long)(body: => T): T = { + import KinesisSequenceRangeIterator._ + val startTimeMs = System.currentTimeMillis() + var retryCount = 0 + var waitTimeMs = MIN_RETRY_WAIT_TIME_MS + var result: Option[T] = None + var lastError: Throwable = null + + def isTimedOut = (System.currentTimeMillis() - startTimeMs) >= retryTimeoutMs + def isMaxRetryDone = retryCount >= MAX_RETRIES + + while (result.isEmpty && !isTimedOut && !isMaxRetryDone) { + if (retryCount > 0) { + // wait only if this is a retry + Thread.sleep(waitTimeMs) + waitTimeMs *= 2 // if you have waited, then double wait time for next round + } + try { + result = Some(body) + } catch { + case NonFatal(t) => + lastError = t + t match { + case ptee: ProvisionedThroughputExceededException => + logWarning(s"Error while $message [attempt = ${retryCount + 1}]", ptee) + case e: Throwable => + throw new SparkException(s"Error while $message", e) + } + } + retryCount += 1 + } + result.getOrElse { + if (isTimedOut) { + throw new SparkException( + s"Timed out after $retryTimeoutMs ms while $message, last exception: ", lastError) + } else { + throw new SparkException( + s"Gave up after $retryCount retries while $message, last exception: ", lastError) + } + } + } +} diff --git a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisReceiver.scala b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisReceiver.scala index ca13a21087cc8..d295c4d7a63eb 100644 --- a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisReceiver.scala +++ b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisReceiver.scala @@ -41,7 +41,7 @@ case class SerializableAWSCredentials(accessKeyId: String, secretKey: String) override def getAWSSecretKey: String = secretKey } -object SerializableAWSCredentials { +private[kinesis] object SerializableAWSCredentials { def apply(credentials: AWSCredentials): SerializableAWSCredentials = { new SerializableAWSCredentials(credentials.getAWSAccessKeyId, credentials.getAWSSecretKey) } diff --git a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisSource.scala b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisSource.scala index 07876f5f45aa1..77c6a70235686 100644 --- a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisSource.scala +++ b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisSource.scala @@ -16,23 +16,22 @@ */ package org.apache.spark.streaming.kinesis +import java.util.concurrent.atomic.AtomicLong + import scala.collection.JavaConverters._ +import scala.collection.mutable import scala.collection.mutable.ArrayBuffer -import scala.util.Random -import scala.util.control.NonFatal -import com.amazonaws.auth.DefaultAWSCredentialsProviderChain +import com.amazonaws.AbortedException import com.amazonaws.services.kinesis.AmazonKinesisClient import com.amazonaws.services.kinesis.clientlibrary.lib.worker.InitialPositionInStream -import com.amazonaws.services.kinesis.clientlibrary.types.UserRecord -import com.amazonaws.services.kinesis.model._ import org.apache.spark._ +import org.apache.spark.sql.SQLContext import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder -import org.apache.spark.sql.execution.streaming.{Batch, Offset, Source, StreamingRelation} +import org.apache.spark.sql.execution.streaming.{Batch, Offset, Source} import org.apache.spark.sql.types.StructType -import org.apache.spark.sql.{DataFrame, Dataset, SQLContext} -import org.apache.spark.storage.{BlockId, StorageLevel, StreamBlockId} +import org.apache.spark.storage.{BlockId, StreamBlockId} private[kinesis] case class Shard(streamName: String, shardId: String) @@ -40,14 +39,13 @@ private[kinesis] case class KinesisSourceOffset(seqNums: Map[Shard, String]) extends Offset { override def compareTo(otherOffset: Offset): Int = otherOffset match { - case that: KinesisSourceOffset => val allShards = this.seqNums.keySet ++ that.seqNums.keySet val comparisons = allShards.map { shard => (this.seqNums.get(shard).map(BigInt(_)), that.seqNums.get(shard).map(BigInt(_))) match { case (Some(thisNum), Some(thatNum)) => thisNum.compare(thatNum) case (None, _) => -1 // new shard started by resharding - case (_, None) => -1 // old shard got eliminated by resharding + case (_, None) => 1 // old shard got eliminated by resharding } } @@ -57,7 +55,7 @@ private[kinesis] case class KinesisSourceOffset(seqNums: Map[Shard, String]) case 1 => nonZeroSigns.head // if there are only (0s and 1s) or (0s and -1s) case _ => // there are both 1s and -1s throw new IllegalArgumentException( - s"Invalid comparison between non-linear histories: $this <=> $that") + s"Invalid comparison between KinesisSource offsets: \n\t this: $this\n\t that: $that") } case _ => @@ -65,50 +63,58 @@ private[kinesis] case class KinesisSourceOffset(seqNums: Map[Shard, String]) } } - -private[kinesis] object KinesisSourceOffset { - def fromOffset(offset: Offset): KinesisSourceOffset = { - offset match { - case o: KinesisSourceOffset => o - case _ => - throw new IllegalArgumentException( - s"Invalid conversion from offset of ${offset.getClass} to $getClass") - } - } -} - - -private[kinesis] case class KinesisSource( +private[kinesis] class KinesisSource( sqlContext: SQLContext, regionName: String, endpointUrl: String, streamNames: Set[String], - initialPosInStream: InitialPositionInStream = InitialPositionInStream.LATEST, - awsCredentialsOption: Option[SerializableAWSCredentials] = None) extends Source { + initialPosInStream: InitialPositionInStream, + awsCredentials: SerializableAWSCredentials) extends Source { + + // How long we should wait before calling `fetchShards()`. Because DescribeStream has a limit of + // 10 transactions per second per account, we should not request too frequently. + private val FETCH_SHARDS_INTERVAL_MS = 200L + + // The last time `fetchShards()` is called. + private var lastFetchShardsTimeMS = 0L implicit private val encoder = ExpressionEncoder[Array[Byte]] - private val logicalPlan = StreamingRelation(this) - @transient val credentials = SerializableAWSCredentials( - awsCredentialsOption.getOrElse(new DefaultAWSCredentialsProviderChain().getCredentials()) - ) + private val client = new AmazonKinesisClient(awsCredentials) + client.setEndpoint(endpointUrl) - @transient private val client = new AmazonKinesisClient(credentials) - client.setEndpoint(endpointUrl, "kinesis", regionName) + private var cachedBlocks = new mutable.HashSet[BlockId] - override def schema: StructType = encoder.schema + override val schema: StructType = encoder.schema override def getNextBatch(start: Option[Offset]): Option[Batch] = { - val startOffset = start.map(KinesisSourceOffset.fromOffset) + val now = System.currentTimeMillis() + if (now - lastFetchShardsTimeMS < FETCH_SHARDS_INTERVAL_MS) { + // Because DescribeStream has a limit of 10 transactions per second per account, we should not + // request too frequently. + return None + } + lastFetchShardsTimeMS = now + val startOffset = start.map(_.asInstanceOf[KinesisSourceOffset]) val shards = fetchShards() // Get the starting seq number of each shard if available val fromSeqNums = shards.map { shard => shard -> startOffset.flatMap(_.seqNums.get(shard)) } - /** Prefetch Kinesis data from the starting seq nums */ - val prefetchedData = new KinesisDataFetcher( - sqlContext, credentials, endpointUrl, regionName, fromSeqNums, initialPosInStream).fetch() + // Assign a unique block id for each shard + val fromSeqNumsWithBlockId = fromSeqNums.map { case (shard, seqNum) => + val uniqueBlockId = KinesisSource.nextBlockId + (shard, seqNum, uniqueBlockId) + } + + // Prefetch Kinesis data from the starting seq nums + val prefetchedData = + new KinesisDataFetcher( + awsCredentials, + endpointUrl, + fromSeqNumsWithBlockId, + initialPosInStream).fetch(sqlContext.sparkContext) if (prefetchedData.nonEmpty) { val prefetechedRanges = prefetchedData.map(_._2) @@ -125,189 +131,52 @@ private[kinesis] case class KinesisSource( val rdd = new KinesisBackedBlockRDD[Array[Byte]](sqlContext.sparkContext, regionName, endpointUrl, prefetchedBlockIds, prefetechedRanges.map(SequenceNumberRanges.apply)) + + dropOldBlocks() + cachedBlocks ++= prefetchedBlockIds + Some(new Batch(new KinesisSourceOffset(endOffset), sqlContext.createDataset(rdd).toDF)) } else { None } } - def toDS(): Dataset[Array[Byte]] = { - toDF.as[Array[Byte]] - } - - def toDF(): DataFrame = { - new DataFrame(sqlContext, logicalPlan) - } - - private def fetchShards(): Seq[Shard] = { - streamNames.toSeq.flatMap { streamName => - val desc = client.describeStream(streamName) - desc.getStreamDescription.getShards.asScala.map { s => - Shard(streamName, s.getShardId) + private def dropOldBlocks(): Unit = { + val droppedBlocks = ArrayBuffer[BlockId]() + try { + for (blockId <- cachedBlocks) { + SparkEnv.get.blockManager.removeBlock(blockId) + droppedBlocks += blockId } + } finally { + cachedBlocks --= droppedBlocks } } -} - -private[kinesis] class KinesisDataFetcher( - sqlContext: SQLContext, - credentials: SerializableAWSCredentials, - endpointUrl: String, - regionName: String, - fromSeqNums: Seq[(Shard, Option[String])], - initialPositionInStream: InitialPositionInStream, - readTimeoutMs: Int = 2000 - ) extends Serializable with Logging { - - @transient private lazy val client = new AmazonKinesisClient(credentials) - - def fetch(): Array[(BlockId, SequenceNumberRange)] = { - sqlContext.sparkContext.makeRDD(fromSeqNums, fromSeqNums.size).map { - case (shard, fromSeqNum) => fetchPartition(shard, fromSeqNum) - }.collect().flatten - } - - private def fetchPartition( - shard: Shard, - fromSeqNum: Option[String]): Option[(BlockId, SequenceNumberRange)] = { - client.setEndpoint(endpointUrl, "kinesis", regionName) - - val endTime = System.currentTimeMillis + readTimeoutMs - def timeLeft = math.max(endTime - System.currentTimeMillis, 0) - - val buffer = new ArrayBuffer[Array[Byte]] - var firstSeqNumber: String = null - var lastSeqNumber: String = fromSeqNum.orNull - var lastIterator: String = null + private def fetchShards(): Seq[Shard] = { try { - logDebug(s"Trying to fetch data from $shard, from seq num $lastSeqNumber") - - while (timeLeft > 0) { - val (records, nextIterator) = retryOrTimeout("getting shard iterator", timeLeft) { - if (lastIterator == null) { - lastIterator = if (lastSeqNumber != null) { - getKinesisIterator(shard, ShardIteratorType.AFTER_SEQUENCE_NUMBER, lastSeqNumber) - } else { - if (initialPositionInStream == InitialPositionInStream.LATEST) { - getKinesisIterator(shard, ShardIteratorType.LATEST, lastSeqNumber) - } else { - getKinesisIterator(shard, ShardIteratorType.TRIM_HORIZON, lastSeqNumber) - } - } - } - getRecordsAndNextKinesisIterator(lastIterator) + streamNames.toSeq.flatMap { streamName => + val desc = client.describeStream(streamName) + desc.getStreamDescription.getShards.asScala.map { s => + Shard(streamName, s.getShardId) } - - records.foreach { record => - val byteBuffer = record.getData() - val byteArray = new Array[Byte](byteBuffer.remaining()) - byteBuffer.get(byteArray) - buffer += byteArray - if (firstSeqNumber == null) { - firstSeqNumber = record.getSequenceNumber - } - lastSeqNumber = record.getSequenceNumber - } - - lastIterator = nextIterator - } - - if (buffer.nonEmpty) { - val blockId = StreamBlockId(0, Random.nextLong) - SparkEnv.get.blockManager.putIterator(blockId, buffer.iterator, StorageLevel.MEMORY_ONLY) - val range = SequenceNumberRange( - shard.streamName, shard.shardId, firstSeqNumber, lastSeqNumber) - logDebug(s"Received block $blockId having range $range from shard $shard") - Some(blockId -> range) - } else { - None } } catch { - case NonFatal(e) => - logWarning(s"Error fetching data from shard $shard", e) - None + case e: AbortedException => + // AbortedException will be thrown if the current thread is interrupted + // So let's convert it back to InterruptedException + val e1 = new InterruptedException("thread is interrupted") + e1.addSuppressed(e) + throw e1 } } + override def toString: String = s"KinesisSource[streamNames=${streamNames.mkString(",")}]" +} - /** - * Get the records starting from using a Kinesis shard iterator (which is a progress handle - * to get records from Kinesis), and get the next shard iterator for next consumption. - */ - private def getRecordsAndNextKinesisIterator( - shardIterator: String): (Seq[Record], String) = { - val getRecordsRequest = new GetRecordsRequest - getRecordsRequest.setRequestCredentials(credentials) - getRecordsRequest.setShardIterator(shardIterator) - val getRecordsResult = client.getRecords(getRecordsRequest) - // De-aggregate records, if KPL was used in producing the records. The KCL automatically - // handles de-aggregation during regular operation. This code path is used during recovery - val records = UserRecord.deaggregate(getRecordsResult.getRecords) - logTrace( - s"Got ${records.size()} records and next iterator ${getRecordsResult.getNextShardIterator}") - (records.asScala, getRecordsResult.getNextShardIterator) - } +private[kinesis] object KinesisSource { - /** - * Get the Kinesis shard iterator for getting records starting from or after the given - * sequence number. - */ - private def getKinesisIterator( - shard: Shard, - iteratorType: ShardIteratorType, - sequenceNumber: String): String = { - val getShardIteratorRequest = new GetShardIteratorRequest - getShardIteratorRequest.setRequestCredentials(credentials) - getShardIteratorRequest.setStreamName(shard.streamName) - getShardIteratorRequest.setShardId(shard.shardId) - getShardIteratorRequest.setShardIteratorType(iteratorType.toString) - getShardIteratorRequest.setStartingSequenceNumber(sequenceNumber) - val getShardIteratorResult = client.getShardIterator(getShardIteratorRequest) - logTrace(s"Shard $shard: Got iterator ${getShardIteratorResult.getShardIterator}") - getShardIteratorResult.getShardIterator - } + private val nextId = new AtomicLong(0) - /** Helper method to retry Kinesis API request with exponential backoff and timeouts */ - private def retryOrTimeout[T](message: String, retryTimeoutMs: Long)(body: => T): T = { - import KinesisSequenceRangeIterator._ - var startTimeMs = System.currentTimeMillis() - var retryCount = 0 - var waitTimeMs = MIN_RETRY_WAIT_TIME_MS - var result: Option[T] = None - var lastError: Throwable = null - - def isTimedOut = (System.currentTimeMillis() - startTimeMs) >= retryTimeoutMs - def isMaxRetryDone = retryCount >= MAX_RETRIES - - while (result.isEmpty && !isTimedOut && !isMaxRetryDone) { - if (retryCount > 0) { - // wait only if this is a retry - Thread.sleep(waitTimeMs) - waitTimeMs *= 2 // if you have waited, then double wait time for next round - } - try { - result = Some(body) - } catch { - case NonFatal(t) => - lastError = t - t match { - case ptee: ProvisionedThroughputExceededException => - logWarning(s"Error while $message [attempt = ${retryCount + 1}]", ptee) - case e: Throwable => - throw new SparkException(s"Error while $message", e) - } - } - retryCount += 1 - } - result.getOrElse { - if (isTimedOut) { - throw new SparkException( - s"Timed out after $retryTimeoutMs ms while $message, last exception: ", lastError) - } else { - throw new SparkException( - s"Gave up after $retryCount retries while $message, last exception: ", lastError) - } - } - } + def nextBlockId: StreamBlockId = StreamBlockId(Int.MaxValue, nextId.getAndIncrement) } diff --git a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisTestUtils.scala b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisTestUtils.scala index e3814be676e85..8d32c5c5aee64 100644 --- a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisTestUtils.scala +++ b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisTestUtils.scala @@ -42,7 +42,7 @@ import org.apache.spark.Logging private[kinesis] class KinesisTestUtils extends Logging { val endpointUrl = KinesisTestUtils.endpointUrl - val regionName = RegionUtils.getRegionByEndpoint(endpointUrl).getName() + val regionName = KinesisTestUtils.regionName val streamShardCount = 2 private val createStreamTimeoutSeconds = 300 @@ -54,6 +54,8 @@ private[kinesis] class KinesisTestUtils extends Logging { @volatile private var _streamName: String = _ + private val shardIdToLatestSeqNum = mutable.HashMap[String, String]() + protected lazy val kinesisClient = { val client = new AmazonKinesisClient(KinesisTestUtils.getAWSCredentials()) client.setEndpoint(endpointUrl) @@ -105,9 +107,14 @@ private[kinesis] class KinesisTestUtils extends Logging { val producer = getProducer(aggregate) val shardIdToSeqNumbers = producer.sendData(streamName, testData) logInfo(s"Pushed $testData:\n\t ${shardIdToSeqNumbers.mkString("\n\t")}") + shardIdToSeqNumbers.foreach { case (shardId, seq) => + shardIdToLatestSeqNum(shardId) = seq.last._2 + } shardIdToSeqNumbers.toMap } + def getLatestSeqNumsOfShards(): Map[String, String] = shardIdToLatestSeqNum.toMap + /** * Expose a Python friendly API. */ @@ -118,7 +125,6 @@ private[kinesis] class KinesisTestUtils extends Logging { def deleteStream(): Unit = { try { if (streamCreated) { - logInfo(s"Deleting stream $streamName") kinesisClient.deleteStream(streamName) } } catch { @@ -210,6 +216,8 @@ private[kinesis] object KinesisTestUtils { url } + lazy val regionName = RegionUtils.getRegionByEndpoint(endpointUrl).getName() + def isAWSCredentialsPresent: Boolean = { Try { new DefaultAWSCredentialsProviderChain().getCredentials() }.isSuccess } @@ -256,10 +264,6 @@ private[kinesis] class SimpleDataGenerator( val seqNumber = putRecordResult.getSequenceNumber() val sentSeqNumbers = shardIdToSeqNumbers.getOrElseUpdate(shardId, new ArrayBuffer[(Int, String)]()) - // scalastyle:off println - println(s"$data with key $str in shard ${putRecordResult.getShardId} " + - s"and seq ${putRecordResult.getSequenceNumber}") - // scalastyle:on println sentSeqNumbers += ((num, seqNumber)) seqNumForOrdering = seqNumber } diff --git a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/package.scala b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/package.scala new file mode 100644 index 0000000000000..efb79bad1ee9a --- /dev/null +++ b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/package.scala @@ -0,0 +1,31 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming + +import org.apache.spark.sql.DataFrameReader + +package object kinesis { + + /** + * Add the `kinesis` method to DataFrameReader that allows people to read from Kinesis using + * the DataFileReader. + */ + implicit class KinesisDataFrameReader(reader: DataFrameReader) { + def kinesis(): DataFrameReader = reader.format("org.apache.spark.streaming.kinesis") + } +} diff --git a/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisSourceSuite.scala b/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisSourceSuite.scala index f17131dc93c6c..29336051d089d 100644 --- a/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisSourceSuite.scala +++ b/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisSourceSuite.scala @@ -17,63 +17,209 @@ package org.apache.spark.streaming.kinesis - import com.amazonaws.services.kinesis.clientlibrary.lib.worker.InitialPositionInStream -import org.scalatest.time.SpanSugar._ -import org.apache.spark.sql.StreamTest -import org.apache.spark.sql.execution.streaming.{Offset, Source} +import org.apache.spark.SparkEnv +import org.apache.spark.sql.{AnalysisException, StreamTest} +import org.apache.spark.sql.execution.streaming.{Offset, Source, StreamingRelation} import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.storage.StreamBlockId +abstract class KinesisSourceTest extends StreamTest with SharedSQLContext { -class KinesisSourceSuite extends StreamTest with SharedSQLContext with KinesisFunSuite { - - import testImplicits._ - - private var testUtils: KPLBasedKinesisTestUtils = _ - private var streamName: String = _ + case class AddKinesisData( + testUtils: KPLBasedKinesisTestUtils, + kinesisSource: KinesisSource, + data: Seq[Int]) extends AddData { - override val streamingTimout = 60.seconds - - case class AddKinesisData(kinesisSource: KinesisSource, data: Int*) extends AddData { override def addData(): Offset = { - val shardIdToSeqNums = testUtils.pushData(data, false).map { case (shardId, info) => - (Shard(streamName, shardId), info.last._2) + testUtils.pushData(data, false) + val shardIdToSeqNums = testUtils.getLatestSeqNumsOfShards().map { case (shardId, seqNum) => + (Shard(testUtils.streamName, shardId), seqNum) } - assert(shardIdToSeqNums.size === testUtils.streamShardCount, - s"Data must be send to all ${testUtils.streamShardCount} shards of stream $streamName") KinesisSourceOffset(shardIdToSeqNums) } override def source: Source = kinesisSource } - override def beforeAll(): Unit = { - super.beforeAll() - testUtils = new KPLBasedKinesisTestUtils - testUtils.createStream() - streamName = testUtils.streamName + def createKinesisSourceForTest(testUtils: KPLBasedKinesisTestUtils): KinesisSource = { + new KinesisSource( + sqlContext, + testUtils.regionName, + testUtils.endpointUrl, + Set(testUtils.streamName), + InitialPositionInStream.TRIM_HORIZON, + SerializableAWSCredentials(KinesisTestUtils.getAWSCredentials())) } +} + +class KinesisSourceSuite extends KinesisSourceTest with KinesisFunSuite { + + import testImplicits._ + + testIfEnabled("basic receiving and failover") { + var streamBlocksInLastBatch: Seq[StreamBlockId] = Seq.empty - override def afterAll(): Unit = { - if (testUtils != null) { + def assertStreamBlocks: Boolean = { + if (sqlContext.sparkContext.isLocal) { + // Only test this one in local mode so that we can assume there is only one BlockManager + val streamBlocks = + SparkEnv.get.blockManager.getMatchingBlockIds(_.isInstanceOf[StreamBlockId]) + val cleaned = streamBlocks.intersect(streamBlocksInLastBatch).isEmpty + streamBlocksInLastBatch = streamBlocks.map(_.asInstanceOf[StreamBlockId]) + cleaned + } else { + true + } + } + + val testUtils = new KPLBasedKinesisTestUtils + testUtils.createStream() + try { + val kinesisSource = createKinesisSourceForTest(testUtils) + val mapped = + kinesisSource.toDS[Array[Byte]]().map((bytes: Array[Byte]) => new String(bytes).toInt + 1) + val testData = 1 to 10 + testStream(mapped)( + AddKinesisData(testUtils, kinesisSource, testData), + CheckAnswer((1 to 10).map(_ + 1): _*), + Assert(assertStreamBlocks, "Old stream blocks should be cleaned"), + StopStream, + AddKinesisData(testUtils, kinesisSource, 11 to 20), + StartStream, + CheckAnswer((1 to 20).map(_ + 1): _*), + Assert(assertStreamBlocks, "Old stream blocks should be cleaned"), + AddKinesisData(testUtils, kinesisSource, 21 to 30), + CheckAnswer((1 to 30).map(_ + 1): _*), + Assert(assertStreamBlocks, "Old stream blocks should be cleaned") + ) + } finally { testUtils.deleteStream() } - super.afterAll() } - test("basic receiving") { - val kinesisSource = KinesisSource( - sqlContext, - testUtils.regionName, - testUtils.endpointUrl, - Set(streamName), - InitialPositionInStream.TRIM_HORIZON) - val mapped = kinesisSource.toDS().map[Int]((bytes: Array[Byte]) => new String(bytes).toInt + 1) - val testData = 1 to 10 // This ensures that data is sent to multiple shards for 2 shard streams - testStream(mapped)( - AddKinesisData(kinesisSource, testData: _*), - CheckAnswer(testData.map { _ + 1 }: _*) - ) + test("DataFrameReader") { + val df = sqlContext.read + .option("endpoint", KinesisTestUtils.endpointUrl) + .option("stream", "stream1") + .option("accessKey", "accessKey") + .option("secretKey", "secretKey") + .option("position", InitialPositionInStream.TRIM_HORIZON.name()) + .kinesis().stream() + + val sources = df.queryExecution.analyzed.collect { + case StreamingRelation(s: KinesisSource, _) => s + } + assert(sources.size === 1) + + // stream + assertExceptionAndMessage[IllegalArgumentException]( + "Option 'stream' must be specified.") { + sqlContext.read.kinesis().stream() + } + assertExceptionAndMessage[IllegalArgumentException]( + "Option 'stream' is invalid, as stream names cannot be empty.") { + sqlContext.read.option("stream", "").kinesis().stream() + } + assertExceptionAndMessage[IllegalArgumentException]( + "Option 'stream' is invalid, as stream names cannot be empty.") { + sqlContext.read.option("stream", "a,").kinesis().stream() + } + assertExceptionAndMessage[IllegalArgumentException]( + "Option 'stream' is invalid, as stream names cannot be empty.") { + sqlContext.read.option("stream", ",a").kinesis().stream() + } + + // region and endpoint + // Setting either endpoint or region is fine + sqlContext.read + .option("stream", "stream1") + .option("endpoint", KinesisTestUtils.endpointUrl) + .option("accessKey", "accessKey") + .option("secretKey", "secretKey") + .kinesis().stream() + sqlContext.read + .option("stream", "stream1") + .option("region", KinesisTestUtils.regionName) + .option("accessKey", "accessKey") + .option("secretKey", "secretKey") + .kinesis().stream() + + assertExceptionAndMessage[IllegalArgumentException]( + "Either option 'region' or option 'endpoint' must be specified.") { + sqlContext.read.option("stream", "stream1").kinesis().stream() + } + assertExceptionAndMessage[IllegalArgumentException]( + s"'region'(invalid-region) doesn't match to 'endpoint'(${KinesisTestUtils.endpointUrl})") { + sqlContext.read + .option("stream", "stream1") + .option("region", "invalid-region") + .option("endpoint", KinesisTestUtils.endpointUrl) + .kinesis().stream() + } + + // position + assertExceptionAndMessage[IllegalArgumentException]( + "Unknown value of option 'position': invalid") { + sqlContext.read + .option("stream", "stream1") + .option("endpoint", KinesisTestUtils.endpointUrl) + .option("position", "invalid") + .kinesis().stream() + } + + // accessKey and secretKey + assertExceptionAndMessage[IllegalArgumentException]( + "'accessKey' is set but 'secretKey' is not found") { + sqlContext.read + .option("stream", "stream1") + .option("endpoint", KinesisTestUtils.endpointUrl) + .option("position", InitialPositionInStream.TRIM_HORIZON.name()) + .option("accessKey", "test") + .kinesis().stream() + } + assertExceptionAndMessage[IllegalArgumentException]( + "'secretKey' is set but 'accessKey' is not found") { + sqlContext.read + .option("stream", "stream1") + .option("endpoint", KinesisTestUtils.endpointUrl) + .option("position", InitialPositionInStream.TRIM_HORIZON.name()) + .option("secretKey", "test") + .kinesis().stream() + } + } + + test("call kinesis when not using stream") { + intercept[AnalysisException] { + sqlContext.read.kinesis().load() + } + } + + private def assertExceptionAndMessage[T <: Exception : Manifest]( + expectedMessage: String)(body: => Unit): Unit = { + val e = intercept[T] { + body + } + assert(e.getMessage.contains(expectedMessage)) + } +} + +class KinesisSourceStressTestSuite extends KinesisSourceTest with KinesisFunSuite { + + import testImplicits._ + + testIfEnabled("kinesis source stress test") { + val testUtils = new KPLBasedKinesisTestUtils + testUtils.createStream() + try { + val kinesisSource = createKinesisSourceForTest(testUtils) + val ds = kinesisSource.toDS[String]().map(_.toInt + 1) + runStressTest(ds, data => { + AddKinesisData(testUtils, kinesisSource, data) + }) + } finally { + testUtils.deleteStream() + } } -} \ No newline at end of file +} diff --git a/pom.xml b/pom.xml index 244355b080221..88b9ab2a37ce3 100644 --- a/pom.xml +++ b/pom.xml @@ -154,9 +154,9 @@ 1.7.7 hadoop2 0.7.1 - 1.6.1 + 1.4.0 - 0.10.2 + 0.10.1 4.3.2 diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StreamTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/StreamTest.scala index bb5135826e2f3..fbf6f6ba49600 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/StreamTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/StreamTest.scala @@ -380,10 +380,10 @@ trait StreamTest extends QueryTest with Timeouts { pos += 1 } } catch { - case _: InterruptedException if streamDeathCause != null => - failTest("Stream Thread Died") - case _: org.scalatest.exceptions.TestFailedDueToTimeoutException => - failTest("Timed out waiting for stream") + case e: InterruptedException if streamDeathCause != null => + failTest("Stream Thread Died", e) + case e: org.scalatest.exceptions.TestFailedDueToTimeoutException => + failTest("Timed out waiting for stream", e) } finally { if (currentStream != null && currentStream.microBatchThread.isAlive) { currentStream.stop()