Skip to content

Commit 59d820a

Browse files
cloud-fanmarmbrus
authored andcommitted
[SPARK-9029] [SQL] shortcut CaseKeyWhen if key is null
Author: Wenchen Fan <[email protected]> Closes apache#7389 from cloud-fan/case-when and squashes the following commits: ea4b6ba [Wenchen Fan] shortcut for case key when
1 parent 257236c commit 59d820a

File tree

1 file changed

+24
-24
lines changed

1 file changed

+24
-24
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionals.scala

Lines changed: 24 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -230,24 +230,31 @@ case class CaseKeyWhen(key: Expression, branches: Seq[Expression]) extends CaseW
230230
}
231231
}
232232

233+
private def evalElse(input: InternalRow): Any = {
234+
if (branchesArr.length % 2 == 0) {
235+
null
236+
} else {
237+
branchesArr(branchesArr.length - 1).eval(input)
238+
}
239+
}
240+
233241
/** Written in imperative fashion for performance considerations. */
234242
override def eval(input: InternalRow): Any = {
235243
val evaluatedKey = key.eval(input)
236-
val len = branchesArr.length
237-
var i = 0
238-
// If all branches fail and an elseVal is not provided, the whole statement
239-
// defaults to null, according to Hive's semantics.
240-
while (i < len - 1) {
241-
if (threeValueEquals(evaluatedKey, branchesArr(i).eval(input))) {
242-
return branchesArr(i + 1).eval(input)
244+
// If key is null, we can just return the else part or null if there is no else.
245+
// If key is not null but doesn't match any when part, we need to return
246+
// the else part or null if there is no else, according to Hive's semantics.
247+
if (evaluatedKey != null) {
248+
val len = branchesArr.length
249+
var i = 0
250+
while (i < len - 1) {
251+
if (evaluatedKey == branchesArr(i).eval(input)) {
252+
return branchesArr(i + 1).eval(input)
253+
}
254+
i += 2
243255
}
244-
i += 2
245256
}
246-
var res: Any = null
247-
if (i == len - 1) {
248-
res = branchesArr(i).eval(input)
249-
}
250-
return res
257+
evalElse(input)
251258
}
252259

253260
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
@@ -261,8 +268,7 @@ case class CaseKeyWhen(key: Expression, branches: Seq[Expression]) extends CaseW
261268
s"""
262269
if (!$got) {
263270
${cond.code}
264-
if (!${keyEval.isNull} && !${cond.isNull}
265-
&& ${ctx.genEqual(key.dataType, keyEval.primitive, cond.primitive)}) {
271+
if (!${cond.isNull} && ${ctx.genEqual(key.dataType, keyEval.primitive, cond.primitive)}) {
266272
$got = true;
267273
${res.code}
268274
${ev.isNull} = ${res.isNull};
@@ -290,19 +296,13 @@ case class CaseKeyWhen(key: Expression, branches: Seq[Expression]) extends CaseW
290296
boolean ${ev.isNull} = true;
291297
${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)};
292298
${keyEval.code}
293-
$cases
299+
if (!${keyEval.isNull}) {
300+
$cases
301+
}
294302
$other
295303
"""
296304
}
297305

298-
private def threeValueEquals(l: Any, r: Any) = {
299-
if (l == null || r == null) {
300-
false
301-
} else {
302-
l == r
303-
}
304-
}
305-
306306
override def toString: String = {
307307
s"CASE $key" + branches.sliding(2, 2).map {
308308
case Seq(cond, value) => s" WHEN $cond THEN $value"

0 commit comments

Comments
 (0)