Skip to content

Commit b9d99f7

Browse files
mn-mikkemn-mikke
authored andcommitted
[SPARK-23821][SQL] Merging current master to the feature branch
2 parents e213341 + 6498884 commit b9d99f7

File tree

57 files changed

+771
-306
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

57 files changed

+771
-306
lines changed

common/unsafe/src/main/java/org/apache/spark/sql/catalyst/expressions/HiveHasher.java

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
package org.apache.spark.sql.catalyst.expressions;
1919

2020
import org.apache.spark.unsafe.memory.MemoryBlock;
21+
import org.apache.spark.unsafe.types.UTF8String;
2122

2223
/**
2324
* Simulates Hive's hashing function from Hive v1.2.1
@@ -51,4 +52,8 @@ public static int hashUnsafeBytesBlock(MemoryBlock mb) {
5152
public static int hashUnsafeBytes(Object base, long offset, int lengthInBytes) {
5253
return hashUnsafeBytesBlock(MemoryBlock.allocateFromObject(base, offset, lengthInBytes));
5354
}
55+
56+
public static int hashUTF8String(UTF8String str) {
57+
return hashUnsafeBytesBlock(str.getMemoryBlock());
58+
}
5459
}

common/unsafe/src/main/java/org/apache/spark/unsafe/hash/Murmur3_x86_32.java

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import com.google.common.primitives.Ints;
2121

2222
import org.apache.spark.unsafe.memory.MemoryBlock;
23+
import org.apache.spark.unsafe.types.UTF8String;
2324

2425
/**
2526
* 32-bit Murmur3 hasher. This is based on Guava's Murmur3_32HashFunction.
@@ -82,6 +83,10 @@ public static int hashUnsafeBytesBlock(MemoryBlock base, int seed) {
8283
return fmix(h1, lengthInBytes);
8384
}
8485

86+
public static int hashUTF8String(UTF8String str, int seed) {
87+
return hashUnsafeBytesBlock(str.getMemoryBlock(), seed);
88+
}
89+
8590
public static int hashUnsafeBytes(Object base, long offset, int lengthInBytes, int seed) {
8691
return hashUnsafeBytesBlock(MemoryBlock.allocateFromObject(base, offset, lengthInBytes), seed);
8792
}
@@ -91,7 +96,7 @@ public static int hashUnsafeBytes2(Object base, long offset, int lengthInBytes,
9196
}
9297

9398
public static int hashUnsafeBytes2Block(MemoryBlock base, int seed) {
94-
// This is compatible with original and another implementations.
99+
// This is compatible with original and other implementations.
95100
// Use this method for new components after Spark 2.3.
96101
int lengthInBytes = Ints.checkedCast(base.size());
97102
assert (lengthInBytes >= 0) : "lengthInBytes cannot be negative";

core/src/main/scala/org/apache/spark/executor/Executor.scala

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -480,6 +480,19 @@ private[spark] class Executor(
480480
execBackend.statusUpdate(taskId, TaskState.FINISHED, serializedResult)
481481

482482
} catch {
483+
case t: TaskKilledException =>
484+
logInfo(s"Executor killed $taskName (TID $taskId), reason: ${t.reason}")
485+
setTaskFinishedAndClearInterruptStatus()
486+
execBackend.statusUpdate(taskId, TaskState.KILLED, ser.serialize(TaskKilled(t.reason)))
487+
488+
case _: InterruptedException | NonFatal(_) if
489+
task != null && task.reasonIfKilled.isDefined =>
490+
val killReason = task.reasonIfKilled.getOrElse("unknown reason")
491+
logInfo(s"Executor interrupted and killed $taskName (TID $taskId), reason: $killReason")
492+
setTaskFinishedAndClearInterruptStatus()
493+
execBackend.statusUpdate(
494+
taskId, TaskState.KILLED, ser.serialize(TaskKilled(killReason)))
495+
483496
case t: Throwable if hasFetchFailure && !Utils.isFatalError(t) =>
484497
val reason = task.context.fetchFailed.get.toTaskFailedReason
485498
if (!t.isInstanceOf[FetchFailedException]) {
@@ -494,19 +507,6 @@ private[spark] class Executor(
494507
setTaskFinishedAndClearInterruptStatus()
495508
execBackend.statusUpdate(taskId, TaskState.FAILED, ser.serialize(reason))
496509

497-
case t: TaskKilledException =>
498-
logInfo(s"Executor killed $taskName (TID $taskId), reason: ${t.reason}")
499-
setTaskFinishedAndClearInterruptStatus()
500-
execBackend.statusUpdate(taskId, TaskState.KILLED, ser.serialize(TaskKilled(t.reason)))
501-
502-
case _: InterruptedException | NonFatal(_) if
503-
task != null && task.reasonIfKilled.isDefined =>
504-
val killReason = task.reasonIfKilled.getOrElse("unknown reason")
505-
logInfo(s"Executor interrupted and killed $taskName (TID $taskId), reason: $killReason")
506-
setTaskFinishedAndClearInterruptStatus()
507-
execBackend.statusUpdate(
508-
taskId, TaskState.KILLED, ser.serialize(TaskKilled(killReason)))
509-
510510
case CausedBy(cDE: CommitDeniedException) =>
511511
val reason = cDE.toTaskCommitDeniedReason
512512
setTaskFinishedAndClearInterruptStatus()

core/src/test/scala/org/apache/spark/JobCancellationSuite.scala

Lines changed: 18 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -332,13 +332,15 @@ class JobCancellationSuite extends SparkFunSuite with Matchers with BeforeAndAft
332332
import JobCancellationSuite._
333333
sc = new SparkContext("local[2]", "test interruptible iterator")
334334

335+
// Increase the number of elements to be proceeded to avoid this test being flaky.
336+
val numElements = 10000
335337
val taskCompletedSem = new Semaphore(0)
336338

337339
sc.addSparkListener(new SparkListener {
338340
override def onStageCompleted(stageCompleted: SparkListenerStageCompleted): Unit = {
339341
// release taskCancelledSemaphore when cancelTasks event has been posted
340342
if (stageCompleted.stageInfo.stageId == 1) {
341-
taskCancelledSemaphore.release(1000)
343+
taskCancelledSemaphore.release(numElements)
342344
}
343345
}
344346

@@ -349,36 +351,39 @@ class JobCancellationSuite extends SparkFunSuite with Matchers with BeforeAndAft
349351
}
350352
})
351353

352-
val f = sc.parallelize(1 to 1000).map { i => (i, i) }
354+
// Explicitly disable interrupt task thread on cancelling tasks, so the task thread can only be
355+
// interrupted by `InterruptibleIterator`.
356+
sc.setLocalProperty(SparkContext.SPARK_JOB_INTERRUPT_ON_CANCEL, "false")
357+
358+
val f = sc.parallelize(1 to numElements).map { i => (i, i) }
353359
.repartitionAndSortWithinPartitions(new HashPartitioner(1))
354360
.mapPartitions { iter =>
355361
taskStartedSemaphore.release()
356362
iter
357363
}.foreachAsync { x =>
358-
if (x._1 >= 10) {
359-
// This block of code is partially executed. It will be blocked when x._1 >= 10 and the
360-
// next iteration will be cancelled if the source iterator is interruptible. Then in this
361-
// case, the maximum num of increment would be 10(|1...10|)
362-
taskCancelledSemaphore.acquire()
363-
}
364+
// Block this code from being executed, until the job get cancelled. In this case, if the
365+
// source iterator is interruptible, the max number of increment should be under
366+
// `numElements`.
367+
taskCancelledSemaphore.acquire()
364368
executionOfInterruptibleCounter.getAndIncrement()
365369
}
366370

367371
taskStartedSemaphore.acquire()
368372
// Job is cancelled when:
369373
// 1. task in reduce stage has been started, guaranteed by previous line.
370-
// 2. task in reduce stage is blocked after processing at most 10 records as
371-
// taskCancelledSemaphore is not released until cancelTasks event is posted
372-
// After job being cancelled, task in reduce stage will be cancelled and no more iteration are
373-
// executed.
374+
// 2. task in reduce stage is blocked as taskCancelledSemaphore is not released until
375+
// JobCancelled event is posted.
376+
// After job being cancelled, task in reduce stage will be cancelled asynchronously, thus
377+
// partial of the inputs should not get processed (It's very unlikely that Spark can process
378+
// 10000 elements between JobCancelled is posted and task is really killed).
374379
f.cancel()
375380

376381
val e = intercept[SparkException](f.get()).getCause
377382
assert(e.getMessage.contains("cancelled") || e.getMessage.contains("killed"))
378383

379384
// Make sure tasks are indeed completed.
380385
taskCompletedSem.acquire()
381-
assert(executionOfInterruptibleCounter.get() <= 10)
386+
assert(executionOfInterruptibleCounter.get() < numElements)
382387
}
383388

384389
def testCount() {

core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala

Lines changed: 75 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ import java.lang.Thread.UncaughtExceptionHandler
2222
import java.nio.ByteBuffer
2323
import java.util.Properties
2424
import java.util.concurrent.{CountDownLatch, TimeUnit}
25+
import java.util.concurrent.atomic.AtomicBoolean
2526

2627
import scala.collection.mutable.Map
2728
import scala.concurrent.duration._
@@ -139,7 +140,7 @@ class ExecutorSuite extends SparkFunSuite with LocalSparkContext with MockitoSug
139140
// the fetch failure. The executor should still tell the driver that the task failed due to a
140141
// fetch failure, not a generic exception from user code.
141142
val inputRDD = new FetchFailureThrowingRDD(sc)
142-
val secondRDD = new FetchFailureHidingRDD(sc, inputRDD, throwOOM = false)
143+
val secondRDD = new FetchFailureHidingRDD(sc, inputRDD, throwOOM = false, interrupt = false)
143144
val taskBinary = sc.broadcast(serializer.serialize((secondRDD, resultFunc)).array())
144145
val serializedTaskMetrics = serializer.serialize(TaskMetrics.registered).array()
145146
val task = new ResultTask(
@@ -173,17 +174,48 @@ class ExecutorSuite extends SparkFunSuite with LocalSparkContext with MockitoSug
173174
}
174175

175176
test("SPARK-19276: OOMs correctly handled with a FetchFailure") {
177+
val (failReason, uncaughtExceptionHandler) = testFetchFailureHandling(true)
178+
assert(failReason.isInstanceOf[ExceptionFailure])
179+
val exceptionCaptor = ArgumentCaptor.forClass(classOf[Throwable])
180+
verify(uncaughtExceptionHandler).uncaughtException(any(), exceptionCaptor.capture())
181+
assert(exceptionCaptor.getAllValues.size === 1)
182+
assert(exceptionCaptor.getAllValues().get(0).isInstanceOf[OutOfMemoryError])
183+
}
184+
185+
test("SPARK-23816: interrupts are not masked by a FetchFailure") {
186+
// If killing the task causes a fetch failure, we still treat it as a task that was killed,
187+
// as the fetch failure could easily be caused by interrupting the thread.
188+
val (failReason, _) = testFetchFailureHandling(false)
189+
assert(failReason.isInstanceOf[TaskKilled])
190+
}
191+
192+
/**
193+
* Helper for testing some cases where a FetchFailure should *not* get sent back, because its
194+
* superceded by another error, either an OOM or intentionally killing a task.
195+
* @param oom if true, throw an OOM after the FetchFailure; else, interrupt the task after the
196+
* FetchFailure
197+
*/
198+
private def testFetchFailureHandling(
199+
oom: Boolean): (TaskFailedReason, UncaughtExceptionHandler) = {
176200
// when there is a fatal error like an OOM, we don't do normal fetch failure handling, since it
177201
// may be a false positive. And we should call the uncaught exception handler.
202+
// SPARK-23816 also handle interrupts the same way, as killing an obsolete speculative task
203+
// does not represent a real fetch failure.
178204
val conf = new SparkConf().setMaster("local").setAppName("executor suite test")
179205
sc = new SparkContext(conf)
180206
val serializer = SparkEnv.get.closureSerializer.newInstance()
181207
val resultFunc = (context: TaskContext, itr: Iterator[Int]) => itr.size
182208

183-
// Submit a job where a fetch failure is thrown, but then there is an OOM. We should treat
184-
// the fetch failure as a false positive, and just do normal OOM handling.
209+
// Submit a job where a fetch failure is thrown, but then there is an OOM or interrupt. We
210+
// should treat the fetch failure as a false positive, and do normal OOM or interrupt handling.
185211
val inputRDD = new FetchFailureThrowingRDD(sc)
186-
val secondRDD = new FetchFailureHidingRDD(sc, inputRDD, throwOOM = true)
212+
if (!oom) {
213+
// we are trying to setup a case where a task is killed after a fetch failure -- this
214+
// is just a helper to coordinate between the task thread and this thread that will
215+
// kill the task
216+
ExecutorSuiteHelper.latches = new ExecutorSuiteHelper()
217+
}
218+
val secondRDD = new FetchFailureHidingRDD(sc, inputRDD, throwOOM = oom, interrupt = !oom)
187219
val taskBinary = sc.broadcast(serializer.serialize((secondRDD, resultFunc)).array())
188220
val serializedTaskMetrics = serializer.serialize(TaskMetrics.registered).array()
189221
val task = new ResultTask(
@@ -200,15 +232,8 @@ class ExecutorSuite extends SparkFunSuite with LocalSparkContext with MockitoSug
200232
val serTask = serializer.serialize(task)
201233
val taskDescription = createFakeTaskDescription(serTask)
202234

203-
val (failReason, uncaughtExceptionHandler) =
204-
runTaskGetFailReasonAndExceptionHandler(taskDescription)
205-
// make sure the task failure just looks like a OOM, not a fetch failure
206-
assert(failReason.isInstanceOf[ExceptionFailure])
207-
val exceptionCaptor = ArgumentCaptor.forClass(classOf[Throwable])
208-
verify(uncaughtExceptionHandler).uncaughtException(any(), exceptionCaptor.capture())
209-
assert(exceptionCaptor.getAllValues.size === 1)
210-
assert(exceptionCaptor.getAllValues.get(0).isInstanceOf[OutOfMemoryError])
211-
}
235+
runTaskGetFailReasonAndExceptionHandler(taskDescription, killTask = !oom)
236+
}
212237

213238
test("Gracefully handle error in task deserialization") {
214239
val conf = new SparkConf
@@ -257,22 +282,39 @@ class ExecutorSuite extends SparkFunSuite with LocalSparkContext with MockitoSug
257282
}
258283

259284
private def runTaskAndGetFailReason(taskDescription: TaskDescription): TaskFailedReason = {
260-
runTaskGetFailReasonAndExceptionHandler(taskDescription)._1
285+
runTaskGetFailReasonAndExceptionHandler(taskDescription, false)._1
261286
}
262287

263288
private def runTaskGetFailReasonAndExceptionHandler(
264-
taskDescription: TaskDescription): (TaskFailedReason, UncaughtExceptionHandler) = {
289+
taskDescription: TaskDescription,
290+
killTask: Boolean): (TaskFailedReason, UncaughtExceptionHandler) = {
265291
val mockBackend = mock[ExecutorBackend]
266292
val mockUncaughtExceptionHandler = mock[UncaughtExceptionHandler]
267293
var executor: Executor = null
294+
val timedOut = new AtomicBoolean(false)
268295
try {
269296
executor = new Executor("id", "localhost", SparkEnv.get, userClassPath = Nil, isLocal = true,
270297
uncaughtExceptionHandler = mockUncaughtExceptionHandler)
271298
// the task will be launched in a dedicated worker thread
272299
executor.launchTask(mockBackend, taskDescription)
300+
if (killTask) {
301+
val killingThread = new Thread("kill-task") {
302+
override def run(): Unit = {
303+
// wait to kill the task until it has thrown a fetch failure
304+
if (ExecutorSuiteHelper.latches.latch1.await(10, TimeUnit.SECONDS)) {
305+
// now we can kill the task
306+
executor.killAllTasks(true, "Killed task, eg. because of speculative execution")
307+
} else {
308+
timedOut.set(true)
309+
}
310+
}
311+
}
312+
killingThread.start()
313+
}
273314
eventually(timeout(5.seconds), interval(10.milliseconds)) {
274315
assert(executor.numRunningTasks === 0)
275316
}
317+
assert(!timedOut.get(), "timed out waiting to be ready to kill tasks")
276318
} finally {
277319
if (executor != null) {
278320
executor.stop()
@@ -282,8 +324,9 @@ class ExecutorSuite extends SparkFunSuite with LocalSparkContext with MockitoSug
282324
val statusCaptor = ArgumentCaptor.forClass(classOf[ByteBuffer])
283325
orderedMock.verify(mockBackend)
284326
.statusUpdate(meq(0L), meq(TaskState.RUNNING), statusCaptor.capture())
327+
val finalState = if (killTask) TaskState.KILLED else TaskState.FAILED
285328
orderedMock.verify(mockBackend)
286-
.statusUpdate(meq(0L), meq(TaskState.FAILED), statusCaptor.capture())
329+
.statusUpdate(meq(0L), meq(finalState), statusCaptor.capture())
287330
// first statusUpdate for RUNNING has empty data
288331
assert(statusCaptor.getAllValues().get(0).remaining() === 0)
289332
// second update is more interesting
@@ -321,7 +364,8 @@ class SimplePartition extends Partition {
321364
class FetchFailureHidingRDD(
322365
sc: SparkContext,
323366
val input: FetchFailureThrowingRDD,
324-
throwOOM: Boolean) extends RDD[Int](input) {
367+
throwOOM: Boolean,
368+
interrupt: Boolean) extends RDD[Int](input) {
325369
override def compute(split: Partition, context: TaskContext): Iterator[Int] = {
326370
val inItr = input.compute(split, context)
327371
try {
@@ -330,6 +374,15 @@ class FetchFailureHidingRDD(
330374
case t: Throwable =>
331375
if (throwOOM) {
332376
throw new OutOfMemoryError("OOM while handling another exception")
377+
} else if (interrupt) {
378+
// make sure our test is setup correctly
379+
assert(TaskContext.get().asInstanceOf[TaskContextImpl].fetchFailed.isDefined)
380+
// signal our test is ready for the task to get killed
381+
ExecutorSuiteHelper.latches.latch1.countDown()
382+
// then wait for another thread in the test to kill the task -- this latch
383+
// is never actually decremented, we just wait to get killed.
384+
ExecutorSuiteHelper.latches.latch2.await(10, TimeUnit.SECONDS)
385+
throw new IllegalStateException("timed out waiting to be interrupted")
333386
} else {
334387
throw new RuntimeException("User Exception that hides the original exception", t)
335388
}
@@ -352,6 +405,11 @@ private class ExecutorSuiteHelper {
352405
@volatile var testFailedReason: TaskFailedReason = _
353406
}
354407

408+
// helper for coordinating killing tasks
409+
private object ExecutorSuiteHelper {
410+
var latches: ExecutorSuiteHelper = null
411+
}
412+
355413
private class NonDeserializableTask extends FakeTask(0, 0) with Externalizable {
356414
def writeExternal(out: ObjectOutput): Unit = {}
357415
def readExternal(in: ObjectInput): Unit = {

mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -165,7 +165,7 @@ object DecisionTreeClassifier extends DefaultParamsReadable[DecisionTreeClassifi
165165
@Since("1.4.0")
166166
class DecisionTreeClassificationModel private[ml] (
167167
@Since("1.4.0")override val uid: String,
168-
@Since("1.4.0")override val rootNode: Node,
168+
@Since("1.4.0")override val rootNode: ClassificationNode,
169169
@Since("1.6.0")override val numFeatures: Int,
170170
@Since("1.5.0")override val numClasses: Int)
171171
extends ProbabilisticClassificationModel[Vector, DecisionTreeClassificationModel]
@@ -178,7 +178,7 @@ class DecisionTreeClassificationModel private[ml] (
178178
* Construct a decision tree classification model.
179179
* @param rootNode Root node of tree, with other nodes attached.
180180
*/
181-
private[ml] def this(rootNode: Node, numFeatures: Int, numClasses: Int) =
181+
private[ml] def this(rootNode: ClassificationNode, numFeatures: Int, numClasses: Int) =
182182
this(Identifiable.randomUID("dtc"), rootNode, numFeatures, numClasses)
183183

184184
override def predict(features: Vector): Double = {
@@ -276,8 +276,9 @@ object DecisionTreeClassificationModel extends MLReadable[DecisionTreeClassifica
276276
val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
277277
val numFeatures = (metadata.metadata \ "numFeatures").extract[Int]
278278
val numClasses = (metadata.metadata \ "numClasses").extract[Int]
279-
val root = loadTreeNodes(path, metadata, sparkSession)
280-
val model = new DecisionTreeClassificationModel(metadata.uid, root, numFeatures, numClasses)
279+
val root = loadTreeNodes(path, metadata, sparkSession, isClassification = true)
280+
val model = new DecisionTreeClassificationModel(metadata.uid,
281+
root.asInstanceOf[ClassificationNode], numFeatures, numClasses)
281282
DefaultParamsReader.getAndSetParams(model, metadata)
282283
model
283284
}
@@ -292,9 +293,10 @@ object DecisionTreeClassificationModel extends MLReadable[DecisionTreeClassifica
292293
require(oldModel.algo == OldAlgo.Classification,
293294
s"Cannot convert non-classification DecisionTreeModel (old API) to" +
294295
s" DecisionTreeClassificationModel (new API). Algo is: ${oldModel.algo}")
295-
val rootNode = Node.fromOld(oldModel.topNode, categoricalFeatures)
296+
val rootNode = Node.fromOld(oldModel.topNode, categoricalFeatures, isClassification = true)
296297
val uid = if (parent != null) parent.uid else Identifiable.randomUID("dtc")
297298
// Can't infer number of features from old model, so default to -1
298-
new DecisionTreeClassificationModel(uid, rootNode, numFeatures, -1)
299+
new DecisionTreeClassificationModel(uid,
300+
rootNode.asInstanceOf[ClassificationNode], numFeatures, -1)
299301
}
300302
}

0 commit comments

Comments
 (0)