1717
1818package org .apache .spark .sql .catalyst .expressions .codegen
1919
20+ import com .google .common .cache .{CacheLoader , CacheBuilder }
21+
2022import scala .language .existentials
2123
2224import org .apache .spark .Logging
@@ -26,27 +28,53 @@ import org.apache.spark.sql.catalyst.types._
2628
2729/**
2830 * A base class for generators of byte code that performs expression evaluation. Includes helpers
29- * for refering to Catalyst types and building trees that perform evaluation of individual
31+ * for referring to Catalyst types and building trees that perform evaluation of individual
3032 * expressions.
3133 */
32- abstract class CodeGenerator extends Logging {
34+ abstract class CodeGenerator [ InType <: AnyRef , OutType <: AnyRef ] extends Logging {
3335 import scala .reflect .runtime .{universe => ru }
3436 import scala .reflect .runtime .universe ._
3537
3638 import scala .tools .reflect .ToolBox
3739
38- val toolBox = runtimeMirror(getClass.getClassLoader).mkToolBox()
40+ protected val toolBox = runtimeMirror(getClass.getClassLoader).mkToolBox()
3941
40- val rowType = typeOf[Row ]
41- val mutableRowType = typeOf[MutableRow ]
42- val genericRowType = typeOf[GenericRow ]
43- val genericMutableRowType = typeOf[GenericMutableRow ]
42+ protected val rowType = typeOf[Row ]
43+ protected val mutableRowType = typeOf[MutableRow ]
44+ protected val genericRowType = typeOf[GenericRow ]
45+ protected val genericMutableRowType = typeOf[GenericMutableRow ]
4446
45- val projectionType = typeOf[Projection ]
46- val mutableProjectionType = typeOf[MutableProjection ]
47+ protected val projectionType = typeOf[Projection ]
48+ protected val mutableProjectionType = typeOf[MutableProjection ]
4749
4850 private val curId = new java.util.concurrent.atomic.AtomicInteger ()
49- private val javaSeperator = " $"
51+ private val javaSeparator = " $"
52+
53+ /**
54+ * Generates a class for a given input expression. Called when there is not a cached code
55+ * already available.
56+ */
57+ protected def create (in : InType ): OutType
58+
59+ /** Canonicalizes an input expression. */
60+ protected def canonicalize (in : InType ): InType
61+
62+ /** Binds an input expression to a given input schema */
63+ protected def bind (in : InType , inputSchema : Seq [Attribute ]): InType
64+
65+ protected val cache = CacheBuilder .newBuilder()
66+ .maximumSize(1000 )
67+ .build(
68+ new CacheLoader [InType , OutType ]() {
69+ override def load (in : InType ): OutType = globalLock.synchronized {
70+ create(in)
71+ }
72+ })
73+
74+ def apply (expressions : InType , inputSchema : Seq [Attribute ]): OutType =
75+ apply(bind(expressions, inputSchema))
76+
77+ def apply (expressions : InType ): OutType = cache.get(canonicalize(expressions))
5078
5179 /**
5280 * Returns a term name that is unique within this instance of a `CodeGenerator`.
@@ -55,7 +83,7 @@ abstract class CodeGenerator extends Logging {
5583 * function.)
5684 */
5785 protected def freshName (prefix : String ): TermName = {
58- newTermName(s " $prefix$javaSeperator ${curId.getAndIncrement}" )
86+ newTermName(s " $prefix$javaSeparator ${curId.getAndIncrement}" )
5987 }
6088
6189 /**
@@ -66,7 +94,7 @@ abstract class CodeGenerator extends Logging {
6694 * to null.
6795 * @param primitiveTerm A term for a possible primitive value of the result of the evaluation. Not
6896 * valid if `nullTerm` is set to `false`.
69- * @param objectTerm An possibly boxed version of the result of evaluating this expression.
97+ * @param objectTerm A possibly boxed version of the result of evaluating this expression.
7098 */
7199 protected case class EvaluatedExpression (
72100 code : Seq [Tree ],
@@ -87,7 +115,7 @@ abstract class CodeGenerator extends Logging {
87115 def castOrNull (f : TermName => Tree , dataType : DataType ): Seq [Tree ] = {
88116 val eval = expressionEvaluator(e)
89117 eval.code ++
90- q """
118+ q """
91119 val $nullTerm = ${eval.nullTerm}
92120 val $primitiveTerm =
93121 if( $nullTerm)
@@ -119,7 +147,7 @@ abstract class CodeGenerator extends Logging {
119147 val resultCode = f(eval1.primitiveTerm, eval2.primitiveTerm)
120148
121149 eval1.code ++ eval2.code ++
122- q """
150+ q """
123151 val $nullTerm = ${eval1.nullTerm} || ${eval2.nullTerm}
124152 val $primitiveTerm: ${termForType(resultType)} =
125153 if( $nullTerm) {
@@ -135,14 +163,15 @@ abstract class CodeGenerator extends Logging {
135163
136164 // TODO: Skip generation of null handling code when expression are not nullable.
137165 val primitiveEvaluation : PartialFunction [Expression , Seq [Tree ]] = {
138- case b @ BoundReference (ordinal, _) =>
166+ case b @ BoundReference (ordinal, dataType, nullable) =>
167+ val nullValue = if (nullable) q " $inputTuple.isNullAt( $ordinal) " else q " false "
139168 q """
140- val $nullTerm: Boolean = $inputTuple .isNullAt( $ordinal )
141- val $primitiveTerm: ${termForType(b. dataType)} =
169+ val $nullTerm: Boolean = $nullValue
170+ val $primitiveTerm: ${termForType(dataType)} =
142171 if( $nullTerm)
143- ${defaultPrimitive(e. dataType)}
172+ ${defaultPrimitive(dataType)}
144173 else
145- ${getColumn(inputTuple, b. dataType, ordinal)}
174+ ${getColumn(inputTuple, dataType, ordinal)}
146175 """ .children
147176
148177 case expressions.Literal (null , dataType) =>
@@ -162,11 +191,13 @@ abstract class CodeGenerator extends Logging {
162191 val $nullTerm = ${value == null }
163192 val $primitiveTerm: ${termForType(dataType)} = $value
164193 """ .children
194+
165195 case expressions.Literal (value : Int , dataType) =>
166196 q """
167197 val $nullTerm = ${value == null }
168198 val $primitiveTerm: ${termForType(dataType)} = $value
169199 """ .children
200+
170201 case expressions.Literal (value : Long , dataType) =>
171202 q """
172203 val $nullTerm = ${value == null }
@@ -176,7 +207,7 @@ abstract class CodeGenerator extends Logging {
176207 case Cast (e @ BinaryType (), StringType ) =>
177208 val eval = expressionEvaluator(e)
178209 eval.code ++
179- q """
210+ q """
180211 val $nullTerm = ${eval.nullTerm}
181212 val $primitiveTerm =
182213 if( $nullTerm)
@@ -200,7 +231,7 @@ abstract class CodeGenerator extends Logging {
200231 case Cast (e, StringType ) =>
201232 val eval = expressionEvaluator(e)
202233 eval.code ++
203- q """
234+ q """
204235 val $nullTerm = ${eval.nullTerm}
205236 val $primitiveTerm =
206237 if( $nullTerm)
@@ -251,7 +282,7 @@ abstract class CodeGenerator extends Logging {
251282 val eval2 = expressionEvaluator(e2)
252283
253284 eval1.code ++ eval2.code ++
254- q """
285+ q """
255286 var $nullTerm = false
256287 var $primitiveTerm: ${termForType(BooleanType )} = false
257288
@@ -272,7 +303,7 @@ abstract class CodeGenerator extends Logging {
272303 val eval2 = expressionEvaluator(e2)
273304
274305 eval1.code ++ eval2.code ++
275- q """
306+ q """
276307 var $nullTerm = false
277308 var $primitiveTerm: ${termForType(BooleanType )} = false
278309
@@ -360,10 +391,10 @@ abstract class CodeGenerator extends Logging {
360391 log.debug(s " No rules to generate $e" )
361392 val tree = reify { e }
362393 q """
363- val $objectTerm = $tree.eval(i)
364- val $nullTerm = $objectTerm == null
365- val $primitiveTerm = $objectTerm.asInstanceOf[ ${termForType(e.dataType)}]
366- """ .children
394+ val $objectTerm = $tree.eval(i)
395+ val $nullTerm = $objectTerm == null
396+ val $primitiveTerm = $objectTerm.asInstanceOf[ ${termForType(e.dataType)}]
397+ """ .children
367398 }
368399
369400 EvaluatedExpression (code, nullTerm, primitiveTerm, objectTerm)
@@ -377,10 +408,10 @@ abstract class CodeGenerator extends Logging {
377408 }
378409
379410 protected def setColumn (
380- destinationRow : TermName ,
381- dataType : DataType ,
382- ordinal : Int ,
383- value : TermName ) = {
411+ destinationRow : TermName ,
412+ dataType : DataType ,
413+ ordinal : Int ,
414+ value : TermName ) = {
384415 dataType match {
385416 case dt @ NativeType () => q " $destinationRow. ${mutatorForType(dt)}( $ordinal, $value) "
386417 case _ => q " $destinationRow.update( $ordinal, $value) "
0 commit comments