Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,13 @@
*/
package org.apache.spark.shuffle.api;

import java.io.IOException;

public interface ShuffleDataIO {

void initialize();
void initialize() throws IOException;

ShuffleReadSupport readSupport();
ShuffleReadSupport readSupport() throws IOException;

ShuffleWriteSupport writeSupport();
ShuffleWriteSupport writeSupport() throws IOException;
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,13 @@

package org.apache.spark.shuffle.api;

import java.io.IOException;

public interface ShuffleMapOutputWriter {

ShufflePartitionWriter newPartitionWriter(int partitionId);
ShufflePartitionWriter newPartitionWriter(int partitionId) throws IOException;

void commitAllPartitions();
void commitAllPartitions() throws IOException;

void abort(Exception exception);
void abort(Exception exception) throws IOException;
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,14 @@

package org.apache.spark.shuffle.api;

import org.apache.spark.storage.ShuffleLocation;

import java.io.InputStream;
import java.io.IOException;
import java.util.Optional;

import org.apache.spark.storage.ShuffleLocation;

public interface ShufflePartitionReader {

InputStream fetchPartition(int reduceId, Optional<ShuffleLocation> shuffleLocation);
InputStream fetchPartition(int reduceId, Optional<ShuffleLocation> shuffleLocation)
throws IOException;
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

package org.apache.spark.shuffle.api;

import java.io.IOException;
import java.io.OutputStream;

/**
Expand All @@ -27,18 +28,18 @@ public interface ShufflePartitionWriter {
/**
* Return a stream that should persist the bytes for this partition.
*/
OutputStream openPartitionStream();
OutputStream openPartitionStream() throws IOException;

/**
* Indicate that the partition was written successfully and there are no more incoming bytes.
* Returns a {@link CommittedPartition} indicating information about that written partition.
*/
CommittedPartition commitPartition();
CommittedPartition commitPartition() throws IOException;

/**
* Indicate that the write has failed for some reason and the implementation can handle the
* failure reason. After this method is called, this writer will be discarded; it's expected that
* the implementation will close any underlying resources.
*/
void abort(Exception failureReason);
void abort(Exception failureReason) throws IOException;
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,11 @@

package org.apache.spark.shuffle.api;

import java.io.IOException;

public interface ShuffleReadSupport {

ShufflePartitionReader newPartitionReader(String appId, int shuffleId, int mapId);
ShufflePartitionReader newPartitionReader(String appId, int shuffleId, int mapId)
throws IOException;

}
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,10 @@

package org.apache.spark.shuffle.api;

import java.io.IOException;

public interface ShuffleWriteSupport {

ShuffleMapOutputWriter newMapOutputWriter(String appId, int shuffleId, int mapId);
ShuffleMapOutputWriter newMapOutputWriter(String appId, int shuffleId, int mapId)
throws IOException;
}
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
package org.apache.spark.shuffle.external;

import scala.compat.java8.OptionConverters;

import com.google.common.collect.Lists;
import org.apache.spark.MapOutputTracker;
import org.apache.spark.network.TransportContext;
Expand All @@ -13,7 +15,6 @@
import org.apache.spark.storage.ShuffleLocation;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import scala.compat.java8.OptionConverters;

import java.util.List;
import java.util.Optional;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ package org.apache.spark.shuffle
import scala.compat.java8.OptionConverters

import org.apache.spark._
import org.apache.spark.internal.{Logging, config}
import org.apache.spark.internal.{config, Logging}
import org.apache.spark.serializer.SerializerManager
import org.apache.spark.shuffle.api.ShuffleReadSupport
import org.apache.spark.storage._
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,38 +20,36 @@ package org.apache.spark.shuffle
import java.io.{InputStream, OutputStream}
import java.nio.ByteBuffer

import org.apache.spark.network.util.LimitedInputStream
import org.apache.spark.serializer.SerializerManager
import org.apache.spark.shuffle.api.ShufflePartitionWriter
import org.apache.spark.storage.ShuffleBlockId
import org.apache.spark.util.{ByteBufferInputStream, Utils}

class ShufflePartitionWriterOutputStream(
partitionWriter: ShufflePartitionWriter, buffer: ByteBuffer, bufferSize: Int)
extends OutputStream {
blockId: ShuffleBlockId,
partitionWriter: ShufflePartitionWriter,
buffer: ByteBuffer,
serializerManager: SerializerManager)
extends OutputStream {

private var currentChunkSize = 0
private val bufferForRead = buffer.asReadOnlyBuffer()
private var underlyingOutputStream: OutputStream = _

override def write(b: Int): Unit = {
buffer.putInt(b)
currentChunkSize += 1
if (currentChunkSize == bufferSize) {
buffer.put(b.asInstanceOf[Byte])
if (buffer.remaining() == 0) {
pushBufferedBytesToUnderlyingOutput()
}
}

private def pushBufferedBytesToUnderlyingOutput(): Unit = {
bufferForRead.reset()
var bufferInputStream: InputStream = new ByteBufferInputStream(bufferForRead)
if (currentChunkSize < bufferSize) {
bufferInputStream = new LimitedInputStream(bufferInputStream, currentChunkSize)
}
buffer.flip()
var bufferInputStream: InputStream = new ByteBufferInputStream(buffer)
if (underlyingOutputStream == null) {
underlyingOutputStream = partitionWriter.openPartitionStream()
underlyingOutputStream = serializerManager.wrapStream(blockId,
partitionWriter.openPartitionStream())
}
Utils.copyStream(bufferInputStream, underlyingOutputStream, false, false)
buffer.reset()
currentChunkSize = 0
buffer.clear()
}

override def flush(): Unit = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ private[spark] class SortShuffleWriter[K, V, C](
try {
val blockId = ShuffleBlockId(dep.shuffleId, mapId, IndexShuffleBlockResolver.NOOP_REDUCE_ID)
val committedPartitions = pluggableWriteSupport.map { writeSupport =>
sorter.writePartitionedToExternalShuffleWriteSupport(mapId, dep.shuffleId, writeSupport)
sorter.writePartitionedToExternalShuffleWriteSupport(blockId, writeSupport)
}.getOrElse(sorter.writePartitionedFile(blockId, tmp))
if (pluggableWriteSupport.isEmpty) {
shuffleBlockResolver.writeIndexFileAndCommit(dep.shuffleId,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
package org.apache.spark.storage

import java.io._
import java.lang.ref.{WeakReference, ReferenceQueue => JReferenceQueue}
import java.lang.ref.{ReferenceQueue => JReferenceQueue, WeakReference}
import java.nio.ByteBuffer
import java.nio.channels.Channels
import java.util.Collections
Expand All @@ -31,11 +31,12 @@ import scala.concurrent.duration._
import scala.reflect.ClassTag
import scala.util.Random
import scala.util.control.NonFatal

import com.codahale.metrics.{MetricRegistry, MetricSet}

import org.apache.spark._
import org.apache.spark.executor.DataReadMethod
import org.apache.spark.internal.{Logging, config}
import org.apache.spark.internal.{config, Logging}
import org.apache.spark.memory.{MemoryManager, MemoryMode}
import org.apache.spark.metrics.source.Source
import org.apache.spark.network._
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ package org.apache.spark.storage

import java.nio.ByteBuffer

import org.apache.spark.serializer.{SerializationStream, SerializerInstance}
import org.apache.spark.serializer.{SerializationStream, SerializerInstance, SerializerManager}
import org.apache.spark.shuffle.ShufflePartitionWriterOutputStream
import org.apache.spark.shuffle.api.{CommittedPartition, ShuffleMapOutputWriter, ShufflePartitionWriter}

Expand All @@ -30,10 +30,12 @@ import org.apache.spark.shuffle.api.{CommittedPartition, ShuffleMapOutputWriter,
* left to the implementation of the underlying implementation of the writer plugin.
*/
private[spark] class ShufflePartitionObjectWriter(
blockId: ShuffleBlockId,
bufferSize: Int,
serializerInstance: SerializerInstance,
serializerManager: SerializerManager,
mapOutputWriter: ShuffleMapOutputWriter)
extends PairsWriter {
extends PairsWriter {

// Reused buffer. Experiments should be done with off-heap at some point.
private val buffer = ByteBuffer.allocate(bufferSize)
Expand All @@ -44,10 +46,9 @@ private[spark] class ShufflePartitionObjectWriter(
def startNewPartition(partitionId: Int): Unit = {
require(buffer.position() == 0,
"Buffer was not flushed to the underlying output on the previous partition.")
buffer.reset()
currentWriter = mapOutputWriter.newPartitionWriter(partitionId)
val currentWriterStream = new ShufflePartitionWriterOutputStream(
currentWriter, buffer, bufferSize)
blockId, currentWriter, buffer, serializerManager)
objectOutputStream = serializerInstance.serializeStream(currentWriterStream)
}

Expand All @@ -56,7 +57,7 @@ private[spark] class ShufflePartitionObjectWriter(
require(currentWriter != null, "Cannot commit a partition that has not been started.")
objectOutputStream.close()
val committedPartition = currentWriter.commitPartition()
buffer.reset()
buffer.clear()
currentWriter = null
objectOutputStream = null
committedPartition
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,13 @@ import com.google.common.io.ByteStreams
import scala.collection.mutable
import scala.collection.mutable.ArrayBuffer

import org.apache.spark.{util, _}
import org.apache.spark.executor.ShuffleWriteMetrics
import org.apache.spark.internal.Logging
import org.apache.spark.serializer._
import org.apache.spark.shuffle.api.{CommittedPartition, ShuffleWriteSupport}
import org.apache.spark.shuffle.sort.LocalCommittedPartition
import org.apache.spark.storage.{BlockId, DiskBlockObjectWriter, PairsWriter, ShuffleLocation, ShufflePartitionObjectWriter}
import org.apache.spark.{util, _}
import org.apache.spark.storage._

/**
* Sorts and potentially merges a number of key-value pairs of type (K, V) to produce key-combiner
Expand Down Expand Up @@ -727,14 +727,18 @@ private[spark] class ExternalSorter[K, V, C](
* Write all partitions to some backend that is pluggable.
*/
def writePartitionedToExternalShuffleWriteSupport(
mapId: Int, shuffleId: Int, writeSupport: ShuffleWriteSupport): Array[CommittedPartition] = {
blockId: ShuffleBlockId,
writeSupport: ShuffleWriteSupport): Array[CommittedPartition] = {

// Track location of each range in the output file
val committedPartitions = new Array[CommittedPartition](numPartitions)
val mapOutputWriter = writeSupport.newMapOutputWriter(conf.getAppId, shuffleId, mapId)
val mapOutputWriter = writeSupport.newMapOutputWriter(conf.getAppId, blockId.shuffleId,
blockId.mapId)
val writer = new ShufflePartitionObjectWriter(
blockId,
Math.min(serializerBatchSize, Integer.MAX_VALUE).toInt,
serInstance,
serializerManager,
mapOutputWriter)

try {
Expand Down Expand Up @@ -781,6 +785,7 @@ private[spark] class ExternalSorter[K, V, C](
mapOutputWriter.commitAllPartitions()
} catch {
case e: Exception =>
logError("Error writing shuffle data.", e)
util.Utils.tryLogNonFatalError {
writer.abortCurrentPartition(e)
mapOutputWriter.abort(e)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -139,10 +139,4 @@ class KubernetesShuffleServiceAddressProvider(

override def onClose(e: KubernetesClientException): Unit = {}
}

private implicit def toRunnable(func: () => Unit): Runnable = {
new Runnable {
override def run(): Unit = func()
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,9 @@ private[spark] class MesosClusterManager extends ExternalClusterManager {

override def initialize(scheduler: TaskScheduler, backend: SchedulerBackend): Unit = {
scheduler.asInstanceOf[TaskSchedulerImpl].initialize(backend)
}

override def createShuffleServiceAddressProvider(): ShuffleServiceAddressProvider =
def createShuffleServiceAddressProvider(): ShuffleServiceAddressProvider = {
DefaultShuffleServiceAddressProvider
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,8 @@ private[spark] class YarnClusterManager extends ExternalClusterManager {
override def initialize(scheduler: TaskScheduler, backend: SchedulerBackend): Unit = {
scheduler.asInstanceOf[TaskSchedulerImpl].initialize(backend)
}
def createShuffleServiceAddressProvider(): ShuffleServiceAddressProvider =

def createShuffleServiceAddressProvider(): ShuffleServiceAddressProvider = {
DefaultShuffleServiceAddressProvider
}
}