Skip to content

Commit fea9f34

Browse files
committed
update
1 parent 4775db2 commit fea9f34

File tree

1 file changed

+27
-23
lines changed

1 file changed

+27
-23
lines changed

sql/core/src/test/scala/org/apache/spark/sql/DataFrameCacheBenchmark.scala

Lines changed: 27 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -33,17 +33,9 @@ import org.apache.spark.util.Benchmark
3333
* sql/core/target/spark-sql_*-tests.jar
3434
* [float datasize scale] [double datasize scale] [master URL]
3535
*/
36-
case class DataFrameCacheBenchmark(masterURL: String) {
37-
val conf = new SparkConf()
38-
val sc = new SparkContext(
39-
(if (masterURL == null) "local[1]" else masterURL), "test-sql-context", conf)
40-
val sqlContext = new SQLContext(sc)
41-
import sqlContext.implicits._
36+
class DataFrameCacheBenchmark {
4237

43-
// Set default configs. Individual cases will change them if necessary.
44-
sqlContext.conf.setConfString(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key, "true")
45-
46-
def withSQLConf(pairs: (String, String)*)(f: => Unit): Unit = {
38+
def withSQLConf(sqlContext: SQLContext, pairs: (String, String)*)(f: => Unit): Unit = {
4739
val (keys, values) = pairs.unzip
4840
val currentValues = keys.map(key => Try(sqlContext.conf.getConfString(key)).toOption)
4941
(keys, values).zipped.foreach(sqlContext.conf.setConfString)
@@ -55,18 +47,20 @@ case class DataFrameCacheBenchmark(masterURL: String) {
5547
}
5648
}
5749

58-
def floatSumBenchmark(values: Int, iters: Int = 5): Unit = {
50+
def floatSumBenchmark(sqlContext: SQLContext, values: Int, iters: Int = 5): Unit = {
51+
import sqlContext.implicits._
52+
5953
val suites = Seq(("InternalRow", "false"), ("ColumnVector", "true"))
6054

6155
val benchmarkPT = new Benchmark("Float Sum with PassThrough cache", values, iters)
6256
val rand1 = new Random(511)
63-
val dfPassThrough = sc.parallelize(0 to values - 1, 1)
57+
val dfPassThrough = sqlContext.sparkContext.parallelize(0 to values - 1, 1)
6458
.map(i => rand1.nextFloat()).toDF().cache()
6559
dfPassThrough.count() // force to create df.cache()
6660
suites.foreach {
6761
case (str, value) =>
6862
benchmarkPT.addCase(s"$str codegen") { iter =>
69-
withSQLConf(SQLConf. COLUMN_VECTOR_CODEGEN.key -> value) {
63+
withSQLConf(sqlContext, SQLConf. COLUMN_VECTOR_CODEGEN.key -> value) {
7064
dfPassThrough.agg(sum("value")).collect
7165
}
7266
}
@@ -86,18 +80,20 @@ case class DataFrameCacheBenchmark(masterURL: String) {
8680
System.gc()
8781
}
8882

89-
def doubleSumBenchmark(values: Int, iters: Int = 5): Unit = {
83+
def doubleSumBenchmark(sqlContext: SQLContext, values: Int, iters: Int = 5): Unit = {
84+
import sqlContext.implicits._
85+
9086
val suites = Seq(("InternalRow", "false"), ("ColumnVector", "true"))
9187

9288
val benchmarkPT = new Benchmark("Double Sum with PassThrough cache", values, iters)
9389
val rand1 = new Random(511)
94-
val dfPassThrough = sc.parallelize(0 to values - 1, 1)
90+
val dfPassThrough = sqlContext.sparkContext.parallelize(0 to values - 1, 1)
9591
.map(i => rand1.nextDouble()).toDF().cache()
9692
dfPassThrough.count() // force to create df.cache()
9793
suites.foreach {
9894
case (str, value) =>
9995
benchmarkPT.addCase(s"$str codegen") { iter =>
100-
withSQLConf(SQLConf. COLUMN_VECTOR_CODEGEN.key -> value) {
96+
withSQLConf(sqlContext, SQLConf. COLUMN_VECTOR_CODEGEN.key -> value) {
10197
dfPassThrough.agg(sum("value")).collect
10298
}
10399
}
@@ -117,19 +113,27 @@ case class DataFrameCacheBenchmark(masterURL: String) {
117113
System.gc()
118114
}
119115

120-
def run(f: Int, d: Int): Unit = {
121-
floatSumBenchmark(1024 * 1024 * f)
122-
doubleSumBenchmark(1024 * 1024 * d)
116+
def run(sqlContext: SQLContext, f: Int, d: Int): Unit = {
117+
sqlContext.conf.setConfString(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key, "true")
118+
119+
floatSumBenchmark(sqlContext, 1024 * 1024 * f)
120+
doubleSumBenchmark(sqlContext, 1024 * 1024 * d)
123121
}
124122
}
125123

126124
object DataFrameCacheBenchmark {
125+
val F = 30
126+
val D = 15
127127
def main(args: Array[String]): Unit = {
128-
val f = if (args.length > 0) args(0).toInt else 30
129-
val d = if (args.length > 1) args(1).toInt else 15
128+
val f = if (args.length > 0) args(0).toInt else F
129+
val d = if (args.length > 1) args(1).toInt else D
130130
val masterURL = if (args.length > 2) args(2) else "local[1]"
131131

132-
val benchmark = DataFrameCacheBenchmark(masterURL)
133-
benchmark.run(f, d)
132+
val conf = new SparkConf()
133+
val sc = new SparkContext(masterURL, "DataFrameCacheBenchmark", conf)
134+
val sqlContext = new SQLContext(sc)
135+
136+
val benchmark = new DataFrameCacheBenchmark
137+
benchmark.run(sqlContext, f, d)
134138
}
135139
}

0 commit comments

Comments
 (0)