diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 2cb66599076a9..58738b52b299f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -77,6 +77,22 @@ case class SubExprEliminationState(isNull: String, value: String) */ case class SubExprCodes(codes: Seq[String], states: Map[Expression, SubExprEliminationState]) +/** + * The main information about a new added function. + * + * @param functionName String representing the name of the function + * @param innerClassName Optional value which is empty if the function is added to + * the outer class, otherwise it contains the name of the + * inner class in which the function has been added. + * @param innerClassInstance Optional value which is empty if the function is added to + * the outer class, otherwise it contains the name of the + * instance of the inner class in the outer class. + */ +private[codegen] case class NewFunctionSpec( + functionName: String, + innerClassName: Option[String], + innerClassInstance: Option[String]) + /** * A context for codegen, tracking a list of objects that could be passed into generated Java * function. @@ -228,8 +244,8 @@ class CodegenContext { /** * Holds the class and instance names to be generated, where `OuterClass` is a placeholder * standing for whichever class is generated as the outermost class and which will contain any - * nested sub-classes. All other classes and instance names in this list will represent private, - * nested sub-classes. + * inner sub-classes. All other classes and instance names in this list will represent private, + * inner sub-classes. */ private val classes: mutable.ListBuffer[(String, String)] = mutable.ListBuffer[(String, String)](outerClassName -> null) @@ -260,8 +276,8 @@ class CodegenContext { /** * Adds a function to the generated class. If the code for the `OuterClass` grows too large, the - * function will be inlined into a new private, nested class, and a class-qualified name for the - * function will be returned. Otherwise, the function will be inined to the `OuterClass` the + * function will be inlined into a new private, inner class, and a class-qualified name for the + * function will be returned. Otherwise, the function will be inlined to the `OuterClass` the * simple `funcName` will be returned. * * @param funcName the class-unqualified name of the function @@ -271,19 +287,27 @@ class CodegenContext { * it is eventually referenced and a returned qualified function name * cannot otherwise be accessed. * @return the name of the function, qualified by class if it will be inlined to a private, - * nested sub-class + * inner class */ def addNewFunction( funcName: String, funcCode: String, inlineToOuterClass: Boolean = false): String = { - // The number of named constants that can exist in the class is limited by the Constant Pool - // limit, 65,536. We cannot know how many constants will be inserted for a class, so we use a - // threshold of 1600k bytes to determine when a function should be inlined to a private, nested - // sub-class. + val newFunction = addNewFunctionInternal(funcName, funcCode, inlineToOuterClass) + newFunction match { + case NewFunctionSpec(functionName, None, None) => functionName + case NewFunctionSpec(functionName, Some(_), Some(innerClassInstance)) => + innerClassInstance + "." + functionName + } + } + + private[this] def addNewFunctionInternal( + funcName: String, + funcCode: String, + inlineToOuterClass: Boolean): NewFunctionSpec = { val (className, classInstance) = if (inlineToOuterClass) { outerClassName -> "" - } else if (currClassSize > 1600000) { + } else if (currClassSize > CodeGenerator.GENERATED_CLASS_SIZE_THRESHOLD) { val className = freshName("NestedClass") val classInstance = freshName("nestedClassInstance") @@ -294,17 +318,23 @@ class CodegenContext { currClass() } - classSize(className) += funcCode.length - classFunctions(className) += funcName -> funcCode + addNewFunctionToClass(funcName, funcCode, className) if (className == outerClassName) { - funcName + NewFunctionSpec(funcName, None, None) } else { - - s"$classInstance.$funcName" + NewFunctionSpec(funcName, Some(className), Some(classInstance)) } } + private[this] def addNewFunctionToClass( + funcName: String, + funcCode: String, + className: String) = { + classSize(className) += funcCode.length + classFunctions(className) += funcName -> funcCode + } + /** * Declares all function code. If the added functions are too many, split them into nested * sub-classes to avoid hitting Java compiler constant pool limitation. @@ -738,7 +768,7 @@ class CodegenContext { /** * Splits the generated code of expressions into multiple functions, because function has * 64kb code size limit in JVM. If the class to which the function would be inlined would grow - * beyond 1600kb, we declare a private, nested sub-class, and the function is inlined to it + * beyond 1000kb, we declare a private, inner sub-class, and the function is inlined to it * instead, because classes have a constant pool limit of 65,536 named values. * * @param row the variable name of row that is used by expressions @@ -801,10 +831,90 @@ class CodegenContext { | ${makeSplitFunction(body)} |} """.stripMargin - addNewFunction(name, code) + addNewFunctionInternal(name, code, inlineToOuterClass = false) } - foldFunctions(functions.map(name => s"$name(${arguments.map(_._2).mkString(", ")})")) + val (outerClassFunctions, innerClassFunctions) = functions.partition(_.innerClassName.isEmpty) + + val argsString = arguments.map(_._2).mkString(", ") + val outerClassFunctionCalls = outerClassFunctions.map(f => s"${f.functionName}($argsString)") + + val innerClassFunctionCalls = generateInnerClassesFunctionCalls( + innerClassFunctions, + func, + arguments, + returnType, + makeSplitFunction, + foldFunctions) + + foldFunctions(outerClassFunctionCalls ++ innerClassFunctionCalls) + } + } + + /** + * Here we handle all the methods which have been added to the inner classes and + * not to the outer class. + * Since they can be many, their direct invocation in the outer class adds many entries + * to the outer class' constant pool. This can cause the constant pool to past JVM limit. + * Moreover, this can cause also the outer class method where all the invocations are + * performed to grow beyond the 64k limit. + * To avoid these problems, we group them and we call only the grouping methods in the + * outer class. + * + * @param functions a [[Seq]] of [[NewFunctionSpec]] defined in the inner classes + * @param funcName the split function name base. + * @param arguments the list of (type, name) of the arguments of the split function. + * @param returnType the return type of the split function. + * @param makeSplitFunction makes split function body, e.g. add preparation or cleanup. + * @param foldFunctions folds the split function calls. + * @return an [[Iterable]] containing the methods' invocations + */ + private def generateInnerClassesFunctionCalls( + functions: Seq[NewFunctionSpec], + funcName: String, + arguments: Seq[(String, String)], + returnType: String, + makeSplitFunction: String => String, + foldFunctions: Seq[String] => String): Iterable[String] = { + val innerClassToFunctions = mutable.LinkedHashMap.empty[(String, String), Seq[String]] + functions.foreach(f => { + val key = (f.innerClassName.get, f.innerClassInstance.get) + val value = f.functionName +: innerClassToFunctions.getOrElse(key, Seq.empty[String]) + innerClassToFunctions.put(key, value) + }) + + val argDefinitionString = arguments.map { case (t, name) => s"$t $name" }.mkString(", ") + val argInvocationString = arguments.map(_._2).mkString(", ") + + innerClassToFunctions.flatMap { + case ((innerClassName, innerClassInstance), innerClassFunctions) => + // for performance reasons, the functions are prepended, instead of appended, + // thus here they are in reversed order + val orderedFunctions = innerClassFunctions.reverse + if (orderedFunctions.size > CodeGenerator.MERGE_SPLIT_METHODS_THRESHOLD) { + // Adding a new function to each inner class which contains the invocation of all the + // ones which have been added to that inner class. For example, + // private class NestedClass { + // private void apply_862(InternalRow i) { ... } + // private void apply_863(InternalRow i) { ... } + // ... + // private void apply(InternalRow i) { + // apply_862(i); + // apply_863(i); + // ... + // } + // } + val body = foldFunctions(orderedFunctions.map(name => s"$name($argInvocationString)")) + val code = s""" + |private $returnType $funcName($argDefinitionString) { + | ${makeSplitFunction(body)} + |} + """.stripMargin + addNewFunctionToClass(funcName, code, innerClassName) + Seq(s"$innerClassInstance.$funcName($argInvocationString)") + } else { + orderedFunctions.map(f => s"$innerClassInstance.$f($argInvocationString)") + } } } @@ -1013,6 +1123,16 @@ object CodeGenerator extends Logging { // This is the value of HugeMethodLimit in the OpenJDK JVM settings val DEFAULT_JVM_HUGE_METHOD_LIMIT = 8000 + // This is the threshold over which the methods in an inner class are grouped in a single + // method which is going to be called by the outer class instead of the many small ones + val MERGE_SPLIT_METHODS_THRESHOLD = 3 + + // The number of named constants that can exist in the class is limited by the Constant Pool + // limit, 65,536. We cannot know how many constants will be inserted for a class, so we use a + // threshold of 1000k bytes to determine when a function should be inlined to a private, inner + // class. + val GENERATED_CLASS_SIZE_THRESHOLD = 1000000 + /** * Compile the Java source code into a Java class, using Janino. * diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala index 7ea0bec145481..1e6f7b65e7e72 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala @@ -201,6 +201,23 @@ class CodeGenerationSuite extends SparkFunSuite with ExpressionEvalHelper { } } + test("SPARK-22226: group splitted expressions into one method per nested class") { + val length = 10000 + val expressions = Seq.fill(length) { + ToUTCTimestamp( + Literal.create(Timestamp.valueOf("2017-10-10 00:00:00"), TimestampType), + Literal.create("PST", StringType)) + } + val plan = GenerateMutableProjection.generate(expressions) + val actual = plan(new GenericInternalRow(length)).toSeq(expressions.map(_.dataType)) + val expected = Seq.fill(length)( + DateTimeUtils.fromJavaTimestamp(Timestamp.valueOf("2017-10-10 07:00:00"))) + + if (actual != expected) { + fail(s"Incorrect Evaluation: expressions: $expressions, actual: $actual, expected: $expected") + } + } + test("test generated safe and unsafe projection") { val schema = new StructType(Array( StructField("a", StringType, true), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 473c355cf3c7f..17c88b0690800 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -2106,6 +2106,18 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { Seq(Row(7, 1, 1), Row(7, 1, 2), Row(7, 2, 1), Row(7, 2, 2), Row(7, 3, 1), Row(7, 3, 2))) } + test("SPARK-22226: splitExpressions should not generate codes beyond 64KB") { + val colNumber = 10000 + val input = spark.range(2).rdd.map(_ => Row(1 to colNumber: _*)) + val df = sqlContext.createDataFrame(input, StructType( + (1 to colNumber).map(colIndex => StructField(s"_$colIndex", IntegerType, false)))) + val newCols = (1 to colNumber).flatMap { colIndex => + Seq(expr(s"if(1000 < _$colIndex, 1000, _$colIndex)"), + expr(s"sqrt(_$colIndex)")) + } + df.select(newCols: _*).collect() + } + test("SPARK-22271: mean overflows and returns null for some decimal variables") { val d = 0.034567890 val df = Seq(d, d, d, d, d, d, d, d, d, d).toDF("DecimalCol")