From 3f88e2a927c22f4fc509b8ca96027ef381f7fe84 Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Fri, 3 Aug 2018 17:16:11 +0200 Subject: [PATCH 1/7] [SPARK-23937][SQL] Add map_filter SQL function --- .../catalyst/analysis/FunctionRegistry.scala | 1 + .../expressions/higherOrderFunctions.scala | 82 ++++++++++++++++++- .../HigherOrderFunctionsSuite.scala | 49 +++++++++++ .../spark/sql/DataFrameFunctionsSuite.scala | 46 +++++++++++ 4 files changed, 174 insertions(+), 4 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index f7517486e541..1746d807bc65 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -441,6 +441,7 @@ object FunctionRegistry { expression[ArrayRemove]("array_remove"), expression[ArrayDistinct]("array_distinct"), expression[ArrayTransform]("transform"), + expression[MapFilter]("map_filter"), CreateStruct.registryEntry, // misc functions diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala index c5c3482afa13..eae19620e7b6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala @@ -19,10 +19,12 @@ package org.apache.spark.sql.catalyst.expressions import java.util.concurrent.atomic.AtomicReference +import scala.collection.mutable + import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.expressions.codegen.Block._ -import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData} +import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, GenericArrayData, MapData} import org.apache.spark.sql.types._ /** @@ -123,7 +125,10 @@ trait HigherOrderFunction extends Expression { } } -trait ArrayBasedHigherOrderFunction extends HigherOrderFunction with ExpectsInputTypes { +/** + * Trait for functions having as input one argument and one function. + */ +trait UnaryHigherOrderFunction extends HigherOrderFunction with ExpectsInputTypes { def input: Expression @@ -135,9 +140,15 @@ trait ArrayBasedHigherOrderFunction extends HigherOrderFunction with ExpectsInpu def expectingFunctionType: AbstractDataType = AnyDataType + @transient lazy val functionForEval: Expression = functionsForEval.head +} + +trait ArrayBasedUnaryHigherOrderFunction extends UnaryHigherOrderFunction { override def inputTypes: Seq[AbstractDataType] = Seq(ArrayType, expectingFunctionType) +} - @transient lazy val functionForEval: Expression = functionsForEval.head +trait MapBasedUnaryHigherOrderFunction extends UnaryHigherOrderFunction { + override def inputTypes: Seq[AbstractDataType] = Seq(MapType, expectingFunctionType) } /** @@ -157,7 +168,7 @@ trait ArrayBasedHigherOrderFunction extends HigherOrderFunction with ExpectsInpu case class ArrayTransform( input: Expression, function: Expression) - extends ArrayBasedHigherOrderFunction with CodegenFallback { + extends ArrayBasedUnaryHigherOrderFunction with CodegenFallback { override def nullable: Boolean = input.nullable @@ -210,3 +221,66 @@ case class ArrayTransform( override def prettyName: String = "transform" } + +/** + * Filters entries in a map using the provided function. + */ +@ExpressionDescription( +usage = "_FUNC_(expr, func) - Filters entries in a map using the function.", +examples = """ + Examples: + > SELECT _FUNC_(map(1, 0, 2, 2, 3, -1), (k, v) -> k > v); + [1 -> 0, 3 -> -1] + """, +since = "2.4.0") +case class MapFilter( + input: Expression, + function: Expression) + extends MapBasedUnaryHigherOrderFunction with CodegenFallback { + + @transient val (keyType, valueType, valueContainsNull) = input.dataType match { + case MapType(kType, vType, vContainsNull) => (kType, vType, vContainsNull) + case _ => + val MapType(kType, vType, vContainsNull) = MapType.defaultConcreteType + (kType, vType, vContainsNull) + } + + @transient lazy val (keyVar, valueVar) = { + val args = function.asInstanceOf[LambdaFunction].arguments + (args.head.asInstanceOf[NamedLambdaVariable], args.tail.head.asInstanceOf[NamedLambdaVariable]) + } + + override def bind(f: (Expression, Seq[(DataType, Boolean)]) => LambdaFunction): MapFilter = { + function match { + case LambdaFunction(_, _, _) => + copy(function = f(function, (keyType, false) :: (valueType, valueContainsNull) :: Nil)) + } + } + + override def nullable: Boolean = input.nullable + + override def eval(input: InternalRow): Any = { + val m = this.input.eval(input).asInstanceOf[MapData] + if (m == null) { + null + } else { + val retKeys = new mutable.ListBuffer[Any] + val retValues = new mutable.ListBuffer[Any] + m.foreach(keyType, valueType, (k, v) => { + keyVar.value.set(k) + valueVar.value.set(v) + if (functionForEval.eval(input).asInstanceOf[Boolean]) { + retKeys += k + retValues += v + } + }) + ArrayBasedMapData(retKeys.toArray, retValues.toArray) + } + } + + override def dataType: DataType = input.dataType + + override def expectingFunctionType: AbstractDataType = BooleanType + + override def prettyName: String = "map_filter" +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HigherOrderFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HigherOrderFunctionsSuite.scala index e987ea5b8a4d..45a2e3be31b0 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HigherOrderFunctionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HigherOrderFunctionsSuite.scala @@ -94,4 +94,53 @@ class HigherOrderFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(transform(aai, array => Cast(transform(array, plusIndex), StringType)), Seq("[1, 3, 5]", null, "[4, 6]")) } + + test("MapFilter") { + def mapFilter(expr: Expression, f: (Expression, Expression) => Expression): Expression = { + val mt = expr.dataType.asInstanceOf[MapType] + MapFilter(expr, createLambda(mt.keyType, false, mt.valueType, mt.valueContainsNull, f)) + } + val mii0 = Literal.create(Map(1 -> 0, 2 -> 10, 3 -> -1), + MapType(IntegerType, IntegerType, valueContainsNull = false)) + val mii1 = Literal.create(Map(1 -> null, 2 -> 10, 3 -> null), + MapType(IntegerType, IntegerType, valueContainsNull = true)) + val miin = Literal.create(null, MapType(IntegerType, IntegerType, valueContainsNull = false)) + + val kGreaterThanV: (Expression, Expression) => Expression = (k, v) => k > v + + checkEvaluation(mapFilter(mii0, kGreaterThanV), Map(1 -> 0, 3 -> -1)) + checkEvaluation(mapFilter(mii1, kGreaterThanV), Map()) + checkEvaluation(mapFilter(miin, kGreaterThanV), null) + + val valueNull: (Expression, Expression) => Expression = (_, v) => v.isNull + + checkEvaluation(mapFilter(mii0, valueNull), Map()) + checkEvaluation(mapFilter(mii1, valueNull), Map(1 -> null, 3 -> null)) + checkEvaluation(mapFilter(miin, valueNull), null) + + val msi0 = Literal.create(Map("abcdf" -> 5, "abc" -> 10, "" -> 0), + MapType(StringType, IntegerType, valueContainsNull = false)) + val msi1 = Literal.create(Map("abcdf" -> 5, "abc" -> 10, "" -> null), + MapType(StringType, IntegerType, valueContainsNull = true)) + val msin = Literal.create(null, MapType(StringType, IntegerType, valueContainsNull = false)) + + val isLengthOfKey: (Expression, Expression) => Expression = (k, v) => Length(k) === v + + checkEvaluation(mapFilter(msi0, isLengthOfKey), Map("abcdf" -> 5, "" -> 0)) + checkEvaluation(mapFilter(msi1, isLengthOfKey), Map("abcdf" -> 5)) + checkEvaluation(mapFilter(msin, isLengthOfKey), null) + + val mia0 = Literal.create(Map(1 -> Seq(0, 1, 2), 2 -> Seq(10), -3 -> Seq(-1, 0, -2, 3)), + MapType(IntegerType, ArrayType(IntegerType), valueContainsNull = false)) + val mia1 = Literal.create(Map(1 -> Seq(0, 1, 2), 2 -> null, -3 -> Seq(-1, 0, -2, 3)), + MapType(IntegerType, ArrayType(IntegerType), valueContainsNull = true)) + val mian = Literal.create( + null, MapType(IntegerType, ArrayType(IntegerType), valueContainsNull = false)) + + val customFunc: (Expression, Expression) => Expression = (k, v) => Size(v) + k > 3 + + checkEvaluation(mapFilter(mia0, customFunc), Map(1 -> Seq(0, 1, 2))) + checkEvaluation(mapFilter(mia1, customFunc), Map(1 -> Seq(0, 1, 2))) + checkEvaluation(mapFilter(mian, customFunc), null) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 923482024b03..346cd2f7b16a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -1800,6 +1800,52 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { assert(ex2.getMessage.contains("data type mismatch: argument 1 requires array type")) } + test("map_filter") { + val dfInts = Seq( + Map(1 -> 10, 2 -> 20, 3 -> 30), + Map(1 -> -1, 2 -> -2, 3 -> -3), + Map(1 -> 10, 2 -> 5, 3 -> -3)).toDF("m") + + checkAnswer(dfInts.selectExpr( + "map_filter(m, (k, v) -> k * 10 = v)", "map_filter(m, (k, v) -> k = -v)"), + Seq( + Row(Map(1 -> 10, 2 -> 20, 3 -> 30), Map()), + Row(Map(), Map(1 -> -1, 2 -> -2, 3 -> -3)), + Row(Map(1 -> 10), Map(3 -> -3)))) + + val dfComplex = Seq( + Map(1 -> Seq(Some(1)), 2 -> Seq(Some(1), Some(2)), 3 -> Seq(Some(1), Some(2), Some(3))), + Map(1 -> null, 2 -> Seq(Some(-2), Some(-2)), 3 -> Seq[Option[Int]](None))).toDF("m") + + checkAnswer(dfComplex.selectExpr( + "map_filter(m, (k, v) -> k = v[0])", "map_filter(m, (k, v) -> k = size(v))"), + Seq( + Row(Map(1 -> Seq(1)), Map(1 -> Seq(1), 2 -> Seq(1, 2), 3 -> Seq(1, 2, 3))), + Row(Map(), Map(2 -> Seq(-2, -2))))) + + // Invalid use cases + val df = Seq( + (Map(1 -> "a"), 1), + (Map.empty[Int, String], 2), + (null, 3) + ).toDF("s", "i") + + val ex1 = intercept[AnalysisException] { + df.selectExpr("map_filter(s, (x, y, z) -> x + y + z)") + } + assert(ex1.getMessage.contains("The number of lambda function arguments '3' does not match")) + + val ex2 = intercept[AnalysisException] { + df.selectExpr("map_filter(s, x -> x)") + } + assert(ex2.getMessage.contains("The number of lambda function arguments '1' does not match")) + + val ex3 = intercept[AnalysisException] { + df.selectExpr("map_filter(i, (k, v) -> k > v)") + } + assert(ex3.getMessage.contains("data type mismatch: argument 1 requires map type")) + } + private def assertValuesDoNotChangeAfterCoalesceOrUnion(v: Column): Unit = { import DataFrameFunctionsSuite.CodegenFallbackExpr for ((codegenFallback, wholeStage) <- Seq((true, false), (false, false), (false, true))) { From 9bbaa3b18493fe5e77652b7f39bdc5a6732771bb Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Mon, 6 Aug 2018 10:39:31 +0200 Subject: [PATCH 2/7] address comments --- .../expressions/higherOrderFunctions.scala | 75 +++++++++++-------- .../HigherOrderFunctionsSuite.scala | 8 +- 2 files changed, 46 insertions(+), 37 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala index eae19620e7b6..14e45cf4c933 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala @@ -141,6 +141,22 @@ trait UnaryHigherOrderFunction extends HigherOrderFunction with ExpectsInputType def expectingFunctionType: AbstractDataType = AnyDataType @transient lazy val functionForEval: Expression = functionsForEval.head + + /** + * Called by [[eval]]. If a subclass keeps the default nullability, it can override this method + * in order to save null-check code. + */ + protected def nullSafeEval(inputRow: InternalRow, input: Any): Any = + sys.error(s"UnaryHigherOrderFunction must override either eval or nullSafeEval") + + override def eval(inputRow: InternalRow): Any = { + val value = input.eval(inputRow) + if (value == null) { + null + } else { + nullSafeEval(inputRow, value) + } + } } trait ArrayBasedUnaryHigherOrderFunction extends UnaryHigherOrderFunction { @@ -199,24 +215,20 @@ case class ArrayTransform( (elementVar, indexVar) } - override def eval(input: InternalRow): Any = { - val arr = this.input.eval(input).asInstanceOf[ArrayData] - if (arr == null) { - null - } else { - val f = functionForEval - val result = new GenericArrayData(new Array[Any](arr.numElements)) - var i = 0 - while (i < arr.numElements) { - elementVar.value.set(arr.get(i, elementVar.dataType)) - if (indexVar.isDefined) { - indexVar.get.value.set(i) - } - result.update(i, f.eval(input)) - i += 1 + override def nullSafeEval(inputRow: InternalRow, inputValue: Any): Any = { + val arr = inputValue.asInstanceOf[ArrayData] + val f = functionForEval + val result = new GenericArrayData(new Array[Any](arr.numElements)) + var i = 0 + while (i < arr.numElements) { + elementVar.value.set(arr.get(i, elementVar.dataType)) + if (indexVar.isDefined) { + indexVar.get.value.set(i) } - result + result.update(i, f.eval(inputRow)) + i += 1 } + result } override def prettyName: String = "transform" @@ -259,23 +271,20 @@ case class MapFilter( override def nullable: Boolean = input.nullable - override def eval(input: InternalRow): Any = { - val m = this.input.eval(input).asInstanceOf[MapData] - if (m == null) { - null - } else { - val retKeys = new mutable.ListBuffer[Any] - val retValues = new mutable.ListBuffer[Any] - m.foreach(keyType, valueType, (k, v) => { - keyVar.value.set(k) - valueVar.value.set(v) - if (functionForEval.eval(input).asInstanceOf[Boolean]) { - retKeys += k - retValues += v - } - }) - ArrayBasedMapData(retKeys.toArray, retValues.toArray) - } + override def nullSafeEval(inputRow: InternalRow, value: Any): Any = { + val m = value.asInstanceOf[MapData] + val f = functionForEval + val retKeys = new mutable.ListBuffer[Any] + val retValues = new mutable.ListBuffer[Any] + m.foreach(keyType, valueType, (k, v) => { + keyVar.value.set(k) + valueVar.value.set(v) + if (f.eval(inputRow).asInstanceOf[Boolean]) { + retKeys += k + retValues += v + } + }) + ArrayBasedMapData(retKeys.toArray, retValues.toArray) } override def dataType: DataType = input.dataType diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HigherOrderFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HigherOrderFunctionsSuite.scala index 45a2e3be31b0..9a5375f7df0e 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HigherOrderFunctionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HigherOrderFunctionsSuite.scala @@ -112,11 +112,11 @@ class HigherOrderFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(mapFilter(mii1, kGreaterThanV), Map()) checkEvaluation(mapFilter(miin, kGreaterThanV), null) - val valueNull: (Expression, Expression) => Expression = (_, v) => v.isNull + val valueIsNull: (Expression, Expression) => Expression = (_, v) => v.isNull - checkEvaluation(mapFilter(mii0, valueNull), Map()) - checkEvaluation(mapFilter(mii1, valueNull), Map(1 -> null, 3 -> null)) - checkEvaluation(mapFilter(miin, valueNull), null) + checkEvaluation(mapFilter(mii0, valueIsNull), Map()) + checkEvaluation(mapFilter(mii1, valueIsNull), Map(1 -> null, 3 -> null)) + checkEvaluation(mapFilter(miin, valueIsNull), null) val msi0 = Literal.create(Map("abcdf" -> 5, "abc" -> 10, "" -> 0), MapType(StringType, IntegerType, valueContainsNull = false)) From b58a1dec715a26aa8bd53efa102342afff44a896 Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Mon, 6 Aug 2018 16:55:06 +0200 Subject: [PATCH 3/7] address comment --- .../sql/catalyst/expressions/higherOrderFunctions.scala | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala index 07c9a3977d96..08bb21b7c905 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala @@ -278,10 +278,7 @@ case class MapFilter( } override def bind(f: (Expression, Seq[(DataType, Boolean)]) => LambdaFunction): MapFilter = { - function match { - case LambdaFunction(_, _, _) => - copy(function = f(function, (keyType, false) :: (valueType, valueContainsNull) :: Nil)) - } + copy(function = f(function, (keyType, false) :: (valueType, valueContainsNull) :: Nil)) } override def nullable: Boolean = input.nullable From 9c25ae66b7fd0e3d5f11e3e097af32ef72a55e76 Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Mon, 6 Aug 2018 17:54:26 +0200 Subject: [PATCH 4/7] address comment --- .../expressions/higherOrderFunctions.scala | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala index 08bb21b7c905..89648fb40648 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala @@ -173,6 +173,13 @@ trait ArrayBasedUnaryHigherOrderFunction extends UnaryHigherOrderFunction { trait MapBasedUnaryHigherOrderFunction extends UnaryHigherOrderFunction { override def inputTypes: Seq[AbstractDataType] = Seq(MapType, expectingFunctionType) + + @transient val (keyType, valueType, valueContainsNull) = input.dataType match { + case MapType(kType, vType, vContainsNull) => (kType, vType, vContainsNull) + case _ => + val MapType(kType, vType, vContainsNull) = MapType.defaultConcreteType + (kType, vType, vContainsNull) + } } object ArrayBasedHigherOrderFunction { @@ -265,13 +272,6 @@ case class MapFilter( function: Expression) extends MapBasedUnaryHigherOrderFunction with CodegenFallback { - @transient val (keyType, valueType, valueContainsNull) = input.dataType match { - case MapType(kType, vType, vContainsNull) => (kType, vType, vContainsNull) - case _ => - val MapType(kType, vType, vContainsNull) = MapType.defaultConcreteType - (kType, vType, vContainsNull) - } - @transient lazy val (keyVar, valueVar) = { val args = function.asInstanceOf[LambdaFunction].arguments (args.head.asInstanceOf[NamedLambdaVariable], args.tail.head.asInstanceOf[NamedLambdaVariable]) From 1823fb279b1e5ed7b55d6e27ede27982ce94d922 Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Tue, 7 Aug 2018 11:38:49 +0200 Subject: [PATCH 5/7] address comments --- .../expressions/higherOrderFunctions.scala | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala index 89648fb40648..5d30fc53c00d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala @@ -136,7 +136,7 @@ trait HigherOrderFunction extends Expression { /** * Trait for functions having as input one argument and one function. */ -trait UnaryHigherOrderFunction extends HigherOrderFunction with ExpectsInputTypes { +trait SimpleHigherOrderFunction extends HigherOrderFunction with ExpectsInputTypes { def input: Expression @@ -167,14 +167,14 @@ trait UnaryHigherOrderFunction extends HigherOrderFunction with ExpectsInputType } } -trait ArrayBasedUnaryHigherOrderFunction extends UnaryHigherOrderFunction { +trait ArrayBasedSimpleHigherOrderFunction extends SimpleHigherOrderFunction { override def inputTypes: Seq[AbstractDataType] = Seq(ArrayType, expectingFunctionType) } -trait MapBasedUnaryHigherOrderFunction extends UnaryHigherOrderFunction { +trait MapBasedSimpleHigherOrderFunction extends SimpleHigherOrderFunction { override def inputTypes: Seq[AbstractDataType] = Seq(MapType, expectingFunctionType) - @transient val (keyType, valueType, valueContainsNull) = input.dataType match { + def keyValueArgumentType(dt: DataType): (DataType, DataType, Boolean) = dt match { case MapType(kType, vType, vContainsNull) => (kType, vType, vContainsNull) case _ => val MapType(kType, vType, vContainsNull) = MapType.defaultConcreteType @@ -211,7 +211,7 @@ object ArrayBasedHigherOrderFunction { case class ArrayTransform( input: Expression, function: Expression) - extends ArrayBasedUnaryHigherOrderFunction with CodegenFallback { + extends ArrayBasedSimpleHigherOrderFunction with CodegenFallback { override def nullable: Boolean = input.nullable @@ -270,13 +270,15 @@ since = "2.4.0") case class MapFilter( input: Expression, function: Expression) - extends MapBasedUnaryHigherOrderFunction with CodegenFallback { + extends MapBasedSimpleHigherOrderFunction with CodegenFallback { @transient lazy val (keyVar, valueVar) = { val args = function.asInstanceOf[LambdaFunction].arguments (args.head.asInstanceOf[NamedLambdaVariable], args.tail.head.asInstanceOf[NamedLambdaVariable]) } + @transient val (keyType, valueType, valueContainsNull) = keyValueArgumentType(input.dataType) + override def bind(f: (Expression, Seq[(DataType, Boolean)]) => LambdaFunction): MapFilter = { copy(function = f(function, (keyType, false) :: (valueType, valueContainsNull) :: Nil)) } @@ -320,7 +322,7 @@ case class MapFilter( case class ArrayFilter( input: Expression, function: Expression) - extends ArrayBasedUnaryHigherOrderFunction with CodegenFallback { + extends ArrayBasedSimpleHigherOrderFunction with CodegenFallback { override def nullable: Boolean = input.nullable From 16d8b6418a56c6417cc16a0664898642c827c4c7 Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Tue, 7 Aug 2018 12:44:33 +0200 Subject: [PATCH 6/7] address comment --- .../expressions/higherOrderFunctions.scala | 47 ++++++++++--------- 1 file changed, 24 insertions(+), 23 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala index 5d30fc53c00d..6ef5fe8111f7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala @@ -133,6 +133,25 @@ trait HigherOrderFunction extends Expression { } } +object HigherOrderFunctionHelper { + + def arrayArgumentType(dt: DataType): (DataType, Boolean) = { + dt match { + case ArrayType(elementType, containsNull) => (elementType, containsNull) + case _ => + val ArrayType(elementType, containsNull) = ArrayType.defaultConcreteType + (elementType, containsNull) + } + } + + def mapKeyValueArgumentType(dt: DataType): (DataType, DataType, Boolean) = dt match { + case MapType(kType, vType, vContainsNull) => (kType, vType, vContainsNull) + case _ => + val MapType(kType, vType, vContainsNull) = MapType.defaultConcreteType + (kType, vType, vContainsNull) + } +} + /** * Trait for functions having as input one argument and one function. */ @@ -173,25 +192,6 @@ trait ArrayBasedSimpleHigherOrderFunction extends SimpleHigherOrderFunction { trait MapBasedSimpleHigherOrderFunction extends SimpleHigherOrderFunction { override def inputTypes: Seq[AbstractDataType] = Seq(MapType, expectingFunctionType) - - def keyValueArgumentType(dt: DataType): (DataType, DataType, Boolean) = dt match { - case MapType(kType, vType, vContainsNull) => (kType, vType, vContainsNull) - case _ => - val MapType(kType, vType, vContainsNull) = MapType.defaultConcreteType - (kType, vType, vContainsNull) - } -} - -object ArrayBasedHigherOrderFunction { - - def elementArgumentType(dt: DataType): (DataType, Boolean) = { - dt match { - case ArrayType(elementType, containsNull) => (elementType, containsNull) - case _ => - val ArrayType(elementType, containsNull) = ArrayType.defaultConcreteType - (elementType, containsNull) - } - } } /** @@ -218,7 +218,7 @@ case class ArrayTransform( override def dataType: ArrayType = ArrayType(function.dataType, function.nullable) override def bind(f: (Expression, Seq[(DataType, Boolean)]) => LambdaFunction): ArrayTransform = { - val elem = ArrayBasedHigherOrderFunction.elementArgumentType(input.dataType) + val elem = HigherOrderFunctionHelper.arrayArgumentType(input.dataType) function match { case LambdaFunction(_, arguments, _) if arguments.size == 2 => copy(function = f(function, elem :: (IntegerType, false) :: Nil)) @@ -277,7 +277,8 @@ case class MapFilter( (args.head.asInstanceOf[NamedLambdaVariable], args.tail.head.asInstanceOf[NamedLambdaVariable]) } - @transient val (keyType, valueType, valueContainsNull) = keyValueArgumentType(input.dataType) + @transient val (keyType, valueType, valueContainsNull) = + HigherOrderFunctionHelper.mapKeyValueArgumentType(input.dataType) override def bind(f: (Expression, Seq[(DataType, Boolean)]) => LambdaFunction): MapFilter = { copy(function = f(function, (keyType, false) :: (valueType, valueContainsNull) :: Nil)) @@ -331,7 +332,7 @@ case class ArrayFilter( override def expectingFunctionType: AbstractDataType = BooleanType override def bind(f: (Expression, Seq[(DataType, Boolean)]) => LambdaFunction): ArrayFilter = { - val elem = ArrayBasedHigherOrderFunction.elementArgumentType(input.dataType) + val elem = HigherOrderFunctionHelper.arrayArgumentType(input.dataType) copy(function = f(function, elem :: Nil)) } @@ -410,7 +411,7 @@ case class ArrayAggregate( override def bind(f: (Expression, Seq[(DataType, Boolean)]) => LambdaFunction): ArrayAggregate = { // Be very conservative with nullable. We cannot be sure that the accumulator does not // evaluate to null. So we always set nullable to true here. - val elem = ArrayBasedHigherOrderFunction.elementArgumentType(input.dataType) + val elem = HigherOrderFunctionHelper.arrayArgumentType(input.dataType) val acc = zero.dataType -> true val newMerge = f(merge, acc :: elem :: Nil) val newFinish = f(finish, acc :: Nil) From af79644cb4687b6acb9a10548f05aef980f1882a Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Tue, 7 Aug 2018 12:46:26 +0200 Subject: [PATCH 7/7] rename to HigherOrderFunction --- .../catalyst/expressions/higherOrderFunctions.scala | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala index 6ef5fe8111f7..e0746b938662 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala @@ -133,7 +133,7 @@ trait HigherOrderFunction extends Expression { } } -object HigherOrderFunctionHelper { +object HigherOrderFunction { def arrayArgumentType(dt: DataType): (DataType, Boolean) = { dt match { @@ -218,7 +218,7 @@ case class ArrayTransform( override def dataType: ArrayType = ArrayType(function.dataType, function.nullable) override def bind(f: (Expression, Seq[(DataType, Boolean)]) => LambdaFunction): ArrayTransform = { - val elem = HigherOrderFunctionHelper.arrayArgumentType(input.dataType) + val elem = HigherOrderFunction.arrayArgumentType(input.dataType) function match { case LambdaFunction(_, arguments, _) if arguments.size == 2 => copy(function = f(function, elem :: (IntegerType, false) :: Nil)) @@ -278,7 +278,7 @@ case class MapFilter( } @transient val (keyType, valueType, valueContainsNull) = - HigherOrderFunctionHelper.mapKeyValueArgumentType(input.dataType) + HigherOrderFunction.mapKeyValueArgumentType(input.dataType) override def bind(f: (Expression, Seq[(DataType, Boolean)]) => LambdaFunction): MapFilter = { copy(function = f(function, (keyType, false) :: (valueType, valueContainsNull) :: Nil)) @@ -332,7 +332,7 @@ case class ArrayFilter( override def expectingFunctionType: AbstractDataType = BooleanType override def bind(f: (Expression, Seq[(DataType, Boolean)]) => LambdaFunction): ArrayFilter = { - val elem = HigherOrderFunctionHelper.arrayArgumentType(input.dataType) + val elem = HigherOrderFunction.arrayArgumentType(input.dataType) copy(function = f(function, elem :: Nil)) } @@ -411,7 +411,7 @@ case class ArrayAggregate( override def bind(f: (Expression, Seq[(DataType, Boolean)]) => LambdaFunction): ArrayAggregate = { // Be very conservative with nullable. We cannot be sure that the accumulator does not // evaluate to null. So we always set nullable to true here. - val elem = HigherOrderFunctionHelper.arrayArgumentType(input.dataType) + val elem = HigherOrderFunction.arrayArgumentType(input.dataType) val acc = zero.dataType -> true val newMerge = f(merge, acc :: elem :: Nil) val newFinish = f(finish, acc :: Nil)