Skip to content

Commit 9614a0c

Browse files
committed
Additional tests to test protocol changes
1 parent 351ae53 commit 9614a0c

File tree

2 files changed

+54
-12
lines changed

2 files changed

+54
-12
lines changed

core/src/main/scala/org/apache/spark/MapOutputTracker.scala

Lines changed: 7 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -289,12 +289,12 @@ private class ShuffleStatus(
289289
isLocal: Boolean,
290290
minBroadcastSize: Int,
291291
conf: SparkConf,
292-
isMapOutput: Boolean): (Array[Byte], Array[Byte]) = {
292+
isMapOnlyOutput: Boolean): (Array[Byte], Array[Byte]) = {
293293
var mapStatuses: Array[Byte] = null
294294
var mergeStatuses: Array[Byte] = null
295295

296296
withReadLock {
297-
if (isMapOutput) {
297+
if (isMapOnlyOutput) {
298298
if (cachedSerializedMapStatus != null) {
299299
mapStatuses = cachedSerializedMapStatus
300300
}
@@ -309,7 +309,7 @@ private class ShuffleStatus(
309309
}
310310
}
311311

312-
if (isMapOutput) {
312+
if (isMapOnlyOutput) {
313313
if (mapStatuses == null) {
314314
mapStatuses =
315315
serializeAndCacheMapStatuses(broadcastManager, isLocal, minBroadcastSize, conf)
@@ -646,19 +646,14 @@ private[spark] class MapOutputTrackerMaster(
646646
private def handleStatusMessage(
647647
shuffleId: Int,
648648
context: RpcCallContext,
649-
isMapOutput: Boolean): Unit = {
649+
isMapOnlyOutput: Boolean): Unit = {
650650
val hostPort = context.senderAddress.hostPort
651651
val shuffleStatus = shuffleStatuses.get(shuffleId).head
652-
val mapOrMerge = if (isMapOutput) {
653-
"map"
654-
} else {
655-
"merge"
656-
}
657-
logDebug(s"Handling request to send $mapOrMerge output locations" +
658-
s" for shuffle $shuffleId to $hostPort")
652+
logDebug(s"Handling request to send ${if (isMapOnlyOutput) "map" else "map/merge"}" +
653+
s" output locations for shuffle $shuffleId to $hostPort")
659654
context.reply(
660655
shuffleStatus.serializedOutputStatus(broadcastManager, isLocal,
661-
minSizeForBroadcast, conf, isMapOutput = isMapOutput))
656+
minSizeForBroadcast, conf, isMapOnlyOutput = isMapOnlyOutput))
662657
}
663658

664659
override def run(): Unit = {

core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -589,6 +589,53 @@ class MapOutputTrackerSuite extends SparkFunSuite with LocalSparkContext {
589589
}
590590
}
591591

592+
test("SPARK-32921: test new protocol changes fetching both Map and Merge status in single RPC") {
593+
val newConf = new SparkConf
594+
newConf.set(RPC_MESSAGE_MAX_SIZE, 1)
595+
newConf.set(RPC_ASK_TIMEOUT, "1") // Fail fast
596+
newConf.set(SHUFFLE_MAPOUTPUT_MIN_SIZE_FOR_BROADCAST, 10240L) // 10 KiB << 1MiB framesize
597+
newConf.set(PUSH_BASED_SHUFFLE_ENABLED, true)
598+
newConf.set(IS_TESTING, true)
599+
600+
// needs TorrentBroadcast so need a SparkContext
601+
withSpark(new SparkContext("local", "MapOutputTrackerSuite", newConf)) { sc =>
602+
val masterTracker = sc.env.mapOutputTracker.asInstanceOf[MapOutputTrackerMaster]
603+
val rpcEnv = sc.env.rpcEnv
604+
val masterEndpoint = new MapOutputTrackerMasterEndpoint(rpcEnv, masterTracker, newConf)
605+
rpcEnv.stop(masterTracker.trackerEndpoint)
606+
rpcEnv.setupEndpoint(MapOutputTracker.ENDPOINT_NAME, masterEndpoint)
607+
val bitmap1 = new RoaringBitmap()
608+
bitmap1.add(1)
609+
610+
masterTracker.registerShuffle(20, 100, MergeStatus.SHUFFLE_PUSH_DUMMY_NUM_REDUCES)
611+
(0 until 100).foreach { i =>
612+
masterTracker.registerMapOutput(20, i, new CompressedMapStatus(
613+
BlockManagerId("999", "mps", 1000), Array.fill[Long](4000000)(0), 5))
614+
}
615+
masterTracker.registerMergeResult(20, 0, MergeStatus(BlockManagerId("999", "mps", 1000),
616+
bitmap1, 1000L))
617+
618+
val mapWorkerRpcEnv = createRpcEnv("spark-worker", "localhost", 0, new SecurityManager(conf))
619+
val mapWorkerTracker = new MapOutputTrackerWorker(conf)
620+
mapWorkerTracker.trackerEndpoint =
621+
mapWorkerRpcEnv.setupEndpointRef(rpcEnv.address, MapOutputTracker.ENDPOINT_NAME)
622+
623+
val fetchedBytes = mapWorkerTracker.trackerEndpoint
624+
.askSync[(Array[Byte], Array[Byte])](GetMapAndMergeResultStatuses(20))
625+
assert(masterTracker.getNumAvailableMergeResults(20) == 1)
626+
assert(masterTracker.getNumAvailableOutputs(20) == 100)
627+
628+
val mapOutput =
629+
MapOutputTracker.deserializeOutputStatuses[MapStatus](fetchedBytes._1, newConf)
630+
val mergeOutput =
631+
MapOutputTracker.deserializeOutputStatuses[MergeStatus](fetchedBytes._2, newConf)
632+
assert(mapOutput.length == 100)
633+
assert(mergeOutput.length == 1)
634+
mapWorkerTracker.stop()
635+
masterTracker.stop()
636+
}
637+
}
638+
592639
test("SPARK-32921: unregister merge result if it is present and contains the map Id") {
593640
val rpcEnv = createRpcEnv("test")
594641
val tracker = newTrackerMaster()

0 commit comments

Comments
 (0)