Skip to content

Commit be2cd6b

Browse files
committed
WIP: Remove old method for reference binding, more work on configuration.
1 parent bc88ecd commit be2cd6b

File tree

18 files changed

+174
-179
lines changed

18 files changed

+174
-179
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -172,7 +172,7 @@ package object dsl {
172172
// Protobuf terminology
173173
def required = a.withNullability(false)
174174

175-
def at(ordinal: Int) = BoundReference(ordinal, a)
175+
def at(ordinal: Int) = BoundReference(ordinal, a.dataType, a.nullable)
176176
}
177177
}
178178

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala

Lines changed: 9 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -17,72 +17,40 @@
1717

1818
package org.apache.spark.sql.catalyst.expressions
1919

20-
import org.apache.spark.sql.catalyst.trees
2120
import org.apache.spark.sql.catalyst.errors.attachTree
2221
import org.apache.spark.sql.catalyst.plans.QueryPlan
2322
import org.apache.spark.sql.catalyst.rules.Rule
23+
import org.apache.spark.sql.catalyst.types._
24+
import org.apache.spark.sql.catalyst.trees
25+
2426
import org.apache.spark.sql.Logging
2527

2628
/**
2729
* A bound reference points to a specific slot in the input tuple, allowing the actual value
2830
* to be retrieved more efficiently. However, since operations like column pruning can change
2931
* the layout of intermediate tuples, BindReferences should be run after all such transformations.
3032
*/
31-
case class BoundReference(ordinal: Int, baseReference: Attribute)
32-
extends Attribute with trees.LeafNode[Expression] {
33+
case class BoundReference(ordinal: Int, dataType: DataType, nullable: Boolean)
34+
extends Expression with trees.LeafNode[Expression] {
3335

3436
type EvaluatedType = Any
3537

36-
override def nullable = baseReference.nullable
37-
override def dataType = baseReference.dataType
38-
override def exprId = baseReference.exprId
39-
override def qualifiers = baseReference.qualifiers
40-
override def name = baseReference.name
38+
def references = Set.empty
4139

42-
override def newInstance = BoundReference(ordinal, baseReference.newInstance)
43-
override def withNullability(newNullability: Boolean) =
44-
BoundReference(ordinal, baseReference.withNullability(newNullability))
45-
override def withQualifiers(newQualifiers: Seq[String]) =
46-
BoundReference(ordinal, baseReference.withQualifiers(newQualifiers))
47-
48-
override def toString = s"$baseReference:$ordinal"
40+
override def toString = s"input[$ordinal]"
4941

5042
override def eval(input: Row): Any = input(ordinal)
5143
}
5244

53-
/**
54-
* Used to denote operators that do their own binding of attributes internally.
55-
*/
56-
trait NoBind { self: trees.TreeNode[_] => }
57-
58-
class BindReferences[TreeNode <: QueryPlan[TreeNode]] extends Rule[TreeNode] {
59-
import BindReferences._
60-
61-
def apply(plan: TreeNode): TreeNode = {
62-
plan.transform {
63-
case n: NoBind => n.asInstanceOf[TreeNode]
64-
case leafNode if leafNode.children.isEmpty => leafNode
65-
case unaryNode if unaryNode.children.size == 1 => unaryNode.transformExpressions { case e =>
66-
bindReference(e, unaryNode.children.head.output)
67-
}
68-
}
69-
}
70-
}
71-
7245
object BindReferences extends Logging {
7346
def bindReference[A <: Expression](expression: A, input: Seq[Attribute]): A = {
7447
expression.transform { case a: AttributeReference =>
7548
attachTree(a, "Binding attribute") {
7649
val ordinal = input.indexWhere(_.exprId == a.exprId)
7750
if (ordinal == -1) {
78-
// TODO: This fallback is required because some operators (such as ScriptTransform)
79-
// produce new attributes that can't be bound. Likely the right thing to do is remove
80-
// this rule and require all operators to explicitly bind to the input schema that
81-
// they specify.
82-
logger.debug(s"Couldn't find $a in ${input.mkString("[", ",", "]")}")
83-
a
51+
sys.error(s"Couldn't find $a in ${input.mkString("[", ",", "]")}")
8452
} else {
85-
BoundReference(ordinal, a)
53+
BoundReference(ordinal, a.dataType, a.nullable)
8654
}
8755
}
8856
}.asInstanceOf[A] // Kind of a hack, but safe. TODO: Tighten return type when possible.

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala

Lines changed: 62 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717

1818
package org.apache.spark.sql.catalyst.expressions.codegen
1919

20+
import com.google.common.cache.{CacheLoader, CacheBuilder}
21+
2022
import scala.language.existentials
2123

2224
import org.apache.spark.Logging
@@ -26,27 +28,53 @@ import org.apache.spark.sql.catalyst.types._
2628

2729
/**
2830
* A base class for generators of byte code that performs expression evaluation. Includes helpers
29-
* for refering to Catalyst types and building trees that perform evaluation of individual
31+
* for referring to Catalyst types and building trees that perform evaluation of individual
3032
* expressions.
3133
*/
32-
abstract class CodeGenerator extends Logging {
34+
abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Logging {
3335
import scala.reflect.runtime.{universe => ru}
3436
import scala.reflect.runtime.universe._
3537

3638
import scala.tools.reflect.ToolBox
3739

38-
val toolBox = runtimeMirror(getClass.getClassLoader).mkToolBox()
40+
protected val toolBox = runtimeMirror(getClass.getClassLoader).mkToolBox()
3941

40-
val rowType = typeOf[Row]
41-
val mutableRowType = typeOf[MutableRow]
42-
val genericRowType = typeOf[GenericRow]
43-
val genericMutableRowType = typeOf[GenericMutableRow]
42+
protected val rowType = typeOf[Row]
43+
protected val mutableRowType = typeOf[MutableRow]
44+
protected val genericRowType = typeOf[GenericRow]
45+
protected val genericMutableRowType = typeOf[GenericMutableRow]
4446

45-
val projectionType = typeOf[Projection]
46-
val mutableProjectionType = typeOf[MutableProjection]
47+
protected val projectionType = typeOf[Projection]
48+
protected val mutableProjectionType = typeOf[MutableProjection]
4749

4850
private val curId = new java.util.concurrent.atomic.AtomicInteger()
49-
private val javaSeperator = "$"
51+
private val javaSeparator = "$"
52+
53+
/**
54+
* Generates a class for a given input expression. Called when there is not a cached code
55+
* already available.
56+
*/
57+
protected def create(in: InType): OutType
58+
59+
/** Canonicalizes an input expression. */
60+
protected def canonicalize(in: InType): InType
61+
62+
/** Binds an input expression to a given input schema */
63+
protected def bind(in: InType, inputSchema: Seq[Attribute]): InType
64+
65+
protected val cache = CacheBuilder.newBuilder()
66+
.maximumSize(1000)
67+
.build(
68+
new CacheLoader[InType, OutType]() {
69+
override def load(in: InType): OutType = globalLock.synchronized {
70+
create(in)
71+
}
72+
})
73+
74+
def apply(expressions: InType, inputSchema: Seq[Attribute]): OutType=
75+
apply(bind(expressions, inputSchema))
76+
77+
def apply(expressions: InType): OutType = cache.get(canonicalize(expressions))
5078

5179
/**
5280
* Returns a term name that is unique within this instance of a `CodeGenerator`.
@@ -55,7 +83,7 @@ abstract class CodeGenerator extends Logging {
5583
* function.)
5684
*/
5785
protected def freshName(prefix: String): TermName = {
58-
newTermName(s"$prefix$javaSeperator${curId.getAndIncrement}")
86+
newTermName(s"$prefix$javaSeparator${curId.getAndIncrement}")
5987
}
6088

6189
/**
@@ -66,7 +94,7 @@ abstract class CodeGenerator extends Logging {
6694
* to null.
6795
* @param primitiveTerm A term for a possible primitive value of the result of the evaluation. Not
6896
* valid if `nullTerm` is set to `false`.
69-
* @param objectTerm An possibly boxed version of the result of evaluating this expression.
97+
* @param objectTerm A possibly boxed version of the result of evaluating this expression.
7098
*/
7199
protected case class EvaluatedExpression(
72100
code: Seq[Tree],
@@ -87,7 +115,7 @@ abstract class CodeGenerator extends Logging {
87115
def castOrNull(f: TermName => Tree, dataType: DataType): Seq[Tree] = {
88116
val eval = expressionEvaluator(e)
89117
eval.code ++
90-
q"""
118+
q"""
91119
val $nullTerm = ${eval.nullTerm}
92120
val $primitiveTerm =
93121
if($nullTerm)
@@ -119,7 +147,7 @@ abstract class CodeGenerator extends Logging {
119147
val resultCode = f(eval1.primitiveTerm, eval2.primitiveTerm)
120148

121149
eval1.code ++ eval2.code ++
122-
q"""
150+
q"""
123151
val $nullTerm = ${eval1.nullTerm} || ${eval2.nullTerm}
124152
val $primitiveTerm: ${termForType(resultType)} =
125153
if($nullTerm) {
@@ -135,14 +163,15 @@ abstract class CodeGenerator extends Logging {
135163

136164
// TODO: Skip generation of null handling code when expression are not nullable.
137165
val primitiveEvaluation: PartialFunction[Expression, Seq[Tree]] = {
138-
case b @ BoundReference(ordinal, _) =>
166+
case b @ BoundReference(ordinal, dataType, nullable) =>
167+
val nullValue = if (nullable) q"$inputTuple.isNullAt($ordinal)" else q"false"
139168
q"""
140-
val $nullTerm: Boolean = $inputTuple.isNullAt($ordinal)
141-
val $primitiveTerm: ${termForType(b.dataType)} =
169+
val $nullTerm: Boolean = $nullValue
170+
val $primitiveTerm: ${termForType(dataType)} =
142171
if($nullTerm)
143-
${defaultPrimitive(e.dataType)}
172+
${defaultPrimitive(dataType)}
144173
else
145-
${getColumn(inputTuple, b.dataType, ordinal)}
174+
${getColumn(inputTuple, dataType, ordinal)}
146175
""".children
147176

148177
case expressions.Literal(null, dataType) =>
@@ -162,11 +191,13 @@ abstract class CodeGenerator extends Logging {
162191
val $nullTerm = ${value == null}
163192
val $primitiveTerm: ${termForType(dataType)} = $value
164193
""".children
194+
165195
case expressions.Literal(value: Int, dataType) =>
166196
q"""
167197
val $nullTerm = ${value == null}
168198
val $primitiveTerm: ${termForType(dataType)} = $value
169199
""".children
200+
170201
case expressions.Literal(value: Long, dataType) =>
171202
q"""
172203
val $nullTerm = ${value == null}
@@ -176,7 +207,7 @@ abstract class CodeGenerator extends Logging {
176207
case Cast(e @ BinaryType(), StringType) =>
177208
val eval = expressionEvaluator(e)
178209
eval.code ++
179-
q"""
210+
q"""
180211
val $nullTerm = ${eval.nullTerm}
181212
val $primitiveTerm =
182213
if($nullTerm)
@@ -200,7 +231,7 @@ abstract class CodeGenerator extends Logging {
200231
case Cast(e, StringType) =>
201232
val eval = expressionEvaluator(e)
202233
eval.code ++
203-
q"""
234+
q"""
204235
val $nullTerm = ${eval.nullTerm}
205236
val $primitiveTerm =
206237
if($nullTerm)
@@ -251,7 +282,7 @@ abstract class CodeGenerator extends Logging {
251282
val eval2 = expressionEvaluator(e2)
252283

253284
eval1.code ++ eval2.code ++
254-
q"""
285+
q"""
255286
var $nullTerm = false
256287
var $primitiveTerm: ${termForType(BooleanType)} = false
257288

@@ -272,7 +303,7 @@ abstract class CodeGenerator extends Logging {
272303
val eval2 = expressionEvaluator(e2)
273304

274305
eval1.code ++ eval2.code ++
275-
q"""
306+
q"""
276307
var $nullTerm = false
277308
var $primitiveTerm: ${termForType(BooleanType)} = false
278309

@@ -360,10 +391,10 @@ abstract class CodeGenerator extends Logging {
360391
log.debug(s"No rules to generate $e")
361392
val tree = reify { e }
362393
q"""
363-
val $objectTerm = $tree.eval(i)
364-
val $nullTerm = $objectTerm == null
365-
val $primitiveTerm = $objectTerm.asInstanceOf[${termForType(e.dataType)}]
366-
""".children
394+
val $objectTerm = $tree.eval(i)
395+
val $nullTerm = $objectTerm == null
396+
val $primitiveTerm = $objectTerm.asInstanceOf[${termForType(e.dataType)}]
397+
""".children
367398
}
368399

369400
EvaluatedExpression(code, nullTerm, primitiveTerm, objectTerm)
@@ -377,10 +408,10 @@ abstract class CodeGenerator extends Logging {
377408
}
378409

379410
protected def setColumn(
380-
destinationRow: TermName,
381-
dataType: DataType,
382-
ordinal: Int,
383-
value: TermName) = {
411+
destinationRow: TermName,
412+
dataType: DataType,
413+
ordinal: Int,
414+
value: TermName) = {
384415
dataType match {
385416
case dt @ NativeType() => q"$destinationRow.${mutatorForType(dt)}($ordinal, $value)"
386417
case _ => q"$destinationRow.update($ordinal, $value)"

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala

Lines changed: 23 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -23,31 +23,24 @@ import org.apache.spark.sql.catalyst.expressions._
2323
* Generates byte code that produces a [[MutableRow]] object that can update itself based on a new
2424
* input [[Row]] for a fixed set of [[Expression Expressions]].
2525
*/
26-
object GenerateMutableProjection extends CodeGenerator {
26+
object GenerateMutableProjection extends CodeGenerator[Seq[Expression], () => MutableProjection] {
2727
import scala.reflect.runtime.{universe => ru}
2828
import scala.reflect.runtime.universe._
2929

30-
// TODO: Should be weak references... bounded in size.
31-
val projectionCache = new collection.mutable.HashMap[Seq[Expression], () => MutableProjection]
32-
33-
def apply(expressions: Seq[Expression], inputSchema: Seq[Attribute]): (() => MutableProjection) =
34-
apply(expressions.map(BindReferences.bindReference(_, inputSchema)))
30+
val mutableRowName = newTermName("mutableRow")
3531

36-
// TODO: Safe to fire up multiple instances of the compiler?
37-
def apply(expressions: Seq[Expression]): () => MutableProjection =
38-
globalLock.synchronized {
39-
val cleanedExpressions = expressions.map(ExpressionCanonicalizer(_))
40-
projectionCache.getOrElseUpdate(cleanedExpressions, createProjection(cleanedExpressions))
41-
}
32+
protected def canonicalize(in: Seq[Expression]): Seq[Expression] =
33+
in.map(ExpressionCanonicalizer(_))
4234

43-
val mutableRowName = newTermName("mutableRow")
35+
protected def bind(in: Seq[Expression], inputSchema: Seq[Attribute]): Seq[Expression] =
36+
in.map(BindReferences.bindReference(_, inputSchema))
4437

45-
def createProjection(expressions: Seq[Expression]): (() => MutableProjection) = {
38+
protected def create(expressions: Seq[Expression]): (() => MutableProjection) = {
4639
val projectionCode = expressions.zipWithIndex.flatMap { case (e, i) =>
4740
val evaluationCode = expressionEvaluator(e)
4841

4942
evaluationCode.code :+
50-
q"""
43+
q"""
5144
if(${evaluationCode.nullTerm})
5245
mutableRow.setNullAt($i)
5346
else
@@ -57,25 +50,25 @@ object GenerateMutableProjection extends CodeGenerator {
5750

5851
val code =
5952
q"""
60-
() => { new $mutableProjectionType {
53+
() => { new $mutableProjectionType {
6154

62-
private[this] var $mutableRowName: $mutableRowType =
63-
new $genericMutableRowType(${expressions.size})
55+
private[this] var $mutableRowName: $mutableRowType =
56+
new $genericMutableRowType(${expressions.size})
6457

65-
def target(row: $mutableRowType): $mutableProjectionType = {
66-
$mutableRowName = row
67-
this
68-
}
58+
def target(row: $mutableRowType): $mutableProjectionType = {
59+
$mutableRowName = row
60+
this
61+
}
6962

70-
/* Provide immutable access to the last projected row. */
71-
def currentValue: $rowType = mutableRow
63+
/* Provide immutable access to the last projected row. */
64+
def currentValue: $rowType = mutableRow
7265

73-
def apply(i: $rowType): $rowType = {
74-
..$projectionCode
75-
mutableRow
76-
}
77-
} }
78-
"""
66+
def apply(i: $rowType): $rowType = {
67+
..$projectionCode
68+
mutableRow
69+
}
70+
} }
71+
"""
7972

8073
log.debug(s"code for ${expressions.mkString(",")}:\n$code")
8174
toolBox.eval(code).asInstanceOf[() => MutableProjection]

0 commit comments

Comments
 (0)