Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -17,25 +17,29 @@

package org.apache.spark.streaming.kafka

import java.io.OutputStream
import java.lang.{Integer => JInt, Long => JLong}
import java.util.{List => JList, Map => JMap, Set => JSet}

import scala.collection.JavaConverters._
import scala.reflect.ClassTag

import com.google.common.base.Charsets.UTF_8
import kafka.common.TopicAndPartition
import kafka.message.MessageAndMetadata
import kafka.serializer.{Decoder, DefaultDecoder, StringDecoder}
import kafka.serializer.{DefaultDecoder, Decoder, StringDecoder}
import net.razorvine.pickle.{Opcodes, Pickler, IObjectPickler}

import org.apache.spark.api.java.function.{Function => JFunction}
import org.apache.spark.api.java.{JavaPairRDD, JavaRDD, JavaSparkContext}
import org.apache.spark.streaming.util.WriteAheadLogUtils
import org.apache.spark.{SparkContext, SparkException}
import org.apache.spark.api.java.{JavaSparkContext, JavaPairRDD, JavaRDD}
import org.apache.spark.api.python.SerDeUtil
import org.apache.spark.rdd.RDD
import org.apache.spark.storage.StorageLevel
import org.apache.spark.streaming.StreamingContext
import org.apache.spark.streaming.api.java.{JavaInputDStream, JavaPairInputDStream, JavaPairReceiverInputDStream, JavaStreamingContext}
import org.apache.spark.streaming.dstream.{InputDStream, ReceiverInputDStream}
import org.apache.spark.streaming.util.WriteAheadLogUtils
import org.apache.spark.{SparkContext, SparkException}
import org.apache.spark.streaming.api.java._
import org.apache.spark.streaming.dstream.{DStream, InputDStream, ReceiverInputDStream}

object KafkaUtils {
/**
Expand Down Expand Up @@ -184,6 +188,27 @@ object KafkaUtils {
}
}

private[kafka] def getFromOffsets(
kc: KafkaCluster,
kafkaParams: Map[String, String],
topics: Set[String]
): Map[TopicAndPartition, Long] = {
val reset = kafkaParams.get("auto.offset.reset").map(_.toLowerCase)
val result = for {
topicPartitions <- kc.getPartitions(topics).right
leaderOffsets <- (if (reset == Some("smallest")) {
kc.getEarliestLeaderOffsets(topicPartitions)
} else {
kc.getLatestLeaderOffsets(topicPartitions)
}).right
} yield {
leaderOffsets.map { case (tp, lo) =>
(tp, lo.offset)
}
}
KafkaCluster.checkErrors(result)
}

/**
* Create a RDD from Kafka using offset ranges for each topic and partition.
*
Expand Down Expand Up @@ -246,7 +271,7 @@ object KafkaUtils {
// This could be avoided by refactoring KafkaRDD.leaders and KafkaCluster to use Broker
leaders.map {
case (tp: TopicAndPartition, Broker(host, port)) => (tp, (host, port))
}.toMap
}
}
val cleanedHandler = sc.clean(messageHandler)
checkOffsets(kc, offsetRanges)
Expand Down Expand Up @@ -406,23 +431,9 @@ object KafkaUtils {
): InputDStream[(K, V)] = {
val messageHandler = (mmd: MessageAndMetadata[K, V]) => (mmd.key, mmd.message)
val kc = new KafkaCluster(kafkaParams)
val reset = kafkaParams.get("auto.offset.reset").map(_.toLowerCase)

val result = for {
topicPartitions <- kc.getPartitions(topics).right
leaderOffsets <- (if (reset == Some("smallest")) {
kc.getEarliestLeaderOffsets(topicPartitions)
} else {
kc.getLatestLeaderOffsets(topicPartitions)
}).right
} yield {
val fromOffsets = leaderOffsets.map { case (tp, lo) =>
(tp, lo.offset)
}
new DirectKafkaInputDStream[K, V, KD, VD, (K, V)](
ssc, kafkaParams, fromOffsets, messageHandler)
}
KafkaCluster.checkErrors(result)
val fromOffsets = getFromOffsets(kc, kafkaParams, topics)
new DirectKafkaInputDStream[K, V, KD, VD, (K, V)](
ssc, kafkaParams, fromOffsets, messageHandler)
}

/**
Expand Down Expand Up @@ -550,6 +561,8 @@ object KafkaUtils {
* takes care of known parameters instead of passing them from Python
*/
private[kafka] class KafkaUtilsPythonHelper {
import KafkaUtilsPythonHelper._

def createStream(
jssc: JavaStreamingContext,
kafkaParams: JMap[String, String],
Expand All @@ -566,86 +579,92 @@ private[kafka] class KafkaUtilsPythonHelper {
storageLevel)
}

def createRDD(
def createRDDWithoutMessageHandler(
jsc: JavaSparkContext,
kafkaParams: JMap[String, String],
offsetRanges: JList[OffsetRange],
leaders: JMap[TopicAndPartition, Broker]): JavaPairRDD[Array[Byte], Array[Byte]] = {
val messageHandler = new JFunction[MessageAndMetadata[Array[Byte], Array[Byte]],
(Array[Byte], Array[Byte])] {
def call(t1: MessageAndMetadata[Array[Byte], Array[Byte]]): (Array[Byte], Array[Byte]) =
(t1.key(), t1.message())
}
leaders: JMap[TopicAndPartition, Broker]): JavaRDD[(Array[Byte], Array[Byte])] = {
val messageHandler =
(mmd: MessageAndMetadata[Array[Byte], Array[Byte]]) => (mmd.key, mmd.message)
new JavaRDD(createRDD(jsc, kafkaParams, offsetRanges, leaders, messageHandler))
}

val jrdd = KafkaUtils.createRDD[
Array[Byte],
Array[Byte],
DefaultDecoder,
DefaultDecoder,
(Array[Byte], Array[Byte])](
jsc,
classOf[Array[Byte]],
classOf[Array[Byte]],
classOf[DefaultDecoder],
classOf[DefaultDecoder],
classOf[(Array[Byte], Array[Byte])],
kafkaParams,
offsetRanges.toArray(new Array[OffsetRange](offsetRanges.size())),
leaders,
messageHandler
)
new JavaPairRDD(jrdd.rdd)
def createRDDWithMessageHandler(
jsc: JavaSparkContext,
kafkaParams: JMap[String, String],
offsetRanges: JList[OffsetRange],
leaders: JMap[TopicAndPartition, Broker]): JavaRDD[Array[Byte]] = {
val messageHandler = (mmd: MessageAndMetadata[Array[Byte], Array[Byte]]) =>
new PythonMessageAndMetadata(
mmd.topic, mmd.partition, mmd.offset, mmd.key(), mmd.message())
val rdd = createRDD(jsc, kafkaParams, offsetRanges, leaders, messageHandler).
mapPartitions(picklerIterator)
new JavaRDD(rdd)
}

def createDirectStream(
private def createRDD[V: ClassTag](
jsc: JavaSparkContext,
kafkaParams: JMap[String, String],
offsetRanges: JList[OffsetRange],
leaders: JMap[TopicAndPartition, Broker],
messageHandler: MessageAndMetadata[Array[Byte], Array[Byte]] => V): RDD[V] = {
KafkaUtils.createRDD[Array[Byte], Array[Byte], DefaultDecoder, DefaultDecoder, V](
jsc.sc,
kafkaParams.asScala.toMap,
offsetRanges.toArray(new Array[OffsetRange](offsetRanges.size())),
leaders.asScala.toMap,
messageHandler
)
}

def createDirectStreamWithoutMessageHandler(
jssc: JavaStreamingContext,
kafkaParams: JMap[String, String],
topics: JSet[String],
fromOffsets: JMap[TopicAndPartition, JLong]): JavaDStream[(Array[Byte], Array[Byte])] = {
val messageHandler =
(mmd: MessageAndMetadata[Array[Byte], Array[Byte]]) => (mmd.key, mmd.message)
new JavaDStream(createDirectStream(jssc, kafkaParams, topics, fromOffsets, messageHandler))
}

def createDirectStreamWithMessageHandler(
jssc: JavaStreamingContext,
kafkaParams: JMap[String, String],
topics: JSet[String],
fromOffsets: JMap[TopicAndPartition, JLong]
): JavaPairInputDStream[Array[Byte], Array[Byte]] = {
fromOffsets: JMap[TopicAndPartition, JLong]): JavaDStream[Array[Byte]] = {
val messageHandler = (mmd: MessageAndMetadata[Array[Byte], Array[Byte]]) =>
new PythonMessageAndMetadata(mmd.topic, mmd.partition, mmd.offset, mmd.key(), mmd.message())
val stream = createDirectStream(jssc, kafkaParams, topics, fromOffsets, messageHandler).
mapPartitions(picklerIterator)
new JavaDStream(stream)
}

if (!fromOffsets.isEmpty) {
private def createDirectStream[V: ClassTag](
jssc: JavaStreamingContext,
kafkaParams: JMap[String, String],
topics: JSet[String],
fromOffsets: JMap[TopicAndPartition, JLong],
messageHandler: MessageAndMetadata[Array[Byte], Array[Byte]] => V): DStream[V] = {

val currentFromOffsets = if (!fromOffsets.isEmpty) {
val topicsFromOffsets = fromOffsets.keySet().asScala.map(_.topic)
if (topicsFromOffsets != topics.asScala.toSet) {
throw new IllegalStateException(
s"The specified topics: ${topics.asScala.toSet.mkString(" ")} " +
s"do not equal to the topic from offsets: ${topicsFromOffsets.mkString(" ")}")
}
}

if (fromOffsets.isEmpty) {
KafkaUtils.createDirectStream[Array[Byte], Array[Byte], DefaultDecoder, DefaultDecoder](
jssc,
classOf[Array[Byte]],
classOf[Array[Byte]],
classOf[DefaultDecoder],
classOf[DefaultDecoder],
kafkaParams,
topics)
Map(fromOffsets.asScala.mapValues { _.longValue() }.toSeq: _*)
} else {
val messageHandler = new JFunction[MessageAndMetadata[Array[Byte], Array[Byte]],
(Array[Byte], Array[Byte])] {
def call(t1: MessageAndMetadata[Array[Byte], Array[Byte]]): (Array[Byte], Array[Byte]) =
(t1.key(), t1.message())
}

val jstream = KafkaUtils.createDirectStream[
Array[Byte],
Array[Byte],
DefaultDecoder,
DefaultDecoder,
(Array[Byte], Array[Byte])](
jssc,
classOf[Array[Byte]],
classOf[Array[Byte]],
classOf[DefaultDecoder],
classOf[DefaultDecoder],
classOf[(Array[Byte], Array[Byte])],
kafkaParams,
fromOffsets,
messageHandler)
new JavaPairInputDStream(jstream.inputDStream)
val kc = new KafkaCluster(Map(kafkaParams.asScala.toSeq: _*))
KafkaUtils.getFromOffsets(
kc, Map(kafkaParams.asScala.toSeq: _*), Set(topics.asScala.toSeq: _*))
}

KafkaUtils.createDirectStream[Array[Byte], Array[Byte], DefaultDecoder, DefaultDecoder, V](
jssc.ssc,
Map(kafkaParams.asScala.toSeq: _*),
Map(currentFromOffsets.toSeq: _*),
messageHandler)
}

def createOffsetRange(topic: String, partition: JInt, fromOffset: JLong, untilOffset: JLong
Expand All @@ -669,3 +688,57 @@ private[kafka] class KafkaUtilsPythonHelper {
kafkaRDD.offsetRanges.toSeq.asJava
}
}

private object KafkaUtilsPythonHelper {
private var initialized = false

def initialize(): Unit = {
SerDeUtil.initialize()
synchronized {
if (!initialized) {
new PythonMessageAndMetadataPickler().register()
initialized = true
}
}
}

initialize()

def picklerIterator(iter: Iterator[Any]): Iterator[Array[Byte]] = {
new SerDeUtil.AutoBatchedPickler(iter)
}

case class PythonMessageAndMetadata(
topic: String,
partition: JInt,
offset: JLong,
key: Array[Byte],
message: Array[Byte])

class PythonMessageAndMetadataPickler extends IObjectPickler {
private val module = "pyspark.streaming.kafka"

def register(): Unit = {
Pickler.registerCustomPickler(classOf[PythonMessageAndMetadata], this)
Pickler.registerCustomPickler(this.getClass, this)
}

def pickle(obj: Object, out: OutputStream, pickler: Pickler) {
if (obj == this) {
out.write(Opcodes.GLOBAL)
out.write(s"$module\nKafkaMessageAndMetadata\n".getBytes(UTF_8))
} else {
pickler.save(this)
val msgAndMetaData = obj.asInstanceOf[PythonMessageAndMetadata]
out.write(Opcodes.MARK)
pickler.save(msgAndMetaData.topic)
pickler.save(msgAndMetaData.partition)
pickler.save(msgAndMetaData.offset)
pickler.save(msgAndMetaData.key)
pickler.save(msgAndMetaData.message)
out.write(Opcodes.TUPLE)
out.write(Opcodes.REDUCE)
}
}
}
}
6 changes: 6 additions & 0 deletions project/MimaExcludes.scala
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,12 @@ object MimaExcludes {
// SPARK-11766 add toJson to Vector
ProblemFilters.exclude[MissingMethodProblem](
"org.apache.spark.mllib.linalg.Vector.toJson")
) ++ Seq(
// SPARK-9065 Support message handler in Kafka Python API
ProblemFilters.exclude[MissingMethodProblem](
"org.apache.spark.streaming.kafka.KafkaUtilsPythonHelper.createDirectStream"),
ProblemFilters.exclude[MissingMethodProblem](
"org.apache.spark.streaming.kafka.KafkaUtilsPythonHelper.createRDD")
)
case v if v.startsWith("1.5") =>
Seq(
Expand Down
Loading