Skip to content

Commit 1abf7dc

Browse files
adrian-wangdavies
authored andcommitted
[SPARK-8186] [SPARK-8187] [SPARK-8194] [SPARK-8198] [SPARK-9133] [SPARK-9290] [SQL] functions: date_add, date_sub, add_months, months_between, time-interval calculation
This PR is based on apache#7589 , thanks to adrian-wang Added SQL function date_add, date_sub, add_months, month_between, also add a rule for add/subtract of date/timestamp and interval. Closes apache#7589 cc rxin Author: Daoyuan Wang <[email protected]> Author: Davies Liu <[email protected]> Closes apache#7754 from davies/date_add and squashes the following commits: e8c633a [Davies Liu] Merge branch 'master' of github.com:apache/spark into date_add 9e8e085 [Davies Liu] Merge branch 'master' of github.com:apache/spark into date_add 6224ce4 [Davies Liu] fix conclict bd18cd4 [Davies Liu] Merge branch 'master' of github.com:apache/spark into date_add e47ff2c [Davies Liu] add python api, fix date functions 01943d0 [Davies Liu] Merge branch 'master' into date_add 522e91a [Daoyuan Wang] fix e8a639a [Daoyuan Wang] fix 42df486 [Daoyuan Wang] fix style 87c4b77 [Daoyuan Wang] function add_months, months_between and some fixes 1a68e03 [Daoyuan Wang] poc of time interval calculation c506661 [Daoyuan Wang] function date_add , date_sub
1 parent d8cfd53 commit 1abf7dc

File tree

10 files changed

+791
-162
lines changed

10 files changed

+791
-162
lines changed

python/pyspark/sql/functions.py

Lines changed: 64 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@
5959
__all__ += ['lag', 'lead', 'ntile']
6060

6161
__all__ += [
62-
'date_format',
62+
'date_format', 'date_add', 'date_sub', 'add_months', 'months_between',
6363
'year', 'quarter', 'month', 'hour', 'minute', 'second',
6464
'dayofmonth', 'dayofyear', 'weekofyear']
6565

@@ -716,7 +716,7 @@ def date_format(dateCol, format):
716716
[Row(date=u'04/08/2015')]
717717
"""
718718
sc = SparkContext._active_spark_context
719-
return Column(sc._jvm.functions.date_format(dateCol, format))
719+
return Column(sc._jvm.functions.date_format(_to_java_column(dateCol), format))
720720

721721

722722
@since(1.5)
@@ -729,7 +729,7 @@ def year(col):
729729
[Row(year=2015)]
730730
"""
731731
sc = SparkContext._active_spark_context
732-
return Column(sc._jvm.functions.year(col))
732+
return Column(sc._jvm.functions.year(_to_java_column(col)))
733733

734734

735735
@since(1.5)
@@ -742,7 +742,7 @@ def quarter(col):
742742
[Row(quarter=2)]
743743
"""
744744
sc = SparkContext._active_spark_context
745-
return Column(sc._jvm.functions.quarter(col))
745+
return Column(sc._jvm.functions.quarter(_to_java_column(col)))
746746

747747

748748
@since(1.5)
@@ -755,7 +755,7 @@ def month(col):
755755
[Row(month=4)]
756756
"""
757757
sc = SparkContext._active_spark_context
758-
return Column(sc._jvm.functions.month(col))
758+
return Column(sc._jvm.functions.month(_to_java_column(col)))
759759

760760

761761
@since(1.5)
@@ -768,7 +768,7 @@ def dayofmonth(col):
768768
[Row(day=8)]
769769
"""
770770
sc = SparkContext._active_spark_context
771-
return Column(sc._jvm.functions.dayofmonth(col))
771+
return Column(sc._jvm.functions.dayofmonth(_to_java_column(col)))
772772

773773

774774
@since(1.5)
@@ -781,7 +781,7 @@ def dayofyear(col):
781781
[Row(day=98)]
782782
"""
783783
sc = SparkContext._active_spark_context
784-
return Column(sc._jvm.functions.dayofyear(col))
784+
return Column(sc._jvm.functions.dayofyear(_to_java_column(col)))
785785

786786

787787
@since(1.5)
@@ -794,7 +794,7 @@ def hour(col):
794794
[Row(hour=13)]
795795
"""
796796
sc = SparkContext._active_spark_context
797-
return Column(sc._jvm.functions.hour(col))
797+
return Column(sc._jvm.functions.hour(_to_java_column(col)))
798798

799799

800800
@since(1.5)
@@ -807,7 +807,7 @@ def minute(col):
807807
[Row(minute=8)]
808808
"""
809809
sc = SparkContext._active_spark_context
810-
return Column(sc._jvm.functions.minute(col))
810+
return Column(sc._jvm.functions.minute(_to_java_column(col)))
811811

812812

813813
@since(1.5)
@@ -820,7 +820,7 @@ def second(col):
820820
[Row(second=15)]
821821
"""
822822
sc = SparkContext._active_spark_context
823-
return Column(sc._jvm.functions.second(col))
823+
return Column(sc._jvm.functions.second(_to_java_column(col)))
824824

825825

826826
@since(1.5)
@@ -829,11 +829,63 @@ def weekofyear(col):
829829
Extract the week number of a given date as integer.
830830
831831
>>> df = sqlContext.createDataFrame([('2015-04-08',)], ['a'])
832-
>>> df.select(weekofyear('a').alias('week')).collect()
832+
>>> df.select(weekofyear(df.a).alias('week')).collect()
833833
[Row(week=15)]
834834
"""
835835
sc = SparkContext._active_spark_context
836-
return Column(sc._jvm.functions.weekofyear(col))
836+
return Column(sc._jvm.functions.weekofyear(_to_java_column(col)))
837+
838+
839+
@since(1.5)
840+
def date_add(start, days):
841+
"""
842+
Returns the date that is `days` days after `start`
843+
844+
>>> df = sqlContext.createDataFrame([('2015-04-08',)], ['d'])
845+
>>> df.select(date_add(df.d, 1).alias('d')).collect()
846+
[Row(d=datetime.date(2015, 4, 9))]
847+
"""
848+
sc = SparkContext._active_spark_context
849+
return Column(sc._jvm.functions.date_add(_to_java_column(start), days))
850+
851+
852+
@since(1.5)
853+
def date_sub(start, days):
854+
"""
855+
Returns the date that is `days` days before `start`
856+
857+
>>> df = sqlContext.createDataFrame([('2015-04-08',)], ['d'])
858+
>>> df.select(date_sub(df.d, 1).alias('d')).collect()
859+
[Row(d=datetime.date(2015, 4, 7))]
860+
"""
861+
sc = SparkContext._active_spark_context
862+
return Column(sc._jvm.functions.date_sub(_to_java_column(start), days))
863+
864+
865+
@since(1.5)
866+
def add_months(start, months):
867+
"""
868+
Returns the date that is `months` months after `start`
869+
870+
>>> df = sqlContext.createDataFrame([('2015-04-08',)], ['d'])
871+
>>> df.select(add_months(df.d, 1).alias('d')).collect()
872+
[Row(d=datetime.date(2015, 5, 8))]
873+
"""
874+
sc = SparkContext._active_spark_context
875+
return Column(sc._jvm.functions.add_months(_to_java_column(start), months))
876+
877+
878+
@since(1.5)
879+
def months_between(date1, date2):
880+
"""
881+
Returns the number of months between date1 and date2.
882+
883+
>>> df = sqlContext.createDataFrame([('1997-02-28 10:30:00', '1996-10-30')], ['t', 'd'])
884+
>>> df.select(months_between(df.t, df.d).alias('months')).collect()
885+
[Row(months=3.9495967...)]
886+
"""
887+
sc = SparkContext._active_spark_context
888+
return Column(sc._jvm.functions.months_between(_to_java_column(date1), _to_java_column(date2)))
837889

838890

839891
@since(1.5)

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -205,9 +205,12 @@ object FunctionRegistry {
205205
expression[Upper]("upper"),
206206

207207
// datetime functions
208+
expression[AddMonths]("add_months"),
208209
expression[CurrentDate]("current_date"),
209210
expression[CurrentTimestamp]("current_timestamp"),
211+
expression[DateAdd]("date_add"),
210212
expression[DateFormatClass]("date_format"),
213+
expression[DateSub]("date_sub"),
211214
expression[DayOfMonth]("day"),
212215
expression[DayOfYear]("dayofyear"),
213216
expression[DayOfMonth]("dayofmonth"),
@@ -216,6 +219,7 @@ object FunctionRegistry {
216219
expression[LastDay]("last_day"),
217220
expression[Minute]("minute"),
218221
expression[Month]("month"),
222+
expression[MonthsBetween]("months_between"),
219223
expression[NextDay]("next_day"),
220224
expression[Quarter]("quarter"),
221225
expression[Second]("second"),

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ object HiveTypeCoercion {
4747
Division ::
4848
PropagateTypes ::
4949
ImplicitTypeCasts ::
50+
DateTimeOperations ::
5051
Nil
5152

5253
// See https://cwiki.apache.org/confluence/display/Hive/LanguageManual+Types.
@@ -638,6 +639,27 @@ object HiveTypeCoercion {
638639
}
639640
}
640641

642+
/**
643+
* Turns Add/Subtract of DateType/TimestampType/StringType and CalendarIntervalType
644+
* to TimeAdd/TimeSub
645+
*/
646+
object DateTimeOperations extends Rule[LogicalPlan] {
647+
648+
private val acceptedTypes = Seq(DateType, TimestampType, StringType)
649+
650+
def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions {
651+
// Skip nodes who's children have not been resolved yet.
652+
case e if !e.childrenResolved => e
653+
654+
case Add(l @ CalendarIntervalType(), r) if acceptedTypes.contains(r.dataType) =>
655+
Cast(TimeAdd(r, l), r.dataType)
656+
case Add(l, r @ CalendarIntervalType()) if acceptedTypes.contains(l.dataType) =>
657+
Cast(TimeAdd(l, r), l.dataType)
658+
case Subtract(l, r @ CalendarIntervalType()) if acceptedTypes.contains(l.dataType) =>
659+
Cast(TimeSub(l, r), l.dataType)
660+
}
661+
}
662+
641663
/**
642664
* Casts types according to the expected input types for [[Expression]]s.
643665
*/

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeFunctions.scala

Lines changed: 154 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ import org.apache.spark.sql.catalyst.InternalRow
2525
import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
2626
import org.apache.spark.sql.catalyst.util.DateTimeUtils
2727
import org.apache.spark.sql.types._
28-
import org.apache.spark.unsafe.types.UTF8String
28+
import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String}
2929

3030
import scala.util.Try
3131

@@ -63,6 +63,53 @@ case class CurrentTimestamp() extends LeafExpression with CodegenFallback {
6363
}
6464
}
6565

66+
/**
67+
* Adds a number of days to startdate.
68+
*/
69+
case class DateAdd(startDate: Expression, days: Expression)
70+
extends BinaryExpression with ImplicitCastInputTypes {
71+
72+
override def left: Expression = startDate
73+
override def right: Expression = days
74+
75+
override def inputTypes: Seq[AbstractDataType] = Seq(DateType, IntegerType)
76+
77+
override def dataType: DataType = DateType
78+
79+
override def nullSafeEval(start: Any, d: Any): Any = {
80+
start.asInstanceOf[Int] + d.asInstanceOf[Int]
81+
}
82+
83+
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
84+
nullSafeCodeGen(ctx, ev, (sd, d) => {
85+
s"""${ev.primitive} = $sd + $d;"""
86+
})
87+
}
88+
}
89+
90+
/**
91+
* Subtracts a number of days to startdate.
92+
*/
93+
case class DateSub(startDate: Expression, days: Expression)
94+
extends BinaryExpression with ImplicitCastInputTypes {
95+
override def left: Expression = startDate
96+
override def right: Expression = days
97+
98+
override def inputTypes: Seq[AbstractDataType] = Seq(DateType, IntegerType)
99+
100+
override def dataType: DataType = DateType
101+
102+
override def nullSafeEval(start: Any, d: Any): Any = {
103+
start.asInstanceOf[Int] - d.asInstanceOf[Int]
104+
}
105+
106+
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
107+
nullSafeCodeGen(ctx, ev, (sd, d) => {
108+
s"""${ev.primitive} = $sd - $d;"""
109+
})
110+
}
111+
}
112+
66113
case class Hour(child: Expression) extends UnaryExpression with ImplicitCastInputTypes {
67114

68115
override def inputTypes: Seq[AbstractDataType] = Seq(TimestampType)
@@ -543,3 +590,109 @@ case class NextDay(startDate: Expression, dayOfWeek: Expression)
543590

544591
override def prettyName: String = "next_day"
545592
}
593+
594+
/**
595+
* Adds an interval to timestamp.
596+
*/
597+
case class TimeAdd(start: Expression, interval: Expression)
598+
extends BinaryExpression with ImplicitCastInputTypes {
599+
600+
override def left: Expression = start
601+
override def right: Expression = interval
602+
603+
override def toString: String = s"$left + $right"
604+
override def inputTypes: Seq[AbstractDataType] = Seq(TimestampType, CalendarIntervalType)
605+
606+
override def dataType: DataType = TimestampType
607+
608+
override def nullSafeEval(start: Any, interval: Any): Any = {
609+
val itvl = interval.asInstanceOf[CalendarInterval]
610+
DateTimeUtils.timestampAddInterval(
611+
start.asInstanceOf[Long], itvl.months, itvl.microseconds)
612+
}
613+
614+
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
615+
val dtu = DateTimeUtils.getClass.getName.stripSuffix("$")
616+
defineCodeGen(ctx, ev, (sd, i) => {
617+
s"""$dtu.timestampAddInterval($sd, $i.months, $i.microseconds)"""
618+
})
619+
}
620+
}
621+
622+
/**
623+
* Subtracts an interval from timestamp.
624+
*/
625+
case class TimeSub(start: Expression, interval: Expression)
626+
extends BinaryExpression with ImplicitCastInputTypes {
627+
628+
override def left: Expression = start
629+
override def right: Expression = interval
630+
631+
override def toString: String = s"$left - $right"
632+
override def inputTypes: Seq[AbstractDataType] = Seq(TimestampType, CalendarIntervalType)
633+
634+
override def dataType: DataType = TimestampType
635+
636+
override def nullSafeEval(start: Any, interval: Any): Any = {
637+
val itvl = interval.asInstanceOf[CalendarInterval]
638+
DateTimeUtils.timestampAddInterval(
639+
start.asInstanceOf[Long], 0 - itvl.months, 0 - itvl.microseconds)
640+
}
641+
642+
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
643+
val dtu = DateTimeUtils.getClass.getName.stripSuffix("$")
644+
defineCodeGen(ctx, ev, (sd, i) => {
645+
s"""$dtu.timestampAddInterval($sd, 0 - $i.months, 0 - $i.microseconds)"""
646+
})
647+
}
648+
}
649+
650+
/**
651+
* Returns the date that is num_months after start_date.
652+
*/
653+
case class AddMonths(startDate: Expression, numMonths: Expression)
654+
extends BinaryExpression with ImplicitCastInputTypes {
655+
656+
override def left: Expression = startDate
657+
override def right: Expression = numMonths
658+
659+
override def inputTypes: Seq[AbstractDataType] = Seq(DateType, IntegerType)
660+
661+
override def dataType: DataType = DateType
662+
663+
override def nullSafeEval(start: Any, months: Any): Any = {
664+
DateTimeUtils.dateAddMonths(start.asInstanceOf[Int], months.asInstanceOf[Int])
665+
}
666+
667+
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
668+
val dtu = DateTimeUtils.getClass.getName.stripSuffix("$")
669+
defineCodeGen(ctx, ev, (sd, m) => {
670+
s"""$dtu.dateAddMonths($sd, $m)"""
671+
})
672+
}
673+
}
674+
675+
/**
676+
* Returns number of months between dates date1 and date2.
677+
*/
678+
case class MonthsBetween(date1: Expression, date2: Expression)
679+
extends BinaryExpression with ImplicitCastInputTypes {
680+
681+
override def left: Expression = date1
682+
override def right: Expression = date2
683+
684+
override def inputTypes: Seq[AbstractDataType] = Seq(TimestampType, TimestampType)
685+
686+
override def dataType: DataType = DoubleType
687+
688+
override def nullSafeEval(t1: Any, t2: Any): Any = {
689+
DateTimeUtils.monthsBetween(t1.asInstanceOf[Long], t2.asInstanceOf[Long])
690+
}
691+
692+
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
693+
val dtu = DateTimeUtils.getClass.getName.stripSuffix("$")
694+
defineCodeGen(ctx, ev, (l, r) => {
695+
s"""$dtu.monthsBetween($l, $r)"""
696+
})
697+
}
698+
}

0 commit comments

Comments
 (0)