Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala
Original file line number Diff line number Diff line change
Expand Up @@ -477,6 +477,26 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable {
new java.util.ArrayList(arr)
}

/**
* Returns the maximum element from this RDD as defined by the specified
* Comparator[T].
* @params comp the comparator that defines ordering
* @return the maximum of the RDD
* */
def max(comp: Comparator[T]): T = {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add doc comments to these and the Scala versions

rdd.max()(Ordering.comparatorToOrdering(comp))
}

/**
* Returns the minimum element from this RDD as defined by the specified
* Comparator[T].
* @params comp the comparator that defines ordering
* @return the minimum of the RDD
* */
def min(comp: Comparator[T]): T = {
rdd.min()(Ordering.comparatorToOrdering(comp))
}

/**
* Returns the first K elements from this RDD using the
* natural ordering for T while maintain the order.
Expand Down
12 changes: 12 additions & 0 deletions core/src/main/scala/org/apache/spark/rdd/RDD.scala
Original file line number Diff line number Diff line change
Expand Up @@ -958,6 +958,18 @@ abstract class RDD[T: ClassTag](
*/
def takeOrdered(num: Int)(implicit ord: Ordering[T]): Array[T] = top(num)(ord.reverse)

/**
* Returns the max of this RDD as defined by the implicit Ordering[T].
* @return the maximum element of the RDD
* */
def max()(implicit ord: Ordering[T]):T = this.reduce(ord.max)

/**
* Returns the min of this RDD as defined by the implicit Ordering[T].
* @return the minimum element of the RDD
* */
def min()(implicit ord: Ordering[T]):T = this.reduce(ord.min)

/**
* Save this RDD as a text file, using string representations of elements.
*/
Expand Down
18 changes: 16 additions & 2 deletions core/src/main/scala/org/apache/spark/util/StatCounter.scala
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ class StatCounter(values: TraversableOnce[Double]) extends Serializable {
private var n: Long = 0 // Running count of our values
private var mu: Double = 0 // Running mean of our values
private var m2: Double = 0 // Running variance numerator (sum of (x - mean)^2)
private var maxValue: Double = Double.NegativeInfinity // Running max of our values
private var minValue: Double = Double.PositiveInfinity // Running min of our values

merge(values)

Expand All @@ -41,6 +43,8 @@ class StatCounter(values: TraversableOnce[Double]) extends Serializable {
n += 1
mu += delta / n
m2 += delta * (value - mu)
maxValue = math.max(maxValue, value)
minValue = math.min(minValue, value)
this
}

Expand All @@ -58,7 +62,9 @@ class StatCounter(values: TraversableOnce[Double]) extends Serializable {
if (n == 0) {
mu = other.mu
m2 = other.m2
n = other.n
n = other.n
maxValue = other.maxValue
minValue = other.minValue
} else if (other.n != 0) {
val delta = other.mu - mu
if (other.n * 10 < n) {
Expand All @@ -70,6 +76,8 @@ class StatCounter(values: TraversableOnce[Double]) extends Serializable {
}
m2 += other.m2 + (delta * delta * n * other.n) / (n + other.n)
n += other.n
maxValue = math.max(maxValue, other.maxValue)
minValue = math.min(minValue, other.minValue)
}
this
}
Expand All @@ -81,6 +89,8 @@ class StatCounter(values: TraversableOnce[Double]) extends Serializable {
other.n = n
other.mu = mu
other.m2 = m2
other.maxValue = maxValue
other.minValue = minValue
other
}

Expand All @@ -90,6 +100,10 @@ class StatCounter(values: TraversableOnce[Double]) extends Serializable {

def sum: Double = n * mu

def max: Double = maxValue

def min: Double = minValue

/** Return the variance of the values. */
def variance: Double = {
if (n == 0) {
Expand Down Expand Up @@ -121,7 +135,7 @@ class StatCounter(values: TraversableOnce[Double]) extends Serializable {
def sampleStdev: Double = math.sqrt(sampleVariance)

override def toString: String = {
"(count: %d, mean: %f, stdev: %f)".format(count, mean, stdev)
"(count: %d, mean: %f, stdev: %f, max: %f, min: %f)".format(count, mean, stdev, max, min)
}
}

Expand Down
2 changes: 2 additions & 0 deletions core/src/test/scala/org/apache/spark/PartitioningSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,8 @@ class PartitioningSuite extends FunSuite with SharedSparkContext with PrivateMet
assert(abs(6.0/2 - rdd.mean) < 0.01)
assert(abs(1.0 - rdd.variance) < 0.01)
assert(abs(1.0 - rdd.stdev) < 0.01)
assert(stats.max === 4.0)
assert(stats.min === 2.0)

// Add other tests here for classes that should be able to handle empty partitions correctly
}
Expand Down
2 changes: 2 additions & 0 deletions core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@ class RDDSuite extends FunSuite with SharedSparkContext {
assert(nums.glom().map(_.toList).collect().toList === List(List(1, 2), List(3, 4)))
assert(nums.collect({ case i if i >= 3 => i.toString }).collect().toList === List("3", "4"))
assert(nums.keyBy(_.toString).collect().toList === List(("1", 1), ("2", 2), ("3", 3), ("4", 4)))
assert(nums.max() === 4)
assert(nums.min() === 1)
val partitionSums = nums.mapPartitions(iter => Iterator(iter.reduceLeft(_ + _)))
assert(partitionSums.collect().toList === List(3, 7))

Expand Down
19 changes: 19 additions & 0 deletions python/pyspark/rdd.py
Original file line number Diff line number Diff line change
Expand Up @@ -534,7 +534,26 @@ def func(iterator):
return reduce(op, vals, zeroValue)

# TODO: aggregate


def max(self):
"""
Find the maximum item in this RDD.

>>> sc.parallelize([1.0, 5.0, 43.0, 10.0]).max()
43.0
"""
return self.reduce(max)

def min(self):
"""
Find the maximum item in this RDD.

>>> sc.parallelize([1.0, 5.0, 43.0, 10.0]).min()
1.0
"""
return self.reduce(min)

def sum(self):
"""
Add up the elements in this RDD.
Expand Down
25 changes: 22 additions & 3 deletions python/pyspark/statcounter.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,9 @@ def __init__(self, values=[]):
self.n = 0L # Running count of our values
self.mu = 0.0 # Running mean of our values
self.m2 = 0.0 # Running variance numerator (sum of (x - mean)^2)

self.maxValue = float("-inf")
self.minValue = float("inf")

for v in values:
self.merge(v)

Expand All @@ -36,6 +38,11 @@ def merge(self, value):
self.n += 1
self.mu += delta / self.n
self.m2 += delta * (value - self.mu)
if self.maxValue < value:
self.maxValue = value
if self.minValue > value:
self.minValue = value

return self

# Merge another StatCounter into this one, adding up the internal statistics.
Expand All @@ -49,7 +56,10 @@ def mergeStats(self, other):
if self.n == 0:
self.mu = other.mu
self.m2 = other.m2
self.n = other.n
self.n = other.n
self.maxValue = other.maxValue
self.minValue = other.minValue

elif other.n != 0:
delta = other.mu - self.mu
if other.n * 10 < self.n:
Expand All @@ -58,6 +68,9 @@ def mergeStats(self, other):
self.mu = other.mu - (delta * self.n) / (self.n + other.n)
else:
self.mu = (self.mu * self.n + other.mu * other.n) / (self.n + other.n)

self.maxValue = max(self.maxValue, other.maxValue)
self.minValue = min(self.minValue, other.minValue)

self.m2 += other.m2 + (delta * delta * self.n * other.n) / (self.n + other.n)
self.n += other.n
Expand All @@ -76,6 +89,12 @@ def mean(self):
def sum(self):
return self.n * self.mu

def min(self):
return self.minValue

def max(self):
return self.maxValue

# Return the variance of the values.
def variance(self):
if self.n == 0:
Expand Down Expand Up @@ -105,5 +124,5 @@ def sampleStdev(self):
return math.sqrt(self.sampleVariance())

def __repr__(self):
return "(count: %s, mean: %s, stdev: %s)" % (self.count(), self.mean(), self.stdev())
return "(count: %s, mean: %s, stdev: %s, max: %s, min: %s)" % (self.count(), self.mean(), self.stdev(), self.max(), self.min())