diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 3e6addeaf04a8..d6a4cd42f7604 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -40,7 +40,7 @@ import org.apache.mesos.MesosNativeLibrary import org.apache.spark.annotation.{DeveloperApi, Experimental} import org.apache.spark.broadcast.Broadcast import org.apache.spark.deploy.{LocalSparkCluster, SparkHadoopUtil} -import org.apache.spark.input.WholeTextFileInputFormat +import org.apache.spark.input.{StreamInputFormat, PortableDataStream, WholeTextFileInputFormat, FixedLengthBinaryInputFormat} import org.apache.spark.partial.{ApproximateEvaluator, PartialResult} import org.apache.spark.rdd._ import org.apache.spark.scheduler._ @@ -289,7 +289,7 @@ class SparkContext(config: SparkConf) extends Logging { value <- Option(System.getenv(envKey)).orElse(Option(System.getProperty(propKey)))} { executorEnvs(envKey) = value } - Option(System.getenv("SPARK_PREPEND_CLASSES")).foreach { v => + Option(System.getenv("SPARK_PREPEND_CLASSES")).foreach { v => executorEnvs("SPARK_PREPEND_CLASSES") = v } // The Mesos scheduler backend relies on this environment variable to set executor memory. @@ -510,6 +510,52 @@ class SparkContext(config: SparkConf) extends Logging { minPartitions).setName(path) } + + /** + * Get an RDD for a Hadoop-readable dataset as PortableDataStream for each file + * (useful for binary data) + * + * + * @param minPartitions A suggestion value of the minimal splitting number for input data. + * + * @note Small files are preferred, large file is also allowable, but may cause bad performance. + */ + @DeveloperApi + def binaryFiles(path: String, minPartitions: Int = defaultMinPartitions): + RDD[(String, PortableDataStream)] = { + val job = new NewHadoopJob(hadoopConfiguration) + NewFileInputFormat.addInputPath(job, new Path(path)) + val updateConf = job.getConfiguration + new BinaryFileRDD( + this, + classOf[StreamInputFormat], + classOf[String], + classOf[PortableDataStream], + updateConf, + minPartitions).setName(path) + } + + /** + * Load data from a flat binary file, assuming each record is a set of numbers + * with the specified numerical format (see ByteBuffer), and the number of + * bytes per record is constant (see FixedLengthBinaryInputFormat) + * + * @param path Directory to the input data files + * @param recordLength The length at which to split the records + * @return An RDD of data with values, RDD[(Array[Byte])] + */ + def binaryRecords(path: String, recordLength: Int, + conf: Configuration = hadoopConfiguration): RDD[Array[Byte]] = { + conf.setInt("recordLength",recordLength) + val br = newAPIHadoopFile[LongWritable, BytesWritable, FixedLengthBinaryInputFormat](path, + classOf[FixedLengthBinaryInputFormat], + classOf[LongWritable], + classOf[BytesWritable], + conf=conf) + val data = br.map{ case (k, v) => v.getBytes} + data + } + /** * Get an RDD for a Hadoop-readable dataset from a Hadoop JobConf given its InputFormat and other * necessary info (e.g. file name for a filesystem-based dataset, table name for HyperTable), @@ -1208,7 +1254,7 @@ class SparkContext(config: SparkConf) extends Logging { * If checkSerializable is set, clean will also proactively * check to see if f is serializable and throw a SparkException * if not. - * + * * @param f the closure to clean * @param checkSerializable whether or not to immediately check f for serializability * @throws SparkException if checkSerializable is set but f is not diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala b/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala index 8a5f8088a05ca..373e6f1d12864 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala @@ -20,6 +20,11 @@ package org.apache.spark.api.java import java.util import java.util.{Map => JMap} +import java.io.DataInputStream + +import org.apache.hadoop.io.{BytesWritable, LongWritable} +import org.apache.spark.input.{PortableDataStream, FixedLengthBinaryInputFormat} + import scala.collection.JavaConversions import scala.collection.JavaConversions._ import scala.language.implicitConversions @@ -31,7 +36,7 @@ import org.apache.hadoop.mapred.{InputFormat, JobConf} import org.apache.hadoop.mapreduce.{InputFormat => NewInputFormat} import org.apache.spark._ -import org.apache.spark.SparkContext.{DoubleAccumulatorParam, IntAccumulatorParam} +import org.apache.spark.SparkContext._ import org.apache.spark.api.java.JavaSparkContext.fakeClassTag import org.apache.spark.broadcast.Broadcast import org.apache.spark.rdd.{EmptyRDD, RDD} @@ -180,6 +185,8 @@ class JavaSparkContext(val sc: SparkContext) extends JavaSparkContextVarargsWork def textFile(path: String, minPartitions: Int): JavaRDD[String] = sc.textFile(path, minPartitions) + + /** * Read a directory of text files from HDFS, a local file system (available on all nodes), or any * Hadoop-supported file system URI. Each file is read as a single record and returned in a @@ -220,6 +227,54 @@ class JavaSparkContext(val sc: SparkContext) extends JavaSparkContextVarargsWork def wholeTextFiles(path: String): JavaPairRDD[String, String] = new JavaPairRDD(sc.wholeTextFiles(path)) + /** + * Read a directory of binary files from HDFS, a local file system (available on all nodes), + * or any Hadoop-supported file system URI as a byte array. Each file is read as a single + * record and returned in a key-value pair, where the key is the path of each file, + * the value is the content of each file. + * + *

For example, if you have the following files: + * {{{ + * hdfs://a-hdfs-path/part-00000 + * hdfs://a-hdfs-path/part-00001 + * ... + * hdfs://a-hdfs-path/part-nnnnn + * }}} + * + * Do + * `JavaPairRDD rdd = sparkContext.dataStreamFiles("hdfs://a-hdfs-path")`, + * + *

then `rdd` contains + * {{{ + * (a-hdfs-path/part-00000, its content) + * (a-hdfs-path/part-00001, its content) + * ... + * (a-hdfs-path/part-nnnnn, its content) + * }}} + * + * @note Small files are preferred, large file is also allowable, but may cause bad performance. + * + * @param minPartitions A suggestion value of the minimal splitting number for input data. + */ + def binaryFiles(path: String, minPartitions: Int): + JavaPairRDD[String,PortableDataStream] = new JavaPairRDD(sc.binaryFiles(path,minPartitions)) + + def binaryFiles(path: String): + JavaPairRDD[String,PortableDataStream] = + new JavaPairRDD(sc.binaryFiles(path,defaultMinPartitions)) + + /** + * Load data from a flat binary file, assuming each record is a set of numbers + * with the specified numerical format (see ByteBuffer), and the number of + * bytes per record is constant (see FixedLengthBinaryInputFormat) + * + * @param path Directory to the input data files + * @return An RDD of data with values, JavaRDD[(Array[Byte])] + */ + def binaryRecords(path: String,recordLength: Int): JavaRDD[Array[Byte]] = { + new JavaRDD(sc.binaryRecords(path,recordLength)) + } + /** Get an RDD for a Hadoop SequenceFile with given key and value types. * * '''Note:''' Because Hadoop's RecordReader class re-uses the same Writable object for each diff --git a/core/src/main/scala/org/apache/spark/input/FixedLengthBinaryInputFormat.scala b/core/src/main/scala/org/apache/spark/input/FixedLengthBinaryInputFormat.scala new file mode 100644 index 0000000000000..852be54d181d4 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/input/FixedLengthBinaryInputFormat.scala @@ -0,0 +1,87 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.input + +import org.apache.hadoop.fs.Path +import org.apache.hadoop.io.{BytesWritable, LongWritable} +import org.apache.hadoop.mapreduce.lib.input.FileInputFormat +import org.apache.hadoop.mapreduce.{InputSplit, JobContext, RecordReader, TaskAttemptContext} + +/** + * Custom Input Format for reading and splitting flat binary files that contain records, + * each of which are a fixed size in bytes. The fixed record size is specified through + * a parameter recordLength in the Hadoop configuration. + */ + +private[spark] object FixedLengthBinaryInputFormat { + /** + * This function retrieves the recordLength by checking the configuration parameter + * + */ + def getRecordLength(context: JobContext): Int = { + + // retrieve record length from configuration + context.getConfiguration.get("recordLength").toInt + } +} + +private[spark] class FixedLengthBinaryInputFormat + extends FileInputFormat[LongWritable, BytesWritable] { + /** + * Override of isSplitable to ensure initial computation of the record length + */ + override def isSplitable(context: JobContext, filename: Path): Boolean = { + + if (recordLength == -1) { + recordLength = FixedLengthBinaryInputFormat.getRecordLength(context) + } + if (recordLength <= 0) { + println("record length is less than 0, file cannot be split") + false + } else { + true + } + } + + /** + * This input format overrides computeSplitSize() to make sure that each split + * only contains full records. Each InputSplit passed to FixedLengthBinaryRecordReader + * will start at the first byte of a record, and the last byte will the last byte of a record. + */ + override def computeSplitSize(blockSize: Long, minSize: Long, maxSize: Long): Long = { + val defaultSize = super.computeSplitSize(blockSize, minSize, maxSize) + // If the default size is less than the length of a record, make it equal to it + // Otherwise, make sure the split size is as close to possible as the default size, + // but still contains a complete set of records, with the first record + // starting at the first byte in the split and the last record ending with the last byte + if (defaultSize < recordLength) { + recordLength.toLong + } else { + (Math.floor(defaultSize / recordLength) * recordLength).toLong + } + } + + /** + * Create a FixedLengthBinaryRecordReader + */ + override def createRecordReader(split: InputSplit, context: TaskAttemptContext): + RecordReader[LongWritable, BytesWritable] = { + new FixedLengthBinaryRecordReader + } + var recordLength = -1 +} diff --git a/core/src/main/scala/org/apache/spark/input/FixedLengthBinaryRecordReader.scala b/core/src/main/scala/org/apache/spark/input/FixedLengthBinaryRecordReader.scala new file mode 100644 index 0000000000000..98988910831be --- /dev/null +++ b/core/src/main/scala/org/apache/spark/input/FixedLengthBinaryRecordReader.scala @@ -0,0 +1,129 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.input + +import java.io.IOException + +import org.apache.hadoop.fs.FSDataInputStream +import org.apache.hadoop.io.compress.CompressionCodecFactory +import org.apache.hadoop.io.{BytesWritable, LongWritable} +import org.apache.hadoop.mapreduce.{InputSplit, RecordReader, TaskAttemptContext} +import org.apache.hadoop.mapreduce.lib.input.FileSplit + +/** + * + * FixedLengthBinaryRecordReader is returned by FixedLengthBinaryInputFormat. + * It uses the record length set in FixedLengthBinaryInputFormat to + * read one record at a time from the given InputSplit. + * + * Each call to nextKeyValue() updates the LongWritable KEY and BytesWritable VALUE. + * + * KEY = record index (Long) + * VALUE = the record itself (BytesWritable) + * + */ +private[spark] class FixedLengthBinaryRecordReader + extends RecordReader[LongWritable, BytesWritable] { + + override def close() { + if (fileInputStream != null) { + fileInputStream.close() + } + } + + override def getCurrentKey: LongWritable = { + recordKey + } + + override def getCurrentValue: BytesWritable = { + recordValue + } + + override def getProgress: Float = { + splitStart match { + case x if x == splitEnd => 0.0.toFloat + case _ => Math.min( + ((currentPosition - splitStart) / (splitEnd - splitStart)).toFloat, 1.0 + ).toFloat + } + } + + override def initialize(inputSplit: InputSplit, context: TaskAttemptContext) { + + // the file input + val fileSplit = inputSplit.asInstanceOf[FileSplit] + + // the byte position this fileSplit starts at + splitStart = fileSplit.getStart + + // splitEnd byte marker that the fileSplit ends at + splitEnd = splitStart + fileSplit.getLength + + // the actual file we will be reading from + val file = fileSplit.getPath + // job configuration + val job = context.getConfiguration + // check compression + val codec = new CompressionCodecFactory(job).getCodec(file) + if (codec != null) { + throw new IOException("FixedLengthRecordReader does not support reading compressed files") + } + // get the record length + recordLength = FixedLengthBinaryInputFormat.getRecordLength(context) + // get the filesystem + val fs = file.getFileSystem(job) + // open the File + fileInputStream = fs.open(file) + // seek to the splitStart position + fileInputStream.seek(splitStart) + // set our current position + currentPosition = splitStart + } + + override def nextKeyValue(): Boolean = { + if (recordKey == null) { + recordKey = new LongWritable() + } + // the key is a linear index of the record, given by the + // position the record starts divided by the record length + recordKey.set(currentPosition / recordLength) + // the recordValue to place the bytes into + if (recordValue == null) { + recordValue = new BytesWritable(new Array[Byte](recordLength)) + } + // read a record if the currentPosition is less than the split end + if (currentPosition < splitEnd) { + // setup a buffer to store the record + val buffer = recordValue.getBytes + fileInputStream.read(buffer, 0, recordLength) + // update our current position + currentPosition = currentPosition + recordLength + // return true + return true + } + false + } + var splitStart: Long = 0L + var splitEnd: Long = 0L + var currentPosition: Long = 0L + var recordLength: Int = 0 + var fileInputStream: FSDataInputStream = null + var recordKey: LongWritable = null + var recordValue: BytesWritable = null + +} diff --git a/core/src/main/scala/org/apache/spark/input/PortableDataStream.scala b/core/src/main/scala/org/apache/spark/input/PortableDataStream.scala new file mode 100644 index 0000000000000..9cfe3e5c1742a --- /dev/null +++ b/core/src/main/scala/org/apache/spark/input/PortableDataStream.scala @@ -0,0 +1,215 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.input + +import java.io.{ByteArrayInputStream, ByteArrayOutputStream, DataInputStream, DataOutputStream} + +import com.google.common.io.ByteStreams +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.Path +import org.apache.hadoop.mapreduce.{InputSplit, JobContext, RecordReader, TaskAttemptContext} +import org.apache.hadoop.mapreduce.lib.input.{CombineFileInputFormat, CombineFileRecordReader, CombineFileSplit} + +import scala.collection.JavaConversions._ + +/** + * A general format for reading whole files in as streams, byte arrays, + * or other functions to be added + */ +private[spark] abstract class StreamFileInputFormat[T] + extends CombineFileInputFormat[String, T] { + override protected def isSplitable(context: JobContext, file: Path): Boolean = false + /** + * Allow minPartitions set by end-user in order to keep compatibility with old Hadoop API + * which is set through setMaxSplitSize + */ + def setMinPartitions(context: JobContext, minPartitions: Int) { + val files = listStatus(context) + val totalLen = files.map { file => + if (file.isDir) 0L else file.getLen + }.sum + + val maxSplitSize = Math.ceil(totalLen * 1.0 / files.length).toLong + super.setMaxSplitSize(maxSplitSize) + } + + def createRecordReader(split: InputSplit, taContext: TaskAttemptContext): RecordReader[String, T] + +} + +/** + * An abstract class of [[org.apache.hadoop.mapreduce.RecordReader RecordReader]] + * to reading files out as streams + */ +private[spark] abstract class StreamBasedRecordReader[T]( + split: CombineFileSplit, + context: TaskAttemptContext, + index: Integer) + extends RecordReader[String, T] { + + // True means the current file has been processed, then skip it. + private var processed = false + + private var key = "" + private var value: T = null.asInstanceOf[T] + + override def initialize(split: InputSplit, context: TaskAttemptContext) = {} + override def close() = {} + + override def getProgress = if (processed) 1.0f else 0.0f + + override def getCurrentKey = key + + override def getCurrentValue = value + + override def nextKeyValue = { + if (!processed) { + val fileIn = new PortableDataStream(split, context, index) + value = parseStream(fileIn) + fileIn.close() // if it has not been open yet, close does nothing + key = fileIn.getPath + processed = true + true + } else { + false + } + } + + /** + * Parse the stream (and close it afterwards) and return the value as in type T + * @param inStream the stream to be read in + * @return the data formatted as + */ + def parseStream(inStream: PortableDataStream): T +} + +/** + * Reads the record in directly as a stream for other objects to manipulate and handle + */ +private[spark] class StreamRecordReader( + split: CombineFileSplit, + context: TaskAttemptContext, + index: Integer) + extends StreamBasedRecordReader[PortableDataStream](split, context, index) { + + def parseStream(inStream: PortableDataStream): PortableDataStream = inStream +} + +/** + * The format for the PortableDataStream files + */ +private[spark] class StreamInputFormat extends StreamFileInputFormat[PortableDataStream] { + override def createRecordReader(split: InputSplit, taContext: TaskAttemptContext) = + { + new CombineFileRecordReader[String, PortableDataStream]( + split.asInstanceOf[CombineFileSplit], taContext, classOf[StreamRecordReader]) + } +} + +/** + * A class that allows DataStreams to be serialized and moved around by not creating them + * until they need to be read + * @note TaskAttemptContext is not serializable resulting in the confBytes construct + * @note CombineFileSplit is not serializable resulting in the splitBytes construct + */ +class PortableDataStream(@transient isplit: CombineFileSplit, + @transient context: TaskAttemptContext, index: Integer) + extends Serializable { + // transient forces file to be reopened after being serialization + // it is also used for non-serializable classes + + @transient + private var fileIn: DataInputStream = null.asInstanceOf[DataInputStream] + @transient + private var isOpen = false + + private val confBytes = { + val baos = new ByteArrayOutputStream() + context.getConfiguration.write(new DataOutputStream(baos)) + baos.toByteArray + } + + private val splitBytes = { + val baos = new ByteArrayOutputStream() + isplit.write(new DataOutputStream(baos)) + baos.toByteArray + } + + @transient + private lazy val split = { + val bais = new ByteArrayInputStream(splitBytes) + val nsplit = new CombineFileSplit() + nsplit.readFields(new DataInputStream(bais)) + nsplit + } + + @transient + private lazy val conf = { + val bais = new ByteArrayInputStream(confBytes) + val nconf = new Configuration() + nconf.readFields(new DataInputStream(bais)) + nconf + } + /** + * Calculate the path name independently of opening the file + */ + @transient + private lazy val path = { + val pathp = split.getPath(index) + pathp.toString + } + + /** + * create a new DataInputStream from the split and context + */ + def open(): DataInputStream = { + if (!isOpen) { + val pathp = split.getPath(index) + val fs = pathp.getFileSystem(conf) + fileIn = fs.open(pathp) + isOpen = true + } + fileIn + } + + /** + * Read the file as a byte array + */ + def toArray(): Array[Byte] = { + open() + val innerBuffer = ByteStreams.toByteArray(fileIn) + close() + innerBuffer + } + + /** + * close the file (if it is already open) + */ + def close() = { + if (isOpen) { + try { + fileIn.close() + isOpen = false + } catch { + case ioe: java.io.IOException => // do nothing + } + } + } + def getPath(): String = path +} + diff --git a/core/src/main/scala/org/apache/spark/input/WholeTextFileInputFormat.scala b/core/src/main/scala/org/apache/spark/input/WholeTextFileInputFormat.scala index 4cb450577796a..183bce3d8d8d3 100644 --- a/core/src/main/scala/org/apache/spark/input/WholeTextFileInputFormat.scala +++ b/core/src/main/scala/org/apache/spark/input/WholeTextFileInputFormat.scala @@ -48,9 +48,10 @@ private[spark] class WholeTextFileInputFormat extends CombineFileInputFormat[Str } /** - * Allow minPartitions set by end-user in order to keep compatibility with old Hadoop API. + * Allow minPartitions set by end-user in order to keep compatibility with old Hadoop API, + * which is set through setMaxSplitSize */ - def setMaxSplitSize(context: JobContext, minPartitions: Int) { + def setMinPartitions(context: JobContext, minPartitions: Int) { val files = listStatus(context) val totalLen = files.map { file => if (file.isDir) 0L else file.getLen diff --git a/core/src/main/scala/org/apache/spark/rdd/BinaryFileRDD.scala b/core/src/main/scala/org/apache/spark/rdd/BinaryFileRDD.scala new file mode 100644 index 0000000000000..c01196de6bad7 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/rdd/BinaryFileRDD.scala @@ -0,0 +1,52 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.rdd + +import org.apache.hadoop.conf.{ Configurable, Configuration } +import org.apache.hadoop.io.Writable +import org.apache.hadoop.mapreduce._ +import org.apache.spark.input.StreamFileInputFormat +import org.apache.spark.{ Partition, SparkContext } + +private[spark] class BinaryFileRDD[T]( + sc: SparkContext, + inputFormatClass: Class[_ <: StreamFileInputFormat[T]], + keyClass: Class[String], + valueClass: Class[T], + @transient conf: Configuration, + minPartitions: Int) + extends NewHadoopRDD[String, T](sc, inputFormatClass, keyClass, valueClass, conf) { + + override def getPartitions: Array[Partition] = { + val inputFormat = inputFormatClass.newInstance + inputFormat match { + case configurable: Configurable => + configurable.setConf(conf) + case _ => + } + val jobContext = newJobContext(conf, jobId) + inputFormat.setMinPartitions(jobContext, minPartitions) + val rawSplits = inputFormat.getSplits(jobContext).toArray + val result = new Array[Partition](rawSplits.size) + for (i <- 0 until rawSplits.size) { + result(i) = new NewHadoopPartition( + id, i, rawSplits(i).asInstanceOf[InputSplit with Writable]) + } + result + } +} diff --git a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala index f2b3a64bf1345..26bfb2d5fa8a3 100644 --- a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala @@ -182,7 +182,7 @@ private[spark] class WholeTextFileRDD( case _ => } val jobContext = newJobContext(conf, jobId) - inputFormat.setMaxSplitSize(jobContext, minPartitions) + inputFormat.setMinPartitions(jobContext, minPartitions) val rawSplits = inputFormat.getSplits(jobContext).toArray val result = new Array[Partition](rawSplits.size) for (i <- 0 until rawSplits.size) { diff --git a/core/src/test/java/org/apache/spark/JavaAPISuite.java b/core/src/test/java/org/apache/spark/JavaAPISuite.java index e8bd65f8e4507..a39ee903ba3c9 100644 --- a/core/src/test/java/org/apache/spark/JavaAPISuite.java +++ b/core/src/test/java/org/apache/spark/JavaAPISuite.java @@ -18,9 +18,12 @@ package org.apache.spark; import java.io.*; +import java.nio.channels.FileChannel; +import java.nio.ByteBuffer; import java.net.URI; import java.util.*; +import org.apache.spark.input.PortableDataStream; import scala.Tuple2; import scala.Tuple3; import scala.Tuple4; @@ -836,6 +839,82 @@ public Tuple2 call(Tuple2 pair) { Assert.assertEquals(pairs, readRDD.collect()); } + @Test + public void binaryFiles() throws Exception { + // Reusing the wholeText files example + byte[] content1 = "spark is easy to use.\n".getBytes("utf-8"); + + String tempDirName = tempDir.getAbsolutePath(); + File file1 = new File(tempDirName + "/part-00000"); + + FileOutputStream fos1 = new FileOutputStream(file1); + + FileChannel channel1 = fos1.getChannel(); + ByteBuffer bbuf = java.nio.ByteBuffer.wrap(content1); + channel1.write(bbuf); + channel1.close(); + JavaPairRDD readRDD = sc.binaryFiles(tempDirName,3); + List> result = readRDD.collect(); + for (Tuple2 res : result) { + Assert.assertArrayEquals(content1, res._2().toArray()); + } + } + + @Test + public void binaryFilesCaching() throws Exception { + // Reusing the wholeText files example + byte[] content1 = "spark is easy to use.\n".getBytes("utf-8"); + + String tempDirName = tempDir.getAbsolutePath(); + File file1 = new File(tempDirName + "/part-00000"); + + FileOutputStream fos1 = new FileOutputStream(file1); + + FileChannel channel1 = fos1.getChannel(); + ByteBuffer bbuf = java.nio.ByteBuffer.wrap(content1); + channel1.write(bbuf); + channel1.close(); + + JavaPairRDD readRDD = sc.binaryFiles(tempDirName).cache(); + readRDD.foreach(new VoidFunction>() { + @Override + public void call(Tuple2 stringPortableDataStreamTuple2) throws Exception { + stringPortableDataStreamTuple2._2().toArray(); // force the file to read + } + }); + + List> result = readRDD.collect(); + for (Tuple2 res : result) { + Assert.assertArrayEquals(content1, res._2().toArray()); + } + } + + @Test + public void binaryRecords() throws Exception { + // Reusing the wholeText files example + byte[] content1 = "spark isn't always easy to use.\n".getBytes("utf-8"); + int numOfCopies = 10; + String tempDirName = tempDir.getAbsolutePath(); + File file1 = new File(tempDirName + "/part-00000"); + + FileOutputStream fos1 = new FileOutputStream(file1); + + FileChannel channel1 = fos1.getChannel(); + + for (int i=0;i readRDD = sc.binaryRecords(tempDirName,content1.length); + Assert.assertEquals(numOfCopies,readRDD.count()); + List result = readRDD.collect(); + for (byte[] res : result) { + Assert.assertArrayEquals(content1, res); + } + } + @SuppressWarnings("unchecked") @Test public void writeWithNewAPIHadoopFile() { diff --git a/core/src/test/scala/org/apache/spark/FileSuite.scala b/core/src/test/scala/org/apache/spark/FileSuite.scala index c70e22cf09433..3f4c7b0180486 100644 --- a/core/src/test/scala/org/apache/spark/FileSuite.scala +++ b/core/src/test/scala/org/apache/spark/FileSuite.scala @@ -19,6 +19,9 @@ package org.apache.spark import java.io.{File, FileWriter} +import org.apache.spark.input.PortableDataStream +import org.apache.spark.storage.StorageLevel + import scala.io.Source import com.google.common.io.Files @@ -224,6 +227,187 @@ class FileSuite extends FunSuite with LocalSparkContext { assert(output.map(_.toString).collect().toList === List("(1,a)", "(2,aa)", "(3,aaa)")) } + test("binary file input as byte array") { + sc = new SparkContext("local", "test") + val outFile = new File(tempDir, "record-bytestream-00000.bin") + val outFileName = outFile.getAbsolutePath() + + // create file + val testOutput = Array[Byte](1,2,3,4,5,6) + val bbuf = java.nio.ByteBuffer.wrap(testOutput) + // write data to file + val file = new java.io.FileOutputStream(outFile) + val channel = file.getChannel + channel.write(bbuf) + channel.close() + file.close() + + val inRdd = sc.binaryFiles(outFileName) + val (infile: String, indata: PortableDataStream) = inRdd.collect.head + + // Make sure the name and array match + assert(infile.contains(outFileName)) // a prefix may get added + assert(indata.toArray === testOutput) + } + + test("portabledatastream caching tests") { + sc = new SparkContext("local", "test") + val outFile = new File(tempDir, "record-bytestream-00000.bin") + val outFileName = outFile.getAbsolutePath() + + // create file + val testOutput = Array[Byte](1,2,3,4,5,6) + val bbuf = java.nio.ByteBuffer.wrap(testOutput) + // write data to file + val file = new java.io.FileOutputStream(outFile) + val channel = file.getChannel + channel.write(bbuf) + channel.close() + file.close() + + val inRdd = sc.binaryFiles(outFileName).cache() + inRdd.foreach{ + curData: (String, PortableDataStream) => + curData._2.toArray() // force the file to read + } + val mappedRdd = inRdd.map{ + curData: (String, PortableDataStream) => + (curData._2.getPath(),curData._2) + } + val (infile: String, indata: PortableDataStream) = mappedRdd.collect.head + + // Try reading the output back as an object file + + assert(indata.toArray === testOutput) + } + + test("portabledatastream persist disk storage") { + sc = new SparkContext("local", "test") + val outFile = new File(tempDir, "record-bytestream-00000.bin") + val outFileName = outFile.getAbsolutePath() + + // create file + val testOutput = Array[Byte](1,2,3,4,5,6) + val bbuf = java.nio.ByteBuffer.wrap(testOutput) + // write data to file + val file = new java.io.FileOutputStream(outFile) + val channel = file.getChannel + channel.write(bbuf) + channel.close() + file.close() + + val inRdd = sc.binaryFiles(outFileName).persist(StorageLevel.DISK_ONLY) + inRdd.foreach{ + curData: (String, PortableDataStream) => + curData._2.toArray() // force the file to read + } + val mappedRdd = inRdd.map{ + curData: (String, PortableDataStream) => + (curData._2.getPath(),curData._2) + } + val (infile: String, indata: PortableDataStream) = mappedRdd.collect.head + + // Try reading the output back as an object file + + assert(indata.toArray === testOutput) + } + + test("portabledatastream flatmap tests") { + sc = new SparkContext("local", "test") + val outFile = new File(tempDir, "record-bytestream-00000.bin") + val outFileName = outFile.getAbsolutePath() + + // create file + val testOutput = Array[Byte](1,2,3,4,5,6) + val numOfCopies = 3 + val bbuf = java.nio.ByteBuffer.wrap(testOutput) + // write data to file + val file = new java.io.FileOutputStream(outFile) + val channel = file.getChannel + channel.write(bbuf) + channel.close() + file.close() + + val inRdd = sc.binaryFiles(outFileName) + val mappedRdd = inRdd.map{ + curData: (String, PortableDataStream) => + (curData._2.getPath(),curData._2) + } + val copyRdd = mappedRdd.flatMap{ + curData: (String, PortableDataStream) => + for(i <- 1 to numOfCopies) yield (i,curData._2) + } + + val copyArr: Array[(Int, PortableDataStream)] = copyRdd.collect() + + // Try reading the output back as an object file + assert(copyArr.length == numOfCopies) + copyArr.foreach{ + cEntry: (Int, PortableDataStream) => + assert(cEntry._2.toArray === testOutput) + } + + } + + test("fixed record length binary file as byte array") { + // a fixed length of 6 bytes + + sc = new SparkContext("local", "test") + + val outFile = new File(tempDir, "record-bytestream-00000.bin") + val outFileName = outFile.getAbsolutePath() + + // create file + val testOutput = Array[Byte](1,2,3,4,5,6) + val testOutputCopies = 10 + + // write data to file + val file = new java.io.FileOutputStream(outFile) + val channel = file.getChannel + for(i <- 1 to testOutputCopies) { + val bbuf = java.nio.ByteBuffer.wrap(testOutput) + channel.write(bbuf) + } + channel.close() + file.close() + + val inRdd = sc.binaryRecords(outFileName, testOutput.length) + // make sure there are enough elements + assert(inRdd.count == testOutputCopies) + + // now just compare the first one + val indata: Array[Byte] = inRdd.collect.head + assert(indata === testOutput) + } + + test ("negative binary record length should raise an exception") { + // a fixed length of 6 bytes + sc = new SparkContext("local", "test") + + val outFile = new File(tempDir, "record-bytestream-00000.bin") + val outFileName = outFile.getAbsolutePath() + + // create file + val testOutput = Array[Byte](1,2,3,4,5,6) + val testOutputCopies = 10 + + // write data to file + val file = new java.io.FileOutputStream(outFile) + val channel = file.getChannel + for(i <- 1 to testOutputCopies) { + val bbuf = java.nio.ByteBuffer.wrap(testOutput) + channel.write(bbuf) + } + channel.close() + file.close() + + val inRdd = sc.binaryRecords(outFileName, -1) + + intercept[SparkException] { + inRdd.count + } + } + test("file caching") { sc = new SparkContext("local", "test") val out = new FileWriter(tempDir + "/input")