@@ -19,16 +19,17 @@ package org.apache.spark.rdd
1919
2020import java .io .File
2121
22- import org .apache .commons . io . FileUtils
23- import org .apache .hadoop .fs . FileSystem
22+ import org .apache .hadoop . fs .{ Path , FileSystem }
23+ import org .apache .hadoop .io . IntWritable
2424import org .apache .hadoop .mapred ._
2525import org .apache .hadoop .util .Progressable
2626
2727import scala .collection .mutable .{ArrayBuffer , HashSet }
2828import scala .util .Random
2929
3030import org .apache .hadoop .conf .{Configurable , Configuration }
31- import org .apache .hadoop .mapreduce .{JobContext => NewJobContext , OutputCommitter => NewOutputCommitter ,
31+ import org .apache .hadoop .mapreduce .{JobContext => NewJobContext ,
32+ OutputCommitter => NewOutputCommitter ,
3233OutputFormat => NewOutputFormat , RecordWriter => NewRecordWriter ,
3334TaskAttemptContext => NewTaskAttempContext }
3435import org .apache .spark .{SparkException , HashPartitioner , Partitioner , SharedSparkContext }
@@ -555,17 +556,33 @@ class PairRDDFunctionsSuite extends FunSuite with SharedSparkContext {
555556 }
556557
557558 test(" assumePartitioned" ) {
558- val rdd : RDD [(Int , String )] = sc.parallelize(1 to 100 , 10 ).groupBy{x => x % 10 }.
559- mapPartitions({itr => itr.map{case (k,v) => k -> v.mkString(" ," )}}, true )
560- val path = " tmp_txt_file"
561- val f = new File (path)
562- if (f.exists) FileUtils .deleteDirectory(f)
559+ val nGroups = 20
560+ val nParts = 10
561+ val rdd : RDD [(Int , Int )] = sc.parallelize(1 to 100 , nParts).groupBy{x => x % nGroups}.
562+ mapPartitions({itr =>
563+ itr.flatMap{case (k,vals) =>
564+ vals.map{v => k -> v}
565+ }
566+ },
567+ true )
568+ val tempDir = Utils .createTempDir()
569+ val f = new File (tempDir, " assumedPartitionedSeqFile" )
570+ val path = f.getAbsolutePath
563571 rdd.saveAsSequenceFile(path)
564- val reloaded = sc.sequenceFile[Int , String ](path)
572+
573+ // this is basically sc.sequenceFile[Int,Int], but with input splits turned off
574+ val reloaded : RDD [(Int ,Int )] = sc.hadoopFile(
575+ path,
576+ classOf [NoSplitSequenceFileInputFormat [IntWritable ,IntWritable ]],
577+ classOf [IntWritable ],
578+ classOf [IntWritable ],
579+ nParts
580+ ).map{case (k,v) => k.get() -> v.get()}
581+
565582 val assumedPartitioned = reloaded.assumePartitionedBy(rdd.partitioner.get)
566- assumedPartitioned.collect()
583+ assumedPartitioned.count() // need an action to run the verify step
567584
568- val j1 : RDD [(Int , (Iterable [String ], Iterable [String ]))] = rdd.cogroup(assumedPartitioned)
585+ val j1 : RDD [(Int , (Iterable [Int ], Iterable [Int ]))] = rdd.cogroup(assumedPartitioned)
569586 assert(j1.getNarrowAncestors.contains(rdd))
570587 assert(j1.getNarrowAncestors.contains(assumedPartitioned))
571588
@@ -575,14 +592,12 @@ class PairRDDFunctionsSuite extends FunSuite with SharedSparkContext {
575592 val rightSet = right.toSet
576593 if (leftSet != rightSet) throw new RuntimeException (" left not equal to right" )
577594 // and check that the groups are correct
578- leftSet.foreach{str => str.split(" ," ).foreach{
579- x => if (x.toInt % 10 != group) throw new RuntimeException (s " bad grouping " )
580- }}
595+ leftSet.foreach{x => if (x % nGroups != group) throw new RuntimeException (s " bad grouping " )}
581596 }
582597
583598
584- // this is just to make sure the test is actually useful, and would catch a mistake if it was *not*
585- // a narrow dependency
599+ // this is just to make sure the test is actually useful, and would catch a mistake if it was
600+ // *not* a narrow dependency
586601 val j2 = rdd.cogroup(reloaded)
587602 assert(! j2.getNarrowAncestors.contains(reloaded))
588603 }
@@ -591,9 +606,9 @@ class PairRDDFunctionsSuite extends FunSuite with SharedSparkContext {
591606 val rdd = sc.parallelize(1 to 100 , 10 ).groupBy(identity)
592607 // catch wrong number of partitions immediately
593608 val exc1 = intercept[SparkException ]{rdd.assumePartitionedBy(new HashPartitioner (5 ))}
594- assert(exc1.getMessage() == (" Assumed Partitioner org.apache.spark.HashPartitioner@5 expects 5 partitions, but " +
595- " there are 10 partitions. If you are assuming a partitioner on a HadoopRDD, you might need to disable input " +
596- " splits with a custom input format" ))
609+ assert(exc1.getMessage() == (" Assumed Partitioner org.apache.spark.HashPartitioner@5 expects" +
610+ " 5 partitions, but there are 10 partitions. If you are assuming a partitioner on a" +
611+ " HadoopRDD, you might need to disable input splits with a custom input format" ))
597612
598613 // also catch wrong partitioner (w/ right number of partitions) during action
599614 val assumedPartitioned = rdd.assumePartitionedBy(new Partitioner {
@@ -602,7 +617,8 @@ class PairRDDFunctionsSuite extends FunSuite with SharedSparkContext {
602617 })
603618 val exc2 = intercept[SparkException ] {assumedPartitioned.collect()}
604619 assert(exc2.getMessage().contains(" was not in the assumed partition. If you are assuming a" +
605- " partitioner on a HadoopRDD, you might need to disable input splits with a custom input format" ))
620+ " partitioner on a HadoopRDD, you might need to disable input splits with a custom input" +
621+ " format" ))
606622 }
607623
608624 private object StratifiedAuxiliary {
@@ -801,3 +817,7 @@ class ConfigTestFormat() extends NewFakeFormat() with Configurable {
801817 super .getRecordWriter(p1)
802818 }
803819}
820+
821+ class NoSplitSequenceFileInputFormat [K ,V ] extends SequenceFileInputFormat [K ,V ] {
822+ override def isSplitable (fs : FileSystem , file : Path ) = false
823+ }
0 commit comments