Skip to content

Commit aff44f9

Browse files
Davies Liudavies
authored andcommitted
[SPARK-8029] Robust shuffle writer
Currently, all the shuffle writer will write to target path directly, the file could be corrupted by other attempt of the same partition on the same executor. They should write to temporary file then rename to target path, as what we do in output committer. In order to make the rename atomic, the temporary file should be created in the same local directory (FileSystem). This PR is based on #9214 , thanks to squito . Closes #9214 Author: Davies Liu <[email protected]> Closes #9610 from davies/safe_shuffle.
1 parent 55faab5 commit aff44f9

File tree

16 files changed

+402
-52
lines changed

16 files changed

+402
-52
lines changed

core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,7 @@ public void write(Iterator<Product2<K, V>> records) throws IOException {
125125
assert (partitionWriters == null);
126126
if (!records.hasNext()) {
127127
partitionLengths = new long[numPartitions];
128-
shuffleBlockResolver.writeIndexFile(shuffleId, mapId, partitionLengths);
128+
shuffleBlockResolver.writeIndexFileAndCommit(shuffleId, mapId, partitionLengths, null);
129129
mapStatus = MapStatus$.MODULE$.apply(blockManager.shuffleServerId(), partitionLengths);
130130
return;
131131
}
@@ -155,9 +155,10 @@ public void write(Iterator<Product2<K, V>> records) throws IOException {
155155
writer.commitAndClose();
156156
}
157157

158-
partitionLengths =
159-
writePartitionedFile(shuffleBlockResolver.getDataFile(shuffleId, mapId));
160-
shuffleBlockResolver.writeIndexFile(shuffleId, mapId, partitionLengths);
158+
File output = shuffleBlockResolver.getDataFile(shuffleId, mapId);
159+
File tmp = Utils.tempFileWith(output);
160+
partitionLengths = writePartitionedFile(tmp);
161+
shuffleBlockResolver.writeIndexFileAndCommit(shuffleId, mapId, partitionLengths, tmp);
161162
mapStatus = MapStatus$.MODULE$.apply(blockManager.shuffleServerId(), partitionLengths);
162163
}
163164

core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@
4141
import org.apache.spark.executor.ShuffleWriteMetrics;
4242
import org.apache.spark.io.CompressionCodec;
4343
import org.apache.spark.io.CompressionCodec$;
44-
import org.apache.spark.io.LZFCompressionCodec;
44+
import org.apache.spark.memory.TaskMemoryManager;
4545
import org.apache.spark.network.util.LimitedInputStream;
4646
import org.apache.spark.scheduler.MapStatus;
4747
import org.apache.spark.scheduler.MapStatus$;
@@ -53,7 +53,7 @@
5353
import org.apache.spark.storage.BlockManager;
5454
import org.apache.spark.storage.TimeTrackingOutputStream;
5555
import org.apache.spark.unsafe.Platform;
56-
import org.apache.spark.memory.TaskMemoryManager;
56+
import org.apache.spark.util.Utils;
5757

5858
@Private
5959
public class UnsafeShuffleWriter<K, V> extends ShuffleWriter<K, V> {
@@ -206,16 +206,18 @@ void closeAndWriteOutput() throws IOException {
206206
final SpillInfo[] spills = sorter.closeAndGetSpills();
207207
sorter = null;
208208
final long[] partitionLengths;
209+
final File output = shuffleBlockResolver.getDataFile(shuffleId, mapId);
210+
final File tmp = Utils.tempFileWith(output);
209211
try {
210-
partitionLengths = mergeSpills(spills);
212+
partitionLengths = mergeSpills(spills, tmp);
211213
} finally {
212214
for (SpillInfo spill : spills) {
213215
if (spill.file.exists() && ! spill.file.delete()) {
214216
logger.error("Error while deleting spill file {}", spill.file.getPath());
215217
}
216218
}
217219
}
218-
shuffleBlockResolver.writeIndexFile(shuffleId, mapId, partitionLengths);
220+
shuffleBlockResolver.writeIndexFileAndCommit(shuffleId, mapId, partitionLengths, tmp);
219221
mapStatus = MapStatus$.MODULE$.apply(blockManager.shuffleServerId(), partitionLengths);
220222
}
221223

@@ -248,8 +250,7 @@ void forceSorterToSpill() throws IOException {
248250
*
249251
* @return the partition lengths in the merged file.
250252
*/
251-
private long[] mergeSpills(SpillInfo[] spills) throws IOException {
252-
final File outputFile = shuffleBlockResolver.getDataFile(shuffleId, mapId);
253+
private long[] mergeSpills(SpillInfo[] spills, File outputFile) throws IOException {
253254
final boolean compressionEnabled = sparkConf.getBoolean("spark.shuffle.compress", true);
254255
final CompressionCodec compressionCodec = CompressionCodec$.MODULE$.createCodec(sparkConf);
255256
final boolean fastMergeEnabled =

core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockResolver.scala

Lines changed: 4 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -21,13 +21,13 @@ import java.util.concurrent.ConcurrentLinkedQueue
2121

2222
import scala.collection.JavaConverters._
2323

24-
import org.apache.spark.{Logging, SparkConf, SparkEnv}
2524
import org.apache.spark.executor.ShuffleWriteMetrics
2625
import org.apache.spark.network.buffer.{FileSegmentManagedBuffer, ManagedBuffer}
2726
import org.apache.spark.network.netty.SparkTransportConf
2827
import org.apache.spark.serializer.Serializer
2928
import org.apache.spark.storage._
30-
import org.apache.spark.util.{MetadataCleaner, MetadataCleanerType, TimeStampedHashMap}
29+
import org.apache.spark.util.{MetadataCleaner, MetadataCleanerType, TimeStampedHashMap, Utils}
30+
import org.apache.spark.{Logging, SparkConf, SparkEnv}
3131

3232
/** A group of writers for a ShuffleMapTask, one writer per reducer. */
3333
private[spark] trait ShuffleWriterGroup {
@@ -84,17 +84,8 @@ private[spark] class FileShuffleBlockResolver(conf: SparkConf)
8484
Array.tabulate[DiskBlockObjectWriter](numReducers) { bucketId =>
8585
val blockId = ShuffleBlockId(shuffleId, mapId, bucketId)
8686
val blockFile = blockManager.diskBlockManager.getFile(blockId)
87-
// Because of previous failures, the shuffle file may already exist on this machine.
88-
// If so, remove it.
89-
if (blockFile.exists) {
90-
if (blockFile.delete()) {
91-
logInfo(s"Removed existing shuffle file $blockFile")
92-
} else {
93-
logWarning(s"Failed to remove existing shuffle file $blockFile")
94-
}
95-
}
96-
blockManager.getDiskWriter(blockId, blockFile, serializerInstance, bufferSize,
97-
writeMetrics)
87+
val tmp = Utils.tempFileWith(blockFile)
88+
blockManager.getDiskWriter(blockId, tmp, serializerInstance, bufferSize, writeMetrics)
9889
}
9990
}
10091
// Creating the file to write to and creating a disk writer both involve interacting with

core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala

Lines changed: 95 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -21,13 +21,12 @@ import java.io._
2121

2222
import com.google.common.io.ByteStreams
2323

24-
import org.apache.spark.{SparkConf, SparkEnv, Logging}
2524
import org.apache.spark.network.buffer.{FileSegmentManagedBuffer, ManagedBuffer}
2625
import org.apache.spark.network.netty.SparkTransportConf
26+
import org.apache.spark.shuffle.IndexShuffleBlockResolver.NOOP_REDUCE_ID
2727
import org.apache.spark.storage._
2828
import org.apache.spark.util.Utils
29-
30-
import IndexShuffleBlockResolver.NOOP_REDUCE_ID
29+
import org.apache.spark.{SparkEnv, Logging, SparkConf}
3130

3231
/**
3332
* Create and maintain the shuffle blocks' mapping between logic block and physical file location.
@@ -40,10 +39,13 @@ import IndexShuffleBlockResolver.NOOP_REDUCE_ID
4039
*/
4140
// Note: Changes to the format in this file should be kept in sync with
4241
// org.apache.spark.network.shuffle.ExternalShuffleBlockResolver#getSortBasedShuffleBlockData().
43-
private[spark] class IndexShuffleBlockResolver(conf: SparkConf) extends ShuffleBlockResolver
42+
private[spark] class IndexShuffleBlockResolver(
43+
conf: SparkConf,
44+
_blockManager: BlockManager = null)
45+
extends ShuffleBlockResolver
4446
with Logging {
4547

46-
private lazy val blockManager = SparkEnv.get.blockManager
48+
private lazy val blockManager = Option(_blockManager).getOrElse(SparkEnv.get.blockManager)
4749

4850
private val transportConf = SparkTransportConf.fromSparkConf(conf)
4951

@@ -74,14 +76,69 @@ private[spark] class IndexShuffleBlockResolver(conf: SparkConf) extends ShuffleB
7476
}
7577
}
7678

79+
/**
80+
* Check whether the given index and data files match each other.
81+
* If so, return the partition lengths in the data file. Otherwise return null.
82+
*/
83+
private def checkIndexAndDataFile(index: File, data: File, blocks: Int): Array[Long] = {
84+
// the index file should have `block + 1` longs as offset.
85+
if (index.length() != (blocks + 1) * 8) {
86+
return null
87+
}
88+
val lengths = new Array[Long](blocks)
89+
// Read the lengths of blocks
90+
val in = try {
91+
new DataInputStream(new BufferedInputStream(new FileInputStream(index)))
92+
} catch {
93+
case e: IOException =>
94+
return null
95+
}
96+
try {
97+
// Convert the offsets into lengths of each block
98+
var offset = in.readLong()
99+
if (offset != 0L) {
100+
return null
101+
}
102+
var i = 0
103+
while (i < blocks) {
104+
val off = in.readLong()
105+
lengths(i) = off - offset
106+
offset = off
107+
i += 1
108+
}
109+
} catch {
110+
case e: IOException =>
111+
return null
112+
} finally {
113+
in.close()
114+
}
115+
116+
// the size of data file should match with index file
117+
if (data.length() == lengths.sum) {
118+
lengths
119+
} else {
120+
null
121+
}
122+
}
123+
77124
/**
78125
* Write an index file with the offsets of each block, plus a final offset at the end for the
79126
* end of the output file. This will be used by getBlockData to figure out where each block
80127
* begins and ends.
128+
*
129+
* It will commit the data and index file as an atomic operation, use the existing ones, or
130+
* replace them with new ones.
131+
*
132+
* Note: the `lengths` will be updated to match the existing index file if use the existing ones.
81133
* */
82-
def writeIndexFile(shuffleId: Int, mapId: Int, lengths: Array[Long]): Unit = {
134+
def writeIndexFileAndCommit(
135+
shuffleId: Int,
136+
mapId: Int,
137+
lengths: Array[Long],
138+
dataTmp: File): Unit = {
83139
val indexFile = getIndexFile(shuffleId, mapId)
84-
val out = new DataOutputStream(new BufferedOutputStream(new FileOutputStream(indexFile)))
140+
val indexTmp = Utils.tempFileWith(indexFile)
141+
val out = new DataOutputStream(new BufferedOutputStream(new FileOutputStream(indexTmp)))
85142
Utils.tryWithSafeFinally {
86143
// We take in lengths of each block, need to convert it to offsets.
87144
var offset = 0L
@@ -93,6 +150,37 @@ private[spark] class IndexShuffleBlockResolver(conf: SparkConf) extends ShuffleB
93150
} {
94151
out.close()
95152
}
153+
154+
val dataFile = getDataFile(shuffleId, mapId)
155+
// There is only one IndexShuffleBlockResolver per executor, this synchronization make sure
156+
// the following check and rename are atomic.
157+
synchronized {
158+
val existingLengths = checkIndexAndDataFile(indexFile, dataFile, lengths.length)
159+
if (existingLengths != null) {
160+
// Another attempt for the same task has already written our map outputs successfully,
161+
// so just use the existing partition lengths and delete our temporary map outputs.
162+
System.arraycopy(existingLengths, 0, lengths, 0, lengths.length)
163+
if (dataTmp != null && dataTmp.exists()) {
164+
dataTmp.delete()
165+
}
166+
indexTmp.delete()
167+
} else {
168+
// This is the first successful attempt in writing the map outputs for this task,
169+
// so override any existing index and data files with the ones we wrote.
170+
if (indexFile.exists()) {
171+
indexFile.delete()
172+
}
173+
if (dataFile.exists()) {
174+
dataFile.delete()
175+
}
176+
if (!indexTmp.renameTo(indexFile)) {
177+
throw new IOException("fail to rename file " + indexTmp + " to " + indexFile)
178+
}
179+
if (dataTmp != null && dataTmp.exists() && !dataTmp.renameTo(dataFile)) {
180+
throw new IOException("fail to rename file " + dataTmp + " to " + dataFile)
181+
}
182+
}
183+
}
96184
}
97185

98186
override def getBlockData(blockId: ShuffleBlockId): ManagedBuffer = {

core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleWriter.scala

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717

1818
package org.apache.spark.shuffle.hash
1919

20+
import java.io.IOException
21+
2022
import org.apache.spark._
2123
import org.apache.spark.executor.ShuffleWriteMetrics
2224
import org.apache.spark.scheduler.MapStatus
@@ -106,6 +108,29 @@ private[spark] class HashShuffleWriter[K, V](
106108
writer.commitAndClose()
107109
writer.fileSegment().length
108110
}
111+
// rename all shuffle files to final paths
112+
// Note: there is only one ShuffleBlockResolver in executor
113+
shuffleBlockResolver.synchronized {
114+
shuffle.writers.zipWithIndex.foreach { case (writer, i) =>
115+
val output = blockManager.diskBlockManager.getFile(writer.blockId)
116+
if (sizes(i) > 0) {
117+
if (output.exists()) {
118+
// Use length of existing file and delete our own temporary one
119+
sizes(i) = output.length()
120+
writer.file.delete()
121+
} else {
122+
// Commit by renaming our temporary file to something the fetcher expects
123+
if (!writer.file.renameTo(output)) {
124+
throw new IOException(s"fail to rename ${writer.file} to $output")
125+
}
126+
}
127+
} else {
128+
if (output.exists()) {
129+
output.delete()
130+
}
131+
}
132+
}
133+
}
109134
MapStatus(blockManager.shuffleServerId, sizes)
110135
}
111136

core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,9 @@ package org.apache.spark.shuffle.sort
2020
import org.apache.spark._
2121
import org.apache.spark.executor.ShuffleWriteMetrics
2222
import org.apache.spark.scheduler.MapStatus
23-
import org.apache.spark.shuffle.{IndexShuffleBlockResolver, ShuffleWriter, BaseShuffleHandle}
23+
import org.apache.spark.shuffle.{BaseShuffleHandle, IndexShuffleBlockResolver, ShuffleWriter}
2424
import org.apache.spark.storage.ShuffleBlockId
25+
import org.apache.spark.util.Utils
2526
import org.apache.spark.util.collection.ExternalSorter
2627

2728
private[spark] class SortShuffleWriter[K, V, C](
@@ -65,11 +66,11 @@ private[spark] class SortShuffleWriter[K, V, C](
6566
// Don't bother including the time to open the merged output file in the shuffle write time,
6667
// because it just opens a single file, so is typically too fast to measure accurately
6768
// (see SPARK-3570).
68-
val outputFile = shuffleBlockResolver.getDataFile(dep.shuffleId, mapId)
69+
val output = shuffleBlockResolver.getDataFile(dep.shuffleId, mapId)
70+
val tmp = Utils.tempFileWith(output)
6971
val blockId = ShuffleBlockId(dep.shuffleId, mapId, IndexShuffleBlockResolver.NOOP_REDUCE_ID)
70-
val partitionLengths = sorter.writePartitionedFile(blockId, outputFile)
71-
shuffleBlockResolver.writeIndexFile(dep.shuffleId, mapId, partitionLengths)
72-
72+
val partitionLengths = sorter.writePartitionedFile(blockId, tmp)
73+
shuffleBlockResolver.writeIndexFileAndCommit(dep.shuffleId, mapId, partitionLengths, tmp)
7374
mapStatus = MapStatus(blockManager.shuffleServerId, partitionLengths)
7475
}
7576

core/src/main/scala/org/apache/spark/storage/BlockManager.scala

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,10 +21,10 @@ import java.io._
2121
import java.nio.{ByteBuffer, MappedByteBuffer}
2222

2323
import scala.collection.mutable.{ArrayBuffer, HashMap}
24-
import scala.concurrent.{ExecutionContext, Await, Future}
2524
import scala.concurrent.duration._
26-
import scala.util.control.NonFatal
25+
import scala.concurrent.{Await, ExecutionContext, Future}
2726
import scala.util.Random
27+
import scala.util.control.NonFatal
2828

2929
import sun.nio.ch.DirectBuffer
3030

@@ -38,9 +38,8 @@ import org.apache.spark.network.netty.SparkTransportConf
3838
import org.apache.spark.network.shuffle.ExternalShuffleClient
3939
import org.apache.spark.network.shuffle.protocol.ExecutorShuffleInfo
4040
import org.apache.spark.rpc.RpcEnv
41-
import org.apache.spark.serializer.{SerializerInstance, Serializer}
41+
import org.apache.spark.serializer.{Serializer, SerializerInstance}
4242
import org.apache.spark.shuffle.ShuffleManager
43-
import org.apache.spark.shuffle.hash.HashShuffleManager
4443
import org.apache.spark.util._
4544

4645
private[spark] sealed trait BlockValues
@@ -660,7 +659,7 @@ private[spark] class BlockManager(
660659
val compressStream: OutputStream => OutputStream = wrapForCompression(blockId, _)
661660
val syncWrites = conf.getBoolean("spark.shuffle.sync", false)
662661
new DiskBlockObjectWriter(file, serializerInstance, bufferSize, compressStream,
663-
syncWrites, writeMetrics)
662+
syncWrites, writeMetrics, blockId)
664663
}
665664

666665
/**

core/src/main/scala/org/apache/spark/storage/DiskBlockObjectWriter.scala

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,14 +34,15 @@ import org.apache.spark.util.Utils
3434
* reopened again.
3535
*/
3636
private[spark] class DiskBlockObjectWriter(
37-
file: File,
37+
val file: File,
3838
serializerInstance: SerializerInstance,
3939
bufferSize: Int,
4040
compressStream: OutputStream => OutputStream,
4141
syncWrites: Boolean,
4242
// These write metrics concurrently shared with other active DiskBlockObjectWriters who
4343
// are themselves performing writes. All updates must be relative.
44-
writeMetrics: ShuffleWriteMetrics)
44+
writeMetrics: ShuffleWriteMetrics,
45+
val blockId: BlockId = null)
4546
extends OutputStream
4647
with Logging {
4748

0 commit comments

Comments
 (0)