diff --git a/core/src/main/scala/org/apache/spark/rdd/AssumedPartitionedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/AssumedPartitionedRDD.scala new file mode 100644 index 0000000000000..5a311fb8aaa77 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/rdd/AssumedPartitionedRDD.scala @@ -0,0 +1,55 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.rdd + +import org.apache.spark.{SparkException, TaskContext, Partition, Partitioner} + +import scala.reflect.ClassTag + +private[spark] class AssumedPartitionedRDD[K: ClassTag, V: ClassTag]( + parent: RDD[(K,V)], + part: Partitioner, + val verify: Boolean + ) extends RDD[(K,V)](parent) { + + override val partitioner = Some(part) + + override def getPartitions: Array[Partition] = firstParent[(K,V)].partitions + + if(verify && getPartitions.size != part.numPartitions) { + throw new SparkException(s"Assumed Partitioner $part expects ${part.numPartitions} " + + s"partitions, but there are ${getPartitions.size} partitions. If you are assuming a" + + s" partitioner on a HadoopRDD, you might need to disable input splits with a custom input" + + s" format") + } + + override def compute(split: Partition, context: TaskContext) = { + if (verify) { + firstParent[(K,V)].iterator(split, context).map{ case(k,v) => + if (partitioner.get.getPartition(k) != split.index) { + throw new SparkException(s"key $k in split ${split.index} was not in the assumed " + + s"partition. If you are assuming a partitioner on a HadoopRDD, you might need to " + + s"disable input splits with a custom input format") + } + (k,v) + } + } else { + firstParent[(K,V)].iterator(split, context) + } + } + +} diff --git a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala index 486e86ce1bb19..d646a4b6699a3 100644 --- a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala @@ -200,7 +200,10 @@ class HadoopRDD[K, V]( if (inputFormat.isInstanceOf[Configurable]) { inputFormat.asInstanceOf[Configurable].setConf(jobConf) } - val inputSplits = inputFormat.getSplits(jobConf, minPartitions) + // we have to sort the partitions here so that part-0000 goes to partition 0, etc. This is + // so we can use the same partitioner after we save an RDD to hdfs and then read it back + // SPARK-1061 + val inputSplits = inputFormat.getSplits(jobConf, minPartitions).sorted(SplitOrdering) val array = new Array[Partition](inputSplits.size) for (i <- 0 until inputSplits.size) { array(i) = new HadoopPartition(id, i, inputSplits(i)) @@ -416,3 +419,16 @@ private[spark] object HadoopRDD extends Logging { out.seq } } + +private[spark] object SplitOrdering extends Ordering[InputSplit] { + def compare(x: InputSplit, y: InputSplit): Int = { + (x,y) match { + case fileSplits: (FileSplit, FileSplit) => + fileSplitOrdering.compare(fileSplits._1, fileSplits._2) + case _ => 1 + } + } + + val fileSplitOrdering: Ordering[FileSplit] = Ordering.by{fileSplit => + (fileSplit.getPath.toString, fileSplit.getStart)} +} diff --git a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala index 955b42c3baaa1..8b1bcb6035e72 100644 --- a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala +++ b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala @@ -1115,6 +1115,23 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) */ def values: RDD[V] = self.map(_._2) + /** + * Return an RDD that assumes the data in this RDD has been partitioned. This can be useful + * to create narrow dependencies to avoid shuffles. This can be especially useful when loading + * an RDD from HDFS, where that RDD was already partitioned when it was saved. Note that when + * loading a file from HDFS, you must ensure that the input splits are disabled with a custom + * FileInputFormat. + * + * If verify == true, every record will be checked against the Partitioner. If any record is not + * in the correct partition, an exception will be thrown when the RDD is computed. + * + * If verify == false, and the RDD is not partitioned correctly, the behavior is undefined. There + * may be a runtime error, or there may simply be wildly inaccurate results with no warning. + */ + def assumePartitionedBy(partitioner: Partitioner, verify: Boolean = true): RDD[(K,V)] = { + new AssumedPartitionedRDD[K,V](self, partitioner, verify) + } + private[spark] def keyClass: Class[_] = kt.runtimeClass private[spark] def valueClass: Class[_] = vt.runtimeClass diff --git a/core/src/test/scala/org/apache/spark/rdd/HadoopRDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/HadoopRDDSuite.scala new file mode 100644 index 0000000000000..bc88453efdb29 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/rdd/HadoopRDDSuite.scala @@ -0,0 +1,38 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.rdd + +import java.util + +import org.apache.hadoop.fs.Path +import org.apache.hadoop.mapred.FileSplit +import org.scalatest.{Matchers, FunSuite} + +import scala.collection.JavaConverters._ + +class HadoopRDDSuite extends FunSuite with Matchers { + test("file split ordering") { + val splits = (0 until 10).map{idx => + new FileSplit(new Path("/foo/bar/part-0000" + idx), 0l, 0l, Array[String]())} + + val javaShuffledSplits = new util.ArrayList[FileSplit]() + splits.foreach{s => javaShuffledSplits.add(s)} + java.util.Collections.shuffle(javaShuffledSplits) + val scalaShuffledSplits = javaShuffledSplits.asScala + scalaShuffledSplits.sorted(SplitOrdering) should be (splits) + } +} diff --git a/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala b/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala index 108f70af43f37..043a2da26acf7 100644 --- a/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala @@ -17,7 +17,10 @@ package org.apache.spark.rdd -import org.apache.hadoop.fs.FileSystem +import java.io.File + +import org.apache.hadoop.fs.{Path, FileSystem} +import org.apache.hadoop.io.IntWritable import org.apache.hadoop.mapred._ import org.apache.hadoop.util.Progressable @@ -25,10 +28,11 @@ import scala.collection.mutable.{ArrayBuffer, HashSet} import scala.util.Random import org.apache.hadoop.conf.{Configurable, Configuration} -import org.apache.hadoop.mapreduce.{JobContext => NewJobContext, OutputCommitter => NewOutputCommitter, +import org.apache.hadoop.mapreduce.{JobContext => NewJobContext, +OutputCommitter => NewOutputCommitter, OutputFormat => NewOutputFormat, RecordWriter => NewRecordWriter, TaskAttemptContext => NewTaskAttempContext} -import org.apache.spark.{Partitioner, SharedSparkContext} +import org.apache.spark.{SparkException, HashPartitioner, Partitioner, SharedSparkContext} import org.apache.spark.util.Utils import org.scalatest.FunSuite @@ -551,6 +555,68 @@ class PairRDDFunctionsSuite extends FunSuite with SharedSparkContext { intercept[IllegalArgumentException] {shuffled.lookup(-1)} } + test("assumePartitioned") { + val nGroups = 20 + val nParts = 10 + val rdd: RDD[(Int, Int)] = sc.parallelize(1 to 100, nParts).map{x => (x % nGroups) -> x}. + partitionBy(new HashPartitioner(nParts)) + val tempDir = Utils.createTempDir() + val f = new File(tempDir, "assumedPartitionedSeqFile") + val path = f.getAbsolutePath + rdd.saveAsSequenceFile(path) + + // this is basically sc.sequenceFile[Int,Int], but with input splits turned off + val reloaded: RDD[(Int,Int)] = sc.hadoopFile( + path, + classOf[NoSplitSequenceFileInputFormat[IntWritable,IntWritable]], + classOf[IntWritable], + classOf[IntWritable], + nParts + ).map{case(k,v) => k.get() -> v.get()} + + + val assumedPartitioned = reloaded.assumePartitionedBy(rdd.partitioner.get) + assumedPartitioned.count() //need an action to run the verify step + + val j1: RDD[(Int, (Iterable[Int], Iterable[Int]))] = rdd.cogroup(assumedPartitioned) + assert(j1.getNarrowAncestors.contains(rdd)) + assert(j1.getNarrowAncestors.contains(assumedPartitioned)) + + j1.foreach{case(group, (left, right)) => + //check that we've got the same groups in both RDDs + val leftSet = left.toSet + val rightSet = right.toSet + if (leftSet != rightSet) throw new RuntimeException("left not equal to right") + //and check that the groups are correct + leftSet.foreach{x =>if (x % nGroups != group) throw new RuntimeException(s"bad grouping")} + } + + + // this is just to make sure the test is actually useful, and would catch a mistake if it was + // *not* a narrow dependency + val j2 = rdd.cogroup(reloaded) + assert(!j2.getNarrowAncestors.contains(reloaded)) + } + + test("assumePartitioned -- verify works") { + val rdd = sc.parallelize(1 to 100, 10).groupBy(identity) + //catch wrong number of partitions immediately + val exc1 = intercept[SparkException]{rdd.assumePartitionedBy(new HashPartitioner(5))} + assert(exc1.getMessage() == ("Assumed Partitioner org.apache.spark.HashPartitioner@5 expects" + + " 5 partitions, but there are 10 partitions. If you are assuming a partitioner on a" + + " HadoopRDD, you might need to disable input splits with a custom input format")) + + //also catch wrong partitioner (w/ right number of partitions) during action + val assumedPartitioned = rdd.assumePartitionedBy(new Partitioner { + override def numPartitions: Int = 10 + override def getPartition(key: Any): Int = 3 + }) + val exc2 = intercept[SparkException] {assumedPartitioned.collect()} + assert(exc2.getMessage().contains(" was not in the assumed partition. If you are assuming a" + + " partitioner on a HadoopRDD, you might need to disable input splits with a custom input" + + " format")) + } + private object StratifiedAuxiliary { def stratifier (fractionPositive: Double) = { (x: Int) => if (x % 10 < (10 * fractionPositive).toInt) "1" else "0" @@ -747,3 +813,7 @@ class ConfigTestFormat() extends NewFakeFormat() with Configurable { super.getRecordWriter(p1) } } + +class NoSplitSequenceFileInputFormat[K,V] extends SequenceFileInputFormat[K,V] { + override def isSplitable(fs: FileSystem, file: Path) = false +}