@@ -33,17 +33,9 @@ import org.apache.spark.util.Benchmark
3333 * sql/core/target/spark-sql_*-tests.jar
3434 * [float datasize scale] [double datasize scale] [master URL]
3535 */
36- case class DataFrameCacheBenchmark (masterURL : String ) {
37- val conf = new SparkConf ()
38- val sc = new SparkContext (
39- (if (masterURL == null ) " local[1]" else masterURL), " test-sql-context" , conf)
40- val sqlContext = new SQLContext (sc)
41- import sqlContext .implicits ._
36+ class DataFrameCacheBenchmark {
4237
43- // Set default configs. Individual cases will change them if necessary.
44- sqlContext.conf.setConfString(SQLConf .WHOLESTAGE_CODEGEN_ENABLED .key, " true" )
45-
46- def withSQLConf (pairs : (String , String )* )(f : => Unit ): Unit = {
38+ def withSQLConf (sqlContext : SQLContext , pairs : (String , String )* )(f : => Unit ): Unit = {
4739 val (keys, values) = pairs.unzip
4840 val currentValues = keys.map(key => Try (sqlContext.conf.getConfString(key)).toOption)
4941 (keys, values).zipped.foreach(sqlContext.conf.setConfString)
@@ -55,18 +47,20 @@ case class DataFrameCacheBenchmark(masterURL: String) {
5547 }
5648 }
5749
58- def floatSumBenchmark (values : Int , iters : Int = 5 ): Unit = {
50+ def floatSumBenchmark (sqlContext : SQLContext , values : Int , iters : Int = 5 ): Unit = {
51+ import sqlContext .implicits ._
52+
5953 val suites = Seq ((" InternalRow" , " false" ), (" ColumnVector" , " true" ))
6054
6155 val benchmarkPT = new Benchmark (" Float Sum with PassThrough cache" , values, iters)
6256 val rand1 = new Random (511 )
63- val dfPassThrough = sc .parallelize(0 to values - 1 , 1 )
57+ val dfPassThrough = sqlContext.sparkContext .parallelize(0 to values - 1 , 1 )
6458 .map(i => rand1.nextFloat()).toDF().cache()
6559 dfPassThrough.count() // force to create df.cache()
6660 suites.foreach {
6761 case (str, value) =>
6862 benchmarkPT.addCase(s " $str codegen " ) { iter =>
69- withSQLConf(SQLConf . COLUMN_VECTOR_CODEGEN .key -> value) {
63+ withSQLConf(sqlContext, SQLConf . COLUMN_VECTOR_CODEGEN .key -> value) {
7064 dfPassThrough.agg(sum(" value" )).collect
7165 }
7266 }
@@ -86,18 +80,20 @@ case class DataFrameCacheBenchmark(masterURL: String) {
8680 System .gc()
8781 }
8882
89- def doubleSumBenchmark (values : Int , iters : Int = 5 ): Unit = {
83+ def doubleSumBenchmark (sqlContext : SQLContext , values : Int , iters : Int = 5 ): Unit = {
84+ import sqlContext .implicits ._
85+
9086 val suites = Seq ((" InternalRow" , " false" ), (" ColumnVector" , " true" ))
9187
9288 val benchmarkPT = new Benchmark (" Double Sum with PassThrough cache" , values, iters)
9389 val rand1 = new Random (511 )
94- val dfPassThrough = sc .parallelize(0 to values - 1 , 1 )
90+ val dfPassThrough = sqlContext.sparkContext .parallelize(0 to values - 1 , 1 )
9591 .map(i => rand1.nextDouble()).toDF().cache()
9692 dfPassThrough.count() // force to create df.cache()
9793 suites.foreach {
9894 case (str, value) =>
9995 benchmarkPT.addCase(s " $str codegen " ) { iter =>
100- withSQLConf(SQLConf . COLUMN_VECTOR_CODEGEN .key -> value) {
96+ withSQLConf(sqlContext, SQLConf . COLUMN_VECTOR_CODEGEN .key -> value) {
10197 dfPassThrough.agg(sum(" value" )).collect
10298 }
10399 }
@@ -117,19 +113,27 @@ case class DataFrameCacheBenchmark(masterURL: String) {
117113 System .gc()
118114 }
119115
120- def run (f : Int , d : Int ): Unit = {
121- floatSumBenchmark(1024 * 1024 * f)
122- doubleSumBenchmark(1024 * 1024 * d)
116+ def run (sqlContext : SQLContext , f : Int , d : Int ): Unit = {
117+ sqlContext.conf.setConfString(SQLConf .WHOLESTAGE_CODEGEN_ENABLED .key, " true" )
118+
119+ floatSumBenchmark(sqlContext, 1024 * 1024 * f)
120+ doubleSumBenchmark(sqlContext, 1024 * 1024 * d)
123121 }
124122}
125123
126124object DataFrameCacheBenchmark {
125+ val F = 30
126+ val D = 15
127127 def main (args : Array [String ]): Unit = {
128- val f = if (args.length > 0 ) args(0 ).toInt else 30
129- val d = if (args.length > 1 ) args(1 ).toInt else 15
128+ val f = if (args.length > 0 ) args(0 ).toInt else F
129+ val d = if (args.length > 1 ) args(1 ).toInt else D
130130 val masterURL = if (args.length > 2 ) args(2 ) else " local[1]"
131131
132- val benchmark = DataFrameCacheBenchmark (masterURL)
133- benchmark.run(f, d)
132+ val conf = new SparkConf ()
133+ val sc = new SparkContext (masterURL, " DataFrameCacheBenchmark" , conf)
134+ val sqlContext = new SQLContext (sc)
135+
136+ val benchmark = new DataFrameCacheBenchmark
137+ benchmark.run(sqlContext, f, d)
134138 }
135139}
0 commit comments