diff --git a/sql/catalyst/pom.xml b/sql/catalyst/pom.xml
index 4b4a8eb3815e1..66c0ff09ea4a2 100644
--- a/sql/catalyst/pom.xml
+++ b/sql/catalyst/pom.xml
@@ -126,6 +126,13 @@
+
+ org.scalatest
+ scalatest-maven-plugin
+
+ -Xmx4g -Xss4096k -XX:MaxPermSize=${MaxPermGen} -XX:ReservedCodeCacheSize=512m
+
+
org.antlr
antlr4-maven-plugin
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala
index 228f4b756c8b4..5c68f9ffc691c 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala
@@ -988,7 +988,7 @@ case class ScalaUDF(
val converterTerm = ctx.freshName("converter")
val expressionIdx = ctx.references.size - 1
ctx.addMutableState(converterClassName, converterTerm,
- s"this.$converterTerm = ($converterClassName)$typeConvertersClassName" +
+ s"$converterTerm = ($converterClassName)$typeConvertersClassName" +
s".createToScalaConverter(((${expressionClassName})((($scalaUDFClassName)" +
s"references[$expressionIdx]).getChildren().apply($index))).dataType());")
converterTerm
@@ -1005,7 +1005,7 @@ case class ScalaUDF(
// Generate codes used to convert the returned value of user-defined functions to Catalyst type
val catalystConverterTerm = ctx.freshName("catalystConverter")
ctx.addMutableState(converterClassName, catalystConverterTerm,
- s"this.$catalystConverterTerm = ($converterClassName)$typeConvertersClassName" +
+ s"$catalystConverterTerm = ($converterClassName)$typeConvertersClassName" +
s".createToCatalystConverter($scalaUDF.dataType());")
val resultTerm = ctx.freshName("result")
@@ -1019,7 +1019,7 @@ case class ScalaUDF(
val funcTerm = ctx.freshName("udf")
ctx.addMutableState(funcClassName, funcTerm,
- s"this.$funcTerm = ($funcClassName)$scalaUDF.userDefinedFunc();")
+ s"$funcTerm = ($funcClassName)$scalaUDF.userDefinedFunc();")
// codegen for children expressions
val evals = children.map(_.genCode(ctx))
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
index 683b9cbb343c8..22ce3f7e7c52e 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
@@ -109,7 +109,7 @@ class CodegenContext {
val idx = references.length
references += obj
val clsName = Option(className).getOrElse(obj.getClass.getName)
- addMutableState(clsName, term, s"this.$term = ($clsName) references[$idx];")
+ addMutableState(clsName, term, s"$term = ($clsName) references[$idx];")
term
}
@@ -198,41 +198,139 @@ class CodegenContext {
partitionInitializationStatements.mkString("\n")
}
+ /**
+ * Holds expressions that are equivalent. Used to perform subexpression elimination
+ * during codegen.
+ *
+ * For expressions that appear more than once, generate additional code to prevent
+ * recomputing the value.
+ *
+ * For example, consider two expression generated from this SQL statement:
+ * SELECT (col1 + col2), (col1 + col2) / col3.
+ *
+ * equivalentExpressions will match the tree containing `col1 + col2` and it will only
+ * be evaluated once.
+ */
+ val equivalentExpressions: EquivalentExpressions = new EquivalentExpressions
+
+ // Foreach expression that is participating in subexpression elimination, the state to use.
+ val subExprEliminationExprs = mutable.HashMap.empty[Expression, SubExprEliminationState]
+
+ // The collection of sub-expression result resetting methods that need to be called on each row.
+ val subexprFunctions = mutable.ArrayBuffer.empty[String]
+
+ private val outerClassName = "OuterClass"
+
/**
- * Holding all the functions those will be added into generated class.
+ * Holds the class and instance names to be generated, where `OuterClass` is a placeholder
+ * standing for whichever class is generated as the outermost class and which will contain any
+ * nested sub-classes. All other classes and instance names in this list will represent private,
+ * nested sub-classes.
*/
- val addedFunctions: mutable.Map[String, String] =
- mutable.Map.empty[String, String]
+ private val classes: mutable.ListBuffer[(String, String)] =
+ mutable.ListBuffer[(String, String)](outerClassName -> null)
+
+ // A map holding the current size in bytes of each class to be generated.
+ private val classSize: mutable.Map[String, Int] =
+ mutable.Map[String, Int](outerClassName -> 0)
+
+ // Nested maps holding function names and their code belonging to each class.
+ private val classFunctions: mutable.Map[String, mutable.Map[String, String]] =
+ mutable.Map(outerClassName -> mutable.Map.empty[String, String])
- def addNewFunction(funcName: String, funcCode: String): Unit = {
- addedFunctions += ((funcName, funcCode))
+ // Returns the size of the most recently added class.
+ private def currClassSize(): Int = classSize(classes.head._1)
+
+ // Returns the class name and instance name for the most recently added class.
+ private def currClass(): (String, String) = classes.head
+
+ // Adds a new class. Requires the class' name, and its instance name.
+ private def addClass(className: String, classInstance: String): Unit = {
+ classes.prepend(className -> classInstance)
+ classSize += className -> 0
+ classFunctions += className -> mutable.Map.empty[String, String]
}
/**
- * Holds expressions that are equivalent. Used to perform subexpression elimination
- * during codegen.
- *
- * For expressions that appear more than once, generate additional code to prevent
- * recomputing the value.
+ * Adds a function to the generated class. If the code for the `OuterClass` grows too large, the
+ * function will be inlined into a new private, nested class, and a instance-qualified name for
+ * the function will be returned. Otherwise, the function will be inlined to the `OuterClass` the
+ * simple `funcName` will be returned.
*
- * For example, consider two expression generated from this SQL statement:
- * SELECT (col1 + col2), (col1 + col2) / col3.
- *
- * equivalentExpressions will match the tree containing `col1 + col2` and it will only
- * be evaluated once.
+ * @param funcName the class-unqualified name of the function
+ * @param funcCode the body of the function
+ * @param inlineToOuterClass whether the given code must be inlined to the `OuterClass`. This
+ * can be necessary when a function is declared outside of the context
+ * it is eventually referenced and a returned qualified function name
+ * cannot otherwise be accessed.
+ * @return the name of the function, qualified by class if it will be inlined to a private,
+ * nested sub-class
*/
- val equivalentExpressions: EquivalentExpressions = new EquivalentExpressions
+ def addNewFunction(
+ funcName: String,
+ funcCode: String,
+ inlineToOuterClass: Boolean = false): String = {
+ // The number of named constants that can exist in the class is limited by the Constant Pool
+ // limit, 65,536. We cannot know how many constants will be inserted for a class, so we use a
+ // threshold of 1600k bytes to determine when a function should be inlined to a private, nested
+ // sub-class.
+ val (className, classInstance) = if (inlineToOuterClass) {
+ outerClassName -> ""
+ } else if (currClassSize > 1600000) {
+ val className = freshName("NestedClass")
+ val classInstance = freshName("nestedClassInstance")
+
+ addClass(className, classInstance)
+
+ className -> classInstance
+ } else {
+ currClass()
+ }
- // Foreach expression that is participating in subexpression elimination, the state to use.
- val subExprEliminationExprs = mutable.HashMap.empty[Expression, SubExprEliminationState]
+ classSize(className) += funcCode.length
+ classFunctions(className) += funcName -> funcCode
- // The collection of sub-expression result resetting methods that need to be called on each row.
- val subexprFunctions = mutable.ArrayBuffer.empty[String]
+ if (className == outerClassName) {
+ funcName
+ } else {
- def declareAddedFunctions(): String = {
- addedFunctions.map { case (funcName, funcCode) => funcCode }.mkString("\n")
+ s"$classInstance.$funcName"
+ }
}
+ /**
+ * Instantiates all nested, private sub-classes as objects to the `OuterClass`
+ */
+ private[sql] def initNestedClasses(): String = {
+ // Nested, private sub-classes have no mutable state (though they do reference the outer class'
+ // mutable state), so we declare and initialize them inline to the OuterClass.
+ classes.filter(_._1 != outerClassName).map {
+ case (className, classInstance) =>
+ s"private $className $classInstance = new $className();"
+ }.mkString("\n")
+ }
+
+ /**
+ * Declares all function code that should be inlined to the `OuterClass`.
+ */
+ private[sql] def declareAddedFunctions(): String = {
+ classFunctions(outerClassName).values.mkString("\n")
+ }
+
+ /**
+ * Declares all nested, private sub-classes and the function code that should be inlined to them.
+ */
+ private[sql] def declareNestedClasses(): String = {
+ classFunctions.filterKeys(_ != outerClassName).map {
+ case (className, functions) =>
+ s"""
+ |private class $className {
+ | ${functions.values.mkString("\n")}
+ |}
+ """.stripMargin
+ }
+ }.mkString("\n")
+
final val JAVA_BOOLEAN = "boolean"
final val JAVA_BYTE = "byte"
final val JAVA_SHORT = "short"
@@ -552,8 +650,7 @@ class CodegenContext {
return 0;
}
"""
- addNewFunction(compareFunc, funcCode)
- s"this.$compareFunc($c1, $c2)"
+ s"${addNewFunction(compareFunc, funcCode)}($c1, $c2)"
case schema: StructType =>
val comparisons = GenerateOrdering.genComparisons(this, schema)
val compareFunc = freshName("compareStruct")
@@ -569,8 +666,7 @@ class CodegenContext {
return 0;
}
"""
- addNewFunction(compareFunc, funcCode)
- s"this.$compareFunc($c1, $c2)"
+ s"${addNewFunction(compareFunc, funcCode)}($c1, $c2)"
case other if other.isInstanceOf[AtomicType] => s"$c1.compare($c2)"
case udt: UserDefinedType[_] => genComp(udt.sqlType, c1, c2)
case _ =>
@@ -640,7 +736,9 @@ class CodegenContext {
/**
* Splits the generated code of expressions into multiple functions, because function has
- * 64kb code size limit in JVM
+ * 64kb code size limit in JVM. If the class to which the function would be inlined would grow
+ * beyond 1600kb, we declare a private, nested sub-class, and the function is inlined to it
+ * instead, because classes have a constant pool limit of 65,536 named values.
*
* @param expressions the codes to evaluate expressions.
* @param funcName the split function name base.
@@ -685,7 +783,6 @@ class CodegenContext {
|}
""".stripMargin
addNewFunction(name, code)
- name
}
foldFunctions(functions.map(name => s"$name(${arguments.map(_._2).mkString(", ")})"))
@@ -769,8 +866,6 @@ class CodegenContext {
|}
""".stripMargin
- addNewFunction(fnName, fn)
-
// Add a state and a mapping of the common subexpressions that are associate with this
// state. Adding this expression to subExprEliminationExprMap means it will call `fn`
// when it is code generated. This decision should be a cost based one.
@@ -791,7 +886,7 @@ class CodegenContext {
addMutableState(javaType(expr.dataType), value,
s"$value = ${defaultValue(expr.dataType)};")
- subexprFunctions += s"$fnName($INPUT_ROW);"
+ subexprFunctions += s"${addNewFunction(fnName, fn)}($INPUT_ROW);"
val state = SubExprEliminationState(isNull, value)
e.foreach(subExprEliminationExprs.put(_, state))
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala
index 4d732445544a8..635766835029b 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala
@@ -63,21 +63,21 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], MutableP
if (e.nullable) {
val isNull = s"isNull_$i"
val value = s"value_$i"
- ctx.addMutableState("boolean", isNull, s"this.$isNull = true;")
+ ctx.addMutableState("boolean", isNull, s"$isNull = true;")
ctx.addMutableState(ctx.javaType(e.dataType), value,
- s"this.$value = ${ctx.defaultValue(e.dataType)};")
+ s"$value = ${ctx.defaultValue(e.dataType)};")
s"""
${ev.code}
- this.$isNull = ${ev.isNull};
- this.$value = ${ev.value};
+ $isNull = ${ev.isNull};
+ $value = ${ev.value};
"""
} else {
val value = s"value_$i"
ctx.addMutableState(ctx.javaType(e.dataType), value,
- s"this.$value = ${ctx.defaultValue(e.dataType)};")
+ s"$value = ${ctx.defaultValue(e.dataType)};")
s"""
${ev.code}
- this.$value = ${ev.value};
+ $value = ${ev.value};
"""
}
}
@@ -87,7 +87,7 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], MutableP
val updates = validExpr.zip(index).map {
case (e, i) =>
- val ev = ExprCode("", s"this.isNull_$i", s"this.value_$i")
+ val ev = ExprCode("", s"isNull_$i", s"value_$i")
ctx.updateColumn("mutableRow", e.dataType, i, ev, e.nullable)
}
@@ -135,6 +135,9 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], MutableP
$allUpdates
return mutableRow;
}
+
+ ${ctx.initNestedClasses()}
+ ${ctx.declareNestedClasses()}
}
"""
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala
index f7fc2d54a047b..a31943255b995 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala
@@ -179,6 +179,9 @@ object GenerateOrdering extends CodeGenerator[Seq[SortOrder], Ordering[InternalR
$comparisons
return 0;
}
+
+ ${ctx.initNestedClasses()}
+ ${ctx.declareNestedClasses()}
}"""
val code = CodeFormatter.stripOverlappingComments(
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala
index dcd1ed96a298e..b400783bb5e55 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala
@@ -72,6 +72,9 @@ object GeneratePredicate extends CodeGenerator[Expression, Predicate] {
${eval.code}
return !${eval.isNull} && ${eval.value};
}
+
+ ${ctx.initNestedClasses()}
+ ${ctx.declareNestedClasses()}
}"""
val code = CodeFormatter.stripOverlappingComments(
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala
index b1cb6edefb852..f708aeff2b146 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala
@@ -49,7 +49,7 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection]
val output = ctx.freshName("safeRow")
val values = ctx.freshName("values")
// These expressions could be split into multiple functions
- ctx.addMutableState("Object[]", values, s"this.$values = null;")
+ ctx.addMutableState("Object[]", values, s"$values = null;")
val rowClass = classOf[GenericInternalRow].getName
@@ -65,10 +65,10 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection]
val allFields = ctx.splitExpressions(tmp, fieldWriters)
val code = s"""
final InternalRow $tmp = $input;
- this.$values = new Object[${schema.length}];
+ $values = new Object[${schema.length}];
$allFields
final InternalRow $output = new $rowClass($values);
- this.$values = null;
+ $values = null;
"""
ExprCode(code, "false", output)
@@ -184,6 +184,9 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection]
$allExpressions
return mutableRow;
}
+
+ ${ctx.initNestedClasses()}
+ ${ctx.declareNestedClasses()}
}
"""
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala
index b358102d914bd..febfe3124f2bd 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala
@@ -82,7 +82,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
val rowWriterClass = classOf[UnsafeRowWriter].getName
val rowWriter = ctx.freshName("rowWriter")
ctx.addMutableState(rowWriterClass, rowWriter,
- s"this.$rowWriter = new $rowWriterClass($bufferHolder, ${inputs.length});")
+ s"$rowWriter = new $rowWriterClass($bufferHolder, ${inputs.length});")
val resetWriter = if (isTopLevel) {
// For top level row writer, it always writes to the beginning of the global buffer holder,
@@ -182,7 +182,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
val arrayWriterClass = classOf[UnsafeArrayWriter].getName
val arrayWriter = ctx.freshName("arrayWriter")
ctx.addMutableState(arrayWriterClass, arrayWriter,
- s"this.$arrayWriter = new $arrayWriterClass();")
+ s"$arrayWriter = new $arrayWriterClass();")
val numElements = ctx.freshName("numElements")
val index = ctx.freshName("index")
val element = ctx.freshName("element")
@@ -321,7 +321,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
val holder = ctx.freshName("holder")
val holderClass = classOf[BufferHolder].getName
ctx.addMutableState(holderClass, holder,
- s"this.$holder = new $holderClass($result, ${numVarLenFields * 32});")
+ s"$holder = new $holderClass($result, ${numVarLenFields * 32});")
val resetBufferHolder = if (numVarLenFields == 0) {
""
@@ -402,6 +402,9 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
${eval.code.trim}
return ${eval.value};
}
+
+ ${ctx.initNestedClasses()}
+ ${ctx.declareNestedClasses()}
}
"""
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala
index 3df2ed8be0650..04e32bda6b0d4 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala
@@ -58,10 +58,10 @@ case class CreateArray(children: Seq[Expression]) extends Expression {
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val arrayClass = classOf[GenericArrayData].getName
val values = ctx.freshName("values")
- ctx.addMutableState("Object[]", values, s"this.$values = null;")
+ ctx.addMutableState("Object[]", values, s"$values = null;")
ev.copy(code = s"""
- this.$values = new Object[${children.size}];""" +
+ $values = new Object[${children.size}];""" +
ctx.splitExpressions(
ctx.INPUT_ROW,
children.zipWithIndex.map { case (e, i) =>
@@ -76,7 +76,7 @@ case class CreateArray(children: Seq[Expression]) extends Expression {
}) +
s"""
final ArrayData ${ev.value} = new $arrayClass($values);
- this.$values = null;
+ $values = null;
""", isNull = "false")
}
@@ -137,8 +137,8 @@ case class CreateMap(children: Seq[Expression]) extends Expression {
val mapClass = classOf[ArrayBasedMapData].getName
val keyArray = ctx.freshName("keyArray")
val valueArray = ctx.freshName("valueArray")
- ctx.addMutableState("Object[]", keyArray, s"this.$keyArray = null;")
- ctx.addMutableState("Object[]", valueArray, s"this.$valueArray = null;")
+ ctx.addMutableState("Object[]", keyArray, s"$keyArray = null;")
+ ctx.addMutableState("Object[]", valueArray, s"$valueArray = null;")
val keyData = s"new $arrayClass($keyArray)"
val valueData = s"new $arrayClass($valueArray)"
@@ -173,8 +173,8 @@ case class CreateMap(children: Seq[Expression]) extends Expression {
}) +
s"""
final MapData ${ev.value} = new $mapClass($keyData, $valueData);
- this.$keyArray = null;
- this.$valueArray = null;
+ $keyArray = null;
+ $valueArray = null;
""", isNull = "false")
}
@@ -296,7 +296,7 @@ case class CreateNamedStruct(children: Seq[Expression]) extends CreateNamedStruc
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val rowClass = classOf[GenericInternalRow].getName
val values = ctx.freshName("values")
- ctx.addMutableState("Object[]", values, s"this.$values = null;")
+ ctx.addMutableState("Object[]", values, s"$values = null;")
ev.copy(code = s"""
$values = new Object[${valExprs.size}];""" +
@@ -313,7 +313,7 @@ case class CreateNamedStruct(children: Seq[Expression]) extends CreateNamedStruc
}) +
s"""
final InternalRow ${ev.value} = new $rowClass($values);
- this.$values = null;
+ $values = null;
""", isNull = "false")
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala
index bacedec1ae203..092c5de08df70 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala
@@ -131,8 +131,8 @@ case class If(predicate: Expression, trueValue: Expression, falseValue: Expressi
| $globalValue = ${ev.value};
|}
""".stripMargin
- ctx.addNewFunction(funcName, funcBody)
- (funcName, globalIsNull, globalValue)
+ val fullFuncName = ctx.addNewFunction(funcName, funcBody)
+ (fullFuncName, globalIsNull, globalValue)
}
override def toString: String = s"if ($predicate) $trueValue else $falseValue"
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
index 256de74d410e4..5009bf8e96e83 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
@@ -912,7 +912,7 @@ case class InitializeJavaBean(beanInstance: Expression, setters: Map[String, Exp
val code = s"""
${instanceGen.code}
- this.${javaBeanInstance} = ${instanceGen.value};
+ ${javaBeanInstance} = ${instanceGen.value};
if (!${instanceGen.isNull}) {
$initializeCode
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratedProjectionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratedProjectionSuite.scala
index b69b74b4240bd..7bfdf550bc376 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratedProjectionSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratedProjectionSuite.scala
@@ -33,10 +33,10 @@ class GeneratedProjectionSuite extends SparkFunSuite {
test("generated projections on wider table") {
val N = 1000
- val wideRow1 = new GenericInternalRow((1 to N).toArray[Any])
+ val wideRow1 = new GenericInternalRow((0 until N).toArray[Any])
val schema1 = StructType((1 to N).map(i => StructField("", IntegerType)))
val wideRow2 = new GenericInternalRow(
- (1 to N).map(i => UTF8String.fromString(i.toString)).toArray[Any])
+ (0 until N).map(i => UTF8String.fromString(i.toString)).toArray[Any])
val schema2 = StructType((1 to N).map(i => StructField("", StringType)))
val joined = new JoinedRow(wideRow1, wideRow2)
val joinedSchema = StructType(schema1 ++ schema2)
@@ -48,12 +48,12 @@ class GeneratedProjectionSuite extends SparkFunSuite {
val unsafeProj = UnsafeProjection.create(nestedSchema)
val unsafe: UnsafeRow = unsafeProj(nested)
(0 until N).foreach { i =>
- val s = UTF8String.fromString((i + 1).toString)
- assert(i + 1 === unsafe.getInt(i + 2))
+ val s = UTF8String.fromString(i.toString)
+ assert(i === unsafe.getInt(i + 2))
assert(s === unsafe.getUTF8String(i + 2 + N))
- assert(i + 1 === unsafe.getStruct(0, N * 2).getInt(i))
+ assert(i === unsafe.getStruct(0, N * 2).getInt(i))
assert(s === unsafe.getStruct(0, N * 2).getUTF8String(i + N))
- assert(i + 1 === unsafe.getStruct(1, N * 2).getInt(i))
+ assert(i === unsafe.getStruct(1, N * 2).getInt(i))
assert(s === unsafe.getStruct(1, N * 2).getUTF8String(i + N))
}
@@ -62,13 +62,63 @@ class GeneratedProjectionSuite extends SparkFunSuite {
val result = safeProj(unsafe)
// Can't compare GenericInternalRow with JoinedRow directly
(0 until N).foreach { i =>
- val r = i + 1
- val s = UTF8String.fromString((i + 1).toString)
- assert(r === result.getInt(i + 2))
+ val s = UTF8String.fromString(i.toString)
+ assert(i === result.getInt(i + 2))
assert(s === result.getUTF8String(i + 2 + N))
- assert(r === result.getStruct(0, N * 2).getInt(i))
+ assert(i === result.getStruct(0, N * 2).getInt(i))
assert(s === result.getStruct(0, N * 2).getUTF8String(i + N))
- assert(r === result.getStruct(1, N * 2).getInt(i))
+ assert(i === result.getStruct(1, N * 2).getInt(i))
+ assert(s === result.getStruct(1, N * 2).getUTF8String(i + N))
+ }
+
+ // test generated MutableProjection
+ val exprs = nestedSchema.fields.zipWithIndex.map { case (f, i) =>
+ BoundReference(i, f.dataType, true)
+ }
+ val mutableProj = GenerateMutableProjection.generate(exprs)
+ val row1 = mutableProj(result)
+ assert(result === row1)
+ val row2 = mutableProj(result)
+ assert(result === row2)
+ }
+
+ test("generated projections on wider table requiring class-splitting") {
+ val N = 4000
+ val wideRow1 = new GenericInternalRow((0 until N).toArray[Any])
+ val schema1 = StructType((1 to N).map(i => StructField("", IntegerType)))
+ val wideRow2 = new GenericInternalRow(
+ (0 until N).map(i => UTF8String.fromString(i.toString)).toArray[Any])
+ val schema2 = StructType((1 to N).map(i => StructField("", StringType)))
+ val joined = new JoinedRow(wideRow1, wideRow2)
+ val joinedSchema = StructType(schema1 ++ schema2)
+ val nested = new JoinedRow(InternalRow(joined, joined), joined)
+ val nestedSchema = StructType(
+ Seq(StructField("", joinedSchema), StructField("", joinedSchema)) ++ joinedSchema)
+
+ // test generated UnsafeProjection
+ val unsafeProj = UnsafeProjection.create(nestedSchema)
+ val unsafe: UnsafeRow = unsafeProj(nested)
+ (0 until N).foreach { i =>
+ val s = UTF8String.fromString(i.toString)
+ assert(i === unsafe.getInt(i + 2))
+ assert(s === unsafe.getUTF8String(i + 2 + N))
+ assert(i === unsafe.getStruct(0, N * 2).getInt(i))
+ assert(s === unsafe.getStruct(0, N * 2).getUTF8String(i + N))
+ assert(i === unsafe.getStruct(1, N * 2).getInt(i))
+ assert(s === unsafe.getStruct(1, N * 2).getUTF8String(i + N))
+ }
+
+ // test generated SafeProjection
+ val safeProj = FromUnsafeProjection(nestedSchema)
+ val result = safeProj(unsafe)
+ // Can't compare GenericInternalRow with JoinedRow directly
+ (0 until N).foreach { i =>
+ val s = UTF8String.fromString(i.toString)
+ assert(i === result.getInt(i + 2))
+ assert(s === result.getUTF8String(i + 2 + N))
+ assert(i === result.getStruct(0, N * 2).getInt(i))
+ assert(s === result.getStruct(0, N * 2).getUTF8String(i + N))
+ assert(i === result.getStruct(1, N * 2).getInt(i))
assert(s === result.getStruct(1, N * 2).getUTF8String(i + N))
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala
index b4aed23218357..0cfdc83573936 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala
@@ -363,7 +363,7 @@ case class FileSourceScanExec(
}
val nextBatch = ctx.freshName("nextBatch")
- ctx.addNewFunction(nextBatch,
+ val nextBatchFuncName = ctx.addNewFunction(nextBatch,
s"""
|private void $nextBatch() throws java.io.IOException {
| long getBatchStart = System.nanoTime();
@@ -383,7 +383,7 @@ case class FileSourceScanExec(
}
s"""
|if ($batch == null) {
- | $nextBatch();
+ | $nextBatchFuncName();
|}
|while ($batch != null) {
| int numRows = $batch.numRows();
@@ -393,7 +393,7 @@ case class FileSourceScanExec(
| if (shouldStop()) return;
| }
| $batch = null;
- | $nextBatch();
+ | $nextBatchFuncName();
|}
|$scanTimeMetric.add($scanTimeTotalNs / (1000 * 1000));
|$scanTimeTotalNs = 0;
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SortExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SortExec.scala
index cc576bbc4c802..9d3dbc2571610 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SortExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SortExec.scala
@@ -141,7 +141,7 @@ case class SortExec(
ctx.addMutableState("scala.collection.Iterator", sortedIterator, "")
val addToSorter = ctx.freshName("addToSorter")
- ctx.addNewFunction(addToSorter,
+ val addToSorterFuncName = ctx.addNewFunction(addToSorter,
s"""
| private void $addToSorter() throws java.io.IOException {
| ${child.asInstanceOf[CodegenSupport].produce(ctx, this)}
@@ -160,7 +160,7 @@ case class SortExec(
s"""
| if ($needToSort) {
| long $spillSizeBefore = $metrics.memoryBytesSpilled();
- | $addToSorter();
+ | $addToSorterFuncName();
| $sortedIterator = $sorterVariable.sort();
| $sortTime.add($sorterVariable.getSortTimeNanos() / 1000000);
| $peakMemory.add($sorterVariable.getPeakMemoryUsage());
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala
index 2ead8f6baae6b..f3931b8e47d15 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala
@@ -339,6 +339,9 @@ case class WholeStageCodegenExec(child: SparkPlan) extends UnaryExecNode with Co
protected void processNext() throws java.io.IOException {
${code.trim}
}
+
+ ${ctx.initNestedClasses()}
+ ${ctx.declareNestedClasses()}
}
""".trim
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala
index 4529ed067e565..1c6d4f8b18fa5 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala
@@ -209,7 +209,7 @@ case class HashAggregateExec(
}
val doAgg = ctx.freshName("doAggregateWithoutKey")
- ctx.addNewFunction(doAgg,
+ val doAggFuncName = ctx.addNewFunction(doAgg,
s"""
| private void $doAgg() throws java.io.IOException {
| // initialize aggregation buffer
@@ -226,7 +226,7 @@ case class HashAggregateExec(
| while (!$initAgg) {
| $initAgg = true;
| long $beforeAgg = System.nanoTime();
- | $doAgg();
+ | $doAggFuncName();
| $aggTime.add((System.nanoTime() - $beforeAgg) / 1000000);
|
| // output the result
@@ -590,7 +590,7 @@ case class HashAggregateExec(
} else ""
}
- ctx.addNewFunction(doAgg,
+ val doAggFuncName = ctx.addNewFunction(doAgg,
s"""
${generateGenerateCode}
private void $doAgg() throws java.io.IOException {
@@ -670,7 +670,7 @@ case class HashAggregateExec(
if (!$initAgg) {
$initAgg = true;
long $beforeAgg = System.nanoTime();
- $doAgg();
+ $doAggFuncName();
$aggTime.add((System.nanoTime() - $beforeAgg) / 1000000);
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala
index b00223a86d4d4..6176e6d55f784 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala
@@ -281,10 +281,8 @@ case class SampleExec(
val samplerClass = classOf[PoissonSampler[UnsafeRow]].getName
val initSampler = ctx.freshName("initSampler")
ctx.copyResult = true
- ctx.addMutableState(s"$samplerClass", sampler,
- s"$initSampler();")
- ctx.addNewFunction(initSampler,
+ val initSamplerFuncName = ctx.addNewFunction(initSampler,
s"""
| private void $initSampler() {
| $sampler = new $samplerClass($upperBound - $lowerBound, false);
@@ -299,6 +297,8 @@ case class SampleExec(
| }
""".stripMargin.trim)
+ ctx.addMutableState(s"$samplerClass", sampler, s"$initSamplerFuncName();")
+
val samplingCount = ctx.freshName("samplingCount")
s"""
| int $samplingCount = $sampler.sample();
@@ -370,7 +370,7 @@ case class RangeExec(range: org.apache.spark.sql.catalyst.plans.logical.Range)
s"$number > $partitionEnd"
}
- ctx.addNewFunction("initRange",
+ val initRangeFuncName = ctx.addNewFunction("initRange",
s"""
| private void initRange(int idx) {
| $BigInt index = $BigInt.valueOf(idx);
@@ -409,7 +409,7 @@ case class RangeExec(range: org.apache.spark.sql.catalyst.plans.logical.Range)
| // initialize Range
| if (!$initTerm) {
| $initTerm = true;
- | initRange(partitionIndex);
+ | $initRangeFuncName(partitionIndex);
| }
|
| while (!$overflow && $checkEnd) {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/GenerateColumnAccessor.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/GenerateColumnAccessor.scala
index 14024d6c10558..f4566496fca5a 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/GenerateColumnAccessor.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/GenerateColumnAccessor.scala
@@ -128,9 +128,7 @@ object GenerateColumnAccessor extends CodeGenerator[Seq[DataType], ColumnarItera
} else {
val groupedAccessorsItr = initializeAccessors.grouped(numberOfStatementsThreshold)
val groupedExtractorsItr = extractors.grouped(numberOfStatementsThreshold)
- var groupedAccessorsLength = 0
- groupedAccessorsItr.zipWithIndex.foreach { case (body, i) =>
- groupedAccessorsLength += 1
+ val accessorNames = groupedAccessorsItr.zipWithIndex.map { case (body, i) =>
val funcName = s"accessors$i"
val funcCode = s"""
|private void $funcName() {
@@ -139,7 +137,7 @@ object GenerateColumnAccessor extends CodeGenerator[Seq[DataType], ColumnarItera
""".stripMargin
ctx.addNewFunction(funcName, funcCode)
}
- groupedExtractorsItr.zipWithIndex.foreach { case (body, i) =>
+ val extractorNames = groupedExtractorsItr.zipWithIndex.map { case (body, i) =>
val funcName = s"extractors$i"
val funcCode = s"""
|private void $funcName() {
@@ -148,8 +146,8 @@ object GenerateColumnAccessor extends CodeGenerator[Seq[DataType], ColumnarItera
""".stripMargin
ctx.addNewFunction(funcName, funcCode)
}
- ((0 to groupedAccessorsLength - 1).map { i => s"accessors$i();" }.mkString("\n"),
- (0 to groupedAccessorsLength - 1).map { i => s"extractors$i();" }.mkString("\n"))
+ (accessorNames.map { accessorName => s"$accessorName();" }.mkString("\n"),
+ extractorNames.map { extractorName => s"$extractorName();" }.mkString("\n"))
}
val codeBody = s"""
@@ -184,9 +182,9 @@ object GenerateColumnAccessor extends CodeGenerator[Seq[DataType], ColumnarItera
${ctx.declareMutableStates()}
public SpecificColumnarIterator() {
- this.nativeOrder = ByteOrder.nativeOrder();
- this.buffers = new byte[${columnTypes.length}][];
- this.mutableRow = new MutableUnsafeRow(rowWriter);
+ nativeOrder = ByteOrder.nativeOrder();
+ buffers = new byte[${columnTypes.length}][];
+ mutableRow = new MutableUnsafeRow(rowWriter);
}
public void initialize(Iterator input, DataType[] columnTypes, int[] columnIndexes) {
@@ -224,6 +222,9 @@ object GenerateColumnAccessor extends CodeGenerator[Seq[DataType], ColumnarItera
unsafeRow.setTotalSize(bufferHolder.totalSize());
return unsafeRow;
}
+
+ ${ctx.initNestedClasses()}
+ ${ctx.declareNestedClasses()}
}"""
val code = CodeFormatter.stripOverlappingComments(
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala
index 89a9b38132732..f8e9a91592c0b 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala
@@ -446,7 +446,7 @@ case class SortMergeJoinExec(
| }
| return false; // unreachable
|}
- """.stripMargin)
+ """.stripMargin, inlineToOuterClass = true)
(leftRow, matches)
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala
index 757fe2185d302..73a0f8735ed45 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala
@@ -75,7 +75,7 @@ trait BaseLimitExec extends UnaryExecNode with CodegenSupport {
protected boolean stopEarly() {
return $stopEarly;
}
- """)
+ """, inlineToOuterClass = true)
val countTerm = ctx.freshName("count")
ctx.addMutableState("int", countTerm, s"$countTerm = 0;")
s"""