From 8e8ed6334dc616349d4017e4be550d73446a5e3e Mon Sep 17 00:00:00 2001 From: Ajith Date: Fri, 9 Aug 2019 11:04:50 +0530 Subject: [PATCH] dag --- .../apache/spark/scheduler/DAGScheduler.scala | 23 +++++ .../DAGSchedulerIntegrationSuite.scala | 91 +++++++++++++++++++ 2 files changed, 114 insertions(+) create mode 100644 core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerIntegrationSuite.scala diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index 7bf363dd71c1..5ffceab3fa37 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -711,12 +711,34 @@ private[spark] class DAGScheduler( assert(partitions.nonEmpty) val func2 = func.asInstanceOf[(TaskContext, Iterator[_]) => _] val waiter = new JobWaiter[U](this, jobId, partitions.size, resultHandler) + eagerPartitions(rdd) eventProcessLoop.post(JobSubmitted( jobId, rdd, func2, partitions.toArray, callSite, waiter, Utils.cloneProperties(properties))) waiter } + /** + * Responsible for eager evaluation of all dependency partitions. + * Takes effect only if spark.rdd.eager.partitions is true + * + * @param rdd : initial rdd to be evaluated + * @param visited: Set of rdd depndencies which are already visited + */ + def eagerPartitions( + rdd: RDD[_], + visited: mutable.HashSet[RDD[_]] = new mutable.HashSet[RDD[_]]): Unit = { + try { + rdd.partitions + rdd.dependencies.filter(dep => !visited.contains(dep.rdd)) foreach { dep => + visited.add(dep.rdd) + eagerPartitions(dep.rdd, visited) + } + } catch { + case t: Throwable => logError("Error in eager evaluation of partitions, ignoring", t) + } + } + /** * Run an action job on the given RDD and pass all the results to the resultHandler function as * they arrive. @@ -783,6 +805,7 @@ private[spark] class DAGScheduler( } val listener = new ApproximateActionListener(rdd, func, evaluator, timeout) val func2 = func.asInstanceOf[(TaskContext, Iterator[_]) => _] + eagerPartitions(rdd) eventProcessLoop.post(JobSubmitted( jobId, rdd, func2, rdd.partitions.indices.toArray, callSite, listener, Utils.cloneProperties(properties))) diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerIntegrationSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerIntegrationSuite.scala new file mode 100644 index 000000000000..7ab4a52f6728 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerIntegrationSuite.scala @@ -0,0 +1,91 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.scheduler + +import java.util.concurrent.{ConcurrentHashMap, Executors} + +import scala.concurrent.{ExecutionContext, Future} +import scala.concurrent.duration._ + +import org.scalatest.concurrent.Eventually._ + +import org.apache.spark._ +import org.apache.spark.rdd.RDD + +class DAGSchedulerIntegrationSuite extends SparkFunSuite with LocalSparkContext { + implicit val submissionPool = + ExecutionContext.fromExecutor(Executors.newFixedThreadPool(3)) + + test("blocking of DAGEventQueue due to a heavy pause job") { + sc = new SparkContext("local", "DAGSchedulerIntegrationSuite") + + // form 3 rdds (2 quick and 1 with heavy dependency calculation) + val simpleRDD1 = new DelegateRDD(sc, new PauseRDD(sc, 100)) + val heavyRDD = new DelegateRDD(sc, new PauseRDD(sc, 1000000)) + val simpleRDD2 = new DelegateRDD(sc, new PauseRDD(sc, 100)) + + // submit all concurrently + val finishedRDDs = new ConcurrentHashMap[DelegateRDD, String]() + List(simpleRDD1, heavyRDD, simpleRDD2).foreach(rdd => + Future { + rdd.collect + finishedRDDs.put(rdd, rdd.toString) + }) + + // wait for certain time and see if quick jobs can finish + eventually(timeout(10.seconds)) { + assert(finishedRDDs.size() == 2) + assert( + finishedRDDs.contains(simpleRDD1.toString) && + finishedRDDs.contains(simpleRDD2.toString)) + } + } +} + +class DelegateRDD(sc: SparkContext, var dependency: PauseRDD) + extends RDD[(Int, Int)](sc, List(new OneToOneDependency(dependency))) + with Serializable { + override def compute(split: Partition, context: TaskContext): Iterator[(Int, Int)] = { + Nil.toIterator + } + + override protected def getPartitions: Array[Partition] = { + Seq(new Partition { + override def index: Int = 0 + }).toArray + } + + override def toString: String = "DelegateRDD " + id +} + +class PauseRDD(sc: SparkContext, var pauseDuartion: Long) + extends RDD[(Int, Int)](sc, Seq.empty) + with Serializable { + override def compute(split: Partition, context: TaskContext): Iterator[(Int, Int)] = { + Nil.toIterator + } + + override protected def getPartitions: Array[Partition] = { + Thread.sleep(pauseDuartion) + Seq(new Partition { + override def index: Int = 0 + }).toArray + } + + override def toString: String = "PauseRDD " + id +}