Skip to content

Commit 163e3f1

Browse files
committed
[SPARK-8241][SQL] string function: concat_ws.
I also changed the semantics of concat w.r.t. null back to the same behavior as Hive. That is to say, concat now returns null if any input is null. Author: Reynold Xin <[email protected]> Closes apache#7504 from rxin/concat_ws and squashes the following commits: 83fd950 [Reynold Xin] Fixed type casting. 3ae85f7 [Reynold Xin] Write null better. cdc7be6 [Reynold Xin] Added code generation for pure string mode. a61c4e4 [Reynold Xin] Updated comments. 2d51406 [Reynold Xin] [SPARK-8241][SQL] string function: concat_ws.
1 parent 7a81245 commit 163e3f1

File tree

10 files changed

+256
-38
lines changed

10 files changed

+256
-38
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,7 @@ object FunctionRegistry {
153153
expression[Ascii]("ascii"),
154154
expression[Base64]("base64"),
155155
expression[Concat]("concat"),
156+
expression[ConcatWs]("concat_ws"),
156157
expression[Encode]("encode"),
157158
expression[Decode]("decode"),
158159
expression[FormatNumber]("format_number"),
@@ -211,7 +212,10 @@ object FunctionRegistry {
211212
val builder = (expressions: Seq[Expression]) => {
212213
if (varargCtor.isDefined) {
213214
// If there is an apply method that accepts Seq[Expression], use that one.
214-
varargCtor.get.newInstance(expressions).asInstanceOf[Expression]
215+
Try(varargCtor.get.newInstance(expressions).asInstanceOf[Expression]) match {
216+
case Success(e) => e
217+
case Failure(e) => throw new AnalysisException(e.getMessage)
218+
}
215219
} else {
216220
// Otherwise, find an ctor method that matches the number of arguments, and use that.
217221
val params = Seq.fill(expressions.size)(classOf[Expression])
@@ -221,7 +225,10 @@ object FunctionRegistry {
221225
case Failure(e) =>
222226
throw new AnalysisException(s"Invalid number of arguments for function $name")
223227
}
224-
f.newInstance(expressions : _*).asInstanceOf[Expression]
228+
Try(f.newInstance(expressions : _*).asInstanceOf[Expression]) match {
229+
case Success(e) => e
230+
case Failure(e) => throw new AnalysisException(e.getMessage)
231+
}
225232
}
226233
}
227234
(name, builder)

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala

Lines changed: 64 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -34,19 +34,14 @@ import org.apache.spark.unsafe.types.UTF8String
3434

3535
/**
3636
* An expression that concatenates multiple input strings into a single string.
37-
* Input expressions that are evaluated to nulls are skipped.
38-
*
39-
* For example, `concat("a", null, "b")` is evaluated to `"ab"`.
40-
*
41-
* Note that this is different from Hive since Hive outputs null if any input is null.
42-
* We never output null.
37+
* If any input is null, concat returns null.
4338
*/
4439
case class Concat(children: Seq[Expression]) extends Expression with ImplicitCastInputTypes {
4540

4641
override def inputTypes: Seq[AbstractDataType] = Seq.fill(children.size)(StringType)
4742
override def dataType: DataType = StringType
4843

49-
override def nullable: Boolean = false
44+
override def nullable: Boolean = children.exists(_.nullable)
5045
override def foldable: Boolean = children.forall(_.foldable)
5146

5247
override def eval(input: InternalRow): Any = {
@@ -56,15 +51,76 @@ case class Concat(children: Seq[Expression]) extends Expression with ImplicitCas
5651

5752
override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
5853
val evals = children.map(_.gen(ctx))
59-
val inputs = evals.map { eval => s"${eval.isNull} ? null : ${eval.primitive}" }.mkString(", ")
54+
val inputs = evals.map { eval =>
55+
s"${eval.isNull} ? (UTF8String)null : ${eval.primitive}"
56+
}.mkString(", ")
6057
evals.map(_.code).mkString("\n") + s"""
6158
boolean ${ev.isNull} = false;
6259
UTF8String ${ev.primitive} = UTF8String.concat($inputs);
60+
if (${ev.primitive} == null) {
61+
${ev.isNull} = true;
62+
}
6363
"""
6464
}
6565
}
6666

6767

68+
/**
69+
* An expression that concatenates multiple input strings or array of strings into a single string,
70+
* using a given separator (the first child).
71+
*
72+
* Returns null if the separator is null. Otherwise, concat_ws skips all null values.
73+
*/
74+
case class ConcatWs(children: Seq[Expression])
75+
extends Expression with ImplicitCastInputTypes with CodegenFallback {
76+
77+
require(children.nonEmpty, s"$prettyName requires at least one argument.")
78+
79+
override def prettyName: String = "concat_ws"
80+
81+
/** The 1st child (separator) is str, and rest are either str or array of str. */
82+
override def inputTypes: Seq[AbstractDataType] = {
83+
val arrayOrStr = TypeCollection(ArrayType(StringType), StringType)
84+
StringType +: Seq.fill(children.size - 1)(arrayOrStr)
85+
}
86+
87+
override def dataType: DataType = StringType
88+
89+
override def nullable: Boolean = children.head.nullable
90+
override def foldable: Boolean = children.forall(_.foldable)
91+
92+
override def eval(input: InternalRow): Any = {
93+
val flatInputs = children.flatMap { child =>
94+
child.eval(input) match {
95+
case s: UTF8String => Iterator(s)
96+
case arr: Seq[_] => arr.asInstanceOf[Seq[UTF8String]]
97+
case null => Iterator(null.asInstanceOf[UTF8String])
98+
}
99+
}
100+
UTF8String.concatWs(flatInputs.head, flatInputs.tail : _*)
101+
}
102+
103+
override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
104+
if (children.forall(_.dataType == StringType)) {
105+
// All children are strings. In that case we can construct a fixed size array.
106+
val evals = children.map(_.gen(ctx))
107+
108+
val inputs = evals.map { eval =>
109+
s"${eval.isNull} ? (UTF8String)null : ${eval.primitive}"
110+
}.mkString(", ")
111+
112+
evals.map(_.code).mkString("\n") + s"""
113+
UTF8String ${ev.primitive} = UTF8String.concatWs($inputs);
114+
boolean ${ev.isNull} = ${ev.primitive} == null;
115+
"""
116+
} else {
117+
// Contains a mix of strings and array<string>s. Fall back to interpreted mode for now.
118+
super.genCode(ctx, ev)
119+
}
120+
}
121+
}
122+
123+
68124
trait StringRegexExpression extends ImplicitCastInputTypes {
69125
self: BinaryExpression =>
70126

sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ abstract class DataType extends AbstractDataType {
7979

8080
override private[sql] def defaultConcreteType: DataType = this
8181

82-
override private[sql] def acceptsType(other: DataType): Boolean = this == other
82+
override private[sql] def acceptsType(other: DataType): Boolean = sameType(other)
8383
}
8484

8585

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,6 @@ class HiveTypeCoercionSuite extends PlanTest {
3737
shouldCast(NullType, IntegerType, IntegerType)
3838
shouldCast(NullType, DecimalType, DecimalType.Unlimited)
3939

40-
// TODO: write the entire implicit cast table out for test cases.
4140
shouldCast(ByteType, IntegerType, IntegerType)
4241
shouldCast(IntegerType, IntegerType, IntegerType)
4342
shouldCast(IntegerType, LongType, LongType)
@@ -86,6 +85,16 @@ class HiveTypeCoercionSuite extends PlanTest {
8685
DecimalType.Unlimited, DecimalType(10, 2)).foreach { tpe =>
8786
shouldCast(tpe, NumericType, tpe)
8887
}
88+
89+
shouldCast(
90+
ArrayType(StringType, false),
91+
TypeCollection(ArrayType(StringType), StringType),
92+
ArrayType(StringType, false))
93+
94+
shouldCast(
95+
ArrayType(StringType, true),
96+
TypeCollection(ArrayType(StringType), StringType),
97+
ArrayType(StringType, true))
8998
}
9099

91100
test("ineligible implicit type cast") {

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
2626

2727
test("concat") {
2828
def testConcat(inputs: String*): Unit = {
29-
val expected = inputs.filter(_ != null).mkString
29+
val expected = if (inputs.contains(null)) null else inputs.mkString
3030
checkEvaluation(Concat(inputs.map(Literal.create(_, StringType))), expected, EmptyRow)
3131
}
3232

@@ -46,6 +46,35 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
4646
// scalastyle:on
4747
}
4848

49+
test("concat_ws") {
50+
def testConcatWs(expected: String, sep: String, inputs: Any*): Unit = {
51+
val inputExprs = inputs.map {
52+
case s: Seq[_] => Literal.create(s, ArrayType(StringType))
53+
case null => Literal.create(null, StringType)
54+
case s: String => Literal.create(s, StringType)
55+
}
56+
val sepExpr = Literal.create(sep, StringType)
57+
checkEvaluation(ConcatWs(sepExpr +: inputExprs), expected, EmptyRow)
58+
}
59+
60+
// scalastyle:off
61+
// non ascii characters are not allowed in the code, so we disable the scalastyle here.
62+
testConcatWs(null, null)
63+
testConcatWs(null, null, "a", "b")
64+
testConcatWs("", "")
65+
testConcatWs("ab", "哈哈", "ab")
66+
testConcatWs("a哈哈b", "哈哈", "a", "b")
67+
testConcatWs("a哈哈b", "哈哈", "a", null, "b")
68+
testConcatWs("a哈哈b哈哈c", "哈哈", null, "a", null, "b", "c")
69+
70+
testConcatWs("ab", "哈哈", Seq("ab"))
71+
testConcatWs("a哈哈b", "哈哈", Seq("a", "b"))
72+
testConcatWs("a哈哈b哈哈c哈哈d", "哈哈", Seq("a", null, "b"), null, "c", Seq(null, "d"))
73+
testConcatWs("a哈哈b哈哈c", "哈哈", Seq("a", null, "b"), null, "c", Seq.empty[String])
74+
testConcatWs("a哈哈b哈哈c", "哈哈", Seq("a", null, "b"), null, "c", Seq[String](null))
75+
// scalastyle:on
76+
}
77+
4978
test("StringComparison") {
5079
val row = create_row("abc", null)
5180
val c1 = 'a.string.at(0)

sql/core/src/main/scala/org/apache/spark/sql/functions.scala

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1732,6 +1732,30 @@ object functions {
17321732
concat((columnName +: columnNames).map(Column.apply): _*)
17331733
}
17341734

1735+
/**
1736+
* Concatenates input strings together into a single string, using the given separator.
1737+
*
1738+
* @group string_funcs
1739+
* @since 1.5.0
1740+
*/
1741+
@scala.annotation.varargs
1742+
def concat_ws(sep: String, exprs: Column*): Column = {
1743+
ConcatWs(Literal.create(sep, StringType) +: exprs.map(_.expr))
1744+
}
1745+
1746+
/**
1747+
* Concatenates input strings together into a single string, using the given separator.
1748+
*
1749+
* This is the variant of concat_ws that takes in the column names.
1750+
*
1751+
* @group string_funcs
1752+
* @since 1.5.0
1753+
*/
1754+
@scala.annotation.varargs
1755+
def concat_ws(sep: String, columnName: String, columnNames: String*): Column = {
1756+
concat_ws(sep, (columnName +: columnNames).map(Column.apply) : _*)
1757+
}
1758+
17351759
/**
17361760
* Computes the length of a given string / binary value.
17371761
*

sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -30,14 +30,25 @@ class StringFunctionsSuite extends QueryTest {
3030
val df = Seq[(String, String, String)](("a", "b", null)).toDF("a", "b", "c")
3131

3232
checkAnswer(
33-
df.select(concat($"a", $"b", $"c")),
34-
Row("ab"))
33+
df.select(concat($"a", $"b"), concat($"a", $"b", $"c")),
34+
Row("ab", null))
3535

3636
checkAnswer(
37-
df.selectExpr("concat(a, b, c)"),
38-
Row("ab"))
37+
df.selectExpr("concat(a, b)", "concat(a, b, c)"),
38+
Row("ab", null))
3939
}
4040

41+
test("string concat_ws") {
42+
val df = Seq[(String, String, String)](("a", "b", null)).toDF("a", "b", "c")
43+
44+
checkAnswer(
45+
df.select(concat_ws("||", $"a", $"b", $"c")),
46+
Row("a||b"))
47+
48+
checkAnswer(
49+
df.selectExpr("concat_ws('||', a, b, c)"),
50+
Row("a||b"))
51+
}
4152

4253
test("string Levenshtein distance") {
4354
val df = Seq(("kitten", "sitting"), ("frog", "fog")).toDF("l", "r")

sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -263,9 +263,6 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter {
263263
"timestamp_2",
264264
"timestamp_udf",
265265

266-
// Hive outputs NULL if any concat input has null. We never output null for concat.
267-
"udf_concat",
268-
269266
// Unlike Hive, we do support log base in (0, 1.0], therefore disable this
270267
"udf7"
271268
)
@@ -856,6 +853,7 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter {
856853
"udf_case",
857854
"udf_ceil",
858855
"udf_ceiling",
856+
"udf_concat",
859857
"udf_concat_insert1",
860858
"udf_concat_insert2",
861859
"udf_concat_ws",

unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java

Lines changed: 52 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -397,33 +397,79 @@ public UTF8String lpad(int len, UTF8String pad) {
397397
}
398398

399399
/**
400-
* Concatenates input strings together into a single string. A null input is skipped.
401-
* For example, concat("a", null, "c") would yield "ac".
400+
* Concatenates input strings together into a single string. Returns null if any input is null.
402401
*/
403402
public static UTF8String concat(UTF8String... inputs) {
404-
if (inputs == null) {
405-
return fromBytes(new byte[0]);
406-
}
407-
408403
// Compute the total length of the result.
409404
int totalLength = 0;
410405
for (int i = 0; i < inputs.length; i++) {
411406
if (inputs[i] != null) {
412407
totalLength += inputs[i].numBytes;
408+
} else {
409+
return null;
413410
}
414411
}
415412

416413
// Allocate a new byte array, and copy the inputs one by one into it.
417414
final byte[] result = new byte[totalLength];
418415
int offset = 0;
416+
for (int i = 0; i < inputs.length; i++) {
417+
int len = inputs[i].numBytes;
418+
PlatformDependent.copyMemory(
419+
inputs[i].base, inputs[i].offset,
420+
result, PlatformDependent.BYTE_ARRAY_OFFSET + offset,
421+
len);
422+
offset += len;
423+
}
424+
return fromBytes(result);
425+
}
426+
427+
/**
428+
* Concatenates input strings together into a single string using the separator.
429+
* A null input is skipped. For example, concat(",", "a", null, "c") would yield "a,c".
430+
*/
431+
public static UTF8String concatWs(UTF8String separator, UTF8String... inputs) {
432+
if (separator == null) {
433+
return null;
434+
}
435+
436+
int numInputBytes = 0; // total number of bytes from the inputs
437+
int numInputs = 0; // number of non-null inputs
419438
for (int i = 0; i < inputs.length; i++) {
439+
if (inputs[i] != null) {
440+
numInputBytes += inputs[i].numBytes;
441+
numInputs++;
442+
}
443+
}
444+
445+
if (numInputs == 0) {
446+
// Return an empty string if there is no input, or all the inputs are null.
447+
return fromBytes(new byte[0]);
448+
}
449+
450+
// Allocate a new byte array, and copy the inputs one by one into it.
451+
// The size of the new array is the size of all inputs, plus the separators.
452+
final byte[] result = new byte[numInputBytes + (numInputs - 1) * separator.numBytes];
453+
int offset = 0;
454+
455+
for (int i = 0, j = 0; i < inputs.length; i++) {
420456
if (inputs[i] != null) {
421457
int len = inputs[i].numBytes;
422458
PlatformDependent.copyMemory(
423459
inputs[i].base, inputs[i].offset,
424460
result, PlatformDependent.BYTE_ARRAY_OFFSET + offset,
425461
len);
426462
offset += len;
463+
464+
j++;
465+
// Add separator if this is not the last input.
466+
if (j < numInputs) {
467+
PlatformDependent.copyMemory(
468+
separator.base, separator.offset,
469+
result, PlatformDependent.BYTE_ARRAY_OFFSET + offset,
470+
separator.numBytes);
471+
offset += separator.numBytes;
472+
}
427473
}
428474
}
429475
return fromBytes(result);

0 commit comments

Comments
 (0)