Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -95,15 +95,17 @@ abstract class Expression extends TreeNode[Expression] {
ctx.subExprEliminationExprs.get(this).map { subExprState =>
// This expression is repeated meaning the code to evaluated has already been added
// as a function and called in advance. Just use it.
val code = s"/* ${this.toCommentSafeString} */"
GeneratedExpressionCode(code, subExprState.isNull, subExprState.value)
GeneratedExpressionCode(
ctx.registerComment(this.toString),
subExprState.isNull,
subExprState.value)
}.getOrElse {
val isNull = ctx.freshName("isNull")
val primitive = ctx.freshName("primitive")
val ve = GeneratedExpressionCode("", isNull, primitive)
ve.code = genCode(ctx, ve)
// Add `this` in the comment.
ve.copy(s"/* ${this.toCommentSafeString} */\n" + ve.code.trim)
ve.copy(code = s"${ctx.registerComment(this.toString)}\n" + ve.code.trim)
}
}

Expand Down Expand Up @@ -215,14 +217,6 @@ abstract class Expression extends TreeNode[Expression] {
override def simpleString: String = toString

override def toString: String = prettyName + flatArguments.mkString("(", ",", ")")

/**
* Returns the string representation of this expression that is safe to be put in
* code comments of generated code.
*/
protected def toCommentSafeString: String = this.toString
.replace("*/", "\\*\\/")
.replace("\\u", "\\\\u")
}


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,23 @@

package org.apache.spark.sql.catalyst.expressions.codegen

import org.apache.commons.lang3.StringUtils

/**
* An utility class that indents a block of code based on the curly braces and parentheses.
* This is used to prettify generated code when in debug mode (or exceptions).
*
* Written by Matei Zaharia.
*/
object CodeFormatter {
def format(code: String): String = new CodeFormatter().addLines(code).result()
def format(code: CodeAndComment): String = {
new CodeFormatter().addLines(
StringUtils.replaceEach(
code.body,
code.comment.keys.toArray,
code.comment.values.toArray)
).result
}
}

private class CodeFormatter {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,11 @@ class CodeGenContext {

private val curId = new java.util.concurrent.atomic.AtomicInteger()

/**
* The map from a place holder to a corresponding comment
*/
private val placeHolderToComments = new mutable.HashMap[String, String]

/**
* Returns a term name that is unique within this instance of a `CodeGenerator`.
*
Expand Down Expand Up @@ -458,6 +463,35 @@ class CodeGenContext {
if (doSubexpressionElimination) subexpressionElimination(expressions)
expressions.map(e => e.gen(this))
}

/**
* get a map of the pair of a place holder and a corresponding comment
*/
def getPlaceHolderToComments(): collection.Map[String, String] = placeHolderToComments

/**
* Register a multi-line comment and return the corresponding place holder
*/
private def registerMultilineComment(text: String): String = {
val placeHolder = s"/*${freshName("c")}*/"
val comment = text.split("(\r\n)|\r|\n").mkString("/**\n * ", "\n * ", "\n */")
placeHolderToComments += (placeHolder -> comment)
placeHolder
}

/**
* Register a comment and return the corresponding place holder
*/
def registerComment(text: String): String = {
if (text.contains("\n") || text.contains("\r")) {
registerMultilineComment(text)
} else {
val placeHolder = s"/*${freshName("c")}*/"
val safeComment = s"// $text"
placeHolderToComments += (placeHolder -> safeComment)
placeHolder
}
}
}

/**
Expand All @@ -468,6 +502,19 @@ abstract class GeneratedClass {
def generate(expressions: Array[Expression]): Any
}

/**
* A wrapper for the source code to be compiled by [[CodeGenerator]].
*/
class CodeAndComment(val body: String, val comment: collection.Map[String, String])
extends Serializable {
override def equals(that: Any): Boolean = that match {
case t: CodeAndComment if t.body == body => true
case _ => false
}

override def hashCode(): Int = body.hashCode
}

/**
* A base class for generators of byte code to perform expression evaluation. Includes a set of
* helpers for referring to Catalyst types and building trees that perform evaluation of individual
Expand Down Expand Up @@ -511,14 +558,14 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin
/**
* Compile the Java source code into a Java class, using Janino.
*/
protected def compile(code: String): GeneratedClass = {
protected def compile(code: CodeAndComment): GeneratedClass = {
cache.get(code)
}

/**
* Compile the Java source code into a Java class, using Janino.
*/
private[this] def doCompile(code: String): GeneratedClass = {
private[this] def doCompile(code: CodeAndComment): GeneratedClass = {
val evaluator = new ClassBodyEvaluator()
evaluator.setParentClassLoader(Utils.getContextOrSparkClassLoader)
// Cannot be under package codegen, or fail with java.lang.InstantiationException
Expand All @@ -538,7 +585,7 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin
))
evaluator.setExtendedClass(classOf[GeneratedClass])

def formatted = CodeFormatter.format(code)
lazy val formatted = CodeFormatter.format(code)

logDebug({
// Only add extra debugging info to byte code when we are going to print the source code.
Expand All @@ -547,7 +594,7 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin
})

try {
evaluator.cook("generated.java", code)
evaluator.cook("generated.java", code.body)
} catch {
case e: Exception =>
val msg = s"failed to compile: $e\n$formatted"
Expand All @@ -569,8 +616,8 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin
private val cache = CacheBuilder.newBuilder()
.maximumSize(100)
.build(
new CacheLoader[String, GeneratedClass]() {
override def load(code: String): GeneratedClass = {
new CacheLoader[CodeAndComment, GeneratedClass]() {
override def load(code: CodeAndComment): GeneratedClass = {
val startTime = System.nanoTime()
val result = doCompile(code)
val endTime = System.nanoTime()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,9 @@ trait CodegenFallback extends Expression {

ctx.references += this
val objectTerm = ctx.freshName("obj")
val placeHolder = ctx.registerComment(this.toString)
s"""
/* expression: ${this.toCommentSafeString} */
$placeHolder
java.lang.Object $objectTerm = expressions[${ctx.references.size - 1}].eval(${ctx.INPUT_ROW});
boolean ${ev.isNull} = $objectTerm == null;
${ctx.javaType(this.dataType)} ${ev.value} = ${ctx.defaultValue(this.dataType)};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], () => Mu
val allProjections = ctx.splitExpressions(ctx.INPUT_ROW, projectionCodes)
val allUpdates = ctx.splitExpressions(ctx.INPUT_ROW, updates)

val code = s"""
val codeBody = s"""
public java.lang.Object generate($exprType[] expr) {
return new SpecificMutableProjection(expr);
}
Expand Down Expand Up @@ -119,6 +119,7 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], () => Mu
}
"""

val code = new CodeAndComment(codeBody, ctx.getPlaceHolderToComments())
logDebug(s"code for ${expressions.mkString(",")}:\n${CodeFormatter.format(code)}")

val c = compile(code)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ object GenerateOrdering extends CodeGenerator[Seq[SortOrder], Ordering[InternalR
protected def create(ordering: Seq[SortOrder]): BaseOrdering = {
val ctx = newCodeGenContext()
val comparisons = genComparisons(ctx, ordering)
val code = s"""
val codeBody = s"""
public SpecificOrdering generate($exprType[] expr) {
return new SpecificOrdering(expr);
}
Expand All @@ -133,6 +133,7 @@ object GenerateOrdering extends CodeGenerator[Seq[SortOrder], Ordering[InternalR
}
}"""

val code = new CodeAndComment(codeBody, ctx.getPlaceHolderToComments())
logDebug(s"Generated Ordering: ${CodeFormatter.format(code)}")

compile(code).generate(ctx.references.toArray).asInstanceOf[BaseOrdering]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ object GeneratePredicate extends CodeGenerator[Expression, (InternalRow) => Bool
protected def create(predicate: Expression): ((InternalRow) => Boolean) = {
val ctx = newCodeGenContext()
val eval = predicate.gen(ctx)
val code = s"""
val codeBody = s"""
public SpecificPredicate generate($exprType[] expr) {
return new SpecificPredicate(expr);
}
Expand All @@ -61,6 +61,7 @@ object GeneratePredicate extends CodeGenerator[Expression, (InternalRow) => Bool
}
}"""

val code = new CodeAndComment(codeBody, ctx.getPlaceHolderToComments())
logDebug(s"Generated predicate '$predicate':\n${CodeFormatter.format(code)}")

val p = compile(code).generate(ctx.references.toArray).asInstanceOf[Predicate]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] {
s"""if (!nullBits[$i]) arr[$i] = c$i;"""
}.mkString("\n")

val code = s"""
val codeBody = s"""
public SpecificProjection generate($exprType[] expr) {
return new SpecificProjection(expr);
}
Expand Down Expand Up @@ -230,6 +230,7 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] {
}
"""

val code = new CodeAndComment(codeBody, ctx.getPlaceHolderToComments())
logDebug(s"MutableRow, initExprs: ${expressions.mkString(",")} code:\n" +
CodeFormatter.format(code))

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection]
"""
}
val allExpressions = ctx.splitExpressions(ctx.INPUT_ROW, expressionCodes)
val code = s"""
val codeBody = s"""
public java.lang.Object generate($exprType[] expr) {
return new SpecificSafeProjection(expr);
}
Expand All @@ -173,6 +173,7 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection]
}
"""

val code = new CodeAndComment(codeBody, ctx.getPlaceHolderToComments())
logDebug(s"code for ${expressions.mkString(",")}:\n${CodeFormatter.format(code)}")

val c = compile(code)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -323,7 +323,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
val ctx = newCodeGenContext()
val eval = createCode(ctx, expressions, subexpressionEliminationEnabled)

val code = s"""
val codeBody = s"""
public java.lang.Object generate($exprType[] exprs) {
return new SpecificUnsafeProjection(exprs);
}
Expand Down Expand Up @@ -353,6 +353,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
}
"""

val code = new CodeAndComment(codeBody, ctx.getPlaceHolderToComments())
logDebug(s"code for ${expressions.mkString(",")}:\n${CodeFormatter.format(code)}")

val c = compile(code)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ object GenerateUnsafeRowJoiner extends CodeGenerator[(StructType, StructType), U
}.mkString("\n")

// ------------------------ Finally, put everything together --------------------------- //
val code = s"""
val codeBody = s"""
|public java.lang.Object generate($exprType[] exprs) {
| return new SpecificUnsafeRowJoiner();
|}
Expand Down Expand Up @@ -195,6 +195,7 @@ object GenerateUnsafeRowJoiner extends CodeGenerator[(StructType, StructType), U
|}
""".stripMargin

val code = new CodeAndComment(codeBody, Map.empty)
logDebug(s"SpecificUnsafeRowJoiner($schema1, $schema2):\n${CodeFormatter.format(code)}")

val c = compile(code)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -138,4 +138,47 @@ class CodeGenerationSuite extends SparkFunSuite with ExpressionEvalHelper {
true,
InternalRow(UTF8String.fromString("\\u")))
}

test("check compilation error doesn't occur caused by specific literal") {
// The end of comment (*/) should be escaped.
GenerateUnsafeProjection.generate(
Literal.create("*/Compilation error occurs/*", StringType) :: Nil)

// `\u002A` is `*` and `\u002F` is `/`
// so if the end of comment consists of those characters in queries, we need to escape them.
GenerateUnsafeProjection.generate(
Literal.create("\\u002A/Compilation error occurs/*", StringType) :: Nil)
GenerateUnsafeProjection.generate(
Literal.create("\\\\u002A/Compilation error occurs/*", StringType) :: Nil)
GenerateUnsafeProjection.generate(
Literal.create("\\u002a/Compilation error occurs/*", StringType) :: Nil)
GenerateUnsafeProjection.generate(
Literal.create("\\\\u002a/Compilation error occurs/*", StringType) :: Nil)
GenerateUnsafeProjection.generate(
Literal.create("*\\u002FCompilation error occurs/*", StringType) :: Nil)
GenerateUnsafeProjection.generate(
Literal.create("*\\\\u002FCompilation error occurs/*", StringType) :: Nil)
GenerateUnsafeProjection.generate(
Literal.create("*\\002fCompilation error occurs/*", StringType) :: Nil)
GenerateUnsafeProjection.generate(
Literal.create("*\\\\002fCompilation error occurs/*", StringType) :: Nil)
GenerateUnsafeProjection.generate(
Literal.create("\\002A\\002FCompilation error occurs/*", StringType) :: Nil)
GenerateUnsafeProjection.generate(
Literal.create("\\\\002A\\002FCompilation error occurs/*", StringType) :: Nil)
GenerateUnsafeProjection.generate(
Literal.create("\\002A\\\\002FCompilation error occurs/*", StringType) :: Nil)

// \ u002X is an invalid unicode literal so it should be escaped.
GenerateUnsafeProjection.generate(
Literal.create("\\u002X/Compilation error occurs", StringType) :: Nil)
GenerateUnsafeProjection.generate(
Literal.create("\\\\u002X/Compilation error occurs", StringType) :: Nil)

// \ u001 is an invalid unicode literal so it should be escaped.
GenerateUnsafeProjection.generate(
Literal.create("\\u001/Compilation error occurs", StringType) :: Nil)
GenerateUnsafeProjection.generate(
Literal.create("\\\\u001/Compilation error occurs", StringType) :: Nil)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@ class CodeFormatterSuite extends SparkFunSuite {

def testCase(name: String)(input: String)(expected: String): Unit = {
test(name) {
assert(CodeFormatter.format(input).trim === expected.trim)
val sourceCode = new CodeAndComment(input, Map.empty)
assert(CodeFormatter.format(sourceCode).trim === expected.trim)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ import scala.collection.mutable
import org.apache.spark.Logging
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.codegen.{UnsafeRowWriter, CodeFormatter, CodeGenerator}
import org.apache.spark.sql.catalyst.expressions.codegen.{CodeAndComment, CodeFormatter, CodeGenerator, UnsafeRowWriter}
import org.apache.spark.sql.types._

/**
Expand Down Expand Up @@ -152,7 +152,7 @@ object GenerateColumnAccessor extends CodeGenerator[Seq[DataType], ColumnarItera
(0 to groupedAccessorsLength - 1).map { i => s"extractors$i();" }.mkString("\n"))
}

val code = s"""
val codeBody = s"""
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import scala.collection.Iterator;
Expand Down Expand Up @@ -226,6 +226,7 @@ object GenerateColumnAccessor extends CodeGenerator[Seq[DataType], ColumnarItera
}
}"""

val code = new CodeAndComment(codeBody, ctx.getPlaceHolderToComments())
logDebug(s"Generated ColumnarIterator: ${CodeFormatter.format(code)}")

compile(code).generate(ctx.references.toArray).asInstanceOf[ColumnarIterator]
Expand Down
Loading