Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.rdd

import org.apache.spark.{SparkException, TaskContext, Partition, Partitioner}

import scala.reflect.ClassTag

private[spark] class AssumedPartitionedRDD[K: ClassTag, V: ClassTag](
parent: RDD[(K,V)],
part: Partitioner,
val verify: Boolean
) extends RDD[(K,V)](parent) {

override val partitioner = Some(part)

override def getPartitions: Array[Partition] = firstParent[(K,V)].partitions

if(verify && getPartitions.size != part.numPartitions) {
throw new SparkException(s"Assumed Partitioner $part expects ${part.numPartitions} " +
s"partitions, but there are ${getPartitions.size} partitions. If you are assuming a" +
s" partitioner on a HadoopRDD, you might need to disable input splits with a custom input" +
s" format")
}

override def compute(split: Partition, context: TaskContext) = {
if (verify) {
firstParent[(K,V)].iterator(split, context).map{ case(k,v) =>
if (partitioner.get.getPartition(k) != split.index) {
throw new SparkException(s"key $k in split ${split.index} was not in the assumed " +
s"partition. If you are assuming a partitioner on a HadoopRDD, you might need to " +
s"disable input splits with a custom input format")
}
(k,v)
}
} else {
firstParent[(K,V)].iterator(split, context)
}
}

}
18 changes: 17 additions & 1 deletion core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,10 @@ class HadoopRDD[K, V](
if (inputFormat.isInstanceOf[Configurable]) {
inputFormat.asInstanceOf[Configurable].setConf(jobConf)
}
val inputSplits = inputFormat.getSplits(jobConf, minPartitions)
// we have to sort the partitions here so that part-0000 goes to partition 0, etc. This is
// so we can use the same partitioner after we save an RDD to hdfs and then read it back
// SPARK-1061
val inputSplits = inputFormat.getSplits(jobConf, minPartitions).sorted(SplitOrdering)
val array = new Array[Partition](inputSplits.size)
for (i <- 0 until inputSplits.size) {
array(i) = new HadoopPartition(id, i, inputSplits(i))
Expand Down Expand Up @@ -416,3 +419,16 @@ private[spark] object HadoopRDD extends Logging {
out.seq
}
}

private[spark] object SplitOrdering extends Ordering[InputSplit] {
def compare(x: InputSplit, y: InputSplit): Int = {
(x,y) match {
case fileSplits: (FileSplit, FileSplit) =>
fileSplitOrdering.compare(fileSplits._1, fileSplits._2)
case _ => 1
}
}

val fileSplitOrdering: Ordering[FileSplit] = Ordering.by{fileSplit =>
(fileSplit.getPath.toString, fileSplit.getStart)}
}
17 changes: 17 additions & 0 deletions core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1115,6 +1115,23 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)])
*/
def values: RDD[V] = self.map(_._2)

/**
* Return an RDD that assumes the data in this RDD has been partitioned. This can be useful
* to create narrow dependencies to avoid shuffles. This can be especially useful when loading
* an RDD from HDFS, where that RDD was already partitioned when it was saved. Note that when
* loading a file from HDFS, you must ensure that the input splits are disabled with a custom
* FileInputFormat.
*
* If verify == true, every record will be checked against the Partitioner. If any record is not
* in the correct partition, an exception will be thrown when the RDD is computed.
*
* If verify == false, and the RDD is not partitioned correctly, the behavior is undefined. There
* may be a runtime error, or there may simply be wildly inaccurate results with no warning.
*/
def assumePartitionedBy(partitioner: Partitioner, verify: Boolean = true): RDD[(K,V)] = {
new AssumedPartitionedRDD[K,V](self, partitioner, verify)
}

private[spark] def keyClass: Class[_] = kt.runtimeClass

private[spark] def valueClass: Class[_] = vt.runtimeClass
Expand Down
38 changes: 38 additions & 0 deletions core/src/test/scala/org/apache/spark/rdd/HadoopRDDSuite.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.rdd

import java.util

import org.apache.hadoop.fs.Path
import org.apache.hadoop.mapred.FileSplit
import org.scalatest.{Matchers, FunSuite}

import scala.collection.JavaConverters._

class HadoopRDDSuite extends FunSuite with Matchers {
test("file split ordering") {
val splits = (0 until 10).map{idx =>
new FileSplit(new Path("/foo/bar/part-0000" + idx), 0l, 0l, Array[String]())}

val javaShuffledSplits = new util.ArrayList[FileSplit]()
splits.foreach{s => javaShuffledSplits.add(s)}
java.util.Collections.shuffle(javaShuffledSplits)
val scalaShuffledSplits = javaShuffledSplits.asScala
scalaShuffledSplits.sorted(SplitOrdering) should be (splits)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,18 +17,22 @@

package org.apache.spark.rdd

import org.apache.hadoop.fs.FileSystem
import java.io.File

import org.apache.hadoop.fs.{Path, FileSystem}
import org.apache.hadoop.io.IntWritable
import org.apache.hadoop.mapred._
import org.apache.hadoop.util.Progressable

import scala.collection.mutable.{ArrayBuffer, HashSet}
import scala.util.Random

import org.apache.hadoop.conf.{Configurable, Configuration}
import org.apache.hadoop.mapreduce.{JobContext => NewJobContext, OutputCommitter => NewOutputCommitter,
import org.apache.hadoop.mapreduce.{JobContext => NewJobContext,
OutputCommitter => NewOutputCommitter,
OutputFormat => NewOutputFormat, RecordWriter => NewRecordWriter,
TaskAttemptContext => NewTaskAttempContext}
import org.apache.spark.{Partitioner, SharedSparkContext}
import org.apache.spark.{SparkException, HashPartitioner, Partitioner, SharedSparkContext}
import org.apache.spark.util.Utils

import org.scalatest.FunSuite
Expand Down Expand Up @@ -551,6 +555,68 @@ class PairRDDFunctionsSuite extends FunSuite with SharedSparkContext {
intercept[IllegalArgumentException] {shuffled.lookup(-1)}
}

test("assumePartitioned") {
val nGroups = 20
val nParts = 10
val rdd: RDD[(Int, Int)] = sc.parallelize(1 to 100, nParts).map{x => (x % nGroups) -> x}.
partitionBy(new HashPartitioner(nParts))
val tempDir = Utils.createTempDir()
val f = new File(tempDir, "assumedPartitionedSeqFile")
val path = f.getAbsolutePath
rdd.saveAsSequenceFile(path)

// this is basically sc.sequenceFile[Int,Int], but with input splits turned off
val reloaded: RDD[(Int,Int)] = sc.hadoopFile(
path,
classOf[NoSplitSequenceFileInputFormat[IntWritable,IntWritable]],
classOf[IntWritable],
classOf[IntWritable],
nParts
).map{case(k,v) => k.get() -> v.get()}


val assumedPartitioned = reloaded.assumePartitionedBy(rdd.partitioner.get)
assumedPartitioned.count() //need an action to run the verify step

val j1: RDD[(Int, (Iterable[Int], Iterable[Int]))] = rdd.cogroup(assumedPartitioned)
assert(j1.getNarrowAncestors.contains(rdd))
assert(j1.getNarrowAncestors.contains(assumedPartitioned))

j1.foreach{case(group, (left, right)) =>
//check that we've got the same groups in both RDDs
val leftSet = left.toSet
val rightSet = right.toSet
if (leftSet != rightSet) throw new RuntimeException("left not equal to right")
//and check that the groups are correct
leftSet.foreach{x =>if (x % nGroups != group) throw new RuntimeException(s"bad grouping")}
}


// this is just to make sure the test is actually useful, and would catch a mistake if it was
// *not* a narrow dependency
val j2 = rdd.cogroup(reloaded)
assert(!j2.getNarrowAncestors.contains(reloaded))
}

test("assumePartitioned -- verify works") {
val rdd = sc.parallelize(1 to 100, 10).groupBy(identity)
//catch wrong number of partitions immediately
val exc1 = intercept[SparkException]{rdd.assumePartitionedBy(new HashPartitioner(5))}
assert(exc1.getMessage() == ("Assumed Partitioner org.apache.spark.HashPartitioner@5 expects" +
" 5 partitions, but there are 10 partitions. If you are assuming a partitioner on a" +
" HadoopRDD, you might need to disable input splits with a custom input format"))

//also catch wrong partitioner (w/ right number of partitions) during action
val assumedPartitioned = rdd.assumePartitionedBy(new Partitioner {
override def numPartitions: Int = 10
override def getPartition(key: Any): Int = 3
})
val exc2 = intercept[SparkException] {assumedPartitioned.collect()}
assert(exc2.getMessage().contains(" was not in the assumed partition. If you are assuming a" +
" partitioner on a HadoopRDD, you might need to disable input splits with a custom input" +
" format"))
}

private object StratifiedAuxiliary {
def stratifier (fractionPositive: Double) = {
(x: Int) => if (x % 10 < (10 * fractionPositive).toInt) "1" else "0"
Expand Down Expand Up @@ -747,3 +813,7 @@ class ConfigTestFormat() extends NewFakeFormat() with Configurable {
super.getRecordWriter(p1)
}
}

class NoSplitSequenceFileInputFormat[K,V] extends SequenceFileInputFormat[K,V] {
override def isSplitable(fs: FileSystem, file: Path) = false
}