@@ -230,24 +230,31 @@ case class CaseKeyWhen(key: Expression, branches: Seq[Expression]) extends CaseW
230230 }
231231 }
232232
233+ private def evalElse (input : InternalRow ): Any = {
234+ if (branchesArr.length % 2 == 0 ) {
235+ null
236+ } else {
237+ branchesArr(branchesArr.length - 1 ).eval(input)
238+ }
239+ }
240+
233241 /** Written in imperative fashion for performance considerations. */
234242 override def eval (input : InternalRow ): Any = {
235243 val evaluatedKey = key.eval(input)
236- val len = branchesArr.length
237- var i = 0
238- // If all branches fail and an elseVal is not provided, the whole statement
239- // defaults to null, according to Hive's semantics.
240- while (i < len - 1 ) {
241- if (threeValueEquals(evaluatedKey, branchesArr(i).eval(input))) {
242- return branchesArr(i + 1 ).eval(input)
244+ // If key is null, we can just return the else part or null if there is no else.
245+ // If key is not null but doesn't match any when part, we need to return
246+ // the else part or null if there is no else, according to Hive's semantics.
247+ if (evaluatedKey != null ) {
248+ val len = branchesArr.length
249+ var i = 0
250+ while (i < len - 1 ) {
251+ if (evaluatedKey == branchesArr(i).eval(input)) {
252+ return branchesArr(i + 1 ).eval(input)
253+ }
254+ i += 2
243255 }
244- i += 2
245256 }
246- var res : Any = null
247- if (i == len - 1 ) {
248- res = branchesArr(i).eval(input)
249- }
250- return res
257+ evalElse(input)
251258 }
252259
253260 override def genCode (ctx : CodeGenContext , ev : GeneratedExpressionCode ): String = {
@@ -261,8 +268,7 @@ case class CaseKeyWhen(key: Expression, branches: Seq[Expression]) extends CaseW
261268 s """
262269 if (! $got) {
263270 ${cond.code}
264- if (! ${keyEval.isNull} && ! ${cond.isNull}
265- && ${ctx.genEqual(key.dataType, keyEval.primitive, cond.primitive)}) {
271+ if (! ${cond.isNull} && ${ctx.genEqual(key.dataType, keyEval.primitive, cond.primitive)}) {
266272 $got = true;
267273 ${res.code}
268274 ${ev.isNull} = ${res.isNull};
@@ -290,19 +296,13 @@ case class CaseKeyWhen(key: Expression, branches: Seq[Expression]) extends CaseW
290296 boolean ${ev.isNull} = true;
291297 ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)};
292298 ${keyEval.code}
293- $cases
299+ if (! ${keyEval.isNull}) {
300+ $cases
301+ }
294302 $other
295303 """
296304 }
297305
298- private def threeValueEquals (l : Any , r : Any ) = {
299- if (l == null || r == null ) {
300- false
301- } else {
302- l == r
303- }
304- }
305-
306306 override def toString : String = {
307307 s " CASE $key" + branches.sliding(2 , 2 ).map {
308308 case Seq (cond, value) => s " WHEN $cond THEN $value"
0 commit comments