Skip to content

Commit 3ebe288

Browse files
update as feedback
1 parent 52274f7 commit 3ebe288

File tree

4 files changed

+19
-20
lines changed

4 files changed

+19
-20
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -567,7 +567,6 @@ case class Length(child: Expression) extends UnaryExpression with ExpectsInputTy
567567
child.dataType match {
568568
case StringType => defineCodeGen(ctx, ev, c => s"($c).numChars()")
569569
case BinaryType => defineCodeGen(ctx, ev, c => s"($c).length")
570-
case NullType => defineCodeGen(ctx, ev, c => s"-1")
571570
}
572571
}
573572

@@ -685,8 +684,6 @@ case class FormatNumber(x: Expression, d: Expression)
685684
override def right: Expression = d
686685
override def dataType: DataType = StringType
687686
override def inputTypes: Seq[AbstractDataType] = Seq(NumericType, IntegerType)
688-
override def foldable: Boolean = x.foldable && d.foldable
689-
override def nullable: Boolean = x.nullable || d.nullable
690687

691688
@transient
692689
private var lastDValue: Int = -100
@@ -706,8 +703,7 @@ case class FormatNumber(x: Expression, d: Expression)
706703
val dObject = d.eval(input)
707704

708705
if (dObject == null || dObject.asInstanceOf[Int] < 0) {
709-
throw new IllegalArgumentException(
710-
s"Argument 2 of function FORMAT_NUMBER must be >= 0, but $dObject was found")
706+
return null
711707
}
712708
val dValue = dObject.asInstanceOf[Int]
713709

@@ -742,5 +738,7 @@ case class FormatNumber(x: Expression, d: Expression)
742738
UTF8String.fromString(numberFormat.format(xObject.asInstanceOf[Decimal].toJavaBigDecimal))
743739
}
744740
}
741+
742+
override def prettyName: String = "format_number"
745743
}
746744

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringFunctionsSuite.scala

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -441,8 +441,6 @@ class StringFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper {
441441

442442
checkEvaluation(Length(Literal.create(null, StringType)), null, create_row(string))
443443
checkEvaluation(Length(Literal.create(null, BinaryType)), null, create_row(bytes))
444-
445-
checkEvaluation(Length(Literal.create(null, NullType)), null, create_row(null))
446444
}
447445

448446
test("number format") {
@@ -453,6 +451,7 @@ class StringFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper {
453451
checkEvaluation(FormatNumber(Literal(12831273.23481d), Literal(3)), "12,831,273.235")
454452
checkEvaluation(FormatNumber(Literal(12831273.83421d), Literal(0)), "12,831,274")
455453
checkEvaluation(FormatNumber(Literal(123123324123L), Literal(3)), "123,123,324,123.000")
454+
checkEvaluation(FormatNumber(Literal(123123324123L), Literal(-1)), null)
456455
checkEvaluation(
457456
FormatNumber(
458457
Literal(Decimal(123123324123L) * Decimal(123123.21234d)), Literal(4)),

sql/core/src/main/scala/org/apache/spark/sql/functions.scala

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1685,41 +1685,43 @@ object functions {
16851685
//////////////////////////////////////////////////////////////////////////////////////////////
16861686

16871687
/**
1688-
* Computes the length of a given string / binary value
1688+
* Computes the length of a given string / binary value.
16891689
*
16901690
* @group string_funcs
16911691
* @since 1.5.0
16921692
*/
16931693
def length(e: Column): Column = Length(e.expr)
16941694

16951695
/**
1696-
* Computes the length of a given string / binary column
1696+
* Computes the length of a given string / binary column.
16971697
*
16981698
* @group string_funcs
16991699
* @since 1.5.0
17001700
*/
17011701
def length(columnName: String): Column = length(Column(columnName))
17021702

17031703
/**
1704-
* Formats the number X to a format like '#,###,###.##', rounded to D decimal places,
1705-
* and returns the result as a string. If D is 0, the result has no decimal point or
1706-
* fractional part.
1704+
* Formats the number X to a format like '#,###,###.##', rounded to d decimal places,
1705+
* and returns the result as a string.
1706+
* If d is 0, the result has no decimal point or fractional part.
1707+
* If d < 0, the result will be null.
17071708
*
17081709
* @group string_funcs
17091710
* @since 1.5.0
17101711
*/
1711-
def formatNumber(x: Column, d: Column): Column = FormatNumber(x.expr, d.expr)
1712+
def format_number(x: Column, d: Int): Column = FormatNumber(x.expr, lit(d).expr)
17121713

17131714
/**
1714-
* Formats the number X to a format like '#,###,###.##', rounded to D decimal places,
1715-
* and returns the result as a string. If D is 0, the result has no decimal point or
1716-
* fractional part.
1715+
* Formats the number X to a format like '#,###,###.##', rounded to d decimal places,
1716+
* and returns the result as a string.
1717+
* If d is 0, the result has no decimal point or fractional part.
1718+
* If d < 0, the result will be null.
17171719
*
17181720
* @group string_funcs
17191721
* @since 1.5.0
17201722
*/
1721-
def formatNumber(columnXName: String, columnDName: String): Column = {
1722-
formatNumber(Column(columnXName), Column(columnDName))
1723+
def format_number(columnXName: String, d: Int): Column = {
1724+
format_number(Column(columnXName), d)
17231725
}
17241726

17251727
/**

sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -465,8 +465,8 @@ class DataFrameFunctionsSuite extends QueryTest {
465465

466466
checkAnswer(
467467
df.select(
468-
formatNumber($"f", $"e"),
469-
formatNumber("f", "e")),
468+
format_number($"f", 4),
469+
format_number("f", 4)),
470470
Row("5.0000", "5.0000"))
471471

472472
checkAnswer(

0 commit comments

Comments
 (0)