@@ -24,7 +24,7 @@ import org.apache.spark.ml.linalg.{Vector, Vectors}
2424import org .apache .spark .ml .stat .SummaryBuilderImpl .Buffer
2525import org .apache .spark .ml .util .TestingUtils ._
2626import org .apache .spark .mllib .linalg .{Vector => OldVector , Vectors => OldVectors }
27- import org .apache .spark .mllib .stat .MultivariateOnlineSummarizer
27+ import org .apache .spark .mllib .stat .{ MultivariateOnlineSummarizer , Statistics }
2828import org .apache .spark .mllib .util .MLlibTestSparkContext
2929import org .apache .spark .sql .{DataFrame , Row }
3030import org .apache .spark .sql .catalyst .expressions .GenericRowWithSchema
@@ -335,4 +335,65 @@ class SummarizerSuite extends SparkFunSuite with MLlibTestSparkContext {
335335 assert(Buffer .totalCount(summarizer) === 6 )
336336 }
337337
338+ // TODO: this test should not be committed. It is here to isolate some performance hotspots.
339+ test(" perf test" ) {
340+ val n = 10000000
341+ val rdd1 = sc.parallelize(1 to n).map { idx =>
342+ OldVectors .dense(idx.toDouble)
343+ }
344+ val trieouts = 10
345+ rdd1.cache()
346+ rdd1.count()
347+ val rdd2 = sc.parallelize(1 to n).map { idx =>
348+ Vectors .dense(idx.toDouble)
349+ }
350+ rdd2.cache()
351+ rdd2.count()
352+ val df = rdd2.map(Tuple1 .apply).toDF(" features" )
353+ df.cache()
354+ df.count()
355+ val x = df.select(
356+ metrics(" mean" , " variance" , " count" , " numNonZeros" , " max" , " min" , " normL1" ,
357+ " normL2" ).summary($" features" ))
358+ val x1 = df.select(metrics(" variance" ).summary($" features" ))
359+
360+ var times_df : List [Long ] = Nil
361+ for (_ <- 1 to trieouts) {
362+ val t21 = System .nanoTime()
363+ x.head()
364+ val t22 = System .nanoTime()
365+ times_df ::= (t22 - t21)
366+ }
367+
368+ var times_rdd : List [Long ] = Nil
369+ for (_ <- 1 to trieouts) {
370+ val t21 = System .nanoTime()
371+ Statistics .colStats(rdd1)
372+ val t22 = System .nanoTime()
373+ times_rdd ::= (t22 - t21)
374+ }
375+
376+ var times_df_variance : List [Long ] = Nil
377+ for (_ <- 1 to trieouts) {
378+ val t21 = System .nanoTime()
379+ x1.head()
380+ val t22 = System .nanoTime()
381+ times_df_variance ::= (t22 - t21)
382+ }
383+
384+ def print (name : String , l : List [Long ]): Unit = {
385+ def f (z : Long ) = (1e6 * n.toDouble) / z
386+ val min = f(l.max)
387+ val max = f(l.min)
388+ val med = f(l.sorted.drop(l.size / 2 ).head)
389+
390+ // scalastyle:off println
391+ println(s " dataframe = [ $min ~ $med ~ $max] records / milli " )
392+ }
393+
394+ print(" RDD" , times_rdd)
395+ print(" Dataframes (variance only)" , times_df_variance)
396+ print(" Dataframes" , times_df)
397+ }
398+
338399}
0 commit comments