Skip to content

Commit 0e98abe

Browse files
committed
be sure to turn off input splits in test
1 parent 943984f commit 0e98abe

File tree

1 file changed

+40
-20
lines changed

1 file changed

+40
-20
lines changed

core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala

Lines changed: 40 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -19,16 +19,17 @@ package org.apache.spark.rdd
1919

2020
import java.io.File
2121

22-
import org.apache.commons.io.FileUtils
23-
import org.apache.hadoop.fs.FileSystem
22+
import org.apache.hadoop.fs.{Path, FileSystem}
23+
import org.apache.hadoop.io.IntWritable
2424
import org.apache.hadoop.mapred._
2525
import org.apache.hadoop.util.Progressable
2626

2727
import scala.collection.mutable.{ArrayBuffer, HashSet}
2828
import scala.util.Random
2929

3030
import org.apache.hadoop.conf.{Configurable, Configuration}
31-
import org.apache.hadoop.mapreduce.{JobContext => NewJobContext, OutputCommitter => NewOutputCommitter,
31+
import org.apache.hadoop.mapreduce.{JobContext => NewJobContext,
32+
OutputCommitter => NewOutputCommitter,
3233
OutputFormat => NewOutputFormat, RecordWriter => NewRecordWriter,
3334
TaskAttemptContext => NewTaskAttempContext}
3435
import org.apache.spark.{SparkException, HashPartitioner, Partitioner, SharedSparkContext}
@@ -555,17 +556,33 @@ class PairRDDFunctionsSuite extends FunSuite with SharedSparkContext {
555556
}
556557

557558
test("assumePartitioned") {
558-
val rdd: RDD[(Int, String)] = sc.parallelize(1 to 100, 10).groupBy{x => x % 10}.
559-
mapPartitions({itr => itr.map{case(k,v) => k -> v.mkString(",")}}, true)
560-
val path = "tmp_txt_file"
561-
val f = new File(path)
562-
if (f.exists) FileUtils.deleteDirectory(f)
559+
val nGroups = 20
560+
val nParts = 10
561+
val rdd: RDD[(Int, Int)] = sc.parallelize(1 to 100, nParts).groupBy{x => x % nGroups}.
562+
mapPartitions({itr =>
563+
itr.flatMap{case(k,vals) =>
564+
vals.map{v => k -> v}
565+
}
566+
},
567+
true)
568+
val tempDir = Utils.createTempDir()
569+
val f = new File(tempDir, "assumedPartitionedSeqFile")
570+
val path = f.getAbsolutePath
563571
rdd.saveAsSequenceFile(path)
564-
val reloaded = sc.sequenceFile[Int, String](path)
572+
573+
// this is basically sc.sequenceFile[Int,Int], but with input splits turned off
574+
val reloaded: RDD[(Int,Int)] = sc.hadoopFile(
575+
path,
576+
classOf[NoSplitSequenceFileInputFormat[IntWritable,IntWritable]],
577+
classOf[IntWritable],
578+
classOf[IntWritable],
579+
nParts
580+
).map{case(k,v) => k.get() -> v.get()}
581+
565582
val assumedPartitioned = reloaded.assumePartitionedBy(rdd.partitioner.get)
566-
assumedPartitioned.collect()
583+
assumedPartitioned.count() //need an action to run the verify step
567584

568-
val j1: RDD[(Int, (Iterable[String], Iterable[String]))] = rdd.cogroup(assumedPartitioned)
585+
val j1: RDD[(Int, (Iterable[Int], Iterable[Int]))] = rdd.cogroup(assumedPartitioned)
569586
assert(j1.getNarrowAncestors.contains(rdd))
570587
assert(j1.getNarrowAncestors.contains(assumedPartitioned))
571588

@@ -575,14 +592,12 @@ class PairRDDFunctionsSuite extends FunSuite with SharedSparkContext {
575592
val rightSet = right.toSet
576593
if (leftSet != rightSet) throw new RuntimeException("left not equal to right")
577594
//and check that the groups are correct
578-
leftSet.foreach{str => str.split(",").foreach{
579-
x => if (x.toInt % 10 != group) throw new RuntimeException(s"bad grouping")
580-
}}
595+
leftSet.foreach{x =>if (x % nGroups != group) throw new RuntimeException(s"bad grouping")}
581596
}
582597

583598

584-
// this is just to make sure the test is actually useful, and would catch a mistake if it was *not*
585-
// a narrow dependency
599+
// this is just to make sure the test is actually useful, and would catch a mistake if it was
600+
// *not* a narrow dependency
586601
val j2 = rdd.cogroup(reloaded)
587602
assert(!j2.getNarrowAncestors.contains(reloaded))
588603
}
@@ -591,9 +606,9 @@ class PairRDDFunctionsSuite extends FunSuite with SharedSparkContext {
591606
val rdd = sc.parallelize(1 to 100, 10).groupBy(identity)
592607
//catch wrong number of partitions immediately
593608
val exc1 = intercept[SparkException]{rdd.assumePartitionedBy(new HashPartitioner(5))}
594-
assert(exc1.getMessage() == ("Assumed Partitioner org.apache.spark.HashPartitioner@5 expects 5 partitions, but" +
595-
" there are 10 partitions. If you are assuming a partitioner on a HadoopRDD, you might need to disable input" +
596-
" splits with a custom input format"))
609+
assert(exc1.getMessage() == ("Assumed Partitioner org.apache.spark.HashPartitioner@5 expects" +
610+
" 5 partitions, but there are 10 partitions. If you are assuming a partitioner on a" +
611+
" HadoopRDD, you might need to disable input splits with a custom input format"))
597612

598613
//also catch wrong partitioner (w/ right number of partitions) during action
599614
val assumedPartitioned = rdd.assumePartitionedBy(new Partitioner {
@@ -602,7 +617,8 @@ class PairRDDFunctionsSuite extends FunSuite with SharedSparkContext {
602617
})
603618
val exc2 = intercept[SparkException] {assumedPartitioned.collect()}
604619
assert(exc2.getMessage().contains(" was not in the assumed partition. If you are assuming a" +
605-
" partitioner on a HadoopRDD, you might need to disable input splits with a custom input format"))
620+
" partitioner on a HadoopRDD, you might need to disable input splits with a custom input" +
621+
" format"))
606622
}
607623

608624
private object StratifiedAuxiliary {
@@ -801,3 +817,7 @@ class ConfigTestFormat() extends NewFakeFormat() with Configurable {
801817
super.getRecordWriter(p1)
802818
}
803819
}
820+
821+
class NoSplitSequenceFileInputFormat[K,V] extends SequenceFileInputFormat[K,V] {
822+
override def isSplitable(fs: FileSystem, file: Path) = false
823+
}

0 commit comments

Comments
 (0)