Skip to content

Commit a943d3e

Browse files
committed
[SPARK-9154] implicit input cast, added tests for null, support for null primitives
1 parent 10b4de8 commit a943d3e

File tree

3 files changed

+37
-15
lines changed

3 files changed

+37
-15
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -476,7 +476,7 @@ case class StringRPad(str: Expression, len: Expression, pad: Expression)
476476
/**
477477
* Returns the input formatted according do printf-style format strings
478478
*/
479-
case class StringFormat(children: Expression*) extends Expression {
479+
case class StringFormat(children: Expression*) extends Expression with ImplicitCastInputTypes {
480480

481481
require(children.nonEmpty, "printf() should take at least 1 argument")
482482

@@ -486,6 +486,10 @@ case class StringFormat(children: Expression*) extends Expression {
486486
private def format: Expression = children(0)
487487
private def args: Seq[Expression] = children.tail
488488

489+
override def inputTypes: Seq[AbstractDataType] =
490+
children.zipWithIndex.map(x => if (x._2 == 0) StringType else AnyDataType)
491+
492+
489493
override def eval(input: InternalRow): Any = {
490494
val pattern = format.eval(input)
491495
if (pattern == null) {
@@ -504,15 +508,25 @@ case class StringFormat(children: Expression*) extends Expression {
504508
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
505509
val pattern = children.head.gen(ctx)
506510

507-
val argListGen = children.tail.map(_.gen(ctx))
508-
val argListCode = argListGen.map(_.code + "\n")
509-
val argListString = argListGen.foldLeft("")((s, v) => s + s", ${v.primitive}")
511+
val argListGen = children.tail.map(x => (x.dataType, x.gen(ctx)))
512+
val argListCode = argListGen.map(_._2.code + "\n")
513+
514+
val argListString = argListGen.foldLeft("")((s, v) => {
515+
val nullSafeString =
516+
if (ctx.boxedType(v._1) != ctx.javaType(v._1)) {
517+
// Java primitives get boxed in order to allow null values.
518+
s"(${v._2.isNull}) ? (${ctx.boxedType(v._1)}) null : " +
519+
s"new ${ctx.boxedType(v._1)}(${v._2.primitive})"
520+
} else {
521+
s"(${v._2.isNull}) ? null : ${v._2.primitive}"
522+
}
523+
s + "," + nullSafeString
524+
})
510525

511526
val form = ctx.freshName("formatter")
512527
val formatter = classOf[java.util.Formatter].getName
513528
val sb = ctx.freshName("sb")
514529
val stringBuffer = classOf[StringBuffer].getName
515-
516530
s"""
517531
${pattern.code}
518532
boolean ${ev.isNull} = ${pattern.isNull};

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -351,18 +351,16 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
351351
}
352352

353353
test("FORMAT") {
354-
val f = 'f.string.at(0)
355-
val d1 = 'd.int.at(1)
356-
val s1 = 's.string.at(2)
357-
358-
val row1 = create_row("aa%d%s", 12, "cc")
359-
val row2 = create_row(null, 12, "cc")
360-
checkEvaluation(StringFormat(Literal("aa%d%s"), Literal(123), Literal("a")), "aa123a", row1)
354+
checkEvaluation(StringFormat(Literal("aa%d%s"), Literal(123), Literal("a")), "aa123a")
361355
checkEvaluation(StringFormat(Literal("aa")), "aa", create_row(null))
362-
checkEvaluation(StringFormat(Literal("aa%d%s"), Literal(123), Literal("a")), "aa123a", row1)
356+
checkEvaluation(StringFormat(Literal("aa%d%s"), Literal(123), Literal("a")), "aa123a")
357+
checkEvaluation(StringFormat(Literal("aa%d%s"), 12, "cc"), "aa12cc")
363358

364-
checkEvaluation(StringFormat(f, d1, s1), "aa12cc", row1)
365-
checkEvaluation(StringFormat(f, d1, s1), null, row2)
359+
checkEvaluation(StringFormat(Literal.create(null, StringType), 12, "cc"), null)
360+
checkEvaluation(
361+
StringFormat(Literal("aa%d%s"), Literal.create(null, IntegerType), "cc"), "aanullcc")
362+
checkEvaluation(
363+
StringFormat(Literal("aa%d%s"), 12, Literal.create(null, StringType)), "aa12null")
366364
}
367365

368366
test("INSTR") {

sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,16 @@ class StringFunctionsSuite extends QueryTest {
120120
checkAnswer(
121121
df.selectExpr("printf(a, b, c)"),
122122
Row("aa123cc"))
123+
124+
val df2 = Seq(("aa%d%s".getBytes, 123, "cc")).toDF("a", "b", "c")
125+
126+
checkAnswer(
127+
df2.select(formatString($"a", $"b", $"c"), formatString("aa%d%s", "b", "c")),
128+
Row("aa123cc", "aa123cc"))
129+
130+
checkAnswer(
131+
df2.selectExpr("printf(a, b, c)"),
132+
Row("aa123cc"))
123133
}
124134

125135
test("string instr function") {

0 commit comments

Comments
 (0)