From 013c02f215de85d50b4c7125ee571b14801bdb47 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Tue, 13 Feb 2018 20:23:47 +0800 Subject: [PATCH] add a config to try to inline all mutable states during codegen --- .../expressions/codegen/CodeGenerator.scala | 17 ++++++++++------- .../org/apache/spark/sql/internal/SQLConf.scala | 10 ++++++++++ .../expressions/CodeGenerationSuite.scala | 15 +++++++++++++++ 3 files changed, 35 insertions(+), 7 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 4dcbb702893da..2992a0f3fe181 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -136,6 +136,9 @@ class CodegenContext { */ var currentVars: Seq[ExprCode] = null + // Reads the config here, to make it effective for the entire lifetime of this context. + private val tryInlineAllState = SQLConf.get.getConf(SQLConf.CODEGEN_TRY_INLINE_ALL_STATES) + /** * Holding expressions' inlined mutable states like `MonotonicallyIncreasingID.count` as a * 2-tuple: java type, variable name. @@ -253,10 +256,11 @@ class CodegenContext { forceInline: Boolean = false, useFreshName: Boolean = true): String = { - // want to put a primitive type variable at outerClass for performance - val canInlinePrimitive = isPrimitiveType(javaType) && + // Puts a primitive type variable at outerClass for performance, or puts all variables at outer + // class if CODEGEN_TRY_INLINE_ALL_STATES is true, if we have't hit the threshold yet. + val canInline = (isPrimitiveType(javaType) || tryInlineAllState) && (inlinedMutableStates.length < CodeGenerator.OUTER_CLASS_VARIABLES_THRESHOLD) - if (forceInline || canInlinePrimitive || javaType.contains("[][]")) { + if (forceInline || javaType.contains("[][]") || canInline) { val varName = if (useFreshName) freshName(variableName) else variableName val initCode = initFunc(varName) inlinedMutableStates += ((javaType, varName)) @@ -1461,20 +1465,19 @@ object CodeGenerator extends Logging { CodegenMetrics.METRIC_GENERATED_CLASS_BYTECODE_SIZE.update(classBytes.length) try { val cf = new ClassFile(new ByteArrayInputStream(classBytes)) - val stats = cf.methodInfos.asScala.flatMap { method => + cf.methodInfos.asScala.flatMap { method => method.getAttributes().filter(_.getClass.getName == codeAttr.getName).map { a => val byteCodeSize = codeAttrField.get(a).asInstanceOf[Array[Byte]].length CodegenMetrics.METRIC_GENERATED_METHOD_BYTECODE_SIZE.update(byteCodeSize) byteCodeSize } } - Some(stats) } catch { case NonFatal(e) => logWarning("Error calculating stats of compiled class.", e) - None + Nil } - }.flatten + } codeSizes.max } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 7835dbaa58439..0cfa0153df862 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -675,6 +675,16 @@ object SQLConf { "disable logging or -1 to apply no limit.") .createWithDefault(1000) + val CODEGEN_TRY_INLINE_ALL_STATES = + buildConf("spark.sql.codegen.tryInlineAllStates") + .internal() + .doc("When adding mutable states during code generation, whether or not we should try to " + + "inline all the states. If this config is false, we only try to inline primitive stats, " + + "so that primitive states are more likely to be inlined. Set this config to true to make " + + "the behavior same as Spark 2.2.") + .booleanConf + .createWithDefault(false) + val WHOLESTAGE_HUGE_METHOD_LIMIT = buildConf("spark.sql.codegen.hugeMethodLimit") .internal() .doc("The maximum bytecode size of a single compiled Java function generated by whole-stage " + diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala index 676ba3956ddc8..7550c4172e715 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala @@ -27,6 +27,7 @@ import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.expressions.objects._ import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, DateTimeUtils} +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.util.ThreadUtils @@ -436,4 +437,18 @@ class CodeGenerationSuite extends SparkFunSuite with ExpressionEvalHelper { ctx.addImmutableStateIfNotExists("String", mutableState2) assert(ctx.inlinedMutableStates.length == 2) } + + test("SPARK-23407: inline all mutable states if CODEGEN_TRY_INLINE_ALL_STATES is true") { + val conf = SQLConf.get + try { + conf.setConf(SQLConf.CODEGEN_TRY_INLINE_ALL_STATES, true) + val ctx = new CodegenContext + ctx.addMutableState(ctx.JAVA_INT, "i", v => s"$v = 1;") + ctx.addMutableState("String", "s", v => s"$v = null;") + assert(ctx.inlinedMutableStates.size == 2) + assert(ctx.arrayCompactedMutableStates.isEmpty) + } finally { + conf.unsetConf(SQLConf.CODEGEN_TRY_INLINE_ALL_STATES) + } + } }