Skip to content

Commit 87c4b77

Browse files
committed
function add_months, months_between and some fixes
1 parent 1a68e03 commit 87c4b77

File tree

6 files changed

+313
-67
lines changed

6 files changed

+313
-67
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -181,6 +181,7 @@ object FunctionRegistry {
181181
expression[Upper]("upper"),
182182

183183
// datetime functions
184+
expression[AddMonths]("add_months"),
184185
expression[CurrentDate]("current_date"),
185186
expression[CurrentTimestamp]("current_timestamp"),
186187
expression[DateAdd]("date_add"),
@@ -191,6 +192,7 @@ object FunctionRegistry {
191192
expression[DayOfMonth]("dayofmonth"),
192193
expression[Hour]("hour"),
193194
expression[Month]("month"),
195+
expression[MonthsBetween]("months_between"),
194196
expression[Minute]("minute"),
195197
expression[Quarter]("quarter"),
196198
expression[Second]("second"),

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeFunctions.scala

Lines changed: 82 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -309,21 +309,18 @@ case class DateFormatClass(left: Expression, right: Expression) extends BinaryEx
309309
/**
310310
* Time Adds Interval.
311311
*/
312-
case class TimeAdd(start: Expression, interval: Expression)
313-
extends BinaryExpression with ExpectsInputTypes with CodegenFallback {
314-
315-
override def left: Expression = start
316-
override def right: Expression = interval
312+
case class TimeAdd(left: Expression, right: Expression)
313+
extends BinaryExpression with ExpectsInputTypes {
317314

318315
override def toString: String = s"$left + $right"
319316
override def inputTypes: Seq[AbstractDataType] =
320317
Seq(TypeCollection(DateType, TimestampType), IntervalType)
321318

322319
override def dataType: DataType = TimestampType
323320

324-
override def nullSafeEval(start: Any, inter: Any): Any = {
325-
val itvl = inter.asInstanceOf[Interval]
326-
dataType match {
321+
override def nullSafeEval(start: Any, interval: Any): Any = {
322+
val itvl = interval.asInstanceOf[Interval]
323+
left.dataType match {
327324
case DateType =>
328325
DateTimeUtils.dateAddFullInterval(
329326
start.asInstanceOf[Int], itvl.months, itvl.microseconds)
@@ -332,26 +329,37 @@ case class TimeAdd(start: Expression, interval: Expression)
332329
start.asInstanceOf[Long], itvl.months, itvl.microseconds)
333330
}
334331
}
332+
333+
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
334+
val dtu = DateTimeUtils.getClass.getName.stripSuffix("$")
335+
left.dataType match {
336+
case DateType =>
337+
defineCodeGen(ctx, ev, (sd, i) => {
338+
s"""$dtu.dateAddFullInterval($sd, $i.months, $i.microseconds)"""
339+
})
340+
case TimestampType => // TimestampType
341+
defineCodeGen(ctx, ev, (sd, i) => {
342+
s"""$dtu.timestampAddFullInterval($sd, $i.months, $i.microseconds)"""
343+
})
344+
}
345+
}
335346
}
336347

337348
/**
338349
* Time Subtracts Interval.
339350
*/
340-
case class TimeSub(start: Expression, interval: Expression)
341-
extends BinaryExpression with ExpectsInputTypes with CodegenFallback {
342-
343-
override def left: Expression = start
344-
override def right: Expression = interval
351+
case class TimeSub(left: Expression, right: Expression)
352+
extends BinaryExpression with ExpectsInputTypes {
345353

346354
override def toString: String = s"$left - $right"
347355
override def inputTypes: Seq[AbstractDataType] =
348356
Seq(TypeCollection(DateType, TimestampType), IntervalType)
349357

350358
override def dataType: DataType = TimestampType
351359

352-
override def nullSafeEval(start: Any, inter: Any): Any = {
353-
val itvl = inter.asInstanceOf[Interval]
354-
dataType match {
360+
override def nullSafeEval(start: Any, interval: Any): Any = {
361+
val itvl = interval.asInstanceOf[Interval]
362+
left.dataType match {
355363
case DateType =>
356364
DateTimeUtils.dateAddFullInterval(
357365
start.asInstanceOf[Int], 0 - itvl.months, 0 - itvl.microseconds)
@@ -360,4 +368,62 @@ case class TimeSub(start: Expression, interval: Expression)
360368
start.asInstanceOf[Long], 0 - itvl.months, 0 - itvl.microseconds)
361369
}
362370
}
371+
372+
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
373+
val dtu = DateTimeUtils.getClass.getName.stripSuffix("$")
374+
left.dataType match {
375+
case DateType =>
376+
defineCodeGen(ctx, ev, (sd, i) => {
377+
s"""$dtu.dateAddFullInterval($sd, 0 - $i.months, 0 - $i.microseconds)"""
378+
})
379+
case TimestampType => // TimestampType
380+
defineCodeGen(ctx, ev, (sd, i) => {
381+
s"""$dtu.timestampAddFullInterval($sd, 0 - $i.months, 0 - $i.microseconds)"""
382+
})
383+
}
384+
}
385+
}
386+
387+
/**
388+
* Returns the date that is num_months after start_date.
389+
*/
390+
case class AddMonths(left: Expression, right: Expression)
391+
extends BinaryExpression with ImplicitCastInputTypes {
392+
393+
override def inputTypes: Seq[AbstractDataType] = Seq(DateType, IntegerType)
394+
395+
override def dataType: DataType = DateType
396+
397+
override def nullSafeEval(start: Any, months: Any): Any = {
398+
DateTimeUtils.dateAddYearMonthInterval(start.asInstanceOf[Int], months.asInstanceOf[Int])
399+
}
400+
401+
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
402+
val dtu = DateTimeUtils.getClass.getName.stripSuffix("$")
403+
defineCodeGen(ctx, ev, (sd, m) => {
404+
s"""$dtu.dateAddYearMonthInterval($sd, $m)"""
405+
})
406+
}
407+
}
408+
409+
/**
410+
* Returns number of months between dates date1 and date2.
411+
*/
412+
case class MonthsBetween(left: Expression, right: Expression)
413+
extends BinaryExpression with ImplicitCastInputTypes {
414+
415+
override def inputTypes: Seq[AbstractDataType] = Seq(TimestampType, TimestampType)
416+
417+
override def dataType: DataType = DoubleType
418+
419+
override def nullSafeEval(t1: Any, t2: Any): Any = {
420+
DateTimeUtils.monthsBetween(t1.asInstanceOf[Long], t2.asInstanceOf[Long])
421+
}
422+
423+
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
424+
val dtu = DateTimeUtils.getClass.getName.stripSuffix("$")
425+
defineCodeGen(ctx, ev, (l, r) => {
426+
s"""$dtu.monthsBetween($l, $r)"""
427+
})
428+
}
363429
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala

Lines changed: 57 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -584,7 +584,7 @@ object DateTimeUtils {
584584
val daysFromYears = getDaysFromYears(yearSinceEpoch)
585585
val febDays = if (isLeapYear(1970 + yearSinceEpoch)) 29 else 28
586586
val daysForMonths = Seq(31, febDays, 31, 30, 31, 30, 31, 31, 30, 31, 30, 31)
587-
daysForMonths.slice(0, monthInYearFromZero + 1).sum + daysFromYears
587+
daysForMonths.slice(0, monthInYearFromZero).sum + daysFromYears
588588
}
589589

590590
/**
@@ -611,23 +611,75 @@ object DateTimeUtils {
611611
* Returns a date value, expressed in days since 1.1.1970.
612612
*/
613613
def dateAddYearMonthInterval(days: Int, months: Int): Int = {
614-
getDaysFromMonths(getYear(days) * 12 + getMonth(days) + months) + getDayOfMonth(days)
614+
val currentMonth = (getYear(days) - 1970) * 12 + getMonth(days) - 1 + months
615+
val currentMonthInYear = if (currentMonth < 0) {
616+
((currentMonth % 12) + 12) % 12
617+
} else {
618+
currentMonth % 12
619+
}
620+
val currentYear = if (currentMonth < 0) (currentMonth / 12) - 1 else currentMonth / 12
621+
val febDays = if (isLeapYear(1970 + currentYear)) 29 else 28
622+
val daysForMonths = Seq(31, febDays, 31, 30, 31, 30, 31, 31, 30, 31, 30, 31)
623+
getDaysFromMonths(currentMonth) + math.min(
624+
getDayOfMonth(days), daysForMonths(currentMonthInYear)) - 1
615625
}
616626

617627
/**
618628
* Add date and full interval.
619629
* Returns a timestamp value, expressed in microseconds since 1.1.1970 00:00:00.
620630
*/
621631
def dateAddFullInterval(days: Int, months: Int, microseconds: Long): Long = {
622-
daysToMillis(dateAddYearMonthInterval(days, months)) * 1000 + microseconds
632+
daysToMillis(dateAddYearMonthInterval(days, months)) * 1000L + microseconds
623633
}
624634

625635
/**
626636
* Add timestamp and full interval.
627637
* Returns a timestamp value, expressed in microseconds since 1.1.1970 00:00:00.
628638
*/
629639
def timestampAddFullInterval(micros: Long, months: Int, microseconds: Long): Long = {
630-
dateAddYearMonthInterval(
631-
millisToDays(micros / 1000), months) + micros % (MILLIS_PER_DAY * 1000) + microseconds
640+
val days = millisToDays(micros / 1000L)
641+
dateAddFullInterval(days, months, microseconds) + micros - daysToMillis(days) * 1000L
642+
}
643+
644+
/**
645+
* Returns the last dayInMonth in the month it belongs to. The date is expressed
646+
* in days since 1.1.1970. the return value starts from 1.
647+
*/
648+
def getLastDayInMonthOfMonth(date: Int): Int = {
649+
val month = getMonth(date)
650+
651+
val febDay = if (isLeapYear(getYear(date))) 29 else 28
652+
val days = Seq(31, febDay, 31, 30, 31, 30, 31, 31, 30, 31, 30, 31)
653+
days(month - 1)
654+
}
655+
656+
/**
657+
* Returns number of months between time1 and time2. time1 and time2 are expressed in
658+
* microseconds since 1.1.1970
659+
*/
660+
def monthsBetween(time1: Long, time2: Long): Double = {
661+
val millis1 = time1.asInstanceOf[Long] / 1000L
662+
val millis2 = time2.asInstanceOf[Long] / 1000L
663+
val date1 = millisToDays(millis1)
664+
val date2 = millisToDays(millis2)
665+
val microOffset1 = time1 - daysToMillis(date1) * 1000L
666+
val microOffset2 = time2 - daysToMillis(date2) * 1000L
667+
val dayInMonth1 = getDayOfMonth(date1)
668+
val dayInMonth2 = getDayOfMonth(date2)
669+
val lastDayMonth1 = getLastDayInMonthOfMonth(date1)
670+
val lastDayMonth2 = getLastDayInMonthOfMonth(date2)
671+
val months1 = getYear(date1) * 12 + getMonth(date1) - 1
672+
val months2 = getYear(date2) * 12 + getMonth(date2) - 1
673+
val timeInDay1 = time1 - daysToMillis(date1) * 1000L
674+
val timeInDay2 = time2 - daysToMillis(date2) * 1000L
675+
if (dayInMonth1 == dayInMonth2 || (lastDayMonth1 == dayInMonth1 &&
676+
lastDayMonth2 == dayInMonth2) || dayInMonth2 >= lastDayMonth1) {
677+
(months1 - months2).toDouble
678+
} else {
679+
val timesBetween = (timeInDay1 - timeInDay2).toDouble / (MILLIS_PER_DAY * 1000)
680+
println(timesBetween, dayInMonth1, dayInMonth2)
681+
(microOffset1 - microOffset2) / (31 * MILLIS_PER_DAY * 1000L).toDouble + (
682+
months1 - months2).toDouble + (dayInMonth1 - dayInMonth2 + timesBetween) / 31.0
683+
}
632684
}
633685
}

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala

Lines changed: 35 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ import java.util.Calendar
2323

2424
import org.apache.spark.SparkFunSuite
2525
import org.apache.spark.sql.catalyst.util.DateTimeUtils
26-
import org.apache.spark.sql.types.{IntegerType, StringType, TimestampType, DateType}
26+
import org.apache.spark.sql.types.{StringType, TimestampType, DateType}
2727
import org.apache.spark.unsafe.types.Interval
2828

2929
class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
@@ -249,55 +249,56 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
249249
}
250250

251251
test("date_add") {
252-
checkResult(
252+
checkEvaluation(
253253
DateAdd(Literal(Date.valueOf("2016-02-28")), Literal(1)),
254254
DateTimeUtils.fromJavaDate(Date.valueOf("2016-02-29")))
255-
checkResult(
256-
DateAdd(Literal("2016-03-01"), Literal(-1)),
257-
DateTimeUtils.fromJavaDate(Date.valueOf("2016-02-29")))
258-
checkResult(
259-
DateAdd(Literal(Timestamp.valueOf("2016-03-01 23:59:59")), Literal(-2)),
260-
DateTimeUtils.fromJavaDate(Date.valueOf("2016-02-28")))
261-
checkResult(
262-
DateAdd(Literal("2016-03-01 23:59:59"), Literal(-3)),
263-
DateTimeUtils.fromJavaDate(Date.valueOf("2016-02-27")))
264-
checkResult(DateAdd(Literal(null), Literal(-1)), null)
255+
checkEvaluation(
256+
DateAdd(Literal(Date.valueOf("2016-02-28")), Literal(-365)),
257+
DateTimeUtils.fromJavaDate(Date.valueOf("2015-02-28")))
265258
}
266259

267260
test("date_sub") {
268-
checkResult(
269-
DateSub(Literal("2015-01-01"), Literal(1)),
261+
checkEvaluation(
262+
DateSub(Literal(Date.valueOf("2015-01-01")), Literal(1)),
270263
DateTimeUtils.fromJavaDate(Date.valueOf("2014-12-31")))
271-
checkResult(
264+
checkEvaluation(
272265
DateSub(Literal(Date.valueOf("2015-01-01")), Literal(-1)),
273266
DateTimeUtils.fromJavaDate(Date.valueOf("2015-01-02")))
274-
checkResult(
275-
DateSub(Literal(Timestamp.valueOf("2015-01-01 01:00:00")), Literal(-1)),
276-
DateTimeUtils.fromJavaDate(Date.valueOf("2015-01-02")))
277-
checkResult(
278-
DateSub(Literal("2015-01-01 01:00:00"), Literal(0)),
279-
DateTimeUtils.fromJavaDate(Date.valueOf("2015-01-01")))
280-
checkResult(
281-
DateSub(Literal("2015-01-01"), Literal.create(null, IntegerType)), null)
282267
}
283268

284269
test("time_add") {
285-
checkResult(
286-
TimeAdd(Literal(Date.valueOf("2016-02-28")), Literal(new Interval(1, 0))),
287-
DateTimeUtils.fromJavaDate(Date.valueOf("2016-02-29")))
288-
checkResult(
289-
TimeAdd(Literal(Date.valueOf("2016-03-01")), Literal(new Interval(1, 2000000.toLong))),
290-
DateTimeUtils.fromJavaTimestamp(Timestamp.valueOf("2016-03-02 00:00:02")))
270+
checkEvaluation(
271+
TimeAdd(Literal(Date.valueOf("2016-01-29")), Literal(new Interval(1, 0))),
272+
DateTimeUtils.fromJavaTimestamp(Timestamp.valueOf("2016-02-29 00:00:00")))
273+
checkEvaluation(
274+
TimeAdd(Literal(Date.valueOf("2016-01-31")), Literal(new Interval(1, 2000000.toLong))),
275+
DateTimeUtils.fromJavaTimestamp(Timestamp.valueOf("2016-02-29 00:00:02")))
291276
}
292277

293278
test("time_sub") {
294-
checkResult(
295-
TimeSub(Literal(Timestamp.valueOf("2016-02-28 10:00:00")), Literal(new Interval(1, 0))),
296-
DateTimeUtils.fromJavaTimestamp(Timestamp.valueOf("2016-02-27 10:00:00")))
297-
checkResult(
279+
checkEvaluation(
280+
TimeSub(Literal(Timestamp.valueOf("2016-03-31 10:00:00")), Literal(new Interval(1, 0))),
281+
DateTimeUtils.fromJavaTimestamp(Timestamp.valueOf("2016-02-29 10:00:00")))
282+
checkEvaluation(
298283
TimeSub(
299-
Literal(Timestamp.valueOf("2016-03-01 00:00:02")),
284+
Literal(Timestamp.valueOf("2016-03-30 00:00:01")),
300285
Literal(new Interval(1, 2000000.toLong))),
301286
DateTimeUtils.fromJavaTimestamp(Timestamp.valueOf("2016-02-28 23:59:59")))
302287
}
288+
289+
test("add_months") {
290+
checkEvaluation(AddMonths(Literal(Date.valueOf(
291+
"2015-01-30")), Literal(1)), DateTimeUtils.fromJavaDate(Date.valueOf("2015-02-28")))
292+
checkEvaluation(AddMonths(Literal(Date.valueOf(
293+
"2016-03-30")), Literal(-1)), DateTimeUtils.fromJavaDate(Date.valueOf("2016-02-29")))
294+
}
295+
296+
test("months_between") {
297+
checkEvaluation(MonthsBetween(Literal(Timestamp.valueOf(
298+
"2015-01-30 11:52:00")), Literal(Timestamp.valueOf("2015-01-30 11:50:00"))), 0.0)
299+
checkEvaluation(MonthsBetween(Literal(Timestamp.valueOf(
300+
"2015-01-31 00:00:00")), Literal(Timestamp.valueOf("2015-03-31 22:00:00"))), -2.0)
301+
checkEvaluation(MonthsBetween(Literal(Timestamp.valueOf(
302+
"2015-03-31 22:00:00")), Literal(Timestamp.valueOf("2015-02-28 00:00:00"))), 1.0)
303+
}
303304
}

sql/core/src/main/scala/org/apache/spark/sql/functions.scala

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2261,6 +2261,14 @@ object functions {
22612261
// DateTime functions
22622262
//////////////////////////////////////////////////////////////////////////////////////////////
22632263

2264+
/**
2265+
* Returns the date that is numMonths after startDate.
2266+
* @group datetime_funcs
2267+
* @since 1.5.0
2268+
*/
2269+
def add_months(startDate: Column, numMonths: Column): Column =
2270+
AddMonths(startDate.expr, numMonths.expr)
2271+
22642272
/**
22652273
* Converts a date/timestamp/string to a value of string in the format specified by the date
22662274
* format given by the second argument.
@@ -2405,6 +2413,13 @@ object functions {
24052413
*/
24062414
def minute(columnName: String): Column = minute(Column(columnName))
24072415

2416+
/**
2417+
* Returns number of months between dates date1 and date2.
2418+
* @group datetime_funcs
2419+
* @since 1.5.0
2420+
*/
2421+
def months_between(l: Column, r: Column): Column = MonthsBetween(l.expr, r.expr)
2422+
24082423
/**
24092424
* Extracts the seconds as an integer from a given date/timestamp/string.
24102425
* @group datetime_funcs
@@ -2419,6 +2434,20 @@ object functions {
24192434
*/
24202435
def second(columnName: String): Column = second(Column(columnName))
24212436

2437+
/**
2438+
* Adds a time and an interval value
2439+
* @group datetime_funcs
2440+
* @since 1.5.0
2441+
*/
2442+
def time_add(l: Column, r: Column): Column = TimeAdd(l.expr, r.expr)
2443+
2444+
/**
2445+
* Subtracts an interval from a time value
2446+
* @group datetime_funcs
2447+
* @since 1.5.0
2448+
*/
2449+
def time_sub(l: Column, r: Column): Column = TimeSub(l.expr, r.expr)
2450+
24222451
/**
24232452
* Extracts the week number as an integer from a given date/timestamp/string.
24242453
* @group datetime_funcs

0 commit comments

Comments
 (0)