Skip to content

Commit 2189247

Browse files
committed
Merge branch 'master' of github.com:apache/spark into pool-npe
2 parents 05ad9e9 + c33b8dc commit 2189247

File tree

18 files changed

+432
-44
lines changed

18 files changed

+432
-44
lines changed

core/src/main/scala/org/apache/spark/Partitioner.scala

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -156,3 +156,64 @@ class RangePartitioner[K : Ordering : ClassTag, V](
156156
false
157157
}
158158
}
159+
160+
/**
161+
* A [[org.apache.spark.Partitioner]] that partitions records into specified bounds
162+
* Default value is 1000. Once all partitions have bounds elements, the partitioner
163+
* allocates 1 element per partition so eventually the smaller partitions are at most
164+
* off by 1 key compared to the larger partitions.
165+
*/
166+
class BoundaryPartitioner[K : Ordering : ClassTag, V](
167+
partitions: Int,
168+
@transient rdd: RDD[_ <: Product2[K,V]],
169+
private val boundary: Int = 1000)
170+
extends Partitioner {
171+
172+
// this array keeps track of keys assigned to a partition
173+
// counts[0] refers to # of keys in partition 0 and so on
174+
private val counts: Array[Int] = {
175+
new Array[Int](numPartitions)
176+
}
177+
178+
def numPartitions = math.abs(partitions)
179+
180+
/*
181+
* Ideally, this should've been calculated based on # partitions and total keys
182+
* But we are not calling count on RDD here to avoid calling an action.
183+
* User has the flexibility of calling count and passing in any appropriate boundary
184+
*/
185+
def keysPerPartition = boundary
186+
187+
var currPartition = 0
188+
189+
/*
190+
* Pick current partition for the key until we hit the bound for keys / partition,
191+
* start allocating to next partition at that time.
192+
*
193+
* NOTE: In case where we have lets say 2000 keys and user says 3 partitions with 500
194+
* passed in as boundary, the first 500 will goto P1, 501-1000 go to P2, 1001-1500 go to P3,
195+
* after that, next keys go to one partition at a time. So 1501 goes to P1, 1502 goes to P2,
196+
* 1503 goes to P3 and so on.
197+
*/
198+
def getPartition(key: Any): Int = {
199+
val partition = currPartition
200+
counts(partition) = counts(partition) + 1
201+
/*
202+
* Since we are filling up a partition before moving to next one (this helps in maintaining
203+
* order of keys, in certain cases, it is possible to end up with empty partitions, like
204+
* 3 partitions, 500 keys / partition and if rdd has 700 keys, 1 partition will be entirely
205+
* empty.
206+
*/
207+
if(counts(currPartition) >= keysPerPartition)
208+
currPartition = (currPartition + 1) % numPartitions
209+
partition
210+
}
211+
212+
override def equals(other: Any): Boolean = other match {
213+
case r: BoundaryPartitioner[_,_] =>
214+
(r.counts.sameElements(counts) && r.boundary == boundary
215+
&& r.currPartition == currPartition)
216+
case _ =>
217+
false
218+
}
219+
}

core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -217,7 +217,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)])
217217
* Return approximate number of distinct values for each key in this RDD.
218218
* The accuracy of approximation can be controlled through the relative standard deviation
219219
* (relativeSD) parameter, which also controls the amount of memory used. Lower values result in
220-
* more accurate counts but increase the memory footprint and vise versa. Uses the provided
220+
* more accurate counts but increase the memory footprint and vice versa. Uses the provided
221221
* Partitioner to partition the output RDD.
222222
*/
223223
def countApproxDistinctByKey(relativeSD: Double, partitioner: Partitioner): RDD[(K, Long)] = {
@@ -232,7 +232,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)])
232232
* Return approximate number of distinct values for each key in this RDD.
233233
* The accuracy of approximation can be controlled through the relative standard deviation
234234
* (relativeSD) parameter, which also controls the amount of memory used. Lower values result in
235-
* more accurate counts but increase the memory footprint and vise versa. HashPartitions the
235+
* more accurate counts but increase the memory footprint and vice versa. HashPartitions the
236236
* output RDD into numPartitions.
237237
*
238238
*/
@@ -244,7 +244,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)])
244244
* Return approximate number of distinct values for each key this RDD.
245245
* The accuracy of approximation can be controlled through the relative standard deviation
246246
* (relativeSD) parameter, which also controls the amount of memory used. Lower values result in
247-
* more accurate counts but increase the memory footprint and vise versa. The default value of
247+
* more accurate counts but increase the memory footprint and vice versa. The default value of
248248
* relativeSD is 0.05. Hash-partitions the output RDD using the existing partitioner/parallelism
249249
* level.
250250
*/

core/src/main/scala/org/apache/spark/scheduler/LiveListenerBus.scala

Lines changed: 28 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ package org.apache.spark.scheduler
2020
import java.util.concurrent.{LinkedBlockingQueue, Semaphore}
2121

2222
import org.apache.spark.Logging
23+
import org.apache.spark.util.Utils
2324

2425
/**
2526
* Asynchronously passes SparkListenerEvents to registered SparkListeners.
@@ -42,7 +43,7 @@ private[spark] class LiveListenerBus extends SparkListenerBus with Logging {
4243

4344
private val listenerThread = new Thread("SparkListenerBus") {
4445
setDaemon(true)
45-
override def run() {
46+
override def run(): Unit = Utils.logUncaughtExceptions {
4647
while (true) {
4748
eventLock.acquire()
4849
// Atomically remove and process this event
@@ -77,11 +78,8 @@ private[spark] class LiveListenerBus extends SparkListenerBus with Logging {
7778
val eventAdded = eventQueue.offer(event)
7879
if (eventAdded) {
7980
eventLock.release()
80-
} else if (!queueFullErrorMessageLogged) {
81-
logError("Dropping SparkListenerEvent because no remaining room in event queue. " +
82-
"This likely means one of the SparkListeners is too slow and cannot keep up with the " +
83-
"rate at which tasks are being started by the scheduler.")
84-
queueFullErrorMessageLogged = true
81+
} else {
82+
logQueueFullErrorMessage()
8583
}
8684
}
8785

@@ -96,13 +94,18 @@ private[spark] class LiveListenerBus extends SparkListenerBus with Logging {
9694
if (System.currentTimeMillis > finishTime) {
9795
return false
9896
}
99-
/* Sleep rather than using wait/notify, because this is used only for testing and wait/notify
100-
* add overhead in the general case. */
97+
/* Sleep rather than using wait/notify, because this is used only for testing and
98+
* wait/notify add overhead in the general case. */
10199
Thread.sleep(10)
102100
}
103101
true
104102
}
105103

104+
/**
105+
* For testing only. Return whether the listener daemon thread is still alive.
106+
*/
107+
def listenerThreadIsAlive: Boolean = synchronized { listenerThread.isAlive }
108+
106109
/**
107110
* Return whether the event queue is empty.
108111
*
@@ -111,6 +114,23 @@ private[spark] class LiveListenerBus extends SparkListenerBus with Logging {
111114
*/
112115
def queueIsEmpty: Boolean = synchronized { eventQueue.isEmpty }
113116

117+
/**
118+
* Log an error message to indicate that the event queue is full. Do this only once.
119+
*/
120+
private def logQueueFullErrorMessage(): Unit = {
121+
if (!queueFullErrorMessageLogged) {
122+
if (listenerThread.isAlive) {
123+
logError("Dropping SparkListenerEvent because no remaining room in event queue. " +
124+
"This likely means one of the SparkListeners is too slow and cannot keep up with" +
125+
"the rate at which tasks are being started by the scheduler.")
126+
} else {
127+
logError("SparkListenerBus thread is dead! This means SparkListenerEvents have not" +
128+
"been (and will no longer be) propagated to listeners for some time.")
129+
}
130+
queueFullErrorMessageLogged = true
131+
}
132+
}
133+
114134
def stop() {
115135
if (!started) {
116136
throw new IllegalStateException("Attempted to stop a listener bus that has not yet started!")

core/src/main/scala/org/apache/spark/scheduler/SparkListenerBus.scala

Lines changed: 34 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -20,10 +20,13 @@ package org.apache.spark.scheduler
2020
import scala.collection.mutable
2121
import scala.collection.mutable.ArrayBuffer
2222

23+
import org.apache.spark.Logging
24+
import org.apache.spark.util.Utils
25+
2326
/**
2427
* A SparkListenerEvent bus that relays events to its listeners
2528
*/
26-
private[spark] trait SparkListenerBus {
29+
private[spark] trait SparkListenerBus extends Logging {
2730

2831
// SparkListeners attached to this event bus
2932
protected val sparkListeners = new ArrayBuffer[SparkListener]
@@ -34,38 +37,53 @@ private[spark] trait SparkListenerBus {
3437
}
3538

3639
/**
37-
* Post an event to all attached listeners. This does nothing if the event is
38-
* SparkListenerShutdown.
40+
* Post an event to all attached listeners.
41+
* This does nothing if the event is SparkListenerShutdown.
3942
*/
4043
def postToAll(event: SparkListenerEvent) {
4144
event match {
4245
case stageSubmitted: SparkListenerStageSubmitted =>
43-
sparkListeners.foreach(_.onStageSubmitted(stageSubmitted))
46+
foreachListener(_.onStageSubmitted(stageSubmitted))
4447
case stageCompleted: SparkListenerStageCompleted =>
45-
sparkListeners.foreach(_.onStageCompleted(stageCompleted))
48+
foreachListener(_.onStageCompleted(stageCompleted))
4649
case jobStart: SparkListenerJobStart =>
47-
sparkListeners.foreach(_.onJobStart(jobStart))
50+
foreachListener(_.onJobStart(jobStart))
4851
case jobEnd: SparkListenerJobEnd =>
49-
sparkListeners.foreach(_.onJobEnd(jobEnd))
52+
foreachListener(_.onJobEnd(jobEnd))
5053
case taskStart: SparkListenerTaskStart =>
51-
sparkListeners.foreach(_.onTaskStart(taskStart))
54+
foreachListener(_.onTaskStart(taskStart))
5255
case taskGettingResult: SparkListenerTaskGettingResult =>
53-
sparkListeners.foreach(_.onTaskGettingResult(taskGettingResult))
56+
foreachListener(_.onTaskGettingResult(taskGettingResult))
5457
case taskEnd: SparkListenerTaskEnd =>
55-
sparkListeners.foreach(_.onTaskEnd(taskEnd))
58+
foreachListener(_.onTaskEnd(taskEnd))
5659
case environmentUpdate: SparkListenerEnvironmentUpdate =>
57-
sparkListeners.foreach(_.onEnvironmentUpdate(environmentUpdate))
60+
foreachListener(_.onEnvironmentUpdate(environmentUpdate))
5861
case blockManagerAdded: SparkListenerBlockManagerAdded =>
59-
sparkListeners.foreach(_.onBlockManagerAdded(blockManagerAdded))
62+
foreachListener(_.onBlockManagerAdded(blockManagerAdded))
6063
case blockManagerRemoved: SparkListenerBlockManagerRemoved =>
61-
sparkListeners.foreach(_.onBlockManagerRemoved(blockManagerRemoved))
64+
foreachListener(_.onBlockManagerRemoved(blockManagerRemoved))
6265
case unpersistRDD: SparkListenerUnpersistRDD =>
63-
sparkListeners.foreach(_.onUnpersistRDD(unpersistRDD))
66+
foreachListener(_.onUnpersistRDD(unpersistRDD))
6467
case applicationStart: SparkListenerApplicationStart =>
65-
sparkListeners.foreach(_.onApplicationStart(applicationStart))
68+
foreachListener(_.onApplicationStart(applicationStart))
6669
case applicationEnd: SparkListenerApplicationEnd =>
67-
sparkListeners.foreach(_.onApplicationEnd(applicationEnd))
70+
foreachListener(_.onApplicationEnd(applicationEnd))
6871
case SparkListenerShutdown =>
6972
}
7073
}
74+
75+
/**
76+
* Apply the given function to all attached listeners, catching and logging any exception.
77+
*/
78+
private def foreachListener(f: SparkListener => Unit): Unit = {
79+
sparkListeners.foreach { listener =>
80+
try {
81+
f(listener)
82+
} catch {
83+
case e: Exception =>
84+
logError(s"Listener ${Utils.getFormattedClassName(listener)} threw an exception", e)
85+
}
86+
}
87+
}
88+
7189
}

core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ import scala.collection.mutable.Set
2525
import com.esotericsoftware.reflectasm.shaded.org.objectweb.asm.{ClassReader, ClassVisitor, MethodVisitor, Type}
2626
import com.esotericsoftware.reflectasm.shaded.org.objectweb.asm.Opcodes._
2727

28-
import org.apache.spark.Logging
28+
import org.apache.spark.{Logging, SparkException}
2929

3030
private[spark] object ClosureCleaner extends Logging {
3131
// Get an ASM class reader for a given class from the JAR that loaded it
@@ -108,6 +108,9 @@ private[spark] object ClosureCleaner extends Logging {
108108
val outerObjects = getOuterObjects(func)
109109

110110
val accessedFields = Map[Class[_], Set[String]]()
111+
112+
getClassReader(func.getClass).accept(new ReturnStatementFinder(), 0)
113+
111114
for (cls <- outerClasses)
112115
accessedFields(cls) = Set[String]()
113116
for (cls <- func.getClass :: innerClasses)
@@ -180,6 +183,24 @@ private[spark] object ClosureCleaner extends Logging {
180183
}
181184
}
182185

186+
private[spark]
187+
class ReturnStatementFinder extends ClassVisitor(ASM4) {
188+
override def visitMethod(access: Int, name: String, desc: String,
189+
sig: String, exceptions: Array[String]): MethodVisitor = {
190+
if (name.contains("apply")) {
191+
new MethodVisitor(ASM4) {
192+
override def visitTypeInsn(op: Int, tp: String) {
193+
if (op == NEW && tp.contains("scala/runtime/NonLocalReturnControl")) {
194+
throw new SparkException("Return statements aren't allowed in Spark closures")
195+
}
196+
}
197+
}
198+
} else {
199+
new MethodVisitor(ASM4) {}
200+
}
201+
}
202+
}
203+
183204
private[spark]
184205
class FieldAccessFinder(output: Map[Class[_], Set[String]]) extends ClassVisitor(ASM4) {
185206
override def visitMethod(access: Int, name: String, desc: String,

core/src/main/scala/org/apache/spark/util/Utils.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1128,7 +1128,7 @@ private[spark] object Utils extends Logging {
11281128
}
11291129

11301130
/**
1131-
* Executes the given block, printing and re-throwing any uncaught exceptions.
1131+
* Execute the given block, logging and re-throwing any uncaught exception.
11321132
* This is particularly useful for wrapping code that runs in a thread, to ensure
11331133
* that exceptions are printed, and to avoid having to catch Throwable.
11341134
*/

core/src/test/scala/org/apache/spark/PartitioningSuite.scala

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,40 @@ class PartitioningSuite extends FunSuite with SharedSparkContext with PrivateMet
6666
assert(descendingP4 != p4)
6767
}
6868

69+
test("BoundaryPartitioner equality") {
70+
// Make an RDD where all the elements are the same so that the partition range bounds
71+
// are deterministically all the same.
72+
val rdd = sc.parallelize(1.to(4000)).map(x => (x, x))
73+
74+
val p2 = new BoundaryPartitioner(2, rdd, 1000)
75+
val p4 = new BoundaryPartitioner(4, rdd, 1000)
76+
val anotherP4 = new BoundaryPartitioner(4, rdd)
77+
78+
assert(p2 === p2)
79+
assert(p4 === p4)
80+
assert(p2 != p4)
81+
assert(p4 != p2)
82+
assert(p4 === anotherP4)
83+
assert(anotherP4 === p4)
84+
}
85+
86+
test("BoundaryPartitioner getPartition") {
87+
val rdd = sc.parallelize(1.to(2000)).map(x => (x, x))
88+
val partitioner = new BoundaryPartitioner(4, rdd, 500)
89+
1.to(2000).map { element => {
90+
val partition = partitioner.getPartition(element)
91+
if (element <= 500) {
92+
assert(partition === 0)
93+
} else if (element > 501 && element <= 1000) {
94+
assert(partition === 1)
95+
} else if (element > 1001 && element <= 1500) {
96+
assert(partition === 2)
97+
} else if (element > 1501 && element <= 2000) {
98+
assert(partition === 3)
99+
}
100+
}}
101+
}
102+
69103
test("RangePartitioner getPartition") {
70104
val rdd = sc.parallelize(1.to(2000)).map(x => (x, x))
71105
// We have different behaviour of getPartition for partitions with less than 1000 and more than

0 commit comments

Comments
 (0)