Skip to content

Commit a20fea9

Browse files
chenghao-intelrxin
authored andcommitted
[Spark-1461] Deferred Expression Evaluation (short-circuit evaluation)
This patch unify the foldable & nullable interface for Expression. 1) Deterministic-less UDF (like Rand()) can not be folded. 2) Short-circut will significantly improves the performance in Expression Evaluation, however, the stateful UDF should not be ignored in a short-circuit evaluation(e.g. in expression: col1 > 0 and row_sequence() < 1000, row_sequence() can not be ignored even if col1 > 0 is false) I brought an concept of DeferredObject from Hive, which has 2 kinds of children classes (EagerResult / DeferredResult), the former requires triggering the evaluation before it's created, while the later trigger the evaluation when first called its get() method. Author: Cheng Hao <[email protected]> Closes apache#446 from chenghao-intel/expression_deferred_evaluation and squashes the following commits: d2729de [Cheng Hao] Fix the codestyle issues a08f09c [Cheng Hao] fix bug in or/and short-circuit evaluation af2236b [Cheng Hao] revert the short-circuit expression evaluation for IF b7861d2 [Cheng Hao] Add Support for Deferred Expression Evaluation
1 parent bb98eca commit a20fea9

File tree

2 files changed

+53
-22
lines changed

2 files changed

+53
-22
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala

Lines changed: 32 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -98,13 +98,19 @@ case class And(left: Expression, right: Expression) extends BinaryPredicate {
9898

9999
override def eval(input: Row): Any = {
100100
val l = left.eval(input)
101-
val r = right.eval(input)
102-
if (l == false || r == false) {
103-
false
104-
} else if (l == null || r == null ) {
105-
null
101+
if (l == false) {
102+
false
106103
} else {
107-
true
104+
val r = right.eval(input)
105+
if (r == false) {
106+
false
107+
} else {
108+
if (l != null && r != null) {
109+
true
110+
} else {
111+
null
112+
}
113+
}
108114
}
109115
}
110116
}
@@ -114,13 +120,19 @@ case class Or(left: Expression, right: Expression) extends BinaryPredicate {
114120

115121
override def eval(input: Row): Any = {
116122
val l = left.eval(input)
117-
val r = right.eval(input)
118-
if (l == true || r == true) {
123+
if (l == true) {
119124
true
120-
} else if (l == null || r == null) {
121-
null
122125
} else {
123-
false
126+
val r = right.eval(input)
127+
if (r == true) {
128+
true
129+
} else {
130+
if (l != null && r != null) {
131+
false
132+
} else {
133+
null
134+
}
135+
}
124136
}
125137
}
126138
}
@@ -133,8 +145,12 @@ case class Equals(left: Expression, right: Expression) extends BinaryComparison
133145
def symbol = "="
134146
override def eval(input: Row): Any = {
135147
val l = left.eval(input)
136-
val r = right.eval(input)
137-
if (l == null || r == null) null else l == r
148+
if (l == null) {
149+
null
150+
} else {
151+
val r = right.eval(input)
152+
if (r == null) null else l == r
153+
}
138154
}
139155
}
140156

@@ -162,7 +178,7 @@ case class If(predicate: Expression, trueValue: Expression, falseValue: Expressi
162178
extends Expression {
163179

164180
def children = predicate :: trueValue :: falseValue :: Nil
165-
def nullable = trueValue.nullable || falseValue.nullable
181+
override def nullable = trueValue.nullable || falseValue.nullable
166182
def references = children.flatMap(_.references).toSet
167183
override lazy val resolved = childrenResolved && trueValue.dataType == falseValue.dataType
168184
def dataType = {
@@ -175,8 +191,9 @@ case class If(predicate: Expression, trueValue: Expression, falseValue: Expressi
175191
}
176192

177193
type EvaluatedType = Any
194+
178195
override def eval(input: Row): Any = {
179-
if (predicate.eval(input).asInstanceOf[Boolean]) {
196+
if (true == predicate.eval(input)) {
180197
trueValue.eval(input)
181198
} else {
182199
falseValue.eval(input)

sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala

Lines changed: 21 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -248,17 +248,31 @@ private[hive] case class HiveGenericUdf(name: String, children: Seq[Expression])
248248
isUDFDeterministic && children.foldLeft(true)((prev, n) => prev && n.foldable)
249249
}
250250

251+
protected lazy val deferedObjects = Array.fill[DeferredObject](children.length)({
252+
new DeferredObjectAdapter
253+
})
254+
255+
// Adapter from Catalyst ExpressionResult to Hive DeferredObject
256+
class DeferredObjectAdapter extends DeferredObject {
257+
private var func: () => Any = _
258+
def set(func: () => Any) {
259+
this.func = func
260+
}
261+
override def prepare(i: Int) = {}
262+
override def get(): AnyRef = wrap(func())
263+
}
264+
251265
val dataType: DataType = inspectorToDataType(returnInspector)
252266

253267
override def eval(input: Row): Any = {
254268
returnInspector // Make sure initialized.
255-
val args = children.map { v =>
256-
new DeferredObject {
257-
override def prepare(i: Int) = {}
258-
override def get(): AnyRef = wrap(v.eval(input))
259-
}
260-
}.toArray
261-
unwrap(function.evaluate(args))
269+
var i = 0
270+
while (i < children.length) {
271+
val idx = i
272+
deferedObjects(i).asInstanceOf[DeferredObjectAdapter].set(() => {children(idx).eval(input)})
273+
i += 1
274+
}
275+
unwrap(function.evaluate(deferedObjects))
262276
}
263277
}
264278

0 commit comments

Comments
 (0)