Skip to content

Commit 662f62c

Browse files
committed
small test to find perf issues
1 parent 2f809ef commit 662f62c

File tree

1 file changed

+62
-1
lines changed

1 file changed

+62
-1
lines changed

mllib/src/test/scala/org/apache/spark/ml/stat/SummarizerSuite.scala

Lines changed: 62 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ import org.apache.spark.ml.linalg.{Vector, Vectors}
2424
import org.apache.spark.ml.stat.SummaryBuilderImpl.Buffer
2525
import org.apache.spark.ml.util.TestingUtils._
2626
import org.apache.spark.mllib.linalg.{Vector => OldVector, Vectors => OldVectors}
27-
import org.apache.spark.mllib.stat.MultivariateOnlineSummarizer
27+
import org.apache.spark.mllib.stat.{MultivariateOnlineSummarizer, Statistics}
2828
import org.apache.spark.mllib.util.MLlibTestSparkContext
2929
import org.apache.spark.sql.{DataFrame, Row}
3030
import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema
@@ -335,4 +335,65 @@ class SummarizerSuite extends SparkFunSuite with MLlibTestSparkContext {
335335
assert(Buffer.totalCount(summarizer) === 6)
336336
}
337337

338+
// TODO: this test should not be committed. It is here to isolate some performance hotspots.
339+
test("perf test") {
340+
val n = 10000000
341+
val rdd1 = sc.parallelize(1 to n).map { idx =>
342+
OldVectors.dense(idx.toDouble)
343+
}
344+
val trieouts = 10
345+
rdd1.cache()
346+
rdd1.count()
347+
val rdd2 = sc.parallelize(1 to n).map { idx =>
348+
Vectors.dense(idx.toDouble)
349+
}
350+
rdd2.cache()
351+
rdd2.count()
352+
val df = rdd2.map(Tuple1.apply).toDF("features")
353+
df.cache()
354+
df.count()
355+
val x = df.select(
356+
metrics("mean", "variance", "count", "numNonZeros", "max", "min", "normL1",
357+
"normL2").summary($"features"))
358+
val x1 = df.select(metrics("variance").summary($"features"))
359+
360+
var times_df: List[Long] = Nil
361+
for (_ <- 1 to trieouts) {
362+
val t21 = System.nanoTime()
363+
x.head()
364+
val t22 = System.nanoTime()
365+
times_df ::= (t22 - t21)
366+
}
367+
368+
var times_rdd: List[Long] = Nil
369+
for (_ <- 1 to trieouts) {
370+
val t21 = System.nanoTime()
371+
Statistics.colStats(rdd1)
372+
val t22 = System.nanoTime()
373+
times_rdd ::= (t22 - t21)
374+
}
375+
376+
var times_df_variance: List[Long] = Nil
377+
for (_ <- 1 to trieouts) {
378+
val t21 = System.nanoTime()
379+
x1.head()
380+
val t22 = System.nanoTime()
381+
times_df_variance ::= (t22 - t21)
382+
}
383+
384+
def print(name: String, l: List[Long]): Unit = {
385+
def f(z: Long) = (1e6 * n.toDouble) / z
386+
val min = f(l.max)
387+
val max = f(l.min)
388+
val med = f(l.sorted.drop(l.size / 2).head)
389+
390+
// scalastyle:off println
391+
println(s"dataframe = [$min ~ $med ~ $max] records / milli")
392+
}
393+
394+
print("RDD", times_rdd)
395+
print("Dataframes (variance only)", times_df_variance)
396+
print("Dataframes", times_df)
397+
}
398+
338399
}

0 commit comments

Comments
 (0)