diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaDoubleRDD.scala b/core/src/main/scala/org/apache/spark/api/java/JavaDoubleRDD.scala index f816bb43a5b4..5687eebfe3dc 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaDoubleRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaDoubleRDD.scala @@ -153,6 +153,14 @@ class JavaDoubleRDD(val srdd: RDD[scala.Double]) extends JavaRDDLike[JDouble, Ja /** Add up the elements in this RDD. */ def sum(): JDouble = srdd.sum() + + /** Max of the elements in this RDD. */ + def max(): JDouble = srdd.max() + + + /** Min of the elements in this RDD. */ + def min(): JDouble = srdd.min() + /** * Return a [[org.apache.spark.util.StatCounter]] object that captures the mean, variance and * count of the RDD's elements in one operation. diff --git a/core/src/main/scala/org/apache/spark/rdd/DoubleRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/DoubleRDDFunctions.scala index a7b6b3b5146c..406d1c9b759f 100644 --- a/core/src/main/scala/org/apache/spark/rdd/DoubleRDDFunctions.scala +++ b/core/src/main/scala/org/apache/spark/rdd/DoubleRDDFunctions.scala @@ -35,8 +35,8 @@ class DoubleRDDFunctions(self: RDD[Double]) extends Logging with Serializable { } /** - * Return a [[org.apache.spark.util.StatCounter]] object that captures the mean, variance and - * count of the RDD's elements in one operation. + * Return a [[org.apache.spark.util.StatCounter]] object that captures the mean, variance, + * count, min and max of the RDD's elements in one operation. */ def stats(): StatCounter = { self.mapPartitions(nums => Iterator(StatCounter(nums))).reduce((a, b) => a.merge(b)) @@ -51,6 +51,12 @@ class DoubleRDDFunctions(self: RDD[Double]) extends Logging with Serializable { /** Compute the standard deviation of this RDD's elements. */ def stdev(): Double = stats().stdev + /** Find the min element of this RDD's elements. */ + def min(): Double = stats().min + + /** Find the max element of this RDD's elements. */ + def max(): Double = stats().max + /** * Compute the sample standard deviation of this RDD's elements (which corrects for bias in * estimating the standard deviation by dividing by N-1 instead of N). @@ -86,14 +92,9 @@ class DoubleRDDFunctions(self: RDD[Double]) extends Logging with Serializable { * If the elements in RDD do not vary (max == min) always returns a single bucket. */ def histogram(bucketCount: Int): Pair[Array[Double], Array[Long]] = { - // Compute the minimum and the maxium - val (max: Double, min: Double) = self.mapPartitions { items => - Iterator(items.foldRight(Double.NegativeInfinity, - Double.PositiveInfinity)((e: Double, x: Pair[Double, Double]) => - (x._1.max(e), x._2.min(e)))) - }.reduce { (maxmin1, maxmin2) => - (maxmin1._1.max(maxmin2._1), maxmin1._2.min(maxmin2._2)) - } + // Compute the minimum and the maximum from stats once + val _stats = stats() + val (max: Double, min: Double) = (_stats.max, _stats.min) if (min.isNaN || max.isNaN || max.isInfinity || min.isInfinity ) { throw new UnsupportedOperationException( "Histogram on either an empty RDD or RDD containing +/-infinity or NaN") diff --git a/core/src/main/scala/org/apache/spark/util/StatCounter.scala b/core/src/main/scala/org/apache/spark/util/StatCounter.scala index f837dc7ccc86..4cd0057f1bb3 100644 --- a/core/src/main/scala/org/apache/spark/util/StatCounter.scala +++ b/core/src/main/scala/org/apache/spark/util/StatCounter.scala @@ -29,6 +29,12 @@ class StatCounter(values: TraversableOnce[Double]) extends Serializable { private var n: Long = 0 // Running count of our values private var mu: Double = 0 // Running mean of our values private var m2: Double = 0 // Running variance numerator (sum of (x - mean)^2) + private var _min: Double = Double.PositiveInfinity + private var _max: Double = Double.NegativeInfinity + + def min: Double = _min + + def max: Double = _max merge(values) @@ -41,6 +47,8 @@ class StatCounter(values: TraversableOnce[Double]) extends Serializable { n += 1 mu += delta / n m2 += delta * (value - mu) + _min = math.min(value, _min) + _max = math.max(value, _max) this } @@ -58,7 +66,9 @@ class StatCounter(values: TraversableOnce[Double]) extends Serializable { if (n == 0) { mu = other.mu m2 = other.m2 - n = other.n + n = other.n + _min = other.min + _max = other.max } else if (other.n != 0) { val delta = other.mu - mu if (other.n * 10 < n) { @@ -70,6 +80,8 @@ class StatCounter(values: TraversableOnce[Double]) extends Serializable { } m2 += other.m2 + (delta * delta * n * other.n) / (n + other.n) n += other.n + _min = math.min(other.min, _min) + _max = math.max(other.max, _max) } this } @@ -81,6 +93,8 @@ class StatCounter(values: TraversableOnce[Double]) extends Serializable { other.n = n other.mu = mu other.m2 = m2 + other._min = _min + other._max = _max other } @@ -120,9 +134,9 @@ class StatCounter(values: TraversableOnce[Double]) extends Serializable { */ def sampleStdev: Double = math.sqrt(sampleVariance) - override def toString: String = { - "(count: %d, mean: %f, stdev: %f)".format(count, mean, stdev) - } + override def toString: String = + s"count: $count, mean: $mean, stdev: $stdev , min: ${_min}, max: ${_max}" + } object StatCounter { diff --git a/core/src/test/java/org/apache/spark/JavaAPISuite.java b/core/src/test/java/org/apache/spark/JavaAPISuite.java index 40e853c39ca9..241b610b7f5e 100644 --- a/core/src/test/java/org/apache/spark/JavaAPISuite.java +++ b/core/src/test/java/org/apache/spark/JavaAPISuite.java @@ -395,7 +395,8 @@ public Boolean call(Double x) { Assert.assertEquals(7.46667, rdd.sampleVariance(), 0.01); Assert.assertEquals(2.49444, rdd.stdev(), 0.01); Assert.assertEquals(2.73252, rdd.sampleStdev(), 0.01); - + Assert.assertEquals(1.0, rdd.min(), 0.01); + Assert.assertEquals(8.0, rdd.max(), 0.01); rdd.first(); rdd.take(5); } diff --git a/core/src/test/scala/org/apache/spark/PartitioningSuite.scala b/core/src/test/scala/org/apache/spark/PartitioningSuite.scala index 4305686d3a6d..741a7fc1408e 100644 --- a/core/src/test/scala/org/apache/spark/PartitioningSuite.scala +++ b/core/src/test/scala/org/apache/spark/PartitioningSuite.scala @@ -171,6 +171,8 @@ class PartitioningSuite extends FunSuite with SharedSparkContext with PrivateMet assert(abs(6.0/2 - rdd.mean) < 0.01) assert(abs(1.0 - rdd.variance) < 0.01) assert(abs(1.0 - rdd.stdev) < 0.01) + assert(abs(2.0 - stats.min) < 0.01) + assert(abs(4.0 - stats.max) < 0.01) // Add other tests here for classes that should be able to handle empty partitions correctly } diff --git a/project/build.properties b/project/build.properties index 4b52bb928a66..43aca56fceaa 100644 --- a/project/build.properties +++ b/project/build.properties @@ -14,4 +14,4 @@ # See the License for the specific language governing permissions and # limitations under the License. # -sbt.version=0.13.1 +sbt.version=0.13.2-M1