Skip to content

Commit ea90ea6

Browse files
mccheahMarcelo Vanzin
authored andcommitted
[SPARK-28571][CORE][SHUFFLE] Use the shuffle writer plugin for the SortShuffleWriter
## What changes were proposed in this pull request? Use the shuffle writer APIs introduced in SPARK-28209 in the sort shuffle writer. ## How was this patch tested? Existing unit tests were changed to use the plugin instead, and they used the local disk version to ensure that there were no regressions. Closes apache#25342 from mccheah/shuffle-writer-refactor-sort-shuffle-writer. Lead-authored-by: mcheah <[email protected]> Co-authored-by: mccheah <[email protected]> Signed-off-by: Marcelo Vanzin <[email protected]>
1 parent 92cabf6 commit ea90ea6

File tree

8 files changed

+265
-31
lines changed

8 files changed

+265
-31
lines changed
Lines changed: 126 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,126 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.shuffle
19+
20+
import java.io.{Closeable, IOException, OutputStream}
21+
22+
import org.apache.spark.serializer.{SerializationStream, SerializerInstance, SerializerManager}
23+
import org.apache.spark.shuffle.api.ShufflePartitionWriter
24+
import org.apache.spark.storage.BlockId
25+
import org.apache.spark.util.Utils
26+
import org.apache.spark.util.collection.PairsWriter
27+
28+
/**
29+
* A key-value writer inspired by {@link DiskBlockObjectWriter} that pushes the bytes to an
30+
* arbitrary partition writer instead of writing to local disk through the block manager.
31+
*/
32+
private[spark] class ShufflePartitionPairsWriter(
33+
partitionWriter: ShufflePartitionWriter,
34+
serializerManager: SerializerManager,
35+
serializerInstance: SerializerInstance,
36+
blockId: BlockId,
37+
writeMetrics: ShuffleWriteMetricsReporter)
38+
extends PairsWriter with Closeable {
39+
40+
private var isClosed = false
41+
private var partitionStream: OutputStream = _
42+
private var wrappedStream: OutputStream = _
43+
private var objOut: SerializationStream = _
44+
private var numRecordsWritten = 0
45+
private var curNumBytesWritten = 0L
46+
47+
override def write(key: Any, value: Any): Unit = {
48+
if (isClosed) {
49+
throw new IOException("Partition pairs writer is already closed.")
50+
}
51+
if (objOut == null) {
52+
open()
53+
}
54+
objOut.writeKey(key)
55+
objOut.writeValue(value)
56+
recordWritten()
57+
}
58+
59+
private def open(): Unit = {
60+
try {
61+
partitionStream = partitionWriter.openStream
62+
wrappedStream = serializerManager.wrapStream(blockId, partitionStream)
63+
objOut = serializerInstance.serializeStream(wrappedStream)
64+
} catch {
65+
case e: Exception =>
66+
Utils.tryLogNonFatalError {
67+
close()
68+
}
69+
throw e
70+
}
71+
}
72+
73+
override def close(): Unit = {
74+
if (!isClosed) {
75+
Utils.tryWithSafeFinally {
76+
Utils.tryWithSafeFinally {
77+
objOut = closeIfNonNull(objOut)
78+
// Setting these to null will prevent the underlying streams from being closed twice
79+
// just in case any stream's close() implementation is not idempotent.
80+
wrappedStream = null
81+
partitionStream = null
82+
} {
83+
// Normally closing objOut would close the inner streams as well, but just in case there
84+
// was an error in initialization etc. we make sure we clean the other streams up too.
85+
Utils.tryWithSafeFinally {
86+
wrappedStream = closeIfNonNull(wrappedStream)
87+
// Same as above - if wrappedStream closes then assume it closes underlying
88+
// partitionStream and don't close again in the finally
89+
partitionStream = null
90+
} {
91+
partitionStream = closeIfNonNull(partitionStream)
92+
}
93+
}
94+
updateBytesWritten()
95+
} {
96+
isClosed = true
97+
}
98+
}
99+
}
100+
101+
private def closeIfNonNull[T <: Closeable](closeable: T): T = {
102+
if (closeable != null) {
103+
closeable.close()
104+
}
105+
null.asInstanceOf[T]
106+
}
107+
108+
/**
109+
* Notify the writer that a record worth of bytes has been written with OutputStream#write.
110+
*/
111+
private def recordWritten(): Unit = {
112+
numRecordsWritten += 1
113+
writeMetrics.incRecordsWritten(1)
114+
115+
if (numRecordsWritten % 16384 == 0) {
116+
updateBytesWritten()
117+
}
118+
}
119+
120+
private def updateBytesWritten(): Unit = {
121+
val numBytesWritten = partitionWriter.getNumBytesWritten
122+
val bytesWrittenDiff = numBytesWritten - curNumBytesWritten
123+
writeMetrics.incBytesWritten(bytesWrittenDiff)
124+
curNumBytesWritten = numBytesWritten
125+
}
126+
}

core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -157,7 +157,8 @@ private[spark] class SortShuffleManager(conf: SparkConf) extends ShuffleManager
157157
metrics,
158158
shuffleExecutorComponents)
159159
case other: BaseShuffleHandle[K @unchecked, V @unchecked, _] =>
160-
new SortShuffleWriter(shuffleBlockResolver, other, mapId, context)
160+
new SortShuffleWriter(
161+
shuffleBlockResolver, other, mapId, context, shuffleExecutorComponents)
161162
}
162163
}
163164

core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala

Lines changed: 8 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -21,15 +21,15 @@ import org.apache.spark._
2121
import org.apache.spark.internal.{config, Logging}
2222
import org.apache.spark.scheduler.MapStatus
2323
import org.apache.spark.shuffle.{BaseShuffleHandle, IndexShuffleBlockResolver, ShuffleWriter}
24-
import org.apache.spark.storage.ShuffleBlockId
25-
import org.apache.spark.util.Utils
24+
import org.apache.spark.shuffle.api.ShuffleExecutorComponents
2625
import org.apache.spark.util.collection.ExternalSorter
2726

2827
private[spark] class SortShuffleWriter[K, V, C](
2928
shuffleBlockResolver: IndexShuffleBlockResolver,
3029
handle: BaseShuffleHandle[K, V, C],
3130
mapId: Int,
32-
context: TaskContext)
31+
context: TaskContext,
32+
shuffleExecutorComponents: ShuffleExecutorComponents)
3333
extends ShuffleWriter[K, V] with Logging {
3434

3535
private val dep = handle.dependency
@@ -64,18 +64,11 @@ private[spark] class SortShuffleWriter[K, V, C](
6464
// Don't bother including the time to open the merged output file in the shuffle write time,
6565
// because it just opens a single file, so is typically too fast to measure accurately
6666
// (see SPARK-3570).
67-
val output = shuffleBlockResolver.getDataFile(dep.shuffleId, mapId)
68-
val tmp = Utils.tempFileWith(output)
69-
try {
70-
val blockId = ShuffleBlockId(dep.shuffleId, mapId, IndexShuffleBlockResolver.NOOP_REDUCE_ID)
71-
val partitionLengths = sorter.writePartitionedFile(blockId, tmp)
72-
shuffleBlockResolver.writeIndexFileAndCommit(dep.shuffleId, mapId, partitionLengths, tmp)
73-
mapStatus = MapStatus(blockManager.shuffleServerId, partitionLengths)
74-
} finally {
75-
if (tmp.exists() && !tmp.delete()) {
76-
logError(s"Error while deleting temp file ${tmp.getAbsolutePath}")
77-
}
78-
}
67+
val mapOutputWriter = shuffleExecutorComponents.createMapOutputWriter(
68+
dep.shuffleId, mapId, context.taskAttemptId(), dep.partitioner.numPartitions)
69+
sorter.writePartitionedMapOutput(dep.shuffleId, mapId, mapOutputWriter)
70+
val partitionLengths = mapOutputWriter.commitAllPartitions()
71+
mapStatus = MapStatus(blockManager.shuffleServerId, partitionLengths)
7972
}
8073

8174
/** Close this writer, passing along whether the map completed */

core/src/main/scala/org/apache/spark/storage/DiskBlockObjectWriter.scala

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ import org.apache.spark.internal.Logging
2424
import org.apache.spark.serializer.{SerializationStream, SerializerInstance, SerializerManager}
2525
import org.apache.spark.shuffle.ShuffleWriteMetricsReporter
2626
import org.apache.spark.util.Utils
27+
import org.apache.spark.util.collection.PairsWriter
2728

2829
/**
2930
* A class for writing JVM objects directly to a file on disk. This class allows data to be appended
@@ -46,7 +47,8 @@ private[spark] class DiskBlockObjectWriter(
4647
writeMetrics: ShuffleWriteMetricsReporter,
4748
val blockId: BlockId = null)
4849
extends OutputStream
49-
with Logging {
50+
with Logging
51+
with PairsWriter {
5052

5153
/**
5254
* Guards against close calls, e.g. from a wrapping stream.
@@ -232,7 +234,7 @@ private[spark] class DiskBlockObjectWriter(
232234
/**
233235
* Writes a key-value pair.
234236
*/
235-
def write(key: Any, value: Any) {
237+
override def write(key: Any, value: Any) {
236238
if (!streamOpen) {
237239
open()
238240
}

core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala

Lines changed: 80 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -23,13 +23,16 @@ import java.util.Comparator
2323
import scala.collection.mutable
2424
import scala.collection.mutable.ArrayBuffer
2525

26-
import com.google.common.io.ByteStreams
26+
import com.google.common.io.{ByteStreams, Closeables}
2727

2828
import org.apache.spark._
2929
import org.apache.spark.executor.ShuffleWriteMetrics
3030
import org.apache.spark.internal.{config, Logging}
3131
import org.apache.spark.serializer._
32-
import org.apache.spark.storage.{BlockId, DiskBlockObjectWriter}
32+
import org.apache.spark.shuffle.ShufflePartitionPairsWriter
33+
import org.apache.spark.shuffle.api.{ShuffleMapOutputWriter, ShufflePartitionWriter}
34+
import org.apache.spark.storage.{BlockId, DiskBlockObjectWriter, ShuffleBlockId}
35+
import org.apache.spark.util.{Utils => TryUtils}
3336

3437
/**
3538
* Sorts and potentially merges a number of key-value pairs of type (K, V) to produce key-combiner
@@ -670,11 +673,9 @@ private[spark] class ExternalSorter[K, V, C](
670673
}
671674

672675
/**
673-
* Write all the data added into this ExternalSorter into a file in the disk store. This is
674-
* called by the SortShuffleWriter.
675-
*
676-
* @param blockId block ID to write to. The index file will be blockId.name + ".index".
677-
* @return array of lengths, in bytes, of each partition of the file (used by map output tracker)
676+
* TODO(SPARK-28764): remove this, as this is only used by UnsafeRowSerializerSuite in the SQL
677+
* project. We should figure out an alternative way to test that so that we can remove this
678+
* otherwise unused code path.
678679
*/
679680
def writePartitionedFile(
680681
blockId: BlockId,
@@ -718,6 +719,77 @@ private[spark] class ExternalSorter[K, V, C](
718719
lengths
719720
}
720721

722+
/**
723+
* Write all the data added into this ExternalSorter into a map output writer that pushes bytes
724+
* to some arbitrary backing store. This is called by the SortShuffleWriter.
725+
*
726+
* @return array of lengths, in bytes, of each partition of the file (used by map output tracker)
727+
*/
728+
def writePartitionedMapOutput(
729+
shuffleId: Int,
730+
mapId: Int,
731+
mapOutputWriter: ShuffleMapOutputWriter): Unit = {
732+
var nextPartitionId = 0
733+
if (spills.isEmpty) {
734+
// Case where we only have in-memory data
735+
val collection = if (aggregator.isDefined) map else buffer
736+
val it = collection.destructiveSortedWritablePartitionedIterator(comparator)
737+
while (it.hasNext()) {
738+
val partitionId = it.nextPartition()
739+
var partitionWriter: ShufflePartitionWriter = null
740+
var partitionPairsWriter: ShufflePartitionPairsWriter = null
741+
TryUtils.tryWithSafeFinally {
742+
partitionWriter = mapOutputWriter.getPartitionWriter(partitionId)
743+
val blockId = ShuffleBlockId(shuffleId, mapId, partitionId)
744+
partitionPairsWriter = new ShufflePartitionPairsWriter(
745+
partitionWriter,
746+
serializerManager,
747+
serInstance,
748+
blockId,
749+
context.taskMetrics().shuffleWriteMetrics)
750+
while (it.hasNext && it.nextPartition() == partitionId) {
751+
it.writeNext(partitionPairsWriter)
752+
}
753+
} {
754+
if (partitionPairsWriter != null) {
755+
partitionPairsWriter.close()
756+
}
757+
}
758+
nextPartitionId = partitionId + 1
759+
}
760+
} else {
761+
// We must perform merge-sort; get an iterator by partition and write everything directly.
762+
for ((id, elements) <- this.partitionedIterator) {
763+
val blockId = ShuffleBlockId(shuffleId, mapId, id)
764+
var partitionWriter: ShufflePartitionWriter = null
765+
var partitionPairsWriter: ShufflePartitionPairsWriter = null
766+
TryUtils.tryWithSafeFinally {
767+
partitionWriter = mapOutputWriter.getPartitionWriter(id)
768+
partitionPairsWriter = new ShufflePartitionPairsWriter(
769+
partitionWriter,
770+
serializerManager,
771+
serInstance,
772+
blockId,
773+
context.taskMetrics().shuffleWriteMetrics)
774+
if (elements.hasNext) {
775+
for (elem <- elements) {
776+
partitionPairsWriter.write(elem._1, elem._2)
777+
}
778+
}
779+
} {
780+
if (partitionPairsWriter != null) {
781+
partitionPairsWriter.close()
782+
}
783+
}
784+
nextPartitionId = id + 1
785+
}
786+
}
787+
788+
context.taskMetrics().incMemoryBytesSpilled(memoryBytesSpilled)
789+
context.taskMetrics().incDiskBytesSpilled(diskBytesSpilled)
790+
context.taskMetrics().incPeakExecutionMemory(peakMemoryUsedBytes)
791+
}
792+
721793
def stop(): Unit = {
722794
spills.foreach(s => s.file.delete())
723795
spills.clear()
@@ -781,7 +853,7 @@ private[spark] class ExternalSorter[K, V, C](
781853
val inMemoryIterator = new WritablePartitionedIterator {
782854
private[this] var cur = if (upstream.hasNext) upstream.next() else null
783855

784-
def writeNext(writer: DiskBlockObjectWriter): Unit = {
856+
def writeNext(writer: PairsWriter): Unit = {
785857
writer.write(cur._1._2, cur._2)
786858
cur = if (upstream.hasNext) upstream.next() else null
787859
}
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.util.collection
19+
20+
/**
21+
* An abstraction of a consumer of key-value pairs, primarily used when
22+
* persisting partitioned data, either through the shuffle writer plugins
23+
* or via DiskBlockObjectWriter.
24+
*/
25+
private[spark] trait PairsWriter {
26+
27+
def write(key: Any, value: Any): Unit
28+
}

core/src/main/scala/org/apache/spark/util/collection/WritablePartitionedPairCollection.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ private[spark] trait WritablePartitionedPairCollection[K, V] {
5252
new WritablePartitionedIterator {
5353
private[this] var cur = if (it.hasNext) it.next() else null
5454

55-
def writeNext(writer: DiskBlockObjectWriter): Unit = {
55+
def writeNext(writer: PairsWriter): Unit = {
5656
writer.write(cur._1._2, cur._2)
5757
cur = if (it.hasNext) it.next() else null
5858
}
@@ -89,7 +89,7 @@ private[spark] object WritablePartitionedPairCollection {
8989
* has an associated partition.
9090
*/
9191
private[spark] trait WritablePartitionedIterator {
92-
def writeNext(writer: DiskBlockObjectWriter): Unit
92+
def writeNext(writer: PairsWriter): Unit
9393

9494
def hasNext(): Boolean
9595

0 commit comments

Comments
 (0)