Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 51 additions & 0 deletions core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala
Original file line number Diff line number Diff line change
Expand Up @@ -543,6 +543,18 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)])
partitioner: Partitioner): JavaPairRDD[K, (JIterable[V], JIterable[W1], JIterable[W2])] =
fromRDD(cogroupResult2ToJava(rdd.cogroup(other1, other2, partitioner)))

/**
* For each key k in `this` or `other1` or `other2` or `other3`,
* return a resulting RDD that contains a tuple with the list of values
* for that key in `this`, `other1`, `other2` and `other3`.
*/
def cogroup[W1, W2, W3](other1: JavaPairRDD[K, W1],
other2: JavaPairRDD[K, W2],
other3: JavaPairRDD[K, W3],
partitioner: Partitioner)
: JavaPairRDD[K, (JIterable[V], JIterable[W1], JIterable[W2], JIterable[W3])] =
fromRDD(cogroupResult3ToJava(rdd.cogroup(other1, other2, other3, partitioner)))

/**
* For each key k in `this` or `other`, return a resulting RDD that contains a tuple with the
* list of values for that key in `this` as well as `other`.
Expand All @@ -558,6 +570,17 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)])
: JavaPairRDD[K, (JIterable[V], JIterable[W1], JIterable[W2])] =
fromRDD(cogroupResult2ToJava(rdd.cogroup(other1, other2)))

/**
* For each key k in `this` or `other1` or `other2` or `other3`,
* return a resulting RDD that contains a tuple with the list of values
* for that key in `this`, `other1`, `other2` and `other3`.
*/
def cogroup[W1, W2, W3](other1: JavaPairRDD[K, W1],
other2: JavaPairRDD[K, W2],
other3: JavaPairRDD[K, W3])
: JavaPairRDD[K, (JIterable[V], JIterable[W1], JIterable[W2], JIterable[W3])] =
fromRDD(cogroupResult3ToJava(rdd.cogroup(other1, other2, other3)))

/**
* For each key k in `this` or `other`, return a resulting RDD that contains a tuple with the
* list of values for that key in `this` as well as `other`.
Expand All @@ -574,6 +597,18 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)])
: JavaPairRDD[K, (JIterable[V], JIterable[W1], JIterable[W2])] =
fromRDD(cogroupResult2ToJava(rdd.cogroup(other1, other2, numPartitions)))

/**
* For each key k in `this` or `other1` or `other2` or `other3`,
* return a resulting RDD that contains a tuple with the list of values
* for that key in `this`, `other1`, `other2` and `other3`.
*/
def cogroup[W1, W2, W3](other1: JavaPairRDD[K, W1],
other2: JavaPairRDD[K, W2],
other3: JavaPairRDD[K, W3],
numPartitions: Int)
: JavaPairRDD[K, (JIterable[V], JIterable[W1], JIterable[W2], JIterable[W3])] =
fromRDD(cogroupResult3ToJava(rdd.cogroup(other1, other2, other3, numPartitions)))

/** Alias for cogroup. */
def groupWith[W](other: JavaPairRDD[K, W]): JavaPairRDD[K, (JIterable[V], JIterable[W])] =
fromRDD(cogroupResultToJava(rdd.groupWith(other)))
Expand All @@ -583,6 +618,13 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)])
: JavaPairRDD[K, (JIterable[V], JIterable[W1], JIterable[W2])] =
fromRDD(cogroupResult2ToJava(rdd.groupWith(other1, other2)))

/** Alias for cogroup. */
def groupWith[W1, W2, W3](other1: JavaPairRDD[K, W1],
other2: JavaPairRDD[K, W2],
other3: JavaPairRDD[K, W3])
: JavaPairRDD[K, (JIterable[V], JIterable[W1], JIterable[W2], JIterable[W3])] =
fromRDD(cogroupResult3ToJava(rdd.groupWith(other1, other2, other3)))

/**
* Return the list of values in the RDD for key `key`. This operation is done efficiently if the
* RDD has a known partitioner by only searching the partition that the key maps to.
Expand Down Expand Up @@ -786,6 +828,15 @@ object JavaPairRDD {
.mapValues(x => (asJavaIterable(x._1), asJavaIterable(x._2), asJavaIterable(x._3)))
}

private[spark]
def cogroupResult3ToJava[K: ClassTag, V, W1, W2, W3](
rdd: RDD[(K, (Iterable[V], Iterable[W1], Iterable[W2], Iterable[W3]))])
: RDD[(K, (JIterable[V], JIterable[W1], JIterable[W2], JIterable[W3]))] = {
rddToPairRDDFunctions(rdd)
.mapValues(x =>
(asJavaIterable(x._1), asJavaIterable(x._2), asJavaIterable(x._3), asJavaIterable(x._4)))
}

def fromRDD[K: ClassTag, V: ClassTag](rdd: RDD[(K, V)]): JavaPairRDD[K, V] = {
new JavaPairRDD[K, V](rdd)
}
Expand Down
51 changes: 51 additions & 0 deletions core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -567,6 +567,28 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)])
new FlatMappedValuesRDD(self, cleanF)
}

/**
* For each key k in `this` or `other1` or `other2` or `other3`,
* return a resulting RDD that contains a tuple with the list of values
* for that key in `this`, `other1`, `other2` and `other3`.
*/
def cogroup[W1, W2, W3](other1: RDD[(K, W1)],
other2: RDD[(K, W2)],
other3: RDD[(K, W3)],
partitioner: Partitioner)
: RDD[(K, (Iterable[V], Iterable[W1], Iterable[W2], Iterable[W3]))] = {
if (partitioner.isInstanceOf[HashPartitioner] && keyClass.isArray) {
throw new SparkException("Default partitioner cannot partition array keys.")
}
val cg = new CoGroupedRDD[K](Seq(self, other1, other2, other3), partitioner)
cg.mapValues { case Seq(vs, w1s, w2s, w3s) =>
(vs.asInstanceOf[Seq[V]],
w1s.asInstanceOf[Seq[W1]],
w2s.asInstanceOf[Seq[W2]],
w3s.asInstanceOf[Seq[W3]])
}
}

/**
* For each key k in `this` or `other`, return a resulting RDD that contains a tuple with the
* list of values for that key in `this` as well as `other`.
Expand Down Expand Up @@ -599,6 +621,16 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)])
}
}

/**
* For each key k in `this` or `other1` or `other2` or `other3`,
* return a resulting RDD that contains a tuple with the list of values
* for that key in `this`, `other1`, `other2` and `other3`.
*/
def cogroup[W1, W2, W3](other1: RDD[(K, W1)], other2: RDD[(K, W2)], other3: RDD[(K, W3)])
: RDD[(K, (Iterable[V], Iterable[W1], Iterable[W2], Iterable[W3]))] = {
cogroup(other1, other2, other3, defaultPartitioner(self, other1, other2, other3))
}

/**
* For each key k in `this` or `other`, return a resulting RDD that contains a tuple with the
* list of values for that key in `this` as well as `other`.
Expand Down Expand Up @@ -633,6 +665,19 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)])
cogroup(other1, other2, new HashPartitioner(numPartitions))
}

/**
* For each key k in `this` or `other1` or `other2` or `other3`,
* return a resulting RDD that contains a tuple with the list of values
* for that key in `this`, `other1`, `other2` and `other3`.
*/
def cogroup[W1, W2, W3](other1: RDD[(K, W1)],
other2: RDD[(K, W2)],
other3: RDD[(K, W3)],
numPartitions: Int)
: RDD[(K, (Iterable[V], Iterable[W1], Iterable[W2], Iterable[W3]))] = {
cogroup(other1, other2, other3, new HashPartitioner(numPartitions))
}

/** Alias for cogroup. */
def groupWith[W](other: RDD[(K, W)]): RDD[(K, (Iterable[V], Iterable[W]))] = {
cogroup(other, defaultPartitioner(self, other))
Expand All @@ -644,6 +689,12 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)])
cogroup(other1, other2, defaultPartitioner(self, other1, other2))
}

/** Alias for cogroup. */
def groupWith[W1, W2, W3](other1: RDD[(K, W1)], other2: RDD[(K, W2)], other3: RDD[(K, W3)])
: RDD[(K, (Iterable[V], Iterable[W1], Iterable[W2], Iterable[W3]))] = {
cogroup(other1, other2, other3, defaultPartitioner(self, other1, other2, other3))
}

/**
* Return an RDD with the pairs from `this` whose keys are not in `other`.
*
Expand Down
63 changes: 63 additions & 0 deletions core/src/test/java/org/apache/spark/JavaAPISuite.java
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@
import java.util.*;

import scala.Tuple2;
import scala.Tuple3;
import scala.Tuple4;


import com.google.common.collect.Iterables;
import com.google.common.collect.Iterators;
Expand Down Expand Up @@ -304,6 +307,66 @@ public void cogroup() {
cogrouped.collect();
}

@SuppressWarnings("unchecked")
@Test
public void cogroup3() {
JavaPairRDD<String, String> categories = sc.parallelizePairs(Arrays.asList(
new Tuple2<String, String>("Apples", "Fruit"),
new Tuple2<String, String>("Oranges", "Fruit"),
new Tuple2<String, String>("Oranges", "Citrus")
));
JavaPairRDD<String, Integer> prices = sc.parallelizePairs(Arrays.asList(
new Tuple2<String, Integer>("Oranges", 2),
new Tuple2<String, Integer>("Apples", 3)
));
JavaPairRDD<String, Integer> quantities = sc.parallelizePairs(Arrays.asList(
new Tuple2<String, Integer>("Oranges", 21),
new Tuple2<String, Integer>("Apples", 42)
));

JavaPairRDD<String, Tuple3<Iterable<String>, Iterable<Integer>, Iterable<Integer>>> cogrouped =
categories.cogroup(prices, quantities);
Assert.assertEquals("[Fruit, Citrus]",
Iterables.toString(cogrouped.lookup("Oranges").get(0)._1()));
Assert.assertEquals("[2]", Iterables.toString(cogrouped.lookup("Oranges").get(0)._2()));
Assert.assertEquals("[42]", Iterables.toString(cogrouped.lookup("Apples").get(0)._3()));


cogrouped.collect();
}

@SuppressWarnings("unchecked")
@Test
public void cogroup4() {
JavaPairRDD<String, String> categories = sc.parallelizePairs(Arrays.asList(
new Tuple2<String, String>("Apples", "Fruit"),
new Tuple2<String, String>("Oranges", "Fruit"),
new Tuple2<String, String>("Oranges", "Citrus")
));
JavaPairRDD<String, Integer> prices = sc.parallelizePairs(Arrays.asList(
new Tuple2<String, Integer>("Oranges", 2),
new Tuple2<String, Integer>("Apples", 3)
));
JavaPairRDD<String, Integer> quantities = sc.parallelizePairs(Arrays.asList(
new Tuple2<String, Integer>("Oranges", 21),
new Tuple2<String, Integer>("Apples", 42)
));
JavaPairRDD<String, String> countries = sc.parallelizePairs(Arrays.asList(
new Tuple2<String, String>("Oranges", "BR"),
new Tuple2<String, String>("Apples", "US")
));

JavaPairRDD<String, Tuple4<Iterable<String>, Iterable<Integer>, Iterable<Integer>, Iterable<String>>> cogrouped =
categories.cogroup(prices, quantities, countries);
Assert.assertEquals("[Fruit, Citrus]",
Iterables.toString(cogrouped.lookup("Oranges").get(0)._1()));
Assert.assertEquals("[2]", Iterables.toString(cogrouped.lookup("Oranges").get(0)._2()));
Assert.assertEquals("[42]", Iterables.toString(cogrouped.lookup("Apples").get(0)._3()));
Assert.assertEquals("[BR]", Iterables.toString(cogrouped.lookup("Oranges").get(0)._4()));

cogrouped.collect();
}

@SuppressWarnings("unchecked")
@Test
public void leftOuterJoin() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,39 @@ class PairRDDFunctionsSuite extends FunSuite with SharedSparkContext {
))
}

test("groupWith3") {
val rdd1 = sc.parallelize(Array((1, 1), (1, 2), (2, 1), (3, 1)))
val rdd2 = sc.parallelize(Array((1, 'x'), (2, 'y'), (2, 'z'), (4, 'w')))
val rdd3 = sc.parallelize(Array((1, 'a'), (3, 'b'), (4, 'c'), (4, 'd')))
val joined = rdd1.groupWith(rdd2, rdd3).collect()
assert(joined.size === 4)
val joinedSet = joined.map(x => (x._1,
(x._2._1.toList, x._2._2.toList, x._2._3.toList))).toSet
assert(joinedSet === Set(
(1, (List(1, 2), List('x'), List('a'))),
(2, (List(1), List('y', 'z'), List())),
(3, (List(1), List(), List('b'))),
(4, (List(), List('w'), List('c', 'd')))
))
}

test("groupWith4") {
val rdd1 = sc.parallelize(Array((1, 1), (1, 2), (2, 1), (3, 1)))
val rdd2 = sc.parallelize(Array((1, 'x'), (2, 'y'), (2, 'z'), (4, 'w')))
val rdd3 = sc.parallelize(Array((1, 'a'), (3, 'b'), (4, 'c'), (4, 'd')))
val rdd4 = sc.parallelize(Array((2, '@')))
val joined = rdd1.groupWith(rdd2, rdd3, rdd4).collect()
assert(joined.size === 4)
val joinedSet = joined.map(x => (x._1,
(x._2._1.toList, x._2._2.toList, x._2._3.toList, x._2._4.toList))).toSet
assert(joinedSet === Set(
(1, (List(1, 2), List('x'), List('a'), List())),
(2, (List(1), List('y', 'z'), List(), List('@'))),
(3, (List(1), List(), List('b'), List())),
(4, (List(), List('w'), List('c', 'd'), List()))
))
}

test("zero-partition RDD") {
val emptyDir = Files.createTempDir()
emptyDir.deleteOnExit()
Expand Down
20 changes: 10 additions & 10 deletions python/pyspark/join.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,15 +79,15 @@ def dispatch(seq):
return _do_python_join(rdd, other, numPartitions, dispatch)


def python_cogroup(rdd, other, numPartitions):
vs = rdd.map(lambda (k, v): (k, (1, v)))
ws = other.map(lambda (k, v): (k, (2, v)))
def python_cogroup(rdds, numPartitions):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will this break compatibility for users who were building against the previous API? It seems like this is a public API, so we might need to make a second version rather than replace the current one.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay I looked yet again, this entire file is not exposed in e.g. the docs, so I guess this isn't public.

def make_mapper(i):
return lambda (k, v): (k, (i, v))
vrdds = [rdd.map(make_mapper(i)) for i, rdd in enumerate(rdds)]
union_vrdds = reduce(lambda acc, other: acc.union(other), vrdds)
rdd_len = len(vrdds)
def dispatch(seq):
vbuf, wbuf = [], []
bufs = [[] for i in range(rdd_len)]
for (n, v) in seq:
if n == 1:
vbuf.append(v)
elif n == 2:
wbuf.append(v)
return (ResultIterable(vbuf), ResultIterable(wbuf))
return vs.union(ws).groupByKey(numPartitions).mapValues(dispatch)
bufs[n].append(v)
return tuple(map(ResultIterable, bufs))
return union_vrdds.groupByKey(numPartitions).mapValues(dispatch)
22 changes: 15 additions & 7 deletions python/pyspark/rdd.py
Original file line number Diff line number Diff line change
Expand Up @@ -1233,7 +1233,7 @@ def _mergeCombiners(iterator):
combiners[k] = mergeCombiners(combiners[k], v)
return combiners.iteritems()
return shuffled.mapPartitions(_mergeCombiners)

def aggregateByKey(self, zeroValue, seqFunc, combFunc, numPartitions=None):
"""
Aggregate the values of each key, using given combine functions and a neutral "zero value".
Expand All @@ -1245,7 +1245,7 @@ def aggregateByKey(self, zeroValue, seqFunc, combFunc, numPartitions=None):
"""
def createZero():
return copy.deepcopy(zeroValue)

return self.combineByKey(lambda v: seqFunc(createZero(), v), seqFunc, combFunc, numPartitions)

def foldByKey(self, zeroValue, func, numPartitions=None):
Expand Down Expand Up @@ -1323,12 +1323,20 @@ def mapValues(self, f):
map_values_fn = lambda (k, v): (k, f(v))
return self.map(map_values_fn, preservesPartitioning=True)

# TODO: support varargs cogroup of several RDDs.
def groupWith(self, other):
def groupWith(self, other, *others):
"""
Alias for cogroup.
Alias for cogroup but with support for multiple RDDs.

>>> w = sc.parallelize([("a", 5), ("b", 6)])
>>> x = sc.parallelize([("a", 1), ("b", 4)])
>>> y = sc.parallelize([("a", 2)])
>>> z = sc.parallelize([("b", 42)])
>>> map((lambda (x,y): (x, (list(y[0]), list(y[1]), list(y[2]), list(y[3])))), \
sorted(list(w.groupWith(x, y, z).collect())))
[('a', ([5], [1], [2], [])), ('b', ([6], [4], [], [42]))]

"""
return self.cogroup(other)
return python_cogroup((self, other) + others, numPartitions=None)

# TODO: add variant with custom parittioner
def cogroup(self, other, numPartitions=None):
Expand All @@ -1342,7 +1350,7 @@ def cogroup(self, other, numPartitions=None):
>>> map((lambda (x,y): (x, (list(y[0]), list(y[1])))), sorted(list(x.cogroup(y).collect())))
[('a', ([1], [2])), ('b', ([4], []))]
"""
return python_cogroup(self, other, numPartitions)
return python_cogroup((self, other), numPartitions)

def subtractByKey(self, other, numPartitions=None):
"""
Expand Down