Skip to content

Commit 36d5d25

Browse files
mn-mikkemn-mikke
authored andcommitted
[SPARK-23736][SQL] Merging current master to the feature branch.
2 parents f7bdcf7 + 46bb2b5 commit 36d5d25

File tree

7 files changed

+276
-24
lines changed

7 files changed

+276
-24
lines changed

python/pyspark/sql/functions.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1866,6 +1866,30 @@ def array_position(col, value):
18661866
return Column(sc._jvm.functions.array_position(_to_java_column(col), value))
18671867

18681868

1869+
@ignore_unicode_prefix
1870+
@since(2.4)
1871+
def element_at(col, extraction):
1872+
"""
1873+
Collection function: Returns element of array at given index in extraction if col is array.
1874+
Returns value for the given key in extraction if col is map.
1875+
1876+
:param col: name of column containing array or map
1877+
:param extraction: index to check for in array or key to check for in map
1878+
1879+
.. note:: The position is not zero based, but 1 based index.
1880+
1881+
>>> df = spark.createDataFrame([(["a", "b", "c"],), ([],)], ['data'])
1882+
>>> df.select(element_at(df.data, 1)).collect()
1883+
[Row(element_at(data, 1)=u'a'), Row(element_at(data, 1)=None)]
1884+
1885+
>>> df = spark.createDataFrame([({"a": 1.0, "b": 2.0},), ({},)], ['data'])
1886+
>>> df.select(element_at(df.data, "a")).collect()
1887+
[Row(element_at(data, a)=1.0), Row(element_at(data, a)=None)]
1888+
"""
1889+
sc = SparkContext._active_spark_context
1890+
return Column(sc._jvm.functions.element_at(_to_java_column(col), extraction))
1891+
1892+
18691893
@since(1.4)
18701894
def explode(col):
18711895
"""Returns a new row for each element in the given array or map.

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -404,6 +404,7 @@ object FunctionRegistry {
404404
expression[ArrayPosition]("array_position"),
405405
expression[CreateMap]("map"),
406406
expression[CreateNamedStruct]("named_struct"),
407+
expression[ElementAt]("element_at"),
407408
expression[MapKeys]("map_keys"),
408409
expression[MapValues]("map_values"),
409410
expression[Size]("size"),

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala

Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -564,6 +564,110 @@ case class ArrayPosition(left: Expression, right: Expression)
564564
}
565565
}
566566

567+
/**
568+
* Returns the value of index `right` in Array `left` or the value for key `right` in Map `left`.
569+
*/
570+
@ExpressionDescription(
571+
usage = """
572+
_FUNC_(array, index) - Returns element of array at given (1-based) index. If index < 0,
573+
accesses elements from the last to the first. Returns NULL if the index exceeds the length
574+
of the array.
575+
576+
_FUNC_(map, key) - Returns value for given key, or NULL if the key is not contained in the map
577+
""",
578+
examples = """
579+
Examples:
580+
> SELECT _FUNC_(array(1, 2, 3), 2);
581+
2
582+
> SELECT _FUNC_(map(1, 'a', 2, 'b'), 2);
583+
"b"
584+
""",
585+
since = "2.4.0")
586+
case class ElementAt(left: Expression, right: Expression) extends GetMapValueUtil {
587+
588+
override def dataType: DataType = left.dataType match {
589+
case ArrayType(elementType, _) => elementType
590+
case MapType(_, valueType, _) => valueType
591+
}
592+
593+
override def inputTypes: Seq[AbstractDataType] = {
594+
Seq(TypeCollection(ArrayType, MapType),
595+
left.dataType match {
596+
case _: ArrayType => IntegerType
597+
case _: MapType => left.dataType.asInstanceOf[MapType].keyType
598+
}
599+
)
600+
}
601+
602+
override def nullable: Boolean = true
603+
604+
override def nullSafeEval(value: Any, ordinal: Any): Any = {
605+
left.dataType match {
606+
case _: ArrayType =>
607+
val array = value.asInstanceOf[ArrayData]
608+
val index = ordinal.asInstanceOf[Int]
609+
if (array.numElements() < math.abs(index)) {
610+
null
611+
} else {
612+
val idx = if (index == 0) {
613+
throw new ArrayIndexOutOfBoundsException("SQL array indices start at 1")
614+
} else if (index > 0) {
615+
index - 1
616+
} else {
617+
array.numElements() + index
618+
}
619+
if (left.dataType.asInstanceOf[ArrayType].containsNull && array.isNullAt(idx)) {
620+
null
621+
} else {
622+
array.get(idx, dataType)
623+
}
624+
}
625+
case _: MapType =>
626+
getValueEval(value, ordinal, left.dataType.asInstanceOf[MapType].keyType)
627+
}
628+
}
629+
630+
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
631+
left.dataType match {
632+
case _: ArrayType =>
633+
nullSafeCodeGen(ctx, ev, (eval1, eval2) => {
634+
val index = ctx.freshName("elementAtIndex")
635+
val nullCheck = if (left.dataType.asInstanceOf[ArrayType].containsNull) {
636+
s"""
637+
|if ($eval1.isNullAt($index)) {
638+
| ${ev.isNull} = true;
639+
|} else
640+
""".stripMargin
641+
} else {
642+
""
643+
}
644+
s"""
645+
|int $index = (int) $eval2;
646+
|if ($eval1.numElements() < Math.abs($index)) {
647+
| ${ev.isNull} = true;
648+
|} else {
649+
| if ($index == 0) {
650+
| throw new ArrayIndexOutOfBoundsException("SQL array indices start at 1");
651+
| } else if ($index > 0) {
652+
| $index--;
653+
| } else {
654+
| $index += $eval1.numElements();
655+
| }
656+
| $nullCheck
657+
| {
658+
| ${ev.value} = ${CodeGenerator.getValue(eval1, dataType, index)};
659+
| }
660+
|}
661+
""".stripMargin
662+
})
663+
case _: MapType =>
664+
doGetValueGenCode(ctx, ev, left.dataType.asInstanceOf[MapType])
665+
}
666+
}
667+
668+
override def prettyName: String = "element_at"
669+
}
670+
567671
/**
568672
* Concatenates multiple input columns together into a single column.
569673
* The function works with strings, binary and compatible array columns.

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala

Lines changed: 40 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -268,31 +268,12 @@ case class GetArrayItem(child: Expression, ordinal: Expression)
268268
}
269269

270270
/**
271-
* Returns the value of key `key` in Map `child`.
272-
*
273-
* We need to do type checking here as `key` expression maybe unresolved.
271+
* Common base class for [[GetMapValue]] and [[ElementAt]].
274272
*/
275-
case class GetMapValue(child: Expression, key: Expression)
276-
extends BinaryExpression with ImplicitCastInputTypes with ExtractValue with NullIntolerant {
277-
278-
private def keyType = child.dataType.asInstanceOf[MapType].keyType
279-
280-
// We have done type checking for child in `ExtractValue`, so only need to check the `key`.
281-
override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType, keyType)
282-
283-
override def toString: String = s"$child[$key]"
284-
override def sql: String = s"${child.sql}[${key.sql}]"
285-
286-
override def left: Expression = child
287-
override def right: Expression = key
288-
289-
/** `Null` is returned for invalid ordinals. */
290-
override def nullable: Boolean = true
291-
292-
override def dataType: DataType = child.dataType.asInstanceOf[MapType].valueType
293273

274+
abstract class GetMapValueUtil extends BinaryExpression with ImplicitCastInputTypes {
294275
// todo: current search is O(n), improve it.
295-
protected override def nullSafeEval(value: Any, ordinal: Any): Any = {
276+
def getValueEval(value: Any, ordinal: Any, keyType: DataType): Any = {
296277
val map = value.asInstanceOf[MapData]
297278
val length = map.numElements()
298279
val keys = map.keyArray()
@@ -315,14 +296,15 @@ case class GetMapValue(child: Expression, key: Expression)
315296
}
316297
}
317298

318-
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
299+
def doGetValueGenCode(ctx: CodegenContext, ev: ExprCode, mapType: MapType): ExprCode = {
319300
val index = ctx.freshName("index")
320301
val length = ctx.freshName("length")
321302
val keys = ctx.freshName("keys")
322303
val found = ctx.freshName("found")
323304
val key = ctx.freshName("key")
324305
val values = ctx.freshName("values")
325-
val nullCheck = if (child.dataType.asInstanceOf[MapType].valueContainsNull) {
306+
val keyType = mapType.keyType
307+
val nullCheck = if (mapType.valueContainsNull) {
326308
s" || $values.isNullAt($index)"
327309
} else {
328310
""
@@ -354,3 +336,37 @@ case class GetMapValue(child: Expression, key: Expression)
354336
})
355337
}
356338
}
339+
340+
/**
341+
* Returns the value of key `key` in Map `child`.
342+
*
343+
* We need to do type checking here as `key` expression maybe unresolved.
344+
*/
345+
case class GetMapValue(child: Expression, key: Expression)
346+
extends GetMapValueUtil with ExtractValue with NullIntolerant {
347+
348+
private def keyType = child.dataType.asInstanceOf[MapType].keyType
349+
350+
// We have done type checking for child in `ExtractValue`, so only need to check the `key`.
351+
override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType, keyType)
352+
353+
override def toString: String = s"$child[$key]"
354+
override def sql: String = s"${child.sql}[${key.sql}]"
355+
356+
override def left: Expression = child
357+
override def right: Expression = key
358+
359+
/** `Null` is returned for invalid ordinals. */
360+
override def nullable: Boolean = true
361+
362+
override def dataType: DataType = child.dataType.asInstanceOf[MapType].valueType
363+
364+
// todo: current search is O(n), improve it.
365+
override def nullSafeEval(value: Any, ordinal: Any): Any = {
366+
getValueEval(value, ordinal, keyType)
367+
}
368+
369+
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
370+
doGetValueGenCode(ctx, ev, child.dataType.asInstanceOf[MapType])
371+
}
372+
}

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -192,6 +192,54 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper
192192
checkEvaluation(ArrayPosition(a3, Literal.create(null, StringType)), null)
193193
}
194194

195+
test("elementAt") {
196+
val a0 = Literal.create(Seq(1, 2, 3), ArrayType(IntegerType))
197+
val a1 = Literal.create(Seq[String](null, ""), ArrayType(StringType))
198+
val a2 = Literal.create(Seq(null), ArrayType(LongType))
199+
val a3 = Literal.create(null, ArrayType(StringType))
200+
201+
intercept[Exception] {
202+
checkEvaluation(ElementAt(a0, Literal(0)), null)
203+
}.getMessage.contains("SQL array indices start at 1")
204+
intercept[Exception] { checkEvaluation(ElementAt(a0, Literal(1.1)), null) }
205+
checkEvaluation(ElementAt(a0, Literal(4)), null)
206+
checkEvaluation(ElementAt(a0, Literal(-4)), null)
207+
208+
checkEvaluation(ElementAt(a0, Literal(1)), 1)
209+
checkEvaluation(ElementAt(a0, Literal(2)), 2)
210+
checkEvaluation(ElementAt(a0, Literal(3)), 3)
211+
checkEvaluation(ElementAt(a0, Literal(-3)), 1)
212+
checkEvaluation(ElementAt(a0, Literal(-2)), 2)
213+
checkEvaluation(ElementAt(a0, Literal(-1)), 3)
214+
215+
checkEvaluation(ElementAt(a1, Literal(1)), null)
216+
checkEvaluation(ElementAt(a1, Literal(2)), "")
217+
checkEvaluation(ElementAt(a1, Literal(-2)), null)
218+
checkEvaluation(ElementAt(a1, Literal(-1)), "")
219+
220+
checkEvaluation(ElementAt(a2, Literal(1)), null)
221+
222+
checkEvaluation(ElementAt(a3, Literal(1)), null)
223+
224+
225+
val m0 =
226+
Literal.create(Map("a" -> "1", "b" -> "2", "c" -> null), MapType(StringType, StringType))
227+
val m1 = Literal.create(Map[String, String](), MapType(StringType, StringType))
228+
val m2 = Literal.create(null, MapType(StringType, StringType))
229+
230+
checkEvaluation(ElementAt(m0, Literal(1.0)), null)
231+
232+
checkEvaluation(ElementAt(m0, Literal("d")), null)
233+
234+
checkEvaluation(ElementAt(m1, Literal("a")), null)
235+
236+
checkEvaluation(ElementAt(m0, Literal("a")), "1")
237+
checkEvaluation(ElementAt(m0, Literal("b")), "2")
238+
checkEvaluation(ElementAt(m0, Literal("c")), null)
239+
240+
checkEvaluation(ElementAt(m2, Literal("a")), null)
241+
}
242+
195243
test("Concat") {
196244
// Primitive-type elements
197245
val ai0 = Literal.create(Seq(1, 2, 3), ArrayType(IntegerType))

sql/core/src/main/scala/org/apache/spark/sql/functions.scala

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3052,6 +3052,17 @@ object functions {
30523052
ArrayPosition(column.expr, Literal(value))
30533053
}
30543054

3055+
/**
3056+
* Returns element of array at given index in value if column is array. Returns value for
3057+
* the given key in value if column is map.
3058+
*
3059+
* @group collection_funcs
3060+
* @since 2.4.0
3061+
*/
3062+
def element_at(column: Column, value: Any): Column = withExpr {
3063+
ElementAt(column.expr, Literal(value))
3064+
}
3065+
30553066
/**
30563067
* Creates a new row for each element in the given array or map column.
30573068
*

sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -569,6 +569,54 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext {
569569
)
570570
}
571571

572+
test("element_at function") {
573+
val df = Seq(
574+
(Seq[String]("1", "2", "3")),
575+
(Seq[String](null, "")),
576+
(Seq[String]())
577+
).toDF("a")
578+
579+
intercept[Exception] {
580+
checkAnswer(
581+
df.select(element_at(df("a"), 0)),
582+
Seq(Row(null), Row(null), Row(null))
583+
)
584+
}.getMessage.contains("SQL array indices start at 1")
585+
intercept[Exception] {
586+
checkAnswer(
587+
df.select(element_at(df("a"), 1.1)),
588+
Seq(Row(null), Row(null), Row(null))
589+
)
590+
}
591+
checkAnswer(
592+
df.select(element_at(df("a"), 4)),
593+
Seq(Row(null), Row(null), Row(null))
594+
)
595+
596+
checkAnswer(
597+
df.select(element_at(df("a"), 1)),
598+
Seq(Row("1"), Row(null), Row(null))
599+
)
600+
checkAnswer(
601+
df.select(element_at(df("a"), -1)),
602+
Seq(Row("3"), Row(""), Row(null))
603+
)
604+
605+
checkAnswer(
606+
df.selectExpr("element_at(a, 4)"),
607+
Seq(Row(null), Row(null), Row(null))
608+
)
609+
610+
checkAnswer(
611+
df.selectExpr("element_at(a, 1)"),
612+
Seq(Row("1"), Row(null), Row(null))
613+
)
614+
checkAnswer(
615+
df.selectExpr("element_at(a, -1)"),
616+
Seq(Row("3"), Row(""), Row(null))
617+
)
618+
}
619+
572620
test("concat function - arrays") {
573621
val nseqi : Seq[Int] = null
574622
val nseqs : Seq[String] = null

0 commit comments

Comments
 (0)