Skip to content

Commit a7fcc31

Browse files
jerryshaotdas
authored andcommitted
[SPARK-9065][STREAMING][PYSPARK] Add MessageHandler for Kafka Python API
Fixed the merge conflicts in #7410 Closes #7410 Author: Shixiong Zhu <[email protected]> Author: jerryshao <[email protected]> Author: jerryshao <[email protected]> Closes #9742 from zsxwing/pr7410. (cherry picked from commit 75a2922) Signed-off-by: Tathagata Das <[email protected]>
1 parent 3133d8b commit a7fcc31

File tree

4 files changed

+299
-98
lines changed

4 files changed

+299
-98
lines changed

external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaUtils.scala

Lines changed: 159 additions & 86 deletions
Original file line numberDiff line numberDiff line change
@@ -17,25 +17,29 @@
1717

1818
package org.apache.spark.streaming.kafka
1919

20+
import java.io.OutputStream
2021
import java.lang.{Integer => JInt, Long => JLong}
2122
import java.util.{List => JList, Map => JMap, Set => JSet}
2223

2324
import scala.collection.JavaConverters._
2425
import scala.reflect.ClassTag
2526

27+
import com.google.common.base.Charsets.UTF_8
2628
import kafka.common.TopicAndPartition
2729
import kafka.message.MessageAndMetadata
28-
import kafka.serializer.{Decoder, DefaultDecoder, StringDecoder}
30+
import kafka.serializer.{DefaultDecoder, Decoder, StringDecoder}
31+
import net.razorvine.pickle.{Opcodes, Pickler, IObjectPickler}
2932

3033
import org.apache.spark.api.java.function.{Function => JFunction}
31-
import org.apache.spark.api.java.{JavaPairRDD, JavaRDD, JavaSparkContext}
34+
import org.apache.spark.streaming.util.WriteAheadLogUtils
35+
import org.apache.spark.{SparkContext, SparkException}
36+
import org.apache.spark.api.java.{JavaSparkContext, JavaPairRDD, JavaRDD}
37+
import org.apache.spark.api.python.SerDeUtil
3238
import org.apache.spark.rdd.RDD
3339
import org.apache.spark.storage.StorageLevel
3440
import org.apache.spark.streaming.StreamingContext
35-
import org.apache.spark.streaming.api.java.{JavaInputDStream, JavaPairInputDStream, JavaPairReceiverInputDStream, JavaStreamingContext}
36-
import org.apache.spark.streaming.dstream.{InputDStream, ReceiverInputDStream}
37-
import org.apache.spark.streaming.util.WriteAheadLogUtils
38-
import org.apache.spark.{SparkContext, SparkException}
41+
import org.apache.spark.streaming.api.java._
42+
import org.apache.spark.streaming.dstream.{DStream, InputDStream, ReceiverInputDStream}
3943

4044
object KafkaUtils {
4145
/**
@@ -184,6 +188,27 @@ object KafkaUtils {
184188
}
185189
}
186190

191+
private[kafka] def getFromOffsets(
192+
kc: KafkaCluster,
193+
kafkaParams: Map[String, String],
194+
topics: Set[String]
195+
): Map[TopicAndPartition, Long] = {
196+
val reset = kafkaParams.get("auto.offset.reset").map(_.toLowerCase)
197+
val result = for {
198+
topicPartitions <- kc.getPartitions(topics).right
199+
leaderOffsets <- (if (reset == Some("smallest")) {
200+
kc.getEarliestLeaderOffsets(topicPartitions)
201+
} else {
202+
kc.getLatestLeaderOffsets(topicPartitions)
203+
}).right
204+
} yield {
205+
leaderOffsets.map { case (tp, lo) =>
206+
(tp, lo.offset)
207+
}
208+
}
209+
KafkaCluster.checkErrors(result)
210+
}
211+
187212
/**
188213
* Create a RDD from Kafka using offset ranges for each topic and partition.
189214
*
@@ -246,7 +271,7 @@ object KafkaUtils {
246271
// This could be avoided by refactoring KafkaRDD.leaders and KafkaCluster to use Broker
247272
leaders.map {
248273
case (tp: TopicAndPartition, Broker(host, port)) => (tp, (host, port))
249-
}.toMap
274+
}
250275
}
251276
val cleanedHandler = sc.clean(messageHandler)
252277
checkOffsets(kc, offsetRanges)
@@ -406,23 +431,9 @@ object KafkaUtils {
406431
): InputDStream[(K, V)] = {
407432
val messageHandler = (mmd: MessageAndMetadata[K, V]) => (mmd.key, mmd.message)
408433
val kc = new KafkaCluster(kafkaParams)
409-
val reset = kafkaParams.get("auto.offset.reset").map(_.toLowerCase)
410-
411-
val result = for {
412-
topicPartitions <- kc.getPartitions(topics).right
413-
leaderOffsets <- (if (reset == Some("smallest")) {
414-
kc.getEarliestLeaderOffsets(topicPartitions)
415-
} else {
416-
kc.getLatestLeaderOffsets(topicPartitions)
417-
}).right
418-
} yield {
419-
val fromOffsets = leaderOffsets.map { case (tp, lo) =>
420-
(tp, lo.offset)
421-
}
422-
new DirectKafkaInputDStream[K, V, KD, VD, (K, V)](
423-
ssc, kafkaParams, fromOffsets, messageHandler)
424-
}
425-
KafkaCluster.checkErrors(result)
434+
val fromOffsets = getFromOffsets(kc, kafkaParams, topics)
435+
new DirectKafkaInputDStream[K, V, KD, VD, (K, V)](
436+
ssc, kafkaParams, fromOffsets, messageHandler)
426437
}
427438

428439
/**
@@ -550,6 +561,8 @@ object KafkaUtils {
550561
* takes care of known parameters instead of passing them from Python
551562
*/
552563
private[kafka] class KafkaUtilsPythonHelper {
564+
import KafkaUtilsPythonHelper._
565+
553566
def createStream(
554567
jssc: JavaStreamingContext,
555568
kafkaParams: JMap[String, String],
@@ -566,86 +579,92 @@ private[kafka] class KafkaUtilsPythonHelper {
566579
storageLevel)
567580
}
568581

569-
def createRDD(
582+
def createRDDWithoutMessageHandler(
570583
jsc: JavaSparkContext,
571584
kafkaParams: JMap[String, String],
572585
offsetRanges: JList[OffsetRange],
573-
leaders: JMap[TopicAndPartition, Broker]): JavaPairRDD[Array[Byte], Array[Byte]] = {
574-
val messageHandler = new JFunction[MessageAndMetadata[Array[Byte], Array[Byte]],
575-
(Array[Byte], Array[Byte])] {
576-
def call(t1: MessageAndMetadata[Array[Byte], Array[Byte]]): (Array[Byte], Array[Byte]) =
577-
(t1.key(), t1.message())
578-
}
586+
leaders: JMap[TopicAndPartition, Broker]): JavaRDD[(Array[Byte], Array[Byte])] = {
587+
val messageHandler =
588+
(mmd: MessageAndMetadata[Array[Byte], Array[Byte]]) => (mmd.key, mmd.message)
589+
new JavaRDD(createRDD(jsc, kafkaParams, offsetRanges, leaders, messageHandler))
590+
}
579591

580-
val jrdd = KafkaUtils.createRDD[
581-
Array[Byte],
582-
Array[Byte],
583-
DefaultDecoder,
584-
DefaultDecoder,
585-
(Array[Byte], Array[Byte])](
586-
jsc,
587-
classOf[Array[Byte]],
588-
classOf[Array[Byte]],
589-
classOf[DefaultDecoder],
590-
classOf[DefaultDecoder],
591-
classOf[(Array[Byte], Array[Byte])],
592-
kafkaParams,
593-
offsetRanges.toArray(new Array[OffsetRange](offsetRanges.size())),
594-
leaders,
595-
messageHandler
596-
)
597-
new JavaPairRDD(jrdd.rdd)
592+
def createRDDWithMessageHandler(
593+
jsc: JavaSparkContext,
594+
kafkaParams: JMap[String, String],
595+
offsetRanges: JList[OffsetRange],
596+
leaders: JMap[TopicAndPartition, Broker]): JavaRDD[Array[Byte]] = {
597+
val messageHandler = (mmd: MessageAndMetadata[Array[Byte], Array[Byte]]) =>
598+
new PythonMessageAndMetadata(
599+
mmd.topic, mmd.partition, mmd.offset, mmd.key(), mmd.message())
600+
val rdd = createRDD(jsc, kafkaParams, offsetRanges, leaders, messageHandler).
601+
mapPartitions(picklerIterator)
602+
new JavaRDD(rdd)
598603
}
599604

600-
def createDirectStream(
605+
private def createRDD[V: ClassTag](
606+
jsc: JavaSparkContext,
607+
kafkaParams: JMap[String, String],
608+
offsetRanges: JList[OffsetRange],
609+
leaders: JMap[TopicAndPartition, Broker],
610+
messageHandler: MessageAndMetadata[Array[Byte], Array[Byte]] => V): RDD[V] = {
611+
KafkaUtils.createRDD[Array[Byte], Array[Byte], DefaultDecoder, DefaultDecoder, V](
612+
jsc.sc,
613+
kafkaParams.asScala.toMap,
614+
offsetRanges.toArray(new Array[OffsetRange](offsetRanges.size())),
615+
leaders.asScala.toMap,
616+
messageHandler
617+
)
618+
}
619+
620+
def createDirectStreamWithoutMessageHandler(
621+
jssc: JavaStreamingContext,
622+
kafkaParams: JMap[String, String],
623+
topics: JSet[String],
624+
fromOffsets: JMap[TopicAndPartition, JLong]): JavaDStream[(Array[Byte], Array[Byte])] = {
625+
val messageHandler =
626+
(mmd: MessageAndMetadata[Array[Byte], Array[Byte]]) => (mmd.key, mmd.message)
627+
new JavaDStream(createDirectStream(jssc, kafkaParams, topics, fromOffsets, messageHandler))
628+
}
629+
630+
def createDirectStreamWithMessageHandler(
601631
jssc: JavaStreamingContext,
602632
kafkaParams: JMap[String, String],
603633
topics: JSet[String],
604-
fromOffsets: JMap[TopicAndPartition, JLong]
605-
): JavaPairInputDStream[Array[Byte], Array[Byte]] = {
634+
fromOffsets: JMap[TopicAndPartition, JLong]): JavaDStream[Array[Byte]] = {
635+
val messageHandler = (mmd: MessageAndMetadata[Array[Byte], Array[Byte]]) =>
636+
new PythonMessageAndMetadata(mmd.topic, mmd.partition, mmd.offset, mmd.key(), mmd.message())
637+
val stream = createDirectStream(jssc, kafkaParams, topics, fromOffsets, messageHandler).
638+
mapPartitions(picklerIterator)
639+
new JavaDStream(stream)
640+
}
606641

607-
if (!fromOffsets.isEmpty) {
642+
private def createDirectStream[V: ClassTag](
643+
jssc: JavaStreamingContext,
644+
kafkaParams: JMap[String, String],
645+
topics: JSet[String],
646+
fromOffsets: JMap[TopicAndPartition, JLong],
647+
messageHandler: MessageAndMetadata[Array[Byte], Array[Byte]] => V): DStream[V] = {
648+
649+
val currentFromOffsets = if (!fromOffsets.isEmpty) {
608650
val topicsFromOffsets = fromOffsets.keySet().asScala.map(_.topic)
609651
if (topicsFromOffsets != topics.asScala.toSet) {
610652
throw new IllegalStateException(
611653
s"The specified topics: ${topics.asScala.toSet.mkString(" ")} " +
612654
s"do not equal to the topic from offsets: ${topicsFromOffsets.mkString(" ")}")
613655
}
614-
}
615-
616-
if (fromOffsets.isEmpty) {
617-
KafkaUtils.createDirectStream[Array[Byte], Array[Byte], DefaultDecoder, DefaultDecoder](
618-
jssc,
619-
classOf[Array[Byte]],
620-
classOf[Array[Byte]],
621-
classOf[DefaultDecoder],
622-
classOf[DefaultDecoder],
623-
kafkaParams,
624-
topics)
656+
Map(fromOffsets.asScala.mapValues { _.longValue() }.toSeq: _*)
625657
} else {
626-
val messageHandler = new JFunction[MessageAndMetadata[Array[Byte], Array[Byte]],
627-
(Array[Byte], Array[Byte])] {
628-
def call(t1: MessageAndMetadata[Array[Byte], Array[Byte]]): (Array[Byte], Array[Byte]) =
629-
(t1.key(), t1.message())
630-
}
631-
632-
val jstream = KafkaUtils.createDirectStream[
633-
Array[Byte],
634-
Array[Byte],
635-
DefaultDecoder,
636-
DefaultDecoder,
637-
(Array[Byte], Array[Byte])](
638-
jssc,
639-
classOf[Array[Byte]],
640-
classOf[Array[Byte]],
641-
classOf[DefaultDecoder],
642-
classOf[DefaultDecoder],
643-
classOf[(Array[Byte], Array[Byte])],
644-
kafkaParams,
645-
fromOffsets,
646-
messageHandler)
647-
new JavaPairInputDStream(jstream.inputDStream)
658+
val kc = new KafkaCluster(Map(kafkaParams.asScala.toSeq: _*))
659+
KafkaUtils.getFromOffsets(
660+
kc, Map(kafkaParams.asScala.toSeq: _*), Set(topics.asScala.toSeq: _*))
648661
}
662+
663+
KafkaUtils.createDirectStream[Array[Byte], Array[Byte], DefaultDecoder, DefaultDecoder, V](
664+
jssc.ssc,
665+
Map(kafkaParams.asScala.toSeq: _*),
666+
Map(currentFromOffsets.toSeq: _*),
667+
messageHandler)
649668
}
650669

651670
def createOffsetRange(topic: String, partition: JInt, fromOffset: JLong, untilOffset: JLong
@@ -669,3 +688,57 @@ private[kafka] class KafkaUtilsPythonHelper {
669688
kafkaRDD.offsetRanges.toSeq.asJava
670689
}
671690
}
691+
692+
private object KafkaUtilsPythonHelper {
693+
private var initialized = false
694+
695+
def initialize(): Unit = {
696+
SerDeUtil.initialize()
697+
synchronized {
698+
if (!initialized) {
699+
new PythonMessageAndMetadataPickler().register()
700+
initialized = true
701+
}
702+
}
703+
}
704+
705+
initialize()
706+
707+
def picklerIterator(iter: Iterator[Any]): Iterator[Array[Byte]] = {
708+
new SerDeUtil.AutoBatchedPickler(iter)
709+
}
710+
711+
case class PythonMessageAndMetadata(
712+
topic: String,
713+
partition: JInt,
714+
offset: JLong,
715+
key: Array[Byte],
716+
message: Array[Byte])
717+
718+
class PythonMessageAndMetadataPickler extends IObjectPickler {
719+
private val module = "pyspark.streaming.kafka"
720+
721+
def register(): Unit = {
722+
Pickler.registerCustomPickler(classOf[PythonMessageAndMetadata], this)
723+
Pickler.registerCustomPickler(this.getClass, this)
724+
}
725+
726+
def pickle(obj: Object, out: OutputStream, pickler: Pickler) {
727+
if (obj == this) {
728+
out.write(Opcodes.GLOBAL)
729+
out.write(s"$module\nKafkaMessageAndMetadata\n".getBytes(UTF_8))
730+
} else {
731+
pickler.save(this)
732+
val msgAndMetaData = obj.asInstanceOf[PythonMessageAndMetadata]
733+
out.write(Opcodes.MARK)
734+
pickler.save(msgAndMetaData.topic)
735+
pickler.save(msgAndMetaData.partition)
736+
pickler.save(msgAndMetaData.offset)
737+
pickler.save(msgAndMetaData.key)
738+
pickler.save(msgAndMetaData.message)
739+
out.write(Opcodes.TUPLE)
740+
out.write(Opcodes.REDUCE)
741+
}
742+
}
743+
}
744+
}

project/MimaExcludes.scala

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,12 @@ object MimaExcludes {
136136
// SPARK-11766 add toJson to Vector
137137
ProblemFilters.exclude[MissingMethodProblem](
138138
"org.apache.spark.mllib.linalg.Vector.toJson")
139+
) ++ Seq(
140+
// SPARK-9065 Support message handler in Kafka Python API
141+
ProblemFilters.exclude[MissingMethodProblem](
142+
"org.apache.spark.streaming.kafka.KafkaUtilsPythonHelper.createDirectStream"),
143+
ProblemFilters.exclude[MissingMethodProblem](
144+
"org.apache.spark.streaming.kafka.KafkaUtilsPythonHelper.createRDD")
139145
)
140146
case v if v.startsWith("1.5") =>
141147
Seq(

0 commit comments

Comments
 (0)