Skip to content

Commit f7bdcf7

Browse files
mn-mikkemn-mikke
authored andcommitted
[SPARK-23736][SQL] Merging current master to the feature branch.
2 parents 5a4cc8c + d5bec48 commit f7bdcf7

File tree

13 files changed

+219
-39
lines changed

13 files changed

+219
-39
lines changed

python/pyspark/sql/functions.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1849,6 +1849,23 @@ def concat(*cols):
18491849
return Column(sc._jvm.functions.concat(_to_seq(sc, cols, _to_java_column)))
18501850

18511851

1852+
@since(2.4)
1853+
def array_position(col, value):
1854+
"""
1855+
Collection function: Locates the position of the first occurrence of the given value
1856+
in the given array. Returns null if either of the arguments are null.
1857+
1858+
.. note:: The position is not zero based, but 1 based index. Returns 0 if the given
1859+
value could not be found in the array.
1860+
1861+
>>> df = spark.createDataFrame([(["c", "b", "a"],), ([],)], ['data'])
1862+
>>> df.select(array_position(df.data, "a")).collect()
1863+
[Row(array_position(data, a)=3), Row(array_position(data, a)=0)]
1864+
"""
1865+
sc = SparkContext._active_spark_context
1866+
return Column(sc._jvm.functions.array_position(_to_java_column(col), value))
1867+
1868+
18521869
@since(1.4)
18531870
def explode(col):
18541871
"""Returns a new row for each element in the given array or map.

python/pyspark/streaming/kafka.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,8 @@ def createDirectStream(ssc, topics, kafkaParams, fromOffsets=None,
104104
:param topics: list of topic_name to consume.
105105
:param kafkaParams: Additional params for Kafka.
106106
:param fromOffsets: Per-topic/partition Kafka offsets defining the (inclusive) starting
107-
point of the stream.
107+
point of the stream (a dictionary mapping `TopicAndPartition` to
108+
integers).
108109
:param keyDecoder: A function used to decode key (default is utf8_decoder).
109110
:param valueDecoder: A function used to decode value (default is utf8_decoder).
110111
:param messageHandler: A function used to convert KafkaMessageAndMetadata. You can assess

python/pyspark/streaming/listener.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,12 @@ class StreamingListener(object):
2323
def __init__(self):
2424
pass
2525

26+
def onStreamingStarted(self, streamingStarted):
27+
"""
28+
Called when the streaming has been started.
29+
"""
30+
pass
31+
2632
def onReceiverStarted(self, receiverStarted):
2733
"""
2834
Called when a receiver has been started

python/pyspark/streaming/tests.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -507,6 +507,10 @@ def __init__(self):
507507
self.batchInfosCompleted = []
508508
self.batchInfosStarted = []
509509
self.batchInfosSubmitted = []
510+
self.streamingStartedTime = []
511+
512+
def onStreamingStarted(self, streamingStarted):
513+
self.streamingStartedTime.append(streamingStarted.time)
510514

511515
def onBatchSubmitted(self, batchSubmitted):
512516
self.batchInfosSubmitted.append(batchSubmitted.batchInfo())
@@ -530,9 +534,12 @@ def func(dstream):
530534
batchInfosSubmitted = batch_collector.batchInfosSubmitted
531535
batchInfosStarted = batch_collector.batchInfosStarted
532536
batchInfosCompleted = batch_collector.batchInfosCompleted
537+
streamingStartedTime = batch_collector.streamingStartedTime
533538

534539
self.wait_for(batchInfosCompleted, 4)
535540

541+
self.assertEqual(len(streamingStartedTime), 1)
542+
536543
self.assertGreaterEqual(len(batchInfosSubmitted), 4)
537544
for info in batchInfosSubmitted:
538545
self.assertGreaterEqual(info.batchTime().milliseconds(), 0)

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -401,6 +401,7 @@ object FunctionRegistry {
401401
// collection functions
402402
expression[CreateArray]("array"),
403403
expression[ArrayContains]("array_contains"),
404+
expression[ArrayPosition]("array_position"),
404405
expression[CreateMap]("map"),
405406
expression[CreateNamedStruct]("named_struct"),
406407
expression[MapKeys]("map_keys"),

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -582,8 +582,10 @@ class CodegenContext {
582582
*/
583583
def genEqual(dataType: DataType, c1: String, c2: String): String = dataType match {
584584
case BinaryType => s"java.util.Arrays.equals($c1, $c2)"
585-
case FloatType => s"(java.lang.Float.isNaN($c1) && java.lang.Float.isNaN($c2)) || $c1 == $c2"
586-
case DoubleType => s"(java.lang.Double.isNaN($c1) && java.lang.Double.isNaN($c2)) || $c1 == $c2"
585+
case FloatType =>
586+
s"((java.lang.Float.isNaN($c1) && java.lang.Float.isNaN($c2)) || $c1 == $c2)"
587+
case DoubleType =>
588+
s"((java.lang.Double.isNaN($c1) && java.lang.Double.isNaN($c2)) || $c1 == $c2)"
587589
case dt: DataType if isPrimitiveType(dt) => s"$c1 == $c2"
588590
case dt: DataType if dt.isInstanceOf[AtomicType] => s"$c1.equals($c2)"
589591
case array: ArrayType => genComp(array, c1, c2) + " == 0"

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -508,6 +508,62 @@ case class ArrayMax(child: Expression) extends UnaryExpression with ImplicitCast
508508
override def prettyName: String = "array_max"
509509
}
510510

511+
512+
/**
513+
* Returns the position of the first occurrence of element in the given array as long.
514+
* Returns 0 if the given value could not be found in the array. Returns null if either of
515+
* the arguments are null
516+
*
517+
* NOTE: that this is not zero based, but 1-based index. The first element in the array has
518+
* index 1.
519+
*/
520+
@ExpressionDescription(
521+
usage = """
522+
_FUNC_(array, element) - Returns the (1-based) index of the first element of the array as long.
523+
""",
524+
examples = """
525+
Examples:
526+
> SELECT _FUNC_(array(3, 2, 1), 1);
527+
3
528+
""",
529+
since = "2.4.0")
530+
case class ArrayPosition(left: Expression, right: Expression)
531+
extends BinaryExpression with ImplicitCastInputTypes {
532+
533+
override def dataType: DataType = LongType
534+
override def inputTypes: Seq[AbstractDataType] =
535+
Seq(ArrayType, left.dataType.asInstanceOf[ArrayType].elementType)
536+
537+
override def nullSafeEval(arr: Any, value: Any): Any = {
538+
arr.asInstanceOf[ArrayData].foreach(right.dataType, (i, v) =>
539+
if (v == value) {
540+
return (i + 1).toLong
541+
}
542+
)
543+
0L
544+
}
545+
546+
override def prettyName: String = "array_position"
547+
548+
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
549+
nullSafeCodeGen(ctx, ev, (arr, value) => {
550+
val pos = ctx.freshName("arrayPosition")
551+
val i = ctx.freshName("i")
552+
val getValue = CodeGenerator.getValue(arr, right.dataType, i)
553+
s"""
554+
|int $pos = 0;
555+
|for (int $i = 0; $i < $arr.numElements(); $i ++) {
556+
| if (!$arr.isNullAt($i) && ${ctx.genEqual(right.dataType, value, getValue)}) {
557+
| $pos = $i + 1;
558+
| break;
559+
| }
560+
|}
561+
|${ev.value} = (long) $pos;
562+
""".stripMargin
563+
})
564+
}
565+
}
566+
511567
/**
512568
* Concatenates multiple input columns together into a single column.
513569
* The function works with strings, binary and compatible array columns.

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ArrayData.scala

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -141,28 +141,29 @@ abstract class ArrayData extends SpecializedGetters with Serializable {
141141

142142
def toArray[T: ClassTag](elementType: DataType): Array[T] = {
143143
val size = numElements()
144+
val accessor = InternalRow.getAccessor(elementType)
144145
val values = new Array[T](size)
145146
var i = 0
146147
while (i < size) {
147148
if (isNullAt(i)) {
148149
values(i) = null.asInstanceOf[T]
149150
} else {
150-
values(i) = get(i, elementType).asInstanceOf[T]
151+
values(i) = accessor(this, i).asInstanceOf[T]
151152
}
152153
i += 1
153154
}
154155
values
155156
}
156157

157-
// todo: specialize this.
158158
def foreach(elementType: DataType, f: (Int, Any) => Unit): Unit = {
159159
val size = numElements()
160+
val accessor = InternalRow.getAccessor(elementType)
160161
var i = 0
161162
while (i < size) {
162163
if (isNullAt(i)) {
163164
f(i, null)
164165
} else {
165-
f(i, get(i, elementType))
166+
f(i, accessor(this, i))
166167
}
167168
i += 1
168169
}

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,28 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper
170170
checkEvaluation(Reverse(aa), Seq(Seq("e"), Seq("c", "d"), Seq("a", "b")))
171171
}
172172

173+
test("Array Position") {
174+
val a0 = Literal.create(Seq(1, null, 2, 3), ArrayType(IntegerType))
175+
val a1 = Literal.create(Seq[String](null, ""), ArrayType(StringType))
176+
val a2 = Literal.create(Seq(null), ArrayType(LongType))
177+
val a3 = Literal.create(null, ArrayType(StringType))
178+
179+
checkEvaluation(ArrayPosition(a0, Literal(3)), 4L)
180+
checkEvaluation(ArrayPosition(a0, Literal(1)), 1L)
181+
checkEvaluation(ArrayPosition(a0, Literal(0)), 0L)
182+
checkEvaluation(ArrayPosition(a0, Literal.create(null, IntegerType)), null)
183+
184+
checkEvaluation(ArrayPosition(a1, Literal("")), 2L)
185+
checkEvaluation(ArrayPosition(a1, Literal("a")), 0L)
186+
checkEvaluation(ArrayPosition(a1, Literal.create(null, StringType)), null)
187+
188+
checkEvaluation(ArrayPosition(a2, Literal(1L)), 0L)
189+
checkEvaluation(ArrayPosition(a2, Literal.create(null, LongType)), null)
190+
191+
checkEvaluation(ArrayPosition(a3, Literal("")), null)
192+
checkEvaluation(ArrayPosition(a3, Literal.create(null, StringType)), null)
193+
}
194+
173195
test("Concat") {
174196
// Primitive-type elements
175197
val ai0 = Literal.create(Seq(1, 2, 3), ArrayType(IntegerType))

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -442,4 +442,11 @@ class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper {
442442
InSet(Literal(1), Set(1, 2, 3, 4)).genCode(ctx)
443443
assert(ctx.inlinedMutableStates.isEmpty)
444444
}
445+
446+
test("SPARK-24007: EqualNullSafe for FloatType and DoubleType might generate a wrong result") {
447+
checkEvaluation(EqualNullSafe(Literal(null, FloatType), Literal(-1.0f)), false)
448+
checkEvaluation(EqualNullSafe(Literal(-1.0f), Literal(null, FloatType)), false)
449+
checkEvaluation(EqualNullSafe(Literal(null, DoubleType), Literal(-1.0d)), false)
450+
checkEvaluation(EqualNullSafe(Literal(-1.0d), Literal(null, DoubleType)), false)
451+
}
445452
}

0 commit comments

Comments
 (0)