Skip to content

Commit d4c7a7a

Browse files
tarekbeckermarmbrus
authored andcommitted
[SPARK-9154] [SQL] codegen StringFormat
Jira: https://issues.apache.org/jira/browse/SPARK-9154 fixes bug of apache#7546 marmbrus I can't reopen the other PR, because I didn't closed it. Can you trigger Jenkins? Author: Tarek Auel <[email protected]> Closes apache#7571 from tarekauel/SPARK-9154 and squashes the following commits: dcae272 [Tarek Auel] [SPARK-9154][SQL] build fix 1487602 [Tarek Auel] Merge remote-tracking branch 'upstream/master' into SPARK-9154 f512c5f [Tarek Auel] [SPARK-9154][SQL] build fix a943d3e [Tarek Auel] [SPARK-9154] implicit input cast, added tests for null, support for null primitives 10b4de8 [Tarek Auel] [SPARK-9154][SQL] codegen removed fallback trait cd8322b [Tarek Auel] [SPARK-9154][SQL] codegen string format 086caba [Tarek Auel] [SPARK-9154][SQL] codegen string format
1 parent c07838b commit d4c7a7a

File tree

4 files changed

+70
-11
lines changed

4 files changed

+70
-11
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala

Lines changed: 41 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -526,7 +526,7 @@ case class StringRPad(str: Expression, len: Expression, pad: Expression)
526526
/**
527527
* Returns the input formatted according do printf-style format strings
528528
*/
529-
case class StringFormat(children: Expression*) extends Expression with CodegenFallback {
529+
case class StringFormat(children: Expression*) extends Expression with ImplicitCastInputTypes {
530530

531531
require(children.nonEmpty, "printf() should take at least 1 argument")
532532

@@ -536,6 +536,10 @@ case class StringFormat(children: Expression*) extends Expression with CodegenFa
536536
private def format: Expression = children(0)
537537
private def args: Seq[Expression] = children.tail
538538

539+
override def inputTypes: Seq[AbstractDataType] =
540+
StringType :: List.fill(children.size - 1)(AnyDataType)
541+
542+
539543
override def eval(input: InternalRow): Any = {
540544
val pattern = format.eval(input)
541545
if (pattern == null) {
@@ -551,6 +555,42 @@ case class StringFormat(children: Expression*) extends Expression with CodegenFa
551555
}
552556
}
553557

558+
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
559+
val pattern = children.head.gen(ctx)
560+
561+
val argListGen = children.tail.map(x => (x.dataType, x.gen(ctx)))
562+
val argListCode = argListGen.map(_._2.code + "\n")
563+
564+
val argListString = argListGen.foldLeft("")((s, v) => {
565+
val nullSafeString =
566+
if (ctx.boxedType(v._1) != ctx.javaType(v._1)) {
567+
// Java primitives get boxed in order to allow null values.
568+
s"(${v._2.isNull}) ? (${ctx.boxedType(v._1)}) null : " +
569+
s"new ${ctx.boxedType(v._1)}(${v._2.primitive})"
570+
} else {
571+
s"(${v._2.isNull}) ? null : ${v._2.primitive}"
572+
}
573+
s + "," + nullSafeString
574+
})
575+
576+
val form = ctx.freshName("formatter")
577+
val formatter = classOf[java.util.Formatter].getName
578+
val sb = ctx.freshName("sb")
579+
val stringBuffer = classOf[StringBuffer].getName
580+
s"""
581+
${pattern.code}
582+
boolean ${ev.isNull} = ${pattern.isNull};
583+
${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)};
584+
if (!${ev.isNull}) {
585+
${argListCode.mkString}
586+
$stringBuffer $sb = new $stringBuffer();
587+
$formatter $form = new $formatter($sb, ${classOf[Locale].getName}.US);
588+
$form.format(${pattern.primitive}.toString() $argListString);
589+
${ev.primitive} = UTF8String.fromString($sb.toString());
590+
}
591+
"""
592+
}
593+
554594
override def prettyName: String = "printf"
555595
}
556596

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -351,18 +351,16 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
351351
}
352352

353353
test("FORMAT") {
354-
val f = 'f.string.at(0)
355-
val d1 = 'd.int.at(1)
356-
val s1 = 's.int.at(2)
357-
358-
val row1 = create_row("aa%d%s", 12, "cc")
359-
val row2 = create_row(null, 12, "cc")
360-
checkEvaluation(StringFormat(Literal("aa%d%s"), Literal(123), Literal("a")), "aa123a", row1)
354+
checkEvaluation(StringFormat(Literal("aa%d%s"), Literal(123), Literal("a")), "aa123a")
361355
checkEvaluation(StringFormat(Literal("aa")), "aa", create_row(null))
362-
checkEvaluation(StringFormat(Literal("aa%d%s"), Literal(123), Literal("a")), "aa123a", row1)
356+
checkEvaluation(StringFormat(Literal("aa%d%s"), Literal(123), Literal("a")), "aa123a")
357+
checkEvaluation(StringFormat(Literal("aa%d%s"), 12, "cc"), "aa12cc")
363358

364-
checkEvaluation(StringFormat(f, d1, s1), "aa12cc", row1)
365-
checkEvaluation(StringFormat(f, d1, s1), null, row2)
359+
checkEvaluation(StringFormat(Literal.create(null, StringType), 12, "cc"), null)
360+
checkEvaluation(
361+
StringFormat(Literal("aa%d%s"), Literal.create(null, IntegerType), "cc"), "aanullcc")
362+
checkEvaluation(
363+
StringFormat(Literal("aa%d%s"), 12, Literal.create(null, StringType)), "aa12null")
366364
}
367365

368366
test("INSTR") {

sql/core/src/main/scala/org/apache/spark/sql/functions.scala

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1741,6 +1741,17 @@ object functions {
17411741
*/
17421742
def rtrim(e: Column): Column = StringTrimRight(e.expr)
17431743

1744+
/**
1745+
* Format strings in printf-style.
1746+
*
1747+
* @group string_funcs
1748+
* @since 1.5.0
1749+
*/
1750+
@scala.annotation.varargs
1751+
def formatString(format: Column, arguments: Column*): Column = {
1752+
StringFormat((format +: arguments).map(_.expr): _*)
1753+
}
1754+
17441755
/**
17451756
* Format strings in printf-style.
17461757
* NOTE: `format` is the string value of the formatter, not column name.

sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,16 @@ class StringFunctionsSuite extends QueryTest {
132132
checkAnswer(
133133
df.selectExpr("printf(a, b, c)"),
134134
Row("aa123cc"))
135+
136+
val df2 = Seq(("aa%d%s".getBytes, 123, "cc")).toDF("a", "b", "c")
137+
138+
checkAnswer(
139+
df2.select(formatString($"a", $"b", $"c"), formatString("aa%d%s", "b", "c")),
140+
Row("aa123cc", "aa123cc"))
141+
142+
checkAnswer(
143+
df2.selectExpr("printf(a, b, c)"),
144+
Row("aa123cc"))
135145
}
136146

137147
test("string instr function") {

0 commit comments

Comments
 (0)