@@ -299,7 +299,8 @@ case class ArrayContains(left: Expression, right: Expression)
299299 Examples:
300300 > SELECT _FUNC_(array(array(1, 2), array(3, 4));
301301 [1,2,3,4]
302- """ )
302+ """ ,
303+ since = " 2.4.0" )
303304case class Flatten (child : Expression ) extends UnaryExpression {
304305
305306 override def nullable : Boolean = child.nullable || dataType.containsNull
@@ -310,18 +311,14 @@ case class Flatten(child: Expression) extends UnaryExpression {
310311 .elementType.asInstanceOf [ArrayType ]
311312 }
312313
313- override def checkInputDataTypes (): TypeCheckResult = {
314- if (
315- ArrayType .acceptsType(child.dataType) &&
316- ArrayType .acceptsType(child.dataType.asInstanceOf [ArrayType ].elementType)
317- ) {
314+ override def checkInputDataTypes (): TypeCheckResult = child.dataType match {
315+ case ArrayType (_ : ArrayType , _) =>
318316 TypeCheckResult .TypeCheckSuccess
319- } else {
317+ case _ =>
320318 TypeCheckResult .TypeCheckFailure (
321319 s " The argument should be an array of arrays, " +
322320 s " but ' ${child.sql}' is of ${child.dataType.simpleString} type. "
323321 )
324- }
325322 }
326323
327324 override def nullSafeEval (array : Any ): Any = {
@@ -339,8 +336,7 @@ case class Flatten(child: Expression) extends UnaryExpression {
339336
340337 override def doGenCode (ctx : CodegenContext , ev : ExprCode ): ExprCode = {
341338 nullSafeCodeGen(ctx, ev, c => {
342- val code =
343- if (CodeGenerator .isPrimitiveType(dataType.elementType)) {
339+ val code = if (CodeGenerator .isPrimitiveType(dataType.elementType)) {
344340 genCodeForConcatOfPrimitiveElements(ctx, c, ev.value)
345341 } else {
346342 genCodeForConcatOfComplexElements(ctx, c, ev.value)
@@ -354,26 +350,25 @@ case class Flatten(child: Expression) extends UnaryExpression {
354350 childVariableName : String ,
355351 coreLogic : String ): String = {
356352 s """
357- |for(int z=0; z < $childVariableName.numElements(); z++) {
358- | ${ev.isNull} |= $childVariableName.isNullAt(z);
359- |}
360- |if(! ${ev.isNull}) {
361- | $coreLogic
362- |}
363- """ .stripMargin
353+ |for(int z=0; z < $childVariableName.numElements(); z++) {
354+ | ${ev.isNull} |= $childVariableName.isNullAt(z);
355+ |}
356+ |if(! ${ev.isNull}) {
357+ | $coreLogic
358+ |}
359+ """ .stripMargin
364360 }
365361
366362 private def genCodeForNumberOfElements (
367363 ctx : CodegenContext ,
368364 childVariableName : String ) : (String , String ) = {
369365 val variableName = ctx.freshName(" numElements" )
370- val code =
371- s """
372- |int $variableName = 0;
373- |for(int z=0; z < $childVariableName.numElements(); z++) {
374- | $variableName += $childVariableName.getArray(z).numElements();
375- |}
376- """ .stripMargin
366+ val code = s """
367+ |int $variableName = 0;
368+ |for(int z=0; z < $childVariableName.numElements(); z++) {
369+ | $variableName += $childVariableName.getArray(z).numElements();
370+ |}
371+ """ .stripMargin
377372 (code, variableName)
378373 }
379374
@@ -400,28 +395,28 @@ case class Flatten(child: Expression) extends UnaryExpression {
400395 val primitiveValueTypeName = CodeGenerator .primitiveTypeName(elementType)
401396
402397 s """
403- | $numElemCode
404- | $unsafeArraySizeInBytes
405- |byte[] $arrayName = new byte[ $arraySizeName];
406- |UnsafeArrayData $tempArrayDataName = new UnsafeArrayData();
407- |Platform.putLong( $arrayName, $baseOffset, $numElemName);
408- | $tempArrayDataName.pointTo( $arrayName, $baseOffset, $arraySizeName);
409- |int $counter = 0;
410- |for(int k=0; k < $childVariableName.numElements(); k++) {
411- | ArrayData arr = $childVariableName.getArray(k);
412- | for(int l = 0; l < arr.numElements(); l++) {
413- | if(arr.isNullAt(l)) {
414- | $tempArrayDataName.setNullAt( $counter);
415- | } else {
416- | $tempArrayDataName.set $primitiveValueTypeName(
417- | $counter,
418- | arr.get $primitiveValueTypeName(l)
419- | );
420- | }
421- | $counter++;
422- | }
423- |}
424- | $arrayDataName = $tempArrayDataName;
398+ | $numElemCode
399+ | $unsafeArraySizeInBytes
400+ |byte[] $arrayName = new byte[ $arraySizeName];
401+ |UnsafeArrayData $tempArrayDataName = new UnsafeArrayData();
402+ |Platform.putLong( $arrayName, $baseOffset, $numElemName);
403+ | $tempArrayDataName.pointTo( $arrayName, $baseOffset, $arraySizeName);
404+ |int $counter = 0;
405+ |for(int k=0; k < $childVariableName.numElements(); k++) {
406+ | ArrayData arr = $childVariableName.getArray(k);
407+ | for(int l = 0; l < arr.numElements(); l++) {
408+ | if(arr.isNullAt(l)) {
409+ | $tempArrayDataName.setNullAt( $counter);
410+ | } else {
411+ | $tempArrayDataName.set $primitiveValueTypeName(
412+ | $counter,
413+ | arr.get $primitiveValueTypeName(l)
414+ | );
415+ | }
416+ | $counter++;
417+ | }
418+ |}
419+ | $arrayDataName = $tempArrayDataName;
425420 """ .stripMargin
426421 }
427422
@@ -435,18 +430,18 @@ case class Flatten(child: Expression) extends UnaryExpression {
435430 val (numElemCode, numElemName) = genCodeForNumberOfElements(ctx, childVariableName)
436431
437432 s """
438- | $numElemCode
439- |Object[] $arrayName = new Object[ $numElemName];
440- |int $counter = 0;
441- |for(int k=0; k < $childVariableName.numElements(); k++) {
442- | Object[] arr = $childVariableName.getArray(k).array();
443- | for(int l = 0; l < arr.length; l++) {
444- | $arrayName[ $counter] = arr[l];
445- | $counter++;
446- | }
447- |}
448- | $arrayDataName = new $genericArrayClass( $arrayName);
449- """ .stripMargin
433+ | $numElemCode
434+ |Object[] $arrayName = new Object[ $numElemName];
435+ |int $counter = 0;
436+ |for(int k=0; k < $childVariableName.numElements(); k++) {
437+ | Object[] arr = $childVariableName.getArray(k).array();
438+ | for(int l = 0; l < arr.length; l++) {
439+ | $arrayName[ $counter] = arr[l];
440+ | $counter++;
441+ | }
442+ |}
443+ | $arrayDataName = new $genericArrayClass( $arrayName);
444+ """ .stripMargin
450445 }
451446
452447 override def prettyName : String = " flatten"
0 commit comments