@@ -376,45 +376,6 @@ case class MapObjects private(
376376 lambdaFunction : Expression ,
377377 inputData : Expression ) extends Expression with NonSQLExpression {
378378
379- @ tailrec
380- private def itemAccessorMethod (dataType : DataType ): String => String = dataType match {
381- case NullType =>
382- val nullTypeClassName = NullType .getClass.getName + " .MODULE$"
383- (i : String ) => s " .get( $i, $nullTypeClassName) "
384- case IntegerType => (i : String ) => s " .getInt( $i) "
385- case LongType => (i : String ) => s " .getLong( $i) "
386- case FloatType => (i : String ) => s " .getFloat( $i) "
387- case DoubleType => (i : String ) => s " .getDouble( $i) "
388- case ByteType => (i : String ) => s " .getByte( $i) "
389- case ShortType => (i : String ) => s " .getShort( $i) "
390- case BooleanType => (i : String ) => s " .getBoolean( $i) "
391- case StringType => (i : String ) => s " .getUTF8String( $i) "
392- case s : StructType => (i : String ) => s " .getStruct( $i, ${s.size}) "
393- case a : ArrayType => (i : String ) => s " .getArray( $i) "
394- case _ : MapType => (i : String ) => s " .getMap( $i) "
395- case udt : UserDefinedType [_] => itemAccessorMethod(udt.sqlType)
396- case DecimalType .Fixed (p, s) => (i : String ) => s " .getDecimal( $i, $p, $s) "
397- case DateType => (i : String ) => s " .getInt( $i) "
398- }
399-
400- private lazy val (lengthFunction, itemAccessor, primitiveElement) = inputData.dataType match {
401- case ObjectType (cls) if classOf [Seq [_]].isAssignableFrom(cls) =>
402- (" .size()" , (i : String ) => s " .apply( $i) " , false )
403- case ObjectType (cls) if cls.isArray =>
404- (" .length" , (i : String ) => s " [ $i] " , false )
405- case ObjectType (cls) if classOf [java.util.List [_]].isAssignableFrom(cls) =>
406- (" .size()" , (i : String ) => s " .get( $i) " , false )
407- case ArrayType (t, _) =>
408- val (sqlType, primitiveElement) = t match {
409- case m : MapType => (m, false )
410- case s : StructType => (s, false )
411- case s : StringType => (s, false )
412- case udt : UserDefinedType [_] => (udt.sqlType, false )
413- case o => (o, true )
414- }
415- (" .numElements()" , itemAccessorMethod(sqlType), primitiveElement)
416- }
417-
418379 override def nullable : Boolean = true
419380
420381 override def children : Seq [Expression ] = lambdaFunction :: inputData :: Nil
@@ -425,7 +386,6 @@ case class MapObjects private(
425386 override def dataType : DataType = ArrayType (lambdaFunction.dataType)
426387
427388 override def doGenCode (ctx : CodegenContext , ev : ExprCode ): ExprCode = {
428- val javaType = ctx.javaType(dataType)
429389 val elementJavaType = ctx.javaType(loopVar.dataType)
430390 ctx.addMutableState(" boolean" , loopVar.isNull, " " )
431391 ctx.addMutableState(elementJavaType, loopVar.value, " " )
@@ -448,27 +408,61 @@ case class MapObjects private(
448408 s " new $convertedType[ $dataLength] "
449409 }
450410
451- val loopNullCheck = if (primitiveElement) {
452- s " ${loopVar.isNull} = ${genInputData.value}.isNullAt( $loopIndex); "
453- } else {
454- s " ${loopVar.isNull} = ${genInputData.isNull} || ${loopVar.value} == null; "
411+ // In RowEncoder, we use `Object` to represent Array or Seq, so we need to determine the type
412+ // of input collection at runtime for this case.
413+ val seq = ctx.freshName(" seq" )
414+ val array = ctx.freshName(" array" )
415+ val determineCollectionType = inputData.dataType match {
416+ case ObjectType (cls) if cls == classOf [Object ] =>
417+ val seqClass = classOf [Seq [_]].getName
418+ s """
419+ $seqClass $seq = null;
420+ $elementJavaType[] $array = null;
421+ if ( ${genInputData.value}.getClass().isArray()) {
422+ $array = ( $elementJavaType[]) ${genInputData.value};
423+ } else {
424+ $seq = ( $seqClass) ${genInputData.value};
425+ }
426+ """
427+ case _ => " "
428+ }
429+
430+
431+ val (getLength, getLoopVar) = inputData.dataType match {
432+ case ObjectType (cls) if classOf [Seq [_]].isAssignableFrom(cls) =>
433+ s " ${genInputData.value}.size() " -> s " ${genInputData.value}.apply( $loopIndex) "
434+ case ObjectType (cls) if cls.isArray =>
435+ s " ${genInputData.value}.length " -> s " ${genInputData.value}[ $loopIndex] "
436+ case ObjectType (cls) if classOf [java.util.List [_]].isAssignableFrom(cls) =>
437+ s " ${genInputData.value}.size() " -> s " ${genInputData.value}.get( $loopIndex) "
438+ case ArrayType (et, _) =>
439+ s " ${genInputData.value}.numElements() " -> ctx.getValue(genInputData.value, et, loopIndex)
440+ case ObjectType (cls) if cls == classOf [Object ] =>
441+ s " $seq == null ? $array.length : $seq.size() " ->
442+ s " $seq == null ? $array[ $loopIndex] : $seq.apply( $loopIndex) "
443+ }
444+
445+ val loopNullCheck = inputData.dataType match {
446+ case _ : ArrayType => s " ${loopVar.isNull} = ${genInputData.value}.isNullAt( $loopIndex); "
447+ // The element of primitive array will never be null.
448+ case ObjectType (cls) if cls.isArray && cls.getComponentType.isPrimitive =>
449+ s " ${loopVar.isNull} = false "
450+ case _ => s " ${loopVar.isNull} = ${loopVar.value} == null; "
455451 }
456452
457453 val code = s """
458454 ${genInputData.code}
455+ ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};
459456
460- boolean ${ev.isNull} = ${genInputData.value} == null;
461- $javaType ${ev.value} = ${ctx.defaultValue(dataType)};
462-
463- if (! ${ev.isNull}) {
457+ if (! ${genInputData.isNull}) {
458+ $determineCollectionType
464459 $convertedType[] $convertedArray = null;
465- int $dataLength = ${genInputData.value}$lengthFunction ;
460+ int $dataLength = $getLength ;
466461 $convertedArray = $arrayConstructor;
467462
468463 int $loopIndex = 0;
469464 while ( $loopIndex < $dataLength) {
470- ${loopVar.value} =
471- ( $elementJavaType) ${genInputData.value}${itemAccessor(loopIndex)};
465+ ${loopVar.value} = ( $elementJavaType) ( $getLoopVar);
472466 $loopNullCheck
473467
474468 ${genFunction.code}
@@ -481,11 +475,10 @@ case class MapObjects private(
481475 $loopIndex += 1;
482476 }
483477
484- ${ev.isNull} = false;
485478 ${ev.value} = new ${classOf [GenericArrayData ].getName}( $convertedArray);
486479 }
487480 """
488- ev.copy(code = code)
481+ ev.copy(code = code, isNull = genInputData.isNull )
489482 }
490483}
491484
0 commit comments