From f92e18cda61cb4d01fa3eb985d9d0733dfa6ca24 Mon Sep 17 00:00:00 2001 From: Huaxin Gao Date: Fri, 13 Apr 2018 11:25:19 -0700 Subject: [PATCH 1/7] [SPARK-23920][SQL]add array_remove to remove all elements that equal element from array --- python/pyspark/sql/functions.py | 16 ++++++++ .../catalyst/analysis/FunctionRegistry.scala | 1 + .../expressions/collectionOperations.scala | 41 +++++++++++++++++++ .../CollectionExpressionsSuite.scala | 26 ++++++++++++ .../org/apache/spark/sql/functions.scala | 9 ++++ .../spark/sql/DataFrameFunctionsSuite.scala | 24 +++++++++++ 6 files changed, 117 insertions(+) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 8490081facc5a..6d8aea1f16085 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -1964,6 +1964,22 @@ def element_at(col, extraction): return Column(sc._jvm.functions.element_at(_to_java_column(col), extraction)) +@since(2.4) +def array_remove(col, element): + """ + Collection function: Remove all elements that equal to element from the given array. + + :param col: name of column containing array + :param element: element to be removed from the array + + >>> df = spark.createDataFrame([([1, 2, 3, 1, 1],), ([],)], ['data']) + >>> df.select(array_remove(df.data, 1)).collect() + [Row(array_remove(data, 1)=[2, 3]), Row(array_remove(data, 1)=[])] + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.array_remove(_to_java_column(col), element)) + + @since(1.4) def explode(col): """Returns a new row for each element in the given array or map. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 867c2d5eab53d..8e5833bd7de5e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -429,6 +429,7 @@ object FunctionRegistry { expression[Concat]("concat"), expression[Flatten]("flatten"), expression[ArrayRepeat]("array_repeat"), + expression[ArrayRemove]("array_remove"), CreateStruct.registryEntry, // misc functions diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index c82db839438ed..e8563b0eef2d3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -1882,3 +1882,44 @@ case class ArrayRepeat(left: Expression, right: Expression) } } + +/** + * Remove all elements that equal to element from the given array + */ +@ExpressionDescription( + usage = "_FUNC_(array, element) - Remove all elements that equal to element from array.", + examples = """ + Examples: + > SELECT _FUNC_(array(1, 2, 3, null, 3), 3); + [1,2,null] + """, since = "2.4.0") +case class ArrayRemove(left: Expression, right: Expression) + extends BinaryExpression with ImplicitCastInputTypes with CodegenFallback { + + override def dataType: DataType = left.dataType + + override def inputTypes: Seq[AbstractDataType] = + Seq(ArrayType, left.dataType.asInstanceOf[ArrayType].elementType) + + override def nullable: Boolean = { + left.nullable || right.nullable || left.dataType.asInstanceOf[ArrayType].containsNull + } + + override def eval(input: InternalRow): Any = { + val value1 = left.eval(input) + if (value1 == null) { + null + } else { + val value2 = right.eval(input) + nullSafeEval(value1, value2) + } + } + + override def nullSafeEval(arr: Any, value: Any): Any = { + val elementType = left.dataType.asInstanceOf[ArrayType].elementType + val data = arr.asInstanceOf[ArrayData].toArray[AnyRef](elementType).filter(_ != value) + new GenericArrayData(data.asInstanceOf[Array[Any]]) + } + + override def prettyName: String = "array_remove" +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala index 6ae1ac18c4dc4..69b710c7f81b0 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala @@ -552,4 +552,30 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(ArrayRepeat(strArray, Literal(2)), Seq(Seq("hi", "hola"), Seq("hi", "hola"))) checkEvaluation(ArrayRepeat(Literal("hi"), Literal(null, IntegerType)), null) } + + test("Array remove") { + val a0 = Literal.create(Seq(1, 2, 3, 2, 2, 5), ArrayType(IntegerType)) + val a1 = Literal.create(Seq("b", "a", "a", "c", "b"), ArrayType(StringType)) + val a2 = Literal.create(Seq[String](null, "", null, ""), ArrayType(StringType)) + val a3 = Literal.create(Seq.empty[Integer], ArrayType(IntegerType)) + val a4 = Literal.create(null, ArrayType(StringType)) + + checkEvaluation(ArrayRemove(a0, Literal(0)), Seq(1, 2, 3, 2, 2, 5)) + checkEvaluation(ArrayRemove(a0, Literal(1)), Seq(2, 3, 2, 2, 5)) + checkEvaluation(ArrayRemove(a0, Literal(2)), Seq(1, 3, 5)) + checkEvaluation(ArrayRemove(a0, Literal(3)), Seq(1, 2, 2, 2, 5)) + checkEvaluation(ArrayRemove(a0, Literal(5)), Seq(1, 2, 3, 2, 2)) + + checkEvaluation(ArrayRemove(a1, Literal("")), Seq("b", "a", "a", "c", "b")) + checkEvaluation(ArrayRemove(a1, Literal("a")), Seq("b", "c", "b")) + checkEvaluation(ArrayRemove(a1, Literal("b")), Seq("a", "a", "c")) + checkEvaluation(ArrayRemove(a1, Literal("c")), Seq("b", "a", "a", "b")) + + checkEvaluation(ArrayRemove(a2, Literal("")), Seq(null, null)) + checkEvaluation(ArrayRemove(a2, Literal(null)), Seq("", "")) + + checkEvaluation(ArrayRemove(a3, Literal("a")), Seq.empty[Integer]) + + checkEvaluation(ArrayRemove(a4, Literal("a")), null) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 2a8fe583b83bc..8f3723c51c0c0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -3169,6 +3169,15 @@ object functions { */ def array_sort(e: Column): Column = withExpr { ArraySort(e.expr) } + /** + * Remove all elements that equal to element from the given array. + * @group collection_funcs + * @since 2.4.0 + */ + def array_remove(column: Column, element: Any): Column = withExpr { + ArrayRemove(column.expr, Literal(element)) + } + /** * Creates a new row for each element in the given array or map column. * diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index d08982a138bc5..d7333af59608d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -948,6 +948,30 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { } + test("array remove") { + val df = Seq( + (Array[Int](2, 1, 2, 3), Array("a", "b", "c", "a"), Array("", "")), + (Array.empty[Int], Array.empty[String], Array.empty[String]), + (null, null, null) + ).toDF("a", "b", "c") + checkAnswer( + df.select(array_remove(df("a"), 2), array_remove(df("b"), "a"), array_remove(df("c"), "")), + Seq( + Row(Seq(1, 3), Seq("b", "c"), Seq.empty[String]), + Row(Seq.empty[Int], Seq.empty[String], Seq.empty[String]), + Row(null, null, null)) + ) + + checkAnswer( + df.selectExpr("array_remove(a, 2)", "array_remove(b, \"a\")", + "array_remove(c, \"\")"), + Seq( + Row(Seq(1, 3), Seq("b", "c"), Seq.empty[String]), + Row(Seq.empty[Int], Seq.empty[String], Seq.empty[String]), + Row(null, null, null)) + ) + } + private def assertValuesDoNotChangeAfterCoalesceOrUnion(v: Column): Unit = { import DataFrameFunctionsSuite.CodegenFallbackExpr for ((codegenFallback, wholeStage) <- Seq((true, false), (false, false), (false, true))) { From f6a629bd048b8ac9939aa93eb849108a765049e5 Mon Sep 17 00:00:00 2001 From: Huaxin Gao Date: Wed, 18 Apr 2018 17:19:35 -0700 Subject: [PATCH 2/7] add doGenCode --- .../expressions/collectionOperations.scala | 54 ++++++++++++++----- .../CollectionExpressionsSuite.scala | 9 +++- 2 files changed, 47 insertions(+), 16 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index e8563b0eef2d3..69f90287d95ac 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -1901,25 +1901,51 @@ case class ArrayRemove(left: Expression, right: Expression) override def inputTypes: Seq[AbstractDataType] = Seq(ArrayType, left.dataType.asInstanceOf[ArrayType].elementType) - override def nullable: Boolean = { - left.nullable || right.nullable || left.dataType.asInstanceOf[ArrayType].containsNull - } - - override def eval(input: InternalRow): Any = { - val value1 = left.eval(input) - if (value1 == null) { - null - } else { - val value2 = right.eval(input) - nullSafeEval(value1, value2) - } - } - override def nullSafeEval(arr: Any, value: Any): Any = { val elementType = left.dataType.asInstanceOf[ArrayType].elementType val data = arr.asInstanceOf[ArrayData].toArray[AnyRef](elementType).filter(_ != value) new GenericArrayData(data.asInstanceOf[Array[Any]]) } + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + val elementType = left.dataType.asInstanceOf[ArrayType].elementType + nullSafeCodeGen(ctx, ev, (arr, value) => { + val arrayClass = classOf[GenericArrayData].getName + val values = ctx.freshName("values") + val i = ctx.freshName("i") + val pos = ctx.freshName("arrayPosition") + val numsToRemove = ctx.freshName("newArrLen") + val getValue = CodeGenerator.getValue(arr, right.dataType, i) + s""" + |int $pos = 0; + |int $numsToRemove = 0; + |Object[] $values; + | + |for (int $i = 0; $i < $arr.numElements(); $i ++) { + | if (!$arr.isNullAt($i) && ${ctx.genEqual(right.dataType, value, getValue)}) { + | $numsToRemove = $numsToRemove + 1; + | } + |} + |$values = new Object[$arr.numElements() - $numsToRemove]; + |for (int $i = 0; $i < $arr.numElements(); $i ++) { + | if ($arr.isNullAt($i)) { + | $values[$pos] = null; + | $pos = $pos + 1; + | } + | else { + | if (${ctx.genEqual(right.dataType, value, getValue)}) { + | ; + | } + | else { + | $values[$pos] = ${CodeGenerator.getValue(arr, elementType, s"$i")}; + | $pos = $pos + 1; + | } + | } + |} + |${ev.value} = new $arrayClass($values); + """.stripMargin + }) + } + override def prettyName: String = "array_remove" } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala index 69b710c7f81b0..91e31a290ea27 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala @@ -559,6 +559,8 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper val a2 = Literal.create(Seq[String](null, "", null, ""), ArrayType(StringType)) val a3 = Literal.create(Seq.empty[Integer], ArrayType(IntegerType)) val a4 = Literal.create(null, ArrayType(StringType)) + val a5 = Literal.create(Seq(1, null, 8, 9, null), ArrayType(IntegerType)) + val a6 = Literal.create(Seq(true, false, false, true), ArrayType(BooleanType)) checkEvaluation(ArrayRemove(a0, Literal(0)), Seq(1, 2, 3, 2, 2, 5)) checkEvaluation(ArrayRemove(a0, Literal(1)), Seq(2, 3, 2, 2, 5)) @@ -572,10 +574,13 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(ArrayRemove(a1, Literal("c")), Seq("b", "a", "a", "b")) checkEvaluation(ArrayRemove(a2, Literal("")), Seq(null, null)) - checkEvaluation(ArrayRemove(a2, Literal(null)), Seq("", "")) + checkEvaluation(ArrayRemove(a2, Literal.create(null, StringType)), null) - checkEvaluation(ArrayRemove(a3, Literal("a")), Seq.empty[Integer]) + checkEvaluation(ArrayRemove(a3, Literal(1)), Seq.empty[Integer]) checkEvaluation(ArrayRemove(a4, Literal("a")), null) + + checkEvaluation(ArrayRemove(a5, Literal(9)), Seq(1, null, 8, null)) + checkEvaluation(ArrayRemove(a6, Literal(false)), Seq(true, true)) } } From 1c247209599174cc7ee1ddc98bf36dd9a5572058 Mon Sep 17 00:00:00 2001 From: Huaxin Gao Date: Tue, 15 May 2018 20:28:41 -0700 Subject: [PATCH 3/7] resolve conflicts and address comments --- .../expressions/collectionOperations.scala | 88 +++++++++++++------ .../CollectionExpressionsSuite.scala | 3 +- .../spark/sql/DataFrameFunctionsSuite.scala | 2 +- 3 files changed, 64 insertions(+), 29 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 69f90287d95ac..8c5fa3dad3459 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -1894,13 +1894,15 @@ case class ArrayRepeat(left: Expression, right: Expression) [1,2,null] """, since = "2.4.0") case class ArrayRemove(left: Expression, right: Expression) - extends BinaryExpression with ImplicitCastInputTypes with CodegenFallback { + extends BinaryExpression with ImplicitCastInputTypes { override def dataType: DataType = left.dataType override def inputTypes: Seq[AbstractDataType] = Seq(ArrayType, left.dataType.asInstanceOf[ArrayType].elementType) + lazy val elementType: DataType = left.dataType.asInstanceOf[ArrayType].elementType + override def nullSafeEval(arr: Any, value: Any): Any = { val elementType = left.dataType.asInstanceOf[ArrayType].elementType val data = arr.asInstanceOf[ArrayData].toArray[AnyRef](elementType).filter(_ != value) @@ -1908,43 +1910,75 @@ case class ArrayRemove(left: Expression, right: Expression) } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - val elementType = left.dataType.asInstanceOf[ArrayType].elementType nullSafeCodeGen(ctx, ev, (arr, value) => { - val arrayClass = classOf[GenericArrayData].getName - val values = ctx.freshName("values") + val numsToRemove = ctx.freshName("numsToRemove") + val newArraySize = ctx.freshName("newArraySize") val i = ctx.freshName("i") - val pos = ctx.freshName("arrayPosition") - val numsToRemove = ctx.freshName("newArrLen") - val getValue = CodeGenerator.getValue(arr, right.dataType, i) + val getValue = CodeGenerator.getValue(arr, elementType, i) + val isEqual = ctx.genEqual(elementType, value, getValue) s""" - |int $pos = 0; |int $numsToRemove = 0; - |Object[] $values; - | |for (int $i = 0; $i < $arr.numElements(); $i ++) { - | if (!$arr.isNullAt($i) && ${ctx.genEqual(right.dataType, value, getValue)}) { + | if (!$arr.isNullAt($i) && $isEqual) { | $numsToRemove = $numsToRemove + 1; | } |} - |$values = new Object[$arr.numElements() - $numsToRemove]; - |for (int $i = 0; $i < $arr.numElements(); $i ++) { - | if ($arr.isNullAt($i)) { - | $values[$pos] = null; - | $pos = $pos + 1; - | } - | else { - | if (${ctx.genEqual(right.dataType, value, getValue)}) { - | ; - | } - | else { - | $values[$pos] = ${CodeGenerator.getValue(arr, elementType, s"$i")}; - | $pos = $pos + 1; - | } - | } + |int $newArraySize = $arr.numElements() - $numsToRemove; + |${genCodeForResult(ctx, ev, arr, value, newArraySize)} + """.stripMargin + }) + } + + def genCodeForResult( + ctx: CodegenContext, + ev: ExprCode, + inputArray: String, + value: String, + newArraySize: String): String = { + val values = ctx.freshName("values") + val i = ctx.freshName("i") + val pos = ctx.freshName("pos") + val getValue = CodeGenerator.getValue(inputArray, elementType, i) + val isEqual = ctx.genEqual(elementType, value, getValue) + if (!CodeGenerator.isPrimitiveType(elementType)) { + val arrayClass = classOf[GenericArrayData].getName + s""" + |int $pos = 0; + |Object[] $values = new Object[$newArraySize]; + |for (int $i = 0; $i < $inputArray.numElements(); $i ++) { + | if ($isEqual) { + | ; + | } + | else { + | $values[$pos] = $getValue; + | $pos = $pos + 1; + | } |} |${ev.value} = new $arrayClass($values); """.stripMargin - }) + } else { + val primitiveValueTypeName = CodeGenerator.primitiveTypeName(elementType) + s""" + |${ctx.createUnsafeArray(values, newArraySize, elementType, s" $prettyName failed.")} + |int $pos = 0; + |for (int $i = 0; $i < $inputArray.numElements(); $i ++) { + | if ($inputArray.isNullAt($i)) { + | $values.setNullAt($pos); + | $pos = $pos + 1; + | } + | else { + | if ($isEqual) { + | ; + | } + | else { + | $values.set$primitiveValueTypeName($pos, $getValue); + | $pos = $pos + 1; + | } + | } + |} + |${ev.value} = $values; + """.stripMargin + } } override def prettyName: String = "array_remove" diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala index 91e31a290ea27..ab9a1e13f2835 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala @@ -567,6 +567,7 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(ArrayRemove(a0, Literal(2)), Seq(1, 3, 5)) checkEvaluation(ArrayRemove(a0, Literal(3)), Seq(1, 2, 2, 2, 5)) checkEvaluation(ArrayRemove(a0, Literal(5)), Seq(1, 2, 3, 2, 2)) + checkEvaluation(ArrayRemove(a0, Literal(null, IntegerType)), null) checkEvaluation(ArrayRemove(a1, Literal("")), Seq("b", "a", "a", "c", "b")) checkEvaluation(ArrayRemove(a1, Literal("a")), Seq("b", "c", "b")) @@ -574,7 +575,7 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(ArrayRemove(a1, Literal("c")), Seq("b", "a", "a", "b")) checkEvaluation(ArrayRemove(a2, Literal("")), Seq(null, null)) - checkEvaluation(ArrayRemove(a2, Literal.create(null, StringType)), null) + checkEvaluation(ArrayRemove(a2, Literal(null, StringType)), null) checkEvaluation(ArrayRemove(a3, Literal(1)), Seq.empty[Integer]) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index d7333af59608d..5c2f9fc6489f1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -955,7 +955,7 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { (null, null, null) ).toDF("a", "b", "c") checkAnswer( - df.select(array_remove(df("a"), 2), array_remove(df("b"), "a"), array_remove(df("c"), "")), + df.select(array_remove($"a", 2), array_remove($"b", "a"), array_remove($"c", "")), Seq( Row(Seq(1, 3), Seq("b", "c"), Seq.empty[String]), Row(Seq.empty[Int], Seq.empty[String], Seq.empty[String]), From 89b4f48245e4247d946bf61d6d4d557716e2854b Mon Sep 17 00:00:00 2001 From: Huaxin Gao Date: Wed, 16 May 2018 10:33:58 -0700 Subject: [PATCH 4/7] addres comments (2) --- .../catalyst/expressions/collectionOperations.scala | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 8c5fa3dad3459..809e4cdd9d629 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -1946,10 +1946,7 @@ case class ArrayRemove(left: Expression, right: Expression) |int $pos = 0; |Object[] $values = new Object[$newArraySize]; |for (int $i = 0; $i < $inputArray.numElements(); $i ++) { - | if ($isEqual) { - | ; - | } - | else { + | if (!($isEqual)) { | $values[$pos] = $getValue; | $pos = $pos + 1; | } @@ -1967,10 +1964,7 @@ case class ArrayRemove(left: Expression, right: Expression) | $pos = $pos + 1; | } | else { - | if ($isEqual) { - | ; - | } - | else { + | if (!($isEqual)) { | $values.set$primitiveValueTypeName($pos, $getValue); | $pos = $pos + 1; | } From 9281ae233dc54dd961e99e345be559929232c148 Mon Sep 17 00:00:00 2001 From: Huaxin Gao Date: Tue, 22 May 2018 10:01:34 -0700 Subject: [PATCH 5/7] add complex type support --- .../expressions/collectionOperations.scala | 31 +++++++++++++++++-- .../CollectionExpressionsSuite.scala | 24 ++++++++++++++ 2 files changed, 52 insertions(+), 3 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 809e4cdd9d629..09fbbafbca4c2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -1903,10 +1903,35 @@ case class ArrayRemove(left: Expression, right: Expression) lazy val elementType: DataType = left.dataType.asInstanceOf[ArrayType].elementType + @transient private lazy val ordering: Ordering[Any] = + TypeUtils.getInterpretedOrdering(right.dataType) + + override def checkInputDataTypes(): TypeCheckResult = { + if (!left.dataType.isInstanceOf[ArrayType] + || left.dataType.asInstanceOf[ArrayType].elementType != right.dataType) { + TypeCheckResult.TypeCheckFailure( + "Arguments must be an array followed by a value of same type as the array members") + } else { + TypeUtils.checkForOrderingExpr(right.dataType, s"function $prettyName") + } + } + override def nullSafeEval(arr: Any, value: Any): Any = { - val elementType = left.dataType.asInstanceOf[ArrayType].elementType - val data = arr.asInstanceOf[ArrayData].toArray[AnyRef](elementType).filter(_ != value) - new GenericArrayData(data.asInstanceOf[Array[Any]]) + val newArray = new Array[Any](arr.asInstanceOf[ArrayData].numElements()) + var pos = 0 + arr.asInstanceOf[ArrayData].foreach(right.dataType, (i, v) => + if (v == null) { + if (value != null) { + newArray(pos) = null + pos += 1 + } + } + else if (!ordering.equiv(v, value)) { + newArray(pos) = v + pos += 1 + } + ) + new GenericArrayData(newArray.slice(0, pos)) } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala index ab9a1e13f2835..91356f26dacf4 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala @@ -583,5 +583,29 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(ArrayRemove(a5, Literal(9)), Seq(1, null, 8, null)) checkEvaluation(ArrayRemove(a6, Literal(false)), Seq(true, true)) + + // complex data types + val b0 = Literal.create(Seq[Array[Byte]](Array[Byte](5, 6), Array[Byte](1, 2), + Array[Byte](1, 2), Array[Byte](5, 6)), ArrayType(BinaryType)) + val b1 = Literal.create(Seq[Array[Byte]](Array[Byte](2, 1), null), + ArrayType(BinaryType)) + val b2 = Literal.create(Seq[Array[Byte]](null, Array[Byte](1, 2)), + ArrayType(BinaryType)) + val nullBinary = Literal.create(null, BinaryType) + + val dataToRemoved1 = Literal.create(Array[Byte](5, 6), BinaryType) + checkEvaluation(ArrayRemove(b0, dataToRemoved1), + Seq[Array[Byte]](Array[Byte](1, 2), Array[Byte](1, 2))) + checkEvaluation(ArrayRemove(b0, nullBinary), null) + checkEvaluation(ArrayRemove(b1, dataToRemoved1), Seq[Array[Byte]](Array[Byte](2, 1), null)) + checkEvaluation(ArrayRemove(b2, dataToRemoved1), Seq[Array[Byte]](null, Array[Byte](1, 2))) + + val c0 = Literal.create(Seq[Seq[Int]](Seq[Int](1, 2), Seq[Int](3, 4)), + ArrayType(ArrayType(IntegerType))) + val c1 = Literal.create(Seq[Seq[Int]](Seq[Int](5, 6), Seq[Int](2, 1)), + ArrayType(ArrayType(IntegerType))) + val dataToRemoved2 = Literal.create(Seq[Int](1, 2), ArrayType(IntegerType)) + checkEvaluation(ArrayRemove(c0, dataToRemoved2), Seq[Seq[Int]](Seq[Int](3, 4))) + checkEvaluation(ArrayRemove(c1, dataToRemoved2), Seq[Seq[Int]](Seq[Int](5, 6), Seq[Int](2, 1))) } } From 074ed887e9e0c233b29635974c3d52cad67808b7 Mon Sep 17 00:00:00 2001 From: Huaxin Gao Date: Thu, 24 May 2018 09:33:29 -0700 Subject: [PATCH 6/7] remove redundant null check --- .../sql/catalyst/expressions/collectionOperations.scala | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 09fbbafbca4c2..091fc09fdd9f9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -1920,13 +1920,7 @@ case class ArrayRemove(left: Expression, right: Expression) val newArray = new Array[Any](arr.asInstanceOf[ArrayData].numElements()) var pos = 0 arr.asInstanceOf[ArrayData].foreach(right.dataType, (i, v) => - if (v == null) { - if (value != null) { - newArray(pos) = null - pos += 1 - } - } - else if (!ordering.equiv(v, value)) { + if (v == null || !ordering.equiv(v, value)) { newArray(pos) = v pos += 1 } From 52d230880b9dbdb72f0a558e32cdc0a34618c5e4 Mon Sep 17 00:00:00 2001 From: Huaxin Gao Date: Thu, 31 May 2018 15:32:43 -0700 Subject: [PATCH 7/7] address comments --- .../expressions/collectionOperations.scala | 29 ++++++++++++------- .../CollectionExpressionsSuite.scala | 16 +++++----- .../spark/sql/DataFrameFunctionsSuite.scala | 5 ++++ 3 files changed, 33 insertions(+), 17 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 091fc09fdd9f9..513fa428d9056 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -1898,8 +1898,13 @@ case class ArrayRemove(left: Expression, right: Expression) override def dataType: DataType = left.dataType - override def inputTypes: Seq[AbstractDataType] = - Seq(ArrayType, left.dataType.asInstanceOf[ArrayType].elementType) + override def inputTypes: Seq[AbstractDataType] = { + val elementType = left.dataType match { + case t: ArrayType => t.elementType + case _ => AnyDataType + } + Seq(ArrayType, elementType) + } lazy val elementType: DataType = left.dataType.asInstanceOf[ArrayType].elementType @@ -1907,12 +1912,10 @@ case class ArrayRemove(left: Expression, right: Expression) TypeUtils.getInterpretedOrdering(right.dataType) override def checkInputDataTypes(): TypeCheckResult = { - if (!left.dataType.isInstanceOf[ArrayType] - || left.dataType.asInstanceOf[ArrayType].elementType != right.dataType) { - TypeCheckResult.TypeCheckFailure( - "Arguments must be an array followed by a value of same type as the array members") - } else { - TypeUtils.checkForOrderingExpr(right.dataType, s"function $prettyName") + super.checkInputDataTypes() match { + case f: TypeCheckResult.TypeCheckFailure => f + case TypeCheckResult.TypeCheckSuccess => + TypeUtils.checkForOrderingExpr(right.dataType, s"function $prettyName") } } @@ -1965,10 +1968,16 @@ case class ArrayRemove(left: Expression, right: Expression) |int $pos = 0; |Object[] $values = new Object[$newArraySize]; |for (int $i = 0; $i < $inputArray.numElements(); $i ++) { - | if (!($isEqual)) { - | $values[$pos] = $getValue; + | if ($inputArray.isNullAt($i)) { + | $values[$pos] = null; | $pos = $pos + 1; | } + | else { + | if (!($isEqual)) { + | $values[$pos] = $getValue; + | $pos = $pos + 1; + | } + | } |} |${ev.value} = new $arrayClass($values); """.stripMargin diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala index 91356f26dacf4..b8799b898f9ad 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala @@ -593,19 +593,21 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper ArrayType(BinaryType)) val nullBinary = Literal.create(null, BinaryType) - val dataToRemoved1 = Literal.create(Array[Byte](5, 6), BinaryType) - checkEvaluation(ArrayRemove(b0, dataToRemoved1), + val dataToRemove1 = Literal.create(Array[Byte](5, 6), BinaryType) + checkEvaluation(ArrayRemove(b0, dataToRemove1), Seq[Array[Byte]](Array[Byte](1, 2), Array[Byte](1, 2))) checkEvaluation(ArrayRemove(b0, nullBinary), null) - checkEvaluation(ArrayRemove(b1, dataToRemoved1), Seq[Array[Byte]](Array[Byte](2, 1), null)) - checkEvaluation(ArrayRemove(b2, dataToRemoved1), Seq[Array[Byte]](null, Array[Byte](1, 2))) + checkEvaluation(ArrayRemove(b1, dataToRemove1), Seq[Array[Byte]](Array[Byte](2, 1), null)) + checkEvaluation(ArrayRemove(b2, dataToRemove1), Seq[Array[Byte]](null, Array[Byte](1, 2))) val c0 = Literal.create(Seq[Seq[Int]](Seq[Int](1, 2), Seq[Int](3, 4)), ArrayType(ArrayType(IntegerType))) val c1 = Literal.create(Seq[Seq[Int]](Seq[Int](5, 6), Seq[Int](2, 1)), ArrayType(ArrayType(IntegerType))) - val dataToRemoved2 = Literal.create(Seq[Int](1, 2), ArrayType(IntegerType)) - checkEvaluation(ArrayRemove(c0, dataToRemoved2), Seq[Seq[Int]](Seq[Int](3, 4))) - checkEvaluation(ArrayRemove(c1, dataToRemoved2), Seq[Seq[Int]](Seq[Int](5, 6), Seq[Int](2, 1))) + val c2 = Literal.create(Seq[Seq[Int]](null, Seq[Int](2, 1)), ArrayType(ArrayType(IntegerType))) + val dataToRemove2 = Literal.create(Seq[Int](1, 2), ArrayType(IntegerType)) + checkEvaluation(ArrayRemove(c0, dataToRemove2), Seq[Seq[Int]](Seq[Int](3, 4))) + checkEvaluation(ArrayRemove(c1, dataToRemove2), Seq[Seq[Int]](Seq[Int](5, 6), Seq[Int](2, 1))) + checkEvaluation(ArrayRemove(c2, dataToRemove2), Seq[Seq[Int]](null, Seq[Int](2, 1))) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 5c2f9fc6489f1..998faf49beba2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -970,6 +970,11 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { Row(Seq.empty[Int], Seq.empty[String], Seq.empty[String]), Row(null, null, null)) ) + + val e = intercept[AnalysisException] { + Seq(("a string element", "a")).toDF().selectExpr("array_remove(_1, _2)") + } + assert(e.message.contains("argument 1 requires array type, however, '`_1`' is of string type")) } private def assertValuesDoNotChangeAfterCoalesceOrUnion(v: Column): Unit = {