diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 35f8de1328b5..64e6bbd0e8d3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -441,6 +441,7 @@ object FunctionRegistry { expression[ArrayRemove]("array_remove"), expression[ArrayDistinct]("array_distinct"), expression[ArrayTransform]("transform"), + expression[MapFilter]("map_filter"), expression[ArrayFilter]("filter"), expression[ArrayAggregate]("aggregate"), CreateStruct.registryEntry, diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala index 20c7f7d43b9d..e0746b938662 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala @@ -25,7 +25,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, UnresolvedAttribute} import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.expressions.codegen.Block._ -import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData} +import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, GenericArrayData, MapData} import org.apache.spark.sql.types._ /** @@ -133,7 +133,29 @@ trait HigherOrderFunction extends Expression { } } -trait ArrayBasedHigherOrderFunction extends HigherOrderFunction with ExpectsInputTypes { +object HigherOrderFunction { + + def arrayArgumentType(dt: DataType): (DataType, Boolean) = { + dt match { + case ArrayType(elementType, containsNull) => (elementType, containsNull) + case _ => + val ArrayType(elementType, containsNull) = ArrayType.defaultConcreteType + (elementType, containsNull) + } + } + + def mapKeyValueArgumentType(dt: DataType): (DataType, DataType, Boolean) = dt match { + case MapType(kType, vType, vContainsNull) => (kType, vType, vContainsNull) + case _ => + val MapType(kType, vType, vContainsNull) = MapType.defaultConcreteType + (kType, vType, vContainsNull) + } +} + +/** + * Trait for functions having as input one argument and one function. + */ +trait SimpleHigherOrderFunction extends HigherOrderFunction with ExpectsInputTypes { def input: Expression @@ -145,23 +167,33 @@ trait ArrayBasedHigherOrderFunction extends HigherOrderFunction with ExpectsInpu def expectingFunctionType: AbstractDataType = AnyDataType - override def inputTypes: Seq[AbstractDataType] = Seq(ArrayType, expectingFunctionType) - @transient lazy val functionForEval: Expression = functionsForEval.head -} -object ArrayBasedHigherOrderFunction { + /** + * Called by [[eval]]. If a subclass keeps the default nullability, it can override this method + * in order to save null-check code. + */ + protected def nullSafeEval(inputRow: InternalRow, input: Any): Any = + sys.error(s"UnaryHigherOrderFunction must override either eval or nullSafeEval") - def elementArgumentType(dt: DataType): (DataType, Boolean) = { - dt match { - case ArrayType(elementType, containsNull) => (elementType, containsNull) - case _ => - val ArrayType(elementType, containsNull) = ArrayType.defaultConcreteType - (elementType, containsNull) + override def eval(inputRow: InternalRow): Any = { + val value = input.eval(inputRow) + if (value == null) { + null + } else { + nullSafeEval(inputRow, value) } } } +trait ArrayBasedSimpleHigherOrderFunction extends SimpleHigherOrderFunction { + override def inputTypes: Seq[AbstractDataType] = Seq(ArrayType, expectingFunctionType) +} + +trait MapBasedSimpleHigherOrderFunction extends SimpleHigherOrderFunction { + override def inputTypes: Seq[AbstractDataType] = Seq(MapType, expectingFunctionType) +} + /** * Transform elements in an array using the transform function. This is similar to * a `map` in functional programming. @@ -179,14 +211,14 @@ object ArrayBasedHigherOrderFunction { case class ArrayTransform( input: Expression, function: Expression) - extends ArrayBasedHigherOrderFunction with CodegenFallback { + extends ArrayBasedSimpleHigherOrderFunction with CodegenFallback { override def nullable: Boolean = input.nullable override def dataType: ArrayType = ArrayType(function.dataType, function.nullable) override def bind(f: (Expression, Seq[(DataType, Boolean)]) => LambdaFunction): ArrayTransform = { - val elem = ArrayBasedHigherOrderFunction.elementArgumentType(input.dataType) + val elem = HigherOrderFunction.arrayArgumentType(input.dataType) function match { case LambdaFunction(_, arguments, _) if arguments.size == 2 => copy(function = f(function, elem :: (IntegerType, false) :: Nil)) @@ -205,29 +237,78 @@ case class ArrayTransform( (elementVar, indexVar) } - override def eval(input: InternalRow): Any = { - val arr = this.input.eval(input).asInstanceOf[ArrayData] - if (arr == null) { - null - } else { - val f = functionForEval - val result = new GenericArrayData(new Array[Any](arr.numElements)) - var i = 0 - while (i < arr.numElements) { - elementVar.value.set(arr.get(i, elementVar.dataType)) - if (indexVar.isDefined) { - indexVar.get.value.set(i) - } - result.update(i, f.eval(input)) - i += 1 + override def nullSafeEval(inputRow: InternalRow, inputValue: Any): Any = { + val arr = inputValue.asInstanceOf[ArrayData] + val f = functionForEval + val result = new GenericArrayData(new Array[Any](arr.numElements)) + var i = 0 + while (i < arr.numElements) { + elementVar.value.set(arr.get(i, elementVar.dataType)) + if (indexVar.isDefined) { + indexVar.get.value.set(i) } - result + result.update(i, f.eval(inputRow)) + i += 1 } + result } override def prettyName: String = "transform" } +/** + * Filters entries in a map using the provided function. + */ +@ExpressionDescription( +usage = "_FUNC_(expr, func) - Filters entries in a map using the function.", +examples = """ + Examples: + > SELECT _FUNC_(map(1, 0, 2, 2, 3, -1), (k, v) -> k > v); + [1 -> 0, 3 -> -1] + """, +since = "2.4.0") +case class MapFilter( + input: Expression, + function: Expression) + extends MapBasedSimpleHigherOrderFunction with CodegenFallback { + + @transient lazy val (keyVar, valueVar) = { + val args = function.asInstanceOf[LambdaFunction].arguments + (args.head.asInstanceOf[NamedLambdaVariable], args.tail.head.asInstanceOf[NamedLambdaVariable]) + } + + @transient val (keyType, valueType, valueContainsNull) = + HigherOrderFunction.mapKeyValueArgumentType(input.dataType) + + override def bind(f: (Expression, Seq[(DataType, Boolean)]) => LambdaFunction): MapFilter = { + copy(function = f(function, (keyType, false) :: (valueType, valueContainsNull) :: Nil)) + } + + override def nullable: Boolean = input.nullable + + override def nullSafeEval(inputRow: InternalRow, value: Any): Any = { + val m = value.asInstanceOf[MapData] + val f = functionForEval + val retKeys = new mutable.ListBuffer[Any] + val retValues = new mutable.ListBuffer[Any] + m.foreach(keyType, valueType, (k, v) => { + keyVar.value.set(k) + valueVar.value.set(v) + if (f.eval(inputRow).asInstanceOf[Boolean]) { + retKeys += k + retValues += v + } + }) + ArrayBasedMapData(retKeys.toArray, retValues.toArray) + } + + override def dataType: DataType = input.dataType + + override def expectingFunctionType: AbstractDataType = BooleanType + + override def prettyName: String = "map_filter" +} + /** * Filters the input array using the given lambda function. */ @@ -242,7 +323,7 @@ case class ArrayTransform( case class ArrayFilter( input: Expression, function: Expression) - extends ArrayBasedHigherOrderFunction with CodegenFallback { + extends ArrayBasedSimpleHigherOrderFunction with CodegenFallback { override def nullable: Boolean = input.nullable @@ -251,29 +332,25 @@ case class ArrayFilter( override def expectingFunctionType: AbstractDataType = BooleanType override def bind(f: (Expression, Seq[(DataType, Boolean)]) => LambdaFunction): ArrayFilter = { - val elem = ArrayBasedHigherOrderFunction.elementArgumentType(input.dataType) + val elem = HigherOrderFunction.arrayArgumentType(input.dataType) copy(function = f(function, elem :: Nil)) } @transient lazy val LambdaFunction(_, Seq(elementVar: NamedLambdaVariable), _) = function - override def eval(input: InternalRow): Any = { - val arr = this.input.eval(input).asInstanceOf[ArrayData] - if (arr == null) { - null - } else { - val f = functionForEval - val buffer = new mutable.ArrayBuffer[Any](arr.numElements) - var i = 0 - while (i < arr.numElements) { - elementVar.value.set(arr.get(i, elementVar.dataType)) - if (f.eval(input).asInstanceOf[Boolean]) { - buffer += elementVar.value.get - } - i += 1 + override def nullSafeEval(inputRow: InternalRow, value: Any): Any = { + val arr = value.asInstanceOf[ArrayData] + val f = functionForEval + val buffer = new mutable.ArrayBuffer[Any](arr.numElements) + var i = 0 + while (i < arr.numElements) { + elementVar.value.set(arr.get(i, elementVar.dataType)) + if (f.eval(inputRow).asInstanceOf[Boolean]) { + buffer += elementVar.value.get } - new GenericArrayData(buffer) + i += 1 } + new GenericArrayData(buffer) } override def prettyName: String = "filter" @@ -334,7 +411,7 @@ case class ArrayAggregate( override def bind(f: (Expression, Seq[(DataType, Boolean)]) => LambdaFunction): ArrayAggregate = { // Be very conservative with nullable. We cannot be sure that the accumulator does not // evaluate to null. So we always set nullable to true here. - val elem = ArrayBasedHigherOrderFunction.elementArgumentType(input.dataType) + val elem = HigherOrderFunction.arrayArgumentType(input.dataType) val acc = zero.dataType -> true val newMerge = f(merge, acc :: elem :: Nil) val newFinish = f(finish, acc :: Nil) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HigherOrderFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HigherOrderFunctionsSuite.scala index 40cfc0ccc7c0..f7e84b875791 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HigherOrderFunctionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HigherOrderFunctionsSuite.scala @@ -121,6 +121,55 @@ class HigherOrderFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper Seq("[1, 3, 5]", null, "[4, 6]")) } + test("MapFilter") { + def mapFilter(expr: Expression, f: (Expression, Expression) => Expression): Expression = { + val mt = expr.dataType.asInstanceOf[MapType] + MapFilter(expr, createLambda(mt.keyType, false, mt.valueType, mt.valueContainsNull, f)) + } + val mii0 = Literal.create(Map(1 -> 0, 2 -> 10, 3 -> -1), + MapType(IntegerType, IntegerType, valueContainsNull = false)) + val mii1 = Literal.create(Map(1 -> null, 2 -> 10, 3 -> null), + MapType(IntegerType, IntegerType, valueContainsNull = true)) + val miin = Literal.create(null, MapType(IntegerType, IntegerType, valueContainsNull = false)) + + val kGreaterThanV: (Expression, Expression) => Expression = (k, v) => k > v + + checkEvaluation(mapFilter(mii0, kGreaterThanV), Map(1 -> 0, 3 -> -1)) + checkEvaluation(mapFilter(mii1, kGreaterThanV), Map()) + checkEvaluation(mapFilter(miin, kGreaterThanV), null) + + val valueIsNull: (Expression, Expression) => Expression = (_, v) => v.isNull + + checkEvaluation(mapFilter(mii0, valueIsNull), Map()) + checkEvaluation(mapFilter(mii1, valueIsNull), Map(1 -> null, 3 -> null)) + checkEvaluation(mapFilter(miin, valueIsNull), null) + + val msi0 = Literal.create(Map("abcdf" -> 5, "abc" -> 10, "" -> 0), + MapType(StringType, IntegerType, valueContainsNull = false)) + val msi1 = Literal.create(Map("abcdf" -> 5, "abc" -> 10, "" -> null), + MapType(StringType, IntegerType, valueContainsNull = true)) + val msin = Literal.create(null, MapType(StringType, IntegerType, valueContainsNull = false)) + + val isLengthOfKey: (Expression, Expression) => Expression = (k, v) => Length(k) === v + + checkEvaluation(mapFilter(msi0, isLengthOfKey), Map("abcdf" -> 5, "" -> 0)) + checkEvaluation(mapFilter(msi1, isLengthOfKey), Map("abcdf" -> 5)) + checkEvaluation(mapFilter(msin, isLengthOfKey), null) + + val mia0 = Literal.create(Map(1 -> Seq(0, 1, 2), 2 -> Seq(10), -3 -> Seq(-1, 0, -2, 3)), + MapType(IntegerType, ArrayType(IntegerType), valueContainsNull = false)) + val mia1 = Literal.create(Map(1 -> Seq(0, 1, 2), 2 -> null, -3 -> Seq(-1, 0, -2, 3)), + MapType(IntegerType, ArrayType(IntegerType), valueContainsNull = true)) + val mian = Literal.create( + null, MapType(IntegerType, ArrayType(IntegerType), valueContainsNull = false)) + + val customFunc: (Expression, Expression) => Expression = (k, v) => Size(v) + k > 3 + + checkEvaluation(mapFilter(mia0, customFunc), Map(1 -> Seq(0, 1, 2))) + checkEvaluation(mapFilter(mia1, customFunc), Map(1 -> Seq(0, 1, 2))) + checkEvaluation(mapFilter(mian, customFunc), null) + } + test("ArrayFilter") { val ai0 = Literal.create(Seq(1, 2, 3), ArrayType(IntegerType, containsNull = false)) val ai1 = Literal.create(Seq[Integer](1, null, 3), ArrayType(IntegerType, containsNull = true)) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index af3301b1599a..662a7b643e49 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -1800,6 +1800,52 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { assert(ex2.getMessage.contains("data type mismatch: argument 1 requires array type")) } + test("map_filter") { + val dfInts = Seq( + Map(1 -> 10, 2 -> 20, 3 -> 30), + Map(1 -> -1, 2 -> -2, 3 -> -3), + Map(1 -> 10, 2 -> 5, 3 -> -3)).toDF("m") + + checkAnswer(dfInts.selectExpr( + "map_filter(m, (k, v) -> k * 10 = v)", "map_filter(m, (k, v) -> k = -v)"), + Seq( + Row(Map(1 -> 10, 2 -> 20, 3 -> 30), Map()), + Row(Map(), Map(1 -> -1, 2 -> -2, 3 -> -3)), + Row(Map(1 -> 10), Map(3 -> -3)))) + + val dfComplex = Seq( + Map(1 -> Seq(Some(1)), 2 -> Seq(Some(1), Some(2)), 3 -> Seq(Some(1), Some(2), Some(3))), + Map(1 -> null, 2 -> Seq(Some(-2), Some(-2)), 3 -> Seq[Option[Int]](None))).toDF("m") + + checkAnswer(dfComplex.selectExpr( + "map_filter(m, (k, v) -> k = v[0])", "map_filter(m, (k, v) -> k = size(v))"), + Seq( + Row(Map(1 -> Seq(1)), Map(1 -> Seq(1), 2 -> Seq(1, 2), 3 -> Seq(1, 2, 3))), + Row(Map(), Map(2 -> Seq(-2, -2))))) + + // Invalid use cases + val df = Seq( + (Map(1 -> "a"), 1), + (Map.empty[Int, String], 2), + (null, 3) + ).toDF("s", "i") + + val ex1 = intercept[AnalysisException] { + df.selectExpr("map_filter(s, (x, y, z) -> x + y + z)") + } + assert(ex1.getMessage.contains("The number of lambda function arguments '3' does not match")) + + val ex2 = intercept[AnalysisException] { + df.selectExpr("map_filter(s, x -> x)") + } + assert(ex2.getMessage.contains("The number of lambda function arguments '1' does not match")) + + val ex3 = intercept[AnalysisException] { + df.selectExpr("map_filter(i, (k, v) -> k > v)") + } + assert(ex3.getMessage.contains("data type mismatch: argument 1 requires map type")) + } + test("filter function - array for primitive type not containing null") { val df = Seq( Seq(1, 9, 8, 7),