Skip to content

Commit d58a087

Browse files
committed
move treeReduce and treeAggregate to mllib
1 parent 8a2a59c commit d58a087

File tree

7 files changed

+83
-81
lines changed

7 files changed

+83
-81
lines changed

core/src/main/scala/org/apache/spark/rdd/RDD.scala

Lines changed: 0 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -839,39 +839,6 @@ abstract class RDD[T: ClassTag](
839839
jobResult.getOrElse(throw new UnsupportedOperationException("empty collection"))
840840
}
841841

842-
/**
843-
* :: DeveloperApi ::
844-
* Reduces the elements of this RDD in a tree pattern.
845-
* @param depth suggested depth of the tree
846-
* @see [[org.apache.spark.rdd.RDD#reduce]]
847-
*/
848-
@DeveloperApi
849-
def treeReduce(f: (T, T) => T, depth: Int): T = {
850-
require(depth >= 1, s"Depth must be greater than 1 but got $depth.")
851-
val cleanF = sc.clean(f)
852-
val reducePartition: Iterator[T] => Option[T] = iter => {
853-
if (iter.hasNext) {
854-
Some(iter.reduceLeft(cleanF))
855-
} else {
856-
None
857-
}
858-
}
859-
val local = this.mapPartitions(it => Iterator(reducePartition(it)))
860-
val op: (Option[T], Option[T]) => Option[T] = (c, x) => {
861-
if (c.isDefined && x.isDefined) {
862-
Some(cleanF(c.get, x.get))
863-
} else if (c.isDefined) {
864-
c
865-
} else if (x.isDefined) {
866-
x
867-
} else {
868-
None
869-
}
870-
}
871-
local.treeAggregate(Option.empty[T])(op, op, depth)
872-
.getOrElse(throw new UnsupportedOperationException("empty collection"))
873-
}
874-
875842
/**
876843
* Aggregate the elements of each partition, and then the results for all the partitions, using a
877844
* given associative function and a neutral "zero value". The function op(t1, t2) is allowed to
@@ -907,36 +874,6 @@ abstract class RDD[T: ClassTag](
907874
jobResult
908875
}
909876

910-
/**
911-
* :: DeveloperApi ::
912-
* Aggregates the elements of this RDD in a tree pattern.
913-
* @param depth suggested depth of the tree
914-
* @see [[org.apache.spark.rdd.RDD#aggregate]]
915-
*/
916-
@DeveloperApi
917-
def treeAggregate[U: ClassTag](zeroValue: U)(
918-
seqOp: (U, T) => U,
919-
combOp: (U, U) => U,
920-
depth: Int): U = {
921-
require(depth >= 1, s"Depth must be greater than 1 but got $depth.")
922-
if (this.partitions.size == 0) {
923-
return Utils.clone(zeroValue, sc.env.closureSerializer.newInstance())
924-
}
925-
val cleanSeqOp = sc.clean(seqOp)
926-
val cleanCombOp = sc.clean(combOp)
927-
val aggregatePartition = (it: Iterator[T]) => it.aggregate(zeroValue)(cleanSeqOp, cleanCombOp)
928-
var local = this.mapPartitions(it => Iterator(aggregatePartition(it)))
929-
var numPartitions = local.partitions.size
930-
val scale = math.max(math.ceil(math.pow(numPartitions, 1.0 / depth)).toInt, 2)
931-
while (numPartitions > scale + numPartitions / scale) {
932-
numPartitions /= scale
933-
local = local.mapPartitionsWithIndex { (i, iter) =>
934-
iter.map((i % numPartitions, _))
935-
}.reduceByKey(new HashPartitioner(numPartitions), cleanCombOp).values
936-
}
937-
local.reduce(cleanCombOp)
938-
}
939-
940877
/**
941878
* Return the number of elements in the RDD.
942879
*/

core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala

Lines changed: 0 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -820,22 +820,4 @@ class RDDSuite extends FunSuite with SharedSparkContext {
820820
mutableDependencies += dep
821821
}
822822
}
823-
824-
test("treeAggregate") {
825-
val rdd = sc.makeRDD(-1000 until 1000, 10)
826-
def seqOp = (c: Long, x: Int) => c + x
827-
def combOp = (c1: Long, c2: Long) => c1 + c2
828-
for (level <- 1 until 10) {
829-
val sum = rdd.treeAggregate(0L)(seqOp, combOp, level)
830-
assert(sum === -1000L)
831-
}
832-
}
833-
834-
test("treeReduce") {
835-
val rdd = sc.makeRDD(-1000 until 1000, 10)
836-
for (level <- 1 until 10) {
837-
val sum = rdd.treeReduce(_ + _, level)
838-
assert(sum === -1000)
839-
}
840-
}
841823
}

mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ import org.apache.spark.annotation.Experimental
2727
import org.apache.spark.mllib.linalg._
2828
import org.apache.spark.rdd.RDD
2929
import org.apache.spark.Logging
30+
import org.apache.spark.mllib.rdd.RDDFunctions._
3031
import org.apache.spark.mllib.stat.MultivariateStatisticalSummary
3132

3233
/**

mllib/src/main/scala/org/apache/spark/mllib/optimization/GradientDescent.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ import org.apache.spark.annotation.{Experimental, DeveloperApi}
2525
import org.apache.spark.Logging
2626
import org.apache.spark.rdd.RDD
2727
import org.apache.spark.mllib.linalg.{Vectors, Vector}
28+
import org.apache.spark.mllib.rdd.RDDFunctions._
2829

2930
/**
3031
* Class used to solve an optimization problem using Gradient Descent.

mllib/src/main/scala/org/apache/spark/mllib/optimization/LBFGS.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ import org.apache.spark.annotation.DeveloperApi
2626
import org.apache.spark.Logging
2727
import org.apache.spark.rdd.RDD
2828
import org.apache.spark.mllib.linalg.{Vectors, Vector}
29+
import org.apache.spark.mllib.rdd.RDDFunctions._
2930

3031
/**
3132
* :: DeveloperApi ::

mllib/src/main/scala/org/apache/spark/mllib/rdd/RDDFunctions.scala

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,10 @@ package org.apache.spark.mllib.rdd
2020
import scala.language.implicitConversions
2121
import scala.reflect.ClassTag
2222

23+
import org.apache.spark.HashPartitioner
24+
import org.apache.spark.SparkContext._
2325
import org.apache.spark.rdd.RDD
26+
import org.apache.spark.util.Utils
2427

2528
/**
2629
* Machine learning specific RDD functions.
@@ -44,6 +47,65 @@ class RDDFunctions[T: ClassTag](self: RDD[T]) {
4447
new SlidingRDD[T](self, windowSize)
4548
}
4649
}
50+
51+
/**
52+
* Reduces the elements of this RDD in a tree pattern.
53+
* @param depth suggested depth of the tree
54+
* @see [[org.apache.spark.rdd.RDD#reduce]]
55+
*/
56+
def treeReduce(f: (T, T) => T, depth: Int): T = {
57+
require(depth >= 1, s"Depth must be greater than 1 but got $depth.")
58+
val cleanF = self.context.clean(f)
59+
val reducePartition: Iterator[T] => Option[T] = iter => {
60+
if (iter.hasNext) {
61+
Some(iter.reduceLeft(cleanF))
62+
} else {
63+
None
64+
}
65+
}
66+
val local = self.mapPartitions(it => Iterator(reducePartition(it)))
67+
val op: (Option[T], Option[T]) => Option[T] = (c, x) => {
68+
if (c.isDefined && x.isDefined) {
69+
Some(cleanF(c.get, x.get))
70+
} else if (c.isDefined) {
71+
c
72+
} else if (x.isDefined) {
73+
x
74+
} else {
75+
None
76+
}
77+
}
78+
RDDFunctions.fromRDD(local).treeAggregate(Option.empty[T])(op, op, depth)
79+
.getOrElse(throw new UnsupportedOperationException("empty collection"))
80+
}
81+
82+
/**
83+
* Aggregates the elements of this RDD in a tree pattern.
84+
* @param depth suggested depth of the tree
85+
* @see [[org.apache.spark.rdd.RDD#aggregate]]
86+
*/
87+
def treeAggregate[U: ClassTag](zeroValue: U)(
88+
seqOp: (U, T) => U,
89+
combOp: (U, U) => U,
90+
depth: Int): U = {
91+
require(depth >= 1, s"Depth must be greater than 1 but got $depth.")
92+
if (self.partitions.size == 0) {
93+
return Utils.clone(zeroValue, self.context.env.closureSerializer.newInstance())
94+
}
95+
val cleanSeqOp = self.context.clean(seqOp)
96+
val cleanCombOp = self.context.clean(combOp)
97+
val aggregatePartition = (it: Iterator[T]) => it.aggregate(zeroValue)(cleanSeqOp, cleanCombOp)
98+
var local = self.mapPartitions(it => Iterator(aggregatePartition(it)))
99+
var numPartitions = local.partitions.size
100+
val scale = math.max(math.ceil(math.pow(numPartitions, 1.0 / depth)).toInt, 2)
101+
while (numPartitions > scale + numPartitions / scale) {
102+
numPartitions /= scale
103+
local = local.mapPartitionsWithIndex { (i, iter) =>
104+
iter.map((i % numPartitions, _))
105+
}.reduceByKey(new HashPartitioner(numPartitions), cleanCombOp).values
106+
}
107+
local.reduce(cleanCombOp)
108+
}
47109
}
48110

49111
private[mllib]

mllib/src/test/scala/org/apache/spark/mllib/rdd/RDDFunctionsSuite.scala

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,4 +46,22 @@ class RDDFunctionsSuite extends FunSuite with LocalSparkContext {
4646
val expected = data.flatMap(x => x).sliding(3).toList
4747
assert(sliding.collect().toList === expected)
4848
}
49+
50+
test("treeAggregate") {
51+
val rdd = sc.makeRDD(-1000 until 1000, 10)
52+
def seqOp = (c: Long, x: Int) => c + x
53+
def combOp = (c1: Long, c2: Long) => c1 + c2
54+
for (level <- 1 until 10) {
55+
val sum = rdd.treeAggregate(0L)(seqOp, combOp, level)
56+
assert(sum === -1000L)
57+
}
58+
}
59+
60+
test("treeReduce") {
61+
val rdd = sc.makeRDD(-1000 until 1000, 10)
62+
for (level <- 1 until 10) {
63+
val sum = rdd.treeReduce(_ + _, level)
64+
assert(sum === -1000)
65+
}
66+
}
4967
}

0 commit comments

Comments
 (0)