From 7eba87a7e73873d6fea5e4c61662a6287ed6cb8c Mon Sep 17 00:00:00 2001 From: mcheah Date: Thu, 12 Sep 2019 18:12:43 -0700 Subject: [PATCH 1/3] [SPARK-29072] Put back usage of TimeTrackingOutputStream for UnsafeShuffleWriter and ShufflePartitionPairsWriter. --- .../spark/shuffle/sort/UnsafeShuffleWriter.java | 2 ++ .../spark/shuffle/ShufflePartitionPairsWriter.scala | 13 +++++++++++-- 2 files changed, 13 insertions(+), 2 deletions(-) diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java index f59bddc993639..52e6339ffa2f2 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java @@ -25,6 +25,7 @@ import java.nio.channels.WritableByteChannel; import java.util.Iterator; +import org.apache.spark.storage.TimeTrackingOutputStream; import scala.Option; import scala.Product2; import scala.collection.JavaConverters; @@ -382,6 +383,7 @@ private void mergeSpillsWithFileStream( ShufflePartitionWriter writer = mapWriter.getPartitionWriter(partition); OutputStream partitionOutput = writer.openStream(); try { + partitionOutput = new TimeTrackingOutputStream(writeMetrics, partitionOutput); partitionOutput = blockManager.serializerManager().wrapForEncryption(partitionOutput); if (compressionCodec != null) { partitionOutput = compressionCodec.compressedOutputStream(partitionOutput); diff --git a/core/src/main/scala/org/apache/spark/shuffle/ShufflePartitionPairsWriter.scala b/core/src/main/scala/org/apache/spark/shuffle/ShufflePartitionPairsWriter.scala index a988c5e126a76..e83254025b883 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/ShufflePartitionPairsWriter.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/ShufflePartitionPairsWriter.scala @@ -21,7 +21,7 @@ import java.io.{Closeable, IOException, OutputStream} import org.apache.spark.serializer.{SerializationStream, SerializerInstance, SerializerManager} import org.apache.spark.shuffle.api.ShufflePartitionWriter -import org.apache.spark.storage.BlockId +import org.apache.spark.storage.{BlockId, TimeTrackingOutputStream} import org.apache.spark.util.Utils import org.apache.spark.util.collection.PairsWriter @@ -39,6 +39,7 @@ private[spark] class ShufflePartitionPairsWriter( private var isClosed = false private var partitionStream: OutputStream = _ + private var timeTrackingStream: OutputStream = _ private var wrappedStream: OutputStream = _ private var objOut: SerializationStream = _ private var numRecordsWritten = 0 @@ -59,6 +60,7 @@ private[spark] class ShufflePartitionPairsWriter( private def open(): Unit = { try { partitionStream = partitionWriter.openStream + timeTrackingStream = new TimeTrackingOutputStream(writeMetrics, partitionStream) wrappedStream = serializerManager.wrapStream(blockId, partitionStream) objOut = serializerInstance.serializeStream(wrappedStream) } catch { @@ -78,6 +80,7 @@ private[spark] class ShufflePartitionPairsWriter( // Setting these to null will prevent the underlying streams from being closed twice // just in case any stream's close() implementation is not idempotent. wrappedStream = null + timeTrackingStream = null partitionStream = null } { // Normally closing objOut would close the inner streams as well, but just in case there @@ -86,9 +89,15 @@ private[spark] class ShufflePartitionPairsWriter( wrappedStream = closeIfNonNull(wrappedStream) // Same as above - if wrappedStream closes then assume it closes underlying // partitionStream and don't close again in the finally + timeTrackingStream = null partitionStream = null } { - partitionStream = closeIfNonNull(partitionStream) + Utils.tryWithSafeFinally { + timeTrackingStream = closeIfNonNull(timeTrackingStream) + partitionStream = null + } { + partitionStream = closeIfNonNull(partitionStream) + } } } updateBytesWritten() From fe5b93fb11b02a07aa8941516d15544d6a6f1290 Mon Sep 17 00:00:00 2001 From: mcheah Date: Thu, 12 Sep 2019 18:17:27 -0700 Subject: [PATCH 2/3] Import ordering --- .../java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java index 52e6339ffa2f2..4d11abd36985e 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java @@ -25,7 +25,6 @@ import java.nio.channels.WritableByteChannel; import java.util.Iterator; -import org.apache.spark.storage.TimeTrackingOutputStream; import scala.Option; import scala.Product2; import scala.collection.JavaConverters; @@ -58,6 +57,7 @@ import org.apache.spark.shuffle.api.SingleSpillShuffleMapOutputWriter; import org.apache.spark.shuffle.api.WritableByteChannelWrapper; import org.apache.spark.storage.BlockManager; +import org.apache.spark.storage.TimeTrackingOutputStream; import org.apache.spark.unsafe.Platform; import org.apache.spark.util.Utils; From 08a6068a0261a37e52d02e62ee3fc5395e653c91 Mon Sep 17 00:00:00 2001 From: mcheah Date: Thu, 12 Sep 2019 18:27:03 -0700 Subject: [PATCH 3/3] Fix stream reference --- .../org/apache/spark/shuffle/ShufflePartitionPairsWriter.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/shuffle/ShufflePartitionPairsWriter.scala b/core/src/main/scala/org/apache/spark/shuffle/ShufflePartitionPairsWriter.scala index e83254025b883..e0affb858c359 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/ShufflePartitionPairsWriter.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/ShufflePartitionPairsWriter.scala @@ -61,7 +61,7 @@ private[spark] class ShufflePartitionPairsWriter( try { partitionStream = partitionWriter.openStream timeTrackingStream = new TimeTrackingOutputStream(writeMetrics, partitionStream) - wrappedStream = serializerManager.wrapStream(blockId, partitionStream) + wrappedStream = serializerManager.wrapStream(blockId, timeTrackingStream) objOut = serializerInstance.serializeStream(wrappedStream) } catch { case e: Exception =>