Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -1268,11 +1268,15 @@ case class Reverse(child: Expression) extends UnaryExpression with ImplicitCastI

override def dataType: DataType = child.dataType

@transient private lazy val elementType: DataType = dataType.asInstanceOf[ArrayType].elementType
override def nullSafeEval(input: Any): Any = doReverse(input)

override def nullSafeEval(input: Any): Any = input match {
case a: ArrayData => new GenericArrayData(a.toObjectArray(elementType).reverse)
case s: UTF8String => s.reverse()
@transient private lazy val doReverse: Any => Any = dataType match {
case ArrayType(elementType, _) =>
input => {
val arrayData = input.asInstanceOf[ArrayData]
new GenericArrayData(arrayData.toObjectArray(elementType).reverse)
}
case StringType => _.asInstanceOf[UTF8String].reverse()
}

override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
Expand All @@ -1294,6 +1298,7 @@ case class Reverse(child: Expression) extends UnaryExpression with ImplicitCastI
val i = ctx.freshName("i")
val j = ctx.freshName("j")

val elementType = dataType.asInstanceOf[ArrayType].elementType
val initialization = CodeGenerator.createArrayData(
arrayData, elementType, numElements, s" $prettyName failed.")
val assignment = CodeGenerator.createArrayAssignment(
Expand Down Expand Up @@ -2160,9 +2165,11 @@ case class ElementAt(left: Expression, right: Expression) extends GetMapValueUti

override def nullable: Boolean = true

override def nullSafeEval(value: Any, ordinal: Any): Any = {
left.dataType match {
case _: ArrayType =>
override def nullSafeEval(value: Any, ordinal: Any): Any = doElementAt(value, ordinal)

@transient private lazy val doElementAt: (Any, Any) => Any = left.dataType match {
case _: ArrayType =>
(value, ordinal) => {
val array = value.asInstanceOf[ArrayData]
val index = ordinal.asInstanceOf[Int]
if (array.numElements() < math.abs(index)) {
Expand All @@ -2181,9 +2188,9 @@ case class ElementAt(left: Expression, right: Expression) extends GetMapValueUti
array.get(idx, dataType)
}
}
case _: MapType =>
getValueEval(value, ordinal, mapKeyType, ordering)
}
}
case _: MapType =>
(value, ordinal) => getValueEval(value, ordinal, mapKeyType, ordering)
}

override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
Expand Down Expand Up @@ -2274,33 +2281,41 @@ case class Concat(children: Seq[Expression]) extends ComplexTypeMergingExpressio

override def foldable: Boolean = children.forall(_.foldable)

override def eval(input: InternalRow): Any = dataType match {
override def eval(input: InternalRow): Any = doConcat(input)

@transient private lazy val doConcat: InternalRow => Any = dataType match {
case BinaryType =>
val inputs = children.map(_.eval(input).asInstanceOf[Array[Byte]])
ByteArray.concat(inputs: _*)
input => {
val inputs = children.map(_.eval(input).asInstanceOf[Array[Byte]])
ByteArray.concat(inputs: _*)
}
case StringType =>
val inputs = children.map(_.eval(input).asInstanceOf[UTF8String])
UTF8String.concat(inputs : _*)
input => {
val inputs = children.map(_.eval(input).asInstanceOf[UTF8String])
UTF8String.concat(inputs: _*)
}
case ArrayType(elementType, _) =>
val inputs = children.toStream.map(_.eval(input))
if (inputs.contains(null)) {
null
} else {
val arrayData = inputs.map(_.asInstanceOf[ArrayData])
val numberOfElements = arrayData.foldLeft(0L)((sum, ad) => sum + ad.numElements())
if (numberOfElements > ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) {
throw new RuntimeException(s"Unsuccessful try to concat arrays with $numberOfElements" +
" elements due to exceeding the array size limit " +
ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH + ".")
}
val finalData = new Array[AnyRef](numberOfElements.toInt)
var position = 0
for(ad <- arrayData) {
val arr = ad.toObjectArray(elementType)
Array.copy(arr, 0, finalData, position, arr.length)
position += arr.length
input => {
val inputs = children.toStream.map(_.eval(input))
if (inputs.contains(null)) {
null
} else {
val arrayData = inputs.map(_.asInstanceOf[ArrayData])
val numberOfElements = arrayData.foldLeft(0L)((sum, ad) => sum + ad.numElements())
if (numberOfElements > ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) {
throw new RuntimeException(s"Unsuccessful try to concat arrays with $numberOfElements" +
" elements due to exceeding the array size limit " +
ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH + ".")
}
val finalData = new Array[AnyRef](numberOfElements.toInt)
var position = 0
for (ad <- arrayData) {
val arr = ad.toObjectArray(elementType)
Array.copy(arr, 0, finalData, position, arr.length)
position += arr.length
}
new GenericArrayData(finalData)
}
new GenericArrayData(finalData)
}
}

Expand Down