Skip to content

Commit 3a9d82c

Browse files
ash211pwendell
authored andcommitted
Merge pull request apache#506 from ash211/intersection. Closes apache#506.
SPARK-1062 Add rdd.intersection(otherRdd) method Author: Andrew Ash <[email protected]> == Merge branch commits == commit 5d9982b171b9572649e9828f37ef0b43f0242912 Author: Andrew Ash <[email protected]> Date: Thu Feb 6 18:11:45 2014 -0800 Minor fixes - style: (v,null) => (v, null) - mention the shuffle in Javadoc commit b86d02f14e810902719cef893cf6bfa18ff9acb0 Author: Andrew Ash <[email protected]> Date: Sun Feb 2 13:17:40 2014 -0800 Overload .intersection() for numPartitions and custom Partitioner commit bcaa34911fcc6bb5bc5e4f9fe46d1df73cb71c09 Author: Andrew Ash <[email protected]> Date: Sun Feb 2 13:05:40 2014 -0800 Better naming of parameters in intersection's filter commit b10a6af2d793ec6e9a06c798007fac3f6b860d89 Author: Andrew Ash <[email protected]> Date: Sat Jan 25 23:06:26 2014 -0800 Follow spark code format conventions of tab => 2 spaces commit 965256e4304cca514bb36a1a36087711dec535ec Author: Andrew Ash <[email protected]> Date: Fri Jan 24 00:28:01 2014 -0800 Add rdd.intersection(otherRdd) method
1 parent 1896c6e commit 3a9d82c

File tree

2 files changed

+58
-2
lines changed

2 files changed

+58
-2
lines changed

core/src/main/scala/org/apache/spark/rdd/RDD.scala

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -393,6 +393,43 @@ abstract class RDD[T: ClassTag](
393393
*/
394394
def ++(other: RDD[T]): RDD[T] = this.union(other)
395395

396+
/**
397+
* Return the intersection of this RDD and another one. The output will not contain any duplicate
398+
* elements, even if the input RDDs did.
399+
*
400+
* Note that this method performs a shuffle internally.
401+
*/
402+
def intersection(other: RDD[T]): RDD[T] =
403+
this.map(v => (v, null)).cogroup(other.map(v => (v, null)))
404+
.filter { case (_, (leftGroup, rightGroup)) => leftGroup.nonEmpty && rightGroup.nonEmpty }
405+
.keys
406+
407+
/**
408+
* Return the intersection of this RDD and another one. The output will not contain any duplicate
409+
* elements, even if the input RDDs did.
410+
*
411+
* Note that this method performs a shuffle internally.
412+
*
413+
* @param partitioner Partitioner to use for the resulting RDD
414+
*/
415+
def intersection(other: RDD[T], partitioner: Partitioner): RDD[T] =
416+
this.map(v => (v, null)).cogroup(other.map(v => (v, null)), partitioner)
417+
.filter { case (_, (leftGroup, rightGroup)) => leftGroup.nonEmpty && rightGroup.nonEmpty }
418+
.keys
419+
420+
/**
421+
* Return the intersection of this RDD and another one. The output will not contain any duplicate
422+
* elements, even if the input RDDs did. Performs a hash partition across the cluster
423+
*
424+
* Note that this method performs a shuffle internally.
425+
*
426+
* @param numPartitions How many partitions to use in the resulting RDD
427+
*/
428+
def intersection(other: RDD[T], numPartitions: Int): RDD[T] =
429+
this.map(v => (v, null)).cogroup(other.map(v => (v, null)), new HashPartitioner(numPartitions))
430+
.filter { case (_, (leftGroup, rightGroup)) => leftGroup.nonEmpty && rightGroup.nonEmpty }
431+
.keys
432+
396433
/**
397434
* Return an RDD created by coalescing all elements within each partition into an array.
398435
*/

core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -373,8 +373,8 @@ class RDDSuite extends FunSuite with SharedSparkContext {
373373
val prng42 = new Random(42)
374374
val prng43 = new Random(43)
375375
Array(1, 2, 3, 4, 5, 6).filter{i =>
376-
if (i < 4) 0 == prng42.nextInt(3)
377-
else 0 == prng43.nextInt(3)}
376+
if (i < 4) 0 == prng42.nextInt(3)
377+
else 0 == prng43.nextInt(3)}
378378
}
379379
assert(sample.size === checkSample.size)
380380
for (i <- 0 until sample.size) assert(sample(i) === checkSample(i))
@@ -506,4 +506,23 @@ class RDDSuite extends FunSuite with SharedSparkContext {
506506
sc.runJob(sc.parallelize(1 to 10, 2), {iter: Iterator[Int] => iter.size}, Seq(0, 1, 2), false)
507507
}
508508
}
509+
510+
test("intersection") {
511+
val all = sc.parallelize(1 to 10)
512+
val evens = sc.parallelize(2 to 10 by 2)
513+
val intersection = Array(2, 4, 6, 8, 10)
514+
515+
// intersection is commutative
516+
assert(all.intersection(evens).collect.sorted === intersection)
517+
assert(evens.intersection(all).collect.sorted === intersection)
518+
}
519+
520+
test("intersection strips duplicates in an input") {
521+
val a = sc.parallelize(Seq(1,2,3,3))
522+
val b = sc.parallelize(Seq(1,1,2,3))
523+
val intersection = Array(1,2,3)
524+
525+
assert(a.intersection(b).collect.sorted === intersection)
526+
assert(b.intersection(a).collect.sorted === intersection)
527+
}
509528
}

0 commit comments

Comments
 (0)