Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
282e724
[SPARK-23736][SQL] Implementation of the concat_arrays function conca…
Mar 13, 2018
aa5a089
[SPARK-23736][SQL] Code style fixes.
Mar 26, 2018
90d3ab7
[SPARK-23736][SQL] Improving the description of the ConcatArrays expr…
Mar 26, 2018
bb46c3d
[SPARK-23736][SQL] Merging concat and concat_arrays into one function.
Mar 26, 2018
11205af
[SPARK-23736][SQL] Adding new line at the end of the unresolved.scala…
Mar 26, 2018
753499d
[SPARK-23736][SQL] Fixing failing unit test from DDLSuite.
Mar 26, 2018
2efdd77
[SPARK-23736][SQL] Changing method styling according to the standards.
Mar 27, 2018
fd84bee
[SPARK-23736][SQL] Changing data type to ArrayType(StringType) for th…
Mar 27, 2018
116f91f
[SPARK-23736][SQL] Fixing a SparkR unit test by filtering out Unresol…
Mar 27, 2018
e199ac5
[SPARK-23736][SQL] Merging the current master into the feature branch.
Mar 28, 2018
067c2db
[SPARK-23736][SQL] Merging the current master to the feature branch.
Mar 29, 2018
090929f
[SPARK-23736][SQL] Merging string concat and array concat into one ex…
Apr 6, 2018
8abd1a8
[SPARK-23736][SQL] Adding more test cases
Apr 7, 2018
367ee22
[SPARK-23736][SQL] Optimizing null elements protection.
Apr 7, 2018
6bb33e6
[SPARK-23736][SQL] Protection against the length limit of Java functions
Apr 12, 2018
57b250c
Merge remote-tracking branch 'spark/master' into feature/array-api-co…
Apr 12, 2018
944e0c9
[SPARK-23736][SQL] Adding test for the limit of Java function size.
Apr 12, 2018
7f5124b
[SPARK-23736][SQL] Adding more tests
Apr 13, 2018
0201e4b
[SPARK-23736][SQL] Checks of max array size + Rewriting codegen using…
Apr 16, 2018
600ae89
[SPARK-23736][SQL] Merging current master into the feature branch.
Apr 16, 2018
f2a67e8
[SPARK-23736][SQL] Fixing exception messages
Apr 17, 2018
8a125d9
[SPARK-23736][SQL] Small refactoring
Apr 18, 2018
5a4cc8c
[SPARK-23736][SQL] Merging current master to the feature branch
Apr 18, 2018
f7bdcf7
[SPARK-23736][SQL] Merging current master to the feature branch.
Apr 19, 2018
36d5d25
[SPARK-23736][SQL] Merging current master to the feature branch.
Apr 19, 2018
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,11 @@ public static long nextPowerOf2(long num) {
}

public static int roundNumberOfBytesToNearestWord(int numBytes) {
int remainder = numBytes & 0x07; // This is equivalent to `numBytes % 8`
return (int)roundNumberOfBytesToNearestWord((long)numBytes);
}

public static long roundNumberOfBytesToNearestWord(long numBytes) {
long remainder = numBytes & 0x07; // This is equivalent to `numBytes % 8`
if (remainder == 0) {
return numBytes;
} else {
Expand Down
34 changes: 19 additions & 15 deletions python/pyspark/sql/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -1425,21 +1425,6 @@ def hash(*cols):
del _name, _doc


@since(1.5)
@ignore_unicode_prefix
def concat(*cols):
"""
Concatenates multiple input columns together into a single column.
If all inputs are binary, concat returns an output as binary. Otherwise, it returns as string.

>>> df = spark.createDataFrame([('abcd','123')], ['s', 'd'])
>>> df.select(concat(df.s, df.d).alias('s')).collect()
[Row(s=u'abcd123')]
"""
sc = SparkContext._active_spark_context
return Column(sc._jvm.functions.concat(_to_seq(sc, cols, _to_java_column)))


@since(1.5)
@ignore_unicode_prefix
def concat_ws(sep, *cols):
Expand Down Expand Up @@ -1845,6 +1830,25 @@ def array_contains(col, value):
return Column(sc._jvm.functions.array_contains(_to_java_column(col), value))


@since(1.5)
@ignore_unicode_prefix
def concat(*cols):
"""
Concatenates multiple input columns together into a single column.
The function works with strings, binary and compatible array columns.

>>> df = spark.createDataFrame([('abcd','123')], ['s', 'd'])
>>> df.select(concat(df.s, df.d).alias('s')).collect()
[Row(s=u'abcd123')]

>>> df = spark.createDataFrame([([1, 2], [3, 4], [5]), ([1, 2], None, [3])], ['a', 'b', 'c'])
>>> df.select(concat(df.a, df.b, df.c).alias("arr")).collect()
[Row(arr=[1, 2, 3, 4, 5]), Row(arr=None)]
"""
sc = SparkContext._active_spark_context
return Column(sc._jvm.functions.concat(_to_seq(sc, cols, _to_java_column)))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why did we move this down .. ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The whole file is divide into sections according to groups of functions. Based on @gatorsmile's suggestion, the concat function should be categorized as a collection function. So I moved the function to comply with the file structure.



@since(2.4)
def array_position(col, value):
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,9 +56,19 @@
public final class UnsafeArrayData extends ArrayData {

public static int calculateHeaderPortionInBytes(int numFields) {
return (int)calculateHeaderPortionInBytes((long)numFields);
}

public static long calculateHeaderPortionInBytes(long numFields) {
return 8 + ((numFields + 63)/ 64) * 8;
}

public static long calculateSizeOfUnderlyingByteArray(long numFields, int elementSize) {
long size = UnsafeArrayData.calculateHeaderPortionInBytes(numFields) +
ByteArrayMethods.roundNumberOfBytesToNearestWord(numFields * elementSize);
return size;
}

private Object baseObject;
private long baseOffset;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -308,7 +308,6 @@ object FunctionRegistry {
expression[BitLength]("bit_length"),
expression[Length]("char_length"),
expression[Length]("character_length"),
expression[Concat]("concat"),
expression[ConcatWs]("concat_ws"),
expression[Decode]("decode"),
expression[Elt]("elt"),
Expand Down Expand Up @@ -413,6 +412,7 @@ object FunctionRegistry {
expression[ArrayMin]("array_min"),
expression[ArrayMax]("array_max"),
expression[Reverse]("reverse"),
expression[Concat]("concat"),
CreateStruct.registryEntry,

// misc functions
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -504,6 +504,14 @@ object TypeCoercion {
case None => a
}

case c @ Concat(children) if children.forall(c => ArrayType.acceptsType(c.dataType)) &&
!haveSameType(children) =>
val types = children.map(_.dataType)
findWiderCommonType(types) match {
case Some(finalDataType) => Concat(children.map(Cast(_, finalDataType)))
case None => c
}

case m @ CreateMap(children) if m.keys.length == m.values.length &&
(!haveSameType(m.keys) || !haveSameType(m.values)) =>
val newKeys = if (haveSameType(m.keys)) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,9 @@ import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData, MapData, TypeUtils}
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String
import org.apache.spark.unsafe.Platform
import org.apache.spark.unsafe.array.ByteArrayMethods
import org.apache.spark.unsafe.types.{ByteArray, UTF8String}

/**
* Given an array or map, returns its size. Returns -1 if null.
Expand Down Expand Up @@ -665,3 +667,219 @@ case class ElementAt(left: Expression, right: Expression) extends GetMapValueUti

override def prettyName: String = "element_at"
}

/**
* Concatenates multiple input columns together into a single column.
* The function works with strings, binary and compatible array columns.
*/
@ExpressionDescription(
usage = "_FUNC_(col1, col2, ..., colN) - Returns the concatenation of col1, col2, ..., colN.",
examples = """
Examples:
> SELECT _FUNC_('Spark', 'SQL');
SparkSQL
> SELECT _FUNC_(array(1, 2, 3), array(4, 5), array(6));
| [1,2,3,4,5,6]
""")
case class Concat(children: Seq[Expression]) extends Expression {

private val MAX_ARRAY_LENGTH: Int = ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH

val allowedTypes = Seq(StringType, BinaryType, ArrayType)

override def checkInputDataTypes(): TypeCheckResult = {
if (children.isEmpty) {
TypeCheckResult.TypeCheckSuccess
} else {
val childTypes = children.map(_.dataType)
if (childTypes.exists(tpe => !allowedTypes.exists(_.acceptsType(tpe)))) {
return TypeCheckResult.TypeCheckFailure(
s"input to function $prettyName should have been StringType, BinaryType or ArrayType," +
s" but it's " + childTypes.map(_.simpleString).mkString("[", ", ", "]"))
}
TypeUtils.checkForSameTypeInputExpr(childTypes, s"function $prettyName")
}
}

override def dataType: DataType = children.map(_.dataType).headOption.getOrElse(StringType)

lazy val javaType: String = CodeGenerator.javaType(dataType)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can move this into doGenCode() method.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point! But I think it would be better to reuse javaType also in genCodeForPrimitiveArrays and genCodeForNonPrimitiveArrays.


override def nullable: Boolean = children.exists(_.nullable)

override def foldable: Boolean = children.forall(_.foldable)

override def eval(input: InternalRow): Any = dataType match {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so this pattern match will probably cause significant regression in the interpreted (non-codegen) mode, due to the way scala pattern matching is implemented.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! I've created #22471 to call the pattern matching only once.

WDYT about Reverse? It looks like a similar problem.

case BinaryType =>
val inputs = children.map(_.eval(input).asInstanceOf[Array[Byte]])
ByteArray.concat(inputs: _*)
case StringType =>
val inputs = children.map(_.eval(input).asInstanceOf[UTF8String])
UTF8String.concat(inputs : _*)
case ArrayType(elementType, _) =>
val inputs = children.toStream.map(_.eval(input))
if (inputs.contains(null)) {
null
} else {
val arrayData = inputs.map(_.asInstanceOf[ArrayData])
val numberOfElements = arrayData.foldLeft(0L)((sum, ad) => sum + ad.numElements())
if (numberOfElements > MAX_ARRAY_LENGTH) {
throw new RuntimeException(s"Unsuccessful try to concat arrays with $numberOfElements" +
s" elements due to exceeding the array size limit $MAX_ARRAY_LENGTH.")
}
val finalData = new Array[AnyRef](numberOfElements.toInt)
var position = 0
for(ad <- arrayData) {
val arr = ad.toObjectArray(elementType)
Array.copy(arr, 0, finalData, position, arr.length)
position += arr.length
}
new GenericArrayData(finalData)
}
}

override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val evals = children.map(_.genCode(ctx))
val args = ctx.freshName("args")

val inputs = evals.zipWithIndex.map { case (eval, index) =>
s"""
${eval.code}
if (!${eval.isNull}) {
$args[$index] = ${eval.value};
}
"""
}

val (concatenator, initCode) = dataType match {
case BinaryType =>
(classOf[ByteArray].getName, s"byte[][] $args = new byte[${evals.length}][];")
case StringType =>
("UTF8String", s"UTF8String[] $args = new UTF8String[${evals.length}];")
case ArrayType(elementType, _) =>
val arrayConcatClass = if (CodeGenerator.isPrimitiveType(elementType)) {
genCodeForPrimitiveArrays(ctx, elementType)
} else {
genCodeForNonPrimitiveArrays(ctx, elementType)
}
(arrayConcatClass, s"ArrayData[] $args = new ArrayData[${evals.length}];")
}
val codes = ctx.splitExpressionsWithCurrentInputs(
expressions = inputs,
funcName = "valueConcat",
extraArguments = (s"$javaType[]", args) :: Nil)
ev.copy(s"""
$initCode
$codes
$javaType ${ev.value} = $concatenator.concat($args);
boolean ${ev.isNull} = ${ev.value} == null;
""")
}

private def genCodeForNumberOfElements(ctx: CodegenContext) : (String, String) = {
val numElements = ctx.freshName("numElements")
val code = s"""
|long $numElements = 0L;
|for (int z = 0; z < ${children.length}; z++) {
| $numElements += args[z].numElements();
|}
|if ($numElements > $MAX_ARRAY_LENGTH) {
| throw new RuntimeException("Unsuccessful try to concat arrays with " + $numElements +
| " elements due to exceeding the array size limit $MAX_ARRAY_LENGTH.");
|}
""".stripMargin

(code, numElements)
}

private def nullArgumentProtection() : String = {
if (nullable) {
s"""
|for (int z = 0; z < ${children.length}; z++) {
| if (args[z] == null) return null;
|}
""".stripMargin
} else {
""
}
}

private def genCodeForPrimitiveArrays(ctx: CodegenContext, elementType: DataType): String = {
val arrayName = ctx.freshName("array")
val arraySizeName = ctx.freshName("size")
val counter = ctx.freshName("counter")
val arrayData = ctx.freshName("arrayData")

val (numElemCode, numElemName) = genCodeForNumberOfElements(ctx)

val unsafeArraySizeInBytes = s"""
|long $arraySizeName = UnsafeArrayData.calculateSizeOfUnderlyingByteArray(
| $numElemName,
| ${elementType.defaultSize});
|if ($arraySizeName > $MAX_ARRAY_LENGTH) {
| throw new RuntimeException("Unsuccessful try to concat arrays with " + $arraySizeName +
| " bytes of data due to exceeding the limit $MAX_ARRAY_LENGTH bytes" +
| " for UnsafeArrayData.");
|}
""".stripMargin
val baseOffset = Platform.BYTE_ARRAY_OFFSET
val primitiveValueTypeName = CodeGenerator.primitiveTypeName(elementType)

s"""
|new Object() {
| public ArrayData concat($javaType[] args) {
| ${nullArgumentProtection()}
| $numElemCode
| $unsafeArraySizeInBytes
| byte[] $arrayName = new byte[(int)$arraySizeName];
| UnsafeArrayData $arrayData = new UnsafeArrayData();
| Platform.putLong($arrayName, $baseOffset, $numElemName);
| $arrayData.pointTo($arrayName, $baseOffset, (int)$arraySizeName);
| int $counter = 0;
| for (int y = 0; y < ${children.length}; y++) {
| for (int z = 0; z < args[y].numElements(); z++) {
| if (args[y].isNullAt(z)) {
| $arrayData.setNullAt($counter);
| } else {
| $arrayData.set$primitiveValueTypeName(
| $counter,
| ${CodeGenerator.getValue(s"args[y]", elementType, "z")}
| );
| }
| $counter++;
| }
| }
| return $arrayData;
| }
|}""".stripMargin.stripPrefix("\n")
}

private def genCodeForNonPrimitiveArrays(ctx: CodegenContext, elementType: DataType): String = {
val genericArrayClass = classOf[GenericArrayData].getName
val arrayData = ctx.freshName("arrayObjects")
val counter = ctx.freshName("counter")

val (numElemCode, numElemName) = genCodeForNumberOfElements(ctx)

s"""
|new Object() {
| public ArrayData concat($javaType[] args) {
| ${nullArgumentProtection()}
| $numElemCode
| Object[] $arrayData = new Object[(int)$numElemName];
| int $counter = 0;
| for (int y = 0; y < ${children.length}; y++) {
| for (int z = 0; z < args[y].numElements(); z++) {
| $arrayData[$counter] = ${CodeGenerator.getValue(s"args[y]", elementType, "z")};
| $counter++;
| }
| }
| return new $genericArrayClass($arrayData);
| }
|}""".stripMargin.stripPrefix("\n")
}

override def toString: String = s"concat(${children.mkString(", ")})"

override def sql: String = s"concat(${children.map(_.sql).mkString(", ")})"
}
Loading