diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala b/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala index 14fa9d8135afe..4f3081433a542 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala @@ -543,6 +543,18 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)]) partitioner: Partitioner): JavaPairRDD[K, (JIterable[V], JIterable[W1], JIterable[W2])] = fromRDD(cogroupResult2ToJava(rdd.cogroup(other1, other2, partitioner))) + /** + * For each key k in `this` or `other1` or `other2` or `other3`, + * return a resulting RDD that contains a tuple with the list of values + * for that key in `this`, `other1`, `other2` and `other3`. + */ + def cogroup[W1, W2, W3](other1: JavaPairRDD[K, W1], + other2: JavaPairRDD[K, W2], + other3: JavaPairRDD[K, W3], + partitioner: Partitioner) + : JavaPairRDD[K, (JIterable[V], JIterable[W1], JIterable[W2], JIterable[W3])] = + fromRDD(cogroupResult3ToJava(rdd.cogroup(other1, other2, other3, partitioner))) + /** * For each key k in `this` or `other`, return a resulting RDD that contains a tuple with the * list of values for that key in `this` as well as `other`. @@ -558,6 +570,17 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)]) : JavaPairRDD[K, (JIterable[V], JIterable[W1], JIterable[W2])] = fromRDD(cogroupResult2ToJava(rdd.cogroup(other1, other2))) + /** + * For each key k in `this` or `other1` or `other2` or `other3`, + * return a resulting RDD that contains a tuple with the list of values + * for that key in `this`, `other1`, `other2` and `other3`. + */ + def cogroup[W1, W2, W3](other1: JavaPairRDD[K, W1], + other2: JavaPairRDD[K, W2], + other3: JavaPairRDD[K, W3]) + : JavaPairRDD[K, (JIterable[V], JIterable[W1], JIterable[W2], JIterable[W3])] = + fromRDD(cogroupResult3ToJava(rdd.cogroup(other1, other2, other3))) + /** * For each key k in `this` or `other`, return a resulting RDD that contains a tuple with the * list of values for that key in `this` as well as `other`. @@ -574,6 +597,18 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)]) : JavaPairRDD[K, (JIterable[V], JIterable[W1], JIterable[W2])] = fromRDD(cogroupResult2ToJava(rdd.cogroup(other1, other2, numPartitions))) + /** + * For each key k in `this` or `other1` or `other2` or `other3`, + * return a resulting RDD that contains a tuple with the list of values + * for that key in `this`, `other1`, `other2` and `other3`. + */ + def cogroup[W1, W2, W3](other1: JavaPairRDD[K, W1], + other2: JavaPairRDD[K, W2], + other3: JavaPairRDD[K, W3], + numPartitions: Int) + : JavaPairRDD[K, (JIterable[V], JIterable[W1], JIterable[W2], JIterable[W3])] = + fromRDD(cogroupResult3ToJava(rdd.cogroup(other1, other2, other3, numPartitions))) + /** Alias for cogroup. */ def groupWith[W](other: JavaPairRDD[K, W]): JavaPairRDD[K, (JIterable[V], JIterable[W])] = fromRDD(cogroupResultToJava(rdd.groupWith(other))) @@ -583,6 +618,13 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)]) : JavaPairRDD[K, (JIterable[V], JIterable[W1], JIterable[W2])] = fromRDD(cogroupResult2ToJava(rdd.groupWith(other1, other2))) + /** Alias for cogroup. */ + def groupWith[W1, W2, W3](other1: JavaPairRDD[K, W1], + other2: JavaPairRDD[K, W2], + other3: JavaPairRDD[K, W3]) + : JavaPairRDD[K, (JIterable[V], JIterable[W1], JIterable[W2], JIterable[W3])] = + fromRDD(cogroupResult3ToJava(rdd.groupWith(other1, other2, other3))) + /** * Return the list of values in the RDD for key `key`. This operation is done efficiently if the * RDD has a known partitioner by only searching the partition that the key maps to. @@ -786,6 +828,15 @@ object JavaPairRDD { .mapValues(x => (asJavaIterable(x._1), asJavaIterable(x._2), asJavaIterable(x._3))) } + private[spark] + def cogroupResult3ToJava[K: ClassTag, V, W1, W2, W3]( + rdd: RDD[(K, (Iterable[V], Iterable[W1], Iterable[W2], Iterable[W3]))]) + : RDD[(K, (JIterable[V], JIterable[W1], JIterable[W2], JIterable[W3]))] = { + rddToPairRDDFunctions(rdd) + .mapValues(x => + (asJavaIterable(x._1), asJavaIterable(x._2), asJavaIterable(x._3), asJavaIterable(x._4))) + } + def fromRDD[K: ClassTag, V: ClassTag](rdd: RDD[(K, V)]): JavaPairRDD[K, V] = { new JavaPairRDD[K, V](rdd) } diff --git a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala index bff77b4ecbf27..92d2e24a68834 100644 --- a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala +++ b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala @@ -567,6 +567,28 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) new FlatMappedValuesRDD(self, cleanF) } + /** + * For each key k in `this` or `other1` or `other2` or `other3`, + * return a resulting RDD that contains a tuple with the list of values + * for that key in `this`, `other1`, `other2` and `other3`. + */ + def cogroup[W1, W2, W3](other1: RDD[(K, W1)], + other2: RDD[(K, W2)], + other3: RDD[(K, W3)], + partitioner: Partitioner) + : RDD[(K, (Iterable[V], Iterable[W1], Iterable[W2], Iterable[W3]))] = { + if (partitioner.isInstanceOf[HashPartitioner] && keyClass.isArray) { + throw new SparkException("Default partitioner cannot partition array keys.") + } + val cg = new CoGroupedRDD[K](Seq(self, other1, other2, other3), partitioner) + cg.mapValues { case Seq(vs, w1s, w2s, w3s) => + (vs.asInstanceOf[Seq[V]], + w1s.asInstanceOf[Seq[W1]], + w2s.asInstanceOf[Seq[W2]], + w3s.asInstanceOf[Seq[W3]]) + } + } + /** * For each key k in `this` or `other`, return a resulting RDD that contains a tuple with the * list of values for that key in `this` as well as `other`. @@ -599,6 +621,16 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) } } + /** + * For each key k in `this` or `other1` or `other2` or `other3`, + * return a resulting RDD that contains a tuple with the list of values + * for that key in `this`, `other1`, `other2` and `other3`. + */ + def cogroup[W1, W2, W3](other1: RDD[(K, W1)], other2: RDD[(K, W2)], other3: RDD[(K, W3)]) + : RDD[(K, (Iterable[V], Iterable[W1], Iterable[W2], Iterable[W3]))] = { + cogroup(other1, other2, other3, defaultPartitioner(self, other1, other2, other3)) + } + /** * For each key k in `this` or `other`, return a resulting RDD that contains a tuple with the * list of values for that key in `this` as well as `other`. @@ -633,6 +665,19 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) cogroup(other1, other2, new HashPartitioner(numPartitions)) } + /** + * For each key k in `this` or `other1` or `other2` or `other3`, + * return a resulting RDD that contains a tuple with the list of values + * for that key in `this`, `other1`, `other2` and `other3`. + */ + def cogroup[W1, W2, W3](other1: RDD[(K, W1)], + other2: RDD[(K, W2)], + other3: RDD[(K, W3)], + numPartitions: Int) + : RDD[(K, (Iterable[V], Iterable[W1], Iterable[W2], Iterable[W3]))] = { + cogroup(other1, other2, other3, new HashPartitioner(numPartitions)) + } + /** Alias for cogroup. */ def groupWith[W](other: RDD[(K, W)]): RDD[(K, (Iterable[V], Iterable[W]))] = { cogroup(other, defaultPartitioner(self, other)) @@ -644,6 +689,12 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) cogroup(other1, other2, defaultPartitioner(self, other1, other2)) } + /** Alias for cogroup. */ + def groupWith[W1, W2, W3](other1: RDD[(K, W1)], other2: RDD[(K, W2)], other3: RDD[(K, W3)]) + : RDD[(K, (Iterable[V], Iterable[W1], Iterable[W2], Iterable[W3]))] = { + cogroup(other1, other2, other3, defaultPartitioner(self, other1, other2, other3)) + } + /** * Return an RDD with the pairs from `this` whose keys are not in `other`. * diff --git a/core/src/test/java/org/apache/spark/JavaAPISuite.java b/core/src/test/java/org/apache/spark/JavaAPISuite.java index e46298c6a9e63..761f2d6a77d33 100644 --- a/core/src/test/java/org/apache/spark/JavaAPISuite.java +++ b/core/src/test/java/org/apache/spark/JavaAPISuite.java @@ -21,6 +21,9 @@ import java.util.*; import scala.Tuple2; +import scala.Tuple3; +import scala.Tuple4; + import com.google.common.collect.Iterables; import com.google.common.collect.Iterators; @@ -304,6 +307,66 @@ public void cogroup() { cogrouped.collect(); } + @SuppressWarnings("unchecked") + @Test + public void cogroup3() { + JavaPairRDD categories = sc.parallelizePairs(Arrays.asList( + new Tuple2("Apples", "Fruit"), + new Tuple2("Oranges", "Fruit"), + new Tuple2("Oranges", "Citrus") + )); + JavaPairRDD prices = sc.parallelizePairs(Arrays.asList( + new Tuple2("Oranges", 2), + new Tuple2("Apples", 3) + )); + JavaPairRDD quantities = sc.parallelizePairs(Arrays.asList( + new Tuple2("Oranges", 21), + new Tuple2("Apples", 42) + )); + + JavaPairRDD, Iterable, Iterable>> cogrouped = + categories.cogroup(prices, quantities); + Assert.assertEquals("[Fruit, Citrus]", + Iterables.toString(cogrouped.lookup("Oranges").get(0)._1())); + Assert.assertEquals("[2]", Iterables.toString(cogrouped.lookup("Oranges").get(0)._2())); + Assert.assertEquals("[42]", Iterables.toString(cogrouped.lookup("Apples").get(0)._3())); + + + cogrouped.collect(); + } + + @SuppressWarnings("unchecked") + @Test + public void cogroup4() { + JavaPairRDD categories = sc.parallelizePairs(Arrays.asList( + new Tuple2("Apples", "Fruit"), + new Tuple2("Oranges", "Fruit"), + new Tuple2("Oranges", "Citrus") + )); + JavaPairRDD prices = sc.parallelizePairs(Arrays.asList( + new Tuple2("Oranges", 2), + new Tuple2("Apples", 3) + )); + JavaPairRDD quantities = sc.parallelizePairs(Arrays.asList( + new Tuple2("Oranges", 21), + new Tuple2("Apples", 42) + )); + JavaPairRDD countries = sc.parallelizePairs(Arrays.asList( + new Tuple2("Oranges", "BR"), + new Tuple2("Apples", "US") + )); + + JavaPairRDD, Iterable, Iterable, Iterable>> cogrouped = + categories.cogroup(prices, quantities, countries); + Assert.assertEquals("[Fruit, Citrus]", + Iterables.toString(cogrouped.lookup("Oranges").get(0)._1())); + Assert.assertEquals("[2]", Iterables.toString(cogrouped.lookup("Oranges").get(0)._2())); + Assert.assertEquals("[42]", Iterables.toString(cogrouped.lookup("Apples").get(0)._3())); + Assert.assertEquals("[BR]", Iterables.toString(cogrouped.lookup("Oranges").get(0)._4())); + + cogrouped.collect(); + } + @SuppressWarnings("unchecked") @Test public void leftOuterJoin() { diff --git a/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala b/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala index 0b9004448a63e..447e38ec9dbd0 100644 --- a/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala @@ -249,6 +249,39 @@ class PairRDDFunctionsSuite extends FunSuite with SharedSparkContext { )) } + test("groupWith3") { + val rdd1 = sc.parallelize(Array((1, 1), (1, 2), (2, 1), (3, 1))) + val rdd2 = sc.parallelize(Array((1, 'x'), (2, 'y'), (2, 'z'), (4, 'w'))) + val rdd3 = sc.parallelize(Array((1, 'a'), (3, 'b'), (4, 'c'), (4, 'd'))) + val joined = rdd1.groupWith(rdd2, rdd3).collect() + assert(joined.size === 4) + val joinedSet = joined.map(x => (x._1, + (x._2._1.toList, x._2._2.toList, x._2._3.toList))).toSet + assert(joinedSet === Set( + (1, (List(1, 2), List('x'), List('a'))), + (2, (List(1), List('y', 'z'), List())), + (3, (List(1), List(), List('b'))), + (4, (List(), List('w'), List('c', 'd'))) + )) + } + + test("groupWith4") { + val rdd1 = sc.parallelize(Array((1, 1), (1, 2), (2, 1), (3, 1))) + val rdd2 = sc.parallelize(Array((1, 'x'), (2, 'y'), (2, 'z'), (4, 'w'))) + val rdd3 = sc.parallelize(Array((1, 'a'), (3, 'b'), (4, 'c'), (4, 'd'))) + val rdd4 = sc.parallelize(Array((2, '@'))) + val joined = rdd1.groupWith(rdd2, rdd3, rdd4).collect() + assert(joined.size === 4) + val joinedSet = joined.map(x => (x._1, + (x._2._1.toList, x._2._2.toList, x._2._3.toList, x._2._4.toList))).toSet + assert(joinedSet === Set( + (1, (List(1, 2), List('x'), List('a'), List())), + (2, (List(1), List('y', 'z'), List(), List('@'))), + (3, (List(1), List(), List('b'), List())), + (4, (List(), List('w'), List('c', 'd'), List())) + )) + } + test("zero-partition RDD") { val emptyDir = Files.createTempDir() emptyDir.deleteOnExit() diff --git a/python/pyspark/join.py b/python/pyspark/join.py index 6f94d26ef86a9..5f3a7e71f7866 100644 --- a/python/pyspark/join.py +++ b/python/pyspark/join.py @@ -79,15 +79,15 @@ def dispatch(seq): return _do_python_join(rdd, other, numPartitions, dispatch) -def python_cogroup(rdd, other, numPartitions): - vs = rdd.map(lambda (k, v): (k, (1, v))) - ws = other.map(lambda (k, v): (k, (2, v))) +def python_cogroup(rdds, numPartitions): + def make_mapper(i): + return lambda (k, v): (k, (i, v)) + vrdds = [rdd.map(make_mapper(i)) for i, rdd in enumerate(rdds)] + union_vrdds = reduce(lambda acc, other: acc.union(other), vrdds) + rdd_len = len(vrdds) def dispatch(seq): - vbuf, wbuf = [], [] + bufs = [[] for i in range(rdd_len)] for (n, v) in seq: - if n == 1: - vbuf.append(v) - elif n == 2: - wbuf.append(v) - return (ResultIterable(vbuf), ResultIterable(wbuf)) - return vs.union(ws).groupByKey(numPartitions).mapValues(dispatch) + bufs[n].append(v) + return tuple(map(ResultIterable, bufs)) + return union_vrdds.groupByKey(numPartitions).mapValues(dispatch) diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index a0b2c744f0e7f..41bede576146d 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -1233,7 +1233,7 @@ def _mergeCombiners(iterator): combiners[k] = mergeCombiners(combiners[k], v) return combiners.iteritems() return shuffled.mapPartitions(_mergeCombiners) - + def aggregateByKey(self, zeroValue, seqFunc, combFunc, numPartitions=None): """ Aggregate the values of each key, using given combine functions and a neutral "zero value". @@ -1245,7 +1245,7 @@ def aggregateByKey(self, zeroValue, seqFunc, combFunc, numPartitions=None): """ def createZero(): return copy.deepcopy(zeroValue) - + return self.combineByKey(lambda v: seqFunc(createZero(), v), seqFunc, combFunc, numPartitions) def foldByKey(self, zeroValue, func, numPartitions=None): @@ -1323,12 +1323,20 @@ def mapValues(self, f): map_values_fn = lambda (k, v): (k, f(v)) return self.map(map_values_fn, preservesPartitioning=True) - # TODO: support varargs cogroup of several RDDs. - def groupWith(self, other): + def groupWith(self, other, *others): """ - Alias for cogroup. + Alias for cogroup but with support for multiple RDDs. + + >>> w = sc.parallelize([("a", 5), ("b", 6)]) + >>> x = sc.parallelize([("a", 1), ("b", 4)]) + >>> y = sc.parallelize([("a", 2)]) + >>> z = sc.parallelize([("b", 42)]) + >>> map((lambda (x,y): (x, (list(y[0]), list(y[1]), list(y[2]), list(y[3])))), \ + sorted(list(w.groupWith(x, y, z).collect()))) + [('a', ([5], [1], [2], [])), ('b', ([6], [4], [], [42]))] + """ - return self.cogroup(other) + return python_cogroup((self, other) + others, numPartitions=None) # TODO: add variant with custom parittioner def cogroup(self, other, numPartitions=None): @@ -1342,7 +1350,7 @@ def cogroup(self, other, numPartitions=None): >>> map((lambda (x,y): (x, (list(y[0]), list(y[1])))), sorted(list(x.cogroup(y).collect()))) [('a', ([1], [2])), ('b', ([4], []))] """ - return python_cogroup(self, other, numPartitions) + return python_cogroup((self, other), numPartitions) def subtractByKey(self, other, numPartitions=None): """