Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 16 additions & 7 deletions core/src/main/scala/org/apache/spark/rdd/RDD.scala
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,9 @@ abstract class RDD[T: ClassTag](
/** A friendly name for this RDD */
@transient var name: String = _

/** lock for accessing partition */
@transient private val partitionLock = new Object

/** Assign a name to this RDD */
def setName(_name: String): this.type = {
name = _name
Expand Down Expand Up @@ -228,6 +231,8 @@ abstract class RDD[T: ClassTag](
// Our dependencies and partitions will be gotten by calling subclass's methods below, and will
// be overwritten when we're checkpointed
private var dependencies_ : Seq[Dependency[_]] = _

/** to be accessed only with partitionLock */
@transient private var partitions_ : Array[Partition] = _

/** An Option holding our checkpoint RDD, if we are checkpointed */
Expand All @@ -252,14 +257,16 @@ abstract class RDD[T: ClassTag](
*/
final def partitions: Array[Partition] = {
checkpointRDD.map(_.partitions).getOrElse {
if (partitions_ == null) {
partitions_ = getPartitions
partitions_.zipWithIndex.foreach { case (partition, index) =>
require(partition.index == index,
s"partitions($index).partition == ${partition.index}, but it should equal $index")
partitionLock.synchronized {
if (partitions_ == null) {
partitions_ = getPartitions
partitions_.zipWithIndex.foreach { case (partition, index) =>
require(partition.index == index,
s"partitions($index).partition == ${partition.index}, but it should equal $index")
}
}
partitions_
}
partitions_
}
}

Expand Down Expand Up @@ -1769,7 +1776,9 @@ abstract class RDD[T: ClassTag](
*/
private[spark] def markCheckpointed(): Unit = {
clearDependencies()
partitions_ = null
partitionLock.synchronized {
partitions_ = null
}
deps = null // Forget the constructor argument for dependencies too
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ import java.util.concurrent.{ConcurrentHashMap, TimeUnit}
import java.util.concurrent.atomic.AtomicInteger

import scala.annotation.tailrec
import scala.collection.Map
import scala.collection.{mutable, Map}
import scala.collection.mutable.{HashMap, HashSet, ListBuffer}
import scala.concurrent.duration._
import scala.util.control.NonFatal
Expand Down Expand Up @@ -704,12 +704,34 @@ private[spark] class DAGScheduler(
assert(partitions.nonEmpty)
val func2 = func.asInstanceOf[(TaskContext, Iterator[_]) => _]
val waiter = new JobWaiter[U](this, jobId, partitions.size, resultHandler)
eagerPartitions(rdd)
eventProcessLoop.post(JobSubmitted(
jobId, rdd, func2, partitions.toArray, callSite, waiter,
SerializationUtils.clone(properties)))
waiter
}

/**
* Responsible for eager evaluation of all dependency partitions.
* Takes effect only if <b>spark.rdd.eager.partitions</b> is true
*
* @param rdd : initial rdd to be evaluated
* @param visited: Set of rdd depndencies which are already visited
*/
def eagerPartitions(
rdd: RDD[_],
visited: mutable.HashSet[RDD[_]] = new mutable.HashSet[RDD[_]]): Unit = {
try {
rdd.partitions
rdd.dependencies.filter(dep => !visited.contains(dep.rdd)) foreach { dep =>
visited.add(dep.rdd)
eagerPartitions(dep.rdd, visited)
}
} catch {
case t: Throwable => logError("Error in eager evaluation of partitions, ignoring", t)
}
}

/**
* Run an action job on the given RDD and pass all the results to the resultHandler function as
* they arrive.
Expand Down Expand Up @@ -776,6 +798,7 @@ private[spark] class DAGScheduler(
}
val listener = new ApproximateActionListener(rdd, func, evaluator, timeout)
val func2 = func.asInstanceOf[(TaskContext, Iterator[_]) => _]
eagerPartitions(rdd)
eventProcessLoop.post(JobSubmitted(
jobId, rdd, func2, rdd.partitions.indices.toArray, callSite, listener,
SerializationUtils.clone(properties)))
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.scheduler

import java.util.concurrent.{ConcurrentHashMap, Executors}

import scala.concurrent.{ExecutionContext, Future}
import scala.concurrent.duration._

import org.scalatest.concurrent.Eventually._

import org.apache.spark._
import org.apache.spark.rdd.RDD

class DAGSchedulerIntegrationSuite extends SparkFunSuite with LocalSparkContext {
implicit val submissionPool =
ExecutionContext.fromExecutor(Executors.newFixedThreadPool(3))

test("blocking of DAGEventQueue due to a heavy pause job") {
sc = new SparkContext("local", "DAGSchedulerIntegrationSuite")

// form 3 rdds (2 quick and 1 with heavy dependency calculation)
val simpleRDD1 = new DelegateRDD(sc, new PauseRDD(sc, 100))
val heavyRDD = new DelegateRDD(sc, new PauseRDD(sc, 1000000))
val simpleRDD2 = new DelegateRDD(sc, new PauseRDD(sc, 100))

// submit all concurrently
val finishedRDDs = new ConcurrentHashMap[DelegateRDD, String]()
List(simpleRDD1, heavyRDD, simpleRDD2).foreach(rdd =>
Future {
rdd.collect
finishedRDDs.put(rdd, rdd.toString)
})

// wait for certain time and see if quick jobs can finish
eventually(timeout(10.seconds)) {
assert(finishedRDDs.size() == 2)
assert(
finishedRDDs.contains(simpleRDD1.toString) &&
finishedRDDs.contains(simpleRDD2.toString))
}
}
}

class DelegateRDD(sc: SparkContext, var dependency: PauseRDD)
extends RDD[(Int, Int)](sc, List(new OneToOneDependency(dependency)))
with Serializable {
override def compute(split: Partition, context: TaskContext): Iterator[(Int, Int)] = {
Nil.toIterator
}

override protected def getPartitions: Array[Partition] = {
Seq(new Partition {
override def index: Int = 0
}).toArray
}

override def toString: String = "DelegateRDD " + id
}

class PauseRDD(sc: SparkContext, var pauseDuartion: Long)
extends RDD[(Int, Int)](sc, Seq.empty)
with Serializable {
override def compute(split: Partition, context: TaskContext): Iterator[(Int, Int)] = {
Nil.toIterator
}

override protected def getPartitions: Array[Partition] = {
Thread.sleep(pauseDuartion)
Seq(new Partition {
override def index: Int = 0
}).toArray
}

override def toString: String = "PauseRDD " + id
}