diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeFormatter.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeFormatter.scala index 60e600d8dbd8f..7b398f424cead 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeFormatter.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeFormatter.scala @@ -89,6 +89,14 @@ object CodeFormatter { } new CodeAndComment(code.result().trim(), map) } + + def stripExtraNewLinesAndComments(input: String): String = { + val commentReg = + ("""([ |\t]*?\/\*[\s|\S]*?\*\/[ |\t]*?)|""" + // strip /*comment*/ + """([ |\t]*?\/\/[\s\S]*?\n)""").r // strip //comment + val codeWithoutComment = commentReg.replaceAllIn(input, "") + codeWithoutComment.replaceAll("""\n\s*\n""", "\n") // strip ExtraNewLines + } } private class CodeFormatter { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index a014e2aa34820..807765c1e00a1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -355,6 +355,20 @@ class CodegenContext { */ private val placeHolderToComments = new mutable.HashMap[String, String] + /** + * It will count the lines of every Java function generated by whole-stage codegen, + * if there is a function of length greater than spark.sql.codegen.maxLinesPerFunction, + * it will return true. + */ + def isTooLongGeneratedFunction: Boolean = { + classFunctions.values.exists { _.values.exists { + code => + val codeWithoutComments = CodeFormatter.stripExtraNewLinesAndComments(code) + codeWithoutComments.count(_ == '\n') > SQLConf.get.maxLinesPerFunction + } + } + } + /** * Returns a term name that is unique within this instance of a `CodegenContext`. */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index a819cddcae988..a0b8364e91c92 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -572,6 +572,16 @@ object SQLConf { "disable logging or -1 to apply no limit.") .createWithDefault(1000) + val WHOLESTAGE_MAX_LINES_PER_FUNCTION = buildConf("spark.sql.codegen.maxLinesPerFunction") + .internal() + .doc("The maximum lines of a single Java function generated by whole-stage codegen. " + + "When the generated function exceeds this threshold, " + + "the whole-stage codegen is deactivated for this subtree of the current query plan. " + + "The default value 2667 is the max length of byte code JIT supported " + + "for a single function(8000) divided by 3.") + .intConf + .createWithDefault(2667) + val FILES_MAX_PARTITION_BYTES = buildConf("spark.sql.files.maxPartitionBytes") .doc("The maximum number of bytes to pack into a single partition when reading files.") .longConf @@ -1014,6 +1024,8 @@ class SQLConf extends Serializable with Logging { def loggingMaxLinesForCodegen: Int = getConf(CODEGEN_LOGGING_MAX_LINES) + def maxLinesPerFunction: Int = getConf(WHOLESTAGE_MAX_LINES_PER_FUNCTION) + def tableRelationCacheSize: Int = getConf(StaticSQLConf.FILESOURCE_TABLE_RELATION_CACHE_SIZE) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeFormatterSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeFormatterSuite.scala index 9d0a41661beaa..a0f1a64b0ab08 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeFormatterSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeFormatterSuite.scala @@ -53,6 +53,38 @@ class CodeFormatterSuite extends SparkFunSuite { assert(reducedCode.body === "/*project_c4*/") } + test("removing extra new lines and comments") { + val code = + """ + |/* + | * multi + | * line + | * comments + | */ + | + |public function() { + |/*comment*/ + | /*comment_with_space*/ + |code_body + |//comment + |code_body + | //comment_with_space + | + |code_body + |} + """.stripMargin + + val reducedCode = CodeFormatter.stripExtraNewLinesAndComments(code) + assert(reducedCode === + """ + |public function() { + |code_body + |code_body + |code_body + |} + """.stripMargin) + } + testCase("basic example") { """ |class A { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala index 34134db278ad8..bacb7090a70ab 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala @@ -370,6 +370,14 @@ case class WholeStageCodegenExec(child: SparkPlan) extends UnaryExecNode with Co override def doExecute(): RDD[InternalRow] = { val (ctx, cleanedSource) = doCodeGen() + if (ctx.isTooLongGeneratedFunction) { + logWarning("Found too long generated codes and JIT optimization might not work, " + + "Whole-stage codegen disabled for this plan, " + + "You can change the config spark.sql.codegen.MaxFunctionLength " + + "to adjust the function length limit:\n " + + s"$treeString") + return child.execute() + } // try to compile and fallback if it failed try { CodeGenerator.compile(cleanedSource) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala index 183c68fd3c016..beeee6a97c8dd 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.execution import org.apache.spark.sql.{Column, Dataset, Row} import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute import org.apache.spark.sql.catalyst.expressions.{Add, Literal, Stack} +import org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext import org.apache.spark.sql.execution.aggregate.HashAggregateExec import org.apache.spark.sql.execution.joins.BroadcastHashJoinExec import org.apache.spark.sql.execution.joins.SortMergeJoinExec @@ -149,4 +150,60 @@ class WholeStageCodegenSuite extends SparkPlanTest with SharedSQLContext { assert(df.collect() === Array(Row(1), Row(2))) } } + + def genGroupByCodeGenContext(caseNum: Int): CodegenContext = { + val caseExp = (1 to caseNum).map { i => + s"case when id > $i and id <= ${i + 1} then 1 else 0 end as v$i" + }.toList + val keyExp = List( + "id", + "(id & 1023) as k1", + "cast(id & 1023 as double) as k2", + "cast(id & 1023 as int) as k3") + + val ds = spark.range(10) + .selectExpr(keyExp:::caseExp: _*) + .groupBy("k1", "k2", "k3") + .sum() + val plan = ds.queryExecution.executedPlan + + val wholeStageCodeGenExec = plan.find(p => p match { + case wp: WholeStageCodegenExec => wp.child match { + case hp: HashAggregateExec if (hp.child.isInstanceOf[ProjectExec]) => true + case _ => false + } + case _ => false + }) + + assert(wholeStageCodeGenExec.isDefined) + wholeStageCodeGenExec.get.asInstanceOf[WholeStageCodegenExec].doCodeGen()._1 + } + + test("SPARK-21603 check there is a too long generated function") { + withSQLConf(SQLConf.WHOLESTAGE_MAX_LINES_PER_FUNCTION.key -> "1500") { + val ctx = genGroupByCodeGenContext(30) + assert(ctx.isTooLongGeneratedFunction === true) + } + } + + test("SPARK-21603 check there is not a too long generated function") { + withSQLConf(SQLConf.WHOLESTAGE_MAX_LINES_PER_FUNCTION.key -> "1500") { + val ctx = genGroupByCodeGenContext(1) + assert(ctx.isTooLongGeneratedFunction === false) + } + } + + test("SPARK-21603 check there is not a too long generated function when threshold is Int.Max") { + withSQLConf(SQLConf.WHOLESTAGE_MAX_LINES_PER_FUNCTION.key -> Int.MaxValue.toString) { + val ctx = genGroupByCodeGenContext(30) + assert(ctx.isTooLongGeneratedFunction === false) + } + } + + test("SPARK-21603 check there is a too long generated function when threshold is 0") { + withSQLConf(SQLConf.WHOLESTAGE_MAX_LINES_PER_FUNCTION.key -> "0") { + val ctx = genGroupByCodeGenContext(1) + assert(ctx.isTooLongGeneratedFunction === true) + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/AggregateBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/AggregateBenchmark.scala index 8a798fb444696..691fa9ac5e1e7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/AggregateBenchmark.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/AggregateBenchmark.scala @@ -301,6 +301,68 @@ class AggregateBenchmark extends BenchmarkBase { */ } + ignore("max function length of wholestagecodegen") { + val N = 20 << 15 + + val benchmark = new Benchmark("max function length of wholestagecodegen", N) + def f(): Unit = sparkSession.range(N) + .selectExpr( + "id", + "(id & 1023) as k1", + "cast(id & 1023 as double) as k2", + "cast(id & 1023 as int) as k3", + "case when id > 100 and id <= 200 then 1 else 0 end as v1", + "case when id > 200 and id <= 300 then 1 else 0 end as v2", + "case when id > 300 and id <= 400 then 1 else 0 end as v3", + "case when id > 400 and id <= 500 then 1 else 0 end as v4", + "case when id > 500 and id <= 600 then 1 else 0 end as v5", + "case when id > 600 and id <= 700 then 1 else 0 end as v6", + "case when id > 700 and id <= 800 then 1 else 0 end as v7", + "case when id > 800 and id <= 900 then 1 else 0 end as v8", + "case when id > 900 and id <= 1000 then 1 else 0 end as v9", + "case when id > 1000 and id <= 1100 then 1 else 0 end as v10", + "case when id > 1100 and id <= 1200 then 1 else 0 end as v11", + "case when id > 1200 and id <= 1300 then 1 else 0 end as v12", + "case when id > 1300 and id <= 1400 then 1 else 0 end as v13", + "case when id > 1400 and id <= 1500 then 1 else 0 end as v14", + "case when id > 1500 and id <= 1600 then 1 else 0 end as v15", + "case when id > 1600 and id <= 1700 then 1 else 0 end as v16", + "case when id > 1700 and id <= 1800 then 1 else 0 end as v17", + "case when id > 1800 and id <= 1900 then 1 else 0 end as v18") + .groupBy("k1", "k2", "k3") + .sum() + .collect() + + benchmark.addCase(s"codegen = F") { iter => + sparkSession.conf.set("spark.sql.codegen.wholeStage", "false") + f() + } + + benchmark.addCase(s"codegen = T maxLinesPerFunction = 10000") { iter => + sparkSession.conf.set("spark.sql.codegen.wholeStage", "true") + sparkSession.conf.set("spark.sql.codegen.maxLinesPerFunction", "10000") + f() + } + + benchmark.addCase(s"codegen = T maxLinesPerFunction = 1500") { iter => + sparkSession.conf.set("spark.sql.codegen.wholeStage", "true") + sparkSession.conf.set("spark.sql.codegen.maxLinesPerFunction", "1500") + f() + } + + benchmark.run() + + /* + Java HotSpot(TM) 64-Bit Server VM 1.8.0_111-b14 on Windows 7 6.1 + Intel64 Family 6 Model 58 Stepping 9, GenuineIntel + max function length of wholestagecodegen: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ---------------------------------------------------------------------------------------------- + codegen = F 462 / 533 1.4 704.4 1.0X + codegen = T maxLinesPerFunction = 10000 3444 / 3447 0.2 5255.3 0.1X + codegen = T maxLinesPerFunction = 1500 447 / 478 1.5 682.1 1.0X + */ + } + ignore("cube") { val N = 5 << 20