Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 11 additions & 11 deletions core/src/main/scala/org/apache/spark/SparkEnv.scala
Original file line number Diff line number Diff line change
Expand Up @@ -210,12 +210,22 @@ object SparkEnv extends Logging {
"MapOutputTracker",
new MapOutputTrackerMasterActor(mapOutputTracker.asInstanceOf[MapOutputTrackerMaster], conf))

// Let the user specify short names for shuffle managers
val shortShuffleMgrNames = Map(
"hash" -> "org.apache.spark.shuffle.hash.HashShuffleManager",
"sort" -> "org.apache.spark.shuffle.sort.SortShuffleManager")
val shuffleMgrName = conf.get("spark.shuffle.manager", "hash")
val shuffleMgrClass = shortShuffleMgrNames.getOrElse(shuffleMgrName.toLowerCase, shuffleMgrName)
val shuffleManager = instantiateClass[ShuffleManager](shuffleMgrClass)

val shuffleMemoryManager = new ShuffleMemoryManager(conf)

val blockManagerMaster = new BlockManagerMaster(registerOrLookup(
"BlockManagerMaster",
new BlockManagerMasterActor(isLocal, conf, listenerBus)), conf)

val blockManager = new BlockManager(executorId, actorSystem, blockManagerMaster,
serializer, conf, securityManager, mapOutputTracker)
serializer, conf, securityManager, mapOutputTracker, shuffleManager)

val connectionManager = blockManager.connectionManager

Expand Down Expand Up @@ -250,16 +260,6 @@ object SparkEnv extends Logging {
"."
}

// Let the user specify short names for shuffle managers
val shortShuffleMgrNames = Map(
"hash" -> "org.apache.spark.shuffle.hash.HashShuffleManager",
"sort" -> "org.apache.spark.shuffle.sort.SortShuffleManager")
val shuffleMgrName = conf.get("spark.shuffle.manager", "hash")
val shuffleMgrClass = shortShuffleMgrNames.getOrElse(shuffleMgrName.toLowerCase, shuffleMgrName)
val shuffleManager = instantiateClass[ShuffleManager](shuffleMgrClass)

val shuffleMemoryManager = new ShuffleMemoryManager(conf)

// Warn about deprecated spark.cache.class property
if (conf.contains("spark.cache.class")) {
logWarning("The spark.cache.class property is no longer being used! Specify storage " +
Expand Down
11 changes: 7 additions & 4 deletions core/src/main/scala/org/apache/spark/storage/BlockManager.scala
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ import org.apache.spark.executor._
import org.apache.spark.io.CompressionCodec
import org.apache.spark.network._
import org.apache.spark.serializer.Serializer
import org.apache.spark.shuffle.ShuffleManager
import org.apache.spark.util._

private[spark] sealed trait BlockValues
Expand All @@ -57,11 +58,12 @@ private[spark] class BlockManager(
maxMemory: Long,
val conf: SparkConf,
securityManager: SecurityManager,
mapOutputTracker: MapOutputTracker)
mapOutputTracker: MapOutputTracker,
shuffleManager: ShuffleManager)
extends Logging {

private val port = conf.getInt("spark.blockManager.port", 0)
val shuffleBlockManager = new ShuffleBlockManager(this)
val shuffleBlockManager = new ShuffleBlockManager(this, shuffleManager)
val diskBlockManager = new DiskBlockManager(shuffleBlockManager,
conf.get("spark.local.dir", System.getProperty("java.io.tmpdir")))
val connectionManager =
Expand Down Expand Up @@ -142,9 +144,10 @@ private[spark] class BlockManager(
serializer: Serializer,
conf: SparkConf,
securityManager: SecurityManager,
mapOutputTracker: MapOutputTracker) = {
mapOutputTracker: MapOutputTracker,
shuffleManager: ShuffleManager) = {
this(execId, actorSystem, master, serializer, BlockManager.getMaxMemory(conf),
conf, securityManager, mapOutputTracker)
conf, securityManager, mapOutputTracker, shuffleManager)
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import scala.collection.JavaConversions._

import org.apache.spark.Logging
import org.apache.spark.serializer.Serializer
import org.apache.spark.shuffle.ShuffleManager
import org.apache.spark.storage.ShuffleBlockManager.ShuffleFileGroup
import org.apache.spark.util.{MetadataCleaner, MetadataCleanerType, TimeStampedHashMap}
import org.apache.spark.util.collection.{PrimitiveKeyOpenHashMap, PrimitiveVector}
Expand Down Expand Up @@ -62,7 +63,8 @@ private[spark] trait ShuffleWriterGroup {
*/
// TODO: Factor this into a separate class for each ShuffleManager implementation
private[spark]
class ShuffleBlockManager(blockManager: BlockManager) extends Logging {
class ShuffleBlockManager(blockManager: BlockManager,
shuffleManager: ShuffleManager) extends Logging {
def conf = blockManager.conf

// Turning off shuffle file consolidation causes all shuffle Blocks to get their own file.
Expand All @@ -71,8 +73,7 @@ class ShuffleBlockManager(blockManager: BlockManager) extends Logging {
conf.getBoolean("spark.shuffle.consolidateFiles", false)

// Are we using sort-based shuffle?
val sortBasedShuffle =
conf.get("spark.shuffle.manager", "") == classOf[SortShuffleManager].getName
val sortBasedShuffle = shuffleManager.isInstanceOf[SortShuffleManager]

private val bufferSize = conf.getInt("spark.shuffle.file.buffer.kb", 32) * 1024

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ package org.apache.spark.storage
import java.util.concurrent.ArrayBlockingQueue

import akka.actor._
import org.apache.spark.shuffle.hash.HashShuffleManager
import util.Random

import org.apache.spark.{MapOutputTrackerMaster, SecurityManager, SparkConf}
Expand Down Expand Up @@ -101,7 +102,7 @@ private[spark] object ThreadingTest {
conf)
val blockManager = new BlockManager(
"<driver>", actorSystem, blockManagerMaster, serializer, 1024 * 1024, conf,
new SecurityManager(conf), new MapOutputTrackerMaster(conf))
new SecurityManager(conf), new MapOutputTrackerMaster(conf), new HashShuffleManager(conf))
val producers = (1 to numProducers).map(i => new ProducerThread(blockManager, i))
val consumers = producers.map(p => new ConsumerThread(blockManager, p.queue))
producers.foreach(_.start)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import java.util.concurrent.TimeUnit
import akka.actor._
import akka.pattern.ask
import akka.util.Timeout
import org.apache.spark.shuffle.hash.HashShuffleManager

import org.mockito.invocation.InvocationOnMock
import org.mockito.Matchers.any
Expand Down Expand Up @@ -61,6 +62,7 @@ class BlockManagerSuite extends FunSuite with Matchers with BeforeAndAfter
conf.set("spark.authenticate", "false")
val securityMgr = new SecurityManager(conf)
val mapOutputTracker = new MapOutputTrackerMaster(conf)
val shuffleManager = new HashShuffleManager(conf)

// Reuse a serializer across tests to avoid creating a new thread-local buffer on each test
conf.set("spark.kryoserializer.buffer.mb", "1")
Expand All @@ -71,8 +73,8 @@ class BlockManagerSuite extends FunSuite with Matchers with BeforeAndAfter
def rdd(rddId: Int, splitId: Int) = RDDBlockId(rddId, splitId)

private def makeBlockManager(maxMem: Long, name: String = "<driver>"): BlockManager = {
new BlockManager(
name, actorSystem, master, serializer, maxMem, conf, securityMgr, mapOutputTracker)
new BlockManager(name, actorSystem, master, serializer, maxMem, conf, securityMgr,
mapOutputTracker, shuffleManager)
}

before {
Expand Down Expand Up @@ -791,7 +793,7 @@ class BlockManagerSuite extends FunSuite with Matchers with BeforeAndAfter
test("block store put failure") {
// Use Java serializer so we can create an unserializable error.
store = new BlockManager("<driver>", actorSystem, master, new JavaSerializer(conf), 1200, conf,
securityMgr, mapOutputTracker)
securityMgr, mapOutputTracker, shuffleManager)

// The put should fail since a1 is not serializable.
class UnserializableClass
Expand Down Expand Up @@ -1007,7 +1009,7 @@ class BlockManagerSuite extends FunSuite with Matchers with BeforeAndAfter

test("return error message when error occurred in BlockManagerWorker#onBlockMessageReceive") {
store = new BlockManager("<driver>", actorSystem, master, serializer, 1200, conf,
securityMgr, mapOutputTracker)
securityMgr, mapOutputTracker, shuffleManager)

val worker = spy(new BlockManagerWorker(store))
val connManagerId = mock(classOf[ConnectionManagerId])
Expand Down Expand Up @@ -1054,7 +1056,7 @@ class BlockManagerSuite extends FunSuite with Matchers with BeforeAndAfter

test("return ack message when no error occurred in BlocManagerWorker#onBlockMessageReceive") {
store = new BlockManager("<driver>", actorSystem, master, serializer, 1200, conf,
securityMgr, mapOutputTracker)
securityMgr, mapOutputTracker, shuffleManager)

val worker = spy(new BlockManagerWorker(store))
val connManagerId = mock(classOf[ConnectionManagerId])
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ package org.apache.spark.storage

import java.io.{File, FileWriter}

import org.apache.spark.shuffle.hash.HashShuffleManager

import scala.collection.mutable
import scala.language.reflectiveCalls

Expand All @@ -42,7 +44,9 @@ class DiskBlockManagerSuite extends FunSuite with BeforeAndAfterEach with Before
// so we coerce consolidation if not already enabled.
testConf.set("spark.shuffle.consolidateFiles", "true")

val shuffleBlockManager = new ShuffleBlockManager(null) {
private val shuffleManager = new HashShuffleManager(testConf.clone)

val shuffleBlockManager = new ShuffleBlockManager(null, shuffleManager) {
override def conf = testConf.clone
var idToSegmentMap = mutable.Map[ShuffleBlockId, FileSegment]()
override def getBlockLocation(id: ShuffleBlockId) = idToSegmentMap(id)
Expand Down Expand Up @@ -148,7 +152,7 @@ class DiskBlockManagerSuite extends FunSuite with BeforeAndAfterEach with Before
actorSystem.actorOf(Props(new BlockManagerMasterActor(true, confCopy, new LiveListenerBus))),
confCopy)
val store = new BlockManager("<driver>", actorSystem, master , serializer, confCopy,
securityManager, null)
securityManager, null, shuffleManager)

try {

Expand Down