Skip to content

Commit be84c4d

Browse files
committed
refactor splitExpressions
1 parent 6a1eeca commit be84c4d

File tree

1 file changed

+64
-33
lines changed
  • sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen

1 file changed

+64
-33
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala

Lines changed: 64 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@ import java.io.ByteArrayInputStream
2121
import java.util.{Map => JavaMap}
2222

2323
import scala.collection.JavaConverters._
24-
import scala.collection.immutable.ListMap
2524
import scala.collection.mutable
2625
import scala.collection.mutable.ArrayBuffer
2726
import scala.language.existentials
@@ -840,44 +839,76 @@ class CodegenContext {
840839
.filter(_.innerClassName.isEmpty)
841840
.map(_.functionName)
842841

843-
val argsString = arguments.map(_._2).mkString(", ")
844-
845-
// Here we handle all the methods which have been added to the inner classes and
846-
// not to the outer class.
847-
// Since they can be many, their direct invocation in the outer class adds many entries
848-
// to the outer class' constant pool. This can cause the constant pool to past JVM limit.
849-
// To avoid this problem, we group them and we call only the grouping methods in the
850-
// outer class.
851-
val innerClassToFunctions = functions
852-
.filter(_.innerClassName.isDefined)
853-
.foldLeft(ListMap.empty[(String, String), Seq[String]]) { case (acc, f) =>
854-
val key = (f.innerClassName.get, f.innerClassInstance.get)
855-
acc.updated(key, acc.getOrElse(key, Seq.empty[String]) ++ Seq(f.functionName))
856-
}
857-
val innerClassFunctions = innerClassToFunctions.flatMap {
858-
case ((innerClassName, innerClassInstance), innerClassFunctions) =>
859-
if (innerClassFunctions.size > CodeGenerator.MERGE_SPLIT_METHODS_THRESHOLD) {
860-
// Adding a new function to each inner class which contains
861-
// the invocation of all the ones which have been added to
862-
// that inner class
863-
val body = foldFunctions(innerClassFunctions.map(name => s"$name($argsString)"))
864-
val code = s"""
865-
|private $returnType $func($argString) {
866-
| ${makeSplitFunction(body)}
867-
|}
868-
""".stripMargin
869-
addNewFunctionToClass(func, code, innerClassName)
870-
Seq(s"$innerClassInstance.$func")
871-
} else {
872-
innerClassFunctions.map(f => s"$innerClassInstance.$f")
873-
}
874-
}
842+
val innerClassFunctions = generateInnerClassesMethodsCalls(
843+
functions.filter(_.innerClassName.nonEmpty),
844+
func,
845+
arguments,
846+
returnType,
847+
makeSplitFunction,
848+
foldFunctions)
875849

850+
val argsString = arguments.map(_._2).mkString(", ")
876851
foldFunctions((outerClassFunctions ++ innerClassFunctions).map(
877852
name => s"$name($argsString)"))
878853
}
879854
}
880855

856+
/**
857+
* Here we handle all the methods which have been added to the inner classes and
858+
* not to the outer class.
859+
* Since they can be many, their direct invocation in the outer class adds many entries
860+
* to the outer class' constant pool. This can cause the constant pool to past JVM limit.
861+
* Moreover, this can cause also the outer class method where all the invocations are
862+
* performed to grow beyond the 64k limit.
863+
* To avoid these problems, we group them and we call only the grouping methods in the
864+
* outer class.
865+
*
866+
* @param functions a [[Seq]] of [[NewFunctionSpec]] defined in the inner classes
867+
* @param funcName the split function name base.
868+
* @param arguments the list of (type, name) of the arguments of the split function.
869+
* @param returnType the return type of the split function.
870+
* @param makeSplitFunction makes split function body, e.g. add preparation or cleanup.
871+
* @param foldFunctions folds the split function calls.
872+
* @return an [[Iterable]] containing the methods' invocations
873+
*/
874+
private def generateInnerClassesMethodsCalls(
875+
functions: Seq[NewFunctionSpec],
876+
funcName: String,
877+
arguments: Seq[(String, String)],
878+
returnType: String,
879+
makeSplitFunction: String => String,
880+
foldFunctions: Seq[String] => String): Iterable[String] = {
881+
val innerClassToFunctions = mutable.ListMap.empty[(String, String), Seq[String]]
882+
functions.foreach(f => {
883+
val key = (f.innerClassName.get, f.innerClassInstance.get)
884+
innerClassToFunctions.update(key, f.functionName +:
885+
innerClassToFunctions.getOrElse(key, Seq.empty[String]))
886+
})
887+
888+
val argDefinitionString = arguments.map { case (t, name) => s"$t $name" }.mkString(", ")
889+
val argInvocationString = arguments.map(_._2).mkString(", ")
890+
891+
innerClassToFunctions.flatMap {
892+
case ((innerClassName, innerClassInstance), innerClassFunctions) =>
893+
if (innerClassFunctions.size > CodeGenerator.MERGE_SPLIT_METHODS_THRESHOLD) {
894+
// Adding a new function to each inner class which contains
895+
// the invocation of all the ones which have been added to
896+
// that inner class
897+
val body = foldFunctions(innerClassFunctions.map(name =>
898+
s"$name($argInvocationString)"))
899+
val code = s"""
900+
|private $returnType $funcName($argDefinitionString) {
901+
| ${makeSplitFunction(body)}
902+
|}
903+
""".stripMargin
904+
addNewFunctionToClass(funcName, code, innerClassName)
905+
Seq(s"$innerClassInstance.$funcName")
906+
} else {
907+
innerClassFunctions.map(f => s"$innerClassInstance.$f")
908+
}
909+
}
910+
}
911+
881912
/**
882913
* Perform a function which generates a sequence of ExprCodes with a given mapping between
883914
* expressions and common expressions, instead of using the mapping in current context.

0 commit comments

Comments
 (0)