diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index 22449517d100f..865d2ddc19140 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -193,15 +193,16 @@ class DAGScheduler( eventProcessActor ! TaskSetFailed(taskSet, reason) } - private def getCacheLocs(rdd: RDD[_]): Array[Seq[TaskLocation]] = { - if (!cacheLocs.contains(rdd.id)) { - val blockIds = rdd.partitions.indices.map(index => RDDBlockId(rdd.id, index)).toArray[BlockId] - val locs = BlockManager.blockIdsToBlockManagers(blockIds, env, blockManagerMaster) - cacheLocs(rdd.id) = blockIds.map { id => - locs.getOrElse(id, Nil).map(bm => TaskLocation(bm.host, bm.executorId)) - } + private def getLocs(rdd: RDD[_]): Array[Seq[TaskLocation]] = { + val blockIds = rdd.partitions.indices.map(index => RDDBlockId(rdd.id, index)).toArray[BlockId] + val locs = BlockManager.blockIdsToBlockManagers(blockIds, env, blockManagerMaster) + blockIds.map { id => + locs.getOrElse(id, Nil).map(bm => TaskLocation(bm.host, bm.executorId)) } - cacheLocs(rdd.id) + } + + private def getCacheLocs(rdd: RDD[_]): Array[Seq[TaskLocation]] = { + cacheLocs.getOrElseUpdate(rdd.id,getLocs(rdd)) } private def clearCacheLocs() {