Skip to content

Commit 37b68cd

Browse files
mn-mikkemn-mikke
authored andcommitted
[SPARK-23821][SQL] Small refactoring
1 parent 88c4971 commit 37b68cd

File tree

1 file changed

+104
-103
lines changed

1 file changed

+104
-103
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala

Lines changed: 104 additions & 103 deletions
Original file line numberDiff line numberDiff line change
@@ -564,6 +564,110 @@ case class ArrayPosition(left: Expression, right: Expression)
564564
}
565565
}
566566

567+
/**
568+
* Returns the value of index `right` in Array `left` or the value for key `right` in Map `left`.
569+
*/
570+
@ExpressionDescription(
571+
usage = """
572+
_FUNC_(array, index) - Returns element of array at given (1-based) index. If index < 0,
573+
accesses elements from the last to the first. Returns NULL if the index exceeds the length
574+
of the array.
575+
576+
_FUNC_(map, key) - Returns value for given key, or NULL if the key is not contained in the map
577+
""",
578+
examples = """
579+
Examples:
580+
> SELECT _FUNC_(array(1, 2, 3), 2);
581+
2
582+
> SELECT _FUNC_(map(1, 'a', 2, 'b'), 2);
583+
"b"
584+
""",
585+
since = "2.4.0")
586+
case class ElementAt(left: Expression, right: Expression) extends GetMapValueUtil {
587+
588+
override def dataType: DataType = left.dataType match {
589+
case ArrayType(elementType, _) => elementType
590+
case MapType(_, valueType, _) => valueType
591+
}
592+
593+
override def inputTypes: Seq[AbstractDataType] = {
594+
Seq(TypeCollection(ArrayType, MapType),
595+
left.dataType match {
596+
case _: ArrayType => IntegerType
597+
case _: MapType => left.dataType.asInstanceOf[MapType].keyType
598+
}
599+
)
600+
}
601+
602+
override def nullable: Boolean = true
603+
604+
override def nullSafeEval(value: Any, ordinal: Any): Any = {
605+
left.dataType match {
606+
case _: ArrayType =>
607+
val array = value.asInstanceOf[ArrayData]
608+
val index = ordinal.asInstanceOf[Int]
609+
if (array.numElements() < math.abs(index)) {
610+
null
611+
} else {
612+
val idx = if (index == 0) {
613+
throw new ArrayIndexOutOfBoundsException("SQL array indices start at 1")
614+
} else if (index > 0) {
615+
index - 1
616+
} else {
617+
array.numElements() + index
618+
}
619+
if (left.dataType.asInstanceOf[ArrayType].containsNull && array.isNullAt(idx)) {
620+
null
621+
} else {
622+
array.get(idx, dataType)
623+
}
624+
}
625+
case _: MapType =>
626+
getValueEval(value, ordinal, left.dataType.asInstanceOf[MapType].keyType)
627+
}
628+
}
629+
630+
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
631+
left.dataType match {
632+
case _: ArrayType =>
633+
nullSafeCodeGen(ctx, ev, (eval1, eval2) => {
634+
val index = ctx.freshName("elementAtIndex")
635+
val nullCheck = if (left.dataType.asInstanceOf[ArrayType].containsNull) {
636+
s"""
637+
|if ($eval1.isNullAt($index)) {
638+
| ${ev.isNull} = true;
639+
|} else
640+
""".stripMargin
641+
} else {
642+
""
643+
}
644+
s"""
645+
|int $index = (int) $eval2;
646+
|if ($eval1.numElements() < Math.abs($index)) {
647+
| ${ev.isNull} = true;
648+
|} else {
649+
| if ($index == 0) {
650+
| throw new ArrayIndexOutOfBoundsException("SQL array indices start at 1");
651+
| } else if ($index > 0) {
652+
| $index--;
653+
| } else {
654+
| $index += $eval1.numElements();
655+
| }
656+
| $nullCheck
657+
| {
658+
| ${ev.value} = ${CodeGenerator.getValue(eval1, dataType, index)};
659+
| }
660+
|}
661+
""".stripMargin
662+
})
663+
case _: MapType =>
664+
doGetValueGenCode(ctx, ev, left.dataType.asInstanceOf[MapType])
665+
}
666+
}
667+
668+
override def prettyName: String = "element_at"
669+
}
670+
567671
/**
568672
* Transforms an array of arrays into a single array.
569673
*/
@@ -740,106 +844,3 @@ case class Flatten(child: Expression) extends UnaryExpression {
740844
override def prettyName: String = "flatten"
741845
}
742846

743-
/**
744-
* Returns the value of index `right` in Array `left` or the value for key `right` in Map `left`.
745-
*/
746-
@ExpressionDescription(
747-
usage = """
748-
_FUNC_(array, index) - Returns element of array at given (1-based) index. If index < 0,
749-
accesses elements from the last to the first. Returns NULL if the index exceeds the length
750-
of the array.
751-
752-
_FUNC_(map, key) - Returns value for given key, or NULL if the key is not contained in the map
753-
""",
754-
examples = """
755-
Examples:
756-
> SELECT _FUNC_(array(1, 2, 3), 2);
757-
2
758-
> SELECT _FUNC_(map(1, 'a', 2, 'b'), 2);
759-
"b"
760-
""",
761-
since = "2.4.0")
762-
case class ElementAt(left: Expression, right: Expression) extends GetMapValueUtil {
763-
764-
override def dataType: DataType = left.dataType match {
765-
case ArrayType(elementType, _) => elementType
766-
case MapType(_, valueType, _) => valueType
767-
}
768-
769-
override def inputTypes: Seq[AbstractDataType] = {
770-
Seq(TypeCollection(ArrayType, MapType),
771-
left.dataType match {
772-
case _: ArrayType => IntegerType
773-
case _: MapType => left.dataType.asInstanceOf[MapType].keyType
774-
}
775-
)
776-
}
777-
778-
override def nullable: Boolean = true
779-
780-
override def nullSafeEval(value: Any, ordinal: Any): Any = {
781-
left.dataType match {
782-
case _: ArrayType =>
783-
val array = value.asInstanceOf[ArrayData]
784-
val index = ordinal.asInstanceOf[Int]
785-
if (array.numElements() < math.abs(index)) {
786-
null
787-
} else {
788-
val idx = if (index == 0) {
789-
throw new ArrayIndexOutOfBoundsException("SQL array indices start at 1")
790-
} else if (index > 0) {
791-
index - 1
792-
} else {
793-
array.numElements() + index
794-
}
795-
if (left.dataType.asInstanceOf[ArrayType].containsNull && array.isNullAt(idx)) {
796-
null
797-
} else {
798-
array.get(idx, dataType)
799-
}
800-
}
801-
case _: MapType =>
802-
getValueEval(value, ordinal, left.dataType.asInstanceOf[MapType].keyType)
803-
}
804-
}
805-
806-
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
807-
left.dataType match {
808-
case _: ArrayType =>
809-
nullSafeCodeGen(ctx, ev, (eval1, eval2) => {
810-
val index = ctx.freshName("elementAtIndex")
811-
val nullCheck = if (left.dataType.asInstanceOf[ArrayType].containsNull) {
812-
s"""
813-
|if ($eval1.isNullAt($index)) {
814-
| ${ev.isNull} = true;
815-
|} else
816-
""".stripMargin
817-
} else {
818-
""
819-
}
820-
s"""
821-
|int $index = (int) $eval2;
822-
|if ($eval1.numElements() < Math.abs($index)) {
823-
| ${ev.isNull} = true;
824-
|} else {
825-
| if ($index == 0) {
826-
| throw new ArrayIndexOutOfBoundsException("SQL array indices start at 1");
827-
| } else if ($index > 0) {
828-
| $index--;
829-
| } else {
830-
| $index += $eval1.numElements();
831-
| }
832-
| $nullCheck
833-
| {
834-
| ${ev.value} = ${CodeGenerator.getValue(eval1, dataType, index)};
835-
| }
836-
|}
837-
""".stripMargin
838-
})
839-
case _: MapType =>
840-
doGetValueGenCode(ctx, ev, left.dataType.asInstanceOf[MapType])
841-
}
842-
}
843-
844-
override def prettyName: String = "element_at"
845-
}

0 commit comments

Comments
 (0)