Skip to content

Commit 57a2352

Browse files
committed
[SPARK-8240][SQL] string function: concat
1 parent 1b4ff05 commit 57a2352

File tree

7 files changed

+128
-246
lines changed

7 files changed

+128
-246
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,7 @@ object FunctionRegistry {
152152
// string functions
153153
expression[Ascii]("ascii"),
154154
expression[Base64]("base64"),
155+
expression[Concat]("concat"),
155156
expression[Encode]("encode"),
156157
expression[Decode]("decode"),
157158
expression[FormatNumber]("format_number"),

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,40 @@ import org.apache.spark.sql.catalyst.expressions.codegen._
2727
import org.apache.spark.sql.types._
2828
import org.apache.spark.unsafe.types.UTF8String
2929

30+
////////////////////////////////////////////////////////////////////////////////////////////////////
31+
// This file defines expressions for string operations.
32+
////////////////////////////////////////////////////////////////////////////////////////////////////
33+
34+
35+
/**
36+
* An expression that concatenates multiple input strings into a single string.
37+
* Input expressions that are evaluated to nulls are skipped.
38+
*
39+
* For example, `concat("a", null, "b")` is evaluated to `"ab"`.
40+
*/
41+
case class Concat(children: Seq[Expression]) extends Expression with ImplicitCastInputTypes {
42+
43+
override def inputTypes: Seq[AbstractDataType] = Seq.fill(children.size)(StringType)
44+
override def dataType: DataType = StringType
45+
46+
override def nullable: Boolean = children.exists(_.nullable)
47+
override def foldable: Boolean = children.forall(_.foldable)
48+
49+
override def eval(input: InternalRow): Any = {
50+
val inputs = children.map(_.eval(input).asInstanceOf[UTF8String])
51+
UTF8String.concat(inputs : _*)
52+
}
53+
54+
override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
55+
val evals = children.map(_.gen(ctx))
56+
val inputs = evals.map { eval => s"${eval.isNull} ? null : ${eval.primitive}" }.mkString(", ")
57+
evals.map(_.code).mkString("\n") + s"""
58+
boolean ${ev.isNull} = false;
59+
UTF8String ${ev.primitive} = UTF8String.concat($inputs);
60+
"""
61+
}
62+
}
63+
3064

3165
trait StringRegexExpression extends ImplicitCastInputTypes {
3266
self: BinaryExpression =>
Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,29 @@ import org.apache.spark.sql.catalyst.dsl.expressions._
2222
import org.apache.spark.sql.types._
2323

2424

25-
class StringFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper {
25+
class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
26+
27+
test("concat") {
28+
def testConcat(inputs: String*): Unit = {
29+
val expected = inputs.filter(_ != null).mkString
30+
checkEvaluation(Concat(inputs.map(Literal.create(_, StringType))), expected, EmptyRow)
31+
}
32+
33+
testConcat()
34+
testConcat(null)
35+
testConcat("")
36+
testConcat("ab")
37+
testConcat("a", "b")
38+
testConcat("a", "b", "C")
39+
testConcat("a", null, "C")
40+
testConcat("a", null, null)
41+
testConcat(null, null, null)
42+
43+
// scalastyle:off
44+
// non ascii characters are not allowed in the code, so we disable the scalastyle here.
45+
testConcat("数据", null, "砖头")
46+
// scalastyle:on
47+
}
2648

2749
test("StringComparison") {
2850
val row = create_row("abc", null)

sql/core/src/main/scala/org/apache/spark/sql/functions.scala

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1710,6 +1710,28 @@ object functions {
17101710
// String functions
17111711
//////////////////////////////////////////////////////////////////////////////////////////////
17121712

1713+
/**
1714+
* Concatenates input strings together into a single string.
1715+
*
1716+
* @group string_funcs
1717+
* @since 1.5.0
1718+
*/
1719+
@scala.annotation.varargs
1720+
def concat(exprs: Column*): Column = Concat(exprs.map(_.expr))
1721+
1722+
/**
1723+
* Concatenates input strings together into a single string.
1724+
*
1725+
* This is the variant of concat that takes in the column names.
1726+
*
1727+
* @group string_funcs
1728+
* @since 1.5.0
1729+
*/
1730+
@scala.annotation.varargs
1731+
def concat(columnName: String, columnNames: String*): Column = {
1732+
concat((columnName +: columnNames).map(Column.apply): _*)
1733+
}
1734+
17131735
/**
17141736
* Computes the length of a given string / binary value.
17151737
*

sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala

Lines changed: 0 additions & 242 deletions
Original file line numberDiff line numberDiff line change
@@ -208,169 +208,6 @@ class DataFrameFunctionsSuite extends QueryTest {
208208
Row(2743272264L, 2180413220L))
209209
}
210210

211-
test("Levenshtein distance") {
212-
val df = Seq(("kitten", "sitting"), ("frog", "fog")).toDF("l", "r")
213-
checkAnswer(df.select(levenshtein("l", "r")), Seq(Row(3), Row(1)))
214-
checkAnswer(df.selectExpr("levenshtein(l, r)"), Seq(Row(3), Row(1)))
215-
}
216-
217-
test("string ascii function") {
218-
val df = Seq(("abc", "")).toDF("a", "b")
219-
checkAnswer(
220-
df.select(ascii($"a"), ascii("b")),
221-
Row(97, 0))
222-
223-
checkAnswer(
224-
df.selectExpr("ascii(a)", "ascii(b)"),
225-
Row(97, 0))
226-
}
227-
228-
test("string base64/unbase64 function") {
229-
val bytes = Array[Byte](1, 2, 3, 4)
230-
val df = Seq((bytes, "AQIDBA==")).toDF("a", "b")
231-
checkAnswer(
232-
df.select(base64("a"), base64($"a"), unbase64("b"), unbase64($"b")),
233-
Row("AQIDBA==", "AQIDBA==", bytes, bytes))
234-
235-
checkAnswer(
236-
df.selectExpr("base64(a)", "unbase64(b)"),
237-
Row("AQIDBA==", bytes))
238-
}
239-
240-
test("string encode/decode function") {
241-
val bytes = Array[Byte](-27, -92, -89, -27, -115, -125, -28, -72, -106, -25, -107, -116)
242-
// scalastyle:off
243-
// non ascii characters are not allowed in the code, so we disable the scalastyle here.
244-
val df = Seq(("大千世界", "utf-8", bytes)).toDF("a", "b", "c")
245-
checkAnswer(
246-
df.select(
247-
encode($"a", "utf-8"),
248-
encode("a", "utf-8"),
249-
decode($"c", "utf-8"),
250-
decode("c", "utf-8")),
251-
Row(bytes, bytes, "大千世界", "大千世界"))
252-
253-
checkAnswer(
254-
df.selectExpr("encode(a, 'utf-8')", "decode(c, 'utf-8')"),
255-
Row(bytes, "大千世界"))
256-
// scalastyle:on
257-
}
258-
259-
test("string trim functions") {
260-
val df = Seq((" example ", "")).toDF("a", "b")
261-
262-
checkAnswer(
263-
df.select(ltrim($"a"), rtrim($"a"), trim($"a")),
264-
Row("example ", " example", "example"))
265-
266-
checkAnswer(
267-
df.selectExpr("ltrim(a)", "rtrim(a)", "trim(a)"),
268-
Row("example ", " example", "example"))
269-
}
270-
271-
test("string formatString function") {
272-
val df = Seq(("aa%d%s", 123, "cc")).toDF("a", "b", "c")
273-
274-
checkAnswer(
275-
df.select(formatString($"a", $"b", $"c"), formatString("aa%d%s", "b", "c")),
276-
Row("aa123cc", "aa123cc"))
277-
278-
checkAnswer(
279-
df.selectExpr("printf(a, b, c)"),
280-
Row("aa123cc"))
281-
}
282-
283-
test("string instr function") {
284-
val df = Seq(("aaads", "aa", "zz")).toDF("a", "b", "c")
285-
286-
checkAnswer(
287-
df.select(instr($"a", $"b"), instr("a", "b")),
288-
Row(1, 1))
289-
290-
checkAnswer(
291-
df.selectExpr("instr(a, b)"),
292-
Row(1))
293-
}
294-
295-
test("string locate function") {
296-
val df = Seq(("aaads", "aa", "zz", 1)).toDF("a", "b", "c", "d")
297-
298-
checkAnswer(
299-
df.select(
300-
locate($"b", $"a"), locate("b", "a"), locate($"b", $"a", 1),
301-
locate("b", "a", 1), locate($"b", $"a", $"d"), locate("b", "a", "d")),
302-
Row(1, 1, 2, 2, 2, 2))
303-
304-
checkAnswer(
305-
df.selectExpr("locate(b, a)", "locate(b, a, d)"),
306-
Row(1, 2))
307-
}
308-
309-
test("string padding functions") {
310-
val df = Seq(("hi", 5, "??")).toDF("a", "b", "c")
311-
312-
checkAnswer(
313-
df.select(
314-
lpad($"a", $"b", $"c"), rpad("a", "b", "c"),
315-
lpad($"a", 1, $"c"), rpad("a", 1, "c")),
316-
Row("???hi", "hi???", "h", "h"))
317-
318-
checkAnswer(
319-
df.selectExpr("lpad(a, b, c)", "rpad(a, b, c)", "lpad(a, 1, c)", "rpad(a, 1, c)"),
320-
Row("???hi", "hi???", "h", "h"))
321-
}
322-
323-
test("string repeat function") {
324-
val df = Seq(("hi", 2)).toDF("a", "b")
325-
326-
checkAnswer(
327-
df.select(
328-
repeat($"a", 2), repeat("a", 2), repeat($"a", $"b"), repeat("a", "b")),
329-
Row("hihi", "hihi", "hihi", "hihi"))
330-
331-
checkAnswer(
332-
df.selectExpr("repeat(a, 2)", "repeat(a, b)"),
333-
Row("hihi", "hihi"))
334-
}
335-
336-
test("string reverse function") {
337-
val df = Seq(("hi", "hhhi")).toDF("a", "b")
338-
339-
checkAnswer(
340-
df.select(reverse($"a"), reverse("b")),
341-
Row("ih", "ihhh"))
342-
343-
checkAnswer(
344-
df.selectExpr("reverse(b)"),
345-
Row("ihhh"))
346-
}
347-
348-
test("string space function") {
349-
val df = Seq((2, 3)).toDF("a", "b")
350-
351-
checkAnswer(
352-
df.select(space($"a"), space("b")),
353-
Row(" ", " "))
354-
355-
checkAnswer(
356-
df.selectExpr("space(b)"),
357-
Row(" "))
358-
}
359-
360-
test("string split function") {
361-
val df = Seq(("aa2bb3cc", "[1-9]+")).toDF("a", "b")
362-
363-
checkAnswer(
364-
df.select(
365-
split($"a", "[1-9]+"),
366-
split("a", "[1-9]+")),
367-
Row(Seq("aa", "bb", "cc"), Seq("aa", "bb", "cc")))
368-
369-
checkAnswer(
370-
df.selectExpr("split(a, '[1-9]+')"),
371-
Row(Seq("aa", "bb", "cc")))
372-
}
373-
374211
test("conditional function: least") {
375212
checkAnswer(
376213
testData2.select(least(lit(-1), lit(0), col("a"), col("b"))).limit(1),
@@ -430,83 +267,4 @@ class DataFrameFunctionsSuite extends QueryTest {
430267
)
431268
}
432269

433-
test("string / binary length function") {
434-
val df = Seq(("123", Array[Byte](1, 2, 3, 4), 123)).toDF("a", "b", "c")
435-
checkAnswer(
436-
df.select(length($"a"), length("a"), length($"b"), length("b")),
437-
Row(3, 3, 4, 4))
438-
439-
checkAnswer(
440-
df.selectExpr("length(a)", "length(b)"),
441-
Row(3, 4))
442-
443-
intercept[AnalysisException] {
444-
checkAnswer(
445-
df.selectExpr("length(c)"), // int type of the argument is unacceptable
446-
Row("5.0000"))
447-
}
448-
}
449-
450-
test("number format function") {
451-
val tuple =
452-
("aa", 1.asInstanceOf[Byte], 2.asInstanceOf[Short],
453-
3.13223f, 4, 5L, 6.48173d, Decimal(7.128381))
454-
val df =
455-
Seq(tuple)
456-
.toDF(
457-
"a", // string "aa"
458-
"b", // byte 1
459-
"c", // short 2
460-
"d", // float 3.13223f
461-
"e", // integer 4
462-
"f", // long 5L
463-
"g", // double 6.48173d
464-
"h") // decimal 7.128381
465-
466-
checkAnswer(
467-
df.select(
468-
format_number($"f", 4),
469-
format_number("f", 4)),
470-
Row("5.0000", "5.0000"))
471-
472-
checkAnswer(
473-
df.selectExpr("format_number(b, e)"), // convert the 1st argument to integer
474-
Row("1.0000"))
475-
476-
checkAnswer(
477-
df.selectExpr("format_number(c, e)"), // convert the 1st argument to integer
478-
Row("2.0000"))
479-
480-
checkAnswer(
481-
df.selectExpr("format_number(d, e)"), // convert the 1st argument to double
482-
Row("3.1322"))
483-
484-
checkAnswer(
485-
df.selectExpr("format_number(e, e)"), // not convert anything
486-
Row("4.0000"))
487-
488-
checkAnswer(
489-
df.selectExpr("format_number(f, e)"), // not convert anything
490-
Row("5.0000"))
491-
492-
checkAnswer(
493-
df.selectExpr("format_number(g, e)"), // not convert anything
494-
Row("6.4817"))
495-
496-
checkAnswer(
497-
df.selectExpr("format_number(h, e)"), // not convert anything
498-
Row("7.1284"))
499-
500-
intercept[AnalysisException] {
501-
checkAnswer(
502-
df.selectExpr("format_number(a, e)"), // string type of the 1st argument is unacceptable
503-
Row("5.0000"))
504-
}
505-
506-
intercept[AnalysisException] {
507-
checkAnswer(
508-
df.selectExpr("format_number(e, g)"), // decimal type of the 2nd argument is unacceptable
509-
Row("5.0000"))
510-
}
511-
}
512270
}

0 commit comments

Comments
 (0)