Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.trees.TreeNode
import org.apache.spark.sql.catalyst.util.toCommentSafeString
import org.apache.spark.sql.types._

////////////////////////////////////////////////////////////////////////////////////////////////////
Expand Down Expand Up @@ -97,15 +96,14 @@ abstract class Expression extends TreeNode[Expression] {
ctx.subExprEliminationExprs.get(this).map { subExprState =>
// This expression is repeated which means that the code to evaluate it has already been added
// as a function before. In that case, we just re-use it.
val code = s"/* ${toCommentSafeString(this.toString)} */"
ExprCode(code, subExprState.isNull, subExprState.value)
ExprCode(ctx.registerComment(this.toString), subExprState.isNull, subExprState.value)
}.getOrElse {
val isNull = ctx.freshName("isNull")
val value = ctx.freshName("value")
val ve = doGenCode(ctx, ExprCode("", isNull, value))
if (ve.code.nonEmpty) {
// Add `this` in the comment.
ve.copy(s"/* ${toCommentSafeString(this.toString)} */\n" + ve.code.trim)
ve.copy(code = s"${ctx.registerComment(this.toString)}\n" + ve.code.trim)
} else {
ve
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,24 @@

package org.apache.spark.sql.catalyst.expressions.codegen

import org.apache.commons.lang3.StringUtils

/**
* An utility class that indents a block of code based on the curly braces and parentheses.
* This is used to prettify generated code when in debug mode (or exceptions).
*
* Written by Matei Zaharia.
*/
object CodeFormatter {
def format(code: String): String = new CodeFormatter().addLines(code).result()
def format(code: CodeAndComment): String = {
new CodeFormatter().addLines(
StringUtils.replaceEach(
code.body,
code.comment.keys.toArray,
code.comment.values.toArray)
).result
}

def stripExtraNewLines(input: String): String = {
val code = new StringBuilder
var lastLine: String = "dummy"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,11 @@ class CodegenContext {
*/
var freshNamePrefix = ""

/**
* The map from a place holder to a corresponding comment
*/
private val placeHolderToComments = new mutable.HashMap[String, String]

/**
* Returns a term name that is unique within this instance of a `CodegenContext`.
*/
Expand Down Expand Up @@ -706,6 +711,35 @@ class CodegenContext {
if (doSubexpressionElimination) subexpressionElimination(expressions)
expressions.map(e => e.genCode(this))
}

/**
* get a map of the pair of a place holder and a corresponding comment
*/
def getPlaceHolderToComments(): collection.Map[String, String] = placeHolderToComments

/**
* Register a multi-line comment and return the corresponding place holder
*/
private def registerMultilineComment(text: String): String = {
val placeHolder = s"/*${freshName("c")}*/"
val comment = text.split("(\r\n)|\r|\n").mkString("/**\n * ", "\n * ", "\n */")
placeHolderToComments += (placeHolder -> comment)
placeHolder
}

/**
* Register a comment and return the corresponding place holder
*/
def registerComment(text: String): String = {
if (text.contains("\n") || text.contains("\r")) {
registerMultilineComment(text)
} else {
val placeHolder = s"/*${freshName("c")}*/"
val safeComment = s"// $text"
placeHolderToComments += (placeHolder -> safeComment)
placeHolder
}
}
}

/**
Expand All @@ -716,6 +750,19 @@ abstract class GeneratedClass {
def generate(references: Array[Any]): Any
}

/**
* A wrapper for the source code to be compiled by [[CodeGenerator]].
*/
class CodeAndComment(val body: String, val comment: collection.Map[String, String])
extends Serializable {
override def equals(that: Any): Boolean = that match {
case t: CodeAndComment if t.body == body => true
case _ => false
}

override def hashCode(): Int = body.hashCode
}

/**
* A base class for generators of byte code to perform expression evaluation. Includes a set of
* helpers for referring to Catalyst types and building trees that perform evaluation of individual
Expand Down Expand Up @@ -760,14 +807,14 @@ object CodeGenerator extends Logging {
/**
* Compile the Java source code into a Java class, using Janino.
*/
def compile(code: String): GeneratedClass = {
def compile(code: CodeAndComment): GeneratedClass = {
cache.get(code)
}

/**
* Compile the Java source code into a Java class, using Janino.
*/
private[this] def doCompile(code: String): GeneratedClass = {
private[this] def doCompile(code: CodeAndComment): GeneratedClass = {
val evaluator = new ClassBodyEvaluator()
evaluator.setParentClassLoader(Utils.getContextOrSparkClassLoader)
// Cannot be under package codegen, or fail with java.lang.InstantiationException
Expand All @@ -788,7 +835,7 @@ object CodeGenerator extends Logging {
))
evaluator.setExtendedClass(classOf[GeneratedClass])

def formatted = CodeFormatter.format(code)
lazy val formatted = CodeFormatter.format(code)

logDebug({
// Only add extra debugging info to byte code when we are going to print the source code.
Expand All @@ -797,7 +844,7 @@ object CodeGenerator extends Logging {
})

try {
evaluator.cook("generated.java", code)
evaluator.cook("generated.java", code.body)
} catch {
case e: Exception =>
val msg = s"failed to compile: $e\n$formatted"
Expand All @@ -819,8 +866,8 @@ object CodeGenerator extends Logging {
private val cache = CacheBuilder.newBuilder()
.maximumSize(100)
.build(
new CacheLoader[String, GeneratedClass]() {
override def load(code: String): GeneratedClass = {
new CacheLoader[CodeAndComment, GeneratedClass]() {
override def load(code: CodeAndComment): GeneratedClass = {
val startTime = System.nanoTime()
val result = doCompile(code)
val endTime = System.nanoTime()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
package org.apache.spark.sql.catalyst.expressions.codegen

import org.apache.spark.sql.catalyst.expressions.{Expression, LeafExpression, Nondeterministic}
import org.apache.spark.sql.catalyst.util.toCommentSafeString

/**
* A trait that can be used to provide a fallback mode for expression code generation.
Expand All @@ -36,9 +35,10 @@ trait CodegenFallback extends Expression {
val idx = ctx.references.length
ctx.references += this
val objectTerm = ctx.freshName("obj")
val placeHolder = ctx.registerComment(this.toString)
if (nullable) {
ev.copy(code = s"""
/* expression: ${toCommentSafeString(this.toString)} */
$placeHolder
Object $objectTerm = ((Expression) references[$idx]).eval($input);
boolean ${ev.isNull} = $objectTerm == null;
${ctx.javaType(this.dataType)} ${ev.value} = ${ctx.defaultValue(this.dataType)};
Expand All @@ -47,7 +47,7 @@ trait CodegenFallback extends Expression {
}""")
} else {
ev.copy(code = s"""
/* expression: ${toCommentSafeString(this.toString)} */
$placeHolder
Object $objectTerm = ((Expression) references[$idx]).eval($input);
${ctx.javaType(this.dataType)} ${ev.value} = (${ctx.boxedType(this.dataType)}) $objectTerm;
""", isNull = "false")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], MutableP
val allProjections = ctx.splitExpressions(ctx.INPUT_ROW, projectionCodes)
val allUpdates = ctx.splitExpressions(ctx.INPUT_ROW, updates)

val code = s"""
val codeBody = s"""
public java.lang.Object generate(Object[] references) {
return new SpecificMutableProjection(references);
}
Expand Down Expand Up @@ -133,6 +133,7 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], MutableP
}
"""

val code = new CodeAndComment(codeBody, ctx.getPlaceHolderToComments())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this line does not need to change

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry It's not clear for me why this line does not need to change.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

format() could accept CodeAndComment now, and code is CodeAndComment

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see. You mean we don't need change the line to logDebug(s"code for ${expressions.mkString(",")}:\n$formatted").

logDebug(s"code for ${expressions.mkString(",")}:\n${CodeFormatter.format(code)}")

val c = CodeGenerator.compile(code)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ object GenerateOrdering extends CodeGenerator[Seq[SortOrder], Ordering[InternalR
protected def create(ordering: Seq[SortOrder]): BaseOrdering = {
val ctx = newCodeGenContext()
val comparisons = genComparisons(ctx, ordering)
val code = s"""
val codeBody = s"""
public SpecificOrdering generate(Object[] references) {
return new SpecificOrdering(references);
}
Expand All @@ -136,6 +136,7 @@ object GenerateOrdering extends CodeGenerator[Seq[SortOrder], Ordering[InternalR
}
}"""

val code = new CodeAndComment(codeBody, ctx.getPlaceHolderToComments())
logDebug(s"Generated Ordering by ${ordering.mkString(",")}:\n${CodeFormatter.format(code)}")

CodeGenerator.compile(code).generate(ctx.references.toArray).asInstanceOf[BaseOrdering]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ object GeneratePredicate extends CodeGenerator[Expression, (InternalRow) => Bool
protected def create(predicate: Expression): ((InternalRow) => Boolean) = {
val ctx = newCodeGenContext()
val eval = predicate.genCode(ctx)
val code = s"""
val codeBody = s"""
public SpecificPredicate generate(Object[] references) {
return new SpecificPredicate(references);
}
Expand All @@ -61,6 +61,7 @@ object GeneratePredicate extends CodeGenerator[Expression, (InternalRow) => Bool
}
}"""

val code = new CodeAndComment(codeBody, ctx.getPlaceHolderToComments())
logDebug(s"Generated predicate '$predicate':\n${CodeFormatter.format(code)}")

val p = CodeGenerator.compile(code).generate(ctx.references.toArray).asInstanceOf[Predicate]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection]
"""
}
val allExpressions = ctx.splitExpressions(ctx.INPUT_ROW, expressionCodes)
val code = s"""
val codeBody = s"""
public java.lang.Object generate(Object[] references) {
return new SpecificSafeProjection(references);
}
Expand All @@ -181,6 +181,7 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection]
}
"""

val code = new CodeAndComment(codeBody, ctx.getPlaceHolderToComments())
logDebug(s"code for ${expressions.mkString(",")}:\n${CodeFormatter.format(code)}")

val c = CodeGenerator.compile(code)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -362,7 +362,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
val ctx = newCodeGenContext()
val eval = createCode(ctx, expressions, subexpressionEliminationEnabled)

val code = s"""
val codeBody = s"""
public java.lang.Object generate(Object[] references) {
return new SpecificUnsafeProjection(references);
}
Expand Down Expand Up @@ -390,6 +390,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
}
"""

val code = new CodeAndComment(codeBody, ctx.getPlaceHolderToComments())
logDebug(s"code for ${expressions.mkString(",")}:\n${CodeFormatter.format(code)}")

val c = CodeGenerator.compile(code)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ object GenerateUnsafeRowJoiner extends CodeGenerator[(StructType, StructType), U
}.mkString("\n")

// ------------------------ Finally, put everything together --------------------------- //
val code = s"""
val codeBody = s"""
|public java.lang.Object generate(Object[] references) {
| return new SpecificUnsafeRowJoiner();
|}
Expand Down Expand Up @@ -193,7 +193,7 @@ object GenerateUnsafeRowJoiner extends CodeGenerator[(StructType, StructType), U
| }
|}
""".stripMargin

val code = new CodeAndComment(codeBody, Map.empty)
logDebug(s"SpecificUnsafeRowJoiner($schema1, $schema2):\n${CodeFormatter.format(code)}")

val c = CodeGenerator.compile(code)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -155,27 +155,6 @@ package object util {

def toPrettySQL(e: Expression): String = usePrettyExpression(e).sql

/**
* Returns the string representation of this expression that is safe to be put in
* code comments of generated code. The length is capped at 128 characters.
*/
def toCommentSafeString(str: String): String = {
val len = math.min(str.length, 128)
val suffix = if (str.length > len) "..." else ""

// Unicode literals, like \u0022, should be escaped before
// they are put in code comment to avoid codegen breaking.
// To escape them, single "\" should be prepended to a series of "\" just before "u"
// only when the number of "\" is odd.
// For example, \u0022 should become to \\u0022
// but \\u0022 should not become to \\\u0022 because the first backslash escapes the second one,
// and \u0022 will remain, means not escaped.
// Otherwise, the runtime Java compiler will fail to compile or code injection can be allowed.
// For details, see SPARK-15165.
str.substring(0, len).replace("*/", "*\\/")
.replaceAll("(^|[^\\\\])(\\\\(\\\\\\\\)*u)", "$1\\\\$2") + suffix
}

/* FIX ME
implicit class debugLogging(a: Any) {
def debugLogging() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,12 @@ class CodeFormatterSuite extends SparkFunSuite {

def testCase(name: String)(input: String)(expected: String): Unit = {
test(name) {
if (CodeFormatter.format(input).trim !== expected.trim) {
val sourceCode = new CodeAndComment(input, Map.empty)
if (CodeFormatter.format(sourceCode).trim !== expected.trim) {
fail(
s"""
|== FAIL: Formatted code doesn't match ===
|${sideBySide(CodeFormatter.format(input).trim, expected.trim).mkString("\n")}
|${sideBySide(CodeFormatter.format(sourceCode).trim, expected.trim).mkString("\n")}
""".stripMargin)
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Statistics}
import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, Partitioning, UnknownPartitioning}
import org.apache.spark.sql.catalyst.util.toCommentSafeString
import org.apache.spark.sql.execution.datasources.HadoopFsRelation
import org.apache.spark.sql.execution.datasources.parquet.{DefaultSource => ParquetSource}
import org.apache.spark.sql.execution.metric.SQLMetrics
Expand Down Expand Up @@ -252,7 +251,7 @@ private[sql] case class BatchedDataSourceScanExec(
val isNullVar = if (nullable) { ctx.freshName("isNull") } else { "false" }
val valueVar = ctx.freshName("value")
val str = s"columnVector[$columnVar, $ordinal, ${dataType.simpleString}]"
val code = s"/* ${toCommentSafeString(str)} */\n" + (if (nullable) {
val code = s"${ctx.registerComment(str)}\n" + (if (nullable) {
s"""
boolean ${isNullVar} = ${columnVar}.isNullAt($ordinal);
$javaType ${valueVar} = ${isNullVar} ? ${ctx.defaultValue(dataType)} : ($value);
Expand Down
Loading