diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index ec014a5b39c31..eaecf284b51f1 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -2033,6 +2033,25 @@ def array_distinct(col): return Column(sc._jvm.functions.array_distinct(_to_java_column(col))) +@ignore_unicode_prefix +@since(2.4) +def array_intersect(col1, col2): + """ + Collection function: returns an array of the elements in the intersection of col1 and col2, + without duplicates. + + :param col1: name of column containing array + :param col2: name of column containing array + + >>> from pyspark.sql import Row + >>> df = spark.createDataFrame([Row(c1=["b", "a", "c"], c2=["c", "d", "a", "f"])]) + >>> df.select(array_intersect(df.c1, df.c2)).collect() + [Row(array_intersect(c1, c2)=[u'a', u'c'])] + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.array_intersect(_to_java_column(col1), _to_java_column(col2))) + + @ignore_unicode_prefix @since(2.4) def array_union(col1, col2): diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index d0efe975f81ce..10f89c9ee02db 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -411,6 +411,7 @@ object FunctionRegistry { expression[CreateArray]("array"), expression[ArrayContains]("array_contains"), expression[ArraysOverlap]("arrays_overlap"), + expression[ArrayIntersect]("array_intersect"), expression[ArrayJoin]("array_join"), expression[ArrayPosition]("array_position"), expression[ArraySort]("array_sort"), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 3f94f25796634..e385c2d9782e8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -3651,7 +3651,7 @@ case class ArrayDistinct(child: Expression) } /** - * Will become common base class for [[ArrayUnion]], ArrayIntersect, and [[ArrayExcept]]. + * Will become common base class for [[ArrayUnion]], [[ArrayIntersect]], and [[ArrayExcept]]. */ abstract class ArraySetLike extends BinaryArrayExpressionWithImplicitCast { override def checkInputDataTypes(): TypeCheckResult = { @@ -3672,6 +3672,75 @@ abstract class ArraySetLike extends BinaryArrayExpressionWithImplicitCast { case _: AtomicType => true case _ => false } + + @transient protected lazy val canUseSpecializedHashSet = elementType match { + case ByteType | ShortType | IntegerType | LongType | FloatType | DoubleType => true + case _ => false + } + + protected def genGetValue(array: String, i: String): String = + CodeGenerator.getValue(array, elementType, i) + + @transient protected lazy val (hsPostFix, hsTypeName) = { + val ptName = CodeGenerator.primitiveTypeName (elementType) + elementType match { + // we cast byte/short to int when writing to the hash set. + case ByteType | ShortType | IntegerType => ("$mcI$sp", "Int") + case LongType => ("$mcJ$sp", ptName) + case FloatType => ("$mcF$sp", ptName) + case DoubleType => ("$mcD$sp", ptName) + } + } + + // we cast byte/short to int when writing to the hash set. + @transient protected lazy val hsValueCast = elementType match { + case ByteType | ShortType => "(int) " + case _ => "" + } + + // When hitting a null value, put a null holder in the ArrayBuilder. Finally we will + // convert ArrayBuilder to ArrayData and setNull on the slot with null holder. + @transient protected lazy val nullValueHolder = elementType match { + case ByteType => "(byte) 0" + case ShortType => "(short) 0" + case _ => "0" + } + + protected def withResultArrayNullCheck( + body: String, + value: String, + nullElementIndex: String): String = { + if (dataType.asInstanceOf[ArrayType].containsNull) { + s""" + |$body + |if ($nullElementIndex >= 0) { + | // result has null element + | $value.setNullAt($nullElementIndex); + |} + """.stripMargin + } else { + body + } + } + + def buildResultArray( + builder: String, + value : String, + size : String, + nullElementIndex : String): String = withResultArrayNullCheck( + s""" + |if ($size > ${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}) { + | throw new RuntimeException("Cannot create array with " + $size + + | " bytes of data due to exceeding the limit " + + | "${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH} elements for ArrayData."); + |} + | + |if (!UnsafeArrayData.shouldUseGenericArrayData(${elementType.defaultSize}, $size)) { + | $value = UnsafeArrayData.fromPrimitiveArray($builder.result()); + |} else { + | $value = new ${classOf[GenericArrayData].getName}($builder.result()); + |} + """.stripMargin, value, nullElementIndex) } object ArraySetLike { @@ -3965,6 +4034,248 @@ object ArrayUnion { } } +/** + * Returns an array of the elements in the intersect of x and y, without duplicates + */ +@ExpressionDescription( + usage = """ + _FUNC_(array1, array2) - Returns an array of the elements in the intersection of array1 and + array2, without duplicates. + """, + examples = """ + Examples: + > SELECT _FUNC_(array(1, 2, 3), array(1, 3, 5)); + array(1, 3) + """, + since = "2.4.0") +case class ArrayIntersect(left: Expression, right: Expression) extends ArraySetLike + with ComplexTypeMergingExpression { + override def dataType: DataType = { + dataTypeCheck + ArrayType(elementType, + left.dataType.asInstanceOf[ArrayType].containsNull && + right.dataType.asInstanceOf[ArrayType].containsNull) + } + + @transient lazy val evalIntersect: (ArrayData, ArrayData) => ArrayData = { + if (elementTypeSupportEquals) { + (array1, array2) => + if (array1.numElements() != 0 && array2.numElements() != 0) { + val hs = new OpenHashSet[Any] + val hsResult = new OpenHashSet[Any] + var foundNullElement = false + var i = 0 + while (i < array2.numElements()) { + if (array2.isNullAt(i)) { + foundNullElement = true + } else { + val elem = array2.get(i, elementType) + hs.add(elem) + } + i += 1 + } + val arrayBuffer = new scala.collection.mutable.ArrayBuffer[Any] + i = 0 + while (i < array1.numElements()) { + if (array1.isNullAt(i)) { + if (foundNullElement) { + arrayBuffer += null + foundNullElement = false + } + } else { + val elem = array1.get(i, elementType) + if (hs.contains(elem) && !hsResult.contains(elem)) { + arrayBuffer += elem + hsResult.add(elem) + } + } + i += 1 + } + new GenericArrayData(arrayBuffer) + } else { + new GenericArrayData(Array.emptyObjectArray) + } + } else { + (array1, array2) => + if (array1.numElements() != 0 && array2.numElements() != 0) { + val arrayBuffer = new scala.collection.mutable.ArrayBuffer[Any] + var alreadySeenNull = false + var i = 0 + while (i < array1.numElements()) { + var found = false + val elem1 = array1.get(i, elementType) + if (array1.isNullAt(i)) { + if (!alreadySeenNull) { + var j = 0 + while (!found && j < array2.numElements()) { + found = array2.isNullAt(j) + j += 1 + } + // array2 is scanned only once for null element + alreadySeenNull = true + } + } else { + var j = 0 + while (!found && j < array2.numElements()) { + if (!array2.isNullAt(j)) { + val elem2 = array2.get(j, elementType) + if (ordering.equiv(elem1, elem2)) { + // check whether elem1 is already stored in arrayBuffer + var foundArrayBuffer = false + var k = 0 + while (!foundArrayBuffer && k < arrayBuffer.size) { + val va = arrayBuffer(k) + foundArrayBuffer = (va != null) && ordering.equiv(va, elem1) + k += 1 + } + found = !foundArrayBuffer + } + } + j += 1 + } + } + if (found) { + arrayBuffer += elem1 + } + i += 1 + } + new GenericArrayData(arrayBuffer) + } else { + new GenericArrayData(Array.emptyObjectArray) + } + } + } + + override def nullSafeEval(input1: Any, input2: Any): Any = { + val array1 = input1.asInstanceOf[ArrayData] + val array2 = input2.asInstanceOf[ArrayData] + + evalIntersect(array1, array2) + } + + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + val arrayData = classOf[ArrayData].getName + val i = ctx.freshName("i") + val value = ctx.freshName("value") + val size = ctx.freshName("size") + if (canUseSpecializedHashSet) { + val jt = CodeGenerator.javaType(elementType) + val ptName = CodeGenerator.primitiveTypeName(jt) + + nullSafeCodeGen(ctx, ev, (array1, array2) => { + val foundNullElement = ctx.freshName("foundNullElement") + val nullElementIndex = ctx.freshName("nullElementIndex") + val builder = ctx.freshName("builder") + val openHashSet = classOf[OpenHashSet[_]].getName + val classTag = s"scala.reflect.ClassTag$$.MODULE$$.$hsTypeName()" + val hashSet = ctx.freshName("hashSet") + val hashSetResult = ctx.freshName("hashSetResult") + val arrayBuilder = classOf[mutable.ArrayBuilder[_]].getName + val arrayBuilderClass = s"$arrayBuilder$$of$ptName" + + def withArray2NullCheck(body: String): String = + if (right.dataType.asInstanceOf[ArrayType].containsNull) { + if (left.dataType.asInstanceOf[ArrayType].containsNull) { + s""" + |if ($array2.isNullAt($i)) { + | $foundNullElement = true; + |} else { + | $body + |} + """.stripMargin + } else { + // if array1's element is not nullable, we don't need to track the null element index. + s""" + |if (!$array2.isNullAt($i)) { + | $body + |} + """.stripMargin + } + } else { + body + } + + val writeArray2ToHashSet = withArray2NullCheck( + s""" + |$jt $value = ${genGetValue(array2, i)}; + |$hashSet.add$hsPostFix($hsValueCast$value); + """.stripMargin) + + def withArray1NullAssignment(body: String) = + if (left.dataType.asInstanceOf[ArrayType].containsNull) { + if (right.dataType.asInstanceOf[ArrayType].containsNull) { + s""" + |if ($array1.isNullAt($i)) { + | if ($foundNullElement) { + | $nullElementIndex = $size; + | $foundNullElement = false; + | $size++; + | $builder.$$plus$$eq($nullValueHolder); + | } + |} else { + | $body + |} + """.stripMargin + } else { + s""" + |if (!$array1.isNullAt($i)) { + | $body + |} + """.stripMargin + } + } else { + body + } + + val processArray1 = withArray1NullAssignment( + s""" + |$jt $value = ${genGetValue(array1, i)}; + |if ($hashSet.contains($hsValueCast$value) && + | !$hashSetResult.contains($hsValueCast$value)) { + | if (++$size > ${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}) { + | break; + | } + | $hashSetResult.add$hsPostFix($hsValueCast$value); + | $builder.$$plus$$eq($value); + |} + """.stripMargin) + + // Only need to track null element index when result array's element is nullable. + val declareNullTrackVariables = if (dataType.asInstanceOf[ArrayType].containsNull) { + s""" + |boolean $foundNullElement = false; + |int $nullElementIndex = -1; + """.stripMargin + } else { + "" + } + + s""" + |$openHashSet $hashSet = new $openHashSet$hsPostFix($classTag); + |$openHashSet $hashSetResult = new $openHashSet$hsPostFix($classTag); + |$declareNullTrackVariables + |for (int $i = 0; $i < $array2.numElements(); $i++) { + | $writeArray2ToHashSet + |} + |$arrayBuilderClass $builder = new $arrayBuilderClass(); + |int $size = 0; + |for (int $i = 0; $i < $array1.numElements(); $i++) { + | $processArray1 + |} + |${buildResultArray(builder, ev.value, size, nullElementIndex)} + """.stripMargin + }) + } else { + nullSafeCodeGen(ctx, ev, (array1, array2) => { + val expr = ctx.addReferenceObj("arrayIntersectExpr", this) + s"${ev.value} = ($arrayData)$expr.nullSafeEval($array1, $array2);" + }) + } + } + + override def prettyName: String = "array_intersect" +} + /** * Returns an array of the elements in the intersect of x and y, without duplicates */ @@ -4065,7 +4376,7 @@ case class ArrayExcept(left: Expression, right: Expression) extends ArraySetLike i += 1 } new GenericArrayData(arrayBuffer) - } + } } override def nullSafeEval(input1: Any, input2: Any): Any = { @@ -4080,31 +4391,10 @@ case class ArrayExcept(left: Expression, right: Expression) extends ArraySetLike val i = ctx.freshName("i") val value = ctx.freshName("value") val size = ctx.freshName("size") - val canUseSpecializedHashSet = elementType match { - case ByteType | ShortType | IntegerType | LongType | FloatType | DoubleType => true - case _ => false - } if (canUseSpecializedHashSet) { val jt = CodeGenerator.javaType(elementType) val ptName = CodeGenerator.primitiveTypeName(jt) - def genGetValue(array: String): String = - CodeGenerator.getValue(array, elementType, i) - - val (hsPostFix, hsTypeName) = elementType match { - // we cast byte/short to int when writing to the hash set. - case ByteType | ShortType | IntegerType => ("$mcI$sp", "Int") - case LongType => ("$mcJ$sp", ptName) - case FloatType => ("$mcF$sp", ptName) - case DoubleType => ("$mcD$sp", ptName) - } - - // we cast byte/short to int when writing to the hash set. - val hsValueCast = elementType match { - case ByteType | ShortType => "(int) " - case _ => "" - } - nullSafeCodeGen(ctx, ev, (array1, array2) => { val notFoundNullElement = ctx.freshName("notFoundNullElement") val nullElementIndex = ctx.freshName("nullElementIndex") @@ -4112,10 +4402,8 @@ case class ArrayExcept(left: Expression, right: Expression) extends ArraySetLike val openHashSet = classOf[OpenHashSet[_]].getName val classTag = s"scala.reflect.ClassTag$$.MODULE$$.$hsTypeName()" val hashSet = ctx.freshName("hashSet") - val genericArrayData = classOf[GenericArrayData].getName - val arrayBuilder = "scala.collection.mutable.ArrayBuilder" + val arrayBuilder = classOf[mutable.ArrayBuilder[_]].getName val arrayBuilderClass = s"$arrayBuilder$$of$ptName" - val arrayBuilderClassTag = s"scala.reflect.ClassTag$$.MODULE$$.$ptName()" def withArray2NullCheck(body: String): String = if (right.dataType.asInstanceOf[ArrayType].containsNull) { @@ -4141,18 +4429,10 @@ case class ArrayExcept(left: Expression, right: Expression) extends ArraySetLike val writeArray2ToHashSet = withArray2NullCheck( s""" - |$jt $value = ${genGetValue(array2)}; + |$jt $value = ${genGetValue(array2, i)}; |$hashSet.add$hsPostFix($hsValueCast$value); """.stripMargin) - // When hitting a null value, put a null holder in the ArrayBuilder. Finally we will - // convert ArrayBuilder to ArrayData and setNull on the slot with null holder. - val nullValueHolder = elementType match { - case ByteType => "(byte) 0" - case ShortType => "(short) 0" - case _ => "0" - } - def withArray1NullAssignment(body: String) = if (left.dataType.asInstanceOf[ArrayType].containsNull) { s""" @@ -4173,7 +4453,7 @@ case class ArrayExcept(left: Expression, right: Expression) extends ArraySetLike val processArray1 = withArray1NullAssignment( s""" - |$jt $value = ${genGetValue(array1)}; + |$jt $value = ${genGetValue(array1, i)}; |if (!$hashSet.contains($hsValueCast$value)) { | if (++$size > ${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}) { | break; @@ -4183,35 +4463,6 @@ case class ArrayExcept(left: Expression, right: Expression) extends ArraySetLike |} """.stripMargin) - def withResultArrayNullCheck(body: String): String = { - if (dataType.asInstanceOf[ArrayType].containsNull) { - s""" - |$body - |if ($nullElementIndex >= 0) { - | // result has null element - | ${ev.value}.setNullAt($nullElementIndex); - |} - """.stripMargin - } else { - body - } - } - - val buildResultArray = withResultArrayNullCheck( - s""" - |if ($size > ${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}) { - | throw new RuntimeException("Cannot create array with " + $size + - | " bytes of data due to exceeding the limit " + - | "${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH} elements for ArrayData."); - |} - | - |if (!UnsafeArrayData.shouldUseGenericArrayData(${elementType.defaultSize}, $size)) { - | ${ev.value} = UnsafeArrayData.fromPrimitiveArray($builder.result()); - |} else { - | ${ev.value} = new $genericArrayData($builder.result()); - |} - """.stripMargin) - // Only need to track null element index when array1's element is nullable. val declareNullTrackVariables = if (left.dataType.asInstanceOf[ArrayType].containsNull) { s""" @@ -4228,13 +4479,12 @@ case class ArrayExcept(left: Expression, right: Expression) extends ArraySetLike |for (int $i = 0; $i < $array2.numElements(); $i++) { | $writeArray2ToHashSet |} - |$arrayBuilderClass $builder = - | ($arrayBuilderClass)$arrayBuilder.make($arrayBuilderClassTag); + |$arrayBuilderClass $builder = new $arrayBuilderClass(); |int $size = 0; |for (int $i = 0; $i < $array1.numElements(); $i++) { | $processArray1 |} - |$buildResultArray + |${buildResultArray(builder, ev.value, size, nullElementIndex)} """.stripMargin }) } else { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala index 2f6f9064f9e62..4daa113869b5d 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala @@ -1618,4 +1618,116 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper assert(ArrayExcept(a20, a21).dataType.asInstanceOf[ArrayType].containsNull === false) assert(ArrayExcept(a24, a22).dataType.asInstanceOf[ArrayType].containsNull === true) } + + test("Array Intersect") { + val a00 = Literal.create(Seq(1, 2, 4), ArrayType(IntegerType, false)) + val a01 = Literal.create(Seq(4, 2), ArrayType(IntegerType, false)) + val a02 = Literal.create(Seq(1, 2, 1, 4), ArrayType(IntegerType, false)) + val a03 = Literal.create(Seq(4, 2, 4), ArrayType(IntegerType, false)) + val a04 = Literal.create(Seq(1, 2, null, 4, 5, null), ArrayType(IntegerType, true)) + val a05 = Literal.create(Seq(-5, 4, null, 2, -1, null), ArrayType(IntegerType, true)) + val a06 = Literal.create(Seq.empty[Int], ArrayType(IntegerType, false)) + val abl0 = Literal.create(Seq[Boolean](true, false, true), ArrayType(BooleanType, false)) + val abl1 = Literal.create(Seq[Boolean](true, true), ArrayType(BooleanType, false)) + val ab0 = Literal.create(Seq[Byte](1, 2, 3, 2), ArrayType(ByteType, containsNull = false)) + val ab1 = Literal.create(Seq[Byte](4, 2, 4), ArrayType(ByteType, containsNull = false)) + val as0 = Literal.create(Seq[Short](1, 2, 3, 2), ArrayType(ShortType, containsNull = false)) + val as1 = Literal.create(Seq[Short](4, 2, 4), ArrayType(ShortType, containsNull = false)) + val af0 = Literal.create(Seq[Float](1.1F, 2.2F, 3.3F, 2.2F), ArrayType(FloatType, false)) + val af1 = Literal.create(Seq[Float](4.4F, 2.2F, 4.4F), ArrayType(FloatType, false)) + val ad0 = Literal.create(Seq[Double](1.1, 2.2, 3.3, 2.2), ArrayType(DoubleType, false)) + val ad1 = Literal.create(Seq[Double](4.4, 2.2, 4.4), ArrayType(DoubleType, false)) + + val a10 = Literal.create(Seq(1L, 2L, 4L), ArrayType(LongType, false)) + val a11 = Literal.create(Seq(4L, 2L), ArrayType(LongType, false)) + val a12 = Literal.create(Seq(1L, 2L, 1L, 4L), ArrayType(LongType, false)) + val a13 = Literal.create(Seq(4L, 2L, 4L), ArrayType(LongType, false)) + val a14 = Literal.create(Seq(1L, 2L, null, 4L, 5L, null), ArrayType(LongType, true)) + val a15 = Literal.create(Seq(-5L, 4L, null, 2L, -1L, null), ArrayType(LongType, true)) + val a16 = Literal.create(Seq.empty[Long], ArrayType(LongType, false)) + + val a20 = Literal.create(Seq("b", "a", "c"), ArrayType(StringType, false)) + val a21 = Literal.create(Seq("c", "a"), ArrayType(StringType, false)) + val a22 = Literal.create(Seq("b", "a", "c", "a"), ArrayType(StringType, false)) + val a23 = Literal.create(Seq("c", "a", null, "f"), ArrayType(StringType, true)) + val a24 = Literal.create(Seq("b", null, "a", "g", null), ArrayType(StringType, true)) + val a25 = Literal.create(Seq.empty[String], ArrayType(StringType, false)) + + val a30 = Literal.create(Seq(null, null), ArrayType(IntegerType)) + val a31 = Literal.create(null, ArrayType(StringType)) + + checkEvaluation(ArrayIntersect(a00, a01), Seq(2, 4)) + checkEvaluation(ArrayIntersect(a01, a00), Seq(4, 2)) + checkEvaluation(ArrayIntersect(a02, a03), Seq(2, 4)) + checkEvaluation(ArrayIntersect(a03, a02), Seq(4, 2)) + checkEvaluation(ArrayIntersect(a00, a04), Seq(1, 2, 4)) + checkEvaluation(ArrayIntersect(a04, a05), Seq(2, null, 4)) + checkEvaluation(ArrayIntersect(a02, a06), Seq.empty) + checkEvaluation(ArrayIntersect(a06, a04), Seq.empty) + checkEvaluation(ArrayIntersect(abl0, abl1), Seq[Boolean](true)) + checkEvaluation(ArrayIntersect(ab0, ab1), Seq[Byte](2)) + checkEvaluation(ArrayIntersect(as0, as1), Seq[Short](2)) + checkEvaluation(ArrayIntersect(af0, af1), Seq[Float](2.2F)) + checkEvaluation(ArrayIntersect(ad0, ad1), Seq[Double](2.2D)) + + checkEvaluation(ArrayIntersect(a10, a11), Seq(2L, 4L)) + checkEvaluation(ArrayIntersect(a11, a10), Seq(4L, 2L)) + checkEvaluation(ArrayIntersect(a12, a13), Seq(2L, 4L)) + checkEvaluation(ArrayIntersect(a13, a12), Seq(4L, 2L)) + checkEvaluation(ArrayIntersect(a14, a15), Seq(2L, null, 4L)) + checkEvaluation(ArrayIntersect(a12, a16), Seq.empty) + checkEvaluation(ArrayIntersect(a16, a14), Seq.empty) + + checkEvaluation(ArrayIntersect(a20, a21), Seq("a", "c")) + checkEvaluation(ArrayIntersect(a21, a20), Seq("c", "a")) + checkEvaluation(ArrayIntersect(a22, a21), Seq("a", "c")) + checkEvaluation(ArrayIntersect(a21, a22), Seq("c", "a")) + checkEvaluation(ArrayIntersect(a23, a24), Seq("a", null)) + checkEvaluation(ArrayIntersect(a24, a23), Seq(null, "a")) + checkEvaluation(ArrayIntersect(a24, a25), Seq.empty) + checkEvaluation(ArrayIntersect(a25, a24), Seq.empty) + + checkEvaluation(ArrayIntersect(a30, a30), Seq(null)) + checkEvaluation(ArrayIntersect(a20, a31), null) + checkEvaluation(ArrayIntersect(a31, a20), null) + + val b0 = Literal.create( + Seq[Array[Byte]](Array[Byte](5, 6), Array[Byte](1, 2), Array[Byte](3, 4)), + ArrayType(BinaryType)) + val b1 = Literal.create( + Seq[Array[Byte]](Array[Byte](2, 1), Array[Byte](3, 4), Array[Byte](5, 6)), + ArrayType(BinaryType)) + val b2 = Literal.create( + Seq[Array[Byte]](Array[Byte](3, 4), Array[Byte](1, 2), Array[Byte](1, 2)), + ArrayType(BinaryType)) + val b3 = Literal.create(Seq[Array[Byte]](Array[Byte](1, 2), Array[Byte](3, 4), null), + ArrayType(BinaryType)) + val b4 = Literal.create(Seq[Array[Byte]](null, Array[Byte](3, 4), null), ArrayType(BinaryType)) + val b5 = Literal.create(Seq.empty, ArrayType(BinaryType)) + val arrayWithBinaryNull = Literal.create(Seq(null), ArrayType(BinaryType)) + checkEvaluation(ArrayIntersect(b0, b1), Seq[Array[Byte]](Array[Byte](5, 6), Array[Byte](3, 4))) + checkEvaluation(ArrayIntersect(b1, b0), Seq[Array[Byte]](Array[Byte](3, 4), Array[Byte](5, 6))) + checkEvaluation(ArrayIntersect(b0, b2), Seq[Array[Byte]](Array[Byte](1, 2), Array[Byte](3, 4))) + checkEvaluation(ArrayIntersect(b2, b0), Seq[Array[Byte]](Array[Byte](3, 4), Array[Byte](1, 2))) + checkEvaluation(ArrayIntersect(b2, b3), Seq[Array[Byte]](Array[Byte](3, 4), Array[Byte](1, 2))) + checkEvaluation(ArrayIntersect(b3, b2), Seq[Array[Byte]](Array[Byte](1, 2), Array[Byte](3, 4))) + checkEvaluation(ArrayIntersect(b3, b4), Seq[Array[Byte]](Array[Byte](3, 4), null)) + checkEvaluation(ArrayIntersect(b4, b3), Seq[Array[Byte]](null, Array[Byte](3, 4))) + checkEvaluation(ArrayIntersect(b4, b5), Seq.empty) + checkEvaluation(ArrayIntersect(b5, b4), Seq.empty) + checkEvaluation(ArrayIntersect(b4, arrayWithBinaryNull), Seq[Array[Byte]](null)) + + val aa0 = Literal.create(Seq[Seq[Int]](Seq[Int](1, 2), Seq[Int](3, 4), Seq[Int](1, 2)), + ArrayType(ArrayType(IntegerType))) + val aa1 = Literal.create(Seq[Seq[Int]](Seq[Int](3, 4), Seq[Int](2, 1), Seq[Int](3, 4)), + ArrayType(ArrayType(IntegerType))) + checkEvaluation(ArrayIntersect(aa0, aa1), Seq[Seq[Int]](Seq[Int](3, 4))) + checkEvaluation(ArrayIntersect(aa1, aa0), Seq[Seq[Int]](Seq[Int](3, 4))) + + assert(ArrayIntersect(a00, a01).dataType.asInstanceOf[ArrayType].containsNull === false) + assert(ArrayIntersect(a00, a04).dataType.asInstanceOf[ArrayType].containsNull === false) + assert(ArrayIntersect(a04, a05).dataType.asInstanceOf[ArrayType].containsNull === true) + assert(ArrayIntersect(a20, a21).dataType.asInstanceOf[ArrayType].containsNull === false) + assert(ArrayIntersect(a23, a24).dataType.asInstanceOf[ArrayType].containsNull === true) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index cc739b85f555c..310e428b69819 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -3233,6 +3233,17 @@ object functions { */ def array_distinct(e: Column): Column = withExpr { ArrayDistinct(e.expr) } + /** + * Returns an array of the elements in the intersection of the given two arrays, + * without duplicates. + * + * @group collection_funcs + * @since 2.4.0 + */ + def array_intersect(col1: Column, col2: Column): Column = withExpr { + ArrayIntersect(col1.expr, col2.expr) + } + /** * Returns an array of the elements in the union of the given two arrays, without duplicates. * diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 1d5707a2c7047..2e6ef11a54d2c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -1647,6 +1647,60 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { assert(result10.first.schema(0).dataType === expectedType10) } + test("array_intersect functions") { + val df1 = Seq((Array(1, 2, 4), Array(4, 2))).toDF("a", "b") + val ans1 = Row(Seq(2, 4)) + checkAnswer(df1.select(array_intersect($"a", $"b")), ans1) + checkAnswer(df1.selectExpr("array_intersect(a, b)"), ans1) + + val df2 = Seq((Array[Integer](1, 2, null, 4, 5), Array[Integer](-5, 4, null, 2, -1))) + .toDF("a", "b") + val ans2 = Row(Seq(2, null, 4)) + checkAnswer(df2.select(array_intersect($"a", $"b")), ans2) + checkAnswer(df2.selectExpr("array_intersect(a, b)"), ans2) + + val df3 = Seq((Array(1L, 2L, 4L), Array(4L, 2L))).toDF("a", "b") + val ans3 = Row(Seq(2L, 4L)) + checkAnswer(df3.select(array_intersect($"a", $"b")), ans3) + checkAnswer(df3.selectExpr("array_intersect(a, b)"), ans3) + + val df4 = Seq( + (Array[java.lang.Long](1L, 2L, null, 4L, 5L), Array[java.lang.Long](-5L, 4L, null, 2L, -1L))) + .toDF("a", "b") + val ans4 = Row(Seq(2L, null, 4L)) + checkAnswer(df4.select(array_intersect($"a", $"b")), ans4) + checkAnswer(df4.selectExpr("array_intersect(a, b)"), ans4) + + val df5 = Seq((Array("c", null, "a", "f"), Array("b", "a", null, "g"))).toDF("a", "b") + val ans5 = Row(Seq(null, "a")) + checkAnswer(df5.select(array_intersect($"a", $"b")), ans5) + checkAnswer(df5.selectExpr("array_intersect(a, b)"), ans5) + + val df6 = Seq((null, null)).toDF("a", "b") + assert(intercept[AnalysisException] { + df6.select(array_intersect($"a", $"b")) + }.getMessage.contains("data type mismatch")) + assert(intercept[AnalysisException] { + df6.selectExpr("array_intersect(a, b)") + }.getMessage.contains("data type mismatch")) + + val df7 = Seq((Array(1), Array("a"))).toDF("a", "b") + assert(intercept[AnalysisException] { + df7.select(array_intersect($"a", $"b")) + }.getMessage.contains("data type mismatch")) + assert(intercept[AnalysisException] { + df7.selectExpr("array_intersect(a, b)") + }.getMessage.contains("data type mismatch")) + + val df8 = Seq((null, Array("a"))).toDF("a", "b") + assert(intercept[AnalysisException] { + df8.select(array_intersect($"a", $"b")) + }.getMessage.contains("data type mismatch")) + assert(intercept[AnalysisException] { + df8.selectExpr("array_intersect(a, b)") + }.getMessage.contains("data type mismatch")) + } + test("transform function - array for primitive type not containing null") { val df = Seq( Seq(1, 9, 8, 7),