Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -441,6 +441,7 @@ object FunctionRegistry {
expression[ArrayRemove]("array_remove"),
expression[ArrayDistinct]("array_distinct"),
expression[ArrayTransform]("transform"),
expression[MapFilter]("map_filter"),
expression[ArrayFilter]("filter"),
expression[ArrayAggregate]("aggregate"),
CreateStruct.registryEntry,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, UnresolvedAttribute}
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData}
import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, GenericArrayData, MapData}
import org.apache.spark.sql.types._

/**
Expand Down Expand Up @@ -133,7 +133,29 @@ trait HigherOrderFunction extends Expression {
}
}

trait ArrayBasedHigherOrderFunction extends HigherOrderFunction with ExpectsInputTypes {
object HigherOrderFunction {

def arrayArgumentType(dt: DataType): (DataType, Boolean) = {
dt match {
case ArrayType(elementType, containsNull) => (elementType, containsNull)
case _ =>
val ArrayType(elementType, containsNull) = ArrayType.defaultConcreteType
(elementType, containsNull)
}
}

def mapKeyValueArgumentType(dt: DataType): (DataType, DataType, Boolean) = dt match {
case MapType(kType, vType, vContainsNull) => (kType, vType, vContainsNull)
case _ =>
val MapType(kType, vType, vContainsNull) = MapType.defaultConcreteType
(kType, vType, vContainsNull)
}
}

/**
* Trait for functions having as input one argument and one function.
*/
trait SimpleHigherOrderFunction extends HigherOrderFunction with ExpectsInputTypes {

def input: Expression

Expand All @@ -145,23 +167,33 @@ trait ArrayBasedHigherOrderFunction extends HigherOrderFunction with ExpectsInpu

def expectingFunctionType: AbstractDataType = AnyDataType

override def inputTypes: Seq[AbstractDataType] = Seq(ArrayType, expectingFunctionType)

@transient lazy val functionForEval: Expression = functionsForEval.head
}

object ArrayBasedHigherOrderFunction {
/**
* Called by [[eval]]. If a subclass keeps the default nullability, it can override this method
* in order to save null-check code.
*/
protected def nullSafeEval(inputRow: InternalRow, input: Any): Any =
sys.error(s"UnaryHigherOrderFunction must override either eval or nullSafeEval")

def elementArgumentType(dt: DataType): (DataType, Boolean) = {
dt match {
case ArrayType(elementType, containsNull) => (elementType, containsNull)
case _ =>
val ArrayType(elementType, containsNull) = ArrayType.defaultConcreteType
(elementType, containsNull)
override def eval(inputRow: InternalRow): Any = {
val value = input.eval(inputRow)
if (value == null) {
null
} else {
nullSafeEval(inputRow, value)
}
}
}

trait ArrayBasedSimpleHigherOrderFunction extends SimpleHigherOrderFunction {
override def inputTypes: Seq[AbstractDataType] = Seq(ArrayType, expectingFunctionType)
}

trait MapBasedSimpleHigherOrderFunction extends SimpleHigherOrderFunction {
override def inputTypes: Seq[AbstractDataType] = Seq(MapType, expectingFunctionType)
}

/**
* Transform elements in an array using the transform function. This is similar to
* a `map` in functional programming.
Expand All @@ -179,14 +211,14 @@ object ArrayBasedHigherOrderFunction {
case class ArrayTransform(
input: Expression,
function: Expression)
extends ArrayBasedHigherOrderFunction with CodegenFallback {
extends ArrayBasedSimpleHigherOrderFunction with CodegenFallback {

override def nullable: Boolean = input.nullable

override def dataType: ArrayType = ArrayType(function.dataType, function.nullable)

override def bind(f: (Expression, Seq[(DataType, Boolean)]) => LambdaFunction): ArrayTransform = {
val elem = ArrayBasedHigherOrderFunction.elementArgumentType(input.dataType)
val elem = HigherOrderFunction.arrayArgumentType(input.dataType)
function match {
case LambdaFunction(_, arguments, _) if arguments.size == 2 =>
copy(function = f(function, elem :: (IntegerType, false) :: Nil))
Expand All @@ -205,29 +237,78 @@ case class ArrayTransform(
(elementVar, indexVar)
}

override def eval(input: InternalRow): Any = {
val arr = this.input.eval(input).asInstanceOf[ArrayData]
if (arr == null) {
null
} else {
val f = functionForEval
val result = new GenericArrayData(new Array[Any](arr.numElements))
var i = 0
while (i < arr.numElements) {
elementVar.value.set(arr.get(i, elementVar.dataType))
if (indexVar.isDefined) {
indexVar.get.value.set(i)
}
result.update(i, f.eval(input))
i += 1
override def nullSafeEval(inputRow: InternalRow, inputValue: Any): Any = {
val arr = inputValue.asInstanceOf[ArrayData]
val f = functionForEval
val result = new GenericArrayData(new Array[Any](arr.numElements))
var i = 0
while (i < arr.numElements) {
elementVar.value.set(arr.get(i, elementVar.dataType))
if (indexVar.isDefined) {
indexVar.get.value.set(i)
}
result
result.update(i, f.eval(inputRow))
i += 1
}
result
}

override def prettyName: String = "transform"
}

/**
* Filters entries in a map using the provided function.
*/
@ExpressionDescription(
usage = "_FUNC_(expr, func) - Filters entries in a map using the function.",
examples = """
Examples:
> SELECT _FUNC_(map(1, 0, 2, 2, 3, -1), (k, v) -> k > v);
[1 -> 0, 3 -> -1]
""",
since = "2.4.0")
case class MapFilter(
input: Expression,
function: Expression)
extends MapBasedSimpleHigherOrderFunction with CodegenFallback {

@transient lazy val (keyVar, valueVar) = {
val args = function.asInstanceOf[LambdaFunction].arguments
(args.head.asInstanceOf[NamedLambdaVariable], args.tail.head.asInstanceOf[NamedLambdaVariable])
}

@transient val (keyType, valueType, valueContainsNull) =
HigherOrderFunction.mapKeyValueArgumentType(input.dataType)

override def bind(f: (Expression, Seq[(DataType, Boolean)]) => LambdaFunction): MapFilter = {
copy(function = f(function, (keyType, false) :: (valueType, valueContainsNull) :: Nil))
}

override def nullable: Boolean = input.nullable

override def nullSafeEval(inputRow: InternalRow, value: Any): Any = {
val m = value.asInstanceOf[MapData]
val f = functionForEval
val retKeys = new mutable.ListBuffer[Any]
val retValues = new mutable.ListBuffer[Any]
m.foreach(keyType, valueType, (k, v) => {
keyVar.value.set(k)
valueVar.value.set(v)
if (f.eval(inputRow).asInstanceOf[Boolean]) {
retKeys += k
retValues += v
}
})
ArrayBasedMapData(retKeys.toArray, retValues.toArray)
}

override def dataType: DataType = input.dataType

override def expectingFunctionType: AbstractDataType = BooleanType

override def prettyName: String = "map_filter"
}

/**
* Filters the input array using the given lambda function.
*/
Expand All @@ -242,7 +323,7 @@ case class ArrayTransform(
case class ArrayFilter(
input: Expression,
function: Expression)
extends ArrayBasedHigherOrderFunction with CodegenFallback {
extends ArrayBasedSimpleHigherOrderFunction with CodegenFallback {

override def nullable: Boolean = input.nullable

Expand All @@ -251,29 +332,25 @@ case class ArrayFilter(
override def expectingFunctionType: AbstractDataType = BooleanType

override def bind(f: (Expression, Seq[(DataType, Boolean)]) => LambdaFunction): ArrayFilter = {
val elem = ArrayBasedHigherOrderFunction.elementArgumentType(input.dataType)
val elem = HigherOrderFunction.arrayArgumentType(input.dataType)
copy(function = f(function, elem :: Nil))
}

@transient lazy val LambdaFunction(_, Seq(elementVar: NamedLambdaVariable), _) = function

override def eval(input: InternalRow): Any = {
val arr = this.input.eval(input).asInstanceOf[ArrayData]
if (arr == null) {
null
} else {
val f = functionForEval
val buffer = new mutable.ArrayBuffer[Any](arr.numElements)
var i = 0
while (i < arr.numElements) {
elementVar.value.set(arr.get(i, elementVar.dataType))
if (f.eval(input).asInstanceOf[Boolean]) {
buffer += elementVar.value.get
}
i += 1
override def nullSafeEval(inputRow: InternalRow, value: Any): Any = {
val arr = value.asInstanceOf[ArrayData]
val f = functionForEval
val buffer = new mutable.ArrayBuffer[Any](arr.numElements)
var i = 0
while (i < arr.numElements) {
elementVar.value.set(arr.get(i, elementVar.dataType))
if (f.eval(inputRow).asInstanceOf[Boolean]) {
buffer += elementVar.value.get
}
new GenericArrayData(buffer)
i += 1
}
new GenericArrayData(buffer)
}

override def prettyName: String = "filter"
Expand Down Expand Up @@ -334,7 +411,7 @@ case class ArrayAggregate(
override def bind(f: (Expression, Seq[(DataType, Boolean)]) => LambdaFunction): ArrayAggregate = {
// Be very conservative with nullable. We cannot be sure that the accumulator does not
// evaluate to null. So we always set nullable to true here.
val elem = ArrayBasedHigherOrderFunction.elementArgumentType(input.dataType)
val elem = HigherOrderFunction.arrayArgumentType(input.dataType)
val acc = zero.dataType -> true
val newMerge = f(merge, acc :: elem :: Nil)
val newFinish = f(finish, acc :: Nil)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,55 @@ class HigherOrderFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper
Seq("[1, 3, 5]", null, "[4, 6]"))
}

test("MapFilter") {
def mapFilter(expr: Expression, f: (Expression, Expression) => Expression): Expression = {
val mt = expr.dataType.asInstanceOf[MapType]
MapFilter(expr, createLambda(mt.keyType, false, mt.valueType, mt.valueContainsNull, f))
}
val mii0 = Literal.create(Map(1 -> 0, 2 -> 10, 3 -> -1),
MapType(IntegerType, IntegerType, valueContainsNull = false))
val mii1 = Literal.create(Map(1 -> null, 2 -> 10, 3 -> null),
MapType(IntegerType, IntegerType, valueContainsNull = true))
val miin = Literal.create(null, MapType(IntegerType, IntegerType, valueContainsNull = false))

val kGreaterThanV: (Expression, Expression) => Expression = (k, v) => k > v

checkEvaluation(mapFilter(mii0, kGreaterThanV), Map(1 -> 0, 3 -> -1))
checkEvaluation(mapFilter(mii1, kGreaterThanV), Map())
checkEvaluation(mapFilter(miin, kGreaterThanV), null)

val valueIsNull: (Expression, Expression) => Expression = (_, v) => v.isNull

checkEvaluation(mapFilter(mii0, valueIsNull), Map())
checkEvaluation(mapFilter(mii1, valueIsNull), Map(1 -> null, 3 -> null))
checkEvaluation(mapFilter(miin, valueIsNull), null)

val msi0 = Literal.create(Map("abcdf" -> 5, "abc" -> 10, "" -> 0),
MapType(StringType, IntegerType, valueContainsNull = false))
val msi1 = Literal.create(Map("abcdf" -> 5, "abc" -> 10, "" -> null),
MapType(StringType, IntegerType, valueContainsNull = true))
val msin = Literal.create(null, MapType(StringType, IntegerType, valueContainsNull = false))

val isLengthOfKey: (Expression, Expression) => Expression = (k, v) => Length(k) === v

checkEvaluation(mapFilter(msi0, isLengthOfKey), Map("abcdf" -> 5, "" -> 0))
checkEvaluation(mapFilter(msi1, isLengthOfKey), Map("abcdf" -> 5))
checkEvaluation(mapFilter(msin, isLengthOfKey), null)

val mia0 = Literal.create(Map(1 -> Seq(0, 1, 2), 2 -> Seq(10), -3 -> Seq(-1, 0, -2, 3)),
MapType(IntegerType, ArrayType(IntegerType), valueContainsNull = false))
val mia1 = Literal.create(Map(1 -> Seq(0, 1, 2), 2 -> null, -3 -> Seq(-1, 0, -2, 3)),
MapType(IntegerType, ArrayType(IntegerType), valueContainsNull = true))
val mian = Literal.create(
null, MapType(IntegerType, ArrayType(IntegerType), valueContainsNull = false))

val customFunc: (Expression, Expression) => Expression = (k, v) => Size(v) + k > 3

checkEvaluation(mapFilter(mia0, customFunc), Map(1 -> Seq(0, 1, 2)))
checkEvaluation(mapFilter(mia1, customFunc), Map(1 -> Seq(0, 1, 2)))
checkEvaluation(mapFilter(mian, customFunc), null)
}

test("ArrayFilter") {
val ai0 = Literal.create(Seq(1, 2, 3), ArrayType(IntegerType, containsNull = false))
val ai1 = Literal.create(Seq[Integer](1, null, 3), ArrayType(IntegerType, containsNull = true))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1800,6 +1800,52 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext {
assert(ex2.getMessage.contains("data type mismatch: argument 1 requires array type"))
}

test("map_filter") {
val dfInts = Seq(
Map(1 -> 10, 2 -> 20, 3 -> 30),
Map(1 -> -1, 2 -> -2, 3 -> -3),
Map(1 -> 10, 2 -> 5, 3 -> -3)).toDF("m")

checkAnswer(dfInts.selectExpr(
"map_filter(m, (k, v) -> k * 10 = v)", "map_filter(m, (k, v) -> k = -v)"),
Seq(
Row(Map(1 -> 10, 2 -> 20, 3 -> 30), Map()),
Row(Map(), Map(1 -> -1, 2 -> -2, 3 -> -3)),
Row(Map(1 -> 10), Map(3 -> -3))))

val dfComplex = Seq(
Map(1 -> Seq(Some(1)), 2 -> Seq(Some(1), Some(2)), 3 -> Seq(Some(1), Some(2), Some(3))),
Map(1 -> null, 2 -> Seq(Some(-2), Some(-2)), 3 -> Seq[Option[Int]](None))).toDF("m")

checkAnswer(dfComplex.selectExpr(
"map_filter(m, (k, v) -> k = v[0])", "map_filter(m, (k, v) -> k = size(v))"),
Seq(
Row(Map(1 -> Seq(1)), Map(1 -> Seq(1), 2 -> Seq(1, 2), 3 -> Seq(1, 2, 3))),
Row(Map(), Map(2 -> Seq(-2, -2)))))

// Invalid use cases
val df = Seq(
(Map(1 -> "a"), 1),
(Map.empty[Int, String], 2),
(null, 3)
).toDF("s", "i")

val ex1 = intercept[AnalysisException] {
df.selectExpr("map_filter(s, (x, y, z) -> x + y + z)")
}
assert(ex1.getMessage.contains("The number of lambda function arguments '3' does not match"))

val ex2 = intercept[AnalysisException] {
df.selectExpr("map_filter(s, x -> x)")
}
assert(ex2.getMessage.contains("The number of lambda function arguments '1' does not match"))

val ex3 = intercept[AnalysisException] {
df.selectExpr("map_filter(i, (k, v) -> k > v)")
}
assert(ex3.getMessage.contains("data type mismatch: argument 1 requires map type"))
}

test("filter function - array for primitive type not containing null") {
val df = Seq(
Seq(1, 9, 8, 7),
Expand Down