From 240d5a379ad40e726e4a36a157883a69293cfeba Mon Sep 17 00:00:00 2001 From: ALeksander Eskilson Date: Wed, 11 Jan 2017 13:33:26 -0600 Subject: [PATCH 1/6] adding initial class splitting [class_splitting] increasing stack size for Catalyst tests class_splitting with variable splitting adding mutable state changes finished enhancement poc, requires cleanup fixing mutable state for several classes --- sql/catalyst/pom.xml | 7 + .../MonotonicallyIncreasingID.scala | 13 +- .../sql/catalyst/expressions/ScalaUDF.scala | 19 +- .../expressions/SparkPartitionID.scala | 7 +- .../expressions/codegen/CodeGenerator.scala | 285 +++++++++++++++--- .../codegen/GenerateMutableProjection.scala | 19 +- .../codegen/GenerateOrdering.scala | 3 + .../codegen/GeneratePredicate.scala | 3 + .../codegen/GenerateSafeProjection.scala | 13 +- .../codegen/GenerateUnsafeProjection.scala | 66 ++-- .../expressions/complexTypeCreator.scala | 34 +-- .../expressions/conditionalExpressions.scala | 13 +- .../expressions/datetimeExpressions.scala | 22 +- .../sql/catalyst/expressions/generators.scala | 11 +- .../spark/sql/catalyst/expressions/hash.scala | 22 +- .../expressions/objects/objects.scala | 74 +++-- .../sql/catalyst/expressions/predicates.scala | 11 +- .../expressions/randomExpressions.scala | 14 +- .../expressions/regexpExpressions.scala | 60 ++-- .../expressions/stringExpressions.scala | 43 +-- .../sql/execution/ColumnarBatchScan.scala | 44 +-- .../sql/execution/DataSourceScanExec.scala | 14 +- .../apache/spark/sql/execution/SortExec.scala | 30 +- .../sql/execution/WholeStageCodegenExec.scala | 5 +- .../aggregate/HashAggregateExec.scala | 81 ++--- .../aggregate/HashMapGenerator.scala | 10 +- .../execution/basicPhysicalOperators.scala | 77 ++--- .../columnar/GenerateColumnAccessor.scala | 25 +- .../joins/BroadcastHashJoinExec.scala | 4 +- .../execution/joins/SortMergeJoinExec.scala | 66 ++-- .../apache/spark/sql/execution/limit.scala | 14 +- .../spark/sql/DataFrameComplexTypeSuite.scala | 13 +- 32 files changed, 700 insertions(+), 422 deletions(-) diff --git a/sql/catalyst/pom.xml b/sql/catalyst/pom.xml index 8d80f8eca5dba..0f75852d51661 100644 --- a/sql/catalyst/pom.xml +++ b/sql/catalyst/pom.xml @@ -131,6 +131,13 @@ + + org.scalatest + scalatest-maven-plugin + + -Xmx3g -Xss4096k -XX:MaxPermSize=${MaxPermGen} -XX:ReservedCodeCacheSize=512m + + org.antlr antlr4-maven-plugin diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/MonotonicallyIncreasingID.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/MonotonicallyIncreasingID.scala index 84027b53dca27..079c271912180 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/MonotonicallyIncreasingID.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/MonotonicallyIncreasingID.scala @@ -67,14 +67,15 @@ case class MonotonicallyIncreasingID() extends LeafExpression with Nondeterminis override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val countTerm = ctx.freshName("count") val partitionMaskTerm = ctx.freshName("partitionMask") - ctx.addMutableState(ctx.JAVA_LONG, countTerm, "") - ctx.addMutableState(ctx.JAVA_LONG, partitionMaskTerm, "") - ctx.addPartitionInitializationStatement(s"$countTerm = 0L;") - ctx.addPartitionInitializationStatement(s"$partitionMaskTerm = ((long) partitionIndex) << 33;") + val countTermAccessor = ctx.addMutableState(ctx.JAVA_LONG, countTerm, "") + val partitionMaskTermAccessor = ctx.addMutableState(ctx.JAVA_LONG, partitionMaskTerm, "") + ctx.addPartitionInitializationStatement(s"$countTermAccessor = 0L;") + ctx.addPartitionInitializationStatement( + s"$partitionMaskTermAccessor = ((long) partitionIndex) << 33;") ev.copy(code = s""" - final ${ctx.javaType(dataType)} ${ev.value} = $partitionMaskTerm + $countTerm; - $countTerm++;""", isNull = "false") + final ${ctx.javaType(dataType)} ${ev.value} = $partitionMaskTermAccessor + $countTermAccessor; + $countTermAccessor++;""", isNull = "false") } override def prettyName: String = "monotonically_increasing_id" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala index af1eba26621bd..afabc998e2551 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala @@ -987,11 +987,11 @@ case class ScalaUDF( val converterTerm = ctx.freshName("converter") val expressionIdx = ctx.references.size - 1 - ctx.addMutableState(converterClassName, converterTerm, - s"this.$converterTerm = ($converterClassName)$typeConvertersClassName" + + val converterTermAccessor = ctx.addMutableState(converterClassName, converterTerm, + s"$converterTerm = ($converterClassName)$typeConvertersClassName" + s".createToScalaConverter(((${expressionClassName})((($scalaUDFClassName)" + s"references[$expressionIdx]).getChildren().apply($index))).dataType());") - converterTerm + converterTermAccessor } override def doGenCode( @@ -1004,8 +1004,9 @@ case class ScalaUDF( // Generate codes used to convert the returned value of user-defined functions to Catalyst type val catalystConverterTerm = ctx.freshName("catalystConverter") - ctx.addMutableState(converterClassName, catalystConverterTerm, - s"this.$catalystConverterTerm = ($converterClassName)$typeConvertersClassName" + + val catalystTermAccessor = + ctx.addMutableState(converterClassName, catalystConverterTerm, + s"$catalystConverterTerm = ($converterClassName)$typeConvertersClassName" + s".createToCatalystConverter($scalaUDF.dataType());") val resultTerm = ctx.freshName("result") @@ -1018,8 +1019,8 @@ case class ScalaUDF( val funcClassName = s"scala.Function${children.size}" val funcTerm = ctx.freshName("udf") - ctx.addMutableState(funcClassName, funcTerm, - s"this.$funcTerm = ($funcClassName)$scalaUDF.userDefinedFunc();") + val funcTermAccessor = ctx.addMutableState(funcClassName, funcTerm, + s"$funcTerm = ($funcClassName)$scalaUDF.userDefinedFunc();") // codegen for children expressions val evals = children.map(_.genCode(ctx)) @@ -1036,12 +1037,12 @@ case class ScalaUDF( (convert, argTerm) }.unzip - val getFuncResult = s"$funcTerm.apply(${funcArguments.mkString(", ")})" + val getFuncResult = s"$funcTermAccessor.apply(${funcArguments.mkString(", ")})" val callFunc = s""" ${ctx.boxedType(dataType)} $resultTerm = null; try { - $resultTerm = (${ctx.boxedType(dataType)})$catalystConverterTerm.apply($getFuncResult); + $resultTerm = (${ctx.boxedType(dataType)})$catalystTermAccessor.apply($getFuncResult); } catch (Exception e) { throw new org.apache.spark.SparkException($scalaUDF.udfErrorMessage(), e); } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SparkPartitionID.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SparkPartitionID.scala index 8db7efdbb5dd4..4f9dba3fb8b9b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SparkPartitionID.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SparkPartitionID.scala @@ -44,8 +44,9 @@ case class SparkPartitionID() extends LeafExpression with Nondeterministic { override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val idTerm = ctx.freshName("partitionId") - ctx.addMutableState(ctx.JAVA_INT, idTerm, "") - ctx.addPartitionInitializationStatement(s"$idTerm = partitionIndex;") - ev.copy(code = s"final ${ctx.javaType(dataType)} ${ev.value} = $idTerm;", isNull = "false") + val idTermAccessor = ctx.addMutableState(ctx.JAVA_INT, idTerm, "") + ctx.addPartitionInitializationStatement(s"$idTermAccessor = partitionIndex;") + ev.copy( + code = s"final ${ctx.javaType(dataType)} ${ev.value} = $idTermAccessor;", isNull = "false") } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 760ead42c762c..289888937adbf 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -110,8 +110,8 @@ class CodegenContext { val idx = references.length references += obj val clsName = Option(className).getOrElse(obj.getClass.getName) - addMutableState(clsName, term, s"this.$term = ($clsName) references[$idx];") - term + val termAccessor = addMutableState(clsName, term, s"$term = ($clsName) references[$idx];") + termAccessor } /** @@ -145,11 +145,71 @@ class CodegenContext { * * They will be kept as member variables in generated classes like `SpecificProjection`. */ - val mutableStates: mutable.ArrayBuffer[(String, String, String)] = - mutable.ArrayBuffer.empty[(String, String, String)] + val mutableState: mutable.ListBuffer[(String, String, String)] = + mutable.ListBuffer.empty[(String, String, String)] - def addMutableState(javaType: String, variableName: String, initCode: String): Unit = { - mutableStates += ((javaType, variableName, initCode)) + var mutableStateCount: Int = 0 + + var mutableStateArrayIdx: mutable.Map[(String, String), Int] = + mutable.Map.empty[(String, String), Int] + + var mutableStateArrayNames: mutable.Map[(String, String), String] = + mutable.Map.empty[(String, String), String] + + var mutableStateArrayInitCodes: mutable.Map[(String, String), String] = + mutable.Map.empty[(String, String), String] + + /** + * Adds an instance of globally-accessible mutable state. Mutable state may either be inlined + * as a private member variable to the class, or it may be compacted into arrays of the same + * type and initialization if the amount of mutable state would grow past 10k, in order to avoid + * Constant Pool limit errors for both state declaration and initialization. + * + * We compact state into arrays when we can anticipate variables of the same type and initCode + * may appear numerous times. Variable names with integer suffixes (as given by the `freshName` + * function), that are either simply assigned (null or no initialization) or are primitive are + * good candidates for array compaction, as these variables types are likely to appear numerous + * times, and can be easily initialized in loops. + * + * @param javaType the javaType + * @param variableName the variable name + * @param initCode the initialization code for the variable + * @return the name of the mutable state variable, which is either the original name if the + * variable is inlined to the class, or an array access if the variable is to be stored + * in an array of variables of the same type and initialization. + */ + def addMutableState( + javaType: String, + variableName: String, + initCode: String, + inLine: Boolean = false): String = { + if (!inLine && variableName.matches(".*\\d+.*") && + (initCode.matches("(^.*\\s*=\\s*null;$|^$)") || isPrimitiveType(javaType))) { + val initCodeKey = initCode.replaceAll(variableName, "*VALUE*") + if (mutableStateArrayIdx.contains((javaType, initCodeKey))) { + val arrayName = mutableStateArrayNames((javaType, initCodeKey)) + val idx = mutableStateArrayIdx((javaType, initCodeKey)) + 1 + + mutableStateArrayIdx.update( + (javaType, initCodeKey), + mutableStateArrayIdx(javaType, initCodeKey) + 1) + + s"$arrayName[$idx]" + } else { + val arrayName = freshName("mutableStateArray") + val qualifiedInitCode = initCode.replaceAll(variableName, s"$arrayName[i]") + mutableStateArrayNames += Tuple2(javaType, initCodeKey) -> arrayName + mutableStateArrayIdx += Tuple2(javaType, initCodeKey) -> 0 + mutableStateArrayInitCodes += Tuple2(javaType, initCodeKey) -> qualifiedInitCode + + s"$arrayName[0]" + } + } else { + mutableStateCount += 1 + mutableState += Tuple3(javaType, variableName, initCode) + + variableName + } } /** @@ -159,30 +219,54 @@ class CodegenContext { */ def addBufferedState(dataType: DataType, variableName: String, initCode: String): ExprCode = { val value = freshName(variableName) - addMutableState(javaType(dataType), value, "") + val valueAccessor = addMutableState(javaType(dataType), value, "") val code = dataType match { - case StringType => s"$value = $initCode.clone();" - case _: StructType | _: ArrayType | _: MapType => s"$value = $initCode.copy();" - case _ => s"$value = $initCode;" + case StringType => s"$valueAccessor = $initCode.clone();" + case _: StructType | _: ArrayType | _: MapType => s"$valueAccessor = $initCode.copy();" + case _ => s"$valueAccessor = $initCode;" } - ExprCode(code, "false", value) + ExprCode(code, "false", valueAccessor) } def declareMutableStates(): String = { // It's possible that we add same mutable state twice, e.g. the `mergeExpressions` in // `TypedAggregateExpression`, we should call `distinct` here to remove the duplicated ones. - mutableStates.distinct.map { case (javaType, variableName, _) => + val inlinedStates = mutableState.distinct.map { case (javaType, variableName, _) => s"private $javaType $variableName;" - }.mkString("\n") + } + val arrayStates = mutableStateArrayNames.map { case ((javaType, initCode), arrayName) => + val length = mutableStateArrayIdx((javaType, initCode)) + 1 + if (javaType.matches("^.*\\[\\]$")) { + val baseType = javaType.substring(0, javaType.length - 2) + s"private $javaType[] $arrayName = new $baseType[$length][];" + } else { + s"private $javaType[] $arrayName = new $javaType[$length];" + } + } + + (inlinedStates ++ arrayStates).mkString("\n") } def initMutableStates(): String = { // It's possible that we add same mutable state twice, e.g. the `mergeExpressions` in // `TypedAggregateExpression`, we should call `distinct` here to remove the duplicated ones. - val initCodes = mutableStates.distinct.map(_._3 + "\n") + val initCodes = mutableState.distinct.map(_._3 + "\n") + // Array state is initialized in loops + val arrayInitCodes = mutableStateArrayNames.map { case ((javaType, initCode), arrayName) => + val qualifiedInitCode = mutableStateArrayInitCodes((javaType, initCode)) + if (qualifiedInitCode.equals("")) { + "" + } else { + s""" + for (int i = 0; i < $arrayName.length; i++) { + $qualifiedInitCode + } + """ + } + } // The generated initialization code may exceed 64kb function size limit in JVM if there are too // many mutable states, so split it into multiple functions. - splitExpressions(initCodes, "init", Nil) + splitExpressions(initCodes ++ arrayInitCodes, "init", Nil) } /** @@ -199,16 +283,6 @@ class CodegenContext { partitionInitializationStatements.mkString("\n") } - /** - * Holding all the functions those will be added into generated class. - */ - val addedFunctions: mutable.Map[String, String] = - mutable.Map.empty[String, String] - - def addNewFunction(funcName: String, funcCode: String): Unit = { - addedFunctions += ((funcName, funcCode)) - } - /** * Holds expressions that are equivalent. Used to perform subexpression elimination * during codegen. @@ -230,10 +304,125 @@ class CodegenContext { // The collection of sub-expression result resetting methods that need to be called on each row. val subexprFunctions = mutable.ArrayBuffer.empty[String] + /** + * The Class and instance names generated. `OuterClass` is a placeholder standing for whatever + * class is generated as the outermost class. All other classes and instance names in this list + * are private nested classes. + */ + private val classes: mutable.ListBuffer[(String, String)] = + mutable.ListBuffer[(String, String)](("OuterClass", null)) + + // A map holding the current size in bytes of each class. + private val classSize: mutable.Map[String, Int] = + mutable.Map[String, Int](("OuterClass", 0)) + + // A map holding all functions and their names belonging to each class + private val classFunctions: mutable.Map[String, mutable.Map[String, String]] = + mutable.Map(("OuterClass", mutable.Map.empty[String, String])) + + // Returns the size of the most recently added class + private def currClassSize(): Int = classSize(classes.head._1) + + private def currClass(): (String, String) = classes.head + + // Adds a new class. Requires the class' name, and its instance name + private def addClass(className: String, classInstance: String): Unit = { + classes.prepend(Tuple2(className, classInstance)) + classSize += className -> 0 + classFunctions += className -> mutable.Map.empty[String, String] + } + + /** + * Adds a function to the generated class. If the code for the `OuterClass` grows too large, the + * function will be inlined into a new private nested class, and a class-qualified name for + * the function will be returned. Otherwise, the function will be inlined to the `OuterClass` + * and the simple `funcName` will be returned. + * + * @param funcName the class-unqualified name of the function + * @param funcCode the body of the function + * @return the name of the function, qualified by class if it will be inlined to a private + * nested class + */ + def addNewFunction( + funcName: String, + funcCode: String, + inlineToOuterClass: Boolean = false): String = { + // The number of named constants that can exist in the class is limited by the Constant Pool + // limit, 65536. We cannot know how many constants will be inserted for a class, so we use a + // threshold of 1000K bytes to determine when a function should be inlined into a private + // NestedClass. + val classInfo = if (inlineToOuterClass) { + ("OuterClass", "") + } else if (currClassSize > 1600000) { + val className = freshName("NestedClass") + val classInstance = freshName("nestedClassInstance") + + addClass(className, classInstance) + + Tuple2(className, classInstance) + } else { + currClass() + } + val name = classInfo._1 + + classSize.update(name, classSize(name) + funcCode.length) + classFunctions.update(name, classFunctions(name) += funcName -> funcCode) + if (name.equals("OuterClass")) { + funcName + } else { + val classInstance = classInfo._2 + + s"$classInstance.$funcName" + } + } + + /** + * Instantiates all nested private classes as objects to the OuterClass + */ + def initNestedClasses(): String = { + // Nested private classes have no mutable state (though they do reference the outer class's + // mutable state), so we declare and initialize them inline to the OuterClass + classes.map { + case (className, classInstance) => + if (className.equals("OuterClass")) { + "" + } else { + s"private $className $classInstance = new $className();" + } + }.mkString("\n") + } + + /** + * Declares all functions that should be inlined to the `OuterClass` + */ def declareAddedFunctions(): String = { - addedFunctions.map { case (funcName, funcCode) => funcCode }.mkString("\n") + classFunctions("OuterClass").map { + case (funcName, funcCode) => funcCode + }.mkString("\n") } + /** + * Declares all nested private classes and functions that should be inlined to them + */ + def declareNestedClasses(): String = { + classFunctions.map { + case (className, functions) => + if (className.equals("OuterClass")) { + "" + } else { + val code = functions.map { + case (_, funcCode) => + s"$funcCode" + }.mkString("\n") + s""" + |private class $className { + | $code + |} + """.stripMargin + } + } + }.mkString("\n") + final val JAVA_BOOLEAN = "boolean" final val JAVA_BYTE = "byte" final val JAVA_SHORT = "short" @@ -553,8 +742,7 @@ class CodegenContext { return 0; } """ - addNewFunction(compareFunc, funcCode) - s"this.$compareFunc($c1, $c2)" + s"${addNewFunction(compareFunc, funcCode)}($c1, $c2)" case schema: StructType => val comparisons = GenerateOrdering.genComparisons(this, schema) val compareFunc = freshName("compareStruct") @@ -570,8 +758,7 @@ class CodegenContext { return 0; } """ - addNewFunction(compareFunc, funcCode) - s"this.$compareFunc($c1, $c2)" + s"${addNewFunction(compareFunc, funcCode)}($c1, $c2)" case other if other.isInstanceOf[AtomicType] => s"$c1.compare($c2)" case udt: UserDefinedType[_] => genComp(udt.sqlType, c1, c2) case _ => @@ -641,7 +828,9 @@ class CodegenContext { /** * Splits the generated code of expressions into multiple functions, because function has - * 64kb code size limit in JVM + * 64kb code size limit in JVM. If the class the function is to be inlined to would beyond + * 1600kb, a private nested class is declared, and the function is inlined to it, because + * classes have a constant pool limit of 65536 named values. * * @param expressions the codes to evaluate expressions. * @param funcName the split function name base. @@ -685,8 +874,8 @@ class CodegenContext { | ${makeSplitFunction(body)} |} """.stripMargin + addNewFunction(name, code) - name } foldFunctions(functions.map(name => s"$name(${arguments.map(_._2).mkString(", ")})")) @@ -759,19 +948,6 @@ class CodegenContext { val isNull = s"${fnName}IsNull" val value = s"${fnName}Value" - // Generate the code for this expression tree and wrap it in a function. - val eval = expr.genCode(this) - val fn = - s""" - |private void $fnName(InternalRow $INPUT_ROW) { - | ${eval.code.trim} - | $isNull = ${eval.isNull}; - | $value = ${eval.value}; - |} - """.stripMargin - - addNewFunction(fnName, fn) - // Add a state and a mapping of the common subexpressions that are associate with this // state. Adding this expression to subExprEliminationExprMap means it will call `fn` // when it is code generated. This decision should be a cost based one. @@ -785,12 +961,23 @@ class CodegenContext { // 2. Less code. // Currently, we will do this for all non-leaf only expression trees (i.e. expr trees with // at least two nodes) as the cost of doing it is expected to be low. - addMutableState("boolean", isNull, s"$isNull = false;") - addMutableState(javaType(expr.dataType), value, + val isNullAccessor = addMutableState("boolean", isNull, s"$isNull = false;") + val valueAccessor = addMutableState(javaType(expr.dataType), value, s"$value = ${defaultValue(expr.dataType)};") - subexprFunctions += s"$fnName($INPUT_ROW);" - val state = SubExprEliminationState(isNull, value) + // Generate the code for this expression tree and wrap it in a function. + val eval = expr.genCode(this) + val fn = + s""" + |private void $fnName(InternalRow $INPUT_ROW) { + | ${eval.code.trim} + | $isNullAccessor = ${eval.isNull}; + | $valueAccessor = ${eval.value}; + |} + """.stripMargin + + subexprFunctions += s"${addNewFunction(fnName, fn)}($INPUT_ROW);" + val state = SubExprEliminationState(isNullAccessor, valueAccessor) e.foreach(subExprEliminationExprs.put(_, state)) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala index 4d732445544a8..fa7b353dea5bc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala @@ -63,21 +63,21 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], MutableP if (e.nullable) { val isNull = s"isNull_$i" val value = s"value_$i" - ctx.addMutableState("boolean", isNull, s"this.$isNull = true;") - ctx.addMutableState(ctx.javaType(e.dataType), value, - s"this.$value = ${ctx.defaultValue(e.dataType)};") + val isNullAccessor = ctx.addMutableState("boolean", isNull, s"$isNull = true;") + val valueAccessor = ctx.addMutableState(ctx.javaType(e.dataType), value, + s"$value = ${ctx.defaultValue(e.dataType)};") s""" ${ev.code} - this.$isNull = ${ev.isNull}; - this.$value = ${ev.value}; + $isNullAccessor = ${ev.isNull}; + $valueAccessor = ${ev.value}; """ } else { val value = s"value_$i" - ctx.addMutableState(ctx.javaType(e.dataType), value, - s"this.$value = ${ctx.defaultValue(e.dataType)};") + val valueAccessor = ctx.addMutableState(ctx.javaType(e.dataType), value, + s"$value = ${ctx.defaultValue(e.dataType)};") s""" ${ev.code} - this.$value = ${ev.value}; + $valueAccessor = ${ev.value}; """ } } @@ -135,6 +135,9 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], MutableP $allUpdates return mutableRow; } + + ${ctx.initNestedClasses()} + ${ctx.declareNestedClasses()} } """ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala index f7fc2d54a047b..a31943255b995 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala @@ -179,6 +179,9 @@ object GenerateOrdering extends CodeGenerator[Seq[SortOrder], Ordering[InternalR $comparisons return 0; } + + ${ctx.initNestedClasses()} + ${ctx.declareNestedClasses()} }""" val code = CodeFormatter.stripOverlappingComments( diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala index dcd1ed96a298e..b400783bb5e55 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala @@ -72,6 +72,9 @@ object GeneratePredicate extends CodeGenerator[Expression, Predicate] { ${eval.code} return !${eval.isNull} && ${eval.value}; } + + ${ctx.initNestedClasses()} + ${ctx.declareNestedClasses()} }""" val code = CodeFormatter.stripOverlappingComments( diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala index b1cb6edefb852..8964b3a6bdada 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala @@ -49,7 +49,7 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection] val output = ctx.freshName("safeRow") val values = ctx.freshName("values") // These expressions could be split into multiple functions - ctx.addMutableState("Object[]", values, s"this.$values = null;") + val valuesAccessor = ctx.addMutableState("Object[]", values, s"$values = null;") val rowClass = classOf[GenericInternalRow].getName @@ -58,17 +58,17 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection] s""" if (!$tmp.isNullAt($i)) { ${converter.code} - $values[$i] = ${converter.value}; + $valuesAccessor[$i] = ${converter.value}; } """ } val allFields = ctx.splitExpressions(tmp, fieldWriters) val code = s""" final InternalRow $tmp = $input; - this.$values = new Object[${schema.length}]; + $valuesAccessor = new Object[${schema.length}]; $allFields - final InternalRow $output = new $rowClass($values); - this.$values = null; + final InternalRow $output = new $rowClass($valuesAccessor); + $valuesAccessor = null; """ ExprCode(code, "false", output) @@ -184,6 +184,9 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection] $allExpressions return mutableRow; } + + ${ctx.initNestedClasses()} + ${ctx.declareNestedClasses()} } """ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala index 7e4c9089a2cb9..f48b695aa8cd7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala @@ -74,8 +74,8 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro isTopLevel: Boolean = false): String = { val rowWriterClass = classOf[UnsafeRowWriter].getName val rowWriter = ctx.freshName("rowWriter") - ctx.addMutableState(rowWriterClass, rowWriter, - s"this.$rowWriter = new $rowWriterClass($bufferHolder, ${inputs.length});") + val rowWriterAccessor = ctx.addMutableState(rowWriterClass, rowWriter, + s"$rowWriter = new $rowWriterClass($bufferHolder, ${inputs.length});") val resetWriter = if (isTopLevel) { // For top level row writer, it always writes to the beginning of the global buffer holder, @@ -86,10 +86,10 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro // need to clear it out every time. "" } else { - s"$rowWriter.zeroOutNullBytes();" + s"$rowWriterAccessor.zeroOutNullBytes();" } } else { - s"$rowWriter.reset();" + s"$rowWriterAccessor.reset();" } val writeFields = inputs.zip(inputTypes).zipWithIndex.map { @@ -103,8 +103,8 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro val setNull = dt match { case t: DecimalType if t.precision > Decimal.MAX_LONG_DIGITS => // Can't call setNullAt() for DecimalType with precision larger than 18. - s"$rowWriter.write($index, (Decimal) null, ${t.precision}, ${t.scale});" - case _ => s"$rowWriter.setNullAt($index);" + s"$rowWriterAccessor.write($index, (Decimal) null, ${t.precision}, ${t.scale});" + case _ => s"$rowWriterAccessor.setNullAt($index);" } val writeField = dt match { @@ -114,7 +114,8 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro // written later. final int $tmpCursor = $bufferHolder.cursor; ${writeStructToBuffer(ctx, input.value, t.map(_.dataType), bufferHolder)} - $rowWriter.setOffsetAndSize($index, $tmpCursor, $bufferHolder.cursor - $tmpCursor); + $rowWriterAccessor.setOffsetAndSize( + $index, $tmpCursor, $bufferHolder.cursor - $tmpCursor); """ case a @ ArrayType(et, _) => @@ -123,7 +124,8 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro // written later. final int $tmpCursor = $bufferHolder.cursor; ${writeArrayToBuffer(ctx, input.value, et, bufferHolder)} - $rowWriter.setOffsetAndSize($index, $tmpCursor, $bufferHolder.cursor - $tmpCursor); + $rowWriterAccessor.setOffsetAndSize( + $index, $tmpCursor, $bufferHolder.cursor - $tmpCursor); """ case m @ MapType(kt, vt, _) => @@ -132,15 +134,16 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro // written later. final int $tmpCursor = $bufferHolder.cursor; ${writeMapToBuffer(ctx, input.value, kt, vt, bufferHolder)} - $rowWriter.setOffsetAndSize($index, $tmpCursor, $bufferHolder.cursor - $tmpCursor); + $rowWriterAccessor.setOffsetAndSize( + $index, $tmpCursor, $bufferHolder.cursor - $tmpCursor); """ case t: DecimalType => - s"$rowWriter.write($index, ${input.value}, ${t.precision}, ${t.scale});" + s"$rowWriterAccessor.write($index, ${input.value}, ${t.precision}, ${t.scale});" case NullType => "" - case _ => s"$rowWriter.write($index, ${input.value});" + case _ => s"$rowWriterAccessor.write($index, ${input.value});" } if (input.isNull == "false") { @@ -174,8 +177,8 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro bufferHolder: String): String = { val arrayWriterClass = classOf[UnsafeArrayWriter].getName val arrayWriter = ctx.freshName("arrayWriter") - ctx.addMutableState(arrayWriterClass, arrayWriter, - s"this.$arrayWriter = new $arrayWriterClass();") + val arrayWriterAccessor = ctx.addMutableState(arrayWriterClass, arrayWriter, + s"$arrayWriter = new $arrayWriterClass();") val numElements = ctx.freshName("numElements") val index = ctx.freshName("index") val element = ctx.freshName("element") @@ -199,29 +202,32 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro s""" final int $tmpCursor = $bufferHolder.cursor; ${writeStructToBuffer(ctx, element, t.map(_.dataType), bufferHolder)} - $arrayWriter.setOffsetAndSize($index, $tmpCursor, $bufferHolder.cursor - $tmpCursor); + $arrayWriterAccessor.setOffsetAndSize( + $index, $tmpCursor, $bufferHolder.cursor - $tmpCursor); """ case a @ ArrayType(et, _) => s""" final int $tmpCursor = $bufferHolder.cursor; ${writeArrayToBuffer(ctx, element, et, bufferHolder)} - $arrayWriter.setOffsetAndSize($index, $tmpCursor, $bufferHolder.cursor - $tmpCursor); + $arrayWriterAccessor.setOffsetAndSize( + $index, $tmpCursor, $bufferHolder.cursor - $tmpCursor); """ case m @ MapType(kt, vt, _) => s""" final int $tmpCursor = $bufferHolder.cursor; ${writeMapToBuffer(ctx, element, kt, vt, bufferHolder)} - $arrayWriter.setOffsetAndSize($index, $tmpCursor, $bufferHolder.cursor - $tmpCursor); + $arrayWriterAccessor.setOffsetAndSize( + $index, $tmpCursor, $bufferHolder.cursor - $tmpCursor); """ case t: DecimalType => - s"$arrayWriter.write($index, $element, ${t.precision}, ${t.scale});" + s"$arrayWriterAccessor.write($index, $element, ${t.precision}, ${t.scale});" case NullType => "" - case _ => s"$arrayWriter.write($index, $element);" + case _ => s"$arrayWriterAccessor.write($index, $element);" } val primitiveTypeName = if (ctx.isPrimitiveType(jt)) ctx.primitiveTypeName(et) else "" @@ -230,11 +236,11 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro ${writeUnsafeData(ctx, s"((UnsafeArrayData) $input)", bufferHolder)} } else { final int $numElements = $input.numElements(); - $arrayWriter.initialize($bufferHolder, $numElements, $elementOrOffsetSize); + $arrayWriterAccessor.initialize($bufferHolder, $numElements, $elementOrOffsetSize); for (int $index = 0; $index < $numElements; $index++) { if ($input.isNullAt($index)) { - $arrayWriter.setNull$primitiveTypeName($index); + $arrayWriterAccessor.setNull$primitiveTypeName($index); } else { final $jt $element = ${ctx.getValue(input, et, index)}; $writeElement @@ -309,29 +315,30 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro } val result = ctx.freshName("result") - ctx.addMutableState("UnsafeRow", result, s"$result = new UnsafeRow(${expressions.length});") + val resultAccessor = + ctx.addMutableState("UnsafeRow", result, s"$result = new UnsafeRow(${expressions.length});") val holder = ctx.freshName("holder") val holderClass = classOf[BufferHolder].getName - ctx.addMutableState(holderClass, holder, - s"this.$holder = new $holderClass($result, ${numVarLenFields * 32});") + val holderAccessor = ctx.addMutableState(holderClass, holder, + s"$holder = new $holderClass($resultAccessor, ${numVarLenFields * 32});") val resetBufferHolder = if (numVarLenFields == 0) { "" } else { - s"$holder.reset();" + s"$holderAccessor.reset();" } val updateRowSize = if (numVarLenFields == 0) { "" } else { - s"$result.setTotalSize($holder.totalSize());" + s"$resultAccessor.setTotalSize($holderAccessor.totalSize());" } // Evaluate all the subexpression. val evalSubexpr = ctx.subexprFunctions.mkString("\n") - val writeExpressions = - writeExpressionsToBuffer(ctx, ctx.INPUT_ROW, exprEvals, exprTypes, holder, isTopLevel = true) + val writeExpressions = writeExpressionsToBuffer( + ctx, ctx.INPUT_ROW, exprEvals, exprTypes, holderAccessor, isTopLevel = true) val code = s""" @@ -340,7 +347,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro $writeExpressions $updateRowSize """ - ExprCode(code, "false", result) + ExprCode(code, "false", resultAccessor) } protected def canonicalize(in: Seq[Expression]): Seq[Expression] = @@ -395,6 +402,9 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro ${eval.code.trim} return ${eval.value}; } + + ${ctx.initNestedClasses()} + ${ctx.declareNestedClasses()} } """ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala index b6675a84ece48..f18bc909e1b69 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala @@ -92,12 +92,12 @@ private [sql] object GenArrayData { if (!ctx.isPrimitiveType(elementType)) { val genericArrayClass = classOf[GenericArrayData].getName - ctx.addMutableState("Object[]", arrayName, - s"this.$arrayName = new Object[${numElements}];") + val arrayNameAccessor = ctx.addMutableState("Object[]", arrayName, + s"$arrayName = new Object[${numElements}];") val assignments = elementsCode.zipWithIndex.map { case (eval, i) => val isNullAssignment = if (!isMapKey) { - s"$arrayName[$i] = null;" + s"$arrayNameAccessor[$i] = null;" } else { "throw new RuntimeException(\"Cannot use null as map key!\");" } @@ -105,26 +105,26 @@ private [sql] object GenArrayData { if (${eval.isNull}) { $isNullAssignment } else { - $arrayName[$i] = ${eval.value}; + $arrayNameAccessor[$i] = ${eval.value}; } """ } ("", assignments, - s"final ArrayData $arrayDataName = new $genericArrayClass($arrayName);", + s"final ArrayData $arrayDataName = new $genericArrayClass($arrayNameAccessor);", arrayDataName) } else { val unsafeArraySizeInBytes = UnsafeArrayData.calculateHeaderPortionInBytes(numElements) + ByteArrayMethods.roundNumberOfBytesToNearestWord(elementType.defaultSize * numElements) val baseOffset = Platform.BYTE_ARRAY_OFFSET - ctx.addMutableState("UnsafeArrayData", arrayDataName, ""); + val arrayDataNameAccessor = ctx.addMutableState("UnsafeArrayData", arrayDataName, ""); val primitiveValueTypeName = ctx.primitiveTypeName(elementType) val assignments = elementsCode.zipWithIndex.map { case (eval, i) => val isNullAssignment = if (!isMapKey) { - s"$arrayDataName.setNullAt($i);" + s"$arrayDataNameAccessor.setNullAt($i);" } else { "throw new RuntimeException(\"Cannot use null as map key!\");" } @@ -132,20 +132,20 @@ private [sql] object GenArrayData { if (${eval.isNull}) { $isNullAssignment } else { - $arrayDataName.set$primitiveValueTypeName($i, ${eval.value}); + $arrayDataNameAccessor.set$primitiveValueTypeName($i, ${eval.value}); } """ } (s""" byte[] $arrayName = new byte[$unsafeArraySizeInBytes]; - $arrayDataName = new UnsafeArrayData(); + $arrayDataNameAccessor = new UnsafeArrayData(); Platform.putLong($arrayName, $baseOffset, $numElements); - $arrayDataName.pointTo($arrayName, $baseOffset, $unsafeArraySizeInBytes); + $arrayDataNameAccessor.pointTo($arrayName, $baseOffset, $unsafeArraySizeInBytes); """, assignments, "", - arrayDataName) + arrayDataNameAccessor) } } } @@ -340,24 +340,24 @@ case class CreateNamedStruct(children: Seq[Expression]) extends CreateNamedStruc override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val rowClass = classOf[GenericInternalRow].getName val values = ctx.freshName("values") - ctx.addMutableState("Object[]", values, s"this.$values = null;") + val valuesAccessor = ctx.addMutableState("Object[]", values, s"$values = null;") ev.copy(code = s""" - $values = new Object[${valExprs.size}];""" + + $valuesAccessor = new Object[${valExprs.size}];""" + ctx.splitExpressions( ctx.INPUT_ROW, valExprs.zipWithIndex.map { case (e, i) => val eval = e.genCode(ctx) eval.code + s""" if (${eval.isNull}) { - $values[$i] = null; + $valuesAccessor[$i] = null; } else { - $values[$i] = ${eval.value}; + $valuesAccessor[$i] = ${eval.value}; }""" }) + s""" - final InternalRow ${ev.value} = new $rowClass($values); - this.$values = null; + final InternalRow ${ev.value} = new $rowClass($valuesAccessor); + $valuesAccessor = null; """, isNull = "false") } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala index ee365fe636614..44a3e167d7c59 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala @@ -118,21 +118,22 @@ case class If(predicate: Expression, trueValue: Expression, falseValue: Expressi dataType: DataType, baseFuncName: String): (String, String, String) = { val globalIsNull = ctx.freshName("isNull") - ctx.addMutableState("boolean", globalIsNull, s"$globalIsNull = false;") + val globalIsNullAccessor = + ctx.addMutableState("boolean", globalIsNull, s"$globalIsNull = false;") val globalValue = ctx.freshName("value") - ctx.addMutableState(ctx.javaType(dataType), globalValue, + val globalValueAccessor = ctx.addMutableState(ctx.javaType(dataType), globalValue, s"$globalValue = ${ctx.defaultValue(dataType)};") val funcName = ctx.freshName(baseFuncName) val funcBody = s""" |private void $funcName(InternalRow ${ctx.INPUT_ROW}) { | ${ev.code.trim} - | $globalIsNull = ${ev.isNull}; - | $globalValue = ${ev.value}; + | $globalIsNullAccessor = ${ev.isNull}; + | $globalValueAccessor = ${ev.value}; |} """.stripMargin - ctx.addNewFunction(funcName, funcBody) - (funcName, globalIsNull, globalValue) + val fullFuncName = ctx.addNewFunction(funcName, funcBody) + (fullFuncName, globalIsNullAccessor, globalValueAccessor) } override def toString: String = s"if ($predicate) $trueValue else $falseValue" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala index 43ca2cff58825..2f65fce761285 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala @@ -432,15 +432,15 @@ case class WeekOfYear(child: Expression) extends UnaryExpression with ImplicitCa val cal = classOf[Calendar].getName val c = ctx.freshName("cal") val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") - ctx.addMutableState(cal, c, + val cAccessor = ctx.addMutableState(cal, c, s""" $c = $cal.getInstance($dtu.getTimeZone("UTC")); $c.setFirstDayOfWeek($cal.MONDAY); $c.setMinimalDaysInFirstWeek(4); """) s""" - $c.setTimeInMillis($time * 1000L * 3600L * 24L); - ${ev.value} = $c.get($cal.WEEK_OF_YEAR); + $cAccessor.setTimeInMillis($time * 1000L * 3600L * 24L); + ${ev.value} = $cAccessor.get($cal.WEEK_OF_YEAR); """ }) } @@ -956,15 +956,17 @@ case class FromUTCTimestamp(left: Expression, right: Expression) val utcTerm = ctx.freshName("utc") val tzClass = classOf[TimeZone].getName val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") - ctx.addMutableState(tzClass, tzTerm, s"""$tzTerm = $dtu.getTimeZone("$tz");""") - ctx.addMutableState(tzClass, utcTerm, s"""$utcTerm = $dtu.getTimeZone("UTC");""") + val tzTermAccessor = + ctx.addMutableState(tzClass, tzTerm, s"""$tzTerm = $dtu.getTimeZone("$tz");""") + val utcTermAccessor = + ctx.addMutableState(tzClass, utcTerm, s"""$utcTerm = $dtu.getTimeZone("UTC");""") val eval = left.genCode(ctx) ev.copy(code = s""" |${eval.code} |boolean ${ev.isNull} = ${eval.isNull}; |long ${ev.value} = 0; |if (!${ev.isNull}) { - | ${ev.value} = $dtu.convertTz(${eval.value}, $utcTerm, $tzTerm); + | ${ev.value} = $dtu.convertTz(${eval.value}, $utcTermAccessor, $tzTermAccessor); |} """.stripMargin) } @@ -1128,15 +1130,17 @@ case class ToUTCTimestamp(left: Expression, right: Expression) val utcTerm = ctx.freshName("utc") val tzClass = classOf[TimeZone].getName val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") - ctx.addMutableState(tzClass, tzTerm, s"""$tzTerm = $dtu.getTimeZone("$tz");""") - ctx.addMutableState(tzClass, utcTerm, s"""$utcTerm = $dtu.getTimeZone("UTC");""") + val tzTermAccessor = + ctx.addMutableState(tzClass, tzTerm, s"""$tzTerm = $dtu.getTimeZone("$tz");""") + val utcTermAccessor = + ctx.addMutableState(tzClass, utcTerm, s"""$utcTerm = $dtu.getTimeZone("UTC");""") val eval = left.genCode(ctx) ev.copy(code = s""" |${eval.code} |boolean ${ev.isNull} = ${eval.isNull}; |long ${ev.value} = 0; |if (!${ev.isNull}) { - | ${ev.value} = $dtu.convertTz(${eval.value}, $tzTerm, $utcTerm); + | ${ev.value} = $dtu.convertTz(${eval.value}, $tzTermAccessor, $utcTermAccessor); |} """.stripMargin) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala index e84796f2edad0..ba6833ab79d2c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala @@ -181,7 +181,8 @@ case class Stack(children: Seq[Expression]) extends Generator { override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { // Rows - we write these into an array. val rowData = ctx.freshName("rows") - ctx.addMutableState("InternalRow[]", rowData, s"this.$rowData = new InternalRow[$numRows];") + val rowDataAccessor = + ctx.addMutableState("InternalRow[]", rowData, s"$rowData = new InternalRow[$numRows];") val values = children.tail val dataTypes = values.take(numFields).map(_.dataType) val code = ctx.splitExpressions(ctx.INPUT_ROW, Seq.tabulate(numRows) { row => @@ -190,16 +191,16 @@ case class Stack(children: Seq[Expression]) extends Generator { if (index < values.length) values(index) else Literal(null, dataTypes(col)) } val eval = CreateStruct(fields).genCode(ctx) - s"${eval.code}\nthis.$rowData[$row] = ${eval.value};" + s"${eval.code}\n$rowDataAccessor[$row] = ${eval.value};" }) // Create the collection. val wrapperClass = classOf[mutable.WrappedArray[_]].getName - ctx.addMutableState( + val valueAccessor = ctx.addMutableState( s"$wrapperClass", ev.value, - s"this.${ev.value} = $wrapperClass$$.MODULE$$.make(this.$rowData);") - ev.copy(code = code, isNull = "false") + s"${ev.value} = $wrapperClass$$.MODULE$$.make($rowDataAccessor);") + ev.copy(code = code, isNull = "false", value = valueAccessor) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala index 2a5963d37f5e8..41dad6493e728 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala @@ -269,17 +269,17 @@ abstract class HashExpression[E] extends Expression { override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { ev.isNull = "false" + val valueAccessor = ctx.addMutableState(ctx.javaType(dataType), ev.value, "") val childrenHash = ctx.splitExpressions(ctx.INPUT_ROW, children.map { child => val childGen = child.genCode(ctx) childGen.code + ctx.nullSafeExec(child.nullable, childGen.isNull) { - computeHash(childGen.value, child.dataType, ev.value, ctx) + computeHash(childGen.value, child.dataType, valueAccessor, ctx) } }) - ctx.addMutableState(ctx.javaType(dataType), ev.value, "") ev.copy(code = s""" - ${ev.value} = $seed; - $childrenHash""") + $valueAccessor = $seed; + $childrenHash""", value = valueAccessor) } protected def nullSafeElementHash( @@ -606,19 +606,19 @@ case class HiveHash(children: Seq[Expression]) extends HashExpression[Int] { override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { ev.isNull = "false" val childHash = ctx.freshName("childHash") + val valueAccessor = ctx.addMutableState(ctx.javaType(dataType), ev.value, "") + val childHashAccessor = ctx.addMutableState("int", childHash, s"$childHash = 0;") val childrenHash = ctx.splitExpressions(ctx.INPUT_ROW, children.map { child => val childGen = child.genCode(ctx) childGen.code + ctx.nullSafeExec(child.nullable, childGen.isNull) { - computeHash(childGen.value, child.dataType, childHash, ctx) - } + s"${ev.value} = (31 * ${ev.value}) + $childHash;" + - s"\n$childHash = 0;" + computeHash(childGen.value, child.dataType, childHashAccessor, ctx) + } + s"$valueAccessor = (31 * $valueAccessor) + $childHashAccessor;" + + s"\n$childHashAccessor = 0;" }) - ctx.addMutableState(ctx.javaType(dataType), ev.value, "") - ctx.addMutableState("int", childHash, s"$childHash = 0;") ev.copy(code = s""" - ${ev.value} = $seed; - $childrenHash""") + $valueAccessor = $seed; + $childrenHash""", value = valueAccessor) } override def eval(input: InternalRow = null): Int = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index 1a202ecf745c9..a07d9daf77edf 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.expressions.objects import java.lang.reflect.Modifier import scala.collection.mutable.Builder +import scala.collection.mutable import scala.language.existentials import scala.reflect.ClassTag @@ -61,15 +62,15 @@ trait InvokeLike extends Expression with NonSQLExpression { val resultIsNull = if (needNullCheck) { val resultIsNull = ctx.freshName("resultIsNull") - ctx.addMutableState("boolean", resultIsNull, "") - resultIsNull + val resultIsNullAccessor = ctx.addMutableState("boolean", resultIsNull, "") + resultIsNullAccessor } else { "false" } val argValues = arguments.map { e => val argValue = ctx.freshName("argValue") - ctx.addMutableState(ctx.javaType(e.dataType), argValue, "") - argValue + val argValueAccessor = ctx.addMutableState(ctx.javaType(e.dataType), argValue, "") + argValueAccessor } val argCodes = if (needNullCheck) { @@ -140,10 +141,13 @@ case class StaticInvoke( val callFunc = s"$objectName.$functionName($argString)" + val valueAccessor = + ctx.addMutableState(javaType, ev.value, s"${ev.value} = ${ctx.defaultValue(dataType)};") + // If the function can return null, we do an extra check to make sure our null bit is still set // correctly. val postNullCheck = if (ctx.defaultValue(dataType) == "null") { - s"${ev.isNull} = ${ev.value} == null;" + s"${ev.isNull} = $valueAccessor == null;" } else { "" } @@ -151,10 +155,10 @@ case class StaticInvoke( val code = s""" $argCode boolean ${ev.isNull} = $resultIsNull; - final $javaType ${ev.value} = $resultIsNull ? ${ctx.defaultValue(dataType)} : $callFunc; + $valueAccessor = $resultIsNull ? ${ctx.defaultValue(dataType)} : $callFunc; $postNullCheck """ - ev.copy(code = code) + ev.copy(code = code, value = valueAccessor) } } @@ -322,8 +326,11 @@ case class NewInstance( ev.isNull = resultIsNull + val valueAccessor = + ctx.addMutableState(javaType, ev.value, s"${ev.value} = ${ctx.defaultValue(javaType)};") + val constructorCall = outer.map { gen => - s"${gen.value}.new ${cls.getSimpleName}($argString)" + s"$valueAccessor.new ${cls.getSimpleName}($argString)" }.getOrElse { s"new $className($argString)" } @@ -331,9 +338,9 @@ case class NewInstance( val code = s""" $argCode ${outer.map(_.code).getOrElse("")} - final $javaType ${ev.value} = ${ev.isNull} ? ${ctx.defaultValue(javaType)} : $constructorCall; + $valueAccessor = ${ev.isNull} ? ${ctx.defaultValue(javaType)} : $constructorCall; """ - ev.copy(code = code) + ev.copy(code = code, value = valueAccessor) } override def toString: String = s"newInstance($cls)" @@ -516,8 +523,8 @@ case class MapObjects private( override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val elementJavaType = ctx.javaType(loopVarDataType) - ctx.addMutableState("boolean", loopIsNull, "") - ctx.addMutableState(elementJavaType, loopValue, "") + val loopIsNullAccessor = ctx.addMutableState("boolean", loopIsNull, "", inLine = true) + val loopValueAccessor = ctx.addMutableState(elementJavaType, loopValue, "", inLine = true) val genInputData = inputData.genCode(ctx) val genFunction = lambdaFunction.genCode(ctx) val dataLength = ctx.freshName("dataLength") @@ -588,11 +595,11 @@ case class MapObjects private( } val loopNullCheck = inputDataType match { - case _: ArrayType => s"$loopIsNull = ${genInputData.value}.isNullAt($loopIndex);" + case _: ArrayType => s"$loopIsNullAccessor = ${genInputData.value}.isNullAt($loopIndex);" // The element of primitive array will never be null. case ObjectType(cls) if cls.isArray && cls.getComponentType.isPrimitive => - s"$loopIsNull = false" - case _ => s"$loopIsNull = $loopValue == null;" + s"$loopIsNullAccessor = false" + case _ => s"$loopIsNullAccessor = $loopValueAccessor == null;" } val (initCollection, addElement, getResult): (String, String => String, String) = @@ -632,7 +639,7 @@ case class MapObjects private( int $loopIndex = 0; while ($loopIndex < $dataLength) { - $loopValue = ($elementJavaType) ($getLoopVar); + $loopValueAccessor = ($elementJavaType) ($getLoopVar); $loopNullCheck ${genFunction.code} @@ -666,15 +673,18 @@ object ExternalMapToCatalyst { val keyName = "ExternalMapToCatalyst_key" + id val valueName = "ExternalMapToCatalyst_value" + id val valueIsNull = "ExternalMapToCatalyst_value_isNull" + id + val mapValuesMap: mutable.Map[String, String] = mutable.Map.empty[String, String] ExternalMapToCatalyst( keyName, keyType, - keyConverter(LambdaVariable(keyName, "false", keyType, false)), + keyConverter( + LambdaVariable(keyName, "false", keyType, false)), valueName, valueIsNull, valueType, - valueConverter(LambdaVariable(valueName, valueIsNull, valueType, valueNullable)), + valueConverter( + LambdaVariable(valueName, valueIsNull, valueType, valueNullable)), inputMap ) } @@ -829,15 +839,15 @@ case class CreateExternalRow(children: Seq[Expression], schema: StructType) override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val rowClass = classOf[GenericRowWithSchema].getName val values = ctx.freshName("values") - ctx.addMutableState("Object[]", values, "") + val valuesAccessor = ctx.addMutableState("Object[]", values, "") val childrenCodes = children.zipWithIndex.map { case (e, i) => val eval = e.genCode(ctx) eval.code + s""" if (${eval.isNull}) { - $values[$i] = null; + $valuesAccessor[$i] = null; } else { - $values[$i] = ${eval.value}; + $valuesAccessor[$i] = ${eval.value}; } """ } @@ -846,9 +856,9 @@ case class CreateExternalRow(children: Seq[Expression], schema: StructType) val schemaField = ctx.addReferenceObj("schema", schema) val code = s""" - $values = new Object[${children.size}]; + $valuesAccessor = new Object[${children.size}]; $childrenCode - final ${classOf[Row].getName} ${ev.value} = new $rowClass($values, $schemaField); + final ${classOf[Row].getName} ${ev.value} = new $rowClass($valuesAccessor, $schemaField); """ ev.copy(code = code, isNull = "false") } @@ -885,12 +895,13 @@ case class EncodeUsingSerializer(child: Expression, kryo: Boolean) $serializer = ($serializerInstanceClass) new $serializerClass($env.conf()).newInstance(); } """ - ctx.addMutableState(serializerInstanceClass, serializer, serializerInit) + val serializerAccessor = + ctx.addMutableState(serializerInstanceClass, serializer, serializerInit) // Code to serialize. val input = child.genCode(ctx) val javaType = ctx.javaType(dataType) - val serialize = s"$serializer.serialize(${input.value}, null).array()" + val serialize = s"$serializerAccessor.serialize(${input.value}, null).array()" val code = s""" ${input.code} @@ -931,13 +942,14 @@ case class DecodeUsingSerializer[T](child: Expression, tag: ClassTag[T], kryo: B $serializer = ($serializerInstanceClass) new $serializerClass($env.conf()).newInstance(); } """ - ctx.addMutableState(serializerInstanceClass, serializer, serializerInit) + val serializerAccessor = + ctx.addMutableState(serializerInstanceClass, serializer, serializerInit) // Code to deserialize. val input = child.genCode(ctx) val javaType = ctx.javaType(dataType) val deserialize = - s"($javaType) $serializer.deserialize(java.nio.ByteBuffer.wrap(${input.value}), null)" + s"($javaType) $serializerAccessor.deserialize(java.nio.ByteBuffer.wrap(${input.value}), null)" val code = s""" ${input.code} @@ -967,26 +979,26 @@ case class InitializeJavaBean(beanInstance: Expression, setters: Map[String, Exp val javaBeanInstance = ctx.freshName("javaBean") val beanInstanceJavaType = ctx.javaType(beanInstance.dataType) - ctx.addMutableState(beanInstanceJavaType, javaBeanInstance, "") + val javaBeanInstanceAccessor = ctx.addMutableState(beanInstanceJavaType, javaBeanInstance, "") val initialize = setters.map { case (setterMethod, fieldValue) => val fieldGen = fieldValue.genCode(ctx) s""" ${fieldGen.code} - ${javaBeanInstance}.$setterMethod(${fieldGen.value}); + ${javaBeanInstanceAccessor}.$setterMethod(${fieldGen.value}); """ } val initializeCode = ctx.splitExpressions(ctx.INPUT_ROW, initialize.toSeq) val code = s""" ${instanceGen.code} - this.${javaBeanInstance} = ${instanceGen.value}; + ${javaBeanInstanceAccessor} = ${instanceGen.value}; if (!${instanceGen.isNull}) { $initializeCode } """ - ev.copy(code = code, isNull = instanceGen.isNull, value = instanceGen.value) + ev.copy(code = code, isNull = instanceGen.isNull, value = javaBeanInstanceAccessor) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index 5034566132f7a..1e4e9a0a44e93 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -269,16 +269,17 @@ case class InSet(child: Expression, hset: Set[Any]) extends UnaryExpression with ctx.references += this val hsetTerm = ctx.freshName("hset") val hasNullTerm = ctx.freshName("hasNull") - ctx.addMutableState(setName, hsetTerm, - s"$hsetTerm = (($InSetName)references[${ctx.references.size - 1}]).getHSet();") - ctx.addMutableState("boolean", hasNullTerm, s"$hasNullTerm = $hsetTerm.contains(null);") + val hsetTermAccessor = ctx.addMutableState(setName, + hsetTerm, s"$hsetTerm = (($InSetName)references[${ctx.references.size - 1}]).getHSet();") + val hasNullTermAccessor = ctx.addMutableState( + "boolean", hasNullTerm, s"$hasNullTerm = $hsetTermAccessor.contains(null);") ev.copy(code = s""" ${childGen.code} boolean ${ev.isNull} = ${childGen.isNull}; boolean ${ev.value} = false; if (!${ev.isNull}) { - ${ev.value} = $hsetTerm.contains(${childGen.value}); - if (!${ev.value} && $hasNullTerm) { + ${ev.value} = $hsetTermAccessor.contains(${childGen.value}); + if (!${ev.value} && $hasNullTermAccessor) { ${ev.isNull} = true; } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala index 1d7a3c7356075..5138011c3b9c6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala @@ -79,11 +79,12 @@ case class Rand(child: Expression) extends RDG { override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val rngTerm = ctx.freshName("rng") val className = classOf[XORShiftRandom].getName - ctx.addMutableState(className, rngTerm, "") + val rngTermAccessor = ctx.addMutableState(className, rngTerm, "") ctx.addPartitionInitializationStatement( - s"$rngTerm = new $className(${seed}L + partitionIndex);") + s"$rngTermAccessor = new $className(${seed}L + partitionIndex);") ev.copy(code = s""" - final ${ctx.javaType(dataType)} ${ev.value} = $rngTerm.nextDouble();""", isNull = "false") + final ${ctx.javaType(dataType)} ${ev.value} = $rngTermAccessor.nextDouble();""", + isNull = "false") } } @@ -114,11 +115,12 @@ case class Randn(child: Expression) extends RDG { override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val rngTerm = ctx.freshName("rng") val className = classOf[XORShiftRandom].getName - ctx.addMutableState(className, rngTerm, "") + val rngTermAccessor = ctx.addMutableState(className, rngTerm, "") ctx.addPartitionInitializationStatement( - s"$rngTerm = new $className(${seed}L + partitionIndex);") + s"$rngTermAccessor = new $className(${seed}L + partitionIndex);") ev.copy(code = s""" - final ${ctx.javaType(dataType)} ${ev.value} = $rngTerm.nextGaussian();""", isNull = "false") + final ${ctx.javaType(dataType)} ${ev.value} = $rngTermAccessor.nextGaussian();""", + isNull = "false") } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala index aa5a1b5448c6d..02d9eb31212a8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala @@ -118,7 +118,7 @@ case class Like(left: Expression, right: Expression) extends StringRegexExpressi if (rVal != null) { val regexStr = StringEscapeUtils.escapeJava(escape(rVal.asInstanceOf[UTF8String].toString())) - ctx.addMutableState(patternClass, pattern, + val patternAccessor = ctx.addMutableState(patternClass, pattern, s"""$pattern = ${patternClass}.compile("$regexStr");""") // We don't use nullSafeCodeGen here because we don't want to re-evaluate right again. @@ -128,7 +128,7 @@ case class Like(left: Expression, right: Expression) extends StringRegexExpressi boolean ${ev.isNull} = ${eval.isNull}; ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; if (!${ev.isNull}) { - ${ev.value} = $pattern.matcher(${eval.value}.toString()).matches(); + ${ev.value} = $patternAccessor.matcher(${eval.value}.toString()).matches(); } """) } else { @@ -191,7 +191,7 @@ case class RLike(left: Expression, right: Expression) extends StringRegexExpress if (rVal != null) { val regexStr = StringEscapeUtils.escapeJava(rVal.asInstanceOf[UTF8String].toString()) - ctx.addMutableState(patternClass, pattern, + val patternAccessor = ctx.addMutableState(patternClass, pattern, s"""$pattern = ${patternClass}.compile("$regexStr");""") // We don't use nullSafeCodeGen here because we don't want to re-evaluate right again. @@ -201,7 +201,7 @@ case class RLike(left: Expression, right: Expression) extends StringRegexExpress boolean ${ev.isNull} = ${eval.isNull}; ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; if (!${ev.isNull}) { - ${ev.value} = $pattern.matcher(${eval.value}.toString()).find(0); + ${ev.value} = $patternAccessor.matcher(${eval.value}.toString()).find(0); } """) } else { @@ -326,12 +326,16 @@ case class RegExpReplace(subject: Expression, regexp: Expression, rep: Expressio val matcher = ctx.freshName("matcher") - ctx.addMutableState("UTF8String", termLastRegex, s"${termLastRegex} = null;") - ctx.addMutableState(classNamePattern, termPattern, s"${termPattern} = null;") - ctx.addMutableState("String", termLastReplacement, s"${termLastReplacement} = null;") - ctx.addMutableState("UTF8String", - termLastReplacementInUTF8, s"${termLastReplacementInUTF8} = null;") - ctx.addMutableState(classNameStringBuffer, + val termLastRegexAccessor = + ctx.addMutableState("UTF8String", termLastRegex, s"${termLastRegex} = null;") + val termPatternAccessor = + ctx.addMutableState(classNamePattern, termPattern, s"${termPattern} = null;") + val termLastReplacementAccessor = + ctx.addMutableState("String", termLastReplacement, s"${termLastReplacement} = null;") + val termLastReplacementInUTF8Accessor = + ctx.addMutableState( + "UTF8String", termLastReplacementInUTF8, s"${termLastReplacementInUTF8} = null;") + val termResultAccessor = ctx.addMutableState(classNameStringBuffer, termResult, s"${termResult} = new $classNameStringBuffer();") val setEvNotNull = if (nullable) { @@ -342,24 +346,24 @@ case class RegExpReplace(subject: Expression, regexp: Expression, rep: Expressio nullSafeCodeGen(ctx, ev, (subject, regexp, rep) => { s""" - if (!$regexp.equals(${termLastRegex})) { + if (!$regexp.equals(${termLastRegexAccessor})) { // regex value changed - ${termLastRegex} = $regexp.clone(); - ${termPattern} = ${classNamePattern}.compile(${termLastRegex}.toString()); + ${termLastRegexAccessor} = $regexp.clone(); + ${termPatternAccessor} = ${classNamePattern}.compile(${termLastRegexAccessor}.toString()); } - if (!$rep.equals(${termLastReplacementInUTF8})) { + if (!$rep.equals(${termLastReplacementInUTF8Accessor})) { // replacement string changed - ${termLastReplacementInUTF8} = $rep.clone(); - ${termLastReplacement} = ${termLastReplacementInUTF8}.toString(); + ${termLastReplacementInUTF8Accessor} = $rep.clone(); + ${termLastReplacementAccessor} = ${termLastReplacementInUTF8Accessor}.toString(); } - ${termResult}.delete(0, ${termResult}.length()); - java.util.regex.Matcher ${matcher} = ${termPattern}.matcher($subject.toString()); + ${termResultAccessor}.delete(0, ${termResultAccessor}.length()); + java.util.regex.Matcher ${matcher} = ${termPatternAccessor}.matcher($subject.toString()); while (${matcher}.find()) { - ${matcher}.appendReplacement(${termResult}, ${termLastReplacement}); + ${matcher}.appendReplacement(${termResultAccessor}, ${termLastReplacementAccessor}); } - ${matcher}.appendTail(${termResult}); - ${ev.value} = UTF8String.fromString(${termResult}.toString()); + ${matcher}.appendTail(${termResultAccessor}); + ${ev.value} = UTF8String.fromString(${termResultAccessor}.toString()); $setEvNotNull """ }) @@ -419,8 +423,10 @@ case class RegExpExtract(subject: Expression, regexp: Expression, idx: Expressio val matcher = ctx.freshName("matcher") val matchResult = ctx.freshName("matchResult") - ctx.addMutableState("UTF8String", termLastRegex, s"${termLastRegex} = null;") - ctx.addMutableState(classNamePattern, termPattern, s"${termPattern} = null;") + val termLastRegexAccessor = + ctx.addMutableState("UTF8String", termLastRegex, s"${termLastRegex} = null;") + val termPatternAccessor = + ctx.addMutableState(classNamePattern, termPattern, s"${termPattern} = null;") val setEvNotNull = if (nullable) { s"${ev.isNull} = false;" @@ -430,13 +436,13 @@ case class RegExpExtract(subject: Expression, regexp: Expression, idx: Expressio nullSafeCodeGen(ctx, ev, (subject, regexp, idx) => { s""" - if (!$regexp.equals(${termLastRegex})) { + if (!$regexp.equals(${termLastRegexAccessor})) { // regex value changed - ${termLastRegex} = $regexp.clone(); - ${termPattern} = ${classNamePattern}.compile(${termLastRegex}.toString()); + ${termLastRegexAccessor} = $regexp.clone(); + ${termPatternAccessor} = ${classNamePattern}.compile(${termLastRegexAccessor}.toString()); } java.util.regex.Matcher ${matcher} = - ${termPattern}.matcher($subject.toString()); + ${termPatternAccessor}.matcher($subject.toString()); if (${matcher}.find()) { java.util.regex.MatchResult ${matchResult} = ${matcher}.toMatchResult(); if (${matchResult}.group($idx) == null) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala index 5598a146997ca..e687e91f9962f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala @@ -396,24 +396,27 @@ case class StringTranslate(srcExpr: Expression, matchingExpr: Expression, replac val termDict = ctx.freshName("dict") val classNameDict = classOf[JMap[Character, Character]].getCanonicalName - ctx.addMutableState("UTF8String", termLastMatching, s"$termLastMatching = null;") - ctx.addMutableState("UTF8String", termLastReplace, s"$termLastReplace = null;") - ctx.addMutableState(classNameDict, termDict, s"$termDict = null;") + val termLastMatchingAccessor = + ctx.addMutableState("UTF8String", termLastMatching, s"$termLastMatching = null;") + val termLastReplaceAccessor = + ctx.addMutableState("UTF8String", termLastReplace, s"$termLastReplace = null;") + val termDictAccessor = ctx.addMutableState(classNameDict, termDict, s"$termDict = null;") nullSafeCodeGen(ctx, ev, (src, matching, replace) => { val check = if (matchingExpr.foldable && replaceExpr.foldable) { s"$termDict == null" } else { - s"!$matching.equals($termLastMatching) || !$replace.equals($termLastReplace)" + s"!$matching.equals($termLastMatchingAccessor) " + + s"|| !$replace.equals($termLastReplaceAccessor)" } s"""if ($check) { // Not all of them is literal or matching or replace value changed - $termLastMatching = $matching.clone(); - $termLastReplace = $replace.clone(); - $termDict = org.apache.spark.sql.catalyst.expressions.StringTranslate - .buildDict($termLastMatching, $termLastReplace); + $termLastMatchingAccessor = $matching.clone(); + $termLastReplaceAccessor = $replace.clone(); + $termDictAccessor = org.apache.spark.sql.catalyst.expressions.StringTranslate + .buildDict($termLastMatchingAccessor, $termLastReplaceAccessor); } - ${ev.value} = $src.translate($termDict); + ${ev.value} = $src.translate($termDictAccessor); """ }) } @@ -1505,27 +1508,27 @@ case class FormatNumber(x: Expression, d: Expression) val numberFormat = ctx.freshName("numberFormat") val i = ctx.freshName("i") val dFormat = ctx.freshName("dFormat") - ctx.addMutableState("int", lastDValue, s"$lastDValue = -100;") - ctx.addMutableState(sb, pattern, s"$pattern = new $sb();") - ctx.addMutableState(df, numberFormat, + val lastDValueAccessor = ctx.addMutableState("int", lastDValue, s"$lastDValue = -100;") + val patternAccessor = ctx.addMutableState(sb, pattern, s"$pattern = new $sb();") + val numberFormatAccessor = ctx.addMutableState(df, numberFormat, s"""$numberFormat = new $df("", new $dfs($l.$usLocale));""") s""" if ($d >= 0) { - $pattern.delete(0, $pattern.length()); - if ($d != $lastDValue) { - $pattern.append("#,###,###,###,###,###,##0"); + $patternAccessor.delete(0, $patternAccessor.length()); + if ($d != $lastDValueAccessor) { + $patternAccessor.append("#,###,###,###,###,###,##0"); if ($d > 0) { - $pattern.append("."); + $patternAccessor.append("."); for (int $i = 0; $i < $d; $i++) { - $pattern.append("0"); + $patternAccessor.append("0"); } } - $lastDValue = $d; - $numberFormat.applyLocalizedPattern($pattern.toString()); + $lastDValueAccessor = $d; + $numberFormatAccessor.applyLocalizedPattern($patternAccessor.toString()); } - ${ev.value} = UTF8String.fromString($numberFormat.format(${typeHelper(num)})); + ${ev.value} = UTF8String.fromString($numberFormatAccessor.format(${typeHelper(num)})); } else { ${ev.value} = null; ${ev.isNull} = true; diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala index e86116680a57a..a216eb00ec704 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala @@ -71,25 +71,29 @@ private[sql] trait ColumnarBatchScan extends CodegenSupport { override protected def doProduce(ctx: CodegenContext): String = { val input = ctx.freshName("input") // PhysicalRDD always just has one input - ctx.addMutableState("scala.collection.Iterator", input, s"$input = inputs[0];") + val inputAccessor = ctx.addMutableState( + "scala.collection.Iterator", input, s"$input = inputs[0];") // metrics val numOutputRows = metricTerm(ctx, "numOutputRows") val scanTimeMetric = metricTerm(ctx, "scanTime") val scanTimeTotalNs = ctx.freshName("scanTime") - ctx.addMutableState("long", scanTimeTotalNs, s"$scanTimeTotalNs = 0;") + val scanTimeTotalNsAccessor = + ctx.addMutableState("long", scanTimeTotalNs, s"$scanTimeTotalNs = 0;") val columnarBatchClz = "org.apache.spark.sql.execution.vectorized.ColumnarBatch" val batch = ctx.freshName("batch") - ctx.addMutableState(columnarBatchClz, batch, s"$batch = null;") + val batchAccessor = ctx.addMutableState(columnarBatchClz, batch, s"$batch = null;") val columnVectorClz = "org.apache.spark.sql.execution.vectorized.ColumnVector" val idx = ctx.freshName("batchIdx") - ctx.addMutableState("int", idx, s"$idx = 0;") - val colVars = output.indices.map(i => ctx.freshName("colInstance" + i)) - val columnAssigns = colVars.zipWithIndex.map { case (name, i) => + val idxAccessor = ctx.addMutableState("int", idx, s"$idx = 0;") + val colVars = output.indices.map(i => { + val name = ctx.freshName("colInstance" + i) ctx.addMutableState(columnVectorClz, name, s"$name = null;") - s"$name = $batch.column($i);" + }) + val columnAssigns = colVars.zipWithIndex.map { case (nameAccessor, i) => + s"$nameAccessor = $batchAccessor.column($i);" } val nextBatch = ctx.freshName("nextBatch") @@ -97,13 +101,13 @@ private[sql] trait ColumnarBatchScan extends CodegenSupport { s""" |private void $nextBatch() throws java.io.IOException { | long getBatchStart = System.nanoTime(); - | if ($input.hasNext()) { - | $batch = ($columnarBatchClz)$input.next(); - | $numOutputRows.add($batch.numRows()); - | $idx = 0; + | if ($inputAccessor.hasNext()) { + | $batchAccessor = ($columnarBatchClz)$inputAccessor.next(); + | $numOutputRows.add($batchAccessor.numRows()); + | $idxAccessor = 0; | ${columnAssigns.mkString("", "\n", "\n")} | } - | $scanTimeTotalNs += System.nanoTime() - getBatchStart; + | $scanTimeTotalNsAccessor += System.nanoTime() - getBatchStart; |}""".stripMargin) ctx.currentVars = null @@ -120,23 +124,23 @@ private[sql] trait ColumnarBatchScan extends CodegenSupport { "// shouldStop check is eliminated" } s""" - |if ($batch == null) { + |if ($batchAccessor == null) { | $nextBatch(); |} - |while ($batch != null) { - | int $numRows = $batch.numRows(); + |while ($batchAccessor != null) { + | int $numRows = $batchAccessor.numRows(); | int $localEnd = $numRows - $idx; | for (int $localIdx = 0; $localIdx < $localEnd; $localIdx++) { - | int $rowidx = $idx + $localIdx; + | int $rowidx = $idxAccessor + $localIdx; | ${consume(ctx, columnsBatchInput).trim} | $shouldStop | } - | $idx = $numRows; - | $batch = null; + | $idxAccessor = $numRows; + | $batchAccessor = null; | $nextBatch(); |} - |$scanTimeMetric.add($scanTimeTotalNs / (1000 * 1000)); - |$scanTimeTotalNs = 0; + |$scanTimeMetric.add($scanTimeTotalNsAccessor / (1000 * 1000)); + |$scanTimeTotalNsAccessor = 0; """.stripMargin } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala index 74fc23a52a141..5e68db2334b94 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala @@ -118,7 +118,8 @@ case class RowDataSourceScanExec( val numOutputRows = metricTerm(ctx, "numOutputRows") // PhysicalRDD always just has one input val input = ctx.freshName("input") - ctx.addMutableState("scala.collection.Iterator", input, s"$input = inputs[0];") + val inputAccessor = + ctx.addMutableState("scala.collection.Iterator", input, s"$input = inputs[0];") val exprRows = output.zipWithIndex.map{ case (a, i) => BoundReference(i, a.dataType, a.nullable) } @@ -128,8 +129,8 @@ case class RowDataSourceScanExec( val columnsRowInput = exprRows.map(_.genCode(ctx)) val inputRow = if (outputUnsafeRows) row else null s""" - |while ($input.hasNext()) { - | InternalRow $row = (InternalRow) $input.next(); + |while ($inputAccessor.hasNext()) { + | InternalRow $row = (InternalRow) $inputAccessor.next(); | $numOutputRows.add(1); | ${consume(ctx, columnsRowInput, inputRow).trim} | if (shouldStop()) return; @@ -345,7 +346,8 @@ case class FileSourceScanExec( val numOutputRows = metricTerm(ctx, "numOutputRows") // PhysicalRDD always just has one input val input = ctx.freshName("input") - ctx.addMutableState("scala.collection.Iterator", input, s"$input = inputs[0];") + val inputAccessor = + ctx.addMutableState("scala.collection.Iterator", input, s"$input = inputs[0];") val exprRows = output.zipWithIndex.map{ case (a, i) => BoundReference(i, a.dataType, a.nullable) } @@ -355,8 +357,8 @@ case class FileSourceScanExec( val columnsRowInput = exprRows.map(_.genCode(ctx)) val inputRow = if (needsUnsafeRowConversion) null else row s""" - |while ($input.hasNext()) { - | InternalRow $row = (InternalRow) $input.next(); + |while ($inputAccessor.hasNext()) { + | InternalRow $row = (InternalRow) $inputAccessor.next(); | $numOutputRows.add(1); | ${consume(ctx, columnsRowInput, inputRow).trim} | if (shouldStop()) return; diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SortExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SortExec.scala index f98ae82574d20..4863727dac78d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SortExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SortExec.scala @@ -126,22 +126,24 @@ case class SortExec( override protected def doProduce(ctx: CodegenContext): String = { val needToSort = ctx.freshName("needToSort") - ctx.addMutableState("boolean", needToSort, s"$needToSort = true;") + val needToSortAccessor = ctx.addMutableState("boolean", needToSort, s"$needToSort = true;") // Initialize the class member variables. This includes the instance of the Sorter and // the iterator to return sorted rows. val thisPlan = ctx.addReferenceObj("plan", this) sorterVariable = ctx.freshName("sorter") - ctx.addMutableState(classOf[UnsafeExternalRowSorter].getName, sorterVariable, + // Reset sorterVariable value to potentially class-qualified form + sorterVariable = ctx.addMutableState(classOf[UnsafeExternalRowSorter].getName, sorterVariable, s"$sorterVariable = $thisPlan.createSorter();") val metrics = ctx.freshName("metrics") - ctx.addMutableState(classOf[TaskMetrics].getName, metrics, + val metricsAccessor = ctx.addMutableState(classOf[TaskMetrics].getName, metrics, s"$metrics = org.apache.spark.TaskContext.get().taskMetrics();") val sortedIterator = ctx.freshName("sortedIter") - ctx.addMutableState("scala.collection.Iterator", sortedIterator, "") + val sortedIteratorAccessor = + ctx.addMutableState("scala.collection.Iterator", sortedIterator, "") val addToSorter = ctx.freshName("addToSorter") - ctx.addNewFunction(addToSorter, + val addToSorterFunc = ctx.addNewFunction(addToSorter, s""" | private void $addToSorter() throws java.io.IOException { | ${child.asInstanceOf[CodegenSupport].produce(ctx, this)} @@ -158,19 +160,19 @@ case class SortExec( val spillSizeBefore = ctx.freshName("spillSizeBefore") val sortTime = metricTerm(ctx, "sortTime") s""" - | if ($needToSort) { - | long $spillSizeBefore = $metrics.memoryBytesSpilled(); - | $addToSorter(); - | $sortedIterator = $sorterVariable.sort(); + | if ($needToSortAccessor) { + | long $spillSizeBefore = $metricsAccessor.memoryBytesSpilled(); + | $addToSorterFunc(); + | $sortedIteratorAccessor = $sorterVariable.sort(); | $sortTime.add($sorterVariable.getSortTimeNanos() / 1000000); | $peakMemory.add($sorterVariable.getPeakMemoryUsage()); - | $spillSize.add($metrics.memoryBytesSpilled() - $spillSizeBefore); - | $metrics.incPeakExecutionMemory($sorterVariable.getPeakMemoryUsage()); - | $needToSort = false; + | $spillSize.add($metricsAccessor.memoryBytesSpilled() - $spillSizeBefore); + | $metricsAccessor.incPeakExecutionMemory($sorterVariable.getPeakMemoryUsage()); + | $needToSortAccessor = false; | } | - | while ($sortedIterator.hasNext()) { - | UnsafeRow $outputRow = (UnsafeRow)$sortedIterator.next(); + | while ($sortedIteratorAccessor.hasNext()) { + | UnsafeRow $outputRow = (UnsafeRow)$sortedIteratorAccessor.next(); | ${consume(ctx, null, outputRow)} | if (shouldStop()) return; | } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala index c1e1a631c677e..ec46c6ee27dc5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala @@ -255,11 +255,12 @@ case class InputAdapter(child: SparkPlan) extends UnaryExecNode with CodegenSupp override def doProduce(ctx: CodegenContext): String = { val input = ctx.freshName("input") // Right now, InputAdapter is only used when there is one input RDD. + val inputAccessor = ctx.addMutableState("scala.collection.Iterator", input, s"$input = inputs[0];") val row = ctx.freshName("row") s""" - | while ($input.hasNext() && !stopEarly()) { - | InternalRow $row = (InternalRow) $input.next(); + | while ($inputAccessor.hasNext() && !stopEarly()) { + | InternalRow $row = (InternalRow) $inputAccessor.next(); | ${consume(ctx, null, row).trim} | if (shouldStop()) return; | } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala index 68c8e6ce62cbb..d2cecf5cd4dd9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala @@ -162,7 +162,7 @@ case class HashAggregateExec( private def doProduceWithoutKeys(ctx: CodegenContext): String = { val initAgg = ctx.freshName("initAgg") - ctx.addMutableState("boolean", initAgg, s"$initAgg = false;") + val initAggAccessor = ctx.addMutableState("boolean", initAgg, s"$initAgg = false;") // generate variables for aggregation buffer val functions = aggregateExpressions.map(_.aggregateFunction.asInstanceOf[DeclarativeAggregate]) @@ -170,15 +170,15 @@ case class HashAggregateExec( bufVars = initExpr.map { e => val isNull = ctx.freshName("bufIsNull") val value = ctx.freshName("bufValue") - ctx.addMutableState("boolean", isNull, "") - ctx.addMutableState(ctx.javaType(e.dataType), value, "") + val isNullAccessor = ctx.addMutableState("boolean", isNull, "") + val valueAccessor = ctx.addMutableState(ctx.javaType(e.dataType), value, "") // The initial expression should not access any column val ev = e.genCode(ctx) val initVars = s""" - | $isNull = ${ev.isNull}; - | $value = ${ev.value}; + | $isNullAccessor = ${ev.isNull}; + | $valueAccessor = ${ev.value}; """.stripMargin - ExprCode(ev.code + initVars, isNull, value) + ExprCode(ev.code + initVars, isNullAccessor, valueAccessor) } val initBufVar = evaluateVariables(bufVars) @@ -209,7 +209,7 @@ case class HashAggregateExec( } val doAgg = ctx.freshName("doAggregateWithoutKey") - ctx.addNewFunction(doAgg, + val doAggFunc = ctx.addNewFunction(doAgg, s""" | private void $doAgg() throws java.io.IOException { | // initialize aggregation buffer @@ -223,10 +223,10 @@ case class HashAggregateExec( val aggTime = metricTerm(ctx, "aggTime") val beforeAgg = ctx.freshName("beforeAgg") s""" - | while (!$initAgg) { - | $initAgg = true; + | while (!$initAggAccessor) { + | $initAggAccessor = true; | long $beforeAgg = System.nanoTime(); - | $doAgg(); + | $doAggFunc(); | $aggTime.add((System.nanoTime() - $beforeAgg) / 1000000); | | // output the result @@ -459,11 +459,11 @@ case class HashAggregateExec( } else if (modes.contains(Partial) || modes.contains(PartialMerge)) { // This should be the last operator in a stage, we should output UnsafeRow directly val joinerTerm = ctx.freshName("unsafeRowJoiner") - ctx.addMutableState(classOf[UnsafeRowJoiner].getName, joinerTerm, + val joinerTermAccessor = ctx.addMutableState(classOf[UnsafeRowJoiner].getName, joinerTerm, s"$joinerTerm = $plan.createUnsafeJoiner();") val resultRow = ctx.freshName("resultRow") s""" - UnsafeRow $resultRow = $joinerTerm.join($keyTerm, $bufferTerm); + UnsafeRow $resultRow = $joinerTermAccessor.join($keyTerm, $bufferTerm); ${consume(ctx, null, resultRow)} """ @@ -521,7 +521,7 @@ case class HashAggregateExec( private def doProduceWithKeys(ctx: CodegenContext): String = { val initAgg = ctx.freshName("initAgg") - ctx.addMutableState("boolean", initAgg, s"$initAgg = false;") + val initAggAccessor = ctx.addMutableState("boolean", initAgg, s"$initAgg = false;") if (sqlContext.conf.enableTwoLevelAggMap) { enableTwoLevelHashMap(ctx) } else { @@ -546,18 +546,20 @@ case class HashAggregateExec( // Create a name for iterator from vectorized HashMap val iterTermForFastHashMap = ctx.freshName("fastHashMapIter") + var iterTermForFastHashMapAccessor: String = iterTermForFastHashMap if (isFastHashMapEnabled) { if (isVectorizedHashMapEnabled) { - ctx.addMutableState(fastHashMapClassName, fastHashMapTerm, + // Reassign fastHashMapTerm value to potentially class-qualified name + fastHashMapTerm = ctx.addMutableState(fastHashMapClassName, fastHashMapTerm, s"$fastHashMapTerm = new $fastHashMapClassName();") - ctx.addMutableState( + iterTermForFastHashMapAccessor = ctx.addMutableState( "java.util.Iterator", iterTermForFastHashMap, "") } else { - ctx.addMutableState(fastHashMapClassName, fastHashMapTerm, + fastHashMapTerm = ctx.addMutableState(fastHashMapClassName, fastHashMapTerm, s"$fastHashMapTerm = new $fastHashMapClassName(" + s"$thisPlan.getTaskMemoryManager(), $thisPlan.getEmptyAggregationBuffer());") - ctx.addMutableState( + iterTermForFastHashMapAccessor = ctx.addMutableState( "org.apache.spark.unsafe.KVIterator", iterTermForFastHashMap, "") } @@ -566,13 +568,14 @@ case class HashAggregateExec( // create hashMap hashMapTerm = ctx.freshName("hashMap") val hashMapClassName = classOf[UnsafeFixedWidthAggregationMap].getName - ctx.addMutableState(hashMapClassName, hashMapTerm, "") + hashMapTerm = ctx.addMutableState(hashMapClassName, hashMapTerm, "") sorterTerm = ctx.freshName("sorter") - ctx.addMutableState(classOf[UnsafeKVExternalSorter].getName, sorterTerm, "") + sorterTerm = ctx.addMutableState(classOf[UnsafeKVExternalSorter].getName, sorterTerm, "") // Create a name for iterator from HashMap val iterTerm = ctx.freshName("mapIter") - ctx.addMutableState(classOf[KVIterator[UnsafeRow, UnsafeRow]].getName, iterTerm, "") + val iterTermAccessor = + ctx.addMutableState(classOf[KVIterator[UnsafeRow, UnsafeRow]].getName, iterTerm, "") val doAgg = ctx.freshName("doAggregateWithKeys") val peakMemory = metricTerm(ctx, "peakMemory") @@ -592,7 +595,7 @@ case class HashAggregateExec( } else "" } - ctx.addNewFunction(doAgg, + val doAggFunc = ctx.addNewFunction(doAgg, s""" ${generateGenerateCode} private void $doAgg() throws java.io.IOException { @@ -600,9 +603,10 @@ case class HashAggregateExec( ${child.asInstanceOf[CodegenSupport].produce(ctx, this)} ${if (isFastHashMapEnabled) { - s"$iterTermForFastHashMap = $fastHashMapTerm.rowIterator();"} else ""} + s"$iterTermForFastHashMapAccessor = $fastHashMapTerm.rowIterator();"} else ""} - $iterTerm = $thisPlan.finishAggregate($hashMapTerm, $sorterTerm, $peakMemory, $spillSize); + $iterTermAccessor = $thisPlan.finishAggregate( + $hashMapTerm, $sorterTerm, $peakMemory, $spillSize); } """) @@ -628,10 +632,10 @@ case class HashAggregateExec( def outputFromRowBasedMap: String = { s""" - while ($iterTermForFastHashMap.next()) { + while ($iterTermForFastHashMapAccessor.next()) { $numOutput.add(1); - UnsafeRow $keyTerm = (UnsafeRow) $iterTermForFastHashMap.getKey(); - UnsafeRow $bufferTerm = (UnsafeRow) $iterTermForFastHashMap.getValue(); + UnsafeRow $keyTerm = (UnsafeRow) $iterTermForFastHashMapAccessor.getKey(); + UnsafeRow $bufferTerm = (UnsafeRow) $iterTermForFastHashMapAccessor.getValue(); $outputCode if (shouldStop()) return; @@ -650,11 +654,11 @@ case class HashAggregateExec( val generateRow = GenerateUnsafeProjection.createCode(ctx, schema.toAttributes.zipWithIndex .map { case (attr, i) => BoundReference(i, attr.dataType, attr.nullable) }) s""" - | while ($iterTermForFastHashMap.hasNext()) { + | while ($iterTermForFastHashMapAccessor.hasNext()) { | $numOutput.add(1); | org.apache.spark.sql.execution.vectorized.ColumnarBatch.Row $row = | (org.apache.spark.sql.execution.vectorized.ColumnarBatch.Row) - | $iterTermForFastHashMap.next(); + | $iterTermForFastHashMapAccessor.next(); | ${generateRow.code} | ${consume(ctx, Seq.empty, {generateRow.value})} | @@ -669,26 +673,26 @@ case class HashAggregateExec( val aggTime = metricTerm(ctx, "aggTime") val beforeAgg = ctx.freshName("beforeAgg") s""" - if (!$initAgg) { - $initAgg = true; + if (!$initAggAccessor) { + $initAggAccessor = true; long $beforeAgg = System.nanoTime(); - $doAgg(); + $doAggFunc(); $aggTime.add((System.nanoTime() - $beforeAgg) / 1000000); } // output the result ${outputFromGeneratedMap} - while ($iterTerm.next()) { + while ($iterTermAccessor.next()) { $numOutput.add(1); - UnsafeRow $keyTerm = (UnsafeRow) $iterTerm.getKey(); - UnsafeRow $bufferTerm = (UnsafeRow) $iterTerm.getValue(); + UnsafeRow $keyTerm = (UnsafeRow) $iterTermAccessor.getKey(); + UnsafeRow $bufferTerm = (UnsafeRow) $iterTermAccessor.getValue(); $outputCode if (shouldStop()) return; } - $iterTerm.close(); + $iterTermAccessor.close(); if ($sorterTerm == null) { $hashMapTerm.free(); } @@ -728,9 +732,10 @@ case class HashAggregateExec( val (checkFallbackForGeneratedHashMap, checkFallbackForBytesToBytesMap, resetCounter, incCounter) = if (testFallbackStartsAt.isDefined) { val countTerm = ctx.freshName("fallbackCounter") - ctx.addMutableState("int", countTerm, s"$countTerm = 0;") - (s"$countTerm < ${testFallbackStartsAt.get._1}", - s"$countTerm < ${testFallbackStartsAt.get._2}", s"$countTerm = 0;", s"$countTerm += 1;") + val counterTermAccessor = ctx.addMutableState("int", countTerm, s"$countTerm = 0;") + (s"$counterTermAccessor < ${testFallbackStartsAt.get._1}", + s"$counterTermAccessor < ${testFallbackStartsAt.get._2}", + s"$counterTermAccessor = 0;", s"$counterTermAccessor += 1;") } else { ("true", "true", "", "") } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashMapGenerator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashMapGenerator.scala index 90deb20e97244..12c835fc445d7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashMapGenerator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashMapGenerator.scala @@ -48,15 +48,15 @@ abstract class HashMapGenerator( initExpr.map { e => val isNull = ctx.freshName("bufIsNull") val value = ctx.freshName("bufValue") - ctx.addMutableState("boolean", isNull, "") - ctx.addMutableState(ctx.javaType(e.dataType), value, "") + val isNullAccessor = ctx.addMutableState("boolean", isNull, "") + val valueAccessor = ctx.addMutableState(ctx.javaType(e.dataType), value, "") val ev = e.genCode(ctx) val initVars = s""" - | $isNull = ${ev.isNull}; - | $value = ${ev.value}; + | $isNullAccessor = ${ev.isNull}; + | $valueAccessor = ${ev.value}; """.stripMargin - ExprCode(ev.code + initVars, isNull, value) + ExprCode(ev.code + initVars, isNullAccessor, valueAccessor) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala index 85096dcc40f5d..504203cc8cc7f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala @@ -281,10 +281,8 @@ case class SampleExec( val samplerClass = classOf[PoissonSampler[UnsafeRow]].getName val initSampler = ctx.freshName("initSampler") ctx.copyResult = true - ctx.addMutableState(s"$samplerClass", sampler, - s"$initSampler();") - ctx.addNewFunction(initSampler, + val initSamplerFunc = ctx.addNewFunction(initSampler, s""" | private void $initSampler() { | $sampler = new $samplerClass($upperBound - $lowerBound, false); @@ -299,9 +297,12 @@ case class SampleExec( | } """.stripMargin.trim) + val samplerAccessor = ctx.addMutableState(s"$samplerClass", sampler, + s"$initSamplerFunc();") + val samplingCount = ctx.freshName("samplingCount") s""" - | int $samplingCount = $sampler.sample(); + | int $samplingCount = $samplerAccessor.sample(); | while ($samplingCount-- > 0) { | $numOutput.add(1); | ${consume(ctx, input)} @@ -309,14 +310,14 @@ case class SampleExec( """.stripMargin.trim } else { val samplerClass = classOf[BernoulliCellSampler[UnsafeRow]].getName - ctx.addMutableState(s"$samplerClass", sampler, + val samplerAccessor = ctx.addMutableState(s"$samplerClass", sampler, s""" | $sampler = new $samplerClass($lowerBound, $upperBound, false); | $sampler.setSeed(${seed}L + partitionIndex); """.stripMargin.trim) s""" - | if ($sampler.sample() == 0) continue; + | if ($samplerAccessor.sample() == 0) continue; | $numOutput.add(1); | ${consume(ctx, input)} """.stripMargin.trim @@ -355,19 +356,20 @@ case class RangeExec(range: org.apache.spark.sql.catalyst.plans.logical.Range) val numOutput = metricTerm(ctx, "numOutputRows") val initTerm = ctx.freshName("initRange") - ctx.addMutableState("boolean", initTerm, s"$initTerm = false;") + val initTermAccessor = ctx.addMutableState("boolean", initTerm, s"$initTerm = false;") val number = ctx.freshName("number") - ctx.addMutableState("long", number, s"$number = 0L;") + val numberAccessor = ctx.addMutableState("long", number, s"$number = 0L;") val value = ctx.freshName("value") val ev = ExprCode("", "false", value) val BigInt = classOf[java.math.BigInteger].getName val taskContext = ctx.freshName("taskContext") - ctx.addMutableState("TaskContext", taskContext, s"$taskContext = TaskContext.get();") + val taskContextAccessor = + ctx.addMutableState("TaskContext", taskContext, s"$taskContext = TaskContext.get();") val inputMetrics = ctx.freshName("inputMetrics") - ctx.addMutableState("InputMetrics", inputMetrics, - s"$inputMetrics = $taskContext.taskMetrics().inputMetrics();") + val inputMetricsAccessor = ctx.addMutableState("InputMetrics", inputMetrics, + s"$inputMetrics = $taskContextAccessor.taskMetrics().inputMetrics();") // In order to periodically update the metrics without inflicting performance penalty, this // operator produces elements in batches. After a batch is complete, the metrics are updated @@ -378,11 +380,12 @@ case class RangeExec(range: org.apache.spark.sql.catalyst.plans.logical.Range) // Once number == batchEnd, it's time to progress to the next batch. val batchEnd = ctx.freshName("batchEnd") - ctx.addMutableState("long", batchEnd, s"$batchEnd = 0;") + val batchEndAccessor = ctx.addMutableState("long", batchEnd, s"$batchEnd = 0;") // How many values should still be generated by this range operator. val numElementsTodo = ctx.freshName("numElementsTodo") - ctx.addMutableState("long", numElementsTodo, s"$numElementsTodo = 0L;") + val numElementsTodoAccessor = + ctx.addMutableState("long", numElementsTodo, s"$numElementsTodo = 0L;") // How many values should be generated in the next batch. val nextBatchTodo = ctx.freshName("nextBatchTodo") @@ -390,7 +393,7 @@ case class RangeExec(range: org.apache.spark.sql.catalyst.plans.logical.Range) // The default size of a batch, which must be positive integer val batchSize = 1000 - ctx.addNewFunction("initRange", + val initRangeFunc = ctx.addNewFunction("initRange", s""" | private void initRange(int idx) { | $BigInt index = $BigInt.valueOf(idx); @@ -402,13 +405,13 @@ case class RangeExec(range: org.apache.spark.sql.catalyst.plans.logical.Range) | | $BigInt st = index.multiply(numElement).divide(numSlice).multiply(step).add(start); | if (st.compareTo($BigInt.valueOf(Long.MAX_VALUE)) > 0) { - | $number = Long.MAX_VALUE; + | $numberAccessor = Long.MAX_VALUE; | } else if (st.compareTo($BigInt.valueOf(Long.MIN_VALUE)) < 0) { - | $number = Long.MIN_VALUE; + | $numberAccessor = Long.MIN_VALUE; | } else { - | $number = st.longValue(); + | $numberAccessor = st.longValue(); | } - | $batchEnd = $number; + | $batchEndAccessor = $numberAccessor; | | $BigInt end = index.add($BigInt.ONE).multiply(numElement).divide(numSlice) | .multiply(step).add(start); @@ -421,12 +424,12 @@ case class RangeExec(range: org.apache.spark.sql.catalyst.plans.logical.Range) | } | | $BigInt startToEnd = $BigInt.valueOf(partitionEnd).subtract( - | $BigInt.valueOf($number)); - | $numElementsTodo = startToEnd.divide(step).longValue(); - | if ($numElementsTodo < 0) { - | $numElementsTodo = 0; + | $BigInt.valueOf($numberAccessor)); + | $numElementsTodoAccessor = startToEnd.divide(step).longValue(); + | if ($numElementsTodoAccessor < 0) { + | $numElementsTodoAccessor = 0; | } else if (startToEnd.remainder(step).compareTo($BigInt.valueOf(0L)) != 0) { - | $numElementsTodo++; + | $numElementsTodoAccessor++; | } | } """.stripMargin) @@ -439,44 +442,44 @@ case class RangeExec(range: org.apache.spark.sql.catalyst.plans.logical.Range) val localEnd = ctx.freshName("localEnd") val range = ctx.freshName("range") val shouldStop = if (isShouldStopRequired) { - s"if (shouldStop()) { $number = $value + ${step}L; return; }" + s"if (shouldStop()) { $numberAccessor = $value + ${step}L; return; }" } else { "// shouldStop check is eliminated" } s""" | // initialize Range - | if (!$initTerm) { - | $initTerm = true; - | initRange(partitionIndex); + | if (!$initTermAccessor) { + | $initTermAccessor = true; + | $initRangeFunc(partitionIndex); | } | | while (true) { - | long $range = $batchEnd - $number; + | long $range = $batchEndAccessor - $numberAccessor; | if ($range != 0L) { | int $localEnd = (int)($range / ${step}L); | for (int $localIdx = 0; $localIdx < $localEnd; $localIdx++) { - | long $value = ((long)$localIdx * ${step}L) + $number; + | long $value = ((long)$localIdx * ${step}L) + $numberAccessor; | ${consume(ctx, Seq(ev))} | $shouldStop | } - | $number = $batchEnd; + | $numberAccessor = $batchEndAccessor; | } | - | $taskContext.killTaskIfInterrupted(); + | $taskContextAccessor.killTaskIfInterrupted(); | | long $nextBatchTodo; - | if ($numElementsTodo > ${batchSize}L) { + | if ($numElementsTodoAccessor > ${batchSize}L) { | $nextBatchTodo = ${batchSize}L; - | $numElementsTodo -= ${batchSize}L; + | $numElementsTodoAccessor -= ${batchSize}L; | } else { - | $nextBatchTodo = $numElementsTodo; - | $numElementsTodo = 0; + | $nextBatchTodo = $numElementsTodoAccessor; + | $numElementsTodoAccessor = 0; | if ($nextBatchTodo == 0) break; | } | $numOutput.add($nextBatchTodo); - | $inputMetrics.incRecordsRead($nextBatchTodo); + | $inputMetricsAccessor.incRecordsRead($nextBatchTodo); | - | $batchEnd += $nextBatchTodo * ${step}L; + | $batchEndAccessor += $nextBatchTodo * ${step}L; | } """.stripMargin } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/GenerateColumnAccessor.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/GenerateColumnAccessor.scala index 14024d6c10558..bc4f549538917 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/GenerateColumnAccessor.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/GenerateColumnAccessor.scala @@ -90,19 +90,22 @@ object GenerateColumnAccessor extends CodeGenerator[Seq[DataType], ColumnarItera case array: ArrayType => classOf[ArrayColumnAccessor].getName case t: MapType => classOf[MapColumnAccessor].getName } - ctx.addMutableState(accessorCls, accessorName, "") + val accessorNameAccessor = ctx.addMutableState(accessorCls, accessorName, "") val createCode = dt match { case t if ctx.isPrimitiveType(dt) => - s"$accessorName = new $accessorCls(ByteBuffer.wrap(buffers[$index]).order(nativeOrder));" + s"$accessorNameAccessor = " + + s"new $accessorCls(ByteBuffer.wrap(buffers[$index]).order(nativeOrder));" case NullType | StringType | BinaryType => - s"$accessorName = new $accessorCls(ByteBuffer.wrap(buffers[$index]).order(nativeOrder));" + s"$accessorNameAccessor = " + + s"new $accessorCls(ByteBuffer.wrap(buffers[$index]).order(nativeOrder));" case other => - s"""$accessorName = new $accessorCls(ByteBuffer.wrap(buffers[$index]).order(nativeOrder), - (${dt.getClass.getName}) columnTypes[$index]);""" + s"$accessorNameAccessor = " + + s"new $accessorCls(ByteBuffer.wrap(buffers[$index]).order(nativeOrder)," + + s"(${dt.getClass.getName}) columnTypes[$index]);" } - val extract = s"$accessorName.extractTo(mutableRow, $index);" + val extract = s"$accessorNameAccessor.extractTo(mutableRow, $index);" val patch = dt match { case DecimalType.Fixed(p, s) if p > Decimal.MAX_LONG_DIGITS => // For large Decimal, it should have 16 bytes for future update even it's null now. @@ -128,9 +131,7 @@ object GenerateColumnAccessor extends CodeGenerator[Seq[DataType], ColumnarItera } else { val groupedAccessorsItr = initializeAccessors.grouped(numberOfStatementsThreshold) val groupedExtractorsItr = extractors.grouped(numberOfStatementsThreshold) - var groupedAccessorsLength = 0 - groupedAccessorsItr.zipWithIndex.foreach { case (body, i) => - groupedAccessorsLength += 1 + val accessorNames = groupedAccessorsItr.zipWithIndex.map { case (body, i) => val funcName = s"accessors$i" val funcCode = s""" |private void $funcName() { @@ -139,7 +140,7 @@ object GenerateColumnAccessor extends CodeGenerator[Seq[DataType], ColumnarItera """.stripMargin ctx.addNewFunction(funcName, funcCode) } - groupedExtractorsItr.zipWithIndex.foreach { case (body, i) => + val extractorNames = groupedExtractorsItr.zipWithIndex.map { case (body, i) => val funcName = s"extractors$i" val funcCode = s""" |private void $funcName() { @@ -148,8 +149,8 @@ object GenerateColumnAccessor extends CodeGenerator[Seq[DataType], ColumnarItera """.stripMargin ctx.addNewFunction(funcName, funcCode) } - ((0 to groupedAccessorsLength - 1).map { i => s"accessors$i();" }.mkString("\n"), - (0 to groupedAccessorsLength - 1).map { i => s"extractors$i();" }.mkString("\n")) + (accessorNames.map { accessorName => s"$accessorName();" }.mkString("\n"), + extractorNames.map { extractorName => s"$extractorName();" }.mkString("\n")) } val codeBody = s""" diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala index 0bc261d593df4..a7b3d8908f229 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala @@ -99,12 +99,12 @@ case class BroadcastHashJoinExec( val broadcast = ctx.addReferenceObj("broadcast", broadcastRelation) val relationTerm = ctx.freshName("relation") val clsName = broadcastRelation.value.getClass.getName - ctx.addMutableState(clsName, relationTerm, + val relationTermAccessor = ctx.addMutableState(clsName, relationTerm, s""" | $relationTerm = (($clsName) $broadcast.value()).asReadOnlyCopy(); | incPeakExecutionMemory($relationTerm.estimatedSize()); """.stripMargin) - (broadcastRelation, relationTerm) + (broadcastRelation, relationTermAccessor) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala index 26fb6103953fc..778934884e6ac 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala @@ -402,14 +402,14 @@ case class SortMergeJoinExec( private def genScanner(ctx: CodegenContext): (String, String) = { // Create class member for next row from both sides. val leftRow = ctx.freshName("leftRow") - ctx.addMutableState("InternalRow", leftRow, "") + val leftRowAccessor = ctx.addMutableState("InternalRow", leftRow, "") val rightRow = ctx.freshName("rightRow") - ctx.addMutableState("InternalRow", rightRow, s"$rightRow = null;") + val rightRowAccessor = ctx.addMutableState("InternalRow", rightRow, s"$rightRow = null;") // Create variables for join keys from both sides. - val leftKeyVars = createJoinKey(ctx, leftRow, leftKeys, left.output) + val leftKeyVars = createJoinKey(ctx, leftRowAccessor, leftKeys, left.output) val leftAnyNull = leftKeyVars.map(_.isNull).mkString(" || ") - val rightKeyTmpVars = createJoinKey(ctx, rightRow, rightKeys, right.output) + val rightKeyTmpVars = createJoinKey(ctx, rightRowAccessor, rightKeys, right.output) val rightAnyNull = rightKeyTmpVars.map(_.isNull).mkString(" || ") // Copy the right key as class members so they could be used in next function call. val rightKeyVars = copyKeys(ctx, rightKeyTmpVars) @@ -420,7 +420,7 @@ case class SortMergeJoinExec( val spillThreshold = getSpillThreshold - ctx.addMutableState(clsName, matches, s"$matches = new $clsName($spillThreshold);") + val matchesAccessor = ctx.addMutableState(clsName, matches, s"$matches = new $clsName($spillThreshold);") // Copy the left keys as class members so they could be used in next function call. val matchedKeyVars = copyKeys(ctx, leftKeyVars) @@ -429,58 +429,58 @@ case class SortMergeJoinExec( |private boolean findNextInnerJoinRows( | scala.collection.Iterator leftIter, | scala.collection.Iterator rightIter) { - | $leftRow = null; + | $leftRowAccessor = null; | int comp = 0; - | while ($leftRow == null) { + | while ($leftRowAccessor == null) { | if (!leftIter.hasNext()) return false; - | $leftRow = (InternalRow) leftIter.next(); + | $leftRowAccessor = (InternalRow) leftIter.next(); | ${leftKeyVars.map(_.code).mkString("\n")} | if ($leftAnyNull) { - | $leftRow = null; + | $leftRowAccessor = null; | continue; | } - | if (!$matches.isEmpty()) { + | if (!$matchesAccessor.isEmpty()) { | ${genComparision(ctx, leftKeyVars, matchedKeyVars)} | if (comp == 0) { | return true; | } - | $matches.clear(); + | $matchesAccessor.clear(); | } | | do { - | if ($rightRow == null) { + | if ($rightRowAccessor == null) { | if (!rightIter.hasNext()) { | ${matchedKeyVars.map(_.code).mkString("\n")} - | return !$matches.isEmpty(); + | return !$matchesAccessor.isEmpty(); | } - | $rightRow = (InternalRow) rightIter.next(); + | $rightRowAccessor = (InternalRow) rightIter.next(); | ${rightKeyTmpVars.map(_.code).mkString("\n")} | if ($rightAnyNull) { - | $rightRow = null; + | $rightRowAccessor = null; | continue; | } | ${rightKeyVars.map(_.code).mkString("\n")} | } | ${genComparision(ctx, leftKeyVars, rightKeyVars)} | if (comp > 0) { - | $rightRow = null; + | $rightRowAccessor = null; | } else if (comp < 0) { - | if (!$matches.isEmpty()) { + | if (!$matchesAccessor.isEmpty()) { | ${matchedKeyVars.map(_.code).mkString("\n")} | return true; | } - | $leftRow = null; + | $leftRowAccessor = null; | } else { - | $matches.add((UnsafeRow) $rightRow); - | $rightRow = null;; + | $matchesAccessor.add((UnsafeRow) $rightRowAccessor); + | $rightRowAccessor = null;; | } - | } while ($leftRow != null); + | } while ($leftRowAccessor != null); | } | return false; // unreachable |} - """.stripMargin) + """.stripMargin, inlineToOuterClass = true) - (leftRow, matches) + (leftRowAccessor, matchesAccessor) } /** @@ -496,18 +496,18 @@ case class SortMergeJoinExec( val value = ctx.freshName("value") val valueCode = ctx.getValue(leftRow, a.dataType, i.toString) // declare it as class member, so we can access the column before or in the loop. - ctx.addMutableState(ctx.javaType(a.dataType), value, "") + val valueAccessor = ctx.addMutableState(ctx.javaType(a.dataType), value, "") if (a.nullable) { val isNull = ctx.freshName("isNull") - ctx.addMutableState("boolean", isNull, "") + val isNullAccessor = ctx.addMutableState("boolean", isNull, "") val code = s""" - |$isNull = $leftRow.isNullAt($i); - |$value = $isNull ? ${ctx.defaultValue(a.dataType)} : ($valueCode); + |$isNullAccessor = $leftRow.isNullAt($i); + |$valueAccessor = $isNullAccessor ? ${ctx.defaultValue(a.dataType)} : ($valueCode); """.stripMargin - ExprCode(code, isNull, value) + ExprCode(code, isNullAccessor, valueAccessor) } else { - ExprCode(s"$value = $valueCode;", "false", value) + ExprCode(s"$valueAccessor = $valueCode;", "false", valueAccessor) } } } @@ -549,9 +549,11 @@ case class SortMergeJoinExec( override def doProduce(ctx: CodegenContext): String = { ctx.copyResult = true val leftInput = ctx.freshName("leftInput") - ctx.addMutableState("scala.collection.Iterator", leftInput, s"$leftInput = inputs[0];") + val leftInputAccessor = + ctx.addMutableState("scala.collection.Iterator", leftInput, s"$leftInput = inputs[0];") val rightInput = ctx.freshName("rightInput") - ctx.addMutableState("scala.collection.Iterator", rightInput, s"$rightInput = inputs[1];") + val rightInputAccessor = + ctx.addMutableState("scala.collection.Iterator", rightInput, s"$rightInput = inputs[1];") val (leftRow, matches) = genScanner(ctx) @@ -592,7 +594,7 @@ case class SortMergeJoinExec( } s""" - |while (findNextInnerJoinRows($leftInput, $rightInput)) { + |while (findNextInnerJoinRows($leftInputAccessor, $rightInputAccessor)) { | ${beforeLoop.trim} | scala.collection.Iterator $iterator = $matches.generateIterator(); | while ($iterator.hasNext()) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala index 757fe2185d302..b996b46a46d48 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala @@ -68,22 +68,22 @@ trait BaseLimitExec extends UnaryExecNode with CodegenSupport { override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: ExprCode): String = { val stopEarly = ctx.freshName("stopEarly") - ctx.addMutableState("boolean", stopEarly, s"$stopEarly = false;") + val stopEarlyAccessor = ctx.addMutableState("boolean", stopEarly, s"$stopEarly = false;") ctx.addNewFunction("stopEarly", s""" @Override protected boolean stopEarly() { - return $stopEarly; + return $stopEarlyAccessor; } - """) + """, inlineToOuterClass = true) val countTerm = ctx.freshName("count") - ctx.addMutableState("int", countTerm, s"$countTerm = 0;") + val countTermAccessor = ctx.addMutableState("int", countTerm, s"$countTerm = 0;") s""" - | if ($countTerm < $limit) { - | $countTerm += 1; + | if ($countTermAccessor < $limit) { + | $countTermAccessor += 1; | ${consume(ctx, input)} | } else { - | $stopEarly = true; + | $stopEarlyAccessor = true; | } """.stripMargin } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameComplexTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameComplexTypeSuite.scala index 1230b921aa279..e1c95a4814e2d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameComplexTypeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameComplexTypeSuite.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql import org.apache.spark.sql.catalyst.DefinedByConstructorParams import org.apache.spark.sql.functions._ import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.sql.types.{StringType, StructField, StructType} /** * A test suite to test DataFrame/SQL functionalities with complex types (i.e. array, struct, map). @@ -64,6 +65,16 @@ class DataFrameComplexTypeSuite extends QueryTest with SharedSQLContext { val ds100_5 = Seq(S100_5()).toDS() ds100_5.rdd.count } + + test("SPARK-18016 Constant Pool Past Limit for Wide/Nested Dataset") { + val schema = StructType((0 to 8000).map(n => StructField(s"column_$n", StringType))) + + val values = schema.map(_ => null) + val rows = spark.sparkContext.parallelize(Seq(Row(values: _*))) + val frame = spark.sqlContext.createDataFrame(rows, schema) + + frame.count() + } } class S100( @@ -97,5 +108,3 @@ extends DefinedByConstructorParams case class S100_5( s1: S100 = new S100(), s2: S100 = new S100(), s3: S100 = new S100(), s4: S100 = new S100(), s5: S100 = new S100()) - - From 75052c2bedead78bf53471b249233497a4579d5c Mon Sep 17 00:00:00 2001 From: ALeksander Eskilson Date: Thu, 30 Mar 2017 13:15:26 -0500 Subject: [PATCH 2/6] class_splitting mutable projections mutable state fix --- .../codegen/GenerateMutableProjection.scala | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala index fa7b353dea5bc..817a173ab726a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala @@ -66,32 +66,32 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], MutableP val isNullAccessor = ctx.addMutableState("boolean", isNull, s"$isNull = true;") val valueAccessor = ctx.addMutableState(ctx.javaType(e.dataType), value, s"$value = ${ctx.defaultValue(e.dataType)};") - s""" + (s""" ${ev.code} $isNullAccessor = ${ev.isNull}; $valueAccessor = ${ev.value}; - """ + """, isNullAccessor, valueAccessor, i) } else { val value = s"value_$i" val valueAccessor = ctx.addMutableState(ctx.javaType(e.dataType), value, s"$value = ${ctx.defaultValue(e.dataType)};") - s""" + (s""" ${ev.code} $valueAccessor = ${ev.value}; - """ + """, ev.isNull, valueAccessor, i) } } // Evaluate all the subexpressions. val evalSubexpr = ctx.subexprFunctions.mkString("\n") - val updates = validExpr.zip(index).map { - case (e, i) => - val ev = ExprCode("", s"this.isNull_$i", s"this.value_$i") + val updates = validExpr.zip(projectionCodes).map { + case (e, (_, isNullAccessor, valueAccessor, i)) => + val ev = ExprCode("", s"$isNullAccessor", s"$valueAccessor") ctx.updateColumn("mutableRow", e.dataType, i, ev, e.nullable) } - val allProjections = ctx.splitExpressions(ctx.INPUT_ROW, projectionCodes) + val allProjections = ctx.splitExpressions(ctx.INPUT_ROW, projectionCodes.map(_._1)) val allUpdates = ctx.splitExpressions(ctx.INPUT_ROW, updates) val codeBody = s""" From 0128e1eba5a7ad4ea37808006e125a8f462de32e Mon Sep 17 00:00:00 2001 From: ALeksander Eskilson Date: Thu, 6 Apr 2017 16:17:01 -0500 Subject: [PATCH 3/6] class_splitting increasing memory during tests, fixing accessor references --- pom.xml | 4 ++-- .../spark/sql/catalyst/expressions/objects/objects.scala | 2 +- .../spark/sql/catalyst/expressions/stringExpressions.scala | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/pom.xml b/pom.xml index 0533a8dcf2e0a..ea7aa4715467b 100644 --- a/pom.xml +++ b/pom.xml @@ -2005,7 +2005,7 @@ **/*Suite.java ${project.build.directory}/surefire-reports - -Xmx3g -Xss4096k -XX:ReservedCodeCacheSize=${CodeCacheSize} + -Xmx4g -Xss4096k -XX:ReservedCodeCacheSize=${CodeCacheSize}