Skip to content

Commit 537fe88

Browse files
committed
Revert the renaming in SubDate and make the returned type as input type (date/timestamp)
1 parent 8c50b2c commit 537fe88

File tree

5 files changed

+59
-35
lines changed

5 files changed

+59
-35
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -324,7 +324,7 @@ object FunctionRegistry {
324324
expression[DateDiff]("datediff"),
325325
expression[AddDays]("date_add"),
326326
expression[DateFormatClass]("date_format"),
327-
expression[SubDays]("date_sub"),
327+
expression[DateSub]("date_sub"),
328328
expression[DayOfMonth]("day"),
329329
expression[DayOfYear]("dayofyear"),
330330
expression[DayOfMonth]("dayofmonth"),
@@ -342,7 +342,7 @@ object FunctionRegistry {
342342
expression[ToDate]("to_date"),
343343
expression[ToUnixTimestamp]("to_unix_timestamp"),
344344
expression[ToUTCTimestamp]("to_utc_timestamp"),
345-
expression[TruncateTimestamp]("trunc"),
345+
expression[TruncInstant]("trunc"),
346346
expression[UnixTimestamp]("unix_timestamp"),
347347
expression[WeekOfYear]("weekofyear"),
348348
expression[Year]("year"),

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala

Lines changed: 18 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,7 @@ case class AddDays(instant: Expression, days: Expression) extends AddDaysBase(in
127127
@ExpressionDescription(
128128
usage = "_FUNC_(instant, num_days) - Returns the date/timestamp that is num_days before instant.",
129129
extended = "> SELECT _FUNC_('2016-07-30', 1);\n '2016-07-29'")
130-
case class SubDays(instant: Expression, days: Expression) extends AddDaysBase(instant, days) {
130+
case class DateSub(instant: Expression, days: Expression) extends AddDaysBase(instant, days) {
131131

132132
override def signModifier: Int = -1
133133

@@ -935,18 +935,19 @@ case class ToDate(child: Expression) extends UnaryExpression with ImplicitCastIn
935935
*/
936936
// scalastyle:off line.size.limit
937937
@ExpressionDescription(
938-
usage = "_FUNC_(timestamp, fmt) - Returns returns timestamp with the time portion truncated to the unit specified by the format model fmt.",
938+
usage = "_FUNC_(instant, fmt) - Returns returns date/timestamp with the time portion truncated to the unit specified by the format model fmt.",
939939
extended = "> SELECT _FUNC_('2009-02-12', 'MM')\n '2009-02-01 00:00:00'\n> SELECT _FUNC_('2015-10-27', 'YEAR');\n '2015-01-01 00:00:00'")
940940
// scalastyle:on line.size.limit
941-
case class TruncateTimestamp(timestamp: Expression, format: Expression)
941+
case class TruncInstant(instant: Expression, format: Expression)
942942
extends BinaryExpression with ImplicitCastInputTypes {
943943

944-
override def left: Expression = timestamp
944+
override def left: Expression = instant
945945
override def right: Expression = format
946946

947-
override def inputTypes: Seq[AbstractDataType] = Seq(TimestampType, StringType)
947+
override def inputTypes: Seq[AbstractDataType] =
948+
Seq(TypeCollection(DateType, TimestampType), StringType)
948949

949-
override def dataType: DataType = TimestampType
950+
override def dataType: DataType = instant.dataType
950951

951952
override def nullable: Boolean = true
952953

@@ -965,11 +966,12 @@ case class TruncateTimestamp(timestamp: Expression, format: Expression)
965966
// unknown format
966967
null
967968
} else {
968-
val ts = timestamp.eval(input)
969-
if (ts == null) {
970-
null
971-
} else {
972-
DateTimeUtils.truncateTimestamp(ts.asInstanceOf[Long], level)
969+
(instant.dataType, instant.eval(input)) match {
970+
case (_: DateType, date: Int) =>
971+
DateTimeUtils.truncateInstant(date, level)
972+
case (_: TimestampType, timestamp: Long) =>
973+
DateTimeUtils.truncateInstant(timestamp, level)
974+
case (_, null) => null
973975
}
974976
}
975977
}
@@ -983,13 +985,13 @@ case class TruncateTimestamp(timestamp: Expression, format: Expression)
983985
boolean ${ev.isNull} = true;
984986
${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};""")
985987
} else {
986-
val ts = timestamp.genCode(ctx)
988+
val ist = instant.genCode(ctx)
987989
ev.copy(code = s"""
988-
${ts.code}
989-
boolean ${ev.isNull} = ${ts.isNull};
990+
${ist.code}
991+
boolean ${ev.isNull} = ${ist.isNull};
990992
${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};
991993
if (!${ev.isNull}) {
992-
${ev.value} = $dtu.truncateTimestamp(${ts.value}, $truncLevel);
994+
${ev.value} = $dtu.truncateInstant(${ist.value}, $truncLevel);
993995
}""")
994996
}
995997
} else {
@@ -1000,7 +1002,7 @@ case class TruncateTimestamp(timestamp: Expression, format: Expression)
10001002
if ($form == -1) {
10011003
${ev.isNull} = true;
10021004
} else {
1003-
${ev.value} = $dtu.truncateTimestamp($dateVal, $form);
1005+
${ev.value} = $dtu.truncateInstant($dateVal, $form);
10041006
}
10051007
"""
10061008
})

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -839,9 +839,9 @@ object DateTimeUtils {
839839
* Returns the trunc timestamp from original timestamp and trunc level.
840840
* Trunc level should be generated using `parseTruncLevel()`, should only be 1 - 6.
841841
*/
842-
def truncateTimestamp(ts: SQLTimestamp, level: Int): SQLTimestamp = {
842+
def truncateInstant(ts: SQLTimestamp, level: Int): SQLTimestamp = {
843843
if (level == TRUNC_TO_YEAR || level == TRUNC_TO_MONTH) {
844-
daysToMillis(truncateDate(millisToDays(ts / 1000L), level)) * 1000L
844+
daysToMillis(truncateInstant(millisToDays(ts / 1000L), level)) * 1000L
845845
} else {
846846
truncateTime(ts, level)
847847
}
@@ -851,7 +851,7 @@ object DateTimeUtils {
851851
* Returns the trunc date from original date and trunc level.
852852
* Trunc level should be generated using `parseTruncLevel()`, should only be 1 or 2.
853853
*/
854-
private def truncateDate(d: SQLDate, level: Int): SQLDate = {
854+
def truncateInstant(d: SQLDate, level: Int): SQLDate = {
855855
if (level == TRUNC_TO_YEAR) {
856856
d - DateTimeUtils.getDayInYear(d) + 1
857857
} else if (level == TRUNC_TO_MONTH) {

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala

Lines changed: 34 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -234,22 +234,22 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
234234

235235
test("date_sub") {
236236
checkEvaluation(
237-
SubDays(Literal(Date.valueOf("2015-01-01")), Literal(1)),
237+
DateSub(Literal(Date.valueOf("2015-01-01")), Literal(1)),
238238
DateTimeUtils.fromJavaDate(Date.valueOf("2014-12-31")))
239239
checkEvaluation(
240-
SubDays(Literal(Date.valueOf("2015-01-01")), Literal(-1)),
240+
DateSub(Literal(Date.valueOf("2015-01-01")), Literal(-1)),
241241
DateTimeUtils.fromJavaDate(Date.valueOf("2015-01-02")))
242-
checkEvaluation(SubDays(Literal.create(null, DateType), Literal(1)), null)
243-
checkEvaluation(SubDays(Literal(Date.valueOf("2016-02-28")), Literal.create(null, IntegerType)),
242+
checkEvaluation(DateSub(Literal.create(null, DateType), Literal(1)), null)
243+
checkEvaluation(DateSub(Literal(Date.valueOf("2016-02-28")), Literal.create(null, IntegerType)),
244244
null)
245-
checkEvaluation(SubDays(Literal.create(null, DateType), Literal.create(null, IntegerType)),
245+
checkEvaluation(DateSub(Literal.create(null, DateType), Literal.create(null, IntegerType)),
246246
null)
247247
checkEvaluation(
248-
SubDays(Literal(Date.valueOf("2016-02-28")), positiveIntLit), -15909)
248+
DateSub(Literal(Date.valueOf("2016-02-28")), positiveIntLit), -15909)
249249
checkEvaluation(
250-
SubDays(Literal(Date.valueOf("2016-02-28")), negativeIntLit), 49628)
251-
checkConsistencyBetweenInterpretedAndCodegen(SubDays, DateType, IntegerType)
252-
checkEvaluation(SubDays(Literal(Timestamp.valueOf("2015-01-10 12:00:00")), Literal(1)),
250+
DateSub(Literal(Date.valueOf("2016-02-28")), negativeIntLit), 49628)
251+
checkConsistencyBetweenInterpretedAndCodegen(DateSub, DateType, IntegerType)
252+
checkEvaluation(DateSub(Literal(Timestamp.valueOf("2015-01-10 12:00:00")), Literal(1)),
253253
DateTimeUtils.fromJavaTimestamp(Timestamp.valueOf("2015-01-09 12:00:00")))
254254
}
255255

@@ -395,13 +395,35 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
395395
checkConsistencyBetweenInterpretedAndCodegen(ToDate, DateType)
396396
}
397397

398-
test("function trunc") {
398+
test("function trunc - date") {
399+
def testTrunc(input: Date, fmt: String, expected: Date): Unit = {
400+
checkEvaluation(TruncInstant(Literal.create(input, DateType), Literal.create(fmt, StringType)),
401+
expected)
402+
checkEvaluation(
403+
TruncInstant(Literal.create(input, DateType), NonFoldableLiteral.create(fmt, StringType)),
404+
expected)
405+
}
406+
407+
val date = Date.valueOf("2015-07-22")
408+
Seq("yyyy", "YYYY", "year", "YEAR", "yy", "YY").foreach { fmt =>
409+
testTrunc(date, fmt, Date.valueOf("2015-01-01"))
410+
}
411+
Seq("month", "MONTH", "mon", "MON", "mm", "MM").foreach { fmt =>
412+
testTrunc(date, fmt, Date.valueOf("2015-07-01"))
413+
}
414+
415+
testTrunc(date, null, null)
416+
testTrunc(null, "MON", null)
417+
testTrunc(null, null, null)
418+
}
419+
420+
test("function trunc - timestamp") {
399421
def testTrunc(input: Timestamp, fmt: String, expected: Timestamp): Unit = {
400422
checkEvaluation(
401-
TruncateTimestamp(Literal.create(input, TimestampType), Literal.create(fmt, StringType)),
423+
TruncInstant(Literal.create(input, TimestampType), Literal.create(fmt, StringType)),
402424
expected)
403425
checkEvaluation(
404-
TruncateTimestamp(
426+
TruncInstant(
405427
Literal.create(input, TimestampType), NonFoldableLiteral.create(fmt, StringType)),
406428
expected)
407429
}

sql/core/src/main/scala/org/apache/spark/sql/functions.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2381,7 +2381,7 @@ object functions {
23812381
* @group datetime_funcs
23822382
* @since 1.5.0
23832383
*/
2384-
def date_sub(start: Column, days: Int): Column = withExpr { SubDays(start.expr, Literal(days)) }
2384+
def date_sub(start: Column, days: Int): Column = withExpr { DateSub(start.expr, Literal(days)) }
23852385

23862386
/**
23872387
* Returns the number of days from `start` to `end`.
@@ -2561,7 +2561,7 @@ object functions {
25612561
* @since 1.5.0
25622562
*/
25632563
def trunc(date: Column, format: String): Column = withExpr {
2564-
TruncateTimestamp(date.expr, Literal(format))
2564+
TruncInstant(date.expr, Literal(format))
25652565
}
25662566

25672567
/**

0 commit comments

Comments
 (0)