diff --git a/external/kafka-0-8/src/main/scala/org/apache/spark/streaming/kafka/KafkaRDD.scala b/external/kafka-0-8/src/main/scala/org/apache/spark/streaming/kafka/KafkaRDD.scala index 5ea52b6ad36a0..9e896a95fbc70 100644 --- a/external/kafka-0-8/src/main/scala/org/apache/spark/streaming/kafka/KafkaRDD.scala +++ b/external/kafka-0-8/src/main/scala/org/apache/spark/streaming/kafka/KafkaRDD.scala @@ -57,10 +57,35 @@ class KafkaRDD[ messageHandler: MessageAndMetadata[K, V] => R ) extends RDD[R](sc, Nil) with Logging with HasOffsetRanges { override def getPartitions: Array[Partition] = { - offsetRanges.zipWithIndex.map { case (o, i) => - val (host, port) = leaders(TopicAndPartition(o.topic, o.partition)) - new KafkaRDDPartition(i, o.topic, o.partition, o.fromOffset, o.untilOffset, host, port) + + val subconcurrency = if (kafkaParams.contains("topic.partition.subconcurrency")) + kafkaParams.getOrElse("topic.partition.subconcurrency","1").toInt + else 1 + val numPartitions = offsetRanges.length + + val subOffsetRanges: Array[OffsetRange] = new Array[OffsetRange](subconcurrency * numPartitions) + for (i <- 0 until numPartitions) { + val offsetRange = offsetRanges(i) + val step = (offsetRange.untilOffset - offsetRange.fromOffset) / subconcurrency + + var from = -1L + var until = -1L + + for (j <- 0 until subconcurrency) { + from = offsetRange.fromOffset + j * step + until = offsetRange.fromOffset + (j + 1) * step -1 + if (j == subconcurrency) { + until = offsetRange.untilOffset + } + subOffsetRanges(i * subconcurrency + j) = OffsetRange.create(offsetRange.topic, offsetRange.partition, from, until) + } + } + + subOffsetRanges.zipWithIndex.map{ case (o, i) => + val (host, port) = leaders(TopicAndPartition(o.topic, o.partition)) + new KafkaRDDPartition(i, o.topic, o.partition, o.fromOffset, o.untilOffset, host, port) }.toArray + } override def count(): Long = offsetRanges.map(_.count).sum