Skip to content

Commit 48aa7c3

Browse files
committed
If the user doesn't set messageHandler, use the old approach to speed up
1 parent 59df9ee commit 48aa7c3

File tree

2 files changed

+103
-56
lines changed

2 files changed

+103
-56
lines changed

external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaUtils.scala

Lines changed: 62 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ import org.apache.spark.rdd.RDD
3939
import org.apache.spark.storage.StorageLevel
4040
import org.apache.spark.streaming.StreamingContext
4141
import org.apache.spark.streaming.api.java._
42-
import org.apache.spark.streaming.dstream.{InputDStream, ReceiverInputDStream}
42+
import org.apache.spark.streaming.dstream.{DStream, InputDStream, ReceiverInputDStream}
4343

4444
object KafkaUtils {
4545
/**
@@ -579,7 +579,18 @@ private[kafka] class KafkaUtilsPythonHelper {
579579
storageLevel)
580580
}
581581

582-
def createRDD(
582+
def createRDDWithoutMessageHandler(
583+
jsc: JavaSparkContext,
584+
kafkaParams: JMap[String, String],
585+
offsetRanges: JList[OffsetRange],
586+
leaders: JMap[TopicAndPartition, Broker]
587+
): JavaRDD[(Array[Byte], Array[Byte])] = {
588+
val messageHandler =
589+
(mmd: MessageAndMetadata[Array[Byte], Array[Byte]]) => (mmd.key, mmd.message)
590+
new JavaRDD(createRDD(jsc, kafkaParams, offsetRanges, leaders, messageHandler))
591+
}
592+
593+
def createRDDWithMessageHandler(
583594
jsc: JavaSparkContext,
584595
kafkaParams: JMap[String, String],
585596
offsetRanges: JList[OffsetRange],
@@ -588,26 +599,57 @@ private[kafka] class KafkaUtilsPythonHelper {
588599
val messageHandler = (mmd: MessageAndMetadata[Array[Byte], Array[Byte]]) =>
589600
new PythonMessageAndMetadata(
590601
mmd.topic, mmd.partition, mmd.offset, mmd.key(), mmd.message())
602+
val rdd = createRDD(jsc, kafkaParams, offsetRanges, leaders, messageHandler).
603+
mapPartitions(picklerIterator)
604+
new JavaRDD(rdd)
605+
}
591606

592-
KafkaUtils.createRDD[
593-
Array[Byte],
594-
Array[Byte],
595-
DefaultDecoder,
596-
DefaultDecoder,
597-
PythonMessageAndMetadata](
598-
jsc.sc,
599-
Map(kafkaParams.asScala.toSeq: _*),
600-
offsetRanges.toArray(new Array[OffsetRange](offsetRanges.size())),
601-
Map(leaders.asScala.toSeq: _*),
602-
messageHandler).mapPartitions { iter => picklerIterator(iter) }
607+
private def createRDD[V: ClassTag](
608+
jsc: JavaSparkContext,
609+
kafkaParams: JMap[String, String],
610+
offsetRanges: JList[OffsetRange],
611+
leaders: JMap[TopicAndPartition, Broker],
612+
messageHandler: MessageAndMetadata[Array[Byte], Array[Byte]] => V): RDD[V] = {
613+
KafkaUtils.createRDD[Array[Byte], Array[Byte], DefaultDecoder, DefaultDecoder, V](
614+
jsc.sc,
615+
kafkaParams.asScala.toMap,
616+
offsetRanges.toArray(new Array[OffsetRange](offsetRanges.size())),
617+
leaders.asScala.toMap,
618+
messageHandler
619+
)
620+
}
621+
622+
def createDirectStreamWithoutMessageHandler(
623+
jssc: JavaStreamingContext,
624+
kafkaParams: JMap[String, String],
625+
topics: JSet[String],
626+
fromOffsets: JMap[TopicAndPartition, JLong]
627+
): JavaDStream[(Array[Byte], Array[Byte])] = {
628+
val messageHandler =
629+
(mmd: MessageAndMetadata[Array[Byte], Array[Byte]]) => (mmd.key, mmd.message)
630+
new JavaDStream(createDirectStream(jssc, kafkaParams, topics, fromOffsets, messageHandler))
603631
}
604632

605-
def createDirectStream(
633+
def createDirectStreamWithMessageHandler(
606634
jssc: JavaStreamingContext,
607635
kafkaParams: JMap[String, String],
608636
topics: JSet[String],
609637
fromOffsets: JMap[TopicAndPartition, JLong]
610-
): JavaDStream[Array[Byte]] = {
638+
): JavaDStream[Array[Byte]] = {
639+
val messageHandler = (mmd: MessageAndMetadata[Array[Byte], Array[Byte]]) =>
640+
new PythonMessageAndMetadata(mmd.topic, mmd.partition, mmd.offset, mmd.key(), mmd.message())
641+
val stream = createDirectStream(jssc, kafkaParams, topics, fromOffsets, messageHandler).
642+
mapPartitions(picklerIterator)
643+
new JavaDStream(stream)
644+
}
645+
646+
private def createDirectStream[V: ClassTag](
647+
jssc: JavaStreamingContext,
648+
kafkaParams: JMap[String, String],
649+
topics: JSet[String],
650+
fromOffsets: JMap[TopicAndPartition, JLong],
651+
messageHandler: MessageAndMetadata[Array[Byte], Array[Byte]] => V
652+
): DStream[V] = {
611653

612654
val currentFromOffsets = if (!fromOffsets.isEmpty) {
613655
val topicsFromOffsets = fromOffsets.keySet().asScala.map(_.topic)
@@ -623,21 +665,11 @@ private[kafka] class KafkaUtilsPythonHelper {
623665
kc, Map(kafkaParams.asScala.toSeq: _*), Set(topics.asScala.toSeq: _*))
624666
}
625667

626-
val messageHandler = (mmd: MessageAndMetadata[Array[Byte], Array[Byte]]) =>
627-
new PythonMessageAndMetadata(
628-
mmd.topic, mmd.partition, mmd.offset, mmd.key(), mmd.message())
629-
630-
val stream = KafkaUtils.createDirectStream[
631-
Array[Byte],
632-
Array[Byte],
633-
DefaultDecoder,
634-
DefaultDecoder,
635-
PythonMessageAndMetadata](
636-
jssc.ssc,
637-
Map(kafkaParams.asScala.toSeq: _*),
638-
Map(currentFromOffsets.toSeq: _*),
639-
messageHandler).mapPartitions { iter => picklerIterator(iter) }
640-
new JavaDStream(stream)
668+
KafkaUtils.createDirectStream[Array[Byte], Array[Byte], DefaultDecoder, DefaultDecoder, V](
669+
jssc.ssc,
670+
Map(kafkaParams.asScala.toSeq: _*),
671+
Map(currentFromOffsets.toSeq: _*),
672+
messageHandler)
641673
}
642674

643675
def createOffsetRange(topic: String, partition: JInt, fromOffset: JLong, untilOffset: JLong

python/pyspark/streaming/kafka.py

Lines changed: 41 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -36,16 +36,6 @@ def utf8_decoder(s):
3636
return s.decode('utf-8')
3737

3838

39-
def default_message_handler(s):
40-
"""
41-
Function for translating each message and metadata into the desired type
42-
43-
:param s: A KafkaMessageAndMetadata object includes message and metadata
44-
:return: A tuple of Kafka key and message
45-
"""
46-
return s and (s.key, s.message)
47-
48-
4939
class KafkaUtils(object):
5040

5141
@staticmethod
@@ -95,7 +85,7 @@ def createStream(ssc, zkQuorum, groupId, topics, kafkaParams=None,
9585
@staticmethod
9686
def createDirectStream(ssc, topics, kafkaParams, fromOffsets=None,
9787
keyDecoder=utf8_decoder, valueDecoder=utf8_decoder,
98-
messageHandler=default_message_handler):
88+
messageHandler=None):
9989
"""
10090
.. note:: Experimental
10191
@@ -120,6 +110,8 @@ def createDirectStream(ssc, topics, kafkaParams, fromOffsets=None,
120110
point of the stream.
121111
:param keyDecoder: A function used to decode key (default is utf8_decoder).
122112
:param valueDecoder: A function used to decode value (default is utf8_decoder).
113+
:param messageHandler: A function used to convert KafkaMessageAndMetadata. You can assess
114+
meta using messageHandler (default is None).
123115
:return: A DStream object
124116
"""
125117
if fromOffsets is None:
@@ -129,32 +121,43 @@ def createDirectStream(ssc, topics, kafkaParams, fromOffsets=None,
129121
if not isinstance(kafkaParams, dict):
130122
raise TypeError("kafkaParams should be dict")
131123

124+
def funcWithoutMessageHandler(k_v):
125+
return (keyDecoder(k_v[0]), valueDecoder(k_v[1]))
126+
127+
def funcWithMessageHandler(m):
128+
m._set_key_decoder(keyDecoder)
129+
m._set_value_decoder(valueDecoder)
130+
return messageHandler(m)
131+
132132
try:
133133
helperClass = ssc._jvm.java.lang.Thread.currentThread().getContextClassLoader() \
134134
.loadClass("org.apache.spark.streaming.kafka.KafkaUtilsPythonHelper")
135135
helper = helperClass.newInstance()
136136

137137
jfromOffsets = dict([(k._jTopicAndPartition(helper),
138138
v) for (k, v) in fromOffsets.items()])
139-
jstream = helper.createDirectStream(ssc._jssc, kafkaParams, set(topics), jfromOffsets)
139+
if messageHandler is None:
140+
ser = PairDeserializer(NoOpSerializer(), NoOpSerializer())
141+
func = funcWithoutMessageHandler
142+
jstream = helper.createDirectStreamWithoutMessageHandler(
143+
ssc._jssc, kafkaParams, set(topics), jfromOffsets)
144+
else:
145+
ser = AutoBatchedSerializer(PickleSerializer())
146+
func = funcWithMessageHandler
147+
jstream = helper.createDirectStreamWithMessageHandler(
148+
ssc._jssc, kafkaParams, set(topics), jfromOffsets)
140149
except Py4JJavaError as e:
141150
if 'ClassNotFoundException' in str(e.java_exception):
142151
KafkaUtils._printErrorMsg(ssc.sparkContext)
143152
raise e
144153

145-
def func(m):
146-
m._set_key_decoder(keyDecoder)
147-
m._set_value_decoder(valueDecoder)
148-
return messageHandler(m)
149-
150-
ser = AutoBatchedSerializer(PickleSerializer())
151154
stream = DStream(jstream, ssc, ser).map(func)
152155
return KafkaDStream(stream._jdstream, ssc, stream._jrdd_deserializer)
153156

154157
@staticmethod
155158
def createRDD(sc, kafkaParams, offsetRanges, leaders=None,
156159
keyDecoder=utf8_decoder, valueDecoder=utf8_decoder,
157-
messageHandler=default_message_handler):
160+
messageHandler=None):
158161
"""
159162
.. note:: Experimental
160163
@@ -167,6 +170,8 @@ def createRDD(sc, kafkaParams, offsetRanges, leaders=None,
167170
map, in which case leaders will be looked up on the driver.
168171
:param keyDecoder: A function used to decode key (default is utf8_decoder)
169172
:param valueDecoder: A function used to decode value (default is utf8_decoder)
173+
:param messageHandler: A function used to convert KafkaMessageAndMetadata. You can assess
174+
meta using messageHandler (default is None).
170175
:return: A RDD object
171176
"""
172177
if leaders is None:
@@ -176,25 +181,35 @@ def createRDD(sc, kafkaParams, offsetRanges, leaders=None,
176181
if not isinstance(offsetRanges, list):
177182
raise TypeError("offsetRanges should be list")
178183

184+
def funcWithoutMessageHandler(k_v):
185+
return (keyDecoder(k_v[0]), valueDecoder(k_v[1]))
186+
187+
def funcWithMessageHandler(m):
188+
m._set_key_decoder(keyDecoder)
189+
m._set_value_decoder(valueDecoder)
190+
return messageHandler(m)
191+
179192
try:
180193
helperClass = sc._jvm.java.lang.Thread.currentThread().getContextClassLoader() \
181194
.loadClass("org.apache.spark.streaming.kafka.KafkaUtilsPythonHelper")
182195
helper = helperClass.newInstance()
183196
joffsetRanges = [o._jOffsetRange(helper) for o in offsetRanges]
184197
jleaders = dict([(k._jTopicAndPartition(helper),
185198
v._jBroker(helper)) for (k, v) in leaders.items()])
186-
jrdd = helper.createRDD(sc._jsc, kafkaParams, joffsetRanges, jleaders)
199+
if messageHandler is None:
200+
jrdd = helper.createRDDWithoutMessageHandler(
201+
sc._jsc, kafkaParams, joffsetRanges, jleaders)
202+
ser = PairDeserializer(NoOpSerializer(), NoOpSerializer())
203+
rdd = RDD(jrdd, sc, ser).map(funcWithoutMessageHandler)
204+
else:
205+
jrdd = helper.createRDDWithMessageHandler(
206+
sc._jsc, kafkaParams, joffsetRanges, jleaders)
207+
rdd = RDD(jrdd, sc).map(funcWithMessageHandler)
187208
except Py4JJavaError as e:
188209
if 'ClassNotFoundException' in str(e.java_exception):
189210
KafkaUtils._printErrorMsg(sc)
190211
raise e
191212

192-
def func(m):
193-
m._set_key_decoder(keyDecoder)
194-
m._set_value_decoder(valueDecoder)
195-
return messageHandler(m)
196-
197-
rdd = RDD(jrdd, sc).map(func)
198213
return KafkaRDD(rdd._jrdd, sc, rdd._jrdd_deserializer)
199214

200215
@staticmethod

0 commit comments

Comments
 (0)