Skip to content

Commit d44ea5f

Browse files
committed
function to_date, trunc
1 parent 708794e commit d44ea5f

File tree

6 files changed

+205
-16
lines changed

6 files changed

+205
-16
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -218,6 +218,8 @@ object FunctionRegistry {
218218
expression[NextDay]("next_day"),
219219
expression[Quarter]("quarter"),
220220
expression[Second]("second"),
221+
expression[ToDate]("to_date"),
222+
expression[Trunc]("trunc"),
221223
expression[WeekOfYear]("weekofyear"),
222224
expression[Year]("year"),
223225

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeFunctions.scala

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -272,6 +272,24 @@ case class LastDay(startDate: Expression) extends UnaryExpression with ImplicitC
272272
override def prettyName: String = "last_day"
273273
}
274274

275+
/**
276+
* Returns the date part of a timestamp string.
277+
*/
278+
case class ToDate(child: Expression) extends UnaryExpression with ImplicitCastInputTypes {
279+
280+
override def inputTypes: Seq[AbstractDataType] = Seq(DateType)
281+
282+
override def dataType: DataType = DateType
283+
284+
override def eval(input: InternalRow): Any = {
285+
child.eval(input)
286+
}
287+
288+
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
289+
defineCodeGen(ctx, ev, (time) => time)
290+
}
291+
}
292+
275293
/**
276294
* Returns the first date which is later than startDate and named as dayOfWeek.
277295
* For example, NextDay(2015-07-27, Sunday) would return 2015-08-02, which is the first
@@ -283,6 +301,7 @@ case class NextDay(startDate: Expression, dayOfWeek: Expression)
283301
extends BinaryExpression with ImplicitCastInputTypes {
284302

285303
override def left: Expression = startDate
304+
286305
override def right: Expression = dayOfWeek
287306

288307
override def inputTypes: Seq[AbstractDataType] = Seq(DateType, StringType)
@@ -330,3 +349,49 @@ case class NextDay(startDate: Expression, dayOfWeek: Expression)
330349

331350
override def prettyName: String = "next_day"
332351
}
352+
353+
/**
354+
* Returns date truncated to the unit specified by the format.
355+
*/
356+
case class Trunc(date: Expression, format: Expression)
357+
extends BinaryExpression with ImplicitCastInputTypes {
358+
override def left: Expression = date
359+
override def right: Expression = format
360+
361+
override def inputTypes: Seq[AbstractDataType] = Seq(DateType, StringType)
362+
override def dataType: DataType = DateType
363+
364+
override def nullSafeEval(d: Any, fmt: Any): Any = {
365+
val minItem = DateTimeUtils.getFmt(fmt.asInstanceOf[UTF8String])
366+
if (minItem == -1) {
367+
// unknown format
368+
null
369+
} else {
370+
val days = d.asInstanceOf[Int]
371+
if (minItem == Calendar.YEAR) {
372+
days - DateTimeUtils.getDayInYear(days) + 1
373+
} else {
374+
// trunc to MONTH
375+
days - DateTimeUtils.getDayOfMonth(days) + 1
376+
}
377+
}
378+
}
379+
380+
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
381+
nullSafeCodeGen(ctx, ev, (dateVal, fmt) => {
382+
val dtu = DateTimeUtils.getClass.getName.stripSuffix("$")
383+
val form = ctx.freshName("form")
384+
s"""
385+
int $form = $dtu.getFmt($fmt);
386+
if ($form == ${Calendar.YEAR}) {
387+
${ev.primitive} = $dateVal - $dtu.getDayInYear($dateVal) + 1;
388+
} else if ($form == ${Calendar.MONTH}) {
389+
${ev.primitive} = $dateVal - $dtu.getDayInYear($dateVal) + 1;
390+
} else {
391+
${ev.isNull} = true;
392+
}
393+
"""
394+
})
395+
}
396+
397+
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -640,4 +640,17 @@ object DateTimeUtils {
640640
}
641641
date + (lastDayOfMonthInYear - dayInYear)
642642
}
643+
644+
/**
645+
* Returns the truncate level, could be [[Calendar.MONTH]]/[[Calendar.YEAR]]/-1
646+
* -1 means unsupported truncate level.
647+
*/
648+
def getFmt(string: UTF8String): Int = {
649+
val fmtString = string.toString.toUpperCase
650+
fmtString match {
651+
case "MON" | "MONTH" | "MM" => Calendar.MONTH
652+
case "YEAR"| "YYYY" | "YY" => Calendar.YEAR
653+
case _ => -1
654+
}
655+
}
643656
}

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -303,4 +303,35 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
303303
checkEvaluation(
304304
NextDay(Literal(Date.valueOf("2015-07-23")), Literal.create(null, StringType)), null)
305305
}
306+
307+
test("datetime function current_date") {
308+
val d0 = DateTimeUtils.millisToDays(System.currentTimeMillis())
309+
val cd = CurrentDate().eval(EmptyRow).asInstanceOf[Int]
310+
val d1 = DateTimeUtils.millisToDays(System.currentTimeMillis())
311+
assert(d0 <= cd && cd <= d1 && d1 - d0 <= 1)
312+
}
313+
314+
test("datetime function current_timestamp") {
315+
val ct = DateTimeUtils.toJavaTimestamp(CurrentTimestamp().eval(EmptyRow).asInstanceOf[Long])
316+
val t1 = System.currentTimeMillis()
317+
assert(math.abs(t1 - ct.getTime) < 5000)
318+
}
319+
320+
test("function to_date") {
321+
checkEvaluation(
322+
ToDate(Literal(Date.valueOf("2015-07-22"))),
323+
DateTimeUtils.fromJavaDate(Date.valueOf("2015-07-22")))
324+
}
325+
326+
test("function trunc") {
327+
checkEvaluation(EqualTo(
328+
Trunc(Literal(Date.valueOf("2015-07-22")), Literal("YYYY")),
329+
Trunc(Literal(Date.valueOf("2015-01-01")), Literal("YEAR"))), true)
330+
331+
checkEvaluation(EqualTo(
332+
Trunc(Literal(Date.valueOf("2015-07-22")), Literal("MONTH")),
333+
Trunc(Literal(Date.valueOf("2015-07-01")), Literal("mm"))), true)
334+
335+
checkEvaluation(Trunc(Literal(Date.valueOf("2015-07-22")), Literal("DD")), null)
336+
}
306337
}

sql/core/src/main/scala/org/apache/spark/sql/functions.scala

Lines changed: 32 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1015,22 +1015,6 @@ object functions {
10151015
*/
10161016
def cosh(columnName: String): Column = cosh(Column(columnName))
10171017

1018-
/**
1019-
* Returns the current date.
1020-
*
1021-
* @group datetime_funcs
1022-
* @since 1.5.0
1023-
*/
1024-
def current_date(): Column = CurrentDate()
1025-
1026-
/**
1027-
* Returns the current timestamp.
1028-
*
1029-
* @group datetime_funcs
1030-
* @since 1.5.0
1031-
*/
1032-
def current_timestamp(): Column = CurrentTimestamp()
1033-
10341018
/**
10351019
* Computes the exponential of the given value.
10361020
*
@@ -1916,6 +1900,22 @@ object functions {
19161900
// DateTime functions
19171901
//////////////////////////////////////////////////////////////////////////////////////////////
19181902

1903+
/**
1904+
* Returns the current date.
1905+
*
1906+
* @group datetime_funcs
1907+
* @since 1.5.0
1908+
*/
1909+
def current_date(): Column = CurrentDate()
1910+
1911+
/**
1912+
* Returns the current timestamp.
1913+
*
1914+
* @group datetime_funcs
1915+
* @since 1.5.0
1916+
*/
1917+
def current_timestamp(): Column = CurrentTimestamp()
1918+
19191919
/**
19201920
* Converts a date/timestamp/string to a value of string in the format specified by the date
19211921
* format given by the second argument.
@@ -2099,6 +2099,22 @@ object functions {
20992099
*/
21002100
def weekofyear(columnName: String): Column = weekofyear(Column(columnName))
21012101

2102+
/**
2103+
* Returns date truncated to the unit specified by the format.
2104+
*
2105+
* @group datetime_funcs
2106+
* @since 1.5.0
2107+
*/
2108+
def to_date(e: Column): Column = ToDate(e.expr)
2109+
2110+
/**
2111+
* Returns date truncated to the unit specified by the format.
2112+
*
2113+
* @group datetime_funcs
2114+
* @since 1.5.0
2115+
*/
2116+
def trunc(date: Column, format: Column): Column = Trunc(date.expr, format.expr)
2117+
21022118
//////////////////////////////////////////////////////////////////////////////////////////////
21032119
// Collection functions
21042120
//////////////////////////////////////////////////////////////////////////////////////////////

sql/core/src/test/scala/org/apache/spark/sql/DateFunctionsSuite.scala

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -228,4 +228,66 @@ class DateFunctionsSuite extends QueryTest {
228228
Seq(Row(Date.valueOf("2015-07-30")), Row(Date.valueOf("2015-07-30"))))
229229
}
230230

231+
test("function current_date") {
232+
val df = Seq((1, 2), (3, 1)).toDF("a", "b")
233+
val d0 = DateTimeUtils.millisToDays(System.currentTimeMillis())
234+
val d1 = DateTimeUtils.fromJavaDate(df.select(current_date()).collect().head.getDate(0))
235+
val d2 = DateTimeUtils.fromJavaDate(
236+
ctx.sql("""SELECT CURRENT_DATE()""").collect().head.getDate(0))
237+
val d3 = DateTimeUtils.millisToDays(System.currentTimeMillis())
238+
assert(d0 <= d1 && d1 <= d2 && d2 <= d3 && d3 - d0 <= 1)
239+
}
240+
241+
test("function current_timestamp") {
242+
val df = Seq((1, 2), (3, 1)).toDF("a", "b")
243+
checkAnswer(df.select(countDistinct(current_timestamp())), Row(1))
244+
// TODO SPARK-9196: Execution in one query should return the same value
245+
assert(math.abs(ctx.sql("""SELECT CURRENT_TIMESTAMP()""").collect().head.getTimestamp(
246+
0).getTime - System.currentTimeMillis()) < 5000)
247+
}
248+
249+
test("function to_date") {
250+
val d1 = Date.valueOf("2015-07-22")
251+
val d2 = Date.valueOf("2015-07-01")
252+
val t1 = Timestamp.valueOf("2015-07-22 10:00:00")
253+
val t2 = Timestamp.valueOf("2014-12-31 23:59:59")
254+
val s1 = "2015-07-22 10:00:00"
255+
val s2 = "2014-12-31"
256+
val df = Seq((d1, t1, s1), (d2, t2, s2)).toDF("d", "t", "s")
257+
258+
checkAnswer(
259+
df.select(to_date(col("t"))),
260+
Seq(Row(Date.valueOf("2015-07-22")), Row(Date.valueOf("2014-12-31"))))
261+
checkAnswer(
262+
df.select(to_date(col("d"))),
263+
Seq(Row(Date.valueOf("2015-07-22")), Row(Date.valueOf("2015-07-01"))))
264+
checkAnswer(
265+
df.select(to_date(col("s"))),
266+
Seq(Row(Date.valueOf("2015-07-22")), Row(Date.valueOf("2014-12-31"))))
267+
268+
checkAnswer(
269+
df.selectExpr("to_date(t)"),
270+
Seq(Row(Date.valueOf("2015-07-22")), Row(Date.valueOf("2014-12-31"))))
271+
checkAnswer(
272+
df.selectExpr("to_date(d)"),
273+
Seq(Row(Date.valueOf("2015-07-22")), Row(Date.valueOf("2015-07-01"))))
274+
checkAnswer(
275+
df.selectExpr("to_date(s)"),
276+
Seq(Row(Date.valueOf("2015-07-22")), Row(Date.valueOf("2014-12-31"))))
277+
}
278+
279+
test("function trunc") {
280+
val df = Seq(
281+
(1, Timestamp.valueOf("2015-07-22 10:00:00")),
282+
(2, Timestamp.valueOf("2014-12-31 00:00:00"))).toDF("i", "t")
283+
284+
checkAnswer(
285+
df.select(trunc(col("t"), lit("YY"))),
286+
Seq(Row(Date.valueOf("2015-01-01")), Row(Date.valueOf("2014-01-01"))))
287+
288+
289+
checkAnswer(
290+
df.selectExpr("trunc(t, 'Month')"),
291+
Seq(Row(Date.valueOf("2015-07-01")), Row(Date.valueOf("2014-12-01"))))
292+
}
231293
}

0 commit comments

Comments
 (0)