diff --git a/core/src/main/scala/org/apache/spark/scheduler/Task.scala b/core/src/main/scala/org/apache/spark/scheduler/Task.scala index a8bcb7dfe2f3c..f2a4422f05a66 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/Task.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/Task.scala @@ -17,7 +17,7 @@ package org.apache.spark.scheduler -import java.io.{ByteArrayOutputStream, DataInputStream, DataOutputStream} +import java.io.{DataInputStream, DataOutputStream} import java.nio.ByteBuffer import scala.collection.mutable.HashMap @@ -25,7 +25,7 @@ import scala.collection.mutable.HashMap import org.apache.spark.TaskContext import org.apache.spark.executor.TaskMetrics import org.apache.spark.serializer.SerializerInstance -import org.apache.spark.util.ByteBufferInputStream +import org.apache.spark.util.io.{ByteBufferInputStream, FastByteArrayOutputStream} /** * A unit of execution. We have two kinds of Task's in Spark: @@ -102,7 +102,7 @@ private[spark] object Task { serializer: SerializerInstance) : ByteBuffer = { - val out = new ByteArrayOutputStream(4096) + val out = new FastByteArrayOutputStream(4096) val dataOut = new DataOutputStream(out) // Write currentFiles @@ -123,7 +123,7 @@ private[spark] object Task { dataOut.flush() val taskBytes = serializer.serialize(task).array() out.write(taskBytes) - ByteBuffer.wrap(out.toByteArray) + out.toByteBuffer } /** diff --git a/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala b/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala index 5e5883554fcc1..31073ddc13e88 100644 --- a/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala +++ b/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala @@ -22,7 +22,7 @@ import java.nio.ByteBuffer import org.apache.spark.SparkConf import org.apache.spark.annotation.DeveloperApi -import org.apache.spark.util.ByteBufferInputStream +import org.apache.spark.util.io.ByteBufferInputStream private[spark] class JavaSerializationStream(out: OutputStream, counterReset: Int) extends SerializationStream { diff --git a/core/src/main/scala/org/apache/spark/serializer/Serializer.scala b/core/src/main/scala/org/apache/spark/serializer/Serializer.scala index f2c8f9b6218d6..2fb2b191e3ace 100644 --- a/core/src/main/scala/org/apache/spark/serializer/Serializer.scala +++ b/core/src/main/scala/org/apache/spark/serializer/Serializer.scala @@ -17,12 +17,13 @@ package org.apache.spark.serializer -import java.io.{ByteArrayOutputStream, EOFException, InputStream, OutputStream} +import java.io.{EOFException, InputStream, OutputStream} import java.nio.ByteBuffer import org.apache.spark.SparkEnv import org.apache.spark.annotation.DeveloperApi -import org.apache.spark.util.{ByteBufferInputStream, NextIterator} +import org.apache.spark.util.NextIterator +import org.apache.spark.util.io.{ByteBufferInputStream, FastByteArrayOutputStream} /** * :: DeveloperApi :: @@ -71,9 +72,9 @@ trait SerializerInstance { def serializeMany[T](iterator: Iterator[T]): ByteBuffer = { // Default implementation uses serializeStream - val stream = new ByteArrayOutputStream() + val stream = new FastByteArrayOutputStream() serializeStream(stream).writeAll(iterator) - val buffer = ByteBuffer.wrap(stream.toByteArray) + val buffer = stream.toByteBuffer buffer.flip() buffer } diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index f14017051fa07..ba4c9e6be94a9 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -17,7 +17,7 @@ package org.apache.spark.storage -import java.io.{File, InputStream, OutputStream, BufferedOutputStream, ByteArrayOutputStream} +import java.io.{File, InputStream, OutputStream, BufferedOutputStream} import java.nio.{ByteBuffer, MappedByteBuffer} import scala.collection.mutable.{ArrayBuffer, HashMap} @@ -33,6 +33,8 @@ import org.apache.spark.io.CompressionCodec import org.apache.spark.network._ import org.apache.spark.serializer.Serializer import org.apache.spark.util._ +import org.apache.spark.util.io.{ByteBufferInputStream, FastByteArrayOutputStream} + private[spark] sealed trait Values @@ -1001,9 +1003,9 @@ private[spark] class BlockManager( blockId: BlockId, values: Iterator[Any], serializer: Serializer = defaultSerializer): ByteBuffer = { - val byteStream = new ByteArrayOutputStream(4096) + val byteStream = new FastByteArrayOutputStream(4096) dataSerializeStream(blockId, byteStream, values, serializer) - ByteBuffer.wrap(byteStream.toByteArray) + byteStream.toByteBuffer } /** diff --git a/core/src/main/scala/org/apache/spark/util/ByteBufferInputStream.scala b/core/src/main/scala/org/apache/spark/util/io/ByteBufferInputStream.scala similarity index 96% rename from core/src/main/scala/org/apache/spark/util/ByteBufferInputStream.scala rename to core/src/main/scala/org/apache/spark/util/io/ByteBufferInputStream.scala index 54de4d4ee8ca7..da45cfe2ce655 100644 --- a/core/src/main/scala/org/apache/spark/util/ByteBufferInputStream.scala +++ b/core/src/main/scala/org/apache/spark/util/io/ByteBufferInputStream.scala @@ -15,11 +15,12 @@ * limitations under the License. */ -package org.apache.spark.util +package org.apache.spark.util.io import java.io.InputStream import java.nio.ByteBuffer +// TODO(rxin): This file should not depend on BlockManager. import org.apache.spark.storage.BlockManager /** diff --git a/core/src/main/scala/org/apache/spark/util/io/FastByteArrayOutputStream.scala b/core/src/main/scala/org/apache/spark/util/io/FastByteArrayOutputStream.scala new file mode 100644 index 0000000000000..9c4ab274e041b --- /dev/null +++ b/core/src/main/scala/org/apache/spark/util/io/FastByteArrayOutputStream.scala @@ -0,0 +1,103 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util.io + +import java.io.OutputStream +import java.nio.ByteBuffer + +/** + * A simple, fast byte-array output stream that exposes the backing array, + * inspired by fastutil's FastByteArrayOutputStream. + * + * [[java.io.ByteArrayOutputStream]] is nice, but to get its content you + * must generate each time a new object. This doesn't happen here. + * + * This class will automatically enlarge the backing array, doubling its + * size whenever new space is needed. + */ +private[spark] class FastByteArrayOutputStream(initialCapacity: Int = 16) extends OutputStream { + + private[this] var _array = new Array[Byte](initialCapacity) + + /** The current writing position. */ + private[this] var _position: Int = 0 + + /** The number of valid bytes in array. */ + def length: Int = _position + + override def write(b: Int): Unit = { + if (_position >= _array.length ) { + _array = FastByteArrayOutputStream.growArray(_array, _position + 1, _position) + } + _array(_position) = b.toByte + _position += 1 + } + + override def write(b: Array[Byte], off: Int, len: Int) { + if (off < 0) { + throw new ArrayIndexOutOfBoundsException(s"Offset ($off) is negative" ) + } + if (len < 0) { + throw new IllegalArgumentException(s"Length ($len) is negative" ) + } + if (off + len > b.length) { + throw new ArrayIndexOutOfBoundsException( + s"Last index (${off + len}) is greater than array length (${b.length})") + } + if ( _position + len > _array.length ) { + _array = FastByteArrayOutputStream.growArray(_array, _position + len, _position) + } + System.arraycopy(b, off, _array, _position, len) + _position += len + } + + /** Return a ByteBuffer wrapping around the filled content of the underlying array. */ + def toByteBuffer: ByteBuffer = { + ByteBuffer.wrap(_array, 0, _position) + } + + /** + * Return a tuple, where the first element is the underlying array, and the second element + * is the length of the filled content. + */ + def toArray: (Array[Byte], Int) = (_array, _position) +} + +private object FastByteArrayOutputStream { + /** + * Grows the given array to the maximum between the given length and the current length + * multiplied by two, provided that the given length is larger than the current length, + * preserving just a part of the array. + * + * @param arr input array + * @param len the new minimum length for this array + * @param preserve the number of elements of the array that must be preserved + * in case a new allocation is necessary + */ + private def growArray(arr: Array[Byte], len: Int, preserve: Int): Array[Byte] = { + if (len > arr.length) { + val maxArraySize = Integer.MAX_VALUE - 8 + val newLen = math.min( math.max(2L * arr.length, len), maxArraySize).toInt + val newArr = new Array[Byte](newLen) + System.arraycopy(arr, 0, newArr, 0, preserve) + newArr + } else { + arr + } + } +} diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala index e10ec7d2624a0..7dd931857aa73 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala @@ -31,7 +31,8 @@ import org.scalatest.time.SpanSugar._ import org.apache.spark.{MapOutputTrackerMaster, SecurityManager, SparkConf} import org.apache.spark.scheduler.LiveListenerBus import org.apache.spark.serializer.{JavaSerializer, KryoSerializer} -import org.apache.spark.util.{AkkaUtils, ByteBufferInputStream, SizeEstimator, Utils} +import org.apache.spark.util.{AkkaUtils, SizeEstimator, Utils} +import org.apache.spark.util.io.ByteBufferInputStream class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodTester { private val conf = new SparkConf(false) diff --git a/core/src/test/scala/org/apache/spark/util/io/FastByteArrayOutputStreamSuite.scala b/core/src/test/scala/org/apache/spark/util/io/FastByteArrayOutputStreamSuite.scala new file mode 100644 index 0000000000000..d143164d27259 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/util/io/FastByteArrayOutputStreamSuite.scala @@ -0,0 +1,88 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util.io + +import org.scalatest.FunSuite + + +class FastByteArrayOutputStreamSuite extends FunSuite { + + test("write single byte") { + val out = new FastByteArrayOutputStream(initialCapacity = 4) + out.write(0) + out.write(1) + assert(out.toArray._1(0) === 0) + assert(out.toArray._1(1) === 1) + assert(out.toArray._2 === 2) + assert(out.length === 2) + + out.write(2) + out.write(3) + assert(out.toArray._1(2) === 2) + assert(out.toArray._1(3) === 3) + assert(out.length === 4) + + out.write(4) + assert(out.toArray._1(4) === 4) + assert(out.toArray._2 === 5) + assert(out.length === 5) + + for (i <- 5 to 100) { + out.write(i) + } + + for (i <- 5 to 100) { + assert(out.toArray._1(i) === i) + } + } + + test("write multiple bytes") { + val out = new FastByteArrayOutputStream(initialCapacity = 4) + out.write(Array[Byte](0.toByte, 1.toByte)) + assert(out.length === 2) + assert(out.toArray._1(0) === 0) + assert(out.toArray._1(1) === 1) + + out.write(Array[Byte](2.toByte, 3.toByte, 4.toByte)) + assert(out.length === 5) + assert(out.toArray._1(2) === 2) + assert(out.toArray._1(3) === 3) + assert(out.toArray._1(4) === 4) + + // Write more than double the size of the current array + out.write((1 to 100).map(_.toByte).toArray) + assert(out.length === 105) + assert(out.toArray._1(104) === 100) + } + + test("test large writes") { + val out = new FastByteArrayOutputStream(initialCapacity = 4096) + out.write(Array.tabulate[Byte](4096 * 1000)(_.toByte)) + assert(out.length === 4096 * 1000) + assert(out.toArray._1(0) === 0) + assert(out.toArray._1(4096 * 1000 - 1) === (4096 * 1000 - 1).toByte) + assert(out.toArray._2 === 4096 * 1000) + + out.write(Array.tabulate[Byte](4096 * 1000)(_.toByte)) + assert(out.length === 2 * 4096 * 1000) + assert(out.toArray._1(0) === 0) + assert(out.toArray._1(4096 * 1000) === 0) + assert(out.toArray._1(2 * 4096 * 1000 - 1) === (4096 * 1000 - 1).toByte) + assert(out.toArray._2 === 2 * 4096 * 1000) + } +} diff --git a/streaming/src/main/scala/org/apache/spark/streaming/util/RawTextSender.scala b/streaming/src/main/scala/org/apache/spark/streaming/util/RawTextSender.scala index a7850812bd612..a7d0f1dbf914a 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/util/RawTextSender.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/util/RawTextSender.scala @@ -17,7 +17,7 @@ package org.apache.spark.streaming.util -import java.io.{ByteArrayOutputStream, IOException} +import java.io.IOException import java.net.ServerSocket import java.nio.ByteBuffer @@ -26,6 +26,8 @@ import scala.io.Source import org.apache.spark.{SparkConf, Logging} import org.apache.spark.serializer.KryoSerializer import org.apache.spark.util.IntParam +import org.apache.spark.util.io.FastByteArrayOutputStream + /** * A helper program that sends blocks of Kryo-serialized text strings out on a socket at a @@ -43,18 +45,18 @@ object RawTextSender extends Logging { // Repeat the input data multiple times to fill in a buffer val lines = Source.fromFile(file).getLines().toArray - val bufferStream = new ByteArrayOutputStream(blockSize + 1000) + val bufferStream = new FastByteArrayOutputStream(blockSize + 1000) val ser = new KryoSerializer(new SparkConf()).newInstance() val serStream = ser.serializeStream(bufferStream) var i = 0 - while (bufferStream.size < blockSize) { + while (bufferStream.length < blockSize) { serStream.writeObject(lines(i)) i = (i + 1) % lines.length } - val array = bufferStream.toByteArray + val (array, len) = bufferStream.toArray val countBuf = ByteBuffer.wrap(new Array[Byte](4)) - countBuf.putInt(array.length) + countBuf.putInt(len) countBuf.flip() val serverSocket = new ServerSocket(port) @@ -67,7 +69,7 @@ object RawTextSender extends Logging { try { while (true) { out.write(countBuf.array) - out.write(array) + out.write(array, 0, len) // array's offset is 0, as returned by FastByteArrayOutputStream } } catch { case e: IOException =>