Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
755d6db
[SPARK-23821][SQL] Collection function: flatten
Mar 26, 2018
ad46962
[SPARK-23821][SQL] Improving test cases
Apr 2, 2018
eeab727
[SPARK-23821][SQL] Merging the current master into the feature branch.
Apr 3, 2018
a50d42e
[SPARK-23821][SQL] Code-styling improvements
Apr 9, 2018
e213341
[SPARK-23821][SQL] Merging current master to the feature branch
Apr 10, 2018
b9d99f7
[SPARK-23821][SQL] Merging current master to the feature branch
Apr 10, 2018
0e0def4
Merge remote-tracking branch 'spark/master' into feature/array-api-fl…
Apr 12, 2018
0089e45
[SPARK-23821][SQL] Checks of max array size + Added more tests
Apr 16, 2018
1a3ec1f
[SPARK-23821][SQL] Merging current master to the feature branch.
Apr 16, 2018
2ceb53b
[SPARK-23821][SQL] Optimizing evaluation without codegen.
Apr 16, 2018
e21c306
[SPARK-23821][SQL] Merging current master to the feature branch.
Apr 17, 2018
207eb5a
[SPARK-23821][SQL] Improving codeGen.
Apr 17, 2018
10849d7
[SPARK-23821][SQL] Removing extra space from the exception message.
Apr 17, 2018
57f554b
[SPARK-23821][SQL] Merging current master to the feature branch.
Apr 18, 2018
9081291
[SPARK-23821][SQL] Merging current master to the feature branch.
Apr 18, 2018
f11aa7b
[SPARK-23821][SQL] Merging current master to the feature branch.
Apr 19, 2018
88c4971
[SPARK-23821][SQL] Merging current master to the feature branch.
Apr 19, 2018
37b68cd
[SPARK-23821][SQL] Small refactoring
Apr 19, 2018
508fee0
[SPARK-23821][SQL] Merging current master to the feature branch.
Apr 20, 2018
939fc23
Merge remote-tracking branch 'spark/master' into feature/array-api-fl…
Apr 20, 2018
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions python/pyspark/sql/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2191,6 +2191,23 @@ def reverse(col):
return Column(sc._jvm.functions.reverse(_to_java_column(col)))


@since(2.4)
def flatten(col):
"""
Collection function: creates a single array from an array of arrays.
If a structure of nested arrays is deeper than two levels,
only one level of nesting is removed.

:param col: name of column or expression

>>> df = spark.createDataFrame([([[1, 2, 3], [4, 5], [6]],), ([None, [4, 5]],)], ['data'])
>>> df.select(flatten(df.data).alias('r')).collect()
[Row(r=[1, 2, 3, 4, 5, 6]), Row(r=None)]
"""
sc = SparkContext._active_spark_context
return Column(sc._jvm.functions.flatten(_to_java_column(col)))


@since(2.3)
def map_keys(col):
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -413,6 +413,7 @@ object FunctionRegistry {
expression[ArrayMax]("array_max"),
expression[Reverse]("reverse"),
expression[Concat]("concat"),
expression[Flatten]("flatten"),
CreateStruct.registryEntry,

// misc functions
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -883,3 +883,179 @@ case class Concat(children: Seq[Expression]) extends Expression {

override def sql: String = s"concat(${children.map(_.sql).mkString(", ")})"
}

/**
* Transforms an array of arrays into a single array.
*/
@ExpressionDescription(
usage = "_FUNC_(arrayOfArrays) - Transforms an array of arrays into a single array.",
examples = """
Examples:
> SELECT _FUNC_(array(array(1, 2), array(3, 4));
[1,2,3,4]
""",
since = "2.4.0")
case class Flatten(child: Expression) extends UnaryExpression {

private val MAX_ARRAY_LENGTH = ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH

private lazy val childDataType: ArrayType = child.dataType.asInstanceOf[ArrayType]

override def nullable: Boolean = child.nullable || childDataType.containsNull

override def dataType: DataType = childDataType.elementType

lazy val elementType: DataType = dataType.asInstanceOf[ArrayType].elementType

override def checkInputDataTypes(): TypeCheckResult = child.dataType match {
case ArrayType(_: ArrayType, _) =>
TypeCheckResult.TypeCheckSuccess
case _ =>
TypeCheckResult.TypeCheckFailure(
s"The argument should be an array of arrays, " +
s"but '${child.sql}' is of ${child.dataType.simpleString} type."
)
}

override def nullSafeEval(child: Any): Any = {
val elements = child.asInstanceOf[ArrayData].toObjectArray(dataType)

if (elements.contains(null)) {
null
} else {
val arrayData = elements.map(_.asInstanceOf[ArrayData])
val numberOfElements = arrayData.foldLeft(0L)((sum, e) => sum + e.numElements())
if (numberOfElements > MAX_ARRAY_LENGTH) {
throw new RuntimeException("Unsuccessful try to flatten an array of arrays with " +
s"$numberOfElements elements due to exceeding the array size limit $MAX_ARRAY_LENGTH.")
}
val flattenedData = new Array(numberOfElements.toInt)
var position = 0
for (ad <- arrayData) {
val arr = ad.toObjectArray(elementType)
Array.copy(arr, 0, flattenedData, position, arr.length)
position += arr.length
}
new GenericArrayData(flattenedData)
}
}

override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
nullSafeCodeGen(ctx, ev, c => {
val code = if (CodeGenerator.isPrimitiveType(elementType)) {
genCodeForFlattenOfPrimitiveElements(ctx, c, ev.value)
} else {
genCodeForFlattenOfNonPrimitiveElements(ctx, c, ev.value)
}
if (childDataType.containsNull) nullElementsProtection(ev, c, code) else code
})
}

private def nullElementsProtection(
ev: ExprCode,
childVariableName: String,
coreLogic: String): String = {
s"""
|for (int z = 0; !${ev.isNull} && z < $childVariableName.numElements(); z++) {
| ${ev.isNull} |= $childVariableName.isNullAt(z);
|}
|if (!${ev.isNull}) {
| $coreLogic
|}
""".stripMargin
}

private def genCodeForNumberOfElements(
ctx: CodegenContext,
childVariableName: String) : (String, String) = {
val variableName = ctx.freshName("numElements")
val code = s"""
|long $variableName = 0;
|for (int z = 0; z < $childVariableName.numElements(); z++) {
| $variableName += $childVariableName.getArray(z).numElements();
|}
|if ($variableName > $MAX_ARRAY_LENGTH) {
| throw new RuntimeException("Unsuccessful try to flatten an array of arrays with " +
| $variableName + " elements due to exceeding the array size limit $MAX_ARRAY_LENGTH.");
|}
""".stripMargin
(code, variableName)
}

private def genCodeForFlattenOfPrimitiveElements(
ctx: CodegenContext,
childVariableName: String,
arrayDataName: String): String = {
val arrayName = ctx.freshName("array")
val arraySizeName = ctx.freshName("size")
val counter = ctx.freshName("counter")
val tempArrayDataName = ctx.freshName("tempArrayData")

val (numElemCode, numElemName) = genCodeForNumberOfElements(ctx, childVariableName)

val unsafeArraySizeInBytes = s"""
|long $arraySizeName = UnsafeArrayData.calculateSizeOfUnderlyingByteArray(
| $numElemName,
| ${elementType.defaultSize});
|if ($arraySizeName > $MAX_ARRAY_LENGTH) {
| throw new RuntimeException("Unsuccessful try to flatten an array of arrays with " +
| $arraySizeName + " bytes of data due to exceeding the limit $MAX_ARRAY_LENGTH" +
| " bytes for UnsafeArrayData.");
|}
""".stripMargin
val baseOffset = Platform.BYTE_ARRAY_OFFSET

val primitiveValueTypeName = CodeGenerator.primitiveTypeName(elementType)

s"""
|$numElemCode
|$unsafeArraySizeInBytes
|byte[] $arrayName = new byte[(int)$arraySizeName];
|UnsafeArrayData $tempArrayDataName = new UnsafeArrayData();
|Platform.putLong($arrayName, $baseOffset, $numElemName);
|$tempArrayDataName.pointTo($arrayName, $baseOffset, (int)$arraySizeName);
|int $counter = 0;
|for (int k = 0; k < $childVariableName.numElements(); k++) {
| ArrayData arr = $childVariableName.getArray(k);
| for (int l = 0; l < arr.numElements(); l++) {
| if (arr.isNullAt(l)) {
| $tempArrayDataName.setNullAt($counter);
| } else {
| $tempArrayDataName.set$primitiveValueTypeName(
| $counter,
| ${CodeGenerator.getValue("arr", elementType, "l")}
| );
| }
| $counter++;
| }
|}
|$arrayDataName = $tempArrayDataName;
""".stripMargin
}

private def genCodeForFlattenOfNonPrimitiveElements(
ctx: CodegenContext,
childVariableName: String,
arrayDataName: String): String = {
val genericArrayClass = classOf[GenericArrayData].getName
val arrayName = ctx.freshName("arrayObject")
val counter = ctx.freshName("counter")
val (numElemCode, numElemName) = genCodeForNumberOfElements(ctx, childVariableName)

s"""
|$numElemCode
|Object[] $arrayName = new Object[(int)$numElemName];
|int $counter = 0;
|for (int k = 0; k < $childVariableName.numElements(); k++) {
| ArrayData arr = $childVariableName.getArray(k);
| for (int l = 0; l < arr.numElements(); l++) {
| $arrayName[$counter] = ${CodeGenerator.getValue("arr", elementType, "l")};
| $counter++;
| }
|}
|$arrayDataName = new $genericArrayClass($arrayName);
""".stripMargin
}

override def prettyName: String = "flatten"
}
Original file line number Diff line number Diff line change
Expand Up @@ -280,4 +280,99 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper

checkEvaluation(Concat(Seq(aa0, aa1)), Seq(Seq("a", "b"), Seq("c"), Seq("d"), Seq("e", "f")))
}

test("Flatten") {
// Primitive-type test cases
val intArrayType = ArrayType(ArrayType(IntegerType))

// Main test cases (primitive type)
val aim1 = Literal.create(Seq(Seq(1, 2, 3), Seq(4, 5), Seq(6)), intArrayType)
val aim2 = Literal.create(Seq(Seq(1, 2, 3)), intArrayType)

checkEvaluation(Flatten(aim1), Seq(1, 2, 3, 4, 5, 6))
checkEvaluation(Flatten(aim2), Seq(1, 2, 3))

// Test cases with an empty array (primitive type)
val aie1 = Literal.create(Seq(Seq.empty, Seq(1, 2), Seq(3, 4)), intArrayType)
val aie2 = Literal.create(Seq(Seq(1, 2), Seq.empty, Seq(3, 4)), intArrayType)
val aie3 = Literal.create(Seq(Seq(1, 2), Seq(3, 4), Seq.empty), intArrayType)
val aie4 = Literal.create(Seq(Seq.empty, Seq.empty, Seq.empty), intArrayType)
val aie5 = Literal.create(Seq(Seq.empty), intArrayType)
val aie6 = Literal.create(Seq.empty, intArrayType)

checkEvaluation(Flatten(aie1), Seq(1, 2, 3, 4))
checkEvaluation(Flatten(aie2), Seq(1, 2, 3, 4))
checkEvaluation(Flatten(aie3), Seq(1, 2, 3, 4))
checkEvaluation(Flatten(aie4), Seq.empty)
checkEvaluation(Flatten(aie5), Seq.empty)
checkEvaluation(Flatten(aie6), Seq.empty)

// Test cases with null elements (primitive type)
val ain1 = Literal.create(Seq(Seq(null, null, null), Seq(4, null)), intArrayType)
val ain2 = Literal.create(Seq(Seq(null, 2, null), Seq(null, null)), intArrayType)
val ain3 = Literal.create(Seq(Seq(null, null), Seq(null, null)), intArrayType)

checkEvaluation(Flatten(ain1), Seq(null, null, null, 4, null))
checkEvaluation(Flatten(ain2), Seq(null, 2, null, null, null))
checkEvaluation(Flatten(ain3), Seq(null, null, null, null))

// Test cases with a null array (primitive type)
val aia1 = Literal.create(Seq(null, Seq(1, 2)), intArrayType)
val aia2 = Literal.create(Seq(Seq(1, 2), null), intArrayType)
val aia3 = Literal.create(Seq(null), intArrayType)
val aia4 = Literal.create(null, intArrayType)

checkEvaluation(Flatten(aia1), null)
checkEvaluation(Flatten(aia2), null)
checkEvaluation(Flatten(aia3), null)
checkEvaluation(Flatten(aia4), null)

// Non-primitive-type test cases
val strArrayType = ArrayType(ArrayType(StringType))
val arrArrayType = ArrayType(ArrayType(ArrayType(StringType)))

// Main test cases (non-primitive type)
val asm1 = Literal.create(Seq(Seq("a"), Seq("b", "c"), Seq("d", "e", "f")), strArrayType)
val asm2 = Literal.create(Seq(Seq("a", "b")), strArrayType)
val asm3 = Literal.create(Seq(Seq(Seq("a", "b"), Seq("c")), Seq(Seq("d", "e"))), arrArrayType)

checkEvaluation(Flatten(asm1), Seq("a", "b", "c", "d", "e", "f"))
checkEvaluation(Flatten(asm2), Seq("a", "b"))
checkEvaluation(Flatten(asm3), Seq(Seq("a", "b"), Seq("c"), Seq("d", "e")))

// Test cases with an empty array (non-primitive type)
val ase1 = Literal.create(Seq(Seq.empty, Seq("a", "b"), Seq("c", "d")), strArrayType)
val ase2 = Literal.create(Seq(Seq("a", "b"), Seq.empty, Seq("c", "d")), strArrayType)
val ase3 = Literal.create(Seq(Seq("a", "b"), Seq("c", "d"), Seq.empty), strArrayType)
val ase4 = Literal.create(Seq(Seq.empty, Seq.empty, Seq.empty), strArrayType)
val ase5 = Literal.create(Seq(Seq.empty), strArrayType)
val ase6 = Literal.create(Seq.empty, strArrayType)

checkEvaluation(Flatten(ase1), Seq("a", "b", "c", "d"))
checkEvaluation(Flatten(ase2), Seq("a", "b", "c", "d"))
checkEvaluation(Flatten(ase3), Seq("a", "b", "c", "d"))
checkEvaluation(Flatten(ase4), Seq.empty)
checkEvaluation(Flatten(ase5), Seq.empty)
checkEvaluation(Flatten(ase6), Seq.empty)

// Test cases with null elements (non-primitive type)
val asn1 = Literal.create(Seq(Seq(null, null, "c"), Seq(null, null)), strArrayType)
val asn2 = Literal.create(Seq(Seq(null, null, null), Seq("d", null)), strArrayType)
val asn3 = Literal.create(Seq(Seq(null, null), Seq(null, null)), strArrayType)

checkEvaluation(Flatten(asn1), Seq(null, null, "c", null, null))
checkEvaluation(Flatten(asn2), Seq(null, null, null, "d", null))
checkEvaluation(Flatten(asn3), Seq(null, null, null, null))

// Test cases with a null array (non-primitive type)
val asa1 = Literal.create(Seq(null, Seq("a", "b")), strArrayType)
val asa2 = Literal.create(Seq(Seq("a", "b"), null), strArrayType)
val asa3 = Literal.create(Seq(null), strArrayType)
val asa4 = Literal.create(null, strArrayType)

checkEvaluation(Flatten(asa1), null)
checkEvaluation(Flatten(asa2), null)
checkEvaluation(Flatten(asa3), null)
checkEvaluation(Flatten(asa4), null)
}
}
8 changes: 8 additions & 0 deletions sql/core/src/main/scala/org/apache/spark/sql/functions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3340,6 +3340,14 @@ object functions {
*/
def reverse(e: Column): Column = withExpr { Reverse(e.expr) }

/**
* Creates a single array from an array of arrays. If a structure of nested arrays is deeper than
* two levels, only one level of nesting is removed.
* @group collection_funcs
* @since 2.4.0
*/
def flatten(e: Column): Column = withExpr { Flatten(e.expr) }

/**
* Returns an unordered array containing the keys of the map.
* @group collection_funcs
Expand Down
Loading