Skip to content

Commit f1d5901

Browse files
committed
Merge remote-tracking branch 'upstream/master' into UDAF
2 parents 35b0520 + a4c83cb commit f1d5901

File tree

5 files changed

+60
-25
lines changed

5 files changed

+60
-25
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -168,7 +168,8 @@ object FunctionRegistry {
168168
expression[StringLocate]("locate"),
169169
expression[StringLPad]("lpad"),
170170
expression[StringTrimLeft]("ltrim"),
171-
expression[StringFormat]("printf"),
171+
expression[FormatString]("format_string"),
172+
expression[FormatString]("printf"),
172173
expression[StringRPad]("rpad"),
173174
expression[StringRepeat]("repeat"),
174175
expression[StringReverse]("reverse"),

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala

Lines changed: 44 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -526,32 +526,69 @@ case class StringRPad(str: Expression, len: Expression, pad: Expression)
526526
/**
527527
* Returns the input formatted according do printf-style format strings
528528
*/
529-
case class StringFormat(children: Expression*) extends Expression with CodegenFallback {
529+
case class FormatString(children: Expression*) extends Expression with ImplicitCastInputTypes {
530530

531-
require(children.nonEmpty, "printf() should take at least 1 argument")
531+
require(children.nonEmpty, "format_string() should take at least 1 argument")
532532

533533
override def foldable: Boolean = children.forall(_.foldable)
534534
override def nullable: Boolean = children(0).nullable
535535
override def dataType: DataType = StringType
536-
private def format: Expression = children(0)
537-
private def args: Seq[Expression] = children.tail
536+
537+
override def inputTypes: Seq[AbstractDataType] =
538+
StringType :: List.fill(children.size - 1)(AnyDataType)
538539

539540
override def eval(input: InternalRow): Any = {
540-
val pattern = format.eval(input)
541+
val pattern = children(0).eval(input)
541542
if (pattern == null) {
542543
null
543544
} else {
544545
val sb = new StringBuffer()
545546
val formatter = new java.util.Formatter(sb, Locale.US)
546547

547-
val arglist = args.map(_.eval(input).asInstanceOf[AnyRef])
548+
val arglist = children.tail.map(_.eval(input).asInstanceOf[AnyRef])
548549
formatter.format(pattern.asInstanceOf[UTF8String].toString, arglist: _*)
549550

550551
UTF8String.fromString(sb.toString)
551552
}
552553
}
553554

554-
override def prettyName: String = "printf"
555+
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
556+
val pattern = children.head.gen(ctx)
557+
558+
val argListGen = children.tail.map(x => (x.dataType, x.gen(ctx)))
559+
val argListCode = argListGen.map(_._2.code + "\n")
560+
561+
val argListString = argListGen.foldLeft("")((s, v) => {
562+
val nullSafeString =
563+
if (ctx.boxedType(v._1) != ctx.javaType(v._1)) {
564+
// Java primitives get boxed in order to allow null values.
565+
s"(${v._2.isNull}) ? (${ctx.boxedType(v._1)}) null : " +
566+
s"new ${ctx.boxedType(v._1)}(${v._2.primitive})"
567+
} else {
568+
s"(${v._2.isNull}) ? null : ${v._2.primitive}"
569+
}
570+
s + "," + nullSafeString
571+
})
572+
573+
val form = ctx.freshName("formatter")
574+
val formatter = classOf[java.util.Formatter].getName
575+
val sb = ctx.freshName("sb")
576+
val stringBuffer = classOf[StringBuffer].getName
577+
s"""
578+
${pattern.code}
579+
boolean ${ev.isNull} = ${pattern.isNull};
580+
${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)};
581+
if (!${ev.isNull}) {
582+
${argListCode.mkString}
583+
$stringBuffer $sb = new $stringBuffer();
584+
$formatter $form = new $formatter($sb, ${classOf[Locale].getName}.US);
585+
$form.format(${pattern.primitive}.toString() $argListString);
586+
${ev.primitive} = UTF8String.fromString($sb.toString());
587+
}
588+
"""
589+
}
590+
591+
override def prettyName: String = "format_string"
555592
}
556593

557594
/**

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala

Lines changed: 10 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -351,18 +351,16 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
351351
}
352352

353353
test("FORMAT") {
354-
val f = 'f.string.at(0)
355-
val d1 = 'd.int.at(1)
356-
val s1 = 's.int.at(2)
357-
358-
val row1 = create_row("aa%d%s", 12, "cc")
359-
val row2 = create_row(null, 12, "cc")
360-
checkEvaluation(StringFormat(Literal("aa%d%s"), Literal(123), Literal("a")), "aa123a", row1)
361-
checkEvaluation(StringFormat(Literal("aa")), "aa", create_row(null))
362-
checkEvaluation(StringFormat(Literal("aa%d%s"), Literal(123), Literal("a")), "aa123a", row1)
363-
364-
checkEvaluation(StringFormat(f, d1, s1), "aa12cc", row1)
365-
checkEvaluation(StringFormat(f, d1, s1), null, row2)
354+
checkEvaluation(FormatString(Literal("aa%d%s"), Literal(123), Literal("a")), "aa123a")
355+
checkEvaluation(FormatString(Literal("aa")), "aa", create_row(null))
356+
checkEvaluation(FormatString(Literal("aa%d%s"), Literal(123), Literal("a")), "aa123a")
357+
checkEvaluation(FormatString(Literal("aa%d%s"), 12, "cc"), "aa12cc")
358+
359+
checkEvaluation(FormatString(Literal.create(null, StringType), 12, "cc"), null)
360+
checkEvaluation(
361+
FormatString(Literal("aa%d%s"), Literal.create(null, IntegerType), "cc"), "aanullcc")
362+
checkEvaluation(
363+
FormatString(Literal("aa%d%s"), 12, Literal.create(null, StringType)), "aa12null")
366364
}
367365

368366
test("INSTR") {

sql/core/src/main/scala/org/apache/spark/sql/functions.scala

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1742,15 +1742,14 @@ object functions {
17421742
def rtrim(e: Column): Column = StringTrimRight(e.expr)
17431743

17441744
/**
1745-
* Format strings in printf-style.
1746-
* NOTE: `format` is the string value of the formatter, not column name.
1745+
* Formats the arguments in printf-style and returns the result as a string column.
17471746
*
17481747
* @group string_funcs
17491748
* @since 1.5.0
17501749
*/
17511750
@scala.annotation.varargs
1752-
def formatString(format: String, arguNames: String*): Column = {
1753-
StringFormat(lit(format).expr +: arguNames.map(Column(_).expr): _*)
1751+
def format_string(format: String, arguments: Column*): Column = {
1752+
FormatString((lit(format) +: arguments).map(_.expr): _*)
17541753
}
17551754

17561755
/**

sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,7 @@ class StringFunctionsSuite extends QueryTest {
126126
val df = Seq(("aa%d%s", 123, "cc")).toDF("a", "b", "c")
127127

128128
checkAnswer(
129-
df.select(formatString("aa%d%s", "b", "c")),
129+
df.select(format_string("aa%d%s", $"b", $"c")),
130130
Row("aa123cc"))
131131

132132
checkAnswer(

0 commit comments

Comments
 (0)