Skip to content

Commit e47ff2c

Browse files
author
Davies Liu
committed
add python api, fix date functions
1 parent 01943d0 commit e47ff2c

File tree

8 files changed

+380
-349
lines changed

8 files changed

+380
-349
lines changed

python/pyspark/sql/functions.py

Lines changed: 64 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@
5959
__all__ += ['lag', 'lead', 'ntile']
6060

6161
__all__ += [
62-
'date_format',
62+
'date_format', 'date_add', 'date_sub', 'add_months', 'months_between',
6363
'year', 'quarter', 'month', 'hour', 'minute', 'second',
6464
'dayofmonth', 'dayofyear', 'weekofyear']
6565

@@ -716,7 +716,7 @@ def date_format(dateCol, format):
716716
[Row(date=u'04/08/2015')]
717717
"""
718718
sc = SparkContext._active_spark_context
719-
return Column(sc._jvm.functions.date_format(dateCol, format))
719+
return Column(sc._jvm.functions.date_format(_to_java_column(dateCol), format))
720720

721721

722722
@since(1.5)
@@ -729,7 +729,7 @@ def year(col):
729729
[Row(year=2015)]
730730
"""
731731
sc = SparkContext._active_spark_context
732-
return Column(sc._jvm.functions.year(col))
732+
return Column(sc._jvm.functions.year(_to_java_column(col)))
733733

734734

735735
@since(1.5)
@@ -742,7 +742,7 @@ def quarter(col):
742742
[Row(quarter=2)]
743743
"""
744744
sc = SparkContext._active_spark_context
745-
return Column(sc._jvm.functions.quarter(col))
745+
return Column(sc._jvm.functions.quarter(_to_java_column(col)))
746746

747747

748748
@since(1.5)
@@ -755,7 +755,7 @@ def month(col):
755755
[Row(month=4)]
756756
"""
757757
sc = SparkContext._active_spark_context
758-
return Column(sc._jvm.functions.month(col))
758+
return Column(sc._jvm.functions.month(_to_java_column(col)))
759759

760760

761761
@since(1.5)
@@ -768,7 +768,7 @@ def dayofmonth(col):
768768
[Row(day=8)]
769769
"""
770770
sc = SparkContext._active_spark_context
771-
return Column(sc._jvm.functions.dayofmonth(col))
771+
return Column(sc._jvm.functions.dayofmonth(_to_java_column(col)))
772772

773773

774774
@since(1.5)
@@ -781,7 +781,7 @@ def dayofyear(col):
781781
[Row(day=98)]
782782
"""
783783
sc = SparkContext._active_spark_context
784-
return Column(sc._jvm.functions.dayofyear(col))
784+
return Column(sc._jvm.functions.dayofyear(_to_java_column(col)))
785785

786786

787787
@since(1.5)
@@ -794,7 +794,7 @@ def hour(col):
794794
[Row(hour=13)]
795795
"""
796796
sc = SparkContext._active_spark_context
797-
return Column(sc._jvm.functions.hour(col))
797+
return Column(sc._jvm.functions.hour(_to_java_column(col)))
798798

799799

800800
@since(1.5)
@@ -807,7 +807,7 @@ def minute(col):
807807
[Row(minute=8)]
808808
"""
809809
sc = SparkContext._active_spark_context
810-
return Column(sc._jvm.functions.minute(col))
810+
return Column(sc._jvm.functions.minute(_to_java_column(col)))
811811

812812

813813
@since(1.5)
@@ -820,7 +820,7 @@ def second(col):
820820
[Row(second=15)]
821821
"""
822822
sc = SparkContext._active_spark_context
823-
return Column(sc._jvm.functions.second(col))
823+
return Column(sc._jvm.functions.second(_to_java_column(col)))
824824

825825

826826
@since(1.5)
@@ -829,11 +829,63 @@ def weekofyear(col):
829829
Extract the week number of a given date as integer.
830830
831831
>>> df = sqlContext.createDataFrame([('2015-04-08',)], ['a'])
832-
>>> df.select(weekofyear('a').alias('week')).collect()
832+
>>> df.select(weekofyear(df.a).alias('week')).collect()
833833
[Row(week=15)]
834834
"""
835835
sc = SparkContext._active_spark_context
836-
return Column(sc._jvm.functions.weekofyear(col))
836+
return Column(sc._jvm.functions.weekofyear(_to_java_column(col)))
837+
838+
839+
@since(1.5)
840+
def date_add(start, days):
841+
"""
842+
Returns the date that is `days` days after `start`
843+
844+
>>> df = sqlContext.createDataFrame([('2015-04-08',)], ['d'])
845+
>>> df.select(date_add(df.d, 1).alias('d')).collect()
846+
[Row(d=datetime.date(2015, 4, 9))]
847+
"""
848+
sc = SparkContext._active_spark_context
849+
return Column(sc._jvm.functions.date_add(_to_java_column(start), days))
850+
851+
852+
@since(1.5)
853+
def date_sub(start, days):
854+
"""
855+
Returns the date that is `days` days before `start`
856+
857+
>>> df = sqlContext.createDataFrame([('2015-04-08',)], ['d'])
858+
>>> df.select(date_sub(df.d, 1).alias('d')).collect()
859+
[Row(d=datetime.date(2015, 4, 7))]
860+
"""
861+
sc = SparkContext._active_spark_context
862+
return Column(sc._jvm.functions.date_sub(_to_java_column(start), days))
863+
864+
865+
@since(1.5)
866+
def add_months(start, months):
867+
"""
868+
Returns the date that is `months` months after `start`
869+
870+
>>> df = sqlContext.createDataFrame([('2015-04-08',)], ['d'])
871+
>>> df.select(add_months(df.d, 1).alias('d')).collect()
872+
[Row(d=datetime.date(2015, 5, 8))]
873+
"""
874+
sc = SparkContext._active_spark_context
875+
return Column(sc._jvm.functions.add_months(_to_java_column(start), months))
876+
877+
878+
@since(1.5)
879+
def months_between(date1, date2):
880+
"""
881+
Returns the number of months between date1 and date2.
882+
883+
>>> df = sqlContext.createDataFrame([('1997-02-28 10:30:00', '1996-10-30')], ['t', 'd'])
884+
>>> df.select(months_between(df.t, df.d).alias('months')).collect()
885+
[Row(months=3.94959677)]
886+
"""
887+
sc = SparkContext._active_spark_context
888+
return Column(sc._jvm.functions.months_between(_to_java_column(date1), _to_java_column(date2)))
837889

838890

839891
@since(1.5)

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ object HiveTypeCoercion {
4747
Division ::
4848
PropagateTypes ::
4949
ImplicitTypeCasts ::
50+
DateTimeOperations ::
5051
Nil
5152

5253
// See https://cwiki.apache.org/confluence/display/Hive/LanguageManual+Types.
@@ -638,6 +639,24 @@ object HiveTypeCoercion {
638639
}
639640
}
640641

642+
/**
643+
* Turns Add/Subtract of DateType/TimestampType and IntervalType to TimeAdd/TimeSub
644+
*/
645+
object DateTimeOperations extends Rule[LogicalPlan] {
646+
def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions {
647+
// Skip nodes who's children have not been resolved yet.
648+
case e if !e.childrenResolved => e
649+
650+
case Add(left, right) if left.dataType == IntervalType =>
651+
Add(right, left) // switch the order
652+
653+
case Add(left, right) if right.dataType == IntervalType =>
654+
Cast(TimeAdd(Cast(left, TimestampType), right), left.dataType)
655+
case Subtract(left, right) if right.dataType == IntervalType =>
656+
Cast(TimeSub(Cast(left, TimestampType), right), left.dataType)
657+
}
658+
}
659+
641660
/**
642661
* Casts types according to the expected input types for [[Expression]]s.
643662
*/

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeFunctions.scala

Lines changed: 32 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -379,111 +379,95 @@ case class NextDay(startDate: Expression, dayOfWeek: Expression)
379379
}
380380

381381
/**
382-
* Time Adds Interval.
382+
* Adds an interval to timestamp.
383383
*/
384-
case class TimeAdd(left: Expression, right: Expression)
384+
case class TimeAdd(start: Expression, interval: Expression)
385385
extends BinaryExpression with ExpectsInputTypes {
386386

387+
override def left: Expression = start
388+
override def right: Expression = interval
389+
387390
override def toString: String = s"$left + $right"
388-
override def inputTypes: Seq[AbstractDataType] =
389-
Seq(TypeCollection(DateType, TimestampType), IntervalType)
391+
override def inputTypes: Seq[AbstractDataType] = Seq(TimestampType, IntervalType)
390392

391393
override def dataType: DataType = TimestampType
392394

393395
override def nullSafeEval(start: Any, interval: Any): Any = {
394396
val itvl = interval.asInstanceOf[Interval]
395-
left.dataType match {
396-
case DateType =>
397-
DateTimeUtils.dateAddFullInterval(
398-
start.asInstanceOf[Int], itvl.months, itvl.microseconds)
399-
case TimestampType =>
400-
DateTimeUtils.timestampAddFullInterval(
401-
start.asInstanceOf[Long], itvl.months, itvl.microseconds)
402-
}
397+
DateTimeUtils.timestampAddInterval(
398+
start.asInstanceOf[Long], itvl.months, itvl.microseconds)
403399
}
404400

405401
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
406402
val dtu = DateTimeUtils.getClass.getName.stripSuffix("$")
407-
left.dataType match {
408-
case DateType =>
409-
defineCodeGen(ctx, ev, (sd, i) => {
410-
s"""$dtu.dateAddFullInterval($sd, $i.months, $i.microseconds)"""
411-
})
412-
case TimestampType => // TimestampType
413-
defineCodeGen(ctx, ev, (sd, i) => {
414-
s"""$dtu.timestampAddFullInterval($sd, $i.months, $i.microseconds)"""
415-
})
416-
}
403+
defineCodeGen(ctx, ev, (sd, i) => {
404+
s"""$dtu.timestampAddInterval($sd, $i.months, $i.microseconds)"""
405+
})
417406
}
418407
}
419408

420409
/**
421-
* Time Subtracts Interval.
410+
* Subtracts an interval from timestamp.
422411
*/
423-
case class TimeSub(left: Expression, right: Expression)
412+
case class TimeSub(start: Expression, interval: Expression)
424413
extends BinaryExpression with ExpectsInputTypes {
425414

415+
override def left: Expression = start
416+
override def right: Expression = interval
417+
426418
override def toString: String = s"$left - $right"
427-
override def inputTypes: Seq[AbstractDataType] =
428-
Seq(TypeCollection(DateType, TimestampType), IntervalType)
419+
override def inputTypes: Seq[AbstractDataType] = Seq(TimestampType, IntervalType)
429420

430421
override def dataType: DataType = TimestampType
431422

432423
override def nullSafeEval(start: Any, interval: Any): Any = {
433424
val itvl = interval.asInstanceOf[Interval]
434-
left.dataType match {
435-
case DateType =>
436-
DateTimeUtils.dateAddFullInterval(
437-
start.asInstanceOf[Int], 0 - itvl.months, 0 - itvl.microseconds)
438-
case TimestampType =>
439-
DateTimeUtils.timestampAddFullInterval(
440-
start.asInstanceOf[Long], 0 - itvl.months, 0 - itvl.microseconds)
441-
}
425+
DateTimeUtils.timestampAddInterval(
426+
start.asInstanceOf[Long], 0 - itvl.months, 0 - itvl.microseconds)
442427
}
443428

444429
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
445430
val dtu = DateTimeUtils.getClass.getName.stripSuffix("$")
446-
left.dataType match {
447-
case DateType =>
448-
defineCodeGen(ctx, ev, (sd, i) => {
449-
s"""$dtu.dateAddFullInterval($sd, 0 - $i.months, 0 - $i.microseconds)"""
450-
})
451-
case TimestampType => // TimestampType
452-
defineCodeGen(ctx, ev, (sd, i) => {
453-
s"""$dtu.timestampAddFullInterval($sd, 0 - $i.months, 0 - $i.microseconds)"""
454-
})
455-
}
431+
defineCodeGen(ctx, ev, (sd, i) => {
432+
s"""$dtu.timestampAddInterval($sd, 0 - $i.months, 0 - $i.microseconds)"""
433+
})
456434
}
457435
}
458436

459437
/**
460438
* Returns the date that is num_months after start_date.
461439
*/
462-
case class AddMonths(left: Expression, right: Expression)
440+
case class AddMonths(startDate: Expression, numMonths: Expression)
463441
extends BinaryExpression with ImplicitCastInputTypes {
464442

443+
override def left: Expression = startDate
444+
override def right: Expression = numMonths
445+
465446
override def inputTypes: Seq[AbstractDataType] = Seq(DateType, IntegerType)
466447

467448
override def dataType: DataType = DateType
468449

469450
override def nullSafeEval(start: Any, months: Any): Any = {
470-
DateTimeUtils.dateAddYearMonthInterval(start.asInstanceOf[Int], months.asInstanceOf[Int])
451+
DateTimeUtils.dateAddMonths(start.asInstanceOf[Int], months.asInstanceOf[Int])
471452
}
472453

473454
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
474455
val dtu = DateTimeUtils.getClass.getName.stripSuffix("$")
475456
defineCodeGen(ctx, ev, (sd, m) => {
476-
s"""$dtu.dateAddYearMonthInterval($sd, $m)"""
457+
s"""$dtu.dateAddMonths($sd, $m)"""
477458
})
478459
}
479460
}
480461

481462
/**
482463
* Returns number of months between dates date1 and date2.
483464
*/
484-
case class MonthsBetween(left: Expression, right: Expression)
465+
case class MonthsBetween(date1: Expression, date2: Expression)
485466
extends BinaryExpression with ImplicitCastInputTypes {
486467

468+
override def left: Expression = date1
469+
override def right: Expression = date2
470+
487471
override def inputTypes: Seq[AbstractDataType] = Seq(TimestampType, TimestampType)
488472

489473
override def dataType: DataType = DoubleType

0 commit comments

Comments
 (0)