Skip to content

Commit 83670fc

Browse files
adrian-wangrxin
authored andcommitted
[SPARK-8176] [SPARK-8197] [SQL] function to_date/ trunc
This PR is based on apache#6988 , thanks to adrian-wang . This brings two SQL functions: to_date() and trunc(). Closes apache#6988 Author: Daoyuan Wang <[email protected]> Author: Davies Liu <[email protected]> Closes apache#7805 from davies/to_date and squashes the following commits: 2c7beba [Davies Liu] Merge branch 'master' of github.com:apache/spark into to_date 310dd55 [Daoyuan Wang] remove dup test in rebase 980b092 [Daoyuan Wang] resolve rebase conflict a476c5a [Daoyuan Wang] address comments from davies d44ea5f [Daoyuan Wang] function to_date, trunc
1 parent 9307f56 commit 83670fc

File tree

8 files changed

+245
-2
lines changed

8 files changed

+245
-2
lines changed

python/pyspark/sql/functions.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -888,6 +888,36 @@ def months_between(date1, date2):
888888
return Column(sc._jvm.functions.months_between(_to_java_column(date1), _to_java_column(date2)))
889889

890890

891+
@since(1.5)
892+
def to_date(col):
893+
"""
894+
Converts the column of StringType or TimestampType into DateType.
895+
896+
>>> df = sqlContext.createDataFrame([('1997-02-28 10:30:00',)], ['t'])
897+
>>> df.select(to_date(df.t).alias('date')).collect()
898+
[Row(date=datetime.date(1997, 2, 28))]
899+
"""
900+
sc = SparkContext._active_spark_context
901+
return Column(sc._jvm.functions.to_date(_to_java_column(col)))
902+
903+
904+
@since(1.5)
905+
def trunc(date, format):
906+
"""
907+
Returns date truncated to the unit specified by the format.
908+
909+
:param format: 'year', 'YYYY', 'yy' or 'month', 'mon', 'mm'
910+
911+
>>> df = sqlContext.createDataFrame([('1997-02-28',)], ['d'])
912+
>>> df.select(trunc(df.d, 'year').alias('year')).collect()
913+
[Row(year=datetime.date(1997, 1, 1))]
914+
>>> df.select(trunc(df.d, 'mon').alias('month')).collect()
915+
[Row(month=datetime.date(1997, 2, 1))]
916+
"""
917+
sc = SparkContext._active_spark_context
918+
return Column(sc._jvm.functions.trunc(_to_java_column(date), format))
919+
920+
891921
@since(1.5)
892922
def size(col):
893923
"""

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -223,6 +223,8 @@ object FunctionRegistry {
223223
expression[NextDay]("next_day"),
224224
expression[Quarter]("quarter"),
225225
expression[Second]("second"),
226+
expression[ToDate]("to_date"),
227+
expression[TruncDate]("trunc"),
226228
expression[UnixTimestamp]("unix_timestamp"),
227229
expression[WeekOfYear]("weekofyear"),
228230
expression[Year]("year"),

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeFunctions.scala

Lines changed: 87 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -507,7 +507,6 @@ case class FromUnixTime(sec: Expression, format: Expression)
507507
})
508508
}
509509
}
510-
511510
}
512511

513512
/**
@@ -696,3 +695,90 @@ case class MonthsBetween(date1: Expression, date2: Expression)
696695
})
697696
}
698697
}
698+
699+
/**
700+
* Returns the date part of a timestamp or string.
701+
*/
702+
case class ToDate(child: Expression) extends UnaryExpression with ImplicitCastInputTypes {
703+
704+
// Implicit casting of spark will accept string in both date and timestamp format, as
705+
// well as TimestampType.
706+
override def inputTypes: Seq[AbstractDataType] = Seq(DateType)
707+
708+
override def dataType: DataType = DateType
709+
710+
override def eval(input: InternalRow): Any = child.eval(input)
711+
712+
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
713+
defineCodeGen(ctx, ev, d => d)
714+
}
715+
}
716+
717+
/*
718+
* Returns date truncated to the unit specified by the format.
719+
*/
720+
case class TruncDate(date: Expression, format: Expression)
721+
extends BinaryExpression with ImplicitCastInputTypes {
722+
override def left: Expression = date
723+
override def right: Expression = format
724+
725+
override def inputTypes: Seq[AbstractDataType] = Seq(DateType, StringType)
726+
override def dataType: DataType = DateType
727+
override def prettyName: String = "trunc"
728+
729+
lazy val minItemConst = DateTimeUtils.parseTruncLevel(format.eval().asInstanceOf[UTF8String])
730+
731+
override def eval(input: InternalRow): Any = {
732+
val minItem = if (format.foldable) {
733+
minItemConst
734+
} else {
735+
DateTimeUtils.parseTruncLevel(format.eval().asInstanceOf[UTF8String])
736+
}
737+
if (minItem == -1) {
738+
// unknown format
739+
null
740+
} else {
741+
val d = date.eval(input)
742+
if (d == null) {
743+
null
744+
} else {
745+
DateTimeUtils.truncDate(d.asInstanceOf[Int], minItem)
746+
}
747+
}
748+
}
749+
750+
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
751+
val dtu = DateTimeUtils.getClass.getName.stripSuffix("$")
752+
753+
if (format.foldable) {
754+
if (minItemConst == -1) {
755+
s"""
756+
boolean ${ev.isNull} = true;
757+
${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)};
758+
"""
759+
} else {
760+
val d = date.gen(ctx)
761+
s"""
762+
${d.code}
763+
boolean ${ev.isNull} = ${d.isNull};
764+
${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)};
765+
if (!${ev.isNull}) {
766+
${ev.primitive} = $dtu.truncDate(${d.primitive}, $minItemConst);
767+
}
768+
"""
769+
}
770+
} else {
771+
nullSafeCodeGen(ctx, ev, (dateVal, fmt) => {
772+
val form = ctx.freshName("form")
773+
s"""
774+
int $form = $dtu.parseTruncLevel($fmt);
775+
if ($form == -1) {
776+
${ev.isNull} = true;
777+
} else {
778+
${ev.primitive} = $dtu.truncDate($dateVal, $form);
779+
}
780+
"""
781+
})
782+
}
783+
}
784+
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -779,4 +779,38 @@ object DateTimeUtils {
779779
}
780780
date + (lastDayOfMonthInYear - dayInYear)
781781
}
782+
783+
private val TRUNC_TO_YEAR = 1
784+
private val TRUNC_TO_MONTH = 2
785+
private val TRUNC_INVALID = -1
786+
787+
/**
788+
* Returns the trunc date from original date and trunc level.
789+
* Trunc level should be generated using `parseTruncLevel()`, should only be 1 or 2.
790+
*/
791+
def truncDate(d: Int, level: Int): Int = {
792+
if (level == TRUNC_TO_YEAR) {
793+
d - DateTimeUtils.getDayInYear(d) + 1
794+
} else if (level == TRUNC_TO_MONTH) {
795+
d - DateTimeUtils.getDayOfMonth(d) + 1
796+
} else {
797+
throw new Exception(s"Invalid trunc level: $level")
798+
}
799+
}
800+
801+
/**
802+
* Returns the truncate level, could be TRUNC_YEAR, TRUNC_MONTH, or TRUNC_INVALID,
803+
* TRUNC_INVALID means unsupported truncate level.
804+
*/
805+
def parseTruncLevel(format: UTF8String): Int = {
806+
if (format == null) {
807+
TRUNC_INVALID
808+
} else {
809+
format.toString.toUpperCase match {
810+
case "YEAR" | "YYYY" | "YY" => TRUNC_TO_YEAR
811+
case "MON" | "MONTH" | "MM" => TRUNC_TO_MONTH
812+
case _ => TRUNC_INVALID
813+
}
814+
}
815+
}
782816
}

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -351,6 +351,34 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
351351
NextDay(Literal(Date.valueOf("2015-07-23")), Literal.create(null, StringType)), null)
352352
}
353353

354+
test("function to_date") {
355+
checkEvaluation(
356+
ToDate(Literal(Date.valueOf("2015-07-22"))),
357+
DateTimeUtils.fromJavaDate(Date.valueOf("2015-07-22")))
358+
checkEvaluation(ToDate(Literal.create(null, DateType)), null)
359+
}
360+
361+
test("function trunc") {
362+
def testTrunc(input: Date, fmt: String, expected: Date): Unit = {
363+
checkEvaluation(TruncDate(Literal.create(input, DateType), Literal.create(fmt, StringType)),
364+
expected)
365+
checkEvaluation(
366+
TruncDate(Literal.create(input, DateType), NonFoldableLiteral.create(fmt, StringType)),
367+
expected)
368+
}
369+
val date = Date.valueOf("2015-07-22")
370+
Seq("yyyy", "YYYY", "year", "YEAR", "yy", "YY").foreach{ fmt =>
371+
testTrunc(date, fmt, Date.valueOf("2015-01-01"))
372+
}
373+
Seq("month", "MONTH", "mon", "MON", "mm", "MM").foreach { fmt =>
374+
testTrunc(date, fmt, Date.valueOf("2015-07-01"))
375+
}
376+
testTrunc(date, "DD", null)
377+
testTrunc(date, null, null)
378+
testTrunc(null, "MON", null)
379+
testTrunc(null, null, null)
380+
}
381+
354382
test("from_unixtime") {
355383
val sdf1 = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss")
356384
val fmt2 = "yyyy-MM-dd HH:mm:ss.SSS"
@@ -405,5 +433,4 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
405433
checkEvaluation(
406434
UnixTimestamp(Literal("2015-07-24"), Literal("not a valid format")), null)
407435
}
408-
409436
}

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NonFoldableLiteral.scala

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,4 +47,8 @@ object NonFoldableLiteral {
4747
val lit = Literal(value)
4848
NonFoldableLiteral(lit.value, lit.dataType)
4949
}
50+
def create(value: Any, dataType: DataType): NonFoldableLiteral = {
51+
val lit = Literal.create(value, dataType)
52+
NonFoldableLiteral(lit.value, lit.dataType)
53+
}
5054
}

sql/core/src/main/scala/org/apache/spark/sql/functions.scala

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2181,6 +2181,22 @@ object functions {
21812181
*/
21822182
def unix_timestamp(s: Column, p: String): Column = UnixTimestamp(s.expr, Literal(p))
21832183

2184+
/*
2185+
* Converts the column into DateType.
2186+
*
2187+
* @group datetime_funcs
2188+
* @since 1.5.0
2189+
*/
2190+
def to_date(e: Column): Column = ToDate(e.expr)
2191+
2192+
/**
2193+
* Returns date truncated to the unit specified by the format.
2194+
*
2195+
* @group datetime_funcs
2196+
* @since 1.5.0
2197+
*/
2198+
def trunc(date: Column, format: String): Column = TruncDate(date.expr, Literal(format))
2199+
21842200
//////////////////////////////////////////////////////////////////////////////////////////////
21852201
// Collection functions
21862202
//////////////////////////////////////////////////////////////////////////////////////////////

sql/core/src/test/scala/org/apache/spark/sql/DateFunctionsSuite.scala

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -345,6 +345,50 @@ class DateFunctionsSuite extends QueryTest {
345345
Seq(Row(Date.valueOf("2015-07-30")), Row(Date.valueOf("2015-07-30"))))
346346
}
347347

348+
test("function to_date") {
349+
val d1 = Date.valueOf("2015-07-22")
350+
val d2 = Date.valueOf("2015-07-01")
351+
val t1 = Timestamp.valueOf("2015-07-22 10:00:00")
352+
val t2 = Timestamp.valueOf("2014-12-31 23:59:59")
353+
val s1 = "2015-07-22 10:00:00"
354+
val s2 = "2014-12-31"
355+
val df = Seq((d1, t1, s1), (d2, t2, s2)).toDF("d", "t", "s")
356+
357+
checkAnswer(
358+
df.select(to_date(col("t"))),
359+
Seq(Row(Date.valueOf("2015-07-22")), Row(Date.valueOf("2014-12-31"))))
360+
checkAnswer(
361+
df.select(to_date(col("d"))),
362+
Seq(Row(Date.valueOf("2015-07-22")), Row(Date.valueOf("2015-07-01"))))
363+
checkAnswer(
364+
df.select(to_date(col("s"))),
365+
Seq(Row(Date.valueOf("2015-07-22")), Row(Date.valueOf("2014-12-31"))))
366+
367+
checkAnswer(
368+
df.selectExpr("to_date(t)"),
369+
Seq(Row(Date.valueOf("2015-07-22")), Row(Date.valueOf("2014-12-31"))))
370+
checkAnswer(
371+
df.selectExpr("to_date(d)"),
372+
Seq(Row(Date.valueOf("2015-07-22")), Row(Date.valueOf("2015-07-01"))))
373+
checkAnswer(
374+
df.selectExpr("to_date(s)"),
375+
Seq(Row(Date.valueOf("2015-07-22")), Row(Date.valueOf("2014-12-31"))))
376+
}
377+
378+
test("function trunc") {
379+
val df = Seq(
380+
(1, Timestamp.valueOf("2015-07-22 10:00:00")),
381+
(2, Timestamp.valueOf("2014-12-31 00:00:00"))).toDF("i", "t")
382+
383+
checkAnswer(
384+
df.select(trunc(col("t"), "YY")),
385+
Seq(Row(Date.valueOf("2015-01-01")), Row(Date.valueOf("2014-01-01"))))
386+
387+
checkAnswer(
388+
df.selectExpr("trunc(t, 'Month')"),
389+
Seq(Row(Date.valueOf("2015-07-01")), Row(Date.valueOf("2014-12-01"))))
390+
}
391+
348392
test("from_unixtime") {
349393
val sdf1 = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss")
350394
val fmt2 = "yyyy-MM-dd HH:mm:ss.SSS"

0 commit comments

Comments
 (0)