diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index cc9edcfd41d0..c863be32f662 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -1268,11 +1268,15 @@ case class Reverse(child: Expression) extends UnaryExpression with ImplicitCastI override def dataType: DataType = child.dataType - @transient private lazy val elementType: DataType = dataType.asInstanceOf[ArrayType].elementType + override def nullSafeEval(input: Any): Any = doReverse(input) - override def nullSafeEval(input: Any): Any = input match { - case a: ArrayData => new GenericArrayData(a.toObjectArray(elementType).reverse) - case s: UTF8String => s.reverse() + @transient private lazy val doReverse: Any => Any = dataType match { + case ArrayType(elementType, _) => + input => { + val arrayData = input.asInstanceOf[ArrayData] + new GenericArrayData(arrayData.toObjectArray(elementType).reverse) + } + case StringType => _.asInstanceOf[UTF8String].reverse() } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { @@ -1294,6 +1298,7 @@ case class Reverse(child: Expression) extends UnaryExpression with ImplicitCastI val i = ctx.freshName("i") val j = ctx.freshName("j") + val elementType = dataType.asInstanceOf[ArrayType].elementType val initialization = CodeGenerator.createArrayData( arrayData, elementType, numElements, s" $prettyName failed.") val assignment = CodeGenerator.createArrayAssignment( @@ -2160,9 +2165,11 @@ case class ElementAt(left: Expression, right: Expression) extends GetMapValueUti override def nullable: Boolean = true - override def nullSafeEval(value: Any, ordinal: Any): Any = { - left.dataType match { - case _: ArrayType => + override def nullSafeEval(value: Any, ordinal: Any): Any = doElementAt(value, ordinal) + + @transient private lazy val doElementAt: (Any, Any) => Any = left.dataType match { + case _: ArrayType => + (value, ordinal) => { val array = value.asInstanceOf[ArrayData] val index = ordinal.asInstanceOf[Int] if (array.numElements() < math.abs(index)) { @@ -2181,9 +2188,9 @@ case class ElementAt(left: Expression, right: Expression) extends GetMapValueUti array.get(idx, dataType) } } - case _: MapType => - getValueEval(value, ordinal, mapKeyType, ordering) - } + } + case _: MapType => + (value, ordinal) => getValueEval(value, ordinal, mapKeyType, ordering) } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { @@ -2274,33 +2281,41 @@ case class Concat(children: Seq[Expression]) extends ComplexTypeMergingExpressio override def foldable: Boolean = children.forall(_.foldable) - override def eval(input: InternalRow): Any = dataType match { + override def eval(input: InternalRow): Any = doConcat(input) + + @transient private lazy val doConcat: InternalRow => Any = dataType match { case BinaryType => - val inputs = children.map(_.eval(input).asInstanceOf[Array[Byte]]) - ByteArray.concat(inputs: _*) + input => { + val inputs = children.map(_.eval(input).asInstanceOf[Array[Byte]]) + ByteArray.concat(inputs: _*) + } case StringType => - val inputs = children.map(_.eval(input).asInstanceOf[UTF8String]) - UTF8String.concat(inputs : _*) + input => { + val inputs = children.map(_.eval(input).asInstanceOf[UTF8String]) + UTF8String.concat(inputs: _*) + } case ArrayType(elementType, _) => - val inputs = children.toStream.map(_.eval(input)) - if (inputs.contains(null)) { - null - } else { - val arrayData = inputs.map(_.asInstanceOf[ArrayData]) - val numberOfElements = arrayData.foldLeft(0L)((sum, ad) => sum + ad.numElements()) - if (numberOfElements > ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) { - throw new RuntimeException(s"Unsuccessful try to concat arrays with $numberOfElements" + - " elements due to exceeding the array size limit " + - ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH + ".") - } - val finalData = new Array[AnyRef](numberOfElements.toInt) - var position = 0 - for(ad <- arrayData) { - val arr = ad.toObjectArray(elementType) - Array.copy(arr, 0, finalData, position, arr.length) - position += arr.length + input => { + val inputs = children.toStream.map(_.eval(input)) + if (inputs.contains(null)) { + null + } else { + val arrayData = inputs.map(_.asInstanceOf[ArrayData]) + val numberOfElements = arrayData.foldLeft(0L)((sum, ad) => sum + ad.numElements()) + if (numberOfElements > ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) { + throw new RuntimeException(s"Unsuccessful try to concat arrays with $numberOfElements" + + " elements due to exceeding the array size limit " + + ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH + ".") + } + val finalData = new Array[AnyRef](numberOfElements.toInt) + var position = 0 + for (ad <- arrayData) { + val arr = ad.toObjectArray(elementType) + Array.copy(arr, 0, finalData, position, arr.length) + position += arr.length + } + new GenericArrayData(finalData) } - new GenericArrayData(finalData) } }