Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions extras/kinesis-asl-assembly/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,12 @@
<version>${project.version}</version>
<scope>provided</scope>
</dependency>
<dependency>
<groupId>org.apache.spark</groupId>
<artifactId>spark-sql_${scala.binary.version}</artifactId>
<version>${project.version}</version>
<scope>provided</scope>
</dependency>
<!--
Demote already included in the Spark assembly.
-->
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.streaming.kinesis

import com.amazonaws.AmazonClientException
import com.amazonaws.auth.DefaultAWSCredentialsProviderChain
import com.amazonaws.regions.RegionUtils
import com.amazonaws.services.kinesis.clientlibrary.lib.worker.InitialPositionInStream

import org.apache.spark.sql.SQLContext
import org.apache.spark.sql.execution.datasources.CaseInsensitiveMap
import org.apache.spark.sql.execution.streaming.Source
import org.apache.spark.sql.sources.{DataSourceRegister, StreamSourceProvider}
import org.apache.spark.sql.types.StructType

class DefaultSource extends StreamSourceProvider with DataSourceRegister {

override def shortName(): String = "kinesis"

override def createSource(
sqlContext: SQLContext,
schema: Option[StructType],
providerName: String,
parameters: Map[String, String]): Source = {
val caseInsensitiveOptions = new CaseInsensitiveMap(parameters)

val streams = caseInsensitiveOptions.getOrElse("stream", {
throw new IllegalArgumentException(
"Option 'stream' must be specified. Examples: " +
"""option("stream", "stream1"), option("stream", "stream1,stream2")""")
}).split(",", -1).toSet

if (streams.isEmpty || streams.exists(_.isEmpty)) {
throw new IllegalArgumentException(
"Option 'stream' is invalid, as stream names cannot be empty.")
}

val regionOption = caseInsensitiveOptions.get("region")
val endpointOption = caseInsensitiveOptions.get("endpoint")
val (region, endpoint) = (regionOption, endpointOption) match {
case (Some(_region), Some(_endpoint)) =>
if (RegionUtils.getRegionByEndpoint(_endpoint).getName != _region) {
throw new IllegalArgumentException(
s"'region'(${_region}) doesn't match to 'endpoint'(${_endpoint})")
}
(_region, _endpoint)
case (Some(_region), None) =>
(_region, RegionUtils.getRegion(_region).getServiceEndpoint("kinesis"))
case (None, Some(_endpoint)) =>
(RegionUtils.getRegionByEndpoint(_endpoint).getName, _endpoint)
case (None, None) =>
throw new IllegalArgumentException(
"Either option 'region' or option 'endpoint' must be specified. Examples: " +
"""option("region", "us-west-2"), """ +
"""option("endpoint", "https://kinesis.us-west-2.amazonaws.com")""")
}

val initialPosInStream =
caseInsensitiveOptions.getOrElse("position", InitialPositionInStream.LATEST.name) match {
case pos if pos.toUpperCase == InitialPositionInStream.LATEST.name =>
InitialPositionInStream.LATEST
case pos if pos.toUpperCase == InitialPositionInStream.TRIM_HORIZON.name =>
InitialPositionInStream.TRIM_HORIZON
case pos =>
throw new IllegalArgumentException(s"Unknown value of option 'position': $pos")
}

val accessKeyOption = caseInsensitiveOptions.get("accessKey")
val secretKeyOption = caseInsensitiveOptions.get("secretKey")
val credentials = (accessKeyOption, secretKeyOption) match {
case (Some(accessKey), Some(secretKey)) =>
new SerializableAWSCredentials(accessKey, secretKey)
case (Some(accessKey), None) =>
throw new IllegalArgumentException(
s"'accessKey' is set but 'secretKey' is not found")
case (None, Some(secretKey)) =>
throw new IllegalArgumentException(
s"'secretKey' is set but 'accessKey' is not found")
case (None, None) =>
try {
SerializableAWSCredentials(new DefaultAWSCredentialsProviderChain().getCredentials())
} catch {
case _: AmazonClientException =>
throw new IllegalArgumentException(
"No credential found using default AWS provider chain. Specify credentials using " +
"options 'accessKey' and 'secretKey'. Examples: " +
"""option("accessKey", "your-aws-access-key"), """ +
"""option("secretKey", "your-aws-secret-key")""")
}
}
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Check whether access keys are available (either specified, or loaded from a provider) right here. Otherwise, it fails later, and gives a bad error message like

com.amazonaws.AmazonClientException: Unable to load AWS credentials from any provider in the chain
    at com.amazonaws.auth.AWSCredentialsProviderChain.getCredentials(AWSCredentialsProviderChain.java:117)
    at org.apache.spark.streaming.kinesis.KinesisSource$$anonfun$5.apply(KinesisSource.scala:85)

So better to resolve the credentials right here and pass it on to the source without an option.


new KinesisSource(
sqlContext,
region,
endpoint,
streams,
initialPosInStream,
credentials)
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,203 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.streaming.kinesis

import scala.collection.JavaConverters._
import scala.collection.mutable.ArrayBuffer
import scala.util.control.NonFatal

import com.amazonaws.services.kinesis.AmazonKinesisClient
import com.amazonaws.services.kinesis.clientlibrary.lib.worker.InitialPositionInStream
import com.amazonaws.services.kinesis.clientlibrary.types.UserRecord
import com.amazonaws.services.kinesis.model._

import org.apache.spark._
import org.apache.spark.network.util.JavaUtils
import org.apache.spark.storage.{BlockId, StorageLevel}

/**
* However, this class runs in the driver so could be a bottleneck.
*/
private[kinesis] class KinesisDataFetcher(
credentials: SerializableAWSCredentials,
endpointUrl: String,
fromSeqNums: Seq[(Shard, Option[String], BlockId)],
initialPositionInStream: InitialPositionInStream,
readTimeoutMs: Long = 2000L) extends Serializable with Logging {

/**
* Use lazy because the client needs to be created in executors
*/
@transient private lazy val client = new AmazonKinesisClient(credentials)

/**
* Launch a Spark job to fetch latest data from the specified `shard`s. This method will try to
* fetch arriving data in `readTimeoutMs` milliseconds so as to get the latest sequence numbers.
* New data will be pushed to the block manager to avoid fetching them again.
*
* This is a workaround since Kinesis doesn't provider an API to fetch the latest sequence number.
*/
def fetch(sc: SparkContext): Array[(BlockId, SequenceNumberRange)] = {
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add scala docs.

sc.makeRDD(fromSeqNums, fromSeqNums.size).map {
case (shard, fromSeqNum, blockId) => fetchPartition(shard, fromSeqNum, blockId)
}.collect().flatten
}

/**
* Fetch latest data from the specified `shard` since `fromSeqNum`. This method will try to fetch
* arriving data in `readTimeoutMs` milliseconds so as to get the latest sequence number. New data
* will be pushed to the block manager to avoid fetching them again.
*
* This is a workaround since Kinesis doesn't provider an API to fetch the latest sequence number.
*/
private def fetchPartition(
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add scala docs

shard: Shard,
fromSeqNum: Option[String],
blockId: BlockId): Option[(BlockId, SequenceNumberRange)] = {
client.setEndpoint(endpointUrl)

val endTime = System.currentTimeMillis + readTimeoutMs
def timeLeft = math.max(endTime - System.currentTimeMillis, 0)

val buffer = new ArrayBuffer[Array[Byte]]
var firstSeqNumber: String = null
var lastSeqNumber: String = fromSeqNum.orNull
var lastIterator: String = null
try {
logDebug(s"Trying to fetch data from $shard, from seq num $lastSeqNumber")

while (timeLeft > 0) {
val (records, nextIterator) = retryOrTimeout("getting shard iterator", timeLeft) {
if (lastIterator == null) {
lastIterator = if (lastSeqNumber != null) {
getKinesisIterator(shard, ShardIteratorType.AFTER_SEQUENCE_NUMBER, lastSeqNumber)
} else {
if (initialPositionInStream == InitialPositionInStream.LATEST) {
getKinesisIterator(shard, ShardIteratorType.LATEST, lastSeqNumber)
} else {
getKinesisIterator(shard, ShardIteratorType.TRIM_HORIZON, lastSeqNumber)
}
}
}
getRecordsAndNextKinesisIterator(lastIterator)
}

records.foreach { record =>
buffer += JavaUtils.bufferToArray(record.getData())
if (firstSeqNumber == null) {
firstSeqNumber = record.getSequenceNumber
}
lastSeqNumber = record.getSequenceNumber
}

lastIterator = nextIterator
}

if (buffer.nonEmpty) {
SparkEnv.get.blockManager.putIterator(blockId, buffer.iterator, StorageLevel.MEMORY_ONLY)
val range = SequenceNumberRange(
shard.streamName, shard.shardId, firstSeqNumber, lastSeqNumber)
logDebug(s"Received block $blockId having range $range from shard $shard")
Some(blockId -> range)
} else {
None
}
} catch {
case NonFatal(e) =>
logWarning(s"Error fetching data from shard $shard", e)
None
}
}

/**
* Get the records starting from using a Kinesis shard iterator (which is a progress handle
* to get records from Kinesis), and get the next shard iterator for next consumption.
*/
private def getRecordsAndNextKinesisIterator(shardIterator: String): (Seq[Record], String) = {
val getRecordsRequest = new GetRecordsRequest().withShardIterator(shardIterator)
getRecordsRequest.setRequestCredentials(credentials)
val getRecordsResult = client.getRecords(getRecordsRequest)
// De-aggregate records, if KPL was used in producing the records. The KCL automatically
// handles de-aggregation during regular operation. This code path is used during recovery
val records = UserRecord.deaggregate(getRecordsResult.getRecords)
logTrace(
s"Got ${records.size()} records and next iterator ${getRecordsResult.getNextShardIterator}")
(records.asScala, getRecordsResult.getNextShardIterator)
}

/**
* Get the Kinesis shard iterator for getting records starting from or after the given
* sequence number.
*/
private def getKinesisIterator(
shard: Shard,
iteratorType: ShardIteratorType,
sequenceNumber: String): String = {
val getShardIteratorRequest = new GetShardIteratorRequest()
.withStreamName(shard.streamName)
.withShardId(shard.shardId)
.withShardIteratorType(iteratorType.toString)
.withStartingSequenceNumber(sequenceNumber)
getShardIteratorRequest.setRequestCredentials(credentials)
val getShardIteratorResult = client.getShardIterator(getShardIteratorRequest)
logTrace(s"Shard $shard: Got iterator ${getShardIteratorResult.getShardIterator}")
getShardIteratorResult.getShardIterator
}

/** Helper method to retry Kinesis API request with exponential backoff and timeouts */
private def retryOrTimeout[T](message: String, retryTimeoutMs: Long)(body: => T): T = {
import KinesisSequenceRangeIterator._
val startTimeMs = System.currentTimeMillis()
var retryCount = 0
var waitTimeMs = MIN_RETRY_WAIT_TIME_MS
var result: Option[T] = None
var lastError: Throwable = null

def isTimedOut = (System.currentTimeMillis() - startTimeMs) >= retryTimeoutMs
def isMaxRetryDone = retryCount >= MAX_RETRIES

while (result.isEmpty && !isTimedOut && !isMaxRetryDone) {
if (retryCount > 0) {
// wait only if this is a retry
Thread.sleep(waitTimeMs)
waitTimeMs *= 2 // if you have waited, then double wait time for next round
}
try {
result = Some(body)
} catch {
case NonFatal(t) =>
lastError = t
t match {
case ptee: ProvisionedThroughputExceededException =>
logWarning(s"Error while $message [attempt = ${retryCount + 1}]", ptee)
case e: Throwable =>
throw new SparkException(s"Error while $message", e)
}
}
retryCount += 1
}
result.getOrElse {
if (isTimedOut) {
throw new SparkException(
s"Timed out after $retryTimeoutMs ms while $message, last exception: ", lastError)
} else {
throw new SparkException(
s"Gave up after $retryCount retries while $message, last exception: ", lastError)
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ case class SerializableAWSCredentials(accessKeyId: String, secretKey: String)
override def getAWSSecretKey: String = secretKey
}

object SerializableAWSCredentials {
private[kinesis] object SerializableAWSCredentials {
def apply(credentials: AWSCredentials): SerializableAWSCredentials = {
new SerializableAWSCredentials(credentials.getAWSAccessKeyId, credentials.getAWSSecretKey)
}
Expand Down
Loading