@@ -333,6 +333,49 @@ class ExpressionEvaluationSuite extends FunSuite {
333333 Literal (" ^Ba*n" , StringType ) :: c2 :: Nil ), true , row)
334334 }
335335
336+ test(" case when" ) {
337+ val row = new GenericRow (Array [Any ](null , false , true , " a" , " b" , " c" ))
338+ val c1 = ' a .boolean.at(0 )
339+ val c2 = ' a .boolean.at(1 )
340+ val c3 = ' a .boolean.at(2 )
341+ val c4 = ' a .string.at(3 )
342+ val c5 = ' a .string.at(4 )
343+ val c6 = ' a .string.at(5 )
344+
345+ checkEvaluation(CaseWhen (Seq (c1, c4, c6)), " c" , row)
346+ checkEvaluation(CaseWhen (Seq (c2, c4, c6)), " c" , row)
347+ checkEvaluation(CaseWhen (Seq (c3, c4, c6)), " a" , row)
348+ checkEvaluation(CaseWhen (Seq (Literal (null , BooleanType ), c4, c6)), " c" , row)
349+ checkEvaluation(CaseWhen (Seq (Literal (false , BooleanType ), c4, c6)), " c" , row)
350+ checkEvaluation(CaseWhen (Seq (Literal (true , BooleanType ), c4, c6)), " a" , row)
351+
352+ checkEvaluation(CaseWhen (Seq (c3, c4, c2, c5, c6)), " a" , row)
353+ checkEvaluation(CaseWhen (Seq (c2, c4, c3, c5, c6)), " b" , row)
354+ checkEvaluation(CaseWhen (Seq (c1, c4, c2, c5, c6)), " c" , row)
355+ checkEvaluation(CaseWhen (Seq (c1, c4, c2, c5)), null , row)
356+
357+ assert(CaseWhen (Seq (c2, c4, c6)).nullable === true )
358+ assert(CaseWhen (Seq (c2, c4, c3, c5, c6)).nullable === true )
359+ assert(CaseWhen (Seq (c2, c4, c3, c5)).nullable === true )
360+
361+ val c4_notNull = ' a .boolean.notNull.at(3 )
362+ val c5_notNull = ' a .boolean.notNull.at(4 )
363+ val c6_notNull = ' a .boolean.notNull.at(5 )
364+
365+ assert(CaseWhen (Seq (c2, c4_notNull, c6_notNull)).nullable === false )
366+ assert(CaseWhen (Seq (c2, c4, c6_notNull)).nullable === true )
367+ assert(CaseWhen (Seq (c2, c4_notNull, c6)).nullable === true )
368+
369+ assert(CaseWhen (Seq (c2, c4_notNull, c3, c5_notNull, c6_notNull)).nullable === false )
370+ assert(CaseWhen (Seq (c2, c4, c3, c5_notNull, c6_notNull)).nullable === true )
371+ assert(CaseWhen (Seq (c2, c4_notNull, c3, c5, c6_notNull)).nullable === true )
372+ assert(CaseWhen (Seq (c2, c4_notNull, c3, c5_notNull, c6)).nullable === true )
373+
374+ assert(CaseWhen (Seq (c2, c4_notNull, c3, c5_notNull)).nullable === true )
375+ assert(CaseWhen (Seq (c2, c4, c3, c5_notNull)).nullable === true )
376+ assert(CaseWhen (Seq (c2, c4_notNull, c3, c5)).nullable === true )
377+ }
378+
336379 test(" complex type" ) {
337380 val row = new GenericRow (Array [Any ](
338381 " ^Ba*n" , // 0
0 commit comments