Skip to content

Commit 36f90cf

Browse files
committed
Address comment.
1 parent fe2a1cd commit 36f90cf

File tree

1 file changed

+39
-23
lines changed

1 file changed

+39
-23
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/CodegenObjectFactory.scala

Lines changed: 39 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -20,32 +20,48 @@ package org.apache.spark.sql.catalyst.expressions
2020
import org.codehaus.commons.compiler.CompileException
2121
import org.codehaus.janino.InternalCompilerException
2222

23-
object CodegenObjectFactory {
24-
def codegenOrInterpreted[T](codegenCreator: () => T, interpretedCreator: () => T): T = {
25-
try {
26-
codegenCreator()
27-
} catch {
28-
// Catch compile error related exceptions
29-
case e: InternalCompilerException => interpretedCreator()
30-
case e: CompileException => interpretedCreator()
31-
}
23+
/**
24+
* Catches compile error during code generation.
25+
*/
26+
object CodegenError {
27+
def unapply(throwable: Throwable): Option[Exception] = throwable match {
28+
case e: InternalCompilerException => Some(e)
29+
case e: CompileException => Some(e)
30+
case _ => None
3231
}
3332
}
3433

35-
object UnsafeProjectionFactory extends UnsafeProjectionCreator {
36-
import CodegenObjectFactory._
34+
/**
35+
* A factory class which can be used to create objects that have both codegen and interpreted
36+
* implementations. This tries to create codegen object first, if any compile error happens,
37+
* it fallbacks to interpreted version.
38+
*/
39+
abstract class CodegenObjectFactory[IN, OUT] {
3740

38-
private val codegenCreator = UnsafeProjection
39-
private lazy val interpretedCreator = InterpretedUnsafeProjection
41+
def createObject(in: IN): OUT = try {
42+
createCodeGeneratedObject(in)
43+
} catch {
44+
case CodegenError(_) => createInterpretedObject(in)
45+
}
4046

41-
/**
42-
* Returns an [[UnsafeProjection]] for given sequence of bound Expressions.
43-
*/
44-
override protected[sql] def createProjection(exprs: Seq[Expression]): UnsafeProjection = {
45-
codegenOrInterpreted[UnsafeProjection](() => codegenCreator.createProjection(exprs),
46-
() => interpretedCreator.createProjection(exprs))
47+
protected def createCodeGeneratedObject(in: IN): OUT
48+
protected def createInterpretedObject(in: IN): OUT
49+
}
50+
51+
object UnsafeProjectionFactory extends CodegenObjectFactory[Seq[Expression], UnsafeProjection]
52+
with UnsafeProjectionCreator {
53+
54+
override protected def createCodeGeneratedObject(in: Seq[Expression]): UnsafeProjection = {
55+
UnsafeProjection.createProjection(in)
4756
}
4857

58+
override protected def createInterpretedObject(in: Seq[Expression]): UnsafeProjection = {
59+
InterpretedUnsafeProjection.createProjection(in)
60+
}
61+
62+
override protected[sql] def createProjection(exprs: Seq[Expression]): UnsafeProjection =
63+
createObject(exprs)
64+
4965
/**
5066
* Same as other create()'s but allowing enabling/disabling subexpression elimination.
5167
* The param `subexpressionEliminationEnabled` doesn't guarantee to work. For example,
@@ -54,10 +70,10 @@ object UnsafeProjectionFactory extends UnsafeProjectionCreator {
5470
def create(
5571
exprs: Seq[Expression],
5672
inputSchema: Seq[Attribute],
57-
subexpressionEliminationEnabled: Boolean): UnsafeProjection = {
58-
codegenOrInterpreted[UnsafeProjection](
59-
() => codegenCreator.create(exprs, inputSchema, subexpressionEliminationEnabled),
60-
() => interpretedCreator.create(exprs, inputSchema))
73+
subexpressionEliminationEnabled: Boolean): UnsafeProjection = try {
74+
UnsafeProjection.create(exprs, inputSchema, subexpressionEliminationEnabled)
75+
} catch {
76+
case CodegenError(_) => InterpretedUnsafeProjection.create(exprs, inputSchema)
6177
}
6278
}
6379

0 commit comments

Comments
 (0)