diff --git a/core/src/main/scala/org/apache/spark/util/collection/OpenHashSet.scala b/core/src/main/scala/org/apache/spark/util/collection/OpenHashSet.scala index 60f6f537c1d54..8883e17bf3164 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/OpenHashSet.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/OpenHashSet.scala @@ -28,9 +28,9 @@ import org.apache.spark.annotation.Private * removed. * * The underlying implementation uses Scala compiler's specialization to generate optimized - * storage for two primitive types (Long and Int). It is much faster than Java's standard HashSet - * while incurring much less memory overhead. This can serve as building blocks for higher level - * data structures such as an optimized HashMap. + * storage for four primitive types (Long, Int, Double, and Float). It is much faster than Java's + * standard HashSet while incurring much less memory overhead. This can serve as building blocks + * for higher level data structures such as an optimized HashMap. * * This OpenHashSet is designed to serve as building blocks for higher level data structures * such as an optimized hash map. Compared with standard hash set implementations, this class @@ -41,7 +41,7 @@ import org.apache.spark.annotation.Private * to explore all spaces for each key (see http://en.wikipedia.org/wiki/Quadratic_probing). */ @Private -class OpenHashSet[@specialized(Long, Int) T: ClassTag]( +class OpenHashSet[@specialized(Long, Int, Double, Float) T: ClassTag]( initialCapacity: Int, loadFactor: Double) extends Serializable { @@ -77,6 +77,10 @@ class OpenHashSet[@specialized(Long, Int) T: ClassTag]( (new LongHasher).asInstanceOf[Hasher[T]] } else if (mt == ClassTag.Int) { (new IntHasher).asInstanceOf[Hasher[T]] + } else if (mt == ClassTag.Double) { + (new DoubleHasher).asInstanceOf[Hasher[T]] + } else if (mt == ClassTag.Float) { + (new FloatHasher).asInstanceOf[Hasher[T]] } else { new Hasher[T] } @@ -293,7 +297,7 @@ object OpenHashSet { * A set of specialized hash function implementation to avoid boxing hash code computation * in the specialized implementation of OpenHashSet. */ - sealed class Hasher[@specialized(Long, Int) T] extends Serializable { + sealed class Hasher[@specialized(Long, Int, Double, Float) T] extends Serializable { def hash(o: T): Int = o.hashCode() } @@ -305,6 +309,17 @@ object OpenHashSet { override def hash(o: Int): Int = o } + class DoubleHasher extends Hasher[Double] { + override def hash(o: Double): Int = { + val bits = java.lang.Double.doubleToLongBits(o) + (bits ^ (bits >>> 32)).toInt + } + } + + class FloatHasher extends Hasher[Float] { + override def hash(o: Float): Int = java.lang.Float.floatToIntBits(o) + } + private def grow1(newSize: Int) {} private def move1(oldPos: Int, newPos: Int) { } diff --git a/core/src/test/scala/org/apache/spark/util/collection/OpenHashSetSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/OpenHashSetSuite.scala index 210bc5c099742..b887f937a9da9 100644 --- a/core/src/test/scala/org/apache/spark/util/collection/OpenHashSetSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/collection/OpenHashSetSuite.scala @@ -112,6 +112,80 @@ class OpenHashSetSuite extends SparkFunSuite with Matchers { assert(!set.contains(10000L)) } + test("primitive float") { + val set = new OpenHashSet[Float] + assert(set.size === 0) + assert(!set.contains(10.1F)) + assert(!set.contains(50.5F)) + assert(!set.contains(999.9F)) + assert(!set.contains(10000.1F)) + + set.add(10.1F) + assert(set.size === 1) + assert(set.contains(10.1F)) + assert(!set.contains(50.5F)) + assert(!set.contains(999.9F)) + assert(!set.contains(10000.1F)) + + set.add(50.5F) + assert(set.size === 2) + assert(set.contains(10.1F)) + assert(set.contains(50.5F)) + assert(!set.contains(999.9F)) + assert(!set.contains(10000.1F)) + + set.add(999.9F) + assert(set.size === 3) + assert(set.contains(10.1F)) + assert(set.contains(50.5F)) + assert(set.contains(999.9F)) + assert(!set.contains(10000.1F)) + + set.add(50.5F) + assert(set.size === 3) + assert(set.contains(10.1F)) + assert(set.contains(50.5F)) + assert(set.contains(999.9F)) + assert(!set.contains(10000.1F)) + } + + test("primitive double") { + val set = new OpenHashSet[Double] + assert(set.size === 0) + assert(!set.contains(10.1D)) + assert(!set.contains(50.5D)) + assert(!set.contains(999.9D)) + assert(!set.contains(10000.1D)) + + set.add(10.1D) + assert(set.size === 1) + assert(set.contains(10.1D)) + assert(!set.contains(50.5D)) + assert(!set.contains(999.9D)) + assert(!set.contains(10000.1D)) + + set.add(50.5D) + assert(set.size === 2) + assert(set.contains(10.1D)) + assert(set.contains(50.5D)) + assert(!set.contains(999.9D)) + assert(!set.contains(10000.1D)) + + set.add(999.9D) + assert(set.size === 3) + assert(set.contains(10.1D)) + assert(set.contains(50.5D)) + assert(set.contains(999.9D)) + assert(!set.contains(10000.1D)) + + set.add(50.5D) + assert(set.size === 3) + assert(set.contains(10.1D)) + assert(set.contains(50.5D)) + assert(set.contains(999.9D)) + assert(!set.contains(10000.1D)) + } + test("non-primitive") { val set = new OpenHashSet[String] assert(set.size === 0) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 0a88e482787ff..778fa787ed8ca 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -2052,6 +2052,25 @@ def array_union(col1, col2): return Column(sc._jvm.functions.array_union(_to_java_column(col1), _to_java_column(col2))) +@ignore_unicode_prefix +@since(2.4) +def array_except(col1, col2): + """ + Collection function: returns an array of the elements in col1 but not in col2, + without duplicates. + + :param col1: name of column containing array + :param col2: name of column containing array + + >>> from pyspark.sql import Row + >>> df = spark.createDataFrame([Row(c1=["b", "a", "c"], c2=["c", "d", "a", "f"])]) + >>> df.select(array_except(df.c1, df.c2)).collect() + [Row(array_except(c1, c2)=[u'b'])] + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.array_except(_to_java_column(col1), _to_java_column(col2))) + + @since(1.4) def explode(col): """Returns a new row for each element in the given array or map. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index adc4837276793..b8b311219ca8d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -414,6 +414,7 @@ object FunctionRegistry { expression[ArrayJoin]("array_join"), expression[ArrayPosition]("array_position"), expression[ArraySort]("array_sort"), + expression[ArrayExcept]("array_except"), expression[ArrayUnion]("array_union"), expression[CreateMap]("map"), expression[CreateNamedStruct]("named_struct"), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index dcb9c96ca3b2d..773aefc0ac1f9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -709,7 +709,7 @@ trait ComplexTypeMergingExpression extends Expression { @transient lazy val inputTypesForMerging: Seq[DataType] = children.map(_.dataType) - override def dataType: DataType = { + def dataTypeCheck: Unit = { require( inputTypesForMerging.nonEmpty, "The collection of input data types must not be empty.") @@ -717,6 +717,10 @@ trait ComplexTypeMergingExpression extends Expression { TypeCoercion.haveSameType(inputTypesForMerging), "All input types must be the same except nullable, containsNull, valueContainsNull flags." + s" The input types found are\n\t${inputTypesForMerging.mkString("\n\t")}") + } + + override def dataType: DataType = { + dataTypeCheck inputTypesForMerging.reduceLeft(TypeCoercion.findCommonTypeDifferentOnlyInNullFlags(_, _).get) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index b1d91ffbe86e0..b03bd7d942d72 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -3651,14 +3651,9 @@ case class ArrayDistinct(child: Expression) } /** - * Will become common base class for [[ArrayUnion]], ArrayIntersect, and ArrayExcept. + * Will become common base class for [[ArrayUnion]], ArrayIntersect, and [[ArrayExcept]]. */ abstract class ArraySetLike extends BinaryArrayExpressionWithImplicitCast { - override def dataType: DataType = { - val dataTypes = children.map(_.dataType.asInstanceOf[ArrayType]) - ArrayType(elementType, dataTypes.exists(_.containsNull)) - } - override def checkInputDataTypes(): TypeCheckResult = { val typeCheckResult = super.checkInputDataTypes() if (typeCheckResult.isSuccess) { @@ -3702,7 +3697,8 @@ object ArraySetLike { array(1, 2, 3, 5) """, since = "2.4.0") -case class ArrayUnion(left: Expression, right: Expression) extends ArraySetLike { +case class ArrayUnion(left: Expression, right: Expression) extends ArraySetLike + with ComplexTypeMergingExpression { var hsInt: OpenHashSet[Int] = _ var hsLong: OpenHashSet[Long] = _ @@ -3968,3 +3964,295 @@ object ArrayUnion { new GenericArrayData(arrayBuffer) } } + +/** + * Returns an array of the elements in the intersect of x and y, without duplicates + */ +@ExpressionDescription( + usage = """ + _FUNC_(array1, array2) - Returns an array of the elements in array1 but not in array2, + without duplicates. + """, + examples = """ + Examples: + > SELECT _FUNC_(array(1, 2, 3), array(1, 3, 5)); + array(2) + """, + since = "2.4.0") +case class ArrayExcept(left: Expression, right: Expression) extends ArraySetLike + with ComplexTypeMergingExpression { + override def dataType: DataType = { + dataTypeCheck + left.dataType + } + + @transient lazy val evalExcept: (ArrayData, ArrayData) => ArrayData = { + if (elementTypeSupportEquals) { + (array1, array2) => + val hs = new OpenHashSet[Any] + var notFoundNullElement = true + var i = 0 + while (i < array2.numElements()) { + if (array2.isNullAt(i)) { + notFoundNullElement = false + } else { + val elem = array2.get(i, elementType) + hs.add(elem) + } + i += 1 + } + val arrayBuffer = new scala.collection.mutable.ArrayBuffer[Any] + i = 0 + while (i < array1.numElements()) { + if (array1.isNullAt(i)) { + if (notFoundNullElement) { + arrayBuffer += null + notFoundNullElement = false + } + } else { + val elem = array1.get(i, elementType) + if (!hs.contains(elem)) { + arrayBuffer += elem + hs.add(elem) + } + } + i += 1 + } + new GenericArrayData(arrayBuffer) + } else { + (array1, array2) => + val arrayBuffer = new scala.collection.mutable.ArrayBuffer[Any] + var scannedNullElements = false + var i = 0 + while (i < array1.numElements()) { + var found = false + val elem1 = array1.get(i, elementType) + if (elem1 == null) { + if (!scannedNullElements) { + var j = 0 + while (!found && j < array2.numElements()) { + found = array2.isNullAt(j) + j += 1 + } + // array2 is scanned only once for null element + scannedNullElements = true + } else { + found = true + } + } else { + var j = 0 + while (!found && j < array2.numElements()) { + val elem2 = array2.get(j, elementType) + if (elem2 != null) { + found = ordering.equiv(elem1, elem2) + } + j += 1 + } + if (!found) { + // check whether elem1 is already stored in arrayBuffer + var k = 0 + while (!found && k < arrayBuffer.size) { + val va = arrayBuffer(k) + found = (va != null) && ordering.equiv(va, elem1) + k += 1 + } + } + } + if (!found) { + arrayBuffer += elem1 + } + i += 1 + } + new GenericArrayData(arrayBuffer) + } + } + + override def nullSafeEval(input1: Any, input2: Any): Any = { + val array1 = input1.asInstanceOf[ArrayData] + val array2 = input2.asInstanceOf[ArrayData] + + evalExcept(array1, array2) + } + + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + val arrayData = classOf[ArrayData].getName + val i = ctx.freshName("i") + val pos = ctx.freshName("pos") + val value = ctx.freshName("value") + val hsValue = ctx.freshName("hsValue") + val size = ctx.freshName("size") + if (elementTypeSupportEquals) { + val ptName = CodeGenerator.primitiveTypeName(elementType) + val unsafeArray = ctx.freshName("unsafeArray") + val (postFix, openHashElementType, hsJavaTypeName, genHsValue, + getter, setter, javaTypeName, primitiveTypeName, arrayDataBuilder) = + elementType match { + case ByteType | ShortType | IntegerType => + ("$mcI$sp", "Int", "int", s"(int) $value", + s"get$ptName($i)", s"set$ptName($pos, $value)", + CodeGenerator.javaType(elementType), ptName, + s""" + |${ctx.createUnsafeArray(unsafeArray, size, elementType, s" $prettyName failed.")} + |${ev.value} = $unsafeArray; + """.stripMargin) + case LongType | FloatType | DoubleType => + val signature = elementType match { + case LongType => "$mcJ$sp" + case FloatType => "$mcF$sp" + case DoubleType => "$mcD$sp" + } + (signature, CodeGenerator.boxedType(elementType), + CodeGenerator.javaType(elementType), value, + s"get$ptName($i)", s"set$ptName($pos, $value)", + CodeGenerator.javaType(elementType), ptName, + s""" + |${ctx.createUnsafeArray(unsafeArray, size, elementType, s" $prettyName failed.")} + |${ev.value} = $unsafeArray; + """.stripMargin) + case _ => + val genericArrayData = classOf[GenericArrayData].getName + val et = ctx.addReferenceObj("elementType", elementType) + ("", "Object", "Object", value, + s"get($i, $et)", s"update($pos, $value)", "Object", "Ref", + s"${ev.value} = new $genericArrayData(new Object[$size]);") + } + + nullSafeCodeGen(ctx, ev, (array1, array2) => { + val notFoundNullElement = ctx.freshName("notFoundNullElement") + val nullElementIndex = ctx.freshName("nullElementIndex") + val builder = ctx.freshName("builder") + val array = ctx.freshName("array") + val openHashSet = classOf[OpenHashSet[_]].getName + val classTag = s"scala.reflect.ClassTag$$.MODULE$$.$openHashElementType()" + val hs = ctx.freshName("hs") + val genericArrayData = classOf[GenericArrayData].getName + val arrayBuilder = "scala.collection.mutable.ArrayBuilder" + val arrayBuilderClass = s"$arrayBuilder$$of$primitiveTypeName" + val arrayBuilderClassTag = if (primitiveTypeName != "Ref") { + s"scala.reflect.ClassTag$$.MODULE$$.$primitiveTypeName()" + } else { + s"scala.reflect.ClassTag$$.MODULE$$.AnyRef()" + } + + def withArray2NullCheck(body: String) = + if (right.dataType.asInstanceOf[ArrayType].containsNull) { + s""" + |if ($array2.isNullAt($i)) { + | $notFoundNullElement = false; + |} else { + | $body + |} + """.stripMargin + } else { + body + } + val array2Body = + s""" + |$javaTypeName $value = $array2.$getter; + |$hsJavaTypeName $hsValue = $genHsValue; + |$hs.add$postFix($hsValue); + """.stripMargin + + def withArray1NullAssignment(body: String) = + if (left.dataType.asInstanceOf[ArrayType].containsNull) { + s""" + |if ($array1.isNullAt($i)) { + | if ($notFoundNullElement) { + | $nullElementIndex = $size; + | $notFoundNullElement = false; + | $size++; + | } + |} else { + | $body + |} + """.stripMargin + } else { + body + } + val array1Body = + s""" + |$javaTypeName $value = $array1.$getter; + |$hsJavaTypeName $hsValue = $genHsValue; + |if (!$hs.contains($hsValue)) { + | if (++$size > ${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}) { + | break; + | } + | $hs.add$postFix($hsValue); + | $builder.$$plus$$eq($value); + |} + """.stripMargin + + val nonNullArrayDataBuild = { + val build = if (postFix != "") { + val defaultSize = elementType.defaultSize + s""" + |if (!UnsafeArrayData.shouldUseGenericArrayData($defaultSize, $size)) { + | ${ev.value} = UnsafeArrayData.fromPrimitiveArray($builder.result()); + |} else { + | ${ev.value} = new $genericArrayData($builder.result()); + |} + """.stripMargin + } else { + s"${ev.value} = new $genericArrayData($builder.result());" + } + s""" + |if ($size > ${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}) { + | throw new RuntimeException("Unsuccessful try create array with " + $size + + | " bytes of data due to exceeding the limit " + + | "${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH} elements for GenericArrayData." + + | " $prettyName failed."); + |} + |$build + """.stripMargin + } + + def buildResultArrayData(nonNullArrayDataBuild: String) = + if (dataType.asInstanceOf[ArrayType].containsNull) { + s""" + |if ($nullElementIndex < 0) { + | // result has no null element + | $nonNullArrayDataBuild + |} else { + | // result has null element + | $arrayDataBuilder + | $javaTypeName[] $array = $builder.result(); + | for (int $i = 0, $pos = 0; $pos < $size; $pos++) { + | if ($pos == $nullElementIndex) { + | ${ev.value}.setNullAt($pos); + | } else { + | $javaTypeName $value = $array[$i++]; + | ${ev.value}.$setter; + | } + | } + |} + """.stripMargin + } else { + nonNullArrayDataBuild + } + + s""" + |$openHashSet $hs = new $openHashSet$postFix($classTag); + |boolean $notFoundNullElement = true; + |for (int $i = 0; $i < $array2.numElements(); $i++) { + | ${withArray2NullCheck(array2Body)} + |} + |$arrayBuilderClass $builder = + | ($arrayBuilderClass)$arrayBuilder.make($arrayBuilderClassTag); + |int $nullElementIndex = -1; + |int $size = 0; + |for (int $i = 0; $i < $array1.numElements(); $i++) { + | ${withArray1NullAssignment(array1Body)} + |} + |${buildResultArrayData(nonNullArrayDataBuild)} + """.stripMargin + }) + } else { + nullSafeCodeGen(ctx, ev, (array1, array2) => { + val expr = ctx.addReferenceObj("arrayExceptExpr", this) + s"${ev.value} = ($arrayData)$expr.nullSafeEval($array1, $array2);" + }) + } + } + + override def prettyName: String = "array_except" +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala index 5c5728548e646..2f6f9064f9e62 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala @@ -1503,4 +1503,119 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper assert(!shuffle.fastEquals(shuffle.freshCopy())) assert(!shuffle.fastEquals(Shuffle(ai0, seed2))) } + + test("Array Except") { + val a00 = Literal.create(Seq(1, 2, 4, 3), ArrayType(IntegerType, false)) + val a01 = Literal.create(Seq(4, 2), ArrayType(IntegerType, false)) + val a02 = Literal.create(Seq(1, 2, 4, 2), ArrayType(IntegerType, false)) + val a03 = Literal.create(Seq(4, 2, 4), ArrayType(IntegerType, false)) + val a04 = Literal.create(Seq(1, 2, null, 4, 5, 1), ArrayType(IntegerType, true)) + val a05 = Literal.create(Seq(-5, 4, null, 2, -1), ArrayType(IntegerType, true)) + val a06 = Literal.create(Seq.empty[Int], ArrayType(IntegerType, false)) + val abl0 = Literal.create(Seq[Boolean](true, true), ArrayType(BooleanType, false)) + val abl1 = Literal.create(Seq[Boolean](false, false), ArrayType(BooleanType, false)) + val ab0 = Literal.create(Seq[Byte](1, 2, 3, 2), ArrayType(ByteType, false)) + val ab1 = Literal.create(Seq[Byte](4, 2, 4), ArrayType(ByteType, false)) + val as0 = Literal.create(Seq[Short](1, 2, 3, 2), ArrayType(ShortType, false)) + val as1 = Literal.create(Seq[Short](4, 2, 4), ArrayType(ShortType, false)) + val af0 = Literal.create(Seq[Float](1.1F, 2.2F, 3.3F, 2.2F), ArrayType(FloatType, false)) + val af1 = Literal.create(Seq[Float](4.4F, 2.2F, 4.4F), ArrayType(FloatType, false)) + val ad0 = Literal.create(Seq[Double](1.1, 2.2, 3.3, 2.2), ArrayType(DoubleType, false)) + val ad1 = Literal.create(Seq[Double](4.4, 2.2, 4.4), ArrayType(DoubleType, false)) + + val a10 = Literal.create(Seq(1L, 2L, 4L, 3L), ArrayType(LongType, false)) + val a11 = Literal.create(Seq(4L, 2L), ArrayType(LongType, false)) + val a12 = Literal.create(Seq(1L, 2L, 4L, 2L), ArrayType(LongType, false)) + val a13 = Literal.create(Seq(4L, 2L), ArrayType(LongType, false)) + val a14 = Literal.create(Seq(1L, 2L, null, 4L, 5L, 1L), ArrayType(LongType, true)) + val a15 = Literal.create(Seq(-5L, 4L, null, 2L, -1L), ArrayType(LongType, true)) + val a16 = Literal.create(Seq.empty[Long], ArrayType(LongType, false)) + + val a20 = Literal.create(Seq("b", "a", "c", "d"), ArrayType(StringType, false)) + val a21 = Literal.create(Seq("c", "a"), ArrayType(StringType, false)) + val a22 = Literal.create(Seq("b", "a", "c", "a"), ArrayType(StringType, false)) + val a23 = Literal.create(Seq("c", "a", "c"), ArrayType(StringType, false)) + val a24 = Literal.create(Seq("c", null, "a", "f", "c"), ArrayType(StringType, true)) + val a25 = Literal.create(Seq("b", null, "a", "g"), ArrayType(StringType, true)) + val a26 = Literal.create(Seq.empty[String], ArrayType(StringType, false)) + + val a30 = Literal.create(Seq(null, null), ArrayType(IntegerType)) + val a31 = Literal.create(null, ArrayType(StringType)) + + checkEvaluation(ArrayExcept(a00, a01), Seq(1, 3)) + checkEvaluation(ArrayExcept(a02, a01), Seq(1)) + checkEvaluation(ArrayExcept(a02, a02), Seq.empty) + checkEvaluation(ArrayExcept(a02, a03), Seq(1)) + checkEvaluation(ArrayExcept(a04, a02), Seq(null, 5)) + checkEvaluation(ArrayExcept(a04, a05), Seq(1, 5)) + checkEvaluation(ArrayExcept(a04, a06), Seq(1, 2, null, 4, 5)) + checkEvaluation(ArrayExcept(a06, a04), Seq.empty) + checkEvaluation(ArrayExcept(abl0, abl1), Seq[Boolean](true)) + checkEvaluation(ArrayExcept(ab0, ab1), Seq[Byte](1, 3)) + checkEvaluation(ArrayExcept(as0, as1), Seq[Short](1, 3)) + checkEvaluation(ArrayExcept(af0, af1), Seq[Float](1.1F, 3.3F)) + checkEvaluation(ArrayExcept(ad0, ad1), Seq[Double](1.1, 3.3)) + + checkEvaluation(ArrayExcept(a10, a11), Seq(1L, 3L)) + checkEvaluation(ArrayExcept(a12, a11), Seq(1L)) + checkEvaluation(ArrayExcept(a12, a12), Seq.empty) + checkEvaluation(ArrayExcept(a12, a13), Seq(1L)) + checkEvaluation(ArrayExcept(a14, a12), Seq(null, 5L)) + checkEvaluation(ArrayExcept(a14, a15), Seq(1L, 5L)) + checkEvaluation(ArrayExcept(a14, a16), Seq(1L, 2L, null, 4L, 5L)) + checkEvaluation(ArrayExcept(a16, a14), Seq.empty) + + checkEvaluation(ArrayExcept(a20, a21), Seq("b", "d")) + checkEvaluation(ArrayExcept(a22, a21), Seq("b")) + checkEvaluation(ArrayExcept(a22, a22), Seq.empty) + checkEvaluation(ArrayExcept(a22, a23), Seq("b")) + checkEvaluation(ArrayExcept(a24, a22), Seq(null, "f")) + checkEvaluation(ArrayExcept(a24, a25), Seq("c", "f")) + checkEvaluation(ArrayExcept(a24, a26), Seq("c", null, "a", "f")) + checkEvaluation(ArrayExcept(a26, a24), Seq.empty) + + checkEvaluation(ArrayExcept(a30, a30), Seq.empty) + checkEvaluation(ArrayExcept(a20, a31), null) + checkEvaluation(ArrayExcept(a31, a20), null) + + val b0 = Literal.create( + Seq[Array[Byte]](Array[Byte](5, 6), Array[Byte](1, 2), Array[Byte](3, 4), Array[Byte](7, 8)), + ArrayType(BinaryType)) + val b1 = Literal.create( + Seq[Array[Byte]](Array[Byte](2, 1), Array[Byte](3, 4), Array[Byte](5, 6)), + ArrayType(BinaryType)) + val b2 = Literal.create( + Seq[Array[Byte]](Array[Byte](1, 2), Array[Byte](3, 4), Array[Byte](1, 2)), + ArrayType(BinaryType)) + val b3 = Literal.create(Seq[Array[Byte]](Array[Byte](2, 1), Array[Byte](3, 4), null), + ArrayType(BinaryType)) + val b4 = Literal.create(Seq[Array[Byte]](null, Array[Byte](3, 4), null), ArrayType(BinaryType)) + val b5 = Literal.create(Seq.empty, ArrayType(BinaryType)) + val arrayWithBinaryNull = Literal.create(Seq(null), ArrayType(BinaryType)) + + checkEvaluation(ArrayExcept(b0, b1), Seq[Array[Byte]](Array[Byte](1, 2), Array[Byte](7, 8))) + checkEvaluation(ArrayExcept(b1, b0), Seq[Array[Byte]](Array[Byte](2, 1))) + checkEvaluation(ArrayExcept(b0, b2), Seq[Array[Byte]](Array[Byte](5, 6), Array[Byte](7, 8))) + checkEvaluation(ArrayExcept(b2, b0), Seq.empty) + checkEvaluation(ArrayExcept(b2, b3), Seq[Array[Byte]](Array[Byte](1, 2))) + checkEvaluation(ArrayExcept(b3, b2), Seq[Array[Byte]](Array[Byte](2, 1), null)) + checkEvaluation(ArrayExcept(b3, b4), Seq[Array[Byte]](Array[Byte](2, 1))) + checkEvaluation(ArrayExcept(b4, b3), Seq.empty) + checkEvaluation(ArrayExcept(b4, b5), Seq[Array[Byte]](null, Array[Byte](3, 4))) + checkEvaluation(ArrayExcept(b5, b4), Seq.empty) + checkEvaluation(ArrayExcept(b4, arrayWithBinaryNull), Seq[Array[Byte]](Array[Byte](3, 4))) + + val aa0 = Literal.create(Seq[Seq[Int]](Seq[Int](1, 2), Seq[Int](3, 4), Seq[Int](1, 2)), + ArrayType(ArrayType(IntegerType))) + val aa1 = Literal.create(Seq[Seq[Int]](Seq[Int](3, 4), Seq[Int](2, 1), Seq[Int](3, 4)), + ArrayType(ArrayType(IntegerType))) + checkEvaluation(ArrayExcept(aa0, aa1), Seq[Seq[Int]](Seq[Int](1, 2))) + checkEvaluation(ArrayExcept(aa1, aa0), Seq[Seq[Int]](Seq[Int](2, 1))) + + assert(ArrayExcept(a00, a01).dataType.asInstanceOf[ArrayType].containsNull === false) + assert(ArrayExcept(a04, a02).dataType.asInstanceOf[ArrayType].containsNull === true) + assert(ArrayExcept(a04, a05).dataType.asInstanceOf[ArrayType].containsNull === true) + assert(ArrayExcept(a20, a21).dataType.asInstanceOf[ArrayType].containsNull === false) + assert(ArrayExcept(a24, a22).dataType.asInstanceOf[ArrayType].containsNull === true) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index bcd0c946ab996..760b2219cb888 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -3229,6 +3229,17 @@ object functions { ArrayUnion(col1.expr, col2.expr) } + /** + * Returns an array of the elements in the first array but not in the second array, + * without duplicates. The order of elements in the result is not determined + * + * @group collection_funcs + * @since 2.4.0 + */ + def array_except(col1: Column, col2: Column): Column = withExpr { + ArrayExcept(col1.expr, col2.expr) + } + /** * Creates a new row for each element in the given array or map column. * diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 299c96f74af22..e550b142c738d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -1578,6 +1578,75 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { testNonPrimitiveType() } + test("array_except functions") { + val df1 = Seq((Array(1, 2, 4), Array(4, 2))).toDF("a", "b") + val ans1 = Row(Seq(1)) + checkAnswer(df1.select(array_except($"a", $"b")), ans1) + checkAnswer(df1.selectExpr("array_except(a, b)"), ans1) + + val df2 = Seq((Array[Integer](1, 2, null, 4, 5), Array[Integer](-5, 4, null, 2, -1))) + .toDF("a", "b") + val ans2 = Row(Seq(1, 5)) + checkAnswer(df2.select(array_except($"a", $"b")), ans2) + checkAnswer(df2.selectExpr("array_except(a, b)"), ans2) + + val df3 = Seq((Array(1L, 2L, 4L), Array(4L, 2L))).toDF("a", "b") + val ans3 = Row(Seq(1L)) + checkAnswer(df3.select(array_except($"a", $"b")), ans3) + checkAnswer(df3.selectExpr("array_except(a, b)"), ans3) + + val df4 = Seq( + (Array[java.lang.Long](1L, 2L, null, 4L, 5L), Array[java.lang.Long](-5L, 4L, null, 2L, -1L))) + .toDF("a", "b") + val ans4 = Row(Seq(1L, 5L)) + checkAnswer(df4.select(array_except($"a", $"b")), ans4) + checkAnswer(df4.selectExpr("array_except(a, b)"), ans4) + + val df5 = Seq((Array("c", null, "a", "f"), Array("b", null, "a", "g"))).toDF("a", "b") + val ans5 = Row(Seq("c", "f")) + checkAnswer(df5.select(array_except($"a", $"b")), ans5) + checkAnswer(df5.selectExpr("array_except(a, b)"), ans5) + + val df6 = Seq((null, null)).toDF("a", "b") + intercept[AnalysisException] { + df6.select(array_except($"a", $"b")) + } + intercept[AnalysisException] { + df6.selectExpr("array_except(a, b)") + } + val df7 = Seq((Array(1), Array("a"))).toDF("a", "b") + intercept[AnalysisException] { + df7.select(array_except($"a", $"b")) + } + intercept[AnalysisException] { + df7.selectExpr("array_except(a, b)") + } + val df8 = Seq((Array("a"), null)).toDF("a", "b") + intercept[AnalysisException] { + df8.select(array_except($"a", $"b")) + } + intercept[AnalysisException] { + df8.selectExpr("array_except(a, b)") + } + val df9 = Seq((null, Array("a"))).toDF("a", "b") + intercept[AnalysisException] { + df9.select(array_except($"a", $"b")) + } + intercept[AnalysisException] { + df9.selectExpr("array_except(a, b)") + } + + val df10 = Seq( + (Array[Integer](1, 2), Array[Integer](2)), + (Array[Integer](1, 2), Array[Integer](1, null)), + (Array[Integer](1, null, 3), Array[Integer](1, 2)), + (Array[Integer](1, null), Array[Integer](2, null)) + ).toDF("a", "b") + val result10 = df10.select(array_except($"a", $"b")) + val expectedType10 = ArrayType(IntegerType, containsNull = true) + assert(result10.first.schema(0).dataType === expectedType10) + } + private def assertValuesDoNotChangeAfterCoalesceOrUnion(v: Column): Unit = { import DataFrameFunctionsSuite.CodegenFallbackExpr for ((codegenFallback, wholeStage) <- Seq((true, false), (false, false), (false, true))) {