Skip to content

Commit 2d51406

Browse files
committed
[SPARK-8241][SQL] string function: concat_ws.
I also changed the semantics of concat w.r.t. null back to the same behavior as Hive. That is to say, concat now returns null if any input is null.
1 parent a803ac3 commit 2d51406

File tree

8 files changed

+220
-36
lines changed

8 files changed

+220
-36
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,7 @@ object FunctionRegistry {
153153
expression[Ascii]("ascii"),
154154
expression[Base64]("base64"),
155155
expression[Concat]("concat"),
156+
expression[ConcatWs]("concat_ws"),
156157
expression[Encode]("encode"),
157158
expression[Decode]("decode"),
158159
expression[FormatNumber]("format_number"),
@@ -211,7 +212,10 @@ object FunctionRegistry {
211212
val builder = (expressions: Seq[Expression]) => {
212213
if (varargCtor.isDefined) {
213214
// If there is an apply method that accepts Seq[Expression], use that one.
214-
varargCtor.get.newInstance(expressions).asInstanceOf[Expression]
215+
Try(varargCtor.get.newInstance(expressions).asInstanceOf[Expression]) match {
216+
case Success(e) => e
217+
case Failure(e) => throw new AnalysisException(e.getMessage)
218+
}
215219
} else {
216220
// Otherwise, find an ctor method that matches the number of arguments, and use that.
217221
val params = Seq.fill(expressions.size)(classOf[Expression])
@@ -221,7 +225,10 @@ object FunctionRegistry {
221225
case Failure(e) =>
222226
throw new AnalysisException(s"Invalid number of arguments for function $name")
223227
}
224-
f.newInstance(expressions : _*).asInstanceOf[Expression]
228+
Try(f.newInstance(expressions : _*).asInstanceOf[Expression]) match {
229+
case Success(e) => e
230+
case Failure(e) => throw new AnalysisException(e.getMessage)
231+
}
225232
}
226233
}
227234
(name, builder)

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala

Lines changed: 38 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -34,19 +34,14 @@ import org.apache.spark.unsafe.types.UTF8String
3434

3535
/**
3636
* An expression that concatenates multiple input strings into a single string.
37-
* Input expressions that are evaluated to nulls are skipped.
38-
*
39-
* For example, `concat("a", null, "b")` is evaluated to `"ab"`.
40-
*
41-
* Note that this is different from Hive since Hive outputs null if any input is null.
42-
* We never output null.
37+
* If any input is null, concat returns null.
4338
*/
4439
case class Concat(children: Seq[Expression]) extends Expression with ImplicitCastInputTypes {
4540

4641
override def inputTypes: Seq[AbstractDataType] = Seq.fill(children.size)(StringType)
4742
override def dataType: DataType = StringType
4843

49-
override def nullable: Boolean = false
44+
override def nullable: Boolean = children.exists(_.nullable)
5045
override def foldable: Boolean = children.forall(_.foldable)
5146

5247
override def eval(input: InternalRow): Any = {
@@ -56,15 +51,50 @@ case class Concat(children: Seq[Expression]) extends Expression with ImplicitCas
5651

5752
override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
5853
val evals = children.map(_.gen(ctx))
59-
val inputs = evals.map { eval => s"${eval.isNull} ? null : ${eval.primitive}" }.mkString(", ")
54+
val inputs = evals.map { eval =>
55+
s"${eval.isNull} ? (UTF8String)null : ${eval.primitive}"
56+
}.mkString(", ")
6057
evals.map(_.code).mkString("\n") + s"""
6158
boolean ${ev.isNull} = false;
6259
UTF8String ${ev.primitive} = UTF8String.concat($inputs);
60+
if (${ev.primitive} == null) {
61+
${ev.isNull} = true;
62+
}
6363
"""
6464
}
6565
}
6666

6767

68+
case class ConcatWs(children: Seq[Expression]) extends Expression with ImplicitCastInputTypes with CodegenFallback {
69+
require(children.nonEmpty, s"$prettyName requires at least one argument.")
70+
71+
override def prettyName: String = "concat_ws"
72+
73+
override def inputTypes: Seq[AbstractDataType] = {
74+
Seq.fill(children.size)(TypeCollection(
75+
ArrayType(StringType, true),
76+
ArrayType(StringType, false),
77+
StringType))
78+
}
79+
80+
override def dataType: DataType = StringType
81+
82+
override def nullable: Boolean = children.head.nullable
83+
override def foldable: Boolean = children.forall(_.foldable)
84+
85+
override def eval(input: InternalRow): Any = {
86+
val flatInputs = children.flatMap { child =>
87+
child.eval(input) match {
88+
case s: UTF8String => Iterator(s)
89+
case arr: Seq[_] => arr.asInstanceOf[Seq[UTF8String]]
90+
case null => Iterator(null.asInstanceOf[UTF8String])
91+
}
92+
}
93+
UTF8String.concatWs(flatInputs.head, flatInputs.tail : _*)
94+
}
95+
}
96+
97+
6898
trait StringRegexExpression extends ImplicitCastInputTypes {
6999
self: BinaryExpression =>
70100

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
2626

2727
test("concat") {
2828
def testConcat(inputs: String*): Unit = {
29-
val expected = inputs.filter(_ != null).mkString
29+
val expected = if (inputs.contains(null)) null else inputs.mkString
3030
checkEvaluation(Concat(inputs.map(Literal.create(_, StringType))), expected, EmptyRow)
3131
}
3232

@@ -46,6 +46,36 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
4646
// scalastyle:on
4747
}
4848

49+
test("concat_ws") {
50+
def testConcatWs(expected: String, sep: String, inputs: Any*): Unit = {
51+
val inputExprs = inputs.map {
52+
case s: Seq[_] => Literal.create(s, ArrayType(StringType))
53+
case null => Literal.create(null, StringType)
54+
case s: String => Literal.create(s, StringType)
55+
}
56+
val sepExpr = Literal.create(sep, StringType)
57+
checkEvaluation(ConcatWs(sepExpr +: inputExprs), expected, EmptyRow)
58+
}
59+
60+
// scalastyle:off
61+
// non ascii characters are not allowed in the code, so we disable the scalastyle here.
62+
testConcatWs(null, null)
63+
testConcatWs(null, null, "a", "b")
64+
testConcatWs("", "")
65+
testConcatWs("ab", "哈哈", "ab")
66+
testConcatWs("a哈哈b", "哈哈", "a", "b")
67+
testConcatWs("a哈哈b", "哈哈", "a", null, "b")
68+
testConcatWs("a哈哈b哈哈c", "哈哈", null, "a", null, "b", "c")
69+
70+
testConcatWs("ab", "哈哈", Seq("ab"))
71+
testConcatWs("a哈哈b", "哈哈", Seq("a", "b"))
72+
testConcatWs("a哈哈b哈哈c哈哈d", "哈哈", Seq("a", null, "b"), null, "c", Seq(null, "d"))
73+
testConcatWs("a哈哈b哈哈c", "哈哈", Seq("a", null, "b"), null, "c", Seq.empty[String])
74+
testConcatWs("a哈哈b哈哈c", "哈哈", Seq("a", null, "b"), null, "c", Seq[String](null))
75+
// scalastyle:on
76+
77+
}
78+
4979
test("StringComparison") {
5080
val row = create_row("abc", null)
5181
val c1 = 'a.string.at(0)

sql/core/src/main/scala/org/apache/spark/sql/functions.scala

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1732,6 +1732,30 @@ object functions {
17321732
concat((columnName +: columnNames).map(Column.apply): _*)
17331733
}
17341734

1735+
/**
1736+
* Concatenates input strings together into a single string, using the given separator.
1737+
*
1738+
* @group string_funcs
1739+
* @since 1.5.0
1740+
*/
1741+
@scala.annotation.varargs
1742+
def concat_ws(sep: String, exprs: Column*): Column = {
1743+
ConcatWs(Literal.create(sep, StringType) +: exprs.map(_.expr))
1744+
}
1745+
1746+
/**
1747+
* Concatenates input strings together into a single string, using the given separator.
1748+
*
1749+
* This is the variant of concat_ws that takes in the column names.
1750+
*
1751+
* @group string_funcs
1752+
* @since 1.5.0
1753+
*/
1754+
@scala.annotation.varargs
1755+
def concat_ws(sep: String, columnName: String, columnNames: String*): Column = {
1756+
concat_ws(sep, (columnName +: columnNames).map(Column.apply) : _*)
1757+
}
1758+
17351759
/**
17361760
* Computes the length of a given string / binary value.
17371761
*

sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -30,14 +30,25 @@ class StringFunctionsSuite extends QueryTest {
3030
val df = Seq[(String, String, String)](("a", "b", null)).toDF("a", "b", "c")
3131

3232
checkAnswer(
33-
df.select(concat($"a", $"b", $"c")),
34-
Row("ab"))
33+
df.select(concat($"a", $"b"), concat($"a", $"b", $"c")),
34+
Row("ab", null))
3535

3636
checkAnswer(
37-
df.selectExpr("concat(a, b, c)"),
38-
Row("ab"))
37+
df.selectExpr("concat(a, b)", "concat(a, b, c)"),
38+
Row("ab", null))
3939
}
4040

41+
test("string concat_ws") {
42+
val df = Seq[(String, String, String)](("a", "b", null)).toDF("a", "b", "c")
43+
44+
checkAnswer(
45+
df.select(concat_ws("||", $"a", $"b", $"c")),
46+
Row("a||b"))
47+
48+
checkAnswer(
49+
df.selectExpr("concat_ws('||', a, b, c)"),
50+
Row("a||b"))
51+
}
4152

4253
test("string Levenshtein distance") {
4354
val df = Seq(("kitten", "sitting"), ("frog", "fog")).toDF("l", "r")

sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -263,9 +263,6 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter {
263263
"timestamp_2",
264264
"timestamp_udf",
265265

266-
// Hive outputs NULL if any concat input has null. We never output null for concat.
267-
"udf_concat",
268-
269266
// Unlike Hive, we do support log base in (0, 1.0], therefore disable this
270267
"udf7"
271268
)
@@ -856,6 +853,7 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter {
856853
"udf_case",
857854
"udf_ceil",
858855
"udf_ceiling",
856+
"udf_concat",
859857
"udf_concat_insert1",
860858
"udf_concat_insert2",
861859
"udf_concat_ws",

unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java

Lines changed: 52 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -397,33 +397,79 @@ public UTF8String lpad(int len, UTF8String pad) {
397397
}
398398

399399
/**
400-
* Concatenates input strings together into a single string. A null input is skipped.
401-
* For example, concat("a", null, "c") would yield "ac".
400+
* Concatenates input strings together into a single string. Returns null if any input is null.
402401
*/
403402
public static UTF8String concat(UTF8String... inputs) {
404-
if (inputs == null) {
405-
return fromBytes(new byte[0]);
406-
}
407-
408403
// Compute the total length of the result.
409404
int totalLength = 0;
410405
for (int i = 0; i < inputs.length; i++) {
411406
if (inputs[i] != null) {
412407
totalLength += inputs[i].numBytes;
408+
} else {
409+
return null;
413410
}
414411
}
415412

416413
// Allocate a new byte array, and copy the inputs one by one into it.
417414
final byte[] result = new byte[totalLength];
418415
int offset = 0;
416+
for (int i = 0; i < inputs.length; i++) {
417+
int len = inputs[i].numBytes;
418+
PlatformDependent.copyMemory(
419+
inputs[i].base, inputs[i].offset,
420+
result, PlatformDependent.BYTE_ARRAY_OFFSET + offset,
421+
len);
422+
offset += len;
423+
}
424+
return fromBytes(result);
425+
}
426+
427+
/**
428+
* Concatenates input strings together into a single string using the separator.
429+
* A null input is skipped. For example, concat(",", "a", null, "c") would yield "a,c".
430+
*/
431+
public static UTF8String concatWs(UTF8String separator, UTF8String... inputs) {
432+
if (separator == null) {
433+
return null;
434+
}
435+
436+
int numInputBytes = 0; // total number of bytes from the inputs
437+
int numInputs = 0; // number of non-null inputs
419438
for (int i = 0; i < inputs.length; i++) {
439+
if (inputs[i] != null) {
440+
numInputBytes += inputs[i].numBytes;
441+
numInputs++;
442+
}
443+
}
444+
445+
if (numInputs == 0) {
446+
// Return an empty string if there is no input, or all the inputs are null.
447+
return fromBytes(new byte[0]);
448+
}
449+
450+
// Allocate a new byte array, and copy the inputs one by one into it.
451+
// The size of the new array is the size of all inputs, plus the separators.
452+
final byte[] result = new byte[numInputBytes + (numInputs - 1) * separator.numBytes];
453+
int offset = 0;
454+
455+
for (int i = 0, j = 0; i < inputs.length; i++) {
420456
if (inputs[i] != null) {
421457
int len = inputs[i].numBytes;
422458
PlatformDependent.copyMemory(
423459
inputs[i].base, inputs[i].offset,
424460
result, PlatformDependent.BYTE_ARRAY_OFFSET + offset,
425461
len);
426462
offset += len;
463+
464+
j++;
465+
// Add separator if this is not the last input.
466+
if (j < numInputs) {
467+
PlatformDependent.copyMemory(
468+
separator.base, separator.offset,
469+
result, PlatformDependent.BYTE_ARRAY_OFFSET + offset,
470+
separator.numBytes);
471+
offset += separator.numBytes;
472+
}
427473
}
428474
}
429475
return fromBytes(result);

unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java

Lines changed: 50 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -88,16 +88,50 @@ public void upperAndLower() {
8888

8989
@Test
9090
public void concatTest() {
91-
assertEquals(concat(), fromString(""));
92-
assertEquals(concat(null), fromString(""));
93-
assertEquals(concat(fromString("")), fromString(""));
94-
assertEquals(concat(fromString("ab")), fromString("ab"));
95-
assertEquals(concat(fromString("a"), fromString("b")), fromString("ab"));
96-
assertEquals(concat(fromString("a"), fromString("b"), fromString("c")), fromString("abc"));
97-
assertEquals(concat(fromString("a"), null, fromString("c")), fromString("ac"));
98-
assertEquals(concat(fromString("a"), null, null), fromString("a"));
99-
assertEquals(concat(null, null, null), fromString(""));
100-
assertEquals(concat(fromString("数据"), fromString("砖头")), fromString("数据砖头"));
91+
assertEquals(fromString(""), concat());
92+
assertEquals(null, concat((UTF8String) null));
93+
assertEquals(fromString(""), concat(fromString("")));
94+
assertEquals(fromString("ab"), concat(fromString("ab")));
95+
assertEquals(fromString("ab"), concat(fromString("a"), fromString("b")));
96+
assertEquals(fromString("abc"), concat(fromString("a"), fromString("b"), fromString("c")));
97+
assertEquals(null, concat(fromString("a"), null, fromString("c")));
98+
assertEquals(null, concat(fromString("a"), null, null));
99+
assertEquals(null, concat(null, null, null));
100+
assertEquals(fromString("数据砖头"), concat(fromString("数据"), fromString("砖头")));
101+
}
102+
103+
@Test
104+
public void concatWsTest() {
105+
// Returns null if the separator is null
106+
assertEquals(null, concatWs(null, (UTF8String)null));
107+
assertEquals(null, concatWs(null, fromString("a")));
108+
109+
// If separator is null, concatWs should skip all null inputs and never return null.
110+
UTF8String sep = fromString("哈哈");
111+
assertEquals(
112+
fromString(""),
113+
concatWs(sep, fromString("")));
114+
assertEquals(
115+
fromString("ab"),
116+
concatWs(sep, fromString("ab")));
117+
assertEquals(
118+
fromString("a哈哈b"),
119+
concatWs(sep, fromString("a"), fromString("b")));
120+
assertEquals(
121+
fromString("a哈哈b哈哈c"),
122+
concatWs(sep, fromString("a"), fromString("b"), fromString("c")));
123+
assertEquals(
124+
fromString("a哈哈c"),
125+
concatWs(sep, fromString("a"), null, fromString("c")));
126+
assertEquals(
127+
fromString("a"),
128+
concatWs(sep, fromString("a"), null, null));
129+
assertEquals(
130+
fromString(""),
131+
concatWs(sep, null, null, null));
132+
assertEquals(
133+
fromString("数据哈哈砖头"),
134+
concatWs(sep, fromString("数据"), fromString("砖头")));
101135
}
102136

103137
@Test
@@ -215,14 +249,18 @@ public void pad() {
215249
assertEquals(fromString("??数据砖头"), fromString("数据砖头").lpad(6, fromString("????")));
216250
assertEquals(fromString("孙行数据砖头"), fromString("数据砖头").lpad(6, fromString("孙行者")));
217251
assertEquals(fromString("孙行者数据砖头"), fromString("数据砖头").lpad(7, fromString("孙行者")));
218-
assertEquals(fromString("孙行者孙行者孙行数据砖头"), fromString("数据砖头").lpad(12, fromString("孙行者")));
252+
assertEquals(
253+
fromString("孙行者孙行者孙行数据砖头"),
254+
fromString("数据砖头").lpad(12, fromString("孙行者")));
219255

220256
assertEquals(fromString("数据砖"), fromString("数据砖头").rpad(3, fromString("????")));
221257
assertEquals(fromString("数据砖头?"), fromString("数据砖头").rpad(5, fromString("????")));
222258
assertEquals(fromString("数据砖头??"), fromString("数据砖头").rpad(6, fromString("????")));
223259
assertEquals(fromString("数据砖头孙行"), fromString("数据砖头").rpad(6, fromString("孙行者")));
224260
assertEquals(fromString("数据砖头孙行者"), fromString("数据砖头").rpad(7, fromString("孙行者")));
225-
assertEquals(fromString("数据砖头孙行者孙行者孙行"), fromString("数据砖头").rpad(12, fromString("孙行者")));
261+
assertEquals(
262+
fromString("数据砖头孙行者孙行者孙行"),
263+
fromString("数据砖头").rpad(12, fromString("孙行者")));
226264
}
227265

228266
@Test

0 commit comments

Comments
 (0)